134 lines
4.7 KiB
Python
134 lines
4.7 KiB
Python
import json
|
||
import uuid
|
||
import time
|
||
import logging
|
||
|
||
import requests
|
||
|
||
# 配置日志
|
||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 服务器地址
|
||
BASE_URL = "http://localhost:8000"
|
||
CHAT_ENDPOINT = f"{BASE_URL}/api/chat"
|
||
|
||
# 用户ID(固定一个以便模拟多轮对话)
|
||
USER_ID = "test_user_123"
|
||
# 会话ID(固定一个以便模拟多轮对话)
|
||
SESSION_ID = str(uuid.uuid4())
|
||
# 防止重复发送的时间间隔(秒)
|
||
MIN_QUERY_INTERVAL = 1
|
||
# 上次查询时间
|
||
last_query_time = 0
|
||
|
||
|
||
def send_message(query):
|
||
"""发送消息到聊天API并接收流式响应"""
|
||
global last_query_time
|
||
current_time = time.time()
|
||
|
||
# 检查查询间隔
|
||
if current_time - last_query_time < MIN_QUERY_INTERVAL:
|
||
print(f"\n请等待 {MIN_QUERY_INTERVAL} 秒后再发送消息")
|
||
return
|
||
|
||
last_query_time = current_time
|
||
headers = {
|
||
"Content-Type": "application/json"
|
||
}
|
||
|
||
data = {
|
||
"user_id": USER_ID,
|
||
"query": query,
|
||
"session_id": SESSION_ID,
|
||
"include_history": True # 包含历史记录
|
||
}
|
||
|
||
try:
|
||
print(f"\n你: {query}")
|
||
print("老师: ", end="", flush=True)
|
||
logger.info(f"发送查询: {query}")
|
||
|
||
# 发送POST请求,使用stream=True以流式接收响应
|
||
with requests.post(CHAT_ENDPOINT, json=data, headers=headers, stream=True, timeout=30) as response:
|
||
if response.status_code == 200:
|
||
# 逐行处理流式响应
|
||
for line in response.iter_lines():
|
||
if line:
|
||
# 去掉前缀'data: '
|
||
line_str = line.decode('utf-8').strip()
|
||
if line_str.startswith('data:'):
|
||
line_str = line_str.replace('data:', '')
|
||
|
||
# 处理[DONE]标记
|
||
if line_str == '[DONE]':
|
||
break
|
||
|
||
if not line_str: # 跳过空行
|
||
continue
|
||
|
||
try:
|
||
# 解析JSON
|
||
data = json.loads(line_str)
|
||
if 'reply' in data:
|
||
print(data['reply'], end="", flush=True)
|
||
elif 'error' in data:
|
||
print(f"\n错误: {data['error']}")
|
||
logger.error(f"API错误: {data['error']}")
|
||
else:
|
||
print(f"\n未知响应格式: {data}")
|
||
logger.warning(f"未知响应格式: {data}")
|
||
except json.JSONDecodeError:
|
||
# 非JSON格式的响应
|
||
print(f"\n无法解析的响应: {line_str}")
|
||
logger.error(f"无法解析的响应: {line_str}")
|
||
print() # 换行
|
||
else:
|
||
error_msg = f"请求失败,状态码: {response.status_code},错误信息: {response.text}"
|
||
print(f"\n{error_msg}")
|
||
logger.error(error_msg)
|
||
|
||
except requests.exceptions.RequestException as e:
|
||
error_msg = f"请求异常: {str(e)}"
|
||
print(f"\n{error_msg}")
|
||
logger.error(error_msg)
|
||
except Exception as e:
|
||
error_msg = f"发生未知错误: {str(e)}"
|
||
print(f"\n{error_msg}")
|
||
logger.exception(error_msg)
|
||
|
||
|
||
def main():
|
||
"""主函数,处理用户输入并调用聊天API"""
|
||
print("===== 教育助手对话系统 =====")
|
||
print("请输入您的问题,比如:帮我讲解一下勾股定理的证明。输入'退出'结束对话")
|
||
print("===========================")
|
||
logger.info("教育助手对话系统已启动")
|
||
|
||
while True:
|
||
try:
|
||
# 获取用户输入
|
||
query = input("\n你: ")
|
||
|
||
# 检查是否退出
|
||
if query.strip() == '退出':
|
||
print("对话已结束,再见!")
|
||
logger.info("对话已结束")
|
||
break
|
||
|
||
# 发送消息
|
||
send_message(query)
|
||
|
||
except KeyboardInterrupt:
|
||
print("\n程序被中断,再见!")
|
||
logger.info("程序被用户中断")
|
||
break
|
||
except Exception as e:
|
||
error_msg = f"发生错误: {str(e)}"
|
||
print(f"\n{error_msg}")
|
||
logger.exception(error_msg)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main() |