From 0764a0ee844135c64f82590257f5f00be2342715 Mon Sep 17 00:00:00 2001 From: Anthony DePasquale Date: Mon, 15 Dec 2025 19:44:20 +0100 Subject: [PATCH] Use optimal alignment heads for Whisper --- whisper/convert.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/whisper/convert.py b/whisper/convert.py index 9cc8b861b..0d7c82762 100644 --- a/whisper/convert.py +++ b/whisper/convert.py @@ -59,6 +59,19 @@ } +def _get_model_variant(name_or_path: str) -> str | None: + """Extract model variant for alignment heads lookup.""" + if name_or_path in _ALIGNMENT_HEADS: + return name_or_path + + # Extract from repo name like "openai/whisper-large-v3" + name = name_or_path.split("/")[-1] + if name.startswith("whisper-"): + return name[8:] # Remove "whisper-" prefix + + return None + + def _download(url: str, root: str) -> str: os.makedirs(root, exist_ok=True) @@ -156,10 +169,11 @@ def load_torch_weights_and_config( if download_root is None: download_root = os.path.join(os.path.expanduser("~"), ".cache/whisper") - # todo: accept alignment_heads of local Pytorch checkpoint - alignment_heads = None + # Look up alignment heads using normalized variant name + variant = _get_model_variant(name_or_path) + alignment_heads = _ALIGNMENT_HEADS.get(variant) if variant else None + if name_or_path in _MODELS: - alignment_heads = _ALIGNMENT_HEADS[name_or_path] name_or_path = _download(_MODELS[name_or_path], download_root) elif not Path(name_or_path).exists(): # Try downloading from HF