Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion mapreader/classify/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def __init__(
is_inception: bool = False,
load_path: str | None = None,
force_device: bool = False,
huggingface=False,
**kwargs,
):
# set up device
Expand Down Expand Up @@ -141,6 +142,7 @@ def __init__(
)

self.labels_map = labels_map
self.huggingface = huggingface

# set up model and move to device
print("[INFO] Initializing model.")
Expand All @@ -149,7 +151,25 @@ def __init__(
self.input_size = input_size
self.is_inception = is_inception
elif isinstance(model, str):
self._initialize_model(model, **kwargs)
if self.huggingface:
try:
from transformers import AutoModelForImageClassification, AutoImageProcessor
except ImportError:
raise ImportError(
"Hugging Face models require the 'transformers' library: 'pip install transformers'."
)
print(f"[INFO] Initializing Hugging Face model: {model}")
num_labels = len(self.labels_map)
self.model = AutoModelForImageClassification.from_pretrained(
model,
num_labels=num_labels,
ignore_mismatched_sizes=True
).to(self.device)
self.hf_processor = AutoImageProcessor.from_pretrained(model)
self.input_size = getattr(self.hf_processor, "size", {"height": 224})["height"]
self.is_inception = False
else:
self._initialize_model(model, **kwargs)

self.optimizer = None
self.scheduler = None
Expand Down