Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ LIRA is a **CLI-first, developer-friendly tool**: run and serve ASR models local

2. **Activate your conda environment:**
```bash
conda activate ryzen-ai-1.5.0
conda activate ryzen-ai-1.6.0
```

3. **Install LIRA in editable mode:**
Expand Down
20 changes: 12 additions & 8 deletions config/vitisai_config_whisper_base_decoder.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,19 @@
"name": "vaiml_partition",
"plugin": "vaip-pass_vaiml_partition",
"vaiml_config": {
"keep_outputs": false,
"device": "stx",
"optimize_level": 2,
"logging_level": "info",
"no_failsafe": true,
"experiment_features": [
"InsertGatherFront"
]
"optimize_level": 3,
"aiecompiler_args": "--system-stack-size=512"
}
}
],
"target": "VAIML",
"targets": [
{
"name": "VAIML",
"pass": [
"init",
"vaiml_partition"
]
}
]
}
14 changes: 13 additions & 1 deletion config/vitisai_config_whisper_base_encoder.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,20 @@
"name": "vaiml_partition",
"plugin": "vaip-pass_vaiml_partition",
"vaiml_config": {
"optimize_level": 2
"optimize_level": 3,
"fe_experiment": "use-accurate-mode=LayerNorm2PassAdf",
"aiecompiler_args": "--system-stack-size=512"
}
}
],
"target": "VAIML",
"targets": [
{
"name": "VAIML",
"pass": [
"init",
"vaiml_partition"
]
}
]
}
18 changes: 13 additions & 5 deletions config/vitisai_config_zipformer_encoder.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,20 @@
"name": "vaiml_partition",
"plugin": "vaip-pass_vaiml_partition",
"vaiml_config": {
"keep_outputs": false,
"optimize_level": 2,
"fe_experiment": "prefer-conv-to-gemm-conversion=1 enable-gather-elements-adf=1 enable-edge-reshape-unwrapping=0",
"preferred_data_storage": "unvectorized",
"threshold_gops_percent": 0
"optimize_level": 3,
"fe_experiment": "use-accurate-mode=LayerNorm2PassAdf",
"aiecompiler_args": "--system-stack-size=512"
}
}
],
"target": "VAIML",
"targets": [
{
"name": "VAIML",
"pass": [
"init",
"vaiml_partition"
]
}
]
}
127 changes: 73 additions & 54 deletions lira/models/whisper/export.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
# Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (C) 2025, Advanced Micro Devices, Inc.
# SPDX-License-Identifier: MIT

import numpy as np
import subprocess
import onnx
import onnxruntime as ort
import json
import os
import argparse
import json
import shutil
import subprocess
import argparse
from pathlib import Path

import onnx
import onnxruntime as ort
import numpy as np
from onnx import shape_inference
from lira.utils.config import get_exported_cache_dir

# Global variable for static shape parameters

# ----------------------------------------------------------------------
# Static shape parameter definitions
# ----------------------------------------------------------------------
STATIC_PARAMS = {
"batch_size": "1",
"encoder_sequence_length / 2": "1500",
Expand All @@ -26,6 +30,9 @@
}


# ----------------------------------------------------------------------
# Argument parsing
# ----------------------------------------------------------------------
def parse_args():
parser = argparse.ArgumentParser(
description="Export and fix ONNX models for Whisper"
Expand All @@ -45,24 +52,25 @@ def parse_args():
"--static", action="store_true", help="Use static shape parameters"
)
parser.add_argument(
"--force",
action="store_true",
help="Force export even if cache already exists",
"--force", action="store_true", help="Force re-export even if cache exists"
)

args = parser.parse_args()

# Load parameter map
if not args.static:
with open("config/params.json", "r") as f:
args.params_to_fix = json.load(f)
else:
args.params_to_fix = STATIC_PARAMS

return args


