Files
dsProject/dsLightRag/Util/ASRClient.py
2025-08-22 09:23:33 +08:00

126 lines
4.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import time
import uuid
import requests
from Config.Config import HS_ASR_APP_ID, HS_ASR_TOKEN
class ASRClient:
def __init__(self, appid=None, token=None, file_url=None):
self.appid = appid or HS_ASR_APP_ID
self.token = token or HS_ASR_TOKEN
self.file_url = file_url or "https://ttc-advisory-oss.oss-cn-hangzhou.aliyuncs.com/lark_audio/int/T_APLA_1941058348698869760.mp3"
self.submit_url = "https://openspeech-direct.zijieapi.com/api/v3/auc/bigmodel/submit"
self.query_url = "https://openspeech-direct.zijieapi.com/api/v3/auc/bigmodel/query"
self.task_id = None
self.x_tt_logid = None
def _prepare_headers(self, task_id=None):
"""准备请求头"""
headers = {
"X-Api-App-Key": self.appid,
"X-Api-Access-Key": self.token,
"X-Api-Resource-Id": "volc.bigasr.auc",
"X-Api-Request-Id": task_id or str(uuid.uuid4()),
}
if task_id and self.x_tt_logid:
headers["X-Tt-Logid"] = self.x_tt_logid
else:
headers["X-Api-Sequence"] = "-1"
return headers
def _prepare_request_body(self):
"""准备请求体"""
return {
"user": {
"uid": "fake_uid"
},
"audio": {
"url": self.file_url
},
"request": {
"model_name": "bigmodel",
"enable_channel_split": True,
"enable_ddc": True,
"enable_speaker_info": True,
"enable_punc": True,
"enable_itn": True,
"corpus": {
"correct_table_name": "",
"context": ""
}
}
}
def submit_task(self):
"""提交ASR任务"""
headers = self._prepare_headers()
self.task_id = headers["X-Api-Request-Id"]
request_body = self._prepare_request_body()
print(f'Submit task id: {self.task_id}')
response = requests.post(self.submit_url, data=json.dumps(request_body), headers=headers)
if 'X-Api-Status-Code' in response.headers and response.headers["X-Api-Status-Code"] == "20000000":
print(f'Submit task response header X-Api-Status-Code: {response.headers["X-Api-Status-Code"]}')
print(f'Submit task response header X-Api-Message: {response.headers["X-Api-Message"]}')
self.x_tt_logid = response.headers.get("X-Tt-Logid", "")
print(f'Submit task response header X-Tt-Logid: {self.x_tt_logid}\n')
return self.task_id, self.x_tt_logid
else:
print(f'Submit task failed and the response headers are: {response.headers}')
exit(1)
def query_task(self):
"""查询ASR任务状态"""
if not self.task_id or not self.x_tt_logid:
print("Task not submitted yet. Please call submit_task() first.")
exit(1)
headers = self._prepare_headers(self.task_id)
response = requests.post(self.query_url, json.dumps({}), headers=headers)
if 'X-Api-Status-Code' in response.headers:
print(f'Query task response header X-Api-Status-Code: {response.headers["X-Api-Status-Code"]}')
print(f'Query task response header X-Api-Message: {response.headers["X-Api-Message"]}')
print(f'Query task response header X-Tt-Logid: {response.headers["X-Tt-Logid"]}\n')
else:
print(f'Query task failed and the response headers are: {response.headers}')
exit(1)
return response
def extract_text_from_result(self, result):
"""从识别结果中提取文本内容"""
if isinstance(result, dict) and 'result' in result and 'text' in result['result']:
return result['result']['text']
return ""
def process_task(self):
"""处理ASR任务提交并轮询结果"""
self.submit_task()
while True:
query_response = self.query_task()
code = query_response.headers.get('X-Api-Status-Code', "")
if code == '20000000': # task finished
result_json = query_response.json()
# 提取并返回文本内容
text_content = self.extract_text_from_result(result_json)
print(f"识别文本: {text_content}")
return text_content
elif code != '20000001' and code != '20000002': # task failed
print("FAILED!")
exit(1)
time.sleep(1)
if __name__ == '__main__':
# 使用指定的音频文件URL进行测试
audio_url = "https://dsideal.obs.cn-north-1.myhuaweicloud.com/HuangHai/Temp/temp_audio_20250822085243.wav"
# 创建ASR客户端实例传入指定的音频文件URL
asr_client = ASRClient(file_url=audio_url)
# 处理ASR任务并获取文本结果
text_result = asr_client.process_task()