From d63df78a92eb300e6a755fa71a49317f8e9d92d1 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 21 Mar 2025 19:40:45 -0700 Subject: [PATCH 01/19] [IR] Update pass infra --- onnxscript/ir/passes/_pass_infra.py | 45 +++++++++++++++-------------- 1 file changed, 24 insertions(+), 21 deletions(-) diff --git a/onnxscript/ir/passes/_pass_infra.py b/onnxscript/ir/passes/_pass_infra.py index c03a23bd8b..f08c89d193 100644 --- a/onnxscript/ir/passes/_pass_infra.py +++ b/onnxscript/ir/passes/_pass_infra.py @@ -70,12 +70,31 @@ class PassBase(abc.ABC): Class attributes: in_place: Whether the pass modifies the model in place. + destructive: Whether the pass will destroy the input model when ``in_place=False``. """ in_place: bool = True + destructive: bool = False def __call__(self, model: ir.Model) -> PassResult: - return self.call(model) + # Check preconditions + try: + self.requires(model) + except PreconditionError: + raise + except Exception as e: + raise PreconditionError("Pre-condition failed") from e + + result = self.call(model) + + # Check postconditions + try: + self.ensures(model) + except PostconditionError: + raise + except Exception as e: + raise PostconditionError("Post-condition failed") from e + return result @abc.abstractmethod def call(self, model: ir.Model) -> PassResult: @@ -97,26 +116,23 @@ def ensures(self, model: ir.Model) -> None: del model # Unused -class PassManager: +class PassManager(PassBase): """Pass manager for the IR. - The PassManager is a callable that runs a sequence of passes on a model. + The PassManager is a Pass that runs a sequence of passes on a model. Attributes: passes: The passes to run. - check_invariants: Whether to check invariants before and after each pass. steps: The number of times to run the passes. """ def __init__( self, passes: Sequence[PassBase], - check_invariants: bool = False, steps: int = 1, ): # TODO(justinchuby): Implement constraints self.passes = list(passes) - self.check_invariants = check_invariants self.steps = steps def __call__(self, model: ir.Model) -> PassResult: @@ -137,17 +153,10 @@ def _run_one_step(self, model: ir.Model, step: int) -> PassResult: modified = False for i, pass_ in enumerate(self.passes): logger.debug("Running the %s-th pass '%s', (step %s)", i, pass_, step) - - # 1. Check preconditions - if self.check_invariants: - try: - pass_.requires(model) - except Exception as e: - raise PreconditionError(f"Pre-condition failed for {pass_}") from e - - # 2. Run the pass try: pass_result = pass_(model) + except (PreconditionError, PostconditionError): + raise except Exception as e: prev_pass_names = [str(p) for p in self.passes[:i]] raise PassError( @@ -163,10 +172,4 @@ def _run_one_step(self, model: ir.Model, step: int) -> PassResult: model = pass_result.model modified = modified or pass_result.modified - # 3. Check postconditions - if self.check_invariants: - try: - pass_.ensures(model) - except Exception as e: - raise PostconditionError(f"Post-condition failed for {pass_}") from e return PassResult(model, modified) From 6fa351ae046f32f622bf7c03cb573f31381d1731 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 21 Mar 2025 19:49:54 -0700 Subject: [PATCH 02/19] Update PassBase --- onnxscript/ir/passes/_pass_infra.py | 29 +++++++++++++++++++++-------- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/onnxscript/ir/passes/_pass_infra.py b/onnxscript/ir/passes/_pass_infra.py index f08c89d193..1dd008da21 100644 --- a/onnxscript/ir/passes/_pass_infra.py +++ b/onnxscript/ir/passes/_pass_infra.py @@ -66,15 +66,17 @@ class PassResult: class PassBase(abc.ABC): - """Base class for all passes. + """Base class for all passes.""" - Class attributes: - in_place: Whether the pass modifies the model in place. - destructive: Whether the pass will destroy the input model when ``in_place=False``. - """ + @property + def in_place(self) -> bool: + """Whether the pass modifies the model in place.""" + return True - in_place: bool = True - destructive: bool = False + @property + def destructive(self) -> bool: + """Whether the pass will destroy the input model when ``in_place=False``.""" + return False def __call__(self, model: ir.Model) -> PassResult: # Check preconditions @@ -135,7 +137,18 @@ def __init__( self.passes = list(passes) self.steps = steps - def __call__(self, model: ir.Model) -> PassResult: + @property + def in_place(self) -> bool: + """Whether the pass modifies the model in place.""" + return all(pass_.in_place for pass_ in self.passes) + + @property + def destructive(self) -> bool: + """Whether the pass will destroy the input model when ``in_place=False``.""" + # This logic is a little conservative, but it is ok for now + return any(pass_.destructive for pass_ in self.passes) + + def call(self, model: ir.Model) -> PassResult: """Run the set of passes `steps` number of times or until the graph stops changing.""" overall_modified = False for step in range(self.steps): From e2bed4a5a730df7482de7f8906b8a9b17b323cb0 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 21 Mar 2025 19:52:22 -0700 Subject: [PATCH 03/19] Update onnxscript/ir/passes/_pass_infra.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- onnxscript/ir/passes/_pass_infra.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/ir/passes/_pass_infra.py b/onnxscript/ir/passes/_pass_infra.py index 1dd008da21..2399f259af 100644 --- a/onnxscript/ir/passes/_pass_infra.py +++ b/onnxscript/ir/passes/_pass_infra.py @@ -95,7 +95,7 @@ def __call__(self, model: ir.Model) -> PassResult: except PostconditionError: raise except Exception as e: - raise PostconditionError("Post-condition failed") from e + raise PostconditionError(f"Post-condition failed: {e.__class__.__name__}: {e}") from e return result @abc.abstractmethod From 0ced04d6910c3ab195b1493e2c3c13cb62be8ee1 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 21 Mar 2025 19:53:11 -0700 Subject: [PATCH 04/19] msg --- onnxscript/ir/passes/_pass_infra.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxscript/ir/passes/_pass_infra.py b/onnxscript/ir/passes/_pass_infra.py index 2399f259af..550de05b5e 100644 --- a/onnxscript/ir/passes/_pass_infra.py +++ b/onnxscript/ir/passes/_pass_infra.py @@ -85,7 +85,7 @@ def __call__(self, model: ir.Model) -> PassResult: except PreconditionError: raise except Exception as e: - raise PreconditionError("Pre-condition failed") from e + raise PreconditionError(f"Pre-condition failed: {e.__class__.__name__}") from e result = self.call(model) @@ -95,7 +95,7 @@ def __call__(self, model: ir.Model) -> PassResult: except PostconditionError: raise except Exception as e: - raise PostconditionError(f"Post-condition failed: {e.__class__.__name__}: {e}") from e + raise PostconditionError(f"Post-condition failed: {e.__class__.__name__}") from e return result @abc.abstractmethod From c2b1385d5ebce9abb3fcec8eaf2fad3f7295cb3e Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 21 Mar 2025 19:55:05 -0700 Subject: [PATCH 05/19] msg --- onnxscript/ir/passes/_pass_infra.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/onnxscript/ir/passes/_pass_infra.py b/onnxscript/ir/passes/_pass_infra.py index 550de05b5e..63b2141a2f 100644 --- a/onnxscript/ir/passes/_pass_infra.py +++ b/onnxscript/ir/passes/_pass_infra.py @@ -85,7 +85,9 @@ def __call__(self, model: ir.Model) -> PassResult: except PreconditionError: raise except Exception as e: - raise PreconditionError(f"Pre-condition failed: {e.__class__.__name__}") from e + raise PreconditionError( + f"Pre-condition for pass '{self.__class__.__name__}' failed" + ) from e result = self.call(model) @@ -95,7 +97,9 @@ def __call__(self, model: ir.Model) -> PassResult: except PostconditionError: raise except Exception as e: - raise PostconditionError(f"Post-condition failed: {e.__class__.__name__}") from e + raise PostconditionError( + f"Post-condition for pass '{self.__class__.__name__}' failed" + ) from e return result @abc.abstractmethod From 3fffe61a1a31d04d36540a8e46269ec6b33fc5a8 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 21 Mar 2025 20:00:34 -0700 Subject: [PATCH 06/19] checking --- onnxscript/ir/passes/_pass_infra.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/onnxscript/ir/passes/_pass_infra.py b/onnxscript/ir/passes/_pass_infra.py index 63b2141a2f..8c60d4af3b 100644 --- a/onnxscript/ir/passes/_pass_infra.py +++ b/onnxscript/ir/passes/_pass_infra.py @@ -100,6 +100,12 @@ def __call__(self, model: ir.Model) -> PassResult: raise PostconditionError( f"Post-condition for pass '{self.__class__.__name__}' failed" ) from e + + if not isinstance(result, PassResult): + raise TypeError( + f"The result of the pass '{self.__class__.__name__}' should be type PassResult." + "Please create one with ir.passes.PassResult()." + ) return result @abc.abstractmethod From f9ab36a418013843cd2717d88badfa3195c49af0 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 21 Mar 2025 20:00:52 -0700 Subject: [PATCH 07/19] check --- onnxscript/ir/passes/_pass_infra.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/onnxscript/ir/passes/_pass_infra.py b/onnxscript/ir/passes/_pass_infra.py index 8c60d4af3b..fc050c65ff 100644 --- a/onnxscript/ir/passes/_pass_infra.py +++ b/onnxscript/ir/passes/_pass_infra.py @@ -186,11 +186,6 @@ def _run_one_step(self, model: ir.Model, step: int) -> PassResult: f"An error occurred when running the '{pass_}' pass after the " f"following passes: {prev_pass_names} during step {step}" ) from e - if not isinstance(pass_result, PassResult): - raise TypeError( - f"The result of the pass {pass_} should be type PassResult." - "Please create one with ir.passes.PassResult()." - ) model = pass_result.model modified = modified or pass_result.modified From 01cee46d32e8f67912f575ff68033744236d643c Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 21 Mar 2025 20:21:50 -0700 Subject: [PATCH 08/19] msg --- onnxscript/ir/passes/_pass_infra.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/ir/passes/_pass_infra.py b/onnxscript/ir/passes/_pass_infra.py index fc050c65ff..4da3c2d41f 100644 --- a/onnxscript/ir/passes/_pass_infra.py +++ b/onnxscript/ir/passes/_pass_infra.py @@ -103,7 +103,7 @@ def __call__(self, model: ir.Model) -> PassResult: if not isinstance(result, PassResult): raise TypeError( - f"The result of the pass '{self.__class__.__name__}' should be type PassResult." + f"The result of the pass '{self.__class__.__name__}' should be type PassResult. " "Please create one with ir.passes.PassResult()." ) return result From ff947e9354a3c0b372a67f2e796b667d2999d95b Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 21 Mar 2025 21:56:58 -0700 Subject: [PATCH 09/19] docs --- onnxscript/ir/passes/_pass_infra.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/onnxscript/ir/passes/_pass_infra.py b/onnxscript/ir/passes/_pass_infra.py index 4da3c2d41f..e5decff93a 100644 --- a/onnxscript/ir/passes/_pass_infra.py +++ b/onnxscript/ir/passes/_pass_infra.py @@ -70,7 +70,11 @@ class PassBase(abc.ABC): @property def in_place(self) -> bool: - """Whether the pass modifies the model in place.""" + """Whether the pass modifies the model in place. + + If True, the pass will return the same model object that was passed in. + If False, the pass will return a new model object. + """ return True @property From 3caf8186cad10c0dbac0c6d31b5d67b818860d9c Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sat, 22 Mar 2025 07:58:42 -0700 Subject: [PATCH 10/19] Expand pass categories --- onnxscript/ir/passes/__init__.py | 8 ++ onnxscript/ir/passes/_pass_infra.py | 133 +++++++++++++----- onnxscript/optimizer/_constant_folding.py | 2 +- onnxscript/optimizer/_remove_unused.py | 2 +- .../optimizer/_remove_unused_function.py | 2 +- 5 files changed, 106 insertions(+), 41 deletions(-) diff --git a/onnxscript/ir/passes/__init__.py b/onnxscript/ir/passes/__init__.py index 9cea129d2b..bfc58b9f95 100644 --- a/onnxscript/ir/passes/__init__.py +++ b/onnxscript/ir/passes/__init__.py @@ -5,6 +5,10 @@ "PassBase", "PassResult", "PassManager", + "Sequential", + "InPlacePass", + "OutOfPlacePass", + "DestructivePass", # Errors "InvariantError", "PreconditionError", @@ -13,13 +17,17 @@ ] from onnxscript.ir.passes._pass_infra import ( + DestructivePass, + InPlacePass, InvariantError, + OutOfPlacePass, PassBase, PassError, PassManager, PassResult, PostconditionError, PreconditionError, + Sequential, ) diff --git a/onnxscript/ir/passes/_pass_infra.py b/onnxscript/ir/passes/_pass_infra.py index e5decff93a..a204a52112 100644 --- a/onnxscript/ir/passes/_pass_infra.py +++ b/onnxscript/ir/passes/_pass_infra.py @@ -20,6 +20,10 @@ __all__ = [ "PassBase", + "Sequential", + "InPlacePass", + "OutOfPlacePass", + "DestructivePass", "PassManager", "PassResult", # Errors @@ -69,18 +73,28 @@ class PassBase(abc.ABC): """Base class for all passes.""" @property + @abc.abstractmethod def in_place(self) -> bool: - """Whether the pass modifies the model in place. + """Whether the pass modifies the model in place and returns it. If True, the pass will return the same model object that was passed in. If False, the pass will return a new model object. """ return True + @property + @abc.abstractmethod + def changes_input(self) -> bool: + """Whether the pass modifies input model.""" + return True + @property def destructive(self) -> bool: - """Whether the pass will destroy the input model when ``in_place=False``.""" - return False + """Whether the pass will destroy the input model when ``in_place=False``. + + A pass is destructive if it is not in place and it modifies the input model. + """ + return not self.in_place and self.changes_input def __call__(self, model: ir.Model) -> PassResult: # Check preconditions @@ -132,7 +146,76 @@ def ensures(self, model: ir.Model) -> None: del model # Unused -class PassManager(PassBase): +class InPlacePass(PassBase): + """A pass that modifies the input model in place and returns it.""" + + @property + def in_place(self) -> bool: + return True + + @property + def changes_input(self) -> bool: + return True + + +class OutOfPlacePass(PassBase): + """A pass that returns a new model but does not modify the input model.""" + + @property + def in_place(self) -> bool: + return False + + @property + def changes_input(self) -> bool: + return False + + +class DestructivePass(PassBase): + """A pass that modifies the input model and returns a new model.""" + + @property + def in_place(self) -> bool: + return False + + @property + def changes_input(self) -> bool: + return True + + +class Sequential(PassBase): + """Run a sequence of passes in order.""" + + def __init__(self, *passes: PassBase): + self.passes = passes + + @property + def in_place(self) -> bool: + return all(pass_.in_place for pass_ in self.passes) + + @property + def changes_input(self) -> bool: + return self.passes[0].changes_input or self.passes[0].in_place + + def call(self, model: ir.Model) -> PassResult: + modified = False + for i, pass_ in enumerate(self.passes): + logger.debug("Running the %s-th pass '%s'", i, pass_) + try: + pass_result = pass_(model) + except Exception as e: + prev_pass_names = [str(p) for p in self.passes[:i]] + raise PassError( + f"An error occurred when running the '{pass_}' pass after the " + f"following passes: {prev_pass_names}" + ) from e + + model = pass_result.model + modified = modified or pass_result.modified + + return PassResult(model, modified) + + +class PassManager(Sequential): """Pass manager for the IR. The PassManager is a Pass that runs a sequence of passes on a model. @@ -146,52 +229,26 @@ def __init__( self, passes: Sequence[PassBase], steps: int = 1, + early_stop=True, ): # TODO(justinchuby): Implement constraints - self.passes = list(passes) + super().__init__(*passes) self.steps = steps - - @property - def in_place(self) -> bool: - """Whether the pass modifies the model in place.""" - return all(pass_.in_place for pass_ in self.passes) - - @property - def destructive(self) -> bool: - """Whether the pass will destroy the input model when ``in_place=False``.""" - # This logic is a little conservative, but it is ok for now - return any(pass_.destructive for pass_ in self.passes) + self.early_stop = early_stop def call(self, model: ir.Model) -> PassResult: """Run the set of passes `steps` number of times or until the graph stops changing.""" overall_modified = False for step in range(self.steps): - step_result = self._run_one_step(model, step) + try: + step_result = super().__call__(model) + except Exception as e: + raise PassError(f"An error occurred at step {step}") from e model = step_result.model modified = step_result.modified overall_modified = overall_modified or modified # If the graph no longer changes, then we can stop running these passes - if not modified: + if not modified and self.early_stop: logger.info("PassManager: No more graph changes detected after step %s", step) break return PassResult(model, overall_modified) - - def _run_one_step(self, model: ir.Model, step: int) -> PassResult: - modified = False - for i, pass_ in enumerate(self.passes): - logger.debug("Running the %s-th pass '%s', (step %s)", i, pass_, step) - try: - pass_result = pass_(model) - except (PreconditionError, PostconditionError): - raise - except Exception as e: - prev_pass_names = [str(p) for p in self.passes[:i]] - raise PassError( - f"An error occurred when running the '{pass_}' pass after the " - f"following passes: {prev_pass_names} during step {step}" - ) from e - - model = pass_result.model - modified = modified or pass_result.modified - - return PassResult(model, modified) diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index a40dc76293..db3386f89d 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -797,7 +797,7 @@ def merge_dims(dim1, dim2): return ir.Shape([merge_dims(dim1, dim2) for dim1, dim2 in zip(shape1, shape2)]) -class FoldConstantsPass(ir.passes.PassBase): +class FoldConstantsPass(ir.passes.InPlacePass): def __init__( self, *, diff --git a/onnxscript/optimizer/_remove_unused.py b/onnxscript/optimizer/_remove_unused.py index e1e0136ddb..e160d895ee 100644 --- a/onnxscript/optimizer/_remove_unused.py +++ b/onnxscript/optimizer/_remove_unused.py @@ -82,7 +82,7 @@ def _process_function_or_graph(function_or_graph: ir.Function | ir.Graph) -> int return count -class RemoveUnusedNodesPass(ir.passes.PassBase): +class RemoveUnusedNodesPass(ir.passes.InPlacePass): def call(self, model: ir.Model) -> ir.passes.PassResult: count = _process_function_or_graph(model.graph) graph_outputs = frozenset(model.graph.outputs) diff --git a/onnxscript/optimizer/_remove_unused_function.py b/onnxscript/optimizer/_remove_unused_function.py index dedf69d91d..64d2643ab2 100644 --- a/onnxscript/optimizer/_remove_unused_function.py +++ b/onnxscript/optimizer/_remove_unused_function.py @@ -25,7 +25,7 @@ def _clean_up_unused_functions(model: ir.Model, unused: set[ir.OperatorIdentifie logger.debug("Functions removed: %s", unused) -class RemoveUnusedFunctionPass(ir.passes.PassBase): +class RemoveUnusedFunctionPass(ir.passes.InPlacePass): def __init__(self): super().__init__() self.used: set[ir.OperatorIdentifier] | None = None From dd5d17beff007e45b3b52dde3fda36e86af6dc28 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sat, 22 Mar 2025 08:04:49 -0700 Subject: [PATCH 11/19] copilot review --- onnxscript/ir/passes/_pass_infra.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/onnxscript/ir/passes/_pass_infra.py b/onnxscript/ir/passes/_pass_infra.py index a204a52112..4877b31751 100644 --- a/onnxscript/ir/passes/_pass_infra.py +++ b/onnxscript/ir/passes/_pass_infra.py @@ -186,6 +186,8 @@ class Sequential(PassBase): """Run a sequence of passes in order.""" def __init__(self, *passes: PassBase): + if not passes: + raise ValueError("Sequential must take at least one pass") self.passes = passes @property From 2d77e3685fe9e2d5e1d0a821ba515e125afaff28 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sat, 22 Mar 2025 08:06:04 -0700 Subject: [PATCH 12/19] early_stop --- onnxscript/ir/passes/_pass_infra.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/onnxscript/ir/passes/_pass_infra.py b/onnxscript/ir/passes/_pass_infra.py index 4877b31751..5468d5be86 100644 --- a/onnxscript/ir/passes/_pass_infra.py +++ b/onnxscript/ir/passes/_pass_infra.py @@ -225,13 +225,14 @@ class PassManager(Sequential): Attributes: passes: The passes to run. steps: The number of times to run the passes. + early_stop: Whether to stop running the passes if the graph stops changing. """ def __init__( self, passes: Sequence[PassBase], steps: int = 1, - early_stop=True, + early_stop: bool = True, ): # TODO(justinchuby): Implement constraints super().__init__(*passes) From 5f906c923c5931377146b903c45cf4fa915e318e Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 24 Mar 2025 11:03:52 -0700 Subject: [PATCH 13/19] Update shape_inference --- onnxscript/ir/passes/common/shape_inference.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/onnxscript/ir/passes/common/shape_inference.py b/onnxscript/ir/passes/common/shape_inference.py index 7502ecbf79..f90f48a3e9 100644 --- a/onnxscript/ir/passes/common/shape_inference.py +++ b/onnxscript/ir/passes/common/shape_inference.py @@ -22,12 +22,9 @@ _BIG_TENSOR_SIZE_LIMIT = 1000 # 1KB -class ShapeInferencePass(ir.passes.PassBase): +class ShapeInferencePass(ir.passes.OutOfPlacePass): """This pass performs shape inference on the graph.""" - # This pass does not modify the model in place. - in_place = False - def __init__( self, check_type: bool = True, strict_mode: bool = True, data_prop: bool = True ) -> None: From b38528a3aa5d6adac36de9f20b6661b6ded5f786 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 24 Mar 2025 13:31:36 -0700 Subject: [PATCH 14/19] Remove DestructivePass --- onnxscript/ir/passes/__init__.py | 2 -- onnxscript/ir/passes/_pass_infra.py | 13 ------------- 2 files changed, 15 deletions(-) diff --git a/onnxscript/ir/passes/__init__.py b/onnxscript/ir/passes/__init__.py index bfc58b9f95..9d1be233e4 100644 --- a/onnxscript/ir/passes/__init__.py +++ b/onnxscript/ir/passes/__init__.py @@ -8,7 +8,6 @@ "Sequential", "InPlacePass", "OutOfPlacePass", - "DestructivePass", # Errors "InvariantError", "PreconditionError", @@ -17,7 +16,6 @@ ] from onnxscript.ir.passes._pass_infra import ( - DestructivePass, InPlacePass, InvariantError, OutOfPlacePass, diff --git a/onnxscript/ir/passes/_pass_infra.py b/onnxscript/ir/passes/_pass_infra.py index bcc5e65255..4f98094727 100644 --- a/onnxscript/ir/passes/_pass_infra.py +++ b/onnxscript/ir/passes/_pass_infra.py @@ -23,7 +23,6 @@ "Sequential", "InPlacePass", "OutOfPlacePass", - "DestructivePass", "PassManager", "PassResult", # Errors @@ -170,18 +169,6 @@ def changes_input(self) -> bool: return False -class DestructivePass(PassBase): - """A pass that modifies the input model and returns a new model.""" - - @property - def in_place(self) -> bool: - return False - - @property - def changes_input(self) -> bool: - return True - - class Sequential(PassBase): """Run a sequence of passes in order.""" From 460a41a287257154b98caeb04caf552ae5546a02 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 24 Mar 2025 15:06:31 -0700 Subject: [PATCH 15/19] Cache _changes_input --- onnxscript/ir/passes/_pass_infra.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/onnxscript/ir/passes/_pass_infra.py b/onnxscript/ir/passes/_pass_infra.py index 4f98094727..f07c115373 100644 --- a/onnxscript/ir/passes/_pass_infra.py +++ b/onnxscript/ir/passes/_pass_infra.py @@ -176,14 +176,16 @@ def __init__(self, *passes: PassBase): if not passes: raise ValueError("Sequential must take at least one pass") self.passes = passes + self._in_place = all(pass_.in_place for pass_ in passes) + self._changes_input = self.passes[0].changes_input or self.passes[0].in_place @property def in_place(self) -> bool: - return all(pass_.in_place for pass_ in self.passes) + return self._in_place @property def changes_input(self) -> bool: - return self.passes[0].changes_input or self.passes[0].in_place + return self._changes_input def call(self, model: ir.Model) -> PassResult: modified = False From 996551ff582c4aad0d08766e212a3b7770cf253a Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 24 Mar 2025 16:21:17 -0700 Subject: [PATCH 16/19] FunctionalPass --- onnxscript/ir/passes/__init__.py | 4 ++-- onnxscript/ir/passes/_pass_infra.py | 4 ++-- onnxscript/ir/passes/common/shape_inference.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/onnxscript/ir/passes/__init__.py b/onnxscript/ir/passes/__init__.py index 9d1be233e4..8a18c1b72f 100644 --- a/onnxscript/ir/passes/__init__.py +++ b/onnxscript/ir/passes/__init__.py @@ -7,7 +7,7 @@ "PassManager", "Sequential", "InPlacePass", - "OutOfPlacePass", + "FunctionalPass", # Errors "InvariantError", "PreconditionError", @@ -16,9 +16,9 @@ ] from onnxscript.ir.passes._pass_infra import ( + FunctionalPass, InPlacePass, InvariantError, - OutOfPlacePass, PassBase, PassError, PassManager, diff --git a/onnxscript/ir/passes/_pass_infra.py b/onnxscript/ir/passes/_pass_infra.py index f07c115373..51db9eb829 100644 --- a/onnxscript/ir/passes/_pass_infra.py +++ b/onnxscript/ir/passes/_pass_infra.py @@ -22,7 +22,7 @@ "PassBase", "Sequential", "InPlacePass", - "OutOfPlacePass", + "FunctionalPass", "PassManager", "PassResult", # Errors @@ -157,7 +157,7 @@ def changes_input(self) -> bool: return True -class OutOfPlacePass(PassBase): +class FunctionalPass(PassBase): """A pass that returns a new model but does not modify the input model.""" @property diff --git a/onnxscript/ir/passes/common/shape_inference.py b/onnxscript/ir/passes/common/shape_inference.py index f90f48a3e9..f6d88584e7 100644 --- a/onnxscript/ir/passes/common/shape_inference.py +++ b/onnxscript/ir/passes/common/shape_inference.py @@ -22,7 +22,7 @@ _BIG_TENSOR_SIZE_LIMIT = 1000 # 1KB -class ShapeInferencePass(ir.passes.OutOfPlacePass): +class ShapeInferencePass(ir.passes.FunctionalPass): """This pass performs shape inference on the graph.""" def __init__( From 870309760ef23cfe5f9fe4c2729a8d0dc11594c4 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 24 Mar 2025 16:35:08 -0700 Subject: [PATCH 17/19] Update onnxscript/ir/passes/_pass_infra.py --- onnxscript/ir/passes/_pass_infra.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/onnxscript/ir/passes/_pass_infra.py b/onnxscript/ir/passes/_pass_infra.py index 51db9eb829..18d9908b77 100644 --- a/onnxscript/ir/passes/_pass_infra.py +++ b/onnxscript/ir/passes/_pass_infra.py @@ -177,6 +177,9 @@ def __init__(self, *passes: PassBase): raise ValueError("Sequential must take at least one pass") self.passes = passes self._in_place = all(pass_.in_place for pass_ in passes) + # The reason changes_inputs is decided by the first pass is that if the first pass is either in-place, + # or if it is not designed to be in-place but somehow changes the input (destructive), + # this pass sequence will change inputs. self._changes_input = self.passes[0].changes_input or self.passes[0].in_place @property From 9b5264137c1f46ec78b14cf8c598833f1814d967 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 25 Mar 2025 09:12:14 -0700 Subject: [PATCH 18/19] docs --- onnxscript/ir/passes/_pass_infra.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/onnxscript/ir/passes/_pass_infra.py b/onnxscript/ir/passes/_pass_infra.py index 18d9908b77..c4fe3d868a 100644 --- a/onnxscript/ir/passes/_pass_infra.py +++ b/onnxscript/ir/passes/_pass_infra.py @@ -69,7 +69,20 @@ class PassResult: class PassBase(abc.ABC): - """Base class for all passes.""" + """Base class for all passes. + + + ``in_place`` and ``changes_input`` properties and what they mean: + + +------------+------------------+----------------------------+ + | | changes_inputs | not changes_inputs | + +------------+------------------+----------------------------+ + | in_place | in place | Side-effect-only pass | + +------------+------------------+----------------------------+ + | not | destructive | functional | + | in_place | | | + +------------+------------------+----------------------------+ + """ @property @abc.abstractmethod From 0b059b9cae76c1023057dc0f9f0775e450c6b628 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 25 Mar 2025 09:12:35 -0700 Subject: [PATCH 19/19] NotImplementedError --- onnxscript/ir/passes/_pass_infra.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxscript/ir/passes/_pass_infra.py b/onnxscript/ir/passes/_pass_infra.py index c4fe3d868a..e6cd5fbbb9 100644 --- a/onnxscript/ir/passes/_pass_infra.py +++ b/onnxscript/ir/passes/_pass_infra.py @@ -92,13 +92,13 @@ def in_place(self) -> bool: If True, the pass will return the same model object that was passed in. If False, the pass will return a new model object. """ - return True + raise NotImplementedError @property @abc.abstractmethod def changes_input(self) -> bool: """Whether the pass modifies input model.""" - return True + raise NotImplementedError @property def destructive(self) -> bool: