diff --git a/hiddenlayer/pytorch_builder.py b/hiddenlayer/pytorch_builder.py index 702c167..96c38cd 100644 --- a/hiddenlayer/pytorch_builder.py +++ b/hiddenlayer/pytorch_builder.py @@ -68,7 +68,7 @@ def import_graph(hl_graph, model, args, input_names=None, verbose=False): # Run the Pytorch graph to get a trace and generate a graph from it trace, out = torch.jit._get_trace_graph(model, args) - torch_graph = torch.onnx._optimize_trace(trace, torch.onnx.OperatorExportTypes.ONNX) + torch_graph = torch.onnx._optimize_graph(trace, torch.onnx.OperatorExportTypes.ONNX) # Dump list of nodes (DEBUG only) if verbose: