Files
dsProject/dsSchoolBuddy/Start.py

248 lines
8.9 KiB
Python
Raw Normal View History

2025-08-19 07:34:39 +08:00
import json
2025-08-19 10:51:04 +08:00
import logging
2025-08-19 07:34:39 +08:00
import uuid
import warnings
2025-08-19 10:51:04 +08:00
from datetime import datetime
2025-08-19 07:34:39 +08:00
import fastapi
import uvicorn
2025-08-19 10:51:04 +08:00
from fastapi import FastAPI, HTTPException, Depends
2025-08-19 07:34:39 +08:00
from openai import AsyncOpenAI
from sse_starlette import EventSourceResponse
from Config import Config
2025-08-19 10:51:04 +08:00
from ElasticSearch.Utils.EsSearchUtil import EsSearchUtil
2025-08-19 07:34:39 +08:00
# 初始化日志
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
# 初始化异步 OpenAI 客户端
client = AsyncOpenAI(
2025-08-19 10:51:04 +08:00
api_key=Config.ALY_LLM_API_KEY,
2025-08-19 10:52:03 +08:00
base_url=Config.ALY_LLM_BASE_URL,
2025-08-19 07:34:39 +08:00
)
2025-08-19 10:51:04 +08:00
# 初始化 ElasticSearch 工具
search_util = EsSearchUtil(Config.ES_CONFIG)
2025-08-19 07:34:39 +08:00
async def lifespan(app: FastAPI):
yield
app = FastAPI(lifespan=lifespan)
2025-08-19 10:52:03 +08:00
@app.post("/api/chat")
2025-08-19 10:51:04 +08:00
async def chat(request: fastapi.Request):
"""
根据用户输入的语句通过关键字和向量两种方式查询相关信息
然后调用大模型进行回答
"""
2025-08-19 07:34:39 +08:00
try:
2025-08-19 10:51:04 +08:00
data = await request.json()
user_id = data.get('user_id', 'anonymous')
query = data.get('query', '')
query_tags = data.get('tags', [])
session_id = data.get('session_id', str(uuid.uuid4()))
include_history = data.get('include_history', True)
if not query:
raise HTTPException(status_code=400, detail="查询内容不能为空")
# 1. 保存当前查询到 ES
await save_query_to_es(user_id, session_id, query, query_tags)
# 2. 获取相关历史记录
history_context = ""
if include_history:
history_results = await get_related_history_from_es(user_id, query)
if history_results:
history_context = "\n".join([
f"历史问题 {i+1}: {item['query']}\n历史回答: {item['answer']}"
for i, item in enumerate(history_results[:3]) # 取最近3条相关历史
])
logger.info(f"找到 {len(history_results)} 条相关历史记录")
# 3. 调用 ES 进行混合搜索
logger.info(f"开始执行混合搜索: query={query}")
search_results = await search_by_mixed(query, query_tags)
# 4. 构建提示词
context = ""
if search_results:
context = "\n".join([
f"搜索结果 {i+1}: {res['_source']['user_input']}"
for i, res in enumerate(search_results[:5]) # 取前5条搜索结果
])
# 结合历史记录和搜索结果构建完整上下文
full_context = ""
if history_context:
full_context += f"相关历史对话:\n{history_context}\n\n"
if context:
full_context += f"相关知识:\n{context}"
if not full_context:
full_context = "没有找到相关信息"
prompt = f"""
2025-08-19 07:34:39 +08:00
信息检索与回答助手
2025-08-19 10:51:04 +08:00
用户现在的问题是: '{query}'
2025-08-19 07:34:39 +08:00
2025-08-19 10:51:04 +08:00
{full_context}
2025-08-19 07:34:39 +08:00
2025-08-19 10:51:04 +08:00
回答要求:
1. 对于公式内容:
- 使用行内格式: $公式$
2025-08-19 07:34:39 +08:00
- 重要公式可单独一行显示
2025-08-19 10:51:04 +08:00
- 绝对不要使用代码块格式(```''')
2025-08-19 07:34:39 +08:00
- 可适当使用\large增大公式字号
2. 如果没有提供任何资料那就直接拒绝回答明确不在知识范围内
3. 如果发现提供的资料与要询问的问题都不相关就拒绝回答明确不在知识范围内
4. 如果发现提供的资料中只有部分与问题相符那就只提取有用的相关部分其它部分请忽略
2025-08-19 10:51:04 +08:00
5. 回答要基于提供的资料不要编造信息
6. 请结合用户的历史问题和回答提供更连贯的回复
"""
# 5. 流式调用大模型生成回答
async def generate_response_stream():
try:
stream = await client.chat.completions.create(
model=Config.MODEL_NAME,
messages=[
{'role': 'user', 'content': prompt}
],
max_tokens=8000,
stream=True
)
# 收集完整回答用于保存
full_answer = []
async for chunk in stream:
if chunk.choices[0].delta.content:
full_answer.append(chunk.choices[0].delta.content)
yield f"data: {json.dumps({'reply': chunk.choices[0].delta.content}, ensure_ascii=False)}\n\n"
# 保存回答到 ES
if full_answer:
await save_answer_to_es(user_id, session_id, query, ''.join(full_answer))
except Exception as e:
logger.error(f"大模型调用失败: {str(e)}")
yield f"data: {json.dumps({'error': f'生成回答失败: {str(e)}'})}\n\n"
return EventSourceResponse(generate_response_stream())
except HTTPException as e:
logger.error(f"聊天接口错误: {str(e.detail)}")
raise e
except Exception as e:
logger.error(f"聊天接口异常: {str(e)}")
raise HTTPException(status_code=500, detail=f"处理请求失败: {str(e)}")
async def save_query_to_es(user_id, session_id, query, query_tags):
"""保存查询记录到 Elasticsearch"""
try:
# 生成查询向量
query_embedding = search_util.get_query_embedding(query)
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
# 准备文档数据
doc = {
'user_id': user_id,
'session_id': session_id,
'query': query,
'query_embedding': query_embedding,
'tags': query_tags,
'created_at': timestamp,
'type': 'query'
}
# 保存到 ES
# 假设 EsSearchUtil 有一个 save_document 方法
# 如果没有,需要在 EsSearchUtil 中实现该方法
await search_util.save_document('user_queries', doc)
logger.info(f"保存用户查询记录成功: user_id={user_id}, query={query}")
except Exception as e:
logger.error(f"保存查询记录失败: {str(e)}")
async def save_answer_to_es(user_id, session_id, query, answer):
"""保存回答记录到 Elasticsearch"""
try:
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
# 准备文档数据
doc = {
'user_id': user_id,
'session_id': session_id,
'query': query,
'answer': answer,
'created_at': timestamp,
'type': 'answer'
}
# 保存到 ES
await search_util.save_document('user_answers', doc)
logger.info(f"保存回答成功: user_id={user_id}, query={query}")
except Exception as e:
logger.error(f"保存回答失败: {str(e)}")
2025-08-19 07:34:39 +08:00
2025-08-19 10:51:04 +08:00
async def get_related_history_from_es(user_id, query):
"""从 Elasticsearch 获取相关历史记录"""
try:
# 生成查询向量
query_embedding = search_util.get_query_embedding(query)
# 在 ES 中搜索相关查询
# 假设 EsSearchUtil 有一个 search_related_queries 方法
# 如果没有,需要在 EsSearchUtil 中实现该方法
related_queries = await search_util.search_related_queries(
user_id=user_id,
query_embedding=query_embedding,
size=50
)
if not related_queries:
return []
# 对结果按相似度排序
related_queries.sort(key=lambda x: x['similarity'], reverse=True)
return related_queries[:3] # 返回前3条最相关的记录
except Exception as e:
logger.error(f"获取相关历史记录失败: {str(e)}")
return []
async def search_by_mixed(query, query_tags):
"""混合关键字和向量搜索"""
try:
# 1. 向量搜索
query_embedding = search_util.get_query_embedding(query)
vector_results = search_util.search_by_vector(query_embedding, k=10)
# 2. 关键字搜索
keyword_results = search_util.text_search(query, size=10)
keyword_hits = keyword_results['hits']['hits'] if 'hits' in keyword_results and 'hits' in keyword_results['hits'] else []
# 3. 合并结果
keyword_results_with_scores = [(doc, doc['_score']) for doc in keyword_hits]
vector_results_with_scores = [(doc, doc['_score']) for doc in vector_results]
merged_results = search_util.merge_results(keyword_results_with_scores, vector_results_with_scores)
# 4. 提取文档
return [item[0] for item in merged_results[:10]] # 返回前10条结果
except Exception as e:
logger.error(f"混合搜索失败: {str(e)}")
return []
2025-08-19 07:34:39 +08:00
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)