This commit is contained in:
2025-08-19 14:03:33 +08:00
parent 0440e8c817
commit aa852ce369
24 changed files with 0 additions and 1452 deletions

View File

@@ -1,27 +0,0 @@
# Elasticsearch配置
ES_CONFIG = {
'hosts': ['https://localhost:9200'],
'basic_auth': ('elastic', 'jv9h8uwRrRxmDi1dq6u8'),
'verify_certs': False,
'index_name': 'ds_db', # 默认索引名称
'student_info_index': 'student_info' # 添加student_info索引名称配置
}
# 嵌入向量模型
EMBED_MODEL_NAME = "BAAI/bge-m3"
EMBED_API_KEY = "sk-pbqibyjwhrgmnlsmdygplahextfaclgnedetybccknxojlyl"
EMBED_BASE_URL = "https://api.siliconflow.cn/v1"
EMBED_DIM = 1024
EMBED_MAX_TOKEN_SIZE = 8192
# 重排模型
RERANK_MODEL = 'BAAI/bge-reranker-v2-m3'
RERANK_BASE_URL = 'https://api.siliconflow.cn/v1/rerank'
RERANK_BINDING_API_KEY = 'sk-pbqibyjwhrgmnlsmdygplahextfaclgnedetybccknxojlyl'
# 阿里云API信息【HZKJ】
ALY_LLM_API_KEY = "sk-01d13a39e09844038322108ecdbd1bbc"
ALY_LLM_BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1"
ALY_LLM_MODEL_NAME = "qwen-plus"
#ALY_LLM_MODEL_NAME = "deepseek-r1"
# ALY_LLM_MODEL_NAME = "deepseek-v3"

View File

@@ -1,32 +0,0 @@
from Config import Config
from ElasticSearch.Utils.EsSearchUtil import EsSearchUtil
# 创建EsSearchUtil实例
search_util = EsSearchUtil(Config.ES_CONFIG)
def rebuild_index(index_name):
"""
重建指定的索引
参数:
index_name: 要重建的索引名称
返回:
bool: 操作是否成功
"""
print(f"开始重建索引: {index_name}")
if search_util.rebuild_mapping(index_name):
print(f"重建索引 '{index_name}' 操作成功")
return True
else:
print(f"重建索引 '{index_name}' 操作失败")
return False
if __name__ == "__main__":
# 重建ds_db索引
rebuild_index(Config.ES_CONFIG['index_name'])
# 重建student_info索引
rebuild_index(Config.ES_CONFIG['student_info_index'])
print("所有索引重建操作完成")

View File

@@ -1,38 +0,0 @@
# pip install pydantic requests
from ElasticSearch.Utils.VectorDBUtil import VectorDBUtil
def main():
# 模拟长字符串文档内容
long_text = """混凝土是一种广泛使用的建筑材料,由水泥、砂、石子和水混合而成。它具有高强度、耐久性和良好的可塑性,被广泛应用于建筑、桥梁、道路等土木工程领域。
混凝土的历史可以追溯到古罗马时期当时人们使用火山灰、石灰和碎石混合制成类似混凝土的材料。现代混凝土技术始于19世纪随着波特兰水泥的发明而得到快速发展。
混凝土的性能取决于其配合比,包括水灰比、砂率等参数。水灰比是影响混凝土强度的关键因素,较小的水灰比通常会产生更高强度的混凝土。
为了改善混凝土的性能,常常会添加各种外加剂,如减水剂、早强剂、缓凝剂等。此外,还可以使用纤维增强、聚合物改性等技术来提高混凝土的韧性和耐久性。
在施工过程中,混凝土需要适当的养护,以确保其强度正常发展。养护措施包括浇水、覆盖保湿、蒸汽养护等。
随着建筑技术的发展,高性能混凝土、自密实混凝土、再生骨料混凝土等新型混凝土不断涌现,为土木工程领域提供了更多的选择。"""
# 创建工具实例
vector_util = VectorDBUtil()
# 调用文本入库功能
vector_util.text_to_vector_db(long_text)
# 调用文本查询功能
query = "混凝土"
reranked_results = vector_util.query_vector_db(query, k=4)
# 打印所有查询结果及其可信度
print("最终查询结果:")
for i, (result, score) in enumerate(reranked_results):
print(f"结果 {i+1} (可信度: {score:.4f}):")
print(result.page_content)
print("---")
if __name__ == "__main__":
main()

View File

@@ -1,28 +0,0 @@
from Config import Config
from ElasticSearch.Utils.EsSearchUtil import EsSearchUtil
def main():
# 示例1插入单个长文本
long_text = """混凝土是一种广泛使用的建筑材料,由水泥、砂、石子和水混合而成。它具有高强度、耐久性和良好的可塑性,被广泛应用于建筑、桥梁、道路等土木工程领域。
混凝土的历史可以追溯到古罗马时期当时人们使用火山灰、石灰和碎石混合制成类似混凝土的材料。现代混凝土技术始于19世纪随着波特兰水泥的发明而得到快速发展。
混凝土的性能取决于其配合比,包括水灰比、砂率等参数。水灰比是影响混凝土强度的关键因素,较小的水灰比通常会产生更高强度的混凝土。
为了改善混凝土的性能,常常会添加各种外加剂,如减水剂、早强剂、缓凝剂等。此外,还可以使用纤维增强、聚合物改性等技术来提高混凝土的韧性和耐久性。
在施工过程中,混凝土需要适当的养护,以确保其强度正常发展。养护措施包括浇水、覆盖保湿、蒸汽养护等。
随着建筑技术的发展,高性能混凝土、自密实混凝土、再生骨料混凝土等新型混凝土不断涌现,为土木工程领域提供了更多的选择。"""
# 打标签
tags = ["student_110"]
# 创建EsSearchUtil实例
search_util = EsSearchUtil(Config.ES_CONFIG)
search_util.insert_long_text_to_es(long_text, tags)
if __name__ == "__main__":
main()

View File

@@ -1,28 +0,0 @@
from Config import Config
from ElasticSearch.Utils.EsSearchUtil import EsSearchUtil
# 创建EsSearchUtil实例
search_util = EsSearchUtil(Config.ES_CONFIG)
# 查询所有数据
def select_all_data(index_name):
try:
# 调用EsSearchUtil中的select_all_data方法
response = search_util.select_all_data()
hits = response['hits']['hits']
if not hits:
print(f"索引 {index_name} 中没有数据")
else:
print(f"索引 {index_name} 中共有 {len(hits)} 条数据:")
for i, hit in enumerate(hits, 1):
print(f"{i}. ID: {hit['_id']}")
print(f" 内容: {hit['_source'].get('user_input', '')}")
print(f" 标签: {hit['_source'].get('tags', '')}")
print(f" 时间戳: {hit['_source'].get('timestamp', '')}")
print("-" * 50)
except Exception as e:
print(f"查询出错: {e}")
if __name__ == "__main__":
select_all_data(Config.ES_CONFIG['index_name'])

View File

@@ -1,28 +0,0 @@
from Config import Config
from ElasticSearch.Utils.EsSearchUtil import EsSearchUtil
# 1. 创建EsSearchUtil实例已封装连接池
search_util = EsSearchUtil(Config.ES_CONFIG)
# 2. 直接在代码中指定要查询的关键字
query_keyword = "混凝土"
# 3. 执行查询并处理结果
try:
# 使用连接池进行查询
results = search_util.text_search(query_keyword, size=1000)
print(f"查询关键字 '{query_keyword}' 结果:")
if results['hits']['hits']:
for i, hit in enumerate(results['hits']['hits'], 1):
doc = hit['_source']
print(f"{i}. ID: {hit['_id']}")
print(f" 标签: {doc['tags']['tags'] if 'tags' in doc and 'tags' in doc['tags'] else ''}")
print(f" 用户问题: {doc['user_input'] if 'user_input' in doc else ''}")
print(f" 时间: {doc['timestamp'] if 'timestamp' in doc else ''}")
print("-" * 50)
else:
print(f"未找到包含 '{query_keyword}' 的数据。")
except Exception as e:
print(f"查询失败: {e}")
print(f"查询关键字: {query_keyword}")

View File

@@ -1,29 +0,0 @@
from Config import Config
from ElasticSearch.Utils.EsSearchUtil import EsSearchUtil
def main():
# 初始化搜索工具
search_util = EsSearchUtil(Config.ES_CONFIG)
# 输入查询文本
query = "混凝土"
print(f"查询文本: {query}")
# 获取查询向量
query_embedding = search_util.get_query_embedding(query)
print(f"查询向量维度: {len(query_embedding)}")
# 向量搜索
search_results = search_util.search_by_vector(query_embedding, k=10)
print(f"向量搜索结果数量: {len(search_results)}")
# 结果重排
reranked_results = search_util.rerank_results(query, search_results)
# 显示结果
search_util.display_results(reranked_results)
if __name__ == "__main__":
main()

View File

@@ -1,66 +0,0 @@
import logging
from Config.Config import ES_CONFIG
from ElasticSearch.Utils.EsSearchUtil import EsSearchUtil
# 初始化日志
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
if __name__ == "__main__":
# 初始化EsSearchUtil
search_util = EsSearchUtil(ES_CONFIG)
# 获取用户输入
user_query = input("请输入查询语句(例如:高性能的混凝土): ")
if not user_query:
user_query = "高性能的混凝土"
print(f"未输入查询语句,使用默认值: {user_query}")
query_tags = [] # 可以根据需要添加标签过滤
print(f"\n=== 开始执行查询 ===")
print(f"原始查询文本: {user_query}")
try:
# 1. 向量搜索
print("\n=== 向量搜索阶段 ===")
print("1. 文本向量化处理中...")
query_embedding = search_util.get_query_embedding(user_query)
print(f"2. 生成的查询向量维度: {len(query_embedding)}")
print(f"3. 前3维向量值: {query_embedding[:3]}")
print("4. 正在执行Elasticsearch向量搜索...")
vector_hits = search_util.search_by_vector(query_embedding, k=5)
print(f"5. 向量搜索结果数量: {len(vector_hits)}")
# 向量结果重排
print("6. 正在进行向量结果重排...")
reranked_vector_results = search_util.rerank_results(user_query, vector_hits)
print(f"7. 重排后向量结果数量: {len(reranked_vector_results)}")
# 2. 关键字搜索
print("\n=== 关键字搜索阶段 ===")
print("1. 正在执行Elasticsearch关键字搜索...")
keyword_results = search_util.text_search(user_query, size=5)
keyword_hits = keyword_results['hits']['hits']
print(f"2. 关键字搜索结果数量: {len(keyword_hits)}")
# 3. 合并结果
print("\n=== 合并搜索结果 ===")
# 为关键字结果添加分数
keyword_results_with_scores = [(doc, doc['_score']) for doc in keyword_hits]
merged_results = search_util.merge_results(keyword_results_with_scores, reranked_vector_results)
print(f"合并后唯一结果数量: {len(merged_results)}")
# 4. 打印最终结果
print("\n=== 最终搜索结果 ===")
for i, (doc, score, source) in enumerate(merged_results, 1):
print(f"{i}. 文档ID: {doc['_id']}, 分数: {score:.2f}, 来源: {source}")
print(f" 内容: {doc['_source']['user_input']}")
print(" --- ")
except Exception as e:
logger.error(f"搜索过程中发生错误: {str(e)}")
print(f"搜索失败: {str(e)}")

View File

