'commit'
This commit is contained in:
@@ -1,3 +1,12 @@
|
||||
# Elasticsearch配置
|
||||
ES_CONFIG = {
|
||||
'hosts': ['https://localhost:9200'],
|
||||
'basic_auth': ('elastic', 'jv9h8uwRrRxmDi1dq6u8'),
|
||||
'verify_certs': False,
|
||||
'index_name': 'ds_db', # 默认索引名称
|
||||
'student_info_index': 'student_info' # 添加student_info索引名称配置
|
||||
}
|
||||
|
||||
# 阿里云的配置信息【绘智科技】
|
||||
ALY_AK = 'LTAI5tE4tgpGcKWhbZg6C4bh'
|
||||
ALY_SK = 'oizcTOZ8izbGUouboC00RcmGE8vBQ1'
|
||||
@@ -45,8 +54,8 @@ GLM_MODEL_NAME = "THUDM/GLM-4.1V-9B-Thinking"
|
||||
# 阿里云API信息【YLT】
|
||||
ALY_LLM_API_KEY = "sk-f6da0c787eff4b0389e4ad03a35a911f"
|
||||
ALY_LLM_BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
ALY_LLM_MODEL_NAME = "deepseek-r1"
|
||||
# ALY_LLM_MODEL_NAME = "qwen-plus"
|
||||
#ALY_LLM_MODEL_NAME = "deepseek-r1"
|
||||
ALY_LLM_MODEL_NAME = "qwen-plus"
|
||||
# ALY_LLM_MODEL_NAME = "deepseek-v3"
|
||||
|
||||
# 华为云云存储
|
||||
|
Binary file not shown.
32
dsLightRag/ElasticSearch/T1_RebuildMapping.py
Normal file
32
dsLightRag/ElasticSearch/T1_RebuildMapping.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from Config import Config
|
||||
from ElasticSearch.Utils.EsSearchUtil import EsSearchUtil
|
||||
|
||||
# 创建EsSearchUtil实例
|
||||
search_util = EsSearchUtil(Config.ES_CONFIG)
|
||||
|
||||
def rebuild_index(index_name):
|
||||
"""
|
||||
重建指定的索引
|
||||
|
||||
参数:
|
||||
index_name: 要重建的索引名称
|
||||
|
||||
返回:
|
||||
bool: 操作是否成功
|
||||
"""
|
||||
print(f"开始重建索引: {index_name}")
|
||||
if search_util.rebuild_mapping(index_name):
|
||||
print(f"重建索引 '{index_name}' 操作成功")
|
||||
return True
|
||||
else:
|
||||
print(f"重建索引 '{index_name}' 操作失败")
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 重建ds_db索引
|
||||
rebuild_index(Config.ES_CONFIG['index_name'])
|
||||
|
||||
# 重建student_info索引
|
||||
rebuild_index(Config.ES_CONFIG['student_info_index'])
|
||||
|
||||
print("所有索引重建操作完成")
|
38
dsLightRag/ElasticSearch/T2_Vector.py
Normal file
38
dsLightRag/ElasticSearch/T2_Vector.py
Normal file
@@ -0,0 +1,38 @@
|
||||
# pip install pydantic requests
|
||||
from ElasticSearch.Utils.VectorDBUtil import VectorDBUtil
|
||||
|
||||
|
||||
def main():
|
||||
# 模拟长字符串文档内容
|
||||
long_text = """混凝土是一种广泛使用的建筑材料,由水泥、砂、石子和水混合而成。它具有高强度、耐久性和良好的可塑性,被广泛应用于建筑、桥梁、道路等土木工程领域。
|
||||
|
||||
混凝土的历史可以追溯到古罗马时期,当时人们使用火山灰、石灰和碎石混合制成类似混凝土的材料。现代混凝土技术始于19世纪,随着波特兰水泥的发明而得到快速发展。
|
||||
|
||||
混凝土的性能取决于其配合比,包括水灰比、砂率等参数。水灰比是影响混凝土强度的关键因素,较小的水灰比通常会产生更高强度的混凝土。
|
||||
|
||||
为了改善混凝土的性能,常常会添加各种外加剂,如减水剂、早强剂、缓凝剂等。此外,还可以使用纤维增强、聚合物改性等技术来提高混凝土的韧性和耐久性。
|
||||
|
||||
在施工过程中,混凝土需要适当的养护,以确保其强度正常发展。养护措施包括浇水、覆盖保湿、蒸汽养护等。
|
||||
|
||||
随着建筑技术的发展,高性能混凝土、自密实混凝土、再生骨料混凝土等新型混凝土不断涌现,为土木工程领域提供了更多的选择。"""
|
||||
|
||||
# 创建工具实例
|
||||
vector_util = VectorDBUtil()
|
||||
|
||||
# 调用文本入库功能
|
||||
vector_util.text_to_vector_db(long_text)
|
||||
|
||||
# 调用文本查询功能
|
||||
query = "混凝土"
|
||||
reranked_results = vector_util.query_vector_db(query, k=4)
|
||||
|
||||
# 打印所有查询结果及其可信度
|
||||
print("最终查询结果:")
|
||||
for i, (result, score) in enumerate(reranked_results):
|
||||
print(f"结果 {i+1} (可信度: {score:.4f}):")
|
||||
print(result.page_content)
|
||||
print("---")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
28
dsLightRag/ElasticSearch/T3_InsertData.py
Normal file
28
dsLightRag/ElasticSearch/T3_InsertData.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from Config import Config
|
||||
from ElasticSearch.Utils.EsSearchUtil import EsSearchUtil
|
||||
|
||||
|
||||
def main():
|
||||
# 示例1:插入单个长文本
|
||||
long_text = """混凝土是一种广泛使用的建筑材料,由水泥、砂、石子和水混合而成。它具有高强度、耐久性和良好的可塑性,被广泛应用于建筑、桥梁、道路等土木工程领域。
|
||||
|
||||
混凝土的历史可以追溯到古罗马时期,当时人们使用火山灰、石灰和碎石混合制成类似混凝土的材料。现代混凝土技术始于19世纪,随着波特兰水泥的发明而得到快速发展。
|
||||
|
||||
混凝土的性能取决于其配合比,包括水灰比、砂率等参数。水灰比是影响混凝土强度的关键因素,较小的水灰比通常会产生更高强度的混凝土。
|
||||
|
||||
为了改善混凝土的性能,常常会添加各种外加剂,如减水剂、早强剂、缓凝剂等。此外,还可以使用纤维增强、聚合物改性等技术来提高混凝土的韧性和耐久性。
|
||||
|
||||
在施工过程中,混凝土需要适当的养护,以确保其强度正常发展。养护措施包括浇水、覆盖保湿、蒸汽养护等。
|
||||
|
||||
随着建筑技术的发展,高性能混凝土、自密实混凝土、再生骨料混凝土等新型混凝土不断涌现,为土木工程领域提供了更多的选择。"""
|
||||
|
||||
# 打标签
|
||||
tags = ["student_110"]
|
||||
|
||||
# 创建EsSearchUtil实例
|
||||
search_util = EsSearchUtil(Config.ES_CONFIG)
|
||||
search_util.insert_long_text_to_es(long_text, tags)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
28
dsLightRag/ElasticSearch/T4_SelectAllData.py
Normal file
28
dsLightRag/ElasticSearch/T4_SelectAllData.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from Config import Config
|
||||
from ElasticSearch.Utils.EsSearchUtil import EsSearchUtil
|
||||
|
||||
# 创建EsSearchUtil实例
|
||||
search_util = EsSearchUtil(Config.ES_CONFIG)
|
||||
|
||||
# 查询所有数据
|
||||
def select_all_data(index_name):
|
||||
try:
|
||||
# 调用EsSearchUtil中的select_all_data方法
|
||||
response = search_util.select_all_data()
|
||||
hits = response['hits']['hits']
|
||||
|
||||
if not hits:
|
||||
print(f"索引 {index_name} 中没有数据")
|
||||
else:
|
||||
print(f"索引 {index_name} 中共有 {len(hits)} 条数据:")
|
||||
for i, hit in enumerate(hits, 1):
|
||||
print(f"{i}. ID: {hit['_id']}")
|
||||
print(f" 内容: {hit['_source'].get('user_input', '')}")
|
||||
print(f" 标签: {hit['_source'].get('tags', '')}")
|
||||
print(f" 时间戳: {hit['_source'].get('timestamp', '')}")
|
||||
print("-" * 50)
|
||||
except Exception as e:
|
||||
print(f"查询出错: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
select_all_data(Config.ES_CONFIG['index_name'])
|
28
dsLightRag/ElasticSearch/T5_SelectByKeyWord.py
Normal file
28
dsLightRag/ElasticSearch/T5_SelectByKeyWord.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from Config import Config
|
||||
from ElasticSearch.Utils.EsSearchUtil import EsSearchUtil
|
||||
|
||||
|
||||
# 1. 创建EsSearchUtil实例(已封装连接池)
|
||||
search_util = EsSearchUtil(Config.ES_CONFIG)
|
||||
|
||||
# 2. 直接在代码中指定要查询的关键字
|
||||
query_keyword = "混凝土"
|
||||
|
||||
# 3. 执行查询并处理结果
|
||||
try:
|
||||
# 使用连接池进行查询
|
||||
results = search_util.text_search(query_keyword, size=1000)
|
||||
print(f"查询关键字 '{query_keyword}' 结果:")
|
||||
if results['hits']['hits']:
|
||||
for i, hit in enumerate(results['hits']['hits'], 1):
|
||||
doc = hit['_source']
|
||||
print(f"{i}. ID: {hit['_id']}")
|
||||
print(f" 标签: {doc['tags']['tags'] if 'tags' in doc and 'tags' in doc['tags'] else '无'}")
|
||||
print(f" 用户问题: {doc['user_input'] if 'user_input' in doc else '无'}")
|
||||
print(f" 时间: {doc['timestamp'] if 'timestamp' in doc else '无'}")
|
||||
print("-" * 50)
|
||||
else:
|
||||
print(f"未找到包含 '{query_keyword}' 的数据。")
|
||||
except Exception as e:
|
||||
print(f"查询失败: {e}")
|
||||
print(f"查询关键字: {query_keyword}")
|
29
dsLightRag/ElasticSearch/T6_SelectByVector.py
Normal file
29
dsLightRag/ElasticSearch/T6_SelectByVector.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from Config import Config
|
||||
from ElasticSearch.Utils.EsSearchUtil import EsSearchUtil
|
||||
|
||||
|
||||
def main():
|
||||
# 初始化搜索工具
|
||||
search_util = EsSearchUtil(Config.ES_CONFIG)
|
||||
|
||||
# 输入查询文本
|
||||
query = "混凝土"
|
||||
print(f"查询文本: {query}")
|
||||
|
||||
# 获取查询向量
|
||||
query_embedding = search_util.get_query_embedding(query)
|
||||
print(f"查询向量维度: {len(query_embedding)}")
|
||||
|
||||
# 向量搜索
|
||||
search_results = search_util.search_by_vector(query_embedding, k=10)
|
||||
print(f"向量搜索结果数量: {len(search_results)}")
|
||||
|
||||
# 结果重排
|
||||
reranked_results = search_util.rerank_results(query, search_results)
|
||||
|
||||
# 显示结果
|
||||
search_util.display_results(reranked_results)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
66
dsLightRag/ElasticSearch/T7_XiangLiangQuery.py
Normal file
66
dsLightRag/ElasticSearch/T7_XiangLiangQuery.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import logging
|
||||
|
||||
from Config.Config import ES_CONFIG
|
||||
from ElasticSearch.Utils.EsSearchUtil import EsSearchUtil
|
||||
|
||||
# 初始化日志
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 初始化EsSearchUtil
|
||||
search_util = EsSearchUtil(ES_CONFIG)
|
||||
|
||||
# 获取用户输入
|
||||
user_query = input("请输入查询语句(例如:高性能的混凝土): ")
|
||||
if not user_query:
|
||||
user_query = "高性能的混凝土"
|
||||
print(f"未输入查询语句,使用默认值: {user_query}")
|
||||
|
||||
query_tags = [] # 可以根据需要添加标签过滤
|
||||
|
||||
print(f"\n=== 开始执行查询 ===")
|
||||
print(f"原始查询文本: {user_query}")
|
||||
|
||||
try:
|
||||
# 1. 向量搜索
|
||||
print("\n=== 向量搜索阶段 ===")
|
||||
print("1. 文本向量化处理中...")
|
||||
query_embedding = search_util.get_query_embedding(user_query)
|
||||
print(f"2. 生成的查询向量维度: {len(query_embedding)}")
|
||||
print(f"3. 前3维向量值: {query_embedding[:3]}")
|
||||
|
||||
print("4. 正在执行Elasticsearch向量搜索...")
|
||||
vector_hits = search_util.search_by_vector(query_embedding, k=5)
|
||||
print(f"5. 向量搜索结果数量: {len(vector_hits)}")
|
||||
|
||||
# 向量结果重排
|
||||
print("6. 正在进行向量结果重排...")
|
||||
reranked_vector_results = search_util.rerank_results(user_query, vector_hits)
|
||||
print(f"7. 重排后向量结果数量: {len(reranked_vector_results)}")
|
||||
|
||||
# 2. 关键字搜索
|
||||
print("\n=== 关键字搜索阶段 ===")
|
||||
print("1. 正在执行Elasticsearch关键字搜索...")
|
||||
keyword_results = search_util.text_search(user_query, size=5)
|
||||
keyword_hits = keyword_results['hits']['hits']
|
||||
print(f"2. 关键字搜索结果数量: {len(keyword_hits)}")
|
||||
|
||||
# 3. 合并结果
|
||||
print("\n=== 合并搜索结果 ===")
|
||||
# 为关键字结果添加分数
|
||||
keyword_results_with_scores = [(doc, doc['_score']) for doc in keyword_hits]
|
||||
merged_results = search_util.merge_results(keyword_results_with_scores, reranked_vector_results)
|
||||
print(f"合并后唯一结果数量: {len(merged_results)}")
|
||||
|
||||
# 4. 打印最终结果
|
||||
print("\n=== 最终搜索结果 ===")
|
||||
for i, (doc, score, source) in enumerate(merged_results, 1):
|
||||
print(f"{i}. 文档ID: {doc['_id']}, 分数: {score:.2f}, 来源: {source}")
|
||||
print(f" 内容: {doc['_source']['user_input']}")
|
||||
print(" --- ")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"搜索过程中发生错误: {str(e)}")
|
||||
print(f"搜索失败: {str(e)}")
|
110
dsLightRag/ElasticSearch/Utils/ElasticsearchCollectionManager.py
Normal file
110
dsLightRag/ElasticSearch/Utils/ElasticsearchCollectionManager.py
Normal file
@@ -0,0 +1,110 @@
|
||||
from elasticsearch.exceptions import NotFoundError
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ElasticsearchCollectionManager:
|
||||
def __init__(self, index_name):
|
||||
"""
|
||||
初始化Elasticsearch索引管理器
|
||||
:param index_name: Elasticsearch索引名称
|
||||
"""
|
||||
self.index_name = index_name
|
||||
|
||||
def load_collection(self, es_connection):
|
||||
"""
|
||||
加载索引,如果不存在则创建
|
||||
:param es_connection: Elasticsearch连接
|
||||
"""
|
||||
try:
|
||||
if not es_connection.indices.exists(index=self.index_name):
|
||||
logger.warning(f"Index {self.index_name} does not exist, creating new index")
|
||||
self._create_index(es_connection)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load collection: {str(e)}")
|
||||
raise
|
||||
|
||||
def _create_index(self, es_connection):
|
||||
"""创建新的Elasticsearch索引"""
|
||||
mapping = {
|
||||
"mappings": {
|
||||
"properties": {
|
||||
"user_input": {"type": "text"},
|
||||
"tags": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"tags": {"type": "keyword"},
|
||||
"full_content": {"type": "text"}
|
||||
}
|
||||
},
|
||||
"timestamp": {"type": "date"},
|
||||
"embedding": {"type": "dense_vector", "dims": 200}
|
||||
}
|
||||
}
|
||||
}
|
||||
es_connection.indices.create(index=self.index_name, body=mapping)
|
||||
|
||||
def search(self, es_connection, query_embedding, search_params, expr=None, limit=5):
|
||||
"""
|
||||
执行混合搜索(向量+关键字)
|
||||
:param es_connection: Elasticsearch连接
|
||||
:param query_embedding: 查询向量
|
||||
:param search_params: 搜索参数
|
||||
:param expr: 过滤表达式
|
||||
:param limit: 返回结果数量
|
||||
:return: 搜索结果
|
||||
"""
|
||||
try:
|
||||
# 构建查询
|
||||
query = {
|
||||
"query": {
|
||||
"script_score": {
|
||||
"query": {
|
||||
"bool": {
|
||||
"must": []
|
||||
}
|
||||
},
|
||||
"script": {
|
||||
"source": "cosineSimilarity(params.query_vector, 'embedding') + 1.0",
|
||||
"params": {"query_vector": query_embedding}
|
||||
}
|
||||
}
|
||||
},
|
||||
"size": limit
|
||||
}
|
||||
|
||||
# 添加标签过滤条件
|
||||
if expr:
|
||||
query["query"]["script_score"]["query"]["bool"]["must"].append(
|
||||
{"nested": {
|
||||
"path": "tags",
|
||||
"query": {
|
||||
"terms": {"tags.tags": expr.split(" OR ")}
|
||||
}
|
||||
}}
|
||||
)
|
||||
|
||||
logger.info(f"Executing search with query: {query}")
|
||||
response = es_connection.search(index=self.index_name, body=query)
|
||||
return response["hits"]["hits"]
|
||||
except Exception as e:
|
||||
logger.error(f"Search failed: {str(e)}")
|
||||
raise
|
||||
|
||||
def query_by_id(self, es_connection, doc_id):
|
||||
"""
|
||||
根据ID查询文档
|
||||
:param es_connection: Elasticsearch连接
|
||||
:param doc_id: 文档ID
|
||||
:return: 文档内容
|
||||
"""
|
||||
try:
|
||||
response = es_connection.get(index=self.index_name, id=doc_id)
|
||||
return response["_source"]
|
||||
except NotFoundError:
|
||||
logger.warning(f"Document with id {doc_id} not found")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to query document by id: {str(e)}")
|
||||
raise
|
@@ -0,0 +1,65 @@
|
||||
from elasticsearch import Elasticsearch
|
||||
import threading
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ElasticsearchConnectionPool:
|
||||
def __init__(self, hosts, basic_auth, verify_certs=False, max_connections=50):
|
||||
"""
|
||||
初始化Elasticsearch连接池
|
||||
:param hosts: Elasticsearch服务器地址
|
||||
:param basic_auth: 认证信息(username, password)
|
||||
:param verify_certs: 是否验证SSL证书
|
||||
:param max_connections: 最大连接数
|
||||
"""
|
||||
self.hosts = hosts
|
||||
self.basic_auth = basic_auth
|
||||
self.verify_certs = verify_certs
|
||||
self.max_connections = max_connections
|
||||
self._connections = []
|
||||
self._lock = threading.Lock()
|
||||
self._initialize_pool()
|
||||
|
||||
def _initialize_pool(self):
|
||||
"""初始化连接池"""
|
||||
for _ in range(self.max_connections):
|
||||
self._connections.append(self._create_connection())
|
||||
|
||||
def _create_connection(self):
|
||||
"""创建新的Elasticsearch连接"""
|
||||
return Elasticsearch(
|
||||
hosts=self.hosts,
|
||||
basic_auth=self.basic_auth,
|
||||
verify_certs=self.verify_certs
|
||||
)
|
||||
|
||||
def get_connection(self):
|
||||
"""从连接池获取一个连接"""
|
||||
with self._lock:
|
||||
if not self._connections:
|
||||
logger.warning("Connection pool exhausted, creating new connection")
|
||||
return self._create_connection()
|
||||
return self._connections.pop()
|
||||
|
||||
def release_connection(self, connection):
|
||||
"""释放连接回连接池"""
|
||||
with self._lock:
|
||||
if len(self._connections) < self.max_connections:
|
||||
self._connections.append(connection)
|
||||
else:
|
||||
try:
|
||||
connection.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to close connection: {str(e)}")
|
||||
|
||||
def close(self):
|
||||
"""关闭所有连接"""
|
||||
with self._lock:
|
||||
for conn in self._connections:
|
||||
try:
|
||||
conn.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to close connection: {str(e)}")
|
||||
self._connections.clear()
|
664
dsLightRag/ElasticSearch/Utils/EsSearchUtil.py
Normal file
664
dsLightRag/ElasticSearch/Utils/EsSearchUtil.py
Normal file
@@ -0,0 +1,664 @@
|
||||
import json
|
||||
import logging
|
||||
import warnings
|
||||
import hashlib
|
||||
import time
|
||||
|
||||
import jieba
|
||||
import requests
|
||||
|
||||
from ElasticSearch.Utils.ElasticsearchConnectionPool import ElasticsearchConnectionPool
|
||||
from langchain_core.documents import Document
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
from pydantic import SecretStr
|
||||
from Config import Config
|
||||
from typing import List, Tuple, Dict
|
||||
# 初始化日志
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
class EsSearchUtil:
|
||||
# 存储对话历史的字典,键为会话ID,值为对话历史列表
|
||||
conversation_history = {}
|
||||
|
||||
# 存储学生信息的字典,键为用户ID,值为学生信息
|
||||
student_info = {}
|
||||
|
||||
# 年级关键词词典
|
||||
GRADE_KEYWORDS = {
|
||||
'一年级': ['一年级', '初一'],
|
||||
'二年级': ['二年级', '初二'],
|
||||
'三年级': ['三年级', '初三'],
|
||||
'四年级': ['四年级'],
|
||||
'五年级': ['五年级'],
|
||||
'六年级': ['六年级'],
|
||||
'七年级': ['七年级', '初一'],
|
||||
'八年级': ['八年级', '初二'],
|
||||
'九年级': ['九年级', '初三'],
|
||||
'高一': ['高一'],
|
||||
'高二': ['高二'],
|
||||
'高三': ['高三']
|
||||
}
|
||||
|
||||
# 最大对话历史轮数
|
||||
MAX_HISTORY_ROUNDS = 10
|
||||
|
||||
# 初始化停用词表
|
||||
STOPWORDS = set(
|
||||
['的', '了', '在', '是', '我', '有', '和', '就', '不', '人', '都', '一', '一个', '上', '也', '很', '到', '说',
|
||||
'要', '去', '你', '会', '着', '没有', '看', '好', '自己', '这'])
|
||||
|
||||
def __init__(self, es_config):
|
||||
"""
|
||||
初始化Elasticsearch搜索工具
|
||||
:param es_config: Elasticsearch配置字典,包含hosts, username, password, index_name等
|
||||
"""
|
||||
# 抑制HTTPS相关警告
|
||||
warnings.filterwarnings('ignore', message='Connecting to .* using TLS with verify_certs=False is insecure')
|
||||
warnings.filterwarnings('ignore', message='Unverified HTTPS request is being made to host')
|
||||
|
||||
self.es_config = es_config
|
||||
|
||||
# 初始化连接池
|
||||
self.es_pool = ElasticsearchConnectionPool(
|
||||
hosts=es_config['hosts'],
|
||||
basic_auth=es_config['basic_auth'],
|
||||
verify_certs=es_config.get('verify_certs', False),
|
||||
max_connections=50
|
||||
)
|
||||
|
||||
self.index_name = es_config['index_name']
|
||||
logger.info(f"EsSearchUtil初始化成功,索引名称: {self.index_name}")
|
||||
|
||||
def rebuild_mapping(self, index_name=None):
|
||||
"""
|
||||
重建Elasticsearch索引和mapping结构
|
||||
|
||||
参数:
|
||||
index_name: 可选,指定要重建的索引名称,默认使用初始化时的索引名称
|
||||
|
||||
返回:
|
||||
bool: 操作是否成功
|
||||
"""
|
||||
try:
|
||||
# 从连接池获取连接
|
||||
conn = self.es_pool.get_connection()
|
||||
|
||||
# 使用指定的索引名称或默认索引名称
|
||||
target_index = index_name if index_name else self.index_name
|
||||
logger.info(f"开始重建索引: {target_index}")
|
||||
|
||||
# 定义mapping结构
|
||||
if target_index == 'student_info':
|
||||
mapping = {
|
||||
"mappings": {
|
||||
"properties": {
|
||||
"user_id": {"type": "keyword"},
|
||||
"grade": {"type": "keyword"},
|
||||
"recent_questions": {"type": "text"},
|
||||
"learned_knowledge": {"type": "text"},
|
||||
"updated_at": {"type": "date"}
|
||||
}
|
||||
}
|
||||
}
|
||||
else:
|
||||
mapping = {
|
||||
"mappings": {
|
||||
"properties": {
|
||||
"embedding": {
|
||||
"type": "dense_vector",
|
||||
"dims": Config.EMBED_DIM,
|
||||
"index": True,
|
||||
"similarity": "l2_norm"
|
||||
},
|
||||
"user_input": {"type": "text"},
|
||||
"tags": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"tags": {"type": "keyword"},
|
||||
"full_content": {"type": "text"}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# 检查索引是否存在,存在则删除
|
||||
if conn.indices.exists(index=target_index):
|
||||
conn.indices.delete(index=target_index)
|
||||
logger.info(f"删除已存在的索引 '{target_index}'")
|
||||
print(f"删除已存在的索引 '{target_index}'")
|
||||
|
||||
# 创建索引和mapping
|
||||
conn.indices.create(index=target_index, body=mapping)
|
||||
logger.info(f"索引 '{target_index}' 创建成功,mapping结构已设置")
|
||||
print(f"索引 '{target_index}' 创建成功,mapping结构已设置。")
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"重建索引 '{target_index}' 失败: {str(e)}")
|
||||
print(f"重建索引 '{target_index}' 失败: {e}")
|
||||
|
||||
# 提供认证错误的具体提示
|
||||
if 'AuthenticationException' in str(e):
|
||||
print("认证失败提示: 请检查Config.py中的ES_CONFIG配置,确保用户名和密码正确。")
|
||||
logger.error("认证失败: 请检查Config.py中的ES_CONFIG配置,确保用户名和密码正确。")
|
||||
|
||||
return False
|
||||
finally:
|
||||
# 释放连接回连接池
|
||||
self.es_pool.release_connection(conn)
|
||||
|
||||
def text_search(self, query, size=10):
|
||||
# 从连接池获取连接
|
||||
conn = self.es_pool.get_connection()
|
||||
try:
|
||||
# 使用连接执行搜索
|
||||
result = conn.search(
|
||||
index=self.es_config['index_name'],
|
||||
query={"match": {"user_input": query}},
|
||||
size=size
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"文本搜索失败: {str(e)}")
|
||||
raise
|
||||
finally:
|
||||
# 释放连接回连接池
|
||||
self.es_pool.release_connection(conn)
|
||||
|
||||
def select_all_data(self, size=1000):
|
||||
"""
|
||||
查询索引中的所有数据
|
||||
|
||||
参数:
|
||||
size: 返回的最大结果数量,默认1000
|
||||
|
||||
返回:
|
||||
dict: 查询结果
|
||||
"""
|
||||
# 从连接池获取连接
|
||||
conn = self.es_pool.get_connection()
|
||||
try:
|
||||
# 构建查询条件 - 匹配所有文档
|
||||
query = {
|
||||
"query": {
|
||||
"match_all": {}
|
||||
},
|
||||
"size": size
|
||||
}
|
||||
|
||||
# 执行查询
|
||||
response = conn.search(index=self.es_config['index_name'], body=query)
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"查询所有数据失败: {str(e)}")
|
||||
raise
|
||||
finally:
|
||||
# 释放连接回连接池
|
||||
self.es_pool.release_connection(conn)
|
||||
|
||||
def split_text_into_chunks(self,text: str, chunk_size: int = 200, chunk_overlap: int = 0) -> list:
|
||||
"""
|
||||
将文本切割成块
|
||||
|
||||
参数:
|
||||
text: 要切割的文本
|
||||
chunk_size: 每个块的大小
|
||||
chunk_overlap: 块之间的重叠大小
|
||||
|
||||
返回:
|
||||
list: 文本块列表
|
||||
"""
|
||||
# 创建文档对象
|
||||
docs = [Document(page_content=text, metadata={"source": "simulated_document"})]
|
||||
|
||||
# 切割文档
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=chunk_size, chunk_overlap=chunk_overlap, add_start_index=True
|
||||
)
|
||||
all_splits = text_splitter.split_documents(docs)
|
||||
print(f"切割后的文档块数量:{len(all_splits)}")
|
||||
|
||||
return [split.page_content for split in all_splits]
|
||||
|
||||
def insert_long_text_to_es(self,long_text: str, tags: list = None) -> bool:
|
||||
"""
|
||||
将长文本切割后向量化并插入到Elasticsearch,基于文本内容哈希实现去重
|
||||
|
||||
参数:
|
||||
long_text: 要插入的长文本
|
||||
tags: 可选的标签列表
|
||||
|
||||
返回:
|
||||
bool: 插入是否成功
|
||||
"""
|
||||
try:
|
||||
# 1. 创建EsSearchUtil实例以使用连接池
|
||||
search_util = EsSearchUtil(Config.ES_CONFIG)
|
||||
|
||||
# 2. 从连接池获取连接
|
||||
conn = search_util.es_pool.get_connection()
|
||||
|
||||
# # 3. 检查索引是否存在,不存在则创建
|
||||
index_name = Config.ES_CONFIG['index_name']
|
||||
# if not conn.indices.exists(index=index_name):
|
||||
# # 定义mapping结构
|
||||
# mapping = {
|
||||
# "mappings": {
|
||||
# "properties": {
|
||||
# "embedding": {
|
||||
# "type": "dense_vector",
|
||||
# "dims": Config.EMBED_DIM, # 根据实际embedding维度调整
|
||||
# "index": True,
|
||||
# "similarity": "l2_norm"
|
||||
# },
|
||||
# "user_input": {"type": "text"},
|
||||
# "tags": {
|
||||
# "type": "object",
|
||||
# "properties": {
|
||||
# "tags": {"type": "keyword"},
|
||||
# "full_content": {"type": "text"}
|
||||
# }
|
||||
# },
|
||||
# "timestamp": {"type": "date"}
|
||||
# }
|
||||
# }
|
||||
# }
|
||||
# conn.indices.create(index=index_name, body=mapping)
|
||||
# print(f"索引 '{index_name}' 创建成功")
|
||||
|
||||
# 4. 切割文本
|
||||
text_chunks = self.split_text_into_chunks(long_text)
|
||||
|
||||
# 5. 准备标签
|
||||
if tags is None:
|
||||
tags = ["general_text"]
|
||||
|
||||
# 6. 获取当前时间
|
||||
timestamp = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
|
||||
# 7. 创建嵌入模型
|
||||
embeddings = OpenAIEmbeddings(
|
||||
model=Config.EMBED_MODEL_NAME,
|
||||
base_url=Config.EMBED_BASE_URL,
|
||||
api_key=SecretStr(Config.EMBED_API_KEY)
|
||||
)
|
||||
|
||||
# 8. 为每个文本块生成向量并插入
|
||||
for i, chunk in enumerate(text_chunks):
|
||||
# 生成文本块的哈希值作为文档ID
|
||||
doc_id = hashlib.md5(chunk.encode('utf-8')).hexdigest()
|
||||
|
||||
# 检查文档是否已存在
|
||||
if conn.exists(index=index_name, id=doc_id):
|
||||
print(f"文档块 {i+1} 已存在,跳过插入: {doc_id}")
|
||||
continue
|
||||
|
||||
# 生成文本块的嵌入向量
|
||||
embedding = embeddings.embed_documents([chunk])[0]
|
||||
|
||||
# 准备文档数据
|
||||
doc = {
|
||||
'tags': {"tags": tags, "full_content": long_text},
|
||||
'user_input': chunk,
|
||||
'timestamp': timestamp,
|
||||
'embedding': embedding
|
||||
}
|
||||
|
||||
# 插入数据到Elasticsearch
|
||||
conn.index(index=index_name, id=doc_id, document=doc)
|
||||
print(f"文档块 {i+1} 插入成功: {doc_id}")
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"插入数据失败: {e}")
|
||||
return False
|
||||
finally:
|
||||
# 确保释放连接回连接池
|
||||
if 'conn' in locals() and 'search_util' in locals():
|
||||
search_util.es_pool.release_connection(conn)
|
||||
|
||||
def get_query_embedding(self, query: str) -> list:
|
||||
"""
|
||||
将查询文本转换为向量
|
||||
|
||||
参数:
|
||||
query: 查询文本
|
||||
|
||||
返回:
|
||||
list: 向量表示
|
||||
"""
|
||||
# 创建嵌入模型
|
||||
embeddings = OpenAIEmbeddings(
|
||||
model=Config.EMBED_MODEL_NAME,
|
||||
base_url=Config.EMBED_BASE_URL,
|
||||
api_key=SecretStr(Config.EMBED_API_KEY)
|
||||
)
|
||||
|
||||
# 生成查询向量
|
||||
query_embedding = embeddings.embed_query(query)
|
||||
return query_embedding
|
||||
|
||||
def rerank_results(self, query: str, results: list) -> list:
|
||||
"""
|
||||
使用重排模型对搜索结果进行重排
|
||||
|
||||
参数:
|
||||
query: 查询文本
|
||||
results: 搜索结果列表
|
||||
|
||||
返回:
|
||||
list: 重排后的结果列表,每个元素是(文档对象, 分数)的元组
|
||||
"""
|
||||
if not results:
|
||||
print("警告: 没有搜索结果可供重排")
|
||||
return []
|
||||
|
||||
try:
|
||||
# 准备重排请求数据
|
||||
# 确保doc是字典并包含'_source'和'user_input'字段
|
||||
documents = []
|
||||
valid_results = []
|
||||
for i, doc in enumerate(results):
|
||||
if isinstance(doc, dict) and '_source' in doc and 'user_input' in doc['_source']:
|
||||
documents.append(doc['_source']['user_input'])
|
||||
valid_results.append(doc)
|
||||
else:
|
||||
print(f"警告: 结果项 {i} 格式不正确,跳过该结果")
|
||||
print(f"结果项内容: {doc}")
|
||||
|
||||
if not documents:
|
||||
print("警告: 没有有效的文档可供重排")
|
||||
# 返回原始结果,但转换为(结果, 分数)的元组格式
|
||||
return [(doc, doc.get('_score', 0.0)) for doc in results]
|
||||
|
||||
rerank_data = {
|
||||
"model": Config.RERANK_MODEL,
|
||||
"query": query,
|
||||
"documents": documents,
|
||||
"top_n": len(documents)
|
||||
}
|
||||
|
||||
# 调用重排API
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {Config.RERANK_BINDING_API_KEY}"
|
||||
}
|
||||
|
||||
response = requests.post(Config.RERANK_BASE_URL, headers=headers, data=json.dumps(rerank_data))
|
||||
response.raise_for_status() # 检查请求是否成功
|
||||
rerank_result = response.json()
|
||||
|
||||
# 处理重排结果
|
||||
reranked_results = []
|
||||
if "results" in rerank_result:
|
||||
for item in rerank_result["results"]:
|
||||
doc_idx = item.get("index")
|
||||
score = item.get("relevance_score", 0.0)
|
||||
if 0 <= doc_idx < len(valid_results):
|
||||
result = valid_results[doc_idx]
|
||||
reranked_results.append((result, score))
|
||||
else:
|
||||
print("警告: 无法识别重排API响应格式")
|
||||
# 返回原始结果,但转换为(结果, 分数)的元组格式
|
||||
reranked_results = [(doc, doc.get('_score', 0.0)) for doc in valid_results]
|
||||
|
||||
print(f"重排后结果数量:{len(reranked_results)}")
|
||||
return reranked_results
|
||||
|
||||
except Exception as e:
|
||||
print(f"重排失败: {e}")
|
||||
print("将使用原始搜索结果")
|
||||
# 返回原始结果,但转换为(结果, 分数)的元组格式
|
||||
return [(doc, doc.get('_score', 0.0)) for doc in results]
|
||||
|
||||
def search_by_vector(self, query_embedding: list, k: int = 10) -> list:
|
||||
"""
|
||||
根据向量进行相似性搜索
|
||||
|
||||
参数:
|
||||
query_embedding: 查询向量
|
||||
k: 返回的结果数量
|
||||
|
||||
返回:
|
||||
list: 搜索结果列表
|
||||
"""
|
||||
try:
|
||||
# 从连接池获取连接
|
||||
conn = self.es_pool.get_connection()
|
||||
index_name = Config.ES_CONFIG['index_name']
|
||||
|
||||
# 执行向量搜索
|
||||
response = conn.search(
|
||||
index=index_name,
|
||||
body={
|
||||
"query": {
|
||||
"script_score": {
|
||||
"query": {"match_all": {}},
|
||||
"script": {
|
||||
"source": "cosineSimilarity(params.query_vector, 'embedding') + 1.0",
|
||||
"params": {
|
||||
"query_vector": query_embedding
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"size": k
|
||||
}
|
||||
)
|
||||
|
||||
# 提取结果
|
||||
# 确保我们提取的是 hits.hits 部分
|
||||
if 'hits' in response and 'hits' in response['hits']:
|
||||
results = response['hits']['hits']
|
||||
print(f"向量搜索结果数量: {len(results)}")
|
||||
return results
|
||||
else:
|
||||
print("警告: 向量搜索响应格式不正确")
|
||||
print(f"响应内容: {response}")
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
print(f"向量搜索失败: {e}")
|
||||
return []
|
||||
finally:
|
||||
# 释放连接回连接池
|
||||
self.es_pool.release_connection(conn)
|
||||
|
||||
def display_results(self, results: list, show_score: bool = True) -> None:
|
||||
"""
|
||||
展示搜索结果
|
||||
|
||||
参数:
|
||||
results: 搜索结果列表
|
||||
show_score: 是否显示分数
|
||||
"""
|
||||
if not results:
|
||||
print("没有找到匹配的结果。")
|
||||
return
|
||||
|
||||
print(f"找到 {len(results)} 条结果:\n")
|
||||
for i, item in enumerate(results, 1):
|
||||
print(f"结果 {i}:")
|
||||
try:
|
||||
# 检查item是否为元组格式 (result, score)
|
||||
if isinstance(item, tuple):
|
||||
if len(item) >= 2:
|
||||
result, score = item[0], item[1]
|
||||
else:
|
||||
result, score = item[0], 0.0
|
||||
else:
|
||||
# 如果不是元组,假设item就是result
|
||||
result = item
|
||||
score = result.get('_score', 0.0)
|
||||
|
||||
# 确保result是字典类型
|
||||
if not isinstance(result, dict):
|
||||
print(f"警告: 结果项 {i} 不是字典类型,跳过显示")
|
||||
print(f"结果项内容: {result}")
|
||||
print("---")
|
||||
continue
|
||||
|
||||
# 尝试获取user_input内容
|
||||
if '_source' in result and 'user_input' in result['_source']:
|
||||
content = result['_source']['user_input']
|
||||
print(f"内容: {content}")
|
||||
elif 'user_input' in result:
|
||||
content = result['user_input']
|
||||
print(f"内容: {content}")
|
||||
else:
|
||||
print(f"警告: 结果项 {i} 缺少'user_input'字段")
|
||||
print(f"结果项内容: {result}")
|
||||
print("---")
|
||||
continue
|
||||
|
||||
# 显示分数
|
||||
if show_score:
|
||||
print(f"分数: {score:.4f}")
|
||||
|
||||
# 如果有标签信息,也显示出来
|
||||
if '_source' in result and 'tags' in result['_source']:
|
||||
tags = result['_source']['tags']
|
||||
if isinstance(tags, dict) and 'tags' in tags:
|
||||
print(f"标签: {tags['tags']}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"处理结果项 {i} 时出错: {str(e)}")
|
||||
print(f"结果项内容: {item}")
|
||||
print("---")
|
||||
|
||||
def merge_results(self, keyword_results: List[Tuple[Dict, float]], vector_results: List[Tuple[Dict, float]]) -> List[Tuple[Dict, float, str]]:
|
||||
"""
|
||||
合并关键字搜索和向量搜索结果
|
||||
|
||||
参数:
|
||||
keyword_results: 关键字搜索结果列表,每个元素是(文档, 分数)元组
|
||||
vector_results: 向量搜索结果列表,每个元素是(文档, 分数)元组
|
||||
|
||||
返回:
|
||||
list: 合并后的结果列表,每个元素是(文档, 分数, 来源)元组
|
||||
"""
|
||||
# 标记结果来源并合并
|
||||
all_results = []
|
||||
for doc, score in keyword_results:
|
||||
all_results.append((doc, score, "关键字搜索"))
|
||||
for doc, score in vector_results:
|
||||
all_results.append((doc, score, "向量搜索"))
|
||||
|
||||
# 去重并按分数排序
|
||||
unique_results = {}
|
||||
for doc, score, source in all_results:
|
||||
doc_id = doc['_id']
|
||||
if doc_id not in unique_results or score > unique_results[doc_id][1]:
|
||||
unique_results[doc_id] = (doc, score, source)
|
||||
|
||||
# 按分数降序排序
|
||||
sorted_results = sorted(unique_results.values(), key=lambda x: x[1], reverse=True)
|
||||
return sorted_results
|
||||
|
||||
# 添加函数:保存学生信息到ES
|
||||
def save_student_info_to_es(self,user_id, info):
|
||||
"""将学生信息保存到Elasticsearch"""
|
||||
try:
|
||||
# 使用用户ID作为文档ID
|
||||
doc_id = f"student_{user_id}"
|
||||
# 准备文档内容
|
||||
doc = {
|
||||
"user_id": user_id,
|
||||
"info": info,
|
||||
"update_time": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
}
|
||||
# 从连接池获取连接
|
||||
es_conn = self.es_pool.get_connection()
|
||||
try:
|
||||
# 确保索引存在,如果不存在则创建
|
||||
es_conn.index(index="student_info", id=doc_id, document=doc)
|
||||
logger.info(f"学生 {user_id} 的信息已保存到ES: {info}")
|
||||
finally:
|
||||
# 释放连接回连接池
|
||||
self.es_pool.release_connection(es_conn)
|
||||
except Exception as e:
|
||||
logger.error(f"保存学生信息到ES失败: {str(e)}", exc_info=True)
|
||||
|
||||
# 添加函数:从ES获取学生信息
|
||||
def get_student_info_from_es(self,user_id):
|
||||
"""从Elasticsearch获取学生信息"""
|
||||
try:
|
||||
doc_id = f"student_{user_id}"
|
||||
# 从连接池获取连接
|
||||
es_conn = self.es_pool.get_connection()
|
||||
try:
|
||||
# 确保索引存在
|
||||
if es_conn.indices.exists(index=Config.ES_CONFIG.get("student_info_index")):
|
||||
result = es_conn.get(index=Config.ES_CONFIG.get("student_info_index"), id=doc_id)
|
||||
if result and '_source' in result:
|
||||
logger.info(f"从ES获取到学生 {user_id} 的信息: {result['_source']['info']}")
|
||||
return result['_source']['info']
|
||||
else:
|
||||
logger.info(f"ES中没有找到学生 {user_id} 的信息")
|
||||
else:
|
||||
logger.info("student_info索引不存在")
|
||||
finally:
|
||||
# 释放连接回连接池
|
||||
self.es_pool.release_connection(es_conn)
|
||||
except Exception as e:
|
||||
# 如果文档不存在,返回空字典
|
||||
if "not_found" in str(e).lower():
|
||||
logger.info(f"学生 {user_id} 的信息在ES中不存在")
|
||||
return {}
|
||||
logger.error(f"从ES获取学生信息失败: {str(e)}", exc_info=True)
|
||||
return {}
|
||||
|
||||
def extract_student_info(self,text, user_id):
|
||||
"""使用jieba分词提取学生信息"""
|
||||
try:
|
||||
# 提取年级信息
|
||||
seg_list = jieba.cut(text, cut_all=False) # 精确模式
|
||||
seg_set = set(seg_list)
|
||||
|
||||
# 检查是否已有学生信息,如果没有则从ES加载
|
||||
if user_id not in self.student_info:
|
||||
# 从ES加载学生信息
|
||||
info_from_es = self.get_student_info_from_es(user_id)
|
||||
if info_from_es:
|
||||
self.student_info[user_id] = info_from_es
|
||||
logger.info(f"从ES加载用户 {user_id} 的信息: {info_from_es}")
|
||||
else:
|
||||
self.student_info[user_id] = {}
|
||||
|
||||
# 提取并更新年级信息
|
||||
grade_found = False
|
||||
for grade, keywords in self.GRADE_KEYWORDS.items():
|
||||
for keyword in keywords:
|
||||
if keyword in seg_set:
|
||||
if 'grade' not in self.student_info[user_id] or self.student_info[user_id]['grade'] != grade:
|
||||
self.student_info[user_id]['grade'] = grade
|
||||
logger.info(f"提取到用户 {user_id} 的年级信息: {grade}")
|
||||
# 保存到ES
|
||||
self.save_student_info_to_es(user_id, self.student_info[user_id])
|
||||
grade_found = True
|
||||
break
|
||||
if grade_found:
|
||||
break
|
||||
|
||||
# 如果文本中明确提到年级,但没有匹配到关键词,尝试直接提取数字
|
||||
if not grade_found:
|
||||
import re
|
||||
# 匹配"我是X年级"格式
|
||||
match = re.search(r'我是(\d+)年级', text)
|
||||
if match:
|
||||
grade_num = match.group(1)
|
||||
grade = f"{grade_num}年级"
|
||||
if 'grade' not in self.student_info[user_id] or self.student_info[user_id]['grade'] != grade:
|
||||
self.student_info[user_id]['grade'] = grade
|
||||
logger.info(f"通过正则提取到用户 {user_id} 的年级信息: {grade}")
|
||||
# 保存到ES
|
||||
self.save_student_info_to_es(user_id, self.student_info[user_id])
|
||||
except Exception as e:
|
||||
logger.error(f"提取学生信息失败: {str(e)}", exc_info=True)
|
||||
|
||||
|
||||
|
125
dsLightRag/ElasticSearch/Utils/VectorDBUtil.py
Normal file
125
dsLightRag/ElasticSearch/Utils/VectorDBUtil.py
Normal file
@@ -0,0 +1,125 @@
|
||||
# pip install pydantic requests
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.vectorstores import InMemoryVectorStore
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
from pydantic import SecretStr
|
||||
import requests
|
||||
import json
|
||||
from Config.Config import (
|
||||
EMBED_MODEL_NAME, EMBED_BASE_URL, EMBED_API_KEY,
|
||||
RERANK_MODEL, RERANK_BASE_URL, RERANK_BINDING_API_KEY
|
||||
)
|
||||
|
||||
|
||||
class VectorDBUtil:
|
||||
"""向量数据库工具类,提供文本向量化存储和查询功能"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化向量数据库工具"""
|
||||
# 初始化嵌入模型
|
||||
self.embeddings = OpenAIEmbeddings(
|
||||
model=EMBED_MODEL_NAME,
|
||||
base_url=EMBED_BASE_URL,
|
||||
api_key=SecretStr(EMBED_API_KEY) # 包装成 SecretStr 类型
|
||||
)
|
||||
# 初始化向量存储
|
||||
self.vector_store = None
|
||||
|
||||
def text_to_vector_db(self, text: str, chunk_size: int = 200, chunk_overlap: int = 0) -> tuple:
|
||||
"""
|
||||
将文本存入向量数据库
|
||||
|
||||
参数:
|
||||
text: 要入库的文本
|
||||
chunk_size: 文本分割块大小
|
||||
chunk_overlap: 文本块重叠大小
|
||||
|
||||
返回:
|
||||
tuple: (向量存储对象, 文档数量, 分割后的文档块数量)
|
||||
"""
|
||||
# 创建文档对象
|
||||
docs = [Document(page_content=text, metadata={"source": "simulated_document"})]
|
||||
doc_count = len(docs)
|
||||
print(f"文档数量:{doc_count} 个")
|
||||
|
||||
# 切割文档
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=chunk_size, chunk_overlap=chunk_overlap, add_start_index=True
|
||||
)
|
||||
all_splits = text_splitter.split_documents(docs)
|
||||
split_count = len(all_splits)
|
||||
print(f"切割后的文档块数量:{split_count}")
|
||||
|
||||
# 向量存储
|
||||
self.vector_store = InMemoryVectorStore(self.embeddings)
|
||||
ids = self.vector_store.add_documents(documents=all_splits)
|
||||
|
||||
return self.vector_store, doc_count, split_count
|
||||
|
||||
def query_vector_db(self, query: str, k: int = 4) -> list:
|
||||
"""
|
||||
从向量数据库查询文本
|
||||
|
||||
参数:
|
||||
query: 查询字符串
|
||||
k: 要返回的结果数量
|
||||
|
||||
返回:
|
||||
list: 重排后的结果列表,每个元素是(文档对象, 可信度分数)的元组
|
||||
"""
|
||||
if not self.vector_store:
|
||||
print("错误: 向量数据库未初始化,请先调用text_to_vector_db方法")
|
||||
return []
|
||||
|
||||
# 向量查询 - 获取更多结果用于重排
|
||||
results = self.vector_store.similarity_search(query, k=k)
|
||||
print(f"向量搜索结果数量:{len(results)}")
|
||||
|
||||
# 存储重排后的文档和分数
|
||||
reranked_docs_with_scores = []
|
||||
|
||||
# 调用重排模型
|
||||
if len(results) > 1:
|
||||
# 准备重排请求数据
|
||||
rerank_data = {
|
||||
"model": RERANK_MODEL,
|
||||
"query": query,
|
||||
"documents": [doc.page_content for doc in results],
|
||||
"top_n": len(results)
|
||||
}
|
||||
|
||||
# 调用SiliconFlow API进行重排
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {RERANK_BINDING_API_KEY}"
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(RERANK_BASE_URL, headers=headers, data=json.dumps(rerank_data))
|
||||
response.raise_for_status() # 检查请求是否成功
|
||||
rerank_result = response.json()
|
||||
|
||||
# 处理重排结果,提取relevance_score
|
||||
if "results" in rerank_result:
|
||||
for item in rerank_result["results"]:
|
||||
doc_idx = item.get("index")
|
||||
score = item.get("relevance_score", 0.0)
|
||||
if 0 <= doc_idx < len(results):
|
||||
reranked_docs_with_scores.append((results[doc_idx], score))
|
||||
else:
|
||||
print("警告: 无法识别重排API响应格式")
|
||||
reranked_docs_with_scores = [(doc, 0.0) for doc in results]
|
||||
|
||||
print(f"重排后结果数量:{len(reranked_docs_with_scores)}")
|
||||
except Exception as e:
|
||||
print(f"重排模型调用失败: {e}")
|
||||
print("将使用原始搜索结果")
|
||||
reranked_docs_with_scores = [(doc, 0.0) for doc in results]
|
||||
else:
|
||||
# 只有一个结果,无需重排
|
||||
reranked_docs_with_scores = [(doc, 1.0) for doc in results] # 单个结果可信度设为1.0
|
||||
|
||||
return reranked_docs_with_scores
|
||||
|
||||
|
0
dsLightRag/ElasticSearch/Utils/__init__.py
Normal file
0
dsLightRag/ElasticSearch/Utils/__init__.py
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
0
dsLightRag/ElasticSearch/__init__.py
Normal file
0
dsLightRag/ElasticSearch/__init__.py
Normal file
BIN
dsLightRag/ElasticSearch/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
dsLightRag/ElasticSearch/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
203
dsLightRag/Routes/QA.py
Normal file
203
dsLightRag/Routes/QA.py
Normal file
@@ -0,0 +1,203 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import fastapi
|
||||
import jieba
|
||||
from fastapi import APIRouter
|
||||
from fastapi import HTTPException
|
||||
from openai import AsyncOpenAI
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
|
||||
from Config import Config
|
||||
from ElasticSearch.Utils.EsSearchUtil import EsSearchUtil
|
||||
|
||||
# 创建路由路由器
|
||||
router = APIRouter(prefix="/qa", tags=["答疑"])
|
||||
|
||||
# 配置日志
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 初始化异步 OpenAI 客户端
|
||||
client = AsyncOpenAI(
|
||||
api_key=Config.ALY_LLM_API_KEY,
|
||||
base_url=Config.ALY_LLM_BASE_URL
|
||||
)
|
||||
# 初始化 ElasticSearch 工具
|
||||
search_util = EsSearchUtil(Config.ES_CONFIG)
|
||||
@router.post("/chat")
|
||||
async def chat(request: fastapi.Request):
|
||||
"""
|
||||
根据用户输入的语句,查询相关历史对话
|
||||
然后调用大模型进行回答
|
||||
"""
|
||||
try:
|
||||
data = await request.json()
|
||||
user_id = data.get('user_id', 'anonymous')
|
||||
query = data.get('query', '')
|
||||
session_id = data.get('session_id', str(uuid.uuid4())) # 获取或生成会话ID
|
||||
include_history = data.get('include_history', True)
|
||||
|
||||
if not query:
|
||||
raise HTTPException(status_code=400, detail="查询内容不能为空")
|
||||
|
||||
# 1. 初始化会话历史和学生信息
|
||||
if session_id not in search_util.conversation_history:
|
||||
search_util.conversation_history[session_id] = []
|
||||
|
||||
# 检查是否已有学生信息,如果没有则从ES加载
|
||||
if user_id not in search_util.student_info:
|
||||
# 从ES加载学生信息
|
||||
info_from_es = search_util.get_student_info_from_es(user_id)
|
||||
if info_from_es:
|
||||
search_util.student_info[user_id] = info_from_es
|
||||
logger.info(f"从ES加载用户 {user_id} 的信息: {info_from_es}")
|
||||
else:
|
||||
search_util.student_info[user_id] = {}
|
||||
|
||||
# 2. 使用jieba分词提取学生信息
|
||||
search_util.extract_student_info(query, user_id)
|
||||
|
||||
# 输出调试信息
|
||||
logger.info(f"当前学生信息: {search_util.student_info.get(user_id, {})}")
|
||||
|
||||
# 为用户查询生成标签并存储到ES
|
||||
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
tags = [user_id, f"time:{current_time.split()[0]}", f"session:{session_id}"]
|
||||
|
||||
# 提取查询中的关键词作为额外标签 - 使用jieba分词
|
||||
try:
|
||||
seg_list = jieba.cut(query, cut_all=False) # 精确模式
|
||||
keywords = [kw for kw in seg_list if kw.strip() and kw not in search_util.STOPWORDS and len(kw) > 1]
|
||||
keywords = keywords[:5]
|
||||
tags.extend([f"keyword:{kw}" for kw in keywords])
|
||||
logger.info(f"使用jieba分词提取的关键词: {keywords}")
|
||||
except Exception as e:
|
||||
logger.error(f"分词失败: {str(e)}")
|
||||
keywords = query.split()[:5]
|
||||
tags.extend([f"keyword:{kw}" for kw in keywords if kw.strip()])
|
||||
|
||||
# 存储查询到ES
|
||||
try:
|
||||
search_util.insert_long_text_to_es(query, tags)
|
||||
logger.info(f"用户 {user_id} 的查询已存储到ES,标签: {tags}")
|
||||
except Exception as e:
|
||||
logger.error(f"存储用户查询到ES失败: {str(e)}")
|
||||
|
||||
# 3. 构建对话历史上下文
|
||||
history_context = ""
|
||||
if include_history and session_id in search_util.conversation_history:
|
||||
# 获取最近的几次对话历史
|
||||
recent_history = search_util.conversation_history[session_id][-search_util.MAX_HISTORY_ROUNDS:]
|
||||
if recent_history:
|
||||
history_context = "\n\n以下是最近的对话历史,可供参考:\n"
|
||||
for i, (user_msg, ai_msg) in enumerate(recent_history, 1):
|
||||
history_context += f"[对话 {i}] 用户: {user_msg}\n"
|
||||
history_context += f"[对话 {i}] 老师: {ai_msg}\n"
|
||||
|
||||
# 4. 构建学生信息上下文
|
||||
student_context = ""
|
||||
if user_id in search_util.student_info and search_util.student_info[user_id]:
|
||||
student_context = "\n\n学生基础信息:\n"
|
||||
for key, value in search_util.student_info[user_id].items():
|
||||
if key == 'grade':
|
||||
student_context += f"- 年级: {value}\n"
|
||||
else:
|
||||
student_context += f"- {key}: {value}\n"
|
||||
|
||||
# 5. 构建提示词
|
||||
system_prompt = """
|
||||
你是一位平易近人且教学方法灵活的教师,通过引导学生自主学习来帮助他们掌握知识。
|
||||
|
||||
严格遵循以下教学规则:
|
||||
1. 基于学生情况调整教学:如果已了解学生的年级水平和知识背景,应基于此调整教学内容和难度。
|
||||
2. 基于现有知识构建:将新思想与学生已有的知识联系起来。
|
||||
3. 引导而非灌输:使用问题、提示和小步骤,让学生自己发现答案。
|
||||
4. 检查和强化:在讲解难点后,确认学生能够重述或应用这些概念。
|
||||
5. 变化节奏:混合讲解、提问和互动活动,让教学像对话而非讲座。
|
||||
|
||||
最重要的是:不要直接给出答案,而是通过合作和基于学生已有知识的引导,帮助学生自己找到答案。
|
||||
"""
|
||||
|
||||
# 添加学生信息到系统提示词
|
||||
if user_id in search_util.student_info and search_util.student_info[user_id]:
|
||||
student_info_str = "\n\n学生基础信息:\n"
|
||||
for key, value in search_util.student_info[user_id].items():
|
||||
if key == 'grade':
|
||||
student_info_str += f"- 年级: {value}\n"
|
||||
else:
|
||||
student_info_str += f"- {key}: {value}\n"
|
||||
system_prompt += student_info_str
|
||||
|
||||
# 6. 流式调用大模型生成回答
|
||||
async def generate_response_stream():
|
||||
try:
|
||||
# 构建消息列表
|
||||
messages = [{'role': 'system', 'content': system_prompt.strip()}]
|
||||
|
||||
# 添加学生信息(如果有)
|
||||
if student_context:
|
||||
messages.append({'role': 'user', 'content': student_context.strip()})
|
||||
|
||||
# 添加历史对话(如果有)
|
||||
if history_context:
|
||||
messages.append({'role': 'user', 'content': history_context.strip()})
|
||||
|
||||
# 添加当前问题
|
||||
messages.append({'role': 'user', 'content': query})
|
||||
|
||||
stream = await client.chat.completions.create(
|
||||
model=Config.ALY_LLM_MODEL_NAME,
|
||||
messages=messages,
|
||||
max_tokens=8000,
|
||||
stream=True
|
||||
)
|
||||
|
||||
# 收集完整回答用于保存
|
||||
full_answer = []
|
||||
async for chunk in stream:
|
||||
if chunk.choices[0].delta.content:
|
||||
full_answer.append(chunk.choices[0].delta.content)
|
||||
yield f"data: {json.dumps({'reply': chunk.choices[0].delta.content}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 保存回答到ES和对话历史
|
||||
if full_answer:
|
||||
answer_text = ''.join(full_answer)
|
||||
search_util.extract_student_info(answer_text, user_id)
|
||||
try:
|
||||
# 为回答添加标签
|
||||
answer_tags = [f"{user_id}_answer", f"time:{current_time.split()[0]}", f"session:{session_id}"]
|
||||
try:
|
||||
seg_list = jieba.cut(answer_text, cut_all=False)
|
||||
answer_keywords = [kw for kw in seg_list if kw.strip() and kw not in search_util.STOPWORDS and len(kw) > 1]
|
||||
answer_keywords = answer_keywords[:5]
|
||||
answer_tags.extend([f"keyword:{kw}" for kw in answer_keywords])
|
||||
except Exception as e:
|
||||
logger.error(f"回答分词失败: {str(e)}")
|
||||
|
||||
search_util.insert_long_text_to_es(answer_text, answer_tags)
|
||||
logger.info(f"用户 {user_id} 的回答已存储到ES")
|
||||
|
||||
# 更新对话历史
|
||||
search_util.conversation_history[session_id].append((query, answer_text))
|
||||
# 保持历史记录不超过最大轮数
|
||||
if len(search_util.conversation_history[session_id]) > search_util.MAX_HISTORY_ROUNDS:
|
||||
search_util.conversation_history[session_id].pop(0)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"存储回答到ES失败: {str(e)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"大模型调用失败: {str(e)}")
|
||||
yield f"data: {json.dumps({'error': f'生成回答失败: {str(e)}'})}\n\n"
|
||||
|
||||
return EventSourceResponse(generate_response_stream())
|
||||
|
||||
except HTTPException as e:
|
||||
logger.error(f"聊天接口错误: {str(e.detail)}")
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.error(f"聊天接口异常: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"处理请求失败: {str(e)}")
|
||||
|
@@ -1,19 +1,20 @@
|
||||
# pip install captcha
|
||||
# routes/LoginController.py
|
||||
import base64
|
||||
import os
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import string
|
||||
import jwt
|
||||
|
||||
import jwt
|
||||
from captcha.image import ImageCaptcha
|
||||
from fastapi import APIRouter, Request, Response, status, HTTPException
|
||||
from fastapi import APIRouter
|
||||
|
||||
from Util.CommonUtil import *
|
||||
from Util.CookieUtil import *
|
||||
from Util.Database import *
|
||||
from Util.JwtUtil import *
|
||||
from Util.ParseRequest import *
|
||||
from Config.Config import *
|
||||
|
||||
# 创建一个路由实例
|
||||
router = APIRouter()
|
||||
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -18,6 +18,7 @@ from Routes.TeachingModel.api.DmController import router as dm_router
|
||||
from Routes.TeachingModel.api.ThemeController import router as theme_router
|
||||
from Routes.TeachingModel.api.DocumentController import router as document_router
|
||||
from Routes.TeachingModel.api.TeachingModelController import router as teaching_model_router
|
||||
from Routes.QA import router as qa_router
|
||||
|
||||
from Util.LightRagUtil import *
|
||||
from contextlib import asynccontextmanager
|
||||
@@ -53,8 +54,8 @@ app.include_router(ggb_router) # Geogebra路由
|
||||
app.include_router(rag_router) # LightRAG路由
|
||||
app.include_router(knowledge_router) # 知识图谱路由
|
||||
app.include_router(oss_router) # 阿里云OSS路由
|
||||
|
||||
app.include_router(llm_router) # 大模型路由
|
||||
app.include_router(qa_router) # 答疑路由
|
||||
|
||||
# Teaching Model 相关路由
|
||||
# 登录相关(不用登录)
|
||||
@@ -69,6 +70,9 @@ app.include_router(theme_router, prefix="/api/theme", tags=["theme"])
|
||||
app.include_router(document_router, prefix="/api/document", tags=["document"])
|
||||
# 问题相关(大模型应用)
|
||||
app.include_router(teaching_model_router, prefix="/api/teaching/model", tags=["teacher_model"])
|
||||
# 教学答疑
|
||||
app.include_router(teaching_model_router, prefix="/api/teaching/model", tags=["teacher_model"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(app, host="0.0.0.0", port=8100)
|
||||
|
BIN
dsLightRag/Util/__pycache__/CommonUtil.cpython-310.pyc
Normal file
BIN
dsLightRag/Util/__pycache__/CommonUtil.cpython-310.pyc
Normal file
Binary file not shown.
BIN
dsLightRag/Util/__pycache__/CookieUtil.cpython-310.pyc
Normal file
BIN
dsLightRag/Util/__pycache__/CookieUtil.cpython-310.pyc
Normal file
Binary file not shown.
BIN
dsLightRag/Util/__pycache__/Database.cpython-310.pyc
Normal file
BIN
dsLightRag/Util/__pycache__/Database.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
dsLightRag/Util/__pycache__/JwtUtil.cpython-310.pyc
Normal file
BIN
dsLightRag/Util/__pycache__/JwtUtil.cpython-310.pyc
Normal file
Binary file not shown.
BIN
dsLightRag/Util/__pycache__/PageUtil.cpython-310.pyc
Normal file
BIN
dsLightRag/Util/__pycache__/PageUtil.cpython-310.pyc
Normal file
Binary file not shown.
BIN
dsLightRag/Util/__pycache__/ParseRequest.cpython-310.pyc
Normal file
BIN
dsLightRag/Util/__pycache__/ParseRequest.cpython-310.pyc
Normal file
Binary file not shown.
BIN
dsLightRag/Util/__pycache__/TranslateUtil.cpython-310.pyc
Normal file
BIN
dsLightRag/Util/__pycache__/TranslateUtil.cpython-310.pyc
Normal file
Binary file not shown.
1
dsSchoolBuddy/.idea/vcs.xml
generated
1
dsSchoolBuddy/.idea/vcs.xml
generated
@@ -2,6 +2,5 @@
|
||||
<project version="4">
|
||||
<component name="VcsDirectoryMappings">
|
||||
<mapping directory="$PROJECT_DIR$/.." vcs="Git" />
|
||||
<mapping directory="$PROJECT_DIR$/mtef-go-3" vcs="Git" />
|
||||
</component>
|
||||
</project>
|
@@ -10,8 +10,8 @@ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 服务器地址
|
||||
BASE_URL = "http://localhost:8000"
|
||||
CHAT_ENDPOINT = f"{BASE_URL}/api/teaching_chat"
|
||||
BASE_URL = "http://localhost:8100"
|
||||
CHAT_ENDPOINT = f"{BASE_URL}/qa/chat"
|
||||
|
||||
# 用户ID(固定一个以便模拟多轮对话)
|
||||
USER_ID = "test_user_123"
|
||||
|
Reference in New Issue
Block a user