-
Notifications
You must be signed in to change notification settings - Fork 772
Failure after converting HLO to StableHLO: error: 'builtin.module' op -:12:5: error: Dynamic input dimension to reshape that is both splitted and combined is not supported #39829
Copy link
Copy link
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Bug Report:
This bug is triggered by the HLO-to-StableHLO conversion: the HLO module runs successfully via run_hlo_module --input_format=hlo, and the conversion via hlo-translate --hlo-to-mlir also succeeds, but running the translated StableHLO via run_hlo_module --input_format=stablehlo fails with:
error: 'builtin.module' op -:12:5: error: Dynamic input dimension to reshape that is both splitted and combined is not supported: output: s32[...], input: s32[2,<=4,4], input_dim: 1:
Environment
- XLA commit:
5ce7908a2d32a9f91fd99380435cda1b645c8cc7 - CPU:
Intel(R) Core(TM) i9-14900HX - GPU:
NVIDIA GeForce RTX 4060 Laptop GPU - CUDA Driver:
580.126.09
run_hlo_module (HLO) — Success
HLO:
HloModule TensorFlowScatterV1, entry_computation_layout={(s32[2,4,4]{2,1,0})->s32[]}
update_s32 {
lhs = s32[] parameter(0)
rhs = s32[] parameter(1)
ROOT add = s32[] add(lhs, rhs)
}
ENTRY main {
param = s32[2,4,4]{2,1,0} parameter(0)
two = s32[] constant(2)
param_padded_dynamic = s32[2,<=4,4]{2,1,0} set-dimension-size(param, two), dimensions={1}
reshaped = s32[<=16,2]{1,0} reshape(param_padded_dynamic), inferred_dimension=0
init = s32[] constant(0)
ROOT reduce = s32[] reduce(reshaped, init), dimensions={0,1}, to_apply=update_s32
}
Execution Command:
run_hlo_module \
--platform=CPU \
--reference_platform= \
--input_format=hlo \
TensorFlowScatterV1_60e32416_7.hloOutput:
** Running TensorFlowScatterV1_60e32416_7.hlo**
Running HLO module with runner Host...
... compiled and ran in 0.0316903s.
Skipping reference runner
run_hlo_module (StableHLO) — FAIL
Translation Command:
hlo-translate \
--hlo-to-mlir \
TensorFlowScatterV1_60e32416_7.hlo \
-o \
TensorFlowScatterV1_60e32416_7.mlirIR After Translation:
module @TensorFlowScatterV1 attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {
func.func private @update_s32(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
%0 = stablehlo.add %arg0, %arg1 : tensor<i32>
return %0 : tensor<i32>
}
func.func @main(%arg0: tensor<2x4x4xi32>) -> tensor<i32> {
%c = stablehlo.constant dense<2> : tensor<i32>
%0 = stablehlo.set_dimension_size %arg0, %c, dim = 1 : (tensor<2x4x4xi32>, tensor<i32>) -> tensor<2x?x4xi32, #stablehlo.bounds<?, 4, ?>>
%1 = stablehlo.reshape %0 : (tensor<2x?x4xi32, #stablehlo.bounds<?, 4, ?>>) -> tensor<?x2xi32, #stablehlo.bounds<16, ?>>
%c_0 = stablehlo.constant dense<0> : tensor<i32>
%2 = stablehlo.reduce(%1 init: %c_0) applies stablehlo.add across dimensions = [0, 1] : (tensor<?x2xi32, #stablehlo.bounds<16, ?>>, tensor<i32>) -> tensor<i32>
return %2 : tensor<i32>
}
}Execution Command:
run_hlo_module \
--platform=CPU \
--reference_platform= \
--input_format=stablehlo \
TensorFlowScatterV1_60e32416_7.mlirOutput:
loc("-":1:1): error: 'builtin.module' op -:12:5: error: Dynamic input dimension to reshape that is both splitted and combined is not supported: output: s32[16,2], input: s32[2,<=4,4], input_dim: 1:
-:12:5: note: see current operation: "func.return"(%4) : (tensor<i32>) -> ()
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1774418323.889865 2114695 translate.cc:190] Conversion to HLO module failed: UNKNOWN: -:12:5: error: Dynamic input dimension to reshape that is both splitted and combined is not supported: output: s32[16,2], input: s32[2,<=4,4], input_dim: 1:
-:12:5: note: see current operation: "func.return"(%4) : (tensor<i32>) -> ()
F0000 00:00:1774418323.890028 2114695 hlo_module_loader.cc:117] Failed to translate input stablehlo program to HLO text
Contact
- Email:
ch395@njit.edu, zhihao.yao@njit.edu, benquike@gmail.com
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working