Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
fc751d9
replace pytorch calls to apex calls in microbenchmarking code
amd-sriram Aug 2, 2025
80b38cd
update the optimizer from torch to apex fused sgd optimizer
amd-sriram Oct 9, 2025
674bd6b
update the DistributedDataParallel from torch to apex
amd-sriram Oct 9, 2025
782bf64
Fixed the errors in conformer and wavernn models. also added units fo…
amd-sriram Oct 14, 2025
e28c8e5
refactor the code, move the models, input, output selection to audio_…
amd-sriram Oct 16, 2025
fbdd71a
refactored the audio benchmarking code so that most of the lower leve…
amd-sriram Oct 28, 2025
ead8362
moved audio_model to audio folder
amd-sriram Oct 28, 2025
0cd37b0
add loss functions for more audio models, refactor the input and outp…
amd-sriram Oct 28, 2025
1405036
add hubert loss for hubert pretrained model
amd-sriram Oct 29, 2025
b60ceaa
add recent changes in pytorch microbenchmarking to apex microbenchmar…
amd-sriram Feb 3, 2026
cc26b2f
add methods to calculate the target
amd-sriram Feb 8, 2026
3724039
add loss function for squim objective
amd-sriram Feb 8, 2026
714a142
correct usage of squim subjective model
amd-sriram Feb 9, 2026
ad68179
refactor audio loss code to combine conditions
amd-sriram Feb 9, 2026
d5b520c
refactor audio loss code to combine conditions
amd-sriram Feb 9, 2026
870711b
add error messages for undefined input, target, model, loss function …
amd-sriram Feb 9, 2026
007e0cf
fix error related to sdr loss
amd-sriram Feb 9, 2026
768bf9c
fix error related to hdmucs loss
amd-sriram Feb 9, 2026
8696de9
apply some recent pytorch changes to torchaudio benchmark
amd-sriram Feb 9, 2026
2f02eb4
replace apex amp with torch amp
amd-sriram Feb 9, 2026
a8ee74d
change the location of optimizer step before profiling
amd-sriram Feb 11, 2026
55c5517
add readme sections for apex and audio
amd-sriram Feb 17, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 30 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
# pytorch-micro-benchmarking

This repo provides microbenchmarking script for training and inferencing models in pytorch, apex and torchaudio libraries on ROCm.

## Pytorch

We supply a small microbenchmarking script for PyTorch training on ROCm.

To execute:
Expand Down Expand Up @@ -37,10 +42,10 @@ python3 micro_benchmarking_pytorch.py --device_ids=1 --network resnet50 --distri
To run FlopsProfiler (with deepspeed.profiling.flops_profiler imported):
`python micro_benchmarking_pytorch.py --network resnet50 --amp-opt-level=2 --batch-size=256 --iterations=20 --flops-prof-step 10`

