diff --git a/dsRag/Milvus/Utils/MilvusCollectionManager.py b/dsRag/Milvus/Utils/MilvusCollectionManager.py index afbafc3e..389b7e3a 100644 --- a/dsRag/Milvus/Utils/MilvusCollectionManager.py +++ b/dsRag/Milvus/Utils/MilvusCollectionManager.py @@ -17,7 +17,7 @@ class MilvusCollectionManager: """ if utility.has_collection(self.collection_name): self.collection = Collection(name=self.collection_name) - #print(f"集合 '{self.collection_name}' 已加载。") + # print(f"集合 '{self.collection_name}' 已加载。") else: print(f"集合 '{self.collection_name}' 不存在。") @@ -52,6 +52,7 @@ class MilvusCollectionManager: if self.collection is None: raise Exception("集合未加载,请检查集合是否存在。") self.collection.insert(entities) + def load_collection(self): """ 加载集合到内存 @@ -70,7 +71,7 @@ class MilvusCollectionManager: # 使用 Milvus 的 query 方法查询指定 ID 的记录 results = self.collection.query( expr=f"id == {id}", # 查询条件 - output_fields=["id", "person_id", "user_input", "model_response", "timestamp"] # 返回的字段 + output_fields=["id", "document_id", "user_input", "timestamp"] # 返回的字段 ) if results: return results[0] # 返回第一条记录 @@ -79,6 +80,7 @@ class MilvusCollectionManager: except Exception as e: print(f"查询失败: {e}") return None + def search(self, data, search_params, expr=None, limit=5): """ 在集合中搜索与输入向量最相似的数据 @@ -121,4 +123,4 @@ class MilvusCollectionManager: if result: return result[0]["text"] else: - return None \ No newline at end of file + return None diff --git a/dsRag/Milvus/Utils/__pycache__/MilvusCollectionManager.cpython-310.pyc b/dsRag/Milvus/Utils/__pycache__/MilvusCollectionManager.cpython-310.pyc index 95eb9b18..b42a2a2b 100644 Binary files a/dsRag/Milvus/Utils/__pycache__/MilvusCollectionManager.cpython-310.pyc and b/dsRag/Milvus/Utils/__pycache__/MilvusCollectionManager.cpython-310.pyc differ diff --git a/dsRag/Milvus/X5_select_all_data.py b/dsRag/Milvus/X5_select_all_data.py index 2ca5528c..60cdfaa6 100644 --- a/dsRag/Milvus/X5_select_all_data.py +++ b/dsRag/Milvus/X5_select_all_data.py @@ -21,7 +21,7 @@ try: # 使用 Milvus 的 query 方法查询所有数据 results = collection_manager.collection.query( expr="", # 空表达式表示查询所有数据 - output_fields=["id", "person_id", "user_input", "model_response", "timestamp", "embedding"], # 指定返回的字段 + output_fields=["id", "document_id", "user_input", "timestamp", "embedding"], # 指定返回的字段 limit=1000 # 设置最大返回记录数 ) print("查询结果:") @@ -29,16 +29,14 @@ try: for result in results: try: # 获取字段值 - person_id = result["person_id"] + document_id = result["document_id"] user_input = result["user_input"] - model_response = result["model_response"] timestamp = result["timestamp"] embedding = result["embedding"] # 打印结果 print(f"ID: {result['id']}") - print(f"会话 ID: {person_id}") + print(f"文档 ID: {document_id}") print(f"用户问题: {user_input}") - print(f"大模型回复: {model_response}") print(f"时间: {timestamp}") print(f"向量: {embedding[:5]}...") # 只打印前 5 维向量 print("-" * 40) # 分隔线 diff --git a/dsRag/Milvus/X6_search_near_data.py b/dsRag/Milvus/X6_search_near_data.py index 80738407..de8d1bf7 100644 --- a/dsRag/Milvus/X6_search_near_data.py +++ b/dsRag/Milvus/X6_search_near_data.py @@ -62,9 +62,8 @@ if results: # 查询非向量字段 record = collection_manager.query_by_id(hit.id) print(f"ID: {hit.id}") - print(f"会话 ID: {record['person_id']}") + print(f"文档 ID: {record['document_id']}") print(f"用户问题: {record['user_input']}") - print(f"大模型回复: {record['model_response']}") print(f"时间: {record['timestamp']}") print(f"距离: {hit.distance}") print("-" * 40) # 分隔线 diff --git a/dsRag/Start.py b/dsRag/Start.py index 61c45016..a20ccc31 100644 --- a/dsRag/Start.py +++ b/dsRag/Start.py @@ -32,11 +32,11 @@ print(f"模型加载成功,词向量维度: {model.vector_size}") async def lifespan(app: FastAPI): # 初始化Milvus连接池 app.state.milvus_pool = MilvusConnectionPool(host=MS_HOST, port=MS_PORT, max_connections=MS_MAX_CONNECTIONS) - + # 初始化集合管理器 app.state.collection_manager = MilvusCollectionManager(MS_COLLECTION_NAME) app.state.collection_manager.load_collection() - + # 初始化DeepSeek客户端 app.state.deepseek_client = OpenAI( api_key=Config.DEEPSEEK_API_KEY, @@ -47,8 +47,10 @@ async def lifespan(app: FastAPI): # 关闭Milvus连接池 app.state.milvus_pool.close() + app = FastAPI(lifespan=lifespan) + # 将文本转换为嵌入向量 def text_to_embedding(text): words = jieba.lcut(text) # 使用 jieba 分词 @@ -63,11 +65,12 @@ def text_to_embedding(text): print("未找到有效词,返回零向量") return [0.0] * model.vector_size + async def generate_stream(client, milvus_pool, collection_manager, query): """生成SSE流""" # 从连接池获取连接 connection = milvus_pool.get_connection() - + try: # 1. 将查询文本转换为向量 current_embedding = text_to_embedding(query) @@ -79,10 +82,10 @@ async def generate_stream(client, milvus_pool, collection_manager, query): } # 7. 将文本转换为嵌入向量 results = collection_manager.search(current_embedding, search_params, limit=5) # 返回 2 条结果 - + # 3. 处理搜索结果 print("最相关的历史对话:") - context="" + context = "" if results: for hits in results: for hit in hits: @@ -90,10 +93,9 @@ async def generate_stream(client, milvus_pool, collection_manager, query): # 查询非向量字段 record = collection_manager.query_by_id(hit.id) print(f"ID: {hit.id}") - print(f"会话 ID: {record['person_id']}") + print(f"文档 ID: {record['document_id']}") print(f"用户问题: {record['user_input']}") - context=context+record['user_input'] - print(f"大模型回复: {record['model_response']}") + context = context + record['user_input'] print(f"时间: {record['timestamp']}") print(f"距离: {hit.distance}") print("-" * 40) # 分隔线 @@ -101,7 +103,7 @@ async def generate_stream(client, milvus_pool, collection_manager, query): print(f"查询失败: {e}") else: print("未找到相关历史对话,请检查查询参数或数据。") - + prompt = f"""根据以下关于'{query}'的相关信息,# Role: 信息检索与回答助手 ## Profile @@ -154,7 +156,7 @@ async def generate_stream(client, milvus_pool, collection_manager, query): 相关信息: {context}""" - + response = client.chat.completions.create( model="deepseek-chat", messages=[ @@ -164,7 +166,7 @@ async def generate_stream(client, milvus_pool, collection_manager, query): temperature=0.3, stream=True ) - + for chunk in response: if chunk.choices[0].delta.content: yield {"data": chunk.choices[0].delta.content} @@ -174,9 +176,13 @@ async def generate_stream(client, milvus_pool, collection_manager, query): finally: # 释放连接 milvus_pool.release_connection(connection) + + """ http://10.10.21.22:8000/api/rag?query=小学数学中有哪些模型 """ + + @app.post("/api/rag") async def rag_stream(request: Request, query: str = Body(...)): """RAG+DeepSeek流式接口""" @@ -189,7 +195,8 @@ async def rag_stream(request: Request, query: str = Body(...)): ) ) + app.mount("/static", StaticFiles(directory="Static"), name="static") if __name__ == "__main__": - uvicorn.run(app, host="0.0.0.0", port=8000) \ No newline at end of file + uvicorn.run(app, host="0.0.0.0", port=8000)