diff --git a/.gitignore b/.gitignore index af0b09c..7bda6e9 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ *.csv +.torch_hub/* *.mp3 *.wav test.py @@ -7,7 +8,8 @@ chromadb cl_*.py *.ipynb *.bin - +temp/* +models/* .env __pycache__/ *.py[cod] @@ -70,4 +72,5 @@ ipython_config.py .pdm-build/ __pypackages__/ -venv/* \ No newline at end of file +venv/* +cache \ No newline at end of file diff --git a/README.md b/README.md index 9e26831..6a0c00b 100644 --- a/README.md +++ b/README.md @@ -1,278 +1,107 @@ # Balalaika Pipeline -A complete production-ready pipeline for processing podcast audio data, from download to feature extraction. +End-to-end speech data processing: ingest, segmentation, quality filtering, multi-model ASR with ROVER, punctuation, lexical stress, G2P, and export to Parquet / WebDataset. ---- +Works with Yandex Music podcasts out of the box, or **your own corpus** if you follow the expected layout (see [Preparing your dataset](docs/preparing.md)). -## Table of Contents - -1. [Prerequisites](#prerequisites) -2. [Installation](#installation) -3. [Data Preparation](#data-preparation) - - [Quick Setup (Default Parameters)](#quick-setup) - - [Custom Metadata Download](#custom-metadata-download) -4. [Running the Pipeline](#running-the-pipeline) - - [Basic Scenario (Local Processing)](#basic-scenario-local-processing) -5. [Configuration](#configuration) -6. [Environment Variables](#environment-variables) -7. [Models](#models) -8. [Citation](#citation) -9. [Acknowledgments](#acknowledgments) +**Pre-built processed datasets** (segmented, filtered, annotated) are published on Hugging Face: **[Balalaika Dataset — MTUCI collection](https://huggingface.co/collections/MTUCI/balalaika-dataset)**. --- -## Prerequisites +## Quick Start -Ensure you have the following tools installed on your system: +### Prerequisites ```bash sudo apt update && sudo apt install -y ffmpeg wget -qO- https://astral.sh/uv/install.sh | sh +``` -```` - ---- - -## Installation - -Clone the repository and set up the environment: +### Installation ```bash git clone https://github.com/mtuciru/balalaika cd balalaika -# Use this if you want to annotate/modify the dataset -bash create_dev_env.sh -# Use this if you only want to use the pre-annotated dataset -bash create_user_env.sh +bash create_dev_env.sh # full stack for running the pipeline +# or +bash create_user_env.sh # consume pre-built datasets only ``` ---- - -## Data Preparation - -### Quick Setup (Default Parameters) - -To download and prepare the dataset with default settings, choose one of the preconfigured dataset sizes: - -* **100-hour dataset** - ```bash - bash use_meta_100h.sh - ``` +### Basic setup -* **500-hour dataset** - ```bash - bash use_meta_500h.sh - ``` +1. Create `.env`: -* **1000-hour dataset** - ```bash - bash use_meta_1000h.sh - ``` - -* **2000-hour dataset** - ```bash - bash use_meta_2000h.sh - ``` - -All metadata can also be downloaded from [Hugging Face – MTUCI](https://huggingface.co/MTUCI). - -### Custom Metadata Download - -If you already have generated metadata files (`balalaika.parquet` and `balalaika.pkl`), place them in the project root and run: - -```bash -bash use_meta.sh +```ini +HF_TOKEN= +YANDEX_KEY= ``` ---- - -## Running the Pipeline +2. Edit `configs/config.yaml`: set absolute paths (`podcasts_path`, model files under `models/`, etc.). - -### Basic Scenario (Local Processing) - - -This scenario will: - -1. Download datasets -2. Split audio into semantic chunks -3. Transcribe all segments -4. Perform speaker segmentation -5. Apply phonemization - -To execute locally, run: +3. Run stages (see [Usage Guide](docs/guide.md)). Sequential wrapper: ```bash bash base.sh configs/config.yaml ``` -All output metadata will be saved in `podcasts/result.csv`. - ---- - -## Configuration - -The main configuration file is located at `configs/config.yaml`. This file is organized into several sections, each corresponding to a specific stage of the podcast processing pipeline. Below is a detailed explanation of the key parameters within each section. - ---- - -### Global Parameters +Note: `base.sh` may have early stages commented out—uncomment what you need. -* `podcasts_path`: It specifies the **absolute path** to the directory where all downloaded podcast files will be stored and where subsequent processing (preprocessing, separation, transcription, etc.) will look for and save its output. --- -### `download` Section +## Documentation -This section controls how podcast episodes are downloaded. +- **[Preparing your dataset](docs/preparing.md)** — HF collection vs. local pipeline, folder layout, models, config. +- **[Usage Guide](docs/guide.md)** — stages, artifacts, per-step commands. +- **[example/README.md](example/README.md)** — loading the WebDataset with Hugging Face `datasets`. -* `podcasts_path`: (As explained above) The directory where downloaded podcasts will be saved. -* `episodes_limit`: This sets a **limit on the number of episodes** to download from a single podcast playlist. -* `num_workers`: Specifies the **number of parallel processes** to use for downloading. A higher number can speed up downloads but will consume more system resources. -* `podcasts_urls_file`: This parameter points to the **path of a `.pkl` file** that contains a list of podcast URLs to be downloaded. +Per-module notes live under `src/*/README.md` (aligned with `configs/config.yaml`). --- -### `preprocess` Section - -This section handles the initial processing of downloaded audio files, such as chopping them into smaller segments. +## Pipeline overview -* `podcasts_path`: (As explained above) The directory containing the raw downloaded podcasts that need to be preprocessed. -* `duration`: Defines the **maximum length in seconds** for each audio sample (segment). -* `num_workers`: Specifies the **number of parallel processes** to use during preprocessing. -* `whisper_model`: Specifies the **name or path of the Faster-Whisper compatible model** to be used for initial audio processing. -* `compute_type`: Determines the **computation type** for the Whisper model, affecting performance and memory usage. -* `beam_size`: This parameter is related to the **beam search algorithm** used in the Whisper model's decoding process. +1. **Download** — optional episode fetch. +2. **Preprocess** — **Sortformer (ONNX)** diarization, single-speaker selection, **Smart Turn** boundary refinement, chunking + `balalaika.csv`; long source files removed after chunking; **crest-factor** filtering (`crest_factor` written to CSV, bad files deleted and their CSV rows removed); **EBU R128-style** loudness normalization (see `preprocess_yaml.sh` order). +3. **Separation** — **music detection** (WavLM-based): `music_prob` written to CSV, clips above threshold deleted and their CSV rows removed; **DistillMOS** → `DistillMOS` column in `balalaika.csv`. +4. **Transcription** — **[onnx-asr](https://github.com/istupakov/onnx-asr)** (ONNX Runtime / optional TensorRT), **ROVER** consensus, optional word-level `.tst`. +5. **Punctuation** — RUPunct. +6. **Accents** — ruAccent (e.g. `turbo3.1`). +7. **Phonemization** — **TryIParu** `G2PModel` → `*_rover_phonemes.txt`. +8. **Collate / export** — `balalaika.parquet` and WebDataset shards via `src/collate_yamls.sh`. --- -### `separation` Section - -This section calculates metrics for each audio - -* `podcasts_path`: (As explained above) The directory where the chopped podcasts (from the `preprocess` stage) are located. -* `num_workers`: The **number of parallel processes** to use for audio separation. -* `nisqa_config`: Specifies the **path to the configuration file for NISQA** -* `one_speaker`: A **boolean flag** (`True`/`False`) that, when enabled (`True`), instructs the system to download and process only those audio recordings that should contain a single speaker. - ---- - -### `transcription` Section - -This section is responsible for converting audio into text. - -* `podcasts_path`: (As explained above) The directory containing the processed audio files ready for transcription. -* `model_name`: Specifies the **type of automatic speech recognition (ASR) model** to use. Options typically include `"ctc" or "rnnt"`. -* `num_workers`: The **number of parallel processes per GPU** to use for transcription. -* `with_timestamps`: A **boolean flag** (`True`/`False`) that, when enabled, allows the transcription process to generate timestamps for each word or segment. **it only works with ctc** -* `lm_path`: Specifies the **path to a language model file (`.bin`)**. A language model can improve transcription accuracy by providing contextual information. - ---- - -### `punctuation` Section - -This section focuses on adding proper punctuation to the transcribed text. - -* `podcasts_path`: (As explained above) The directory where the transcribed text files are located. -* `model_name`: Specifies the **name of the RUPunct model** to be used for punctuation restoration. -* `num_workers`: The **number of parallel processes per GPU** to use for punctuation. ---- - -### `accent` Section - -In the transcribed text this part is restored with accents. - -* `podcasts_path`: (As explained above) The directory containing the relevant podcast files. -* `num_workers`: The **number of parallel processes per GPU** to use for accent processing. -* `model_name`: Specifies the **name of the ruAccent model** to be used. - ---- - -### `phonemizer` Section - -This section is responsible for converting text into phonetic representations (phonemes). - -* `podcasts_path`: (As explained above) The directory where the text files (from transcription and punctuation stages) are located. -* `num_workers`: The **number of parallel processes per GPU** to use for phonemization. ---- - -### `classification` Section - -This section relates to global speaker clustering. - -* `podcasts_path`: (As explained above) The directory containing the podcast files relevant for classification. -* `num_workers`: The **number of parallel processes per GPU** to use for classification. -* `threshold`: This is the **speaker classification confidence threshold**. Values typically range from `0.6` to `0.9`. A higher threshold means the model needs to be more confident in its classification to assign a label. -* `model_path`: Specifies the **path to the pretrained speaker classification model** in `.pt` format. ---- - -### Execution Scripts - -Each processing script (`*_yaml.sh` and `*_args.sh`) offers flexibility in how parameters are provided: - -* `*_yaml.sh`: These scripts read all necessary parameters directly from the main `config.yaml` file, ensuring consistency across different stages. -* `*_args.sh`: These scripts allow for hardcoded arguments directly within the shell script itself, which can be useful for quick tests or specific overrides without modifying the main configuration file. - -## Environment Variables - -Create a `.env` file in the project root with the following: - -```ini -HF_TOKEN= -YANDEX_KEY= -``` - -* `HF_TOKEN`: Required for speaker count estimation. -* `YANDEX_KEY`: Required for dataset downloads. - ---- - -## Important Notes - -- All scripts must be executed from the **project root directory**. -- Paths in the config file must be **absolute**. -- The processing scripts (punctuation, accents, yofication) should be run **sequentially**. -- You’ll need: - - Yandex Music API key ([How to get one](https://yandex-music.readthedocs.io/en/main/token.html)) - - Hugging Face token - -## Models - -Place all required models under the `models/` directory with the following structure: +## Citation -``` -models/ -├── vosblink_resnet/ # Speaker classification model -│ └── ... -└── nisqa_s.tar # Audio quality assessment model +```bibtex +@article{borodin2025datacentric, + title={A Data-Centric Framework for Addressing Phonetic and Prosodic Challenges in Russian Speech Generative Models}, + author={Borodin, Kirill and Vasiliev, Nikita and Kudryavtsev, Vasiliy and Maslov, Maxim and Gorodnichev, Mikhail and Rogov, Oleg and Mkrtchian, Grach}, + journal={arXiv preprint arXiv:2507.13563}, + year={2025} +} ``` -Supported models: - -- [NISQA](https://github.com/deepvk/NISQA-s) – Audio quality assessment. -- [GigaAM](https://github.com/salute-developers/GigaAM) – ASR. -- [ruAccent](https://github.com/Den4ikAI/ruaccent) – Accent restoration. -- [RUPunct](https://huggingface.co/RUPunct/RUPunct_big) – Punctuation restoration. -- [VoxBlink ResNet](https://github.com/wenet-e2e/wespeaker) – Speaker classification. -- [TryIPaG2P](https://github.com/NikiPshg/TryIPaG2P) – Phonemization. -- [Speaker Diarization](https://github.com/pyannote/pyannote-audio) – Speaker diarization. -- [Whisper](https://github.com/SYSTRAN/faster-whisper) – ASR + segmentation +**Paper**: [arXiv:2507.13563](https://arxiv.org/abs/2507.13563) +**DOI**: [10.48550/arXiv.2507.13563](https://doi.org/10.48550/arXiv.2507.13563) --- -## Citation +## Models & tooling -If you use this pipeline in your research or production, please cite: -``` -``` +| Piece | Role | +|--------|------| +| **Sortformer** (ONNX) | streaming diarization, single-speaker slices | +| **[Smart Turn](https://github.com/pipecat-ai/smart-turn)** (`smart-turn-v3.0.onnx`) | end-of-speech / turn boundaries | +| **Music detector** (`music_detection.safetensors`) | drop music-heavy chunks | +| **DistillMOS** | predicted MOS in `balalaika.csv` | +| **[onnx-asr](https://github.com/istupakov/onnx-asr)** | GigaAM v3 CTC/RNNT, Vosk, T-one, Parakeet, Canary, Whisper, … | +| **[RUPunct](https://huggingface.co/RUPunct/RUPunct_big)** | punctuation | +| **[ruAccent](https://github.com/Den4ikAI/ruaccent)** | stress marks | +| **TryIParu** (`tryiparu`) | grapheme → IPA | --- -## References and Acknowledgements - -Thanks to all the developers and contributors who made this project possible. - - - - - +## License +See [LICENSE](LICENSE). \ No newline at end of file diff --git a/base.sh b/base.sh index 8aed625..b8aa0b2 100644 --- a/base.sh +++ b/base.sh @@ -9,6 +9,18 @@ activate_venv() { fi source "$venv_path/bin/activate" echo "Activated: $(which python)" + + local python_version=$(python3 -c "import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}')") + local nvidia_base="$venv_path/lib/python$python_version/site-packages/nvidia" + + if [ -d "$nvidia_base" ]; then + export LD_LIBRARY_PATH="${nvidia_base}/cublas/lib:${nvidia_base}/cudnn/lib:${nvidia_base}/cuda_runtime/lib:${nvidia_base}/cuda_nvrtc/lib:${nvidia_base}/cufft/lib:${nvidia_base}/nvjitlink/lib:${nvidia_base}/cusolver/lib:${nvidia_base}/cusparse/lib:${LD_LIBRARY_PATH:-}" + fi + + local trt_libs="$venv_path/lib/python$python_version/site-packages/tensorrt_libs" + if [ -d "$trt_libs" ]; then + export LD_LIBRARY_PATH="${trt_libs}:${LD_LIBRARY_PATH:-}" + fi } if [ -z "${1:-}" ]; then @@ -17,16 +29,15 @@ if [ -z "${1:-}" ]; then fi CONFIG_PATH=$(realpath "$1") - +echo $CONFIG_PATH "--src" SCRIPTS=( - "./src/download/download_yaml.sh" + # "./src/download/download_yaml.sh" "./src/preprocess/preprocess_yaml.sh" "./src/separation/separation_yaml.sh" "./src/transcription/transcription_yaml.sh" "./src/punctuation/punctuation_yaml.sh" "./src/accents/accents_yaml.sh" "./src/phonemizer/phonemizer_yaml.sh" - "./src/classification/classification_yaml.sh" "./src/collate_yamls.sh" ) @@ -46,4 +57,4 @@ for script in "${SCRIPTS[@]}"; do } done -echo -e "\n\033[1;32mAll scripts executed successfully!\033[0m" \ No newline at end of file +echo -e "\n\033[1;32mAll scripts executed successfully!\033[0m" diff --git a/benchmarking/.gitignore b/benchmarking/.gitignore new file mode 100644 index 0000000..21c1a3a --- /dev/null +++ b/benchmarking/.gitignore @@ -0,0 +1,2 @@ +__pycache__/ +reports/ diff --git a/benchmarking/README.md b/benchmarking/README.md new file mode 100644 index 0000000..434ac74 --- /dev/null +++ b/benchmarking/README.md @@ -0,0 +1,12 @@ +# Benchmarking + +```bash +DISABLE_DIARIZATION=1 \ +TARGET=pipeline.base \ +DATASET=/mnt/drive2_8tb/ruslan \ +NUM_SAMPLES=512 \ +REPEATS=1 \ +GPU_IDS=0 \ +benchmarking/run_benchmark.sh +``` + diff --git a/benchmarking/__init__.py b/benchmarking/__init__.py new file mode 100644 index 0000000..68b72ff --- /dev/null +++ b/benchmarking/__init__.py @@ -0,0 +1 @@ +"""Benchmarking harness package for Balalaika.""" diff --git a/benchmarking/bench.py b/benchmarking/bench.py new file mode 100755 index 0000000..18a6c10 --- /dev/null +++ b/benchmarking/bench.py @@ -0,0 +1,15 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import sys +from pathlib import Path + +REPO_ROOT = Path(__file__).resolve().parents[1] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +from benchmarking.cli import main + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/benchmarking/cli.py b/benchmarking/cli.py new file mode 100644 index 0000000..dc6f6a4 --- /dev/null +++ b/benchmarking/cli.py @@ -0,0 +1,203 @@ +from __future__ import annotations + +import argparse +import copy +import json +from dataclasses import asdict +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, List + +from .common import REPO_ROOT, load_full_config, parse_gpu_ids, run_git, utc_now +from .resources import filter_gpu_inventory, query_gpu_inventory +from .runner import ( + aggregate_repeats, + host_metadata, + make_env, + repeat_label, + resolve_source_dataset, + run_single_repeat, + selected_gpu_ids, +) +from .sampling import collect_source_samples +from .targets import TARGETS, list_targets + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Benchmark harness for Balalaika stages and models.") + parser.add_argument("--target", type=str, help="Benchmark target name") + parser.add_argument("--config-path", type=Path, default=REPO_ROOT / "configs" / "config.yaml") + parser.add_argument("--dataset", type=Path, help="Source dataset root. Defaults to the dataset from config.") + parser.add_argument( + "--num-examples", + type=int, + default=None, + help="How many examples to benchmark. Omit or set <= 0 to use all eligible files.", + ) + parser.add_argument("--sample-mode", choices=("first", "random"), default="first") + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--repeats", type=int, default=3) + parser.add_argument("--warmup-repeats", type=int, default=0) + parser.add_argument("--gpu-ids", type=parse_gpu_ids) + parser.add_argument("--num-gpus", type=int) + parser.add_argument("--cpu-workers-per-gpu", type=int) + parser.add_argument("--cpu-workers-total", type=int) + parser.add_argument("--batch-size-override", type=int) + parser.add_argument("--model-name-override", type=str) + parser.add_argument("--disable-diarization", action="store_true") + parser.add_argument("--sample-interval-sec", type=float, default=0.5) + parser.add_argument("--output-root", type=Path, default=REPO_ROOT / "benchmarking" / "reports") + parser.add_argument("--keep-workdirs", action="store_true") + parser.add_argument("--dry-run", action="store_true") + parser.add_argument("--list-targets", action="store_true") + return parser.parse_args() + + +def main() -> int: + args = parse_args() + + if args.list_targets: + list_targets() + return 0 + + if not args.target: + raise SystemExit("--target is required unless --list-targets is used") + + if args.target not in TARGETS: + raise SystemExit(f"Unknown target: {args.target}. Use --list-targets.") + + target = TARGETS[args.target] + base_config = load_full_config(args.config_path.resolve()) + args.dataset = resolve_source_dataset(args, copy.deepcopy(base_config), target) + + if not args.dataset.exists(): + raise SystemExit(f"Dataset does not exist: {args.dataset}") + + gpu_inventory = query_gpu_inventory() + gpu_ids = selected_gpu_ids(args, gpu_inventory) + scoped_gpu_inventory = filter_gpu_inventory(gpu_inventory, gpu_ids) + env = make_env(args, gpu_ids if target.uses_gpu else []) + + selected_samples = collect_source_samples( + source_dataset=args.dataset, + target=target, + config=base_config, + sample_mode=args.sample_mode, + num_examples=args.num_examples, + seed=args.seed, + ) + + if not selected_samples: + raise SystemExit( + "No eligible samples were found. Check dataset contents and target prerequisites." + ) + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + run_root = (args.output_root / f"{timestamp}__{target.name.replace('.', '_')}").resolve() + run_root.mkdir(parents=True, exist_ok=True) + + selected_samples_path = run_root / "selected_samples.json" + with selected_samples_path.open("w", encoding="utf-8") as handle: + json.dump([asdict(sample) for sample in selected_samples], handle, indent=2, ensure_ascii=False) + + summary_stub = { + "report_version": 1, + "created_at_utc": utc_now(), + "repo_root": str(REPO_ROOT), + "git": { + "branch": run_git(("git", "branch", "--show-current")), + "commit": run_git(("git", "rev-parse", "HEAD")), + "status_short": run_git(("git", "status", "--short")), + }, + "host": host_metadata(), + "gpu_inventory": scoped_gpu_inventory, + "target": { + "name": target.name, + "description": target.description, + "modules": list(target.modules), + "required_sidecars": list(target.required_sidecars), + "copied_sidecars": list(target.copied_sidecars), + "uses_gpu": target.uses_gpu, + }, + "params": { + "config_path": str(args.config_path.resolve()), + "source_dataset": str(args.dataset), + "num_examples_requested": args.num_examples, + "num_examples_selected": len(selected_samples), + "sample_mode": args.sample_mode, + "seed": args.seed, + "repeats": args.repeats, + "warmup_repeats": args.warmup_repeats, + "gpu_ids": gpu_ids, + "cpu_workers_per_gpu": args.cpu_workers_per_gpu, + "cpu_workers_total": args.cpu_workers_total, + "batch_size_override": args.batch_size_override, + "model_name_override": args.model_name_override, + "disable_diarization": args.disable_diarization, + "sample_interval_sec": args.sample_interval_sec, + "keep_workdirs": args.keep_workdirs, + "cuda_visible_devices": env.get("CUDA_VISIBLE_DEVICES"), + }, + "selected_samples_path": str(selected_samples_path), + } + + if args.dry_run: + print(json.dumps(summary_stub, indent=2, ensure_ascii=False)) + return 0 + + warmups: List[Dict[str, Any]] = [] + for warmup_index in range(1, args.warmup_repeats + 1): + label = repeat_label("warmup", warmup_index) + warmups.append( + run_single_repeat( + label=label, + target=target, + sample_records=selected_samples, + base_config=base_config, + run_root=run_root, + args=args, + gpu_ids=gpu_ids, + env=env, + ) + ) + + repeats: List[Dict[str, Any]] = [] + for repeat_index in range(1, args.repeats + 1): + label = repeat_label("repeat", repeat_index) + repeats.append( + run_single_repeat( + label=label, + target=target, + sample_records=selected_samples, + base_config=base_config, + run_root=run_root, + args=args, + gpu_ids=gpu_ids, + env=env, + ) + ) + + report = dict(summary_stub) + report["warmups"] = warmups + report["repeats"] = repeats + report["aggregate"] = aggregate_repeats(repeats) + report["successful"] = report["aggregate"]["failed_repeats"] == 0 + + report_path = run_root / "report.json" + with report_path.open("w", encoding="utf-8") as handle: + json.dump(report, handle, indent=2, ensure_ascii=False) + handle.write("\n") + + final_summary = { + "report_path": str(report_path), + "target": target.name, + "successful": report["successful"], + "successful_repeats": report["aggregate"]["successful_repeats"], + "avg_rtf": report["aggregate"]["rtf"]["avg"], + "avg_gpu_util_percent": report["aggregate"]["gpu_util_percent"]["avg"], + "avg_gpu_vram_mb": report["aggregate"]["gpu_vram_mb"]["avg"], + "avg_rss_gb": report["aggregate"]["rss_gb"]["avg"], + "avg_cpu_util_percent": report["aggregate"]["cpu_util_percent"]["avg"], + } + print(json.dumps(final_summary, ensure_ascii=False)) + return 0 if report["successful"] else 1 diff --git a/benchmarking/common.py b/benchmarking/common.py new file mode 100644 index 0000000..c940b72 --- /dev/null +++ b/benchmarking/common.py @@ -0,0 +1,156 @@ +from __future__ import annotations + +import os +import shutil +import subprocess +import sys +import warnings +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional, Sequence + +import yaml + +REPO_ROOT = Path(__file__).resolve().parents[1] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +TRANSCRIPTION_MODELS: tuple[str, ...] = ( + "giga_ctc", + "giga_rnnt", + "giga_ctc_lm", + "tone", + "vosk", + "vosk_small", + "parakeet_v2", + "parakeet_v3", + "canary", + "whisper_base", + "whisper_turbo", +) + +COLLATE_SIDECARS: tuple[str, ...] = ( + "_rover.txt", + "_punct.txt", + "_accent.txt", + "_rover_phonemes.txt", +) + +CLK_TCK = os.sysconf(os.sysconf_names["SC_CLK_TCK"]) +PAGE_SIZE = os.sysconf("SC_PAGE_SIZE") + + +def eprint(message: str) -> None: + print(message, file=sys.stderr, flush=True) + + +def utc_now() -> str: + return datetime.now(timezone.utc).isoformat() + + +def preserve_repeat_artifacts(dataset_root: Path, repeat_root: Path) -> List[str]: + artifacts_dir = repeat_root / "artifacts" + preserved_paths: List[str] = [] + + for filename in ("balalaika.csv", "balalaika.parquet"): + source_path = dataset_root / filename + if not source_path.exists(): + continue + artifacts_dir.mkdir(parents=True, exist_ok=True) + destination_path = artifacts_dir / filename + shutil.copy2(source_path, destination_path) + preserved_paths.append(str(destination_path)) + + return preserved_paths + + +def ensure_dict(mapping: Dict[str, Any], key: str) -> Dict[str, Any]: + value = mapping.get(key) + if not isinstance(value, dict): + value = {} + mapping[key] = value + return value + + +def load_full_config(config_path: Path) -> Dict[str, Any]: + with config_path.open("r", encoding="utf-8") as handle: + data = yaml.safe_load(handle) or {} + if not isinstance(data, dict): + raise ValueError(f"Config {config_path} must be a YAML mapping") + return data + + +def write_yaml(path: Path, data: Dict[str, Any]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w", encoding="utf-8") as handle: + yaml.safe_dump(data, handle, sort_keys=False, allow_unicode=False) + + +def run_git(command: Sequence[str]) -> str: + try: + result = subprocess.run( + list(command), + cwd=REPO_ROOT, + capture_output=True, + text=True, + check=True, + ) + except Exception: + return "" + return result.stdout.strip() + + +def get_audio_duration(path: Path) -> float: + try: + import torchaudio + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info = torchaudio.info(str(path)) + if info.sample_rate > 0 and info.num_frames > 0: + return float(info.num_frames) / float(info.sample_rate) + except Exception: + pass + + try: + import soundfile as sf + + info = sf.info(str(path)) + if info.samplerate > 0 and info.frames > 0: + return float(info.frames) / float(info.samplerate) + except Exception: + pass + + raise RuntimeError(f"Unable to determine duration for {path}") + + +def sidecar_path(audio_path: Path, suffix: str) -> Path: + return audio_path.with_name(f"{audio_path.stem}{suffix}") + + +def copy_file(src: Path, dst: Path) -> None: + dst.parent.mkdir(parents=True, exist_ok=True) + shutil.copy2(src, dst) + + +def summarize_numeric(values: Iterable[Optional[float]]) -> Dict[str, Optional[float]]: + filtered = [float(value) for value in values if value is not None] + if not filtered: + return {"avg": None, "max": None, "min": None} + return { + "avg": sum(filtered) / len(filtered), + "max": max(filtered), + "min": min(filtered), + } + + +def parse_gpu_ids(raw: Optional[str]) -> Optional[List[int]]: + if not raw: + return None + values: List[int] = [] + for part in raw.split(","): + part = part.strip() + if not part: + continue + values.append(int(part)) + return values or None diff --git a/benchmarking/models.py b/benchmarking/models.py new file mode 100644 index 0000000..d62059c --- /dev/null +++ b/benchmarking/models.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +import argparse +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Callable, Dict, Optional, TypeAlias + + +@dataclass(frozen=True) +class SampleRecord: + audio_path: str + relative_path: str + duration_sec: float + copied_sidecars: tuple[str, ...] + + +@dataclass(frozen=True) +class CommandSpec: + name: str + argv: tuple[str, ...] + + +ConfigMutator: TypeAlias = Callable[[Dict[str, Any], argparse.Namespace, Path], None] + + +@dataclass(frozen=True) +class TargetSpec: + name: str + description: str + modules: tuple[str, ...] + required_sidecars: tuple[str, ...] = () + copied_sidecars: tuple[str, ...] = () + mutator: Optional[ConfigMutator] = None + min_input_duration_from_config: Optional[tuple[str, ...]] = None + uses_gpu: bool = True diff --git a/benchmarking/resources.py b/benchmarking/resources.py new file mode 100644 index 0000000..3bc28f1 --- /dev/null +++ b/benchmarking/resources.py @@ -0,0 +1,366 @@ +from __future__ import annotations + +import os +import subprocess +import sys +import threading +import time +from pathlib import Path +from typing import Any, Dict, List, Optional, Sequence + +from .common import CLK_TCK, PAGE_SIZE, REPO_ROOT, summarize_numeric, utc_now + + +def build_runtime_library_paths() -> List[str]: + python_version = f"{sys.version_info.major}.{sys.version_info.minor}" + site_packages = Path(sys.prefix) / "lib" / f"python{python_version}" / "site-packages" + nvidia_root = site_packages / "nvidia" + + library_dirs: List[Path] = [] + for candidate in [ + nvidia_root / "cu13" / "lib", + nvidia_root / "cublas" / "lib", + nvidia_root / "cudnn" / "lib", + nvidia_root / "cuda_runtime" / "lib", + nvidia_root / "cuda_nvrtc" / "lib", + nvidia_root / "cufft" / "lib", + nvidia_root / "nvjitlink" / "lib", + nvidia_root / "cusolver" / "lib", + nvidia_root / "cusparse" / "lib", + nvidia_root / "cusparselt" / "lib", + nvidia_root / "nccl" / "lib", + site_packages / "tensorrt_libs", + ]: + if candidate.exists(): + library_dirs.append(candidate) + + seen: set[str] = set() + resolved: List[str] = [] + for path in library_dirs: + path_str = str(path) + if path_str in seen: + continue + seen.add(path_str) + resolved.append(path_str) + return resolved + + +def nvidia_smi_env() -> Dict[str, str]: + env = os.environ.copy() + env.pop("LD_LIBRARY_PATH", None) + return env + + +def proc_root() -> Path: + return Path("/proc") + + +def system_cpu_snapshot() -> Optional[tuple[int, int]]: + try: + first_line = (proc_root() / "stat").read_text(encoding="utf-8").splitlines()[0] + except Exception: + return None + + parts = first_line.split() + if not parts or parts[0] != "cpu": + return None + + values = [int(value) for value in parts[1:]] + total = sum(values) + idle = values[3] + values[4] if len(values) > 4 else values[3] + return total, idle + + +def system_memory_snapshot() -> Dict[str, Optional[float]]: + meminfo_path = proc_root() / "meminfo" + values: Dict[str, int] = {} + try: + for line in meminfo_path.read_text(encoding="utf-8").splitlines(): + key, _, raw_value = line.partition(":") + parts = raw_value.strip().split() + if not parts: + continue + values[key] = int(parts[0]) * 1024 + except Exception: + return {"used_gb": None, "percent": None} + + total = values.get("MemTotal") + available = values.get("MemAvailable") + if not total or available is None: + return {"used_gb": None, "percent": None} + + used = total - available + return { + "used_gb": used / (1024 ** 3), + "percent": (used / total) * 100.0, + } + + +def read_process_stat(pid: int) -> Optional[Dict[str, int]]: + stat_path = proc_root() / str(pid) / "stat" + statm_path = proc_root() / str(pid) / "statm" + + try: + stat_text = stat_path.read_text(encoding="utf-8") + statm_text = statm_path.read_text(encoding="utf-8") + except Exception: + return None + + right_paren = stat_text.rfind(")") + if right_paren == -1: + return None + fields = stat_text[right_paren + 2 :].split() + if len(fields) < 13: + return None + + try: + ppid = int(fields[1]) + utime = int(fields[11]) + stime = int(fields[12]) + resident_pages = int(statm_text.split()[1]) + except (IndexError, ValueError): + return None + + return { + "pid": pid, + "ppid": ppid, + "cpu_ticks": utime + stime, + "rss_bytes": resident_pages * PAGE_SIZE, + } + + +def process_tree_snapshot(root_pid: int) -> Dict[int, Dict[str, int]]: + stats: Dict[int, Dict[str, int]] = {} + children: Dict[int, List[int]] = {} + + for entry in proc_root().iterdir(): + if not entry.name.isdigit(): + continue + pid = int(entry.name) + stat = read_process_stat(pid) + if not stat: + continue + stats[pid] = stat + children.setdefault(stat["ppid"], []).append(pid) + + if root_pid not in stats: + return {} + + result: Dict[int, Dict[str, int]] = {} + pending = [root_pid] + while pending: + pid = pending.pop() + if pid in result: + continue + stat = stats.get(pid) + if not stat: + continue + result[pid] = stat + pending.extend(children.get(pid, [])) + + return result + + +def query_gpu_inventory() -> Dict[str, Any]: + command = [ + "nvidia-smi", + "--query-gpu=index,name,memory.total", + "--format=csv,noheader,nounits", + ] + try: + result = subprocess.run( + command, + cwd=REPO_ROOT, + capture_output=True, + text=True, + timeout=5, + env=nvidia_smi_env(), + check=True, + ) + except Exception as exc: + return {"gpus": [], "error": str(exc)} + + gpus: List[Dict[str, Any]] = [] + for line in result.stdout.splitlines(): + parts = [item.strip() for item in line.split(",")] + if len(parts) != 3: + continue + try: + gpus.append( + { + "index": int(parts[0]), + "name": parts[1], + "memory_total_mb": float(parts[2]), + } + ) + except ValueError: + continue + return {"gpus": gpus, "error": None} + + +def filter_gpu_inventory(inventory: Dict[str, Any], gpu_ids: Optional[List[int]]) -> Dict[str, Any]: + if gpu_ids is None: + return inventory + if gpu_ids == []: + return {"gpus": [], "error": inventory.get("error")} + gpus = [gpu for gpu in inventory.get("gpus", []) if gpu.get("index") in gpu_ids] + return {"gpus": gpus, "error": inventory.get("error")} + + +class ResourceSampler: + def __init__(self, pid: int, gpu_ids: Optional[List[int]], interval_sec: float) -> None: + self.pid = pid + self.gpu_ids = gpu_ids + self.interval_sec = interval_sec + self.samples: List[Dict[str, Any]] = [] + self.gpu_query_error: Optional[str] = None + self._stop_event = threading.Event() + self._thread: Optional[threading.Thread] = None + self._cpu_count = max(os.cpu_count() or 1, 1) + self._prev_process_cpu_ticks: Dict[int, int] = {} + self._prev_process_sample_time: Optional[float] = None + self._prev_system_cpu_snapshot: Optional[tuple[int, int]] = None + + def start(self) -> None: + self._thread = threading.Thread(target=self._run, name=f"resource-sampler-{self.pid}", daemon=True) + self._thread.start() + + def stop(self) -> None: + self._stop_event.set() + if self._thread: + self._thread.join(timeout=max(self.interval_sec * 2, 1.0)) + + def _query_gpu(self) -> List[Dict[str, Any]]: + if self.gpu_ids == []: + return [] + command = [ + "nvidia-smi", + "--query-gpu=index,utilization.gpu,memory.used,memory.total", + "--format=csv,noheader,nounits", + ] + try: + result = subprocess.run( + command, + cwd=REPO_ROOT, + capture_output=True, + text=True, + timeout=5, + env=nvidia_smi_env(), + check=True, + ) + except Exception as exc: + if self.gpu_query_error is None: + self.gpu_query_error = str(exc) + return [] + + metrics: List[Dict[str, Any]] = [] + for line in result.stdout.splitlines(): + parts = [item.strip() for item in line.split(",")] + if len(parts) != 4: + continue + try: + index = int(parts[0]) + if self.gpu_ids is not None and index not in self.gpu_ids: + continue + metrics.append( + { + "index": index, + "gpu_util_percent": float(parts[1]), + "memory_used_mb": float(parts[2]), + "memory_total_mb": float(parts[3]), + } + ) + except ValueError: + continue + return metrics + + def _run(self) -> None: + while not self._stop_event.is_set(): + now = time.monotonic() + process_stats = process_tree_snapshot(self.pid) + cpu_ticks_total = sum(stat["cpu_ticks"] for stat in process_stats.values()) + rss_bytes = sum(stat["rss_bytes"] for stat in process_stats.values()) + + cpu_percent_raw = 0.0 + if self._prev_process_sample_time is not None: + elapsed = now - self._prev_process_sample_time + if elapsed > 0: + previous_total = sum( + self._prev_process_cpu_ticks.get(pid, stat["cpu_ticks"]) + for pid, stat in process_stats.items() + ) + delta_ticks = max(cpu_ticks_total - previous_total, 0) + cpu_percent_raw = (delta_ticks / (elapsed * CLK_TCK)) * 100.0 + + self._prev_process_cpu_ticks = {pid: stat["cpu_ticks"] for pid, stat in process_stats.items()} + self._prev_process_sample_time = now + + system_cpu_percent = 0.0 + current_system_cpu_snapshot = system_cpu_snapshot() + if self._prev_system_cpu_snapshot and current_system_cpu_snapshot: + prev_total, prev_idle = self._prev_system_cpu_snapshot + total, idle = current_system_cpu_snapshot + total_delta = total - prev_total + idle_delta = idle - prev_idle + if total_delta > 0: + system_cpu_percent = (1.0 - (idle_delta / total_delta)) * 100.0 + self._prev_system_cpu_snapshot = current_system_cpu_snapshot + + memory_snapshot = system_memory_snapshot() + + sample = { + "timestamp_utc": utc_now(), + "process_count": len(process_stats), + "cpu_percent_raw": cpu_percent_raw, + "cpu_util_percent": cpu_percent_raw / self._cpu_count, + "cpu_cores_used": cpu_percent_raw / 100.0, + "rss_bytes": rss_bytes, + "rss_gb": rss_bytes / (1024 ** 3), + "system_cpu_percent": system_cpu_percent, + "system_ram_used_gb": memory_snapshot["used_gb"], + "system_ram_percent": memory_snapshot["percent"], + "gpus": self._query_gpu(), + } + self.samples.append(sample) + self._stop_event.wait(self.interval_sec) + + +def summarize_gpu_samples(samples: Sequence[Dict[str, Any]]) -> Dict[str, Any]: + per_gpu: Dict[str, Dict[str, List[float]]] = {} + all_utils: List[float] = [] + all_vram_mb: List[float] = [] + + for sample in samples: + for gpu in sample.get("gpus", []): + gpu_key = str(gpu["index"]) + bucket = per_gpu.setdefault(gpu_key, {"gpu_util_percent": [], "memory_used_mb": []}) + bucket["gpu_util_percent"].append(gpu["gpu_util_percent"]) + bucket["memory_used_mb"].append(gpu["memory_used_mb"]) + all_utils.append(gpu["gpu_util_percent"]) + all_vram_mb.append(gpu["memory_used_mb"]) + + return { + "gpu_util_percent": summarize_numeric(all_utils), + "gpu_vram_mb": summarize_numeric(all_vram_mb), + "per_gpu": { + gpu_key: { + "gpu_util_percent": summarize_numeric(values["gpu_util_percent"]), + "memory_used_mb": summarize_numeric(values["memory_used_mb"]), + } + for gpu_key, values in per_gpu.items() + }, + } + + +def summarize_resource_samples(samples: Sequence[Dict[str, Any]]) -> Dict[str, Any]: + summary = { + "sample_count": len(samples), + "cpu_util_percent": summarize_numeric(sample["cpu_util_percent"] for sample in samples), + "cpu_cores_used": summarize_numeric(sample["cpu_cores_used"] for sample in samples), + "rss_gb": summarize_numeric(sample["rss_gb"] for sample in samples), + "system_cpu_percent": summarize_numeric(sample["system_cpu_percent"] for sample in samples), + "system_ram_used_gb": summarize_numeric(sample["system_ram_used_gb"] for sample in samples), + "system_ram_percent": summarize_numeric(sample["system_ram_percent"] for sample in samples), + } + summary.update(summarize_gpu_samples(samples)) + return summary diff --git a/benchmarking/run_benchmark.sh b/benchmarking/run_benchmark.sh new file mode 100755 index 0000000..c2a1a64 --- /dev/null +++ b/benchmarking/run_benchmark.sh @@ -0,0 +1,89 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd) +REPO_ROOT=$(cd "${SCRIPT_DIR}/.." && pwd) +cd "${REPO_ROOT}" + +if [[ ! -f ".dev_venv/bin/activate" ]]; then + echo "Missing virtual environment: ${REPO_ROOT}/.dev_venv" >&2 + exit 1 +fi + +source ".dev_venv/bin/activate" + +TARGET="${TARGET:-${1:-}}" +if [[ -z "${TARGET}" ]]; then + echo "Usage: TARGET= DATASET=/path/to/data ${SCRIPT_DIR}/run_benchmark.sh" >&2 + echo "Or: ${SCRIPT_DIR}/run_benchmark.sh " >&2 + echo "Use TARGET=list to print available benchmark targets." >&2 + exit 1 +fi + +CONFIG_PATH="${CONFIG_PATH:-${REPO_ROOT}/configs/config.yaml}" +SAMPLE_MODE="${SAMPLE_MODE:-first}" +SEED="${SEED:-42}" +REPEATS="${REPEATS:-3}" +WARMUP_REPEATS="${WARMUP_REPEATS:-0}" +SAMPLE_INTERVAL_SEC="${SAMPLE_INTERVAL_SEC:-0.5}" +OUTPUT_ROOT="${OUTPUT_ROOT:-${REPO_ROOT}/benchmarking/reports}" + +ARGS=( + "--config-path" "${CONFIG_PATH}" + "--sample-mode" "${SAMPLE_MODE}" + "--seed" "${SEED}" + "--repeats" "${REPEATS}" + "--warmup-repeats" "${WARMUP_REPEATS}" + "--sample-interval-sec" "${SAMPLE_INTERVAL_SEC}" + "--output-root" "${OUTPUT_ROOT}" +) + +if [[ "${TARGET}" == "list" ]]; then + exec python3 "${SCRIPT_DIR}/bench.py" --list-targets +fi + +ARGS+=("--target" "${TARGET}") + +if [[ -n "${DATASET:-}" ]]; then + ARGS+=("--dataset" "${DATASET}") +fi + +if [[ -n "${NUM_SAMPLES:-}" ]]; then + ARGS+=("--num-examples" "${NUM_SAMPLES}") +fi + +if [[ -n "${GPU_IDS:-}" ]]; then + ARGS+=("--gpu-ids" "${GPU_IDS}") +elif [[ -n "${NUM_GPUS:-}" ]]; then + ARGS+=("--num-gpus" "${NUM_GPUS}") +fi + +if [[ -n "${CPU_WORKERS_PER_GPU:-}" ]]; then + ARGS+=("--cpu-workers-per-gpu" "${CPU_WORKERS_PER_GPU}") +fi + +if [[ -n "${CPU_WORKERS_TOTAL:-}" ]]; then + ARGS+=("--cpu-workers-total" "${CPU_WORKERS_TOTAL}") +fi + +if [[ -n "${BATCH_SIZE_OVERRIDE:-}" ]]; then + ARGS+=("--batch-size-override" "${BATCH_SIZE_OVERRIDE}") +fi + +if [[ -n "${MODEL_NAME_OVERRIDE:-}" ]]; then + ARGS+=("--model-name-override" "${MODEL_NAME_OVERRIDE}") +fi + +if [[ "${DISABLE_DIARIZATION:-0}" == "1" ]]; then + ARGS+=("--disable-diarization") +fi + +if [[ "${KEEP_WORKDIRS:-0}" == "1" ]]; then + ARGS+=("--keep-workdirs") +fi + +if [[ "${DRY_RUN:-0}" == "1" ]]; then + ARGS+=("--dry-run") +fi + +exec python3 "${SCRIPT_DIR}/bench.py" "${ARGS[@]}" diff --git a/benchmarking/runner.py b/benchmarking/runner.py new file mode 100644 index 0000000..b891b4f --- /dev/null +++ b/benchmarking/runner.py @@ -0,0 +1,235 @@ +from __future__ import annotations + +import argparse +import copy +import json +import os +import platform +import shutil +import socket +import subprocess +import sys +import time +from pathlib import Path +from typing import Any, Dict, List, Optional, Sequence + +from .common import REPO_ROOT, eprint, ensure_dict, preserve_repeat_artifacts, summarize_numeric, utc_now, write_yaml +from .models import CommandSpec, SampleRecord, TargetSpec +from .resources import ResourceSampler, build_runtime_library_paths, summarize_resource_samples +from .sampling import copy_benchmark_dataset +from .targets import build_commands + + +def run_command( + command: CommandSpec, + config_path: Path, + log_path: Path, + gpu_ids: Optional[List[int]], + sample_interval_sec: float, + env: Dict[str, str], +) -> Dict[str, Any]: + argv = list(command.argv) + [str(config_path)] + log_path.parent.mkdir(parents=True, exist_ok=True) + started_at = utc_now() + started = time.monotonic() + + with log_path.open("w", encoding="utf-8") as handle: + process = subprocess.Popen( + argv, + cwd=REPO_ROOT, + stdout=handle, + stderr=subprocess.STDOUT, + env=env, + ) + sampler = ResourceSampler(pid=process.pid, gpu_ids=gpu_ids, interval_sec=sample_interval_sec) + sampler.start() + return_code = process.wait() + sampler.stop() + + wall_time_sec = time.monotonic() - started + finished_at = utc_now() + summary = summarize_resource_samples(sampler.samples) + return { + "name": command.name, + "argv": argv, + "log_path": str(log_path), + "started_at_utc": started_at, + "finished_at_utc": finished_at, + "wall_time_sec": wall_time_sec, + "return_code": return_code, + "resource_summary": summary, + "resource_samples": sampler.samples, + "gpu_query_error": sampler.gpu_query_error, + } + + +def aggregate_repeats(repeats: Sequence[Dict[str, Any]]) -> Dict[str, Any]: + successful = [repeat for repeat in repeats if repeat.get("success")] + return { + "total_repeats": len(repeats), + "successful_repeats": len(successful), + "failed_repeats": len(repeats) - len(successful), + "wall_time_sec": summarize_numeric(repeat.get("wall_time_sec") for repeat in successful), + "rtf": summarize_numeric(repeat.get("rtf") for repeat in successful), + "x_realtime": summarize_numeric(repeat.get("x_realtime") for repeat in successful), + "cpu_util_percent": summarize_numeric( + repeat.get("resource_summary", {}).get("cpu_util_percent", {}).get("avg") + for repeat in successful + ), + "rss_gb": summarize_numeric( + repeat.get("resource_summary", {}).get("rss_gb", {}).get("avg") for repeat in successful + ), + "gpu_util_percent": summarize_numeric( + repeat.get("resource_summary", {}).get("gpu_util_percent", {}).get("avg") + for repeat in successful + ), + "gpu_vram_mb": summarize_numeric( + repeat.get("resource_summary", {}).get("gpu_vram_mb", {}).get("avg") + for repeat in successful + ), + } + + +def selected_gpu_ids(args: argparse.Namespace, inventory: Dict[str, Any]) -> Optional[List[int]]: + if args.gpu_ids is not None: + return args.gpu_ids + if args.num_gpus is not None: + return list(range(args.num_gpus)) + gpu_inventory = inventory.get("gpus", []) + if not gpu_inventory: + return [] + return [int(gpu["index"]) for gpu in gpu_inventory] + + +def make_env(args: argparse.Namespace, gpu_ids: Optional[List[int]]) -> Dict[str, str]: + env = os.environ.copy() + pythonpath_parts = [str(REPO_ROOT)] + if env.get("PYTHONPATH"): + pythonpath_parts.append(env["PYTHONPATH"]) + env["PYTHONPATH"] = os.pathsep.join(pythonpath_parts) + + ld_library_parts = build_runtime_library_paths() + if env.get("LD_LIBRARY_PATH"): + ld_library_parts.append(env["LD_LIBRARY_PATH"]) + if ld_library_parts: + env["LD_LIBRARY_PATH"] = os.pathsep.join(ld_library_parts) + + if gpu_ids is not None and gpu_ids != []: + env["CUDA_VISIBLE_DEVICES"] = ",".join(str(gpu_id) for gpu_id in gpu_ids) + return env + + +def repeat_label(prefix: str, index: int) -> str: + return f"{prefix}_{index:02d}" + + +def run_single_repeat( + label: str, + target: TargetSpec, + sample_records: Sequence[SampleRecord], + base_config: Dict[str, Any], + run_root: Path, + args: argparse.Namespace, + gpu_ids: Optional[List[int]], + env: Dict[str, str], +) -> Dict[str, Any]: + repeat_root = run_root / label + repeat_root.mkdir(parents=True, exist_ok=True) + + dataset_root = repeat_root / "dataset" + copy_benchmark_dataset(destination_dataset=dataset_root, sample_records=sample_records) + + effective_config = copy.deepcopy(base_config) + if target.mutator is not None: + target.mutator(effective_config, args, dataset_root) + + config_path = repeat_root / "config.yaml" + write_yaml(config_path, effective_config) + + commands = build_commands(target.modules) + command_results: List[Dict[str, Any]] = [] + all_samples: List[Dict[str, Any]] = [] + success = True + error_message = None + + for index, command in enumerate(commands, start=1): + log_path = repeat_root / f"command_{index:02d}_{command.name.replace('.', '_')}.log" + eprint(f"[{label}] running {command.name}") + result = run_command( + command=command, + config_path=config_path, + log_path=log_path, + gpu_ids=gpu_ids if target.uses_gpu else [], + sample_interval_sec=args.sample_interval_sec, + env=env, + ) + all_samples.extend(result.pop("resource_samples")) + command_results.append(result) + if result["return_code"] != 0: + success = False + error_message = f"{command.name} failed with exit code {result['return_code']}" + break + + samples_path = repeat_root / "resource_samples.jsonl" + with samples_path.open("w", encoding="utf-8") as handle: + for sample in all_samples: + handle.write(json.dumps(sample, ensure_ascii=False) + "\n") + + total_audio_sec = sum(record.duration_sec for record in sample_records) + wall_time_sec = sum(command_result["wall_time_sec"] for command_result in command_results) + rtf = (wall_time_sec / total_audio_sec) if total_audio_sec > 0 else None + x_realtime = (total_audio_sec / wall_time_sec) if wall_time_sec > 0 else None + + repeat_result = { + "label": label, + "repeat_root": str(repeat_root), + "dataset_root": str(dataset_root), + "config_path": str(config_path), + "samples_path": str(samples_path), + "success": success, + "error": error_message, + "command_results": command_results, + "wall_time_sec": wall_time_sec, + "total_audio_sec": total_audio_sec, + "rtf": rtf, + "x_realtime": x_realtime, + "resource_summary": summarize_resource_samples(all_samples), + "preserved_artifacts": preserve_repeat_artifacts(dataset_root, repeat_root), + } + + if not args.keep_workdirs: + shutil.rmtree(dataset_root, ignore_errors=True) + + return repeat_result + + +def host_metadata() -> Dict[str, Any]: + return { + "hostname": socket.gethostname(), + "platform": platform.platform(), + "python": sys.version, + "logical_cpu_count": os.cpu_count(), + "physical_cpu_count": None, + } + + +def resolve_source_dataset(args: argparse.Namespace, config: Dict[str, Any], target: TargetSpec) -> Path: + if args.dataset: + return args.dataset.resolve() + + fallback_sections = { + "collate.stage": "download", + "pipeline.base": "preprocess", + } + preferred_section = fallback_sections.get(target.name) + if preferred_section: + preferred = ensure_dict(config, preferred_section).get("podcasts_path") + if preferred: + return Path(preferred).resolve() + + for key in ("preprocess", "separation", "transcription", "punctuation", "accent", "phonemizer", "download"): + value = ensure_dict(config, key).get("podcasts_path") + if value: + return Path(value).resolve() + + raise ValueError("Dataset path was not provided and could not be inferred from config") diff --git a/benchmarking/sampling.py b/benchmarking/sampling.py new file mode 100644 index 0000000..305786c --- /dev/null +++ b/benchmarking/sampling.py @@ -0,0 +1,97 @@ +from __future__ import annotations + +import random +import shutil +from pathlib import Path +from typing import Any, Dict, List, Optional, Sequence + +from src.utils.utils import get_audio_paths + +from .common import copy_file, eprint, get_audio_duration, sidecar_path +from .models import SampleRecord, TargetSpec +from .targets import min_duration_for_target + + +def collect_source_samples( + source_dataset: Path, + target: TargetSpec, + config: Dict[str, Any], + sample_mode: str, + num_examples: Optional[int], + seed: int, +) -> List[SampleRecord]: + all_audio_paths = sorted(Path(path) for path in get_audio_paths(str(source_dataset))) + minimum_duration = min_duration_for_target(target, config) + invalid_audio_count = 0 + invalid_audio_examples: List[str] = [] + + def build_record(audio_path: Path) -> Optional[SampleRecord]: + nonlocal invalid_audio_count + for suffix in target.required_sidecars: + if not sidecar_path(audio_path, suffix).exists(): + return None + try: + if audio_path.stat().st_size == 0: + raise RuntimeError("empty file") + duration_sec = get_audio_duration(audio_path) + except Exception as exc: + invalid_audio_count += 1 + if len(invalid_audio_examples) < 5: + invalid_audio_examples.append(f"{audio_path}: {exc}") + return None + if minimum_duration is not None and duration_sec <= minimum_duration: + return None + try: + relative_path = str(audio_path.relative_to(source_dataset)) + except ValueError: + return None + copied_sidecars = tuple( + suffix for suffix in target.copied_sidecars if sidecar_path(audio_path, suffix).exists() + ) + return SampleRecord( + audio_path=str(audio_path), + relative_path=relative_path, + duration_sec=duration_sec, + copied_sidecars=copied_sidecars, + ) + + if sample_mode == "first": + records: List[SampleRecord] = [] + for audio_path in all_audio_paths: + record = build_record(audio_path) + if not record: + continue + records.append(record) + if num_examples is not None and num_examples > 0 and len(records) >= num_examples: + break + if invalid_audio_count: + eprint(f"Skipped {invalid_audio_count} unreadable audio files during sample selection.") + for example in invalid_audio_examples: + eprint(f" {example}") + return records + + candidates = [record for record in (build_record(path) for path in all_audio_paths) if record] + if invalid_audio_count: + eprint(f"Skipped {invalid_audio_count} unreadable audio files during sample selection.") + for example in invalid_audio_examples: + eprint(f" {example}") + random.Random(seed).shuffle(candidates) + if num_examples is not None and num_examples > 0: + return candidates[:num_examples] + return candidates + + +def copy_benchmark_dataset(destination_dataset: Path, sample_records: Sequence[SampleRecord]) -> None: + if destination_dataset.exists(): + shutil.rmtree(destination_dataset) + destination_dataset.mkdir(parents=True, exist_ok=True) + + for record in sample_records: + src_audio = Path(record.audio_path) + dst_audio = destination_dataset / record.relative_path + copy_file(src_audio, dst_audio) + for suffix in record.copied_sidecars: + src_sidecar = sidecar_path(src_audio, suffix) + dst_sidecar = sidecar_path(dst_audio, suffix) + if src_sidecar.exists(): + copy_file(src_sidecar, dst_sidecar) diff --git a/benchmarking/targets.py b/benchmarking/targets.py new file mode 100644 index 0000000..48de640 --- /dev/null +++ b/benchmarking/targets.py @@ -0,0 +1,363 @@ +from __future__ import annotations + +import argparse +import sys +from pathlib import Path +from typing import Any, Dict, List, Optional, Sequence + +from .common import COLLATE_SIDECARS, TRANSCRIPTION_MODELS, ensure_dict +from .models import CommandSpec, ConfigMutator, TargetSpec + + +def build_commands(modules: Sequence[str]) -> List[CommandSpec]: + return [ + CommandSpec( + name=module, + argv=(sys.executable, "-m", module, "--config_path"), + ) + for module in modules + ] + + +def base_mutator(config: Dict[str, Any], _: argparse.Namespace, work_dataset_path: Path) -> None: + config["cache_path"] = str(work_dataset_path.parent / "cache") + + +def set_dataset_everywhere(config: Dict[str, Any], work_dataset_path: Path) -> None: + dataset_path = str(work_dataset_path) + for key in ("download", "preprocess", "separation", "transcription", "punctuation", "accent", "phonemizer"): + section = ensure_dict(config, key) + section["podcasts_path"] = dataset_path + + +def patch_preprocess(config: Dict[str, Any], args: argparse.Namespace, work_dataset_path: Path) -> None: + base_mutator(config, args, work_dataset_path) + section = ensure_dict(config, "preprocess") + section["podcasts_path"] = str(work_dataset_path) + if args.cpu_workers_total is not None: + section["num_workers"] = args.cpu_workers_total + elif args.cpu_workers_per_gpu is not None: + section["num_workers"] = args.cpu_workers_per_gpu + + +def patch_separation_common(config: Dict[str, Any], args: argparse.Namespace, work_dataset_path: Path) -> Dict[str, Any]: + base_mutator(config, args, work_dataset_path) + section = ensure_dict(config, "separation") + section["podcasts_path"] = str(work_dataset_path) + section["cache_path"] = str(work_dataset_path.parent / "cache") + return section + + +def patch_separation_music_detect(config: Dict[str, Any], args: argparse.Namespace, work_dataset_path: Path) -> None: + section = patch_separation_common(config, args, work_dataset_path) + music_detect = ensure_dict(section, "music_detect") + music_detect["cache_path"] = str(work_dataset_path.parent / "cache") + if args.cpu_workers_per_gpu is not None: + music_detect["num_workers"] = args.cpu_workers_per_gpu + if args.batch_size_override is not None: + music_detect["bs"] = args.batch_size_override + + +def patch_separation_nisqa(config: Dict[str, Any], args: argparse.Namespace, work_dataset_path: Path) -> None: + section = patch_separation_common(config, args, work_dataset_path) + nisqa = ensure_dict(section, "nisqa") + if args.batch_size_override is not None: + nisqa["bs"] = args.batch_size_override + if args.cpu_workers_per_gpu is not None: + nisqa["num_workers"] = args.cpu_workers_per_gpu + section["bs"] = nisqa.get("bs", 32) + section["num_workers_nisqa"] = nisqa.get("num_workers", 4) + section["nisqa_config_path"] = nisqa.get("nisqa_config_path", "./configs/nisqa_b.yaml") + + +def patch_separation_distillmos(config: Dict[str, Any], args: argparse.Namespace, work_dataset_path: Path) -> None: + patch_separation_common(config, args, work_dataset_path) + + +def patch_separation_diarization(config: Dict[str, Any], args: argparse.Namespace, work_dataset_path: Path) -> None: + section = patch_separation_common(config, args, work_dataset_path) + diarization = ensure_dict(section, "diarization") + if args.disable_diarization: + diarization["enabled"] = False + if args.cpu_workers_per_gpu is not None: + diarization["num_workers"] = args.cpu_workers_per_gpu + + +def patch_separation_silence(config: Dict[str, Any], args: argparse.Namespace, work_dataset_path: Path) -> None: + section = patch_separation_common(config, args, work_dataset_path) + silence_detect = ensure_dict(section, "silence_detect") + if args.cpu_workers_per_gpu is not None: + silence_detect["num_workers"] = args.cpu_workers_per_gpu + + +def patch_separation_stage(config: Dict[str, Any], args: argparse.Namespace, work_dataset_path: Path) -> None: + patch_separation_common(config, args, work_dataset_path) + patch_separation_music_detect(config, args, work_dataset_path) + patch_separation_nisqa(config, args, work_dataset_path) + patch_separation_diarization(config, args, work_dataset_path) + patch_separation_silence(config, args, work_dataset_path) + + +def transcription_batch_section(model_name: str) -> str: + if "giga" in model_name: + return "giga" + if "vosk" in model_name: + return "vosk" + return model_name + + +def patch_transcription_common(config: Dict[str, Any], args: argparse.Namespace, work_dataset_path: Path) -> Dict[str, Any]: + base_mutator(config, args, work_dataset_path) + section = ensure_dict(config, "transcription") + section["podcasts_path"] = str(work_dataset_path) + return section + + +def patch_transcription_stage(config: Dict[str, Any], args: argparse.Namespace, work_dataset_path: Path) -> None: + section = patch_transcription_common(config, args, work_dataset_path) + if args.batch_size_override is not None: + for model_name in section.get("model_names", []): + model_section = ensure_dict(section, transcription_batch_section(str(model_name))) + model_section["batch_size"] = args.batch_size_override + + +def patch_transcription_model(model_name: str) -> ConfigMutator: + def mutator(config: Dict[str, Any], args: argparse.Namespace, work_dataset_path: Path) -> None: + section = patch_transcription_common(config, args, work_dataset_path) + section["model_names"] = [model_name] + section["consensus_num"] = 0 + section["use_rover"] = False + if args.batch_size_override is not None: + model_section = ensure_dict(section, transcription_batch_section(model_name)) + model_section["batch_size"] = args.batch_size_override + + return mutator + + +def patch_punctuation(config: Dict[str, Any], args: argparse.Namespace, work_dataset_path: Path) -> None: + base_mutator(config, args, work_dataset_path) + section = ensure_dict(config, "punctuation") + section["podcasts_path"] = str(work_dataset_path) + if args.cpu_workers_per_gpu is not None: + section["num_workers"] = args.cpu_workers_per_gpu + if args.model_name_override: + section["model_name"] = args.model_name_override + + +def patch_accent(config: Dict[str, Any], args: argparse.Namespace, work_dataset_path: Path) -> None: + base_mutator(config, args, work_dataset_path) + section = ensure_dict(config, "accent") + section["podcasts_path"] = str(work_dataset_path) + if args.cpu_workers_per_gpu is not None: + section["num_workers"] = args.cpu_workers_per_gpu + if args.model_name_override: + section["model_name"] = args.model_name_override + + +def patch_phonemizer(config: Dict[str, Any], args: argparse.Namespace, work_dataset_path: Path) -> None: + base_mutator(config, args, work_dataset_path) + section = ensure_dict(config, "phonemizer") + section["podcasts_path"] = str(work_dataset_path) + if args.cpu_workers_per_gpu is not None: + section["num_workers"] = args.cpu_workers_per_gpu + + +def patch_collate(config: Dict[str, Any], args: argparse.Namespace, work_dataset_path: Path) -> None: + base_mutator(config, args, work_dataset_path) + section = ensure_dict(config, "download") + section["podcasts_path"] = str(work_dataset_path) + if args.cpu_workers_total is not None: + section["num_workers"] = args.cpu_workers_total + elif args.cpu_workers_per_gpu is not None: + section["num_workers"] = args.cpu_workers_per_gpu + + +def patch_pipeline(config: Dict[str, Any], args: argparse.Namespace, work_dataset_path: Path) -> None: + set_dataset_everywhere(config, work_dataset_path) + patch_preprocess(config, args, work_dataset_path) + patch_separation_stage(config, args, work_dataset_path) + patch_transcription_stage(config, args, work_dataset_path) + patch_punctuation(config, args, work_dataset_path) + patch_accent(config, args, work_dataset_path) + patch_phonemizer(config, args, work_dataset_path) + patch_collate(config, args, work_dataset_path) + + +TARGETS: Dict[str, TargetSpec] = { + "preprocess.stage": TargetSpec( + name="preprocess.stage", + description="Full preprocess sequence: VAD chunking, crest-factor filter, loudness normalization.", + modules=( + "src.preprocess.preprocess", + "src.preprocess.crest_factor_remover", + "src.preprocess.preprocess_audio", + ), + mutator=patch_preprocess, + uses_gpu=True, + ), + "preprocess.vad": TargetSpec( + name="preprocess.vad", + description="Smart-turn VAD chunking only.", + modules=("src.preprocess.preprocess",), + mutator=patch_preprocess, + uses_gpu=True, + ), + "preprocess.crest_factor": TargetSpec( + name="preprocess.crest_factor", + description="Crest-factor filtering only.", + modules=("src.preprocess.crest_factor_remover",), + mutator=patch_preprocess, + uses_gpu=False, + ), + "preprocess.audio_normalize": TargetSpec( + name="preprocess.audio_normalize", + description="Loudness normalization only.", + modules=("src.preprocess.preprocess_audio",), + mutator=patch_preprocess, + uses_gpu=False, + ), + "separation.stage": TargetSpec( + name="separation.stage", + description="Full separation sequence: music_detect, NISQA, DistillMOS, diarization, silence_detect.", + modules=( + "src.separation.music_detect", + "src.separation.nisqa_process", + "src.separation.distillmos_process", + "src.separation.diarization", + "src.separation.silence_detect", + ), + mutator=patch_separation_stage, + uses_gpu=True, + ), + "separation.music_detect": TargetSpec( + name="separation.music_detect", + description="Music detection model only.", + modules=("src.separation.music_detect",), + mutator=patch_separation_music_detect, + uses_gpu=True, + ), + "separation.nisqa": TargetSpec( + name="separation.nisqa", + description="NISQA MOS model only.", + modules=("src.separation.nisqa_process",), + mutator=patch_separation_nisqa, + uses_gpu=True, + ), + "separation.distillmos": TargetSpec( + name="separation.distillmos", + description="DistillMOS model only.", + modules=("src.separation.distillmos_process",), + mutator=patch_separation_distillmos, + uses_gpu=True, + ), + "separation.diarization": TargetSpec( + name="separation.diarization", + description="Pyannote diarization only.", + modules=("src.separation.diarization",), + mutator=patch_separation_diarization, + uses_gpu=True, + ), + "separation.silence_detect": TargetSpec( + name="separation.silence_detect", + description="Silero silence metrics only.", + modules=("src.separation.silence_detect",), + mutator=patch_separation_silence, + uses_gpu=True, + ), + "transcription.stage": TargetSpec( + name="transcription.stage", + description="Full transcription stage using model_names from config.", + modules=("src.transcription.transcription",), + mutator=patch_transcription_stage, + uses_gpu=True, + ), + "punctuation.stage": TargetSpec( + name="punctuation.stage", + description="Punctuation restoration stage.", + modules=("src.punctuation.punctuation",), + required_sidecars=("_rover.txt",), + copied_sidecars=("_rover.txt",), + mutator=patch_punctuation, + uses_gpu=True, + ), + "accents.stage": TargetSpec( + name="accents.stage", + description="Accent restoration stage.", + modules=("src.accents.accents",), + required_sidecars=("_punct.txt",), + copied_sidecars=("_punct.txt",), + mutator=patch_accent, + uses_gpu=True, + ), + "phonemizer.stage": TargetSpec( + name="phonemizer.stage", + description="Phonemizer stage.", + modules=("src.phonemizer.phonemizer",), + required_sidecars=("_rover.txt",), + copied_sidecars=("_rover.txt",), + mutator=patch_phonemizer, + uses_gpu=True, + ), + "collate.stage": TargetSpec( + name="collate.stage", + description="Final collate into parquet.", + modules=("src.collate",), + copied_sidecars=COLLATE_SIDECARS, + mutator=patch_collate, + uses_gpu=False, + ), + "pipeline.base": TargetSpec( + name="pipeline.base", + description="Full base pipeline from preprocess to collate.", + modules=( + "src.preprocess.preprocess", + "src.preprocess.crest_factor_remover", + "src.preprocess.preprocess_audio", + "src.separation.music_detect", + "src.separation.distillmos_process", + "src.transcription.transcription", + "src.punctuation.punctuation", + "src.accents.accents", + "src.phonemizer.phonemizer", + "src.collate", + ), + mutator=patch_pipeline, + uses_gpu=True, + ), +} + +for transcription_model in TRANSCRIPTION_MODELS: + TARGETS[f"transcription.{transcription_model}"] = TargetSpec( + name=f"transcription.{transcription_model}", + description=f"Transcription benchmark for {transcription_model}.", + modules=("src.transcription.transcription",), + mutator=patch_transcription_model(transcription_model), + uses_gpu=True, + ) + + +def target_names() -> List[str]: + return sorted(TARGETS.keys()) + + +def min_duration_for_target(target: TargetSpec, config: Dict[str, Any]) -> Optional[float]: + if not target.min_input_duration_from_config: + return None + current: Any = config + for key in target.min_input_duration_from_config: + if not isinstance(current, dict): + return None + current = current.get(key) + if current is None: + return None + try: + return float(current) + except (TypeError, ValueError): + return None + + +def list_targets() -> None: + for name in target_names(): + target = TARGETS[name] + sidecars = ",".join(target.required_sidecars) if target.required_sidecars else "-" + print(f"{name}\tGPU={int(target.uses_gpu)}\tinputs={sidecars}\t{target.description}") diff --git a/configs/config.yaml b/configs/config.yaml index 13fef0f..03a30c3 100644 --- a/configs/config.yaml +++ b/configs/config.yaml @@ -1,49 +1,148 @@ +# ============================================================================= +# Balalaika — config reference (configs/config.yaml) +# ============================================================================= +# How it is loaded: each stage calls load_config(config_path, SECTION) and only +# sees that SECTION's mapping. Keys at the root that are not a section name are +# ignored by load_config (except you may read the full file elsewhere). +# +# podcasts_path — use the SAME dataset root in every section that processes +# on-disk chunks under playlist_id/podcast_id/, unless you know you need +# different trees. +# +# cache_path (top-level) — not read by load_config-based stages; reserved / +# used by some benchmarks. Music detection uses separation.music_detect.cache_path +# if you add it (default ./cache inside that subsection in code). +# +# download — Yandex Music downloader (src/download). +# NOTE: src/collate.py also loads the "download" section for podcasts_path +# and num_workers when merging text sidecars into balalaika.parquet. Keep +# download.podcasts_path aligned with the rest of the pipeline. +# +# preprocess — Sortformer diarization, Smart Turn VAD, chunking, crest filter, +# loudness normalization (see src/preprocess/preprocess_yaml.sh order). +# +# separation — music detection (deletes clips) + DistillMOS → balalaika.csv. +# +# transcription — onnx-asr multi-model ASR, optional ROVER (src/transcription). +# +# punctuation — RUPunct on *_rover.txt files. +# +# accent — ruAccent on *_punct.txt (YAML key is accent, not accents). +# +# phonemizer — TryIParu G2P on *_rover.txt → *_rover_phonemes.txt. +# +# export — WebDataset tar shards (src/to_webdataset.py); collate parquet uses +# download section as noted above. +# +# Optional transcription keys (not set below; see src/transcription/transcription.py): +# model_path, vosk_path — local override paths for onnx-asr load_model. +# quantization — passed to onnx-asr when set. +# vad_params — Silero VAD kwargs when use_vad: True. +# ============================================================================= + +cache_path: ./cache + download: - podcasts_path: /home/nikita/balalaika - episodes_limit: 1 + # Root directory where downloaded episodes are stored ({podcast_id}/{episode_id}/*.mp3). + podcasts_path: /mnt/drive2_8tb/ruslan + # Max episodes fetched per playlist when using the downloader. + episodes_limit: 100 + # Parallel download threads. num_workers: 4 + # Optional pickle path with podcast / episode URLs (see download module). podcasts_urls_file: '' - + preprocess: - podcasts_path: /home/nikita/balalaika # full path to the folder where the podcasts are located + # Dataset root: long-form audio before chunking; after chunking, same tree holds segments. + podcasts_path: /mnt/drive2_8tb/ruslan + # Maximum duration of each output segment in seconds. duration: 15 - num_workers: 2 - whisper_model: 'large-v3' - compute_type: 'float16' - beam_size: 5 + # Worker processes per GPU for diarization+VAD chunking (Sortformer is heavy — often 1). + num_workers: 1 + # Delete files whose crest factor (peak / RMS) exceeds this value (linear ratio, not dB). + # crest_factor is saved to balalaika.csv for all files; deleted files also have their CSV row removed. + crest_treshold: 10 + # True peak ceiling in dBFS before/after loudness step (see preprocess_audio). + peak: -1.0 + # Target integrated loudness in LUFS (ITU-R BS.1770-4 meter in code). + loudness: -23.0 + # Block size in seconds for loudness measurement windows. + block_size: 0.400 + # Max audio length in seconds fed to Sortformer at once; longer files are windowed. + chunk_duration: 900 + # ONNX Sortformer weights for streaming diarization. + sortformer_model: ./models/diar_streaming_sortformer_4spk-v2.1.onnx + # Use TensorRT execution provider for Sortformer ONNX Runtime session (needs CUDA+TensorRT). + use_tensorrt: True + vad_args: + # Smart Turn end-of-turn ONNX model path. + smart_vad_model: ./models/smart-turn-v3.0.onnx + # Smart Turn probability threshold for cutting/refining segment ends. + smart_vad_threshold: 0.4 separation: - podcasts_path: /home/nikita/balalaika # full path to the folder where the chopped podcasts are located - num_workers: 4 - nisqa_config: "configs/nisqa_s.yaml" - one_speaker: False + # Root containing chunked audio and balalaika.csv. + podcasts_path: /mnt/drive2_8tb/ruslan + + music_detect: + # Inference batch size for music classifier. + bs: 32 + # DataLoader worker processes per GPU worker (see music_detect.py). + num_workers: 1 + # Fine-tuned WavLM music-detector weights (safetensors). + music_detect_model: ./models/music_detection.safetensors + # If P(music) > threshold, the file is deleted and its CSV row removed. + # music_prob is saved to balalaika.csv for all processed files. + threshold: 0.6 + # Optional: HF model id for WavLM backbone (default microsoft/wavlm-base-plus in code). + # base_model: microsoft/wavlm-base-plus + # Optional: directory for per-worker audio length cache (default ./cache). + # cache_path: ./cache transcription: - podcasts_path: /home/nikita/balalaika # full path to the folder where the chopped podcasts are located - model_name: "rnnt" - num_workers: 4 # per one gpu - with_timestamps: False # ctc only support - lm_path: '' # path to lm.bin - + podcasts_path: /mnt/drive2_8tb/ruslan + # If >0: before running model K, skip files where at least this many *earlier* + # models in model_names already wrote identical normalized transcripts (early exit). + consensus_num: 3 + # Request word-level timestamps where the model supports it (.tst TSV sidecars). + with_timestamps: True + # TensorRT EP for onnx-asr ORT sessions (requires compatible CUDA/TensorRT). + use_tensorrt: True + # Chunk long audio with Silero VAD inside onnx-asr (e.g. >30s). + use_vad: False + # After all models, run ROVER to write {stem}_rover.txt. + use_rover: True + # Logical names → HF/onnx-asr models (see MODEL_MAP in transcription.py). + # Supported: giga_ctc, giga_rnnt, vosk, vosk_small, tone, parakeet_v2, parakeet_v3, + # canary, whisper_base, whisper_turbo, giga_ctc_lm, ... + model_names: ['giga_ctc', 'giga_rnnt', 'vosk', 'tone'] + # Batch size passed to recognize() loops in onnx-asr. + batch_size: 16 + punctuation: - podcasts_path: /home/nikita/balalaika # full path to the folder where the chopped podcasts are located + podcasts_path: /mnt/drive2_8tb/ruslan + # Hugging Face model id for RUPunct. model_name: "RUPunct/RUPunct_big" - num_workers: 4 # per one gpu + # Parallel worker count for punctuation jobs (see punctuation.py). + num_workers: 1 accent: - podcasts_path: /home/nikita/balalaika # full path to the folder where the chopped podcasts are located - num_workers: 8 + podcasts_path: /mnt/drive2_8tb/ruslan + num_workers: 1 + # ruAccent package model tag (e.g. turbo3.1). model_name: turbo3.1 - -phonemizer: - podcasts_path: /home/nikita/balalaika # full path to the folder where the text files are located - num_workers: 4 # per one gpu - -classification: - podcasts_path: /home/nikita/balalaika - num_workers: 8 - threshold: 0.8 - model_path: models/voxblink2_samresnet100_ft + # TensorRT EP for ruAccent ONNX path when supported. + use_tensorrt: False +phonemizer: + podcasts_path: /mnt/drive2_8tb/ruslan + num_workers: 1 - +export: + podcasts_path: /mnt/drive2_8tb/ruslan + # Parallel processes writing WebDataset shards. + num_workers: 1 + # Approximate max uncompressed bytes per .tar shard before rolling to next file. + max_shard_size: 536870912 + # Max samples per shard file. + max_shard_count: 10000 diff --git a/configs/nisqa_s.yaml b/configs/nisqa_s.yaml deleted file mode 100644 index 7599c96..0000000 --- a/configs/nisqa_s.yaml +++ /dev/null @@ -1,45 +0,0 @@ -# Mel-Specs options -ms_sr: null # resample speech signal to 'ms_sr' -ms_fmax: 20000 # maximum considered Mel-band frequency (in Hz), set to 20k for fullband speech samples -ms_n_fft: 960 # fft size -ms_hop_length: 480 # hop length of fft windowing -ms_win_length: 960 # fft window length, will be padded with zeros to match 'ms_n_fft' -ms_n_mels: 48 # number of Mel bands -ms_seg_length: 15 # width of extracted Mel-spec segments (in bins) -ms_seg_hop_length: 3 # hop length of segments (in bins), decreasing this may improve performance but increases memory usage and runtime. -ms_channel: null # audio channel in case of stereo file (0->left, 1->right). if null, mono mix is used -ms_max_length: 1300 # spec length for training only (in bins). if samples of different duration are used in dataloader they will be padded. one segment corresponds to 40ms -> 0.04*1300=52sec max sample duration. change if you want to train on different samples - - -# CNN parameters -cnn_c_out_1: 16 # number of output channels of first convolutional layer -cnn_c_out_2: 32 # number of output channels of the second convolutional layer -cnn_c_out_3: 64 # number of output channels of the last four convolutional layer -cnn_kernel_size: !!python/tuple [3,3] -cnn_dropout: 0.2 -cnn_fc_out_h: null # length of the CNN output feature vector, if 'null' the last fully connected layer is omitted -cnn_pool_1: [24,7] # output dimensions of first adaptive pooling ('adaptive' CNN only) -cnn_pool_2: [12,5] # output dimensions of second adaptive pooling ('adaptive' CNN only) -cnn_pool_3: [6,3] # output dimensions of third adaptive pooling ('adaptive' CNN only) - -# LSTM parameters -td_lstm_h: 128 # number of LSTM hidden units -td_lstm_num_layers: 1 # LSTM depth -td_lstm_dropout: 0 -td_lstm_bidirectional: true # use bidirectional LSTM -> hidden units x 2 - -# Arguments for inference -ckp: models/nisqa_s.tar # checkpoint that will be used for inference -inf_device: cuda # device for inference runs -warmup: False # warmup run before inference; usually is not needed on CPU runs - -#Microphone inference specifics -frame: 2 # framesize for file/mic capture in seconds -updates: null # if null, metrics will be calculated over whole available frame (every [frame] seconds); if int - metrics will be calculated every n bins (which equivalent to sr / ms_n_fft * n seconds, for updates=1 it will be 48000/960*1 = 20ms) -sd_device: null # check mic device ID in sounddevice or leave null to use default input device (most probably your system default mic) -sd_dump: null # set this to filename if you want to dump mic signal into file - - - - - diff --git a/create_dev_env.sh b/create_dev_env.sh index e185bda..c2dc144 100644 --- a/create_dev_env.sh +++ b/create_dev_env.sh @@ -7,7 +7,7 @@ create_venv_env() { if [ ! -d "$env_name" ]; then echo "Creating $env_name environment..." - uv venv "$env_name" + uv venv "$env_name" --python 3.12 if [ -f "$env_name/Scripts/activate" ]; then source "$env_name/Scripts/activate" @@ -19,6 +19,13 @@ create_venv_env() { fi uv pip install -r "$requirements_file" + + echo "Installing ONNX Runtime GPU (CUDA 13 nightly)..." + uv pip install coloredlogs flatbuffers numpy packaging protobuf sympy + uv pip install --pre --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ort-cuda-13-nightly/pypi/simple/ onnxruntime-gpu + uv pip install tensorrt-cu13 + uv pip install onnx-asr[gpu,hub] + deactivate else echo "Environment $env_name already exists" @@ -40,16 +47,4 @@ create_venv_env() { fi } -if ! command -v python &> /dev/null; then - echo "Python not found! Please install Python 3.10+" - exit 1 -fi - -python -c "import sys; exit(0 if sys.version_info >= (3,10) else 1)" || { - echo "Requires Python 3.10 or newer!" - exit 1 -} - -# create_venv_env ".main_venv" "requirements_main.txt" -# create_venv_env ".support_venv" "requirements_support.txt" create_venv_env ".dev_venv" "requirements_dev.txt" \ No newline at end of file diff --git a/docs/guide.md b/docs/guide.md new file mode 100644 index 0000000..08887e5 --- /dev/null +++ b/docs/guide.md @@ -0,0 +1,327 @@ +# Usage Guide + +This guide explains how to use the YapodDataset pipeline, what files are created at each stage, and how to run individual processing stages. + +--- + +## Table of Contents + +1. [Pipeline Stages](#pipeline-stages) +2. [Running the Pipeline](#running-the-pipeline) +3. [Output Files](#output-files) +4. [Configuration](#configuration) +5. [Running Individual Stages](#running-individual-stages) + +--- + +## Pipeline Stages + +### 1. Download (`src/download/`) +Downloads podcast episodes from Yandex Music based on provided URLs or playlists. + +**Input**: Podcast URLs or playlist IDs +**Output**: Raw audio files (`.mp3`) organized by `{album_id}/{episode_id}/` + +**Configuration**: `config.yaml` → `download` section + +--- + +### 2. Preprocess (`src/preprocess/`) +The preprocessing stage consists of three sequential steps: + +#### 2.1. Crest Factor Removal (`crest_factor_remover.py`) +Removes audio files that have excessive crest factor (peak/RMS ratio). Files with crest factor exceeding the threshold are deleted to filter out problematic audio with extreme dynamic range. + +**Input**: Raw audio files +**Output**: Filtered audio files (files with high crest factor are deleted) + +**Configuration**: `config.yaml` → `preprocess` section +- `crest_treshold`: Maximum allowed crest factor (peak/RMS). Files exceeding this value are deleted. Default: 10.0 + +#### 2.2. Loudness Normalization (`preprocess_audio.py`) +Normalizes audio loudness using ITU-R BS.1770-4 standard. All audio files are normalized to a consistent loudness level, overwriting the original files. + +**Input**: Filtered audio files +**Output**: Loudness-normalized audio files (original files overwritten) + +**Configuration**: `config.yaml` → `preprocess` section +- `peak`: Peak normalization level in dB. Default: -1.0 +- `loudness`: Target loudness level in LUFS. Default: -23.0 +- `block_size`: Block size for loudness measurement in seconds. Default: 0.400 + +#### 2.3. Audio Segmentation (`preprocess.py`) +Splits long audio files into shorter segments (default: 15 seconds) using Voice Activity Detection (VAD). Removes segments that are too short (< 1 second) or too long (> duration limit). + +**Input**: Normalized audio files +**Output**: Segmented audio files named `{start_time}_{end_time}_{album_id}_{episode_id}.mp3` + +**Configuration**: `config.yaml` → `preprocess` section +- `duration`: Maximum segment length in seconds +- `vad_args`: VAD thresholds and model path + - `smart_vad_model`: Path to Smart VAD model + - `silero_vad_threshold`: Threshold for Silero VAD (0.0-1.0) + - `smart_vad_threshold`: Threshold for Smart VAD (0.0-1.0) + +**Note**: The preprocessing stage runs all three steps sequentially. After preprocessing, the separation stage will create `balalaika.csv` with metadata including single speaker flags and audio quality metrics. Files with detected music will be automatically deleted during the separation stage. + +--- + +### 3. Separation (`src/separation/`) +Performs four types of analysis: +- **Diarization**: Identifies and separates different speakers, creates `.rttm` files +- **NISQA**: Assesses audio quality metrics +- **Music Detection**: Detects music segments in audio +- **Silence Detection**: Analyzes silence patterns in audio + +**Input**: Segmented audio files +**Output**: +- `.rttm` files (speaker diarization data) +- **`balalaika.csv`**: Metadata file containing: + - Single speaker flags (indicating whether each audio segment contains only one speaker) + - Audio quality metrics (from NISQA assessment) + - **Silence percent**: Percentage of silence in each audio segment + - **Max silence duration**: Maximum continuous silence duration in seconds + - File paths and processing status +- Files with detected music are **automatically deleted** during music detection stage +- Can filter out multi-speaker files if `one_speaker: True` is set (files are deleted) + +**Configuration**: `config.yaml` → `separation` section +- `diarization`: Speaker diarization settings + - `num_workers`: Number of workers per GPU + - `one_speaker`: Filter for single-speaker audio only +- `nisqa`: Audio quality assessment settings + - `bs`: Batch size + - `num_workers`: Number of workers + - `nisqa_config_path`: Path to NISQA config +- `music_detect`: Music detection settings + - `bs`: Batch size + - `num_workers`: Number of workers per GPU + - `music_detect_model`: Path to model + - `threshold`: Detection threshold +- `silence_detect`: Silence detection settings + - `num_workers`: Number of workers per GPU + +**Note**: The `balalaika.csv` file is created/updated during the separation stage and contains important metadata about each audio segment, including speaker information, quality metrics, and silence analysis. Files detected as containing music are removed from the dataset. + +--- + +### 4. Transcription (`src/transcription/`) +Transcribes audio using multiple ASR models in parallel. This stage is now powered by **onnx-asr**, which provides a unified, high-performance interface for all supported models without requiring manual dataloaders or complex PyTorch code. + +**Key Features**: +- **Consensus Processing**: If `consensus_num` is set, the pipeline will automatically skip processing remaining models for files where the specified number of models have already produced identical transcriptions. +- **Direct GPU Inference**: Uses `onnxruntime-gpu` and `tensorrt-cu13` for native GPU acceleration. +- **Multiprocessing**: Automatically distributes workload across all available GPUs. +- **Word-Level Timestamps**: Supports extracting word-level timestamps for multiple models. + +**Input**: Audio files (`.mp3`) +**Output**: +- Individual model transcriptions: `{filename}_{model_name}.txt` +- Timestamp files: `{filename}_{model_name}.tst` (available for supported models if enabled) +- **Consensus transcription**: `{filename}_rover.txt` (aggregated from all models using the ROVER consensus algorithm) + +**Configuration**: `config.yaml` → `transcription` section +- `model_names`: List of models to use +- `consensus_num`: Number of models that need to agree before skipping remaining models. +- `with_timestamps`: Enable word-level timestamp generation. +- `use_tensorrt`: Enable TensorRT 10 for maximum performance. +- `use_vad`: Use Silero VAD to process long audio files in chunks. + +--- + +### 5. Punctuation (`src/punctuation/`) +Restores punctuation marks in transcribed text using RUPunct model. + +**Input**: `{filename}_rover.txt` +**Output**: `{filename}_punct.txt` + +**Configuration**: `config.yaml` → `punctuation` section +- `model_name`: RUPunct model name (e.g., `"RUPunct/RUPunct_big"`) + +--- + +### 6. Accents (`src/accents/`) +Restores stress marks (accents) in Russian text using ruAccent model. + +**Input**: `{filename}_punct.txt` +**Output**: `{filename}_accent.txt` + +**Configuration**: `config.yaml` → `accent` section +- `model_name`: ruAccent model name (e.g., `"turbo3.1"`) + +--- + +### 7. Phonemizer (`src/phonemizer/`) +Converts text to phonetic representation (phonemes) using TryIPaG2P. + +**Input**: `{filename}_rover.txt` +**Output**: `{filename}_rover_phonemes.txt` + +**Configuration**: `config.yaml` → `phonemizer` section + +--- + +### 8. Collate (`src/collate.py`) +Collects all generated metadata files and aggregates them into a single Parquet file for easy access and analysis. + +**Input**: All generated text files (`_rover.txt`, `_punct.txt`, `_accent.txt`, `_rover_phonemes.txt`) +**Output**: `balalaika.parquet` (contains columns: filepath, rover, punct, accent, phonemes) + +**Usage**: +```bash +bash src/collate_yamls.sh configs/config.yaml +``` + +--- + +## Running the Pipeline + +### Complete Pipeline + +To run the complete annotation pipeline: + +```bash +bash base.sh configs/config.yaml +``` + +This executes all enabled scripts in sequence: +1. Separation (diarization, quality assessment, music detection) +2. Transcription (multi-model ASR with ROVER consensus) +3. Punctuation restoration +4. Accent restoration +5. Phonemization + +### Collecting Metadata + +After processing, collect all metadata: + +```bash +bash src/collate_yamls.sh configs/config.yaml +``` + +This creates `balalaika.parquet` in your `podcasts_path` directory. + +--- + +## Output Files + +For each audio segment, the pipeline generates: + +``` +{start_time}_{end_time}_{album_id}_{episode_id}.mp3 # Audio file +{start_time}_{end_time}_{album_id}_{episode_id}.rttm # Speaker diarization + +# Individual model transcriptions (if enabled) +{start_time}_{end_time}_{album_id}_{episode_id}_giga_ctc.txt +{start_time}_{end_time}_{album_id}_{episode_id}_giga_rnnt.txt +{start_time}_{end_time}_{album_id}_{episode_id}_vosk.txt +{start_time}_{end_time}_{album_id}_{episode_id}_tone.txt +{start_time}_{end_time}_{album_id}_{episode_id}_giga_ctc.tst # Timestamps (if enabled) + +# Consensus and processed text files +{start_time}_{end_time}_{album_id}_{episode_id}_rover.txt # Consensus transcription +{start_time}_{end_time}_{album_id}_{episode_id}_punct.txt # With punctuation +{start_time}_{end_time}_{album_id}_{episode_id}_accent.txt # With accents +{start_time}_{end_time}_{album_id}_{episode_id}_rover_phonemes.txt # Phonetic representation +``` + +**Intermediate metadata:** +- `balalaika.csv`: Created during separation stage, contains: + - Single speaker flags (indicating one-speaker vs multi-speaker segments) + - Audio quality metrics (NISQA scores) + - **Silence metrics**: + - `silence_percent`: Percentage of silence in each segment + - `max_silence_duration`: Maximum continuous silence duration (seconds) + - File paths and processing status + - Files with music are automatically deleted and not included + +**Final aggregated metadata:** +- `balalaika.parquet`: All metadata in structured format (created by collate stage) +- Contains: filepath, rover, punct, accent, phonemes columns + +--- + +## Configuration + +The main configuration file is `configs/config.yaml`. Key sections: + +### Global Parameters +- `cache_path`: Path for caching temporary files +- `podcasts_path`: **Absolute path** to your data directory + +### Stage-Specific Configuration + +Each stage has its own configuration section. See `config.yaml` for all available parameters. + +**Important**: All paths must be **absolute paths**. + +--- + +## Running Individual Stages + +### Modify `base.sh` + +Edit the `SCRIPTS` array to run only specific stages: + +```bash +SCRIPTS=( + # "./src/download/download_yaml.sh" + # "./src/preprocess/preprocess_yaml.sh" + # "./src/separation/separation_yaml.sh" + "./src/transcription/transcription_yaml.sh" + # "./src/punctuation/punctuation_yaml.sh" + # "./src/accents/accents_yaml.sh" + # "./src/phonemizer/phonemizer_yaml.sh" + # "./src/collate_yamls.sh" +) +``` + +### Run Scripts Directly + +```bash +# Activate virtual environment +source .dev_venv/bin/activate + +# Run specific stages +bash src/download/download_yaml.sh configs/config.yaml +bash src/preprocess/preprocess_yaml.sh configs/config.yaml +bash src/separation/separation_yaml.sh configs/config.yaml +bash src/transcription/transcription_yaml.sh configs/config.yaml +bash src/punctuation/punctuation_yaml.sh configs/config.yaml +bash src/accents/accents_yaml.sh configs/config.yaml +bash src/phonemizer/phonemizer_yaml.sh configs/config.yaml +bash src/collate_yamls.sh configs/config.yaml +``` + +### Processing Order + +The stages must be run in this order: +1. **Download** → Downloads raw audio files +2. **Preprocess** → Three sequential steps: + - **Crest Factor Removal** → Removes files with excessive peak/RMS ratio + - **Loudness Normalization** → Normalizes audio loudness (overwrites files) + - **Audio Segmentation** → Segments audio into chunks using VAD +3. **Separation** → Diarization, quality assessment, music detection, silence detection +4. **Transcription** → Creates individual model transcriptions + `_rover.txt` (consensus) +5. **Punctuation** → Processes `_rover.txt` → `_punct.txt` +6. **Accents** → Processes `_punct.txt` → `_accent.txt` +7. **Phonemizer** → Processes `_rover.txt` → `_rover_phonemes.txt` +8. **Collate** → Aggregates all metadata into Parquet + +**Important Notes:** +- All scripts must be executed from the **project root directory** +- Processing scripts (punctuation, accents, phonemizer) should be run **sequentially** after transcription +- The pipeline processes files in place, so ensure you have backups if needed +- Transcription stage creates individual model files first, then aggregates them into `_rover.txt` + +--- + +## Troubleshooting + +### Common Issues + +1. **Path errors**: Ensure all paths in `config.yaml` are **absolute paths** +2. **Missing `_rover.txt` files**: Ensure transcription stage completed successfully. ROVER aggregation runs automatically after all model transcriptions finish +3. **File naming**: The pipeline expects specific file naming patterns. Ensure audio files follow the expected structure + +For more troubleshooting tips, see individual module READMEs in `src/*/README.md`. diff --git a/docs/preparing.md b/docs/preparing.md new file mode 100644 index 0000000..ee6a3d1 --- /dev/null +++ b/docs/preparing.md @@ -0,0 +1,77 @@ +# Preparing your dataset + +You can either **use ready-made Balalaika exports on Hugging Face** or **run this repo** on your own audio. + +--- + +## Pre-built datasets (Hugging Face) + +MTUCI hosts processed datasets in the **[Balalaika Dataset collection](https://huggingface.co/collections/MTUCI/balalaika-dataset)** on Hugging Face. Those snapshots are already segmented, filtered, and annotated through the Balalaika-style pipeline—use them directly for training or evaluation without running download → preprocess → … locally. + +To load a WebDataset-style export with the Hugging Face `datasets` library, follow [example/README.md](../example/README.md). + +--- + +## Running the pipeline on your own data + +### 1. Directory layout (for annotating your own corpus) + +Put **`podcasts_path`** at the root of your tree (use the same root everywhere in `configs/config.yaml`). For local annotation, a practical layout is **one subfolder per source group**, with **one long audio file per recording** (name = episode / clip id): + +```text +dataset/ # = podcasts_path +├── balalaika.csv # created/updated by preprocess and later stages (may appear after first run) +├── 00/ +│ ├── audio1.mp3 +│ ├── audio2.mp3 +│ └── ... +├── 01/ +│ ├── interview_a.mp3 +│ └── ... +└── ... +``` + +Rules (see `src/preprocess/preprocess.py`): + +- **`playlist_id`** = the **immediate parent folder name** (`00`, `01`, … or any stable label—album, speaker batch, etc.). +- **`podcast_id`** = the **file name without extension** (`audio1` from `audio1.mp3`). Use **unique stems** inside each playlist folder so metadata does not collide. + +Supported extensions are whatever `get_audio_paths` collects (e.g. `.mp3`, `.wav`, `.flac`, `.ogg`, `.opus`). + +After **preprocess**, a long file is **removed** once chunks are written. Chunks land in a subfolder named after that file’s stem: + +```text +{playlist_id}/{podcast_id}/{start}_{end}_{playlist_id}_{podcast_id}.mp3 +``` + +Example: from `dataset/00/audio1.mp3` you get files like `dataset/00/audio1/12.50_26.30_00_audio1.mp3`. + +If a file is **shorter than** `preprocess.duration`, it is **left in place** and only metrics are appended to `balalaika.csv` (no chunk subfolder for that case). + +### 2. Environment + +Create a `.env` in the repo root (see main [README.md](../README.md)): + +- **`HF_TOKEN`** — Hugging Face token for gated models and hub downloads (RUPunct, onnx-asr models, etc.). +- **`YANDEX_KEY`** — only if you use the Yandex Music **download** stage. + +### 3. Configuration + +Edit **`configs/config.yaml`**: absolute `podcasts_path` everywhere you process data, batch sizes, thresholds, and `model_names` for transcription. The file includes an inline **parameter reference** at the top. + +**Collate note:** `src/collate.py` reads the **`download`** section for `podcasts_path` and `num_workers`. Keep `download.podcasts_path` the same as your working dataset root, or collate will look in the wrong place. + +### 4. Run order + +Use [docs/guide.md](guide.md) for stage-by-stage commands, or uncomment stages in `base.sh` and run: + +```bash +bash base.sh configs/config.yaml +``` + +--- + +## More reading + +- [Usage Guide](guide.md) — stages and artifacts (some sections may still mention older tooling; prefer `src/*/README.md` and `configs/config.yaml` for the current stack). +- [example/README.md](../example/README.md) — WebDataset + `datasets` loading example. diff --git a/example/README.md b/example/README.md index 20b4c0d..9a3e625 100644 --- a/example/README.md +++ b/example/README.md @@ -1,62 +1,68 @@ -# BalalaikaDataset Example +# Balalaika WebDataset (Hugging Face) -This example demonstrates how to load and access a sample from the `BalalaikaDataset`. +Pipeline export is packed as [WebDataset](https://github.com/webdataset/webdataset) `.tar` shards. Load with Hugging Face [`datasets`](https://huggingface.co/docs/datasets) and `streaming=True` to avoid holding the full corpus in RAM. -## Usage +## Install -```python -from dataset import BalalaikaDataset - -if __name__ == "__main__": - dataset = BalalaikaDataset( - podcasts_path='../Balalaika100H', - parquet_path='../balalaika/balalaika.parquet' - ) - print(dataset[0]) +```bash +pip install datasets webdataset ``` -## Output Example +You may also need `torchaudio` (or another backend) depending on your `datasets` version for decoding `mp3` / `wav` columns. + +## Loading + +After `src/collate_yamls.sh`, shards are written to +`{parent of podcasts_path}/{dataset_folder_name}_webdataset/train/` +(see `src/to_webdataset.py`). Point `data_dir` at that `train` folder. + +[`example.py`](example.py) shows a minimal loop. ```python -( - '/home/nikita/podcasts_1/21851634/102739417/469.72_483.93_21851634_102739417.mp3', - { - 'audio_path': '21851634/102739417/469.72_483.93_21851634_102739417.mp3', - 'is_mono': True, - 'NOI': 3.0010192, - 'COL': 4.100467, - 'DISC': 2.4562664, - 'LOUD': 3.4343345, - 'MOS': 3.9075212, - 'playlist_id': 21851634, - 'podcast_id': 102739417, - 'start': 469.72, - 'end': 483.93, - 'speaker': 0.0, - 'fullness': 0.8249, - 'accent': 'П+апа +едет дом+ой. ...', - 'phonemes': 'p a p ə j e dʲ ɪ t d ɐ m o j ...', - 'giga': 'папа едет домой и оля сейчас ...', - 'punct': 'Папа едет домой. И Оля сейчас поедет домой. ...', - 'whisper': 'Папа едет домой. И Оля сейчас поедет домой. ...', - } +from datasets import load_dataset + +dataset = load_dataset( + "webdataset", + data_dir="/path/to/your_dataset_webdataset/train", + split="train", + streaming=True, ) + +for item in dataset: + print(item["__key__"]) + audio_key = next(k for k in item if k in ("mp3", "wav", "flac", "ogg")) + audio = item[audio_key] + print(audio["array"].shape, audio["sampling_rate"]) + print(item["json"]) # dict: CSV metadata + all sidecar texts ``` -## Field Descriptions +## Sample layout + +| Field | Type | Description | +|--------|------|-------------| +| `__key__` | `str` | Sample id; dots in the stem are replaced with `_` for HF / WebDataset parsing. | +| `mp3` / `wav` / … | `dict` | Audio: NumPy `array`, `sampling_rate`. Extension matches the chunk on disk. | +| `json` | `dict` | Merged metadata from `balalaika.csv` plus every sidecar file next to the chunk. | + +## Typical `json` keys (full run) + +Keys mirror CSV columns and filenames `{stem}_{postfix}` → JSON key `postfix`. + +**From `balalaika.csv`:** e.g. `start`, `end`, `total_duration`, `speaker_id`, `playlist_id`, `podcast_id`, `silence_percent`, `max_silence_duration`, `crest_factor`, `music_prob`, `DistillMOS`, and optionally `is_single_speaker`, etc. Only files that passed all filters are in the dataset. + +**Text sidecars** (depend on `transcription.model_names` and which stages you ran): -- **`audio_path`** — relative path to the audio segment. -- **`is_mono`** — whether the audio is mono. -- **`NOI`, `COL`, `DISC`, `LOUD`, `MOS`** — NISQA metrics (noise, coloration, discontinuity, loudness, MOS). -- **`playlist_id`, `podcast_id`, `start`, `end`** — source identifiers and time boundaries of the segment. -- **`speaker`** — predicted speaker ID. -- **`fullness`** — ratio of speech to silence in the segment. -- **`accent`** — text with stress markers. -- **`phonemes`** — phoneme-level representation of the utterance. -- **`giga`** — raw ASR output from GigaAM. -- **`punct`** — GigaAM output with punctuation. -- **`whisper`** — transcription from Whisper model. +| Key | Content | +|-----|---------| +| `giga_ctc.txt`, `giga_rnnt.txt`, `vosk.txt`, `tone.txt`, … | Raw ASR text per model. | +| `giga_ctc.tst`, `tone.tst`, … | Word-level TSV: `start_sec\tend_sec\tword` per line if timestamps enabled. | +| `rover.txt` | Multi-model consensus (ROVER). | +| `punct.txt` | Punctuation restoration (RUPunct). | +| `accent.txt` | Stress marks + normalized text (ruAccent). | +| `rover_phonemes.txt` | IPA string from consensus text (TryIParu `G2PModel`). | ---- +## Why WebDataset -Ensure that the dataset path and metadata are correctly specified before running the script. +- **Throughput**: sequential read of large `.tar` files instead of millions of tiny files. +- **Streaming**: train without fully unpacking to disk. +- **HF-friendly**: same loader pattern for local folders or Hub-hosted shards. diff --git a/example/dataset.py b/example/dataset.py deleted file mode 100644 index 99b7bed..0000000 --- a/example/dataset.py +++ /dev/null @@ -1,44 +0,0 @@ -import pandas as pd -import os -from pathlib import Path -from torch.utils.data import Dataset - - -class BalalaikaDataset(Dataset): - def __init__( - self, - podcasts_path: str, - parquet_path: str, - audio_key_column: str = "audio_path" - ): - self.podcasts_path = podcasts_path - self.audio_key_column = audio_key_column - - self.metadata = pd.read_parquet(parquet_path) - - available_files = { - str(path) for path in Path(self.podcasts_path).rglob('*.mp3') - } - - self.valid_items = [] - for idx, row in self.metadata.iterrows(): - audio_path_in_meta = row[self.audio_key_column] - if not os.path.isabs(audio_path_in_meta): - full_audio_path = os.path.join(self.podcasts_path, audio_path_in_meta) - else: - full_audio_path = audio_path_in_meta - - if full_audio_path in available_files: - self.valid_items.append((idx, full_audio_path)) - else: - continue - - print(f"Found {len(self.valid_items)} matches between Parquet and audio files") - - def __len__(self): - return len(self.valid_items) - - def __getitem__(self, idx): - meta_idx, audio_path = self.valid_items[idx] - row = self.metadata.iloc[meta_idx] - return audio_path, row.to_dict() \ No newline at end of file diff --git a/example/example.py b/example/example.py index 8cdd67b..22fc274 100644 --- a/example/example.py +++ b/example/example.py @@ -1,10 +1,28 @@ -from dataset import BalalaikaDataset +from datasets import load_dataset import time + if __name__ == "__main__": - dataset = BalalaikaDataset( - podcasts_path='/home/nikita/balalaika', - parquet_path='/home/nikita/yapoddataset/Balalaika100H.parquet' + # Hugging Face will find all .tar archives in the folder and collect them into a dataset + dataset = load_dataset( + "webdataset", + data_dir="/home/nikita/balalaika/dataset_webdataset", + split="train", + streaming=True # Streaming reading, does not fill up the RAM ) + for item in dataset: - print(item) + print(f"=== Sample: {item['__key__']} ===") + + # Audio will be automatically loaded as a dictionary (NumPy array and Sampling Rate) + # Hugging Face automatically names the columns by the extensions from the archive + audio_key = next((k for k in item.keys() if k in ['mp3', 'wav', 'flac', 'ogg']), None) + + if audio_key: + audio_data = item[audio_key] + print(f"Audio Array: {audio_data['array'].shape}, SR: {audio_data['sampling_rate']}") + + # JSON will be automatically parsed + print(item['json']) + + print("-" * 50) time.sleep(1) \ No newline at end of file diff --git a/requirements_dev.txt b/requirements_dev.txt index 64a91c5..f805587 100644 Binary files a/requirements_dev.txt and b/requirements_dev.txt differ diff --git a/src/accents/README.md b/src/accents/README.md index 5f39629..756af84 100644 --- a/src/accents/README.md +++ b/src/accents/README.md @@ -1,45 +1,21 @@ -## Usage/Examples +## Accents (ruAccent) -### Running the Code via Command-Line Arguments -You can modify the parameters directly in the shell script (`accents/accents_args.sh`) and then run it: -~~~sh -sh accents/accents_args.sh -~~~ +Lexical stress and text normalization from **`{stem}_punct.txt`**. -### Running the Code via Config File -Example: -~~~sh -bash accents/accents_yaml.sh config_path -~~~ +## Run -## Explanation of Parameters +```bash +bash src/accents/accents_yaml.sh configs/config.yaml +``` -- `--config_path`: Path to the YAML configuration file. -- `--podcasts_path`: Root directory containing the text files for accent restoration (default: "../../../podcasts"). -- `--num_workers`: Number of worker processes per GPU for parallel processing (default: 4). -- `--model_name`: Model version to use with RUAccent (default: "turbo3.1"). -- `--device`: Device to run the model on (default: "cuda"). +## Parameters -## Output Structure +See **`accent`** in `configs/config.yaml` (`podcasts_path`, `model_name`, `num_workers`, `use_tensorrt`). The YAML section name is **`accent`**, not `accents`. -For each text file ending with `_punct.txt`, a corresponding `_accent.txt` file will be created: - -~~~ -podcasts/ -└── {album_id}/ - └── {episode_id}/ - ├── {start_time}_{end_time}_{album_id}_{episode_id}.mp3 - ├── {start_time}_{end_time}_{album_id}_{episode_id}_giga.txt - ├── {start_time}_{end_time}_{album_id}_{episode_id}_punct.txt - └── {start_time}_{end_time}_{album_id}_{episode_id}_accent.txt -~~~ - -### File Descriptions -- `.mp3`: Original audio file -- `_giga.txt`: Initial transcription without punctuation -- `_punct.txt`: Text with restored punctuation (input file for accent restoration) -- `_accent.txt`: Final text with restored accents, punctuation, and capitalization - -The script processes all `_punct.txt` files found in the directory structure and creates corresponding `_accent.txt` files. Processing is done in parallel using available GPUs for better performance. +## Output +```text +{stem}_punct.txt → {stem}_accent.txt +``` +WebDataset key: **`accent.txt`**. diff --git a/src/accents/accents.py b/src/accents/accents.py index 4ce94a9..63e7b9e 100644 --- a/src/accents/accents.py +++ b/src/accents/accents.py @@ -10,22 +10,45 @@ from ruaccent import RUAccent from tqdm import tqdm -from src.utils import load_config, get_txt_paths, read_file_content +from src.utils.utils import load_config, get_txt_paths, read_file_content + +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cuda.enable_flash_sdp(True) +torch.backends.cuda.enable_mem_efficient_sdp(True) +torch.backends.cuda.enable_math_sdp(False) accentizer = None -def init_process( - model_name: str, - device: str - ) -> None: +def get_providers(cuda_id: int, use_tensorrt: bool = False) -> list: + if use_tensorrt: + cache_path = f".cache/trt_cache_{cuda_id}" + os.makedirs(cache_path, exist_ok=True) + return [ + ("TensorrtExecutionProvider", { + "device_id": cuda_id, + "trt_max_workspace_size": 4 * 1024**3, + "trt_fp16_enable": True, + "trt_engine_cache_enable": True, + "trt_engine_cache_path": cache_path, + }), + ("CUDAExecutionProvider", {"device_id": cuda_id}), + ] + return [("CUDAExecutionProvider", {"device_id": cuda_id})] + + +def init_process(model_name: str, cuda_id: int, use_tensorrt: bool) -> None: global accentizer - + + providers = get_providers(cuda_id, use_tensorrt) + + logger.info(f"Initializing worker on GPU:{cuda_id} (TRT={use_tensorrt})") + accentizer = RUAccent() accentizer.load( - omograph_model_size=model_name, - use_dictionary=True, - tiny_mode=False, - device=device + omograph_model_size=model_name, + use_dictionary=True, + tiny_mode=False, + providers=providers ) @@ -37,6 +60,8 @@ def process_file(path: Path): return text = read_file_content(path) + if not text or len(text.strip()) == 0: + return processed_text = accentizer.process_all(text) @@ -44,108 +69,87 @@ def process_file(path: Path): f.write(processed_text) except Exception as e: - logger.error(f"Error processing {path}: {e}") - raise + logger.error(f"Error processing {path.name}: {e}") -def get_valid_txt_paths(path: str) -> List[str]: +def get_valid_txt_paths(path: str) -> List[Path]: all_punct_paths = get_txt_paths(path, "_punct.txt") valid_paths = [] for punct_path in all_punct_paths: - accent_path = punct_path.with_name(punct_path.stem.replace("_punct", "_accent") + ".txt") - if not os.path.exists(accent_path): - valid_paths.append(punct_path) + p = Path(punct_path) + accent_path = p.with_name(p.stem.replace("_punct", "_accent") + ".txt") + if not accent_path.exists(): + valid_paths.append(p) return valid_paths def main(args): config = load_config(args.config_path, 'accent') - num_workers = args.num_workers if args.num_workers else config.get('num_workers', 4) - model_name = args.model_name if args.model_name else config.get('model_name', 'turbo3.1') - podcast_path = args.podcasts_path if args.podcasts_path else config.get('podcasts_path', '../../../balalaika') + + num_workers_per_gpu = config.get('num_workers', 1) + model_name = config.get('model_name', 'turbo3.1') + podcast_path = config.get('podcasts_path', './data') + use_tensorrt = config.get('use_tensorrt', False) available_gpu_ids = list(range(torch.cuda.device_count())) num_gpus = len(available_gpu_ids) if num_gpus == 0: - logger.error("No GPUs available. Exiting.") + logger.error("No GPUs found via torch.cuda.device_count().") return - logger.info( - f""" - Using parms - podcast_path:{podcast_path} - num_workers:{num_workers} - model_name:{model_name} - devices:{available_gpu_ids} - """) + logger.info(f"Config loaded. GPUs: {available_gpu_ids}, Workers per GPU: {num_workers_per_gpu}, TRT: {use_tensorrt}") valid_text_files = get_valid_txt_paths(podcast_path) + if not valid_text_files: + logger.success("All files are already processed (no new _punct.txt found).") + return - files_for_each_gpu = [[] for _ in range(num_gpus)] - for i, path in enumerate(valid_text_files): - gpu_assignment_index = i % num_gpus - files_for_each_gpu[gpu_assignment_index].append(path) - - logger.info(f"Found {len(valid_text_files)} files to process") + logger.info(f"Found {len(valid_text_files)} files to process.") + + files_per_gpu = [[] for _ in range(num_gpus)] + for i, file_path in enumerate(valid_text_files): + files_per_gpu[i % num_gpus].append(file_path) all_futures = [] executors = [] for i, gpu_id in enumerate(available_gpu_ids): - device_str = f'cuda:{gpu_id}' - files_for_this_gpu = files_for_each_gpu[i] - - if not files_for_this_gpu: + gpu_files = files_per_gpu[i] + if not gpu_files: continue - logger.info(f"Creating ProcessPoolExecutor for {device_str} with {num_workers} workers for {len(files_for_this_gpu)} files.") - + logger.info(f"Starting {num_workers_per_gpu} workers for GPU:{gpu_id} ({len(gpu_files)} files)") + executor = ProcessPoolExecutor( - max_workers=num_workers, + max_workers=num_workers_per_gpu, initializer=init_process, - initargs=(model_name, device_str) + initargs=(model_name, gpu_id, use_tensorrt) ) executors.append(executor) - for path in files_for_this_gpu: + for path in gpu_files: future = executor.submit(process_file, path) all_futures.append(future) - logger.info(f"Submitted all {len(all_futures)} tasks across {len(executors)} GPU(s). Waiting for completion...") - - for future in tqdm(as_completed(all_futures), total=len(all_futures), desc="Overall Punctuation Progress"): - try: - future.result() - except Exception as e: - logger.error(f"A task processing encountered an error (already logged by worker): {e}") + try: + with tqdm(total=len(all_futures), desc="Accents Restoration") as pbar: + for future in as_completed(all_futures): + future.result() + pbar.update(1) + except KeyboardInterrupt: + logger.warning("Interrupted by user. Shutting down...") + finally: + for e in executors: + e.shutdown(wait=True) - logger.info("Processing completed") + logger.success("Accent restoration completed!") if __name__ == "__main__": - multiprocessing.set_start_method("spawn") + multiprocessing.set_start_method("spawn", force=True) - parser = argparse.ArgumentParser(description="Accent restoration script.") - parser.add_argument( - "--config_path", - type=str, - help="Path to config" - ) - parser.add_argument( - "--podcasts_path", - type=str, - help="Path to dataset directory" - ) - parser.add_argument( - "--num_workers", - type=int, - help="Number of worker processes" - ) - parser.add_argument( - "--model_name", - type=str, - help="Model version" - ) + parser = argparse.ArgumentParser(description="Multi-GPU Accent Restoration") + parser.add_argument("--config_path", type=str, required=True, help="Path to YAML config") args = parser.parse_args() main(args) \ No newline at end of file diff --git a/src/accents/accents_args.sh b/src/accents/accents_args.sh deleted file mode 100644 index 05202a0..0000000 --- a/src/accents/accents_args.sh +++ /dev/null @@ -1,23 +0,0 @@ -#!/bin/bash - -activate_venv() { - local venv_path=$1 - if [ ! -f "$venv_path/bin/activate" ]; then - echo "Error: Virtual environment not found at $venv_path" - exit 1 - fi - source "$venv_path/bin/activate" - echo "Activated: $(which python)" -} - - -activate_venv ".dev_venv" - -PODCASTS_PATH="../../../balalaika" -MODEL_NAME="turbo3.1" -NUM_WORKERS=4 - -python3 -m src.accents.accents \ - --podcasts_path "$PODCASTS_PATH" \ - --model_name "$MODEL_NAME" \ - --num_workers "$NUM_WORKERS" diff --git a/src/accents/accents_yaml.sh b/src/accents/accents_yaml.sh index 0fbd946..d38c522 100644 --- a/src/accents/accents_yaml.sh +++ b/src/accents/accents_yaml.sh @@ -19,4 +19,4 @@ CONFIG_PATH=$(realpath "$1") activate_venv ".dev_venv" -python3 -m src.accents.accents --config_path "$CONFIG_PATH" \ No newline at end of file +taskset -c 0-24 python3 -m src.accents.accents --config_path "$CONFIG_PATH" \ No newline at end of file diff --git a/src/classification/README.md b/src/classification/README.md deleted file mode 100644 index d18981f..0000000 --- a/src/classification/README.md +++ /dev/null @@ -1,36 +0,0 @@ -## Usage/Examples - -### Running the Code via Command-Line Arguments -You can modify the parameters directly in the shell script (e.g., `classification/classification_args.sh`) and then run it: -~~~sh -sh classification/classification_args.sh -~~~ - -### Running the Code via Config File -Example: -~~~sh -bash classification/classification_yaml.sh -~~~ - -## Explanation of Parameters - -- `--config_path`: Path to the YAML configuration file. -- `--podcasts_path`: Path to the podcast folder. This folder must contain the `results.csv` file with metadata about the podcast segments. -- `--threshold`: Similarity threshold for clustering speaker embeddings (default: 0.8). Higher values result in stricter clustering. -- `--model_path`: Embedder model path or identifier (default: `"voxblink2_samresnet100_ft"`). This model is used to generate speaker embeddings. -- `--device`: Device to run the embedder model (e.g., `cuda` or `cpu`). - -## Output Structure - -After execution, a CSV file named `clustering_result.csv` will be generated in the specified podcasts folder. This file includes the original metadata along with an additional column `speaker` that indicates the assigned speaker cluster ID for each audio segment. - - -### File Descriptions -- **`results.csv`**: The input CSV file located in the podcasts folder containing metadata for each podcast segment. This file should include a column `IsMono` used to filter segments for clustering. -- **`clustering_result.csv`**: The output CSV file containing the original metadata along with an additional `speaker` column. Each entry in this column represents the speaker cluster ID assigned to that segment. - -## Important Notice - -- Ensure that the `results.csv` file exists in the specified podcasts folder and contains the required data (e.g., the `IsMono` column) for proper clustering. -- The clustering process is applied only to segments marked as mono (i.e., `IsMono == True`). - diff --git a/src/classification/classification.py b/src/classification/classification.py deleted file mode 100644 index eccfe25..0000000 --- a/src/classification/classification.py +++ /dev/null @@ -1,328 +0,0 @@ -import argparse -import os -from concurrent.futures import ProcessPoolExecutor, as_completed -import multiprocessing as mp - -import numpy as np -import pandas as pd -import torch -import torch.nn.functional as F -from loguru import logger -from tqdm import tqdm - -from src.classification.emb.embeder import ResNetEmbedder -from src.utils import load_config - -class SpeakerClustering: - def __init__( - self, - podcasts_path: str, - threshold=0.85, - model_path: str = 'voxblink2_samresnet100_ft', - ): - - self.podcasts_path = podcasts_path - df_path = os.path.join(podcasts_path, 'results.csv') - self.full_df = pd.read_csv(df_path) - - if 'fullness' not in self.full_df.columns: - self.full_df['fullness'] = np.nan - - self.df = self.full_df[self.full_df['is_mono'] == True].copy() - self.threshold = threshold - self.model_path = model_path - self.all_clusters = [] # {"embeddings": list, "paths": list} - - def norm_cos_sim(self, emb1: torch.Tensor, emb2: torch.Tensor) -> float: - if emb1.dim() == 1: - emb1 = emb1.unsqueeze(0) - if emb2.dim() == 1: - emb2 = emb2.unsqueeze(0) - - cosine_score = F.cosine_similarity( - F.normalize(emb1, p=2, dim=1), - F.normalize(emb2, p=2, dim=1), - dim=1 - ) - cosine_score = cosine_score.item() - return (cosine_score + 1.0) / 2 # normalize: [-1, 1] => [0, 1] - - def load_embedding(self, audio_path: str) -> torch.Tensor: - emb_path = audio_path.rsplit('.', 1)[0] + '.emb' - if os.path.exists(emb_path): - try: - emb, _ = torch.load(emb_path) - return emb - except Exception as e: - logger.error(f"Failed to load embedding from {emb_path}: {str(e)}") - - def cluster_mono_podcast(self, podcast_id: int): - group = self.df[self.df['podcast_id'] == podcast_id] - temp_clusters = [] - - for _, row in group.iterrows(): - audio_path = os.path.join( - self.podcasts_path, - row["audio_path"] - ) - emb = self.load_embedding(audio_path) - - if emb is None: - continue - - best_similarity = 0.0 - best_cluster = None - - for cluster in temp_clusters: - cluster_embs = torch.stack(cluster["embeddings"]) - avg_emb = cluster_embs.mean(dim=0) - - sim = self.norm_cos_sim(emb, avg_emb) - if sim > best_similarity: - best_similarity = sim - best_cluster = cluster - - if best_similarity >= self.threshold and best_cluster is not None: - best_cluster["embeddings"].append(emb) - best_cluster["paths"].append(row["audio_path"]) - else: - new_cluster = { - "embeddings": [emb], - "paths": [row["audio_path"]] - } - temp_clusters.append(new_cluster) - - self.all_clusters.extend(temp_clusters) - - def cluster_all_podcast(self): - for podcast_id in tqdm(self.df['podcast_id'].unique()): - self.cluster_mono_podcast(int(podcast_id)) - - def merge_clusters(self): - changed = True - while changed: - changed = False - new_clusters = [] - used = [False] * len(self.all_clusters) - - for i in range(len(self.all_clusters)): - if used[i]: - continue - base_cluster = self.all_clusters[i] - base_embs = torch.stack(base_cluster["embeddings"]) - base_avg = base_embs.mean(dim=0) - - for j in range(i + 1, len(self.all_clusters)): - if used[j]: - continue - comp_cluster = self.all_clusters[j] - comp_embs = torch.stack(comp_cluster["embeddings"]) - comp_avg = comp_embs.mean(dim=0) - sim = self.norm_cos_sim(base_avg, comp_avg) - if sim >= self.threshold: - base_cluster["embeddings"].extend(comp_cluster["embeddings"]) - base_cluster["paths"].extend(comp_cluster["paths"]) - - base_embs = torch.stack(base_cluster["embeddings"]) - base_avg = base_embs.mean(dim=0) - used[j] = True - changed = True - new_clusters.append(base_cluster) - used[i] = True - self.all_clusters = new_clusters - - for cluster_id, cluster in enumerate(self.all_clusters): - cluster['id'] = cluster_id - - return self.all_clusters - - def assign_cluster_ids(self): - cluster_mapping = {} - for cluster in self.all_clusters: - for path in cluster["paths"]: - cluster_mapping[path] = cluster["id"] - - self.full_df["speaker"] = ( - self.full_df["audio_path"] - .map(cluster_mapping) - .astype(pd.Int64Dtype()) - ) - -def init_worker(model_path, device_str): - global worker_embedder - worker_embedder = ResNetEmbedder(model_path, device_str) - -def compute_embedding_for_file(audio_path): - try: - emb, fullness = worker_embedder(audio_path) - if fullness: - emb_path = audio_path.rsplit('.', 1)[0] + '.emb' - torch.save((emb, fullness), emb_path) - return audio_path, fullness - return None - except Exception as e: - logger.error(f"Error processing {audio_path}: {str(e)}") - return None - -def precompute_embeddings( - audio_files: list, - model_path: str, - num_workers_per_gpu: int = 8 -) -> dict: - num_gpus = torch.cuda.device_count() - available_gpu_ids = list(range(num_gpus)) - - files_for_each_gpu = [[] for _ in range(num_gpus)] - for i, path in enumerate(audio_files): - gpu_assignment_index = i % num_gpus - files_for_each_gpu[gpu_assignment_index].append(path) - - all_futures = [] - executors = [] - fullness_results = {} - - logger.info(f"Starting embedding computation on {num_gpus} GPUs with {num_workers_per_gpu} workers per GPU") - - for i, gpu_id in enumerate(available_gpu_ids): - device_str = f'cuda:{gpu_id}' - files_for_this_gpu = files_for_each_gpu[i] - - if not files_for_this_gpu: - continue - - logger.info(f"Creating ProcessPoolExecutor for {device_str} with {num_workers_per_gpu} workers for {len(files_for_this_gpu)} files") - - executor = ProcessPoolExecutor( - max_workers=num_workers_per_gpu, - initializer=init_worker, - initargs=(model_path, device_str) - ) - executors.append(executor) - - for path in files_for_this_gpu: - future = executor.submit(compute_embedding_for_file, path) - all_futures.append(future) - - logger.info(f"Submitted {len(all_futures)} tasks across {len(executors)} GPUs") - - completed_count = 0 - for future in tqdm(as_completed(all_futures), total=len(all_futures), desc="Computing embeddings"): - try: - result = future.result() - if result: - audio_path, fullness = result - fullness_results[audio_path] = fullness - completed_count += 1 - except Exception as e: - logger.error(f"Task failed: {str(e)}") - - for executor in executors: - executor.shutdown() - - logger.info(f"Completed {completed_count}/{len(all_futures)} embedding computations") - return fullness_results - -def main(args): - config = load_config(args.config_path, 'classification') - podcasts_path = config.get('podcasts_path', '../../../podcasts') if args.podcasts_path is None else args.podcasts_path - threshold = config.get('threshold', 0.85) if args.threshold is None else args.threshold - model_path = config.get('model_path', '/models/voxblink2_samresnet100_ft') if args.model_path is None else args.model_path - num_workers = config.get('num_workers', 8) if args.num_workers is None else args.num_workers - - available_gpu_ids = list(range(torch.cuda.device_count())) - num_gpus = len(available_gpu_ids) - - if num_gpus == 0: - logger.error("No GPUs available. Exiting.") - return - - logger.info( - f""" - Used params: - podcasts path: {podcasts_path} - threshold: {threshold} - emb model path: {model_path} - list available devices: {available_gpu_ids} - workers per GPU: {num_workers} - """ - ) - - sc = SpeakerClustering( - podcasts_path=podcasts_path, - threshold=threshold, - model_path=model_path, - ) - - audio_files = [ - os.path.join(podcasts_path, path) - for path in sc.df['audio_path'].unique().tolist() - ] - - missing_files = [] - for path in audio_files: - emb_path = os.path.join(podcasts_path,path.rsplit('.', 1)[0] + '.emb') - if not os.path.exists(emb_path): - missing_files.append(path) - - if missing_files: - logger.info(f"Found {len(missing_files)} missing embeddings, starting computation...") - fullness_results = precompute_embeddings( - missing_files, - model_path, - num_workers_per_gpu=num_workers - ) - - for path, fullness in fullness_results.items(): - relative_path = os.path.relpath(path, podcasts_path) - mask = sc.full_df['audio_path'] == relative_path - sc.full_df.loc[mask, 'fullness'] = fullness - - mask_df = sc.df['audio_path'] == relative_path - sc.df.loc[mask_df, 'fullness'] = fullness - - updated_csv_path = os.path.join(podcasts_path, 'results.csv') - sc.full_df.to_csv(updated_csv_path, index=False) - logger.info(f"Updated results.csv with fullness values") - else: - logger.info("All embeddings already precomputed") - - sc.cluster_all_podcast() - sc.merge_clusters() - sc.assign_cluster_ids() - - output_path = os.path.join(podcasts_path, 'results.csv') - sc.full_df.to_csv(output_path, index=False) - logger.info(f"Clustered DataFrame saved to {output_path}") - -if __name__ == '__main__': - mp.set_start_method('spawn', force=True) - - parser = argparse.ArgumentParser(description="Classification speaker") - parser.add_argument( - "--config_path", - type=str, - help="Path to the configuration file" - ) - parser.add_argument( - "--podcasts_path", - type=str, - help="Path to the podcast folder" - ) - parser.add_argument( - "--threshold", - type=float, - help="Threshold for clustering" - ) - parser.add_argument( - "--model_path", - type=str, - help="embedder model path" - ) - parser.add_argument( - "--num_workers", - type=int, - help="Number of workers per GPU" - ) - - args = parser.parse_args() - main(args) \ No newline at end of file diff --git a/src/classification/classification_args.sh b/src/classification/classification_args.sh deleted file mode 100644 index 3b2f2ba..0000000 --- a/src/classification/classification_args.sh +++ /dev/null @@ -1,23 +0,0 @@ -#!/bin/bash - -activate_venv() { - local venv_path=$1 - if [ ! -f "$venv_path/bin/activate" ]; then - echo "Error: Virtual environment not found at $venv_path" - exit 1 - fi - source "$venv_path/bin/activate" - echo "Activated: $(which python)" -} - - -activate_venv ".dev_venv" - -MODEL_PATH="$SCRIPT_DIR/voxblink2_samresnet100_ft" -PODCASTS_PATH=""../../../podcasts"" -THRESHOLD=0.8 - -python -m src.classificatoin.classificatoin \ - --podcasts_path "$PODCASTS_PATH" \ - --model_path "$MODEL_PATH" \ - --threshold "$THRESHOLD" \ diff --git a/src/classification/classification_yaml.sh b/src/classification/classification_yaml.sh deleted file mode 100644 index 06cec60..0000000 --- a/src/classification/classification_yaml.sh +++ /dev/null @@ -1,22 +0,0 @@ -#!/bin/bash - -activate_venv() { - local venv_path=$1 - if [ ! -f "$venv_path/bin/activate" ]; then - echo "Error: Virtual environment not found at $venv_path" - exit 1 - fi - source "$venv_path/bin/activate" - echo "Activated: $(which python)" -} - -if [ -z "${1:-}" ]; then - echo "Usage: $0 " - exit 1 -fi - -CONFIG_PATH=$(realpath "$1") - -activate_venv ".dev_venv" - -python3 -m src.classification.classification --config_path "$CONFIG_PATH" \ No newline at end of file diff --git a/src/classification/emb/embeder.py b/src/classification/emb/embeder.py deleted file mode 100644 index 1e1f124..0000000 --- a/src/classification/emb/embeder.py +++ /dev/null @@ -1,106 +0,0 @@ -import os -import yaml -import torchaudio.compliance.kaldi as kaldi -import torch -import torchaudio -from silero_vad import load_silero_vad, get_speech_timestamps -from loguru import logger - -from wespeaker.models.speaker_model import get_speaker_model -from wespeaker.utils.checkpoint import load_checkpoint - -class ResNetEmbedder: - def __init__(self, model_path, device, resample_rate=16000, use_vad=True): - self.device = device - self.resample_rate = resample_rate - self.use_vad = use_vad - - config_path = os.path.join(model_path, 'config.yaml') - model_file = os.path.join(model_path, 'avg_model.pt') - - if not os.path.exists(config_path): - raise FileNotFoundError(f"Config file {config_path} not found") - if not os.path.exists(model_file): - raise FileNotFoundError(f"Model file {model_file} not found") - - with open(config_path, 'r') as fin: - configs = yaml.safe_load(fin) - - self.model = get_speaker_model(configs['model'])(**configs['model_args']) - load_checkpoint(self.model, model_file) - self.model.to(self.device).eval() - - if self.use_vad: - self.vad_model = load_silero_vad() - - def __call__(self, path): - wav = self._preprocess(path) - - if wav is None: - return None, None - - duration = wav.shape[-1] / self.resample_rate - fullness = duration / self.total_duration - - feats = self.compute_fbank(wav) - feats = feats.unsqueeze(0).to(self.device) - - with torch.no_grad(): - outputs = self.model(feats) - outputs = outputs[-1] if isinstance(outputs, tuple) else outputs - - return outputs[0].detach().cpu(), fullness - - def _preprocess(self, path): - try: - waveform, sr = torchaudio.load(path) - if sr != self.resample_rate: - waveform = torchaudio.functional.resample( - waveform, - orig_freq=sr, - new_freq=self.resample_rate - ) - self.total_duration = waveform.shape[-1] / self.resample_rate - except Exception as e: - logger.error(f"failed to load {path} - {e}") - return None - - if waveform.shape[0] > 1: - waveform = torch.mean(waveform, dim=0, keepdim=True) - - if self.use_vad: - speech_segments = get_speech_timestamps( - waveform.squeeze(0), - self.vad_model, - threshold=0.4, - return_seconds=False - ) - - if len(speech_segments) == 0: # empty audio - return - - segments = [waveform[:, seg['start']:seg['end']] - for seg in speech_segments] - - waveform = torch.cat(segments, dim=-1) - - return waveform - - def compute_fbank(self, - waveform, - sample_rate=16000, - num_mel_bins=80, - frame_length=25, - frame_shift=10, - cmn=True - )->torch.Tensor: - - feat = kaldi.fbank(waveform, - num_mel_bins=num_mel_bins, - frame_length=frame_length, - frame_shift=frame_shift, - sample_frequency=sample_rate) - if cmn: - feat = feat - torch.mean(feat, 0) - - return feat \ No newline at end of file diff --git a/src/collate.py b/src/collate.py index 79dc767..aeadce9 100644 --- a/src/collate.py +++ b/src/collate.py @@ -6,7 +6,7 @@ import concurrent.futures from loguru import logger -from src.utils import load_config, read_file_content +from src.utils.utils import load_config, read_file_content, get_audio_paths def process_audio_file(audio_path_str: str, base_path: Path) -> Dict[str, Optional[str]]: @@ -16,13 +16,12 @@ def process_audio_file(audio_path_str: str, base_path: Path) -> Dict[str, Option file_types = { 'accent': '_accent.txt', - 'giga': '_giga.txt', + 'rover': '_rover.txt', 'punct': '_punct.txt', - 'whisper': '_whisper.txt', - 'phonemes': '_giga_phonemes.txt' + 'phonemes': '_rover_phonemes.txt', } - results = {'audio_path': audio_path_str} + results = {'filepath': audio_path_str} for key, suffix in file_types.items(): file_path = base_path / dir_path / f"{base_name}{suffix}" results[key] = read_file_content(file_path) @@ -31,20 +30,23 @@ def process_audio_file(audio_path_str: str, base_path: Path) -> Dict[str, Option def main(args): - - base_path = Path( - load_config(args.config_path, 'download').get('podcasts_path', '../../balalaika') - if args.config_path else args.podcasts_path - ) - - - df = pd.read_csv(base_path / "results.csv") - df.drop_duplicates(subset='audio_path', inplace=True) + config = load_config(args.config_path, 'download') + base_path = Path(config.get('podcasts_path', '../../balalaika')) + num_workers = config.get('num_workers', 32) + + df_path = Path(base_path) / "balalaika.csv" + if df_path.exists(): + logger.info(f"Loading existing dataframe from {df_path}") + df = pd.read_csv(df_path) + df.drop_duplicates(subset='filepath', inplace=True) + else: + logger.info(f"No existing dataframe found. Creating new one from audio paths.") + audio_paths = [str(path) for path in get_audio_paths(base_path)] + df = pd.DataFrame({'filepath': audio_paths}) - audio_paths = df['audio_path'].tolist() + audio_paths = df['filepath'].tolist() results = [] - num_workers = 32 logger.info(f"Starting processing with {num_workers} workers") with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor: @@ -65,7 +67,7 @@ def main(args): extracted_df = pd.DataFrame(results) - final_df = pd.merge(df, extracted_df, on='audio_path', how='left') + final_df = pd.merge(df, extracted_df, on='filepath', how='left') output_path = base_path / "balalaika.parquet" final_df.to_parquet(output_path, engine='pyarrow', index=False) @@ -79,12 +81,6 @@ def main(args): type=str, help="Path to config file" ) - parser.add_argument( - "--podcasts_path", - type=str, - default='../../balalaika', - help="Path to dataset directory" - ) - + args = parser.parse_args() main(args) \ No newline at end of file diff --git a/src/collate_yamls.sh b/src/collate_yamls.sh index 95c59bf..acfec96 100644 --- a/src/collate_yamls.sh +++ b/src/collate_yamls.sh @@ -21,4 +21,5 @@ activate_venv ".dev_venv" SCRIPT_DIR=$(dirname "$(realpath "$0")") -python3 -m src.collate --config_path "$CONFIG_PATH" \ No newline at end of file +taskset -c 0-24 python3 -m src.collate --config_path "$CONFIG_PATH" +taskset -c 0-24 python3 -m src.to_webdataset --config_path "$CONFIG_PATH" \ No newline at end of file diff --git a/src/download/README.md b/src/download/README.md index 9be0fc5..c1c2e51 100644 --- a/src/download/README.md +++ b/src/download/README.md @@ -1,28 +1,26 @@ +## Download (Yandex Music) -## Usage/Examples -Running the Code via Command-Line Arguments -You can modify the parameters directly in the shell script (download_args.sh) and then run it. -~~~ -bash download/download_args.sh -~~~ - -### Running the Code via Config File -Example: -~~~ -bash download/download_yaml.sh config_path -~~~ - -## Explanation of Parameters -- `--config_path`: Path to the YAML configuration file. -- `--podcasts_path`: Directory to save downloaded podcasts. -- `--episodes_limit`: Maximum number of episodes to download per podcast. -- `--num_workers`: Number of parallel threads for downloading. - - -## Output Structure -~~~ -podcasts/ +Downloads episodes from URLs / playlists configured for the downloader. + +## Run + +```bash +bash src/download/download_yaml.sh configs/config.yaml +``` + +## Parameters + +See **`download`** in `configs/config.yaml` (`podcasts_path`, `episodes_limit`, `num_workers`, `podcasts_urls_file`). + +**Note:** `src/collate.py` also reads the **`download`** section for `podcasts_path` and `num_workers` when building `balalaika.parquet`. Keep that path aligned with the rest of the pipeline. + +## Output + +```text +{podcasts_path}/ └── {podcast_id}/ - ├── {segment_id}/ - └──... -~~~ \ No newline at end of file + └── {episode_id}/ + └── *.mp3 +``` + +Next step: **preprocess** (Sortformer, Smart VAD, chunking). diff --git a/src/download/download.py b/src/download/download.py index c352bc4..0baef2b 100644 --- a/src/download/download.py +++ b/src/download/download.py @@ -11,7 +11,7 @@ from loguru import logger from yandex_music import Client -from src.utils import load_config +from src.utils.utils import load_config def init_client(client_key): try: @@ -147,9 +147,9 @@ def main(args): if not client: return - podcasts_path = args.podcasts_path if args.podcasts_path else config.get('podcasts_path','../../../balalaika') - episodes_limit = args.episodes_limit if args.episodes_limit else config.get('episodes_limit',1) - urls_pickle_path = args.podcasts_urls_file if args.podcasts_urls_file else config.get('podcasts_urls_file','alboms.pkl') + podcasts_path = config.get('podcasts_path','../../../balalaika') + episodes_limit = config.get('episodes_limit',1) + urls_pickle_path = config.get('podcasts_urls_file','alboms.pkl') try: with open(urls_pickle_path, 'rb') as file: @@ -196,28 +196,6 @@ def main(args): default="./configs/config.yaml", help="Path to the configuration file" ) - parser.add_argument( - "--podcasts_urls_file", - default=None, - help="Path to the pickle file with album urls" - ) - parser.add_argument( - "--podcasts_path", - default=None, - help="Path for saving podcasts" - ) - parser.add_argument( - "--episodes_limit", - default=None, - type=int, - help="Limit for episodes to download" - ) - parser.add_argument( - "--num_workers", - default=None, - type=int, - help="num workers" - ) - + args = parser.parse_args() main(args) \ No newline at end of file diff --git a/src/download/download_args.sh b/src/download/download_args.sh deleted file mode 100644 index f2cf543..0000000 --- a/src/download/download_args.sh +++ /dev/null @@ -1,25 +0,0 @@ -#!/bin/bash - -activate_venv() { - local venv_path=$1 - if [ ! -f "$venv_path/bin/activate" ]; then - echo "Error: Virtual environment not found at $venv_path" - exit 1 - fi - source "$venv_path/bin/activate" - echo "Activated: $(which python)" -} - -activate_venv ".dev_venv" - -SCRIPT_DIR=$(dirname "$(realpath "$0")") - - -PODCASTS_PATH="../../../balalaika" -EPISODES_LIMIT=2 -NUM_WORKERS=2 - -python3 -m src.download.download \ - --podcasts_path "$PODCASTS_PATH" \ - --episodes_limit "$EPISODES_LIMIT" \ - --num_workers "$NUM_WORKERS" \ No newline at end of file diff --git a/src/download/download_prepared.sh b/src/download/download_prepared.sh index 2af56e0..decbe55 100644 --- a/src/download/download_prepared.sh +++ b/src/download/download_prepared.sh @@ -18,4 +18,4 @@ activate_venv ".user_venv" SCRIPT_DIR=$(dirname "$(realpath "$0")") -python3 -m src.download.download_prepared --pickle_path "$PICKLE_PATH" --podcasts_path "$PODCASTS_PATH" --num_workers "$NUM_WORKERS" \ No newline at end of file +taskset -c 0-24 python3 -m src.download.download_prepared --pickle_path "$PICKLE_PATH" --podcasts_path "$PODCASTS_PATH" --num_workers "$NUM_WORKERS" \ No newline at end of file diff --git a/src/download/download_yaml.sh b/src/download/download_yaml.sh index 16a57b8..fff1f63 100644 --- a/src/download/download_yaml.sh +++ b/src/download/download_yaml.sh @@ -21,4 +21,4 @@ activate_venv ".dev_venv" SCRIPT_DIR=$(dirname "$(realpath "$0")") -python3 -m src.download.download --config_path "$CONFIG_PATH" \ No newline at end of file +taskset -c 0-24 python3 -m src.download.download --config_path "$CONFIG_PATH" \ No newline at end of file diff --git a/src/libs/nisqa/__init__.py b/src/libs/nisqa/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/libs/nisqa/core/__init__.py b/src/libs/nisqa/core/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/libs/nisqa/core/model_torch.py b/src/libs/nisqa/core/model_torch.py deleted file mode 100644 index ed3f69f..0000000 --- a/src/libs/nisqa/core/model_torch.py +++ /dev/null @@ -1,340 +0,0 @@ -# --------------------------------------------------------------------- -# This file is based on code from the NISQA-s project by Deep Intelligence: -# https://github.com/deepvk/NISQA-s -# -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at https://www.apache.org/licenses/LICENSE-2.0 -# --------------------------------------------------------------------- - -import os - -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence - - -class PoolAvg(torch.nn.Module): - """ - PoolAvg: Average pooling that consideres masked time-steps. - """ - - def __init__(self, d_input, output_size): - super().__init__() - - self.linear = nn.Linear(d_input, output_size) - - def forward(self, x, n_wins): - mask = torch.arange(x.shape[1])[None, :] < n_wins[:, None].to("cpu").to(torch.long) - mask = ~mask.unsqueeze(2).to(x.device) - x.masked_fill_(mask, 0.0) - - x = torch.div(x.sum(1), n_wins.unsqueeze(1)) - - x = self.linear(x) - - return x - - -class AdaptCNN(nn.Module): - """ - AdaptCNN: CNN with adaptive maxpooling that can be used as framewise model. - Overall, it has six convolutional layers. This CNN module is more flexible - than the StandardCNN that requires a fixed input dimension of 48x15. - """ - - def __init__( - self, - input_channels, - c_out_1, - c_out_2, - c_out_3, - kernel_size, - dropout, - pool_1, - pool_2, - pool_3, - fc_out_h=20, - ): - super().__init__() - self.name = "CNN_adapt" - - self.input_channels = input_channels - self.c_out_1 = c_out_1 - self.c_out_2 = c_out_2 - self.c_out_3 = c_out_3 - self.kernel_size = kernel_size - self.pool_1 = pool_1 - self.pool_2 = pool_2 - self.pool_3 = pool_3 - self.dropout_rate = dropout - self.fc_out_h = fc_out_h - - self.dropout = nn.Dropout2d(p=self.dropout_rate) - - if isinstance(self.kernel_size, int): - self.kernel_size = (self.kernel_size, self.kernel_size) - - self.kernel_size_last = (self.kernel_size[0], self.pool_3[1]) - - if self.kernel_size[1] == 1: - self.cnn_pad = (1, 0) - else: - self.cnn_pad = (1, 1) - - self.conv1 = nn.Conv2d(self.input_channels, self.c_out_1, self.kernel_size, padding=self.cnn_pad) - - self.bn1 = nn.BatchNorm2d(self.conv1.out_channels) - - self.conv2 = nn.Conv2d(self.conv1.out_channels, self.c_out_2, self.kernel_size, padding=self.cnn_pad) - - self.bn2 = nn.BatchNorm2d(self.conv2.out_channels) - - self.conv3 = nn.Conv2d(self.conv2.out_channels, self.c_out_3, self.kernel_size, padding=self.cnn_pad) - - self.bn3 = nn.BatchNorm2d(self.conv3.out_channels) - - self.conv4 = nn.Conv2d(self.conv3.out_channels, self.c_out_3, self.kernel_size, padding=self.cnn_pad) - - self.bn4 = nn.BatchNorm2d(self.conv4.out_channels) - - self.conv5 = nn.Conv2d(self.conv4.out_channels, self.c_out_3, self.kernel_size, padding=self.cnn_pad) - - self.bn5 = nn.BatchNorm2d(self.conv5.out_channels) - - self.conv6 = nn.Conv2d(self.conv5.out_channels, self.c_out_3, self.kernel_size_last, padding=(1, 0)) - - self.bn6 = nn.BatchNorm2d(self.conv6.out_channels) - - if self.fc_out_h: - self.fc = nn.Linear(self.conv6.out_channels * self.pool_3[0], self.fc_out_h) - self.fan_out = self.fc_out_h - else: - self.fan_out = self.conv6.out_channels * self.pool_3[0] - - def forward(self, x, n_wins): - x_packed = pack_padded_sequence(x, n_wins.cpu(), batch_first=True, enforce_sorted=False) - - x = F.relu(self.bn1(self.conv1(x_packed.data))) - x = F.adaptive_max_pool2d(x, output_size=self.pool_1) - - x = F.relu(self.bn2(self.conv2(x))) - x = F.adaptive_max_pool2d(x, output_size=self.pool_2) - x = self.dropout(x) - - x = F.relu(self.bn3(self.conv3(x))) - x = self.dropout(x) - - x = F.relu(self.bn4(self.conv4(x))) - x = F.adaptive_max_pool2d(x, output_size=self.pool_3) - x = self.dropout(x) - - x = F.relu(self.bn5(self.conv5(x))) - x = self.dropout(x) - - x = F.relu(self.bn6(self.conv6(x))) - - x = x.view(-1, self.conv6.out_channels * self.pool_3[0]) - - if self.fc_out_h: - x = self.fc(x) - - x = x_packed._replace(data=x) - - x, _ = pad_packed_sequence(x, batch_first=True, padding_value=0.0, total_length=n_wins.max()) - return x - - -class LSTM(nn.Module): - def __init__(self, input_size, lstm_h=128, num_layers=2, dropout=0.1, bidirectional=True): - super().__init__() - - self.lstm = nn.LSTM( - input_size=input_size, - hidden_size=lstm_h, - num_layers=num_layers, - dropout=dropout, - batch_first=True, - bidirectional=bidirectional, - ) - if bidirectional: - num_directions = 2 - else: - num_directions = 1 - self.fan_out = num_directions * lstm_h - - def forward(self, x, n_wins, h0=None, c0=None): - x = pack_padded_sequence(x, n_wins.cpu(), batch_first=True, enforce_sorted=False) - self.lstm.flatten_parameters() - x, (h, c) = self.lstm(x, (h0, c0)) - x, _ = pad_packed_sequence(x, batch_first=True, padding_value=0.0, total_length=n_wins.max()) - return x, h, c - - -class NISQA_DIM(nn.Module): - """ - NISQA_DIM: The main speech quality model with speech quality dimension - estimation (MOS, Noisiness, Coloration, Discontinuity, and Loudness). - """ - - def __init__( - self, - cnn_c_out_1=16, - cnn_c_out_2=32, - cnn_c_out_3=64, - cnn_kernel_size=3, - cnn_dropout=0.2, - cnn_pool_1=[24, 7], - cnn_pool_2=[12, 5], - cnn_pool_3=[6, 3], - cnn_fc_out_h=None, - td_lstm_h=128, - td_lstm_num_layers=1, - td_lstm_dropout=0, - td_lstm_bidirectional=True, - ): - super().__init__() - - self.name = "NISQA_DIM" - - self.cnn = AdaptCNN( - input_channels=1, - c_out_1=cnn_c_out_1, - c_out_2=cnn_c_out_2, - c_out_3=cnn_c_out_3, - kernel_size=cnn_kernel_size, - dropout=cnn_dropout, - pool_1=cnn_pool_1, - pool_2=cnn_pool_2, - pool_3=cnn_pool_3, - fc_out_h=cnn_fc_out_h, - ) - - self.time_dependency = LSTM( - input_size=self.cnn.fan_out, - lstm_h=td_lstm_h, - num_layers=td_lstm_num_layers, - dropout=td_lstm_dropout, - bidirectional=td_lstm_bidirectional, - ) - - self.pool_layers = nn.ModuleList( - [ - PoolAvg( - self.time_dependency.fan_out, - output_size=1, - ) - for _ in range(5) - ] - ) - - def forward(self, x, n_wins, h0, c0): - x = self.cnn(x, n_wins) - x, h, c = self.time_dependency(x, n_wins, h0, c0) - out = [mod(x, n_wins) for mod in self.pool_layers] - out = torch.cat(out, dim=1) - - return out, h, c - - -def loadModel(args): - """ - Loads the Pytorch models with given input arguments. - """ - # if True overwrite input arguments from pretrained model - if torch.cuda.is_available(): - dev = torch.device("cuda") - else: - dev = torch.device("cpu") - - if "tr_device" in args: - if args["tr_device"] == "cpu": - dev = torch.device("cpu") - elif args["tr_device"] == "cuda": - dev = torch.device("cuda") - print("Device: {}".format(dev)) - - if "tr_parallel" in args: - if (dev == torch.device("cpu")) and args["tr_parallel"]: - args["tr_parallel"] == False - print("Using CPU -> tr_parallel set to False") - if args["pretrained_model"]: - if os.path.isabs(args["pretrained_model"]): - model_path = os.path.join(args["pretrained_model"]) - else: - model_path = os.path.join(os.getcwd(), args["pretrained_model"]) - checkpoint = torch.load(model_path, map_location=dev) - - # update checkpoint arguments with new arguments - checkpoint["args"].update(args) - args = checkpoint["args"] - - args["dim"] = True - args["csv_mos_train"] = None # column names hardcoded for dim models - args["csv_mos_val"] = None - - args["double_ended"] = False - args["csv_ref"] = None - - # Load Model - model_args = { - "cnn_c_out_1": args["cnn_c_out_1"], - "cnn_c_out_2": args["cnn_c_out_2"], - "cnn_c_out_3": args["cnn_c_out_3"], - "cnn_kernel_size": args["cnn_kernel_size"], - "cnn_dropout": args["cnn_dropout"], - "cnn_pool_1": args["cnn_pool_1"], - "cnn_pool_2": args["cnn_pool_2"], - "cnn_pool_3": args["cnn_pool_3"], - "cnn_fc_out_h": args["cnn_fc_out_h"], - "td_lstm_h": args["td_lstm_h"], - "td_lstm_num_layers": args["td_lstm_num_layers"], - "td_lstm_dropout": args["td_lstm_dropout"], - "td_lstm_bidirectional": args["td_lstm_bidirectional"], - } - - model = NISQA_DIM(**model_args) - - # Load weights if pretrained model is used ------------------------------------ - if args["pretrained_model"]: - missing_keys, unexpected_keys = model.load_state_dict(checkpoint["model_state_dict"], strict=True) - print("Loaded pretrained model from " + args["pretrained_model"]) - if missing_keys: - print("missing_keys:") - print(missing_keys) - if unexpected_keys: - print("unexpected_keys:") - print(unexpected_keys) - return model, dev, model_args - - -def model_init(args): - model = NISQA_DIM( - cnn_c_out_1=args["cnn_c_out_1"], - cnn_c_out_2=args["cnn_c_out_2"], - cnn_c_out_3=args["cnn_c_out_3"], - cnn_kernel_size=args["cnn_kernel_size"], - cnn_dropout=args["cnn_dropout"], - cnn_pool_1=args["cnn_pool_1"], - cnn_pool_2=args["cnn_pool_2"], - cnn_pool_3=args["cnn_pool_3"], - cnn_fc_out_h=args["cnn_fc_out_h"], - td_lstm_h=args["td_lstm_h"], - td_lstm_num_layers=args["td_lstm_num_layers"], - td_lstm_dropout=args["td_lstm_num_layers"], - td_lstm_bidirectional=args["td_lstm_bidirectional"], - ) - - ckp = torch.load(args["ckp"], map_location="cpu", weights_only=False) - model.load_state_dict(ckp["model_state_dict"], strict=True) - model = model.to(torch.device(args["inf_device"])) - model.eval() - - # init lstm states - directions = 2 if args["td_lstm_bidirectional"] else 1 - h0 = torch.zeros(args["td_lstm_num_layers"] * directions, 1, args["td_lstm_h"]) - c0 = torch.zeros(args["td_lstm_num_layers"] * directions, 1, args["td_lstm_h"]) - - return model, h0, c0 diff --git a/src/libs/nisqa/utils/__init__.py b/src/libs/nisqa/utils/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/libs/nisqa/utils/process_utils.py b/src/libs/nisqa/utils/process_utils.py deleted file mode 100644 index 5906e1f..0000000 --- a/src/libs/nisqa/utils/process_utils.py +++ /dev/null @@ -1,124 +0,0 @@ -# --------------------------------------------------------------------- -# This file is based on code from the NISQA-s project by Deep Intelligence: -# https://github.com/deepvk/NISQA-s -# -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at https://www.apache.org/licenses/LICENSE-2.0 -# --------------------------------------------------------------------- -import numpy as np -import torch -import torchaudio as ta - - -def get_ta_melspec( - y, - sr=48e3, - n_fft=1024, - hop_length=80, - win_length=170, - n_mels=32, - fmax=16e3, # for the sake of consistency with original librosa implementation -): - """ - Calculate mel-spectrograms with torchaudio (librosa-like). - """ - if isinstance(y, str): - try: - y, sr = ta.load(y) - y = y[0] - except: - raise ValueError("Could not load file {}".format(y)) - - melSpec = ta.transforms.MelSpectrogram( - sample_rate=int(sr), - n_fft=n_fft, - hop_length=hop_length, - win_length=win_length, - window_fn=torch.hann_window, - center=True, - pad_mode="reflect", - power=1.0, - n_mels=n_mels, - norm=None, - mel_scale="htk", - ) - - S = melSpec(y) - - spec = ta.functional.amplitude_to_DB(S, amin=1e-4, top_db=80.0, multiplier=20.0, db_multiplier=0.0) - return spec - - -def segment_specs(x, seg_length=15, seg_hop=3, max_length=None): - """ - Segment a spectrogram into "seg_length" wide spectrogram segments. - Instead of using only the frequency bin of the current time step, - the neighboring bins are included as input to the CNN. For example - for a seg_length of 7, the previous 3 and the follwing 3 frequency - bins are included. - - A spectrogram with input size [H x W] will be segmented to: - [W-(seg_length-1) x C x H x seg_length], where W is the width of the - original mel-spec (corresponding to the length of the speech signal), - H is the height of the mel-spec (corresponding to the number of mel bands), - C is the number of CNN input Channels (always one in our case). - """ - n_wins = x.shape[1] - (seg_length - 1) - transposed_x = x.transpose(1, 0) - - unfolded_x = transposed_x.unfold(0, seg_length, 1) - expanded_x = unfolded_x.unsqueeze(1) - x = expanded_x[::seg_hop, :] - n_wins = int(np.ceil(n_wins / seg_hop)) - if max_length is not None: - if max_length < x.shape[0]: - raise ValueError( - "n_wins {} > max_length {}. Increase max window length ms_max_segments!".format(x.shape[0], max_length) - ) - x_padded = torch.zeros((max_length, x.shape[1], x.shape[2], x.shape[3])) - x_padded[:n_wins, :] = x - x = x_padded - - return x, np.array(n_wins) - - -def process(audio, sr, model, h0, c0, args): - audio = get_ta_melspec( - torch.as_tensor(audio).float().cpu(), - sr, - args["ms_n_fft"], - args["ms_hop_length"], - args["ms_win_length"], - args["ms_n_mels"], - args["ms_fmax"], - ) - device = args['inf_device'] - audio, n_wins = segment_specs(audio, args["ms_seg_length"], args["ms_seg_hop_length"]) - np.set_printoptions(precision=3) - if args["updates"]: - n_wins = args["updates"] - audio = torch.split(audio, args["updates"], dim=0) - for seg in audio: - if seg.shape[0] < n_wins: - to_pad = torch.zeros(audio[0].shape) - to_pad[: seg.shape[0], :] = seg - seg = to_pad - with torch.no_grad(): - out, h0, c0 = model( - seg.unsqueeze(0).float().to(device), - torch.as_tensor(n_wins).unsqueeze(0).to(device), - h0.to(device), - c0.to(device) - ) - else: - with torch.no_grad(): - out, h0, c0 = model( - audio.unsqueeze(0).float().to(device), - torch.as_tensor(n_wins).unsqueeze(0).to(device), - h0.to(device), - c0.to(device) - ) - - return out, h0, c0 diff --git a/src/libs/smart_turn/.gitignore b/src/libs/smart_turn/.gitignore new file mode 100644 index 0000000..5553366 --- /dev/null +++ b/src/libs/smart_turn/.gitignore @@ -0,0 +1,2 @@ +.torch_hub +*.onnx \ No newline at end of file diff --git a/src/libs/smart_turn/inference.py b/src/libs/smart_turn/inference.py new file mode 100644 index 0000000..034f029 --- /dev/null +++ b/src/libs/smart_turn/inference.py @@ -0,0 +1,85 @@ +import torch +from transformers import WhisperFeatureExtractor +import onnxruntime as ort + +ONNX_MODEL_PATH = "smart-turn-v3.0.onnx" + +# Load model and processor +def build_session(onnx_path): + so = ort.SessionOptions() + so.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL + so.inter_op_num_threads = 1 + so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL + return ort.InferenceSession(onnx_path, sess_options=so) + +feature_extractor = WhisperFeatureExtractor(chunk_length=8) +session = build_session(ONNX_MODEL_PATH) + +def truncate_audio_to_last_n_seconds(audio_array, n_seconds=8, sample_rate=16000): + """Truncate audio to last n seconds or pad with zeros to meet n seconds.""" + max_samples = n_seconds * sample_rate + if len(audio_array) > max_samples: + return audio_array[-max_samples:] + elif len(audio_array) < max_samples: + # Pad with zeros at the beginning + padding = max_samples - len(audio_array) + return np.pad(audio_array, (padding, 0), mode='constant', constant_values=0) + return audio_array + + +def predict_endpoint(audio_array): + """ + Predict whether an audio segment is complete (turn ended) or incomplete. + + Args: + audio_array: Numpy array containing audio samples at 16kHz + + Returns: + Dictionary containing prediction results: + - prediction: 1 for complete, 0 for incomplete + - probability: Probability of completion (sigmoid output) + """ + + # Truncate to 8 seconds (keeping the end) or pad to 8 seconds + audio_array = truncate_audio_to_last_n_seconds(audio_array, n_seconds=8) + + # Process audio using Whisper's feature extractor + inputs = feature_extractor( + audio_array, + sampling_rate=16000, + return_tensors="pt", + padding="max_length", + max_length=8 * 16000, + truncation=True, + do_normalize=True, + ) + + # Convert to numpy and ensure correct shape for ONNX + input_features = inputs.input_features.squeeze(0).numpy().astype(np.float32) + input_features = np.expand_dims(input_features, axis=0) # Add batch dimension + + # Run ONNX inference + outputs = session.run(None, {"input_features": input_features}) + + # Extract probability (ONNX model returns sigmoid probabilities) + probability = outputs[0][0].item() + + # Make prediction (1 for Complete, 0 for Incomplete) + prediction = 1 if probability > 0.5 else 0 + + return { + "prediction": prediction, + "probability": probability, + } + + +# Example usage +if __name__ == "__main__": + import numpy as np + + # Create a dummy audio array for testing (1 second of random audio) + dummy_audio = np.random.randn(16000).astype(np.float32) + + result = predict_endpoint(dummy_audio) + print(f"Prediction: {result['prediction']}") + print(f"Probability: {result['probability']:.4f}") \ No newline at end of file diff --git a/src/libs/smart_turn/offline_svad.py b/src/libs/smart_turn/offline_svad.py new file mode 100644 index 0000000..236e0b4 --- /dev/null +++ b/src/libs/smart_turn/offline_svad.py @@ -0,0 +1,104 @@ +import os +from loguru import logger +import numpy as np +import onnxruntime as ort +from transformers import WhisperFeatureExtractor +from huggingface_hub import hf_hub_download + + +class SmartVAD: + """EOS (End of Speech) classifier using Smart Turn model. + Classifies whether a speech segment represents a complete utterance. + """ + + def __init__( + self, + smart_vad_threshold: float = 0.4, + device: str = 'cuda:0', + resample_rate: int = 16_000, + smart_vad_model: str = "pipecat-ai/smart-turn-v3" + ): + self.smart_vad_threshold = smart_vad_threshold + self.sample_rate = resample_rate + self.smart_vad_model = smart_vad_model + self.device = device + self.device_id = int(device.split(':')[1]) if ':' in device else 0 + + if not os.path.exists(self.smart_vad_model): + self._load_from_hf() + + self._init_smart_vad() + + def _init_smart_vad(self): + logger.info('Initializing Smart VAD (EOS classifier)...') + self.feature_extractor = WhisperFeatureExtractor(chunk_length=8) + + so = ort.SessionOptions() + so.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL + so.inter_op_num_threads = 1 + so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL + self.session = ort.InferenceSession( + self.smart_vad_model, sess_options=so, + providers=[ + ( + "CUDAExecutionProvider", + {"device_id": self.device_id} + ), + "CPUExecutionProvider", + ] + ) + logger.info('Smart VAD (EOS classifier) initialized.') + + def _load_from_hf(self): + return hf_hub_download( + repo_id="pipecat-ai/smart-turn-v3", + filename="smart-turn-v3.2-gpu.onnx", + local_dir="./models", + local_dir_use_symlinks=False + ) + + def predict_endpoint(self, audio_array: np.ndarray) -> dict: + """ + Predict whether an audio segment is complete (turn ended) or incomplete. + + Args: + audio_array: Numpy array containing audio samples at 16kHz + + Returns: + Dictionary with 'prediction' (1=complete, 0=incomplete) and 'probability' + """ + audio_array = self._truncate_audio(audio_array, n_seconds=8) + + inputs = self.feature_extractor( + audio_array, + sampling_rate=16000, + return_tensors="pt", + padding="max_length", + max_length=8 * 16000, + truncation=True, + do_normalize=True + ) + + input_features = inputs.input_features.squeeze(0).numpy().astype(np.float32) + input_features = np.expand_dims(input_features, axis=0) + + outputs = self.session.run(None, {"input_features": input_features}) + probability = outputs[0][0].item() + prediction = 1 if probability > self.smart_vad_threshold else 0 + + return { + "prediction": prediction, + "probability": round(probability, 4), + } + + @staticmethod + def _truncate_audio(audio_array: np.ndarray, n_seconds: int = 8, sample_rate: int = 16000) -> np.ndarray: + """Truncate audio to last n seconds or pad with zeros to meet n seconds.""" + max_samples = n_seconds * sample_rate + if len(audio_array) > max_samples: + return audio_array[-max_samples:] + elif len(audio_array) < max_samples: + padding = max_samples - len(audio_array) + return np.pad(audio_array, (padding, 0), mode='constant', constant_values=0) + return audio_array + diff --git a/src/phonemizer/README.md b/src/phonemizer/README.md index 4764151..9fe9ca9 100644 --- a/src/phonemizer/README.md +++ b/src/phonemizer/README.md @@ -1,41 +1,21 @@ -## Usage/Examples - -### Running the Code via Command-Line Arguments -You can modify the parameters directly in the shell script (`phonemizer/phonemizer_args.sh`) and then run it: -~~~sh -sh phonemizer/phonemizer_args.sh -~~~ - -### Running the Code via Config File -Example: -~~~sh -bash phonemizer/phonemizer_yaml.sh config_path -~~~ - -## Explanation of Parameters - -- `--config_path`: Path to the YAML configuration file. -- `--podcasts_path`: Root directory containing the text files for phoneme conversion (default: "../../../podcasts"). -- `--num_workers`: Number of worker processes per GPU for parallel processing (default: 8). - -## Output Structure - -For each text file ending with `_giga.txt`, a corresponding `_phonemes.txt` file will be created: - -~~~ -podcasts/ -└── {album_id}/ - └── {episode_id}/ - ├── {start_time}_{end_time}_{album_id}_{episode_id}.mp3 - ├── {start_time}_{end_time}_{album_id}_{episode_id}_giga.txt - ├── {start_time}_{end_time}_{album_id}_{episode_id}_punct.txt - ├── {start_time}_{end_time}_{album_id}_{episode_id}_accent.txt - └── {start_time}_{end_time}_{album_id}_{episode_id}_giga_phonemes.txt -~~~ - -### File Descriptions -- `.mp3`: Original audio file -- `_giga.txt`: Initial transcription without punctuation -- `_punct.txt`: Text with restored punctuation -- `_accent.txt`: Text with restored accents -- `_giga_phonemes.txt`: Text converted to phonemes +## Phonemizer (TryIParu) + +Grapheme → IPA from **`{stem}_rover.txt`** using **`tryiparu.G2PModel`** (workers call `load_dataset=True` at init). + +## Run + +```bash +bash src/phonemizer/phonemizer_yaml.sh configs/config.yaml +``` + +## Parameters + +See **`phonemizer`** in `configs/config.yaml` (`podcasts_path`, `num_workers`). + +## Output + +For each `{stem}_rover.txt`: + +- **`{stem}_rover_phonemes.txt`** — space-separated IPA symbols. + +WebDataset key: **`rover_phonemes.txt`**. diff --git a/src/phonemizer/phonemizer.py b/src/phonemizer/phonemizer.py index 31784dd..df30411 100644 --- a/src/phonemizer/phonemizer.py +++ b/src/phonemizer/phonemizer.py @@ -9,7 +9,13 @@ from tryiparu import G2PModel from tqdm import tqdm -from src.utils import get_txt_paths, load_config, read_file_content +from src.utils.utils import get_txt_paths, load_config, read_file_content + +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cuda.enable_flash_sdp(True) +torch.backends.cuda.enable_mem_efficient_sdp(True) +torch.backends.cuda.enable_math_sdp(False) + g2p_model: Any = None @@ -33,7 +39,7 @@ def process_text(text_path: Path): f.write(" ".join(phonemes)) def get_valid_text_paths(src_path: str) -> List[Path]: - all_paths = get_txt_paths(src_path, '_giga.txt') + all_paths = get_txt_paths(src_path, '_rover.txt') valid_paths = [] for giga_path in all_paths: @@ -47,8 +53,8 @@ def get_valid_text_paths(src_path: str) -> List[Path]: def main(args): config = load_config(args.config_path, 'phonemizer') - num_workers = args.num_workers if args.num_workers else config.get('num_workers', 4) - src_path_str = args.podcasts_path if args.podcasts_path else config.get('podcasts_path', '../../../podcasts') + num_workers = config.get('num_workers', 4) + src_path_str = config.get('podcasts_path', '../../../podcasts') all_text_paths = get_valid_text_paths(src_path_str) logger.info(f"Found {len(all_text_paths)} text files to process") @@ -116,16 +122,6 @@ def main(args): type=str, help="Path to the configuration YAML file." ) - parser.add_argument( - "--podcasts_path", - type=str, - help="Path to the directory containing audio files (e.g., MP3s)." - ) - parser.add_argument( - "--num_workers", - type=int, - help="Number of worker processes per GPU for parallel processing." - ) args = parser.parse_args() main(args) \ No newline at end of file diff --git a/src/phonemizer/phonemizer_args.sh b/src/phonemizer/phonemizer_args.sh deleted file mode 100644 index dc4ccec..0000000 --- a/src/phonemizer/phonemizer_args.sh +++ /dev/null @@ -1,23 +0,0 @@ -#!/bin/bash - -activate_venv() { - local venv_path=$1 - if [ ! -f "$venv_path/bin/activate" ]; then - echo "Error: Virtual environment not found at $venv_path" - exit 1 - fi - source "$venv_path/bin/activate" - echo "Activated: $(which python)" -} - -activate_venv ".dev_venv" - -SCRIPT_DIR=$(dirname "$(realpath "$0")") -CONFIG_PATH="$SCRIPT_DIR/../../configs/config.yaml" - -PODCASTS_PATH="../../../balalaika" -NUM_WORKERS=8 - -python3 -m src.phonemizer.phonemizer \ - --podcasts_path "$PODCASTS_PATH" \ - --num_workers "$NUM_WORKERS" \ No newline at end of file diff --git a/src/phonemizer/phonemizer_yaml.sh b/src/phonemizer/phonemizer_yaml.sh index 2991021..ed2be54 100644 --- a/src/phonemizer/phonemizer_yaml.sh +++ b/src/phonemizer/phonemizer_yaml.sh @@ -21,4 +21,4 @@ activate_venv ".dev_venv" SCRIPT_DIR=$(dirname "$(realpath "$0")") -python3 -m src.phonemizer.phonemizer --config_path "$CONFIG_PATH" \ No newline at end of file +taskset -c 0-24 python3 -m src.phonemizer.phonemizer --config_path "$CONFIG_PATH" \ No newline at end of file diff --git a/src/preprocess/README.md b/src/preprocess/README.md index 0bdb27d..30755a1 100644 --- a/src/preprocess/README.md +++ b/src/preprocess/README.md @@ -1,45 +1,43 @@ -## Usage/Examples - -### Running the Code via Command-Line Arguments -You can modify the parameters directly in the shell script (`predprocess_args.sh`) and then run it: -~~~ -sh predprocess/predprocess_args.sh -~~~ - -### Running the Code via Config File -Example: -~~~ -sh predprocess/predprocess_yaml.sh config_path -~~~ - -## Explanation of Parameters - -- `--config_path`: Path to the YAML configuration file (default: None). -- `--podcasts_path`: Root directory containing podcast audio files (default: '../podcasts'). -- `--whisper_model`: Name of the Whisper model to use (default: 'large-v3'). -- `--compute_type`: Compute type for the model (default: 'float16'). -- `--beam_size`: Beam size for beam search decoding (default: 5). -- `--duration`: Target duration in seconds for each audio segment (default: 15). -- `--device`: Hardware accelerator for the Whisper model (default: 'cpu'). -- `--num_workers`: Number of parallel processes for audio processing (default: 1). - -## Output Structure - -After processing, the original audio file is segmented into shorter clips. The resulting structure will be: - -~~~ -podcasts/ -└── {album_id}/ - ├── {episode_id}/ - │ ├── start_time_end_time_{album_id}_{episode_id}.mp3 - │ ├── start_time_end_time_{album_id}_{episode_id}_whisper.txt - │ └── ... (other segments) -~~~ - -Each podcast episode (originally an `.mp3` file) is moved to its own folder named after the episode ID (within its album folder) and then segmented. For each segment, two files are created: -- An audio file (`{start_time}_{end_time}_{album_id}_{episode_id}.mp3`) containing the audio segment -- A text file (`{start_time}_{end_time}_{album_id}_{episode_id}_whisper.txt`) containing the transcription for that segment - -The `start_time` and `end_time` in the filenames represent the timestamp positions (in seconds) of the segment in the original audio file. - -If the segmentation is successful, the original file is deleted. +## Overview + +Prepares long recordings for ASR/TTS: **Sortformer (ONNX)** diarization, **single-speaker** selection, **Smart Turn** end-of-segment refinement, chunk export and `balalaika.csv`, then **crest-factor** filtering and **loudness** normalization (ITU-R BS.1770-4). + +### Stage order (`preprocess_yaml.sh`) + +1. **`src.preprocess.preprocess`** — Sortformer in windows up to `chunk_duration`, overlap filtering, Smart VAD, writes `{start}_{end}_{playlist}_{podcast}.mp3`, appends `balalaika.csv`; **deletes the original long file** after successful chunking. +2. **`crest_factor_remover`** — computes crest factor (peak/RMS) for every chunk, writes it to `balalaika.csv` as `crest_factor`, then deletes files that exceed the threshold and removes their rows from the CSV. +3. **`preprocess_audio`** — peak + target LUFS, overwrites audio. + +### Parameters (`configs/config.yaml` → `preprocess`) + +See the **preprocess** block in `configs/config.yaml` for a line-by-line description (`podcasts_path`, `duration`, `chunk_duration`, `num_workers`, `crest_treshold`, `peak`, `loudness`, `block_size`, `sortformer_model`, `use_tensorrt`, `vad_args.*`). + +## Run + +```bash +bash src/preprocess/preprocess_yaml.sh configs/config.yaml +``` + +(Optional: `src/preprocess/preprocess_args.sh` for ad-hoc CLI args.) + +## Output layout + +```text +{podcasts_path}/ +├── balalaika.csv +└── {playlist_id}/ + └── {podcast_id}/ + ├── 12.50_26.30_{playlist_id}_{podcast_id}.mp3 + └── ... +``` + +Filename times are seconds in the **source** episode. + +## `balalaika.csv` columns after preprocess + +| Column | Added by | +|--------|----------| +| `filepath`, `speaker_id`, `start`, `end`, `total_duration`, `playlist_id`, `podcast_id`, `silence_percent`, `max_silence_duration`, `is_single_speaker` | `preprocess.py` | +| `crest_factor` | `crest_factor_remover.py` (files above threshold are deleted and their rows removed) | + +`DistillMOS`, `music_prob`, and transcription fields are added in **separation** / downstream stages. diff --git a/src/preprocess/crest_factor_remover.py b/src/preprocess/crest_factor_remover.py new file mode 100644 index 0000000..8dbb233 --- /dev/null +++ b/src/preprocess/crest_factor_remover.py @@ -0,0 +1,161 @@ +import argparse +import os +import torch +import torch.multiprocessing as mp +import torchaudio +import pandas as pd +from pathlib import Path +from typing import List + +import numpy as np +from loguru import logger +from tqdm import tqdm + +from src.utils.utils import load_config, get_audio_paths + + +def calculate_crest_factor(audio: np.ndarray) -> float: + peak = np.max(np.abs(audio)) + rms = np.sqrt(np.mean(audio ** 2)) + if rms == 0: + return float('inf') + return peak / rms + + +def run_worker( + rank: int, + world_size: int, + all_file_paths: List[str], + crest_threshold: float, + output_dir: str, +): + my_files = all_file_paths[rank::world_size] + if not my_files: + return + + results = [] + logger.info(f"Worker {rank}/{world_size} processing {len(my_files)} files") + + for path_str in tqdm(my_files, desc=f"Worker-{rank}", position=rank): + try: + audio_tensor, _ = torchaudio.load_with_torchcodec(path_str) + if audio_tensor.shape[0] > 1: + audio = audio_tensor.mean(dim=0).numpy() + else: + audio = audio_tensor.squeeze(0).numpy() + cf = calculate_crest_factor(audio) + except Exception as e: + logger.error(f"Error processing {path_str}: {e}") + continue + finally: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + if cf > crest_threshold: + try: + os.remove(path_str) + logger.debug(f"Deleted {path_str} (crest_factor={cf:.2f})") + except OSError as e: + logger.error(f"Could not delete {path_str}: {e}") + + results.append({'filepath': str(Path(path_str).resolve()), 'crest_factor': round(cf, 4)}) + + if results: + part_path = Path(output_dir) / f'crest_part_{rank}.csv' + pd.DataFrame(results).to_csv(part_path, index=False) + logger.info(f"Worker {rank} saved {len(results)} crest_factor results.") + + +def update_csv(podcasts_path: Path, num_workers: int, crest_threshold: float): + csv_path = podcasts_path / 'balalaika.csv' + parts = [podcasts_path / f'crest_part_{i}.csv' for i in range(num_workers)] + existing_parts = [p for p in parts if p.exists()] + + if not existing_parts: + logger.warning("No crest_part_*.csv files found; skipping CSV update.") + if csv_path.exists(): + df = pd.read_csv(csv_path) + before = len(df) + df = df[df['filepath'].apply(lambda p: Path(p).exists())] + if before != len(df): + df.to_csv(csv_path, index=False) + logger.info(f"Removed {before - len(df)} missing rows from CSV.") + return + + results_df = pd.concat([pd.read_csv(p) for p in existing_parts], ignore_index=True) + for p in existing_parts: + p.unlink() + + if not csv_path.exists(): + logger.warning(f"balalaika.csv not found at {csv_path}; skipping CSV update.") + return + + df = pd.read_csv(csv_path) + df['filepath'] = df['filepath'].apply(lambda p: str(Path(p).resolve())) + results_df['filepath'] = results_df['filepath'].apply(lambda p: str(Path(p).resolve())) + + if 'crest_factor' in df.columns: + df = df.drop(columns=['crest_factor']) + df = df.merge(results_df[['filepath', 'crest_factor']], on='filepath', how='left') + + before = len(df) + df = df[df['filepath'].apply(lambda p: Path(p).exists())] + removed = before - len(df) + deleted_count = (results_df['crest_factor'] > crest_threshold).sum() + logger.info( + f"Crest filter: {deleted_count} files deleted (threshold={crest_threshold}), " + f"{removed} rows removed from CSV." + ) + + df.to_csv(csv_path, index=False) + logger.success(f"CSV updated: {len(df)} rows remain.") + + +def main(args): + config = load_config(args.config_path, 'preprocess') + + podcasts_path = config.get('podcasts_path') + if not podcasts_path: + podcasts_path = '../../../podcasts' + logger.warning("Using default podcasts_path") + podcasts_path = Path(podcasts_path) + + crest_threshold = config.get('crest_treshold', 10.0) + num_workers = config.get('num_workers', 4) + + logger.info( + f"Running crest factor removal: path={podcasts_path}, " + f"threshold={crest_threshold}, workers={num_workers}" + ) + + audio_paths = [str(p) for p in get_audio_paths(str(podcasts_path))] + if not audio_paths: + logger.info("No audio files found for processing.") + return + + logger.info(f"Found {len(audio_paths)} audio files to check") + + if num_workers > 1: + mp.spawn( + run_worker, + args=(num_workers, audio_paths, crest_threshold, str(podcasts_path)), + nprocs=num_workers, + join=True + ) + else: + run_worker(0, 1, audio_paths, crest_threshold, str(podcasts_path)) + + update_csv(podcasts_path, num_workers, crest_threshold) + logger.info("Crest factor check completed.") + + +if __name__ == "__main__": + mp.set_start_method('spawn', force=True) + + parser = argparse.ArgumentParser( + description="Remove audio files that exceed crest factor threshold (peak/rms > threshold)." + ) + parser.add_argument("--config_path", type=str, required=True) + args = parser.parse_args() + + main(args) diff --git a/src/preprocess/preprocess.py b/src/preprocess/preprocess.py index eccbb21..771fb35 100644 --- a/src/preprocess/preprocess.py +++ b/src/preprocess/preprocess.py @@ -1,326 +1,377 @@ import argparse import os import multiprocessing +import re from concurrent.futures import ProcessPoolExecutor, as_completed -from typing import Any, List, Tuple +from pathlib import Path +from typing import Dict, List, Tuple, Any import torch import torchaudio +import pandas as pd from loguru import logger from tqdm import tqdm -from faster_whisper import WhisperModel from dotenv import load_dotenv - from huggingface_hub import login -from src.utils import load_config +from src.utils.utils import load_config, get_audio_paths +from src.libs.smart_turn.offline_svad import SmartVAD + +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cuda.enable_flash_sdp(True) +torch.backends.cuda.enable_mem_efficient_sdp(True) +torch.backends.cuda.enable_math_sdp(False) + +CHUNK_DURATION_S = 15 * 60 + +sortformer_model = None +smart_vad = None + +def get_providers(cuda_id: int) -> list: + return [ + ("TensorrtExecutionProvider", { + "device_id": cuda_id, + "trt_max_workspace_size": 4 * 1024**3, + "trt_fp16_enable": True, + "trt_engine_cache_enable": True, + "trt_engine_cache_path": f".cache/trt_cache_{cuda_id}", + }), + ("CUDAExecutionProvider", {"device_id": cuda_id}), + ] -def get_audio_paths(directory: str) -> List[str]: - audio_paths = [] - for entry in os.listdir(directory): - full_path = os.path.join(directory, entry) - if len(os.path.basename(full_path).split('_')) == 4: - continue - if os.path.isdir(full_path): - audio_paths.extend(get_audio_paths(full_path)) - elif entry.endswith(".mp3") : - audio_paths.append(full_path) - return audio_paths - -def get_whisper_segments( - model: Any, - path_audio: str, - beam_size: int = 5 -) -> Tuple[List[float], List[float], List[str]]: - - segments, info = model.transcribe( - path_audio, - beam_size=beam_size, - ) +def init_models(gpu_id: int, config: Dict[str, Any]): + global sortformer_model, smart_vad + device = f"cuda:{gpu_id}" + providers = get_providers(gpu_id) - timesteps_starts = [] - timesteps_ends = [] - - timestamps_text = [] + try: + from src.preprocess.sortformer_onnx import Sortformer, DiarizationConfig + except ImportError: + logger.error("Sortformer module or Sortformer class not found in src.preprocess.sortformer_onnx") + raise - for segment in segments: - timesteps_starts.append(segment.start) - timesteps_ends.append(segment.end + 0.05) - timestamps_text.append(segment.text) - - assert len(timesteps_starts) == len(timesteps_ends) == len(timestamps_text), "Mismatch in timestamps lengths." - return timesteps_starts, timesteps_ends, timestamps_text - - -def get_piece_idx( - timesteps_starts: List[float], - timesteps_ends: List[float], - duration: float = 15.0, -) -> List[Tuple[int, int]]: - """ - Groups speech segments into optimal chunks respecting duration constraints, - returning indices of the original segments rather than time intervals. - - Processes speech segments to identify optimal groupings that: - 1. Do not exceed specified maximum duration when combined - 2. Contain complete consecutive speech segments - 3. Maintain natural segmentation boundaries - - Args: - timesteps_starts: Start times of speech segments (seconds) - timesteps_ends: End times of speech segments (seconds) - duration: Maximum allowed combined duration in seconds (default: 15.0) - - Returns: - List of tuples representing segment index ranges: - - Each tuple contains (start_index, end_index) of segments in original lists - - Index ranges reference positions in timesteps_starts/timesteps_ends - - Resulting combined segments would be between duration//3 and duration seconds - - Maintains original segment order - - Example: - Input segments: - starts = [0.0, 2.5, 5.5, 12.5] - ends = [2.15, 5.15, 12.15, 18.15] + model_config = DiarizationConfig() + sortformer_model = Sortformer( + model_path=config.get('sortformer_model'), + config=model_config, + providers=providers + ) + + vad_args = config.get('vad_args', {}) + smart_vad = SmartVAD( + smart_vad_model=vad_args.get('smart_vad_model', './models/smart-turn-v3.0.onnx'), + smart_vad_threshold=vad_args.get('smart_vad_threshold', 0.4), + device=device + ) + logger.info(f"Models initialized on {device}") + + +def parse_diarization_output(raw_results) -> List[Tuple[float, float, int]]: + segments = [] + if not raw_results or len(raw_results) == 0: return segments + inner_results = raw_results[0] if isinstance(raw_results[0], list) else raw_results + for seg in inner_results: + try: + if isinstance(seg, str): + parts = seg.strip().split() + if len(parts) >= 3: + segments.append((float(parts[0]), float(parts[1]), int(parts[2].replace('speaker_', '')))) + elif isinstance(seg, (list, tuple)) and len(seg) >= 3: + segments.append((float(seg[0]), float(seg[1]), int(seg[2]))) + except (ValueError, IndexError): + pass + return sorted(segments, key=lambda x: x[0]) + +def diarize_audio(audio: torch.Tensor, sr: int, chunk_duration: float = CHUNK_DURATION_S) -> List[Tuple[float, float, int]]: + global sortformer_model + total_samples = audio.shape[-1] + chunk_samples = int(chunk_duration * sr) + all_segments, offset = [], 0 + + while offset < total_samples: + end = min(offset + chunk_samples, total_samples) + chunk = audio[:, offset:end] - Output with duration=15: - [(0, 2), (3, 3)] # First 3 segments grouped, last segment standalone - """ + audio_np = chunk.squeeze(0).numpy() if chunk.dim() > 1 else chunk.numpy() + raw = sortformer_model.diarize(audio=audio_np, sample_rate=sr, include_tensor_outputs=False) + segs = parse_diarization_output(raw) - if not timesteps_starts or not timesteps_ends: - return [] + offset_sec = offset / sr + segs = [(s + offset_sec, e + offset_sec, spk) for s, e, spk in segs] + + if len(segs) > 2 and total_samples > chunk_samples: + segs = segs[1:-1] + + all_segments.extend(segs) + offset = end + + return sorted(all_segments, key=lambda x: x[0]) + + +def filter_single_speaker_segments(segments: List[Tuple[float, float, int]], min_duration: float = 1.0, max_duration: float = 15.0) -> List[Tuple[float, float, int]]: + filtered = [] + segments = sorted(segments, key=lambda x: x[0]) + for i, (start, end, spk) in enumerate(segments): + dur = end - start + if not (min_duration <= dur <= max_duration): continue + overlap = False + for j, (s2, e2, _) in enumerate(segments): + if i != j and start < e2 and end > s2: + overlap = True + break + if not overlap: + filtered.append((start, end, spk)) + return filtered + +def apply_eos_classification(audio: torch.Tensor, sr: int, segments: List[Tuple[float, float, int]], max_duration: float = 15.0) -> List[Tuple[float, float, int]]: + global smart_vad + if not smart_vad or not segments: return segments + + audio_np = audio.squeeze(0).numpy() if audio.dim() > 1 else audio.numpy() + classified = [] + for s, e, spk in segments: + segment_audio = audio_np[int(s * sr):min(int(e * sr), len(audio_np))] + if len(segment_audio) == 0: continue + pred = smart_vad.predict_endpoint(segment_audio)['prediction'] + classified.append((s, e, spk, pred)) + + merged = [] + i = 0 + while i < len(classified): + start, end, spk, pred = classified[i] + if pred == 1: # EOS detected + merged.append((start, end, spk)) + i += 1 + continue + + j = i + 1 + while j < len(classified): + ns, ne, nspk, npred = classified[j] + if nspk != spk or ne - start > max_duration: break + end = ne + j += 1 + if npred == 1: break + + if end - start >= 1.0: + merged.append((start, end, spk)) + i = max(j, i + 1) + return merged + + +def get_chunk_metrics(c_start: float, c_end: float, raw_segments: List[Tuple[float, float, int]]) -> Tuple[float, float, int]: + chunk_dur = c_end - c_start + if chunk_dur <= 0: return 0.0, 0.0, 0 + + intervals = [] + speakers_in_chunk = set() + + for rs, re, spk in raw_segments: + overlap_s = max(c_start, rs) + overlap_e = min(c_end, re) + if overlap_s < overlap_e: + intervals.append([overlap_s, overlap_e]) + speakers_in_chunk.add(spk) + + intervals.sort(key=lambda x: x[0]) + if not intervals: + return 100.0, round(chunk_dur, 2), 0 + + merged_speech = [] + for interval in intervals: + if not merged_speech: + merged_speech.append(interval) + else: + prev = merged_speech[-1] + if interval[0] <= prev[1]: + prev[1] = max(prev[1], interval[1]) + else: + merged_speech.append(interval) + + speech_dur = sum(e - s for s, e in merged_speech) + silence_dur = max(0.0, chunk_dur - speech_dur) + silence_pct = (silence_dur / chunk_dur) * 100 + + gaps = [merged_speech[0][0] - c_start] + for i in range(len(merged_speech) - 1): + gaps.append(merged_speech[i+1][0] - merged_speech[i][1]) + gaps.append(c_end - merged_speech[-1][1]) + max_gap = max(gaps) + + return round(silence_pct, 2), round(max_gap, 2), len(speakers_in_chunk) + + +def cut_audio(audio: torch.Tensor, sr: int, final_segments: List[Tuple[float, float, int]], raw_segments: List[Tuple[float, float, int]], output_folder: str, album_id: str, episode_id: str, fmt: str = 'mp3', max_duration: float = 15.0) -> List[Dict]: + os.makedirs(output_folder, exist_ok=True) + results = [] + + for start, end, spk in final_segments: + dur = end - start + if dur <= 0.5: continue + + s_sample, e_sample = int(start * sr), min(int(end * sr), audio.shape[-1]) + if e_sample <= s_sample: continue + + sil_pct, max_sil, unique_spk = get_chunk_metrics(start, end, raw_segments) + + segment = audio[:, s_sample:e_sample] + fname = f"{start:.2f}_{end:.2f}_{album_id}_{episode_id}.{fmt}" + out_path = os.path.join(output_folder, fname) + torchaudio.save(out_path, segment, sr) + + results.append({ + 'filepath': os.path.abspath(out_path), + 'speaker_id': spk, + 'start': round(start, 2), + 'end': round(end, 2), + 'total_duration': round(dur, 2), + 'playlist_id': album_id, + 'podcast_id': episode_id, + 'silence_percent': sil_pct, + 'max_silence_duration': max_sil, + 'is_single_speaker': unique_spk == 1 + }) + return results + + +def process_audio_file(path_audio: str, config: Dict[str, Any]) -> List[Dict]: + limit_dur = config.get('duration', 15) + chunk_duration = config.get('chunk_duration', CHUNK_DURATION_S) + + p_audio = Path(path_audio) + album_id, episode_id = p_audio.parent.name, p_audio.stem + episode_folder = p_audio.parent / episode_id - pieces = [] - n = len(timesteps_starts) - m = len(timesteps_ends) - start_idx = 0 + try: + audio, sr = torchaudio.load(path_audio) + except Exception as e: + logger.error(f"Broken file {path_audio}: {e}") + return [] - while start_idx < n: - left_s = timesteps_starts[start_idx] - end_idx = start_idx - temp_duration = 0.0 + total_audio_duration = audio.shape[-1] / sr - while end_idx < m: - right_s = timesteps_ends[end_idx] - temp_duration = right_s - left_s - if temp_duration > duration: - break - end_idx += 1 - - if end_idx > start_idx: - current_duration = timesteps_ends[end_idx - 1] - left_s - if (current_duration >= duration / 3) and (current_duration <= duration): - pieces.append((start_idx, end_idx - 1)) - - start_idx = max(end_idx, start_idx + 1) - - return pieces - -def cut_audio( - audio: torch.Tensor, - sr: int, - pieces: List[Tuple[int, int]], - satrt_timestamps: List[float], - end_timestamps: List[float], - text_segments: List[str], - output_folder: str, - album_id: str, - episode_id: str, - format: str = 'mp3', - duration:float = 15.0 -): try: - os.makedirs(output_folder, exist_ok=True) - for i, (start_idx, end_idx) in enumerate(pieces): + if audio.shape[0] > 1: + audio = torch.mean(audio, dim=0, keepdim=True) - start_time = satrt_timestamps[start_idx] - end_time = end_timestamps[end_idx] + raw_segments = diarize_audio(audio, sr, chunk_duration) + if not raw_segments: return [] - if end_time - start_time <= duration / 3 : - continue + if total_audio_duration <= limit_dur: + sil_pct, max_sil, unique_spk = get_chunk_metrics(0.0, total_audio_duration, raw_segments) + main_spk = raw_segments[0][2] if raw_segments else -1 - start_sample = int(start_time * sr) - end_sample = int(end_time * sr) - end_sample = min(audio.shape[-1], end_sample) - assert end_sample > start_sample - - segment = audio[:, start_sample:end_sample] - output_audio_filename = f"{start_time:.2f}_{end_time:.2f}_{album_id}_{episode_id}.{format}" - output_whisper_filename = f"{start_time:.2f}_{end_time:.2f}_{album_id}_{episode_id}_whisper.txt" - - output_audio_path = os.path.join(output_folder, output_audio_filename) - output_whisper_path = os.path.join(output_folder, output_whisper_filename) - - whisper_text = ' '.join(text_segments[start_idx:end_idx + 1]) - with open(output_whisper_path, 'w', encoding='utf-8') as f: - f.write(whisper_text) + return [{ + 'filepath': os.path.abspath(path_audio), + 'speaker_id': main_spk, + 'start': 0.0, + 'end': round(total_audio_duration, 2), + 'total_duration': round(total_audio_duration, 2), + 'playlist_id': album_id, + 'podcast_id': episode_id, + 'silence_percent': sil_pct, + 'max_silence_duration': max_sil, + 'is_single_speaker': unique_spk == 1 + }] + + clean_segments = filter_single_speaker_segments(raw_segments, min_duration=1.0, max_duration=limit_dur) + final_segments = apply_eos_classification(audio, sr, clean_segments, max_duration=limit_dur) + + if not final_segments: return [] - torchaudio.save(output_audio_path, segment, sr) + seg_results = cut_audio(audio, sr, final_segments, raw_segments, str(episode_folder), album_id, episode_id, max_duration=limit_dur) - logger.success(f"The folder has been processed : {output_folder}") + if seg_results: + logger.success(f"Processed {len(seg_results)} chunks from: {p_audio.name}") + if p_audio.exists(): + os.remove(p_audio) + + return seg_results except Exception as e: - logger.error(f"Error : {e}") - raise - -def process_audio_file( - path_audio: str, - duration: float, - beam_size: int -): - album_id = os.path.basename(os.path.dirname(path_audio)) - episode_id = os.path.splitext(os.path.basename(path_audio))[0] - - logger.info(f"Processing: Album={album_id}, Episode={episode_id}") + logger.error(f"Processing error {path_audio}: {e}") + return [] + finally: + if torch.cuda.is_available(): torch.cuda.empty_cache() - episode_folder = os.path.join(os.path.dirname(path_audio), episode_id) - try: - audio, sr = torchaudio.load(path_audio) +def main(args): + load_dotenv() + if hf_key := os.environ.get('HF_TOKEN'): login(token=hf_key) - if audio.shape[-1] / sr <= duration: - return - except Exception as e: - os.remove(path_audio) - logger.info(f"broken file {path_audio}: {e}") - return + config = load_config(args.config_path, 'preprocess') + podcasts_path = Path(config.get('podcasts_path', '../../../podcasts')) + num_workers_per_gpu = config.get('num_workers', 1) - try: - timesteps_starts, timesteps_ends, timestamps_text = get_whisper_segments(model, path_audio, beam_size) - pieces = get_piece_idx(timesteps_starts, timesteps_ends, duration) - - cut_audio( - audio=audio, - sr=sr, - pieces=pieces, - satrt_timestamps=timesteps_starts, - end_timestamps=timesteps_ends, - text_segments=timestamps_text, - output_folder=episode_folder, - album_id=album_id, - episode_id=episode_id, - format='mp3', - duration=duration - ) + num_gpus = torch.cuda.device_count() + total_workers = max(1, num_gpus * num_workers_per_gpu) + csv_path = podcasts_path / 'balalaika.csv' + existing_df = pd.read_csv(csv_path) if csv_path.exists() else pd.DataFrame() - except Exception as e: - logger.error(f"Processing error {path_audio}: {e}") - torch.cuda.empty_cache() - - - if len(os.listdir(episode_folder)) > 0 or len(pieces) == 0 : # the audio was cut or we couldn't cut it considering our length - os.remove(path_audio) - logger.info(f"Temporary file deleted: {path_audio}") - -def init_process( - whisper_model: str, - device: str, - compute_type: str = 'float16', - device_index = [0] -): - - global model - model = WhisperModel( - whisper_model, - device=device, - compute_type=compute_type, - device_index=device_index, - ) + raw_audio_paths = get_audio_paths(str(podcasts_path)) + paths_to_process = [] + chunk_pattern = re.compile(r'^\d+\.\d+_\d+\.\d+_') -def main(args): - load_dotenv() - hf_key = os.getenv("HF_TOKEN") - login(token=hf_key) + for p_str in raw_audio_paths: + p = Path(p_str) + if chunk_pattern.match(p.name): continue + paths_to_process.append(p) - config = load_config(args.config_path, 'preprocess') + if not paths_to_process: + logger.info("No new files to process.") + return - podcasts_path = args.podcasts_path if args.podcasts_path else config.get('podcasts_path', '../../../podcasts') - duration = args.duration if args.duration else config.get('duration', 15) - num_workers = args.num_workers if args.num_workers else config.get('num_workers', 4) - whisper_model = args.whisper_model if args.whisper_model else config.get('whisper_model', 'large-v3') - compute_type = args.compute_type if args.compute_type else config.get('compute_type', 'float16') - beam_size = args.beam_size if args.beam_size else config.get('beam_size', 5) - device = 'cuda' - available_gpu_ids = list(range(torch.cuda.device_count())) + logger.info(f"Files to process: {len(paths_to_process)} on {num_gpus} GPU(s)") - num_gpus = len(available_gpu_ids) + all_results = [] + files_per_gpu = [[] for _ in range(num_gpus)] if num_gpus > 0 else [paths_to_process] + + if num_gpus > 0: + for i, p in enumerate(paths_to_process): + files_per_gpu[i % num_gpus].append(p) - if num_gpus == 0: - logger.error("No GPUs available. Exiting.") - return + for gpu_id in range(max(1, num_gpus)): + gpu_files = files_per_gpu[gpu_id] + if not gpu_files: continue - audio_paths = get_audio_paths(podcasts_path) - - max_workers = min(num_workers, os.cpu_count()) - - logger.info( - f""" - Using parms - podcasts_path:{podcasts_path} - whisper_model:{whisper_model} - duration:{duration} - num_workers:{num_workers} - devices:{available_gpu_ids} - compute_type:{compute_type} - beam_size:{beam_size} - """) - - with ProcessPoolExecutor( - max_workers=max_workers, - initializer=init_process, - initargs=(whisper_model, device, compute_type, available_gpu_ids, ) + logger.info(f"GPU:{gpu_id} processing {len(gpu_files)} files...") + + with ProcessPoolExecutor( + max_workers=num_workers_per_gpu, + initializer=init_models, + initargs=(gpu_id, config) ) as executor: - futures = [ - executor.submit(process_audio_file, path_audio, duration, beam_size) - for path_audio in audio_paths + futures = [executor.submit(process_audio_file, str(p), config) for p in gpu_files] + for future in tqdm(as_completed(futures), total=len(gpu_files), desc=f"GPU {gpu_id}"): + try: + res = future.result() + if res: all_results.extend(res) + except Exception as e: + logger.error(f"Task error: {e}") + + if all_results: + new_df = pd.DataFrame(all_results) + + if not existing_df.empty: + df = pd.concat([existing_df, new_df], ignore_index=True).drop_duplicates(subset=['filepath'], keep='last') + else: + df = new_df + + base_cols = [ + 'filepath', 'speaker_id', 'start', 'end', 'total_duration', + 'playlist_id', 'podcast_id', 'silence_percent', + 'max_silence_duration', 'is_single_speaker' ] - for future in tqdm(as_completed(futures), total=len(futures)): - try: - future.result() - except Exception as e: - logger.error(f"Error processing file: {e}") + + cols = [c for c in base_cols if c in df.columns] + [c for c in df.columns if c not in base_cols] + df = df[cols] + + df.to_csv(csv_path, index=False) + logger.success(f"Successfully processed {len(all_results)} samples. Metadata saved to {csv_path}") if __name__ == "__main__": - torchaudio.set_audio_backend('soundfile') multiprocessing.set_start_method('spawn', force=True) - parser = argparse.ArgumentParser(description="Process audio files using whisper model.") - parser.add_argument( - "--config_path", - help="Path to YAML configuration file", - type=str, - ) - parser.add_argument( - "--podcasts_path", - help="Path to podcasts folder", - type=str, - ) - parser.add_argument( - "--whisper_model", - help="name of model", - type=str, - ) - parser.add_argument( - "--compute_type", - help="compute type", - type=str, - ) - parser.add_argument( - "--beam_size", - help="beam size", - type=int, - ) - parser.add_argument( - "--duration", - help="Duration in seconds", - type=int, - ) - parser.add_argument( - "--num_workers", - help="Number of workers", - type=int, - ) - - args = parser.parse_args() - main(args) \ No newline at end of file + parser = argparse.ArgumentParser() + parser.add_argument("--config_path", type=str, required=True, help="Path to YAML config file") + main(parser.parse_args()) \ No newline at end of file diff --git a/src/preprocess/preprocess_args.sh b/src/preprocess/preprocess_args.sh deleted file mode 100644 index 3061350..0000000 --- a/src/preprocess/preprocess_args.sh +++ /dev/null @@ -1,33 +0,0 @@ -#!/bin/bash - -activate_venv() { - local venv_path=$1 - if [ ! -f "$venv_path/bin/activate" ]; then - echo "Error: Virtual environment not found at $venv_path" - exit 1 - fi - source "$venv_path/bin/activate" - echo "Activated: $(which python)" -} - - -activate_venv ".dev_venv" - -SCRIPT_DIR=$(dirname "$(realpath "$0")") - -PODCASTS_PATH="../../../balalaika" -DURATION=15 -NUM_WORKERS=2 -WHISPER_MODEL="large-v3" -COMPUTE_TYPE="float16" -BEAM_SIZE=5 - -python3 -m src.preprocess.preprocess \ - --podcasts_path "$PODCASTS_PATH" \ - --duration "$DURATION" \ - --num_workers "$NUM_WORKERS" \ - --whisper_model "$WHISPER_MODEL" \ - --compute_type "$COMPUTE_TYPE" \ - --beam_size "$BEAM_SIZE" - - diff --git a/src/preprocess/preprocess_audio.py b/src/preprocess/preprocess_audio.py new file mode 100644 index 0000000..064dbd3 --- /dev/null +++ b/src/preprocess/preprocess_audio.py @@ -0,0 +1,215 @@ +import argparse +import torch +import torch.multiprocessing as mp +import torchaudio +from pathlib import Path +from typing import List + +import numpy as np +import pyloudnorm as pyln +from loguru import logger +from tqdm import tqdm + +import soundfile as sf + +from src.utils.utils import load_config, get_audio_paths + +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cuda.enable_flash_sdp(True) +torch.backends.cuda.enable_mem_efficient_sdp(True) +torch.backends.cuda.enable_math_sdp(False) + + +def normalize_audio_loudness( + audio: np.ndarray, + rate: int, + peak: float = -1.0, + loudness: float = -23.0, + block_size: float = 0.400 +) -> np.ndarray: + """ + Perform loudness normalization (ITU-R BS.1770-4) on audio data. + + Args: + audio: audio data array + rate: sample rate + peak: peak normalize audio to N dB. Defaults to -1.0. + loudness: loudness normalize audio to N dB LUFS. Defaults to -23.0. + block_size: block size for loudness measurement in seconds. Defaults to 0.400 (400 ms). + + Returns: + loudness normalized audio array + """ + # Peak normalize audio to [peak] dB + audio = pyln.normalize.peak(audio, peak) + + # Measure the loudness first + meter = pyln.Meter(rate, block_size=block_size) # create BS.1770 meter + _loudness = meter.integrated_loudness(audio) + + return pyln.normalize.loudness(audio, _loudness, loudness) + + +def process_audio_file( + audio_path: str, + peak: float, + loudness: float, + block_size: float +): + """ + Process a single audio file: normalize loudness and overwrite the original file. + + Args: + audio_path: Path to the audio file to process + peak: Peak normalization level in dB + loudness: Target loudness level in LUFS + block_size: Block size for loudness measurement in seconds + """ + try: + audio, sample_rate = torchaudio.load_with_torchcodec(audio_path) + audio_np = audio.numpy() + + # torchaudio returns (channels, samples), pyloudnorm expects (samples,) or (samples, channels≤5) + if audio_np.shape[0] == 1: + audio_np = audio_np.squeeze(0) + else: + audio_np = audio_np.T + + normalized_audio = normalize_audio_loudness( + audio_np, + sample_rate, + peak=peak, + loudness=loudness, + block_size=block_size + ) + + # Convert back to torchaudio format (channels, samples) + if normalized_audio.ndim == 1: + normalized_audio = normalized_audio[np.newaxis, :] + else: + normalized_audio = normalized_audio.T + + torchaudio.save(audio_path, torch.from_numpy(normalized_audio), sample_rate) + + logger.debug(f"Normalized: {audio_path}") + + except Exception as e: + logger.error(f"Error processing {audio_path}: {e}") + finally: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + +def run_worker( + rank: int, + world_size: int, + all_file_paths: List[Path], + peak: float, + loudness: float, + block_size: float +): + """ + Worker function for processing files on a specific GPU/process. + + Args: + rank: Process rank (0 to world_size-1) + world_size: Total number of processes + all_file_paths: List of all audio file paths to process + peak: Peak normalization level in dB + loudness: Target loudness level in LUFS + block_size: Block size for loudness measurement in seconds + """ + if not all_file_paths: + return + + # Distribute files across workers + my_files = all_file_paths[rank::world_size] + + if not my_files: + return + + logger.info(f"Worker {rank}/{world_size} processing {len(my_files)} files") + + for file_path in tqdm(my_files, desc=f"Worker-{rank}", position=rank): + process_audio_file( + str(file_path), + peak, + loudness, + block_size + ) + + +def main(args): + """ + Main function to normalize audio loudness for all audio files in the specified directory. + """ + config = load_config(args.config_path, 'preprocess') + + podcasts_path = config.get('podcasts_path') + if not podcasts_path: + podcasts_path = config.get('podcasts_path', '../../../podcasts') + logger.warning("Using default podcasts_path") + + # Loudness normalization parameters + peak = config.get('peak', -1.0) + loudness = config.get('loudness', -23.0) + block_size = config.get('block_size', 0.400) + num_workers = config.get('num_workers', 4) + + # Use CPU workers (loudness normalization doesn't require GPU) + # But we can use multiple CPU cores + num_processes = num_workers + + logger.info(f""" + Running loudness normalization: + Podcasts path: {podcasts_path} + Peak normalization: {peak} dB + Target loudness: {loudness} LUFS + Block size: {block_size} seconds + Number of processes: {num_processes} + """) + + # Get all audio files + audio_paths = get_audio_paths(podcasts_path) + if not audio_paths: + logger.info("No audio files found for processing.") + return + + logger.info(f"Found {len(audio_paths)} audio files to process") + + # Process files using torch multiprocessing + if num_processes > 1: + mp.spawn( + run_worker, + args=(num_processes, audio_paths, peak, loudness, block_size), + nprocs=num_processes, + join=True + ) + else: + # Single process mode + for file_path in tqdm(audio_paths, desc="Normalizing loudness"): + process_audio_file( + str(file_path), + peak, + loudness, + block_size + ) + + logger.info("All files have been processed and normalized.") + + +if __name__ == "__main__": + mp.set_start_method('spawn', force=True) + + parser = argparse.ArgumentParser( + description="Normalize audio loudness (ITU-R BS.1770-4) for all audio files in the dataset." + ) + parser.add_argument( + "--config_path", + type=str, + required=True, + help="Path to YAML configuration file" + ) + args = parser.parse_args() + + main(args) diff --git a/src/preprocess/preprocess_yaml.sh b/src/preprocess/preprocess_yaml.sh index acaf32d..adb2f1e 100644 --- a/src/preprocess/preprocess_yaml.sh +++ b/src/preprocess/preprocess_yaml.sh @@ -21,4 +21,6 @@ activate_venv ".dev_venv" SCRIPT_DIR=$(dirname "$(realpath "$0")") -python3 -m src.preprocess.preprocess --config_path "$CONFIG_PATH" \ No newline at end of file +python3 -m src.preprocess.preprocess --config_path "$CONFIG_PATH" +python3 -m src.preprocess.crest_factor_remover --config_path "$CONFIG_PATH" +python3 -m src.preprocess.preprocess_audio --config_path "$CONFIG_PATH" \ No newline at end of file diff --git a/src/preprocess/sortformer_onnx.py b/src/preprocess/sortformer_onnx.py new file mode 100644 index 0000000..c5c78f3 --- /dev/null +++ b/src/preprocess/sortformer_onnx.py @@ -0,0 +1,398 @@ +import numpy as np +import onnxruntime as ort +import librosa +from typing import List, Tuple, Union +import time +import os +import huggingface_hub + +# Model constants +N_FFT = 512 +WIN_LENGTH = 400 +HOP_LENGTH = 160 +N_MELS = 128 +PREEMPH = 0.97 +LOG_ZERO_GUARD = 5.9604645e-8 +SAMPLE_RATE = 16000 + +# Streaming constants defaults +CHUNK_LEN = 124 +FIFO_LEN = 124 +SPKCACHE_LEN = 188 +RIGHT_CONTEXT = 1 +SUBSAMPLING = 8 +EMB_DIM = 512 +NUM_SPEAKERS = 4 +FRAME_DURATION = 0.08 + +# Cache compression params +SPKCACHE_SIL_FRAMES_PER_SPK = 3 +PRED_SCORE_THRESHOLD = 0.25 +STRONG_BOOST_RATE = 0.75 +WEAK_BOOST_RATE = 1.5 +MIN_POS_SCORES_RATE = 0.5 +SIL_THRESHOLD = 0.2 +MAX_INDEX = 999999 + + +class DiarizationConfig: + def __init__(self, onset=0.5, offset=0.5, pad_onset=0.0, pad_offset=0.0, + min_duration_on=0.0, min_duration_off=0.0, median_window=1): + self.onset = onset + self.offset = offset + self.pad_onset = pad_onset + self.pad_offset = pad_offset + self.min_duration_on = min_duration_on + self.min_duration_off = min_duration_off + self.median_window = median_window + + +class Sortformer: + def __init__(self, model_path: str, config: DiarizationConfig = None, providers: List[str] = None): + if config is None: + self.config = DiarizationConfig() + else: + self.config = config + + if not os.path.exists(model_path): + model_path = huggingface_hub.hf_hub_download(repo_id="altunenes/parakeet-rs", filename="diar_streaming_sortformer_4spk-v2.1.onnx", local_dir="./models") + + if providers is None: + providers = [ + # ( + # "TensorrtExecutionProvider", + # { + # "trt_max_workspace_size": 6 * 1024**3, + # "trt_fp16_enable": True, + # "trt_engine_cache_enable": True, + # "trt_engine_cache_path": "./trt_cache", + # } + # ), + "CUDAExecutionProvider", + "CPUExecutionProvider" + ] + + self.session = ort.InferenceSession(model_path, providers=providers) + + meta = self.session.get_modelmeta().custom_metadata_map + self.chunk_len = int(meta.get("chunk_len", CHUNK_LEN)) + self.fifo_len = int(meta.get("fifo_len", FIFO_LEN)) + self.spkcache_len = int(meta.get("spkcache_len", SPKCACHE_LEN)) + self.right_context = int(meta.get("right_context", RIGHT_CONTEXT)) + + self.mel_basis = librosa.filters.mel(sr=SAMPLE_RATE, n_fft=N_FFT, n_mels=N_MELS, norm='slaney') + + self.reset_state() + + def reset_state(self): + self.spkcache = np.zeros((1, 0, EMB_DIM), dtype=np.float32) + self.spkcache_preds = None + self.fifo = np.zeros((1, 0, EMB_DIM), dtype=np.float32) + self.fifo_preds = np.zeros((1, 0, NUM_SPEAKERS), dtype=np.float32) + self.mean_sil_emb = np.zeros((1, EMB_DIM), dtype=np.float32) + self.n_sil_frames = 0 + + def extract_mel_features(self, audio: np.ndarray) -> np.ndarray: + preemphasized = np.append(audio[0], audio[1:] - PREEMPH * audio[:-1]) + S = librosa.stft(preemphasized, n_fft=N_FFT, hop_length=HOP_LENGTH, + win_length=WIN_LENGTH, window='hann', center=True) + power_spec = np.abs(S) ** 2 + mel_spec = np.dot(self.mel_basis, power_spec) + log_mel_spec = np.log(mel_spec + LOG_ZERO_GUARD) + return log_mel_spec.T[np.newaxis, :, :].astype(np.float32) + + def diarize(self, audio: np.ndarray, sample_rate: int = 16000, include_tensor_outputs: bool = False) -> Union[List[List[str]], Tuple[List[List[str]], np.ndarray]]: + if sample_rate != SAMPLE_RATE: + audio = librosa.resample(audio, orig_sr=sample_rate, target_sr=SAMPLE_RATE) + + if len(audio.shape) > 1: + audio = np.mean(audio, axis=0) + + self.reset_state() + + features = self.extract_mel_features(audio) + full_preds = self._process_features(features) + + if self.config.median_window > 1: + from scipy.ndimage import median_filter + filtered_preds = median_filter(full_preds, size=(self.config.median_window, 1)) + else: + filtered_preds = full_preds + + audio_duration_sec = len(audio) / SAMPLE_RATE + segments = self._binarize(filtered_preds, audio_duration_sec) + + formatted_result = [segments] + + if include_tensor_outputs: + return formatted_result, full_preds + return formatted_result + + def _process_features(self, features: np.ndarray) -> np.ndarray: + total_frames = features.shape[1] + chunk_stride = self.chunk_len * SUBSAMPLING + feed_size = (self.chunk_len + self.right_context) * SUBSAMPLING + num_chunks = int(np.ceil(total_frames / chunk_stride)) + + all_chunk_preds = [] + + for chunk_idx in range(num_chunks): + start = chunk_idx * chunk_stride + end = min(start + feed_size, total_frames) + current_len = end - start + + chunk_feat = features[:, start:end, :] + + if current_len < feed_size: + padded = np.zeros((1, feed_size, N_MELS), dtype=np.float32) + padded[:, :current_len, :] = chunk_feat + chunk_feat = padded + + chunk_preds = self._streaming_update(chunk_feat, current_len) + all_chunk_preds.append(chunk_preds) + + if len(all_chunk_preds) == 0: + return np.zeros((0, NUM_SPEAKERS), dtype=np.float32) + + return np.concatenate(all_chunk_preds, axis=0) + + def _streaming_update(self, chunk_feat: np.ndarray, current_len: int) -> np.ndarray: + spkcache_len = self.spkcache.shape[1] + fifo_len = self.fifo.shape[1] + + inputs = { + "chunk": chunk_feat, + "chunk_lengths": np.array([current_len], dtype=np.int64), + "spkcache": self.spkcache, + "spkcache_lengths": np.array([spkcache_len], dtype=np.int64), + "fifo": self.fifo, + "fifo_lengths": np.array([fifo_len], dtype=np.int64) + } + + outputs = self.session.run(["spkcache_fifo_chunk_preds", "chunk_pre_encode_embs"], inputs) + + preds = outputs[0] + new_embs = outputs[1] + + valid_frames = int(np.ceil(current_len / SUBSAMPLING)) + fifo_preds = preds[:, spkcache_len:spkcache_len+fifo_len, :] if fifo_len > 0 else np.zeros((1, 0, NUM_SPEAKERS)) + + keep = min(self.chunk_len, valid_frames) + chunk_preds_idx_start = spkcache_len + fifo_len + chunk_preds = preds[:, chunk_preds_idx_start:chunk_preds_idx_start+keep, :] + chunk_embs = new_embs[:, :keep, :] + + self.fifo = np.concatenate([self.fifo, chunk_embs], axis=1) + self.fifo_preds = np.concatenate([fifo_preds, chunk_preds], axis=1) + + fifo_len_after = self.fifo.shape[1] + + if fifo_len_after > self.fifo_len: + pop_out_len = max(self.chunk_len, valid_frames - self.fifo_len + fifo_len) + pop_out_len = min(pop_out_len, fifo_len_after) + + pop_out_embs = self.fifo[:, :pop_out_len, :] + pop_out_preds = self.fifo_preds[:, :pop_out_len, :] + + self._update_silence_profile(pop_out_embs[0], pop_out_preds[0]) + + self.fifo = self.fifo[:, pop_out_len:, :] + self.fifo_preds = self.fifo_preds[:, pop_out_len:, :] + + self.spkcache = np.concatenate([self.spkcache, pop_out_embs], axis=1) + + if self.spkcache_preds is not None: + self.spkcache_preds = np.concatenate([self.spkcache_preds, pop_out_preds], axis=1) + + if self.spkcache.shape[1] > self.spkcache_len: + if self.spkcache_preds is None: + initial_cache_preds = preds[:, :spkcache_len, :] + self.spkcache_preds = np.concatenate([initial_cache_preds, pop_out_preds], axis=1) + self._compress_spkcache() + + return chunk_preds[0] + + def _update_silence_profile(self, embs: np.ndarray, preds: np.ndarray): + sums = np.sum(preds, axis=1) + sil_mask = sums < SIL_THRESHOLD + if np.any(sil_mask): + sil_embs = embs[sil_mask] + for emb in sil_embs: + self.mean_sil_emb[0] = (self.mean_sil_emb[0] * self.n_sil_frames + emb) / (self.n_sil_frames + 1) + self.n_sil_frames += 1 + + def _compress_spkcache(self): + if self.spkcache_preds is None: return + + n_frames = self.spkcache.shape[1] + per_spk = self.spkcache_len // NUM_SPEAKERS + if per_spk <= SPKCACHE_SIL_FRAMES_PER_SPK: + self.spkcache = self.spkcache[:, :self.spkcache_len, :] + self.spkcache_preds = self.spkcache_preds[:, :self.spkcache_len, :] + return + + spkcache_len_per_spk = per_spk - SPKCACHE_SIL_FRAMES_PER_SPK + strong_boost = int(spkcache_len_per_spk * STRONG_BOOST_RATE) + weak_boost = int(spkcache_len_per_spk * WEAK_BOOST_RATE) + min_pos = int(spkcache_len_per_spk * MIN_POS_SCORES_RATE) + + preds_2d = self.spkcache_preds[0] + scores = self._get_log_pred_scores(preds_2d) + scores = self._disable_low_scores(preds_2d, scores, min_pos) + scores = self._boost_topk_scores(scores, strong_boost, 2.0) + scores = self._boost_topk_scores(scores, weak_boost, 1.0) + + if SPKCACHE_SIL_FRAMES_PER_SPK > 0: + padded = np.full((n_frames + SPKCACHE_SIL_FRAMES_PER_SPK, NUM_SPEAKERS), -np.inf, dtype=np.float32) + padded[:n_frames, :] = scores + padded[n_frames:, :] = np.inf + scores = padded + + topk_indices, is_disabled = self._get_topk_indices(scores, n_frames) + new_embs, new_preds = self._gather_spkcache(topk_indices, is_disabled) + + self.spkcache = new_embs + self.spkcache_preds = new_preds + + def _get_log_pred_scores(self, preds: np.ndarray) -> np.ndarray: + p = np.maximum(preds, PRED_SCORE_THRESHOLD) + log_p = np.log(p) + log_1_p = np.log(np.maximum(1.0 - preds, PRED_SCORE_THRESHOLD)) + log_1_probs_sum = np.sum(log_1_p, axis=1, keepdims=True) + return log_p - log_1_p + log_1_probs_sum - np.log(0.5) + + def _disable_low_scores(self, preds: np.ndarray, scores: np.ndarray, min_pos: int) -> np.ndarray: + pos_count = np.sum(scores > 0.0, axis=0) + is_speech = preds > 0.5 + is_pos = scores > 0.0 + mask = (~is_speech) | ((~is_pos) & (pos_count >= min_pos)) + scores[mask] = -np.inf + return scores + + def _boost_topk_scores(self, scores: np.ndarray, n_boost: int, scale_factor: float) -> np.ndarray: + for s in range(NUM_SPEAKERS): + col = scores[:, s].copy() + top_idx = np.argsort(col)[::-1][:n_boost] + valid_mask = scores[top_idx, s] != -np.inf + scores[top_idx[valid_mask], s] -= scale_factor * np.log(0.5) + return scores + + def _get_topk_indices(self, scores: np.ndarray, n_frames_no_sil: int) -> Tuple[List[int], List[bool]]: + n_frames = scores.shape[0] + flat_scores = scores.flatten('F') + sorted_flat_idx = np.argsort(flat_scores)[::-1] + + topk_flat = [] + for idx in sorted_flat_idx[:self.spkcache_len]: + if flat_scores[idx] == -np.inf: + topk_flat.append(MAX_INDEX) + else: + topk_flat.append(idx) + topk_flat.sort() + + is_disabled = [False] * self.spkcache_len + frame_indices = [0] * self.spkcache_len + + for i, flat_idx in enumerate(topk_flat): + if flat_idx == MAX_INDEX: + is_disabled[i] = True + else: + frame_idx = flat_idx % n_frames + if frame_idx >= n_frames_no_sil: + is_disabled[i] = True + else: + frame_indices[i] = frame_idx + return frame_indices, is_disabled + + def _gather_spkcache(self, indices: List[int], is_disabled: List[bool]) -> Tuple[np.ndarray, np.ndarray]: + new_embs = np.zeros((1, self.spkcache_len, EMB_DIM), dtype=np.float32) + new_preds = np.zeros((1, self.spkcache_len, NUM_SPEAKERS), dtype=np.float32) + cache_preds = self.spkcache_preds[0] + cache_embs = self.spkcache[0] + + for i, (idx, disabled) in enumerate(zip(indices, is_disabled)): + if disabled: + new_embs[0, i, :] = self.mean_sil_emb[0] + elif idx < cache_embs.shape[0]: + new_embs[0, i, :] = cache_embs[idx] + new_preds[0, i, :] = cache_preds[idx] + return new_embs, np.expand_dims(new_preds[0], axis=0) + + def _binarize(self, preds: np.ndarray, audio_duration_sec: float) -> List[str]: + raw_segments = [] + num_frames = preds.shape[0] + + for spk in range(NUM_SPEAKERS): + raw_intervals = [] + in_seg = False + start_t = 0.0 + + for t in range(num_frames): + p = preds[t, spk] + if p >= self.config.onset and not in_seg: + in_seg = True + start_t = t * FRAME_DURATION + elif p < self.config.offset and in_seg: + in_seg = False + raw_intervals.append([start_t, t * FRAME_DURATION, spk]) + + if in_seg: + raw_intervals.append([start_t, num_frames * FRAME_DURATION, spk]) + + if not raw_intervals: + continue + + merged_intervals = [raw_intervals[0]] + for i in range(1, len(raw_intervals)): + gap = raw_intervals[i][0] - merged_intervals[-1][1] + if gap <= self.config.min_duration_off: + merged_intervals[-1][1] = raw_intervals[i][1] + else: + merged_intervals.append(raw_intervals[i]) + + filtered_intervals = [] + for seg in merged_intervals: + if (seg[1] - seg[0]) >= self.config.min_duration_on: + filtered_intervals.append(seg) + + padded_intervals = [] + for seg in filtered_intervals: + start_s = max(0.0, seg[0] - self.config.pad_onset) + end_s = min(audio_duration_sec, seg[1] + self.config.pad_offset) + + if not padded_intervals: + padded_intervals.append([start_s, end_s, spk]) + else: + if start_s <= padded_intervals[-1][1]: + padded_intervals[-1][1] = max(padded_intervals[-1][1], end_s) + else: + padded_intervals.append([start_s, end_s, spk]) + + raw_segments.extend(padded_intervals) + + raw_segments.sort(key=lambda x: (x[0], x[2])) + + str_segments = [] + for seg in raw_segments: + str_segments.append(f"{seg[0]} {seg[1]} speaker_{seg[2]}") + + return str_segments + + + +if __name__ == "__main__": + model_path = "/home/nikita/balalaika/models/diar_streaming_sortformer_4spk-v2.1.onnx" + audio_path = "/home/nikita/balalaika/datkamatka/12.mp3" + + audio, sr = librosa.load(audio_path, sr=16000, mono=True) + + config = DiarizationConfig() + diarizer = Sortformer(model_path, config=config) + + start_time = time.time() + # print(audio.shape) + results = diarizer.diarize(audio, sample_rate=16000, include_tensor_outputs=False) + end_time = time.time() + + print(f"RTF: {(end_time - start_time) / (audio.shape[-1] / 16_000):.3f}") + print(results) diff --git a/src/punctuation/README.md b/src/punctuation/README.md index 6d72550..bb112db 100644 --- a/src/punctuation/README.md +++ b/src/punctuation/README.md @@ -1,44 +1,28 @@ -## Usage/Examples +## Punctuation (RUPunct) -### Running the Code via Command-Line Arguments -You can modify the parameters directly in the shell script (`punctuation/punctuation_args.sh`) and then run it: -```sh -sh punctuation/punctuation_args.sh -``` +Restores punctuation and capitalization from **`{stem}_rover.txt`**. + +## Run -### Running the Code via Config File -Example: -```sh -bash punctuation/punctuation_yaml.sh config_path +```bash +bash src/punctuation/punctuation_yaml.sh configs/config.yaml ``` -## Explanation of Parameters +## Parameters -- `--config_path`: Path to the YAML configuration file. -- `--podcasts_path`: Root directory containing text files for processing (default: "../../../podcasts"). -- `--model_name`: Name of the punctuation model (default: "RUPunct/RUPunct_big"). -- `--num_workers`: Number of worker processes per GPU for parallel processing (default: 4). +See **`punctuation`** in `configs/config.yaml` (`podcasts_path`, `model_name`, `num_workers`). -## Output Structure +## Output -For each transcribed audio file, a new file with restored punctuation will be created: +For each `{stem}_rover.txt`, writes **`{stem}_punct.txt`**. +```text +{podcasts_path}/ +└── {playlist_id}/ + └── {podcast_id}/ + ├── {stem}.mp3 + ├── {stem}_rover.txt # input + └── {stem}_punct.txt # output ``` -podcasts/ -└── {album_id}/ - └── {episode_id}/ - ├── {start_time}_{end_time}_{album_id}_{episode_id}.mp3 - ├── {start_time}_{end_time}_{album_id}_{episode_id}_giga.txt - └── {start_time}_{end_time}_{album_id}_{episode_id}_punct.txt -``` - -### File Descriptions -- `.mp3`: Original audio file -- `_giga.txt`: Transcribed text without punctuation (input file) -- `_punct.txt`: Text with restored punctuation using the RUPunct model - -The script processes all `_giga.txt` files found in the directory structure and creates corresponding `_punct.txt` files with restored punctuation. Processing is done in parallel using available GPUs for better performance. - -## Important Notice -The punctuation and yofication scripts must be executed sequentially! +WebDataset packs this as `punct.txt` inside `json` (`src/to_webdataset.py`). diff --git a/src/punctuation/punctuation.py b/src/punctuation/punctuation.py index 78075f9..766a9e0 100644 --- a/src/punctuation/punctuation.py +++ b/src/punctuation/punctuation.py @@ -10,7 +10,12 @@ from tqdm import tqdm from transformers import pipeline, AutoTokenizer -from src.utils import load_config, get_audio_paths, process_token, read_file_content +from src.utils.utils import load_config, get_audio_paths, process_token, read_file_content + +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cuda.enable_flash_sdp(True) +torch.backends.cuda.enable_mem_efficient_sdp(True) +torch.backends.cuda.enable_math_sdp(False) model = None @@ -39,7 +44,7 @@ def make_punct_txt( src_text = read_file_content(path) - punct_path = path.with_name(path.name.replace("_giga.txt", "_punct.txt")) + punct_path = path.with_name(path.name.replace("_rover.txt", "_punct.txt")) if str(path).endswith('_punct.txt') or str(path).endswith('_accent.txt') or os.path.exists(punct_path): return @@ -61,7 +66,7 @@ def get_valid_txt_paths(src_path: str) -> List[str]: valid_paths = [] for audio_path in all_audio_paths: - giga_path = audio_path.with_name(audio_path.stem + "_giga.txt") + giga_path = audio_path.with_name(audio_path.stem + "_rover.txt") punct_path = audio_path.with_name(audio_path.stem + "_punct.txt") if os.path.exists(giga_path) and not os.path.exists(punct_path): @@ -72,9 +77,9 @@ def get_valid_txt_paths(src_path: str) -> List[str]: def main(args): config = load_config(args.config_path, 'punctuation') - num_workers_per_gpu = args.num_workers if args.num_workers else config.get('num_workers', 4) - model_name = args.model_name if args.model_name else config.get('model_name', 'RUPunct/RUPunct_big') - podcasts_path = args.podcasts_path if args.podcasts_path else config.get('podcasts_path', '../../../balalaika') + num_workers_per_gpu = config.get('num_workers', 4) + model_name = config.get('model_name', 'RUPunct/RUPunct_big') + podcasts_path = config.get('podcasts_path', '../../../balalaika') all_text_files = get_valid_txt_paths(podcasts_path) @@ -141,21 +146,6 @@ def main(args): type=str, help="Path to the configuration file" ) - parser.add_argument( - "--podcasts_path", - type=str, - help="Path to the dataset directory containing .txt files" - ) - parser.add_argument( - "--num_workers", - type=int, - help="Number of worker processes per GPU" - ) - parser.add_argument( - "--model_name", - type=str, - help="Hugging Face NER model name for punctuation" - ) args = parser.parse_args() main(args) \ No newline at end of file diff --git a/src/punctuation/punctuation_args.sh b/src/punctuation/punctuation_args.sh deleted file mode 100644 index cc095f9..0000000 --- a/src/punctuation/punctuation_args.sh +++ /dev/null @@ -1,24 +0,0 @@ -#!/bin/bash - -activate_venv() { - local venv_path=$1 - if [ ! -f "$venv_path/bin/activate" ]; then - echo "Error: Virtual environment not found at $venv_path" - exit 1 - fi - source "$venv_path/bin/activate" - echo "Activated: $(which python)" -} - - -activate_venv ".dev_venv" -SCRIPT_DIR=$(dirname "$(realpath "$0")") - -PODCASTS_PATH="../../../balalaika" -MODEL_NAME="RUPunct/RUPunct_big" -NUM_WORKERS=4 - -python3 -m src.punctuation.punctuation \ - --podcasts_path "$PODCASTS_PATH" \ - --model_name "$MODEL_NAME" \ - --num_workers "$NUM_WORKERS" diff --git a/src/punctuation/punctuation_yaml.sh b/src/punctuation/punctuation_yaml.sh index c298b63..938b383 100644 --- a/src/punctuation/punctuation_yaml.sh +++ b/src/punctuation/punctuation_yaml.sh @@ -20,4 +20,4 @@ CONFIG_PATH=$(realpath "$1") activate_venv ".dev_venv" SCRIPT_DIR=$(dirname "$(realpath "$0")") -python3 -m src.punctuation.punctuation --config_path "$CONFIG_PATH" \ No newline at end of file +taskset -c 0-24 python3 -m src.punctuation.punctuation --config_path "$CONFIG_PATH" \ No newline at end of file diff --git a/src/recovery_from_meta_yamls.sh b/src/recovery_from_meta_yamls.sh index 591034e..2779063 100644 --- a/src/recovery_from_meta_yamls.sh +++ b/src/recovery_from_meta_yamls.sh @@ -19,4 +19,4 @@ activate_venv "$VENV_PATH" SCRIPT_DIR=$(dirname "$(realpath "$0")") -python3 -m src.recovery_from_meta --podcasts_path $PODCASTS_PATH --parquet_path $PARQUET_PATH --num_workers $NUM_WORKERS \ No newline at end of file +taskset -с 0-24 python3 -m src.recovery_from_meta --podcasts_path $PODCASTS_PATH --parquet_path $PARQUET_PATH --num_workers $NUM_WORKERS \ No newline at end of file diff --git a/src/separation/README.md b/src/separation/README.md index b62f24a..ba07be1 100644 --- a/src/separation/README.md +++ b/src/separation/README.md @@ -1,45 +1,32 @@ -## Usage/Examples - -### Running the Code via Command-Line Arguments -You can modify the parameters directly in the shell script (`separation_args.sh`) and then run it: -~~~ -bash separation/separation_args.sh -~~~ - -### Running the Code via Config File -Example: -~~~ -bash separation/separation_yaml.sh config_path -~~~ - -## Explanation of Parameters - -- `--config_path`: Path to the YAML configuration file. -- `--podcasts_path`: Root directory containing podcast audio files for processing. -- `--one_speaker`: Boolean flag to indicate if only one speaker is expected per audio file (default: True). -- `--num_workers`: Number of parallel processes for audio processing (default: 4). - -## Output Structure - -After running the script, a results.csv file will be created in the specified `podcasts_path` directory: - -~~~ -podcasts/ -└── {album_id}/ -| └── {episode_id}/ -| .... -└── results.csv -~~~ - -The `results.csv` file contains the following information for each processed audio file: -- `audio_path`: Path to the audio file relative to the podcasts directory -- `is_mono`: Boolean indicating if the file contains a single speaker -- `NOI`: Noise metric score -- `COL`: Coloration metric score -- `DISC`: Discontinuity metric score -- `LOUD`: Loudness metric score -- `MOS`: Mean Opinion Score -- `playlist_id`: ID of the playlist -- `podcast_id`: ID of the podcast -- `start`: Start time of the segment -- `end`: End time of the segment +## Overview + +Quality filtering on chunked clips: + +1. **Music detection** — WavLM backbone + fine-tuned head at `music_detect.music_detect_model`. For every processed clip the model's music probability is written to `balalaika.csv` as `music_prob`. Clips above the threshold are deleted from disk and their rows removed from the CSV. +2. **DistillMOS** — speech quality score written to `balalaika.csv` as `DistillMOS`. Runs only on files not yet scored; skips if all entries are already present. + +Speaker diarization is handled in **preprocess** (Sortformer), not here. + +## Run + +```bash +bash src/separation/separation_yaml.sh configs/config.yaml +``` + +## Parameters + +Documented under **`separation`** and **`separation.music_detect`** in `configs/config.yaml` (`podcasts_path`, `bs`, `num_workers`, `music_detect_model`, `threshold`, optional `base_model` / `cache_path`). + +## `balalaika.csv` columns added here + +| Column | Description | +|--------|-------------| +| `music_prob` | Music classifier probability (0–1). Row removed if file deleted. | +| `DistillMOS` | Predicted MOS score. | + +## Result + +- Music-heavy chunks removed; CSV rows for deleted files are also removed. +- `balalaika.csv` updated with `music_prob` and `DistillMOS`; parallel runs use partial CSVs for safety. + +For merged fields in exported WebDataset `json`, see [`example/README.md`](../../example/README.md). diff --git a/src/separation/distillmos_process.py b/src/separation/distillmos_process.py new file mode 100644 index 0000000..81688f6 --- /dev/null +++ b/src/separation/distillmos_process.py @@ -0,0 +1,184 @@ +import argparse +import os +import re +import torch +import torch.multiprocessing as mp +import pandas as pd +import torchaudio +from pathlib import Path +from typing import List, Dict +from loguru import logger +from tqdm import tqdm + +from src.utils.utils import get_audio_paths, load_config + +torch.backends.cuda.matmul.allow_tf32 = True + +def save_chunk(results: List[Dict], output_path: Path): + """Saves a chunk of results to CSV, appending if exists.""" + if not results: + return + df = pd.DataFrame(results) + output_path.parent.mkdir(parents=True, exist_ok=True) + header = not output_path.exists() + df.to_csv(output_path, mode='a', header=header, index=False) + +def run_inference_worker(rank: int, world_size: int, file_paths: List[str], config: dict, final_output_path: Path): + """ + Worker function running on a dedicated GPU for DistillMOS. + """ + my_files = file_paths[rank::world_size] + if not my_files: + logger.info(f"Worker {rank}: No files to process.") + return + + device = torch.device(f"cuda:{rank}") + worker_output_path = final_output_path.with_name(f"distillmos_part_{rank}.csv") + + logger.info(f"[cuda:{rank}] Loading DistillMOS model...") + try: + import distillmos + sqa_model = distillmos.ConvTransformerSQAModel() + sqa_model.to(device) + sqa_model.eval() + except Exception as e: + logger.error(f"Failed to load distillmos model on worker {rank}: {e}") + return + + results_buffer = [] + save_every = 500 + + logger.info(f"[cuda:{rank}] Starting inference for {len(my_files)} files.") + + for path_str in tqdm(my_files, desc=f"DistillMOS-{rank}", position=rank): + try: + x, sr = torchaudio.load(path_str) + + if x.shape[0] > 1: + x = x[0, None, :] + + if sr != 16000: + resampler = torchaudio.transforms.Resample(sr, 16000).to(device) + x = resampler(x.to(device)) + else: + x = x.to(device) + + with torch.no_grad(): + mos = sqa_model(x) + mos_val = mos[0].item() + + results_buffer.append({ + 'filepath': path_str, + 'DistillMOS': mos_val + }) + + except Exception as e: + logger.warning(f"Error processing {path_str}: {e}") + continue + + if len(results_buffer) >= save_every: + save_chunk(results_buffer, worker_output_path) + results_buffer = [] + + save_chunk(results_buffer, worker_output_path) + logger.success(f"[cuda:{rank}] Finished.") + +def combine_results(final_output_path: Path, num_parts: int): + """Merges partial CSVs into the main balalaika.csv securely.""" + logger.info("Combining DistillMOS results...") + dfs = [] + for i in range(num_parts): + part_path = final_output_path.with_name(f"distillmos_part_{i}.csv") + if part_path.exists(): + try: + dfs.append(pd.read_csv(part_path)) + os.remove(part_path) + except Exception as e: + logger.error(f"Error reading {part_path}: {e}") + + if not dfs: + logger.warning("No DistillMOS results found to merge.") + return + + new_df = pd.concat(dfs, ignore_index=True) + + if final_output_path.exists(): + logger.info(f"Safely merging with existing CSV: {final_output_path}") + main_df = pd.read_csv(final_output_path) + + main_df.set_index('filepath', inplace=True) + new_df.set_index('filepath', inplace=True) + + main_df = main_df.combine_first(new_df).reset_index() + else: + main_df = new_df + + if 'is_single_speaker' in main_df.columns: + main_df.drop(columns=['is_single_speaker'], inplace=True) + + base_cols = ['filepath', 'speaker_id', 'start', 'end', 'total_duration', + 'playlist_id', 'podcast_id', 'silence_percent', 'max_silence_duration', 'DistillMOS'] + final_cols = [c for c in base_cols if c in main_df.columns] + [c for c in main_df.columns if c not in base_cols] + + main_df[final_cols].to_csv(final_output_path, index=False) + logger.success(f"Combined successfully. Total rows: {len(main_df)}") + +def get_unprocessed_paths(podcasts_path: Path, result_csv_path: Path) -> List[str]: + """Finds all audio files that haven't been processed by DistillMOS yet.""" + all_audio_paths = [str(Path(p).resolve()) for p in get_audio_paths(str(podcasts_path))] + + if not result_csv_path.exists(): + return all_audio_paths + + try: + df = pd.read_csv(result_csv_path) + if 'DistillMOS' not in df.columns: + return all_audio_paths + + processed = set( + df.dropna(subset=['DistillMOS'])['filepath'] + .apply(lambda p: str(Path(p).resolve())) + .tolist() + ) + + return [p for p in all_audio_paths if p not in processed] + except Exception as e: + logger.warning(f"Could not read CSV to filter paths: {e}. Processing all chunks.") + return all_audio_paths + +def main(): + mp.set_start_method('spawn', force=True) + parser = argparse.ArgumentParser() + parser.add_argument("--config_path", type=str, required=True) + args = parser.parse_args() + + config = load_config(args.config_path, 'separation') + podcasts_path = Path(config.get('podcasts_path', '.')) + final_output_path = podcasts_path / 'balalaika.csv' + + available_gpus = torch.cuda.device_count() + if available_gpus == 0: + logger.error("No GPU detected.") + return + + unprocessed = get_unprocessed_paths(podcasts_path, final_output_path) + if not unprocessed: + logger.success("All small audio files already have a DistillMOS score. Exiting.") + return + + logger.info(f"Processing {len(unprocessed)} files on {available_gpus} GPUs.") + + try: + mp.spawn( + run_inference_worker, + args=(available_gpus, unprocessed, config, final_output_path), + nprocs=available_gpus, + join=True + ) + except Exception as e: + logger.critical(f"Multiprocessing failed: {e}") + + combine_results(final_output_path, available_gpus) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/separation/music_detect.py b/src/separation/music_detect.py new file mode 100644 index 0000000..964c0d5 --- /dev/null +++ b/src/separation/music_detect.py @@ -0,0 +1,165 @@ +import argparse +import os +import torch +import torch.multiprocessing as mp +import pandas as pd +from pathlib import Path +from typing import List +from loguru import logger +from tqdm import tqdm +from safetensors import safe_open +from transformers import AutoFeatureExtractor +from torch.utils.data import DataLoader + +from musicdetection.audio_cache import create_audio_length_cache +from musicdetection.dataset import MusicDetectionDataset, AudioCollate +from musicdetection.core.model import WavLMForMusicDetection +from musicdetection.audio_sampler import LengthBasedBatchSampler +from src.utils.utils import get_audio_paths, load_config + +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cuda.enable_flash_sdp(True) + +def create_loader(paths: List[str], model_name: str, batch_size: int, num_workers: int, cache_file: Path): + audio_lengths = create_audio_length_cache(file_paths=paths, cache_file=str(cache_file)) + processor = AutoFeatureExtractor.from_pretrained(model_name) + dataset = MusicDetectionDataset(file_paths=paths, target_sample_rate=processor.sampling_rate) + sampler = LengthBasedBatchSampler(paths, audio_lengths, batch_size=batch_size, shuffle=False) + return DataLoader( + dataset, + batch_sampler=sampler, + collate_fn=AudioCollate(processor), + num_workers=num_workers, + pin_memory=True + ) + +def load_model(model_path: str, base_model: str, device: torch.device): + model = WavLMForMusicDetection(base_model_name=base_model) + with safe_open(model_path, framework="pt", device="cpu") as f: + model.load_state_dict({k: f.get_tensor(k) for k in f.keys()}) + model = model.to(device).eval() + model.device = device + return model + +def run_worker(rank: int, world_size: int, all_paths: List[str], config: dict): + my_paths = all_paths[rank::world_size] + if not my_paths: + return + + device = torch.device(f"cuda:{rank}") + cfg = config.get('music_detect', {}) + podcasts_path = Path(config.get('podcasts_path', '.')) + + threshold = cfg.get('threshold', 0.5) + cache_dir = Path(cfg.get('cache_path', './cache')) / f'nisqa_temp_worker_{rank}' + cache_dir.mkdir(parents=True, exist_ok=True) + + logger.info(f"[{device}] Processing {len(my_paths)} files...") + + try: + dataloader = create_loader( + my_paths, + cfg.get('base_model', 'microsoft/wavlm-base-plus'), + cfg.get('bs', 32), + cfg.get('num_workers', 4), + cache_dir / 'audio_lengths.json' + ) + + model = load_model( + cfg.get('music_detect_model'), + cfg.get('base_model', 'microsoft/wavlm-base-plus'), + device + ) + + probs, paths = model.predict_proba(dataloader) + + results = [] + deleted_count = 0 + for path, prob in zip(paths, probs.detach().flatten()): + prob_val = round(float(prob), 6) + results.append({'filepath': str(Path(path).resolve()), 'music_prob': prob_val}) + if prob_val > threshold: + try: + os.remove(path) + deleted_count += 1 + except OSError as e: + logger.warning(f"Could not delete {path}: {e}") + + if results: + part_path = podcasts_path / f'music_part_{rank}.csv' + pd.DataFrame(results).to_csv(part_path, index=False) + + logger.success(f"[{device}] Done. Deleted {deleted_count}/{len(my_paths)} files.") + + except Exception as e: + logger.exception(f"Worker {rank} error: {e}") + + +def update_csv(podcasts_path: Path, n_gpus: int): + csv_path = podcasts_path / 'balalaika.csv' + parts = [podcasts_path / f'music_part_{i}.csv' for i in range(n_gpus)] + existing_parts = [p for p in parts if p.exists()] + + if not existing_parts: + logger.warning("No music_part_*.csv files found; skipping CSV update.") + if csv_path.exists(): + df = pd.read_csv(csv_path) + before = len(df) + df = df[df['filepath'].apply(lambda p: Path(p).exists())] + if before != len(df): + df.to_csv(csv_path, index=False) + logger.info(f"Removed {before - len(df)} missing rows from CSV.") + return + + results_df = pd.concat([pd.read_csv(p) for p in existing_parts], ignore_index=True) + for p in existing_parts: + p.unlink() + + if not csv_path.exists(): + logger.warning(f"balalaika.csv not found at {csv_path}; skipping CSV update.") + return + + df = pd.read_csv(csv_path) + df['filepath'] = df['filepath'].apply(lambda p: str(Path(p).resolve())) + results_df['filepath'] = results_df['filepath'].apply(lambda p: str(Path(p).resolve())) + + if 'music_prob' in df.columns: + df = df.drop(columns=['music_prob']) + df = df.merge(results_df[['filepath', 'music_prob']], on='filepath', how='left') + + before = len(df) + df = df[df['filepath'].apply(lambda p: Path(p).exists())] + removed = before - len(df) + logger.info(f"Music detection: removed {removed} rows from CSV (files deleted).") + + df.to_csv(csv_path, index=False) + logger.success(f"CSV updated: {len(df)} rows remain.") + + +def main(args): + mp.set_start_method('spawn', force=True) + config = load_config(args.config_path, 'separation') + podcasts_path = config.get('podcasts_path') + + if not podcasts_path: + logger.error("No podcasts_path in config") + return + + all_paths = list(get_audio_paths(podcasts_path)) + n_gpus = torch.cuda.device_count() + + if not all_paths: + logger.warning("No audio files found.") + return + + if n_gpus == 0: + logger.error("No GPU found.") + return + + mp.spawn(run_worker, args=(n_gpus, all_paths, config), nprocs=n_gpus, join=True) + update_csv(Path(podcasts_path), n_gpus) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--config_path", type=str, required=True) + main(parser.parse_args()) diff --git a/src/separation/separation.py b/src/separation/separation.py deleted file mode 100644 index 3cdca60..0000000 --- a/src/separation/separation.py +++ /dev/null @@ -1,299 +0,0 @@ -import argparse -import os -import multiprocessing -from concurrent.futures import ProcessPoolExecutor, as_completed -from pathlib import Path -from typing import Dict, List -import yaml - -import numpy as np -import pandas as pd -import torch -import torchaudio -from dotenv import load_dotenv -from loguru import logger -from pyannote.audio import Pipeline -from tqdm import tqdm - -from src.libs.nisqa.core.model_torch import model_init -from src.libs.nisqa.utils.process_utils import process -from src.utils import load_config, get_audio_paths - -_global_worker = None - -class Worker: - def __init__( - self, - one_speaker: bool, - nisqa_config_path: str, - gpu_id: int, - hf_token: str, - ): - self.one_speaker = one_speaker - self.nisqa_config_path = nisqa_config_path - self.hf_token = hf_token - self.device = f"cuda:{gpu_id}" - self._init_models() - - def _init_models(self): - - torch.cuda.set_device(self.device) - - try: - with open(self.nisqa_config_path, "r") as f: - args_yaml = yaml.load(f, Loader=yaml.FullLoader) - - self.nisqa_device = self.device - args = {**args_yaml, "inf_device": self.device} - - self.nisqa_model, self.h0, self.c0 = model_init(args) - self.nisqa_args = args - except Exception as e: - logger.warning(f"{e} on {self.device}.") - - - try: - self.diarization_model = Pipeline.from_pretrained( - "pyannote/speaker-diarization-3.1", - use_auth_token=self.hf_token - ).to(torch.device(self.device)) - except Exception as e: - logger.warning(f"{e} on {self.device}.") - - def process_audio(self, audio_path: str) -> Dict: - audio_path = Path(audio_path) - frame_duration = self.nisqa_args.get("frame") - audio_frames, sr, audio = self._preprocess_audio(str(audio_path), frame_duration) - - is_mono = True - is_mono = self._check_single_speaker(audio, sr) - - avg_out = self._nisqua_predict(audio_frames, sr) - NOI, COL, DISC, LOUD, MOS = avg_out - - file_parts = audio_path.name.split('_') - playlist_id = file_parts[-2] if len(file_parts) > 0 else 'N/A' - podcast_id = file_parts[-1].split('.')[0] if len(file_parts) > 1 else 'N/A' - - return { - 'audio_path': '/'.join(audio_path.parts[-3:]), - 'is_mono': is_mono, - 'NOI': NOI, - 'COL': COL, - 'DISC': DISC, - 'LOUD': LOUD, - 'MOS': MOS, - 'playlist_id': playlist_id, - 'podcast_id': podcast_id, - 'start': file_parts[0] if len(file_parts) > 0 else 'N/A', - 'end': file_parts[1] if len(file_parts) > 1 else 'N/A' - } - - - def _preprocess_audio(self, audio_path: str, frame_duration: int): - audio, sr = torchaudio.load(audio_path) - if audio.shape[0] != 1: - audio = torch.mean(audio, dim=0, keepdim=True) - audio = audio.squeeze(0) - - audio = audio.to(self.device) - frame_size = int(sr * frame_duration) - if len(audio) % frame_size != 0: - padding = frame_size - (len(audio) % frame_size) - audio = torch.cat([audio, torch.zeros(padding, device=self.device)]) - frames = torch.split(audio, frame_size) - - return frames, sr, audio - - def _check_single_speaker(self, waveform: torch.Tensor, sr: int) -> bool: - try: - diarization = self.diarization_model({ - "waveform": waveform.unsqueeze(0), - "sample_rate": sr - }) - is_single_speaker = len({speaker for _, _, speaker in diarization.itertracks(yield_label=True)}) == 1 - - if not is_single_speaker and self.one_speaker: - audio_path = str(waveform.audio_path) - base_path = os.path.splitext(audio_path)[0] - - for ext in ['.mp3', '_giga.txt', '_punct.txt', '_accent.txt', '_e.txt', '_e_phonemes.txt']: - file_path = base_path + ext - if os.path.exists(file_path): - os.remove(file_path) - logger.info(f"Deleted {file_path} due to multiple speakers detected") - - return is_single_speaker - - except Exception as e: - logger.error(f"Diarization error on {self.device}: {e}") - return True - - - def _nisqua_predict(self, frames: List[torch.Tensor], sr: int) -> np.ndarray: - - outputs = [] - h, c = self.h0.clone().to(self.device), self.c0.clone().to(self.device) - - for frame in frames: - out, h, c = process(frame.to(self.device), sr, self.nisqa_model, h, c, self.nisqa_args) - outputs.append(out[0].cpu().numpy()) - - return np.mean(outputs, axis=0) - - -def _worker_initializer( - one_speaker: bool, - nisqa_config_path: str, - hf_token: str, - gpu_id_assignment_queue - ): - global _global_worker - gpu_id = None - try: - if not gpu_id_assignment_queue.empty(): - gpu_id = gpu_id_assignment_queue.get() - - _global_worker = Worker( - one_speaker=one_speaker, - nisqa_config_path=nisqa_config_path, - gpu_id=gpu_id, - hf_token=hf_token - ) - except Exception as e: - logger.error(f"Failed to initialize worker process on {f'GPU {gpu_id}' if gpu_id is not None else 'CPU'}: {e}") - -def _process_audio_task(audio_path: str) -> Dict: - global _global_worker - - if _global_worker is None: - logger.error(f"Worker not initialized in this process. Cannot process {audio_path}.") - try: - result = _global_worker.process_audio(audio_path) - return result - except Exception as e: - logger.error(f"Error during audio processing for {audio_path} by worker on {_global_worker.device}: {e}") - return None - finally: - if _global_worker and _global_worker.device.startswith("cuda"): - torch.cuda.empty_cache() - -def main(args): - load_dotenv() - hf_token = os.getenv("HF_TOKEN") - config = load_config(args.config_path, 'separation') - - podcasts_path = args.podcasts_path if args.podcasts_path else config.get('podcasts_path', '/../../../balalaika') - one_speaker = args.one_speaker if args.one_speaker else config.get('one_speaker', False) - num_workers = args.num_workers if args.num_workers else config.get('num_workers', 4) - - nisqa_config_path = args.nisqa_config if args.nisqa_config else config.get('nisqa_config', '') - num_gpus = torch.cuda.device_count() - - if num_gpus == 0: - logger.error("No GPUs available. Exiting.") - return - - actual_max_workers = num_gpus - gpu_ids_available = list(range(num_gpus)) - - logger.info(f""" - Using params: - Podcasts path: {podcasts_path} - One speaker : {one_speaker} - num_workers : {num_workers} - Number of GPUs detected: {num_gpus} - """) - - audio_paths = get_audio_paths(podcasts_path) - result_csv_path = Path(podcasts_path) / 'results.csv' - - - if os.path.exists(result_csv_path): - logger.info('csv file exists') - df = pd.read_csv(result_csv_path) - processed_audio_paths = set(df['audio_path'].tolist()) - else: - processed_audio_paths = set() - logger.info(f'csv does not exist: found {len(audio_paths) - len(processed_audio_paths)} files') - - audio_paths_to_process = [ - audio_path for audio_path in audio_paths - if str('/'.join(Path(audio_path).parts[-3:])) not in processed_audio_paths - ] - - - if not audio_paths: - logger.error(f"No audio files found in {podcasts_path}") - return - - results = [] - num_workers_per_gpu = num_workers - manager = multiprocessing.Manager() - gpu_id_assignment_queue = manager.Queue() - - for gid in gpu_ids_available: - for _ in range(num_workers_per_gpu): - gpu_id_assignment_queue.put(gid) - - actual_max_workers = len(gpu_ids_available) * num_workers_per_gpu - - if actual_max_workers == 0: - actual_max_workers = 1 - if gpu_id_assignment_queue.empty(): - gpu_id_assignment_queue.put(None) - - with ProcessPoolExecutor( - max_workers=actual_max_workers, - mp_context=multiprocessing.get_context('spawn'), - initializer=_worker_initializer, - initargs=(one_speaker, nisqa_config_path, hf_token, gpu_id_assignment_queue) - ) as executor: - futures = [executor.submit(_process_audio_task, str(path)) for path in audio_paths_to_process] - - with tqdm(total=len(audio_paths_to_process), desc="Processing files") as pbar: - for future in as_completed(futures): - result = future.result() - if result: - results.append(result) - pbar.update(1) - - csv_path = Path(podcasts_path) / "results.csv" - pd.DataFrame(results).to_csv( - csv_path, - mode='a', - header=not csv_path.exists(), - index=False - ) - logger.success(f"Processing completed successfully. Results saved to {result_csv_path}") - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Process audio files using multiple GPUs") - parser.add_argument( - "--config_path", - type=str, - help="Path to the YAML configuration file." - ) - parser.add_argument( - "--nisqa_config", - type=str, - help="Path to the NISQA YAML configuration file." - ) - parser.add_argument( - "--podcasts_path", - type=str, - help="Path to the directory containing podcast audio files." - ) - parser.add_argument( - "--one_speaker", - type=bool, - help="Boolean flag to indicate if only one speaker is expected per audio file" - ) - parser.add_argument( - "--num_workers", - type=int, - help="Number of worker processes per GPU for parallel processing." - ) - - args = parser.parse_args() - main(args) diff --git a/src/separation/separation_args.sh b/src/separation/separation_args.sh deleted file mode 100644 index dcbfec6..0000000 --- a/src/separation/separation_args.sh +++ /dev/null @@ -1,31 +0,0 @@ -#!/bin/bash - -activate_venv() { - local venv_path=$1 - if [ ! -f "$venv_path/bin/activate" ]; then - echo "Error: Virtual environment not found at $venv_path" - exit 1 - fi - source "$venv_path/bin/activate" - echo "Activated: $(which python)" -} - -activate_venv ".dev_venv" - -SCRIPT_DIR=$(dirname "$(realpath "$0")") -CONFIG_PATH="$SCRIPT_DIR/../../configs/config.yaml" -NISQA_CONFIG_PATH="$SCRIPT_DIR/../../configs/nisqa_config.yaml" -PODCASTS_PATH="../../../balalaika" -USE_NISQA="True" -USE_MONO="True" -ONE_SPEAKER="False" -NUM_WORKERS="4" - -python3 -m src.separation.separation \ - --config_path "$CONFIG_PATH" \ - --nisqa_config "$NISQA_CONFIG_PATH" \ - --podcasts_path "$PODCASTS_PATH" \ - --use_nisqa "$USE_NISQA" \ - --use_mono "$USE_MONO" \ - --one_speaker "$ONE_SPEAKER" \ - --num_workers "$NUM_WORKERS" \ No newline at end of file diff --git a/src/separation/separation_yaml.sh b/src/separation/separation_yaml.sh index a3053f1..0241487 100644 --- a/src/separation/separation_yaml.sh +++ b/src/separation/separation_yaml.sh @@ -16,9 +16,9 @@ if [ -z "${1:-}" ]; then fi CONFIG_PATH=$(realpath "$1") - activate_venv ".dev_venv" SCRIPT_DIR=$(dirname "$(realpath "$0")") -python3 -m src.separation.separation --config_path "$CONFIG_PATH" \ No newline at end of file +python3 -m src.separation.music_detect --config_path "$CONFIG_PATH" +python3 -m src.separation.distillmos_process --config_path "$CONFIG_PATH" diff --git a/src/to_webdataset.py b/src/to_webdataset.py new file mode 100644 index 0000000..5af7b68 --- /dev/null +++ b/src/to_webdataset.py @@ -0,0 +1,173 @@ +import argparse +import multiprocessing +import pandas as pd +import math +import json +from pathlib import Path +from concurrent.futures import ProcessPoolExecutor, as_completed +from typing import List, Dict + +import webdataset as wds +from tqdm import tqdm +from loguru import logger + +from src.utils.utils import load_config, get_audio_paths + +def load_metadata(csv_path: Path) -> Dict[str, dict]: + """Загружает balalaika.csv и делает словарь с ключом по базовому имени файла.""" + if not csv_path.exists(): + logger.warning(f"Metadata file {csv_path} not found!") + return {} + + df = pd.read_csv(csv_path) + metadata_dict = {} + + for _, row in df.iterrows(): + base_name = Path(row['filepath']).stem + row_dict = row.to_dict() + row_dict.pop('filepath', None) + metadata_dict[base_name] = row_dict + + logger.info(f"Loaded metadata for {len(metadata_dict)} files.") + return metadata_dict + +def worker_fn(worker_id: int, audio_paths: List[str], output_dir: Path, metadata_dict: Dict[str, dict], max_shard_size: int, max_shard_count: int): + if not audio_paths: + return 0 + + pattern = str(output_dir / f"shard_{worker_id:03d}_%04d.tar") + samples_processed = 0 + + with wds.ShardWriter(pattern, maxsize=max_shard_size, maxcount=max_shard_count) as sink: + for audio_str in tqdm(audio_paths, desc=f"Worker {worker_id}", position=worker_id): + audio_path = Path(audio_str) + + key = audio_path.stem + ext = audio_path.suffix.lstrip('.') + + if not audio_path.exists(): + continue + + # IMPORTANT: WebDataset uses the first dot to separate the key from the extension. + # Replaces dots in keys with underscores to not break the HuggingFace parser. + safe_key = key.replace('.', '_') + + # --- 1. Reads audio bytes --- + try: + audio_bytes = audio_path.read_bytes() + except Exception as e: + logger.warning(f"Error reading {audio_path}: {e}") + continue + + # --- 2. Formats JSON with texts and metadata --- + json_data = {} + + # Search in CSV still goes by the original key (with dots) + if key in metadata_dict: + for k, v in metadata_dict[key].items(): + k_str = str(k) + if pd.isna(v) or (isinstance(v, float) and math.isnan(v)): + json_data[k_str] = None + elif isinstance(v, (pd.Timestamp, pd.Timedelta)): + json_data[k_str] = str(v) + elif hasattr(v, 'item'): + json_data[k_str] = v.item() + else: + json_data[k_str] = v + + # Finds ALL files with the same prefix + parent_dir = audio_path.parent + siblings = set(parent_dir.glob(f"{key}_*")).union(set(parent_dir.glob(f"{key}.*"))) + + for sibling in siblings: + if not sibling.is_file() or sibling == audio_path: + continue + + postfix_name = sibling.name[len(key):].lstrip('_.') + + try: + text_content = sibling.read_text(encoding='utf-8').strip() + json_data[str(postfix_name)] = text_content + except UnicodeDecodeError: + pass + except Exception as e: + logger.warning(f"Error reading {sibling}: {e}") + + # --- 3. MANUAL SERIALIZATION OF JSON --- + try: + json_bytes = json.dumps(json_data, ensure_ascii=False).encode('utf-8') + except Exception as e: + logger.error(f"Failed to serialize JSON for {key}: {e}") + continue + + # Uses safe_key, so that the files safe_key.mp3 and safe_key.json are inside the .tar + sample = { + "__key__": safe_key, + ext: audio_bytes, + "json": json_bytes + } + + try: + sink.write(sample) + samples_processed += 1 + except Exception as e: + logger.error(f"Failed to write sample {key} to tar: {e}") + + return samples_processed + +def main(config): + podcasts_path_str = config.get('podcasts_path') + if not podcasts_path_str: + logger.error("podcasts_path is not defined in the config!") + return + + max_shard_size = config.get('max_shard_size', 512 * 1024 * 1024) + max_shard_count = config.get('max_shard_count', 10000) + + podcasts_path = Path(podcasts_path_str) + csv_path = podcasts_path / 'balalaika.csv' + + wds_output_dir = podcasts_path.parent / f"{podcasts_path.name}_webdataset" / "train" + wds_output_dir.mkdir(parents=True, exist_ok=True) + logger.info(f"WebDataset shards will be saved to: {wds_output_dir}") + + num_workers = config.get('num_workers', 4) + num_workers = max(1, num_workers) + + all_audio_paths = get_audio_paths(podcasts_path_str) + if not all_audio_paths: + logger.warning("No audio data to process.") + return + + metadata_dict = load_metadata(csv_path) + + chunk_size = len(all_audio_paths) // num_workers + 1 + chunks = [all_audio_paths[i:i + chunk_size] for i in range(0, len(all_audio_paths), chunk_size)] + + logger.info(f"Starting {len(chunks)} workers to build WebDataset from {len(all_audio_paths)} audio files...") + + total_processed = 0 + with ProcessPoolExecutor(max_workers=len(chunks)) as executor: + futures = [ + executor.submit(worker_fn, worker_id, chunk, wds_output_dir, metadata_dict, max_shard_size, max_shard_count ) + for worker_id, chunk in enumerate(chunks) + ] + + for future in as_completed(futures): + try: + total_processed += future.result() + except Exception as e: + logger.error(f"Worker failed with error: {e}") + + logger.success(f"WebDataset creation completed! Total samples packed: {total_processed}") + logger.success(f"Output directory: {wds_output_dir}") + +if __name__ == "__main__": + multiprocessing.set_start_method('spawn', force=True) + + parser = argparse.ArgumentParser() + parser.add_argument("--config_path", type=str, required=True, help="Path to YAML config") + args = parser.parse_args() + + config = load_config(args.config_path, process_name='export') + main(config) \ No newline at end of file diff --git a/src/transcription/README.md b/src/transcription/README.md index a36a5ce..d2cd211 100644 --- a/src/transcription/README.md +++ b/src/transcription/README.md @@ -1,35 +1,43 @@ -## Usage/Examples +## Transcription (onnx-asr) -### Running the Code via Command-Line Arguments -You can modify the parameters directly in the shell script (`transcription/transcription_args.sh`) and then run it: -~~~ -sh transcription/transcription_args.sh -~~~ +ASR via **[onnx-asr](https://github.com/istupakov/onnx-asr)** on **ONNX Runtime**, optionally **TensorRT** — no custom PyTorch dataloaders in this repo. -### Running the Code via Config File -Example: -~~~ -sh transcription/transcription_yaml.sh config_path -~~~ +### Features -## Explanation of Parameters +- Run multiple models sequentially with **early skip** when `consensus_num` earlier models agree on normalized text. +- **ROVER** → `{stem}_rover.txt` when `use_rover: True`. +- **Word-level timestamps** → `{stem}_{model}.tst` (TSV) when `with_timestamps: True` and the model is in the supported set. +- **Multi-GPU** via multiprocessing. -- `--config_path`: Path to the YAML configuration file. - *Note: When provided, the config file may include additional settings such as `model_name` and `device`.* -- `--podcasts_path`: Path to the directory containing audio files for transcription (default: "../../../podcasts"). -- `--num_workers`: Number of worker processes per GPU for parallel processing (default: 4). -- `--model_name`: Name of the model to use for transcription (default: "rnnt"). +### Typical `model_names` (Russian) -## Output Structure +| Config name | Backend (onnx-asr / HF id) | +|-------------|----------------------------| +| `giga_ctc` | GigaAM v3 CTC | +| `giga_rnnt` | GigaAM v3 RNN-T | +| `vosk` | Vosk Russian | +| `tone` | T-one | -For each `.mp3` audio file found within the specified `podcasts_path`, a corresponding `_giga.txt` file will be created in the same directory containing the transcription: +Others: `parakeet_v2`, `parakeet_v3`, `canary`, `whisper_base`, `whisper_turbo`, … — see `MODEL_MAP` in `transcription.py` and comments in `configs/config.yaml`. +## Run + +```bash +bash src/transcription/transcription_yaml.sh configs/config.yaml ``` -podcasts/ -└── {album_id}/ - └── {episode_id}/ - ├── {start_time}_{end_time}_{album_id}_{episode_id}.mp3 - └── {start_time}_{end_time}_{album_id}_{episode_id}_giga.txt -``` -The `_giga.txt` file contains the transcribed text from the audio file. The transcription is performed using the specified model (default: "rnnt") and is processed in parallel using multiple GPUs if available. +## Config snippet + +All keys are documented under **`transcription`** in `configs/config.yaml` (`podcasts_path`, `consensus_num`, `with_timestamps`, `use_tensorrt`, `use_vad`, `use_rover`, `model_names`, `batch_size`, plus optional `model_path`, `vosk_path`, `quantization`, `vad_params`). + +## On-disk artifacts + +For chunk `{stem}.mp3`: + +- `{stem}_{model}.txt` — hypothesis. +- `{stem}_{model}.tst` — timestamps when enabled. +- `{stem}_rover.txt` — ROVER consensus. + +## Dependencies + +`create_dev_env.sh` typically installs nightly **onnxruntime-gpu** for your CUDA, **`tensorrt-cu13`** (or matching wheel), and **`onnx-asr[gpu,hub]`** — pin versions there. diff --git a/src/transcription/rover.py b/src/transcription/rover.py new file mode 100644 index 0000000..b065e32 --- /dev/null +++ b/src/transcription/rover.py @@ -0,0 +1,75 @@ +from pathlib import Path +from typing import List + +import pandas as pd +from crowdkit.aggregation import ROVER +from loguru import logger +from tqdm import tqdm + +from src.utils.utils import read_file_content, get_audio_paths + +class ROVERWrapper: + def __init__(self, podcasts_path: str, model_names: List[str]): + self.podcasts_path = Path(podcasts_path) + self.model_names = model_names + self.tokenizer = lambda s: s.lower().split() + self.detokenizer = lambda tokens: ' '.join(tokens) + self.rover_aggregator = ROVER(self.tokenizer, self.detokenizer) + + def aggregate_and_save(self): + logger.info("Starting transcription aggregation based on audio files.") + + all_audio_paths = get_audio_paths(str(self.podcasts_path)) + + if not all_audio_paths: + logger.warning("Audio files not found. Aggregation finished.") + return + + records = [] + excluded_patterns = ['_rover', '_phonemes', '_accent'] + + for audio_path in tqdm(all_audio_paths, desc="Aggregating transcriptions"): + if any(pattern in audio_path.stem for pattern in excluded_patterns): + continue + + for model_name in self.model_names: + suffix = 'vosk' if 'vosk' in model_name else model_name + transcript_path = audio_path.with_name(f"{audio_path.stem}_{suffix}.txt") + + if not transcript_path.exists(): + continue + + try: + text = read_file_content(transcript_path) + if not text: + continue + + records.append({ + 'task': str(audio_path), + 'worker': model_name, + 'text': text + }) + except Exception as e: + logger.error(f"Error reading file {transcript_path}: {e}") + + df = pd.DataFrame(records) + if df.empty: + logger.warning("No transcriptions found for aggregation. Check file paths and names.") + return + + df['text'] = df['text'].str.lower() + logger.info(f"Running ROVER on {len(df['task'].unique())} unique audio files...") + result = self.rover_aggregator.fit_predict(df) + + logger.info("Saving aggregated results...") + for task_path, agg_text in result.items(): + audio_path = Path(task_path) + output_path = audio_path.with_name(f"{audio_path.stem}_rover.txt") + + try: + with open(output_path, "w", encoding="utf-8") as f: + f.write(agg_text) + except IOError as e: + logger.error(f"Failed to write result to {output_path}: {e}") + + logger.info("Aggregation complete.") \ No newline at end of file diff --git a/src/transcription/transcription.py b/src/transcription/transcription.py index 3a8ef7a..777054d 100644 --- a/src/transcription/transcription.py +++ b/src/transcription/transcription.py @@ -1,254 +1,321 @@ import argparse -import multiprocessing -import os -from concurrent.futures import ProcessPoolExecutor, as_completed -from functools import partial +import multiprocessing as mp from pathlib import Path -from typing import List, Tuple - -import gigaam -import pyctcdecode -import torch -import torchaudio +from typing import List, Optional +from collections import Counter from loguru import logger from tqdm import tqdm -from src.utils import get_audio_paths, load_config +import onnx_asr -# Global variables for each worker process -model = None -decoder = None -# Frame size in milliseconds for the GigaAM model, crucial for timestamp calculation -GIGA_AM_FRAME_SIZE_MS = 40 +try: + import soundfile as sf + HAS_SOUNDFILE = True +except ImportError: + HAS_SOUNDFILE = False +from src.utils.utils import get_audio_paths, load_config, read_file_content -def init_process( - model_name: str, - device_str: str, - lm_path: str, - with_timestamps: bool, -): - """ - Initializes the model and, if needed, the CTC decoder for each worker process. - """ - global model, decoder - logger.info(f"Initializing worker on {device_str}...") - if not (with_timestamps and 'ctc' in model_name): - logger.info("Timestamp generation requested CTC model. Decoder will not be initialized.") - with_timestamps=False +MODEL_MAP = { + 'giga_rnnt': 'gigaam-v3-rnnt', + 'giga_ctc': 'gigaam-v3-ctc', + 'giga_ctc_lm': 'gigaam-v3-ctc', + 'tone': 't-tech/t-one', + 'vosk': 'alphacep/vosk-model-ru', + 'vosk_small': 'alphacep/vosk-model-small-ru', + 'parakeet_v2': 'nemo-parakeet-tdt-0.6b-v2', + 'parakeet_v3': 'nemo-parakeet-tdt-0.6b-v3', + 'canary': 'nemo-canary-1b-v2', + 'whisper_base': 'whisper-base', + 'whisper_turbo': 'onnx-community/whisper-large-v3-turbo', +} - model = gigaam.load_model(model_name, device=device_str) +SUPPORTED_TIMESTAMPS = {'giga_ctc', 'giga_ctc_lm', 'tone', 'parakeet_v2', 'parakeet_v3', 'canary'} - if with_timestamps: - if not lm_path: - logger.warning("Timestamp generation requested without an LM path. Decoder will not be initialized.") - decoder = None - return - logger.info(f"Building CTC decoder with LM: {lm_path}") +def get_gpu_count() -> int: + try: + import onnxruntime as ort + if 'CUDAExecutionProvider' not in ort.get_available_providers(): + return 0 + except ImportError: + return 0 + try: + import subprocess + result = subprocess.run( + ['nvidia-smi', '--query-gpu=index', '--format=csv,noheader'], + capture_output=True, text=True, timeout=5 + ) + if result.returncode == 0: + return len([line for line in result.stdout.strip().split('\n') if line.strip()]) + except Exception: + pass + return 1 + + +def get_providers(cuda_id: int, use_tensorrt: bool = False) -> list: + if use_tensorrt: + return [ + ("TensorrtExecutionProvider", { + "device_id": cuda_id, + "trt_max_workspace_size": 6 * 1024**3, + "trt_fp16_enable": True, + "trt_engine_cache_enable": True, + "trt_engine_cache_path": f".cache/trt_cache_{cuda_id}", + }), + ("CUDAExecutionProvider", {"device_id": cuda_id}), + ] + return [("CUDAExecutionProvider", {"device_id": cuda_id})] + + +def save_results(paths: List[str], texts: List[str], timestamps: Optional[List[str]], model_suffix: str): + for i, (path_str, text) in enumerate(zip(paths, texts)): + path = Path(path_str) + + txt_path = path.with_name(f"{path.stem}_{model_suffix}.txt") try: - vocab = model.decoding.tokenizer.vocab - decoder = pyctcdecode.build_ctcdecoder( - vocab, - lm_path, - alpha=0.5, - beta=1.0, - ) + with open(txt_path, "w", encoding="utf-8") as f: + f.write(text) except Exception as e: - logger.error(f"Failed to build CTC decoder: {e}") - decoder = None - + logger.error(f"Write TXT failed {path.name}: {e}") + + ts = timestamps[i] if timestamps and i < len(timestamps) else '' + if ts: + tst_path = path.with_name(f"{path.stem}_{model_suffix}.tst") + try: + with open(tst_path, "w", encoding="utf-8") as f: + f.write(ts) + except Exception as e: + logger.error(f"Write TST failed {path.name}: {e}") + + +def load_batch(file_paths: List[str]): + """Read batch — pass WAV paths directly to onnx-asr; non-WAV read via soundfile.""" + all_wav = all(Path(f).suffix.lower() == '.wav' for f in file_paths) + if all_wav or not HAS_SOUNDFILE: + return file_paths, None + + waveforms, sr = [], None + for f in file_paths: + wf, file_sr = sf.read(f, dtype='float32') + if wf.ndim > 1: + wf = wf.mean(axis=1) + waveforms.append(wf) + sr = file_sr + return waveforms, sr + + +def extract_text(result) -> str: + """Extract plain text from onnx-asr result (str or TimestampedResult).""" + if hasattr(result, 'text'): + return result.text + return str(result) + + +def format_timestamps(result) -> str: + """Format TimestampedResult as word-level TSV: start\\tend\\tword per line. + + onnx-asr TimestampedResult has parallel arrays: + .tokens = ['с', 'п', 'а', 'с', 'и', 'б', 'о', ' ', ...] + .timestamps = [0.39, 0.44, 0.51, 0.54, 0.57, 0.63, 0.66, 0.75, ...] + We group characters into words and produce word-level timestamps. + """ + tokens = getattr(result, 'tokens', None) + timestamps = getattr(result, 'timestamps', None) + + if not tokens or not timestamps or len(tokens) != len(timestamps): + return '' + + words = [] + current_word = '' + word_start = None + + for token, ts in zip(tokens, timestamps): + if token.strip() == '': + if current_word and word_start is not None: + words.append((word_start, ts, current_word)) + current_word = '' + word_start = None + else: + if word_start is None: + word_start = ts + current_word += token + + if current_word and word_start is not None: + words.append((word_start, timestamps[-1], current_word)) + + return '\n'.join(f"{start:.3f}\t{end:.3f}\t{word}" for start, end, word in words) + + +def run_worker(cuda_id: int, world_size: int, model_name: str, + all_files: List[str], config: dict): + """Inference worker: loads onnx-asr model on a single GPU and processes its shard.""" + my_files = all_files[cuda_id::world_size] + if not my_files: + return + + batch_size = config.get('batch_size', 16) + use_trt = config.get('use_tensorrt', False) + quantization = config.get('quantization') -def to_simple_timestamps(word_timestamps: List[Tuple[str, Tuple[int, int]]]) -> str: - output_lines = [] - sec_per_frame = GIGA_AM_FRAME_SIZE_MS / 1000.0 - for word, (start_frame, end_frame) in word_timestamps: - start_time = start_frame * sec_per_frame - end_time = end_frame * sec_per_frame - output_lines.append(f"{word} {start_time:.3f} {end_time:.3f}") - return "\n".join(output_lines) + onnx_name = MODEL_MAP.get(model_name, model_name) + output_suffix = 'vosk' if 'vosk' in model_name else model_name + do_timestamps = config.get('with_timestamps', False) and model_name in SUPPORTED_TIMESTAMPS + local_path = config.get('vosk_path') if 'vosk' in model_name else config.get('model_path') -def make_txt_and_tst(path: Path, with_timestamps: bool): - """ - Transcribes an audio file. If with_timestamps is True and the decoder is available, - it generates both a .txt (transcription) and a .tstt (timestamps) file. - Otherwise, it only generates the .txt file. - """ - txt_path = path.with_name(f"{path.stem}_giga.txt") - tst_path = path.with_name(f"{path.stem}_giga.tst") - - if os.path.exists(txt_path): - return + logger.info(f"Worker {cuda_id}/{world_size}: {onnx_name} on cuda:{cuda_id}, {len(my_files)} files, batch={batch_size}") - # Timestamp-enabled path using the CTC decoder - if not (with_timestamps and decoder): - text = model.transcribe(str(path)) - with open(txt_path, "w", encoding="utf-8") as f: - f.write(text) - return try: - wav, sr = torchaudio.load(path) - - if wav.shape[0] > 1: - wav = wav.mean(dim=0).unsqueeze(0) - - wav = torchaudio.functional.resample(wav, sr, 16000) - length = torch.full([1], wav.shape[-1]) - - encoded, _ = model.forward(wav.to(model._device), length.to(model._device)) - logitst = model.head(encoded).squeeze(0).detach().cpu().numpy() - - # Use decode_beams to get timestamps - beams = decoder.decode_beams(logitst, beam_width=100) - - # The top beam result with timestamps is typically at index 0 - best_beam = beams[0] - word_timestamps = best_beam[2] - - # 1. Save the plain text transcription to _giga.txt - plain_text = best_beam[0] - with open(txt_path, "w", encoding="utf-8") as f: - f.write(plain_text) - - # 2. Save the timestamps to _giga.tst - tst_content = to_simple_timestamps(word_timestamps) - with open(tst_path, "w", encoding="utf-8") as f: - f.write(tst_content) + providers = get_providers(cuda_id, use_trt) + load_args = [onnx_name] + ([local_path] if local_path else []) + load_kwargs = {"providers": providers} + if quantization: + load_kwargs["quantization"] = quantization + + model = onnx_asr.load_model(*load_args, **load_kwargs) + + if do_timestamps: + model = model.with_timestamps() + + if config.get('use_vad', False): + vad_params = config.get('vad_params', {}) + vad = onnx_asr.load_vad("silero", **vad_params) + model = model.with_vad(vad) + + for i in tqdm(range(0, len(my_files), batch_size), desc=f"ASR-{cuda_id}", position=cuda_id): + batch = my_files[i:i + batch_size] + + try: + data, sr = load_batch(batch) + kw = {"sample_rate": sr} if sr else {} + results = model.recognize(data, **kw) + except Exception as e: + logger.error(f"Batch failed: {e}. Falling back to single-file mode.") + results = [] + for f in batch: + try: + results.append(model.recognize(f)) + except Exception as e2: + logger.error(f"File failed {f}: {e2}") + results.append("") + + if not isinstance(results, list): + results = [results] + + texts = [extract_text(r) for r in results] + ts = [format_timestamps(r) for r in results] if do_timestamps else None + + save_results(batch, texts, ts, output_suffix) except Exception as e: - logger.error(f"Error processing {path} with timestamps: {e}") - # Fallback to simple transcription if timestamping fails - text = model.transcribe(str(path)) - with open(txt_path, "w", encoding="utf-8") as f: - f.write(text) - - -def get_valid_audio_paths(src_path: str) -> List[Path]: - """ - Getst all audio paths and filters out those that have already been transcribed. - """ - all_audio_paths = get_audio_paths(src_path) - valid_paths = [] - for audio_path in all_audio_paths: - giga_path = audio_path.with_name(audio_path.stem + "_giga.txt") - if not giga_path.exists(): - valid_paths.append(audio_path) - return valid_paths + logger.exception(f"Worker {cuda_id} fatal error ({model_name}): {e}") + + +def check_consensus(audio_path: Path, model_names: List[str], consensus_num: int) -> bool: + texts = [] + for mn in model_names: + suffix = 'vosk' if 'vosk' in mn else mn + tp = audio_path.with_name(f"{audio_path.stem}_{suffix}.txt") + if tp.exists(): + try: + t = read_file_content(tp) + if t: + texts.append(t.lower().strip()) + except Exception: + pass + if len(texts) < consensus_num: + return False + return max(Counter(texts).values()) >= consensus_num + + +def get_valid_paths(src_path: str, output_suffix: str, + processed: List[str], consensus_num: int) -> List[str]: + all_paths = get_audio_paths(src_path) + if not all_paths: + return [] + + valid = [p for p in all_paths if not p.with_name(f"{p.stem}_{output_suffix}.txt").exists()] + + if consensus_num > 0 and len(processed) >= consensus_num: + skipped = 0 + filtered = [] + for p in valid: + if check_consensus(p, processed, consensus_num): + skipped += 1 + else: + filtered.append(p) + if skipped: + logger.info(f"Consensus reached for {skipped} files, skipping") + valid = filtered + + return [str(p) for p in valid] def main(args): config = load_config(args.config_path, 'transcription') + model_names = config.get('model_names', ['giga_rnnt']) + src_path = config.get('podcasts_path', '.') + consensus_num = config.get('consensus_num', 0) - model_name = args.model_name if args.model_name else config.get('model_name', 'rnnt') - num_workers_per_gpu = args.num_workers if args.num_workers else config.get('num_workers', 4) - src_path = args.podcasts_path if args.podcasts_path else config.get('podcasts_path', '../../../balalaika') - lm_path = args.lm_path if args.lm_path else config.get('lm_path', 'ru.lm.bin') - with_timestamps = args.with_timestamps if args.with_timestamps else config.get('with_timestamps', 'False') - - if with_timestamps and not lm_path: - raise ValueError("Language model path (--lm_path) is required when using --with_timestamps.") - - all_audio_paths = get_valid_audio_paths(src_path) - logger.info(f"Found {len(all_audio_paths)} audio files to process.") - - available_gpu_ids = list(range(torch.cuda.device_count())) - num_gpus = len(available_gpu_ids) - + num_gpus = get_gpu_count() if num_gpus == 0: - logger.error("No GPUs available. Exiting.") + logger.error("No CUDA GPUs detected. GPU required for transcription.") return - logger.info( - f""" - Starting transcription with parameters: - Source Path: {src_path} - Model Name: {model_name} - Timestamps Enabled: {with_timestamps} - Language Model Path: {lm_path} - Number of GPUs: {num_gpus} (IDs: {available_gpu_ids}) - Workers per GPU: {num_workers_per_gpu} - Total Worker Processes: {num_gpus * num_workers_per_gpu} - """ - ) - - files_for_each_gpu = [[] for _ in range(num_gpus)] - for i, path in enumerate(all_audio_paths): - gpu_assignment_index = i % num_gpus - files_for_each_gpu[gpu_assignment_index].append(path) - - all_futures = [] - executors = [] - - task_fn = partial(make_txt_and_tst, with_timestamps=with_timestamps) - - for i, gpu_id in enumerate(available_gpu_ids): - device_str = f'cuda:{gpu_id}' - files_for_this_gpu = files_for_each_gpu[i] - - if not files_for_this_gpu: - continue + logger.info(f"{num_gpus} GPU(s) detected. Starting transcription pipeline.") + if consensus_num > 0: + logger.info(f"Consensus mode: {consensus_num} models must agree") - logger.info(f"Creating ProcessPoolExecutor for {device_str} with {num_workers_per_gpu} workers for {len(files_for_this_gpu)} files.") - - executor = ProcessPoolExecutor( - max_workers=num_workers_per_gpu, - initializer=init_process, - initargs=(model_name, device_str, lm_path, with_timestamps), - ) - executors.append(executor) + for idx, model_name in enumerate(model_names): + logger.info(f"=== [{idx + 1}/{len(model_names)}] {model_name} ===") - for path in files_for_this_gpu: - future = executor.submit(task_fn, path) - all_futures.append(future) + output_suffix = 'vosk' if 'vosk' in model_name else model_name + processed = model_names[:idx] if consensus_num > 0 else [] + paths = get_valid_paths(src_path, output_suffix, processed, consensus_num) - for future in tqdm(as_completed(all_futures), total=len(all_futures), desc="Overall Transcription Progress"): + if not paths: + logger.info(f"No files to process for {model_name}") + continue + + logger.info(f"{len(paths)} files to process") + + if num_gpus == 1: + run_worker(0, 1, model_name, paths, config) + else: + procs = [] + for gid in range(num_gpus): + p = mp.Process( + target=run_worker, + args=(gid, num_gpus, model_name, paths, config) + ) + p.start() + procs.append(p) + + for p in procs: + p.join() + + failed = [p.exitcode for p in procs if p.exitcode != 0] + if failed: + logger.error(f"Workers failed with exit codes: {failed}") + + if config.get('use_rover', False): + logger.info("ROVER aggregation...") try: - future.result() + from src.transcription.rover import ROVERWrapper + ROVERWrapper(podcasts_path=src_path, model_names=model_names).aggregate_and_save() + logger.info("ROVER done.") + except ImportError: + logger.warning("ROVER module not available, skipping") except Exception as e: - logger.error(f"A task processing encountered an error: {e}") + logger.error(f"ROVER failed: {e}") - for executor in executors: - executor.shutdown() + logger.info("Transcription pipeline complete!") if __name__ == "__main__": - multiprocessing.set_start_method('spawn', force=True) - torchaudio.set_audio_backend('soundfile') - - parser = argparse.ArgumentParser( - description="Transcribe audio files in parallel using multiple GPUs." - ) - parser.add_argument( - "--config_path", - type=str, - help="Path to the configuration YAML file." - ) - parser.add_argument( - "--podcasts_path", - type=str, - help="Path to the directory containing audio files (e.g., MP3s)." - ) - parser.add_argument( - "--num_workers", - type=int, - help="Number of worker processes per GPU for parallel processing." - ) - parser.add_argument( - "--model_name", - type=str, - help="Name of the model to use for transcription (e.g., 'rnnt', 'ctc')." - ) - parser.add_argument( - "--lm_path", - type=str, - help="Path to the language model binary file (e.g., 'ru.lm.bin') required for timestamps." - ) - parser.add_argument( - '--with_timestamps', - type=bool, - help="Enable to generate tst files with word timestamps." - ) - - args = parser.parse_args() - main(args) \ No newline at end of file + mp.set_start_method('spawn', force=True) + + parser = argparse.ArgumentParser(description="ASR Transcription (onnx-asr)") + parser.add_argument("--config_path", type=str, required=True) + main(parser.parse_args()) diff --git a/src/transcription/transcription_args.sh b/src/transcription/transcription_args.sh deleted file mode 100644 index 633cc29..0000000 --- a/src/transcription/transcription_args.sh +++ /dev/null @@ -1,28 +0,0 @@ -#!/bin/bash - -activate_venv() { - local venv_path=$1 - if [ ! -f "$venv_path/bin/activate" ]; then - echo "Error: Virtual environment not found at $venv_path" - exit 1 - fi - source "$venv_path/bin/activate" - echo "Activated: $(which python)" -} - -activate_venv ".dev_venv" - -SCRIPT_DIR=$(dirname "$(realpath "$0")") - -PODCASTS_PATH="/home/nikita/Balalaika100H" -NUM_WORKERS=1 -MODEL_NAME="ctc" -LM_PATH="/home/nikita/yapoddataset/ru.lm.bin" -WITH_TIMESTAMPS=True - -python -m src.transcription.transcription \ - --podcasts_path "$PODCASTS_PATH" \ - --num_workers "$NUM_WORKERS" \ - --model_name "$MODEL_NAME" \ - --lm_path "$LM_PATH" \ - --with_timestamps "$WITH_TIMESTAMPS" diff --git a/src/transcription/transcription_yaml.sh b/src/transcription/transcription_yaml.sh index e221468..0b9eb35 100644 --- a/src/transcription/transcription_yaml.sh +++ b/src/transcription/transcription_yaml.sh @@ -21,4 +21,4 @@ activate_venv ".dev_venv" SCRIPT_DIR=$(dirname "$(realpath "$0")") -python3 -m src.transcription.transcription --config_path "$CONFIG_PATH" \ No newline at end of file +python3 -m src.transcription.transcription --config_path "$CONFIG_PATH" \ No newline at end of file diff --git a/src/utils.py b/src/utils/utils.py similarity index 91% rename from src/utils.py rename to src/utils/utils.py index bfd3cf6..25e538a 100644 --- a/src/utils.py +++ b/src/utils/utils.py @@ -2,6 +2,7 @@ from typing import List import yaml from loguru import logger +import re def load_config(config_path: str, process_name: str): config = {} @@ -29,15 +30,17 @@ def read_file_content(file_path): except FileNotFoundError: return '' -def get_audio_paths(podcast_path: str) -> List[Path]: +def get_audio_paths(podcast_path: str): podcast_path=Path(podcast_path) return ( list(podcast_path.rglob("*.mp3")) + list(podcast_path.rglob("*.wav")) + list(podcast_path.rglob("*.flac")) + - list(podcast_path.rglob("*.ogg")) + list(podcast_path.rglob("*.ogg")) + + list(podcast_path.rglob("*.opus")) ) + def process_token(token, label): if label == "LOWER_O": return token @@ -104,4 +107,10 @@ def process_token(token, label): if label == "UPPER_TOTAL_MNOGOTOCHIE": return token.upper() + "..." if label == "UPPER_TOTAL_QUESTIONVOSKL": - return token.upper() + "?!" \ No newline at end of file + return token.upper() + "?!" + +def normalize_text(text: str) -> str: + text = text.lower().strip() + text = re.sub(r'[^\w\s]', '', text) + text = re.sub(r'\s+', ' ', text) + return text \ No newline at end of file diff --git a/use_meta_1000h.sh b/use_meta_1000h.sh index d68f6a8..b47688e 100644 --- a/use_meta_1000h.sh +++ b/use_meta_1000h.sh @@ -1,5 +1,4 @@ -# bin/bash - +#!/bin/bash activate_venv() { local venv_path=$1 @@ -8,18 +7,36 @@ activate_venv() { exit 1 fi source "$venv_path/bin/activate" - echo "Activated: $(which python)" + echo "Activated virtual environment: $(which python)" } -wget https://huggingface.co/datasets/MTUCI/Balalaika1000H/resolve/main/Balalaika1000H.parquet -wget https://huggingface.co/datasets/MTUCI/Balalaika1000H/resolve/main/Balalaika1000H.pkl +download_if_not_exists() { + local url=$1 + local filename=$2 + + if [ ! -f "$filename" ]; then + echo "Downloading $filename..." + wget "$url" -O "$filename" || { + echo "Error: Failed to download $filename" + exit 1 + } + else + echo "$filename already exists, skipping download." + fi +} PODCASTS_PATH="Balalaika1000H" PICKLE_PATH="Balalaika1000H.pkl" PARQUET_PATH="Balalaika1000H.parquet" NUM_WORKERS=4 +PICKLE_URL="https://huggingface.co/datasets/MTUCI/Balalaika1000H/resolve/main/Balalaika1000H.pkl" +PARQUET_URL="https://huggingface.co/datasets/MTUCI/Balalaika1000H/resolve/main/Balalaika1000H.parquet" + +download_if_not_exists "$PICKLE_URL" "$PICKLE_PATH" +download_if_not_exists "$PARQUET_URL" "$PARQUET_PATH" + activate_venv ".user_venv" -bash src/download/download_prepared.sh $PODCASTS_PATH $PICKLE_PATH $NUM_WORKERS -bash src/recovery_from_meta_yamls.sh $PODCASTS_PATH $PARQUET_PATH $NUM_WORKERS \ No newline at end of file +bash src/download/download_prepared.sh "$PODCASTS_PATH" "$PICKLE_PATH" "$NUM_WORKERS" +bash src/recovery_from_meta_yamls.sh "$PODCASTS_PATH" "$PARQUET_PATH" "$NUM_WORKERS" diff --git a/use_meta_100h.sh b/use_meta_100h.sh index d670ecf..6babb67 100644 --- a/use_meta_100h.sh +++ b/use_meta_100h.sh @@ -1,4 +1,4 @@ -# bin/bash +#!/bin/bash activate_venv() { local venv_path=$1 @@ -7,18 +7,46 @@ activate_venv() { exit 1 fi source "$venv_path/bin/activate" - echo "Activated: $(which python)" + echo "Activated virtual environment: $(which python)" } -wget https://huggingface.co/datasets/MTUCI/Balalaika100H/resolve/main/Balalaika100H.parquet -wget https://huggingface.co/datasets/MTUCI/Balalaika100H/resolve/main/Balalaika100H.pkl +download_if_not_exists() { + local url=$1 + local filename=$2 + + if [ ! -f "$filename" ]; then + echo "Downloading $filename..." + wget "$url" -O "$filename" || { + echo "Error: Failed to download $filename" + exit 1 + } + else + echo "$filename already exists, skipping download." + fi +} PODCASTS_PATH="Balalaika100H" PICKLE_PATH="Balalaika100H.pkl" PARQUET_PATH="Balalaika100H.parquet" NUM_WORKERS=4 +PICKLE_URL="https://huggingface.co/datasets/MTUCI/Balalaika100H/resolve/main/Balalaika100H.pkl" +PARQUET_URL="https://huggingface.co/datasets/MTUCI/Balalaika100H/resolve/main/Balalaika100H.parquet" + +download_if_not_exists "$PICKLE_URL" "$PICKLE_PATH" +download_if_not_exists "$PARQUET_URL" "$PARQUET_PATH" + activate_venv ".user_venv" -bash src/download/download_prepared.sh $PODCASTS_PATH $PICKLE_PATH $NUM_WORKERS -bash src/recovery_from_meta_yamls.sh $PODCASTS_PATH $PARQUET_PATH $NUM_WORKERS \ No newline at end of file +echo "Starting processing..." +bash src/download/download_prepared.sh "$PODCASTS_PATH" "$PICKLE_PATH" "$NUM_WORKERS" || { + echo "Error in download_prepared.sh" + exit 1 +} + +bash src/recovery_from_meta_yamls.sh "$PODCASTS_PATH" "$PARQUET_PATH" "$NUM_WORKERS" || { + echo "Error in recovery_from_meta_yamls.sh" + exit 1 +} + +echo "All operations completed successfully." \ No newline at end of file diff --git a/use_meta_2000h.sh b/use_meta_2000h.sh index 87efed1..43b9b16 100644 --- a/use_meta_2000h.sh +++ b/use_meta_2000h.sh @@ -1,4 +1,4 @@ -# bin/bash +#!/bin/bash activate_venv() { local venv_path=$1 @@ -7,18 +7,36 @@ activate_venv() { exit 1 fi source "$venv_path/bin/activate" - echo "Activated: $(which python)" + echo "Activated virtual environment: $(which python)" } -wget https://huggingface.co/datasets/MTUCI/Balalaika2000H/resolve/main/Balalaika2000H.parquet -wget https://huggingface.co/datasets/MTUCI/Balalaika2000H/resolve/main/Balalaika2000H.pkl +download_if_not_exists() { + local url=$1 + local filename=$2 + + if [ ! -f "$filename" ]; then + echo "Downloading $filename..." + wget "$url" -O "$filename" || { + echo "Error: Failed to download $filename" + exit 1 + } + else + echo "$filename already exists, skipping download." + fi +} PODCASTS_PATH="Balalaika2000H" PICKLE_PATH="Balalaika2000H.pkl" PARQUET_PATH="Balalaika2000H.parquet" NUM_WORKERS=4 +PICKLE_URL="https://huggingface.co/datasets/MTUCI/Balalaika2000H/resolve/main/Balalaika2000H.pkl" +PARQUET_URL="https://huggingface.co/datasets/MTUCI/Balalaika2000H/resolve/main/Balalaika2000H.parquet" + +download_if_not_exists "$PICKLE_URL" "$PICKLE_PATH" +download_if_not_exists "$PARQUET_URL" "$PARQUET_PATH" + activate_venv ".user_venv" -bash src/download/download_prepared.sh $PODCASTS_PATH $PICKLE_PATH $NUM_WORKERS -bash src/recovery_from_meta_yamls.sh $PODCASTS_PATH $PARQUET_PATH $NUM_WORKERS \ No newline at end of file +bash src/download/download_prepared.sh "$PODCASTS_PATH" "$PICKLE_PATH" "$NUM_WORKERS" +bash src/recovery_from_meta_yamls.sh "$PODCASTS_PATH" "$PARQUET_PATH" "$NUM_WORKERS" diff --git a/use_meta_500h.sh b/use_meta_500h.sh index 62a6dd2..3f92148 100644 --- a/use_meta_500h.sh +++ b/use_meta_500h.sh @@ -1,5 +1,4 @@ -# bin/bash - +#!/bin/bash activate_venv() { local venv_path=$1 @@ -8,18 +7,36 @@ activate_venv() { exit 1 fi source "$venv_path/bin/activate" - echo "Activated: $(which python)" + echo "Activated virtual environment: $(which python)" } -wget https://huggingface.co/datasets/MTUCI/Balalaika500H/resolve/main/Balalaika500H.parquet -wget https://huggingface.co/datasets/MTUCI/Balalaika500H/resolve/main/Balalaika500H.pkl - -activate_venv ".user_venv" +download_if_not_exists() { + local url=$1 + local filename=$2 + + if [ ! -f "$filename" ]; then + echo "Downloading $filename..." + wget "$url" -O "$filename" || { + echo "Error: Failed to download $filename" + exit 1 + } + else + echo "$filename already exists, skipping download." + fi +} PODCASTS_PATH="Balalaika500H" PICKLE_PATH="Balalaika500H.pkl" PARQUET_PATH="Balalaika500H.parquet" NUM_WORKERS=4 -bash src/download/download_prepared.sh $PODCASTS_PATH $PICKLE_PATH $NUM_WORKERS -bash src/recovery_from_meta_yamls.sh $PODCASTS_PATH $PARQUET_PATH $NUM_WORKERS \ No newline at end of file +PICKLE_URL="https://huggingface.co/datasets/MTUCI/Balalaika500H/resolve/main/Balalaika500H.pkl" +PARQUET_URL="https://huggingface.co/datasets/MTUCI/Balalaika500H/resolve/main/Balalaika500H.parquet" + +download_if_not_exists "$PICKLE_URL" "$PICKLE_PATH" +download_if_not_exists "$PARQUET_URL" "$PARQUET_PATH" + +activate_venv ".user_venv" + +bash src/download/download_prepared.sh "$PODCASTS_PATH" "$PICKLE_PATH" "$NUM_WORKERS" +bash src/recovery_from_meta_yamls.sh "$PODCASTS_PATH" "$PARQUET_PATH" "$NUM_WORKERS"