Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 75 additions & 3 deletions whisper/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,75 @@ mlx_whisper audio_file.mp3

This will make a text file `audio_file.txt` with the results.

Use `-f` to specify the output format and `--model` to specify the model. There
are many other supported command line options. To see them all, run
`mlx_whisper -h`.
**Common Options:**

```sh
# Specify output format (txt, vtt, srt, tsv, json, all)
mlx_whisper audio.mp3 -f json

# Choose a different model
mlx_whisper audio.mp3 --model mlx-community/whisper-large-v3-mlx

# Enable detailed performance metrics
mlx_whisper audio.mp3 --verbose True

# Combine options
mlx_whisper audio.mp3 --model mlx-community/whisper-base.en-mlx -f srt --verbose True
```

**Performance Benchmarking:**

Use `--verbose True` to display detailed performance metrics:

```sh
mlx_whisper audio.mp3 --verbose True
```

This will show:

- Model load time (ms)
- Mel spectrogram computation time (ms)
- Inference time (ms)
- Total processing time (ms)
- RTF (Real-Time Factor) - values < 1.0 indicate faster-than-real-time
- Token generation throughput (tokens/sec and steps/sec)
- Per-segment performance breakdown

Example output:

```
================================================================================
BENCHMARK METRICS
================================================================================
Model load time: 386.68 ms
Mel spectrogram time: 62.27 ms
Inference time: 242.75 ms
Total time: 691.70 ms
Audio duration: 20.03 s
RTF (Real-Time Factor): 0.035

Total output tokens: 75
Total inference steps (decoder forward passes): 77
Average output tokens/sec: 308.95
Average inference steps/sec: 317.19
================================================================================
```

**Additional Options:**

```sh
# Word-level timestamps
mlx_whisper audio.mp3 --word-timestamps True

# Specify language (skip auto-detection)
mlx_whisper audio.mp3 --language en

# Translate to English
mlx_whisper audio.mp3 --task translate

# Custom output name
mlx_whisper audio.mp3 --output-name my_transcript
```

You can also pipe the audio content of other programs via stdin:

Expand All @@ -44,6 +110,12 @@ some-process | mlx_whisper -
The default output file name will be `content.*`. You can specify the name with
the `--output-name` flag.

To see all available options, run:

```sh
mlx_whisper -h
```

#### API

Transcribe audio with:
Expand Down
10 changes: 8 additions & 2 deletions whisper/mlx_whisper/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ class DecodingResult:
no_speech_prob: float = np.nan
temperature: float = np.nan
compression_ratio: float = np.nan
num_inference_steps: int = 0 # Total decoder forward passes


class Inference:
Expand Down Expand Up @@ -572,6 +573,7 @@ def _detect_language(self, audio_features: mx.array, tokens: np.array):
def _main_loop(self, audio_features: mx.array, tokens: mx.array):
n_batch = tokens.shape[0]
sum_logprobs = mx.zeros(n_batch)
inference_steps = 1 # Count the first step

def _step(inputs, audio_features, tokens, sum_logprobs):
pre_logits = self.inference.logits(inputs, audio_features)
Expand Down Expand Up @@ -608,13 +610,14 @@ def _step(inputs, audio_features, tokens, sum_logprobs):
inputs, audio_features, tokens, sum_logprobs
)
mx.async_eval(next_completed, next_tokens, next_sum_logprobs)
inference_steps += 1 # Count each iteration
if completed:
break
tokens = next_tokens
completed = next_completed
sum_logprobs = next_sum_logprobs

return tokens, sum_logprobs, no_speech_probs
return tokens, sum_logprobs, no_speech_probs, inference_steps

def run(self, mel: mx.array) -> List[DecodingResult]:
self.inference.reset()
Expand Down Expand Up @@ -647,7 +650,9 @@ def run(self, mel: mx.array) -> List[DecodingResult]:
tokens = tokens.reshape((n_audio * self.n_group, len(self.initial_tokens)))

# call the main sampling loop
tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens)
tokens, sum_logprobs, no_speech_probs, inference_steps = self._main_loop(
audio_features, tokens
)

