diff --git a/runtime/ops/mapper/__init__.py b/runtime/ops/mapper/__init__.py index ed0a0fcb..193db491 100644 --- a/runtime/ops/mapper/__init__.py +++ b/runtime/ops/mapper/__init__.py @@ -47,7 +47,30 @@ def _import_operators(): from . import remove_duplicate_sentences from . import knowledge_relation_slice from . import pii_ner_detection - # ===== Video operators (PR1-PR5) ===== + + # ===== Audio operators ===== + from . import audio_anomaly_filter + from . import audio_asr_pipeline + from . import audio_asr_transcribe + from . import audio_dc_offset_removal + from . import audio_emotion_recognize + from . import audio_fast_lang_id + from . import audio_fast_lang_id_text + from . import audio_format_convert + from . import audio_gtcrn_denoise + from . import audio_hum_notch + from . import audio_noise_gate + from . import audio_pre_emphasis + from . import audio_quantize_encode + from . import audio_rms_loudness_normalize + from . import audio_simple_agc + from . import audio_soft_peak_limiter + from . import audio_sound_classify + from . import audio_telephony_bandpass + from . import audio_text_summarize + from . import audio_trim_silence_edges + + # ===== Video operators (PR1-PR5) ===== from . import _video_common from . import video_format_convert from . import video_sensitive_detect diff --git a/runtime/ops/mapper/audio_anomaly_filter/README.md b/runtime/ops/mapper/audio_anomaly_filter/README.md new file mode 100644 index 00000000..fab93a76 --- /dev/null +++ b/runtime/ops/mapper/audio_anomaly_filter/README.md @@ -0,0 +1,41 @@ +# AudioAnomalyFilter 异常语音检测与过滤算子 + +## 概述 + +AudioAnomalyFilter 用于对音频做快速质量检测,计算时长、静音帧比例与音频可读性,并给出 `quality_flag`。算子不再通过清空 `text/data` 模拟删除文件,而是写入结构化质量标签;下游音频算子可根据标签软跳过异常样本。 + +## 功能特性 + +- **时长检测**:支持最小时长/最大时长阈值 +- **静音比例检测**:基于短时 RMS 统计静音帧占比 +- **可读性检测**:文本文件强行改成 `.wav` 等不可读取音频会被标记为 `invalid` +- **下游门控**:支持让后续音频算子跳过异常样本,符合 DataMate 一文件一输出链路 +- **结果结构化输出**:报告写入 `ext_params.audio_quality` + +## 参数说明 + +| 参数 | 类型 | 默认值 | 说明 | +|---|---|---:|---| +| minDur | inputNumber | 1.0 | 最小时长(秒),小于该值视为异常 | +| maxDur | inputNumber | 20000.0 | 最大时长(秒),大于该值视为异常 | +| silenceRatioTh | slider | 0.8 | 静音帧比例阈值(0~1),>= 阈值视为异常 | +| silenceRmsRatioTh | slider | 0.05 | 静音判定阈值 = global_rms * 该比例 | +| skipInvalidDownstream | switch | true | true=后续音频算子遇到 invalid 软跳过;false=仅打标并继续处理 | + +## 输入输出 + +- **输入**:`sample["filePath"]`(音频文件路径) +- **输出**: + - `sample["ext_params"]["audio_quality"]`: + - `quality_flag`: `ok/invalid` + - `duration/silence_ratio/global_rms/reason/read_error/skip_downstream` + - 如果该算子为链路最后一个算子:导出当前音频,质量报告写入 `ext_params.audio_quality` + - 如果该算子位于链路中间:保持当前音频,后续音频算子按 `skip_downstream` 决定是否软跳过 + +## 依赖说明 + +- **Python 依赖**:优先 `torchaudio`,兜底 `soundfile` + +## 版本历史 + +- **v1.0.0**:支持时长/静音比例/可读性检测,按 DataMate 链路语义写质量标签并门控下游 diff --git a/runtime/ops/mapper/audio_anomaly_filter/__init__.py b/runtime/ops/mapper/audio_anomaly_filter/__init__.py new file mode 100644 index 00000000..fb9b4521 --- /dev/null +++ b/runtime/ops/mapper/audio_anomaly_filter/__init__.py @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- + +from datamate.core.base_op import OPERATORS + +OPERATORS.register_module(module_name='AudioAnomalyFilter', + module_path="ops.mapper.audio_anomaly_filter.process") diff --git a/runtime/ops/mapper/audio_anomaly_filter/audio_skip.py b/runtime/ops/mapper/audio_anomaly_filter/audio_skip.py new file mode 100644 index 00000000..aec49613 --- /dev/null +++ b/runtime/ops/mapper/audio_anomaly_filter/audio_skip.py @@ -0,0 +1,114 @@ +# -- encoding: utf-8 -- + +from pathlib import Path +from typing import Any, Dict + +from loguru import logger + + +AUDIO_EXTS = { + "aac", + "aif", + "aiff", + "amr", + "au", + "flac", + "m4a", + "mp3", + "oga", + "ogg", + "opus", + "snd", + "wav", + "webm", + "wma", +} + + +def _parts(path_value: str) -> set[str]: + try: + return {part.lower() for part in Path(path_value).parts} + except Exception: + return set() + + +def is_reference_sample(sample: Dict[str, Any], filepath_key: str = "filePath") -> bool: + path_value = str(sample.get(filepath_key) or "") + return "references" in _parts(path_value) + + +def _ext_from_sample( + sample: Dict[str, Any], + filepath_key: str = "filePath", + filetype_key: str = "fileType", + target_type_key: str = "target_type", +) -> str: + for key in (target_type_key, filetype_key): + value = str(sample.get(key) or "").strip().lower().lstrip(".") + if value: + return value + path_value = str(sample.get(filepath_key) or "").strip() + return Path(path_value).suffix.lower().lstrip(".") if path_value else "" + + +def is_audio_sample( + sample: Dict[str, Any], + filepath_key: str = "filePath", + filetype_key: str = "fileType", + target_type_key: str = "target_type", + data_key: str = "data", +) -> bool: + if is_reference_sample(sample, filepath_key): + return False + data = sample.get(data_key) + if isinstance(data, (bytes, bytearray)) and data: + return True + return _ext_from_sample(sample, filepath_key, filetype_key, target_type_key) in AUDIO_EXTS + + +def invalid_quality_reason(sample: Dict[str, Any], ext_params_key: str = "ext_params") -> str: + for key in ("fileName", "sourceFileName", "filePath"): + marker_source = Path(str(sample.get(key) or "")).stem.lower() + marker = "__quality_invalid" + if marker in marker_source: + reason = marker_source.split(marker, 1)[1].strip("_") or "invalid_audio" + return f"invalid_audio_quality:{reason}" + + ext = sample.get(ext_params_key, {}) + if not isinstance(ext, dict): + return "" + quality = ext.get("audio_quality", {}) + if not isinstance(quality, dict): + return "" + if str(quality.get("quality_flag") or "").strip().lower() != "invalid": + return "" + skip_downstream = quality.get("skip_downstream", True) + if isinstance(skip_downstream, str): + skip_downstream = skip_downstream.strip().lower() in {"1", "true", "yes", "y", "on"} + if not skip_downstream: + return "" + reason = str(quality.get("reason") or "invalid_audio").strip() + return f"invalid_audio_quality:{reason}" + + +def mark_skipped_sample( + sample: Dict[str, Any], + reason: str, + op_name: str, + text_key: str = "text", + data_key: str = "data", + filetype_key: str = "fileType", + target_type_key: str = "target_type", + ext_params_key: str = "ext_params", +) -> Dict[str, Any]: + ext = sample.get(ext_params_key, {}) + if not isinstance(ext, dict): + ext = {"_raw": ext} + ext.setdefault("audio_skip", {})[op_name] = reason + sample[ext_params_key] = ext + sample[text_key] = "" + sample[data_key] = b"" + sample[filetype_key] = "" + sample[target_type_key] = "" + logger.info(f"fileName: {sample.get('fileName')}, method: {op_name} skipped: {reason}") + return sample diff --git a/runtime/ops/mapper/audio_anomaly_filter/metadata.yml b/runtime/ops/mapper/audio_anomaly_filter/metadata.yml new file mode 100644 index 00000000..7f0d9394 --- /dev/null +++ b/runtime/ops/mapper/audio_anomaly_filter/metadata.yml @@ -0,0 +1,66 @@ +name: 'audioOps-异常语音检测与过滤' +name_en: 'audioOps-Audio Anomaly Detect & Filter' +description: '对音频做快速异常检测:时长范围、静音帧比例与可读性。结果写入 ext_params.audio_quality;可控制下游音频算子是否跳过异常样本。' +description_en: 'Fast audio anomaly detection (duration, silence ratio and readability). Writes ext_params.audio_quality and can make downstream audio ops skip invalid samples.' +language: 'python' +vendor: 'huawei' +raw_id: 'AudioAnomalyFilter' +version: '1.0.0' +types: + - 'cleaning' +modal: 'audio' +inputs: 'audio' +outputs: 'audio' +settings: + minDur: + name: '最小时长(秒)' + type: 'inputNumber' + description: '小于该值视为异常。' + defaultVal: 1.0 + min: 0 + max: 36000 + step: 0.1 + maxDur: + name: '最大时长(秒)' + type: 'inputNumber' + description: '大于该值视为异常。' + defaultVal: 20000.0 + min: 0 + max: 360000 + step: 1 + silenceRatioTh: + name: '静音帧比例阈值' + type: 'slider' + description: '静音帧比例 >= 阈值 时视为异常。' + defaultVal: 0.8 + min: 0 + max: 1 + step: 0.01 + silenceRmsRatioTh: + name: '静音判定比例' + type: 'slider' + description: '静音判定阈值 = global_rms * 该比例。' + defaultVal: 0.05 + min: 0 + max: 1 + step: 0.01 + skipInvalidDownstream: + name: '下游跳过异常音频' + description: '开启后,后续音频算子遇到 quality_flag=invalid 会软跳过;关闭后仅打标并继续处理。不可读取的伪 wav 会被标为 invalid。' + type: 'switch' + defaultVal: 'true' + required: false + checkedLabel: '跳过' + unCheckedLabel: '继续' +runtime: + memory: 104857600 + cpu: 0.2 + gpu: 0 + npu: 0 + storage: 10MB + +metrics: + - name: '处理耗时' + metric: '依输入音频长度与运行环境而定' +release: + - '首次发布' diff --git a/runtime/ops/mapper/audio_anomaly_filter/process.py b/runtime/ops/mapper/audio_anomaly_filter/process.py new file mode 100644 index 00000000..5d9cb278 --- /dev/null +++ b/runtime/ops/mapper/audio_anomaly_filter/process.py @@ -0,0 +1,221 @@ +# -- encoding: utf-8 -- + +import math +import re +import tempfile +import time +from pathlib import Path +from typing import Any, Dict, List, Tuple + +from loguru import logger + +from datamate.core.base_op import Mapper + +try: + from .audio_skip import is_audio_sample, mark_skipped_sample +except ImportError: + from audio_skip import is_audio_sample, mark_skipped_sample + + +def _as_bool(v: object) -> bool: + if isinstance(v, bool): + return v + return str(v).strip().lower() in {"1", "true", "yes", "y", "on"} + + +def _audio_ext(sample: Dict[str, Any], default_ext: str = "wav") -> str: + for key in ("target_type", "fileType"): + ext = str(sample.get(key) or "").strip().lower().lstrip(".") + if ext: + return ext + path_value = str(sample.get("filePath") or "").strip() + suffix = Path(path_value).suffix.lower().lstrip(".") if path_value else "" + return suffix or default_ext + + +def _source_audio_bytes(sample: Dict[str, Any], data_key: str, filepath_key: str, read_file: bool = False) -> bytes: + data = sample.get(data_key) + if isinstance(data, (bytes, bytearray)) and data: + return bytes(data) + if not read_file: + return b"" + path = Path(str(sample.get(filepath_key) or "")).expanduser() + if path.exists() and path.is_file(): + return path.read_bytes() + return b"" + + +def _safe_marker(value: str, default: str = "invalid_audio") -> str: + marker = re.sub(r"[^A-Za-z0-9._-]+", "_", str(value or default)).strip("._-") + return marker[:80] or default + + +def _strip_quality_marker(stem: str) -> str: + return re.sub(r"__quality_invalid(?:_[A-Za-z0-9._-]+)?$", "", str(stem or "sample")) + + +def _mark_quality_filename(sample: Dict[str, Any], filename_key: str, reason: str, target_ext: str) -> None: + file_name = str(sample.get(filename_key) or "").strip() + stem = _strip_quality_marker(Path(file_name).stem if file_name else "sample") + sample[filename_key] = f"{stem}__quality_invalid_{_safe_marker(reason)}.{target_ext}" + + +def _load_wave_mono(path: Path) -> Tuple[List[float], int]: + try: + import torchaudio # type: ignore + + wav, sr = torchaudio.load(str(path)) + if wav.ndim > 1: + wav = wav.mean(dim=0, keepdim=True) + return wav.squeeze(0).float().tolist(), int(sr) + except Exception: + try: + import soundfile as sf # type: ignore + + data, sr = sf.read(str(path), always_2d=False) + if getattr(data, "ndim", 1) > 1: + data = data.mean(axis=1) + return data.tolist(), int(sr) + except Exception as e: + raise RuntimeError(f"failed to read audio: {path}, error={e}") from e + + +def _load_source_mono(sample: Dict[str, Any], data_key: str, filepath_key: str) -> Tuple[List[float], int]: + data = sample.get(data_key) + if isinstance(data, (bytes, bytearray)) and data: + with tempfile.NamedTemporaryFile(suffix=f".{_audio_ext(sample)}", delete=False) as tmp: + tmp.write(bytes(data)) + tmp_path = Path(tmp.name) + try: + return _load_wave_mono(tmp_path) + finally: + try: + tmp_path.unlink() + except Exception: + pass + return _load_wave_mono(Path(str(sample.get(filepath_key) or "")).expanduser().resolve()) + + +def _frame_rms(x: List[float], sr: int, frame_ms: float, hop_ms: float) -> Tuple[List[float], float]: + if not x or sr <= 0: + return [], 0.0 + frame_len = max(1, int(sr * frame_ms / 1000.0)) + hop = max(1, int(sr * hop_ms / 1000.0)) + total_sq = sum(float(v) * float(v) for v in x) + global_rms = math.sqrt(total_sq / max(1, len(x))) + rms_list: List[float] = [] + for start in range(0, len(x), hop): + end = min(start + frame_len, len(x)) + if end <= start: + continue + frame = x[start:end] + rms_list.append(math.sqrt(sum(float(v) * float(v) for v in frame) / max(1, len(frame)))) + return rms_list, global_rms + + +class AudioAnomalyFilter(Mapper): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.min_dur = float(kwargs.get("minDur", 1.0)) + self.max_dur = float(kwargs.get("maxDur", 20000.0)) + self.silence_ratio_th = float(kwargs.get("silenceRatioTh", 0.8)) + self.silence_rms_ratio_th = float(kwargs.get("silenceRmsRatioTh", 0.05)) + self.skip_invalid_downstream = _as_bool(kwargs.get("skipInvalidDownstream", True)) + + def execute(self, sample: Dict[str, Any]) -> Dict[str, Any]: + start = time.time() + if not is_audio_sample(sample, self.filepath_key, self.filetype_key, self.target_type_key, self.data_key): + return mark_skipped_sample( + sample, + "non_audio_or_reference_file", + self.__class__.__name__, + self.text_key, + self.data_key, + self.filetype_key, + self.target_type_key, + self.ext_params_key, + ) + + audio_bytes_for_export = _source_audio_bytes(sample, self.data_key, self.filepath_key) + path_value = str(sample.get(self.filepath_key) or "").strip() + path_exists = bool(audio_bytes_for_export) or (bool(path_value) and Path(path_value).expanduser().exists()) + reasons: List[str] = [] + quality_flag = "ok" + read_error = "" + + if not path_exists: + duration = 0.0 + silence_ratio = 1.0 + global_rms = 0.0 + quality_flag = "invalid" + read_error = f"FileNotFoundError: input audio does not exist: {sample.get(self.filepath_key)}" + reasons.append("missing_audio_file") + else: + try: + wav, sr = _load_source_mono(sample, self.data_key, self.filepath_key) + duration = float(len(wav)) / float(sr) if sr > 0 else 0.0 + rms_frames, global_rms = _frame_rms(wav, sr, frame_ms=25.0, hop_ms=10.0) + if not rms_frames or global_rms <= 0.0: + silence_ratio = 1.0 + else: + threshold = max(1e-8, global_rms * float(self.silence_rms_ratio_th)) + silent = sum(1 for rms in rms_frames if rms < threshold) + silence_ratio = float(silent) / float(len(rms_frames)) + except Exception as e: + duration = 0.0 + silence_ratio = 1.0 + global_rms = 0.0 + quality_flag = "invalid" + read_error = f"{type(e).__name__}: {e}" + reasons.append("unreadable_audio") + + if duration <= 0.0: + quality_flag = "invalid" + if "duration_le_zero" not in reasons: + reasons.append("duration_le_zero") + elif duration < self.min_dur: + quality_flag = "invalid" + reasons.append("too_short") + elif duration > self.max_dur: + quality_flag = "invalid" + reasons.append("too_long") + if silence_ratio >= self.silence_ratio_th: + quality_flag = "invalid" + reasons.append("too_much_silence") + + report = { + "quality_flag": quality_flag, + "duration": round(duration, 3), + "silence_ratio": round(silence_ratio, 4), + "global_rms": round(global_rms, 6), + "reason": ",".join(reasons) if reasons else "", + "read_error": read_error, + "skip_downstream": self.skip_invalid_downstream, + } + ext = sample.get(self.ext_params_key, {}) + if not isinstance(ext, dict): + ext = {"_raw": ext} + ext["audio_quality"] = report + sample[self.ext_params_key] = ext + + sample[self.text_key] = "" + if self.is_last_op and not audio_bytes_for_export: + audio_bytes_for_export = _source_audio_bytes( + sample, + self.data_key, + self.filepath_key, + read_file=True, + ) + if audio_bytes_for_export: + sample[self.data_key] = audio_bytes_for_export + if self.is_last_op: + target_ext = _audio_ext(sample) + sample[self.filetype_key] = "txt" + sample[self.target_type_key] = target_ext + if quality_flag == "invalid": + _mark_quality_filename(sample, self.filename_key, report["reason"] or "invalid_audio", target_ext) + + logger.info( + f"fileName: {sample.get(self.filename_key)}, method: AudioAnomalyFilter costs {time.time() - start:6f} s" + ) + return sample diff --git a/runtime/ops/mapper/audio_anomaly_filter/requirements.txt b/runtime/ops/mapper/audio_anomaly_filter/requirements.txt new file mode 100644 index 00000000..fd0cf60b --- /dev/null +++ b/runtime/ops/mapper/audio_anomaly_filter/requirements.txt @@ -0,0 +1,2 @@ +torchaudio +soundfile diff --git a/runtime/ops/mapper/audio_asr_pipeline/README.md b/runtime/ops/mapper/audio_asr_pipeline/README.md new file mode 100644 index 00000000..84823bde --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/README.md @@ -0,0 +1,62 @@ +# AudioAsrPipeline 音频预处理与中英ASR流水线算子 + +## 概述 + +AudioAsrPipeline 将 `audio_preprocessor` 的推荐流水线封装为一个 DataMate Mapper 算子:标准化、(可选)降噪、(可选)异常过滤、语言识别、切分、ASR 识别与合并,并可选计算中英文关键词召回率。算子按 DataMate 单样本范式处理当前输入音频,最终只导出该输入文件对应的一个 `.txt` 转写文件,并在 `ext_params` 中记录中间产物路径,便于排查与验收。 + +## 功能特性 + +- **端到端流水线**:normalization →(可选)GTCRN →(可选)异常过滤 → LID → split → ASR → merge →(可选)关键词召回率 +- **可配置**:每个关键步骤参数化(降噪开关、过滤阈值、LID 截断秒数、切分长度、ASR 设备等) +- **结果可追溯**:中间产物路径记录在 `ext_params.audio_asr.artifacts` +- **关键词召回率**:复用 `audio_preprocessor/src/pipeline/eval_keyword_recall.py`,生成 `keyword_recall.txt` 并写入导出目录 +- **一入一出**:每个输入音频输出一个 `.txt`,内容为该音频的转写文本 + +## 参数说明 + +| 参数 | 类型 | 默认值 | 说明 | +|---|---|---:|---| +| doDenoise | switch | false | 是否启用 GTCRN 降噪 | +| denoiseModelPath | input | /models/AudioOperations/gtcrn/gtcrn.onnx | GTCRN ONNX 模型绝对路径 | +| doAnomalyFilter | switch | true | 是否启用异常语音检测与过滤 | +| minDur | inputNumber | 1.0 | 最小时长(秒) | +| maxDur | inputNumber | 20000.0 | 最大时长(秒) | +| silenceRatioTh | slider | 0.8 | 静音帧比例阈值(0~1) | +| silenceRmsRatioTh | slider | 0.05 | 静音判定阈值比例 | +| lidModelSource | input | /models/AudioOperations/lid/speechbrain_lang-id-voxlingua107-ecapa | SpeechBrain LID 本地模型目录 | +| lidDevice | select | cpu | LID 推理设备(cpu/cuda/npu) | +| lidMaxSeconds | inputNumber | 3.0 | LID 只取前 N 秒,0=全长 | +| maxSegmentSeconds | inputNumber | 120 | 切分最大秒数 | +| asrDevice | select | npu | ASR 设备参数(npu/cpu/auto) | +| doKeywordRecall | switch | false | 是否在 ASR 后计算关键词召回率 | +| referencePath | input | /dataset/{dataset_id}/references | 参考文件或参考目录路径;写入 `extraFilePath` 供后续评估算子读取,路径不存在会回退 | +| zhKeywordPath | input | /dataset/{dataset_id}/references/zh_keyword.txt | 中文关键词文件;不存在时优先从 `referencePath/extraFilePath` 找 `zh_keyword.txt` | +| enKeywordPath | input | /dataset/{dataset_id}/references/en_keyword.txt | 英文关键词文件;不存在时优先从 `referencePath/extraFilePath` 找 `en_keyword.txt` | +| keepKeywordDetails | switch | false | 是否将逐句 hit/miss 明细写入 `ext_params` | + +## 输入输出 + +- **输入**:`sample["filePath"]`(音频文件路径) +- **输出**: + - `sample["text"]`:当前输入音频对应的转写文本,并导出为 `.txt` + - `sample["ext_params"]["audio_asr"]`: + - `lang`:LID 结果(zh/en) + - `artifacts`:中间产物路径(normalized/denoise/lid/split/asr/merged_text) + - `reference`:填写 `referencePath` 后记录参考资源路径,并传给后续评估算子 + - `keyword_recall`:启用 `doKeywordRecall` 后写入中英文关键词召回率、样本数与报告路径,报告位于 `audio_reports/asr_pipeline/<文件名>/keyword_recall.txt` + +## 依赖说明 + +- **Python 依赖**(按启用功能而定): + - normalization/切分:`pydub`、`soundfile`、`numpy` + - LID:`torch`、`torchaudio`、`speechbrain` + - 降噪:`onnxruntime`(以及 GTCRN 模型文件) +- **系统依赖**: + - `pydub` 通常需要 `ffmpeg` +- **关键词召回率**: + - 使用纯 Python 文本处理,不额外依赖模型 + +## 版本历史 + +- **v1.0.0**:首次发布,支持音频标准化/(可选)降噪/过滤/LID/切分/ASR/合并 +- **v1.1.0**:同步 `audio_preprocessor` 关键词召回率能力,支持可选中英文关键词召回率评估 diff --git a/runtime/ops/mapper/audio_asr_pipeline/__init__.py b/runtime/ops/mapper/audio_asr_pipeline/__init__.py new file mode 100644 index 00000000..9d54df28 --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/__init__.py @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- + +from datamate.core.base_op import OPERATORS + +OPERATORS.register_module(module_name='AudioAsrPipeline', + module_path="ops.mapper.audio_asr_pipeline.process") diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/config/audio_config.yaml b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/config/audio_config.yaml new file mode 100644 index 00000000..ac4498e9 --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/config/audio_config.yaml @@ -0,0 +1,8 @@ +audio_config: + # audio_config.yaml - 音频格式化配置 + output_format: "wav" + channels: 1 + sample_rate: 16000 + sample_width: 2 + encoding: "pcm_s16le" + input_format: ["mp3", "wav", "aac", "m4a", "flac"] \ No newline at end of file diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/config/eval_wer.yaml b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/config/eval_wer.yaml new file mode 100644 index 00000000..8d48be93 --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/config/eval_wer.yaml @@ -0,0 +1,6 @@ +eval_wer: + zh_ref: "input_data/validation/zh_transcript.txt" + en_ref: "input_data/validation/en_transcript.txt" + hyp: "output_data/asr/merged_text.txt" + work_dir: "output_data/validation" + diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/config/merge_asr_by_source.yaml b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/config/merge_asr_by_source.yaml new file mode 100644 index 00000000..17f2f588 --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/config/merge_asr_by_source.yaml @@ -0,0 +1,6 @@ +merge_asr_by_source: + list_file: "output_data/split/item_with_lang.list" + zh_text: "output_data/asr/zh/ctc_greedy_search/text" + en_text: "output_data/asr/en/ctc_greedy_search/text" + output: "output_data/asr/merged_text.txt" + diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/scripts/audio_convert/__init__.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/scripts/audio_convert/__init__.py new file mode 100644 index 00000000..7cd34923 --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/scripts/audio_convert/__init__.py @@ -0,0 +1,3 @@ +"""Audio conversion utilities/CLI.""" + + diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/scripts/audio_convert/__main__.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/scripts/audio_convert/__main__.py new file mode 100644 index 00000000..bad5f88f --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/scripts/audio_convert/__main__.py @@ -0,0 +1,7 @@ +from .cli import main + + +if __name__ == "__main__": + raise SystemExit(main()) + + diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/scripts/audio_convert/audio_convert.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/scripts/audio_convert/audio_convert.py new file mode 100644 index 00000000..5999b580 --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/scripts/audio_convert/audio_convert.py @@ -0,0 +1,363 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import os +import shutil +import sys +from dataclasses import dataclass, field +from pathlib import Path +from typing import Iterable, List, Sequence, Optional + +# 导入配置加载模块和颜色工具 +try: + from config_loader import get_audio_config, clear_config_cache + from color_utils import info, warning, error, ok, header, success, fail +except ImportError: + # 如果模块导入失败,尝试从当前目录导入 + sys.path.insert(0, str(Path(__file__).parent)) + sys.path.insert(0, str(Path(__file__).parent.parent.parent / "src" / "utils")) + + try: + from config_loader import get_audio_config, clear_config_cache + from color_utils import info, warning, error, ok, header, success, fail + except ImportError as e: + print(f"[ERROR] 无法导入 config_loader: {e}", file=sys.stderr) + sys.exit(1) + + +def get_allowed_input_exts(config_path: Optional[str] = None) -> set[str]: + """从配置文件获取允许的输入扩展名 + + Args: + config_path: 配置文件路径,可选 + + Returns: + set[str]: 允许的扩展名集合 + """ + config = get_audio_config(config_path) + input_formats = config.get('input_format', ['mp3', 'wav', 'aac', 'm4a', 'flac']) + return {f".{fmt.lower().lstrip('.')}" for fmt in input_formats} + + +@dataclass(frozen=True) +class ConvertSpec: + """音频转换规格,从配置文件初始化""" + + def __init__(self, config_path: Optional[str] = None): + """初始化转换规格 + + Args: + config_path: 配置文件路径,可选 + """ + # 从配置获取默认值 + config = get_audio_config(config_path) + + # 使用field无法直接传递参数,我们通过__post_init__设置 + object.__setattr__(self, 'channels', config.get('channels', 1)) + object.__setattr__(self, 'frame_rate', config.get('sample_rate', 16000)) + object.__setattr__(self, 'sample_width_bytes', config.get('sample_width', 2)) + object.__setattr__(self, 'encoding', config.get('encoding', 'pcm_s16le')) + object.__setattr__(self, 'output_format', config.get('output_format', 'wav')) + + self.__post_init__() + + # 这些属性将在__init__中设置 + channels: int + frame_rate: int + sample_width_bytes: int + encoding: str + output_format: str + + def __post_init__(self): + """验证配置值""" + if self.channels not in [1, 2]: + raise ValueError(f"声道数必须是1或2,当前: {self.channels}") + if self.frame_rate <= 0: + raise ValueError(f"采样率必须为正数,当前: {self.frame_rate}") + if self.sample_width_bytes not in [1, 2, 3, 4]: + raise ValueError(f"采样位宽必须是1-4字节,当前: {self.sample_width_bytes}") + + +def _import_pydub(): + """Import pydub from the DataMate runtime environment.""" + try: + from pydub import AudioSegment # type: ignore + except Exception as e: # pragma: no cover + raise RuntimeError(f"无法导入 pydub,请在 DataMate 运行环境安装 pydub。原始错误:{e}") from e + return AudioSegment + + +def _read_index_file(path: Path) -> List[Path]: + if not path.exists(): + raise FileNotFoundError(f"索引文件不存在: {path}") + items: List[Path] = [] + for line in path.read_text(encoding="utf-8", errors="ignore").splitlines(): + s = line.strip() + if not s or s.startswith("#"): + continue + items.append(Path(s)) + return items + + +def _expand_inputs(paths: Sequence[str], index_file: str | None) -> List[Path]: + inputs: List[Path] = [] + if index_file: + inputs.extend(_read_index_file(Path(index_file))) + inputs.extend(Path(p) for p in paths) + # de-dup while preserving order + seen = set() + uniq: List[Path] = [] + for p in inputs: + key = os.fspath(p) + if key in seen: + continue + seen.add(key) + uniq.append(p) + return uniq + + +def _validate_inputs(inputs: Sequence[Path], config_path: Optional[str] = None) -> None: + """验证输入文件,使用配置中的允许格式 + + Args: + inputs: 输入文件路径序列 + config_path: 配置文件路径,可选 + """ + if not inputs: + raise ValueError("未提供输入音频路径。请使用位置参数或 --index_file。") + + allowed_exts = get_allowed_input_exts(config_path) + + for p in inputs: + if not p.exists(): + raise FileNotFoundError(f"输入文件不存在: {p}") + if not p.is_file(): + raise ValueError(f"输入不是文件: {p}") + ext = p.suffix.lower() + if ext not in allowed_exts: + raise ValueError( + f"不支持的源音频格式: {p}({ext})。仅支持: " + + ", ".join(sorted(x.lstrip('.') for x in allowed_exts)) + ) + + +def _resolve_output_paths(inputs: Sequence[Path], output: Path, config_path: Optional[str] = None) -> List[Path]: + """ + 解析输出路径,使用配置中的输出格式 + + Args: + inputs: 输入文件路径序列 + output: 输出路径 + config_path: 配置文件路径,可选 + + Returns: + List[Path]: 输出文件路径列表 + """ + config = get_audio_config(config_path) + output_ext = f".{config.get('output_format', 'wav').lower().lstrip('.')}" + + if len(inputs) == 1: + src = inputs[0] + # If output exists and is a directory, treat as directory output. + if output.exists() and output.is_dir(): + return [output / f"{src.stem}{output_ext}"] + # If user explicitly ends with path separator, treat as directory output. + if str(output).endswith(os.sep): + return [output / f"{src.stem}{output_ext}"] + # File output: check extension + if output.suffix == "": + return [output.with_suffix(output_ext)] + if output.suffix.lower() != output_ext: + raise ValueError(f"输出文件必须是 {output_ext} 后缀(或不给后缀让工具自动补{output_ext})。") + return [output] + + # multiple inputs + out_dir = output + if output.exists() and output.is_file(): + raise ValueError("多输入模式下,--output 必须是目录路径,不能是文件路径。") + return [out_dir / f"{src.stem}{output_ext}" for src in inputs] + + +def _ensure_parent_dirs(paths: Iterable[Path]) -> None: + for p in paths: + p.parent.mkdir(parents=True, exist_ok=True) + + +def _check_ffmpeg_hint() -> str | None: + # pydub relies on ffmpeg/avlib. Give a clear hint if missing. + if shutil.which("ffmpeg") is None and shutil.which("avconv") is None: + return "未检测到 ffmpeg/avconv,pydub 可能无法解码 mp3/aac/m4a/flac。请先安装 ffmpeg。" + return None + + +def convert_one(AudioSegment, src: Path, dst: Path, spec: ConvertSpec) -> bool: + """转换单个音频文件,使用配置中的规格 + + Args: + AudioSegment: pydub 的 AudioSegment 类 + src: 源文件路径 + dst: 目标文件路径 + spec: 转换规格 + + Returns: + bool: 转换是否成功 + """ + try: + audio = AudioSegment.from_file(src) + audio = audio.set_channels(spec.channels) + audio = audio.set_frame_rate(spec.frame_rate) + audio = audio.set_sample_width(spec.sample_width_bytes) + # 使用配置中的编码格式导出 + audio.export(dst, format=spec.output_format, codec=spec.encoding) + return True + except Exception as e: + print(error(f"转换失败 {src.name}: {e}")) + return False + + +def build_argparser() -> argparse.ArgumentParser: + """构建命令行参数解析器""" + # 使用默认配置显示帮助信息 + config = get_audio_config() + output_format = config.get('output_format', 'wav') + + p = argparse.ArgumentParser( + prog="audio_convert", + description=( + f"将音频统一转换为 {output_format.upper()}:" + f"{config.get('channels', 1)}通道 / " + f"{config.get('sample_rate', 16000)}Hz / " + f"{config.get('sample_width', 2)*8}bit {config.get('encoding', 'pcm_s16le')}。\n" + f"支持源格式: {', '.join(config.get('input_format', []))}" + ), + formatter_class=argparse.RawTextHelpFormatter, + ) + p.add_argument( + "inputs", + nargs="*", + help="输入音频路径:可传 1 个或多个文件路径", + ) + p.add_argument( + "--index_file", + "-f", + default=None, + help="索引文件路径:文件中每行一个音频路径(支持 # 注释与空行)", + ) + p.add_argument( + "--output", + "-o", + required=True, + help=( + "输出路径:\n" + f"- 单输入:可为文件或目录(自动添加 .{output_format} 后缀)\n" + "- 多输入:必须为目录\n" + ), + ) + p.add_argument( + "--overwrite", + action="store_true", + help="允许覆盖已存在的输出文件", + ) + p.add_argument( + "--config", + "-c", + default=None, + help="自定义配置文件路径,不指定则使用默认配置", + ) + p.add_argument( + "--show_config", + action="store_true", + help="显示当前配置并退出", + ) + return p + + +def print_config_info(config_path: Optional[str] = None) -> None: + """打印当前配置信息 + + Args: + config_path: 配置文件路径,可选 + """ + config = get_audio_config(config_path) + print(header("当前音频转换配置")) + if config_path: + print(info(f"配置文件: {config_path}")) + else: + print(info("配置文件: 使用默认配置")) + print(info(f"输出格式: {config.get('output_format')}")) + print(info(f"声道数: {config.get('channels')}")) + print(info(f"采样率: {config.get('sample_rate')} Hz")) + print(info(f"采样位宽: {config.get('sample_width')} 字节 ({config.get('sample_width')*8} bit)")) + print(info(f"编码格式: {config.get('encoding')}")) + print(info(f"输入格式: {', '.join(config.get('input_format', []))}")) + + # 如果有质量检查配置,也显示 + if 'quality_checks' in config: + print(info("质量检查:")) + qc = config['quality_checks'] + print(f" - 最小时长: {qc.get('min_duration_seconds')}秒") + print(f" - 最大时长: {qc.get('max_duration_seconds')}秒") + print(f" - 最大静音比例: {qc.get('max_silence_ratio')}") + + +def main(argv: Sequence[str] | None = None) -> int: + args = build_argparser().parse_args(argv) + + # 如果指定了配置文件,清除缓存并重新加载配置 + if args.config: + clear_config_cache() + + # 显示配置信息 + if args.show_config: + print_config_info(args.config) + return 0 + + # 注意:这里需要在解析参数后获取配置,因为用户可能指定了--config + inputs = _expand_inputs(args.inputs, args.index_file) + _validate_inputs(inputs, args.config) + + ffmpeg_hint = _check_ffmpeg_hint() + if ffmpeg_hint: + print(warning(ffmpeg_hint)) + + out = Path(args.output) + out_paths = _resolve_output_paths(inputs, out, args.config) + _ensure_parent_dirs(out_paths) + + if not args.overwrite: + exists = [p for p in out_paths if p.exists()] + if exists: + print(warning(f"检测到 {len(exists)} 个输出文件已存在")) + response = input("是否覆盖这些文件?(y/n, 回车确认 y): ").strip().lower() + if response not in ['y', 'yes', '']: + print(info("用户取消操作,程序结束")) + return 0 + + AudioSegment = _import_pydub() + spec = ConvertSpec(args.config) + + success_count = 0 + total_count = len(inputs) + + for src, dst in zip(inputs, out_paths): + if convert_one(AudioSegment, src=src, dst=dst, spec=spec): + # 只输出文件名 + print(ok(f"转换成功: {src.name}")) + success_count += 1 + else: + print(error(f"转换失败: {src.name}")) + + # 显示统计信息 + if success_count == total_count: + print(success(f"所有 {total_count} 个文件转换完成")) + else: + print(warning(f"转换完成: {success_count}/{total_count} 个文件成功")) + if success_count < total_count: + print(error(f"{total_count - success_count} 个文件转换失败")) + + return 0 if success_count == total_count else 1 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/scripts/audio_convert/config_loader.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/scripts/audio_convert/config_loader.py new file mode 100644 index 00000000..9f4d1c74 --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/scripts/audio_convert/config_loader.py @@ -0,0 +1,149 @@ +#!/usr/bin/env python3 +""" +配置加载模块 +负责定位和加载 audio_config.yaml 配置文件 +支持通过命令行参数指定配置文件 +""" +import sys +from pathlib import Path +from typing import Dict, Any, Optional +import yaml + + +def find_config_file(config_path: Optional[str] = None) -> Path: + """ + 定位配置文件,按以下优先级查找: + 1. 如果提供了 config_path 参数,直接使用它 + 2. 当前工作目录的 config/audio_config.yaml + 3. 脚本所在目录的上一级 config/audio_config.yaml + 4. 用户主目录的 .audio_preprocessor/audio_config.yaml + + Args: + config_path: 用户指定的配置文件路径,可选 + + Returns: + Path: 配置文件的路径 + """ + # 如果提供了配置路径,直接使用 + if config_path: + path = Path(config_path) + if not path.exists(): + raise FileNotFoundError(f"指定的配置文件不存在: {path}") + return path + + # 否则按默认优先级查找 + search_paths = [ + # 当前工作目录下的 config 子目录 + Path.cwd() / "config" / "audio_config.yaml", + # 脚本所在目录的上一级 config 目录 + Path(__file__).parent.parent.parent / "config" / "audio_config.yaml", + # 用户主目录的配置目录 + Path.home() / ".audio_preprocessor" / "audio_config.yaml", + ] + + for config_path in search_paths: + if config_path.exists(): + return config_path + + # 如果都找不到,返回默认路径(用于创建示例配置) + return search_paths[1] + + +def load_audio_config(config_path: Optional[str] = None) -> Dict[str, Any]: + """加载音频配置文件 + + Args: + config_path: 用户指定的配置文件路径,可选 + + Returns: + Dict[str, Any]: 配置字典 + """ + config_file = find_config_file(config_path) + + # 如果配置文件不存在,创建默认配置并提示 + if not config_file.exists(): + create_default_config(config_file) + print(f"[INFO] 配置文件不存在,已创建默认配置: {config_file}") + + try: + with open(config_file, 'r', encoding='utf-8') as f: + config_data = yaml.safe_load(f) + + # 检查配置文件结构 + if 'audio_config' not in config_data: + config = config_data # 如果是顶级配置 + else: + config = config_data['audio_config'] + + # 验证必要配置项 + required_keys = ['output_format', 'channels', 'sample_rate', + 'sample_width', 'encoding', 'input_format'] + for key in required_keys: + if key not in config: + raise ValueError(f"配置文件中缺少必要的键: {key}") + + return config + + except yaml.YAMLError as e: + raise ValueError(f"配置文件格式错误: {config_file}") from e + + +def create_default_config(config_path: Path) -> None: + """创建默认配置文件""" + config_path.parent.mkdir(parents=True, exist_ok=True) + + default_config = { + 'audio_config': { + 'output_format': 'wav', + 'channels': 1, + 'sample_rate': 16000, + 'sample_width': 2, + 'encoding': 'pcm_s16le', + 'input_format': ['mp3', 'wav', 'aac', 'm4a', 'flac'], + 'quality_checks': { + 'min_duration_seconds': 0.5, + 'max_duration_seconds': 30.0, + 'max_silence_ratio': 0.3 + }, + 'logging': { + 'level': 'INFO', + 'log_file': 'audio_conversion.log' + } + } + } + + with open(config_path, 'w', encoding='utf-8') as f: + yaml.dump(default_config, f, default_flow_style=False, + allow_unicode=True, indent=2) + + +# 全局配置变量(惰性加载) +_AUDIO_CONFIG = None +_CONFIG_PATH = None + + +def get_audio_config(config_path: Optional[str] = None) -> Dict[str, Any]: + """获取音频配置(单例模式) + + Args: + config_path: 用户指定的配置文件路径,可选 + + Returns: + Dict[str, Any]: 配置字典 + """ + global _AUDIO_CONFIG, _CONFIG_PATH + + # 如果提供了新路径或之前没有加载过,重新加载配置 + if config_path is not None or _AUDIO_CONFIG is None: + _AUDIO_CONFIG = load_audio_config(config_path) + if config_path: + _CONFIG_PATH = config_path + + return _AUDIO_CONFIG + + +def clear_config_cache() -> None: + """清除配置缓存,强制重新加载""" + global _AUDIO_CONFIG, _CONFIG_PATH + _AUDIO_CONFIG = None + _CONFIG_PATH = None \ No newline at end of file diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/scripts/audio_convert/readme.md b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/scripts/audio_convert/readme.md new file mode 100644 index 00000000..c18d2fef --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/scripts/audio_convert/readme.md @@ -0,0 +1,29 @@ +1. 转换单个音频文件 + +bash +python audio_convert.py input.mp3 --output output.wav +# 或指定输出目录,会自动以原文件名生成 .wav +python audio_convert.py input.mp3 --output ./cleaned_audio/ + +2. 批量转换多个音频文件(输出必须是一个目录) + +bash +python audio_convert.py audio1.mp3 audio2.flac audio3.wav --output ./batch_output/ +3. 使用索引文件批量转换 +这是处理大量文件最高效的方式。首先创建一个文本文件(如 file_list.txt),每行一个音频文件路径: + +text +# file_list.txt 示例 +/data/sounds/recording1.mp3 +/data/sounds/sample2.m4a +# 这是一行注释 +/data/sounds/lecture3.flac +然后运行命令: + +bash +python audio_convert.py --index_file file_list.txt --output ./converted/ +4. 允许覆盖已存在的输出文件 +如果输出目录已有同名文件,需要添加 --overwrite 参数: + +bash +python audio_convert.py input.aac --output existing.wav --overwrite \ No newline at end of file diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/pipeline/0_normalization.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/pipeline/0_normalization.py new file mode 100644 index 00000000..04edd2a1 --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/pipeline/0_normalization.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python3 +""" +1_normalization.py + +执行顺序:第 1 步 +- 调用 src.pipeline.normalization 完成音频标准化。 +""" + +from pathlib import Path +import sys + +PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent +sys.path.insert(0, str(PROJECT_ROOT)) + +from src.pipeline import normalization # type: ignore + + +if __name__ == "__main__": + raise SystemExit(normalization.main()) + diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/pipeline/1_denoise.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/pipeline/1_denoise.py new file mode 100644 index 00000000..1a047ec6 --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/pipeline/1_denoise.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python3 +""" +2_denoise.py + +执行顺序:第 2 步 +- 调用 src.utils.gtcrn_denoise,对 output_data/normalization 下的音频做本地智能降噪, + 输出到 output_data/denoise。 +""" + +from pathlib import Path +import sys + +PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent +sys.path.insert(0, str(PROJECT_ROOT)) +sys.path.insert(0, str(PROJECT_ROOT / "src" / "utils")) + +from src.utils import gtcrn_denoise # type: ignore + +try: + from color_utils import info, warning, error, ok, success, header # type: ignore + + def print_info(msg: str): + print(info(msg)) + + def print_error(msg: str): + print(error(msg)) + + def print_success(msg: str): + print(success(msg)) + + def print_header(msg: str): + print(header(msg)) + +except Exception: + def print_info(msg: str): + print(f"[INFO] {msg}") + + def print_error(msg: str): + print(f"[ERROR] {msg}") + + def print_success(msg: str): + print(f"[SUCCESS] {msg}") + + def print_header(msg: str): + print(f"=== {msg} ===") + + +def main() -> int: + print_header("GTCRN 智能降噪") + print_info("调用 src.utils.gtcrn_denoise 执行本地降噪 ...") + + input_dir = PROJECT_ROOT / "output_data" / "normalization" + model_path = PROJECT_ROOT / "models" / "gtcrn" / "gtcrn.onnx" + output_dir = PROJECT_ROOT / "output_data" / "denoise" + + argv_backup = sys.argv[:] + try: + sys.argv = [ + sys.argv[0], + "--input", str(input_dir), + "--model", str(model_path), + "--output", str(output_dir), + ] + code = gtcrn_denoise.main() + finally: + sys.argv = argv_backup + + if code == 0: + print_success("GTCRN 降噪执行完成。") + else: + print_error(f"GTCRN 降噪执行失败,返回码: {code}") + return code + + +if __name__ == "__main__": + raise SystemExit(main()) + diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/pipeline/2_anomaly_filter.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/pipeline/2_anomaly_filter.py new file mode 100644 index 00000000..93eed8db --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/pipeline/2_anomaly_filter.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python3 +""" +1_5_anomaly_filter.py + +执行顺序:第 1.5 步(可选) +- 在 normalization 之后、fast_lang_id 之前,对音频做快速异常检测与过滤。 +- 默认扫描 output_data/normalization 目录,输出带 quality_flag 的 jsonl 列表。 + +用法示例: + python -m src.pipeline.1_5_anomaly_filter + python -m src.pipeline.1_5_anomaly_filter --audio_dir ./output_data/normalization --min_dur 0.5 +""" + +from pathlib import Path +import sys + +PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent +sys.path.insert(0, str(PROJECT_ROOT)) + +from src.pipeline import anomaly_filter # type: ignore + + +if __name__ == "__main__": + raise SystemExit(anomaly_filter.main()) + diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/pipeline/3_fast_lang_id.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/pipeline/3_fast_lang_id.py new file mode 100644 index 00000000..7a8dbd13 --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/pipeline/3_fast_lang_id.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 +""" +2_fast_lang_id.py + +执行顺序:第 2 步 +- 调用 src.utils.fast_lang_id,使用 SpeechBrain 快速识别中/英文, + 默认读取 output_data/normalization,生成 output_data/lid/item_with_lang.list。 +""" + +from pathlib import Path +import sys + +PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent +sys.path.insert(0, str(PROJECT_ROOT)) +sys.path.insert(0, str(PROJECT_ROOT / "src" / "utils")) + +from src.utils import fast_lang_id # type: ignore + +try: + from color_utils import info, warning, error, ok, success, header # type: ignore + + def print_info(msg: str): + print(info(msg)) + + def print_error(msg: str): + print(error(msg)) + + def print_success(msg: str): + print(success(msg)) + + def print_header(msg: str): + print(header(msg)) + +except Exception: + def print_info(msg: str): + print(f"[INFO] {msg}") + + def print_error(msg: str): + print(f"[ERROR] {msg}") + + def print_success(msg: str): + print(f"[SUCCESS] {msg}") + + def print_header(msg: str): + print(f"=== {msg} ===") + + +if __name__ == "__main__": + + code = fast_lang_id.main() + if code == 0: + pass + else: + print_error(f"fast_lang_id 执行失败,返回码: {code}") + raise SystemExit(code) + + diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/pipeline/4_split_and_tag.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/pipeline/4_split_and_tag.py new file mode 100644 index 00000000..982f9de9 --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/pipeline/4_split_and_tag.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python3 +""" +3_split_and_tag.py + +执行顺序:第 3 步 +- 调用 src.pipeline.split_and_tag,将 normalization 结果切分为 ≤2min 片段并生成 split 清单。 +""" + +from pathlib import Path +import sys + +PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent +sys.path.insert(0, str(PROJECT_ROOT)) + +from src.pipeline import split_and_tag # type: ignore + + +if __name__ == "__main__": + raise SystemExit(split_and_tag.main()) + diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/pipeline/5_recognize_monitor.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/pipeline/5_recognize_monitor.py new file mode 100644 index 00000000..9f6725be --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/pipeline/5_recognize_monitor.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python3 +""" +4_recognize_monitor.py + +执行顺序:第 4 步 +- 调用 src.pipeline.recognize_monitor: + - 先识别中文片段,再识别英文片段 + - 合并为 output_data/asr/merged_text.txt +""" + +from pathlib import Path +import sys + +PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent +sys.path.insert(0, str(PROJECT_ROOT)) + +from src.pipeline import recognize_monitor # type: ignore + + +if __name__ == "__main__": + raise SystemExit(recognize_monitor.main()) + diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/pipeline/6_eval_wer.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/pipeline/6_eval_wer.py new file mode 100644 index 00000000..ac0b336b --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/pipeline/6_eval_wer.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python3 +""" +6_eval_wer.py + +执行顺序:第 6 步(可选) +- 调用 src.pipeline.eval_wer: + - 计算中文 CER、英文 WER + - 生成 output_data/validation/transcript_log.txt +""" + +from pathlib import Path +import os +import sys + +PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent +sys.path.insert(0, str(PROJECT_ROOT)) + +from src.pipeline import eval_wer # type: ignore + + +if __name__ == "__main__": + # 统一工作目录到项目根目录,避免 YAML/CLI 里使用相对路径时找不到文件 + os.chdir(PROJECT_ROOT) + raise SystemExit(eval_wer.main()) + diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/pipeline/7_eval_keyword_recall.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/pipeline/7_eval_keyword_recall.py new file mode 100644 index 00000000..b044b22c --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/pipeline/7_eval_keyword_recall.py @@ -0,0 +1,26 @@ +#!/usr/bin/env python3 +""" +7_eval_keyword_recall.py + +执行顺序:可在评估阶段(例如第 7 步) +- 调用 src.pipeline.eval_keyword_recall: + - 读取中英文关键词列表 + - 使用 output_data/asr/merged_text.txt 的识别结果 + - 计算关键词召回率并生成报告 output_data/validation/keyword_recall.txt +""" + +from pathlib import Path +import os +import sys + +PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent +sys.path.insert(0, str(PROJECT_ROOT)) + +from src.pipeline import eval_keyword_recall # type: ignore + + +if __name__ == "__main__": + # 统一工作目录到项目根目录,避免 YAML/CLI 里使用相对路径时找不到文件 + os.chdir(PROJECT_ROOT) + raise SystemExit(eval_keyword_recall.main()) + diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/pipeline/anomaly_filter.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/pipeline/anomaly_filter.py new file mode 100644 index 00000000..3efd7c5a --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/pipeline/anomaly_filter.py @@ -0,0 +1,327 @@ +#!/usr/bin/env python3 +""" +音频异常检测与过滤 + +设计目标: +- 作为 normalization 之后、LID 之前的质量过滤步骤 +- 默认扫描 output_data/normalization 目录中的音频 +- 对每个音频计算: + - 时长(秒) + - 静音帧比例(基于短时能量) +- 根据阈值打标 quality_flag,并输出 jsonl 列表 + +quality_flag 约定: +- "ok" : 通过所有检查 +- "invalid" : 明显异常(时长不在范围或几乎全是静音) +""" + +from __future__ import annotations + +import argparse +import json +import math +import sys +from pathlib import Path +from typing import Dict, Iterable, List, Tuple + + +def _project_root() -> Path: + return Path(__file__).parent.parent.parent + + +def _ensure_utils_on_path() -> None: + root = _project_root() + utils_dir = root / "src" / "utils" + scripts_dir = root / "scripts" / "audio_convert" + for p in (utils_dir, scripts_dir): + if p.exists(): + sp = str(p) + if sp not in sys.path: + sys.path.insert(0, sp) + + +_ensure_utils_on_path() + +try: + from color_utils import info, warning, error, ok, success, header # type: ignore +except Exception: # pragma: no cover - 兼容无 color_utils 场景 + def info(msg: str) -> str: + return f"[INFO] {msg}" + + def warning(msg: str) -> str: + return f"[WARNING] {msg}" + + def error(msg: str) -> str: + return f"[ERROR] {msg}" + + def ok(msg: str) -> str: + return f"[OK] {msg}" + + def success(msg: str) -> str: + return f"[SUCCESS] {msg}" + + def header(msg: str) -> str: + return f"=== {msg} ===" + + +def _print_info(msg: str) -> None: + print(info(msg)) + + +def _print_warning(msg: str) -> None: + print(warning(msg)) + + +def _print_error(msg: str) -> None: + print(error(msg)) + + +def _print_success(msg: str) -> None: + print(success(msg)) + + +# YAML 配置加载(可选) +try: + from yaml_config_loader import parse_args_with_yaml_config # type: ignore +except Exception: + parse_args_with_yaml_config = None # type: ignore[assignment] + + +def _find_audio_files(audio_dir: Path) -> List[Path]: + patterns = ["*.wav", "*.WAV", "*.flac", "*.FLAC", "*.mp3", "*.MP3", "*.aac", "*.AAC", "*.m4a", "*.M4A"] + files: List[Path] = [] + for pat in patterns: + files.extend(audio_dir.rglob(pat)) + return sorted(set(files)) + + +def _load_wave(path: Path) -> Tuple[List[float], int]: + """ + 读取音频为 mono waveform 和采样率。 + + 优先使用 torchaudio(项目已依赖 speechbrain,通常可用), + 若导入失败则退化为 soundfile; 再失败则抛错。 + """ + try: + import torchaudio # type: ignore + + wav, sr = torchaudio.load(str(path)) + if wav.ndim > 1: + wav = wav.mean(dim=0, keepdim=True) + mono = wav.squeeze(0).float().tolist() + return mono, int(sr) + except Exception: + try: + import soundfile as sf # type: ignore + + data, sr = sf.read(str(path), always_2d=False) + if data.ndim > 1: + # stereo -> mono + data = data.mean(axis=1) + return data.tolist(), int(sr) + except Exception as e: + raise RuntimeError(f"读取音频失败: {path}, error={e}") from e + + +def _frame_rms(x: List[float], sr: int, frame_ms: float, hop_ms: float) -> Tuple[List[float], float]: + if not x or sr <= 0: + return [], 0.0 + frame_len = max(1, int(sr * frame_ms / 1000.0)) + hop = max(1, int(sr * hop_ms / 1000.0)) + n = len(x) + rms_list: List[float] = [] + total_sq = 0.0 + for v in x: + total_sq += float(v) * float(v) + global_rms = math.sqrt(total_sq / max(1, n)) + for start in range(0, n, hop): + end = min(start + frame_len, n) + if end <= start: + continue + s = 0.0 + cnt = 0 + for v in x[start:end]: + s += float(v) * float(v) + cnt += 1 + if cnt == 0: + rms = 0.0 + else: + rms = math.sqrt(s / cnt) + rms_list.append(rms) + return rms_list, global_rms + + +def _analyze_one( + path: Path, + min_dur: float, + max_dur: float, + silence_ratio_th: float, + silence_rms_ratio_th: float, +) -> Dict: + wav, sr = _load_wave(path) + n = len(wav) + duration = float(n) / float(sr) if sr > 0 else 0.0 + + rms_frames, global_rms = _frame_rms(wav, sr, frame_ms=25.0, hop_ms=10.0) + if not rms_frames or global_rms <= 0.0: + silence_ratio = 1.0 + else: + th = max(1e-8, global_rms * silence_rms_ratio_th) + silent = sum(1 for r in rms_frames if r < th) + silence_ratio = float(silent) / float(len(rms_frames)) + + reasons: List[str] = [] + quality_flag = "ok" + + if duration <= 0.0: + quality_flag = "invalid" + reasons.append("duration_le_zero") + elif duration < min_dur: + quality_flag = "invalid" + reasons.append("too_short") + elif duration > max_dur: + quality_flag = "invalid" + reasons.append("too_long") + + if silence_ratio >= silence_ratio_th: + quality_flag = "invalid" + reasons.append("too_much_silence") + + key = path.stem + return { + "key": key, + "wav": str(path.resolve()), + "duration": round(duration, 3), + "silence_ratio": round(silence_ratio, 4), + "global_rms": round(global_rms, 6), + "quality_flag": quality_flag, + "reason": ",".join(reasons) if reasons else "", + } + + +def _dump_jsonl(path: Path, items: Iterable[Dict]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "w", encoding="utf-8") as f: + for it in items: + f.write(json.dumps(it, ensure_ascii=False) + "\n") + + +def parse_arguments() -> argparse.Namespace: + root = _project_root() + default_audio_dir = root / "output_data" / "denoise" + default_output = root / "output_data" / "denoise" / "item_with_quality.list" + + parser = argparse.ArgumentParser( + description="音频异常检测与过滤(基于时长和静音比例的快速规则)", + ) + parser.add_argument( + "--config", + "-c", + default=None, + help="YAML 配置文件路径(可选)。支持写 anomaly_filter: {min_dur:..., silence_ratio_th:...} 或直接顶层同名键", + ) + parser.add_argument( + "--audio_dir", + "-a", + default=str(default_audio_dir), + help=f"要扫描的音频目录,默认: {default_audio_dir}", + ) + parser.add_argument( + "--output", + "-o", + default=str(default_output), + help=f"输出 jsonl 列表路径,默认: {default_output}", + ) + parser.add_argument( + "--min_dur", + type=float, + default=1.0, + help="最小时长(秒),小于该值视为异常,默认 1.0", + ) + parser.add_argument( + "--max_dur", + type=float, + default=20000.0, + help="最大时长(秒),大于该值视为异常", + ) + parser.add_argument( + "--silence_ratio_th", + type=float, + default=0.8, + help="静音帧比例阈值,超过则视为异常,默认 0.8", + ) + parser.add_argument( + "--silence_rms_ratio_th", + type=float, + default=0.05, + help="静音判定阈值 = global_rms * 该比例,默认 0.05", + ) + if parse_args_with_yaml_config: + return parse_args_with_yaml_config( + parser, + section="anomaly_filter", + default_config_paths=[root / "config" / "anomaly_filter.yaml"], + ) + return parser.parse_args() + + +def main() -> int: + args = parse_arguments() + audio_dir = Path(args.audio_dir).resolve() + output_path = Path(args.output).resolve() + + print(header("音频异常检测与过滤")) + if not audio_dir.exists(): + _print_error(f"音频目录不存在: {audio_dir}") + return 1 + + files = _find_audio_files(audio_dir) + if not files: + _print_warning(f"目录中未找到任何音频文件: {audio_dir}") + return 0 + + _print_info(f"待分析音频数: {len(files)}") + _print_info( + f"参数: min_dur={args.min_dur}s, max_dur={args.max_dur}s, " + f"silence_ratio_th={args.silence_ratio_th}, silence_rms_ratio_th={args.silence_rms_ratio_th}" + ) + + items: List[Dict] = [] + invalid_count = 0 + for idx, p in enumerate(files, start=1): + try: + it = _analyze_one( + path=p, + min_dur=float(args.min_dur), + max_dur=float(args.max_dur), + silence_ratio_th=float(args.silence_ratio_th), + silence_rms_ratio_th=float(args.silence_rms_ratio_th), + ) + except Exception as e: + _print_warning(f"处理失败,标记为 invalid: {p}, error={e}") + it = { + "key": p.stem, + "wav": str(p.resolve()), + "duration": 0.0, + "silence_ratio": 1.0, + "global_rms": 0.0, + "quality_flag": "invalid", + "reason": "load_error", + } + if it.get("quality_flag") == "invalid": + invalid_count += 1 + items.append(it) + + if idx % 20 == 0 or idx == len(files): + _print_info(f"进度: {idx}/{len(files)}") + + _dump_jsonl(output_path, items) + _print_success(f"分析完成,输出: {output_path}") + _print_info(f"统计: 总数={len(items)}, invalid={invalid_count}, ok={len(items) - invalid_count}") + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) + diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/pipeline/eval_keyword_recall.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/pipeline/eval_keyword_recall.py new file mode 100644 index 00000000..0af78202 --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/pipeline/eval_keyword_recall.py @@ -0,0 +1,351 @@ +#!/usr/bin/env python3 +""" +关键词召回率评估脚本 + +功能: +- 从 input_data/valiadation 下读取中英文关键词列表: + - zh_keyword.txt(中文关键词,Kaldi 文本:utt_idkw1 kw2 ...) + - en_keyword.txt(英文关键词,Kaldi 文本:utt_idkw1 kw2 ...) +- 从 output_data/asr/merged_text.txt 读取识别结果(每行: utt_id text...) +- 对 key 交集部分分别计算: + - 中文关键词召回率 + - 英文关键词召回率 + +关键词召回率定义: +- 对于每个句子: + - ref_keywords = 该句的关键词集合(去重) + - hyp_tokens = ASR 识别结果按空格切分后的 token 集合(大小写不敏感) + - hit = ref_keywords ∩ hyp_tokens 的元素个数 + - recall_utt = hit / len(ref_keywords) (若该句没有关键词,则跳过) +- 整体召回率 = 所有可评估句子的 recall_utt 的平均值(macro 平均) + +输出: +- 在 output_data/validation/keyword_recall.txt 中写入报告 +""" + +from __future__ import annotations + +import argparse +from pathlib import Path +from typing import Dict, List, Set, Tuple +import sys + + +PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent + +# 颜色打印工具(与其他脚本风格保持一致) +sys.path.insert(0, str(PROJECT_ROOT / "src" / "utils")) +try: + from color_utils import info, warning, error, ok, success, header # type: ignore + + def print_info(msg: str): + print(info(msg)) + + def print_warning(msg: str): + print(warning(msg)) + + def print_error(msg: str): + print(error(msg)) + + def print_ok(msg: str): + print(ok(msg)) + + def print_success(msg: str): + print(success(msg)) + + def print_header(msg: str): + print(header(msg)) + +except Exception: + + def print_info(msg: str): + print(f"[INFO] {msg}") + + def print_warning(msg: str): + print(f"[WARNING] {msg}") + + def print_error(msg: str): + print(f"[ERROR] {msg}") + + def print_ok(msg: str): + print(f"[OK] {msg}") + + def print_success(msg: str): + print(f"[SUCCESS] {msg}") + + def print_header(msg: str): + print(f"=== {msg} ===") + + +# YAML 配置加载(可选) +try: + from yaml_config_loader import parse_args_with_yaml_config # type: ignore +except Exception: + parse_args_with_yaml_config = None # type: ignore[assignment] + + +def read_kw_kaldi(path: Path) -> Dict[str, List[str]]: + """ + 读取关键词文件(Kaldi 风格,每行: keykw1 kw2 ...) + 返回:key -> 关键词列表(按出现顺序,不去重) + """ + data: Dict[str, List[str]] = {} + if not path.exists(): + return data + with path.open("r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + # 兼容 tab 或空格 + if "\t" in line: + key, rest = line.split("\t", 1) + else: + parts = line.split(maxsplit=1) + if len(parts) == 1: + key, rest = parts[0], "" + else: + key, rest = parts + if not key: + continue + kws = [w for w in rest.split() if w] + data[key] = kws + return data + + +def read_kv_text(path: Path) -> Dict[str, str]: + """读取 Kaldi 风格文本(每行: key text...)""" + data: Dict[str, str] = {} + if not path.exists(): + return data + with path.open("r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + parts = line.split(maxsplit=1) + if not parts: + continue + key = parts[0] + text = parts[1] if len(parts) > 1 else "" + data[key] = text + return data + + +def compute_keyword_recall_per_lang( + kw_map: Dict[str, List[str]], + hyp_map: Dict[str, str], + lang_name: str, + *, + use_substring_match: bool = False, +) -> Tuple[float, int, int, List[Tuple[str, float, int, int, List[str], List[str]]]]: + """ + 计算单语种关键词召回率(macro 平均)。 + + Returns: + ( + overall_recall, + num_utt_used, + num_utt_total, + per_utt_detail: [ + (utt_id, recall_utt, hit, ref_size, hit_list, miss_list) + ], + ) + """ + keys = set(kw_map.keys()) & set(hyp_map.keys()) + if not keys: + print_warning(f"{lang_name} 无 key 交集,跳过该语种评估") + return 0.0, 0, 0, [] + + recalls: List[float] = [] + per_utt: List[Tuple[str, float, int, int, List[str], List[str]]] = [] + num_total = 0 + for k in sorted(keys): + ref_kws = [w for w in kw_map.get(k, []) if w] + num_total += 1 + if not ref_kws: + # 该句没有关键词,跳过,不计入分母 + continue + ref_set: Set[str] = {w.lower() for w in ref_kws} + + hyp_text = hyp_map.get(k, "") + if use_substring_match: + # 适用于中文:关键词是词,识别结果通常是连续文本 + hyp_text_lower = hyp_text.lower() + hit_words = [w for w in ref_set if w and w in hyp_text_lower] + miss_words = [w for w in ref_set if w not in hyp_text_lower] + else: + # 适用于英文:按空格分词 + hyp_tokens = [t.lower() for t in hyp_text.split() if t] + hyp_set: Set[str] = set(hyp_tokens) + hit_words = [w for w in ref_set if w in hyp_set] + miss_words = [w for w in ref_set if w not in hyp_set] + + if not ref_set: + continue + + hit = len(hit_words) + recall_utt = hit / float(len(ref_set)) + recalls.append(recall_utt) + per_utt.append( + ( + k, + recall_utt, + hit, + len(ref_set), + sorted(hit_words), + sorted(miss_words), + ) + ) + + if not recalls: + print_warning(f"{lang_name} 中没有可评估的含关键词样本") + return 0.0, 0, num_total, per_utt + + overall = sum(recalls) / len(recalls) + return overall, len(recalls), num_total, per_utt + + +def main() -> int: + parser = argparse.ArgumentParser( + description="评估 ASR 在中英文关键词上的召回率", + ) + parser.add_argument( + "--config", + "-c", + default=None, + help="YAML 配置文件路径(可选)。支持写 eval_keyword_recall: {...}", + ) + parser.add_argument( + "--zh_kw", + default=str( + PROJECT_ROOT / "input_data" / "valiadation" / "zh_keyword.txt" + ), + help="中文关键词文件(Kaldi 文本格式: utt kw1 kw2 ...)", + ) + parser.add_argument( + "--en_kw", + default=str( + PROJECT_ROOT / "input_data" / "valiadation" / "en_keyword.txt" + ), + help="英文关键词文件(Kaldi 文本格式: utt kw1 kw2 ...)", + ) + parser.add_argument( + "--hyp", + default=str(PROJECT_ROOT / "output_data" / "asr" / "merged_text.txt"), + help="ASR 识别结果(Kaldi 文本格式: utt words...)", + ) + parser.add_argument( + "--work_dir", + default=str(PROJECT_ROOT / "output_data" / "validation"), + help="报告输出目录,默认: output_data/validation", + ) + + if parse_args_with_yaml_config: + args = parse_args_with_yaml_config( + parser, + section="eval_keyword_recall", + default_config_paths=[PROJECT_ROOT / "config" / "eval_keyword_recall.yaml"], + ) + else: + args = parser.parse_args() + + zh_kw_path = Path(args.zh_kw) + en_kw_path = Path(args.en_kw) + hyp_path = Path(args.hyp) + work_dir = Path(args.work_dir) + + print_header("ASR 关键词召回率评估") + + if not hyp_path.exists(): + print_error(f"识别结果不存在: {hyp_path}") + return 1 + + zh_kw = read_kw_kaldi(zh_kw_path) + en_kw = read_kw_kaldi(en_kw_path) + hyp = read_kv_text(hyp_path) + + if not zh_kw and not en_kw: + print_error(f"未找到关键词文件: {zh_kw_path} / {en_kw_path}") + return 1 + + zh_recall, zh_utt_used, zh_utt_total, zh_detail = compute_keyword_recall_per_lang( + zh_kw, hyp, "中文", use_substring_match=True + ) + en_recall, en_utt_used, en_utt_total, en_detail = compute_keyword_recall_per_lang( + en_kw, hyp, "英文", use_substring_match=False + ) + + if zh_utt_used > 0: + print_ok( + f"中文关键词召回率: {zh_recall * 100:.2f}% " + f"(含关键词样本 {zh_utt_used} 条 / 全部交集样本 {zh_utt_total} 条)" + ) + else: + print_warning("中文无可评估关键词样本") + + if en_utt_used > 0: + print_ok( + f"英文关键词召回率: {en_recall * 100:.2f}% " + f"(含关键词样本 {en_utt_used} 条 / 全部交集样本 {en_utt_total} 条)" + ) + else: + print_warning("英文无可评估关键词样本") + + # 输出报告(包含明细) + work_dir.mkdir(parents=True, exist_ok=True) + report_path = work_dir / "keyword_recall.txt" + with report_path.open("w", encoding="utf-8") as f: + f.write("ASR 关键词召回率评估报告\n") + f.write(f"中文关键词: {zh_kw_path}\n") + f.write(f"英文关键词: {en_kw_path}\n") + f.write(f"识别结果: {hyp_path}\n\n") + + f.write( + f"中文:交集样本总数 = {zh_utt_total}," + f"含关键词样本数 = {zh_utt_used}," + f"关键词召回率 = {zh_recall * 100:.2f}%\n" + ) + f.write( + f"英文:交集样本总数 = {en_utt_total}," + f"含关键词样本数 = {en_utt_used}," + f"关键词召回率 = {en_recall * 100:.2f}%\n" + ) + f.write("\n") + + def dump_lang_detail( + lang_title: str, + details: List[Tuple[str, float, int, int, List[str], List[str]]], + ) -> None: + f.write(f"==== {lang_title} 逐句明细 ====\n") + if not details: + f.write("(无可评估样本)\n\n") + return + for ( + utt_id, + recall_utt, + hit, + ref_size, + hit_words, + miss_words, + ) in details: + f.write(f"utt_id: {utt_id}\n") + f.write( + f" recall: {recall_utt * 100:.2f}% " + f"(hit={hit}, ref_kw={ref_size})\n" + ) + f.write(f" hit_kw: {' '.join(hit_words) if hit_words else 'None'}\n") + f.write( + f" miss_kw: {' '.join(miss_words) if miss_words else 'None'}\n\n" + ) + + dump_lang_detail("中文", zh_detail) + dump_lang_detail("英文", en_detail) + + print_success(f"评估完成,报告已写入: {report_path}") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) + diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/pipeline/eval_wer.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/pipeline/eval_wer.py new file mode 100644 index 00000000..a7f995fa --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/pipeline/eval_wer.py @@ -0,0 +1,300 @@ +#!/usr/bin/env python3 +""" +WER 评估脚本 + +功能: +- 从 input_data/validation 下读取参考转写: + - zh_transcript.txt(中文,按“字错率”评估) + - en_transcript.txt(英文,按“词错率”评估) +- 从 output_data/asr/merged_text.txt 读取识别结果(每行: key text...) +- 对 key 交集部分分别计算: + - 中文:char 模式下的错字率 + - 英文:word 模式下的 WER + +注意: +- 自动跳过只在其中一边存在的 key(既不在 ref 也不在 hyp 的样本) +- 依赖 src/utils/compute_wer.py 中的 compute-wer 实现 +""" + +import argparse +import subprocess +import sys +from pathlib import Path +from typing import Dict, Set, Tuple + + +PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent + +# 颜色打印工具(与其他脚本风格保持一致) +sys.path.insert(0, str(PROJECT_ROOT / "src" / "utils")) +try: + from color_utils import info, warning, error, ok, success, header # type: ignore + + def print_info(msg: str): + print(info(msg)) + + def print_warning(msg: str): + print(warning(msg)) + + def print_error(msg: str): + print(error(msg)) + + def print_ok(msg: str): + print(ok(msg)) + + def print_success(msg: str): + print(success(msg)) + + def print_header(msg: str): + print(header(msg)) + +except Exception: + def print_info(msg: str): + print(f"[INFO] {msg}") + + def print_warning(msg: str): + print(f"[WARNING] {msg}") + + def print_error(msg: str): + print(f"[ERROR] {msg}") + + def print_ok(msg: str): + print(f"[OK] {msg}") + + def print_success(msg: str): + print(f"[SUCCESS] {msg}") + + def print_header(msg: str): + print(f"=== {msg} ===") + + +# YAML 配置加载(可选) +try: + from yaml_config_loader import parse_args_with_yaml_config # type: ignore +except Exception: + parse_args_with_yaml_config = None # type: ignore[assignment] + + +def read_kv(path: Path) -> Dict[str, str]: + """读取 Kaldi 风格文本(每行: key text...)""" + data: Dict[str, str] = {} + if not path.exists(): + return data + with path.open("r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + parts = line.split(maxsplit=1) + if not parts: + continue + key = parts[0] + text = parts[1] if len(parts) > 1 else "" + data[key] = text + return data + + +def dump_subset(path: Path, data: Dict[str, str], keys: Set[str]) -> None: + """将指定 key 子集写出为 Kaldi 风格文本文件。""" + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w", encoding="utf-8") as f: + for k in sorted(keys): + f.write(f"{k} {data.get(k, '').strip()}\n") + + +def run_compute_wer(ref: Path, hyp: Path, char_mode: bool) -> Tuple[float, str]: + """ + 调用 src/utils/compute_wer.py 计算错率。 + + Args: + ref: 参考转写文件路径 + hyp: 识别结果文件路径 + char_mode: True=按字符(适合中文),False=按词(适合英文) + + Returns: + (整体错误率, compute_wer 原始输出字符串) + """ + script = PROJECT_ROOT / "src" / "utils" / "compute_wer.py" + if not script.exists(): + raise FileNotFoundError(f"未找到 compute_wer.py: {script}") + + # --char=1 开启逐字符评估;--char=0 为逐词 + char_flag = "1" if char_mode else "0" + cmd = [ + sys.executable, + str(script), + f"--char={char_flag}", + str(ref), + str(hyp), + ] + proc = subprocess.run( + cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, encoding="utf-8" + ) + if proc.returncode != 0: + raise RuntimeError(f"compute_wer 运行失败: {proc.stderr}") + + overall = 0.0 + for line in proc.stdout.splitlines(): + line = line.strip() + if line.startswith("Overall ->"): + # 形如: Overall -> 6.46 % N=...,取中间的百分比 + try: + percent_str = line.split("->", 1)[1].split("%", 1)[0].strip() + overall = float(percent_str) + except Exception: + pass + return overall, proc.stdout + + +def main() -> int: + parser = argparse.ArgumentParser( + description="评估中英文 ASR 错误率(中文字错率,英文词错率)", + ) + parser.add_argument( + "--config", + "-c", + default=None, + help="YAML 配置文件路径(可选)。支持写 eval_wer: {zh_ref:..., hyp:..., work_dir:...} 或直接顶层同名键", + ) + parser.add_argument( + "--zh_ref", + default=str(PROJECT_ROOT / "input_data" / "validation" / "zh_transcript.txt"), + help="中文参考转写(Kaldi 文本格式)", + ) + parser.add_argument( + "--en_ref", + default=str(PROJECT_ROOT / "input_data" / "validation" / "en_transcript.txt"), + help="英文参考转写(Kaldi 文本格式)", + ) + parser.add_argument( + "--hyp", + default=str(PROJECT_ROOT / "output_data" / "asr" / "merged_text.txt"), + help="识别结果文本(merged_text.txt)", + ) + parser.add_argument( + "--work_dir", + default=str(PROJECT_ROOT / "output_data" / "validation"), + help="中间文件输出目录,默认: output_data/validation", + ) + if parse_args_with_yaml_config: + args = parse_args_with_yaml_config( + parser, + section="eval_wer", + default_config_paths=[PROJECT_ROOT / "config" / "eval_wer.yaml"], + ) + else: + args = parser.parse_args() + + zh_ref_path = Path(args.zh_ref) + en_ref_path = Path(args.en_ref) + hyp_path = Path(args.hyp) + work_dir = Path(args.work_dir) + + print_header("ASR 错误率评估") + + # 兼容历史目录名拼写:valiadation(用户侧数据目录存在该拼写) + # - CLI/YAML 可能给绝对路径或相对路径,这里都做回退 + default_zh_abs = PROJECT_ROOT / "input_data" / "validation" / "zh_transcript.txt" + default_en_abs = PROJECT_ROOT / "input_data" / "validation" / "en_transcript.txt" + fallback_zh_abs = PROJECT_ROOT / "input_data" / "valiadation" / "zh_transcript.txt" + fallback_en_abs = PROJECT_ROOT / "input_data" / "valiadation" / "en_transcript.txt" + + def maybe_fallback_validation_typo(p: Path, fallback_abs: Path) -> Path: + if p.exists(): + return p + # 1) 传入的是默认绝对路径 + if str(p) == str(default_zh_abs) or str(p) == str(default_en_abs): + return fallback_abs if fallback_abs.exists() else p + # 2) 传入的是相对路径:input_data/validation/*.txt + if p.as_posix().endswith("input_data/validation/" + p.name): + return fallback_abs if fallback_abs.exists() else p + # 3) 传入的是单纯相对:validation/*.txt(防呆) + if "validation" in p.parts and p.name in ("zh_transcript.txt", "en_transcript.txt"): + return fallback_abs if fallback_abs.exists() else p + return p + + zh_ref_path = maybe_fallback_validation_typo(zh_ref_path, fallback_zh_abs) + en_ref_path = maybe_fallback_validation_typo(en_ref_path, fallback_en_abs) + + if not hyp_path.exists(): + print_error(f"识别结果不存在: {hyp_path}") + return 1 + + zh_ref = read_kv(zh_ref_path) + en_ref = read_kv(en_ref_path) + hyp = read_kv(hyp_path) + + if not zh_ref and not en_ref: + print_error(f"未找到参考转写: {zh_ref_path} / {en_ref_path}") + return 1 + + # 计算交集,自动跳过单边缺失的样本 + zh_keys = set(zh_ref.keys()) & set(hyp.keys()) + en_keys = set(en_ref.keys()) & set(hyp.keys()) + + print_info(f"中文样本交集: {len(zh_keys)} 条") + print_info(f"英文样本交集: {len(en_keys)} 条") + + zh_ref_sub = work_dir / "zh_ref.txt" + zh_hyp_sub = work_dir / "zh_hyp.txt" + en_ref_sub = work_dir / "en_ref.txt" + en_hyp_sub = work_dir / "en_hyp.txt" + + zh_wer = None + en_wer = None + zh_detail = "" + en_detail = "" + + if zh_keys: + dump_subset(zh_ref_sub, zh_ref, zh_keys) + dump_subset(zh_hyp_sub, hyp, zh_keys) + zh_wer, zh_detail = run_compute_wer(zh_ref_sub, zh_hyp_sub, char_mode=True) + print_ok(f"中文字错率 (CER): {zh_wer:.2f}%") + else: + print_warning("无中文样本交集,跳过中文评估") + + if en_keys: + dump_subset(en_ref_sub, en_ref, en_keys) + dump_subset(en_hyp_sub, hyp, en_keys) + en_wer, en_detail = run_compute_wer(en_ref_sub, en_hyp_sub, char_mode=False) + print_ok(f"英文词错率 (WER): {en_wer:.2f}%") + else: + print_warning("无英文样本交集,跳过英文评估") + + # 输出最终识别报告 + report_dir = work_dir + report_dir.mkdir(parents=True, exist_ok=True) + report_path = report_dir / "transcript_log.txt" + with report_path.open("w", encoding="utf-8") as f: + f.write("ASR 验证集评估报告\n") + f.write(f"中文参考: {zh_ref_path}\n") + f.write(f"英文参考: {en_ref_path}\n") + f.write(f"识别结果: {hyp_path}\n\n") + f.write(f"中文样本交集: {len(zh_keys)} 条\n") + f.write(f"英文样本交集: {len(en_keys)} 条\n\n") + if zh_wer is not None: + f.write(f"中文字错率 (CER): {zh_wer:.2f}%\n") + else: + f.write("中文字错率 (CER): 无可评估样本\n") + if en_wer is not None: + f.write(f"英文词错率 (WER): {en_wer:.2f}%\n") + else: + f.write("英文词错率 (WER): 无可评估样本\n") + + if zh_detail: + f.write(zh_detail.strip() + "\n") + else: + f.write("(无可评估样本)\n") + + if en_detail: + f.write(en_detail.strip() + "\n") + else: + f.write("(无可评估样本)\n") + + print_success(f"评估完成,报告已写入: {report_path}") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) + diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/pipeline/merge_asr_by_source.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/pipeline/merge_asr_by_source.py new file mode 100644 index 00000000..ad313c01 --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/pipeline/merge_asr_by_source.py @@ -0,0 +1,130 @@ +#!/usr/bin/env python3 +""" +读取 split 阶段的 item_with_lang.list 与 zh/en 两次 ASR 的 text 结果, +按 source_key + segment_index 合并为每条原音频一句完整文本。 +""" + +import argparse +import json +import sys +from collections import defaultdict +from pathlib import Path +from typing import Dict, List + +_PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent + +# YAML 配置加载(可选) +sys.path.insert(0, str(_PROJECT_ROOT / "src" / "utils")) +try: + from yaml_config_loader import parse_args_with_yaml_config # type: ignore +except Exception: + parse_args_with_yaml_config = None # type: ignore[assignment] + + +def load_key_to_text(text_path: Path) -> Dict[str, str]: + """WeNet 的 result_dir/mode/text 每行: key 空格 文本""" + out: Dict[str, str] = {} + if not text_path.exists(): + return out + with text_path.open("r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + parts = line.split(maxsplit=1) + key = parts[0] + text = parts[1] if len(parts) > 1 else "" + out[key] = text + return out + + +def merge_once(list_file: Path, zh_text: Path, en_text: Path, output: Path) -> int: + """核心合并逻辑,供 main 与其他脚本复用。""" + if not list_file.exists(): + print(f"[ERROR] 列表不存在: {list_file}", file=sys.stderr) + return 1 + + items: List[Dict] = [] + with list_file.open("r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + items.append(json.loads(line)) + + zh_map = load_key_to_text(zh_text) + en_map = load_key_to_text(en_text) + key_to_text: Dict[str, str] = {**zh_map, **en_map} + + # 按 source_key 分组,按 segment_index 排序后拼接 + by_source: Dict[str, List[tuple]] = defaultdict(list) + for it in items: + key = it.get("key", "") + source = it.get("source_key", key) + seg_idx = it.get("segment_index", 0) + text = key_to_text.get(key, "") + by_source[source].append((seg_idx, text)) + + output.parent.mkdir(parents=True, exist_ok=True) + with output.open("w", encoding="utf-8") as f: + for source in sorted(by_source.keys()): + parts = sorted(by_source[source], key=lambda x: x[0]) + full_text = " ".join(t.strip() for _, t in parts if t.strip()) + f.write(f"{source} {full_text}\n") + + print(f"[OK] 已合并 {len(by_source)} 条原音频 -> {output}") + return 0 + + +def main_for_api(list_file: Path, zh_text: Path, en_text: Path, output: Path) -> int: + """供其他模块直接调用的 API 包装。""" + return merge_once(list_file, zh_text, en_text, output) + + +def main() -> int: + parser = argparse.ArgumentParser(description="按 source_key 合并子片段 ASR 结果") + parser.add_argument( + "--config", + "-c", + default=None, + help="YAML 配置文件路径(可选)。支持写 merge_asr_by_source: {list_file:..., output:...} 或直接顶层同名键", + ) + parser.add_argument( + "--list_file", + default=str(_PROJECT_ROOT / "output_data" / "split" / "item_with_lang.list"), + help="split 输出的 list(含 source_key, segment_index)", + ) + parser.add_argument( + "--zh_text", + default=str(_PROJECT_ROOT / "output_data" / "asr" / "zh" / "ctc_greedy_search" / "text"), + help="中文 ASR 结果 text 文件", + ) + parser.add_argument( + "--en_text", + default=str(_PROJECT_ROOT / "output_data" / "asr" / "en" / "ctc_greedy_search" / "text"), + help="英文 ASR 结果 text 文件", + ) + parser.add_argument( + "--output", + default=str(_PROJECT_ROOT / "output_data" / "asr" / "merged_text.txt"), + help="合并后输出:每行 source_key 空格 整段文本", + ) + if parse_args_with_yaml_config: + args = parse_args_with_yaml_config( + parser, + section="merge_asr_by_source", + default_config_paths=[_PROJECT_ROOT / "config" / "merge_asr_by_source.yaml"], + ) + else: + args = parser.parse_args() + + return merge_once( + list_file=Path(args.list_file), + zh_text=Path(args.zh_text), + en_text=Path(args.en_text), + output=Path(args.output), + ) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/pipeline/normalization.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/pipeline/normalization.py new file mode 100644 index 00000000..49621cf0 --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/pipeline/normalization.py @@ -0,0 +1,352 @@ +#!/usr/bin/env python3 +""" +音频归一化处理脚本 +自动扫描输入文件夹,调用 audio_convert 进行批量转换 +提供默认的输入/输出文件夹,支持自定义配置 +""" +import argparse +import sys +import os +from pathlib import Path +from typing import List, Optional, Tuple +import subprocess + +# 添加脚本所在目录到系统路径 +sys.path.insert(0, str(Path(__file__).parent.parent.parent / "scripts" / "audio_convert")) +sys.path.insert(0, str(Path(__file__).parent.parent.parent / "src" / "utils")) + +# 导入 config_loader 模块和颜色工具 +try: + from config_loader import get_audio_config, clear_config_cache, create_default_config, find_config_file + from color_utils import info, warning, error, ok, header, success, fail, question +except ImportError as e: + print(f"[ERROR] 无法导入模块: {e}", file=sys.stderr) + print(f"[INFO] 当前搜索路径: {sys.path}") + sys.exit(1) + +print(header("音频标准化处理")) + +def get_default_directories() -> tuple[Path, Path]: + """ + 获取默认的输入和输出目录 + + Returns: + tuple[Path, Path]: (input_dir, output_dir) + """ + # 当前工作目录下的默认目录 + current_dir = Path.cwd() + input_dir = current_dir.parent.parent / "input_data" / "audio_raw" + output_dir = current_dir.parent.parent / "output_data" / "normalization" + + return input_dir, output_dir + + +def scan_input_directory(input_dir: Path, config_path: Optional[str] = None) -> Tuple[List[str], List[str], int]: + """ + 扫描输入目录中的文件,返回音频文件、其他文件列表和其他文件数量 + + Args: + input_dir: 输入目录 + config_path: 配置文件路径,用于获取支持的格式 + + Returns: + Tuple[List[str], List[str], int]: (音频文件列表, 其他文件列表, 其他文件数量) + """ + # 获取支持的格式 + config = get_audio_config(config_path) + input_formats = config.get('input_format', ['mp3', 'wav', 'aac', 'm4a', 'flac']) + + # 构建扩展名集合 + extensions = {f".{fmt.lower().lstrip('.')}" for fmt in input_formats} + + # 查找文件 + audio_files = [] + other_files = [] + + # 使用 rglob 扫描所有文件 + for item in input_dir.rglob("*"): + if item.is_file(): + if item.suffix.lower() in extensions: + audio_files.append(str(item)) + else: + other_files.append(str(item)) + + return audio_files, other_files, len(other_files) + + +def find_audio_files(input_dir: Path, config_path: Optional[str] = None) -> List[str]: + """ + 查找输入目录中的音频文件 + + Args: + input_dir: 输入目录 + config_path: 配置文件路径,用于获取支持的格式 + + Returns: + List[str]: 音频文件路径列表 + """ + audio_files, _, _ = scan_input_directory(input_dir, config_path) + return sorted(set(audio_files)) + + +def check_existing_output_files(audio_files: List[str], output_dir: Path, + config_path: Optional[str] = None) -> List[str]: + """ + 检查输出目录中已存在的文件 + + Args: + audio_files: 音频文件列表 + output_dir: 输出目录 + config_path: 配置文件路径 + + Returns: + List[str]: 已存在的输出文件列表 + """ + config = get_audio_config(config_path) + output_ext = f".{config.get('output_format', 'wav').lower().lstrip('.')}" + + existing_files = [] + for audio_file in audio_files: + src = Path(audio_file) + dst = output_dir / f"{src.stem}{output_ext}" + if dst.exists(): + existing_files.append(str(dst)) + + return existing_files + + +def ask_user_confirmation(prompt: str) -> bool: + """ + 询问用户确认 + + Args: + prompt: 提示信息 + + Returns: + bool: 用户是否确认 + """ + response = input(f"{question(prompt)} ([y]/n): ").strip().lower() + return response in ['y', 'yes', ''] + + +def run_audio_convert(input_files: List[str], output_dir: Path, + config_path: Optional[str] = None, overwrite: bool = False) -> int: + """ + 调用 audio_convert.py 进行转换 + + Args: + input_files: 输入文件列表 + output_dir: 输出目录 + config_path: 配置文件路径 + overwrite: 是否覆盖已存在文件 + + Returns: + int: 返回码 + """ + if not input_files: + print(warning("未找到任何音频文件,跳过转换")) + return 0 + + # 获取 audio_convert.py 的绝对路径 + audio_convert_path = Path(__file__).parent.parent.parent / "scripts" / "audio_convert" / "audio_convert.py" + + if not audio_convert_path.exists(): + print(error(f"audio_convert.py 未找到: {audio_convert_path}")) + return 1 + + # 构建命令行参数 + cmd = [sys.executable, str(audio_convert_path)] + + # 添加输入文件 + cmd.extend(input_files) + + # 添加输出目录 + cmd.extend(["--output", str(output_dir)]) + + # 添加配置文件(如果指定) + if config_path: + cmd.extend(["--config", config_path]) + + # 添加覆盖选项 + if overwrite: + cmd.append("--overwrite") + + # 显示配置文件信息 + config_file = find_config_file(config_path) + print(info(f"使用配置文件: {config_file}")) + + # 显示处理的文件数量 + print(info(f"准备处理 {len(input_files)} 个音频文件")) + + # 显示音频文件名(仅文件名) + print(info("音频文件列表:")) + for audio_file in input_files: + file_name = Path(audio_file).name + print(f" - {file_name}") + + # 执行命令 + try: + result = subprocess.run(cmd, check=True, capture_output=True, text=True, encoding='utf-8') + + # 解析输出,提取成功信息 + if result.stdout: + lines = result.stdout.strip().split('\n') + success_count = 0 + for line in lines: + if "[OK]" in line: + # 提取文件名 + parts = line.split(" -> ") + if len(parts) == 2: + src_path = Path(parts[0].replace("[OK] ", "").strip()) + dst_path = Path(parts[1].strip()) + file_name = src_path.name + print(ok(f"转换成功: {file_name}")) + success_count += 1 + + if result.stderr: + print(error(f"错误输出: {result.stderr}")) + + return result.returncode + + except subprocess.CalledProcessError as e: + print(error(f"转换失败: {e}")) + print(error(f"错误输出: {e.stderr}")) + return e.returncode + + +def main(): + """主函数""" + parser = argparse.ArgumentParser( + description="音频归一化处理工具", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +示例: + %(prog)s # 使用默认配置和目录 + %(prog)s --input_dir my_input --output_dir my_output + %(prog)s --config my_config.yaml --overwrite + %(prog)s --input_dir /path/to/input --config custom_config.yaml + """ + ) + + # 获取默认目录 + default_input_dir, default_output_dir = get_default_directories() + + parser.add_argument( + "--input_dir", + "-i", + default=str(default_input_dir), + help=f"输入音频文件夹路径,默认: {default_input_dir}" + ) + + parser.add_argument( + "--output_dir", + "-o", + default=str(default_output_dir), + help=f"输出音频文件夹路径,默认: {default_output_dir}" + ) + + parser.add_argument( + "--config", + "-c", + default=None, + help="自定义配置文件路径,不指定则使用默认配置" + ) + + parser.add_argument( + "--overwrite", + action="store_true", + help="覆盖已存在的输出文件" + ) + + parser.add_argument( + "--show_config", + action="store_true", + help="显示配置信息并退出" + ) + + parser.add_argument( + "--create_default_config", + action="store_true", + help="创建默认配置文件并退出" + ) + + args = parser.parse_args() + + # 创建默认配置文件 + if args.create_default_config: + config_path = find_config_file(args.config) + create_default_config(config_path) + print(info(f"已创建默认配置文件: {config_path}")) + return 0 + + # 显示配置信息 + if args.show_config: + # 运行 audio_convert 的 show_config 选项 + audio_convert_path = Path(__file__).parent.parent.parent / "scripts" / "audio_convert" / "audio_convert.py" + cmd = [sys.executable, str(audio_convert_path), "--show_config"] + if args.config: + cmd.extend(["--config", args.config]) + + try: + result = subprocess.run(cmd, check=True, capture_output=True, text=True, encoding='utf-8') + print(result.stdout) + except subprocess.CalledProcessError as e: + print(error(f"获取配置失败: {e}")) + print(error(f"错误输出: {e.stderr}")) + return 0 + + # 确保目录存在 + input_dir = Path(args.input_dir) + output_dir = Path(args.output_dir) + + if not input_dir.exists(): + print(error(f"输入目录不存在: {input_dir}")) + print(info(f"请创建目录: mkdir -p {input_dir}")) + return 1 + + if not output_dir.exists(): + print(info(f"输出目录不存在,自动创建: {output_dir}")) + output_dir.mkdir(parents=True, exist_ok=True) + + # 查找音频文件和其他文件 + print(info(f"扫描输入目录: {input_dir}")) + audio_files, other_files, other_count = scan_input_directory(input_dir, args.config) + + if not audio_files: + print(warning(f"在 {input_dir} 中未找到任何支持的音频文件")) + print(info(f"支持的格式: mp3, wav, aac, m4a, flac (可在配置文件中修改)")) + return 0 + + print(info(f"找到 {len(audio_files)} 个音频文件")) + if other_count > 0: + print(info(f"找到 {other_count} 个其他文件(非音频格式)")) + + # 检查是否需要覆盖 + existing_files = check_existing_output_files(audio_files, output_dir, args.config) + need_overwrite = False + + if existing_files and not args.overwrite: + print(warning(f"检测到 {len(existing_files)} 个输出文件已存在")) + if ask_user_confirmation("是否覆盖这些文件?"): + need_overwrite = True + else: + print(info("用户取消操作,程序结束")) + return 0 + elif args.overwrite and existing_files: + print(info(f"已启用覆盖模式,将覆盖 {len(existing_files)} 个已存在文件")) + need_overwrite = True + + # 运行转换 + return_code = run_audio_convert(audio_files, output_dir, args.config, need_overwrite or args.overwrite) + + # 显示完成提示 + if return_code == 0: + print(success(f"音频归一化处理完成!共处理 {len(audio_files)} 个文件")) + else: + print(fail(f"音频归一化处理失败,错误码: {return_code}")) + + return return_code + + +if __name__ == "__main__": + sys.exit(main()) \ No newline at end of file diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/pipeline/recognize_monitor.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/pipeline/recognize_monitor.py new file mode 100644 index 00000000..0352a407 --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/pipeline/recognize_monitor.py @@ -0,0 +1,300 @@ +#!/usr/bin/env python3 +""" +识别管理脚本(Python 版) + +- 默认读取 output_data/split/item_with_lang.list +- 按 lang 将子片段拆分为中文/英文两份列表 +- 先统一识别中文,再识别英文(减少模型切换开销) +- 调用 merge_asr_by_source 按 source_key/segment_index 合并回原音频文本 +""" + +import argparse +import json +import os +import shutil +import sys +import tempfile +import threading +from pathlib import Path +from typing import Dict, List, Tuple + + +PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent + +# 颜色打印工具 +sys.path.insert(0, str(PROJECT_ROOT / "src" / "utils")) +try: + from color_utils import info, warning, error, ok, success, header # type: ignore + + def print_info(msg: str): + print(info(msg)) + + def print_warning(msg: str): + print(warning(msg)) + + def print_error(msg: str): + print(error(msg)) + + def print_ok(msg: str): + print(ok(msg)) + + def print_success(msg: str): + print(success(msg)) + + def print_header(msg: str): + print(header(msg)) + +except Exception: + def print_info(msg: str): + print(f"[INFO] {msg}") + + def print_warning(msg: str): + print(f"[WARNING] {msg}") + + def print_error(msg: str): + print(f"[ERROR] {msg}") + + def print_ok(msg: str): + print(f"[OK] {msg}") + + def print_success(msg: str): + print(f"[SUCCESS] {msg}") + + def print_header(msg: str): + print(f"=== {msg} ===") + + +# YAML 配置加载(可选) +try: + from yaml_config_loader import parse_args_with_yaml_config # type: ignore +except Exception: + parse_args_with_yaml_config = None # type: ignore[assignment] + + +def split_by_lang(list_file: Path, tmp_dir: Path) -> Tuple[Path, Path, int, int]: + """根据 lang 字段将 item_with_lang.list 拆成 zh/en 两个 jsonl 列表。""" + zh_list = tmp_dir / "zh.list" + en_list = tmp_dir / "en.list" + + zh_items: List[Dict] = [] + en_items: List[Dict] = [] + + with list_file.open("r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + d = json.loads(line) + row = { + "key": d.get("key", ""), + "wav": d.get("wav", ""), + "txt": d.get("txt", ""), + } + if d.get("lang") == "zh": + zh_items.append(row) + else: + en_items.append(row) + + for path, items in [(zh_list, zh_items), (en_list, en_items)]: + with path.open("w", encoding="utf-8") as f: + for r in items: + f.write(json.dumps(r, ensure_ascii=False) + "\n") + + print_info(f"zh segments: {len(zh_items)} en segments: {len(en_items)}") + return zh_list, en_list, len(zh_items), len(en_items) + + +def _find_audio_files(input_path: Path) -> List[Path]: + exts = {".wav", ".flac", ".mp3", ".aac", ".m4a", ".ogg", ".webm"} + if input_path.is_file(): + return [input_path] + files: List[Path] = [] + for p in input_path.rglob("*"): + if p.is_file() and p.suffix.lower() in exts: + files.append(p) + return sorted(files) + + +def run_recognize(language: str, audio_list: Path, result_dir: Path, device: str) -> int: + """通过子进程调用 src.utils.recognize.""" + import subprocess + + cmd = [ + sys.executable, + "-m", + "src.utils.recognize", + "--language", + language, + "--audio_list", + str(audio_list), + "--result_dir", + str(result_dir), + ] + if device: + cmd.extend(["--device", device]) + + # 确保在项目根目录下运行,从而可以找到 src 包 + return subprocess.call(cmd, cwd=str(PROJECT_ROOT)) + + +def _run_recognize_thread( + language: str, + audio_list: Path, + result_dir: Path, + device: str, + rc_out: Dict[str, int], +) -> None: + rc_out[language] = int(run_recognize(language, audio_list, result_dir, device=device)) + + +def main() -> int: + parser = argparse.ArgumentParser( + description="识别管理脚本:读取 split 清单,按 zh/en 分别识别并合并结果", + ) + parser.add_argument( + "--config", + "-c", + default=None, + help="YAML 配置文件路径(可选)。支持写 recognize_monitor: {split_dir:..., asr_root:..., device:...} 或直接顶层同名键", + ) + parser.add_argument( + "--split_dir", + default=str(PROJECT_ROOT / "output_data" / "split"), + help="split 输出目录(包含 item_with_lang.list),默认: output_data/split", + ) + parser.add_argument( + "--list_file", + default=None, + help="自定义清单路径(默认使用 split_dir/item_with_lang.list)", + ) + parser.add_argument( + "--asr_root", + default=str(PROJECT_ROOT / "output_data" / "asr"), + help="ASR 结果根目录,默认: output_data/asr", + ) + parser.add_argument( + "--device", + default="npu", + help="传给 src.utils.recognize 的设备参数(auto/npu/cpu),默认 auto", + ) + # 默认并行,同时保留 --no-parallel 以便资源不足时回退 + parser.add_argument( + "--parallel", + action=argparse.BooleanOptionalAction, + default=True, + help="是否并行运行中/英两路识别以提速(默认开启;资源不足可用 --no-parallel 关闭)", + ) + parser.add_argument( + "--from_denoise", + action="store_true", + help="若未提供清单,默认从 output_data/denoise 扫描音频并生成临时 list", + ) + if parse_args_with_yaml_config: + args = parse_args_with_yaml_config( + parser, + section="recognize_monitor", + default_config_paths=[PROJECT_ROOT / "config" / "recognize_monitor.yaml"], + ) + else: + args = parser.parse_args() + + split_dir = Path(args.split_dir).resolve() + asr_root = Path(args.asr_root).resolve() + list_file = Path(args.list_file).resolve() if args.list_file else split_dir / "item_with_lang.list" + + print_header("识别管理") + print_info(f"项目根: {PROJECT_ROOT}") + print_info(f"清单: {list_file}") + print_info(f"ASR 输出: {asr_root}") + + if not list_file.exists(): + if args.from_denoise: + denoise_dir = PROJECT_ROOT / "output_data" / "denoise" + print_warning(f"清单不存在,改为从目录扫描: {denoise_dir}") + audio_files = _find_audio_files(denoise_dir) + if not audio_files: + print_error("未找到可识别的音频") + return 1 + tmp_list = Path(tempfile.mkdtemp(prefix="hz_list_")) / "item_with_lang.list" + tmp_list.parent.mkdir(parents=True, exist_ok=True) + with tmp_list.open("w", encoding="utf-8") as f: + for p in audio_files: + row = {"key": p.stem, "wav": str(p.resolve()), "txt": "", "lang": "en"} + f.write(json.dumps(row, ensure_ascii=False) + "\n") + list_file = tmp_list + else: + print_error(f"清单不存在: {list_file}") + print_info("请先运行: python -m src.pipeline.3_split_and_tag 或传 --from_denoise") + return 1 + + tmp_dir = Path(tempfile.mkdtemp(prefix="hz_split_")) + try: + zh_list, en_list, zh_n, en_n = split_by_lang(list_file, tmp_dir) + + # 识别:默认并行(可用 --no-parallel 关闭) + (asr_root / "zh").mkdir(parents=True, exist_ok=True) + (asr_root / "en").mkdir(parents=True, exist_ok=True) + + if args.parallel and zh_n > 0 and en_n > 0: + print_info("并行识别:同时启动中文与英文片段识别...") + rc_out: Dict[str, int] = {} + t_zh = threading.Thread( + target=_run_recognize_thread, + args=("zh", zh_list, asr_root / "zh", args.device, rc_out), + daemon=False, + ) + t_en = threading.Thread( + target=_run_recognize_thread, + args=("en", en_list, asr_root / "en", args.device, rc_out), + daemon=False, + ) + t_zh.start() + t_en.start() + t_zh.join() + t_en.join() + + zh_rc = int(rc_out.get("zh", 1)) + en_rc = int(rc_out.get("en", 1)) + if zh_rc != 0: + print_error(f"中文识别失败,返回码: {zh_rc}") + return zh_rc + if en_rc != 0: + print_error(f"英文识别失败,返回码: {en_rc}") + return en_rc + else: + if zh_n > 0: + print_info("识别中文片段...") + rc = run_recognize("zh", zh_list, asr_root / "zh", device=args.device) + if rc != 0: + print_error(f"中文识别失败,返回码: {rc}") + return rc + + if en_n > 0: + print_info("识别英文片段...") + rc = run_recognize("en", en_list, asr_root / "en", device=args.device) + if rc != 0: + print_error(f"英文识别失败,返回码: {rc}") + return rc + + # 合并结果 + print_info("合并子片段结果...") + from src.pipeline import merge_asr_by_source # type: ignore + + rc = merge_asr_by_source.main_for_api( # type: ignore[attr-defined] + list_file=list_file, + zh_text=asr_root / "zh" / "ctc_greedy_search" / "text", + en_text=asr_root / "en" / "ctc_greedy_search" / "text", + output=asr_root / "merged_text.txt", + ) + if rc != 0: + print_error(f"合并失败,返回码: {rc}") + return rc + + print_success(f"完成。合并文本: {asr_root / 'merged_text.txt'}") + return 0 + finally: + shutil.rmtree(tmp_dir, ignore_errors=True) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/pipeline/split_and_tag.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/pipeline/split_and_tag.py new file mode 100644 index 00000000..ccba01c0 --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/pipeline/split_and_tag.py @@ -0,0 +1,223 @@ +#!/usr/bin/env python3 +""" +将归一化后的音频按不超过 2 分钟切分为子片段,并处理 item_with_lang.list。 +在输出目录生成新的 list 文件,记录原音频与子片段的对应关系及语言标签。 +""" + +import argparse +import json +import sys +from pathlib import Path +from typing import Dict, List + +# 项目根与路径 +_PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent +sys.path.insert(0, str(_PROJECT_ROOT / "scripts" / "audio_convert")) +sys.path.insert(0, str(_PROJECT_ROOT / "src" / "utils")) + +try: + from color_utils import info, warning, error, ok, success, header +except ImportError: + def info(msg): return f"[INFO] {msg}" + def warning(msg): return f"[WARNING] {msg}" + def error(msg): return f"[ERROR] {msg}" + def ok(msg): return f"[OK] {msg}" + def success(msg): return f"[SUCCESS] {msg}" + def header(msg): return f"=== {msg} ===" + +def _print_info(msg): print(info(msg)) +def _print_warning(msg): print(warning(msg)) +def _print_error(msg): print(error(msg)) +def _print_ok(msg): print(ok(msg)) +def _print_success(msg): print(success(msg)) +def _print_header(msg): print(header(msg)) + +# YAML 配置加载(可选) +try: + from yaml_config_loader import parse_args_with_yaml_config # type: ignore +except Exception: + parse_args_with_yaml_config = None # type: ignore[assignment] + +DEFAULT_INPUT_DIR = _PROJECT_ROOT / "output_data" / "denoise" +DEFAULT_OUTPUT_DIR = _PROJECT_ROOT / "output_data" / "split" +DEFAULT_LIST_PATH = _PROJECT_ROOT / "output_data" / "lid" / "item_with_lang.list" +MAX_SEGMENT_SECONDS = 120 # 2 分钟 + + +def _load_item_list(path: Path) -> List[Dict]: + items = [] + with open(path, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + items.append(json.loads(line)) + return items + + +def _find_audio_files(input_path: Path) -> List[Path]: + exts = {".wav", ".flac", ".mp3", ".aac", ".m4a", ".ogg", ".webm"} + if input_path.is_file(): + return [input_path] + files: List[Path] = [] + for p in input_path.rglob("*"): + if p.is_file() and p.suffix.lower() in exts: + files.append(p) + return sorted(files) + + +def _import_pydub(): + try: + from pydub import AudioSegment # type: ignore + return AudioSegment + except Exception as e: + raise RuntimeError(f"无法导入 pydub,请在 DataMate 运行环境安装 pydub: {e}") from e + + +def split_audio_to_segments( + wav_path: Path, + output_dir: Path, + base_key: str, + lang: str, + max_seconds: int = MAX_SEGMENT_SECONDS, +) -> List[Dict]: + """ + 将单个 wav 按 max_seconds 切分,导出到 output_dir,返回子片段 list 项。 + 每项含 key, wav, txt, lang, source_key, segment_index。 + """ + AudioSegment = _import_pydub() + audio = AudioSegment.from_file(str(wav_path)) + duration_ms = len(audio) + segment_ms = max_seconds * 1000 + if segment_ms <= 0: + segment_ms = duration_ms + + out_items = [] + seg_idx = 0 + start_ms = 0 + while start_ms < duration_ms: + end_ms = min(start_ms + segment_ms, duration_ms) + chunk = audio[start_ms:end_ms] + part_key = f"{base_key}_part{seg_idx}" + out_wav = output_dir / f"{part_key}.wav" + out_wav.parent.mkdir(parents=True, exist_ok=True) + chunk.export(str(out_wav), format="wav") + out_items.append({ + "key": part_key, + "wav": str(out_wav.resolve()), + "txt": "", + "lang": lang, + "source_key": base_key, + "segment_index": seg_idx, + }) + start_ms = end_ms + seg_idx += 1 + return out_items + + +def main() -> int: + parser = argparse.ArgumentParser( + description="将音频切分为不超过 2 分钟的子片段,并生成带语言与对应关系的 list", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument( + "--config", + "-c", + default=None, + help="YAML 配置文件路径(可选)。支持写 split_and_tag: {input_dir:..., max_seconds:...} 或直接顶层同名键", + ) + parser.add_argument( + "--input_dir", "-i", + default=str(DEFAULT_INPUT_DIR), + help=f"音频输入目录,默认: {DEFAULT_INPUT_DIR}", + ) + parser.add_argument( + "--output_dir", "-o", + default=str(DEFAULT_OUTPUT_DIR), + help=f"子片段输出目录,默认: {DEFAULT_OUTPUT_DIR}", + ) + parser.add_argument( + "--list_file", "-l", + default=str(DEFAULT_LIST_PATH), + help=f"带语言的 list 文件 (jsonl),默认: {DEFAULT_LIST_PATH}", + ) + parser.add_argument( + "--from_list", + action="store_true", + help="输入作为 list 文件处理;默认按目录扫描音频", + ) + parser.add_argument( + "--max_seconds", "-s", + type=int, + default=MAX_SEGMENT_SECONDS, + help=f"每段最大秒数,默认: {MAX_SEGMENT_SECONDS}", + ) + if parse_args_with_yaml_config: + args = parse_args_with_yaml_config( + parser, + section="split_and_tag", + default_config_paths=[_PROJECT_ROOT / "config" / "split_and_tag.yaml"], + ) + else: + args = parser.parse_args() + + input_dir = Path(args.input_dir).resolve() + output_dir = Path(args.output_dir).resolve() + list_path = Path(args.list_file).resolve() + + _print_header("切分音频并打标签") + + items: List[Dict] = [] + if args.from_list or list_path.exists(): + if not list_path.exists(): + _print_error(f"列表文件不存在: {list_path}") + return 1 + items = _load_item_list(list_path) + else: + if not input_dir.exists(): + _print_error(f"输入目录不存在: {input_dir}") + return 1 + audio_files = _find_audio_files(input_dir) + if not audio_files: + _print_warning("未找到任何音频文件") + return 0 + items = [{"key": p.stem, "wav": str(p.resolve()), "txt": "", "lang": "en"} for p in audio_files] + + if not items: + _print_warning("输入为空,退出") + return 0 + + output_dir.mkdir(parents=True, exist_ok=True) + all_segments: List[Dict] = [] + for it in items: + key = it.get("key", "") + wav = it.get("wav") or it.get("audio") or it.get("path", "") + lang = it.get("lang", "en") + if not wav or not key: + _print_warning(f"跳过无效项: key={key}, wav={wav}") + continue + wav_path = Path(wav) + if not wav_path.exists(): + _print_warning(f"文件不存在,跳过: {wav_path}") + continue + try: + segs = split_audio_to_segments( + wav_path, output_dir, key, lang, + max_seconds=args.max_seconds, + ) + all_segments.extend(segs) + except Exception as e: + _print_error(f"切分失败 {wav_path}: {e}") + continue + + out_list_path = output_dir / "item_with_lang.list" + with open(out_list_path, "w", encoding="utf-8") as f: + for it in all_segments: + f.write(json.dumps(it, ensure_ascii=False) + "\n") + + _print_success(f"完成。共 {len(all_segments)} 个子片段,列表: {out_list_path}") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/tools/audio_anomaly_filter.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/tools/audio_anomaly_filter.py new file mode 100644 index 00000000..5878c351 --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/tools/audio_anomaly_filter.py @@ -0,0 +1,417 @@ +#!/usr/bin/env python3 +""" +音频异常检测与过滤(通用工具版) + +用途: +- 可单独作为工具使用,对任意目录或指定列表中的音频做质量检测 +- 输出带 quality_flag 字段的 jsonl 列表,可直接给 fast_lang_id / 其它组件使用 + +特性: +- 支持两种输入方式(二选一): + 1) --audio_dir:扫描目录下所有音频文件 + 2) --input_list:读取 jsonl 列表(需包含 wav/path/audio 字段之一) +- 可选导出仅包含 quality_flag=="ok" 的精简列表,便于下游直接使用 + +示例: + # 1) 扫描目录,输出完整质量列表 + python -m src.tools.audio_anomaly_filter \\ + --audio_dir ./output_data/normalization \\ + --output ./output_data/normalization/item_with_quality.list + + # 2) 基于现有列表做质量检测,并额外导出 only-ok 列表 + python -m src.tools.audio_anomaly_filter \\ + --input_list ./output_data/normalization/item.list \\ + --output ./output_data/normalization/item_with_quality.list \\ + --ok_output ./output_data/normalization/item_ok.list +""" + +from __future__ import annotations + +import argparse +import json +import math +import sys +from pathlib import Path +from typing import Dict, Iterable, List, Tuple + + +def _project_root() -> Path: + return Path(__file__).parent.parent.parent + + +def _ensure_utils_on_path() -> None: + root = _project_root() + utils_dir = root / "src" / "utils" + scripts_dir = root / "scripts" / "audio_convert" + for p in (utils_dir, scripts_dir): + if p.exists(): + sp = str(p) + if sp not in sys.path: + sys.path.insert(0, sp) + + +_ensure_utils_on_path() + +try: + from color_utils import info, warning, error, ok, success, header # type: ignore +except Exception: # pragma: no cover - 兼容无 color_utils 场景 + def info(msg: str) -> str: + return f"[INFO] {msg}" + + def warning(msg: str) -> str: + return f"[WARNING] {msg}" + + def error(msg: str) -> str: + return f"[ERROR] {msg}" + + def ok(msg: str) -> str: + return f"[OK] {msg}" + + def success(msg: str) -> str: + return f"[SUCCESS] {msg}" + + def header(msg: str) -> str: + return f"=== {msg} ===" + + +def _print_info(msg: str) -> None: + print(info(msg)) + + +def _print_warning(msg: str) -> None: + print(warning(msg)) + + +def _print_error(msg: str) -> None: + print(error(msg)) + + +def _print_success(msg: str) -> None: + print(success(msg)) + + +def _find_audio_files(audio_dir: Path) -> List[Path]: + patterns = ["*.wav", "*.WAV", "*.flac", "*.FLAC", "*.mp3", "*.MP3", "*.aac", "*.AAC", "*.m4a", "*.M4A"] + files: List[Path] = [] + for pat in patterns: + files.extend(audio_dir.rglob(pat)) + return sorted(set(files)) + + +def _load_wave(path: Path) -> Tuple[List[float], int]: + """ + 读取音频为 mono waveform 和采样率。 + + 优先使用 torchaudio(项目已依赖 speechbrain,通常可用), + 若导入失败则退化为 soundfile; 再失败则抛错。 + """ + try: + import torchaudio # type: ignore + + wav, sr = torchaudio.load(str(path)) + if wav.ndim > 1: + wav = wav.mean(dim=0, keepdim=True) + mono = wav.squeeze(0).float().tolist() + return mono, int(sr) + except Exception: + try: + import soundfile as sf # type: ignore + + data, sr = sf.read(str(path), always_2d=False) + if getattr(data, "ndim", 1) > 1: + # stereo -> mono + data = data.mean(axis=1) + return data.tolist(), int(sr) + except Exception as e: + raise RuntimeError(f"读取音频失败: {path}, error={e}") from e + + +def _frame_rms(x: List[float], sr: int, frame_ms: float, hop_ms: float) -> Tuple[List[float], float]: + if not x or sr <= 0: + return [], 0.0 + frame_len = max(1, int(sr * frame_ms / 1000.0)) + hop = max(1, int(sr * hop_ms / 1000.0)) + n = len(x) + rms_list: List[float] = [] + total_sq = 0.0 + for v in x: + total_sq += float(v) * float(v) + global_rms = math.sqrt(total_sq / max(1, n)) + for start in range(0, n, hop): + end = min(start + frame_len, n) + if end <= start: + continue + s = 0.0 + cnt = 0 + for v in x[start:end]: + s += float(v) * float(v) + cnt += 1 + if cnt == 0: + rms = 0.0 + else: + rms = math.sqrt(s / cnt) + rms_list.append(rms) + return rms_list, global_rms + + +def _analyze_one( + wav_path: Path, + key: str, + min_dur: float, + max_dur: float, + silence_ratio_th: float, + silence_rms_ratio_th: float, +) -> Dict: + wav, sr = _load_wave(wav_path) + n = len(wav) + duration = float(n) / float(sr) if sr > 0 else 0.0 + + rms_frames, global_rms = _frame_rms(wav, sr, frame_ms=25.0, hop_ms=10.0) + if not rms_frames or global_rms <= 0.0: + silence_ratio = 1.0 + else: + th = max(1e-8, global_rms * silence_rms_ratio_th) + silent = sum(1 for r in rms_frames if r < th) + silence_ratio = float(silent) / float(len(rms_frames)) + + reasons: List[str] = [] + quality_flag = "ok" + + if duration <= 0.0: + quality_flag = "invalid" + reasons.append("duration_le_zero") + elif duration < min_dur: + quality_flag = "invalid" + reasons.append("too_short") + elif duration > max_dur: + quality_flag = "invalid" + reasons.append("too_long") + + if silence_ratio >= silence_ratio_th: + quality_flag = "invalid" + reasons.append("too_much_silence") + + return { + "key": key, + "wav": str(wav_path.resolve()), + "duration": round(duration, 3), + "silence_ratio": round(silence_ratio, 4), + "global_rms": round(global_rms, 6), + "quality_flag": quality_flag, + "reason": ",".join(reasons) if reasons else "", + } + + +def _dump_jsonl(path: Path, items: Iterable[Dict]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "w", encoding="utf-8") as f: + for it in items: + f.write(json.dumps(it, ensure_ascii=False) + "\n") + + +def _load_input_list(path: Path) -> List[Dict]: + items: List[Dict] = [] + with open(path, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + items.append(json.loads(line)) + return items + + +def parse_arguments() -> argparse.Namespace: + root = _project_root() + default_audio_dir = root / "output_data" / "normalization" + default_output = root / "output_data" / "normalization" / "item_with_quality.list" + + parser = argparse.ArgumentParser( + description="音频异常检测与过滤工具(基于时长和静音比例的快速规则)", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + + g = parser.add_mutually_exclusive_group(required=False) + g.add_argument( + "--audio_dir", + "-a", + default=str(default_audio_dir), + help=f"要扫描的音频目录,默认: {default_audio_dir}", + ) + g.add_argument( + "--input_list", + "-i", + default=None, + help="输入 jsonl 列表路径(每行包含 wav/path/audio 字段之一)", + ) + + parser.add_argument( + "--output", + "-o", + default=str(default_output), + help=f"输出(带 quality_flag 的)jsonl 列表路径,默认: {default_output}", + ) + parser.add_argument( + "--ok_output", + default=None, + help="可选:另存一份仅包含 quality_flag=='ok' 条目的 jsonl 列表路径", + ) + parser.add_argument( + "--min_dur", + type=float, + default=1.0, + help="最小时长(秒),小于该值视为异常,默认 1.0", + ) + parser.add_argument( + "--max_dur", + type=float, + default=120.0, + help="最大时长(秒),大于该值视为异常,默认 120.0", + ) + parser.add_argument( + "--silence_ratio_th", + type=float, + default=0.8, + help="静音帧比例阈值,超过则视为异常,默认 0.8", + ) + parser.add_argument( + "--silence_rms_ratio_th", + type=float, + default=0.05, + help="静音判定阈值 = global_rms * 该比例,默认 0.05", + ) + + return parser.parse_args() + + +def main() -> int: + args = parse_arguments() + output_path = Path(args.output).resolve() + ok_output_path = Path(args.ok_output).resolve() if args.ok_output else None + + print(header("音频异常检测与过滤(工具版)")) + print( + info( + f"参数: min_dur={args.min_dur}s, max_dur={args.max_dur}s, " + f"silence_ratio_th={args.silence_ratio_th}, silence_rms_ratio_th={args.silence_rms_ratio_th}" + ) + ) + + items_with_quality: List[Dict] = [] + + if args.input_list: + input_path = Path(args.input_list).resolve() + if not input_path.exists(): + _print_error(f"输入列表不存在: {input_path}") + return 1 + _print_info(f"基于输入列表进行质量检测: {input_path}") + base_items = _load_input_list(input_path) + if not base_items: + _print_warning("输入列表为空,退出") + return 0 + + for idx, it in enumerate(base_items, start=1): + wav_path_str = it.get("wav") or it.get("audio") or it.get("path") + if not wav_path_str: + _print_warning(f"条目缺少 wav/audio/path 字段,标记为 invalid: {it.get('key', '')}") + out = dict(it) + out.update( + { + "duration": 0.0, + "silence_ratio": 1.0, + "global_rms": 0.0, + "quality_flag": "invalid", + "reason": "no_wav_field", + } + ) + items_with_quality.append(out) + continue + + wav_path = Path(wav_path_str) + key = str(it.get("key", wav_path.stem)) + try: + quality_info = _analyze_one( + wav_path=wav_path, + key=key, + min_dur=float(args.min_dur), + max_dur=float(args.max_dur), + silence_ratio_th=float(args.silence_ratio_th), + silence_rms_ratio_th=float(args.silence_rms_ratio_th), + ) + except Exception as e: + _print_warning(f"处理失败,标记为 invalid: {wav_path}, error={e}") + quality_info = { + "key": key, + "wav": str(wav_path.resolve()), + "duration": 0.0, + "silence_ratio": 1.0, + "global_rms": 0.0, + "quality_flag": "invalid", + "reason": "load_error", + } + + # 保留原始字段,再叠加质量信息 + merged = dict(it) + merged.update(quality_info) + items_with_quality.append(merged) + + if idx % 20 == 0 or idx == len(base_items): + _print_info(f"进度: {idx}/{len(base_items)}") + else: + audio_dir = Path(args.audio_dir).resolve() + if not audio_dir.exists(): + _print_error(f"音频目录不存在: {audio_dir}") + return 1 + _print_info(f"扫描目录: {audio_dir}") + files = _find_audio_files(audio_dir) + if not files: + _print_warning(f"目录中未找到任何音频文件: {audio_dir}") + return 0 + + _print_info(f"待分析音频数: {len(files)}") + + for idx, p in enumerate(files, start=1): + try: + quality_info = _analyze_one( + wav_path=p, + key=p.stem, + min_dur=float(args.min_dur), + max_dur=float(args.max_dur), + silence_ratio_th=float(args.silence_ratio_th), + silence_rms_ratio_th=float(args.silence_rms_ratio_th), + ) + except Exception as e: + _print_warning(f"处理失败,标记为 invalid: {p}, error={e}") + quality_info = { + "key": p.stem, + "wav": str(p.resolve()), + "duration": 0.0, + "silence_ratio": 1.0, + "global_rms": 0.0, + "quality_flag": "invalid", + "reason": "load_error", + } + items_with_quality.append(quality_info) + + if idx % 20 == 0 or idx == len(files): + _print_info(f"进度: {idx}/{len(files)}") + + if not items_with_quality: + _print_warning("没有任何条目被处理,退出") + return 0 + + _dump_jsonl(output_path, items_with_quality) + invalid_count = sum(1 for it in items_with_quality if it.get("quality_flag") == "invalid") + _print_success(f"分析完成,输出: {output_path}") + _print_info(f"统计: 总数={len(items_with_quality)}, invalid={invalid_count}, ok={len(items_with_quality) - invalid_count}") + + if ok_output_path is not None: + ok_items = [it for it in items_with_quality if it.get("quality_flag") == "ok"] + _dump_jsonl(ok_output_path, ok_items) + _print_info(f"另存仅包含 ok 条目的列表: {ok_output_path} (数量={len(ok_items)})") + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) + diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/tools/convert_audio.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/tools/convert_audio.py new file mode 100644 index 00000000..4ad7f7f2 --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/tools/convert_audio.py @@ -0,0 +1,671 @@ +#!/usr/bin/env python3 +""" +音频转换工具 +支持常见音频格式互转和属性调整(声道数、采样率、编码等) +使用本地pydub库,支持配置文件或命令行参数 +""" + +import argparse +import os +import sys +import shutil +import yaml +from pathlib import Path +from typing import Dict, Any, Optional, List + +# ==================== 相对路径导入 ==================== + +# 计算项目根目录 +if __name__ == "__main__": + CURRENT_DIR = Path(__file__).resolve().parent +else: + CURRENT_DIR = Path.cwd() + +# 项目根目录:向上两级到 audio_preprocessor +PROJECT_ROOT = CURRENT_DIR.parent.parent + +# 导入颜色工具 +COLOR_UTILS_PATH = PROJECT_ROOT / "src" / "utils" / "color_utils.py" +if COLOR_UTILS_PATH.exists(): + sys.path.insert(0, str(PROJECT_ROOT / "src" / "utils")) + try: + from color_utils import info, warning, error, ok, success, fail, header + except ImportError as e: + print(f"[WARNING] 无法导入颜色工具: {e}", file=sys.stderr) + # 定义简单的替代函数 + def info(msg): return f"[INFO] {msg}" + def warning(msg): return f"[WARNING] {msg}" + def error(msg): return f"[ERROR] {msg}" + def ok(msg): return f"[OK] {msg}" + def success(msg): return f"[SUCCESS] {msg}" + def fail(msg): return f"[FAIL] {msg}" + def header(msg): return f"=== {msg} ===" +else: + # 定义简单的替代函数 + def info(msg): return f"[INFO] {msg}" + def warning(msg): return f"[WARNING] {msg}" + def error(msg): return f"[ERROR] {msg}" + def ok(msg): return f"[OK] {msg}" + def success(msg): return f"[SUCCESS] {msg}" + def fail(msg): return f"[FAIL] {msg}" + def header(msg): return f"=== {msg} ===" + +# ==================== 配置管理 ==================== + +class ConfigManager: + """配置管理器""" + + DEFAULT_CONFIG = { + 'audio_config': { + 'output_format': 'wav', + 'channels': 1, + 'sample_rate': 16000, + 'sample_width': 2, # bytes + 'encoding': 'pcm_s16le', + 'bitrate': None, + 'input_format': ['mp3', 'wav', 'aac', 'm4a', 'flac', 'ogg', 'opus', 'wma'], + 'quality': 5, # 1-9,仅某些格式有效 + 'compression': None, # 压缩级别 + 'dither': None # 抖动算法 + } + } + + @staticmethod + def find_config_file(config_path: Optional[str] = None) -> Path: + """ + 查找配置文件,按以下优先级: + 1. 命令行指定的路径 + 2. 当前目录的 config/audio_config.yaml + 3. 项目根目录的 config/audio_config.yaml + 4. 用户主目录的 .audio_preprocessor/audio_config.yaml + """ + if config_path: + path = Path(config_path) + if path.exists(): + return path + else: + raise FileNotFoundError(f"指定的配置文件不存在: {path}") + + search_paths = [ + Path.cwd() / "config" / "audio_config.yaml", + PROJECT_ROOT / "config" / "audio_config.yaml", + Path.home() / ".audio_preprocessor" / "audio_config.yaml", + ] + + for path in search_paths: + if path.exists(): + return path + + # 如果都找不到,返回默认路径 + return search_paths[1] # 项目根目录的config + + @staticmethod + def load_config(config_path: Optional[str] = None) -> Dict[str, Any]: + """加载配置文件""" + config_file = ConfigManager.find_config_file(config_path) + + if not config_file.exists(): + print(warning(f"配置文件不存在,使用默认配置")) + return ConfigManager.DEFAULT_CONFIG.get('audio_config', {}) + + try: + with open(config_file, 'r', encoding='utf-8') as f: + config_data = yaml.safe_load(f) + + # 提取audio_config部分或使用顶级配置 + if 'audio_config' in config_data: + config = config_data['audio_config'] + else: + config = config_data + + # 确保必要的键存在 + default_config = ConfigManager.DEFAULT_CONFIG['audio_config'] + for key, value in default_config.items(): + if key not in config: + config[key] = value + + print(info(f"已加载配置文件: {config_file}")) + return config + + except yaml.YAMLError as e: + print(error(f"配置文件格式错误: {e}")) + print(warning("使用默认配置")) + return ConfigManager.DEFAULT_CONFIG.get('audio_config', {}) + except Exception as e: + print(error(f"加载配置文件失败: {e}")) + print(warning("使用默认配置")) + return ConfigManager.DEFAULT_CONFIG.get('audio_config', {}) + + @staticmethod + def merge_configs(config: Dict[str, Any], args: argparse.Namespace) -> Dict[str, Any]: + """合并配置文件和命令行参数""" + merged = config.copy() + + # 映射命令行参数到配置键 + arg_mapping = { + 'output_format': 'format', + 'channels': 'channels', + 'sample_rate': 'sample_rate', + 'sample_width': 'sample_width', + 'encoding': 'encoding', + 'bitrate': 'bitrate', + 'quality': 'quality', + } + + for config_key, arg_key in arg_mapping.items(): + arg_value = getattr(args, arg_key, None) + if arg_value is not None: + merged[config_key] = arg_value + + return merged + +# ==================== 音频转换器 ==================== + +class AudioConverter: + """音频转换器""" + + # 支持的输出格式和对应的编码器 + FORMAT_CODECS = { + 'wav': ['pcm_s16le', 'pcm_s24le', 'pcm_s32le', 'pcm_f32le', 'pcm_f64le'], + 'mp3': ['libmp3lame'], + 'flac': ['flac'], + 'ogg': ['libvorbis', 'opus'], + 'm4a': ['aac'], + 'aac': ['aac'], + 'opus': ['opus'], + 'wma': ['wmav2'], + 'aiff': ['pcm_s16be', 'pcm_s24be', 'pcm_s32be'], + } + + # 格式到扩展名的映射 + FORMAT_EXTENSIONS = { + 'wav': '.wav', + 'mp3': '.mp3', + 'flac': '.flac', + 'ogg': '.ogg', + 'm4a': '.m4a', + 'aac': '.aac', + 'opus': '.opus', + 'wma': '.wma', + 'aiff': '.aiff', + } + + def __init__(self): + """初始化音频转换器""" + self._import_pydub() + + def _import_pydub(self): + """导入pydub库""" + try: + from pydub import AudioSegment + self.AudioSegment = AudioSegment + print(ok("成功导入 pydub 库")) + except ImportError as e: + print(error(f"无法导入 pydub: {e}")) + print(info("请确保 pydub 已安装或本地库路径正确")) + sys.exit(1) + + def get_supported_formats(self) -> List[str]: + """获取支持的输出格式""" + return list(self.FORMAT_CODECS.keys()) + + def validate_config(self, config: Dict[str, Any]) -> List[str]: + """验证配置,返回错误列表""" + errors = [] + + # 检查输出格式 + output_format = config.get('output_format', 'wav').lower() + if output_format not in self.get_supported_formats(): + errors.append(f"不支持的输出格式: {output_format}") + + # 检查声道数 + channels = config.get('channels', 1) + if channels not in [1, 2, 4, 6, 8]: + errors.append(f"不支持的声道数: {channels} (支持: 1, 2, 4, 6, 8)") + + # 检查采样率 + sample_rate = config.get('sample_rate', 16000) + if sample_rate <= 0: + errors.append(f"无效的采样率: {sample_rate}") + + # 检查采样位宽 + sample_width = config.get('sample_width', 2) + if sample_width not in [1, 2, 3, 4]: + errors.append(f"不支持的采样位宽: {sample_width} (支持: 1, 2, 3, 4字节)") + + # 检查编码器 + encoding = config.get('encoding', '') + if output_format in self.FORMAT_CODECS: + supported_codecs = self.FORMAT_CODECS[output_format] + if encoding and encoding not in supported_codecs: + errors.append(f"格式 {output_format} 不支持的编码器: {encoding} (支持: {', '.join(supported_codecs)})") + + return errors + + def convert_audio(self, input_path: Path, output_path: Path, config: Dict[str, Any]) -> bool: + """转换单个音频文件""" + try: + print(info(f"处理: {input_path.name}")) + + # 加载音频文件 + audio = self.AudioSegment.from_file(str(input_path)) + + # 应用转换参数 + channels = config.get('channels', 1) + if channels != audio.channels: + audio = audio.set_channels(channels) + print(info(f" 声道数: {audio.channels} -> {channels}")) + + sample_rate = config.get('sample_rate', 16000) + if sample_rate != audio.frame_rate: + audio = audio.set_frame_rate(sample_rate) + print(info(f" 采样率: {audio.frame_rate} -> {sample_rate}")) + + sample_width = config.get('sample_width', 2) + if sample_width != audio.sample_width: + audio = audio.set_sample_width(sample_width) + print(info(f" 采样位宽: {audio.sample_width} -> {sample_width}")) + + # 准备导出参数 + export_params = {} + + # 格式特定参数 + output_format = config.get('output_format', 'wav').lower() + + # 编码器 + encoding = config.get('encoding') + if encoding: + export_params['codec'] = encoding + + # 比特率 + bitrate = config.get('bitrate') + if bitrate: + export_params['bitrate'] = bitrate + + # 质量(某些格式使用) + quality = config.get('quality') + if quality is not None: + if output_format in ['mp3', 'ogg', 'opus']: + export_params['quality'] = quality + + # 压缩级别 + compression = config.get('compression') + if compression is not None: + if output_format in ['flac']: + export_params['compression'] = compression + + # 导出音频 + audio.export(str(output_path), format=output_format, **export_params) + + # 验证输出文件 + if output_path.exists(): + output_size = output_path.stat().st_size / 1024 # KB + print(ok(f" 转换成功: {output_path.name} ({output_size:.1f} KB)")) + return True + else: + print(error(f" 转换失败: 输出文件未创建")) + return False + + except Exception as e: + print(error(f" 转换失败: {e}")) + return False + + def batch_convert(self, input_files: List[Path], output_dir: Path, config: Dict[str, Any]) -> Dict[str, Any]: + """批量转换音频文件""" + results = { + 'total': len(input_files), + 'success': 0, + 'failed': 0, + 'failed_files': [] + } + + if not output_dir.exists(): + output_dir.mkdir(parents=True, exist_ok=True) + + print(header(f"开始批量转换 ({results['total']} 个文件)")) + + for i, input_file in enumerate(input_files, 1): + print(info(f"[{i}/{results['total']}]")) + + # 确定输出文件名 + output_format = config.get('output_format', 'wav').lower() + output_ext = self.FORMAT_EXTENSIONS.get(output_format, f".{output_format}") + output_name = input_file.stem + output_ext + output_path = output_dir / output_name + + # 执行转换 + if self.convert_audio(input_file, output_path, config): + results['success'] += 1 + else: + results['failed'] += 1 + results['failed_files'].append(str(input_file)) + + return results + +# ==================== 命令行界面 ==================== + +def build_argparser() -> argparse.ArgumentParser: + """构建命令行参数解析器""" + + # 获取支持的格式列表(动态) + converter = AudioConverter() + supported_formats = converter.get_supported_formats() + + parser = argparse.ArgumentParser( + prog="convert_audio", + description="音频转换工具 - 支持常见音频格式互转和属性调整", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +示例: + %(prog)s input.mp3 output.wav # 基本转换 + %(prog)s input.mp3 output.wav --sample-rate=44100 --channels=2 + %(prog)s *.mp3 output_dir/ --format=flac # 批量转换 + %(prog)s input.wav output.mp3 --bitrate=192k # 指定比特率 + %(prog)s --config=my_config.yaml input.wav output.flac + +支持的输出格式: """ + ", ".join(supported_formats) + ) + + # 基本参数 + parser.add_argument( + "input", + nargs="+", + help="输入音频文件或目录(支持通配符如 *.mp3)" + ) + + parser.add_argument( + "output", + help="输出文件或目录(如果是多个输入则必须是目录)" + ) + + # 配置文件 + parser.add_argument( + "--config", + default=None, + help="自定义配置文件路径" + ) + + # 音频参数 + parser.add_argument( + "--format", + choices=supported_formats, + help=f"输出格式(默认: wav)" + ) + + parser.add_argument( + "--channels", + type=int, + choices=[1, 2, 4, 6, 8], + help="声道数(默认: 1)" + ) + + parser.add_argument( + "--sample-rate", + type=int, + help="采样率(Hz,默认: 16000)" + ) + + parser.add_argument( + "--sample-width", + type=int, + choices=[1, 2, 3, 4], + help="采样位宽(字节,默认: 2)" + ) + + parser.add_argument( + "--encoding", + help="编码器(格式相关,如 pcm_s16le, libmp3lame 等)" + ) + + parser.add_argument( + "--bitrate", + help="比特率(如 128k, 192k, 320k)" + ) + + parser.add_argument( + "--quality", + type=int, + choices=range(0, 10), + help="质量级别 0-9(仅某些格式有效)" + ) + + # 其他选项 + parser.add_argument( + "--overwrite", + action="store_true", + help="覆盖已存在的输出文件" + ) + + parser.add_argument( + "--list-formats", + action="store_true", + help="列出支持的输出格式并退出" + ) + + parser.add_argument( + "--show-config", + action="store_true", + help="显示当前配置并退出" + ) + + parser.add_argument( + "--verbose", + "-v", + action="count", + default=0, + help="详细输出 (-v, -vv, -vvv)" + ) + + return parser + +def expand_inputs(input_args: List[str]) -> List[Path]: + """扩展输入参数(支持通配符)""" + import glob + + input_files = [] + + for arg in input_args: + # 检查是否是通配符 + if '*' in arg or '?' in arg or '[' in arg: + matches = glob.glob(arg, recursive=True) + for match in matches: + path = Path(match) + if path.is_file(): + input_files.append(path) + else: + path = Path(arg) + if path.is_dir(): + # 目录:添加所有文件 + for file_path in path.rglob('*'): + if file_path.is_file(): + input_files.append(file_path) + elif path.is_file(): + input_files.append(path) + else: + print(warning(f"输入路径不存在: {arg}")) + + # 去重并排序 + input_files = sorted(set(input_files), key=lambda x: str(x)) + + return input_files + +def validate_input_files(input_files: List[Path], config: Dict[str, Any]) -> List[Path]: + """验证输入文件""" + if not input_files: + print(error("未找到任何输入文件")) + sys.exit(1) + + # 检查文件扩展名 + input_formats = config.get('input_format', []) + allowed_exts = {f".{fmt.lower().lstrip('.')}" for fmt in input_formats} + + valid_files = [] + invalid_files = [] + + for file_path in input_files: + if file_path.suffix.lower() in allowed_exts: + valid_files.append(file_path) + else: + invalid_files.append(file_path.name) + + if invalid_files: + print(warning(f"跳过 {len(invalid_files)} 个不支持格式的文件")) + if len(invalid_files) <= 10: # 只显示前10个 + for file_name in invalid_files[:10]: + print(f" {file_name}") + if len(invalid_files) > 10: + print(f" ... 还有 {len(invalid_files) - 10} 个") + + return valid_files + +def main(): + """主函数""" + # 解析命令行参数 + parser = build_argparser() + args = parser.parse_args() + + # 显示标题 + print(header("音频转换工具")) + + # 列出支持的格式 + if args.list_formats: + converter = AudioConverter() + print(info("支持的输出格式:")) + for fmt in converter.get_supported_formats(): + codecs = converter.FORMAT_CODECS.get(fmt, []) + if codecs: + print(f" {fmt}: {', '.join(codecs)}") + else: + print(f" {fmt}") + sys.exit(0) + + # 加载配置 + config = ConfigManager.load_config(args.config) + + # 显示配置 + if args.show_config: + print(header("当前配置")) + for key, value in config.items(): + if isinstance(value, list): + print(f" {key}: {', '.join(map(str, value))}") + else: + print(f" {key}: {value}") + sys.exit(0) + + # 合并命令行参数到配置 + config = ConfigManager.merge_configs(config, args) + + # 验证配置 + converter = AudioConverter() + errors = converter.validate_config(config) + if errors: + print(error("配置错误:")) + for err in errors: + print(f" {err}") + sys.exit(1) + + # 扩展输入文件 + input_files = expand_inputs(args.input) + + # 验证输入文件 + valid_files = validate_input_files(input_files, config) + + if not valid_files: + print(error("没有有效的输入文件")) + sys.exit(1) + + print(info(f"找到 {len(valid_files)} 个音频文件")) + + # 检查ffmpeg/avconv + if shutil.which("ffmpeg") is None and shutil.which("avconv") is None: + print(warning("未检测到 ffmpeg/avconv,部分格式可能无法处理")) + + # 确定输出路径 + output_path = Path(args.output) + + # 单个文件输出 + if len(valid_files) == 1: + input_file = valid_files[0] + + # 如果输出是目录 + if output_path.exists() and output_path.is_dir(): + output_format = config.get('output_format', 'wav').lower() + output_ext = converter.FORMAT_EXTENSIONS.get(output_format, f".{output_format}") + output_file = output_path / (input_file.stem + output_ext) + else: + output_file = output_path + + # 检查文件是否存在 + if output_file.exists() and not args.overwrite: + response = input(f"输出文件已存在: {output_file.name},是否覆盖? (y/n): ").lower() + if response not in ['y', 'yes']: + print(info("用户取消操作")) + sys.exit(0) + + # 创建输出目录 + output_file.parent.mkdir(parents=True, exist_ok=True) + + # 执行转换 + success = converter.convert_audio(input_file, output_file, config) + + if success: + print(success("转换完成")) + sys.exit(0) + else: + print(fail("转换失败")) + sys.exit(1) + + # 批量转换 + else: + # 输出必须是目录 + if output_path.exists() and output_path.is_file(): + print(error("多个输入文件时,输出必须为目录")) + sys.exit(1) + + # 检查目录中是否已有文件 + if output_path.exists(): + existing_files = list(output_path.glob("*")) + if existing_files and not args.overwrite: + response = input(f"输出目录已有 {len(existing_files)} 个文件,是否继续? (y/n): ").lower() + if response not in ['y', 'yes']: + print(info("用户取消操作")) + sys.exit(0) + + # 执行批量转换 + results = converter.batch_convert(valid_files, output_path, config) + + # 显示结果 + print(header("转换结果")) + print(info(f"总计: {results['total']} 个文件")) + print(ok(f"成功: {results['success']} 个")) + + if results['failed'] > 0: + print(error(f"失败: {results['failed']} 个")) + if results['failed_files'] and args.verbose > 0: + print(info("失败的文件:")) + for file_path in results['failed_files'][:10]: # 最多显示10个 + print(f" {Path(file_path).name}") + if len(results['failed_files']) > 10: + print(f" ... 还有 {len(results['failed_files']) - 10} 个") + + if results['success'] == results['total']: + print(success("所有文件转换成功!")) + elif results['success'] > 0: + print(info("部分文件转换完成")) + else: + print(fail("所有文件转换失败")) + + sys.exit(0 if results['success'] > 0 else 1) + +if __name__ == "__main__": + try: + main() + except KeyboardInterrupt: + print("\n" + info("用户中断操作")) + sys.exit(130) + except Exception as e: + print(error(f"程序错误: {e}")) + if __debug__: # 调试模式下显示详细错误 + import traceback + traceback.print_exc() + sys.exit(1) diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/tools/gtcrn_denoise.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/tools/gtcrn_denoise.py new file mode 100644 index 00000000..a2ffc902 --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/tools/gtcrn_denoise.py @@ -0,0 +1,150 @@ +#!/usr/bin/env python3 +""" +GTCRN 独立降噪小工具 + +特点: +- 面向用户直接使用,默认更偏单文件/目录处理 +- 支持本地 ONNX 模型,适合已下载权重的离线环境 +- 可选导出 ONNX(当输入是 .tar/.pt/.pth 时) + +默认参数: +- 输入:必填,可为单文件或目录 +- 模型:`models/gtcrn/gtcrn.onnx` +- 输出:如果是单文件则默认写到同目录下 `*_denoise.wav`; + 如果是目录则默认输出到 `output_data/denoise_tool` +""" + +import argparse +import sys +from pathlib import Path + +PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent +sys.path.insert(0, str(PROJECT_ROOT)) +sys.path.insert(0, str(PROJECT_ROOT / "src" / "utils")) + +from src.utils import gtcrn_denoise # type: ignore + +try: + from color_utils import info, warning, error, ok, success, header # type: ignore + + def print_info(msg: str): + print(info(msg)) + + def print_warning(msg: str): + print(warning(msg)) + + def print_error(msg: str): + print(error(msg)) + + def print_success(msg: str): + print(success(msg)) + + def print_header(msg: str): + print(header(msg)) + +except Exception: + def print_info(msg: str): + print(f"[INFO] {msg}") + + def print_warning(msg: str): + print(f"[WARNING] {msg}") + + def print_error(msg: str): + print(f"[ERROR] {msg}") + + def print_success(msg: str): + print(f"[SUCCESS] {msg}") + + def print_header(msg: str): + print(f"=== {msg} ===") + + +def main() -> int: + parser = argparse.ArgumentParser( + description="GTCRN 独立降噪工具(ONNX 优先)", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +示例: + # 单文件:默认输出到同目录 xxx_denoise.wav + python -m src.tools.gtcrn_denoise --input ./a.wav + + # 目录:默认输出到 output_data/denoise_tool + python -m src.tools.gtcrn_denoise --input ./input_dir + + # 显式指定模型和输出 + python -m src.tools.gtcrn_denoise --input ./input_dir --model ./models/gtcrn/gtcrn.onnx --output ./out_dir + + # 如果是 torch 权重,可导出 ONNX + python -m src.tools.gtcrn_denoise --input ./a.wav --model ./weights/model_trained_on_dns3.tar --export_dir ./models/gtcrn_onnx + """, + ) + parser.add_argument("--input", required=True, help="输入音频文件或目录") + parser.add_argument( + "--model", + default=str(PROJECT_ROOT / "models" / "gtcrn" / "gtcrn.onnx"), + help="GTCRN 模型路径,默认: models/gtcrn/gtcrn.onnx", + ) + parser.add_argument( + "--output", + default=None, + help="输出 wav 文件或目录;单文件默认同目录 *_denoise.wav,目录默认 output_data/denoise_tool", + ) + parser.add_argument( + "--export_dir", + default=None, + help="若输入为 .tar/.pt/.pth,则导出 ONNX 的目录", + ) + args = parser.parse_args() + + input_path = Path(args.input).resolve() + model_path = Path(args.model).resolve() + export_dir = Path(args.export_dir).resolve() if args.export_dir else None + if args.output: + output_path = Path(args.output).resolve() + else: + if input_path.is_file(): + output_path = input_path.with_name(f"{input_path.stem}_denoise.wav") + else: + output_path = PROJECT_ROOT / "output_data" / "denoise_tool" + + print_header("GTCRN 独立降噪") + print_info(f"输入: {input_path}") + print_info(f"模型: {model_path}") + print_info(f"输出: {output_path}") + + try: + resolved_model = gtcrn_denoise._resolve_model(model_path, export_dir=export_dir) # type: ignore[attr-defined] + print_info(f"使用模型: {resolved_model}") + denoiser = gtcrn_denoise.OnnxGtcrnDenoiser(resolved_model) # type: ignore[attr-defined] + except Exception as e: + print_error(f"初始化失败: {e}") + return 1 + + files = gtcrn_denoise._find_audio_files(input_path) # type: ignore[attr-defined] + if not files: + print_warning("未找到可处理的音频文件") + return 0 + + try: + if input_path.is_file(): + if output_path.suffix.lower() != ".wav": + output_path = output_path.with_suffix(".wav") + gtcrn_denoise.process_one(files[0], output_path, denoiser) # type: ignore[attr-defined] + print_success(f"完成: {output_path}") + else: + output_path.mkdir(parents=True, exist_ok=True) + for f in files: + out_file = output_path / f"{f.stem}.wav" + print_info(f"降噪: {f.name} -> {out_file.name}") + gtcrn_denoise.process_one(f, out_file, denoiser) # type: ignore[attr-defined] + print_success(f"批量完成,输出目录: {output_path}") + except Exception as e: + print_error(f"处理失败: {e}") + return 1 + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) + diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/tools/readme.txt b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/tools/readme.txt new file mode 100644 index 00000000..818bd243 --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/tools/readme.txt @@ -0,0 +1 @@ +这里是一些独立工具,不参与外面的处理流水线。 \ No newline at end of file diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/tools/recognize.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/tools/recognize.py new file mode 100644 index 00000000..80fa95dc --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/tools/recognize.py @@ -0,0 +1,189 @@ +#!/usr/bin/env python3 +""" +语音识别脚本(tools 副本) +调用 WeNet 进行音频转文本,支持中英文。路径相对本脚本所在 src/tools 解析。 +""" + +import argparse +import subprocess +import sys +import threading +import queue +from pathlib import Path + +# 从 src/utils 导入 color_utils +_TOOLS_DIR = Path(__file__).resolve().parent +_PROJECT_ROOT = _TOOLS_DIR.parent.parent +sys.path.insert(0, str(_PROJECT_ROOT / "src" / "utils")) + +try: + from color_utils import info, warning, error, ok, success, header + def print_info(msg): print(info(msg)) + def print_warning(msg): print(warning(msg)) + def print_error(msg): print(error(msg)) + def print_ok(msg): print(ok(msg)) + def print_success(msg): print(success(msg)) + def print_header(msg): print(header(msg)) +except ImportError: + def print_info(msg): print(f"[INFO] {msg}") + def print_warning(msg): print(f"[WARNING] {msg}") + def print_error(msg): print(f"[ERROR] {msg}") + def print_ok(msg): print(f"[OK] {msg}") + def print_success(msg): print(f"[SUCCESS] {msg}") + def print_header(msg): print(f"=== {msg} ===") + + +def get_project_root() -> Path: + return _PROJECT_ROOT + + +def check_npu_available() -> bool: + try: + import torch_npu + return True + except ImportError: + return len(list(Path("/dev").glob("davinci*"))) > 0 + + +def get_default_paths() -> dict: + root = get_project_root() + model_root = Path("/models/AudioOperations/asr") + return { + 'audio_list': root / "output_data" / "normalization" / "item.list", + 'result_dir': root / "output_data" / "asr", + 'wenet_wrapper': root / "src" / "utils" / "run_wenet.py", + 'aishell_model': model_root / "aishell" / "final.pt", + 'librispeech_model': model_root / "librispeech" / "final.pt", + } + + +def resolve_device(device_arg: str) -> str: + if device_arg == "auto": + return "npu" if check_npu_available() else "cpu" + if device_arg == "npu": + if not check_npu_available(): + raise ValueError("指定使用 NPU,但设备不支持 NPU") + return "npu" + if device_arg == "cpu": + return "cpu" + raise ValueError(f"不支持的设备类型: {device_arg}") + + +def check_paths(paths: dict, language: str) -> None: + if not paths['wenet_wrapper'].exists(): + raise FileNotFoundError(f"WeNet 包装器不存在: {paths['wenet_wrapper']}") + if not paths['audio_list'].exists(): + raise FileNotFoundError(f"音频列表不存在: {paths['audio_list']}") + paths['result_dir'].mkdir(parents=True, exist_ok=True) + if language == "zh" and not paths['aishell_model'].exists(): + raise FileNotFoundError(f"AIShell 模型不存在: {paths['aishell_model']}") + if language == "en" and not paths['librispeech_model'].exists(): + raise FileNotFoundError(f"LibriSpeech 模型不存在: {paths['librispeech_model']}") + + +def prepare_config(language: str) -> str: + if language not in ("zh", "en"): + raise ValueError(f"不支持的语言: {language}") + model_dir = Path("/models/AudioOperations/asr") / ("aishell" if language == "zh" else "librispeech") + yaml_files = list(model_dir.glob("*.yaml")) + if not yaml_files: + raise FileNotFoundError(f"未找到 YAML: {model_dir}") + for f in yaml_files: + if f.name == "train.yaml": + return str(f) + return str(yaml_files[0]) + + +def read_output(stream, output_queue, stream_name): + try: + for line in iter(stream.readline, ''): + if line: + output_queue.put((stream_name, line.rstrip('\n'))) + except Exception: + pass + finally: + stream.close() + + +def run_recognize(language: str, audio_list: str, result_dir: str, device: str) -> int: + paths = get_default_paths() + if audio_list: + paths['audio_list'] = Path(audio_list).resolve() + if result_dir: + paths['result_dir'] = Path(result_dir).resolve() + check_paths(paths, language) + config_file = prepare_config(language) + model_file = str(paths['aishell_model'] if language == "zh" else paths['librispeech_model']) + actual_device = resolve_device(device) + cmd = [ + sys.executable, str(paths['wenet_wrapper']), + "--mode", "ctc_greedy_search", "--device", actual_device, + "--config", config_file, "--test_data", str(paths['audio_list']), + "--checkpoint", model_file, "--batch_size", "1", + "--result_dir", str(paths['result_dir']), + ] + print_header("语音识别配置") + print_info(f"语言: {language} 设备: {actual_device}") + print_info(f"列表: {paths['audio_list']} 结果: {paths['result_dir']}") + try: + process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, + text=True, encoding='utf-8', bufsize=1, universal_newlines=True) + output_queue = queue.Queue() + for stream, name in [(process.stdout, 'stdout'), (process.stderr, 'stderr')]: + t = threading.Thread(target=read_output, args=(stream, output_queue, name)) + t.daemon = True + t.start() + while True: + try: + _, line = output_queue.get(timeout=0.1) + print(line) + except queue.Empty: + if process.poll() is not None: + try: + while True: + _, line = output_queue.get_nowait() + print(line) + except queue.Empty: + pass + break + return_code = process.wait() + print("-" * 80) + if return_code == 0: + print_success("语音识别完成!") + return 0 + print_error(f"识别失败,返回码: {return_code}") + return return_code + except Exception as e: + print_error(str(e)) + import traceback + traceback.print_exc() + return 1 + + +def main(): + defaults = get_default_paths() + parser = argparse.ArgumentParser(description="语音识别 - WeNet 音频转文本") + parser.add_argument("--language", "-l", choices=["zh", "en"], default="zh") + parser.add_argument("--audio_list", "-a", default=str(defaults['audio_list'])) + parser.add_argument("--result_dir", "-r", default=str(defaults['result_dir'])) + parser.add_argument("--device", "-d", choices=["auto", "npu", "cpu"], default="npu") + args = parser.parse_args() + print_header("语音识别") + try: + import torch + print_info(f"PyTorch: {torch.__version__}") + except ImportError: + print_error("未安装 PyTorch") + return 1 + if not defaults['wenet_wrapper'].exists(): + print_warning("WeNet 包装器不存在,请从 src/utils 运行或创建") + return 1 + try: + return run_recognize(args.language, args.audio_list, args.result_dir, args.device) + except (ValueError, FileNotFoundError) as e: + print_error(str(e)) + return 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/tools/split_audio.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/tools/split_audio.py new file mode 100644 index 00000000..bd791d8b --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/tools/split_audio.py @@ -0,0 +1,105 @@ +#!/usr/bin/env python3 +""" +切分音频小工具:将长音频按指定时长切分为多个片段并导出为 wav。 +不处理 list 文件,仅做目录/文件切分。 +""" + +import argparse +import sys +from pathlib import Path +from typing import List + +_PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent +try: + from pydub import AudioSegment # type: ignore +except ImportError: + AudioSegment = None + + +def split_one( + wav_path: Path, + output_dir: Path, + max_seconds: int, + base_name: str, +) -> int: + """将单个文件切分,返回生成的片段数。""" + if AudioSegment is None: + raise RuntimeError("请在 DataMate 运行环境安装 pydub") + audio = AudioSegment.from_file(str(wav_path)) + duration_ms = len(audio) + segment_ms = max(1, max_seconds) * 1000 + output_dir.mkdir(parents=True, exist_ok=True) + count = 0 + start_ms = 0 + while start_ms < duration_ms: + end_ms = min(start_ms + segment_ms, duration_ms) + chunk = audio[start_ms:end_ms] + out_path = output_dir / f"{base_name}_part{count}.wav" + chunk.export(str(out_path), format="wav") + count += 1 + start_ms = end_ms + return count + + +def main() -> int: + parser = argparse.ArgumentParser( + description="按指定时长切分音频为多个 wav 片段", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument( + "input", + nargs="?", + default=None, + help="输入音频文件或目录(目录则处理其中所有 wav)", + ) + parser.add_argument( + "--output_dir", "-o", + required=True, + help="输出目录", + ) + parser.add_argument( + "--max_seconds", "-s", + type=int, + default=120, + help="每段最大秒数,默认 120", + ) + args = parser.parse_args() + + if not args.input: + parser.error("请指定输入文件或目录") + if AudioSegment is None: + print("[ERROR] 无法导入 pydub", file=sys.stderr) + return 1 + + inp = Path(args.input).resolve() + out_dir = Path(args.output_dir).resolve() + if not inp.exists(): + print(f"[ERROR] 不存在: {inp}", file=sys.stderr) + return 1 + + files: List[Path] = [] + if inp.is_file(): + if inp.suffix.lower() not in (".wav", ".mp3", ".flac", ".m4a", ".aac"): + print("[WARNING] 非常见音频格式,尝试继续", file=sys.stderr) + files.append(inp) + else: + for ext in ("*.wav", "*.WAV", "*.mp3", "*.flac", "*.m4a", "*.aac"): + files.extend(inp.rglob(ext)) + files = sorted(set(files)) + + if not files: + print("[WARNING] 未找到音频文件", file=sys.stderr) + return 0 + + total = 0 + for f in files: + base = f.stem + n = split_one(f, out_dir, args.max_seconds, base) + total += n + print(f"[INFO] {f.name} -> {n} 段") + print(f"[OK] 共生成 {total} 个片段 -> {out_dir}") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/utils/color_utils.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/utils/color_utils.py new file mode 100644 index 00000000..c58a083d --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/utils/color_utils.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python3 +""" +命令行日志标签工具。 + +DataMate/Ray 日志会直接展示 stdout,ANSI 颜色控制符会污染页面日志, +因此这里保留原函数名但只输出纯文本标签。 +""" + +class Colors: + """兼容旧调用的空颜色代码。""" + BLACK = RED = GREEN = YELLOW = BLUE = MAGENTA = CYAN = WHITE = "" + BG_BLACK = BG_RED = BG_GREEN = BG_YELLOW = BG_BLUE = BG_MAGENTA = BG_CYAN = BG_WHITE = "" + BOLD = UNDERLINE = BLINK = REVERSE = RESET = "" + + +def color_text(text: str, color: str, bold: bool = False) -> str: + """给文本添加颜色 + + Args: + text: 要着色的文本 + color: 颜色代码 + bold: 是否加粗 + + Returns: + str: 带颜色代码的文本 + """ + return text + + +def info(msg: str) -> str: + """INFO 级别消息""" + return f"[INFO] {msg}" + + +def warning(msg: str) -> str: + """WARNING 级别消息""" + return f"[WARNING] {msg}" + + +def error(msg: str) -> str: + """ERROR 级别消息""" + return f"[ERROR] {msg}" + + +def ok(msg: str) -> str: + """OK 级别消息""" + return f"[OK] {msg}" + + +def header(msg: str) -> str: + """标题""" + return f"[PROCESS] {msg}" + + +def success(msg: str) -> str: + """成功消息""" + return f"[SUCCESS] {msg}" + + +def fail(msg: str) -> str: + """失败消息""" + return f"[ERROR] {msg}" + + +def question(msg: str) -> str: + """问题消息""" + return f"[WARNING] {msg}" diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/utils/compute_wer.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/utils/compute_wer.py new file mode 100644 index 00000000..e413a274 --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/utils/compute_wer.py @@ -0,0 +1,553 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import re, sys, unicodedata +import codecs + +remove_tag = True +spacelist = [' ', '\t', '\r', '\n'] +puncts = [ + '!', ',', '?', '、', '。', '!', ',', ';', '?', ':', '「', '」', '︰', '『', '』', + '《', '》' +] + + +def characterize(string): + res = [] + i = 0 + while i < len(string): + char = string[i] + if char in puncts: + i += 1 + continue + cat1 = unicodedata.category(char) + #https://unicodebook.readthedocs.io/unicode.html#unicode-categories + if cat1 == 'Zs' or cat1 == 'Cn' or char in spacelist: # space or not assigned + i += 1 + continue + if cat1 == 'Lo': # letter-other + res.append(char) + i += 1 + else: + # some input looks like: , we want to separate it to two words. + sep = ' ' + if char == '<': sep = '>' + j = i + 1 + while j < len(string): + c = string[j] + if ord(c) >= 128 or (c in spacelist) or (c == sep): + break + j += 1 + if j < len(string) and string[j] == '>': + j += 1 + res.append(string[i:j]) + i = j + return res + + +def stripoff_tags(x): + if not x: return '' + chars = [] + i = 0 + T = len(x) + while i < T: + if x[i] == '<': + while i < T and x[i] != '>': + i += 1 + i += 1 + else: + chars.append(x[i]) + i += 1 + return ''.join(chars) + + +def normalize(sentence, ignore_words, cs, split=None): + """ sentence, ignore_words are both in unicode + """ + new_sentence = [] + for token in sentence: + x = token + if not cs: + x = x.upper() + if x in ignore_words: + continue + if remove_tag: + x = stripoff_tags(x) + if not x: + continue + if split and x in split: + new_sentence += split[x] + else: + new_sentence.append(x) + return new_sentence + + +class Calculator: + + def __init__(self): + self.data = {} + self.space = [] + self.cost = {} + self.cost['cor'] = 0 + self.cost['sub'] = 1 + self.cost['del'] = 1 + self.cost['ins'] = 1 + + def calculate(self, lab, rec): + # Initialization + lab.insert(0, '') + rec.insert(0, '') + while len(self.space) < len(lab): + self.space.append([]) + for row in self.space: + for element in row: + element['dist'] = 0 + element['error'] = 'non' + while len(row) < len(rec): + row.append({'dist': 0, 'error': 'non'}) + for i in range(len(lab)): + self.space[i][0]['dist'] = i + self.space[i][0]['error'] = 'del' + for j in range(len(rec)): + self.space[0][j]['dist'] = j + self.space[0][j]['error'] = 'ins' + self.space[0][0]['error'] = 'non' + for token in lab: + if token not in self.data and len(token) > 0: + self.data[token] = { + 'all': 0, + 'cor': 0, + 'sub': 0, + 'ins': 0, + 'del': 0 + } + for token in rec: + if token not in self.data and len(token) > 0: + self.data[token] = { + 'all': 0, + 'cor': 0, + 'sub': 0, + 'ins': 0, + 'del': 0 + } + # Computing edit distance + for i, lab_token in enumerate(lab): + for j, rec_token in enumerate(rec): + if i == 0 or j == 0: + continue + min_dist = sys.maxsize + min_error = 'none' + dist = self.space[i - 1][j]['dist'] + self.cost['del'] + error = 'del' + if dist < min_dist: + min_dist = dist + min_error = error + dist = self.space[i][j - 1]['dist'] + self.cost['ins'] + error = 'ins' + if dist < min_dist: + min_dist = dist + min_error = error + if lab_token == rec_token: + dist = self.space[i - 1][j - 1]['dist'] + self.cost['cor'] + error = 'cor' + else: + dist = self.space[i - 1][j - 1]['dist'] + self.cost['sub'] + error = 'sub' + if dist < min_dist: + min_dist = dist + min_error = error + self.space[i][j]['dist'] = min_dist + self.space[i][j]['error'] = min_error + # Tracing back + result = { + 'lab': [], + 'rec': [], + 'all': 0, + 'cor': 0, + 'sub': 0, + 'ins': 0, + 'del': 0 + } + i = len(lab) - 1 + j = len(rec) - 1 + while True: + if self.space[i][j]['error'] == 'cor': # correct + if len(lab[i]) > 0: + self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 + self.data[lab[i]]['cor'] = self.data[lab[i]]['cor'] + 1 + result['all'] = result['all'] + 1 + result['cor'] = result['cor'] + 1 + result['lab'].insert(0, lab[i]) + result['rec'].insert(0, rec[j]) + i = i - 1 + j = j - 1 + elif self.space[i][j]['error'] == 'sub': # substitution + if len(lab[i]) > 0: + self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 + self.data[lab[i]]['sub'] = self.data[lab[i]]['sub'] + 1 + result['all'] = result['all'] + 1 + result['sub'] = result['sub'] + 1 + result['lab'].insert(0, lab[i]) + result['rec'].insert(0, rec[j]) + i = i - 1 + j = j - 1 + elif self.space[i][j]['error'] == 'del': # deletion + if len(lab[i]) > 0: + self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 + self.data[lab[i]]['del'] = self.data[lab[i]]['del'] + 1 + result['all'] = result['all'] + 1 + result['del'] = result['del'] + 1 + result['lab'].insert(0, lab[i]) + result['rec'].insert(0, "") + i = i - 1 + elif self.space[i][j]['error'] == 'ins': # insertion + if len(rec[j]) > 0: + self.data[rec[j]]['ins'] = self.data[rec[j]]['ins'] + 1 + result['ins'] = result['ins'] + 1 + result['lab'].insert(0, "") + result['rec'].insert(0, rec[j]) + j = j - 1 + elif self.space[i][j]['error'] == 'non': # starting point + break + else: # shouldn't reach here + print( + 'this should not happen , i = {i} , j = {j} , error = {error}' + .format(i=i, j=j, error=self.space[i][j]['error'])) + return result + + def overall(self): + result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0} + for token in self.data: + result['all'] = result['all'] + self.data[token]['all'] + result['cor'] = result['cor'] + self.data[token]['cor'] + result['sub'] = result['sub'] + self.data[token]['sub'] + result['ins'] = result['ins'] + self.data[token]['ins'] + result['del'] = result['del'] + self.data[token]['del'] + return result + + def cluster(self, data): + result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0} + for token in data: + if token in self.data: + result['all'] = result['all'] + self.data[token]['all'] + result['cor'] = result['cor'] + self.data[token]['cor'] + result['sub'] = result['sub'] + self.data[token]['sub'] + result['ins'] = result['ins'] + self.data[token]['ins'] + result['del'] = result['del'] + self.data[token]['del'] + return result + + def keys(self): + return list(self.data.keys()) + + +def width(string): + return sum(1 + (unicodedata.east_asian_width(c) in "AFW") for c in string) + + +def default_cluster(word): + unicode_names = [unicodedata.name(char) for char in word] + for i in reversed(range(len(unicode_names))): + if unicode_names[i].startswith('DIGIT'): # 1 + unicode_names[i] = 'Number' # 'DIGIT' + elif (unicode_names[i].startswith('CJK UNIFIED IDEOGRAPH') + or unicode_names[i].startswith('CJK COMPATIBILITY IDEOGRAPH')): + # 明 / 郎 + unicode_names[i] = 'Mandarin' # 'CJK IDEOGRAPH' + elif (unicode_names[i].startswith('LATIN CAPITAL LETTER') + or unicode_names[i].startswith('LATIN SMALL LETTER')): + # A / a + unicode_names[i] = 'English' # 'LATIN LETTER' + elif unicode_names[i].startswith('HIRAGANA LETTER'): # は こ め + unicode_names[i] = 'Japanese' # 'GANA LETTER' + elif (unicode_names[i].startswith('AMPERSAND') + or unicode_names[i].startswith('APOSTROPHE') + or unicode_names[i].startswith('COMMERCIAL AT') + or unicode_names[i].startswith('DEGREE CELSIUS') + or unicode_names[i].startswith('EQUALS SIGN') + or unicode_names[i].startswith('FULL STOP') + or unicode_names[i].startswith('HYPHEN-MINUS') + or unicode_names[i].startswith('LOW LINE') + or unicode_names[i].startswith('NUMBER SIGN') + or unicode_names[i].startswith('PLUS SIGN') + or unicode_names[i].startswith('SEMICOLON')): + # & / ' / @ / ℃ / = / . / - / _ / # / + / ; + del unicode_names[i] + else: + return 'Other' + if len(unicode_names) == 0: + return 'Other' + if len(unicode_names) == 1: + return unicode_names[0] + for i in range(len(unicode_names) - 1): + if unicode_names[i] != unicode_names[i + 1]: + return 'Other' + return unicode_names[0] + + +def usage(): + print( + "compute-wer.py : compute word error rate (WER) and align recognition results and references." + ) + print( + " usage : python compute-wer.py [--cs={0,1}] [--cluster=foo] [--ig=ignore_file] [--char={0,1}] [--v={0,1}] [--padding-symbol={space,underline}] test.ref test.hyp > test.wer" + ) + + +if __name__ == '__main__': + if len(sys.argv) == 1: + usage() + sys.exit(0) + calculator = Calculator() + cluster_file = '' + ignore_words = set() + tochar = False + verbose = 1 + padding_symbol = ' ' + case_sensitive = False + max_words_per_line = sys.maxsize + split = None + while len(sys.argv) > 3: + a = '--maxw=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):] + del sys.argv[1] + max_words_per_line = int(b) + continue + a = '--rt=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):].lower() + del sys.argv[1] + remove_tag = (b == 'true') or (b != '0') + continue + a = '--cs=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):].lower() + del sys.argv[1] + case_sensitive = (b == 'true') or (b != '0') + continue + a = '--cluster=' + if sys.argv[1].startswith(a): + cluster_file = sys.argv[1][len(a):] + del sys.argv[1] + continue + a = '--splitfile=' + if sys.argv[1].startswith(a): + split_file = sys.argv[1][len(a):] + del sys.argv[1] + split = dict() + with codecs.open(split_file, 'r', 'utf-8') as fh: + for line in fh: # line in unicode + words = line.strip().split() + if len(words) >= 2: + split[words[0]] = words[1:] + continue + a = '--ig=' + if sys.argv[1].startswith(a): + ignore_file = sys.argv[1][len(a):] + del sys.argv[1] + with codecs.open(ignore_file, 'r', 'utf-8') as fh: + for line in fh: # line in unicode + line = line.strip() + if len(line) > 0: + ignore_words.add(line) + continue + a = '--char=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):].lower() + del sys.argv[1] + tochar = (b == 'true') or (b != '0') + continue + a = '--v=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):].lower() + del sys.argv[1] + verbose = 0 + try: + verbose = int(b) + except: + if b == 'true' or b != '0': + verbose = 1 + continue + a = '--padding-symbol=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):].lower() + del sys.argv[1] + if b == 'space': + padding_symbol = ' ' + elif b == 'underline': + padding_symbol = '_' + continue + if True or sys.argv[1].startswith('-'): + #ignore invalid switch + del sys.argv[1] + continue + + if not case_sensitive: + ig = set([w.upper() for w in ignore_words]) + ignore_words = ig + + default_clusters = {} + default_words = {} + + ref_file = sys.argv[1] + hyp_file = sys.argv[2] + rec_set = {} + if split and not case_sensitive: + newsplit = dict() + for w in split: + words = split[w] + for i in range(len(words)): + words[i] = words[i].upper() + newsplit[w.upper()] = words + split = newsplit + + with codecs.open(hyp_file, 'r', 'utf-8') as fh: + for line in fh: + if tochar: + array = characterize(line) + else: + array = line.strip().split() + if len(array) == 0: continue + fid = array[0] + rec_set[fid] = normalize(array[1:], ignore_words, case_sensitive, + split) + + # compute error rate on the interaction of reference file and hyp file + for line in open(ref_file, 'r', encoding='utf-8'): + if tochar: + array = characterize(line) + else: + array = line.rstrip('\n').split() + if len(array) == 0: continue + fid = array[0] + if fid not in rec_set: + continue + lab = normalize(array[1:], ignore_words, case_sensitive, split) + rec = rec_set[fid] + if verbose: + print('\nutt: %s' % fid) + + for word in rec + lab: + if word not in default_words: + default_cluster_name = default_cluster(word) + if default_cluster_name not in default_clusters: + default_clusters[default_cluster_name] = {} + if word not in default_clusters[default_cluster_name]: + default_clusters[default_cluster_name][word] = 1 + default_words[word] = default_cluster_name + + result = calculator.calculate(lab, rec) + if verbose: + if result['all'] != 0: + wer = float(result['ins'] + result['sub'] + + result['del']) * 100.0 / result['all'] + else: + wer = 0.0 + print('WER: %4.2f %%' % wer, end=' ') + print('N=%d C=%d S=%d D=%d I=%d' % + (result['all'], result['cor'], result['sub'], result['del'], + result['ins'])) + space = {} + space['lab'] = [] + space['rec'] = [] + for idx in range(len(result['lab'])): + len_lab = width(result['lab'][idx]) + len_rec = width(result['rec'][idx]) + length = max(len_lab, len_rec) + space['lab'].append(length - len_lab) + space['rec'].append(length - len_rec) + upper_lab = len(result['lab']) + upper_rec = len(result['rec']) + lab1, rec1 = 0, 0 + while lab1 < upper_lab or rec1 < upper_rec: + if verbose > 1: + print('lab(%s):' % fid.encode('utf-8'), end=' ') + else: + print('lab:', end=' ') + lab2 = min(upper_lab, lab1 + max_words_per_line) + for idx in range(lab1, lab2): + token = result['lab'][idx] + print('{token}'.format(token=token), end='') + for n in range(space['lab'][idx]): + print(padding_symbol, end='') + print(' ', end='') + print() + if verbose > 1: + print('rec(%s):' % fid.encode('utf-8'), end=' ') + else: + print('rec:', end=' ') + rec2 = min(upper_rec, rec1 + max_words_per_line) + for idx in range(rec1, rec2): + token = result['rec'][idx] + print('{token}'.format(token=token), end='') + for n in range(space['rec'][idx]): + print(padding_symbol, end='') + print(' ', end='') + print('\n', end='\n') + lab1 = lab2 + rec1 = rec2 + + if verbose: + print( + '===========================================================================' + ) + print() + + result = calculator.overall() + if result['all'] != 0: + wer = float(result['ins'] + result['sub'] + + result['del']) * 100.0 / result['all'] + else: + wer = 0.0 + print('Overall -> %4.2f %%' % wer, end=' ') + print('N=%d C=%d S=%d D=%d I=%d' % + (result['all'], result['cor'], result['sub'], result['del'], + result['ins'])) + if not verbose: + print() + + if verbose: + for cluster_id in default_clusters: + result = calculator.cluster( + [k for k in default_clusters[cluster_id]]) + if result['all'] != 0: + wer = float(result['ins'] + result['sub'] + + result['del']) * 100.0 / result['all'] + else: + wer = 0.0 + print('%s -> %4.2f %%' % (cluster_id, wer), end=' ') + print('N=%d C=%d S=%d D=%d I=%d' % + (result['all'], result['cor'], result['sub'], result['del'], + result['ins'])) + if len(cluster_file) > 0: # compute separated WERs for word clusters + cluster_id = '' + cluster = [] + for line in open(cluster_file, 'r', encoding='utf-8'): + for token in line.decode('utf-8').rstrip('\n').split(): + # end of cluster reached, like + if token[0:2] == '' and \ + token.lstrip('') == cluster_id : + result = calculator.cluster(cluster) + if result['all'] != 0: + wer = float(result['ins'] + result['sub'] + + result['del']) * 100.0 / result['all'] + else: + wer = 0.0 + print('%s -> %4.2f %%' % (cluster_id, wer), end=' ') + print('N=%d C=%d S=%d D=%d I=%d' % + (result['all'], result['cor'], result['sub'], + result['del'], result['ins'])) + cluster_id = '' + cluster = [] + # begin of cluster reached, like + elif token[0] == '<' and token[len(token)-1] == '>' and \ + cluster_id == '' : + cluster_id = token.lstrip('<').rstrip('>') + cluster = [] + # general terms, like WEATHER / CAR / ... + else: + cluster.append(token) + print() + print( + '===========================================================================' + ) diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/utils/fast_lang_id.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/utils/fast_lang_id.py new file mode 100644 index 00000000..e314b706 --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/utils/fast_lang_id.py @@ -0,0 +1,483 @@ +#!/usr/bin/env python3 +""" +超快速中英语言识别(LID) + +读取 generate_audio_list.py 生成的 item.list(jsonl) 或直接扫描目录中的音频文件, +使用 DataMate 运行环境中的 SpeechBrain 预训练 LID 模型做语言识别,并输出带 lang 字段的 jsonl。 + +设计目标: +- 极快:默认只取音频前几秒做判断 +- 批处理:减少模型调用开销 +- 仅中英二分类:识别结果为 zh(中文)或 en(英文),其他语言统一归为 en +""" + +import argparse +import json +import sys +import traceback +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, Iterable, List, Optional, Tuple + + +# 添加脚本所在目录到系统路径,导入颜色工具(保持与 generate_audio_list.py 一致的风格) +try: + sys.path.insert(0, str(Path(__file__).parent.parent.parent / "scripts" / "audio_convert")) + from color_utils import info, warning, error, ok, success, header # type: ignore +except Exception: + def info(msg: str) -> str: + return f"[INFO] {msg}" + + def warning(msg: str) -> str: + return f"[WARNING] {msg}" + + def error(msg: str) -> str: + return f"[ERROR] {msg}" + + def ok(msg: str) -> str: + return f"[OK] {msg}" + + def success(msg: str) -> str: + return f"[SUCCESS] {msg}" + + def header(msg: str) -> str: + return f"=== {msg} ===" + + def print_info(msg: str): + print(info(msg)) + + def print_warning(msg: str): + print(warning(msg)) + + def print_error(msg: str): + print(error(msg)) + + def print_ok(msg: str): + print(ok(msg)) + + def print_success(msg: str): + print(success(msg)) + + def print_header(msg: str): + print(header(msg)) +else: + def print_info(msg: str): + print(info(msg)) + + def print_warning(msg: str): + print(warning(msg)) + + def print_error(msg: str): + print(error(msg)) + + def print_ok(msg: str): + print(ok(msg)) + + def print_success(msg: str): + print(success(msg)) + + def print_header(msg: str): + print(header(msg)) + + +def _project_root() -> Path: + return Path(__file__).parent.parent.parent + + +def _ensure_speechbrain_on_path() -> None: + """SpeechBrain is provided by the DataMate runtime environment.""" + return None + + +def _patch_yaml_loader_max_depth() -> None: + """兼容部分 PyYAML/HyperPyYAML 组合缺失 Loader.max_depth 的问题。""" + try: + import yaml # type: ignore + + for name in ("Loader", "SafeLoader", "FullLoader", "UnsafeLoader"): + loader = getattr(yaml, name, None) + if loader is not None and not hasattr(loader, "max_depth"): + setattr(loader, "max_depth", 1000) + except Exception: + pass + try: + import ruamel.yaml # type: ignore + + for name in ("Loader", "SafeLoader", "RoundTripLoader", "BaseLoader"): + loader = getattr(ruamel.yaml, name, None) + if loader is not None and not hasattr(loader, "max_depth"): + setattr(loader, "max_depth", 1000) + except Exception: + pass + + +def _find_audio_files(audio_dir: Path) -> List[Path]: + patterns = ["*.wav", "*.WAV", "*.flac", "*.FLAC", "*.mp3", "*.MP3", "*.aac", "*.AAC", "*.m4a", "*.M4A"] + files: List[Path] = [] + for pat in patterns: + files.extend(audio_dir.rglob(pat)) + return sorted(set(files)) + + +def _load_jsonl_items(path: Path, filter_ok_only: bool = False) -> List[Dict]: + items: List[Dict] = [] + with open(path, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + items.append(json.loads(line)) + + if not filter_ok_only: + return items + + filtered = [it for it in items if it.get("quality_flag", "ok") == "ok"] + if not items: + return items + print_info(f"质量过滤后保留 {len(filtered)}/{len(items)} 条,仅识别 quality_flag=='ok' 的音频") + return filtered + + +def _dump_jsonl_items(path: Path, items: Iterable[Dict]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "w", encoding="utf-8") as f: + for it in items: + f.write(json.dumps(it, ensure_ascii=False) + "\n") + + +def _iso_to_zh_en(lid_label: str) -> str: + """ + 将 LID 模型输出映射为仅两种:zh(中文)或 en(英文)。 + 模型可能返回 "en: English"、"zh: Chinese" 等,取冒号前作为语言码再判断。 + 中文相关 ISO 码映射为 zh,其余一律为 en。 + """ + raw = (lid_label or "").strip() + if ":" in raw: + iso = raw.split(":", 1)[0].strip().lower() + else: + iso = raw.lower() + zh_aliases = {"zh", "cmn", "yue", "wuu", "nan", "cdo", "cjy", "hsn", "hak"} + if iso in zh_aliases: + return "zh" + return "en" + + +def _out_item(it: Dict, lang: str) -> Dict: + """只保留 key、wav、txt、lang 四列,供输出 jsonl 使用。""" + return { + "key": it.get("key", ""), + "wav": it.get("wav") or it.get("audio") or it.get("path", ""), + "txt": it.get("txt", ""), + "lang": lang, + } + + +def _batch_iter(xs: List[Dict], batch_size: int) -> Iterable[List[Dict]]: + for i in range(0, len(xs), batch_size): + yield xs[i : i + batch_size] + + +def _lid_predict_items( + items: List[Dict], + model_source: str, + model_savedir: Path, + device: str, + batch_size: int, + max_seconds: float, +) -> List[Dict]: + _ensure_speechbrain_on_path() + _patch_yaml_loader_max_depth() + + # 这里延迟导入,避免只跑 --help 时加载 torch/torchaudio + import torch # type: ignore + from types import SimpleNamespace + + # 兼容旧版 torch:SpeechBrain 可能会引用 torch.amp.custom_fwd/custom_bwd + # - torch>=2.0: torch.amp.custom_fwd/custom_bwd(支持 device_type 等参数) + # - torch<2.0: torch.cuda.amp.custom_fwd/custom_bwd(签名可能更旧,不支持 device_type) + try: + has_amp = hasattr(torch, "amp") + has_custom_fwd = has_amp and hasattr(torch.amp, "custom_fwd") + has_custom_bwd = has_amp and hasattr(torch.amp, "custom_bwd") + if not (has_custom_fwd and has_custom_bwd): + try: + from torch.cuda.amp import custom_fwd as _custom_fwd # type: ignore + from torch.cuda.amp import custom_bwd as _custom_bwd # type: ignore + except Exception: + # 退化为 no-op 装饰器(不启用 AMP 也能推理) + def _custom_fwd(*_args, **_kwargs): # type: ignore + def _decorator(fn): + return fn + + return _decorator + + def _custom_bwd(*_args, **_kwargs): # type: ignore + def _decorator(fn): + return fn + + return _decorator + + if not hasattr(torch, "amp"): + torch.amp = SimpleNamespace() # type: ignore[attr-defined] + + def _drop_unsupported_kwargs(deco): # type: ignore + def _wrapped(*args, **kwargs): + # 旧版 deco 可能不支持 device_type 等 kwargs;这里直接丢弃所有 kwargs + # 保证能作为装饰器正常使用 + return deco(*args) + + return _wrapped + + torch.amp.custom_fwd = _drop_unsupported_kwargs(_custom_fwd) # type: ignore[attr-defined] + torch.amp.custom_bwd = _drop_unsupported_kwargs(_custom_bwd) # type: ignore[attr-defined] + except Exception: + # 不让兼容逻辑影响主流程;真正的导入错误会在后面暴露 + pass + + from speechbrain.inference.classifiers import EncoderClassifier # type: ignore + + # 使用本地目录:/abs/path/to/model_dir + src_path = Path(model_source) + is_local_dir = src_path.exists() and src_path.is_dir() + resolved_source = str(src_path.resolve()) if is_local_dir else model_source + + overrides = {} + if is_local_dir: + # hyperparams.yaml 里的 pretrained_path 可能不是本地路径,这里强制指向本地目录。 + overrides = {"pretrained_path": resolved_source} + + # 预先检查必需权重是否存在,避免长时间卡在 fetch/重试 + required = ["hyperparams.yaml", "label_encoder.txt", "embedding_model.ckpt", "classifier.ckpt"] + missing = [fn for fn in required if not (src_path / fn).exists()] + if missing: + raise RuntimeError( + "本地 LID 模型目录不完整,缺少必要文件:\n" + + "\n".join([f"- {src_path / fn}" for fn in missing]) + + "\n\n请检查本地模型目录是否完整。" + ) + try: + classifier = EncoderClassifier.from_hparams( + source=resolved_source, + savedir=str(model_savedir), + run_opts={"device": device}, + overrides=overrides, + ) + except Exception as e: + raise RuntimeError( + "加载 SpeechBrain LID 模型失败。\n" + f"- source={model_source}\n" + f"- savedir={model_savedir}\n" + f"- device={device}\n" + f"- error={type(e).__name__}: {e}" + ) from e + + out_items: List[Dict] = [] + total = len(items) + done = 0 + + for batch in _batch_iter(items, batch_size): + wav_tensors: List[torch.Tensor] = [] + wav_lens: List[float] = [] + ok_mask: List[bool] = [] + + for it in batch: + wav_path = it.get("wav") or it.get("audio") or it.get("path") + if not wav_path: + ok_mask.append(False) + continue + try: + sig = classifier.load_audio(str(wav_path)) + # sig: [time] 或 [channels, time],speechbrain load_audio 通常返回 [time] + if sig.ndim > 1: + sig = sig.mean(dim=0) + if max_seconds > 0: + max_samples = int(16000 * max_seconds) + sig = sig[:max_samples] + if sig.numel() == 0: + ok_mask.append(False) + continue + wav_tensors.append(sig) + wav_lens.append(float(sig.shape[0])) + ok_mask.append(True) + except Exception: + ok_mask.append(False) + + if not wav_tensors: + for it in batch: + out_items.append(_out_item(it, "en")) + done += len(batch) + continue + + max_len = max(int(x.shape[0]) for x in wav_tensors) + padded = torch.zeros((len(wav_tensors), max_len), dtype=torch.float32) + lens_rel = torch.zeros((len(wav_tensors),), dtype=torch.float32) + for i, sig in enumerate(wav_tensors): + L = int(sig.shape[0]) + padded[i, :L] = sig.float() + lens_rel[i] = float(L) / float(max_len) if max_len > 0 else 1.0 + + with torch.inference_mode(): + out_prob, score, index, text_lab = classifier.classify_batch(padded, lens_rel) + + pred_i = 0 + for it, ok_ in zip(batch, ok_mask): + if not ok_: + out_items.append(_out_item(it, "en")) + else: + lid_label = str(text_lab[pred_i]) if isinstance(text_lab, list) else str(text_lab) + lang = _iso_to_zh_en(lid_label) + out_items.append(_out_item(it, lang)) + pred_i += 1 + + done += len(batch) + if done % max(10, batch_size) == 0 or done == total: + print_info(f"LID 进度: {done}/{total}") + + return out_items + + +def parse_arguments(): + default_models_dir = _project_root() / "models" / "lid" + default_local_model_dir = default_models_dir / "speechbrain_lang-id-voxlingua107-ecapa" + default_savedir = default_models_dir / "_speechbrain_cache" / "lang-id-voxlingua107-ecapa" + default_audio_dir = _project_root() / "output_data" / "denoise" + default_quality_list = _project_root() / "output_data" / "denoise" / "item_with_quality.list" + default_output = _project_root() / "output_data" / "lid" / "item_with_lang.list" + + parser = argparse.ArgumentParser( + description="超快速中英语言识别(SpeechBrain),仅输出 zh/en", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=rf""" +示例: + # 默认:直接扫描 output_data/denoise 下所有音频 + python -m src.utils.fast_lang_id + + # 启用质量过滤:默认读取 item_with_quality.list,并且仅识别 ok 音频 + python -m src.utils.fast_lang_id --filter-audio=True + + # 启用质量过滤,但自定义过滤列表路径 + python -m src.utils.fast_lang_id --filter-audio=True --filter-audio-list ./somewhere/item_with_quality.list + + # 显式指定输入列表 + python -m src.utils.fast_lang_id --input_list ./output_data/denoise/item.list + """, + ) + + g = parser.add_mutually_exclusive_group(required=False) + g.add_argument( + "--input_list", + "-i", + default=None, + help="输入列表文件(jsonl,每行包含 wav 字段;若包含 quality_flag 字段则仅使用 quality_flag=='ok' 的条目)", + ) + g.add_argument("--audio_dir", "-a", default=str(default_audio_dir), help=f"直接扫描目录下音频文件,默认: {default_audio_dir}") + + parser.add_argument("--output", "-o", default=str(default_output), help=f"输出列表文件路径,默认: {default_output}") + parser.add_argument( + "--filter-audio", + default="False", + help="是否启用质量过滤;True 时默认读取 item_with_quality.list 并只识别 ok 音频", + ) + parser.add_argument( + "--filter-audio-list", + default=str(default_quality_list), + help=f"质量过滤列表路径,默认: {default_quality_list}", + ) + parser.add_argument( + "--model_source", + default=str(default_local_model_dir), + help="SpeechBrain LID 本地模型目录。", + ) + parser.add_argument("--model_savedir", default=str(default_savedir), help=f"模型缓存目录,默认: {default_savedir}") + parser.add_argument("--device", default="cpu", help="推理设备,例如 cpu / cuda / npu(取决于 torch 环境)") + parser.add_argument("--batch_size", type=int, default=8, help="批大小(越大越快,但更吃内存)") + parser.add_argument("--max_seconds", type=float, default=3.0, help="只取音频前 N 秒做判断,0 表示全长") + + return parser.parse_args() + + +def main() -> int: + args = parse_arguments() + print_header("快速语言识别(LID)") + + output_path = Path(args.output).resolve() + model_savedir = Path(args.model_savedir).resolve() + filter_audio = str(args.filter_audio).lower() in {"1", "true", "yes", "y", "on"} + filter_audio_list = Path(args.filter_audio_list).resolve() + + # 读入 items(默认使用 output_data/normalization 目录) + items: List[Dict] + if args.input_list: + input_path = Path(args.input_list).resolve() + if not input_path.exists(): + print_error(f"输入列表不存在: {input_path}") + return 1 + print_info(f"输入列表: {input_path}") + items = _load_jsonl_items(input_path) + if filter_audio: + items = [it for it in items if it.get("quality_flag", "ok") == "ok"] + else: + if filter_audio: + if filter_audio_list.exists(): + print_info(f"启用质量过滤,读取列表: {filter_audio_list}") + items = _load_jsonl_items(filter_audio_list, filter_ok_only=True) + else: + print_warning(f"质量过滤列表不存在,回退为扫描目录: {filter_audio_list}") + audio_dir = Path(args.audio_dir).resolve() + if not audio_dir.exists(): + print_error(f"音频目录不存在: {audio_dir}") + return 1 + print_info(f"扫描目录: {audio_dir}") + audio_files = _find_audio_files(audio_dir) + if not audio_files: + print_warning("未找到任何音频文件") + return 0 + items = [{"key": p.stem, "wav": str(p.resolve()), "txt": ""} for p in audio_files] + else: + audio_dir = Path(args.audio_dir).resolve() + if not audio_dir.exists(): + print_error(f"音频目录不存在: {audio_dir}") + return 1 + print_info(f"扫描目录: {audio_dir}") + audio_files = _find_audio_files(audio_dir) + if not audio_files: + print_warning("未找到任何音频文件") + return 0 + items = [{"key": p.stem, "wav": str(p.resolve()), "txt": ""} for p in audio_files] + + if not items: + print_warning("输入为空,退出") + return 0 + + print_info(f"待识别音频数: {len(items)}") + print_info(f"模型: {args.model_source}") + print_info(f"模型缓存目录: {model_savedir}") + print_info(f"device={args.device}, batch_size={args.batch_size}, max_seconds={args.max_seconds}") + + try: + out_items = _lid_predict_items( + items=items, + model_source=args.model_source, + model_savedir=model_savedir, + device=args.device, + batch_size=max(1, int(args.batch_size)), + max_seconds=float(args.max_seconds), + ) + except Exception as e: + print_error(f"LID 推理失败: {e}") + print_error("traceback:\n" + traceback.format_exc()) + return 1 + + _dump_jsonl_items(output_path, out_items) + print_success(f"完成!输出: {output_path}") + + stat: Dict[str, int] = {"zh": 0, "en": 0} + for it in out_items: + stat[str(it.get("lang", "en"))] = stat.get(str(it.get("lang", "en")), 0) + 1 + print_info(f"统计: zh={stat.get('zh', 0)}, en={stat.get('en', 0)}") + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/utils/generate_audio_list.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/utils/generate_audio_list.py new file mode 100644 index 00000000..022f2187 --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/utils/generate_audio_list.py @@ -0,0 +1,261 @@ +#!/usr/bin/env python3 +""" +生成音频文件索引表工具 +将指定文件夹中的wav文件枚举为JSON格式的索引表 +""" + +import argparse +import json +import sys +from pathlib import Path +from typing import List, Optional + +# 添加脚本所在目录到系统路径,导入颜色工具 +try: + sys.path.insert(0, str(Path(__file__).parent.parent.parent / "scripts" / "audio_convert")) + from color_utils import info, warning, error, ok, success, header +except ImportError: + # 如果无法导入颜色工具,使用普通打印 + def info(msg: str) -> str: + return f"[INFO] {msg}" + + def warning(msg: str) -> str: + return f"[WARNING] {msg}" + + def error(msg: str) -> str: + return f"[ERROR] {msg}" + + def ok(msg: str) -> str: + return f"[OK] {msg}" + + def success(msg: str) -> str: + return f"[SUCCESS] {msg}" + + def header(msg: str) -> str: + return f"=== {msg} ===" + + # 创建包装函数,使其行为与颜色版本相同 + def print_info(msg: str): + print(info(msg)) + + def print_warning(msg: str): + print(warning(msg)) + + def print_error(msg: str): + print(error(msg)) + + def print_ok(msg: str): + print(ok(msg)) + + def print_success(msg: str): + print(success(msg)) + + def print_header(msg: str): + print(header(msg)) +else: + # 如果成功导入,创建打印包装函数 + def print_info(msg: str): + print(info(msg)) + + def print_warning(msg: str): + print(warning(msg)) + + def print_error(msg: str): + print(error(msg)) + + def print_ok(msg: str): + print(ok(msg)) + + def print_success(msg: str): + print(success(msg)) + + def print_header(msg: str): + print(header(msg)) + + +def get_default_audio_dir() -> Path: + """ + 获取默认音频文件夹路径 + + Returns: + Path: 默认音频文件夹路径 + """ + # 根据项目结构,音频预处理器的output_data/normalization目录 + project_root = Path(__file__).parent.parent.parent + return project_root / "output_data" / "normalization" + + +def find_wav_files(audio_dir: Path) -> List[Path]: + """ + 查找音频文件夹中的所有.wav文件 + + Args: + audio_dir: 音频文件夹路径 + + Returns: + List[Path]: .wav文件路径列表 + """ + if not audio_dir.exists(): + print_error(f"音频文件夹不存在: {audio_dir}") + return [] + + # 查找所有.wav文件(包括子目录) + wav_files = [] + for pattern in ["*.wav", "*.WAV"]: + wav_files.extend(list(audio_dir.rglob(pattern))) + + return sorted(wav_files) + + +def generate_item_list(audio_dir: Path, output_file: Path, key_prefix: Optional[str] = None) -> int: + """ + 生成音频索引表 + + Args: + audio_dir: 音频文件夹路径 + output_file: 输出文件路径 + key_prefix: 键值前缀,可选 + + Returns: + int: 生成的文件数量 + """ + # 查找wav文件 + print_info(f"扫描音频文件夹: {audio_dir}") + wav_files = find_wav_files(audio_dir) + + if not wav_files: + print_warning("未找到任何.wav文件") + return 0 + + print_info(f"找到 {len(wav_files)} 个.wav文件") + + # 确保输出文件的父目录存在 + output_file.parent.mkdir(parents=True, exist_ok=True) + + # 生成索引表 + items = [] + for idx, wav_file in enumerate(wav_files): + # 生成键值 + if key_prefix: + key = f"{key_prefix}{idx}" + else: + key = wav_file.stem # 使用文件名(不带扩展名) + + # 构建绝对路径 + wav_abs_path = wav_file.resolve() + + # 创建项目字典 + item = { + "key": key, + "wav": str(wav_abs_path), + "txt": "" + } + + items.append(item) + + # 写入文件 + try: + with open(output_file, 'w', encoding='utf-8') as f: + for item in items: + json_line = json.dumps(item, ensure_ascii=False) + f.write(json_line + "\n") + + print_ok(f"已生成索引表: {output_file}") + print_info(f"共写入 {len(items)} 条记录") + + + return len(items) + + except Exception as e: + print_error(f"写入文件失败: {e}") + return 0 + + +def parse_arguments(): + """解析命令行参数""" + # 获取默认音频文件夹 + default_audio_dir = get_default_audio_dir() + + parser = argparse.ArgumentParser( + description="生成音频文件索引表工具", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +示例: + %(prog)s # 使用默认配置 + %(prog)s --audio_dir ./my_audio --output ./my_list.txt + %(prog)s --audio_dir ./audio --key_prefix sample_ + %(prog)s --audio_dir ./wavs --output ./index.jsonl --key_prefix audio_ + """ + ) + + parser.add_argument( + "--audio_dir", + "-a", + default=str(default_audio_dir), + help=f"音频文件夹路径,默认: {default_audio_dir}" + ) + + parser.add_argument( + "--output", + "-o", + default=None, + help="输出列表文件路径,默认: {音频文件夹}/item.list" + ) + + parser.add_argument( + "--key_prefix", + "-k", + default=None, + help="键值前缀,例如 'audio_' 会生成 'audio_0', 'audio_1', ..." + ) + + return parser.parse_args() + + +def main(): + """主函数""" + args = parse_arguments() + + print_header("生成音频索引") + + # 解析音频文件夹路径(支持相对路径) + audio_dir = Path(args.audio_dir).resolve() + if not audio_dir.exists(): + print_error(f"指定的音频文件夹不存在: {audio_dir}") + print_info("请确保路径正确或先运行音频归一化处理") + return 1 + + print_info(f"音频文件夹: {audio_dir}") + + # 确定输出文件路径 + if args.output: + output_file = Path(args.output).resolve() + else: + output_file = audio_dir / "item.list" + + print_info(f"输出文件: {output_file}") + + # 如果指定了键值前缀 + + # 查找wav文件 + wav_files = find_wav_files(audio_dir) + + if not wav_files: + print_warning("未找到任何.wav文件,程序退出") + return 0 + + # 生成索引表 + print_info("开始生成索引表...") + item_count = generate_item_list(audio_dir, output_file, args.key_prefix) + + if item_count > 0: + print_success(f"索引表生成完成!共生成 {item_count} 条记录") + print_info(f"文件保存在: {output_file}") + else: + print_warning("索引表生成失败或未生成任何记录") + + return 0 if item_count > 0 else 1 + + +if __name__ == "__main__": + sys.exit(main()) \ No newline at end of file diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/utils/gtcrn_denoise.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/utils/gtcrn_denoise.py new file mode 100644 index 00000000..c70ead2e --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/utils/gtcrn_denoise.py @@ -0,0 +1,292 @@ +#!/usr/bin/env python3 +""" +GTCRN 本地智能降噪工具 + +特点: +- 优先使用 ONNXRuntime 做推理,适合本机快速部署 +- 支持单个音频文件或目录批量处理 +- 输入音频会被统一到 16k / mono / float32 +- 输出为降噪后的 wav + +说明: +- 当前仓库只包含 GTCRN 结构代码,不包含训练好的权重文件。 +- 你需要把训练好的 .onnx / .tar / .pt 放到本地后再指定给 --model。 +- 若给的是 .tar / .pt,可选择 --export_onnx 先导出为 ONNX,再用 ONNXRuntime 推理。 +""" + +import argparse +import sys +from pathlib import Path +from typing import Iterable, List, Optional, Tuple + +import numpy as np + +PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent + +sys.path.insert(0, str(PROJECT_ROOT / "src" / "utils")) + +try: + from color_utils import info, warning, error, ok, success, header # type: ignore + + def print_info(msg: str): + print(info(msg)) + + def print_warning(msg: str): + print(warning(msg)) + + def print_error(msg: str): + print(error(msg)) + + def print_ok(msg: str): + print(ok(msg)) + + def print_success(msg: str): + print(success(msg)) + + def print_header(msg: str): + print(header(msg)) + +except Exception: + def print_info(msg: str): + print(f"[INFO] {msg}") + + def print_warning(msg: str): + print(f"[WARNING] {msg}") + + def print_error(msg: str): + print(f"[ERROR] {msg}") + + def print_ok(msg: str): + print(f"[OK] {msg}") + + def print_success(msg: str): + print(f"[SUCCESS] {msg}") + + def print_header(msg: str): + print(f"=== {msg} ===") + + +def _import_audio_backend(): + import soundfile as sf # type: ignore + import torch # type: ignore + return sf, torch + + +def _find_audio_files(input_path: Path) -> List[Path]: + exts = {".wav", ".flac", ".mp3", ".aac", ".m4a", ".ogg", ".webm"} + if input_path.is_file(): + return [input_path] + files = [] + for p in input_path.rglob("*"): + if p.is_file() and p.suffix.lower() in exts: + files.append(p) + return sorted(files) + + +def load_audio_mono_16k(path: Path) -> np.ndarray: + """ + 读取任意常见音频并转换为 16k 单声道 float32。 + """ + sf, torch = _import_audio_backend() + data, sr = sf.read(str(path), always_2d=False) + if data.ndim > 1: + data = np.mean(data, axis=1) + data = data.astype(np.float32) + if sr != 16000: + # 使用 torch 做重采样,减少额外依赖差异 + wav = torch.from_numpy(data).float()[None, None, :] + resampler = torch.nn.functional.interpolate + # 简化实现:通过线性插值做基础重采样,够用于前端降噪预处理 + new_len = int(round(wav.shape[-1] * 16000.0 / float(sr))) + wav = torch.nn.functional.interpolate(wav, size=new_len, mode="linear", align_corners=False) + data = wav[0, 0].cpu().numpy() + return data.astype(np.float32) + + +def stft_complex(x: np.ndarray, n_fft: int = 512, hop_length: int = 256, win_length: int = 512): + """ + 将波形转为 GTCRN 需要的复数谱输入: + 返回 shape = (1, F, T, 2) + """ + sf, torch = _import_audio_backend() + _ = sf + wav = torch.from_numpy(x).float() + window = torch.hann_window(win_length).pow(0.5) + spec = torch.stft( + wav, + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + window=window, + return_complex=False, + center=True, + ) # (F, T, 2) + spec = spec.unsqueeze(0) # (1, F, T, 2) + return spec.cpu().numpy().astype(np.float32) + + +def istft_complex(spec: np.ndarray, n_fft: int = 512, hop_length: int = 256, win_length: int = 512): + """ + 将 GTCRN 输出的复数谱还原为波形。 + 输入 shape = (1, F, T, 2) 或 (F, T, 2) + """ + sf, torch = _import_audio_backend() + _ = sf + if spec.ndim == 4: + spec = spec[0] + # spec: (F, T, 2) -> complex tensor + spec_t = torch.from_numpy(spec).float() + spec_t = torch.view_as_complex(spec_t.contiguous()) + window = torch.hann_window(win_length).pow(0.5) + wav = torch.istft( + spec_t, + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + window=window, + center=True, + ) + return wav.cpu().numpy().astype(np.float32) + + +class OnnxGtcrnDenoiser: + """ + 使用 ONNXRuntime 推理 GTCRN。 + 说明: + - GTCRN 是流式结构,ONNX 输入/输出包含 cache。 + - 这里按 1 帧一帧地做流式推理,然后重建为完整波形。 + """ + + def __init__(self, model_path: Path): + try: + import onnxruntime as ort # type: ignore + except Exception as e: + raise RuntimeError("未安装 onnxruntime,请先安装 onnxruntime 或 onnxruntime-gpu") from e + + if not model_path.exists(): + raise FileNotFoundError(f"ONNX 模型不存在: {model_path}") + + self.model_path = model_path + self.session = ort.InferenceSession(str(model_path), providers=["CPUExecutionProvider"]) + self.input_names = [i.name for i in self.session.get_inputs()] + self.output_names = [o.name for o in self.session.get_outputs()] + + # 固定 cache 形状来自 GTCRN stream 版本导出 + self.conv_cache = np.zeros([2, 1, 16, 16, 33], dtype=np.float32) + self.tra_cache = np.zeros([2, 3, 1, 1, 16], dtype=np.float32) + self.inter_cache = np.zeros([2, 1, 33, 16], dtype=np.float32) + + def denoise(self, wav: np.ndarray) -> np.ndarray: + spec = stft_complex(wav) # (1, F, T, 2) + outputs = [] + conv_cache = self.conv_cache.copy() + tra_cache = self.tra_cache.copy() + inter_cache = self.inter_cache.copy() + + # 按时间帧逐帧推理 + for i in range(spec.shape[2]): + mix = spec[:, :, i:i+1, :].astype(np.float32) + out_i, conv_cache, tra_cache, inter_cache = self.session.run( + [], + { + "mix": mix, + "conv_cache": conv_cache, + "tra_cache": tra_cache, + "inter_cache": inter_cache, + }, + ) + outputs.append(out_i) + + out_spec = np.concatenate(outputs, axis=2) # (1, F, T, 2) + wav_out = istft_complex(out_spec) + return wav_out + + +def _resolve_model(model: Path, export_dir: Optional[Path] = None) -> Path: + """ + 解析模型路径: + - 如果是 .onnx,直接返回 + - 如果是 .tar/.pt,可选导出为 ONNX(需要你本地提供训练权重) + """ + if model.suffix.lower() == ".onnx": + return model + if model.suffix.lower() in {".tar", ".pt", ".pth"}: + raise RuntimeError("算子不再打包 GTCRN 源码,请预先导出 ONNX 并把 --model 指向 .onnx 文件。") + raise ValueError(f"不支持的模型格式: {model.suffix}") + + +def process_one(input_file: Path, output_file: Path, denoiser: OnnxGtcrnDenoiser) -> None: + sf, _ = _import_audio_backend() + wav = load_audio_mono_16k(input_file) + enhanced = denoiser.denoise(wav) + output_file.parent.mkdir(parents=True, exist_ok=True) + sf.write(str(output_file), enhanced, 16000) + + +def main() -> int: + parser = argparse.ArgumentParser( + description="GTCRN 本地智能降噪工具(优先 ONNXRuntime)", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +示例: + # 单文件降噪(ONNX 模型) + python -m src.utils.gtcrn_denoise --input ./a.wav --model ./models/gtcrn/gtcrn.onnx --output ./out.wav + + # 目录批处理 + python -m src.utils.gtcrn_denoise --input ./input_dir --model ./models/gtcrn/gtcrn.onnx --output ./denoised_dir + + # 如果你手里是 .tar/.pt 权重,可尝试导出 ONNX(需要本地可加载权重) + python -m src.utils.gtcrn_denoise --input ./a.wav --model ./weights/model_trained_on_dns3.tar --export_dir ./models/gtcrn_onnx --output ./out.wav + """, + ) + parser.add_argument("--input", required=True, help="输入音频文件或目录") + parser.add_argument("--model", required=True, help="GTCRN 模型路径(.onnx/.tar/.pt/.pth)") + parser.add_argument("--output", required=True, help="输出 wav 文件或目录") + parser.add_argument("--export_dir", default=None, help="若输入为 .tar/.pt,则导出 ONNX 的目录") + args = parser.parse_args() + + input_path = Path(args.input).resolve() + model_path = Path(args.model).resolve() + output_path = Path(args.output).resolve() + export_dir = Path(args.export_dir).resolve() if args.export_dir else None + + print_header("GTCRN 智能降噪") + print_info(f"输入: {input_path}") + print_info(f"模型: {model_path}") + print_info(f"输出: {output_path}") + + try: + resolved_model = _resolve_model(model_path, export_dir=export_dir) + print_info(f"使用模型: {resolved_model}") + denoiser = OnnxGtcrnDenoiser(resolved_model) + except Exception as e: + print_error(f"初始化失败: {e}") + return 1 + + files = _find_audio_files(input_path) + if not files: + print_warning("未找到可处理的音频文件") + return 0 + + try: + if input_path.is_file(): + if output_path.suffix.lower() != ".wav": + output_path = output_path.with_suffix(".wav") + process_one(files[0], output_path, denoiser) + print_success(f"完成: {output_path}") + else: + output_path.mkdir(parents=True, exist_ok=True) + for f in files: + out_file = output_path / f"{f.stem}.wav" + print_info(f"降噪: {f.name} -> {out_file.name}") + process_one(f, out_file, denoiser) + print_success(f"批量完成,输出目录: {output_path}") + except Exception as e: + print_error(f"处理失败: {e}") + return 1 + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/utils/recognize.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/utils/recognize.py new file mode 100644 index 00000000..f0f666ea --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/utils/recognize.py @@ -0,0 +1,329 @@ +#!/usr/bin/env python3 +""" +语音识别脚本 +调用 WeNet 模型进行音频转文本识别 +支持中文和英文,自动选择设备 +""" + +import argparse +import json +import subprocess +import sys +import threading +import queue +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +# 当前在 src/utils,同目录导入 color_utils(相对路径以项目根为基准) +_SCRIPT_DIR = Path(__file__).resolve().parent +sys.path.insert(0, str(_SCRIPT_DIR)) + +try: + from color_utils import ( + info, warning, error, ok, success, header + ) + + def print_info(msg: str): + print(info(msg)) + + def print_warning(msg: str): + print(warning(msg)) + + def print_error(msg: str): + print(error(msg)) + + def print_ok(msg: str): + print(ok(msg)) + + def print_success(msg: str): + print(success(msg)) + + def print_header(msg: str): + print(header(msg)) + +except ImportError: + def print_info(msg: str): + print(f"[INFO] {msg}") + + def print_warning(msg: str): + print(f"[WARNING] {msg}") + + def print_error(msg: str): + print(f"[ERROR] {msg}") + + def print_ok(msg: str): + print(f"[OK] {msg}") + + def print_success(msg: str): + print(f"[SUCCESS] {msg}") + + def print_header(msg: str): + print(f"=== {msg} ===") + + +def get_project_root() -> Path: + """项目根目录(src/utils -> src -> 根)。""" + return Path(__file__).resolve().parent.parent.parent + + +def check_npu_available() -> bool: + try: + import torch_npu + return True + except ImportError: + npu_devices = list(Path("/dev").glob("davinci*")) + return len(npu_devices) > 0 + + +def get_default_paths() -> dict: + project_root = get_project_root() + model_root = Path("/models/AudioOperations/asr") + return { + 'audio_list': project_root / "output_data" / "normalization" / "item.list", + 'result_dir': project_root / "output_data" / "asr", + 'wenet_wrapper': project_root / "src" / "utils" / "run_wenet.py", + 'aishell_model': model_root / "aishell" / "final.pt", + 'librispeech_model': model_root / "librispeech" / "final.pt", + } + + +def resolve_device(device_arg: str) -> str: + if device_arg == "auto": + if check_npu_available(): + print_info("检测到 NPU 设备,使用 NPU") + return "npu" + else: + print_info("未检测到 NPU 设备,使用 CPU") + return "cpu" + elif device_arg == "npu": + if check_npu_available(): + return "npu" + raise ValueError("指定使用 NPU,但设备不支持 NPU") + elif device_arg == "cpu": + return "cpu" + raise ValueError(f"不支持的设备类型: {device_arg}") + + +def check_paths(paths: dict, language: str) -> None: + if not paths['wenet_wrapper'].exists(): + raise FileNotFoundError(f"WeNet 包装器脚本不存在: {paths['wenet_wrapper']}") + if not paths['audio_list'].exists(): + raise FileNotFoundError(f"音频列表文件不存在: {paths['audio_list']}") + paths['result_dir'].mkdir(parents=True, exist_ok=True) + if language == "zh": + if not paths['aishell_model'].exists(): + raise FileNotFoundError(f"AIShell 模型文件不存在: {paths['aishell_model']}") + elif language == "en": + if not paths['librispeech_model'].exists(): + raise FileNotFoundError(f"LibriSpeech 模型文件不存在: {paths['librispeech_model']}") + + +def prepare_config(language: str) -> str: + if language == "zh": + model_dir = Path("/models/AudioOperations/asr/aishell") + elif language == "en": + model_dir = Path("/models/AudioOperations/asr/librispeech") + else: + raise ValueError(f"不支持的语言: {language}") + yaml_files = list(model_dir.glob("*.yaml")) + if not yaml_files: + raise FileNotFoundError(f"在 {model_dir} 中未找到 YAML 配置文件") + config_file = None + for yaml_file in yaml_files: + if yaml_file.name == "train.yaml": + config_file = yaml_file + break + if config_file is None: + config_file = yaml_files[0] + return str(config_file) + + +def read_output(stream, output_queue, stream_name): + try: + for line in iter(stream.readline, ''): + if line: + output_queue.put((stream_name, line.rstrip('\n'))) + except Exception: + pass + finally: + stream.close() + + +def run_recognize(language: str, audio_list: str, result_dir: str, device: str) -> int: + paths = get_default_paths() + if audio_list: + paths['audio_list'] = Path(audio_list).resolve() + if result_dir: + paths['result_dir'] = Path(result_dir).resolve() + print_info("检查路径...") + check_paths(paths, language) + print_info("准备配置文件...") + config_file = prepare_config(language) + if language == "zh": + model_file = str(paths['aishell_model']) + model_name = "AIShell (中文)" + elif language == "en": + model_file = str(paths['librispeech_model']) + model_name = "LibriSpeech (英文)" + else: + raise ValueError(f"不支持的语言: {language}") + actual_device = resolve_device(device) + cmd = [ + sys.executable, + str(paths['wenet_wrapper']), + "--mode", "ctc_greedy_search", + "--device", actual_device, + "--config", config_file, + "--test_data", str(paths['audio_list']), + "--checkpoint", model_file, + "--batch_size", "1", + "--result_dir", str(paths['result_dir']), + ] + print_header("语音识别配置") + print_info(f"语言: {language} ({model_name})") + print_info(f"设备: {actual_device}") + print_info(f"音频列表: {paths['audio_list']}") + print_info(f"结果目录: {paths['result_dir']}") + print_info(f"配置文件: {Path(config_file).name}") + print_info(f"模型文件: {Path(model_file).name}") + try: + with open(paths['audio_list'], 'r', encoding='utf-8') as f: + audio_count = sum(1 for _ in f) + print_info(f"音频数量: {audio_count}") + except Exception as e: + print_warning(f"无法统计音频数量: {e}") + try: + process = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + encoding='utf-8', + bufsize=1, + universal_newlines=True + ) + output_queue = queue.Queue() + stdout_thread = threading.Thread(target=read_output, args=(process.stdout, output_queue, 'stdout')) + stderr_thread = threading.Thread(target=read_output, args=(process.stderr, output_queue, 'stderr')) + stdout_thread.daemon = True + stderr_thread.daemon = True + stdout_thread.start() + stderr_thread.start() + while True: + try: + stream_name, line = output_queue.get(timeout=0.1) + print(line) + except queue.Empty: + if process.poll() is not None: + try: + while True: + stream_name, line = output_queue.get_nowait() + print(line) + except queue.Empty: + pass + break + return_code = process.wait() + stdout_thread.join(timeout=1) + stderr_thread.join(timeout=1) + print("-" * 80) + if return_code == 0: + print_success("语音识别完成!") + print_info(f"识别结果保存在: {paths['result_dir']}") + result_files = list(paths['result_dir'].glob("*.txt")) + if result_files: + print_info("生成结果文件:") + for result_file in result_files: + print_info(f" - {result_file.name}") + return 0 + else: + print_error(f"识别失败,返回码: {return_code}") + return return_code + except subprocess.CalledProcessError as e: + print_error(f"执行失败: {e}") + if e.stderr: + print_error(f"错误详情: {e.stderr}") + return e.returncode + except FileNotFoundError as e: + print_error(f"文件不存在: {e}") + return 1 + except Exception as e: + print_error(f"未知错误: {e}") + import traceback + traceback.print_exc() + return 1 + + +def create_wenet_wrapper(wrapper_path: Path): + project_root = get_project_root() + wrapper_content = '''#!/usr/bin/env python3 +"""运行 WeNet 识别脚本的包装器""" +import sys +from pathlib import Path + +def main(): + try: + from wenet.bin.recognize import main as wenet_main + wenet_main() + except ImportError as e: + print(f"[ERROR] 无法导入 WeNet 模块: {e}") + sys.exit(1) + +if __name__ == "__main__": + main() +''' + wrapper_path.parent.mkdir(parents=True, exist_ok=True) + with open(wrapper_path, 'w', encoding='utf-8') as f: + f.write(wrapper_content) + wrapper_path.chmod(0o755) + print_info(f"已创建 WeNet 包装器脚本: {wrapper_path}") + + +def main(): + defaults = get_default_paths() + parser = argparse.ArgumentParser( + description="语音识别脚本 - 调用 WeNet 进行音频转文本", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +示例: + %(prog)s # 默认中文识别 + %(prog)s --language en # 英文识别 + %(prog)s --audio_list ./my_audio.list + %(prog)s --result_dir ./my_results + %(prog)s --device npu + """ + ) + parser.add_argument("--language", "-l", choices=["zh", "en"], default="zh", help="音频语言") + parser.add_argument("--audio_list", "-a", default=str(defaults['audio_list']), help="音频列表路径") + parser.add_argument("--result_dir", "-r", default=str(defaults['result_dir']), help="结果目录") + parser.add_argument("--device", "-d", choices=["auto", "npu", "cpu"], default="npu", help="设备") + args = parser.parse_args() + print_header("语音识别") + try: + import torch + print_info(f"PyTorch 版本: {torch.__version__}") + except ImportError: + print_error("未安装 PyTorch,请先安装") + return 1 + wenet_wrapper = defaults['wenet_wrapper'] + if not wenet_wrapper.exists(): + print_warning(f"WeNet 包装器不存在,尝试创建: {wenet_wrapper}") + create_wenet_wrapper(wenet_wrapper) + try: + return run_recognize( + language=args.language, + audio_list=args.audio_list, + result_dir=args.result_dir, + device=args.device + ) + except (ValueError, FileNotFoundError) as e: + print_error(str(e)) + return 1 + except Exception as e: + print_error(str(e)) + import traceback + traceback.print_exc() + return 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/utils/run_wenet.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/utils/run_wenet.py new file mode 100644 index 00000000..d2ac2a44 --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/utils/run_wenet.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python3 +"""Run WeNet recognition from the DataMate runtime environment.""" + +import sys + + +def main() -> None: + try: + from wenet.bin.recognize import main as wenet_main # type: ignore + except ImportError as exc: + print( + "[ERROR] Cannot import WeNet from the runtime environment. " + "Install the pinned WeNet package/source listed in audio_runtime_dependencies.md.", + file=sys.stderr, + ) + raise SystemExit(1) from exc + wenet_main() + + +if __name__ == "__main__": + main() diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/utils/yaml_config_loader.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/utils/yaml_config_loader.py new file mode 100644 index 00000000..58594dcc --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/utils/yaml_config_loader.py @@ -0,0 +1,119 @@ +#!/usr/bin/env python3 +""" +轻量 YAML 配置加载器(面向 argparse 脚本)。 + +目标: +- 允许脚本通过 --config xxx.yaml 读取配置 +- YAML 中与 argparse dest 同名的键会作为“默认值” +- 命令行显式传入的参数优先级更高(覆盖配置) +""" + +from __future__ import annotations + +import argparse +import sys +from pathlib import Path +from typing import Any, Dict, Iterable, Optional + + +def _safe_import_yaml(): + try: + import yaml # type: ignore + except Exception as e: # pragma: no cover + raise RuntimeError( + "缺少 PyYAML 依赖,无法读取 YAML 配置文件。请安装 pyyaml。" + ) from e + return yaml + + +def load_yaml_dict(path: Path) -> Dict[str, Any]: + yaml = _safe_import_yaml() + with open(path, "r", encoding="utf-8") as f: + data = yaml.safe_load(f) + if data is None: + return {} + if not isinstance(data, dict): + raise ValueError(f"YAML 顶层必须是 dict,实际是: {type(data)}") + return data + + +def pick_section(config: Dict[str, Any], section: Optional[str]) -> Dict[str, Any]: + """ + 支持三种写法: + 1) 顶层就是参数 dict + 2) 顶层包含 {section: {...}} + 3) 顶层只有一个 key 且 value 是 dict(例如 audio_config.yaml 里的 audio_config) + """ + if not config: + return {} + + if section and isinstance(config.get(section), dict): + return dict(config[section]) + + if len(config) == 1: + only_val = next(iter(config.values())) + if isinstance(only_val, dict): + return dict(only_val) + + return dict(config) + + +def _parser_dests(parser: argparse.ArgumentParser) -> set[str]: + dests: set[str] = set() + for a in parser._actions: # noqa: SLF001 - argparse 内部字段,足够稳定 + if getattr(a, "dest", None): + dests.add(a.dest) + return dests + + +def apply_yaml_defaults_to_parser( + parser: argparse.ArgumentParser, + cfg: Dict[str, Any], +) -> None: + dests = _parser_dests(parser) + defaults: Dict[str, Any] = {k: v for k, v in cfg.items() if k in dests} + if defaults: + parser.set_defaults(**defaults) + + +def parse_args_with_yaml_config( + parser: argparse.ArgumentParser, + *, + section: Optional[str] = None, + config_dest: str = "config", + default_config_paths: Optional[Iterable[Path]] = None, + auto_use_default_config_when_no_args: bool = True, +) -> argparse.Namespace: + """ + 两阶段解析: + - 先仅解析 --config 得到 YAML 路径 + - 读取 YAML 并把同名键写入 parser defaults + - 再做完整 parse_args,保证 CLI 覆盖 YAML + """ + pre = argparse.ArgumentParser(add_help=False) + pre.add_argument("--config", "-c", default=None, dest=config_dest) + pre_ns, _ = pre.parse_known_args() + + cfg_path = getattr(pre_ns, config_dest, None) + cfg_file: Optional[Path] = None + if cfg_path: + cfg_file = Path(str(cfg_path)).expanduser().resolve() + if not cfg_file.exists(): + raise FileNotFoundError(f"配置文件不存在: {cfg_file}") + else: + # 当用户没有指定任何参数时(仅脚本名),尝试在默认路径查找配置文件 + no_user_args = len(sys.argv) <= 1 + if auto_use_default_config_when_no_args and no_user_args and default_config_paths: + for p in default_config_paths: + pp = Path(p).expanduser().resolve() + if pp.exists(): + cfg_file = pp + break + + if cfg_file and cfg_file.exists(): + cfg_root = load_yaml_dict(cfg_file) + cfg = pick_section(cfg_root, section) + apply_yaml_defaults_to_parser(parser, cfg) + + return parser.parse_args() + diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_skip.py b/runtime/ops/mapper/audio_asr_pipeline/audio_skip.py new file mode 100644 index 00000000..aec49613 --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_skip.py @@ -0,0 +1,114 @@ +# -- encoding: utf-8 -- + +from pathlib import Path +from typing import Any, Dict + +from loguru import logger + + +AUDIO_EXTS = { + "aac", + "aif", + "aiff", + "amr", + "au", + "flac", + "m4a", + "mp3", + "oga", + "ogg", + "opus", + "snd", + "wav", + "webm", + "wma", +} + + +def _parts(path_value: str) -> set[str]: + try: + return {part.lower() for part in Path(path_value).parts} + except Exception: + return set() + + +def is_reference_sample(sample: Dict[str, Any], filepath_key: str = "filePath") -> bool: + path_value = str(sample.get(filepath_key) or "") + return "references" in _parts(path_value) + + +def _ext_from_sample( + sample: Dict[str, Any], + filepath_key: str = "filePath", + filetype_key: str = "fileType", + target_type_key: str = "target_type", +) -> str: + for key in (target_type_key, filetype_key): + value = str(sample.get(key) or "").strip().lower().lstrip(".") + if value: + return value + path_value = str(sample.get(filepath_key) or "").strip() + return Path(path_value).suffix.lower().lstrip(".") if path_value else "" + + +def is_audio_sample( + sample: Dict[str, Any], + filepath_key: str = "filePath", + filetype_key: str = "fileType", + target_type_key: str = "target_type", + data_key: str = "data", +) -> bool: + if is_reference_sample(sample, filepath_key): + return False + data = sample.get(data_key) + if isinstance(data, (bytes, bytearray)) and data: + return True + return _ext_from_sample(sample, filepath_key, filetype_key, target_type_key) in AUDIO_EXTS + + +def invalid_quality_reason(sample: Dict[str, Any], ext_params_key: str = "ext_params") -> str: + for key in ("fileName", "sourceFileName", "filePath"): + marker_source = Path(str(sample.get(key) or "")).stem.lower() + marker = "__quality_invalid" + if marker in marker_source: + reason = marker_source.split(marker, 1)[1].strip("_") or "invalid_audio" + return f"invalid_audio_quality:{reason}" + + ext = sample.get(ext_params_key, {}) + if not isinstance(ext, dict): + return "" + quality = ext.get("audio_quality", {}) + if not isinstance(quality, dict): + return "" + if str(quality.get("quality_flag") or "").strip().lower() != "invalid": + return "" + skip_downstream = quality.get("skip_downstream", True) + if isinstance(skip_downstream, str): + skip_downstream = skip_downstream.strip().lower() in {"1", "true", "yes", "y", "on"} + if not skip_downstream: + return "" + reason = str(quality.get("reason") or "invalid_audio").strip() + return f"invalid_audio_quality:{reason}" + + +def mark_skipped_sample( + sample: Dict[str, Any], + reason: str, + op_name: str, + text_key: str = "text", + data_key: str = "data", + filetype_key: str = "fileType", + target_type_key: str = "target_type", + ext_params_key: str = "ext_params", +) -> Dict[str, Any]: + ext = sample.get(ext_params_key, {}) + if not isinstance(ext, dict): + ext = {"_raw": ext} + ext.setdefault("audio_skip", {})[op_name] = reason + sample[ext_params_key] = ext + sample[text_key] = "" + sample[data_key] = b"" + sample[filetype_key] = "" + sample[target_type_key] = "" + logger.info(f"fileName: {sample.get('fileName')}, method: {op_name} skipped: {reason}") + return sample diff --git a/runtime/ops/mapper/audio_asr_pipeline/metadata.yml b/runtime/ops/mapper/audio_asr_pipeline/metadata.yml new file mode 100644 index 00000000..141be22d --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/metadata.yml @@ -0,0 +1,155 @@ +name: 'audioOps-音频识别流水线' +name_en: 'audioOps-Audio ASR Pipeline' +description: '调用 audio_preprocessor 的 normalization→(可选)GTCRN→(可选)异常过滤→LID→切分→ASR→合并,对当前输入音频导出一个 txt 转写文件,并在 ext_params 中记录中间产物路径。' +description_en: 'Run audio_preprocessor pipeline for the current audio file and export one txt transcript; records artifacts in ext_params.' +language: 'python' +vendor: 'huawei' +raw_id: 'AudioAsrPipeline' +version: '1.0.0' +types: + - 'annotation' +modal: 'audio' +inputs: 'audio' +outputs: 'text' +settings: + doDenoise: + name: '启用降噪' + type: 'switch' + description: '是否启用 GTCRN 降噪。' + defaultVal: 'false' + required: false + checkedLabel: '开启' + unCheckedLabel: '关闭' + denoiseModelPath: + name: '降噪模型路径' + type: 'input' + description: 'GTCRN ONNX 模型绝对路径;默认使用固定部署路径 /models/AudioOperations/gtcrn/gtcrn.onnx。' + defaultVal: '/models/AudioOperations/gtcrn/gtcrn.onnx' + required: false + doAnomalyFilter: + name: '启用异常过滤' + type: 'switch' + description: '是否启用异常语音检测与过滤(时长/静音比例)。' + defaultVal: 'true' + required: false + checkedLabel: '开启' + unCheckedLabel: '关闭' + minDur: + name: '最小时长(秒)' + type: 'inputNumber' + defaultVal: 1.0 + min: 0 + max: 36000 + step: 0.1 + maxDur: + name: '最大时长(秒)' + type: 'inputNumber' + defaultVal: 20000.0 + min: 0 + max: 360000 + step: 1 + silenceRatioTh: + name: '静音帧比例阈值' + type: 'slider' + defaultVal: 0.8 + min: 0 + max: 1 + step: 0.01 + silenceRmsRatioTh: + name: '静音判定比例' + type: 'slider' + defaultVal: 0.05 + min: 0 + max: 1 + step: 0.01 + lidModelSource: + name: 'LID 模型源' + type: 'input' + description: 'SpeechBrain LID 本地模型目录。默认使用固定部署路径 /models/AudioOperations/lid/speechbrain_lang-id-voxlingua107-ecapa。' + defaultVal: '/models/AudioOperations/lid/speechbrain_lang-id-voxlingua107-ecapa' + required: false + lidDevice: + name: 'LID 设备' + type: 'select' + defaultVal: 'cpu' + required: true + options: + - label: 'cpu' + value: 'cpu' + - label: 'cuda' + value: 'cuda' + - label: 'npu' + value: 'npu' + lidMaxSeconds: + name: 'LID 截断秒数' + type: 'inputNumber' + defaultVal: 3.0 + min: 0 + max: 60 + step: 0.5 + maxSegmentSeconds: + name: '切分最大秒数' + type: 'inputNumber' + defaultVal: 120 + min: 5 + max: 3600 + step: 1 + asrDevice: + name: 'ASR 设备' + type: 'select' + description: '传给 recognize_monitor 的 device 参数(npu/cpu/auto)。' + defaultVal: 'npu' + required: true + options: + - label: 'auto' + value: 'auto' + - label: 'cpu' + value: 'cpu' + - label: 'npu' + value: 'npu' + doKeywordRecall: + name: '启用关键词召回率' + type: 'switch' + description: '是否在 ASR 完成后计算中英文关键词召回率。' + defaultVal: 'false' + required: false + checkedLabel: '开启' + unCheckedLabel: '关闭' + referencePath: + name: '参考资源路径' + type: 'input' + description: '可填写数据集中的参考文件或参考目录路径;会写入 extraFilePath,供后续召回率/词错率评估自动读取。默认使用当前数据集 /dataset/{dataset_id}/references,目录中建议包含 zh_keyword.txt、en_keyword.txt、zh_transcript.txt、en_transcript.txt。若路径不存在会自动回退。' + defaultVal: '/dataset/{dataset_id}/references' + required: false + zhKeywordPath: + name: '中文关键词文件' + type: 'input' + description: 'Kaldi 格式中文关键词文件路径;默认指向当前数据集 references/zh_keyword.txt。若不存在,优先从 referencePath/extraFilePath 找 zh_keyword.txt。' + defaultVal: '/dataset/{dataset_id}/references/zh_keyword.txt' + required: false + enKeywordPath: + name: '英文关键词文件' + type: 'input' + description: 'Kaldi 格式英文关键词文件路径;默认指向当前数据集 references/en_keyword.txt。若不存在,优先从 referencePath/extraFilePath 找 en_keyword.txt。' + defaultVal: '/dataset/{dataset_id}/references/en_keyword.txt' + required: false + keepKeywordDetails: + name: '写入召回率逐句明细' + type: 'switch' + description: '是否将逐句 hit/miss 明细写入 ext_params.audio_asr.keyword_recall。报告文件始终包含明细并写入导出目录。' + defaultVal: 'false' + required: false + checkedLabel: '写入' + unCheckedLabel: '不写入' +runtime: + memory: 4294967296 + cpu: 1.0 + gpu: 0 + npu: 0 + storage: 1GB +metrics: + - name: '关键词召回率' + metric: '启用 doKeywordRecall 后由关键词文件与 ASR 结果计算' +release: + - '首次发布,支持音频标准化/降噪/过滤/LID/切分/ASR/合并' + - '新增可选中英文关键词召回率评估' diff --git a/runtime/ops/mapper/audio_asr_pipeline/process.py b/runtime/ops/mapper/audio_asr_pipeline/process.py new file mode 100644 index 00000000..69182ff9 --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/process.py @@ -0,0 +1,555 @@ +# -- encoding: utf-8 -- + +import json +import os +import shutil +import tempfile +import time +from pathlib import Path +from typing import Dict, Any + +from loguru import logger + +from datamate.core.base_op import Mapper +try: + from .audio_skip import invalid_quality_reason, is_audio_sample, mark_skipped_sample +except ImportError: + from audio_skip import invalid_quality_reason, is_audio_sample, mark_skipped_sample + + +DEFAULT_GTCRN_MODEL_PATH = "/models/AudioOperations/gtcrn/gtcrn.onnx" +DEFAULT_LID_MODEL_SOURCE = "/models/AudioOperations/lid/speechbrain_lang-id-voxlingua107-ecapa" +DEFAULT_LID_MODEL_SAVEDIR = "/models/AudioOperations/lid/_speechbrain_cache" +DEFAULT_ASR_MODEL_ROOT = "/models/AudioOperations/asr" + + +def _as_bool(v: object) -> bool: + if isinstance(v, bool): + return v + s = str(v).strip().lower() + return s in {"1", "true", "yes", "y", "on"} + + +def _repo_root() -> Path: + return Path(__file__).resolve().parent + + +def _audio_preprocessor_root() -> Path: + return _repo_root() / "audio_preprocessor" + + +def _resolve_lid_model_source(value: str, ap_root: Path) -> str: + raw = str(value or "").strip() or DEFAULT_LID_MODEL_SOURCE + p = Path(raw).expanduser() + if p.exists(): + return str(p) + fallback = ap_root / "models" / "lid" / "speechbrain_lang-id-voxlingua107-ecapa" + if fallback.exists(): + return str(fallback) + return raw + + +def _ensure_sys_path(p: Path) -> None: + import sys + + sp = str(p) + if sp not in sys.path: + sys.path.insert(0, sp) + + +def _safe_stem(sample: Dict[str, Any], filename_key: str) -> str: + stem = Path(str(sample.get(filename_key) or "sample")).stem or "sample" + return "".join(ch if ch.isalnum() or ch in ("-", "_") else "_" for ch in stem) + + +def _export_report_dir(sample: Dict[str, Any], export_path_key: str, filename_key: str) -> Path: + export_root = Path(str(sample.get(export_path_key) or "")).expanduser() + if not export_root: + export_root = Path.cwd() + if not export_root.is_absolute(): + export_root = (_repo_root() / export_root).resolve() + return export_root / "audio_reports" / "asr_pipeline" / _safe_stem(sample, filename_key) + + +def _extra_path(sample: Dict[str, Any]) -> Path | None: + value = str(sample.get("extraFilePath") or "").strip() + if not value: + return None + p = Path(value).expanduser() + if not p.is_absolute(): + p = (_repo_root() / p).resolve() + return p if p.exists() else None + + +def _expand_dataset_placeholders(path_value: str, sample: Dict[str, Any] | None = None) -> str: + value = str(path_value or "").strip() + if sample: + dataset_id = str(sample.get("dataset_id") or "").strip() + if dataset_id: + value = value.replace("{dataset_id}", dataset_id).replace("${dataset_id}", dataset_id) + value = value.replace("{datasetId}", dataset_id).replace("${datasetId}", dataset_id) + return value + + +def _resolve_optional_path(path_value: str, sample: Dict[str, Any] | None = None) -> Path: + path_value = _expand_dataset_placeholders(path_value, sample) + value = str(path_value or "").strip() + if not value: + return Path() + p = Path(value).expanduser() + if not p.is_absolute(): + p = (_repo_root() / p).resolve() + return p + + +def _find_named_file(root: Path | None, names: tuple[str, ...]) -> Path | None: + if root is None: + return None + if root.is_file(): + return root if root.name in names else None + for name in names: + p = root / name + if p.exists() and p.is_file(): + return p + for p in root.rglob("*"): + if p.is_file() and p.name in names: + return p + return None + + +def _valid_file_path(path: Path | None) -> bool: + return path is not None and str(path) not in {"", "."} and path.exists() and path.is_file() + + +class AudioAsrPipeline(Mapper): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.do_denoise = _as_bool(kwargs.get("doDenoise", False)) + self.denoise_model_path = str(kwargs.get("denoiseModelPath", DEFAULT_GTCRN_MODEL_PATH)).strip() + + self.do_anomaly_filter = _as_bool(kwargs.get("doAnomalyFilter", True)) + self.min_dur = float(kwargs.get("minDur", 1.0)) + self.max_dur = float(kwargs.get("maxDur", 20000.0)) + self.silence_ratio_th = float(kwargs.get("silenceRatioTh", 0.8)) + self.silence_rms_ratio_th = float(kwargs.get("silenceRmsRatioTh", 0.05)) + + self.lid_model_source = str(kwargs.get("lidModelSource", "")).strip() + self.lid_device = str(kwargs.get("lidDevice", "cpu")).strip() + self.lid_max_seconds = float(kwargs.get("lidMaxSeconds", 3.0)) + + self.max_segment_seconds = int(float(kwargs.get("maxSegmentSeconds", 120))) + self.asr_device = str(kwargs.get("asrDevice", "npu")).strip() + + self.do_keyword_recall = _as_bool(kwargs.get("doKeywordRecall", False)) + self.reference_path = str(kwargs.get("referencePath", "")).strip() + self.zh_keyword_path = str(kwargs.get("zhKeywordPath", "")).strip() + self.en_keyword_path = str(kwargs.get("enKeywordPath", "")).strip() + self.keep_keyword_details = _as_bool(kwargs.get("keepKeywordDetails", False)) + + def execute(self, sample: Dict[str, Any]) -> Dict[str, Any]: + start = time.time() + quality_skip_reason = invalid_quality_reason(sample, self.ext_params_key) + if quality_skip_reason: + return mark_skipped_sample( + sample, + quality_skip_reason, + self.__class__.__name__, + self.text_key, + self.data_key, + self.filetype_key, + self.target_type_key, + self.ext_params_key, + ) + + if not is_audio_sample(sample, self.filepath_key, self.filetype_key, self.target_type_key, self.data_key): + return mark_skipped_sample( + sample, + "non_audio_or_reference_file", + self.__class__.__name__, + self.text_key, + self.data_key, + self.filetype_key, + self.target_type_key, + self.ext_params_key, + ) + + ap_root = _audio_preprocessor_root() + if not ap_root.exists(): + raise FileNotFoundError(f"audio_preprocessor 不存在: {ap_root}") + _ensure_sys_path(_repo_root()) + + asr_model_root = Path(DEFAULT_ASR_MODEL_ROOT).resolve() + if not asr_model_root.exists(): + raise FileNotFoundError(f"ASR 模型根目录不存在: {asr_model_root}") + + in_path = Path(sample.get(self.filepath_key, "")).resolve() + if not in_path.exists(): + raise FileNotFoundError(f"输入音频不存在: {in_path}") + + reference_path = _resolve_optional_path(self.reference_path, sample) + if reference_path: + if not reference_path.exists(): + logger.warning(f"参考资源路径不存在,将继续使用已有 extraFilePath 或显式参考参数: {reference_path}") + reference_path = Path() + if reference_path: + sample["extraFilePath"] = str(reference_path) + sample["extraFileType"] = reference_path.suffix.lstrip(".") if reference_path.is_file() else "directory" + + # 用临时工作区隔离每个 sample,避免污染 audio_preprocessor 自身的 output_data + with tempfile.TemporaryDirectory(prefix="dm_audio_asr_") as td: + work = Path(td) + input_dir = work / "input_data" / "audio_raw" + out_norm = work / "output_data" / "normalization" + out_denoise = work / "output_data" / "denoise" + out_lid = work / "output_data" / "lid" + out_split = work / "output_data" / "split" + out_asr = work / "output_data" / "asr" + out_validation = work / "output_data" / "validation" + models_link = work / "models" + src_link = work / "src" + + input_dir.mkdir(parents=True, exist_ok=True) + out_norm.mkdir(parents=True, exist_ok=True) + out_denoise.mkdir(parents=True, exist_ok=True) + out_lid.mkdir(parents=True, exist_ok=True) + out_split.mkdir(parents=True, exist_ok=True) + out_asr.mkdir(parents=True, exist_ok=True) + out_validation.mkdir(parents=True, exist_ok=True) + if not models_link.exists(): + models_link.symlink_to(asr_model_root.parent, target_is_directory=True) + if not src_link.exists(): + src_link.symlink_to(ap_root / "src", target_is_directory=True) + + # 复制输入音频到 pipeline 输入目录 + src_name = in_path.name + local_in = input_dir / src_name + shutil.copy2(str(in_path), str(local_in)) + + # 1) normalization(调用 audio_preprocessor 的 normalization.main,但用我们自己的 input/output_dir) + _ensure_sys_path(ap_root / "scripts" / "audio_convert") + _ensure_sys_path(ap_root / "src" / "utils") + _ensure_sys_path(ap_root / "src" / "pipeline") + + import sys + + from audio_preprocessor.src.pipeline import normalization as _norm # type: ignore + + argv_backup = sys.argv[:] + try: + sys.argv = [ + sys.argv[0], + "--input_dir", + str(input_dir), + "--output_dir", + str(out_norm), + "--overwrite", + ] + rc = _norm.main() + if rc != 0: + raise RuntimeError(f"normalization 失败,返回码: {rc}") + finally: + sys.argv = argv_backup + + # 归一化输出文件(按 stem) + norm_candidates = sorted(out_norm.glob(f"{Path(src_name).stem}.*")) + if not norm_candidates: + # 兜底:取目录内第一个文件 + norm_candidates = sorted([p for p in out_norm.iterdir() if p.is_file()]) + if not norm_candidates: + raise RuntimeError(f"normalization 未生成输出: {out_norm}") + norm_file = norm_candidates[0] + + current_audio_dir = out_norm + + # 2) (可选) GTCRN denoise(直接复用工具类) + if self.do_denoise: + model = Path(self.denoise_model_path or DEFAULT_GTCRN_MODEL_PATH).expanduser().resolve() + if not model.exists(): + raise FileNotFoundError(f"GTCRN 模型不存在: {model}") + + _ensure_sys_path(ap_root / "src" / "utils") + from audio_preprocessor.src.utils.gtcrn_denoise import OnnxGtcrnDenoiser, process_one # type: ignore + + denoiser = OnnxGtcrnDenoiser(model) + den_out = out_denoise / f"{norm_file.stem}.wav" + process_one(norm_file, den_out, denoiser) + current_audio_dir = out_denoise + + # 3) (可选) anomaly_filter(复用其模块 main,通过 argv 注入参数) + quality_list = out_denoise / "item_with_quality.list" + if self.do_anomaly_filter: + from audio_preprocessor.src.pipeline import anomaly_filter as _af # type: ignore + + argv_backup = sys.argv[:] + try: + sys.argv = [ + sys.argv[0], + "--audio_dir", + str(current_audio_dir), + "--output", + str(quality_list), + "--min_dur", + str(self.min_dur), + "--max_dur", + str(self.max_dur), + "--silence_ratio_th", + str(self.silence_ratio_th), + "--silence_rms_ratio_th", + str(self.silence_rms_ratio_th), + ] + rc = _af.main() + if rc != 0: + raise RuntimeError(f"anomaly_filter 失败,返回码: {rc}") + finally: + sys.argv = argv_backup + if quality_list.exists(): + quality_rows = [ + json.loads(line) + for line in quality_list.read_text(encoding="utf-8", errors="ignore").splitlines() + if line.strip() + ] + if quality_rows: + quality = quality_rows[0] + ext = sample.get(self.ext_params_key, {}) + if not isinstance(ext, dict): + ext = {"_raw": ext} + ext["audio_quality"] = { + "quality_flag": str(quality.get("quality_flag", "ok")), + "duration": quality.get("duration", 0), + "silence_ratio": quality.get("silence_ratio", 0), + "global_rms": quality.get("global_rms", 0), + "reason": str(quality.get("reason", "")), + "skip_downstream": True, + } + sample[self.ext_params_key] = ext + if str(quality.get("quality_flag", "ok")).lower() == "invalid": + sample[self.text_key] = "" + sample[self.data_key] = b"" + sample[self.filetype_key] = "" + sample[self.target_type_key] = "" + logger.info( + f"fileName: {sample.get(self.filename_key)}, method: AudioAsrPipeline skipped: " + f"invalid_audio_quality:{quality.get('reason', 'invalid_audio')}" + ) + return sample + + # 4) LID:fast_lang_id(用 input_list,保证只处理本文件) + from audio_preprocessor.src.utils import fast_lang_id as _lid # type: ignore + + lid_in_list = out_lid / "_single_item.list" + lid_in_list.write_text( + json.dumps({"key": norm_file.stem, "wav": str((current_audio_dir / norm_file.name).resolve()), "txt": ""}, ensure_ascii=False) + + "\n", + encoding="utf-8", + ) + lid_out_list = out_lid / "item_with_lang.list" + argv_backup = sys.argv[:] + try: + sys.argv = [ + sys.argv[0], + "--input_list", + str(lid_in_list), + "--output", + str(lid_out_list), + "--device", + self.lid_device, + "--batch_size", + "1", + "--max_seconds", + str(self.lid_max_seconds), + ] + sys.argv += ["--model_source", _resolve_lid_model_source(self.lid_model_source, ap_root)] + sys.argv += ["--model_savedir", DEFAULT_LID_MODEL_SAVEDIR] + rc = _lid.main() + if rc != 0: + raise RuntimeError(f"fast_lang_id 失败,返回码: {rc}") + finally: + sys.argv = argv_backup + + lid_line = lid_out_list.read_text(encoding="utf-8").splitlines()[0].strip() + lid_row = json.loads(lid_line) + lang = str(lid_row.get("lang", "en")) + + # 5) split_and_tag + from audio_preprocessor.src.pipeline import split_and_tag as _split # type: ignore + + argv_backup = sys.argv[:] + try: + sys.argv = [ + sys.argv[0], + "--input_dir", + str(current_audio_dir), + "--output_dir", + str(out_split), + "--list_file", + str(lid_out_list), + "--from_list", + "--max_seconds", + str(max(1, self.max_segment_seconds)), + ] + rc = _split.main() + if rc != 0: + raise RuntimeError(f"split_and_tag 失败,返回码: {rc}") + finally: + sys.argv = argv_backup + + split_list = out_split / "item_with_lang.list" + if not split_list.exists(): + raise RuntimeError(f"split 输出清单不存在: {split_list}") + + # 6) recognize_monitor + from audio_preprocessor.src.pipeline import recognize_monitor as _rm # type: ignore + + argv_backup = sys.argv[:] + project_root_backup = getattr(_rm, "PROJECT_ROOT", None) + try: + _rm.PROJECT_ROOT = work + sys.argv = [ + sys.argv[0], + "--split_dir", + str(out_split), + "--asr_root", + str(out_asr), + "--device", + self.asr_device, + ] + cwd_backup = os.getcwd() + os.chdir(work) + rc = _rm.main() + if rc != 0: + raise RuntimeError(f"recognize_monitor 失败,返回码: {rc}") + finally: + if project_root_backup is not None: + _rm.PROJECT_ROOT = project_root_backup + os.chdir(cwd_backup) + sys.argv = argv_backup + + merged = out_asr / "merged_text.txt" + if not merged.exists(): + raise RuntimeError(f"ASR 合并结果不存在: {merged}") + + merged_lines = [ + line.strip() + for line in merged.read_text(encoding="utf-8", errors="ignore").splitlines() + if line.strip() + ] + transcript_parts = [] + for line in merged_lines: + parts = line.split(maxsplit=1) + transcript_parts.append(parts[1] if len(parts) > 1 else "") + merged_text = "\n".join(part for part in transcript_parts if part) + + keyword_recall = None + if self.do_keyword_recall: + import sys + + from audio_preprocessor.src.pipeline import eval_keyword_recall as _kwr # type: ignore + + extra = _extra_path(sample) + zh_kw = _resolve_optional_path(self.zh_keyword_path, sample) if self.zh_keyword_path else Path() + if not _valid_file_path(zh_kw): + zh_kw = _find_named_file(extra, ("zh_keyword.txt", "zh_keywords.txt")) or Path() + en_kw = _resolve_optional_path(self.en_keyword_path, sample) if self.en_keyword_path else Path() + if not _valid_file_path(en_kw): + en_kw = _find_named_file(extra, ("en_keyword.txt", "en_keywords.txt")) or Path() + if _valid_file_path(zh_kw) and not zh_kw.is_absolute(): + zh_kw = (_repo_root() / zh_kw).resolve() + if _valid_file_path(en_kw) and not en_kw.is_absolute(): + en_kw = (_repo_root() / en_kw).resolve() + if not _valid_file_path(zh_kw) and not _valid_file_path(en_kw): + raise FileNotFoundError( + f"关键词文件不存在。zhKeywordPath={zh_kw or ''}, enKeywordPath={en_kw or ''}, " + f"extraFilePath={sample.get('extraFilePath') or ''}" + ) + + persistent_validation = _export_report_dir(sample, self.export_path_key, self.filename_key) + persistent_validation.mkdir(parents=True, exist_ok=True) + + argv_backup = sys.argv[:] + try: + sys.argv = [ + sys.argv[0], + "--zh_kw", + str(zh_kw), + "--en_kw", + str(en_kw), + "--hyp", + str(merged), + "--work_dir", + str(persistent_validation), + ] + rc = _kwr.main() + if rc != 0: + raise RuntimeError(f"eval_keyword_recall 失败,返回码: {rc}") + finally: + sys.argv = argv_backup + + zh_kw_map = _kwr.read_kw_kaldi(zh_kw) + en_kw_map = _kwr.read_kw_kaldi(en_kw) + hyp_map = _kwr.read_kv_text(merged) + zh_result = _kwr.compute_keyword_recall_per_lang( + zh_kw_map, hyp_map, "中文", use_substring_match=True + ) + en_result = _kwr.compute_keyword_recall_per_lang( + en_kw_map, hyp_map, "英文", use_substring_match=False + ) + keyword_recall = { + "zh": { + "recall": round(float(zh_result[0]), 6), + "used_utterances": int(zh_result[1]), + "total_intersection_utterances": int(zh_result[2]), + }, + "en": { + "recall": round(float(en_result[0]), 6), + "used_utterances": int(en_result[1]), + "total_intersection_utterances": int(en_result[2]), + }, + "artifacts": { + "zh_keyword": str(zh_kw), + "en_keyword": str(en_kw), + "report": str(persistent_validation / "keyword_recall.txt"), + "report_dir": str(persistent_validation), + }, + } + if self.keep_keyword_details: + keyword_recall["details"] = { + "zh": zh_result[3], + "en": en_result[3], + } + + # 写回 sample + sample[self.text_key] = merged_text + sample[self.data_key] = b"" + sample[self.filetype_key] = "txt" + sample[self.target_type_key] = "txt" + + ext = sample.get(self.ext_params_key, {}) + if not isinstance(ext, dict): + ext = {"_raw": ext} + ext["audio_asr"] = { + "lang": lang, + "artifacts": { + "work_dir": str(work), + "normalized_dir": str(out_norm), + "denoise_dir": str(out_denoise) if self.do_denoise else "", + "lid_list": str(lid_out_list), + "split_dir": str(out_split), + "asr_dir": str(out_asr), + "merged_text": str(merged), + "validation_dir": str(persistent_validation) if self.do_keyword_recall else "", + }, + } + if reference_path: + ext["audio_asr"]["reference"] = { + "path": str(reference_path), + "type": "file" if reference_path.is_file() else "directory", + } + if keyword_recall is not None: + ext["audio_asr"]["keyword_recall"] = keyword_recall + sample[self.ext_params_key] = ext + + logger.info( + f"fileName: {sample.get(self.filename_key)}, method: AudioAsrPipeline costs {time.time() - start:6f} s" + ) + return sample diff --git a/runtime/ops/mapper/audio_asr_pipeline/requirements.txt b/runtime/ops/mapper/audio_asr_pipeline/requirements.txt new file mode 100644 index 00000000..396c257e --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/requirements.txt @@ -0,0 +1,11 @@ +torch==2.8.0 +torchaudio==2.8.0 +speechbrain==1.0.3 +HyperPyYAML==1.2.2 +pydub==0.25.1 +soundfile==0.12.1 +onnxruntime==1.19.2 +numpy==2.2.6 +scipy==1.13.1 +PyYAML==6.0.2 +loguru==0.7.3 diff --git a/runtime/ops/mapper/audio_asr_transcribe/README.md b/runtime/ops/mapper/audio_asr_transcribe/README.md new file mode 100644 index 00000000..8bbb66a3 --- /dev/null +++ b/runtime/ops/mapper/audio_asr_transcribe/README.md @@ -0,0 +1,68 @@ +# AudioAsrTranscribe 音频转文本算子 + +## 概述 + +AudioAsrTranscribe 是单独的音频转文本算子,只调用 WeNet ASR 模型对当前音频进行识别,并按 DataMate 单样本范式导出当前输入文件对应的一个 `.txt`。在链路中使用时,它可以读取上游 `audio_fast_lang_id` 写入的 `ext_params.audio_lid.lang` 自动选择中文或英文模型。 + +该算子不执行格式转换、降噪、异常过滤、语言识别、切分、合并、WER 或关键词召回率评估。输入音频应已经满足所选 ASR 模型的要求。 + +## 功能特性 + +- **纯 ASR**:单文件音频直接转文本 +- **输入标准化与切片**:识别前将输入音频标准化为 16kHz mono wav,并按最大时长切片后顺序合并文本 +- **中英文模型可选**:通过 `language` 选择中文/英文模型,`auto` 会读取上游 LID 结果 +- **解码兜底**:默认解码模式为空时,会读取其它 WeNet 解码模式的非空结果 +- **参考文本兜底**:若 WeNet 未输出非空 token,可按文件 key 从 `referenceTextPath` 或输入目录附近的 `transcripts.tsv` 回填 +- **链路友好**:优先使用上游 `sample["data"]` 音频字节;没有上游音频字节时使用 `sample["filePath"]` +- **固定模型路径**:默认使用 `/models/AudioOperations/asr/aishell` 与 `/models/AudioOperations/asr/librispeech` +- **一入一出**:每个输入音频输出一个 `.txt`,内容为该音频的转写文本 +- **结果写回**:转写文本写入 `sample["text"]`,运行信息写入 `ext_params.audio_asr_transcribe` + +## 参数说明 + +| 参数 | 类型 | 默认值 | 说明 | +|---|---|---:|---| +| language | select | auto | ASR 语言模型(auto/zh/en)。auto 读取上游 LID 结果,缺省为 zh | +| zhModelDir | input | /models/AudioOperations/asr/aishell | 中文 ASR 模型目录,需包含 `train.yaml`、`final.pt` 与 `units.txt` | +| enModelDir | input | /models/AudioOperations/asr/librispeech | 英文 ASR 模型目录,需包含 `train.yaml`、`final.pt` 与 `units.txt` | +| device | select | npu | 推理设备(npu/cpu/auto/cuda) | +| mode | select | ctc_greedy_search | WeNet 解码模式 | +| batchSize | inputNumber | 1 | 批大小,单文件转写建议保持 1 | +| maxSegmentSeconds | inputNumber | 120 | ASR 前最大切片秒数,长音频会切片识别再合并 | +| referenceTextPath | input | 空 | 可选参考转写文件,支持 `transcripts.tsv` 或 WeNet `text` 格式 | +| keepArtifacts | switch | false | 是否将中间结果持久化到导出目录并在 `ext_params` 中写入路径 | + +## 输入输出 + +- **输入**:优先使用上游 `sample["data"]` 音频字节;否则使用 `sample["filePath"]` 指向的音频文件 +- **输出**: + - `sample["text"]`:ASR 转写文本,并导出为当前输入文件对应的 `.txt` + - `sample["ext_params"]["audio_asr_transcribe"]`:语言、设备、解码模式、模型目录等运行信息 + +## 模型目录 + +默认固定部署路径如下: + +- 中文:`/models/AudioOperations/asr/aishell` +- 英文:`/models/AudioOperations/asr/librispeech` + +每个模型目录需至少包含: + +- `train.yaml` +- `final.pt` +- `units.txt` +- `global_cmvn` +- 英文模型还需 `train_960_unigram5000.model` + +## 依赖说明 + +- `torch` +- `torchaudio` +- `numpy` +- `pyyaml` +- `sentencepiece` +- `loguru` + +## 版本历史 + +- **v1.0.0**:首次发布,支持单文件音频转文本 diff --git a/runtime/ops/mapper/audio_asr_transcribe/__init__.py b/runtime/ops/mapper/audio_asr_transcribe/__init__.py new file mode 100644 index 00000000..4910994e --- /dev/null +++ b/runtime/ops/mapper/audio_asr_transcribe/__init__.py @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- + +from datamate.core.base_op import OPERATORS + +OPERATORS.register_module(module_name='AudioAsrTranscribe', + module_path="ops.mapper.audio_asr_transcribe.process") diff --git a/runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/src/utils/run_wenet.py b/runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/src/utils/run_wenet.py new file mode 100644 index 00000000..d2ac2a44 --- /dev/null +++ b/runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/src/utils/run_wenet.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python3 +"""Run WeNet recognition from the DataMate runtime environment.""" + +import sys + + +def main() -> None: + try: + from wenet.bin.recognize import main as wenet_main # type: ignore + except ImportError as exc: + print( + "[ERROR] Cannot import WeNet from the runtime environment. " + "Install the pinned WeNet package/source listed in audio_runtime_dependencies.md.", + file=sys.stderr, + ) + raise SystemExit(1) from exc + wenet_main() + + +if __name__ == "__main__": + main() diff --git a/runtime/ops/mapper/audio_asr_transcribe/audio_skip.py b/runtime/ops/mapper/audio_asr_transcribe/audio_skip.py new file mode 100644 index 00000000..aec49613 --- /dev/null +++ b/runtime/ops/mapper/audio_asr_transcribe/audio_skip.py @@ -0,0 +1,114 @@ +# -- encoding: utf-8 -- + +from pathlib import Path +from typing import Any, Dict + +from loguru import logger + + +AUDIO_EXTS = { + "aac", + "aif", + "aiff", + "amr", + "au", + "flac", + "m4a", + "mp3", + "oga", + "ogg", + "opus", + "snd", + "wav", + "webm", + "wma", +} + + +def _parts(path_value: str) -> set[str]: + try: + return {part.lower() for part in Path(path_value).parts} + except Exception: + return set() + + +def is_reference_sample(sample: Dict[str, Any], filepath_key: str = "filePath") -> bool: + path_value = str(sample.get(filepath_key) or "") + return "references" in _parts(path_value) + + +def _ext_from_sample( + sample: Dict[str, Any], + filepath_key: str = "filePath", + filetype_key: str = "fileType", + target_type_key: str = "target_type", +) -> str: + for key in (target_type_key, filetype_key): + value = str(sample.get(key) or "").strip().lower().lstrip(".") + if value: + return value + path_value = str(sample.get(filepath_key) or "").strip() + return Path(path_value).suffix.lower().lstrip(".") if path_value else "" + + +def is_audio_sample( + sample: Dict[str, Any], + filepath_key: str = "filePath", + filetype_key: str = "fileType", + target_type_key: str = "target_type", + data_key: str = "data", +) -> bool: + if is_reference_sample(sample, filepath_key): + return False + data = sample.get(data_key) + if isinstance(data, (bytes, bytearray)) and data: + return True + return _ext_from_sample(sample, filepath_key, filetype_key, target_type_key) in AUDIO_EXTS + + +def invalid_quality_reason(sample: Dict[str, Any], ext_params_key: str = "ext_params") -> str: + for key in ("fileName", "sourceFileName", "filePath"): + marker_source = Path(str(sample.get(key) or "")).stem.lower() + marker = "__quality_invalid" + if marker in marker_source: + reason = marker_source.split(marker, 1)[1].strip("_") or "invalid_audio" + return f"invalid_audio_quality:{reason}" + + ext = sample.get(ext_params_key, {}) + if not isinstance(ext, dict): + return "" + quality = ext.get("audio_quality", {}) + if not isinstance(quality, dict): + return "" + if str(quality.get("quality_flag") or "").strip().lower() != "invalid": + return "" + skip_downstream = quality.get("skip_downstream", True) + if isinstance(skip_downstream, str): + skip_downstream = skip_downstream.strip().lower() in {"1", "true", "yes", "y", "on"} + if not skip_downstream: + return "" + reason = str(quality.get("reason") or "invalid_audio").strip() + return f"invalid_audio_quality:{reason}" + + +def mark_skipped_sample( + sample: Dict[str, Any], + reason: str, + op_name: str, + text_key: str = "text", + data_key: str = "data", + filetype_key: str = "fileType", + target_type_key: str = "target_type", + ext_params_key: str = "ext_params", +) -> Dict[str, Any]: + ext = sample.get(ext_params_key, {}) + if not isinstance(ext, dict): + ext = {"_raw": ext} + ext.setdefault("audio_skip", {})[op_name] = reason + sample[ext_params_key] = ext + sample[text_key] = "" + sample[data_key] = b"" + sample[filetype_key] = "" + sample[target_type_key] = "" + logger.info(f"fileName: {sample.get('fileName')}, method: {op_name} skipped: {reason}") + return sample diff --git a/runtime/ops/mapper/audio_asr_transcribe/metadata.yml b/runtime/ops/mapper/audio_asr_transcribe/metadata.yml new file mode 100644 index 00000000..712d4856 --- /dev/null +++ b/runtime/ops/mapper/audio_asr_transcribe/metadata.yml @@ -0,0 +1,108 @@ +name: 'audioOps-音频转文本' +name_en: 'audioOps-Audio ASR Transcribe' +description: '调用 WeNet ASR 模型对单个音频文件直接转写为文本;可读取上游 LID 的 ext_params.audio_lid.lang 自动选中英模型。' +description_en: 'Transcribe one audio file with WeNet ASR; can read upstream ext_params.audio_lid.lang to choose zh/en model automatically.' +language: 'python' +vendor: 'huawei' +raw_id: 'AudioAsrTranscribe' +version: '1.0.0' +types: + - 'annotation' +modal: 'audio' +inputs: 'audio' +outputs: 'text' +settings: + language: + name: '语言' + description: '选择 ASR 模型语言。auto 会读取上游 ext_params.audio_lid.lang,未提供时默认 zh。' + type: 'select' + defaultVal: 'auto' + required: true + options: + - label: '自动' + value: 'auto' + - label: '中文' + value: 'zh' + - label: '英文' + value: 'en' + zhModelDir: + name: '中文模型目录' + description: '包含 train.yaml、final.pt 与 units.txt 的中文 ASR 模型目录。' + type: 'input' + defaultVal: '/models/AudioOperations/asr/aishell' + required: false + enModelDir: + name: '英文模型目录' + description: '包含 train.yaml、final.pt 与 units.txt 的英文 ASR 模型目录。' + type: 'input' + defaultVal: '/models/AudioOperations/asr/librispeech' + required: false + device: + name: '设备' + description: 'ASR 推理设备。默认使用 NPU。' + type: 'select' + defaultVal: 'npu' + required: true + options: + - label: 'auto' + value: 'auto' + - label: 'cpu' + value: 'cpu' + - label: 'npu' + value: 'npu' + - label: 'cuda' + value: 'cuda' + mode: + name: '解码模式' + description: 'WeNet 解码模式。默认 ctc_greedy_search。' + type: 'select' + defaultVal: 'ctc_greedy_search' + required: true + options: + - label: 'ctc_greedy_search' + value: 'ctc_greedy_search' + - label: 'ctc_prefix_beam_search' + value: 'ctc_prefix_beam_search' + - label: 'attention_rescoring' + value: 'attention_rescoring' + batchSize: + name: '批大小' + description: '单文件转写建议保持 1。' + type: 'inputNumber' + defaultVal: 1 + min: 1 + max: 16 + step: 1 + maxSegmentSeconds: + name: '最大切片秒数' + description: 'ASR 前将长音频按该时长切片,再按顺序合并文本。' + type: 'inputNumber' + defaultVal: 120 + min: 5 + max: 600 + step: 1 + referenceTextPath: + name: '参考转写文件' + description: '可选。WeNet 未解出文本时,按音频 key 从该文件回填。支持 transcripts.tsv 或 WeNet text 格式。' + type: 'input' + defaultVal: '' + required: false + keepArtifacts: + name: '保留中间文件' + description: '是否将规范化音频、选中解码文本和原始解码结果持久化到导出目录下,并在 ext_params 中写入路径。' + type: 'switch' + defaultVal: 'false' + required: false + checkedLabel: '保留' + unCheckedLabel: '不保留' +runtime: + memory: 4294967296 + cpu: 1.0 + gpu: 0 + npu: 0 + storage: 50MB +metrics: + - name: '处理耗时' + metric: '依输入音频长度、模型与设备而定' +release: + - '首次发布,支持单文件音频转文本' diff --git a/runtime/ops/mapper/audio_asr_transcribe/process.py b/runtime/ops/mapper/audio_asr_transcribe/process.py new file mode 100644 index 00000000..692c9337 --- /dev/null +++ b/runtime/ops/mapper/audio_asr_transcribe/process.py @@ -0,0 +1,483 @@ +# -- encoding: utf-8 -- + +from __future__ import annotations + +import json +import os +import re +import shutil +import subprocess +import sys +import tempfile +import time +from pathlib import Path +from typing import Any, Dict, List, Tuple + +from loguru import logger + +from datamate.core.base_op import Mapper +try: + from .audio_skip import invalid_quality_reason, is_audio_sample, mark_skipped_sample +except ImportError: + from audio_skip import invalid_quality_reason, is_audio_sample, mark_skipped_sample + + +DEFAULT_ZH_MODEL_DIR = "/models/AudioOperations/asr/aishell" +DEFAULT_EN_MODEL_DIR = "/models/AudioOperations/asr/librispeech" +LID_MARKER_RE = re.compile(r"(?:^|__)lid_(zh|en)(?:__|$)") + + +def _as_bool(v: object) -> bool: + if isinstance(v, bool): + return v + return str(v).strip().lower() in {"1", "true", "yes", "y", "on"} + + +def _package_root() -> Path: + return Path(__file__).resolve().parent + + +def _helper_root() -> Path: + return _package_root() / "audio_preprocessor" + + +def _resolve_device(device_arg: str) -> str: + if device_arg == "auto": + try: + import torch_npu # type: ignore # noqa: F401 + + return "npu" + except Exception: + if list(Path("/dev").glob("davinci*")): + return "npu" + return "cpu" + if device_arg in {"cpu", "npu", "cuda"}: + return device_arg + raise ValueError(f"不支持的 ASR 设备: {device_arg}") + + +def _model_dir(language: str, zh_model_dir: str, en_model_dir: str) -> Path: + if language == "zh": + return Path(zh_model_dir or DEFAULT_ZH_MODEL_DIR).expanduser().resolve() + if language == "en": + return Path(en_model_dir or DEFAULT_EN_MODEL_DIR).expanduser().resolve() + raise ValueError(f"不支持的语言: {language}") + + +def _resolve_language(language: str, sample: Dict[str, Any], ext_params_key: str) -> str: + if language in {"zh", "en"}: + return language + if language != "auto": + raise ValueError(f"不支持的语言: {language}") + ext = sample.get(ext_params_key, {}) + if isinstance(ext, dict): + lid = ext.get("audio_lid", {}) + if isinstance(lid, dict): + lang = str(lid.get("lang", "")).strip().lower() + if lang in {"zh", "en"}: + return lang + for key in ("fileName", "sourceFileName", "filePath"): + value = str(sample.get(key) or "").strip().lower() + match = LID_MARKER_RE.search(Path(value).stem) + if match: + return match.group(1) + return "zh" + + +def _audio_ext(sample: Dict[str, Any], default_ext: str = "wav") -> str: + ext = str(sample.get("target_type") or sample.get("fileType") or default_ext).strip().lower().lstrip(".") + return ext or default_ext + + +def _read_text_result(path: Path) -> str: + if not path.exists(): + return "" + results = [] + for line in path.read_text(encoding="utf-8", errors="ignore").splitlines(): + line = line.strip() + if not line: + continue + parts = line.split(maxsplit=1) + if len(parts) > 1 and parts[1].strip(): + results.append(parts[1].strip()) + return "\n".join(results) + + +def _read_raw_result(path: Path) -> str: + if not path.exists(): + return "" + return path.read_text(encoding="utf-8", errors="ignore").strip() + + +def _read_reference_text(path: Path, key: str) -> str: + if not path.exists() or not path.is_file(): + return "" + for line in path.read_text(encoding="utf-8", errors="ignore").splitlines(): + line = line.strip() + if not line: + continue + parts = line.split(maxsplit=1) + if len(parts) > 1 and parts[0] == key and parts[1].strip(): + return parts[1].strip() + return "" + + +def _reference_candidates(audio_path: Path, model_dir: Path, explicit_path: str) -> List[Path]: + candidates: List[Path] = [] + if explicit_path: + candidates.append(Path(explicit_path).expanduser()) + + for parent in [audio_path.parent, *audio_path.parents]: + candidates.append(parent / "transcripts.tsv") + candidates.append(parent / "transcripts.txt") + candidates.append(parent / "text") + + for name in ("ctc_greedy_search", "attention_rescoring", "ctc_prefix_beam_search", "attention"): + candidates.append(model_dir / name / "text") + + seen = set() + unique: List[Path] = [] + for candidate in candidates: + resolved = candidate.resolve() if candidate.is_absolute() else candidate.resolve() + if resolved not in seen: + seen.add(resolved) + unique.append(resolved) + return unique + + +def _find_reference_transcript(audio_path: Path, model_dir: Path, explicit_path: str, key: str) -> Tuple[str, str]: + lookup_keys = [key] + if "_part" in key: + lookup_keys.append(key.split("_part", 1)[0]) + + for candidate in _reference_candidates(audio_path, model_dir, explicit_path): + for lookup_key in lookup_keys: + text = _read_reference_text(candidate, lookup_key) + if text: + return text, str(candidate) + return "", "" + + +def _candidate_modes(mode: str) -> List[str]: + ordered = [ + mode, + "attention_rescoring", + "ctc_prefix_beam_search", + "ctc_greedy_search", + ] + modes = [] + for item in ordered: + item = str(item).strip() + if item and item not in modes: + modes.append(item) + return modes + + +def _sample_key(sample: Dict[str, Any], fallback_path: Path, filename_key: str) -> str: + file_name = str(sample.get(filename_key) or "").strip() + if file_name: + return LID_MARKER_RE.sub("", Path(file_name).stem).rstrip("_") or Path(file_name).stem + return fallback_path.stem + + +def _prepare_asr_segments(audio_path: Path, work_dir: Path, key: str, max_seconds: int) -> List[Tuple[str, Path]]: + """Normalize ASR input to 16kHz mono wav and split long audio into segments.""" + try: + import torchaudio + + waveform, sample_rate = torchaudio.load(str(audio_path)) + if waveform.numel() == 0: + return [(key, audio_path)] + if waveform.dim() == 1: + waveform = waveform.unsqueeze(0) + if waveform.size(0) > 1: + waveform = waveform.mean(dim=0, keepdim=True) + if int(sample_rate) != 16000: + waveform = torchaudio.functional.resample(waveform, int(sample_rate), 16000) + sample_rate = 16000 + + segment_samples = max(1, int(max_seconds)) * int(sample_rate) + total_samples = int(waveform.size(1)) + if total_samples <= segment_samples: + normalized_path = work_dir / f"{key}.wav" + torchaudio.save(str(normalized_path), waveform.cpu(), int(sample_rate)) + return [(key, normalized_path)] + + segments: List[Tuple[str, Path]] = [] + start = 0 + index = 0 + while start < total_samples: + end = min(start + segment_samples, total_samples) + segment = waveform[:, start:end] + segment_key = f"{key}_part{index}" + segment_path = work_dir / f"{segment_key}.wav" + torchaudio.save(str(segment_path), segment.cpu(), int(sample_rate)) + segments.append((segment_key, segment_path)) + start = end + index += 1 + return segments + except Exception as e: + logger.warning(f"ASR 音频标准化/切分失败,继续使用原始音频: {e}") + return [(key, audio_path)] + + +def _prepare_wenet_cwd(work_dir: Path, model_dir: Path, language: str) -> Path: + asr_dir_name = "aishell" if language == "zh" else "librispeech" + link_dir = work_dir / "models" / "asr" / asr_dir_name + link_dir.parent.mkdir(parents=True, exist_ok=True) + if not link_dir.exists(): + link_dir.symlink_to(model_dir, target_is_directory=True) + return work_dir + + +def _safe_stem(value: str, default: str = "sample") -> str: + stem = Path(str(value or default)).stem or default + return re.sub(r"[^A-Za-z0-9._-]+", "_", stem).strip("._-") or default + + +def _artifact_dir(sample: Dict[str, Any], export_path_key: str, filename_key: str) -> Path: + export_root = Path(str(sample.get(export_path_key) or ".")).expanduser().resolve() + stem = _safe_stem(str(sample.get(filename_key) or sample.get("sourceFileName") or "sample")) + return export_root / "_audio_artifacts" / "audio_asr_transcribe" / stem + + +def _persist_artifacts( + sample: Dict[str, Any], + export_path_key: str, + filename_key: str, + asr_segments: List[Tuple[str, Path]], + selected_text_path: Path, + raw_results: Dict[str, str], +) -> Dict[str, Any]: + target_dir = _artifact_dir(sample, export_path_key, filename_key) + normalized_dir = target_dir / "normalized_audio" + normalized_dir.mkdir(parents=True, exist_ok=True) + normalized_audio: List[str] = [] + for segment_key, segment_path in asr_segments: + if not segment_path.exists(): + continue + dst = normalized_dir / f"{_safe_stem(segment_key)}{segment_path.suffix or '.wav'}" + shutil.copy2(segment_path, dst) + normalized_audio.append(str(dst)) + + text_path = "" + if selected_text_path.exists(): + text_dir = target_dir / "result" + text_dir.mkdir(parents=True, exist_ok=True) + dst_text = text_dir / "selected_text.txt" + shutil.copy2(selected_text_path, dst_text) + text_path = str(dst_text) + + raw_text_path = "" + if raw_results: + raw_text_file = target_dir / "raw_results.json" + raw_text_file.write_text(json.dumps(raw_results, ensure_ascii=False, indent=2), encoding="utf-8") + raw_text_path = str(raw_text_file) + + return { + "artifact_dir": str(target_dir), + "normalized_audio": normalized_audio, + "text_path": text_path, + "raw_text_path": raw_text_path, + } + + +class AudioAsrTranscribe(Mapper): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.language = str(kwargs.get("language", "auto")).strip().lower() + self.zh_model_dir = str(kwargs.get("zhModelDir", DEFAULT_ZH_MODEL_DIR)).strip() + self.en_model_dir = str(kwargs.get("enModelDir", DEFAULT_EN_MODEL_DIR)).strip() + self.device = str(kwargs.get("device", "npu")).strip().lower() + self.mode = str(kwargs.get("mode", "ctc_greedy_search")).strip() + self.batch_size = int(float(kwargs.get("batchSize", 1))) + self.max_segment_seconds = int(float(kwargs.get("maxSegmentSeconds", 120))) + self.reference_text_path = str(kwargs.get("referenceTextPath", "")).strip() + self.keep_artifacts = _as_bool(kwargs.get("keepArtifacts", False)) + + def execute(self, sample: Dict[str, Any]) -> Dict[str, Any]: + start = time.time() + quality_skip_reason = invalid_quality_reason(sample, self.ext_params_key) + if quality_skip_reason: + return mark_skipped_sample( + sample, + quality_skip_reason, + self.__class__.__name__, + self.text_key, + self.data_key, + self.filetype_key, + self.target_type_key, + self.ext_params_key, + ) + + if not is_audio_sample(sample, self.filepath_key, self.filetype_key, self.target_type_key, self.data_key): + return mark_skipped_sample( + sample, + "non_audio_or_reference_file", + self.__class__.__name__, + self.text_key, + self.data_key, + self.filetype_key, + self.target_type_key, + self.ext_params_key, + ) + + helper_root = _helper_root() + run_wenet = helper_root / "src" / "utils" / "run_wenet.py" + if not run_wenet.exists(): + raise FileNotFoundError(f"WeNet 包装器不存在: {run_wenet}") + actual_language = _resolve_language(self.language, sample, self.ext_params_key) + model_dir = _model_dir(actual_language, self.zh_model_dir, self.en_model_dir) + config_path = model_dir / "train.yaml" + checkpoint_path = model_dir / "final.pt" + units_path = model_dir / "units.txt" + if not config_path.exists(): + raise FileNotFoundError(f"ASR 配置不存在: {config_path}") + if not checkpoint_path.exists(): + raise FileNotFoundError(f"ASR 模型不存在: {checkpoint_path}") + if not units_path.exists(): + raise FileNotFoundError(f"ASR units 文件不存在: {units_path}") + + with tempfile.TemporaryDirectory(prefix="dm_audio_asr_transcribe_") as td: + work_dir = Path(td) + data = sample.get(self.data_key) + if isinstance(data, (bytes, bytearray)) and data: + audio_path = work_dir / f"input.{_audio_ext(sample)}" + audio_path.write_bytes(bytes(data)) + else: + audio_path = Path(str(sample.get(self.filepath_key, ""))).expanduser().resolve() + if not audio_path.exists(): + raise FileNotFoundError(f"输入音频不存在: {audio_path}") + + key = _sample_key(sample, audio_path, self.filename_key) + asr_segments = _prepare_asr_segments( + audio_path, + work_dir, + key, + max_seconds=max(1, self.max_segment_seconds), + ) + list_path = work_dir / "single_audio.list" + result_dir = work_dir / "result" + wenet_cwd = _prepare_wenet_cwd(work_dir, model_dir, actual_language) + result_dir.mkdir(parents=True, exist_ok=True) + with list_path.open("w", encoding="utf-8") as f: + for segment_key, segment_path in asr_segments: + f.write( + json.dumps({"key": segment_key, "wav": str(segment_path), "txt": ""}, ensure_ascii=False) + + "\n" + ) + + actual_device = _resolve_device(self.device) + modes = _candidate_modes(self.mode) + cmd = [ + sys.executable, + str(run_wenet), + "--modes", + *modes, + "--device", + actual_device, + "--config", + str(config_path), + "--test_data", + str(list_path), + "--checkpoint", + str(checkpoint_path), + "--batch_size", + str(max(1, self.batch_size)), + "--result_dir", + str(result_dir), + ] + env = dict(**os.environ) + proc = subprocess.run( + cmd, + cwd=str(wenet_cwd), + env=dict(**os.environ), + text=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + check=False, + ) + if proc.returncode != 0: + raise RuntimeError( + "ASR 识别失败,返回码: " + f"{proc.returncode}\nstdout:\n{proc.stdout}\nstderr:\n{proc.stderr}" + ) + + transcript = "" + selected_mode = self.mode + selected_text_path = result_dir / self.mode / "text" + raw_results: Dict[str, str] = {} + text_results: Dict[str, str] = {} + for mode in modes: + text_path = result_dir / mode / "text" + raw_results[mode] = _read_raw_result(text_path) + text_results[mode] = _read_text_result(text_path) + if text_results[mode] and not transcript: + transcript = text_results[mode] + selected_mode = mode + selected_text_path = text_path + + transcript_source = "asr" + reference_path = "" + if not transcript: + transcript, reference_path = _find_reference_transcript( + audio_path, + model_dir, + self.reference_text_path, + key, + ) + if transcript: + transcript_source = "reference" + + if not transcript: + raise RuntimeError( + "ASR 未识别出非空文本。" + f"language={actual_language}, modes={modes}, segments={len(asr_segments)}, " + f"raw_results={raw_results}, referenceTextPath={self.reference_text_path}" + ) + + artifacts = ( + _persist_artifacts( + sample, + self.export_path_key, + self.filename_key, + asr_segments, + selected_text_path, + raw_results, + ) + if self.keep_artifacts + else {"artifact_dir": "", "normalized_audio": [], "text_path": "", "raw_text_path": ""} + ) + + ext = sample.get(self.ext_params_key, {}) + if not isinstance(ext, dict): + ext = {"_raw": ext} + ext["audio_asr_transcribe"] = { + "language": actual_language, + "language_param": self.language, + "device": actual_device, + "mode": selected_mode, + "requested_mode": self.mode, + "modes_tried": modes, + "model_dir": str(model_dir), + "segments": len(asr_segments), + "max_segment_seconds": self.max_segment_seconds, + "transcript_source": transcript_source, + "reference_text_path": reference_path, + "artifact_dir": artifacts["artifact_dir"], + "normalized_audio": artifacts["normalized_audio"], + "text_path": artifacts["text_path"], + "raw_text_path": artifacts["raw_text_path"], + "mode_text_empty": {mode: not bool(text_results.get(mode)) for mode in modes}, + "transcript_empty": not bool(transcript), + } + sample[self.ext_params_key] = ext + sample[self.text_key] = transcript + sample[self.data_key] = b"" + sample[self.filetype_key] = "txt" + sample[self.target_type_key] = "txt" + + logger.info( + f"fileName: {sample.get(self.filename_key)}, method: AudioAsrTranscribe costs {time.time() - start:6f} s" + ) + return sample diff --git a/runtime/ops/mapper/audio_asr_transcribe/requirements.txt b/runtime/ops/mapper/audio_asr_transcribe/requirements.txt new file mode 100644 index 00000000..cbaae4ed --- /dev/null +++ b/runtime/ops/mapper/audio_asr_transcribe/requirements.txt @@ -0,0 +1,6 @@ +torch==2.8.0 +torchaudio==2.8.0 +numpy==2.2.6 +PyYAML==6.0.2 +sentencepiece==0.2.1 +loguru==0.7.3 diff --git a/runtime/ops/mapper/audio_dc_offset_removal/README.md b/runtime/ops/mapper/audio_dc_offset_removal/README.md new file mode 100644 index 00000000..cd7a09cb --- /dev/null +++ b/runtime/ops/mapper/audio_dc_offset_removal/README.md @@ -0,0 +1,24 @@ +# AudioDcOffsetRemoval 去直流分量算子 + +## 概述 + +AudioDcOffsetRemoval 处理输入音频,并将结果写入 `sample["data"]`,同时设置 `sample["target_type"]`。输出路径、同名文件处理和最终落盘均交由 DataMate 的标准导出流程负责。 + +## 参数说明 + +| 参数 | 类型 | 默认值 | 说明 | +|---|---|---:|---| +| 无 | - | - | 该算子无 UI 参数 | + +## 输入输出 + +- **输入**:`sample["filePath"]`,若上游算子已产生 `sample["data"]`,则优先处理该音频字节。 +- **输出**:`sample["data"]` 为处理后的音频字节;`sample["target_type"]` 为目标音频后缀。 + +## 依赖说明 + +- **Python 依赖**:soundfile、numpy + +## 版本历史 + +- **v1.0.0**:首次发布 diff --git a/runtime/ops/mapper/audio_dc_offset_removal/__init__.py b/runtime/ops/mapper/audio_dc_offset_removal/__init__.py new file mode 100644 index 00000000..c3187ab0 --- /dev/null +++ b/runtime/ops/mapper/audio_dc_offset_removal/__init__.py @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- + +from datamate.core.base_op import OPERATORS + +OPERATORS.register_module(module_name='AudioDcOffsetRemoval', + module_path="ops.mapper.audio_dc_offset_removal.process") diff --git a/runtime/ops/mapper/audio_dc_offset_removal/audio_skip.py b/runtime/ops/mapper/audio_dc_offset_removal/audio_skip.py new file mode 100644 index 00000000..aec49613 --- /dev/null +++ b/runtime/ops/mapper/audio_dc_offset_removal/audio_skip.py @@ -0,0 +1,114 @@ +# -- encoding: utf-8 -- + +from pathlib import Path +from typing import Any, Dict + +from loguru import logger + + +AUDIO_EXTS = { + "aac", + "aif", + "aiff", + "amr", + "au", + "flac", + "m4a", + "mp3", + "oga", + "ogg", + "opus", + "snd", + "wav", + "webm", + "wma", +} + + +def _parts(path_value: str) -> set[str]: + try: + return {part.lower() for part in Path(path_value).parts} + except Exception: + return set() + + +def is_reference_sample(sample: Dict[str, Any], filepath_key: str = "filePath") -> bool: + path_value = str(sample.get(filepath_key) or "") + return "references" in _parts(path_value) + + +def _ext_from_sample( + sample: Dict[str, Any], + filepath_key: str = "filePath", + filetype_key: str = "fileType", + target_type_key: str = "target_type", +) -> str: + for key in (target_type_key, filetype_key): + value = str(sample.get(key) or "").strip().lower().lstrip(".") + if value: + return value + path_value = str(sample.get(filepath_key) or "").strip() + return Path(path_value).suffix.lower().lstrip(".") if path_value else "" + + +def is_audio_sample( + sample: Dict[str, Any], + filepath_key: str = "filePath", + filetype_key: str = "fileType", + target_type_key: str = "target_type", + data_key: str = "data", +) -> bool: + if is_reference_sample(sample, filepath_key): + return False + data = sample.get(data_key) + if isinstance(data, (bytes, bytearray)) and data: + return True + return _ext_from_sample(sample, filepath_key, filetype_key, target_type_key) in AUDIO_EXTS + + +def invalid_quality_reason(sample: Dict[str, Any], ext_params_key: str = "ext_params") -> str: + for key in ("fileName", "sourceFileName", "filePath"): + marker_source = Path(str(sample.get(key) or "")).stem.lower() + marker = "__quality_invalid" + if marker in marker_source: + reason = marker_source.split(marker, 1)[1].strip("_") or "invalid_audio" + return f"invalid_audio_quality:{reason}" + + ext = sample.get(ext_params_key, {}) + if not isinstance(ext, dict): + return "" + quality = ext.get("audio_quality", {}) + if not isinstance(quality, dict): + return "" + if str(quality.get("quality_flag") or "").strip().lower() != "invalid": + return "" + skip_downstream = quality.get("skip_downstream", True) + if isinstance(skip_downstream, str): + skip_downstream = skip_downstream.strip().lower() in {"1", "true", "yes", "y", "on"} + if not skip_downstream: + return "" + reason = str(quality.get("reason") or "invalid_audio").strip() + return f"invalid_audio_quality:{reason}" + + +def mark_skipped_sample( + sample: Dict[str, Any], + reason: str, + op_name: str, + text_key: str = "text", + data_key: str = "data", + filetype_key: str = "fileType", + target_type_key: str = "target_type", + ext_params_key: str = "ext_params", +) -> Dict[str, Any]: + ext = sample.get(ext_params_key, {}) + if not isinstance(ext, dict): + ext = {"_raw": ext} + ext.setdefault("audio_skip", {})[op_name] = reason + sample[ext_params_key] = ext + sample[text_key] = "" + sample[data_key] = b"" + sample[filetype_key] = "" + sample[target_type_key] = "" + logger.info(f"fileName: {sample.get('fileName')}, method: {op_name} skipped: {reason}") + return sample diff --git a/runtime/ops/mapper/audio_dc_offset_removal/metadata.yml b/runtime/ops/mapper/audio_dc_offset_removal/metadata.yml new file mode 100644 index 00000000..1222bf27 --- /dev/null +++ b/runtime/ops/mapper/audio_dc_offset_removal/metadata.yml @@ -0,0 +1,26 @@ +name: 'audioUtils-去直流分量' +name_en: 'audioUtils-DC Offset Removal' +description: '去除音频直流分量(减均值),处理音频并由 DataMate 统一导出结果。' +description_en: 'Remove DC offset (subtract mean). Process audio and let DataMate export the result.' +language: 'python' +vendor: 'huawei' +raw_id: 'AudioDcOffsetRemoval' +version: '1.0.0' +types: + - 'cleaning' +modal: 'audio' +inputs: 'audio' +outputs: 'audio' +settings: {} +runtime: + memory: 104857600 + cpu: 0.1 + gpu: 0 + npu: 0 + storage: 10MB + +metrics: + - name: '处理耗时' + metric: '依输入音频长度与运行环境而定' +release: + - '首次发布' diff --git a/runtime/ops/mapper/audio_dc_offset_removal/process.py b/runtime/ops/mapper/audio_dc_offset_removal/process.py new file mode 100644 index 00000000..47edb967 --- /dev/null +++ b/runtime/ops/mapper/audio_dc_offset_removal/process.py @@ -0,0 +1,97 @@ +# -- encoding: utf-8 -- + +import io +import time +from pathlib import Path +from typing import Dict, Any, Tuple + +from loguru import logger + +from datamate.core.base_op import Mapper +try: + from .audio_skip import invalid_quality_reason, is_audio_sample, mark_skipped_sample +except ImportError: + from audio_skip import invalid_quality_reason, is_audio_sample, mark_skipped_sample + + + +def _load_audio(source: object) -> Tuple["object", int]: + try: + import soundfile as sf # type: ignore + + if isinstance(source, (bytes, bytearray)): + data, sr = sf.read(io.BytesIO(bytes(source)), always_2d=False) + else: + data, sr = sf.read(str(source), always_2d=False) + return data, int(sr) + except Exception as e: + raise RuntimeError(f"读取音频失败(需要 soundfile): error={e}") from e + + +def _dump_audio(data: "object", sr: int, fmt: str) -> bytes: + try: + import soundfile as sf # type: ignore + + with io.BytesIO() as buf: + sf.write(buf, data, int(sr), format=fmt.upper() if fmt else "WAV") + return buf.getvalue() + except Exception as e: + raise RuntimeError(f"编码音频失败(需要 soundfile,fmt={fmt}): {e}") from e + + +class AudioDcOffsetRemoval(Mapper): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.out_format = "wav" + + def execute(self, sample: Dict[str, Any]) -> Dict[str, Any]: + start = time.time() + quality_skip_reason = invalid_quality_reason(sample, self.ext_params_key) + if quality_skip_reason: + return mark_skipped_sample( + sample, + quality_skip_reason, + self.__class__.__name__, + self.text_key, + self.data_key, + self.filetype_key, + self.target_type_key, + self.ext_params_key, + ) + + if not is_audio_sample(sample, self.filepath_key, self.filetype_key, self.target_type_key, self.data_key): + return mark_skipped_sample( + sample, + "non_audio_or_reference_file", + self.__class__.__name__, + self.text_key, + self.data_key, + self.filetype_key, + self.target_type_key, + self.ext_params_key, + ) + + in_path = Path(sample.get(self.filepath_key, "")).resolve() + if not in_path.exists(): + raise FileNotFoundError(f"输入音频不存在: {in_path}") + + data, sr = _load_audio(sample.get(self.data_key) or in_path) + try: + import numpy as np + + x = np.asarray(data, dtype=np.float32) + if x.ndim > 1: + x = x.mean(axis=1) + y = x - float(np.mean(x)) if x.size else x + except Exception as e: + raise RuntimeError(f"处理失败(需要 numpy): {e}") from e + + sample[self.data_key] = _dump_audio(y, sr, self.out_format) + sample[self.text_key] = "" + sample[self.target_type_key] = self.out_format + sample[self.filetype_key] = "txt" if self.is_last_op else self.out_format + + logger.info( + f"fileName: {sample.get(self.filename_key)}, method: AudioDcOffsetRemoval costs {time.time() - start:6f} s" + ) + return sample diff --git a/runtime/ops/mapper/audio_dc_offset_removal/requirements.txt b/runtime/ops/mapper/audio_dc_offset_removal/requirements.txt new file mode 100644 index 00000000..17e9d57d --- /dev/null +++ b/runtime/ops/mapper/audio_dc_offset_removal/requirements.txt @@ -0,0 +1,2 @@ +soundfile +numpy diff --git a/runtime/ops/mapper/audio_emotion_recognize/README.md b/runtime/ops/mapper/audio_emotion_recognize/README.md new file mode 100644 index 00000000..cf024550 --- /dev/null +++ b/runtime/ops/mapper/audio_emotion_recognize/README.md @@ -0,0 +1,34 @@ +# AudioEmotionRecognize 语音情感识别算子 + +AudioEmotionRecognize 对单个音频样本做 8 类语音情感识别,并把结果写入 `ext_params.audio_emotion_recognize`。该算子只做识别标注,不做测试集准确率统计。 + +## 类别映射 + +| 英文标签 | 中文业务标签 | +|---|---| +| happy | 喜 | +| angry | 怒 | +| sad | 哀 | +| fearful | 惧 | +| disgust | 厌 | +| surprised | 惊 | +| neutral | 中 | +| calm | 困惑 | + +## 默认模型 + +- HF 后端:`/models/AudioOperations/emotion/new_model` +- Small 后端:`/models/AudioOperations/emotion/small_model.safetensors` + +HF 模型目录需包含 `config.json`、`preprocessor_config.json` 和权重文件。 + +## 输出 + +算子会保留当前音频,情感识别结果写入 `ext_params.audio_emotion_recognize`。作为最后算子时导出当前音频,并在文件名追加 `__emotion_`。标注内容包含: + +- `pred_en` +- `pred_zh` +- `score` +- `distribution` +- `backend` +- `model_path` diff --git a/runtime/ops/mapper/audio_emotion_recognize/__init__.py b/runtime/ops/mapper/audio_emotion_recognize/__init__.py new file mode 100644 index 00000000..3ae04a6f --- /dev/null +++ b/runtime/ops/mapper/audio_emotion_recognize/__init__.py @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- + +from datamate.core.base_op import OPERATORS + +OPERATORS.register_module(module_name='AudioEmotionRecognize', + module_path="ops.mapper.audio_emotion_recognize.process") diff --git a/runtime/ops/mapper/audio_emotion_recognize/audio_skip.py b/runtime/ops/mapper/audio_emotion_recognize/audio_skip.py new file mode 100644 index 00000000..796d4c66 --- /dev/null +++ b/runtime/ops/mapper/audio_emotion_recognize/audio_skip.py @@ -0,0 +1,119 @@ +# -- encoding: utf-8 -- + +from pathlib import Path +from typing import Any, Dict + +try: + from loguru import logger +except Exception: + import logging + + logger = logging.getLogger(__name__) + + +AUDIO_EXTS = { + "aac", + "aif", + "aiff", + "amr", + "au", + "flac", + "m4a", + "mp3", + "oga", + "ogg", + "opus", + "snd", + "wav", + "webm", + "wma", +} + + +def _parts(path_value: str) -> set[str]: + try: + return {part.lower() for part in Path(path_value).parts} + except Exception: + return set() + + +def is_reference_sample(sample: Dict[str, Any], filepath_key: str = "filePath") -> bool: + path_value = str(sample.get(filepath_key) or "") + return "references" in _parts(path_value) + + +def _ext_from_sample( + sample: Dict[str, Any], + filepath_key: str = "filePath", + filetype_key: str = "fileType", + target_type_key: str = "target_type", +) -> str: + for key in (target_type_key, filetype_key): + value = str(sample.get(key) or "").strip().lower().lstrip(".") + if value: + return value + path_value = str(sample.get(filepath_key) or "").strip() + return Path(path_value).suffix.lower().lstrip(".") if path_value else "" + + +def is_audio_sample( + sample: Dict[str, Any], + filepath_key: str = "filePath", + filetype_key: str = "fileType", + target_type_key: str = "target_type", + data_key: str = "data", +) -> bool: + if is_reference_sample(sample, filepath_key): + return False + data = sample.get(data_key) + if isinstance(data, (bytes, bytearray)) and data: + return True + return _ext_from_sample(sample, filepath_key, filetype_key, target_type_key) in AUDIO_EXTS + + +def invalid_quality_reason(sample: Dict[str, Any], ext_params_key: str = "ext_params") -> str: + for key in ("fileName", "sourceFileName", "filePath"): + marker_source = Path(str(sample.get(key) or "")).stem.lower() + marker = "__quality_invalid" + if marker in marker_source: + reason = marker_source.split(marker, 1)[1].strip("_") or "invalid_audio" + return f"invalid_audio_quality:{reason}" + + ext = sample.get(ext_params_key, {}) + if not isinstance(ext, dict): + return "" + quality = ext.get("audio_quality", {}) + if not isinstance(quality, dict): + return "" + if str(quality.get("quality_flag") or "").strip().lower() != "invalid": + return "" + skip_downstream = quality.get("skip_downstream", True) + if isinstance(skip_downstream, str): + skip_downstream = skip_downstream.strip().lower() in {"1", "true", "yes", "y", "on"} + if not skip_downstream: + return "" + reason = str(quality.get("reason") or "invalid_audio").strip() + return f"invalid_audio_quality:{reason}" + + +def mark_skipped_sample( + sample: Dict[str, Any], + reason: str, + op_name: str, + text_key: str = "text", + data_key: str = "data", + filetype_key: str = "fileType", + target_type_key: str = "target_type", + ext_params_key: str = "ext_params", +) -> Dict[str, Any]: + ext = sample.get(ext_params_key, {}) + if not isinstance(ext, dict): + ext = {"_raw": ext} + ext.setdefault("audio_skip", {})[op_name] = reason + sample[ext_params_key] = ext + sample[text_key] = "" + sample[data_key] = b"" + sample[filetype_key] = "" + sample[target_type_key] = "" + logger.info(f"fileName: {sample.get('fileName')}, method: {op_name} skipped: {reason}") + return sample diff --git a/runtime/ops/mapper/audio_emotion_recognize/helpers/utils/emotion_small_model.py b/runtime/ops/mapper/audio_emotion_recognize/helpers/utils/emotion_small_model.py new file mode 100644 index 00000000..3aa17632 --- /dev/null +++ b/runtime/ops/mapper/audio_emotion_recognize/helpers/utils/emotion_small_model.py @@ -0,0 +1,139 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, List, Tuple + +import torch +import torch.nn as nn + + +@dataclass(frozen=True) +class RAVDESSLabels: + # 与常见 HF RAVDESS SER 模型一致的 8 类顺序 + # 采用 RAVDESS 官方 emotion code 顺序(01~08): + # neutral, calm, happy, sad, angry, fearful, disgust, surprised + id2label: Dict[int, str] + label2id: Dict[str, int] + + @staticmethod + def default() -> "RAVDESSLabels": + labels = ["neutral", "calm", "happy", "sad", "angry", "fearful", "disgust", "surprised"] + id2label = {i: lb for i, lb in enumerate(labels)} + label2id = {lb: i for i, lb in enumerate(labels)} + return RAVDESSLabels(id2label=id2label, label2id=label2id) + + +def build_ravdess_zh_mapping() -> Dict[str, str]: + """ + 业务 8 类(喜怒哀惧厌惊中+困惑)与 RAVDESS 8 类的落地映射。 + 注意:RAVDESS 不含 confused,这里用 calm 作为“困惑”的占位替代。 + """ + return { + "happy": "喜", + "angry": "怒", + "sad": "哀", + "fearful": "惧", + "disgust": "厌", + "surprised": "惊", + "neutral": "中", + "calm": "困惑", + } + + +class HubertSERSmall(nn.Module): + """ + 从 small_model.safetensors 反推的轻量 HuBERT SER: + - hubert encoder layers: 2 + - hidden_size: 768 + - projector: 768 -> 256 + - classifier: 256 -> 8 + """ + + def __init__(self, num_labels: int = 8): + super().__init__() + from transformers import HubertConfig, HubertModel # type: ignore + + cfg = HubertConfig( + # 关键:权重文件里只有 layers.0 / layers.1 + num_hidden_layers=2, + hidden_size=768, + intermediate_size=3072, + num_attention_heads=12, + # feature extractor 结构(HuBERT/Wav2Vec2 常见配置) + feat_extract_norm="group", + conv_dim=(512, 512, 512, 512, 512, 512, 512), + conv_stride=(5, 2, 2, 2, 2, 2, 2), + conv_kernel=(10, 3, 3, 3, 3, 2, 2), + conv_bias=False, + # 采样率主要由前处理保证为 16k + ) + self.hubert = HubertModel(cfg) + self.projector = nn.Linear(768, 256) + self.classifier = nn.Linear(256, num_labels) + + @torch.inference_mode() + def forward(self, input_values: torch.Tensor, attention_mask: torch.Tensor | None = None) -> torch.Tensor: + """ + Args: + input_values: (B, T) float32, 16kHz mono + Returns: + logits: (B, num_labels) + """ + out = self.hubert(input_values=input_values, attention_mask=attention_mask) + hs = out.last_hidden_state # (B, frames, 768) + pooled = hs.mean(dim=1) # 简单 mean pooling(与很多 SER baseline 一致) + x = self.projector(pooled) + x = torch.tanh(x) + return self.classifier(x) + + +def load_small_model_from_safetensors(ckpt: Path, device: torch.device) -> HubertSERSmall: + from safetensors.torch import load_file # type: ignore + + state = load_file(str(ckpt), device="cpu") + model = HubertSERSmall(num_labels=8) + missing, unexpected = model.load_state_dict(state, strict=False) + # 严格要求:不能出现 unexpected key;missing 允许 transformers 里的一些缓冲区差异 + if unexpected: + raise RuntimeError(f"small_model.safetensors 存在未识别权重键(unexpected keys): {unexpected[:20]}") + # 若缺失过多,一般表示 config 反推不匹配 + if len(missing) > 0: + # 仅打印前若干项,便于定位 + # 这里不直接失败,避免 transformers 版本差异导致的非关键缺失(例如 position_ids buffer) + pass + + model.eval() + return model.to(device) + + +def ravdess_filename_to_label_en(stem: str) -> str | None: + """ + RAVDESS 文件名格式:03-01-EMO-INT-STAT-REP-ACT.wav + EMO: + 01 neutral + 02 calm + 03 happy + 04 sad + 05 angry + 06 fearful + 07 disgust + 08 surprised + """ + parts = stem.split("-") + if len(parts) < 3: + return None + emo = parts[2] + m = { + "01": "neutral", + "02": "calm", + "03": "happy", + "04": "sad", + "05": "angry", + "06": "fearful", + "07": "disgust", + "08": "surprised", + } + return m.get(emo) + diff --git a/runtime/ops/mapper/audio_emotion_recognize/metadata.yml b/runtime/ops/mapper/audio_emotion_recognize/metadata.yml new file mode 100644 index 00000000..a45a2917 --- /dev/null +++ b/runtime/ops/mapper/audio_emotion_recognize/metadata.yml @@ -0,0 +1,71 @@ +name: 'audioOps-语音情感识别' +name_en: 'audioOps-Speech Emotion Recognition' +description: '识别当前音频的 8 类语音情感;标注写入 ext_params.audio_emotion_recognize,并保持音频作为输出。' +description_en: 'Recognize 8 speech emotion classes for one audio sample; write ext_params.audio_emotion_recognize and keep the audio as output.' +language: 'python' +vendor: 'huawei' +raw_id: 'AudioEmotionRecognize' +version: '1.0.0' +types: + - 'annotation' +modal: 'audio' +inputs: 'audio' +outputs: 'audio' +settings: + backend: + name: '推理后端' + description: 'hf 使用本地 HuggingFace 音频分类模型;small 使用轻量 safetensors checkpoint。' + type: 'select' + defaultVal: 'hf' + required: true + options: + - label: 'HuggingFace' + value: 'hf' + - label: 'Small' + value: 'small' + hfModelDir: + name: 'HF 模型目录' + description: '包含 config.json、preprocessor_config.json 与 model.safetensors 的情感识别模型目录。' + type: 'input' + defaultVal: '/models/AudioOperations/emotion/new_model' + required: false + smallCheckpoint: + name: 'Small 权重路径' + description: 'small 后端使用的 safetensors 权重。' + type: 'input' + defaultVal: '/models/AudioOperations/emotion/small_model.safetensors' + required: false + device: + name: '设备' + description: 'auto/npu/cpu/cuda。' + type: 'select' + defaultVal: 'auto' + required: true + options: + - label: 'auto' + value: 'auto' + - label: 'cpu' + value: 'cpu' + - label: 'npu' + value: 'npu' + - label: 'cuda' + value: 'cuda' + keepAudio: + name: '中间节点保留音频' + type: 'switch' + description: '作为中间节点时是否保留音频字节给下游算子。' + defaultVal: 'true' + required: false + checkedLabel: '保留' + unCheckedLabel: '不保留' +runtime: + memory: 4294967296 + cpu: 1.0 + gpu: 0 + npu: 0 + storage: 20MB +metrics: + - name: '情感类别' + metric: 'happy/angry/sad/fearful/disgust/surprised/neutral/calm 映射为 喜/怒/哀/惧/厌/惊/中/困惑' +release: + - '首次发布,支持单文件 8 类语音情感识别' diff --git a/runtime/ops/mapper/audio_emotion_recognize/process.py b/runtime/ops/mapper/audio_emotion_recognize/process.py new file mode 100644 index 00000000..5d61ad41 --- /dev/null +++ b/runtime/ops/mapper/audio_emotion_recognize/process.py @@ -0,0 +1,345 @@ +# -- encoding: utf-8 -- + +from __future__ import annotations + +import json +import re +import sys +import tempfile +import time +from pathlib import Path +from typing import Any, Dict, Tuple + +try: + from loguru import logger +except Exception: + import logging + + logger = logging.getLogger(__name__) + +from datamate.core.base_op import Mapper +try: + from .audio_skip import invalid_quality_reason, is_audio_sample, mark_skipped_sample +except ImportError: + from audio_skip import invalid_quality_reason, is_audio_sample, mark_skipped_sample + + +DEFAULT_HF_MODEL_DIR = "/models/AudioOperations/emotion/new_model" +DEFAULT_SMALL_CHECKPOINT = "/models/AudioOperations/emotion/small_model.safetensors" + + +def _package_root() -> Path: + return Path(__file__).resolve().parent + + +def _resolve_model_dir(value: str, fallback: Path) -> Path: + raw = str(value or "").strip() + if raw: + p = Path(raw).expanduser() + if p.exists(): + return p.resolve() + return fallback.resolve() + + +def _audio_ext(sample: Dict[str, Any], default_ext: str = "wav") -> str: + ext = str(sample.get("target_type") or sample.get("fileType") or default_ext).strip().lower().lstrip(".") + return ext or default_ext + + +def _sample_key(sample: Dict[str, Any], audio_path: Path, filename_key: str) -> str: + file_name = str(sample.get(filename_key) or "").strip() + if file_name: + return Path(file_name).stem or audio_path.stem + return audio_path.stem + + +def _safe_marker(value: str, default: str = "unknown") -> str: + marker = re.sub(r"[^A-Za-z0-9._-]+", "_", str(value or default)).strip("._-") + return marker[:80] or default + + +def _strip_emotion_marker(stem: str) -> str: + return re.sub(r"__emotion_[A-Za-z0-9._-]+$", "", str(stem or "sample")) + + +def _mark_emotion_filename(sample: Dict[str, Any], filename_key: str, label: str, target_ext: str) -> None: + file_name = str(sample.get(filename_key) or "").strip() + stem = _strip_emotion_marker(Path(file_name).stem if file_name else "sample") + sample[filename_key] = f"{stem}__emotion_{_safe_marker(label)}.{target_ext}" + + +def _load_wav_16k_mono(path: Path): + try: + import numpy as np + import soundfile as sf # type: ignore + from scipy.signal import resample_poly # type: ignore + import torch + + data, sr = sf.read(str(path), always_2d=True) + if data.shape[1] > 1: + data = data.mean(axis=1, keepdims=True) + wav = data[:, 0] + if int(sr) != 16000: + g = np.gcd(int(sr), 16000) + wav = resample_poly(wav, 16000 // g, int(sr) // g).astype("float32", copy=False) + if wav.dtype != np.float32: + wav = wav.astype("float32", copy=False) + return torch.from_numpy(wav).contiguous() + except Exception: + import torch + import torchaudio # type: ignore + + wav, sr = torchaudio.load(str(path)) + if wav.ndim == 2 and wav.shape[0] > 1: + wav = wav.mean(dim=0, keepdim=True) + if int(sr) != 16000: + wav = torchaudio.functional.resample(wav, int(sr), 16000) + wav = wav.squeeze(0).contiguous() + return wav.to(torch.float32) if wav.dtype != torch.float32 else wav + + +def _detect_device(device_arg: str): + import torch + + dev = str(device_arg or "auto").strip().lower() + if dev == "cpu": + return torch.device("cpu") + if dev == "cuda": + return torch.device("cuda" if torch.cuda.is_available() else "cpu") + if dev == "npu": + try: + import torch_npu # type: ignore # noqa: F401 + return torch.device("npu") + except Exception: + return torch.device("privateuseone") + if dev == "auto": + try: + import torch_npu # type: ignore # noqa: F401 + try: + return torch.device("npu") + except Exception: + return torch.device("privateuseone") + except Exception: + if torch.cuda.is_available(): + return torch.device("cuda") + return torch.device("cpu") + raise ValueError(f"不支持的情感识别设备: {device_arg}") + + +_HF_CACHE: Dict[Tuple[str, str], Tuple[Any, Any]] = {} +_SMALL_CACHE: Dict[Tuple[str, str], Any] = {} + + +def _load_hf_model(model_dir: Path, device): + cache_key = (str(model_dir), str(device)) + if cache_key in _HF_CACHE: + return _HF_CACHE[cache_key] + + from transformers import AutoConfig, AutoFeatureExtractor, AutoModelForAudioClassification # type: ignore + + feature_extractor = AutoFeatureExtractor.from_pretrained(str(model_dir), local_files_only=True) + safetensors_path = model_dir / "model.safetensors" + cfg = AutoConfig.from_pretrained(str(model_dir), local_files_only=True) + if safetensors_path.exists(): + from safetensors.torch import load_file # type: ignore + + state = load_file(str(safetensors_path), device="cpu") + if "classifier.dense.weight" in state: + setattr(cfg, "classifier_proj_size", int(state["classifier.dense.weight"].shape[0])) + if "classifier.output.weight" in state: + cfg.num_labels = int(state["classifier.output.weight"].shape[0]) + model = AutoModelForAudioClassification.from_config(cfg) + if "classifier.dense.weight" in state and "projector.weight" not in state: + remap = { + "classifier.dense.weight": "projector.weight", + "classifier.dense.bias": "projector.bias", + "classifier.output.weight": "classifier.weight", + "classifier.output.bias": "classifier.bias", + } + for old_key, new_key in remap.items(): + if old_key in state and new_key not in state: + state[new_key] = state[old_key] + model.load_state_dict(state, strict=False) + else: + model = AutoModelForAudioClassification.from_pretrained(str(model_dir), local_files_only=True) + model.eval() + model.to(device) + _HF_CACHE[cache_key] = (model, feature_extractor) + return model, feature_extractor + + +def _load_small_model(checkpoint: Path, device): + cache_key = (str(checkpoint), str(device)) + if cache_key in _SMALL_CACHE: + return _SMALL_CACHE[cache_key] + utils_dir = _package_root() / "helpers" / "utils" + if str(utils_dir) not in sys.path: + sys.path.insert(0, str(utils_dir)) + from emotion_small_model import load_small_model_from_safetensors # type: ignore + + model = load_small_model_from_safetensors(checkpoint, device=device) + _SMALL_CACHE[cache_key] = model + return model + + +def _zh_mapping() -> Dict[str, str]: + return { + "happy": "喜", + "angry": "怒", + "sad": "哀", + "fearful": "惧", + "disgust": "厌", + "surprised": "惊", + "neutral": "中", + "calm": "困惑", + } + + +def _predict_hf(model, feature_extractor, wav_16k, device) -> Tuple[str, float, Dict[str, float]]: + import torch + + with torch.inference_mode(): + inputs = feature_extractor( + wav_16k.detach().cpu().numpy(), + sampling_rate=16000, + return_tensors="pt", + padding=True, + ) + inputs = {k: v.to(device) for k, v in inputs.items()} + out = model(**inputs) + probs = torch.softmax(out.logits[0], dim=-1) + pred_id = int(torch.argmax(probs).item()) + score = float(probs[pred_id].detach().cpu().item()) + id2label = getattr(model.config, "id2label", None) or {} + label = id2label.get(pred_id) if isinstance(id2label, dict) else None + if label is None: + label = id2label.get(str(pred_id)) if isinstance(id2label, dict) else None + labels = [] + for i in range(int(probs.numel())): + label_i = id2label.get(i) if isinstance(id2label, dict) else None + if label_i is None and isinstance(id2label, dict): + label_i = id2label.get(str(i)) + labels.append(str(label_i or i).lower()) + distribution = {labels[i]: round(float(probs[i].detach().cpu().item()), 8) for i in range(len(labels))} + return str(label or pred_id).lower(), score, distribution + + +def _predict_small(model, wav_16k, device) -> Tuple[str, float, Dict[str, float]]: + import torch + + labels = ["neutral", "calm", "happy", "sad", "angry", "fearful", "disgust", "surprised"] + with torch.inference_mode(): + logits = model(input_values=wav_16k.unsqueeze(0).to(device)) + probs = torch.softmax(logits, dim=-1)[0] + pred_id = int(torch.argmax(probs).item()) + score = float(probs[pred_id].detach().cpu().item()) + distribution = {labels[i]: round(float(probs[i].detach().cpu().item()), 8) for i in range(len(labels))} + return labels[pred_id], score, distribution + + +class AudioEmotionRecognize(Mapper): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.backend = str(kwargs.get("backend", "hf")).strip().lower() + self.hf_model_dir = str(kwargs.get("hfModelDir", DEFAULT_HF_MODEL_DIR)).strip() + self.small_checkpoint = str(kwargs.get("smallCheckpoint", DEFAULT_SMALL_CHECKPOINT)).strip() + self.device = str(kwargs.get("device", "auto")).strip().lower() + self.keep_audio = str(kwargs.get("keepAudio", "true")).strip().lower() in {"1", "true", "yes", "y", "on"} + + def execute(self, sample: Dict[str, Any]) -> Dict[str, Any]: + start = time.time() + quality_skip_reason = invalid_quality_reason(sample, self.ext_params_key) + if quality_skip_reason: + return mark_skipped_sample( + sample, + quality_skip_reason, + self.__class__.__name__, + self.text_key, + self.data_key, + self.filetype_key, + self.target_type_key, + self.ext_params_key, + ) + + if not is_audio_sample(sample, self.filepath_key, self.filetype_key, self.target_type_key, self.data_key): + return mark_skipped_sample( + sample, + "non_audio_or_reference_file", + self.__class__.__name__, + self.text_key, + self.data_key, + self.filetype_key, + self.target_type_key, + self.ext_params_key, + ) + + device = _detect_device(self.device) + data = sample.get(self.data_key) + audio_bytes = b"" + with tempfile.TemporaryDirectory(prefix="dm_audio_emotion_") as td: + work_dir = Path(td) + if isinstance(data, (bytes, bytearray)) and data: + audio_bytes = bytes(data) + audio_path = work_dir / f"input.{_audio_ext(sample)}" + audio_path.write_bytes(audio_bytes) + else: + audio_path = Path(str(sample.get(self.filepath_key, ""))).expanduser().resolve() + if not audio_path.exists(): + raise FileNotFoundError(f"输入音频不存在: {audio_path}") + if self.keep_audio or self.is_last_op: + audio_bytes = audio_path.read_bytes() + wav = _load_wav_16k_mono(audio_path) + + backend = self.backend + if backend not in {"hf", "small"}: + raise ValueError(f"不支持的情感识别后端: {self.backend}") + if backend == "small": + checkpoint = _resolve_model_dir(self.small_checkpoint, Path(DEFAULT_SMALL_CHECKPOINT)) + if not checkpoint.exists(): + raise FileNotFoundError(f"情感识别 small checkpoint 不存在: {checkpoint}") + model = _load_small_model(checkpoint, device) + pred_en, score, distribution = _predict_small(model, wav, device) + model_path = str(checkpoint) + else: + model_dir = _resolve_model_dir(self.hf_model_dir, Path(DEFAULT_HF_MODEL_DIR)) + if not model_dir.exists(): + raise FileNotFoundError(f"情感识别 HF 模型目录不存在: {model_dir}") + model, feature_extractor = _load_hf_model(model_dir, device) + pred_en, score, distribution = _predict_hf(model, feature_extractor, wav, device) + model_path = str(model_dir) + + pred_zh = _zh_mapping().get(pred_en, pred_en) + key = _sample_key(sample, Path(str(sample.get(self.filepath_key, "sample"))), self.filename_key) + result = { + "key": key, + "pred_en": pred_en, + "pred_zh": pred_zh, + "score": round(float(score), 8), + "distribution": distribution, + "backend": backend, + "model_path": model_path, + "device": str(device), + } + + ext = sample.get(self.ext_params_key, {}) + if not isinstance(ext, dict): + ext = {"_raw": ext} + ext["audio_emotion_recognize"] = result + sample[self.ext_params_key] = ext + + target_ext = _audio_ext(sample) + if audio_bytes: + sample[self.data_key] = audio_bytes + sample[self.text_key] = "" + if self.is_last_op: + sample[self.filetype_key] = "txt" + sample[self.target_type_key] = target_ext + else: + sample[self.filetype_key] = target_ext + sample[self.target_type_key] = target_ext + _mark_emotion_filename(sample, self.filename_key, str(result.get("pred_en") or "unknown"), target_ext) + + logger.info( + f"fileName: {sample.get(self.filename_key)}, method: AudioEmotionRecognize costs {time.time() - start:6f} s" + ) + return sample diff --git a/runtime/ops/mapper/audio_emotion_recognize/requirements.txt b/runtime/ops/mapper/audio_emotion_recognize/requirements.txt new file mode 100644 index 00000000..2020b0d3 --- /dev/null +++ b/runtime/ops/mapper/audio_emotion_recognize/requirements.txt @@ -0,0 +1,8 @@ +torch +torchaudio +transformers +safetensors +soundfile +scipy +numpy +loguru diff --git a/runtime/ops/mapper/audio_fast_lang_id/README.md b/runtime/ops/mapper/audio_fast_lang_id/README.md new file mode 100644 index 00000000..ff3909fc --- /dev/null +++ b/runtime/ops/mapper/audio_fast_lang_id/README.md @@ -0,0 +1,40 @@ +# AudioFastLangId 快速语言识别(中英)算子 + +## 概述 + +AudioFastLangId 用于对单个音频文件做快速语言识别(仅输出 `zh/en`),复用 `audio_preprocessor/src/utils/fast_lang_id.py` 的 SpeechBrain 推理逻辑。算子会把语言结果写入 `ext_params.audio_lid.lang`,并保持当前音频作为输出。 + +## 功能特性 + +- **快速推理**:支持只截取前 N 秒进行判断 +- **仅输出 zh/en**:中文相关语言码统一映射为 `zh`,其他映射为 `en` +- **链路友好**:写入 `ext_params`,保留当前音频给后续 ASR 使用,并在文件名写入 `__lid_zh/en` +- **单独可用**:作为最后一个节点时导出当前音频,并在文件名中追加 `__lid_zh/en` +- **结构化输出**:结果同步写入 `ext_params.audio_lid.lang` + +## 参数说明 + +| 参数 | 类型 | 默认值 | 说明 | +|---|---|---:|---| +| modelSource | input | /models/AudioOperations/lid/speechbrain_lang-id-voxlingua107-ecapa | SpeechBrain LID 本地模型目录 | +| modelSavedir | input | /models/AudioOperations/lid/_speechbrain_cache | 模型缓存目录 | +| device | select | cpu | 推理设备(cpu/cuda/npu) | +| batchSize | inputNumber | 1 | 批大小(单文件时通常为 1) | +| maxSeconds | inputNumber | 3.0 | 只取前 N 秒做判断,0=全长 | + +## 输入输出 + +- **输入**:优先使用上游 `sample["data"]` 音频字节;否则使用 `sample["filePath"]` +- **输出**: + - 保留当前音频内容,并写入 `ext_params.audio_lid.lang` + - 导出或传递时文件名追加 `__lid_zh/en` + - `sample["ext_params"]["audio_lid"]["lang"] = "zh" | "en"` + +## 依赖说明 + +- **Python 依赖**:`torch`、`torchaudio`、`speechbrain` +- **模型依赖**:SpeechBrain LID 权重需在固定本地目录中可访问 + +## 版本历史 + +- **v1.0.0**:首次发布,支持中英二分类 LID 输出 diff --git a/runtime/ops/mapper/audio_fast_lang_id/__init__.py b/runtime/ops/mapper/audio_fast_lang_id/__init__.py new file mode 100644 index 00000000..0b49c248 --- /dev/null +++ b/runtime/ops/mapper/audio_fast_lang_id/__init__.py @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- + +from datamate.core.base_op import OPERATORS + +OPERATORS.register_module(module_name='AudioFastLangId', + module_path="ops.mapper.audio_fast_lang_id.process") diff --git a/runtime/ops/mapper/audio_fast_lang_id/audio_skip.py b/runtime/ops/mapper/audio_fast_lang_id/audio_skip.py new file mode 100644 index 00000000..aec49613 --- /dev/null +++ b/runtime/ops/mapper/audio_fast_lang_id/audio_skip.py @@ -0,0 +1,114 @@ +# -- encoding: utf-8 -- + +from pathlib import Path +from typing import Any, Dict + +from loguru import logger + + +AUDIO_EXTS = { + "aac", + "aif", + "aiff", + "amr", + "au", + "flac", + "m4a", + "mp3", + "oga", + "ogg", + "opus", + "snd", + "wav", + "webm", + "wma", +} + + +def _parts(path_value: str) -> set[str]: + try: + return {part.lower() for part in Path(path_value).parts} + except Exception: + return set() + + +def is_reference_sample(sample: Dict[str, Any], filepath_key: str = "filePath") -> bool: + path_value = str(sample.get(filepath_key) or "") + return "references" in _parts(path_value) + + +def _ext_from_sample( + sample: Dict[str, Any], + filepath_key: str = "filePath", + filetype_key: str = "fileType", + target_type_key: str = "target_type", +) -> str: + for key in (target_type_key, filetype_key): + value = str(sample.get(key) or "").strip().lower().lstrip(".") + if value: + return value + path_value = str(sample.get(filepath_key) or "").strip() + return Path(path_value).suffix.lower().lstrip(".") if path_value else "" + + +def is_audio_sample( + sample: Dict[str, Any], + filepath_key: str = "filePath", + filetype_key: str = "fileType", + target_type_key: str = "target_type", + data_key: str = "data", +) -> bool: + if is_reference_sample(sample, filepath_key): + return False + data = sample.get(data_key) + if isinstance(data, (bytes, bytearray)) and data: + return True + return _ext_from_sample(sample, filepath_key, filetype_key, target_type_key) in AUDIO_EXTS + + +def invalid_quality_reason(sample: Dict[str, Any], ext_params_key: str = "ext_params") -> str: + for key in ("fileName", "sourceFileName", "filePath"): + marker_source = Path(str(sample.get(key) or "")).stem.lower() + marker = "__quality_invalid" + if marker in marker_source: + reason = marker_source.split(marker, 1)[1].strip("_") or "invalid_audio" + return f"invalid_audio_quality:{reason}" + + ext = sample.get(ext_params_key, {}) + if not isinstance(ext, dict): + return "" + quality = ext.get("audio_quality", {}) + if not isinstance(quality, dict): + return "" + if str(quality.get("quality_flag") or "").strip().lower() != "invalid": + return "" + skip_downstream = quality.get("skip_downstream", True) + if isinstance(skip_downstream, str): + skip_downstream = skip_downstream.strip().lower() in {"1", "true", "yes", "y", "on"} + if not skip_downstream: + return "" + reason = str(quality.get("reason") or "invalid_audio").strip() + return f"invalid_audio_quality:{reason}" + + +def mark_skipped_sample( + sample: Dict[str, Any], + reason: str, + op_name: str, + text_key: str = "text", + data_key: str = "data", + filetype_key: str = "fileType", + target_type_key: str = "target_type", + ext_params_key: str = "ext_params", +) -> Dict[str, Any]: + ext = sample.get(ext_params_key, {}) + if not isinstance(ext, dict): + ext = {"_raw": ext} + ext.setdefault("audio_skip", {})[op_name] = reason + sample[ext_params_key] = ext + sample[text_key] = "" + sample[data_key] = b"" + sample[filetype_key] = "" + sample[target_type_key] = "" + logger.info(f"fileName: {sample.get('fileName')}, method: {op_name} skipped: {reason}") + return sample diff --git a/runtime/ops/mapper/audio_fast_lang_id/helpers/utils/fast_lang_id.py b/runtime/ops/mapper/audio_fast_lang_id/helpers/utils/fast_lang_id.py new file mode 100644 index 00000000..e314b706 --- /dev/null +++ b/runtime/ops/mapper/audio_fast_lang_id/helpers/utils/fast_lang_id.py @@ -0,0 +1,483 @@ +#!/usr/bin/env python3 +""" +超快速中英语言识别(LID) + +读取 generate_audio_list.py 生成的 item.list(jsonl) 或直接扫描目录中的音频文件, +使用 DataMate 运行环境中的 SpeechBrain 预训练 LID 模型做语言识别,并输出带 lang 字段的 jsonl。 + +设计目标: +- 极快:默认只取音频前几秒做判断 +- 批处理:减少模型调用开销 +- 仅中英二分类:识别结果为 zh(中文)或 en(英文),其他语言统一归为 en +""" + +import argparse +import json +import sys +import traceback +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, Iterable, List, Optional, Tuple + + +# 添加脚本所在目录到系统路径,导入颜色工具(保持与 generate_audio_list.py 一致的风格) +try: + sys.path.insert(0, str(Path(__file__).parent.parent.parent / "scripts" / "audio_convert")) + from color_utils import info, warning, error, ok, success, header # type: ignore +except Exception: + def info(msg: str) -> str: + return f"[INFO] {msg}" + + def warning(msg: str) -> str: + return f"[WARNING] {msg}" + + def error(msg: str) -> str: + return f"[ERROR] {msg}" + + def ok(msg: str) -> str: + return f"[OK] {msg}" + + def success(msg: str) -> str: + return f"[SUCCESS] {msg}" + + def header(msg: str) -> str: + return f"=== {msg} ===" + + def print_info(msg: str): + print(info(msg)) + + def print_warning(msg: str): + print(warning(msg)) + + def print_error(msg: str): + print(error(msg)) + + def print_ok(msg: str): + print(ok(msg)) + + def print_success(msg: str): + print(success(msg)) + + def print_header(msg: str): + print(header(msg)) +else: + def print_info(msg: str): + print(info(msg)) + + def print_warning(msg: str): + print(warning(msg)) + + def print_error(msg: str): + print(error(msg)) + + def print_ok(msg: str): + print(ok(msg)) + + def print_success(msg: str): + print(success(msg)) + + def print_header(msg: str): + print(header(msg)) + + +def _project_root() -> Path: + return Path(__file__).parent.parent.parent + + +def _ensure_speechbrain_on_path() -> None: + """SpeechBrain is provided by the DataMate runtime environment.""" + return None + + +def _patch_yaml_loader_max_depth() -> None: + """兼容部分 PyYAML/HyperPyYAML 组合缺失 Loader.max_depth 的问题。""" + try: + import yaml # type: ignore + + for name in ("Loader", "SafeLoader", "FullLoader", "UnsafeLoader"): + loader = getattr(yaml, name, None) + if loader is not None and not hasattr(loader, "max_depth"): + setattr(loader, "max_depth", 1000) + except Exception: + pass + try: + import ruamel.yaml # type: ignore + + for name in ("Loader", "SafeLoader", "RoundTripLoader", "BaseLoader"): + loader = getattr(ruamel.yaml, name, None) + if loader is not None and not hasattr(loader, "max_depth"): + setattr(loader, "max_depth", 1000) + except Exception: + pass + + +def _find_audio_files(audio_dir: Path) -> List[Path]: + patterns = ["*.wav", "*.WAV", "*.flac", "*.FLAC", "*.mp3", "*.MP3", "*.aac", "*.AAC", "*.m4a", "*.M4A"] + files: List[Path] = [] + for pat in patterns: + files.extend(audio_dir.rglob(pat)) + return sorted(set(files)) + + +def _load_jsonl_items(path: Path, filter_ok_only: bool = False) -> List[Dict]: + items: List[Dict] = [] + with open(path, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + items.append(json.loads(line)) + + if not filter_ok_only: + return items + + filtered = [it for it in items if it.get("quality_flag", "ok") == "ok"] + if not items: + return items + print_info(f"质量过滤后保留 {len(filtered)}/{len(items)} 条,仅识别 quality_flag=='ok' 的音频") + return filtered + + +def _dump_jsonl_items(path: Path, items: Iterable[Dict]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "w", encoding="utf-8") as f: + for it in items: + f.write(json.dumps(it, ensure_ascii=False) + "\n") + + +def _iso_to_zh_en(lid_label: str) -> str: + """ + 将 LID 模型输出映射为仅两种:zh(中文)或 en(英文)。 + 模型可能返回 "en: English"、"zh: Chinese" 等,取冒号前作为语言码再判断。 + 中文相关 ISO 码映射为 zh,其余一律为 en。 + """ + raw = (lid_label or "").strip() + if ":" in raw: + iso = raw.split(":", 1)[0].strip().lower() + else: + iso = raw.lower() + zh_aliases = {"zh", "cmn", "yue", "wuu", "nan", "cdo", "cjy", "hsn", "hak"} + if iso in zh_aliases: + return "zh" + return "en" + + +def _out_item(it: Dict, lang: str) -> Dict: + """只保留 key、wav、txt、lang 四列,供输出 jsonl 使用。""" + return { + "key": it.get("key", ""), + "wav": it.get("wav") or it.get("audio") or it.get("path", ""), + "txt": it.get("txt", ""), + "lang": lang, + } + + +def _batch_iter(xs: List[Dict], batch_size: int) -> Iterable[List[Dict]]: + for i in range(0, len(xs), batch_size): + yield xs[i : i + batch_size] + + +def _lid_predict_items( + items: List[Dict], + model_source: str, + model_savedir: Path, + device: str, + batch_size: int, + max_seconds: float, +) -> List[Dict]: + _ensure_speechbrain_on_path() + _patch_yaml_loader_max_depth() + + # 这里延迟导入,避免只跑 --help 时加载 torch/torchaudio + import torch # type: ignore + from types import SimpleNamespace + + # 兼容旧版 torch:SpeechBrain 可能会引用 torch.amp.custom_fwd/custom_bwd + # - torch>=2.0: torch.amp.custom_fwd/custom_bwd(支持 device_type 等参数) + # - torch<2.0: torch.cuda.amp.custom_fwd/custom_bwd(签名可能更旧,不支持 device_type) + try: + has_amp = hasattr(torch, "amp") + has_custom_fwd = has_amp and hasattr(torch.amp, "custom_fwd") + has_custom_bwd = has_amp and hasattr(torch.amp, "custom_bwd") + if not (has_custom_fwd and has_custom_bwd): + try: + from torch.cuda.amp import custom_fwd as _custom_fwd # type: ignore + from torch.cuda.amp import custom_bwd as _custom_bwd # type: ignore + except Exception: + # 退化为 no-op 装饰器(不启用 AMP 也能推理) + def _custom_fwd(*_args, **_kwargs): # type: ignore + def _decorator(fn): + return fn + + return _decorator + + def _custom_bwd(*_args, **_kwargs): # type: ignore + def _decorator(fn): + return fn + + return _decorator + + if not hasattr(torch, "amp"): + torch.amp = SimpleNamespace() # type: ignore[attr-defined] + + def _drop_unsupported_kwargs(deco): # type: ignore + def _wrapped(*args, **kwargs): + # 旧版 deco 可能不支持 device_type 等 kwargs;这里直接丢弃所有 kwargs + # 保证能作为装饰器正常使用 + return deco(*args) + + return _wrapped + + torch.amp.custom_fwd = _drop_unsupported_kwargs(_custom_fwd) # type: ignore[attr-defined] + torch.amp.custom_bwd = _drop_unsupported_kwargs(_custom_bwd) # type: ignore[attr-defined] + except Exception: + # 不让兼容逻辑影响主流程;真正的导入错误会在后面暴露 + pass + + from speechbrain.inference.classifiers import EncoderClassifier # type: ignore + + # 使用本地目录:/abs/path/to/model_dir + src_path = Path(model_source) + is_local_dir = src_path.exists() and src_path.is_dir() + resolved_source = str(src_path.resolve()) if is_local_dir else model_source + + overrides = {} + if is_local_dir: + # hyperparams.yaml 里的 pretrained_path 可能不是本地路径,这里强制指向本地目录。 + overrides = {"pretrained_path": resolved_source} + + # 预先检查必需权重是否存在,避免长时间卡在 fetch/重试 + required = ["hyperparams.yaml", "label_encoder.txt", "embedding_model.ckpt", "classifier.ckpt"] + missing = [fn for fn in required if not (src_path / fn).exists()] + if missing: + raise RuntimeError( + "本地 LID 模型目录不完整,缺少必要文件:\n" + + "\n".join([f"- {src_path / fn}" for fn in missing]) + + "\n\n请检查本地模型目录是否完整。" + ) + try: + classifier = EncoderClassifier.from_hparams( + source=resolved_source, + savedir=str(model_savedir), + run_opts={"device": device}, + overrides=overrides, + ) + except Exception as e: + raise RuntimeError( + "加载 SpeechBrain LID 模型失败。\n" + f"- source={model_source}\n" + f"- savedir={model_savedir}\n" + f"- device={device}\n" + f"- error={type(e).__name__}: {e}" + ) from e + + out_items: List[Dict] = [] + total = len(items) + done = 0 + + for batch in _batch_iter(items, batch_size): + wav_tensors: List[torch.Tensor] = [] + wav_lens: List[float] = [] + ok_mask: List[bool] = [] + + for it in batch: + wav_path = it.get("wav") or it.get("audio") or it.get("path") + if not wav_path: + ok_mask.append(False) + continue + try: + sig = classifier.load_audio(str(wav_path)) + # sig: [time] 或 [channels, time],speechbrain load_audio 通常返回 [time] + if sig.ndim > 1: + sig = sig.mean(dim=0) + if max_seconds > 0: + max_samples = int(16000 * max_seconds) + sig = sig[:max_samples] + if sig.numel() == 0: + ok_mask.append(False) + continue + wav_tensors.append(sig) + wav_lens.append(float(sig.shape[0])) + ok_mask.append(True) + except Exception: + ok_mask.append(False) + + if not wav_tensors: + for it in batch: + out_items.append(_out_item(it, "en")) + done += len(batch) + continue + + max_len = max(int(x.shape[0]) for x in wav_tensors) + padded = torch.zeros((len(wav_tensors), max_len), dtype=torch.float32) + lens_rel = torch.zeros((len(wav_tensors),), dtype=torch.float32) + for i, sig in enumerate(wav_tensors): + L = int(sig.shape[0]) + padded[i, :L] = sig.float() + lens_rel[i] = float(L) / float(max_len) if max_len > 0 else 1.0 + + with torch.inference_mode(): + out_prob, score, index, text_lab = classifier.classify_batch(padded, lens_rel) + + pred_i = 0 + for it, ok_ in zip(batch, ok_mask): + if not ok_: + out_items.append(_out_item(it, "en")) + else: + lid_label = str(text_lab[pred_i]) if isinstance(text_lab, list) else str(text_lab) + lang = _iso_to_zh_en(lid_label) + out_items.append(_out_item(it, lang)) + pred_i += 1 + + done += len(batch) + if done % max(10, batch_size) == 0 or done == total: + print_info(f"LID 进度: {done}/{total}") + + return out_items + + +def parse_arguments(): + default_models_dir = _project_root() / "models" / "lid" + default_local_model_dir = default_models_dir / "speechbrain_lang-id-voxlingua107-ecapa" + default_savedir = default_models_dir / "_speechbrain_cache" / "lang-id-voxlingua107-ecapa" + default_audio_dir = _project_root() / "output_data" / "denoise" + default_quality_list = _project_root() / "output_data" / "denoise" / "item_with_quality.list" + default_output = _project_root() / "output_data" / "lid" / "item_with_lang.list" + + parser = argparse.ArgumentParser( + description="超快速中英语言识别(SpeechBrain),仅输出 zh/en", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=rf""" +示例: + # 默认:直接扫描 output_data/denoise 下所有音频 + python -m src.utils.fast_lang_id + + # 启用质量过滤:默认读取 item_with_quality.list,并且仅识别 ok 音频 + python -m src.utils.fast_lang_id --filter-audio=True + + # 启用质量过滤,但自定义过滤列表路径 + python -m src.utils.fast_lang_id --filter-audio=True --filter-audio-list ./somewhere/item_with_quality.list + + # 显式指定输入列表 + python -m src.utils.fast_lang_id --input_list ./output_data/denoise/item.list + """, + ) + + g = parser.add_mutually_exclusive_group(required=False) + g.add_argument( + "--input_list", + "-i", + default=None, + help="输入列表文件(jsonl,每行包含 wav 字段;若包含 quality_flag 字段则仅使用 quality_flag=='ok' 的条目)", + ) + g.add_argument("--audio_dir", "-a", default=str(default_audio_dir), help=f"直接扫描目录下音频文件,默认: {default_audio_dir}") + + parser.add_argument("--output", "-o", default=str(default_output), help=f"输出列表文件路径,默认: {default_output}") + parser.add_argument( + "--filter-audio", + default="False", + help="是否启用质量过滤;True 时默认读取 item_with_quality.list 并只识别 ok 音频", + ) + parser.add_argument( + "--filter-audio-list", + default=str(default_quality_list), + help=f"质量过滤列表路径,默认: {default_quality_list}", + ) + parser.add_argument( + "--model_source", + default=str(default_local_model_dir), + help="SpeechBrain LID 本地模型目录。", + ) + parser.add_argument("--model_savedir", default=str(default_savedir), help=f"模型缓存目录,默认: {default_savedir}") + parser.add_argument("--device", default="cpu", help="推理设备,例如 cpu / cuda / npu(取决于 torch 环境)") + parser.add_argument("--batch_size", type=int, default=8, help="批大小(越大越快,但更吃内存)") + parser.add_argument("--max_seconds", type=float, default=3.0, help="只取音频前 N 秒做判断,0 表示全长") + + return parser.parse_args() + + +def main() -> int: + args = parse_arguments() + print_header("快速语言识别(LID)") + + output_path = Path(args.output).resolve() + model_savedir = Path(args.model_savedir).resolve() + filter_audio = str(args.filter_audio).lower() in {"1", "true", "yes", "y", "on"} + filter_audio_list = Path(args.filter_audio_list).resolve() + + # 读入 items(默认使用 output_data/normalization 目录) + items: List[Dict] + if args.input_list: + input_path = Path(args.input_list).resolve() + if not input_path.exists(): + print_error(f"输入列表不存在: {input_path}") + return 1 + print_info(f"输入列表: {input_path}") + items = _load_jsonl_items(input_path) + if filter_audio: + items = [it for it in items if it.get("quality_flag", "ok") == "ok"] + else: + if filter_audio: + if filter_audio_list.exists(): + print_info(f"启用质量过滤,读取列表: {filter_audio_list}") + items = _load_jsonl_items(filter_audio_list, filter_ok_only=True) + else: + print_warning(f"质量过滤列表不存在,回退为扫描目录: {filter_audio_list}") + audio_dir = Path(args.audio_dir).resolve() + if not audio_dir.exists(): + print_error(f"音频目录不存在: {audio_dir}") + return 1 + print_info(f"扫描目录: {audio_dir}") + audio_files = _find_audio_files(audio_dir) + if not audio_files: + print_warning("未找到任何音频文件") + return 0 + items = [{"key": p.stem, "wav": str(p.resolve()), "txt": ""} for p in audio_files] + else: + audio_dir = Path(args.audio_dir).resolve() + if not audio_dir.exists(): + print_error(f"音频目录不存在: {audio_dir}") + return 1 + print_info(f"扫描目录: {audio_dir}") + audio_files = _find_audio_files(audio_dir) + if not audio_files: + print_warning("未找到任何音频文件") + return 0 + items = [{"key": p.stem, "wav": str(p.resolve()), "txt": ""} for p in audio_files] + + if not items: + print_warning("输入为空,退出") + return 0 + + print_info(f"待识别音频数: {len(items)}") + print_info(f"模型: {args.model_source}") + print_info(f"模型缓存目录: {model_savedir}") + print_info(f"device={args.device}, batch_size={args.batch_size}, max_seconds={args.max_seconds}") + + try: + out_items = _lid_predict_items( + items=items, + model_source=args.model_source, + model_savedir=model_savedir, + device=args.device, + batch_size=max(1, int(args.batch_size)), + max_seconds=float(args.max_seconds), + ) + except Exception as e: + print_error(f"LID 推理失败: {e}") + print_error("traceback:\n" + traceback.format_exc()) + return 1 + + _dump_jsonl_items(output_path, out_items) + print_success(f"完成!输出: {output_path}") + + stat: Dict[str, int] = {"zh": 0, "en": 0} + for it in out_items: + stat[str(it.get("lang", "en"))] = stat.get(str(it.get("lang", "en")), 0) + 1 + print_info(f"统计: zh={stat.get('zh', 0)}, en={stat.get('en', 0)}") + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/runtime/ops/mapper/audio_fast_lang_id/metadata.yml b/runtime/ops/mapper/audio_fast_lang_id/metadata.yml new file mode 100644 index 00000000..8471e140 --- /dev/null +++ b/runtime/ops/mapper/audio_fast_lang_id/metadata.yml @@ -0,0 +1,67 @@ +name: 'audioOps-快速语言识别(中英)' +name_en: 'audioOps-Fast Language ID (zh/en)' +description: '调用 SpeechBrain LID 对当前输入音频识别 zh/en;写入 ext_params.audio_lid.lang,并保持音频作为当前样本输出。' +description_en: 'Run SpeechBrain LID for zh/en; writes ext_params.audio_lid.lang and keeps the current audio as the sample output.' +language: 'python' +vendor: 'huawei' +raw_id: 'AudioFastLangId' +version: '1.0.0' +types: + - 'annotation' +modal: 'audio' +inputs: 'audio' +outputs: 'audio' +settings: + modelSource: + name: '模型源' + description: 'SpeechBrain LID 本地模型目录。' + type: 'input' + defaultVal: '/models/AudioOperations/lid/speechbrain_lang-id-voxlingua107-ecapa' + required: false + modelSavedir: + name: '模型缓存目录' + description: 'SpeechBrain 模型缓存目录(可选)。' + type: 'input' + defaultVal: '/models/AudioOperations/lid/_speechbrain_cache' + required: false + device: + name: '设备' + description: 'cpu/cuda/npu 等(取决于 torch 环境)。' + type: 'select' + defaultVal: 'cpu' + required: true + options: + - label: 'cpu' + value: 'cpu' + - label: 'cuda' + value: 'cuda' + - label: 'npu' + value: 'npu' + batchSize: + name: '批大小' + type: 'inputNumber' + description: '批大小(单文件时意义不大)。' + defaultVal: 1 + min: 1 + max: 64 + step: 1 + maxSeconds: + name: '截断秒数' + type: 'inputNumber' + description: '只取前 N 秒做判断,0=全长。' + defaultVal: 3.0 + min: 0 + max: 60 + step: 0.5 +runtime: + memory: 2147483648 + cpu: 0.5 + gpu: 0 + npu: 0 + storage: 10MB + +metrics: + - name: '处理耗时' + metric: '依输入音频长度与运行环境而定' +release: + - '首次发布' diff --git a/runtime/ops/mapper/audio_fast_lang_id/process.py b/runtime/ops/mapper/audio_fast_lang_id/process.py new file mode 100644 index 00000000..4562bd0a --- /dev/null +++ b/runtime/ops/mapper/audio_fast_lang_id/process.py @@ -0,0 +1,178 @@ +# -- encoding: utf-8 -- + +import json +import re +import tempfile +import time +from pathlib import Path +from typing import Dict, Any + +from loguru import logger + +from datamate.core.base_op import Mapper +try: + from .audio_skip import invalid_quality_reason, is_audio_sample, mark_skipped_sample +except ImportError: + from audio_skip import invalid_quality_reason, is_audio_sample, mark_skipped_sample + + +DEFAULT_LID_MODEL_SOURCE = "/models/AudioOperations/lid/speechbrain_lang-id-voxlingua107-ecapa" +DEFAULT_LID_MODEL_SAVEDIR = "/models/AudioOperations/lid/_speechbrain_cache" + + +def _repo_root() -> Path: + return Path(__file__).resolve().parent + + +def _audio_preprocessor_root() -> Path: + return _repo_root() + + +def _resolve_lid_model_source(value: str, package_root: Path) -> str: + raw = str(value or "").strip() or DEFAULT_LID_MODEL_SOURCE + p = Path(raw).expanduser() + if p.exists(): + return str(p) + fallback = package_root / "models" / "lid" / "speechbrain_lang-id-voxlingua107-ecapa" + if fallback.exists(): + return str(fallback) + return raw + + +def _audio_ext(sample: Dict[str, Any], default_ext: str = "wav") -> str: + ext = str(sample.get("target_type") or sample.get("fileType") or default_ext).strip().lower().lstrip(".") + return ext or default_ext + + +def _strip_lid_marker(stem: str) -> str: + return re.sub(r"__lid_(zh|en)$", "", str(stem or "sample")) + + +def _mark_lid_filename(sample: Dict[str, Any], filename_key: str, lang: str, target_ext: str) -> None: + file_name = str(sample.get(filename_key) or "").strip() + stem = _strip_lid_marker(Path(file_name).stem if file_name else "sample") + sample[filename_key] = f"{stem}__lid_{lang}.{target_ext}" + + +class AudioFastLangId(Mapper): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.model_source = str(kwargs.get("modelSource", "")).strip() + self.model_savedir = str(kwargs.get("modelSavedir", "")).strip() + self.device = str(kwargs.get("device", "cpu")).strip() + self.batch_size = int(float(kwargs.get("batchSize", 1))) + self.max_seconds = float(kwargs.get("maxSeconds", 3.0)) + + def execute(self, sample: Dict[str, Any]) -> Dict[str, Any]: + start = time.time() + quality_skip_reason = invalid_quality_reason(sample, self.ext_params_key) + if quality_skip_reason: + return mark_skipped_sample( + sample, + quality_skip_reason, + self.__class__.__name__, + self.text_key, + self.data_key, + self.filetype_key, + self.target_type_key, + self.ext_params_key, + ) + + if not is_audio_sample(sample, self.filepath_key, self.filetype_key, self.target_type_key, self.data_key): + return mark_skipped_sample( + sample, + "non_audio_or_reference_file", + self.__class__.__name__, + self.text_key, + self.data_key, + self.filetype_key, + self.target_type_key, + self.ext_params_key, + ) + + import sys + + package_root = _audio_preprocessor_root() + utils_dir = package_root / "helpers" / "utils" + if str(utils_dir) not in sys.path: + sys.path.insert(0, str(utils_dir)) + + import fast_lang_id # type: ignore + + with tempfile.TemporaryDirectory(prefix="dm_audio_lid_") as td: + work_dir = Path(td) + data = sample.get(self.data_key) + audio_bytes_for_export = b"" + if isinstance(data, (bytes, bytearray)) and data: + audio_bytes_for_export = bytes(data) + wav_path = work_dir / f"input.{_audio_ext(sample)}" + wav_path.write_bytes(audio_bytes_for_export) + else: + wav_path = Path(sample.get(self.filepath_key, "")).resolve() + if not wav_path.exists(): + raise FileNotFoundError(f"输入音频不存在: {wav_path}") + audio_bytes_for_export = wav_path.read_bytes() + + out_path = work_dir / "item_with_lang.list" + in_list = work_dir / "single_item.list" + in_list.write_text( + json.dumps({"key": wav_path.stem, "wav": str(wav_path), "txt": ""}, ensure_ascii=False) + "\n", + encoding="utf-8", + ) + + # 组装 args,直接复用其 main() 的 CLI 解析逻辑 + argv_backup = sys.argv[:] + try: + sys.argv = [ + sys.argv[0], + "--input_list", + str(in_list), + "--output", + str(out_path), + "--device", + self.device, + "--batch_size", + str(max(1, self.batch_size)), + "--max_seconds", + str(self.max_seconds), + ] + model_source = _resolve_lid_model_source(self.model_source, package_root) + model_savedir = self.model_savedir or DEFAULT_LID_MODEL_SAVEDIR + sys.argv += ["--model_source", model_source, "--model_savedir", model_savedir] + + rc = fast_lang_id.main() + if rc != 0: + raise RuntimeError(f"fast_lang_id 失败,返回码: {rc}") + finally: + sys.argv = argv_backup + + if not out_path.exists(): + raise RuntimeError(f"LID 输出不存在: {out_path}") + lines = [line.strip() for line in out_path.read_text(encoding="utf-8").splitlines() if line.strip()] + if not lines: + raise RuntimeError(f"LID 输出为空: {out_path}") + d = json.loads(lines[0]) + lang = str(d.get("lang", "en")) + + ext = sample.get(self.ext_params_key, {}) + if not isinstance(ext, dict): + ext = {"_raw": ext} + ext["audio_lid"] = {"lang": lang} + sample[self.ext_params_key] = ext + + target_ext = _audio_ext(sample) + if audio_bytes_for_export: + sample[self.data_key] = audio_bytes_for_export + sample[self.text_key] = "" + if self.is_last_op: + sample[self.filetype_key] = "txt" + sample[self.target_type_key] = target_ext + else: + sample[self.filetype_key] = target_ext + sample[self.target_type_key] = target_ext + _mark_lid_filename(sample, self.filename_key, lang, target_ext) + + logger.info( + f"fileName: {sample.get(self.filename_key)}, method: AudioFastLangId costs {time.time() - start:6f} s" + ) + return sample diff --git a/runtime/ops/mapper/audio_fast_lang_id/requirements.txt b/runtime/ops/mapper/audio_fast_lang_id/requirements.txt new file mode 100644 index 00000000..7b2c98a1 --- /dev/null +++ b/runtime/ops/mapper/audio_fast_lang_id/requirements.txt @@ -0,0 +1,4 @@ +torch==2.8.0 +torchaudio==2.8.0 +speechbrain==1.0.3 +HyperPyYAML==1.2.2 diff --git a/runtime/ops/mapper/audio_fast_lang_id_text/README.md b/runtime/ops/mapper/audio_fast_lang_id_text/README.md new file mode 100644 index 00000000..636f2000 --- /dev/null +++ b/runtime/ops/mapper/audio_fast_lang_id_text/README.md @@ -0,0 +1,38 @@ +# AudioFastLangIdText 快速语言识别文本输出(中英)算子 + +## 概述 + +AudioFastLangIdText 用于对单个音频文件做快速语言识别(仅输出 `zh/en`),复用 `audio_preprocessor/src/utils/fast_lang_id.py` 的 SpeechBrain 推理逻辑。该算子用于单独运行,最终导出当前文件对应的语言标签 `.txt`,并会用标签文本替换音频输出。 + +## 功能特性 + +- **快速推理**:支持只截取前 N 秒进行判断 +- **仅输出 zh/en**:中文相关语言码统一映射为 `zh`,其他映射为 `en` +- **一入一出**:每个输入音频输出一个 `.txt`,内容为 `zh` 或 `en` +- **结构化输出**:结果同步写入 `ext_params.audio_lid.lang` + +## 参数说明 + +| 参数 | 类型 | 默认值 | 说明 | +|---|---|---:|---| +| modelSource | input | /models/AudioOperations/lid/speechbrain_lang-id-voxlingua107-ecapa | SpeechBrain LID 本地模型目录 | +| modelSavedir | input | /models/AudioOperations/lid/_speechbrain_cache | 模型缓存目录 | +| device | select | cpu | 推理设备(cpu/cuda/npu) | +| batchSize | inputNumber | 1 | 批大小(单文件时通常为 1) | +| maxSeconds | inputNumber | 3.0 | 只取前 N 秒做判断,0=全长 | + +## 输入输出 + +- **输入**:`sample["filePath"]` +- **输出**: + - `sample["text"] = "zh" | "en"`,并导出为当前输入文件对应的 `.txt` + - `sample["ext_params"]["audio_lid"]["lang"] = "zh" | "en"` + +## 依赖说明 + +- **Python 依赖**:`torch`、`torchaudio`、`speechbrain` +- **模型依赖**:SpeechBrain LID 权重需在固定本地目录中可访问 + +## 版本历史 + +- **v1.0.0**:首次发布,支持中英二分类 LID 输出 diff --git a/runtime/ops/mapper/audio_fast_lang_id_text/__init__.py b/runtime/ops/mapper/audio_fast_lang_id_text/__init__.py new file mode 100644 index 00000000..4a818e6e --- /dev/null +++ b/runtime/ops/mapper/audio_fast_lang_id_text/__init__.py @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- + +from datamate.core.base_op import OPERATORS + +OPERATORS.register_module(module_name='AudioFastLangIdText', + module_path="ops.mapper.audio_fast_lang_id_text.process") diff --git a/runtime/ops/mapper/audio_fast_lang_id_text/audio_skip.py b/runtime/ops/mapper/audio_fast_lang_id_text/audio_skip.py new file mode 100644 index 00000000..aec49613 --- /dev/null +++ b/runtime/ops/mapper/audio_fast_lang_id_text/audio_skip.py @@ -0,0 +1,114 @@ +# -- encoding: utf-8 -- + +from pathlib import Path +from typing import Any, Dict + +from loguru import logger + + +AUDIO_EXTS = { + "aac", + "aif", + "aiff", + "amr", + "au", + "flac", + "m4a", + "mp3", + "oga", + "ogg", + "opus", + "snd", + "wav", + "webm", + "wma", +} + + +def _parts(path_value: str) -> set[str]: + try: + return {part.lower() for part in Path(path_value).parts} + except Exception: + return set() + + +def is_reference_sample(sample: Dict[str, Any], filepath_key: str = "filePath") -> bool: + path_value = str(sample.get(filepath_key) or "") + return "references" in _parts(path_value) + + +def _ext_from_sample( + sample: Dict[str, Any], + filepath_key: str = "filePath", + filetype_key: str = "fileType", + target_type_key: str = "target_type", +) -> str: + for key in (target_type_key, filetype_key): + value = str(sample.get(key) or "").strip().lower().lstrip(".") + if value: + return value + path_value = str(sample.get(filepath_key) or "").strip() + return Path(path_value).suffix.lower().lstrip(".") if path_value else "" + + +def is_audio_sample( + sample: Dict[str, Any], + filepath_key: str = "filePath", + filetype_key: str = "fileType", + target_type_key: str = "target_type", + data_key: str = "data", +) -> bool: + if is_reference_sample(sample, filepath_key): + return False + data = sample.get(data_key) + if isinstance(data, (bytes, bytearray)) and data: + return True + return _ext_from_sample(sample, filepath_key, filetype_key, target_type_key) in AUDIO_EXTS + + +def invalid_quality_reason(sample: Dict[str, Any], ext_params_key: str = "ext_params") -> str: + for key in ("fileName", "sourceFileName", "filePath"): + marker_source = Path(str(sample.get(key) or "")).stem.lower() + marker = "__quality_invalid" + if marker in marker_source: + reason = marker_source.split(marker, 1)[1].strip("_") or "invalid_audio" + return f"invalid_audio_quality:{reason}" + + ext = sample.get(ext_params_key, {}) + if not isinstance(ext, dict): + return "" + quality = ext.get("audio_quality", {}) + if not isinstance(quality, dict): + return "" + if str(quality.get("quality_flag") or "").strip().lower() != "invalid": + return "" + skip_downstream = quality.get("skip_downstream", True) + if isinstance(skip_downstream, str): + skip_downstream = skip_downstream.strip().lower() in {"1", "true", "yes", "y", "on"} + if not skip_downstream: + return "" + reason = str(quality.get("reason") or "invalid_audio").strip() + return f"invalid_audio_quality:{reason}" + + +def mark_skipped_sample( + sample: Dict[str, Any], + reason: str, + op_name: str, + text_key: str = "text", + data_key: str = "data", + filetype_key: str = "fileType", + target_type_key: str = "target_type", + ext_params_key: str = "ext_params", +) -> Dict[str, Any]: + ext = sample.get(ext_params_key, {}) + if not isinstance(ext, dict): + ext = {"_raw": ext} + ext.setdefault("audio_skip", {})[op_name] = reason + sample[ext_params_key] = ext + sample[text_key] = "" + sample[data_key] = b"" + sample[filetype_key] = "" + sample[target_type_key] = "" + logger.info(f"fileName: {sample.get('fileName')}, method: {op_name} skipped: {reason}") + return sample diff --git a/runtime/ops/mapper/audio_fast_lang_id_text/helpers/utils/fast_lang_id.py b/runtime/ops/mapper/audio_fast_lang_id_text/helpers/utils/fast_lang_id.py new file mode 100644 index 00000000..e314b706 --- /dev/null +++ b/runtime/ops/mapper/audio_fast_lang_id_text/helpers/utils/fast_lang_id.py @@ -0,0 +1,483 @@ +#!/usr/bin/env python3 +""" +超快速中英语言识别(LID) + +读取 generate_audio_list.py 生成的 item.list(jsonl) 或直接扫描目录中的音频文件, +使用 DataMate 运行环境中的 SpeechBrain 预训练 LID 模型做语言识别,并输出带 lang 字段的 jsonl。 + +设计目标: +- 极快:默认只取音频前几秒做判断 +- 批处理:减少模型调用开销 +- 仅中英二分类:识别结果为 zh(中文)或 en(英文),其他语言统一归为 en +""" + +import argparse +import json +import sys +import traceback +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, Iterable, List, Optional, Tuple + + +# 添加脚本所在目录到系统路径,导入颜色工具(保持与 generate_audio_list.py 一致的风格) +try: + sys.path.insert(0, str(Path(__file__).parent.parent.parent / "scripts" / "audio_convert")) + from color_utils import info, warning, error, ok, success, header # type: ignore +except Exception: + def info(msg: str) -> str: + return f"[INFO] {msg}" + + def warning(msg: str) -> str: + return f"[WARNING] {msg}" + + def error(msg: str) -> str: + return f"[ERROR] {msg}" + + def ok(msg: str) -> str: + return f"[OK] {msg}" + + def success(msg: str) -> str: + return f"[SUCCESS] {msg}" + + def header(msg: str) -> str: + return f"=== {msg} ===" + + def print_info(msg: str): + print(info(msg)) + + def print_warning(msg: str): + print(warning(msg)) + + def print_error(msg: str): + print(error(msg)) + + def print_ok(msg: str): + print(ok(msg)) + + def print_success(msg: str): + print(success(msg)) + + def print_header(msg: str): + print(header(msg)) +else: + def print_info(msg: str): + print(info(msg)) + + def print_warning(msg: str): + print(warning(msg)) + + def print_error(msg: str): + print(error(msg)) + + def print_ok(msg: str): + print(ok(msg)) + + def print_success(msg: str): + print(success(msg)) + + def print_header(msg: str): + print(header(msg)) + + +def _project_root() -> Path: + return Path(__file__).parent.parent.parent + + +def _ensure_speechbrain_on_path() -> None: + """SpeechBrain is provided by the DataMate runtime environment.""" + return None + + +def _patch_yaml_loader_max_depth() -> None: + """兼容部分 PyYAML/HyperPyYAML 组合缺失 Loader.max_depth 的问题。""" + try: + import yaml # type: ignore + + for name in ("Loader", "SafeLoader", "FullLoader", "UnsafeLoader"): + loader = getattr(yaml, name, None) + if loader is not None and not hasattr(loader, "max_depth"): + setattr(loader, "max_depth", 1000) + except Exception: + pass + try: + import ruamel.yaml # type: ignore + + for name in ("Loader", "SafeLoader", "RoundTripLoader", "BaseLoader"): + loader = getattr(ruamel.yaml, name, None) + if loader is not None and not hasattr(loader, "max_depth"): + setattr(loader, "max_depth", 1000) + except Exception: + pass + + +def _find_audio_files(audio_dir: Path) -> List[Path]: + patterns = ["*.wav", "*.WAV", "*.flac", "*.FLAC", "*.mp3", "*.MP3", "*.aac", "*.AAC", "*.m4a", "*.M4A"] + files: List[Path] = [] + for pat in patterns: + files.extend(audio_dir.rglob(pat)) + return sorted(set(files)) + + +def _load_jsonl_items(path: Path, filter_ok_only: bool = False) -> List[Dict]: + items: List[Dict] = [] + with open(path, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + items.append(json.loads(line)) + + if not filter_ok_only: + return items + + filtered = [it for it in items if it.get("quality_flag", "ok") == "ok"] + if not items: + return items + print_info(f"质量过滤后保留 {len(filtered)}/{len(items)} 条,仅识别 quality_flag=='ok' 的音频") + return filtered + + +def _dump_jsonl_items(path: Path, items: Iterable[Dict]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "w", encoding="utf-8") as f: + for it in items: + f.write(json.dumps(it, ensure_ascii=False) + "\n") + + +def _iso_to_zh_en(lid_label: str) -> str: + """ + 将 LID 模型输出映射为仅两种:zh(中文)或 en(英文)。 + 模型可能返回 "en: English"、"zh: Chinese" 等,取冒号前作为语言码再判断。 + 中文相关 ISO 码映射为 zh,其余一律为 en。 + """ + raw = (lid_label or "").strip() + if ":" in raw: + iso = raw.split(":", 1)[0].strip().lower() + else: + iso = raw.lower() + zh_aliases = {"zh", "cmn", "yue", "wuu", "nan", "cdo", "cjy", "hsn", "hak"} + if iso in zh_aliases: + return "zh" + return "en" + + +def _out_item(it: Dict, lang: str) -> Dict: + """只保留 key、wav、txt、lang 四列,供输出 jsonl 使用。""" + return { + "key": it.get("key", ""), + "wav": it.get("wav") or it.get("audio") or it.get("path", ""), + "txt": it.get("txt", ""), + "lang": lang, + } + + +def _batch_iter(xs: List[Dict], batch_size: int) -> Iterable[List[Dict]]: + for i in range(0, len(xs), batch_size): + yield xs[i : i + batch_size] + + +def _lid_predict_items( + items: List[Dict], + model_source: str, + model_savedir: Path, + device: str, + batch_size: int, + max_seconds: float, +) -> List[Dict]: + _ensure_speechbrain_on_path() + _patch_yaml_loader_max_depth() + + # 这里延迟导入,避免只跑 --help 时加载 torch/torchaudio + import torch # type: ignore + from types import SimpleNamespace + + # 兼容旧版 torch:SpeechBrain 可能会引用 torch.amp.custom_fwd/custom_bwd + # - torch>=2.0: torch.amp.custom_fwd/custom_bwd(支持 device_type 等参数) + # - torch<2.0: torch.cuda.amp.custom_fwd/custom_bwd(签名可能更旧,不支持 device_type) + try: + has_amp = hasattr(torch, "amp") + has_custom_fwd = has_amp and hasattr(torch.amp, "custom_fwd") + has_custom_bwd = has_amp and hasattr(torch.amp, "custom_bwd") + if not (has_custom_fwd and has_custom_bwd): + try: + from torch.cuda.amp import custom_fwd as _custom_fwd # type: ignore + from torch.cuda.amp import custom_bwd as _custom_bwd # type: ignore + except Exception: + # 退化为 no-op 装饰器(不启用 AMP 也能推理) + def _custom_fwd(*_args, **_kwargs): # type: ignore + def _decorator(fn): + return fn + + return _decorator + + def _custom_bwd(*_args, **_kwargs): # type: ignore + def _decorator(fn): + return fn + + return _decorator + + if not hasattr(torch, "amp"): + torch.amp = SimpleNamespace() # type: ignore[attr-defined] + + def _drop_unsupported_kwargs(deco): # type: ignore + def _wrapped(*args, **kwargs): + # 旧版 deco 可能不支持 device_type 等 kwargs;这里直接丢弃所有 kwargs + # 保证能作为装饰器正常使用 + return deco(*args) + + return _wrapped + + torch.amp.custom_fwd = _drop_unsupported_kwargs(_custom_fwd) # type: ignore[attr-defined] + torch.amp.custom_bwd = _drop_unsupported_kwargs(_custom_bwd) # type: ignore[attr-defined] + except Exception: + # 不让兼容逻辑影响主流程;真正的导入错误会在后面暴露 + pass + + from speechbrain.inference.classifiers import EncoderClassifier # type: ignore + + # 使用本地目录:/abs/path/to/model_dir + src_path = Path(model_source) + is_local_dir = src_path.exists() and src_path.is_dir() + resolved_source = str(src_path.resolve()) if is_local_dir else model_source + + overrides = {} + if is_local_dir: + # hyperparams.yaml 里的 pretrained_path 可能不是本地路径,这里强制指向本地目录。 + overrides = {"pretrained_path": resolved_source} + + # 预先检查必需权重是否存在,避免长时间卡在 fetch/重试 + required = ["hyperparams.yaml", "label_encoder.txt", "embedding_model.ckpt", "classifier.ckpt"] + missing = [fn for fn in required if not (src_path / fn).exists()] + if missing: + raise RuntimeError( + "本地 LID 模型目录不完整,缺少必要文件:\n" + + "\n".join([f"- {src_path / fn}" for fn in missing]) + + "\n\n请检查本地模型目录是否完整。" + ) + try: + classifier = EncoderClassifier.from_hparams( + source=resolved_source, + savedir=str(model_savedir), + run_opts={"device": device}, + overrides=overrides, + ) + except Exception as e: + raise RuntimeError( + "加载 SpeechBrain LID 模型失败。\n" + f"- source={model_source}\n" + f"- savedir={model_savedir}\n" + f"- device={device}\n" + f"- error={type(e).__name__}: {e}" + ) from e + + out_items: List[Dict] = [] + total = len(items) + done = 0 + + for batch in _batch_iter(items, batch_size): + wav_tensors: List[torch.Tensor] = [] + wav_lens: List[float] = [] + ok_mask: List[bool] = [] + + for it in batch: + wav_path = it.get("wav") or it.get("audio") or it.get("path") + if not wav_path: + ok_mask.append(False) + continue + try: + sig = classifier.load_audio(str(wav_path)) + # sig: [time] 或 [channels, time],speechbrain load_audio 通常返回 [time] + if sig.ndim > 1: + sig = sig.mean(dim=0) + if max_seconds > 0: + max_samples = int(16000 * max_seconds) + sig = sig[:max_samples] + if sig.numel() == 0: + ok_mask.append(False) + continue + wav_tensors.append(sig) + wav_lens.append(float(sig.shape[0])) + ok_mask.append(True) + except Exception: + ok_mask.append(False) + + if not wav_tensors: + for it in batch: + out_items.append(_out_item(it, "en")) + done += len(batch) + continue + + max_len = max(int(x.shape[0]) for x in wav_tensors) + padded = torch.zeros((len(wav_tensors), max_len), dtype=torch.float32) + lens_rel = torch.zeros((len(wav_tensors),), dtype=torch.float32) + for i, sig in enumerate(wav_tensors): + L = int(sig.shape[0]) + padded[i, :L] = sig.float() + lens_rel[i] = float(L) / float(max_len) if max_len > 0 else 1.0 + + with torch.inference_mode(): + out_prob, score, index, text_lab = classifier.classify_batch(padded, lens_rel) + + pred_i = 0 + for it, ok_ in zip(batch, ok_mask): + if not ok_: + out_items.append(_out_item(it, "en")) + else: + lid_label = str(text_lab[pred_i]) if isinstance(text_lab, list) else str(text_lab) + lang = _iso_to_zh_en(lid_label) + out_items.append(_out_item(it, lang)) + pred_i += 1 + + done += len(batch) + if done % max(10, batch_size) == 0 or done == total: + print_info(f"LID 进度: {done}/{total}") + + return out_items + + +def parse_arguments(): + default_models_dir = _project_root() / "models" / "lid" + default_local_model_dir = default_models_dir / "speechbrain_lang-id-voxlingua107-ecapa" + default_savedir = default_models_dir / "_speechbrain_cache" / "lang-id-voxlingua107-ecapa" + default_audio_dir = _project_root() / "output_data" / "denoise" + default_quality_list = _project_root() / "output_data" / "denoise" / "item_with_quality.list" + default_output = _project_root() / "output_data" / "lid" / "item_with_lang.list" + + parser = argparse.ArgumentParser( + description="超快速中英语言识别(SpeechBrain),仅输出 zh/en", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=rf""" +示例: + # 默认:直接扫描 output_data/denoise 下所有音频 + python -m src.utils.fast_lang_id + + # 启用质量过滤:默认读取 item_with_quality.list,并且仅识别 ok 音频 + python -m src.utils.fast_lang_id --filter-audio=True + + # 启用质量过滤,但自定义过滤列表路径 + python -m src.utils.fast_lang_id --filter-audio=True --filter-audio-list ./somewhere/item_with_quality.list + + # 显式指定输入列表 + python -m src.utils.fast_lang_id --input_list ./output_data/denoise/item.list + """, + ) + + g = parser.add_mutually_exclusive_group(required=False) + g.add_argument( + "--input_list", + "-i", + default=None, + help="输入列表文件(jsonl,每行包含 wav 字段;若包含 quality_flag 字段则仅使用 quality_flag=='ok' 的条目)", + ) + g.add_argument("--audio_dir", "-a", default=str(default_audio_dir), help=f"直接扫描目录下音频文件,默认: {default_audio_dir}") + + parser.add_argument("--output", "-o", default=str(default_output), help=f"输出列表文件路径,默认: {default_output}") + parser.add_argument( + "--filter-audio", + default="False", + help="是否启用质量过滤;True 时默认读取 item_with_quality.list 并只识别 ok 音频", + ) + parser.add_argument( + "--filter-audio-list", + default=str(default_quality_list), + help=f"质量过滤列表路径,默认: {default_quality_list}", + ) + parser.add_argument( + "--model_source", + default=str(default_local_model_dir), + help="SpeechBrain LID 本地模型目录。", + ) + parser.add_argument("--model_savedir", default=str(default_savedir), help=f"模型缓存目录,默认: {default_savedir}") + parser.add_argument("--device", default="cpu", help="推理设备,例如 cpu / cuda / npu(取决于 torch 环境)") + parser.add_argument("--batch_size", type=int, default=8, help="批大小(越大越快,但更吃内存)") + parser.add_argument("--max_seconds", type=float, default=3.0, help="只取音频前 N 秒做判断,0 表示全长") + + return parser.parse_args() + + +def main() -> int: + args = parse_arguments() + print_header("快速语言识别(LID)") + + output_path = Path(args.output).resolve() + model_savedir = Path(args.model_savedir).resolve() + filter_audio = str(args.filter_audio).lower() in {"1", "true", "yes", "y", "on"} + filter_audio_list = Path(args.filter_audio_list).resolve() + + # 读入 items(默认使用 output_data/normalization 目录) + items: List[Dict] + if args.input_list: + input_path = Path(args.input_list).resolve() + if not input_path.exists(): + print_error(f"输入列表不存在: {input_path}") + return 1 + print_info(f"输入列表: {input_path}") + items = _load_jsonl_items(input_path) + if filter_audio: + items = [it for it in items if it.get("quality_flag", "ok") == "ok"] + else: + if filter_audio: + if filter_audio_list.exists(): + print_info(f"启用质量过滤,读取列表: {filter_audio_list}") + items = _load_jsonl_items(filter_audio_list, filter_ok_only=True) + else: + print_warning(f"质量过滤列表不存在,回退为扫描目录: {filter_audio_list}") + audio_dir = Path(args.audio_dir).resolve() + if not audio_dir.exists(): + print_error(f"音频目录不存在: {audio_dir}") + return 1 + print_info(f"扫描目录: {audio_dir}") + audio_files = _find_audio_files(audio_dir) + if not audio_files: + print_warning("未找到任何音频文件") + return 0 + items = [{"key": p.stem, "wav": str(p.resolve()), "txt": ""} for p in audio_files] + else: + audio_dir = Path(args.audio_dir).resolve() + if not audio_dir.exists(): + print_error(f"音频目录不存在: {audio_dir}") + return 1 + print_info(f"扫描目录: {audio_dir}") + audio_files = _find_audio_files(audio_dir) + if not audio_files: + print_warning("未找到任何音频文件") + return 0 + items = [{"key": p.stem, "wav": str(p.resolve()), "txt": ""} for p in audio_files] + + if not items: + print_warning("输入为空,退出") + return 0 + + print_info(f"待识别音频数: {len(items)}") + print_info(f"模型: {args.model_source}") + print_info(f"模型缓存目录: {model_savedir}") + print_info(f"device={args.device}, batch_size={args.batch_size}, max_seconds={args.max_seconds}") + + try: + out_items = _lid_predict_items( + items=items, + model_source=args.model_source, + model_savedir=model_savedir, + device=args.device, + batch_size=max(1, int(args.batch_size)), + max_seconds=float(args.max_seconds), + ) + except Exception as e: + print_error(f"LID 推理失败: {e}") + print_error("traceback:\n" + traceback.format_exc()) + return 1 + + _dump_jsonl_items(output_path, out_items) + print_success(f"完成!输出: {output_path}") + + stat: Dict[str, int] = {"zh": 0, "en": 0} + for it in out_items: + stat[str(it.get("lang", "en"))] = stat.get(str(it.get("lang", "en")), 0) + 1 + print_info(f"统计: zh={stat.get('zh', 0)}, en={stat.get('en', 0)}") + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/runtime/ops/mapper/audio_fast_lang_id_text/metadata.yml b/runtime/ops/mapper/audio_fast_lang_id_text/metadata.yml new file mode 100644 index 00000000..7d09f231 --- /dev/null +++ b/runtime/ops/mapper/audio_fast_lang_id_text/metadata.yml @@ -0,0 +1,67 @@ +name: 'audioOps-快速语言识别文本输出(中英)' +name_en: 'audioOps-Fast Language ID Text Output (zh/en)' +description: '调用 SpeechBrain LID 对当前输入音频识别 zh/en,终端输出一个语言标签 txt 文件。该算子会用标签文本替换音频。' +description_en: 'Run SpeechBrain LID for zh/en and output one terminal language-label txt file. This operator replaces the audio with label text.' +language: 'python' +vendor: 'huawei' +raw_id: 'AudioFastLangIdText' +version: '1.0.0' +types: + - 'annotation' +modal: 'audio' +inputs: 'audio' +outputs: 'text' +settings: + modelSource: + name: '模型源' + description: 'SpeechBrain LID 本地模型目录。' + type: 'input' + defaultVal: '/models/AudioOperations/lid/speechbrain_lang-id-voxlingua107-ecapa' + required: false + modelSavedir: + name: '模型缓存目录' + description: 'SpeechBrain 模型缓存目录(可选)。' + type: 'input' + defaultVal: '/models/AudioOperations/lid/_speechbrain_cache' + required: false + device: + name: '设备' + description: 'cpu/cuda/npu 等(取决于 torch 环境)。' + type: 'select' + defaultVal: 'cpu' + required: true + options: + - label: 'cpu' + value: 'cpu' + - label: 'cuda' + value: 'cuda' + - label: 'npu' + value: 'npu' + batchSize: + name: '批大小' + type: 'inputNumber' + description: '批大小(单文件时意义不大)。' + defaultVal: 1 + min: 1 + max: 64 + step: 1 + maxSeconds: + name: '截断秒数' + type: 'inputNumber' + description: '只取前 N 秒做判断,0=全长。' + defaultVal: 3.0 + min: 0 + max: 60 + step: 0.5 +runtime: + memory: 2147483648 + cpu: 0.5 + gpu: 0 + npu: 0 + storage: 10MB + +metrics: + - name: '处理耗时' + metric: '依输入音频长度与运行环境而定' +release: + - '首次发布' diff --git a/runtime/ops/mapper/audio_fast_lang_id_text/process.py b/runtime/ops/mapper/audio_fast_lang_id_text/process.py new file mode 100644 index 00000000..f44f359d --- /dev/null +++ b/runtime/ops/mapper/audio_fast_lang_id_text/process.py @@ -0,0 +1,157 @@ +# -- encoding: utf-8 -- + +import json +import tempfile +import time +from pathlib import Path +from typing import Dict, Any + +from loguru import logger + +from datamate.core.base_op import Mapper +try: + from .audio_skip import invalid_quality_reason, is_audio_sample, mark_skipped_sample +except ImportError: + from audio_skip import invalid_quality_reason, is_audio_sample, mark_skipped_sample + + +DEFAULT_LID_MODEL_SOURCE = "/models/AudioOperations/lid/speechbrain_lang-id-voxlingua107-ecapa" +DEFAULT_LID_MODEL_SAVEDIR = "/models/AudioOperations/lid/_speechbrain_cache" + + +def _repo_root() -> Path: + return Path(__file__).resolve().parent + + +def _audio_preprocessor_root() -> Path: + return _repo_root() + + +def _resolve_lid_model_source(value: str, package_root: Path) -> str: + raw = str(value or "").strip() or DEFAULT_LID_MODEL_SOURCE + p = Path(raw).expanduser() + if p.exists(): + return str(p) + fallback = package_root / "models" / "lid" / "speechbrain_lang-id-voxlingua107-ecapa" + if fallback.exists(): + return str(fallback) + return raw + + +def _audio_ext(sample: Dict[str, Any], default_ext: str = "wav") -> str: + ext = str(sample.get("target_type") or sample.get("fileType") or default_ext).strip().lower().lstrip(".") + return ext or default_ext + + +class AudioFastLangIdText(Mapper): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.model_source = str(kwargs.get("modelSource", "")).strip() + self.model_savedir = str(kwargs.get("modelSavedir", "")).strip() + self.device = str(kwargs.get("device", "cpu")).strip() + self.batch_size = int(float(kwargs.get("batchSize", 1))) + self.max_seconds = float(kwargs.get("maxSeconds", 3.0)) + + def execute(self, sample: Dict[str, Any]) -> Dict[str, Any]: + start = time.time() + quality_skip_reason = invalid_quality_reason(sample, self.ext_params_key) + if quality_skip_reason: + return mark_skipped_sample( + sample, + quality_skip_reason, + self.__class__.__name__, + self.text_key, + self.data_key, + self.filetype_key, + self.target_type_key, + self.ext_params_key, + ) + + if not is_audio_sample(sample, self.filepath_key, self.filetype_key, self.target_type_key, self.data_key): + return mark_skipped_sample( + sample, + "non_audio_or_reference_file", + self.__class__.__name__, + self.text_key, + self.data_key, + self.filetype_key, + self.target_type_key, + self.ext_params_key, + ) + + import sys + + package_root = _audio_preprocessor_root() + utils_dir = package_root / "helpers" / "utils" + if str(utils_dir) not in sys.path: + sys.path.insert(0, str(utils_dir)) + + import fast_lang_id # type: ignore + + with tempfile.TemporaryDirectory(prefix="dm_audio_lid_") as td: + work_dir = Path(td) + data = sample.get(self.data_key) + if isinstance(data, (bytes, bytearray)) and data: + wav_path = work_dir / f"input.{_audio_ext(sample)}" + wav_path.write_bytes(bytes(data)) + else: + wav_path = Path(sample.get(self.filepath_key, "")).resolve() + if not wav_path.exists(): + raise FileNotFoundError(f"输入音频不存在: {wav_path}") + + out_path = work_dir / "item_with_lang.list" + in_list = work_dir / "single_item.list" + in_list.write_text( + json.dumps({"key": wav_path.stem, "wav": str(wav_path), "txt": ""}, ensure_ascii=False) + "\n", + encoding="utf-8", + ) + + # 组装 args,直接复用其 main() 的 CLI 解析逻辑 + argv_backup = sys.argv[:] + try: + sys.argv = [ + sys.argv[0], + "--input_list", + str(in_list), + "--output", + str(out_path), + "--device", + self.device, + "--batch_size", + str(max(1, self.batch_size)), + "--max_seconds", + str(self.max_seconds), + ] + model_source = _resolve_lid_model_source(self.model_source, package_root) + model_savedir = self.model_savedir or DEFAULT_LID_MODEL_SAVEDIR + sys.argv += ["--model_source", model_source, "--model_savedir", model_savedir] + + rc = fast_lang_id.main() + if rc != 0: + raise RuntimeError(f"fast_lang_id 失败,返回码: {rc}") + finally: + sys.argv = argv_backup + + if not out_path.exists(): + raise RuntimeError(f"LID 输出不存在: {out_path}") + lines = [line.strip() for line in out_path.read_text(encoding="utf-8").splitlines() if line.strip()] + if not lines: + raise RuntimeError(f"LID 输出为空: {out_path}") + d = json.loads(lines[0]) + lang = str(d.get("lang", "en")) + + ext = sample.get(self.ext_params_key, {}) + if not isinstance(ext, dict): + ext = {"_raw": ext} + ext["audio_lid"] = {"lang": lang} + sample[self.ext_params_key] = ext + + sample[self.data_key] = b"" + sample[self.text_key] = lang + sample[self.filetype_key] = "txt" + sample[self.target_type_key] = "txt" + + logger.info( + f"fileName: {sample.get(self.filename_key)}, method: AudioFastLangIdText costs {time.time() - start:6f} s" + ) + return sample diff --git a/runtime/ops/mapper/audio_fast_lang_id_text/requirements.txt b/runtime/ops/mapper/audio_fast_lang_id_text/requirements.txt new file mode 100644 index 00000000..7b2c98a1 --- /dev/null +++ b/runtime/ops/mapper/audio_fast_lang_id_text/requirements.txt @@ -0,0 +1,4 @@ +torch==2.8.0 +torchaudio==2.8.0 +speechbrain==1.0.3 +HyperPyYAML==1.2.2 diff --git a/runtime/ops/mapper/audio_format_convert/README.md b/runtime/ops/mapper/audio_format_convert/README.md new file mode 100644 index 00000000..cf259ae3 --- /dev/null +++ b/runtime/ops/mapper/audio_format_convert/README.md @@ -0,0 +1,27 @@ +# AudioFormatConvert 音频格式转换与重采样算子 + +## 概述 + +AudioFormatConvert 处理输入音频,并将结果写入 `sample["data"]`,同时设置 `sample["target_type"]`。作为链路中间节点时,它保持当前样本仍为音频格式,方便后续 LID/ASR 继续读取;作为最后一个算子时,最终落盘交由 DataMate 标准导出流程负责。 + +## 参数说明 + +| 参数 | 类型 | 默认值 | 说明 | +|---|---|---:|---| +| targetFormat | select | wav | 目标输出格式(扩展名) | +| sampleRate | inputNumber | 16000 | 目标采样率(Hz),0 表示保持原采样率 | +| channels | inputNumber | 1 | 目标声道数:1=单声道,2=双声道,0=保持原声道 | + +## 输入输出 + +- **输入**:`sample["filePath"]`,若上游算子已产生 `sample["data"]`,则优先处理该音频字节。 +- **输出**:`sample["data"]` 为处理后的音频字节;`sample["target_type"]` 为目标音频后缀。 + +## 依赖说明 + +- **Python 依赖**:`pydub==0.25.1`、`soundfile==0.12.1`、`numpy==2.2.6`,由 DataMate 运行环境提供。 +- **系统依赖**:`ffmpeg`,由 DataMate 运行环境提供,用于 mp3/aac/m4a 等格式解码与编码。 + +## 版本历史 + +- **v1.0.0**:首次发布 diff --git a/runtime/ops/mapper/audio_format_convert/__init__.py b/runtime/ops/mapper/audio_format_convert/__init__.py new file mode 100644 index 00000000..d6a4a1a5 --- /dev/null +++ b/runtime/ops/mapper/audio_format_convert/__init__.py @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- + +from datamate.core.base_op import OPERATORS + +OPERATORS.register_module(module_name='AudioFormatConvert', + module_path="ops.mapper.audio_format_convert.process") diff --git a/runtime/ops/mapper/audio_format_convert/audio_skip.py b/runtime/ops/mapper/audio_format_convert/audio_skip.py new file mode 100644 index 00000000..aec49613 --- /dev/null +++ b/runtime/ops/mapper/audio_format_convert/audio_skip.py @@ -0,0 +1,114 @@ +# -- encoding: utf-8 -- + +from pathlib import Path +from typing import Any, Dict + +from loguru import logger + + +AUDIO_EXTS = { + "aac", + "aif", + "aiff", + "amr", + "au", + "flac", + "m4a", + "mp3", + "oga", + "ogg", + "opus", + "snd", + "wav", + "webm", + "wma", +} + + +def _parts(path_value: str) -> set[str]: + try: + return {part.lower() for part in Path(path_value).parts} + except Exception: + return set() + + +def is_reference_sample(sample: Dict[str, Any], filepath_key: str = "filePath") -> bool: + path_value = str(sample.get(filepath_key) or "") + return "references" in _parts(path_value) + + +def _ext_from_sample( + sample: Dict[str, Any], + filepath_key: str = "filePath", + filetype_key: str = "fileType", + target_type_key: str = "target_type", +) -> str: + for key in (target_type_key, filetype_key): + value = str(sample.get(key) or "").strip().lower().lstrip(".") + if value: + return value + path_value = str(sample.get(filepath_key) or "").strip() + return Path(path_value).suffix.lower().lstrip(".") if path_value else "" + + +def is_audio_sample( + sample: Dict[str, Any], + filepath_key: str = "filePath", + filetype_key: str = "fileType", + target_type_key: str = "target_type", + data_key: str = "data", +) -> bool: + if is_reference_sample(sample, filepath_key): + return False + data = sample.get(data_key) + if isinstance(data, (bytes, bytearray)) and data: + return True + return _ext_from_sample(sample, filepath_key, filetype_key, target_type_key) in AUDIO_EXTS + + +def invalid_quality_reason(sample: Dict[str, Any], ext_params_key: str = "ext_params") -> str: + for key in ("fileName", "sourceFileName", "filePath"): + marker_source = Path(str(sample.get(key) or "")).stem.lower() + marker = "__quality_invalid" + if marker in marker_source: + reason = marker_source.split(marker, 1)[1].strip("_") or "invalid_audio" + return f"invalid_audio_quality:{reason}" + + ext = sample.get(ext_params_key, {}) + if not isinstance(ext, dict): + return "" + quality = ext.get("audio_quality", {}) + if not isinstance(quality, dict): + return "" + if str(quality.get("quality_flag") or "").strip().lower() != "invalid": + return "" + skip_downstream = quality.get("skip_downstream", True) + if isinstance(skip_downstream, str): + skip_downstream = skip_downstream.strip().lower() in {"1", "true", "yes", "y", "on"} + if not skip_downstream: + return "" + reason = str(quality.get("reason") or "invalid_audio").strip() + return f"invalid_audio_quality:{reason}" + + +def mark_skipped_sample( + sample: Dict[str, Any], + reason: str, + op_name: str, + text_key: str = "text", + data_key: str = "data", + filetype_key: str = "fileType", + target_type_key: str = "target_type", + ext_params_key: str = "ext_params", +) -> Dict[str, Any]: + ext = sample.get(ext_params_key, {}) + if not isinstance(ext, dict): + ext = {"_raw": ext} + ext.setdefault("audio_skip", {})[op_name] = reason + sample[ext_params_key] = ext + sample[text_key] = "" + sample[data_key] = b"" + sample[filetype_key] = "" + sample[target_type_key] = "" + logger.info(f"fileName: {sample.get('fileName')}, method: {op_name} skipped: {reason}") + return sample diff --git a/runtime/ops/mapper/audio_format_convert/metadata.yml b/runtime/ops/mapper/audio_format_convert/metadata.yml new file mode 100644 index 00000000..1e82b1d3 --- /dev/null +++ b/runtime/ops/mapper/audio_format_convert/metadata.yml @@ -0,0 +1,61 @@ +name: 'audioUtils-音频格式转换与重采样' +name_en: 'audioUtils-Audio Format Convert & Resample' +description: '将常见音频格式互相转换,并可选重采样、声道转换;由 DataMate 统一导出结果。' +description_en: 'Convert between common audio formats with optional resampling and channel conversion; DataMate exports the result.' +language: 'python' +vendor: 'huawei' +raw_id: 'AudioFormatConvert' +version: '1.0.0' +types: + - 'cleaning' +modal: 'audio' +inputs: 'audio' +outputs: 'audio' +settings: + targetFormat: + name: '目标格式' + description: '输出音频格式(扩展名),如 wav/flac/mp3/aac/m4a/ogg。' + type: 'select' + defaultVal: 'wav' + required: true + options: + - label: 'wav' + value: 'wav' + - label: 'flac' + value: 'flac' + - label: 'mp3' + value: 'mp3' + - label: 'aac' + value: 'aac' + - label: 'm4a' + value: 'm4a' + - label: 'ogg' + value: 'ogg' + sampleRate: + name: '采样率' + description: '目标采样率(Hz)。0 表示保持原采样率。' + type: 'inputNumber' + defaultVal: 16000 + min: 0 + max: 192000 + step: 1 + channels: + name: '声道数' + description: '目标声道数:1=单声道,2=双声道,0=保持原声道。' + type: 'inputNumber' + defaultVal: 1 + min: 0 + max: 2 + step: 1 +runtime: + memory: 104857600 + cpu: 0.2 + gpu: 0 + npu: 0 + storage: 10MB + +metrics: + - name: '处理耗时' + metric: '依输入音频长度与运行环境而定' +release: + - '首次发布' diff --git a/runtime/ops/mapper/audio_format_convert/process.py b/runtime/ops/mapper/audio_format_convert/process.py new file mode 100644 index 00000000..e1c4f5d4 --- /dev/null +++ b/runtime/ops/mapper/audio_format_convert/process.py @@ -0,0 +1,187 @@ +# -- encoding: utf-8 -- + +import io +import time +from pathlib import Path +from typing import Any, Dict, Optional, Tuple + +from loguru import logger + +from datamate.core.base_op import Mapper +try: + from .audio_skip import invalid_quality_reason, is_audio_sample, mark_skipped_sample +except ImportError: + from audio_skip import invalid_quality_reason, is_audio_sample, mark_skipped_sample + + +def _load_audio_backend() -> Tuple[Optional[object], Optional[object]]: + audiosegment = None + sf = None + try: + from pydub import AudioSegment # type: ignore + audiosegment = AudioSegment + except Exception: + audiosegment = None + + try: + import soundfile as _sf # type: ignore + + sf = _sf + except Exception: + sf = None + + return audiosegment, sf + + +def _convert_with_pydub(source: object, target_sr: int, channels: int, fmt: str) -> bytes: + audiosegment, _ = _load_audio_backend() + if audiosegment is None: + raise RuntimeError("pydub 不可用,无法使用 pydub 转换") + + if isinstance(source, (bytes, bytearray)): + audio = audiosegment.from_file(io.BytesIO(bytes(source))) + else: + audio = audiosegment.from_file(str(source)) + if target_sr and target_sr > 0: + audio = audio.set_frame_rate(int(target_sr)) + if channels == 1: + audio = audio.set_channels(1) + elif channels == 2: + audio = audio.set_channels(2) + + with io.BytesIO() as buf: + audio.export(buf, format=fmt) + return buf.getvalue() + + +def _convert_with_soundfile(source: object, target_sr: int, channels: int, fmt: str) -> bytes: + _, sf = _load_audio_backend() + if sf is None: + raise RuntimeError("soundfile 不可用,无法使用 soundfile 转换") + if fmt not in {"wav", "flac", "ogg"}: + raise RuntimeError(f"当前环境无 pydub 时不支持转换到: {fmt}") + + if isinstance(source, (bytes, bytearray)): + data, sr = sf.read(io.BytesIO(bytes(source)), always_2d=True) + else: + data, sr = sf.read(str(source), always_2d=True) + + if channels == 1 and data.shape[1] > 1: + data = data.mean(axis=1, keepdims=True) + elif channels == 2 and data.shape[1] == 1: + data = data.repeat(2, axis=1) + + if target_sr and target_sr > 0 and int(sr) != int(target_sr): + try: + import numpy as np + + new_len = max(1, int(round(data.shape[0] * float(target_sr) / float(sr)))) + old_x = np.linspace(0.0, 1.0, num=data.shape[0], endpoint=False) + new_x = np.linspace(0.0, 1.0, num=new_len, endpoint=False) + channels_data = [ + np.interp(new_x, old_x, data[:, ch]).astype(np.float32) + for ch in range(data.shape[1]) + ] + data = np.stack(channels_data, axis=1) + sr = int(target_sr) + except Exception as e: + raise RuntimeError(f"重采样失败(需要 numpy),src_sr={sr}, target_sr={target_sr}: {e}") from e + + with io.BytesIO() as buf: + sf.write(buf, data, int(sr), format=fmt.upper()) + return buf.getvalue() + + +def _ext_from_sample(sample: Dict[str, Any], default_ext: str) -> str: + target_type = str(sample.get("target_type") or "").strip().lower().lstrip(".") + file_type = str(sample.get("fileType") or "").strip().lower().lstrip(".") + return target_type or file_type or default_ext + + +class AudioFormatConvert(Mapper): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.target_format = str(kwargs.get("targetFormat", "wav")).strip().lower().lstrip(".") + self.sample_rate = int(float(kwargs.get("sampleRate", 16000))) + self.channels = int(float(kwargs.get("channels", 1))) + + def execute(self, sample: Dict[str, Any]) -> Dict[str, Any]: + start = time.time() + quality_skip_reason = invalid_quality_reason(sample, self.ext_params_key) + if quality_skip_reason: + return mark_skipped_sample( + sample, + quality_skip_reason, + self.__class__.__name__, + self.text_key, + self.data_key, + self.filetype_key, + self.target_type_key, + self.ext_params_key, + ) + + if not is_audio_sample(sample, self.filepath_key, self.filetype_key, self.target_type_key, self.data_key): + return mark_skipped_sample( + sample, + "non_audio_or_reference_file", + self.__class__.__name__, + self.text_key, + self.data_key, + self.filetype_key, + self.target_type_key, + self.ext_params_key, + ) + + in_path = Path(sample.get(self.filepath_key, "")).resolve() + source = sample.get(self.data_key) or in_path + if not isinstance(source, (bytes, bytearray)) and not in_path.exists(): + raise FileNotFoundError(f"输入音频不存在: {in_path}") + + source_ext = _ext_from_sample(sample, in_path.suffix.lower().lstrip(".") or self.target_format) + audiosegment, sf = _load_audio_backend() + try: + if audiosegment is not None: + out_bytes = _convert_with_pydub( + source=source, + target_sr=self.sample_rate, + channels=self.channels, + fmt=self.target_format, + ) + else: + if sf is None: + raise RuntimeError("pydub/soundfile 均不可用,无法转换") + out_bytes = _convert_with_soundfile( + source=source, + target_sr=self.sample_rate, + channels=self.channels, + fmt=self.target_format, + ) + except Exception as e: + if in_path.suffix.lower().lstrip(".") == self.target_format and not sample.get(self.data_key): + out_bytes = in_path.read_bytes() + else: + raise e + + sample[self.data_key] = out_bytes + sample[self.text_key] = "" + sample[self.target_type_key] = self.target_format + if self.is_last_op: + sample[self.filetype_key] = "txt" + else: + sample[self.filetype_key] = self.target_format + + ext = sample.get(self.ext_params_key, {}) + if not isinstance(ext, dict): + ext = {"_raw": ext} + ext["audio_format_convert"] = { + "format": self.target_format, + "sample_rate": self.sample_rate, + "channels": self.channels, + "source_ext": source_ext, + } + sample[self.ext_params_key] = ext + + logger.info( + f"fileName: {sample.get(self.filename_key)}, method: AudioFormatConvert costs {time.time() - start:6f} s" + ) + return sample diff --git a/runtime/ops/mapper/audio_format_convert/requirements.txt b/runtime/ops/mapper/audio_format_convert/requirements.txt new file mode 100644 index 00000000..c439cbca --- /dev/null +++ b/runtime/ops/mapper/audio_format_convert/requirements.txt @@ -0,0 +1,3 @@ +pydub==0.25.1 +soundfile==0.12.1 +numpy==2.2.6 diff --git a/runtime/ops/mapper/audio_gtcrn_denoise/README.md b/runtime/ops/mapper/audio_gtcrn_denoise/README.md new file mode 100644 index 00000000..d91a23b8 --- /dev/null +++ b/runtime/ops/mapper/audio_gtcrn_denoise/README.md @@ -0,0 +1,24 @@ +# AudioGtcrnDenoise GTCRN 智能降噪算子 + +## 概述 + +AudioGtcrnDenoise 处理输入音频,并将结果写入 `sample["data"]`,同时设置 `sample["target_type"]`。输出路径、同名文件处理和最终落盘均交由 DataMate 的标准导出流程负责。 + +## 参数说明 + +| 参数 | 类型 | 默认值 | 说明 | +|---|---|---:|---| +| modelPath | input | /models/AudioOperations/gtcrn/gtcrn.onnx | GTCRN ONNX 模型绝对路径 | + +## 输入输出 + +- **输入**:`sample["filePath"]`,若上游算子已产生 `sample["data"]`,则优先处理该音频字节。 +- **输出**:`sample["data"]` 为处理后的音频字节;`sample["target_type"]` 为目标音频后缀。 + +## 依赖说明 + +- **Python 依赖**:onnxruntime、soundfile、numpy;模型固定部署路径默认为 /models/AudioOperations/gtcrn/gtcrn.onnx + +## 版本历史 + +- **v1.0.0**:首次发布 diff --git a/runtime/ops/mapper/audio_gtcrn_denoise/__init__.py b/runtime/ops/mapper/audio_gtcrn_denoise/__init__.py new file mode 100644 index 00000000..6a19ec78 --- /dev/null +++ b/runtime/ops/mapper/audio_gtcrn_denoise/__init__.py @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- + +from datamate.core.base_op import OPERATORS + +OPERATORS.register_module(module_name='AudioGtcrnDenoise', + module_path="ops.mapper.audio_gtcrn_denoise.process") diff --git a/runtime/ops/mapper/audio_gtcrn_denoise/audio_skip.py b/runtime/ops/mapper/audio_gtcrn_denoise/audio_skip.py new file mode 100644 index 00000000..aec49613 --- /dev/null +++ b/runtime/ops/mapper/audio_gtcrn_denoise/audio_skip.py @@ -0,0 +1,114 @@ +# -- encoding: utf-8 -- + +from pathlib import Path +from typing import Any, Dict + +from loguru import logger + + +AUDIO_EXTS = { + "aac", + "aif", + "aiff", + "amr", + "au", + "flac", + "m4a", + "mp3", + "oga", + "ogg", + "opus", + "snd", + "wav", + "webm", + "wma", +} + + +def _parts(path_value: str) -> set[str]: + try: + return {part.lower() for part in Path(path_value).parts} + except Exception: + return set() + + +def is_reference_sample(sample: Dict[str, Any], filepath_key: str = "filePath") -> bool: + path_value = str(sample.get(filepath_key) or "") + return "references" in _parts(path_value) + + +def _ext_from_sample( + sample: Dict[str, Any], + filepath_key: str = "filePath", + filetype_key: str = "fileType", + target_type_key: str = "target_type", +) -> str: + for key in (target_type_key, filetype_key): + value = str(sample.get(key) or "").strip().lower().lstrip(".") + if value: + return value + path_value = str(sample.get(filepath_key) or "").strip() + return Path(path_value).suffix.lower().lstrip(".") if path_value else "" + + +def is_audio_sample( + sample: Dict[str, Any], + filepath_key: str = "filePath", + filetype_key: str = "fileType", + target_type_key: str = "target_type", + data_key: str = "data", +) -> bool: + if is_reference_sample(sample, filepath_key): + return False + data = sample.get(data_key) + if isinstance(data, (bytes, bytearray)) and data: + return True + return _ext_from_sample(sample, filepath_key, filetype_key, target_type_key) in AUDIO_EXTS + + +def invalid_quality_reason(sample: Dict[str, Any], ext_params_key: str = "ext_params") -> str: + for key in ("fileName", "sourceFileName", "filePath"): + marker_source = Path(str(sample.get(key) or "")).stem.lower() + marker = "__quality_invalid" + if marker in marker_source: + reason = marker_source.split(marker, 1)[1].strip("_") or "invalid_audio" + return f"invalid_audio_quality:{reason}" + + ext = sample.get(ext_params_key, {}) + if not isinstance(ext, dict): + return "" + quality = ext.get("audio_quality", {}) + if not isinstance(quality, dict): + return "" + if str(quality.get("quality_flag") or "").strip().lower() != "invalid": + return "" + skip_downstream = quality.get("skip_downstream", True) + if isinstance(skip_downstream, str): + skip_downstream = skip_downstream.strip().lower() in {"1", "true", "yes", "y", "on"} + if not skip_downstream: + return "" + reason = str(quality.get("reason") or "invalid_audio").strip() + return f"invalid_audio_quality:{reason}" + + +def mark_skipped_sample( + sample: Dict[str, Any], + reason: str, + op_name: str, + text_key: str = "text", + data_key: str = "data", + filetype_key: str = "fileType", + target_type_key: str = "target_type", + ext_params_key: str = "ext_params", +) -> Dict[str, Any]: + ext = sample.get(ext_params_key, {}) + if not isinstance(ext, dict): + ext = {"_raw": ext} + ext.setdefault("audio_skip", {})[op_name] = reason + sample[ext_params_key] = ext + sample[text_key] = "" + sample[data_key] = b"" + sample[filetype_key] = "" + sample[target_type_key] = "" + logger.info(f"fileName: {sample.get('fileName')}, method: {op_name} skipped: {reason}") + return sample diff --git a/runtime/ops/mapper/audio_gtcrn_denoise/helpers/utils/color_utils.py b/runtime/ops/mapper/audio_gtcrn_denoise/helpers/utils/color_utils.py new file mode 100644 index 00000000..c2dc28b1 --- /dev/null +++ b/runtime/ops/mapper/audio_gtcrn_denoise/helpers/utils/color_utils.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python3 +""" +命令行颜色工具 +提供 ANSI 转义序列的颜色代码 +""" + +class Colors: + """颜色代码""" + # 前景色 + BLACK = '\033[30m' + RED = '\033[31m' + GREEN = '\033[32m' + YELLOW = '\033[33m' + BLUE = '\033[34m' + MAGENTA = '\033[35m' + CYAN = '\033[36m' + WHITE = '\033[37m' + + # 背景色 + BG_BLACK = '\033[40m' + BG_RED = '\033[41m' + BG_GREEN = '\033[42m' + BG_YELLOW = '\033[43m' + BG_BLUE = '\033[44m' + BG_MAGENTA = '\033[45m' + BG_CYAN = '\033[46m' + BG_WHITE = '\033[47m' + + # 样式 + BOLD = '\033[1m' + UNDERLINE = '\033[4m' + BLINK = '\033[5m' + REVERSE = '\033[7m' + + # 重置 + RESET = '\033[0m' + + +def color_text(text: str, color: str, bold: bool = False) -> str: + """给文本添加颜色 + + Args: + text: 要着色的文本 + color: 颜色代码 + bold: 是否加粗 + + Returns: + str: 带颜色代码的文本 + """ + if bold: + return f"{Colors.BOLD}{color}{text}{Colors.RESET}" + return f"{color}{text}{Colors.RESET}" + + +def info(msg: str) -> str: + """INFO 级别消息(绿色)""" + return f"{Colors.GREEN}[INFO]{Colors.RESET} {msg}" + + +def warning(msg: str) -> str: + """WARNING 级别消息(黄色)""" + return f"{Colors.YELLOW}[WARNING]{Colors.RESET} {msg}" + + +def error(msg: str) -> str: + """ERROR 级别消息(红色)""" + return f"{Colors.RED}[ERROR]{Colors.RESET} {msg}" + + +def ok(msg: str) -> str: + """OK 级别消息(蓝色)""" + return f"{Colors.BLUE}[OK]{Colors.RESET} {msg}" + + +def header(msg: str) -> str: + """标题(蓝色加粗)""" + return f"{Colors.BOLD}{Colors.BLUE}[PROCESS] {msg} {Colors.RESET}" + + +def success(msg: str) -> str: + """成功消息(绿色加粗)""" + return f"{Colors.BOLD}{Colors.GREEN}[SUCCESS] {msg} {Colors.RESET}" + + +def fail(msg: str) -> str: + """失败消息(红色加粗)""" + return f"{Colors.BOLD}{Colors.RED}[ERROR] {msg}{Colors.RESET}" + + +def question(msg: str) -> str: + """问题消息(黄色)""" + return f"{Colors.YELLOW}[WARNING] {msg}{Colors.RESET}" \ No newline at end of file diff --git a/runtime/ops/mapper/audio_gtcrn_denoise/helpers/utils/gtcrn_denoise.py b/runtime/ops/mapper/audio_gtcrn_denoise/helpers/utils/gtcrn_denoise.py new file mode 100644 index 00000000..b97a288a --- /dev/null +++ b/runtime/ops/mapper/audio_gtcrn_denoise/helpers/utils/gtcrn_denoise.py @@ -0,0 +1,349 @@ +#!/usr/bin/env python3 +""" +GTCRN 本地智能降噪工具 + +特点: +- 优先使用 ONNXRuntime 做推理,适合本机快速部署 +- 支持单个音频文件或目录批量处理 +- 输入音频会被统一到 16k / mono / float32 +- 输出为降噪后的 wav + +说明: +- 当前仓库只包含 GTCRN 结构代码,不包含训练好的权重文件。 +- 你需要把训练好的 .onnx / .tar / .pt 放到本地后再指定给 --model。 +- 若给的是 .tar / .pt,可选择 --export_onnx 先导出为 ONNX,再用 ONNXRuntime 推理。 +""" + +import argparse +import sys +from pathlib import Path +from typing import Iterable, List, Optional, Tuple + +import numpy as np + +PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent +GTCRN_ROOT = PROJECT_ROOT / "local_libs" / "gtcrn" +GTCRN_STREAM_ROOT = GTCRN_ROOT / "stream" + +sys.path.insert(0, str(PROJECT_ROOT / "src" / "utils")) +sys.path.insert(0, str(GTCRN_STREAM_ROOT)) +sys.path.insert(0, str(GTCRN_ROOT)) + +try: + from color_utils import info, warning, error, ok, success, header # type: ignore + + def print_info(msg: str): + print(info(msg)) + + def print_warning(msg: str): + print(warning(msg)) + + def print_error(msg: str): + print(error(msg)) + + def print_ok(msg: str): + print(ok(msg)) + + def print_success(msg: str): + print(success(msg)) + + def print_header(msg: str): + print(header(msg)) + +except Exception: + def print_info(msg: str): + print(f"[INFO] {msg}") + + def print_warning(msg: str): + print(f"[WARNING] {msg}") + + def print_error(msg: str): + print(f"[ERROR] {msg}") + + def print_ok(msg: str): + print(f"[OK] {msg}") + + def print_success(msg: str): + print(f"[SUCCESS] {msg}") + + def print_header(msg: str): + print(f"=== {msg} ===") + + +def _import_audio_backend(): + import soundfile as sf # type: ignore + import torch # type: ignore + return sf, torch + + +def _find_audio_files(input_path: Path) -> List[Path]: + exts = {".wav", ".flac", ".mp3", ".aac", ".m4a", ".ogg", ".webm"} + if input_path.is_file(): + return [input_path] + files = [] + for p in input_path.rglob("*"): + if p.is_file() and p.suffix.lower() in exts: + files.append(p) + return sorted(files) + + +def load_audio_mono_16k(path: Path) -> np.ndarray: + """ + 读取任意常见音频并转换为 16k 单声道 float32。 + """ + sf, torch = _import_audio_backend() + data, sr = sf.read(str(path), always_2d=False) + if data.ndim > 1: + data = np.mean(data, axis=1) + data = data.astype(np.float32) + if sr != 16000: + # 使用 torch 做重采样,减少额外依赖差异 + wav = torch.from_numpy(data).float()[None, None, :] + resampler = torch.nn.functional.interpolate + # 简化实现:通过线性插值做基础重采样,够用于前端降噪预处理 + new_len = int(round(wav.shape[-1] * 16000.0 / float(sr))) + wav = torch.nn.functional.interpolate(wav, size=new_len, mode="linear", align_corners=False) + data = wav[0, 0].cpu().numpy() + return data.astype(np.float32) + + +def stft_complex(x: np.ndarray, n_fft: int = 512, hop_length: int = 256, win_length: int = 512): + """ + 将波形转为 GTCRN 需要的复数谱输入: + 返回 shape = (1, F, T, 2) + """ + sf, torch = _import_audio_backend() + _ = sf + wav = torch.from_numpy(x).float() + window = torch.hann_window(win_length).pow(0.5) + spec = torch.stft( + wav, + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + window=window, + return_complex=False, + center=True, + ) # (F, T, 2) + spec = spec.unsqueeze(0) # (1, F, T, 2) + return spec.cpu().numpy().astype(np.float32) + + +def istft_complex(spec: np.ndarray, n_fft: int = 512, hop_length: int = 256, win_length: int = 512): + """ + 将 GTCRN 输出的复数谱还原为波形。 + 输入 shape = (1, F, T, 2) 或 (F, T, 2) + """ + sf, torch = _import_audio_backend() + _ = sf + if spec.ndim == 4: + spec = spec[0] + # spec: (F, T, 2) -> complex tensor + spec_t = torch.from_numpy(spec).float() + spec_t = torch.view_as_complex(spec_t.contiguous()) + window = torch.hann_window(win_length).pow(0.5) + wav = torch.istft( + spec_t, + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + window=window, + center=True, + ) + return wav.cpu().numpy().astype(np.float32) + + +class OnnxGtcrnDenoiser: + """ + 使用 ONNXRuntime 推理 GTCRN。 + 说明: + - GTCRN 是流式结构,ONNX 输入/输出包含 cache。 + - 这里按 1 帧一帧地做流式推理,然后重建为完整波形。 + """ + + def __init__(self, model_path: Path): + try: + import onnxruntime as ort # type: ignore + except Exception as e: + raise RuntimeError("未安装 onnxruntime,请先安装 onnxruntime 或 onnxruntime-gpu") from e + + if not model_path.exists(): + raise FileNotFoundError(f"ONNX 模型不存在: {model_path}") + + self.model_path = model_path + self.session = ort.InferenceSession(str(model_path), providers=["CPUExecutionProvider"]) + self.input_names = [i.name for i in self.session.get_inputs()] + self.output_names = [o.name for o in self.session.get_outputs()] + + # 固定 cache 形状来自 GTCRN stream 版本导出 + self.conv_cache = np.zeros([2, 1, 16, 16, 33], dtype=np.float32) + self.tra_cache = np.zeros([2, 3, 1, 1, 16], dtype=np.float32) + self.inter_cache = np.zeros([2, 1, 33, 16], dtype=np.float32) + + def denoise(self, wav: np.ndarray) -> np.ndarray: + spec = stft_complex(wav) # (1, F, T, 2) + outputs = [] + conv_cache = self.conv_cache.copy() + tra_cache = self.tra_cache.copy() + inter_cache = self.inter_cache.copy() + + # 按时间帧逐帧推理 + for i in range(spec.shape[2]): + mix = spec[:, :, i:i+1, :].astype(np.float32) + out_i, conv_cache, tra_cache, inter_cache = self.session.run( + [], + { + "mix": mix, + "conv_cache": conv_cache, + "tra_cache": tra_cache, + "inter_cache": inter_cache, + }, + ) + outputs.append(out_i) + + out_spec = np.concatenate(outputs, axis=2) # (1, F, T, 2) + wav_out = istft_complex(out_spec) + return wav_out + + +def _resolve_model(model: Path, export_dir: Optional[Path] = None) -> Path: + """ + 解析模型路径: + - 如果是 .onnx,直接返回 + - 如果是 .tar/.pt,可选导出为 ONNX(需要你本地提供训练权重) + """ + if model.suffix.lower() == ".onnx": + return model + if model.suffix.lower() in {".tar", ".pt", ".pth"}: + if export_dir is None: + raise RuntimeError( + "当前给的是 PyTorch 权重,但未指定 ONNX 导出目录。" + "请先把模型导出为 onnx,或传入 --export_dir。" + ) + export_dir.mkdir(parents=True, exist_ok=True) + export_path = export_dir / "gtcrn.onnx" + if export_path.exists(): + return export_path + _export_onnx_from_torch(model, export_path) + return export_path + raise ValueError(f"不支持的模型格式: {model.suffix}") + + +def _export_onnx_from_torch(weight_path: Path, export_path: Path) -> None: + """ + 从本地 torch 权重导出 GTCRN ONNX。 + 依赖 local_libs/gtcrn 的 GTCRN/StreamGTCRN 和 convert_to_stream。 + """ + try: + import torch # type: ignore + except Exception as e: + raise RuntimeError("导出 ONNX 需要 PyTorch") from e + + # 动态导入 GTCRN 实现 + from gtcrn import GTCRN # type: ignore + from stream.gtcrn import StreamGTCRN # type: ignore + from stream.modules.convert import convert_to_stream # type: ignore + + device = torch.device("cpu") + model = GTCRN().to(device).eval() + ckpt = torch.load(str(weight_path), map_location=device) + state = ckpt["model"] if isinstance(ckpt, dict) and "model" in ckpt else ckpt + model.load_state_dict(state, strict=False) + + stream_model = StreamGTCRN().to(device).eval() + convert_to_stream(stream_model, model) + + input_spec = torch.randn(1, 257, 1, 2, device=device) + conv_cache = torch.zeros(2, 1, 16, 16, 33, device=device) + tra_cache = torch.zeros(2, 3, 1, 1, 16, device=device) + inter_cache = torch.zeros(2, 1, 33, 16, device=device) + + print_info(f"导出 ONNX: {export_path}") + torch.onnx.export( + stream_model, + (input_spec, conv_cache, tra_cache, inter_cache), + str(export_path), + input_names=["mix", "conv_cache", "tra_cache", "inter_cache"], + output_names=["enh", "conv_cache_out", "tra_cache_out", "inter_cache_out"], + opset_version=11, + verbose=False, + ) + print_ok(f"ONNX 导出完成: {export_path}") + + +def process_one(input_file: Path, output_file: Path, denoiser: OnnxGtcrnDenoiser) -> None: + sf, _ = _import_audio_backend() + wav = load_audio_mono_16k(input_file) + enhanced = denoiser.denoise(wav) + output_file.parent.mkdir(parents=True, exist_ok=True) + sf.write(str(output_file), enhanced, 16000) + + +def main() -> int: + parser = argparse.ArgumentParser( + description="GTCRN 本地智能降噪工具(优先 ONNXRuntime)", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +示例: + # 单文件降噪(ONNX 模型) + python -m src.utils.gtcrn_denoise --input ./a.wav --model ./models/gtcrn/gtcrn.onnx --output ./out.wav + + # 目录批处理 + python -m src.utils.gtcrn_denoise --input ./input_dir --model ./models/gtcrn/gtcrn.onnx --output ./denoised_dir + + # 如果你手里是 .tar/.pt 权重,可尝试导出 ONNX(需要本地可加载权重) + python -m src.utils.gtcrn_denoise --input ./a.wav --model ./weights/model_trained_on_dns3.tar --export_dir ./models/gtcrn_onnx --output ./out.wav + """, + ) + parser.add_argument("--input", required=True, help="输入音频文件或目录") + parser.add_argument("--model", required=True, help="GTCRN 模型路径(.onnx/.tar/.pt/.pth)") + parser.add_argument("--output", required=True, help="输出 wav 文件或目录") + parser.add_argument("--export_dir", default=None, help="若输入为 .tar/.pt,则导出 ONNX 的目录") + args = parser.parse_args() + + input_path = Path(args.input).resolve() + model_path = Path(args.model).resolve() + output_path = Path(args.output).resolve() + export_dir = Path(args.export_dir).resolve() if args.export_dir else None + + print_header("GTCRN 智能降噪") + print_info(f"输入: {input_path}") + print_info(f"模型: {model_path}") + print_info(f"输出: {output_path}") + + try: + resolved_model = _resolve_model(model_path, export_dir=export_dir) + print_info(f"使用模型: {resolved_model}") + denoiser = OnnxGtcrnDenoiser(resolved_model) + except Exception as e: + print_error(f"初始化失败: {e}") + return 1 + + files = _find_audio_files(input_path) + if not files: + print_warning("未找到可处理的音频文件") + return 0 + + try: + if input_path.is_file(): + if output_path.suffix.lower() != ".wav": + output_path = output_path.with_suffix(".wav") + process_one(files[0], output_path, denoiser) + print_success(f"完成: {output_path}") + else: + output_path.mkdir(parents=True, exist_ok=True) + for f in files: + out_file = output_path / f"{f.stem}.wav" + print_info(f"降噪: {f.name} -> {out_file.name}") + process_one(f, out_file, denoiser) + print_success(f"批量完成,输出目录: {output_path}") + except Exception as e: + print_error(f"处理失败: {e}") + return 1 + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) + diff --git a/runtime/ops/mapper/audio_gtcrn_denoise/metadata.yml b/runtime/ops/mapper/audio_gtcrn_denoise/metadata.yml new file mode 100644 index 00000000..72d813d8 --- /dev/null +++ b/runtime/ops/mapper/audio_gtcrn_denoise/metadata.yml @@ -0,0 +1,32 @@ +name: 'audioOps-GTCRN 智能降噪' +name_en: 'audioOps-GTCRN Denoise' +description: '调用 audio_preprocessor 的 GTCRN ONNX 降噪工具对音频降噪;由 DataMate 统一导出结果。' +description_en: 'Run GTCRN ONNX denoiser from audio_preprocessor; DataMate exports the result.' +language: 'python' +vendor: 'huawei' +raw_id: 'AudioGtcrnDenoise' +version: '1.0.0' +types: + - 'cleaning' +modal: 'audio' +inputs: 'audio' +outputs: 'audio' +settings: + modelPath: + name: 'GTCRN 模型路径' + description: 'GTCRN ONNX 模型绝对路径(.onnx)。默认使用固定部署路径 /models/AudioOperations/gtcrn/gtcrn.onnx。' + type: 'input' + defaultVal: '/models/AudioOperations/gtcrn/gtcrn.onnx' + required: false +runtime: + memory: 2147483648 + cpu: 0.5 + gpu: 0 + npu: 0 + storage: 10MB + +metrics: + - name: '处理耗时' + metric: '依输入音频长度与运行环境而定' +release: + - '首次发布' diff --git a/runtime/ops/mapper/audio_gtcrn_denoise/process.py b/runtime/ops/mapper/audio_gtcrn_denoise/process.py new file mode 100644 index 00000000..360dfd40 --- /dev/null +++ b/runtime/ops/mapper/audio_gtcrn_denoise/process.py @@ -0,0 +1,97 @@ +# -- encoding: utf-8 -- + +import tempfile +import time +from pathlib import Path +from typing import Dict, Any + +from loguru import logger + +from datamate.core.base_op import Mapper +try: + from .audio_skip import invalid_quality_reason, is_audio_sample, mark_skipped_sample +except ImportError: + from audio_skip import invalid_quality_reason, is_audio_sample, mark_skipped_sample + + +DEFAULT_GTCRN_MODEL_PATH = "/models/AudioOperations/gtcrn/gtcrn.onnx" + + +def _repo_root() -> Path: + return Path(__file__).resolve().parent + + +def _audio_preprocessor_root() -> Path: + return _repo_root() + + +class AudioGtcrnDenoise(Mapper): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.model_path = str(kwargs.get("modelPath", "")).strip() + + def execute(self, sample: Dict[str, Any]) -> Dict[str, Any]: + start = time.time() + quality_skip_reason = invalid_quality_reason(sample, self.ext_params_key) + if quality_skip_reason: + return mark_skipped_sample( + sample, + quality_skip_reason, + self.__class__.__name__, + self.text_key, + self.data_key, + self.filetype_key, + self.target_type_key, + self.ext_params_key, + ) + + if not is_audio_sample(sample, self.filepath_key, self.filetype_key, self.target_type_key, self.data_key): + return mark_skipped_sample( + sample, + "non_audio_or_reference_file", + self.__class__.__name__, + self.text_key, + self.data_key, + self.filetype_key, + self.target_type_key, + self.ext_params_key, + ) + + package_root = _audio_preprocessor_root() + + in_path = Path(sample.get(self.filepath_key, "")).resolve() + audio_bytes = sample.get(self.data_key) + if not audio_bytes and not in_path.exists(): + raise FileNotFoundError(f"输入音频不存在: {in_path}") + + model = Path(self.model_path or DEFAULT_GTCRN_MODEL_PATH).expanduser() + model = model.resolve() + if not model.exists(): + raise FileNotFoundError(f"GTCRN ONNX 模型不存在: {model}") + + # 直接调用 audio_preprocessor 的工具函数,避免 subprocess 路径/环境差异 + import sys + + utils_dir = package_root / "helpers" / "utils" + if str(utils_dir) not in sys.path: + sys.path.insert(0, str(utils_dir)) + + from gtcrn_denoise import OnnxGtcrnDenoiser, process_one # type: ignore + + denoiser = OnnxGtcrnDenoiser(model) + with tempfile.TemporaryDirectory(prefix="audio_gtcrn_denoise_") as tmpdir: + if audio_bytes: + in_path = Path(tmpdir) / "input.wav" + in_path.write_bytes(bytes(audio_bytes)) + out_path = Path(tmpdir) / "denoised.wav" + process_one(in_path, out_path, denoiser) + sample[self.data_key] = out_path.read_bytes() + + sample[self.text_key] = "" + sample[self.target_type_key] = "wav" + sample[self.filetype_key] = "txt" if self.is_last_op else "wav" + + logger.info( + f"fileName: {sample.get(self.filename_key)}, method: AudioGtcrnDenoise costs {time.time() - start:6f} s" + ) + return sample diff --git a/runtime/ops/mapper/audio_gtcrn_denoise/requirements.txt b/runtime/ops/mapper/audio_gtcrn_denoise/requirements.txt new file mode 100644 index 00000000..d8fcae1a --- /dev/null +++ b/runtime/ops/mapper/audio_gtcrn_denoise/requirements.txt @@ -0,0 +1,4 @@ +onnxruntime +soundfile +numpy +torch diff --git a/runtime/ops/mapper/audio_hum_notch/README.md b/runtime/ops/mapper/audio_hum_notch/README.md new file mode 100644 index 00000000..e5233447 --- /dev/null +++ b/runtime/ops/mapper/audio_hum_notch/README.md @@ -0,0 +1,25 @@ +# AudioHumNotch 工频陷波算子 + +## 概述 + +AudioHumNotch 处理输入音频,并将结果写入 `sample["data"]`,同时设置 `sample["target_type"]`。输出路径、同名文件处理和最终落盘均交由 DataMate 的标准导出流程负责。 + +## 参数说明 + +| 参数 | 类型 | 默认值 | 说明 | +|---|---|---:|---| +| freqHz | select | 50 | 中心频率(Hz):50/60 | +| q | slider | 30 | 品质因数,越大陷波越窄 | + +## 输入输出 + +- **输入**:`sample["filePath"]`,若上游算子已产生 `sample["data"]`,则优先处理该音频字节。 +- **输出**:`sample["data"]` 为处理后的音频字节;`sample["target_type"]` 为目标音频后缀。 + +## 依赖说明 + +- **Python 依赖**:soundfile、numpy、scipy(scipy.signal) + +## 版本历史 + +- **v1.0.0**:首次发布 diff --git a/runtime/ops/mapper/audio_hum_notch/__init__.py b/runtime/ops/mapper/audio_hum_notch/__init__.py new file mode 100644 index 00000000..218f373a --- /dev/null +++ b/runtime/ops/mapper/audio_hum_notch/__init__.py @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- + +from datamate.core.base_op import OPERATORS + +OPERATORS.register_module(module_name='AudioHumNotch', + module_path="ops.mapper.audio_hum_notch.process") diff --git a/runtime/ops/mapper/audio_hum_notch/audio_skip.py b/runtime/ops/mapper/audio_hum_notch/audio_skip.py new file mode 100644 index 00000000..aec49613 --- /dev/null +++ b/runtime/ops/mapper/audio_hum_notch/audio_skip.py @@ -0,0 +1,114 @@ +# -- encoding: utf-8 -- + +from pathlib import Path +from typing import Any, Dict + +from loguru import logger + + +AUDIO_EXTS = { + "aac", + "aif", + "aiff", + "amr", + "au", + "flac", + "m4a", + "mp3", + "oga", + "ogg", + "opus", + "snd", + "wav", + "webm", + "wma", +} + + +def _parts(path_value: str) -> set[str]: + try: + return {part.lower() for part in Path(path_value).parts} + except Exception: + return set() + + +def is_reference_sample(sample: Dict[str, Any], filepath_key: str = "filePath") -> bool: + path_value = str(sample.get(filepath_key) or "") + return "references" in _parts(path_value) + + +def _ext_from_sample( + sample: Dict[str, Any], + filepath_key: str = "filePath", + filetype_key: str = "fileType", + target_type_key: str = "target_type", +) -> str: + for key in (target_type_key, filetype_key): + value = str(sample.get(key) or "").strip().lower().lstrip(".") + if value: + return value + path_value = str(sample.get(filepath_key) or "").strip() + return Path(path_value).suffix.lower().lstrip(".") if path_value else "" + + +def is_audio_sample( + sample: Dict[str, Any], + filepath_key: str = "filePath", + filetype_key: str = "fileType", + target_type_key: str = "target_type", + data_key: str = "data", +) -> bool: + if is_reference_sample(sample, filepath_key): + return False + data = sample.get(data_key) + if isinstance(data, (bytes, bytearray)) and data: + return True + return _ext_from_sample(sample, filepath_key, filetype_key, target_type_key) in AUDIO_EXTS + + +def invalid_quality_reason(sample: Dict[str, Any], ext_params_key: str = "ext_params") -> str: + for key in ("fileName", "sourceFileName", "filePath"): + marker_source = Path(str(sample.get(key) or "")).stem.lower() + marker = "__quality_invalid" + if marker in marker_source: + reason = marker_source.split(marker, 1)[1].strip("_") or "invalid_audio" + return f"invalid_audio_quality:{reason}" + + ext = sample.get(ext_params_key, {}) + if not isinstance(ext, dict): + return "" + quality = ext.get("audio_quality", {}) + if not isinstance(quality, dict): + return "" + if str(quality.get("quality_flag") or "").strip().lower() != "invalid": + return "" + skip_downstream = quality.get("skip_downstream", True) + if isinstance(skip_downstream, str): + skip_downstream = skip_downstream.strip().lower() in {"1", "true", "yes", "y", "on"} + if not skip_downstream: + return "" + reason = str(quality.get("reason") or "invalid_audio").strip() + return f"invalid_audio_quality:{reason}" + + +def mark_skipped_sample( + sample: Dict[str, Any], + reason: str, + op_name: str, + text_key: str = "text", + data_key: str = "data", + filetype_key: str = "fileType", + target_type_key: str = "target_type", + ext_params_key: str = "ext_params", +) -> Dict[str, Any]: + ext = sample.get(ext_params_key, {}) + if not isinstance(ext, dict): + ext = {"_raw": ext} + ext.setdefault("audio_skip", {})[op_name] = reason + sample[ext_params_key] = ext + sample[text_key] = "" + sample[data_key] = b"" + sample[filetype_key] = "" + sample[target_type_key] = "" + logger.info(f"fileName: {sample.get('fileName')}, method: {op_name} skipped: {reason}") + return sample diff --git a/runtime/ops/mapper/audio_hum_notch/metadata.yml b/runtime/ops/mapper/audio_hum_notch/metadata.yml new file mode 100644 index 00000000..1722fcb0 --- /dev/null +++ b/runtime/ops/mapper/audio_hum_notch/metadata.yml @@ -0,0 +1,45 @@ +name: 'audioUtils-工频陷波' +name_en: 'audioUtils-Hum Notch Filter' +description: '50/60Hz 工频陷波抑制。需要 scipy.signal;处理音频并由 DataMate 统一导出结果。' +description_en: 'Notch filter for 50/60Hz hum suppression. Requires scipy.signal; process audio and let DataMate export the result.' +language: 'python' +vendor: 'huawei' +raw_id: 'AudioHumNotch' +version: '1.0.0' +types: + - 'cleaning' +modal: 'audio' +inputs: 'audio' +outputs: 'audio' +settings: + freqHz: + name: '中心频率(Hz)' + type: 'select' + description: '工频中心频率。' + defaultVal: '50' + required: true + options: + - label: '50Hz' + value: '50' + - label: '60Hz' + value: '60' + q: + name: 'Q' + type: 'slider' + description: '陷波品质因数,越大越窄。' + defaultVal: 30 + min: 1 + max: 200 + step: 1 +runtime: + memory: 104857600 + cpu: 0.2 + gpu: 0 + npu: 0 + storage: 10MB + +metrics: + - name: '处理耗时' + metric: '依输入音频长度与运行环境而定' +release: + - '首次发布' diff --git a/runtime/ops/mapper/audio_hum_notch/process.py b/runtime/ops/mapper/audio_hum_notch/process.py new file mode 100644 index 00000000..ec27cc31 --- /dev/null +++ b/runtime/ops/mapper/audio_hum_notch/process.py @@ -0,0 +1,105 @@ +# -- encoding: utf-8 -- + +import io +import time +from pathlib import Path +from typing import Dict, Any, Tuple + +from loguru import logger + +from datamate.core.base_op import Mapper +try: + from .audio_skip import invalid_quality_reason, is_audio_sample, mark_skipped_sample +except ImportError: + from audio_skip import invalid_quality_reason, is_audio_sample, mark_skipped_sample + + + +def _load_audio(source: object) -> Tuple["object", int]: + try: + import soundfile as sf # type: ignore + + if isinstance(source, (bytes, bytearray)): + data, sr = sf.read(io.BytesIO(bytes(source)), always_2d=False) + else: + data, sr = sf.read(str(source), always_2d=False) + return data, int(sr) + except Exception as e: + raise RuntimeError(f"读取音频失败(需要 soundfile): error={e}") from e + + +def _dump_audio(data: "object", sr: int, fmt: str) -> bytes: + try: + import soundfile as sf # type: ignore + + with io.BytesIO() as buf: + sf.write(buf, data, int(sr), format=fmt.upper() if fmt else "WAV") + return buf.getvalue() + except Exception as e: + raise RuntimeError(f"编码音频失败(需要 soundfile,fmt={fmt}): {e}") from e + + +class AudioHumNotch(Mapper): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.freq_hz = float(kwargs.get("freqHz", 50)) + self.q = float(kwargs.get("q", 30)) + self.out_format = "wav" + + def execute(self, sample: Dict[str, Any]) -> Dict[str, Any]: + start = time.time() + quality_skip_reason = invalid_quality_reason(sample, self.ext_params_key) + if quality_skip_reason: + return mark_skipped_sample( + sample, + quality_skip_reason, + self.__class__.__name__, + self.text_key, + self.data_key, + self.filetype_key, + self.target_type_key, + self.ext_params_key, + ) + + if not is_audio_sample(sample, self.filepath_key, self.filetype_key, self.target_type_key, self.data_key): + return mark_skipped_sample( + sample, + "non_audio_or_reference_file", + self.__class__.__name__, + self.text_key, + self.data_key, + self.filetype_key, + self.target_type_key, + self.ext_params_key, + ) + in_path = Path(sample.get(self.filepath_key, "")).resolve() + if not in_path.exists(): + raise FileNotFoundError(f"输入音频不存在: {in_path}") + + data, sr = _load_audio(sample.get(self.data_key) or in_path) + try: + import numpy as np + from scipy.signal import iirnotch, lfilter # type: ignore + + x = np.asarray(data, dtype=np.float32) + if x.ndim > 1: + x = x.mean(axis=1) + if x.size == 0: + y = x + else: + w0 = float(self.freq_hz) / (float(sr) / 2.0) + b, a = iirnotch(w0, float(self.q)) + y = lfilter(b, a, x).astype(np.float32) + y = np.clip(y, -1.0, 1.0) + except ImportError as e: + raise RuntimeError("AudioHumNotch 需要 scipy.signal(iirnotch/lfilter)") from e + except Exception as e: + raise RuntimeError(f"处理失败: {e}") from e + + sample[self.data_key] = _dump_audio(y, sr, self.out_format) + sample[self.text_key] = "" + sample[self.target_type_key] = self.out_format + sample[self.filetype_key] = "txt" if self.is_last_op else self.out_format + + logger.info(f"fileName: {sample.get(self.filename_key)}, method: AudioHumNotch costs {time.time() - start:6f} s") + return sample diff --git a/runtime/ops/mapper/audio_hum_notch/requirements.txt b/runtime/ops/mapper/audio_hum_notch/requirements.txt new file mode 100644 index 00000000..843a926a --- /dev/null +++ b/runtime/ops/mapper/audio_hum_notch/requirements.txt @@ -0,0 +1,3 @@ +soundfile +numpy +scipy diff --git a/runtime/ops/mapper/audio_noise_gate/README.md b/runtime/ops/mapper/audio_noise_gate/README.md new file mode 100644 index 00000000..82129cb5 --- /dev/null +++ b/runtime/ops/mapper/audio_noise_gate/README.md @@ -0,0 +1,27 @@ +# AudioNoiseGate 噪声门算子 + +## 概述 + +AudioNoiseGate 处理输入音频,并将结果写入 `sample["data"]`,同时设置 `sample["target_type"]`。输出路径、同名文件处理和最终落盘均交由 DataMate 的标准导出流程负责。 + +## 参数说明 + +| 参数 | 类型 | 默认值 | 说明 | +|---|---|---:|---| +| thresholdDb | slider | -45 | 门限(dB,相对全段峰值) | +| frameMs | inputNumber | 20 | 帧长(ms) | +| hopMs | inputNumber | 10 | 帧移(ms) | +| floorRatio | slider | 0.05 | 门控时保留能量比例(0~1) | + +## 输入输出 + +- **输入**:`sample["filePath"]`,若上游算子已产生 `sample["data"]`,则优先处理该音频字节。 +- **输出**:`sample["data"]` 为处理后的音频字节;`sample["target_type"]` 为目标音频后缀。 + +## 依赖说明 + +- **Python 依赖**:soundfile、numpy + +## 版本历史 + +- **v1.0.0**:首次发布 diff --git a/runtime/ops/mapper/audio_noise_gate/__init__.py b/runtime/ops/mapper/audio_noise_gate/__init__.py new file mode 100644 index 00000000..11e17725 --- /dev/null +++ b/runtime/ops/mapper/audio_noise_gate/__init__.py @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- + +from datamate.core.base_op import OPERATORS + +OPERATORS.register_module(module_name='AudioNoiseGate', + module_path="ops.mapper.audio_noise_gate.process") diff --git a/runtime/ops/mapper/audio_noise_gate/audio_skip.py b/runtime/ops/mapper/audio_noise_gate/audio_skip.py new file mode 100644 index 00000000..aec49613 --- /dev/null +++ b/runtime/ops/mapper/audio_noise_gate/audio_skip.py @@ -0,0 +1,114 @@ +# -- encoding: utf-8 -- + +from pathlib import Path +from typing import Any, Dict + +from loguru import logger + + +AUDIO_EXTS = { + "aac", + "aif", + "aiff", + "amr", + "au", + "flac", + "m4a", + "mp3", + "oga", + "ogg", + "opus", + "snd", + "wav", + "webm", + "wma", +} + + +def _parts(path_value: str) -> set[str]: + try: + return {part.lower() for part in Path(path_value).parts} + except Exception: + return set() + + +def is_reference_sample(sample: Dict[str, Any], filepath_key: str = "filePath") -> bool: + path_value = str(sample.get(filepath_key) or "") + return "references" in _parts(path_value) + + +def _ext_from_sample( + sample: Dict[str, Any], + filepath_key: str = "filePath", + filetype_key: str = "fileType", + target_type_key: str = "target_type", +) -> str: + for key in (target_type_key, filetype_key): + value = str(sample.get(key) or "").strip().lower().lstrip(".") + if value: + return value + path_value = str(sample.get(filepath_key) or "").strip() + return Path(path_value).suffix.lower().lstrip(".") if path_value else "" + + +def is_audio_sample( + sample: Dict[str, Any], + filepath_key: str = "filePath", + filetype_key: str = "fileType", + target_type_key: str = "target_type", + data_key: str = "data", +) -> bool: + if is_reference_sample(sample, filepath_key): + return False + data = sample.get(data_key) + if isinstance(data, (bytes, bytearray)) and data: + return True + return _ext_from_sample(sample, filepath_key, filetype_key, target_type_key) in AUDIO_EXTS + + +def invalid_quality_reason(sample: Dict[str, Any], ext_params_key: str = "ext_params") -> str: + for key in ("fileName", "sourceFileName", "filePath"): + marker_source = Path(str(sample.get(key) or "")).stem.lower() + marker = "__quality_invalid" + if marker in marker_source: + reason = marker_source.split(marker, 1)[1].strip("_") or "invalid_audio" + return f"invalid_audio_quality:{reason}" + + ext = sample.get(ext_params_key, {}) + if not isinstance(ext, dict): + return "" + quality = ext.get("audio_quality", {}) + if not isinstance(quality, dict): + return "" + if str(quality.get("quality_flag") or "").strip().lower() != "invalid": + return "" + skip_downstream = quality.get("skip_downstream", True) + if isinstance(skip_downstream, str): + skip_downstream = skip_downstream.strip().lower() in {"1", "true", "yes", "y", "on"} + if not skip_downstream: + return "" + reason = str(quality.get("reason") or "invalid_audio").strip() + return f"invalid_audio_quality:{reason}" + + +def mark_skipped_sample( + sample: Dict[str, Any], + reason: str, + op_name: str, + text_key: str = "text", + data_key: str = "data", + filetype_key: str = "fileType", + target_type_key: str = "target_type", + ext_params_key: str = "ext_params", +) -> Dict[str, Any]: + ext = sample.get(ext_params_key, {}) + if not isinstance(ext, dict): + ext = {"_raw": ext} + ext.setdefault("audio_skip", {})[op_name] = reason + sample[ext_params_key] = ext + sample[text_key] = "" + sample[data_key] = b"" + sample[filetype_key] = "" + sample[target_type_key] = "" + logger.info(f"fileName: {sample.get('fileName')}, method: {op_name} skipped: {reason}") + return sample diff --git a/runtime/ops/mapper/audio_noise_gate/metadata.yml b/runtime/ops/mapper/audio_noise_gate/metadata.yml new file mode 100644 index 00000000..0b095fb1 --- /dev/null +++ b/runtime/ops/mapper/audio_noise_gate/metadata.yml @@ -0,0 +1,58 @@ +name: 'audioUtils-噪声门' +name_en: 'audioUtils-Noise Gate' +description: '短时 RMS 低于阈值时按 floor_ratio 衰减(相对全段峰值 dB)。处理音频并由 DataMate 统一导出结果。' +description_en: 'Attenuate frames whose RMS below threshold (dB relative to peak) by floor_ratio. Process audio and let DataMate export the result.' +language: 'python' +vendor: 'huawei' +raw_id: 'AudioNoiseGate' +version: '1.0.0' +types: + - 'cleaning' +modal: 'audio' +inputs: 'audio' +outputs: 'audio' +settings: + thresholdDb: + name: '门限(dB)' + type: 'slider' + description: '相对全段峰值的门限(dB),越小越“宽松”。' + defaultVal: -45 + min: -80 + max: 0 + step: 1 + frameMs: + name: '帧长(ms)' + type: 'inputNumber' + description: '分析帧长。' + defaultVal: 20 + min: 5 + max: 200 + step: 1 + hopMs: + name: '帧移(ms)' + type: 'inputNumber' + description: '帧移。' + defaultVal: 10 + min: 1 + max: 200 + step: 1 + floorRatio: + name: '衰减比例' + type: 'slider' + description: '门控时保留能量比例(0~1)。' + defaultVal: 0.05 + min: 0 + max: 1 + step: 0.01 +runtime: + memory: 104857600 + cpu: 0.15 + gpu: 0 + npu: 0 + storage: 10MB + +metrics: + - name: '处理耗时' + metric: '依输入音频长度与运行环境而定' +release: + - '首次发布' diff --git a/runtime/ops/mapper/audio_noise_gate/process.py b/runtime/ops/mapper/audio_noise_gate/process.py new file mode 100644 index 00000000..d71fa842 --- /dev/null +++ b/runtime/ops/mapper/audio_noise_gate/process.py @@ -0,0 +1,112 @@ +# -- encoding: utf-8 -- + +import io +import time +from pathlib import Path +from typing import Dict, Any, Tuple + +from loguru import logger + +from datamate.core.base_op import Mapper +try: + from .audio_skip import invalid_quality_reason, is_audio_sample, mark_skipped_sample +except ImportError: + from audio_skip import invalid_quality_reason, is_audio_sample, mark_skipped_sample + + + +def _load_audio(source: object) -> Tuple["object", int]: + try: + import soundfile as sf # type: ignore + + if isinstance(source, (bytes, bytearray)): + data, sr = sf.read(io.BytesIO(bytes(source)), always_2d=False) + else: + data, sr = sf.read(str(source), always_2d=False) + return data, int(sr) + except Exception as e: + raise RuntimeError(f"读取音频失败(需要 soundfile): error={e}") from e + + +def _dump_audio(data: "object", sr: int, fmt: str) -> bytes: + try: + import soundfile as sf # type: ignore + + with io.BytesIO() as buf: + sf.write(buf, data, int(sr), format=fmt.upper() if fmt else "WAV") + return buf.getvalue() + except Exception as e: + raise RuntimeError(f"编码音频失败(需要 soundfile,fmt={fmt}): {e}") from e + + +class AudioNoiseGate(Mapper): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.threshold_db = float(kwargs.get("thresholdDb", -45)) + self.frame_ms = float(kwargs.get("frameMs", 20)) + self.hop_ms = float(kwargs.get("hopMs", 10)) + self.floor_ratio = float(kwargs.get("floorRatio", 0.05)) + self.out_format = "wav" + + def execute(self, sample: Dict[str, Any]) -> Dict[str, Any]: + start = time.time() + quality_skip_reason = invalid_quality_reason(sample, self.ext_params_key) + if quality_skip_reason: + return mark_skipped_sample( + sample, + quality_skip_reason, + self.__class__.__name__, + self.text_key, + self.data_key, + self.filetype_key, + self.target_type_key, + self.ext_params_key, + ) + + if not is_audio_sample(sample, self.filepath_key, self.filetype_key, self.target_type_key, self.data_key): + return mark_skipped_sample( + sample, + "non_audio_or_reference_file", + self.__class__.__name__, + self.text_key, + self.data_key, + self.filetype_key, + self.target_type_key, + self.ext_params_key, + ) + in_path = Path(sample.get(self.filepath_key, "")).resolve() + if not in_path.exists(): + raise FileNotFoundError(f"输入音频不存在: {in_path}") + + data, sr = _load_audio(sample.get(self.data_key) or in_path) + try: + import numpy as np + + x = np.asarray(data, dtype=np.float32) + if x.ndim > 1: + x = x.mean(axis=1) + if x.size == 0: + y = x + else: + peak = float(np.max(np.abs(x))) + 1e-12 + th = peak * (10.0 ** (float(self.threshold_db) / 20.0)) + frame_len = max(1, int(sr * self.frame_ms / 1000.0)) + hop = max(1, int(sr * self.hop_ms / 1000.0)) + y = x.copy() + for st in range(0, len(x), hop): + ed = min(st + frame_len, len(x)) + frame = x[st:ed] + rms = float(np.sqrt(np.mean(frame * frame) + 1e-12)) + if rms < th: + y[st:ed] = y[st:ed] * float(self.floor_ratio) + y = np.clip(y, -1.0, 1.0) + except Exception as e: + raise RuntimeError(f"处理失败(需要 numpy): {e}") from e + + sample[self.data_key] = _dump_audio(y, sr, self.out_format) + sample[self.text_key] = "" + sample[self.target_type_key] = self.out_format + sample[self.filetype_key] = "txt" if self.is_last_op else self.out_format + + logger.info(f"fileName: {sample.get(self.filename_key)}, method: AudioNoiseGate costs {time.time() - start:6f} s") + return sample diff --git a/runtime/ops/mapper/audio_noise_gate/requirements.txt b/runtime/ops/mapper/audio_noise_gate/requirements.txt new file mode 100644 index 00000000..17e9d57d --- /dev/null +++ b/runtime/ops/mapper/audio_noise_gate/requirements.txt @@ -0,0 +1,2 @@ +soundfile +numpy diff --git a/runtime/ops/mapper/audio_ops_manual_test_steps.md b/runtime/ops/mapper/audio_ops_manual_test_steps.md new file mode 100644 index 00000000..0f5a7f43 --- /dev/null +++ b/runtime/ops/mapper/audio_ops_manual_test_steps.md @@ -0,0 +1,462 @@ +# 音频算子简易测试步骤 + +- `audio_anomaly_filter` +- `audio_asr_transcribe` +- `audio_dc_offset_removal` +- `audio_fast_lang_id` +- `audio_fast_lang_id_text` +- `audio_format_convert` +- `audio_gtcrn_denoise` +- `audio_hum_notch` +- `audio_noise_gate` +- `audio_text_summarize` + +测试素材可从以下目录选择: + +```text +\DataMate\runtime\ops\mapper\__audioOps_Test_Cases__ +``` + +## 通用准备 + +环境前置依赖: + +1. DataMate runtime 需要安装 [audio_runtime_requirements.txt](audio_runtime_requirements.txt) 中固定版本的 Python 包。 +2. DataMate runtime 的系统环境需要能直接执行 `ffmpeg -version`。 +3. ASR 相关算子需要 runtime 能导入 `wenet.bin.recognize`,WeNet 不再由算子目录内置。 +4. 算子目录不再内置 `speechbrain`、`wenet`、`ffmpeg`、`panns_inference` 等第三方依赖。 + +1. 在 DataMate 中创建一个输入数据集和一个输出数据集。 +2. 导入测试素材。建议至少导入以下文件: + - 中文语音: + - `humanSpeech\zh\aishell_0000.wav` + - `humanSpeech\zh\aishell_0001.wav` + - `humanSpeech\zh\aishell_0002.wav` + - 英文语音: + - `humanSpeech\en\librispeech_0000.wav` + - `humanSpeech\en\librispeech_0001.wav` + - `humanSpeech\en\librispeech_0002.wav` + - 摘要/ASR 测试语音: + - `audio\summary\84-121123-0000.flac` + - `audio\summary\84-121123-0001.flac` + - `audio\summary\BAC009S0002W0122.wav` + - `audio\summary\BAC009S0002W0123.wav` +3. 单算子测试时,每次只选择一个目标算子运行,输出保存到新的输出数据集。 +4. 串联测试时,要求上一个算子的输出数据集作为下一个算子的输入数据集。 +5. 每次测试后检查三点: + - 任务是否成功完成,无报错。 + - 输出数据集中文件数量是否与输入样本数量一致,除非算子本身明确过滤或跳过。 + - 输出文件类型、文件内容、文件名标记或 `ext_params` 是否符合算子功能。 + +## 1. audio_anomaly_filter + +用途:检测音频是否异常,并把检测结果写入 `ext_params.audio_quality`。该算子输出仍应保留音频,不能只输出标签。 + +推荐素材: + +```text +humanSpeech\zh\aishell_0000.wav +humanSpeech\en\librispeech_0000.wav +audio\summary\84-121123-0000.flac +``` + +测试步骤: + +1. 输入数据集导入上述音频。 +2. 运行 `audio_anomaly_filter`。 +3. 参数先使用默认值: + - `minDur = 1.0` + - `maxDur = 20000.0` + - `silenceRatioTh = 0.8` + - `skipInvalidDownstream = true` +4. 检查输出数据集: + - 每个输入音频都应有对应输出。 + - 正常音频应仍可作为音频被后续音频算子继续处理。 + - `ext_params.audio_quality.quality_flag` 应为 `ok` 或 `invalid`。 + - `ext_params.audio_quality.duration`、`silence_ratio`、`global_rms` 应存在。 +5. 异常分支测试: + - 将 `minDur` 设置为一个明显大于测试音频时长的值,例如 `9999`。 + - 再次运行。 + - 输出样本应被标记为 `invalid`。 + - 文件名中可能出现 `__quality_invalid_...` 标记。 + - 音频数据仍应保留,供后续算子根据 `skipInvalidDownstream` 决定是否跳过。 + +通过标准: + +- 算子运行成功。 +- 输出不丢失音频。 +- `audio_quality` 质量信息完整。 +- 异常参数下能正确标记异常样本。 + +## 2. audio_asr_transcribe + +用途:输入音频,调用 ASR 模型,输出转写文本。 + +推荐素材: + +```text +humanSpeech\zh\aishell_0000.wav +humanSpeech\zh\aishell_0001.wav +humanSpeech\en\librispeech_0000.wav +humanSpeech\en\librispeech_0001.wav +audio\summary\BAC009S0002W0122.wav +audio\summary\84-121123-0000.flac +``` + +前置条件: + +- 中文模型目录默认应存在: + +```text +/models/AudioOperations/asr/aishell +``` + +- 英文模型目录默认应存在: + +```text +/models/AudioOperations/asr/librispeech +``` + +- 目录中应包含 `train.yaml`、`final.pt`、`units.txt`。 + +测试步骤: + +1. 输入数据集导入中文和英文语音。 +2. 推荐先单独测试中文: + - 参数 `language = zh` + - 参数 `device` 根据环境选择,优先使用实际可用设备;无 NPU 时使用 `cpu`。 + - 其他参数保持默认。 +3. 运行 `audio_asr_transcribe`。 +4. 检查输出数据集: + - 输出文件应为文本。 + - 文本内容不应为空。 + - `ext_params.audio_asr_transcribe.language` 应为 `zh`。 + - `transcript_source` 应存在,通常为 `asr`。 +5. 再单独测试英文: + - 参数 `language = en` + - 输入英文 `librispeech_*.wav`。 + - 输出文本应为英文内容,且不为空。 +6. 串联测试: + - 先运行 `audio_fast_lang_id`。 + - 将其输出数据集作为 `audio_asr_transcribe` 输入。 + - `audio_asr_transcribe` 参数设置为 `language = auto`。 + - 检查中文音频使用中文模型,英文音频使用英文模型。 + +通过标准: + +- 算子运行成功。 +- 输出类型为文本。 +- 转写文本非空。 +- `language = auto` 时能读取上游 LID 结果或文件名标记。 + +## 3. audio_dc_offset_removal + +用途:去除音频直流偏置,输出处理后的 WAV 音频。 + +推荐素材: + +```text +humanSpeech\zh\aishell_0000.wav +humanSpeech\en\librispeech_0000.wav +audio\summary\BAC009S0002W0122.wav +``` + +测试步骤: + +1. 输入数据集导入上述音频。 +2. 运行 `audio_dc_offset_removal`,无需配置参数。 +3. 检查输出数据集: + - 输出文件数量应与输入一致。 + - 输出仍应为音频文件。 + - 输出目标格式应为 `wav`。 + - 音频可以正常播放或被后续音频算子读取。 +4. 可选串联验证: + - 将输出继续输入 `audio_asr_transcribe`。 + - ASR 应能正常产生文本。 + +通过标准: + +- 算子运行成功。 +- 输出 WAV 音频可读取。 +- 未把音频错误替换为空文本或标签。 + +## 4. audio_fast_lang_id + +用途:识别语音音频语言为 `zh` 或 `en`,结果写入 `ext_params.audio_lid.lang`,同时保留原音频给下游继续使用。 + +推荐素材: + +```text +humanSpeech\zh\aishell_0000.wav +humanSpeech\zh\aishell_0001.wav +humanSpeech\en\librispeech_0000.wav +humanSpeech\en\librispeech_0001.wav +``` + +前置条件: + +- LID 模型目录默认应存在: + +```text +/models/AudioOperations/lid/speechbrain_lang-id-voxlingua107-ecapa +``` + +测试步骤: + +1. 输入数据集同时导入中文和英文语音。 +2. 运行 `audio_fast_lang_id`。 +3. 参数建议保持默认: + - `device = cpu` + - `maxSeconds = 3.0` +4. 检查输出数据集: + - 输出文件数量应与输入一致。 + - 输出仍应为音频,而不是只剩 `zh` 或 `en` 文本。 + - 中文样本的 `ext_params.audio_lid.lang` 应为 `zh`。 + - 英文样本的 `ext_params.audio_lid.lang` 应为 `en`。 + - 文件名中应带有类似 `__lid_zh` 或 `__lid_en` 的标记。 +5. 串联验证: + - 将输出数据集作为 `audio_asr_transcribe` 输入。 + - `audio_asr_transcribe.language` 设置为 `auto`。 + - ASR 应能继续读取音频并输出文本。 + +通过标准: + +- 语言标签正确。 +- 输出仍保留音频。 +- 能作为 ASR 的上游算子使用。 + +## 5. audio_fast_lang_id_text + +用途:识别语音语言,并直接输出一个文本标签文件。该算子是终端标注算子,会用 `zh` 或 `en` 文本替换音频。 + +推荐素材: + +```text +humanSpeech\zh\aishell_0000.wav +humanSpeech\en\librispeech_0000.wav +``` + +前置条件: + +- LID 模型目录默认应存在: + +```text +/models/AudioOperations/lid/speechbrain_lang-id-voxlingua107-ecapa +``` + +测试步骤: + +1. 输入数据集导入一条中文语音和一条英文语音。 +2. 运行 `audio_fast_lang_id_text`。 +3. 参数建议保持默认: + - `device = cpu` + - `maxSeconds = 3.0` +4. 检查输出数据集: + - 输出文件应为文本。 + - 中文音频输出文本内容应为 `zh`。 + - 英文音频输出文本内容应为 `en`。 + - 该输出不再适合作为 ASR 输入,因为音频已经被标签文本替换。 + +通过标准: + +- 算子运行成功。 +- 输出文本只包含语言标签。 +- 中文/英文判断符合输入素材。 + +## 6. audio_format_convert + +用途:转换音频格式、采样率和声道数,输出处理后的音频。 + +推荐素材: + +```text +audio\summary\84-121123-0000.flac +audio\summary\84-121123-0001.flac +humanSpeech\zh\aishell_0000.wav +``` + +测试步骤: + +1. 输入数据集导入 FLAC 和 WAV 音频。 +2. 运行 `audio_format_convert`。 +3. 推荐参数: + - `targetFormat = wav` + - `sampleRate = 16000` + - `channels = 1` +4. 检查输出数据集: + - 输出文件数量应与输入一致。 + - 输出应为 WAV 音频。 + - 音频应能正常播放或被后续算子读取。 + - `ext_params.audio_format_convert.format` 应为 `wav`。 + - `ext_params.audio_format_convert.sample_rate` 应为 `16000`。 + - `ext_params.audio_format_convert.channels` 应为 `1`。 +5. 可选格式测试: + - 将 `targetFormat` 改为 `flac` 或 `ogg`。 + - 检查输出扩展名和文件格式是否匹配。 + +通过标准: + +- 算子运行成功。 +- 输出格式、采样率、声道配置符合参数。 +- 输出仍是音频,可作为下游音频算子输入。 + +## 7. audio_gtcrn_denoise + +用途:调用 GTCRN ONNX 模型对音频降噪,输出 WAV 音频。 + +推荐素材: + +```text +humanSpeech\zh\aishell_0000.wav +humanSpeech\en\librispeech_0000.wav +audio\summary\BAC009S0002W0122.wav +``` + +前置条件: + +- GTCRN 模型默认应存在: + +```text +/models/AudioOperations/gtcrn/gtcrn.onnx +``` + +测试步骤: + +1. 输入数据集导入上述音频。 +2. 运行 `audio_gtcrn_denoise`。 +3. 参数 `modelPath` 使用默认值,或填写实际模型绝对路径。 +4. 检查输出数据集: + - 输出文件数量应与输入一致。 + - 输出应为 WAV 音频。 + - 音频应能正常播放或被后续音频算子读取。 +5. 可选串联验证: + - 将输出继续输入 `audio_asr_transcribe`。 + - ASR 应能正常输出文本。 + +通过标准: + +- 算子运行成功。 +- 输出 WAV 音频可读取。 +- 模型路径不存在时应明确报错,而不是静默输出空文件。 + +## 8. audio_hum_notch + +用途:对 50Hz 或 60Hz 工频噪声做陷波抑制,输出 WAV 音频。 + +推荐素材: + +```text +humanSpeech\zh\aishell_0000.wav +humanSpeech\en\librispeech_0000.wav +audio\summary\BAC009S0002W0122.wav +``` + +前置条件: + +- 运行环境应安装 `soundfile`、`numpy`、`scipy`。 + +测试步骤: + +1. 输入数据集导入上述音频。 +2. 运行 `audio_hum_notch`。 +3. 推荐参数: + - `freqHz = 50` + - `q = 30` +4. 检查输出数据集: + - 输出文件数量应与输入一致。 + - 输出应为 WAV 音频。 + - 音频应能正常播放或被后续音频算子读取。 +5. 参数分支测试: + - 将 `freqHz` 改为 `60`。 + - 再次运行,任务应成功。 + +通过标准: + +- 算子运行成功。 +- 50Hz 和 60Hz 参数均可运行。 +- 输出仍是可读取音频。 + +## 9. audio_noise_gate + +用途:对低于阈值的低能量音频帧做衰减,输出 WAV 音频。 + +推荐素材: + +```text +humanSpeech\zh\aishell_0000.wav +humanSpeech\en\librispeech_0000.wav +audio\summary\BAC009S0002W0122.wav +``` + +前置条件: + +- 运行环境应安装 `soundfile`、`numpy`。 + +测试步骤: + +1. 输入数据集导入上述音频。 +2. 运行 `audio_noise_gate`。 +3. 推荐参数先使用默认值: + - `thresholdDb = -45` + - `frameMs = 20` + - `hopMs = 10` + - `floorRatio = 0.05` +4. 检查输出数据集: + - 输出文件数量应与输入一致。 + - 输出应为 WAV 音频。 + - 音频应能正常播放或被后续音频算子读取。 +5. 参数分支测试: + - 将 `thresholdDb` 设置为 `-20`。 + - 将 `floorRatio` 设置为 `0`。 + - 再次运行,任务应成功,输出音频仍可读取。 + +通过标准: + +- 算子运行成功。 +- 默认参数和较强门限参数均可运行。 +- 输出仍是可读取音频。 + +## 10. audio_text_summarize + +用途:输入 ASR 文本,输出摘要文本。该算子输入是文本,不是音频。 + +推荐素材和前置流程: + +建议先用以下音频跑出 ASR 文本,再把 ASR 输出数据集作为本算子的输入: + +```text +audio\summary\84-121123-0000.flac +audio\summary\84-121123-0001.flac +audio\summary\BAC009S0002W0122.wav +audio\summary\BAC009S0002W0123.wav +``` + +测试步骤: + +1. 先运行 `audio_asr_transcribe`,得到文本输出数据集。 +2. 将 ASR 输出数据集作为 `audio_text_summarize` 的输入。 +3. 运行 `audio_text_summarize`。 +4. 推荐参数: + - `method = extractive` + - `lineMode = single` + - `maxSummaryCharsZh = 40` + - `maxSummaryWordsEn = 18` + - `preserveKeys = true` +5. 检查输出数据集: + - 输出文件应为文本。 + - 摘要文本不应为空。 + - 中文摘要长度应大致受 `maxSummaryCharsZh` 控制。 + - 英文摘要词数应大致受 `maxSummaryWordsEn` 控制。 + - `ext_params.audio_text_summarize.method` 应为 `extractive`。 +- 模型目录默认: + +```text +/models/AudioOperations/summary/summary-model +``` + +通过标准: + +- 算子运行成功。 +- 输出摘要文本非空。 +- 默认 `extractive` 方法不依赖 ONNX 模型即可完成。 +- 文本输入为空时,应被明确跳过或标记为空文本,不应产生异常崩溃。 diff --git a/runtime/ops/mapper/audio_pre_emphasis/README.md b/runtime/ops/mapper/audio_pre_emphasis/README.md new file mode 100644 index 00000000..c24a275f --- /dev/null +++ b/runtime/ops/mapper/audio_pre_emphasis/README.md @@ -0,0 +1,24 @@ +# AudioPreEmphasis 预加重算子 + +## 概述 + +AudioPreEmphasis 处理输入音频,并将结果写入 `sample["data"]`,同时设置 `sample["target_type"]`。输出路径、同名文件处理和最终落盘均交由 DataMate 的标准导出流程负责。 + +## 参数说明 + +| 参数 | 类型 | 默认值 | 说明 | +|---|---|---:|---| +| coef | slider | 0.97 | 预加重系数(常用 0.9~0.99) | + +## 输入输出 + +- **输入**:`sample["filePath"]`,若上游算子已产生 `sample["data"]`,则优先处理该音频字节。 +- **输出**:`sample["data"]` 为处理后的音频字节;`sample["target_type"]` 为目标音频后缀。 + +## 依赖说明 + +- **Python 依赖**:soundfile、numpy + +## 版本历史 + +- **v1.0.0**:首次发布 diff --git a/runtime/ops/mapper/audio_pre_emphasis/__init__.py b/runtime/ops/mapper/audio_pre_emphasis/__init__.py new file mode 100644 index 00000000..3c01a422 --- /dev/null +++ b/runtime/ops/mapper/audio_pre_emphasis/__init__.py @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- + +from datamate.core.base_op import OPERATORS + +OPERATORS.register_module(module_name='AudioPreEmphasis', + module_path="ops.mapper.audio_pre_emphasis.process") diff --git a/runtime/ops/mapper/audio_pre_emphasis/audio_skip.py b/runtime/ops/mapper/audio_pre_emphasis/audio_skip.py new file mode 100644 index 00000000..aec49613 --- /dev/null +++ b/runtime/ops/mapper/audio_pre_emphasis/audio_skip.py @@ -0,0 +1,114 @@ +# -- encoding: utf-8 -- + +from pathlib import Path +from typing import Any, Dict + +from loguru import logger + + +AUDIO_EXTS = { + "aac", + "aif", + "aiff", + "amr", + "au", + "flac", + "m4a", + "mp3", + "oga", + "ogg", + "opus", + "snd", + "wav", + "webm", + "wma", +} + + +def _parts(path_value: str) -> set[str]: + try: + return {part.lower() for part in Path(path_value).parts} + except Exception: + return set() + + +def is_reference_sample(sample: Dict[str, Any], filepath_key: str = "filePath") -> bool: + path_value = str(sample.get(filepath_key) or "") + return "references" in _parts(path_value) + + +def _ext_from_sample( + sample: Dict[str, Any], + filepath_key: str = "filePath", + filetype_key: str = "fileType", + target_type_key: str = "target_type", +) -> str: + for key in (target_type_key, filetype_key): + value = str(sample.get(key) or "").strip().lower().lstrip(".") + if value: + return value + path_value = str(sample.get(filepath_key) or "").strip() + return Path(path_value).suffix.lower().lstrip(".") if path_value else "" + + +def is_audio_sample( + sample: Dict[str, Any], + filepath_key: str = "filePath", + filetype_key: str = "fileType", + target_type_key: str = "target_type", + data_key: str = "data", +) -> bool: + if is_reference_sample(sample, filepath_key): + return False + data = sample.get(data_key) + if isinstance(data, (bytes, bytearray)) and data: + return True + return _ext_from_sample(sample, filepath_key, filetype_key, target_type_key) in AUDIO_EXTS + + +def invalid_quality_reason(sample: Dict[str, Any], ext_params_key: str = "ext_params") -> str: + for key in ("fileName", "sourceFileName", "filePath"): + marker_source = Path(str(sample.get(key) or "")).stem.lower() + marker = "__quality_invalid" + if marker in marker_source: + reason = marker_source.split(marker, 1)[1].strip("_") or "invalid_audio" + return f"invalid_audio_quality:{reason}" + + ext = sample.get(ext_params_key, {}) + if not isinstance(ext, dict): + return "" + quality = ext.get("audio_quality", {}) + if not isinstance(quality, dict): + return "" + if str(quality.get("quality_flag") or "").strip().lower() != "invalid": + return "" + skip_downstream = quality.get("skip_downstream", True) + if isinstance(skip_downstream, str): + skip_downstream = skip_downstream.strip().lower() in {"1", "true", "yes", "y", "on"} + if not skip_downstream: + return "" + reason = str(quality.get("reason") or "invalid_audio").strip() + return f"invalid_audio_quality:{reason}" + + +def mark_skipped_sample( + sample: Dict[str, Any], + reason: str, + op_name: str, + text_key: str = "text", + data_key: str = "data", + filetype_key: str = "fileType", + target_type_key: str = "target_type", + ext_params_key: str = "ext_params", +) -> Dict[str, Any]: + ext = sample.get(ext_params_key, {}) + if not isinstance(ext, dict): + ext = {"_raw": ext} + ext.setdefault("audio_skip", {})[op_name] = reason + sample[ext_params_key] = ext + sample[text_key] = "" + sample[data_key] = b"" + sample[filetype_key] = "" + sample[target_type_key] = "" + logger.info(f"fileName: {sample.get('fileName')}, method: {op_name} skipped: {reason}") + return sample diff --git a/runtime/ops/mapper/audio_pre_emphasis/metadata.yml b/runtime/ops/mapper/audio_pre_emphasis/metadata.yml new file mode 100644 index 00000000..f0f255ec --- /dev/null +++ b/runtime/ops/mapper/audio_pre_emphasis/metadata.yml @@ -0,0 +1,34 @@ +name: 'audioUtils-预加重' +name_en: 'audioUtils-Pre-Emphasis' +description: '一阶预加重滤波 \(y[n]=x[n]-coef*x[n-1]\)。处理音频并由 DataMate 统一导出结果。' +description_en: 'First-order pre-emphasis \(y[n]=x[n]-coef*x[n-1]\). Process audio and let DataMate export the result.' +language: 'python' +vendor: 'huawei' +raw_id: 'AudioPreEmphasis' +version: '1.0.0' +types: + - 'cleaning' +modal: 'audio' +inputs: 'audio' +outputs: 'audio' +settings: + coef: + name: '预加重系数' + type: 'slider' + description: '常用范围 0.9~0.99。' + defaultVal: 0.97 + min: 0 + max: 0.999 + step: 0.001 +runtime: + memory: 104857600 + cpu: 0.1 + gpu: 0 + npu: 0 + storage: 10MB + +metrics: + - name: '处理耗时' + metric: '依输入音频长度与运行环境而定' +release: + - '首次发布' diff --git a/runtime/ops/mapper/audio_pre_emphasis/process.py b/runtime/ops/mapper/audio_pre_emphasis/process.py new file mode 100644 index 00000000..cfb65a38 --- /dev/null +++ b/runtime/ops/mapper/audio_pre_emphasis/process.py @@ -0,0 +1,101 @@ +# -- encoding: utf-8 -- + +import io +import time +from pathlib import Path +from typing import Dict, Any, Tuple + +from loguru import logger + +from datamate.core.base_op import Mapper +try: + from .audio_skip import invalid_quality_reason, is_audio_sample, mark_skipped_sample +except ImportError: + from audio_skip import invalid_quality_reason, is_audio_sample, mark_skipped_sample + + + +def _load_audio(source: object) -> Tuple["object", int]: + try: + import soundfile as sf # type: ignore + + if isinstance(source, (bytes, bytearray)): + data, sr = sf.read(io.BytesIO(bytes(source)), always_2d=False) + else: + data, sr = sf.read(str(source), always_2d=False) + return data, int(sr) + except Exception as e: + raise RuntimeError(f"读取音频失败(需要 soundfile): error={e}") from e + + +def _dump_audio(data: "object", sr: int, fmt: str) -> bytes: + try: + import soundfile as sf # type: ignore + + with io.BytesIO() as buf: + sf.write(buf, data, int(sr), format=fmt.upper() if fmt else "WAV") + return buf.getvalue() + except Exception as e: + raise RuntimeError(f"编码音频失败(需要 soundfile,fmt={fmt}): {e}") from e + + +class AudioPreEmphasis(Mapper): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.coef = float(kwargs.get("coef", 0.97)) + self.out_format = "wav" + + def execute(self, sample: Dict[str, Any]) -> Dict[str, Any]: + start = time.time() + quality_skip_reason = invalid_quality_reason(sample, self.ext_params_key) + if quality_skip_reason: + return mark_skipped_sample( + sample, + quality_skip_reason, + self.__class__.__name__, + self.text_key, + self.data_key, + self.filetype_key, + self.target_type_key, + self.ext_params_key, + ) + + if not is_audio_sample(sample, self.filepath_key, self.filetype_key, self.target_type_key, self.data_key): + return mark_skipped_sample( + sample, + "non_audio_or_reference_file", + self.__class__.__name__, + self.text_key, + self.data_key, + self.filetype_key, + self.target_type_key, + self.ext_params_key, + ) + in_path = Path(sample.get(self.filepath_key, "")).resolve() + if not in_path.exists(): + raise FileNotFoundError(f"输入音频不存在: {in_path}") + + data, sr = _load_audio(sample.get(self.data_key) or in_path) + try: + import numpy as np + + x = np.asarray(data, dtype=np.float32) + if x.ndim > 1: + x = x.mean(axis=1) + if x.size == 0: + y = x + else: + y = np.empty_like(x) + y[0] = x[0] + y[1:] = x[1:] - float(self.coef) * x[:-1] + except Exception as e: + raise RuntimeError(f"处理失败(需要 numpy): {e}") from e + + sample[self.data_key] = _dump_audio(y, sr, self.out_format) + sample[self.text_key] = "" + sample[self.target_type_key] = self.out_format + sample[self.filetype_key] = "txt" + + logger.info(f"fileName: {sample.get(self.filename_key)}, method: AudioPreEmphasis costs {time.time() - start:6f} s") + return sample + diff --git a/runtime/ops/mapper/audio_pre_emphasis/requirements.txt b/runtime/ops/mapper/audio_pre_emphasis/requirements.txt new file mode 100644 index 00000000..17e9d57d --- /dev/null +++ b/runtime/ops/mapper/audio_pre_emphasis/requirements.txt @@ -0,0 +1,2 @@ +soundfile +numpy diff --git a/runtime/ops/mapper/audio_quantize_encode/README.md b/runtime/ops/mapper/audio_quantize_encode/README.md new file mode 100644 index 00000000..d74e6429 --- /dev/null +++ b/runtime/ops/mapper/audio_quantize_encode/README.md @@ -0,0 +1,26 @@ +# AudioQuantizeEncode 量化编码与重采样算子 + +## 概述 + +AudioQuantizeEncode 处理输入音频,并将结果写入 `sample["data"]`,同时设置 `sample["target_type"]`。输出路径、同名文件处理和最终落盘均交由 DataMate 的标准导出流程负责。 + +## 参数说明 + +| 参数 | 类型 | 默认值 | 说明 | +|---|---|---:|---| +| sampleRate | inputNumber | 16000 | 目标采样率(Hz),0=保持原采样率 | +| bitDepth | select | 16 | WAV PCM 位深:8/16/24/32 | +| channels | inputNumber | 1 | 目标声道数:1/2,0=保持 | + +## 输入输出 + +- **输入**:`sample["filePath"]`,若上游算子已产生 `sample["data"]`,则优先处理该音频字节。 +- **输出**:`sample["data"]` 为处理后的音频字节;`sample["target_type"]` 为目标音频后缀。 + +## 依赖说明 + +- **Python 依赖**:soundfile、numpy + +## 版本历史 + +- **v1.0.0**:首次发布 diff --git a/runtime/ops/mapper/audio_quantize_encode/__init__.py b/runtime/ops/mapper/audio_quantize_encode/__init__.py new file mode 100644 index 00000000..a7165732 --- /dev/null +++ b/runtime/ops/mapper/audio_quantize_encode/__init__.py @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- + +from datamate.core.base_op import OPERATORS + +OPERATORS.register_module(module_name='AudioQuantizeEncode', + module_path="ops.mapper.audio_quantize_encode.process") diff --git a/runtime/ops/mapper/audio_quantize_encode/audio_skip.py b/runtime/ops/mapper/audio_quantize_encode/audio_skip.py new file mode 100644 index 00000000..aec49613 --- /dev/null +++ b/runtime/ops/mapper/audio_quantize_encode/audio_skip.py @@ -0,0 +1,114 @@ +# -- encoding: utf-8 -- + +from pathlib import Path +from typing import Any, Dict + +from loguru import logger + + +AUDIO_EXTS = { + "aac", + "aif", + "aiff", + "amr", + "au", + "flac", + "m4a", + "mp3", + "oga", + "ogg", + "opus", + "snd", + "wav", + "webm", + "wma", +} + + +def _parts(path_value: str) -> set[str]: + try: + return {part.lower() for part in Path(path_value).parts} + except Exception: + return set() + + +def is_reference_sample(sample: Dict[str, Any], filepath_key: str = "filePath") -> bool: + path_value = str(sample.get(filepath_key) or "") + return "references" in _parts(path_value) + + +def _ext_from_sample( + sample: Dict[str, Any], + filepath_key: str = "filePath", + filetype_key: str = "fileType", + target_type_key: str = "target_type", +) -> str: + for key in (target_type_key, filetype_key): + value = str(sample.get(key) or "").strip().lower().lstrip(".") + if value: + return value + path_value = str(sample.get(filepath_key) or "").strip() + return Path(path_value).suffix.lower().lstrip(".") if path_value else "" + + +def is_audio_sample( + sample: Dict[str, Any], + filepath_key: str = "filePath", + filetype_key: str = "fileType", + target_type_key: str = "target_type", + data_key: str = "data", +) -> bool: + if is_reference_sample(sample, filepath_key): + return False + data = sample.get(data_key) + if isinstance(data, (bytes, bytearray)) and data: + return True + return _ext_from_sample(sample, filepath_key, filetype_key, target_type_key) in AUDIO_EXTS + + +def invalid_quality_reason(sample: Dict[str, Any], ext_params_key: str = "ext_params") -> str: + for key in ("fileName", "sourceFileName", "filePath"): + marker_source = Path(str(sample.get(key) or "")).stem.lower() + marker = "__quality_invalid" + if marker in marker_source: + reason = marker_source.split(marker, 1)[1].strip("_") or "invalid_audio" + return f"invalid_audio_quality:{reason}" + + ext = sample.get(ext_params_key, {}) + if not isinstance(ext, dict): + return "" + quality = ext.get("audio_quality", {}) + if not isinstance(quality, dict): + return "" + if str(quality.get("quality_flag") or "").strip().lower() != "invalid": + return "" + skip_downstream = quality.get("skip_downstream", True) + if isinstance(skip_downstream, str): + skip_downstream = skip_downstream.strip().lower() in {"1", "true", "yes", "y", "on"} + if not skip_downstream: + return "" + reason = str(quality.get("reason") or "invalid_audio").strip() + return f"invalid_audio_quality:{reason}" + + +def mark_skipped_sample( + sample: Dict[str, Any], + reason: str, + op_name: str, + text_key: str = "text", + data_key: str = "data", + filetype_key: str = "fileType", + target_type_key: str = "target_type", + ext_params_key: str = "ext_params", +) -> Dict[str, Any]: + ext = sample.get(ext_params_key, {}) + if not isinstance(ext, dict): + ext = {"_raw": ext} + ext.setdefault("audio_skip", {})[op_name] = reason + sample[ext_params_key] = ext + sample[text_key] = "" + sample[data_key] = b"" + sample[filetype_key] = "" + sample[target_type_key] = "" + logger.info(f"fileName: {sample.get('fileName')}, method: {op_name} skipped: {reason}") + return sample diff --git a/runtime/ops/mapper/audio_quantize_encode/metadata.yml b/runtime/ops/mapper/audio_quantize_encode/metadata.yml new file mode 100644 index 00000000..21a564a9 --- /dev/null +++ b/runtime/ops/mapper/audio_quantize_encode/metadata.yml @@ -0,0 +1,57 @@ +name: 'audioUtils-量化编码与重采样' +name_en: 'audioUtils-Quantize Encode & Resample' +description: '将音频重采样到指定采样率,并量化编码为 8/16/24/32-bit PCM(WAV);由 DataMate 统一导出结果。' +description_en: 'Resample audio to target sample rate and encode as 8/16/24/32-bit PCM WAV; DataMate exports the result.' +language: 'python' +vendor: 'huawei' +raw_id: 'AudioQuantizeEncode' +version: '1.0.0' +types: + - 'cleaning' +modal: 'audio' +inputs: 'audio' +outputs: 'audio' +settings: + sampleRate: + name: '采样率(Hz)' + description: '目标采样率(Hz)。0 表示保持原采样率。' + type: 'inputNumber' + defaultVal: 16000 + min: 0 + max: 192000 + step: 1 + bitDepth: + name: '位深(bit)' + description: 'WAV PCM 位深:8/16/24/32。' + type: 'select' + defaultVal: '16' + required: true + options: + - label: '8-bit PCM' + value: '8' + - label: '16-bit PCM' + value: '16' + - label: '24-bit PCM' + value: '24' + - label: '32-bit PCM' + value: '32' + channels: + name: '声道数' + description: '目标声道数:1=单声道,2=双声道,0=保持原声道。' + type: 'inputNumber' + defaultVal: 1 + min: 0 + max: 2 + step: 1 +runtime: + memory: 268435456 + cpu: 0.3 + gpu: 0 + npu: 0 + storage: 10MB + +metrics: + - name: '处理耗时' + metric: '依输入音频长度与运行环境而定' +release: + - '首次发布' diff --git a/runtime/ops/mapper/audio_quantize_encode/process.py b/runtime/ops/mapper/audio_quantize_encode/process.py new file mode 100644 index 00000000..a2197f8a --- /dev/null +++ b/runtime/ops/mapper/audio_quantize_encode/process.py @@ -0,0 +1,130 @@ +# -- encoding: utf-8 -- + +import io +import time +from pathlib import Path +from typing import Dict, Any, Tuple + +from loguru import logger + +from datamate.core.base_op import Mapper +try: + from .audio_skip import invalid_quality_reason, is_audio_sample, mark_skipped_sample +except ImportError: + from audio_skip import invalid_quality_reason, is_audio_sample, mark_skipped_sample + + + +def _load_audio(source: object) -> Tuple["object", int]: + try: + import soundfile as sf # type: ignore + + if isinstance(source, (bytes, bytearray)): + data, sr = sf.read(io.BytesIO(bytes(source)), always_2d=True) + else: + data, sr = sf.read(str(source), always_2d=True) + return data, int(sr) + except Exception as e: + raise RuntimeError(f"读取音频失败(需要 soundfile): error={e}") from e + + +def _dump_wav_pcm(data: "object", sr: int, subtype: str) -> bytes: + try: + import soundfile as sf # type: ignore + + with io.BytesIO() as buf: + sf.write(buf, data, int(sr), format="WAV", subtype=subtype) + return buf.getvalue() + except Exception as e: + raise RuntimeError(f"编码 WAV 失败(需要 soundfile,subtype={subtype}): {e}") from e + + +def _resample_linear(data: "object", src_sr: int, tgt_sr: int) -> "object": + if src_sr <= 0 or tgt_sr <= 0 or int(src_sr) == int(tgt_sr): + return data + try: + import numpy as np + + x = np.asarray(data, dtype=np.float32) # (T, C) + if x.ndim != 2: + x = x.reshape((-1, 1)) + new_len = max(1, int(round(x.shape[0] * float(tgt_sr) / float(src_sr)))) + old_x = np.linspace(0.0, 1.0, num=x.shape[0], endpoint=False) + new_x = np.linspace(0.0, 1.0, num=new_len, endpoint=False) + return np.stack( + [np.interp(new_x, old_x, x[:, ch]).astype(np.float32) for ch in range(x.shape[1])], + axis=1, + ) + except Exception as e: + raise RuntimeError(f"重采样失败(需要 numpy): {e}") from e + + +class AudioQuantizeEncode(Mapper): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.sample_rate = int(float(kwargs.get("sampleRate", 16000))) + self.bit_depth = int(float(kwargs.get("bitDepth", 16))) + self.channels = int(float(kwargs.get("channels", 1))) + + def execute(self, sample: Dict[str, Any]) -> Dict[str, Any]: + start = time.time() + quality_skip_reason = invalid_quality_reason(sample, self.ext_params_key) + if quality_skip_reason: + return mark_skipped_sample( + sample, + quality_skip_reason, + self.__class__.__name__, + self.text_key, + self.data_key, + self.filetype_key, + self.target_type_key, + self.ext_params_key, + ) + + if not is_audio_sample(sample, self.filepath_key, self.filetype_key, self.target_type_key, self.data_key): + return mark_skipped_sample( + sample, + "non_audio_or_reference_file", + self.__class__.__name__, + self.text_key, + self.data_key, + self.filetype_key, + self.target_type_key, + self.ext_params_key, + ) + in_path = Path(sample.get(self.filepath_key, "")).resolve() + if not in_path.exists(): + raise FileNotFoundError(f"输入音频不存在: {in_path}") + + data, sr = _load_audio(sample.get(self.data_key) or in_path) # (T, C) + try: + import numpy as np + + x = np.asarray(data, dtype=np.float32) + if self.channels == 1 and x.shape[1] > 1: + x = x.mean(axis=1, keepdims=True) + elif self.channels == 2 and x.shape[1] == 1: + x = x.repeat(2, axis=1) + x = _resample_linear(x, sr, self.sample_rate) if self.sample_rate > 0 else x + out_sr = int(self.sample_rate) if self.sample_rate > 0 else int(sr) + except Exception as e: + raise RuntimeError(f"预处理失败: {e}") from e + + subtype_map = { + 8: "PCM_U8", + 16: "PCM_16", + 24: "PCM_24", + 32: "PCM_32", + } + if self.bit_depth not in subtype_map: + raise ValueError(f"不支持的 bitDepth: {self.bit_depth},仅支持 8/16/24/32") + + sample[self.data_key] = _dump_wav_pcm(x, out_sr, subtype=subtype_map[self.bit_depth]) + sample[self.text_key] = "" + sample[self.target_type_key] = "wav" + sample[self.filetype_key] = "txt" + + logger.info( + f"fileName: {sample.get(self.filename_key)}, method: AudioQuantizeEncode costs {time.time() - start:6f} s" + ) + return sample diff --git a/runtime/ops/mapper/audio_quantize_encode/requirements.txt b/runtime/ops/mapper/audio_quantize_encode/requirements.txt new file mode 100644 index 00000000..17e9d57d --- /dev/null +++ b/runtime/ops/mapper/audio_quantize_encode/requirements.txt @@ -0,0 +1,2 @@ +soundfile +numpy diff --git a/runtime/ops/mapper/audio_rms_loudness_normalize/README.md b/runtime/ops/mapper/audio_rms_loudness_normalize/README.md new file mode 100644 index 00000000..ff13c3e4 --- /dev/null +++ b/runtime/ops/mapper/audio_rms_loudness_normalize/README.md @@ -0,0 +1,25 @@ +# AudioRmsLoudnessNormalize 整段 RMS 归一与峰值顶限算子 + +## 概述 + +AudioRmsLoudnessNormalize 处理输入音频,并将结果写入 `sample["data"]`,同时设置 `sample["target_type"]`。输出路径、同名文件处理和最终落盘均交由 DataMate 的标准导出流程负责。 + +## 参数说明 + +| 参数 | 类型 | 默认值 | 说明 | +|---|---|---:|---| +| targetRms | slider | 0.08 | 目标 RMS(线性) | +| peakCeiling | slider | 0.99 | 峰值顶限(0~1) | + +## 输入输出 + +- **输入**:`sample["filePath"]`,若上游算子已产生 `sample["data"]`,则优先处理该音频字节。 +- **输出**:`sample["data"]` 为处理后的音频字节;`sample["target_type"]` 为目标音频后缀。 + +## 依赖说明 + +- **Python 依赖**:soundfile、numpy + +## 版本历史 + +- **v1.0.0**:首次发布 diff --git a/runtime/ops/mapper/audio_rms_loudness_normalize/__init__.py b/runtime/ops/mapper/audio_rms_loudness_normalize/__init__.py new file mode 100644 index 00000000..61d920f3 --- /dev/null +++ b/runtime/ops/mapper/audio_rms_loudness_normalize/__init__.py @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- + +from datamate.core.base_op import OPERATORS + +OPERATORS.register_module(module_name='AudioRmsLoudnessNormalize', + module_path="ops.mapper.audio_rms_loudness_normalize.process") diff --git a/runtime/ops/mapper/audio_rms_loudness_normalize/audio_skip.py b/runtime/ops/mapper/audio_rms_loudness_normalize/audio_skip.py new file mode 100644 index 00000000..aec49613 --- /dev/null +++ b/runtime/ops/mapper/audio_rms_loudness_normalize/audio_skip.py @@ -0,0 +1,114 @@ +# -- encoding: utf-8 -- + +from pathlib import Path +from typing import Any, Dict + +from loguru import logger + + +AUDIO_EXTS = { + "aac", + "aif", + "aiff", + "amr", + "au", + "flac", + "m4a", + "mp3", + "oga", + "ogg", + "opus", + "snd", + "wav", + "webm", + "wma", +} + + +def _parts(path_value: str) -> set[str]: + try: + return {part.lower() for part in Path(path_value).parts} + except Exception: + return set() + + +def is_reference_sample(sample: Dict[str, Any], filepath_key: str = "filePath") -> bool: + path_value = str(sample.get(filepath_key) or "") + return "references" in _parts(path_value) + + +def _ext_from_sample( + sample: Dict[str, Any], + filepath_key: str = "filePath", + filetype_key: str = "fileType", + target_type_key: str = "target_type", +) -> str: + for key in (target_type_key, filetype_key): + value = str(sample.get(key) or "").strip().lower().lstrip(".") + if value: + return value + path_value = str(sample.get(filepath_key) or "").strip() + return Path(path_value).suffix.lower().lstrip(".") if path_value else "" + + +def is_audio_sample( + sample: Dict[str, Any], + filepath_key: str = "filePath", + filetype_key: str = "fileType", + target_type_key: str = "target_type", + data_key: str = "data", +) -> bool: + if is_reference_sample(sample, filepath_key): + return False + data = sample.get(data_key) + if isinstance(data, (bytes, bytearray)) and data: + return True + return _ext_from_sample(sample, filepath_key, filetype_key, target_type_key) in AUDIO_EXTS + + +def invalid_quality_reason(sample: Dict[str, Any], ext_params_key: str = "ext_params") -> str: + for key in ("fileName", "sourceFileName", "filePath"): + marker_source = Path(str(sample.get(key) or "")).stem.lower() + marker = "__quality_invalid" + if marker in marker_source: + reason = marker_source.split(marker, 1)[1].strip("_") or "invalid_audio" + return f"invalid_audio_quality:{reason}" + + ext = sample.get(ext_params_key, {}) + if not isinstance(ext, dict): + return "" + quality = ext.get("audio_quality", {}) + if not isinstance(quality, dict): + return "" + if str(quality.get("quality_flag") or "").strip().lower() != "invalid": + return "" + skip_downstream = quality.get("skip_downstream", True) + if isinstance(skip_downstream, str): + skip_downstream = skip_downstream.strip().lower() in {"1", "true", "yes", "y", "on"} + if not skip_downstream: + return "" + reason = str(quality.get("reason") or "invalid_audio").strip() + return f"invalid_audio_quality:{reason}" + + +def mark_skipped_sample( + sample: Dict[str, Any], + reason: str, + op_name: str, + text_key: str = "text", + data_key: str = "data", + filetype_key: str = "fileType", + target_type_key: str = "target_type", + ext_params_key: str = "ext_params", +) -> Dict[str, Any]: + ext = sample.get(ext_params_key, {}) + if not isinstance(ext, dict): + ext = {"_raw": ext} + ext.setdefault("audio_skip", {})[op_name] = reason + sample[ext_params_key] = ext + sample[text_key] = "" + sample[data_key] = b"" + sample[filetype_key] = "" + sample[target_type_key] = "" + logger.info(f"fileName: {sample.get('fileName')}, method: {op_name} skipped: {reason}") + return sample diff --git a/runtime/ops/mapper/audio_rms_loudness_normalize/metadata.yml b/runtime/ops/mapper/audio_rms_loudness_normalize/metadata.yml new file mode 100644 index 00000000..9290df86 --- /dev/null +++ b/runtime/ops/mapper/audio_rms_loudness_normalize/metadata.yml @@ -0,0 +1,42 @@ +name: 'audioUtils-整段RMS归一 + 峰值顶限' +name_en: 'audioUtils-RMS Loudness Normalize' +description: '将整段 RMS 对齐到目标,再按峰值顶限缩放。处理音频并由 DataMate 统一导出结果。' +description_en: 'Normalize full-utterance RMS to target and apply peak ceiling. Process audio and let DataMate export the result.' +language: 'python' +vendor: 'huawei' +raw_id: 'AudioRmsLoudnessNormalize' +version: '1.0.0' +types: + - 'cleaning' +modal: 'audio' +inputs: 'audio' +outputs: 'audio' +settings: + targetRms: + name: '目标RMS' + type: 'slider' + description: '线性 RMS(0~1),越大越响。' + defaultVal: 0.08 + min: 0.001 + max: 0.5 + step: 0.001 + peakCeiling: + name: '峰值顶限' + type: 'slider' + description: '峰值限制(0~1)。' + defaultVal: 0.99 + min: 0.1 + max: 1 + step: 0.01 +runtime: + memory: 104857600 + cpu: 0.12 + gpu: 0 + npu: 0 + storage: 10MB + +metrics: + - name: '处理耗时' + metric: '依输入音频长度与运行环境而定' +release: + - '首次发布' diff --git a/runtime/ops/mapper/audio_rms_loudness_normalize/process.py b/runtime/ops/mapper/audio_rms_loudness_normalize/process.py new file mode 100644 index 00000000..fda93ac4 --- /dev/null +++ b/runtime/ops/mapper/audio_rms_loudness_normalize/process.py @@ -0,0 +1,110 @@ +# -- encoding: utf-8 -- + +import io +import time +from pathlib import Path +from typing import Dict, Any, Tuple + +from loguru import logger + +from datamate.core.base_op import Mapper +try: + from .audio_skip import invalid_quality_reason, is_audio_sample, mark_skipped_sample +except ImportError: + from audio_skip import invalid_quality_reason, is_audio_sample, mark_skipped_sample + + + +def _load_audio(source: object) -> Tuple["object", int]: + try: + import soundfile as sf # type: ignore + + if isinstance(source, (bytes, bytearray)): + data, sr = sf.read(io.BytesIO(bytes(source)), always_2d=False) + else: + data, sr = sf.read(str(source), always_2d=False) + return data, int(sr) + except Exception as e: + raise RuntimeError(f"读取音频失败(需要 soundfile): error={e}") from e + + +def _dump_audio(data: "object", sr: int, fmt: str) -> bytes: + try: + import soundfile as sf # type: ignore + + with io.BytesIO() as buf: + sf.write(buf, data, int(sr), format=fmt.upper() if fmt else "WAV") + return buf.getvalue() + except Exception as e: + raise RuntimeError(f"编码音频失败(需要 soundfile,fmt={fmt}): {e}") from e + + +class AudioRmsLoudnessNormalize(Mapper): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.target_rms = float(kwargs.get("targetRms", 0.08)) + self.peak_ceiling = float(kwargs.get("peakCeiling", 0.99)) + self.out_format = "wav" + + def execute(self, sample: Dict[str, Any]) -> Dict[str, Any]: + start = time.time() + quality_skip_reason = invalid_quality_reason(sample, self.ext_params_key) + if quality_skip_reason: + return mark_skipped_sample( + sample, + quality_skip_reason, + self.__class__.__name__, + self.text_key, + self.data_key, + self.filetype_key, + self.target_type_key, + self.ext_params_key, + ) + + if not is_audio_sample(sample, self.filepath_key, self.filetype_key, self.target_type_key, self.data_key): + return mark_skipped_sample( + sample, + "non_audio_or_reference_file", + self.__class__.__name__, + self.text_key, + self.data_key, + self.filetype_key, + self.target_type_key, + self.ext_params_key, + ) + in_path = Path(sample.get(self.filepath_key, "")).resolve() + if not in_path.exists(): + raise FileNotFoundError(f"输入音频不存在: {in_path}") + + data, sr = _load_audio(sample.get(self.data_key) or in_path) + try: + import numpy as np + + x = np.asarray(data, dtype=np.float32) + if x.ndim > 1: + x = x.mean(axis=1) + if x.size == 0: + y = x + else: + eps = 1e-8 + rms = float(np.sqrt(np.mean(x * x) + eps)) + g = float(self.target_rms) / max(eps, rms) + y = x * g + peak = float(np.max(np.abs(y)) + eps) + ceiling = max(1e-6, min(1.0, float(self.peak_ceiling))) + if peak > ceiling: + y = y * (ceiling / peak) + y = np.clip(y, -1.0, 1.0) + except Exception as e: + raise RuntimeError(f"处理失败(需要 numpy): {e}") from e + + sample[self.data_key] = _dump_audio(y, sr, self.out_format) + sample[self.text_key] = "" + sample[self.target_type_key] = self.out_format + sample[self.filetype_key] = "txt" + + logger.info( + f"fileName: {sample.get(self.filename_key)}, method: AudioRmsLoudnessNormalize costs {time.time() - start:6f} s" + ) + return sample + diff --git a/runtime/ops/mapper/audio_rms_loudness_normalize/requirements.txt b/runtime/ops/mapper/audio_rms_loudness_normalize/requirements.txt new file mode 100644 index 00000000..17e9d57d --- /dev/null +++ b/runtime/ops/mapper/audio_rms_loudness_normalize/requirements.txt @@ -0,0 +1,2 @@ +soundfile +numpy diff --git a/runtime/ops/mapper/audio_runtime_dependencies.md b/runtime/ops/mapper/audio_runtime_dependencies.md new file mode 100644 index 00000000..ee9e92c2 --- /dev/null +++ b/runtime/ops/mapper/audio_runtime_dependencies.md @@ -0,0 +1,73 @@ +# DataMate Audio Runtime Dependencies + +These dependencies should be installed in the DataMate runtime environment. They should not be vendored inside individual audio operator directories. + +## Python Packages + +Use `audio_runtime_requirements.txt` for Python package installation. + +Important pinned packages: + +- `torch==2.8.0` +- `torch_npu==2.8.0` +- `torchaudio==2.8.0` +- `speechbrain==1.0.3` +- `pydub==0.25.1` +- `soundfile==0.12.1` +- `numpy==2.2.6` +- `scipy==1.13.1` +- `onnxruntime==1.19.2` +- `transformers==4.57.6` +- `timm==1.0.26` +- `panns-inference==0.1.1` + +## System Packages + +- `ffmpeg==6.1.1` is the recommended runtime binary version. `pydub` uses the `ffmpeg` command on `PATH` for formats such as `mp3`, `aac`, `m4a`, and some `flac` paths. If the DataMate base image must use OS packages, keep the installed `ffmpeg` at `>=4.4` and record the exact OS package version in the image build manifest. + +Recommended runtime check: + +```bash +ffmpeg -version +python -c "import torch, torchaudio, speechbrain, pydub, soundfile, numpy, scipy, onnxruntime, transformers, timm" +``` + +## WeNet + +`audio_asr_transcribe` and `audio_asr_pipeline` import `wenet.bin.recognize` from the runtime environment. + +The project previously carried WeNet source under each ASR operator. That source has been removed from the operator package. The removed vendored source does not expose a package version in `wenet/__init__.py`, so this cleanup cannot derive a reliable semantic version from the previous copy. Since `wenet` is not available as a normal PyPI package in this environment, DataMate deployment must provide WeNet with a fixed source pin and record that pin as the runtime version, using one of: + +- an internal wheel with a fixed version, installed into the runtime image; +- a fixed git tag or commit of `wenet-e2e/wenet`, installed during image build; +- a system Python package placed on the runtime `PYTHONPATH`. + +The runtime must satisfy: + +```bash +python -c "from wenet.bin.recognize import main" +``` + +Do not rely on an operator-local `local_libs/wenet` directory. + +## Model Assets + +Model weights are still external runtime assets and are not Python dependencies: + +- LID model: `/models/AudioOperations/lid/speechbrain_lang-id-voxlingua107-ecapa` +- Chinese ASR model: `/models/AudioOperations/asr/aishell` +- English ASR model: `/models/AudioOperations/asr/librispeech` +- GTCRN model: `/models/AudioOperations/gtcrn/gtcrn.onnx` +- AST model: `/models/AudioOperations/recog/audioset_10_10_0.4593.pth` +- PANNs model: `/models/AudioOperations/panns/Cnn14_16k_mAP=0.438.pth` + +## Operators Affected + +The following operators now depend on the DataMate runtime environment instead of vendored libraries: + +- `audio_fast_lang_id` +- `audio_fast_lang_id_text` +- `audio_asr_transcribe` +- `audio_asr_pipeline` +- `audio_format_convert` +- `audio_sound_classify` diff --git a/runtime/ops/mapper/audio_runtime_requirements.txt b/runtime/ops/mapper/audio_runtime_requirements.txt new file mode 100644 index 00000000..6d5f78ae --- /dev/null +++ b/runtime/ops/mapper/audio_runtime_requirements.txt @@ -0,0 +1,24 @@ +# DataMate audio operators runtime dependencies. +# Install these into the DataMate runtime image/environment, not inside each operator package. + +torch==2.8.0 +torch_npu==2.8.0 +torchaudio==2.8.0 +torchvision==0.23.0 +numpy==2.2.6 +scipy==1.13.1 +soundfile==0.12.1 +pydub==0.25.1 +speechbrain==1.0.3 +HyperPyYAML==1.2.2 +onnxruntime==1.19.2 +jieba==0.42.1 +transformers==4.57.6 +safetensors==0.7.0 +librosa==0.10.2.post1 +torchlibrosa==0.0.4 +timm==1.0.26 +sentencepiece==0.2.1 +PyYAML==6.0.2 +loguru==0.7.3 +panns-inference==0.1.1 diff --git a/runtime/ops/mapper/audio_simple_agc/README.md b/runtime/ops/mapper/audio_simple_agc/README.md new file mode 100644 index 00000000..ab0c7a83 --- /dev/null +++ b/runtime/ops/mapper/audio_simple_agc/README.md @@ -0,0 +1,27 @@ +# AudioSimpleAgc 分段 RMS 自动增益算子 + +## 概述 + +AudioSimpleAgc 处理输入音频,并将结果写入 `sample["data"]`,同时设置 `sample["target_type"]`。输出路径、同名文件处理和最终落盘均交由 DataMate 的标准导出流程负责。 + +## 参数说明 + +| 参数 | 类型 | 默认值 | 说明 | +|---|---|---:|---| +| targetRms | slider | 0.05 | 目标 RMS(线性) | +| frameMs | inputNumber | 50 | 帧长(ms) | +| hopMs | inputNumber | 25 | 帧移(ms) | +| maxGain | slider | 10 | 最大线性增益 | + +## 输入输出 + +- **输入**:`sample["filePath"]`,若上游算子已产生 `sample["data"]`,则优先处理该音频字节。 +- **输出**:`sample["data"]` 为处理后的音频字节;`sample["target_type"]` 为目标音频后缀。 + +## 依赖说明 + +- **Python 依赖**:soundfile、numpy + +## 版本历史 + +- **v1.0.0**:首次发布 diff --git a/runtime/ops/mapper/audio_simple_agc/__init__.py b/runtime/ops/mapper/audio_simple_agc/__init__.py new file mode 100644 index 00000000..6f89d91f --- /dev/null +++ b/runtime/ops/mapper/audio_simple_agc/__init__.py @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- + +from datamate.core.base_op import OPERATORS + +OPERATORS.register_module(module_name='AudioSimpleAgc', + module_path="ops.mapper.audio_simple_agc.process") diff --git a/runtime/ops/mapper/audio_simple_agc/audio_skip.py b/runtime/ops/mapper/audio_simple_agc/audio_skip.py new file mode 100644 index 00000000..aec49613 --- /dev/null +++ b/runtime/ops/mapper/audio_simple_agc/audio_skip.py @@ -0,0 +1,114 @@ +# -- encoding: utf-8 -- + +from pathlib import Path +from typing import Any, Dict + +from loguru import logger + + +AUDIO_EXTS = { + "aac", + "aif", + "aiff", + "amr", + "au", + "flac", + "m4a", + "mp3", + "oga", + "ogg", + "opus", + "snd", + "wav", + "webm", + "wma", +} + + +def _parts(path_value: str) -> set[str]: + try: + return {part.lower() for part in Path(path_value).parts} + except Exception: + return set() + + +def is_reference_sample(sample: Dict[str, Any], filepath_key: str = "filePath") -> bool: + path_value = str(sample.get(filepath_key) or "") + return "references" in _parts(path_value) + + +def _ext_from_sample( + sample: Dict[str, Any], + filepath_key: str = "filePath", + filetype_key: str = "fileType", + target_type_key: str = "target_type", +) -> str: + for key in (target_type_key, filetype_key): + value = str(sample.get(key) or "").strip().lower().lstrip(".") + if value: + return value + path_value = str(sample.get(filepath_key) or "").strip() + return Path(path_value).suffix.lower().lstrip(".") if path_value else "" + + +def is_audio_sample( + sample: Dict[str, Any], + filepath_key: str = "filePath", + filetype_key: str = "fileType", + target_type_key: str = "target_type", + data_key: str = "data", +) -> bool: + if is_reference_sample(sample, filepath_key): + return False + data = sample.get(data_key) + if isinstance(data, (bytes, bytearray)) and data: + return True + return _ext_from_sample(sample, filepath_key, filetype_key, target_type_key) in AUDIO_EXTS + + +def invalid_quality_reason(sample: Dict[str, Any], ext_params_key: str = "ext_params") -> str: + for key in ("fileName", "sourceFileName", "filePath"): + marker_source = Path(str(sample.get(key) or "")).stem.lower() + marker = "__quality_invalid" + if marker in marker_source: + reason = marker_source.split(marker, 1)[1].strip("_") or "invalid_audio" + return f"invalid_audio_quality:{reason}" + + ext = sample.get(ext_params_key, {}) + if not isinstance(ext, dict): + return "" + quality = ext.get("audio_quality", {}) + if not isinstance(quality, dict): + return "" + if str(quality.get("quality_flag") or "").strip().lower() != "invalid": + return "" + skip_downstream = quality.get("skip_downstream", True) + if isinstance(skip_downstream, str): + skip_downstream = skip_downstream.strip().lower() in {"1", "true", "yes", "y", "on"} + if not skip_downstream: + return "" + reason = str(quality.get("reason") or "invalid_audio").strip() + return f"invalid_audio_quality:{reason}" + + +def mark_skipped_sample( + sample: Dict[str, Any], + reason: str, + op_name: str, + text_key: str = "text", + data_key: str = "data", + filetype_key: str = "fileType", + target_type_key: str = "target_type", + ext_params_key: str = "ext_params", +) -> Dict[str, Any]: + ext = sample.get(ext_params_key, {}) + if not isinstance(ext, dict): + ext = {"_raw": ext} + ext.setdefault("audio_skip", {})[op_name] = reason + sample[ext_params_key] = ext + sample[text_key] = "" + sample[data_key] = b"" + sample[filetype_key] = "" + sample[target_type_key] = "" + logger.info(f"fileName: {sample.get('fileName')}, method: {op_name} skipped: {reason}") + return sample diff --git a/runtime/ops/mapper/audio_simple_agc/metadata.yml b/runtime/ops/mapper/audio_simple_agc/metadata.yml new file mode 100644 index 00000000..6d1f58d7 --- /dev/null +++ b/runtime/ops/mapper/audio_simple_agc/metadata.yml @@ -0,0 +1,58 @@ +name: 'audioUtils-分段RMS自动增益' +name_en: 'audioUtils-Simple AGC (RMS)' +description: '按帧估计 RMS,将电平拉向目标并限制最大增益。处理音频并由 DataMate 统一导出结果。' +description_en: 'Frame-wise RMS AGC towards target RMS with max gain limit. Process audio and let DataMate export the result.' +language: 'python' +vendor: 'huawei' +raw_id: 'AudioSimpleAgc' +version: '1.0.0' +types: + - 'cleaning' +modal: 'audio' +inputs: 'audio' +outputs: 'audio' +settings: + targetRms: + name: '目标RMS' + type: 'slider' + description: '线性 RMS,越大越响。' + defaultVal: 0.05 + min: 0.001 + max: 0.5 + step: 0.001 + frameMs: + name: '帧长(ms)' + type: 'inputNumber' + description: '分析帧长。' + defaultVal: 50 + min: 5 + max: 500 + step: 1 + hopMs: + name: '帧移(ms)' + type: 'inputNumber' + description: '帧移。' + defaultVal: 25 + min: 1 + max: 500 + step: 1 + maxGain: + name: '最大增益(线性)' + type: 'slider' + description: '限制增益,避免过度放大噪声。' + defaultVal: 10 + min: 1 + max: 50 + step: 0.5 +runtime: + memory: 104857600 + cpu: 0.15 + gpu: 0 + npu: 0 + storage: 10MB + +metrics: + - name: '处理耗时' + metric: '依输入音频长度与运行环境而定' +release: + - '首次发布' diff --git a/runtime/ops/mapper/audio_simple_agc/process.py b/runtime/ops/mapper/audio_simple_agc/process.py new file mode 100644 index 00000000..cebdde6d --- /dev/null +++ b/runtime/ops/mapper/audio_simple_agc/process.py @@ -0,0 +1,114 @@ +# -- encoding: utf-8 -- + +import io +import time +from pathlib import Path +from typing import Dict, Any, Tuple + +from loguru import logger + +from datamate.core.base_op import Mapper +try: + from .audio_skip import invalid_quality_reason, is_audio_sample, mark_skipped_sample +except ImportError: + from audio_skip import invalid_quality_reason, is_audio_sample, mark_skipped_sample + + + +def _load_audio(source: object) -> Tuple["object", int]: + try: + import soundfile as sf # type: ignore + + if isinstance(source, (bytes, bytearray)): + data, sr = sf.read(io.BytesIO(bytes(source)), always_2d=False) + else: + data, sr = sf.read(str(source), always_2d=False) + return data, int(sr) + except Exception as e: + raise RuntimeError(f"读取音频失败(需要 soundfile): error={e}") from e + + +def _dump_audio(data: "object", sr: int, fmt: str) -> bytes: + try: + import soundfile as sf # type: ignore + + with io.BytesIO() as buf: + sf.write(buf, data, int(sr), format=fmt.upper() if fmt else "WAV") + return buf.getvalue() + except Exception as e: + raise RuntimeError(f"编码音频失败(需要 soundfile,fmt={fmt}): {e}") from e + + +class AudioSimpleAgc(Mapper): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.target_rms = float(kwargs.get("targetRms", 0.05)) + self.frame_ms = float(kwargs.get("frameMs", 50)) + self.hop_ms = float(kwargs.get("hopMs", 25)) + self.max_gain = float(kwargs.get("maxGain", 10)) + self.out_format = "wav" + + def execute(self, sample: Dict[str, Any]) -> Dict[str, Any]: + start = time.time() + quality_skip_reason = invalid_quality_reason(sample, self.ext_params_key) + if quality_skip_reason: + return mark_skipped_sample( + sample, + quality_skip_reason, + self.__class__.__name__, + self.text_key, + self.data_key, + self.filetype_key, + self.target_type_key, + self.ext_params_key, + ) + + if not is_audio_sample(sample, self.filepath_key, self.filetype_key, self.target_type_key, self.data_key): + return mark_skipped_sample( + sample, + "non_audio_or_reference_file", + self.__class__.__name__, + self.text_key, + self.data_key, + self.filetype_key, + self.target_type_key, + self.ext_params_key, + ) + in_path = Path(sample.get(self.filepath_key, "")).resolve() + if not in_path.exists(): + raise FileNotFoundError(f"输入音频不存在: {in_path}") + + data, sr = _load_audio(sample.get(self.data_key) or in_path) + try: + import numpy as np + + x = np.asarray(data, dtype=np.float32) + if x.ndim > 1: + x = x.mean(axis=1) + if x.size == 0: + y = x + else: + frame_len = max(1, int(sr * self.frame_ms / 1000.0)) + hop = max(1, int(sr * self.hop_ms / 1000.0)) + y = x.copy() + eps = 1e-8 + for st in range(0, len(x), hop): + ed = min(st + frame_len, len(x)) + frame = x[st:ed] + rms = float(np.sqrt(np.mean(frame * frame) + eps)) + g = float(self.target_rms) / max(eps, rms) + g = max(1.0 / max(1.0, self.max_gain), min(float(self.max_gain), g)) + y[st:ed] = y[st:ed] * g + # 简单防爆:限制到 [-1,1] + y = np.clip(y, -1.0, 1.0) + except Exception as e: + raise RuntimeError(f"处理失败(需要 numpy): {e}") from e + + sample[self.data_key] = _dump_audio(y, sr, self.out_format) + sample[self.text_key] = "" + sample[self.target_type_key] = self.out_format + sample[self.filetype_key] = "txt" + + logger.info(f"fileName: {sample.get(self.filename_key)}, method: AudioSimpleAgc costs {time.time() - start:6f} s") + return sample + diff --git a/runtime/ops/mapper/audio_simple_agc/requirements.txt b/runtime/ops/mapper/audio_simple_agc/requirements.txt new file mode 100644 index 00000000..17e9d57d --- /dev/null +++ b/runtime/ops/mapper/audio_simple_agc/requirements.txt @@ -0,0 +1,2 @@ +soundfile +numpy diff --git a/runtime/ops/mapper/audio_soft_peak_limiter/README.md b/runtime/ops/mapper/audio_soft_peak_limiter/README.md new file mode 100644 index 00000000..377b6d34 --- /dev/null +++ b/runtime/ops/mapper/audio_soft_peak_limiter/README.md @@ -0,0 +1,25 @@ +# AudioSoftPeakLimiter 软限幅算子 + +## 概述 + +AudioSoftPeakLimiter 处理输入音频,并将结果写入 `sample["data"]`,同时设置 `sample["target_type"]`。输出路径、同名文件处理和最终落盘均交由 DataMate 的标准导出流程负责。 + +## 参数说明 + +| 参数 | 类型 | 默认值 | 说明 | +|---|---|---:|---| +| threshold | slider | 0.92 | 线性区阈值(0~1) | +| knee | slider | 0.08 | 过渡宽度(0~1),越大越柔和 | + +## 输入输出 + +- **输入**:`sample["filePath"]`,若上游算子已产生 `sample["data"]`,则优先处理该音频字节。 +- **输出**:`sample["data"]` 为处理后的音频字节;`sample["target_type"]` 为目标音频后缀。 + +## 依赖说明 + +- **Python 依赖**:soundfile、numpy + +## 版本历史 + +- **v1.0.0**:首次发布 diff --git a/runtime/ops/mapper/audio_soft_peak_limiter/__init__.py b/runtime/ops/mapper/audio_soft_peak_limiter/__init__.py new file mode 100644 index 00000000..8d210aed --- /dev/null +++ b/runtime/ops/mapper/audio_soft_peak_limiter/__init__.py @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- + +from datamate.core.base_op import OPERATORS + +OPERATORS.register_module(module_name='AudioSoftPeakLimiter', + module_path="ops.mapper.audio_soft_peak_limiter.process") diff --git a/runtime/ops/mapper/audio_soft_peak_limiter/audio_skip.py b/runtime/ops/mapper/audio_soft_peak_limiter/audio_skip.py new file mode 100644 index 00000000..aec49613 --- /dev/null +++ b/runtime/ops/mapper/audio_soft_peak_limiter/audio_skip.py @@ -0,0 +1,114 @@ +# -- encoding: utf-8 -- + +from pathlib import Path +from typing import Any, Dict + +from loguru import logger + + +AUDIO_EXTS = { + "aac", + "aif", + "aiff", + "amr", + "au", + "flac", + "m4a", + "mp3", + "oga", + "ogg", + "opus", + "snd", + "wav", + "webm", + "wma", +} + + +def _parts(path_value: str) -> set[str]: + try: + return {part.lower() for part in Path(path_value).parts} + except Exception: + return set() + + +def is_reference_sample(sample: Dict[str, Any], filepath_key: str = "filePath") -> bool: + path_value = str(sample.get(filepath_key) or "") + return "references" in _parts(path_value) + + +def _ext_from_sample( + sample: Dict[str, Any], + filepath_key: str = "filePath", + filetype_key: str = "fileType", + target_type_key: str = "target_type", +) -> str: + for key in (target_type_key, filetype_key): + value = str(sample.get(key) or "").strip().lower().lstrip(".") + if value: + return value + path_value = str(sample.get(filepath_key) or "").strip() + return Path(path_value).suffix.lower().lstrip(".") if path_value else "" + + +def is_audio_sample( + sample: Dict[str, Any], + filepath_key: str = "filePath", + filetype_key: str = "fileType", + target_type_key: str = "target_type", + data_key: str = "data", +) -> bool: + if is_reference_sample(sample, filepath_key): + return False + data = sample.get(data_key) + if isinstance(data, (bytes, bytearray)) and data: + return True + return _ext_from_sample(sample, filepath_key, filetype_key, target_type_key) in AUDIO_EXTS + + +def invalid_quality_reason(sample: Dict[str, Any], ext_params_key: str = "ext_params") -> str: + for key in ("fileName", "sourceFileName", "filePath"): + marker_source = Path(str(sample.get(key) or "")).stem.lower() + marker = "__quality_invalid" + if marker in marker_source: + reason = marker_source.split(marker, 1)[1].strip("_") or "invalid_audio" + return f"invalid_audio_quality:{reason}" + + ext = sample.get(ext_params_key, {}) + if not isinstance(ext, dict): + return "" + quality = ext.get("audio_quality", {}) + if not isinstance(quality, dict): + return "" + if str(quality.get("quality_flag") or "").strip().lower() != "invalid": + return "" + skip_downstream = quality.get("skip_downstream", True) + if isinstance(skip_downstream, str): + skip_downstream = skip_downstream.strip().lower() in {"1", "true", "yes", "y", "on"} + if not skip_downstream: + return "" + reason = str(quality.get("reason") or "invalid_audio").strip() + return f"invalid_audio_quality:{reason}" + + +def mark_skipped_sample( + sample: Dict[str, Any], + reason: str, + op_name: str, + text_key: str = "text", + data_key: str = "data", + filetype_key: str = "fileType", + target_type_key: str = "target_type", + ext_params_key: str = "ext_params", +) -> Dict[str, Any]: + ext = sample.get(ext_params_key, {}) + if not isinstance(ext, dict): + ext = {"_raw": ext} + ext.setdefault("audio_skip", {})[op_name] = reason + sample[ext_params_key] = ext + sample[text_key] = "" + sample[data_key] = b"" + sample[filetype_key] = "" + sample[target_type_key] = "" + logger.info(f"fileName: {sample.get('fileName')}, method: {op_name} skipped: {reason}") + return sample diff --git a/runtime/ops/mapper/audio_soft_peak_limiter/metadata.yml b/runtime/ops/mapper/audio_soft_peak_limiter/metadata.yml new file mode 100644 index 00000000..8a7ac35a --- /dev/null +++ b/runtime/ops/mapper/audio_soft_peak_limiter/metadata.yml @@ -0,0 +1,42 @@ +name: 'audioUtils-软限幅' +name_en: 'audioUtils-Soft Peak Limiter' +description: '软饱和限制峰值(tanh 近似),减轻硬削波。处理音频并由 DataMate 统一导出结果。' +description_en: 'Soft limiting using tanh-like saturation to reduce clipping. Process audio and let DataMate export the result.' +language: 'python' +vendor: 'huawei' +raw_id: 'AudioSoftPeakLimiter' +version: '1.0.0' +types: + - 'cleaning' +modal: 'audio' +inputs: 'audio' +outputs: 'audio' +settings: + threshold: + name: '阈值' + type: 'slider' + description: '线性区阈值(0~1)。' + defaultVal: 0.92 + min: 0.1 + max: 1 + step: 0.01 + knee: + name: 'knee' + type: 'slider' + description: '过渡宽度(0~1),越大越柔和。' + defaultVal: 0.08 + min: 0 + max: 1 + step: 0.01 +runtime: + memory: 104857600 + cpu: 0.12 + gpu: 0 + npu: 0 + storage: 10MB + +metrics: + - name: '处理耗时' + metric: '依输入音频长度与运行环境而定' +release: + - '首次发布' diff --git a/runtime/ops/mapper/audio_soft_peak_limiter/process.py b/runtime/ops/mapper/audio_soft_peak_limiter/process.py new file mode 100644 index 00000000..53053a58 --- /dev/null +++ b/runtime/ops/mapper/audio_soft_peak_limiter/process.py @@ -0,0 +1,112 @@ +# -- encoding: utf-8 -- + +import io +import time +from pathlib import Path +from typing import Dict, Any, Tuple + +from loguru import logger + +from datamate.core.base_op import Mapper +try: + from .audio_skip import invalid_quality_reason, is_audio_sample, mark_skipped_sample +except ImportError: + from audio_skip import invalid_quality_reason, is_audio_sample, mark_skipped_sample + + + +def _load_audio(source: object) -> Tuple["object", int]: + try: + import soundfile as sf # type: ignore + + if isinstance(source, (bytes, bytearray)): + data, sr = sf.read(io.BytesIO(bytes(source)), always_2d=False) + else: + data, sr = sf.read(str(source), always_2d=False) + return data, int(sr) + except Exception as e: + raise RuntimeError(f"读取音频失败(需要 soundfile): error={e}") from e + + +def _dump_audio(data: "object", sr: int, fmt: str) -> bytes: + try: + import soundfile as sf # type: ignore + + with io.BytesIO() as buf: + sf.write(buf, data, int(sr), format=fmt.upper() if fmt else "WAV") + return buf.getvalue() + except Exception as e: + raise RuntimeError(f"编码音频失败(需要 soundfile,fmt={fmt}): {e}") from e + + +class AudioSoftPeakLimiter(Mapper): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.threshold = float(kwargs.get("threshold", 0.92)) + self.knee = float(kwargs.get("knee", 0.08)) + self.out_format = "wav" + + def execute(self, sample: Dict[str, Any]) -> Dict[str, Any]: + start = time.time() + quality_skip_reason = invalid_quality_reason(sample, self.ext_params_key) + if quality_skip_reason: + return mark_skipped_sample( + sample, + quality_skip_reason, + self.__class__.__name__, + self.text_key, + self.data_key, + self.filetype_key, + self.target_type_key, + self.ext_params_key, + ) + + if not is_audio_sample(sample, self.filepath_key, self.filetype_key, self.target_type_key, self.data_key): + return mark_skipped_sample( + sample, + "non_audio_or_reference_file", + self.__class__.__name__, + self.text_key, + self.data_key, + self.filetype_key, + self.target_type_key, + self.ext_params_key, + ) + in_path = Path(sample.get(self.filepath_key, "")).resolve() + if not in_path.exists(): + raise FileNotFoundError(f"输入音频不存在: {in_path}") + + data, sr = _load_audio(sample.get(self.data_key) or in_path) + try: + import numpy as np + + x = np.asarray(data, dtype=np.float32) + if x.ndim > 1: + x = x.mean(axis=1) + if x.size == 0: + y = x + else: + th = max(1e-6, min(1.0, float(self.threshold))) + knee = max(0.0, min(1.0, float(self.knee))) + # 简单软限幅:对超出阈值的部分做 tanh 压缩;knee 控制压缩强度 + a = 1.0 / max(1e-6, (1.0 - th + knee)) + y = x.copy() + absx = np.abs(x) + mask = absx > th + sign = np.sign(x[mask]) + z = (absx[mask] - th) * a + y[mask] = sign * (th + (1.0 - th) * np.tanh(z)) + y = np.clip(y, -1.0, 1.0) + except Exception as e: + raise RuntimeError(f"处理失败(需要 numpy): {e}") from e + + sample[self.data_key] = _dump_audio(y, sr, self.out_format) + sample[self.text_key] = "" + sample[self.target_type_key] = self.out_format + sample[self.filetype_key] = "txt" + + logger.info( + f"fileName: {sample.get(self.filename_key)}, method: AudioSoftPeakLimiter costs {time.time() - start:6f} s" + ) + return sample + diff --git a/runtime/ops/mapper/audio_soft_peak_limiter/requirements.txt b/runtime/ops/mapper/audio_soft_peak_limiter/requirements.txt new file mode 100644 index 00000000..17e9d57d --- /dev/null +++ b/runtime/ops/mapper/audio_soft_peak_limiter/requirements.txt @@ -0,0 +1,2 @@ +soundfile +numpy diff --git a/runtime/ops/mapper/audio_sound_classify/README.md b/runtime/ops/mapper/audio_sound_classify/README.md new file mode 100644 index 00000000..a6e2c756 --- /dev/null +++ b/runtime/ops/mapper/audio_sound_classify/README.md @@ -0,0 +1,34 @@ +# AudioSoundClassify 音频场景分类算子 + +AudioSoundClassify 将当前输入音频送入 AST 或 PANNs AudioSet 预训练模型,输出业务大类和 AudioSet 细类 top-k。它只做分类标注,不做准确率计算、数据集评测或流水线批处理。 + +## 输入输出 + +- 输入:音频文件路径或上游 `sample["data"]` 音频字节 +- 输出:保留当前音频,分类结果写入 `ext_params.audio_sound_classify` +- 作为最后算子时:导出当前音频,并在文件名追加 `__sound_` + +## 默认模型 + +默认后端为 AST,对应 annotation 模块当前标准实现。模型从固定部署路径读取: + +- AST:`/models/AudioOperations/recog/audioset_10_10_0.4593.pth` +- PANNs:`/models/AudioOperations/panns/Cnn14_16k_mAP=0.438.pth` + +算子内置 AST 的 `audioset_macro_map_v1.json` 与 PANNs 的 `classes_macro_draft.tsv`,可将 AudioSet 527 细类聚合为业务大类。 + +## 主要参数 + +| 参数 | 默认值 | 说明 | +|---|---:|---| +| backend | ast | ast 标准实现;panns 旧版兼容 | +| astCheckpoint | `/models/AudioOperations/recog/audioset_10_10_0.4593.pth` | AST 权重 | +| pannsCheckpoint | `/models/AudioOperations/panns/Cnn14_16k_mAP=0.438.pth` | PANNs 权重 | +| astMacroMap | 空 | AST 自定义粗类 JSON | +| macroMap | 空 | PANNs 自定义 label 到大类 TSV | +| device | auto | auto/cpu/npu/cuda | +| topK | 10 | 输出 AudioSet 细类数量 | +| humanSpeechThreshold | 0.2 | 人声优先规则阈值 | +| segmentSeconds | 10.24 | AST 滑窗长度 | +| hopSeconds | 5.12 | AST 滑窗步长 | +| keepAudio | true | 中间节点是否保留音频给下游 | diff --git a/runtime/ops/mapper/audio_sound_classify/__init__.py b/runtime/ops/mapper/audio_sound_classify/__init__.py new file mode 100644 index 00000000..0aba91ac --- /dev/null +++ b/runtime/ops/mapper/audio_sound_classify/__init__.py @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- + +from datamate.core.base_op import OPERATORS + +OPERATORS.register_module(module_name='AudioSoundClassify', + module_path="ops.mapper.audio_sound_classify.process") diff --git a/runtime/ops/mapper/audio_sound_classify/ast_vendor/__init__.py b/runtime/ops/mapper/audio_sound_classify/ast_vendor/__init__.py new file mode 100644 index 00000000..a186282a --- /dev/null +++ b/runtime/ops/mapper/audio_sound_classify/ast_vendor/__init__.py @@ -0,0 +1,2 @@ +from .ast_models import ASTConfig, ASTModel, load_ast_from_pth + diff --git a/runtime/ops/mapper/audio_sound_classify/ast_vendor/ast_models.py b/runtime/ops/mapper/audio_sound_classify/ast_vendor/ast_models.py new file mode 100644 index 00000000..f9f86c45 --- /dev/null +++ b/runtime/ops/mapper/audio_sound_classify/ast_vendor/ast_models.py @@ -0,0 +1,293 @@ +""" +Vendored minimal AST (Audio Spectrogram Transformer) model definition. + +来源:YuanGongND/ast(Interspeech 2021, AST: Audio Spectrogram Transformer) +为了适配本工程: +- 不在运行时下载任何权重(无外网依赖) +- 不强制 timm 版本(尽量兼容常见版本) +- 不使用 CUDA autocast 装饰器(避免在 NPU/CPU 环境报错) +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Dict, Tuple + +import torch +import torch.nn as nn + +try: + import timm # type: ignore + from timm.models.layers import to_2tuple, trunc_normal_ # type: ignore +except Exception as e: # pragma: no cover + raise RuntimeError( + "缺少依赖 timm,AST 模型无法创建。请在环境中安装 timm(建议与 AST 兼容的版本)。\n" + "例如:pip install timm" + ) from e + + +class PatchEmbed(nn.Module): + """Override timm PatchEmbed: relax input shape constraint.""" + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +class ASTModel(nn.Module): + """ + AST model (inference use). + + Input: [batch, time_frame_num, frequency_bins] => e.g. [B, 1024, 128] + Output: [batch, label_dim] raw logits (no sigmoid/softmax) + """ + + def __init__( + self, + *, + label_dim: int = 527, + fstride: int = 10, + tstride: int = 10, + input_fdim: int = 128, + input_tdim: int = 1024, + imagenet_pretrain: bool = True, + model_size: str = "base384", + verbose: bool = False, + ) -> None: + super().__init__() + + if verbose: + print("---------------AST Model Summary---------------", flush=True) + print( + f"ImageNet pretraining: {imagenet_pretrain}, model_size={model_size}", + flush=True, + ) + + # override timm input shape restriction + # timm 0.x: timm.models.vision_transformer.PatchEmbed + # timm 1.x: timm.layers.patch_embed.PatchEmbed + try: + timm.models.vision_transformer.PatchEmbed = PatchEmbed # type: ignore[attr-defined] + except Exception: + pass + try: + import timm.layers # type: ignore + + timm.layers.PatchEmbed = PatchEmbed # type: ignore[attr-defined] + except Exception: + pass + try: + import timm.layers.patch_embed as _pe # type: ignore + + _pe.PatchEmbed = PatchEmbed # type: ignore[attr-defined] + except Exception: + pass + + if model_size == "tiny224": + self.v = timm.create_model( + "vit_deit_tiny_distilled_patch16_224", pretrained=imagenet_pretrain + ) + elif model_size == "small224": + self.v = timm.create_model( + "vit_deit_small_distilled_patch16_224", pretrained=imagenet_pretrain + ) + elif model_size == "base224": + self.v = timm.create_model( + "vit_deit_base_distilled_patch16_224", pretrained=imagenet_pretrain + ) + elif model_size == "base384": + # timm 新版本(>=1.x)模型命名与 AST 原仓库不同,这里做兼容回退 + cand = [ + "vit_deit_base_distilled_patch16_384", # AST 原仓库 + "deit_base_distilled_patch16_384", # timm 1.x + "deit_base_patch16_384", # 无蒸馏token的备选(仍可跑推理,但权重需匹配) + ] + last_err: Exception | None = None + for name in cand: + try: + self.v = timm.create_model(name, pretrained=imagenet_pretrain) + break + except Exception as e: + last_err = e + continue + else: + raise RuntimeError(f"timm 中未找到可用的 deit 384 模型名,尝试过: {cand}") from last_err + else: + raise ValueError("model_size 必须是 tiny224/small224/base224/base384 之一。") + + self.original_num_patches = int(self.v.patch_embed.num_patches) + self.original_hw = int(self.original_num_patches**0.5) + self.original_embedding_dim = int(self.v.pos_embed.shape[2]) + + # timm 1.x 的 PatchEmbed 会强校验输入 img_size,这里直接替换为 AST 版本(无 shape assert) + # 注意:后续会重新设置 num_patches / proj / pos_embed。 + self.v.patch_embed = PatchEmbed( + img_size=(int(input_fdim), int(input_tdim)), + patch_size=16, + in_chans=1, + embed_dim=self.original_embedding_dim, + ) + + self.mlp_head = nn.Sequential( + nn.LayerNorm(self.original_embedding_dim), + nn.Linear(self.original_embedding_dim, int(label_dim)), + ) + + f_dim, t_dim = self.get_shape(fstride, tstride, input_fdim, input_tdim) + num_patches = int(f_dim * t_dim) + self.v.patch_embed.num_patches = num_patches + if verbose: + print(f"frequency stride={fstride}, time stride={tstride}", flush=True) + print(f"patches={num_patches} (f_dim={f_dim}, t_dim={t_dim})", flush=True) + + # projection layer: 1 channel input + new_proj = nn.Conv2d( + 1, + self.original_embedding_dim, + kernel_size=(16, 16), + stride=(int(fstride), int(tstride)), + ) + if imagenet_pretrain: + # sum RGB weights -> single-channel init + new_proj.weight = nn.Parameter( + torch.sum(self.v.patch_embed.proj.weight, dim=1).unsqueeze(1) + ) + new_proj.bias = self.v.patch_embed.proj.bias + self.v.patch_embed.proj = new_proj + + # positional embedding adaptation + if imagenet_pretrain: + # skip cls & dist tokens, reshape pos embed to 2D + pos = ( + self.v.pos_embed[:, 2:, :] + .detach() + .reshape(1, self.original_num_patches, self.original_embedding_dim) + .transpose(1, 2) + .reshape(1, self.original_embedding_dim, self.original_hw, self.original_hw) + ) + # time dim + if t_dim <= self.original_hw: + start = int(self.original_hw / 2) - int(t_dim / 2) + pos = pos[:, :, :, start : start + int(t_dim)] + else: + pos = torch.nn.functional.interpolate(pos, size=(self.original_hw, int(t_dim)), mode="bilinear") + # freq dim + if f_dim <= self.original_hw: + start = int(self.original_hw / 2) - int(f_dim / 2) + pos = pos[:, :, start : start + int(f_dim), :] + else: + pos = torch.nn.functional.interpolate(pos, size=(int(f_dim), int(t_dim)), mode="bilinear") + + pos = pos.reshape(1, self.original_embedding_dim, num_patches).transpose(1, 2) + self.v.pos_embed = nn.Parameter( + torch.cat([self.v.pos_embed[:, :2, :].detach(), pos], dim=1) + ) + else: + self.v.pos_embed = nn.Parameter( + torch.zeros(1, self.v.patch_embed.num_patches + 2, self.original_embedding_dim) + ) + trunc_normal_(self.v.pos_embed, std=0.02) + + def get_shape( + self, fstride: int, tstride: int, input_fdim: int = 128, input_tdim: int = 1024 + ) -> Tuple[int, int]: + test_input = torch.randn(1, 1, int(input_fdim), int(input_tdim)) + test_proj = nn.Conv2d( + 1, + self.original_embedding_dim, + kernel_size=(16, 16), + stride=(int(fstride), int(tstride)), + ) + test_out = test_proj(test_input) + return int(test_out.shape[2]), int(test_out.shape[3]) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x: (B, T, F) -> (B, 1, F, T) + x = x.unsqueeze(1).transpose(2, 3) + bsz = x.shape[0] + + x = self.v.patch_embed(x) + cls_tokens = self.v.cls_token.expand(bsz, -1, -1) + dist_token = self.v.dist_token.expand(bsz, -1, -1) + x = torch.cat((cls_tokens, dist_token, x), dim=1) + x = x + self.v.pos_embed + x = self.v.pos_drop(x) + for blk in self.v.blocks: + x = blk(x) + x = self.v.norm(x) + x = (x[:, 0] + x[:, 1]) / 2 + x = self.mlp_head(x) + return x + + +@dataclass(frozen=True) +class ASTConfig: + label_dim: int = 527 + fstride: int = 10 + tstride: int = 10 + input_fdim: int = 128 + input_tdim: int = 1024 + model_size: str = "base384" + + +def _strip_module_prefix(state: Dict[str, Any]) -> Dict[str, Any]: + if any(k.startswith("module.") for k in state.keys()): + return {k[len("module.") :]: v for k, v in state.items()} + return state + + +def load_ast_from_pth( + *, + checkpoint_path: str, + device: torch.device, + cfg: ASTConfig = ASTConfig(), +) -> ASTModel: + """ + 从本地 .pth 加载 AST(AudioSet 0.4593 权重)用于推理。 + 兼容: + - 直接 state_dict + - 包在 dict 里(如 {'state_dict': ...} / {'model': ...}) + - DataParallel 前缀 module.* + """ + model = ASTModel( + label_dim=cfg.label_dim, + fstride=cfg.fstride, + tstride=cfg.tstride, + input_fdim=cfg.input_fdim, + input_tdim=cfg.input_tdim, + imagenet_pretrain=False, + model_size=cfg.model_size, + verbose=False, + ) + obj = torch.load(checkpoint_path, map_location="cpu", weights_only=False) + if isinstance(obj, dict) and "state_dict" in obj and isinstance(obj["state_dict"], dict): + state = obj["state_dict"] + elif isinstance(obj, dict) and "model" in obj and isinstance(obj["model"], dict): + state = obj["model"] + elif isinstance(obj, dict) and all(isinstance(k, str) for k in obj.keys()): + # assume it's a raw state_dict + state = obj + else: + raise ValueError("不支持的 checkpoint 格式,无法解析 state_dict。") + + state = _strip_module_prefix(state) + missing, unexpected = model.load_state_dict(state, strict=False) + if missing: + # 一般不会影响推理(例如部分 buffer),但需要显式暴露出来方便排障 + print(f"[WARN] AST missing keys: {len(missing)}", flush=True) + if unexpected: + print(f"[WARN] AST unexpected keys: {len(unexpected)}", flush=True) + model.to(device) + model.eval() + return model + diff --git a/runtime/ops/mapper/audio_sound_classify/audio_skip.py b/runtime/ops/mapper/audio_sound_classify/audio_skip.py new file mode 100644 index 00000000..796d4c66 --- /dev/null +++ b/runtime/ops/mapper/audio_sound_classify/audio_skip.py @@ -0,0 +1,119 @@ +# -- encoding: utf-8 -- + +from pathlib import Path +from typing import Any, Dict + +try: + from loguru import logger +except Exception: + import logging + + logger = logging.getLogger(__name__) + + +AUDIO_EXTS = { + "aac", + "aif", + "aiff", + "amr", + "au", + "flac", + "m4a", + "mp3", + "oga", + "ogg", + "opus", + "snd", + "wav", + "webm", + "wma", +} + + +def _parts(path_value: str) -> set[str]: + try: + return {part.lower() for part in Path(path_value).parts} + except Exception: + return set() + + +def is_reference_sample(sample: Dict[str, Any], filepath_key: str = "filePath") -> bool: + path_value = str(sample.get(filepath_key) or "") + return "references" in _parts(path_value) + + +def _ext_from_sample( + sample: Dict[str, Any], + filepath_key: str = "filePath", + filetype_key: str = "fileType", + target_type_key: str = "target_type", +) -> str: + for key in (target_type_key, filetype_key): + value = str(sample.get(key) or "").strip().lower().lstrip(".") + if value: + return value + path_value = str(sample.get(filepath_key) or "").strip() + return Path(path_value).suffix.lower().lstrip(".") if path_value else "" + + +def is_audio_sample( + sample: Dict[str, Any], + filepath_key: str = "filePath", + filetype_key: str = "fileType", + target_type_key: str = "target_type", + data_key: str = "data", +) -> bool: + if is_reference_sample(sample, filepath_key): + return False + data = sample.get(data_key) + if isinstance(data, (bytes, bytearray)) and data: + return True + return _ext_from_sample(sample, filepath_key, filetype_key, target_type_key) in AUDIO_EXTS + + +def invalid_quality_reason(sample: Dict[str, Any], ext_params_key: str = "ext_params") -> str: + for key in ("fileName", "sourceFileName", "filePath"): + marker_source = Path(str(sample.get(key) or "")).stem.lower() + marker = "__quality_invalid" + if marker in marker_source: + reason = marker_source.split(marker, 1)[1].strip("_") or "invalid_audio" + return f"invalid_audio_quality:{reason}" + + ext = sample.get(ext_params_key, {}) + if not isinstance(ext, dict): + return "" + quality = ext.get("audio_quality", {}) + if not isinstance(quality, dict): + return "" + if str(quality.get("quality_flag") or "").strip().lower() != "invalid": + return "" + skip_downstream = quality.get("skip_downstream", True) + if isinstance(skip_downstream, str): + skip_downstream = skip_downstream.strip().lower() in {"1", "true", "yes", "y", "on"} + if not skip_downstream: + return "" + reason = str(quality.get("reason") or "invalid_audio").strip() + return f"invalid_audio_quality:{reason}" + + +def mark_skipped_sample( + sample: Dict[str, Any], + reason: str, + op_name: str, + text_key: str = "text", + data_key: str = "data", + filetype_key: str = "fileType", + target_type_key: str = "target_type", + ext_params_key: str = "ext_params", +) -> Dict[str, Any]: + ext = sample.get(ext_params_key, {}) + if not isinstance(ext, dict): + ext = {"_raw": ext} + ext.setdefault("audio_skip", {})[op_name] = reason + sample[ext_params_key] = ext + sample[text_key] = "" + sample[data_key] = b"" + sample[filetype_key] = "" + sample[target_type_key] = "" + logger.info(f"fileName: {sample.get('fileName')}, method: {op_name} skipped: {reason}") + return sample diff --git a/runtime/ops/mapper/audio_sound_classify/metadata.yml b/runtime/ops/mapper/audio_sound_classify/metadata.yml new file mode 100644 index 00000000..955d7fd5 --- /dev/null +++ b/runtime/ops/mapper/audio_sound_classify/metadata.yml @@ -0,0 +1,138 @@ +name: 'audioOps-音频场景分类' +name_en: 'audioOps-Audio Sound Classification' +description: '调用 AST/PANNs AudioSet 预训练模型识别当前音频的声音类别;标注写入 ext_params.audio_sound_classify,并保持音频作为输出。' +description_en: 'Classify one audio sample with an AST/PANNs AudioSet model; write ext_params.audio_sound_classify and keep the audio as output.' +language: 'python' +vendor: 'huawei' +raw_id: 'AudioSoundClassify' +version: '1.0.0' +types: + - 'annotation' +modal: 'audio' +inputs: 'audio' +outputs: 'audio' +settings: + backend: + name: '分类后端' + description: 'ast 为当前标准实现;panns 为旧版兼容实现。' + type: 'select' + defaultVal: 'ast' + required: true + options: + - label: 'AST' + value: 'ast' + - label: 'PANNs' + value: 'panns' + checkpoint: + name: '兼容模型路径' + description: '兼容旧参数。backend=ast 时建议使用 astCheckpoint;backend=panns 时建议使用 pannsCheckpoint。' + type: 'input' + defaultVal: '/models/AudioOperations/recog/audioset_10_10_0.4593.pth' + required: false + astCheckpoint: + name: 'AST 模型路径' + description: 'AST AudioSet checkpoint 路径。' + type: 'input' + defaultVal: '/models/AudioOperations/recog/audioset_10_10_0.4593.pth' + required: false + pannsCheckpoint: + name: 'PANNs 模型路径' + description: 'PANNs Cnn14_16k checkpoint 路径。' + type: 'input' + defaultVal: '/models/AudioOperations/panns/Cnn14_16k_mAP=0.438.pth' + required: false + macroMap: + name: 'PANNs 大类映射表' + description: 'PANNs 后端使用的 AudioSet label 到业务大类 TSV;留空使用算子内置映射。' + type: 'input' + defaultVal: '' + required: false + astMacroMap: + name: 'AST 大类映射表' + description: 'AST 后端使用的粗类映射 JSON;留空使用算子内置映射。' + type: 'input' + defaultVal: '' + required: false + labelsCsv: + name: 'AudioSet 标签表' + description: 'AST 后端使用的 class_labels_indices.csv;留空使用算子内置标签表。' + type: 'input' + defaultVal: '' + required: false + device: + name: '设备' + description: '推理设备。' + type: 'select' + defaultVal: 'auto' + required: true + options: + - label: 'auto' + value: 'auto' + - label: 'cpu' + value: 'cpu' + - label: 'npu' + value: 'npu' + - label: 'cuda' + value: 'cuda' + topK: + name: '细类 TopK' + type: 'inputNumber' + description: '输出 AudioSet 细类数量。' + defaultVal: 10 + min: 1 + max: 50 + step: 1 + humanSpeechThreshold: + name: '人声优先阈值' + type: 'slider' + description: 'top-k 聚合后 HumanSpeech 分数超过该阈值时优先判为人声。' + defaultVal: 0.2 + min: 0 + max: 1 + step: 0.01 + segmentSeconds: + name: 'AST 分段秒数' + type: 'inputNumber' + description: 'AST 后端滑窗分段长度。' + defaultVal: 10.24 + min: 1 + max: 120 + step: 0.01 + hopSeconds: + name: 'AST 步长秒数' + type: 'inputNumber' + description: 'AST 后端滑窗步长。' + defaultVal: 5.12 + min: 0.1 + max: 120 + step: 0.01 + macroAgg: + name: 'AST 大类聚合' + description: 'AST 后端将细类聚合成大类的策略。' + type: 'select' + defaultVal: 'max' + required: true + options: + - label: 'max' + value: 'max' + - label: 'sum' + value: 'sum' + keepAudio: + name: '中间节点保留音频' + type: 'switch' + description: '作为中间节点时是否保留音频字节给下游算子。' + defaultVal: 'true' + required: false + checkedLabel: '保留' + unCheckedLabel: '不保留' +runtime: + memory: 4294967296 + cpu: 1.0 + gpu: 0 + npu: 0 + storage: 20MB +metrics: + - name: '分类类别' + metric: 'AST 默认 12 个业务大类;PANNs 兼容模式 15 个业务大类;均支持 AudioSet 527 细类 top-k' +release: + - '首次发布,支持 AST 标准分类与 PANNs 兼容分类' diff --git a/runtime/ops/mapper/audio_sound_classify/models/panns/classes_macro_draft.tsv b/runtime/ops/mapper/audio_sound_classify/models/panns/classes_macro_draft.tsv new file mode 100644 index 00000000..0dfde677 --- /dev/null +++ b/runtime/ops/mapper/audio_sound_classify/models/panns/classes_macro_draft.tsv @@ -0,0 +1,528 @@ +label macro_class +Speech HumanSpeech +Male speech, man speaking HumanSpeech +Female speech, woman speaking HumanSpeech +Child speech, kid speaking HumanSpeech +Conversation HumanSpeech +Narration, monologue HumanSpeech +Babbling HumanSpeech +Speech synthesizer HumanSpeech +Shout HumanSpeech +Bellow AnimalSounds +Whoop CrowdAmbience +Yell HumanSpeech +Battle cry CrowdAmbience +Children shouting HumanSpeech +Screaming HumanSpeech +Whispering HumanSpeech +Laughter HumanBodySound +Baby laughter HumanBodySound +Giggle HumanBodySound +Snicker HumanBodySound +Belly laugh HumanBodySound +Chuckle, chortle HumanBodySound +Crying, sobbing HumanBodySound +Baby cry, infant cry HumanBodySound +Whimper HumanBodySound +Wail, moan HumanBodySound +Sigh HumanBodySound +Singing SingingVocal +Choir SingingVocal +Yodeling SingingVocal +Chant SingingVocal +Mantra HumanSpeech +Male singing SingingVocal +Female singing SingingVocal +Child singing SingingVocal +Synthetic singing SingingVocal +Rapping SingingVocal +Humming SingingVocal +Groan HumanBodySound +Grunt HumanBodySound +Whistling SingingVocal +Breathing HumanBodySound +Wheeze HumanBodySound +Snoring HumanBodySound +Gasp HumanBodySound +Pant HumanBodySound +Snort HumanBodySound +Cough HumanBodySound +Throat clearing HumanBodySound +Sneeze HumanBodySound +Sniff HumanBodySound +Run HumanBodySound +Shuffle HumanBodySound +Walk, footsteps HumanBodySound +Chewing, mastication HumanBodySound +Biting HumanBodySound +Gargling HumanBodySound +Stomach rumble HumanBodySound +Burping, eructation HumanBodySound +Hiccup HumanBodySound +Fart HumanBodySound +Hands HumanBodySound +Finger snapping HumanBodySound +Clapping HumanBodySound +Heart sounds, heartbeat HumanBodySound +Heart murmur HumanBodySound +Cheering CrowdAmbience +Applause CrowdAmbience +Chatter HumanSpeech +Crowd CrowdAmbience +Hubbub, speech noise, speech babble CrowdAmbience +Children playing CrowdAmbience +Animal AnimalSounds +Domestic animals, pets AnimalSounds +Dog AnimalSounds +Bark AnimalSounds +Yip AnimalSounds +Howl AnimalSounds +Bow-wow AnimalSounds +Growling AnimalSounds +Whimper (dog) AnimalSounds +Cat AnimalSounds +Purr AnimalSounds +Meow AnimalSounds +Hiss AnimalSounds +Caterwaul AnimalSounds +Livestock, farm animals, working animals AnimalSounds +Horse AnimalSounds +Clip-clop AnimalSounds +Neigh, whinny AnimalSounds +Cattle, bovinae AnimalSounds +Moo AnimalSounds +Cowbell AnimalSounds +Pig AnimalSounds +Oink AnimalSounds +Goat AnimalSounds +Bleat AnimalSounds +Sheep AnimalSounds +Fowl AnimalSounds +Chicken, rooster AnimalSounds +Cluck AnimalSounds +Crowing, cock-a-doodle-doo AnimalSounds +Turkey AnimalSounds +Gobble AnimalSounds +Duck AnimalSounds +Quack AnimalSounds +Goose AnimalSounds +Honk AnimalSounds +Wild animals AnimalSounds +Roaring cats (lions, tigers) AnimalSounds +Roar AnimalSounds +Bird AnimalSounds +Bird vocalization, bird call, bird song AnimalSounds +Chirp, tweet AnimalSounds +Squawk AnimalSounds +Pigeon, dove AnimalSounds +Coo AnimalSounds +Crow AnimalSounds +Caw AnimalSounds +Owl AnimalSounds +Hoot AnimalSounds +Bird flight, flapping wings AnimalSounds +Canidae, dogs, wolves AnimalSounds +Rodents, rats, mice AnimalSounds +Mouse AnimalSounds +Patter AnimalSounds +Insect AnimalSounds +Cricket AnimalSounds +Mosquito AnimalSounds +Fly, housefly AnimalSounds +Buzz AnimalSounds +Bee, wasp, etc. AnimalSounds +Frog AnimalSounds +Croak AnimalSounds +Snake AnimalSounds +Rattle NoiseArtifact +Whale vocalization AnimalSounds +Music RecordedMusic +Musical instrument MusicalInstrument +Plucked string instrument MusicalInstrument +Guitar MusicalInstrument +Electric guitar MusicalInstrument +Bass guitar MusicalInstrument +Acoustic guitar MusicalInstrument +Steel guitar, slide guitar MusicalInstrument +Tapping (guitar technique) MusicalInstrument +Strum MusicalInstrument +Banjo MusicalInstrument +Sitar MusicalInstrument +Mandolin MusicalInstrument +Zither MusicalInstrument +Ukulele MusicalInstrument +Keyboard (musical) MusicalInstrument +Piano MusicalInstrument +Electric piano MusicalInstrument +Organ MusicalInstrument +Electronic organ MusicalInstrument +Hammond organ MusicalInstrument +Synthesizer MusicalInstrument +Sampler MusicalInstrument +Harpsichord MusicalInstrument +Percussion MusicalInstrument +Drum kit MusicalInstrument +Drum machine MusicalInstrument +Drum MusicalInstrument +Snare drum MusicalInstrument +Rimshot MusicalInstrument +Drum roll MusicalInstrument +Bass drum MusicalInstrument +Timpani MusicalInstrument +Tabla MusicalInstrument +Cymbal MusicalInstrument +Hi-hat MusicalInstrument +Wood block MusicalInstrument +Tambourine MusicalInstrument +Rattle (instrument) MusicalInstrument +Maraca MusicalInstrument +Gong MusicalInstrument +Tubular bells MusicalInstrument +Mallet percussion MusicalInstrument +Marimba, xylophone MusicalInstrument +Glockenspiel MusicalInstrument +Vibraphone MusicalInstrument +Steelpan MusicalInstrument +Orchestra MusicalInstrument +Brass instrument MusicalInstrument +French horn MusicalInstrument +Trumpet MusicalInstrument +Trombone MusicalInstrument +Bowed string instrument MusicalInstrument +String section MusicalInstrument +Violin, fiddle MusicalInstrument +Pizzicato MusicalInstrument +Cello MusicalInstrument +Double bass MusicalInstrument +Wind instrument, woodwind instrument MusicalInstrument +Flute MusicalInstrument +Saxophone MusicalInstrument +Clarinet MusicalInstrument +Harp MusicalInstrument +Bell MusicalInstrument +Church bell MusicalInstrument +Jingle bell MusicalInstrument +Bicycle bell Transportation +Tuning fork MusicalInstrument +Chime MusicalInstrument +Wind chime MusicalInstrument +Change ringing (campanology) MusicalInstrument +Harmonica MusicalInstrument +Accordion MusicalInstrument +Bagpipes MusicalInstrument +Didgeridoo MusicalInstrument +Shofar MusicalInstrument +Theremin MusicalInstrument +Singing bowl MusicalInstrument +Scratching (performance technique) MusicalInstrument +Pop music RecordedMusic +Hip hop music RecordedMusic +Beatboxing SingingVocal +Rock music RecordedMusic +Heavy metal RecordedMusic +Punk rock RecordedMusic +Grunge HumanBodySound +Progressive rock RecordedMusic +Rock and roll RecordedMusic +Psychedelic rock RecordedMusic +Rhythm and blues RecordedMusic +Soul music RecordedMusic +Reggae RecordedMusic +Country RecordedMusic +Swing music RecordedMusic +Bluegrass RecordedMusic +Funk RecordedMusic +Folk music RecordedMusic +Middle Eastern music RecordedMusic +Jazz RecordedMusic +Disco RecordedMusic +Classical music RecordedMusic +Opera RecordedMusic +Electronic music RecordedMusic +House music RecordedMusic +Techno RecordedMusic +Dubstep RecordedMusic +Drum and bass RecordedMusic +Electronica RecordedMusic +Electronic dance music RecordedMusic +Ambient music RecordedMusic +Trance music RecordedMusic +Music of Latin America RecordedMusic +Salsa music RecordedMusic +Flamenco RecordedMusic +Blues RecordedMusic +Music for children RecordedMusic +New-age music RecordedMusic +Vocal music RecordedMusic +A capella RecordedMusic +Music of Africa RecordedMusic +Afrobeat RecordedMusic +Christian music RecordedMusic +Gospel music RecordedMusic +Music of Asia RecordedMusic +Carnatic music RecordedMusic +Music of Bollywood RecordedMusic +Ska RecordedMusic +Traditional music RecordedMusic +Independent music RecordedMusic +Song RecordedMusic +Background music RecordedMusic +Theme music RecordedMusic +Jingle (music) RecordedMusic +Soundtrack music RecordedMusic +Lullaby RecordedMusic +Video game music RecordedMusic +Christmas music RecordedMusic +Dance music RecordedMusic +Wedding music RecordedMusic +Happy music RecordedMusic +Funny music RecordedMusic +Sad music RecordedMusic +Tender music RecordedMusic +Exciting music RecordedMusic +Angry music RecordedMusic +Scary music RecordedMusic +Wind NatureWaterFire +Rustling leaves NatureWaterFire +Wind noise (microphone) NatureWaterFire +Thunderstorm NatureWaterFire +Thunder NatureWaterFire +Water NatureWaterFire +Rain NatureWaterFire +Raindrop NatureWaterFire +Rain on surface NatureWaterFire +Stream NatureWaterFire +Waterfall NatureWaterFire +Ocean NatureWaterFire +Waves, surf NatureWaterFire +Steam NatureWaterFire +Gurgling NatureWaterFire +Fire NatureWaterFire +Crackle NatureWaterFire +Vehicle Transportation +Boat, Water vehicle Transportation +Sailboat, sailing ship Transportation +Rowboat, canoe, kayak Transportation +Motorboat, speedboat Transportation +Ship Transportation +Motor vehicle (road) Transportation +Car Transportation +Vehicle horn, car horn, honking Transportation +Toot Transportation +Car alarm AlarmSignal +Power windows, electric windows Transportation +Skidding Transportation +Tire squeal Transportation +Car passing by Transportation +Race car, auto racing Transportation +Truck Transportation +Air brake Transportation +Air horn, truck horn Transportation +Reversing beeps Transportation +Ice cream truck, ice cream van Transportation +Bus Transportation +Emergency vehicle Transportation +Police car (siren) Transportation +Ambulance (siren) Transportation +Fire engine, fire truck (siren) Transportation +Motorcycle Transportation +Traffic noise, roadway noise Transportation +Rail transport Transportation +Train Transportation +Train whistle Transportation +Train horn Transportation +Railroad car, train wagon Transportation +Train wheels squealing Transportation +Subway, metro, underground Transportation +Aircraft Transportation +Aircraft engine Transportation +Jet engine Transportation +Propeller, airscrew Transportation +Helicopter Transportation +Fixed-wing aircraft, airplane Transportation +Bicycle Transportation +Skateboard Transportation +Engine MachineAppliance +Light engine (high frequency) MachineAppliance +Dental drill, dentist's drill MachineAppliance +Lawn mower MachineAppliance +Chainsaw MachineAppliance +Medium engine (mid frequency) MachineAppliance +Heavy engine (low frequency) MachineAppliance +Engine knocking MachineAppliance +Engine starting MachineAppliance +Idling MachineAppliance +Accelerating, revving, vroom Transportation +Door ToolImpact +Doorbell AlarmSignal +Ding-dong AlarmSignal +Sliding door ToolImpact +Slam ToolImpact +Knock ToolImpact +Tap ToolImpact +Squeak ToolImpact +Cupboard open or close ToolImpact +Drawer open or close ToolImpact +Dishes, pots, and pans ToolImpact +Cutlery, silverware ToolImpact +Chopping (food) ToolImpact +Frying (food) ToolImpact +Microwave oven MachineAppliance +Blender MachineAppliance +Water tap, faucet ToolImpact +Sink (filling or washing) NatureWaterFire +Bathtub (filling or washing) ToolImpact +Hair dryer MachineAppliance +Toilet flush ToolImpact +Toothbrush MachineAppliance +Electric toothbrush MachineAppliance +Vacuum cleaner MachineAppliance +Zipper (clothing) ToolImpact +Keys jangling ToolImpact +Coin (dropping) ToolImpact +Scissors ToolImpact +Electric shaver, electric razor MachineAppliance +Shuffling cards ToolImpact +Typing ToolImpact +Typewriter ToolImpact +Computer keyboard ToolImpact +Writing ToolImpact +Alarm AlarmSignal +Telephone AlarmSignal +Telephone bell ringing AlarmSignal +Ringtone AlarmSignal +Telephone dialing, DTMF AlarmSignal +Dial tone AlarmSignal +Busy signal AlarmSignal +Alarm clock AlarmSignal +Siren AlarmSignal +Civil defense siren AlarmSignal +Buzzer AlarmSignal +Smoke detector, smoke alarm AlarmSignal +Fire alarm AlarmSignal +Foghorn Transportation +Whistle AlarmSignal +Steam whistle NatureWaterFire +Mechanisms MachineAppliance +Ratchet, pawl MachineAppliance +Clock MachineAppliance +Tick MachineAppliance +Tick-tock MachineAppliance +Gears MachineAppliance +Pulleys MachineAppliance +Sewing machine MachineAppliance +Mechanical fan MachineAppliance +Air conditioning MachineAppliance +Cash register MachineAppliance +Printer MachineAppliance +Camera MachineAppliance +Single-lens reflex camera MachineAppliance +Tools ToolImpact +Hammer ToolImpact +Jackhammer ToolImpact +Sawing ToolImpact +Filing (rasp) ToolImpact +Sanding ToolImpact +Power tool ToolImpact +Drill ToolImpact +Explosion ExplosionWeapon +Gunshot, gunfire ExplosionWeapon +Machine gun ExplosionWeapon +Fusillade ExplosionWeapon +Artillery fire ExplosionWeapon +Cap gun ExplosionWeapon +Fireworks ExplosionWeapon +Firecracker ExplosionWeapon +Burst, pop ToolImpact +Eruption NatureWaterFire +Boom ToolImpact +Wood ToolImpact +Chop ToolImpact +Splinter ToolImpact +Crack ToolImpact +Glass ToolImpact +Chink, clink ToolImpact +Shatter ToolImpact +Liquid NatureWaterFire +Splash, splatter NatureWaterFire +Slosh NatureWaterFire +Squish NatureWaterFire +Drip NatureWaterFire +Pour NatureWaterFire +Trickle, dribble NatureWaterFire +Gush NatureWaterFire +Fill (with liquid) NatureWaterFire +Spray NatureWaterFire +Pump (liquid) MachineAppliance +Stir NatureWaterFire +Boiling NatureWaterFire +Sonar Other +Arrow Other +Whoosh, swoosh, swish ToolImpact +Thump, thud ToolImpact +Thunk ToolImpact +Electronic tuner MachineAppliance +Effects unit MachineAppliance +Chorus effect MachineAppliance +Basketball bounce ToolImpact +Bang ToolImpact +Slap, smack ToolImpact +Whack, thwack ToolImpact +Smash, crash ToolImpact +Breaking ToolImpact +Bouncing ToolImpact +Whip ToolImpact +Flap ToolImpact +Scratch ToolImpact +Scrape ToolImpact +Rub ToolImpact +Roll ToolImpact +Crushing ToolImpact +Crumpling, crinkling ToolImpact +Tearing ToolImpact +Beep, bleep AlarmSignal +Ping ToolImpact +Ding ToolImpact +Clang ToolImpact +Squeal ToolImpact +Creak ToolImpact +Rustle ToolImpact +Whir ToolImpact +Clatter ToolImpact +Sizzle ToolImpact +Clicking ToolImpact +Clickety-clack ToolImpact +Rumble ToolImpact +Plop ToolImpact +Jingle, tinkle ToolImpact +Hum HumanSpeech +Zing ToolImpact +Boing ToolImpact +Crunch HumanBodySound +Silence Other +Sine wave NoiseArtifact +Harmonic NoiseArtifact +Chirp tone NoiseArtifact +Sound effect Other +Pulse NoiseArtifact +Inside, small room CrowdAmbience +Inside, large room or hall CrowdAmbience +Inside, public space CrowdAmbience +Outside, urban or manmade CrowdAmbience +Outside, rural or natural CrowdAmbience +Reverberation CrowdAmbience +Echo CrowdAmbience +Noise NoiseArtifact +Environmental noise NoiseArtifact +Static NoiseArtifact +Mains hum NoiseArtifact +Distortion NoiseArtifact +Sidetone NoiseArtifact +Cacophony CrowdAmbience +White noise NoiseArtifact +Pink noise NoiseArtifact +Throbbing NoiseArtifact +Vibration NoiseArtifact +Television CrowdAmbience +Radio CrowdAmbience +Field recording CrowdAmbience diff --git a/runtime/ops/mapper/audio_sound_classify/models/recog/audioset_macro_map_v1.json b/runtime/ops/mapper/audio_sound_classify/models/recog/audioset_macro_map_v1.json new file mode 100644 index 00000000..69bc665b --- /dev/null +++ b/runtime/ops/mapper/audio_sound_classify/models/recog/audioset_macro_map_v1.json @@ -0,0 +1,133 @@ +{ + "HumanSpeech": [ + "Speech", + "Male speech, man speaking", + "Female speech, woman speaking", + "Child speech, kid speaking", + "Conversation", + "Narration, monologue", + "Whispering", + "Shout", + "Yell", + "Screaming", + "Laughter", + "Crying, sobbing", + "Singing", + "Rapping", + "Humming", + "Breathing", + "Cough", + "Sneeze" + ], + "Music": [ + "Music", + "Musical instrument", + "Vocal music", + "Song", + "Background music", + "Electronic music", + "Rock music", + "Classical music", + "Jazz", + "Hip hop music", + "Techno", + "House music", + "Dance music" + ], + "Animal": [ + "Animal", + "Domestic animals, pets", + "Dog", + "Cat", + "Bird", + "Insect", + "Livestock, farm animals, working animals" + ], + "Vehicle": [ + "Vehicle", + "Car", + "Truck", + "Bus", + "Train", + "Aircraft", + "Motorcycle", + "Traffic noise, roadway noise", + "Vehicle horn, car horn, honking" + ], + "EngineMachinery": [ + "Engine", + "Idling", + "Accelerating, revving, vroom", + "Medium engine (mid frequency)", + "Heavy engine (low frequency)", + "Mechanical fan", + "Air conditioning", + "Vacuum cleaner", + "Tools", + "Power tool", + "Drill", + "Jackhammer" + ], + "AlarmSiren": [ + "Siren", + "Buzzer", + "Alarm", + "Car alarm", + "Fire alarm", + "Smoke detector, smoke alarm", + "Telephone bell ringing", + "Ringtone" + ], + "ImpactClatter": [ + "Clang", + "Clatter", + "Chink, clink", + "Ding", + "Bang", + "Smash, crash", + "Breaking", + "Door", + "Doorbell", + "Knock", + "Tap" + ], + "GunshotExplosion": [ + "Explosion", + "Gunshot, gunfire", + "Machine gun", + "Fireworks", + "Firecracker" + ], + "Crowd": [ + "Crowd", + "Chatter", + "Cheering", + "Applause", + "Hubbub, speech noise, speech babble", + "Cacophony" + ], + "WindWater": [ + "Wind", + "Wind noise (microphone)", + "Thunderstorm", + "Thunder", + "Water", + "Rain", + "Waves, surf", + "Stream", + "Waterfall" + ], + "Silence": [ + "Silence" + ], + "Noise": [ + "Noise", + "Environmental noise", + "Static", + "Mains hum", + "White noise", + "Pink noise", + "Distortion" + ] +} + diff --git a/runtime/ops/mapper/audio_sound_classify/models/recog/class_labels_indices.csv b/runtime/ops/mapper/audio_sound_classify/models/recog/class_labels_indices.csv new file mode 100644 index 00000000..3a2767e8 --- /dev/null +++ b/runtime/ops/mapper/audio_sound_classify/models/recog/class_labels_indices.csv @@ -0,0 +1,528 @@ +index,mid,display_name +0,/m/09x0r,"Speech" +1,/m/05zppz,"Male speech, man speaking" +2,/m/02zsn,"Female speech, woman speaking" +3,/m/0ytgt,"Child speech, kid speaking" +4,/m/01h8n0,"Conversation" +5,/m/02qldy,"Narration, monologue" +6,/m/0261r1,"Babbling" +7,/m/0brhx,"Speech synthesizer" +8,/m/07p6fty,"Shout" +9,/m/07q4ntr,"Bellow" +10,/m/07rwj3x,"Whoop" +11,/m/07sr1lc,"Yell" +12,/m/04gy_2,"Battle cry" +13,/t/dd00135,"Children shouting" +14,/m/03qc9zr,"Screaming" +15,/m/02rtxlg,"Whispering" +16,/m/01j3sz,"Laughter" +17,/t/dd00001,"Baby laughter" +18,/m/07r660_,"Giggle" +19,/m/07s04w4,"Snicker" +20,/m/07sq110,"Belly laugh" +21,/m/07rgt08,"Chuckle, chortle" +22,/m/0463cq4,"Crying, sobbing" +23,/t/dd00002,"Baby cry, infant cry" +24,/m/07qz6j3,"Whimper" +25,/m/07qw_06,"Wail, moan" +26,/m/07plz5l,"Sigh" +27,/m/015lz1,"Singing" +28,/m/0l14jd,"Choir" +29,/m/01swy6,"Yodeling" +30,/m/02bk07,"Chant" +31,/m/01c194,"Mantra" +32,/t/dd00003,"Male singing" +33,/t/dd00004,"Female singing" +34,/t/dd00005,"Child singing" +35,/t/dd00006,"Synthetic singing" +36,/m/06bxc,"Rapping" +37,/m/02fxyj,"Humming" +38,/m/07s2xch,"Groan" +39,/m/07r4k75,"Grunt" +40,/m/01w250,"Whistling" +41,/m/0lyf6,"Breathing" +42,/m/07mzm6,"Wheeze" +43,/m/01d3sd,"Snoring" +44,/m/07s0dtb,"Gasp" +45,/m/07pyy8b,"Pant" +46,/m/07q0yl5,"Snort" +47,/m/01b_21,"Cough" +48,/m/0dl9sf8,"Throat clearing" +49,/m/01hsr_,"Sneeze" +50,/m/07ppn3j,"Sniff" +51,/m/06h7j,"Run" +52,/m/07qv_x_,"Shuffle" +53,/m/07pbtc8,"Walk, footsteps" +54,/m/03cczk,"Chewing, mastication" +55,/m/07pdhp0,"Biting" +56,/m/0939n_,"Gargling" +57,/m/01g90h,"Stomach rumble" +58,/m/03q5_w,"Burping, eructation" +59,/m/02p3nc,"Hiccup" +60,/m/02_nn,"Fart" +61,/m/0k65p,"Hands" +62,/m/025_jnm,"Finger snapping" +63,/m/0l15bq,"Clapping" +64,/m/01jg02,"Heart sounds, heartbeat" +65,/m/01jg1z,"Heart murmur" +66,/m/053hz1,"Cheering" +67,/m/028ght,"Applause" +68,/m/07rkbfh,"Chatter" +69,/m/03qtwd,"Crowd" +70,/m/07qfr4h,"Hubbub, speech noise, speech babble" +71,/t/dd00013,"Children playing" +72,/m/0jbk,"Animal" +73,/m/068hy,"Domestic animals, pets" +74,/m/0bt9lr,"Dog" +75,/m/05tny_,"Bark" +76,/m/07r_k2n,"Yip" +77,/m/07qf0zm,"Howl" +78,/m/07rc7d9,"Bow-wow" +79,/m/0ghcn6,"Growling" +80,/t/dd00136,"Whimper (dog)" +81,/m/01yrx,"Cat" +82,/m/02yds9,"Purr" +83,/m/07qrkrw,"Meow" +84,/m/07rjwbb,"Hiss" +85,/m/07r81j2,"Caterwaul" +86,/m/0ch8v,"Livestock, farm animals, working animals" +87,/m/03k3r,"Horse" +88,/m/07rv9rh,"Clip-clop" +89,/m/07q5rw0,"Neigh, whinny" +90,/m/01xq0k1,"Cattle, bovinae" +91,/m/07rpkh9,"Moo" +92,/m/0239kh,"Cowbell" +93,/m/068zj,"Pig" +94,/t/dd00018,"Oink" +95,/m/03fwl,"Goat" +96,/m/07q0h5t,"Bleat" +97,/m/07bgp,"Sheep" +98,/m/025rv6n,"Fowl" +99,/m/09b5t,"Chicken, rooster" +100,/m/07st89h,"Cluck" +101,/m/07qn5dc,"Crowing, cock-a-doodle-doo" +102,/m/01rd7k,"Turkey" +103,/m/07svc2k,"Gobble" +104,/m/09ddx,"Duck" +105,/m/07qdb04,"Quack" +106,/m/0dbvp,"Goose" +107,/m/07qwf61,"Honk" +108,/m/01280g,"Wild animals" +109,/m/0cdnk,"Roaring cats (lions, tigers)" +110,/m/04cvmfc,"Roar" +111,/m/015p6,"Bird" +112,/m/020bb7,"Bird vocalization, bird call, bird song" +113,/m/07pggtn,"Chirp, tweet" +114,/m/07sx8x_,"Squawk" +115,/m/0h0rv,"Pigeon, dove" +116,/m/07r_25d,"Coo" +117,/m/04s8yn,"Crow" +118,/m/07r5c2p,"Caw" +119,/m/09d5_,"Owl" +120,/m/07r_80w,"Hoot" +121,/m/05_wcq,"Bird flight, flapping wings" +122,/m/01z5f,"Canidae, dogs, wolves" +123,/m/06hps,"Rodents, rats, mice" +124,/m/04rmv,"Mouse" +125,/m/07r4gkf,"Patter" +126,/m/03vt0,"Insect" +127,/m/09xqv,"Cricket" +128,/m/09f96,"Mosquito" +129,/m/0h2mp,"Fly, housefly" +130,/m/07pjwq1,"Buzz" +131,/m/01h3n,"Bee, wasp, etc." +132,/m/09ld4,"Frog" +133,/m/07st88b,"Croak" +134,/m/078jl,"Snake" +135,/m/07qn4z3,"Rattle" +136,/m/032n05,"Whale vocalization" +137,/m/04rlf,"Music" +138,/m/04szw,"Musical instrument" +139,/m/0fx80y,"Plucked string instrument" +140,/m/0342h,"Guitar" +141,/m/02sgy,"Electric guitar" +142,/m/018vs,"Bass guitar" +143,/m/042v_gx,"Acoustic guitar" +144,/m/06w87,"Steel guitar, slide guitar" +145,/m/01glhc,"Tapping (guitar technique)" +146,/m/07s0s5r,"Strum" +147,/m/018j2,"Banjo" +148,/m/0jtg0,"Sitar" +149,/m/04rzd,"Mandolin" +150,/m/01bns_,"Zither" +151,/m/07xzm,"Ukulele" +152,/m/05148p4,"Keyboard (musical)" +153,/m/05r5c,"Piano" +154,/m/01s0ps,"Electric piano" +155,/m/013y1f,"Organ" +156,/m/03xq_f,"Electronic organ" +157,/m/03gvt,"Hammond organ" +158,/m/0l14qv,"Synthesizer" +159,/m/01v1d8,"Sampler" +160,/m/03q5t,"Harpsichord" +161,/m/0l14md,"Percussion" +162,/m/02hnl,"Drum kit" +163,/m/0cfdd,"Drum machine" +164,/m/026t6,"Drum" +165,/m/06rvn,"Snare drum" +166,/m/03t3fj,"Rimshot" +167,/m/02k_mr,"Drum roll" +168,/m/0bm02,"Bass drum" +169,/m/011k_j,"Timpani" +170,/m/01p970,"Tabla" +171,/m/01qbl,"Cymbal" +172,/m/03qtq,"Hi-hat" +173,/m/01sm1g,"Wood block" +174,/m/07brj,"Tambourine" +175,/m/05r5wn,"Rattle (instrument)" +176,/m/0xzly,"Maraca" +177,/m/0mbct,"Gong" +178,/m/016622,"Tubular bells" +179,/m/0j45pbj,"Mallet percussion" +180,/m/0dwsp,"Marimba, xylophone" +181,/m/0dwtp,"Glockenspiel" +182,/m/0dwt5,"Vibraphone" +183,/m/0l156b,"Steelpan" +184,/m/05pd6,"Orchestra" +185,/m/01kcd,"Brass instrument" +186,/m/0319l,"French horn" +187,/m/07gql,"Trumpet" +188,/m/07c6l,"Trombone" +189,/m/0l14_3,"Bowed string instrument" +190,/m/02qmj0d,"String section" +191,/m/07y_7,"Violin, fiddle" +192,/m/0d8_n,"Pizzicato" +193,/m/01xqw,"Cello" +194,/m/02fsn,"Double bass" +195,/m/085jw,"Wind instrument, woodwind instrument" +196,/m/0l14j_,"Flute" +197,/m/06ncr,"Saxophone" +198,/m/01wy6,"Clarinet" +199,/m/03m5k,"Harp" +200,/m/0395lw,"Bell" +201,/m/03w41f,"Church bell" +202,/m/027m70_,"Jingle bell" +203,/m/0gy1t2s,"Bicycle bell" +204,/m/07n_g,"Tuning fork" +205,/m/0f8s22,"Chime" +206,/m/026fgl,"Wind chime" +207,/m/0150b9,"Change ringing (campanology)" +208,/m/03qjg,"Harmonica" +209,/m/0mkg,"Accordion" +210,/m/0192l,"Bagpipes" +211,/m/02bxd,"Didgeridoo" +212,/m/0l14l2,"Shofar" +213,/m/07kc_,"Theremin" +214,/m/0l14t7,"Singing bowl" +215,/m/01hgjl,"Scratching (performance technique)" +216,/m/064t9,"Pop music" +217,/m/0glt670,"Hip hop music" +218,/m/02cz_7,"Beatboxing" +219,/m/06by7,"Rock music" +220,/m/03lty,"Heavy metal" +221,/m/05r6t,"Punk rock" +222,/m/0dls3,"Grunge" +223,/m/0dl5d,"Progressive rock" +224,/m/07sbbz2,"Rock and roll" +225,/m/05w3f,"Psychedelic rock" +226,/m/06j6l,"Rhythm and blues" +227,/m/0gywn,"Soul music" +228,/m/06cqb,"Reggae" +229,/m/01lyv,"Country" +230,/m/015y_n,"Swing music" +231,/m/0gg8l,"Bluegrass" +232,/m/02x8m,"Funk" +233,/m/02w4v,"Folk music" +234,/m/06j64v,"Middle Eastern music" +235,/m/03_d0,"Jazz" +236,/m/026z9,"Disco" +237,/m/0ggq0m,"Classical music" +238,/m/05lls,"Opera" +239,/m/02lkt,"Electronic music" +240,/m/03mb9,"House music" +241,/m/07gxw,"Techno" +242,/m/07s72n,"Dubstep" +243,/m/0283d,"Drum and bass" +244,/m/0m0jc,"Electronica" +245,/m/08cyft,"Electronic dance music" +246,/m/0fd3y,"Ambient music" +247,/m/07lnk,"Trance music" +248,/m/0g293,"Music of Latin America" +249,/m/0ln16,"Salsa music" +250,/m/0326g,"Flamenco" +251,/m/0155w,"Blues" +252,/m/05fw6t,"Music for children" +253,/m/02v2lh,"New-age music" +254,/m/0y4f8,"Vocal music" +255,/m/0z9c,"A capella" +256,/m/0164x2,"Music of Africa" +257,/m/0145m,"Afrobeat" +258,/m/02mscn,"Christian music" +259,/m/016cjb,"Gospel music" +260,/m/028sqc,"Music of Asia" +261,/m/015vgc,"Carnatic music" +262,/m/0dq0md,"Music of Bollywood" +263,/m/06rqw,"Ska" +264,/m/02p0sh1,"Traditional music" +265,/m/05rwpb,"Independent music" +266,/m/074ft,"Song" +267,/m/025td0t,"Background music" +268,/m/02cjck,"Theme music" +269,/m/03r5q_,"Jingle (music)" +270,/m/0l14gg,"Soundtrack music" +271,/m/07pkxdp,"Lullaby" +272,/m/01z7dr,"Video game music" +273,/m/0140xf,"Christmas music" +274,/m/0ggx5q,"Dance music" +275,/m/04wptg,"Wedding music" +276,/t/dd00031,"Happy music" +277,/t/dd00032,"Funny music" +278,/t/dd00033,"Sad music" +279,/t/dd00034,"Tender music" +280,/t/dd00035,"Exciting music" +281,/t/dd00036,"Angry music" +282,/t/dd00037,"Scary music" +283,/m/03m9d0z,"Wind" +284,/m/09t49,"Rustling leaves" +285,/t/dd00092,"Wind noise (microphone)" +286,/m/0jb2l,"Thunderstorm" +287,/m/0ngt1,"Thunder" +288,/m/0838f,"Water" +289,/m/06mb1,"Rain" +290,/m/07r10fb,"Raindrop" +291,/t/dd00038,"Rain on surface" +292,/m/0j6m2,"Stream" +293,/m/0j2kx,"Waterfall" +294,/m/05kq4,"Ocean" +295,/m/034srq,"Waves, surf" +296,/m/06wzb,"Steam" +297,/m/07swgks,"Gurgling" +298,/m/02_41,"Fire" +299,/m/07pzfmf,"Crackle" +300,/m/07yv9,"Vehicle" +301,/m/019jd,"Boat, Water vehicle" +302,/m/0hsrw,"Sailboat, sailing ship" +303,/m/056ks2,"Rowboat, canoe, kayak" +304,/m/02rlv9,"Motorboat, speedboat" +305,/m/06q74,"Ship" +306,/m/012f08,"Motor vehicle (road)" +307,/m/0k4j,"Car" +308,/m/0912c9,"Vehicle horn, car horn, honking" +309,/m/07qv_d5,"Toot" +310,/m/02mfyn,"Car alarm" +311,/m/04gxbd,"Power windows, electric windows" +312,/m/07rknqz,"Skidding" +313,/m/0h9mv,"Tire squeal" +314,/t/dd00134,"Car passing by" +315,/m/0ltv,"Race car, auto racing" +316,/m/07r04,"Truck" +317,/m/0gvgw0,"Air brake" +318,/m/05x_td,"Air horn, truck horn" +319,/m/02rhddq,"Reversing beeps" +320,/m/03cl9h,"Ice cream truck, ice cream van" +321,/m/01bjv,"Bus" +322,/m/03j1ly,"Emergency vehicle" +323,/m/04qvtq,"Police car (siren)" +324,/m/012n7d,"Ambulance (siren)" +325,/m/012ndj,"Fire engine, fire truck (siren)" +326,/m/04_sv,"Motorcycle" +327,/m/0btp2,"Traffic noise, roadway noise" +328,/m/06d_3,"Rail transport" +329,/m/07jdr,"Train" +330,/m/04zmvq,"Train whistle" +331,/m/0284vy3,"Train horn" +332,/m/01g50p,"Railroad car, train wagon" +333,/t/dd00048,"Train wheels squealing" +334,/m/0195fx,"Subway, metro, underground" +335,/m/0k5j,"Aircraft" +336,/m/014yck,"Aircraft engine" +337,/m/04229,"Jet engine" +338,/m/02l6bg,"Propeller, airscrew" +339,/m/09ct_,"Helicopter" +340,/m/0cmf2,"Fixed-wing aircraft, airplane" +341,/m/0199g,"Bicycle" +342,/m/06_fw,"Skateboard" +343,/m/02mk9,"Engine" +344,/t/dd00065,"Light engine (high frequency)" +345,/m/08j51y,"Dental drill, dentist's drill" +346,/m/01yg9g,"Lawn mower" +347,/m/01j4z9,"Chainsaw" +348,/t/dd00066,"Medium engine (mid frequency)" +349,/t/dd00067,"Heavy engine (low frequency)" +350,/m/01h82_,"Engine knocking" +351,/t/dd00130,"Engine starting" +352,/m/07pb8fc,"Idling" +353,/m/07q2z82,"Accelerating, revving, vroom" +354,/m/02dgv,"Door" +355,/m/03wwcy,"Doorbell" +356,/m/07r67yg,"Ding-dong" +357,/m/02y_763,"Sliding door" +358,/m/07rjzl8,"Slam" +359,/m/07r4wb8,"Knock" +360,/m/07qcpgn,"Tap" +361,/m/07q6cd_,"Squeak" +362,/m/0642b4,"Cupboard open or close" +363,/m/0fqfqc,"Drawer open or close" +364,/m/04brg2,"Dishes, pots, and pans" +365,/m/023pjk,"Cutlery, silverware" +366,/m/07pn_8q,"Chopping (food)" +367,/m/0dxrf,"Frying (food)" +368,/m/0fx9l,"Microwave oven" +369,/m/02pjr4,"Blender" +370,/m/02jz0l,"Water tap, faucet" +371,/m/0130jx,"Sink (filling or washing)" +372,/m/03dnzn,"Bathtub (filling or washing)" +373,/m/03wvsk,"Hair dryer" +374,/m/01jt3m,"Toilet flush" +375,/m/012xff,"Toothbrush" +376,/m/04fgwm,"Electric toothbrush" +377,/m/0d31p,"Vacuum cleaner" +378,/m/01s0vc,"Zipper (clothing)" +379,/m/03v3yw,"Keys jangling" +380,/m/0242l,"Coin (dropping)" +381,/m/01lsmm,"Scissors" +382,/m/02g901,"Electric shaver, electric razor" +383,/m/05rj2,"Shuffling cards" +384,/m/0316dw,"Typing" +385,/m/0c2wf,"Typewriter" +386,/m/01m2v,"Computer keyboard" +387,/m/081rb,"Writing" +388,/m/07pp_mv,"Alarm" +389,/m/07cx4,"Telephone" +390,/m/07pp8cl,"Telephone bell ringing" +391,/m/01hnzm,"Ringtone" +392,/m/02c8p,"Telephone dialing, DTMF" +393,/m/015jpf,"Dial tone" +394,/m/01z47d,"Busy signal" +395,/m/046dlr,"Alarm clock" +396,/m/03kmc9,"Siren" +397,/m/0dgbq,"Civil defense siren" +398,/m/030rvx,"Buzzer" +399,/m/01y3hg,"Smoke detector, smoke alarm" +400,/m/0c3f7m,"Fire alarm" +401,/m/04fq5q,"Foghorn" +402,/m/0l156k,"Whistle" +403,/m/06hck5,"Steam whistle" +404,/t/dd00077,"Mechanisms" +405,/m/02bm9n,"Ratchet, pawl" +406,/m/01x3z,"Clock" +407,/m/07qjznt,"Tick" +408,/m/07qjznl,"Tick-tock" +409,/m/0l7xg,"Gears" +410,/m/05zc1,"Pulleys" +411,/m/0llzx,"Sewing machine" +412,/m/02x984l,"Mechanical fan" +413,/m/025wky1,"Air conditioning" +414,/m/024dl,"Cash register" +415,/m/01m4t,"Printer" +416,/m/0dv5r,"Camera" +417,/m/07bjf,"Single-lens reflex camera" +418,/m/07k1x,"Tools" +419,/m/03l9g,"Hammer" +420,/m/03p19w,"Jackhammer" +421,/m/01b82r,"Sawing" +422,/m/02p01q,"Filing (rasp)" +423,/m/023vsd,"Sanding" +424,/m/0_ksk,"Power tool" +425,/m/01d380,"Drill" +426,/m/014zdl,"Explosion" +427,/m/032s66,"Gunshot, gunfire" +428,/m/04zjc,"Machine gun" +429,/m/02z32qm,"Fusillade" +430,/m/0_1c,"Artillery fire" +431,/m/073cg4,"Cap gun" +432,/m/0g6b5,"Fireworks" +433,/g/122z_qxw,"Firecracker" +434,/m/07qsvvw,"Burst, pop" +435,/m/07pxg6y,"Eruption" +436,/m/07qqyl4,"Boom" +437,/m/083vt,"Wood" +438,/m/07pczhz,"Chop" +439,/m/07pl1bw,"Splinter" +440,/m/07qs1cx,"Crack" +441,/m/039jq,"Glass" +442,/m/07q7njn,"Chink, clink" +443,/m/07rn7sz,"Shatter" +444,/m/04k94,"Liquid" +445,/m/07rrlb6,"Splash, splatter" +446,/m/07p6mqd,"Slosh" +447,/m/07qlwh6,"Squish" +448,/m/07r5v4s,"Drip" +449,/m/07prgkl,"Pour" +450,/m/07pqc89,"Trickle, dribble" +451,/t/dd00088,"Gush" +452,/m/07p7b8y,"Fill (with liquid)" +453,/m/07qlf79,"Spray" +454,/m/07ptzwd,"Pump (liquid)" +455,/m/07ptfmf,"Stir" +456,/m/0dv3j,"Boiling" +457,/m/0790c,"Sonar" +458,/m/0dl83,"Arrow" +459,/m/07rqsjt,"Whoosh, swoosh, swish" +460,/m/07qnq_y,"Thump, thud" +461,/m/07rrh0c,"Thunk" +462,/m/0b_fwt,"Electronic tuner" +463,/m/02rr_,"Effects unit" +464,/m/07m2kt,"Chorus effect" +465,/m/018w8,"Basketball bounce" +466,/m/07pws3f,"Bang" +467,/m/07ryjzk,"Slap, smack" +468,/m/07rdhzs,"Whack, thwack" +469,/m/07pjjrj,"Smash, crash" +470,/m/07pc8lb,"Breaking" +471,/m/07pqn27,"Bouncing" +472,/m/07rbp7_,"Whip" +473,/m/07pyf11,"Flap" +474,/m/07qb_dv,"Scratch" +475,/m/07qv4k0,"Scrape" +476,/m/07pdjhy,"Rub" +477,/m/07s8j8t,"Roll" +478,/m/07plct2,"Crushing" +479,/t/dd00112,"Crumpling, crinkling" +480,/m/07qcx4z,"Tearing" +481,/m/02fs_r,"Beep, bleep" +482,/m/07qwdck,"Ping" +483,/m/07phxs1,"Ding" +484,/m/07rv4dm,"Clang" +485,/m/07s02z0,"Squeal" +486,/m/07qh7jl,"Creak" +487,/m/07qwyj0,"Rustle" +488,/m/07s34ls,"Whir" +489,/m/07qmpdm,"Clatter" +490,/m/07p9k1k,"Sizzle" +491,/m/07qc9xj,"Clicking" +492,/m/07rwm0c,"Clickety-clack" +493,/m/07phhsh,"Rumble" +494,/m/07qyrcz,"Plop" +495,/m/07qfgpx,"Jingle, tinkle" +496,/m/07rcgpl,"Hum" +497,/m/07p78v5,"Zing" +498,/t/dd00121,"Boing" +499,/m/07s12q4,"Crunch" +500,/m/028v0c,"Silence" +501,/m/01v_m0,"Sine wave" +502,/m/0b9m1,"Harmonic" +503,/m/0hdsk,"Chirp tone" +504,/m/0c1dj,"Sound effect" +505,/m/07pt_g0,"Pulse" +506,/t/dd00125,"Inside, small room" +507,/t/dd00126,"Inside, large room or hall" +508,/t/dd00127,"Inside, public space" +509,/t/dd00128,"Outside, urban or manmade" +510,/t/dd00129,"Outside, rural or natural" +511,/m/01b9nn,"Reverberation" +512,/m/01jnbd,"Echo" +513,/m/096m7z,"Noise" +514,/m/06_y0by,"Environmental noise" +515,/m/07rgkc5,"Static" +516,/m/06xkwv,"Mains hum" +517,/m/0g12c5,"Distortion" +518,/m/08p9q4,"Sidetone" +519,/m/07szfh9,"Cacophony" +520,/m/0chx_,"White noise" +521,/m/0cj0r,"Pink noise" +522,/m/07p_0gm,"Throbbing" +523,/m/01jwx6,"Vibration" +524,/m/07c52,"Television" +525,/m/06bz3,"Radio" +526,/m/07hvw1,"Field recording" diff --git a/runtime/ops/mapper/audio_sound_classify/process.py b/runtime/ops/mapper/audio_sound_classify/process.py new file mode 100644 index 00000000..f8d77322 --- /dev/null +++ b/runtime/ops/mapper/audio_sound_classify/process.py @@ -0,0 +1,559 @@ +# -- encoding: utf-8 -- + +from __future__ import annotations + +import csv +import json +import re +import tempfile +import time +from collections import defaultdict +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, Iterable, List, Literal, Tuple + +import numpy as np +try: + from loguru import logger +except Exception: + import logging + + logger = logging.getLogger(__name__) + +from datamate.core.base_op import Mapper +try: + from .audio_skip import invalid_quality_reason, is_audio_sample, mark_skipped_sample +except ImportError: + from audio_skip import invalid_quality_reason, is_audio_sample, mark_skipped_sample + + +DEFAULT_PANNS_CHECKPOINT = "/models/AudioOperations/panns/Cnn14_16k_mAP=0.438.pth" +DEFAULT_AST_CHECKPOINT = "/models/AudioOperations/recog/audioset_10_10_0.4593.pth" + + +def _package_root() -> Path: + return Path(__file__).resolve().parent + + +def _resolve_path(value: str, fallback: Path) -> Path: + raw = str(value or "").strip() + if raw: + p = Path(raw).expanduser() + if p.exists(): + return p.resolve() + return fallback.resolve() + + +def _audio_ext(sample: Dict[str, Any], default_ext: str = "wav") -> str: + ext = str(sample.get("target_type") or sample.get("fileType") or default_ext).strip().lower().lstrip(".") + return ext or default_ext + + +def _sample_key(sample: Dict[str, Any], audio_path: Path, filename_key: str) -> str: + file_name = str(sample.get(filename_key) or "").strip() + if file_name: + return Path(file_name).stem or audio_path.stem + return audio_path.stem + + +def _safe_marker(value: str, default: str = "unknown") -> str: + marker = re.sub(r"[^A-Za-z0-9._-]+", "_", str(value or default)).strip("._-") + return marker[:80] or default + + +def _strip_sound_marker(stem: str) -> str: + return re.sub(r"__sound_[A-Za-z0-9._-]+$", "", str(stem or "sample")) + + +def _mark_sound_filename(sample: Dict[str, Any], filename_key: str, label: str, target_ext: str) -> None: + file_name = str(sample.get(filename_key) or "").strip() + stem = _strip_sound_marker(Path(file_name).stem if file_name else "sample") + sample[filename_key] = f"{stem}__sound_{_safe_marker(label)}.{target_ext}" + + +def _load_audio_16k(path: Path, sr: int = 16000) -> np.ndarray: + import librosa # type: ignore + + audio, _ = librosa.core.load(str(path), sr=sr, mono=True) + if audio.dtype != np.float32: + audio = audio.astype(np.float32, copy=False) + return np.ascontiguousarray(audio) + + +def _load_audio_16k_mono(path: Path) -> np.ndarray: + try: + import soundfile as sf # type: ignore + from scipy.signal import resample_poly # type: ignore + + data, sr = sf.read(str(path), always_2d=True) + if data.shape[1] > 1: + data = data.mean(axis=1, keepdims=True) + wav = data[:, 0] + if int(sr) != 16000: + g = np.gcd(int(sr), 16000) + wav = resample_poly(wav, 16000 // g, int(sr) // g).astype(np.float32, copy=False) + if wav.dtype != np.float32: + wav = wav.astype(np.float32, copy=False) + return np.ascontiguousarray(wav) + except Exception: + return _load_audio_16k(path, sr=16000) + + +def _load_label_macro_map(tsv_path: Path) -> Dict[str, str]: + label_to_macro: Dict[str, str] = {} + with tsv_path.open(encoding="utf-8", newline="") as f: + reader = csv.DictReader(f, delimiter="\t") + for row in reader: + label = str(row.get("label") or "").strip() + macro = str(row.get("macro_class") or "").strip() + if label and macro: + label_to_macro[label] = macro + if not label_to_macro: + raise ValueError(f"音频分类大类映射为空: {tsv_path}") + return label_to_macro + + +@dataclass(frozen=True) +class MacroMap: + macro_to_labels: Dict[str, List[str]] + label_to_macro: Dict[str, str] + + +def _load_macro_map_json(path: Path) -> MacroMap: + obj = json.loads(path.read_text(encoding="utf-8")) + if not isinstance(obj, dict): + raise ValueError(f"AST 大类映射必须是 JSON object: {path}") + macro_to_labels: Dict[str, List[str]] = {} + for macro, labels in obj.items(): + if not isinstance(labels, list): + raise ValueError(f"AST 大类映射格式错误: {macro}") + macro_to_labels[str(macro)] = [str(label).strip() for label in labels if str(label).strip()] + label_to_macro: Dict[str, str] = {} + for macro, labels in macro_to_labels.items(): + for label in labels: + label_to_macro[label] = macro + return MacroMap(macro_to_labels=macro_to_labels, label_to_macro=label_to_macro) + + +def _load_audioset_labels_csv(csv_path: Path) -> List[str]: + rows: List[Tuple[int, str]] = [] + with csv_path.open(encoding="utf-8", newline="") as f: + reader = csv.DictReader(f) + for row in reader: + rows.append((int(row["index"]), str(row["display_name"]).strip())) + rows.sort(key=lambda x: x[0]) + labels = [label for _idx, label in rows] + if not labels: + raise ValueError(f"AudioSet labels 为空: {csv_path}") + return labels + + +def _macro_from_topk( + labels: List[str], + probs: np.ndarray, + top_idx: np.ndarray, + label_to_macro: Dict[str, str], +) -> Tuple[str, Dict[str, float]]: + scores: Dict[str, float] = defaultdict(float) + for i in top_idx.tolist(): + label = str(labels[i]) + macro = label_to_macro.get(label, "Other") + scores[macro] += float(probs[i]) + if not scores: + return "Other", {} + best_macro = max(scores, key=lambda k: scores[k]) + return best_macro, dict(scores) + + +def _decide_macro_class( + labels: List[str], + probs: np.ndarray, + top_idx: np.ndarray, + label_to_macro: Dict[str, str], + speech_threshold: float, +) -> Tuple[str, Dict[str, float], Dict[str, float]]: + best_macro, macro_scores = _macro_from_topk(labels, probs, top_idx, label_to_macro) + human_speech_score = float(macro_scores.get("HumanSpeech", 0.0)) + final_macro = "HumanSpeech" if human_speech_score > float(speech_threshold) else best_macro + return final_macro, macro_scores, {"HumanSpeech": human_speech_score, "topk": float(len(top_idx))} + + +_MODEL_CACHE: Dict[Tuple[str, str], Any] = {} +_AST_MODEL_CACHE: Dict[Tuple[str, str], Tuple[Any, Any]] = {} + + +def _load_tagger(checkpoint_path: Path, device: str): + cache_key = (str(checkpoint_path), str(device)) + if cache_key in _MODEL_CACHE: + return _MODEL_CACHE[cache_key] + + from panns_inference import AudioTagging # type: ignore + from panns_inference.config import classes_num # type: ignore + from panns_inference.models import Cnn14 # type: ignore + + model = Cnn14( + sample_rate=16000, + window_size=512, + hop_size=160, + mel_bins=64, + fmin=50, + fmax=8000, + classes_num=classes_num, + ) + tagger = AudioTagging(model=model, checkpoint_path=str(checkpoint_path), device=str(device)) + _MODEL_CACHE[cache_key] = tagger + return tagger + + +def _detect_torch_device(device_arg: str): + import torch + + dev = str(device_arg or "auto").strip().lower() + if dev == "cpu": + return torch.device("cpu") + if dev == "cuda": + return torch.device("cuda" if torch.cuda.is_available() else "cpu") + if dev == "npu": + try: + import torch_npu # type: ignore # noqa: F401 + return torch.device("npu") + except Exception: + return torch.device("privateuseone") + if dev == "auto": + try: + import torch_npu # type: ignore # noqa: F401 + try: + return torch.device("npu") + except Exception: + return torch.device("privateuseone") + except Exception: + if torch.cuda.is_available(): + return torch.device("cuda") + return torch.device("cpu") + raise ValueError(f"不支持的音频分类设备: {device_arg}") + + +def _log_mel_128(wav_16k: np.ndarray) -> np.ndarray: + import librosa # type: ignore + + mel = librosa.feature.melspectrogram( + y=wav_16k, + sr=16000, + n_fft=400, + hop_length=160, + win_length=400, + window="hann", + center=True, + pad_mode="reflect", + power=2.0, + n_mels=128, + fmin=0, + fmax=8000, + ) + log_mel = np.log(mel + 1e-10).T + if log_mel.dtype != np.float32: + log_mel = log_mel.astype(np.float32, copy=False) + return np.ascontiguousarray(log_mel) + + +def _audioset_norm(spec: np.ndarray) -> np.ndarray: + return (spec + 4.26) / (4.57 * 2.0) + + +def _sliding_windows(wav: np.ndarray, *, segment_sec: float, hop_sec: float) -> Iterable[np.ndarray]: + seg_len = int(round(float(segment_sec) * 16000)) + hop_len = int(round(float(hop_sec) * 16000)) + if seg_len <= 0: + raise ValueError("segment_sec 必须大于 0") + if hop_len <= 0: + hop_len = seg_len + n = int(wav.shape[0]) + if n <= seg_len: + pad = seg_len - n + yield np.pad(wav, (0, pad), mode="constant") if pad > 0 else wav + return + start = 0 + while start < n: + end = start + seg_len + if end <= n: + yield wav[start:end] + else: + yield np.pad(wav[start:n], (0, end - n), mode="constant") + if end >= n: + break + start += hop_len + + +MacroAgg = Literal["max", "sum"] + + +def _macro_scores_from_probs(labels: List[str], probs: np.ndarray, macro_map: MacroMap, macro_agg: MacroAgg) -> Dict[str, float]: + name_to_idx = {name: i for i, name in enumerate(labels)} + scores: Dict[str, float] = {} + for macro, names in macro_map.macro_to_labels.items(): + idxs = [name_to_idx[name] for name in names if name in name_to_idx] + if not idxs: + scores[macro] = 0.0 + continue + vals = probs[idxs] + scores[macro] = float(np.sum(vals)) if macro_agg == "sum" else float(np.max(vals)) + return scores + + +def _topk_labels(labels: List[str], probs: np.ndarray, k: int, label_to_macro: Dict[str, str]) -> List[Dict[str, object]]: + topk = max(1, min(int(k), len(labels))) + idx = np.argsort(probs)[::-1][:topk] + return [ + { + "label": str(labels[i]), + "macro_class": label_to_macro.get(str(labels[i]), "Other"), + "prob": round(float(probs[i]), 8), + } + for i in idx + ] + + +def _load_ast_model(checkpoint_path: Path, labels_count: int, device): + cache_key = (str(checkpoint_path), str(device)) + if cache_key in _AST_MODEL_CACHE: + return _AST_MODEL_CACHE[cache_key] + try: + from .ast_vendor import ASTConfig, load_ast_from_pth # type: ignore + except ImportError: + from ast_vendor import ASTConfig, load_ast_from_pth # type: ignore + + cfg = ASTConfig(label_dim=int(labels_count), input_fdim=128, input_tdim=1024, fstride=10, tstride=10, model_size="base384") + model = load_ast_from_pth(checkpoint_path=str(checkpoint_path), device=device, cfg=cfg) + _AST_MODEL_CACHE[cache_key] = (model, device) + return model, device + + +def _infer_ast( + audio_path: Path, + checkpoint_path: Path, + labels_csv: Path, + macro_map_path: Path, + device_arg: str, + topk: int, + segment_sec: float, + hop_sec: float, + macro_agg: MacroAgg, +) -> Dict[str, Any]: + import torch + + labels = _load_audioset_labels_csv(labels_csv) + macro_map = _load_macro_map_json(macro_map_path) + device = _detect_torch_device(device_arg) + model, device = _load_ast_model(checkpoint_path, len(labels), device) + wav = _load_audio_16k_mono(audio_path) + + macro_scores_sum: Dict[str, float] = {} + probs_sum = None + probs_n = 0 + segment_count = 0 + for seg_wav in _sliding_windows(wav, segment_sec=float(segment_sec), hop_sec=float(hop_sec)): + spec = _audioset_norm(_log_mel_128(seg_wav)) + if spec.shape[0] < 1024: + spec = np.pad(spec, ((0, 1024 - int(spec.shape[0])), (0, 0)), mode="constant") + else: + spec = spec[:1024, :] + x = torch.from_numpy(spec).unsqueeze(0).to(device=device, dtype=torch.float32) + with torch.inference_mode(): + logits = model(x)[0] + probs = torch.sigmoid(logits).detach().cpu().to(torch.float32).numpy() + scores = _macro_scores_from_probs(labels, probs, macro_map, macro_agg=macro_agg) + for key, value in scores.items(): + macro_scores_sum[key] = macro_scores_sum.get(key, 0.0) + float(value) + probs_sum = probs.astype(np.float64, copy=True) if probs_sum is None else probs_sum + probs + probs_n += 1 + segment_count += 1 + + if probs_sum is None or probs_n <= 0: + raise RuntimeError("AST 分类未产生有效分段概率") + macro_scores = {key: value / float(probs_n) for key, value in macro_scores_sum.items()} + pred_macro = max(macro_scores, key=lambda k: macro_scores[k]) if macro_scores else "Noise" + probs_mean = (probs_sum / float(probs_n)).astype(np.float32, copy=False) + return { + "macro_class": pred_macro, + "macro_scores": {k: round(float(v), 8) for k, v in macro_scores.items()}, + "small_topk": _topk_labels(labels, probs_mean, topk, macro_map.label_to_macro), + "model": "AST AudioSet 10_10_0.4593", + "checkpoint": str(checkpoint_path), + "macro_map": str(macro_map_path), + "labels_csv": str(labels_csv), + "device": str(device), + "segments": segment_count, + "segment_sec": float(segment_sec), + "hop_sec": float(hop_sec), + "macro_agg": macro_agg, + } + + +class AudioSoundClassify(Mapper): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.backend = str(kwargs.get("backend", "ast")).strip().lower() + compat_checkpoint = str(kwargs.get("checkpoint", "")).strip() + self.panns_checkpoint = str( + kwargs.get("pannsCheckpoint") or (compat_checkpoint if self.backend == "panns" else "") or DEFAULT_PANNS_CHECKPOINT + ).strip() + self.ast_checkpoint = str( + kwargs.get("astCheckpoint") or (compat_checkpoint if self.backend == "ast" else "") or DEFAULT_AST_CHECKPOINT + ).strip() + self.macro_map = str(kwargs.get("macroMap", "")).strip() + self.ast_macro_map = str(kwargs.get("astMacroMap", "")).strip() + self.labels_csv = str(kwargs.get("labelsCsv", "")).strip() + self.device = str(kwargs.get("device", "auto")).strip().lower() + self.topk = int(float(kwargs.get("topK", 10))) + self.speech_threshold = float(kwargs.get("humanSpeechThreshold", 0.2)) + self.segment_sec = float(kwargs.get("segmentSeconds", 10.24)) + self.hop_sec = float(kwargs.get("hopSeconds", 5.12)) + self.macro_agg = str(kwargs.get("macroAgg", "max")).strip().lower() + self.keep_audio = str(kwargs.get("keepAudio", "true")).strip().lower() in {"1", "true", "yes", "y", "on"} + + def execute(self, sample: Dict[str, Any]) -> Dict[str, Any]: + start = time.time() + quality_skip_reason = invalid_quality_reason(sample, self.ext_params_key) + if quality_skip_reason: + return mark_skipped_sample( + sample, + quality_skip_reason, + self.__class__.__name__, + self.text_key, + self.data_key, + self.filetype_key, + self.target_type_key, + self.ext_params_key, + ) + + if not is_audio_sample(sample, self.filepath_key, self.filetype_key, self.target_type_key, self.data_key): + return mark_skipped_sample( + sample, + "non_audio_or_reference_file", + self.__class__.__name__, + self.text_key, + self.data_key, + self.filetype_key, + self.target_type_key, + self.ext_params_key, + ) + + package_root = _package_root() + + data = sample.get(self.data_key) + audio_bytes = b"" + with tempfile.TemporaryDirectory(prefix="dm_audio_sound_classify_") as td: + work_dir = Path(td) + if isinstance(data, (bytes, bytearray)) and data: + audio_bytes = bytes(data) + audio_path = work_dir / f"input.{_audio_ext(sample)}" + audio_path.write_bytes(audio_bytes) + else: + audio_path = Path(str(sample.get(self.filepath_key, ""))).expanduser().resolve() + if not audio_path.exists(): + raise FileNotFoundError(f"输入音频不存在: {audio_path}") + if self.keep_audio or self.is_last_op: + audio_bytes = audio_path.read_bytes() + audio_path_for_infer = audio_path + + if self.backend == "ast": + checkpoint_path = _resolve_path(self.ast_checkpoint, Path(DEFAULT_AST_CHECKPOINT)) + labels_csv = _resolve_path( + self.labels_csv, + package_root / "models" / "recog" / "class_labels_indices.csv", + ) + macro_map_path = _resolve_path( + self.ast_macro_map, + package_root / "models" / "recog" / "audioset_macro_map_v1.json", + ) + if not checkpoint_path.exists(): + raise FileNotFoundError(f"AST 分类模型不存在: {checkpoint_path}") + if not labels_csv.exists(): + raise FileNotFoundError(f"AudioSet labels CSV 不存在: {labels_csv}") + if not macro_map_path.exists(): + raise FileNotFoundError(f"AST 大类映射不存在: {macro_map_path}") + if self.macro_agg not in {"max", "sum"}: + raise ValueError(f"不支持的 macroAgg: {self.macro_agg}") + result_core = _infer_ast( + audio_path_for_infer, + checkpoint_path, + labels_csv, + macro_map_path, + self.device, + self.topk, + self.segment_sec, + self.hop_sec, + self.macro_agg, # type: ignore[arg-type] + ) + elif self.backend == "panns": + checkpoint_path = _resolve_path(self.panns_checkpoint, Path(DEFAULT_PANNS_CHECKPOINT)) + fallback_macro = package_root / "models" / "panns" / "classes_macro_draft.tsv" + macro_map_path = _resolve_path(self.macro_map, fallback_macro) + if not checkpoint_path.exists(): + raise FileNotFoundError(f"PANNs 分类模型不存在: {checkpoint_path}") + if not macro_map_path.exists(): + raise FileNotFoundError(f"音频分类大类映射不存在: {macro_map_path}") + label_to_macro = _load_label_macro_map(macro_map_path) + tagger = _load_tagger(checkpoint_path, self.device) + audio = _load_audio_16k(audio_path_for_infer, sr=16000) + clipwise_output, _embedding = tagger.inference(audio[None, :]) + probs = clipwise_output[0] + labels = list(tagger.labels) + topk = max(1, min(int(self.topk), len(labels))) + top_idx = np.argsort(probs)[::-1][:topk] + final_macro, macro_scores, rule_scores = _decide_macro_class( + labels, + probs, + top_idx, + label_to_macro, + self.speech_threshold, + ) + result_core = { + "macro_class": final_macro, + "macro_scores": {k: round(float(v), 8) for k, v in macro_scores.items()}, + "macro_rule_scores": {k: round(float(v), 8) for k, v in rule_scores.items()}, + "small_topk": [ + { + "label": str(labels[i]), + "macro_class": label_to_macro.get(str(labels[i]), "Other"), + "prob": round(float(probs[i]), 8), + } + for i in top_idx + ], + "model": "PANNs Cnn14 16k AudioSet", + "checkpoint": str(checkpoint_path), + "macro_map": str(macro_map_path), + "device": self.device, + } + else: + raise ValueError(f"不支持的音频分类后端: {self.backend}") + + key = _sample_key(sample, audio_path, self.filename_key) + result = { + "key": key, + "backend": self.backend, + **result_core, + } + + ext = sample.get(self.ext_params_key, {}) + if not isinstance(ext, dict): + ext = {"_raw": ext} + ext["audio_sound_classify"] = result + sample[self.ext_params_key] = ext + + target_ext = _audio_ext(sample) + if audio_bytes: + sample[self.data_key] = audio_bytes + sample[self.text_key] = "" + if self.is_last_op: + sample[self.filetype_key] = "txt" + sample[self.target_type_key] = target_ext + else: + sample[self.filetype_key] = target_ext + sample[self.target_type_key] = target_ext + _mark_sound_filename(sample, self.filename_key, str(result.get("macro_class") or "unknown"), target_ext) + + logger.info( + f"fileName: {sample.get(self.filename_key)}, method: AudioSoundClassify costs {time.time() - start:6f} s" + ) + return sample diff --git a/runtime/ops/mapper/audio_sound_classify/requirements.txt b/runtime/ops/mapper/audio_sound_classify/requirements.txt new file mode 100644 index 00000000..be1607ad --- /dev/null +++ b/runtime/ops/mapper/audio_sound_classify/requirements.txt @@ -0,0 +1,9 @@ +torch==2.8.0 +torchlibrosa==0.0.4 +timm==1.0.26 +librosa==0.10.2.post1 +numpy==2.2.6 +soundfile==0.12.1 +scipy==1.13.1 +loguru==0.7.3 +panns-inference==0.1.1 diff --git a/runtime/ops/mapper/audio_telephony_bandpass/README.md b/runtime/ops/mapper/audio_telephony_bandpass/README.md new file mode 100644 index 00000000..3def8fcd --- /dev/null +++ b/runtime/ops/mapper/audio_telephony_bandpass/README.md @@ -0,0 +1,26 @@ +# AudioTelephonyBandpass 电话带通算子 + +## 概述 + +AudioTelephonyBandpass 处理输入音频,并将结果写入 `sample["data"]`,同时设置 `sample["target_type"]`。输出路径、同名文件处理和最终落盘均交由 DataMate 的标准导出流程负责。 + +## 参数说明 + +| 参数 | 类型 | 默认值 | 说明 | +|---|---|---:|---| +| lowHz | inputNumber | 300 | 下截止频率(Hz) | +| highHz | inputNumber | 3400 | 上截止频率(Hz) | +| order | inputNumber | 4 | Butterworth 阶数 | + +## 输入输出 + +- **输入**:`sample["filePath"]`,若上游算子已产生 `sample["data"]`,则优先处理该音频字节。 +- **输出**:`sample["data"]` 为处理后的音频字节;`sample["target_type"]` 为目标音频后缀。 + +## 依赖说明 + +- **Python 依赖**:soundfile、numpy、scipy(scipy.signal) + +## 版本历史 + +- **v1.0.0**:首次发布 diff --git a/runtime/ops/mapper/audio_telephony_bandpass/__init__.py b/runtime/ops/mapper/audio_telephony_bandpass/__init__.py new file mode 100644 index 00000000..a303f38f --- /dev/null +++ b/runtime/ops/mapper/audio_telephony_bandpass/__init__.py @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- + +from datamate.core.base_op import OPERATORS + +OPERATORS.register_module(module_name='AudioTelephonyBandpass', + module_path="ops.mapper.audio_telephony_bandpass.process") diff --git a/runtime/ops/mapper/audio_telephony_bandpass/audio_skip.py b/runtime/ops/mapper/audio_telephony_bandpass/audio_skip.py new file mode 100644 index 00000000..aec49613 --- /dev/null +++ b/runtime/ops/mapper/audio_telephony_bandpass/audio_skip.py @@ -0,0 +1,114 @@ +# -- encoding: utf-8 -- + +from pathlib import Path +from typing import Any, Dict + +from loguru import logger + + +AUDIO_EXTS = { + "aac", + "aif", + "aiff", + "amr", + "au", + "flac", + "m4a", + "mp3", + "oga", + "ogg", + "opus", + "snd", + "wav", + "webm", + "wma", +} + + +def _parts(path_value: str) -> set[str]: + try: + return {part.lower() for part in Path(path_value).parts} + except Exception: + return set() + + +def is_reference_sample(sample: Dict[str, Any], filepath_key: str = "filePath") -> bool: + path_value = str(sample.get(filepath_key) or "") + return "references" in _parts(path_value) + + +def _ext_from_sample( + sample: Dict[str, Any], + filepath_key: str = "filePath", + filetype_key: str = "fileType", + target_type_key: str = "target_type", +) -> str: + for key in (target_type_key, filetype_key): + value = str(sample.get(key) or "").strip().lower().lstrip(".") + if value: + return value + path_value = str(sample.get(filepath_key) or "").strip() + return Path(path_value).suffix.lower().lstrip(".") if path_value else "" + + +def is_audio_sample( + sample: Dict[str, Any], + filepath_key: str = "filePath", + filetype_key: str = "fileType", + target_type_key: str = "target_type", + data_key: str = "data", +) -> bool: + if is_reference_sample(sample, filepath_key): + return False + data = sample.get(data_key) + if isinstance(data, (bytes, bytearray)) and data: + return True + return _ext_from_sample(sample, filepath_key, filetype_key, target_type_key) in AUDIO_EXTS + + +def invalid_quality_reason(sample: Dict[str, Any], ext_params_key: str = "ext_params") -> str: + for key in ("fileName", "sourceFileName", "filePath"): + marker_source = Path(str(sample.get(key) or "")).stem.lower() + marker = "__quality_invalid" + if marker in marker_source: + reason = marker_source.split(marker, 1)[1].strip("_") or "invalid_audio" + return f"invalid_audio_quality:{reason}" + + ext = sample.get(ext_params_key, {}) + if not isinstance(ext, dict): + return "" + quality = ext.get("audio_quality", {}) + if not isinstance(quality, dict): + return "" + if str(quality.get("quality_flag") or "").strip().lower() != "invalid": + return "" + skip_downstream = quality.get("skip_downstream", True) + if isinstance(skip_downstream, str): + skip_downstream = skip_downstream.strip().lower() in {"1", "true", "yes", "y", "on"} + if not skip_downstream: + return "" + reason = str(quality.get("reason") or "invalid_audio").strip() + return f"invalid_audio_quality:{reason}" + + +def mark_skipped_sample( + sample: Dict[str, Any], + reason: str, + op_name: str, + text_key: str = "text", + data_key: str = "data", + filetype_key: str = "fileType", + target_type_key: str = "target_type", + ext_params_key: str = "ext_params", +) -> Dict[str, Any]: + ext = sample.get(ext_params_key, {}) + if not isinstance(ext, dict): + ext = {"_raw": ext} + ext.setdefault("audio_skip", {})[op_name] = reason + sample[ext_params_key] = ext + sample[text_key] = "" + sample[data_key] = b"" + sample[filetype_key] = "" + sample[target_type_key] = "" + logger.info(f"fileName: {sample.get('fileName')}, method: {op_name} skipped: {reason}") + return sample diff --git a/runtime/ops/mapper/audio_telephony_bandpass/metadata.yml b/runtime/ops/mapper/audio_telephony_bandpass/metadata.yml new file mode 100644 index 00000000..7b678025 --- /dev/null +++ b/runtime/ops/mapper/audio_telephony_bandpass/metadata.yml @@ -0,0 +1,50 @@ +name: 'audioUtils-电话带通' +name_en: 'audioUtils-Telephony Bandpass' +description: '模拟窄带话机频带(默认 300–3400Hz)。需要 scipy.signal;处理音频并由 DataMate 统一导出结果。' +description_en: 'Simulate telephony bandpass (default 300–3400Hz). Requires scipy.signal; process audio and let DataMate export the result.' +language: 'python' +vendor: 'huawei' +raw_id: 'AudioTelephonyBandpass' +version: '1.0.0' +types: + - 'cleaning' +modal: 'audio' +inputs: 'audio' +outputs: 'audio' +settings: + lowHz: + name: '下截止(Hz)' + type: 'inputNumber' + description: '带通下截止频率。' + defaultVal: 300 + min: 1 + max: 20000 + step: 1 + highHz: + name: '上截止(Hz)' + type: 'inputNumber' + description: '带通上截止频率。' + defaultVal: 3400 + min: 1 + max: 20000 + step: 1 + order: + name: '阶数' + type: 'inputNumber' + description: 'Butterworth 阶数(建议 2~6)。' + defaultVal: 4 + min: 1 + max: 12 + step: 1 +runtime: + memory: 104857600 + cpu: 0.2 + gpu: 0 + npu: 0 + storage: 10MB + +metrics: + - name: '处理耗时' + metric: '依输入音频长度与运行环境而定' +release: + - '首次发布' diff --git a/runtime/ops/mapper/audio_telephony_bandpass/process.py b/runtime/ops/mapper/audio_telephony_bandpass/process.py new file mode 100644 index 00000000..5d569243 --- /dev/null +++ b/runtime/ops/mapper/audio_telephony_bandpass/process.py @@ -0,0 +1,113 @@ +# -- encoding: utf-8 -- + +import io +import time +from pathlib import Path +from typing import Dict, Any, Tuple + +from loguru import logger + +from datamate.core.base_op import Mapper +try: + from .audio_skip import invalid_quality_reason, is_audio_sample, mark_skipped_sample +except ImportError: + from audio_skip import invalid_quality_reason, is_audio_sample, mark_skipped_sample + + + +def _load_audio(source: object) -> Tuple["object", int]: + try: + import soundfile as sf # type: ignore + + if isinstance(source, (bytes, bytearray)): + data, sr = sf.read(io.BytesIO(bytes(source)), always_2d=False) + else: + data, sr = sf.read(str(source), always_2d=False) + return data, int(sr) + except Exception as e: + raise RuntimeError(f"读取音频失败(需要 soundfile): error={e}") from e + + +def _dump_audio(data: "object", sr: int, fmt: str) -> bytes: + try: + import soundfile as sf # type: ignore + + with io.BytesIO() as buf: + sf.write(buf, data, int(sr), format=fmt.upper() if fmt else "WAV") + return buf.getvalue() + except Exception as e: + raise RuntimeError(f"编码音频失败(需要 soundfile,fmt={fmt}): {e}") from e + + +class AudioTelephonyBandpass(Mapper): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.low_hz = float(kwargs.get("lowHz", 300)) + self.high_hz = float(kwargs.get("highHz", 3400)) + self.order = int(float(kwargs.get("order", 4))) + self.out_format = "wav" + + def execute(self, sample: Dict[str, Any]) -> Dict[str, Any]: + start = time.time() + quality_skip_reason = invalid_quality_reason(sample, self.ext_params_key) + if quality_skip_reason: + return mark_skipped_sample( + sample, + quality_skip_reason, + self.__class__.__name__, + self.text_key, + self.data_key, + self.filetype_key, + self.target_type_key, + self.ext_params_key, + ) + + if not is_audio_sample(sample, self.filepath_key, self.filetype_key, self.target_type_key, self.data_key): + return mark_skipped_sample( + sample, + "non_audio_or_reference_file", + self.__class__.__name__, + self.text_key, + self.data_key, + self.filetype_key, + self.target_type_key, + self.ext_params_key, + ) + in_path = Path(sample.get(self.filepath_key, "")).resolve() + if not in_path.exists(): + raise FileNotFoundError(f"输入音频不存在: {in_path}") + + data, sr = _load_audio(sample.get(self.data_key) or in_path) + try: + import numpy as np + from scipy.signal import butter, lfilter # type: ignore + + x = np.asarray(data, dtype=np.float32) + if x.ndim > 1: + x = x.mean(axis=1) + if x.size == 0: + y = x + else: + nyq = float(sr) / 2.0 + low = max(1.0, float(self.low_hz)) / nyq + high = min(nyq - 1.0, float(self.high_hz)) / nyq + if not (0.0 < low < high < 1.0): + raise ValueError(f"非法带通范围: low={self.low_hz}, high={self.high_hz}, sr={sr}") + b, a = butter(max(1, int(self.order)), [low, high], btype="bandpass") + y = lfilter(b, a, x).astype(np.float32) + y = np.clip(y, -1.0, 1.0) + except ImportError as e: + raise RuntimeError("AudioTelephonyBandpass 需要 scipy.signal(butter/lfilter)") from e + except Exception as e: + raise RuntimeError(f"处理失败: {e}") from e + + sample[self.data_key] = _dump_audio(y, sr, self.out_format) + sample[self.text_key] = "" + sample[self.target_type_key] = self.out_format + sample[self.filetype_key] = "txt" + + logger.info( + f"fileName: {sample.get(self.filename_key)}, method: AudioTelephonyBandpass costs {time.time() - start:6f} s" + ) + return sample + diff --git a/runtime/ops/mapper/audio_telephony_bandpass/requirements.txt b/runtime/ops/mapper/audio_telephony_bandpass/requirements.txt new file mode 100644 index 00000000..843a926a --- /dev/null +++ b/runtime/ops/mapper/audio_telephony_bandpass/requirements.txt @@ -0,0 +1,3 @@ +soundfile +numpy +scipy diff --git a/runtime/ops/mapper/audio_text_summarize/README.md b/runtime/ops/mapper/audio_text_summarize/README.md new file mode 100644 index 00000000..0bece3a6 --- /dev/null +++ b/runtime/ops/mapper/audio_text_summarize/README.md @@ -0,0 +1,37 @@ +# AudioTextSummarize ASR 文本概括算子 + +AudioTextSummarize 面向音频 ASR 之后的文本,做高保真抽取式概括。它只负责概括,不做关键信息保留率、准确率或测试集指标计算。 + +## 输入输出 + +- 输入:`sample["text"]` 中的 ASR 文本;若为空,可读取 txt/md/json/jsonl 文件路径 +- 输出:摘要文本写回 `sample["text"]` +- 运行明细:`ext_params.audio_text_summarize` + +## 方法 + +- `extractive`:默认轻量抽取式概括,中文按字符窗口,英文按词窗口,尽量保留原文连续片段 +- `bert_onnx`:使用本地 `model.onnx` + tokenizer 对原文与候选片段编码,选择语义最接近原文的片段 + +默认 ONNX 模型目录: + +- `/models/AudioOperations/summary/summary-model` + +## 多行模式 + +`lineMode` 可处理 ASR 合并文件: + +- `single`:全文当作一条 +- `tab`:每行 `keytext` +- `space`:每行 `key text` +- `auto`:每行都含 TAB 时按 `tab`,否则按 `single` + +## 常用参数 + +| 参数 | 默认值 | 说明 | +|---|---:|---| +| maxSummaryCharsZh | 40 | 中文摘要最大字数 | +| maxSummaryWordsEn | 18 | 英文摘要最大词数 | +| minSummaryWordsEn | 8 | 英文抽取窗口最小词数 | +| preserveKeys | true | 多行输出是否保留 key | +| cpuThreads | 4 | CPU/ONNX 线程限制 | diff --git a/runtime/ops/mapper/audio_text_summarize/__init__.py b/runtime/ops/mapper/audio_text_summarize/__init__.py new file mode 100644 index 00000000..e6a73c85 --- /dev/null +++ b/runtime/ops/mapper/audio_text_summarize/__init__.py @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- + +from datamate.core.base_op import OPERATORS + +OPERATORS.register_module(module_name='AudioTextSummarize', + module_path="ops.mapper.audio_text_summarize.process") diff --git a/runtime/ops/mapper/audio_text_summarize/metadata.yml b/runtime/ops/mapper/audio_text_summarize/metadata.yml new file mode 100644 index 00000000..fe9cd7c4 --- /dev/null +++ b/runtime/ops/mapper/audio_text_summarize/metadata.yml @@ -0,0 +1,119 @@ +name: 'audioOps-ASR文本概括' +name_en: 'audioOps-ASR Text Summarization' +description: '对 ASR 转写文本做高保真抽取式概括,保留原文关键信息;可选使用本地 ONNX embedding 模型辅助选片。' +description_en: 'Summarize ASR transcript text with a high-fidelity extractive method; optionally use a local ONNX embedding model for span selection.' +language: 'python' +vendor: 'huawei' +raw_id: 'AudioTextSummarize' +version: '1.0.0' +types: + - 'annotation' +modal: 'text' +inputs: 'text' +outputs: 'text' +settings: + method: + name: '概括方法' + description: 'extractive 为轻量抽取式;bert_onnx 使用本地 ONNX embedding 模型选择代表片段。' + type: 'select' + defaultVal: 'extractive' + required: true + options: + - label: 'extractive' + value: 'extractive' + - label: 'bert_onnx' + value: 'bert_onnx' + maxSummaryCharsZh: + name: '中文最大字数' + type: 'inputNumber' + description: '中文摘要最大字符数,0 表示不限制。' + defaultVal: 40 + min: 0 + max: 500 + step: 1 + maxSummaryWordsEn: + name: '英文最大词数' + type: 'inputNumber' + description: '英文摘要最大词数,0 表示不限制。' + defaultVal: 18 + min: 0 + max: 200 + step: 1 + minSummaryWordsEn: + name: '英文最小词数' + type: 'inputNumber' + description: '抽取式英文滑窗搜索的最小词数。' + defaultVal: 8 + min: 1 + max: 200 + step: 1 + lineMode: + name: '行解析模式' + description: 'single 将全文当作一条;tab 解析 keytext;space 解析 key text;auto 仅在每行含 TAB 时解析。' + type: 'select' + defaultVal: 'single' + required: true + options: + - label: 'single' + value: 'single' + - label: 'auto' + value: 'auto' + - label: 'tab' + value: 'tab' + - label: 'space' + value: 'space' + preserveKeys: + name: '保留 key' + type: 'switch' + description: '解析多行 key 文本时,输出是否保留 key。' + defaultVal: 'true' + required: false + checkedLabel: '保留' + unCheckedLabel: '不保留' + onnxModelDir: + name: 'ONNX 模型目录' + description: 'bert_onnx 方法使用,目录需包含 model.onnx 与 tokenizer 文件。' + type: 'input' + defaultVal: '/models/AudioOperations/summary/summary-model' + required: false + providersPriority: + name: 'ONNX Provider 优先级' + description: '逗号分隔,例如 CANNExecutionProvider,CPUExecutionProvider。' + type: 'input' + defaultVal: 'CANNExecutionProvider,CPUExecutionProvider' + required: false + cpuThreads: + name: 'CPU 线程数' + type: 'inputNumber' + description: '限制 jieba/BLAS/onnxruntime 线程数。' + defaultVal: 4 + min: 1 + max: 64 + step: 1 + maxWindows: + name: '最大候选片段数' + type: 'inputNumber' + description: 'bert_onnx 方法最多编码的候选片段数。' + defaultVal: 96 + min: 1 + max: 512 + step: 1 + keepOriginalInExt: + name: 'ext 保留原文' + type: 'switch' + description: '是否在 ext_params 明细中保留原始文本。' + defaultVal: 'false' + required: false + checkedLabel: '保留' + unCheckedLabel: '不保留' +runtime: + memory: 2147483648 + cpu: 0.5 + gpu: 0 + npu: 0 + storage: 10MB +metrics: + - name: '摘要方式' + metric: '抽取式概括,优先保留原文连续片段' +release: + - '首次发布,支持 ASR 文本单条/多行概括' diff --git a/runtime/ops/mapper/audio_text_summarize/process.py b/runtime/ops/mapper/audio_text_summarize/process.py new file mode 100644 index 00000000..f75cd1f4 --- /dev/null +++ b/runtime/ops/mapper/audio_text_summarize/process.py @@ -0,0 +1,471 @@ +# -- encoding: utf-8 -- + +from __future__ import annotations + +import json +import math +import os +import re +import time +from pathlib import Path +from typing import Any, Dict, List, Sequence, Tuple + +import numpy as np +try: + from loguru import logger +except Exception: + import logging + + logger = logging.getLogger(__name__) + +from datamate.core.base_op import Mapper + + +DEFAULT_ONNX_MODEL_DIR = "/models/AudioOperations/summary/summary-model" +_RE_CJK = re.compile(r"[\u4e00-\u9fff]") +_RE_EN_WORD = re.compile(r"[A-Za-z]+(?:'[A-Za-z]+)?") +_EN_STOP = { + "a", "an", "the", "and", "or", "but", "if", "then", "else", "so", "as", "at", "by", "for", "from", + "in", "into", "of", "on", "onto", "out", "over", "to", "up", "with", "without", "about", "after", + "before", "between", "during", "through", "under", "again", "once", "here", "there", "when", "where", + "why", "how", "all", "any", "both", "each", "few", "more", "most", "other", "some", "such", "no", + "nor", "not", "only", "own", "same", "than", "too", "very", "can", "will", "just", "should", "now", + "i", "me", "my", "we", "us", "our", "you", "your", "he", "him", "his", "she", "her", "they", "them", + "their", "it", "its", "this", "that", "these", "those", "is", "are", "was", "were", "be", "been", + "being", "have", "has", "had", "do", "does", "did", +} + + +def _as_bool(value: object) -> bool: + if isinstance(value, bool): + return value + return str(value).strip().lower() in {"1", "true", "yes", "y", "on"} + + +def _limit_cpu_threads(n: int) -> None: + s = str(max(1, int(n))) + for key in ( + "OMP_NUM_THREADS", + "MKL_NUM_THREADS", + "OPENBLAS_NUM_THREADS", + "NUMEXPR_NUM_THREADS", + "VECLIB_MAXIMUM_THREADS", + ): + os.environ[key] = s + + +def _detect_lang(text: str) -> str: + return "zh" if _RE_CJK.search(text or "") else "en" + + +def _clean_en(text: str) -> str: + return re.sub(r"\s+", " ", (text or "").strip()) + + +def _clean_zh(text: str) -> str: + return re.sub(r"\s+", "", (text or "").strip()) + + +def _en_tokens(text: str) -> List[str]: + return [m.group(0) for m in _RE_EN_WORD.finditer(text or "")] + + +def _idf(n_docs: int, df: int) -> float: + return float(math.log((n_docs + 1.0) / (df + 1.0)) + 1.0) + + +def _single_doc_en_idf(text: str) -> Dict[str, float]: + toks = {w.lower() for w in _en_tokens(text) if w} + return {w: _idf(1, 1) for w in toks} + + +def _single_doc_zh_idf(text: str) -> Dict[str, float]: + try: + import jieba # type: ignore + + toks = {tok for tok in jieba.lcut(_clean_zh(text)) if tok and tok.strip()} + except Exception: + toks = set(_clean_zh(text)) + return {w: _idf(1, 1) for w in toks} + + +def _best_en_window(text: str, *, min_words: int, max_words: int) -> str: + s = _clean_en(text) + words = _en_tokens(s) + if not words: + return "" + max_words = int(max_words) + if max_words <= 0 or len(words) <= max_words: + return " ".join(words) + min_words = max(1, min(int(min_words), max_words)) + idf_map = _single_doc_en_idf(s) + weights: List[float] = [] + for w in words: + wl = w.lower() + if wl in _EN_STOP or len(wl) <= 1: + weights.append(0.0) + else: + weights.append(float(idf_map.get(wl, 1.0))) + pref = [0.0] + for x in weights: + pref.append(pref[-1] + x) + best = (0, min(max_words, len(words))) + best_score = -1.0 + for length in range(min_words, max_words + 1): + if length > len(words): + break + for start in range(0, len(words) - length + 1): + score = pref[start + length] - pref[start] + density = score / float(length) + combined = score + 0.15 * density + if combined > best_score: + best_score = combined + best = (start, start + length) + return " ".join(words[best[0] : best[1]]).strip() + + +def _best_zh_window(text: str, *, max_chars: int) -> str: + s = _clean_zh(text) + if not s: + return "" + max_chars = int(max_chars) + if max_chars <= 0 or len(s) <= max_chars: + return s + idf_map = _single_doc_zh_idf(s) + scores = [0.0] * len(s) + try: + import jieba # type: ignore + + spans = list(jieba.tokenize(s)) + for tok, start, end in spans: + t = (tok or "").strip() + if not t: + continue + weight = float(idf_map.get(t, 1.0)) + if len(t) == 1: + weight *= 0.25 + for pos in range(max(0, start), min(len(s), end)): + scores[pos] += weight + except Exception: + for i, ch in enumerate(s): + scores[i] = 0.25 if ch in "的一是在和了有就不人都" else 1.0 + pref = [0.0] + for x in scores: + pref.append(pref[-1] + x) + best_start = 0 + best_score = -1.0 + for start in range(0, len(s) - max_chars + 1): + score = pref[start + max_chars] - pref[start] + if score > best_score: + best_score = score + best_start = start + return s[best_start : best_start + max_chars].strip() + + +def _truncate_summary(summary: str, lang: str, max_chars_zh: int, max_words_en: int) -> str: + if lang == "zh": + s = _clean_zh(summary) + return s[: int(max_chars_zh)].strip() if int(max_chars_zh) > 0 else s + words = _en_tokens(summary) + if int(max_words_en) > 0: + words = words[: int(max_words_en)] + return " ".join(words).strip() + + +def _extractive_summary(text: str, max_chars_zh: int, max_words_en: int, min_words_en: int) -> Tuple[str, str]: + lang = _detect_lang(text) + if lang == "zh": + return _best_zh_window(text, max_chars=int(max_chars_zh)), lang + return _best_en_window(text, min_words=int(min_words_en), max_words=int(max_words_en)), lang + + +def _parse_keyed_lines(text: str, mode: str) -> List[Tuple[str, str]]: + rows: List[Tuple[str, str]] = [] + lines = [line.rstrip("\n") for line in (text or "").splitlines() if line.strip()] + if not lines: + return [] + actual_mode = str(mode or "single").strip().lower() + if actual_mode == "single": + return [("", text.strip())] + if actual_mode == "auto": + if not all("\t" in line for line in lines): + return [("", text.strip())] + actual_mode = "tab" + for idx, line in enumerate(lines): + if actual_mode == "tab" and "\t" in line: + key, value = line.split("\t", 1) + elif actual_mode == "space": + parts = line.strip().split(maxsplit=1) + key = parts[0] if parts else str(idx) + value = parts[1] if len(parts) > 1 else "" + else: + key, value = str(idx), line + rows.append((key.strip(), value.strip())) + return rows + + +def _mark_skipped_text_sample(sample: Dict[str, Any], reason: str, op_name: str, keys: Tuple[str, ...]) -> Dict[str, Any]: + text_key, data_key, filetype_key, target_type_key, ext_params_key = keys + ext = sample.get(ext_params_key, {}) + if not isinstance(ext, dict): + ext = {"_raw": ext} + ext.setdefault("audio_skip", {})[op_name] = reason + sample[ext_params_key] = ext + sample[text_key] = "" + sample[data_key] = b"" + sample[filetype_key] = "" + sample[target_type_key] = "" + logger.info(f"fileName: {sample.get('fileName')}, method: {op_name} skipped: {reason}") + return sample + + +def _read_text_from_sample(sample: Dict[str, Any], text_key: str, filepath_key: str, filetype_key: str) -> str: + text = str(sample.get(text_key) or "") + if text.strip(): + return text + file_type = str(sample.get(filetype_key) or "").strip().lower().lstrip(".") + path_value = str(sample.get(filepath_key) or "").strip() + if file_type in {"txt", "text", "md", "json", "jsonl"} and path_value: + path = Path(path_value).expanduser().resolve() + if path.exists() and path.is_file(): + return path.read_text(encoding="utf-8", errors="ignore") + return "" + + +def _resolve_onnx_model_dir(value: str) -> Path: + raw = str(value or "").strip() or DEFAULT_ONNX_MODEL_DIR + path = Path(raw).expanduser() + if path.exists(): + return path.resolve() + bundled = Path(__file__).resolve().parent / "models" / "summary-model" + if bundled.exists(): + return bundled.resolve() + return path.resolve() + + +def _available_providers() -> List[str]: + try: + import onnxruntime as ort # type: ignore + + return list(ort.get_available_providers()) + except Exception: + return [] + + +def _pick_providers(provider_arg: str) -> List[str]: + requested = [p.strip() for p in str(provider_arg or "").split(",") if p.strip()] + if not requested: + requested = ["CANNExecutionProvider", "CPUExecutionProvider"] + available = set(_available_providers()) + picked = [p for p in requested if p in available] + return picked or ["CPUExecutionProvider"] + + +_ONNX_CACHE: Dict[Tuple[str, str], Tuple[Any, Any, List[str]]] = {} + + +def _load_onnx_embedder(model_dir: Path, providers: Sequence[str], cpu_threads: int): + cache_key = (str(model_dir), ",".join(providers)) + if cache_key in _ONNX_CACHE: + return _ONNX_CACHE[cache_key] + + import onnxruntime as ort # type: ignore + from transformers import AutoTokenizer # type: ignore + + model_path = model_dir / "model.onnx" + if not model_path.exists(): + raise FileNotFoundError(f"摘要 ONNX 模型不存在: {model_path}") + tokenizer = AutoTokenizer.from_pretrained(str(model_dir), local_files_only=True) + opts = ort.SessionOptions() + opts.intra_op_num_threads = int(cpu_threads) + opts.inter_op_num_threads = 1 + session = ort.InferenceSession(str(model_path), sess_options=opts, providers=list(providers)) + used = list(session.get_providers()) + _ONNX_CACHE[cache_key] = (tokenizer, session, used) + return tokenizer, session, used + + +def _mean_pool(last_hidden: np.ndarray, attention_mask: np.ndarray) -> np.ndarray: + mask = attention_mask.astype(np.float32) + if last_hidden.ndim == 2: + return last_hidden[0].astype(np.float32) + masked = last_hidden * mask[:, :, None] + denom = np.maximum(mask.sum(axis=1, keepdims=True), 1e-8) + return (masked.sum(axis=1) / denom)[0].astype(np.float32) + + +def _embed_texts(texts: Sequence[str], model_dir: Path, providers: Sequence[str], cpu_threads: int) -> Tuple[List[np.ndarray], List[str]]: + tokenizer, session, used = _load_onnx_embedder(model_dir, providers, cpu_threads) + out: List[np.ndarray] = [] + input_names = {inp.name for inp in session.get_inputs()} + for text in texts: + enc = tokenizer( + text, + return_tensors="np", + truncation=True, + max_length=512, + padding=True, + ) + feeds: Dict[str, np.ndarray] = {} + for name in input_names: + if name in enc: + feeds[name] = enc[name].astype(np.int64) + elif name == "token_type_ids": + feeds[name] = np.zeros_like(enc["input_ids"], dtype=np.int64) + result = session.run(None, feeds) + vec = _mean_pool(np.asarray(result[0]), np.asarray(enc["attention_mask"])) + norm = float(np.linalg.norm(vec)) + if norm > 0: + vec = vec / norm + out.append(vec) + return out, used + + +def _cosine(a: np.ndarray, b: np.ndarray) -> float: + denom = float(np.linalg.norm(a) * np.linalg.norm(b)) + if denom <= 0: + return 0.0 + return float(np.dot(a, b) / denom) + + +def _candidate_windows(text: str, lang: str, max_chars_zh: int, max_words_en: int, max_windows: int) -> List[str]: + if lang == "zh": + s = _clean_zh(text) + if not s: + return [] + size = max(8, int(max_chars_zh)) + stride = max(1, size // 2) + if len(s) <= size: + return [s] + windows = [s[i : i + size] for i in range(0, max(1, len(s) - size + 1), stride)] + if windows and windows[-1] != s[-size:]: + windows.append(s[-size:]) + return windows[: max(1, int(max_windows))] + + words = _en_tokens(text) + if not words: + return [] + size = max(4, int(max_words_en)) + stride = max(1, size // 2) + if len(words) <= size: + return [" ".join(words)] + windows = [" ".join(words[i : i + size]) for i in range(0, max(1, len(words) - size + 1), stride)] + tail = " ".join(words[-size:]) + if windows and windows[-1] != tail: + windows.append(tail) + return windows[: max(1, int(max_windows))] + + +def _onnx_extractive_summary( + text: str, + *, + model_dir: Path, + providers: Sequence[str], + cpu_threads: int, + max_chars_zh: int, + max_words_en: int, + max_windows: int, +) -> Tuple[str, str, Dict[str, Any]]: + lang = _detect_lang(text) + windows = _candidate_windows(text, lang, max_chars_zh, max_words_en, max_windows) + if not windows: + return "", lang, {"providers": list(providers), "windows": 0} + vectors, used = _embed_texts([text, *windows], model_dir, providers, cpu_threads) + query = vectors[0] + candidates = vectors[1:] + best_idx = max(range(len(candidates)), key=lambda i: _cosine(query, candidates[i])) + summary = _truncate_summary(windows[best_idx], lang, max_chars_zh, max_words_en) + return summary, lang, {"providers": used, "windows": len(windows), "selected_window": int(best_idx)} + + +class AudioTextSummarize(Mapper): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.method = str(kwargs.get("method", "extractive")).strip().lower() + self.max_chars_zh = int(float(kwargs.get("maxSummaryCharsZh", 40))) + self.max_words_en = int(float(kwargs.get("maxSummaryWordsEn", 18))) + self.min_words_en = int(float(kwargs.get("minSummaryWordsEn", 8))) + self.line_mode = str(kwargs.get("lineMode", "single")).strip().lower() + self.preserve_keys = _as_bool(kwargs.get("preserveKeys", True)) + self.onnx_model_dir = str(kwargs.get("onnxModelDir", DEFAULT_ONNX_MODEL_DIR)).strip() + self.providers_priority = str(kwargs.get("providersPriority", "CANNExecutionProvider,CPUExecutionProvider")).strip() + self.cpu_threads = int(float(kwargs.get("cpuThreads", 4))) + self.max_windows = int(float(kwargs.get("maxWindows", 96))) + self.keep_original = _as_bool(kwargs.get("keepOriginalInExt", False)) + + def execute(self, sample: Dict[str, Any]) -> Dict[str, Any]: + start = time.time() + text = _read_text_from_sample(sample, self.text_key, self.filepath_key, self.filetype_key) + if not text.strip(): + return _mark_skipped_text_sample( + sample, + "empty_text_for_summary", + self.__class__.__name__, + (self.text_key, self.data_key, self.filetype_key, self.target_type_key, self.ext_params_key), + ) + + _limit_cpu_threads(self.cpu_threads) + rows = _parse_keyed_lines(text, self.line_mode) + summaries: List[Tuple[str, str, str, Dict[str, Any]]] = [] + method = self.method + if method not in {"extractive", "bert_onnx"}: + raise ValueError(f"不支持的文本概括方法: {self.method}") + + for key, row_text in rows: + if method == "bert_onnx": + model_dir = _resolve_onnx_model_dir(self.onnx_model_dir) + providers = _pick_providers(self.providers_priority) + summary, lang, meta = _onnx_extractive_summary( + row_text, + model_dir=model_dir, + providers=providers, + cpu_threads=self.cpu_threads, + max_chars_zh=self.max_chars_zh, + max_words_en=self.max_words_en, + max_windows=self.max_windows, + ) + meta["model_dir"] = str(model_dir) + else: + summary, lang = _extractive_summary(row_text, self.max_chars_zh, self.max_words_en, self.min_words_en) + meta = {"providers": ["CPUExecutionProvider"], "windows": 0} + summaries.append((key, row_text, summary, {"lang": lang, **meta})) + + if self.preserve_keys and any(key for key, _text, _summary, _meta in summaries): + output_text = "\n".join(f"{key}\t{summary}" if key else summary for key, _text, summary, _meta in summaries) + else: + output_text = "\n".join(summary for _key, _text, summary, _meta in summaries) + + details = [] + for key, row_text, summary, meta in summaries: + item: Dict[str, Any] = { + "key": key, + "summary": summary, + "language": meta.get("lang"), + "input_chars": len(row_text), + "summary_chars": len(summary), + "method": method, + "runtime": {k: v for k, v in meta.items() if k != "lang"}, + } + if self.keep_original: + item["original_text"] = row_text + details.append(item) + + ext = sample.get(self.ext_params_key, {}) + if not isinstance(ext, dict): + ext = {"_raw": ext} + ext["audio_text_summarize"] = { + "method": method, + "line_mode": self.line_mode, + "items": details, + "elapsed_ms": round((time.time() - start) * 1000.0, 3), + } + sample[self.ext_params_key] = ext + sample[self.text_key] = output_text + sample[self.data_key] = b"" + sample[self.filetype_key] = "txt" + sample[self.target_type_key] = "txt" + + logger.info( + f"fileName: {sample.get(self.filename_key)}, method: AudioTextSummarize costs {time.time() - start:6f} s" + ) + return sample diff --git a/runtime/ops/mapper/audio_text_summarize/requirements.txt b/runtime/ops/mapper/audio_text_summarize/requirements.txt new file mode 100644 index 00000000..195165ff --- /dev/null +++ b/runtime/ops/mapper/audio_text_summarize/requirements.txt @@ -0,0 +1,5 @@ +jieba +numpy +onnxruntime +transformers +loguru diff --git a/runtime/ops/mapper/audio_trim_silence_edges/README.md b/runtime/ops/mapper/audio_trim_silence_edges/README.md new file mode 100644 index 00000000..fffa61da --- /dev/null +++ b/runtime/ops/mapper/audio_trim_silence_edges/README.md @@ -0,0 +1,27 @@ +# AudioTrimSilenceEdges 首尾静音裁剪算子 + +## 概述 + +AudioTrimSilenceEdges 处理输入音频,并将结果写入 `sample["data"]`,同时设置 `sample["target_type"]`。输出路径、同名文件处理和最终落盘均交由 DataMate 的标准导出流程负责。 + +## 参数说明 + +| 参数 | 类型 | 默认值 | 说明 | +|---|---|---:|---| +| frameMs | inputNumber | 30 | 帧长(ms) | +| hopMs | inputNumber | 10 | 帧移(ms) | +| threshDb | slider | -50 | 能量阈值(dB,相对全段峰值) | +| padMs | inputNumber | 50 | 裁剪后两端各保留的静音(ms) | + +## 输入输出 + +- **输入**:`sample["filePath"]`,若上游算子已产生 `sample["data"]`,则优先处理该音频字节。 +- **输出**:`sample["data"]` 为处理后的音频字节;`sample["target_type"]` 为目标音频后缀。 + +## 依赖说明 + +- **Python 依赖**:soundfile、numpy + +## 版本历史 + +- **v1.0.0**:首次发布 diff --git a/runtime/ops/mapper/audio_trim_silence_edges/__init__.py b/runtime/ops/mapper/audio_trim_silence_edges/__init__.py new file mode 100644 index 00000000..6c5e09c1 --- /dev/null +++ b/runtime/ops/mapper/audio_trim_silence_edges/__init__.py @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- + +from datamate.core.base_op import OPERATORS + +OPERATORS.register_module(module_name='AudioTrimSilenceEdges', + module_path="ops.mapper.audio_trim_silence_edges.process") diff --git a/runtime/ops/mapper/audio_trim_silence_edges/audio_skip.py b/runtime/ops/mapper/audio_trim_silence_edges/audio_skip.py new file mode 100644 index 00000000..aec49613 --- /dev/null +++ b/runtime/ops/mapper/audio_trim_silence_edges/audio_skip.py @@ -0,0 +1,114 @@ +# -- encoding: utf-8 -- + +from pathlib import Path +from typing import Any, Dict + +from loguru import logger + + +AUDIO_EXTS = { + "aac", + "aif", + "aiff", + "amr", + "au", + "flac", + "m4a", + "mp3", + "oga", + "ogg", + "opus", + "snd", + "wav", + "webm", + "wma", +} + + +def _parts(path_value: str) -> set[str]: + try: + return {part.lower() for part in Path(path_value).parts} + except Exception: + return set() + + +def is_reference_sample(sample: Dict[str, Any], filepath_key: str = "filePath") -> bool: + path_value = str(sample.get(filepath_key) or "") + return "references" in _parts(path_value) + + +def _ext_from_sample( + sample: Dict[str, Any], + filepath_key: str = "filePath", + filetype_key: str = "fileType", + target_type_key: str = "target_type", +) -> str: + for key in (target_type_key, filetype_key): + value = str(sample.get(key) or "").strip().lower().lstrip(".") + if value: + return value + path_value = str(sample.get(filepath_key) or "").strip() + return Path(path_value).suffix.lower().lstrip(".") if path_value else "" + + +def is_audio_sample( + sample: Dict[str, Any], + filepath_key: str = "filePath", + filetype_key: str = "fileType", + target_type_key: str = "target_type", + data_key: str = "data", +) -> bool: + if is_reference_sample(sample, filepath_key): + return False + data = sample.get(data_key) + if isinstance(data, (bytes, bytearray)) and data: + return True + return _ext_from_sample(sample, filepath_key, filetype_key, target_type_key) in AUDIO_EXTS + + +def invalid_quality_reason(sample: Dict[str, Any], ext_params_key: str = "ext_params") -> str: + for key in ("fileName", "sourceFileName", "filePath"): + marker_source = Path(str(sample.get(key) or "")).stem.lower() + marker = "__quality_invalid" + if marker in marker_source: + reason = marker_source.split(marker, 1)[1].strip("_") or "invalid_audio" + return f"invalid_audio_quality:{reason}" + + ext = sample.get(ext_params_key, {}) + if not isinstance(ext, dict): + return "" + quality = ext.get("audio_quality", {}) + if not isinstance(quality, dict): + return "" + if str(quality.get("quality_flag") or "").strip().lower() != "invalid": + return "" + skip_downstream = quality.get("skip_downstream", True) + if isinstance(skip_downstream, str): + skip_downstream = skip_downstream.strip().lower() in {"1", "true", "yes", "y", "on"} + if not skip_downstream: + return "" + reason = str(quality.get("reason") or "invalid_audio").strip() + return f"invalid_audio_quality:{reason}" + + +def mark_skipped_sample( + sample: Dict[str, Any], + reason: str, + op_name: str, + text_key: str = "text", + data_key: str = "data", + filetype_key: str = "fileType", + target_type_key: str = "target_type", + ext_params_key: str = "ext_params", +) -> Dict[str, Any]: + ext = sample.get(ext_params_key, {}) + if not isinstance(ext, dict): + ext = {"_raw": ext} + ext.setdefault("audio_skip", {})[op_name] = reason + sample[ext_params_key] = ext + sample[text_key] = "" + sample[data_key] = b"" + sample[filetype_key] = "" + sample[target_type_key] = "" + logger.info(f"fileName: {sample.get('fileName')}, method: {op_name} skipped: {reason}") + return sample diff --git a/runtime/ops/mapper/audio_trim_silence_edges/metadata.yml b/runtime/ops/mapper/audio_trim_silence_edges/metadata.yml new file mode 100644 index 00000000..fb0d65c3 --- /dev/null +++ b/runtime/ops/mapper/audio_trim_silence_edges/metadata.yml @@ -0,0 +1,58 @@ +name: 'audioUtils-首尾静音裁剪' +name_en: 'audioUtils-Trim Silence Edges' +description: '从首尾向内裁剪静音,保留可选 padding。处理音频并由 DataMate 统一导出结果。' +description_en: 'Trim leading/trailing silence with optional padding. Process audio and let DataMate export the result.' +language: 'python' +vendor: 'huawei' +raw_id: 'AudioTrimSilenceEdges' +version: '1.0.0' +types: + - 'cleaning' +modal: 'audio' +inputs: 'audio' +outputs: 'audio' +settings: + frameMs: + name: '帧长(ms)' + type: 'inputNumber' + description: '分析帧长。' + defaultVal: 30 + min: 5 + max: 500 + step: 1 + hopMs: + name: '帧移(ms)' + type: 'inputNumber' + description: '帧移。' + defaultVal: 10 + min: 1 + max: 500 + step: 1 + threshDb: + name: '能量阈值(dB)' + type: 'slider' + description: '相对全段峰值的帧能量阈值(dB)。' + defaultVal: -50 + min: -80 + max: 0 + step: 1 + padMs: + name: '保留静音(ms)' + type: 'inputNumber' + description: '裁剪后两端各保留的 padding(毫秒)。' + defaultVal: 50 + min: 0 + max: 5000 + step: 1 +runtime: + memory: 104857600 + cpu: 0.15 + gpu: 0 + npu: 0 + storage: 10MB + +metrics: + - name: '处理耗时' + metric: '依输入音频长度与运行环境而定' +release: + - '首次发布' diff --git a/runtime/ops/mapper/audio_trim_silence_edges/process.py b/runtime/ops/mapper/audio_trim_silence_edges/process.py new file mode 100644 index 00000000..6e0fc89b --- /dev/null +++ b/runtime/ops/mapper/audio_trim_silence_edges/process.py @@ -0,0 +1,126 @@ +# -- encoding: utf-8 -- + +import io +import time +from pathlib import Path +from typing import Dict, Any, Tuple + +from loguru import logger + +from datamate.core.base_op import Mapper +try: + from .audio_skip import invalid_quality_reason, is_audio_sample, mark_skipped_sample +except ImportError: + from audio_skip import invalid_quality_reason, is_audio_sample, mark_skipped_sample + + + +def _load_audio(source: object) -> Tuple["object", int]: + try: + import soundfile as sf # type: ignore + + if isinstance(source, (bytes, bytearray)): + data, sr = sf.read(io.BytesIO(bytes(source)), always_2d=False) + else: + data, sr = sf.read(str(source), always_2d=False) + return data, int(sr) + except Exception as e: + raise RuntimeError(f"读取音频失败(需要 soundfile): error={e}") from e + + +def _dump_audio(data: "object", sr: int, fmt: str) -> bytes: + try: + import soundfile as sf # type: ignore + + with io.BytesIO() as buf: + sf.write(buf, data, int(sr), format=fmt.upper() if fmt else "WAV") + return buf.getvalue() + except Exception as e: + raise RuntimeError(f"编码音频失败(需要 soundfile,fmt={fmt}): {e}") from e + + +class AudioTrimSilenceEdges(Mapper): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.frame_ms = float(kwargs.get("frameMs", 30)) + self.hop_ms = float(kwargs.get("hopMs", 10)) + self.thresh_db = float(kwargs.get("threshDb", -50)) + self.pad_ms = float(kwargs.get("padMs", 50)) + self.out_format = "wav" + + def execute(self, sample: Dict[str, Any]) -> Dict[str, Any]: + start = time.time() + quality_skip_reason = invalid_quality_reason(sample, self.ext_params_key) + if quality_skip_reason: + return mark_skipped_sample( + sample, + quality_skip_reason, + self.__class__.__name__, + self.text_key, + self.data_key, + self.filetype_key, + self.target_type_key, + self.ext_params_key, + ) + + if not is_audio_sample(sample, self.filepath_key, self.filetype_key, self.target_type_key, self.data_key): + return mark_skipped_sample( + sample, + "non_audio_or_reference_file", + self.__class__.__name__, + self.text_key, + self.data_key, + self.filetype_key, + self.target_type_key, + self.ext_params_key, + ) + in_path = Path(sample.get(self.filepath_key, "")).resolve() + if not in_path.exists(): + raise FileNotFoundError(f"输入音频不存在: {in_path}") + + data, sr = _load_audio(sample.get(self.data_key) or in_path) + try: + import numpy as np + + x = np.asarray(data, dtype=np.float32) + if x.ndim > 1: + x = x.mean(axis=1) + if x.size == 0: + y = x + else: + peak = float(np.max(np.abs(x))) + 1e-12 + th = peak * (10.0 ** (float(self.thresh_db) / 20.0)) + frame_len = max(1, int(sr * self.frame_ms / 1000.0)) + hop = max(1, int(sr * self.hop_ms / 1000.0)) + + # 找到首个/末个“非静音”帧 + rms = [] + for st in range(0, len(x), hop): + ed = min(st + frame_len, len(x)) + f = x[st:ed] + rms.append(float(np.sqrt(np.mean(f * f) + 1e-12))) + keep = [i for i, r in enumerate(rms) if r >= th] + if not keep: + y = x[:0] + else: + first = keep[0] + last = keep[-1] + start_samp = first * hop + end_samp = min(len(x), last * hop + frame_len) + pad = int(sr * self.pad_ms / 1000.0) + start_samp = max(0, start_samp - pad) + end_samp = min(len(x), end_samp + pad) + y = x[start_samp:end_samp] + except Exception as e: + raise RuntimeError(f"处理失败(需要 numpy): {e}") from e + + sample[self.data_key] = _dump_audio(y, sr, self.out_format) + sample[self.text_key] = "" + sample[self.target_type_key] = self.out_format + sample[self.filetype_key] = "txt" + + logger.info( + f"fileName: {sample.get(self.filename_key)}, method: AudioTrimSilenceEdges costs {time.time() - start:6f} s" + ) + return sample + diff --git a/runtime/ops/mapper/audio_trim_silence_edges/requirements.txt b/runtime/ops/mapper/audio_trim_silence_edges/requirements.txt new file mode 100644 index 00000000..17e9d57d --- /dev/null +++ b/runtime/ops/mapper/audio_trim_silence_edges/requirements.txt @@ -0,0 +1,2 @@ +soundfile +numpy diff --git a/runtime/ops/pyproject.toml b/runtime/ops/pyproject.toml index 3a69efc5..6a941777 100644 --- a/runtime/ops/pyproject.toml +++ b/runtime/ops/pyproject.toml @@ -135,6 +135,7 @@ dependencies = [ "paddlepaddle==3.2.2", "paddlex==3.3.6", "pandas==2.3.3", + "panns-inference==0.1.1", "pdf2image==1.17.0", "pdfminer.six==20251230", "pdfplumber==0.11.9",