You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

218 lines
8.6 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import asyncio
from contextlib import asynccontextmanager
from logging.handlers import RotatingFileHandler
import jieba # 导入 jieba 分词库
import uvicorn
from fastapi import FastAPI, Request, HTTPException
from pydantic import BaseModel, Field, ValidationError
from fastapi.staticfiles import StaticFiles
from openai import OpenAI
from sse_starlette.sse import EventSourceResponse
from gensim.models import KeyedVectors
from Config import Config
from Config.Config import MS_MODEL_PATH, MS_MODEL_LIMIT, MS_HOST, MS_PORT, MS_MAX_CONNECTIONS, MS_NPROBE, DEEPSEEK_API_KEY, DEEPSEEK_URL, MS_COLLECTION_NAME
from Milvus.Utils.MilvusCollectionManager import MilvusCollectionManager
from Milvus.Utils.MilvusConnectionPool import *
from Milvus.Utils.MilvusConnectionPool import MilvusConnectionPool
# 初始化日志
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
handler = RotatingFileHandler('Logs/start.log', maxBytes=1024 * 1024, backupCount=5)
handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
logger.addHandler(handler)
# 1. 加载预训练的 Word2Vec 模型
model_path = MS_MODEL_PATH # 替换为你的 Word2Vec 模型路径
model = KeyedVectors.load_word2vec_format(model_path, binary=False, limit=MS_MODEL_LIMIT)
logger.info(f"模型加载成功,词向量维度: {model.vector_size}")
@asynccontextmanager
async def lifespan(app: FastAPI):
# 初始化Milvus连接池
app.state.milvus_pool = MilvusConnectionPool(host=MS_HOST, port=MS_PORT, max_connections=MS_MAX_CONNECTIONS)
# 初始化集合管理器
app.state.collection_manager = MilvusCollectionManager(MS_COLLECTION_NAME)
app.state.collection_manager.load_collection()
# 初始化DeepSeek客户端
app.state.deepseek_client = OpenAI(
api_key=Config.DEEPSEEK_API_KEY,
base_url=Config.DEEPSEEK_URL
)
yield
# 关闭Milvus连接池
app.state.milvus_pool.close()
app = FastAPI(lifespan=lifespan)
# 挂载静态文件目录
app.mount("/static", StaticFiles(directory="Static"), name="static")
# 将文本转换为嵌入向量
def text_to_embedding(text):
words = jieba.lcut(text) # 使用 jieba 分词
print(f"文本: {text}, 分词结果: {words}")
embeddings = [model[word] for word in words if word in model]
logger.info(f"有效词向量数量: {len(embeddings)}")
if embeddings:
avg_embedding = sum(embeddings) / len(embeddings)
logger.info(f"生成的平均向量: {avg_embedding[:5]}...") # 打印前 5 维
return avg_embedding
else:
logger.warning("未找到有效词,返回零向量")
return [0.0] * model.vector_size
async def generate_stream(client, milvus_pool, collection_manager, query):
"""生成SSE流"""
# 从连接池获取连接
connection = milvus_pool.get_connection()
try:
# 1. 将查询文本转换为向量
current_embedding = text_to_embedding(query)
# 2. 搜索相关数据
search_params = {
"metric_type": "L2", # 使用 L2 距离度量方式
"params": {"nprobe": MS_NPROBE} # 设置 IVF_FLAT 的 nprobe 参数
}
# 7. 将文本转换为嵌入向量
results = collection_manager.search(current_embedding, search_params, limit=5) # 返回 2 条结果
# 3. 处理搜索结果
logger.info("最相关的历史对话:")
context = ""
if results:
for hits in results:
for hit in hits:
try:
# 查询非向量字段
record = collection_manager.query_by_id(hit.id)
logger.info(f"ID: {hit.id}")
logger.info(f"标签: {record['tags']}")
logger.info(f"用户问题: {record['user_input']}")
context = context + record['user_input']
logger.info(f"时间: {record['timestamp']}")
logger.info(f"距离: {hit.distance}")
logger.info("-" * 40) # 分隔线
except Exception as e:
logger.error(f"查询失败: {e}")
else:
logger.warning("未找到相关历史对话,请检查查询参数或数据。")
prompt = f"""根据以下关于'{query}'的相关信息,# Role: 信息检索与回答助手
## Profile
- language: 中文
- description: 这是一个专门设计来根据提供的材料检索信息并回答相关问题的助手。它能够快速准确地从大量文本中提取关键信息,并以清晰、简洁的方式回答用户的问题。
- background: 该助手基于先进的自然语言处理技术,能够理解和处理复杂的查询,提供准确的信息。
- personality: 冷静、客观、高效
- expertise: 信息检索、文本分析、问答系统
- target_audience: 需要快速获取信息的用户,如研究人员、学生、专业人士等。
## Skills
1. 信息检索
- 文本搜索: 能够在大量文本中快速搜索关键词或短语。
- 语义理解: 理解用户的查询意图,即使查询语句不完全符合标准格式。
- 结果筛选: 从搜索结果中筛选出最相关的信息。
2. 信息处理
- 文本摘要: 提供文本的简要摘要,帮助用户快速了解主要内容。
- 关键信息提取: 提取文本中的关键信息,如日期、地点、人物等。
- 数据整合: 将来自不同来源的信息进行整合,提供全面的回答。
## Rules
1. 基本原则:
- 准确性: 提供的信息必须准确无误,确保来源可靠。
- 客观性: 回答问题时保持客观,避免主观判断。
- 完整性: 尽可能提供完整的信息,满足用户的需求。
2. 行为准则:
- 及时响应: 快速响应用户的查询,提供及时的信息。
- 清晰表达: 使用简洁明了的语言,确保用户能够理解回答。
- 保密性: 严格遵守保密协议,不泄露用户的个人信息。
3. 限制条件:
- 不提供猜测性信息: 只提供有据可查的信息,不进行猜测。
- 不传播不实信息: 确保提供的信息真实可靠,不传播不实信息。
- 不涉及敏感内容: 避免回答涉及敏感内容的问题。
## Workflows
- 目标: 根据提供的材料回答用户的问题。
- 步骤 1: 接收并理解用户的查询,确定查询意图。
- 步骤 2: 在提供的材料中搜索相关信息,筛选出最相关的信息。
- 步骤 3: 对搜索到的信息进行处理,提取关键信息并进行整合。
- 预期结果: 提供准确、清晰、完整的回答,满足用户的需求。
## Initialization
作为信息检索与回答助手你必须遵守上述Rules按照Workflows执行任务。
相关信息:
{context}"""
response = client.chat.completions.create(
model="deepseek-chat",
messages=[
{"role": "system", "content": "你是一个专业的文档整理助手"},
{"role": "user", "content": prompt}
],
temperature=0.3,
stream=True
)
for chunk in response:
if chunk.choices[0].delta.content:
yield {"data": chunk.choices[0].delta.content}
await asyncio.sleep(0.01)
except Exception as e:
yield {"data": f"生成报告时出错: {str(e)}"}
finally:
# 释放连接
milvus_pool.release_connection(connection)
"""
http://10.10.21.22:8000/static/chat.html
小学数学中有哪些模型?
"""
class QueryRequest(BaseModel):
query: str = Field(..., description="用户查询的问题")
@app.post("/api/rag")
async def rag_stream(request: Request):
try:
data = await request.json()
query_request = QueryRequest(**data)
except ValidationError as e:
logger.error(f"请求体验证失败: {e.errors()}")
raise HTTPException(status_code=422, detail=e.errors())
except Exception as e:
logger.error(f"请求解析失败: {str(e)}")
raise HTTPException(status_code=400, detail="无效的请求格式")
"""RAG+DeepSeek流式接口"""
return EventSourceResponse(
generate_stream(
request.app.state.deepseek_client,
request.app.state.milvus_pool,
request.app.state.collection_manager,
query_request.query
)
)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)