-
Notifications
You must be signed in to change notification settings - Fork 772
Failure after converting HLO to StableHLO: INTERNAL: during context [end-of-post-layout_assignment] #39835
Description
Bug Report:
This bug is triggered by the HLO-to-StableHLO conversion: the HLO module runs successfully via run_hlo_module --input_format=hlo, and the conversion via hlo-translate --hlo-to-mlir also succeeds, but running the translated StableHLO via run_hlo_module --input_format=stablehlo fails with:
INTERNAL: during context [end-of-post-layout_assignment]: Operand %tuple.5 = (s32[], bf16[...]{2,1,0:T(8,128)(2,1)}, bf16[...]{2,1,0}) tuple(%constant.10, %Arg_0.5, %broadcast.1), metadata={source_file="-" source_line=18 source_end_line=18 source_column=12 source_end_column=12} shape does not match parameter's %arg_tuple.1 = (s32[], bf16[...]{2,1,0}, bf16[...]{2,1,0}) parameter(0) in %while.1 = (s32[], bf16[...]{2,1,0}, bf16[...]{2,1,0}) while(%tuple.5), condition=%region_1.4, body=%region_0.3, metadata={source_file="-" source_line=18 source_end_line=18 source_column=12 source_end_column=12}, backend_config={"known_trip_count":{"n":"16"},"known_init_step":{"init":"0","step":"1"},"known_induction_variable":{"tuple_index":"0"},"dynamic_variable_tuple_indices":[]}
Environment
- XLA commit:
5ce7908a2d32a9f91fd99380435cda1b645c8cc7 - CPU:
Intel(R) Core(TM) i9-14900HX - GPU:
NVIDIA GeForce RTX 4060 Laptop GPU - CUDA Driver:
580.126.09
run_hlo_module (HLO) — Success
HLO:
HloModule jit__prefill_impl, entry_computation_layout={(bf16[2,16,16]{2,1,0:T(8,128)(2,1)S(5)})->bf16[2,16,16]{2,1,0:T(8,128)(2,1)}}
while_body {
input_tuple.0 = (s32[], bf16[2,16,16]{2,1,0:T(8,128)(2,1)}, bf16[2,16,16]{2,1,0:T(8,128)(2,1)}) parameter(0)
current_iteration_index.0 = s32[] get-tuple-element(input_tuple.0), index=0
constant_1 = s32[] constant(1)
incremented_index.0 = s32[] add(current_iteration_index.0, constant_1)
orig_data = bf16[2,16,16]{2,1,0:T(8,128)(2,1)} get-tuple-element(input_tuple.0), index=1
custom-call.0 = bf16[2,16,16]{2,1,0:T(8,128)(2,1)} custom-call(orig_data), custom_call_target="MoveToDevice"
sum = bf16[2,16,16]{2,1,0:T(8,128)(2,1)} get-tuple-element(input_tuple.0), index=2
sum.1 = bf16[2,16,16]{2,1,0:T(8,128)(2,1)} add(custom-call.0, sum)
ROOT tuple_result.0 = (s32[], bf16[2,16,16]{2,1,0:T(8,128)(2,1)}, bf16[2,16,16]{2,1,0:T(8,128)(2,1)}) tuple(incremented_index.0, orig_data, sum.1)
}
while_condition {
condition_param = (s32[], bf16[2,16,16]{2,1,0:T(8,128)(2,1)}, bf16[2,16,16]{2,1,0:T(8,128)(2,1)}) parameter(0)
condition_current_iteration_index = s32[] get-tuple-element(condition_param), index=0
condition_iteration_count = s32[] constant(16)
ROOT condition_result = pred[] compare(condition_current_iteration_index, condition_iteration_count), direction=LT
}
ENTRY main {
constant_0 = s32[] constant(0)
param.0 = bf16[2,16,16]{2,1,0:T(8,128)(2,1)} parameter(0)
constant_0.1 = bf16[] constant(0)
broadcast = bf16[2,16,16]{2,1,0:T(8,128)(2,1)} broadcast(constant_0.1), dimensions={}
tuple_for_while = (s32[], bf16[2,16,16]{2,1,0:T(8,128)(2,1)}, bf16[2,16,16]{2,1,0:T(8,128)(2,1)}) tuple(constant_0, param.0, broadcast)
while = (s32[], bf16[2,16,16]{2,1,0:T(8,128)(2,1)}, bf16[2,16,16]{2,1,0:T(8,128)(2,1)}) while(tuple_for_while), condition=while_condition, body=while_body
ROOT gte = bf16[2,16,16]{2,1,0:T(8,128)(2,1)} get-tuple-element(while), index=2
}
Execution Command:
run_hlo_module \
--platform=CUDA \
--reference_platform= \
--input_format=hlo \
jit__prefill_impl_1aa2b9ba_23.hloOutput:
** Running jit__prefill_impl_1aa2b9ba_23.hlo**
Running HLO module with runner CUDA...
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1774418326.762918 2115344 cuda_dnn.cc:461] Loaded cuDNN version 91900
... compiled and ran in 0.0791785s.
Skipping reference runner
run_hlo_module (StableHLO) — FAIL
Translation Command:
hlo-translate \
--hlo-to-mlir \
jit__prefill_impl_1aa2b9ba_23.hlo \
-o \
jit__prefill_impl_1aa2b9ba_23.mlirIR After Translation:
module @jit__prefill_impl attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false, mhlo.xla_entry_computation_parameter_layouts = [dense<[2, 1, 0]> : tensor<3xindex>], mhlo.xla_entry_computation_parameter_tiles = [[dense<[8, 128]> : tensor<2xindex>, dense<[2, 1]> : tensor<2xindex>]], mhlo.xla_entry_computation_result_layout = [dense<[2, 1, 0]> : tensor<3xindex>], mhlo.xla_entry_computation_result_tiles = [[dense<[8, 128]> : tensor<2xindex>, dense<[2, 1]> : tensor<2xindex>]]} {
func.func private @while_body(%arg0: tensor<i32>, %arg1: tensor<2x16x16xbf16>, %arg2: tensor<2x16x16xbf16>) -> (tensor<i32>, tensor<2x16x16xbf16>, tensor<2x16x16xbf16>) {
%c = stablehlo.constant dense<1> : tensor<i32>
%0 = stablehlo.add %arg0, %c : tensor<i32>
%1 = stablehlo.custom_call @MoveToDevice(%arg1) {backend_config = "", result_layout = dense<[2, 1, 0]> : tensor<3xindex>, xla_shape = "bf16[2,16,16]{2,1,0:T(8,128)(2,1)}"} : (tensor<2x16x16xbf16>) -> tensor<2x16x16xbf16>
%2 = stablehlo.add %1, %arg2 {result_layout = dense<[2, 1, 0]> : tensor<3xindex>, xla_shape = "bf16[2,16,16]{2,1,0:T(8,128)(2,1)}"} : tensor<2x16x16xbf16>
return %0, %arg1, %2 : tensor<i32>, tensor<2x16x16xbf16>, tensor<2x16x16xbf16>
}
func.func private @while_condition(%arg0: tensor<i32>, %arg1: tensor<2x16x16xbf16>, %arg2: tensor<2x16x16xbf16>) -> tensor<i1> {
%c = stablehlo.constant dense<16> : tensor<i32>
%0 = stablehlo.compare LT, %arg0, %c : (tensor<i32>, tensor<i32>) -> tensor<i1>
return %0 : tensor<i1>
}
func.func @main(%arg0: tensor<2x16x16xbf16>) -> tensor<2x16x16xbf16> {
%c = stablehlo.constant dense<0> : tensor<i32>
%cst = stablehlo.constant dense<0.000000e+00> : tensor<bf16>
%0 = stablehlo.broadcast_in_dim %cst, dims = [] {result_layout = dense<[2, 1, 0]> : tensor<3xindex>, xla_shape = "bf16[2,16,16]{2,1,0:T(8,128)(2,1)}"} : (tensor<bf16>) -> tensor<2x16x16xbf16>
%1:3 = stablehlo.while(%iterArg = %c, %iterArg_0 = %arg0, %iterArg_1 = %0) : tensor<i32>, tensor<2x16x16xbf16>, tensor<2x16x16xbf16>
cond {
%c_2 = stablehlo.constant dense<16> : tensor<i32>
%2 = stablehlo.compare LT, %iterArg, %c_2 : (tensor<i32>, tensor<i32>) -> tensor<i1>
stablehlo.return %2 : tensor<i1>
} do {
%c_2 = stablehlo.constant dense<1> : tensor<i32>
%2 = stablehlo.add %iterArg, %c_2 : tensor<i32>
%3 = stablehlo.custom_call @MoveToDevice(%iterArg_0) {backend_config = "", result_layout = dense<[2, 1, 0]> : tensor<3xindex>, xla_shape = "bf16[2,16,16]{2,1,0:T(8,128)(2,1)}"} : (tensor<2x16x16xbf16>) -> tensor<2x16x16xbf16>
%4 = stablehlo.add %3, %iterArg_1 {result_layout = dense<[2, 1, 0]> : tensor<3xindex>, xla_shape = "bf16[2,16,16]{2,1,0:T(8,128)(2,1)}"} : tensor<2x16x16xbf16>
stablehlo.return %2, %iterArg_0, %4 : tensor<i32>, tensor<2x16x16xbf16>, tensor<2x16x16xbf16>
}
return %1#2 : tensor<2x16x16xbf16>
}
}Execution Command:
run_hlo_module \
--platform=CUDA \
--reference_platform= \
--input_format=stablehlo \
jit__prefill_impl_1aa2b9ba_23.mlirOutput:
** Running jit__prefill_impl_1aa2b9ba_23.mlir**
Running HLO module with runner CUDA...
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1774418327.177423 2115390 cuda_dnn.cc:461] Loaded cuDNN version 91900
... compiled and ran in 0.0434891s.
INTERNAL: during context [end-of-post-layout_assignment]: Operand %tuple.5 = (s32[], bf16[2,16,16]{2,1,0:T(8,128)(2,1)}, bf16[2,16,16]{2,1,0}) tuple(%constant.10, %Arg_0.5, %broadcast.1), metadata={source_file="-" source_line=18 source_end_line=18 source_column=12 source_end_column=12} shape does not match parameter's %arg_tuple.1 = (s32[], bf16[2,16,16]{2,1,0}, bf16[2,16,16]{2,1,0}) parameter(0) in %while.1 = (s32[], bf16[2,16,16]{2,1,0}, bf16[2,16,16]{2,1,0}) while(%tuple.5), condition=%region_1.4, body=%region_0.3, metadata={source_file="-" source_line=18 source_end_line=18 source_column=12 source_end_column=12}, backend_config={"known_trip_count":{"n":"16"},"known_init_step":{"init":"0","step":"1"},"known_induction_variable":{"tuple_index":"0"},"dynamic_variable_tuple_indices":[]}
Failed to execute on CUDA
Contact
- Email:
ch395@njit.edu, zhihao.yao@njit.edu, benquike@gmail.com