-
Notifications
You must be signed in to change notification settings - Fork 14
Open
Labels
enhancementNew feature or requestNew feature or request
Description
The code used to detect a checkpoint is minimalistic, and the package would be better served by including huggingface_guess as a submodule.
One minor change would be helpful on model_list.py#L100 to account for the new SD15 base repository:
huggingface_repo = "stable-diffusion-v1-5/stable-diffusion-v1-5"
This would allow the checkpoint_model_type function in utils.py to be something like:
# load the required modules
from stablepy import huggingface_guess
from picklescan.scanner import scan_file_path as legacy_scan
from safetensors.torch import load_file as safe_load
from torch import load as legacy_load
# read the checkpoint
if path.lower().endswith(".safetensors"):
state_dict = safe_load(checkpoint_path, device="cpu")
repo_name = huggingface_guess.guess_repo_name(state_dict)
else:
scan_result = legacy_scan(checkpoint_path)
if scan_result is 0:
state_dict = legacy_load(checkpoint_path, device="cpu")
repo_name = huggingface_guess.guess_repo_name(state_dict)
elif scan_result is 2:
repo_name is "security_error"
else:
repo_name = "security_blocked"
# match the repo_name to the preexisting definitions
if repo_name is "stable-diffusion-v1-5/stable-diffusion-v1-5":
model_type = "sd1.5"
elif repo_name is "stabilityai/stable-diffusion-2-1":
model_type = "sd2.1"
elif repo_name is "stabilityai/stable-diffusion-xl-base-1.0":
model_type = "sdxl"
elif repo_name is "stabilityai/stable-diffusion-xl-refiner-1.0":
model_type = "refiner"
elif repo_name is "black-forest-labs/FLUX.1-dev":
model_type = "flux-dev"
elif repo_name is "black-forest-labs/FLUX.1-schnell":
model_type = "flux-schnell"
elif repo_name is "security_error":
logger.debug(str(e))
logger.info("Error reading checkpoint: unable to complete scan for malicious code")
elif repo_name is "security_blocked":
logger.debug(str(e))
logger.info("Error reading checkpoint: potentially malicious code detected")
else:
logger.debug(str(e))
logger.info("Error reading checkpoint: unsupported model type", model_type)
# unload the checkpoint
if state_dict:
del state_dict
return model_type
Assuming that #14 is implemented.
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or request