## Performance tuning
### Performance tuning
If performance on a specific card and/or model is found to be lacking, typically some gains can be made by tuning MIOpen. For this, `export MIOPEN_FIND_ENFORCE=3` prior to running the model. This will take some time if untuned configurations are encountered and write to a local performance database. More information on this can be found in the [MIOpen documentation](https://rocm.github.io/MIOpen/doc/html/perfdatabase.html).

## PyTorch 2.0
### PyTorch 2.0
Added the `--compile` option opens up PyTorch 2.0 capabilities, which comes with several options. Here are some notes from upstream:
```
Optimizes given model/function using TorchDynamo and specified backend.
Expand Down Expand Up @@ -75,3 +80,26 @@ python micro_benchmarking_pytorch.py --network resnet50 --compile --compileConte
python micro_benchmarking_pytorch.py --network resnet50 --compile --compileContext "{'options': {'static-memory': 'True', 'matmul-padding': 'True'}}"
```
Note: you cannot pass the `mode` and `options` options together.

## TorchAudio

The script and parameters for torchaudio are similar to pytorch.

To execute:
`python micro_benchmarking_audio.py --network <network name> [--batch-size <batch size> ] [--iterations <number of iterations>] [--fp16 <0 or 1> ] [--distributed_dataparallel] [--device_ids <comma separated list (no spaces) of GPU indices (0-indexed) to run distributed_dataparallel api on>] `

Possible network names are: `wav2vec2_base`, `deepspeech`, `hdemucs_low`, `tacotron2`, `wavernn`, `wav2letter`, `hubert_base` etc.

## Apex

The script and parameters for torchaudio are similar to pytorch.

To execute:
`python micro_benchmarking_apex.py --network <network name> [--batch-size <batch size> ] [--iterations <number of iterations>] [--fp16 <0 or 1> ] [--distributed_dataparallel] [--device_ids <comma separated list (no spaces) of GPU indices (0-indexed) to run distributed_dataparallel api on>] [--sync_bn] [--keep-batchnorm-fp32 <true|false>] [--loss-scale <value|dynamic>]`

There are three additional parameters.
1. `--sync_bn`: Use apex synchronized batch normalization across GPUs (useful for multi-GPU training).
2. `--keep-batchnorm-fp32`: Keep batch norm layers in FP32 when using AMP (e.g. `--keep-batchnorm-fp32 true`). Omit with opt_level O1.
3. `--loss-scale`: Loss scale for mixed precision. It is a number (e.g. `1024`) for static scaling, or `dynamic` for adaptive scaling.

Instead of amp flag (true/false), there is a level of amp optimization used in apex.
65 changes: 65 additions & 0 deletions audio/audio_input.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import torch
from audio.audio_model import *


def get_input_type(network_name):
if network_name in acoustic_models or network_name in source_separation_models or network_name in speech_quality_models:
return "waveform"
elif network_name in speech_recognition_models:
return "acoustic features"
elif network_name in speech_synthesis_models:
if "wavernn" in network_name:
return "waveform"
else:
return "tokens"


def get_input(network_name, network, batch_size):
if network_name in acoustic_models:
inp = {"waveforms": torch.randn(batch_size, FRAME_COUNT, device="cuda")}
elif network_name in source_separation_models:
if "hdemucs" in network_name:
inp = {"input" : torch.randn(batch_size, 2, FRAME_COUNT, device="cuda")}
else:
inp = {"input" : torch.randn(batch_size, 1, FRAME_COUNT, device="cuda")}
elif network_name in speech_recognition_models:
if "deepspeech" in network_name:
#number of channels must be specified for deepspeech
inp = {"x" : torch.randn(batch_size, 1, FRAME_COUNT, ACOUSTIC_FEATURES_SIZE, device="cuda")}
elif "wav2letter" in network_name:
inp = {"x" : torch.randn(batch_size, ACOUSTIC_FEATURES_SIZE, FRAME_COUNT, device="cuda")}
elif "emformer" in network_name:
inp = {"input" : torch.randn(batch_size, FRAME_COUNT, ACOUSTIC_FEATURES_SIZE, device="cuda"),
"lengths" : torch.randint(1, FRAME_COUNT, (batch_size,)).to(device="cuda")}
elif "conformer" in network_name:
lengths = torch.randint(1, FRAME_COUNT, (batch_size,), device="cuda")
inp = {"input" : torch.rand(batch_size, int(lengths.max()), 80, device="cuda"),
"lengths" : lengths}
elif network_name in speech_quality_models:
if "subjective" in network_name:
inp = {"waveform" : torch.randn(batch_size, FRAME_COUNT, device="cuda"),
"reference" : torch.randn(batch_size, FRAME_COUNT, device="cuda")}
else:
inp = {"x" : torch.randn(batch_size, FRAME_COUNT, device="cuda")}
elif network_name in speech_synthesis_models:
if "wavernn" in network_name:
spec_frames = 64
waveform_length = HOP_LENGTH * (spec_frames - 4)

inp = {"waveform" : torch.rand(batch_size, 1, waveform_length, device="cuda"),
"specgram": torch.rand(batch_size, 1, N_FREQ, spec_frames, device="cuda")}
elif "tacotron2" in network_name:
n_mels = 80
max_mel_specgram_length = 300
max_text_length = 100
inp = {"tokens" : torch.randint(0, 148, (batch_size, max_text_length), dtype=torch.int32, device="cuda"),
"token_lengths" : max_text_length * torch.ones((batch_size,), device="cuda"),
"mel_specgram": torch.rand(batch_size, n_mels, max_mel_specgram_length, device="cuda"),
"mel_specgram_lengths" : max_mel_specgram_length * torch.ones((batch_size,), dtype=torch.int32, device="cuda")}
elif network_name in speech_representation_models:
inp = {"waveforms" : torch.rand(batch_size, FRAME_COUNT, device="cuda"),
"labels" : torch.randint(0, 100, (batch_size, 2), dtype=torch.int32, device="cuda")}
else:
print (f"Input for {network_name} not defined")
sys.exit(1)
return inp
73 changes: 73 additions & 0 deletions audio/audio_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import torch
from audio.languagemodels import LanguageModel
import string
from audio.hubert_loss import hubert_loss
from audio.sdr import si_sdr_loss
from torch import nn
from audio.audio_model import *


def get_criterion(network_name):
criterion = None
if network_name in speech_representation_models:
criterion = hubert_loss
elif network_name in speech_recognition_models or network_name in acoustic_models:
char_blank = "*"
char_space = " "
char_apostrophe = "'"
labels = char_blank + char_space + char_apostrophe + string.ascii_lowercase
language_model = LanguageModel(labels, char_blank, char_space)
criterion = torch.nn.CTCLoss(blank=language_model.mapping[char_blank], zero_infinity=False)
elif "wavernn" in network_name:
criterion = nn.CrossEntropyLoss()
elif "conv_tasnet" in network_name:
criterion = si_sdr_loss
elif "tacotron2" in network_name:
criterion = nn.MSELoss()
elif "hdemucs" in network_name:
criterion = nn.L1Loss(reduction='none')
elif "squim" in network_name:
criterion = nn.L1Loss()
else:
print (f"Criterion for network name {network_name} not defined")
sys.exit(1)
return criterion


def calculate_loss(network_name, criterion, output, target, batch_size, input):
loss = 0
if network_name in speech_representation_models:
logit_m, logit_u, feature_penalty = output
loss = criterion(logit_m, logit_u, feature_penalty)
elif network_name in speech_recognition_models or network_name in acoustic_models:
output = output.transpose(-1, -2).transpose(0, 1)
T, N, C = output.shape
target, target_lengths = target
tensors_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)
loss = criterion(output, target, tensors_lengths, target_lengths)
elif "wavernn" in network_name:
output = output.squeeze(1)
output = output.transpose(1, 2)
loss = criterion(output, target)
elif "conv_tasnet" in network_name:
target, mask = target
loss = criterion(output, target, mask)
elif "tacotron2" in network_name or "subjective" in network_name:
loss = criterion(output, target)
elif "objective" in network_name:
loss = 0
weights = [1, 2, 0.5, 2]
for index in range(len(output)):
if index == 0:
loss = criterion(output[index], target[index])
else:
loss += criterion(output[index], target[index])
loss += criterion(input["x"], target[3])
elif "hdemucs" in network_name:
dims = tuple(range(2, target.dim()))
loss = criterion(output, target)
loss = loss.mean(dims).mean(0)
else:
print (f"Loss function for {network_name} not defined")
sys.exit(1)
return loss
103 changes: 103 additions & 0 deletions audio/audio_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import torch
import torchaudio
import sys

