Files
dsProject/dsLightRag/Util/TTS_Protocols.py
2025-08-31 09:20:47 +08:00

544 lines
18 KiB
Python

import io
import logging
import struct
from dataclasses import dataclass
from enum import IntEnum
from typing import Callable, List
import websockets
logger = logging.getLogger(__name__)
class MsgType(IntEnum):
"""Message type enumeration"""
Invalid = 0
FullClientRequest = 0b1
AudioOnlyClient = 0b10
FullServerResponse = 0b1001
AudioOnlyServer = 0b1011
FrontEndResultServer = 0b1100
Error = 0b1111
# Alias
ServerACK = AudioOnlyServer
def __str__(self) -> str:
return self.name if self.name else f"MsgType({self.value})"
class MsgTypeFlagBits(IntEnum):
"""Message type flag bits"""
NoSeq = 0 # Non-terminal packet with no sequence
PositiveSeq = 0b1 # Non-terminal packet with sequence > 0
LastNoSeq = 0b10 # Last packet with no sequence
NegativeSeq = 0b11 # Last packet with sequence < 0
WithEvent = 0b100 # Payload contains event number (int32)
class VersionBits(IntEnum):
"""Version bits"""
Version1 = 1
Version2 = 2
Version3 = 3
Version4 = 4
class HeaderSizeBits(IntEnum):
"""Header size bits"""
HeaderSize4 = 1
HeaderSize8 = 2
HeaderSize12 = 3
HeaderSize16 = 4
class SerializationBits(IntEnum):
"""Serialization method bits"""
Raw = 0
JSON = 0b1
Thrift = 0b11
Custom = 0b1111
class CompressionBits(IntEnum):
"""Compression method bits"""
None_ = 0
Gzip = 0b1
Custom = 0b1111
class EventType(IntEnum):
"""Event type enumeration"""
None_ = 0 # Default event
# 1 ~ 49 Upstream Connection events
StartConnection = 1
StartTask = 1 # Alias of StartConnection
FinishConnection = 2
FinishTask = 2 # Alias of FinishConnection
# 50 ~ 99 Downstream Connection events
ConnectionStarted = 50 # Connection established successfully
TaskStarted = 50 # Alias of ConnectionStarted
ConnectionFailed = 51 # Connection failed (possibly due to authentication failure)
TaskFailed = 51 # Alias of ConnectionFailed
ConnectionFinished = 52 # Connection ended
TaskFinished = 52 # Alias of ConnectionFinished
# 100 ~ 149 Upstream Session events
StartSession = 100
CancelSession = 101
FinishSession = 102
# 150 ~ 199 Downstream Session events
SessionStarted = 150
SessionCanceled = 151
SessionFinished = 152
SessionFailed = 153
UsageResponse = 154 # Usage response
ChargeData = 154 # Alias of UsageResponse
# 200 ~ 249 Upstream general events
TaskRequest = 200
UpdateConfig = 201
# 250 ~ 299 Downstream general events
AudioMuted = 250
# 300 ~ 349 Upstream TTS events
SayHello = 300
# 350 ~ 399 Downstream TTS events
TTSSentenceStart = 350
TTSSentenceEnd = 351
TTSResponse = 352
TTSEnded = 359
PodcastRoundStart = 360
PodcastRoundResponse = 361
PodcastRoundEnd = 362
# 450 ~ 499 Downstream ASR events
ASRInfo = 450
ASRResponse = 451
ASREnded = 459
# 500 ~ 549 Upstream dialogue events
ChatTTSText = 500 # (Ground-Truth-Alignment) text for speech synthesis
# 550 ~ 599 Downstream dialogue events
ChatResponse = 550
ChatEnded = 559
# 650 ~ 699 Downstream dialogue events
# Events for source (original) language subtitle
SourceSubtitleStart = 650
SourceSubtitleResponse = 651
SourceSubtitleEnd = 652
# Events for target (translation) language subtitle
TranslationSubtitleStart = 653
TranslationSubtitleResponse = 654
TranslationSubtitleEnd = 655
def __str__(self) -> str:
return self.name if self.name else f"EventType({self.value})"
@dataclass
class Message:
"""Message object
Message format:
0 1 2 3
| 0 1 2 3 4 5 6 7 | 0 1 2 3 4 5 6 7 | 0 1 2 3 4 5 6 7 | 0 1 2 3 4 5 6 7 |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Version | Header Size | Msg Type | Flags |
| (4 bits) | (4 bits) | (4 bits) | (4 bits) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Serialization | Compression | Reserved |
| (4 bits) | (4 bits) | (8 bits) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| |
| Optional Header Extensions |
| (if Header Size > 1) |
| |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| |
| Payload |
| (variable length) |
| |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
"""
version: VersionBits = VersionBits.Version1
header_size: HeaderSizeBits = HeaderSizeBits.HeaderSize4
type: MsgType = MsgType.Invalid
flag: MsgTypeFlagBits = MsgTypeFlagBits.NoSeq
serialization: SerializationBits = SerializationBits.JSON
compression: CompressionBits = CompressionBits.None_
event: EventType = EventType.None_
session_id: str = ""
connect_id: str = ""
sequence: int = 0
error_code: int = 0
payload: bytes = b""
@classmethod
def from_bytes(cls, data: bytes) -> "Message":
"""Create message object from bytes"""
if len(data) < 3:
raise ValueError(
f"Data too short: expected at least 3 bytes, got {len(data)}"
)
type_and_flag = data[1]
msg_type = MsgType(type_and_flag >> 4)
flag = MsgTypeFlagBits(type_and_flag & 0b00001111)
msg = cls(type=msg_type, flag=flag)
msg.unmarshal(data)
return msg
def marshal(self) -> bytes:
"""Serialize message to bytes"""
buffer = io.BytesIO()
# Write header
header = [
(self.version << 4) | self.header_size,
(self.type << 4) | self.flag,
(self.serialization << 4) | self.compression,
]
header_size = 4 * self.header_size
if padding := header_size - len(header):
header.extend([0] * padding)
buffer.write(bytes(header))
# Write other fields
writers = self._get_writers()
for writer in writers:
writer(buffer)
return buffer.getvalue()
def unmarshal(self, data: bytes) -> None:
"""Deserialize message from bytes"""
buffer = io.BytesIO(data)
# Read version and header size
version_and_header_size = buffer.read(1)[0]
self.version = VersionBits(version_and_header_size >> 4)
self.header_size = HeaderSizeBits(version_and_header_size & 0b00001111)
# Skip second byte
buffer.read(1)
# Read serialization and compression methods
serialization_compression = buffer.read(1)[0]
self.serialization = SerializationBits(serialization_compression >> 4)
self.compression = CompressionBits(serialization_compression & 0b00001111)
# Skip header padding
header_size = 4 * self.header_size
read_size = 3
if padding_size := header_size - read_size:
buffer.read(padding_size)
# Read other fields
readers = self._get_readers()
for reader in readers:
reader(buffer)
# Check for remaining data
remaining = buffer.read()
if remaining:
raise ValueError(f"Unexpected data after message: {remaining}")
def _get_writers(self) -> List[Callable[[io.BytesIO], None]]:
"""Get list of writer functions"""
writers = []
if self.flag == MsgTypeFlagBits.WithEvent:
writers.extend([self._write_event, self._write_session_id])
if self.type in [
MsgType.FullClientRequest,
MsgType.FullServerResponse,
MsgType.FrontEndResultServer,
MsgType.AudioOnlyClient,
MsgType.AudioOnlyServer,
]:
if self.flag in [MsgTypeFlagBits.PositiveSeq, MsgTypeFlagBits.NegativeSeq]:
writers.append(self._write_sequence)
elif self.type == MsgType.Error:
writers.append(self._write_error_code)
else:
raise ValueError(f"Unsupported message type: {self.type}")
writers.append(self._write_payload)
return writers
def _get_readers(self) -> List[Callable[[io.BytesIO], None]]:
"""Get list of reader functions"""
readers = []
if self.type in [
MsgType.FullClientRequest,
MsgType.FullServerResponse,
MsgType.FrontEndResultServer,
MsgType.AudioOnlyClient,
MsgType.AudioOnlyServer,
]:
if self.flag in [MsgTypeFlagBits.PositiveSeq, MsgTypeFlagBits.NegativeSeq]:
readers.append(self._read_sequence)
elif self.type == MsgType.Error:
readers.append(self._read_error_code)
else:
raise ValueError(f"Unsupported message type: {self.type}")
if self.flag == MsgTypeFlagBits.WithEvent:
readers.extend(
[self._read_event, self._read_session_id, self._read_connect_id]
)
readers.append(self._read_payload)
return readers
def _write_event(self, buffer: io.BytesIO) -> None:
"""Write event"""
buffer.write(struct.pack(">i", self.event))
def _write_session_id(self, buffer: io.BytesIO) -> None:
"""Write session ID"""
if self.event in [
EventType.StartConnection,
EventType.FinishConnection,
EventType.ConnectionStarted,
EventType.ConnectionFailed,
]:
return
session_id_bytes = self.session_id.encode("utf-8")
size = len(session_id_bytes)
if size > 0xFFFFFFFF:
raise ValueError(f"Session ID size ({size}) exceeds max(uint32)")
buffer.write(struct.pack(">I", size))
if size > 0:
buffer.write(session_id_bytes)
def _write_sequence(self, buffer: io.BytesIO) -> None:
"""Write sequence number"""
buffer.write(struct.pack(">i", self.sequence))
def _write_error_code(self, buffer: io.BytesIO) -> None:
"""Write error code"""
buffer.write(struct.pack(">I", self.error_code))
def _write_payload(self, buffer: io.BytesIO) -> None:
"""Write payload"""
size = len(self.payload)
if size > 0xFFFFFFFF:
raise ValueError(f"Payload size ({size}) exceeds max(uint32)")
buffer.write(struct.pack(">I", size))
buffer.write(self.payload)
def _read_event(self, buffer: io.BytesIO) -> None:
"""Read event"""
event_bytes = buffer.read(4)
if event_bytes:
self.event = EventType(struct.unpack(">i", event_bytes)[0])
def _read_session_id(self, buffer: io.BytesIO) -> None:
"""Read session ID"""
if self.event in [
EventType.StartConnection,
EventType.FinishConnection,
EventType.ConnectionStarted,
EventType.ConnectionFailed,
EventType.ConnectionFinished,
]:
return
size_bytes = buffer.read(4)
if size_bytes:
size = struct.unpack(">I", size_bytes)[0]
if size > 0:
session_id_bytes = buffer.read(size)
if len(session_id_bytes) == size:
self.session_id = session_id_bytes.decode("utf-8")
def _read_connect_id(self, buffer: io.BytesIO) -> None:
"""Read connection ID"""
if self.event in [
EventType.ConnectionStarted,
EventType.ConnectionFailed,
EventType.ConnectionFinished,
]:
size_bytes = buffer.read(4)
if size_bytes:
size = struct.unpack(">I", size_bytes)[0]
if size > 0:
self.connect_id = buffer.read(size).decode("utf-8")
def _read_sequence(self, buffer: io.BytesIO) -> None:
"""Read sequence number"""
sequence_bytes = buffer.read(4)
if sequence_bytes:
self.sequence = struct.unpack(">i", sequence_bytes)[0]
def _read_error_code(self, buffer: io.BytesIO) -> None:
"""Read error code"""
error_code_bytes = buffer.read(4)
if error_code_bytes:
self.error_code = struct.unpack(">I", error_code_bytes)[0]
def _read_payload(self, buffer: io.BytesIO) -> None:
"""Read payload"""
size_bytes = buffer.read(4)
if size_bytes:
size = struct.unpack(">I", size_bytes)[0]
if size > 0:
self.payload = buffer.read(size)
def __str__(self) -> str:
"""String representation"""
if self.type in [MsgType.AudioOnlyServer, MsgType.AudioOnlyClient]:
if self.flag in [MsgTypeFlagBits.PositiveSeq, MsgTypeFlagBits.NegativeSeq]:
return f"MsgType: {self.type}, EventType:{self.event}, Sequence: {self.sequence}, PayloadSize: {len(self.payload)}"
return f"MsgType: {self.type}, EventType:{self.event}, PayloadSize: {len(self.payload)}"
elif self.type == MsgType.Error:
return f"MsgType: {self.type}, EventType:{self.event}, ErrorCode: {self.error_code}, Payload: {self.payload.decode('utf-8', 'ignore')}"
else:
if self.flag in [MsgTypeFlagBits.PositiveSeq, MsgTypeFlagBits.NegativeSeq]:
return f"MsgType: {self.type}, EventType:{self.event}, Sequence: {self.sequence}, Payload: {self.payload.decode('utf-8', 'ignore')}"
return f"MsgType: {self.type}, EventType:{self.event}, Payload: {self.payload.decode('utf-8', 'ignore')}"
async def receive_message(websocket: websockets.WebSocketClientProtocol) -> Message:
"""Receive message from websocket"""
try:
data = await websocket.recv()
if isinstance(data, str):
raise ValueError(f"Unexpected text message: {data}")
elif isinstance(data, bytes):
msg = Message.from_bytes(data)
logger.info(f"Received: {msg}")
return msg
else:
raise ValueError(f"Unexpected message type: {type(data)}")
except Exception as e:
logger.error(f"Failed to receive message: {e}")
raise
async def wait_for_event(
websocket: websockets.WebSocketClientProtocol,
msg_type: MsgType,
event_type: EventType,
) -> Message:
"""Wait for specific event"""
while True:
msg = await receive_message(websocket)
if msg.type != msg_type or msg.event != event_type:
raise ValueError(f"Unexpected message: {msg}")
if msg.type == msg_type and msg.event == event_type:
return msg
async def full_client_request(
websocket: websockets.WebSocketClientProtocol, payload: bytes
) -> None:
"""Send full client message"""
msg = Message(type=MsgType.FullClientRequest, flag=MsgTypeFlagBits.NoSeq)
msg.payload = payload
logger.info(f"Sending: {msg}")
await websocket.send(msg.marshal())
async def audio_only_client(
websocket: websockets.WebSocketClientProtocol, payload: bytes, flag: MsgTypeFlagBits
) -> None:
"""Send audio-only client message"""
msg = Message(type=MsgType.AudioOnlyClient, flag=flag)
msg.payload = payload
logger.info(f"Sending: {msg}")
await websocket.send(msg.marshal())
async def start_connection(websocket: websockets.WebSocketClientProtocol) -> None:
"""Start connection"""
msg = Message(type=MsgType.FullClientRequest, flag=MsgTypeFlagBits.WithEvent)
msg.event = EventType.StartConnection
msg.payload = b"{}"
logger.info(f"Sending: {msg}")
await websocket.send(msg.marshal())
async def finish_connection(websocket: websockets.WebSocketClientProtocol) -> None:
"""Finish connection"""
msg = Message(type=MsgType.FullClientRequest, flag=MsgTypeFlagBits.WithEvent)
msg.event = EventType.FinishConnection
msg.payload = b"{}"
logger.info(f"Sending: {msg}")
await websocket.send(msg.marshal())
async def start_session(
websocket: websockets.WebSocketClientProtocol, payload: bytes, session_id: str
) -> None:
"""Start session"""
msg = Message(type=MsgType.FullClientRequest, flag=MsgTypeFlagBits.WithEvent)
msg.event = EventType.StartSession
msg.session_id = session_id
msg.payload = payload
logger.info(f"Sending: {msg}")
await websocket.send(msg.marshal())
async def finish_session(
websocket: websockets.WebSocketClientProtocol, session_id: str
) -> None:
"""Finish session"""
msg = Message(type=MsgType.FullClientRequest, flag=MsgTypeFlagBits.WithEvent)
msg.event = EventType.FinishSession
msg.session_id = session_id
msg.payload = b"{}"
logger.info(f"Sending: {msg}")
await websocket.send(msg.marshal())
async def cancel_session(
websocket: websockets.WebSocketClientProtocol, session_id: str
) -> None:
"""Cancel session"""
msg = Message(type=MsgType.FullClientRequest, flag=MsgTypeFlagBits.WithEvent)
msg.event = EventType.CancelSession
msg.session_id = session_id
msg.payload = b"{}"
logger.info(f"Sending: {msg}")
await websocket.send(msg.marshal())
async def task_request(
websocket: websockets.WebSocketClientProtocol, payload: bytes, session_id: str
) -> None:
"""Send task request"""
msg = Message(type=MsgType.FullClientRequest, flag=MsgTypeFlagBits.WithEvent)
msg.event = EventType.TaskRequest
msg.session_id = session_id
msg.payload = payload
logger.info(f"Sending: {msg}")
await websocket.send(msg.marshal())