From f76169d98e6d980b1877b8a9fd7b2ce51a8a093c Mon Sep 17 00:00:00 2001 From: bogdan-galileo Date: Fri, 30 May 2025 12:57:03 -0400 Subject: [PATCH 1/5] add batch as arg --- dataquality/dq_auto/tc_trainer.py | 4 +++- dataquality/dq_auto/text_classification.py | 2 ++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/dataquality/dq_auto/tc_trainer.py b/dataquality/dq_auto/tc_trainer.py index 14a9dae94..4369c5b65 100644 --- a/dataquality/dq_auto/tc_trainer.py +++ b/dataquality/dq_auto/tc_trainer.py @@ -58,6 +58,7 @@ def get_trainer( model_checkpoint: str, max_padding_length: int, num_train_epochs: int, + batch_size: int, early_stopping: bool = True, ) -> Tuple[Trainer, DatasetDict]: tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True) @@ -76,7 +77,6 @@ def model_init() -> Any: metric = evaluate.load(EVAL_METRIC) # We use the users chosen evaluation metric by preloading it into the partial compute_metrics_partial = partial(compute_metrics, metric) - batch_size = 64 has_val = Split.validation in encoded_datasets eval_strat = IntervalStrategy.EPOCH if has_val else IntervalStrategy.NO load_best_model = has_val # Can only load the best model if we have validation data @@ -108,4 +108,6 @@ def model_init() -> Any: compute_metrics=compute_metrics_partial, callbacks=callbacks, ) + + print(f"Trainer: {trainer.args.per_device_train_batch_size}") return trainer, encoded_datasets diff --git a/dataquality/dq_auto/text_classification.py b/dataquality/dq_auto/text_classification.py index 305caef47..9adffa3de 100644 --- a/dataquality/dq_auto/text_classification.py +++ b/dataquality/dq_auto/text_classification.py @@ -122,6 +122,7 @@ def auto( inference_data: Optional[Dict[str, Union[pd.DataFrame, Dataset, str]]] = None, max_padding_length: int = 200, num_train_epochs: int = 15, + batch_size: int = 64, hf_model: str = "distilbert-base-uncased", labels: Optional[List[str]] = None, project_name: str = "auto_tc", @@ -266,6 +267,7 @@ def auto( hf_model, max_padding_length, num_train_epochs, + batch_size=batch_size, early_stopping=early_stopping, ) return do_train(trainer, encoded_data, wait, create_data_embs) From fe86acc56dcdeafb7dc0c80c5b361156e7075481 Mon Sep 17 00:00:00 2001 From: bogdan-galileo Date: Fri, 30 May 2025 12:57:58 -0400 Subject: [PATCH 2/5] lint --- dataquality/dq_auto/tc_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dataquality/dq_auto/tc_trainer.py b/dataquality/dq_auto/tc_trainer.py index 4369c5b65..645c5d027 100644 --- a/dataquality/dq_auto/tc_trainer.py +++ b/dataquality/dq_auto/tc_trainer.py @@ -108,6 +108,6 @@ def model_init() -> Any: compute_metrics=compute_metrics_partial, callbacks=callbacks, ) - + print(f"Trainer: {trainer.args.per_device_train_batch_size}") return trainer, encoded_datasets From 90433abff979c42e968a469ec110663075aa6ec0 Mon Sep 17 00:00:00 2001 From: bogdan-galileo Date: Fri, 30 May 2025 13:10:31 -0400 Subject: [PATCH 3/5] update transformers dep --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index d52b6c3fd..e249255fb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,7 @@ scipy = ">=1.7.0" cachetools = ">=4.2.4" importlib-metadata = "<6.0.1" datasets = ">=2.14.6" -transformers = ">=4.17.0" +transformers = ">=4.17.0, <4.45" seqeval = "*" sentence-transformers = ">=2.2" h5py = ">=3.1.0" From c9e088e286839b801b8e6217314d320b52faa819 Mon Sep 17 00:00:00 2001 From: bogdan-galileo Date: Fri, 30 May 2025 13:11:53 -0400 Subject: [PATCH 4/5] fix version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index e249255fb..ab84f763e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,7 @@ scipy = ">=1.7.0" cachetools = ">=4.2.4" importlib-metadata = "<6.0.1" datasets = ">=2.14.6" -transformers = ">=4.17.0, <4.45" +transformers = ">=4.17.0, <4.45.0" seqeval = "*" sentence-transformers = ">=2.2" h5py = ">=3.1.0" From 847ec5c9d892f3a0220fad744839e8150d66c47b Mon Sep 17 00:00:00 2001 From: bogdan-galileo Date: Fri, 30 May 2025 13:14:13 -0400 Subject: [PATCH 5/5] fix version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index ab84f763e..1409ecbe9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,7 @@ scipy = ">=1.7.0" cachetools = ">=4.2.4" importlib-metadata = "<6.0.1" datasets = ">=2.14.6" -transformers = ">=4.17.0, <4.45.0" +transformers = ">=4.17.0,<4.45.0" seqeval = "*" sentence-transformers = ">=2.2" h5py = ">=3.1.0"