Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions whisper/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Test artifacts
output/

# Python
__pycache__/
*.py[cod]
*$py.class
*.so
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg

# Virtual environments
.env
.venv
env/
venv/

# IDE
.idea/
.vscode/
*.swp
*.swo

# pytest
.pytest_cache/
.coverage
htmlcov/
97 changes: 97 additions & 0 deletions whisper/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,103 @@ To see more transcription options use:
>>> help(mlx_whisper.transcribe)
```

### Beam Search Decoding

By default, mlx-whisper uses greedy decoding. Enable beam search for potentially
more accurate transcriptions at the cost of speed:

```bash
# Enable beam search with beam size 5
mlx_whisper audio.mp3 --beam-size 5

# Adjust patience for earlier/later stopping (default: 1.0)
mlx_whisper audio.mp3 --beam-size 5 --patience 1.5
```

In Python:

```python
result = mlx_whisper.transcribe(
"audio.mp3",
beam_size=5,
patience=1.0
)
```

The `patience` parameter controls early stopping: decoding stops when
`round(beam_size * patience)` finished sequences have been collected.
Higher patience values explore more candidates before stopping.

### Voice Activity Detection (VAD)

Enable Silero VAD to filter silent audio regions before transcription. This can
significantly speed up transcription for audio with long silent periods:

```bash
# Enable VAD
mlx_whisper audio.mp3 --vad-filter

# Customize VAD settings
mlx_whisper audio.mp3 --vad-filter --vad-threshold 0.6 --vad-min-silence-ms 1000
```

In Python:

```python
from mlx_whisper import transcribe
from mlx_whisper.vad import VadOptions

result = transcribe("audio.mp3", vad_filter=True)

# With custom options
vad_opts = VadOptions(threshold=0.6, min_silence_duration_ms=1000)
result = transcribe("audio.mp3", vad_filter=True, vad_options=vad_opts)
```

**Requirements**: `pip install torch`

### Speaker Diarization

Identify who is speaking when with pyannote.audio. Diarization adds speaker
labels to transcription segments:

```bash
# Enable diarization (requires HuggingFace token)
export HF_TOKEN=your_token
mlx_whisper audio.mp3 --diarize --word-timestamps

# Specify speaker count
mlx_whisper audio.mp3 --diarize --min-speakers 2 --max-speakers 4

# Output diarization in RTTM format
mlx_whisper audio.mp3 --diarize -f rttm
```

In Python:

```python
from mlx_whisper import transcribe_with_diarization

result = transcribe_with_diarization(
"audio.mp3",
hf_token="your_token",
word_timestamps=True
)

# Access speaker info
for segment in result["segments"]:
speaker = segment.get("speaker", "Unknown")
print(f"{speaker}: {segment['text']}")

# List of speakers
print(result["speakers"]) # ['SPEAKER_00', 'SPEAKER_01', ...]
```

**Requirements**:
- `pip install pyannote.audio pandas`
- Accept model terms at https://huggingface.co/pyannote/speaker-diarization-3.1
- Set `HF_TOKEN` environment variable or pass `--hf-token`

### Converting models

> [!TIP]
Expand Down
13 changes: 12 additions & 1 deletion whisper/mlx_whisper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,15 @@

from . import audio, decoding, load_models
from ._version import __version__
from .transcribe import transcribe
from .transcribe import transcribe, transcribe_with_diarization

# Optional modules (may not be available if dependencies are missing or incompatible)
try:
from . import vad
except (ImportError, AttributeError):
vad = None

try:
from . import diarize
except (ImportError, AttributeError):
diarize = None
2 changes: 1 addition & 1 deletion whisper/mlx_whisper/_version.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc.

__version__ = "0.4.3"
__version__ = "0.5.0"
139 changes: 128 additions & 11 deletions whisper/mlx_whisper/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def str2bool(string):
"-f",
type=str,
default="txt",
choices=["txt", "vtt", "srt", "tsv", "json", "all"],
choices=["txt", "vtt", "srt", "tsv", "json", "rttm", "all"],
help="Format of the output file",
)
parser.add_argument(
Expand Down Expand Up @@ -92,6 +92,12 @@ def str2bool(string):
default=5,
help="Number of candidates when sampling with non-zero temperature",
)
parser.add_argument(
"--beam-size",
type=optional_int,
default=None,
help="Beam size for beam search (currently not implemented; option will be ignored)",
)
parser.add_argument(
"--patience",
type=float,
Expand Down Expand Up @@ -148,8 +154,7 @@ def str2bool(string):
)
parser.add_argument(
"--word-timestamps",
type=str2bool,
default=False,
action="store_true",
help="Extract word-level timestamps and refine the results based on them",
)
parser.add_argument(
Expand All @@ -166,8 +171,7 @@ def str2bool(string):
)
parser.add_argument(
"--highlight-words",
type=str2bool,
default=False,
action="store_true",
help="(requires --word-timestamps True) underline each word as it is spoken in srt and vtt",
)
parser.add_argument(
Expand Down Expand Up @@ -199,6 +203,67 @@ def str2bool(string):
default="0",
help="Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process, where the last end timestamp defaults to the end of the file",
)
# VAD arguments
parser.add_argument(
"--vad-filter",
action="store_true",
help="Enable Silero VAD to filter silent audio before transcription",
)
parser.add_argument(
"--vad-threshold",
type=float,
default=0.5,
help="VAD speech detection threshold (0.0-1.0)",
)
parser.add_argument(
"--vad-min-silence-ms",
type=int,
default=2000,
help="Minimum silence duration to split speech segments (ms)",
)
parser.add_argument(
"--vad-speech-pad-ms",
type=int,
default=400,
help="Padding added around speech segments (ms)",
)
# Diarization arguments
parser.add_argument(
"--diarize",
action="store_true",
help="Enable speaker diarization (requires pyannote.audio)",
)
parser.add_argument(
"--hf-token",
type=str,
default=None,
help="HuggingFace token for pyannote models (or set HF_TOKEN env var)",
)
parser.add_argument(
"--diarize-model",
type=str,
default="pyannote/speaker-diarization-3.1",
help="Diarization model to use",
)
parser.add_argument(
"--min-speakers",
type=optional_int,
default=None,
help="Minimum number of speakers for diarization",
)
parser.add_argument(
"--max-speakers",
type=optional_int,
default=None,
help="Maximum number of speakers for diarization",
)
parser.add_argument(
"--diarize-device",
type=str,
default="cpu",
choices=["cpu", "cuda", "mps"],
help="Device for diarization model",
)
return parser


Expand Down Expand Up @@ -232,6 +297,40 @@ def main():
if writer_args["max_words_per_line"] and writer_args["max_line_width"]:
warnings.warn("--max-words-per-line has no effect with --max-line-width")

# Extract VAD options
vad_filter = args.pop("vad_filter")
vad_threshold = args.pop("vad_threshold")
vad_min_silence_ms = args.pop("vad_min_silence_ms")
vad_speech_pad_ms = args.pop("vad_speech_pad_ms")

vad_options = None
if vad_filter:
from .vad import VadOptions

vad_options = VadOptions(
threshold=vad_threshold,
min_silence_duration_ms=vad_min_silence_ms,
speech_pad_ms=vad_speech_pad_ms,
)
elif any(
[vad_threshold != 0.5, vad_min_silence_ms != 2000, vad_speech_pad_ms != 400]
):
warnings.warn("VAD options have no effect without --vad-filter")

# Extract diarization options
diarize = args.pop("diarize")
hf_token = args.pop("hf_token") or os.environ.get("HF_TOKEN")
diarize_model = args.pop("diarize_model")
min_speakers = args.pop("min_speakers")
max_speakers = args.pop("max_speakers")
diarize_device = args.pop("diarize_device")

if diarize and not hf_token:
warnings.warn(
"Diarization requires a HuggingFace token. "
"Set --hf-token or HF_TOKEN environment variable."
)

for audio_obj in args.pop("audio"):
if audio_obj == "-":
# receive the contents from stdin rather than read a file
Expand All @@ -241,12 +340,30 @@ def main():
else:
output_name = output_name or pathlib.Path(audio_obj).stem
try:
result = transcribe(
audio_obj,
path_or_hf_repo=path_or_hf_repo,
**args,
)
writer(result, output_name, **writer_args)
if diarize:
from .transcribe import transcribe_with_diarization

result = transcribe_with_diarization(
audio_obj,
path_or_hf_repo=path_or_hf_repo,
hf_token=hf_token,
diarize_model=diarize_model,
min_speakers=min_speakers,
max_speakers=max_speakers,
device=diarize_device,
vad_filter=vad_filter,
vad_options=vad_options,
**args,
)
else:
result = transcribe(
audio_obj,
path_or_hf_repo=path_or_hf_repo,
vad_filter=vad_filter,
vad_options=vad_options,
**args,
)
writer(result, output_name, writer_args)
except Exception as e:
traceback.print_exc()
print(f"Skipping {audio_obj} due to {type(e).__name__}: {str(e)}")
Expand Down
Loading