Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion all_clip/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""load clip"""

from functools import lru_cache
from functools import lru_cache, partial
import torch
from PIL import Image
import time
Expand All @@ -14,6 +14,8 @@

_CLIP_REGISTRY = {
"open_clip:": load_open_clip,
"open_clip_hf:": partial(load_open_clip, schema="hf-hub"),
"open_clip_local:": partial(load_open_clip, schema="local-dir"),
"hf_clip:": load_hf_clip,
"nm:": load_deepsparse,
"ja_clip:": load_japanese_clip,
Expand Down
49 changes: 35 additions & 14 deletions all_clip/open_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,25 +34,46 @@ def forward(self, *args, **kwargs):
return self.inner_model(*args, **kwargs)


def load_open_clip(clip_model, use_jit=True, device="cuda", clip_cache_path=None):
def load_open_clip(clip_model, use_jit=True, device="cuda", clip_cache_path=None, schema=None):
"""load open clip"""

import open_clip # pylint: disable=import-outside-toplevel

torch.backends.cuda.matmul.allow_tf32 = True
clip_model_parts = clip_model.split("/")
clip_model = clip_model_parts[0]
checkpoint = "/".join(clip_model_parts[1:])
if checkpoint == "":
pretrained = dict(open_clip.list_pretrained())
checkpoint = pretrained[clip_model]
model, _, preprocess = open_clip.create_model_and_transforms(
clip_model,
pretrained=checkpoint,
device=device,
jit=use_jit,
cache_dir=clip_cache_path,
)

if schema is not None:
if schema == "hf-hub":
from open_clip.factory import HF_HUB_PREFIX
clip_model = HF_HUB_PREFIX + clip_model
elif schema == "local-dir":
try:
from open_clip.factory import LOCAL_DIR_PREFIX
except ImportError:
raise ValueError("Upgrade open clip version to >=3 to load from local directories")
clip_model = LOCAL_DIR_PREFIX + clip_model
else:
raise ValueError("Unkown open clip schema: {}".format(schema))
model, _, preprocess = open_clip.create_model_and_transforms(
clip_model,
device=device,
jit=use_jit,
cache_dir=clip_cache_path,
)
else:
clip_model_parts = clip_model.split("/")
clip_model = clip_model_parts[0]
checkpoint = "/".join(clip_model_parts[1:])
if checkpoint == "":
pretrained = dict(open_clip.list_pretrained())
checkpoint = pretrained[clip_model]
model, _, preprocess = open_clip.create_model_and_transforms(
clip_model,
pretrained=checkpoint,
device=device,
jit=use_jit,
cache_dir=clip_cache_path,
)

model = OpenClipWrapper(inner_model=model, device=device)
model.to(device=device)

Expand Down