main
HuangHai 4 months ago
parent d2243aa457
commit 703e9b1ddb

@ -1,5 +1,5 @@
# Milvus 服务器的主机地址
MS_HOST = "10.10.14.101"
MS_HOST = "10.10.14.205"
# Milvus 服务器的端口号
MS_PORT = "19530"

@ -1,5 +1,4 @@
import time
import jieba # 导入 jieba 分词库
from WxMini.Milvus.Utils.MilvusCollectionManager import MilvusCollectionManager
from WxMini.Milvus.Utils.MilvusConnectionPool import *
from WxMini.Milvus.Config.MulvusConfig import *
@ -13,14 +12,11 @@ print(f"模型加载成功,词向量维度: {model.vector_size}")
# 将文本转换为嵌入向量
def text_to_embedding(text):
words = jieba.lcut(text) # 使用 jieba 分词
print(f"文本: {text}, 分词结果: {words}")
embeddings = [model[word] for word in words if word in model]
print(f"有效词向量数量: {len(embeddings)}")
if embeddings:
avg_embedding = sum(embeddings) / len(embeddings)
print(f"生成的平均向量: {avg_embedding[:5]}...") # 打印前 5 维
return avg_embedding
# 直接使用全句进行向量计算
if text in model:
embedding = model[text]
print(f"生成的全句向量: {embedding[:5]}...") # 打印前 5 维
return embedding
else:
print("未找到有效词,返回零向量")
return [0.0] * model.vector_size
@ -74,4 +70,4 @@ print(f"查询耗时: {end_time - start_time:.4f} 秒")
milvus_pool.release_connection(connection)
# 12. 关闭连接池
milvus_pool.close()
milvus_pool.close()

