diff --git a/dsRag/Milvus/Utils/MilvusCollectionManager.py b/dsRag/Milvus/Utils/MilvusCollectionManager.py index afbafc3e..389b7e3a 100644 --- a/dsRag/Milvus/Utils/MilvusCollectionManager.py +++ b/dsRag/Milvus/Utils/MilvusCollectionManager.py @@ -17,7 +17,7 @@ class MilvusCollectionManager: """ if utility.has_collection(self.collection_name): self.collection = Collection(name=self.collection_name) - #print(f"集合 '{self.collection_name}' 已加载。") + # print(f"集合 '{self.collection_name}' 已加载。") else: print(f"集合 '{self.collection_name}' 不存在。") @@ -52,6 +52,7 @@ class MilvusCollectionManager: if self.collection is None: raise Exception("集合未加载,请检查集合是否存在。") self.collection.insert(entities) + def load_collection(self): """ 加载集合到内存 @@ -70,7 +71,7 @@ class MilvusCollectionManager: # 使用 Milvus 的 query 方法查询指定 ID 的记录 results = self.collection.query( expr=f"id == {id}", # 查询条件 - output_fields=["id", "person_id", "user_input", "model_response", "timestamp"] # 返回的字段 + output_fields=["id", "document_id", "user_input", "timestamp"] # 返回的字段 ) if results: return results[0] # 返回第一条记录 @@ -79,6 +80,7 @@ class MilvusCollectionManager: except Exception as e: print(f"查询失败: {e}") return None + def search(self, data, search_params, expr=None, limit=5): """ 在集合中搜索与输入向量最相似的数据 @@ -121,4 +123,4 @@ class MilvusCollectionManager: if result: return result[0]["text"] else: - return None \ No newline at end of file + return None diff --git a/dsRag/Milvus/Utils/__pycache__/MilvusCollectionManager.cpython-310.pyc b/dsRag/Milvus/Utils/__pycache__/MilvusCollectionManager.cpython-310.pyc index a6935521..e56a1fb1 100644 Binary files a/dsRag/Milvus/Utils/__pycache__/MilvusCollectionManager.cpython-310.pyc and b/dsRag/Milvus/Utils/__pycache__/MilvusCollectionManager.cpython-310.pyc differ diff --git a/dsRag/Milvus/X1_create_collection.py b/dsRag/Milvus/X1_create_collection.py index aa58e2c7..6a56a86f 100644 --- a/dsRag/Milvus/X1_create_collection.py +++ b/dsRag/Milvus/X1_create_collection.py @@ -27,13 +27,12 @@ if utility.has_collection(collection_name): # 5. 定义集合的字段和模式 fields = [ FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True), # 主键字段,自动生成 ID - FieldSchema(name="person_id", dtype=DataType.VARCHAR, max_length=64), # 会话 ID + FieldSchema(name="document_id", dtype=DataType.VARCHAR, max_length=64), # 文档 ID FieldSchema(name="user_input", dtype=DataType.VARCHAR, max_length=65535), # 用户问题 - FieldSchema(name="model_response", dtype=DataType.VARCHAR, max_length=65535), # 大模型反馈结果 FieldSchema(name="timestamp", dtype=DataType.VARCHAR, max_length=32), # 时间 FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=MS_DIMENSION) # 向量字段,维度为 200 ] -schema_description = "Chat records collection with person_id , user input, model response, and timestamp" +schema_description = "Chat records collection with document_id , user_input, and timestamp" # 6. 创建集合 print(f"正在创建集合 '{collection_name}'...") diff --git a/dsRag/Milvus/X4_InsertMathData.py b/dsRag/Milvus/X4_InsertMathData.py index 68fd0e35..0397debe 100644 --- a/dsRag/Milvus/X4_InsertMathData.py +++ b/dsRag/Milvus/X4_InsertMathData.py @@ -46,16 +46,15 @@ for filename in os.listdir(txt_dir): # 5. 获取当前时间和会话ID timestamp = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) - person_id = "MATH_DATA_" + str(hash(filename)) + document_id = "MATH_DATA_1" # 史校长的这本书定义为 MATH_DATA_1 # 6. 将文本转换为嵌入向量 embedding = text_to_embedding(content) # 7. 插入数据 entities = [ - [person_id], # person_id + [document_id], # document_id [content], # user_input - [""], # model_response (留空) [timestamp], # timestamp [embedding] # embedding ] diff --git a/dsRag/Milvus/X5_select_all_data.py b/dsRag/Milvus/X5_select_all_data.py index 2ca5528c..60cdfaa6 100644 --- a/dsRag/Milvus/X5_select_all_data.py +++ b/dsRag/Milvus/X5_select_all_data.py @@ -21,7 +21,7 @@ try: # 使用 Milvus 的 query 方法查询所有数据 results = collection_manager.collection.query( expr="", # 空表达式表示查询所有数据 - output_fields=["id", "person_id", "user_input", "model_response", "timestamp", "embedding"], # 指定返回的字段 + output_fields=["id", "document_id", "user_input", "timestamp", "embedding"], # 指定返回的字段 limit=1000 # 设置最大返回记录数 ) print("查询结果:") @@ -29,16 +29,14 @@ try: for result in results: try: # 获取字段值 - person_id = result["person_id"] + document_id = result["document_id"] user_input = result["user_input"] - model_response = result["model_response"] timestamp = result["timestamp"] embedding = result["embedding"] # 打印结果 print(f"ID: {result['id']}") - print(f"会话 ID: {person_id}") + print(f"文档 ID: {document_id}") print(f"用户问题: {user_input}") - print(f"大模型回复: {model_response}") print(f"时间: {timestamp}") print(f"向量: {embedding[:5]}...") # 只打印前 5 维向量 print("-" * 40) # 分隔线 diff --git a/dsRag/Milvus/X6_search_near_data.py b/dsRag/Milvus/X6_search_near_data.py index 80738407..de8d1bf7 100644 --- a/dsRag/Milvus/X6_search_near_data.py +++ b/dsRag/Milvus/X6_search_near_data.py @@ -62,9 +62,8 @@ if results: # 查询非向量字段 record = collection_manager.query_by_id(hit.id) print(f"ID: {hit.id}") - print(f"会话 ID: {record['person_id']}") + print(f"文档 ID: {record['document_id']}") print(f"用户问题: {record['user_input']}") - print(f"大模型回复: {record['model_response']}") print(f"时间: {record['timestamp']}") print(f"距离: {hit.distance}") print("-" * 40) # 分隔线 diff --git a/dsRag/Start.py b/dsRag/Start.py index 61c45016..a20ccc31 100644 --- a/dsRag/Start.py +++ b/dsRag/Start.py @@ -32,11 +32,11 @@ print(f"模型加载成功,词向量维度: {model.vector_size}") async def lifespan(app: FastAPI): # 初始化Milvus连接池 app.state.milvus_pool = MilvusConnectionPool(host=MS_HOST, port=MS_PORT, max_connections=MS_MAX_CONNECTIONS) - + # 初始化集合管理器 app.state.collection_manager = MilvusCollectionManager(MS_COLLECTION_NAME) app.state.collection_manager.load_collection() - + # 初始化DeepSeek客户端 app.state.deepseek_client = OpenAI( api_key=Config.DEEPSEEK_API_KEY, @@ -47,8 +47,10 @@ async def lifespan(app: FastAPI): # 关闭Milvus连接池 app.state.milvus_pool.close() + app = FastAPI(lifespan=lifespan) + # 将文本转换为嵌入向量 def text_to_embedding(text): words = jieba.lcut(text) # 使用 jieba 分词 @@ -63,11 +65,12 @@ def text_to_embedding(text): print("未找到有效词,返回零向量") return [0.0] * model.vector_size + async def generate_stream(client, milvus_pool, collection_manager, query): """生成SSE流""" # 从连接池获取连接 connection = milvus_pool.get_connection() - + try: # 1. 将查询文本转换为向量 current_embedding = text_to_embedding(query) @@ -79,10 +82,10 @@ async def generate_stream(client, milvus_pool, collection_manager, query): } # 7. 将文本转换为嵌入向量 results = collection_manager.search(current_embedding, search_params, limit=5) # 返回 2 条结果 - + # 3. 处理搜索结果 print("最相关的历史对话:") - context="" + context = "" if results: for hits in results: for hit in hits: @@ -90,10 +93,9 @@ async def generate_stream(client, milvus_pool, collection_manager, query): # 查询非向量字段 record = collection_manager.query_by_id(hit.id) print(f"ID: {hit.id}") - print(f"会话 ID: {record['person_id']}") + print(f"文档 ID: {record['document_id']}") print(f"用户问题: {record['user_input']}") - context=context+record['user_input'] - print(f"大模型回复: {record['model_response']}") + context = context + record['user_input'] print(f"时间: {record['timestamp']}") print(f"距离: {hit.distance}") print("-" * 40) # 分隔线 @@ -101,7 +103,7 @@ async def generate_stream(client, milvus_pool, collection_manager, query): print(f"查询失败: {e}") else: print("未找到相关历史对话,请检查查询参数或数据。") - + prompt = f"""根据以下关于'{query}'的相关信息,# Role: 信息检索与回答助手 ## Profile @@ -154,7 +156,7 @@ async def generate_stream(client, milvus_pool, collection_manager, query): 相关信息: {context}""" - + response = client.chat.completions.create( model="deepseek-chat", messages=[ @@ -164,7 +166,7 @@ async def generate_stream(client, milvus_pool, collection_manager, query): temperature=0.3, stream=True ) - + for chunk in response: if chunk.choices[0].delta.content: yield {"data": chunk.choices[0].delta.content} @@ -174,9 +176,13 @@ async def generate_stream(client, milvus_pool, collection_manager, query): finally: # 释放连接 milvus_pool.release_connection(connection) + + """ http://10.10.21.22:8000/api/rag?query=小学数学中有哪些模型 """ + + @app.post("/api/rag") async def rag_stream(request: Request, query: str = Body(...)): """RAG+DeepSeek流式接口""" @@ -189,7 +195,8 @@ async def rag_stream(request: Request, query: str = Body(...)): ) ) + app.mount("/static", StaticFiles(directory="Static"), name="static") if __name__ == "__main__": - uvicorn.run(app, host="0.0.0.0", port=8000) \ No newline at end of file + uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/dsRag/Txt/~$数学(史校长).docx b/dsRag/Txt/~$数学(史校长).docx deleted file mode 100644 index f286835d..00000000 Binary files a/dsRag/Txt/~$数学(史校长).docx and /dev/null differ diff --git a/dsRag/requirements.txt b/dsRag/requirements.txt index f035c6e6..8b72d5e1 100644 Binary files a/dsRag/requirements.txt and b/dsRag/requirements.txt differ