|
|
import base64
|
|
|
import base64
|
|
|
import datetime
|
|
|
import json
|
|
|
import time
|
|
|
import uuid
|
|
|
from contextlib import asynccontextmanager
|
|
|
from datetime import datetime, timedelta
|
|
|
from typing import Optional
|
|
|
|
|
|
from alibabacloud_sts20150401 import models as sts_20150401_models
|
|
|
from alibabacloud_sts20150401.client import Client as Sts20150401Client
|
|
|
from alibabacloud_tea_openapi.models import Config
|
|
|
from fastapi import Query, Depends, status, Form, FastAPI
|
|
|
from fastapi.security import OAuth2PasswordBearer
|
|
|
from jose import JWTError, jwt
|
|
|
from openai import AsyncOpenAI
|
|
|
from passlib.context import CryptContext
|
|
|
from starlette.responses import StreamingResponse
|
|
|
|
|
|
from WxMini.Milvus.Utils.MilvusCollectionManager import MilvusCollectionManager
|
|
|
from WxMini.Milvus.Utils.MilvusConnectionPool import *
|
|
|
from WxMini.Utils.EmbeddingUtil import text_to_embedding
|
|
|
from WxMini.Utils.ImageUtil import *
|
|
|
from WxMini.Utils.MySQLUtil import init_mysql_pool, get_chat_log_by_session, get_user_by_login_name, \
|
|
|
get_chat_logs_by_risk_flag, get_chat_logs_summary, save_chat_to_mysql
|
|
|
from WxMini.Utils.MySQLUtil import update_risk, get_last_chat_log_id
|
|
|
from WxMini.Utils.NewsUtil import *
|
|
|
from WxMini.Utils.OssUtil import upload_mp3_to_oss_from_memory, hmacsha256
|
|
|
from WxMini.Utils.TianQiUtil import get_weather
|
|
|
from WxMini.Utils.TtsUtil import TTS
|
|
|
|
|
|
# 配置日志
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# 密码加密上下文
|
|
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
|
|
# OAuth2 密码模式
|
|
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
|
|
|
|
|
# 初始化 Milvus 连接池
|
|
|
milvus_pool = MilvusConnectionPool(host=MS_HOST, port=MS_PORT, max_connections=MS_MAX_CONNECTIONS)
|
|
|
|
|
|
# 初始化集合管理器
|
|
|
collection_name = MS_COLLECTION_NAME
|
|
|
collection_manager = MilvusCollectionManager(collection_name)
|
|
|
|
|
|
|
|
|
# 使用 Lifespan Events 处理应用启动和关闭逻辑
|
|
|
@asynccontextmanager
|
|
|
async def lifespan(app: FastAPI):
|
|
|
# 应用启动时加载集合到内存
|
|
|
collection_manager.load_collection()
|
|
|
logger.info(f"集合 '{collection_name}' 已加载到内存。")
|
|
|
# 初始化 MySQL 连接池
|
|
|
app.state.mysql_pool = await init_mysql_pool()
|
|
|
logger.info("MySQL 连接池已初始化。")
|
|
|
yield
|
|
|
# 应用关闭时释放连接池
|
|
|
milvus_pool.close()
|
|
|
app.state.mysql_pool.close()
|
|
|
await app.state.mysql_pool.wait_closed()
|
|
|
logger.info("Milvus 和 MySQL 连接池已关闭。")
|
|
|
|
|
|
|
|
|
# 会话结束后,调用检查方法,判断是不是有需要介入的问题出现
|
|
|
async def on_session_end(person_id):
|
|
|
# 获取最后一条聊天记录
|
|
|
last_id = await get_last_chat_log_id(app.state.mysql_pool, person_id)
|
|
|
if last_id:
|
|
|
# 查询最后一条记录的详细信息
|
|
|
async with app.state.mysql_pool.acquire() as conn:
|
|
|
async with conn.cursor() as cur:
|
|
|
await cur.execute(
|
|
|
"SELECT user_input, model_response FROM t_chat_log WHERE id = %s",
|
|
|
(last_id,)
|
|
|
)
|
|
|
last_record = await cur.fetchone()
|
|
|
|
|
|
if last_record:
|
|
|
history = f"问题:{last_record[0]}\n回答:{last_record[1]}"
|
|
|
else:
|
|
|
history = "无聊天记录"
|
|
|
else:
|
|
|
history = "无聊天记录"
|
|
|
|
|
|
# 将历史聊天记录发给大模型,让它帮我分析一下
|
|
|
with open("Input.txt", "r", encoding="utf-8") as file:
|
|
|
input_word = file.read()
|
|
|
prompt = (
|
|
|
"分析用户是否存在心理健康方面的问题:"
|
|
|
f"参考分类文档内容如下:{input_word},注意:只有情节比较严重的才认为有健康问题,轻微的不算。"
|
|
|
"如果没有健康问题请回复: OK;否则回复:NO,换行后输出问题类型的名称"
|
|
|
f"\n\n聊天记录:{history}"
|
|
|
)
|
|
|
|
|
|
# 使用 asyncio.create_task 异步执行大模型调用
|
|
|
async def analyze_mental_health():
|
|
|
response = await client.chat.completions.create(
|
|
|
model=MODEL_NAME,
|
|
|
messages=[
|
|
|
{"role": "system", "content": "你是一个心理健康分析助手,负责分析用户的心理健康状况。"},
|
|
|
{"role": "user", "content": prompt}
|
|
|
],
|
|
|
max_tokens=1000
|
|
|
)
|
|
|
|
|
|
# 处理分析结果
|
|
|
if response.choices and response.choices[0].message.content:
|
|
|
analysis_result = response.choices[0].message.content.strip()
|
|
|
if analysis_result.startswith("NO"):
|
|
|
# 异步执行 update_risk
|
|
|
await update_risk(app.state.mysql_pool, person_id, analysis_result)
|
|
|
logger.info(f"已异步更新 person_id={person_id} 的风险状态。")
|
|
|
else:
|
|
|
logger.info(f"AI大模型没有发现任何心理健康问题,用户会话 {person_id} 没有风险。")
|
|
|
|
|
|
# 创建异步任务
|
|
|
asyncio.create_task(analyze_mental_health())
|
|
|
|
|
|
|
|
|
# 初始化 FastAPI 应用
|
|
|
app = FastAPI(lifespan=lifespan)
|
|
|
|
|
|
# 初始化异步 OpenAI 客户端
|
|
|
client = AsyncOpenAI(
|
|
|
api_key=MODEL_API_KEY,
|
|
|
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
|
|
|
)
|
|
|
|
|
|
|
|
|
# 验证密码
|
|
|
def verify_password(plain_password, hashed_password):
|
|
|
return pwd_context.verify(plain_password, hashed_password)
|
|
|
|
|
|
|
|
|
# 创建 JWT
|
|
|
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
|
|
|
to_encode = data.copy()
|
|
|
if expires_delta:
|
|
|
expire = datetime.utcnow() + expires_delta
|
|
|
else:
|
|
|
expire = datetime.utcnow() + timedelta(minutes=600)
|
|
|
to_encode.update({"exp": expire})
|
|
|
encoded_jwt = jwt.encode(to_encode, JWT_SECRET_KEY, algorithm=ALGORITHM)
|
|
|
return encoded_jwt
|
|
|
|
|
|
|
|
|
# 获取当前用户
|
|
|
async def get_current_user(token: str = Depends(oauth2_scheme)):
|
|
|
credentials_exception = HTTPException(
|
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
|
detail="无法验证凭证",
|
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
|
)
|
|
|
try:
|
|
|
payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[ALGORITHM])
|
|
|
login_name: str = payload.get("sub")
|
|
|
if login_name is None:
|
|
|
raise credentials_exception
|
|
|
except JWTError:
|
|
|
raise credentials_exception
|
|
|
user = await get_user_by_login_name(app.state.mysql_pool, login_name)
|
|
|
if user is None:
|
|
|
raise credentials_exception
|
|
|
return user
|
|
|
|
|
|
|
|
|
# 登录接口
|
|
|
@app.post("/aichat/login")
|
|
|
async def login(
|
|
|
login_name: str = Form(..., description="用户名"),
|
|
|
password: str = Form(..., description="密码")
|
|
|
):
|
|
|
"""
|
|
|
用户登录接口
|
|
|
:param login_name: 用户名
|
|
|
:param password: 密码
|
|
|
:return: 登录结果
|
|
|
"""
|
|
|
flag = True
|
|
|
if not login_name or not password:
|
|
|
flag = False
|
|
|
|
|
|
# 调用 get_user_by_login_name 方法
|
|
|
user = await get_user_by_login_name(app.state.mysql_pool, login_name)
|
|
|
if not user:
|
|
|
flag = False
|
|
|
if user and user['password'] != password:
|
|
|
flag = False
|
|
|
|
|
|
if not flag:
|
|
|
return {
|
|
|
"code": 200,
|
|
|
"message": "登录失败",
|
|
|
"success": False
|
|
|
}
|
|
|
|
|
|
# 生成 JWT
|
|
|
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
|
|
access_token = create_access_token(
|
|
|
data={"sub": user["login_name"]}, expires_delta=access_token_expires
|
|
|
)
|
|
|
|
|
|
# 返回带字段名称的数据
|
|
|
return {
|
|
|
"message": "登录成功",
|
|
|
"success": True,
|
|
|
"data": {
|
|
|
"person_id": user["person_id"],
|
|
|
"login_name": user["login_name"],
|
|
|
"identity_id": user["identity_id"],
|
|
|
"person_name": user["person_name"],
|
|
|
"xb_name": user["xb_name"],
|
|
|
"city_name": user["city_name"],
|
|
|
"area_name": user["area_name"],
|
|
|
"school_name": user["school_name"],
|
|
|
"grade_name": user["grade_name"],
|
|
|
"class_name": user["class_name"],
|
|
|
"access_token": access_token,
|
|
|
"token_type": "bearer"
|
|
|
}
|
|
|
}
|
|
|
|
|
|
|
|
|
# 与用户交流聊天
|
|
|
@app.post("/aichat/reply")
|
|
|
async def reply(person_id: str = Form(...),
|
|
|
prompt: str = Form(...),
|
|
|
current_user: dict = Depends(get_current_user)
|
|
|
):
|
|
|
"""
|
|
|
接收用户输入的 prompt,调用大模型并返回结果
|
|
|
:param person_id: 用户会话 ID
|
|
|
:param prompt: 用户输入的 prompt
|
|
|
:return: 大模型的回复
|
|
|
"""
|
|
|
logger.info(f"current_user= {current_user}")
|
|
|
try:
|
|
|
logger.info(f"收到用户输入: {prompt}")
|
|
|
|
|
|
if not prompt:
|
|
|
return {
|
|
|
"code": 200,
|
|
|
"message": "请输入内容",
|
|
|
"success": False
|
|
|
}
|
|
|
|
|
|
# 从连接池中获取一个连接
|
|
|
connection = milvus_pool.get_connection()
|
|
|
|
|
|
# 将用户输入转换为嵌入向量
|
|
|
current_embedding = text_to_embedding(prompt)
|
|
|
|
|
|
# 查询与当前对话最相关的历史交互
|
|
|
search_params = {
|
|
|
"metric_type": "L2", # 使用 L2 距离度量方式
|
|
|
"params": {"nprobe": MS_NPROBE} # 设置 IVF_FLAT 的 nprobe 参数
|
|
|
}
|
|
|
start_time = time.time()
|
|
|
results = await asyncio.to_thread( # 将阻塞操作放到线程池中执行
|
|
|
collection_manager.search,
|
|
|
data=current_embedding, # 输入向量
|
|
|
search_params=search_params, # 搜索参数
|
|
|
expr=f"person_id == '{person_id}'", # 按 person_id 过滤
|
|
|
limit=6 # 返回 6 条结果
|
|
|
)
|
|
|
end_time = time.time()
|
|
|
|
|
|
# 构建历史交互提示词
|
|
|
history_prompt = ""
|
|
|
if results:
|
|
|
for hits in results:
|
|
|
for hit in hits:
|
|
|
try:
|
|
|
# 查询非向量字段
|
|
|
record = await asyncio.to_thread(collection_manager.query_by_id, hit.id)
|
|
|
if record:
|
|
|
# logger.info(f"查询到的记录: {record}")
|
|
|
# 添加历史交互
|
|
|
history_prompt += f"用户: {record['user_input']}\n大模型: {record['model_response']}\n"
|
|
|
except Exception as e:
|
|
|
logger.error(f"查询失败: {e}")
|
|
|
|
|
|
# 在最后增加此人最近几条的交互记录数据
|
|
|
try:
|
|
|
recent_logs = await get_chat_log_by_session(app.state.mysql_pool, person_id)
|
|
|
data = recent_logs["data"]
|
|
|
for record in data:
|
|
|
history_prompt += f"用户: {record['user_input']}\n大模型: {record['model_response']}\n"
|
|
|
except Exception as e:
|
|
|
logger.error(f"获取交互记录时出错:{e}")
|
|
|
|
|
|
# 限制历史交互提示词长度
|
|
|
history_prompt = history_prompt[:3000]
|
|
|
|
|
|
# 拼接交互提示词
|
|
|
if '天气' in prompt or '降温' in prompt or '气温' in prompt or '下雨' in prompt or '下雪' in prompt or '风' in prompt:
|
|
|
weather_info = await get_weather('长春')
|
|
|
history_prompt = f"天气信息: {weather_info}\n"
|
|
|
|
|
|
logger.info(f"历史交互提示词: {history_prompt}")
|
|
|
|
|
|
# NBA与CBA
|
|
|
result = await get_news(client, prompt)
|
|
|
if result is not None:
|
|
|
history_prompt = result
|
|
|
print("新闻返回了下面的内容:" + result)
|
|
|
|
|
|
# 调用大模型,将历史交互作为提示词
|
|
|
try:
|
|
|
response = await asyncio.wait_for(
|
|
|
client.chat.completions.create(
|
|
|
model=MODEL_NAME,
|
|
|
messages=[
|
|
|
{"role": "system",
|
|
|
"content": "你是一个和你聊天人的好朋友,疏导情绪,让他开心,亲切一些,不要使用哎呀这样的语气词。聊天的回复内容不要超过100字。"},
|
|
|
{"role": "user", "content": f"历史对话记录:{history_prompt},本次用户问题: {prompt}"}
|
|
|
],
|
|
|
max_tokens=4000
|
|
|
),
|
|
|
timeout=60 # 设置超时时间为 60 秒
|
|
|
)
|
|
|
except asyncio.TimeoutError:
|
|
|
logger.error("大模型调用超时")
|
|
|
raise HTTPException(status_code=500, detail="大模型调用超时")
|
|
|
|
|
|
# 提取生成的回复
|
|
|
if response.choices and response.choices[0].message.content:
|
|
|
result = response.choices[0].message.content.strip()
|
|
|
logger.info(f"大模型回复: {result}")
|
|
|
|
|
|
# 记录用户输入和大模型反馈到向量数据库
|
|
|
timestamp = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
|
|
entities = [
|
|
|
[person_id], # person_id
|
|
|
[prompt[:500]], # user_input,截断到 500 字符
|
|
|
[result[:500]], # model_response,截断到 500 字符
|
|
|
[timestamp], # timestamp
|
|
|
[current_embedding] # embedding
|
|
|
]
|
|
|
if len(prompt) > 500:
|
|
|
logger.warning(f"用户输入被截断,原始长度: {len(prompt)}")
|
|
|
if len(result) > 500:
|
|
|
logger.warning(f"大模型回复被截断,原始长度: {len(result)}")
|
|
|
await asyncio.to_thread(collection_manager.insert_data, entities)
|
|
|
logger.info("用户输入和大模型反馈已记录到向量数据库。")
|
|
|
|
|
|
# 调用 TTS 生成 MP3
|
|
|
uuid_str = str(uuid.uuid4())
|
|
|
timestamp = int(time.time())
|
|
|
# 生成年月日的目录名称
|
|
|
audio_dir = f"audio/{time.strftime('%Y%m%d', time.localtime())}"
|
|
|
tts_file = f"{audio_dir}/{uuid_str}_{timestamp}.mp3"
|
|
|
|
|
|
# 生成 TTS 音频数据(不落盘)
|
|
|
t = TTS(None) # 传入 None 表示不保存到本地文件
|
|
|
audio_data, duration = await asyncio.to_thread(t.generate_audio,
|
|
|
result) # 假设 TTS 类有一个 generate_audio 方法返回音频数据
|
|
|
# print(f"音频时长: {duration} 秒")
|
|
|
|
|
|
# 将音频数据直接上传到 OSS
|
|
|
await asyncio.to_thread(upload_mp3_to_oss_from_memory, tts_file, audio_data)
|
|
|
logger.info(f"TTS 文件已直接上传到 OSS: {tts_file}")
|
|
|
|
|
|
# 完整的 URL
|
|
|
url = OSS_PREFIX + tts_file
|
|
|
|
|
|
# 记录聊天数据到 MySQL
|
|
|
await save_chat_to_mysql(app.state.mysql_pool, person_id, prompt, result, url, duration)
|
|
|
logger.info("用户输入和大模型反馈已记录到 MySQL 数据库。")
|
|
|
|
|
|
# 调用会话检查机制,异步执行
|
|
|
asyncio.create_task(on_session_end(person_id))
|
|
|
|
|
|
# 返回数据
|
|
|
return {
|
|
|
"success": True,
|
|
|
"url": url,
|
|
|
"search_time": end_time - start_time, # 返回查询耗时
|
|
|
"duration": duration, # 返回大模型的回复时长
|
|
|
"response": result, # 返回大模型的回复
|
|
|
}
|
|
|
else:
|
|
|
raise HTTPException(status_code=500, detail="大模型未返回有效结果")
|
|
|
except Exception as e:
|
|
|
logger.error(f"调用大模型失败: {str(e)}")
|
|
|
raise HTTPException(status_code=500, detail=f"调用大模型失败: {str(e)}")
|
|
|
finally:
|
|
|
# 释放连接
|
|
|
if 'connection' in locals() and connection: # 检查 connection 是否已定义且不为 None
|
|
|
milvus_pool.release_connection(connection)
|
|
|
|
|
|
|
|
|
# 获取聊天记录
|
|
|
@app.get("/aichat/get_chat_log")
|
|
|
async def get_chat_log(
|
|
|
person_id: str,
|
|
|
page: int = Query(default=1, ge=1, description="当前页码"),
|
|
|
page_size: int = Query(default=10, ge=1, le=100, description="每页记录数"),
|
|
|
current_user: dict = Depends(get_current_user)
|
|
|
):
|
|
|
"""
|
|
|
获取指定会话的聊天记录,默认返回最新的记录(最后一页)
|
|
|
:param person_id: 用户会话 ID
|
|
|
:param page: 当前页码(默认值为 1,但会动态计算为最后一页)
|
|
|
:param page_size: 每页记录数
|
|
|
:return: 分页数据
|
|
|
"""
|
|
|
logger.info(f"current_user={current_user['login_name']}")
|
|
|
# 调用 get_chat_log_by_session 方法
|
|
|
result = await get_chat_log_by_session(app.state.mysql_pool, person_id, page, page_size)
|
|
|
return result
|
|
|
|
|
|
|
|
|
# 获取风险聊天记录接口
|
|
|
@app.get("/aichat/get_risk_chat_logs")
|
|
|
async def get_risk_chat_logs(
|
|
|
risk_flag: int = Query(..., description="风险标志(1 表示有风险,0 表示无风险 ,2:处理完毕)"),
|
|
|
page: int = Query(default=1, ge=1, description="当前页码(默认值为 1)"),
|
|
|
page_size: int = Query(default=10, ge=1, le=100, description="每页记录数(默认值为 10,最大值为 100)"),
|
|
|
person_id: str = Query(..., description="用户会话 ID"),
|
|
|
current_user: dict = Depends(get_current_user)
|
|
|
):
|
|
|
"""
|
|
|
获取聊天记录,支持分页和风险标志过滤
|
|
|
:param risk_flag: 风险标志
|
|
|
:param page: 当前页码
|
|
|
:param page_size: 每页记录数
|
|
|
:return: 分页数据
|
|
|
"""
|
|
|
logger.info(f"current_user={current_user['login_name']}")
|
|
|
# 计算分页偏移量
|
|
|
offset = (page - 1) * page_size
|
|
|
|
|
|
# 调用 get_chat_logs_by_risk_flag 方法
|
|
|
logs, total = await get_chat_logs_by_risk_flag(app.state.mysql_pool, risk_flag, person_id, offset, page_size)
|
|
|
if not logs:
|
|
|
return {
|
|
|
"success": False,
|
|
|
"message": "没有找到相关记录",
|
|
|
"data": {
|
|
|
|
|
|
}
|
|
|
}
|
|
|
|
|
|
# 返回分页数据
|
|
|
return {
|
|
|
"success": True,
|
|
|
"message": "查询成功",
|
|
|
"data": {
|
|
|
"total": total,
|
|
|
"page": page,
|
|
|
"page_size": page_size,
|
|
|
"logs": logs
|
|
|
}
|
|
|
}
|
|
|
|
|
|
|
|
|
# 获取风险统计接口
|
|
|
@app.get("/aichat/chat_logs_summary")
|
|
|
async def chat_logs_summary(
|
|
|
risk_flag: int = Query(..., description="风险标志(1 表示有风险,0 表示无风险 ,2:处理完毕)"),
|
|
|
page: int = Query(default=1, ge=1, description="当前页码(默认值为 1)"),
|
|
|
page_size: int = Query(default=10, ge=1, le=100, description="每页记录数(默认值为 10,最大值为 100)"),
|
|
|
current_user: dict = Depends(get_current_user)
|
|
|
):
|
|
|
"""
|
|
|
获取风险统计接口,支持分页和风险标志过滤
|
|
|
:param risk_flag: 风险标志
|
|
|
:param page: 当前页码
|
|
|
:param page_size: 每页记录数
|
|
|
:param current_user: 当前用户信息
|
|
|
:return: 分页数据
|
|
|
"""
|
|
|
# 验证 risk_flag 的值
|
|
|
if risk_flag not in {0, 1, 2}:
|
|
|
raise HTTPException(status_code=400, detail="risk_flag 的值必须是 0、1 或 2")
|
|
|
|
|
|
# 计算分页偏移量
|
|
|
offset = (page - 1) * page_size
|
|
|
|
|
|
# 调用 get_chat_logs_summary 方法
|
|
|
logs, total = await get_chat_logs_summary(app.state.mysql_pool, risk_flag, offset, page_size)
|
|
|
|
|
|
# 如果未找到记录,返回友好提示
|
|
|
if not logs:
|
|
|
return {
|
|
|
"success": True,
|
|
|
"message": "未找到符合条件的记录",
|
|
|
"data": {
|
|
|
"total": 0,
|
|
|
"page": page,
|
|
|
"page_size": page_size,
|
|
|
"total_pages": 0,
|
|
|
"logs": []
|
|
|
}
|
|
|
}
|
|
|
|
|
|
# 计算总页数
|
|
|
total_pages = (total + page_size - 1) // page_size
|
|
|
|
|
|
# 返回分页数据
|
|
|
return {
|
|
|
"success": True,
|
|
|
"message": "查询成功",
|
|
|
"data": {
|
|
|
"total": total,
|
|
|
"page": page,
|
|
|
"page_size": page_size,
|
|
|
"total_pages": total_pages,
|
|
|
"logs": logs,
|
|
|
"login_name": current_user["login_name"],
|
|
|
"person_name": current_user["person_name"]
|
|
|
}
|
|
|
}
|
|
|
|
|
|
|
|
|
# 获取上传OSS的授权Token
|
|
|
@app.get("/aichat/get_post_signature_for_oss_upload")
|
|
|
async def generate_upload_params(current_user: dict = Depends(get_current_user)):
|
|
|
logger.info(f"current_user={current_user['login_name']}")
|
|
|
# 子账号的AK,SK,ARN
|
|
|
access_key_id = "LTAI5tJrhwuBzF2X9USrzubX"
|
|
|
access_key_secret = "I6ezLuYhk9z9MRjXD2q99STSpTONwW"
|
|
|
role_arn_for_oss_upload = "acs:ram::1546399445482588:role/huanghai-create-role"
|
|
|
# 桶名
|
|
|
oss_bucket = 'hzkc'
|
|
|
# 区域
|
|
|
region_id = 'cn-beijing'
|
|
|
host = f'http://{oss_bucket}.oss-cn-beijing.aliyuncs.com'
|
|
|
upload_dir = 'Upload' # 指定上传到OSS的文件前缀。
|
|
|
role_session_name = 'role_session_name' # 自定义会话名称
|
|
|
|
|
|
# 初始化配置,直接传递凭据
|
|
|
config = Config(
|
|
|
region_id=region_id,
|
|
|
access_key_id=access_key_id,
|
|
|
access_key_secret=access_key_secret
|
|
|
)
|
|
|
|
|
|
# 创建 STS 客户端并获取临时凭证
|
|
|
sts_client = Sts20150401Client(config=config)
|
|
|
assume_role_request = sts_20150401_models.AssumeRoleRequest(
|
|
|
role_arn=role_arn_for_oss_upload,
|
|
|
role_session_name=role_session_name
|
|
|
)
|
|
|
response = sts_client.assume_role(assume_role_request)
|
|
|
token_data = response.body.credentials.to_map()
|
|
|
|
|
|
# 使用 STS 返回的临时凭据
|
|
|
sts_access_key_id = token_data['AccessKeyId']
|
|
|
sts_access_key_secret = token_data['AccessKeySecret']
|
|
|
security_token = token_data['SecurityToken']
|
|
|
|
|
|
now = int(time.time())
|
|
|
# 将时间戳转换为datetime对象
|
|
|
dt_obj = datetime.utcfromtimestamp(now)
|
|
|
# 在当前时间增加3小时,设置为请求的过期时间
|
|
|
dt_obj_plus_3h = dt_obj + timedelta(hours=1)
|
|
|
|
|
|
# 请求时间
|
|
|
dt_obj_1 = dt_obj.strftime('%Y%m%dT%H%M%S') + 'Z'
|
|
|
# 请求日期
|
|
|
dt_obj_2 = dt_obj.strftime('%Y%m%d')
|
|
|
# 请求过期时间
|
|
|
expiration_time = dt_obj_plus_3h.strftime('%Y-%m-%dT%H:%M:%S.000Z')
|
|
|
|
|
|
# 构建 Policy 并生成签名
|
|
|
policy = {
|
|
|
"expiration": expiration_time,
|
|
|
"conditions": [
|
|
|
["eq", "$success_action_status", "200"],
|
|
|
{"x-oss-signature-version": "OSS4-HMAC-SHA256"},
|
|
|
{"x-oss-credential": f"{sts_access_key_id}/{dt_obj_2}/{region_id}/oss/aliyun_v4_request"},
|
|
|
{"x-oss-security-token": security_token},
|
|
|
{"x-oss-date": dt_obj_1},
|
|
|
]
|
|
|
}
|
|
|
policy_str = json.dumps(policy).strip()
|
|
|
|
|
|
# 步骤2:构造待签名字符串(StringToSign)
|
|
|
stringToSign = base64.b64encode(policy_str.encode()).decode()
|
|
|
|
|
|
# 步骤3:计算SigningKey
|
|
|
dateKey = hmacsha256(("aliyun_v4" + sts_access_key_secret).encode(), dt_obj_2)
|
|
|
dateRegionKey = hmacsha256(dateKey, region_id)
|
|
|
dateRegionServiceKey = hmacsha256(dateRegionKey, "oss")
|
|
|
signingKey = hmacsha256(dateRegionServiceKey, "aliyun_v4_request")
|
|
|
|
|
|
# 步骤4:计算Signature
|
|
|
result = hmacsha256(signingKey, stringToSign)
|
|
|
signature = result.hex()
|
|
|
|
|
|
# 组织返回数据
|
|
|
response_data = {
|
|
|
'policy': stringToSign, # 表单域
|
|
|
'x_oss_signature_version': "OSS4-HMAC-SHA256", # 指定签名的版本和算法,固定值为OSS4-HMAC-SHA256
|
|
|
'x_oss_credential': f"{sts_access_key_id}/{dt_obj_2}/{region_id}/oss/aliyun_v4_request", # 指明派生密钥的参数集
|
|
|
'x_oss_date': dt_obj_1, # 请求的时间
|
|
|
'signature': signature, # 签名认证描述信息
|
|
|
'host': host,
|
|
|
'dir': upload_dir,
|
|
|
'security_token': security_token # 安全令牌
|
|
|
|
|
|
}
|
|
|
return response_data
|
|
|
|
|
|
|
|
|
@app.get("/aichat/recognize_content")
|
|
|
async def web_recognize_content(image_url: str,
|
|
|
current_user: dict = Depends(get_current_user)
|
|
|
):
|
|
|
logger.info(f"current_user:{current_user['login_name']}")
|
|
|
person_id = current_user['person_id']
|
|
|
# 获取图片宽高
|
|
|
image_width, image_height = getImgWidthHeight(image_url)
|
|
|
|
|
|
# 调用 AI 模型生成内容(流式输出)
|
|
|
response = await client.chat.completions.create(
|
|
|
model="qwen-vl-plus",
|
|
|
messages=[{"role": "user", "content": [
|
|
|
{"type": "text", "text": "这是什么"},
|
|
|
{"type": "image_url", "image_url": {"url": image_url}}
|
|
|
]}],
|
|
|
stream=True
|
|
|
)
|
|
|
|
|
|
# 定义一个生成器函数,用于逐字返回流式结果
|
|
|
async def generate_stream():
|
|
|
summary = "" # 用于存储最终拼接的字符串
|
|
|
try:
|
|
|
# 逐块处理 AI 返回的内容
|
|
|
async for chunk in response:
|
|
|
if chunk.choices[0].delta.content:
|
|
|
chunk_content = chunk.choices[0].delta.content
|
|
|
# 逐字返回
|
|
|
for char in chunk_content:
|
|
|
print(char, end="", flush=True) # 打印到控制台
|
|
|
yield char.encode("utf-8") # 逐字返回
|
|
|
summary += char # 拼接字符
|
|
|
|
|
|
# 流式传输完成后,记录到数据库
|
|
|
await save_chat_to_mysql(
|
|
|
app.state.mysql_pool, person_id, f'{image_url}', summary, "", 0, 2, 2, 2, image_width, image_height
|
|
|
)
|
|
|
|
|
|
except asyncio.CancelledError:
|
|
|
# 客户端提前断开连接,无需处理
|
|
|
print("客户端断开连接")
|
|
|
except Exception as e:
|
|
|
error_response = json.dumps({
|
|
|
"success": False,
|
|
|
"message": f"生成内容失败: {str(e)}"
|
|
|
})
|
|
|
print(error_response)
|
|
|
yield error_response.encode("utf-8")
|
|
|
|
|
|
# 使用 StreamingResponse 返回流式结果
|
|
|
return StreamingResponse(
|
|
|
generate_stream(),
|
|
|
media_type="text/plain; charset=utf-8", # 明确指定字符编码为 UTF-8
|
|
|
headers={
|
|
|
"Cache-Control": "no-cache", # 禁用缓存
|
|
|
"Content-Type": "text/event-stream; charset=utf-8", # 设置内容类型和字符编码
|
|
|
"Transfer-Encoding": "chunked",
|
|
|
"Connection": "keep-alive",
|
|
|
"X-Accel-Buffering": "no", # 禁用 Nginx 缓冲(如果使用 Nginx)
|
|
|
}
|
|
|
)
|
|
|
|
|
|
|
|
|
@app.get("/aichat/recognize_text")
|
|
|
async def web_recognize_text(image_url: str,
|
|
|
current_user: dict = Depends(get_current_user)
|
|
|
):
|
|
|
logger.info(f"current_user:{current_user['login_name']}")
|
|
|
person_id = current_user['person_id']
|
|
|
# 获取图片宽高
|
|
|
image_width, image_height = getImgWidthHeight(image_url)
|
|
|
|
|
|
# 调用 AI 模型生成内容(流式输出)
|
|
|
response = await client.chat.completions.create(
|
|
|
model="qwen-vl-ocr",
|
|
|
messages=[
|
|
|
{
|
|
|
"role": "user",
|
|
|
"content": [
|
|
|
{
|
|
|
"type": "image_url",
|
|
|
"image_url": image_url,
|
|
|
"min_pixels": 28 * 28 * 4,
|
|
|
"max_pixels": 28 * 28 * 1280
|
|
|
},
|
|
|
{"type": "text", "text": "Read all the text in the image."},
|
|
|
]
|
|
|
}
|
|
|
],
|
|
|
stream=True
|
|
|
)
|
|
|
|
|
|
# 定义一个生成器函数,用于逐字返回流式结果
|
|
|
async def generate_stream():
|
|
|
summary = "" # 用于存储最终拼接的字符串
|
|
|
try:
|
|
|
# 逐块处理 AI 返回的内容
|
|
|
async for chunk in response:
|
|
|
if chunk.choices[0].delta.content:
|
|
|
chunk_content = chunk.choices[0].delta.content
|
|
|
# 逐字返回
|
|
|
for char in chunk_content:
|
|
|
print(char, end="", flush=True) # 打印到控制台
|
|
|
yield char.encode("utf-8") # 逐字返回
|
|
|
summary += char # 拼接字符
|
|
|
|
|
|
# 流式传输完成后,记录到数据库
|
|
|
await save_chat_to_mysql(app.state.mysql_pool, person_id, f'{image_url}', summary, "", 0, 2, 2, 1,
|
|
|
image_width,
|
|
|
image_height)
|
|
|
|
|
|
except asyncio.CancelledError:
|
|
|
# 客户端提前断开连接,无需处理
|
|
|
print("客户端断开连接")
|
|
|
except Exception as e:
|
|
|
error_response = json.dumps({
|
|
|
"success": False,
|
|
|
"message": f"生成内容失败: {str(e)}"
|
|
|
})
|
|
|
print(error_response)
|
|
|
yield error_response.encode("utf-8")
|
|
|
|
|
|
# 使用 StreamingResponse 返回流式结果
|
|
|
return StreamingResponse(
|
|
|
generate_stream(),
|
|
|
media_type="text/plain; charset=utf-8", # 明确指定字符编码为 UTF-8
|
|
|
headers={
|
|
|
"Cache-Control": "no-cache", # 禁用缓存
|
|
|
"Content-Type": "text/event-stream; charset=utf-8", # 设置内容类型和字符编码
|
|
|
"Transfer-Encoding": "chunked",
|
|
|
"Connection": "keep-alive",
|
|
|
"X-Accel-Buffering": "no", # 禁用 Nginx 缓冲(如果使用 Nginx)
|
|
|
}
|
|
|
)
|
|
|
|
|
|
|
|
|
@app.get("/aichat/recognize_math")
|
|
|
async def web_recognize_math(image_url: str,
|
|
|
current_user: dict = Depends(get_current_user)
|
|
|
):
|
|
|
logger.info(f"current_user:{current_user['login_name']}")
|
|
|
person_id = current_user['person_id']
|
|
|
# 获取图片宽高
|
|
|
image_width, image_height = getImgWidthHeight(image_url)
|
|
|
|
|
|
client = AsyncOpenAI(
|
|
|
api_key=MODELSCOPE_ACCESS_TOKEN,
|
|
|
base_url="https://api-inference.modelscope.cn/v1"
|
|
|
)
|
|
|
"""
|
|
|
识别图片中的数学题,流式输出,并将结果记录到数据库
|
|
|
:param client: AsyncOpenAI 客户端
|
|
|
:param pool: 数据库连接池
|
|
|
:param person_id: 用户 ID
|
|
|
:param image_url: 图片 URL
|
|
|
:return: 最终拼接的字符串
|
|
|
"""
|
|
|
# 提示词
|
|
|
prompt = "You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step."
|
|
|
|
|
|
response = await client.chat.completions.create(
|
|
|
model="Qwen/Qwen2.5-VL-32B-Instruct",
|
|
|
messages=[
|
|
|
{
|
|
|
"role": "system",
|
|
|
"content": [
|
|
|
{"type": "text", "text": prompt}
|
|
|
],
|
|
|
},
|
|
|
{
|
|
|
"role": "user",
|
|
|
"content": [
|
|
|
{
|
|
|
"type": "image_url",
|
|
|
"image_url": {"url": image_url}
|
|
|
},
|
|
|
{"type": "text", "text": "请使用中文回答:如何作答?"},
|
|
|
],
|
|
|
}
|
|
|
],
|
|
|
stream=True
|
|
|
)
|
|
|
|
|
|
# 定义一个生成器函数,用于逐字返回流式结果
|
|
|
async def generate_stream():
|
|
|
summary = "" # 用于存储最终拼接的字符串
|
|
|
try:
|
|
|
# 逐块处理 AI 返回的内容
|
|
|
async for chunk in response:
|
|
|
if chunk.choices[0].delta.content:
|
|
|
chunk_content = chunk.choices[0].delta.content
|
|
|
# 逐字返回
|
|
|
for char in chunk_content:
|
|
|
print(char, end="", flush=True) # 打印到控制台
|
|
|
yield char.encode("utf-8") # 逐字返回
|
|
|
summary += char # 拼接字符
|
|
|
|
|
|
# 流式传输完成后,记录到数据库
|
|
|
await save_chat_to_mysql(app.state.mysql_pool, person_id, f'{image_url}', summary, "", 0, 2, 2, 1,
|
|
|
image_width, image_height)
|
|
|
|
|
|
except asyncio.CancelledError:
|
|
|
# 客户端提前断开连接,无需处理
|
|
|
print("客户端断开连接")
|
|
|
except Exception as e:
|
|
|
error_response = json.dumps({
|
|
|
"success": False,
|
|
|
"message": f"生成内容失败: {str(e)}"
|
|
|
})
|
|
|
print(error_response)
|
|
|
yield error_response.encode("utf-8")
|
|
|
|
|
|
# 使用 StreamingResponse 返回流式结果
|
|
|
return StreamingResponse(
|
|
|
generate_stream(),
|
|
|
media_type="text/plain; charset=utf-8", # 明确指定字符编码为 UTF-8
|
|
|
headers={
|
|
|
"Cache-Control": "no-cache", # 禁用缓存
|
|
|
"Content-Type": "text/event-stream; charset=utf-8", # 设置内容类型和字符编码
|
|
|
"Transfer-Encoding": "chunked",
|
|
|
"Connection": "keep-alive",
|
|
|
"X-Accel-Buffering": "no", # 禁用 Nginx 缓冲(如果使用 Nginx)
|
|
|
}
|
|
|
)
|
|
|
|
|
|
|
|
|
# 运行 FastAPI 应用
|
|
|
if __name__ == "__main__":
|
|
|
import uvicorn
|
|
|
|
|
|
uvicorn_log_config = {
|
|
|
"version": 1,
|
|
|
"disable_existing_loggers": False,
|
|
|
"formatters": {
|
|
|
"default": {
|
|
|
"()": "uvicorn.logging.DefaultFormatter",
|
|
|
"fmt": "%(asctime)s %(levelprefix)s %(message)s",
|
|
|
"datefmt": "%Y-%m-%d %H:%M:%S",
|
|
|
}
|
|
|
},
|
|
|
"handlers": {
|
|
|
"default": {
|
|
|
"formatter": "default",
|
|
|
"class": "logging.StreamHandler",
|
|
|
"stream": "ext://sys.stderr",
|
|
|
}
|
|
|
},
|
|
|
"loggers": {
|
|
|
"uvicorn": {"handlers": ["default"], "level": "INFO", "propagate": False},
|
|
|
"uvicorn.access": {"handlers": ["default"], "level": "INFO", "propagate": False},
|
|
|
"uvicorn.error": {"handlers": ["default"], "level": "INFO", "propagate": False},
|
|
|
"uvicorn.asgi": {"handlers": [], "level": "INFO", "propagate": False}, # 禁用 ASGI 日志
|
|
|
},
|
|
|
}
|
|
|
|
|
|
# 强制覆盖Uvicorn的默认配置
|
|
|
uvicorn.run("Start:app", host="0.0.0.0", port=5600, workers=1, log_config=uvicorn_log_config)
|