Skip to content

Failure after converting HLO to StableHLO: INTERNAL: during context [end-of-post-layout_assignment] #39835

@housrepository

Description

@housrepository

Bug Report:

This bug is triggered by the HLO-to-StableHLO conversion: the HLO module runs successfully via run_hlo_module --input_format=hlo, and the conversion via hlo-translate --hlo-to-mlir also succeeds, but running the translated StableHLO via run_hlo_module --input_format=stablehlo fails with:

INTERNAL: during context [end-of-post-layout_assignment]: Operand %tuple.5 = (s32[], bf16[...]{2,1,0:T(8,128)(2,1)}, bf16[...]{2,1,0}) tuple(%constant.10, %Arg_0.5, %broadcast.1), metadata={source_file="-" source_line=18 source_end_line=18 source_column=12 source_end_column=12} shape does not match parameter's %arg_tuple.1 = (s32[], bf16[...]{2,1,0}, bf16[...]{2,1,0}) parameter(0) in %while.1 = (s32[], bf16[...]{2,1,0}, bf16[...]{2,1,0}) while(%tuple.5), condition=%region_1.4, body=%region_0.3, metadata={source_file="-" source_line=18 source_end_line=18 source_column=12 source_end_column=12}, backend_config={"known_trip_count":{"n":"16"},"known_init_step":{"init":"0","step":"1"},"known_induction_variable":{"tuple_index":"0"},"dynamic_variable_tuple_indices":[]}

Environment

  • XLA commit: 5ce7908a2d32a9f91fd99380435cda1b645c8cc7
  • CPU: Intel(R) Core(TM) i9-14900HX
  • GPU: NVIDIA GeForce RTX 4060 Laptop GPU
  • CUDA Driver: 580.126.09

run_hlo_module (HLO) — Success

HLO:

HloModule jit__prefill_impl, entry_computation_layout={(bf16[2,16,16]{2,1,0:T(8,128)(2,1)S(5)})->bf16[2,16,16]{2,1,0:T(8,128)(2,1)}}

while_body {
  input_tuple.0 = (s32[], bf16[2,16,16]{2,1,0:T(8,128)(2,1)}, bf16[2,16,16]{2,1,0:T(8,128)(2,1)}) parameter(0)
  current_iteration_index.0 = s32[] get-tuple-element(input_tuple.0), index=0
  constant_1 = s32[] constant(1)
  incremented_index.0 = s32[] add(current_iteration_index.0, constant_1)
  orig_data = bf16[2,16,16]{2,1,0:T(8,128)(2,1)} get-tuple-element(input_tuple.0), index=1
  custom-call.0 = bf16[2,16,16]{2,1,0:T(8,128)(2,1)} custom-call(orig_data), custom_call_target="MoveToDevice"
  sum = bf16[2,16,16]{2,1,0:T(8,128)(2,1)} get-tuple-element(input_tuple.0), index=2
  sum.1 = bf16[2,16,16]{2,1,0:T(8,128)(2,1)} add(custom-call.0, sum)
  ROOT tuple_result.0 = (s32[], bf16[2,16,16]{2,1,0:T(8,128)(2,1)}, bf16[2,16,16]{2,1,0:T(8,128)(2,1)}) tuple(incremented_index.0, orig_data, sum.1)
}

while_condition {
  condition_param = (s32[], bf16[2,16,16]{2,1,0:T(8,128)(2,1)}, bf16[2,16,16]{2,1,0:T(8,128)(2,1)}) parameter(0)
  condition_current_iteration_index = s32[] get-tuple-element(condition_param), index=0
  condition_iteration_count = s32[] constant(16)
  ROOT condition_result = pred[] compare(condition_current_iteration_index, condition_iteration_count), direction=LT
}

ENTRY main {
  constant_0 = s32[] constant(0)
  param.0 = bf16[2,16,16]{2,1,0:T(8,128)(2,1)} parameter(0)
  constant_0.1 = bf16[] constant(0)
  broadcast = bf16[2,16,16]{2,1,0:T(8,128)(2,1)} broadcast(constant_0.1), dimensions={}
  tuple_for_while = (s32[], bf16[2,16,16]{2,1,0:T(8,128)(2,1)}, bf16[2,16,16]{2,1,0:T(8,128)(2,1)}) tuple(constant_0, param.0, broadcast)
  while = (s32[], bf16[2,16,16]{2,1,0:T(8,128)(2,1)}, bf16[2,16,16]{2,1,0:T(8,128)(2,1)}) while(tuple_for_while), condition=while_condition, body=while_body
  ROOT gte = bf16[2,16,16]{2,1,0:T(8,128)(2,1)} get-tuple-element(while), index=2
}

Execution Command:

run_hlo_module \
  --platform=CUDA \
  --reference_platform= \
  --input_format=hlo \
  jit__prefill_impl_1aa2b9ba_23.hlo

Output:


 ** Running jit__prefill_impl_1aa2b9ba_23.hlo** 
Running HLO module with runner CUDA...
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1774418326.762918 2115344 cuda_dnn.cc:461] Loaded cuDNN version 91900
... compiled and ran in 0.0791785s.
Skipping reference runner

run_hlo_module (StableHLO) — FAIL

Translation Command:

hlo-translate \
  --hlo-to-mlir \
  jit__prefill_impl_1aa2b9ba_23.hlo \
  -o \
  jit__prefill_impl_1aa2b9ba_23.mlir

IR After Translation:

