From c1b5924338522ccc554559970efcc702f0ae83b3 Mon Sep 17 00:00:00 2001 From: Utku Saglam Date: Fri, 27 Mar 2026 00:54:12 +0800 Subject: [PATCH] Changed default substitution logic. --- tensorflow/compiler/jit/kernels/xla_ops.cc | 13 +++++++++++-- tensorflow/compiler/jit/xla_launch_util.cc | 15 ++++++++++++--- 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc index e161bfc5a26fd5..f00492f70393b5 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.cc +++ b/tensorflow/compiler/jit/kernels/xla_ops.cc @@ -426,7 +426,7 @@ static xla::DynExpr* DimExprToDynExpr(const DimExpr* e) { } case DimExpr::Kind::kVariable: { auto* av = static_cast(e); - return xla::DynExpr::V(1); + return xla::DynExpr::V(av->id()); } case DimExpr::Kind::kAdd: { auto* ee = static_cast(e); @@ -448,6 +448,13 @@ static xla::DynExpr* DimExprToDynExpr(const DimExpr* e) { return nullptr; } +static int64_t GetSingleDynamicVarId(const xla::DynExpr* expr) { + CHECK_NE(expr, nullptr); + auto ids = expr->get_all_ids(); + CHECK_EQ(ids.size(), 1); + return ids.front(); +} + absl::Status CompileToLocalExecutable( OpKernelContext* ctx, const NameAttrList& function, bool has_ref_vars, @@ -685,10 +692,12 @@ absl::Status CompileToLocalExecutable( for (int j = 0; j < shp.get_expressions().size(); ++j) { auto e = shp.get_expression(j); if (e->is_dynamic()) { + int64_t dynamic_var_id = GetSingleDynamicVarId(e); int64_t old = shp.dim_size(j); old_vars.push_back({i, j, old}); xla::DynExpr* padded_expr = xla::DynExpr::_(filled_batch); - xla::DynExpr* subst_expr = e->substitute(1, padded_expr)->s(); + xla::DynExpr* subst_expr = + e->substitute(dynamic_var_id, padded_expr)->s(); int64_t new_dim = subst_expr->get_val(); if (new_dim >= 0) { shp.set_dim(j, new_dim); diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index 06196e8d78f26c..ff6ee094338e40 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -68,6 +68,13 @@ namespace { using xla::ScopedShapedBuffer; using xla::ShapedBuffer; +int64_t GetSingleDynamicVarId(const xla::DynExpr* expr) { + CHECK_NE(expr, nullptr); + auto ids = expr->get_all_ids(); + CHECK_EQ(ids.size(), 1); + return ids.front(); +} + // Fetch the platform Id from device. se::Platform::Id XlaPlatformInfoFromDevice(DeviceBase* device_base) { auto device = static_cast(device_base); @@ -444,10 +451,12 @@ absl::Status XlaComputationLaunchContext::PopulateOutputs( auto expr = subshape.expressions(dim); if (expr != nullptr && expr->is_dynamic()) { has_dynamic = true; + int64_t dynamic_var_id = GetSingleDynamicVarId(expr); VLOG(1) << "Current expression is " << expr; if (run_options) { xla::DynExpr* batch_size = xla::DynExpr::_(run_options->batch_size()); - xla::DynExpr* subst_expr = expr->substitute(1, batch_size)->s(); + xla::DynExpr* subst_expr = + expr->substitute(dynamic_var_id, batch_size)->s(); shape.set_dim(dim, subst_expr->get_val()); } else { // TODO: Fallback to BatchSizeResource for now. Remove it later. @@ -457,8 +466,8 @@ absl::Status XlaComputationLaunchContext::PopulateOutputs( TF_RETURN_IF_ERROR(step_container->Lookup( ctx->resource_manager(), BatchSizeResourceName, &bsr)); xla::DynExpr* batch_size = xla::DynExpr::_(bsr->GetBatchSize()); - // Just substitute Var(1) for now. - xla::DynExpr* subst_expr = expr->substitute(1, batch_size)->s(); + xla::DynExpr* subst_expr = + expr->substitute(dynamic_var_id, batch_size)->s(); shape.set_dim(dim, subst_expr->get_val()); bsr->Unref(); }