Skip to content

[Bug][Relax][ONNX] BatchNormalization ignores training_mode attribute, always uses training=True #18703

@Anemone220

Description

@Anemone220

Expected behavior

When importing an ONNX model with BatchNormalization operator that has training_mode=0 (inference mode), TVM Relax should generate R.nn.batch_norm(..., training=False) and use the provided running_mean and running_var parameters for normalization.

According to the ONNX BatchNormalization specification:

  • training_mode=0 (default): Use running statistics (inference mode)
  • training_mode=1: Compute batch statistics from input (training mode)

Actual behavior

TVM Relax ONNX frontend ignores the training_mode attribute and always generates R.nn.batch_norm(..., training=True), causing TVM to compute batch statistics from the input tensor instead of using the provided running_mean and running_var.

Generated IR (incorrect):

lv: R.Tuple(...) = R.nn.batch_norm(X, scale, bias, mean, var, axis=1, epsilon=1e-05, training=True)
                                                                                      ^^^^^^^^^^^^
                                                                            Should be training=False!

Environment

  • TVM version: 0.23.dev0 (commit hash if available)
  • OS: Ubuntu Linux
  • Target: llvm (CPU)
  • Python: 3.11
  • ONNX opset: 15

Steps to reproduce

Minimal reproduction script

Save as reproduce_bn_bug.py and run with python reproduce_bn_bug.py:

#!/usr/bin/env python3
"""
TVM BatchNormalization training_mode Bug - Minimal Reproduction
"""

import numpy as np
import onnx
from onnx import helper, TensorProto, numpy_helper
import onnxruntime as ort
import tvm
from tvm.relax.frontend.onnx import from_onnx
from tvm import relax


def create_minimal_bn_model():
    """Create a minimal ONNX model with only BatchNormalization (training_mode=0)."""
    batch, channels, height, width = 2, 3, 4, 4
    epsilon = 1e-5

    X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [batch, channels, height, width])
    Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [batch, channels, height, width])

    scale = numpy_helper.from_array(np.array([1.0, 2.0, 0.5], dtype=np.float32), name='scale')
    bias = numpy_helper.from_array(np.array([0.0, 1.0, -1.0], dtype=np.float32), name='bias')
    mean = numpy_helper.from_array(np.array([0.5, 1.0, 2.0], dtype=np.float32), name='mean')
    var = numpy_helper.from_array(np.array([0.25, 1.0, 4.0], dtype=np.float32), name='var')

    bn_node = helper.make_node(
        'BatchNormalization',
        inputs=['X', 'scale', 'bias', 'mean', 'var'],
        outputs=['Y'],
        epsilon=epsilon,
        momentum=0.9,
        training_mode=0  # KEY: inference mode!
    )

    graph = helper.make_graph([bn_node], 'bn_test', [X], [Y], [scale, bias, mean, var])
    model = helper.make_model(graph, opset_imports=[helper.make_opsetid('', 15)])
    model.ir_version = 8
    onnx.checker.check_model(model)
    return model


def main():
    print("Creating ONNX model with BatchNormalization (training_mode=0)...")
    model = create_minimal_bn_model()

    # Verify ONNX attribute
    for node in model.graph.node:
        if node.op_type == 'BatchNormalization':
            training_mode = next((a.i for a in node.attribute if a.name == 'training_mode'), 0)
            print(f"ONNX training_mode = {training_mode}")

    # Test input
    np.random.seed(42)
    input_data = np.random.randn(2, 3, 4, 4).astype(np.float32)

    # ONNX Runtime (reference)
    model_bytes = model.SerializeToString()
    sess = ort.InferenceSession(model_bytes, providers=['CPUExecutionProvider'])
    ort_output = sess.run(None, {'X': input_data})[0]
    print(f"ORT output sample: {ort_output[0, 0, 0, :3]}")

    # TVM Relax
    shape_dict = {'X': list(input_data.shape)}
    mod = from_onnx(model, shape_dict=shape_dict)

    # Check IR for the bug
    ir_text = mod.script()
    if 'training=True' in ir_text:
        print("\n[BUG] TVM IR contains training=True (should be False)!")
        for line in ir_text.split('\n'):
            if 'batch_norm' in line:
                print(f"  {line.strip()}")

    # Compile and run
    target = tvm.target.Target("llvm")
    ex = tvm.compile(mod, target)
    device = tvm.cpu()
    vm = relax.VirtualMachine(ex, device)
    tvm_input = tvm.runtime.tensor(input_data, device=device)
    tvm_output = vm['main'](tvm_input).numpy()
    print(f"TVM output sample: {tvm_output[0, 0, 0, :3]}")

    # Compare
    max_diff = np.max(np.abs(ort_output - tvm_output))
    print(f"\nMax difference (ORT vs TVM): {max_diff:.6f}")

    if max_diff > 0.001:
        print("\n[BUG CONFIRMED] TVM produces incorrect results!")
        return 1
    return 0


if __name__ == '__main__':
    exit(main())

Expected output

Creating ONNX model with BatchNormalization (training_mode=0)...
ONNX training_mode = 0
ORT output sample: [-0.00657159 -1.2765031   0.29537117]

[BUG] TVM IR contains training=True (should be False)!
  lv: R.Tuple(...) = R.nn.batch_norm(X, ..., training=True)
TVM output sample: [ 0.66183543 -0.05753053  0.8328741 ]

Max difference (ORT vs TVM): 2.758021

[BUG CONFIRMED] TVM produces incorrect results!

Triage

  • needs-triage

cc @KJlaccHoeUM9l @junrushao

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions