From 0420d08935acb0d8fdbfc96cbfd559a929c41c8a Mon Sep 17 00:00:00 2001 From: Victor-George Giurcoiu <74825149+victorgiurcoiu@users.noreply.github.com> Date: Tue, 9 Dec 2025 17:32:52 +0100 Subject: [PATCH 1/9] Update pickled ptm feature dicts (#81) --- .../pickled_feature_dicts/mz_diff.pkl | Bin 18779 -> 19511 bytes .../pickled_feature_dicts/saved_ac_count.pkl | Bin 24289 -> 25186 bytes .../saved_gained_atoms.pkl | Bin 23161 -> 24038 bytes .../saved_loss_atoms.pkl | Bin 23161 -> 24027 bytes 4 files changed, 0 insertions(+), 0 deletions(-) diff --git a/src/dlomix/data/processing/pickled_feature_dicts/mz_diff.pkl b/src/dlomix/data/processing/pickled_feature_dicts/mz_diff.pkl index cfb5dfe955ddf09d8e2864c7b2ae6859af135e06..4e40504d8619d098be582af44088f56f50aa5a2f 100644 GIT binary patch delta 1127 zcmZXSOH30{7=>qwQ%WB|c}%CYO^0%)QWWe=X=`0gl>!As#0P?b7+*11wUW4TBLtQN zgi}OgOmt)1xPXb`&WXmA8)9Nypac@P8WSaK6l{r8(s411x%d0eJvskbnWdNR)6D3e zx&2);jLM%pd#XEm^jOOwQ3{3!wI3-um}!%R&EM*R+D7x-Y3(8sIjDmI--c8a8RmHI zcpeH#g^(IFzF$)^hPQ@lNjLxc$44K96ca!fLy%!!d5>r^VlW&n)RBFbMYJ0C@kjHF z*sOD{h71idQpR<@H4g=2dPDJvXEGMtE?g*ngQUq1!9;Mww2HYBf{~IscsbS*4ei7* z{PoKy?wVX!E-B~x@}iOyDRigyjvH01lmuuFQ$7~5xfG0<&`xYySmmO*iLQ0F=f_1+ zZ#-sPU&E9ojjL9HcFa7f#*ozpVkOwJ_%Z8Y@zi<-#OBAgRRHaE;GC^LtA3MV4}_X? z3+rdNHHxApohNpmiRm+pQ53Jdumxxs$_Q3ViQ+HYrjaRP@!ejE4*N9z*z0PYJ-ZtJ zExmE$ybO0m4_;Ozp}1I#RWxI}qLt=m?G97UN}Smo#pTL+Bq{|gR;JKWbq&PrK)9+L zi&cWRO_L8tn+tQ4JI4knPlIfq1cwF%` zpxr}2JVcrGbkNyK_>u7n*z3IxpRc<7Uw_k;gITN8qysZn7Rz2M1fGD-ep4=oSrT*W zI(+-#VhW%c5O~JI9H!0XcVU?S&{*8*VlI E1usrsmH+?% delta 479 zcmX|-F>4cH0EYQ4xz_7>m(b?EOWK%R?!LRW4nkEZ4oa^H;Yo_SKRX-KyAUS!pZ=vYd)otKU=`G4Pd^wHy#HBQTu z{IV`V*5$FCqi#oZ?b}H+7Yd%`6?gM0d7IxBs8Yzr;1jg0!{A4sFCT-Jc5@YkvOxrX~lO|~c#IgY1cXI;usN<22-2B!Ce-;oo6Ry)13uAewa3(E@+gHaInh^5WMDE@A7?Bv zj94Vh# z^znG+uYteh5yI9evL%w*Yn@DRy;MS`G>joP#kNw4VJm-tM*rXxoKDrZP;odr)i-kC*n+mq}lLU6mNx#lX)J@`N5)l{ca#>Ev zf9gBG{3ItwnGt7K%UDzS1Ru*IYI{|J1ceSgJCXPcJDnT0k1g)kjP$O&IQ82-_nRrE zF*sVOT^;t^Bpt$xyz)25ax+XIshwoGE|rydVGFQ@DzMe7;IN4z?GV`S9llD|Y!vU=bK583s83Qo{sV*&!()%T z1mEh)urJUEYhV-?>mCCF99-G_!e?(O?)W)q}DYifJJ(f4O*5I1<0lC}2L4BPq0#6q6Ah%aKp$4*2ZPK$NgrOeMJKP~4Vu z9F@k{14$*adAv0_e(c1=kps0l@_Pp3G*ALD5CRgAnI!`iM@L`QuqwN#C}rGH_TszZ z;WeLzL75+)BPhxgUw5@p3>)~Fpm>&Wv5;`kN_dE>H%J;Q!xSG>iY0X)CTxl!GYTm= z!`>v%5KL(i~v}HWOZ&63W(6MQ_uv1OU3Or2eFlSDoYEGkPdU$TS zyzvG#uwo8EN-sdml(Cqjc$7YmLzxQeW_}U$vptSAhA8S~g^i4G7otKDs8Mz$_m;4B zVTYg~8Xe2T+U_cPT$FEJ!h6f(2MshW3A46{zO^4GM;Fm97$}w~E_xK_Mis2)BZ3%& z@FSkayIc%^b284B9A3TRQ#4AHy)1PIe8p0~T~sQL7+T{7|Wa zf*^V-=t39jLMe!4R76Dtb)^dtaiI$lMQjzsf{Gu+*MvzcgNxz8yJyZl@1FDCkL%?9 z9pdVDo_3WNyCU1?NYReQhtX{f!N!R?Dsa#>OFxX`H47b4;%#FOy)BcoadVWLF&qY={0BU(@{%%IZ=ZZN^Nc4~hBd(p4B(d+ZzG}jVaTkeobWo>G!jLPC9 z?s6%36&`WE{^SQ9zqmSc0e_e@gcg;5*Y(NrOV;poNG@y&1aXdK@w##=p7;{D;8W4V z@fh&cVY^d;%83}QVd3!a#btjIV@?sL{34$FkHDiPlA)CoP4fSnC_^Vowj7u7g%{1< zs!oD8Q#IVA-=(;K9*l1iU ze*{0Ng7Xm($%urVbiE!?)zDI_1&xl~R-PW)N#jjz#l5JE{cHt(M>gY6qlk-&fK5uW zn0c`rX{9NE&q^I`DVjMedW_HjN$|u(n*3o^L|?TJ%bGo?jSb*YgNDa36(a#2_q{Cg zZUK?FhOe;%`r-rl6y?oS;|m5qY)LaiGsEwQ%v23CaGuTTle?AktG2LGS@3DC)(VVv(xNsnvSW zTl%R4ir~RqJ@^ZX9)pnBOhnLmSDjJ7%lgE0yE5(3SM_2vg z8^PxCs;X?65UTBd5d5myYI*djh<*ojOFB=+D~wbxu~PNvRs()iJ*H|GqGG_anxtKq z?29y34^*#yS@1<#xJ`?&5wOS9=rrco2rM2LL(Df)Y;St_jz+m^@S_lh%Tp$6~=eLIgDYEnVQZ{ATVZU zqqlXrRH4DTOjT2bQy8=|{ff&)ghhEpU43Cc=#GVO%l(a%qhI3j-NWI18H?MVDxLMZ z3Fc3uILL6A%Sc4TBc8*DfX5|jK63(`WKz*}d2%9=nY>TjU{WY&l~eCn{9>vgi7Yyc z7`=tJU~?on;%<&S9AUuw8vYHm%eRy=vUC^m$26_V|x*gRzJm@D1|@Jfs27S zz7$bhHqK81N8pl`xVU=!g8KhBsgbsB0iRi_?<`(K@Tx?{2@iuiB_W)z<}fsk(=(-6 z8y0Xf&IGSu2PcW*%5nBznE(lv0Z|@Txdr$b^`MzH;tRJPeYBYfm>1-+(O!lKFXJ0k zu#Z24=Pn2zWG#w=$n%S}hjrBnUs)cZsbGz*`uFWj_AH}cb3O5t`; z{XwNp(CPFpP5+fei~T3JqwEb9P)#H!>%+?Ug^-E^AqCGuM{q2xAZ}Q+D@-vEmZpo1 zd5HS2hl1Q#@xtau1az}-j~Ka6y_ zDX(vhwh$vyyg8!r68a^Ivl44J-#Z^Wj(gK(48|l3cqq0QX$e=$_!?92iSglZ^&Y*t z=CvI?r7WB^ReDkuXEyn)wzJ7>F5cVTVw*Cho3VkwYi@6-X*rZOLDm;_wKHUIOrgb4njd=sX-72aTALKZE#4n zi(MrBPy=1fR8;UO2nq^9{ex~IqJvaL#leWEpkVPPhi~|X4?cdsH+RYJTO{2Xx^$x1 zO~1z(ol2Lq2p*)at=LRnS7HU*RTphPhm2ZALw${AHk3EB8G<#H;@j{qn3}-HTqJjE zS%Os>Vx9)rbQlg(3y)}wk}hys53x@V#>2QWR3XLSpQN)(stP9lS$mGp+ou`6*f63pUGnPPwJ;Nn(Yo(x_p_+SV* N9~Y)E7#FEl{15fGt;GNU From ff0c505462308366e820614c8065ce9211da71df Mon Sep 17 00:00:00 2001 From: Omar Shouman Date: Tue, 27 Jan 2026 22:30:27 +0200 Subject: [PATCH 2/9] Feature/add dataset class guide to docs - change default alphabet argument values in datasets (#82) * dataset guide and minor doc additions * changed default alphabet value to be None to trigger learning the tokens and be more explicit * comments NOTE: this PR breaks previous usage if the alphabet was implicitly assumed by the user to be ALPHABET_UNMOD. Yet, we choose to move to a more explicit approach for better transperancy and reproducibility. --- docs/dlomix.callbacks.rst | 16 + docs/dlomix.rst | 13 +- docs/index.rst | 13 +- docs/notes/dataset_guide.rst | 512 ++++++++++++++++++++++ src/dlomix/data/charge_state.py | 5 +- src/dlomix/data/detectability.py | 5 +- src/dlomix/data/fragment_ion_intensity.py | 5 +- src/dlomix/data/ion_mobility.py | 2 +- src/dlomix/data/retention_time.py | 5 +- 9 files changed, 549 insertions(+), 27 deletions(-) create mode 100644 docs/dlomix.callbacks.rst create mode 100644 docs/notes/dataset_guide.rst diff --git a/docs/dlomix.callbacks.rst b/docs/dlomix.callbacks.rst new file mode 100644 index 00000000..be41246b --- /dev/null +++ b/docs/dlomix.callbacks.rst @@ -0,0 +1,16 @@ +``dlomix.callbacks`` +===================== + +.. automodule:: dlomix.callbacks + :members: + :undoc-members: + :show-inheritance: + +Submodules +---------- + + +.. automodule:: dlomix.callbacks.cyclic_lr + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/dlomix.rst b/docs/dlomix.rst index 0ccf218b..59c29e00 100644 --- a/docs/dlomix.rst +++ b/docs/dlomix.rst @@ -12,6 +12,7 @@ Subpackages .. toctree:: :maxdepth: 4 + dlomix.callbacks dlomix.data dlomix.eval dlomix.layers @@ -34,15 +35,3 @@ Submodules :members: :undoc-members: :show-inheritance: - - -.. automodule:: dlomix.detectability_model_constants - :members: - :undoc-members: - :show-inheritance: - - -.. automodule:: dlomix.types - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/index.rst b/docs/index.rst index b5e53fc0..51b3f743 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -29,6 +29,8 @@ The goal of DLOmix is to be easy to use and flexible, while still providing the .. include:: notes/installation.rst .. include:: notes/quickstart.rst +.. include:: notes/dataset_guide.rst + .. toctree:: @@ -36,11 +38,18 @@ The goal of DLOmix is to be easy to use and flexible, while still providing the :maxdepth: 2 :caption: How To - notes/quickstart notes/installation - notes/backend_usage + notes/quickstart notes/citation +.. toctree:: + :glob: + :maxdepth: 2 + :caption: Guides + + notes/dataset_guide + notes/backend_usage + .. toctree:: :maxdepth: 2 diff --git a/docs/notes/dataset_guide.rst b/docs/notes/dataset_guide.rst new file mode 100644 index 00000000..a9875767 --- /dev/null +++ b/docs/notes/dataset_guide.rst @@ -0,0 +1,512 @@ + +Datasets Guide +************************** + +This guide provides an introduction to working with DLOmix dataset classes for various proteomics machine learning tasks. + +.. contents:: Table of Contents + :local: + :depth: 1 + +Overview +======== + +Introduction to PeptideDataset +------------------------------- + +``PeptideDataset`` is the foundational class for handling peptide data in DLOmix. It wraps HuggingFace's ``Dataset`` library and provides specialized preprocessing and other modules for proteomics tasks including: + +* Loading data from various formats and sources with configurable parameters and sensible defaults (Hugging Face datasets on the Hugging Face Hub or in-memory, Parquet files, CSV files, etc..) +* Sequence parsing and encoding, including handling post-translational modifications (PTMs) +* Data splitting for model training and evaluation (Train/val/test) +* Feature extraction from the peptide sequences and/or modifications present in the sequence +* Automatic conversion to TensorFlow or PyTorch tensors +* Efficient batch processing with caching support (relies on Hugging Face datasets cache) +* Custom saving and loading the processed dataset object for faster experimentation and re-runs +* Improved reproducibility via logging saving/loading configurations and meta-data of the processed dataset class + +Available Task-Specific Datasets +--------------------------------- + +DLOmix provides task-specific dataset classes that inherit from ``PeptideDataset`` with appropriate default values that work with the `PROSPECT dataset collection`_ on the HF hub. + +.. _PROSPECT dataset collection: https://huggingface.co/collections/Wilhelmlab/prospect-ptms + +* **RetentionTimeDataset** - Predicts peptide retention time with default label ``indexed_retention_time`` +* **ChargeStateDataset** - Predicts charge state with default label ``most_abundant_charge_by_count`` +* **DetectabilityDataset** - Predicts peptide detectability with default label ``Classes`` +* **FragmentIonIntensityDataset** - Predicts fragment ion intensities with default label ``intensities_raw`` + +All classes share the same API but differ in their default parameters for common use cases, which are sensible for the PROSPECT datasets to provide a working example when used with the datasets hosted on the Hugging Face Hub. + + +Basic Usage +=========== + +Loading Data from Files +------------------------ + +Load datasets from local files (CSV, Parquet, etc.): + +.. code-block:: python + + from dlomix.data import RetentionTimeDataset + + # Load from a single CSV file (auto-splits train/val with the provided ratio) + dataset = RetentionTimeDataset( + data_source="data/peptides.csv", + data_format="csv", + sequence_column="sequence", + label_column="retention_time", + val_ratio=0.2 + ) + + # Load with explicit train/val/test splits + dataset = RetentionTimeDataset( + data_source="data/train.parquet", + val_data_source="data/val.parquet", + test_data_source="data/test.parquet", + data_format="parquet" + ) + +Loading Data from HuggingFace Hub +---------------------------------- + +Load directly from HuggingFace Hub datasets (hosted on the hub): + +.. code-block:: python + + dataset = RetentionTimeDataset( + data_source="wilhelmlab/prospect-rt-dataset", # example dataset provided by Wilhelmlab, TU Munich + data_format="hub", + # column names match the provided example, replace if needed + sequence_column="modified_sequence", + label_column="indexed_retention_time" + ) + +.. note:: + When using ``data_format="hub"``, ``val_data_source`` and ``test_data_source`` are ignored. The dataset must contain pre-defined splits. + +Using In-Memory HuggingFace Datasets +------------------------------------- + +Local data could be in other formats (e.g. pandas) and require some specific wrangling or formatting that the user would want to run before feeding the data into DLOmix and training a model. For this specific flow, the dataset class offers a way to feed in an in-memory hugging face dataset. + +.. code-block:: python + + from datasets import load_dataset + + # ... user code, column renaming, wrangling, etc.. + + # simulate an in-memory hugging face dataset + hf_dataset = load_dataset("csv", data_files="data.csv") + + # Pass to DLOmix + dataset = RetentionTimeDataset( + data_source=hf_dataset, + data_format="hf", # important to ensure that data is read and parsed correctly + sequence_column="sequence", + label_column="rt" + ) + + +Core Concepts +============= + +Data Splits and Validation +--------------------------- + +DLOmix supports three splitting strategies: + +1. **Single source with auto-split**: Set ``val_ratio`` to automatically split training data into train/val randomly +2. **Multiple sources**: Provide separate files for train/val/test +3. **Pre-split datasets**: Use HuggingFace Hub or ``DatasetDict`` with existing splits + +.. code-block:: python + + # Strategy 1: Auto-split + dataset = RetentionTimeDataset( + data_source="train.csv", + val_ratio=0.2, # 20% for validation + data_format="csv" + ) + + # Strategy 2: Separate files + dataset = RetentionTimeDataset( + data_source="train.csv", + val_data_source="val.csv", + test_data_source="test.csv", + data_format="csv" + ) + + # Strategy 3: pre-split dataset (example below points to a remote hugging face dataset hosted on the hub) + dataset = RetentionTimeDataset( + data_source="wilhelmlab/prospect-rt", + data_format="hub" + ) + + +Sequence Processing Pipeline +----------------------------- + +Datasets automatically process sequences through: + +1. **Parsing**: Extract sequences with PTMs (e.g., ``M[UNIMOD:35]``), where PTM representation follows `Unimod_`'s convention. +2. **Encoding**: Convert amino acids to integers using the alphabet (with options to learn the alphabet from the data) +3. **Padding**: Pad to ``max_seq_len`` (default uses ``padding_value="-"``, which is encoded as `0`'s by default') +4. **Feature Extraction**: Add computed features from the sequence and/or the PTM information + +.. _Unimod: https://unimod.org + +.. code-block:: python + + dataset = RetentionTimeDataset( + data_source="data.csv", + max_seq_len=50, # Pad/truncate to length 50 + pad=True, # Enable padding + padding_value="-", # Use '-' for padding + with_termini=True, # Add N/C termini markers, []- and -[], even if there are no terminal modifications present + + # see below and data/processing/feature_extractions.py for more details + features_to_extract=["delta_mass", "mod_gain"] + ) + +Feature Extraction +------------------ + +Built-in feature extractors add computed features to your dataset that are converted to tensors and can be fed into your model: + +There are two options to use feature extractors; (1) dlomix built-in feature extractors as string arguments and (2) custom feature extractors passed as python function objects. + +Available features in the framework as lookup python dicts are: + +* ``mod_loss`` +* ``delta_mass`` +* ``mod_gain`` +* ``atom_count`` +* ``red_smiles`` + +Custom feature extractors can either: + (1) use the `FeatureExtractor` class or + (2) write a function that can be mapped (`dataset.map()`) to the Hugging Face dataset. + +In both cases, you can access the parsed sequence information from the dataset using the following keys, which all provide python lists: + - `_parsed_sequence`: parsed sequence + - `_n_term_mods`: N-terminal modifications + - `_c_term_mods`: C-terminal modifications + +.. code-block:: python + + dataset = RetentionTimeDataset( + data_source="data.csv", + features_to_extract=["mod_loss", "atom_count"], # extracted feautres from the sequence and modifications + model_features=["collision_energy"], # other features present already in the dataset columns + ) + + + +Encoding Schemes and Alphabets +------------------------------ + +Sequences are parsed and are integer encoded to be fed into sequence models (specifically to the embedding layers). Two important parameters control the parsing and encoding: (1) the encoding scheme, and (2) the alphabet or vocabulary used. + +Two primary encoding schemes for sequences are available: + +* **UNMOD**: Assumes the sequences do not contain modifications, hence any [UNIMOD] strings are removed. +* **NAIVE_MODS**: Assumes sequences contain modifications in UNIMOD format (e.g., ``M[UNIMOD:35]``) and encodes them as distinct tokens; separate token from the amino acid. + +The alphabet is a python dict that maps each character (amino acid or amino acid + PTM combination) to a unique integer index. It can either be learnt from the provided data implicitily or provided by the user. + +The user can: +- use built-in alphabets from ``dlomix.constants`` +- or define custom alphabets as needed and pass them as a python dict to the alphabet argument, +- or provide no alphabet or `None` as the alphabet to trigger learning the alphabet from the data based on the selected encoding scheme. + + +Note that if an alphabet is provided, the user has to ensure that it covers all the amino acids (or amino acid + PTM combinations) present in the data. + +1. Use built-in alphabets from ``dlomix.constants`` + +.. code-block:: python + + from dlomix.data import RetentionTimeDataset + from dlomix.constants import ALPHABET_UNMOD, ALPHABET_NAIVE_MODS + + # Unmodified sequences, uses built-in unmodified alphabet + dataset = RetentionTimeDataset( + data_source="data.csv", + encoding_scheme="unmod", + alphabet=ALPHABET_UNMOD + ) + + # With PTMs, uses built-in naive-mods alphabet with tokens for some amino acids + PTMs combinations + dataset = RetentionTimeDataset( + data_source="data.csv", + encoding_scheme="naive-mods", + alphabet=ALPHABET_NAIVE_MODS + ) + +2. Define and use a custom alphabet + +The dataset class uses the provided alphabet for encoding sequences, so it must cover all characters present in the data, else the unknown tokens are all encoded as unknown with the same integer index. + +.. code-block:: python + + # Custom alphabet with special amino acids and PTMs + CUSTOM_ALPHABET = { + # .... + } + + dataset = RetentionTimeDataset( + data_source="data.csv", + encoding_scheme="naive-mods", + alphabet=CUSTOM_ALPHABET + ) + + +3. Learn alphabet from data + +If no alphabet is provided (i.e., ``alphabet=None``), the dataset class learns the alphabet from the data based on the selected encoding scheme. +If a pre-defined split is provided, the alphabet is learned from the training and validation data only, the test set is not used for learning the alphabet to allow for proper evaluation on unseen data. + +.. code-block:: python + + dataset = RetentionTimeDataset( + data_source="data.csv", + encoding_scheme="naive-mods", + alphabet=None # Learn from data + ) + + # Access the learned alphabet after the class is initialized and the processing is done + print(dataset.extended_alphabet) + + + +Tensor Datasets for Model Training +================================== + +The tensor datasets can be accessed via the ``train_data``, ``val_data``, and ``test_data`` attributes of the dataset class. They are ready to be fed into TensorFlow or PyTorch models depending on the selected ``dataset_type``. + +.. code-block:: python + + from dlomix.data import RetentionTimeDataset + + # Load and process dataset + dataset = RetentionTimeDataset(...) + + # model initialization, compilation, etc.. + + # pass to model.fit() in Keras + model.fit(dataset.train_data, + validation_data=dataset.val_data, + epochs=10, + **kwargs) + + +Advanced Features +================= + +Custom Feature Extraction +-------------------------- + +Provide custom feature extraction functions: + +.. code-block:: python + + # define the custom feature extraction function + def hydrophobicity_score(input_data): + """Calculate hydrophobicity score""" + + # lookup table for hydrophobicity values + hydro_values = {'A': 1.8, 'R': -4.5, 'N': -3.5, ...} + + # access the parsed sequence or any other column in the dataset + sequence = input_data["_parsed_sequence"] + + # add the new column with the feature name + input_data["hydrophobicity_score"] = sum(hydro_values.get(aa, 0) for aa in sequence) + + # return the whole row with the new feature added + return input_data + + dataset = RetentionTimeDataset( + data_source="data.csv", + features_to_extract=[hydrophobicity_score], + model_features=["collision_energy"] + ) + +Custom functions receive the row as a dictionary and should return the row again after adding the feature. + +Multi-Processing and Performance +--------------------------------- + +Optimize dataset processing with these parameters: + +.. code-block:: python + + dataset = RetentionTimeDataset( + data_source="large_dataset.parquet", + num_proc=4, # Use 4 CPU cores for processing + batch_processing_size=5000, # Process 5000 rows at a time + disable_cache=False, # Enable HF datasets caching + auto_cleanup_cache=True, # Clean temp files after processing + enable_tf_dataset_cache=True # Cache TF datasets in memory + ) + +**Performance tips:** + +* ``num_proc``: Set to number of CPU cores for large datasets +* ``batch_processing_size``: Increase for better throughput (default: 1000) +* ``enable_tf_dataset_cache``: Speeds up repeated iterations but uses more memory +* ``disable_cache=False``: Reuses processed datasets across runs. This is the hugging face datasets caching mechanism. +* ``auto_cleanup_cache=True``: Cleans temporary files after processing to save disk space. + +Saving and Loading Processed Datasets +-------------------------------------- + +Save processed datasets to disk to avoid reprocessing: + +.. code-block:: python + + # Save processed dataset + dataset = RetentionTimeDataset( + data_source="train.csv", + val_ratio=0.2 + ) + dataset.save_to_disk("processed_datasets/rt_dataset") + + # Load later + from dlomix.data import load_processed_dataset + + # loads the dataset along with its configuration and metadata, no re-processing needed + dataset = load_processed_dataset("processed_datasets/rt_dataset") + + # Access tensor data immediately + train_data = dataset.train_data + +This saves configuration, processed HuggingFace datasets, and metadata. + + +TensorFlow vs PyTorch +===================== + +Generating TensorFlow Datasets +------------------------------- + +Default behavior returns ``tf.data.Dataset`` objects: + +.. code-block:: python + + dataset = RetentionTimeDataset( + data_source="data.csv", + dataset_type="tf", # Default + batch_size=64 + ) + + # Returns batched tf.data.Dataset + train_data = dataset.train_data + val_data = dataset.val_data + + # Use directly with Keras + model.fit(train_data, validation_data=val_data, epochs=10) + +Generating PyTorch Datasets +---------------------------- + +Use ``dataset_type="pt"`` for PyTorch DataLoaders. Since PyTorch DataLoaders must be created to provide tensors, you can pass additional DataLoader arguments via ``torch_dataloader_kwargs``. + +.. code-block:: python + + dataset = RetentionTimeDataset( + data_source="data.csv", + dataset_type="pt", + batch_size=64, + sequence_column="sequence", + label_column="retention_time", + torch_dataloader_kwargs={ + "shuffle": True, + # other DataLoader args + } + ) + + # Returns PyTorch DataLoader + train_loader = dataset.train_data + val_loader = dataset.val_data + + # Training loop + for batch, label in train_loader: + sequences = batch["sequence"] + # ... training code + + +Configuration Reference +======================= + +DatasetConfig Parameters +------------------------- + +**Data Sources** + +* ``data_source``: Path/URL to training data or HF dataset +* ``val_data_source``: Path to validation data (optional) +* ``test_data_source``: Path to test data (optional) +* ``data_format``: Format: ``"csv"``, ``"parquet"``, ``"hub"``, ``"hf"`` + +**Columns** + +* ``sequence_column``: Column name containing peptide sequences +* ``label_column``: Column(s) name(s) containing label(s) (str or list) +* ``dataset_columns_to_keep``: Additional columns to retain in the dataset after processing, else extra columns are dropped to save memory. + +**Processing** + +* ``max_seq_len``: Maximum sequence length for padding +* ``pad``: Enable padding (default: True) +* ``padding_value``: Character for padding (default: ``"-"``) +* ``with_termini``: Add N/C termini markers (default: True) +* ``encoding_scheme``: ``"unmod"`` or ``"naive-mods"`` +* ``alphabet``: Dict mapping tokens to integers + +**Features** + +* ``features_to_extract``: List of feature names or functions +* ``model_features``: Features to include in tensor output that already exist in the provided data and are to be carried over as tensors to be fed into the model. + +**Training** + +* ``val_ratio``: Validation split ratio (0-1) +* ``batch_size``: Batch size for tensor datasets +* ``dataset_type``: ``"tf"`` or ``"pt"`` +* ``shuffle``: Shuffle data (default: False) + + +Best Practices +============== + +**Caching Strategy** + +* Enable HF caching (``disable_cache=False``) for repeated experiments +* Save processed datasets with ``save_to_disk()`` to save time when iterating over the same data +* Use ``enable_tf_dataset_cache=True`` only if dataset fits in the available memory + +**Batch Size Selection** + +* Start with meaningful defaults if you have limited GPU memory (e.g., 64 or 128) +* Increase based on GPU memory availability and utilization +* Reduce if encountering OOM (Out-of-Memory) errors + +**Validation Splits** + +* Prefer explicit ``val_data_source`` for consistent evaluation +* Always use a separate test dataset for final evaluation, can also be created independently using another Dataset instance with ``test_data_source`` only. + +**Feature Engineering** + +* List all features in ``model_features`` to include them in tensors +* Use custom extractors for domain-specific features + +**Performance** + +* Set ``num_proc`` to match available CPU cores +* Use Parquet format for large datasets +* Process data once and save with ``save_to_disk()`` diff --git a/src/dlomix/data/charge_state.py b/src/dlomix/data/charge_state.py index 6fbfeb6f..6068ce00 100644 --- a/src/dlomix/data/charge_state.py +++ b/src/dlomix/data/charge_state.py @@ -1,6 +1,5 @@ from typing import Callable, Dict, List, Optional, Union -from ..constants import ALPHABET_UNMOD from .dataset import PeptideDataset from .dataset_config import DatasetConfig from .dataset_utils import EncodingScheme @@ -27,7 +26,7 @@ class ChargeStateDataset(PeptideDataset): features_to_extract (Optional[List[Union[Callable, str]]]): The list of features to extract from the dataset. Default is None. pad (bool): Whether to pad the sequences to the maximum length. Default is True. padding_value (str): The value to use for padding. Default is '-'. - alphabet (Dict): The mapping of characters to integers for encoding the sequences. Default is ALPHABET_UNMOD. + alphabet (Optional[Dict]): The mapping of characters to integers for encoding the sequences. Default is None to trigger learning the alphabet. with_termini (bool): Whether to add the N- and C-termini in the sequence column, even if they do not exist. Defaults to True. encoding_scheme (Union[str, EncodingScheme]): The encoding scheme to use for encoding the sequences. Default is EncodingScheme.UNMOD. processed (bool): Whether the data has been preprocessed. Default is False. @@ -57,7 +56,7 @@ def __init__( features_to_extract: Optional[List[Union[Callable, str]]] = None, pad: bool = True, padding_value: str = "-", - alphabet: Dict = ALPHABET_UNMOD, + alphabet: Optional[Dict] = None, with_termini: bool = True, encoding_scheme: Union[str, EncodingScheme] = EncodingScheme.UNMOD, processed: bool = False, diff --git a/src/dlomix/data/detectability.py b/src/dlomix/data/detectability.py index 4cc974ab..1f8fbdfe 100644 --- a/src/dlomix/data/detectability.py +++ b/src/dlomix/data/detectability.py @@ -1,6 +1,5 @@ from typing import Callable, Dict, List, Optional, Union -from ..constants import ALPHABET_UNMOD from .dataset import PeptideDataset from .dataset_config import DatasetConfig from .dataset_utils import EncodingScheme @@ -27,7 +26,7 @@ class DetectabilityDataset(PeptideDataset): features_to_extract (Optional[List[Union[Callable, str]]]): The list of features to extract from the dataset. Default is None. pad (bool): Whether to pad the sequences to the maximum length. Default is True. padding_value (str): The value to use for padding. Default is '-'. - alphabet (Dict): The mapping of characters to integers for encoding the sequences. Default is ALPHABET_UNMOD. + alphabet (Optional[Dict]): The mapping of characters to integers for encoding the sequences. Default is None to trigger learning the alphabet. with_termini (bool): Whether to add the N- and C-termini in the sequence column, even if they do not exist. Defaults to True. encoding_scheme (Union[str, EncodingScheme]): The encoding scheme to use for encoding the sequences. Default is EncodingScheme.UNMOD. processed (bool): Whether the data has been preprocessed. Default is False. @@ -57,7 +56,7 @@ def __init__( features_to_extract: Optional[List[Union[Callable, str]]] = None, pad: bool = True, padding_value: str = "-", - alphabet: Dict = ALPHABET_UNMOD, + alphabet: Optional[Dict] = None, with_termini: bool = True, encoding_scheme: Union[str, EncodingScheme] = EncodingScheme.UNMOD, processed: bool = False, diff --git a/src/dlomix/data/fragment_ion_intensity.py b/src/dlomix/data/fragment_ion_intensity.py index 03e30403..fa0087b2 100644 --- a/src/dlomix/data/fragment_ion_intensity.py +++ b/src/dlomix/data/fragment_ion_intensity.py @@ -1,6 +1,5 @@ from typing import Callable, Dict, List, Optional, Union -from ..constants import ALPHABET_UNMOD from .dataset import PeptideDataset from .dataset_config import DatasetConfig from .dataset_utils import EncodingScheme @@ -31,7 +30,7 @@ class FragmentIonIntensityDataset(PeptideDataset): features_to_extract (Optional[List[Union[Callable, str]]]): The list of features to extract from the dataset. Default is None. pad (bool): Whether to pad the sequences to the maximum length. Default is True. padding_value (str): The value to use for padding. Default is '-'. - alphabet (Dict): The mapping of characters to integers for encoding the sequences. Default is ALPHABET_UNMOD. + alphabet (Optional[Dict]): The mapping of characters to integers for encoding the sequences. Default is None to trigger learning the alphabet. with_termini (bool): Whether to add the N- and C-termini in the sequence column, even if they do not exist. Defaults to True. encoding_scheme (Union[str, EncodingScheme]): The encoding scheme to use for encoding the sequences. Default is EncodingScheme.UNMOD. processed (bool): Whether the data has been preprocessed before or not. Default is False. @@ -62,7 +61,7 @@ def __init__( features_to_extract: Optional[List[Union[Callable, str]]] = None, pad: bool = True, padding_value: str = "-", - alphabet: Dict = ALPHABET_UNMOD, + alphabet: Optional[Dict] = None, with_termini: bool = True, encoding_scheme: Union[str, EncodingScheme] = EncodingScheme.UNMOD, processed: bool = False, diff --git a/src/dlomix/data/ion_mobility.py b/src/dlomix/data/ion_mobility.py index abd6f079..d59398ed 100644 --- a/src/dlomix/data/ion_mobility.py +++ b/src/dlomix/data/ion_mobility.py @@ -28,7 +28,7 @@ class IonMobilityDataset(PeptideDataset): features_to_extract (Optional[List[Union[Callable, str]]]): The features to extract from the dataset. Defaults to None. pad (bool): Whether to pad sequences to the maximum length. Defaults to True. padding_value (str): The value to use for padding sequences. Defaults to '-'. - alphabet (Dict): The alphabet used for encoding sequences. Defaults to ALPHABET_UNMOD. + alphabet (Optional[Dict]): The alphabet used for encoding sequences. Defaults to None to trigger learning the alphabet. with_termini (bool): Whether to add the N- and C-termini in the sequence column, even if they do not exist. Defaults to True. encoding_scheme (Union[str, EncodingScheme]): The encoding scheme to use for sequences. Defaults to EncodingScheme.UNMOD. processed (bool): Whether the dataset has been preprocessed. Defaults to False. diff --git a/src/dlomix/data/retention_time.py b/src/dlomix/data/retention_time.py index 42baead5..09f77d73 100644 --- a/src/dlomix/data/retention_time.py +++ b/src/dlomix/data/retention_time.py @@ -1,6 +1,5 @@ from typing import Callable, Dict, List, Optional, Union -from ..constants import ALPHABET_UNMOD from .dataset import PeptideDataset from .dataset_config import DatasetConfig from .dataset_utils import EncodingScheme @@ -27,7 +26,7 @@ class RetentionTimeDataset(PeptideDataset): features_to_extract (Optional[List[Union[Callable, str]]]): The features to extract from the dataset. Defaults to None. pad (bool): Whether to pad sequences to the maximum length. Defaults to True. padding_value (str): The value to use for padding sequences. Defaults to '-'. - alphabet (Dict): The alphabet used for encoding sequences. Defaults to ALPHABET_UNMOD. + alphabet (Optional[Dict]): The alphabet used for encoding sequences. Defaults to None to trigger learning the alphabet. with_termini (bool): Whether to add the N- and C-termini in the sequence column, even if they do not exist. Defaults to True. encoding_scheme (Union[str, EncodingScheme]): The encoding scheme to use for sequences. Defaults to EncodingScheme.UNMOD. processed (bool): Whether the dataset has been preprocessed. Defaults to False. @@ -57,7 +56,7 @@ def __init__( features_to_extract: Optional[List[Union[Callable, str]]] = None, pad: bool = True, padding_value: str = "-", - alphabet: Dict = ALPHABET_UNMOD, + alphabet: Optional[Dict] = None, with_termini: bool = True, encoding_scheme: Union[str, EncodingScheme] = EncodingScheme.UNMOD, processed: bool = False, From c51f9a228a737e99d56ab618a40914d4f35a2e23 Mon Sep 17 00:00:00 2001 From: Omar Shouman Date: Sat, 31 Jan 2026 13:11:51 +0100 Subject: [PATCH 3/9] Feature/revisit tf dataset pipeline (#83) * cache - shuffle - batch - prefetch order * fixes related to termini - hf label column future warning * pr comments + tests --- src/dlomix/data/dataset.py | 53 ++++++++++------- src/dlomix/data/dataset_config.py | 16 ++++++ src/dlomix/data/processing/processors.py | 2 +- tests/conftest.py | 1 + tests/test_datasets.py | 72 ++++++++++++++++++++++++ tests/test_torch_dataset.py | 7 ++- 6 files changed, 126 insertions(+), 25 deletions(-) diff --git a/src/dlomix/data/dataset.py b/src/dlomix/data/dataset.py index d9e0b69d..56610ddc 100644 --- a/src/dlomix/data/dataset.py +++ b/src/dlomix/data/dataset.py @@ -6,6 +6,7 @@ from pathlib import Path from typing import Any, Callable, Dict, Optional, Union +import tensorflow as tf from datasets import Dataset, DatasetDict, Sequence, Value, load_dataset from .dataset_config import DatasetConfig @@ -120,15 +121,6 @@ def __init__(self, dataset_config: DatasetConfig, **kwargs): self.encoding_scheme = EncodingScheme(dataset_config.encoding_scheme) - if isinstance(dataset_config.label_column, str): - self.label_column = [dataset_config.label_column] - elif isinstance(dataset_config.label_column, list): - self.label_column = dataset_config.label_column - else: - raise ValueError( - "The label_column parameter should be a string or a list of strings." - ) - self._set_hf_cache_management() self.extended_alphabet = {} @@ -384,18 +376,11 @@ def _configure_padding_step(self): ) return - if self.max_seq_len > 0: - seq_len = self.max_seq_len - else: - raise ValueError( - f"Max sequence length provided is an integer but not a valid value: {self.max_seq_len}, only positive non-zero values are allowed." - ) - padding_processor = SequencePaddingProcessor( sequence_column_name=self.sequence_column, batched=True, padding_index=self.extended_alphabet[self.padding_value], - max_length=seq_len, + max_length=self.max_seq_len + 2 if self.with_termini else self.max_seq_len, ) self._processors.append(padding_processor) @@ -421,7 +406,9 @@ def _configure_feature_extraction_step(self): ], feature_column_name=feature_name, **FEATURE_EXTRACTORS_PARAMETERS[feature_name], - max_length=self.max_seq_len, + max_length=self.max_seq_len + 2 + if self.with_termini + else self.max_seq_len, batched=True, ) elif isinstance(feature, Callable): @@ -708,6 +695,7 @@ def tensor_train_data(self): if self.dataset_type == "pt": return self._get_split_torch_dataset(PeptideDataset.DEFAULT_SPLIT_NAMES[0]) else: + dataset_len = len(self.hf_dataset[PeptideDataset.DEFAULT_SPLIT_NAMES[0]]) tf_dataset = self._get_split_tf_dataset( PeptideDataset.DEFAULT_SPLIT_NAMES[0] ) @@ -715,6 +703,17 @@ def tensor_train_data(self): if self.enable_tf_dataset_cache: tf_dataset = tf_dataset.cache() + if self.shuffle: + tf_dataset = tf_dataset.shuffle( + buffer_size=min(10000, dataset_len), + reshuffle_each_iteration=True, + ) + + # Batch the data + tf_dataset = tf_dataset.batch(self.batch_size) + + tf_dataset = tf_dataset.prefetch(tf.data.AUTOTUNE) + return tf_dataset @property @@ -730,6 +729,9 @@ def tensor_val_data(self): if self.enable_tf_dataset_cache: tf_dataset = tf_dataset.cache() + tf_dataset = tf_dataset.batch(self.batch_size) + tf_dataset = tf_dataset.prefetch(tf.data.AUTOTUNE) + return tf_dataset @property @@ -745,6 +747,9 @@ def tensor_test_data(self): if self.enable_tf_dataset_cache: tf_dataset = tf_dataset.cache() + tf_dataset = tf_dataset.batch(self.batch_size) + tf_dataset = tf_dataset.prefetch(tf.data.AUTOTUNE) + return tf_dataset def _check_if_split_exists(self, split_name: str): @@ -758,11 +763,17 @@ def _check_if_split_exists(self, split_name: str): def _get_split_tf_dataset(self, split_name: str): self._check_if_split_exists(split_name) + # to return a tuple if it is a single label column and be compatible with HF Datasets API updates + + label_cols = self.label_column + + if isinstance(self.label_column, list) and len(self.label_column) == 1: + label_cols = self.label_column[0] + return self.hf_dataset[split_name].to_tf_dataset( columns=self._get_input_tensor_column_names(), - label_cols=self.label_column, - shuffle=self.shuffle, - batch_size=self.batch_size, + label_cols=label_cols, + shuffle=False, ) def _get_split_torch_dataset(self, split_name: str): diff --git a/src/dlomix/data/dataset_config.py b/src/dlomix/data/dataset_config.py index cf629f8a..d4898bcb 100644 --- a/src/dlomix/data/dataset_config.py +++ b/src/dlomix/data/dataset_config.py @@ -39,6 +39,22 @@ class DatasetConfig: batch_processing_size: int torch_dataloader_kwargs: Optional[Dict] = field(default_factory=dict) + # validate input parameters + def __post_init__(self): + # sequence length validation + if self.max_seq_len <= 0: + raise ValueError( + f"Max sequence length provided is an integer but not a valid value: {self.max_seq_len}, only positive non-zero values are allowed." + ) + + # label column validation, either a string or a list of strings + if not isinstance(self.label_column, (str, list)): + raise ValueError( + "The label_column parameter should be a string or a list of strings." + ) + elif isinstance(self.label_column, str): + self.label_column = [self.label_column] + def save_config_json(self, path: str): """ Save the configuration to a json file. diff --git a/src/dlomix/data/processing/processors.py b/src/dlomix/data/processing/processors.py index 9c632385..405ec939 100644 --- a/src/dlomix/data/processing/processors.py +++ b/src/dlomix/data/processing/processors.py @@ -79,7 +79,7 @@ class SequenceParsingProcessor(PeptideDatasetBaseProcessor): >>> processor = SequenceParsingProcessor("sequence") >>> data = {"sequence": "[]-IGGPC[UNIMOD:4]AHC[UNIMOD:4]AAWEGVR-[]"} >>> processor(data) - {'sequence': ['[]', 'I', 'G', 'G', 'P', 'C', 'A', 'H', 'C', 'A', 'A', 'W', 'E', 'G', 'V', 'R', '-[]'], + {'sequence': ['[]-', 'I', 'G', 'G', 'P', 'C', 'A', 'H', 'C', 'A', 'A', 'W', 'E', 'G', 'V', 'R', '-[]'], '_parsed_sequence': ['I', 'G', 'G', 'P', 'C', 'A', 'H', 'C', 'A', 'A', 'W', 'E', 'G', 'V', 'R'], '_n_term_mods': '[]-', '_c_term_mods': '-[]'} diff --git a/tests/conftest.py b/tests/conftest.py index dea91a47..f3a45e42 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -21,6 +21,7 @@ "seq": ["[UNIMOD:737]-DASAQTTSHELTIPN-[]", "[UNIMOD:737]-DLHTGRLC[UNIMOD:4]-[]"], "nested_feature": [[[30, 64]], [[25, 35]]], "label": [0.1, 0.2], + "label2": [1.0, 2.0], } TEST_ASSETS_TO_DOWNLOAD = [ diff --git a/tests/test_datasets.py b/tests/test_datasets.py index e0bda501..41658a85 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -347,3 +347,75 @@ def test_torch_dataloader_kwargs(): assert dataloader.pin_memory is False assert dataloader.num_workers == 0 assert dataset.torch_dataloader_kwargs is not None + + +def test_tf_tensor_dataset_string_label(): + """Test that TensorFlow TensorDataset is created properly.""" + hfdata = Dataset.from_dict(pytest.global_variables["RAW_GENERIC_NESTED_DATA"]) + + dataset = FragmentIonIntensityDataset( + data_format="hf", + data_source=hfdata, + sequence_column="seq", + label_column="label", + dataset_type="tf", + batch_size=1, + ) + + # Get the TensorFlow dataset + tf_dataset = dataset.tensor_train_data + + # Verify that the TensorFlow dataset is created successfully + assert tf_dataset is not None + for batch in tf_dataset.take(1): + features, labels = batch + assert features is not None + assert labels is not None + + +def test_tf_tensor_dataset_singelton_list_label(): + """Test that TensorFlow TensorDataset is created properly.""" + hfdata = Dataset.from_dict(pytest.global_variables["RAW_GENERIC_NESTED_DATA"]) + + dataset = FragmentIonIntensityDataset( + data_format="hf", + data_source=hfdata, + sequence_column="seq", + label_column=["label"], + dataset_type="tf", + batch_size=1, + ) + + # Get the TensorFlow dataset + tf_dataset = dataset.tensor_train_data + + # Verify that the TensorFlow dataset is created successfully + assert tf_dataset is not None + for batch in tf_dataset.take(1): + features, labels = batch + assert features is not None + assert labels is not None + + +def test_tf_tensor_dataset_list_multi_label(): + """Test that TensorFlow TensorDataset is created properly.""" + hfdata = Dataset.from_dict(pytest.global_variables["RAW_GENERIC_NESTED_DATA"]) + + dataset = FragmentIonIntensityDataset( + data_format="hf", + data_source=hfdata, + sequence_column="seq", + label_column=["label", "label2"], + dataset_type="tf", + batch_size=1, + ) + + # Get the TensorFlow dataset + tf_dataset = dataset.tensor_train_data + + # Verify that the TensorFlow dataset is created successfully + assert tf_dataset is not None + for batch in tf_dataset.take(1): + features, labels = batch + assert features is not None + assert labels is not None diff --git a/tests/test_torch_dataset.py b/tests/test_torch_dataset.py index c5215cac..7e67f469 100644 --- a/tests/test_torch_dataset.py +++ b/tests/test_torch_dataset.py @@ -21,6 +21,7 @@ def test_dataset_torch(): dataset_type="pt", batch_size=2, max_seq_len=15, + with_termini=False, ) logger.info(intensity_dataset) @@ -31,10 +32,10 @@ def test_dataset_torch(): logger.info(batch) - assert list(batch["nested_feature"].shape) == [1, 1, 2] - assert list(batch["seq"].shape) == [1, 15] + assert list(batch["nested_feature"].shape) == [2, 1, 2] + assert list(batch["seq"].shape) == [2, 15] assert list(batch["label"].shape) == [ - 1, + 2, ] assert batch["seq"].dtype == torch.int64 From 01dadc6973654e93a52eed6573af0cbc57cfd1ae Mon Sep 17 00:00:00 2001 From: omsh Date: Sat, 31 Jan 2026 13:13:44 +0100 Subject: [PATCH 4/9] dev version for better version management --- src/dlomix/_metadata.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dlomix/_metadata.py b/src/dlomix/_metadata.py index afc0b3de..042beac4 100644 --- a/src/dlomix/_metadata.py +++ b/src/dlomix/_metadata.py @@ -1,4 +1,4 @@ -__version__ = "0.2.3" +__version__ = "0.2.4.dev0" __author__ = "Wilhelm Lab" __author_email__ = "o.shouman@tum.de" __license__ = "MIT" From 44609eaeb5efd386c622cc84485308b8af3887b0 Mon Sep 17 00:00:00 2001 From: Omar Shouman Date: Mon, 2 Feb 2026 21:07:51 +0100 Subject: [PATCH 5/9] Fix/prosit dynamic creation refactoring - fixes feature extractor - improved tests (#84) * prosit model changes and tests * fix in feature extraction + tests * refined run scripts for intensity * refactored tests --- run_scripts/run_prosit_intensity.py | 29 +- run_scripts/run_prosit_intensity_ptms.py | 36 +- .../run_prosit_intensity_ptms_torch.py | 36 +- run_scripts/run_prosit_intensity_torch.py | 34 +- .../data/processing/feature_extractors.py | 16 +- src/dlomix/models/prosit.py | 285 ++++++++----- src/dlomix/models/prosit_torch.py | 310 ++++++++++---- tests/conftest.py | 177 +++++++- tests/test_datasets.py | 67 ++- tests/test_models.py | 33 +- tests/test_processors.py | 399 +++++++++++++++++- tests/test_torch_dataset.py | 5 +- tests/test_torch_models.py | 9 +- 13 files changed, 1105 insertions(+), 331 deletions(-) diff --git a/run_scripts/run_prosit_intensity.py b/run_scripts/run_prosit_intensity.py index 0dbf77da..e4ccb942 100644 --- a/run_scripts/run_prosit_intensity.py +++ b/run_scripts/run_prosit_intensity.py @@ -16,8 +16,6 @@ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) -# consider the use-case for starting from a saved model - optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001) # TRAIN_DATAPATH = "../example_dataset/intensity/intensity_data.parquet" @@ -28,14 +26,18 @@ data_source=TRAIN_DATAPATH, sequence_column="modified_sequence", label_column="intensities_raw", - # model_features=["precursor_charge_onehot", "collision_energy_aligned_normed"], + model_features=["precursor_charge_onehot", "collision_energy_aligned_normed"], max_seq_len=30, batch_size=128, val_ratio=0.2, - with_termini=False, + with_termini=True, + alphabet=None, ) print(d) +print(d["train"][0]) +print(d.extended_alphabet) + model = PrositIntensityPredictor( seq_length=30, @@ -44,18 +46,23 @@ # "COLLISION_ENERGY_KEY": "collision_energy_aligned_normed", # "PRECURSOR_CHARGE_KEY": "precursor_charge_onehot", }, - # meta_data_keys={ - # "COLLISION_ENERGY_KEY": "collision_energy_aligned_normed", - # "PRECURSOR_CHARGE_KEY": "precursor_charge_onehot", - # }, - with_termini=False, + meta_data_keys={ + "COLLISION_ENERGY_KEY": "collision_energy_aligned_normed", + "PRECURSOR_CHARGE_KEY": "precursor_charge_onehot", + }, + with_termini=True, + alphabet=d.extended_alphabet, + use_meta_data=True, ) model.compile(optimizer=optimizer, loss=masked_spectral_distance, metrics=["mse"]) -weights_file = "./run_scripts/output/prosit_intensity_test" +print(model) + +weights_file = "./run_scripts/output/prosit_intensity_test.keras" checkpoint = tf.keras.callbacks.ModelCheckpoint( - weights_file, save_best_only=True, save_weights_only=True + weights_file, + save_best_only=True, ) decay = tf.keras.callbacks.ReduceLROnPlateau( monitor="val_loss", factor=0.1, patience=10, verbose=1, min_lr=0 diff --git a/run_scripts/run_prosit_intensity_ptms.py b/run_scripts/run_prosit_intensity_ptms.py index b91113da..38cb1775 100644 --- a/run_scripts/run_prosit_intensity_ptms.py +++ b/run_scripts/run_prosit_intensity_ptms.py @@ -13,8 +13,25 @@ ) +TRAIN_DATAPATH = "example_dataset/intensity/third_pool_processed_sample.parquet" + +d = FragmentIonIntensityDataset( + data_source=TRAIN_DATAPATH, + max_seq_len=30, + batch_size=128, + val_ratio=0.2, + model_features=["collision_energy_aligned_normed", "precursor_charge_onehot"], + sequence_column="modified_sequence", + label_column="intensities_raw", + # features_to_extract=["mod_loss", "delta_mass"], + features_to_extract=["delta_mass"], + with_termini=False, + alphabet=None, + encoding_scheme="naive-mods", +) + model = PrositIntensityPredictor( - seq_length=32, + seq_length=30, use_prosit_ptm_features=True, input_keys={ "SEQUENCE_KEY": "modified_sequence", @@ -23,25 +40,12 @@ "COLLISION_ENERGY_KEY": "collision_energy_aligned_normed", "PRECURSOR_CHARGE_KEY": "precursor_charge_onehot", }, - with_termini=True, + with_termini=False, + alphabet=d.extended_alphabet, ) optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001) -TRAIN_DATAPATH = "example_dataset/intensity/third_pool_processed_sample.parquet" - -d = FragmentIonIntensityDataset( - data_source=TRAIN_DATAPATH, - max_seq_len=32, - batch_size=128, - val_ratio=0.2, - model_features=["collision_energy_aligned_normed", "precursor_charge_onehot"], - sequence_column="modified_sequence", - label_column="intensities_raw", - features_to_extract=["mod_loss", "delta_mass"], - with_termini=True, -) - model.compile(optimizer=optimizer, loss=masked_spectral_distance, metrics=["mse"]) weights_file = "./run_scripts/output/prosit_intensity_test_ptms" diff --git a/run_scripts/run_prosit_intensity_ptms_torch.py b/run_scripts/run_prosit_intensity_ptms_torch.py index 1f815f55..16b3413f 100644 --- a/run_scripts/run_prosit_intensity_ptms_torch.py +++ b/run_scripts/run_prosit_intensity_ptms_torch.py @@ -21,38 +21,40 @@ BATCH_SIZE = 8 N_EPOCHS = 20 -model = PrositIntensityPredictor( - seq_length=32, - use_prosit_ptm_features=True, - input_keys={ - "SEQUENCE_KEY": "modified_sequence", - }, - meta_data_keys={ - "COLLISION_ENERGY_KEY": "collision_energy_aligned_normed", - "PRECURSOR_CHARGE_KEY": "precursor_charge_onehot", - }, - with_termini=True, -) - -optimizer = torch.optim.Adam(params=model.parameters(), lr=0.0001) TRAIN_DATAPATH = "example_dataset/intensity/third_pool_processed_sample.parquet" d = FragmentIonIntensityDataset( data_source=TRAIN_DATAPATH, - max_seq_len=32, + max_seq_len=30, batch_size=BATCH_SIZE, val_ratio=0.2, model_features=["collision_energy_aligned_normed", "precursor_charge_onehot"], sequence_column="modified_sequence", label_column="intensities_raw", - features_to_extract=["mod_loss", "delta_mass"], + features_to_extract=["delta_mass"], dataset_type="pt", - with_termini=True, + with_termini=False, ) print(d) +model = PrositIntensityPredictor( + seq_length=30, + use_prosit_ptm_features=True, + input_keys={ + "SEQUENCE_KEY": "modified_sequence", + }, + meta_data_keys={ + "COLLISION_ENERGY_KEY": "collision_energy_aligned_normed", + "PRECURSOR_CHARGE_KEY": "precursor_charge_onehot", + }, + with_termini=False, + alphabet=d.extended_alphabet, +) + +optimizer = torch.optim.Adam(params=model.parameters(), lr=0.0001) + loss_criterion = masked_spectral_distance for epoch in tqdm(range(0, N_EPOCHS)): diff --git a/run_scripts/run_prosit_intensity_torch.py b/run_scripts/run_prosit_intensity_torch.py index 9c0fd35d..8bd3c354 100644 --- a/run_scripts/run_prosit_intensity_torch.py +++ b/run_scripts/run_prosit_intensity_torch.py @@ -20,22 +20,6 @@ BATCH_SIZE = 128 -model = PrositIntensityPredictor( - seq_length=30, - use_prosit_ptm_features=False, - input_keys={ - "SEQUENCE_KEY": "modified_sequence", - }, - # meta_data_keys={ - # "COLLISION_ENERGY_KEY": "collision_energy_aligned_normed", - # "PRECURSOR_CHARGE_KEY": "precursor_charge_onehot", - # }, - with_termini=False, - # alphabet=ALPHABET_UNMOD, -) - -optimizer = torch.optim.Adam(params=model.parameters(), lr=0.0001) - TRAIN_DATAPATH = "example_dataset/intensity/third_pool_processed_sample.parquet" d = FragmentIonIntensityDataset( @@ -48,12 +32,28 @@ label_column="intensities_raw", # features_to_extract=["mod_loss", "delta_mass"], dataset_type="pt", - # alphabet=ALPHABET_UNMOD, + alphabet=None, with_termini=False, ) print(d) +model = PrositIntensityPredictor( + seq_length=30, + use_prosit_ptm_features=False, + input_keys={ + "SEQUENCE_KEY": "modified_sequence", + }, + # meta_data_keys={ + # "COLLISION_ENERGY_KEY": "collision_energy_aligned_normed", + # "PRECURSOR_CHARGE_KEY": "precursor_charge_onehot", + # }, + with_termini=False, + alphabet=d.extended_alphabet, +) + +optimizer = torch.optim.Adam(params=model.parameters(), lr=0.0001) + loss_criterion = masked_spectral_distance N_EPOCHS = 20 diff --git a/src/dlomix/data/processing/feature_extractors.py b/src/dlomix/data/processing/feature_extractors.py index 65d2997a..5995463b 100644 --- a/src/dlomix/data/processing/feature_extractors.py +++ b/src/dlomix/data/processing/feature_extractors.py @@ -188,9 +188,21 @@ def single_process(self, input_data, **kwargs): return {self.feature_column_name: feature} def _extract_feature(self, sequence): + print("sequence:", sequence) + print("sequence length:", len(sequence)) + print("max_length:", self.max_length) + + # we lookup unttil the max length only because some sequences in train/val can be longer + lookup_length = min(self.max_length, len(sequence)) + sequence_for_lookup = sequence[:lookup_length] + feature = np.empty((self.max_length, *self._feature_shape), dtype=np.float32) - feature[: len(sequence)] = itemgetter(*sequence)(self.lookup_table) + print("feature shape:", feature.shape) + + feature[:lookup_length] = itemgetter(*sequence_for_lookup)(self.lookup_table) + + feature = self.pad_feature_to_seq_length(feature, self.max_length) - feature = self.pad_feature_to_seq_length(feature, len(sequence)) + print("DONE feature shape:", feature.shape) return feature diff --git a/src/dlomix/models/prosit.py b/src/dlomix/models/prosit.py index efc9cdd4..1b34d287 100644 --- a/src/dlomix/models/prosit.py +++ b/src/dlomix/models/prosit.py @@ -1,5 +1,4 @@ import logging -import warnings from collections.abc import Sequence import tensorflow as tf @@ -103,54 +102,75 @@ def call(self, inputs, **kwargs): @tf.keras.utils.register_keras_serializable(package="dlomix") class PrositIntensityPredictor(tf.keras.Model): """ - Prosit model for intensity prediction. + Prosit model for intensity prediction with configurable branches for PTM features and metadata. Parameters ---------- + input_keys : dict, optional + Dictionary mapping for the input keys to look for in the input dict. Defaults to None, which uses default required keys only "seqeuence". + meta_data_keys : list or dict, optional + List or dict of keys for the meta data inputs to use. Defaults to None (no meta data). + alphabet : dict, optional + Dictionary mapping for the alphabet (the amino acids in this case). Defaults to ALPHABET_UNMOD. + with_termini : bool, optional + Whether to include terminal tokens in the sequence embedding. Defaults to False. embedding_output_dim : int, optional Size of the embeddings to use. Defaults to 16. seq_length : int, optional Sequence length of the peptide sequences. Defaults to 30. - alphabet : dict, optional - Dictionary mapping for the vocabulary (the amino acids in this case). Defaults to None, which is mapped to `ALPHABET_UNMOD`. + len_fragment_ion : int, optional + Number of fragment ions to predict. Defaults to 6. dropout_rate : float, optional - Probability to use for dropout layers in the encoder. Defaults to 0.5. + Probability to use for dropout layers in the encoder. Defaults to 0.2. latent_dropout_rate : float, optional Probability to use for dropout layers in the regressor layers after encoding. Defaults to 0.1. recurrent_layers_sizes : tuple, optional A tuple of 2 values for the sizes of the two GRU layers in the encoder. Defaults to (256, 512). regressor_layer_size : int, optional Size of the dense layer in the regressor after the encoder. Defaults to 512. - use_prosit_ptm_features : boolean, optional - Whether to use PTM features and create corresponding layers, has to be aligned with input_keys. Defaults to False. - input_keys : dict, optional - Dict of string keys and values mapping a fixed key to a value key in the inputs dict from the dataset class. Defaults to None, which corresponds then to the required default input keys `DEFAULT_INPUT_KEYS`. - meta_data_keys : list, optional - List of string values corresponding to fixed keys in the inputs dict that are considered meta data. Defaults to None, which corresponds then to the default meta data keys `META_DATA_KEYS`. - with_termini : boolean, optional - Whether to consider the termini in the sequence. Defaults to True. + use_meta_data : bool, optional + Whether to use meta data inputs. Defaults to False. + use_prosit_ptm_features : bool, optional + Whether to use Prosit PTM features as input. Defaults to False. use_instrument_embedding : bool, optional - Whether to include an additional embedding layer for instrument type information. - When enabled, the model expects an `instrument_type` input key and embeds it as an additional meta feature. Defaults to False. + Whether to use instrument type embedding as part of the meta data. Defaults to False. instrument_input_dim : int, optional - The number of distinct instrument types (vocabulary size) used as input to the instrument embedding. Defaults to 3. + Number of unique instrument types for embedding. Defaults to 3. instrument_output_dim : int, optional - The dimensionality of the instrument embedding vector. Defaults to 2. + Size of the instrument type embedding. Defaults to 2. + + Attributes ---------- + + REQUIRED_INPUT_SEQUENCE_KEY : str + Key for the required sequence input in the input dictionary. DEFAULT_INPUT_KEYS : dict - Default keys for the input dict. + Default mapping of input keys for various inputs. META_DATA_KEYS : list - List of meta data keys. + List of possible meta data keys that can be used. PTM_INPUT_KEYS : list - List of PTM feature keys. + List of keys for the PTM feature inputs. See `dlomix.data.processing.feature_extractors.FEATURE_EXTRACTORS_PARAMETERS` for details. + + + Notes + ----- + This model is a flexible implementation of the Prosit intensity predictor, allowing for optional + inclusion of PTM features and meta data inputs. The model architecture consists of embedding layers, + bidirectional GRU encoders, attention mechanisms, and dense regressor layers. + + + + + + """ - # consider using kwargs in the call function instead ! + REQUIRED_INPUT_SEQUENCE_KEY = "SEQUENCE_KEY" DEFAULT_INPUT_KEYS = { - "SEQUENCE_KEY": "sequence", + REQUIRED_INPUT_SEQUENCE_KEY: "sequence", "COLLISION_ENERGY_KEY": "collision_energy", "PRECURSOR_CHARGE_KEY": "precursor_charge", "FRAGMENTATION_TYPE_KEY": "fragmentation_type", @@ -170,18 +190,19 @@ class PrositIntensityPredictor(tf.keras.Model): def __init__( self, + input_keys=None, + meta_data_keys=None, + alphabet=None, + with_termini=False, embedding_output_dim=16, seq_length=30, - len_fion=6, - alphabet=None, + len_fragment_ion=6, dropout_rate=0.2, latent_dropout_rate=0.1, recurrent_layers_sizes=(256, 512), regressor_layer_size=512, + use_meta_data=False, use_prosit_ptm_features=False, - input_keys=None, - meta_data_keys=None, - with_termini=True, use_instrument_embedding=False, instrument_input_dim=3, instrument_output_dim=2, @@ -193,73 +214,105 @@ def __init__( self.dropout_rate = dropout_rate self.latent_dropout_rate = latent_dropout_rate self.regressor_layer_size = regressor_layer_size - self.recurrent_layers_sizes = tuple( - recurrent_layers_sizes - ) # Ensure it's a tuple + self.recurrent_layers_sizes = tuple(recurrent_layers_sizes) self.embedding_output_dim = embedding_output_dim - self.seq_length = seq_length - self.len_fion = len_fion + self.raw_seq_length = seq_length + self.len_fragment_ion = len_fragment_ion self.use_prosit_ptm_features = use_prosit_ptm_features - self.with_termini = with_termini # Store this for serialization + self.with_termini = with_termini + self.use_meta_data = use_meta_data self.use_instrument_embedding = use_instrument_embedding self.instrument_input_dim = instrument_input_dim self.instrument_output_dim = instrument_output_dim - # Handle alphabet - store the actual alphabet used + # handle default and fallback attributes + self._handle_alphabet_and_keys(alphabet, input_keys, meta_data_keys) + + self._validate_config() + + self._compute_attributes() + + # Build layers + self._build_embedding_layers() + self._build_encoders() + self._build_decoder() + self.attention = AttentionLayer(name="encoder_att") + self._build_meta_data_fusion_layer() + self._build_regressor() + + def _handle_alphabet_and_keys(self, alphabet, input_keys, meta_data_keys): + # Handle alphabet if alphabet is not None: - self.alphabet = dict(alphabet) # Make a copy + self.alphabet = dict(alphabet) else: - self.alphabet = dict(ALPHABET_UNMOD) # Make a copy of default + self.alphabet = dict(ALPHABET_UNMOD) # fallback to unmodified amino acids - # Handle input_keys - store the actual keys used - if input_keys is not None: - self.input_keys = dict(input_keys) # Make a copy + # Handle input_keys + if input_keys: + self.input_keys = dict(input_keys) else: - self.input_keys = dict(self.DEFAULT_INPUT_KEYS) # Use default + self.input_keys = { + self.REQUIRED_INPUT_SEQUENCE_KEY: "sequence" + } # minimal required key for the sequence input - # Handle meta_data_keys - store the actual keys used - if meta_data_keys is not None: + # Handle meta_data_keys + if not meta_data_keys: + self.meta_data_keys = [] + else: # handle dict in case user passed a mapping, take the values in a list to use for lookup if isinstance(meta_data_keys, dict): self.meta_data_keys = list(meta_data_keys.values()) + elif isinstance(meta_data_keys, list): + self.meta_data_keys = list(meta_data_keys) else: - self.meta_data_keys = list(meta_data_keys) # Make a copy - else: - self.meta_data_keys = list( - [self.input_keys.get(k) for k in self.META_DATA_KEYS] - ) # Use default + raise ValueError( + "meta_data_keys should be either a list of strings or a dict mapping. Provided type: " + f"{type(meta_data_keys)}, and value: {meta_data_keys}" + ) + def _validate_config(self): + if self.use_meta_data and not self.meta_data_keys: + raise ValueError( + "use_meta_data=True requires meta_data_keys to be provided as a list of keys." + ) + + if ( + self.use_instrument_embedding + and "INSTRUMENT_TYPE_KEY" not in self.input_keys + ): + raise ValueError( + "use_instrument_embedding=True requires 'INSTRUMENT_TYPE_KEY' in input_keys" + ) + + def _compute_attributes(self): # Compute derived attributes (will be recomputed during deserialization) - self.max_ion = self.seq_length - 1 - if self.with_termini: - self.max_ion = self.max_ion - 2 + self.max_ion = self.raw_seq_length - 1 + + self.seq_length = ( + self.raw_seq_length + 2 if self.with_termini else self.raw_seq_length + ) # tie the count of embeddings to the size of the vocabulary (count of amino acids) - self.embeddings_count = len(self.alphabet) + 1 + self.embeddings_count = len(self.alphabet) - # Build layers + def _build_embedding_layers(self): self.embedding = tf.keras.layers.Embedding( input_dim=self.embeddings_count, output_dim=self.embedding_output_dim, - input_length=seq_length, + name="sequence_embedding", ) + self.instrument_embedding = None if self.use_instrument_embedding: self.instrument_embedding = tf.keras.layers.Embedding( input_dim=self.instrument_input_dim, output_dim=self.instrument_output_dim, name="instrument_embedding", ) - else: - self.instrument_embedding = None - - self._build_encoders() - self._build_decoder() - - self.attention = AttentionLayer(name="encoder_att") + def _build_meta_data_fusion_layer(self): self.meta_data_fusion_layer = None - if self.meta_data_keys: + if self.use_meta_data: self.meta_data_fusion_layer = tf.keras.Sequential( [ tf.keras.layers.Multiply(name="add_meta"), @@ -267,16 +320,6 @@ def __init__( ] ) - self.regressor = tf.keras.Sequential( - [ - tf.keras.layers.TimeDistributed( - tf.keras.layers.Dense(self.len_fion), name="time_dense" - ), - tf.keras.layers.LeakyReLU(name="activation"), - tf.keras.layers.Flatten(name="out"), - ] - ) - def _build_encoders(self): # sequence encoder -> always present self.sequence_encoder = tf.keras.Sequential( @@ -296,7 +339,7 @@ def _build_encoders(self): # meta data encoder -> optional, only if meta data keys are provided self.meta_encoder = None - if self.meta_data_keys: + if self.use_meta_data: self.meta_encoder = tf.keras.Sequential( [ tf.keras.layers.Concatenate(name="meta_in"), @@ -338,73 +381,102 @@ def _build_decoder(self): ] ) + def _build_regressor(self): + self.regressor = tf.keras.Sequential( + [ + tf.keras.layers.TimeDistributed( + tf.keras.layers.Dense(self.len_fragment_ion), name="time_dense" + ), + tf.keras.layers.LeakyReLU(name="activation"), + tf.keras.layers.Flatten(name="out"), + ] + ) + def call(self, inputs, **kwargs): + # Handle dict input, complex case: multiple inputs + if isinstance(inputs, dict): + return self._forward_dict(inputs, **kwargs) + + # Handle single input, simple case: sequence only + peptides_in = inputs + return self._forward_sequence_only(peptides_in, **kwargs) + + def _forward_sequence_only(self, peptides_in, **kwargs): + x = self.embedding(peptides_in) + x = self.sequence_encoder(x) + x = self.attention(x) + x = tf.expand_dims(x, axis=1) # add meta data dimension + x = self.decoder(x) + x = self.regressor(x) + return x + + def _forward_dict(self, inputs, **kwargs): + missing_input_keys = [k for k in self.input_keys.values() if k not in inputs] + if missing_input_keys: + raise ValueError(f"Missing required input keys: {missing_input_keys}") + + meta_data = [] encoded_meta = None encoded_ptm = None - if not isinstance(inputs, dict): - # when inputs has (seq, target), it comes as tuple - peptides_in = inputs - else: - peptides_in = inputs.get(self.input_keys["SEQUENCE_KEY"]) + # collect instrument embedding if enabled and add to meta data + if self.use_instrument_embedding: + instrument_type = inputs.get(self.input_keys["INSTRUMENT_TYPE_KEY"]) + instrument_embedded = self.instrument_embedding(instrument_type) + meta_data.append(instrument_embedded) + + # collect meta data from the input dict + if self.use_meta_data: + missing_meta_keys = [k for k in self.meta_data_keys if k not in inputs] + if missing_meta_keys: + raise ValueError( + f"Missing required metadata inputs: {missing_meta_keys}" + ) - # read meta data from the input dict - # note that the value here is the key to use in the inputs dict passed from the dataset - meta_data = self._collect_values_from_inputs_if_exists( - inputs, self.meta_data_keys + meta_data.extend( + self._collect_values_from_inputs_if_exists(inputs, self.meta_data_keys) ) - if self.use_instrument_embedding: - instrument_type = inputs.get(self.input_keys["INSTRUMENT_TYPE_KEY"]) - if instrument_type is not None: - instrument_embedded = self.instrument_embedding(instrument_type) - meta_data.append(instrument_embedded) - - if self.meta_encoder and len(meta_data) > 0: - encoded_meta = self.meta_encoder(meta_data) - else: + # collect PTM features from the input dict + if self.use_prosit_ptm_features: + ptm_keys_exist = [k for k in self.PTM_INPUT_KEYS if k in inputs] + if not ptm_keys_exist: raise ValueError( - f"Following metadata keys were specified when creating the model: {self.meta_data_keys}, but the corresponding values do not exist in the input. The actual input passed to the model contains the following keys: {list(inputs.keys())}" + f"At least one PTM input feature is required when use_prosit_ptm_features=True. Missing all of: {self.PTM_INPUT_KEYS}" ) - # read PTM features from the input dict ptm_ac_features = self._collect_values_from_inputs_if_exists( inputs, PrositIntensityPredictor.PTM_INPUT_KEYS ) - if self.ptm_input_encoder and len(ptm_ac_features) > 0: - encoded_ptm = self.ptm_input_encoder(ptm_ac_features) - elif self.use_prosit_ptm_features: - warnings.warn( - f"PTM features enabled and following PTM features are expected in the model for Prosit Intesity: {PrositIntensityPredictor.PTM_INPUT_KEYS}. The actual input passed to the model contains the following keys: {list(inputs.keys())}. Falling back to no PTM features." - ) + peptides_in = inputs[self.input_keys[self.REQUIRED_INPUT_SEQUENCE_KEY]] x = self.embedding(peptides_in) - # fusion of PTMs (before going into the GRU sequence encoder) - if self.ptm_aa_fusion and encoded_ptm is not None: + # encode and fuse PTM features (before going into the GRU sequence encoder) + if self.use_prosit_ptm_features: + encoded_ptm = self.ptm_input_encoder(ptm_ac_features) x = self.ptm_aa_fusion([x, encoded_ptm]) x = self.sequence_encoder(x) - x = self.attention(x) - if self.meta_data_fusion_layer and encoded_meta is not None: + if self.use_meta_data: + encoded_meta = self.meta_encoder(meta_data) x = self.meta_data_fusion_layer([x, encoded_meta]) else: # no metadata -> add a dimension to comply with the shape x = tf.expand_dims(x, axis=1) x = self.decoder(x) - x = self.regressor(x) return x def _collect_values_from_inputs_if_exists(self, inputs, keys_mapping): collected_values = [] - keys = [] + if isinstance(keys_mapping, dict): keys = list(keys_mapping.values()) @@ -433,8 +505,8 @@ def get_config(self): config.update( { "embedding_output_dim": self.embedding_output_dim, - "seq_length": self.seq_length, - "len_fion": self.len_fion, + "seq_length": self.raw_seq_length, + "len_fragment_ion": self.len_fragment_ion, "alphabet": self.alphabet, # Store the actual alphabet dict used "dropout_rate": self.dropout_rate, "latent_dropout_rate": self.latent_dropout_rate, @@ -446,6 +518,7 @@ def get_config(self): "input_keys": self.input_keys, # Store the actual input keys used "meta_data_keys": self.meta_data_keys, # Store the actual meta data keys used "with_termini": self.with_termini, + "use_meta_data": self.use_meta_data, "use_instrument_embedding": self.use_instrument_embedding, "instrument_input_dim": self.instrument_input_dim, "instrument_output_dim": self.instrument_output_dim, diff --git a/src/dlomix/models/prosit_torch.py b/src/dlomix/models/prosit_torch.py index c7cb970a..2cb05741 100644 --- a/src/dlomix/models/prosit_torch.py +++ b/src/dlomix/models/prosit_torch.py @@ -1,6 +1,6 @@ import logging -import warnings from collections import OrderedDict +from collections.abc import Sequence import torch import torch.nn as nn @@ -105,10 +105,64 @@ def forward(self, inputs, **kwargs): class PrositIntensityPredictor(nn.Module): """ - Implementation of the Prosit model for fragment ion intensity prediction. + Prosit model for intensity prediction with configurable branches for PTM features and metadata. Parameters - ----------- + ---------- + input_keys : dict, optional + Dictionary mapping for the input keys to look for in the input dict. Defaults to None, which uses default required keys only "seqeuence". + meta_data_keys : list or dict, optional + List or dict of keys for the meta data inputs to use. Defaults to None (no meta data). + alphabet : dict, optional + Dictionary mapping for the alphabet (the amino acids in this case). Defaults to ALPHABET_UNMOD. + with_termini : bool, optional + Whether to include terminal tokens in the sequence embedding. Defaults to False. + embedding_output_dim : int, optional + Size of the embeddings to use. Defaults to 16. + seq_length : int, optional + Sequence length of the peptide sequences. Defaults to 30. + len_fragment_ion : int, optional + Number of fragment ions to predict. Defaults to 6. + dropout_rate : float, optional + Probability to use for dropout layers in the encoder. Defaults to 0.2. + latent_dropout_rate : float, optional + Probability to use for dropout layers in the regressor layers after encoding. Defaults to 0.1. + recurrent_layers_sizes : tuple, optional + A tuple of 2 values for the sizes of the two GRU layers in the encoder. Defaults to (256, 512). + regressor_layer_size : int, optional + Size of the dense layer in the regressor after the encoder. Defaults to 512. + use_meta_data : bool, optional + Whether to use meta data inputs. Defaults to False. + use_prosit_ptm_features : bool, optional + Whether to use Prosit PTM features as input. Defaults to False. + use_instrument_embedding : bool, optional + Whether to use instrument type embedding as part of the meta data. Defaults to False. + instrument_input_dim : int, optional + Number of unique instrument types for embedding. Defaults to 3. + instrument_output_dim : int, optional + Size of the instrument type embedding. Defaults to 2. + + + + Attributes + ---------- + + REQUIRED_INPUT_SEQUENCE_KEY : str + Key for the required sequence input in the input dictionary. + DEFAULT_INPUT_KEYS : dict + Default mapping of input keys for various inputs. + META_DATA_KEYS : list + List of possible meta data keys that can be used. + PTM_INPUT_KEYS : list + List of keys for the PTM feature inputs. See `dlomix.data.processing.feature_extractors.FEATURE_EXTRACTORS_PARAMETERS` for details. + + + Notes + ----- + This model is a flexible implementation of the Prosit intensity predictor, allowing for optional + inclusion of PTM features and meta data inputs. The model architecture consists of embedding layers, + bidirectional GRU encoders, attention mechanisms, and dense regressor layers. + @@ -116,11 +170,14 @@ class PrositIntensityPredictor(nn.Module): """ + REQUIRED_INPUT_SEQUENCE_KEY = "SEQUENCE_KEY" + DEFAULT_INPUT_KEYS = { - "SEQUENCE_KEY": "sequence", + REQUIRED_INPUT_SEQUENCE_KEY: "sequence", "COLLISION_ENERGY_KEY": "collision_energy", "PRECURSOR_CHARGE_KEY": "precursor_charge", "FRAGMENTATION_TYPE_KEY": "fragmentation_type", + "INSTRUMENT_TYPE_KEY": "instrument_type", } # can be extended to include all possible meta data @@ -128,6 +185,7 @@ class PrositIntensityPredictor(nn.Module): "COLLISION_ENERGY_KEY", "PRECURSOR_CHARGE_KEY", "FRAGMENTATION_TYPE_KEY", + "INSTRUMENT_TYPE_KEY", ] # retrieve the Lookup PTM feature keys @@ -135,73 +193,130 @@ class PrositIntensityPredictor(nn.Module): def __init__( self, + input_keys=None, + meta_data_keys=None, + alphabet=None, + with_termini=False, embedding_output_dim=16, seq_length=30, - len_fion=6, - alphabet=None, + len_fragment_ion=6, dropout_rate=0.2, latent_dropout_rate=0.1, recurrent_layers_sizes=(256, 512), regressor_layer_size=512, + use_meta_data=False, use_prosit_ptm_features=False, - input_keys=None, - meta_data_keys=None, - with_termini=True, + use_instrument_embedding=False, + instrument_input_dim=3, + instrument_output_dim=2, + **kwargs, ): - super(PrositIntensityPredictor, self).__init__() + super(PrositIntensityPredictor, self).__init__(**kwargs) + # Store all configuration parameters self.dropout_rate = dropout_rate self.latent_dropout_rate = latent_dropout_rate self.regressor_layer_size = regressor_layer_size - self.recurrent_layers_sizes = recurrent_layers_sizes + self.recurrent_layers_sizes = tuple(recurrent_layers_sizes) self.embedding_output_dim = embedding_output_dim - self.seq_length = seq_length - self.len_fion = len_fion + self.raw_seq_length = seq_length + self.len_fragment_ion = len_fragment_ion self.use_prosit_ptm_features = use_prosit_ptm_features - self.input_keys = input_keys - self.meta_data_keys = meta_data_keys + self.with_termini = with_termini + self.use_meta_data = use_meta_data + self.use_instrument_embedding = use_instrument_embedding + self.instrument_input_dim = instrument_input_dim + self.instrument_output_dim = instrument_output_dim + + # handle default and fallback attributes + self._handle_alphabet_and_keys(alphabet, input_keys, meta_data_keys) - # maximum number of fragment ions - self.max_ion = self.seq_length - 1 + self._validate_config() - # account for encoded termini - if with_termini: - self.max_ion = self.max_ion - 2 + self._compute_attributes() - if alphabet: - self.alphabet = alphabet + # Build layers + self._build_embedding_layers() + self._build_encoders() + self._build_decoder() + self.attention = AttentionLayer( + feature_dim=regressor_layer_size, seq_len=seq_length + ) + self._build_meta_data_fusion_layer() + self._build_regressor() + + def _handle_alphabet_and_keys(self, alphabet, input_keys, meta_data_keys): + # Handle alphabet + if alphabet is not None: + self.alphabet = dict(alphabet) else: - self.alphabet = ALPHABET_UNMOD + self.alphabet = dict(ALPHABET_UNMOD) # fallback to unmodified amino acids + + # Handle input_keys + if input_keys: + self.input_keys = dict(input_keys) + else: + self.input_keys = { + self.REQUIRED_INPUT_SEQUENCE_KEY: "sequence" + } # minimal required key for the sequence input + + # Handle meta_data_keys + if not meta_data_keys: + self.meta_data_keys = [] + else: + # handle dict in case user passed a mapping, take the values in a list to use for lookup + if isinstance(meta_data_keys, dict): + self.meta_data_keys = list(meta_data_keys.values()) + elif isinstance(meta_data_keys, list): + self.meta_data_keys = list(meta_data_keys) + else: + raise ValueError( + "meta_data_keys should be either a list of strings or a dict mapping. Provided type: " + f"{type(meta_data_keys)}, and value: {meta_data_keys}" + ) + + def _validate_config(self): + if self.use_meta_data and not self.meta_data_keys: + raise ValueError( + "use_meta_data=True requires meta_data_keys to be provided as a list of keys." + ) + + if ( + self.use_instrument_embedding + and "INSTRUMENT_TYPE_KEY" not in self.input_keys + ): + raise ValueError( + "use_instrument_embedding=True requires 'INSTRUMENT_TYPE_KEY' in input_keys" + ) + + def _compute_attributes(self): + # Compute derived attributes (will be recomputed during deserialization) + self.max_ion = self.raw_seq_length - 1 + + self.seq_length = ( + self.raw_seq_length + 2 if self.with_termini else self.raw_seq_length + ) # tie the count of embeddings to the size of the vocabulary (count of amino acids) - self.embeddings_count = len(self.alphabet) + 1 + self.embeddings_count = len(self.alphabet) + def _build_embedding_layers(self): self.embedding = nn.Embedding( num_embeddings=self.embeddings_count, embedding_dim=self.embedding_output_dim, ) - self._build_encoders() - self._build_decoder() - - self.attention = AttentionLayer( - feature_dim=regressor_layer_size, seq_len=seq_length - ) + self.instrument_embedding = None + if self.use_instrument_embedding: + self.instrument_embedding = nn.Embedding( + num_embeddings=self.instrument_input_dim, + embedding_dim=self.instrument_output_dim, + ) - self.meta_data_fusion_layer = None + def _build_meta_data_fusion_layer(self): if self.meta_data_keys: self.meta_data_fusion_layer = MetaDataFusionBlock(max_ion=self.max_ion) - self.regressor = nn.Sequential( - OrderedDict( - [ - ("time_dense", nn.LazyLinear(out_features=len_fion)), - ("activation", nn.LeakyReLU()), - ("output", nn.Flatten()), - ] - ) - ) - def _build_encoders(self): # sequence encoder -> always present gru_features_input_size = self.embedding_output_dim @@ -217,7 +332,7 @@ def _build_encoders(self): # meta data encoder -> optional, only if meta data keys are provided self.meta_encoder = None - if self.meta_data_keys: + if self.use_meta_data: self.meta_encoder = nn.Sequential( OrderedDict( [ @@ -252,91 +367,116 @@ def _build_decoder(self): max_ion=self.max_ion, ) + def _build_regressor(self): + self.regressor = nn.Sequential( + OrderedDict( + [ + ("time_dense", nn.LazyLinear(out_features=self.len_fragment_ion)), + ("activation", nn.LeakyReLU()), + ("output", nn.Flatten()), + ] + ) + ) + def forward(self, inputs, **kwargs): + # Handle dict input, complex case: multiple inputs + if isinstance(inputs, dict): + return self._forward_dict(inputs, **kwargs) + + # Handle single input, simple case: sequence only + peptides_in = inputs + return self._forward_sequence_only(peptides_in, **kwargs) + + def _forward_sequence_only(self, peptides_in, **kwargs): + x = self.embedding(peptides_in) + x = self.sequence_encoder(x) + x = self.attention(x) + x = torch.unsqueeze(x, dim=1) + x = self.decoder(x) + x = self.regressor(x) + return x + + def _forward_dict(self, inputs, **kwargs): + missing_input_keys = [k for k in self.input_keys.values() if k not in inputs] + if missing_input_keys: + raise ValueError(f"Missing required input keys: {missing_input_keys}") + + meta_data = [] encoded_meta = None encoded_ptm = None - if not isinstance(inputs, dict): - # when inputs has (seq, target), it comes as tuple - peptides_in = inputs - else: - peptides_in = inputs.get(self.input_keys["SEQUENCE_KEY"]) + # collect instrument embedding if enabled and add to meta data + if self.use_instrument_embedding: + instrument_type = inputs.get(self.input_keys["INSTRUMENT_TYPE_KEY"]) + instrument_embedded = self.instrument_embedding(instrument_type) + meta_data.append(instrument_embedded) - # read meta data from the input dict - # note that the value here is the key to use in the inputs dict passed from the dataset - meta_data = self._collect_values_from_inputs_if_exists( - inputs, self.meta_data_keys - ) + # collect meta data from the input dict + if self.use_meta_data: + missing_meta_keys = [k for k in self.meta_data_keys if k not in inputs] + if missing_meta_keys: + raise ValueError( + f"Missing required metadata inputs: {missing_meta_keys}" + ) - if self.meta_encoder and len(meta_data) > 0: - if isinstance(meta_data, list): - meta_data = torch.cat(meta_data, dim=-1) - encoded_meta = self.meta_encoder(meta_data) + meta_data.extend( + self._collect_values_from_inputs_if_exists(inputs, self.meta_data_keys) + ) - elif self.meta_data_keys: + # collect PTM features from the input dict + if self.use_prosit_ptm_features: + ptm_keys_exist = [k for k in self.PTM_INPUT_KEYS if k in inputs] + if not ptm_keys_exist: raise ValueError( - f"Following metadata keys were specified when creating the model: {self.meta_data_keys}, but the corresponding values do not exist in the input. The actual input passed to the model contains the following keys: {list(inputs.keys())}" + f"At least one PTM input feature is required when use_prosit_ptm_features=True. Missing all of: {self.PTM_INPUT_KEYS}" ) - else: - pass - # ToDo: ensure PTM features work as expected - # read PTM features from the input dict # --> Still needs to be implemented ptm_ac_features = self._collect_values_from_inputs_if_exists( inputs, PrositIntensityPredictor.PTM_INPUT_KEYS ) - if self.ptm_input_encoder and len(ptm_ac_features) > 0: - logger.debug("PTM features: ") - for f in ptm_ac_features: - logger.debug(f.shape) - encoded_ptm = self.ptm_input_encoder(ptm_ac_features) - elif self.use_prosit_ptm_features: - warnings.warn( - f"PTM features enabled and following PTM features are expected in the model for Prosit Intesity: {PrositIntensityPredictor.PTM_INPUT_KEYS}. The actual input passed to the model contains the following keys: {list(inputs.keys())}. Falling back to no PTM features." - ) + peptides_in = inputs[self.input_keys[self.REQUIRED_INPUT_SEQUENCE_KEY]] x = self.embedding(peptides_in) - # fusion of PTMs (before going into the GRU sequence encoder) - if self.ptm_aa_fusion and encoded_ptm is not None: - logger.debug("before ptm fusion: %s", x.shape) - logger.debug("before ptm fusion enc ptm: %s", encoded_ptm.shape) + # encode and fuse PTM features (before going into the GRU sequence encoder) + if self.use_prosit_ptm_features: + encoded_ptm = self.ptm_input_encoder(ptm_ac_features) x = self.ptm_aa_fusion([x, encoded_ptm]) - logger.debug("concatednated after fusion: %s", x.shape) x = self.sequence_encoder(x) - x = self.attention(x) - if self.meta_data_fusion_layer and encoded_meta is not None: + if self.use_meta_data: + if isinstance(meta_data, list): + meta_data = torch.cat(meta_data, dim=-1) + encoded_meta = self.meta_encoder(meta_data) x = self.meta_data_fusion_layer([x, encoded_meta]) else: # no metadata -> add a dimension to comply with the shape - x = torch.unsqueeze(x, axis=1) + x = torch.unsqueeze(x, dim=1) x = self.decoder(x) - x = self.regressor(x) return x def _collect_values_from_inputs_if_exists(self, inputs, keys_mapping): collected_values = [] - keys = [] + if isinstance(keys_mapping, dict): - keys = keys_mapping.values() + keys = list(keys_mapping.values()) - elif isinstance(keys_mapping, list): - keys = keys_mapping + elif isinstance(keys_mapping, Sequence): + keys = list(keys_mapping) for key_in_inputs in keys: # get the input under the specified key if exists single_input = inputs.get(key_in_inputs, None) if single_input is not None: if single_input.ndim == 1: - single_input = torch.unsqueeze(single_input, axis=-1) + single_input = torch.unsqueeze(single_input, dim=-1) collected_values.append(single_input) return collected_values diff --git a/tests/conftest.py b/tests/conftest.py index f3a45e42..24ba36cd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,15 +14,6 @@ RT_CSV_EXAMPLE_URL = "https://raw.githubusercontent.com/wilhelm-lab/dlomix/develop/example_dataset/proteomTools_train_val.csv" INTENSITY_PARQUET_EXAMPLE_URL = "https://raw.githubusercontent.com/wilhelm-lab/dlomix/develop/example_dataset/intensity/intensity_data.parquet" INTENSITY_CSV_EXAMPLE_URL = "https://raw.githubusercontent.com/wilhelm-lab/dlomix/develop/example_dataset/intensity/intensity_data.csv" -RT_HUB_DATASET_NAME = "Wilhelmlab/prospect-ptms-irt" - - -RAW_GENERIC_NESTED_DATA = { - "seq": ["[UNIMOD:737]-DASAQTTSHELTIPN-[]", "[UNIMOD:737]-DLHTGRLC[UNIMOD:4]-[]"], - "nested_feature": [[[30, 64]], [[25, 35]]], - "label": [0.1, 0.2], - "label2": [1.0, 2.0], -} TEST_ASSETS_TO_DOWNLOAD = [ RT_PARQUET_EXAMPLE_URL, @@ -39,14 +30,24 @@ def unzip_file(zip_file_path, dest_dir): f.extractall(dest_dir) -@pytest.fixture(scope="session", autouse=True) -def global_variables(): - pytest.global_variables = { - "RAW_GENERIC_NESTED_DATA": RAW_GENERIC_NESTED_DATA, - "DOWNLOAD_PATH_FOR_ASSETS": DOWNLOAD_PATH_FOR_ASSETS, +@pytest.fixture(scope="session") +def raw_generic_nested_data(): + return { + "seq": [ + "[UNIMOD:737]-DASAQTTSHELTIPN-[]", + "[UNIMOD:737]-DLHTGRLC[UNIMOD:4]-[]", + ], + "nested_feature": [[[30, 64]], [[25, 35]]], + "label": [0.1, 0.2], + "label2": [1.0, 2.0], } +@pytest.fixture(scope="session") +def download_path_for_assets(): + return DOWNLOAD_PATH_FOR_ASSETS + + @pytest.fixture(scope="session", autouse=True) def download_assets(): makedirs(DOWNLOAD_PATH_FOR_ASSETS, exist_ok=True) @@ -90,3 +91,151 @@ def subset_downloaded_files(download_assets): pd.read_csv(file).sample(N).to_csv(file) else: continue + + +@pytest.fixture +def basic_alphabet(): + """Basic amino acid alphabet without modifications.""" + return { + "A": 1, + "C": 2, + "D": 3, + "E": 4, + "F": 5, + "G": 6, + "H": 7, + "I": 8, + "K": 9, + "L": 10, + "M": 11, + "N": 12, + "P": 13, + "Q": 14, + "R": 15, + "S": 16, + "T": 17, + "V": 18, + "W": 19, + "Y": 20, + "[]-": 21, + "-[]": 22, + } + + +@pytest.fixture +def ptm_alphabet(): + """Alphabet with PTM modifications included.""" + base = { + "A": 1, + "C": 2, + "D": 3, + "E": 4, + "F": 5, + "G": 6, + "H": 7, + "I": 8, + "K": 9, + "L": 10, + "M": 11, + "N": 12, + "P": 13, + "Q": 14, + "R": 15, + "S": 16, + "T": 17, + "V": 18, + "W": 19, + "Y": 20, + "[]-": 21, + "-[]": 22, + "C[UNIMOD:4]": 23, # Carbamidomethylation of C + "K[UNIMOD:737]": 24, # TMT6plex of K + "[UNIMOD:737]-": 25, # TMT6plex of N-terminus + } + return base + + +@pytest.fixture +def sample_parsed_sequence(): + """Sample parsed sequence data (output of SequenceParsingProcessor).""" + return { + "sequence": ["[]-", "D", "E", "L", "-[]"], + "_parsed_sequence": ["D", "E", "L"], + "_n_term_mods": "[]-", + "_c_term_mods": "-[]", + } + + +@pytest.fixture +def sample_parsed_sequence_with_ptm(): + """Sample parsed sequence with PTM modifications.""" + return { + "sequence": ["[]-", "H", "C[UNIMOD:4]", "V", "D", "-[]"], + "_parsed_sequence": ["H", "C[UNIMOD:4]", "V", "D"], + "_n_term_mods": "[]-", + "_c_term_mods": "-[]", + } + + +@pytest.fixture +def sample_parsed_sequence_with_nterm_mod(): + """Sample parsed sequence with N-terminal modification.""" + return { + "sequence": ["[UNIMOD:737]-", "I", "L", "C[UNIMOD:4]", "S", "-[]"], + "_parsed_sequence": ["I", "L", "C[UNIMOD:4]", "S"], + "_n_term_mods": "[UNIMOD:737]-", + "_c_term_mods": "-[]", + } + + +@pytest.fixture +def mock_lookup_table(): + """Mock PTM feature lookup table.""" + return { + "C[UNIMOD:4]": [1, 2, 3, 4, 5, 6], # 6D feature vector + "K[UNIMOD:737]": [7, 8, 9, 10, 11, 12], + "[UNIMOD:737]-": [0, 0, 0, 0, 0, 0], + "[]-": [0, 0, 0, 0, 0, 0], + "-[]": [0, 0, 0, 0, 0, 0], + "A": [0, 0, 0, 0, 0, 0], + "D": [0, 1, 0, 0, 0, 0], + "E": [0, 1, 0, 0, 0, 0], + "L": [0, 0, 0, 0, 0, 0], + "H": [1, 0, 0, 0, 0, 1], + "V": [0, 0, 0, 0, 0, 0], + "I": [0, 0, 0, 0, 0, 0], + "S": [0, 0, 0, 1, 0, 0], + } + + +@pytest.fixture +def sample_batched_sequences(): + """Sample batch of sequences for batched processor tests.""" + return [ + "[]-DEL-[]", + "[]-HHDELIF-[]", + "[]-C[UNIMOD:4]VD-[]", + ] + + +@pytest.fixture +def sample_custom_function(): + """Custom function for FunctionProcessor testing.""" + + def double_sequence_length(data, **kwargs): + if isinstance(data.get("sequence"), list): + return {"sequence": data["sequence"] * 2} + return data + + return double_sequence_length + + +@pytest.fixture +def sample_custom_function_with_kwargs(): + """Custom function that accepts kwargs.""" + + def scale_feature(data, scale_factor=1.0, **kwargs): + feature = data.get("feature", 0.0) + return {"feature": feature * scale_factor} + + return scale_feature diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 41658a85..324f2645 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -3,7 +3,6 @@ from os.path import join from shutil import rmtree -import pytest from datasets import Dataset, DatasetDict, load_dataset from dlomix.data import ( @@ -23,11 +22,9 @@ def test_empty_rtdataset(): assert rtdataset._empty_dataset_mode is True -def test_parquet_rtdataset(): +def test_parquet_rtdataset(download_path_for_assets): rtdataset = RetentionTimeDataset( - data_source=join( - pytest.global_variables["DOWNLOAD_PATH_FOR_ASSETS"], "file_1.parquet" - ), + data_source=join(download_path_for_assets, "file_1.parquet"), sequence_column="modified_sequence", label_column="indexed_retention_time", ) @@ -46,12 +43,10 @@ def test_parquet_rtdataset(): assert rtdataset[RetentionTimeDataset.DEFAULT_SPLIT_NAMES[1]].num_rows > 0 -def test_rtdataset_inmemory(): +def test_rtdataset_inmemory(download_path_for_assets): hf_dataset = load_dataset( "parquet", - data_files=join( - pytest.global_variables["DOWNLOAD_PATH_FOR_ASSETS"], "file_1.parquet" - ), + data_files=join(download_path_for_assets, "file_1.parquet"), split="train", ) @@ -89,11 +84,9 @@ def test_rtdataset_hub(): assert rtdataset[RetentionTimeDataset.DEFAULT_SPLIT_NAMES[2]].num_rows > 0 -def test_csv_rtdataset(): +def test_csv_rtdataset(download_path_for_assets): rtdataset = RetentionTimeDataset( - data_source=join( - pytest.global_variables["DOWNLOAD_PATH_FOR_ASSETS"], "file_2.csv" - ), + data_source=join(download_path_for_assets, "file_2.csv"), data_format="csv", sequence_column="sequence", label_column="irt", @@ -121,10 +114,8 @@ def test_empty_intensitydataset(): assert intensity_dataset._empty_dataset_mode is True -def test_parquet_intensitydataset(): - filepath = join( - pytest.global_variables["DOWNLOAD_PATH_FOR_ASSETS"], "file_3.parquet" - ) +def test_parquet_intensitydataset(download_path_for_assets): + filepath = join(download_path_for_assets, "file_3.parquet") intensity_dataset = FragmentIonIntensityDataset( data_format="parquet", data_source=filepath, @@ -154,8 +145,8 @@ def test_parquet_intensitydataset(): ) -def test_csv_intensitydataset(): - filepath = join(pytest.global_variables["DOWNLOAD_PATH_FOR_ASSETS"], "file_4.csv") +def test_csv_intensitydataset(download_path_for_assets): + filepath = join(download_path_for_assets, "file_4.csv") intensity_dataset = FragmentIonIntensityDataset( data_format="csv", data_source=filepath, @@ -184,8 +175,8 @@ def test_csv_intensitydataset(): ) -def test_nested_model_features(): - hfdata = Dataset.from_dict(pytest.global_variables["RAW_GENERIC_NESTED_DATA"]) +def test_nested_model_features(raw_generic_nested_data): + hfdata = Dataset.from_dict(raw_generic_nested_data) intensity_dataset = FragmentIonIntensityDataset( data_format="hf", @@ -202,8 +193,8 @@ def test_nested_model_features(): assert example[0]["nested_feature"].shape == [2, 1, 2] -def test_save_dataset(): - hfdata = Dataset.from_dict(pytest.global_variables["RAW_GENERIC_NESTED_DATA"]) +def test_save_dataset(raw_generic_nested_data): + hfdata = Dataset.from_dict(raw_generic_nested_data) intensity_dataset = FragmentIonIntensityDataset( data_format="hf", @@ -221,11 +212,9 @@ def test_save_dataset(): rmtree(save_path) -def test_load_dataset(): +def test_load_dataset(download_path_for_assets): rtdataset = RetentionTimeDataset( - data_source=join( - pytest.global_variables["DOWNLOAD_PATH_FOR_ASSETS"], "file_2.csv" - ), + data_source=join(download_path_for_assets, "file_2.csv"), data_format="csv", sequence_column="sequence", label_column="irt", @@ -259,8 +248,8 @@ def test_load_dataset(): rmtree(save_path) -def test_no_split_datasetDict_hf_inmemory(): - hfdata = Dataset.from_dict(pytest.global_variables["RAW_GENERIC_NESTED_DATA"]) +def test_no_split_datasetDict_hf_inmemory(raw_generic_nested_data): + hfdata = Dataset.from_dict(raw_generic_nested_data) hf_dataset = DatasetDict({"train": hfdata}) intensity_dataset = FragmentIonIntensityDataset( @@ -288,9 +277,9 @@ def test_no_split_datasetDict_hf_inmemory(): # test learning alphabet for train/val and then using it for test with fallback -def test_shuffle_parameter(): +def test_shuffle_parameter(raw_generic_nested_data): """Test that shuffle parameter works for both TensorFlow and PyTorch datasets.""" - hfdata = Dataset.from_dict(pytest.global_variables["RAW_GENERIC_NESTED_DATA"]) + hfdata = Dataset.from_dict(raw_generic_nested_data) # Test with shuffle=True for TensorFlow tf_dataset = FragmentIonIntensityDataset( @@ -321,9 +310,9 @@ def test_shuffle_parameter(): assert pt_dataset.tensor_train_data is not None -def test_torch_dataloader_kwargs(): +def test_torch_dataloader_kwargs(raw_generic_nested_data): """Test that additional PyTorch DataLoader kwargs are properly passed through.""" - hfdata = Dataset.from_dict(pytest.global_variables["RAW_GENERIC_NESTED_DATA"]) + hfdata = Dataset.from_dict(raw_generic_nested_data) dataset = FragmentIonIntensityDataset( data_format="hf", @@ -349,9 +338,9 @@ def test_torch_dataloader_kwargs(): assert dataset.torch_dataloader_kwargs is not None -def test_tf_tensor_dataset_string_label(): +def test_tf_tensor_dataset_string_label(raw_generic_nested_data): """Test that TensorFlow TensorDataset is created properly.""" - hfdata = Dataset.from_dict(pytest.global_variables["RAW_GENERIC_NESTED_DATA"]) + hfdata = Dataset.from_dict(raw_generic_nested_data) dataset = FragmentIonIntensityDataset( data_format="hf", @@ -373,9 +362,9 @@ def test_tf_tensor_dataset_string_label(): assert labels is not None -def test_tf_tensor_dataset_singelton_list_label(): +def test_tf_tensor_dataset_singelton_list_label(raw_generic_nested_data): """Test that TensorFlow TensorDataset is created properly.""" - hfdata = Dataset.from_dict(pytest.global_variables["RAW_GENERIC_NESTED_DATA"]) + hfdata = Dataset.from_dict(raw_generic_nested_data) dataset = FragmentIonIntensityDataset( data_format="hf", @@ -397,9 +386,9 @@ def test_tf_tensor_dataset_singelton_list_label(): assert labels is not None -def test_tf_tensor_dataset_list_multi_label(): +def test_tf_tensor_dataset_list_multi_label(raw_generic_nested_data): """Test that TensorFlow TensorDataset is created properly.""" - hfdata = Dataset.from_dict(pytest.global_variables["RAW_GENERIC_NESTED_DATA"]) + hfdata = Dataset.from_dict(raw_generic_nested_data) dataset = FragmentIonIntensityDataset( data_format="hf", diff --git a/tests/test_models.py b/tests/test_models.py index 845e4eaf..37c37298 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -83,7 +83,7 @@ def test_prosit_intensity_model_ptm_on_input(): def test_prosit_intensity_model_ptm_on_missing(): model = PrositIntensityPredictor(use_prosit_ptm_features=True) seq_len = model.seq_length - with pytest.warns(UserWarning, match="PTM"): + with pytest.raises(ValueError, match="PTM"): model.build( { "sequence": ( @@ -99,10 +99,12 @@ def test_prosit_intensity_model_ptm_on_missing(): def test_prosit_intensity_model_encoding_metadata_missing(): - model = PrositIntensityPredictor() + model = PrositIntensityPredictor( + meta_data_keys=["meta_data_1", "meta_data_2"], use_meta_data=True + ) seq_len = model.seq_length - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="metadata"): model.build( { "sequence": ( @@ -114,6 +116,31 @@ def test_prosit_intensity_model_encoding_metadata_missing(): ) +def test_prosit_intensity_model_no_metadata(): + model = PrositIntensityPredictor( + input_keys={ + "SEQUENCE_KEY": "sequence", + }, + meta_data_keys=None, + ) + + seq_len = model.seq_length + + assert model.meta_data_keys == [] + + model.build( + { + "sequence": ( + None, + seq_len, + ), + } + ) + + assert model is not None + assert model.meta_encoder is None + + def basic_model_existence_test(model): logger.info(model) assert model is not None diff --git a/tests/test_processors.py b/tests/test_processors.py index f4c2a8f1..ea95d8a1 100644 --- a/tests/test_processors.py +++ b/tests/test_processors.py @@ -1,8 +1,4 @@ import logging -import urllib.request -import zipfile -from os import makedirs -from os.path import exists, join import pytest @@ -161,25 +157,396 @@ def test_sequence_padding_processor_drop(): assert not padded[SequencePaddingProcessor.KEEP_COLUMN_NAME] -def test_sequence_encoding_processor(): - pass +def test_sequence_encoding_processor_fixed_alphabet(basic_alphabet): + """Test encoding with fixed alphabet (unknown tokens -> unknown_token_index).""" + alphabet = basic_alphabet.copy() + p = SequenceEncodingProcessor( + sequence_column_name=SEQ_COLUMN, + alphabet=alphabet, + extend_alphabet=False, + unknown_token="X", + unknown_token_index=23, + ) + + # Known sequence + input_data = {SEQ_COLUMN: ["D", "E", "L"]} + result = p(input_data) + + assert result[SEQ_COLUMN] == [3, 4, 10] # D=3, E=4, L=10 from alphabet + + # Unknown token should map to unknown_token_index + # Note: alphabet is a copy from fixture which doesn't have Z + input_data = {SEQ_COLUMN: ["Z"]} + result = p(input_data) + + assert result[SEQ_COLUMN] == [23] # Z unknown -> unknown_token_index + + +def test_sequence_encoding_processor_fixed_alphabet_with_ptm(ptm_alphabet): + """Test encoding with PTM modifications in fixed alphabet.""" + alphabet = ptm_alphabet.copy() + p = SequenceEncodingProcessor( + sequence_column_name=SEQ_COLUMN, + alphabet=alphabet, + extend_alphabet=False, + ) + + input_data = {SEQ_COLUMN: ["H", "C[UNIMOD:4]", "V", "D"]} + result = p(input_data) + + assert result[SEQ_COLUMN] == [7, 23, 18, 3] # H=7, C[UNIMOD:4]=23, V=18, D=3 + + +def test_sequence_encoding_processor_extend_alphabet(): + """Test encoding with alphabet extension (learns new tokens).""" + initial_alphabet = { + "A": 1, + "C": 2, + "D": 3, + "E": 4, + "L": 10, + } + + p = SequenceEncodingProcessor( + sequence_column_name=SEQ_COLUMN, + alphabet=initial_alphabet, + extend_alphabet=True, + ) + + # First sequence - learns new tokens + input_data = {SEQ_COLUMN: ["D", "E", "L"]} + result = p(input_data) + assert result[SEQ_COLUMN] == [3, 4, 10] + + # Second sequence with unknown token - extends alphabet + input_data = {SEQ_COLUMN: ["A", "C", "V"]} + result = p(input_data) + # V should now be in the alphabet + assert "V" in p.alphabet + assert result[SEQ_COLUMN][2] == p.alphabet["V"] + + +def test_sequence_encoding_processor_extend_alphabet_batched(basic_alphabet): + """Test alphabet extension in batched mode.""" + alphabet = basic_alphabet.copy() + p = SequenceEncodingProcessor( + sequence_column_name=SEQ_COLUMN, + alphabet=alphabet, + batched=True, + extend_alphabet=True, + ) + + # Initialize with unknown amino acids + initial_alphabet = {"A": 1, "C": 2, "D": 3} + p.alphabet = initial_alphabet.copy() + + input_data = {SEQ_COLUMN: [["A", "C"], ["D", "E"]]} + result = p(input_data) + + # "E" should be learned and added to alphabet + assert "E" in p.alphabet + assert len(result[SEQ_COLUMN]) == 2 + assert len(result[SEQ_COLUMN][0]) == 2 + assert len(result[SEQ_COLUMN][1]) == 2 + + +def test_sequence_encoding_processor_fallback_mode_unmodified(basic_alphabet): + """Test fallback mode: unknown PTM -> unmodified amino acid.""" + alphabet = basic_alphabet.copy() + p = SequenceEncodingProcessor( + sequence_column_name=SEQ_COLUMN, + alphabet=alphabet, + extend_alphabet=False, + fallback_unmodified=True, + ) + + # Unknown PTM C[UNIMOD:999] should fallback to C + input_data = {SEQ_COLUMN: ["C[UNIMOD:999]"]} + result = p(input_data) + + # C[UNIMOD:999] -> C -> 2 + assert result[SEQ_COLUMN][0] == 2 + + +def test_sequence_encoding_processor_fallback_mode_terminal_mods(basic_alphabet): + """Test fallback mode with unknown terminal modifications.""" + alphabet = basic_alphabet.copy() + p = SequenceEncodingProcessor( + sequence_column_name=SEQ_COLUMN, + alphabet=alphabet, + extend_alphabet=False, + fallback_unmodified=True, + ) + + # Unknown N-terminal mod [UNIMOD:999]- should fallback to []- + input_data = {SEQ_COLUMN: ["[UNIMOD:999]-", "D", "E"]} + result = p(input_data) + + assert result[SEQ_COLUMN][0] == 21 # []- + assert result[SEQ_COLUMN][1] == 3 # D + assert result[SEQ_COLUMN][2] == 4 # E + + +def test_sequence_encoding_processor_unknown_token_already_in_alphabet(): + """Test that unknown token uses existing index if already in alphabet.""" + alphabet = { + "A": 1, + "X": 2, # X already in alphabet + "D": 3, + } + + # Should not raise, but should warn + with pytest.warns(UserWarning): + p = SequenceEncodingProcessor( + sequence_column_name=SEQ_COLUMN, + alphabet=alphabet, + unknown_token="X", + unknown_token_index=1, # Will be overridden to 2 + ) + + # unknown_token_index should be set to existing X index + assert p.unknown_token_index == 2 + + +def test_sequence_encoding_processor_unknown_token_index_conflict(): + """Test warning when unknown_token_index is already used.""" + alphabet = { + "A": 1, + "C": 2, + "D": 3, + } + + # Should warn about index conflict + with pytest.warns(UserWarning): + p = SequenceEncodingProcessor( + sequence_column_name=SEQ_COLUMN, + alphabet=alphabet, + unknown_token_index=2, # 2 is already used by C + ) + + # Index should be reassigned + assert p.unknown_token_index != 2 + + +def test_sequence_encoding_processor_single_vs_batched_consistency(basic_alphabet): + """Test that single and batched modes produce identical results.""" + alphabet = basic_alphabet.copy() + p_single = SequenceEncodingProcessor( + sequence_column_name=SEQ_COLUMN, + alphabet=alphabet.copy(), + batched=False, + ) + p_batch = SequenceEncodingProcessor( + sequence_column_name=SEQ_COLUMN, + alphabet=alphabet.copy(), + batched=True, + ) + + sequences = [["D", "E", "L"], ["A", "C"], ["K", "L", "M"]] + + # Process individually + results_single = [p_single({SEQ_COLUMN: seq})[SEQ_COLUMN] for seq in sequences] + + # Process in batch + result_batch = p_batch({SEQ_COLUMN: sequences})[SEQ_COLUMN] + + assert results_single == result_batch + + +# ============================================================================ +# SEQUENCE PTM REMOVAL PROCESSOR TESTS +# ============================================================================ + + +def test_sequence_ptm_removal_processor_basic(): + """Test basic PTM removal from sequence.""" + p = SequencePTMRemovalProcessor(sequence_column_name=SEQ_COLUMN) + + input_data = {SEQ_COLUMN: ["[]-", "C[UNIMOD:4]", "V", "D", "-[]"]} + result = p(input_data) + + assert result[SEQ_COLUMN] == ["[]-", "C", "V", "D", "-[]"] + + +def test_sequence_ptm_removal_processor_multiple_ptms(): + """Test removal of multiple PTMs from same sequence.""" + p = SequencePTMRemovalProcessor(sequence_column_name=SEQ_COLUMN) + + input_data = { + SEQ_COLUMN: ["[UNIMOD:737]-", "C[UNIMOD:4]", "K[UNIMOD:737]", "S", "-[]"] + } + result = p(input_data) + + assert result[SEQ_COLUMN] == ["[UNIMOD:737]-", "C", "K", "S", "-[]"] + + +def test_sequence_ptm_removal_processor_no_ptms(): + """Test sequence without PTMs is unchanged.""" + p = SequencePTMRemovalProcessor(sequence_column_name=SEQ_COLUMN) + + input_data = {SEQ_COLUMN: ["[]-", "D", "E", "L", "-[]"]} + result = p(input_data) + + assert result[SEQ_COLUMN] == ["[]-", "D", "E", "L", "-[]"] -def test_sequence_encoding_processor_with_fixed_alphabet(): - pass +def test_sequence_ptm_removal_processor_batched(): + """Test PTM removal in batched mode.""" + p = SequencePTMRemovalProcessor(sequence_column_name=SEQ_COLUMN, batched=True) + input_data = { + SEQ_COLUMN: [ + ["[]-", "C[UNIMOD:4]", "V", "-[]"], + ["[]-", "K[UNIMOD:737]", "D", "-[]"], + ] + } + result = p(input_data) -def test_sequence_encoding_processor_with_extend_alphabet_enabled(): - pass + assert result[SEQ_COLUMN][0] == ["[]-", "C", "V", "-[]"] + assert result[SEQ_COLUMN][1] == ["[]-", "K", "D", "-[]"] -def test_sequence_encoding_processor_with_fallback_enabled(): - pass +def test_sequence_ptm_removal_processor_preserves_terminals(): + """Test that terminal modifications are preserved.""" + p = SequencePTMRemovalProcessor(sequence_column_name=SEQ_COLUMN) + input_data = {SEQ_COLUMN: ["[UNIMOD:737]-", "C[UNIMOD:4]", "-[UNIMOD:1]"]} + result = p(input_data) + + # Terminals should be unchanged + assert result[SEQ_COLUMN][0] == "[UNIMOD:737]-" + assert result[SEQ_COLUMN][-1] == "-[UNIMOD:1]" + + +def test_sequence_ptm_removal_processor_non_list_input_raises_error(): + """Test that non-list input raises ValueError.""" + p = SequencePTMRemovalProcessor(sequence_column_name=SEQ_COLUMN) + + # String input should raise error + input_data = {SEQ_COLUMN: "C[UNIMOD:4]VD"} + + with pytest.raises(ValueError, match="Sequence must be a list"): + p(input_data) + + +# ============================================================================ +# FUNCTION PROCESSOR TESTS +# ============================================================================ + + +def test_function_processor_basic(sample_custom_function): + """Test basic function processor application.""" + p = FunctionProcessor(function=sample_custom_function) + + input_data = {SEQ_COLUMN: ["D", "E", "L"]} + result = p(input_data) + + assert result[SEQ_COLUMN] == ["D", "E", "L", "D", "E", "L"] + + +def test_function_processor_with_kwargs(sample_custom_function_with_kwargs): + """Test function processor with keyword arguments.""" + p = FunctionProcessor(function=sample_custom_function_with_kwargs) + + input_data = {"feature": 5.0} + result = p(input_data, scale_factor=2.0) + + assert result["feature"] == 10.0 + + +def test_function_processor_lambda(): + """Test function processor with lambda function.""" + p = FunctionProcessor( + function=lambda data, **kwargs: {"sequence": data.get("sequence", []) * 2} + ) + + input_data = {"sequence": [1, 2, 3]} + result = p(input_data) + + assert result["sequence"] == [1, 2, 3, 1, 2, 3] + + +# ============================================================================ +# EDGE CASE TESTS +# ============================================================================ + + +def test_sequence_padding_processor_exact_length(): + """Test padding when sequence length equals max_length.""" + length = 3 + p = SequencePaddingProcessor(sequence_column_name=SEQ_COLUMN, max_length=length) + + input_data = {SEQ_COLUMN: ["D", "E", "L"]} + result = p(input_data) + + assert result[SEQ_COLUMN] == ["D", "E", "L"] + assert result[SequencePaddingProcessor.KEEP_COLUMN_NAME] is True + + +def test_sequence_padding_processor_custom_padding_index(): + """Test padding with custom padding index.""" + p = SequencePaddingProcessor( + sequence_column_name=SEQ_COLUMN, + max_length=5, + padding_index=-1, + ) + + input_data = {SEQ_COLUMN: ["D", "E", "L"]} + result = p(input_data) + + assert result[SEQ_COLUMN] == ["D", "E", "L", -1, -1] + + +def test_sequence_padding_processor_empty_sequence(): + """Test padding with empty sequence.""" + p = SequencePaddingProcessor(sequence_column_name=SEQ_COLUMN, max_length=3) + + input_data = {SEQ_COLUMN: []} + result = p(input_data) + + assert result[SEQ_COLUMN] == [0, 0, 0] + assert result[SequencePaddingProcessor.KEEP_COLUMN_NAME] is True + + +def test_sequence_padding_processor_long_sequence(): + """Test truncation of very long sequence.""" + p = SequencePaddingProcessor(sequence_column_name=SEQ_COLUMN, max_length=5) + + long_sequence = list(range(10)) + input_data = {SEQ_COLUMN: long_sequence} + result = p(input_data) + + assert len(result[SEQ_COLUMN]) == 5 + assert result[SEQ_COLUMN] == long_sequence[:5] + assert result[SequencePaddingProcessor.KEEP_COLUMN_NAME] is False + + +def test_sequence_parsing_processor_invalid_format(): + """Test error on invalid ProForma format (4+ parts).""" + p = SequenceParsingProcessor(sequence_column_name=SEQ_COLUMN) + + input_data = {SEQ_COLUMN: "[UNIMOD:1]-SEQ-[UNIMOD:2]-EXTRA"} + + with pytest.raises(ValueError, match="Invalid sequence format"): + p(input_data) + + +def test_sequence_encoding_processor_empty_alphabet_fixed_mode(): + """Test that fixed mode works with empty alphabet + unknown tokens.""" + p = SequenceEncodingProcessor( + sequence_column_name=SEQ_COLUMN, + alphabet={}, # Empty alphabet (will get unknown token added) + extend_alphabet=False, + unknown_token="X", + unknown_token_index=1, + ) -def test_sequence_ptm_removal_processor(): - pass + # Alphabet was empty, but unknown token X was added with index 1 + assert "X" in p.alphabet + assert p.unknown_token_index == 1 + input_data = {SEQ_COLUMN: ["A", "C", "D"]} + result = p(input_data) -def test_function_processor(): - pass + # All should map to unknown_token_index since they're not in alphabet + assert result[SEQ_COLUMN] == [1, 1, 1] diff --git a/tests/test_torch_dataset.py b/tests/test_torch_dataset.py index 7e67f469..e3086bfc 100644 --- a/tests/test_torch_dataset.py +++ b/tests/test_torch_dataset.py @@ -1,6 +1,5 @@ import logging -import pytest import torch from datasets import Dataset @@ -9,8 +8,8 @@ logger = logging.getLogger(__name__) -def test_dataset_torch(): - hfdata = Dataset.from_dict(pytest.global_variables["RAW_GENERIC_NESTED_DATA"]) +def test_dataset_torch(raw_generic_nested_data): + hfdata = Dataset.from_dict(raw_generic_nested_data) intensity_dataset = FragmentIonIntensityDataset( data_format="hf", diff --git a/tests/test_torch_models.py b/tests/test_torch_models.py index 00ab7c61..ce05b9e4 100644 --- a/tests/test_torch_models.py +++ b/tests/test_torch_models.py @@ -110,8 +110,13 @@ def test_tf_torch_equivalence_intensity_model_shapes(): dummy_input_torch = torch.randint(low=0, high=15, size=(batch_size, seq_len)) dummi_input_tf = dummy_input_torch.numpy() - model_tf = PrositIntensityPredictor() - model_torch = PrositIntensityPredictorTorch() + model_tf = PrositIntensityPredictor( + seq_length=seq_len, + ) + + model_torch = PrositIntensityPredictorTorch( + seq_length=seq_len, + ) output_tf = model_tf(dummi_input_tf) output_torch = model_torch(dummy_input_torch) From 5b6d8465a9458843e08d6a8f3740953212eac5cd Mon Sep 17 00:00:00 2001 From: omsh Date: Mon, 2 Feb 2026 21:43:58 +0100 Subject: [PATCH 6/9] version 0.2.4.dev0 --- src/dlomix/_metadata.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dlomix/_metadata.py b/src/dlomix/_metadata.py index 042beac4..1df6c6c7 100644 --- a/src/dlomix/_metadata.py +++ b/src/dlomix/_metadata.py @@ -1,4 +1,4 @@ -__version__ = "0.2.4.dev0" +__version__ = "0.2.4.dev1" __author__ = "Wilhelm Lab" __author_email__ = "o.shouman@tum.de" __license__ = "MIT" From 890fd8e5fd35389c2b88dd7058623ef5750bd7d9 Mon Sep 17 00:00:00 2001 From: Omar Shouman Date: Mon, 9 Feb 2026 09:32:16 +0100 Subject: [PATCH 7/9] feature array subsetting fixed for correct padding after np.empty + tests (#85) --- run_scripts/run_prosit_intensity_ptms.py | 10 +-- .../data/processing/feature_extractors.py | 11 +-- tests/conftest.py | 27 +++---- tests/test_feature_extractors.py | 79 +++++++++++++++++++ 4 files changed, 96 insertions(+), 31 deletions(-) create mode 100644 tests/test_feature_extractors.py diff --git a/run_scripts/run_prosit_intensity_ptms.py b/run_scripts/run_prosit_intensity_ptms.py index 38cb1775..db27a8a9 100644 --- a/run_scripts/run_prosit_intensity_ptms.py +++ b/run_scripts/run_prosit_intensity_ptms.py @@ -23,16 +23,16 @@ model_features=["collision_energy_aligned_normed", "precursor_charge_onehot"], sequence_column="modified_sequence", label_column="intensities_raw", - # features_to_extract=["mod_loss", "delta_mass"], - features_to_extract=["delta_mass"], - with_termini=False, + features_to_extract=["mod_loss", "delta_mass"], + # features_to_extract=["delta_mass"], + with_termini=True, alphabet=None, encoding_scheme="naive-mods", ) - model = PrositIntensityPredictor( seq_length=30, use_prosit_ptm_features=True, + use_meta_data=True, input_keys={ "SEQUENCE_KEY": "modified_sequence", }, @@ -40,7 +40,7 @@ "COLLISION_ENERGY_KEY": "collision_energy_aligned_normed", "PRECURSOR_CHARGE_KEY": "precursor_charge_onehot", }, - with_termini=False, + with_termini=True, alphabet=d.extended_alphabet, ) diff --git a/src/dlomix/data/processing/feature_extractors.py b/src/dlomix/data/processing/feature_extractors.py index 5995463b..6be789da 100644 --- a/src/dlomix/data/processing/feature_extractors.py +++ b/src/dlomix/data/processing/feature_extractors.py @@ -188,21 +188,16 @@ def single_process(self, input_data, **kwargs): return {self.feature_column_name: feature} def _extract_feature(self, sequence): - print("sequence:", sequence) - print("sequence length:", len(sequence)) - print("max_length:", self.max_length) - # we lookup unttil the max length only because some sequences in train/val can be longer lookup_length = min(self.max_length, len(sequence)) sequence_for_lookup = sequence[:lookup_length] feature = np.empty((self.max_length, *self._feature_shape), dtype=np.float32) - print("feature shape:", feature.shape) feature[:lookup_length] = itemgetter(*sequence_for_lookup)(self.lookup_table) - feature = self.pad_feature_to_seq_length(feature, self.max_length) - - print("DONE feature shape:", feature.shape) + # pad from lookup_length to max_length if needed and expand dims if one-dimensional + # technically, complementing the previous step with a call feature[lookup_length:] = self.feature_default_value + feature = self.pad_feature_to_seq_length(feature, lookup_length) return feature diff --git a/tests/conftest.py b/tests/conftest.py index 24ba36cd..d82fcc2a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -148,9 +148,9 @@ def ptm_alphabet(): "Y": 20, "[]-": 21, "-[]": 22, - "C[UNIMOD:4]": 23, # Carbamidomethylation of C - "K[UNIMOD:737]": 24, # TMT6plex of K - "[UNIMOD:737]-": 25, # TMT6plex of N-terminus + "C[UNIMOD:4]": 23, + "K[UNIMOD:737]": 24, + "[UNIMOD:737]-": 25, } return base @@ -189,22 +189,13 @@ def sample_parsed_sequence_with_nterm_mod(): @pytest.fixture -def mock_lookup_table(): - """Mock PTM feature lookup table.""" +def lookup_table(): + # Create a lookup table return { - "C[UNIMOD:4]": [1, 2, 3, 4, 5, 6], # 6D feature vector - "K[UNIMOD:737]": [7, 8, 9, 10, 11, 12], - "[UNIMOD:737]-": [0, 0, 0, 0, 0, 0], - "[]-": [0, 0, 0, 0, 0, 0], - "-[]": [0, 0, 0, 0, 0, 0], - "A": [0, 0, 0, 0, 0, 0], - "D": [0, 1, 0, 0, 0, 0], - "E": [0, 1, 0, 0, 0, 0], - "L": [0, 0, 0, 0, 0, 0], - "H": [1, 0, 0, 0, 0, 1], - "V": [0, 0, 0, 0, 0, 0], - "I": [0, 0, 0, 0, 0, 0], - "S": [0, 0, 0, 1, 0, 0], + 0: [1.0, 2.0], + 1: [3.0, 4.0], + 2: [5.0, 6.0], + 3: [7.0, 8.0], } diff --git a/tests/test_feature_extractors.py b/tests/test_feature_extractors.py new file mode 100644 index 00000000..ab2711ff --- /dev/null +++ b/tests/test_feature_extractors.py @@ -0,0 +1,79 @@ +import logging + +import numpy as np + +from dlomix.data.processing.feature_extractors import LookupFeatureExtractor + +logger = logging.getLogger(__name__) + + +def test_lookup_feature_extractor_exact_length(lookup_table): + # Create a sequence of indices to look up + sequence_for_lookup = [0, 1, 3] + sequence_length = len(sequence_for_lookup) + max_length = 3 + default_value = [-1.0, -1.0] + + # Create the feature extractor + feature_extractor = LookupFeatureExtractor( + sequence_column_name="sequence", + feature_column_name="feature", + lookup_table=lookup_table, + feature_default_value=default_value, + max_length=max_length, + ) + + # Extract features for the given sequence + feature = feature_extractor._extract_feature(sequence_for_lookup) + logger.info("Extracted feature:\n%s", feature) + + assert feature.shape == ( + max_length, + 2, + ), f"Expected feature shape to be ({max_length}, 2), but got {feature.shape}" + + assert np.array_equal( + feature[:sequence_length], + np.array([lookup_table[idx] for idx in sequence_for_lookup]), + ), "The extracted feature values do not match the expected values from the lookup table." + + logger.info(feature[sequence_length:]) + assert ( + len(feature[sequence_length:]) == 0 + ), "The padded feature values do not match the expected default value from the lookup table." + + +def test_lookup_feature_extractor_with_padding(lookup_table): + # Create a sequence of indices to look up + sequence_for_lookup = [0, 1, 3] + sequence_length = len(sequence_for_lookup) + max_length = 6 + default_value = [-5.0, -5.0] + + # Create the feature extractor + feature_extractor = LookupFeatureExtractor( + sequence_column_name="sequence", + feature_column_name="feature", + lookup_table=lookup_table, + feature_default_value=default_value, + max_length=max_length, + ) + + # Extract features for the given sequence + feature = feature_extractor._extract_feature(sequence_for_lookup) + logger.info("Extracted feature:\n%s", feature) + + assert feature.shape == ( + max_length, + 2, + ), f"Expected feature shape to be ({max_length}, 2), but got {feature.shape}" + + assert np.array_equal( + feature[:sequence_length], + np.array([lookup_table[idx] for idx in sequence_for_lookup]), + ), "The extracted feature values do not match the expected values from the lookup table." + + assert np.array_equal( + feature[sequence_length:], + np.array([default_value] * (max_length - sequence_length)), + ), "The padded feature values do not match the expected default value from the lookup table." From 6443de3c24cf01280a4d6f153fb70a5b8e748107 Mon Sep 17 00:00:00 2001 From: omsh Date: Mon, 9 Feb 2026 09:35:59 +0100 Subject: [PATCH 8/9] version 0.2.4 --- src/dlomix/_metadata.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/dlomix/_metadata.py b/src/dlomix/_metadata.py index 1df6c6c7..72338d62 100644 --- a/src/dlomix/_metadata.py +++ b/src/dlomix/_metadata.py @@ -1,8 +1,8 @@ -__version__ = "0.2.4.dev1" +__version__ = "0.2.4" __author__ = "Wilhelm Lab" __author_email__ = "o.shouman@tum.de" __license__ = "MIT" __description__ = "Deep Learning for Proteomics" __package__ = "DLOmix" -__copyright__ = "2025, Wilhelm Lab, TU Munich, School of Life Sciences" +__copyright__ = "2026, Wilhelm Lab, TU Munich, School of Life Sciences" __github_url__ = "https://github.com/wilhelm-lab/dlomix" From 20ab3798563521900655268f3d4b92f09dddc3d6 Mon Sep 17 00:00:00 2001 From: omsh Date: Mon, 9 Feb 2026 10:46:39 +0100 Subject: [PATCH 9/9] comment - lazy TF import refactor --- src/dlomix/data/dataset.py | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/src/dlomix/data/dataset.py b/src/dlomix/data/dataset.py index 56610ddc..42b42bbc 100644 --- a/src/dlomix/data/dataset.py +++ b/src/dlomix/data/dataset.py @@ -6,7 +6,6 @@ from pathlib import Path from typing import Any, Callable, Dict, Optional, Union -import tensorflow as tf from datasets import Dataset, DatasetDict, Sequence, Value, load_dataset from .dataset_config import DatasetConfig @@ -27,6 +26,27 @@ logger = logging.getLogger(__name__) +# TensorFlow import for tf.data is deferred until needed to avoid unnecessary imports for users who only want to use PyTorch datasets or other functionalities of the PeptideDataset class. +# This also helps to reduce the initial loading time and memory footprint for users who do not need TensorFlow. + +_tf = None + + +def _get_tensorflow(): + """Lazy import of TensorFlow. Only imports when needed.""" + global _tf + if _tf is None: + try: + import tensorflow as tf + + _tf = tf + except ImportError: + raise ImportError( + "TensorFlow backend requires tensorflow to be installed. " + "Install with: pip install tensorflow" + ) + return _tf + class PeptideDataset: """ @@ -695,6 +715,7 @@ def tensor_train_data(self): if self.dataset_type == "pt": return self._get_split_torch_dataset(PeptideDataset.DEFAULT_SPLIT_NAMES[0]) else: + tf = _get_tensorflow() dataset_len = len(self.hf_dataset[PeptideDataset.DEFAULT_SPLIT_NAMES[0]]) tf_dataset = self._get_split_tf_dataset( PeptideDataset.DEFAULT_SPLIT_NAMES[0] @@ -722,6 +743,7 @@ def tensor_val_data(self): if self.dataset_type == "pt": return self._get_split_torch_dataset(PeptideDataset.DEFAULT_SPLIT_NAMES[1]) else: + tf = _get_tensorflow() tf_dataset = self._get_split_tf_dataset( PeptideDataset.DEFAULT_SPLIT_NAMES[1] ) @@ -740,6 +762,7 @@ def tensor_test_data(self): if self.dataset_type == "pt": return self._get_split_torch_dataset(PeptideDataset.DEFAULT_SPLIT_NAMES[2]) else: + tf = _get_tensorflow() tf_dataset = self._get_split_tf_dataset( PeptideDataset.DEFAULT_SPLIT_NAMES[2] )