This commit is contained in:
2025-08-19 11:56:23 +08:00
parent ebc42552c4
commit 0b7f846637
2 changed files with 92 additions and 15 deletions

View File

@@ -1,13 +1,14 @@
# pip install jieba
import json import json
import logging import logging
import uuid import time
from datetime import datetime import jieba
import fastapi import fastapi
import uvicorn import uvicorn
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException
from openai import AsyncOpenAI from openai import AsyncOpenAI
from sse_starlette import EventSourceResponse from sse_starlette import EventSourceResponse
import uuid
from Config import Config from Config import Config
from ElasticSearch.Utils.EsSearchUtil import EsSearchUtil from ElasticSearch.Utils.EsSearchUtil import EsSearchUtil
@@ -16,6 +17,8 @@ from ElasticSearch.Utils.EsSearchUtil import EsSearchUtil
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
# 初始化停用词表
STOPWORDS = set(['', '', '', '', '', '', '', '', '', '', '', '', '一个', '', '', '', '', '', '', '', '', '', '', '没有', '', '', '自己', ''])
# 初始化异步 OpenAI 客户端 # 初始化异步 OpenAI 客户端
client = AsyncOpenAI( client = AsyncOpenAI(
@@ -26,6 +29,12 @@ client = AsyncOpenAI(
# 初始化 ElasticSearch 工具 # 初始化 ElasticSearch 工具
search_util = EsSearchUtil(Config.ES_CONFIG) search_util = EsSearchUtil(Config.ES_CONFIG)
# 存储对话历史的字典键为会话ID值为对话历史列表
conversation_history = {}
# 最大对话历史轮数
MAX_HISTORY_ROUNDS = 10
def get_system_prompt(): def get_system_prompt():
"""获取系统提示""" """获取系统提示"""
@@ -50,34 +59,76 @@ app = FastAPI(lifespan=lifespan)
@app.post("/api/chat") @app.post("/api/chat")
async def chat(request: fastapi.Request): async def chat(request: fastapi.Request):
""" """
根据用户输入的语句,通过关键字和向量两种方式查询相关信息 根据用户输入的语句,查询相关历史对话
然后调用大模型进行回答 然后调用大模型进行回答
""" """
try: try:
data = await request.json() data = await request.json()
user_id = data.get('user_id', 'anonymous') user_id = data.get('user_id', 'anonymous')
query = data.get('query', '') query = data.get('query', '')
session_id = data.get('session_id', str(uuid.uuid4())) # 获取或生成会话ID
include_history = data.get('include_history', True)
if not query: if not query:
raise HTTPException(status_code=400, detail="查询内容不能为空") raise HTTPException(status_code=400, detail="查询内容不能为空")
# 获取系统提示词 # 1. 初始化会话历史
if session_id not in conversation_history:
conversation_history[session_id] = []
# 2. 为用户查询生成标签并存储到ES
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
tags = [user_id, f"time:{current_time.split()[0]}", f"session:{session_id}"]
# 提取查询中的关键词作为额外标签 - 使用jieba分词
try:
seg_list = jieba.cut(query, cut_all=False) # 精确模式
keywords = [kw for kw in seg_list if kw.strip() and kw not in STOPWORDS and len(kw) > 1]
keywords = keywords[:5]
tags.extend([f"keyword:{kw}" for kw in keywords])
logger.info(f"使用jieba分词提取的关键词: {keywords}")
except Exception as e:
logger.error(f"分词失败: {str(e)}")
keywords = query.split()[:5]
tags.extend([f"keyword:{kw}" for kw in keywords if kw.strip()])
# 存储查询到ES
try:
search_util.insert_long_text_to_es(query, tags)
logger.info(f"用户 {user_id} 的查询已存储到ES标签: {tags}")
except Exception as e:
logger.error(f"存储用户查询到ES失败: {str(e)}")
# 3. 构建对话历史上下文
history_context = ""
if include_history and session_id in conversation_history:
# 获取最近的几次对话历史
recent_history = conversation_history[session_id][-MAX_HISTORY_ROUNDS:]
if recent_history:
history_context = "\n\n以下是最近的对话历史,可供参考:\n"
for i, (user_msg, ai_msg) in enumerate(recent_history, 1):
history_context += f"[对话 {i}] 用户: {user_msg}\n"
history_context += f"[对话 {i}] 老师: {ai_msg}\n"
# 4. 构建提示词
system_prompt = get_system_prompt() system_prompt = get_system_prompt()
prompt = f"""
{system_prompt.strip()}
用户现在的问题是: '{query}'
"""
# 5. 流式调用大模型生成回答 # 5. 流式调用大模型生成回答
async def generate_response_stream(): async def generate_response_stream():
try: try:
# 构建消息列表
messages = [{'role': 'system', 'content': system_prompt.strip()}]
# 添加历史对话(如果有)
if history_context:
messages.append({'role': 'user', 'content': history_context.strip()})
# 添加当前问题
messages.append({'role': 'user', 'content': query})
stream = await client.chat.completions.create( stream = await client.chat.completions.create(
model=Config.ALY_LLM_MODEL_NAME, model=Config.ALY_LLM_MODEL_NAME,
messages=[ messages=messages,
{'role': 'user', 'content': prompt}
],
max_tokens=8000, max_tokens=8000,
stream=True stream=True
) )
@@ -89,6 +140,32 @@ async def chat(request: fastapi.Request):
full_answer.append(chunk.choices[0].delta.content) full_answer.append(chunk.choices[0].delta.content)
yield f"data: {json.dumps({'reply': chunk.choices[0].delta.content}, ensure_ascii=False)}\n\n" yield f"data: {json.dumps({'reply': chunk.choices[0].delta.content}, ensure_ascii=False)}\n\n"
# 保存回答到ES和对话历史
if full_answer:
answer_text = ''.join(full_answer)
try:
# 为回答添加标签
answer_tags = [f"{user_id}_answer", f"time:{current_time.split()[0]}", f"session:{session_id}"]
try:
seg_list = jieba.cut(answer_text, cut_all=False)
answer_keywords = [kw for kw in seg_list if kw.strip() and kw not in STOPWORDS and len(kw) > 1]
answer_keywords = answer_keywords[:5]
answer_tags.extend([f"keyword:{kw}" for kw in answer_keywords])
except Exception as e:
logger.error(f"回答分词失败: {str(e)}")
search_util.insert_long_text_to_es(answer_text, answer_tags)
logger.info(f"用户 {user_id} 的回答已存储到ES")
# 更新对话历史
conversation_history[session_id].append((query, answer_text))
# 保持历史记录不超过最大轮数
if len(conversation_history[session_id]) > MAX_HISTORY_ROUNDS:
conversation_history[session_id].pop(0)
except Exception as e:
logger.error(f"存储回答到ES失败: {str(e)}")
except Exception as e: except Exception as e:
logger.error(f"大模型调用失败: {str(e)}") logger.error(f"大模型调用失败: {str(e)}")
yield f"data: {json.dumps({'error': f'生成回答失败: {str(e)}'})}\n\n" yield f"data: {json.dumps({'error': f'生成回答失败: {str(e)}'})}\n\n"
@@ -104,6 +181,5 @@ async def chat(request: fastapi.Request):
if __name__ == "__main__": if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000) uvicorn.run(app, host="0.0.0.0", port=8000)

View File

@@ -105,6 +105,7 @@ def main():
print("请输入您的问题,比如:帮我讲解一下勾股定理的证明。输入'退出'结束对话") print("请输入您的问题,比如:帮我讲解一下勾股定理的证明。输入'退出'结束对话")
print("===========================") print("===========================")
logger.info("教育助手对话系统已启动") logger.info("教育助手对话系统已启动")
logger.info(f"当前会话ID: {SESSION_ID}")
while True: while True:
try: try: