diff --git a/pinecone_text/sparse/splade_encoder.py b/pinecone_text/sparse/splade_encoder.py index 240a9df..f5225ee 100644 --- a/pinecone_text/sparse/splade_encoder.py +++ b/pinecone_text/sparse/splade_encoder.py @@ -1,5 +1,7 @@ from typing import List, Union, Optional - +from os import PathLike +import os +import json try: import torch except (OSError, ImportError, ModuleNotFoundError) as e: @@ -26,11 +28,12 @@ class SpladeEncoder(BaseSparseEncoder): Currently only supports inference with naver/splade-cocondenser-ensembledistil """ - def __init__(self, max_seq_length: int = 256, device: Optional[str] = None): + def __init__(self, max_seq_length: int = 256, device: Optional[str] = None, model_dir:Optional[PathLike[str]] = None): """ Args: max_seq_length: Maximum sequence length for the model. Must be between 1 and 512. device: Device to use for inference. Defaults to GPU if available, otherwise CPU. + model_dir: Directory to download and load model from. Saves time and resources. Example: @@ -61,12 +64,38 @@ def __init__(self, max_seq_length: int = 256, device: Optional[str] = None): device = device or ("cuda" if torch.cuda.is_available() else "cpu") self.device = device - - model = "naver/splade-cocondenser-ensembledistil" - self.tokenizer = AutoTokenizer.from_pretrained(model) - self.model = AutoModelForMaskedLM.from_pretrained(model).to(self.device) + expected_model_name = "naver/splade-cocondenser-ensembledistil" + if model_dir: + if not self._is_correct_model(model_dir, expected_model_name): + self.tokenizer,self.model=self._download_model(model_dir, expected_model_name) + else: + self.tokenizer = AutoTokenizer.from_pretrained(model_dir) + self.model = AutoModelForMaskedLM.from_pretrained(model_dir).to(self.device) + else: + self.tokenizer = AutoTokenizer.from_pretrained(expected_model_name) + self.model = AutoModelForMaskedLM.from_pretrained(expected_model_name).to(self.device) self.max_seq_length = max_seq_length - + def _is_correct_model(self, model_dir, expected_model_name): + # Check for the presence of specific files that indicate the correct model + config_path = os.path.join(model_dir, 'config.json') + if not os.path.exists(config_path): + return False + + with open(config_path, 'r') as config_file: + config = json.load(config_file) + return config.get("_name_or_path") == expected_model_name + + def _download_model(self, model_dir, model_name): + # Ensure the directory exists + os.makedirs(model_dir, exist_ok=True) + + # Download the tokenizer and model + tokenizer = AutoTokenizer.from_pretrained(model_name) + tokenizer.save_pretrained(model_dir) + + model = AutoModelForMaskedLM.from_pretrained(model_name) + model.save_pretrained(model_dir) + return tokenizer,model def encode_documents( self, texts: Union[str, List[str]] ) -> Union[SparseVector, List[SparseVector]]: