-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Open
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug
Description
Expected behavior
When importing an ONNX model with BatchNormalization operator that has training_mode=0 (inference mode), TVM Relax should generate R.nn.batch_norm(..., training=False) and use the provided running_mean and running_var parameters for normalization.
According to the ONNX BatchNormalization specification:
training_mode=0(default): Use running statistics (inference mode)training_mode=1: Compute batch statistics from input (training mode)
Actual behavior
TVM Relax ONNX frontend ignores the training_mode attribute and always generates R.nn.batch_norm(..., training=True), causing TVM to compute batch statistics from the input tensor instead of using the provided running_mean and running_var.
Generated IR (incorrect):
lv: R.Tuple(...) = R.nn.batch_norm(X, scale, bias, mean, var, axis=1, epsilon=1e-05, training=True)
^^^^^^^^^^^^
Should be training=False!
Environment
- TVM version: 0.23.dev0 (commit hash if available)
- OS: Ubuntu Linux
- Target: llvm (CPU)
- Python: 3.11
- ONNX opset: 15
Steps to reproduce
Minimal reproduction script
Save as reproduce_bn_bug.py and run with python reproduce_bn_bug.py:
#!/usr/bin/env python3
"""
TVM BatchNormalization training_mode Bug - Minimal Reproduction
"""
import numpy as np
import onnx
from onnx import helper, TensorProto, numpy_helper
import onnxruntime as ort
import tvm
from tvm.relax.frontend.onnx import from_onnx
from tvm import relax
def create_minimal_bn_model():
"""Create a minimal ONNX model with only BatchNormalization (training_mode=0)."""
batch, channels, height, width = 2, 3, 4, 4
epsilon = 1e-5
X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [batch, channels, height, width])
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [batch, channels, height, width])
scale = numpy_helper.from_array(np.array([1.0, 2.0, 0.5], dtype=np.float32), name='scale')
bias = numpy_helper.from_array(np.array([0.0, 1.0, -1.0], dtype=np.float32), name='bias')
mean = numpy_helper.from_array(np.array([0.5, 1.0, 2.0], dtype=np.float32), name='mean')
var = numpy_helper.from_array(np.array([0.25, 1.0, 4.0], dtype=np.float32), name='var')
bn_node = helper.make_node(
'BatchNormalization',
inputs=['X', 'scale', 'bias', 'mean', 'var'],
outputs=['Y'],
epsilon=epsilon,
momentum=0.9,
training_mode=0 # KEY: inference mode!
)
graph = helper.make_graph([bn_node], 'bn_test', [X], [Y], [scale, bias, mean, var])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid('', 15)])
model.ir_version = 8
onnx.checker.check_model(model)
return model
def main():
print("Creating ONNX model with BatchNormalization (training_mode=0)...")
model = create_minimal_bn_model()
# Verify ONNX attribute
for node in model.graph.node:
if node.op_type == 'BatchNormalization':
training_mode = next((a.i for a in node.attribute if a.name == 'training_mode'), 0)
print(f"ONNX training_mode = {training_mode}")
# Test input
np.random.seed(42)
input_data = np.random.randn(2, 3, 4, 4).astype(np.float32)
# ONNX Runtime (reference)
model_bytes = model.SerializeToString()
sess = ort.InferenceSession(model_bytes, providers=['CPUExecutionProvider'])
ort_output = sess.run(None, {'X': input_data})[0]
print(f"ORT output sample: {ort_output[0, 0, 0, :3]}")
# TVM Relax
shape_dict = {'X': list(input_data.shape)}
mod = from_onnx(model, shape_dict=shape_dict)
# Check IR for the bug
ir_text = mod.script()
if 'training=True' in ir_text:
print("\n[BUG] TVM IR contains training=True (should be False)!")
for line in ir_text.split('\n'):
if 'batch_norm' in line:
print(f" {line.strip()}")
# Compile and run
target = tvm.target.Target("llvm")
ex = tvm.compile(mod, target)
device = tvm.cpu()
vm = relax.VirtualMachine(ex, device)
tvm_input = tvm.runtime.tensor(input_data, device=device)
tvm_output = vm['main'](tvm_input).numpy()
print(f"TVM output sample: {tvm_output[0, 0, 0, :3]}")
# Compare
max_diff = np.max(np.abs(ort_output - tvm_output))
print(f"\nMax difference (ORT vs TVM): {max_diff:.6f}")
if max_diff > 0.001:
print("\n[BUG CONFIRMED] TVM produces incorrect results!")
return 1
return 0
if __name__ == '__main__':
exit(main())Expected output
Creating ONNX model with BatchNormalization (training_mode=0)...
ONNX training_mode = 0
ORT output sample: [-0.00657159 -1.2765031 0.29537117]
[BUG] TVM IR contains training=True (should be False)!
lv: R.Tuple(...) = R.nn.batch_norm(X, ..., training=True)
TVM output sample: [ 0.66183543 -0.05753053 0.8328741 ]
Max difference (ORT vs TVM): 2.758021
[BUG CONFIRMED] TVM produces incorrect results!
Triage
- needs-triage
Metadata
Metadata
Assignees
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug