From ff12bfe1dc22e4228203de069e2397f697ebef7e Mon Sep 17 00:00:00 2001 From: Steven Varoumas Date: Tue, 17 Mar 2026 23:42:00 +0000 Subject: [PATCH 1/5] Add comparison and select DynExpr operations --- .../xla/xla/service/llvm_ir/llvm_util.cc | 41 +++ third_party/xla/xla/shape.cc | 120 ++++++++ third_party/xla/xla/shape_dynexpr.h | 267 ++++++++++++++++++ third_party/xla/xla/xla_data.proto | 50 +++- 4 files changed, 477 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/service/llvm_ir/llvm_util.cc b/third_party/xla/xla/service/llvm_ir/llvm_util.cc index fa27de28a84201..7ccb6ae878a2e3 100644 --- a/third_party/xla/xla/service/llvm_ir/llvm_util.cc +++ b/third_party/xla/xla/service/llvm_ir/llvm_util.cc @@ -913,6 +913,47 @@ llvm::Value* EmitExpression(llvm::IRBuilderBase* b, DynExpr* expr) { llvm::Value* v_rhs = EmitExpression(b, sub_node->get_rhs()); return b->CreateSub(v_lhs, v_rhs, "sub_dims"); } + if (GtExpr* gt_node = dynamic_cast(expr)) { + llvm::Value* v_lhs = EmitExpression(b, gt_node->get_lhs()); + llvm::Value* v_rhs = EmitExpression(b, gt_node->get_rhs()); + llvm::Value* pred = b->CreateICmpSGT(v_lhs, v_rhs, "gt_dims"); + return b->CreateZExt(pred, i64Type, "gt_dims_i64"); + } + if (EqExpr* eq_node = dynamic_cast(expr)) { + llvm::Value* v_lhs = EmitExpression(b, eq_node->get_lhs()); + llvm::Value* v_rhs = EmitExpression(b, eq_node->get_rhs()); + llvm::Value* pred = b->CreateICmpEQ(v_lhs, v_rhs, "eq_dims"); + return b->CreateZExt(pred, i64Type, "eq_dims_i64"); + } + if (NeExpr* ne_node = dynamic_cast(expr)) { + llvm::Value* v_lhs = EmitExpression(b, ne_node->get_lhs()); + llvm::Value* v_rhs = EmitExpression(b, ne_node->get_rhs()); + llvm::Value* pred = b->CreateICmpNE(v_lhs, v_rhs, "ne_dims"); + return b->CreateZExt(pred, i64Type, "ne_dims_i64"); + } + if (GeExpr* ge_node = dynamic_cast(expr)) { + llvm::Value* v_lhs = EmitExpression(b, ge_node->get_lhs()); + llvm::Value* v_rhs = EmitExpression(b, ge_node->get_rhs()); + llvm::Value* pred = b->CreateICmpSGE(v_lhs, v_rhs, "ge_dims"); + return b->CreateZExt(pred, i64Type, "ge_dims_i64"); + } + if (AbsExpr* abs_node = dynamic_cast(expr)) { + llvm::Value* value = EmitExpression(b, abs_node->get_expr()); + llvm::Value* is_negative = + b->CreateICmpSLT(value, llvm::ConstantInt::get(i64Type, 0, true), + "abs_dims_neg"); + llvm::Value* negated = b->CreateNeg(value, "abs_dims_negated"); + return b->CreateSelect(is_negative, negated, value, "abs_dims"); + } + if (SelectExpr* select_node = dynamic_cast(expr)) { + llvm::Value* pred = EmitExpression(b, select_node->get_pred()); + llvm::Value* on_true = EmitExpression(b, select_node->get_on_true()); + llvm::Value* on_false = EmitExpression(b, select_node->get_on_false()); + llvm::Value* cond = + b->CreateICmpNE(pred, llvm::ConstantInt::get(i64Type, 0, true), + "select_dims_pred"); + return b->CreateSelect(cond, on_true, on_false, "select_dims"); + } return nullptr; } diff --git a/third_party/xla/xla/shape.cc b/third_party/xla/xla/shape.cc index 962984b9225b6a..58a44c1ab7889b 100644 --- a/third_party/xla/xla/shape.cc +++ b/third_party/xla/xla/shape.cc @@ -69,6 +69,48 @@ DynExpr* ExprFromProto(const ExpressionProto& proto) { return *ExprFromProto(div.lhs()) / *ExprFromProto(div.rhs()); } + case ExpressionProto::kLtNode: { + const auto& lt = proto.lt_node(); + return Gt(*ExprFromProto(lt.rhs()), *ExprFromProto(lt.lhs())); + } + + case ExpressionProto::kGtNode: { + const auto& gt = proto.gt_node(); + return Gt(*ExprFromProto(gt.lhs()), *ExprFromProto(gt.rhs())); + } + + case ExpressionProto::kSelectNode: { + const auto& select = proto.select_node(); + return Select(*ExprFromProto(select.pred()), + *ExprFromProto(select.on_true()), + *ExprFromProto(select.on_false())); + } + + case ExpressionProto::kEqNode: { + const auto& eq = proto.eq_node(); + return Eq(*ExprFromProto(eq.lhs()), *ExprFromProto(eq.rhs())); + } + + case ExpressionProto::kNeNode: { + const auto& ne = proto.ne_node(); + return Ne(*ExprFromProto(ne.lhs()), *ExprFromProto(ne.rhs())); + } + + case ExpressionProto::kLeNode: { + const auto& le = proto.le_node(); + return Ge(*ExprFromProto(le.rhs()), *ExprFromProto(le.lhs())); + } + + case ExpressionProto::kGeNode: { + const auto& ge = proto.ge_node(); + return Ge(*ExprFromProto(ge.lhs()), *ExprFromProto(ge.rhs())); + } + + case ExpressionProto::kAbsNode: { + const auto& abs = proto.abs_node(); + return Abs(*ExprFromProto(abs.expr())); + } + case ExpressionProto::NODE_TYPE_NOT_SET: default: return nullptr; @@ -91,6 +133,26 @@ DynExpr* operator-(DynExpr& lhs, DynExpr& rhs) { return new Sub(&lhs, &rhs); } DynExpr* operator-(DynExpr& lhs, int64_t d) { return new Sub(&lhs, DynExpr::_(d)); } +DynExpr* Gt(DynExpr& lhs, DynExpr& rhs) { return new GtExpr(&lhs, &rhs); } +DynExpr* Gt(DynExpr& lhs, int64_t rhs) { + return new GtExpr(&lhs, DynExpr::_(rhs)); +} +DynExpr* Eq(DynExpr& lhs, DynExpr& rhs) { return new EqExpr(&lhs, &rhs); } +DynExpr* Eq(DynExpr& lhs, int64_t rhs) { + return new EqExpr(&lhs, DynExpr::_(rhs)); +} +DynExpr* Ne(DynExpr& lhs, DynExpr& rhs) { return new NeExpr(&lhs, &rhs); } +DynExpr* Ne(DynExpr& lhs, int64_t rhs) { + return new NeExpr(&lhs, DynExpr::_(rhs)); +} +DynExpr* Ge(DynExpr& lhs, DynExpr& rhs) { return new GeExpr(&lhs, &rhs); } +DynExpr* Ge(DynExpr& lhs, int64_t rhs) { + return new GeExpr(&lhs, DynExpr::_(rhs)); +} +DynExpr* Abs(DynExpr& expr) { return new AbsExpr(&expr); } +DynExpr* Select(DynExpr& pred, DynExpr& on_true, DynExpr& on_false) { + return new SelectExpr(&pred, &on_true, &on_false); +} bool operator==(DynExpr& lhs, DynExpr& rhs) { return DynExpr::equal(&lhs, &rhs); } @@ -320,6 +382,64 @@ DynExpr* Div::s() { return *s_lhs / *s_rhs; } +DynExpr* GtExpr::s() { + DynExpr* s_lhs = get_lhs()->s(); + DynExpr* s_rhs = get_rhs()->s(); + if (*s_lhs == *s_rhs) return DynExpr::zero; + Constant* l = dynamic_cast(s_lhs); + Constant* r = dynamic_cast(s_rhs); + if (l && r) return DynExpr::_(l->get_val() > r->get_val()); + return Gt(*s_lhs, *s_rhs); +} + +DynExpr* SelectExpr::s() { + DynExpr* s_pred = get_pred()->s(); + DynExpr* s_true = get_on_true()->s(); + DynExpr* s_false = get_on_false()->s(); + Constant* p = dynamic_cast(s_pred); + if (p) return p->get_val() ? s_true : s_false; + if (*s_true == *s_false) return s_true; + return Select(*s_pred, *s_true, *s_false); +} + +DynExpr* EqExpr::s() { + DynExpr* s_lhs = get_lhs()->s(); + DynExpr* s_rhs = get_rhs()->s(); + if (*s_lhs == *s_rhs) return DynExpr::one; + Constant* l = dynamic_cast(s_lhs); + Constant* r = dynamic_cast(s_rhs); + if (l && r) return DynExpr::_(l->get_val() == r->get_val()); + return Eq(*s_lhs, *s_rhs); +} + +DynExpr* NeExpr::s() { + DynExpr* s_lhs = get_lhs()->s(); + DynExpr* s_rhs = get_rhs()->s(); + if (*s_lhs == *s_rhs) return DynExpr::zero; + Constant* l = dynamic_cast(s_lhs); + Constant* r = dynamic_cast(s_rhs); + if (l && r) return DynExpr::_(l->get_val() != r->get_val()); + return Ne(*s_lhs, *s_rhs); +} + +DynExpr* GeExpr::s() { + DynExpr* s_lhs = get_lhs()->s(); + DynExpr* s_rhs = get_rhs()->s(); + if (*s_lhs == *s_rhs) return DynExpr::one; + Constant* l = dynamic_cast(s_lhs); + Constant* r = dynamic_cast(s_rhs); + if (l && r) return DynExpr::_(l->get_val() >= r->get_val()); + return Ge(*s_lhs, *s_rhs); +} + +DynExpr* AbsExpr::s() { + DynExpr* s_expr = get_expr()->s(); + Constant* c = dynamic_cast(s_expr); + if (c) return DynExpr::_(c->get_val() < 0 ? -c->get_val() : c->get_val()); + if (AbsExpr* inner = dynamic_cast(s_expr)) return inner; + return Abs(*s_expr); +} + std::ostream& operator<<(std::ostream& os, DynExpr* expr) { ExpressionProto proto; expr->to_proto(&proto); diff --git a/third_party/xla/xla/shape_dynexpr.h b/third_party/xla/xla/shape_dynexpr.h index 5c1f5645c25e0e..6b0f1b0fc8d623 100644 --- a/third_party/xla/xla/shape_dynexpr.h +++ b/third_party/xla/xla/shape_dynexpr.h @@ -356,6 +356,263 @@ class Div : public DynExpr { } }; +class GtExpr : public DynExpr { + DynExpr* lhs; + DynExpr* rhs; + + public: + GtExpr(DynExpr* l, DynExpr* r) : lhs(std::move(l)), rhs(std::move(r)) {} + void print(xla::Printer* printer) const override { + printer->Append("("); + lhs->print(printer); + printer->Append(" > "); + rhs->print(printer); + printer->Append(")"); + } + + DynExpr* get_lhs() const { return lhs; } + DynExpr* get_rhs() const { return rhs; } + + void to_proto(xla::ExpressionProto* proto) const override { + auto* gt_msg = proto->mutable_gt_node(); + lhs->to_proto(gt_msg->mutable_lhs()); + rhs->to_proto(gt_msg->mutable_rhs()); + } + + bool is_constant() const override { + return lhs->is_constant() && rhs->is_constant(); + } + + int64_t get_val() const override { return lhs->get_val() > rhs->get_val(); } + + DynExpr* substitute(int id, DynExpr* v) override { + return new GtExpr(lhs->substitute(id, v), rhs->substitute(id, v)); + } + + std::set get_all_ids() override { + auto s = lhs->get_all_ids(); + s.merge(rhs->get_all_ids()); + return s; + } + + int64_t solve(int64_t x) override { return -1; } + DynExpr* s() override; + + ~GtExpr() { + delete lhs; + delete rhs; + } +}; + +class SelectExpr : public DynExpr { + DynExpr* pred; + DynExpr* on_true; + DynExpr* on_false; + + public: + SelectExpr(DynExpr* p, DynExpr* t, DynExpr* f) + : pred(std::move(p)), on_true(std::move(t)), on_false(std::move(f)) {} + void print(xla::Printer* printer) const override { + printer->Append("select("); + pred->print(printer); + printer->Append(", "); + on_true->print(printer); + printer->Append(", "); + on_false->print(printer); + printer->Append(")"); + } + + DynExpr* get_pred() const { return pred; } + DynExpr* get_on_true() const { return on_true; } + DynExpr* get_on_false() const { return on_false; } + + void to_proto(xla::ExpressionProto* proto) const override { + auto* select_msg = proto->mutable_select_node(); + pred->to_proto(select_msg->mutable_pred()); + on_true->to_proto(select_msg->mutable_on_true()); + on_false->to_proto(select_msg->mutable_on_false()); + } + + bool is_constant() const override { + return pred->is_constant() && on_true->is_constant() && + on_false->is_constant(); + } + + int64_t get_val() const override { + return pred->get_val() ? on_true->get_val() : on_false->get_val(); + } + + DynExpr* substitute(int id, DynExpr* v) override { + return new SelectExpr(pred->substitute(id, v), on_true->substitute(id, v), + on_false->substitute(id, v)); + } + + std::set get_all_ids() override { + auto s = pred->get_all_ids(); + s.merge(on_true->get_all_ids()); + s.merge(on_false->get_all_ids()); + return s; + } + + int64_t solve(int64_t x) override { return -1; } + DynExpr* s() override; + + ~SelectExpr() { + delete pred; + delete on_true; + delete on_false; + } +}; + +class EqExpr : public DynExpr { + DynExpr* lhs; + DynExpr* rhs; + + public: + EqExpr(DynExpr* l, DynExpr* r) : lhs(std::move(l)), rhs(std::move(r)) {} + void print(xla::Printer* printer) const override { + printer->Append("("); + lhs->print(printer); + printer->Append(" == "); + rhs->print(printer); + printer->Append(")"); + } + DynExpr* get_lhs() const { return lhs; } + DynExpr* get_rhs() const { return rhs; } + void to_proto(xla::ExpressionProto* proto) const override { + auto* msg = proto->mutable_eq_node(); + lhs->to_proto(msg->mutable_lhs()); + rhs->to_proto(msg->mutable_rhs()); + } + bool is_constant() const override { + return lhs->is_constant() && rhs->is_constant(); + } + int64_t get_val() const override { return lhs->get_val() == rhs->get_val(); } + DynExpr* substitute(int id, DynExpr* v) override { + return new EqExpr(lhs->substitute(id, v), rhs->substitute(id, v)); + } + std::set get_all_ids() override { + auto s = lhs->get_all_ids(); + s.merge(rhs->get_all_ids()); + return s; + } + int64_t solve(int64_t x) override { return -1; } + DynExpr* s() override; + ~EqExpr() { + delete lhs; + delete rhs; + } +}; + +class NeExpr : public DynExpr { + DynExpr* lhs; + DynExpr* rhs; + + public: + NeExpr(DynExpr* l, DynExpr* r) : lhs(std::move(l)), rhs(std::move(r)) {} + void print(xla::Printer* printer) const override { + printer->Append("("); + lhs->print(printer); + printer->Append(" != "); + rhs->print(printer); + printer->Append(")"); + } + DynExpr* get_lhs() const { return lhs; } + DynExpr* get_rhs() const { return rhs; } + void to_proto(xla::ExpressionProto* proto) const override { + auto* msg = proto->mutable_ne_node(); + lhs->to_proto(msg->mutable_lhs()); + rhs->to_proto(msg->mutable_rhs()); + } + bool is_constant() const override { + return lhs->is_constant() && rhs->is_constant(); + } + int64_t get_val() const override { return lhs->get_val() != rhs->get_val(); } + DynExpr* substitute(int id, DynExpr* v) override { + return new NeExpr(lhs->substitute(id, v), rhs->substitute(id, v)); + } + std::set get_all_ids() override { + auto s = lhs->get_all_ids(); + s.merge(rhs->get_all_ids()); + return s; + } + int64_t solve(int64_t x) override { return -1; } + DynExpr* s() override; + ~NeExpr() { + delete lhs; + delete rhs; + } +}; + +class GeExpr : public DynExpr { + DynExpr* lhs; + DynExpr* rhs; + + public: + GeExpr(DynExpr* l, DynExpr* r) : lhs(std::move(l)), rhs(std::move(r)) {} + void print(xla::Printer* printer) const override { + printer->Append("("); + lhs->print(printer); + printer->Append(" >= "); + rhs->print(printer); + printer->Append(")"); + } + DynExpr* get_lhs() const { return lhs; } + DynExpr* get_rhs() const { return rhs; } + void to_proto(xla::ExpressionProto* proto) const override { + auto* msg = proto->mutable_ge_node(); + lhs->to_proto(msg->mutable_lhs()); + rhs->to_proto(msg->mutable_rhs()); + } + bool is_constant() const override { + return lhs->is_constant() && rhs->is_constant(); + } + int64_t get_val() const override { return lhs->get_val() >= rhs->get_val(); } + DynExpr* substitute(int id, DynExpr* v) override { + return new GeExpr(lhs->substitute(id, v), rhs->substitute(id, v)); + } + std::set get_all_ids() override { + auto s = lhs->get_all_ids(); + s.merge(rhs->get_all_ids()); + return s; + } + int64_t solve(int64_t x) override { return -1; } + DynExpr* s() override; + ~GeExpr() { + delete lhs; + delete rhs; + } +}; + +class AbsExpr : public DynExpr { + DynExpr* expr; + + public: + explicit AbsExpr(DynExpr* e) : expr(std::move(e)) {} + void print(xla::Printer* printer) const override { + printer->Append("abs("); + expr->print(printer); + printer->Append(")"); + } + DynExpr* get_expr() const { return expr; } + void to_proto(xla::ExpressionProto* proto) const override { + auto* msg = proto->mutable_abs_node(); + expr->to_proto(msg->mutable_expr()); + } + bool is_constant() const override { return expr->is_constant(); } + int64_t get_val() const override { + int64_t v = expr->get_val(); + return v < 0 ? -v : v; + } + DynExpr* substitute(int id, DynExpr* v) override { + return new AbsExpr(expr->substitute(id, v)); + } + std::set get_all_ids() override { return expr->get_all_ids(); } + int64_t solve(int64_t x) override { return -1; } + DynExpr* s() override; + ~AbsExpr() { delete expr; } +}; + DynExpr* operator*(DynExpr& lhs, DynExpr& rhs); DynExpr* operator*(int64_t k, DynExpr& rhs); DynExpr* operator/(DynExpr& lhs, DynExpr& rhs); @@ -364,6 +621,16 @@ DynExpr* operator+(DynExpr& lhs, DynExpr& rhs); DynExpr* operator+(DynExpr& lhs, int64_t d); DynExpr* operator-(DynExpr& lhs, DynExpr& rhs); DynExpr* operator-(DynExpr& lhs, int64_t d); +DynExpr* Gt(DynExpr& lhs, DynExpr& rhs); +DynExpr* Gt(DynExpr& lhs, int64_t rhs); +DynExpr* Eq(DynExpr& lhs, DynExpr& rhs); +DynExpr* Eq(DynExpr& lhs, int64_t rhs); +DynExpr* Ne(DynExpr& lhs, DynExpr& rhs); +DynExpr* Ne(DynExpr& lhs, int64_t rhs); +DynExpr* Ge(DynExpr& lhs, DynExpr& rhs); +DynExpr* Ge(DynExpr& lhs, int64_t rhs); +DynExpr* Abs(DynExpr& expr); +DynExpr* Select(DynExpr& pred, DynExpr& on_true, DynExpr& on_false); bool operator==(DynExpr& lhs, DynExpr& rhs); bool operator==(DynExpr& lhs, int64_t d); diff --git a/third_party/xla/xla/xla_data.proto b/third_party/xla/xla/xla_data.proto index 798dc12fb5734e..518cd4fe5e39b6 100644 --- a/third_party/xla/xla/xla_data.proto +++ b/third_party/xla/xla/xla_data.proto @@ -1215,6 +1215,14 @@ message ExpressionProto { SubNode sub_node = 4; // exp - exp MulNode mul_node = 5; // exp * exp DivNode div_node = 6; // exp / exp + LtNode lt_node = 7; // exp < exp + GtNode gt_node = 8; // exp > exp + SelectNode select_node = 9; // select(pred, on_true, on_false) + EqNode eq_node = 10; // exp == exp + NeNode ne_node = 11; // exp != exp + LeNode le_node = 12; // exp <= exp + GeNode ge_node = 13; // exp >= exp + AbsNode abs_node = 14; // abs(exp) } } @@ -1236,4 +1244,44 @@ message MulNode { message DivNode { ExpressionProto lhs = 1; ExpressionProto rhs = 2; -} \ No newline at end of file +} + +message LtNode { + ExpressionProto lhs = 1; + ExpressionProto rhs = 2; +} + +message GtNode { + ExpressionProto lhs = 1; + ExpressionProto rhs = 2; +} + +message SelectNode { + ExpressionProto pred = 1; + ExpressionProto on_true = 2; + ExpressionProto on_false = 3; +} + +message EqNode { + ExpressionProto lhs = 1; + ExpressionProto rhs = 2; +} + +message NeNode { + ExpressionProto lhs = 1; + ExpressionProto rhs = 2; +} + +message LeNode { + ExpressionProto lhs = 1; + ExpressionProto rhs = 2; +} + +message GeNode { + ExpressionProto lhs = 1; + ExpressionProto rhs = 2; +} + +message AbsNode { + ExpressionProto expr = 1; +} From 912a34b8757cadccf41b68ca8c1001ba73af4b36 Mon Sep 17 00:00:00 2001 From: Steven Varoumas Date: Wed, 18 Mar 2026 00:12:20 +0000 Subject: [PATCH 2/5] Add Max DynExpr support --- .../tf2xla/kernels/strided_slice_op.cc | 9 ++- .../xla/xla/service/llvm_ir/llvm_util.cc | 28 +++++---- third_party/xla/xla/shape.cc | 61 ++++++++++++++----- third_party/xla/xla/shape_dynexpr.h | 58 ++++++++++++++++++ third_party/xla/xla/xla_data.proto | 6 ++ 5 files changed, 128 insertions(+), 34 deletions(-) diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc index df7deaaf80d3bf..8d56835cace86c 100644 --- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc @@ -313,8 +313,7 @@ class StridedSliceOp : public XlaOpKernel { slice_begin.push_back(begin[i]); slice_begin_expr.push_back(begin_expr[i]); slice_end.push_back(std::max(end[i], begin[i])); - slice_end_expr.push_back((end[i] > begin[i]) ? end_expr[i] - : begin_expr[i]); + slice_end_expr.push_back(xla::dexpr::Max(*end_expr[i], *begin_expr[i])->s()); slice_strides.push_back(strides[i]); } else { // Negative stride: swap begin and end, add 1 because the interval @@ -326,9 +325,9 @@ class StridedSliceOp : public XlaOpKernel { slice_end.push_back(std::max(input_shape.dim_size(i) - end[i] - 1, input_shape.dim_size(i) - begin[i] - 1)); slice_end_expr.push_back( - (end[i] < begin[i]) - ? (*input_exprs[i] - *end_expr[i] - 1)->s() - : (*input_exprs[i] - *begin_expr[i] - 1)->s()); + xla::dexpr::Max(*(*input_exprs[i] - *end_expr[i] - 1)->s(), + *(*input_exprs[i] - *begin_expr[i] - 1)->s()) + ->s()); slice_strides.push_back(-strides[i]); dimensions_to_reverse.push_back(i); } diff --git a/third_party/xla/xla/service/llvm_ir/llvm_util.cc b/third_party/xla/xla/service/llvm_ir/llvm_util.cc index 7ccb6ae878a2e3..2451b4f5a10b06 100644 --- a/third_party/xla/xla/service/llvm_ir/llvm_util.cc +++ b/third_party/xla/xla/service/llvm_ir/llvm_util.cc @@ -897,10 +897,12 @@ llvm::Value* EmitExpression(llvm::IRBuilderBase* b, DynExpr* expr) { llvm::Value* v_rhs = EmitExpression(b, mul_node->get_rhs()); return b->CreateMul(v_lhs, v_rhs, "mul_dims"); } - // TODO: Check if this should ever happen if (Div* div_node = dynamic_cast(expr)) { llvm::Value* v_lhs = EmitExpression(b, div_node->get_lhs()); llvm::Value* v_rhs = EmitExpression(b, div_node->get_rhs()); + if (auto* rhs_const = llvm::dyn_cast(v_rhs)) { + CHECK_NE(rhs_const->getSExtValue(), 0); + } return b->CreateUDiv(v_lhs, v_rhs, "div_dims"); } if (Add* add_node = dynamic_cast(expr)) { @@ -916,26 +918,28 @@ llvm::Value* EmitExpression(llvm::IRBuilderBase* b, DynExpr* expr) { if (GtExpr* gt_node = dynamic_cast(expr)) { llvm::Value* v_lhs = EmitExpression(b, gt_node->get_lhs()); llvm::Value* v_rhs = EmitExpression(b, gt_node->get_rhs()); - llvm::Value* pred = b->CreateICmpSGT(v_lhs, v_rhs, "gt_dims"); - return b->CreateZExt(pred, i64Type, "gt_dims_i64"); + return b->CreateICmpSGT(v_lhs, v_rhs, "gt_dims"); } if (EqExpr* eq_node = dynamic_cast(expr)) { llvm::Value* v_lhs = EmitExpression(b, eq_node->get_lhs()); llvm::Value* v_rhs = EmitExpression(b, eq_node->get_rhs()); - llvm::Value* pred = b->CreateICmpEQ(v_lhs, v_rhs, "eq_dims"); - return b->CreateZExt(pred, i64Type, "eq_dims_i64"); + return b->CreateICmpEQ(v_lhs, v_rhs, "eq_dims"); } if (NeExpr* ne_node = dynamic_cast(expr)) { llvm::Value* v_lhs = EmitExpression(b, ne_node->get_lhs()); llvm::Value* v_rhs = EmitExpression(b, ne_node->get_rhs()); - llvm::Value* pred = b->CreateICmpNE(v_lhs, v_rhs, "ne_dims"); - return b->CreateZExt(pred, i64Type, "ne_dims_i64"); + return b->CreateICmpNE(v_lhs, v_rhs, "ne_dims"); } if (GeExpr* ge_node = dynamic_cast(expr)) { llvm::Value* v_lhs = EmitExpression(b, ge_node->get_lhs()); llvm::Value* v_rhs = EmitExpression(b, ge_node->get_rhs()); - llvm::Value* pred = b->CreateICmpSGE(v_lhs, v_rhs, "ge_dims"); - return b->CreateZExt(pred, i64Type, "ge_dims_i64"); + return b->CreateICmpSGE(v_lhs, v_rhs, "ge_dims"); + } + if (MaxExpr* max_node = dynamic_cast(expr)) { + llvm::Value* v_lhs = EmitExpression(b, max_node->get_lhs()); + llvm::Value* v_rhs = EmitExpression(b, max_node->get_rhs()); + llvm::Value* pred = b->CreateICmpSGT(v_lhs, v_rhs, "max_dims_pred"); + return b->CreateSelect(pred, v_lhs, v_rhs, "max_dims"); } if (AbsExpr* abs_node = dynamic_cast(expr)) { llvm::Value* value = EmitExpression(b, abs_node->get_expr()); @@ -949,10 +953,8 @@ llvm::Value* EmitExpression(llvm::IRBuilderBase* b, DynExpr* expr) { llvm::Value* pred = EmitExpression(b, select_node->get_pred()); llvm::Value* on_true = EmitExpression(b, select_node->get_on_true()); llvm::Value* on_false = EmitExpression(b, select_node->get_on_false()); - llvm::Value* cond = - b->CreateICmpNE(pred, llvm::ConstantInt::get(i64Type, 0, true), - "select_dims_pred"); - return b->CreateSelect(cond, on_true, on_false, "select_dims"); + CHECK(pred->getType()->isIntegerTy(1)); + return b->CreateSelect(pred, on_true, on_false, "select_dims"); } return nullptr; } diff --git a/third_party/xla/xla/shape.cc b/third_party/xla/xla/shape.cc index 58a44c1ab7889b..0a8fbf3a186353 100644 --- a/third_party/xla/xla/shape.cc +++ b/third_party/xla/xla/shape.cc @@ -71,44 +71,49 @@ DynExpr* ExprFromProto(const ExpressionProto& proto) { case ExpressionProto::kLtNode: { const auto& lt = proto.lt_node(); - return Gt(*ExprFromProto(lt.rhs()), *ExprFromProto(lt.lhs())); + return dexpr::Gt(*ExprFromProto(lt.rhs()), *ExprFromProto(lt.lhs())); } case ExpressionProto::kGtNode: { const auto& gt = proto.gt_node(); - return Gt(*ExprFromProto(gt.lhs()), *ExprFromProto(gt.rhs())); + return dexpr::Gt(*ExprFromProto(gt.lhs()), *ExprFromProto(gt.rhs())); } case ExpressionProto::kSelectNode: { const auto& select = proto.select_node(); - return Select(*ExprFromProto(select.pred()), - *ExprFromProto(select.on_true()), - *ExprFromProto(select.on_false())); + return dexpr::Select(*ExprFromProto(select.pred()), + *ExprFromProto(select.on_true()), + *ExprFromProto(select.on_false())); } case ExpressionProto::kEqNode: { const auto& eq = proto.eq_node(); - return Eq(*ExprFromProto(eq.lhs()), *ExprFromProto(eq.rhs())); + return dexpr::Eq(*ExprFromProto(eq.lhs()), *ExprFromProto(eq.rhs())); } case ExpressionProto::kNeNode: { const auto& ne = proto.ne_node(); - return Ne(*ExprFromProto(ne.lhs()), *ExprFromProto(ne.rhs())); + return dexpr::Ne(*ExprFromProto(ne.lhs()), *ExprFromProto(ne.rhs())); } case ExpressionProto::kLeNode: { const auto& le = proto.le_node(); - return Ge(*ExprFromProto(le.rhs()), *ExprFromProto(le.lhs())); + return dexpr::Ge(*ExprFromProto(le.rhs()), *ExprFromProto(le.lhs())); } case ExpressionProto::kGeNode: { const auto& ge = proto.ge_node(); - return Ge(*ExprFromProto(ge.lhs()), *ExprFromProto(ge.rhs())); + return dexpr::Ge(*ExprFromProto(ge.lhs()), *ExprFromProto(ge.rhs())); } case ExpressionProto::kAbsNode: { const auto& abs = proto.abs_node(); - return Abs(*ExprFromProto(abs.expr())); + return dexpr::Abs(*ExprFromProto(abs.expr())); + } + + case ExpressionProto::kMaxNode: { + const auto& max = proto.max_node(); + return dexpr::Max(*ExprFromProto(max.lhs()), *ExprFromProto(max.rhs())); } case ExpressionProto::NODE_TYPE_NOT_SET: @@ -133,6 +138,15 @@ DynExpr* operator-(DynExpr& lhs, DynExpr& rhs) { return new Sub(&lhs, &rhs); } DynExpr* operator-(DynExpr& lhs, int64_t d) { return new Sub(&lhs, DynExpr::_(d)); } +namespace dexpr { +DynExpr* Mul(DynExpr& lhs, DynExpr& rhs) { return new xla::Mul(&lhs, &rhs); } +DynExpr* Mul(int64_t lhs, DynExpr& rhs) { return new xla::Mul(DynExpr::_(lhs), &rhs); } +DynExpr* Div(DynExpr& lhs, DynExpr& rhs) { return new xla::Div(&lhs, &rhs); } +DynExpr* Div(DynExpr& lhs, int64_t rhs) { return new xla::Div(&lhs, DynExpr::_(rhs)); } +DynExpr* Add(DynExpr& lhs, DynExpr& rhs) { return new xla::Add(&lhs, &rhs); } +DynExpr* Add(DynExpr& lhs, int64_t rhs) { return new xla::Add(&lhs, DynExpr::_(rhs)); } +DynExpr* Sub(DynExpr& lhs, DynExpr& rhs) { return new xla::Sub(&lhs, &rhs); } +DynExpr* Sub(DynExpr& lhs, int64_t rhs) { return new xla::Sub(&lhs, DynExpr::_(rhs)); } DynExpr* Gt(DynExpr& lhs, DynExpr& rhs) { return new GtExpr(&lhs, &rhs); } DynExpr* Gt(DynExpr& lhs, int64_t rhs) { return new GtExpr(&lhs, DynExpr::_(rhs)); @@ -149,10 +163,15 @@ DynExpr* Ge(DynExpr& lhs, DynExpr& rhs) { return new GeExpr(&lhs, &rhs); } DynExpr* Ge(DynExpr& lhs, int64_t rhs) { return new GeExpr(&lhs, DynExpr::_(rhs)); } +DynExpr* Max(DynExpr& lhs, DynExpr& rhs) { return new MaxExpr(&lhs, &rhs); } +DynExpr* Max(DynExpr& lhs, int64_t rhs) { + return new MaxExpr(&lhs, DynExpr::_(rhs)); +} DynExpr* Abs(DynExpr& expr) { return new AbsExpr(&expr); } DynExpr* Select(DynExpr& pred, DynExpr& on_true, DynExpr& on_false) { return new SelectExpr(&pred, &on_true, &on_false); } +} // namespace dexpr bool operator==(DynExpr& lhs, DynExpr& rhs) { return DynExpr::equal(&lhs, &rhs); } @@ -389,7 +408,7 @@ DynExpr* GtExpr::s() { Constant* l = dynamic_cast(s_lhs); Constant* r = dynamic_cast(s_rhs); if (l && r) return DynExpr::_(l->get_val() > r->get_val()); - return Gt(*s_lhs, *s_rhs); + return dexpr::Gt(*s_lhs, *s_rhs); } DynExpr* SelectExpr::s() { @@ -399,7 +418,7 @@ DynExpr* SelectExpr::s() { Constant* p = dynamic_cast(s_pred); if (p) return p->get_val() ? s_true : s_false; if (*s_true == *s_false) return s_true; - return Select(*s_pred, *s_true, *s_false); + return dexpr::Select(*s_pred, *s_true, *s_false); } DynExpr* EqExpr::s() { @@ -409,7 +428,7 @@ DynExpr* EqExpr::s() { Constant* l = dynamic_cast(s_lhs); Constant* r = dynamic_cast(s_rhs); if (l && r) return DynExpr::_(l->get_val() == r->get_val()); - return Eq(*s_lhs, *s_rhs); + return dexpr::Eq(*s_lhs, *s_rhs); } DynExpr* NeExpr::s() { @@ -419,7 +438,7 @@ DynExpr* NeExpr::s() { Constant* l = dynamic_cast(s_lhs); Constant* r = dynamic_cast(s_rhs); if (l && r) return DynExpr::_(l->get_val() != r->get_val()); - return Ne(*s_lhs, *s_rhs); + return dexpr::Ne(*s_lhs, *s_rhs); } DynExpr* GeExpr::s() { @@ -429,7 +448,17 @@ DynExpr* GeExpr::s() { Constant* l = dynamic_cast(s_lhs); Constant* r = dynamic_cast(s_rhs); if (l && r) return DynExpr::_(l->get_val() >= r->get_val()); - return Ge(*s_lhs, *s_rhs); + return dexpr::Ge(*s_lhs, *s_rhs); +} + +DynExpr* MaxExpr::s() { + DynExpr* s_lhs = get_lhs()->s(); + DynExpr* s_rhs = get_rhs()->s(); + if (*s_lhs == *s_rhs) return s_lhs; + Constant* l = dynamic_cast(s_lhs); + Constant* r = dynamic_cast(s_rhs); + if (l && r) return DynExpr::_(std::max(l->get_val(), r->get_val())); + return dexpr::Max(*s_lhs, *s_rhs); } DynExpr* AbsExpr::s() { @@ -437,7 +466,7 @@ DynExpr* AbsExpr::s() { Constant* c = dynamic_cast(s_expr); if (c) return DynExpr::_(c->get_val() < 0 ? -c->get_val() : c->get_val()); if (AbsExpr* inner = dynamic_cast(s_expr)) return inner; - return Abs(*s_expr); + return dexpr::Abs(*s_expr); } std::ostream& operator<<(std::ostream& os, DynExpr* expr) { diff --git a/third_party/xla/xla/shape_dynexpr.h b/third_party/xla/xla/shape_dynexpr.h index 6b0f1b0fc8d623..aa93ac6bd3fe97 100644 --- a/third_party/xla/xla/shape_dynexpr.h +++ b/third_party/xla/xla/shape_dynexpr.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef XLA_SHAPE_DYNEXPR_H_ #define XLA_SHAPE_DYNEXPR_H_ +#include #include #include #include @@ -584,6 +585,51 @@ class GeExpr : public DynExpr { } }; +class MaxExpr : public DynExpr { + DynExpr* lhs; + DynExpr* rhs; + + public: + MaxExpr(DynExpr* l, DynExpr* r) : lhs(std::move(l)), rhs(std::move(r)) {} + void print(xla::Printer* printer) const override { + printer->Append("max("); + lhs->print(printer); + printer->Append(", "); + rhs->print(printer); + printer->Append(")"); + } + DynExpr* get_lhs() const { return lhs; } + DynExpr* get_rhs() const { return rhs; } + void to_proto(xla::ExpressionProto* proto) const override { + auto* msg = proto->mutable_max_node(); + lhs->to_proto(msg->mutable_lhs()); + rhs->to_proto(msg->mutable_rhs()); + } + bool is_constant() const override { + return lhs->is_constant() && rhs->is_constant(); + } + int64_t get_val() const override { + return std::max(lhs->get_val(), rhs->get_val()); + } + DynExpr* substitute(int id, DynExpr* v) override { + return new MaxExpr(lhs->substitute(id, v), rhs->substitute(id, v)); + } + std::set get_all_ids() override { + auto s = lhs->get_all_ids(); + s.merge(rhs->get_all_ids()); + return s; + } + std::optional solve(int64_t x) override { + // max(A, c) = x is not uniquely invertible. + return std::nullopt; + } + DynExpr* s() override; + ~MaxExpr() { + delete lhs; + delete rhs; + } +}; + class AbsExpr : public DynExpr { DynExpr* expr; @@ -621,6 +667,15 @@ DynExpr* operator+(DynExpr& lhs, DynExpr& rhs); DynExpr* operator+(DynExpr& lhs, int64_t d); DynExpr* operator-(DynExpr& lhs, DynExpr& rhs); DynExpr* operator-(DynExpr& lhs, int64_t d); +namespace dexpr { +DynExpr* Mul(DynExpr& lhs, DynExpr& rhs); +DynExpr* Mul(int64_t lhs, DynExpr& rhs); +DynExpr* Div(DynExpr& lhs, DynExpr& rhs); +DynExpr* Div(DynExpr& lhs, int64_t rhs); +DynExpr* Add(DynExpr& lhs, DynExpr& rhs); +DynExpr* Add(DynExpr& lhs, int64_t rhs); +DynExpr* Sub(DynExpr& lhs, DynExpr& rhs); +DynExpr* Sub(DynExpr& lhs, int64_t rhs); DynExpr* Gt(DynExpr& lhs, DynExpr& rhs); DynExpr* Gt(DynExpr& lhs, int64_t rhs); DynExpr* Eq(DynExpr& lhs, DynExpr& rhs); @@ -629,8 +684,11 @@ DynExpr* Ne(DynExpr& lhs, DynExpr& rhs); DynExpr* Ne(DynExpr& lhs, int64_t rhs); DynExpr* Ge(DynExpr& lhs, DynExpr& rhs); DynExpr* Ge(DynExpr& lhs, int64_t rhs); +DynExpr* Max(DynExpr& lhs, DynExpr& rhs); +DynExpr* Max(DynExpr& lhs, int64_t rhs); DynExpr* Abs(DynExpr& expr); DynExpr* Select(DynExpr& pred, DynExpr& on_true, DynExpr& on_false); +} // namespace dexpr bool operator==(DynExpr& lhs, DynExpr& rhs); bool operator==(DynExpr& lhs, int64_t d); diff --git a/third_party/xla/xla/xla_data.proto b/third_party/xla/xla/xla_data.proto index 518cd4fe5e39b6..9344f490b983b3 100644 --- a/third_party/xla/xla/xla_data.proto +++ b/third_party/xla/xla/xla_data.proto @@ -1223,6 +1223,7 @@ message ExpressionProto { LeNode le_node = 12; // exp <= exp GeNode ge_node = 13; // exp >= exp AbsNode abs_node = 14; // abs(exp) + MaxNode max_node = 15; // max(exp, exp) } } @@ -1285,3 +1286,8 @@ message GeNode { message AbsNode { ExpressionProto expr = 1; } + +message MaxNode { + ExpressionProto lhs = 1; + ExpressionProto rhs = 2; +} From fa2f4cf98d44fc833911c56a5c8d8fb791303367 Mon Sep 17 00:00:00 2001 From: Steven Varoumas Date: Tue, 17 Mar 2026 23:53:43 +0000 Subject: [PATCH 3/5] Extend DynExpr solving and equality --- tensorflow/compiler/jit/kernels/xla_ops.cc | 11 ++- third_party/xla/xla/shape.cc | 88 ++++++++++++----- third_party/xla/xla/shape_dynexpr.h | 108 ++++++++++++++++----- 3 files changed, 156 insertions(+), 51 deletions(-) diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc index e161bfc5a26fd5..1099fd555113f2 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.cc +++ b/tensorflow/compiler/jit/kernels/xla_ops.cc @@ -1272,14 +1272,17 @@ void XlaRunOp::Compute(OpKernelContext* ctx) { VLOG(1) << "input shape is " << ctx->input(input_idx).shape() << ", corresponding xla input shape is " << xla_shape; int64_t size = ctx->input(input_idx).shape().dim_size(dim); - int64_t dyn_val = expr->solve(size); // TODO: check if the result is correct later. + std::optional dyn_val = + expr->solve(size); // TODO: check if the result is correct later. VLOG(1) << "Found dynamic input. Real size is: " << size - << ", solved dynamic value is " << dyn_val; - if (dyn_val == -1) { + << ", solved dynamic value is " + << (dyn_val.has_value() ? std::to_string(*dyn_val) + : std::string("")); + if (!dyn_val.has_value()) { VLOG(1) << "Warning: Failed to solve the expression"; continue; } - dyn_vals.insert(dyn_val); + dyn_vals.insert(*dyn_val); } } } diff --git a/third_party/xla/xla/shape.cc b/third_party/xla/xla/shape.cc index 0a8fbf3a186353..d1286cd058c1b8 100644 --- a/third_party/xla/xla/shape.cc +++ b/third_party/xla/xla/shape.cc @@ -186,50 +186,92 @@ bool DynExpr::equal(DynExpr* expr1, DynExpr* expr2) { auto e1 = expr1->s(); auto e2 = expr2->s(); if (e1 == nullptr || e2 == nullptr) return false; - Constant* c1 = dynamic_cast(e1); - Constant* c2 = dynamic_cast(e2); - if (c1 && c2) return c1->get_val() == c2->get_val(); + + auto ordered_binary_equal = [&](auto* lhs_node, auto* rhs_node) { + return *lhs_node->get_lhs() == *rhs_node->get_lhs() && + *lhs_node->get_rhs() == *rhs_node->get_rhs(); + }; + auto unordered_binary_equal = [&](auto* lhs_node, auto* rhs_node) { + return (*lhs_node->get_lhs() == *rhs_node->get_lhs() && + *lhs_node->get_rhs() == *rhs_node->get_rhs()) || + (*lhs_node->get_lhs() == *rhs_node->get_rhs() && + *lhs_node->get_rhs() == *rhs_node->get_lhs()); + }; + auto constant_equal = dynamic_cast(e1); + auto other_constant = dynamic_cast(e2); + if (constant_equal && other_constant) { + return constant_equal->get_val() == other_constant->get_val(); + } + // Var x = Var y <=> x = y if (Variable* varx = dynamic_cast(e1), *vary = dynamic_cast(e2); varx && vary) { return varx->get_id() == vary->get_id(); } + // a * b = c * d <=> (a = c /\ b = d) \/ (a = d /\ b = c) if (Mul* ab = dynamic_cast(e1), *cd = dynamic_cast(e2); ab && cd) { - auto a = ab->get_lhs(); - auto b = ab->get_rhs(); - auto c = cd->get_lhs(); - auto d = cd->get_rhs(); - return (*a == *c && *b == *d) || (*a == *d && *b == *c); + return unordered_binary_equal(ab, cd); } // a / b = c / d <=> (a = c /\ b = d) if (Div* ab = dynamic_cast(e1), *cd = dynamic_cast(e2); ab && cd) { - auto a = ab->get_lhs(); - auto b = ab->get_rhs(); - auto c = cd->get_lhs(); - auto d = cd->get_rhs(); - return *a == *c && *b == *d; + return ordered_binary_equal(ab, cd); } // a + b = c + d <=> (a = c /\ b = d) \/ (a = d /\ b = c) if (Add* ab = dynamic_cast(e1), *cd = dynamic_cast(e2); ab && cd) { - auto a = ab->get_lhs(); - auto b = ab->get_rhs(); - auto c = cd->get_lhs(); - auto d = cd->get_rhs(); - return (*a == *c && *b == *d) || (*a == *d && *b == *c); + return unordered_binary_equal(ab, cd); } // a - b = c - d <=> (a = c /\ b = d) if (Sub* ab = dynamic_cast(e1), *cd = dynamic_cast(e2); ab && cd) { - auto* a = ab->get_lhs(); - auto* b = ab->get_rhs(); - auto* c = cd->get_lhs(); - auto* d = cd->get_rhs(); - return *a == *c && *b == *d; + return ordered_binary_equal(ab, cd); + } + // a < b = c < d <=> (a = c /\ b = d) + if (LtExpr* ab = dynamic_cast(e1), *cd = dynamic_cast(e2); + ab && cd) { + return ordered_binary_equal(ab, cd); + } + // a > b = c > d <=> (a = c /\ b = d) + if (GtExpr* ab = dynamic_cast(e1), *cd = dynamic_cast(e2); + ab && cd) { + return ordered_binary_equal(ab, cd); + } + // a == b = c == d <=> (a = c /\ b = d) \/ (a = d /\ b = c) + if (EqExpr* ab = dynamic_cast(e1), *cd = dynamic_cast(e2); + ab && cd) { + return unordered_binary_equal(ab, cd); + } + // a != b = c != d <=> (a = c /\ b = d) \/ (a = d /\ b = c) + if (NeExpr* ab = dynamic_cast(e1), *cd = dynamic_cast(e2); + ab && cd) { + return unordered_binary_equal(ab, cd); + } + // a <= b = c <= d <=> (a = c /\ b = d) + if (LeExpr* ab = dynamic_cast(e1), *cd = dynamic_cast(e2); + ab && cd) { + return ordered_binary_equal(ab, cd); + } + // a >= b = c >= d <=> (a = c /\ b = d) + if (GeExpr* ab = dynamic_cast(e1), *cd = dynamic_cast(e2); + ab && cd) { + return ordered_binary_equal(ab, cd); + } + // abs(a) = abs(b) <=> a = b + if (AbsExpr* a = dynamic_cast(e1), *b = dynamic_cast(e2); + a && b) { + return *a->get_expr() == *b->get_expr(); + } + // select(p, a, b) = select(q, c, d) <=> p = q /\ a = c /\ b = d + if (SelectExpr* pab = dynamic_cast(e1), + *qcd = dynamic_cast(e2); + pab && qcd) { + return *pab->get_pred() == *qcd->get_pred() && + *pab->get_on_true() == *qcd->get_on_true() && + *pab->get_on_false() == *qcd->get_on_false(); } return false; } diff --git a/third_party/xla/xla/shape_dynexpr.h b/third_party/xla/xla/shape_dynexpr.h index aa93ac6bd3fe97..2bafc56a0153ff 100644 --- a/third_party/xla/xla/shape_dynexpr.h +++ b/third_party/xla/xla/shape_dynexpr.h @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include #include "xla/printer.h" @@ -36,7 +37,7 @@ class DynExpr { virtual DynExpr* s() = 0; // simplify virtual DynExpr* substitute(int id, DynExpr* v) = 0; virtual std::set get_all_ids() = 0; - virtual int64_t solve(int64_t x) = 0; + virtual std::optional solve(int64_t x) = 0; bool is_dynamic() { return !is_constant(); } @@ -72,7 +73,7 @@ class Constant : public DynExpr { int64_t get_val() const override { return value; } DynExpr* substitute(int id, DynExpr* v) { return this; } std::set get_all_ids() { return {}; } - int64_t solve(int64_t x) { return -1; } + std::optional solve(int64_t x) { return std::nullopt; } DynExpr* s() override; }; @@ -95,7 +96,7 @@ class Variable : public DynExpr { int get_id() const { return id; } DynExpr* substitute(int id, DynExpr* v) { return get_id() == id ? v : this;} std::set get_all_ids() { return {get_id()}; } - int64_t solve(int64_t x) { return x; } + std::optional solve(int64_t x) { return x; } DynExpr* s() override; }; @@ -139,9 +140,9 @@ class Add : public DynExpr { return s; } - int64_t solve(int64_t x) { + std::optional solve(int64_t x) { // Cannot solve if both lhs and rhs are dynamic... - if (lhs->is_dynamic() && rhs->is_dynamic()) return -1; + if (lhs->is_dynamic() && rhs->is_dynamic()) return std::nullopt; if (lhs->get_all_ids().size() == 1) { // (A + c) = x <=> A = x - c => solve A = y with y = x - c return lhs->solve(x - rhs->get_val()); @@ -151,7 +152,7 @@ class Add : public DynExpr { return rhs->solve(x - lhs->get_val()); } // No solution - return -1; + return std::nullopt; } DynExpr* s() override; @@ -202,9 +203,9 @@ class Sub : public DynExpr { return s; } - int64_t solve(int64_t x) { + std::optional solve(int64_t x) { // Cannot solve if both lhs and rhs are dynamic... - if (lhs->is_dynamic() && rhs->is_dynamic()) return -1; + if (lhs->is_dynamic() && rhs->is_dynamic()) return std::nullopt; if (lhs->get_all_ids().size() == 1) { // (A - c) = x <=> A = x + c => solve A = y with y = x + c return lhs->solve(x + rhs->get_val()); @@ -214,7 +215,7 @@ class Sub : public DynExpr { return rhs->solve(x + lhs->get_val()); } // No solution - return -1; + return std::nullopt; } DynExpr* s() override; @@ -265,23 +266,23 @@ class Mul : public DynExpr { return s; } - int64_t solve(int64_t x) { + std::optional solve(int64_t x) { // Cannot solve if both lhs and rhs are dynamic... - if (lhs->is_dynamic() && rhs->is_dynamic()) return -1; + if (lhs->is_dynamic() && rhs->is_dynamic()) return std::nullopt; if (lhs->get_all_ids().size() == 1) { // (A * c) = x <=> A = x / c => solve A = y with y = x / c int64_t c = rhs->get_val(); - if (x % c != 0) return -1; + if (x % c != 0) return std::nullopt; return lhs->solve(x / c); } if (rhs->get_all_ids().size() == 1) { // (c * A) = x <=> A = x / c => solve A = y with y = x / c int64_t c = lhs->get_val(); - if (x % c != 0) return -1; + if (x % c != 0) return std::nullopt; return rhs->solve(x / c); } // No solution - return -1; + return std::nullopt; } DynExpr* s() override; @@ -334,9 +335,9 @@ class Div : public DynExpr { return s; } - int64_t solve(int64_t x) { + std::optional solve(int64_t x) { // Cannot solve if both lhs and rhs are dynamic... - if (lhs->is_dynamic() && rhs->is_dynamic()) return -1; + if (lhs->is_dynamic() && rhs->is_dynamic()) return std::nullopt; if (lhs->get_all_ids().size() == 1) { // (A / c) = x <=> A = x * c => solve A = y with y = x * c return lhs->solve(x * rhs->get_val()); @@ -344,11 +345,11 @@ class Div : public DynExpr { if (rhs->get_all_ids().size() == 1) { // (c / A) = x <=> A = c / x => solve A = y with y = c / x int64_t c = lhs->get_val(); - if (c % x != 0) return -1; + if (c % x != 0) return std::nullopt; return rhs->solve(c / x); } // No solution - return -1; + return std::nullopt; } ~Div() { @@ -396,7 +397,10 @@ class GtExpr : public DynExpr { return s; } - int64_t solve(int64_t x) override { return -1; } + std::optional solve(int64_t x) override { + // (A > c) = x is not uniquely invertible. + return std::nullopt; + } DynExpr* s() override; ~GtExpr() { @@ -455,7 +459,25 @@ class SelectExpr : public DynExpr { return s; } - int64_t solve(int64_t x) override { return -1; } + std::optional solve(int64_t x) override { + // select(c, A, B) = x <=> solve selected branch = x + if (pred->is_constant()) { + return pred->get_val() ? on_true->solve(x) : on_false->solve(x); + } + // select(p, A, A) = x <=> A = x + if (DynExpr::equal(on_true, on_false)) { + return on_true->solve(x); + } + if (on_true->is_constant() && on_false->is_constant()) { + const bool matches_true = on_true->get_val() == x; + const bool matches_false = on_false->get_val() == x; + // select(p, x, c) = x <=> p = 1 + if (matches_true && !matches_false) return pred->solve(1); + // select(p, c, x) = x <=> p = 0 + if (!matches_true && matches_false) return pred->solve(0); + } + return std::nullopt; + } DynExpr* s() override; ~SelectExpr() { @@ -497,7 +519,20 @@ class EqExpr : public DynExpr { s.merge(rhs->get_all_ids()); return s; } - int64_t solve(int64_t x) override { return -1; } + std::optional solve(int64_t x) override { + if (x != 0 && x != 1) return std::nullopt; + if (x == 1) { + if (lhs->get_all_ids().size() == 1 && rhs->is_constant()) { + // (A == c) = 1 <=> A = c => solve A = y with y = c + return lhs->solve(rhs->get_val()); + } + if (rhs->get_all_ids().size() == 1 && lhs->is_constant()) { + // (c == A) = 1 <=> A = c => solve A = y with y = c + return rhs->solve(lhs->get_val()); + } + } + return std::nullopt; + } DynExpr* s() override; ~EqExpr() { delete lhs; @@ -537,7 +572,20 @@ class NeExpr : public DynExpr { s.merge(rhs->get_all_ids()); return s; } - int64_t solve(int64_t x) override { return -1; } + std::optional solve(int64_t x) override { + if (x != 0 && x != 1) return std::nullopt; + if (x == 0) { + if (lhs->get_all_ids().size() == 1 && rhs->is_constant()) { + // (A != c) = 0 <=> A = c => solve A = y with y = c + return lhs->solve(rhs->get_val()); + } + if (rhs->get_all_ids().size() == 1 && lhs->is_constant()) { + // (c != A) = 0 <=> A = c => solve A = y with y = c + return rhs->solve(lhs->get_val()); + } + } + return std::nullopt; + } DynExpr* s() override; ~NeExpr() { delete lhs; @@ -577,7 +625,10 @@ class GeExpr : public DynExpr { s.merge(rhs->get_all_ids()); return s; } - int64_t solve(int64_t x) override { return -1; } + std::optional solve(int64_t x) override { + // (A >= c) = x is not uniquely invertible. + return std::nullopt; + } DynExpr* s() override; ~GeExpr() { delete lhs; @@ -654,7 +705,16 @@ class AbsExpr : public DynExpr { return new AbsExpr(expr->substitute(id, v)); } std::set get_all_ids() override { return expr->get_all_ids(); } - int64_t solve(int64_t x) override { return -1; } + std::optional solve(int64_t x) override { + if (x < 0) return std::nullopt; + // abs(A) = x <=> A = x or A = -x. Solve only if one branch works or both + // branches agree on the same solution. + std::optional pos = expr->solve(x); + std::optional neg = expr->solve(-x); + if (!pos.has_value()) return neg; + if (!neg.has_value()) return pos; + return *pos == *neg ? pos : std::nullopt; + } DynExpr* s() override; ~AbsExpr() { delete expr; } }; From b52597d39e21bf9630b92df36b7ccb10a19c55f1 Mon Sep 17 00:00:00 2001 From: Steven Varoumas Date: Tue, 17 Mar 2026 23:42:00 +0000 Subject: [PATCH 4/5] Use conditional DynExprs in strided slice --- .../tf2xla/kernels/strided_slice_op.cc | 3 +- tensorflow/core/util/strided_slice_op.cc | 34 +++++++++++-------- 2 files changed, 21 insertions(+), 16 deletions(-) diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc index 8d56835cace86c..099eaba008244e 100644 --- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc @@ -313,7 +313,8 @@ class StridedSliceOp : public XlaOpKernel { slice_begin.push_back(begin[i]); slice_begin_expr.push_back(begin_expr[i]); slice_end.push_back(std::max(end[i], begin[i])); - slice_end_expr.push_back(xla::dexpr::Max(*end_expr[i], *begin_expr[i])->s()); + slice_end_expr.push_back( + xla::dexpr::Max(*end_expr[i], *begin_expr[i])->s()); slice_strides.push_back(strides[i]); } else { // Negative stride: swap begin and end, add 1 because the interval diff --git a/tensorflow/core/util/strided_slice_op.cc b/tensorflow/core/util/strided_slice_op.cc index 29c5d56efec504..40e7071915d4a3 100644 --- a/tensorflow/core/util/strided_slice_op.cc +++ b/tensorflow/core/util/strided_slice_op.cc @@ -371,21 +371,22 @@ absl::Status ValidateStridedSliceOp( : x_fwd; } }; - auto canonical_expr = [stride_i, dim_i, masks, valid_range, - valid_range_expr, dim_i_expr](int64_t x, int c) { + auto canonical_expr = [stride_i, masks, valid_range_expr, + dim_i_expr](xla::DynExpr* x_expr, int c) { if (masks[c]) { return stride_i > 0 ? valid_range_expr[c] : valid_range_expr[(c + 1) & 1]; } else { - int64_t x_fwd = - x < 0 ? dim_i + x : x; // make negative indices positive - xla::DynExpr* x_expr = xla::DynExpr::_(x); + xla::DynExpr* wrapped_x_expr = (*dim_i_expr + *x_expr)->s(); xla::DynExpr* x_fwd_expr = - x < 0 ? (*dim_i_expr + *x_expr) - : x_expr; // make negative indices positive - return x_fwd < valid_range[0] ? valid_range_expr[0] - : x_fwd > valid_range[1] ? valid_range_expr[1] - : x_fwd_expr; + xla::Select(*xla::Lt(*x_expr, 0), *wrapped_x_expr, *x_expr)->s(); + xla::DynExpr* low_clamped_expr = + xla::Select(*xla::Lt(*x_fwd_expr, *valid_range_expr[0]), + *valid_range_expr[0], *x_fwd_expr) + ->s(); + return xla::Select(*xla::Gt(*low_clamped_expr, *valid_range_expr[1]), + *valid_range_expr[1], *low_clamped_expr) + ->s(); } }; if (shrink_i && stride_i <= 0) { @@ -403,9 +404,12 @@ absl::Status ValidateStridedSliceOp( // and canonical puts these to n-1 and 0, which implies a degenerate // interval. Fortunately, it is now safe to re-create end as begin+1. int64_t x_fwd = begin_i < 0 ? dim_i + begin_i : begin_i; - xla::DynExpr* x_fwd_expr = begin_i < 0 - ? (*dim_i_expr + *(*begin_expr)[i])->s() - : (*begin_expr)[i]; + xla::DynExpr* wrapped_x_expr = + (*dim_i_expr + *(*begin_expr)[i])->s(); + xla::DynExpr* x_fwd_expr = + xla::Select(*xla::Lt(*(*begin_expr)[i], 0), *wrapped_x_expr, + *(*begin_expr)[i]) + ->s(); begin_i = x_fwd; end_i = begin_i + 1; @@ -422,10 +426,10 @@ absl::Status ValidateStridedSliceOp( begin_i = canonical(begin_raw, 0); end_i = canonical(end_raw, 1); if (begin_expr) { - (*begin_expr)[i] = canonical_expr(begin_raw, 0)->s(); + (*begin_expr)[i] = canonical_expr((*begin_expr)[i], 0)->s(); } if (end_expr) { - (*end_expr)[i] = canonical_expr(end_raw, 1)->s(); + (*end_expr)[i] = canonical_expr((*end_expr)[i], 1)->s(); } } // Update optimization values From f5c10a515b0ff4a0ba7cc966ef3a687d7a11684b Mon Sep 17 00:00:00 2001 From: Steven Varoumas Date: Wed, 18 Mar 2026 00:59:39 +0000 Subject: [PATCH 5/5] Use xla::dexpr helpers in strided slice --- tensorflow/core/util/strided_slice_op.cc | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/tensorflow/core/util/strided_slice_op.cc b/tensorflow/core/util/strided_slice_op.cc index 40e7071915d4a3..97c7a410ab4e61 100644 --- a/tensorflow/core/util/strided_slice_op.cc +++ b/tensorflow/core/util/strided_slice_op.cc @@ -379,13 +379,17 @@ absl::Status ValidateStridedSliceOp( } else { xla::DynExpr* wrapped_x_expr = (*dim_i_expr + *x_expr)->s(); xla::DynExpr* x_fwd_expr = - xla::Select(*xla::Lt(*x_expr, 0), *wrapped_x_expr, *x_expr)->s(); + xla::dexpr::Select(*xla::dexpr::Gt(0, *x_expr), *wrapped_x_expr, + *x_expr) + ->s(); xla::DynExpr* low_clamped_expr = - xla::Select(*xla::Lt(*x_fwd_expr, *valid_range_expr[0]), - *valid_range_expr[0], *x_fwd_expr) + xla::dexpr::Select( + *xla::dexpr::Gt(*valid_range_expr[0], *x_fwd_expr), + *valid_range_expr[0], *x_fwd_expr) ->s(); - return xla::Select(*xla::Gt(*low_clamped_expr, *valid_range_expr[1]), - *valid_range_expr[1], *low_clamped_expr) + return xla::dexpr::Select( + *xla::dexpr::Gt(*low_clamped_expr, *valid_range_expr[1]), + *valid_range_expr[1], *low_clamped_expr) ->s(); } }; @@ -407,8 +411,9 @@ absl::Status ValidateStridedSliceOp( xla::DynExpr* wrapped_x_expr = (*dim_i_expr + *(*begin_expr)[i])->s(); xla::DynExpr* x_fwd_expr = - xla::Select(*xla::Lt(*(*begin_expr)[i], 0), *wrapped_x_expr, - *(*begin_expr)[i]) + xla::dexpr::Select( + *xla::dexpr::Gt(0, *(*begin_expr)[i]), *wrapped_x_expr, + *(*begin_expr)[i]) ->s(); begin_i = x_fwd; end_i = begin_i + 1;