diff --git a/onnxscript/rewriter/_pattern_ir.py b/onnxscript/rewriter/_pattern_ir.py index 9b81e33581..674d1fc593 100644 --- a/onnxscript/rewriter/_pattern_ir.py +++ b/onnxscript/rewriter/_pattern_ir.py @@ -899,12 +899,26 @@ def num_outputs(self) -> int: return len(self._outputs) def commute(self) -> Sequence[GraphPattern]: + # List all commutative elementwise (binary) operators for which we + # consider swapping the inputs + COMMUTATIVE_OPS = { + ("", "Add", ""), + ("", "Mul", ""), + ("", "And", ""), + ("", "Or", ""), + ("", "Xor", ""), + ("", "BitwiseAnd", ""), + ("", "BitwiseOr", ""), + ("", "BitwiseXor", ""), + ("", "Equal", ""), + ("", "Max", ""), + ("", "Mean", ""), + ("", "Min", ""), + ("", "Sum", ""), + } + def commute_node(node: NodePattern) -> Iterable[bool]: - if node.op_identifier() == ("", "Add", "") or node.op_identifier() == ( - "", - "Mul", - "", - ): + if node.op_identifier() in COMMUTATIVE_OPS: # Try with and without swapping inputs. return [False, True] # No swapping of inputs