@@ -1,110 +0,0 @@
from elasticsearch.exceptions import NotFoundError
import logging
logger = logging.getLogger(__name__)
class ElasticsearchCollectionManager:
def __init__(self, index_name):
"""
初始化Elasticsearch索引管理器
:param index_name: Elasticsearch索引名称
"""
self.index_name = index_name
def load_collection(self, es_connection):
"""
加载索引,如果不存在则创建
:param es_connection: Elasticsearch连接
"""
try:
if not es_connection.indices.exists(index=self.index_name):
logger.warning(f"Index {self.index_name} does not exist, creating new index")
self._create_index(es_connection)
except Exception as e:
logger.error(f"Failed to load collection: {str(e)}")
raise
def _create_index(self, es_connection):
"""创建新的Elasticsearch索引"""
mapping = {
"mappings": {
"properties": {
"user_input": {"type": "text"},
"tags": {
"type": "object",
"properties": {
"tags": {"type": "keyword"},
"full_content": {"type": "text"}
}
},
"timestamp": {"type": "date"},
"embedding": {"type": "dense_vector", "dims": 200}
}
}
}
es_connection.indices.create(index=self.index_name, body=mapping)
def search(self, es_connection, query_embedding, search_params, expr=None, limit=5):
"""
执行混合搜索(向量+关键字)
:param es_connection: Elasticsearch连接
:param query_embedding: 查询向量
:param search_params: 搜索参数
:param expr: 过滤表达式
:param limit: 返回结果数量
:return: 搜索结果
"""
try:
# 构建查询
query = {
"query": {
"script_score": {
"query": {
"bool": {
"must": []
}
},
"script": {
"source": "cosineSimilarity(params.query_vector, 'embedding') + 1.0",
"params": {"query_vector": query_embedding}
}
}
},
"size": limit
}
# 添加标签过滤条件
if expr:
query["query"]["script_score"]["query"]["bool"]["must"].append(
{"nested": {
"path": "tags",
"query": {
"terms": {"tags.tags": expr.split(" OR ")}
}
}}
)
logger.info(f"Executing search with query: {query}")
response = es_connection.search(index=self.index_name, body=query)
return response["hits"]["hits"]
except Exception as e:
logger.error(f"Search failed: {str(e)}")
raise
def query_by_id(self, es_connection, doc_id):
"""
根据ID查询文档
:param es_connection: Elasticsearch连接
:param doc_id: 文档ID
:return: 文档内容
"""
try:
response = es_connection.get(index=self.index_name, id=doc_id)
return response["_source"]
except NotFoundError:
logger.warning(f"Document with id {doc_id} not found")
return None
except Exception as e:
logger.error(f"Failed to query document by id: {str(e)}")
raise

View File

@@ -1,65 +0,0 @@
from elasticsearch import Elasticsearch
import threading
import logging
logger = logging.getLogger(__name__)
class ElasticsearchConnectionPool:
def __init__(self, hosts, basic_auth, verify_certs=False, max_connections=50):
"""
初始化Elasticsearch连接池
:param hosts: Elasticsearch服务器地址
:param basic_auth: 认证信息(username, password)
:param verify_certs: 是否验证SSL证书
:param max_connections: 最大连接数
"""
self.hosts = hosts
self.basic_auth = basic_auth
self.verify_certs = verify_certs
self.max_connections = max_connections
self._connections = []
self._lock = threading.Lock()
self._initialize_pool()
def _initialize_pool(self):
"""初始化连接池"""
for _ in range(self.max_connections):
self._connections.append(self._create_connection())
def _create_connection(self):
"""创建新的Elasticsearch连接"""
return Elasticsearch(
hosts=self.hosts,
basic_auth=self.basic_auth,
verify_certs=self.verify_certs
)
def get_connection(self):
"""从连接池获取一个连接"""
with self._lock:
if not self._connections:
logger.warning("Connection pool exhausted, creating new connection")
return self._create_connection()
return self._connections.pop()
def release_connection(self, connection):
"""释放连接回连接池"""
with self._lock:
if len(self._connections) < self.max_connections:
self._connections.append(connection)
else:
try:
connection.close()
except Exception as e:
logger.warning(f"Failed to close connection: {str(e)}")
def close(self):
"""关闭所有连接"""
with self._lock:
for conn in self._connections:
try:
conn.close()
except Exception as e:
logger.warning(f"Failed to close connection: {str(e)}")
self._connections.clear()

View File

