Skip to content

Partial dynamic shape support in XLA#59

Open
joeyye-work wants to merge 1 commit intofor-serving-2.20from
cleanup-serving-2.20
Open

Partial dynamic shape support in XLA#59
joeyye-work wants to merge 1 commit intofor-serving-2.20from
cleanup-serving-2.20

Conversation

@joeyye-work
Copy link
Copy Markdown
Owner

Squashed changes of dynamic shape support in XLA.

Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces partial dynamic-shape support by threading per-dimension symbolic expressions (DynExpr / ExpressionProto) through XLA shapes and TensorFlow shape inference, and by plumbing a runtime “batch size” value into the XLA:CPU execution path for evaluating dynamic dimensions.

Changes:

  • Add per-dimension ExpressionProto / DynExpr* metadata to XLA Shape/ShapeProto and TF TensorShapeProto, and plumb it through many builders/expanders/shape utilities.
  • Add XLA:CPU runtime support for a dynamic “outer batch” dimension (including passing batch_size through the CPU kernel call frame).
  • Add a debug flag (--xla_compile_batch_sizes) plus TF-side utilities to pick compilation batch sizes and propagate dynamic-dimension metadata across encapsulation/launch.

Reviewed changes

Copilot reviewed 160 out of 161 changed files in this pull request and generated 11 comments.

