You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

114 lines
3.2 KiB

"""
conda activate rag
pip install openai
"""
from elasticsearch import Elasticsearch
from openai import OpenAI
from Config import Config
# 初始化ES连接
es = Elasticsearch(
hosts=Config.ES_CONFIG['hosts'],
basic_auth=Config.ES_CONFIG['basic_auth'],
verify_certs=Config.ES_CONFIG['verify_certs']
)
# 初始化DeepSeek客户端
client = OpenAI(
api_key=Config.DEEPSEEK_API_KEY,
base_url=Config.DEEPSEEK_URL
)
def generate_report(query, context):
"""使用DeepSeek生成报告"""
prompt = f"""根据以下关于'{query}'的相关信息,整理一份结构化的报告:
要求:
1. 分章节组织内容
2. 包含关键数据和事实
3. 语言简洁专业
相关信息:
{context}"""
try:
response = client.chat.completions.create(
model="deepseek-chat",
messages=[
{"role": "system", "content": "你是一个专业的文档整理助手"},
{"role": "user", "content": prompt}
],
temperature=0.3,
stream=True
)
# 流式输出处理
full_response = ""
for chunk in response:
if chunk.choices[0].delta.content:
content = chunk.choices[0].delta.content
print(content, end="", flush=True)
full_response += content
return full_response
except Exception as e:
print(f"生成报告时出错: {str(e)}")
return ""
def process_query(query):
"""处理用户查询并生成报告"""
print(f"正在搜索与'{query}'相关的数据...")
context = search_related_data(query)
print(f"找到{len(context.split(chr(10)+chr(10)))}条相关数据")
print("正在生成报告...")
report = generate_report(query, context)
return report
def search_related_data(query):
"""搜索与查询相关的数据"""
# 向量搜索
vector_results = es.search(
index=Config.ES_CONFIG['default_index'],
body={
"query": {
"match": {
"content": {
"query": query,
"analyzer": "ik_smart" # 指定分词器
}
}
},
"size": 5
}
)
# 文本精确搜索
text_results = es.search(
index="raw_texts",
body={
"query": {
"match": {
"text.keyword": query
}
},
"size": 5
}
)
# 合并结果
context = ""
for hit in vector_results['hits']['hits']:
context += f"向量相似度结果(score={hit['_score']}):\n{hit['_source']['text']}\n\n"
for hit in text_results['hits']['hits']:
context += f"文本精确匹配结果(score={hit['_score']}):\n{hit['_source']['text']}\n\n"
return context
if __name__ == "__main__":
#user_query = input("请输入您的查询要求:")
user_query = "整理云南省初中在校生情况文档"
report = process_query(user_query)
print("\n=== 生成的报告 ===\n")
print(report)