@@ -1,664 +0,0 @@
import json
import logging
import warnings
import hashlib
import time
import jieba
import requests
from ElasticSearch.Utils.ElasticsearchConnectionPool import ElasticsearchConnectionPool
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_openai import OpenAIEmbeddings
from pydantic import SecretStr
from Config import Config
from typing import List, Tuple, Dict
# 初始化日志
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class EsSearchUtil:
# 存储对话历史的字典键为会话ID值为对话历史列表
conversation_history = {}
# 存储学生信息的字典键为用户ID值为学生信息
student_info = {}
# 年级关键词词典
GRADE_KEYWORDS = {
'一年级': ['一年级', '初一'],
'二年级': ['二年级', '初二'],
'三年级': ['三年级', '初三'],
'四年级': ['四年级'],
'五年级': ['五年级'],
'六年级': ['六年级'],
'七年级': ['七年级', '初一'],
'八年级': ['八年级', '初二'],
'九年级': ['九年级', '初三'],
'高一': ['高一'],
'高二': ['高二'],
'高三': ['高三']
}
# 最大对话历史轮数
MAX_HISTORY_ROUNDS = 10
# 初始化停用词表
STOPWORDS = set(
['', '', '', '', '', '', '', '', '', '', '', '', '一个', '', '', '', '', '',
'', '', '', '', '', '没有', '', '', '自己', ''])
def __init__(self, es_config):
"""
初始化Elasticsearch搜索工具
:param es_config: Elasticsearch配置字典包含hosts, username, password, index_name等
"""
# 抑制HTTPS相关警告
warnings.filterwarnings('ignore', message='Connecting to .* using TLS with verify_certs=False is insecure')
warnings.filterwarnings('ignore', message='Unverified HTTPS request is being made to host')
self.es_config = es_config
# 初始化连接池
self.es_pool = ElasticsearchConnectionPool(
hosts=es_config['hosts'],
basic_auth=es_config['basic_auth'],
verify_certs=es_config.get('verify_certs', False),
max_connections=50
)
self.index_name = es_config['index_name']
logger.info(f"EsSearchUtil初始化成功索引名称: {self.index_name}")
def rebuild_mapping(self, index_name=None):
"""
重建Elasticsearch索引和mapping结构
参数:
index_name: 可选,指定要重建的索引名称,默认使用初始化时的索引名称
返回:
bool: 操作是否成功
"""
try:
# 从连接池获取连接
conn = self.es_pool.get_connection()
# 使用指定的索引名称或默认索引名称
target_index = index_name if index_name else self.index_name
logger.info(f"开始重建索引: {target_index}")
# 定义mapping结构
if target_index == 'student_info':
mapping = {
"mappings": {
"properties": {
"user_id": {"type": "keyword"},
"grade": {"type": "keyword"},
"recent_questions": {"type": "text"},
"learned_knowledge": {"type": "text"},
"updated_at": {"type": "date"}
}
}
}
else:
mapping = {
"mappings": {
"properties": {
"embedding": {
"type": "dense_vector",
"dims": Config.EMBED_DIM,
"index": True,
"similarity": "l2_norm"
},
"user_input": {"type": "text"},
"tags": {
"type": "object",
"properties": {
"tags": {"type": "keyword"},
"full_content": {"type": "text"}
}
}
}
}
}
# 检查索引是否存在,存在则删除
if conn.indices.exists(index=target_index):
conn.indices.delete(index=target_index)
logger.info(f"删除已存在的索引 '{target_index}'")
print(f"删除已存在的索引 '{target_index}'")
# 创建索引和mapping
conn.indices.create(index=target_index, body=mapping)
logger.info(f"索引 '{target_index}' 创建成功mapping结构已设置")
print(f"索引 '{target_index}' 创建成功mapping结构已设置。")
return True
except Exception as e:
logger.error(f"重建索引 '{target_index}' 失败: {str(e)}")
print(f"重建索引 '{target_index}' 失败: {e}")
# 提供认证错误的具体提示
if 'AuthenticationException' in str(e):
print("认证失败提示: 请检查Config.py中的ES_CONFIG配置确保用户名和密码正确。")
logger.error("认证失败: 请检查Config.py中的ES_CONFIG配置确保用户名和密码正确。")
return False
finally:
# 释放连接回连接池
self.es_pool.release_connection(conn)
def text_search(self, query, size=10):
# 从连接池获取连接
conn = self.es_pool.get_connection()
try:
# 使用连接执行搜索
result = conn.search(
index=self.es_config['index_name'],
query={"match": {"user_input": query}},
size=size
)
return result
except Exception as e:
logger.error(f"文本搜索失败: {str(e)}")
raise
finally:
# 释放连接回连接池
self.es_pool.release_connection(conn)
def select_all_data(self, size=1000):
"""
查询索引中的所有数据
参数:
size: 返回的最大结果数量默认1000
返回:
dict: 查询结果
"""
# 从连接池获取连接
conn = self.es_pool.get_connection()
try:
# 构建查询条件 - 匹配所有文档
query = {
"query": {
"match_all": {}
},
"size": size
}
# 执行查询
response = conn.search(index=self.es_config['index_name'], body=query)
return response
except Exception as e:
logger.error(f"查询所有数据失败: {str(e)}")
raise
finally:
# 释放连接回连接池
self.es_pool.release_connection(conn)
def split_text_into_chunks(self,text: str, chunk_size: int = 200, chunk_overlap: int = 0) -> list:
"""
将文本切割成块
参数:
text: 要切割的文本
chunk_size: 每个块的大小
chunk_overlap: 块之间的重叠大小
返回:
list: 文本块列表
"""
# 创建文档对象
docs = [Document(page_content=text, metadata={"source": "simulated_document"})]
# 切割文档
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size, chunk_overlap=chunk_overlap, add_start_index=True
)
all_splits = text_splitter.split_documents(docs)
print(f"切割后的文档块数量:{len(all_splits)}")
return [split.page_content for split in all_splits]
def insert_long_text_to_es(self,long_text: str, tags: list = None) -> bool:
"""
将长文本切割后向量化并插入到Elasticsearch基于文本内容哈希实现去重
参数:
long_text: 要插入的长文本
tags: 可选的标签列表
返回:
bool: 插入是否成功
"""
try:
# 1. 创建EsSearchUtil实例以使用连接池
search_util = EsSearchUtil(Config.ES_CONFIG)
# 2. 从连接池获取连接
conn = search_util.es_pool.get_connection()
# # 3. 检查索引是否存在,不存在则创建
index_name = Config.ES_CONFIG['index_name']
# if not conn.indices.exists(index=index_name):
# # 定义mapping结构
# mapping = {
# "mappings": {
# "properties": {
# "embedding": {
# "type": "dense_vector",
# "dims": Config.EMBED_DIM, # 根据实际embedding维度调整
# "index": True,
# "similarity": "l2_norm"
# },
# "user_input": {"type": "text"},
# "tags": {
# "type": "object",
# "properties": {
# "tags": {"type": "keyword"},
# "full_content": {"type": "text"}
# }
# },
# "timestamp": {"type": "date"}
# }
# }
# }
# conn.indices.create(index=index_name, body=mapping)
# print(f"索引 '{index_name}' 创建成功")
# 4. 切割文本
text_chunks = self.split_text_into_chunks(long_text)
# 5. 准备标签
if tags is None:
tags = ["general_text"]
# 6. 获取当前时间
timestamp = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
# 7. 创建嵌入模型
embeddings = OpenAIEmbeddings(
model=Config.EMBED_MODEL_NAME,
base_url=Config.EMBED_BASE_URL,
api_key=SecretStr(Config.EMBED_API_KEY)
)
# 8. 为每个文本块生成向量并插入
for i, chunk in enumerate(text_chunks):
# 生成文本块的哈希值作为文档ID
doc_id = hashlib.md5(chunk.encode('utf-8')).hexdigest()
# 检查文档是否已存在
if conn.exists(index=index_name, id=doc_id):
print(f"文档块 {i+1} 已存在,跳过插入: {doc_id}")
continue
# 生成文本块的嵌入向量
embedding = embeddings.embed_documents([chunk])[0]
# 准备文档数据
doc = {
'tags': {"tags": tags, "full_content": long_text},
'user_input': chunk,
'timestamp': timestamp,
'embedding': embedding
}
# 插入数据到Elasticsearch
conn.index(index=index_name, id=doc_id, document=doc)
print(f"文档块 {i+1} 插入成功: {doc_id}")
return True
except Exception as e:
print(f"插入数据失败: {e}")
return False
finally:
# 确保释放连接回连接池
if 'conn' in locals() and 'search_util' in locals():
search_util.es_pool.release_connection(conn)
def get_query_embedding(self, query: str) -> list:
"""
将查询文本转换为向量
参数:
query: 查询文本
返回:
list: 向量表示
"""
# 创建嵌入模型
embeddings = OpenAIEmbeddings(
model=Config.EMBED_MODEL_NAME,
base_url=Config.EMBED_BASE_URL,
api_key=SecretStr(Config.EMBED_API_KEY)
)
# 生成查询向量
query_embedding = embeddings.embed_query(query)
return query_embedding
def rerank_results(self, query: str, results: list) -> list:
"""
使用重排模型对搜索结果进行重排
参数:
query: 查询文本
results: 搜索结果列表
返回:
list: 重排后的结果列表,每个元素是(文档对象, 分数)的元组
"""
if not results:
print("警告: 没有搜索结果可供重排")
return []
try:
# 准备重排请求数据
# 确保doc是字典并包含'_source'和'user_input'字段
documents = []
valid_results = []
for i, doc in enumerate(results):
if isinstance(doc, dict) and '_source' in doc and 'user_input' in doc['_source']:
documents.append(doc['_source']['user_input'])
valid_results.append(doc)
else:
print(f"警告: 结果项 {i} 格式不正确,跳过该结果")
print(f"结果项内容: {doc}")
if not documents:
print("警告: 没有有效的文档可供重排")
# 返回原始结果,但转换为(结果, 分数)的元组格式
return [(doc, doc.get('_score', 0.0)) for doc in results]
rerank_data = {
"model": Config.RERANK_MODEL,
"query": query,
"documents": documents,
"top_n": len(documents)
}
# 调用重排API
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {Config.RERANK_BINDING_API_KEY}"
}
response = requests.post(Config.RERANK_BASE_URL, headers=headers, data=json.dumps(rerank_data))
response.raise_for_status() # 检查请求是否成功
rerank_result = response.json()
# 处理重排结果
reranked_results = []
if "results" in rerank_result:
for item in rerank_result["results"]:
doc_idx = item.get("index")
score = item.get("relevance_score", 0.0)
if 0 <= doc_idx < len(valid_results):
result = valid_results[doc_idx]
reranked_results.append((result, score))
else:
print("警告: 无法识别重排API响应格式")
# 返回原始结果,但转换为(结果, 分数)的元组格式
reranked_results = [(doc, doc.get('_score', 0.0)) for doc in valid_results]
print(f"重排后结果数量:{len(reranked_results)}")
return reranked_results
except Exception as e:
print(f"重排失败: {e}")
print("将使用原始搜索结果")
# 返回原始结果,但转换为(结果, 分数)的元组格式
return [(doc, doc.get('_score', 0.0)) for doc in results]
def search_by_vector(self, query_embedding: list, k: int = 10) -> list:
"""
根据向量进行相似性搜索
参数:
query_embedding: 查询向量
k: 返回的结果数量
返回:
list: 搜索结果列表
"""
try:
# 从连接池获取连接
conn = self.es_pool.get_connection()
index_name = Config.ES_CONFIG['index_name']
# 执行向量搜索
response = conn.search(
index=index_name,
body={
"query": {
"script_score": {
"query": {"match_all": {}},
"script": {
"source": "cosineSimilarity(params.query_vector, 'embedding') + 1.0",
"params": {
"query_vector": query_embedding
}
}
}
},
"size": k
}
)
# 提取结果
# 确保我们提取的是 hits.hits 部分
if 'hits' in response and 'hits' in response['hits']:
results = response['hits']['hits']
print(f"向量搜索结果数量: {len(results)}")
return results
else:
print("警告: 向量搜索响应格式不正确")
print(f"响应内容: {response}")
return []
except Exception as e:
print(f"向量搜索失败: {e}")
return []
finally:
# 释放连接回连接池
self.es_pool.release_connection(conn)
def display_results(self, results: list, show_score: bool = True) -> None:
"""
展示搜索结果
参数:
results: 搜索结果列表
show_score: 是否显示分数
"""
if not results:
print("没有找到匹配的结果。")
return
print(f"找到 {len(results)} 条结果:\n")
for i, item in enumerate(results, 1):
print(f"结果 {i}:")
try:
# 检查item是否为元组格式 (result, score)
if isinstance(item, tuple):
if len(item) >= 2:
result, score = item[0], item[1]
else:
result, score = item[0], 0.0
else:
# 如果不是元组假设item就是result
result = item
score = result.get('_score', 0.0)
# 确保result是字典类型
if not isinstance(result, dict):
print(f"警告: 结果项 {i} 不是字典类型,跳过显示")
print(f"结果项内容: {result}")
print("---")
continue
# 尝试获取user_input内容
if '_source' in result and 'user_input' in result['_source']:
content = result['_source']['user_input']
print(f"内容: {content}")
elif 'user_input' in result:
content = result['user_input']
print(f"内容: {content}")
else:
print(f"警告: 结果项 {i} 缺少'user_input'字段")
print(f"结果项内容: {result}")
print("---")
continue
# 显示分数
if show_score:
print(f"分数: {score:.4f}")
# 如果有标签信息,也显示出来
if '_source' in result and 'tags' in result['_source']:
tags = result['_source']['tags']
if isinstance(tags, dict) and 'tags' in tags:
print(f"标签: {tags['tags']}")
except Exception as e:
print(f"处理结果项 {i} 时出错: {str(e)}")
print(f"结果项内容: {item}")
print("---")
def merge_results(self, keyword_results: List[Tuple[Dict, float]], vector_results: List[Tuple[Dict, float]]) -> List[Tuple[Dict, float, str]]:
"""
合并关键字搜索和向量搜索结果
参数:
keyword_results: 关键字搜索结果列表,每个元素是(文档, 分数)元组
vector_results: 向量搜索结果列表,每个元素是(文档, 分数)元组
返回:
list: 合并后的结果列表,每个元素是(文档, 分数, 来源)元组
"""
# 标记结果来源并合并
all_results = []
for doc, score in keyword_results:
all_results.append((doc, score, "关键字搜索"))
for doc, score in vector_results:
all_results.append((doc, score, "向量搜索"))
# 去重并按分数排序
unique_results = {}
for doc, score, source in all_results:
doc_id = doc['_id']
if doc_id not in unique_results or score > unique_results[doc_id][1]:
unique_results[doc_id] = (doc, score, source)
# 按分数降序排序
sorted_results = sorted(unique_results.values(), key=lambda x: x[1], reverse=True)
return sorted_results
# 添加函数保存学生信息到ES
def save_student_info_to_es(self,user_id, info):
"""将学生信息保存到Elasticsearch"""
try:
# 使用用户ID作为文档ID
doc_id = f"student_{user_id}"
# 准备文档内容
doc = {
"user_id": user_id,
"info": info,
"update_time": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
}
# 从连接池获取连接
es_conn = self.es_pool.get_connection()
try:
# 确保索引存在,如果不存在则创建
es_conn.index(index="student_info", id=doc_id, document=doc)
logger.info(f"学生 {user_id} 的信息已保存到ES: {info}")
finally:
# 释放连接回连接池
self.es_pool.release_connection(es_conn)
except Exception as e:
logger.error(f"保存学生信息到ES失败: {str(e)}", exc_info=True)
# 添加函数从ES获取学生信息
def get_student_info_from_es(self,user_id):
"""从Elasticsearch获取学生信息"""
try:
doc_id = f"student_{user_id}"
# 从连接池获取连接
es_conn = self.es_pool.get_connection()
try:
# 确保索引存在
if es_conn.indices.exists(index=Config.ES_CONFIG.get("student_info_index")):
result = es_conn.get(index=Config.ES_CONFIG.get("student_info_index"), id=doc_id)
if result and '_source' in result:
logger.info(f"从ES获取到学生 {user_id} 的信息: {result['_source']['info']}")
return result['_source']['info']
else:
logger.info(f"ES中没有找到学生 {user_id} 的信息")
else:
logger.info("student_info索引不存在")
finally:
# 释放连接回连接池
self.es_pool.release_connection(es_conn)
except Exception as e:
# 如果文档不存在,返回空字典
if "not_found" in str(e).lower():
logger.info(f"学生 {user_id} 的信息在ES中不存在")
return {}
logger.error(f"从ES获取学生信息失败: {str(e)}", exc_info=True)
return {}
def extract_student_info(self,text, user_id):
"""使用jieba分词提取学生信息"""
try:
# 提取年级信息
seg_list = jieba.cut(text, cut_all=False) # 精确模式
seg_set = set(seg_list)
# 检查是否已有学生信息如果没有则从ES加载
if user_id not in self.student_info:
# 从ES加载学生信息
info_from_es = self.get_student_info_from_es(user_id)
if info_from_es:
self.student_info[user_id] = info_from_es
logger.info(f"从ES加载用户 {user_id} 的信息: {info_from_es}")
else:
self.student_info[user_id] = {}
# 提取并更新年级信息
grade_found = False
for grade, keywords in self.GRADE_KEYWORDS.items():
for keyword in keywords:
if keyword in seg_set:
if 'grade' not in self.student_info[user_id] or self.student_info[user_id]['grade'] != grade:
self.student_info[user_id]['grade'] = grade
logger.info(f"提取到用户 {user_id} 的年级信息: {grade}")
# 保存到ES
self.save_student_info_to_es(user_id, self.student_info[user_id])
grade_found = True
break
if grade_found:
break
# 如果文本中明确提到年级,但没有匹配到关键词,尝试直接提取数字
if not grade_found:
import re
# 匹配"我是X年级"格式
match = re.search(r'我是(\d+)年级', text)
if match:
grade_num = match.group(1)
grade = f"{grade_num}年级"
if 'grade' not in self.student_info[user_id] or self.student_info[user_id]['grade'] != grade:
self.student_info[user_id]['grade'] = grade
logger.info(f"通过正则提取到用户 {user_id} 的年级信息: {grade}")
# 保存到ES
self.save_student_info_to_es(user_id, self.student_info[user_id])
except Exception as e:
logger.error(f"提取学生信息失败: {str(e)}", exc_info=True)