Show a summary per file
File Description
third_party/xla/xla/xla_data.proto Add expressions field and ExpressionProto message for shape expressions.
third_party/xla/xla/xla.proto Add xla_compile_batch_sizes debug option.
third_party/xla/xla/stream_executor/tpu/c_api_decl.h Extend TPU C API shape struct with batch-related fields.
third_party/xla/xla/shape_util.h Add expression-aware ShapeUtil APIs and expression params to shape construction.
third_party/xla/xla/shape.h Add DynExpr* storage/accessors, hashing changes, and batch-related equality option.
third_party/xla/xla/service/triangular_solve_expander.cc Propagate expressions through reshapes in the expander.
third_party/xla/xla/service/shape_inference.h Extend shape inference APIs to accept/pass expression spans.
third_party/xla/xla/service/reduce_scatter_combiner.cc Adjust bitcast/permutation logic while combining reduce-scatter ops.
third_party/xla/xla/service/outer_dimension_propagation.h Introduce an HLO pass for propagating outer-dimension multiplier metadata.
third_party/xla/xla/service/llvm_ir/tuple_ops.cc Avoid dereferenceable metadata when shapes have dynamic expressions.
third_party/xla/xla/service/llvm_ir/loop_emitter.cc Emit loop bounds from DynExpr via LLVM IR expression emission.
third_party/xla/xla/service/llvm_ir/llvm_util.h Add APIs for batch dim lookup, dynamic GEP, and emitting expressions.
third_party/xla/xla/service/llvm_ir/llvm_util.cc Implement batch dim loading, DynExpr→LLVM lowering, and dynamic GEP.
third_party/xla/xla/service/llvm_ir/llvm_loop.h Allow loop construction to carry an optional DynExpr for end bounds.
third_party/xla/xla/service/llvm_ir/llvm_loop.cc Use DynExpr-derived end bounds for loops (min with static end).
third_party/xla/xla/service/llvm_ir/ir_array.cc Use dynamic GEP when shapes have dynamic expressions.
third_party/xla/xla/service/llvm_ir/BUILD Add dependency for executable run options offset helper.
third_party/xla/xla/service/layout_assignment.cc Ignore batch info during layout propagation equality checks.
third_party/xla/xla/service/hlo_creation_utils.cc Preserve expressions when collapsing dimensions.
third_party/xla/xla/service/get_outer_batch_value_simplifier.h Add HLO pass to simplify GetOuterBatchValue when static.
third_party/xla/xla/service/get_outer_batch_value_simplifier.cc Implement rewrite of GetOuterBatchValue custom call to constant.
third_party/xla/xla/service/elemental_ir_emitter.cc Use DynExpr for bounds in concatenate/pad/dot/reverse/reduce-window.
third_party/xla/xla/service/cpu/thunk_emitter.h Add thunk emitter entry for GetOuterBatchValue.
third_party/xla/xla/service/cpu/thunk_emitter.cc Emit host kernel thunk for GetOuterBatchValue custom call.
third_party/xla/xla/service/cpu/parallel_loop_emitter.cc Pass per-dim expressions into parallel loop bounds.
third_party/xla/xla/service/cpu/ir_emitter2.h Add host kernel emission API for GetOuterBatchValue.
third_party/xla/xla/service/cpu/ir_emitter2.cc Implement host kernel that stores evaluated outer batch dim.
third_party/xla/xla/service/cpu/ir_emitter.h Change transfer-element count parameter to DynExpr*.
third_party/xla/xla/service/cpu/executable_run_options_offset.h Declare helper to compute batch_size_ offset in ExecutableRunOptions.
third_party/xla/xla/service/cpu/executable_run_options_offset.cc Implement offset computation via pointer-to-member trick.
third_party/xla/xla/service/cpu/cpu_executable.cc Add param padding logic and pass batch_size into thunk execution params.
third_party/xla/xla/service/cpu/cpu_compiler.cc Wire in (commented out) outer-dimension passes; adjust pipeline bits.
third_party/xla/xla/service/cpu/BUILD Add new executable_run_options_offset library target.
third_party/xla/xla/service/BUILD Add build targets for new outer-dimension-related passes.
third_party/xla/xla/hlo/translate/mhlo_to_hlo/type_to_shape.cc Initialize expression vectors when converting MLIR types to shapes.
third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc Plumb expressions into Reshape/DynamicReshape lowering.
third_party/xla/xla/hlo/transforms/expanders/rng_bit_generator_expander.cc Update reshape/broadcast ops to include expressions.
third_party/xla/xla/hlo/transforms/expanders/qr_expander.cc Propagate expressions through reshapes/broadcasts in QR expander.
third_party/xla/xla/hlo/transforms/expanders/eigh_expander.cc Propagate expressions through broadcasts and reshape usage.
third_party/xla/xla/hlo/transforms/expanders/dot_decomposer.cc Carry expressions when building reshapes/dot shapes.
third_party/xla/xla/hlo/transforms/expanders/cholesky_expander.cc Pass expressions into BroadcastInDim for correct dynamic shape handling.
third_party/xla/xla/hlo/transforms/expanders/bitcast_dtypes_expander.cc Preserve expressions across reshape+broadcast patterns.
third_party/xla/xla/hlo/transforms/collectives/all_gather_combiner.cc Refactor bitcast shape creation; preserve operand for CreateBitcast.
third_party/xla/xla/hlo/pass/hlo_pass_pipeline.cc Minor whitespace change in pass pipeline loop.
third_party/xla/xla/hlo/builder/lib/svd.cc Propagate expressions through BroadcastInDim usage.
third_party/xla/xla/hlo/builder/lib/slicing.cc Propagate expressions through reshape/broadcast in slicing helpers.
third_party/xla/xla/hlo/builder/lib/prng.cc Propagate expressions through reshape/broadcast in PRNG utilities.
third_party/xla/xla/hlo/builder/lib/matrix.cc Propagate expressions through broadcast/reshape in matrix helpers.
third_party/xla/xla/hlo/builder/lib/broadcast.h Extend BroadcastTo signature to accept optional expression span.
third_party/xla/xla/hlo/builder/lib/broadcast.cc Implement expression-aware BroadcastTo.
third_party/xla/xla/hlo/builder/lib/arithmetic.cc Create iota shapes carrying expressions for dynamic dims.
third_party/xla/xla/hlo/builder/lib/approx_topk.cc Propagate expressions through reshape of reduction outputs.
third_party/xla/xla/executable_run_options.h Add batch_size_ to run options with setter/getter.
third_party/xla/xla/debug_options_flags.cc Add command-line flag plumbing for xla_compile_batch_sizes.
third_party/xla/xla/backends/cpu/runtime/thunk.h Extend ExecuteParams with batch_size.
third_party/xla/xla/backends/cpu/runtime/kernel_thunk.cc Pass batch_size into host kernel call/launch.
third_party/xla/xla/backends/cpu/runtime/kernel_c_api.h Add batch_size to kernel call frame C struct.
third_party/xla/xla/backends/cpu/runtime/kernel.h Extend kernel API to accept batch_size for CallOnce/Launch.
third_party/xla/xla/backends/cpu/runtime/kernel.cc Thread batch_size through call frames and parallel task launching.
third_party/xla/xla/backends/cpu/codegen/kernel_api_ir_builder.h Add API for emitting batch dim load from call frame.
third_party/xla/xla/backends/cpu/codegen/kernel_api_ir_builder.cc Implement batch dim load and avoid deref metadata for dynamic expr shapes.
third_party/xla/xla/BUILD Export shape_dynexpr.h in XLA public headers.
tensorflow/tools/toolchains/python/python_repo.bzl Add USE_PYWRAP_RULES template variable.
tensorflow/core/util/strided_slice_op.h Extend strided-slice validation API to optionally return expressions.
tensorflow/core/kernels/strided_slice_op.cc Pass begin/end expression vectors through validation call.
tensorflow/core/kernels/padding_fifo_queue.cc Preserve/normalize expressions when converting partial dims to concrete shapes.
tensorflow/core/kernels/function_ops.h Add dynamic-dimension tracking field to Arg/Retval ops.
tensorflow/core/kernels/function_ops.cc Record runtime batch size into a resource based on a marked arg dim.
tensorflow/core/grappler/optimizers/remapper.cc Preserve _output_shapes on fused nodes where needed.
tensorflow/core/graph/subgraph.cc Mark _Arg nodes with _dynamic_dim based on unknown dims in _output_shapes.
tensorflow/core/framework/tensor_shape_expr.h Add TF-side symbolic dimension expression representation + proto conversion.
tensorflow/core/framework/tensor_shape_expr.cc Implement TF DimExpr parsing/equality/simplification.
tensorflow/core/framework/tensor_shape.proto Add ExpressionProto and per-dim expression annotations in TensorShapeProto.
tensorflow/core/framework/tensor_shape.h Store per-dimension DynExpr pointers in TensorShapeRep.
tensorflow/core/framework/shape_inference.h Add expression/dynamic ratio support in InferenceContext dimensions.
tensorflow/core/framework/common_shape_fns.cc Improve merging of unknown broadcast dims by attempting Merge first.
tensorflow/core/framework/batch_size_resource.h Introduce BatchSizeResource to store runtime batch size per step.
tensorflow/core/framework/BUILD Add new tensor_shape_expr and batch_size_resource to build/export lists.
tensorflow/core/common_runtime/constant_folding.cc Minor indentation change.
tensorflow/core/BUILD Add dependency for shape_util usage.
tensorflow/compiler/tf2xla/xla_op_kernel.cc Preserve expressions when reshaping variables to representation shape.
tensorflow/compiler/tf2xla/xla_compiler.cc Attach op metadata and reshape args using dimension expressions.
tensorflow/compiler/tf2xla/xla_argument.h Add API to fetch argument dimension expressions.
tensorflow/compiler/tf2xla/shape_util.cc Convert expressions between TensorShape and xla::Shape.
tensorflow/compiler/tf2xla/ops/xla_ops.cc Construct xla::Shape with per-dim expressions based on dynamic ratios.
tensorflow/compiler/tf2xla/lib/data_format.cc Preserve expressions through reshape/transpose data-format transforms.
tensorflow/compiler/tf2xla/lib/broadcast.h Extend TF wrapper for BroadcastTo to accept expression span.
tensorflow/compiler/tf2xla/lib/broadcast.cc Forward expression span to xla::BroadcastTo.
tensorflow/compiler/tf2xla/layout_util.cc Copy expressions when reshaping with correct representation.
tensorflow/compiler/tf2xla/kernels/where_op.cc Preserve expressions when flattening/reshaping for Where lowering.
tensorflow/compiler/tf2xla/kernels/unpack_op.cc Reshape slices using output expressions.
tensorflow/compiler/tf2xla/kernels/unique_op.cc Propagate expressions through reshape/broadcast/slice logic.
tensorflow/compiler/tf2xla/kernels/tile_ops.cc Compute output expressions for tiled dimensions.
tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc Preserve expressions through reshape operations in TensorList helpers.
tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc Build element/result shapes using expressions.
tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc Preserve expressions through broadcast/reshape/dynamic-slice paths.
tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc Preserve expressions through reshape/broadcast in RNG lowering.
tensorflow/compiler/tf2xla/kernels/stack_ops.cc Reshape updates using expressions.
tensorflow/compiler/tf2xla/kernels/split_op.cc Add expression-aware slice begin/limit handling.
tensorflow/compiler/tf2xla/kernels/sparse_to_dense_op.cc Preserve expressions in broadcast initialization.
tensorflow/compiler/tf2xla/kernels/softmax_op.cc Replace MLIR kernel with explicit XLA lowering (softmax/log-softmax).
tensorflow/compiler/tf2xla/kernels/slice_op.cc Use expression-aware Slice/DynamicSlice and size computations.
tensorflow/compiler/tf2xla/kernels/shape_op.cc Preserve expressions through ExpandDims/Squeeze/ZerosLike/OnesLike.
tensorflow/compiler/tf2xla/kernels/select_op.cc Use expression-aware BroadcastInDim and custom kernel impl.
tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc Broadcast buffer initialization using expressions.
tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc Broadcast buffer initialization using expressions.
tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc Build iota shapes with expressions for dynamic bounds.
tensorflow/compiler/tf2xla/kernels/relu_op.cc Replace MLIR kernels; preserve expressions in Relu6Grad broadcasts.
tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc Preserve expressions when reshaping reduced outputs with keep_dims.
tensorflow/compiler/tf2xla/kernels/reduction_ops.cc Use GetOuterBatchValue for dynamic batch reduction divisor when enabled.
tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc Use expression-aware BroadcastInDim for per-channel params.
tensorflow/compiler/tf2xla/kernels/pooling_ops.cc Preserve expressions through reshape/transpose in vectorized pooling.
tensorflow/compiler/tf2xla/kernels/pack_op.cc Preserve expressions when inserting a dim and reshaping inputs.
tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc Add expressions to broadcast shapes used in solve op.
tensorflow/compiler/tf2xla/kernels/matrix_diag_ops.cc Add expressions to diag shapes and broadcast/reshape operations.
tensorflow/compiler/tf2xla/kernels/lower_upper_bound_ops.cc Preserve expressions through reshape used for broadcasting.
tensorflow/compiler/tf2xla/kernels/in_topk_op.cc Use expression-aware broadcasts when building one/zero tensors.
tensorflow/compiler/tf2xla/kernels/image_ops.cc Preserve expressions in broadcast of zeros.
tensorflow/compiler/tf2xla/kernels/gather_op.cc Use expression-aware broadcast for empty gather output.
tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc Use expression-aware broadcast and BroadcastInDim in fake-quant ops.
tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc Preserve expressions when reshaping stitch inputs.
tensorflow/compiler/tf2xla/kernels/dynamic_partition_op.cc Pass expressions to BroadcastInDim for partitions.
tensorflow/compiler/tf2xla/kernels/diag_op.cc Use reshape overload that accepts expressions (empty here).
tensorflow/compiler/tf2xla/kernels/conv_ops.cc Preserve expressions through reshape operations for batch handling.
tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc Preserve expressions through filter reshapes.
tensorflow/compiler/tf2xla/kernels/const_op.cc Use expression-aware broadcast when materializing scalar consts.
tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc Use expression-aware broadcast for scalar min/max.
tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc Pass expressions into BroadcastTo wrapper call.
tensorflow/compiler/tf2xla/kernels/bincount_op.cc Preserve expressions through flattening reshapes.
tensorflow/compiler/tf2xla/kernels/beta_op.cc Pass expressions through BroadcastTo for Betainc inputs.
tensorflow/compiler/jit/xla_launch_util.h Extend PopulateOutputs to accept optional ExecutableRunOptions.
tensorflow/compiler/jit/xla_launch_util.cc Substitute batch-size into expressions when shaping TF outputs.
tensorflow/compiler/jit/xla_batch_matcher.h Add batch-size selection helper based on debug flag config.
tensorflow/compiler/jit/xla_batch_matcher.cc Implement parsing/selection of compilation batch sizes.
tensorflow/compiler/jit/shape_inference.cc Attach DynExpr-derived annotations to PartialTensorShape results.
tensorflow/compiler/jit/kernels/BUILD Add dependency on xla_batch_matcher library.
tensorflow/compiler/jit/flags.h Add flags for dynamic sizes and clustering behaviors.
tensorflow/compiler/jit/flags.cc Register and initialize new JIT flags.
tensorflow/compiler/jit/encapsulate_util.h Add constants for new inferred shape attributes.
tensorflow/compiler/jit/encapsulate_util.cc Define new inferred shape attribute names.
tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc Copy inferred output shapes to _Arg nodes; gate compilation for failing ops.
tensorflow/compiler/jit/device_compiler.h Instantiate XlaBatchMatcher when dynamic sizes are enabled.
tensorflow/compiler/jit/device_compilation_profiler.cc Make megamorphic threshold configurable via new flag.
tensorflow/compiler/jit/BUILD Add xla_batch_matcher target and deps.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +339 to +341
llvm::Value* formatStr =
builder.CreateGlobalStringPtr("Function: %s, Batch size is : %d!\n");
builder.CreateCall(printfFunc, {formatStr, funcNameStr, bdim_value});
Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Under PRINT_BATCHSIZE, the generated printf format string uses %d but bdim_value is an i64. This is a varargs type mismatch and is undefined behavior at runtime. Use an i64-appropriate format (e.g. %lld/%ld) or cast/truncate bdim_value to i32 before printing.

