You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
89 lines
2.6 KiB
89 lines
2.6 KiB
from contextlib import asynccontextmanager
|
|
from fastapi import FastAPI, HTTPException
|
|
from pydantic import BaseModel
|
|
from kafka import KafkaProducer, KafkaAdminClient
|
|
from kafka.admin import NewTopic
|
|
import uuid
|
|
import json
|
|
from AiService.Model.TaskModel import *
|
|
|
|
# 定义请求体模型
|
|
class TaskRequest(BaseModel):
|
|
prompt: str # 用户输入的提示词
|
|
|
|
# 初始化 FastAPI 应用
|
|
app = FastAPI()
|
|
|
|
# 创建 Kafka 生产者
|
|
producer = KafkaProducer(
|
|
bootstrap_servers=KAFKA_BOOTSTRAP_SERVERS,
|
|
value_serializer=lambda v: json.dumps(v).encode('utf-8') # 将消息序列化为 JSON
|
|
)
|
|
|
|
# 检查并创建主题
|
|
def ensure_topic_exists():
|
|
try:
|
|
# 创建 Kafka AdminClient
|
|
admin_client = KafkaAdminClient(bootstrap_servers=KAFKA_BOOTSTRAP_SERVERS)
|
|
|
|
# 获取所有主题
|
|
existing_topics = admin_client.list_topics()
|
|
|
|
# 如果主题不存在,则创建
|
|
if KAFKA_TOPIC not in existing_topics:
|
|
topic = NewTopic(
|
|
name=KAFKA_TOPIC,
|
|
num_partitions=1, # 分区数
|
|
replication_factor=1 # 副本数
|
|
)
|
|
admin_client.create_topics([topic])
|
|
print(f"Topic '{KAFKA_TOPIC}' created successfully.")
|
|
else:
|
|
print(f"Topic '{KAFKA_TOPIC}' already exists.")
|
|
|
|
# 关闭 AdminClient
|
|
admin_client.close()
|
|
except Exception as e:
|
|
print(f"Failed to ensure topic exists: {str(e)}")
|
|
|
|
# 定义 Lifespan 事件
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
# 应用启动时检查并创建主题
|
|
ensure_topic_exists()
|
|
yield
|
|
# 应用关闭时清理资源(可选)
|
|
print("Application shutdown.")
|
|
|
|
# 定义接口
|
|
@app.post("/create-task")
|
|
def create_task(task_request: TaskRequest):
|
|
try:
|
|
# 生成唯一的 task_id
|
|
task_id = str(uuid.uuid4())
|
|
|
|
# 构造任务信息
|
|
task_data = {
|
|
"task_id": task_id,
|
|
"prompt": task_request.prompt
|
|
}
|
|
|
|
# 将任务信息写入 Kafka
|
|
producer.send(KAFKA_TOPIC, value=task_data)
|
|
producer.flush() # 确保消息发送完成
|
|
|
|
# 将任务信息写入 MySQL
|
|
tm = TaskModel()
|
|
tm.insert_task(task_id, task_request.prompt)
|
|
tm.close()
|
|
|
|
# 返回 task_id
|
|
return {"task_id": task_id}
|
|
except Exception as e:
|
|
# 处理异常
|
|
raise HTTPException(status_code=500, detail=f"Failed to create task: {str(e)}")
|
|
|
|
# 启动应用
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
uvicorn.run(app, host="0.0.0.0", port=8000) |