From d8177b85f3d4b4b438c544f24b1128ef24abea52 Mon Sep 17 00:00:00 2001 From: Iswarya Alex Date: Mon, 20 Oct 2025 21:48:14 -0700 Subject: [PATCH 1/5] Support RAI 1.6 --- .../vitisai_config_whisper_base_decoder.json | 20 ++- .../vitisai_config_whisper_base_encoder.json | 14 +- config/vitisai_config_zipformer_encoder.json | 18 +- lira/models/whisper/export.py | 159 +++++++++--------- lira/models/whisper/transcribe.py | 10 +- 5 files changed, 127 insertions(+), 94 deletions(-) diff --git a/config/vitisai_config_whisper_base_decoder.json b/config/vitisai_config_whisper_base_decoder.json index b31441d..83216e3 100644 --- a/config/vitisai_config_whisper_base_decoder.json +++ b/config/vitisai_config_whisper_base_decoder.json @@ -8,15 +8,19 @@ "name": "vaiml_partition", "plugin": "vaip-pass_vaiml_partition", "vaiml_config": { - "keep_outputs": false, - "device": "stx", - "optimize_level": 2, - "logging_level": "info", - "no_failsafe": true, - "experiment_features": [ - "InsertGatherFront" - ] + "optimize_level": 3, + "aiecompiler_args": "--system-stack-size=512" } } + ], + "target": "VAIML", + "targets": [ + { + "name": "VAIML", + "pass": [ + "init", + "vaiml_partition" + ] + } ] } \ No newline at end of file diff --git a/config/vitisai_config_whisper_base_encoder.json b/config/vitisai_config_whisper_base_encoder.json index 3fa5036..c13fad7 100644 --- a/config/vitisai_config_whisper_base_encoder.json +++ b/config/vitisai_config_whisper_base_encoder.json @@ -8,8 +8,20 @@ "name": "vaiml_partition", "plugin": "vaip-pass_vaiml_partition", "vaiml_config": { - "optimize_level": 2 + "optimize_level": 3, + "fe_experiment": "use-accurate-mode=LayerNorm2PassAdf", + "aiecompiler_args": "--system-stack-size=512" } } + ], + "target": "VAIML", + "targets": [ + { + "name": "VAIML", + "pass": [ + "init", + "vaiml_partition" + ] + } ] } \ No newline at end of file diff --git a/config/vitisai_config_zipformer_encoder.json b/config/vitisai_config_zipformer_encoder.json index 68f5de3..c13fad7 100644 --- a/config/vitisai_config_zipformer_encoder.json +++ b/config/vitisai_config_zipformer_encoder.json @@ -8,12 +8,20 @@ "name": "vaiml_partition", "plugin": "vaip-pass_vaiml_partition", "vaiml_config": { - "keep_outputs": false, - "optimize_level": 2, - "fe_experiment": "prefer-conv-to-gemm-conversion=1 enable-gather-elements-adf=1 enable-edge-reshape-unwrapping=0", - "preferred_data_storage": "unvectorized", - "threshold_gops_percent": 0 + "optimize_level": 3, + "fe_experiment": "use-accurate-mode=LayerNorm2PassAdf", + "aiecompiler_args": "--system-stack-size=512" } } + ], + "target": "VAIML", + "targets": [ + { + "name": "VAIML", + "pass": [ + "init", + "vaiml_partition" + ] + } ] } \ No newline at end of file diff --git a/lira/models/whisper/export.py b/lira/models/whisper/export.py index a27f1ba..4c9ace5 100644 --- a/lira/models/whisper/export.py +++ b/lira/models/whisper/export.py @@ -1,19 +1,23 @@ -# Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2025, Advanced Micro Devices, Inc. # SPDX-License-Identifier: MIT -import numpy as np -import subprocess -import onnx -import onnxruntime as ort -import json import os -import argparse +import json import shutil +import subprocess +import argparse +from pathlib import Path +import onnx +import onnxruntime as ort +import numpy as np from onnx import shape_inference from lira.utils.config import get_exported_cache_dir -# Global variable for static shape parameters + +# ---------------------------------------------------------------------- +# Static shape parameter definitions +# ---------------------------------------------------------------------- STATIC_PARAMS = { "batch_size": "1", "encoder_sequence_length / 2": "1500", @@ -26,43 +30,37 @@ } +# ---------------------------------------------------------------------- +# Argument parsing +# ---------------------------------------------------------------------- def parse_args(): - parser = argparse.ArgumentParser( - description="Export and fix ONNX models for Whisper" - ) + parser = argparse.ArgumentParser(description="Export and fix ONNX models for Whisper") parser.add_argument( "--model_name", required=True, help="Model name (e.g., openai/whisper-base.en, openai/whisper-medium, etc.)", ) - parser.add_argument( - "--output_dir", help="Directory to save the exported ONNX models" - ) - parser.add_argument( - "--opset", type=int, default=17, help="ONNX opset version (default: 17)" - ) - parser.add_argument( - "--static", action="store_true", help="Use static shape parameters" - ) - parser.add_argument( - "--force", - action="store_true", - help="Force export even if cache already exists", - ) + parser.add_argument("--output_dir", help="Directory to save the exported ONNX models") + parser.add_argument("--opset", type=int, default=17, help="ONNX opset version (default: 17)") + parser.add_argument("--static", action="store_true", help="Use static shape parameters") + parser.add_argument("--force", action="store_true", help="Force re-export even if cache exists") + args = parser.parse_args() + # Load parameter map if not args.static: with open("config/params.json", "r") as f: args.params_to_fix = json.load(f) else: args.params_to_fix = STATIC_PARAMS - return args +# ---------------------------------------------------------------------- +# Optimum CLI model export +# ---------------------------------------------------------------------- def export_with_optimum_cli(model_name, output_dir, opset): - """Run the optimum-cli export command to generate ONNX models, hiding stdout/stderr.""" - # Ensure model_name is in the format 'openai/whisper-*' + """Export transformer to ONNX format using optimum-cli.""" command = [ "optimum-cli", "export", @@ -74,23 +72,19 @@ def export_with_optimum_cli(model_name, output_dir, opset): output_dir, ] print(f"Running optimum-cli export for model: {model_name}") - subprocess.run( - command, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL - ) + subprocess.run(command, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) print(f"Exported ONNX model to: {output_dir}") +# ---------------------------------------------------------------------- +# Static shape patching with support for models > 2GB +# ---------------------------------------------------------------------- def force_set_static(model_path, output_path, mapping): """ - mapping: dict of {dim_param_name: value} - Example: - { - "batch_size": 1, - "encoder_sequence_length / 2": 1500, - "decoder_sequence_length": 1, - } + Safely convert ONNX dynamic dims to static dims without breaking >2GB models. """ - model = onnx.load(model_path) + print(f"Loading large ONNX model from: {model_path}") + model = onnx.load_model(model_path, load_external_data=True) def patch_value_info(value_infos): for vi in value_infos: @@ -98,64 +92,77 @@ def patch_value_info(value_infos): continue shape = vi.type.tensor_type.shape for dim in shape.dim: - if dim.dim_param in mapping: + if dim.dim_param and dim.dim_param in mapping: dim.dim_value = int(mapping[dim.dim_param]) + dim.ClearField("dim_param") patch_value_info(model.graph.input) patch_value_info(model.graph.output) patch_value_info(model.graph.value_info) - onnx.save(model, output_path) - print(f"Forced static shapes written to {output_path}") - + # Save with external data to handle >2GB tensors + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + datafile_name = output_path.with_suffix(".onnx_data").name + + onnx.save_model( + model, + str(output_path), + save_as_external_data=True, + all_tensors_to_one_file=True, + location=datafile_name, + size_threshold=512 * 1024 * 1024, + ) -def export_whisper_model( - model_name, output_dir=None, opset=17, static=False, force=False -): - """Exports the Whisper model and fixes static shapes if required.""" - # Set default output directory to project cache if not specified + # ✅ Use model path for validation instead of in-memory object + try: + onnx.checker.check_model(str(output_path)) + print(f"Model integrity verified for {output_path}") + except Exception as e: + print(f"Warning: ONNX checker skipped in-memory validation: {e}") + + print(f"Saved static ONNX model to {output_path}") + print(f" External tensor data stored at {datafile_name}") + +# ---------------------------------------------------------------------- +# Whisper model export workflow +# ---------------------------------------------------------------------- +def export_whisper_model(model_name, output_dir=None, opset=17, static=False, force=False): + """Exports and statically fixes ONNX models for Whisper pipeline.""" if output_dir is None: output_dir = get_exported_cache_dir() / model_name - output_dir.mkdir(parents=True, exist_ok=True) + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) - # Check if cache already exists - if os.path.exists(output_dir) and not force: + if output_dir.exists() and not force: print(f"Cache already exists at {output_dir}. Use --force to overwrite.") return - # Determine parameters to fix - if not static: - with open("config/params.json", "r") as f: - params_to_fix = json.load(f) - else: - params_to_fix = STATIC_PARAMS - - # Step 1: Export ONNX model using optimum-cli - export_with_optimum_cli(model_name, output_dir, opset) - - # Step 2: Fix static shapes for encoder and decoder models - encoder_model_path = os.path.join(output_dir, "encoder_model.onnx") - decoder_model_path = os.path.join(output_dir, "decoder_model.onnx") - decoder_init_model_path = os.path.join(output_dir, "decoder_init_model.onnx") - if os.path.exists(encoder_model_path): - force_set_static(encoder_model_path, encoder_model_path, params_to_fix) + # Optional: export with optimum + # export_with_optimum_cli(model_name, str(output_dir), opset) - if os.path.exists(decoder_model_path): + params_to_fix = STATIC_PARAMS if static else STATIC_PARAMS + encoder_model = output_dir / "encoder_model.onnx" + decoder_model = output_dir / "decoder_model.onnx" + decoder_init_model = output_dir / "decoder_init_model.onnx" - # Create a copy of decoder_model.onnx before setting it to static - shutil.copy(decoder_model_path, decoder_init_model_path) + static_dir = output_dir.parent / f"{output_dir.name}_static" + static_dir.mkdir(parents=True, exist_ok=True) - # Set the copied model to static using STATIC_PARAMS_KV - force_set_static( - decoder_init_model_path, decoder_init_model_path, STATIC_PARAMS_KV - ) + if encoder_model.exists(): + force_set_static(encoder_model, static_dir / "encoder_model.onnx", params_to_fix) - static_decoder_model_path = os.path.join(output_dir, "decoder_model.onnx") - force_set_static(decoder_model_path, static_decoder_model_path, STATIC_PARAMS) + if decoder_model.exists(): + shutil.copy(decoder_model, decoder_init_model) + force_set_static(decoder_init_model, decoder_init_model, STATIC_PARAMS_KV) + force_set_static(decoder_model, static_dir / "decoder_model.onnx", STATIC_PARAMS) print("Model export and static shape conversion completed successfully.") +# ---------------------------------------------------------------------- +# Main entry point +# ---------------------------------------------------------------------- if __name__ == "__main__": args = parse_args() export_whisper_model( diff --git a/lira/models/whisper/transcribe.py b/lira/models/whisper/transcribe.py index d6745aa..422b54d 100644 --- a/lira/models/whisper/transcribe.py +++ b/lira/models/whisper/transcribe.py @@ -6,6 +6,7 @@ from transformers import WhisperFeatureExtractor, WhisperTokenizer import time import json +import os import torchaudio from pathlib import Path from jiwer import wer, cer @@ -318,10 +319,11 @@ def run(args): # Load providers based on device whisper = WhisperONNX( - encoder_path=f"{args.model}/encoder_model.onnx", - decoder_path=f"{args.model}/decoder_model.onnx", - decoder_init_path=f"{args.model}/decoder_init_model.onnx", - decoder_past_path=f"{args.model}/decoder_with_past_model.onnx", + encoder_path = (Path(args.model) / "encoder_model.onnx").resolve().as_posix(), + decoder_path = (Path(args.model) / "decoder_model.onnx").resolve().as_posix(), + decoder_init_path = (Path(args.model) / "decoder_init_model.onnx").resolve().as_posix(), + decoder_past_path = (Path(args.model) / "decoder_with_past_model.onnx").resolve().as_posix(), + encoder_provider=get_provider( args.device, args.model_type, "encoder", cache_dir=args.cache ), From d0193255b451ecaa1fedf99585983b3d0af10f7e Mon Sep 17 00:00:00 2001 From: Iswarya Alex Date: Mon, 20 Oct 2025 23:32:39 -0700 Subject: [PATCH 2/5] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index c47b9e7..2ddea46 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,7 @@ LIRA is a **CLI-first, developer-friendly tool**: run and serve ASR models local 2. **Activate your conda environment:** ```bash - conda activate ryzen-ai-1.5.0 + conda activate ryzen-ai-1.6.0 ``` 3. **Install LIRA in editable mode:** From 131abbda3bcd4f1f29576a11daa982b087600c52 Mon Sep 17 00:00:00 2001 From: Iswarya Alex Date: Mon, 20 Oct 2025 23:36:45 -0700 Subject: [PATCH 3/5] Update README.md --- lira/models/whisper/export.py | 37 +++++++++++++++++++++++-------- lira/models/whisper/transcribe.py | 13 ++++++----- 2 files changed, 36 insertions(+), 14 deletions(-) diff --git a/lira/models/whisper/export.py b/lira/models/whisper/export.py index 4c9ace5..f826bb6 100644 --- a/lira/models/whisper/export.py +++ b/lira/models/whisper/export.py @@ -34,16 +34,26 @@ # Argument parsing # ---------------------------------------------------------------------- def parse_args(): - parser = argparse.ArgumentParser(description="Export and fix ONNX models for Whisper") + parser = argparse.ArgumentParser( + description="Export and fix ONNX models for Whisper" + ) parser.add_argument( "--model_name", required=True, help="Model name (e.g., openai/whisper-base.en, openai/whisper-medium, etc.)", ) - parser.add_argument("--output_dir", help="Directory to save the exported ONNX models") - parser.add_argument("--opset", type=int, default=17, help="ONNX opset version (default: 17)") - parser.add_argument("--static", action="store_true", help="Use static shape parameters") - parser.add_argument("--force", action="store_true", help="Force re-export even if cache exists") + parser.add_argument( + "--output_dir", help="Directory to save the exported ONNX models" + ) + parser.add_argument( + "--opset", type=int, default=17, help="ONNX opset version (default: 17)" + ) + parser.add_argument( + "--static", action="store_true", help="Use static shape parameters" + ) + parser.add_argument( + "--force", action="store_true", help="Force re-export even if cache exists" + ) args = parser.parse_args() @@ -72,7 +82,9 @@ def export_with_optimum_cli(model_name, output_dir, opset): output_dir, ] print(f"Running optimum-cli export for model: {model_name}") - subprocess.run(command, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + subprocess.run( + command, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL + ) print(f"Exported ONNX model to: {output_dir}") @@ -124,10 +136,13 @@ def patch_value_info(value_infos): print(f"Saved static ONNX model to {output_path}") print(f" External tensor data stored at {datafile_name}") + # ---------------------------------------------------------------------- # Whisper model export workflow # ---------------------------------------------------------------------- -def export_whisper_model(model_name, output_dir=None, opset=17, static=False, force=False): +def export_whisper_model( + model_name, output_dir=None, opset=17, static=False, force=False +): """Exports and statically fixes ONNX models for Whisper pipeline.""" if output_dir is None: output_dir = get_exported_cache_dir() / model_name @@ -150,12 +165,16 @@ def export_whisper_model(model_name, output_dir=None, opset=17, static=False, fo static_dir.mkdir(parents=True, exist_ok=True) if encoder_model.exists(): - force_set_static(encoder_model, static_dir / "encoder_model.onnx", params_to_fix) + force_set_static( + encoder_model, static_dir / "encoder_model.onnx", params_to_fix + ) if decoder_model.exists(): shutil.copy(decoder_model, decoder_init_model) force_set_static(decoder_init_model, decoder_init_model, STATIC_PARAMS_KV) - force_set_static(decoder_model, static_dir / "decoder_model.onnx", STATIC_PARAMS) + force_set_static( + decoder_model, static_dir / "decoder_model.onnx", STATIC_PARAMS + ) print("Model export and static shape conversion completed successfully.") diff --git a/lira/models/whisper/transcribe.py b/lira/models/whisper/transcribe.py index 422b54d..1e3849b 100644 --- a/lira/models/whisper/transcribe.py +++ b/lira/models/whisper/transcribe.py @@ -319,11 +319,14 @@ def run(args): # Load providers based on device whisper = WhisperONNX( - encoder_path = (Path(args.model) / "encoder_model.onnx").resolve().as_posix(), - decoder_path = (Path(args.model) / "decoder_model.onnx").resolve().as_posix(), - decoder_init_path = (Path(args.model) / "decoder_init_model.onnx").resolve().as_posix(), - decoder_past_path = (Path(args.model) / "decoder_with_past_model.onnx").resolve().as_posix(), - + encoder_path=(Path(args.model) / "encoder_model.onnx").resolve().as_posix(), + decoder_path=(Path(args.model) / "decoder_model.onnx").resolve().as_posix(), + decoder_init_path=(Path(args.model) / "decoder_init_model.onnx") + .resolve() + .as_posix(), + decoder_past_path=(Path(args.model) / "decoder_with_past_model.onnx") + .resolve() + .as_posix(), encoder_provider=get_provider( args.device, args.model_type, "encoder", cache_dir=args.cache ), From ccd8bbc160010dc0f101d7312de8c2f1996ee5fa Mon Sep 17 00:00:00 2001 From: Iswarya Alex Date: Tue, 21 Oct 2025 12:33:00 -0700 Subject: [PATCH 4/5] static for decoder --- lira/models/whisper/export.py | 20 +++++++------------- 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/lira/models/whisper/export.py b/lira/models/whisper/export.py index f826bb6..9ab373a 100644 --- a/lira/models/whisper/export.py +++ b/lira/models/whisper/export.py @@ -146,34 +146,28 @@ def export_whisper_model( """Exports and statically fixes ONNX models for Whisper pipeline.""" if output_dir is None: output_dir = get_exported_cache_dir() / model_name + output_dir = Path(output_dir) - output_dir.mkdir(parents=True, exist_ok=True) - if output_dir.exists() and not force: print(f"Cache already exists at {output_dir}. Use --force to overwrite.") return + + + output_dir.mkdir(parents=True, exist_ok=True) # Optional: export with optimum - # export_with_optimum_cli(model_name, str(output_dir), opset) + export_with_optimum_cli(model_name, str(output_dir), opset) params_to_fix = STATIC_PARAMS if static else STATIC_PARAMS encoder_model = output_dir / "encoder_model.onnx" decoder_model = output_dir / "decoder_model.onnx" decoder_init_model = output_dir / "decoder_init_model.onnx" - static_dir = output_dir.parent / f"{output_dir.name}_static" - static_dir.mkdir(parents=True, exist_ok=True) - - if encoder_model.exists(): - force_set_static( - encoder_model, static_dir / "encoder_model.onnx", params_to_fix - ) - if decoder_model.exists(): shutil.copy(decoder_model, decoder_init_model) force_set_static(decoder_init_model, decoder_init_model, STATIC_PARAMS_KV) force_set_static( - decoder_model, static_dir / "decoder_model.onnx", STATIC_PARAMS + decoder_model, output_dir / "decoder_model.onnx", STATIC_PARAMS ) print("Model export and static shape conversion completed successfully.") @@ -190,4 +184,4 @@ def export_whisper_model( opset=args.opset, static=args.static, force=args.force, - ) + ) \ No newline at end of file From aeb14bf456393708e3b095d370a63018aec31a1a Mon Sep 17 00:00:00 2001 From: Iswarya Alex Date: Tue, 21 Oct 2025 12:54:30 -0700 Subject: [PATCH 5/5] fix lint --- lira/models/whisper/export.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/lira/models/whisper/export.py b/lira/models/whisper/export.py index 9ab373a..4394f8f 100644 --- a/lira/models/whisper/export.py +++ b/lira/models/whisper/export.py @@ -146,13 +146,12 @@ def export_whisper_model( """Exports and statically fixes ONNX models for Whisper pipeline.""" if output_dir is None: output_dir = get_exported_cache_dir() / model_name - + output_dir = Path(output_dir) if output_dir.exists() and not force: print(f"Cache already exists at {output_dir}. Use --force to overwrite.") return - - + output_dir.mkdir(parents=True, exist_ok=True) # Optional: export with optimum @@ -184,4 +183,4 @@ def export_whisper_model( opset=args.opset, static=args.static, force=args.force, - ) \ No newline at end of file + )