'commit'
This commit is contained in:
@@ -1,27 +0,0 @@
|
||||
# Elasticsearch配置
|
||||
ES_CONFIG = {
|
||||
'hosts': ['https://localhost:9200'],
|
||||
'basic_auth': ('elastic', 'jv9h8uwRrRxmDi1dq6u8'),
|
||||
'verify_certs': False,
|
||||
'index_name': 'ds_db', # 默认索引名称
|
||||
'student_info_index': 'student_info' # 添加student_info索引名称配置
|
||||
}
|
||||
|
||||
# 嵌入向量模型
|
||||
EMBED_MODEL_NAME = "BAAI/bge-m3"
|
||||
EMBED_API_KEY = "sk-pbqibyjwhrgmnlsmdygplahextfaclgnedetybccknxojlyl"
|
||||
EMBED_BASE_URL = "https://api.siliconflow.cn/v1"
|
||||
EMBED_DIM = 1024
|
||||
EMBED_MAX_TOKEN_SIZE = 8192
|
||||
|
||||
# 重排模型
|
||||
RERANK_MODEL = 'BAAI/bge-reranker-v2-m3'
|
||||
RERANK_BASE_URL = 'https://api.siliconflow.cn/v1/rerank'
|
||||
RERANK_BINDING_API_KEY = 'sk-pbqibyjwhrgmnlsmdygplahextfaclgnedetybccknxojlyl'
|
||||
|
||||
# 阿里云API信息【HZKJ】
|
||||
ALY_LLM_API_KEY = "sk-01d13a39e09844038322108ecdbd1bbc"
|
||||
ALY_LLM_BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
ALY_LLM_MODEL_NAME = "qwen-plus"
|
||||
#ALY_LLM_MODEL_NAME = "deepseek-r1"
|
||||
# ALY_LLM_MODEL_NAME = "deepseek-v3"
|
Binary file not shown.
Binary file not shown.
@@ -1,32 +0,0 @@
|
||||
from Config import Config
|
||||
from ElasticSearch.Utils.EsSearchUtil import EsSearchUtil
|
||||
|
||||
# 创建EsSearchUtil实例
|
||||
search_util = EsSearchUtil(Config.ES_CONFIG)
|
||||
|
||||
def rebuild_index(index_name):
|
||||
"""
|
||||
重建指定的索引
|
||||
|
||||
参数:
|
||||
index_name: 要重建的索引名称
|
||||
|
||||
返回:
|
||||
bool: 操作是否成功
|
||||
"""
|
||||
print(f"开始重建索引: {index_name}")
|
||||
if search_util.rebuild_mapping(index_name):
|
||||
print(f"重建索引 '{index_name}' 操作成功")
|
||||
return True
|
||||
else:
|
||||
print(f"重建索引 '{index_name}' 操作失败")
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 重建ds_db索引
|
||||
rebuild_index(Config.ES_CONFIG['index_name'])
|
||||
|
||||
# 重建student_info索引
|
||||
rebuild_index(Config.ES_CONFIG['student_info_index'])
|
||||
|
||||
print("所有索引重建操作完成")
|
@@ -1,38 +0,0 @@
|
||||
# pip install pydantic requests
|
||||
from ElasticSearch.Utils.VectorDBUtil import VectorDBUtil
|
||||
|
||||
|
||||
def main():
|
||||
# 模拟长字符串文档内容
|
||||
long_text = """混凝土是一种广泛使用的建筑材料,由水泥、砂、石子和水混合而成。它具有高强度、耐久性和良好的可塑性,被广泛应用于建筑、桥梁、道路等土木工程领域。
|
||||
|
||||
混凝土的历史可以追溯到古罗马时期,当时人们使用火山灰、石灰和碎石混合制成类似混凝土的材料。现代混凝土技术始于19世纪,随着波特兰水泥的发明而得到快速发展。
|
||||
|
||||
混凝土的性能取决于其配合比,包括水灰比、砂率等参数。水灰比是影响混凝土强度的关键因素,较小的水灰比通常会产生更高强度的混凝土。
|
||||
|
||||
为了改善混凝土的性能,常常会添加各种外加剂,如减水剂、早强剂、缓凝剂等。此外,还可以使用纤维增强、聚合物改性等技术来提高混凝土的韧性和耐久性。
|
||||
|
||||
在施工过程中,混凝土需要适当的养护,以确保其强度正常发展。养护措施包括浇水、覆盖保湿、蒸汽养护等。
|
||||
|
||||
随着建筑技术的发展,高性能混凝土、自密实混凝土、再生骨料混凝土等新型混凝土不断涌现,为土木工程领域提供了更多的选择。"""
|
||||
|
||||
# 创建工具实例
|
||||
vector_util = VectorDBUtil()
|
||||
|
||||
# 调用文本入库功能
|
||||
vector_util.text_to_vector_db(long_text)
|
||||
|
||||
# 调用文本查询功能
|
||||
query = "混凝土"
|
||||
reranked_results = vector_util.query_vector_db(query, k=4)
|
||||
|
||||
# 打印所有查询结果及其可信度
|
||||
print("最终查询结果:")
|
||||
for i, (result, score) in enumerate(reranked_results):
|
||||
print(f"结果 {i+1} (可信度: {score:.4f}):")
|
||||
print(result.page_content)
|
||||
print("---")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@@ -1,28 +0,0 @@
|
||||
from Config import Config
|
||||
from ElasticSearch.Utils.EsSearchUtil import EsSearchUtil
|
||||
|
||||
|
||||
def main():
|
||||
# 示例1:插入单个长文本
|
||||
long_text = """混凝土是一种广泛使用的建筑材料,由水泥、砂、石子和水混合而成。它具有高强度、耐久性和良好的可塑性,被广泛应用于建筑、桥梁、道路等土木工程领域。
|
||||
|
||||
混凝土的历史可以追溯到古罗马时期,当时人们使用火山灰、石灰和碎石混合制成类似混凝土的材料。现代混凝土技术始于19世纪,随着波特兰水泥的发明而得到快速发展。
|
||||
|
||||
混凝土的性能取决于其配合比,包括水灰比、砂率等参数。水灰比是影响混凝土强度的关键因素,较小的水灰比通常会产生更高强度的混凝土。
|
||||
|
||||
为了改善混凝土的性能,常常会添加各种外加剂,如减水剂、早强剂、缓凝剂等。此外,还可以使用纤维增强、聚合物改性等技术来提高混凝土的韧性和耐久性。
|
||||
|
||||
在施工过程中,混凝土需要适当的养护,以确保其强度正常发展。养护措施包括浇水、覆盖保湿、蒸汽养护等。
|
||||
|
||||
随着建筑技术的发展,高性能混凝土、自密实混凝土、再生骨料混凝土等新型混凝土不断涌现,为土木工程领域提供了更多的选择。"""
|
||||
|
||||
# 打标签
|
||||
tags = ["student_110"]
|
||||
|
||||
# 创建EsSearchUtil实例
|
||||
search_util = EsSearchUtil(Config.ES_CONFIG)
|
||||
search_util.insert_long_text_to_es(long_text, tags)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@@ -1,28 +0,0 @@
|
||||
from Config import Config
|
||||
from ElasticSearch.Utils.EsSearchUtil import EsSearchUtil
|
||||
|
||||
# 创建EsSearchUtil实例
|
||||
search_util = EsSearchUtil(Config.ES_CONFIG)
|
||||
|
||||
# 查询所有数据
|
||||
def select_all_data(index_name):
|
||||
try:
|
||||
# 调用EsSearchUtil中的select_all_data方法
|
||||
response = search_util.select_all_data()
|
||||
hits = response['hits']['hits']
|
||||
|
||||
if not hits:
|
||||
print(f"索引 {index_name} 中没有数据")
|
||||
else:
|
||||
print(f"索引 {index_name} 中共有 {len(hits)} 条数据:")
|
||||
for i, hit in enumerate(hits, 1):
|
||||
print(f"{i}. ID: {hit['_id']}")
|
||||
print(f" 内容: {hit['_source'].get('user_input', '')}")
|
||||
print(f" 标签: {hit['_source'].get('tags', '')}")
|
||||
print(f" 时间戳: {hit['_source'].get('timestamp', '')}")
|
||||
print("-" * 50)
|
||||
except Exception as e:
|
||||
print(f"查询出错: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
select_all_data(Config.ES_CONFIG['index_name'])
|
@@ -1,28 +0,0 @@
|
||||
from Config import Config
|
||||
from ElasticSearch.Utils.EsSearchUtil import EsSearchUtil
|
||||
|
||||
|
||||
# 1. 创建EsSearchUtil实例(已封装连接池)
|
||||
search_util = EsSearchUtil(Config.ES_CONFIG)
|
||||
|
||||
# 2. 直接在代码中指定要查询的关键字
|
||||
query_keyword = "混凝土"
|
||||
|
||||
# 3. 执行查询并处理结果
|
||||
try:
|
||||
# 使用连接池进行查询
|
||||
results = search_util.text_search(query_keyword, size=1000)
|
||||
print(f"查询关键字 '{query_keyword}' 结果:")
|
||||
if results['hits']['hits']:
|
||||
for i, hit in enumerate(results['hits']['hits'], 1):
|
||||
doc = hit['_source']
|
||||
print(f"{i}. ID: {hit['_id']}")
|
||||
print(f" 标签: {doc['tags']['tags'] if 'tags' in doc and 'tags' in doc['tags'] else '无'}")
|
||||
print(f" 用户问题: {doc['user_input'] if 'user_input' in doc else '无'}")
|
||||
print(f" 时间: {doc['timestamp'] if 'timestamp' in doc else '无'}")
|
||||
print("-" * 50)
|
||||
else:
|
||||
print(f"未找到包含 '{query_keyword}' 的数据。")
|
||||
except Exception as e:
|
||||
print(f"查询失败: {e}")
|
||||
print(f"查询关键字: {query_keyword}")
|
@@ -1,29 +0,0 @@
|
||||
from Config import Config
|
||||
from ElasticSearch.Utils.EsSearchUtil import EsSearchUtil
|
||||
|
||||
|
||||
def main():
|
||||
# 初始化搜索工具
|
||||
search_util = EsSearchUtil(Config.ES_CONFIG)
|
||||
|
||||
# 输入查询文本
|
||||
query = "混凝土"
|
||||
print(f"查询文本: {query}")
|
||||
|
||||
# 获取查询向量
|
||||
query_embedding = search_util.get_query_embedding(query)
|
||||
print(f"查询向量维度: {len(query_embedding)}")
|
||||
|
||||
# 向量搜索
|
||||
search_results = search_util.search_by_vector(query_embedding, k=10)
|
||||
print(f"向量搜索结果数量: {len(search_results)}")
|
||||
|
||||
# 结果重排
|
||||
reranked_results = search_util.rerank_results(query, search_results)
|
||||
|
||||
# 显示结果
|
||||
search_util.display_results(reranked_results)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@@ -1,66 +0,0 @@
|
||||
import logging
|
||||
|
||||
from Config.Config import ES_CONFIG
|
||||
from ElasticSearch.Utils.EsSearchUtil import EsSearchUtil
|
||||
|
||||
# 初始化日志
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 初始化EsSearchUtil
|
||||
search_util = EsSearchUtil(ES_CONFIG)
|
||||
|
||||
# 获取用户输入
|
||||
user_query = input("请输入查询语句(例如:高性能的混凝土): ")
|
||||
if not user_query:
|
||||
user_query = "高性能的混凝土"
|
||||
print(f"未输入查询语句,使用默认值: {user_query}")
|
||||
|
||||
query_tags = [] # 可以根据需要添加标签过滤
|
||||
|
||||
print(f"\n=== 开始执行查询 ===")
|
||||
print(f"原始查询文本: {user_query}")
|
||||
|
||||
try:
|
||||
# 1. 向量搜索
|
||||
print("\n=== 向量搜索阶段 ===")
|
||||
print("1. 文本向量化处理中...")
|
||||
query_embedding = search_util.get_query_embedding(user_query)
|
||||
print(f"2. 生成的查询向量维度: {len(query_embedding)}")
|
||||
print(f"3. 前3维向量值: {query_embedding[:3]}")
|
||||
|
||||
print("4. 正在执行Elasticsearch向量搜索...")
|
||||
vector_hits = search_util.search_by_vector(query_embedding, k=5)
|
||||
print(f"5. 向量搜索结果数量: {len(vector_hits)}")
|
||||
|
||||
# 向量结果重排
|
||||
print("6. 正在进行向量结果重排...")
|
||||
reranked_vector_results = search_util.rerank_results(user_query, vector_hits)
|
||||
print(f"7. 重排后向量结果数量: {len(reranked_vector_results)}")
|
||||
|
||||
# 2. 关键字搜索
|
||||
print("\n=== 关键字搜索阶段 ===")
|
||||
print("1. 正在执行Elasticsearch关键字搜索...")
|
||||
keyword_results = search_util.text_search(user_query, size=5)
|
||||
keyword_hits = keyword_results['hits']['hits']
|
||||
print(f"2. 关键字搜索结果数量: {len(keyword_hits)}")
|
||||
|
||||
# 3. 合并结果
|
||||
print("\n=== 合并搜索结果 ===")
|
||||
# 为关键字结果添加分数
|
||||
keyword_results_with_scores = [(doc, doc['_score']) for doc in keyword_hits]
|
||||
merged_results = search_util.merge_results(keyword_results_with_scores, reranked_vector_results)
|
||||
print(f"合并后唯一结果数量: {len(merged_results)}")
|
||||
|
||||
# 4. 打印最终结果
|
||||
print("\n=== 最终搜索结果 ===")
|
||||
for i, (doc, score, source) in enumerate(merged_results, 1):
|
||||
print(f"{i}. 文档ID: {doc['_id']}, 分数: {score:.2f}, 来源: {source}")
|
||||
print(f" 内容: {doc['_source']['user_input']}")
|
||||
print(" --- ")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"搜索过程中发生错误: {str(e)}")
|
||||
print(f"搜索失败: {str(e)}")
|
@@ -1,110 +0,0 @@
|
||||
from elasticsearch.exceptions import NotFoundError
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ElasticsearchCollectionManager:
|
||||
def __init__(self, index_name):
|
||||
"""
|
||||
初始化Elasticsearch索引管理器
|
||||
:param index_name: Elasticsearch索引名称
|
||||
"""
|
||||
self.index_name = index_name
|
||||
|
||||
def load_collection(self, es_connection):
|
||||
"""
|
||||
加载索引,如果不存在则创建
|
||||
:param es_connection: Elasticsearch连接
|
||||
"""
|
||||
try:
|
||||
if not es_connection.indices.exists(index=self.index_name):
|
||||
logger.warning(f"Index {self.index_name} does not exist, creating new index")
|
||||
self._create_index(es_connection)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load collection: {str(e)}")
|
||||
raise
|
||||
|
||||
def _create_index(self, es_connection):
|
||||
"""创建新的Elasticsearch索引"""
|
||||
mapping = {
|
||||
"mappings": {
|
||||
"properties": {
|
||||
"user_input": {"type": "text"},
|
||||
"tags": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"tags": {"type": "keyword"},
|
||||
"full_content": {"type": "text"}
|
||||
}
|
||||
},
|
||||
"timestamp": {"type": "date"},
|
||||
"embedding": {"type": "dense_vector", "dims": 200}
|
||||
}
|
||||
}
|
||||
}
|
||||
es_connection.indices.create(index=self.index_name, body=mapping)
|
||||
|
||||
def search(self, es_connection, query_embedding, search_params, expr=None, limit=5):
|
||||
"""
|
||||
执行混合搜索(向量+关键字)
|
||||
:param es_connection: Elasticsearch连接
|
||||
:param query_embedding: 查询向量
|
||||
:param search_params: 搜索参数
|
||||
:param expr: 过滤表达式
|
||||
:param limit: 返回结果数量
|
||||
:return: 搜索结果
|
||||
"""
|
||||
try:
|
||||
# 构建查询
|
||||
query = {
|
||||
"query": {
|
||||
"script_score": {
|
||||
"query": {
|
||||
"bool": {
|
||||
"must": []
|
||||
}
|
||||
},
|
||||
"script": {
|
||||
"source": "cosineSimilarity(params.query_vector, 'embedding') + 1.0",
|
||||
"params": {"query_vector": query_embedding}
|
||||
}
|
||||
}
|
||||
},
|
||||
"size": limit
|
||||
}
|
||||
|
||||
# 添加标签过滤条件
|
||||
if expr:
|
||||
query["query"]["script_score"]["query"]["bool"]["must"].append(
|
||||
{"nested": {
|
||||
"path": "tags",
|
||||
"query": {
|
||||
"terms": {"tags.tags": expr.split(" OR ")}
|
||||
}
|
||||
}}
|
||||
)
|
||||
|
||||
logger.info(f"Executing search with query: {query}")
|
||||
response = es_connection.search(index=self.index_name, body=query)
|
||||
return response["hits"]["hits"]
|
||||
except Exception as e:
|
||||
logger.error(f"Search failed: {str(e)}")
|
||||
raise
|
||||
|
||||
def query_by_id(self, es_connection, doc_id):
|
||||
"""
|
||||
根据ID查询文档
|
||||
:param es_connection: Elasticsearch连接
|
||||
:param doc_id: 文档ID
|
||||
:return: 文档内容
|
||||
"""
|
||||
try:
|
||||
response = es_connection.get(index=self.index_name, id=doc_id)
|
||||
return response["_source"]
|
||||
except NotFoundError:
|
||||
logger.warning(f"Document with id {doc_id} not found")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to query document by id: {str(e)}")
|
||||
raise
|
@@ -1,65 +0,0 @@
|
||||
from elasticsearch import Elasticsearch
|
||||
import threading
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ElasticsearchConnectionPool:
|
||||
def __init__(self, hosts, basic_auth, verify_certs=False, max_connections=50):
|
||||
"""
|
||||
初始化Elasticsearch连接池
|
||||
:param hosts: Elasticsearch服务器地址
|
||||
:param basic_auth: 认证信息(username, password)
|
||||
:param verify_certs: 是否验证SSL证书
|
||||
:param max_connections: 最大连接数
|
||||
"""
|
||||
self.hosts = hosts
|
||||
self.basic_auth = basic_auth
|
||||
self.verify_certs = verify_certs
|
||||
self.max_connections = max_connections
|
||||
self._connections = []
|
||||
self._lock = threading.Lock()
|
||||
self._initialize_pool()
|
||||
|
||||
def _initialize_pool(self):
|
||||
"""初始化连接池"""
|
||||
for _ in range(self.max_connections):
|
||||
self._connections.append(self._create_connection())
|
||||
|
||||
def _create_connection(self):
|
||||
"""创建新的Elasticsearch连接"""
|
||||
return Elasticsearch(
|
||||
hosts=self.hosts,
|
||||
basic_auth=self.basic_auth,
|
||||
verify_certs=self.verify_certs
|
||||
)
|
||||
|
||||
def get_connection(self):
|
||||
"""从连接池获取一个连接"""
|
||||
with self._lock:
|
||||
if not self._connections:
|
||||
logger.warning("Connection pool exhausted, creating new connection")
|
||||
return self._create_connection()
|
||||
return self._connections.pop()
|
||||
|
||||
def release_connection(self, connection):
|
||||
"""释放连接回连接池"""
|
||||
with self._lock:
|
||||
if len(self._connections) < self.max_connections:
|
||||
self._connections.append(connection)
|
||||
else:
|
||||
try:
|
||||
connection.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to close connection: {str(e)}")
|
||||
|
||||
def close(self):
|
||||
"""关闭所有连接"""
|
||||
with self._lock:
|
||||
for conn in self._connections:
|
||||
try:
|
||||
conn.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to close connection: {str(e)}")
|
||||
self._connections.clear()
|
@@ -1,664 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
import warnings
|
||||
import hashlib
|
||||
import time
|
||||
|
||||
import jieba
|
||||
import requests
|
||||
|
||||
from ElasticSearch.Utils.ElasticsearchConnectionPool import ElasticsearchConnectionPool
|
||||
from langchain_core.documents import Document
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
from pydantic import SecretStr
|
||||
from Config import Config
|
||||
from typing import List, Tuple, Dict
|
||||
# 初始化日志
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
class EsSearchUtil:
|
||||
# 存储对话历史的字典,键为会话ID,值为对话历史列表
|
||||
conversation_history = {}
|
||||
|
||||
# 存储学生信息的字典,键为用户ID,值为学生信息
|
||||
student_info = {}
|
||||
|
||||
# 年级关键词词典
|
||||
GRADE_KEYWORDS = {
|
||||
'一年级': ['一年级', '初一'],
|
||||
'二年级': ['二年级', '初二'],
|
||||
'三年级': ['三年级', '初三'],
|
||||
'四年级': ['四年级'],
|
||||
'五年级': ['五年级'],
|
||||
'六年级': ['六年级'],
|
||||
'七年级': ['七年级', '初一'],
|
||||
'八年级': ['八年级', '初二'],
|
||||
'九年级': ['九年级', '初三'],
|
||||
'高一': ['高一'],
|
||||
'高二': ['高二'],
|
||||
'高三': ['高三']
|
||||
}
|
||||
|
||||
# 最大对话历史轮数
|
||||
MAX_HISTORY_ROUNDS = 10
|
||||
|
||||
# 初始化停用词表
|
||||
STOPWORDS = set(
|
||||
['的', '了', '在', '是', '我', '有', '和', '就', '不', '人', '都', '一', '一个', '上', '也', '很', '到', '说',
|
||||
'要', '去', '你', '会', '着', '没有', '看', '好', '自己', '这'])
|
||||
|
||||
def __init__(self, es_config):
|
||||
"""
|
||||
初始化Elasticsearch搜索工具
|
||||
:param es_config: Elasticsearch配置字典,包含hosts, username, password, index_name等
|
||||
"""
|
||||
# 抑制HTTPS相关警告
|
||||
warnings.filterwarnings('ignore', message='Connecting to .* using TLS with verify_certs=False is insecure')
|
||||
warnings.filterwarnings('ignore', message='Unverified HTTPS request is being made to host')
|
||||
|
||||
self.es_config = es_config
|
||||
|
||||
# 初始化连接池
|
||||
self.es_pool = ElasticsearchConnectionPool(
|
||||
hosts=es_config['hosts'],
|
||||
basic_auth=es_config['basic_auth'],
|
||||
verify_certs=es_config.get('verify_certs', False),
|
||||
max_connections=50
|
||||
)
|
||||
|
||||
self.index_name = es_config['index_name']
|
||||
logger.info(f"EsSearchUtil初始化成功,索引名称: {self.index_name}")
|
||||
|
||||
def rebuild_mapping(self, index_name=None):
|
||||
"""
|
||||
重建Elasticsearch索引和mapping结构
|
||||
|
||||
参数:
|
||||
index_name: 可选,指定要重建的索引名称,默认使用初始化时的索引名称
|
||||
|
||||
返回:
|
||||
bool: 操作是否成功
|
||||
"""
|
||||
try:
|
||||
# 从连接池获取连接
|
||||
conn = self.es_pool.get_connection()
|
||||
|
||||
# 使用指定的索引名称或默认索引名称
|
||||
target_index = index_name if index_name else self.index_name
|
||||
logger.info(f"开始重建索引: {target_index}")
|
||||
|
||||
# 定义mapping结构
|
||||
if target_index == 'student_info':
|
||||
mapping = {
|
||||
"mappings": {
|
||||
"properties": {
|
||||
"user_id": {"type": "keyword"},
|
||||
"grade": {"type": "keyword"},
|
||||
"recent_questions": {"type": "text"},
|
||||
"learned_knowledge": {"type": "text"},
|
||||
"updated_at": {"type": "date"}
|
||||
}
|
||||
}
|
||||
}
|
||||
else:
|
||||
mapping = {
|
||||
"mappings": {
|
||||
"properties": {
|
||||
"embedding": {
|
||||
"type": "dense_vector",
|
||||
"dims": Config.EMBED_DIM,
|
||||
"index": True,
|
||||
"similarity": "l2_norm"
|
||||
},
|
||||
"user_input": {"type": "text"},
|
||||
"tags": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"tags": {"type": "keyword"},
|
||||
"full_content": {"type": "text"}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# 检查索引是否存在,存在则删除
|
||||
if conn.indices.exists(index=target_index):
|
||||
conn.indices.delete(index=target_index)
|
||||
logger.info(f"删除已存在的索引 '{target_index}'")
|
||||
print(f"删除已存在的索引 '{target_index}'")
|
||||
|
||||
# 创建索引和mapping
|
||||
conn.indices.create(index=target_index, body=mapping)
|
||||
logger.info(f"索引 '{target_index}' 创建成功,mapping结构已设置")
|
||||
print(f"索引 '{target_index}' 创建成功,mapping结构已设置。")
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"重建索引 '{target_index}' 失败: {str(e)}")
|
||||
print(f"重建索引 '{target_index}' 失败: {e}")
|
||||
|
||||
# 提供认证错误的具体提示
|
||||
if 'AuthenticationException' in str(e):
|
||||
print("认证失败提示: 请检查Config.py中的ES_CONFIG配置,确保用户名和密码正确。")
|
||||
logger.error("认证失败: 请检查Config.py中的ES_CONFIG配置,确保用户名和密码正确。")
|
||||
|
||||
return False
|
||||
finally:
|
||||
# 释放连接回连接池
|
||||
self.es_pool.release_connection(conn)
|
||||
|
||||
def text_search(self, query, size=10):
|
||||
# 从连接池获取连接
|
||||
conn = self.es_pool.get_connection()
|
||||
try:
|
||||
# 使用连接执行搜索
|
||||
result = conn.search(
|
||||
index=self.es_config['index_name'],
|
||||
query={"match": {"user_input": query}},
|
||||
size=size
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"文本搜索失败: {str(e)}")
|
||||
raise
|
||||
finally:
|
||||
# 释放连接回连接池
|
||||
self.es_pool.release_connection(conn)
|
||||
|
||||
def select_all_data(self, size=1000):
|
||||
"""
|
||||
查询索引中的所有数据
|
||||
|
||||
参数:
|
||||
size: 返回的最大结果数量,默认1000
|
||||
|
||||
返回:
|
||||
dict: 查询结果
|
||||
"""
|
||||
# 从连接池获取连接
|
||||
conn = self.es_pool.get_connection()
|
||||
try:
|
||||
# 构建查询条件 - 匹配所有文档
|
||||
query = {
|
||||
"query": {
|
||||
"match_all": {}
|
||||
},
|
||||
"size": size
|
||||
}
|
||||
|
||||
# 执行查询
|
||||
response = conn.search(index=self.es_config['index_name'], body=query)
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"查询所有数据失败: {str(e)}")
|
||||
raise
|
||||
finally:
|
||||
# 释放连接回连接池
|
||||
self.es_pool.release_connection(conn)
|
||||
|
||||
def split_text_into_chunks(self,text: str, chunk_size: int = 200, chunk_overlap: int = 0) -> list:
|
||||
"""
|
||||
将文本切割成块
|
||||
|
||||
参数:
|
||||
text: 要切割的文本
|
||||
chunk_size: 每个块的大小
|
||||
chunk_overlap: 块之间的重叠大小
|
||||
|
||||
返回:
|
||||
list: 文本块列表
|
||||
"""
|
||||
# 创建文档对象
|
||||
docs = [Document(page_content=text, metadata={"source": "simulated_document"})]
|
||||
|
||||
# 切割文档
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=chunk_size, chunk_overlap=chunk_overlap, add_start_index=True
|
||||
)
|
||||
all_splits = text_splitter.split_documents(docs)
|
||||
print(f"切割后的文档块数量:{len(all_splits)}")
|
||||
|
||||
return [split.page_content for split in all_splits]
|
||||
|
||||
def insert_long_text_to_es(self,long_text: str, tags: list = None) -> bool:
|
||||
"""
|
||||
将长文本切割后向量化并插入到Elasticsearch,基于文本内容哈希实现去重
|
||||
|
||||
参数:
|
||||
long_text: 要插入的长文本
|
||||
tags: 可选的标签列表
|
||||
|
||||
返回:
|
||||
bool: 插入是否成功
|
||||
"""
|
||||
try:
|
||||
# 1. 创建EsSearchUtil实例以使用连接池
|
||||
search_util = EsSearchUtil(Config.ES_CONFIG)
|
||||
|
||||
# 2. 从连接池获取连接
|
||||
conn = search_util.es_pool.get_connection()
|
||||
|
||||
# # 3. 检查索引是否存在,不存在则创建
|
||||
index_name = Config.ES_CONFIG['index_name']
|
||||
# if not conn.indices.exists(index=index_name):
|
||||
# # 定义mapping结构
|
||||
# mapping = {
|
||||
# "mappings": {
|
||||
# "properties": {
|
||||
# "embedding": {
|
||||
# "type": "dense_vector",
|
||||
# "dims": Config.EMBED_DIM, # 根据实际embedding维度调整
|
||||
# "index": True,
|
||||
# "similarity": "l2_norm"
|
||||
# },
|
||||
# "user_input": {"type": "text"},
|
||||
# "tags": {
|
||||
# "type": "object",
|
||||
# "properties": {
|
||||
# "tags": {"type": "keyword"},
|
||||
# "full_content": {"type": "text"}
|
||||
# }
|
||||
# },
|
||||
# "timestamp": {"type": "date"}
|
||||
# }
|
||||
# }
|
||||
# }
|
||||
# conn.indices.create(index=index_name, body=mapping)
|
||||
# print(f"索引 '{index_name}' 创建成功")
|
||||
|
||||
# 4. 切割文本
|
||||
text_chunks = self.split_text_into_chunks(long_text)
|
||||
|
||||
# 5. 准备标签
|
||||
if tags is None:
|
||||
tags = ["general_text"]
|
||||
|
||||
# 6. 获取当前时间
|
||||
timestamp = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
|
||||
# 7. 创建嵌入模型
|
||||
embeddings = OpenAIEmbeddings(
|
||||
model=Config.EMBED_MODEL_NAME,
|
||||
base_url=Config.EMBED_BASE_URL,
|
||||
api_key=SecretStr(Config.EMBED_API_KEY)
|
||||
)
|
||||
|
||||
# 8. 为每个文本块生成向量并插入
|
||||
for i, chunk in enumerate(text_chunks):
|
||||
# 生成文本块的哈希值作为文档ID
|
||||
doc_id = hashlib.md5(chunk.encode('utf-8')).hexdigest()
|
||||
|
||||
# 检查文档是否已存在
|
||||
if conn.exists(index=index_name, id=doc_id):
|
||||
print(f"文档块 {i+1} 已存在,跳过插入: {doc_id}")
|
||||
continue
|
||||
|
||||
# 生成文本块的嵌入向量
|
||||
embedding = embeddings.embed_documents([chunk])[0]
|
||||
|
||||
# 准备文档数据
|
||||
doc = {
|
||||
'tags': {"tags": tags, "full_content": long_text},
|
||||
'user_input': chunk,
|
||||
'timestamp': timestamp,
|
||||
'embedding': embedding
|
||||
}
|
||||
|
||||
# 插入数据到Elasticsearch
|
||||
conn.index(index=index_name, id=doc_id, document=doc)
|
||||
print(f"文档块 {i+1} 插入成功: {doc_id}")
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"插入数据失败: {e}")
|
||||
return False
|
||||
finally:
|
||||
# 确保释放连接回连接池
|
||||
if 'conn' in locals() and 'search_util' in locals():
|
||||
search_util.es_pool.release_connection(conn)
|
||||
|
||||
def get_query_embedding(self, query: str) -> list:
|
||||
"""
|
||||
将查询文本转换为向量
|
||||
|
||||
参数:
|
||||
query: 查询文本
|
||||
|
||||
返回:
|
||||
list: 向量表示
|
||||
"""
|
||||
# 创建嵌入模型
|
||||
embeddings = OpenAIEmbeddings(
|
||||
model=Config.EMBED_MODEL_NAME,
|
||||
base_url=Config.EMBED_BASE_URL,
|
||||
api_key=SecretStr(Config.EMBED_API_KEY)
|
||||
)
|
||||
|
||||
# 生成查询向量
|
||||
query_embedding = embeddings.embed_query(query)
|
||||
return query_embedding
|
||||
|
||||
def rerank_results(self, query: str, results: list) -> list:
|
||||
"""
|
||||
使用重排模型对搜索结果进行重排
|
||||
|
||||
参数:
|
||||
query: 查询文本
|
||||
results: 搜索结果列表
|
||||
|
||||
返回:
|
||||
list: 重排后的结果列表,每个元素是(文档对象, 分数)的元组
|
||||
"""
|
||||
if not results:
|
||||
print("警告: 没有搜索结果可供重排")
|
||||
return []
|
||||
|
||||
try:
|
||||
# 准备重排请求数据
|
||||
# 确保doc是字典并包含'_source'和'user_input'字段
|
||||
documents = []
|
||||
valid_results = []
|
||||
for i, doc in enumerate(results):
|
||||
if isinstance(doc, dict) and '_source' in doc and 'user_input' in doc['_source']:
|
||||
documents.append(doc['_source']['user_input'])
|
||||
valid_results.append(doc)
|
||||
else:
|
||||
print(f"警告: 结果项 {i} 格式不正确,跳过该结果")
|
||||
print(f"结果项内容: {doc}")
|
||||
|
||||
if not documents:
|
||||
print("警告: 没有有效的文档可供重排")
|
||||
# 返回原始结果,但转换为(结果, 分数)的元组格式
|
||||
return [(doc, doc.get('_score', 0.0)) for doc in results]
|
||||
|
||||
rerank_data = {
|
||||
"model": Config.RERANK_MODEL,
|
||||
"query": query,
|
||||
"documents": documents,
|
||||
"top_n": len(documents)
|
||||
}
|
||||
|
||||
# 调用重排API
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {Config.RERANK_BINDING_API_KEY}"
|
||||
}
|
||||
|
||||
response = requests.post(Config.RERANK_BASE_URL, headers=headers, data=json.dumps(rerank_data))
|
||||
response.raise_for_status() # 检查请求是否成功
|
||||
rerank_result = response.json()
|
||||
|
||||
# 处理重排结果
|
||||
reranked_results = []
|
||||
if "results" in rerank_result:
|
||||
for item in rerank_result["results"]:
|
||||
doc_idx = item.get("index")
|
||||
score = item.get("relevance_score", 0.0)
|
||||
if 0 <= doc_idx < len(valid_results):
|
||||
result = valid_results[doc_idx]
|
||||
reranked_results.append((result, score))
|
||||
else:
|
||||
print("警告: 无法识别重排API响应格式")
|
||||
# 返回原始结果,但转换为(结果, 分数)的元组格式
|
||||
reranked_results = [(doc, doc.get('_score', 0.0)) for doc in valid_results]
|
||||
|
||||
print(f"重排后结果数量:{len(reranked_results)}")
|
||||
return reranked_results
|
||||
|
||||
except Exception as e:
|
||||
print(f"重排失败: {e}")
|
||||
print("将使用原始搜索结果")
|
||||
# 返回原始结果,但转换为(结果, 分数)的元组格式
|
||||
return [(doc, doc.get('_score', 0.0)) for doc in results]
|
||||
|
||||
def search_by_vector(self, query_embedding: list, k: int = 10) -> list:
|
||||
"""
|
||||
根据向量进行相似性搜索
|
||||
|
||||
参数:
|
||||
query_embedding: 查询向量
|
||||
k: 返回的结果数量
|
||||
|
||||
返回:
|
||||
list: 搜索结果列表
|
||||
"""
|
||||
try:
|
||||
# 从连接池获取连接
|
||||
conn = self.es_pool.get_connection()
|
||||
index_name = Config.ES_CONFIG['index_name']
|
||||
|
||||
# 执行向量搜索
|
||||
response = conn.search(
|
||||
index=index_name,
|
||||
body={
|
||||
"query": {
|
||||
"script_score": {
|
||||
"query": {"match_all": {}},
|
||||
"script": {
|
||||
"source": "cosineSimilarity(params.query_vector, 'embedding') + 1.0",
|
||||
"params": {
|
||||
"query_vector": query_embedding
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"size": k
|
||||
}
|
||||
)
|
||||
|
||||
# 提取结果
|
||||
# 确保我们提取的是 hits.hits 部分
|
||||
if 'hits' in response and 'hits' in response['hits']:
|
||||
results = response['hits']['hits']
|
||||
print(f"向量搜索结果数量: {len(results)}")
|
||||
return results
|
||||
else:
|
||||
print("警告: 向量搜索响应格式不正确")
|
||||
print(f"响应内容: {response}")
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
print(f"向量搜索失败: {e}")
|
||||
return []
|
||||
finally:
|
||||
# 释放连接回连接池
|
||||
self.es_pool.release_connection(conn)
|
||||
|
||||
def display_results(self, results: list, show_score: bool = True) -> None:
|
||||
"""
|
||||
展示搜索结果
|
||||
|
||||
参数:
|
||||
results: 搜索结果列表
|
||||
show_score: 是否显示分数
|
||||
"""
|
||||
if not results:
|
||||
print("没有找到匹配的结果。")
|
||||
return
|
||||
|
||||
print(f"找到 {len(results)} 条结果:\n")
|
||||
for i, item in enumerate(results, 1):
|
||||
print(f"结果 {i}:")
|
||||
try:
|
||||
# 检查item是否为元组格式 (result, score)
|
||||
if isinstance(item, tuple):
|
||||
if len(item) >= 2:
|
||||
result, score = item[0], item[1]
|
||||
else:
|
||||
result, score = item[0], 0.0
|
||||
else:
|
||||
# 如果不是元组,假设item就是result
|
||||
result = item
|
||||
score = result.get('_score', 0.0)
|
||||
|
||||
# 确保result是字典类型
|
||||
if not isinstance(result, dict):
|
||||
print(f"警告: 结果项 {i} 不是字典类型,跳过显示")
|
||||
print(f"结果项内容: {result}")
|
||||
print("---")
|
||||
continue
|
||||
|
||||
# 尝试获取user_input内容
|
||||
if '_source' in result and 'user_input' in result['_source']:
|
||||
content = result['_source']['user_input']
|
||||
print(f"内容: {content}")
|
||||
elif 'user_input' in result:
|
||||
content = result['user_input']
|
||||
print(f"内容: {content}")
|
||||
else:
|
||||
print(f"警告: 结果项 {i} 缺少'user_input'字段")
|
||||
print(f"结果项内容: {result}")
|
||||
print("---")
|
||||
continue
|
||||
|
||||
# 显示分数
|
||||
if show_score:
|
||||
print(f"分数: {score:.4f}")
|
||||
|
||||
# 如果有标签信息,也显示出来
|
||||
if '_source' in result and 'tags' in result['_source']:
|
||||
tags = result['_source']['tags']
|
||||
if isinstance(tags, dict) and 'tags' in tags:
|
||||
print(f"标签: {tags['tags']}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"处理结果项 {i} 时出错: {str(e)}")
|
||||
print(f"结果项内容: {item}")
|
||||
print("---")
|
||||
|
||||
def merge_results(self, keyword_results: List[Tuple[Dict, float]], vector_results: List[Tuple[Dict, float]]) -> List[Tuple[Dict, float, str]]:
|
||||
"""
|
||||
合并关键字搜索和向量搜索结果
|
||||
|
||||
参数:
|
||||
keyword_results: 关键字搜索结果列表,每个元素是(文档, 分数)元组
|
||||
vector_results: 向量搜索结果列表,每个元素是(文档, 分数)元组
|
||||
|
||||
返回:
|
||||
list: 合并后的结果列表,每个元素是(文档, 分数, 来源)元组
|
||||
"""
|
||||
# 标记结果来源并合并
|
||||
all_results = []
|
||||
for doc, score in keyword_results:
|
||||
all_results.append((doc, score, "关键字搜索"))
|
||||
for doc, score in vector_results:
|
||||
all_results.append((doc, score, "向量搜索"))
|
||||
|
||||
# 去重并按分数排序
|
||||
unique_results = {}
|
||||
for doc, score, source in all_results:
|
||||
doc_id = doc['_id']
|
||||
if doc_id not in unique_results or score > unique_results[doc_id][1]:
|
||||
unique_results[doc_id] = (doc, score, source)
|
||||
|
||||
# 按分数降序排序
|
||||
sorted_results = sorted(unique_results.values(), key=lambda x: x[1], reverse=True)
|
||||
return sorted_results
|
||||
|
||||
# 添加函数:保存学生信息到ES
|
||||
def save_student_info_to_es(self,user_id, info):
|
||||
"""将学生信息保存到Elasticsearch"""
|
||||
try:
|
||||
# 使用用户ID作为文档ID
|
||||
doc_id = f"student_{user_id}"
|
||||
# 准备文档内容
|
||||
doc = {
|
||||
"user_id": user_id,
|
||||
"info": info,
|
||||
"update_time": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
}
|
||||
# 从连接池获取连接
|
||||
es_conn = self.es_pool.get_connection()
|
||||
try:
|
||||
# 确保索引存在,如果不存在则创建
|
||||
es_conn.index(index="student_info", id=doc_id, document=doc)
|
||||
logger.info(f"学生 {user_id} 的信息已保存到ES: {info}")
|
||||
finally:
|
||||
# 释放连接回连接池
|
||||
self.es_pool.release_connection(es_conn)
|
||||
except Exception as e:
|
||||
logger.error(f"保存学生信息到ES失败: {str(e)}", exc_info=True)
|
||||
|
||||
# 添加函数:从ES获取学生信息
|
||||
def get_student_info_from_es(self,user_id):
|
||||
"""从Elasticsearch获取学生信息"""
|
||||
try:
|
||||
doc_id = f"student_{user_id}"
|
||||
# 从连接池获取连接
|
||||
es_conn = self.es_pool.get_connection()
|
||||
try:
|
||||
# 确保索引存在
|
||||
if es_conn.indices.exists(index=Config.ES_CONFIG.get("student_info_index")):
|
||||
result = es_conn.get(index=Config.ES_CONFIG.get("student_info_index"), id=doc_id)
|
||||
if result and '_source' in result:
|
||||
logger.info(f"从ES获取到学生 {user_id} 的信息: {result['_source']['info']}")
|
||||
return result['_source']['info']
|
||||
else:
|
||||
logger.info(f"ES中没有找到学生 {user_id} 的信息")
|
||||
else:
|
||||
logger.info("student_info索引不存在")
|
||||
finally:
|
||||
# 释放连接回连接池
|
||||
self.es_pool.release_connection(es_conn)
|
||||
except Exception as e:
|
||||
# 如果文档不存在,返回空字典
|
||||
if "not_found" in str(e).lower():
|
||||
logger.info(f"学生 {user_id} 的信息在ES中不存在")
|
||||
return {}
|
||||
logger.error(f"从ES获取学生信息失败: {str(e)}", exc_info=True)
|
||||
return {}
|
||||
|
||||
def extract_student_info(self,text, user_id):
|
||||
"""使用jieba分词提取学生信息"""
|
||||
try:
|
||||
# 提取年级信息
|
||||
seg_list = jieba.cut(text, cut_all=False) # 精确模式
|
||||
seg_set = set(seg_list)
|
||||
|
||||
# 检查是否已有学生信息,如果没有则从ES加载
|
||||
if user_id not in self.student_info:
|
||||
# 从ES加载学生信息
|
||||
info_from_es = self.get_student_info_from_es(user_id)
|
||||
if info_from_es:
|
||||
self.student_info[user_id] = info_from_es
|
||||
logger.info(f"从ES加载用户 {user_id} 的信息: {info_from_es}")
|
||||
else:
|
||||
self.student_info[user_id] = {}
|
||||
|
||||
# 提取并更新年级信息
|
||||
grade_found = False
|
||||
for grade, keywords in self.GRADE_KEYWORDS.items():
|
||||
for keyword in keywords:
|
||||
if keyword in seg_set:
|
||||
if 'grade' not in self.student_info[user_id] or self.student_info[user_id]['grade'] != grade:
|
||||
self.student_info[user_id]['grade'] = grade
|
||||
logger.info(f"提取到用户 {user_id} 的年级信息: {grade}")
|
||||
# 保存到ES
|
||||
self.save_student_info_to_es(user_id, self.student_info[user_id])
|
||||
grade_found = True
|
||||
break
|
||||
if grade_found:
|
||||
break
|
||||
|
||||
# 如果文本中明确提到年级,但没有匹配到关键词,尝试直接提取数字
|
||||
if not grade_found:
|
||||
import re
|
||||
# 匹配"我是X年级"格式
|
||||
match = re.search(r'我是(\d+)年级', text)
|
||||
if match:
|
||||
grade_num = match.group(1)
|
||||
grade = f"{grade_num}年级"
|
||||
if 'grade' not in self.student_info[user_id] or self.student_info[user_id]['grade'] != grade:
|
||||
self.student_info[user_id]['grade'] = grade
|
||||
logger.info(f"通过正则提取到用户 {user_id} 的年级信息: {grade}")
|
||||
# 保存到ES
|
||||
self.save_student_info_to_es(user_id, self.student_info[user_id])
|
||||
except Exception as e:
|
||||
logger.error(f"提取学生信息失败: {str(e)}", exc_info=True)
|
||||
|
||||
|
||||
|
@@ -1,125 +0,0 @@
|
||||
# pip install pydantic requests
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.vectorstores import InMemoryVectorStore
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
from pydantic import SecretStr
|
||||
import requests
|
||||
import json
|
||||
from Config.Config import (
|
||||
EMBED_MODEL_NAME, EMBED_BASE_URL, EMBED_API_KEY,
|
||||
RERANK_MODEL, RERANK_BASE_URL, RERANK_BINDING_API_KEY
|
||||
)
|
||||
|
||||
|
||||
class VectorDBUtil:
|
||||
"""向量数据库工具类,提供文本向量化存储和查询功能"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化向量数据库工具"""
|
||||
# 初始化嵌入模型
|
||||
self.embeddings = OpenAIEmbeddings(
|
||||
model=EMBED_MODEL_NAME,
|
||||
base_url=EMBED_BASE_URL,
|
||||
api_key=SecretStr(EMBED_API_KEY) # 包装成 SecretStr 类型
|
||||
)
|
||||
# 初始化向量存储
|
||||
self.vector_store = None
|
||||
|
||||
def text_to_vector_db(self, text: str, chunk_size: int = 200, chunk_overlap: int = 0) -> tuple:
|
||||
"""
|
||||
将文本存入向量数据库
|
||||
|
||||
参数:
|
||||
text: 要入库的文本
|
||||
chunk_size: 文本分割块大小
|
||||
chunk_overlap: 文本块重叠大小
|
||||
|
||||
返回:
|
||||
tuple: (向量存储对象, 文档数量, 分割后的文档块数量)
|
||||
"""
|
||||
# 创建文档对象
|
||||
docs = [Document(page_content=text, metadata={"source": "simulated_document"})]
|
||||
doc_count = len(docs)
|
||||
print(f"文档数量:{doc_count} 个")
|
||||
|
||||
# 切割文档
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=chunk_size, chunk_overlap=chunk_overlap, add_start_index=True
|
||||
)
|
||||
all_splits = text_splitter.split_documents(docs)
|
||||
split_count = len(all_splits)
|
||||
print(f"切割后的文档块数量:{split_count}")
|
||||
|
||||
# 向量存储
|
||||
self.vector_store = InMemoryVectorStore(self.embeddings)
|
||||
ids = self.vector_store.add_documents(documents=all_splits)
|
||||
|
||||
return self.vector_store, doc_count, split_count
|
||||
|
||||
def query_vector_db(self, query: str, k: int = 4) -> list:
|
||||
"""
|
||||
从向量数据库查询文本
|
||||
|
||||
参数:
|
||||
query: 查询字符串
|
||||
k: 要返回的结果数量
|
||||
|
||||
返回:
|
||||
list: 重排后的结果列表,每个元素是(文档对象, 可信度分数)的元组
|
||||
"""
|
||||
if not self.vector_store:
|
||||
print("错误: 向量数据库未初始化,请先调用text_to_vector_db方法")
|
||||
return []
|
||||
|
||||
# 向量查询 - 获取更多结果用于重排
|
||||
results = self.vector_store.similarity_search(query, k=k)
|
||||
print(f"向量搜索结果数量:{len(results)}")
|
||||
|
||||
# 存储重排后的文档和分数
|
||||
reranked_docs_with_scores = []
|
||||
|
||||
# 调用重排模型
|
||||
if len(results) > 1:
|
||||
# 准备重排请求数据
|
||||
rerank_data = {
|
||||
"model": RERANK_MODEL,
|
||||
"query": query,
|
||||
"documents": [doc.page_content for doc in results],
|
||||
"top_n": len(results)
|
||||
}
|
||||
|
||||
# 调用SiliconFlow API进行重排
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {RERANK_BINDING_API_KEY}"
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(RERANK_BASE_URL, headers=headers, data=json.dumps(rerank_data))
|
||||
response.raise_for_status() # 检查请求是否成功
|
||||
rerank_result = response.json()
|
||||
|
||||
# 处理重排结果,提取relevance_score
|
||||
if "results" in rerank_result:
|
||||
for item in rerank_result["results"]:
|
||||
doc_idx = item.get("index")
|
||||
score = item.get("relevance_score", 0.0)
|
||||
if 0 <= doc_idx < len(results):
|
||||
reranked_docs_with_scores.append((results[doc_idx], score))
|
||||
else:
|
||||
print("警告: 无法识别重排API响应格式")
|
||||
reranked_docs_with_scores = [(doc, 0.0) for doc in results]
|
||||
|
||||
print(f"重排后结果数量:{len(reranked_docs_with_scores)}")
|
||||
except Exception as e:
|
||||
print(f"重排模型调用失败: {e}")
|
||||
print("将使用原始搜索结果")
|
||||
reranked_docs_with_scores = [(doc, 0.0) for doc in results]
|
||||
else:
|
||||
# 只有一个结果,无需重排
|
||||
reranked_docs_with_scores = [(doc, 1.0) for doc in results] # 单个结果可信度设为1.0
|
||||
|
||||
return reranked_docs_with_scores
|
||||
|
||||
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,212 +0,0 @@
|
||||
# pip install jieba
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import jieba
|
||||
import fastapi
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from openai import AsyncOpenAI
|
||||
from sse_starlette import EventSourceResponse
|
||||
import uuid
|
||||
|
||||
from Config import Config
|
||||
from ElasticSearch.Utils.EsSearchUtil import EsSearchUtil
|
||||
|
||||
# 初始化日志
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
# 初始化异步 OpenAI 客户端
|
||||
client = AsyncOpenAI(
|
||||
api_key=Config.ALY_LLM_API_KEY,
|
||||
base_url=Config.ALY_LLM_BASE_URL
|
||||
)
|
||||
|
||||
# 初始化 ElasticSearch 工具
|
||||
search_util = EsSearchUtil(Config.ES_CONFIG)
|
||||
|
||||
async def lifespan(_: FastAPI):
|
||||
yield
|
||||
|
||||
app = FastAPI(_=lifespan)
|
||||
|
||||
@app.post("/api/teaching_chat")
|
||||
async def teaching_chat(request: fastapi.Request):
|
||||
"""
|
||||
根据用户输入的语句,查询相关历史对话
|
||||
然后调用大模型进行回答
|
||||
"""
|
||||
try:
|
||||
data = await request.json()
|
||||
user_id = data.get('user_id', 'anonymous')
|
||||
query = data.get('query', '')
|
||||
session_id = data.get('session_id', str(uuid.uuid4())) # 获取或生成会话ID
|
||||
include_history = data.get('include_history', True)
|
||||
|
||||
if not query:
|
||||
raise HTTPException(status_code=400, detail="查询内容不能为空")
|
||||
|
||||
# 1. 初始化会话历史和学生信息
|
||||
if session_id not in search_util.conversation_history:
|
||||
search_util.conversation_history[session_id] = []
|
||||
|
||||
# 检查是否已有学生信息,如果没有则从ES加载
|
||||
if user_id not in search_util.student_info:
|
||||
# 从ES加载学生信息
|
||||
info_from_es = search_util.get_student_info_from_es(user_id)
|
||||
if info_from_es:
|
||||
search_util.student_info[user_id] = info_from_es
|
||||
logger.info(f"从ES加载用户 {user_id} 的信息: {info_from_es}")
|
||||
else:
|
||||
search_util.student_info[user_id] = {}
|
||||
|
||||
# 2. 使用jieba分词提取学生信息
|
||||
search_util.extract_student_info(query, user_id)
|
||||
|
||||
# 输出调试信息
|
||||
logger.info(f"当前学生信息: {search_util.student_info.get(user_id, {})}")
|
||||
|
||||
# 为用户查询生成标签并存储到ES
|
||||
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
tags = [user_id, f"time:{current_time.split()[0]}", f"session:{session_id}"]
|
||||
|
||||
# 提取查询中的关键词作为额外标签 - 使用jieba分词
|
||||
try:
|
||||
seg_list = jieba.cut(query, cut_all=False) # 精确模式
|
||||
keywords = [kw for kw in seg_list if kw.strip() and kw not in search_util.STOPWORDS and len(kw) > 1]
|
||||
keywords = keywords[:5]
|
||||
tags.extend([f"keyword:{kw}" for kw in keywords])
|
||||
logger.info(f"使用jieba分词提取的关键词: {keywords}")
|
||||
except Exception as e:
|
||||
logger.error(f"分词失败: {str(e)}")
|
||||
keywords = query.split()[:5]
|
||||
tags.extend([f"keyword:{kw}" for kw in keywords if kw.strip()])
|
||||
|
||||
# 存储查询到ES
|
||||
try:
|
||||
search_util.insert_long_text_to_es(query, tags)
|
||||
logger.info(f"用户 {user_id} 的查询已存储到ES,标签: {tags}")
|
||||
except Exception as e:
|
||||
logger.error(f"存储用户查询到ES失败: {str(e)}")
|
||||
|
||||
# 3. 构建对话历史上下文
|
||||
history_context = ""
|
||||
if include_history and session_id in search_util.conversation_history:
|
||||
# 获取最近的几次对话历史
|
||||
recent_history = search_util.conversation_history[session_id][-search_util.MAX_HISTORY_ROUNDS:]
|
||||
if recent_history:
|
||||
history_context = "\n\n以下是最近的对话历史,可供参考:\n"
|
||||
for i, (user_msg, ai_msg) in enumerate(recent_history, 1):
|
||||
history_context += f"[对话 {i}] 用户: {user_msg}\n"
|
||||
history_context += f"[对话 {i}] 老师: {ai_msg}\n"
|
||||
|
||||
# 4. 构建学生信息上下文
|
||||
student_context = ""
|
||||
if user_id in search_util.student_info and search_util.student_info[user_id]:
|
||||
student_context = "\n\n学生基础信息:\n"
|
||||
for key, value in search_util.student_info[user_id].items():
|
||||
if key == 'grade':
|
||||
student_context += f"- 年级: {value}\n"
|
||||
else:
|
||||
student_context += f"- {key}: {value}\n"
|
||||
|
||||
# 5. 构建提示词
|
||||
system_prompt = """
|
||||
你是一位平易近人且教学方法灵活的教师,通过引导学生自主学习来帮助他们掌握知识。
|
||||
|
||||
严格遵循以下教学规则:
|
||||
1. 基于学生情况调整教学:如果已了解学生的年级水平和知识背景,应基于此调整教学内容和难度。
|
||||
2. 基于现有知识构建:将新思想与学生已有的知识联系起来。
|
||||
3. 引导而非灌输:使用问题、提示和小步骤,让学生自己发现答案。
|
||||
4. 检查和强化:在讲解难点后,确认学生能够重述或应用这些概念。
|
||||
5. 变化节奏:混合讲解、提问和互动活动,让教学像对话而非讲座。
|
||||
|
||||
最重要的是:不要直接给出答案,而是通过合作和基于学生已有知识的引导,帮助学生自己找到答案。
|
||||
"""
|
||||
|
||||
# 添加学生信息到系统提示词
|
||||
if user_id in search_util.student_info and search_util.student_info[user_id]:
|
||||
student_info_str = "\n\n学生基础信息:\n"
|
||||
for key, value in search_util.student_info[user_id].items():
|
||||
if key == 'grade':
|
||||
student_info_str += f"- 年级: {value}\n"
|
||||
else:
|
||||
student_info_str += f"- {key}: {value}\n"
|
||||
system_prompt += student_info_str
|
||||
|
||||
# 6. 流式调用大模型生成回答
|
||||
async def generate_response_stream():
|
||||
try:
|
||||
# 构建消息列表
|
||||
messages = [{'role': 'system', 'content': system_prompt.strip()}]
|
||||
|
||||
# 添加学生信息(如果有)
|
||||
if student_context:
|
||||
messages.append({'role': 'user', 'content': student_context.strip()})
|
||||
|
||||
# 添加历史对话(如果有)
|
||||
if history_context:
|
||||
messages.append({'role': 'user', 'content': history_context.strip()})
|
||||
|
||||
# 添加当前问题
|
||||
messages.append({'role': 'user', 'content': query})
|
||||
|
||||
stream = await client.chat.completions.create(
|
||||
model=Config.ALY_LLM_MODEL_NAME,
|
||||
messages=messages,
|
||||
max_tokens=8000,
|
||||
stream=True
|
||||
)
|
||||
|
||||
# 收集完整回答用于保存
|
||||
full_answer = []
|
||||
async for chunk in stream:
|
||||
if chunk.choices[0].delta.content:
|
||||
full_answer.append(chunk.choices[0].delta.content)
|
||||
yield f"data: {json.dumps({'reply': chunk.choices[0].delta.content}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 保存回答到ES和对话历史
|
||||
if full_answer:
|
||||
answer_text = ''.join(full_answer)
|
||||
search_util.extract_student_info(answer_text, user_id)
|
||||
try:
|
||||
# 为回答添加标签
|
||||
answer_tags = [f"{user_id}_answer", f"time:{current_time.split()[0]}", f"session:{session_id}"]
|
||||
try:
|
||||
seg_list = jieba.cut(answer_text, cut_all=False)
|
||||
answer_keywords = [kw for kw in seg_list if kw.strip() and kw not in search_util.STOPWORDS and len(kw) > 1]
|
||||
answer_keywords = answer_keywords[:5]
|
||||
answer_tags.extend([f"keyword:{kw}" for kw in answer_keywords])
|
||||
except Exception as e:
|
||||
logger.error(f"回答分词失败: {str(e)}")
|
||||
|
||||
search_util.insert_long_text_to_es(answer_text, answer_tags)
|
||||
logger.info(f"用户 {user_id} 的回答已存储到ES")
|
||||
|
||||
# 更新对话历史
|
||||
search_util.conversation_history[session_id].append((query, answer_text))
|
||||
# 保持历史记录不超过最大轮数
|
||||
if len(search_util.conversation_history[session_id]) > search_util.MAX_HISTORY_ROUNDS:
|
||||
search_util.conversation_history[session_id].pop(0)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"存储回答到ES失败: {str(e)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"大模型调用失败: {str(e)}")
|
||||
yield f"data: {json.dumps({'error': f'生成回答失败: {str(e)}'})}\n\n"
|
||||
|
||||
return EventSourceResponse(generate_response_stream())
|
||||
|
||||
except HTTPException as e:
|
||||
logger.error(f"聊天接口错误: {str(e.detail)}")
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.error(f"聊天接口异常: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"处理请求失败: {str(e)}")
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
Reference in New Issue
Block a user