From 587094829220a9f4631853bd593dc74076da1ae1 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Thu, 3 Apr 2025 16:15:32 +0200 Subject: [PATCH 1/2] add a parameter to the dataset that (if set), throws out all instances that have more than x tokens --- chebai/preprocessing/datasets/base.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index 817bc1d1..b2390ae9 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -79,6 +79,7 @@ def __init__( inner_k_folds: int = -1, # use inner cross-validation if > 1 fold_index: Optional[int] = None, base_dir: Optional[str] = None, + n_token_limit: Optional[int] = None, **kwargs, ): super().__init__() @@ -110,6 +111,7 @@ def __init__( ), "fold_index can't be larger than the total number of folds" self.fold_index = fold_index self._base_dir = base_dir + self.n_token_limit = n_token_limit os.makedirs(self.raw_dir, exist_ok=True) os.makedirs(self.processed_dir, exist_ok=True) if self.use_inner_cross_validation: @@ -311,8 +313,9 @@ def _load_data_from_file(self, path: str) -> List[Dict[str, Any]]: for d in tqdm.tqdm(self._load_dict(path), total=lines) if d["features"] is not None ] - # filter for missing features in resulting data - data = [val for val in data if val["features"] is not None] + # filter for missing features in resulting data, keep features length below token limit + data = [val for val in data if val["features"] is not None + and self.n_token_limit is None or len(val["features"]) <= self.n_token_limit] return data @@ -1181,4 +1184,6 @@ def processed_file_names_dict(self) -> dict: dict: A dictionary mapping dataset keys to their respective file names. For example, {"data": "data.pt"}. """ + if self.n_token_limit is not None: + return {"data": f"data_maxlen{self.n_token_limit}.pt"} return {"data": "data.pt"} From 0a62a95bba5cd8d128998f0265cb42b5a81cab81 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Mon, 7 Apr 2025 11:08:33 +0200 Subject: [PATCH 2/2] reformat using black --- chebai/preprocessing/datasets/base.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index b2390ae9..39e5fbec 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -314,8 +314,13 @@ def _load_data_from_file(self, path: str) -> List[Dict[str, Any]]: if d["features"] is not None ] # filter for missing features in resulting data, keep features length below token limit - data = [val for val in data if val["features"] is not None - and self.n_token_limit is None or len(val["features"]) <= self.n_token_limit] + data = [ + val + for val in data + if val["features"] is not None + and self.n_token_limit is None + or len(val["features"]) <= self.n_token_limit + ] return data