Conversation
There was a problem hiding this comment.
Pull request overview
This PR addresses incorrect propagation/handling of per-dimension shape expressions by ensuring expression vectors don’t outlive the current rank and by limiting when output-dimension canonicalization runs during Grappler shape inference.
Changes:
- Gate
CanonicalizeOutputDims()in Grappler shape inference behindTensorShapeExpressionsEnabled(). - Clamp
TensorShapeRepexpression accessors/mutators soexpressions_.size()cannot exceed the shape rank. - Add bounds checks in
TensorShapeRep::set_expression()and truncate inputs inset_expressions().
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
| tensorflow/core/grappler/costs/graph_properties.cc | Only canonicalize output dims when TensorShape expressions are enabled. |
| tensorflow/core/framework/tensor_shape.h | Ensure get_expressions() does not return more entries than the shape rank. |
| tensorflow/core/framework/tensor_shape.cc | Enforce expression index/rank bounds and truncate expression vectors to rank. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| void TensorShapeRep::set_expression(int d, xla::DynExpr* expr) { | ||
| if (!kTensorShapeExpressionsEnabled) { | ||
| expressions_.clear(); | ||
| return; | ||
| } | ||
| if (expressions_.size() <= static_cast<size_t>(d)) { | ||
| expressions_.resize(d + 1, nullptr); | ||
| CHECK_GE(d, 0); | ||
| CHECK_LT(d, ndims_byte()); | ||
| const size_t new_size = static_cast<size_t>(d) + 1; | ||
| if (expressions_.size() < new_size) { | ||
| expressions_.resize(new_size, nullptr); | ||
| } | ||
| expressions_[d] = expr; | ||
| } | ||
|
|
||
| void TensorShapeRep::AddExpression(xla::DynExpr* expr) { | ||
| if (!kTensorShapeExpressionsEnabled) { | ||
| return; | ||
| } | ||
| CHECK_LT(expressions_.size(), ndims_byte()); | ||
| expressions_.push_back(expr); | ||
| } | ||
|
|
||
| void TensorShapeRep::set_expressions(std::vector<xla::DynExpr*> exprs) { | ||
| if (!kTensorShapeExpressionsEnabled) { | ||
| expressions_.clear(); | ||
| return; | ||
| } | ||
| if (exprs.size() > ndims_byte()) { | ||
| exprs.resize(ndims_byte()); | ||
| } |
There was a problem hiding this comment.
These changes clamp expressions to ndims_byte(), which should prevent producing/propagating TensorShapeProto instances where expressions.size() > dim_size(). There are existing unit tests for TensorShape behavior, but none appear to cover expression truncation/bounds; please add a test case that exercises setting expressions beyond the current rank and verifies serialization (AsProto) and getters (get_expressions) stay consistent with dims() when dynamic sizes are enabled.
Empty line.
f2b1c4c to
9eb0d14
Compare
No description provided.