You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

616 lines
23 KiB

4 months ago
import asyncio
import logging
4 months ago
import re
4 months ago
import time
4 months ago
import uuid
4 months ago
from contextlib import asynccontextmanager
4 months ago
from datetime import datetime, timedelta
from typing import Optional
from fastapi import Query, Depends, HTTPException, status, Form, FastAPI
from fastapi.security import OAuth2PasswordBearer
from jose import JWTError, jwt
4 months ago
from openai import AsyncOpenAI
4 months ago
from passlib.context import CryptContext
4 months ago
from starlette.responses import StreamingResponse
4 months ago
4 months ago
from WxMini.Milvus.Config.MulvusConfig import *
4 months ago
from WxMini.Milvus.Utils.MilvusCollectionManager import MilvusCollectionManager
from WxMini.Milvus.Utils.MilvusConnectionPool import *
4 months ago
from WxMini.Utils.EmbeddingUtil import text_to_embedding
4 months ago
from WxMini.Utils.MySQLUtil import init_mysql_pool, save_chat_to_mysql, get_chat_log_by_session, get_user_by_login_name, \
4 months ago
get_chat_logs_by_risk_flag, get_chat_logs_summary
4 months ago
from WxMini.Utils.MySQLUtil import update_risk, get_last_chat_log_id
from WxMini.Utils.OssUtil import upload_mp3_to_oss_from_memory, get_sts_token
4 months ago
from WxMini.Utils.TtsUtil import TTS
4 months ago
4 months ago
# 配置日志
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
4 months ago
# 密码加密上下文
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
# OAuth2 密码模式
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
4 months ago
# 初始化 Milvus 连接池
milvus_pool = MilvusConnectionPool(host=MS_HOST, port=MS_PORT, max_connections=MS_MAX_CONNECTIONS)
# 初始化集合管理器
collection_name = MS_COLLECTION_NAME
collection_manager = MilvusCollectionManager(collection_name)
4 months ago
4 months ago
# 使用 Lifespan Events 处理应用启动和关闭逻辑
@asynccontextmanager
async def lifespan(app: FastAPI):
# 应用启动时加载集合到内存
collection_manager.load_collection()
4 months ago
logger.info(f"集合 '{collection_name}' 已加载到内存。")
4 months ago
# 初始化 MySQL 连接池
app.state.mysql_pool = await init_mysql_pool()
logger.info("MySQL 连接池已初始化。")
4 months ago
yield
# 应用关闭时释放连接池
milvus_pool.close()
4 months ago
app.state.mysql_pool.close()
await app.state.mysql_pool.wait_closed()
logger.info("Milvus 和 MySQL 连接池已关闭。")
4 months ago
4 months ago
4 months ago
# 会话结束后,调用检查方法,判断是不是有需要介入的问题出现
4 months ago
async def on_session_end(person_id):
4 months ago
# 获取最后一条聊天记录
4 months ago
last_id = await get_last_chat_log_id(app.state.mysql_pool, person_id)
4 months ago
if last_id:
# 查询最后一条记录的详细信息
async with app.state.mysql_pool.acquire() as conn:
async with conn.cursor() as cur:
await cur.execute(
"SELECT user_input, model_response FROM t_chat_log WHERE id = %s",
(last_id,)
)
last_record = await cur.fetchone()
if last_record:
4 months ago
history = f"问题:{last_record[0]}\n回答:{last_record[1]}"
4 months ago
else:
history = "无聊天记录"
else:
history = "无聊天记录"
4 months ago
# 将历史聊天记录发给大模型,让它帮我分析一下
4 months ago
with open("Input.txt", "r", encoding="utf-8") as file:
input_word = file.read()
4 months ago
prompt = (
4 months ago
"分析用户是否存在心理健康方面的问题:"
4 months ago
f"参考分类文档内容如下:{input_word},注意:只有情节比较严重的才认为有健康问题,轻微的不算。"
4 months ago
"如果没有健康问题请回复: OK否则回复NO换行后输出问题类型的名称"
f"\n\n聊天记录:{history}"
4 months ago
)
4 months ago
# 使用 asyncio.create_task 异步执行大模型调用
async def analyze_mental_health():
response = await client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": "你是一个心理健康分析助手,负责分析用户的心理健康状况。"},
{"role": "user", "content": prompt}
],
max_tokens=1000
)
# 处理分析结果
if response.choices and response.choices[0].message.content:
analysis_result = response.choices[0].message.content.strip()
if analysis_result.startswith("NO"):
# 异步执行 update_risk
await update_risk(app.state.mysql_pool, person_id, analysis_result)
logger.info(f"已异步更新 person_id={person_id} 的风险状态。")
else:
logger.info(f"AI大模型没有发现任何心理健康问题用户会话 {person_id} 没有风险。")
# 创建异步任务
asyncio.create_task(analyze_mental_health())
4 months ago
4 months ago
# 初始化 FastAPI 应用
4 months ago
app = FastAPI(lifespan=lifespan)
4 months ago
4 months ago
# 初始化异步 OpenAI 客户端
client = AsyncOpenAI(
4 months ago
api_key=MODEL_API_KEY,
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
)
4 months ago
4 months ago
# 验证密码
def verify_password(plain_password, hashed_password):
return pwd_context.verify(plain_password, hashed_password)
# 创建 JWT
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=15)
to_encode.update({"exp": expire})
4 months ago
encoded_jwt = jwt.encode(to_encode, JWT_SECRET_KEY, algorithm=ALGORITHM)
4 months ago
return encoded_jwt
# 获取当前用户
async def get_current_user(token: str = Depends(oauth2_scheme)):
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无法验证凭证",
headers={"WWW-Authenticate": "Bearer"},
)
try:
4 months ago
payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[ALGORITHM])
4 months ago
login_name: str = payload.get("sub")
if login_name is None:
raise credentials_exception
except JWTError:
raise credentials_exception
user = await get_user_by_login_name(app.state.mysql_pool, login_name)
if user is None:
raise credentials_exception
return user
4 months ago
# 登录接口
@app.post("/aichat/login")
async def login(
login_name: str = Form(..., description="用户名"),
password: str = Form(..., description="密码")
):
"""
用户登录接口
:param login_name: 用户名
:param password: 密码
:return: 登录结果
"""
flag = True
if not login_name or not password:
flag = False
# 调用 get_user_by_login_name 方法
user = await get_user_by_login_name(app.state.mysql_pool, login_name)
if not user:
flag = False
if user and user['password'] != password:
flag = False
if not flag:
return {
"code": 200,
"message": "登录失败",
"success": False
}
4 months ago
# 生成 JWT
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token(
data={"sub": user["login_name"]}, expires_delta=access_token_expires
)
4 months ago
# 返回带字段名称的数据
return {
"message": "登录成功",
"success": True,
"data": {
"person_id": user["person_id"],
"login_name": user["login_name"],
"identity_id": user["identity_id"],
"person_name": user["person_name"],
"xb_name": user["xb_name"],
"city_name": user["city_name"],
"area_name": user["area_name"],
"school_name": user["school_name"],
"grade_name": user["grade_name"],
4 months ago
"class_name": user["class_name"],
"access_token": access_token,
"token_type": "bearer"
4 months ago
}
}
4 months ago
4 months ago
# 与用户交流聊天
4 months ago
@app.post("/aichat/reply")
4 months ago
async def reply(person_id: str = Form(...),
prompt: str = Form(...),
current_user: dict = Depends(get_current_user)
):
4 months ago
"""
接收用户输入的 prompt调用大模型并返回结果
4 months ago
:param person_id: 用户会话 ID
4 months ago
:param prompt: 用户输入的 prompt
:return: 大模型的回复
"""
try:
4 months ago
logger.info(f"收到用户输入: {prompt}")
4 months ago
if not prompt:
return {
"code": 200,
"message": "请输入内容",
"success": False
}
4 months ago
# 从连接池中获取一个连接
connection = milvus_pool.get_connection()
# 将用户输入转换为嵌入向量
current_embedding = text_to_embedding(prompt)
# 查询与当前对话最相关的历史交互
search_params = {
"metric_type": "L2", # 使用 L2 距离度量方式
"params": {"nprobe": MS_NPROBE} # 设置 IVF_FLAT 的 nprobe 参数
}
start_time = time.time()
4 months ago
results = await asyncio.to_thread( # 将阻塞操作放到线程池中执行
collection_manager.search,
4 months ago
data=current_embedding, # 输入向量
search_params=search_params, # 搜索参数
4 months ago
expr=f"person_id == '{person_id}'", # 按 person_id 过滤
4 months ago
limit=6 # 返回 6 条结果
4 months ago
)
end_time = time.time()
# 构建历史交互提示词
history_prompt = ""
if results:
for hits in results:
for hit in hits:
try:
# 查询非向量字段
4 months ago
record = await asyncio.to_thread(collection_manager.query_by_id, hit.id)
4 months ago
if record:
4 months ago
# logger.info(f"查询到的记录: {record}")
4 months ago
# 添加历史交互
history_prompt += f"用户: {record['user_input']}\n大模型: {record['model_response']}\n"
except Exception as e:
4 months ago
logger.error(f"查询失败: {e}")
4 months ago
4 months ago
# 限制历史交互提示词长度
history_prompt = history_prompt[:2000]
4 months ago
# logger.info(f"历史交互提示词: {history_prompt}")
4 months ago
# 调用大模型,将历史交互作为提示词
4 months ago
try:
response = await asyncio.wait_for(
client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system",
4 months ago
"content": "你是一个和你聊天人的好朋友,疏导情绪,让他开心,亲切一些,不要使用哎呀这样的语气词。聊天的回复内容不要超过100字。"},
4 months ago
{"role": "user", "content": f"历史对话记录:{history_prompt},本次用户问题: {prompt}"}
],
4 months ago
max_tokens=2000
4 months ago
),
timeout=60 # 设置超时时间为 60 秒
)
except asyncio.TimeoutError:
logger.error("大模型调用超时")
raise HTTPException(status_code=500, detail="大模型调用超时")
4 months ago
# 提取生成的回复
if response.choices and response.choices[0].message.content:
result = response.choices[0].message.content.strip()
4 months ago
logger.info(f"大模型回复: {result}")
4 months ago
# 记录用户输入和大模型反馈到向量数据库
timestamp = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
entities = [
4 months ago
[person_id], # person_id
4 months ago
[prompt[:500]], # user_input截断到 500 字符
[result[:500]], # model_response截断到 500 字符
4 months ago
[timestamp], # timestamp
[current_embedding] # embedding
]
4 months ago
if len(prompt) > 500:
logger.warning(f"用户输入被截断,原始长度: {len(prompt)}")
if len(result) > 500:
logger.warning(f"大模型回复被截断,原始长度: {len(result)}")
4 months ago
await asyncio.to_thread(collection_manager.insert_data, entities)
4 months ago
# logger.info("用户输入和大模型反馈已记录到向量数据库。")
4 months ago
# 调用 TTS 生成 MP3
4 months ago
uuid_str = str(uuid.uuid4())
4 months ago
timestamp = int(time.time())
4 months ago
# 生成年月日的目录名称
4 months ago
audio_dir = f"audio/{time.strftime('%Y%m%d', time.localtime())}"
tts_file = f"{audio_dir}/{uuid_str}_{timestamp}.mp3"
4 months ago
# 生成 TTS 音频数据(不落盘)
t = TTS(None) # 传入 None 表示不保存到本地文件
4 months ago
audio_data, duration = await asyncio.to_thread(t.generate_audio,
result) # 假设 TTS 类有一个 generate_audio 方法返回音频数据
4 months ago
# print(f"音频时长: {duration} 秒")
4 months ago
# 将音频数据直接上传到 OSS
4 months ago
await asyncio.to_thread(upload_mp3_to_oss_from_memory, tts_file, audio_data)
4 months ago
logger.info(f"TTS 文件已直接上传到 OSS: {tts_file}")
4 months ago
# 完整的 URL
4 months ago
url = OSS_PREFIX + tts_file
4 months ago
# 记录聊天数据到 MySQL
4 months ago
await save_chat_to_mysql(app.state.mysql_pool, person_id, prompt, result, url, duration)
4 months ago
# logger.info("用户输入和大模型反馈已记录到 MySQL 数据库。")
4 months ago
4 months ago
# 调用会话检查机制,异步执行
asyncio.create_task(on_session_end(person_id))
4 months ago
# 返回数据
4 months ago
return {
"success": True,
"url": url,
"search_time": end_time - start_time, # 返回查询耗时
4 months ago
"duration": duration, # 返回大模型的回复时长
4 months ago
"response": result, # 返回大模型的回复
"login_name": current_user["login_name"],
"person_name": current_user["person_name"]
4 months ago
}
4 months ago
else:
raise HTTPException(status_code=500, detail="大模型未返回有效结果")
except Exception as e:
4 months ago
logger.error(f"调用大模型失败: {str(e)}")
4 months ago
raise HTTPException(status_code=500, detail=f"调用大模型失败: {str(e)}")
4 months ago
finally:
# 释放连接
milvus_pool.release_connection(connection)
4 months ago
4 months ago
# 获取聊天记录
4 months ago
@app.get("/aichat/get_chat_log")
4 months ago
async def get_chat_log(
4 months ago
person_id: str,
4 months ago
page: int = Query(default=1, ge=1, description="当前页码"),
4 months ago
page_size: int = Query(default=10, ge=1, le=100, description="每页记录数"),
current_user: dict = Depends(get_current_user)
4 months ago
):
"""
4 months ago
获取指定会话的聊天记录默认返回最新的记录最后一页
4 months ago
:param person_id: 用户会话 ID
4 months ago
:param page: 当前页码默认值为 1但会动态计算为最后一页
4 months ago
:param page_size: 每页记录数
:return: 分页数据
"""
4 months ago
# 调用 get_chat_log_by_session 方法
4 months ago
result = await get_chat_log_by_session(app.state.mysql_pool, current_user, person_id, page, page_size)
4 months ago
return result
4 months ago
4 months ago
4 months ago
# 获取风险聊天记录接口
@app.get("/aichat/get_risk_chat_logs")
async def get_risk_chat_logs(
4 months ago
risk_flag: int = Query(..., description="风险标志1 表示有风险0 表示无风险 ,2:处理完毕)"),
4 months ago
page: int = Query(default=1, ge=1, description="当前页码(默认值为 1"),
4 months ago
page_size: int = Query(default=10, ge=1, le=100, description="每页记录数(默认值为 10最大值为 100"),
current_user: dict = Depends(get_current_user)
4 months ago
):
"""
4 months ago
获取聊天记录支持分页和风险标志过滤
:param risk_flag: 风险标志
4 months ago
:param page: 当前页码
:param page_size: 每页记录数
:return: 分页数据
"""
4 months ago
# 计算分页偏移量
offset = (page - 1) * page_size
# 调用 get_chat_logs_by_risk_flag 方法
4 months ago
logs, total = await get_chat_logs_by_risk_flag(app.state.mysql_pool, risk_flag, current_user["person_id"], offset,
page_size)
4 months ago
if not logs:
4 months ago
return {
"success": False,
"message": "没有找到相关记录",
"data": {
}
}
4 months ago
# 返回分页数据
return {
4 months ago
"success": True,
4 months ago
"message": "查询成功",
"data": {
"total": total,
"page": page,
"page_size": page_size,
4 months ago
"logs": logs
}
}
# 获取风险统计接口
@app.get("/aichat/chat_logs_summary")
async def chat_logs_summary(
risk_flag: int = Query(..., description="风险标志1 表示有风险0 表示无风险 ,2:处理完毕)"),
page: int = Query(default=1, ge=1, description="当前页码(默认值为 1"),
page_size: int = Query(default=10, ge=1, le=100, description="每页记录数(默认值为 10最大值为 100"),
current_user: dict = Depends(get_current_user)
):
"""
获取风险统计接口支持分页和风险标志过滤
:param risk_flag: 风险标志
:param page: 当前页码
:param page_size: 每页记录数
:param current_user: 当前用户信息
:return: 分页数据
"""
# 验证 risk_flag 的值
if risk_flag not in {0, 1, 2}:
raise HTTPException(status_code=400, detail="risk_flag 的值必须是 0、1 或 2")
# 计算分页偏移量
offset = (page - 1) * page_size
# 调用 get_chat_logs_summary 方法
logs, total = await get_chat_logs_summary(app.state.mysql_pool, risk_flag, offset, page_size)
# 如果未找到记录,返回友好提示
if not logs:
return {
"success": True,
"message": "未找到符合条件的记录",
"data": {
"total": 0,
"page": page,
"page_size": page_size,
"total_pages": 0,
"logs": []
}
}
# 计算总页数
total_pages = (total + page_size - 1) // page_size
# 返回分页数据
return {
"success": True,
"message": "查询成功",
"data": {
"total": total,
"page": page,
"page_size": page_size,
"total_pages": total_pages,
4 months ago
"logs": logs,
"login_name": current_user["login_name"],
"person_name": current_user["person_name"]
4 months ago
}
}
4 months ago
4 months ago
4 months ago
# 获取上传OSS的授权Token
@app.get("/aichat/get_oss_upload_token")
4 months ago
async def get_oss_upload_token(current_user: dict = Depends(get_current_user)):
4 months ago
# 获取 STS 临时凭证
sts_token = get_sts_token()
return {
4 months ago
"success": True,
4 months ago
"message": "获取上传凭证成功",
"data": {
"access_key_id": sts_token['AccessKeyId'],
"access_key_secret": sts_token['AccessKeySecret'],
"security_token": sts_token['SecurityToken'],
"bucket_name": BUCKET_NAME,
4 months ago
"endpoint": ENDPOINT,
"login_name": current_user["login_name"],
"person_name": current_user["person_name"]
4 months ago
}
}
4 months ago
4 months ago
async def is_text_dominant(image_url):
"""
判断图片是否主要是文字内容
:param image_url: 图片 URL
:return: True主要是文字 / False主要是物体/场景
"""
completion = await client.chat.completions.create(
model="qwen-vl-ocr",
messages=[
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": image_url,
"min_pixels": 28 * 28 * 4,
"max_pixels": 28 * 28 * 1280
},
{"type": "text", "text": "Read all the text in the image."},
]
}
],
stream=False
)
text = completion.choices[0].message.content
# 判断是否只有英文和数字
if re.match(r'^[A-Za-z0-9\s]+$', text):
print("识别到的内容只有英文和数字,可能是无意义的字符,调用识别内容功能。")
return False
return True
async def recognize_text(image_url):
"""
识别图片中的文字流式输出
"""
completion = await client.chat.completions.create(
model="qwen-vl-ocr",
messages=[
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": image_url,
"min_pixels": 28 * 28 * 4,
"max_pixels": 28 * 28 * 1280
},
{"type": "text", "text": "Read all the text in the image."},
]
}
],
stream=True
)
async for chunk in completion:
if chunk.choices[0].delta.content is not None:
for char in chunk.choices[0].delta.content:
if char != ' ':
yield char
time.sleep(0.1)
async def recognize_content(image_url):
"""
识别图片中的内容流式输出
"""
completion = await client.chat.completions.create(
model="qwen-vl-plus",
messages=[{"role": "user", "content": [
{"type": "text", "text": "这是什么"},
{"type": "image_url", "image_url": {"url": image_url}}
]}],
stream=True
)
async for chunk in completion:
if chunk.choices[0].delta.content is not None:
for char in chunk.choices[0].delta.content:
yield char
time.sleep(0.1)
@app.get("/aichat/process_image")
async def process_image(image_url: str, current_user: dict = Depends(get_current_user)):
logger.info(f"current_user:{current_user['login_name']}")
"""
处理图片自动判断调用哪个功能
:param image_url: 图片 URL
:return: 流式输出结果
"""
try:
if await is_text_dominant(image_url):
print("检测到图片主要是文字内容,开始识别文字:")
return StreamingResponse(recognize_text(image_url), media_type="text/plain")
else:
print("检测到图片主要是物体/场景,开始识别内容:")
return StreamingResponse(recognize_content(image_url), media_type="text/plain")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
4 months ago
# 运行 FastAPI 应用
if __name__ == "__main__":
import uvicorn
4 months ago
uvicorn.run("Start:app", host="0.0.0.0", port=5600, workers=1)