View File

@@ -1,125 +0,0 @@
# pip install pydantic requests
from langchain_core.documents import Document
from langchain_core.vectorstores import InMemoryVectorStore
from langchain_openai import OpenAIEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from pydantic import SecretStr
import requests
import json
from Config.Config import (
EMBED_MODEL_NAME, EMBED_BASE_URL, EMBED_API_KEY,
RERANK_MODEL, RERANK_BASE_URL, RERANK_BINDING_API_KEY
)
class VectorDBUtil:
"""向量数据库工具类,提供文本向量化存储和查询功能"""
def __init__(self):
"""初始化向量数据库工具"""
# 初始化嵌入模型
self.embeddings = OpenAIEmbeddings(
model=EMBED_MODEL_NAME,
base_url=EMBED_BASE_URL,
api_key=SecretStr(EMBED_API_KEY) # 包装成 SecretStr 类型
)
# 初始化向量存储
self.vector_store = None
def text_to_vector_db(self, text: str, chunk_size: int = 200, chunk_overlap: int = 0) -> tuple:
"""
将文本存入向量数据库
参数:
text: 要入库的文本
chunk_size: 文本分割块大小
chunk_overlap: 文本块重叠大小
返回:
tuple: (向量存储对象, 文档数量, 分割后的文档块数量)
"""
# 创建文档对象
docs = [Document(page_content=text, metadata={"source": "simulated_document"})]
doc_count = len(docs)
print(f"文档数量:{doc_count}")
# 切割文档
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size, chunk_overlap=chunk_overlap, add_start_index=True
)
all_splits = text_splitter.split_documents(docs)
split_count = len(all_splits)
print(f"切割后的文档块数量:{split_count}")
# 向量存储
self.vector_store = InMemoryVectorStore(self.embeddings)
ids = self.vector_store.add_documents(documents=all_splits)
return self.vector_store, doc_count, split_count
def query_vector_db(self, query: str, k: int = 4) -> list:
"""
从向量数据库查询文本
参数:
query: 查询字符串
k: 要返回的结果数量
返回:
list: 重排后的结果列表,每个元素是(文档对象, 可信度分数)的元组
"""
if not self.vector_store:
print("错误: 向量数据库未初始化请先调用text_to_vector_db方法")
return []
# 向量查询 - 获取更多结果用于重排
results = self.vector_store.similarity_search(query, k=k)
print(f"向量搜索结果数量:{len(results)}")
# 存储重排后的文档和分数
reranked_docs_with_scores = []
# 调用重排模型
if len(results) > 1:
# 准备重排请求数据
rerank_data = {
"model": RERANK_MODEL,
"query": query,
"documents": [doc.page_content for doc in results],
"top_n": len(results)
}
# 调用SiliconFlow API进行重排
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {RERANK_BINDING_API_KEY}"
}
try:
response = requests.post(RERANK_BASE_URL, headers=headers, data=json.dumps(rerank_data))
response.raise_for_status() # 检查请求是否成功
rerank_result = response.json()
# 处理重排结果提取relevance_score
if "results" in rerank_result:
for item in rerank_result["results"]:
doc_idx = item.get("index")
score = item.get("relevance_score", 0.0)
if 0 <= doc_idx < len(results):
reranked_docs_with_scores.append((results[doc_idx], score))
else:
print("警告: 无法识别重排API响应格式")
reranked_docs_with_scores = [(doc, 0.0) for doc in results]
print(f"重排后结果数量:{len(reranked_docs_with_scores)}")
except Exception as e:
print(f"重排模型调用失败: {e}")
print("将使用原始搜索结果")
reranked_docs_with_scores = [(doc, 0.0) for doc in results]
else:
# 只有一个结果,无需重排
reranked_docs_with_scores = [(doc, 1.0) for doc in results] # 单个结果可信度设为1.0
return reranked_docs_with_scores