# ----------------------------------------------------------------------
# Optimum CLI model export
# ----------------------------------------------------------------------
def export_with_optimum_cli(model_name, output_dir, opset):
"""Run the optimum-cli export command to generate ONNX models, hiding stdout/stderr."""
# Ensure model_name is in the format 'openai/whisper-*'
"""Export transformer to ONNX format using optimum-cli."""
command = [
"optimum-cli",
"export",
Expand All @@ -80,82 +88,93 @@ def export_with_optimum_cli(model_name, output_dir, opset):
print(f"Exported ONNX model to: {output_dir}")


# ----------------------------------------------------------------------
# Static shape patching with support for models > 2GB
# ----------------------------------------------------------------------
def force_set_static(model_path, output_path, mapping):
"""
mapping: dict of {dim_param_name: value}
Example:
{
"batch_size": 1,
"encoder_sequence_length / 2": 1500,
"decoder_sequence_length": 1,
}
Safely convert ONNX dynamic dims to static dims without breaking >2GB models.
"""
model = onnx.load(model_path)
print(f"Loading large ONNX model from: {model_path}")
model = onnx.load_model(model_path, load_external_data=True)

def patch_value_info(value_infos):
for vi in value_infos:
if not vi.type.HasField("tensor_type"):
continue
shape = vi.type.tensor_type.shape
for dim in shape.dim:
if dim.dim_param in mapping:
if dim.dim_param and dim.dim_param in mapping:
dim.dim_value = int(mapping[dim.dim_param])
dim.ClearField("dim_param")

patch_value_info(model.graph.input)
patch_value_info(model.graph.output)
patch_value_info(model.graph.value_info)

onnx.save(model, output_path)
print(f"Forced static shapes written to {output_path}")
# Save with external data to handle >2GB tensors
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
datafile_name = output_path.with_suffix(".onnx_data").name

onnx.save_model(
model,
str(output_path),
save_as_external_data=True,
all_tensors_to_one_file=True,
location=datafile_name,
size_threshold=512 * 1024 * 1024,
)

# ✅ Use model path for validation instead of in-memory object
try:
onnx.checker.check_model(str(output_path))
print(f"Model integrity verified for {output_path}")
except Exception as e:
print(f"Warning: ONNX checker skipped in-memory validation: {e}")

print(f"Saved static ONNX model to {output_path}")
print(f" External tensor data stored at {datafile_name}")


# ----------------------------------------------------------------------
# Whisper model export workflow
# ----------------------------------------------------------------------
def export_whisper_model(
model_name, output_dir=None, opset=17, static=False, force=False
):
"""Exports the Whisper model and fixes static shapes if required."""
# Set default output directory to project cache if not specified
"""Exports and statically fixes ONNX models for Whisper pipeline."""
if output_dir is None:
output_dir = get_exported_cache_dir() / model_name
output_dir.mkdir(parents=True, exist_ok=True)

# Check if cache already exists
if os.path.exists(output_dir) and not force:
output_dir = Path(output_dir)
if output_dir.exists() and not force:
print(f"Cache already exists at {output_dir}. Use --force to overwrite.")
return

# Determine parameters to fix
if not static:
with open("config/params.json", "r") as f:
params_to_fix = json.load(f)
else:
params_to_fix = STATIC_PARAMS

# Step 1: Export ONNX model using optimum-cli
export_with_optimum_cli(model_name, output_dir, opset)
output_dir.mkdir(parents=True, exist_ok=True)

# Step 2: Fix static shapes for encoder and decoder models
encoder_model_path = os.path.join(output_dir, "encoder_model.onnx")
decoder_model_path = os.path.join(output_dir, "decoder_model.onnx")
decoder_init_model_path = os.path.join(output_dir, "decoder_init_model.onnx")
if os.path.exists(encoder_model_path):
force_set_static(encoder_model_path, encoder_model_path, params_to_fix)
# Optional: export with optimum
export_with_optimum_cli(model_name, str(output_dir), opset)

if os.path.exists(decoder_model_path):
params_to_fix = STATIC_PARAMS if static else STATIC_PARAMS
encoder_model = output_dir / "encoder_model.onnx"
decoder_model = output_dir / "decoder_model.onnx"
decoder_init_model = output_dir / "decoder_init_model.onnx"

# Create a copy of decoder_model.onnx before setting it to static
shutil.copy(decoder_model_path, decoder_init_model_path)

# Set the copied model to static using STATIC_PARAMS_KV
if decoder_model.exists():
shutil.copy(decoder_model, decoder_init_model)
force_set_static(decoder_init_model, decoder_init_model, STATIC_PARAMS_KV)
force_set_static(
decoder_init_model_path, decoder_init_model_path, STATIC_PARAMS_KV
decoder_model, output_dir / "decoder_model.onnx", STATIC_PARAMS
)

static_decoder_model_path = os.path.join(output_dir, "decoder_model.onnx")
force_set_static(decoder_model_path, static_decoder_model_path, STATIC_PARAMS)

print("Model export and static shape conversion completed successfully.")


# ----------------------------------------------------------------------
# Main entry point
# ----------------------------------------------------------------------
if __name__ == "__main__":
args = parse_args()
export_whisper_model(
Expand Down
13 changes: 9 additions & 4 deletions lira/models/whisper/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from transformers import WhisperFeatureExtractor, WhisperTokenizer
import time
import json
import os
import torchaudio
from pathlib import Path
from jiwer import wer, cer
Expand Down Expand Up @@ -318,10 +319,14 @@ def run(args):

# Load providers based on device
whisper = WhisperONNX(
encoder_path=f"{args.model}/encoder_model.onnx",
decoder_path=f"{args.model}/decoder_model.onnx",
decoder_init_path=f"{args.model}/decoder_init_model.onnx",
decoder_past_path=f"{args.model}/decoder_with_past_model.onnx",
encoder_path=(Path(args.model) / "encoder_model.onnx").resolve().as_posix(),
decoder_path=(Path(args.model) / "decoder_model.onnx").resolve().as_posix(),
decoder_init_path=(Path(args.model) / "decoder_init_model.onnx")
.resolve()
.as_posix(),
decoder_past_path=(Path(args.model) / "decoder_with_past_model.onnx")
.resolve()
.as_posix(),
encoder_provider=get_provider(
args.device, args.model_type, "encoder", cache_dir=args.cache
),
Expand Down
Loading