module @jit__prefill_impl attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false, mhlo.xla_entry_computation_parameter_layouts = [dense<[2, 1, 0]> : tensor<3xindex>], mhlo.xla_entry_computation_parameter_tiles = [[dense<[8, 128]> : tensor<2xindex>, dense<[2, 1]> : tensor<2xindex>]], mhlo.xla_entry_computation_result_layout = [dense<[2, 1, 0]> : tensor<3xindex>], mhlo.xla_entry_computation_result_tiles = [[dense<[8, 128]> : tensor<2xindex>, dense<[2, 1]> : tensor<2xindex>]]} {
  func.func private @while_body(%arg0: tensor<i32>, %arg1: tensor<2x16x16xbf16>, %arg2: tensor<2x16x16xbf16>) -> (tensor<i32>, tensor<2x16x16xbf16>, tensor<2x16x16xbf16>) {
    %c = stablehlo.constant dense<1> : tensor<i32>
    %0 = stablehlo.add %arg0, %c : tensor<i32>
    %1 = stablehlo.custom_call @MoveToDevice(%arg1) {backend_config = "", result_layout = dense<[2, 1, 0]> : tensor<3xindex>, xla_shape = "bf16[2,16,16]{2,1,0:T(8,128)(2,1)}"} : (tensor<2x16x16xbf16>) -> tensor<2x16x16xbf16>
    %2 = stablehlo.add %1, %arg2 {result_layout = dense<[2, 1, 0]> : tensor<3xindex>, xla_shape = "bf16[2,16,16]{2,1,0:T(8,128)(2,1)}"} : tensor<2x16x16xbf16>
    return %0, %arg1, %2 : tensor<i32>, tensor<2x16x16xbf16>, tensor<2x16x16xbf16>
  }
  func.func private @while_condition(%arg0: tensor<i32>, %arg1: tensor<2x16x16xbf16>, %arg2: tensor<2x16x16xbf16>) -> tensor<i1> {
    %c = stablehlo.constant dense<16> : tensor<i32>
    %0 = stablehlo.compare  LT, %arg0, %c : (tensor<i32>, tensor<i32>) -> tensor<i1>
    return %0 : tensor<i1>
  }
  func.func @main(%arg0: tensor<2x16x16xbf16>) -> tensor<2x16x16xbf16> {
    %c = stablehlo.constant dense<0> : tensor<i32>
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<bf16>
    %0 = stablehlo.broadcast_in_dim %cst, dims = [] {result_layout = dense<[2, 1, 0]> : tensor<3xindex>, xla_shape = "bf16[2,16,16]{2,1,0:T(8,128)(2,1)}"} : (tensor<bf16>) -> tensor<2x16x16xbf16>
    %1:3 = stablehlo.while(%iterArg = %c, %iterArg_0 = %arg0, %iterArg_1 = %0) : tensor<i32>, tensor<2x16x16xbf16>, tensor<2x16x16xbf16>
    cond {
      %c_2 = stablehlo.constant dense<16> : tensor<i32>
      %2 = stablehlo.compare  LT, %iterArg, %c_2 : (tensor<i32>, tensor<i32>) -> tensor<i1>
      stablehlo.return %2 : tensor<i1>
    } do {
      %c_2 = stablehlo.constant dense<1> : tensor<i32>
      %2 = stablehlo.add %iterArg, %c_2 : tensor<i32>
      %3 = stablehlo.custom_call @MoveToDevice(%iterArg_0) {backend_config = "", result_layout = dense<[2, 1, 0]> : tensor<3xindex>, xla_shape = "bf16[2,16,16]{2,1,0:T(8,128)(2,1)}"} : (tensor<2x16x16xbf16>) -> tensor<2x16x16xbf16>
      %4 = stablehlo.add %3, %iterArg_1 {result_layout = dense<[2, 1, 0]> : tensor<3xindex>, xla_shape = "bf16[2,16,16]{2,1,0:T(8,128)(2,1)}"} : tensor<2x16x16xbf16>
      stablehlo.return %2, %iterArg_0, %4 : tensor<i32>, tensor<2x16x16xbf16>, tensor<2x16x16xbf16>
    }
    return %1#2 : tensor<2x16x16xbf16>
  }
}

Execution Command:

run_hlo_module \
  --platform=CUDA \
  --reference_platform= \
  --input_format=stablehlo \
  jit__prefill_impl_1aa2b9ba_23.mlir

Output:


 ** Running jit__prefill_impl_1aa2b9ba_23.mlir** 
Running HLO module with runner CUDA...
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1774418327.177423 2115390 cuda_dnn.cc:461] Loaded cuDNN version 91900
... compiled and ran in 0.0434891s.
INTERNAL: during context [end-of-post-layout_assignment]: Operand %tuple.5 = (s32[], bf16[2,16,16]{2,1,0:T(8,128)(2,1)}, bf16[2,16,16]{2,1,0}) tuple(%constant.10, %Arg_0.5, %broadcast.1), metadata={source_file="-" source_line=18 source_end_line=18 source_column=12 source_end_column=12} shape does not match parameter's %arg_tuple.1 = (s32[], bf16[2,16,16]{2,1,0}, bf16[2,16,16]{2,1,0}) parameter(0) in %while.1 = (s32[], bf16[2,16,16]{2,1,0}, bf16[2,16,16]{2,1,0}) while(%tuple.5), condition=%region_1.4, body=%region_0.3, metadata={source_file="-" source_line=18 source_end_line=18 source_column=12 source_end_column=12}, backend_config={"known_trip_count":{"n":"16"},"known_init_step":{"init":"0","step":"1"},"known_induction_variable":{"tuple_index":"0"},"dynamic_variable_tuple_indices":[]}
	Failed to execute on CUDA

Contact

  • Email: ch395@njit.edu, zhihao.yao@njit.edu, benquike@gmail.com

Metadata

Metadata

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions