From 73d1caf8f28a387f2db5a77a8892edad8ed505a0 Mon Sep 17 00:00:00 2001 From: Logan Date: Fri, 10 May 2024 12:38:10 +1000 Subject: [PATCH 01/17] Add Align Your Steps to available schedulers * Include both SDXL and SD 1.5 variants (https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/howto.html) --- modules/sd_schedulers.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/modules/sd_schedulers.py b/modules/sd_schedulers.py index 75eb3ac032f..2131eae46cc 100644 --- a/modules/sd_schedulers.py +++ b/modules/sd_schedulers.py @@ -4,6 +4,7 @@ import k_diffusion +import numpy as np @dataclasses.dataclass class Scheduler: @@ -30,6 +31,35 @@ def sgm_uniform(n, sigma_min, sigma_max, inner_model, device): sigs += [0.0] return torch.FloatTensor(sigs).to(device) +def get_align_your_steps_sigmas(n, device, sigma_id): + # https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/howto.html + def loglinear_interp(t_steps, num_steps): + """ + Performs log-linear interpolation of a given array of decreasing numbers. + """ + xs = np.linspace(0, 1, len(t_steps)) + ys = np.log(t_steps[::-1]) + + new_xs = np.linspace(0, 1, num_steps) + new_ys = np.interp(new_xs, xs, ys) + + interped_ys = np.exp(new_ys)[::-1].copy() + return interped_ys + + if sigma_id == "sdxl": + sigmas = [14.615, 6.315, 3.771, 2.181, 1.342, 0.862, 0.555, 0.380, 0.234, 0.113, 0.029] + elif sigma_id == "sd15": + sigmas = [14.615, 6.475, 3.861, 2.697, 1.886, 1.396, 0.963, 0.652, 0.399, 0.152, 0.029] + else: + print(f'Align Your Steps sigma identifier "{sigma_id}" not recognized, defaulting to SD 1.5.') + sigmas = [14.615, 6.475, 3.861, 2.697, 1.886, 1.396, 0.963, 0.652, 0.399, 0.152, 0.029] + + if n != len(sigmas): + sigmas = np.append(loglinear_interp(sigmas, n), [0.0]) + else: + sigmas.append(0.0) + + return torch.FloatTensor(sigmas).to(device) schedulers = [ Scheduler('automatic', 'Automatic', None), @@ -38,6 +68,8 @@ def sgm_uniform(n, sigma_min, sigma_max, inner_model, device): Scheduler('exponential', 'Exponential', k_diffusion.sampling.get_sigmas_exponential), Scheduler('polyexponential', 'Polyexponential', k_diffusion.sampling.get_sigmas_polyexponential, default_rho=1.0), Scheduler('sgm_uniform', 'SGM Uniform', sgm_uniform, need_inner_model=True, aliases=["SGMUniform"]), + Scheduler('align_your_steps_sdxl', 'Align Your Steps (SDXL)', lambda n, sigma_min, sigma_max, device: get_align_your_steps_sigmas(n, device, "sdxl")), + Scheduler('align_your_steps_sd15', 'Align Your Steps (SD 1.5)', lambda n, sigma_min, sigma_max, device: get_align_your_steps_sigmas(n, device, "sd15")), ] schedulers_map = {**{x.name: x for x in schedulers}, **{x.label: x for x in schedulers}} From d6b4444069d36cf7554eb9932061ecf43e9b1335 Mon Sep 17 00:00:00 2001 From: Logan Date: Fri, 10 May 2024 18:05:45 +1000 Subject: [PATCH 02/17] Use shared.sd_model.is_sdxl to determine base AYS sigmas --- modules/sd_schedulers.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/modules/sd_schedulers.py b/modules/sd_schedulers.py index 2131eae46cc..0ac1f7a21f8 100644 --- a/modules/sd_schedulers.py +++ b/modules/sd_schedulers.py @@ -6,6 +6,8 @@ import numpy as np +from modules import shared + @dataclasses.dataclass class Scheduler: name: str @@ -31,7 +33,7 @@ def sgm_uniform(n, sigma_min, sigma_max, inner_model, device): sigs += [0.0] return torch.FloatTensor(sigs).to(device) -def get_align_your_steps_sigmas(n, device, sigma_id): +def get_align_your_steps_sigmas(n, sigma_min, sigma_max, device): # https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/howto.html def loglinear_interp(t_steps, num_steps): """ @@ -46,12 +48,10 @@ def loglinear_interp(t_steps, num_steps): interped_ys = np.exp(new_ys)[::-1].copy() return interped_ys - if sigma_id == "sdxl": + if shared.sd_model.is_sdxl: sigmas = [14.615, 6.315, 3.771, 2.181, 1.342, 0.862, 0.555, 0.380, 0.234, 0.113, 0.029] - elif sigma_id == "sd15": - sigmas = [14.615, 6.475, 3.861, 2.697, 1.886, 1.396, 0.963, 0.652, 0.399, 0.152, 0.029] else: - print(f'Align Your Steps sigma identifier "{sigma_id}" not recognized, defaulting to SD 1.5.') + # Default to SD 1.5 sigmas. sigmas = [14.615, 6.475, 3.861, 2.697, 1.886, 1.396, 0.963, 0.652, 0.399, 0.152, 0.029] if n != len(sigmas): @@ -68,8 +68,7 @@ def loglinear_interp(t_steps, num_steps): Scheduler('exponential', 'Exponential', k_diffusion.sampling.get_sigmas_exponential), Scheduler('polyexponential', 'Polyexponential', k_diffusion.sampling.get_sigmas_polyexponential, default_rho=1.0), Scheduler('sgm_uniform', 'SGM Uniform', sgm_uniform, need_inner_model=True, aliases=["SGMUniform"]), - Scheduler('align_your_steps_sdxl', 'Align Your Steps (SDXL)', lambda n, sigma_min, sigma_max, device: get_align_your_steps_sigmas(n, device, "sdxl")), - Scheduler('align_your_steps_sd15', 'Align Your Steps (SD 1.5)', lambda n, sigma_min, sigma_max, device: get_align_your_steps_sigmas(n, device, "sd15")), + Scheduler('align_your_steps', 'Align Your Steps', get_align_your_steps_sigmas), ] schedulers_map = {**{x.name: x for x in schedulers}, **{x.label: x for x in schedulers}} From 5a5ac686ed8898fe8a7b477f9cb892bbaf1baf97 Mon Sep 17 00:00:00 2001 From: Chenlei Hu Date: Wed, 15 May 2024 15:38:53 -0400 Subject: [PATCH 03/17] use_checkpoint = False (#1) --- configs/alt-diffusion-inference.yaml | 2 +- configs/alt-diffusion-m18-inference.yaml | 2 +- configs/instruct-pix2pix.yaml | 2 +- configs/sd_xl_inpaint.yaml | 2 +- configs/v1-inference.yaml | 2 +- configs/v1-inpainting-inference.yaml | 2 +- modules/sd_hijack_checkpoint.py | 9 ++++++--- modules/sd_models_config.py | 2 +- 8 files changed, 13 insertions(+), 10 deletions(-) diff --git a/configs/alt-diffusion-inference.yaml b/configs/alt-diffusion-inference.yaml index cfbee72d71b..4944ab5c8dc 100644 --- a/configs/alt-diffusion-inference.yaml +++ b/configs/alt-diffusion-inference.yaml @@ -40,7 +40,7 @@ model: use_spatial_transformer: True transformer_depth: 1 context_dim: 768 - use_checkpoint: True + use_checkpoint: False legacy: False first_stage_config: diff --git a/configs/alt-diffusion-m18-inference.yaml b/configs/alt-diffusion-m18-inference.yaml index 41a031d55f0..c60dca8c7b3 100644 --- a/configs/alt-diffusion-m18-inference.yaml +++ b/configs/alt-diffusion-m18-inference.yaml @@ -41,7 +41,7 @@ model: use_linear_in_transformer: True transformer_depth: 1 context_dim: 1024 - use_checkpoint: True + use_checkpoint: False legacy: False first_stage_config: diff --git a/configs/instruct-pix2pix.yaml b/configs/instruct-pix2pix.yaml index 4e896879dd7..564e50ae246 100644 --- a/configs/instruct-pix2pix.yaml +++ b/configs/instruct-pix2pix.yaml @@ -45,7 +45,7 @@ model: use_spatial_transformer: True transformer_depth: 1 context_dim: 768 - use_checkpoint: True + use_checkpoint: False legacy: False first_stage_config: diff --git a/configs/sd_xl_inpaint.yaml b/configs/sd_xl_inpaint.yaml index 3bad372186f..f40f45e3316 100644 --- a/configs/sd_xl_inpaint.yaml +++ b/configs/sd_xl_inpaint.yaml @@ -21,7 +21,7 @@ model: params: adm_in_channels: 2816 num_classes: sequential - use_checkpoint: True + use_checkpoint: False in_channels: 9 out_channels: 4 model_channels: 320 diff --git a/configs/v1-inference.yaml b/configs/v1-inference.yaml index d4effe569e8..25c4d9ed066 100644 --- a/configs/v1-inference.yaml +++ b/configs/v1-inference.yaml @@ -40,7 +40,7 @@ model: use_spatial_transformer: True transformer_depth: 1 context_dim: 768 - use_checkpoint: True + use_checkpoint: False legacy: False first_stage_config: diff --git a/configs/v1-inpainting-inference.yaml b/configs/v1-inpainting-inference.yaml index f9eec37d24b..68c199f99c3 100644 --- a/configs/v1-inpainting-inference.yaml +++ b/configs/v1-inpainting-inference.yaml @@ -40,7 +40,7 @@ model: use_spatial_transformer: True transformer_depth: 1 context_dim: 768 - use_checkpoint: True + use_checkpoint: False legacy: False first_stage_config: diff --git a/modules/sd_hijack_checkpoint.py b/modules/sd_hijack_checkpoint.py index 2604d969f91..b2f05bbdcf0 100644 --- a/modules/sd_hijack_checkpoint.py +++ b/modules/sd_hijack_checkpoint.py @@ -4,16 +4,19 @@ import ldm.modules.diffusionmodules.openaimodel +# Setting flag=False so that torch skips checking parameters. +# parameters checking is expensive in frequent operations. + def BasicTransformerBlock_forward(self, x, context=None): - return checkpoint(self._forward, x, context) + return checkpoint(self._forward, x, context, flag=False) def AttentionBlock_forward(self, x): - return checkpoint(self._forward, x) + return checkpoint(self._forward, x, flag=False) def ResBlock_forward(self, x, emb): - return checkpoint(self._forward, x, emb) + return checkpoint(self._forward, x, emb, flag=False) stored = [] diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py index b38137eb5a9..9cec4f13dc2 100644 --- a/modules/sd_models_config.py +++ b/modules/sd_models_config.py @@ -35,7 +35,7 @@ def is_using_v_parameterization_for_sd2(state_dict): with sd_disable_initialization.DisableInitialization(): unet = ldm.modules.diffusionmodules.openaimodel.UNetModel( - use_checkpoint=True, + use_checkpoint=False, use_fp16=False, image_size=32, in_channels=4, From aaa8a996dee2a249bca798731233869630958973 Mon Sep 17 00:00:00 2001 From: Chenlei Hu Date: Wed, 15 May 2024 15:54:27 -0400 Subject: [PATCH 04/17] Replace einops.rearrange with torch native (#2) --- modules/sd_hijack_optimizations.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 7f9e328d05a..4c2dc56d45d 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -486,7 +486,19 @@ def xformers_attention_forward(self, x, context=None, mask=None, **kwargs): k_in = self.to_k(context_k) v_in = self.to_v(context_v) - q, k, v = (rearrange(t, 'b n (h d) -> b n h d', h=h) for t in (q_in, k_in, v_in)) + def _reshape(t): + """rearrange(t, 'b n (h d) -> b n h d', h=h). + Using torch native operations to avoid overhead as this function is + called frequently. (70 times/it for SDXL) + """ + b, n, _ = t.shape # Get the batch size (b) and sequence length (n) + d = t.shape[2] // h # Determine the depth per head + return t.reshape(b, n, h, d) + + q = _reshape(q_in) + k = _reshape(k_in) + v = _reshape(v_in) + del q_in, k_in, v_in dtype = q.dtype @@ -497,7 +509,9 @@ def xformers_attention_forward(self, x, context=None, mask=None, **kwargs): out = out.to(dtype) - out = rearrange(out, 'b n h d -> b n (h d)', h=h) + # out = rearrange(out, 'b n h d -> b n (h d)', h=h) + b, n, h, d = out.shape + out = out.reshape(b, n, h * d) return self.to_out(out) From fbeef19152be31867154ac3a199e4cb90e94ca54 Mon Sep 17 00:00:00 2001 From: Chenlei Hu Date: Wed, 15 May 2024 16:13:15 -0400 Subject: [PATCH 05/17] Disable nan check by default (#3) --- modules/cmd_args.py | 3 ++- modules/devices.py | 4 +--- modules/launch_utils.py | 8 ++++++-- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/modules/cmd_args.py b/modules/cmd_args.py index 016a33d1057..26335903daa 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -69,7 +69,8 @@ parser.add_argument("--opt-sdp-attention", action='store_true', help="prefer scaled dot product cross-attention layer optimization for automatic choice of optimization; requires PyTorch 2.*") parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="prefer scaled dot product cross-attention layer optimization without memory efficient attention for automatic choice of optimization, makes image generation deterministic; requires PyTorch 2.*") parser.add_argument("--disable-opt-split-attention", action='store_true', help="prefer no cross-attention layer optimization for automatic choice of optimization") -parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI") +parser.add_argument("--disable-nan-check", action='store_true', help="[Deprecated] do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI") +parser.add_argument("--enable-nan-check", action='store_true', help="Check if produced images/latent spaces have nans at extra performance cost. (~20ms/it)") parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower) parser.add_argument("--use-ipex", action="store_true", help="use Intel XPU as torch device") parser.add_argument("--disable-model-loading-ram-optimization", action='store_true', help="disable an optimization that reduces RAM use when loading a model") diff --git a/modules/devices.py b/modules/devices.py index e4f671ac659..096918ca4dc 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -230,7 +230,7 @@ class NansException(Exception): def test_for_nans(x, where): - if shared.cmd_opts.disable_nan_check: + if not shared.cmd_opts.enable_nan_check: return if not torch.all(torch.isnan(x)).item(): @@ -250,8 +250,6 @@ def test_for_nans(x, where): else: message = "A tensor with all NaNs was produced." - message += " Use --disable-nan-check commandline argument to disable this check." - raise NansException(message) diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 5812b0e5855..ddc411076ba 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -440,6 +440,10 @@ def prepare_environment(): git_pull_recursive(extensions_dir) startup_timer.record("update extensions") + if args.disable_nan_check: + print("Nan check disabled by default. --disable-nan-check argument is now ignored. " + "Use --enable-nan-check to re-enable nan check.") + if "--exit" in sys.argv: print("Exiting because of --exit argument") exit(0) @@ -454,8 +458,8 @@ def configure_for_tests(): sys.argv.append(os.path.join(script_path, "test/test_files/empty.pt")) if "--skip-torch-cuda-test" not in sys.argv: sys.argv.append("--skip-torch-cuda-test") - if "--disable-nan-check" not in sys.argv: - sys.argv.append("--disable-nan-check") + if "--enable-nan-check" in sys.argv: + sys.argv.remove("--enable-nan-check") os.environ['COMMANDLINE_ARGS'] = "" From b7b2bdc90b659c743e241704164a25e96eb99b74 Mon Sep 17 00:00:00 2001 From: Chenlei Hu Date: Wed, 15 May 2024 16:37:47 -0400 Subject: [PATCH 06/17] Precompute is_sdxl_inpaint flag (#4) --- modules/processing.py | 28 +++++++++++----------------- modules/sd_models.py | 7 +++++++ modules/sd_models_xl.py | 9 ++++----- 3 files changed, 22 insertions(+), 22 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index 76557dd7f5e..d82cb24fb95 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -115,20 +115,17 @@ def txt2img_image_conditioning(sd_model, x, width, height): return x.new_zeros(x.shape[0], 2*sd_model.noise_augmentor.time_embed.dim, dtype=x.dtype, device=x.device) else: - sd = sd_model.model.state_dict() - diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None) - if diffusion_model_input is not None: - if diffusion_model_input.shape[1] == 9: - # The "masked-image" in this case will just be all 0.5 since the entire image is masked. - image_conditioning = torch.ones(x.shape[0], 3, height, width, device=x.device) * 0.5 - image_conditioning = images_tensor_to_samples(image_conditioning, - approximation_indexes.get(opts.sd_vae_encode_method)) + if sd_model.model.is_sdxl_inpaint: + # The "masked-image" in this case will just be all 0.5 since the entire image is masked. + image_conditioning = torch.ones(x.shape[0], 3, height, width, device=x.device) * 0.5 + image_conditioning = images_tensor_to_samples(image_conditioning, + approximation_indexes.get(opts.sd_vae_encode_method)) - # Add the fake full 1s mask to the first dimension. - image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0) - image_conditioning = image_conditioning.to(x.dtype) + # Add the fake full 1s mask to the first dimension. + image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0) + image_conditioning = image_conditioning.to(x.dtype) - return image_conditioning + return image_conditioning # Dummy zero conditioning if we're not using inpainting or unclip models. # Still takes up a bit of memory, but no encoder call. @@ -390,11 +387,8 @@ def img2img_image_conditioning(self, source_image, latent_image, image_mask=None if self.sampler.conditioning_key == "crossattn-adm": return self.unclip_image_conditioning(source_image) - sd = self.sampler.model_wrap.inner_model.model.state_dict() - diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None) - if diffusion_model_input is not None: - if diffusion_model_input.shape[1] == 9: - return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask) + if self.sampler.model_wrap.inner_model.model.is_sdxl_inpaint: + return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask) # Dummy zero conditioning if we're not using inpainting or depth model. return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1) diff --git a/modules/sd_models.py b/modules/sd_models.py index ff245b7a668..62e74d27ae0 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -380,6 +380,13 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer model.is_sd2 = not model.is_sdxl and hasattr(model.cond_stage_model, 'model') model.is_sd1 = not model.is_sdxl and not model.is_sd2 model.is_ssd = model.is_sdxl and 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in state_dict.keys() + # Set is_sdxl_inpaint flag. + diffusion_model_input = state_dict.get('diffusion_model.input_blocks.0.0.weight', None) + model.is_sdxl_inpaint = ( + model.is_sdxl and + diffusion_model_input is not None and + diffusion_model_input.shape[1] == 9 + ) if model.is_sdxl: sd_models_xl.extend_sdxl(model) diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index 94ff973fb84..35e21f6e470 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -35,11 +35,10 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond): - sd = self.model.state_dict() - diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None) - if diffusion_model_input is not None: - if diffusion_model_input.shape[1] == 9: - x = torch.cat([x] + cond['c_concat'], dim=1) + """WARNING: This function is called once per denoising iteration. DO NOT add + expensive functionc calls such as `model.state_dict`. """ + if self.model.is_sdxl_inpaint: + x = torch.cat([x] + cond['c_concat'], dim=1) return self.model(x, t, cond) From 2197ab89243121f39edea4712d8321b5adcc01ea Mon Sep 17 00:00:00 2001 From: Chenlei Hu Date: Wed, 15 May 2024 16:39:32 -0400 Subject: [PATCH 07/17] Revert "Precompute is_sdxl_inpaint flag (#4)" (#5) This reverts commit b7b2bdc90b659c743e241704164a25e96eb99b74. --- modules/processing.py | 28 +++++++++++++++++----------- modules/sd_models.py | 7 ------- modules/sd_models_xl.py | 9 +++++---- 3 files changed, 22 insertions(+), 22 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index d82cb24fb95..76557dd7f5e 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -115,17 +115,20 @@ def txt2img_image_conditioning(sd_model, x, width, height): return x.new_zeros(x.shape[0], 2*sd_model.noise_augmentor.time_embed.dim, dtype=x.dtype, device=x.device) else: - if sd_model.model.is_sdxl_inpaint: - # The "masked-image" in this case will just be all 0.5 since the entire image is masked. - image_conditioning = torch.ones(x.shape[0], 3, height, width, device=x.device) * 0.5 - image_conditioning = images_tensor_to_samples(image_conditioning, - approximation_indexes.get(opts.sd_vae_encode_method)) + sd = sd_model.model.state_dict() + diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None) + if diffusion_model_input is not None: + if diffusion_model_input.shape[1] == 9: + # The "masked-image" in this case will just be all 0.5 since the entire image is masked. + image_conditioning = torch.ones(x.shape[0], 3, height, width, device=x.device) * 0.5 + image_conditioning = images_tensor_to_samples(image_conditioning, + approximation_indexes.get(opts.sd_vae_encode_method)) - # Add the fake full 1s mask to the first dimension. - image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0) - image_conditioning = image_conditioning.to(x.dtype) + # Add the fake full 1s mask to the first dimension. + image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0) + image_conditioning = image_conditioning.to(x.dtype) - return image_conditioning + return image_conditioning # Dummy zero conditioning if we're not using inpainting or unclip models. # Still takes up a bit of memory, but no encoder call. @@ -387,8 +390,11 @@ def img2img_image_conditioning(self, source_image, latent_image, image_mask=None if self.sampler.conditioning_key == "crossattn-adm": return self.unclip_image_conditioning(source_image) - if self.sampler.model_wrap.inner_model.model.is_sdxl_inpaint: - return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask) + sd = self.sampler.model_wrap.inner_model.model.state_dict() + diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None) + if diffusion_model_input is not None: + if diffusion_model_input.shape[1] == 9: + return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask) # Dummy zero conditioning if we're not using inpainting or depth model. return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1) diff --git a/modules/sd_models.py b/modules/sd_models.py index 62e74d27ae0..ff245b7a668 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -380,13 +380,6 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer model.is_sd2 = not model.is_sdxl and hasattr(model.cond_stage_model, 'model') model.is_sd1 = not model.is_sdxl and not model.is_sd2 model.is_ssd = model.is_sdxl and 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in state_dict.keys() - # Set is_sdxl_inpaint flag. - diffusion_model_input = state_dict.get('diffusion_model.input_blocks.0.0.weight', None) - model.is_sdxl_inpaint = ( - model.is_sdxl and - diffusion_model_input is not None and - diffusion_model_input.shape[1] == 9 - ) if model.is_sdxl: sd_models_xl.extend_sdxl(model) diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index 35e21f6e470..94ff973fb84 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -35,10 +35,11 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond): - """WARNING: This function is called once per denoising iteration. DO NOT add - expensive functionc calls such as `model.state_dict`. """ - if self.model.is_sdxl_inpaint: - x = torch.cat([x] + cond['c_concat'], dim=1) + sd = self.model.state_dict() + diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None) + if diffusion_model_input is not None: + if diffusion_model_input.shape[1] == 9: + x = torch.cat([x] + cond['c_concat'], dim=1) return self.model(x, t, cond) From f38bafd36aec356df86ea4246065bba0079351a7 Mon Sep 17 00:00:00 2001 From: Chenlei Hu Date: Wed, 15 May 2024 17:01:24 -0400 Subject: [PATCH 08/17] Inpaint fix (#6) * Precompute is_sdxl_inpaint flag * Fix flag check for SD15 --- modules/processing.py | 28 +++++++++++----------------- modules/sd_models.py | 7 +++++++ modules/sd_models_xl.py | 9 ++++----- 3 files changed, 22 insertions(+), 22 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index 76557dd7f5e..fff2595e70a 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -115,20 +115,17 @@ def txt2img_image_conditioning(sd_model, x, width, height): return x.new_zeros(x.shape[0], 2*sd_model.noise_augmentor.time_embed.dim, dtype=x.dtype, device=x.device) else: - sd = sd_model.model.state_dict() - diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None) - if diffusion_model_input is not None: - if diffusion_model_input.shape[1] == 9: - # The "masked-image" in this case will just be all 0.5 since the entire image is masked. - image_conditioning = torch.ones(x.shape[0], 3, height, width, device=x.device) * 0.5 - image_conditioning = images_tensor_to_samples(image_conditioning, - approximation_indexes.get(opts.sd_vae_encode_method)) + if getattr(sd_model.model, "is_sdxl_inpaint", False): + # The "masked-image" in this case will just be all 0.5 since the entire image is masked. + image_conditioning = torch.ones(x.shape[0], 3, height, width, device=x.device) * 0.5 + image_conditioning = images_tensor_to_samples(image_conditioning, + approximation_indexes.get(opts.sd_vae_encode_method)) - # Add the fake full 1s mask to the first dimension. - image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0) - image_conditioning = image_conditioning.to(x.dtype) + # Add the fake full 1s mask to the first dimension. + image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0) + image_conditioning = image_conditioning.to(x.dtype) - return image_conditioning + return image_conditioning # Dummy zero conditioning if we're not using inpainting or unclip models. # Still takes up a bit of memory, but no encoder call. @@ -390,11 +387,8 @@ def img2img_image_conditioning(self, source_image, latent_image, image_mask=None if self.sampler.conditioning_key == "crossattn-adm": return self.unclip_image_conditioning(source_image) - sd = self.sampler.model_wrap.inner_model.model.state_dict() - diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None) - if diffusion_model_input is not None: - if diffusion_model_input.shape[1] == 9: - return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask) + if getattr(self.sampler.model_wrap.inner_model.model, "is_sdxl_inpaint", False): + return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask) # Dummy zero conditioning if we're not using inpainting or depth model. return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1) diff --git a/modules/sd_models.py b/modules/sd_models.py index ff245b7a668..62e74d27ae0 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -380,6 +380,13 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer model.is_sd2 = not model.is_sdxl and hasattr(model.cond_stage_model, 'model') model.is_sd1 = not model.is_sdxl and not model.is_sd2 model.is_ssd = model.is_sdxl and 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in state_dict.keys() + # Set is_sdxl_inpaint flag. + diffusion_model_input = state_dict.get('diffusion_model.input_blocks.0.0.weight', None) + model.is_sdxl_inpaint = ( + model.is_sdxl and + diffusion_model_input is not None and + diffusion_model_input.shape[1] == 9 + ) if model.is_sdxl: sd_models_xl.extend_sdxl(model) diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index 94ff973fb84..35e21f6e470 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -35,11 +35,10 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond): - sd = self.model.state_dict() - diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None) - if diffusion_model_input is not None: - if diffusion_model_input.shape[1] == 9: - x = torch.cat([x] + cond['c_concat'], dim=1) + """WARNING: This function is called once per denoising iteration. DO NOT add + expensive functionc calls such as `model.state_dict`. """ + if self.model.is_sdxl_inpaint: + x = torch.cat([x] + cond['c_concat'], dim=1) return self.model(x, t, cond) From 5b49881f1a3654cbc74984362689ba223e595da2 Mon Sep 17 00:00:00 2001 From: huchenlei Date: Wed, 15 May 2024 17:29:52 -0400 Subject: [PATCH 09/17] Fix attr access --- modules/sd_models_xl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index 35e21f6e470..1242a59369f 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -37,7 +37,7 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond): """WARNING: This function is called once per denoising iteration. DO NOT add expensive functionc calls such as `model.state_dict`. """ - if self.model.is_sdxl_inpaint: + if self.is_sdxl_inpaint: x = torch.cat([x] + cond['c_concat'], dim=1) return self.model(x, t, cond) From b66dfb55a8b2547e6be4fc408287eaccc2f931f9 Mon Sep 17 00:00:00 2001 From: Chenlei Hu Date: Thu, 16 May 2024 14:46:56 -0400 Subject: [PATCH 10/17] Bias backup (#7) * Prevent uncessary bias backup * Fix LoRA bias error --------- Co-authored-by: AUTOMATIC1111 <16777216c@gmail.com> --- extensions-builtin/Lora/networks.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 42b14dc239d..aee4e9d9ca7 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -378,13 +378,18 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn self.network_weights_backup = weights_backup bias_backup = getattr(self, "network_bias_backup", None) - if bias_backup is None: + if bias_backup is None and wanted_names != (): if isinstance(self, torch.nn.MultiheadAttention) and self.out_proj.bias is not None: bias_backup = self.out_proj.bias.to(devices.cpu, copy=True) elif getattr(self, 'bias', None) is not None: bias_backup = self.bias.to(devices.cpu, copy=True) else: bias_backup = None + + # Unlike weight which always has value, some modules don't have bias. + # Only report if bias is not None and current bias are not unchanged. + if bias_backup is not None and current_names != (): + raise RuntimeError("no backup bias found and current bias are not unchanged") self.network_bias_backup = bias_backup if current_names != wanted_names: From da731579808ca6d06781adcdc2e46cce7898bab4 Mon Sep 17 00:00:00 2001 From: huchenlei Date: Thu, 16 May 2024 16:38:39 -0400 Subject: [PATCH 11/17] Fully prevent use_checkpoint --- modules/sd_models.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/modules/sd_models.py b/modules/sd_models.py index 62e74d27ae0..4b4684a60c1 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -558,6 +558,11 @@ def repair_config(sd_config): karlo_path = os.path.join(paths.models_path, 'karlo') sd_config.model.params.noise_aug_config.params.clip_stats_path = sd_config.model.params.noise_aug_config.params.clip_stats_path.replace("checkpoints/karlo_models", karlo_path) + # Do not use checkpoint for inference. + # This helps prevent extra performance overhead on checking parameters. + # The perf overhead is about 100ms/it on 4090. + sd_config.model.params.network_config.params.use_checkpoint = False + def rescale_zero_terminal_snr_abar(alphas_cumprod): alphas_bar_sqrt = alphas_cumprod.sqrt() From 0061b8900d939acd0c9287f8505267cc0d2728a3 Mon Sep 17 00:00:00 2001 From: huchenlei Date: Thu, 16 May 2024 20:07:27 -0400 Subject: [PATCH 12/17] Fix SD15 --- modules/sd_models.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 4b4684a60c1..bb527f5eb91 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -560,8 +560,11 @@ def repair_config(sd_config): # Do not use checkpoint for inference. # This helps prevent extra performance overhead on checking parameters. - # The perf overhead is about 100ms/it on 4090. - sd_config.model.params.network_config.params.use_checkpoint = False + # The perf overhead is about 100ms/it on 4090 for SDXL. + if hasattr(sd_config.model.params, "network_config"): + sd_config.model.params.network_config.params.use_checkpoint = False + if hasattr(sd_config.model.params, "unet_config"): + sd_config.model.params.unet_config.params.use_checkpoint = False def rescale_zero_terminal_snr_abar(alphas_cumprod): From eff3d791ff562a86b673a226f2902283f49cd2ab Mon Sep 17 00:00:00 2001 From: Chenlei Hu Date: Thu, 16 May 2024 20:15:57 -0400 Subject: [PATCH 13/17] Add --precision half cmd option (#8) Co-authored-by: AUTOMATIC1111 <16777216c@gmail.com> --- modules/cmd_args.py | 2 +- modules/devices.py | 24 ++++++++++++++++++++++++ modules/sd_hijack_unet.py | 29 ++++++++++++++++++++++------- modules/sd_hijack_utils.py | 26 +++++++++++++++----------- modules/sd_models.py | 1 + modules/shared_init.py | 8 ++++++++ 6 files changed, 71 insertions(+), 19 deletions(-) diff --git a/modules/cmd_args.py b/modules/cmd_args.py index 26335903daa..2b32d138665 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -41,7 +41,7 @@ parser.add_argument("--lowram", action='store_true', help="load stable diffusion checkpoint weights to VRAM instead of RAM") parser.add_argument("--always-batch-cond-uncond", action='store_true', help="does not do anything") parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.") -parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast") +parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "half", "autocast"], default="autocast") parser.add_argument("--upcast-sampling", action='store_true', help="upcast sampling. No effect with --no-half. Usually produces similar results to --no-half with better performance while using less memory.") parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site") parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None) diff --git a/modules/devices.py b/modules/devices.py index 096918ca4dc..0e19ed2769a 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -114,6 +114,9 @@ def enable_tf32(): cpu: torch.device = torch.device("cpu") fp8: bool = False +# Force fp16 for all models in inference. No casting during inference. +# This flag is controlled by "--precision half" command line arg. +force_fp16: bool = False device: torch.device = None device_interrogate: torch.device = None device_gfpgan: torch.device = None @@ -127,6 +130,8 @@ def enable_tf32(): def cond_cast_unet(input): + if force_fp16: + return input.to(torch.float16) return input.to(dtype_unet) if unet_needs_upcast else input @@ -206,6 +211,11 @@ def autocast(disable=False): if disable: return contextlib.nullcontext() + if force_fp16: + # No casting during inference if force_fp16 is enabled. + # All tensor dtype conversion happens before inference. + return contextlib.nullcontext() + if fp8 and device==cpu: return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True) @@ -267,3 +277,17 @@ def first_time_calculation(): x = torch.zeros((1, 1, 3, 3)).to(device, dtype) conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype) conv2d(x) + + +def force_model_fp16(): + """ + ldm and sgm has modules.diffusionmodules.util.GroupNorm32.forward, which + force conversion of input to float32. If force_fp16 is enabled, we need to + prevent this casting. + """ + assert force_fp16 + import sgm.modules.diffusionmodules.util as sgm_util + import ldm.modules.diffusionmodules.util as ldm_util + sgm_util.GroupNorm32 = torch.nn.GroupNorm + ldm_util.GroupNorm32 = torch.nn.GroupNorm + print("ldm/sgm GroupNorm32 replaced with normal torch.nn.GroupNorm due to `--precision half`.") diff --git a/modules/sd_hijack_unet.py b/modules/sd_hijack_unet.py index 2101f1a0415..41955313a31 100644 --- a/modules/sd_hijack_unet.py +++ b/modules/sd_hijack_unet.py @@ -36,7 +36,7 @@ def cat(self, tensors, *args, **kwargs): # Below are monkey patches to enable upcasting a float16 UNet for float32 sampling def apply_model(orig_func, self, x_noisy, t, cond, **kwargs): - + """Always make sure inputs to unet are in correct dtype.""" if isinstance(cond, dict): for y in cond.keys(): if isinstance(cond[y], list): @@ -45,7 +45,11 @@ def apply_model(orig_func, self, x_noisy, t, cond, **kwargs): cond[y] = cond[y].to(devices.dtype_unet) if isinstance(cond[y], torch.Tensor) else cond[y] with devices.autocast(): - return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float() + result = orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs) + if devices.unet_needs_upcast: + return result.float() + else: + return result class GELUHijack(torch.nn.GELU, torch.nn.Module): @@ -64,12 +68,11 @@ def hijack_ddpm_edit(): if not ddpm_edit_hijack: CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond) CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond) - ddpm_edit_hijack = CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.apply_model', apply_model, unet_needs_upcast) + ddpm_edit_hijack = CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.apply_model', apply_model) unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast -CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast) -CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast) + if version.parse(torch.__version__) <= version.parse("1.13.2") or torch.cuda.is_available(): CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast) CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast) @@ -81,5 +84,17 @@ def hijack_ddpm_edit(): CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond) CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond) -CondFunc('sgm.modules.diffusionmodules.wrappers.OpenAIWrapper.forward', apply_model, unet_needs_upcast) -CondFunc('sgm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast) +CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model) +CondFunc('sgm.modules.diffusionmodules.wrappers.OpenAIWrapper.forward', apply_model) + + +def timestep_embedding_cast_result(orig_func, timesteps, *args, **kwargs): + if devices.unet_needs_upcast and timesteps.dtype == torch.int64: + dtype = torch.float32 + else: + dtype = devices.dtype_unet + return orig_func(timesteps, *args, **kwargs).to(dtype=dtype) + + +CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result) +CondFunc('sgm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result) diff --git a/modules/sd_hijack_utils.py b/modules/sd_hijack_utils.py index 79bf6e46862..546f2eda4ec 100644 --- a/modules/sd_hijack_utils.py +++ b/modules/sd_hijack_utils.py @@ -1,7 +1,11 @@ import importlib + +always_true_func = lambda *args, **kwargs: True + + class CondFunc: - def __new__(cls, orig_func, sub_func, cond_func): + def __new__(cls, orig_func, sub_func, cond_func=always_true_func): self = super(CondFunc, cls).__new__(cls) if isinstance(orig_func, str): func_path = orig_func.split('.') @@ -20,13 +24,13 @@ def __new__(cls, orig_func, sub_func, cond_func): print(f"Warning: Failed to resolve {orig_func} for CondFunc hijack") pass self.__init__(orig_func, sub_func, cond_func) - return lambda *args, **kwargs: self(*args, **kwargs) - def __init__(self, orig_func, sub_func, cond_func): - self.__orig_func = orig_func - self.__sub_func = sub_func - self.__cond_func = cond_func - def __call__(self, *args, **kwargs): - if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs): - return self.__sub_func(self.__orig_func, *args, **kwargs) - else: - return self.__orig_func(*args, **kwargs) + return lambda *args, **kwargs: self(*args, **kwargs) + def __init__(self, orig_func, sub_func, cond_func): + self.__orig_func = orig_func + self.__sub_func = sub_func + self.__cond_func = cond_func + def __call__(self, *args, **kwargs): + if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs): + return self.__sub_func(self.__orig_func, *args, **kwargs) + else: + return self.__orig_func(*args, **kwargs) diff --git a/modules/sd_models.py b/modules/sd_models.py index bb527f5eb91..f7089538722 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -410,6 +410,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer model.float() model.alphas_cumprod_original = model.alphas_cumprod devices.dtype_unet = torch.float32 + assert shared.cmd_opts.precision != "half", "Cannot use --precision half with --no-half" timer.record("apply float()") else: vae = model.first_stage_model diff --git a/modules/shared_init.py b/modules/shared_init.py index 935e3a21cf2..a6ad0433d6f 100644 --- a/modules/shared_init.py +++ b/modules/shared_init.py @@ -31,6 +31,14 @@ def initialize(): devices.dtype_vae = torch.float32 if cmd_opts.no_half or cmd_opts.no_half_vae else torch.float16 devices.dtype_inference = torch.float32 if cmd_opts.precision == 'full' else devices.dtype + if cmd_opts.precision == "half": + msg = "--no-half and --no-half-vae conflict with --precision half" + assert devices.dtype == torch.float16, msg + assert devices.dtype_vae == torch.float16, msg + assert devices.dtype_inference == torch.float16, msg + devices.force_fp16 = True + devices.force_model_fp16() + shared.device = devices.device shared.weight_load_location = None if cmd_opts.lowram else "cpu" From b3971adbeaf43b8f6f5b3a24b5869b325ce17785 Mon Sep 17 00:00:00 2001 From: huchenlei Date: Fri, 17 May 2024 13:23:47 -0400 Subject: [PATCH 14/17] Fix SD15 dtype --- modules/sd_models.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/modules/sd_models.py b/modules/sd_models.py index f7089538722..f51bf53f25b 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -748,6 +748,10 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): sd_model = instantiate_from_config(sd_config.model) sd_model.used_config = checkpoint_config + # ldm's Unet is using self.dtype to cast input tensor. If we do not overwrite + # UnetModel.dtype, it will be the default dtype from config. + # sgm's Unet is not using dtype for casting. The value will be ignored. + sd_model.model.diffusion_model.dtype = devices.dtype_unet timer.record("create model") From 8e355f08b2e9f6b2e75581e5da752bb428c4bcdd Mon Sep 17 00:00:00 2001 From: huchenlei Date: Fri, 17 May 2024 13:33:19 -0400 Subject: [PATCH 15/17] Proper fix of SD15 dtype --- modules/sd_models.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index f51bf53f25b..aed58e19d36 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -548,7 +548,7 @@ def repair_config(sd_config): if hasattr(sd_config.model.params, 'unet_config'): if shared.cmd_opts.no_half: sd_config.model.params.unet_config.params.use_fp16 = False - elif shared.cmd_opts.upcast_sampling: + elif shared.cmd_opts.upcast_sampling or shared.cmd_opts.precision == "half": sd_config.model.params.unet_config.params.use_fp16 = True if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available: @@ -748,10 +748,6 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): sd_model = instantiate_from_config(sd_config.model) sd_model.used_config = checkpoint_config - # ldm's Unet is using self.dtype to cast input tensor. If we do not overwrite - # UnetModel.dtype, it will be the default dtype from config. - # sgm's Unet is not using dtype for casting. The value will be ignored. - sd_model.model.diffusion_model.dtype = devices.dtype_unet timer.record("create model") From 1d7448281751ea3223c681a82de8219a6fbe1d22 Mon Sep 17 00:00:00 2001 From: Logan Date: Sat, 18 May 2024 09:09:57 +1000 Subject: [PATCH 16/17] Default device for sigma tensor to CPU * Consistent with implementations in k-diffusion. * Makes this compatible with https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15823 --- modules/sd_schedulers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_schedulers.py b/modules/sd_schedulers.py index 0ac1f7a21f8..4ddb778501a 100644 --- a/modules/sd_schedulers.py +++ b/modules/sd_schedulers.py @@ -33,7 +33,7 @@ def sgm_uniform(n, sigma_min, sigma_max, inner_model, device): sigs += [0.0] return torch.FloatTensor(sigs).to(device) -def get_align_your_steps_sigmas(n, sigma_min, sigma_max, device): +def get_align_your_steps_sigmas(n, sigma_min, sigma_max, device='cpu'): # https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/howto.html def loglinear_interp(t_steps, num_steps): """ From 6661c219f10798ac7d716a8b80c09021adae0b3f Mon Sep 17 00:00:00 2001 From: rem Date: Wed, 22 May 2024 16:31:39 -0500 Subject: [PATCH 17/17] xformers --- requirements.txt | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 9e2ecfe4d67..d723775e99a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -30,4 +30,6 @@ torch torchdiffeq torchsde transformers==4.30.2 -pillow-avif-plugin==1.4.3 \ No newline at end of file +pillow-avif-plugin==1.4.3 + +xformers \ No newline at end of file