Partial dynamic shape support in XLA#59
Conversation
There was a problem hiding this comment.
Pull request overview
This PR introduces partial dynamic-shape support by threading per-dimension symbolic expressions (DynExpr / ExpressionProto) through XLA shapes and TensorFlow shape inference, and by plumbing a runtime “batch size” value into the XLA:CPU execution path for evaluating dynamic dimensions.
Changes:
- Add per-dimension
ExpressionProto/DynExpr*metadata to XLAShape/ShapeProtoand TFTensorShapeProto, and plumb it through many builders/expanders/shape utilities. - Add XLA:CPU runtime support for a dynamic “outer batch” dimension (including passing
batch_sizethrough the CPU kernel call frame). - Add a debug flag (
--xla_compile_batch_sizes) plus TF-side utilities to pick compilation batch sizes and propagate dynamic-dimension metadata across encapsulation/launch.
Reviewed changes
Copilot reviewed 160 out of 161 changed files in this pull request and generated 11 comments.
Show a summary per file
| File | Description |
|---|---|
| third_party/xla/xla/xla_data.proto | Add expressions field and ExpressionProto message for shape expressions. |
| third_party/xla/xla/xla.proto | Add xla_compile_batch_sizes debug option. |
| third_party/xla/xla/stream_executor/tpu/c_api_decl.h | Extend TPU C API shape struct with batch-related fields. |
| third_party/xla/xla/shape_util.h | Add expression-aware ShapeUtil APIs and expression params to shape construction. |
| third_party/xla/xla/shape.h | Add DynExpr* storage/accessors, hashing changes, and batch-related equality option. |
| third_party/xla/xla/service/triangular_solve_expander.cc | Propagate expressions through reshapes in the expander. |
| third_party/xla/xla/service/shape_inference.h | Extend shape inference APIs to accept/pass expression spans. |
| third_party/xla/xla/service/reduce_scatter_combiner.cc | Adjust bitcast/permutation logic while combining reduce-scatter ops. |
| third_party/xla/xla/service/outer_dimension_propagation.h | Introduce an HLO pass for propagating outer-dimension multiplier metadata. |
| third_party/xla/xla/service/llvm_ir/tuple_ops.cc | Avoid dereferenceable metadata when shapes have dynamic expressions. |
| third_party/xla/xla/service/llvm_ir/loop_emitter.cc | Emit loop bounds from DynExpr via LLVM IR expression emission. |
| third_party/xla/xla/service/llvm_ir/llvm_util.h | Add APIs for batch dim lookup, dynamic GEP, and emitting expressions. |
| third_party/xla/xla/service/llvm_ir/llvm_util.cc | Implement batch dim loading, DynExpr→LLVM lowering, and dynamic GEP. |
| third_party/xla/xla/service/llvm_ir/llvm_loop.h | Allow loop construction to carry an optional DynExpr for end bounds. |
| third_party/xla/xla/service/llvm_ir/llvm_loop.cc | Use DynExpr-derived end bounds for loops (min with static end). |
| third_party/xla/xla/service/llvm_ir/ir_array.cc | Use dynamic GEP when shapes have dynamic expressions. |
| third_party/xla/xla/service/llvm_ir/BUILD | Add dependency for executable run options offset helper. |
| third_party/xla/xla/service/layout_assignment.cc | Ignore batch info during layout propagation equality checks. |
| third_party/xla/xla/service/hlo_creation_utils.cc | Preserve expressions when collapsing dimensions. |
| third_party/xla/xla/service/get_outer_batch_value_simplifier.h | Add HLO pass to simplify GetOuterBatchValue when static. |
| third_party/xla/xla/service/get_outer_batch_value_simplifier.cc | Implement rewrite of GetOuterBatchValue custom call to constant. |
| third_party/xla/xla/service/elemental_ir_emitter.cc | Use DynExpr for bounds in concatenate/pad/dot/reverse/reduce-window. |
| third_party/xla/xla/service/cpu/thunk_emitter.h | Add thunk emitter entry for GetOuterBatchValue. |
| third_party/xla/xla/service/cpu/thunk_emitter.cc | Emit host kernel thunk for GetOuterBatchValue custom call. |
| third_party/xla/xla/service/cpu/parallel_loop_emitter.cc | Pass per-dim expressions into parallel loop bounds. |
| third_party/xla/xla/service/cpu/ir_emitter2.h | Add host kernel emission API for GetOuterBatchValue. |
| third_party/xla/xla/service/cpu/ir_emitter2.cc | Implement host kernel that stores evaluated outer batch dim. |
| third_party/xla/xla/service/cpu/ir_emitter.h | Change transfer-element count parameter to DynExpr*. |
| third_party/xla/xla/service/cpu/executable_run_options_offset.h | Declare helper to compute batch_size_ offset in ExecutableRunOptions. |
| third_party/xla/xla/service/cpu/executable_run_options_offset.cc | Implement offset computation via pointer-to-member trick. |
| third_party/xla/xla/service/cpu/cpu_executable.cc | Add param padding logic and pass batch_size into thunk execution params. |
| third_party/xla/xla/service/cpu/cpu_compiler.cc | Wire in (commented out) outer-dimension passes; adjust pipeline bits. |
| third_party/xla/xla/service/cpu/BUILD | Add new executable_run_options_offset library target. |
| third_party/xla/xla/service/BUILD | Add build targets for new outer-dimension-related passes. |
| third_party/xla/xla/hlo/translate/mhlo_to_hlo/type_to_shape.cc | Initialize expression vectors when converting MLIR types to shapes. |
| third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc | Plumb expressions into Reshape/DynamicReshape lowering. |
| third_party/xla/xla/hlo/transforms/expanders/rng_bit_generator_expander.cc | Update reshape/broadcast ops to include expressions. |
| third_party/xla/xla/hlo/transforms/expanders/qr_expander.cc | Propagate expressions through reshapes/broadcasts in QR expander. |
| third_party/xla/xla/hlo/transforms/expanders/eigh_expander.cc | Propagate expressions through broadcasts and reshape usage. |
| third_party/xla/xla/hlo/transforms/expanders/dot_decomposer.cc | Carry expressions when building reshapes/dot shapes. |
| third_party/xla/xla/hlo/transforms/expanders/cholesky_expander.cc | Pass expressions into BroadcastInDim for correct dynamic shape handling. |
| third_party/xla/xla/hlo/transforms/expanders/bitcast_dtypes_expander.cc | Preserve expressions across reshape+broadcast patterns. |
| third_party/xla/xla/hlo/transforms/collectives/all_gather_combiner.cc | Refactor bitcast shape creation; preserve operand for CreateBitcast. |
| third_party/xla/xla/hlo/pass/hlo_pass_pipeline.cc | Minor whitespace change in pass pipeline loop. |
| third_party/xla/xla/hlo/builder/lib/svd.cc | Propagate expressions through BroadcastInDim usage. |
| third_party/xla/xla/hlo/builder/lib/slicing.cc | Propagate expressions through reshape/broadcast in slicing helpers. |
| third_party/xla/xla/hlo/builder/lib/prng.cc | Propagate expressions through reshape/broadcast in PRNG utilities. |
| third_party/xla/xla/hlo/builder/lib/matrix.cc | Propagate expressions through broadcast/reshape in matrix helpers. |
| third_party/xla/xla/hlo/builder/lib/broadcast.h | Extend BroadcastTo signature to accept optional expression span. |
| third_party/xla/xla/hlo/builder/lib/broadcast.cc | Implement expression-aware BroadcastTo. |
| third_party/xla/xla/hlo/builder/lib/arithmetic.cc | Create iota shapes carrying expressions for dynamic dims. |
| third_party/xla/xla/hlo/builder/lib/approx_topk.cc | Propagate expressions through reshape of reduction outputs. |
| third_party/xla/xla/executable_run_options.h | Add batch_size_ to run options with setter/getter. |
| third_party/xla/xla/debug_options_flags.cc | Add command-line flag plumbing for xla_compile_batch_sizes. |
| third_party/xla/xla/backends/cpu/runtime/thunk.h | Extend ExecuteParams with batch_size. |
| third_party/xla/xla/backends/cpu/runtime/kernel_thunk.cc | Pass batch_size into host kernel call/launch. |
| third_party/xla/xla/backends/cpu/runtime/kernel_c_api.h | Add batch_size to kernel call frame C struct. |
| third_party/xla/xla/backends/cpu/runtime/kernel.h | Extend kernel API to accept batch_size for CallOnce/Launch. |
| third_party/xla/xla/backends/cpu/runtime/kernel.cc | Thread batch_size through call frames and parallel task launching. |
| third_party/xla/xla/backends/cpu/codegen/kernel_api_ir_builder.h | Add API for emitting batch dim load from call frame. |
| third_party/xla/xla/backends/cpu/codegen/kernel_api_ir_builder.cc | Implement batch dim load and avoid deref metadata for dynamic expr shapes. |
| third_party/xla/xla/BUILD | Export shape_dynexpr.h in XLA public headers. |
| tensorflow/tools/toolchains/python/python_repo.bzl | Add USE_PYWRAP_RULES template variable. |
| tensorflow/core/util/strided_slice_op.h | Extend strided-slice validation API to optionally return expressions. |
| tensorflow/core/kernels/strided_slice_op.cc | Pass begin/end expression vectors through validation call. |
| tensorflow/core/kernels/padding_fifo_queue.cc | Preserve/normalize expressions when converting partial dims to concrete shapes. |
| tensorflow/core/kernels/function_ops.h | Add dynamic-dimension tracking field to Arg/Retval ops. |
| tensorflow/core/kernels/function_ops.cc | Record runtime batch size into a resource based on a marked arg dim. |
| tensorflow/core/grappler/optimizers/remapper.cc | Preserve _output_shapes on fused nodes where needed. |
| tensorflow/core/graph/subgraph.cc | Mark _Arg nodes with _dynamic_dim based on unknown dims in _output_shapes. |
| tensorflow/core/framework/tensor_shape_expr.h | Add TF-side symbolic dimension expression representation + proto conversion. |
| tensorflow/core/framework/tensor_shape_expr.cc | Implement TF DimExpr parsing/equality/simplification. |
| tensorflow/core/framework/tensor_shape.proto | Add ExpressionProto and per-dim expression annotations in TensorShapeProto. |
| tensorflow/core/framework/tensor_shape.h | Store per-dimension DynExpr pointers in TensorShapeRep. |
| tensorflow/core/framework/shape_inference.h | Add expression/dynamic ratio support in InferenceContext dimensions. |
| tensorflow/core/framework/common_shape_fns.cc | Improve merging of unknown broadcast dims by attempting Merge first. |
| tensorflow/core/framework/batch_size_resource.h | Introduce BatchSizeResource to store runtime batch size per step. |
| tensorflow/core/framework/BUILD | Add new tensor_shape_expr and batch_size_resource to build/export lists. |
| tensorflow/core/common_runtime/constant_folding.cc | Minor indentation change. |
| tensorflow/core/BUILD | Add dependency for shape_util usage. |
| tensorflow/compiler/tf2xla/xla_op_kernel.cc | Preserve expressions when reshaping variables to representation shape. |
| tensorflow/compiler/tf2xla/xla_compiler.cc | Attach op metadata and reshape args using dimension expressions. |
| tensorflow/compiler/tf2xla/xla_argument.h | Add API to fetch argument dimension expressions. |
| tensorflow/compiler/tf2xla/shape_util.cc | Convert expressions between TensorShape and xla::Shape. |
| tensorflow/compiler/tf2xla/ops/xla_ops.cc | Construct xla::Shape with per-dim expressions based on dynamic ratios. |
| tensorflow/compiler/tf2xla/lib/data_format.cc | Preserve expressions through reshape/transpose data-format transforms. |
| tensorflow/compiler/tf2xla/lib/broadcast.h | Extend TF wrapper for BroadcastTo to accept expression span. |
| tensorflow/compiler/tf2xla/lib/broadcast.cc | Forward expression span to xla::BroadcastTo. |
| tensorflow/compiler/tf2xla/layout_util.cc | Copy expressions when reshaping with correct representation. |
| tensorflow/compiler/tf2xla/kernels/where_op.cc | Preserve expressions when flattening/reshaping for Where lowering. |
| tensorflow/compiler/tf2xla/kernels/unpack_op.cc | Reshape slices using output expressions. |
| tensorflow/compiler/tf2xla/kernels/unique_op.cc | Propagate expressions through reshape/broadcast/slice logic. |
| tensorflow/compiler/tf2xla/kernels/tile_ops.cc | Compute output expressions for tiled dimensions. |
| tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc | Preserve expressions through reshape operations in TensorList helpers. |
| tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc | Build element/result shapes using expressions. |
| tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc | Preserve expressions through broadcast/reshape/dynamic-slice paths. |
| tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc | Preserve expressions through reshape/broadcast in RNG lowering. |
| tensorflow/compiler/tf2xla/kernels/stack_ops.cc | Reshape updates using expressions. |
| tensorflow/compiler/tf2xla/kernels/split_op.cc | Add expression-aware slice begin/limit handling. |
| tensorflow/compiler/tf2xla/kernels/sparse_to_dense_op.cc | Preserve expressions in broadcast initialization. |
| tensorflow/compiler/tf2xla/kernels/softmax_op.cc | Replace MLIR kernel with explicit XLA lowering (softmax/log-softmax). |
| tensorflow/compiler/tf2xla/kernels/slice_op.cc | Use expression-aware Slice/DynamicSlice and size computations. |
| tensorflow/compiler/tf2xla/kernels/shape_op.cc | Preserve expressions through ExpandDims/Squeeze/ZerosLike/OnesLike. |
| tensorflow/compiler/tf2xla/kernels/select_op.cc | Use expression-aware BroadcastInDim and custom kernel impl. |
| tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc | Broadcast buffer initialization using expressions. |
| tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc | Broadcast buffer initialization using expressions. |
| tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc | Build iota shapes with expressions for dynamic bounds. |
| tensorflow/compiler/tf2xla/kernels/relu_op.cc | Replace MLIR kernels; preserve expressions in Relu6Grad broadcasts. |
| tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc | Preserve expressions when reshaping reduced outputs with keep_dims. |
| tensorflow/compiler/tf2xla/kernels/reduction_ops.cc | Use GetOuterBatchValue for dynamic batch reduction divisor when enabled. |
| tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc | Use expression-aware BroadcastInDim for per-channel params. |
| tensorflow/compiler/tf2xla/kernels/pooling_ops.cc | Preserve expressions through reshape/transpose in vectorized pooling. |
| tensorflow/compiler/tf2xla/kernels/pack_op.cc | Preserve expressions when inserting a dim and reshaping inputs. |
| tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc | Add expressions to broadcast shapes used in solve op. |
| tensorflow/compiler/tf2xla/kernels/matrix_diag_ops.cc | Add expressions to diag shapes and broadcast/reshape operations. |
| tensorflow/compiler/tf2xla/kernels/lower_upper_bound_ops.cc | Preserve expressions through reshape used for broadcasting. |
| tensorflow/compiler/tf2xla/kernels/in_topk_op.cc | Use expression-aware broadcasts when building one/zero tensors. |
| tensorflow/compiler/tf2xla/kernels/image_ops.cc | Preserve expressions in broadcast of zeros. |
| tensorflow/compiler/tf2xla/kernels/gather_op.cc | Use expression-aware broadcast for empty gather output. |
| tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc | Use expression-aware broadcast and BroadcastInDim in fake-quant ops. |
| tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc | Preserve expressions when reshaping stitch inputs. |
| tensorflow/compiler/tf2xla/kernels/dynamic_partition_op.cc | Pass expressions to BroadcastInDim for partitions. |
| tensorflow/compiler/tf2xla/kernels/diag_op.cc | Use reshape overload that accepts expressions (empty here). |
| tensorflow/compiler/tf2xla/kernels/conv_ops.cc | Preserve expressions through reshape operations for batch handling. |
| tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc | Preserve expressions through filter reshapes. |
| tensorflow/compiler/tf2xla/kernels/const_op.cc | Use expression-aware broadcast when materializing scalar consts. |
| tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc | Use expression-aware broadcast for scalar min/max. |
| tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc | Pass expressions into BroadcastTo wrapper call. |
| tensorflow/compiler/tf2xla/kernels/bincount_op.cc | Preserve expressions through flattening reshapes. |
| tensorflow/compiler/tf2xla/kernels/beta_op.cc | Pass expressions through BroadcastTo for Betainc inputs. |
| tensorflow/compiler/jit/xla_launch_util.h | Extend PopulateOutputs to accept optional ExecutableRunOptions. |
| tensorflow/compiler/jit/xla_launch_util.cc | Substitute batch-size into expressions when shaping TF outputs. |
| tensorflow/compiler/jit/xla_batch_matcher.h | Add batch-size selection helper based on debug flag config. |
| tensorflow/compiler/jit/xla_batch_matcher.cc | Implement parsing/selection of compilation batch sizes. |
| tensorflow/compiler/jit/shape_inference.cc | Attach DynExpr-derived annotations to PartialTensorShape results. |
| tensorflow/compiler/jit/kernels/BUILD | Add dependency on xla_batch_matcher library. |
| tensorflow/compiler/jit/flags.h | Add flags for dynamic sizes and clustering behaviors. |
| tensorflow/compiler/jit/flags.cc | Register and initialize new JIT flags. |
| tensorflow/compiler/jit/encapsulate_util.h | Add constants for new inferred shape attributes. |
| tensorflow/compiler/jit/encapsulate_util.cc | Define new inferred shape attribute names. |
| tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc | Copy inferred output shapes to _Arg nodes; gate compilation for failing ops. |
| tensorflow/compiler/jit/device_compiler.h | Instantiate XlaBatchMatcher when dynamic sizes are enabled. |
| tensorflow/compiler/jit/device_compilation_profiler.cc | Make megamorphic threshold configurable via new flag. |
| tensorflow/compiler/jit/BUILD | Add xla_batch_matcher target and deps. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| llvm::Value* formatStr = | ||
| builder.CreateGlobalStringPtr("Function: %s, Batch size is : %d!\n"); | ||
| builder.CreateCall(printfFunc, {formatStr, funcNameStr, bdim_value}); |
There was a problem hiding this comment.
Under PRINT_BATCHSIZE, the generated printf format string uses %d but bdim_value is an i64. This is a varargs type mismatch and is undefined behavior at runtime. Use an i64-appropriate format (e.g. %lld/%ld) or cast/truncate bdim_value to i32 before printing.
| // Next id: 390 | ||
|
|
There was a problem hiding this comment.
The comment // Next id: 390 is now inconsistent with the newly-added field xla_compile_batch_sizes = 399 (and there are fields with ids well above 390 later in the message). Update the Next id comment to the correct next available tag to avoid future field-number collisions during proto evolution.
| void* dst = padded-> opaque(); | ||
| void* src = param_mem.opaque(); | ||
|
|
There was a problem hiding this comment.
This call has a syntax error: padded-> opaque() will not compile (member access has an extra space). It should be padded->opaque().
There was a problem hiding this comment.
@copilot open a new pull request to apply changes based on this feedback
| size_t ExecutableRunOptionsBatchSizeOffset() { | ||
| auto ptr = get_offset(BatchSizeTag{}); | ||
| // Compute offset in bytes from null pointer. | ||
| return reinterpret_cast<size_t>( | ||
| &(reinterpret_cast<xla::ExecutableRunOptions*>(0)->*ptr)); | ||
| } |
There was a problem hiding this comment.
ExecutableRunOptionsBatchSizeOffset() computes an offset by applying a pointer-to-member to a null pointer (reinterpret_cast<ExecutableRunOptions*>(0)->*ptr). This is undefined behavior in C++ and can break under different compilers/ABIs/optimizations. Prefer a supported mechanism to pass batch_size into generated code (e.g., via the already-extended XLA_CPU_KernelCallFrame), and remove this offset hack if possible.
There was a problem hiding this comment.
@copilot You’re right, this hack is not great. I don’t think there is a safe small fix here, so I’d rather handle it in a separate follow-up where we pass the batch size in a cleaner way instead of reading it by offset.
| XlaBatchMatcher::XlaBatchMatcher() { | ||
| const std::string xla_compile_batch_sizes = | ||
| xla::GetDebugOptionsFromFlags().xla_compile_batch_sizes(); | ||
| env_str_ = xla_compile_batch_sizes.c_str(); | ||
| parse_env_config(); | ||
| } |
There was a problem hiding this comment.
env_str_ is set to xla_compile_batch_sizes.c_str(), but xla_compile_batch_sizes is a local std::string that is destroyed at the end of the constructor. This leaves env_str_ as a dangling pointer and makes parse_env_config() use-after-free. Store the string as a member (e.g., std::string env_str_) or copy it into owned storage before taking c_str().
| auto sh = ShapeUtil::PermuteDimensions(*perm, operand_shape), operand; | ||
| operands.back() = | ||
| computation.AddInstruction(HloInstruction::CreateBitcast( | ||
| ShapeUtil::PermuteDimensions(*perm, operand_shape), operand)); | ||
| computation.AddInstruction(HloInstruction::CreateBitcast(sh)); | ||
| output_shapes.back() = ShapeUtil::PermuteDimensions(*perm, hlo->shape()); |
There was a problem hiding this comment.
This declaration is invalid C++ (auto sh = ..., operand;) and also drops the operand from CreateBitcast, so it won't compile and the bitcast won't have an input value. Define sh separately and call HloInstruction::CreateBitcast(sh, operand) (using the existing operand variable from above).
| auto dimensions = LayoutUtil::MinorToMajor(shape); | ||
| for (int i = 0; i < dimensions.size(); i++) { | ||
| // The MinorToMajor order reverses dimensions... | ||
| bool is_dynamic = | ||
| shape.expressions(dimensions.size() - 1 - i)->is_dynamic(); | ||
| int64_t dim_val = is_dynamic ? 0 : shape.dimensions(dimensions[i]); | ||
| result_type = llvm::ArrayType::get(result_type, dim_val); | ||
| } |
There was a problem hiding this comment.
ShapeToIrType iterates dimensions in LayoutUtil::MinorToMajor(shape) order, but checks dynamic-ness using shape.expressions(dimensions.size() - 1 - i). That index is not the dimension number being processed, so dynamic dimensions can be misidentified and the resulting IR type can be wrong. Use shape.expressions(dimensions[i]) (and keep it consistent with the shape.dimensions(dimensions[i]) lookup).
There was a problem hiding this comment.
That is on purpose because expression wasnt permuted.
| auto expressions = shape_.expressions(); | ||
| bool dynamic_first_dim = | ||
| expressions[0]->is_dynamic() && | ||
| std::all_of(expressions.begin() + 1, expressions.end(), | ||
| [](DynExpr* e) { return e->is_constant(); }); | ||
| if (!dynamic_first_dim && shape_.has_dynamic_expr()) { | ||
| llvm::Type* element_type = | ||
| PrimitiveTypeToIrType(shape_.element_type(), b->getContext()); | ||
| return llvm_ir::createDynamicGEP( | ||
| b, base_ptr_, gep_indices, shape_.dimensions(), expressions, | ||
| element_type, llvm_ir::AsStringRef(name)); |
There was a problem hiding this comment.
The dynamic GEP path passes gep_indices in major-to-minor (layout) order (and includes the leading 0 index), but passes shape_.dimensions() and shape_.expressions() in logical dimension order to createDynamicGEP. These orderings don't match, so the linearization will compute incorrect element addresses for non-default layouts. Consider passing dims/expressions permuted to the same order as gep_indices (excluding the leading base index), or change createDynamicGEP to accept a layout and do the permutation internally.
There was a problem hiding this comment.
Fair point. I'll have a look
| @@ -79,15 +80,37 @@ absl::StatusOr<XlaOp> BroadcastTo(XlaOp input, | |||
| } | |||
| TF_RET_CHECK(input_it == input_dims.rend()); | |||
|
|
|||
| absl::Span<DynExpr* const> input_exprs = input_shape.expressions(); | |||
| std::vector<DynExpr*> broadcast_exprs; | |||
| auto input_et = input_exprs.rbegin(); | |||
| for (auto output_et = output_exprs.rbegin(); output_et != output_exprs.rend(); | |||
| ++output_et) { | |||
| if (input_et != input_exprs.rend()) { | |||
| if (*(*output_et) == *(*input_et) || | |||
| (*input_et)->is_constant() && (*input_et)->get_val() == 1) { | |||
| broadcast_exprs.push_back(*output_et); | |||
| } else if (!(*(*output_et) == *(*input_et))) { | |||
| broadcast_exprs.push_back(*input_et); | |||
| broadcast_exprs.push_back((**output_et / **input_et)->s()); | |||
| } | |||
| ++input_et; | |||
| } else { | |||
| broadcast_exprs.push_back(*output_et); | |||
| } | |||
| } | |||
|
|
|||
| absl::c_reverse(broadcast_dims); | |||
| int broadcast_shape_size = broadcast_shape.size(); | |||
| for (int64_t& broadcast_dim : broadcast_dims) { | |||
| broadcast_dim = broadcast_shape_size - broadcast_dim - 1; | |||
| } | |||
| absl::c_reverse(broadcast_shape); | |||
| XlaOp output = BroadcastInDim(input, broadcast_shape, broadcast_dims); | |||
| absl::c_reverse(broadcast_exprs); | |||
|
|
|||
| XlaOp output = | |||
| BroadcastInDim(input, broadcast_shape, broadcast_dims, broadcast_exprs); | |||
| if (broadcast_shape != output_dims) { | |||
| output = Reshape(output, output_dims); | |||
| output = Reshape(output, output_dims, output_exprs); | |||
| } | |||
| return output; | |||
There was a problem hiding this comment.
BroadcastTo now takes output_exprs with a default of {}, but the implementation always iterates over output_exprs to build broadcast_exprs and then calls BroadcastInDim(..., broadcast_shape, ..., broadcast_exprs). If output_exprs is empty (existing call sites that don't pass expressions), broadcast_exprs will be empty while broadcast_shape is not, causing a rank mismatch. Also, when tiling (output_it != input_it && input_it != 1), broadcast_shape gains two dims, but the expr-building logic must add two expressions correspondingly; otherwise expression rank won't match the broadcasted shape.
Suggested fix: if output_exprs.empty(), synthesize constant expressions from output_dims (similar to ShapeUtil::MakeValidatedShape), and ensure broadcast_exprs is built to exactly match broadcast_shape (including the extra tiling dimension).
|
@joeyye-work I've opened a new pull request, #61, to work on those changes. Once the pull request is ready, I'll request review from you. |
Squashed changes of dynamic shape support in XLA.