diff --git a/src/twinkle/model/transformers/multi_lora_transformers.py b/src/twinkle/model/transformers/multi_lora_transformers.py index f7573f41..fc2b53cd 100644 --- a/src/twinkle/model/transformers/multi_lora_transformers.py +++ b/src/twinkle/model/transformers/multi_lora_transformers.py @@ -24,7 +24,7 @@ class MultiLoraTransformersModel(TransformersModel, PreTrainedModel): def __init__( self, # noqa - model_cls=AutoModelForCausalLM, + model_cls=None, model_id: Optional[str] = None, config: Optional[PretrainedConfig] = None, device_mesh: Optional[DeviceMesh] = None, @@ -39,9 +39,19 @@ def __init__( self._try_init_process_group() super(PreTrainedModel, self).__init__() model_id = HubOperation.download_model(model_id) + self.model_id = model_id + if config is None: + from transformers import AutoConfig + self.hf_config = AutoConfig.from_pretrained(model_id, trust_remote_code=True) + else: + self.hf_config = config + if model_cls is None and hasattr(self.hf_config, 'architectures'): + model_cls = self.hf_config.architectures[0] + if model_cls is None: + model_cls = AutoModelForCausalLM if isinstance(model_cls, str): model_cls = getattr(transformers, model_cls) - self.model = model_cls.from_pretrained(model_id, config=config, **kwargs) + self.model = model_cls.from_pretrained(model_id, config=self.hf_config, **kwargs) self.model_id = model_id self.tokenizer_id = kwargs.get('tokenizer_id', self.model_id) self.device_mesh = device_mesh diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index e8f9bdda..ab464811 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -148,7 +148,7 @@ def __init__(self, *, model_id: str, config: Optional[PretrainedConfig] = None, def __init__( self, # noqa - model_cls: Optional[Union[Type[PreTrainedModel], str, Type[_BaseAutoModelClass]]] = AutoModelForCausalLM, + model_cls: Optional[Union[Type[PreTrainedModel], str, Type[_BaseAutoModelClass]]] = None, model_id: Optional[str] = None, config: Optional[PretrainedConfig] = None, device_mesh: Optional[DeviceMesh] = None, @@ -162,8 +162,6 @@ def __init__( os.environ['TOKENIZERS_PARALLELISM'] = 'true' self._try_init_process_group() super(PreTrainedModel, self).__init__() - self.model_id = model_id - self.tokenizer_id = kwargs.get('tokenizer_id', self.model_id) # The Default tokenizer will be used to save with a model if no template was set. self._default_tokenizer = None self.device_mesh = device_mesh @@ -173,15 +171,27 @@ def __init__( self._memory_efficient_init = memory_efficient_init self._decide_strategy(strategy) self.grad_scaler_config = grad_scaler_config + if model_id is not None: + model_id = HubOperation.download_model(model_id) + self.model_id = model_id + self.tokenizer_id = kwargs.get('tokenizer_id', self.model_id) + if config is None: + from transformers import AutoConfig + self.hf_config = AutoConfig.from_pretrained(model_id, trust_remote_code=True) + else: + self.hf_config = config + if model_cls is None and hasattr(self.hf_config, 'architectures'): + model_cls = self.hf_config.architectures[0] + if model_cls is None: + model_cls = AutoModelForCausalLM if isinstance(model_cls, str): model_cls = getattr(transformers, model_cls) if model_id is None: - self.model = model_cls.from_config(config, **kwargs) + self.model = model_cls.from_config(self.hf_config, **kwargs) else: - model_id = HubOperation.download_model(model_id) # Trigger transformers' FSDP-aware loading: meta-device init + rank-0-only weight load. with self.strategy.pretrained_load_context(): - self.model = model_cls.from_pretrained(model_id, config=config, **kwargs) + self.model = model_cls.from_pretrained(model_id, config=self.hf_config, **kwargs) self.model.gradient_checkpointing_enable() self.sp_strategy = None self._model_wrapped = False