You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
68 lines
2.0 KiB
68 lines
2.0 KiB
from elasticsearch import Elasticsearch
|
|
from Util.EmbeddingUtil import text_to_embedding
|
|
import random
|
|
import Config.Config as config
|
|
|
|
# 初始化ES连接
|
|
es = Elasticsearch(
|
|
hosts=config.ES_CONFIG['hosts'],
|
|
basic_auth=config.ES_CONFIG['basic_auth'],
|
|
verify_certs=config.ES_CONFIG['verify_certs']
|
|
)
|
|
|
|
def vector_search(text):
|
|
"""向量相似度搜索"""
|
|
vector = text_to_embedding(text)
|
|
script_query = {
|
|
"script_score": {
|
|
"query": {"match_all": {}},
|
|
"script": {
|
|
"source": "cosineSimilarity(params.query_vector, 'vector') + 1.0",
|
|
"params": {"query_vector": vector}
|
|
}
|
|
}
|
|
}
|
|
response = es.search(
|
|
index='knowledge_base',
|
|
body={
|
|
"size": 5,
|
|
"query": script_query,
|
|
"_source": ["text"]
|
|
}
|
|
)
|
|
return [hit['_source']['text'] for hit in response['hits']['hits']]
|
|
|
|
def text_search(text):
|
|
"""文本精确搜索"""
|
|
response = es.search(
|
|
index='raw_texts',
|
|
body={
|
|
"query": {
|
|
"match_phrase": {
|
|
"text": text
|
|
}
|
|
}
|
|
}
|
|
)
|
|
return [hit['_source']['text'] for hit in response['hits']['hits']]
|
|
|
|
def test_queries(file_path):
|
|
"""从文本文件中随机选取5个句子进行测试"""
|
|
with open(file_path, 'r', encoding='utf-8') as f:
|
|
sentences = [line.strip() for line in f if line.strip()]
|
|
|
|
test_samples = random.sample(sentences, min(5, len(sentences)))
|
|
|
|
for sample in test_samples:
|
|
print(f"测试句子: {sample}")
|
|
print("向量搜索结果:")
|
|
for result in vector_search(sample):
|
|
print(f"- {result}")
|
|
|
|
print("\n文本精确搜索结果:")
|
|
for result in text_search(sample):
|
|
print(f"- {result}")
|
|
print("="*50)
|
|
|
|
if __name__ == "__main__":
|
|
test_queries("../Txt/人口变化趋势对云南教育的影响.txt") |