Files
dsProject/dsSchoolBuddy/ElasticSearch/Utils/VectorDBUtil.py
2025-08-19 10:41:58 +08:00

126 lines
4.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# pip install pydantic requests
from langchain_core.documents import Document
from langchain_core.vectorstores import InMemoryVectorStore
from langchain_openai import OpenAIEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from pydantic import SecretStr
import requests
import json
from Config.Config import (
EMBED_MODEL_NAME, EMBED_BASE_URL, EMBED_API_KEY,
RERANK_MODEL, RERANK_BASE_URL, RERANK_BINDING_API_KEY
)
class VectorDBUtil:
"""向量数据库工具类,提供文本向量化存储和查询功能"""
def __init__(self):
"""初始化向量数据库工具"""
# 初始化嵌入模型
self.embeddings = OpenAIEmbeddings(
model=EMBED_MODEL_NAME,
base_url=EMBED_BASE_URL,
api_key=SecretStr(EMBED_API_KEY) # 包装成 SecretStr 类型
)
# 初始化向量存储
self.vector_store = None
def text_to_vector_db(self, text: str, chunk_size: int = 200, chunk_overlap: int = 0) -> tuple:
"""
将文本存入向量数据库
参数:
text: 要入库的文本
chunk_size: 文本分割块大小
chunk_overlap: 文本块重叠大小
返回:
tuple: (向量存储对象, 文档数量, 分割后的文档块数量)
"""
# 创建文档对象
docs = [Document(page_content=text, metadata={"source": "simulated_document"})]
doc_count = len(docs)
print(f"文档数量:{doc_count}")
# 切割文档
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size, chunk_overlap=chunk_overlap, add_start_index=True
)
all_splits = text_splitter.split_documents(docs)
split_count = len(all_splits)
print(f"切割后的文档块数量:{split_count}")
# 向量存储
self.vector_store = InMemoryVectorStore(self.embeddings)
ids = self.vector_store.add_documents(documents=all_splits)
return self.vector_store, doc_count, split_count
def query_vector_db(self, query: str, k: int = 4) -> list:
"""
从向量数据库查询文本
参数:
query: 查询字符串
k: 要返回的结果数量
返回:
list: 重排后的结果列表,每个元素是(文档对象, 可信度分数)的元组
"""
if not self.vector_store:
print("错误: 向量数据库未初始化请先调用text_to_vector_db方法")
return []
# 向量查询 - 获取更多结果用于重排
results = self.vector_store.similarity_search(query, k=k)
print(f"向量搜索结果数量:{len(results)}")
# 存储重排后的文档和分数
reranked_docs_with_scores = []
# 调用重排模型
if len(results) > 1:
# 准备重排请求数据
rerank_data = {
"model": RERANK_MODEL,
"query": query,
"documents": [doc.page_content for doc in results],
"top_n": len(results)
}
# 调用SiliconFlow API进行重排
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {RERANK_BINDING_API_KEY}"
}
try:
response = requests.post(RERANK_BASE_URL, headers=headers, data=json.dumps(rerank_data))
response.raise_for_status() # 检查请求是否成功
rerank_result = response.json()
# 处理重排结果提取relevance_score
if "results" in rerank_result:
for item in rerank_result["results"]:
doc_idx = item.get("index")
score = item.get("relevance_score", 0.0)
if 0 <= doc_idx < len(results):
reranked_docs_with_scores.append((results[doc_idx], score))
else:
print("警告: 无法识别重排API响应格式")
reranked_docs_with_scores = [(doc, 0.0) for doc in results]
print(f"重排后结果数量:{len(reranked_docs_with_scores)}")
except Exception as e:
print(f"重排模型调用失败: {e}")
print("将使用原始搜索结果")
reranked_docs_with_scores = [(doc, 0.0) for doc in results]
else:
# 只有一个结果,无需重排
reranked_docs_with_scores = [(doc, 1.0) for doc in results] # 单个结果可信度设为1.0
return reranked_docs_with_scores