You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

266 lines
10 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import os
import subprocess
import tempfile
import urllib.parse
import uuid
from contextlib import asynccontextmanager
from io import BytesIO
from logging.handlers import RotatingFileHandler
from typing import List
import jieba # 导入 jieba 分词库
import uvicorn
from fastapi import FastAPI, Request, HTTPException
from fastapi.staticfiles import StaticFiles
from gensim.models import KeyedVectors
from pydantic import BaseModel, Field, ValidationError
from starlette.responses import StreamingResponse
from Config.Config import MS_MODEL_PATH, MS_MODEL_LIMIT, MS_HOST, MS_PORT, MS_MAX_CONNECTIONS, MS_NPROBE, \
MS_COLLECTION_NAME
from Milvus.Utils.MilvusCollectionManager import MilvusCollectionManager
from Milvus.Utils.MilvusConnectionPool import *
from Milvus.Utils.MilvusConnectionPool import MilvusConnectionPool
from Util.ALiYunUtil import ALiYunUtil
# 初始化日志
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
handler = RotatingFileHandler('../Logs/start.log', maxBytes=1024 * 1024, backupCount=5)
handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
logger.addHandler(handler)
# 1. 加载预训练的 Word2Vec 模型
model = KeyedVectors.load_word2vec_format(MS_MODEL_PATH, binary=False, limit=MS_MODEL_LIMIT)
logger.info(f"模型加载成功,词向量维度: {model.vector_size}")
# 将HTML文件转换为Word文件
def html_to_word_pandoc(html_file, output_file):
subprocess.run(['pandoc', html_file, '-o', output_file])
@asynccontextmanager
async def lifespan(app: FastAPI):
# 初始化Milvus连接池
app.state.milvus_pool = MilvusConnectionPool(host=MS_HOST, port=MS_PORT, max_connections=MS_MAX_CONNECTIONS)
# 初始化集合管理器
app.state.collection_manager = MilvusCollectionManager(MS_COLLECTION_NAME)
app.state.collection_manager.load_collection()
# 初始化阿里云大模型工具
app.state.aliyun_util = ALiYunUtil()
yield
# 关闭Milvus连接池
app.state.milvus_pool.close()
app = FastAPI(lifespan=lifespan)
# 挂载静态文件目录
app.mount("../static", StaticFiles(directory="Static"), name="static")
# 将文本转换为嵌入向量
def text_to_embedding(text):
words = jieba.lcut(text) # 使用 jieba 分词
print(f"文本: {text}, 分词结果: {words}")
embeddings = [model[word] for word in words if word in model]
logger.info(f"有效词向量数量: {len(embeddings)}")
if embeddings:
avg_embedding = sum(embeddings) / len(embeddings)
logger.info(f"生成的平均向量: {avg_embedding[:5]}...") # 打印前 5 维
return avg_embedding
else:
logger.warning("未找到有效词,返回零向量")
return [0.0] * model.vector_size
async def generate_stream(client, milvus_pool, collection_manager, query, documents):
# 从连接池获取连接
connection = milvus_pool.get_connection()
try:
# 1. 将查询文本转换为向量
current_embedding = text_to_embedding(query)
# 2. 搜索相关数据
search_params = {
"metric_type": "L2", # 使用 L2 距离度量方式
"params": {"nprobe": MS_NPROBE} # 设置 IVF_FLAT 的 nprobe 参数
}
# 动态生成expr表达式
if documents:
conditions = [f"array_contains(tags['tags'], '{doc}')" for doc in documents]
expr = " OR ".join(conditions)
else:
expr = "" # 如果没有选择文档,返回空字符串
# 7. 将文本转换为嵌入向量
results = collection_manager.search(current_embedding,
search_params,
expr=expr, # 使用in操作符
limit=5) # 返回 5 条结果
# 3. 处理搜索结果
logger.info("最相关的知识库内容:")
context = ""
if results:
for hits in results:
for hit in hits:
try:
# 查询非向量字段
record = collection_manager.query_by_id(hit.id)
if hit.distance < 0.88: # 设置距离阈值
logger.info(f"ID: {hit.id}")
logger.info(f"标签: {record['tags']}")
logger.info(f"用户问题: {record['user_input']}")
logger.info(f"时间: {record['timestamp']}")
logger.info(f"距离: {hit.distance}")
logger.info("-" * 40) # 分隔线
# 获取完整内容
full_content = record['tags'].get('full_content', record['user_input'])
context = context + full_content
else:
logger.warning(f"距离太远,忽略此结果: {hit.id}")
logger.info(f"标签: {record['tags']}")
logger.info(f"用户问题: {record['user_input']}")
logger.info(f"时间: {record['timestamp']}")
logger.info(f"距离: {hit.distance}")
continue
except Exception as e:
logger.error(f"查询失败: {e}")
else:
logger.warning("未找到相关历史对话,请检查查询参数或数据。")
prompt = f"""
信息检索与回答助手
根据以下关于'{query}'的相关信息:
基本信息
- 语言: 中文
- 描述: 根据提供的材料检索信息并回答问题
- 特点: 快速准确提取关键信息,清晰简洁地回答
相关信息
{context}
回答要求
1. 依托给定的资料,快速准确地回答问题,可以添加一些额外的信息,但请勿重复内容。
2. 如果未提供相关信息,请不要回答。
3. 如果发现相关信息与原来的问题契合度低,也不要回答
4. 使用HTML格式返回包含适当的段落、列表和标题标签
5. 确保内容结构清晰,便于前端展示
"""
# 调用阿里云大模型
if len(context) > 0:
html_content = client.chat(prompt)
yield {"data": html_content}
else:
yield {"data": "没有在知识库中找到相关的信息,无法回答此问题。"}
except Exception as e:
yield {"data": f"生成报告时出错: {str(e)}"}
finally:
# 释放连接
milvus_pool.release_connection(connection)
"""
http://10.10.21.22:8000/static/ai.html
知识库中有的内容:
小学数学中有哪些模型?
帮我写一下 “如何理解点、线、面、体、角”的教学设计
知识库中没有的内容:
你知道黄海是谁吗?
"""
class QueryRequest(BaseModel):
query: str = Field(..., description="用户查询的问题")
documents: List[str] = Field(..., description="用户上传的文档")
class SaveWordRequest(BaseModel):
html: str = Field(..., description="要保存为Word的HTML内容")
@app.post("/api/save-word")
async def save_to_word(request: Request):
temp_html = None
output_file = None
try:
# Parse request data
try:
data = await request.json()
html_content = data.get('html_content', '')
if not html_content:
raise ValueError("Empty HTML content")
except Exception as e:
logger.error(f"Request parsing failed: {str(e)}")
raise HTTPException(status_code=400, detail=f"Invalid request: {str(e)}")
# 创建临时HTML文件
temp_html = os.path.join(tempfile.gettempdir(), uuid.uuid4().hex + ".html")
with open(temp_html, "w", encoding="utf-8") as f:
f.write(html_content)
# 使用pandoc转换
output_file = os.path.join(tempfile.gettempdir(), "小学数学问答.docx")
subprocess.run(['pandoc', temp_html, '-o', output_file], check=True)
# 读取生成的Word文件
with open(output_file, "rb") as f:
stream = BytesIO(f.read())
# 返回响应
encoded_filename = urllib.parse.quote("小学数学问答.docx")
return StreamingResponse(
stream,
media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
headers={"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}"})
except HTTPException:
raise
except Exception as e:
logger.error(f"Unexpected error: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")
finally:
# 清理临时文件
try:
if temp_html and os.path.exists(temp_html):
os.remove(temp_html)
if output_file and os.path.exists(output_file):
os.remove(output_file)
except Exception as e:
logger.warning(f"Failed to clean up temp files: {str(e)}")
@app.post("/api/rag")
async def rag_stream(request: Request):
try:
data = await request.json()
query_request = QueryRequest(**data)
except ValidationError as e:
logger.error(f"请求体验证失败: {e.errors()}")
raise HTTPException(status_code=422, detail=e.errors())
except Exception as e:
logger.error(f"请求解析失败: {str(e)}")
raise HTTPException(status_code=400, detail="无效的请求格式")
"""RAG+ALiYun接口"""
async for chunk in generate_stream(
request.app.state.aliyun_util,
request.app.state.milvus_pool,
request.app.state.collection_manager,
query_request.query,
query_request.documents
):
return chunk
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)