diff --git a/whisper/README.md b/whisper/README.md index cd3bc684a..d060bb255 100644 --- a/whisper/README.md +++ b/whisper/README.md @@ -31,9 +31,75 @@ mlx_whisper audio_file.mp3 This will make a text file `audio_file.txt` with the results. -Use `-f` to specify the output format and `--model` to specify the model. There -are many other supported command line options. To see them all, run -`mlx_whisper -h`. +**Common Options:** + +```sh +# Specify output format (txt, vtt, srt, tsv, json, all) +mlx_whisper audio.mp3 -f json + +# Choose a different model +mlx_whisper audio.mp3 --model mlx-community/whisper-large-v3-mlx + +# Enable detailed performance metrics +mlx_whisper audio.mp3 --verbose True + +# Combine options +mlx_whisper audio.mp3 --model mlx-community/whisper-base.en-mlx -f srt --verbose True +``` + +**Performance Benchmarking:** + +Use `--verbose True` to display detailed performance metrics: + +```sh +mlx_whisper audio.mp3 --verbose True +``` + +This will show: + +- Model load time (ms) +- Mel spectrogram computation time (ms) +- Inference time (ms) +- Total processing time (ms) +- RTF (Real-Time Factor) - values < 1.0 indicate faster-than-real-time +- Token generation throughput (tokens/sec and steps/sec) +- Per-segment performance breakdown + +Example output: + +``` +================================================================================ +BENCHMARK METRICS +================================================================================ +Model load time: 386.68 ms +Mel spectrogram time: 62.27 ms +Inference time: 242.75 ms +Total time: 691.70 ms +Audio duration: 20.03 s +RTF (Real-Time Factor): 0.035 + +Total output tokens: 75 +Total inference steps (decoder forward passes): 77 +Average output tokens/sec: 308.95 +Average inference steps/sec: 317.19 +================================================================================ +``` + +**Additional Options:** + +```sh +# Word-level timestamps +mlx_whisper audio.mp3 --word-timestamps True + +# Specify language (skip auto-detection) +mlx_whisper audio.mp3 --language en + +# Translate to English +mlx_whisper audio.mp3 --task translate + +# Custom output name +mlx_whisper audio.mp3 --output-name my_transcript +``` You can also pipe the audio content of other programs via stdin: @@ -44,6 +110,12 @@ some-process | mlx_whisper - The default output file name will be `content.*`. You can specify the name with the `--output-name` flag. +To see all available options, run: + +```sh +mlx_whisper -h +``` + #### API Transcribe audio with: diff --git a/whisper/mlx_whisper/decoding.py b/whisper/mlx_whisper/decoding.py index 814dc95ca..88ce7fc88 100644 --- a/whisper/mlx_whisper/decoding.py +++ b/whisper/mlx_whisper/decoding.py @@ -127,6 +127,7 @@ class DecodingResult: no_speech_prob: float = np.nan temperature: float = np.nan compression_ratio: float = np.nan + num_inference_steps: int = 0 # Total decoder forward passes class Inference: @@ -572,6 +573,7 @@ def _detect_language(self, audio_features: mx.array, tokens: np.array): def _main_loop(self, audio_features: mx.array, tokens: mx.array): n_batch = tokens.shape[0] sum_logprobs = mx.zeros(n_batch) + inference_steps = 1 # Count the first step def _step(inputs, audio_features, tokens, sum_logprobs): pre_logits = self.inference.logits(inputs, audio_features) @@ -608,13 +610,14 @@ def _step(inputs, audio_features, tokens, sum_logprobs): inputs, audio_features, tokens, sum_logprobs ) mx.async_eval(next_completed, next_tokens, next_sum_logprobs) + inference_steps += 1 # Count each iteration if completed: break tokens = next_tokens completed = next_completed sum_logprobs = next_sum_logprobs - return tokens, sum_logprobs, no_speech_probs + return tokens, sum_logprobs, no_speech_probs, inference_steps def run(self, mel: mx.array) -> List[DecodingResult]: self.inference.reset() @@ -647,7 +650,9 @@ def run(self, mel: mx.array) -> List[DecodingResult]: tokens = tokens.reshape((n_audio * self.n_group, len(self.initial_tokens))) # call the main sampling loop - tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens) + tokens, sum_logprobs, no_speech_probs, inference_steps = self._main_loop( + audio_features, tokens + ) # reshape the tensors to have (n_audio, n_group) as the first two dimensions audio_features = audio_features[:: self.n_group] @@ -699,6 +704,7 @@ def run(self, mel: mx.array) -> List[DecodingResult]: no_speech_prob=no_speech_prob, temperature=self.options.temperature, compression_ratio=compression_ratio(text), + num_inference_steps=inference_steps, ) for text, language, tokens, features, avg_logprob, no_speech_prob in zip( *fields diff --git a/whisper/mlx_whisper/transcribe.py b/whisper/mlx_whisper/transcribe.py index bced16a58..991e618f5 100644 --- a/whisper/mlx_whisper/transcribe.py +++ b/whisper/mlx_whisper/transcribe.py @@ -1,6 +1,7 @@ # Copyright © 2023 Apple Inc. import sys +import time import warnings from typing import List, Optional, Tuple, Union @@ -144,10 +145,17 @@ def transcribe( """ dtype = mx.float16 if decode_options.get("fp16", True) else mx.float32 + + # Track model loading time + model_load_start = time.time() model = ModelHolder.get_model(path_or_hf_repo, dtype) + model_load_time = time.time() - model_load_start + # Track mel spectrogram computation time + mel_start = time.time() # Pad 30-seconds of silence to the input audio, for slicing mel = log_mel_spectrogram(audio, n_mels=model.dims.n_mels, padding=N_SAMPLES) + mel_time = time.time() - mel_start content_frames = mel.shape[-2] - N_FRAMES content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE) @@ -204,12 +212,23 @@ def transcribe( if word_timestamps and task == "translate": warnings.warn("Word-level timestamps on translations may not be reliable.") + # Initialize timing metrics + total_inference_time = 0.0 + total_tokens_generated = 0 + total_inference_steps = ( + 0 # Total decoder forward passes (comparable to whisper.cpp "runs") + ) + segment_timings = [] + def decode_with_fallback(segment: mx.array) -> DecodingResult: + nonlocal total_inference_time, total_tokens_generated, total_inference_steps temperatures = ( [temperature] if isinstance(temperature, (int, float)) else temperature ) decode_result = None + segment_start_time = time.time() + for t in temperatures: kwargs = {**decode_options} if t > 0: @@ -242,6 +261,31 @@ def decode_with_fallback(segment: mx.array) -> DecodingResult: if not needs_fallback: break + segment_inference_time = time.time() - segment_start_time + total_inference_time += segment_inference_time + num_output_tokens = len(decode_result.tokens) + num_inference_steps = decode_result.num_inference_steps + total_tokens_generated += num_output_tokens + total_inference_steps += num_inference_steps + + segment_timings.append( + { + "time": segment_inference_time, + "output_tokens": num_output_tokens, + "inference_steps": num_inference_steps, + "output_tokens_per_sec": ( + num_output_tokens / segment_inference_time + if segment_inference_time > 0 + else 0 + ), + "inference_steps_per_sec": ( + num_inference_steps / segment_inference_time + if segment_inference_time > 0 + else 0 + ), + } + ) + return decode_result clip_idx = 0 @@ -536,6 +580,54 @@ def next_words_segment(segments: List[dict]) -> Optional[dict]: # update progress bar pbar.update(min(content_frames, seek) - previous_seek) + # Print detailed inference metrics + total_time = model_load_time + mel_time + total_inference_time + rtf = (total_time / content_duration) if content_duration > 0 else 0 + + if verbose is not False: + print("\n" + "=" * 80) + print("BENCHMARK METRICS") + print("=" * 80) + print(f"Model load time: {model_load_time * 1000:.2f} ms") + print(f"Mel spectrogram time: {mel_time * 1000:.2f} ms") + print(f"Inference time: {total_inference_time * 1000:.2f} ms") + print(f"Total time: {total_time * 1000:.2f} ms") + print(f"Audio duration: {content_duration:.2f} s") + print(f"RTF (Real-Time Factor): {rtf:.3f}") + print(f"\nTotal output tokens: {total_tokens_generated}") + print( + f"Total inference steps (decoder forward passes): {total_inference_steps}" + ) + if total_inference_time > 0: + print( + f"Average output tokens/sec: {total_tokens_generated / total_inference_time:.2f}" + ) + print( + f"Average inference steps/sec: {total_inference_steps / total_inference_time:.2f}" + ) + print(f"Number of segments: {len(segment_timings)}") + + if verbose and len(segment_timings) > 0: + print("\nPer-segment details:") + print( + f"{'Seg#':<6} {'Out':<6} {'Steps':<7} {'Time(s)':<10} {'Out/s':<10} {'Steps/s':<10}" + ) + print("-" * 80) + for i, timing in enumerate(segment_timings): + print( + f"{i:<6} {timing['output_tokens']:<6} {timing['inference_steps']:<7} " + f"{timing['time']:<10.3f} {timing['output_tokens_per_sec']:<10.2f} " + f"{timing['inference_steps_per_sec']:<10.2f}" + ) + + print("\nNOTE:") + print("- RTF < 1.0 means faster than real-time") + print("- 'Output tokens': Final text tokens (excluding special tokens)") + print( + "- 'Inference steps': Total decoder forward passes (comparable to whisper.cpp 'runs')" + ) + print("=" * 80 + "\n") + return dict( text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]), segments=all_segments,