From cb79401c3155b0c2c60859b77d6a3a243ea6be0b Mon Sep 17 00:00:00 2001 From: Benjamin Paine Date: Wed, 11 Mar 2026 18:01:05 -0400 Subject: [PATCH 1/2] Feat: auto-convert models to flashpack --- .../integrations/diffusers/pipeline.py | 175 +++++++++++------- tests/test_wan_pipeline.py | 75 +------- 2 files changed, 119 insertions(+), 131 deletions(-) diff --git a/src/flashpack/integrations/diffusers/pipeline.py b/src/flashpack/integrations/diffusers/pipeline.py index e03e88a..768ff66 100644 --- a/src/flashpack/integrations/diffusers/pipeline.py +++ b/src/flashpack/integrations/diffusers/pipeline.py @@ -27,7 +27,6 @@ ) from diffusers.utils.torch_utils import get_device, is_compiled_module from huggingface_hub import DDUFEntry, create_repo, read_dduf_file, snapshot_download -from packaging import version from typing_extensions import Self from ...constants import ( @@ -68,7 +67,6 @@ pass if is_transformers_available(): - import transformers from transformers import PreTrainedModel LIBRARIES = [] @@ -105,6 +103,8 @@ def save_pretrained_flashpack( silent: bool = True, num_workers: int = DEFAULT_NUM_WRITE_WORKERS, target_dtype: torch.dtype | dict[str, torch.dtype] | None = None, + convert_diffusers_models: bool = False, + convert_transformers_models: bool = False, **kwargs: Any, ) -> None: """ @@ -137,18 +137,54 @@ def is_saveable_module(name, value): for pipeline_component_name in model_index_dict.keys(): sub_model = getattr(self, pipeline_component_name) sub_model_dir = os.path.join(save_directory, pipeline_component_name) - model_cls = sub_model.__class__ + # Dynamo wraps the original model in a private class. + # I didn't find a public API to get the original class. + if is_compiled_module(sub_model): + sub_model = _unwrap_model(sub_model) + + model_cls = sub_model.__class__ sub_model_target_dtype = ( target_dtype.get(pipeline_component_name, None) if isinstance(target_dtype, dict) else target_dtype ) + is_diffusers_model = isinstance(sub_model, ModelMixin) + is_transformers_model = is_transformers_available() and isinstance( + sub_model, PreTrainedModel + ) + is_flashpack_diffusers_model = isinstance( + sub_model, FlashPackDiffusersModelMixin + ) + is_flashpack_transformers_model = isinstance( + sub_model, FlashPackTransformersModelMixin + ) + + if ( + is_diffusers_model + and not is_flashpack_diffusers_model + and convert_diffusers_models + ): + sub_model.__class__ = type( + f"FlashPackAutoClass{model_cls.__name__}", + (FlashPackDiffusersModelMixin, model_cls), + {}, + ) + is_flashpack_diffusers_model = True - if isinstance( - sub_model, - (FlashPackDiffusersModelMixin, FlashPackTransformersModelMixin), + if ( + is_transformers_model + and not is_flashpack_transformers_model + and convert_transformers_models ): + sub_model.__class__ = type( + f"FlashPackAutoClass{model_cls.__name__}", + (FlashPackTransformersModelMixin, model_cls), + {}, + ) + is_flashpack_transformers_model = True + + if is_flashpack_diffusers_model or is_flashpack_transformers_model: os.makedirs(sub_model_dir, exist_ok=True) sub_model.save_pretrained_flashpack( sub_model_dir, @@ -160,12 +196,6 @@ def is_saveable_module(name, value): ) continue - # Dynamo wraps the original model in a private class. - # I didn't find a public API to get the original class. - if is_compiled_module(sub_model): - sub_model = _unwrap_model(sub_model) - model_cls = sub_model.__class__ - save_method_name = None # search for the model's base class in LOADABLE_CLASSES for library_name, library_classes in LOADABLE_CLASSES.items(): @@ -248,6 +278,8 @@ def from_pretrained_flashpack( silent: bool = True, use_distributed_loading: bool = False, coerce_dtype: bool = False, + convert_diffusers_models: bool = False, + convert_transformers_models: bool = False, **kwargs: Any, ) -> Self: """ @@ -618,6 +650,8 @@ def load_module(name, value): local_rank=local_rank, world_size=world_size, coerce_dtype=coerce_dtype, + convert_diffusers_models=convert_diffusers_models, + convert_transformers_models=convert_transformers_models, ) except Exception as e: raise RuntimeError( @@ -725,6 +759,8 @@ def load_sub_model_flashpack( local_rank: int | None = None, world_size: int | None = None, coerce_dtype: bool = False, + convert_diffusers_models: bool = False, + convert_transformers_models: bool = False, ) -> Any: """ Helper method to load the module `name` from `library_name` and `class_name`. @@ -735,6 +771,11 @@ def load_sub_model_flashpack( """ from diffusers.quantizers import PipelineQuantizationConfig + print(f"Loading {class_name} from {library_name}") + is_auto_class = class_name.startswith("FlashPackAutoClass") + if is_auto_class: + class_name = class_name[len("FlashPackAutoClass") :] + # retrieve class candidates class_obj, class_candidates = get_class_obj_and_candidates( library_name, @@ -746,33 +787,68 @@ def load_sub_model_flashpack( cache_dir=cached_folder, ) - # Check if flashpack - if issubclass(class_obj, FlashPackDiffusersModelMixin) or issubclass( + component_dir = os.path.join(cached_folder, name) + is_diffusers_model = issubclass(class_obj, ModelMixin) + is_transformers_model = is_transformers_available() and issubclass( + class_obj, PreTrainedModel + ) + is_flashpack_diffusers_model = issubclass(class_obj, FlashPackDiffusersModelMixin) + is_flashpack_transformers_model = issubclass( class_obj, FlashPackTransformersModelMixin + ) + + if ( + is_diffusers_model + and not is_flashpack_diffusers_model + and (convert_diffusers_models or is_auto_class) ): - component_dir = os.path.join(cached_folder, name) + class_obj = type( + f"FlashPackAutoClass{class_obj.__name__}", + (FlashPackDiffusersModelMixin, class_obj), + {}, + ) + is_flashpack_diffusers_model = True - if device is None: - if device_map in ["cuda", "auto", "balanced"] and torch.cuda.is_available(): - device_index = torch.cuda.current_device() - device = torch.device(f"cuda:{device_index}") - else: - device = torch.device("cpu") - - return class_obj.from_pretrained_flashpack( - component_dir, - device=device, - num_streams=num_streams, - chunk_bytes=chunk_bytes, - ignore_names=ignore_names, - ignore_prefixes=ignore_prefixes, - silent=silent, - use_distributed_loading=use_distributed_loading, - rank=rank, - local_rank=local_rank, - world_size=world_size, - coerce_dtype=coerce_dtype, - ).to(device) + if ( + is_transformers_model + and not is_flashpack_transformers_model + and (convert_transformers_models or is_auto_class) + ): + class_obj = type( + f"FlashPackAutoClass{class_obj.__name__}", + (FlashPackTransformersModelMixin, class_obj), + {}, + ) + is_flashpack_transformers_model = True + + # Check if flashpack + if is_flashpack_diffusers_model or is_flashpack_transformers_model: + flashpack_model_path = os.path.join(component_dir, "model.flashpack") + if os.path.exists(flashpack_model_path): + if device is None: + if ( + device_map in ["cuda", "auto", "balanced"] + and torch.cuda.is_available() + ): + device_index = torch.cuda.current_device() + device = torch.device(f"cuda:{device_index}") + else: + device = torch.device("cpu") + + return class_obj.from_pretrained_flashpack( + component_dir, + device=device, + num_streams=num_streams, + chunk_bytes=chunk_bytes, + ignore_names=ignore_names, + ignore_prefixes=ignore_prefixes, + silent=silent, + use_distributed_loading=use_distributed_loading, + rank=rank, + local_rank=local_rank, + world_size=world_size, + coerce_dtype=coerce_dtype, + ).to(device) load_method_name = None # retrieve load method name @@ -808,21 +884,6 @@ def load_sub_model_flashpack( loading_kwargs["sess_options"] = sess_options loading_kwargs["provider_options"] = provider_options - is_diffusers_model = issubclass(class_obj, ModelMixin) - - if is_transformers_available(): - transformers_version = version.parse( - version.parse(transformers.__version__).base_version - ) - else: - transformers_version = "N/A" - - is_transformers_model = ( - is_transformers_available() - and issubclass(class_obj, PreTrainedModel) - and transformers_version >= version.parse("4.20.0") - ) - # When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers. # To make default loading faster we set the `low_cpu_mem_usage=low_cpu_mem_usage` flag which is `True` by default. # This makes sure that the weights won't be initialized which significantly speeds up loading. @@ -837,17 +898,7 @@ def load_sub_model_flashpack( if from_flax: loading_kwargs["from_flax"] = True - # the following can be deleted once the minimum required `transformers` version - # is higher than 4.27 - if ( - is_transformers_model - and loading_kwargs["variant"] is not None - and transformers_version < version.parse("4.27.0") - ): - raise ImportError( - f"When passing `variant='{variant}'`, please make sure to upgrade your `transformers` version to at least 4.27.0.dev0" - ) - elif is_transformers_model and loading_kwargs["variant"] is None: + if is_transformers_model and loading_kwargs["variant"] is None: loading_kwargs.pop("variant") # if `from_flax` and model is transformer model, can currently not load with `low_cpu_mem_usage` diff --git a/tests/test_wan_pipeline.py b/tests/test_wan_pipeline.py index 68ab69a..3e5784e 100644 --- a/tests/test_wan_pipeline.py +++ b/tests/test_wan_pipeline.py @@ -1,57 +1,15 @@ import os -from typing import Optional import pytest import torch -from diffusers.models import AutoencoderKLWan, WanTransformer3DModel from diffusers.pipelines import WanPipeline -from diffusers.schedulers import UniPCMultistepScheduler -from flashpack.integrations.diffusers import ( - FlashPackDiffusersModelMixin, - FlashPackDiffusionPipeline, -) -from flashpack.integrations.transformers import FlashPackTransformersModelMixin +from flashpack.integrations.diffusers import FlashPackDiffusionPipeline from flashpack.utils import timer from huggingface_hub import snapshot_download -from transformers import AutoTokenizer, UMT5EncoderModel - - -class FlashPackWanTransformer3DModel( - WanTransformer3DModel, FlashPackDiffusersModelMixin -): - flashpack_ignore_prefixes = ["rope"] - - -class FlashPackAutoencoderKLWan(AutoencoderKLWan, FlashPackDiffusersModelMixin): - pass - - -class FlashPackUMT5EncoderModel(UMT5EncoderModel, FlashPackTransformersModelMixin): - flashpack_ignore_names = ["encoder.embed_tokens.weight"] class FlashPackWanPipeline(WanPipeline, FlashPackDiffusionPipeline): - def __init__( - self, - tokenizer: AutoTokenizer, - text_encoder: FlashPackUMT5EncoderModel, - vae: FlashPackAutoencoderKLWan, - scheduler: UniPCMultistepScheduler, - transformer: Optional[FlashPackWanTransformer3DModel] = None, - transformer_2: Optional[FlashPackWanTransformer3DModel] = None, - boundary_ratio: float | None = None, - expand_timesteps: bool = False, - ): - super().__init__( - tokenizer=tokenizer, - text_encoder=text_encoder, - vae=vae, - transformer=transformer, - transformer_2=transformer_2, - scheduler=scheduler, - boundary_ratio=boundary_ratio, - expand_timesteps=expand_timesteps, - ) + pass HERE = os.path.dirname(os.path.abspath(__file__)) @@ -74,31 +32,10 @@ def pipeline_dir(): @pytest.fixture(scope="module") def saved_pipeline(repo_dir, pipeline_dir): """Save the pipeline using flashpack and return the path.""" - transformer = FlashPackWanTransformer3DModel.from_pretrained( - os.path.join(repo_dir, "transformer"), - torch_dtype=torch.bfloat16, - ).to(dtype=torch.bfloat16) - vae = FlashPackAutoencoderKLWan.from_pretrained( - os.path.join(repo_dir, "vae"), - torch_dtype=torch.float32, - ).to(dtype=torch.float32) - text_encoder = FlashPackUMT5EncoderModel.from_pretrained( - os.path.join(repo_dir, "text_encoder"), - torch_dtype=torch.bfloat16, - ).to(dtype=torch.bfloat16) - scheduler = UniPCMultistepScheduler.from_pretrained( - os.path.join(repo_dir, "scheduler"), - ) - tokenizer = AutoTokenizer.from_pretrained( - os.path.join(repo_dir, "tokenizer"), - ) - - pipeline = FlashPackWanPipeline( - tokenizer=tokenizer, - text_encoder=text_encoder, - vae=vae, - transformer=transformer, - scheduler=scheduler, + pipeline = FlashPackWanPipeline.from_pretrained_flashpack( + repo_dir, + convert_diffusers_models=True, + convert_transformers_models=True, ) with timer("save"): From 8191f4239412ea6e1f48018771c9ca85450c3144 Mon Sep 17 00:00:00 2001 From: Benjamin Paine Date: Wed, 11 Mar 2026 18:01:40 -0400 Subject: [PATCH 2/2] Remove print --- src/flashpack/integrations/diffusers/pipeline.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/flashpack/integrations/diffusers/pipeline.py b/src/flashpack/integrations/diffusers/pipeline.py index 768ff66..c00776e 100644 --- a/src/flashpack/integrations/diffusers/pipeline.py +++ b/src/flashpack/integrations/diffusers/pipeline.py @@ -771,7 +771,6 @@ def load_sub_model_flashpack( """ from diffusers.quantizers import PipelineQuantizationConfig - print(f"Loading {class_name} from {library_name}") is_auto_class = class_name.startswith("FlashPackAutoClass") if is_auto_class: class_name = class_name[len("FlashPackAutoClass") :]