Skip to content

[Refactor] Replace custom detection code with huggingface_guess #15

@iwr-redmond

Description

@iwr-redmond

The code used to detect a checkpoint is minimalistic, and the package would be better served by including huggingface_guess as a submodule.

One minor change would be helpful on model_list.py#L100 to account for the new SD15 base repository:

huggingface_repo = "stable-diffusion-v1-5/stable-diffusion-v1-5"

This would allow the checkpoint_model_type function in utils.py to be something like:

# load the required modules
from stablepy import huggingface_guess
from picklescan.scanner import scan_file_path as legacy_scan
from safetensors.torch import load_file as safe_load
from torch import load as legacy_load

# read the checkpoint
if path.lower().endswith(".safetensors"):
        state_dict = safe_load(checkpoint_path, device="cpu")
        repo_name = huggingface_guess.guess_repo_name(state_dict)
    else:
        scan_result = legacy_scan(checkpoint_path)
        if scan_result is 0:
            state_dict = legacy_load(checkpoint_path, device="cpu")
            repo_name = huggingface_guess.guess_repo_name(state_dict)
        elif scan_result is 2:
            repo_name is "security_error"
        else:
            repo_name = "security_blocked"

# match the repo_name to the preexisting definitions
if repo_name is "stable-diffusion-v1-5/stable-diffusion-v1-5":
     model_type = "sd1.5"
elif repo_name is "stabilityai/stable-diffusion-2-1":
    model_type = "sd2.1"
elif repo_name is "stabilityai/stable-diffusion-xl-base-1.0":
    model_type = "sdxl"
elif repo_name is "stabilityai/stable-diffusion-xl-refiner-1.0":
    model_type = "refiner"
elif repo_name is "black-forest-labs/FLUX.1-dev":
    model_type = "flux-dev"
elif repo_name is "black-forest-labs/FLUX.1-schnell":
    model_type = "flux-schnell"
elif repo_name is "security_error":
    logger.debug(str(e))
    logger.info("Error reading checkpoint: unable to complete scan for malicious code")
elif repo_name is "security_blocked":
    logger.debug(str(e))
    logger.info("Error reading checkpoint: potentially malicious code detected")
else:
    logger.debug(str(e))
    logger.info("Error reading checkpoint: unsupported model type", model_type)

# unload the checkpoint
if state_dict:
    del state_dict

return model_type

Assuming that #14 is implemented.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions