Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions tensorflow/compiler/jit/kernels/xla_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ static xla::DynExpr* DimExprToDynExpr(const DimExpr* e) {
}
case DimExpr::Kind::kVariable: {
auto* av = static_cast<const Variable*>(e);
return xla::DynExpr::V(1);
return xla::DynExpr::V(av->id());
}
case DimExpr::Kind::kAdd: {
auto* ee = static_cast<const ExprAdd*>(e);
Expand All @@ -448,6 +448,13 @@ static xla::DynExpr* DimExprToDynExpr(const DimExpr* e) {
return nullptr;
}

static int64_t GetSingleDynamicVarId(const xla::DynExpr* expr) {
CHECK_NE(expr, nullptr);
auto ids = expr->get_all_ids();
CHECK_EQ(ids.size(), 1);
return ids.front();
}


absl::Status CompileToLocalExecutable(
OpKernelContext* ctx, const NameAttrList& function, bool has_ref_vars,
Expand Down Expand Up @@ -685,10 +692,12 @@ absl::Status CompileToLocalExecutable(
for (int j = 0; j < shp.get_expressions().size(); ++j) {
auto e = shp.get_expression(j);
if (e->is_dynamic()) {
int64_t dynamic_var_id = GetSingleDynamicVarId(e);
int64_t old = shp.dim_size(j);
old_vars.push_back({i, j, old});
xla::DynExpr* padded_expr = xla::DynExpr::_(filled_batch);
xla::DynExpr* subst_expr = e->substitute(1, padded_expr)->s();
xla::DynExpr* subst_expr =
e->substitute(dynamic_var_id, padded_expr)->s();
int64_t new_dim = subst_expr->get_val();
if (new_dim >= 0) {
shp.set_dim(j, new_dim);
Expand Down
15 changes: 12 additions & 3 deletions tensorflow/compiler/jit/xla_launch_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,13 @@ namespace {
using xla::ScopedShapedBuffer;
using xla::ShapedBuffer;

int64_t GetSingleDynamicVarId(const xla::DynExpr* expr) {
CHECK_NE(expr, nullptr);
auto ids = expr->get_all_ids();
CHECK_EQ(ids.size(), 1);
return ids.front();
}

// Fetch the platform Id from device.
se::Platform::Id XlaPlatformInfoFromDevice(DeviceBase* device_base) {
auto device = static_cast<Device*>(device_base);
Expand Down Expand Up @@ -444,10 +451,12 @@ absl::Status XlaComputationLaunchContext::PopulateOutputs(
auto expr = subshape.expressions(dim);
if (expr != nullptr && expr->is_dynamic()) {
has_dynamic = true;
int64_t dynamic_var_id = GetSingleDynamicVarId(expr);
VLOG(1) << "Current expression is " << expr;
if (run_options) {
xla::DynExpr* batch_size = xla::DynExpr::_(run_options->batch_size());
xla::DynExpr* subst_expr = expr->substitute(1, batch_size)->s();
xla::DynExpr* subst_expr =
expr->substitute(dynamic_var_id, batch_size)->s();
shape.set_dim(dim, subst_expr->get_val());
} else {
// TODO: Fallback to BatchSizeResource for now. Remove it later.
Expand All @@ -457,8 +466,8 @@ absl::Status XlaComputationLaunchContext::PopulateOutputs(
TF_RETURN_IF_ERROR(step_container->Lookup<BatchSizeResource>(
ctx->resource_manager(), BatchSizeResourceName, &bsr));
xla::DynExpr* batch_size = xla::DynExpr::_(bsr->GetBatchSize());
// Just substitute Var(1) for now.
xla::DynExpr* subst_expr = expr->substitute(1, batch_size)->s();
xla::DynExpr* subst_expr =
expr->substitute(dynamic_var_id, batch_size)->s();
shape.set_dim(dim, subst_expr->get_val());
bsr->Unref();
}
Expand Down