# reshape the tensors to have (n_audio, n_group) as the first two dimensions
audio_features = audio_features[:: self.n_group]
Expand Down Expand Up @@ -699,6 +704,7 @@ def run(self, mel: mx.array) -> List[DecodingResult]:
no_speech_prob=no_speech_prob,
temperature=self.options.temperature,
compression_ratio=compression_ratio(text),
num_inference_steps=inference_steps,
)
for text, language, tokens, features, avg_logprob, no_speech_prob in zip(
*fields
Expand Down
92 changes: 92 additions & 0 deletions whisper/mlx_whisper/transcribe.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright © 2023 Apple Inc.

import sys
import time
import warnings
from typing import List, Optional, Tuple, Union

Expand Down Expand Up @@ -144,10 +145,17 @@ def transcribe(
"""

dtype = mx.float16 if decode_options.get("fp16", True) else mx.float32

# Track model loading time
model_load_start = time.time()
model = ModelHolder.get_model(path_or_hf_repo, dtype)
model_load_time = time.time() - model_load_start

# Track mel spectrogram computation time
mel_start = time.time()
# Pad 30-seconds of silence to the input audio, for slicing
mel = log_mel_spectrogram(audio, n_mels=model.dims.n_mels, padding=N_SAMPLES)
mel_time = time.time() - mel_start
content_frames = mel.shape[-2] - N_FRAMES
content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE)

Expand Down Expand Up @@ -204,12 +212,23 @@ def transcribe(
if word_timestamps and task == "translate":
warnings.warn("Word-level timestamps on translations may not be reliable.")

# Initialize timing metrics
total_inference_time = 0.0
total_tokens_generated = 0
total_inference_steps = (
0 # Total decoder forward passes (comparable to whisper.cpp "runs")
)
segment_timings = []

def decode_with_fallback(segment: mx.array) -> DecodingResult:
nonlocal total_inference_time, total_tokens_generated, total_inference_steps
temperatures = (
[temperature] if isinstance(temperature, (int, float)) else temperature
)
decode_result = None

segment_start_time = time.time()

for t in temperatures:
kwargs = {**decode_options}
if t > 0:
Expand Down Expand Up @@ -242,6 +261,31 @@ def decode_with_fallback(segment: mx.array) -> DecodingResult:
if not needs_fallback:
break

segment_inference_time = time.time() - segment_start_time
total_inference_time += segment_inference_time
num_output_tokens = len(decode_result.tokens)
num_inference_steps = decode_result.num_inference_steps
total_tokens_generated += num_output_tokens
total_inference_steps += num_inference_steps

segment_timings.append(
{
"time": segment_inference_time,
"output_tokens": num_output_tokens,
"inference_steps": num_inference_steps,
"output_tokens_per_sec": (
num_output_tokens / segment_inference_time
if segment_inference_time > 0
else 0
),
"inference_steps_per_sec": (
num_inference_steps / segment_inference_time
if segment_inference_time > 0
else 0
),
}
)

return decode_result

clip_idx = 0
Expand Down Expand Up @@ -536,6 +580,54 @@ def next_words_segment(segments: List[dict]) -> Optional[dict]:
# update progress bar
pbar.update(min(content_frames, seek) - previous_seek)

# Print detailed inference metrics
total_time = model_load_time + mel_time + total_inference_time
rtf = (total_time / content_duration) if content_duration > 0 else 0

if verbose is not False:
print("\n" + "=" * 80)
print("BENCHMARK METRICS")
print("=" * 80)
print(f"Model load time: {model_load_time * 1000:.2f} ms")
print(f"Mel spectrogram time: {mel_time * 1000:.2f} ms")
print(f"Inference time: {total_inference_time * 1000:.2f} ms")
print(f"Total time: {total_time * 1000:.2f} ms")
print(f"Audio duration: {content_duration:.2f} s")
print(f"RTF (Real-Time Factor): {rtf:.3f}")
print(f"\nTotal output tokens: {total_tokens_generated}")
print(
f"Total inference steps (decoder forward passes): {total_inference_steps}"
)
if total_inference_time > 0:
print(
f"Average output tokens/sec: {total_tokens_generated / total_inference_time:.2f}"
)
print(
f"Average inference steps/sec: {total_inference_steps / total_inference_time:.2f}"
)
print(f"Number of segments: {len(segment_timings)}")

if verbose and len(segment_timings) > 0:
print("\nPer-segment details:")
print(
f"{'Seg#':<6} {'Out':<6} {'Steps':<7} {'Time(s)':<10} {'Out/s':<10} {'Steps/s':<10}"
)
print("-" * 80)
for i, timing in enumerate(segment_timings):
print(
f"{i:<6} {timing['output_tokens']:<6} {timing['inference_steps']:<7} "
f"{timing['time']:<10.3f} {timing['output_tokens_per_sec']:<10.2f} "
f"{timing['inference_steps_per_sec']:<10.2f}"
)

print("\nNOTE:")
print("- RTF < 1.0 means faster than real-time")
print("- 'Output tokens': Final text tokens (excluding special tokens)")
print(
"- 'Inference steps': Total decoder forward passes (comparable to whisper.cpp 'runs')"
)
print("=" * 80 + "\n")

return dict(
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),
segments=all_segments,
Expand Down