Copilot uses AI. Check for mistakes.
Comment on lines +1213 to 1214
// Next id: 390

Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment // Next id: 390 is now inconsistent with the newly-added field xla_compile_batch_sizes = 399 (and there are fields with ids well above 390 later in the message). Update the Next id comment to the correct next available tag to avoid future field-number collisions during proto evolution.

Copilot uses AI. Check for mistakes.
Comment on lines +186 to +188
void* dst = padded-> opaque();
void* src = param_mem.opaque();

Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This call has a syntax error: padded-> opaque() will not compile (member access has an extra space). It should be padded->opaque().

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot open a new pull request to apply changes based on this feedback

Comment on lines +21 to +26
size_t ExecutableRunOptionsBatchSizeOffset() {
auto ptr = get_offset(BatchSizeTag{});
// Compute offset in bytes from null pointer.
return reinterpret_cast<size_t>(
&(reinterpret_cast<xla::ExecutableRunOptions*>(0)->*ptr));
}
Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ExecutableRunOptionsBatchSizeOffset() computes an offset by applying a pointer-to-member to a null pointer (reinterpret_cast<ExecutableRunOptions*>(0)->*ptr). This is undefined behavior in C++ and can break under different compilers/ABIs/optimizations. Prefer a supported mechanism to pass batch_size into generated code (e.g., via the already-extended XLA_CPU_KernelCallFrame), and remove this offset hack if possible.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot You’re right, this hack is not great. I don’t think there is a safe small fix here, so I’d rather handle it in a separate follow-up where we pass the batch size in a cleaner way instead of reading it by offset.

Comment on lines +7 to +12
XlaBatchMatcher::XlaBatchMatcher() {
const std::string xla_compile_batch_sizes =
xla::GetDebugOptionsFromFlags().xla_compile_batch_sizes();
env_str_ = xla_compile_batch_sizes.c_str();
parse_env_config();
}
Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

env_str_ is set to xla_compile_batch_sizes.c_str(), but xla_compile_batch_sizes is a local std::string that is destroyed at the end of the constructor. This leaves env_str_ as a dangling pointer and makes parse_env_config() use-after-free. Store the string as a member (e.g., std::string env_str_) or copy it into owned storage before taking c_str().

Copilot uses AI. Check for mistakes.
Comment on lines +134 to 137
auto sh = ShapeUtil::PermuteDimensions(*perm, operand_shape), operand;
operands.back() =
computation.AddInstruction(HloInstruction::CreateBitcast(
ShapeUtil::PermuteDimensions(*perm, operand_shape), operand));
computation.AddInstruction(HloInstruction::CreateBitcast(sh));
output_shapes.back() = ShapeUtil::PermuteDimensions(*perm, hlo->shape());
Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This declaration is invalid C++ (auto sh = ..., operand;) and also drops the operand from CreateBitcast, so it won't compile and the bitcast won't have an input value. Define sh separately and call HloInstruction::CreateBitcast(sh, operand) (using the existing operand variable from above).

