diff --git a/mapreader/classify/classifier.py b/mapreader/classify/classifier.py index 671d41bd..fe4acbc4 100644 --- a/mapreader/classify/classifier.py +++ b/mapreader/classify/classifier.py @@ -108,6 +108,7 @@ def __init__( is_inception: bool = False, load_path: str | None = None, force_device: bool = False, + huggingface=False, **kwargs, ): # set up device @@ -141,6 +142,7 @@ def __init__( ) self.labels_map = labels_map + self.huggingface = huggingface # set up model and move to device print("[INFO] Initializing model.") @@ -149,7 +151,25 @@ def __init__( self.input_size = input_size self.is_inception = is_inception elif isinstance(model, str): - self._initialize_model(model, **kwargs) + if self.huggingface: + try: + from transformers import AutoModelForImageClassification, AutoImageProcessor + except ImportError: + raise ImportError( + "Hugging Face models require the 'transformers' library: 'pip install transformers'." + ) + print(f"[INFO] Initializing Hugging Face model: {model}") + num_labels = len(self.labels_map) + self.model = AutoModelForImageClassification.from_pretrained( + model, + num_labels=num_labels, + ignore_mismatched_sizes=True + ).to(self.device) + self.hf_processor = AutoImageProcessor.from_pretrained(model) + self.input_size = getattr(self.hf_processor, "size", {"height": 224})["height"] + self.is_inception = False + else: + self._initialize_model(model, **kwargs) self.optimizer = None self.scheduler = None