Skip to content

Failure after converting HLO to StableHLO: INTERNAL: during context [Unknown]: Mismatched tuple structure in original_value for instruction %Arg_0.3 = f32[...]{1,0} parameter(0), origin={({"t" {0}}, {"t" {1}}, {"t" {2}}, {"t" {3}})}. Leaf indices in shape and original_value do not match. #39827

@housrepository

Description

@housrepository

Bug Report:

This bug is triggered by the HLO-to-StableHLO conversion: the HLO module runs successfully via run_hlo_module --input_format=hlo, and the conversion via hlo-translate --hlo-to-mlir also succeeds, but running the translated StableHLO via run_hlo_module --input_format=stablehlo fails with:

INTERNAL: during context [Unknown]: Mismatched tuple structure in original_value for instruction %Arg_0.3 = f32[...]{1,0} parameter(0), origin={({"t" {0}}, {"t" {1}}, {"t" {2}}, {"t" {3}})}. Leaf indices in shape and original_value do not match.

Environment

  • XLA commit: 5ce7908a2d32a9f91fd99380435cda1b645c8cc7
  • CPU: Intel(R) Core(TM) i9-14900HX
  • GPU: NVIDIA GeForce RTX 4060 Laptop GPU
  • CUDA Driver: 580.126.09

run_hlo_module (HLO) — Success

HLO:

HloModule UnusedTupleOperands, entry_computation_layout={(f32[20,40]{1,0}, f32[40,40]{1,0}, f32[20,40]{1,0}, f32[40,40]{1,0}, pred[])->(f32[20,40]{1,0})}

on_true {
  t.1 = (f32[20,40]{1,0}, f32[40,40]{1,0}, f32[20,40]{1,0}, f32[40,40]{1,0}) parameter(0)
  lhs.1 = f32[20,40]{1,0} get-tuple-element(t.1), index=2
  rhs.1 = f32[40,40]{1,0} get-tuple-element(t.1), index=3
  dot.1 = f32[20,40]{1,0} dot(lhs.1, rhs.1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
  ROOT result.1 = (f32[20,40]{1,0}) tuple(dot.1)
}

on_false {
  t = (f32[20,40]{1,0}, f32[40,40]{1,0}, f32[20,40]{1,0}, f32[40,40]{1,0}) parameter(0), origin={({"t" {0}}, {"t" {1}}, {"t" {2}}, {"t" {3}})}
  lhs = f32[20,40]{1,0} get-tuple-element(t), index=0
  rhs = f32[40,40]{1,0} get-tuple-element(t), index=1
  dot = f32[20,40]{1,0} dot(lhs, rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0}
  ROOT result = (f32[20,40]{1,0}) tuple(dot)
}

ENTRY main {
  c0_0 = f32[20,40]{1,0} parameter(0)
  c0_1 = f32[40,40]{1,0} parameter(1)
  c1_0 = f32[20,40]{1,0} parameter(2)
  c1_1 = f32[40,40]{1,0} parameter(3)
  t.2 = (f32[20,40]{1,0}, f32[40,40]{1,0}, f32[20,40]{1,0}, f32[40,40]{1,0}) tuple(c0_0, c0_1, c1_0, c1_1)
  call = (f32[20,40]{1,0}) call(t.2), to_apply=on_true
  p = pred[] parameter(4)
  ROOT result.2 = (f32[20,40]{1,0}) conditional(p, t.2, t.2), true_computation=on_true, false_computation=on_false
}

Execution Command:

run_hlo_module \
  --platform=CPU \
  --reference_platform= \
  --input_format=hlo \
  UnusedTupleOperands_b54cb2dd_10.hlo

Output:


 ** Running UnusedTupleOperands_b54cb2dd_10.hlo** 
Running HLO module with runner Host...
... compiled and ran in 0.00857111s.
Skipping reference runner

run_hlo_module (StableHLO) — FAIL

Translation Command:

hlo-translate \
  --hlo-to-mlir \
  UnusedTupleOperands_b54cb2dd_10.hlo \
  -o \
  UnusedTupleOperands_b54cb2dd_10.mlir

IR After Translation:

module @UnusedTupleOperands attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {
  func.func private @on_true(%arg0: tensor<20x40xf32>, %arg1: tensor<40x40xf32>, %arg2: tensor<20x40xf32>, %arg3: tensor<40x40xf32>) -> tensor<20x40xf32> {
    %0 = stablehlo.dot %arg2, %arg3, precision = [DEFAULT, DEFAULT] : (tensor<20x40xf32>, tensor<40x40xf32>) -> tensor<20x40xf32>
    return %0 : tensor<20x40xf32>
  }
  func.func private @on_false(%arg0: tensor<20x40xf32> {mhlo.original_value = "{({\22t\22 {0}}, {\22t\22 {1}}, {\22t\22 {2}}, {\22t\22 {3}})}"}, %arg1: tensor<40x40xf32>, %arg2: tensor<20x40xf32>, %arg3: tensor<40x40xf32>) -> tensor<20x40xf32> {
    %0 = stablehlo.dot %arg0, %arg1, precision = [DEFAULT, DEFAULT] : (tensor<20x40xf32>, tensor<40x40xf32>) -> tensor<20x40xf32>
    return %0 : tensor<20x40xf32>
  }
  func.func @main(%arg0: tensor<20x40xf32>, %arg1: tensor<40x40xf32>, %arg2: tensor<20x40xf32>, %arg3: tensor<40x40xf32>, %arg4: tensor<i1>) -> tensor<20x40xf32> {
    %0 = call @on_true(%arg0, %arg1, %arg2, %arg3) : (tensor<20x40xf32>, tensor<40x40xf32>, tensor<20x40xf32>, tensor<40x40xf32>) -> tensor<20x40xf32>
    %1 = "stablehlo.if"(%arg4) ({
      %2 = stablehlo.dot %arg2, %arg3, precision = [DEFAULT, DEFAULT] : (tensor<20x40xf32>, tensor<40x40xf32>) -> tensor<20x40xf32>
      stablehlo.return %2 : tensor<20x40xf32>
    }, {
      %2 = stablehlo.dot %arg0, %arg1, precision = [DEFAULT, DEFAULT] : (tensor<20x40xf32>, tensor<40x40xf32>) -> tensor<20x40xf32>
      stablehlo.return %2 : tensor<20x40xf32>
    }) : (tensor<i1>) -> tensor<20x40xf32>
    return %1 : tensor<20x40xf32>
  }
}

Execution Command:

run_hlo_module \
  --platform=CPU \
  --reference_platform= \
  --input_format=stablehlo \
  UnusedTupleOperands_b54cb2dd_10.mlir

Output:


 ** Running UnusedTupleOperands_b54cb2dd_10.mlir** 
INTERNAL: during context [Unknown]: Mismatched tuple structure in original_value for instruction %Arg_0.3 = f32[20,40]{1,0} parameter(0), origin={({"t" {0}}, {"t" {1}}, {"t" {2}}, {"t" {3}})}. Leaf indices in shape and original_value do not match.
In shape only: {{}}
In original_value only: {{0}, {1}, {2}, {3}}

Contact

  • Email: ch395@njit.edu, zhihao.yao@njit.edu, benquike@gmail.com

Metadata

Metadata

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions