From f33c6ac544d484b4c9c98369298c923c20c50315 Mon Sep 17 00:00:00 2001 From: HuangHai <10402852@qq.com> Date: Tue, 19 Aug 2025 10:51:04 +0800 Subject: [PATCH] 'commit' --- dsSchoolBuddy/Start.py | 317 +++++++++++++++++++++++++---------------- 1 file changed, 198 insertions(+), 119 deletions(-) diff --git a/dsSchoolBuddy/Start.py b/dsSchoolBuddy/Start.py index 516d000d..848785bd 100644 --- a/dsSchoolBuddy/Start.py +++ b/dsSchoolBuddy/Start.py @@ -1,170 +1,249 @@ import json -import subprocess -import tempfile -import urllib.parse +import logging import uuid import warnings -from io import BytesIO +from datetime import datetime import fastapi import uvicorn -from fastapi import FastAPI, HTTPException +from fastapi import FastAPI, HTTPException, Depends from openai import AsyncOpenAI from sse_starlette import EventSourceResponse -from starlette.responses import StreamingResponse -from starlette.staticfiles import StaticFiles from Config import Config -from ElasticSearch.Utils.EsSearchUtil import * -from Util.MySQLUtil import init_mysql_pool +from ElasticSearch.Utils.EsSearchUtil import EsSearchUtil # 初始化日志 logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) -# 配置日志处理器 -log_file = os.path.join(os.path.dirname(__file__), 'Logs', 'app.log') -os.makedirs(os.path.dirname(log_file), exist_ok=True) - -# 文件处理器 -file_handler = RotatingFileHandler( - log_file, maxBytes=1024 * 1024, backupCount=5, encoding='utf-8') -file_handler.setFormatter(logging.Formatter( - '%(asctime)s - %(name)s - %(levelname)s - %(message)s')) - -# 控制台处理器 -console_handler = logging.StreamHandler() -console_handler.setFormatter(logging.Formatter( - '%(asctime)s - %(name)s - %(levelname)s - %(message)s')) - -logger.addHandler(file_handler) -logger.addHandler(console_handler) # 初始化异步 OpenAI 客户端 client = AsyncOpenAI( - api_key=Config.MODEL_API_KEY, - base_url=Config.MODEL_API_URL, + api_key=Config.ALY_LLM_API_KEY, + base_url=Config.ALY_LLM_MODEL_NAME, ) +# 初始化 ElasticSearch 工具 +search_util = EsSearchUtil(Config.ES_CONFIG) + async def lifespan(app: FastAPI): - # 抑制HTTPS相关警告 + # 抑制 HTTPS 相关警告 warnings.filterwarnings('ignore', message='Connecting to .* using TLS with verify_certs=False is insecure') warnings.filterwarnings('ignore', message='Unverified HTTPS request is being made to host') yield - app = FastAPI(lifespan=lifespan) -# 挂载静态文件目录 -app.mount("/static", StaticFiles(directory="Static"), name="static") - - -@app.post("/api/save-word") -async def save_to_word(request: fastapi.Request): - output_file = None +@app.post("/api/chat", dependencies=[Depends(verify_api_key)]) +async def chat(request: fastapi.Request): + """ + 根据用户输入的语句,通过关键字和向量两种方式查询相关信息 + 然后调用大模型进行回答 + """ try: - # Parse request data - try: - data = await request.json() - markdown_content = data.get('markdown_content', '') - if not markdown_content: - raise ValueError("Empty MarkDown content") - except Exception as e: - logger.error(f"Request parsing failed: {str(e)}") - raise HTTPException(status_code=400, detail=f"Invalid request: {str(e)}") + data = await request.json() + user_id = data.get('user_id', 'anonymous') + query = data.get('query', '') + query_tags = data.get('tags', []) + session_id = data.get('session_id', str(uuid.uuid4())) + include_history = data.get('include_history', True) - # 创建临时Markdown文件 - temp_md = os.path.join(tempfile.gettempdir(), uuid.uuid4().hex + ".md") - with open(temp_md, "w", encoding="utf-8") as f: - f.write(markdown_content) + if not query: + raise HTTPException(status_code=400, detail="查询内容不能为空") - # 使用pandoc转换 - output_file = os.path.join(tempfile.gettempdir(), "【理想大模型】问答.docx") - subprocess.run(['pandoc', temp_md, '-o', output_file, '--resource-path=static'], check=True) + # 1. 保存当前查询到 ES + await save_query_to_es(user_id, session_id, query, query_tags) - # 读取生成的Word文件 - with open(output_file, "rb") as f: - stream = BytesIO(f.read()) + # 2. 获取相关历史记录 + history_context = "" + if include_history: + history_results = await get_related_history_from_es(user_id, query) + if history_results: + history_context = "\n".join([ + f"历史问题 {i+1}: {item['query']}\n历史回答: {item['answer']}" + for i, item in enumerate(history_results[:3]) # 取最近3条相关历史 + ]) + logger.info(f"找到 {len(history_results)} 条相关历史记录") - # 返回响应 - encoded_filename = urllib.parse.quote("【理想大模型】问答.docx") - return StreamingResponse( - stream, - media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document", - headers={"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}"}) + # 3. 调用 ES 进行混合搜索 + logger.info(f"开始执行混合搜索: query={query}") + search_results = await search_by_mixed(query, query_tags) - except HTTPException: - raise - except Exception as e: - logger.error(f"Unexpected error: {str(e)}") - raise HTTPException(status_code=500, detail="Internal server error") - finally: - # 清理临时文件 - try: - if temp_md and os.path.exists(temp_md): - os.remove(temp_md) - if output_file and os.path.exists(output_file): - os.remove(output_file) - except Exception as e: - logger.warning(f"Failed to clean up temp files: {str(e)}") + # 4. 构建提示词 + context = "" + if search_results: + context = "\n".join([ + f"搜索结果 {i+1}: {res['_source']['user_input']}" + for i, res in enumerate(search_results[:5]) # 取前5条搜索结果 + ]) + # 结合历史记录和搜索结果构建完整上下文 + full_context = "" + if history_context: + full_context += f"相关历史对话:\n{history_context}\n\n" + if context: + full_context += f"相关知识:\n{context}" -@app.post("/api/rag", response_model=None) -async def rag(request: fastapi.Request): - data = await request.json() - query = data.get('query', '') - query_tags = data.get('tags', []) - # 调用es进行混合搜索 - search_results = EsSearchUtil.queryByEs(query, query_tags, logger) - # 构建提示词 - context = "\n".join([ - f"结果{i + 1}: {res['tags']['full_content']}" - for i, res in enumerate(search_results['text_results']) - ]) - # 添加图片识别提示 - prompt = f""" + if not full_context: + full_context = "没有找到相关信息" + + prompt = f""" 信息检索与回答助手 - 根据以下关于'{query}'的相关信息: + 用户现在的问题是: '{query}' - 相关信息 - {context} + {full_context} - 回答要求 - 1. 对于公式内容: - - 使用行内格式:$公式$ + 回答要求: + 1. 对于公式内容: + - 使用行内格式: $公式$ - 重要公式可单独一行显示 - - 绝对不要使用代码块格式(```或''') + - 绝对不要使用代码块格式(```或''') - 可适当使用\large增大公式字号 - - 如果内容中包含数学公式,请使用行内格式,如$f(x) = x^2$ - - 如果内容中包含多个公式,请使用行内格式,如$f(x) = x^2$ $g(x) = x^3$ 2. 如果没有提供任何资料,那就直接拒绝回答,明确不在知识范围内。 3. 如果发现提供的资料与要询问的问题都不相关,就拒绝回答,明确不在知识范围内。 4. 如果发现提供的资料中只有部分与问题相符,那就只提取有用的相关部分,其它部分请忽略。 - 5. 对于符合问题的材料中,提供了图片的,尽量保持上下文中的图片,并尽量保持图片的清晰度。 - """ + 5. 回答要基于提供的资料,不要编造信息。 + 6. 请结合用户的历史问题和回答,提供更连贯的回复。 + """ - async def generate_response_stream(): - try: - # 流式调用大模型 - stream = await client.chat.completions.create( - model=Config.MODEL_NAME, - messages=[ - {'role': 'user', 'content': prompt} - ], - max_tokens=8000, - stream=True # 启用流式模式 - ) - # 流式返回模型生成的回复 - async for chunk in stream: - if chunk.choices[0].delta.content: - yield f"data: {json.dumps({'reply': chunk.choices[0].delta.content}, ensure_ascii=False)}\n\n" + # 5. 流式调用大模型生成回答 + async def generate_response_stream(): + try: + stream = await client.chat.completions.create( + model=Config.MODEL_NAME, + messages=[ + {'role': 'user', 'content': prompt} + ], + max_tokens=8000, + stream=True + ) - except Exception as e: - yield f"data: {json.dumps({'error': str(e)})}\n\n" + # 收集完整回答用于保存 + full_answer = [] + async for chunk in stream: + if chunk.choices[0].delta.content: + full_answer.append(chunk.choices[0].delta.content) + yield f"data: {json.dumps({'reply': chunk.choices[0].delta.content}, ensure_ascii=False)}\n\n" - return EventSourceResponse(generate_response_stream()) + # 保存回答到 ES + if full_answer: + await save_answer_to_es(user_id, session_id, query, ''.join(full_answer)) + except Exception as e: + logger.error(f"大模型调用失败: {str(e)}") + yield f"data: {json.dumps({'error': f'生成回答失败: {str(e)}'})}\n\n" + + return EventSourceResponse(generate_response_stream()) + + except HTTPException as e: + logger.error(f"聊天接口错误: {str(e.detail)}") + raise e + except Exception as e: + logger.error(f"聊天接口异常: {str(e)}") + raise HTTPException(status_code=500, detail=f"处理请求失败: {str(e)}") + + +async def save_query_to_es(user_id, session_id, query, query_tags): + """保存查询记录到 Elasticsearch""" + try: + # 生成查询向量 + query_embedding = search_util.get_query_embedding(query) + timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S') + + # 准备文档数据 + doc = { + 'user_id': user_id, + 'session_id': session_id, + 'query': query, + 'query_embedding': query_embedding, + 'tags': query_tags, + 'created_at': timestamp, + 'type': 'query' + } + + # 保存到 ES + # 假设 EsSearchUtil 有一个 save_document 方法 + # 如果没有,需要在 EsSearchUtil 中实现该方法 + await search_util.save_document('user_queries', doc) + logger.info(f"保存用户查询记录成功: user_id={user_id}, query={query}") + except Exception as e: + logger.error(f"保存查询记录失败: {str(e)}") + + +async def save_answer_to_es(user_id, session_id, query, answer): + """保存回答记录到 Elasticsearch""" + try: + timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S') + + # 准备文档数据 + doc = { + 'user_id': user_id, + 'session_id': session_id, + 'query': query, + 'answer': answer, + 'created_at': timestamp, + 'type': 'answer' + } + + # 保存到 ES + await search_util.save_document('user_answers', doc) + logger.info(f"保存回答成功: user_id={user_id}, query={query}") + except Exception as e: + logger.error(f"保存回答失败: {str(e)}") + + +async def get_related_history_from_es(user_id, query): + """从 Elasticsearch 获取相关历史记录""" + try: + # 生成查询向量 + query_embedding = search_util.get_query_embedding(query) + + # 在 ES 中搜索相关查询 + # 假设 EsSearchUtil 有一个 search_related_queries 方法 + # 如果没有,需要在 EsSearchUtil 中实现该方法 + related_queries = await search_util.search_related_queries( + user_id=user_id, + query_embedding=query_embedding, + size=50 + ) + + if not related_queries: + return [] + + # 对结果按相似度排序 + related_queries.sort(key=lambda x: x['similarity'], reverse=True) + + return related_queries[:3] # 返回前3条最相关的记录 + except Exception as e: + logger.error(f"获取相关历史记录失败: {str(e)}") + return [] + + + +async def search_by_mixed(query, query_tags): + """混合关键字和向量搜索""" + try: + # 1. 向量搜索 + query_embedding = search_util.get_query_embedding(query) + vector_results = search_util.search_by_vector(query_embedding, k=10) + + # 2. 关键字搜索 + keyword_results = search_util.text_search(query, size=10) + keyword_hits = keyword_results['hits']['hits'] if 'hits' in keyword_results and 'hits' in keyword_results['hits'] else [] + + # 3. 合并结果 + keyword_results_with_scores = [(doc, doc['_score']) for doc in keyword_hits] + vector_results_with_scores = [(doc, doc['_score']) for doc in vector_results] + + merged_results = search_util.merge_results(keyword_results_with_scores, vector_results_with_scores) + + # 4. 提取文档 + return [item[0] for item in merged_results[:10]] # 返回前10条结果 + except Exception as e: + logger.error(f"混合搜索失败: {str(e)}") + return [] if __name__ == "__main__":