From 28beefdaad18467b669925fa21cc62d9b9796c07 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 21 Mar 2025 17:53:25 -0700 Subject: [PATCH 01/13] Graph extractor --- onnxscript/ir/_convenience.py | 2 + onnxscript/ir/convenience.py | 2 + onnxscript/ir/passes/_pass_infra.py | 40 ++--- .../ir/passes/common/graph_extration.py | 137 ++++++++++++++++++ 4 files changed, 163 insertions(+), 18 deletions(-) create mode 100644 onnxscript/ir/passes/common/graph_extration.py diff --git a/onnxscript/ir/_convenience.py b/onnxscript/ir/_convenience.py index d59bfe4797..e46de74285 100644 --- a/onnxscript/ir/_convenience.py +++ b/onnxscript/ir/_convenience.py @@ -12,6 +12,8 @@ "convert_attribute", "convert_attributes", "replace_all_uses_with", + "create_value_mapping", + "replace_nodes_and_values", ] import typing diff --git a/onnxscript/ir/convenience.py b/onnxscript/ir/convenience.py index fc8416cc1f..480ff603b0 100644 --- a/onnxscript/ir/convenience.py +++ b/onnxscript/ir/convenience.py @@ -9,11 +9,13 @@ "convert_attributes", "replace_all_uses_with", "replace_nodes_and_values", + "create_value_mapping", ] from onnxscript.ir._convenience import ( convert_attribute, convert_attributes, + create_value_mapping, replace_all_uses_with, replace_nodes_and_values, ) diff --git a/onnxscript/ir/passes/_pass_infra.py b/onnxscript/ir/passes/_pass_infra.py index c03a23bd8b..48274e5fd5 100644 --- a/onnxscript/ir/passes/_pass_infra.py +++ b/onnxscript/ir/passes/_pass_infra.py @@ -70,12 +70,31 @@ class PassBase(abc.ABC): Class attributes: in_place: Whether the pass modifies the model in place. + destructive: Whether the pass will destroy the input model when ``in_place=False``. """ in_place: bool = True + destructive: bool = False def __call__(self, model: ir.Model) -> PassResult: - return self.call(model) + # Check preconditions + try: + self.requires(model) + except PreconditionError: + raise + except Exception as e: + raise PreconditionError("Pre-condition failed") from e + + result = self.call(model) + + # Check postconditions + try: + self.ensures(model) + except PostconditionError: + raise + except Exception as e: + raise PostconditionError("Post-condition failed") from e + return result @abc.abstractmethod def call(self, model: ir.Model) -> PassResult: @@ -111,12 +130,10 @@ class PassManager: def __init__( self, passes: Sequence[PassBase], - check_invariants: bool = False, steps: int = 1, ): # TODO(justinchuby): Implement constraints self.passes = list(passes) - self.check_invariants = check_invariants self.steps = steps def __call__(self, model: ir.Model) -> PassResult: @@ -137,17 +154,10 @@ def _run_one_step(self, model: ir.Model, step: int) -> PassResult: modified = False for i, pass_ in enumerate(self.passes): logger.debug("Running the %s-th pass '%s', (step %s)", i, pass_, step) - - # 1. Check preconditions - if self.check_invariants: - try: - pass_.requires(model) - except Exception as e: - raise PreconditionError(f"Pre-condition failed for {pass_}") from e - - # 2. Run the pass try: pass_result = pass_(model) + except (PreconditionError, PostconditionError): + raise except Exception as e: prev_pass_names = [str(p) for p in self.passes[:i]] raise PassError( @@ -163,10 +173,4 @@ def _run_one_step(self, model: ir.Model, step: int) -> PassResult: model = pass_result.model modified = modified or pass_result.modified - # 3. Check postconditions - if self.check_invariants: - try: - pass_.ensures(model) - except Exception as e: - raise PostconditionError(f"Post-condition failed for {pass_}") from e return PassResult(model, modified) diff --git a/onnxscript/ir/passes/common/graph_extration.py b/onnxscript/ir/passes/common/graph_extration.py new file mode 100644 index 0000000000..dfb749d6cd --- /dev/null +++ b/onnxscript/ir/passes/common/graph_extration.py @@ -0,0 +1,137 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Passes for extracting subgraphs from a graph.""" + +from __future__ import annotations +import itertools + +__all__ = [ + "ExtractGraphByNodePass", +] + +from collections.abc import Collection +import logging + +import onnx + +from onnxscript import ir + +logger = logging.getLogger(__name__) + + +def _find_subgraph_bounded_by_values( + graph: ir.Graph, inputs: Collection[ir.Value], outputs: Collection[ir.Value] +) -> tuple[list[ir.Node], list[ir.Value]]: + """Finds the subgraph bounded by the given inputs and outputs. + + Args: + graph: The graph to search. + inputs: The inputs to the subgraph. + outputs: The outputs of the subgraph. + + Returns: + A list of nodes in the subgraph and the initializers used. + """ + all_nodes = [] + value_stack: list[ir.Value] = [*outputs] + visited_nodes: set[ir.Node] = set() + visited_values: set[ir.Value] = set() + initializers = [] + while value_stack: + value = value_stack.pop() + if value in visited_values: + continue + if value.name in graph.initializers: + # Record the initializer + assert value.const_value is not None + initializers.append(value) + visited_values.add(value) + if (node := value.producer()) is not None: + if node not in visited_nodes: + visited_nodes.add(node) + all_nodes.append(node) + for input in node.inputs: + if input not in visited_values and input is not None: + value_stack.append(input) + return all_nodes, initializers + + +class ExtractGraphByValuePass(ir.passes.PassBase): + """This pass performs shape inference on the graph.""" + + # This pass does not modify the model in place + in_place = False + # This pass destroys the input model + destructive = True + + def __init__(self, *, input_names: Collection[str], output_names: Collection[str]) -> None: + """Extracts sub-model from an ONNX model. + + The sub-model is defined by the names of the input and output tensors *exactly*. + + Args: + input_names: The names of the inputs to extract. + output_names: The names of the outputs to extract. + """ + super().__init__() + self.input_names = input_names + self.output_names = output_names + + def requires(self, model: ir.Model) -> None: + # All inputs and outputs can be found in the model + values = ir.convenience.create_value_mapping(model.graph) + input_names_not_found = sorted(set(self.input_names) - set(values.keys())) + if input_names_not_found: + raise ir.passes.PreconditionError( + f"Input names not found in the model: {input_names_not_found}" + ) + output_names_not_found = sorted(set(self.output_names) - set(values.keys())) + if output_names_not_found: + raise ir.passes.PreconditionError( + f"Output names not found in the model: {output_names_not_found}" + ) + + # All inputs and outputs must have type and shape + for name in itertools.chain(self.input_names, self.output_names): + value = values[name] + if value.type is None: + raise ir.passes.PreconditionError( + f"Value {name} does not have a type: {value}. " + "Consider setting its type or running shape inference first." + ) + if value.shape is None: + raise ir.passes.PreconditionError( + f"Value {name} does not have a shape: {value}. " + "Consider setting its shape or running shape inference first." + ) + + def call(self, model: ir.Model) -> ir.passes.PassResult: + values = ir.convenience.create_value_mapping(model.graph) + inputs = [values[name] for name in self.input_names] + outputs = [values[name] for name in self.output_names] + extracted_nodes, initializers = _find_subgraph_bounded_by_values( + model.graph, inputs, outputs + ) + # Create a graph with the extracted nodes + model.graph.remove(extracted_nodes) + new_model = ir.Model( + ir.Graph( + inputs, + outputs, + nodes=extracted_nodes, + initializers=initializers, + doc_string=model.graph.doc_string, + opset_imports=model.graph.opset_imports, + name=model.graph.name, + metadata_props=model.graph.metadata_props, + ), + ir_version=model.ir_version, + producer_name=model.producer_name, + producer_version=model.producer_version, + domain=model.domain, + model_version=model.model_version, + doc_string=model.doc_string, + functions=tuple(model.functions.values()), + meta_data_props=model.metadata_props, + ) + return ir.passes.PassResult(new_model, modified=True) From b4ab30f57184d7f9ef9650bfde2e557fc39b97d1 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 21 Mar 2025 17:55:08 -0700 Subject: [PATCH 02/13] rename --- onnxscript/ir/passes/common/graph_extration.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/onnxscript/ir/passes/common/graph_extration.py b/onnxscript/ir/passes/common/graph_extration.py index dfb749d6cd..665baedf9f 100644 --- a/onnxscript/ir/passes/common/graph_extration.py +++ b/onnxscript/ir/passes/common/graph_extration.py @@ -3,16 +3,15 @@ """Passes for extracting subgraphs from a graph.""" from __future__ import annotations + import itertools __all__ = [ - "ExtractGraphByNodePass", + "ExtractGraphPass", ] -from collections.abc import Collection import logging - -import onnx +from collections.abc import Collection from onnxscript import ir @@ -56,7 +55,7 @@ def _find_subgraph_bounded_by_values( return all_nodes, initializers -class ExtractGraphByValuePass(ir.passes.PassBase): +class ExtractGraphPass(ir.passes.PassBase): """This pass performs shape inference on the graph.""" # This pass does not modify the model in place From 30fb4c4de297ea878f714601bd5d81bc674b083b Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 21 Mar 2025 18:04:12 -0700 Subject: [PATCH 03/13] Fix inputs --- .../ir/passes/common/graph_extration.py | 81 +++++++++++-------- 1 file changed, 48 insertions(+), 33 deletions(-) diff --git a/onnxscript/ir/passes/common/graph_extration.py b/onnxscript/ir/passes/common/graph_extration.py index 665baedf9f..bbdf6fc19e 100644 --- a/onnxscript/ir/passes/common/graph_extration.py +++ b/onnxscript/ir/passes/common/graph_extration.py @@ -34,7 +34,7 @@ def _find_subgraph_bounded_by_values( all_nodes = [] value_stack: list[ir.Value] = [*outputs] visited_nodes: set[ir.Node] = set() - visited_values: set[ir.Value] = set() + visited_values: set[ir.Value] = set(inputs) initializers = [] while value_stack: value = value_stack.pop() @@ -69,41 +69,13 @@ def __init__(self, *, input_names: Collection[str], output_names: Collection[str The sub-model is defined by the names of the input and output tensors *exactly*. Args: - input_names: The names of the inputs to extract. - output_names: The names of the outputs to extract. + input_names: The names of the inputs to extract. Must be deduplicated. + output_names: The names of the outputs to extract. Must be deduplicated. """ super().__init__() self.input_names = input_names self.output_names = output_names - def requires(self, model: ir.Model) -> None: - # All inputs and outputs can be found in the model - values = ir.convenience.create_value_mapping(model.graph) - input_names_not_found = sorted(set(self.input_names) - set(values.keys())) - if input_names_not_found: - raise ir.passes.PreconditionError( - f"Input names not found in the model: {input_names_not_found}" - ) - output_names_not_found = sorted(set(self.output_names) - set(values.keys())) - if output_names_not_found: - raise ir.passes.PreconditionError( - f"Output names not found in the model: {output_names_not_found}" - ) - - # All inputs and outputs must have type and shape - for name in itertools.chain(self.input_names, self.output_names): - value = values[name] - if value.type is None: - raise ir.passes.PreconditionError( - f"Value {name} does not have a type: {value}. " - "Consider setting its type or running shape inference first." - ) - if value.shape is None: - raise ir.passes.PreconditionError( - f"Value {name} does not have a shape: {value}. " - "Consider setting its shape or running shape inference first." - ) - def call(self, model: ir.Model) -> ir.passes.PassResult: values = ir.convenience.create_value_mapping(model.graph) inputs = [values[name] for name in self.input_names] @@ -111,11 +83,25 @@ def call(self, model: ir.Model) -> ir.passes.PassResult: extracted_nodes, initializers = _find_subgraph_bounded_by_values( model.graph, inputs, outputs ) - # Create a graph with the extracted nodes + model.graph.remove(extracted_nodes) + # Create inputs for the new graph as the old inputs are owned by the old nodes + new_inputs = [] + for input in inputs: + new_inputs.append( + ir.Value( + name=input.name, + shape=input.shape, + type=input.type, + doc_string=input.doc_string, + const_value=input.const_value, + ) + ) + ir.convenience.replace_all_uses_with(inputs, new_inputs) + new_model = ir.Model( ir.Graph( - inputs, + new_inputs, outputs, nodes=extracted_nodes, initializers=initializers, @@ -134,3 +120,32 @@ def call(self, model: ir.Model) -> ir.passes.PassResult: meta_data_props=model.metadata_props, ) return ir.passes.PassResult(new_model, modified=True) + + def requires(self, model: ir.Model) -> None: + # All inputs and outputs can be found in the model + values = ir.convenience.create_value_mapping(model.graph) + input_names_not_found = sorted(set(self.input_names) - set(values.keys())) + if input_names_not_found: + raise ir.passes.PreconditionError( + f"Input names not found in the model: {input_names_not_found}" + ) + output_names_not_found = sorted(set(self.output_names) - set(values.keys())) + if output_names_not_found: + raise ir.passes.PreconditionError( + f"Output names not found in the model: {output_names_not_found}" + ) + + # All inputs and outputs must have type and shape + for name in itertools.chain(self.input_names, self.output_names): + value = values[name] + if value.type is None: + raise ir.passes.PreconditionError( + f"Value {name} does not have a type: {value}. " + "Consider setting its type or running shape inference first." + ) + if value.shape is None: + raise ir.passes.PreconditionError( + f"Value {name} does not have a shape: {value}. " + "Consider setting its shape or running shape inference first." + ) + # TODO(justinchuby): Make sure the subgraph is completely bounded by inputs and outputs From 7455c2ea79a3b509790a04628e59a45180a5359f Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 21 Mar 2025 18:07:47 -0700 Subject: [PATCH 04/13] change signature --- onnxscript/ir/passes/common/graph_extration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/ir/passes/common/graph_extration.py b/onnxscript/ir/passes/common/graph_extration.py index bbdf6fc19e..9685f80a71 100644 --- a/onnxscript/ir/passes/common/graph_extration.py +++ b/onnxscript/ir/passes/common/graph_extration.py @@ -63,7 +63,7 @@ class ExtractGraphPass(ir.passes.PassBase): # This pass destroys the input model destructive = True - def __init__(self, *, input_names: Collection[str], output_names: Collection[str]) -> None: + def __init__(self, input_names: Collection[str], output_names: Collection[str]) -> None: """Extracts sub-model from an ONNX model. The sub-model is defined by the names of the input and output tensors *exactly*. From 78aefc2c469f6862ad5b61f53296e967b854822f Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 21 Mar 2025 18:11:27 -0700 Subject: [PATCH 05/13] warning --- onnxscript/ir/passes/common/graph_extration.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/onnxscript/ir/passes/common/graph_extration.py b/onnxscript/ir/passes/common/graph_extration.py index 9685f80a71..a42a9d91e1 100644 --- a/onnxscript/ir/passes/common/graph_extration.py +++ b/onnxscript/ir/passes/common/graph_extration.py @@ -139,13 +139,15 @@ def requires(self, model: ir.Model) -> None: for name in itertools.chain(self.input_names, self.output_names): value = values[name] if value.type is None: - raise ir.passes.PreconditionError( - f"Value {name} does not have a type: {value}. " - "Consider setting its type or running shape inference first." + logger.warning( + "Value %s does not have a type: %s. " + "Consider setting its type or running shape inference first.", + name, value ) if value.shape is None: - raise ir.passes.PreconditionError( - f"Value {name} does not have a shape: {value}. " - "Consider setting its shape or running shape inference first." + logger.warning( + "Value %s does not have a shape: %s. " + "Consider setting its shape or running shape inference first.", + name, value ) # TODO(justinchuby): Make sure the subgraph is completely bounded by inputs and outputs From ca03ed93b9065e6be9729cbc4a2ab57732392d32 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 21 Mar 2025 18:15:46 -0700 Subject: [PATCH 06/13] print --- onnxscript/ir/passes/common/graph_extration.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxscript/ir/passes/common/graph_extration.py b/onnxscript/ir/passes/common/graph_extration.py index a42a9d91e1..aa83cc1153 100644 --- a/onnxscript/ir/passes/common/graph_extration.py +++ b/onnxscript/ir/passes/common/graph_extration.py @@ -140,13 +140,13 @@ def requires(self, model: ir.Model) -> None: value = values[name] if value.type is None: logger.warning( - "Value %s does not have a type: %s. " + "Value %%%s does not have a type: '%r'. " "Consider setting its type or running shape inference first.", name, value ) if value.shape is None: logger.warning( - "Value %s does not have a shape: %s. " + "Value %%%s does not have a shape: '%r'. " "Consider setting its shape or running shape inference first.", name, value ) From 716413f0f35daef2af20ae18b5042779146fc45b Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 21 Mar 2025 18:18:19 -0700 Subject: [PATCH 07/13] Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- onnxscript/ir/passes/common/graph_extration.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/onnxscript/ir/passes/common/graph_extration.py b/onnxscript/ir/passes/common/graph_extration.py index aa83cc1153..3579214dda 100644 --- a/onnxscript/ir/passes/common/graph_extration.py +++ b/onnxscript/ir/passes/common/graph_extration.py @@ -56,8 +56,7 @@ def _find_subgraph_bounded_by_values( class ExtractGraphPass(ir.passes.PassBase): - """This pass performs shape inference on the graph.""" - + """This pass extracts a subgraph from the given graph.""" # This pass does not modify the model in place in_place = False # This pass destroys the input model From 307c62c530a66bf03727eb195fcdb092a0d44836 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 21 Mar 2025 18:22:35 -0700 Subject: [PATCH 08/13] keep initializers --- onnxscript/ir/passes/common/graph_extration.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/onnxscript/ir/passes/common/graph_extration.py b/onnxscript/ir/passes/common/graph_extration.py index 3579214dda..014aa99fc6 100644 --- a/onnxscript/ir/passes/common/graph_extration.py +++ b/onnxscript/ir/passes/common/graph_extration.py @@ -35,7 +35,7 @@ def _find_subgraph_bounded_by_values( value_stack: list[ir.Value] = [*outputs] visited_nodes: set[ir.Node] = set() visited_values: set[ir.Value] = set(inputs) - initializers = [] + initializers = [val for val in inputs if val.name in graph.initializers] while value_stack: value = value_stack.pop() if value in visited_values: @@ -57,6 +57,7 @@ def _find_subgraph_bounded_by_values( class ExtractGraphPass(ir.passes.PassBase): """This pass extracts a subgraph from the given graph.""" + # This pass does not modify the model in place in_place = False # This pass destroys the input model @@ -141,12 +142,14 @@ def requires(self, model: ir.Model) -> None: logger.warning( "Value %%%s does not have a type: '%r'. " "Consider setting its type or running shape inference first.", - name, value + name, + value, ) if value.shape is None: logger.warning( "Value %%%s does not have a shape: '%r'. " "Consider setting its shape or running shape inference first.", - name, value + name, + value, ) # TODO(justinchuby): Make sure the subgraph is completely bounded by inputs and outputs From b0d184353fc8a0b98c364ce4d1eb35bc5cde7fb6 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 21 Mar 2025 19:33:57 -0700 Subject: [PATCH 09/13] sort graph --- onnxscript/ir/passes/common/graph_extration.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/onnxscript/ir/passes/common/graph_extration.py b/onnxscript/ir/passes/common/graph_extration.py index 014aa99fc6..e320400ded 100644 --- a/onnxscript/ir/passes/common/graph_extration.py +++ b/onnxscript/ir/passes/common/graph_extration.py @@ -31,6 +31,9 @@ def _find_subgraph_bounded_by_values( Returns: A list of nodes in the subgraph and the initializers used. """ + node_index = { + node: idx for idx, node in enumerate(graph) + } all_nodes = [] value_stack: list[ir.Value] = [*outputs] visited_nodes: set[ir.Node] = set() @@ -52,6 +55,8 @@ def _find_subgraph_bounded_by_values( for input in node.inputs: if input not in visited_values and input is not None: value_stack.append(input) + # Preserve the original order + all_nodes.sort(key=lambda n: node_index[n]) return all_nodes, initializers From 4479afddd078c3b50da84d06af7240591f175b58 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 21 Mar 2025 19:34:24 -0700 Subject: [PATCH 10/13] lint --- onnxscript/ir/passes/common/graph_extration.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/onnxscript/ir/passes/common/graph_extration.py b/onnxscript/ir/passes/common/graph_extration.py index e320400ded..61f0b1dbfc 100644 --- a/onnxscript/ir/passes/common/graph_extration.py +++ b/onnxscript/ir/passes/common/graph_extration.py @@ -31,9 +31,7 @@ def _find_subgraph_bounded_by_values( Returns: A list of nodes in the subgraph and the initializers used. """ - node_index = { - node: idx for idx, node in enumerate(graph) - } + node_index = {node: idx for idx, node in enumerate(graph)} all_nodes = [] value_stack: list[ir.Value] = [*outputs] visited_nodes: set[ir.Node] = set() From 18d70b1b640e926672a6d97e6a388b82803b1489 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 26 Mar 2025 11:54:35 -0700 Subject: [PATCH 11/13] Update pass to be inplace --- .../ir/passes/common/graph_extration.py | 39 +++++++------------ 1 file changed, 13 insertions(+), 26 deletions(-) diff --git a/onnxscript/ir/passes/common/graph_extration.py b/onnxscript/ir/passes/common/graph_extration.py index 61f0b1dbfc..910dce964f 100644 --- a/onnxscript/ir/passes/common/graph_extration.py +++ b/onnxscript/ir/passes/common/graph_extration.py @@ -58,14 +58,9 @@ def _find_subgraph_bounded_by_values( return all_nodes, initializers -class ExtractGraphPass(ir.passes.PassBase): +class ExtractGraphPass(ir.passes.InPlacePass): """This pass extracts a subgraph from the given graph.""" - # This pass does not modify the model in place - in_place = False - # This pass destroys the input model - destructive = True - def __init__(self, input_names: Collection[str], output_names: Collection[str]) -> None: """Extracts sub-model from an ONNX model. @@ -102,27 +97,19 @@ def call(self, model: ir.Model) -> ir.passes.PassResult: ) ir.convenience.replace_all_uses_with(inputs, new_inputs) - new_model = ir.Model( - ir.Graph( - new_inputs, - outputs, - nodes=extracted_nodes, - initializers=initializers, - doc_string=model.graph.doc_string, - opset_imports=model.graph.opset_imports, - name=model.graph.name, - metadata_props=model.graph.metadata_props, - ), - ir_version=model.ir_version, - producer_name=model.producer_name, - producer_version=model.producer_version, - domain=model.domain, - model_version=model.model_version, - doc_string=model.doc_string, - functions=tuple(model.functions.values()), - meta_data_props=model.metadata_props, + # Replace the model graph + model.graph = ir.Graph( + new_inputs, + outputs, + nodes=extracted_nodes, + initializers=initializers, + doc_string=model.graph.doc_string, + opset_imports=model.graph.opset_imports, + name=model.graph.name, + metadata_props=model.graph.metadata_props, ) - return ir.passes.PassResult(new_model, modified=True) + + return ir.passes.PassResult(model, modified=True) def requires(self, model: ir.Model) -> None: # All inputs and outputs can be found in the model From 0fd082fa64f9584dab5e5c9a6b0867dd25851542 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 10 Apr 2025 21:43:57 -0700 Subject: [PATCH 12/13] Add tests for ExtractGraphPass in `graph_extration_test.py` * **Test extract subgraph**: Add a test case to validate the extraction of a subgraph with basic operations (Add, Mul). * **Test extract subgraph with initializers**: Add a test case to validate the extraction of a subgraph that includes initializers (Constant). * **Test extract subgraph with subgraph**: Add a test case to validate the extraction of a subgraph that includes nested subgraphs (If node with then and else branches). --- .../ir/passes/common/graph_extration_test.py | 125 ++++++++++++++++++ 1 file changed, 125 insertions(+) create mode 100644 onnxscript/ir/passes/common/graph_extration_test.py diff --git a/onnxscript/ir/passes/common/graph_extration_test.py b/onnxscript/ir/passes/common/graph_extration_test.py new file mode 100644 index 0000000000..a7490b0013 --- /dev/null +++ b/onnxscript/ir/passes/common/graph_extration_test.py @@ -0,0 +1,125 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import unittest + +from onnxscript import ir +from onnxscript.ir.passes.common.graph_extration import ExtractGraphPass + + +class TestExtractGraphPass(unittest.TestCase): + def test_extract_subgraph(self): + inputs = [ + ir.Value(name="input_a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3))), + ir.Value(name="input_b", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3))), + ] + + add_node = ir.node("Add", inputs=inputs) + mul_node = ir.node("Mul", inputs=[add_node.outputs[0], inputs[1]]) + + model = ir.Model( + graph=ir.Graph( + inputs=inputs, + outputs=mul_node.outputs, + nodes=[add_node, mul_node], + opset_imports={"": 20}, + ), + ir_version=10, + ) + + # Perform extract graph pass + extract_pass = ExtractGraphPass(input_names=["input_a"], output_names=[mul_node.outputs[0].name]) + result = extract_pass(model) + self.assertTrue(result.modified) + self.assertEqual(len(result.model.graph.nodes), 2) + self.assertEqual(result.model.graph.nodes[0].op_type, "Add") + self.assertEqual(result.model.graph.nodes[1].op_type, "Mul") + + def test_extract_subgraph_with_initializers(self): + inputs = [ + ir.Value(name="input_a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3))), + ir.Value(name="input_b", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3))), + ] + + constant_tensor = ir.tensor(np.random.rand(2, 3).astype(ir.DataType.FLOAT.numpy())) + const_node = ir.node( + "Constant", inputs=[], attributes={"value": constant_tensor}, num_outputs=1 + ) + add_node = ir.node("Add", inputs=[inputs[0], const_node.outputs[0]]) + mul_node = ir.node("Mul", inputs=[add_node.outputs[0], inputs[1]]) + + model = ir.Model( + graph=ir.Graph( + inputs=inputs, + outputs=mul_node.outputs, + nodes=[const_node, add_node, mul_node], + opset_imports={"": 20}, + ), + ir_version=10, + ) + + # Perform extract graph pass + extract_pass = ExtractGraphPass(input_names=["input_a"], output_names=[mul_node.outputs[0].name]) + result = extract_pass(model) + self.assertTrue(result.modified) + self.assertEqual(len(result.model.graph.nodes), 3) + self.assertEqual(result.model.graph.nodes[0].op_type, "Constant") + self.assertEqual(result.model.graph.nodes[1].op_type, "Add") + self.assertEqual(result.model.graph.nodes[2].op_type, "Mul") + + def test_extract_subgraph_with_subgraph(self): + input_value = ir.Value( + name="input", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3)) + ) + + then_constant_tensor = ir.tensor(np.random.rand(2, 3).astype(np.float32)) + then_const_node = ir.node( + "Constant", inputs=[], attributes={"value": then_constant_tensor}, num_outputs=1 + ) + add_node = ir.node("Add", inputs=[input_value, then_const_node.outputs[0]]) + then_graph = ir.Graph( + inputs=[input_value], + outputs=[add_node.outputs[0]], + nodes=[then_const_node, add_node], + opset_imports={"": 20}, + ) + else_constant_tensor = ir.tensor(np.random.rand(2, 3).astype(np.float32)) + else_const_node = ir.node( + "Constant", inputs=[], attributes={"value": else_constant_tensor}, num_outputs=1 + ) + mul_node = ir.node("Mul", inputs=[input_value, else_const_node.outputs[0]]) + else_graph = ir.Graph( + inputs=[input_value], + outputs=[mul_node.outputs[0]], + nodes=[else_const_node, mul_node], + opset_imports={"": 20}, + ) + cond_node = ir.node( + "If", + inputs=[input_value], + attributes={"then_branch": then_graph, "else_branch": else_graph}, + num_outputs=1, + ) + main_graph = ir.Graph( + inputs=[input_value], + outputs=cond_node.outputs, + nodes=[cond_node], + opset_imports={"": 20}, + ) + main_graph.sort() + model = ir.Model( + graph=main_graph, + ir_version=10, + ) + + # Perform extract graph pass + extract_pass = ExtractGraphPass(input_names=["input"], output_names=[cond_node.outputs[0].name]) + result = extract_pass(model) + self.assertTrue(result.modified) + self.assertEqual(len(result.model.graph.nodes), 1) + self.assertEqual(result.model.graph.nodes[0].op_type, "If") + self.assertEqual(len(result.model.graph.nodes[0].attributes["then_branch"].nodes), 2) + self.assertEqual(len(result.model.graph.nodes[0].attributes["else_branch"].nodes), 2) + + +if __name__ == "__main__": + unittest.main() From 35c45cf403d193e1fd2dbb7491e505aa61a334af Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 10 Apr 2025 21:49:07 -0700 Subject: [PATCH 13/13] --- .../ir/passes/common/graph_extration_test.py | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/onnxscript/ir/passes/common/graph_extration_test.py b/onnxscript/ir/passes/common/graph_extration_test.py index a7490b0013..e1ab63c56c 100644 --- a/onnxscript/ir/passes/common/graph_extration_test.py +++ b/onnxscript/ir/passes/common/graph_extration_test.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import unittest +import numpy as np from onnxscript import ir from onnxscript.ir.passes.common.graph_extration import ExtractGraphPass @@ -120,6 +121,34 @@ def test_extract_subgraph_with_subgraph(self): self.assertEqual(len(result.model.graph.nodes[0].attributes["then_branch"].nodes), 2) self.assertEqual(len(result.model.graph.nodes[0].attributes["else_branch"].nodes), 2) + def test_extract_partial_subgraph(self): + inputs = [ + ir.Value(name="input_a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3))), + ir.Value(name="input_b", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3))), + ] + + add_node = ir.node("Add", inputs=inputs) + mul_node = ir.node("Mul", inputs=[add_node.outputs[0], inputs[1]]) + sub_node = ir.node("Sub", inputs=[mul_node.outputs[0], inputs[0]]) + + model = ir.Model( + graph=ir.Graph( + inputs=inputs, + outputs=sub_node.outputs, + nodes=[add_node, mul_node, sub_node], + opset_imports={"": 20}, + ), + ir_version=10, + ) + + # Perform extract graph pass + extract_pass = ExtractGraphPass(input_names=["input_a"], output_names=[mul_node.outputs[0].name]) + result = extract_pass(model) + self.assertTrue(result.modified) + self.assertEqual(len(result.model.graph.nodes), 2) + self.assertEqual(result.model.graph.nodes[0].op_type, "Add") + self.assertEqual(result.model.graph.nodes[1].op_type, "Mul") + if __name__ == "__main__": unittest.main()