Copilot uses AI. Check for mistakes.
Comment on lines +295 to 302
auto dimensions = LayoutUtil::MinorToMajor(shape);
for (int i = 0; i < dimensions.size(); i++) {
// The MinorToMajor order reverses dimensions...
bool is_dynamic =
shape.expressions(dimensions.size() - 1 - i)->is_dynamic();
int64_t dim_val = is_dynamic ? 0 : shape.dimensions(dimensions[i]);
result_type = llvm::ArrayType::get(result_type, dim_val);
}
Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ShapeToIrType iterates dimensions in LayoutUtil::MinorToMajor(shape) order, but checks dynamic-ness using shape.expressions(dimensions.size() - 1 - i). That index is not the dimension number being processed, so dynamic dimensions can be misidentified and the resulting IR type can be wrong. Use shape.expressions(dimensions[i]) (and keep it consistent with the shape.dimensions(dimensions[i]) lookup).

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is on purpose because expression wasnt permuted.

Comment on lines +569 to +579
auto expressions = shape_.expressions();
bool dynamic_first_dim =
expressions[0]->is_dynamic() &&
std::all_of(expressions.begin() + 1, expressions.end(),
[](DynExpr* e) { return e->is_constant(); });
if (!dynamic_first_dim && shape_.has_dynamic_expr()) {
llvm::Type* element_type =
PrimitiveTypeToIrType(shape_.element_type(), b->getContext());
return llvm_ir::createDynamicGEP(
b, base_ptr_, gep_indices, shape_.dimensions(), expressions,
element_type, llvm_ir::AsStringRef(name));
Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dynamic GEP path passes gep_indices in major-to-minor (layout) order (and includes the leading 0 index), but passes shape_.dimensions() and shape_.expressions() in logical dimension order to createDynamicGEP. These orderings don't match, so the linearization will compute incorrect element addresses for non-default layouts. Consider passing dims/expressions permuted to the same order as gep_indices (excluding the leading base index), or change createDynamicGEP to accept a layout and do the permutation internally.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair point. I'll have a look

Comment on lines 34 to 115
@@ -79,15 +80,37 @@ absl::StatusOr<XlaOp> BroadcastTo(XlaOp input,
}
TF_RET_CHECK(input_it == input_dims.rend());

absl::Span<DynExpr* const> input_exprs = input_shape.expressions();
std::vector<DynExpr*> broadcast_exprs;
auto input_et = input_exprs.rbegin();
for (auto output_et = output_exprs.rbegin(); output_et != output_exprs.rend();
++output_et) {
if (input_et != input_exprs.rend()) {
if (*(*output_et) == *(*input_et) ||
(*input_et)->is_constant() && (*input_et)->get_val() == 1) {
broadcast_exprs.push_back(*output_et);
} else if (!(*(*output_et) == *(*input_et))) {
broadcast_exprs.push_back(*input_et);
broadcast_exprs.push_back((**output_et / **input_et)->s());
}
++input_et;
} else {
broadcast_exprs.push_back(*output_et);
}
}

absl::c_reverse(broadcast_dims);
int broadcast_shape_size = broadcast_shape.size();
for (int64_t& broadcast_dim : broadcast_dims) {
broadcast_dim = broadcast_shape_size - broadcast_dim - 1;
}
absl::c_reverse(broadcast_shape);
XlaOp output = BroadcastInDim(input, broadcast_shape, broadcast_dims);
absl::c_reverse(broadcast_exprs);

XlaOp output =
BroadcastInDim(input, broadcast_shape, broadcast_dims, broadcast_exprs);
if (broadcast_shape != output_dims) {
output = Reshape(output, output_dims);
output = Reshape(output, output_dims, output_exprs);
}
return output;
Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BroadcastTo now takes output_exprs with a default of {}, but the implementation always iterates over output_exprs to build broadcast_exprs and then calls BroadcastInDim(..., broadcast_shape, ..., broadcast_exprs). If output_exprs is empty (existing call sites that don't pass expressions), broadcast_exprs will be empty while broadcast_shape is not, causing a rank mismatch. Also, when tiling (output_it != input_it && input_it != 1), broadcast_shape gains two dims, but the expr-building logic must add two expressions correspondingly; otherwise expression rank won't match the broadcasted shape.

Suggested fix: if output_exprs.empty(), synthesize constant expressions from output_dims (similar to ShapeUtil::MakeValidatedShape), and ensure broadcast_exprs is built to exactly match broadcast_shape (including the extra tiling dimension).

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown

Copilot AI commented Mar 21, 2026

@joeyye-work I've opened a new pull request, #61, to work on those changes. Once the pull request is ready, I'll request review from you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants