main
HuangHai 4 months ago
parent 8d375f97ed
commit 43d7d215f6

@ -3,7 +3,7 @@ import uuid
import time
import jieba
from fastapi import FastAPI, Form, HTTPException
from openai import AsyncOpenAI # 使用异步客户端
from openai import OpenAI
from gensim.models import KeyedVectors
from contextlib import asynccontextmanager
from TtsConfig import *
@ -12,15 +12,9 @@ from WxMini.TtsUtil import TTS
from WxMini.Milvus.Utils.MilvusCollectionManager import MilvusCollectionManager
from WxMini.Milvus.Utils.MilvusConnectionPool import *
from WxMini.Milvus.Config.MulvusConfig import *
import asyncio # 引入异步支持
import logging # 增加日志记录
import jieba.analyse
# 配置日志
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
# 提取用户输入的关键词
def extract_keywords(text, topK=3):
"""
@ -35,7 +29,7 @@ def extract_keywords(text, topK=3):
# 初始化 Word2Vec 模型
model_path = MS_MODEL_PATH
model = KeyedVectors.load_word2vec_format(model_path, binary=False, limit=MS_MODEL_LIMIT)
logger.info(f"模型加载成功,词向量维度: {model.vector_size}")
print(f"模型加载成功,词向量维度: {model.vector_size}")
# 初始化 Milvus 连接池
milvus_pool = MilvusConnectionPool(host=MS_HOST, port=MS_PORT, max_connections=MS_MAX_CONNECTIONS)
@ -47,15 +41,15 @@ collection_manager = MilvusCollectionManager(collection_name)
# 将文本转换为嵌入向量
def text_to_embedding(text):
words = jieba.lcut(text) # 使用 jieba 分词
logger.info(f"文本: {text}, 分词结果: {words}")
print(f"文本: {text}, 分词结果: {words}")
embeddings = [model[word] for word in words if word in model]
logger.info(f"有效词向量数量: {len(embeddings)}")
print(f"有效词向量数量: {len(embeddings)}")
if embeddings:
avg_embedding = sum(embeddings) / len(embeddings)
logger.info(f"生成的平均向量: {avg_embedding[:5]}...") # 打印前 5 维
print(f"生成的平均向量: {avg_embedding[:5]}...") # 打印前 5 维
return avg_embedding
else:
logger.warning("未找到有效词,返回零向量")
print("未找到有效词,返回零向量")
return [0.0] * model.vector_size
# 使用 Lifespan Events 处理应用启动和关闭逻辑
@ -63,17 +57,17 @@ def text_to_embedding(text):
async def lifespan(app: FastAPI):
# 应用启动时加载集合到内存
collection_manager.load_collection()
logger.info(f"集合 '{collection_name}' 已加载到内存。")
print(f"集合 '{collection_name}' 已加载到内存。")
yield
# 应用关闭时释放连接池
milvus_pool.close()
logger.info("Milvus 连接池已关闭。")
print("Milvus 连接池已关闭。")
# 初始化 FastAPI 应用
app = FastAPI(lifespan=lifespan)
# 初始化异步 OpenAI 客户端
client = AsyncOpenAI(
# 初始化 OpenAI 客户端
client = OpenAI(
api_key=MODEL_API_KEY,
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
)
@ -87,7 +81,7 @@ async def reply(session_id: str = Form(...), prompt: str = Form(...)):
:return: 大模型的回复
"""
try:
logger.info(f"收到用户输入: {prompt}")
print(f"收到用户输入: {prompt}")
# 从连接池中获取一个连接
connection = milvus_pool.get_connection()
@ -100,8 +94,7 @@ async def reply(session_id: str = Form(...), prompt: str = Form(...)):
"params": {"nprobe": MS_NPROBE} # 设置 IVF_FLAT 的 nprobe 参数
}
start_time = time.time()
results = await asyncio.to_thread( # 将阻塞操作放到线程池中执行
collection_manager.search,
results = collection_manager.search(
data=current_embedding, # 输入向量
search_params=search_params, # 搜索参数
expr=f"session_id == '{session_id}'", # 按 session_id 过滤
@ -116,18 +109,18 @@ async def reply(session_id: str = Form(...), prompt: str = Form(...)):
for hit in hits:
try:
# 查询非向量字段
record = await asyncio.to_thread(collection_manager.query_by_id, hit.id)
record = collection_manager.query_by_id(hit.id)
if record:
logger.info(f"查询到的记录: {record}")
print(f"查询到的记录: {record}")
# 添加历史交互
history_prompt += f"用户: {record['user_input']}\n大模型: {record['model_response']}\n"
except Exception as e:
logger.error(f"查询失败: {e}")
print(f"查询失败: {e}")
logger.info(f"历史交互提示词: {history_prompt}")
print(f"历史交互提示词: {history_prompt}")
# 调用大模型,将历史交互作为提示词
response = await client.chat.completions.create( # 使用异步调用
response = client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": "你是一个私人助理,负责回答用户的问题。请根据用户的历史对话和当前问题,提供准确且简洁的回答。不要提及你是通义千问或其他无关信息,也不可以回复与本次用户问题不相关的历史对话记录内容。"},
@ -139,7 +132,7 @@ async def reply(session_id: str = Form(...), prompt: str = Form(...)):
# 提取生成的回复
if response.choices and response.choices[0].message.content:
result = response.choices[0].message.content.strip()
logger.info(f"大模型回复: {result}")
print(f"大模型回复: {result}")
# 记录用户输入和大模型反馈到向量数据库
timestamp = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
@ -150,22 +143,22 @@ async def reply(session_id: str = Form(...), prompt: str = Form(...)):
[timestamp], # timestamp
[current_embedding] # embedding
]
await asyncio.to_thread(collection_manager.insert_data, entities)
logger.info("用户输入和大模型反馈已记录到向量数据库。")
collection_manager.insert_data(entities)
print("用户输入和大模型反馈已记录到向量数据库。")
# 调用 TTS 生成 MP3
uuid_str = str(uuid.uuid4())
tts_file = "audio/" + uuid_str + ".mp3"
t = TTS(tts_file)
await asyncio.to_thread(t.start, result) # 将 TTS 生成放到线程池中执行
t.start(result)
# 文件上传到 OSS
await asyncio.to_thread(upload_mp3_to_oss, tts_file, tts_file)
upload_mp3_to_oss(tts_file, tts_file)
# 删除临时文件
try:
os.remove(tts_file)
logger.info(f"临时文件 {tts_file} 已删除")
print(f"临时文件 {tts_file} 已删除")
except Exception as e:
logger.error(f"删除临时文件失败: {e}")
print(f"删除临时文件失败: {e}")
# 完整的 URL
url = 'https://ylt.oss-cn-hangzhou.aliyuncs.com/' + tts_file
return {
@ -177,7 +170,6 @@ async def reply(session_id: str = Form(...), prompt: str = Form(...)):
else:
raise HTTPException(status_code=500, detail="大模型未返回有效结果")
except Exception as e:
logger.error(f"调用大模型失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"调用大模型失败: {str(e)}")
finally:
# 释放连接

@ -13,7 +13,7 @@ def getToken():
# 从Redis获取 TTS Token
retrieved_token = get_tts_token()
if retrieved_token:
print("使用Redis中的Token:", retrieved_token)
#print("使用Redis中的Token:", retrieved_token)
return retrieved_token
# 创建AcsClient实例
client = AcsClient(
@ -32,16 +32,16 @@ def getToken():
jss = json.loads(response)
if 'Token' in jss and 'Id' in jss['Token']:
token = jss['Token']['Id']
expireTime = jss['Token']['ExpireTime']
#expireTime = jss['Token']['ExpireTime']
# 转换为本地时间
expire_date = datetime.fromtimestamp(expireTime)
#expire_date = datetime.fromtimestamp(expireTime)
# 格式化输出
formatted_date = expire_date.strftime("%Y-%m-%d %H:%M:%S")
print("过期时间:", formatted_date)
#formatted_date = expire_date.strftime("%Y-%m-%d %H:%M:%S")
#print("过期时间:", formatted_date)
# 计算时间差(秒数)
now = datetime.now()
time_diff = (expire_date - now).total_seconds()
print("距离过期还有(秒):", time_diff)
#now = datetime.now()
#time_diff = (expire_date - now).total_seconds()
#print("距离过期还有(秒):", time_diff)
# 设置 TTS Token
set_tts_token(token)
return token

@ -52,7 +52,6 @@ class MilvusCollectionManager:
if self.collection is None:
raise Exception("集合未加载,请检查集合是否存在。")
self.collection.insert(entities)
print("数据插入成功。")
def load_collection(self):
"""
加载集合到内存

@ -3,7 +3,7 @@ import uuid
import time
import jieba
from fastapi import FastAPI, Form, HTTPException
from openai import OpenAI
from openai import AsyncOpenAI # 使用异步客户端
from gensim.models import KeyedVectors
from contextlib import asynccontextmanager
from TtsConfig import *
@ -12,9 +12,15 @@ from WxMini.TtsUtil import TTS
from WxMini.Milvus.Utils.MilvusCollectionManager import MilvusCollectionManager
from WxMini.Milvus.Utils.MilvusConnectionPool import *
from WxMini.Milvus.Config.MulvusConfig import *
import asyncio # 引入异步支持
import logging # 增加日志记录
import jieba.analyse
# 配置日志
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
# 提取用户输入的关键词
def extract_keywords(text, topK=3):
"""
@ -29,7 +35,7 @@ def extract_keywords(text, topK=3):
# 初始化 Word2Vec 模型
model_path = MS_MODEL_PATH
model = KeyedVectors.load_word2vec_format(model_path, binary=False, limit=MS_MODEL_LIMIT)
print(f"模型加载成功,词向量维度: {model.vector_size}")
logger.info(f"模型加载成功,词向量维度: {model.vector_size}")
# 初始化 Milvus 连接池
milvus_pool = MilvusConnectionPool(host=MS_HOST, port=MS_PORT, max_connections=MS_MAX_CONNECTIONS)
@ -41,15 +47,15 @@ collection_manager = MilvusCollectionManager(collection_name)
# 将文本转换为嵌入向量
def text_to_embedding(text):
words = jieba.lcut(text) # 使用 jieba 分词
print(f"文本: {text}, 分词结果: {words}")
logger.info(f"文本: {text}, 分词结果: {words}")
embeddings = [model[word] for word in words if word in model]
print(f"有效词向量数量: {len(embeddings)}")
logger.info(f"有效词向量数量: {len(embeddings)}")
if embeddings:
avg_embedding = sum(embeddings) / len(embeddings)
print(f"生成的平均向量: {avg_embedding[:5]}...") # 打印前 5 维
logger.info(f"生成的平均向量: {avg_embedding[:5]}...") # 打印前 5 维
return avg_embedding
else:
print("未找到有效词,返回零向量")
logger.warning("未找到有效词,返回零向量")
return [0.0] * model.vector_size
# 使用 Lifespan Events 处理应用启动和关闭逻辑
@ -57,17 +63,17 @@ def text_to_embedding(text):
async def lifespan(app: FastAPI):
# 应用启动时加载集合到内存
collection_manager.load_collection()
print(f"集合 '{collection_name}' 已加载到内存。")
logger.info(f"集合 '{collection_name}' 已加载到内存。")
yield
# 应用关闭时释放连接池
milvus_pool.close()
print("Milvus 连接池已关闭。")
logger.info("Milvus 连接池已关闭。")
# 初始化 FastAPI 应用
app = FastAPI(lifespan=lifespan)
# 初始化 OpenAI 客户端
client = OpenAI(
# 初始化异步 OpenAI 客户端
client = AsyncOpenAI(
api_key=MODEL_API_KEY,
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
)
@ -81,7 +87,7 @@ async def reply(session_id: str = Form(...), prompt: str = Form(...)):
:return: 大模型的回复
"""
try:
print(f"收到用户输入: {prompt}")
logger.info(f"收到用户输入: {prompt}")
# 从连接池中获取一个连接
connection = milvus_pool.get_connection()
@ -94,7 +100,8 @@ async def reply(session_id: str = Form(...), prompt: str = Form(...)):
"params": {"nprobe": MS_NPROBE} # 设置 IVF_FLAT 的 nprobe 参数
}
start_time = time.time()
results = collection_manager.search(
results = await asyncio.to_thread( # 将阻塞操作放到线程池中执行
collection_manager.search,
data=current_embedding, # 输入向量
search_params=search_params, # 搜索参数
expr=f"session_id == '{session_id}'", # 按 session_id 过滤
@ -109,18 +116,18 @@ async def reply(session_id: str = Form(...), prompt: str = Form(...)):
for hit in hits:
try:
# 查询非向量字段
record = collection_manager.query_by_id(hit.id)
record = await asyncio.to_thread(collection_manager.query_by_id, hit.id)
if record:
print(f"查询到的记录: {record}")
logger.info(f"查询到的记录: {record}")
# 添加历史交互
history_prompt += f"用户: {record['user_input']}\n大模型: {record['model_response']}\n"
except Exception as e:
print(f"查询失败: {e}")
logger.error(f"查询失败: {e}")
print(f"历史交互提示词: {history_prompt}")
logger.info(f"历史交互提示词: {history_prompt}")
# 调用大模型,将历史交互作为提示词
response = client.chat.completions.create(
response = await client.chat.completions.create( # 使用异步调用
model=MODEL_NAME,
messages=[
{"role": "system", "content": "你是一个私人助理,负责回答用户的问题。请根据用户的历史对话和当前问题,提供准确且简洁的回答。不要提及你是通义千问或其他无关信息,也不可以回复与本次用户问题不相关的历史对话记录内容。"},
@ -132,7 +139,7 @@ async def reply(session_id: str = Form(...), prompt: str = Form(...)):
# 提取生成的回复
if response.choices and response.choices[0].message.content:
result = response.choices[0].message.content.strip()
print(f"大模型回复: {result}")
logger.info(f"大模型回复: {result}")
# 记录用户输入和大模型反馈到向量数据库
timestamp = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
@ -143,22 +150,22 @@ async def reply(session_id: str = Form(...), prompt: str = Form(...)):
[timestamp], # timestamp
[current_embedding] # embedding
]
collection_manager.insert_data(entities)
print("用户输入和大模型反馈已记录到向量数据库。")
await asyncio.to_thread(collection_manager.insert_data, entities)
logger.info("用户输入和大模型反馈已记录到向量数据库。")
# 调用 TTS 生成 MP3
uuid_str = str(uuid.uuid4())
tts_file = "audio/" + uuid_str + ".mp3"
t = TTS(tts_file)
t.start(result)
await asyncio.to_thread(t.start, result) # 将 TTS 生成放到线程池中执行
# 文件上传到 OSS
upload_mp3_to_oss(tts_file, tts_file)
await asyncio.to_thread(upload_mp3_to_oss, tts_file, tts_file)
# 删除临时文件
try:
os.remove(tts_file)
print(f"临时文件 {tts_file} 已删除")
logger.info(f"临时文件 {tts_file} 已删除")
except Exception as e:
print(f"删除临时文件失败: {e}")
logger.error(f"删除临时文件失败: {e}")
# 完整的 URL
url = 'https://ylt.oss-cn-hangzhou.aliyuncs.com/' + tts_file
return {
@ -170,6 +177,7 @@ async def reply(session_id: str = Form(...), prompt: str = Form(...)):
else:
raise HTTPException(status_code=500, detail="大模型未返回有效结果")
except Exception as e:
logger.error(f"调用大模型失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"调用大模型失败: {str(e)}")
finally:
# 释放连接

@ -14,8 +14,8 @@ REDIS_PASSWORD = None # Redis 密码(如果没有密码,设置为 None
# 阿里云中用来调用 deepseek v3 的密钥
MODEL_API_KEY = "sk-01d13a39e09844038322108ecdbd1bbc"
MODEL_NAME = "qwen-plus"
#MODEL_NAME = "deepseek-v3"
#MODEL_NAME = "qwen-plus"
MODEL_NAME = "deepseek-v3"
# TTS的APPKEY
APPKEY = "90RJcqjlN4ZqymGd" # 获取Appkey请前往控制台https://nls-portal.console.aliyun.com/applist

Loading…
Cancel
Save