diff --git a/all_clip/main.py b/all_clip/main.py index f4b8262..51630df 100644 --- a/all_clip/main.py +++ b/all_clip/main.py @@ -1,6 +1,6 @@ """load clip""" -from functools import lru_cache +from functools import lru_cache, partial import torch from PIL import Image import time @@ -14,6 +14,8 @@ _CLIP_REGISTRY = { "open_clip:": load_open_clip, + "open_clip_hf:": partial(load_open_clip, schema="hf-hub"), + "open_clip_local:": partial(load_open_clip, schema="local-dir"), "hf_clip:": load_hf_clip, "nm:": load_deepsparse, "ja_clip:": load_japanese_clip, diff --git a/all_clip/open_clip.py b/all_clip/open_clip.py index 2530a11..bf69bad 100644 --- a/all_clip/open_clip.py +++ b/all_clip/open_clip.py @@ -34,25 +34,46 @@ def forward(self, *args, **kwargs): return self.inner_model(*args, **kwargs) -def load_open_clip(clip_model, use_jit=True, device="cuda", clip_cache_path=None): +def load_open_clip(clip_model, use_jit=True, device="cuda", clip_cache_path=None, schema=None): """load open clip""" import open_clip # pylint: disable=import-outside-toplevel torch.backends.cuda.matmul.allow_tf32 = True - clip_model_parts = clip_model.split("/") - clip_model = clip_model_parts[0] - checkpoint = "/".join(clip_model_parts[1:]) - if checkpoint == "": - pretrained = dict(open_clip.list_pretrained()) - checkpoint = pretrained[clip_model] - model, _, preprocess = open_clip.create_model_and_transforms( - clip_model, - pretrained=checkpoint, - device=device, - jit=use_jit, - cache_dir=clip_cache_path, - ) + + if schema is not None: + if schema == "hf-hub": + from open_clip.factory import HF_HUB_PREFIX + clip_model = HF_HUB_PREFIX + clip_model + elif schema == "local-dir": + try: + from open_clip.factory import LOCAL_DIR_PREFIX + except ImportError: + raise ValueError("Upgrade open clip version to >=3 to load from local directories") + clip_model = LOCAL_DIR_PREFIX + clip_model + else: + raise ValueError("Unkown open clip schema: {}".format(schema)) + model, _, preprocess = open_clip.create_model_and_transforms( + clip_model, + device=device, + jit=use_jit, + cache_dir=clip_cache_path, + ) + else: + clip_model_parts = clip_model.split("/") + clip_model = clip_model_parts[0] + checkpoint = "/".join(clip_model_parts[1:]) + if checkpoint == "": + pretrained = dict(open_clip.list_pretrained()) + checkpoint = pretrained[clip_model] + model, _, preprocess = open_clip.create_model_and_transforms( + clip_model, + pretrained=checkpoint, + device=device, + jit=use_jit, + cache_dir=clip_cache_path, + ) + model = OpenClipWrapper(inner_model=model, device=device) model.to(device=device)