-
Notifications
You must be signed in to change notification settings - Fork 772
Failure after converting HLO to StableHLO: error: 'builtin.module' op -:16:12: error: INVALID_ARGUMENT: The parameter of condition and body, the result of the body, and init must all have the same shape #39828
Description
Bug Report:
This bug is triggered by the HLO-to-StableHLO conversion: the HLO module runs successfully via run_hlo_module --input_format=hlo, and the conversion via hlo-translate --hlo-to-mlir also succeeds, but running the translated StableHLO via run_hlo_module --input_format=stablehlo fails with:
error: 'builtin.module' op -:16:12: error: INVALID_ARGUMENT: The parameter of condition and body, the result of the body, and init must all have the same shape; got Condition: (arg_tuple.2: (s32[...], s32[...], s32[], pred[])) -> pred[]; body: (arg_tuple: (s32[...], s32[...], s32[], pred[])) -> ((s32[...]), s32[...], s32[], pred[]); init: (s32[...], s32[...], s32[], pred[])..
Environment
- XLA commit:
5ce7908a2d32a9f91fd99380435cda1b645c8cc7 - CPU:
Intel(R) Core(TM) i9-14900HX - GPU:
NVIDIA GeForce RTX 4060 Laptop GPU - CUDA Driver:
580.126.09
run_hlo_module (HLO) — Success
HLO:
HloModule module, entry_computation_layout={(s32[256]{0}, s32[], pred[])->s32[1024]{0}}
body {
input_tuple.1 = (s32[1024]{0}, s32[256]{0}, s32[], pred[]) parameter(0)
input.1 = s32[1024]{0} get-tuple-element(input_tuple.1), index=0
input.2 = s32[256]{0} get-tuple-element(input_tuple.1), index=1
input.3 = s32[] get-tuple-element(input_tuple.1), index=2
async-start = ((s32[1024]{0}, s32[256]{0}, s32[]), s32[1024]{0}, u32[]) dynamic-update-slice-start(input.1, input.2, input.3)
async-done = s32[1024]{0} dynamic-update-slice-done(async-start)
input.4 = pred[] get-tuple-element(input_tuple.1), index=3
ROOT tuple = (s32[1024]{0}, s32[256]{0}, s32[], pred[]) tuple(async-done, input.2, input.3, input.4)
}
condition {
input_tuple = (s32[1024]{0}, s32[256]{0}, s32[], pred[]) parameter(0)
ROOT cond = pred[] get-tuple-element(input_tuple), index=3
}
ENTRY main {
input.5 = s32[] parameter(1)
broadcast = s32[1024]{0} broadcast(input.5), dimensions={}
input.0 = s32[256]{0} parameter(0)
input.6 = pred[] parameter(2)
while_tuple = (s32[1024]{0}, s32[256]{0}, s32[], pred[]) tuple(broadcast, input.0, input.5, input.6)
while = (s32[1024]{0}, s32[256]{0}, s32[], pred[]) while(while_tuple), condition=condition, body=body
ROOT gte = s32[1024]{0} get-tuple-element(while), index=0
}
Execution Command:
run_hlo_module \
--platform=CUDA \
--reference_platform= \
--input_format=hlo \
module_9d759a35_34.hloOutput:
** Running module_9d759a35_34.hlo**
Running HLO module with runner CUDA...
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1774418321.819708 2114360 cuda_dnn.cc:461] Loaded cuDNN version 91900
... compiled and ran in 0.0697718s.
Skipping reference runner
run_hlo_module (StableHLO) — FAIL
Translation Command:
hlo-translate \
--hlo-to-mlir \
module_9d759a35_34.hlo \
-o \
module_9d759a35_34.mlirIR After Translation:
module @module attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {
func.func private @async_wrapped(%arg0: tensor<1024xi32>, %arg1: tensor<256xi32>, %arg2: tensor<i32>) -> tensor<1024xi32> attributes {execution_thread = "main"} {
%0 = stablehlo.dynamic_update_slice %arg0, %arg1, %arg2 : (tensor<1024xi32>, tensor<256xi32>, tensor<i32>) -> tensor<1024xi32>
return %0 : tensor<1024xi32>
}
func.func private @body(%arg0: tensor<1024xi32>, %arg1: tensor<256xi32>, %arg2: tensor<i32>, %arg3: tensor<i1>) -> (tensor<1024xi32>, tensor<256xi32>, tensor<i32>, tensor<i1>) {
%0 = "mhlo.async_start"(%arg0, %arg1, %arg2) <{called_computation = @async_wrapped, execution_thread = "main"}> {xla_shape = "((s32[1024]{0}, s32[256]{0}, s32[]), s32[1024]{0}, u32[])"} : (tensor<1024xi32>, tensor<256xi32>, tensor<i32>) -> !mhlo.async_bundle<tuple<tensor<1024xi32>, tensor<256xi32>, tensor<i32>>, tensor<1024xi32>, tensor<ui32>>
%1 = "mhlo.async_done"(%0) {called_computation = @async_wrapped, execution_thread = "main"} : (!mhlo.async_bundle<tuple<tensor<1024xi32>, tensor<256xi32>, tensor<i32>>, tensor<1024xi32>, tensor<ui32>>) -> tensor<1024xi32>
return %1, %arg1, %arg2, %arg3 : tensor<1024xi32>, tensor<256xi32>, tensor<i32>, tensor<i1>
}
func.func private @condition(%arg0: tensor<1024xi32>, %arg1: tensor<256xi32>, %arg2: tensor<i32>, %arg3: tensor<i1>) -> tensor<i1> {
return %arg3 : tensor<i1>
}
func.func @main(%arg0: tensor<256xi32>, %arg1: tensor<i32>, %arg2: tensor<i1>) -> tensor<1024xi32> {
%0 = stablehlo.broadcast_in_dim %arg1, dims = [] : (tensor<i32>) -> tensor<1024xi32>
%1:4 = stablehlo.while(%iterArg = %0, %iterArg_0 = %arg0, %iterArg_1 = %arg1, %iterArg_2 = %arg2) : tensor<1024xi32>, tensor<256xi32>, tensor<i32>, tensor<i1>
cond {
stablehlo.return %iterArg_2 : tensor<i1>
} do {
%2 = "mhlo.async_start"(%iterArg, %iterArg_0, %iterArg_1) <{called_computation = @async_wrapped, execution_thread = "main"}> {xla_shape = "((s32[1024]{0}, s32[256]{0}, s32[]), s32[1024]{0}, u32[])"} : (tensor<1024xi32>, tensor<256xi32>, tensor<i32>) -> !mhlo.async_bundle<tuple<tensor<1024xi32>, tensor<256xi32>, tensor<i32>>, tensor<1024xi32>, tensor<ui32>>
%3 = "mhlo.async_done"(%2) {called_computation = @async_wrapped, execution_thread = "main"} : (!mhlo.async_bundle<tuple<tensor<1024xi32>, tensor<256xi32>, tensor<i32>>, tensor<1024xi32>, tensor<ui32>>) -> tensor<1024xi32>
stablehlo.return %3, %iterArg_0, %iterArg_1, %iterArg_2 : tensor<1024xi32>, tensor<256xi32>, tensor<i32>, tensor<i1>
}
return %1#0 : tensor<1024xi32>
}
}Execution Command:
run_hlo_module \
--platform=CUDA \
--reference_platform= \
--input_format=stablehlo \
module_9d759a35_34.mlirOutput:
loc("-":1:1): error: 'builtin.module' op -:16:12: error: INVALID_ARGUMENT: The parameter of condition and body, the result of the body, and init must all have the same shape; got Condition: (arg_tuple.2: (s32[1024], s32[256], s32[], pred[])) -> pred[]; body: (arg_tuple: (s32[1024], s32[256], s32[], pred[])) -> ((s32[1024]), s32[256], s32[], pred[]); init: (s32[1024], s32[256], s32[], pred[])..
-:16:12: note: see current operation:
%1:4 = "stablehlo.while"(%0, %arg0, %arg1, %arg2) ({
^bb0(%arg7: tensor<1024xi32>, %arg8: tensor<256xi32>, %arg9: tensor<i32>, %arg10: tensor<i1>):
"stablehlo.return"(%arg10) : (tensor<i1>) -> ()
}, {
^bb0(%arg3: tensor<1024xi32>, %arg4: tensor<256xi32>, %arg5: tensor<i32>, %arg6: tensor<i1>):
%2 = "mhlo.async_start"(%arg3, %arg4, %arg5) <{called_computation = @async_wrapped, execution_thread = "main"}> {xla_shape = "((s32[1024]{0}, s32[256]{0}, s32[]), s32[1024]{0}, u32[])"} : (tensor<1024xi32>, tensor<256xi32>, tensor<i32>) -> !mhlo.async_bundle<tuple<tensor<1024xi32>, tensor<256xi32>, tensor<i32>>, tensor<1024xi32>, tensor<ui32>>
%3 = "mhlo.async_done"(%2) {called_computation = @async_wrapped, execution_thread = "main"} : (!mhlo.async_bundle<tuple<tensor<1024xi32>, tensor<256xi32>, tensor<i32>>, tensor<1024xi32>, tensor<ui32>>) -> tensor<1024xi32>
"stablehlo.return"(%3, %arg4, %arg5, %arg6) : (tensor<1024xi32>, tensor<256xi32>, tensor<i32>, tensor<i1>) -> ()
}) : (tensor<1024xi32>, tensor<256xi32>, tensor<i32>, tensor<i1>) -> (tensor<1024xi32>, tensor<256xi32>, tensor<i32>, tensor<i1>)
-:16:12: error: 'stablehlo.while' op can't be translated to XLA HLO
-:16:12: note: see current operation:
%1:4 = "stablehlo.while"(%0, %arg0, %arg1, %arg2) ({
^bb0(%arg7: tensor<1024xi32>, %arg8: tensor<256xi32>, %arg9: tensor<i32>, %arg10: tensor<i1>):
"stablehlo.return"(%arg10) : (tensor<i1>) -> ()
}, {
^bb0(%arg3: tensor<1024xi32>, %arg4: tensor<256xi32>, %arg5: tensor<i32>, %arg6: tensor<i1>):
%2 = "mhlo.async_start"(%arg3, %arg4, %arg5) <{called_computation = @async_wrapped, execution_thread = "main"}> {xla_shape = "((s32[1024]{0}, s32[256]{0}, s32[]), s32[1024]{0}, u32[])"} : (tensor<1024xi32>, tensor<256xi32>, tensor<i32>) -> !mhlo.async_bundle<tuple<tensor<1024xi32>, tensor<256xi32>, tensor<i32>>, tensor<1024xi32>, tensor<ui32>>
%3 = "mhlo.async_done"(%2) {called_computation = @async_wrapped, execution_thread = "main"} : (!mhlo.async_bundle<tuple<tensor<1024xi32>, tensor<256xi32>, tensor<i32>>, tensor<1024xi32>, tensor<ui32>>) -> tensor<1024xi32>
"stablehlo.return"(%3, %arg4, %arg5, %arg6) : (tensor<1024xi32>, tensor<256xi32>, tensor<i32>, tensor<i1>) -> ()
}) : (tensor<1024xi32>, tensor<256xi32>, tensor<i32>, tensor<i1>) -> (tensor<1024xi32>, tensor<256xi32>, tensor<i32>, tensor<i1>)
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1774418322.175888 2114407 translate.cc:190] Conversion to HLO module failed: UNKNOWN: -:16:12: error: INVALID_ARGUMENT: The parameter of condition and body, the result of the body, and init must all have the same shape; got Condition: (arg_tuple.2: (s32[1024], s32[256], s32[], pred[])) -> pred[]; body: (arg_tuple: (s32[1024], s32[256], s32[], pred[])) -> ((s32[1024]), s32[256], s32[], pred[]); init: (s32[1024], s32[256], s32[], pred[])..
-:16:12: note: see current operation:
%1:4 = "stablehlo.while"(%0, %arg0, %arg1, %arg2) ({
^bb0(%arg7: tensor<1024xi32>, %arg8: tensor<256xi32>, %arg9: tensor<i32>, %arg10: tensor<i1>):
"stablehlo.return"(%arg10) : (tensor<i1>) -> ()
}, {
^bb0(%arg3: tensor<1024xi32>, %arg4: tensor<256xi32>, %arg5: tensor<i32>, %arg6: tensor<i1>):
%2 = "mhlo.async_start"(%arg3, %arg4, %arg5) <{called_computation = @async_wrapped, execution_thread = "main"}> {xla_shape = "((s32[1024]{0}, s32[256]{0}, s32[]), s32[1024]{0}, u32[])"} : (tensor<1024xi32>, tensor<256xi32>, tensor<i32>) -> !mhlo.async_bundle<tuple<tensor<1024xi32>, tensor<256xi32>, tensor<i32>>, tensor<1024xi32>, tensor<ui32>>
%3 = "mhlo.async_done"(%2) {called_computation = @async_wrapped, execution_thread = "main"} : (!mhlo.async_bundle<tuple<tensor<1024xi32>, tensor<256xi32>, tensor<i32>>, tensor<1024xi32>, tensor<ui32>>) -> tensor<1024xi32>
"stablehlo.return"(%3, %arg4, %arg5, %arg6) : (tensor<1024xi32>, tensor<256xi32>, tensor<i32>, tensor<i1>) -> ()
}) : (tensor<1024xi32>, tensor<256xi32>, tensor<i32>, tensor<i1>) -> (tensor<1024xi32>, tensor<256xi32>, tensor<i32>, tensor<i1>)
-:16:12: error: 'stablehlo.while' op can't be translated to XLA HLO
-:16:12: note: see current operation:
%1:4 = "stablehlo.while"(%0, %arg0, %arg1, %arg2) ({
^bb0(%arg7: tensor<1024xi32>, %arg8: tensor<256xi32>, %arg9: tensor<i32>, %arg10: tensor<i1>):
"stablehlo.return"(%arg10) : (tensor<i1>) -> ()
}, {
^bb0(%arg3: tensor<1024xi32>, %arg4: tensor<256xi32>, %arg5: tensor<i32>, %arg6: tensor<i1>):
%2 = "mhlo.async_start"(%arg3, %arg4, %arg5) <{called_computation = @async_wrapped, execution_thread = "main"}> {xla_shape = "((s32[1024]{0}, s32[256]{0}, s32[]), s32[1024]{0}, u32[])"} : (tensor<1024xi32>, tensor<256xi32>, tensor<i32>) -> !mhlo.async_bundle<tuple<tensor<1024xi32>, tensor<256xi32>, tensor<i32>>, tensor<1024xi32>, tensor<ui32>>
%3 = "mhlo.async_done"(%2) {called_computation = @async_wrapped, execution_thread = "main"} : (!mhlo.async_bundle<tuple<tensor<1024xi32>, tensor<256xi32>, tensor<i32>>, tensor<1024xi32>, tensor<ui32>>) -> tensor<1024xi32>
"stablehlo.return"(%3, %arg4, %arg5, %arg6) : (tensor<1024xi32>, tensor<256xi32>, tensor<i32>, tensor<i1>) -> ()
}) : (tensor<1024xi32>, tensor<256xi32>, tensor<i32>, tensor<i1>) -> (tensor<1024xi32>, tensor<256xi32>, tensor<i32>, tensor<i1>)
F0000 00:00:1774418322.176375 2114407 hlo_module_loader.cc:117] Failed to translate input stablehlo program to HLO text
Contact
- Email:
ch395@njit.edu, zhihao.yao@njit.edu, benquike@gmail.com