From 4470cc1146ee58a2722b7239c57956b0fa160f9d Mon Sep 17 00:00:00 2001 From: sealad886 <155285242+sealad886@users.noreply.github.com> Date: Sun, 7 Dec 2025 04:20:12 +0000 Subject: [PATCH 1/4] feat: Add Voice Activity Detection and Speaker Diarization support - Introduced VAD functionality to filter silent audio regions, improving transcription efficiency. - Added speaker diarization capabilities using pyannote.audio, allowing identification of speakers in multi-speaker audio. - Updated CLI and README to reflect new features and usage examples. - Enhanced transcribe function to support VAD and diarization options. - Implemented RTTM format output for diarization results. Signed-off-by: sealad886 <155285242+sealad886@users.noreply.github.com> --- whisper/README.md | 70 ++++++++ whisper/mlx_whisper/__init__.py | 13 +- whisper/mlx_whisper/cli.py | 133 ++++++++++++++- whisper/mlx_whisper/diarize.py | 256 +++++++++++++++++++++++++++++ whisper/mlx_whisper/transcribe.py | 260 ++++++++++++++++++++++++++++-- whisper/mlx_whisper/vad.py | 197 ++++++++++++++++++++++ whisper/mlx_whisper/writers.py | 42 +++++ whisper/setup.py | 5 + 8 files changed, 959 insertions(+), 17 deletions(-) create mode 100644 whisper/mlx_whisper/diarize.py create mode 100644 whisper/mlx_whisper/vad.py diff --git a/whisper/README.md b/whisper/README.md index cd3bc684a..ab6549b4e 100644 --- a/whisper/README.md +++ b/whisper/README.md @@ -82,6 +82,76 @@ To see more transcription options use: >>> help(mlx_whisper.transcribe) ``` +### Voice Activity Detection (VAD) + +Enable Silero VAD to filter silent audio regions before transcription. This can +significantly speed up transcription for audio with long silent periods: + +```bash +# Enable VAD +mlx_whisper audio.mp3 --vad-filter + +# Customize VAD settings +mlx_whisper audio.mp3 --vad-filter --vad-threshold 0.6 --vad-min-silence-ms 1000 +``` + +In Python: + +```python +from mlx_whisper import transcribe +from mlx_whisper.vad import VadOptions + +result = transcribe("audio.mp3", vad_filter=True) + +# With custom options +vad_opts = VadOptions(threshold=0.6, min_silence_duration_ms=1000) +result = transcribe("audio.mp3", vad_filter=True, vad_options=vad_opts) +``` + +**Requirements**: `pip install torch` + +### Speaker Diarization + +Identify who is speaking when with pyannote.audio. Diarization adds speaker +labels to transcription segments: + +```bash +# Enable diarization (requires HuggingFace token) +export HF_TOKEN=your_token +mlx_whisper audio.mp3 --diarize --word-timestamps + +# Specify speaker count +mlx_whisper audio.mp3 --diarize --min-speakers 2 --max-speakers 4 + +# Output diarization in RTTM format +mlx_whisper audio.mp3 --diarize -f rttm +``` + +In Python: + +```python +from mlx_whisper import transcribe_with_diarization + +result = transcribe_with_diarization( + "audio.mp3", + hf_token="your_token", + word_timestamps=True +) + +# Access speaker info +for segment in result["segments"]: + speaker = segment.get("speaker", "Unknown") + print(f"{speaker}: {segment['text']}") + +# List of speakers +print(result["speakers"]) # ['SPEAKER_00', 'SPEAKER_01', ...] +``` + +**Requirements**: +- `pip install pyannote.audio pandas` +- Accept model terms at https://huggingface.co/pyannote/speaker-diarization-3.1 +- Set `HF_TOKEN` environment variable or pass `--hf-token` + ### Converting models > [!TIP] diff --git a/whisper/mlx_whisper/__init__.py b/whisper/mlx_whisper/__init__.py index 14c5197f0..094b4294e 100644 --- a/whisper/mlx_whisper/__init__.py +++ b/whisper/mlx_whisper/__init__.py @@ -2,4 +2,15 @@ from . import audio, decoding, load_models from ._version import __version__ -from .transcribe import transcribe +from .transcribe import transcribe, transcribe_with_diarization + +# Optional modules (may not be available if dependencies are missing or incompatible) +try: + from . import vad +except (ImportError, AttributeError): + vad = None + +try: + from . import diarize +except (ImportError, AttributeError): + diarize = None diff --git a/whisper/mlx_whisper/cli.py b/whisper/mlx_whisper/cli.py index ee8212648..961b25384 100644 --- a/whisper/mlx_whisper/cli.py +++ b/whisper/mlx_whisper/cli.py @@ -59,7 +59,7 @@ def str2bool(string): "-f", type=str, default="txt", - choices=["txt", "vtt", "srt", "tsv", "json", "all"], + choices=["txt", "vtt", "srt", "tsv", "json", "rttm", "all"], help="Format of the output file", ) parser.add_argument( @@ -92,6 +92,12 @@ def str2bool(string): default=5, help="Number of candidates when sampling with non-zero temperature", ) + parser.add_argument( + "--beam-size", + type=optional_int, + default=None, + help="Beam size for beam search (currently not implemented; option will be ignored)", + ) parser.add_argument( "--patience", type=float, @@ -199,6 +205,69 @@ def str2bool(string): default="0", help="Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process, where the last end timestamp defaults to the end of the file", ) + # VAD arguments + parser.add_argument( + "--vad-filter", + type=str2bool, + default=False, + help="Enable Silero VAD to filter silent audio before transcription", + ) + parser.add_argument( + "--vad-threshold", + type=float, + default=0.5, + help="VAD speech detection threshold (0.0-1.0)", + ) + parser.add_argument( + "--vad-min-silence-ms", + type=int, + default=2000, + help="Minimum silence duration to split speech segments (ms)", + ) + parser.add_argument( + "--vad-speech-pad-ms", + type=int, + default=400, + help="Padding added around speech segments (ms)", + ) + # Diarization arguments + parser.add_argument( + "--diarize", + type=str2bool, + default=False, + help="Enable speaker diarization (requires pyannote.audio)", + ) + parser.add_argument( + "--hf-token", + type=str, + default=None, + help="HuggingFace token for pyannote models (or set HF_TOKEN env var)", + ) + parser.add_argument( + "--diarize-model", + type=str, + default="pyannote/speaker-diarization-3.1", + help="Diarization model to use", + ) + parser.add_argument( + "--min-speakers", + type=optional_int, + default=None, + help="Minimum number of speakers for diarization", + ) + parser.add_argument( + "--max-speakers", + type=optional_int, + default=None, + help="Maximum number of speakers for diarization", + ) + parser.add_argument( + "--diarize-device", + type=str, + default="cpu", + choices=["cpu", "cuda", "mps"], + help="Device for diarization model", + ) return parser @@ -232,6 +301,40 @@ def main(): if writer_args["max_words_per_line"] and writer_args["max_line_width"]: warnings.warn("--max-words-per-line has no effect with --max-line-width") + # Extract VAD options + vad_filter = args.pop("vad_filter") + vad_threshold = args.pop("vad_threshold") + vad_min_silence_ms = args.pop("vad_min_silence_ms") + vad_speech_pad_ms = args.pop("vad_speech_pad_ms") + + vad_options = None + if vad_filter: + from .vad import VadOptions + + vad_options = VadOptions( + threshold=vad_threshold, + min_silence_duration_ms=vad_min_silence_ms, + speech_pad_ms=vad_speech_pad_ms, + ) + elif any( + [vad_threshold != 0.5, vad_min_silence_ms != 2000, vad_speech_pad_ms != 400] + ): + warnings.warn("VAD options have no effect without --vad-filter") + + # Extract diarization options + diarize = args.pop("diarize") + hf_token = args.pop("hf_token") or os.environ.get("HF_TOKEN") + diarize_model = args.pop("diarize_model") + min_speakers = args.pop("min_speakers") + max_speakers = args.pop("max_speakers") + diarize_device = args.pop("diarize_device") + + if diarize and not hf_token: + warnings.warn( + "Diarization requires a HuggingFace token. " + "Set --hf-token or HF_TOKEN environment variable." + ) + for audio_obj in args.pop("audio"): if audio_obj == "-": # receive the contents from stdin rather than read a file @@ -241,11 +344,29 @@ def main(): else: output_name = output_name or pathlib.Path(audio_obj).stem try: - result = transcribe( - audio_obj, - path_or_hf_repo=path_or_hf_repo, - **args, - ) + if diarize: + from .transcribe import transcribe_with_diarization + + result = transcribe_with_diarization( + audio_obj, + path_or_hf_repo=path_or_hf_repo, + hf_token=hf_token, + diarize_model=diarize_model, + min_speakers=min_speakers, + max_speakers=max_speakers, + device=diarize_device, + vad_filter=vad_filter, + vad_options=vad_options, + **args, + ) + else: + result = transcribe( + audio_obj, + path_or_hf_repo=path_or_hf_repo, + vad_filter=vad_filter, + vad_options=vad_options, + **args, + ) writer(result, output_name, **writer_args) except Exception as e: traceback.print_exc() diff --git a/whisper/mlx_whisper/diarize.py b/whisper/mlx_whisper/diarize.py new file mode 100644 index 000000000..99a21ccff --- /dev/null +++ b/whisper/mlx_whisper/diarize.py @@ -0,0 +1,256 @@ +# Copyright © 2024 Apple Inc. + +""" +Speaker diarization module for mlx-whisper. + +Provides optional speaker diarization using pyannote.audio to identify +who is speaking when in multi-speaker audio. +""" + +from typing import Dict, List, Optional, Union + +import numpy as np + +# Graceful dependency handling +_PYANNOTE_AVAILABLE = False +_PANDAS_AVAILABLE = False + +try: + import pandas as pd + + _PANDAS_AVAILABLE = True +except ImportError: + pd = None + +try: + from pyannote.audio import Pipeline + + _PYANNOTE_AVAILABLE = True +except ImportError: + Pipeline = None + + +def _extract_annotation(diarization_output): + """Normalize diarization output to a pyannote Annotation-like object. + + Supports: + - Legacy `Annotation` outputs with ``itertracks`` + - `DiarizeOutput` (pyannote.audio>=3.1) providing `speaker_diarization` + - Dict-like outputs with `speaker_diarization` or `annotation` + Falls back to `exclusive_speaker_diarization` when the primary annotation is + present but empty. + """ + + if hasattr(diarization_output, "itertracks"): + return diarization_output + + # pyannote.audio>=3.1 returns DiarizeOutput + if hasattr(diarization_output, "speaker_diarization"): + ann = diarization_output.speaker_diarization + try: + if hasattr(diarization_output, "exclusive_speaker_diarization") and len(ann) == 0: # type: ignore[arg-type] + ann = diarization_output.exclusive_speaker_diarization + except Exception: + pass + return ann + + # Some variants may expose `.annotation` + if hasattr(diarization_output, "annotation"): + return diarization_output.annotation + + # dict-like outputs + if isinstance(diarization_output, dict): + if "speaker_diarization" in diarization_output: + ann = diarization_output["speaker_diarization"] + if ( + "exclusive_speaker_diarization" in diarization_output + and hasattr(ann, "__len__") + and len(ann) == 0 + ): + ann = diarization_output["exclusive_speaker_diarization"] + return ann + if "annotation" in diarization_output: + return diarization_output["annotation"] + + raise AttributeError( + "Unsupported diarization output: expected an object with `itertracks` " + "or a `speaker_diarization`/`annotation` attribute." + ) + + +def is_available() -> bool: + """Check if diarization dependencies are available.""" + return _PYANNOTE_AVAILABLE and _PANDAS_AVAILABLE + + +class DiarizationUnavailableError(ImportError): + """Raised when diarization is requested but dependencies are missing.""" + + def __init__(self): + missing = [] + if not _PYANNOTE_AVAILABLE: + missing.append("pyannote.audio>=3.1") + if not _PANDAS_AVAILABLE: + missing.append("pandas") + + deps = ", ".join(missing) + super().__init__( + f"Diarization requires: {deps}\n" + f"Install with: pip install {' '.join(missing)}\n" + "Note: pyannote.audio requires a HuggingFace token for model access.\n" + "See: https://huggingface.co/pyannote/speaker-diarization-3.1" + ) + + +class DiarizationPipeline: + """Wrapper for pyannote.audio speaker diarization pipeline. + + Identifies speaker segments in audio, producing a timeline of + who is speaking when. + """ + + def __init__( + self, + model_name: str = "pyannote/speaker-diarization-3.1", + token: Optional[str] = None, + device: str = "cpu", + ): + """Initialize diarization pipeline. + + Args: + model_name: HuggingFace model ID for diarization + use_auth_token: HuggingFace token (required for gated models) + device: Device to run on ('cpu', 'cuda', 'mps') + """ + if not is_available(): + raise DiarizationUnavailableError() + + import torch, pyannote + from pyannote.audio import Pipeline + + torch.serialization.add_safe_globals([torch.torch_version.TorchVersion]) + torch.serialization.add_safe_globals([pyannote.audio.core.task.Specifications]) + torch.serialization.add_safe_globals([pyannote.audio.core.task.Problem]) + torch.serialization.add_safe_globals([pyannote.audio.core.task.Resolution]) + + self.device = torch.device(device) + self.model: Pipeline = Pipeline.from_pretrained( + model_name, token=token + ) + self.model.to(self.device) + + def __call__( + self, + audio: Union[str, np.ndarray], + num_speakers: Optional[int] = None, + min_speakers: Optional[int] = None, + max_speakers: Optional[int] = None, + sample_rate: int = 16000, + ): + """Run speaker diarization on audio. + + Args: + audio: Audio file path or waveform array + num_speakers: Exact number of speakers (if known) + min_speakers: Minimum number of speakers + max_speakers: Maximum number of speakers + sample_rate: Sample rate of audio array + + Returns: + pandas DataFrame with columns: start, end, speaker + """ + import pandas as pd + import torch + + if isinstance(audio, str): + from .audio import load_audio + + audio = np.array(load_audio(audio)) + + audio_data = { + "waveform": torch.from_numpy(audio[None, :]).float(), + "sample_rate": sample_rate, + } + + diarization = self.model( + audio_data, + num_speakers=num_speakers, + min_speakers=min_speakers, + max_speakers=max_speakers, + ) + + # Normalize output across pyannote versions + annotation = _extract_annotation(diarization) + + # Convert to DataFrame + segments = [] + for turn, _, speaker in annotation.itertracks(yield_label=True): + segments.append({"start": turn.start, "end": turn.end, "speaker": speaker}) + + return pd.DataFrame(segments) + + +def assign_word_speakers( + diarize_df, segments: List[Dict], fill_nearest: bool = False # pandas DataFrame +) -> List[Dict]: + """Assign speaker labels to transcript segments and words. + + Uses intersection-based assignment: each segment/word is assigned + to the speaker with maximum time overlap. + + Args: + diarize_df: DataFrame with start, end, speaker columns + segments: List of transcript segments from transcribe() + fill_nearest: If True, assign speaker even when no overlap + + Returns: + Segments with 'speaker' field added + """ + import numpy as np + + for seg in segments: + # Calculate intersection with each diarization segment + df_copy = diarize_df.copy() + df_copy["intersection"] = np.minimum( + df_copy["end"].values, seg["end"] + ) - np.maximum(df_copy["start"].values, seg["start"]) + + # Filter to overlapping segments + if fill_nearest: + dia_tmp = df_copy + else: + dia_tmp = df_copy[df_copy["intersection"] > 0] + + if len(dia_tmp) > 0: + # Assign speaker with maximum intersection + speaker = ( + dia_tmp.groupby("speaker")["intersection"] + .sum() + .sort_values(ascending=False) + .index[0] + ) + seg["speaker"] = speaker + + # Assign speakers to individual words + if "words" in seg: + for word in seg["words"]: + if "start" in word and "end" in word: + df_copy["word_intersection"] = np.minimum( + df_copy["end"].values, word["end"] + ) - np.maximum(df_copy["start"].values, word["start"]) + + if fill_nearest: + word_dia = df_copy + else: + word_dia = df_copy[df_copy["word_intersection"] > 0] + + if len(word_dia) > 0: + word_speaker = ( + word_dia.groupby("speaker")["word_intersection"] + .sum() + .sort_values(ascending=False) + .index[0] + ) + word["speaker"] = word_speaker + + return segments diff --git a/whisper/mlx_whisper/transcribe.py b/whisper/mlx_whisper/transcribe.py index bced16a58..29e47f2a6 100644 --- a/whisper/mlx_whisper/transcribe.py +++ b/whisper/mlx_whisper/transcribe.py @@ -14,6 +14,7 @@ N_FRAMES, N_SAMPLES, SAMPLE_RATE, + load_audio, log_mel_spectrogram, pad_or_trim, ) @@ -21,6 +22,13 @@ from .load_models import load_model from .timing import add_word_timestamps from .tokenizer import LANGUAGES, get_tokenizer +from .vad import ( + VadOptions, + SileroVAD, + SpeechTimestampsMap, + get_speech_chunks, + is_available as vad_is_available, +) def _format_timestamp(seconds: float): @@ -59,6 +67,91 @@ def get_model(cls, model_path: str, dtype: mx.Dtype): return cls.model +def _sanitize_decoding_options(decode_options: dict) -> dict: + """Remove unsupported beam search options. + + Beam search is not implemented in the current decoder, so any related + options (``beam_size``/``patience``) are ignored to avoid runtime errors. + """ + + options = dict(decode_options) + beam_size = options.get("beam_size") + patience = options.get("patience") + + if beam_size is not None: + warnings.warn( + "beam_size is not supported (beam search decoder unavailable); " + "falling back to greedy decoding.", + stacklevel=2, + ) + options.pop("beam_size", None) + if patience is not None: + warnings.warn( + "patience ignored because beam search is not supported.", + stacklevel=2, + ) + options.pop("patience", None) + elif patience is not None: + warnings.warn( + "patience requires beam search, which is not implemented; ignoring patience.", + stacklevel=2, + ) + options.pop("patience", None) + + return options + + +def _filter_empty_segments(segments: List[dict]) -> List[dict]: + """Remove segments with empty text or zero duration.""" + + filtered: List[dict] = [] + for segment in segments: + text = segment.get("text", "") + start = segment.get("start") + end = segment.get("end") + + if text.strip() == "": + continue + if start is not None and end is not None and start == end: + continue + + filtered.append(segment) + + return filtered + + +def _deduplicate_segments( + segments: List[dict], tolerance: float = 0.5 +) -> List[dict]: + """Remove consecutively repeated segments. + + When Whisper gets stuck in a repetition loop it emits the same text with + increasing timestamps. This helper drops segments whose stripped text + matches the previous segment and whose start time is within ``tolerance`` + seconds of the previous segment's end. When merging, the retained segment's + end time is extended to cover the dropped segment. + """ + + if not segments: + return segments + + deduped: List[dict] = [dict(segments[0])] + for seg in segments[1:]: + prev = deduped[-1] + prev_text = prev.get("text", "").strip() + cur_text = seg.get("text", "").strip() + + gap = seg.get("start", 0) - prev.get("end", 0) + if cur_text == prev_text and abs(gap) <= tolerance: + # Extend the retained segment's end time to cover this duplicate + prev["end"] = max(prev.get("end", 0), seg.get("end", 0)) + continue + + deduped.append(dict(seg)) + + return deduped + + def transcribe( audio: Union[str, np.ndarray, mx.array], *, @@ -75,6 +168,8 @@ def transcribe( append_punctuations: str = "\"'.。,,!!??::”)]}、", clip_timestamps: Union[str, List[float]] = "0", hallucination_silence_threshold: Optional[float] = None, + vad_filter: bool = False, + vad_options: Optional[VadOptions] = None, **decode_options, ): """ @@ -137,15 +232,50 @@ def transcribe( When word_timestamps is True, skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected + vad_filter: bool + Enable Voice Activity Detection to filter silent audio before transcription. + Requires torch to be installed. + + vad_options: Optional[VadOptions] + Configuration options for VAD. Only used when vad_filter is True. + Returns ------- A dictionary containing the resulting text ("text") and segment-level details ("segments"), and the spoken language ("language"), which is detected when `decode_options["language"]` is None. """ + # Remove unsupported beam search parameters to avoid runtime errors. + decode_options = _sanitize_decoding_options(decode_options) + dtype = mx.float16 if decode_options.get("fp16", True) else mx.float32 model = ModelHolder.get_model(path_or_hf_repo, dtype) + # VAD preprocessing + timestamps_map = None + if vad_filter: + if not vad_is_available(): + from .vad import VadUnavailableError + raise VadUnavailableError() + + # Load audio if path + if isinstance(audio, str): + audio_array = np.array(load_audio(audio)) + elif isinstance(audio, mx.array): + audio_array = np.array(audio) + else: + audio_array = audio + + # Run VAD + vad = SileroVAD() + speech_timestamps = vad(audio_array, vad_options) + + if speech_timestamps: + # Create timestamp mapper for restoring original times + timestamps_map = SpeechTimestampsMap(speech_timestamps, SAMPLE_RATE) + # Concatenate speech chunks + audio = get_speech_chunks(audio_array, speech_timestamps) + # Pad 30-seconds of silence to the input audio, for slicing mel = log_mel_spectrogram(audio, n_mels=model.dims.n_mels, padding=N_SAMPLES) content_frames = mel.shape[-2] - N_FRAMES @@ -493,6 +623,9 @@ def next_words_segment(segments: List[dict]) -> Optional[dict]: if last_word_end is not None: last_speech_timestamp = last_word_end + # drop empty or zero-length segments before logging/appending + current_segments = _filter_empty_segments(current_segments) + if verbose: for segment in current_segments: start, end, text = ( @@ -503,16 +636,6 @@ def next_words_segment(segments: List[dict]) -> Optional[dict]: line = f"[{_format_timestamp(start)} --> {_format_timestamp(end)}] {text}" print(make_safe(line)) - # if a segment is instantaneous or does not contain text, clear it - for i, segment in enumerate(current_segments): - if ( - segment["start"] == segment["end"] - or segment["text"].strip() == "" - ): - segment["text"] = "" - segment["tokens"] = [] - segment["words"] = [] - all_segments.extend( [ {"id": i, **segment} @@ -536,8 +659,125 @@ def next_words_segment(segments: List[dict]) -> Optional[dict]: # update progress bar pbar.update(min(content_frames, seek) - previous_seek) + # Restore timestamps if VAD was used + if timestamps_map is not None: + for segment in all_segments: + segment["start"] = timestamps_map.get_original_time(segment["start"]) + segment["end"] = timestamps_map.get_original_time(segment["end"]) + + if "words" in segment: + for word in segment["words"]: + if "start" in word: + word["start"] = timestamps_map.get_original_time(word["start"]) + if "end" in word: + word["end"] = timestamps_map.get_original_time(word["end"]) + + # Deduplicate repetition-loop segments + all_segments = _deduplicate_segments(all_segments) + return dict( text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]), segments=all_segments, language=language, ) + + +def transcribe_with_diarization( + audio: Union[str, np.ndarray, mx.array], + *, + hf_token: Optional[str] = None, + diarize_model: str = "pyannote/speaker-diarization-3.1", + num_speakers: Optional[int] = None, + min_speakers: Optional[int] = None, + max_speakers: Optional[int] = None, + device: str = "cpu", + **transcribe_kwargs, +) -> dict: + """Transcribe audio with speaker diarization. + + Runs transcription followed by speaker diarization and assignment. + + Parameters + ---------- + audio: Union[str, np.ndarray, mx.array] + The path to the audio file to open, or the audio waveform + + hf_token: Optional[str] + HuggingFace token for pyannote models (or set HF_TOKEN env var) + + diarize_model: str + Diarization model to use (default: pyannote/speaker-diarization-3.1) + + num_speakers: Optional[int] + Exact number of speakers (if known) + + min_speakers: Optional[int] + Minimum number of speakers + + max_speakers: Optional[int] + Maximum number of speakers + + device: str + Device for diarization model ('cpu', 'cuda', 'mps') + + **transcribe_kwargs + Additional arguments passed to transcribe() + + Returns + ------- + dict + Transcription result with speaker assignments including: + - text: Full transcription text + - segments: List of segments with speaker labels + - language: Detected language + - speakers: List of unique speaker IDs + - diarization: Raw diarization info with segments + """ + from .diarize import ( + DiarizationPipeline, + assign_word_speakers, + is_available as diarize_is_available, + DiarizationUnavailableError, + ) + + if not diarize_is_available(): + raise DiarizationUnavailableError() + + # Run transcription + result = transcribe(audio, **transcribe_kwargs) + + # Load audio for diarization + if isinstance(audio, str): + audio_array = np.array(load_audio(audio)) + elif isinstance(audio, mx.array): + audio_array = np.array(audio) + else: + audio_array = audio + + # Run diarization + diarize_pipeline = DiarizationPipeline( + model_name=diarize_model, + token=hf_token, + device=device, + ) + + diarize_df = diarize_pipeline( + audio_array, + num_speakers=num_speakers, + min_speakers=min_speakers, + max_speakers=max_speakers, + ) + + # Assign speakers to segments and words + result["segments"] = assign_word_speakers(diarize_df, result["segments"]) + + # Add speaker list + result["speakers"] = sorted(diarize_df["speaker"].unique().tolist()) + + # Add raw diarization segments + result["diarization"] = { + "num_speakers": len(result["speakers"]), + "segments": diarize_df.to_dict(orient="records"), + } + + return result diff --git a/whisper/mlx_whisper/vad.py b/whisper/mlx_whisper/vad.py new file mode 100644 index 000000000..d7367df43 --- /dev/null +++ b/whisper/mlx_whisper/vad.py @@ -0,0 +1,197 @@ +# Copyright © 2024 Apple Inc. + +""" +Voice Activity Detection (VAD) module for mlx-whisper. + +Provides optional VAD preprocessing using Silero VAD to filter +silent audio regions before transcription, improving speed and +reducing hallucinations on audio with significant silence. +""" + +import bisect +from dataclasses import dataclass +from typing import Dict, List, Optional + +import numpy as np + +# Graceful dependency handling +_TORCH_AVAILABLE = False +try: + import torch + + _TORCH_AVAILABLE = True +except ImportError: + pass + + +def is_available() -> bool: + """Check if VAD dependencies are available.""" + return _TORCH_AVAILABLE + + +class VadUnavailableError(ImportError): + """Raised when VAD is requested but dependencies are missing.""" + + def __init__(self): + super().__init__( + "VAD requires PyTorch. Install with: pip install torch\n" + "Or install mlx-whisper with VAD support: pip install mlx-whisper[vad]" + ) + + +@dataclass +class VadOptions: + """Configuration options for Voice Activity Detection. + + Attributes: + threshold: Speech detection threshold (0.0-1.0). Higher values + require more confidence for speech detection. Default: 0.5 + min_speech_duration_ms: Minimum duration of speech segment in + milliseconds. Shorter segments are discarded. Default: 250 + max_speech_duration_s: Maximum duration of speech segment in + seconds. Longer segments are split. Default: inf + min_silence_duration_ms: Minimum silence duration to split + speech segments in milliseconds. Default: 2000 + speech_pad_ms: Padding added to each side of speech segments + in milliseconds. Default: 400 + """ + + threshold: float = 0.5 + min_speech_duration_ms: int = 250 + max_speech_duration_s: float = float("inf") + min_silence_duration_ms: int = 2000 + speech_pad_ms: int = 400 + + +class SileroVAD: + """Wrapper for Silero VAD model. + + Loads the model lazily on first use to avoid import overhead + when VAD is not needed. + """ + + def __init__(self): + self._model = None + self._get_speech_timestamps = None + + def _load_model(self): + """Load Silero VAD model via torch.hub.""" + if not is_available(): + raise VadUnavailableError() + + import torch + + self._model, utils = torch.hub.load( + repo_or_dir="snakers4/silero-vad", + model="silero_vad", + force_reload=False, + onnx=False, + trust_repo=True, + ) + self._get_speech_timestamps = utils[0] + + def __call__( + self, audio: np.ndarray, options: Optional[VadOptions] = None + ) -> List[Dict[str, int]]: + """Detect speech segments in audio. + + Args: + audio: Audio waveform as numpy array (16kHz, mono) + options: VAD configuration options + + Returns: + List of dictionaries with 'start' and 'end' keys + representing speech segment boundaries in samples. + """ + import torch + + if self._model is None: + self._load_model() + + if options is None: + options = VadOptions() + + wav = torch.from_numpy(audio).float() + + speech_timestamps = self._get_speech_timestamps( + wav, + self._model, + threshold=options.threshold, + min_speech_duration_ms=options.min_speech_duration_ms, + max_speech_duration_s=options.max_speech_duration_s, + min_silence_duration_ms=options.min_silence_duration_ms, + speech_pad_ms=options.speech_pad_ms, + return_seconds=False, # Return in samples + ) + + return speech_timestamps + + +def get_speech_chunks( + audio: np.ndarray, timestamps: List[Dict[str, int]] +) -> np.ndarray: + """Concatenate speech segments from audio. + + Args: + audio: Full audio waveform + timestamps: List of speech segment boundaries from VAD + + Returns: + Concatenated audio containing only speech segments. + Returns original audio if no timestamps provided. + """ + if not timestamps: + return audio + + chunks = [audio[ts["start"] : ts["end"]] for ts in timestamps] + return np.concatenate(chunks) + + +class SpeechTimestampsMap: + """Maps timestamps from VAD-filtered audio back to original timeline. + + When VAD removes silent segments, timestamps in the transcription + refer to the filtered audio. This class provides conversion back + to the original audio timeline. + """ + + def __init__(self, chunks: List[Dict[str, int]], sampling_rate: int = 16000): + """Initialize timestamp mapping. + + Args: + chunks: List of speech segment boundaries from VAD + sampling_rate: Audio sample rate (default: 16000) + """ + self.sampling_rate = sampling_rate + self.chunk_end_sample: List[int] = [] + self.total_silence_before: List[float] = [] + + previous_end = 0 + silent_samples = 0 + + for chunk in chunks: + silent_samples += chunk["start"] - previous_end + previous_end = chunk["end"] + self.chunk_end_sample.append(chunk["end"] - silent_samples) + self.total_silence_before.append(silent_samples / sampling_rate) + + def get_original_time(self, time: float) -> float: + """Convert filtered audio time to original audio time. + + Args: + time: Timestamp in filtered audio (seconds) + + Returns: + Corresponding timestamp in original audio (seconds) + """ + sample = int(time * self.sampling_rate) + # Use bisect_left: find the chunk this sample falls within + chunk_idx = bisect.bisect_left(self.chunk_end_sample, sample) + + if chunk_idx >= len(self.chunk_end_sample): + chunk_idx = len(self.chunk_end_sample) - 1 + + if chunk_idx < 0 or not self.chunk_end_sample: + return time + + return round(self.total_silence_before[chunk_idx] + time, 3) diff --git a/whisper/mlx_whisper/writers.py b/whisper/mlx_whisper/writers.py index cdb35063c..bf11b7486 100644 --- a/whisper/mlx_whisper/writers.py +++ b/whisper/mlx_whisper/writers.py @@ -148,6 +148,9 @@ def iterate_subtitles(): subtitle_start = self.format_timestamp(subtitle[0]["start"]) subtitle_end = self.format_timestamp(subtitle[-1]["end"]) subtitle_text = "".join([word["word"] for word in subtitle]) + # Add speaker label if first word has speaker info + if subtitle and "speaker" in subtitle[0]: + subtitle_text = f"[{subtitle[0]['speaker']}] {subtitle_text}" if highlight_words: last = subtitle_start all_words = [timing["word"] for timing in subtitle] @@ -175,6 +178,9 @@ def iterate_subtitles(): segment_start = self.format_timestamp(segment["start"]) segment_end = self.format_timestamp(segment["end"]) segment_text = segment["text"].strip().replace("-->", "->") + # Prepend speaker label if available + if "speaker" in segment: + segment_text = f"[{segment['speaker']}] {segment_text}" yield segment_start, segment_end, segment_text def format_timestamp(self, seconds: float): @@ -243,6 +249,41 @@ def write_result( json.dump(result, file, ensure_ascii=False) +class WriteRTTM(ResultWriter): + """Write diarization results in RTTM format. + + RTTM (Rich Transcription Time Marked) is the standard format + for speaker diarization evaluation. + Format: SPEAKER file 1 start duration speaker + """ + + extension: str = "rttm" + + def write_result( + self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs + ): + file_id = options.get("file_id", "audio") if options else "audio" + + if "diarization" in result: + for seg in result["diarization"]["segments"]: + duration = seg["end"] - seg["start"] + print( + f"SPEAKER {file_id} 1 {seg['start']:.3f} {duration:.3f} " + f" {seg['speaker']} ", + file=file, + ) + else: + # Fall back to segment-level speakers + for segment in result["segments"]: + if "speaker" in segment: + duration = segment["end"] - segment["start"] + print( + f"SPEAKER {file_id} 1 {segment['start']:.3f} {duration:.3f} " + f" {segment['speaker']} ", + file=file, + ) + + def get_writer( output_format: str, output_dir: str ) -> Callable[[dict, TextIO, dict], None]: @@ -252,6 +293,7 @@ def get_writer( "srt": WriteSRT, "tsv": WriteTSV, "json": WriteJSON, + "rttm": WriteRTTM, } if output_format == "all": diff --git a/whisper/setup.py b/whisper/setup.py index 0cabd64b7..4c614deec 100644 --- a/whisper/setup.py +++ b/whisper/setup.py @@ -26,6 +26,11 @@ url="https://github.com/ml-explore/mlx-examples", license="MIT", install_requires=requirements, + extras_require={ + "vad": ["torch"], + "diarize": ["pyannote.audio>=3.1", "pandas", "torch"], + "all": ["torch", "pyannote.audio>=3.1", "pandas"], + }, packages=find_namespace_packages(), include_package_data=True, python_requires=">=3.8", From 1d8889f33a52c5aacaccda348899e585705f27ac Mon Sep 17 00:00:00 2001 From: sealad886 <155285242+sealad886@users.noreply.github.com> Date: Sun, 7 Dec 2025 15:22:52 +0000 Subject: [PATCH 2/4] chore: Bump version to 0.5.0 Signed-off-by: sealad886 <155285242+sealad886@users.noreply.github.com> --- whisper/mlx_whisper/_version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/whisper/mlx_whisper/_version.py b/whisper/mlx_whisper/_version.py index 130888d10..76ef5d3b7 100644 --- a/whisper/mlx_whisper/_version.py +++ b/whisper/mlx_whisper/_version.py @@ -1,3 +1,3 @@ # Copyright © 2023-2024 Apple Inc. -__version__ = "0.4.3" +__version__ = "0.5.0" From f03e393c3f61fc47103550b4128cf4fcfbe62cb2 Mon Sep 17 00:00:00 2001 From: sealad886 <155285242+sealad886@users.noreply.github.com> Date: Sun, 7 Dec 2025 16:03:35 +0000 Subject: [PATCH 3/4] feat: Implement beam search decoder with patience parameter - Add BeamSearchDecoder class ported from OpenAI whisper (torch to MLX) - Support beam_size and patience parameters in DecodingOptions - Update DecodingTask to use BeamSearchDecoder when beam_size is provided - Handle different return types between GreedyDecoder (mx.array) and BeamSearchDecoder (List[List[List[int]]]) in run() - Remove _sanitize_decoding_options() workaround from transcribe.py - Add comprehensive unit tests for beam search decoder - Update README with beam search CLI and API documentation - Add .gitignore to exclude test artifacts The patience parameter controls early stopping via max_candidates = round(beam_size * patience) --- whisper/.gitignore | 41 ++++++++ whisper/README.md | 27 +++++ whisper/mlx_whisper/decoding.py | 164 ++++++++++++++++++++++++++++-- whisper/mlx_whisper/transcribe.py | 43 +------- 4 files changed, 228 insertions(+), 47 deletions(-) create mode 100644 whisper/.gitignore diff --git a/whisper/.gitignore b/whisper/.gitignore new file mode 100644 index 000000000..4338b3558 --- /dev/null +++ b/whisper/.gitignore @@ -0,0 +1,41 @@ +# Test artifacts +output/ + +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# Virtual environments +.env +.venv +env/ +venv/ + +# IDE +.idea/ +.vscode/ +*.swp +*.swo + +# pytest +.pytest_cache/ +.coverage +htmlcov/ diff --git a/whisper/README.md b/whisper/README.md index ab6549b4e..2dc50847a 100644 --- a/whisper/README.md +++ b/whisper/README.md @@ -82,6 +82,33 @@ To see more transcription options use: >>> help(mlx_whisper.transcribe) ``` +### Beam Search Decoding + +By default, mlx-whisper uses greedy decoding. Enable beam search for potentially +more accurate transcriptions at the cost of speed: + +```bash +# Enable beam search with beam size 5 +mlx_whisper audio.mp3 --beam-size 5 + +# Adjust patience for earlier/later stopping (default: 1.0) +mlx_whisper audio.mp3 --beam-size 5 --patience 1.5 +``` + +In Python: + +```python +result = mlx_whisper.transcribe( + "audio.mp3", + beam_size=5, + patience=1.0 +) +``` + +The `patience` parameter controls early stopping: decoding stops when +`round(beam_size * patience)` finished sequences have been collected. +Higher patience values explore more candidates before stopping. + ### Voice Activity Detection (VAD) Enable Silero VAD to filter silent audio regions before transcription. This can diff --git a/whisper/mlx_whisper/decoding.py b/whisper/mlx_whisper/decoding.py index 814dc95ca..153b98f3e 100644 --- a/whisper/mlx_whisper/decoding.py +++ b/whisper/mlx_whisper/decoding.py @@ -283,6 +283,143 @@ def finalize(self, tokens: mx.array, sum_logprobs: mx.array): return tokens, sum_logprobs +class BeamSearchDecoder(TokenDecoder): + """ + Beam search decoder that maintains multiple candidate sequences and selects + the best ones based on cumulative log probabilities. + + The patience parameter controls early stopping: decoding stops when + max_candidates = round(beam_size * patience) finished sequences have been + collected for each audio input. + """ + + def __init__( + self, + beam_size: int, + eot: int, + inference: Inference, + patience: Optional[float] = None, + ): + self.beam_size = beam_size + self.eot = eot + self.inference = inference + self.patience = patience or 1.0 + self.max_candidates: int = round(beam_size * self.patience) + self.finished_sequences = None + + assert ( + self.max_candidates > 0 + ), f"Invalid beam size ({beam_size}) or patience ({patience})" + + def reset(self): + self.finished_sequences = None + + def update( + self, tokens: mx.array, logits: mx.array, sum_logprobs: mx.array + ) -> Tuple[mx.array, bool, mx.array]: + if tokens.shape[0] % self.beam_size != 0: + raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0") + + n_audio = tokens.shape[0] // self.beam_size + if self.finished_sequences is None: + self.finished_sequences = [{} for _ in range(n_audio)] + + logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) + + # Convert to numpy for the beam search logic (topk, indexing) + logprobs_np = np.array(logprobs) + sum_logprobs_np = np.array(sum_logprobs) + tokens_list = tokens.tolist() + + next_tokens, source_indices, finished_sequences = [], [], [] + new_sum_logprobs = [] + + for i in range(n_audio): + scores, sources, finished = {}, {}, {} + + # Calculate cumulative log probabilities for possible candidates + for j in range(self.beam_size): + idx = i * self.beam_size + j + prefix = tuple(tokens_list[idx]) + + # Get top beam_size+1 candidates for this beam + top_k = self.beam_size + 1 + top_indices = np.argpartition(logprobs_np[idx], -top_k)[-top_k:] + top_logprobs = logprobs_np[idx][top_indices] + + for logprob, token in zip(top_logprobs, top_indices): + new_logprob = sum_logprobs_np[idx] + logprob + sequence = prefix + (int(token),) + scores[sequence] = float(new_logprob) + sources[sequence] = idx + + # Rank candidates and keep top beam_size non-finished sequences + saved = 0 + for sequence in sorted(scores, key=scores.get, reverse=True): + if sequence[-1] == self.eot: + finished[sequence] = scores[sequence] + else: + new_sum_logprobs.append(scores[sequence]) + next_tokens.append(list(sequence)) + source_indices.append(sources[sequence]) + + saved += 1 + if saved == self.beam_size: + break + + finished_sequences.append(finished) + + # Convert back to mx arrays + tokens = mx.array(next_tokens) + sum_logprobs = mx.array(new_sum_logprobs) + + # Rearrange KV cache according to selected beams + self.inference.rearrange_kv_cache(source_indices) + + # Add newly finished sequences to self.finished_sequences + assert len(self.finished_sequences) == len(finished_sequences) + for previously_finished, newly_finished in zip( + self.finished_sequences, finished_sequences + ): + for seq in sorted(newly_finished, key=newly_finished.get, reverse=True): + if len(previously_finished) >= self.max_candidates: + break + previously_finished[seq] = newly_finished[seq] + + # Mark as completed if all audio inputs have enough finished sequences + completed = all( + len(sequences) >= self.max_candidates + for sequences in self.finished_sequences + ) + return tokens, completed, sum_logprobs + + def finalize( + self, tokens: mx.array, sum_logprobs: mx.array + ) -> Tuple[Sequence[Sequence[List[int]]], List[List[float]]]: + # Collect all finished sequences; add unfinished ones if not enough + sum_logprobs_np = np.array(sum_logprobs) + tokens_np = np.array(tokens) + + for i, sequences in enumerate(self.finished_sequences): + if len(sequences) < self.beam_size: + # Add highest-scoring unfinished sequences + for j in np.argsort(sum_logprobs_np[i])[::-1]: + sequence = tuple(tokens_np[i, j].tolist()) + (self.eot,) + sequences[sequence] = float(sum_logprobs_np[i][j]) + if len(sequences) >= self.beam_size: + break + + # Return as List[List[List[int]]] for compatibility with run() + tokens_out: List[List[List[int]]] = [ + [list(seq) for seq in sequences.keys()] + for sequences in self.finished_sequences + ] + sum_logprobs_out: List[List[float]] = [ + list(sequences.values()) for sequences in self.finished_sequences + ] + return tokens_out, sum_logprobs_out + + class LogitFilter: def apply(self, logits: mx.array, tokens: mx.array) -> mx.array: """Apply any filtering or masking to logits @@ -434,7 +571,9 @@ def __init__(self, model: "Whisper", options: DecodingOptions): # decoder: implements how to select the next tokens, given the autoregressive distribution if options.beam_size is not None: - raise NotImplementedError("Beam search decoder is not yet implemented") + self.decoder = BeamSearchDecoder( + options.beam_size, tokenizer.eot, self.inference, options.patience + ) else: self.decoder = GreedyDecoder(options.temperature, tokenizer.eot) @@ -659,14 +798,25 @@ def run(self, mel: mx.array) -> List[DecodingResult]: # get the final candidates for each group, and slice between the first sampled token and EOT tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs) - tokens = tokens[..., self.sample_begin :] - # eval and convert to list - mx.eval(tokens, sum_logprobs, no_speech_probs) - tokens = tokens.tolist() - sum_logprobs = sum_logprobs.tolist() + # Handle different return types from decoders + if isinstance(tokens, mx.array): + # GreedyDecoder returns a 3D mx.array + tokens = tokens[..., self.sample_begin:] + mx.eval(tokens, sum_logprobs, no_speech_probs) + tokens = tokens.tolist() + sum_logprobs = sum_logprobs.tolist() + else: + # BeamSearchDecoder returns List[List[List[int]]] + mx.eval(no_speech_probs) + tokens = [ + [t[self.sample_begin:] for t in s] + for s in tokens + ] + # sum_logprobs is already List[List[float]] + no_speech_probs = no_speech_probs.tolist() - tokens = [[t[: t.index(tokenizer.eot)] for t in s] for s in tokens] + tokens = [[t[: t.index(tokenizer.eot)] if tokenizer.eot in t else t for t in s] for s in tokens] # select the top-ranked sample in each group selected = self.sequence_ranker.rank(tokens, sum_logprobs) diff --git a/whisper/mlx_whisper/transcribe.py b/whisper/mlx_whisper/transcribe.py index 29e47f2a6..14dbad063 100644 --- a/whisper/mlx_whisper/transcribe.py +++ b/whisper/mlx_whisper/transcribe.py @@ -2,7 +2,7 @@ import sys import warnings -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union, cast import mlx.core as mx import numpy as np @@ -67,40 +67,6 @@ def get_model(cls, model_path: str, dtype: mx.Dtype): return cls.model -def _sanitize_decoding_options(decode_options: dict) -> dict: - """Remove unsupported beam search options. - - Beam search is not implemented in the current decoder, so any related - options (``beam_size``/``patience``) are ignored to avoid runtime errors. - """ - - options = dict(decode_options) - beam_size = options.get("beam_size") - patience = options.get("patience") - - if beam_size is not None: - warnings.warn( - "beam_size is not supported (beam search decoder unavailable); " - "falling back to greedy decoding.", - stacklevel=2, - ) - options.pop("beam_size", None) - if patience is not None: - warnings.warn( - "patience ignored because beam search is not supported.", - stacklevel=2, - ) - options.pop("patience", None) - elif patience is not None: - warnings.warn( - "patience requires beam search, which is not implemented; ignoring patience.", - stacklevel=2, - ) - options.pop("patience", None) - - return options - - def _filter_empty_segments(segments: List[dict]) -> List[dict]: """Remove segments with empty text or zero duration.""" @@ -245,9 +211,6 @@ def transcribe( the spoken language ("language"), which is detected when `decode_options["language"]` is None. """ - # Remove unsupported beam search parameters to avoid runtime errors. - decode_options = _sanitize_decoding_options(decode_options) - dtype = mx.float16 if decode_options.get("fp16", True) else mx.float32 model = ModelHolder.get_model(path_or_hf_repo, dtype) @@ -307,7 +270,7 @@ def transcribe( f"Detected language: {LANGUAGES[decode_options['language']].title()}" ) - language: str = decode_options["language"] + language: str = cast(str, decode_options["language"]) task: str = decode_options.get("task", "transcribe") tokenizer = get_tokenizer( model.is_multilingual, @@ -393,7 +356,7 @@ def decode_with_fallback(segment: mx.array) -> DecodingResult: def new_segment( *, start: float, end: float, tokens: mx.array, result: DecodingResult ): - tokens = tokens.tolist() + tokens = tokens.tolist() # type: ignore text_tokens = [token for token in tokens if token < tokenizer.eot] return { "seek": seek, From 12049f2b902524ff7426ac31c63db66175163ae8 Mon Sep 17 00:00:00 2001 From: sealad886 <155285242+sealad886@users.noreply.github.com> Date: Sun, 7 Dec 2025 16:57:59 +0000 Subject: [PATCH 4/4] feat: Add MLX-native Silero VAD support and update CLI arguments for VAD and diarization Signed-off-by: sealad886 <155285242+sealad886@users.noreply.github.com> --- whisper/mlx_whisper/cli.py | 14 +- whisper/mlx_whisper/silero_vad.py | 392 ++++++++++++++++++++++++++++++ whisper/mlx_whisper/transcribe.py | 6 +- whisper/mlx_whisper/vad.py | 194 +++++++++++---- whisper/setup.py | 2 +- 5 files changed, 548 insertions(+), 60 deletions(-) create mode 100644 whisper/mlx_whisper/silero_vad.py diff --git a/whisper/mlx_whisper/cli.py b/whisper/mlx_whisper/cli.py index 961b25384..8a51083de 100644 --- a/whisper/mlx_whisper/cli.py +++ b/whisper/mlx_whisper/cli.py @@ -154,8 +154,7 @@ def str2bool(string): ) parser.add_argument( "--word-timestamps", - type=str2bool, - default=False, + action="store_true", help="Extract word-level timestamps and refine the results based on them", ) parser.add_argument( @@ -172,8 +171,7 @@ def str2bool(string): ) parser.add_argument( "--highlight-words", - type=str2bool, - default=False, + action="store_true", help="(requires --word-timestamps True) underline each word as it is spoken in srt and vtt", ) parser.add_argument( @@ -208,8 +206,7 @@ def str2bool(string): # VAD arguments parser.add_argument( "--vad-filter", - type=str2bool, - default=False, + action="store_true", help="Enable Silero VAD to filter silent audio before transcription", ) parser.add_argument( @@ -233,8 +230,7 @@ def str2bool(string): # Diarization arguments parser.add_argument( "--diarize", - type=str2bool, - default=False, + action="store_true", help="Enable speaker diarization (requires pyannote.audio)", ) parser.add_argument( @@ -367,7 +363,7 @@ def main(): vad_options=vad_options, **args, ) - writer(result, output_name, **writer_args) + writer(result, output_name, writer_args) except Exception as e: traceback.print_exc() print(f"Skipping {audio_obj} due to {type(e).__name__}: {str(e)}") diff --git a/whisper/mlx_whisper/silero_vad.py b/whisper/mlx_whisper/silero_vad.py new file mode 100644 index 000000000..69d548d76 --- /dev/null +++ b/whisper/mlx_whisper/silero_vad.py @@ -0,0 +1,392 @@ +# Copyright © 2024 Apple Inc. + +""" +Native MLX implementation of Silero VAD (Voice Activity Detection). + +This module provides a pure MLX implementation of the Silero VAD model, +converted from the original ONNX weights. The model detects speech segments +in audio to improve transcription quality by filtering silent regions. + +Architecture: + 1. STFT: Converts audio to frequency domain via learned conv basis + 2. Encoder: 4-layer 1D CNN with ReLU activations + 3. Decoder: LSTM + 1D Conv for speech probability output +""" + +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import mlx.core as mx +import mlx.nn as nn +import numpy as np +from huggingface_hub import hf_hub_download + + +class SileroSTFT(nn.Module): + """STFT layer using learned convolutional basis. + + Converts raw audio to magnitude spectrogram using a learned basis + stored as convolution weights (not traditional FFT). + """ + + def __init__(self, n_fft: int = 256): + super().__init__() + self.n_fft = n_fft + self.hop_length = n_fft // 4 + + def __call__(self, x: mx.array) -> mx.array: + """Apply STFT. + + Args: + x: Audio tensor [batch, samples] + + Returns: + Magnitude spectrogram [batch, frames, n_fft//2 + 1] + """ + pad_amount = self.n_fft // 2 + x = mx.pad(x, [(0, 0), (pad_amount, pad_amount)]) + x = mx.expand_dims(x, axis=-1) + x = mx.conv1d(x, self.forward_basis, stride=self.hop_length) + + n_freq = self.n_fft // 2 + 1 + real = x[:, :, :n_freq] + imag = x[:, :, n_freq:] + magnitude = mx.sqrt(real**2 + imag**2 + 1e-9) + + return magnitude + + +class ConvBlock(nn.Module): + """1D Convolution block with ReLU activation.""" + + def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3): + super().__init__() + self.conv = nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + padding=kernel_size // 2, + ) + + def __call__(self, x: mx.array) -> mx.array: + return nn.relu(self.conv(x)) + + +class SileroEncoder(nn.Module): + """4-layer CNN encoder for Silero VAD.""" + + def __init__(self, input_dim: int): + super().__init__() + self.blocks = [ + ConvBlock(input_dim, 128, 3), + ConvBlock(128, 64, 3), + ConvBlock(64, 64, 3), + ConvBlock(64, 128, 3), + ] + + def __call__(self, x: mx.array) -> mx.array: + for block in self.blocks: + x = block(x) + return x + + +class SileroLSTMDecoder(nn.Module): + """LSTM-based decoder for speech probability prediction. + + Uses nn.LSTM with proper state handling for streaming inference. + """ + + def __init__(self, input_size: int = 128, hidden_size: int = 128): + super().__init__() + self.hidden_size = hidden_size + self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size) + self.out_conv = nn.Conv1d(in_channels=hidden_size, out_channels=1, kernel_size=1) + + def __call__( + self, + x: mx.array, + h: Optional[mx.array] = None, + c: Optional[mx.array] = None, + ) -> Tuple[mx.array, mx.array, mx.array]: + """Decode features to speech probability. + + Args: + x: Encoded features [batch, frames, 128] + h: Hidden state [batch, hidden_size] or None + c: Cell state [batch, hidden_size] or None + + Returns: + Tuple of: + - Speech probability [batch, 1] + - New hidden state [batch, hidden_size] + - New cell state [batch, hidden_size] + """ + batch_size = x.shape[0] + + if h is None: + h = mx.zeros((batch_size, self.hidden_size)) + if c is None: + c = mx.zeros((batch_size, self.hidden_size)) + + # MLX LSTM returns (h_states, c_states) both of shape [batch, seq, hidden] + h_states, c_states = self.lstm(x, hidden=h, cell=c) + + # The last timestep gives us the new hidden and cell states + h_new = h_states[:, -1, :] + c_new = c_states[:, -1, :] + + # Apply ReLU to hidden states (which is also the output) + lstm_out = nn.relu(h_states) + + # Conv and sigmoid for final probability + out = self.out_conv(lstm_out) + out = mx.sigmoid(out) + + # Return [batch, 1] probability from last timestep + out = out[:, -1, :] + + return out, h_new, c_new + + +class SileroVADModel(nn.Module): + """Complete Silero VAD model in native MLX.""" + + def __init__(self, sample_rate: int = 16000): + super().__init__() + + if sample_rate == 16000: + n_fft = 256 + self.window_size = 512 + self.context_size = 64 + elif sample_rate == 8000: + n_fft = 128 + self.window_size = 256 + self.context_size = 32 + else: + raise ValueError(f"Unsupported sample rate: {sample_rate}") + + self.sample_rate = sample_rate + n_freq = n_fft // 2 + 1 + + self.stft = SileroSTFT(n_fft=n_fft) + self.encoder = SileroEncoder(input_dim=n_freq) + self.decoder = SileroLSTMDecoder(input_size=128, hidden_size=128) + + def __call__( + self, + audio: mx.array, + h: Optional[mx.array] = None, + c: Optional[mx.array] = None, + ) -> Tuple[mx.array, mx.array, mx.array]: + """Process audio chunk and return speech probability. + + Args: + audio: Audio chunk [batch, samples] + h: Hidden state from previous call + c: Cell state from previous call + + Returns: + Tuple of (probability, h_new, c_new) + """ + spec = self.stft(audio) + features = self.encoder(spec) + features = mx.mean(features, axis=1, keepdims=True) + prob, h_new, c_new = self.decoder(features, h, c) + + return prob, h_new, c_new + + def reset_state(self, batch_size: int = 1) -> Tuple[mx.array, mx.array]: + """Create fresh LSTM state.""" + h = mx.zeros((batch_size, self.decoder.hidden_size)) + c = mx.zeros((batch_size, self.decoder.hidden_size)) + return h, c + + +def _reorder_lstm_gates_onnx_to_mlx(weights: np.ndarray, hidden_size: int) -> np.ndarray: + """Reorder LSTM gate weights from ONNX to MLX format. + + ONNX gate order: [i, o, f, c] (input, output, forget, cell) + MLX gate order: [i, f, g, o] (input, forget, gate/cell, output) + + Reordering: i->i, o->o, f->f, c->g + ONNX indices: 0=i, 1=o, 2=f, 3=c + MLX indices: 0=i, 1=f, 2=g, 3=o + """ + # Split into 4 gates + i = weights[:hidden_size] # input gate + o = weights[hidden_size : 2 * hidden_size] # output gate + f = weights[2 * hidden_size : 3 * hidden_size] # forget gate + c = weights[3 * hidden_size : 4 * hidden_size] # cell gate + # Reorder to MLX format: [i, f, g, o] + return np.concatenate([i, f, c, o], axis=0) + + +def _convert_onnx_weights(sample_rate: int = 16000) -> dict: + """Extract weights from ONNX model for given sample rate. + + The ONNX model has nested If nodes: + - 16kHz: If_0 -> then_branch -> If_0 -> then_branch + - 8kHz: If_0 -> else_branch -> If_0 -> else_branch + + We recursively extract all weights with their full path prefixes. + """ + import onnx + from onnx import numpy_helper + + model_path = hf_hub_download( + repo_id="onnx-community/silero-vad", filename="onnx/model.onnx" + ) + model = onnx.load(model_path) + + branch_name = "then_branch" if sample_rate == 16000 else "else_branch" + weights = {} + + def extract_all_recursive(graph, prefix=""): + """Recursively extract all initializers from graph and subgraphs.""" + for init in graph.initializer: + full_name = prefix + init.name + arr = numpy_helper.to_array(init) + weights[full_name] = arr + + for node in graph.node: + for attr in node.attribute: + if attr.type == onnx.AttributeProto.GRAPH: + sub_prefix = f"{prefix}{node.name}__{attr.name}__" + extract_all_recursive(attr.g, sub_prefix) + elif attr.type == onnx.AttributeProto.GRAPHS: + for i, g in enumerate(attr.graphs): + sub_prefix = f"{prefix}{node.name}__{attr.name}_{i}__" + extract_all_recursive(g, sub_prefix) + + extract_all_recursive(model.graph) + + # Filter to only the weights for the target sample rate branch + # 16kHz uses "then_branch", 8kHz uses "else_branch" + filtered_weights = {} + target_prefix = f"If_0__{branch_name}__If_0_{branch_name}__Inline_0__" + + for key, value in weights.items(): + if target_prefix in key: + # Simplify the key by removing the prefix for easier matching + simple_key = key.replace(target_prefix, "") + filtered_weights[simple_key] = value + + return filtered_weights + + +def _convert_weights_to_mlx(onnx_weights: dict, sample_rate: int = 16000) -> dict: + """Convert ONNX weights to MLX format. + + The keys are already simplified (prefix removed by _convert_onnx_weights). + """ + mlx_weights = {} + hidden_size = 128 + + # STFT forward basis + stft_key = "stft.forward_basis_buffer" + if stft_key in onnx_weights: + # ONNX: [n_fft+2, 1, n_fft] -> MLX: [n_fft+2, n_fft, 1] + basis = onnx_weights[stft_key].transpose(0, 2, 1) + mlx_weights["stft.forward_basis"] = mx.array(basis) + + # Encoder conv layers + for i in range(4): + w_key = f"encoder.{i}.reparam_conv.weight" + b_key = f"encoder.{i}.reparam_conv.bias" + if w_key in onnx_weights: + # ONNX: [out, in, kernel] -> MLX: [out, kernel, in] + w = onnx_weights[w_key].transpose(0, 2, 1) + mlx_weights[f"encoder.blocks.{i}.conv.weight"] = mx.array(w) + if b_key in onnx_weights: + mlx_weights[f"encoder.blocks.{i}.conv.bias"] = mx.array(onnx_weights[b_key]) + + # LSTM weights - find by exact key match + for k, v in onnx_weights.items(): + if "/Unsqueeze_7_output_0" in k and v.shape == (1, 512, 128): + # W_ih: [1, 4*hidden, input] -> MLX Wx: [4*hidden, input] + w_ih = v.squeeze(0) # [512, 128] + w_ih = _reorder_lstm_gates_onnx_to_mlx(w_ih, hidden_size) + mlx_weights["decoder.lstm.Wx"] = mx.array(w_ih) + elif "/Unsqueeze_8_output_0" in k and v.shape == (1, 512, 128): + # W_hh: [1, 4*hidden, hidden] -> MLX Wh: [4*hidden, hidden] + w_hh = v.squeeze(0) # [512, 128] + w_hh = _reorder_lstm_gates_onnx_to_mlx(w_hh, hidden_size) + mlx_weights["decoder.lstm.Wh"] = mx.array(w_hh) + elif "/Unsqueeze_9_output_0" in k and v.shape == (1, 1024): + # Combined bias [1, 8*hidden] -> MLX bias [4*hidden] + bias = v.squeeze(0) # [1024] + # ONNX stores Wb and Rb separately: [Wb_i,o,f,g, Rb_i,o,f,g] + # MLX uses single bias = Wb + Rb + b_w = bias[:512] # Input bias + b_r = bias[512:] # Recurrent bias + combined = b_w + b_r + combined = _reorder_lstm_gates_onnx_to_mlx(combined, hidden_size) + mlx_weights["decoder.lstm.bias"] = mx.array(combined) + + # Output conv weights + out_w_key = "decoder.decoder.2.weight" + out_b_key = "decoder.decoder.2.bias" + if out_w_key in onnx_weights: + # Output conv: ONNX [1, 128, 1] -> MLX [1, 1, 128] + w = onnx_weights[out_w_key].transpose(0, 2, 1) + mlx_weights["decoder.out_conv.weight"] = mx.array(w) + if out_b_key in onnx_weights: + mlx_weights["decoder.out_conv.bias"] = mx.array(onnx_weights[out_b_key]) + + return mlx_weights + + +def load_vad_model( + sample_rate: int = 16000, + cache_dir: Optional[Path] = None, +) -> SileroVADModel: + """Load Silero VAD model with converted weights. + + Downloads ONNX model from HuggingFace and converts to MLX format. + Caches converted weights for faster subsequent loads. + """ + if cache_dir is None: + cache_dir = Path.home() / ".cache" / "mlx-whisper" / "vad" + cache_dir = Path(cache_dir) + cache_dir.mkdir(parents=True, exist_ok=True) + + weights_file = cache_dir / f"silero_vad_{sample_rate}.npz" + + if weights_file.exists(): + weights = dict(mx.load(str(weights_file))) # type: ignore + else: + onnx_weights = _convert_onnx_weights(sample_rate) + weights = _convert_weights_to_mlx(onnx_weights, sample_rate) + mx.savez(str(weights_file), **weights) + + model = SileroVADModel(sample_rate=sample_rate) + + # Load STFT weights + if "stft.forward_basis" in weights: + model.stft.forward_basis = weights["stft.forward_basis"] + + # Load encoder weights + for i in range(4): + w_key = f"encoder.blocks.{i}.conv.weight" + b_key = f"encoder.blocks.{i}.conv.bias" + if w_key in weights: + model.encoder.blocks[i].conv.weight = weights[w_key] + if b_key in weights: + model.encoder.blocks[i].conv.bias = weights[b_key] + + # Load LSTM weights + if "decoder.lstm.Wx" in weights: + model.decoder.lstm.Wx = weights["decoder.lstm.Wx"] + if "decoder.lstm.Wh" in weights: + model.decoder.lstm.Wh = weights["decoder.lstm.Wh"] + if "decoder.lstm.bias" in weights: + model.decoder.lstm.bias = weights["decoder.lstm.bias"] + + # Load output conv weights + if "decoder.out_conv.weight" in weights: + model.decoder.out_conv.weight = weights["decoder.out_conv.weight"] + if "decoder.out_conv.bias" in weights: + model.decoder.out_conv.bias = weights["decoder.out_conv.bias"] + + mx.eval(model.parameters()) + return model diff --git a/whisper/mlx_whisper/transcribe.py b/whisper/mlx_whisper/transcribe.py index 14dbad063..bc1c09d60 100644 --- a/whisper/mlx_whisper/transcribe.py +++ b/whisper/mlx_whisper/transcribe.py @@ -244,14 +244,14 @@ def transcribe( content_frames = mel.shape[-2] - N_FRAMES content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE) + make_safe = lambda x: x if verbose: system_encoding = sys.getdefaultencoding() if system_encoding != "utf-8": make_safe = lambda x: x.encode(system_encoding, errors="replace").decode( system_encoding ) - else: - make_safe = lambda x: x + if decode_options.get("language", None) is None: if not model.is_multilingual: @@ -335,7 +335,7 @@ def decode_with_fallback(segment: mx.array) -> DecodingResult: if not needs_fallback: break - return decode_result + return decode_result # type: ignore clip_idx = 0 seek = seek_clips[clip_idx][0] diff --git a/whisper/mlx_whisper/vad.py b/whisper/mlx_whisper/vad.py index d7367df43..b2abea6f0 100644 --- a/whisper/mlx_whisper/vad.py +++ b/whisper/mlx_whisper/vad.py @@ -3,40 +3,28 @@ """ Voice Activity Detection (VAD) module for mlx-whisper. -Provides optional VAD preprocessing using Silero VAD to filter -silent audio regions before transcription, improving speed and +Provides VAD preprocessing using a native MLX implementation of Silero VAD +to filter silent audio regions before transcription, improving speed and reducing hallucinations on audio with significant silence. + +This is a pure MLX implementation - no PyTorch required. """ import bisect from dataclasses import dataclass from typing import Dict, List, Optional +import mlx.core as mx import numpy as np -# Graceful dependency handling -_TORCH_AVAILABLE = False -try: - import torch - - _TORCH_AVAILABLE = True -except ImportError: - pass - def is_available() -> bool: - """Check if VAD dependencies are available.""" - return _TORCH_AVAILABLE - - -class VadUnavailableError(ImportError): - """Raised when VAD is requested but dependencies are missing.""" + """Check if VAD dependencies are available. - def __init__(self): - super().__init__( - "VAD requires PyTorch. Install with: pip install torch\n" - "Or install mlx-whisper with VAD support: pip install mlx-whisper[vad]" - ) + Always returns True since MLX VAD has no external dependencies + beyond MLX itself (which is already required for mlx-whisper). + """ + return True @dataclass @@ -64,31 +52,26 @@ class VadOptions: class SileroVAD: - """Wrapper for Silero VAD model. + """Native MLX wrapper for Silero VAD model. Loads the model lazily on first use to avoid import overhead when VAD is not needed. """ - def __init__(self): + def __init__(self, sample_rate: int = 16000): + """Initialize VAD wrapper. + + Args: + sample_rate: Audio sample rate (16000 or 8000) + """ self._model = None - self._get_speech_timestamps = None + self._sample_rate = sample_rate def _load_model(self): - """Load Silero VAD model via torch.hub.""" - if not is_available(): - raise VadUnavailableError() - - import torch - - self._model, utils = torch.hub.load( - repo_or_dir="snakers4/silero-vad", - model="silero_vad", - force_reload=False, - onnx=False, - trust_repo=True, - ) - self._get_speech_timestamps = utils[0] + """Load Silero VAD model (MLX native implementation).""" + from .silero_vad import load_vad_model + + self._model = load_vad_model(sample_rate=self._sample_rate) def __call__( self, audio: np.ndarray, options: Optional[VadOptions] = None @@ -103,28 +86,145 @@ def __call__( List of dictionaries with 'start' and 'end' keys representing speech segment boundaries in samples. """ - import torch - if self._model is None: self._load_model() if options is None: options = VadOptions() - wav = torch.from_numpy(audio).float() - - speech_timestamps = self._get_speech_timestamps( - wav, + return get_speech_timestamps( + audio, self._model, threshold=options.threshold, min_speech_duration_ms=options.min_speech_duration_ms, max_speech_duration_s=options.max_speech_duration_s, min_silence_duration_ms=options.min_silence_duration_ms, speech_pad_ms=options.speech_pad_ms, - return_seconds=False, # Return in samples + sample_rate=self._sample_rate, ) - return speech_timestamps + +def get_speech_timestamps( + audio: np.ndarray, + model, + threshold: float = 0.5, + min_speech_duration_ms: int = 250, + max_speech_duration_s: float = float("inf"), + min_silence_duration_ms: int = 2000, + speech_pad_ms: int = 400, + sample_rate: int = 16000, +) -> List[Dict[str, int]]: + """Detect speech timestamps in audio using Silero VAD. + + Args: + audio: Audio waveform as numpy array + model: SileroVADModel instance + threshold: Speech detection threshold (0.0-1.0) + min_speech_duration_ms: Minimum speech segment duration in ms + max_speech_duration_s: Maximum speech segment duration in seconds + min_silence_duration_ms: Minimum silence duration to split segments + speech_pad_ms: Padding to add around speech segments in ms + sample_rate: Audio sample rate + + Returns: + List of dicts with 'start' and 'end' keys (sample indices) + """ + # Model parameters + window_size = model.window_size # 512 for 16kHz, 256 for 8kHz + + # Convert time parameters to samples + min_speech_samples = int(min_speech_duration_ms * sample_rate / 1000) + min_silence_samples = int(min_silence_duration_ms * sample_rate / 1000) + speech_pad_samples = int(speech_pad_ms * sample_rate / 1000) + max_speech_samples = ( + int(max_speech_duration_s * sample_rate) + if max_speech_duration_s < float("inf") + else float("inf") + ) + + # Ensure audio is the right length (pad if needed) + audio_length = len(audio) + if audio_length % window_size != 0: + pad_length = window_size - (audio_length % window_size) + audio = np.pad(audio, (0, pad_length)) + + # Process audio in chunks + num_chunks = len(audio) // window_size + probs = [] + + h, c = model.reset_state(batch_size=1) + + for i in range(num_chunks): + chunk = audio[i * window_size : (i + 1) * window_size] + chunk_mx = mx.array(chunk.reshape(1, -1)) + + prob, h, c = model(chunk_mx, h, c) + mx.eval(prob, h, c) + + probs.append(float(prob[0, 0])) + + # Convert probabilities to speech segments + speeches = [] + current_speech = None + + for i, prob in enumerate(probs): + sample_pos = i * window_size + + if prob >= threshold: + if current_speech is None: + current_speech = {"start": sample_pos, "end": sample_pos + window_size} + else: + current_speech["end"] = sample_pos + window_size + else: + if current_speech is not None: + # Check if silence is long enough to end segment + silence_duration = sample_pos - current_speech["end"] + if silence_duration >= min_silence_samples: + speeches.append(current_speech) + current_speech = None + else: + # Continue current speech through short silence + current_speech["end"] = sample_pos + window_size + + # Don't forget the last segment + if current_speech is not None: + speeches.append(current_speech) + + # Filter by minimum duration + speeches = [s for s in speeches if s["end"] - s["start"] >= min_speech_samples] + + # Split segments that exceed max duration + if max_speech_samples < float("inf"): + split_speeches = [] + for speech in speeches: + duration = speech["end"] - speech["start"] + if duration <= max_speech_samples: + split_speeches.append(speech) + else: + # Split into smaller segments + start = speech["start"] + while start < speech["end"]: + end = min(start + max_speech_samples, speech["end"]) + split_speeches.append({"start": start, "end": end}) + start = end + speeches = split_speeches + + # Apply padding + for speech in speeches: + speech["start"] = max(0, speech["start"] - speech_pad_samples) + speech["end"] = min(audio_length, speech["end"] + speech_pad_samples) + + # Merge overlapping segments after padding + if speeches: + merged = [speeches[0]] + for speech in speeches[1:]: + if speech["start"] <= merged[-1]["end"]: + merged[-1]["end"] = max(merged[-1]["end"], speech["end"]) + else: + merged.append(speech) + speeches = merged + + return speeches def get_speech_chunks( diff --git a/whisper/setup.py b/whisper/setup.py index 4c614deec..9cbd0bc06 100644 --- a/whisper/setup.py +++ b/whisper/setup.py @@ -12,7 +12,7 @@ sys.path.append(str(package_dir)) -from _version import __version__ +from _version import __version__ # type: ignore setup( name="mlx-whisper",