diff --git a/onnxscript/ir/passes/__init__.py b/onnxscript/ir/passes/__init__.py index 9cea129d2b..8a18c1b72f 100644 --- a/onnxscript/ir/passes/__init__.py +++ b/onnxscript/ir/passes/__init__.py @@ -5,6 +5,9 @@ "PassBase", "PassResult", "PassManager", + "Sequential", + "InPlacePass", + "FunctionalPass", # Errors "InvariantError", "PreconditionError", @@ -13,6 +16,8 @@ ] from onnxscript.ir.passes._pass_infra import ( + FunctionalPass, + InPlacePass, InvariantError, PassBase, PassError, @@ -20,6 +25,7 @@ PassResult, PostconditionError, PreconditionError, + Sequential, ) diff --git a/onnxscript/ir/passes/_pass_infra.py b/onnxscript/ir/passes/_pass_infra.py index 0d11a23814..e6cd5fbbb9 100644 --- a/onnxscript/ir/passes/_pass_infra.py +++ b/onnxscript/ir/passes/_pass_infra.py @@ -20,6 +20,9 @@ __all__ = [ "PassBase", + "Sequential", + "InPlacePass", + "FunctionalPass", "PassManager", "PassResult", # Errors @@ -68,14 +71,72 @@ class PassResult: class PassBase(abc.ABC): """Base class for all passes. - Class attributes: - in_place: Whether the pass modifies the model in place. + + ``in_place`` and ``changes_input`` properties and what they mean: + + +------------+------------------+----------------------------+ + | | changes_inputs | not changes_inputs | + +------------+------------------+----------------------------+ + | in_place | in place | Side-effect-only pass | + +------------+------------------+----------------------------+ + | not | destructive | functional | + | in_place | | | + +------------+------------------+----------------------------+ """ - in_place: bool = True + @property + @abc.abstractmethod + def in_place(self) -> bool: + """Whether the pass modifies the model in place and returns it. + + If True, the pass will return the same model object that was passed in. + If False, the pass will return a new model object. + """ + raise NotImplementedError + + @property + @abc.abstractmethod + def changes_input(self) -> bool: + """Whether the pass modifies input model.""" + raise NotImplementedError + + @property + def destructive(self) -> bool: + """Whether the pass will destroy the input model when ``in_place=False``. + + A pass is destructive if it is not in place and it modifies the input model. + """ + return not self.in_place and self.changes_input def __call__(self, model: ir.Model) -> PassResult: - return self.call(model) + # Check preconditions + try: + self.requires(model) + except PreconditionError: + raise + except Exception as e: + raise PreconditionError( + f"Pre-condition for pass '{self.__class__.__name__}' failed" + ) from e + + result = self.call(model) + + # Check postconditions + try: + self.ensures(model) + except PostconditionError: + raise + except Exception as e: + raise PostconditionError( + f"Post-condition for pass '{self.__class__.__name__}' failed" + ) from e + + if not isinstance(result, PassResult): + raise TypeError( + f"The result of the pass '{self.__class__.__name__}' should be type PassResult. " + "Please create one with ir.passes.PassResult()." + ) + return result @abc.abstractmethod def call(self, model: ir.Model) -> PassResult: @@ -97,76 +158,105 @@ def ensures(self, model: ir.Model) -> None: del model # Unused -class PassManager: +class InPlacePass(PassBase): + """A pass that modifies the input model in place and returns it.""" + + @property + def in_place(self) -> bool: + return True + + @property + def changes_input(self) -> bool: + return True + + +class FunctionalPass(PassBase): + """A pass that returns a new model but does not modify the input model.""" + + @property + def in_place(self) -> bool: + return False + + @property + def changes_input(self) -> bool: + return False + + +class Sequential(PassBase): + """Run a sequence of passes in order.""" + + def __init__(self, *passes: PassBase): + if not passes: + raise ValueError("Sequential must take at least one pass") + self.passes = passes + self._in_place = all(pass_.in_place for pass_ in passes) + # The reason changes_inputs is decided by the first pass is that if the first pass is either in-place, + # or if it is not designed to be in-place but somehow changes the input (destructive), + # this pass sequence will change inputs. + self._changes_input = self.passes[0].changes_input or self.passes[0].in_place + + @property + def in_place(self) -> bool: + return self._in_place + + @property + def changes_input(self) -> bool: + return self._changes_input + + def call(self, model: ir.Model) -> PassResult: + modified = False + for i, pass_ in enumerate(self.passes): + logger.debug("Running the %s-th pass '%s'", i, pass_) + try: + pass_result = pass_(model) + except Exception as e: + prev_pass_names = [str(p) for p in self.passes[:i]] + raise PassError( + f"An error occurred when running the '{pass_}' pass after the " + f"following passes: {prev_pass_names}" + ) from e + + model = pass_result.model + modified = modified or pass_result.modified + + return PassResult(model, modified) + + +class PassManager(Sequential): """Pass manager for the IR. - The PassManager is a callable that runs a sequence of passes on a model. + The PassManager is a Pass that runs a sequence of passes on a model. Attributes: passes: The passes to run. - check_invariants: Whether to check invariants before and after each pass. steps: The number of times to run the passes. + early_stop: Whether to stop running the passes if the graph stops changing. """ def __init__( self, passes: Sequence[PassBase], - check_invariants: bool = False, steps: int = 1, + early_stop: bool = True, ): # TODO(justinchuby): Implement constraints - self.passes = list(passes) - self.check_invariants = check_invariants + super().__init__(*passes) self.steps = steps + self.early_stop = early_stop - def __call__(self, model: ir.Model) -> PassResult: + def call(self, model: ir.Model) -> PassResult: """Run the set of passes `steps` number of times or until the graph stops changing.""" overall_modified = False for step in range(self.steps): - step_result = self._run_one_step(model, step) + try: + step_result = super().__call__(model) + except Exception as e: + raise PassError(f"An error occurred at step {step}") from e model = step_result.model modified = step_result.modified overall_modified = overall_modified or modified # If the graph no longer changes, then we can stop running these passes - if not modified: + if not modified and self.early_stop: logger.info("PassManager: No more graph changes detected after step %s", step) break return PassResult(model, overall_modified) - - def _run_one_step(self, model: ir.Model, step: int) -> PassResult: - modified = False - for i, pass_ in enumerate(self.passes): - logger.debug("Running the %s-th pass '%s', (step %s)", i, pass_, step) - - # 1. Check preconditions - if self.check_invariants: - try: - pass_.requires(model) - except Exception as e: - raise PreconditionError(f"Pre-condition failed for {pass_}") from e - - # 2. Run the pass - try: - pass_result = pass_(model) - except Exception as e: - prev_pass_names = [str(p) for p in self.passes[:i]] - raise PassError( - f"An error occurred when running the '{pass_}' pass after the " - f"following passes: {prev_pass_names} during step {step}" - ) from e - if not isinstance(pass_result, PassResult): - raise TypeError( - f"The result of the pass {pass_} should be type PassResult." - "Please create one with ir.passes.PassResult()." - ) - - model = pass_result.model - modified = modified or pass_result.modified - - # 3. Check postconditions - if self.check_invariants: - try: - pass_.ensures(model) - except Exception as e: - raise PostconditionError(f"Post-condition failed for {pass_}") from e - return PassResult(model, modified) diff --git a/onnxscript/ir/passes/common/shape_inference.py b/onnxscript/ir/passes/common/shape_inference.py index 7502ecbf79..f6d88584e7 100644 --- a/onnxscript/ir/passes/common/shape_inference.py +++ b/onnxscript/ir/passes/common/shape_inference.py @@ -22,12 +22,9 @@ _BIG_TENSOR_SIZE_LIMIT = 1000 # 1KB -class ShapeInferencePass(ir.passes.PassBase): +class ShapeInferencePass(ir.passes.FunctionalPass): """This pass performs shape inference on the graph.""" - # This pass does not modify the model in place. - in_place = False - def __init__( self, check_type: bool = True, strict_mode: bool = True, data_prop: bool = True ) -> None: diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index a40dc76293..db3386f89d 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -797,7 +797,7 @@ def merge_dims(dim1, dim2): return ir.Shape([merge_dims(dim1, dim2) for dim1, dim2 in zip(shape1, shape2)]) -class FoldConstantsPass(ir.passes.PassBase): +class FoldConstantsPass(ir.passes.InPlacePass): def __init__( self, *, diff --git a/onnxscript/optimizer/_remove_unused.py b/onnxscript/optimizer/_remove_unused.py index e1e0136ddb..e160d895ee 100644 --- a/onnxscript/optimizer/_remove_unused.py +++ b/onnxscript/optimizer/_remove_unused.py @@ -82,7 +82,7 @@ def _process_function_or_graph(function_or_graph: ir.Function | ir.Graph) -> int return count -class RemoveUnusedNodesPass(ir.passes.PassBase): +class RemoveUnusedNodesPass(ir.passes.InPlacePass): def call(self, model: ir.Model) -> ir.passes.PassResult: count = _process_function_or_graph(model.graph) graph_outputs = frozenset(model.graph.outputs) diff --git a/onnxscript/optimizer/_remove_unused_function.py b/onnxscript/optimizer/_remove_unused_function.py index dedf69d91d..64d2643ab2 100644 --- a/onnxscript/optimizer/_remove_unused_function.py +++ b/onnxscript/optimizer/_remove_unused_function.py @@ -25,7 +25,7 @@ def _clean_up_unused_functions(model: ir.Model, unused: set[ir.OperatorIdentifie logger.debug("Functions removed: %s", unused) -class RemoveUnusedFunctionPass(ir.passes.PassBase): +class RemoveUnusedFunctionPass(ir.passes.InPlacePass): def __init__(self): super().__init__() self.used: set[ir.OperatorIdentifier] | None = None