You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

106 lines
3.2 KiB

5 months ago
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from kafka import KafkaProducer, KafkaAdminClient
from kafka.admin import NewTopic
import uuid
import json
5 months ago
from AiService.Config.Config import KAFKA_TOPIC, KAFKA_BOOTSTRAP_SERVERS
5 months ago
from AiService.Model.TaskModel import *
5 months ago
# 定义请求体模型
class TaskRequest(BaseModel):
prompt: str # 用户输入的提示词
# 初始化 FastAPI 应用
app = FastAPI()
# 创建 Kafka 生产者
producer = KafkaProducer(
bootstrap_servers=KAFKA_BOOTSTRAP_SERVERS,
value_serializer=lambda v: json.dumps(v).encode('utf-8') # 将消息序列化为 JSON
)
# 检查并创建主题
def ensure_topic_exists():
try:
# 创建 Kafka AdminClient
admin_client = KafkaAdminClient(bootstrap_servers=KAFKA_BOOTSTRAP_SERVERS)
# 获取所有主题
existing_topics = admin_client.list_topics()
# 如果主题不存在,则创建
if KAFKA_TOPIC not in existing_topics:
topic = NewTopic(
name=KAFKA_TOPIC,
num_partitions=1, # 分区数
replication_factor=1 # 副本数
)
admin_client.create_topics([topic])
print(f"Topic '{KAFKA_TOPIC}' created successfully.")
else:
print(f"Topic '{KAFKA_TOPIC}' already exists.")
# 关闭 AdminClient
admin_client.close()
except Exception as e:
print(f"Failed to ensure topic exists: {str(e)}")
5 months ago
# 查询任务状态接口
@app.get("/task-status/{task_id}")
def get_task_status(task_id: str):
try:
tm = TaskModel()
task_status = tm.get_task_status(task_id)
tm.close()
if task_status:
return task_status
else:
raise HTTPException(status_code=404, detail="Task not found")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to get task status: {str(e)}")
5 months ago
# 定义 Lifespan 事件
@asynccontextmanager
async def lifespan(app: FastAPI):
# 应用启动时检查并创建主题
ensure_topic_exists()
yield
# 应用关闭时清理资源(可选)
print("Application shutdown.")
# 定义接口
@app.post("/create-task")
def create_task(task_request: TaskRequest):
try:
# 生成唯一的 task_id
task_id = str(uuid.uuid4())
# 构造任务信息
task_data = {
"task_id": task_id,
"prompt": task_request.prompt
}
# 将任务信息写入 Kafka
producer.send(KAFKA_TOPIC, value=task_data)
producer.flush() # 确保消息发送完成
# 将任务信息写入 MySQL
5 months ago
tm = TaskModel()
tm.insert_task(task_id, task_request.prompt)
tm.close()
5 months ago
# 返回 task_id
return {"task_id": task_id}
except Exception as e:
# 处理异常
raise HTTPException(status_code=500, detail=f"Failed to create task: {str(e)}")
# 启动应用
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)