'commit'
This commit is contained in:
@@ -1,13 +1,14 @@
|
||||
# pip install jieba
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
import time
|
||||
import jieba
|
||||
import fastapi
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from openai import AsyncOpenAI
|
||||
from sse_starlette import EventSourceResponse
|
||||
import uuid
|
||||
|
||||
from Config import Config
|
||||
from ElasticSearch.Utils.EsSearchUtil import EsSearchUtil
|
||||
@@ -16,6 +17,8 @@ from ElasticSearch.Utils.EsSearchUtil import EsSearchUtil
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
# 初始化停用词表
|
||||
STOPWORDS = set(['的', '了', '在', '是', '我', '有', '和', '就', '不', '人', '都', '一', '一个', '上', '也', '很', '到', '说', '要', '去', '你', '会', '着', '没有', '看', '好', '自己', '这'])
|
||||
|
||||
# 初始化异步 OpenAI 客户端
|
||||
client = AsyncOpenAI(
|
||||
@@ -26,6 +29,12 @@ client = AsyncOpenAI(
|
||||
# 初始化 ElasticSearch 工具
|
||||
search_util = EsSearchUtil(Config.ES_CONFIG)
|
||||
|
||||
# 存储对话历史的字典,键为会话ID,值为对话历史列表
|
||||
conversation_history = {}
|
||||
|
||||
# 最大对话历史轮数
|
||||
MAX_HISTORY_ROUNDS = 10
|
||||
|
||||
|
||||
def get_system_prompt():
|
||||
"""获取系统提示"""
|
||||
@@ -50,34 +59,76 @@ app = FastAPI(lifespan=lifespan)
|
||||
@app.post("/api/chat")
|
||||
async def chat(request: fastapi.Request):
|
||||
"""
|
||||
根据用户输入的语句,通过关键字和向量两种方式查询相关信息
|
||||
根据用户输入的语句,查询相关历史对话
|
||||
然后调用大模型进行回答
|
||||
"""
|
||||
try:
|
||||
data = await request.json()
|
||||
user_id = data.get('user_id', 'anonymous')
|
||||
query = data.get('query', '')
|
||||
session_id = data.get('session_id', str(uuid.uuid4())) # 获取或生成会话ID
|
||||
include_history = data.get('include_history', True)
|
||||
|
||||
if not query:
|
||||
raise HTTPException(status_code=400, detail="查询内容不能为空")
|
||||
|
||||
# 获取系统提示词
|
||||
# 1. 初始化会话历史
|
||||
if session_id not in conversation_history:
|
||||
conversation_history[session_id] = []
|
||||
|
||||
# 2. 为用户查询生成标签并存储到ES
|
||||
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
tags = [user_id, f"time:{current_time.split()[0]}", f"session:{session_id}"]
|
||||
|
||||
# 提取查询中的关键词作为额外标签 - 使用jieba分词
|
||||
try:
|
||||
seg_list = jieba.cut(query, cut_all=False) # 精确模式
|
||||
keywords = [kw for kw in seg_list if kw.strip() and kw not in STOPWORDS and len(kw) > 1]
|
||||
keywords = keywords[:5]
|
||||
tags.extend([f"keyword:{kw}" for kw in keywords])
|
||||
logger.info(f"使用jieba分词提取的关键词: {keywords}")
|
||||
except Exception as e:
|
||||
logger.error(f"分词失败: {str(e)}")
|
||||
keywords = query.split()[:5]
|
||||
tags.extend([f"keyword:{kw}" for kw in keywords if kw.strip()])
|
||||
|
||||
# 存储查询到ES
|
||||
try:
|
||||
search_util.insert_long_text_to_es(query, tags)
|
||||
logger.info(f"用户 {user_id} 的查询已存储到ES,标签: {tags}")
|
||||
except Exception as e:
|
||||
logger.error(f"存储用户查询到ES失败: {str(e)}")
|
||||
|
||||
# 3. 构建对话历史上下文
|
||||
history_context = ""
|
||||
if include_history and session_id in conversation_history:
|
||||
# 获取最近的几次对话历史
|
||||
recent_history = conversation_history[session_id][-MAX_HISTORY_ROUNDS:]
|
||||
if recent_history:
|
||||
history_context = "\n\n以下是最近的对话历史,可供参考:\n"
|
||||
for i, (user_msg, ai_msg) in enumerate(recent_history, 1):
|
||||
history_context += f"[对话 {i}] 用户: {user_msg}\n"
|
||||
history_context += f"[对话 {i}] 老师: {ai_msg}\n"
|
||||
|
||||
# 4. 构建提示词
|
||||
system_prompt = get_system_prompt()
|
||||
|
||||
prompt = f"""
|
||||
{system_prompt.strip()}
|
||||
|
||||
用户现在的问题是: '{query}'
|
||||
"""
|
||||
|
||||
# 5. 流式调用大模型生成回答
|
||||
async def generate_response_stream():
|
||||
try:
|
||||
# 构建消息列表
|
||||
messages = [{'role': 'system', 'content': system_prompt.strip()}]
|
||||
|
||||
# 添加历史对话(如果有)
|
||||
if history_context:
|
||||
messages.append({'role': 'user', 'content': history_context.strip()})
|
||||
|
||||
# 添加当前问题
|
||||
messages.append({'role': 'user', 'content': query})
|
||||
|
||||
stream = await client.chat.completions.create(
|
||||
model=Config.ALY_LLM_MODEL_NAME,
|
||||
messages=[
|
||||
{'role': 'user', 'content': prompt}
|
||||
],
|
||||
messages=messages,
|
||||
max_tokens=8000,
|
||||
stream=True
|
||||
)
|
||||
@@ -89,6 +140,32 @@ async def chat(request: fastapi.Request):
|
||||
full_answer.append(chunk.choices[0].delta.content)
|
||||
yield f"data: {json.dumps({'reply': chunk.choices[0].delta.content}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 保存回答到ES和对话历史
|
||||
if full_answer:
|
||||
answer_text = ''.join(full_answer)
|
||||
try:
|
||||
# 为回答添加标签
|
||||
answer_tags = [f"{user_id}_answer", f"time:{current_time.split()[0]}", f"session:{session_id}"]
|
||||
try:
|
||||
seg_list = jieba.cut(answer_text, cut_all=False)
|
||||
answer_keywords = [kw for kw in seg_list if kw.strip() and kw not in STOPWORDS and len(kw) > 1]
|
||||
answer_keywords = answer_keywords[:5]
|
||||
answer_tags.extend([f"keyword:{kw}" for kw in answer_keywords])
|
||||
except Exception as e:
|
||||
logger.error(f"回答分词失败: {str(e)}")
|
||||
|
||||
search_util.insert_long_text_to_es(answer_text, answer_tags)
|
||||
logger.info(f"用户 {user_id} 的回答已存储到ES")
|
||||
|
||||
# 更新对话历史
|
||||
conversation_history[session_id].append((query, answer_text))
|
||||
# 保持历史记录不超过最大轮数
|
||||
if len(conversation_history[session_id]) > MAX_HISTORY_ROUNDS:
|
||||
conversation_history[session_id].pop(0)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"存储回答到ES失败: {str(e)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"大模型调用失败: {str(e)}")
|
||||
yield f"data: {json.dumps({'error': f'生成回答失败: {str(e)}'})}\n\n"
|
||||
@@ -104,6 +181,5 @@ async def chat(request: fastapi.Request):
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
|
@@ -105,6 +105,7 @@ def main():
|
||||
print("请输入您的问题,比如:帮我讲解一下勾股定理的证明。输入'退出'结束对话")
|
||||
print("===========================")
|
||||
logger.info("教育助手对话系统已启动")
|
||||
logger.info(f"当前会话ID: {SESSION_ID}")
|
||||
|
||||
while True:
|
||||
try:
|
||||
|
Reference in New Issue
Block a user