diff --git a/models/controlnet_union.py b/models/controlnet_union.py index 85ede8e..7339706 100644 --- a/models/controlnet_union.py +++ b/models/controlnet_union.py @@ -19,7 +19,7 @@ from torch.nn import functional as F from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.loaders import FromOriginalControlNetMixin +from diffusers.loaders import FromOriginalModelMixin from diffusers.utils import BaseOutput, logging from diffusers.models.attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, @@ -30,13 +30,13 @@ ) from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps from diffusers.models.modeling_utils import ModelMixin -from diffusers.models.unet_2d_blocks import ( +from diffusers.models.unets.unet_2d_blocks import ( CrossAttnDownBlock2D, DownBlock2D, UNetMidBlock2DCrossAttn, get_down_block, ) -from diffusers.models.unet_2d_condition import UNet2DConditionModel +from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -149,7 +149,7 @@ def forward(self, conditioning): return embedding -class ControlNetModel_Union(ModelMixin, ConfigMixin, FromOriginalControlNetMixin): +class ControlNetModel_Union(ModelMixin, ConfigMixin, FromOriginalModelMixin): """ A ControlNet model. @@ -955,4 +955,3 @@ def zero_module(module): - diff --git a/pipeline/pipeline_controlnet_union_sd_xl.py b/pipeline/pipeline_controlnet_union_sd_xl.py index dcf99db..dbf79ac 100644 --- a/pipeline/pipeline_controlnet_union_sd_xl.py +++ b/pipeline/pipeline_controlnet_union_sd_xl.py @@ -27,7 +27,8 @@ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from diffusers.models import AutoencoderKL, ControlNetModel, UNet2DConditionModel +from diffusers.models import AutoencoderKL, ControlNetModel +from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel from models.controlnet_union import ControlNetModel_Union from diffusers.models.attention_processor import ( AttnProcessor2_0, @@ -747,6 +748,7 @@ def upcast_vae(self): @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, + union_control_type: torch.Tensor, prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, image_list: PipelineImageInput = None, @@ -779,8 +781,8 @@ def __call__( negative_original_size: Optional[Tuple[int, int]] = None, negative_crops_coords_top_left: Tuple[int, int] = (0, 0), negative_target_size: Optional[Tuple[int, int]] = None, - union_control = False, - union_control_type = None, + denoising_end: Optional[float] = None, + **kwargs ): r""" The call function to the pipeline for generation. @@ -1003,6 +1005,7 @@ def __call__( # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps @@ -1063,6 +1066,22 @@ def __call__( add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + # 8.1 Apply denoising_end + if ( + denoising_end is not None + and isinstance(denoising_end, float) + and denoising_end > 0 + and denoising_end < 1 + ): + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: diff --git a/pipeline/pipeline_controlnet_union_sd_xl_img2img.py b/pipeline/pipeline_controlnet_union_sd_xl_img2img.py index a46749f..698b639 100644 --- a/pipeline/pipeline_controlnet_union_sd_xl_img2img.py +++ b/pipeline/pipeline_controlnet_union_sd_xl_img2img.py @@ -32,11 +32,14 @@ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor from diffusers.loaders import ( + FromSingleFileMixin, IPAdapterMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ) -from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from diffusers.models.unets.unet_2d_condition import ImageProjection, UNet2DConditionModel +from diffusers.models import AutoencoderKL, ControlNetModel +from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel from models.controlnet_union import ControlNetModel_Union from diffusers.models.attention_processor import ( AttnProcessor2_0, @@ -163,6 +166,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline( TextualInversionLoaderMixin, StableDiffusionXLLoraLoaderMixin, IPAdapterMixin, + FromSingleFileMixin, ): r""" Pipeline for image-to-image generation using Stable Diffusion XL with ControlNet guidance. @@ -236,7 +240,7 @@ def __init__( tokenizer: CLIPTokenizer, tokenizer_2: CLIPTokenizer, unet: UNet2DConditionModel, - controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], + controlnet: ControlNetModel_Union | None, scheduler: KarrasDiffusionSchedulers, requires_aesthetics_score: bool = False, force_zeros_for_empty_prompt: bool = True, @@ -1071,10 +1075,11 @@ def num_timesteps(self): @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, + union_control_type: torch.Tensor, prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, image: PipelineImageInput = None, - control_image_list: PipelineImageInput = None, + image_list: PipelineImageInput = None, height: Optional[int] = None, width: Optional[int] = None, strength: float = 0.8, @@ -1110,8 +1115,7 @@ def __call__( clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], - union_control = False, - union_control_type = None, + denoising_end: Optional[float] = None, **kwargs, ): r""" @@ -1308,7 +1312,7 @@ def __call__( ) # 1. Check inputs. Raise error if not correct - for control_image in control_image_list: + for control_image in image_list: if control_image: self.check_inputs( prompt, @@ -1393,10 +1397,10 @@ def __call__( # 4. Prepare image and controlnet_conditioning_image image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) - for idx in range(len(control_image_list)): - if control_image_list[idx]: + for idx in range(len(image_list)): + if image_list[idx]: control_image = self.prepare_control_image( - image=control_image_list[idx], + image=image_list[idx], width=width, height=height, batch_size=batch_size * num_images_per_prompt, @@ -1407,7 +1411,7 @@ def __call__( guess_mode=guess_mode, ) height, width = control_image.shape[-2:] - control_image_list[idx] = control_image + image_list[idx] = control_image # 5. Prepare timesteps @@ -1482,6 +1486,22 @@ def __call__( add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device) + # 8.1 Apply denoising_end + if ( + denoising_end is not None + and isinstance(denoising_end, float) + and denoising_end > 0 + and denoising_end < 1 + ): + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -1520,7 +1540,7 @@ def __call__( control_model_input, t, encoder_hidden_states=controlnet_prompt_embeds, - controlnet_cond_list=control_image_list, + controlnet_cond_list=image_list, conditioning_scale=cond_scale, guess_mode=guess_mode, added_cond_kwargs=controlnet_added_cond_kwargs,