View File

@@ -1,212 +0,0 @@
# pip install jieba
import json
import logging
import time
import jieba
import fastapi
import uvicorn
from fastapi import FastAPI, HTTPException
from openai import AsyncOpenAI
from sse_starlette import EventSourceResponse
import uuid
from Config import Config
from ElasticSearch.Utils.EsSearchUtil import EsSearchUtil
# 初始化日志
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
# 初始化异步 OpenAI 客户端
client = AsyncOpenAI(
api_key=Config.ALY_LLM_API_KEY,
base_url=Config.ALY_LLM_BASE_URL
)
# 初始化 ElasticSearch 工具
search_util = EsSearchUtil(Config.ES_CONFIG)
async def lifespan(_: FastAPI):
yield
app = FastAPI(_=lifespan)
@app.post("/api/teaching_chat")
async def teaching_chat(request: fastapi.Request):
"""
根据用户输入的语句,查询相关历史对话
然后调用大模型进行回答
"""
try:
data = await request.json()
user_id = data.get('user_id', 'anonymous')
query = data.get('query', '')
session_id = data.get('session_id', str(uuid.uuid4())) # 获取或生成会话ID
include_history = data.get('include_history', True)
if not query:
raise HTTPException(status_code=400, detail="查询内容不能为空")
# 1. 初始化会话历史和学生信息
if session_id not in search_util.conversation_history:
search_util.conversation_history[session_id] = []
# 检查是否已有学生信息如果没有则从ES加载
if user_id not in search_util.student_info:
# 从ES加载学生信息
info_from_es = search_util.get_student_info_from_es(user_id)
if info_from_es:
search_util.student_info[user_id] = info_from_es
logger.info(f"从ES加载用户 {user_id} 的信息: {info_from_es}")
else:
search_util.student_info[user_id] = {}
# 2. 使用jieba分词提取学生信息
search_util.extract_student_info(query, user_id)
# 输出调试信息
logger.info(f"当前学生信息: {search_util.student_info.get(user_id, {})}")
# 为用户查询生成标签并存储到ES
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
tags = [user_id, f"time:{current_time.split()[0]}", f"session:{session_id}"]
# 提取查询中的关键词作为额外标签 - 使用jieba分词
try:
seg_list = jieba.cut(query, cut_all=False) # 精确模式
keywords = [kw for kw in seg_list if kw.strip() and kw not in search_util.STOPWORDS and len(kw) > 1]
keywords = keywords[:5]
tags.extend([f"keyword:{kw}" for kw in keywords])
logger.info(f"使用jieba分词提取的关键词: {keywords}")
except Exception as e:
logger.error(f"分词失败: {str(e)}")
keywords = query.split()[:5]
tags.extend([f"keyword:{kw}" for kw in keywords if kw.strip()])
# 存储查询到ES
try:
search_util.insert_long_text_to_es(query, tags)
logger.info(f"用户 {user_id} 的查询已存储到ES标签: {tags}")
except Exception as e:
logger.error(f"存储用户查询到ES失败: {str(e)}")
# 3. 构建对话历史上下文
history_context = ""
if include_history and session_id in search_util.conversation_history:
# 获取最近的几次对话历史
recent_history = search_util.conversation_history[session_id][-search_util.MAX_HISTORY_ROUNDS:]
if recent_history:
history_context = "\n\n以下是最近的对话历史,可供参考:\n"
for i, (user_msg, ai_msg) in enumerate(recent_history, 1):
history_context += f"[对话 {i}] 用户: {user_msg}\n"
history_context += f"[对话 {i}] 老师: {ai_msg}\n"
# 4. 构建学生信息上下文
student_context = ""
if user_id in search_util.student_info and search_util.student_info[user_id]:
student_context = "\n\n学生基础信息:\n"
for key, value in search_util.student_info[user_id].items():
if key == 'grade':
student_context += f"- 年级: {value}\n"
else:
student_context += f"- {key}: {value}\n"
# 5. 构建提示词
system_prompt = """
你是一位平易近人且教学方法灵活的教师,通过引导学生自主学习来帮助他们掌握知识。
严格遵循以下教学规则:
1. 基于学生情况调整教学:如果已了解学生的年级水平和知识背景,应基于此调整教学内容和难度。
2. 基于现有知识构建:将新思想与学生已有的知识联系起来。
3. 引导而非灌输:使用问题、提示和小步骤,让学生自己发现答案。
4. 检查和强化:在讲解难点后,确认学生能够重述或应用这些概念。
5. 变化节奏:混合讲解、提问和互动活动,让教学像对话而非讲座。
最重要的是:不要直接给出答案,而是通过合作和基于学生已有知识的引导,帮助学生自己找到答案。
"""
# 添加学生信息到系统提示词
if user_id in search_util.student_info and search_util.student_info[user_id]:
student_info_str = "\n\n学生基础信息:\n"
for key, value in search_util.student_info[user_id].items():
if key == 'grade':
student_info_str += f"- 年级: {value}\n"
else:
student_info_str += f"- {key}: {value}\n"
system_prompt += student_info_str
# 6. 流式调用大模型生成回答
async def generate_response_stream():
try:
# 构建消息列表
messages = [{'role': 'system', 'content': system_prompt.strip()}]
# 添加学生信息(如果有)
if student_context:
messages.append({'role': 'user', 'content': student_context.strip()})
# 添加历史对话(如果有)
if history_context:
messages.append({'role': 'user', 'content': history_context.strip()})
# 添加当前问题
messages.append({'role': 'user', 'content': query})
stream = await client.chat.completions.create(
model=Config.ALY_LLM_MODEL_NAME,
messages=messages,
max_tokens=8000,
stream=True
)
# 收集完整回答用于保存
full_answer = []
async for chunk in stream:
if chunk.choices[0].delta.content:
full_answer.append(chunk.choices[0].delta.content)
yield f"data: {json.dumps({'reply': chunk.choices[0].delta.content}, ensure_ascii=False)}\n\n"
# 保存回答到ES和对话历史
if full_answer:
answer_text = ''.join(full_answer)
search_util.extract_student_info(answer_text, user_id)
try:
# 为回答添加标签
answer_tags = [f"{user_id}_answer", f"time:{current_time.split()[0]}", f"session:{session_id}"]
try:
seg_list = jieba.cut(answer_text, cut_all=False)
answer_keywords = [kw for kw in seg_list if kw.strip() and kw not in search_util.STOPWORDS and len(kw) > 1]
answer_keywords = answer_keywords[:5]
answer_tags.extend([f"keyword:{kw}" for kw in answer_keywords])
except Exception as e:
logger.error(f"回答分词失败: {str(e)}")
search_util.insert_long_text_to_es(answer_text, answer_tags)
logger.info(f"用户 {user_id} 的回答已存储到ES")
# 更新对话历史
search_util.conversation_history[session_id].append((query, answer_text))
# 保持历史记录不超过最大轮数
if len(search_util.conversation_history[session_id]) > search_util.MAX_HISTORY_ROUNDS:
search_util.conversation_history[session_id].pop(0)
except Exception as e:
logger.error(f"存储回答到ES失败: {str(e)}")
except Exception as e:
logger.error(f"大模型调用失败: {str(e)}")
yield f"data: {json.dumps({'error': f'生成回答失败: {str(e)}'})}\n\n"
return EventSourceResponse(generate_response_stream())
except HTTPException as e:
logger.error(f"聊天接口错误: {str(e.detail)}")
raise e
except Exception as e:
logger.error(f"聊天接口异常: {str(e)}")
raise HTTPException(status_code=500, detail=f"处理请求失败: {str(e)}")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)