diff --git a/dataquality/dq_auto/tc_trainer.py b/dataquality/dq_auto/tc_trainer.py index 14a9dae94..645c5d027 100644 --- a/dataquality/dq_auto/tc_trainer.py +++ b/dataquality/dq_auto/tc_trainer.py @@ -58,6 +58,7 @@ def get_trainer( model_checkpoint: str, max_padding_length: int, num_train_epochs: int, + batch_size: int, early_stopping: bool = True, ) -> Tuple[Trainer, DatasetDict]: tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True) @@ -76,7 +77,6 @@ def model_init() -> Any: metric = evaluate.load(EVAL_METRIC) # We use the users chosen evaluation metric by preloading it into the partial compute_metrics_partial = partial(compute_metrics, metric) - batch_size = 64 has_val = Split.validation in encoded_datasets eval_strat = IntervalStrategy.EPOCH if has_val else IntervalStrategy.NO load_best_model = has_val # Can only load the best model if we have validation data @@ -108,4 +108,6 @@ def model_init() -> Any: compute_metrics=compute_metrics_partial, callbacks=callbacks, ) + + print(f"Trainer: {trainer.args.per_device_train_batch_size}") return trainer, encoded_datasets diff --git a/dataquality/dq_auto/text_classification.py b/dataquality/dq_auto/text_classification.py index 305caef47..9adffa3de 100644 --- a/dataquality/dq_auto/text_classification.py +++ b/dataquality/dq_auto/text_classification.py @@ -122,6 +122,7 @@ def auto( inference_data: Optional[Dict[str, Union[pd.DataFrame, Dataset, str]]] = None, max_padding_length: int = 200, num_train_epochs: int = 15, + batch_size: int = 64, hf_model: str = "distilbert-base-uncased", labels: Optional[List[str]] = None, project_name: str = "auto_tc", @@ -266,6 +267,7 @@ def auto( hf_model, max_padding_length, num_train_epochs, + batch_size=batch_size, early_stopping=early_stopping, ) return do_train(trainer, encoded_data, wait, create_data_embs) diff --git a/pyproject.toml b/pyproject.toml index d52b6c3fd..1409ecbe9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,7 @@ scipy = ">=1.7.0" cachetools = ">=4.2.4" importlib-metadata = "<6.0.1" datasets = ">=2.14.6" -transformers = ">=4.17.0" +transformers = ">=4.17.0,<4.45.0" seqeval = "*" sentence-transformers = ">=2.2" h5py = ">=3.1.0"