Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 112 additions & 62 deletions src/flashpack/integrations/diffusers/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
)
from diffusers.utils.torch_utils import get_device, is_compiled_module
from huggingface_hub import DDUFEntry, create_repo, read_dduf_file, snapshot_download
from packaging import version
from typing_extensions import Self

from ...constants import (
Expand Down Expand Up @@ -68,7 +67,6 @@
pass

if is_transformers_available():
import transformers
from transformers import PreTrainedModel

LIBRARIES = []
Expand Down Expand Up @@ -105,6 +103,8 @@ def save_pretrained_flashpack(
silent: bool = True,
num_workers: int = DEFAULT_NUM_WRITE_WORKERS,
target_dtype: torch.dtype | dict[str, torch.dtype] | None = None,
convert_diffusers_models: bool = False,
convert_transformers_models: bool = False,
**kwargs: Any,
) -> None:
"""
Expand Down Expand Up @@ -137,18 +137,54 @@ def is_saveable_module(name, value):
for pipeline_component_name in model_index_dict.keys():
sub_model = getattr(self, pipeline_component_name)
sub_model_dir = os.path.join(save_directory, pipeline_component_name)
model_cls = sub_model.__class__

# Dynamo wraps the original model in a private class.
# I didn't find a public API to get the original class.
if is_compiled_module(sub_model):
sub_model = _unwrap_model(sub_model)

model_cls = sub_model.__class__
sub_model_target_dtype = (
target_dtype.get(pipeline_component_name, None)
if isinstance(target_dtype, dict)
else target_dtype
)
is_diffusers_model = isinstance(sub_model, ModelMixin)
is_transformers_model = is_transformers_available() and isinstance(
sub_model, PreTrainedModel
)
is_flashpack_diffusers_model = isinstance(
sub_model, FlashPackDiffusersModelMixin
)
is_flashpack_transformers_model = isinstance(
sub_model, FlashPackTransformersModelMixin
)

if (
is_diffusers_model
and not is_flashpack_diffusers_model
and convert_diffusers_models
):
sub_model.__class__ = type(
f"FlashPackAutoClass{model_cls.__name__}",
(FlashPackDiffusersModelMixin, model_cls),
{},
)
is_flashpack_diffusers_model = True

if isinstance(
sub_model,
(FlashPackDiffusersModelMixin, FlashPackTransformersModelMixin),
if (
is_transformers_model
and not is_flashpack_transformers_model
and convert_transformers_models
):
sub_model.__class__ = type(
f"FlashPackAutoClass{model_cls.__name__}",
(FlashPackTransformersModelMixin, model_cls),
{},
)
is_flashpack_transformers_model = True

if is_flashpack_diffusers_model or is_flashpack_transformers_model:
os.makedirs(sub_model_dir, exist_ok=True)
sub_model.save_pretrained_flashpack(
sub_model_dir,
Expand All @@ -160,12 +196,6 @@ def is_saveable_module(name, value):
)
continue

# Dynamo wraps the original model in a private class.
# I didn't find a public API to get the original class.
if is_compiled_module(sub_model):
sub_model = _unwrap_model(sub_model)
model_cls = sub_model.__class__

save_method_name = None
# search for the model's base class in LOADABLE_CLASSES
for library_name, library_classes in LOADABLE_CLASSES.items():
Expand Down Expand Up @@ -248,6 +278,8 @@ def from_pretrained_flashpack(
silent: bool = True,
use_distributed_loading: bool = False,
coerce_dtype: bool = False,
convert_diffusers_models: bool = False,
convert_transformers_models: bool = False,
**kwargs: Any,
) -> Self:
"""
Expand Down Expand Up @@ -618,6 +650,8 @@ def load_module(name, value):
local_rank=local_rank,
world_size=world_size,
coerce_dtype=coerce_dtype,
convert_diffusers_models=convert_diffusers_models,
convert_transformers_models=convert_transformers_models,
)
except Exception as e:
raise RuntimeError(
Expand Down Expand Up @@ -725,6 +759,8 @@ def load_sub_model_flashpack(
local_rank: int | None = None,
world_size: int | None = None,
coerce_dtype: bool = False,
convert_diffusers_models: bool = False,
convert_transformers_models: bool = False,
) -> Any:
"""
Helper method to load the module `name` from `library_name` and `class_name`.
Expand All @@ -735,6 +771,10 @@ def load_sub_model_flashpack(
"""
from diffusers.quantizers import PipelineQuantizationConfig

is_auto_class = class_name.startswith("FlashPackAutoClass")
if is_auto_class:
class_name = class_name[len("FlashPackAutoClass") :]

# retrieve class candidates
class_obj, class_candidates = get_class_obj_and_candidates(
library_name,
Expand All @@ -746,33 +786,68 @@ def load_sub_model_flashpack(
cache_dir=cached_folder,
)

# Check if flashpack
if issubclass(class_obj, FlashPackDiffusersModelMixin) or issubclass(
component_dir = os.path.join(cached_folder, name)
is_diffusers_model = issubclass(class_obj, ModelMixin)
is_transformers_model = is_transformers_available() and issubclass(
class_obj, PreTrainedModel
)
is_flashpack_diffusers_model = issubclass(class_obj, FlashPackDiffusersModelMixin)
is_flashpack_transformers_model = issubclass(
class_obj, FlashPackTransformersModelMixin
)

if (
is_diffusers_model
and not is_flashpack_diffusers_model
and (convert_diffusers_models or is_auto_class)
):
component_dir = os.path.join(cached_folder, name)
class_obj = type(
f"FlashPackAutoClass{class_obj.__name__}",
(FlashPackDiffusersModelMixin, class_obj),
{},
)
is_flashpack_diffusers_model = True

if device is None:
if device_map in ["cuda", "auto", "balanced"] and torch.cuda.is_available():
device_index = torch.cuda.current_device()
device = torch.device(f"cuda:{device_index}")
else:
device = torch.device("cpu")

return class_obj.from_pretrained_flashpack(
component_dir,
device=device,
num_streams=num_streams,
chunk_bytes=chunk_bytes,
ignore_names=ignore_names,
ignore_prefixes=ignore_prefixes,
silent=silent,
use_distributed_loading=use_distributed_loading,
rank=rank,
local_rank=local_rank,
world_size=world_size,
coerce_dtype=coerce_dtype,
).to(device)
if (
is_transformers_model
and not is_flashpack_transformers_model
and (convert_transformers_models or is_auto_class)
):
class_obj = type(
f"FlashPackAutoClass{class_obj.__name__}",
(FlashPackTransformersModelMixin, class_obj),
{},
)
is_flashpack_transformers_model = True

# Check if flashpack
if is_flashpack_diffusers_model or is_flashpack_transformers_model:
flashpack_model_path = os.path.join(component_dir, "model.flashpack")
if os.path.exists(flashpack_model_path):
if device is None:
if (
device_map in ["cuda", "auto", "balanced"]
and torch.cuda.is_available()
):
device_index = torch.cuda.current_device()
device = torch.device(f"cuda:{device_index}")
else:
device = torch.device("cpu")

return class_obj.from_pretrained_flashpack(
component_dir,
device=device,
num_streams=num_streams,
chunk_bytes=chunk_bytes,
ignore_names=ignore_names,
ignore_prefixes=ignore_prefixes,
silent=silent,
use_distributed_loading=use_distributed_loading,
rank=rank,
local_rank=local_rank,
world_size=world_size,
coerce_dtype=coerce_dtype,
).to(device)

load_method_name = None
# retrieve load method name
Expand Down Expand Up @@ -808,21 +883,6 @@ def load_sub_model_flashpack(
loading_kwargs["sess_options"] = sess_options
loading_kwargs["provider_options"] = provider_options

is_diffusers_model = issubclass(class_obj, ModelMixin)

if is_transformers_available():
transformers_version = version.parse(
version.parse(transformers.__version__).base_version
)
else:
transformers_version = "N/A"

is_transformers_model = (
is_transformers_available()
and issubclass(class_obj, PreTrainedModel)
and transformers_version >= version.parse("4.20.0")
)

# When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers.
# To make default loading faster we set the `low_cpu_mem_usage=low_cpu_mem_usage` flag which is `True` by default.
# This makes sure that the weights won't be initialized which significantly speeds up loading.
Expand All @@ -837,17 +897,7 @@ def load_sub_model_flashpack(
if from_flax:
loading_kwargs["from_flax"] = True

# the following can be deleted once the minimum required `transformers` version
# is higher than 4.27
if (
is_transformers_model
and loading_kwargs["variant"] is not None
and transformers_version < version.parse("4.27.0")
):
raise ImportError(
f"When passing `variant='{variant}'`, please make sure to upgrade your `transformers` version to at least 4.27.0.dev0"
)
elif is_transformers_model and loading_kwargs["variant"] is None:
if is_transformers_model and loading_kwargs["variant"] is None:
loading_kwargs.pop("variant")

# if `from_flax` and model is transformer model, can currently not load with `low_cpu_mem_usage`
Expand Down
75 changes: 6 additions & 69 deletions tests/test_wan_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,15 @@
import os
from typing import Optional

import pytest
import torch
from diffusers.models import AutoencoderKLWan, WanTransformer3DModel
from diffusers.pipelines import WanPipeline
from diffusers.schedulers import UniPCMultistepScheduler
from flashpack.integrations.diffusers import (
FlashPackDiffusersModelMixin,
FlashPackDiffusionPipeline,
)
from flashpack.integrations.transformers import FlashPackTransformersModelMixin
from flashpack.integrations.diffusers import FlashPackDiffusionPipeline
from flashpack.utils import timer
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer, UMT5EncoderModel


class FlashPackWanTransformer3DModel(
WanTransformer3DModel, FlashPackDiffusersModelMixin
):
flashpack_ignore_prefixes = ["rope"]


class FlashPackAutoencoderKLWan(AutoencoderKLWan, FlashPackDiffusersModelMixin):
pass


class FlashPackUMT5EncoderModel(UMT5EncoderModel, FlashPackTransformersModelMixin):
flashpack_ignore_names = ["encoder.embed_tokens.weight"]


class FlashPackWanPipeline(WanPipeline, FlashPackDiffusionPipeline):
def __init__(
self,
tokenizer: AutoTokenizer,
text_encoder: FlashPackUMT5EncoderModel,
vae: FlashPackAutoencoderKLWan,
scheduler: UniPCMultistepScheduler,
transformer: Optional[FlashPackWanTransformer3DModel] = None,
transformer_2: Optional[FlashPackWanTransformer3DModel] = None,
boundary_ratio: float | None = None,
expand_timesteps: bool = False,
):
super().__init__(
tokenizer=tokenizer,
text_encoder=text_encoder,
vae=vae,
transformer=transformer,
transformer_2=transformer_2,
scheduler=scheduler,
boundary_ratio=boundary_ratio,
expand_timesteps=expand_timesteps,
)
pass


HERE = os.path.dirname(os.path.abspath(__file__))
Expand All @@ -74,31 +32,10 @@ def pipeline_dir():
@pytest.fixture(scope="module")
def saved_pipeline(repo_dir, pipeline_dir):
"""Save the pipeline using flashpack and return the path."""
transformer = FlashPackWanTransformer3DModel.from_pretrained(
os.path.join(repo_dir, "transformer"),
torch_dtype=torch.bfloat16,
).to(dtype=torch.bfloat16)
vae = FlashPackAutoencoderKLWan.from_pretrained(
os.path.join(repo_dir, "vae"),
torch_dtype=torch.float32,
).to(dtype=torch.float32)
text_encoder = FlashPackUMT5EncoderModel.from_pretrained(
os.path.join(repo_dir, "text_encoder"),
torch_dtype=torch.bfloat16,
).to(dtype=torch.bfloat16)
scheduler = UniPCMultistepScheduler.from_pretrained(
os.path.join(repo_dir, "scheduler"),
)
tokenizer = AutoTokenizer.from_pretrained(
os.path.join(repo_dir, "tokenizer"),
)

pipeline = FlashPackWanPipeline(
tokenizer=tokenizer,
text_encoder=text_encoder,
vae=vae,
transformer=transformer,
scheduler=scheduler,
pipeline = FlashPackWanPipeline.from_pretrained_flashpack(
repo_dir,
convert_diffusers_models=True,
convert_transformers_models=True,
)

with timer("save"):
Expand Down
Loading