'commit'
This commit is contained in:
@@ -1,13 +1,14 @@
|
|||||||
|
# pip install jieba
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import time
|
||||||
from datetime import datetime
|
import jieba
|
||||||
|
|
||||||
import fastapi
|
import fastapi
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import FastAPI, HTTPException
|
from fastapi import FastAPI, HTTPException
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
from sse_starlette import EventSourceResponse
|
from sse_starlette import EventSourceResponse
|
||||||
|
import uuid
|
||||||
|
|
||||||
from Config import Config
|
from Config import Config
|
||||||
from ElasticSearch.Utils.EsSearchUtil import EsSearchUtil
|
from ElasticSearch.Utils.EsSearchUtil import EsSearchUtil
|
||||||
@@ -16,6 +17,8 @@ from ElasticSearch.Utils.EsSearchUtil import EsSearchUtil
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
logger.setLevel(logging.INFO)
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
# 初始化停用词表
|
||||||
|
STOPWORDS = set(['的', '了', '在', '是', '我', '有', '和', '就', '不', '人', '都', '一', '一个', '上', '也', '很', '到', '说', '要', '去', '你', '会', '着', '没有', '看', '好', '自己', '这'])
|
||||||
|
|
||||||
# 初始化异步 OpenAI 客户端
|
# 初始化异步 OpenAI 客户端
|
||||||
client = AsyncOpenAI(
|
client = AsyncOpenAI(
|
||||||
@@ -26,6 +29,12 @@ client = AsyncOpenAI(
|
|||||||
# 初始化 ElasticSearch 工具
|
# 初始化 ElasticSearch 工具
|
||||||
search_util = EsSearchUtil(Config.ES_CONFIG)
|
search_util = EsSearchUtil(Config.ES_CONFIG)
|
||||||
|
|
||||||
|
# 存储对话历史的字典,键为会话ID,值为对话历史列表
|
||||||
|
conversation_history = {}
|
||||||
|
|
||||||
|
# 最大对话历史轮数
|
||||||
|
MAX_HISTORY_ROUNDS = 10
|
||||||
|
|
||||||
|
|
||||||
def get_system_prompt():
|
def get_system_prompt():
|
||||||
"""获取系统提示"""
|
"""获取系统提示"""
|
||||||
@@ -50,34 +59,76 @@ app = FastAPI(lifespan=lifespan)
|
|||||||
@app.post("/api/chat")
|
@app.post("/api/chat")
|
||||||
async def chat(request: fastapi.Request):
|
async def chat(request: fastapi.Request):
|
||||||
"""
|
"""
|
||||||
根据用户输入的语句,通过关键字和向量两种方式查询相关信息
|
根据用户输入的语句,查询相关历史对话
|
||||||
然后调用大模型进行回答
|
然后调用大模型进行回答
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
user_id = data.get('user_id', 'anonymous')
|
user_id = data.get('user_id', 'anonymous')
|
||||||
query = data.get('query', '')
|
query = data.get('query', '')
|
||||||
|
session_id = data.get('session_id', str(uuid.uuid4())) # 获取或生成会话ID
|
||||||
|
include_history = data.get('include_history', True)
|
||||||
|
|
||||||
if not query:
|
if not query:
|
||||||
raise HTTPException(status_code=400, detail="查询内容不能为空")
|
raise HTTPException(status_code=400, detail="查询内容不能为空")
|
||||||
|
|
||||||
# 获取系统提示词
|
# 1. 初始化会话历史
|
||||||
|
if session_id not in conversation_history:
|
||||||
|
conversation_history[session_id] = []
|
||||||
|
|
||||||
|
# 2. 为用户查询生成标签并存储到ES
|
||||||
|
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||||
|
tags = [user_id, f"time:{current_time.split()[0]}", f"session:{session_id}"]
|
||||||
|
|
||||||
|
# 提取查询中的关键词作为额外标签 - 使用jieba分词
|
||||||
|
try:
|
||||||
|
seg_list = jieba.cut(query, cut_all=False) # 精确模式
|
||||||
|
keywords = [kw for kw in seg_list if kw.strip() and kw not in STOPWORDS and len(kw) > 1]
|
||||||
|
keywords = keywords[:5]
|
||||||
|
tags.extend([f"keyword:{kw}" for kw in keywords])
|
||||||
|
logger.info(f"使用jieba分词提取的关键词: {keywords}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"分词失败: {str(e)}")
|
||||||
|
keywords = query.split()[:5]
|
||||||
|
tags.extend([f"keyword:{kw}" for kw in keywords if kw.strip()])
|
||||||
|
|
||||||
|
# 存储查询到ES
|
||||||
|
try:
|
||||||
|
search_util.insert_long_text_to_es(query, tags)
|
||||||
|
logger.info(f"用户 {user_id} 的查询已存储到ES,标签: {tags}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"存储用户查询到ES失败: {str(e)}")
|
||||||
|
|
||||||
|
# 3. 构建对话历史上下文
|
||||||
|
history_context = ""
|
||||||
|
if include_history and session_id in conversation_history:
|
||||||
|
# 获取最近的几次对话历史
|
||||||
|
recent_history = conversation_history[session_id][-MAX_HISTORY_ROUNDS:]
|
||||||
|
if recent_history:
|
||||||
|
history_context = "\n\n以下是最近的对话历史,可供参考:\n"
|
||||||
|
for i, (user_msg, ai_msg) in enumerate(recent_history, 1):
|
||||||
|
history_context += f"[对话 {i}] 用户: {user_msg}\n"
|
||||||
|
history_context += f"[对话 {i}] 老师: {ai_msg}\n"
|
||||||
|
|
||||||
|
# 4. 构建提示词
|
||||||
system_prompt = get_system_prompt()
|
system_prompt = get_system_prompt()
|
||||||
|
|
||||||
prompt = f"""
|
|
||||||
{system_prompt.strip()}
|
|
||||||
|
|
||||||
用户现在的问题是: '{query}'
|
|
||||||
"""
|
|
||||||
|
|
||||||
# 5. 流式调用大模型生成回答
|
# 5. 流式调用大模型生成回答
|
||||||
async def generate_response_stream():
|
async def generate_response_stream():
|
||||||
try:
|
try:
|
||||||
|
# 构建消息列表
|
||||||
|
messages = [{'role': 'system', 'content': system_prompt.strip()}]
|
||||||
|
|
||||||
|
# 添加历史对话(如果有)
|
||||||
|
if history_context:
|
||||||
|
messages.append({'role': 'user', 'content': history_context.strip()})
|
||||||
|
|
||||||
|
# 添加当前问题
|
||||||
|
messages.append({'role': 'user', 'content': query})
|
||||||
|
|
||||||
stream = await client.chat.completions.create(
|
stream = await client.chat.completions.create(
|
||||||
model=Config.ALY_LLM_MODEL_NAME,
|
model=Config.ALY_LLM_MODEL_NAME,
|
||||||
messages=[
|
messages=messages,
|
||||||
{'role': 'user', 'content': prompt}
|
|
||||||
],
|
|
||||||
max_tokens=8000,
|
max_tokens=8000,
|
||||||
stream=True
|
stream=True
|
||||||
)
|
)
|
||||||
@@ -89,6 +140,32 @@ async def chat(request: fastapi.Request):
|
|||||||
full_answer.append(chunk.choices[0].delta.content)
|
full_answer.append(chunk.choices[0].delta.content)
|
||||||
yield f"data: {json.dumps({'reply': chunk.choices[0].delta.content}, ensure_ascii=False)}\n\n"
|
yield f"data: {json.dumps({'reply': chunk.choices[0].delta.content}, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
# 保存回答到ES和对话历史
|
||||||
|
if full_answer:
|
||||||
|
answer_text = ''.join(full_answer)
|
||||||
|
try:
|
||||||
|
# 为回答添加标签
|
||||||
|
answer_tags = [f"{user_id}_answer", f"time:{current_time.split()[0]}", f"session:{session_id}"]
|
||||||
|
try:
|
||||||
|
seg_list = jieba.cut(answer_text, cut_all=False)
|
||||||
|
answer_keywords = [kw for kw in seg_list if kw.strip() and kw not in STOPWORDS and len(kw) > 1]
|
||||||
|
answer_keywords = answer_keywords[:5]
|
||||||
|
answer_tags.extend([f"keyword:{kw}" for kw in answer_keywords])
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"回答分词失败: {str(e)}")
|
||||||
|
|
||||||
|
search_util.insert_long_text_to_es(answer_text, answer_tags)
|
||||||
|
logger.info(f"用户 {user_id} 的回答已存储到ES")
|
||||||
|
|
||||||
|
# 更新对话历史
|
||||||
|
conversation_history[session_id].append((query, answer_text))
|
||||||
|
# 保持历史记录不超过最大轮数
|
||||||
|
if len(conversation_history[session_id]) > MAX_HISTORY_ROUNDS:
|
||||||
|
conversation_history[session_id].pop(0)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"存储回答到ES失败: {str(e)}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"大模型调用失败: {str(e)}")
|
logger.error(f"大模型调用失败: {str(e)}")
|
||||||
yield f"data: {json.dumps({'error': f'生成回答失败: {str(e)}'})}\n\n"
|
yield f"data: {json.dumps({'error': f'生成回答失败: {str(e)}'})}\n\n"
|
||||||
@@ -104,6 +181,5 @@ async def chat(request: fastapi.Request):
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||||
|
@@ -105,6 +105,7 @@ def main():
|
|||||||
print("请输入您的问题,比如:帮我讲解一下勾股定理的证明。输入'退出'结束对话")
|
print("请输入您的问题,比如:帮我讲解一下勾股定理的证明。输入'退出'结束对话")
|
||||||
print("===========================")
|
print("===========================")
|
||||||
logger.info("教育助手对话系统已启动")
|
logger.info("教育助手对话系统已启动")
|
||||||
|
logger.info(f"当前会话ID: {SESSION_ID}")
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
|
Reference in New Issue
Block a user