You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

868 lines
34 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import base64
import base64
import datetime
import json
import time
import uuid
from contextlib import asynccontextmanager
from datetime import datetime, timedelta
from typing import Optional
from alibabacloud_sts20150401 import models as sts_20150401_models
from alibabacloud_sts20150401.client import Client as Sts20150401Client
from alibabacloud_tea_openapi.models import Config
from fastapi import Query, Depends, status, Form, FastAPI
from fastapi.security import OAuth2PasswordBearer
from jose import JWTError, jwt
from openai import AsyncOpenAI
from passlib.context import CryptContext
from starlette.responses import StreamingResponse
from WxMini.Milvus.Utils.MilvusCollectionManager import MilvusCollectionManager
from WxMini.Milvus.Utils.MilvusConnectionPool import *
from WxMini.Utils.EmbeddingUtil import text_to_embedding
from WxMini.Utils.ImageUtil import *
from WxMini.Utils.MySQLUtil import init_mysql_pool, get_chat_log_by_session, get_user_by_login_name, \
get_chat_logs_by_risk_flag, get_chat_logs_summary, save_chat_to_mysql
from WxMini.Utils.MySQLUtil import update_risk, get_last_chat_log_id
from WxMini.Utils.NewsUtil import *
from WxMini.Utils.OssUtil import upload_mp3_to_oss_from_memory, hmacsha256
from WxMini.Utils.TianQiUtil import get_weather
from WxMini.Utils.TtsUtil import TTS
# 配置日志
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
# 密码加密上下文
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
# OAuth2 密码模式
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
# 初始化 Milvus 连接池
milvus_pool = MilvusConnectionPool(host=MS_HOST, port=MS_PORT, max_connections=MS_MAX_CONNECTIONS)
# 初始化集合管理器
collection_name = MS_COLLECTION_NAME
collection_manager = MilvusCollectionManager(collection_name)
# 使用 Lifespan Events 处理应用启动和关闭逻辑
@asynccontextmanager
async def lifespan(app: FastAPI):
# 应用启动时加载集合到内存
collection_manager.load_collection()
logger.info(f"集合 '{collection_name}' 已加载到内存。")
# 初始化 MySQL 连接池
app.state.mysql_pool = await init_mysql_pool()
logger.info("MySQL 连接池已初始化。")
yield
# 应用关闭时释放连接池
milvus_pool.close()
app.state.mysql_pool.close()
await app.state.mysql_pool.wait_closed()
logger.info("Milvus 和 MySQL 连接池已关闭。")
# 会话结束后,调用检查方法,判断是不是有需要介入的问题出现
async def on_session_end(person_id):
# 获取最后一条聊天记录
last_id = await get_last_chat_log_id(app.state.mysql_pool, person_id)
if last_id:
# 查询最后一条记录的详细信息
async with app.state.mysql_pool.acquire() as conn:
async with conn.cursor() as cur:
await cur.execute(
"SELECT user_input, model_response FROM t_chat_log WHERE id = %s",
(last_id,)
)
last_record = await cur.fetchone()
if last_record:
history = f"问题:{last_record[0]}\n回答:{last_record[1]}"
else:
history = "无聊天记录"
else:
history = "无聊天记录"
# 将历史聊天记录发给大模型,让它帮我分析一下
with open("Input.txt", "r", encoding="utf-8") as file:
input_word = file.read()
prompt = (
"分析用户是否存在心理健康方面的问题:"
f"参考分类文档内容如下:{input_word},注意:只有情节比较严重的才认为有健康问题,轻微的不算。"
"如果没有健康问题请回复: OK否则回复NO换行后输出问题类型的名称"
f"\n\n聊天记录:{history}"
)
# 使用 asyncio.create_task 异步执行大模型调用
async def analyze_mental_health():
response = await client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": "你是一个心理健康分析助手,负责分析用户的心理健康状况。"},
{"role": "user", "content": prompt}
],
max_tokens=1000
)
# 处理分析结果
if response.choices and response.choices[0].message.content:
analysis_result = response.choices[0].message.content.strip()
if analysis_result.startswith("NO"):
# 异步执行 update_risk
await update_risk(app.state.mysql_pool, person_id, analysis_result)
logger.info(f"已异步更新 person_id={person_id} 的风险状态。")
else:
logger.info(f"AI大模型没有发现任何心理健康问题用户会话 {person_id} 没有风险。")
# 创建异步任务
asyncio.create_task(analyze_mental_health())
# 初始化 FastAPI 应用
app = FastAPI(lifespan=lifespan)
# 初始化异步 OpenAI 客户端
client = AsyncOpenAI(
api_key=MODEL_API_KEY,
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
)
# 验证密码
def verify_password(plain_password, hashed_password):
return pwd_context.verify(plain_password, hashed_password)
# 创建 JWT
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=600)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, JWT_SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
# 获取当前用户
async def get_current_user(token: str = Depends(oauth2_scheme)):
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无法验证凭证",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[ALGORITHM])
login_name: str = payload.get("sub")
if login_name is None:
raise credentials_exception
except JWTError:
raise credentials_exception
user = await get_user_by_login_name(app.state.mysql_pool, login_name)
if user is None:
raise credentials_exception
return user
# 登录接口
@app.post("/aichat/login")
async def login(
login_name: str = Form(..., description="用户名"),
password: str = Form(..., description="密码")
):
"""
用户登录接口
:param login_name: 用户名
:param password: 密码
:return: 登录结果
"""
flag = True
if not login_name or not password:
flag = False
# 调用 get_user_by_login_name 方法
user = await get_user_by_login_name(app.state.mysql_pool, login_name)
if not user:
flag = False
if user and user['password'] != password:
flag = False
if not flag:
return {
"code": 200,
"message": "登录失败",
"success": False
}
# 生成 JWT
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token(
data={"sub": user["login_name"]}, expires_delta=access_token_expires
)
# 返回带字段名称的数据
return {
"message": "登录成功",
"success": True,
"data": {
"person_id": user["person_id"],
"login_name": user["login_name"],
"identity_id": user["identity_id"],
"person_name": user["person_name"],
"xb_name": user["xb_name"],
"city_name": user["city_name"],
"area_name": user["area_name"],
"school_name": user["school_name"],
"grade_name": user["grade_name"],
"class_name": user["class_name"],
"access_token": access_token,
"token_type": "bearer"
}
}
# 与用户交流聊天
@app.post("/aichat/reply")
async def reply(person_id: str = Form(...),
prompt: str = Form(...),
current_user: dict = Depends(get_current_user)
):
"""
接收用户输入的 prompt调用大模型并返回结果
:param person_id: 用户会话 ID
:param prompt: 用户输入的 prompt
:return: 大模型的回复
"""
logger.info(f"current_user= {current_user}")
try:
logger.info(f"收到用户输入: {prompt}")
if not prompt:
return {
"code": 200,
"message": "请输入内容",
"success": False
}
# 从连接池中获取一个连接
connection = milvus_pool.get_connection()
# 将用户输入转换为嵌入向量
current_embedding = text_to_embedding(prompt)
# 查询与当前对话最相关的历史交互
search_params = {
"metric_type": "L2", # 使用 L2 距离度量方式
"params": {"nprobe": MS_NPROBE} # 设置 IVF_FLAT 的 nprobe 参数
}
start_time = time.time()
results = await asyncio.to_thread( # 将阻塞操作放到线程池中执行
collection_manager.search,
data=current_embedding, # 输入向量
search_params=search_params, # 搜索参数
expr=f"person_id == '{person_id}'", # 按 person_id 过滤
limit=6 # 返回 6 条结果
)
end_time = time.time()
# 构建历史交互提示词
history_prompt = ""
if results:
for hits in results:
for hit in hits:
try:
# 查询非向量字段
record = await asyncio.to_thread(collection_manager.query_by_id, hit.id)
if record:
# logger.info(f"查询到的记录: {record}")
# 添加历史交互
history_prompt += f"用户: {record['user_input']}\n大模型: {record['model_response']}\n"
except Exception as e:
logger.error(f"查询失败: {e}")
# 在最后增加此人最近几条的交互记录数据
try:
recent_logs = await get_chat_log_by_session(app.state.mysql_pool, person_id)
data = recent_logs["data"]
for record in data:
history_prompt += f"用户: {record['user_input']}\n大模型: {record['model_response']}\n"
except Exception as e:
logger.error(f"获取交互记录时出错:{e}")
# 限制历史交互提示词长度
history_prompt = history_prompt[:3000]
# 拼接交互提示词
if '天气' in prompt or '降温' in prompt or '气温' in prompt or '下雨' in prompt or '下雪' in prompt or '' in prompt:
weather_info = await get_weather('长春')
history_prompt = f"天气信息: {weather_info}\n"
logger.info(f"历史交互提示词: {history_prompt}")
# NBA与CBA
result = await get_news(client, prompt)
if result is not None:
history_prompt = result
print("新闻返回了下面的内容:" + result)
# 调用大模型,将历史交互作为提示词
try:
response = await asyncio.wait_for(
client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system",
"content": "你是一个和你聊天人的好朋友,疏导情绪,让他开心,亲切一些,不要使用哎呀这样的语气词。聊天的回复内容不要超过100字。"},
{"role": "user", "content": f"历史对话记录:{history_prompt},本次用户问题: {prompt}"}
],
max_tokens=4000
),
timeout=60 # 设置超时时间为 60 秒
)
except asyncio.TimeoutError:
logger.error("大模型调用超时")
raise HTTPException(status_code=500, detail="大模型调用超时")
# 提取生成的回复
if response.choices and response.choices[0].message.content:
result = response.choices[0].message.content.strip()
logger.info(f"大模型回复: {result}")
# 记录用户输入和大模型反馈到向量数据库
timestamp = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
entities = [
[person_id], # person_id
[prompt[:500]], # user_input截断到 500 字符
[result[:500]], # model_response截断到 500 字符
[timestamp], # timestamp
[current_embedding] # embedding
]
if len(prompt) > 500:
logger.warning(f"用户输入被截断,原始长度: {len(prompt)}")
if len(result) > 500:
logger.warning(f"大模型回复被截断,原始长度: {len(result)}")
await asyncio.to_thread(collection_manager.insert_data, entities)
logger.info("用户输入和大模型反馈已记录到向量数据库。")
# 调用 TTS 生成 MP3
uuid_str = str(uuid.uuid4())
timestamp = int(time.time())
# 生成年月日的目录名称
audio_dir = f"audio/{time.strftime('%Y%m%d', time.localtime())}"
tts_file = f"{audio_dir}/{uuid_str}_{timestamp}.mp3"
# 生成 TTS 音频数据(不落盘)
t = TTS(None) # 传入 None 表示不保存到本地文件
audio_data, duration = await asyncio.to_thread(t.generate_audio,
result) # 假设 TTS 类有一个 generate_audio 方法返回音频数据
# print(f"音频时长: {duration} 秒")
# 将音频数据直接上传到 OSS
await asyncio.to_thread(upload_mp3_to_oss_from_memory, tts_file, audio_data)
logger.info(f"TTS 文件已直接上传到 OSS: {tts_file}")
# 完整的 URL
url = OSS_PREFIX + tts_file
# 记录聊天数据到 MySQL
await save_chat_to_mysql(app.state.mysql_pool, person_id, prompt, result, url, duration)
logger.info("用户输入和大模型反馈已记录到 MySQL 数据库。")
# 调用会话检查机制,异步执行
asyncio.create_task(on_session_end(person_id))
# 返回数据
return {
"success": True,
"url": url,
"search_time": end_time - start_time, # 返回查询耗时
"duration": duration, # 返回大模型的回复时长
"response": result, # 返回大模型的回复
}
else:
raise HTTPException(status_code=500, detail="大模型未返回有效结果")
except Exception as e:
logger.error(f"调用大模型失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"调用大模型失败: {str(e)}")
finally:
# 释放连接
if 'connection' in locals() and connection: # 检查 connection 是否已定义且不为 None
milvus_pool.release_connection(connection)
# 获取聊天记录
@app.get("/aichat/get_chat_log")
async def get_chat_log(
person_id: str,
page: int = Query(default=1, ge=1, description="当前页码"),
page_size: int = Query(default=10, ge=1, le=100, description="每页记录数"),
current_user: dict = Depends(get_current_user)
):
"""
获取指定会话的聊天记录,默认返回最新的记录(最后一页)
:param person_id: 用户会话 ID
:param page: 当前页码(默认值为 1但会动态计算为最后一页
:param page_size: 每页记录数
:return: 分页数据
"""
logger.info(f"current_user={current_user['login_name']}")
# 调用 get_chat_log_by_session 方法
result = await get_chat_log_by_session(app.state.mysql_pool, person_id, page, page_size)
return result
# 获取风险聊天记录接口
@app.get("/aichat/get_risk_chat_logs")
async def get_risk_chat_logs(
risk_flag: int = Query(..., description="风险标志1 表示有风险0 表示无风险 ,2:处理完毕)"),
page: int = Query(default=1, ge=1, description="当前页码(默认值为 1"),
page_size: int = Query(default=10, ge=1, le=100, description="每页记录数(默认值为 10最大值为 100"),
person_id: str = Query(..., description="用户会话 ID"),
current_user: dict = Depends(get_current_user)
):
"""
获取聊天记录,支持分页和风险标志过滤
:param risk_flag: 风险标志
:param page: 当前页码
:param page_size: 每页记录数
:return: 分页数据
"""
logger.info(f"current_user={current_user['login_name']}")
# 计算分页偏移量
offset = (page - 1) * page_size
# 调用 get_chat_logs_by_risk_flag 方法
logs, total = await get_chat_logs_by_risk_flag(app.state.mysql_pool, risk_flag, person_id, offset, page_size)
if not logs:
return {
"success": False,
"message": "没有找到相关记录",
"data": {
}
}
# 返回分页数据
return {
"success": True,
"message": "查询成功",
"data": {
"total": total,
"page": page,
"page_size": page_size,
"logs": logs
}
}
# 获取风险统计接口
@app.get("/aichat/chat_logs_summary")
async def chat_logs_summary(
risk_flag: int = Query(..., description="风险标志1 表示有风险0 表示无风险 ,2:处理完毕)"),
page: int = Query(default=1, ge=1, description="当前页码(默认值为 1"),
page_size: int = Query(default=10, ge=1, le=100, description="每页记录数(默认值为 10最大值为 100"),
current_user: dict = Depends(get_current_user)
):
"""
获取风险统计接口,支持分页和风险标志过滤
:param risk_flag: 风险标志
:param page: 当前页码
:param page_size: 每页记录数
:param current_user: 当前用户信息
:return: 分页数据
"""
# 验证 risk_flag 的值
if risk_flag not in {0, 1, 2}:
raise HTTPException(status_code=400, detail="risk_flag 的值必须是 0、1 或 2")
# 计算分页偏移量
offset = (page - 1) * page_size
# 调用 get_chat_logs_summary 方法
logs, total = await get_chat_logs_summary(app.state.mysql_pool, risk_flag, offset, page_size)
# 如果未找到记录,返回友好提示
if not logs:
return {
"success": True,
"message": "未找到符合条件的记录",
"data": {
"total": 0,
"page": page,
"page_size": page_size,
"total_pages": 0,
"logs": []
}
}
# 计算总页数
total_pages = (total + page_size - 1) // page_size
# 返回分页数据
return {
"success": True,
"message": "查询成功",
"data": {
"total": total,
"page": page,
"page_size": page_size,
"total_pages": total_pages,
"logs": logs,
"login_name": current_user["login_name"],
"person_name": current_user["person_name"]
}
}
# 获取上传OSS的授权Token
@app.get("/aichat/get_post_signature_for_oss_upload")
async def generate_upload_params(current_user: dict = Depends(get_current_user)):
logger.info(f"current_user={current_user['login_name']}")
# 子账号的AK,SK,ARN
access_key_id = "LTAI5tJrhwuBzF2X9USrzubX"
access_key_secret = "I6ezLuYhk9z9MRjXD2q99STSpTONwW"
role_arn_for_oss_upload = "acs:ram::1546399445482588:role/huanghai-create-role"
# 桶名
oss_bucket = 'hzkc'
# 区域
region_id = 'cn-beijing'
host = f'http://{oss_bucket}.oss-cn-beijing.aliyuncs.com'
upload_dir = 'Upload' # 指定上传到OSS的文件前缀。
role_session_name = 'role_session_name' # 自定义会话名称
# 初始化配置,直接传递凭据
config = Config(
region_id=region_id,
access_key_id=access_key_id,
access_key_secret=access_key_secret
)
# 创建 STS 客户端并获取临时凭证
sts_client = Sts20150401Client(config=config)
assume_role_request = sts_20150401_models.AssumeRoleRequest(
role_arn=role_arn_for_oss_upload,
role_session_name=role_session_name
)
response = sts_client.assume_role(assume_role_request)
token_data = response.body.credentials.to_map()
# 使用 STS 返回的临时凭据
sts_access_key_id = token_data['AccessKeyId']
sts_access_key_secret = token_data['AccessKeySecret']
security_token = token_data['SecurityToken']
now = int(time.time())
# 将时间戳转换为datetime对象
dt_obj = datetime.utcfromtimestamp(now)
# 在当前时间增加3小时设置为请求的过期时间
dt_obj_plus_3h = dt_obj + timedelta(hours=1)
# 请求时间
dt_obj_1 = dt_obj.strftime('%Y%m%dT%H%M%S') + 'Z'
# 请求日期
dt_obj_2 = dt_obj.strftime('%Y%m%d')
# 请求过期时间
expiration_time = dt_obj_plus_3h.strftime('%Y-%m-%dT%H:%M:%S.000Z')
# 构建 Policy 并生成签名
policy = {
"expiration": expiration_time,
"conditions": [
["eq", "$success_action_status", "200"],
{"x-oss-signature-version": "OSS4-HMAC-SHA256"},
{"x-oss-credential": f"{sts_access_key_id}/{dt_obj_2}/{region_id}/oss/aliyun_v4_request"},
{"x-oss-security-token": security_token},
{"x-oss-date": dt_obj_1},
]
}
policy_str = json.dumps(policy).strip()
# 步骤2构造待签名字符串StringToSign
stringToSign = base64.b64encode(policy_str.encode()).decode()
# 步骤3计算SigningKey
dateKey = hmacsha256(("aliyun_v4" + sts_access_key_secret).encode(), dt_obj_2)
dateRegionKey = hmacsha256(dateKey, region_id)
dateRegionServiceKey = hmacsha256(dateRegionKey, "oss")
signingKey = hmacsha256(dateRegionServiceKey, "aliyun_v4_request")
# 步骤4计算Signature
result = hmacsha256(signingKey, stringToSign)
signature = result.hex()
# 组织返回数据
response_data = {
'policy': stringToSign, # 表单域
'x_oss_signature_version': "OSS4-HMAC-SHA256", # 指定签名的版本和算法固定值为OSS4-HMAC-SHA256
'x_oss_credential': f"{sts_access_key_id}/{dt_obj_2}/{region_id}/oss/aliyun_v4_request", # 指明派生密钥的参数集
'x_oss_date': dt_obj_1, # 请求的时间
'signature': signature, # 签名认证描述信息
'host': host,
'dir': upload_dir,
'security_token': security_token # 安全令牌
}
return response_data
@app.get("/aichat/recognize_content")
async def web_recognize_content(image_url: str,
current_user: dict = Depends(get_current_user)
):
logger.info(f"current_user:{current_user['login_name']}")
person_id = current_user['person_id']
# 获取图片宽高
image_width, image_height = getImgWidthHeight(image_url)
# 调用 AI 模型生成内容(流式输出)
response = await client.chat.completions.create(
model="qwen-vl-plus",
messages=[{"role": "user", "content": [
{"type": "text", "text": "这是什么"},
{"type": "image_url", "image_url": {"url": image_url}}
]}],
stream=True
)
# 定义一个生成器函数,用于逐字返回流式结果
async def generate_stream():
summary = "" # 用于存储最终拼接的字符串
try:
# 逐块处理 AI 返回的内容
async for chunk in response:
if chunk.choices[0].delta.content:
chunk_content = chunk.choices[0].delta.content
# 逐字返回
for char in chunk_content:
print(char, end="", flush=True) # 打印到控制台
yield char.encode("utf-8") # 逐字返回
summary += char # 拼接字符
# 流式传输完成后,记录到数据库
await save_chat_to_mysql(
app.state.mysql_pool, person_id, f'{image_url}', summary, "", 0, 2, 2, 2, image_width, image_height
)
except asyncio.CancelledError:
# 客户端提前断开连接,无需处理
print("客户端断开连接")
except Exception as e:
error_response = json.dumps({
"success": False,
"message": f"生成内容失败: {str(e)}"
})
print(error_response)
yield error_response.encode("utf-8")
# 使用 StreamingResponse 返回流式结果
return StreamingResponse(
generate_stream(),
media_type="text/plain; charset=utf-8", # 明确指定字符编码为 UTF-8
headers={
"Cache-Control": "no-cache", # 禁用缓存
"Content-Type": "text/event-stream; charset=utf-8", # 设置内容类型和字符编码
"Transfer-Encoding": "chunked",
"Connection": "keep-alive",
"X-Accel-Buffering": "no", # 禁用 Nginx 缓冲(如果使用 Nginx
}
)
@app.get("/aichat/recognize_text")
async def web_recognize_text(image_url: str,
current_user: dict = Depends(get_current_user)
):
logger.info(f"current_user:{current_user['login_name']}")
person_id = current_user['person_id']
# 获取图片宽高
image_width, image_height = getImgWidthHeight(image_url)
# 调用 AI 模型生成内容(流式输出)
response = await client.chat.completions.create(
model="qwen-vl-ocr",
messages=[
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": image_url,
"min_pixels": 28 * 28 * 4,
"max_pixels": 28 * 28 * 1280
},
{"type": "text", "text": "Read all the text in the image."},
]
}
],
stream=True
)
# 定义一个生成器函数,用于逐字返回流式结果
async def generate_stream():
summary = "" # 用于存储最终拼接的字符串
try:
# 逐块处理 AI 返回的内容
async for chunk in response:
if chunk.choices[0].delta.content:
chunk_content = chunk.choices[0].delta.content
# 逐字返回
for char in chunk_content:
print(char, end="", flush=True) # 打印到控制台
yield char.encode("utf-8") # 逐字返回
summary += char # 拼接字符
# 流式传输完成后,记录到数据库
await save_chat_to_mysql(app.state.mysql_pool, person_id, f'{image_url}', summary, "", 0, 2, 2, 1,
image_width,
image_height)
except asyncio.CancelledError:
# 客户端提前断开连接,无需处理
print("客户端断开连接")
except Exception as e:
error_response = json.dumps({
"success": False,
"message": f"生成内容失败: {str(e)}"
})
print(error_response)
yield error_response.encode("utf-8")
# 使用 StreamingResponse 返回流式结果
return StreamingResponse(
generate_stream(),
media_type="text/plain; charset=utf-8", # 明确指定字符编码为 UTF-8
headers={
"Cache-Control": "no-cache", # 禁用缓存
"Content-Type": "text/event-stream; charset=utf-8", # 设置内容类型和字符编码
"Transfer-Encoding": "chunked",
"Connection": "keep-alive",
"X-Accel-Buffering": "no", # 禁用 Nginx 缓冲(如果使用 Nginx
}
)
@app.get("/aichat/recognize_math")
async def web_recognize_math(image_url: str,
current_user: dict = Depends(get_current_user)
):
logger.info(f"current_user:{current_user['login_name']}")
person_id = current_user['person_id']
# 获取图片宽高
image_width, image_height = getImgWidthHeight(image_url)
client = AsyncOpenAI(
api_key=MODELSCOPE_ACCESS_TOKEN,
base_url="https://api-inference.modelscope.cn/v1"
)
"""
识别图片中的数学题,流式输出,并将结果记录到数据库
:param client: AsyncOpenAI 客户端
:param pool: 数据库连接池
:param person_id: 用户 ID
:param image_url: 图片 URL
:return: 最终拼接的字符串
"""
# 提示词
prompt = "You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step."
response = await client.chat.completions.create(
model="Qwen/Qwen2.5-VL-32B-Instruct",
messages=[
{
"role": "system",
"content": [
{"type": "text", "text": prompt}
],
},
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": image_url}
},
{"type": "text", "text": "请使用中文回答:如何作答?"},
],
}
],
stream=True
)
# 定义一个生成器函数,用于逐字返回流式结果
async def generate_stream():
summary = "" # 用于存储最终拼接的字符串
try:
# 逐块处理 AI 返回的内容
async for chunk in response:
if chunk.choices[0].delta.content:
chunk_content = chunk.choices[0].delta.content
# 逐字返回
for char in chunk_content:
print(char, end="", flush=True) # 打印到控制台
yield char.encode("utf-8") # 逐字返回
summary += char # 拼接字符
# 流式传输完成后,记录到数据库
await save_chat_to_mysql(app.state.mysql_pool, person_id, f'{image_url}', summary, "", 0, 2, 2, 1,
image_width, image_height)
except asyncio.CancelledError:
# 客户端提前断开连接,无需处理
print("客户端断开连接")
except Exception as e:
error_response = json.dumps({
"success": False,
"message": f"生成内容失败: {str(e)}"
})
print(error_response)
yield error_response.encode("utf-8")
# 使用 StreamingResponse 返回流式结果
return StreamingResponse(
generate_stream(),
media_type="text/plain; charset=utf-8", # 明确指定字符编码为 UTF-8
headers={
"Cache-Control": "no-cache", # 禁用缓存
"Content-Type": "text/event-stream; charset=utf-8", # 设置内容类型和字符编码
"Transfer-Encoding": "chunked",
"Connection": "keep-alive",
"X-Accel-Buffering": "no", # 禁用 Nginx 缓冲(如果使用 Nginx
}
)
# 运行 FastAPI 应用
if __name__ == "__main__":
import uvicorn
uvicorn_log_config = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"default": {
"()": "uvicorn.logging.DefaultFormatter",
"fmt": "%(asctime)s %(levelprefix)s %(message)s",
"datefmt": "%Y-%m-%d %H:%M:%S",
}
},
"handlers": {
"default": {
"formatter": "default",
"class": "logging.StreamHandler",
"stream": "ext://sys.stderr",
}
},
"loggers": {
"uvicorn": {"handlers": ["default"], "level": "INFO", "propagate": False},
"uvicorn.access": {"handlers": ["default"], "level": "INFO", "propagate": False},
"uvicorn.error": {"handlers": ["default"], "level": "INFO", "propagate": False},
"uvicorn.asgi": {"handlers": [], "level": "INFO", "propagate": False}, # 禁用 ASGI 日志
},
}
# 强制覆盖Uvicorn的默认配置
uvicorn.run("Start:app", host="0.0.0.0", port=5600, workers=1, log_config=uvicorn_log_config)