'commit'
This commit is contained in:
@@ -85,35 +85,43 @@ async def translate_to_ggb(request: Request):
|
|||||||
# 获取请求数据
|
# 获取请求数据
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
natural_language_cmd = data.get("natural_language_cmd")
|
natural_language_cmd = data.get("natural_language_cmd")
|
||||||
|
session_id = data.get("session_id")
|
||||||
except Exception:
|
except Exception:
|
||||||
# 如果JSON解析失败,尝试从表单数据中获取
|
# 如果JSON解析失败,尝试从表单数据中获取
|
||||||
form = await request.form()
|
form = await request.form()
|
||||||
natural_language_cmd = form.get("natural_language_cmd")
|
natural_language_cmd = form.get("natural_language_cmd")
|
||||||
|
session_id = form.get("session_id")
|
||||||
|
|
||||||
# 验证参数是否存在
|
# 验证参数是否存在
|
||||||
if not natural_language_cmd:
|
if not natural_language_cmd:
|
||||||
raise HTTPException(status_code=400, detail="缺少natural_language_cmd参数")
|
raise HTTPException(status_code=400, detail="缺少natural_language_cmd参数")
|
||||||
|
|
||||||
|
# 确保会话ID存在
|
||||||
|
if not session_id:
|
||||||
|
# 如果没有提供会话ID,生成一个临时ID(实际应用中应该使用更安全的方式)
|
||||||
|
session_id = str(uuid.uuid4())
|
||||||
|
logger.warning(f"未提供会话ID,生成临时ID: {session_id}")
|
||||||
|
|
||||||
async def generate_translation_stream():
|
async def generate_translation_stream():
|
||||||
try:
|
try:
|
||||||
logger.info(f"开始翻译自然语言指令: {natural_language_cmd}")
|
logger.info(f"开始翻译自然语言指令: {natural_language_cmd} (会话ID: {session_id})")
|
||||||
# 发送开始翻译的通知
|
# 发送开始翻译的通知
|
||||||
yield f"{json.dumps({'status': 'translating', 'content': '开始翻译指令...'}, ensure_ascii=False)}"
|
yield f"{json.dumps({'status': 'translating', 'content': '开始翻译指令...'}, ensure_ascii=False)}"
|
||||||
await asyncio.sleep(0.1)
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
# 调用翻译函数并流式返回结果(使用别名)
|
# 调用翻译函数并流式返回结果(传递会话ID)
|
||||||
async for chunk in translate_to_ggb_util(natural_language_cmd):
|
async for chunk in translate_to_ggb_util(natural_language_cmd, session_id):
|
||||||
if chunk:
|
if chunk:
|
||||||
yield f"{json.dumps({'status': 'translating', 'content': chunk}, ensure_ascii=False)}"
|
yield f"{json.dumps({'status': 'translating', 'content': chunk, 'session_id': session_id}, ensure_ascii=False)}"
|
||||||
# 控制台输出
|
# 控制台输出
|
||||||
print(chunk, end="", flush=True)
|
print(chunk, end="", flush=True)
|
||||||
await asyncio.sleep(0.1)
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"翻译指令异常: {str(e)}")
|
logger.error(f"翻译指令异常: {str(e)}")
|
||||||
yield f"{json.dumps({'error': f'翻译指令异常: {str(e)}'}, ensure_ascii=False)}"
|
yield f"{json.dumps({'error': f'翻译指令异常: {str(e)}', 'session_id': session_id}, ensure_ascii=False)}"
|
||||||
finally:
|
finally:
|
||||||
yield f"{json.dumps({'DONE': True}, ensure_ascii=False)}"
|
yield f"{json.dumps({'DONE': True, 'session_id': session_id}, ensure_ascii=False)}"
|
||||||
|
|
||||||
# 返回SSE响应
|
# 返回SSE响应
|
||||||
return EventSourceResponse(generate_translation_stream())
|
return EventSourceResponse(generate_translation_stream())
|
@@ -9,8 +9,8 @@ from Config.Config import ZHIPU_API_KEY
|
|||||||
# 设置日志
|
# 设置日志
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# 添加历史记录存储变量
|
# 添加历史记录存储变量 - 使用字典按会话ID存储
|
||||||
history_records = []
|
session_history = {}
|
||||||
MAX_HISTORY = 10
|
MAX_HISTORY = 10
|
||||||
|
|
||||||
def create_llm_client() -> AsyncOpenAI:
|
def create_llm_client() -> AsyncOpenAI:
|
||||||
@@ -123,8 +123,15 @@ async def getGgbCommand(QvqResult):
|
|||||||
yield f"错误: {str(e)}"
|
yield f"错误: {str(e)}"
|
||||||
|
|
||||||
|
|
||||||
async def translate_to_ggb(natural_language_cmd):
|
async def translate_to_ggb(natural_language_cmd, session_id):
|
||||||
global history_records
|
global session_history
|
||||||
|
|
||||||
|
# 确保当前会话ID存在于历史记录字典中
|
||||||
|
if session_id not in session_history:
|
||||||
|
session_history[session_id] = []
|
||||||
|
|
||||||
|
# 获取当前会话的历史记录
|
||||||
|
history_records = session_history[session_id]
|
||||||
|
|
||||||
# 构建包含历史记录的提示
|
# 构建包含历史记录的提示
|
||||||
history_prompt = """
|
history_prompt = """
|
||||||
@@ -175,6 +182,8 @@ async def translate_to_ggb(natural_language_cmd):
|
|||||||
# 保持历史记录不超过最大数量
|
# 保持历史记录不超过最大数量
|
||||||
if len(history_records) > MAX_HISTORY:
|
if len(history_records) > MAX_HISTORY:
|
||||||
history_records.pop(0)
|
history_records.pop(0)
|
||||||
|
# 更新会话历史
|
||||||
|
session_history[session_id] = history_records
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"翻译指令错误: {str(e)}")
|
print(f"翻译指令错误: {str(e)}")
|
||||||
@@ -259,7 +268,8 @@ async def process_geometry_image(image_url: str):
|
|||||||
# 示例用法
|
# 示例用法
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
async def test_translate():
|
async def test_translate():
|
||||||
async for content in translate_to_ggb("把A和B连接起来"):
|
# 测试时使用固定会话ID
|
||||||
|
async for content in translate_to_ggb("把A和B连接起来", "test_session_123"):
|
||||||
if content:
|
if content:
|
||||||
print(content, end="", flush=True)
|
print(content, end="", flush=True)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user