From 73f69ceec768ff384eb66e58041b2e87bd6cc594 Mon Sep 17 00:00:00 2001 From: ByteDance Date: Mon, 27 Jan 2025 11:20:26 +0800 Subject: [PATCH] Fix for PyTorch2.x: remove _optimize_trace & handle torch_node attributes --- hiddenlayer/pytorch_builder.py | 85 +++++++++++++++++++++++++--------- 1 file changed, 62 insertions(+), 23 deletions(-) diff --git a/hiddenlayer/pytorch_builder.py b/hiddenlayer/pytorch_builder.py index 702c167..43edb8a 100644 --- a/hiddenlayer/pytorch_builder.py +++ b/hiddenlayer/pytorch_builder.py @@ -2,7 +2,7 @@ HiddenLayer PyTorch graph importer. - + Written by Waleed Abdulla Licensed under the MIT License """ @@ -18,7 +18,7 @@ # Hide onnx: prefix ht.Rename(op=r"onnx::(.*)", to=r"\1"), # ONNX uses Gemm for linear layers (stands for General Matrix Multiplication). - # It's an odd name that noone recognizes. Rename it. + # It's an odd name that no one recognizes. Rename it. ht.Rename(op=r"Gemm", to=r"Linear"), # PyTorch layers that don't have an ONNX counterpart ht.Rename(op=r"aten::max\_pool2d\_with\_indices", to="MaxPool"), @@ -49,10 +49,9 @@ def get_shape(torch_node): """Return the output shape of the given Pytorch node.""" # Extract node output shape from the node string representation # This is a hack because there doesn't seem to be an official way to do it. - # See my quesiton in the PyTorch forum: + # See question in the PyTorch forum: # https://discuss.pytorch.org/t/node-output-shape-from-trace-graph/24351/2 - # TODO: find a better way to extract output shape - # TODO: Assuming the node has one output. Update if we encounter a multi-output node. + # Assuming the node has one output. Update if we encounter a multi-output node. m = re.match(r".*Float\(([\d\s\,]+)\).*", str(next(torch_node.outputs()))) if m: shape = m.group(1) @@ -64,34 +63,74 @@ def get_shape(torch_node): def import_graph(hl_graph, model, args, input_names=None, verbose=False): - # TODO: add input names to graph - - # Run the Pytorch graph to get a trace and generate a graph from it + """ + Build a hiddenlayer Graph from a PyTorch JIT trace. + """ + # 1) Get trace graph trace, out = torch.jit._get_trace_graph(model, args) - torch_graph = torch.onnx._optimize_trace(trace, torch.onnx.OperatorExportTypes.ONNX) - # Dump list of nodes (DEBUG only) + # 2) Comment out or remove the _optimize_trace call: + # torch_graph = torch.onnx._optimize_trace(trace, torch.onnx.OperatorExportTypes.ONNX) + torch_graph = trace # Use trace directly + + # Debug: optionally dump the list of nodes if verbose: dump_pytorch_graph(torch_graph) - # Loop through nodes and build HL graph + # 3) Traverse the PyTorch graph and build hiddenlayer nodes/edges for torch_node in torch_graph.nodes(): - # Op + # a) Operator kind op = torch_node.kind() - # Parameters - params = {k: torch_node[k] for k in torch_node.attributeNames()} - # Inputs/outputs - # TODO: inputs = [i.unique() for i in node.inputs()] + + # b) Gather attributes (fix 'torch_node[k]' error) + # Use kindOf(k) + corresponding accessor + params = {} + for k in torch_node.attributeNames(): + kind = torch_node.kindOf(k) + if kind == "f": + params[k] = torch_node.f(k) + elif kind == "i": + params[k] = torch_node.i(k) + elif kind == "s": + params[k] = torch_node.s(k) + elif kind == "t": + # e.g. tensor attribute + params[k] = str(torch_node.t(k)) # or more specialized logic + elif kind == "fs": + params[k] = torch_node.fs(k) + elif kind == "is": + params[k] = torch_node.is_(k) + elif kind == "ss": + params[k] = torch_node.ss(k) + else: + # If there's an unrecognized type, store the type name + params[k] = f"<{kind}>" + + # c) Node outputs outputs = [o.unique() for o in torch_node.outputs()] - # Get output shape + + # d) Infer shape from outputs shape = get_shape(torch_node) - # Add HL node - hl_node = Node(uid=pytorch_id(torch_node), name=None, op=op, - output_shape=shape, params=params) + + # e) Create HL node + hl_node = Node( + uid=pytorch_id(torch_node), + name=None, + op=op, + output_shape=shape, + params=params + ) hl_graph.add_node(hl_node) - # Add edges + + # f) Link edges to next nodes that consume these outputs for target_torch_node in torch_graph.nodes(): target_inputs = [i.unique() for i in target_torch_node.inputs()] + # If any output from this node is in target's input => link them if set(outputs) & set(target_inputs): - hl_graph.add_edge_by_id(pytorch_id(torch_node), pytorch_id(target_torch_node), shape) - return hl_graph + hl_graph.add_edge_by_id( + pytorch_id(torch_node), + pytorch_id(target_torch_node), + shape + ) + + return hl_graph \ No newline at end of file