diff --git a/examples/text_to_image/train_unet.py b/examples/text_to_image/train_unet.py new file mode 100644 index 000000000000..227fbd3dd691 --- /dev/null +++ b/examples/text_to_image/train_unet.py @@ -0,0 +1,164 @@ +import diffusers +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + StableDiffusionXLPipeline, + UNet2DConditionModel, +) +from apex.optimizers import FusedAdam +import torch +import torch.nn.functional as F +from accelerate import Accelerator +import time +import functools + +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data.distributed import DistributedSampler +from torch.distributed.optim import ZeroRedundancyOptimizer +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import ShardingStrategy + + +def enable_tf32(): + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + +def batch_inp(bs20inp, target_bs): + mbs=2 + bs4inp = bs20inp[:mbs] + if target_bs10: + perf_times.append(time.time()-beg) + beg=time.time() + # prof.step() + + # if torch.distributed.get_rank()==0: + print("max mem", torch.cuda.max_memory_allocated()/1e9) + print(perf_times) + # prof.stop() + + + +enable_tf32() +# rank, world_size, port, addr=setup_distributed_slurm() + +pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0" + +pretrained_model_name_or_path="CompVis/stable-diffusion-v1-4" +unet = UNet2DConditionModel.from_pretrained( + pretrained_model_name_or_path, subfolder="unet", revision=None, + low_cpu_mem_usage=False, device_map=None +).cuda() +unet.train() +# unet = unet.to(memory_format=torch.channels_last) + +vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae").cuda() + + +# optimizer_class = FusedAdam +optimizer_class = functools.partial(torch.optim.Adam,fused = True) +# optimizer_class = torch.optim.AdamW +train(unet, vae, optimizer_class, 16, + use_amp=True, use_zero=False, h=512, w=512, is_xl ='xl' in pretrained_model_name_or_path) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index ad899212e5a5..29ff024af079 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -22,7 +22,7 @@ from .attention_processor import Attention from .embeddings import CombinedTimestepLabelEmbeddings from .lora import LoRACompatibleLinear - +from .tp_utils import ColParallelLinear,RowParallelLinear, gather_from_sequence_parallel_region, reduce_scatter_to_sequence_parallel_region, _split_along_first_dim @maybe_allow_in_graph class BasicTransformerBlock(nn.Module): @@ -62,9 +62,11 @@ def __init__( norm_elementwise_affine: bool = True, norm_type: str = "layer_norm", final_dropout: bool = False, + sequence_parallel = False, ): super().__init__() self.only_cross_attention = only_cross_attention + self.sequence_parallel = sequence_parallel self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" @@ -91,6 +93,7 @@ def __init__( bias=attention_bias, cross_attention_dim=cross_attention_dim if only_cross_attention else None, upcast_attention=upcast_attention, + sequence_parallel = sequence_parallel ) # 2. Cross-Attn @@ -111,6 +114,7 @@ def __init__( dropout=dropout, bias=attention_bias, upcast_attention=upcast_attention, + sequence_parallel = sequence_parallel ) # is self-attn if encoder_hidden_states is none else: self.norm2 = None @@ -118,7 +122,8 @@ def __init__( # 3. Feed-forward self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) - self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout, + sequence_parallel = sequence_parallel) # let chunk size default to None self._chunk_size = None @@ -138,6 +143,7 @@ def forward( timestep: Optional[torch.LongTensor] = None, cross_attention_kwargs: Dict[str, Any] = None, class_labels: Optional[torch.LongTensor] = None, + sequence_parallel=False ): # Notice that normalization is always applied before the real computation in the following blocks. # 1. Self-Attention @@ -160,6 +166,7 @@ def forward( ) if self.use_ada_layer_norm_zero: attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = attn_output + hidden_states # 2. Cross-Attention @@ -167,7 +174,6 @@ def forward( norm_hidden_states = ( self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) ) - attn_output = self.attn2( norm_hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -201,6 +207,7 @@ def forward( ff_output = gate_mlp.unsqueeze(1) * ff_output hidden_states = ff_output + hidden_states + setattr(hidden_states, "sequence_parallel", True) return hidden_states @@ -226,19 +233,16 @@ def __init__( dropout: float = 0.0, activation_fn: str = "geglu", final_dropout: bool = False, + sequence_parallel = False, ): super().__init__() inner_dim = int(dim * mult) dim_out = dim_out if dim_out is not None else dim + self.sequence_parallel = sequence_parallel + + assert activation_fn == "geglu" + act_fn = GEGLU_ColParallel(dim, inner_dim) - if activation_fn == "gelu": - act_fn = GELU(dim, inner_dim) - if activation_fn == "gelu-approximate": - act_fn = GELU(dim, inner_dim, approximate="tanh") - elif activation_fn == "geglu": - act_fn = GEGLU(dim, inner_dim) - elif activation_fn == "geglu-approximate": - act_fn = ApproximateGELU(dim, inner_dim) self.net = nn.ModuleList([]) # project in @@ -246,12 +250,16 @@ def __init__( # project dropout self.net.append(nn.Dropout(dropout)) # project out - self.net.append(LoRACompatibleLinear(inner_dim, dim_out)) + proj2 = RowParallelLinear(inner_dim, dim_out, sequence_parallel = sequence_parallel) + # LoRACompatibleLinear(inner_dim, dim_out) + self.net.append(proj2) # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout if final_dropout: self.net.append(nn.Dropout(dropout)) def forward(self, hidden_states): + if self.sequence_parallel: + hidden_states = gather_from_sequence_parallel_region(hidden_states) for module in self.net: hidden_states = module(hidden_states) return hidden_states @@ -300,9 +308,35 @@ def gelu(self, gate): def forward(self, hidden_states): hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) + import pdb;pdb.set_trace() return hidden_states * self.gelu(gate) +class GEGLU_ColParallel(nn.Module): + r""" + A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + """ + + def __init__(self, dim_in: int, dim_out: int): + super().__init__() + # self.proj = LoRACompatibleLinear(dim_in, dim_out * 2) + self.proj = ColParallelLinear(dim_in, dim_out * 2) + + def gelu(self, gate): + if gate.device.type != "mps": + return F.gelu(gate) + # mps: gelu is not implemented for float16 + return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) + + def forward(self, hidden_states): + hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) + # import pdb;pdb.set_trace() + return hidden_states * self.gelu(gate) + class ApproximateGELU(nn.Module): """ The approximate form of Gaussian Error Linear Unit (GELU) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 43497c2284ac..8b70f99059f2 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -21,6 +21,7 @@ from ..utils.import_utils import is_xformers_available from .lora import LoRALinearLayer +from .tp_utils import ColParallelLinear,RowParallelLinear, gather_from_sequence_parallel_region logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -71,6 +72,7 @@ def __init__( residual_connection: bool = False, _from_deprecated_attn_block=False, processor: Optional["AttnProcessor"] = None, + sequence_parallel = False, ): super().__init__() inner_dim = dim_head * heads @@ -96,6 +98,7 @@ def __init__( self.added_kv_proj_dim = added_kv_proj_dim self.only_cross_attention = only_cross_attention + self.sequence_parallel = sequence_parallel if self.added_kv_proj_dim is None and self.only_cross_attention: raise ValueError( @@ -135,12 +138,12 @@ def __init__( f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" ) - self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) - + # self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) + self.to_q = ColParallelLinear(query_dim, inner_dim, bias=bias) if not self.only_cross_attention: # only relevant for the `AddedKVProcessor` classes - self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias) - self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + self.to_k = ColParallelLinear(cross_attention_dim, inner_dim, bias=bias) + self.to_v = ColParallelLinear(cross_attention_dim, inner_dim, bias=bias) else: self.to_k = None self.to_v = None @@ -150,7 +153,7 @@ def __init__( self.add_v_proj = nn.Linear(added_kv_proj_dim, inner_dim) self.to_out = nn.ModuleList([]) - self.to_out.append(nn.Linear(inner_dim, query_dim, bias=out_bias)) + self.to_out.append(RowParallelLinear(inner_dim, query_dim, bias=out_bias, sequence_parallel=sequence_parallel)) self.to_out.append(nn.Dropout(dropout)) # set attention processor @@ -316,6 +319,10 @@ def set_processor(self, processor: "AttnProcessor"): self.processor = processor def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs): + if self.sequence_parallel: + hidden_states = gather_from_sequence_parallel_region(hidden_states) + if encoder_hidden_states is not None: + encoder_hidden_states = gather_from_sequence_parallel_region(encoder_hidden_states) # The `Attention` class can call different attention processors / attention functions # here we simply pass along all tensors to the selected processor class # For standard processors that are defined here, `**cross_attention_kwargs` is empty diff --git a/src/diffusers/models/tp_utils.py b/src/diffusers/models/tp_utils.py new file mode 100644 index 000000000000..e18e023aba8f --- /dev/null +++ b/src/diffusers/models/tp_utils.py @@ -0,0 +1,245 @@ +import torch +from torch import nn as nn +from torch.nn.parameter import Parameter + +import torch.distributed as dist + +TP_GROUP=None +def get_tp_group(): + global TP_GROUP + return TP_GROUP + +def set_tp_group(group): + global TP_GROUP + TP_GROUP=group + +def get_tensor_model_parallel_world_size(): + return dist.get_world_size(get_tp_group()) + +def maybe_gather_from_sequence_parallel(inp): + if hasattr(inp, "sequence_parallel") and inp.sequence_parallel: + return gather_from_sequence_parallel_region(inp) + return inp + +def maybe_split_into_sequnce_parallel(inp): + if not hasattr(inp, "sequence_parallel"): + return _split_along_first_dim(inp) + return inp + +def set_sequence_parallel_attr(inp, value=True): + setattr(inp, "sequence_parallel", value) + return inp + +def is_squence_parallel_tensor(inp): + return hasattr(inp, "sequence_parallel") and inp.sequence_parallel==True + + +# the following Reduce and Gather functions are adopted from https://github.com/NVIDIA/Megatron-LM +class _ReduceFromModelParallelRegion(torch.autograd.Function): + """All-reduce the input from the model parallel region.""" + + @staticmethod + def forward(ctx, input_): + torch.distributed.all_reduce(input_, group=get_tp_group()) + return input_ + + @staticmethod + def backward(ctx, grad_output): + return grad_output + + +def _reduce_scatter_along_first_dim(input_): + """Reduce-scatter the input tensor across model parallel group.""" + world_size = get_tensor_model_parallel_world_size() + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + + dim_size = list(input_.size()) + assert dim_size[0] % world_size == 0, \ + "First dimension of the tensor should be divisible by tensor parallel size" + + dim_size[0] = dim_size[0] // world_size + + output = torch.empty(dim_size, dtype=input_.dtype, + device=torch.cuda.current_device()) + torch.distributed._reduce_scatter_base(output, input_.contiguous(), + group=get_tp_group()) + return output + +def _gather_along_first_dim(input_): + """Gather tensors and concatinate along the first dimension.""" + + world_size = get_tensor_model_parallel_world_size() + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + + dim_size = list(input_.size()) + dim_size[0] = dim_size[0] * world_size + + output = torch.empty(dim_size, dtype=input_.dtype, + device=torch.cuda.current_device()) + torch.distributed._all_gather_base(output, input_.contiguous(), + group=get_tp_group()) + + return output +def _split_along_first_dim(input_): + """Split the tensor along its first dimension and keep the + corresponding slice.""" + + world_size = get_tensor_model_parallel_world_size() + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + + # Split along first dimension. + dim_size = input_.size()[0] + assert dim_size % world_size == 0, \ + "First dimension of the tensor should be divisible by tensor parallel size" + local_dim_size = dim_size // world_size + rank = torch.distributed.get_rank(get_tp_group()) + dim_offset = rank * local_dim_size + + output = input_[dim_offset:dim_offset+local_dim_size].contiguous() + + set_sequence_parallel_attr(output, True) + return output + +class _ReduceScatterToSequenceParallelRegion(torch.autograd.Function): + """Reduce scatter the input from the model parallel region.""" + + @staticmethod + def symbolic(graph, input_): + return _reduce_scatter_along_first_dim(input_) + + @staticmethod + def forward(ctx, input_): + return _reduce_scatter_along_first_dim(input_) + + @staticmethod + def backward(ctx, grad_output): + return _gather_along_first_dim(grad_output) + + +class _GatherFromSequenceParallelRegion(torch.autograd.Function): + """Gather the input from sequence parallel region and concatinate.""" + + @staticmethod + def symbolic(graph, input_, tensor_parallel_output_grad=True): + return _gather_along_first_dim(input_) + + @staticmethod + def forward(ctx, input_, tensor_parallel_output_grad=True): + ctx.tensor_parallel_output_grad = tensor_parallel_output_grad + return _gather_along_first_dim(input_) + + @staticmethod + def backward(ctx, grad_output): + tensor_parallel_output_grad = ctx.tensor_parallel_output_grad + + # If the computation graph after the gather operation is + # in the tensor parallel mode, output gradients need to reduce + # scattered and whereas if the computation is duplicated, + # output gradients need to be scattered. + if tensor_parallel_output_grad: + return _reduce_scatter_along_first_dim(grad_output), None + else: + return _split_along_first_dim(grad_output), None + +def gather_from_sequence_parallel_region(input_, tensor_parallel_output_grad=True): + output = _GatherFromSequenceParallelRegion.apply(input_, tensor_parallel_output_grad) + set_sequence_parallel_attr(output, False) + return output + +def reduce_scatter_to_sequence_parallel_region(input_): + return _ReduceScatterToSequenceParallelRegion.apply(input_) + + +class TpLinear(nn.Module): + def __init__(self, fin, fout, bias=True): + super(TpLinear, self).__init__() + self.weight = Parameter(torch.rand((fin, fout))) + self.bias =None + if bias: + self.bias = Parameter(torch.zeros(fout)) + + def forward(self, x): + out = torch.matmul(x, self.weight) + if self.bias is not None: + out +=self.bias + return out + +class ColParallelLinear(nn.Module): + def __init__(self, fin, fout, bias=True): + super(ColParallelLinear, self).__init__() + tp_group=get_tp_group() + self.tp_world_size = torch.distributed.get_world_size(tp_group) + assert fout%self.tp_world_size==0 + self.fout = int(fout/self.tp_world_size) + + self.linear = TpLinear(fin, self.fout, bias) + + def forward(self, x): + """ + 1. automatically split fout + 2. do linear compute + 3. the output is split , no communication + """ + out = self.linear(x) + return out + + def init_weight_from_full(self, fullwt): + cur_rank = torch.distributed.get_rank(get_tp_group()) + start_ind = cur_rank*self.fout + end_ind = (cur_rank+1)*self.fout + slice = fullwt[:,start_ind:end_ind] + with torch.no_grad(): + self.linear.weight.copy_(slice) + + def init_weight_from_full_attn(self, fullwt): + cur_rank = torch.distributed.get_rank(get_tp_group()) + ws = torch.distributed.get_world_size(get_tp_group()) + dim=fullwt.shape[0] + dim3=fullwt.shape[1] + fullwts = fullwt.split(dim3//3, dim=-1) # (q,k,v) + splits = [] + for wt in fullwts: + splits.append(wt.split(wt.shape[-1]//ws, dim=-1)[cur_rank]) + + cat_full = torch.cat(splits, dim=-1) + + with torch.no_grad(): + self.linear.weight.copy_(cat_full) + +class RowParallelLinear(nn.Module): + def __init__(self, fin, fout, bias=True, sequence_parallel=False): + super(RowParallelLinear, self).__init__() + tp_group=get_tp_group() + self.tp_world_size = torch.distributed.get_world_size(tp_group) + assert fin%self.tp_world_size==0 + self.fin = int(fin/self.tp_world_size) + self.linear = TpLinear(self.fin, fout, bias) + self.sequence_parallel = sequence_parallel + + + def forward(self, x): + """ + 1. automatically split fout + 2. do linear compute + 3. the output is allreduced + """ + out = self.linear(x) + if not self.sequence_parallel: + out = _ReduceFromModelParallelRegion.apply(out) + else: + out = reduce_scatter_to_sequence_parallel_region(out) + return out + + def init_weight_from_full(self, fullwt): + cur_rank = torch.distributed.get_rank(get_tp_group()) + start_ind = cur_rank*self.fin + end_ind = (cur_rank+1)*self.fin + slice = fullwt[start_ind:end_ind] + with torch.no_grad(): + self.linear.weight.copy_(slice) \ No newline at end of file diff --git a/src/diffusers/models/transformer_2d.py b/src/diffusers/models/transformer_2d.py index 344a9441ced1..4633b069812e 100644 --- a/src/diffusers/models/transformer_2d.py +++ b/src/diffusers/models/transformer_2d.py @@ -26,7 +26,6 @@ from .lora import LoRACompatibleConv, LoRACompatibleLinear from .modeling_utils import ModelMixin - @dataclass class Transformer2DModelOutput(BaseOutput): """ @@ -91,6 +90,7 @@ def __init__( upcast_attention: bool = False, norm_type: str = "layer_norm", norm_elementwise_affine: bool = True, + sequence_parallel=False, ): super().__init__() self.use_linear_projection = use_linear_projection @@ -183,6 +183,7 @@ def __init__( upcast_attention=upcast_attention, norm_type=norm_type, norm_elementwise_affine=norm_elementwise_affine, + sequence_parallel=sequence_parallel ) for d in range(num_layers) ] @@ -313,7 +314,6 @@ def forward( cross_attention_kwargs=cross_attention_kwargs, class_labels=class_labels, ) - # 3. Output if self.is_input_continuous: if not self.use_linear_projection: diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index e894628462ef..6789762fa01e 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -26,7 +26,6 @@ from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D from .transformer_2d import Transformer2DModel - logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -54,6 +53,7 @@ def get_down_block( cross_attention_norm=None, attention_head_dim=None, downsample_type=None, + sequence_parallel=True, ): # If attn head dim is not defined, we default it to the number of heads if attention_head_dim is None: @@ -129,6 +129,7 @@ def get_down_block( only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, + sequence_parallel=sequence_parallel ) elif down_block_type == "SimpleCrossAttnDownBlock2D": if cross_attention_dim is None: @@ -249,6 +250,7 @@ def get_up_block( cross_attention_norm=None, attention_head_dim=None, upsample_type=None, + sequence_parallel=True, ): # If attn head dim is not defined, we default it to the number of heads if attention_head_dim is None: @@ -307,6 +309,7 @@ def get_up_block( only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, + sequence_parallel=sequence_parallel ) elif up_block_type == "SimpleCrossAttnUpBlock2D": if cross_attention_dim is None: @@ -556,6 +559,7 @@ def __init__( dual_cross_attention=False, use_linear_projection=False, upcast_attention=False, + sequence_parallel=True, ): super().__init__() @@ -592,6 +596,7 @@ def __init__( norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, + sequence_parallel=sequence_parallel ) ) else: @@ -624,6 +629,7 @@ def __init__( self.resnets = nn.ModuleList(resnets) self.gradient_checkpointing = False + self.sequence_parallel = sequence_parallel def forward( self, @@ -673,6 +679,8 @@ def custom_forward(*inputs): )[0] hidden_states = resnet(hidden_states, temb) + # if self.sequence_parallel: + # hidden_states = set_sequence_parallel_attr(hidden_states) return hidden_states @@ -934,6 +942,7 @@ def __init__( use_linear_projection=False, only_cross_attention=False, upcast_attention=False, + sequence_parallel=True, ): super().__init__() resnets = [] @@ -970,6 +979,7 @@ def __init__( use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, + sequence_parallel=sequence_parallel, ) ) else: @@ -998,6 +1008,7 @@ def __init__( self.downsamplers = None self.gradient_checkpointing = False + self.sequence_parallel = sequence_parallel def forward( self, @@ -1042,6 +1053,7 @@ def custom_forward(*inputs): )[0] else: hidden_states = resnet(hidden_states, temb) + hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -1051,6 +1063,9 @@ def custom_forward(*inputs): return_dict=False, )[0] + # if self.sequence_parallel: + # set_sequence_parallel_attr(hidden_states) + # apply additional residuals to the output of the last pair of resnet and attention blocks if i == len(blocks) - 1 and additional_residuals is not None: hidden_states = hidden_states + additional_residuals @@ -1061,6 +1076,8 @@ def custom_forward(*inputs): for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) + # if self.sequence_parallel: + # set_sequence_parallel_attr(hidden_states) output_states = output_states + (hidden_states,) return hidden_states, output_states @@ -1082,6 +1099,7 @@ def __init__( output_scale_factor=1.0, add_downsample=True, downsample_padding=1, + sequence_parallel=True, ): super().__init__() resnets = [] @@ -1117,10 +1135,10 @@ def __init__( self.downsamplers = None self.gradient_checkpointing = False + self.sequence_parallel = sequence_parallel def forward(self, hidden_states, temb=None): output_states = () - for resnet in self.resnets: if self.training and self.gradient_checkpointing: @@ -1140,13 +1158,16 @@ def custom_forward(*inputs): ) else: hidden_states = resnet(hidden_states, temb) + # if self.sequence_parallel: + # set_sequence_parallel_attr(hidden_states) output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) - + # if self.sequence_parallel: + # set_sequence_parallel_attr(hidden_states) output_states = output_states + (hidden_states,) return hidden_states, output_states @@ -2068,6 +2089,7 @@ def __init__( use_linear_projection=False, only_cross_attention=False, upcast_attention=False, + sequence_parallel=False ): super().__init__() resnets = [] @@ -2106,6 +2128,7 @@ def __init__( use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, + sequence_parallel=sequence_parallel ) ) else: @@ -2139,6 +2162,7 @@ def forward( upsample_size: Optional[int] = None, attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, + sequence_parallel=False, ): for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states @@ -2206,6 +2230,7 @@ def __init__( resnet_pre_norm: bool = True, output_scale_factor=1.0, add_upsample=True, + sequence_parallel=True, ): super().__init__() resnets = [] @@ -2237,8 +2262,9 @@ def __init__( self.upsamplers = None self.gradient_checkpointing = False + # self.sequence_parallel = sequence_parallel - def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, sequence_parallel=False): for resnet in self.resnets: # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] @@ -2263,6 +2289,8 @@ def custom_forward(*inputs): ) else: hidden_states = resnet(hidden_states, temb) + # if self.sequence_parallel: + # set_sequence_parallel_attr(hidden_states) if self.upsamplers is not None: for upsampler in self.upsamplers: diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index fea1b4cd7823..480ddb6b456a 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -41,7 +41,7 @@ get_down_block, get_up_block, ) - +from .tp_utils import gather_from_sequence_parallel_region, _split_along_first_dim logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -202,6 +202,7 @@ def __init__( mid_block_only_cross_attention: Optional[bool] = None, cross_attention_norm: Optional[str] = None, addition_embed_type_num_heads=64, + sequence_parallel: bool = True, ): super().__init__() @@ -450,6 +451,7 @@ def __init__( resnet_out_scale_factor=resnet_out_scale_factor, cross_attention_norm=cross_attention_norm, attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + sequence_parallel=sequence_parallel ) self.down_blocks.append(down_block) @@ -469,6 +471,7 @@ def __init__( dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, + sequence_parallel=sequence_parallel ) elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn": self.mid_block = UNetMidBlock2DSimpleCrossAttn( @@ -539,6 +542,7 @@ def __init__( resnet_out_scale_factor=resnet_out_scale_factor, cross_attention_norm=cross_attention_norm, attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + sequence_parallel=sequence_parallel ) self.up_blocks.append(up_block) prev_output_channel = output_channel @@ -560,6 +564,8 @@ def __init__( block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding ) + self.sequence_parallel=sequence_parallel + @property def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" @@ -893,8 +899,11 @@ def forward( image_embeds = added_cond_kwargs.get("image_embeds") encoder_hidden_states = self.encoder_hid_proj(image_embeds) # 2. pre-process + if self.sequence_parallel: + sample = _split_along_first_dim(sample) + emb = _split_along_first_dim(emb) + encoder_hidden_states = _split_along_first_dim(encoder_hidden_states) sample = self.conv_in(sample) - # 3. down is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None @@ -946,7 +955,6 @@ def forward( cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, ) - if is_controlnet: sample = sample + mid_block_additional_residual @@ -984,6 +992,9 @@ def forward( sample = self.conv_act(sample) sample = self.conv_out(sample) + if self.sequence_parallel: + sample = gather_from_sequence_parallel_region(sample) + if not return_dict: return (sample,)