You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

222 lines
7.9 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import asyncio
from contextlib import asynccontextmanager
from logging.handlers import RotatingFileHandler
import jieba # 导入 jieba 分词库
import uvicorn
from fastapi import FastAPI, Request, HTTPException
from pydantic import BaseModel, Field, ValidationError
from fastapi.staticfiles import StaticFiles
from openai import OpenAI
from sse_starlette.sse import EventSourceResponse
from gensim.models import KeyedVectors
from Config import Config
from Config.Config import MS_MODEL_PATH, MS_MODEL_LIMIT, MS_HOST, MS_PORT, MS_MAX_CONNECTIONS, MS_NPROBE, DEEPSEEK_API_KEY, DEEPSEEK_URL, MS_COLLECTION_NAME
from Milvus.Utils.MilvusCollectionManager import MilvusCollectionManager
from Milvus.Utils.MilvusConnectionPool import *
from Milvus.Utils.MilvusConnectionPool import MilvusConnectionPool
from docx import Document
from docx.shared import Inches
from io import BytesIO
import html2text
# 初始化日志
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
handler = RotatingFileHandler('Logs/start.log', maxBytes=1024 * 1024, backupCount=5)
handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
logger.addHandler(handler)
# 1. 加载预训练的 Word2Vec 模型
model = KeyedVectors.load_word2vec_format(MS_MODEL_PATH, binary=False, limit=MS_MODEL_LIMIT)
logger.info(f"模型加载成功,词向量维度: {model.vector_size}")
@asynccontextmanager
async def lifespan(app: FastAPI):
# 初始化Milvus连接池
app.state.milvus_pool = MilvusConnectionPool(host=MS_HOST, port=MS_PORT, max_connections=MS_MAX_CONNECTIONS)
# 初始化集合管理器
app.state.collection_manager = MilvusCollectionManager(MS_COLLECTION_NAME)
app.state.collection_manager.load_collection()
# 初始化DeepSeek客户端
app.state.deepseek_client = OpenAI(
api_key=Config.DEEPSEEK_API_KEY,
base_url=Config.DEEPSEEK_URL
)
yield
# 关闭Milvus连接池
app.state.milvus_pool.close()
app = FastAPI(lifespan=lifespan)
# 挂载静态文件目录
app.mount("/static", StaticFiles(directory="Static"), name="static")
# 将文本转换为嵌入向量
def text_to_embedding(text):
words = jieba.lcut(text) # 使用 jieba 分词
print(f"文本: {text}, 分词结果: {words}")
embeddings = [model[word] for word in words if word in model]
logger.info(f"有效词向量数量: {len(embeddings)}")
if embeddings:
avg_embedding = sum(embeddings) / len(embeddings)
logger.info(f"生成的平均向量: {avg_embedding[:5]}...") # 打印前 5 维
return avg_embedding
else:
logger.warning("未找到有效词,返回零向量")
return [0.0] * model.vector_size
async def generate_stream(client, milvus_pool, collection_manager, query):
"""生成SSE流"""
# 从连接池获取连接
connection = milvus_pool.get_connection()
try:
# 1. 将查询文本转换为向量
current_embedding = text_to_embedding(query)
# 2. 搜索相关数据
search_params = {
"metric_type": "L2", # 使用 L2 距离度量方式
"params": {"nprobe": MS_NPROBE} # 设置 IVF_FLAT 的 nprobe 参数
}
# 7. 将文本转换为嵌入向量
results = collection_manager.search(current_embedding, search_params, limit=5) # 返回 2 条结果
# 3. 处理搜索结果
logger.info("最相关的历史对话:")
context = ""
if results:
for hits in results:
for hit in hits:
try:
# 查询非向量字段
record = collection_manager.query_by_id(hit.id)
logger.info(f"ID: {hit.id}")
logger.info(f"标签: {record['tags']}")
logger.info(f"用户问题: {record['user_input']}")
# 获取完整内容
full_content = record['tags'].get('full_content', record['user_input'])
context = context + full_content
logger.info(f"时间: {record['timestamp']}")
logger.info(f"距离: {hit.distance}")
logger.info("-" * 40) # 分隔线
except Exception as e:
logger.error(f"查询失败: {e}")
else:
logger.warning("未找到相关历史对话,请检查查询参数或数据。")
prompt = f"""
信息检索与回答助手
根据以下关于'{query}'的相关信息:
基本信息
- 语言: 中文
- 描述: 根据提供的材料检索信息并回答问题
- 特点: 快速准确提取关键信息,清晰简洁地回答
相关信息
{context}
回答要求
1. 依托给定的资料,快速准确地回答问题,可以添加一些额外的信息,但请勿重复内容。
2. 使用HTML格式返回包含适当的段落、列表和标题标签
3. 确保内容结构清晰,便于前端展示
"""
response = client.chat.completions.create(
model="deepseek-chat",
messages=[
{"role": "system", "content": "你是一个专业的文档整理助手"},
{"role": "user", "content": prompt}
],
temperature=0.3,
stream=False
)
yield {"data": response.choices[0].message.content}
except Exception as e:
yield {"data": f"生成报告时出错: {str(e)}"}
finally:
# 释放连接
milvus_pool.release_connection(connection)
"""
http://10.10.21.22:8000/static/ai.html
小学数学中有哪些模型?
"""
class QueryRequest(BaseModel):
query: str = Field(..., description="用户查询的问题")
class SaveWordRequest(BaseModel):
html: str = Field(..., description="要保存为Word的HTML内容")
@app.post("/api/save-word")
async def save_to_word(request: Request):
try:
data = await request.json()
save_request = SaveWordRequest(**data)
except ValidationError as e:
logger.error(f"请求体验证失败: {e.errors()}")
raise HTTPException(status_code=422, detail=e.errors())
except Exception as e:
logger.error(f"请求解析失败: {str(e)}")
raise HTTPException(status_code=400, detail="无效的请求格式")
# 转换HTML为纯文本
h = html2text.HTML2Text()
h.ignore_links = True
plain_text = h.handle(save_request.html)
# 创建Word文档
doc = Document()
doc.add_heading('小学数学问答', 0)
doc.add_paragraph(plain_text)
# 保存到内存中的字节流
file_stream = BytesIO()
doc.save(file_stream)
file_stream.seek(0)
return StreamingResponse(
file_stream,
media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
headers={"Content-Disposition": "attachment; filename=小学数学问答.docx"}
)
@app.post("/api/rag")
async def rag_stream(request: Request):
try:
data = await request.json()
query_request = QueryRequest(**data)
except ValidationError as e:
logger.error(f"请求体验证失败: {e.errors()}")
raise HTTPException(status_code=422, detail=e.errors())
except Exception as e:
logger.error(f"请求解析失败: {str(e)}")
raise HTTPException(status_code=400, detail="无效的请求格式")
"""RAG+DeepSeek接口"""
async for chunk in generate_stream(
request.app.state.deepseek_client,
request.app.state.milvus_pool,
request.app.state.collection_manager,
query_request.query
):
return chunk
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)