Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions models/controlnet_union.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torch.nn import functional as F

from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders import FromOriginalControlNetMixin
from diffusers.loaders import FromOriginalModelMixin
from diffusers.utils import BaseOutput, logging
from diffusers.models.attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
Expand All @@ -30,13 +30,13 @@
)
from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.unet_2d_blocks import (
from diffusers.models.unets.unet_2d_blocks import (
CrossAttnDownBlock2D,
DownBlock2D,
UNetMidBlock2DCrossAttn,
get_down_block,
)
from diffusers.models.unet_2d_condition import UNet2DConditionModel
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel


logger = logging.get_logger(__name__) # pylint: disable=invalid-name
Expand Down Expand Up @@ -149,7 +149,7 @@ def forward(self, conditioning):
return embedding


class ControlNetModel_Union(ModelMixin, ConfigMixin, FromOriginalControlNetMixin):
class ControlNetModel_Union(ModelMixin, ConfigMixin, FromOriginalModelMixin):
"""
A ControlNet model.

Expand Down Expand Up @@ -955,4 +955,3 @@ def zero_module(module):




25 changes: 22 additions & 3 deletions pipeline/pipeline_controlnet_union_sd_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@

from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from diffusers.models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from diffusers.models import AutoencoderKL, ControlNetModel
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from models.controlnet_union import ControlNetModel_Union
from diffusers.models.attention_processor import (
AttnProcessor2_0,
Expand Down Expand Up @@ -747,6 +748,7 @@ def upcast_vae(self):
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
union_control_type: torch.Tensor,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
image_list: PipelineImageInput = None,
Expand Down Expand Up @@ -779,8 +781,8 @@ def __call__(
negative_original_size: Optional[Tuple[int, int]] = None,
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
negative_target_size: Optional[Tuple[int, int]] = None,
union_control = False,
union_control_type = None,
denoising_end: Optional[float] = None,
**kwargs
):
r"""
The call function to the pipeline for generation.
Expand Down Expand Up @@ -1003,6 +1005,7 @@ def __call__(


# 5. Prepare timesteps

self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps

Expand Down Expand Up @@ -1063,6 +1066,22 @@ def __call__(
add_text_embeds = add_text_embeds.to(device)
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)

# 8.1 Apply denoising_end
if (
denoising_end is not None
and isinstance(denoising_end, float)
and denoising_end > 0
and denoising_end < 1
):
discrete_timestep_cutoff = int(
round(
self.scheduler.config.num_train_timesteps
- (denoising_end * self.scheduler.config.num_train_timesteps)
)
)
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
timesteps = timesteps[:num_inference_steps]

# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
Expand Down
42 changes: 31 additions & 11 deletions pipeline/pipeline_controlnet_union_sd_xl_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,14 @@

from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
from diffusers.loaders import (
FromSingleFileMixin,
IPAdapterMixin,
StableDiffusionXLLoraLoaderMixin,
TextualInversionLoaderMixin,
)
from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from diffusers.models.unets.unet_2d_condition import ImageProjection, UNet2DConditionModel
from diffusers.models import AutoencoderKL, ControlNetModel
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
from models.controlnet_union import ControlNetModel_Union
from diffusers.models.attention_processor import (
AttnProcessor2_0,
Expand Down Expand Up @@ -163,6 +166,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
TextualInversionLoaderMixin,
StableDiffusionXLLoraLoaderMixin,
IPAdapterMixin,
FromSingleFileMixin,
):
r"""
Pipeline for image-to-image generation using Stable Diffusion XL with ControlNet guidance.
Expand Down Expand Up @@ -236,7 +240,7 @@ def __init__(
tokenizer: CLIPTokenizer,
tokenizer_2: CLIPTokenizer,
unet: UNet2DConditionModel,
controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
controlnet: ControlNetModel_Union | None,
scheduler: KarrasDiffusionSchedulers,
requires_aesthetics_score: bool = False,
force_zeros_for_empty_prompt: bool = True,
Expand Down Expand Up @@ -1071,10 +1075,11 @@ def num_timesteps(self):
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
union_control_type: torch.Tensor,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
image: PipelineImageInput = None,
control_image_list: PipelineImageInput = None,
image_list: PipelineImageInput = None,
height: Optional[int] = None,
width: Optional[int] = None,
strength: float = 0.8,
Expand Down Expand Up @@ -1110,8 +1115,7 @@ def __call__(
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
union_control = False,
union_control_type = None,
denoising_end: Optional[float] = None,
**kwargs,
):
r"""
Expand Down Expand Up @@ -1308,7 +1312,7 @@ def __call__(
)

# 1. Check inputs. Raise error if not correct
for control_image in control_image_list:
for control_image in image_list:
if control_image:
self.check_inputs(
prompt,
Expand Down Expand Up @@ -1393,10 +1397,10 @@ def __call__(
# 4. Prepare image and controlnet_conditioning_image
image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)

for idx in range(len(control_image_list)):
if control_image_list[idx]:
for idx in range(len(image_list)):
if image_list[idx]:
control_image = self.prepare_control_image(
image=control_image_list[idx],
image=image_list[idx],
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
Expand All @@ -1407,7 +1411,7 @@ def __call__(
guess_mode=guess_mode,
)
height, width = control_image.shape[-2:]
control_image_list[idx] = control_image
image_list[idx] = control_image


# 5. Prepare timesteps
Expand Down Expand Up @@ -1482,6 +1486,22 @@ def __call__(
add_text_embeds = add_text_embeds.to(device)
add_time_ids = add_time_ids.to(device)

# 8.1 Apply denoising_end
if (
denoising_end is not None
and isinstance(denoising_end, float)
and denoising_end > 0
and denoising_end < 1
):
discrete_timestep_cutoff = int(
round(
self.scheduler.config.num_train_timesteps
- (denoising_end * self.scheduler.config.num_train_timesteps)
)
)
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
timesteps = timesteps[:num_inference_steps]

# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
Expand Down Expand Up @@ -1520,7 +1540,7 @@ def __call__(
control_model_input,
t,
encoder_hidden_states=controlnet_prompt_embeds,
controlnet_cond_list=control_image_list,
controlnet_cond_list=image_list,
conditioning_scale=cond_scale,
guess_mode=guess_mode,
added_cond_kwargs=controlnet_added_cond_kwargs,
Expand Down