ACOUSTIC_FEATURES_SIZE = 32
FRAME_COUNT = 1024
HOP_LENGTH = 36
N_FREQ = 128
CLASSES_COUNT = 29

#different audio tasks related models
acoustic_models = {
"wav2vec2_base" : torchaudio.models.wav2vec2_base,
"wav2vec2_large" : torchaudio.models.wav2vec2_large,
"wav2vec2_large_lv60k" : torchaudio.models.wav2vec2_large_lv60k,
"wav2vec2_xlsr_300m" : torchaudio.models.wav2vec2_xlsr_300m,
"wav2vec2_xlsr_1b" : torchaudio.models.wav2vec2_xlsr_1b,
"wav2vec2_xlsr_2b" : torchaudio.models.wav2vec2_xlsr_2b,
"hubert_base" : torchaudio.models.hubert_base,
"hubert_large" : torchaudio.models.hubert_large,
"hubert_xlarge" : torchaudio.models.hubert_xlarge,
"wavlm_base" : torchaudio.models.wavlm_base,
"wavlm_large" : torchaudio.models.wavlm_large,
}

speech_recognition_models = {
"conformer" : torchaudio.models.Conformer,
"deepspeech" : torchaudio.models.DeepSpeech,
"emformer" : torchaudio.models.Emformer,
"wav2letter" : torchaudio.models.Wav2Letter
}

source_separation_models = {
"conv_tasnet_base" : torchaudio.models.conv_tasnet_base,
"hdemucs_low" : torchaudio.models.hdemucs_low,
"hdemucs_medium" : torchaudio.models.hdemucs_medium,
"hdemucs_high" : torchaudio.models.hdemucs_high,
}

speech_quality_models = {
"squim_objective_base" : torchaudio.models.squim_objective_base,
"squim_subjective_base" : torchaudio.models.squim_subjective_base
}

speech_synthesis_models = {
"tacotron2" : torchaudio.models.Tacotron2,
"wavernn" : torchaudio.models.WaveRNN
}


speech_representation_models = {
"hubert_pretrain_base" : torchaudio.models.hubert_pretrain_base,
"hubert_pretrain_large" : torchaudio.models.hubert_pretrain_large,
"hubert_pretrain_xlarge" : torchaudio.models.hubert_pretrain_xlarge
}

def get_network_names():
return sorted(list(acoustic_models.keys()) +
list(speech_recognition_models.keys()) +
list(source_separation_models.keys()) +
list(speech_quality_models.keys()) +
list(speech_synthesis_models.keys()) +
list(speech_representation_models.keys()))


