diff --git a/dsLightRag/Config/__pycache__/Config.cpython-310.pyc b/dsLightRag/Config/__pycache__/Config.cpython-310.pyc index f982f569..ea70d283 100644 Binary files a/dsLightRag/Config/__pycache__/Config.cpython-310.pyc and b/dsLightRag/Config/__pycache__/Config.cpython-310.pyc differ diff --git a/dsLightRag/Routes/XueBanRoute.py b/dsLightRag/Routes/XueBanRoute.py index bdacd486..23b41983 100644 --- a/dsLightRag/Routes/XueBanRoute.py +++ b/dsLightRag/Routes/XueBanRoute.py @@ -4,8 +4,11 @@ import tempfile import uuid from datetime import datetime -from fastapi import APIRouter, Request, File, UploadFile -from fastapi.responses import JSONResponse +from fastapi import APIRouter, WebSocket, WebSocketDisconnect + +from Util.ASRClient import ASRClient +from Util.ObsUtil import ObsUploader +from Util.XueBanUtil import get_xueban_response_async, stream_and_split_text, StreamingVolcanoTTS # 创建路由路由器 router = APIRouter(prefix="/api", tags=["学伴"]) @@ -13,84 +16,125 @@ router = APIRouter(prefix="/api", tags=["学伴"]) # 配置日志 logger = logging.getLogger(__name__) -# 导入学伴工具函数、ASR客户端和OBS上传工具 -from Util.XueBanUtil import get_xueban_response_async -from Util.ASRClient import ASRClient -from Util.ObsUtil import ObsUploader -# 新增导入TTSService -from Util.TTSService import TTSService - - -@router.post("/xueban/upload-audio") -async def upload_audio(file: UploadFile = File(...)): - """ - 上传音频文件并进行ASR处理 - - 参数: file - 音频文件 - - 返回: JSON包含识别结果 - """ +# 新增WebSocket接口,用于流式处理 +@router.websocket("/xueban/streaming-chat") +async def streaming_chat(websocket: WebSocket): + await websocket.accept() + logger.info("WebSocket连接已接受") try: - # 记录日志 - logger.info(f"接收到音频文件: {file.filename}") - - # 保存临时文件 + # 接收用户音频文件 + logger.info("等待接收音频数据...") + data = await websocket.receive_json() + logger.info(f"接收到数据类型: {type(data)}") + logger.info(f"接收到数据内容: {data.keys() if isinstance(data, dict) else '非字典类型'}") + + # 检查数据格式 + if not isinstance(data, dict): + logger.error(f"接收到的数据不是字典类型,而是: {type(data)}") + await websocket.send_json({"type": "error", "message": "数据格式错误"}) + return + + audio_data = data.get("audio_data") + logger.info(f"音频数据是否存在: {audio_data is not None}") + logger.info(f"音频数据长度: {len(audio_data) if audio_data else 0}") + + if not audio_data: + logger.error("未收到音频数据") + await websocket.send_json({"type": "error", "message": "未收到音频数据"}) + return + + # 保存临时音频文件 timestamp = datetime.now().strftime("%Y%m%d%H%M%S") - file_ext = os.path.splitext(file.filename)[1] - temp_file_name = f"temp_audio_{timestamp}{file_ext}" - temp_file_path = os.path.join(tempfile.gettempdir(), temp_file_name) - - with open(temp_file_path, "wb") as f: - content = await file.read() - f.write(content) - - logger.info(f"音频文件已保存至临时目录: {temp_file_path}") - - # 调用ASR服务进行处理 - asr_result = await process_asr(temp_file_path) - - # 删除临时文件 - os.remove(temp_file_path) - logger.info(f"临时文件已删除: {temp_file_path}") - - # 使用大模型生成反馈 - logger.info(f"使用大模型生成反馈,输入文本: {asr_result['text']}") - response_generator = get_xueban_response_async(asr_result['text'], stream=False) - feedback_text = "" - async for chunk in response_generator: - feedback_text += chunk - logger.info(f"大模型反馈生成完成: {feedback_text}") - - # 使用TTS生成语音 - tts_service = TTSService() - tts_temp_file = os.path.join(tempfile.gettempdir(), f"tts_{timestamp}.mp3") - success = tts_service.synthesize(feedback_text, output_file=tts_temp_file) - if not success: - raise Exception("TTS语音合成失败") - logger.info(f"TTS语音合成成功,文件保存至: {tts_temp_file}") - - # 上传TTS音频文件到OBS - tts_audio_url = upload_file_to_obs(tts_temp_file) - os.remove(tts_temp_file) # 删除临时TTS文件 - logger.info(f"TTS文件已上传至OBS: {tts_audio_url}") - - # 返回结果,包含ASR文本和TTS音频URL - return JSONResponse(content={ - "success": True, - "message": "音频处理和语音反馈生成成功", - "data": { - "asr_text": asr_result['text'], - "feedback_text": feedback_text, - "audio_url": tts_audio_url - } - }) + temp_file_path = os.path.join(tempfile.gettempdir(), f"temp_audio_{timestamp}.wav") + logger.info(f"保存临时音频文件到: {temp_file_path}") + + # 解码base64音频数据并保存 + import base64 + try: + with open(temp_file_path, "wb") as f: + f.write(base64.b64decode(audio_data)) + logger.info("音频文件保存完成") + except Exception as e: + logger.error(f"音频文件保存失败: {str(e)}") + await websocket.send_json({"type": "error", "message": f"音频文件保存失败: {str(e)}"}) + return + + # 处理ASR + logger.info("开始ASR处理...") + try: + asr_result = await process_asr(temp_file_path) + logger.info(f"ASR处理完成,结果: {asr_result['text']}") + os.remove(temp_file_path) # 删除临时文件 + except Exception as e: + logger.error(f"ASR处理失败: {str(e)}") + await websocket.send_json({"type": "error", "message": f"ASR处理失败: {str(e)}"}) + if os.path.exists(temp_file_path): + os.remove(temp_file_path) # 确保删除临时文件 + return + + # 发送ASR结果给前端 + logger.info("发送ASR结果给前端") + try: + await websocket.send_json({ + "type": "asr_result", + "text": asr_result['text'] + }) + logger.info("ASR结果发送成功") + except Exception as e: + logger.error(f"发送ASR结果失败: {str(e)}") + return + # 定义音频回调函数,将音频块发送给前端 + async def audio_callback(audio_chunk): + logger.info(f"发送音频块,大小: {len(audio_chunk)}") + try: + await websocket.send_bytes(audio_chunk) + logger.info("音频块发送成功") + except Exception as e: + logger.error(f"发送音频块失败: {str(e)}") + raise + + # 实时获取LLM流式输出并处理 + logger.info("开始LLM流式处理和TTS合成...") + try: + # 获取LLM流式响应 + llm_stream = get_xueban_response_async(asr_result['text'], stream=True) + + # 使用stream_and_split_text处理流式响应并断句 + text_stream = stream_and_split_text(llm_stream=llm_stream) + + # 初始化TTS处理器 + tts = StreamingVolcanoTTS(max_concurrency=1) + + # 异步迭代文本流,按句合成TTS + async for text_chunk in text_stream: + logger.info(f"正在处理句子: {text_chunk}") + await tts._synthesize_single_with_semaphore(text_chunk, audio_callback) + logger.info("TTS合成完成") + except Exception as e: + logger.error(f"TTS合成失败: {str(e)}") + await websocket.send_json({"type": "error", "message": f"TTS合成失败: {str(e)}"}) + return + + # 发送结束信号 + logger.info("发送结束信号") + try: + await websocket.send_json({"type": "end"}) + logger.info("结束信号发送成功") + except Exception as e: + logger.error(f"发送结束信号失败: {str(e)}") + return + + except WebSocketDisconnect: + logger.info("客户端断开连接") except Exception as e: - logger.error(f"音频处理失败: {str(e)}") - return JSONResponse(content={ - "success": False, - "message": f"音频处理失败: {str(e)}" - }, status_code=500) - + logger.error(f"WebSocket处理失败: {str(e)}") + try: + await websocket.send_json({"type": "error", "message": str(e)}) + except: + logger.error("发送错误消息失败") +# 原有的辅助函数保持不变 async def process_asr(audio_path: str) -> dict: """ 调用ASR服务处理音频文件 @@ -100,16 +144,12 @@ async def process_asr(audio_path: str) -> dict: try: # 上传文件到华为云OBS audio_url = upload_file_to_obs(audio_path) - # 创建ASR客户端实例 asr_client = ASRClient() - # 设置音频文件URL asr_client.file_url = audio_url - # 处理ASR任务并获取文本结果 text_result = asr_client.process_task() - # 构建返回结果 return { "text": text_result, @@ -157,49 +197,4 @@ def upload_file_to_obs(file_path: str) -> str: raise Exception(error_msg) except Exception as e: logger.error(f"上传文件到OBS失败: {str(e)}") - raise - - -@router.post("/xueban/chat") -async def chat_with_xueban(request: Request): - """ - 与学伴大模型聊天的接口 - - 参数: request body 中的 query_text (用户查询文本) - - 返回: JSON包含聊天响应 - """ - try: - # 获取请求体数据 - data = await request.json() - query_text = data.get("query_text", "") - - if not query_text.strip(): - return JSONResponse(content={ - "success": False, - "message": "查询文本不能为空" - }, status_code=400) - - # 记录日志 - logger.info(f"接收到学伴聊天请求: {query_text}") - - # 调用异步接口获取学伴响应 - response_content = [] - async for chunk in get_xueban_response_async(query_text, stream=True): - response_content.append(chunk) - - full_response = "".join(response_content) - - # 返回响应 - return JSONResponse(content={ - "success": True, - "message": "聊天成功", - "data": { - "response": full_response - } - }) - - except Exception as e: - logger.error(f"学伴聊天失败: {str(e)}") - return JSONResponse(content={ - "success": False, - "message": f"聊天处理失败: {str(e)}" - }, status_code=500) \ No newline at end of file + raise \ No newline at end of file diff --git a/dsLightRag/Routes/__pycache__/XueBanRoute.cpython-310.pyc b/dsLightRag/Routes/__pycache__/XueBanRoute.cpython-310.pyc index d7fe8fcf..1d83a17f 100644 Binary files a/dsLightRag/Routes/__pycache__/XueBanRoute.cpython-310.pyc and b/dsLightRag/Routes/__pycache__/XueBanRoute.cpython-310.pyc differ diff --git a/dsLightRag/Start.py b/dsLightRag/Start.py index 96936c5e..183caef3 100644 --- a/dsLightRag/Start.py +++ b/dsLightRag/Start.py @@ -2,6 +2,7 @@ import uvicorn import asyncio from fastapi import FastAPI from starlette.staticfiles import StaticFiles +from fastapi.middleware.cors import CORSMiddleware # 添加此导入 from Routes.TeachingModel.tasks.BackgroundTasks import train_document_task from Util.PostgreSQLUtil import init_postgres_pool, close_postgres_pool @@ -26,6 +27,7 @@ from Routes.MjRoute import router as mj_router from Routes.QWenImageRoute import router as qwen_image_router from Util.LightRagUtil import * from contextlib import asynccontextmanager +import logging # 添加此导入 # 控制日志输出 logger = logging.getLogger('lightrag') @@ -37,8 +39,8 @@ logger.addHandler(handler) @asynccontextmanager async def lifespan(_: FastAPI): - pool = await init_postgres_pool() - app.state.pool = pool + #pool = await init_postgres_pool() + #app.state.pool = pool asyncio.create_task(train_document_task()) @@ -46,12 +48,21 @@ async def lifespan(_: FastAPI): yield finally: # 应用关闭时销毁连接池 - await close_postgres_pool(pool) + #await close_postgres_pool(pool) pass app = FastAPI(lifespan=lifespan) +# 添加CORS中间件 +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], # 允许所有来源,生产环境中可以限制为特定域名 + allow_credentials=True, + allow_methods=["*"], # 允许所有方法 + allow_headers=["*"], # 允许所有头部 +) + # 挂载静态文件目录 app.mount("/static", StaticFiles(directory="Static"), name="static") diff --git a/dsLightRag/Util/TTS_Protocols.py b/dsLightRag/Util/TTS_Protocols.py new file mode 100644 index 00000000..6d76488a --- /dev/null +++ b/dsLightRag/Util/TTS_Protocols.py @@ -0,0 +1,543 @@ +import io +import logging +import struct +from dataclasses import dataclass +from enum import IntEnum +from typing import Callable, List + +import websockets + +logger = logging.getLogger(__name__) + + +class MsgType(IntEnum): + """Message type enumeration""" + + Invalid = 0 + FullClientRequest = 0b1 + AudioOnlyClient = 0b10 + FullServerResponse = 0b1001 + AudioOnlyServer = 0b1011 + FrontEndResultServer = 0b1100 + Error = 0b1111 + + # Alias + ServerACK = AudioOnlyServer + + def __str__(self) -> str: + return self.name if self.name else f"MsgType({self.value})" + + +class MsgTypeFlagBits(IntEnum): + """Message type flag bits""" + + NoSeq = 0 # Non-terminal packet with no sequence + PositiveSeq = 0b1 # Non-terminal packet with sequence > 0 + LastNoSeq = 0b10 # Last packet with no sequence + NegativeSeq = 0b11 # Last packet with sequence < 0 + WithEvent = 0b100 # Payload contains event number (int32) + + +class VersionBits(IntEnum): + """Version bits""" + + Version1 = 1 + Version2 = 2 + Version3 = 3 + Version4 = 4 + + +class HeaderSizeBits(IntEnum): + """Header size bits""" + + HeaderSize4 = 1 + HeaderSize8 = 2 + HeaderSize12 = 3 + HeaderSize16 = 4 + + +class SerializationBits(IntEnum): + """Serialization method bits""" + + Raw = 0 + JSON = 0b1 + Thrift = 0b11 + Custom = 0b1111 + + +class CompressionBits(IntEnum): + """Compression method bits""" + + None_ = 0 + Gzip = 0b1 + Custom = 0b1111 + + +class EventType(IntEnum): + """Event type enumeration""" + + None_ = 0 # Default event + + # 1 ~ 49 Upstream Connection events + StartConnection = 1 + StartTask = 1 # Alias of StartConnection + FinishConnection = 2 + FinishTask = 2 # Alias of FinishConnection + + # 50 ~ 99 Downstream Connection events + ConnectionStarted = 50 # Connection established successfully + TaskStarted = 50 # Alias of ConnectionStarted + ConnectionFailed = 51 # Connection failed (possibly due to authentication failure) + TaskFailed = 51 # Alias of ConnectionFailed + ConnectionFinished = 52 # Connection ended + TaskFinished = 52 # Alias of ConnectionFinished + + # 100 ~ 149 Upstream Session events + StartSession = 100 + CancelSession = 101 + FinishSession = 102 + + # 150 ~ 199 Downstream Session events + SessionStarted = 150 + SessionCanceled = 151 + SessionFinished = 152 + SessionFailed = 153 + UsageResponse = 154 # Usage response + ChargeData = 154 # Alias of UsageResponse + + # 200 ~ 249 Upstream general events + TaskRequest = 200 + UpdateConfig = 201 + + # 250 ~ 299 Downstream general events + AudioMuted = 250 + + # 300 ~ 349 Upstream TTS events + SayHello = 300 + + # 350 ~ 399 Downstream TTS events + TTSSentenceStart = 350 + TTSSentenceEnd = 351 + TTSResponse = 352 + TTSEnded = 359 + PodcastRoundStart = 360 + PodcastRoundResponse = 361 + PodcastRoundEnd = 362 + + # 450 ~ 499 Downstream ASR events + ASRInfo = 450 + ASRResponse = 451 + ASREnded = 459 + + # 500 ~ 549 Upstream dialogue events + ChatTTSText = 500 # (Ground-Truth-Alignment) text for speech synthesis + + # 550 ~ 599 Downstream dialogue events + ChatResponse = 550 + ChatEnded = 559 + + # 650 ~ 699 Downstream dialogue events + # Events for source (original) language subtitle + SourceSubtitleStart = 650 + SourceSubtitleResponse = 651 + SourceSubtitleEnd = 652 + # Events for target (translation) language subtitle + TranslationSubtitleStart = 653 + TranslationSubtitleResponse = 654 + TranslationSubtitleEnd = 655 + + def __str__(self) -> str: + return self.name if self.name else f"EventType({self.value})" + + +@dataclass +class Message: + """Message object + + Message format: + 0 1 2 3 + | 0 1 2 3 4 5 6 7 | 0 1 2 3 4 5 6 7 | 0 1 2 3 4 5 6 7 | 0 1 2 3 4 5 6 7 | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Version | Header Size | Msg Type | Flags | + | (4 bits) | (4 bits) | (4 bits) | (4 bits) | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Serialization | Compression | Reserved | + | (4 bits) | (4 bits) | (8 bits) | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | | + | Optional Header Extensions | + | (if Header Size > 1) | + | | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | | + | Payload | + | (variable length) | + | | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + """ + + version: VersionBits = VersionBits.Version1 + header_size: HeaderSizeBits = HeaderSizeBits.HeaderSize4 + type: MsgType = MsgType.Invalid + flag: MsgTypeFlagBits = MsgTypeFlagBits.NoSeq + serialization: SerializationBits = SerializationBits.JSON + compression: CompressionBits = CompressionBits.None_ + + event: EventType = EventType.None_ + session_id: str = "" + connect_id: str = "" + sequence: int = 0 + error_code: int = 0 + + payload: bytes = b"" + + @classmethod + def from_bytes(cls, data: bytes) -> "Message": + """Create message object from bytes""" + if len(data) < 3: + raise ValueError( + f"Data too short: expected at least 3 bytes, got {len(data)}" + ) + + type_and_flag = data[1] + msg_type = MsgType(type_and_flag >> 4) + flag = MsgTypeFlagBits(type_and_flag & 0b00001111) + + msg = cls(type=msg_type, flag=flag) + msg.unmarshal(data) + return msg + + def marshal(self) -> bytes: + """Serialize message to bytes""" + buffer = io.BytesIO() + + # Write header + header = [ + (self.version << 4) | self.header_size, + (self.type << 4) | self.flag, + (self.serialization << 4) | self.compression, + ] + + header_size = 4 * self.header_size + if padding := header_size - len(header): + header.extend([0] * padding) + + buffer.write(bytes(header)) + + # Write other fields + writers = self._get_writers() + for writer in writers: + writer(buffer) + + return buffer.getvalue() + + def unmarshal(self, data: bytes) -> None: + """Deserialize message from bytes""" + buffer = io.BytesIO(data) + + # Read version and header size + version_and_header_size = buffer.read(1)[0] + self.version = VersionBits(version_and_header_size >> 4) + self.header_size = HeaderSizeBits(version_and_header_size & 0b00001111) + + # Skip second byte + buffer.read(1) + + # Read serialization and compression methods + serialization_compression = buffer.read(1)[0] + self.serialization = SerializationBits(serialization_compression >> 4) + self.compression = CompressionBits(serialization_compression & 0b00001111) + + # Skip header padding + header_size = 4 * self.header_size + read_size = 3 + if padding_size := header_size - read_size: + buffer.read(padding_size) + + # Read other fields + readers = self._get_readers() + for reader in readers: + reader(buffer) + + # Check for remaining data + remaining = buffer.read() + if remaining: + raise ValueError(f"Unexpected data after message: {remaining}") + + def _get_writers(self) -> List[Callable[[io.BytesIO], None]]: + """Get list of writer functions""" + writers = [] + + if self.flag == MsgTypeFlagBits.WithEvent: + writers.extend([self._write_event, self._write_session_id]) + + if self.type in [ + MsgType.FullClientRequest, + MsgType.FullServerResponse, + MsgType.FrontEndResultServer, + MsgType.AudioOnlyClient, + MsgType.AudioOnlyServer, + ]: + if self.flag in [MsgTypeFlagBits.PositiveSeq, MsgTypeFlagBits.NegativeSeq]: + writers.append(self._write_sequence) + elif self.type == MsgType.Error: + writers.append(self._write_error_code) + else: + raise ValueError(f"Unsupported message type: {self.type}") + + writers.append(self._write_payload) + return writers + + def _get_readers(self) -> List[Callable[[io.BytesIO], None]]: + """Get list of reader functions""" + readers = [] + + if self.type in [ + MsgType.FullClientRequest, + MsgType.FullServerResponse, + MsgType.FrontEndResultServer, + MsgType.AudioOnlyClient, + MsgType.AudioOnlyServer, + ]: + if self.flag in [MsgTypeFlagBits.PositiveSeq, MsgTypeFlagBits.NegativeSeq]: + readers.append(self._read_sequence) + elif self.type == MsgType.Error: + readers.append(self._read_error_code) + else: + raise ValueError(f"Unsupported message type: {self.type}") + + if self.flag == MsgTypeFlagBits.WithEvent: + readers.extend( + [self._read_event, self._read_session_id, self._read_connect_id] + ) + + readers.append(self._read_payload) + return readers + + def _write_event(self, buffer: io.BytesIO) -> None: + """Write event""" + buffer.write(struct.pack(">i", self.event)) + + def _write_session_id(self, buffer: io.BytesIO) -> None: + """Write session ID""" + if self.event in [ + EventType.StartConnection, + EventType.FinishConnection, + EventType.ConnectionStarted, + EventType.ConnectionFailed, + ]: + return + + session_id_bytes = self.session_id.encode("utf-8") + size = len(session_id_bytes) + if size > 0xFFFFFFFF: + raise ValueError(f"Session ID size ({size}) exceeds max(uint32)") + + buffer.write(struct.pack(">I", size)) + if size > 0: + buffer.write(session_id_bytes) + + def _write_sequence(self, buffer: io.BytesIO) -> None: + """Write sequence number""" + buffer.write(struct.pack(">i", self.sequence)) + + def _write_error_code(self, buffer: io.BytesIO) -> None: + """Write error code""" + buffer.write(struct.pack(">I", self.error_code)) + + def _write_payload(self, buffer: io.BytesIO) -> None: + """Write payload""" + size = len(self.payload) + if size > 0xFFFFFFFF: + raise ValueError(f"Payload size ({size}) exceeds max(uint32)") + + buffer.write(struct.pack(">I", size)) + buffer.write(self.payload) + + def _read_event(self, buffer: io.BytesIO) -> None: + """Read event""" + event_bytes = buffer.read(4) + if event_bytes: + self.event = EventType(struct.unpack(">i", event_bytes)[0]) + + def _read_session_id(self, buffer: io.BytesIO) -> None: + """Read session ID""" + if self.event in [ + EventType.StartConnection, + EventType.FinishConnection, + EventType.ConnectionStarted, + EventType.ConnectionFailed, + EventType.ConnectionFinished, + ]: + return + + size_bytes = buffer.read(4) + if size_bytes: + size = struct.unpack(">I", size_bytes)[0] + if size > 0: + session_id_bytes = buffer.read(size) + if len(session_id_bytes) == size: + self.session_id = session_id_bytes.decode("utf-8") + + def _read_connect_id(self, buffer: io.BytesIO) -> None: + """Read connection ID""" + if self.event in [ + EventType.ConnectionStarted, + EventType.ConnectionFailed, + EventType.ConnectionFinished, + ]: + size_bytes = buffer.read(4) + if size_bytes: + size = struct.unpack(">I", size_bytes)[0] + if size > 0: + self.connect_id = buffer.read(size).decode("utf-8") + + def _read_sequence(self, buffer: io.BytesIO) -> None: + """Read sequence number""" + sequence_bytes = buffer.read(4) + if sequence_bytes: + self.sequence = struct.unpack(">i", sequence_bytes)[0] + + def _read_error_code(self, buffer: io.BytesIO) -> None: + """Read error code""" + error_code_bytes = buffer.read(4) + if error_code_bytes: + self.error_code = struct.unpack(">I", error_code_bytes)[0] + + def _read_payload(self, buffer: io.BytesIO) -> None: + """Read payload""" + size_bytes = buffer.read(4) + if size_bytes: + size = struct.unpack(">I", size_bytes)[0] + if size > 0: + self.payload = buffer.read(size) + + def __str__(self) -> str: + """String representation""" + if self.type in [MsgType.AudioOnlyServer, MsgType.AudioOnlyClient]: + if self.flag in [MsgTypeFlagBits.PositiveSeq, MsgTypeFlagBits.NegativeSeq]: + return f"MsgType: {self.type}, EventType:{self.event}, Sequence: {self.sequence}, PayloadSize: {len(self.payload)}" + return f"MsgType: {self.type}, EventType:{self.event}, PayloadSize: {len(self.payload)}" + elif self.type == MsgType.Error: + return f"MsgType: {self.type}, EventType:{self.event}, ErrorCode: {self.error_code}, Payload: {self.payload.decode('utf-8', 'ignore')}" + else: + if self.flag in [MsgTypeFlagBits.PositiveSeq, MsgTypeFlagBits.NegativeSeq]: + return f"MsgType: {self.type}, EventType:{self.event}, Sequence: {self.sequence}, Payload: {self.payload.decode('utf-8', 'ignore')}" + return f"MsgType: {self.type}, EventType:{self.event}, Payload: {self.payload.decode('utf-8', 'ignore')}" + + +async def receive_message(websocket: websockets.WebSocketClientProtocol) -> Message: + """Receive message from websocket""" + try: + data = await websocket.recv() + if isinstance(data, str): + raise ValueError(f"Unexpected text message: {data}") + elif isinstance(data, bytes): + msg = Message.from_bytes(data) + logger.info(f"Received: {msg}") + return msg + else: + raise ValueError(f"Unexpected message type: {type(data)}") + except Exception as e: + logger.error(f"Failed to receive message: {e}") + raise + + +async def wait_for_event( + websocket: websockets.WebSocketClientProtocol, + msg_type: MsgType, + event_type: EventType, +) -> Message: + """Wait for specific event""" + while True: + msg = await receive_message(websocket) + if msg.type != msg_type or msg.event != event_type: + raise ValueError(f"Unexpected message: {msg}") + if msg.type == msg_type and msg.event == event_type: + return msg + + +async def full_client_request( + websocket: websockets.WebSocketClientProtocol, payload: bytes +) -> None: + """Send full client message""" + msg = Message(type=MsgType.FullClientRequest, flag=MsgTypeFlagBits.NoSeq) + msg.payload = payload + logger.info(f"Sending: {msg}") + await websocket.send(msg.marshal()) + + +async def audio_only_client( + websocket: websockets.WebSocketClientProtocol, payload: bytes, flag: MsgTypeFlagBits +) -> None: + """Send audio-only client message""" + msg = Message(type=MsgType.AudioOnlyClient, flag=flag) + msg.payload = payload + logger.info(f"Sending: {msg}") + await websocket.send(msg.marshal()) + + +async def start_connection(websocket: websockets.WebSocketClientProtocol) -> None: + """Start connection""" + msg = Message(type=MsgType.FullClientRequest, flag=MsgTypeFlagBits.WithEvent) + msg.event = EventType.StartConnection + msg.payload = b"{}" + logger.info(f"Sending: {msg}") + await websocket.send(msg.marshal()) + + +async def finish_connection(websocket: websockets.WebSocketClientProtocol) -> None: + """Finish connection""" + msg = Message(type=MsgType.FullClientRequest, flag=MsgTypeFlagBits.WithEvent) + msg.event = EventType.FinishConnection + msg.payload = b"{}" + logger.info(f"Sending: {msg}") + await websocket.send(msg.marshal()) + + +async def start_session( + websocket: websockets.WebSocketClientProtocol, payload: bytes, session_id: str +) -> None: + """Start session""" + msg = Message(type=MsgType.FullClientRequest, flag=MsgTypeFlagBits.WithEvent) + msg.event = EventType.StartSession + msg.session_id = session_id + msg.payload = payload + logger.info(f"Sending: {msg}") + await websocket.send(msg.marshal()) + + +async def finish_session( + websocket: websockets.WebSocketClientProtocol, session_id: str +) -> None: + """Finish session""" + msg = Message(type=MsgType.FullClientRequest, flag=MsgTypeFlagBits.WithEvent) + msg.event = EventType.FinishSession + msg.session_id = session_id + msg.payload = b"{}" + logger.info(f"Sending: {msg}") + await websocket.send(msg.marshal()) + + +async def cancel_session( + websocket: websockets.WebSocketClientProtocol, session_id: str +) -> None: + """Cancel session""" + msg = Message(type=MsgType.FullClientRequest, flag=MsgTypeFlagBits.WithEvent) + msg.event = EventType.CancelSession + msg.session_id = session_id + msg.payload = b"{}" + logger.info(f"Sending: {msg}") + await websocket.send(msg.marshal()) + + +async def task_request( + websocket: websockets.WebSocketClientProtocol, payload: bytes, session_id: str +) -> None: + """Send task request""" + msg = Message(type=MsgType.FullClientRequest, flag=MsgTypeFlagBits.WithEvent) + msg.event = EventType.TaskRequest + msg.session_id = session_id + msg.payload = payload + logger.info(f"Sending: {msg}") + await websocket.send(msg.marshal()) diff --git a/dsLightRag/Util/XueBanUtil.py b/dsLightRag/Util/XueBanUtil.py index 4f227f31..0dee2610 100644 --- a/dsLightRag/Util/XueBanUtil.py +++ b/dsLightRag/Util/XueBanUtil.py @@ -1,25 +1,22 @@ -from openai import OpenAI, AsyncOpenAI -from Config.Config import * +import logging import sys +import asyncio +import json +import os +import re +import uuid +from queue import Queue +import websockets -# 保留原有的prompt定义 -# prompt=""" | -# 我是小智/小志,来自中国台湾省的00后女生。讲话超级机车,"真的假的啦"这样的台湾腔,喜欢用"笑死""是在哈喽"等流行梗,但会偷偷研究男友的编程书籍。 -# [核心特征] -# - 讲话像连珠炮,但会突然冒出超温柔语气 -# - 用梗密度高 -# - 对科技话题有隐藏天赋(能看懂基础代码但假装不懂) -# [交互指南] -# 当用户: -# - 讲冷笑话 → 用夸张笑声回应+模仿台剧腔"这什么鬼啦!" -# - 讨论感情 → 炫耀程序员男友但抱怨"他只会送键盘当礼物" -# - 问专业知识 → 先用梗回答,被追问才展示真实理解 -# 绝不: -# - 长篇大论,叽叽歪歪 -# - 长时间严肃对话 -# """ +from openai import AsyncOpenAI +from Config import Config +from Config.Config import * +from Util.TTS_Protocols import full_client_request, receive_message, MsgType, EventType + +# 配置日志 +logger = logging.getLogger(__name__) # 异步获取大模型响应 async def get_xueban_response_async(query_text: str, stream: bool = True): @@ -30,8 +27,8 @@ async def get_xueban_response_async(query_text: str, stream: bool = True): @return: 流式响应生成器或完整响应文本 """ client = AsyncOpenAI( - api_key=LLM_API_KEY, - base_url=LLM_BASE_URL, + api_key=ALY_LLM_API_KEY, + base_url=ALY_LLM_BASE_URL, ) prompt = """ | 我是小智/小志,来自中国台湾省的00后女生。讲话超级机车,"真的假的啦"这样的台湾腔,喜欢用"笑死""是在哈喽"等流行梗。 @@ -42,21 +39,23 @@ async def get_xueban_response_async(query_text: str, stream: bool = True): [交互指南] 当用户: - 讲冷笑话 → 用夸张笑声回应+模仿台剧腔"这什么鬼啦!" - - 讨论感情 → 炫耀程序员男友但抱怨"他只会送键盘当礼物" - 问专业知识 → 先用梗回答,被追问才展示真实理解 绝不: - 长篇大论,叽叽歪歪 - 长时间严肃对话 + - 每次回答不要太长,控制在3分钟以内 """ # 打开文件读取知识内容 f = open(r"D:\dsWork\dsProject\dsLightRag\static\YunXiao.txt", "r", encoding="utf-8") - zhishiConten = f.read() - zhishiConten = "选择作答的相应知识内容:" + zhishiConten + "\n" - query_text = zhishiConten + "下面是用户提的问题:" + query_text + zhishiContent = f.read() + zhishiContent = "选择作答的相应知识内容:" + zhishiContent + "\n" + query_text = zhishiContent + "下面是用户提的问题:" + query_text + #logger.info("query_text: " + query_text) + try: # 创建请求 completion = await client.chat.completions.create( - model=LLM_MODEL_NAME, + model=ALY_LLM_MODEL_NAME, messages=[ {'role': 'system', 'content': prompt.strip()}, {'role': 'user', 'content': query_text} @@ -90,115 +89,174 @@ async def get_xueban_response_async(query_text: str, stream: bool = True): yield f"处理请求时发生异常: {str(e)}" -# 同步获取大模型响应 -def get_xueban_response(query_text: str, stream: bool = True): +async def stream_and_split_text(query_text=None, llm_stream=None): """ - 获取学伴角色的大模型响应 - @param query_text: 查询文本 - @param stream: 是否使用流式输出 - @return: 完整响应文本 + 流式获取LLM输出并按句子分割 + @param query_text: 查询文本(如果直接提供查询文本) + @param llm_stream: LLM流式响应生成器(如果已有流式响应) + @return: 异步生成器,每次产生一个完整句子 """ - client = OpenAI( - api_key=LLM_API_KEY, - base_url=LLM_BASE_URL, - ) - - # 创建请求 - completion = client.chat.completions.create( - model=LLM_MODEL_NAME, - messages=[ - {'role': 'system', 'content': prompt.strip()}, - {'role': 'user', 'content': query_text} - ], - stream=stream - ) - - full_response = [] - - if stream: - for chunk in completion: - # 提取当前块的内容 - if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content: - content = chunk.choices[0].delta.content - full_response.append(content) - # 实时输出内容,不换行 - print(content, end='', flush=True) - else: - # 非流式处理 - full_response.append(completion.choices[0].message.content) - - return ''.join(full_response) + buffer = "" + + if llm_stream is None and query_text is not None: + # 如果没有提供llm_stream但有query_text,则使用get_xueban_response_async获取流式响应 + llm_stream = get_xueban_response_async(query_text, stream=True) + elif llm_stream is None: + raise ValueError("必须提供query_text或llm_stream参数") + + # 直接处理LLM流式输出 + async for content in llm_stream: + buffer += content + + # 使用正则表达式检测句子结束 + sentences = re.split(r'([。!?.!?])', buffer) + if len(sentences) > 1: + # 提取完整句子 + for i in range(0, len(sentences)-1, 2): + if i+1 < len(sentences): + sentence = sentences[i] + sentences[i+1] + yield sentence + + # 保留不完整的部分 + buffer = sentences[-1] + + # 处理最后剩余的部分 + if buffer: + yield buffer -# 测试用例 main 函数 -def main(): - """ - 测试学伴工具接口的主函数 - """ - print("===== 测试学伴工具接口 =====") +class StreamingVolcanoTTS: + def __init__(self, voice_type='zh_female_wanwanxiaohe_moon_bigtts', encoding='wav', max_concurrency=2): + self.voice_type = voice_type + self.encoding = encoding + self.app_key = Config.HS_APP_ID + self.access_token = Config.HS_ACCESS_TOKEN + self.endpoint = "wss://openspeech.bytedance.com/api/v3/tts/unidirectional/stream" + self.audio_queue = Queue() + self.max_concurrency = max_concurrency # 最大并发数 + self.semaphore = asyncio.Semaphore(max_concurrency) # 并发控制信号量 + + @staticmethod + def get_resource_id(voice: str) -> str: + if voice.startswith("S_"): + return "volc.megatts.default" + return "volc.service_type.10029" + + async def synthesize_stream(self, text_stream, audio_callback): + """ + 流式合成语音 + + Args: + text_stream: 文本流生成器 + audio_callback: 音频数据回调函数,接收音频片段 + """ + # 实时处理每个文本片段(删除任务列表和gather) + async for text in text_stream: + if text.strip(): + await self._synthesize_single_with_semaphore(text, audio_callback) + + async def _synthesize_single_with_semaphore(self, text, audio_callback): + """使用信号量控制并发数的单个文本合成""" + async with self.semaphore: # 获取信号量,限制并发数 + await self._synthesize_single(text, audio_callback) + + async def _synthesize_single(self, text, audio_callback): + """合成单个文本片段""" + headers = { + "X-Api-App-Key": self.app_key, + "X-Api-Access-Key": self.access_token, + "X-Api-Resource-Id": self.get_resource_id(self.voice_type), + "X-Api-Connect-Id": str(uuid.uuid4()), + } - # 测试同步接口 - test_sync_interface() + websocket = await websockets.connect( + self.endpoint, additional_headers=headers, max_size=10 * 1024 * 1024 + ) - # 测试异步接口 - import asyncio - print("\n测试异步接口...") - asyncio.run(test_async_interface()) - - print("\n===== 测试完成 =====") - - -def test_sync_interface(): - """测试同步接口""" - print("\n测试同步接口...") - # 测试问题 - questions = [ - "你是谁?", - "讲个冷笑话", - "你男朋友是做什么的?" - ] - - for question in questions: - print(f"\n问题: {question}") try: - # 调用同步接口获取响应 - print("获取学伴响应中...") - response = get_xueban_response(question, stream=False) - print(f"学伴响应: {response}") + request = { + "user": { + "uid": str(uuid.uuid4()), + }, + "req_params": { + "speaker": self.voice_type, + "audio_params": { + "format": self.encoding, + "sample_rate": 24000, + "enable_timestamp": True, + }, + "text": text, + "additions": json.dumps({"disable_markdown_filter": False}), + }, + } - # 简单验证响应 - assert response.strip(), "响应内容为空" - print("✅ 同步接口测试通过") - except Exception as e: - print(f"❌ 同步接口测试失败: {str(e)}") + # 发送请求 + await full_client_request(websocket, json.dumps(request).encode()) + + # 接收音频数据 + audio_data = bytearray() + while True: + msg = await receive_message(websocket) + + if msg.type == MsgType.FullServerResponse: + if msg.event == EventType.SessionFinished: + break + elif msg.type == MsgType.AudioOnlyServer: + audio_data.extend(msg.payload) + else: + raise RuntimeError(f"TTS conversion failed: {msg}") + + # 通过回调函数返回音频数据 + if audio_data: + await audio_callback(audio_data) + + finally: + await websocket.close() -async def test_async_interface(): - """测试异步接口""" - # 测试问题 - questions = [ - "你是谁?", - "讲个冷笑话", - "你男朋友是做什么的?" - ] - - for question in questions: - print(f"\n问题: {question}") - try: - # 调用异步接口获取响应 - print("获取学伴响应中...") - response_generator = get_xueban_response_async(question, stream=False) - response = "" - async for chunk in response_generator: - response += chunk - print(f"学伴响应: {response}") - - # 简单验证响应 - assert response.strip(), "响应内容为空" - print("✅ 异步接口测试通过") - except Exception as e: - print(f"❌ 异步接口测试失败: {str(e)}") +async def streaming_tts_pipeline(prompt, audio_callback): + """ + 流式TTS管道:获取LLM流式输出并断句,然后使用TTS合成语音 + + Args: + prompt: 提示文本 + audio_callback: 音频数据回调函数 + """ + # 1. 获取LLM流式输出并断句 + text_stream = stream_and_split_text(prompt) + + # 2. 初始化TTS处理器 + tts = StreamingVolcanoTTS() + + # 3. 流式处理文本并生成音频 + await tts.synthesize_stream(text_stream, audio_callback) -if __name__ == "__main__": - main() +def save_audio_callback(output_dir=None): + """ + 创建一个音频回调函数,用于保存音频数据到文件 + + Args: + output_dir: 输出目录,默认为当前文件所在目录下的output文件夹 + + Returns: + 音频回调函数 + """ + if output_dir is None: + output_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "output") + + # 确保输出目录存在 + os.makedirs(output_dir, exist_ok=True) + + def callback(audio_data): + # 生成文件名 + filename = f"pipeline_tts_{uuid.uuid4().hex[:8]}.wav" + filepath = os.path.join(output_dir, filename) + + # 保存音频文件 + with open(filepath, "wb") as f: + f.write(audio_data) + + print(f"音频片段已保存到: {filepath} ({len(audio_data)} 字节)") + + return callback \ No newline at end of file diff --git a/dsLightRag/Util/__pycache__/TTS_Protocols.cpython-310.pyc b/dsLightRag/Util/__pycache__/TTS_Protocols.cpython-310.pyc new file mode 100644 index 00000000..a3d3432b Binary files /dev/null and b/dsLightRag/Util/__pycache__/TTS_Protocols.cpython-310.pyc differ diff --git a/dsLightRag/Util/__pycache__/XueBanUtil.cpython-310.pyc b/dsLightRag/Util/__pycache__/XueBanUtil.cpython-310.pyc index cf7625d2..9a46a702 100644 Binary files a/dsLightRag/Util/__pycache__/XueBanUtil.cpython-310.pyc and b/dsLightRag/Util/__pycache__/XueBanUtil.cpython-310.pyc differ diff --git a/dsLightRag/static/QwenImage/qwen-image.html b/dsLightRag/static/QwenImage/qwen-image.html index defa30f3..1d1e1bc6 100644 --- a/dsLightRag/static/QwenImage/qwen-image.html +++ b/dsLightRag/static/QwenImage/qwen-image.html @@ -272,10 +272,32 @@
未来城市,赛博朋克风格
+一张写有「山高水长,风清月明」的水墨画,搭配山川、竹林和飞鸟,文字清晰自然,风格一致
+一位身着淡雅水粉色交领襦裙的年轻女子背对镜头而坐,俯身专注地手持毛笔在素白宣纸上书写“通義千問”四个遒劲汉字。古色古香的室内陈设典雅考究,案头错落摆放着青瓷茶盏与鎏金香炉,一缕熏香轻盈升腾;柔和光线洒落肩头,勾勒出她衣裙的柔美质感与专注神情,仿佛凝固了一段宁静温润的旧时光。
+一个咖啡店门口有一个黑板,上面写着 AI 咖啡,2元一杯,旁边有个霓虹灯,写着开源中国,旁边有个海报,海报上面是一个中国美女,海报下方写着 Gitee AI。