Files
QingLong/XiaoZhi/xiaozhi-esp32/main/protocols/mqtt_protocol.cc
2025-08-15 09:13:13 +08:00

313 lines
10 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#include "mqtt_protocol.h"
#include "board.h"
#include "application.h"
#include "settings.h"
#include <esp_log.h>
#include <ml307_mqtt.h>
#include <ml307_udp.h>
#include <cstring>
#include <arpa/inet.h>
#include "assets/lang_config.h"
#define TAG "MQTT"
MqttProtocol::MqttProtocol() {
event_group_handle_ = xEventGroupCreate();
}
MqttProtocol::~MqttProtocol() {
ESP_LOGI(TAG, "MqttProtocol deinit");
if (udp_ != nullptr) {
delete udp_;
}
if (mqtt_ != nullptr) {
delete mqtt_;
}
vEventGroupDelete(event_group_handle_);
}
void MqttProtocol::Start() {
StartMqttClient(false);
}
bool MqttProtocol::StartMqttClient(bool report_error) {
if (mqtt_ != nullptr) {
ESP_LOGW(TAG, "Mqtt client already started");
delete mqtt_;
}
Settings settings("mqtt", false);
endpoint_ = settings.GetString("endpoint");
client_id_ = settings.GetString("client_id");
username_ = settings.GetString("username");
password_ = settings.GetString("password");
publish_topic_ = settings.GetString("publish_topic");
if (endpoint_.empty()) {
ESP_LOGW(TAG, "MQTT endpoint is not specified");
if (report_error) {
SetError(Lang::Strings::SERVER_NOT_FOUND);
}
return false;
}
mqtt_ = Board::GetInstance().CreateMqtt();
mqtt_->SetKeepAlive(90);
mqtt_->OnDisconnected([this]() {
ESP_LOGI(TAG, "Disconnected from endpoint");
});
mqtt_->OnMessage([this](const std::string& topic, const std::string& payload) {
cJSON* root = cJSON_Parse(payload.c_str());
if (root == nullptr) {
ESP_LOGE(TAG, "Failed to parse json message %s", payload.c_str());
return;
}
cJSON* type = cJSON_GetObjectItem(root, "type");
if (type == nullptr) {
ESP_LOGE(TAG, "Message type is not specified");
cJSON_Delete(root);
return;
}
if (strcmp(type->valuestring, "hello") == 0) {
ParseServerHello(root);
} else if (strcmp(type->valuestring, "goodbye") == 0) {
auto session_id = cJSON_GetObjectItem(root, "session_id");
ESP_LOGI(TAG, "Received goodbye message, session_id: %s", session_id ? session_id->valuestring : "null");
if (session_id == nullptr || session_id_ == session_id->valuestring) {
Application::GetInstance().Schedule([this]() {
CloseAudioChannel();
});
}
} else if (on_incoming_json_ != nullptr) {
on_incoming_json_(root);
}
cJSON_Delete(root);
last_incoming_time_ = std::chrono::steady_clock::now();
});
ESP_LOGI(TAG, "Connecting to endpoint %s", endpoint_.c_str());
if (!mqtt_->Connect(endpoint_, 8883, client_id_, username_, password_)) {
ESP_LOGE(TAG, "Failed to connect to endpoint");
SetError(Lang::Strings::SERVER_NOT_CONNECTED);
return false;
}
ESP_LOGI(TAG, "Connected to endpoint");
return true;
}
bool MqttProtocol::SendText(const std::string& text) {
if (publish_topic_.empty()) {
return false;
}
if (!mqtt_->Publish(publish_topic_, text)) {
ESP_LOGE(TAG, "Failed to publish message: %s", text.c_str());
SetError(Lang::Strings::SERVER_ERROR);
return false;
}
return true;
}
void MqttProtocol::SendAudio(const std::vector<uint8_t>& data) {
std::lock_guard<std::mutex> lock(channel_mutex_);
if (udp_ == nullptr) {
return;
}
std::string nonce(aes_nonce_);
*(uint16_t*)&nonce[2] = htons(data.size());
*(uint32_t*)&nonce[12] = htonl(++local_sequence_);
std::string encrypted;
encrypted.resize(aes_nonce_.size() + data.size());
memcpy(encrypted.data(), nonce.data(), nonce.size());
size_t nc_off = 0;
uint8_t stream_block[16] = {0};
if (mbedtls_aes_crypt_ctr(&aes_ctx_, data.size(), &nc_off, (uint8_t*)nonce.c_str(), stream_block,
(uint8_t*)data.data(), (uint8_t*)&encrypted[nonce.size()]) != 0) {
ESP_LOGE(TAG, "Failed to encrypt audio data");
return;
}
busy_sending_audio_ = true;
udp_->Send(encrypted);
busy_sending_audio_ = false;
}
void MqttProtocol::CloseAudioChannel() {
{
std::lock_guard<std::mutex> lock(channel_mutex_);
if (udp_ != nullptr) {
delete udp_;
udp_ = nullptr;
}
}
std::string message = "{";
message += "\"session_id\":\"" + session_id_ + "\",";
message += "\"type\":\"goodbye\"";
message += "}";
SendText(message);
if (on_audio_channel_closed_ != nullptr) {
on_audio_channel_closed_();
}
}
bool MqttProtocol::OpenAudioChannel() {
if (mqtt_ == nullptr || !mqtt_->IsConnected()) {
ESP_LOGI(TAG, "MQTT is not connected, try to connect now");
if (!StartMqttClient(true)) {
return false;
}
}
busy_sending_audio_ = false;
error_occurred_ = false;
session_id_ = "";
xEventGroupClearBits(event_group_handle_, MQTT_PROTOCOL_SERVER_HELLO_EVENT);
// 发送 hello 消息申请 UDP 通道
std::string message = "{";
message += "\"type\":\"hello\",";
message += "\"version\": 3,";
message += "\"transport\":\"udp\",";
message += "\"audio_params\":{";
message += "\"format\":\"opus\", \"sample_rate\":16000, \"channels\":1, \"frame_duration\":" + std::to_string(OPUS_FRAME_DURATION_MS);
message += "}}";
if (!SendText(message)) {
return false;
}
// 等待服务器响应
EventBits_t bits = xEventGroupWaitBits(event_group_handle_, MQTT_PROTOCOL_SERVER_HELLO_EVENT, pdTRUE, pdFALSE, pdMS_TO_TICKS(10000));
if (!(bits & MQTT_PROTOCOL_SERVER_HELLO_EVENT)) {
ESP_LOGE(TAG, "Failed to receive server hello");
SetError(Lang::Strings::SERVER_TIMEOUT);
return false;
}
std::lock_guard<std::mutex> lock(channel_mutex_);
if (udp_ != nullptr) {
delete udp_;
}
udp_ = Board::GetInstance().CreateUdp();
udp_->OnMessage([this](const std::string& data) {
if (data.size() < sizeof(aes_nonce_)) {
ESP_LOGE(TAG, "Invalid audio packet size: %zu", data.size());
return;
}
if (data[0] != 0x01) {
ESP_LOGE(TAG, "Invalid audio packet type: %x", data[0]);
return;
}
uint32_t sequence = ntohl(*(uint32_t*)&data[12]);
if (sequence < remote_sequence_) {
ESP_LOGW(TAG, "Received audio packet with old sequence: %lu, expected: %lu", sequence, remote_sequence_);
return;
}
if (sequence != remote_sequence_ + 1) {
ESP_LOGW(TAG, "Received audio packet with wrong sequence: %lu, expected: %lu", sequence, remote_sequence_ + 1);
}
std::vector<uint8_t> decrypted;
size_t decrypted_size = data.size() - aes_nonce_.size();
size_t nc_off = 0;
uint8_t stream_block[16] = {0};
decrypted.resize(decrypted_size);
auto nonce = (uint8_t*)data.data();
auto encrypted = (uint8_t*)data.data() + aes_nonce_.size();
int ret = mbedtls_aes_crypt_ctr(&aes_ctx_, decrypted_size, &nc_off, nonce, stream_block, encrypted, (uint8_t*)decrypted.data());
if (ret != 0) {
ESP_LOGE(TAG, "Failed to decrypt audio data, ret: %d", ret);
return;
}
if (on_incoming_audio_ != nullptr) {
on_incoming_audio_(std::move(decrypted));
}
remote_sequence_ = sequence;
last_incoming_time_ = std::chrono::steady_clock::now();
});
udp_->Connect(udp_server_, udp_port_);
if (on_audio_channel_opened_ != nullptr) {
on_audio_channel_opened_();
}
return true;
}
void MqttProtocol::ParseServerHello(const cJSON* root) {
auto transport = cJSON_GetObjectItem(root, "transport");
if (transport == nullptr || strcmp(transport->valuestring, "udp") != 0) {
ESP_LOGE(TAG, "Unsupported transport: %s", transport->valuestring);
return;
}
auto session_id = cJSON_GetObjectItem(root, "session_id");
if (session_id != nullptr) {
session_id_ = session_id->valuestring;
ESP_LOGI(TAG, "Session ID: %s", session_id_.c_str());
}
// Get sample rate from hello message
auto audio_params = cJSON_GetObjectItem(root, "audio_params");
if (audio_params != NULL) {
auto sample_rate = cJSON_GetObjectItem(audio_params, "sample_rate");
if (sample_rate != NULL) {
server_sample_rate_ = sample_rate->valueint;
}
auto frame_duration = cJSON_GetObjectItem(audio_params, "frame_duration");
if (frame_duration != NULL) {
server_frame_duration_ = frame_duration->valueint;
}
}
auto udp = cJSON_GetObjectItem(root, "udp");
if (udp == nullptr) {
ESP_LOGE(TAG, "UDP is not specified");
return;
}
udp_server_ = cJSON_GetObjectItem(udp, "server")->valuestring;
udp_port_ = cJSON_GetObjectItem(udp, "port")->valueint;
auto key = cJSON_GetObjectItem(udp, "key")->valuestring;
auto nonce = cJSON_GetObjectItem(udp, "nonce")->valuestring;
// auto encryption = cJSON_GetObjectItem(udp, "encryption")->valuestring;
// ESP_LOGI(TAG, "UDP server: %s, port: %d, encryption: %s", udp_server_.c_str(), udp_port_, encryption);
aes_nonce_ = DecodeHexString(nonce);
mbedtls_aes_init(&aes_ctx_);
mbedtls_aes_setkey_enc(&aes_ctx_, (const unsigned char*)DecodeHexString(key).c_str(), 128);
local_sequence_ = 0;
remote_sequence_ = 0;
xEventGroupSetBits(event_group_handle_, MQTT_PROTOCOL_SERVER_HELLO_EVENT);
}
static const char hex_chars[] = "0123456789ABCDEF";
// 辅助函数,将单个十六进制字符转换为对应的数值
static inline uint8_t CharToHex(char c) {
if (c >= '0' && c <= '9') return c - '0';
if (c >= 'A' && c <= 'F') return c - 'A' + 10;
if (c >= 'a' && c <= 'f') return c - 'a' + 10;
return 0; // 对于无效输入返回0
}
std::string MqttProtocol::DecodeHexString(const std::string& hex_string) {
std::string decoded;
decoded.reserve(hex_string.size() / 2);
for (size_t i = 0; i < hex_string.size(); i += 2) {
char byte = (CharToHex(hex_string[i]) << 4) | CharToHex(hex_string[i + 1]);
decoded.push_back(byte);
}
return decoded;
}
bool MqttProtocol::IsAudioChannelOpened() const {
return udp_ != nullptr && !error_occurred_ && !IsTimeout();
}