|
|
|
@ -1,5 +1,6 @@
|
|
|
|
|
import asyncio
|
|
|
|
|
import logging
|
|
|
|
|
import re
|
|
|
|
|
import time
|
|
|
|
|
import uuid
|
|
|
|
|
from contextlib import asynccontextmanager
|
|
|
|
@ -11,6 +12,7 @@ from fastapi.security import OAuth2PasswordBearer
|
|
|
|
|
from jose import JWTError, jwt
|
|
|
|
|
from openai import AsyncOpenAI
|
|
|
|
|
from passlib.context import CryptContext
|
|
|
|
|
from starlette.responses import StreamingResponse
|
|
|
|
|
|
|
|
|
|
from WxMini.Milvus.Config.MulvusConfig import *
|
|
|
|
|
from WxMini.Milvus.Utils.MilvusCollectionManager import MilvusCollectionManager
|
|
|
|
@ -267,7 +269,7 @@ async def reply(person_id: str = Form(...),
|
|
|
|
|
# 查询非向量字段
|
|
|
|
|
record = await asyncio.to_thread(collection_manager.query_by_id, hit.id)
|
|
|
|
|
if record:
|
|
|
|
|
#logger.info(f"查询到的记录: {record}")
|
|
|
|
|
# logger.info(f"查询到的记录: {record}")
|
|
|
|
|
# 添加历史交互
|
|
|
|
|
history_prompt += f"用户: {record['user_input']}\n大模型: {record['model_response']}\n"
|
|
|
|
|
except Exception as e:
|
|
|
|
@ -275,7 +277,7 @@ async def reply(person_id: str = Form(...),
|
|
|
|
|
|
|
|
|
|
# 限制历史交互提示词长度
|
|
|
|
|
history_prompt = history_prompt[:2000]
|
|
|
|
|
#logger.info(f"历史交互提示词: {history_prompt}")
|
|
|
|
|
# logger.info(f"历史交互提示词: {history_prompt}")
|
|
|
|
|
|
|
|
|
|
# 调用大模型,将历史交互作为提示词
|
|
|
|
|
try:
|
|
|
|
@ -314,7 +316,7 @@ async def reply(person_id: str = Form(...),
|
|
|
|
|
if len(result) > 500:
|
|
|
|
|
logger.warning(f"大模型回复被截断,原始长度: {len(result)}")
|
|
|
|
|
await asyncio.to_thread(collection_manager.insert_data, entities)
|
|
|
|
|
#logger.info("用户输入和大模型反馈已记录到向量数据库。")
|
|
|
|
|
# logger.info("用户输入和大模型反馈已记录到向量数据库。")
|
|
|
|
|
|
|
|
|
|
# 调用 TTS 生成 MP3
|
|
|
|
|
uuid_str = str(uuid.uuid4())
|
|
|
|
@ -327,7 +329,7 @@ async def reply(person_id: str = Form(...),
|
|
|
|
|
t = TTS(None) # 传入 None 表示不保存到本地文件
|
|
|
|
|
audio_data, duration = await asyncio.to_thread(t.generate_audio,
|
|
|
|
|
result) # 假设 TTS 类有一个 generate_audio 方法返回音频数据
|
|
|
|
|
#print(f"音频时长: {duration} 秒")
|
|
|
|
|
# print(f"音频时长: {duration} 秒")
|
|
|
|
|
|
|
|
|
|
# 将音频数据直接上传到 OSS
|
|
|
|
|
await asyncio.to_thread(upload_mp3_to_oss_from_memory, tts_file, audio_data)
|
|
|
|
@ -338,7 +340,7 @@ async def reply(person_id: str = Form(...),
|
|
|
|
|
|
|
|
|
|
# 记录聊天数据到 MySQL
|
|
|
|
|
await save_chat_to_mysql(app.state.mysql_pool, person_id, prompt, result, url, duration)
|
|
|
|
|
#logger.info("用户输入和大模型反馈已记录到 MySQL 数据库。")
|
|
|
|
|
# logger.info("用户输入和大模型反馈已记录到 MySQL 数据库。")
|
|
|
|
|
|
|
|
|
|
# 调用会话检查机制,异步执行
|
|
|
|
|
asyncio.create_task(on_session_end(person_id))
|
|
|
|
@ -402,7 +404,8 @@ async def get_risk_chat_logs(
|
|
|
|
|
offset = (page - 1) * page_size
|
|
|
|
|
|
|
|
|
|
# 调用 get_chat_logs_by_risk_flag 方法
|
|
|
|
|
logs, total = await get_chat_logs_by_risk_flag(app.state.mysql_pool, risk_flag,current_user["person_id"], offset, page_size)
|
|
|
|
|
logs, total = await get_chat_logs_by_risk_flag(app.state.mysql_pool, risk_flag, current_user["person_id"], offset,
|
|
|
|
|
page_size)
|
|
|
|
|
if not logs:
|
|
|
|
|
return {
|
|
|
|
|
"success": False,
|
|
|
|
@ -425,7 +428,6 @@ async def get_risk_chat_logs(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 获取风险统计接口
|
|
|
|
|
@app.get("/aichat/chat_logs_summary")
|
|
|
|
|
async def chat_logs_summary(
|
|
|
|
@ -484,6 +486,7 @@ async def chat_logs_summary(
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 获取上传OSS的授权Token
|
|
|
|
|
@app.get("/aichat/get_oss_upload_token")
|
|
|
|
|
async def get_oss_upload_token(current_user: dict = Depends(get_current_user)):
|
|
|
|
@ -504,6 +507,107 @@ async def get_oss_upload_token(current_user: dict = Depends(get_current_user)):
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def is_text_dominant(image_url):
|
|
|
|
|
"""
|
|
|
|
|
判断图片是否主要是文字内容
|
|
|
|
|
:param image_url: 图片 URL
|
|
|
|
|
:return: True(主要是文字) / False(主要是物体/场景)
|
|
|
|
|
"""
|
|
|
|
|
completion = await client.chat.completions.create(
|
|
|
|
|
model="qwen-vl-ocr",
|
|
|
|
|
messages=[
|
|
|
|
|
{
|
|
|
|
|
"role": "user",
|
|
|
|
|
"content": [
|
|
|
|
|
{
|
|
|
|
|
"type": "image_url",
|
|
|
|
|
"image_url": image_url,
|
|
|
|
|
"min_pixels": 28 * 28 * 4,
|
|
|
|
|
"max_pixels": 28 * 28 * 1280
|
|
|
|
|
},
|
|
|
|
|
{"type": "text", "text": "Read all the text in the image."},
|
|
|
|
|
]
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
stream=False
|
|
|
|
|
)
|
|
|
|
|
text = completion.choices[0].message.content
|
|
|
|
|
|
|
|
|
|
# 判断是否只有英文和数字
|
|
|
|
|
if re.match(r'^[A-Za-z0-9\s]+$', text):
|
|
|
|
|
print("识别到的内容只有英文和数字,可能是无意义的字符,调用识别内容功能。")
|
|
|
|
|
return False
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def recognize_text(image_url):
|
|
|
|
|
"""
|
|
|
|
|
识别图片中的文字,流式输出
|
|
|
|
|
"""
|
|
|
|
|
completion = await client.chat.completions.create(
|
|
|
|
|
model="qwen-vl-ocr",
|
|
|
|
|
messages=[
|
|
|
|
|
{
|
|
|
|
|
"role": "user",
|
|
|
|
|
"content": [
|
|
|
|
|
{
|
|
|
|
|
"type": "image_url",
|
|
|
|
|
"image_url": image_url,
|
|
|
|
|
"min_pixels": 28 * 28 * 4,
|
|
|
|
|
"max_pixels": 28 * 28 * 1280
|
|
|
|
|
},
|
|
|
|
|
{"type": "text", "text": "Read all the text in the image."},
|
|
|
|
|
]
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
stream=True
|
|
|
|
|
)
|
|
|
|
|
async for chunk in completion:
|
|
|
|
|
if chunk.choices[0].delta.content is not None:
|
|
|
|
|
for char in chunk.choices[0].delta.content:
|
|
|
|
|
if char != ' ':
|
|
|
|
|
yield char
|
|
|
|
|
time.sleep(0.1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def recognize_content(image_url):
|
|
|
|
|
"""
|
|
|
|
|
识别图片中的内容,流式输出
|
|
|
|
|
"""
|
|
|
|
|
completion = await client.chat.completions.create(
|
|
|
|
|
model="qwen-vl-plus",
|
|
|
|
|
messages=[{"role": "user", "content": [
|
|
|
|
|
{"type": "text", "text": "这是什么"},
|
|
|
|
|
{"type": "image_url", "image_url": {"url": image_url}}
|
|
|
|
|
]}],
|
|
|
|
|
stream=True
|
|
|
|
|
)
|
|
|
|
|
async for chunk in completion:
|
|
|
|
|
if chunk.choices[0].delta.content is not None:
|
|
|
|
|
for char in chunk.choices[0].delta.content:
|
|
|
|
|
yield char
|
|
|
|
|
time.sleep(0.1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/aichat/process_image")
|
|
|
|
|
async def process_image(image_url: str, current_user: dict = Depends(get_current_user)):
|
|
|
|
|
logger.info(f"current_user:{current_user['login_name']}")
|
|
|
|
|
"""
|
|
|
|
|
处理图片,自动判断调用哪个功能
|
|
|
|
|
:param image_url: 图片 URL
|
|
|
|
|
:return: 流式输出结果
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
if await is_text_dominant(image_url):
|
|
|
|
|
print("检测到图片主要是文字内容,开始识别文字:")
|
|
|
|
|
return StreamingResponse(recognize_text(image_url), media_type="text/plain")
|
|
|
|
|
else:
|
|
|
|
|
print("检测到图片主要是物体/场景,开始识别内容:")
|
|
|
|
|
return StreamingResponse(recognize_content(image_url), media_type="text/plain")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 运行 FastAPI 应用
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
import uvicorn
|
|
|
|
|