From 4598a3132b589e9e2e2029573cc09e71b1d93cfe Mon Sep 17 00:00:00 2001 From: HuangHai <10402852@qq.com> Date: Thu, 26 Jun 2025 15:38:15 +0800 Subject: [PATCH] 'commit' --- dsRag/Start.py | 54 +++++--- dsRag/Static/chat.html | 10 -- dsRag/{Static => static}/ai.html | 0 dsRag/static/chat.html | 206 +++++++++++++++++++++++++++++++ 4 files changed, 242 insertions(+), 28 deletions(-) delete mode 100644 dsRag/Static/chat.html rename dsRag/{Static => static}/ai.html (100%) create mode 100644 dsRag/static/chat.html diff --git a/dsRag/Start.py b/dsRag/Start.py index 05e547b5..483a8420 100644 --- a/dsRag/Start.py +++ b/dsRag/Start.py @@ -4,12 +4,14 @@ from logging.handlers import RotatingFileHandler import jieba # 导入 jieba 分词库 import uvicorn -from fastapi import FastAPI, Request, Body +from fastapi import FastAPI, Request, HTTPException +from pydantic import BaseModel, Field, ValidationError +from fastapi.staticfiles import StaticFiles from openai import OpenAI from sse_starlette.sse import EventSourceResponse from gensim.models import KeyedVectors from Config import Config -from Config.Config import * +from Config.Config import MS_MODEL_PATH, MS_MODEL_LIMIT, MS_HOST, MS_PORT, MS_MAX_CONNECTIONS, MS_NPROBE, DEEPSEEK_API_KEY, DEEPSEEK_URL, MS_COLLECTION_NAME from Milvus.Utils.MilvusCollectionManager import MilvusCollectionManager from Milvus.Utils.MilvusConnectionPool import * from Milvus.Utils.MilvusConnectionPool import MilvusConnectionPool @@ -24,7 +26,7 @@ logger.addHandler(handler) # 1. 加载预训练的 Word2Vec 模型 model_path = MS_MODEL_PATH # 替换为你的 Word2Vec 模型路径 model = KeyedVectors.load_word2vec_format(model_path, binary=False, limit=MS_MODEL_LIMIT) -print(f"模型加载成功,词向量维度: {model.vector_size}") +logger.info(f"模型加载成功,词向量维度: {model.vector_size}") @asynccontextmanager @@ -49,19 +51,22 @@ async def lifespan(app: FastAPI): app = FastAPI(lifespan=lifespan) +# 挂载静态文件目录 +app.mount("/static", StaticFiles(directory="Static"), name="static") + # 将文本转换为嵌入向量 def text_to_embedding(text): words = jieba.lcut(text) # 使用 jieba 分词 print(f"文本: {text}, 分词结果: {words}") embeddings = [model[word] for word in words if word in model] - print(f"有效词向量数量: {len(embeddings)}") + logger.info(f"有效词向量数量: {len(embeddings)}") if embeddings: avg_embedding = sum(embeddings) / len(embeddings) - print(f"生成的平均向量: {avg_embedding[:5]}...") # 打印前 5 维 + logger.info(f"生成的平均向量: {avg_embedding[:5]}...") # 打印前 5 维 return avg_embedding else: - print("未找到有效词,返回零向量") + logger.warning("未找到有效词,返回零向量") return [0.0] * model.vector_size @@ -83,7 +88,7 @@ async def generate_stream(client, milvus_pool, collection_manager, query): results = collection_manager.search(current_embedding, search_params, limit=5) # 返回 2 条结果 # 3. 处理搜索结果 - print("最相关的历史对话:") + logger.info("最相关的历史对话:") context = "" if results: for hits in results: @@ -91,17 +96,17 @@ async def generate_stream(client, milvus_pool, collection_manager, query): try: # 查询非向量字段 record = collection_manager.query_by_id(hit.id) - print(f"ID: {hit.id}") - print(f"标签: {record['tags']}") - print(f"用户问题: {record['user_input']}") + logger.info(f"ID: {hit.id}") + logger.info(f"标签: {record['tags']}") + logger.info(f"用户问题: {record['user_input']}") context = context + record['user_input'] - print(f"时间: {record['timestamp']}") - print(f"距离: {hit.distance}") - print("-" * 40) # 分隔线 + logger.info(f"时间: {record['timestamp']}") + logger.info(f"距离: {hit.distance}") + logger.info("-" * 40) # 分隔线 except Exception as e: - print(f"查询失败: {e}") + logger.error(f"查询失败: {e}") else: - print("未找到相关历史对话,请检查查询参数或数据。") + logger.warning("未找到相关历史对话,请检查查询参数或数据。") prompt = f"""根据以下关于'{query}'的相关信息,# Role: 信息检索与回答助手 @@ -178,19 +183,32 @@ async def generate_stream(client, milvus_pool, collection_manager, query): """ -http://10.10.21.22:8000/api/rag?query=小学数学中有哪些模型 +http://10.10.21.22:8000/static/chat.html +小学数学中有哪些模型? """ +class QueryRequest(BaseModel): + query: str = Field(..., description="用户查询的问题") + @app.post("/api/rag") -async def rag_stream(request: Request, query: str = Body(...)): +async def rag_stream(request: Request): + try: + data = await request.json() + query_request = QueryRequest(**data) + except ValidationError as e: + logger.error(f"请求体验证失败: {e.errors()}") + raise HTTPException(status_code=422, detail=e.errors()) + except Exception as e: + logger.error(f"请求解析失败: {str(e)}") + raise HTTPException(status_code=400, detail="无效的请求格式") """RAG+DeepSeek流式接口""" return EventSourceResponse( generate_stream( request.app.state.deepseek_client, request.app.state.milvus_pool, request.app.state.collection_manager, - query + query_request.query ) ) diff --git a/dsRag/Static/chat.html b/dsRag/Static/chat.html deleted file mode 100644 index 93257211..00000000 --- a/dsRag/Static/chat.html +++ /dev/null @@ -1,10 +0,0 @@ - - -
- -