@ -81,22 +81,28 @@ class MilvusCollectionManager:
except Exception as e:
print(f"查询失败: {e}")
return None
def search(self, query_embedding, search_params, limit=2):
def search(self, data, search_params, expr=None, limit=5):
"""
查询数据
:param query_embedding: 查询向量
:param search_params: 查询参数
:param limit: 返回结果数量
:return: 查询结果
在集合中搜索与输入向量最相似的数据
:param data: 输入向量
:param search_params: 搜索参数
:param expr: 过滤条件可选
:param limit: 返回结果的数量
:return: 搜索结果
"""
if self.collection is None:
raise Exception("集合未加载,请检查集合是否存在。")
return self.collection.search(
data=[query_embedding],
anns_field="embedding",
param=search_params,
limit=limit
)
try:
# 构建搜索参数
search_result = self.collection.search(
data=[data], # 输入向量
anns_field="embedding", # 向量字段名称
param=search_params, # 搜索参数
limit=limit, # 返回结果的数量
expr=expr # 过滤条件
)
return search_result
except Exception as e:
print(f"搜索失败: {e}")
return None
def query_text_by_id(self, id):
"""

@ -1,7 +1,7 @@
import os
import uuid
from fastapi import FastAPI, Form, HTTPException
from fastapi.middleware.cors import CORSMiddleware # 导入 CORSMiddleware
from openai import OpenAI
from TtsConfig import *
from WxMini.OssUtil import upload_mp3_to_oss
@ -10,6 +10,15 @@ from WxMini.TtsUtil import TTS
# 初始化 FastAPI 应用
app = FastAPI()
# 添加跨域支持
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 允许所有来源,也可以指定具体的域名
allow_credentials=True,
allow_methods=["*"], # 允许所有 HTTP 方法
allow_headers=["*"], # 允许所有 HTTP 头
)
# 初始化 OpenAI 客户端
client = OpenAI(
api_key=MODEL_API_KEY,
@ -65,4 +74,4 @@ async def reply(prompt: str = Form(...)):
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=5500)
uvicorn.run(app, host="0.0.0.0", port=5500)

@ -0,0 +1,229 @@
import os
import uuid
import time
import jieba
from fastapi import FastAPI, Form, HTTPException
from openai import OpenAI
from gensim.models import KeyedVectors
from contextlib import asynccontextmanager
from TtsConfig import *
from WxMini.OssUtil import upload_mp3_to_oss
from WxMini.TtsUtil import TTS
from WxMini.Milvus.Utils.MilvusCollectionManager import MilvusCollectionManager
from WxMini.Milvus.Utils.MilvusConnectionPool import *
from WxMini.Milvus.Config.MulvusConfig import *
import jieba.analyse
# 提取用户输入的关键词
def extract_keywords(text, topK=3):
"""
提取用户输入的关键词
:param text: 用户输入的文本
:param topK: 返回的关键词数量
:return: 关键词列表
"""
keywords = jieba.analyse.extract_tags(text, topK=topK)
return keywords
# 构建查询条件
def build_query_expr(session_id, keywords):
"""
构建查询条件
:param session_id: 用户会话 ID
:param keywords: 关键词列表
:return: 查询条件表达式
"""
# 基础条件:过滤 session_id
expr = f"session_id == '{session_id}'"
# 添加关键词条件
if keywords:
keyword_conditions = []
for keyword in keywords:
if len(keyword) > 1: # 过滤过短的关键词
# 使用前缀匹配
keyword_conditions.append(f"user_input like '{keyword}%'")
if keyword_conditions:
expr += " and (" + " or ".join(keyword_conditions) + ")"
return expr
# 初始化 Word2Vec 模型
model_path = MS_MODEL_PATH
model = KeyedVectors.load_word2vec_format(model_path, binary=False, limit=MS_MODEL_LIMIT)
print(f"模型加载成功,词向量维度: {model.vector_size}")
# 初始化 Milvus 连接池
milvus_pool = MilvusConnectionPool(host=MS_HOST, port=MS_PORT, max_connections=MS_MAX_CONNECTIONS)
# 初始化集合管理器
collection_name = MS_COLLECTION_NAME
collection_manager = MilvusCollectionManager(collection_name)
# 将文本转换为嵌入向量
def text_to_embedding(text):
words = jieba.lcut(text) # 使用 jieba 分词
print(f"文本: {text}, 分词结果: {words}")
embeddings = [model[word] for word in words if word in model]
print(f"有效词向量数量: {len(embeddings)}")
if embeddings:
avg_embedding = sum(embeddings) / len(embeddings)
print(f"生成的平均向量: {avg_embedding[:5]}...") # 打印前 5 维
return avg_embedding
else:
print("未找到有效词,返回零向量")
return [0.0] * model.vector_size
# 使用 Lifespan Events 处理应用启动和关闭逻辑
@asynccontextmanager
async def lifespan(app: FastAPI):
# 应用启动时加载集合到内存
collection_manager.load_collection()
print(f"集合 '{collection_name}' 已加载到内存。")
yield
# 应用关闭时释放连接池
milvus_pool.close()
print("Milvus 连接池已关闭。")
# 初始化 FastAPI 应用
app = FastAPI(lifespan=lifespan)
# 初始化 OpenAI 客户端
client = OpenAI(
api_key=MODEL_API_KEY,
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
)
# 设置相似度阈值
SIMILARITY_THRESHOLD = 0.5 # 距离小于 0.5 的结果被认为是高相似度
# 在 /reply 接口中优化查询逻辑和提示词构建方式
# 在 /reply 接口中优化提示词传递和系统提示词
# 在 /reply 接口中优化提示词传递和系统提示词
# 在 /reply 接口中优化查询逻辑和提示词构建方式
@app.post("/reply")
async def reply(session_id: str = Form(...), prompt: str = Form(...)):
"""
接收用户输入的 prompt调用大模型并返回结果
:param session_id: 用户会话 ID
:param prompt: 用户输入的 prompt
:return: 大模型的回复
"""
try:
# 从连接池中获取一个连接
connection = milvus_pool.get_connection()
# 将用户输入转换为嵌入向量
current_embedding = text_to_embedding(prompt)
# 提取用户输入的关键词
keywords = extract_keywords(prompt)
print(f"提取的关键词: {keywords}")
# 构建查询条件
expr = f"session_id == '{session_id}'" # 只过滤 session_id
print(f"查询条件: {expr}")
# 查询与当前对话最相关的五条历史交互
search_params = {
"metric_type": "L2", # 使用 L2 距离度量方式
"params": {"nprobe": MS_NPROBE} # 设置 IVF_FLAT 的 nprobe 参数
}
start_time = time.time()
results = collection_manager.search(
data=current_embedding, # 输入向量
search_params=search_params, # 搜索参数
expr=expr, # 查询条件
limit=5 # 返回 5 条结果
)
end_time = time.time()
# 调试:输出查询结果
print(f"查询结果: {results}")
# 构建历史交互提示词
history_prompt = ""
if results:
for hits in results:
for hit in hits:
try:
# 过滤低相似度的结果
if hit.distance > SIMILARITY_THRESHOLD:
print(f"跳过低相似度结果,距离: {hit.distance}")
continue
# 查询非向量字段
record = collection_manager.query_by_id(hit.id)
if record:
print(f"查询到的记录: {record}")
# 只添加与当前问题高度相关的历史交互
if any(keyword in record['user_input'] for keyword in keywords):
history_prompt += f"用户: {record['user_input']}\n大模型: {record['model_response']}\n"
except Exception as e:
print(f"查询失败: {e}")
print(f"历史交互提示词: {history_prompt}")
# 调用大模型,将历史交互作为提示词
response = client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": "你是一个私人助理,负责回答用户的问题。请根据用户的历史对话和当前问题,提供准确且简洁的回答。不要提及你是通义千问或其他无关信息。"},
{"role": "user", "content": f"{history_prompt}用户: {prompt}"} # 将历史交互和当前输入一起发送
],
max_tokens=500
)
# 提取生成的回复
if response.choices and response.choices[0].message.content:
result = response.choices[0].message.content.strip()
# 记录用户输入和大模型反馈到向量数据库
timestamp = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
entities = [
[session_id], # session_id
[prompt], # user_input
[result], # model_response
[timestamp], # timestamp
[current_embedding] # embedding
]
collection_manager.insert_data(entities)
print("用户输入和大模型反馈已记录到向量数据库。")
# 调用tts进行生成mp3
uuid_str = str(uuid.uuid4())
tts_file = "audio/" + uuid_str + ".mp3"
t = TTS(tts_file)
t.start(result)
# 文件上传到oss
upload_mp3_to_oss(tts_file, tts_file)
# 删除临时文件
try:
os.remove(tts_file)
print(f"临时文件 {tts_file} 已删除")
except Exception as e:
print(f"删除临时文件失败: {e}")
# 完整的url
url = 'https://ylt.oss-cn-hangzhou.aliyuncs.com/' + tts_file
return {
"success": True,
"url": url,
"search_time": end_time - start_time, # 返回查询耗时
"response": result # 返回大模型的回复
}
else:
raise HTTPException(status_code=500, detail="大模型未返回有效结果")
except Exception as e:
raise HTTPException(status_code=500, detail=f"调用大模型失败: {str(e)}")
finally:
# 释放连接
milvus_pool.release_connection(connection)
# 运行 FastAPI 应用
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=5600)
Loading…
Cancel
Save