Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/dlomix/_metadata.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.2.4"
__version__ = "0.2.5"
__author__ = "Wilhelm Lab"
__author_email__ = "o.shouman@tum.de"
__license__ = "MIT"
Expand Down
63 changes: 15 additions & 48 deletions src/dlomix/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,27 +26,6 @@

logger = logging.getLogger(__name__)

# TensorFlow import for tf.data is deferred until needed to avoid unnecessary imports for users who only want to use PyTorch datasets or other functionalities of the PeptideDataset class.
# This also helps to reduce the initial loading time and memory footprint for users who do not need TensorFlow.

_tf = None


def _get_tensorflow():
"""Lazy import of TensorFlow. Only imports when needed."""
global _tf
if _tf is None:
try:
import tensorflow as tf

_tf = tf
except ImportError:
raise ImportError(
"TensorFlow backend requires tensorflow to be installed. "
"Install with: pip install tensorflow"
)
return _tf


class PeptideDataset:
"""
Expand Down Expand Up @@ -189,13 +168,20 @@ def __init__(self, dataset_config: DatasetConfig, **kwargs):
self.processed = True

def _set_num_proc(self):
n_processors = get_num_processors()
if self._num_proc:
n_processors = get_num_processors()
if self._num_proc > n_processors:
warnings.warn(
f"Number of processors provided is greater than the available processors. Using the maximum number of processors available: {n_processors}."
)
self._num_proc = n_processors
else:
warnings.warn(
f"Number of processors not provided. Using the maximum number of processors available: {n_processors}.\n"
f"If you want to specify a different number of processors, please provide num_proc=<desired_number> parameter in the dataset configuration.\n"
f"If you face issues with memory usage, please consider providing a smaller number of processors or setting num_proc=1 to disable multi-processing."
)
self._num_proc = n_processors

def _set_hf_cache_management(self):
if self.disable_cache:
Expand Down Expand Up @@ -715,26 +701,13 @@ def tensor_train_data(self):
if self.dataset_type == "pt":
return self._get_split_torch_dataset(PeptideDataset.DEFAULT_SPLIT_NAMES[0])
else:
tf = _get_tensorflow()
dataset_len = len(self.hf_dataset[PeptideDataset.DEFAULT_SPLIT_NAMES[0]])
tf_dataset = self._get_split_tf_dataset(
PeptideDataset.DEFAULT_SPLIT_NAMES[0]
)

if self.enable_tf_dataset_cache:
tf_dataset = tf_dataset.cache()

if self.shuffle:
tf_dataset = tf_dataset.shuffle(
buffer_size=min(10000, dataset_len),
reshuffle_each_iteration=True,
)

# Batch the data
tf_dataset = tf_dataset.batch(self.batch_size)

tf_dataset = tf_dataset.prefetch(tf.data.AUTOTUNE)

return tf_dataset

@property
Expand All @@ -743,17 +716,13 @@ def tensor_val_data(self):
if self.dataset_type == "pt":
return self._get_split_torch_dataset(PeptideDataset.DEFAULT_SPLIT_NAMES[1])
else:
tf = _get_tensorflow()
tf_dataset = self._get_split_tf_dataset(
PeptideDataset.DEFAULT_SPLIT_NAMES[1]
)

if self.enable_tf_dataset_cache:
tf_dataset = tf_dataset.cache()

tf_dataset = tf_dataset.batch(self.batch_size)
tf_dataset = tf_dataset.prefetch(tf.data.AUTOTUNE)

return tf_dataset

@property
Expand All @@ -762,17 +731,10 @@ def tensor_test_data(self):
if self.dataset_type == "pt":
return self._get_split_torch_dataset(PeptideDataset.DEFAULT_SPLIT_NAMES[2])
else:
tf = _get_tensorflow()
tf_dataset = self._get_split_tf_dataset(
PeptideDataset.DEFAULT_SPLIT_NAMES[2]
)

if self.enable_tf_dataset_cache:
tf_dataset = tf_dataset.cache()

tf_dataset = tf_dataset.batch(self.batch_size)
tf_dataset = tf_dataset.prefetch(tf.data.AUTOTUNE)

return tf_dataset

def _check_if_split_exists(self, split_name: str):
Expand All @@ -796,7 +758,10 @@ def _get_split_tf_dataset(self, split_name: str):
return self.hf_dataset[split_name].to_tf_dataset(
columns=self._get_input_tensor_column_names(),
label_cols=label_cols,
shuffle=False,
shuffle=self.shuffle
if split_name == PeptideDataset.DEFAULT_SPLIT_NAMES[0]
else False,
batch_size=self.batch_size,
)

def _get_split_torch_dataset(self, split_name: str):
Expand All @@ -811,7 +776,9 @@ def _get_split_torch_dataset(self, split_name: str):
columns=[*self._get_input_tensor_column_names(), *self.label_column],
),
"batch_size": self.batch_size,
"shuffle": self.shuffle,
"shuffle": self.shuffle
if split_name == PeptideDataset.DEFAULT_SPLIT_NAMES[0]
else False,
}

# Update with user-provided torch_dataloader_kwargs if available
Expand Down