Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 164 additions & 0 deletions examples/text_to_image/train_unet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
import diffusers
from diffusers import (
AutoencoderKL,
DDPMScheduler,
StableDiffusionXLPipeline,
UNet2DConditionModel,
)
from apex.optimizers import FusedAdam
import torch
import torch.nn.functional as F
from accelerate import Accelerator
import time
import functools

from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from torch.distributed.optim import ZeroRedundancyOptimizer
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardingStrategy


def enable_tf32():
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

def batch_inp(bs20inp, target_bs):
mbs=2
bs4inp = bs20inp[:mbs]
if target_bs<mbs:
return bs20inp[:target_bs]
if target_bs==mbs:
return bs20inp
num = int(target_bs/mbs)
out = torch.cat([bs4inp.clone().detach() for _ in range(num)])
if out.dim()==4:
out = out.to(memory_format=torch.channels_last).contiguous()
print(out.shape)
return out



def train(model, vae, optimizer_class, batchsize, use_zero=False, use_amp=True, h=512, w=512, is_xl=False):
timesteps = torch.arange(batchsize, dtype=torch.int64).cuda()+100
prompt_embeds = torch.rand([batchsize,77,768], dtype=torch.float16).cuda()
time_ids = torch.rand([batchsize,6], dtype=torch.float16).cuda()
text_embeds = torch.rand([batchsize,1280], dtype=torch.float16).cuda()
encoder_hidden_states = torch.rand([batchsize,77,768], dtype=torch.float32).cuda()

model_input = torch.rand([batchsize, 3, h, w], dtype=torch.float32).cuda()

if not use_amp:
prompt_embeds = prompt_embeds.float()
text_embeds = text_embeds.float()
time_ids = time_ids.float()

unet_added_conditions = {
"time_ids": time_ids,
"text_embeds": text_embeds
}


model.enable_gradient_checkpointing()
# model.enable_xformers_memory_efficient_attention()
# torch._dynamo.config.suppress_errors = True
# model=torch.compile(model)
perf_times = []
if use_zero:
# model = DDP(model)
model = FSDP(model, sharding_strategy=torch.distributed.fsdp.ShardingStrategy.SHARD_GRAD_OP)
opt =optimizer_class(model.parameters())
# opt = ZeroRedundancyOptimizer(model.parameters(),
# optimizer_class=optimizer_class,
# parameters_as_bucket_view=True,)
else:
opt =optimizer_class(model.parameters())

from torch.profiler import profile, record_function, ProfilerActivity

# prof = torch.profiler.profile(
# schedule=torch.profiler.schedule(wait=0, warmup=5, active=1, repeat=1),
# on_trace_ready=torch.profiler.tensorboard_trace_handler('./prof/unet_720p_tp2'),
# activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
# record_shapes=True,
# with_stack=True,
# profile_memory=True)
# prof.start()
for ind in range(20):
# if ind==5:
beg = time.time()

torch.cuda.synchronize()
vae_beg = time.time()
with torch.no_grad():
noisy_model_input = vae.encode(model_input).latent_dist.sample().mul_(0.18215)
torch.cuda.synchronize()
print("vae time:", time.time()-vae_beg)


with torch.autocast(dtype=torch.float16, device_type='cuda', enabled=use_amp):
# with accelerator.accumulate(unet):
# import pdb;pdb.set_trace()
# print("before fwd", torch.cuda.memory_allocated()/1e9, torch.cuda.max_memory_allocated()/1e9)
# with torch.no_grad():

# model_pred = unet(noisy_model_input, timesteps, encoder_hidden_states).sample
# print("after fwd", torch.cuda.memory_allocated()/1e9, torch.cuda.max_memory_allocated()/1e9)
# import pdb;pdb.set_trace()
# loss = F.mse_loss(model_pred.float(), torch.rand_like(model_pred).float(), reduction="mean")
# print("after fwd", torch.cuda.max_memory_allocated()/1e9)
if is_xl:
model_pred = model(
noisy_model_input, timesteps, prompt_embeds, added_cond_kwargs=unet_added_conditions
).sample
# print("after fwd", torch.cuda.memory_allocated()/1e9, torch.cuda.max_memory_allocated()/1e9)
# import pdb;pdb.set_trace()
loss = F.mse_loss(model_pred.float(), torch.rand_like(model_pred).float(), reduction="mean")
else:
model_pred = model(noisy_model_input, timesteps, encoder_hidden_states).sample
# print("after fwd", torch.cuda.memory_allocated()/1e9, torch.cuda.max_memory_allocated()/1e9)
# import pdb;pdb.set_trace()
loss = F.mse_loss(model_pred.float(), torch.rand_like(model_pred).float(), reduction="mean")

