diff --git a/build_launcher.py b/build_launcher.py index 4443888b20..00f0cec8e4 100644 --- a/build_launcher.py +++ b/build_launcher.py @@ -12,6 +12,7 @@ def build_launcher(): + """Builds the launcher scripts for Windows.""" if not is_win32_standalone_build: return diff --git a/launch.py b/launch.py index eae5b19eb8..a5c6cfce63 100644 --- a/launch.py +++ b/launch.py @@ -27,6 +27,7 @@ def prepare_environment(): + """Prepares the environment for running the application.""" torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu121") torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.1.0 torchvision==0.16.0 --extra-index-url {torch_index_url}") @@ -68,6 +69,10 @@ def prepare_environment(): def ini_args(): + """Initializes the arguments. + Returns: + The arguments. + """ from args_manager import args return args @@ -101,6 +106,17 @@ def ini_args(): def download_models(default_model, previous_default_models, checkpoint_downloads, embeddings_downloads, lora_downloads, vae_downloads): + """Downloads the models. + Args: + default_model (str): The default model. + previous_default_models (list): A list of previous default models. + checkpoint_downloads (dict): A dictionary of checkpoint downloads. + embeddings_downloads (dict): A dictionary of embeddings downloads. + lora_downloads (dict): A dictionary of LoRA downloads. + vae_downloads (dict): A dictionary of VAE downloads. + Returns: + A tuple of the default model and the checkpoint downloads. + """ from modules.util import get_file_from_folder_list for file_name, url in vae_approx_filenames: diff --git a/ldm_patched/modules/model_management.py b/ldm_patched/modules/model_management.py index 840d79a07b..bed340c01d 100644 --- a/ldm_patched/modules/model_management.py +++ b/ldm_patched/modules/model_management.py @@ -6,6 +6,7 @@ import sys class VRAMState(Enum): + """An enumeration of the available VRAM states.""" DISABLED = 0 #No vram present: no need to move models to vram NO_VRAM = 1 #Very low vram: enable all the options to save vram LOW_VRAM = 2 @@ -14,6 +15,7 @@ class VRAMState(Enum): SHARED = 5 #No dedicated vram: memory shared between CPU and GPU but models still need to be moved between both. class CPUState(Enum): + """An enumeration of the available CPU states.""" GPU = 0 CPU = 1 MPS = 2 @@ -66,6 +68,10 @@ class CPUState(Enum): cpu_state = CPUState.CPU def is_intel_xpu(): + """Checks if an Intel XPU is available. + Returns: + bool: True if an Intel XPU is available, False otherwise. + """ global cpu_state global xpu_available if cpu_state == CPUState.GPU: @@ -74,6 +80,10 @@ def is_intel_xpu(): return False def get_torch_device(): + """Gets the torch device. + Returns: + torch.device: The torch device. + """ global directml_enabled global cpu_state if directml_enabled: @@ -90,6 +100,13 @@ def get_torch_device(): return torch.device(torch.cuda.current_device()) def get_total_memory(dev=None, torch_total_too=False): + """Gets the total memory of a device. + Args: + dev (torch.device, optional): The device. Defaults to None. + torch_total_too (bool, optional): Whether to also return the total memory as reported by torch. Defaults to False. + Returns: + int or tuple[int, int]: The total memory of the device, or a tuple of the total memory and the total memory as reported by torch. + """ global directml_enabled if dev is None: dev = get_torch_device() @@ -159,6 +176,10 @@ def get_total_memory(dev=None, torch_total_too=False): XFORMERS_IS_AVAILABLE = False def is_nvidia(): + """Checks if an NVIDIA GPU is available. + Returns: + bool: True if an NVIDIA GPU is available, False otherwise. + """ global cpu_state if cpu_state == CPUState.GPU: if torch.version.cuda: @@ -242,6 +263,12 @@ def is_nvidia(): print("Always offload VRAM") def get_torch_device_name(device): + """Gets the name of a torch device. + Args: + device (torch.device): The device. + Returns: + str: The name of the device. + """ if hasattr(device, 'type'): if device.type == "cuda": try: @@ -266,6 +293,12 @@ def get_torch_device_name(device): current_loaded_models = [] def module_size(module): + """Gets the size of a module in bytes. + Args: + module (torch.nn.Module): The module. + Returns: + int: The size of the module in bytes. + """ module_mem = 0 sd = module.state_dict() for k in sd: @@ -274,21 +307,42 @@ def module_size(module): return module_mem class LoadedModel: + """A class that represents a loaded model.""" def __init__(self, model): + """Initializes a LoadedModel. + Args: + model: The model. + """ self.model = model self.model_accelerated = False self.device = model.load_device def model_memory(self): + """Gets the memory usage of the model. + Returns: + int: The memory usage of the model in bytes. + """ return self.model.model_size() def model_memory_required(self, device): + """Gets the memory required to load the model on a device. + Args: + device (torch.device): The device. + Returns: + int: The memory required to load the model on the device in bytes. + """ if device == self.model.current_device: return 0 else: return self.model_memory() def model_load(self, lowvram_model_memory=0): + """Loads the model. + Args: + lowvram_model_memory (int, optional): The amount of memory to reserve for low VRAM mode. Defaults to 0. + Returns: + The loaded model. + """ patch_model_to = None if lowvram_model_memory == 0: patch_model_to = self.device @@ -327,6 +381,7 @@ def model_load(self, lowvram_model_memory=0): return self.real_model def model_unload(self): + """Unloads the model.""" if self.model_accelerated: for m in self.real_model.modules(): if hasattr(m, "prev_ldm_patched_cast_weights"): @@ -342,9 +397,17 @@ def __eq__(self, other): return self.model is other.model def minimum_inference_memory(): + """Gets the minimum memory required for inference. + Returns: + int: The minimum memory required for inference in bytes. + """ return (1024 * 1024 * 1024) def unload_model_clones(model): + """Unloads all clones of a model. + Args: + model: The model. + """ to_unload = [] for i in range(len(current_loaded_models)): if model.is_clone(current_loaded_models[i].model): @@ -355,6 +418,12 @@ def unload_model_clones(model): current_loaded_models.pop(i).model_unload() def free_memory(memory_required, device, keep_loaded=[]): + """Frees memory by unloading models. + Args: + memory_required (int): The amount of memory required. + device (torch.device): The device to free memory on. + keep_loaded (list, optional): A list of models to keep loaded. Defaults to []. + """ unloaded_model = False for i in range(len(current_loaded_models) -1, -1, -1): if not ALWAYS_VRAM_OFFLOAD: @@ -377,6 +446,11 @@ def free_memory(memory_required, device, keep_loaded=[]): soft_empty_cache() def load_models_gpu(models, memory_required=0): + """Loads a list of models on the GPU. + Args: + models (list): A list of models to load. + memory_required (int, optional): The amount of memory required. Defaults to 0. + """ global vram_state inference_memory = minimum_inference_memory() @@ -440,9 +514,16 @@ def load_models_gpu(models, memory_required=0): def load_model_gpu(model): + """Loads a model on the GPU. + Args: + model: The model to load. + Returns: + The loaded model. + """ return load_models_gpu([model]) def cleanup_models(): + """Cleans up unused models.""" to_delete = [] for i in range(len(current_loaded_models)): if sys.getrefcount(current_loaded_models[i].model) <= 2: @@ -454,6 +535,12 @@ def cleanup_models(): del x def dtype_size(dtype): + """Gets the size of a dtype in bytes. + Args: + dtype: The dtype. + Returns: + int: The size of the dtype in bytes. + """ dtype_size = 4 if dtype == torch.float16 or dtype == torch.bfloat16: dtype_size = 2 @@ -467,12 +554,23 @@ def dtype_size(dtype): return dtype_size def unet_offload_device(): + """Gets the device to offload the UNet to. + Returns: + torch.device: The device to offload the UNet to. + """ if vram_state == VRAMState.HIGH_VRAM: return get_torch_device() else: return torch.device("cpu") def unet_inital_load_device(parameters, dtype): + """Gets the initial device to load the UNet on. + Args: + parameters: The parameters of the UNet. + dtype: The dtype of the UNet. + Returns: + torch.device: The initial device to load the UNet on. + """ torch_dev = get_torch_device() if vram_state == VRAMState.HIGH_VRAM: return torch_dev @@ -491,6 +589,13 @@ def unet_inital_load_device(parameters, dtype): return cpu_dev def unet_dtype(device=None, model_params=0): + """Gets the dtype for the UNet. + Args: + device (torch.device, optional): The device. Defaults to None. + model_params (int, optional): The number of parameters in the model. Defaults to 0. + Returns: + torch.dtype: The dtype for the UNet. + """ if args.unet_in_bf16: return torch.bfloat16 if args.unet_in_fp16: @@ -505,6 +610,13 @@ def unet_dtype(device=None, model_params=0): # None means no manual cast def unet_manual_cast(weight_dtype, inference_device): + """Gets the manual cast for the UNet. + Args: + weight_dtype: The weight dtype of the UNet. + inference_device (torch.device): The inference device. + Returns: + torch.dtype or None: The manual cast for the UNet, or None if no manual cast is needed. + """ if weight_dtype == torch.float32: return None @@ -518,12 +630,20 @@ def unet_manual_cast(weight_dtype, inference_device): return torch.float32 def text_encoder_offload_device(): + """Gets the device to offload the text encoder to. + Returns: + torch.device: The device to offload the text encoder to. + """ if args.always_gpu: return get_torch_device() else: return torch.device("cpu") def text_encoder_device(): + """Gets the device for the text encoder. + Returns: + torch.device: The device for the text encoder. + """ if args.always_gpu: return get_torch_device() elif vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.NORMAL_VRAM: @@ -537,6 +657,12 @@ def text_encoder_device(): return torch.device("cpu") def text_encoder_dtype(device=None): + """Gets the dtype for the text encoder. + Args: + device (torch.device, optional): The device. Defaults to None. + Returns: + torch.dtype: The dtype for the text encoder. + """ if args.clip_in_fp8_e4m3fn: return torch.float8_e4m3fn elif args.clip_in_fp8_e5m2: @@ -555,32 +681,61 @@ def text_encoder_dtype(device=None): return torch.float32 def intermediate_device(): + """Gets the intermediate device. + Returns: + torch.device: The intermediate device. + """ if args.always_gpu: return get_torch_device() else: return torch.device("cpu") def vae_device(): + """Gets the device for the VAE. + Returns: + torch.device: The device for the VAE. + """ if args.vae_in_cpu: return torch.device("cpu") return get_torch_device() def vae_offload_device(): + """Gets the device to offload the VAE to. + Returns: + torch.device: The device to offload the VAE to. + """ if args.always_gpu: return get_torch_device() else: return torch.device("cpu") def vae_dtype(): + """Gets the dtype for the VAE. + Returns: + torch.dtype: The dtype for the VAE. + """ global VAE_DTYPE return VAE_DTYPE def get_autocast_device(dev): + """Gets the autocast device for a device. + Args: + dev (torch.device): The device. + Returns: + str: The autocast device. + """ if hasattr(dev, 'type'): return dev.type return "cuda" def supports_dtype(device, dtype): #TODO + """Checks if a device supports a dtype. + Args: + device (torch.device): The device. + dtype: The dtype. + Returns: + bool: True if the device supports the dtype, False otherwise. + """ if dtype == torch.float32: return True if is_device_cpu(device): @@ -592,11 +747,26 @@ def supports_dtype(device, dtype): #TODO return False def device_supports_non_blocking(device): + """Checks if a device supports non-blocking operations. + Args: + device (torch.device): The device. + Returns: + bool: True if the device supports non-blocking operations, False otherwise. + """ if is_device_mps(device): return False #pytorch bug? mps doesn't support non blocking return True def cast_to_device(tensor, device, dtype, copy=False): + """Casts a tensor to a device and dtype. + Args: + tensor: The tensor to cast. + device (torch.device): The device to cast to. + dtype: The dtype to cast to. + copy (bool, optional): Whether to copy the tensor. Defaults to False. + Returns: + The cast tensor. + """ device_supports_cast = False if tensor.dtype == torch.float32 or tensor.dtype == torch.float16: device_supports_cast = True @@ -619,6 +789,10 @@ def cast_to_device(tensor, device, dtype, copy=False): return tensor.to(device, dtype, copy=copy, non_blocking=non_blocking) def xformers_enabled(): + """Checks if xformers is enabled. + Returns: + bool: True if xformers is enabled, False otherwise. + """ global directml_enabled global cpu_state if cpu_state != CPUState.GPU: @@ -631,6 +805,10 @@ def xformers_enabled(): def xformers_enabled_vae(): + """Checks if xformers is enabled for the VAE. + Returns: + bool: True if xformers is enabled for the VAE, False otherwise. + """ enabled = xformers_enabled() if not enabled: return False @@ -638,10 +816,18 @@ def xformers_enabled_vae(): return XFORMERS_ENABLED_VAE def pytorch_attention_enabled(): + """Checks if PyTorch attention is enabled. + Returns: + bool: True if PyTorch attention is enabled, False otherwise. + """ global ENABLE_PYTORCH_ATTENTION return ENABLE_PYTORCH_ATTENTION def pytorch_attention_flash_attention(): + """Checks if PyTorch flash attention is enabled. + Returns: + bool: True if PyTorch flash attention is enabled, False otherwise. + """ global ENABLE_PYTORCH_ATTENTION if ENABLE_PYTORCH_ATTENTION: #TODO: more reliable way of checking for flash attention? @@ -650,6 +836,13 @@ def pytorch_attention_flash_attention(): return False def get_free_memory(dev=None, torch_free_too=False): + """Gets the free memory of a device. + Args: + dev (torch.device, optional): The device. Defaults to None. + torch_free_too (bool, optional): Whether to also return the free memory as reported by torch. Defaults to False. + Returns: + int or tuple[int, int]: The free memory of the device, or a tuple of the free memory and the free memory as reported by torch. + """ global directml_enabled if dev is None: dev = get_torch_device() @@ -682,26 +875,54 @@ def get_free_memory(dev=None, torch_free_too=False): return mem_free_total def cpu_mode(): + """Checks if the CPU is in CPU mode. + Returns: + bool: True if the CPU is in CPU mode, False otherwise. + """ global cpu_state return cpu_state == CPUState.CPU def mps_mode(): + """Checks if the CPU is in MPS mode. + Returns: + bool: True if the CPU is in MPS mode, False otherwise. + """ global cpu_state return cpu_state == CPUState.MPS def is_device_cpu(device): + """Checks if a device is a CPU. + Args: + device (torch.device): The device. + Returns: + bool: True if the device is a CPU, False otherwise. + """ if hasattr(device, 'type'): if (device.type == 'cpu'): return True return False def is_device_mps(device): + """Checks if a device is an MPS device. + Args: + device (torch.device): The device. + Returns: + bool: True if the device is an MPS device, False otherwise. + """ if hasattr(device, 'type'): if (device.type == 'mps'): return True return False def should_use_fp16(device=None, model_params=0, prioritize_performance=True): + """Checks if fp16 should be used. + Args: + device (torch.device, optional): The device. Defaults to None. + model_params (int, optional): The number of parameters in the model. Defaults to 0. + prioritize_performance (bool, optional): Whether to prioritize performance. Defaults to True. + Returns: + bool: True if fp16 should be used, False otherwise. + """ global directml_enabled if device is not None: @@ -760,6 +981,10 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True): return True def soft_empty_cache(force=False): + """Softly empties the cache. + Args: + force (bool, optional): Whether to force the cache to be emptied. Defaults to False. + """ global cpu_state if cpu_state == CPUState.MPS: torch.mps.empty_cache() @@ -771,34 +996,53 @@ def soft_empty_cache(force=False): torch.cuda.ipc_collect() def unload_all_models(): + """Unloads all models.""" free_memory(1e30, get_torch_device()) def resolve_lowvram_weight(weight, model, key): #TODO: remove + """Resolves the low VRAM weight. + Args: + weight: The weight. + model: The model. + key: The key. + Returns: + The resolved low VRAM weight. + """ return weight #TODO: might be cleaner to put this somewhere else import threading class InterruptProcessingException(Exception): + """An exception that is raised when processing is interrupted.""" pass interrupt_processing_mutex = threading.RLock() interrupt_processing = False def interrupt_current_processing(value=True): + """Interrupts the current processing. + Args: + value (bool, optional): The value to set the interrupt flag to. Defaults to True. + """ global interrupt_processing global interrupt_processing_mutex with interrupt_processing_mutex: interrupt_processing = value def processing_interrupted(): + """Checks if processing has been interrupted. + Returns: + bool: True if processing has been interrupted, False otherwise. + """ global interrupt_processing global interrupt_processing_mutex with interrupt_processing_mutex: return interrupt_processing def throw_exception_if_processing_interrupted(): + """Throws an exception if processing has been interrupted.""" global interrupt_processing global interrupt_processing_mutex with interrupt_processing_mutex: diff --git a/ldm_patched/modules/model_patcher.py b/ldm_patched/modules/model_patcher.py index dd816e52e1..da1747481f 100644 --- a/ldm_patched/modules/model_patcher.py +++ b/ldm_patched/modules/model_patcher.py @@ -6,7 +6,17 @@ import ldm_patched.modules.model_management class ModelPatcher: + """A class for patching models.""" def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False): + """Initializes a ModelPatcher. + Args: + model: The model to patch. + load_device (torch.device): The device to load the model on. + offload_device (torch.device): The device to offload the model to. + size (int, optional): The size of the model. Defaults to 0. + current_device (torch.device, optional): The current device of the model. Defaults to None. + weight_inplace_update (bool, optional): Whether to update the weights in place. Defaults to False. + """ self.size = size self.model = model self.patches = {} @@ -25,6 +35,10 @@ def __init__(self, model, load_device, offload_device, size=0, current_device=No self.weight_inplace_update = weight_inplace_update def model_size(self): + """Gets the size of the model. + Returns: + int: The size of the model in bytes. + """ if self.size > 0: return self.size model_sd = self.model.state_dict() @@ -33,6 +47,10 @@ def model_size(self): return self.size def clone(self): + """Clones the ModelPatcher. + Returns: + ModelPatcher: The cloned ModelPatcher. + """ n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update) n.patches = {} for k in self.patches: @@ -44,14 +62,31 @@ def clone(self): return n def is_clone(self, other): + """Checks if another ModelPatcher is a clone of this one. + Args: + other (ModelPatcher): The other ModelPatcher. + Returns: + bool: True if the other ModelPatcher is a clone of this one, False otherwise. + """ if hasattr(other, 'model') and self.model is other.model: return True return False def memory_required(self, input_shape): + """Gets the memory required for the model. + Args: + input_shape (tuple): The input shape. + Returns: + int: The memory required for the model in bytes. + """ return self.model.memory_required(input_shape=input_shape) def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False): + """Sets the model's sampler CFG function. + Args: + sampler_cfg_function (function): The sampler CFG function. + disable_cfg1_optimization (bool, optional): Whether to disable CFG1 optimization. Defaults to False. + """ if len(inspect.signature(sampler_cfg_function).parameters) == 3: self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way else: @@ -60,20 +95,42 @@ def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_opti self.model_options["disable_cfg1_optimization"] = True def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False): + """Sets the model's sampler post-CFG function. + Args: + post_cfg_function (function): The sampler post-CFG function. + disable_cfg1_optimization (bool, optional): Whether to disable CFG1 optimization. Defaults to False. + """ self.model_options["sampler_post_cfg_function"] = self.model_options.get("sampler_post_cfg_function", []) + [post_cfg_function] if disable_cfg1_optimization: self.model_options["disable_cfg1_optimization"] = True def set_model_unet_function_wrapper(self, unet_wrapper_function): + """Sets the model's UNet function wrapper. + Args: + unet_wrapper_function (function): The UNet function wrapper. + """ self.model_options["model_function_wrapper"] = unet_wrapper_function def set_model_patch(self, patch, name): + """Sets a model patch. + Args: + patch: The patch to set. + name (str): The name of the patch. + """ to = self.model_options["transformer_options"] if "patches" not in to: to["patches"] = {} to["patches"][name] = to["patches"].get(name, []) + [patch] def set_model_patch_replace(self, patch, name, block_name, number, transformer_index=None): + """Sets a model patch replacement. + Args: + patch: The patch to set. + name (str): The name of the patch. + block_name (str): The name of the block. + number (int): The number of the block. + transformer_index (int, optional): The index of the transformer. Defaults to None. + """ to = self.model_options["transformer_options"] if "patches_replace" not in to: to["patches_replace"] = {} @@ -86,36 +143,87 @@ def set_model_patch_replace(self, patch, name, block_name, number, transformer_i to["patches_replace"][name][block] = patch def set_model_attn1_patch(self, patch): + """Sets a model attention1 patch. + Args: + patch: The patch to set. + """ self.set_model_patch(patch, "attn1_patch") def set_model_attn2_patch(self, patch): + """Sets a model attention2 patch. + Args: + patch: The patch to set. + """ self.set_model_patch(patch, "attn2_patch") def set_model_attn1_replace(self, patch, block_name, number, transformer_index=None): + """Sets a model attention1 replacement. + Args: + patch: The patch to set. + block_name (str): The name of the block. + number (int): The number of the block. + transformer_index (int, optional): The index of the transformer. Defaults to None. + """ self.set_model_patch_replace(patch, "attn1", block_name, number, transformer_index) def set_model_attn2_replace(self, patch, block_name, number, transformer_index=None): + """Sets a model attention2 replacement. + Args: + patch: The patch to set. + block_name (str): The name of the block. + number (int): The number of the block. + transformer_index (int, optional): The index of the transformer. Defaults to None. + """ self.set_model_patch_replace(patch, "attn2", block_name, number, transformer_index) def set_model_attn1_output_patch(self, patch): + """Sets a model attention1 output patch. + Args: + patch: The patch to set. + """ self.set_model_patch(patch, "attn1_output_patch") def set_model_attn2_output_patch(self, patch): + """Sets a model attention2 output patch. + Args: + patch: The patch to set. + """ self.set_model_patch(patch, "attn2_output_patch") def set_model_input_block_patch(self, patch): + """Sets a model input block patch. + Args: + patch: The patch to set. + """ self.set_model_patch(patch, "input_block_patch") def set_model_input_block_patch_after_skip(self, patch): + """Sets a model input block patch after the skip connection. + Args: + patch: The patch to set. + """ self.set_model_patch(patch, "input_block_patch_after_skip") def set_model_output_block_patch(self, patch): + """Sets a model output block patch. + Args: + patch: The patch to set. + """ self.set_model_patch(patch, "output_block_patch") def add_object_patch(self, name, obj): + """Adds an object patch. + Args: + name (str): The name of the patch. + obj: The object to patch. + """ self.object_patches[name] = obj def model_patches_to(self, device): + """Moves the model patches to a device. + Args: + device (torch.device): The device to move the patches to. + """ to = self.model_options["transformer_options"] if "patches" in to: patches = to["patches"] @@ -137,10 +245,22 @@ def model_patches_to(self, device): self.model_options["model_function_wrapper"] = wrap_func.to(device) def model_dtype(self): + """Gets the dtype of the model. + Returns: + torch.dtype: The dtype of the model. + """ if hasattr(self.model, "get_dtype"): return self.model.get_dtype() def add_patches(self, patches, strength_patch=1.0, strength_model=1.0): + """Adds patches to the model. + Args: + patches (dict): A dictionary of patches to add. + strength_patch (float, optional): The strength of the patch. Defaults to 1.0. + strength_model (float, optional): The strength of the model. Defaults to 1.0. + Returns: + list: A list of the patched keys. + """ p = set() for k in patches: if k in self.model_keys: @@ -152,6 +272,12 @@ def add_patches(self, patches, strength_patch=1.0, strength_model=1.0): return list(p) def get_key_patches(self, filter_prefix=None): + """Gets the key patches. + Args: + filter_prefix (str, optional): A prefix to filter the keys by. Defaults to None. + Returns: + dict: A dictionary of the key patches. + """ ldm_patched.modules.model_management.unload_model_clones(self) model_sd = self.model_state_dict() p = {} @@ -166,6 +292,12 @@ def get_key_patches(self, filter_prefix=None): return p def model_state_dict(self, filter_prefix=None): + """Gets the model's state dict. + Args: + filter_prefix (str, optional): A prefix to filter the keys by. Defaults to None. + Returns: + dict: The model's state dict. + """ sd = self.model.state_dict() keys = list(sd.keys()) if filter_prefix is not None: @@ -175,6 +307,13 @@ def model_state_dict(self, filter_prefix=None): return sd def patch_model(self, device_to=None, patch_weights=True): + """Patches the model. + Args: + device_to (torch.device, optional): The device to move the model to. Defaults to None. + patch_weights (bool, optional): Whether to patch the weights. Defaults to True. + Returns: + The patched model. + """ for k in self.object_patches: old = getattr(self.model, k) if k not in self.object_patches_backup: @@ -213,6 +352,14 @@ def patch_model(self, device_to=None, patch_weights=True): return self.model def calculate_weight(self, patches, weight, key): + """Calculates the weight of a patch. + Args: + patches (list): A list of patches. + weight: The weight to patch. + key (str): The key of the weight. + Returns: + The patched weight. + """ for p in patches: alpha = p[0] v = p[1] @@ -335,6 +482,10 @@ def calculate_weight(self, patches, weight, key): return weight def unpatch_model(self, device_to=None): + """Unpatches the model. + Args: + device_to (torch.device, optional): The device to move the model to. Defaults to None. + """ keys = list(self.backup.keys()) if self.weight_inplace_update: diff --git a/ldm_patched/modules/sample.py b/ldm_patched/modules/sample.py index 0f48395037..981594c9eb 100644 --- a/ldm_patched/modules/sample.py +++ b/ldm_patched/modules/sample.py @@ -7,9 +7,13 @@ import numpy as np def prepare_noise(latent_image, seed, noise_inds=None): - """ - creates random noise given a latent image and a seed. - optional arg skip can be used to skip and discard x number of noise generations for a given seed + """Creates random noise given a latent image and a seed. + Args: + latent_image (torch.Tensor): The latent image. + seed (int): The random seed. + noise_inds (np.ndarray, optional): The noise indices. Defaults to None. + Returns: + torch.Tensor: The random noise. """ generator = torch.manual_seed(seed) if noise_inds is None: @@ -26,7 +30,14 @@ def prepare_noise(latent_image, seed, noise_inds=None): return noises def prepare_mask(noise_mask, shape, device): - """ensures noise mask is of proper dimensions""" + """Ensures noise mask is of proper dimensions. + Args: + noise_mask (torch.Tensor): The noise mask. + shape (tuple): The shape of the noise mask. + device (torch.device): The device to create the noise mask on. + Returns: + torch.Tensor: The prepared noise mask. + """ noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear") noise_mask = torch.cat([noise_mask] * shape[1], dim=1) noise_mask = ldm_patched.modules.utils.repeat_to_batch_size(noise_mask, shape[0]) @@ -34,6 +45,13 @@ def prepare_mask(noise_mask, shape, device): return noise_mask def get_models_from_cond(cond, model_type): + """Gets a list of models from a conditioning. + Args: + cond: The conditioning. + model_type (str): The type of model to get. + Returns: + list: A list of models. + """ models = [] for c in cond: if model_type in c: @@ -41,6 +59,12 @@ def get_models_from_cond(cond, model_type): return models def convert_cond(cond): + """Converts a conditioning to a different format. + Args: + cond: The conditioning to convert. + Returns: + The converted conditioning. + """ out = [] for c in cond: temp = c[1].copy() @@ -53,7 +77,14 @@ def convert_cond(cond): return out def get_additional_models(positive, negative, dtype): - """loads additional models in positive and negative conditioning""" + """Loads additional models in positive and negative conditioning. + Args: + positive: The positive conditioning. + negative: The negative conditioning. + dtype: The dtype of the models. + Returns: + A tuple of the additional models and the inference memory requirements. + """ control_nets = set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control")) inference_memory = 0 @@ -68,12 +99,25 @@ def get_additional_models(positive, negative, dtype): return models, inference_memory def cleanup_additional_models(models): - """cleanup additional models that were loaded""" + """Cleans up additional models that were loaded. + Args: + models (list): A list of models to clean up. + """ for m in models: if hasattr(m, 'cleanup'): m.cleanup() def prepare_sampling(model, noise_shape, positive, negative, noise_mask): + """Prepares for sampling. + Args: + model: The model. + noise_shape (tuple): The shape of the noise. + positive: The positive conditioning. + negative: The negative conditioning. + noise_mask (torch.Tensor): The noise mask. + Returns: + A tuple of the real model, the positive conditioning, the negative conditioning, the noise mask, and the additional models. + """ device = model.load_device positive = convert_cond(positive) negative = convert_cond(negative) @@ -90,6 +134,30 @@ def prepare_sampling(model, noise_shape, positive, negative, noise_mask): def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None): + """Samples from a model. + Args: + model: The model to sample from. + noise (torch.Tensor): The noise to start with. + steps (int): The number of steps to sample for. + cfg (float): The classifier-free guidance scale. + sampler_name (str): The name of the sampler to use. + scheduler (str): The name of the scheduler to use. + positive: The positive conditioning. + negative: The negative conditioning. + latent_image (torch.Tensor): The latent image to start with. + denoise (float, optional): The denoising strength. Defaults to 1.0. + disable_noise (bool, optional): Whether to disable noise. Defaults to False. + start_step (int, optional): The step to start sampling from. Defaults to None. + last_step (int, optional): The step to stop sampling at. Defaults to None. + force_full_denoise (bool, optional): Whether to force full denoising. Defaults to False. + noise_mask (torch.Tensor, optional): The noise mask. Defaults to None. + sigmas (torch.Tensor, optional): The sigmas to use. Defaults to None. + callback (function, optional): A callback function. Defaults to None. + disable_pbar (bool, optional): Whether to disable the progress bar. Defaults to False. + seed (int, optional): The random seed. Defaults to None. + Returns: + torch.Tensor: The sampled images. + """ real_model, positive_copy, negative_copy, noise_mask, models = prepare_sampling(model, noise.shape, positive, negative, noise_mask) noise = noise.to(model.load_device) @@ -105,6 +173,23 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative return samples def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None): + """Samples from a model using a custom sampler. + Args: + model: The model to sample from. + noise (torch.Tensor): The noise to start with. + cfg (float): The classifier-free guidance scale. + sampler: The sampler to use. + sigmas (torch.Tensor): The sigmas to use. + positive: The positive conditioning. + negative: The negative conditioning. + latent_image (torch.Tensor): The latent image to start with. + noise_mask (torch.Tensor, optional): The noise mask. Defaults to None. + callback (function, optional): A callback function. Defaults to None. + disable_pbar (bool, optional): Whether to disable the progress bar. Defaults to False. + seed (int, optional): The random seed. Defaults to None. + Returns: + torch.Tensor: The sampled images. + """ real_model, positive_copy, negative_copy, noise_mask, models = prepare_sampling(model, noise.shape, positive, negative, noise_mask) noise = noise.to(model.load_device) latent_image = latent_image.to(model.load_device) diff --git a/modules/anisotropic.py b/modules/anisotropic.py index 5768222407..85424cd380 100644 --- a/modules/anisotropic.py +++ b/modules/anisotropic.py @@ -8,11 +8,23 @@ def _compute_zero_padding(kernel_size: tuple[int, int] | int) -> tuple[int, int]: + """Computes the zero padding for a given kernel size. + Args: + kernel_size (tuple[int, int] or int): The kernel size. + Returns: + tuple[int, int]: The zero padding. + """ ky, kx = _unpack_2d_ks(kernel_size) return (ky - 1) // 2, (kx - 1) // 2 def _unpack_2d_ks(kernel_size: tuple[int, int] | int) -> tuple[int, int]: + """Unpacks a 2D kernel size into its y and x components. + Args: + kernel_size (tuple[int, int] or int): The kernel size. + Returns: + tuple[int, int]: The y and x components of the kernel size. + """ if isinstance(kernel_size, int): ky = kx = kernel_size else: @@ -27,6 +39,15 @@ def _unpack_2d_ks(kernel_size: tuple[int, int] | int) -> tuple[int, int]: def gaussian( window_size: int, sigma: Tensor | float, *, device: Device | None = None, dtype: Dtype | None = None ) -> Tensor: + """Generates a 1D Gaussian kernel. + Args: + window_size (int): The size of the window. + sigma (Tensor or float): The standard deviation of the Gaussian. + device (Device, optional): The device to create the tensor on. Defaults to None. + dtype (Dtype, optional): The data type of the tensor. Defaults to None. + Returns: + Tensor: The 1D Gaussian kernel. + """ batch_size = sigma.shape[0] @@ -48,6 +69,16 @@ def get_gaussian_kernel1d( device: Device | None = None, dtype: Dtype | None = None, ) -> Tensor: + """Generates a 1D Gaussian kernel. + Args: + kernel_size (int): The size of the kernel. + sigma (float or Tensor): The standard deviation of the Gaussian. + force_even (bool, optional): Whether to force the kernel to be even. Defaults to False. + device (Device, optional): The device to create the tensor on. Defaults to None. + dtype (Dtype, optional): The data type of the tensor. Defaults to None. + Returns: + Tensor: The 1D Gaussian kernel. + """ return gaussian(kernel_size, sigma, device=device, dtype=dtype) @@ -60,6 +91,16 @@ def get_gaussian_kernel2d( device: Device | None = None, dtype: Dtype | None = None, ) -> Tensor: + """Generates a 2D Gaussian kernel. + Args: + kernel_size (tuple[int, int] or int): The size of the kernel. + sigma (tuple[float, float] or Tensor): The standard deviation of the Gaussian. + force_even (bool, optional): Whether to force the kernel to be even. Defaults to False. + device (Device, optional): The device to create the tensor on. Defaults to None. + dtype (Dtype, optional): The data type of the tensor. Defaults to None. + Returns: + Tensor: The 2D Gaussian kernel. + """ sigma = torch.Tensor([[sigma, sigma]]).to(device=device, dtype=dtype) @@ -123,10 +164,28 @@ def bilateral_blur( border_type: str = 'reflect', color_distance_type: str = 'l1', ) -> Tensor: + """Applies a bilateral blur to an image. + Args: + input (Tensor): The input image. + kernel_size (tuple[int, int] or int, optional): The size of the kernel. Defaults to (13, 13). + sigma_color (float or Tensor, optional): The standard deviation of the color space. Defaults to 3.0. + sigma_space (tuple[float, float] or Tensor, optional): The standard deviation of the coordinate space. Defaults to 3.0. + border_type (str, optional): The padding mode. Defaults to 'reflect'. + color_distance_type (str, optional): The color distance type. Defaults to 'l1'. + Returns: + Tensor: The blurred image. + """ return _bilateral_blur(input, None, kernel_size, sigma_color, sigma_space, border_type, color_distance_type) def adaptive_anisotropic_filter(x, g=None): + """Applies an adaptive anisotropic filter to an image. + Args: + x (Tensor): The input image. + g (Tensor, optional): The guidance image. Defaults to None. + Returns: + Tensor: The filtered image. + """ if g is None: g = x s, m = torch.std_mean(g, dim=(1, 2, 3), keepdim=True) @@ -150,10 +209,23 @@ def joint_bilateral_blur( border_type: str = 'reflect', color_distance_type: str = 'l1', ) -> Tensor: + """Applies a joint bilateral blur to an image. + Args: + input (Tensor): The input image. + guidance (Tensor): The guidance image. + kernel_size (tuple[int, int] or int): The size of the kernel. + sigma_color (float or Tensor): The standard deviation of the color space. + sigma_space (tuple[float, float] or Tensor): The standard deviation of the coordinate space. + border_type (str, optional): The padding mode. Defaults to 'reflect'. + color_distance_type (str, optional): The color distance type. Defaults to 'l1'. + Returns: + Tensor: The blurred image. + """ return _bilateral_blur(input, guidance, kernel_size, sigma_color, sigma_space, border_type, color_distance_type) class _BilateralBlur(torch.nn.Module): + """Base class for bilateral blur modules.""" def __init__( self, kernel_size: tuple[int, int] | int, @@ -162,6 +234,14 @@ def __init__( border_type: str = 'reflect', color_distance_type: str = "l1", ) -> None: + """Initializes the _BilateralBlur module. + Args: + kernel_size (tuple[int, int] or int): The size of the kernel. + sigma_color (float or Tensor): The standard deviation of the color space. + sigma_space (tuple[float, float] or Tensor): The standard deviation of the coordinate space. + border_type (str, optional): The padding mode. Defaults to 'reflect'. + color_distance_type (str, optional): The color distance type. Defaults to 'l1'. + """ super().__init__() self.kernel_size = kernel_size self.sigma_color = sigma_color @@ -181,14 +261,29 @@ def __repr__(self) -> str: class BilateralBlur(_BilateralBlur): + """Applies a bilateral blur to an image.""" def forward(self, input: Tensor) -> Tensor: + """Forward pass of the BilateralBlur module. + Args: + input (Tensor): The input image. + Returns: + Tensor: The blurred image. + """ return bilateral_blur( input, self.kernel_size, self.sigma_color, self.sigma_space, self.border_type, self.color_distance_type ) class JointBilateralBlur(_BilateralBlur): + """Applies a joint bilateral blur to an image.""" def forward(self, input: Tensor, guidance: Tensor) -> Tensor: + """Forward pass of the JointBilateralBlur module. + Args: + input (Tensor): The input image. + guidance (Tensor): The guidance image. + Returns: + Tensor: The blurred image. + """ return joint_bilateral_blur( input, guidance, diff --git a/modules/async_worker.py b/modules/async_worker.py index a0b96a542f..80bd86000f 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -8,7 +8,12 @@ class AsyncTask: + """Represents a task for asynchronous image generation.""" def __init__(self, args): + """Initializes an AsyncTask object. + Args: + args (list): A list of arguments for the task. + """ from modules.flags import Performance, MetadataScheme, ip_list, disabled from modules.util import get_enabled_loras from modules.config import default_max_lora_number @@ -162,10 +167,12 @@ def __init__(self, args): class EarlyReturnException(BaseException): + """An exception that is used to signal an early return from a task.""" pass def worker(): + """The main worker function that processes asynchronous tasks.""" global async_tasks import os diff --git a/modules/auth.py b/modules/auth.py index 3ba1114245..286b5b3de0 100644 --- a/modules/auth.py +++ b/modules/auth.py @@ -6,6 +6,12 @@ def auth_list_to_dict(auth_list): + """Converts a list of authentication data to a dictionary. + Args: + auth_list (list): A list of dictionaries, where each dictionary contains a 'user' and either a 'hash' or a 'pass'. + Returns: + dict: A dictionary where the keys are usernames and the values are password hashes. + """ auth_dict = {} for auth_data in auth_list: if 'user' in auth_data: @@ -17,6 +23,12 @@ def auth_list_to_dict(auth_list): def load_auth_data(filename=None): + """Loads authentication data from a JSON file. + Args: + filename (str, optional): The name of the authentication file. Defaults to None. + Returns: + dict or None: A dictionary of authentication data, or None if the file cannot be loaded. + """ auth_dict = None if filename != None and exists(filename): with open(filename, encoding='utf-8') as auth_file: @@ -35,6 +47,13 @@ def load_auth_data(filename=None): def check_auth(user, password): + """Checks if a user's password is correct. + Args: + user (str): The username. + password (str): The password. + Returns: + bool: True if the password is correct, False otherwise. + """ if user not in auth_dict: return False else: diff --git a/modules/config.py b/modules/config.py index 8609b4154b..5517669490 100644 --- a/modules/config.py +++ b/modules/config.py @@ -14,6 +14,13 @@ def get_config_path(key, default_value): + """Gets a configuration path from an environment variable or returns a default value. + Args: + key (str): The name of the environment variable. + default_value (str): The default value to use if the environment variable is not set. + Returns: + str: The configuration path. + """ env = os.getenv(key) if env is not None and isinstance(env, str): print(f"Environment: {key} = {env}") @@ -50,6 +57,7 @@ def get_config_path(key, default_value): def try_load_deprecated_user_path_config(): + """Loads a deprecated user path configuration file if it exists and updates the main configuration dictionary.""" global config_dict if not os.path.exists('user_path_config.txt'): @@ -99,6 +107,10 @@ def replace_config(old_key, new_key): try_load_deprecated_user_path_config() def get_presets(): + """Gets a list of available presets from the 'presets' directory. + Returns: + list: A list of preset names. + """ preset_folder = 'presets' presets = ['initial'] if not os.path.exists(preset_folder): @@ -108,10 +120,17 @@ def get_presets(): return presets + [f[:f.index(".json")] for f in os.listdir(preset_folder) if f.endswith('.json')] def update_presets(): + """Updates the list of available presets.""" global available_presets available_presets = get_presets() def try_get_preset_content(preset): + """Tries to load the content of a preset file. + Args: + preset (str): The name of the preset to load. + Returns: + dict: The content of the preset file as a dictionary, or an empty dictionary if the preset cannot be loaded. + """ if isinstance(preset, str): preset_path = os.path.abspath(f'./presets/{preset}.json') try: @@ -144,6 +163,15 @@ def get_path_output() -> str: def get_dir_or_set_default(key, default_value, as_array=False, make_directory=False): + """Gets a directory path from the configuration or sets a default value. + Args: + key (str): The key of the configuration item. + default_value (str or list): The default value to use. + as_array (bool, optional): Whether to return the path as an array. Defaults to False. + make_directory (bool, optional): Whether to create the directory if it doesn't exist. Defaults to False. + Returns: + str or list: The directory path(s). + """ global config_dict, visited_keys, always_save_keys if key not in visited_keys: @@ -205,6 +233,16 @@ def get_dir_or_set_default(key, default_value, as_array=False, make_directory=Fa def get_config_item_or_set_default(key, default_value, validator, disable_empty_as_none=False, expected_type=None): + """Gets a configuration item or sets a default value. + Args: + key (str): The key of the configuration item. + default_value: The default value to use. + validator (function): A function to validate the configuration item. + disable_empty_as_none (bool, optional): Whether to disable treating empty values as None. Defaults to False. + expected_type (type, optional): The expected type of the configuration item. Defaults to None. + Returns: + The configuration item. + """ global config_dict, visited_keys if key not in visited_keys: @@ -234,6 +272,13 @@ def get_config_item_or_set_default(key, default_value, validator, disable_empty_ def init_temp_path(path: str | None, default_path: str) -> str: + """Initializes the temporary path. + Args: + path (str or None): The path to initialize. + default_path (str): The default path to use. + Returns: + str: The initialized temporary path. + """ if args_manager.args.temp_path: path = args_manager.args.temp_path @@ -765,6 +810,12 @@ def init_temp_path(path: str | None, default_path: str) -> str: def add_ratio(x): + """Adds a ratio to a string that represents a resolution. + Args: + x (str): The resolution string (e.g., "1024*1024"). + Returns: + str: The resolution string with the ratio appended (e.g., "1024×1024 ❘ 1:1"). + """ a, b = x.replace('*', ' ').split(' ')[:2] a, b = int(a), int(b) g = math.gcd(a, b) @@ -798,6 +849,14 @@ def add_ratio(x): def get_model_filenames(folder_paths, extensions=None, name_filter=None): + """Gets a list of model filenames from a list of folder paths. + Args: + folder_paths (list or str): A list of folder paths or a single folder path. + extensions (list, optional): A list of file extensions to include. Defaults to None. + name_filter (str, optional): A string to filter filenames by. Defaults to None. + Returns: + list: A list of model filenames. + """ if extensions is None: extensions = ['.pth', '.ckpt', '.bin', '.safetensors', '.fooocus.patch'] files = [] @@ -811,6 +870,7 @@ def get_model_filenames(folder_paths, extensions=None, name_filter=None): def update_files(): + """Updates the lists of available model, LoRA, VAE, and wildcard filenames.""" global model_filenames, lora_filenames, vae_filenames, wildcard_filenames, available_presets model_filenames = get_model_filenames(paths_checkpoints) lora_filenames = get_model_filenames(paths_loras) @@ -821,6 +881,12 @@ def update_files(): def downloading_inpaint_models(v): + """Downloads the inpainting models. + Args: + v (str): The version of the inpainting model to download. + Returns: + A tuple that contains the paths to the downloaded head and patch files. + """ assert v in modules.flags.inpaint_engine_versions load_file_from_url( @@ -859,6 +925,10 @@ def downloading_inpaint_models(v): def downloading_sdxl_lcm_lora(): + """Downloads the SDXL LCM LoRA model. + Returns: + str: The filename of the downloaded LoRA model. + """ load_file_from_url( url='https://huggingface.co/lllyasviel/misc/resolve/main/sdxl_lcm_lora.safetensors', model_dir=paths_loras[0], @@ -868,6 +938,10 @@ def downloading_sdxl_lcm_lora(): def downloading_sdxl_lightning_lora(): + """Downloads the SDXL Lightning LoRA model. + Returns: + str: The filename of the downloaded LoRA model. + """ load_file_from_url( url='https://huggingface.co/mashb1t/misc/resolve/main/sdxl_lightning_4step_lora.safetensors', model_dir=paths_loras[0], @@ -877,6 +951,10 @@ def downloading_sdxl_lightning_lora(): def downloading_sdxl_hyper_sd_lora(): + """Downloads the SDXL Hyper SD LoRA model. + Returns: + str: The filename of the downloaded LoRA model. + """ load_file_from_url( url='https://huggingface.co/mashb1t/misc/resolve/main/sdxl_hyper_sd_4step_lora.safetensors', model_dir=paths_loras[0], @@ -886,6 +964,10 @@ def downloading_sdxl_hyper_sd_lora(): def downloading_controlnet_canny(): + """Downloads the ControlNet Canny model. + Returns: + str: The path to the downloaded ControlNet Canny model. + """ load_file_from_url( url='https://huggingface.co/lllyasviel/misc/resolve/main/control-lora-canny-rank128.safetensors', model_dir=path_controlnet, @@ -895,6 +977,10 @@ def downloading_controlnet_canny(): def downloading_controlnet_cpds(): + """Downloads the ControlNet CPDS model. + Returns: + str: The path to the downloaded ControlNet CPDS model. + """ load_file_from_url( url='https://huggingface.co/lllyasviel/misc/resolve/main/fooocus_xl_cpds_128.safetensors', model_dir=path_controlnet, @@ -904,6 +990,12 @@ def downloading_controlnet_cpds(): def downloading_ip_adapters(v): + """Downloads the IP adapter models. + Args: + v (str): The type of IP adapter to download ('ip' or 'face'). + Returns: + list: A list of paths to the downloaded models. + """ assert v in ['ip', 'face'] results = [] @@ -942,6 +1034,10 @@ def downloading_ip_adapters(v): def downloading_upscale_model(): + """Downloads the upscale model. + Returns: + str: The path to the downloaded upscale model. + """ load_file_from_url( url='https://huggingface.co/lllyasviel/misc/resolve/main/fooocus_upscaler_s409985e5.bin', model_dir=path_upscale_models, @@ -950,6 +1046,10 @@ def downloading_upscale_model(): return os.path.join(path_upscale_models, 'fooocus_upscaler_s409985e5.bin') def downloading_safety_checker_model(): + """Downloads the safety checker model. + Returns: + str: The path to the downloaded safety checker model. + """ load_file_from_url( url='https://huggingface.co/mashb1t/misc/resolve/main/stable-diffusion-safety-checker.bin', model_dir=path_safety_checker, @@ -959,6 +1059,12 @@ def downloading_safety_checker_model(): def download_sam_model(sam_model: str) -> str: + """Downloads a SAM model. + Args: + sam_model (str): The name of the SAM model to download. + Returns: + str: The path to the downloaded SAM model. + """ match sam_model: case 'vit_b': return downloading_sam_vit_b() @@ -971,6 +1077,10 @@ def download_sam_model(sam_model: str) -> str: def downloading_sam_vit_b(): + """Downloads the SAM ViT-B model. + Returns: + str: The path to the downloaded SAM ViT-B model. + """ load_file_from_url( url='https://huggingface.co/mashb1t/misc/resolve/main/sam_vit_b_01ec64.pth', model_dir=path_sam, @@ -980,6 +1090,10 @@ def downloading_sam_vit_b(): def downloading_sam_vit_l(): + """Downloads the SAM ViT-L model. + Returns: + str: The path to the downloaded SAM ViT-L model. + """ load_file_from_url( url='https://huggingface.co/mashb1t/misc/resolve/main/sam_vit_l_0b3195.pth', model_dir=path_sam, @@ -989,6 +1103,10 @@ def downloading_sam_vit_l(): def downloading_sam_vit_h(): + """Downloads the SAM ViT-H model. + Returns: + str: The path to the downloaded SAM ViT-H model. + """ load_file_from_url( url='https://huggingface.co/mashb1t/misc/resolve/main/sam_vit_h_4b8939.pth', model_dir=path_sam, diff --git a/modules/core.py b/modules/core.py index 1c3dacb998..5ffa4aeb66 100644 --- a/modules/core.py +++ b/modules/core.py @@ -35,7 +35,19 @@ class StableDiffusionModel: + """A class that represents a Stable Diffusion model, which includes its components: UNet, VAE, and CLIP. + It also includes functionality for applying LoRAs. + """ def __init__(self, unet=None, vae=None, clip=None, clip_vision=None, filename=None, vae_filename=None): + """Initializes the StableDiffusionModel. + Args: + unet: UNet model. + vae: VAE model. + clip: CLIP model for text encoding. + clip_vision: CLIP model for image encoding. + filename (str): Filename of the main model checkpoint. + vae_filename (str): Filename of the VAE model checkpoint. + """ self.unet = unet self.vae = vae self.clip = clip @@ -60,6 +72,10 @@ def __init__(self, unet=None, vae=None, clip=None, clip_vision=None, filename=No @torch.no_grad() @torch.inference_mode() def refresh_loras(self, loras): + """Applies a list of LoRAs to the UNet and CLIP models. + Args: + loras (list): A list of tuples, where each tuple contains the LoRA filename and its weight. + """ assert isinstance(loras, list) if self.visited_loras == str(loras): @@ -125,18 +141,46 @@ def refresh_loras(self, loras): @torch.no_grad() @torch.inference_mode() def apply_freeu(model, b1, b2, s1, s2): + """Applies the FreeU patch to a model. + Args: + model: The model to patch. + b1 (float): The first balancing factor. + b2 (float): The second balancing factor. + s1 (float): The first scaling factor. + s2 (float): The second scaling factor. + Returns: + The patched model. + """ return opFreeU.patch(model=model, b1=b1, b2=b2, s1=s1, s2=s2)[0] @torch.no_grad() @torch.inference_mode() def load_controlnet(ckpt_filename): + """Loads a ControlNet model from a checkpoint file. + Args: + ckpt_filename (str): The path to the ControlNet checkpoint file. + Returns: + The loaded ControlNet model. + """ return ldm_patched.modules.controlnet.load_controlnet(ckpt_filename) @torch.no_grad() @torch.inference_mode() def apply_controlnet(positive, negative, control_net, image, strength, start_percent, end_percent): + """Applies a ControlNet to the positive and negative conditioning. + Args: + positive: The positive conditioning. + negative: The negative conditioning. + control_net: The ControlNet model. + image: The input image for the ControlNet. + strength (float): The strength of the ControlNet. + start_percent (float): The start percentage for applying the ControlNet. + end_percent (float): The end percentage for applying the ControlNet. + Returns: + A tuple that contains the modified positive and negative conditioning. + """ return opControlNetApplyAdvanced.apply_controlnet(positive=positive, negative=negative, control_net=control_net, image=image, strength=strength, start_percent=start_percent, end_percent=end_percent) @@ -144,6 +188,13 @@ def apply_controlnet(positive, negative, control_net, image, strength, start_per @torch.no_grad() @torch.inference_mode() def load_model(ckpt_filename, vae_filename=None): + """Loads a Stable Diffusion model from a checkpoint file. + Args: + ckpt_filename (str): The path to the model checkpoint file. + vae_filename (str, optional): The path to the VAE checkpoint file. Defaults to None. + Returns: + A StableDiffusionModel object. + """ unet, clip, vae, vae_filename, clip_vision = load_checkpoint_guess_config(ckpt_filename, embedding_directory=path_embeddings, vae_filename_param=vae_filename) return StableDiffusionModel(unet=unet, clip=clip, vae=vae, clip_vision=clip_vision, filename=ckpt_filename, vae_filename=vae_filename) @@ -152,12 +203,28 @@ def load_model(ckpt_filename, vae_filename=None): @torch.no_grad() @torch.inference_mode() def generate_empty_latent(width=1024, height=1024, batch_size=1): + """Generates an empty latent tensor. + Args: + width (int, optional): The width of the latent tensor. Defaults to 1024. + height (int, optional): The height of the latent tensor. Defaults to 1024. + batch_size (int, optional): The batch size. Defaults to 1. + Returns: + A torch.Tensor that represents the empty latent. + """ return opEmptyLatentImage.generate(width=width, height=height, batch_size=batch_size)[0] @torch.no_grad() @torch.inference_mode() def decode_vae(vae, latent_image, tiled=False): + """Decodes a latent image using the VAE. + Args: + vae: The VAE model. + latent_image: The latent image to decode. + tiled (bool, optional): Whether to use tiled decoding. Defaults to False. + Returns: + The decoded image as a torch.Tensor. + """ if tiled: return opVAEDecodeTiled.decode(samples=latent_image, vae=vae, tile_size=512)[0] else: @@ -167,6 +234,14 @@ def decode_vae(vae, latent_image, tiled=False): @torch.no_grad() @torch.inference_mode() def encode_vae(vae, pixels, tiled=False): + """Encodes an image into a latent representation using the VAE. + Args: + vae: The VAE model. + pixels: The image to encode. + tiled (bool, optional): Whether to use tiled encoding. Defaults to False. + Returns: + The latent representation as a torch.Tensor. + """ if tiled: return opVAEEncodeTiled.encode(pixels=pixels, vae=vae, tile_size=512)[0] else: @@ -176,6 +251,14 @@ def encode_vae(vae, pixels, tiled=False): @torch.no_grad() @torch.inference_mode() def encode_vae_inpaint(vae, pixels, mask): + """Encodes an image for inpainting and takes a mask into account. + Args: + vae: The VAE model. + pixels: The image to encode. + mask: The inpainting mask. + Returns: + A tuple that contains the latent representation and the latent mask. + """ assert mask.ndim == 3 and pixels.ndim == 4 assert mask.shape[-1] == pixels.shape[-2] assert mask.shape[-2] == pixels.shape[-3] @@ -194,7 +277,9 @@ def encode_vae_inpaint(vae, pixels, mask): class VAEApprox(torch.nn.Module): + """A lightweight, approximate VAE for generating image previews during the sampling process.""" def __init__(self): + """Initializes the VAEApprox model.""" super(VAEApprox, self).__init__() self.conv1 = torch.nn.Conv2d(4, 8, (7, 7)) self.conv2 = torch.nn.Conv2d(8, 16, (5, 5)) @@ -207,6 +292,12 @@ def __init__(self): self.current_type = None def forward(self, x): + """Performs the forward pass of the VAEApprox model. + Args: + x: The input tensor. + Returns: + The output tensor. + """ extra = 11 x = torch.nn.functional.interpolate(x, (x.shape[2] * 2, x.shape[3] * 2)) x = torch.nn.functional.pad(x, (extra, extra, extra, extra)) @@ -222,6 +313,13 @@ def forward(self, x): @torch.no_grad() @torch.inference_mode() def get_previewer(model): + """Returns a function that can be used to generate previews of the image during sampling. + Args: + model: The main Stable Diffusion model. + Returns: + A function that takes the latent representation, step number, and total steps as input + and returns a preview image as a NumPy array. + """ global VAE_approx_models from modules.config import path_vae_approx @@ -266,6 +364,33 @@ def ksampler(model, positive, negative, latent, seed=None, steps=30, cfg=7.0, sa scheduler='karras', denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, callback_function=None, refiner=None, refiner_switch=-1, previewer_start=None, previewer_end=None, sigmas=None, noise_mean=None, disable_preview=False): + """Performs k-sampling to generate an image. + Args: + model: The main Stable Diffusion model. + positive: The positive conditioning. + negative: The negative conditioning. + latent: The initial latent representation. + seed (int, optional): The random seed. Defaults to None. + steps (int, optional): The number of sampling steps. Defaults to 30. + cfg (float, optional): The CFG scale. Defaults to 7.0. + sampler_name (str, optional): The name of the sampler to use. Defaults to 'dpmpp_2m_sde_gpu'. + scheduler (str, optional): The name of the scheduler to use. Defaults to 'karras'. + denoise (float, optional): The denoising strength. Defaults to 1.0. + disable_noise (bool, optional): Whether to disable noise. Defaults to False. + start_step (int, optional): The starting step for sampling. Defaults to None. + last_step (int, optional): The last step for sampling. Defaults to None. + force_full_denoise (bool, optional): Whether to force full denoising. Defaults to False. + callback_function (function, optional): A callback function to be called at each step. Defaults to None. + refiner (StableDiffusionModel, optional): A refiner model to use. Defaults to None. + refiner_switch (int, optional): The step at which to switch to the refiner model. Defaults to -1. + previewer_start (int, optional): The starting step for generating previews. Defaults to None. + previewer_end (int, optional): The ending step for generating previews. Defaults to None. + sigmas (torch.Tensor, optional): The sigmas to use for sampling. Defaults to None. + noise_mean (torch.Tensor, optional): The mean of the noise to add. Defaults to None. + disable_preview (bool, optional): Whether to disable previews. Defaults to False. + Returns: + The generated latent representation. + """ if sigmas is not None: sigmas = sigmas.clone().to(ldm_patched.modules.model_management.get_torch_device()) @@ -328,12 +453,24 @@ def callback(step, x0, x, total_steps): @torch.no_grad() @torch.inference_mode() def pytorch_to_numpy(x): + """Converts a list of PyTorch tensors to a list of NumPy arrays. + Args: + x (list): A list of PyTorch tensors. + Returns: + A list of NumPy arrays. + """ return [np.clip(255. * y.cpu().numpy(), 0, 255).astype(np.uint8) for y in x] @torch.no_grad() @torch.inference_mode() def numpy_to_pytorch(x): + """Converts a NumPy array to a PyTorch tensor. + Args: + x (np.ndarray): A NumPy array. + Returns: + A PyTorch tensor. + """ y = x.astype(np.float32) / 255.0 y = y[None] y = np.ascontiguousarray(y.copy()) diff --git a/modules/default_pipeline.py b/modules/default_pipeline.py index 494644d69d..c10b28dabe 100644 --- a/modules/default_pipeline.py +++ b/modules/default_pipeline.py @@ -31,6 +31,10 @@ @torch.no_grad() @torch.inference_mode() def refresh_controlnets(model_paths): + """Refreshes the loaded ControlNet models. + Args: + model_paths (list): A list of paths to the ControlNet models. + """ global loaded_ControlNets cache = {} for p in model_paths: @@ -46,6 +50,7 @@ def refresh_controlnets(model_paths): @torch.no_grad() @torch.inference_mode() def assert_model_integrity(): + """Asserts that the loaded models are valid.""" error_message = None if not isinstance(model_base.unet_with_lora.model, SDXL): @@ -60,6 +65,11 @@ def assert_model_integrity(): @torch.no_grad() @torch.inference_mode() def refresh_base_model(name, vae_name=None): + """Refreshes the base model. + Args: + name (str): The name of the base model. + vae_name (str, optional): The name of the VAE model. Defaults to None. + """ global model_base filename = get_file_from_folder_list(name, modules.config.paths_checkpoints) @@ -80,6 +90,10 @@ def refresh_base_model(name, vae_name=None): @torch.no_grad() @torch.inference_mode() def refresh_refiner_model(name): + """Refreshes the refiner model. + Args: + name (str): The name of the refiner model. + """ global model_refiner filename = get_file_from_folder_list(name, modules.config.paths_checkpoints) @@ -111,6 +125,7 @@ def refresh_refiner_model(name): @torch.no_grad() @torch.inference_mode() def synthesize_refiner_model(): + """Synthesizes a refiner model from the base model.""" global model_base, model_refiner print('Synthetic Refiner Activated') @@ -131,6 +146,11 @@ def synthesize_refiner_model(): @torch.no_grad() @torch.inference_mode() def refresh_loras(loras, base_model_additional_loras=None): + """Refreshes the LoRAs. + Args: + loras (list): A list of LoRAs to apply. + base_model_additional_loras (list, optional): A list of additional LoRAs to apply to the base model. Defaults to None. + """ global model_base, model_refiner if not isinstance(base_model_additional_loras, list): @@ -145,6 +165,14 @@ def refresh_loras(loras, base_model_additional_loras=None): @torch.no_grad() @torch.inference_mode() def clip_encode_single(clip, text, verbose=False): + """Encodes a single text prompt using CLIP. + Args: + clip: The CLIP model. + text (str): The text to encode. + verbose (bool, optional): Whether to print verbose output. Defaults to False. + Returns: + The encoded text. + """ cached = clip.fcs_cond_cache.get(text, None) if cached is not None: if verbose: @@ -161,6 +189,12 @@ def clip_encode_single(clip, text, verbose=False): @torch.no_grad() @torch.inference_mode() def clone_cond(conds): + """Clones a conditioning. + Args: + conds: The conditioning to clone. + Returns: + The cloned conditioning. + """ results = [] for c, p in conds: @@ -180,6 +214,13 @@ def clone_cond(conds): @torch.no_grad() @torch.inference_mode() def clip_encode(texts, pool_top_k=1): + """Encodes a list of text prompts using CLIP. + Args: + texts (list): A list of texts to encode. + pool_top_k (int, optional): The number of top-k pooled outputs to use. Defaults to 1. + Returns: + The encoded texts. + """ global final_clip if final_clip is None: @@ -204,6 +245,10 @@ def clip_encode(texts, pool_top_k=1): @torch.no_grad() @torch.inference_mode() def set_clip_skip(clip_skip: int): + """Sets the CLIP skip value. + Args: + clip_skip (int): The CLIP skip value. + """ global final_clip if final_clip is None: @@ -215,12 +260,17 @@ def set_clip_skip(clip_skip: int): @torch.no_grad() @torch.inference_mode() def clear_all_caches(): + """Clears all caches.""" final_clip.fcs_cond_cache = {} @torch.no_grad() @torch.inference_mode() def prepare_text_encoder(async_call=True): + """Prepares the text encoder for use. + Args: + async_call (bool, optional): Whether to prepare the text encoder asynchronously. Defaults to True. + """ if async_call: # TODO: make sure that this is always called in an async way so that users cannot feel it. pass @@ -233,6 +283,15 @@ def prepare_text_encoder(async_call=True): @torch.inference_mode() def refresh_everything(refiner_model_name, base_model_name, loras, base_model_additional_loras=None, use_synthetic_refiner=False, vae_name=None): + """Refreshes all models and components. + Args: + refiner_model_name (str): The name of the refiner model. + base_model_name (str): The name of the base model. + loras (list): A list of LoRAs to apply. + base_model_additional_loras (list, optional): A list of additional LoRAs to apply to the base model. Defaults to None. + use_synthetic_refiner (bool, optional): Whether to use a synthetic refiner. Defaults to False. + vae_name (str, optional): The name of the VAE model. Defaults to None. + """ global final_unet, final_clip, final_vae, final_refiner_unet, final_refiner_vae, final_expansion final_unet = None @@ -278,6 +337,12 @@ def refresh_everything(refiner_model_name, base_model_name, loras, @torch.no_grad() @torch.inference_mode() def vae_parse(latent): + """Parses a latent tensor using the VAE. + Args: + latent: The latent tensor to parse. + Returns: + The parsed latent tensor. + """ if final_refiner_vae is None: return latent @@ -288,6 +353,15 @@ def vae_parse(latent): @torch.no_grad() @torch.inference_mode() def calculate_sigmas_all(sampler, model, scheduler, steps): + """Calculates all sigmas for a given sampler, model, scheduler, and number of steps. + Args: + sampler: The sampler. + model: The model. + scheduler: The scheduler. + steps (int): The number of steps. + Returns: + The calculated sigmas. + """ from ldm_patched.modules.samplers import calculate_sigmas_scheduler discard_penultimate_sigma = False @@ -305,6 +379,16 @@ def calculate_sigmas_all(sampler, model, scheduler, steps): @torch.no_grad() @torch.inference_mode() def calculate_sigmas(sampler, model, scheduler, steps, denoise): + """Calculates the sigmas for a given sampler, model, scheduler, number of steps, and denoise value. + Args: + sampler: The sampler. + model: The model. + scheduler: The scheduler. + steps (int): The number of steps. + denoise (float): The denoise value. + Returns: + The calculated sigmas. + """ if denoise is None or denoise > 0.9999: sigmas = calculate_sigmas_all(sampler, model, scheduler, steps) else: @@ -317,6 +401,15 @@ def calculate_sigmas(sampler, model, scheduler, steps, denoise): @torch.no_grad() @torch.inference_mode() def get_candidate_vae(steps, switch, denoise=1.0, refiner_swap_method='joint'): + """Gets a candidate VAE for a given number of steps, switch value, denoise value, and refiner swap method. + Args: + steps (int): The number of steps. + switch (int): The switch value. + denoise (float, optional): The denoise value. Defaults to 1.0. + refiner_swap_method (str, optional): The refiner swap method. Defaults to 'joint'. + Returns: + A tuple that contains the candidate VAE and the candidate refiner VAE. + """ assert refiner_swap_method in ['joint', 'separate', 'vae'] if final_refiner_vae is not None and final_refiner_unet is not None: @@ -334,6 +427,27 @@ def get_candidate_vae(steps, switch, denoise=1.0, refiner_swap_method='joint'): @torch.no_grad() @torch.inference_mode() def process_diffusion(positive_cond, negative_cond, steps, switch, width, height, image_seed, callback, sampler_name, scheduler_name, latent=None, denoise=1.0, tiled=False, cfg_scale=7.0, refiner_swap_method='joint', disable_preview=False): + """Processes the diffusion. + Args: + positive_cond: The positive conditioning. + negative_cond: The negative conditioning. + steps (int): The number of steps. + switch (int): The switch value. + width (int): The width of the image. + height (int): The height of the image. + image_seed (int): The image seed. + callback (function): The callback function. + sampler_name (str): The name of the sampler. + scheduler_name (str): The name of the scheduler. + latent (optional): The latent tensor. Defaults to None. + denoise (float, optional): The denoise value. Defaults to 1.0. + tiled (bool, optional): Whether to use tiled processing. Defaults to False. + cfg_scale (float, optional): The CFG scale. Defaults to 7.0. + refiner_swap_method (str, optional): The refiner swap method. Defaults to 'joint'. + disable_preview (bool, optional): Whether to disable the preview. Defaults to False. + Returns: + The processed images. + """ target_unet, target_vae, target_refiner_unet, target_refiner_vae, target_clip \ = final_unet, final_vae, final_refiner_unet, final_refiner_vae, final_clip diff --git a/modules/extra_utils.py b/modules/extra_utils.py index c2dfa81044..25d1e569c8 100644 --- a/modules/extra_utils.py +++ b/modules/extra_utils.py @@ -3,6 +3,10 @@ def makedirs_with_log(path): + """Creates a directory with a log message. + Args: + path (str): The path to the directory to create. + """ try: os.makedirs(path, exist_ok=True) except OSError as error: @@ -10,6 +14,14 @@ def makedirs_with_log(path): def get_files_from_folder(folder_path, extensions=None, name_filter=None): + """Gets a list of files from a folder. + Args: + folder_path (str): The path to the folder. + extensions (list, optional): A list of file extensions to include. Defaults to None. + name_filter (str, optional): A string to filter filenames by. Defaults to None. + Returns: + list: A list of filenames. + """ if not os.path.isdir(folder_path): raise ValueError("Folder path is not a valid directory.") @@ -29,6 +41,13 @@ def get_files_from_folder(folder_path, extensions=None, name_filter=None): def try_eval_env_var(value: str, expected_type=None): + """Tries to evaluate an environment variable. + Args: + value (str): The value of the environment variable. + expected_type (type, optional): The expected type of the environment variable. Defaults to None. + Returns: + The evaluated environment variable, or the original value if evaluation fails. + """ try: value_eval = value if expected_type is bool: diff --git a/modules/flags.py b/modules/flags.py index 05c29a2328..673d99e521 100644 --- a/modules/flags.py +++ b/modules/flags.py @@ -108,6 +108,7 @@ class MetadataScheme(Enum): + """An enumeration of the available metadata schemes.""" FOOOCUS = 'fooocus' A1111 = 'a1111' @@ -119,16 +120,22 @@ class MetadataScheme(Enum): class OutputFormat(Enum): + """An enumeration of the available output formats.""" PNG = 'png' JPEG = 'jpeg' WEBP = 'webp' @classmethod def list(cls) -> list: + """Returns a list of the available output formats. + Returns: + list: A list of the available output formats. + """ return list(map(lambda c: c.value, cls)) class PerformanceLoRA(Enum): + """An enumeration of the available performance LoRAs.""" QUALITY = None SPEED = None EXTREME_SPEED = 'sdxl_lcm_lora.safetensors' @@ -137,6 +144,7 @@ class PerformanceLoRA(Enum): class Steps(IntEnum): + """An enumeration of the available steps for each performance level.""" QUALITY = 60 SPEED = 30 EXTREME_SPEED = 8 @@ -145,10 +153,15 @@ class Steps(IntEnum): @classmethod def keys(cls) -> list: + """Returns a list of the available steps keys. + Returns: + list: A list of the available steps keys. + """ return list(map(lambda c: c, Steps.__members__)) class StepsUOV(IntEnum): + """An enumeration of the available steps for upscale or variation for each performance level.""" QUALITY = 36 SPEED = 18 EXTREME_SPEED = 8 @@ -157,6 +170,7 @@ class StepsUOV(IntEnum): class Performance(Enum): + """An enumeration of the available performance levels.""" QUALITY = 'Quality' SPEED = 'Speed' EXTREME_SPEED = 'Extreme Speed' @@ -165,27 +179,59 @@ class Performance(Enum): @classmethod def list(cls) -> list: + """Returns a list of the available performance levels. + Returns: + list: A list of the available performance levels. + """ return list(map(lambda c: (c.name, c.value), cls)) @classmethod def values(cls) -> list: + """Returns a list of the available performance level values. + Returns: + list: A list of the available performance level values. + """ return list(map(lambda c: c.value, cls)) @classmethod def by_steps(cls, steps: int | str): + """Returns a performance level by the number of steps. + Args: + steps (int or str): The number of steps. + Returns: + Performance: The performance level. + """ return cls[Steps(int(steps)).name] @classmethod def has_restricted_features(cls, x) -> bool: + """Checks if a performance level has restricted features. + Args: + x: The performance level. + Returns: + bool: True if the performance level has restricted features, False otherwise. + """ if isinstance(x, Performance): x = x.value return x in [cls.EXTREME_SPEED.value, cls.LIGHTNING.value, cls.HYPER_SD.value] def steps(self) -> int | None: + """Returns the number of steps for the performance level. + Returns: + int or None: The number of steps for the performance level, or None if it is not defined. + """ return Steps[self.name].value if self.name in Steps.__members__ else None def steps_uov(self) -> int | None: + """Returns the number of steps for upscale or variation for the performance level. + Returns: + int or None: The number of steps for upscale or variation for the performance level, or None if it is not defined. + """ return StepsUOV[self.name].value if self.name in StepsUOV.__members__ else None def lora_filename(self) -> str | None: + """Returns the LoRA filename for the performance level. + Returns: + str or None: The LoRA filename for the performance level, or None if it is not defined. + """ return PerformanceLoRA[self.name].value if self.name in PerformanceLoRA.__members__ else None diff --git a/modules/gradio_hijack.py b/modules/gradio_hijack.py index 35df81c00c..0acbefaabd 100644 --- a/modules/gradio_hijack.py +++ b/modules/gradio_hijack.py @@ -461,6 +461,13 @@ def as_example(self, input_data: str | None) -> str: def blk_ini(self, *args, **kwargs): + """Initializes a block. + Args: + *args: The arguments to pass to the original __init__ method. + **kwargs: The keyword arguments to pass to the original __init__ method. + Returns: + The result of the original __init__ method. + """ all_components.append(self) return Block.original_init(self, *args, **kwargs) @@ -475,6 +482,13 @@ def blk_ini(self, *args, **kwargs): def patched_wait_for(fut, timeout): + """A patched version of asyncio.wait_for that uses a longer timeout. + Args: + fut: The future to wait for. + timeout: The timeout value. + Returns: + The result of the original wait_for function. + """ del timeout return gradio.routes.asyncio.original_wait_for(fut, timeout=65535) diff --git a/modules/hash_cache.py b/modules/hash_cache.py index e857b0030e..c90a9cd735 100644 --- a/modules/hash_cache.py +++ b/modules/hash_cache.py @@ -11,6 +11,12 @@ def sha256_from_cache(filepath): + """Calculates the SHA256 hash of a file, using a cache to avoid recalculating it if it has been calculated before. + Args: + filepath (str): The path to the file. + Returns: + str: The SHA256 hash of the file. + """ global hash_cache if filepath not in hash_cache: print(f"[Cache] Calculating sha256 for {filepath}") @@ -23,6 +29,7 @@ def sha256_from_cache(filepath): def load_cache_from_file(): + """Loads the hash cache from a file.""" global hash_cache try: @@ -40,6 +47,11 @@ def load_cache_from_file(): def save_cache_to_file(filename=None, hash_value=None): + """Saves the hash cache to a file. + Args: + filename (str, optional): The name of the file to save. Defaults to None. + hash_value (str, optional): The hash value to save. Defaults to None. + """ global hash_cache if filename is not None and hash_value is not None: @@ -59,6 +71,13 @@ def save_cache_to_file(filename=None, hash_value=None): def init_cache(model_filenames, paths_checkpoints, lora_filenames, paths_loras): + """Initializes the hash cache. + Args: + model_filenames (list): A list of model filenames. + paths_checkpoints (list): A list of paths to the model checkpoints. + lora_filenames (list): A list of LoRA filenames. + paths_loras (list): A list of paths to the LoRAs. + """ load_cache_from_file() if args_manager.args.rebuild_hash_cache: @@ -70,6 +89,14 @@ def init_cache(model_filenames, paths_checkpoints, lora_filenames, paths_loras): def rebuild_cache(lora_filenames, model_filenames, paths_checkpoints, paths_loras, max_workers=cpu_count()): + """Rebuilds the hash cache. + Args: + lora_filenames (list): A list of LoRA filenames. + model_filenames (list): A list of model filenames. + paths_checkpoints (list): A list of paths to the model checkpoints. + paths_loras (list): A list of paths to the LoRAs. + max_workers (int, optional): The maximum number of workers to use. Defaults to the number of CPUs. + """ def thread(filename, paths): filepath = get_file_from_folder_list(filename, paths) sha256_from_cache(filepath) diff --git a/modules/html.py b/modules/html.py index 25771cb9f5..28484136f0 100644 --- a/modules/html.py +++ b/modules/html.py @@ -10,4 +10,11 @@ def make_progress_html(number, text): + """Creates an HTML progress bar. + Args: + number (int): The progress value. + text (str): The progress text. + Returns: + str: The HTML code for the progress bar. + """ return progress_html.replace('*number*', str(number)).replace('*text*', text) diff --git a/modules/inpaint_worker.py b/modules/inpaint_worker.py index 88a78a6d61..2d44e033f5 100644 --- a/modules/inpaint_worker.py +++ b/modules/inpaint_worker.py @@ -11,11 +11,19 @@ class InpaintHead(torch.nn.Module): + """A PyTorch module that represents the inpainting head.""" def __init__(self, *args, **kwargs): + """Initializes the InpaintHead.""" super().__init__(*args, **kwargs) self.head = torch.nn.Parameter(torch.empty(size=(320, 5, 3, 3), device='cpu')) def __call__(self, x): + """Performs the forward pass of the inpainting head. + Args: + x (torch.Tensor): The input tensor. + Returns: + torch.Tensor: The output tensor. + """ x = torch.nn.functional.pad(x, (1, 1, 1, 1), "replicate") return torch.nn.functional.conv2d(input=x, weight=self.head) @@ -24,18 +32,38 @@ def __call__(self, x): def box_blur(x, k): + """Applies a box blur to an image. + Args: + x (np.ndarray): The input image. + k (int): The kernel size. + Returns: + np.ndarray: The blurred image. + """ x = Image.fromarray(x) x = x.filter(ImageFilter.BoxBlur(k)) return np.array(x) def max_filter_opencv(x, ksize=3): + """Applies a maximum filter to an image using OpenCV. + Args: + x (np.ndarray): The input image. + ksize (int, optional): The kernel size. Defaults to 3. + Returns: + np.ndarray: The filtered image. + """ # Use OpenCV maximum filter # Make sure the input type is int16 return cv2.dilate(x, np.ones((ksize, ksize), dtype=np.int16)) def morphological_open(x): + """Performs a morphological opening on an image. + Args: + x (np.ndarray): The input image. + Returns: + np.ndarray: The opened image. + """ # Convert array to int16 type via threshold operation x_int16 = np.zeros_like(x, dtype=np.int16) x_int16[x > 127] = 256 @@ -51,17 +79,39 @@ def morphological_open(x): def up255(x, t=0): + """Converts an image to a binary image with values of 0 or 255. + Args: + x (np.ndarray): The input image. + t (int, optional): The threshold value. Defaults to 0. + Returns: + np.ndarray: The binary image. + """ y = np.zeros_like(x).astype(np.uint8) y[x > t] = 255 return y def imsave(x, path): + """Saves an image to a file. + Args: + x (np.ndarray): The image to save. + path (str): The path to the file. + """ x = Image.fromarray(x) x.save(path) def regulate_abcd(x, a, b, c, d): + """Regulates the coordinates of a bounding box to be within the image boundaries. + Args: + x (np.ndarray): The input image. + a (int): The top coordinate. + b (int): The bottom coordinate. + c (int): The left coordinate. + d (int): The right coordinate. + Returns: + tuple[int, int, int, int]: The regulated coordinates. + """ H, W = x.shape[:2] if a < 0: a = 0 @@ -83,6 +133,12 @@ def regulate_abcd(x, a, b, c, d): def compute_initial_abcd(x): + """Computes the initial bounding box of a mask. + Args: + x (np.ndarray): The input mask. + Returns: + tuple[int, int, int, int]: The initial bounding box. + """ indices = np.where(x) a = np.min(indices[0]) b = np.max(indices[0]) @@ -102,6 +158,17 @@ def compute_initial_abcd(x): def solve_abcd(x, a, b, c, d, k): + """Solves for the bounding box of a mask. + Args: + x (np.ndarray): The input mask. + a (int): The top coordinate. + b (int): The bottom coordinate. + c (int): The left coordinate. + d (int): The right coordinate. + k (float): The scaling factor. + Returns: + tuple[int, int, int, int]: The solved bounding box. + """ k = float(k) assert 0.0 <= k <= 1.0 @@ -134,6 +201,13 @@ def solve_abcd(x, a, b, c, d, k): def fooocus_fill(image, mask): + """Fills an image using the Fooocus inpainting algorithm. + Args: + image (np.ndarray): The input image. + mask (np.ndarray): The inpainting mask. + Returns: + np.ndarray: The filled image. + """ current_image = image.copy() raw_image = image.copy() area = np.where(mask < 127) @@ -148,7 +222,15 @@ def fooocus_fill(image, mask): class InpaintWorker: + """A class that performs inpainting on an image.""" def __init__(self, image, mask, use_fill=True, k=0.618): + """Initializes the InpaintWorker. + Args: + image (np.ndarray): The input image. + mask (np.ndarray): The inpainting mask. + use_fill (bool, optional): Whether to use filling. Defaults to True. + k (float, optional): The scaling factor for the bounding box. Defaults to 0.618. + """ a, b, c, d = compute_initial_abcd(mask > 0) a, b, c, d = solve_abcd(mask, a, b, c, d, k=k) @@ -186,12 +268,27 @@ def __init__(self, image, mask, use_fill=True, k=0.618): return def load_latent(self, latent_fill, latent_mask, latent_swap=None): + """Loads the latent variables. + Args: + latent_fill: The latent variable for the fill. + latent_mask: The latent variable for the mask. + latent_swap (optional): The latent variable for the swap. Defaults to None. + """ self.latent = latent_fill self.latent_mask = latent_mask self.latent_after_swap = latent_swap return def patch(self, inpaint_head_model_path, inpaint_latent, inpaint_latent_mask, model): + """Patches the model with the inpainting head. + Args: + inpaint_head_model_path (str): The path to the inpainting head model. + inpaint_latent: The inpainting latent variable. + inpaint_latent_mask: The inpainting latent mask. + model: The model to patch. + Returns: + The patched model. + """ global inpaint_head_model if inpaint_head_model is None: @@ -217,6 +314,7 @@ def input_block_patch(h, transformer_options): return m def swap(self): + """Swaps the latent variables.""" if self.swapped: return @@ -231,6 +329,7 @@ def swap(self): return def unswap(self): + """Unswaps the latent variables.""" if not self.swapped: return @@ -245,6 +344,12 @@ def unswap(self): return def color_correction(self, img): + """Performs color correction on an image. + Args: + img (np.ndarray): The input image. + Returns: + np.ndarray: The color-corrected image. + """ fg = img.astype(np.float32) bg = self.image.copy().astype(np.float32) w = self.mask[:, :, None].astype(np.float32) / 255.0 @@ -252,6 +357,12 @@ def color_correction(self, img): return y.clip(0, 255).astype(np.uint8) def post_process(self, img): + """Post-processes an image. + Args: + img (np.ndarray): The input image. + Returns: + np.ndarray: The post-processed image. + """ a, b, c, d = self.interested_area content = resample_image(img, d - c, b - a) result = self.image.copy() @@ -260,5 +371,9 @@ def post_process(self, img): return result def visualize_mask_processing(self): + """Visualizes the mask processing. + Returns: + list[np.ndarray]: A list of images that represent the mask processing steps. + """ return [self.interested_fill, self.interested_mask, self.interested_image] diff --git a/modules/model_loader.py b/modules/model_loader.py index 1143f75e25..5fa58ea9aa 100644 --- a/modules/model_loader.py +++ b/modules/model_loader.py @@ -10,9 +10,14 @@ def load_file_from_url( progress: bool = True, file_name: Optional[str] = None, ) -> str: - """Download a file from `url` into `model_dir`, using the file present if possible. - - Returns the path to the downloaded file. + """Downloads a file from a URL and saves it to a directory. + Args: + url (str): The URL of the file to download. + model_dir (str): The directory to save the file to. + progress (bool, optional): Whether to display a progress bar. Defaults to True. + file_name (str, optional): The name of the file to save. Defaults to None. + Returns: + str: The path to the downloaded file. """ domain = os.environ.get("HF_MIRROR", "https://huggingface.co").rstrip('/') url = str.replace(url, "https://huggingface.co", domain, 1) diff --git a/readme.md b/readme.md index 6f34a95064..4624318e57 100644 --- a/readme.md +++ b/readme.md @@ -452,6 +452,17 @@ Also, thanks [daswer123](https://github.com/daswer123) for contributing the Canv The log is [here](update_log.md). +## Codebase Overview + +The Fooocus codebase is organized into several directories, each with a specific purpose. Here's a brief overview of the most important ones: + +* **`modules/`**: This directory contains the core logic of the application. It includes modules for image generation (`core.py`), configuration management (`config.py`), model loading (`model_loader.py`), and more. +* **`ldm_patched/`**: This directory contains a patched version of the Latent Diffusion Model (LDM) library. The patches are used to improve performance and add new features. +* **`webui.py`**: This is the main script for the web user interface. It uses the Gradio library to create the UI components and handle user interactions. +* **`launch.py`**: This script is responsible for launching the application. It prepares the environment, downloads the necessary models, and starts the web UI. + +The application works by taking user input from the web UI, processing it in the `modules/` directory, and then using the `ldm_patched/` library to generate images. The `webui.py` script then displays the generated images to the user. + ## Localization/Translation/I18N You can put json files in the `language` folder to translate the user interface. diff --git a/webui.py b/webui.py index b8159d8558..8fec06a451 100644 --- a/webui.py +++ b/webui.py @@ -25,12 +25,22 @@ from modules.util import is_json def get_task(*args): + """Gets a task from the arguments. + Args: + *args: The arguments. + Returns: + worker.AsyncTask: The task. + """ args = list(args) args.pop(0) return worker.AsyncTask(args=args) def generate_clicked(task: worker.AsyncTask): + """Handles the generate button click. + Args: + task (worker.AsyncTask): The task to run. + """ import ldm_patched.modules.model_management as model_management with model_management.interrupt_processing_mutex: @@ -94,6 +104,13 @@ def generate_clicked(task: worker.AsyncTask): def sort_enhance_images(images, task): + """Sorts the enhanced images. + Args: + images (list): The images to sort. + task (worker.AsyncTask): The task. + Returns: + list: The sorted images. + """ if not task.should_enhance or len(images) <= task.images_to_enhance_count: return images @@ -113,6 +130,13 @@ def sort_enhance_images(images, task): def inpaint_mode_change(mode, inpaint_engine_version): + """Handles the inpaint mode change. + Args: + mode (str): The inpaint mode. + inpaint_engine_version (str): The inpaint engine version. + Returns: + list: A list of gradio updates. + """ assert mode in modules.flags.inpaint_options # inpaint_additional_prompt, outpaint_selections, example_inpaint_prompts, @@ -1111,6 +1135,7 @@ def trigger_auto_describe(mode, img, prompt, apply_styles): .then(lambda: None, _js='()=>{refresh_style_localization();}') def dump_default_english_config(): + """Dumps the default English config.""" from modules.localization import dump_english_config dump_english_config(grh.all_components)