diff --git a/README.md b/README.md index c47b9e7..2ddea46 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,7 @@ LIRA is a **CLI-first, developer-friendly tool**: run and serve ASR models local 2. **Activate your conda environment:** ```bash - conda activate ryzen-ai-1.5.0 + conda activate ryzen-ai-1.6.0 ``` 3. **Install LIRA in editable mode:** diff --git a/config/vitisai_config_whisper_base_decoder.json b/config/vitisai_config_whisper_base_decoder.json index b31441d..83216e3 100644 --- a/config/vitisai_config_whisper_base_decoder.json +++ b/config/vitisai_config_whisper_base_decoder.json @@ -8,15 +8,19 @@ "name": "vaiml_partition", "plugin": "vaip-pass_vaiml_partition", "vaiml_config": { - "keep_outputs": false, - "device": "stx", - "optimize_level": 2, - "logging_level": "info", - "no_failsafe": true, - "experiment_features": [ - "InsertGatherFront" - ] + "optimize_level": 3, + "aiecompiler_args": "--system-stack-size=512" } } + ], + "target": "VAIML", + "targets": [ + { + "name": "VAIML", + "pass": [ + "init", + "vaiml_partition" + ] + } ] } \ No newline at end of file diff --git a/config/vitisai_config_whisper_base_encoder.json b/config/vitisai_config_whisper_base_encoder.json index 3fa5036..c13fad7 100644 --- a/config/vitisai_config_whisper_base_encoder.json +++ b/config/vitisai_config_whisper_base_encoder.json @@ -8,8 +8,20 @@ "name": "vaiml_partition", "plugin": "vaip-pass_vaiml_partition", "vaiml_config": { - "optimize_level": 2 + "optimize_level": 3, + "fe_experiment": "use-accurate-mode=LayerNorm2PassAdf", + "aiecompiler_args": "--system-stack-size=512" } } + ], + "target": "VAIML", + "targets": [ + { + "name": "VAIML", + "pass": [ + "init", + "vaiml_partition" + ] + } ] } \ No newline at end of file diff --git a/config/vitisai_config_zipformer_encoder.json b/config/vitisai_config_zipformer_encoder.json index 68f5de3..c13fad7 100644 --- a/config/vitisai_config_zipformer_encoder.json +++ b/config/vitisai_config_zipformer_encoder.json @@ -8,12 +8,20 @@ "name": "vaiml_partition", "plugin": "vaip-pass_vaiml_partition", "vaiml_config": { - "keep_outputs": false, - "optimize_level": 2, - "fe_experiment": "prefer-conv-to-gemm-conversion=1 enable-gather-elements-adf=1 enable-edge-reshape-unwrapping=0", - "preferred_data_storage": "unvectorized", - "threshold_gops_percent": 0 + "optimize_level": 3, + "fe_experiment": "use-accurate-mode=LayerNorm2PassAdf", + "aiecompiler_args": "--system-stack-size=512" } } + ], + "target": "VAIML", + "targets": [ + { + "name": "VAIML", + "pass": [ + "init", + "vaiml_partition" + ] + } ] } \ No newline at end of file diff --git a/lira/models/whisper/export.py b/lira/models/whisper/export.py index a27f1ba..4394f8f 100644 --- a/lira/models/whisper/export.py +++ b/lira/models/whisper/export.py @@ -1,19 +1,23 @@ -# Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2025, Advanced Micro Devices, Inc. # SPDX-License-Identifier: MIT -import numpy as np -import subprocess -import onnx -import onnxruntime as ort -import json import os -import argparse +import json import shutil +import subprocess +import argparse +from pathlib import Path +import onnx +import onnxruntime as ort +import numpy as np from onnx import shape_inference from lira.utils.config import get_exported_cache_dir -# Global variable for static shape parameters + +# ---------------------------------------------------------------------- +# Static shape parameter definitions +# ---------------------------------------------------------------------- STATIC_PARAMS = { "batch_size": "1", "encoder_sequence_length / 2": "1500", @@ -26,6 +30,9 @@ } +# ---------------------------------------------------------------------- +# Argument parsing +# ---------------------------------------------------------------------- def parse_args(): parser = argparse.ArgumentParser( description="Export and fix ONNX models for Whisper" @@ -45,24 +52,25 @@ def parse_args(): "--static", action="store_true", help="Use static shape parameters" ) parser.add_argument( - "--force", - action="store_true", - help="Force export even if cache already exists", + "--force", action="store_true", help="Force re-export even if cache exists" ) + args = parser.parse_args() + # Load parameter map if not args.static: with open("config/params.json", "r") as f: args.params_to_fix = json.load(f) else: args.params_to_fix = STATIC_PARAMS - return args +# ---------------------------------------------------------------------- +# Optimum CLI model export +# ---------------------------------------------------------------------- def export_with_optimum_cli(model_name, output_dir, opset): - """Run the optimum-cli export command to generate ONNX models, hiding stdout/stderr.""" - # Ensure model_name is in the format 'openai/whisper-*' + """Export transformer to ONNX format using optimum-cli.""" command = [ "optimum-cli", "export", @@ -80,17 +88,15 @@ def export_with_optimum_cli(model_name, output_dir, opset): print(f"Exported ONNX model to: {output_dir}") +# ---------------------------------------------------------------------- +# Static shape patching with support for models > 2GB +# ---------------------------------------------------------------------- def force_set_static(model_path, output_path, mapping): """ - mapping: dict of {dim_param_name: value} - Example: - { - "batch_size": 1, - "encoder_sequence_length / 2": 1500, - "decoder_sequence_length": 1, - } + Safely convert ONNX dynamic dims to static dims without breaking >2GB models. """ - model = onnx.load(model_path) + print(f"Loading large ONNX model from: {model_path}") + model = onnx.load_model(model_path, load_external_data=True) def patch_value_info(value_infos): for vi in value_infos: @@ -98,64 +104,77 @@ def patch_value_info(value_infos): continue shape = vi.type.tensor_type.shape for dim in shape.dim: - if dim.dim_param in mapping: + if dim.dim_param and dim.dim_param in mapping: dim.dim_value = int(mapping[dim.dim_param]) + dim.ClearField("dim_param") patch_value_info(model.graph.input) patch_value_info(model.graph.output) patch_value_info(model.graph.value_info) - onnx.save(model, output_path) - print(f"Forced static shapes written to {output_path}") + # Save with external data to handle >2GB tensors + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + datafile_name = output_path.with_suffix(".onnx_data").name + + onnx.save_model( + model, + str(output_path), + save_as_external_data=True, + all_tensors_to_one_file=True, + location=datafile_name, + size_threshold=512 * 1024 * 1024, + ) + + # ✅ Use model path for validation instead of in-memory object + try: + onnx.checker.check_model(str(output_path)) + print(f"Model integrity verified for {output_path}") + except Exception as e: + print(f"Warning: ONNX checker skipped in-memory validation: {e}") + + print(f"Saved static ONNX model to {output_path}") + print(f" External tensor data stored at {datafile_name}") +# ---------------------------------------------------------------------- +# Whisper model export workflow +# ---------------------------------------------------------------------- def export_whisper_model( model_name, output_dir=None, opset=17, static=False, force=False ): - """Exports the Whisper model and fixes static shapes if required.""" - # Set default output directory to project cache if not specified + """Exports and statically fixes ONNX models for Whisper pipeline.""" if output_dir is None: output_dir = get_exported_cache_dir() / model_name - output_dir.mkdir(parents=True, exist_ok=True) - # Check if cache already exists - if os.path.exists(output_dir) and not force: + output_dir = Path(output_dir) + if output_dir.exists() and not force: print(f"Cache already exists at {output_dir}. Use --force to overwrite.") return - # Determine parameters to fix - if not static: - with open("config/params.json", "r") as f: - params_to_fix = json.load(f) - else: - params_to_fix = STATIC_PARAMS - - # Step 1: Export ONNX model using optimum-cli - export_with_optimum_cli(model_name, output_dir, opset) + output_dir.mkdir(parents=True, exist_ok=True) - # Step 2: Fix static shapes for encoder and decoder models - encoder_model_path = os.path.join(output_dir, "encoder_model.onnx") - decoder_model_path = os.path.join(output_dir, "decoder_model.onnx") - decoder_init_model_path = os.path.join(output_dir, "decoder_init_model.onnx") - if os.path.exists(encoder_model_path): - force_set_static(encoder_model_path, encoder_model_path, params_to_fix) + # Optional: export with optimum + export_with_optimum_cli(model_name, str(output_dir), opset) - if os.path.exists(decoder_model_path): + params_to_fix = STATIC_PARAMS if static else STATIC_PARAMS + encoder_model = output_dir / "encoder_model.onnx" + decoder_model = output_dir / "decoder_model.onnx" + decoder_init_model = output_dir / "decoder_init_model.onnx" - # Create a copy of decoder_model.onnx before setting it to static - shutil.copy(decoder_model_path, decoder_init_model_path) - - # Set the copied model to static using STATIC_PARAMS_KV + if decoder_model.exists(): + shutil.copy(decoder_model, decoder_init_model) + force_set_static(decoder_init_model, decoder_init_model, STATIC_PARAMS_KV) force_set_static( - decoder_init_model_path, decoder_init_model_path, STATIC_PARAMS_KV + decoder_model, output_dir / "decoder_model.onnx", STATIC_PARAMS ) - static_decoder_model_path = os.path.join(output_dir, "decoder_model.onnx") - force_set_static(decoder_model_path, static_decoder_model_path, STATIC_PARAMS) - print("Model export and static shape conversion completed successfully.") +# ---------------------------------------------------------------------- +# Main entry point +# ---------------------------------------------------------------------- if __name__ == "__main__": args = parse_args() export_whisper_model( diff --git a/lira/models/whisper/transcribe.py b/lira/models/whisper/transcribe.py index d6745aa..1e3849b 100644 --- a/lira/models/whisper/transcribe.py +++ b/lira/models/whisper/transcribe.py @@ -6,6 +6,7 @@ from transformers import WhisperFeatureExtractor, WhisperTokenizer import time import json +import os import torchaudio from pathlib import Path from jiwer import wer, cer @@ -318,10 +319,14 @@ def run(args): # Load providers based on device whisper = WhisperONNX( - encoder_path=f"{args.model}/encoder_model.onnx", - decoder_path=f"{args.model}/decoder_model.onnx", - decoder_init_path=f"{args.model}/decoder_init_model.onnx", - decoder_past_path=f"{args.model}/decoder_with_past_model.onnx", + encoder_path=(Path(args.model) / "encoder_model.onnx").resolve().as_posix(), + decoder_path=(Path(args.model) / "decoder_model.onnx").resolve().as_posix(), + decoder_init_path=(Path(args.model) / "decoder_init_model.onnx") + .resolve() + .as_posix(), + decoder_past_path=(Path(args.model) / "decoder_with_past_model.onnx") + .resolve() + .as_posix(), encoder_provider=get_provider( args.device, args.model_type, "encoder", cache_dir=args.cache ),