# loss = F.mse_loss(model_pred.float(), torch.rand_like(model_pred).float(), reduction="mean")
# print(loss)

loss.backward()
# print("after bwd", torch.cuda.max_memory_allocated()/1e9)

opt.step()
opt.zero_grad()
torch.cuda.synchronize()
if ind>10:
perf_times.append(time.time()-beg)
beg=time.time()
# prof.step()

# if torch.distributed.get_rank()==0:
print("max mem", torch.cuda.max_memory_allocated()/1e9)
print(perf_times)
# prof.stop()



enable_tf32()
# rank, world_size, port, addr=setup_distributed_slurm()

pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0"

pretrained_model_name_or_path="CompVis/stable-diffusion-v1-4"
unet = UNet2DConditionModel.from_pretrained(
pretrained_model_name_or_path, subfolder="unet", revision=None,
low_cpu_mem_usage=False, device_map=None
).cuda()
unet.train()
# unet = unet.to(memory_format=torch.channels_last)

vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae").cuda()


# optimizer_class = FusedAdam
optimizer_class = functools.partial(torch.optim.Adam,fused = True)
# optimizer_class = torch.optim.AdamW
train(unet, vae, optimizer_class, 16,
use_amp=True, use_zero=False, h=512, w=512, is_xl ='xl' in pretrained_model_name_or_path)
58 changes: 46 additions & 12 deletions src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from .attention_processor import Attention
from .embeddings import CombinedTimestepLabelEmbeddings
from .lora import LoRACompatibleLinear

from .tp_utils import ColParallelLinear,RowParallelLinear, gather_from_sequence_parallel_region, reduce_scatter_to_sequence_parallel_region, _split_along_first_dim

@maybe_allow_in_graph
class BasicTransformerBlock(nn.Module):
Expand Down Expand Up @@ -62,9 +62,11 @@ def __init__(
norm_elementwise_affine: bool = True,
norm_type: str = "layer_norm",
final_dropout: bool = False,
sequence_parallel = False,
):
super().__init__()
self.only_cross_attention = only_cross_attention
self.sequence_parallel = sequence_parallel

self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
Expand All @@ -91,6 +93,7 @@ def __init__(
bias=attention_bias,
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
upcast_attention=upcast_attention,
sequence_parallel = sequence_parallel
)

# 2. Cross-Attn
Expand All @@ -111,14 +114,16 @@ def __init__(
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
sequence_parallel = sequence_parallel
) # is self-attn if encoder_hidden_states is none
else:
self.norm2 = None
self.attn2 = None

# 3. Feed-forward
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout,
sequence_parallel = sequence_parallel)

# let chunk size default to None
self._chunk_size = None
Expand All @@ -138,6 +143,7 @@ def forward(
timestep: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
class_labels: Optional[torch.LongTensor] = None,
sequence_parallel=False
):
# Notice that normalization is always applied before the real computation in the following blocks.
# 1. Self-Attention
Expand All @@ -160,14 +166,14 @@ def forward(
)
if self.use_ada_layer_norm_zero:
attn_output = gate_msa.unsqueeze(1) * attn_output

hidden_states = attn_output + hidden_states

# 2. Cross-Attention
if self.attn2 is not None:
norm_hidden_states = (
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
)

attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
Expand Down Expand Up @@ -201,6 +207,7 @@ def forward(
ff_output = gate_mlp.unsqueeze(1) * ff_output

hidden_states = ff_output + hidden_states
setattr(hidden_states, "sequence_parallel", True)

return hidden_states

Expand All @@ -226,32 +233,33 @@ def __init__(
dropout: float = 0.0,
activation_fn: str = "geglu",
final_dropout: bool = False,
sequence_parallel = False,
):
super().__init__()
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
self.sequence_parallel = sequence_parallel

assert activation_fn == "geglu"
act_fn = GEGLU_ColParallel(dim, inner_dim)

if activation_fn == "gelu":
act_fn = GELU(dim, inner_dim)
if activation_fn == "gelu-approximate":
act_fn = GELU(dim, inner_dim, approximate="tanh")
elif activation_fn == "geglu":
act_fn = GEGLU(dim, inner_dim)
elif activation_fn == "geglu-approximate":
act_fn = ApproximateGELU(dim, inner_dim)

self.net = nn.ModuleList([])
# project in
self.net.append(act_fn)
# project dropout
self.net.append(nn.Dropout(dropout))
# project out
self.net.append(LoRACompatibleLinear(inner_dim, dim_out))
proj2 = RowParallelLinear(inner_dim, dim_out, sequence_parallel = sequence_parallel)
# LoRACompatibleLinear(inner_dim, dim_out)
self.net.append(proj2)
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
if final_dropout:
self.net.append(nn.Dropout(dropout))

def forward(self, hidden_states):
if self.sequence_parallel:
hidden_states = gather_from_sequence_parallel_region(hidden_states)
for module in self.net:
hidden_states = module(hidden_states)
return hidden_states
Expand Down Expand Up @@ -300,9 +308,35 @@ def gelu(self, gate):

def forward(self, hidden_states):
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
import pdb;pdb.set_trace()
return hidden_states * self.gelu(gate)


class GEGLU_ColParallel(nn.Module):
r"""
A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.

Parameters:
dim_in (`int`): The number of channels in the input.
dim_out (`int`): The number of channels in the output.
"""

def __init__(self, dim_in: int, dim_out: int):
super().__init__()
# self.proj = LoRACompatibleLinear(dim_in, dim_out * 2)
self.proj = ColParallelLinear(dim_in, dim_out * 2)

def gelu(self, gate):
if gate.device.type != "mps":
return F.gelu(gate)
# mps: gelu is not implemented for float16
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)

def forward(self, hidden_states):
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
# import pdb;pdb.set_trace()
return hidden_states * self.gelu(gate)

class ApproximateGELU(nn.Module):
"""
The approximate form of Gaussian Error Linear Unit (GELU)
Expand Down
17 changes: 12 additions & 5 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from ..utils.import_utils import is_xformers_available
from .lora import LoRALinearLayer

from .tp_utils import ColParallelLinear,RowParallelLinear, gather_from_sequence_parallel_region

logger = logging.get_logger(__name__) # pylint: disable=invalid-name

Expand Down Expand Up @@ -71,6 +72,7 @@ def __init__(
residual_connection: bool = False,
_from_deprecated_attn_block=False,
processor: Optional["AttnProcessor"] = None,
sequence_parallel = False,
):
super().__init__()
inner_dim = dim_head * heads
Expand All @@ -96,6 +98,7 @@ def __init__(

self.added_kv_proj_dim = added_kv_proj_dim
self.only_cross_attention = only_cross_attention
self.sequence_parallel = sequence_parallel

if self.added_kv_proj_dim is None and self.only_cross_attention:
raise ValueError(
Expand Down Expand Up @@ -135,12 +138,12 @@ def __init__(
f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
)

self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)

# self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
self.to_q = ColParallelLinear(query_dim, inner_dim, bias=bias)
if not self.only_cross_attention:
# only relevant for the `AddedKVProcessor` classes
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
self.to_k = ColParallelLinear(cross_attention_dim, inner_dim, bias=bias)
self.to_v = ColParallelLinear(cross_attention_dim, inner_dim, bias=bias)
else:
self.to_k = None
self.to_v = None
Expand All @@ -150,7 +153,7 @@ def __init__(
self.add_v_proj = nn.Linear(added_kv_proj_dim, inner_dim)

self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(inner_dim, query_dim, bias=out_bias))
self.to_out.append(RowParallelLinear(inner_dim, query_dim, bias=out_bias, sequence_parallel=sequence_parallel))
self.to_out.append(nn.Dropout(dropout))

# set attention processor
Expand Down Expand Up @@ -316,6 +319,10 @@ def set_processor(self, processor: "AttnProcessor"):
self.processor = processor

def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
if self.sequence_parallel:
hidden_states = gather_from_sequence_parallel_region(hidden_states)
if encoder_hidden_states is not None:
encoder_hidden_states = gather_from_sequence_parallel_region(encoder_hidden_states)
# The `Attention` class can call different attention processors / attention functions
# here we simply pass along all tensors to the selected processor class
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
Expand Down
Loading