Skip to content

Error caused when torch.compile() and torchao.autoquant() used. #45

@ankit-vaidya19

Description

@ankit-vaidya19

torch.compile() and torchao.autoquant() work separately. But when used together, the following error is caused.

torch - 2.5.1 + cu124
torchao - 0.8.0 + cu124
diffusers - 0.33.0.dev0
GPU - A100 80GB

Code -

from diffusers import FluxPipeline
from torchao.quantization import autoquant
import torch 

pipeline = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16
).to("cuda")
pipeline.transformer.to(memory_format=torch.channels_last)
pipeline.transformer = autoquant(torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True),error_on_unseen=False)
image = pipeline(
    "a dog surfing on moon", guidance_scale=3.5, num_inference_steps=50
).images[0]

Error -

  0%|                                                                                                       | 0/50 [02:00<?, ?it/s]
Traceback (most recent call last):                                                                                                 
  File "/flux_dev.py", line 10, in <module>                                              
    image = pipeline(                                                                                                              
  File "/anaconda3/envs/optim/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in d
ecorate_context                                                                                                                    
    return func(*args, **kwargs)                                                                                                   
  File "/anaconda3/envs/optim/lib/python3.10/site-packages/diffusers/pipelines/flux/pipeline_flux.py",
 line 912, in __call__                                                                                                             
    noise_pred = self.transformer(                                                                                                 
  File "/anaconda3/envs/optim/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in 
_wrapped_call_impl                                                                                                                 
    return self._call_impl(*args, **kwargs)                                                                                        
  File "/anaconda3/envs/optim/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1844, in 
_call_impl                                                                                                                         
    return inner()    
  File "/anaconda3/envs/optim/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1790, in 
inner                                                                                                                              
    result = forward_call(*args, **kwargs)                                                                                         
  File "/anaconda3/envs/optim/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 465, in 
_fn                                                                                                                                
    return fn(*args, **kwargs)                                                                                                     
  File "/anaconda3/envs/optim/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in 
_wrapped_call_impl                                                                                                                 
    return self._call_impl(*args, **kwargs)                                                                                        
  File "/anaconda3/envs/optim/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in 
_call_impl                                                                                                                         
    return forward_call(*args, **kwargs)                                                                                           
  File "/anaconda3/envs/optim/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1269,
 in __call__                                                                                                                       
    return self._torchdynamo_orig_callable(                                                                                        
  File "/anaconda3/envs/optim/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 526, 
in __call__                                                                                                                        
    return _compile(                                                                                                               
  File "/anaconda3/envs/optim/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 952, 
in _compile                                                                                                                        
    raise InternalTorchDynamoError(                                                                                                
  File "/anaconda3/envs/optim/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 924, 
in _compile                                                                                                                        
    guarded_code = compile_inner(code, one_graph, hooks, transform)                                                                
  File "/anaconda3/envs/optim/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 666, 
in compile_inner                                                                                                                   
    return _compile_inner(code, one_graph, hooks, transform)                                                                       
  File "/anaconda3/envs/optim/lib/python3.10/site-packages/torch/_utils_internal.py", line 87, in wrap
per_function                                                                                                                       
    return function(*args, **kwargs)                                                                                               
  File "/anaconda3/envs/optim/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 699, 
in _compile_inner                                                                                                                  
    out_code = transform_code_object(code, transform)                                                                              
  File "/anaconda3/envs/optim/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", 
line 1322, in transform_code_object                                                                                                
    transformations(instructions, code_options)                                                                                    
  File "/anaconda3/envs/optim/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 219, 
in _fn                                                                                                                             
    return fn(*args, **kwargs)                                                                                                     
  File "/anaconda3/envs/optim/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 634, 
in transform                                                                                                                       
    tracer.run()                                                                                                                   
  File "/anaconda3/envs/optim/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 27
96, in run                                                                                                                         
    super().run()                                                                                                                   
  File "/anaconda3/envs/optim/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 98
3, in run                                                                                                                          
    while self.step():                                                                                                             
  File "/anaconda3/envs/optim/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 89
5, in step                                                                                                                         
    self.dispatch_table[inst.opcode](self, inst)                                                                                   
  File "/anaconda3/envs/optim/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 29
87, in RETURN_VALUE                                                                                                                
    self._return(inst)                                                                                                             
  File "/anaconda3/envs/optim/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 29
72, in _return                                                                                                                     
    self.output.compile_subgraph(                                                                                                  
  File "/anaconda3/envs/optim/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1142, 
in compile_subgraph                                                                                                                
    self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)                                                            
  File "/anaconda3/envs/optim/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1318, 
in compile_and_call_fx_graph                                                                                                       
    fx.GraphModule(root, self.graph),                                                                                              
  File "/anaconda3/envs/optim/lib/python3.10/site-packages/torch/fx/graph_module.py", line 471, in __i
nit__                                                                                                                              
    self.graph = graph                                                                                                             
  File "/anaconda3/envs/optim/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2032, in 
__setattr__                                                                                                                        
    super().__setattr__(name, value)                                                                                               
  File "/anaconda3/envs/optim/lib/python3.10/site-packages/torch/fx/graph_module.py", line 518, in gra
ph                                                                                                                                 
    self.recompile()                                                                                                               
  File /anaconda3/envs/optim/lib/python3.10/site-packages/torch/fx/graph_module.py", line 770, in rec
ompile                                                                                                                             
    cls.forward = _forward_from_src(self._code, python_code.globals, co_fields)                                                    
  File "/anaconda3/envs/optim/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 178, 
in fx_forward_from_src_skip_result                                                                                                 
    result = original_forward_from_src(src, globals, co_fields)                                                                    
  File "/anaconda3/envs/optim/lib/python3.10/site-packages/torch/fx/graph_module.py", line 90, in _for
ward_from_src                                                                                                                      
    return _method_from_src(                                                                                                       
  File "/anaconda3/envs/optim/lib/python3.10/site-packages/torch/fx/graph_module.py", line 100, in _me
thod_from_src                                                                                                                      
    _exec_with_source(src, globals_copy, co_fields)                                                                                
  File "anaconda3/envs/optim/lib/python3.10/site-packages/torch/fx/graph_module.py", line 86, in _exe
c_with_source                                                                                                                      
    exec(compile(src, key, "exec"), globals)                                                                                       
torch._dynamo.exc.InternalTorchDynamoError: SystemError: excessive stack use: stack is 6202 deep 

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions