Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
model: ['openai_clip', 'open_clip', 'hf_clip']
model: ['openai_clip', 'open_clip', 'hf_clip', 'cn_clip']

steps:
- uses: actions/checkout@v4
Expand Down
55 changes: 55 additions & 0 deletions all_clip/cn_clip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""https://github.com/OFA-Sys/Chinese-CLIP"""
import os.path
from typing import Dict

import cn_clip.clip
import torch
from torch import nn


class CnCLIPForBenchmark(nn.Module):
"""
enable to do model.encode_text(dict_tensor)
"""

def __init__(self, model, device):
super().__init__()
self.model = model
self.device = torch.device(device=device)

def encode_image(self, image):
return self.model.encode_image(image)

def encode_text(self, text):
return self.model.encode_text(text)

def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)


def load_chinese_clip(clip_model, use_jit, device, clip_cache_path): # pylint: disable=unused-argument
"""load chinese clip"""
try:
from cn_clip.clip.utils import create_model, image_transform, _MODEL_INFO # pylint: disable=import-outside-toplevel
except ImportError as exc:
raise ImportError(
"Install `Chinese-CLIP` by `pip install git+https://github.com/OFA-Sys/Chinese-CLIP.git`"
) from exc
cache_dir = clip_cache_path
model_info = clip_model.split('/')

clip_model_parts = clip_model.split("/")
model_name = clip_model_parts[0]
checkpoint_file = "/".join(clip_model_parts[1:])

model_name = _MODEL_INFO[model_name]['struct']
checkpoint = None
if os.path.isfile(checkpoint_file):
with open(checkpoint_file, 'rb') as opened_file:
# loading saved checkpoint
checkpoint = torch.load(opened_file, map_location="cpu")
model = create_model(model_name, checkpoint)
model.to(device=device, dtype=torch.float32)
processor = image_transform()

return CnCLIPForBenchmark(model, device), processor, cn_clip.clip.tokenize
2 changes: 2 additions & 0 deletions all_clip/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .open_clip import load_open_clip
from .openai_clip import load_openai_clip
from .ja_clip import load_japanese_clip
from .cn_clip import load_chinese_clip


_CLIP_REGISTRY = {
Expand All @@ -18,6 +19,7 @@
"nm:": load_deepsparse,
"ja_clip:": load_japanese_clip,
"openai_clip:": load_openai_clip,
"cn_clip:": load_chinese_clip,
"": load_openai_clip,
}

Expand Down
1 change: 1 addition & 0 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"hf_clip:patrickjohncyh/fashion-clip",
# "nm:mgoin/CLIP-ViT-B-32-laion2b_s34b_b79k-ds", # deepsparse not compatible with Python 3.10+
# "ja_clip:rinna/japanese-clip-vit-b-16", # japanese-clip has transformers compatibility issues with v4.55+
"cn_clip:ViT-B-16/no_checkpoint"
],
)
def test_load_clip(model):
Expand Down
Loading