diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc index e161bfc5a26fd5..1099fd555113f2 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.cc +++ b/tensorflow/compiler/jit/kernels/xla_ops.cc @@ -1272,14 +1272,17 @@ void XlaRunOp::Compute(OpKernelContext* ctx) { VLOG(1) << "input shape is " << ctx->input(input_idx).shape() << ", corresponding xla input shape is " << xla_shape; int64_t size = ctx->input(input_idx).shape().dim_size(dim); - int64_t dyn_val = expr->solve(size); // TODO: check if the result is correct later. + std::optional dyn_val = + expr->solve(size); // TODO: check if the result is correct later. VLOG(1) << "Found dynamic input. Real size is: " << size - << ", solved dynamic value is " << dyn_val; - if (dyn_val == -1) { + << ", solved dynamic value is " + << (dyn_val.has_value() ? std::to_string(*dyn_val) + : std::string("")); + if (!dyn_val.has_value()) { VLOG(1) << "Warning: Failed to solve the expression"; continue; } - dyn_vals.insert(dyn_val); + dyn_vals.insert(*dyn_val); } } } diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc index df7deaaf80d3bf..099eaba008244e 100644 --- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc @@ -313,8 +313,8 @@ class StridedSliceOp : public XlaOpKernel { slice_begin.push_back(begin[i]); slice_begin_expr.push_back(begin_expr[i]); slice_end.push_back(std::max(end[i], begin[i])); - slice_end_expr.push_back((end[i] > begin[i]) ? end_expr[i] - : begin_expr[i]); + slice_end_expr.push_back( + xla::dexpr::Max(*end_expr[i], *begin_expr[i])->s()); slice_strides.push_back(strides[i]); } else { // Negative stride: swap begin and end, add 1 because the interval @@ -326,9 +326,9 @@ class StridedSliceOp : public XlaOpKernel { slice_end.push_back(std::max(input_shape.dim_size(i) - end[i] - 1, input_shape.dim_size(i) - begin[i] - 1)); slice_end_expr.push_back( - (end[i] < begin[i]) - ? (*input_exprs[i] - *end_expr[i] - 1)->s() - : (*input_exprs[i] - *begin_expr[i] - 1)->s()); + xla::dexpr::Max(*(*input_exprs[i] - *end_expr[i] - 1)->s(), + *(*input_exprs[i] - *begin_expr[i] - 1)->s()) + ->s()); slice_strides.push_back(-strides[i]); dimensions_to_reverse.push_back(i); } diff --git a/tensorflow/core/util/strided_slice_op.cc b/tensorflow/core/util/strided_slice_op.cc index 29c5d56efec504..97c7a410ab4e61 100644 --- a/tensorflow/core/util/strided_slice_op.cc +++ b/tensorflow/core/util/strided_slice_op.cc @@ -371,21 +371,26 @@ absl::Status ValidateStridedSliceOp( : x_fwd; } }; - auto canonical_expr = [stride_i, dim_i, masks, valid_range, - valid_range_expr, dim_i_expr](int64_t x, int c) { + auto canonical_expr = [stride_i, masks, valid_range_expr, + dim_i_expr](xla::DynExpr* x_expr, int c) { if (masks[c]) { return stride_i > 0 ? valid_range_expr[c] : valid_range_expr[(c + 1) & 1]; } else { - int64_t x_fwd = - x < 0 ? dim_i + x : x; // make negative indices positive - xla::DynExpr* x_expr = xla::DynExpr::_(x); + xla::DynExpr* wrapped_x_expr = (*dim_i_expr + *x_expr)->s(); xla::DynExpr* x_fwd_expr = - x < 0 ? (*dim_i_expr + *x_expr) - : x_expr; // make negative indices positive - return x_fwd < valid_range[0] ? valid_range_expr[0] - : x_fwd > valid_range[1] ? valid_range_expr[1] - : x_fwd_expr; + xla::dexpr::Select(*xla::dexpr::Gt(0, *x_expr), *wrapped_x_expr, + *x_expr) + ->s(); + xla::DynExpr* low_clamped_expr = + xla::dexpr::Select( + *xla::dexpr::Gt(*valid_range_expr[0], *x_fwd_expr), + *valid_range_expr[0], *x_fwd_expr) + ->s(); + return xla::dexpr::Select( + *xla::dexpr::Gt(*low_clamped_expr, *valid_range_expr[1]), + *valid_range_expr[1], *low_clamped_expr) + ->s(); } }; if (shrink_i && stride_i <= 0) { @@ -403,9 +408,13 @@ absl::Status ValidateStridedSliceOp( // and canonical puts these to n-1 and 0, which implies a degenerate // interval. Fortunately, it is now safe to re-create end as begin+1. int64_t x_fwd = begin_i < 0 ? dim_i + begin_i : begin_i; - xla::DynExpr* x_fwd_expr = begin_i < 0 - ? (*dim_i_expr + *(*begin_expr)[i])->s() - : (*begin_expr)[i]; + xla::DynExpr* wrapped_x_expr = + (*dim_i_expr + *(*begin_expr)[i])->s(); + xla::DynExpr* x_fwd_expr = + xla::dexpr::Select( + *xla::dexpr::Gt(0, *(*begin_expr)[i]), *wrapped_x_expr, + *(*begin_expr)[i]) + ->s(); begin_i = x_fwd; end_i = begin_i + 1; @@ -422,10 +431,10 @@ absl::Status ValidateStridedSliceOp( begin_i = canonical(begin_raw, 0); end_i = canonical(end_raw, 1); if (begin_expr) { - (*begin_expr)[i] = canonical_expr(begin_raw, 0)->s(); + (*begin_expr)[i] = canonical_expr((*begin_expr)[i], 0)->s(); } if (end_expr) { - (*end_expr)[i] = canonical_expr(end_raw, 1)->s(); + (*end_expr)[i] = canonical_expr((*end_expr)[i], 1)->s(); } } // Update optimization values diff --git a/third_party/xla/xla/service/llvm_ir/llvm_util.cc b/third_party/xla/xla/service/llvm_ir/llvm_util.cc index fa27de28a84201..2451b4f5a10b06 100644 --- a/third_party/xla/xla/service/llvm_ir/llvm_util.cc +++ b/third_party/xla/xla/service/llvm_ir/llvm_util.cc @@ -897,10 +897,12 @@ llvm::Value* EmitExpression(llvm::IRBuilderBase* b, DynExpr* expr) { llvm::Value* v_rhs = EmitExpression(b, mul_node->get_rhs()); return b->CreateMul(v_lhs, v_rhs, "mul_dims"); } - // TODO: Check if this should ever happen if (Div* div_node = dynamic_cast(expr)) { llvm::Value* v_lhs = EmitExpression(b, div_node->get_lhs()); llvm::Value* v_rhs = EmitExpression(b, div_node->get_rhs()); + if (auto* rhs_const = llvm::dyn_cast(v_rhs)) { + CHECK_NE(rhs_const->getSExtValue(), 0); + } return b->CreateUDiv(v_lhs, v_rhs, "div_dims"); } if (Add* add_node = dynamic_cast(expr)) { @@ -913,6 +915,47 @@ llvm::Value* EmitExpression(llvm::IRBuilderBase* b, DynExpr* expr) { llvm::Value* v_rhs = EmitExpression(b, sub_node->get_rhs()); return b->CreateSub(v_lhs, v_rhs, "sub_dims"); } + if (GtExpr* gt_node = dynamic_cast(expr)) { + llvm::Value* v_lhs = EmitExpression(b, gt_node->get_lhs()); + llvm::Value* v_rhs = EmitExpression(b, gt_node->get_rhs()); + return b->CreateICmpSGT(v_lhs, v_rhs, "gt_dims"); + } + if (EqExpr* eq_node = dynamic_cast(expr)) { + llvm::Value* v_lhs = EmitExpression(b, eq_node->get_lhs()); + llvm::Value* v_rhs = EmitExpression(b, eq_node->get_rhs()); + return b->CreateICmpEQ(v_lhs, v_rhs, "eq_dims"); + } + if (NeExpr* ne_node = dynamic_cast(expr)) { + llvm::Value* v_lhs = EmitExpression(b, ne_node->get_lhs()); + llvm::Value* v_rhs = EmitExpression(b, ne_node->get_rhs()); + return b->CreateICmpNE(v_lhs, v_rhs, "ne_dims"); + } + if (GeExpr* ge_node = dynamic_cast(expr)) { + llvm::Value* v_lhs = EmitExpression(b, ge_node->get_lhs()); + llvm::Value* v_rhs = EmitExpression(b, ge_node->get_rhs()); + return b->CreateICmpSGE(v_lhs, v_rhs, "ge_dims"); + } + if (MaxExpr* max_node = dynamic_cast(expr)) { + llvm::Value* v_lhs = EmitExpression(b, max_node->get_lhs()); + llvm::Value* v_rhs = EmitExpression(b, max_node->get_rhs()); + llvm::Value* pred = b->CreateICmpSGT(v_lhs, v_rhs, "max_dims_pred"); + return b->CreateSelect(pred, v_lhs, v_rhs, "max_dims"); + } + if (AbsExpr* abs_node = dynamic_cast(expr)) { + llvm::Value* value = EmitExpression(b, abs_node->get_expr()); + llvm::Value* is_negative = + b->CreateICmpSLT(value, llvm::ConstantInt::get(i64Type, 0, true), + "abs_dims_neg"); + llvm::Value* negated = b->CreateNeg(value, "abs_dims_negated"); + return b->CreateSelect(is_negative, negated, value, "abs_dims"); + } + if (SelectExpr* select_node = dynamic_cast(expr)) { + llvm::Value* pred = EmitExpression(b, select_node->get_pred()); + llvm::Value* on_true = EmitExpression(b, select_node->get_on_true()); + llvm::Value* on_false = EmitExpression(b, select_node->get_on_false()); + CHECK(pred->getType()->isIntegerTy(1)); + return b->CreateSelect(pred, on_true, on_false, "select_dims"); + } return nullptr; } diff --git a/third_party/xla/xla/shape.cc b/third_party/xla/xla/shape.cc index 962984b9225b6a..d1286cd058c1b8 100644 --- a/third_party/xla/xla/shape.cc +++ b/third_party/xla/xla/shape.cc @@ -69,6 +69,53 @@ DynExpr* ExprFromProto(const ExpressionProto& proto) { return *ExprFromProto(div.lhs()) / *ExprFromProto(div.rhs()); } + case ExpressionProto::kLtNode: { + const auto& lt = proto.lt_node(); + return dexpr::Gt(*ExprFromProto(lt.rhs()), *ExprFromProto(lt.lhs())); + } + + case ExpressionProto::kGtNode: { + const auto& gt = proto.gt_node(); + return dexpr::Gt(*ExprFromProto(gt.lhs()), *ExprFromProto(gt.rhs())); + } + + case ExpressionProto::kSelectNode: { + const auto& select = proto.select_node(); + return dexpr::Select(*ExprFromProto(select.pred()), + *ExprFromProto(select.on_true()), + *ExprFromProto(select.on_false())); + } + + case ExpressionProto::kEqNode: { + const auto& eq = proto.eq_node(); + return dexpr::Eq(*ExprFromProto(eq.lhs()), *ExprFromProto(eq.rhs())); + } + + case ExpressionProto::kNeNode: { + const auto& ne = proto.ne_node(); + return dexpr::Ne(*ExprFromProto(ne.lhs()), *ExprFromProto(ne.rhs())); + } + + case ExpressionProto::kLeNode: { + const auto& le = proto.le_node(); + return dexpr::Ge(*ExprFromProto(le.rhs()), *ExprFromProto(le.lhs())); + } + + case ExpressionProto::kGeNode: { + const auto& ge = proto.ge_node(); + return dexpr::Ge(*ExprFromProto(ge.lhs()), *ExprFromProto(ge.rhs())); + } + + case ExpressionProto::kAbsNode: { + const auto& abs = proto.abs_node(); + return dexpr::Abs(*ExprFromProto(abs.expr())); + } + + case ExpressionProto::kMaxNode: { + const auto& max = proto.max_node(); + return dexpr::Max(*ExprFromProto(max.lhs()), *ExprFromProto(max.rhs())); + } + case ExpressionProto::NODE_TYPE_NOT_SET: default: return nullptr; @@ -91,6 +138,40 @@ DynExpr* operator-(DynExpr& lhs, DynExpr& rhs) { return new Sub(&lhs, &rhs); } DynExpr* operator-(DynExpr& lhs, int64_t d) { return new Sub(&lhs, DynExpr::_(d)); } +namespace dexpr { +DynExpr* Mul(DynExpr& lhs, DynExpr& rhs) { return new xla::Mul(&lhs, &rhs); } +DynExpr* Mul(int64_t lhs, DynExpr& rhs) { return new xla::Mul(DynExpr::_(lhs), &rhs); } +DynExpr* Div(DynExpr& lhs, DynExpr& rhs) { return new xla::Div(&lhs, &rhs); } +DynExpr* Div(DynExpr& lhs, int64_t rhs) { return new xla::Div(&lhs, DynExpr::_(rhs)); } +DynExpr* Add(DynExpr& lhs, DynExpr& rhs) { return new xla::Add(&lhs, &rhs); } +DynExpr* Add(DynExpr& lhs, int64_t rhs) { return new xla::Add(&lhs, DynExpr::_(rhs)); } +DynExpr* Sub(DynExpr& lhs, DynExpr& rhs) { return new xla::Sub(&lhs, &rhs); } +DynExpr* Sub(DynExpr& lhs, int64_t rhs) { return new xla::Sub(&lhs, DynExpr::_(rhs)); } +DynExpr* Gt(DynExpr& lhs, DynExpr& rhs) { return new GtExpr(&lhs, &rhs); } +DynExpr* Gt(DynExpr& lhs, int64_t rhs) { + return new GtExpr(&lhs, DynExpr::_(rhs)); +} +DynExpr* Eq(DynExpr& lhs, DynExpr& rhs) { return new EqExpr(&lhs, &rhs); } +DynExpr* Eq(DynExpr& lhs, int64_t rhs) { + return new EqExpr(&lhs, DynExpr::_(rhs)); +} +DynExpr* Ne(DynExpr& lhs, DynExpr& rhs) { return new NeExpr(&lhs, &rhs); } +DynExpr* Ne(DynExpr& lhs, int64_t rhs) { + return new NeExpr(&lhs, DynExpr::_(rhs)); +} +DynExpr* Ge(DynExpr& lhs, DynExpr& rhs) { return new GeExpr(&lhs, &rhs); } +DynExpr* Ge(DynExpr& lhs, int64_t rhs) { + return new GeExpr(&lhs, DynExpr::_(rhs)); +} +DynExpr* Max(DynExpr& lhs, DynExpr& rhs) { return new MaxExpr(&lhs, &rhs); } +DynExpr* Max(DynExpr& lhs, int64_t rhs) { + return new MaxExpr(&lhs, DynExpr::_(rhs)); +} +DynExpr* Abs(DynExpr& expr) { return new AbsExpr(&expr); } +DynExpr* Select(DynExpr& pred, DynExpr& on_true, DynExpr& on_false) { + return new SelectExpr(&pred, &on_true, &on_false); +} +} // namespace dexpr bool operator==(DynExpr& lhs, DynExpr& rhs) { return DynExpr::equal(&lhs, &rhs); } @@ -105,50 +186,92 @@ bool DynExpr::equal(DynExpr* expr1, DynExpr* expr2) { auto e1 = expr1->s(); auto e2 = expr2->s(); if (e1 == nullptr || e2 == nullptr) return false; - Constant* c1 = dynamic_cast(e1); - Constant* c2 = dynamic_cast(e2); - if (c1 && c2) return c1->get_val() == c2->get_val(); + + auto ordered_binary_equal = [&](auto* lhs_node, auto* rhs_node) { + return *lhs_node->get_lhs() == *rhs_node->get_lhs() && + *lhs_node->get_rhs() == *rhs_node->get_rhs(); + }; + auto unordered_binary_equal = [&](auto* lhs_node, auto* rhs_node) { + return (*lhs_node->get_lhs() == *rhs_node->get_lhs() && + *lhs_node->get_rhs() == *rhs_node->get_rhs()) || + (*lhs_node->get_lhs() == *rhs_node->get_rhs() && + *lhs_node->get_rhs() == *rhs_node->get_lhs()); + }; + auto constant_equal = dynamic_cast(e1); + auto other_constant = dynamic_cast(e2); + if (constant_equal && other_constant) { + return constant_equal->get_val() == other_constant->get_val(); + } + // Var x = Var y <=> x = y if (Variable* varx = dynamic_cast(e1), *vary = dynamic_cast(e2); varx && vary) { return varx->get_id() == vary->get_id(); } + // a * b = c * d <=> (a = c /\ b = d) \/ (a = d /\ b = c) if (Mul* ab = dynamic_cast(e1), *cd = dynamic_cast(e2); ab && cd) { - auto a = ab->get_lhs(); - auto b = ab->get_rhs(); - auto c = cd->get_lhs(); - auto d = cd->get_rhs(); - return (*a == *c && *b == *d) || (*a == *d && *b == *c); + return unordered_binary_equal(ab, cd); } // a / b = c / d <=> (a = c /\ b = d) if (Div* ab = dynamic_cast(e1), *cd = dynamic_cast(e2); ab && cd) { - auto a = ab->get_lhs(); - auto b = ab->get_rhs(); - auto c = cd->get_lhs(); - auto d = cd->get_rhs(); - return *a == *c && *b == *d; + return ordered_binary_equal(ab, cd); } // a + b = c + d <=> (a = c /\ b = d) \/ (a = d /\ b = c) if (Add* ab = dynamic_cast(e1), *cd = dynamic_cast(e2); ab && cd) { - auto a = ab->get_lhs(); - auto b = ab->get_rhs(); - auto c = cd->get_lhs(); - auto d = cd->get_rhs(); - return (*a == *c && *b == *d) || (*a == *d && *b == *c); + return unordered_binary_equal(ab, cd); } // a - b = c - d <=> (a = c /\ b = d) if (Sub* ab = dynamic_cast(e1), *cd = dynamic_cast(e2); ab && cd) { - auto* a = ab->get_lhs(); - auto* b = ab->get_rhs(); - auto* c = cd->get_lhs(); - auto* d = cd->get_rhs(); - return *a == *c && *b == *d; + return ordered_binary_equal(ab, cd); + } + // a < b = c < d <=> (a = c /\ b = d) + if (LtExpr* ab = dynamic_cast(e1), *cd = dynamic_cast(e2); + ab && cd) { + return ordered_binary_equal(ab, cd); + } + // a > b = c > d <=> (a = c /\ b = d) + if (GtExpr* ab = dynamic_cast(e1), *cd = dynamic_cast(e2); + ab && cd) { + return ordered_binary_equal(ab, cd); + } + // a == b = c == d <=> (a = c /\ b = d) \/ (a = d /\ b = c) + if (EqExpr* ab = dynamic_cast(e1), *cd = dynamic_cast(e2); + ab && cd) { + return unordered_binary_equal(ab, cd); + } + // a != b = c != d <=> (a = c /\ b = d) \/ (a = d /\ b = c) + if (NeExpr* ab = dynamic_cast(e1), *cd = dynamic_cast(e2); + ab && cd) { + return unordered_binary_equal(ab, cd); + } + // a <= b = c <= d <=> (a = c /\ b = d) + if (LeExpr* ab = dynamic_cast(e1), *cd = dynamic_cast(e2); + ab && cd) { + return ordered_binary_equal(ab, cd); + } + // a >= b = c >= d <=> (a = c /\ b = d) + if (GeExpr* ab = dynamic_cast(e1), *cd = dynamic_cast(e2); + ab && cd) { + return ordered_binary_equal(ab, cd); + } + // abs(a) = abs(b) <=> a = b + if (AbsExpr* a = dynamic_cast(e1), *b = dynamic_cast(e2); + a && b) { + return *a->get_expr() == *b->get_expr(); + } + // select(p, a, b) = select(q, c, d) <=> p = q /\ a = c /\ b = d + if (SelectExpr* pab = dynamic_cast(e1), + *qcd = dynamic_cast(e2); + pab && qcd) { + return *pab->get_pred() == *qcd->get_pred() && + *pab->get_on_true() == *qcd->get_on_true() && + *pab->get_on_false() == *qcd->get_on_false(); } return false; } @@ -320,6 +443,74 @@ DynExpr* Div::s() { return *s_lhs / *s_rhs; } +DynExpr* GtExpr::s() { + DynExpr* s_lhs = get_lhs()->s(); + DynExpr* s_rhs = get_rhs()->s(); + if (*s_lhs == *s_rhs) return DynExpr::zero; + Constant* l = dynamic_cast(s_lhs); + Constant* r = dynamic_cast(s_rhs); + if (l && r) return DynExpr::_(l->get_val() > r->get_val()); + return dexpr::Gt(*s_lhs, *s_rhs); +} + +DynExpr* SelectExpr::s() { + DynExpr* s_pred = get_pred()->s(); + DynExpr* s_true = get_on_true()->s(); + DynExpr* s_false = get_on_false()->s(); + Constant* p = dynamic_cast(s_pred); + if (p) return p->get_val() ? s_true : s_false; + if (*s_true == *s_false) return s_true; + return dexpr::Select(*s_pred, *s_true, *s_false); +} + +DynExpr* EqExpr::s() { + DynExpr* s_lhs = get_lhs()->s(); + DynExpr* s_rhs = get_rhs()->s(); + if (*s_lhs == *s_rhs) return DynExpr::one; + Constant* l = dynamic_cast(s_lhs); + Constant* r = dynamic_cast(s_rhs); + if (l && r) return DynExpr::_(l->get_val() == r->get_val()); + return dexpr::Eq(*s_lhs, *s_rhs); +} + +DynExpr* NeExpr::s() { + DynExpr* s_lhs = get_lhs()->s(); + DynExpr* s_rhs = get_rhs()->s(); + if (*s_lhs == *s_rhs) return DynExpr::zero; + Constant* l = dynamic_cast(s_lhs); + Constant* r = dynamic_cast(s_rhs); + if (l && r) return DynExpr::_(l->get_val() != r->get_val()); + return dexpr::Ne(*s_lhs, *s_rhs); +} + +DynExpr* GeExpr::s() { + DynExpr* s_lhs = get_lhs()->s(); + DynExpr* s_rhs = get_rhs()->s(); + if (*s_lhs == *s_rhs) return DynExpr::one; + Constant* l = dynamic_cast(s_lhs); + Constant* r = dynamic_cast(s_rhs); + if (l && r) return DynExpr::_(l->get_val() >= r->get_val()); + return dexpr::Ge(*s_lhs, *s_rhs); +} + +DynExpr* MaxExpr::s() { + DynExpr* s_lhs = get_lhs()->s(); + DynExpr* s_rhs = get_rhs()->s(); + if (*s_lhs == *s_rhs) return s_lhs; + Constant* l = dynamic_cast(s_lhs); + Constant* r = dynamic_cast(s_rhs); + if (l && r) return DynExpr::_(std::max(l->get_val(), r->get_val())); + return dexpr::Max(*s_lhs, *s_rhs); +} + +DynExpr* AbsExpr::s() { + DynExpr* s_expr = get_expr()->s(); + Constant* c = dynamic_cast(s_expr); + if (c) return DynExpr::_(c->get_val() < 0 ? -c->get_val() : c->get_val()); + if (AbsExpr* inner = dynamic_cast(s_expr)) return inner; + return dexpr::Abs(*s_expr); +} + std::ostream& operator<<(std::ostream& os, DynExpr* expr) { ExpressionProto proto; expr->to_proto(&proto); diff --git a/third_party/xla/xla/shape_dynexpr.h b/third_party/xla/xla/shape_dynexpr.h index 5c1f5645c25e0e..2bafc56a0153ff 100644 --- a/third_party/xla/xla/shape_dynexpr.h +++ b/third_party/xla/xla/shape_dynexpr.h @@ -16,8 +16,10 @@ limitations under the License. #ifndef XLA_SHAPE_DYNEXPR_H_ #define XLA_SHAPE_DYNEXPR_H_ +#include #include #include +#include #include #include "xla/printer.h" @@ -35,7 +37,7 @@ class DynExpr { virtual DynExpr* s() = 0; // simplify virtual DynExpr* substitute(int id, DynExpr* v) = 0; virtual std::set get_all_ids() = 0; - virtual int64_t solve(int64_t x) = 0; + virtual std::optional solve(int64_t x) = 0; bool is_dynamic() { return !is_constant(); } @@ -71,7 +73,7 @@ class Constant : public DynExpr { int64_t get_val() const override { return value; } DynExpr* substitute(int id, DynExpr* v) { return this; } std::set get_all_ids() { return {}; } - int64_t solve(int64_t x) { return -1; } + std::optional solve(int64_t x) { return std::nullopt; } DynExpr* s() override; }; @@ -94,7 +96,7 @@ class Variable : public DynExpr { int get_id() const { return id; } DynExpr* substitute(int id, DynExpr* v) { return get_id() == id ? v : this;} std::set get_all_ids() { return {get_id()}; } - int64_t solve(int64_t x) { return x; } + std::optional solve(int64_t x) { return x; } DynExpr* s() override; }; @@ -138,9 +140,9 @@ class Add : public DynExpr { return s; } - int64_t solve(int64_t x) { + std::optional solve(int64_t x) { // Cannot solve if both lhs and rhs are dynamic... - if (lhs->is_dynamic() && rhs->is_dynamic()) return -1; + if (lhs->is_dynamic() && rhs->is_dynamic()) return std::nullopt; if (lhs->get_all_ids().size() == 1) { // (A + c) = x <=> A = x - c => solve A = y with y = x - c return lhs->solve(x - rhs->get_val()); @@ -150,7 +152,7 @@ class Add : public DynExpr { return rhs->solve(x - lhs->get_val()); } // No solution - return -1; + return std::nullopt; } DynExpr* s() override; @@ -201,9 +203,9 @@ class Sub : public DynExpr { return s; } - int64_t solve(int64_t x) { + std::optional solve(int64_t x) { // Cannot solve if both lhs and rhs are dynamic... - if (lhs->is_dynamic() && rhs->is_dynamic()) return -1; + if (lhs->is_dynamic() && rhs->is_dynamic()) return std::nullopt; if (lhs->get_all_ids().size() == 1) { // (A - c) = x <=> A = x + c => solve A = y with y = x + c return lhs->solve(x + rhs->get_val()); @@ -213,7 +215,7 @@ class Sub : public DynExpr { return rhs->solve(x + lhs->get_val()); } // No solution - return -1; + return std::nullopt; } DynExpr* s() override; @@ -264,23 +266,23 @@ class Mul : public DynExpr { return s; } - int64_t solve(int64_t x) { + std::optional solve(int64_t x) { // Cannot solve if both lhs and rhs are dynamic... - if (lhs->is_dynamic() && rhs->is_dynamic()) return -1; + if (lhs->is_dynamic() && rhs->is_dynamic()) return std::nullopt; if (lhs->get_all_ids().size() == 1) { // (A * c) = x <=> A = x / c => solve A = y with y = x / c int64_t c = rhs->get_val(); - if (x % c != 0) return -1; + if (x % c != 0) return std::nullopt; return lhs->solve(x / c); } if (rhs->get_all_ids().size() == 1) { // (c * A) = x <=> A = x / c => solve A = y with y = x / c int64_t c = lhs->get_val(); - if (x % c != 0) return -1; + if (x % c != 0) return std::nullopt; return rhs->solve(x / c); } // No solution - return -1; + return std::nullopt; } DynExpr* s() override; @@ -333,9 +335,9 @@ class Div : public DynExpr { return s; } - int64_t solve(int64_t x) { + std::optional solve(int64_t x) { // Cannot solve if both lhs and rhs are dynamic... - if (lhs->is_dynamic() && rhs->is_dynamic()) return -1; + if (lhs->is_dynamic() && rhs->is_dynamic()) return std::nullopt; if (lhs->get_all_ids().size() == 1) { // (A / c) = x <=> A = x * c => solve A = y with y = x * c return lhs->solve(x * rhs->get_val()); @@ -343,11 +345,11 @@ class Div : public DynExpr { if (rhs->get_all_ids().size() == 1) { // (c / A) = x <=> A = c / x => solve A = y with y = c / x int64_t c = lhs->get_val(); - if (c % x != 0) return -1; + if (c % x != 0) return std::nullopt; return rhs->solve(c / x); } // No solution - return -1; + return std::nullopt; } ~Div() { @@ -356,6 +358,367 @@ class Div : public DynExpr { } }; +class GtExpr : public DynExpr { + DynExpr* lhs; + DynExpr* rhs; + + public: + GtExpr(DynExpr* l, DynExpr* r) : lhs(std::move(l)), rhs(std::move(r)) {} + void print(xla::Printer* printer) const override { + printer->Append("("); + lhs->print(printer); + printer->Append(" > "); + rhs->print(printer); + printer->Append(")"); + } + + DynExpr* get_lhs() const { return lhs; } + DynExpr* get_rhs() const { return rhs; } + + void to_proto(xla::ExpressionProto* proto) const override { + auto* gt_msg = proto->mutable_gt_node(); + lhs->to_proto(gt_msg->mutable_lhs()); + rhs->to_proto(gt_msg->mutable_rhs()); + } + + bool is_constant() const override { + return lhs->is_constant() && rhs->is_constant(); + } + + int64_t get_val() const override { return lhs->get_val() > rhs->get_val(); } + + DynExpr* substitute(int id, DynExpr* v) override { + return new GtExpr(lhs->substitute(id, v), rhs->substitute(id, v)); + } + + std::set get_all_ids() override { + auto s = lhs->get_all_ids(); + s.merge(rhs->get_all_ids()); + return s; + } + + std::optional solve(int64_t x) override { + // (A > c) = x is not uniquely invertible. + return std::nullopt; + } + DynExpr* s() override; + + ~GtExpr() { + delete lhs; + delete rhs; + } +}; + +class SelectExpr : public DynExpr { + DynExpr* pred; + DynExpr* on_true; + DynExpr* on_false; + + public: + SelectExpr(DynExpr* p, DynExpr* t, DynExpr* f) + : pred(std::move(p)), on_true(std::move(t)), on_false(std::move(f)) {} + void print(xla::Printer* printer) const override { + printer->Append("select("); + pred->print(printer); + printer->Append(", "); + on_true->print(printer); + printer->Append(", "); + on_false->print(printer); + printer->Append(")"); + } + + DynExpr* get_pred() const { return pred; } + DynExpr* get_on_true() const { return on_true; } + DynExpr* get_on_false() const { return on_false; } + + void to_proto(xla::ExpressionProto* proto) const override { + auto* select_msg = proto->mutable_select_node(); + pred->to_proto(select_msg->mutable_pred()); + on_true->to_proto(select_msg->mutable_on_true()); + on_false->to_proto(select_msg->mutable_on_false()); + } + + bool is_constant() const override { + return pred->is_constant() && on_true->is_constant() && + on_false->is_constant(); + } + + int64_t get_val() const override { + return pred->get_val() ? on_true->get_val() : on_false->get_val(); + } + + DynExpr* substitute(int id, DynExpr* v) override { + return new SelectExpr(pred->substitute(id, v), on_true->substitute(id, v), + on_false->substitute(id, v)); + } + + std::set get_all_ids() override { + auto s = pred->get_all_ids(); + s.merge(on_true->get_all_ids()); + s.merge(on_false->get_all_ids()); + return s; + } + + std::optional solve(int64_t x) override { + // select(c, A, B) = x <=> solve selected branch = x + if (pred->is_constant()) { + return pred->get_val() ? on_true->solve(x) : on_false->solve(x); + } + // select(p, A, A) = x <=> A = x + if (DynExpr::equal(on_true, on_false)) { + return on_true->solve(x); + } + if (on_true->is_constant() && on_false->is_constant()) { + const bool matches_true = on_true->get_val() == x; + const bool matches_false = on_false->get_val() == x; + // select(p, x, c) = x <=> p = 1 + if (matches_true && !matches_false) return pred->solve(1); + // select(p, c, x) = x <=> p = 0 + if (!matches_true && matches_false) return pred->solve(0); + } + return std::nullopt; + } + DynExpr* s() override; + + ~SelectExpr() { + delete pred; + delete on_true; + delete on_false; + } +}; + +class EqExpr : public DynExpr { + DynExpr* lhs; + DynExpr* rhs; + + public: + EqExpr(DynExpr* l, DynExpr* r) : lhs(std::move(l)), rhs(std::move(r)) {} + void print(xla::Printer* printer) const override { + printer->Append("("); + lhs->print(printer); + printer->Append(" == "); + rhs->print(printer); + printer->Append(")"); + } + DynExpr* get_lhs() const { return lhs; } + DynExpr* get_rhs() const { return rhs; } + void to_proto(xla::ExpressionProto* proto) const override { + auto* msg = proto->mutable_eq_node(); + lhs->to_proto(msg->mutable_lhs()); + rhs->to_proto(msg->mutable_rhs()); + } + bool is_constant() const override { + return lhs->is_constant() && rhs->is_constant(); + } + int64_t get_val() const override { return lhs->get_val() == rhs->get_val(); } + DynExpr* substitute(int id, DynExpr* v) override { + return new EqExpr(lhs->substitute(id, v), rhs->substitute(id, v)); + } + std::set get_all_ids() override { + auto s = lhs->get_all_ids(); + s.merge(rhs->get_all_ids()); + return s; + } + std::optional solve(int64_t x) override { + if (x != 0 && x != 1) return std::nullopt; + if (x == 1) { + if (lhs->get_all_ids().size() == 1 && rhs->is_constant()) { + // (A == c) = 1 <=> A = c => solve A = y with y = c + return lhs->solve(rhs->get_val()); + } + if (rhs->get_all_ids().size() == 1 && lhs->is_constant()) { + // (c == A) = 1 <=> A = c => solve A = y with y = c + return rhs->solve(lhs->get_val()); + } + } + return std::nullopt; + } + DynExpr* s() override; + ~EqExpr() { + delete lhs; + delete rhs; + } +}; + +class NeExpr : public DynExpr { + DynExpr* lhs; + DynExpr* rhs; + + public: + NeExpr(DynExpr* l, DynExpr* r) : lhs(std::move(l)), rhs(std::move(r)) {} + void print(xla::Printer* printer) const override { + printer->Append("("); + lhs->print(printer); + printer->Append(" != "); + rhs->print(printer); + printer->Append(")"); + } + DynExpr* get_lhs() const { return lhs; } + DynExpr* get_rhs() const { return rhs; } + void to_proto(xla::ExpressionProto* proto) const override { + auto* msg = proto->mutable_ne_node(); + lhs->to_proto(msg->mutable_lhs()); + rhs->to_proto(msg->mutable_rhs()); + } + bool is_constant() const override { + return lhs->is_constant() && rhs->is_constant(); + } + int64_t get_val() const override { return lhs->get_val() != rhs->get_val(); } + DynExpr* substitute(int id, DynExpr* v) override { + return new NeExpr(lhs->substitute(id, v), rhs->substitute(id, v)); + } + std::set get_all_ids() override { + auto s = lhs->get_all_ids(); + s.merge(rhs->get_all_ids()); + return s; + } + std::optional solve(int64_t x) override { + if (x != 0 && x != 1) return std::nullopt; + if (x == 0) { + if (lhs->get_all_ids().size() == 1 && rhs->is_constant()) { + // (A != c) = 0 <=> A = c => solve A = y with y = c + return lhs->solve(rhs->get_val()); + } + if (rhs->get_all_ids().size() == 1 && lhs->is_constant()) { + // (c != A) = 0 <=> A = c => solve A = y with y = c + return rhs->solve(lhs->get_val()); + } + } + return std::nullopt; + } + DynExpr* s() override; + ~NeExpr() { + delete lhs; + delete rhs; + } +}; + +class GeExpr : public DynExpr { + DynExpr* lhs; + DynExpr* rhs; + + public: + GeExpr(DynExpr* l, DynExpr* r) : lhs(std::move(l)), rhs(std::move(r)) {} + void print(xla::Printer* printer) const override { + printer->Append("("); + lhs->print(printer); + printer->Append(" >= "); + rhs->print(printer); + printer->Append(")"); + } + DynExpr* get_lhs() const { return lhs; } + DynExpr* get_rhs() const { return rhs; } + void to_proto(xla::ExpressionProto* proto) const override { + auto* msg = proto->mutable_ge_node(); + lhs->to_proto(msg->mutable_lhs()); + rhs->to_proto(msg->mutable_rhs()); + } + bool is_constant() const override { + return lhs->is_constant() && rhs->is_constant(); + } + int64_t get_val() const override { return lhs->get_val() >= rhs->get_val(); } + DynExpr* substitute(int id, DynExpr* v) override { + return new GeExpr(lhs->substitute(id, v), rhs->substitute(id, v)); + } + std::set get_all_ids() override { + auto s = lhs->get_all_ids(); + s.merge(rhs->get_all_ids()); + return s; + } + std::optional solve(int64_t x) override { + // (A >= c) = x is not uniquely invertible. + return std::nullopt; + } + DynExpr* s() override; + ~GeExpr() { + delete lhs; + delete rhs; + } +}; + +class MaxExpr : public DynExpr { + DynExpr* lhs; + DynExpr* rhs; + + public: + MaxExpr(DynExpr* l, DynExpr* r) : lhs(std::move(l)), rhs(std::move(r)) {} + void print(xla::Printer* printer) const override { + printer->Append("max("); + lhs->print(printer); + printer->Append(", "); + rhs->print(printer); + printer->Append(")"); + } + DynExpr* get_lhs() const { return lhs; } + DynExpr* get_rhs() const { return rhs; } + void to_proto(xla::ExpressionProto* proto) const override { + auto* msg = proto->mutable_max_node(); + lhs->to_proto(msg->mutable_lhs()); + rhs->to_proto(msg->mutable_rhs()); + } + bool is_constant() const override { + return lhs->is_constant() && rhs->is_constant(); + } + int64_t get_val() const override { + return std::max(lhs->get_val(), rhs->get_val()); + } + DynExpr* substitute(int id, DynExpr* v) override { + return new MaxExpr(lhs->substitute(id, v), rhs->substitute(id, v)); + } + std::set get_all_ids() override { + auto s = lhs->get_all_ids(); + s.merge(rhs->get_all_ids()); + return s; + } + std::optional solve(int64_t x) override { + // max(A, c) = x is not uniquely invertible. + return std::nullopt; + } + DynExpr* s() override; + ~MaxExpr() { + delete lhs; + delete rhs; + } +}; + +class AbsExpr : public DynExpr { + DynExpr* expr; + + public: + explicit AbsExpr(DynExpr* e) : expr(std::move(e)) {} + void print(xla::Printer* printer) const override { + printer->Append("abs("); + expr->print(printer); + printer->Append(")"); + } + DynExpr* get_expr() const { return expr; } + void to_proto(xla::ExpressionProto* proto) const override { + auto* msg = proto->mutable_abs_node(); + expr->to_proto(msg->mutable_expr()); + } + bool is_constant() const override { return expr->is_constant(); } + int64_t get_val() const override { + int64_t v = expr->get_val(); + return v < 0 ? -v : v; + } + DynExpr* substitute(int id, DynExpr* v) override { + return new AbsExpr(expr->substitute(id, v)); + } + std::set get_all_ids() override { return expr->get_all_ids(); } + std::optional solve(int64_t x) override { + if (x < 0) return std::nullopt; + // abs(A) = x <=> A = x or A = -x. Solve only if one branch works or both + // branches agree on the same solution. + std::optional pos = expr->solve(x); + std::optional neg = expr->solve(-x); + if (!pos.has_value()) return neg; + if (!neg.has_value()) return pos; + return *pos == *neg ? pos : std::nullopt; + } + DynExpr* s() override; + ~AbsExpr() { delete expr; } +}; + DynExpr* operator*(DynExpr& lhs, DynExpr& rhs); DynExpr* operator*(int64_t k, DynExpr& rhs); DynExpr* operator/(DynExpr& lhs, DynExpr& rhs); @@ -364,6 +727,28 @@ DynExpr* operator+(DynExpr& lhs, DynExpr& rhs); DynExpr* operator+(DynExpr& lhs, int64_t d); DynExpr* operator-(DynExpr& lhs, DynExpr& rhs); DynExpr* operator-(DynExpr& lhs, int64_t d); +namespace dexpr { +DynExpr* Mul(DynExpr& lhs, DynExpr& rhs); +DynExpr* Mul(int64_t lhs, DynExpr& rhs); +DynExpr* Div(DynExpr& lhs, DynExpr& rhs); +DynExpr* Div(DynExpr& lhs, int64_t rhs); +DynExpr* Add(DynExpr& lhs, DynExpr& rhs); +DynExpr* Add(DynExpr& lhs, int64_t rhs); +DynExpr* Sub(DynExpr& lhs, DynExpr& rhs); +DynExpr* Sub(DynExpr& lhs, int64_t rhs); +DynExpr* Gt(DynExpr& lhs, DynExpr& rhs); +DynExpr* Gt(DynExpr& lhs, int64_t rhs); +DynExpr* Eq(DynExpr& lhs, DynExpr& rhs); +DynExpr* Eq(DynExpr& lhs, int64_t rhs); +DynExpr* Ne(DynExpr& lhs, DynExpr& rhs); +DynExpr* Ne(DynExpr& lhs, int64_t rhs); +DynExpr* Ge(DynExpr& lhs, DynExpr& rhs); +DynExpr* Ge(DynExpr& lhs, int64_t rhs); +DynExpr* Max(DynExpr& lhs, DynExpr& rhs); +DynExpr* Max(DynExpr& lhs, int64_t rhs); +DynExpr* Abs(DynExpr& expr); +DynExpr* Select(DynExpr& pred, DynExpr& on_true, DynExpr& on_false); +} // namespace dexpr bool operator==(DynExpr& lhs, DynExpr& rhs); bool operator==(DynExpr& lhs, int64_t d); diff --git a/third_party/xla/xla/xla_data.proto b/third_party/xla/xla/xla_data.proto index 798dc12fb5734e..9344f490b983b3 100644 --- a/third_party/xla/xla/xla_data.proto +++ b/third_party/xla/xla/xla_data.proto @@ -1215,6 +1215,15 @@ message ExpressionProto { SubNode sub_node = 4; // exp - exp MulNode mul_node = 5; // exp * exp DivNode div_node = 6; // exp / exp + LtNode lt_node = 7; // exp < exp + GtNode gt_node = 8; // exp > exp + SelectNode select_node = 9; // select(pred, on_true, on_false) + EqNode eq_node = 10; // exp == exp + NeNode ne_node = 11; // exp != exp + LeNode le_node = 12; // exp <= exp + GeNode ge_node = 13; // exp >= exp + AbsNode abs_node = 14; // abs(exp) + MaxNode max_node = 15; // max(exp, exp) } } @@ -1236,4 +1245,49 @@ message MulNode { message DivNode { ExpressionProto lhs = 1; ExpressionProto rhs = 2; -} \ No newline at end of file +} + +message LtNode { + ExpressionProto lhs = 1; + ExpressionProto rhs = 2; +} + +message GtNode { + ExpressionProto lhs = 1; + ExpressionProto rhs = 2; +} + +message SelectNode { + ExpressionProto pred = 1; + ExpressionProto on_true = 2; + ExpressionProto on_false = 3; +} + +message EqNode { + ExpressionProto lhs = 1; + ExpressionProto rhs = 2; +} + +message NeNode { + ExpressionProto lhs = 1; + ExpressionProto rhs = 2; +} + +message LeNode { + ExpressionProto lhs = 1; + ExpressionProto rhs = 2; +} + +message GeNode { + ExpressionProto lhs = 1; + ExpressionProto rhs = 2; +} + +message AbsNode { + ExpressionProto expr = 1; +} + +message MaxNode { + ExpressionProto lhs = 1; + ExpressionProto rhs = 2; +}