def get_network(network_name):
if network_name in acoustic_models:
return acoustic_models[network_name](aux_num_out=CLASSES_COUNT).to(device="cuda")
elif network_name in source_separation_models:
if "hdemucs" in network_name:
return source_separation_models[network_name](sources = ["vocals"]).to(device="cuda")
else:
return source_separation_models[network_name]().to(device="cuda")
elif network_name in speech_recognition_models:
if "deepspeech" in network_name:
return speech_recognition_models[network_name](n_feature = ACOUSTIC_FEATURES_SIZE).to(device="cuda")
elif "wav2letter" in network_name:
return speech_recognition_models[network_name](num_features = ACOUSTIC_FEATURES_SIZE).to(device="cuda")
elif "emformer" in network_name:
return speech_recognition_models[network_name](input_dim = ACOUSTIC_FEATURES_SIZE,
num_heads=8,
ffn_dim=1024,
num_layers=20,
segment_length=4).to(device="cuda")
elif "conformer" in network_name:
return speech_recognition_models[network_name](input_dim = 80,
num_heads=4,
ffn_dim=128,
num_layers=4,
depthwise_conv_kernel_size=31).to(device="cuda")
elif network_name in speech_quality_models:
return speech_quality_models[network_name]().to(device="cuda")
elif network_name in speech_synthesis_models:
if "wavernn" in network_name:
return speech_synthesis_models[network_name](upsample_scales = [3, 3, 4], n_classes = 10,
hop_length = HOP_LENGTH, n_freq = 128).to(device="cuda")
else:
return speech_synthesis_models[network_name]().to(device="cuda")
elif network_name in speech_representation_models:
return speech_representation_models[network_name]().to(device="cuda")
else:
print ("ERROR: not a supported model '%s'" % network_name)
sys.exit(1)
54 changes: 54 additions & 0 deletions audio/audio_output.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from audio.audio_model import *


def get_output_selection(network_name):
if network_name in acoustic_models:
return 0
elif "conformer" in network_name or "emformer" in network_name:
return 0
elif "tacotron2" in network_name:
return 1
return None

def create_target(network_name, network, input, batch_size):

#get output
output = network(**input)
output_index = get_output_selection(network_name)
if output_index is not None:
output = output[output_index]

target = None
if network_name in speech_recognition_models or network_name in acoustic_models:
output = output.transpose(-1, -2).transpose(0, 1)
T, N, C = output.shape
target_lengths = torch.randint(low=1, high=T, size=(N,), dtype=torch.long)
target = torch.randint(
low=1,
high=C,
size=(sum(target_lengths),),
dtype=torch.long,
)
target = [target, target_lengths]
elif "wavernn" in network_name:
target = torch.randn_like(output)
target = target.squeeze(1)
target = target.transpose(1, 2)
elif "conv_tasnet" in network_name:
batch, _, time = output.shape
mask = torch.randint(low=0, high=1, size=(batch,1,time), dtype=torch.long).cuda()
target = torch.randn_like(output)
target = [target, mask]
elif "tacotron2" in network_name:
target = torch.randn_like(output)
elif "hdemucs" in network_name or "subjective" in network_name:
target = torch.randn_like(output)
elif "objective" in network_name:
target = []
for index in range(len(output)):
target.append(torch.randn_like(output[index]))
target.append(torch.randn_like(input["x"]))
else:
print (f"Target for {network_name} not defined")
sys.exit(1)
return target
36 changes: 36 additions & 0 deletions audio/hubert_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from typing import Optional

import torch
import torch.nn.functional as F
from torch import Tensor


def hubert_loss(
logit_m: Optional[Tensor],
logit_u: Optional[Tensor],
feature_penalty: Tensor,
masked_weight: float = 1.0,
unmasked_weight: float = 0.0,
feature_weight: float = 10.0,
reduction: str = "sum",
) -> Tensor:
"""Compute the cross-entropy loss on HuBERT masked and non-masked logits.
Args:
logit_m (Tensor or None): The masked logit Tensor of dimension `(masked_frames, final_dim)`.
logit_u (Tensor or None): The non-masked logit Tensor of dimension `(unmasked_frames, final_dim)`.
feature_penalty (Tensor): The feature mean value for additional penalty loss.
masked_weight (float, optional): The weight for masked cross-entropy loss (Default: ``1.0``).
unmasked_weight (float, optional): The weight for non-masked cross-entropy loss (Default: ``0.0``).
feature_weight (float, optional): The weight for feature penalty loss (Default: ``10.0``).
reduction (str, optional): The reduction method for cross-entropy loss (Default: ``"sum"``).
"""
loss = feature_penalty * feature_weight * logit_m.shape[0]
if logit_m is not None:
target_m = torch.zeros(logit_m.shape[0], dtype=torch.long, device=logit_m.device)
loss_m = F.cross_entropy(logit_m, target_m, reduction=reduction)
loss += loss_m * masked_weight
if logit_u is not None:
target_u = torch.zeros(logit_u.shape[0], dtype=torch.long, device=logit_m.device)
loss_u = F.cross_entropy(logit_u, target_u, reduction=reduction)
loss += loss_u * unmasked_weight
return loss
Loading