diff --git a/.gitignore b/.gitignore
index c905746..5c94db1 100644
--- a/.gitignore
+++ b/.gitignore
@@ -12,16 +12,30 @@ docs_public
*.csv
*.parquet
*.ipc
+*.mztab
+*.fasta
+*.mgf
*.pkl
*.json
*.yaml
+*.pdf
+*.png
+
+*.ipynb
examples/winnow-general-model
examples/winnow-ms-datasets
examples/output
+# Sample data files
+examples/example_data/*.ipc
+examples/example_data/*.csv
+examples/example_data/*.parquet
+
build/
+.cursorrules
+
# Coverage reports
htmlcov/
.coverage
diff --git a/Makefile b/Makefile
index 31bc740..86a8f81 100644
--- a/Makefile
+++ b/Makefile
@@ -86,12 +86,16 @@ install-all:
## Development commands #
#################################################################################
-.PHONY: tests test-docker bash set-gcp-credentials set-ceph-credentials
+.PHONY: tests clean-coverage test-docker bash build-package clean-build clean-workspace test-build clean-all-build test-cli-isolated test-cli-config set-gcp-credentials set-ceph-credentials
## Run all tests
tests:
$(PYTEST)
+## Clean coverage reports
+clean-coverage:
+ rm -rf htmlcov/ .coverage coverage.xml pytest.xml
+
## Run all tests in the Docker Image
test-docker:
docker run $(DOCKER_RUN_FLAGS) $(DOCKER_IMAGE) $(PYTEST)
@@ -100,6 +104,17 @@ test-docker:
bash:
docker run -it $(DOCKER_RUN_FLAGS) $(DOCKER_IMAGE) /bin/bash
+## Build the winnow-fdr package (creates wheel and sdist in dist/)
+build-package:
+ uv build
+
+## Clean all build artifacts (dist/, build/, *.egg-info/)
+clean-build:
+ rm -rf dist/ build/ *.egg-info/ winnow_fdr.egg-info/
+
+## Build the package and then clean up (safe test build)
+test-build: build-package clean-build
+
## Set the GCP credentials
set-gcp-credentials:
uv run python scripts/set_gcp_credentials.py
@@ -108,3 +123,28 @@ set-gcp-credentials:
## Set the Ceph credentials
set-ceph-credentials:
uv run python scripts/set_ceph_credentials.py
+
+#################################################################################
+## Sample data and CLI commands #
+#################################################################################
+
+.PHONY: sample-data train-sample predict-sample clean clean-all
+
+## Generate sample data files for testing
+sample-data:
+ uv run python scripts/generate_sample_data.py
+
+## Run winnow train with sample data (uses defaults from config)
+train-sample:
+ winnow train
+
+## Run winnow predict with sample data (uses locally trained model from models/new_model)
+predict-sample:
+ winnow predict calibrator.pretrained_model_name_or_path=models/new_model
+
+## Clean output directories (does not delete sample data)
+clean:
+ rm -rf models/ results/
+
+## Clean outputs and regenerate sample data
+clean-all: clean sample-data
diff --git a/README.md b/README.md
index 40629fe..49bba3b 100644
--- a/README.md
+++ b/README.md
@@ -42,9 +42,9 @@
Explore the docs »
- Report Bug
+ Report bug
·
- Request Feature
+ Request feature
[11/06/25 18:46:36] INFO Enabling RDKit 2024.09.6 jupyter extensions __init__.py:22\n", - "\n" - ], - "text/plain": [ - "\u001b[2;36m[11/06/25 18:46:36]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Enabling RDKit \u001b[1;36m2024.09\u001b[0m.\u001b[1;36m6\u001b[0m jupyter extensions \u001b]8;id=885939;file:///home/j-daniel/repos/winnow/.venv/lib/python3.12/site-packages/rdkit/__init__.py\u001b\\\u001b[2m__init__.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=21754;file:///home/j-daniel/repos/winnow/.venv/lib/python3.12/site-packages/rdkit/__init__.py#22\u001b\\\u001b[2m22\u001b[0m\u001b]8;;\u001b\\\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "import logging\n", "import warnings\n", @@ -96,6 +82,8 @@ "import pandas as pd\n", "import seaborn as sns\n", "from huggingface_hub import list_repo_files, snapshot_download\n", + "from hydra import initialize, compose\n", + "from hydra.utils import instantiate\n", "\n", "from winnow.calibration.calibration_features import (\n", " BeamFeatures,\n", @@ -106,9 +94,7 @@ " RetentionTimeFeature,\n", ")\n", "from winnow.calibration.calibrator import ProbabilityCalibrator\n", - "from winnow.constants import RESIDUE_MASSES\n", "from winnow.datasets.calibration_dataset import CalibrationDataset\n", - "from winnow.datasets.data_loaders import InstaNovoDatasetLoader\n", "from winnow.fdr.database_grounded import DatabaseGroundedFDRControl\n", "from winnow.fdr.nonparametric import NonParametricFDRControl\n", "\n", @@ -141,17 +127,9 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "['.gitattributes', 'README.md', 'celegans_labelled.parquet', 'celegans_labelled_beams.csv', 'celegans_raw.parquet', 'celegans_raw_beams.csv', 'general_test.parquet', 'general_test_beams.csv', 'general_train.parquet', 'general_train_beams.csv', 'general_val.parquet', 'general_val_beams.csv', 'helaqc_labelled.parquet', 'helaqc_labelled_beams.csv', 'helaqc_raw_less_train.parquet', 'helaqc_raw_less_train_beams.csv', 'immuno2_labelled.parquet', 'immuno2_labelled_beams.csv', 'immuno2_raw.parquet', 'immuno2_raw_beams.csv']\n" - ] - } - ], + "outputs": [], "source": [ "repo_id = \"InstaDeepAI/winnow-ms-datasets\"\n", "save_dir = \"winnow-ms-datasets\"\n", @@ -162,34 +140,9 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "5a543758fdac4806b5ff0a3ee5c9cdf7", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Fetching 4 files: 0%| | 0/4 [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "'/home/j-daniel/repos/winnow/examples/winnow-ms-datasets'" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "# -- Download the helaqc dataset\n", "snapshot_download(\n", @@ -225,7 +178,20 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Initialise Hydra with the config directory\n", + "with initialize(\n", + " config_path=\"../winnow/configs\", version_base=\"1.3\", job_name=\"winnow_notebook\"\n", + "):\n", + " cfg = compose(config_name=\"train\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, "metadata": {}, "outputs": [ { @@ -272,7 +238,10 @@ "source": [ "# -- Load data\n", "logger.info(\"Loading dataset.\")\n", - "dataset = InstaNovoDatasetLoader().load(\n", + "data_loader = instantiate(\n", + " cfg.data_loader\n", + ") # Loads default (InstaNovo) data loader with config\n", + "dataset = data_loader.load(\n", " data_path=\"winnow-ms-datasets/helaqc_labelled.parquet\",\n", " predictions_path=\"winnow-ms-datasets/helaqc_labelled_beams.csv\",\n", ")\n", @@ -750,7 +719,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -797,19 +766,30 @@ "source": [ "# -- Set up calibrator\n", "logger.info(\"Initialising calibrator.\")\n", - "SEED = 42\n", - "calibrator = ProbabilityCalibrator(SEED)\n", + "calibrator = ProbabilityCalibrator(seed=cfg.calibrator.seed)\n", "\n", "logger.info(\"Adding features to calibrator.\")\n", - "MZ_TOLERANCE = 0.02\n", - "HIDDEN_DIM = 10\n", - "TRAIN_FRACTION = 0.1\n", - "calibrator.add_feature(MassErrorFeature(residue_masses=RESIDUE_MASSES))\n", - "calibrator.add_feature(PrositFeatures(mz_tolerance=MZ_TOLERANCE))\n", "calibrator.add_feature(\n", - " RetentionTimeFeature(hidden_dim=HIDDEN_DIM, train_fraction=TRAIN_FRACTION)\n", + " MassErrorFeature(residue_masses=cfg.calibrator.features.mass_error.residue_masses)\n", + ")\n", + "calibrator.add_feature(\n", + " PrositFeatures(\n", + " mz_tolerance=cfg.calibrator.features.prosit_features.mz_tolerance,\n", + " invalid_prosit_tokens=cfg.calibrator.features.prosit_features.invalid_prosit_tokens,\n", + " )\n", + ")\n", + "calibrator.add_feature(\n", + " RetentionTimeFeature(\n", + " hidden_dim=cfg.calibrator.features.retention_time_feature.hidden_dim,\n", + " train_fraction=cfg.calibrator.features.retention_time_feature.train_fraction,\n", + " invalid_prosit_tokens=cfg.calibrator.features.prosit_features.invalid_prosit_tokens,\n", + " )\n", + ")\n", + "calibrator.add_feature(\n", + " ChimericFeatures(\n", + " mz_tolerance=cfg.calibrator.features.chimeric_features.mz_tolerance,\n", + " )\n", ")\n", - "calibrator.add_feature(ChimericFeatures(mz_tolerance=MZ_TOLERANCE))\n", "calibrator.add_feature(BeamFeatures())" ] }, @@ -1411,7 +1391,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -1425,11 +1405,10 @@ "source": [ "# -- Database-grounded FDR control\n", "database_grounded_fdr_control = DatabaseGroundedFDRControl(\n", - " confidence_feature=\"calibrated_confidence\"\n", - ")\n", - "database_grounded_fdr_control.fit(\n", - " dataset=test_dataset.metadata, residue_masses=RESIDUE_MASSES\n", + " confidence_feature=\"calibrated_confidence\",\n", + " residue_masses=cfg.residue_masses,\n", ")\n", + "database_grounded_fdr_control.fit(dataset=test_dataset.metadata)\n", "print(\n", " \"Database-grounded FDR control confidence cutoff at 5% FDR using calibrated confidence:\",\n", " database_grounded_fdr_control.get_confidence_cutoff(threshold=0.05),\n", @@ -1536,7 +1515,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -1583,7 +1562,7 @@ "source": [ "# -- Load the raw, unlabelled data\n", "logger.info(\"Loading raw dataset.\")\n", - "dataset = InstaNovoDatasetLoader().load(\n", + "dataset = data_loader.load(\n", " data_path=\"winnow-ms-datasets/helaqc_raw_less_train.parquet\",\n", " predictions_path=\"winnow-ms-datasets/helaqc_raw_less_train_beams.csv\",\n", ")\n", @@ -1884,8 +1863,10 @@ ], "source": [ "# Minimal feature set: no Prosit dependency\n", - "cal_min = ProbabilityCalibrator(SEED)\n", - "cal_min.add_feature(MassErrorFeature(residue_masses=RESIDUE_MASSES))\n", + "cal_min = ProbabilityCalibrator(seed=cfg.calibrator.seed)\n", + "cal_min.add_feature(\n", + " MassErrorFeature(residue_masses=cfg.calibrator.features.mass_error.residue_masses)\n", + ")\n", "cal_min.add_feature(BeamFeatures())\n", "\n", "cal_min.fit(train_dataset)\n", @@ -1968,7 +1949,7 @@ "source": [ "# -- Load data\n", "logger.info(\"Loading dataset.\")\n", - "celegans_dataset = InstaNovoDatasetLoader().load(\n", + "celegans_dataset = data_loader.load(\n", " data_path=\"winnow-ms-datasets/celegans_labelled.parquet\",\n", " predictions_path=\"winnow-ms-datasets/celegans_labelled_beams.csv\",\n", ")" @@ -2158,8 +2139,9 @@ "database_grounded_fdr_control = DatabaseGroundedFDRControl(\n", " confidence_feature=\"calibrated_confidence\"\n", ")\n", - "database_grounded_fdr_control.fit(\n", - " dataset=celegans_filtered_dataset.metadata, residue_masses=RESIDUE_MASSES\n", + "database_grounded_fdr_control = DatabaseGroundedFDRControl(\n", + " confidence_feature=\"calibrated_confidence\",\n", + " residue_masses=cfg.residue_masses,\n", ")\n", "confidence_cutoff_dbg = database_grounded_fdr_control.get_confidence_cutoff(\n", " threshold=0.05\n", diff --git a/mkdocs.yml b/mkdocs.yml index fcc7d87..3f81274 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -64,12 +64,14 @@ plugins: nav: - Home: index.md - - API Reference: + - User guide: + - CLI reference: cli.md + - Configuration guide: configuration.md + - Examples: examples.md + - API reference: - Datasets: api/datasets.md - Calibration: api/calibration.md - FDR: api/fdr.md - - CLI Guide: cli.md - - Examples: examples.md - Contributing: contributing.md - License: license.md diff --git a/pyproject.toml b/pyproject.toml index fc45d51..b113924 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ dependencies = [ "typer>=0.15.2", "instanovo>=1.1.4", "huggingface-hub>=0.35.3", + "hydra-core>=1.3.2", ] [project.urls] @@ -32,6 +33,9 @@ build-backend = "setuptools.build_meta" [tool.setuptools] packages = ["winnow", "winnow.scripts", "winnow.fdr", "winnow.datasets", "winnow.calibration"] +[tool.setuptools.package-data] +winnow = ["configs/**/*.yaml", "configs/**/*.yml"] + [dependency-groups] dev = [ "pre-commit>=4.1.0", @@ -86,3 +90,9 @@ exclude_lines = [ [tool.coverage.html] directory = "htmlcov" + +[tool.uv.workspace] +members = [ + "winnow_demo", + "tmp", +] diff --git a/scripts/generate_sample_data.py b/scripts/generate_sample_data.py new file mode 100755 index 0000000..4459df2 --- /dev/null +++ b/scripts/generate_sample_data.py @@ -0,0 +1,128 @@ +#!/usr/bin/env python3 +"""Generate minimal sample data for winnow train and predict commands.""" + +import numpy as np +import pandas as pd +import polars as pl +from pathlib import Path + + +def generate_sample_data(): + """Generate minimal sample IPC and CSV files for InstaNovo format.""" + output_dir = Path("examples/example_data") + output_dir.mkdir(parents=True, exist_ok=True) + + n_samples = 20 + spectrum_ids = [f"spectrum_{i}" for i in range(n_samples)] + + # Generate peptides using only valid amino acids (A, C, D, E, F, G, H, I, K, L, M, N, P, Q, R, S, T, V, W, Y) + # Note: Must avoid O, U, X, Z, B, J which are not standard amino acids + peptides = [ + "PEPTIDEK", + "MASSIVE", + "PEPTIDES", + "SEQQENCR", + "PEPTIDE", + "MASSIVE", + "PEPTIDES", + "SEQQENCR", + "PEPTIDEK", + "MASSIVE", + "PEPTIDES", + "SEQQENCR", + "PEPTIDE", + "MASSIVE", + "PEPTIDES", + "SEQQENCR", + "PEPTIDEK", + "MASSIVE", + "PEPTIDES", + "SEQQENCR", + ] + + # Generate spectrum data (IPC format) + # Calculate precursor_mass from mz and charge + np.random.seed(42) # For reproducibility + precursor_mz = np.random.uniform(400, 1200, n_samples) + precursor_charge = np.random.choice([2, 3, 4], n_samples) + proton_mass = 1.007276 + precursor_mass = precursor_mz * precursor_charge - proton_mass * precursor_charge + + # Generate spectrum arrays (mz_array and intensity_array) + mz_arrays = [] + intensity_arrays = [] + for _ in range(n_samples): + n_peaks = np.random.randint(10, 50) + mz_array = np.random.uniform(100, 1000, n_peaks).tolist() + intensity_array = np.random.uniform(0.1, 1.0, n_peaks).tolist() + mz_arrays.append(mz_array) + intensity_arrays.append(intensity_array) + + # Create spectrum data DataFrame using polars + spectrum_data = pl.DataFrame( + { + "spectrum_id": spectrum_ids, + "precursor_mz": precursor_mz, + "precursor_charge": precursor_charge.astype(int), + "precursor_mass": precursor_mass, + "retention_time": np.random.uniform(10, 60, n_samples), + "sequence": peptides, # Ground truth for training + "mz_array": mz_arrays, + "intensity_array": intensity_arrays, + } + ) + + # Generate predictions (CSV format) + predictions_data = { + "spectrum_id": spectrum_ids, + "predictions": peptides, + "predictions_tokenised": [ + ", ".join(list(p)) + for p in peptides # "P, E, P, T, I, D, E, K" + ], + "log_probs": np.log(np.random.uniform(0.1, 0.9, n_samples)), + "sequence": peptides, # Ground truth + } + + # Add beam predictions (top 3 beams) + # Generate valid alternative peptides for runner-up beams + valid_aa = list("ACDEFGHIKLMNPQRSTVWY") + np.random.seed(43) # Different seed for beam alternatives + for beam_idx in range(3): + if beam_idx == 0: + # Top beam uses the main prediction + beam_predictions = peptides + else: + # Generate valid alternative peptides for runner-up beams + beam_predictions = [ + "".join(np.random.choice(valid_aa, size=len(peptides[i]))) + for i in range(n_samples) + ] + predictions_data[f"instanovo_predictions_beam_{beam_idx}"] = beam_predictions + predictions_data[f"instanovo_log_probabilities_beam_{beam_idx}"] = [ + np.log(np.random.uniform(0.1, 0.9)) for _ in range(n_samples) + ] + # Token log probabilities as string representation of list + predictions_data[f"token_log_probabilities_beam_{beam_idx}"] = [ + str([np.log(np.random.uniform(0.5, 0.9)) for _ in range(len(p))]) + for p in peptides + ] + + predictions_df = pd.DataFrame(predictions_data) + + # Save files + spectrum_path = output_dir / "spectra.ipc" + predictions_path = output_dir / "predictions.csv" + + spectrum_data.write_ipc(str(spectrum_path)) + predictions_df.to_csv(predictions_path, index=False) + + print("✓ Generated sample data:") + print(f" - {spectrum_path}") + print(f" - {predictions_path}") + print("\n✓ You can now run:") + print(" winnow train # Uses sample data from config defaults") + + +if __name__ == "__main__": + generate_sample_data() diff --git a/tests/calibration/test_calibration_features.py b/tests/calibration/test_calibration_features.py index 90aba49..d785d26 100644 --- a/tests/calibration/test_calibration_features.py +++ b/tests/calibration/test_calibration_features.py @@ -16,7 +16,6 @@ _raise_value_error, ) from winnow.datasets.calibration_dataset import CalibrationDataset -from winnow.constants import RESIDUE_MASSES class TestUtilityFunctions: @@ -105,7 +104,23 @@ class TestMassErrorFeature: @pytest.fixture() def mass_error_feature(self): """Create a MassErrorFeature instance for testing.""" - return MassErrorFeature(residue_masses=RESIDUE_MASSES) + residue_masses = { + "G": 57.021464, + "A": 71.037114, + "P": 97.052764, + "E": 129.042593, + "T": 101.047670, + "I": 113.084064, + "D": 115.026943, + "R": 156.101111, + "O": 237.147727, + "N": 114.042927, + "S": 87.032028, + "M": 131.040485, + "L": 113.084064, + "V": 99.068414, + } + return MassErrorFeature(residue_masses=residue_masses) @pytest.fixture() def sample_dataset(self): @@ -469,7 +484,9 @@ class TestRetentionTimeFeature: @pytest.fixture() def retention_time_feature(self): """Create a RetentionTimeFeature instance for testing.""" - return RetentionTimeFeature(hidden_dim=10, train_fraction=0.8) + return RetentionTimeFeature( + hidden_dim=10, train_fraction=0.8, invalid_prosit_tokens=["U", "O", "X"] + ) @pytest.fixture() def sample_dataset_with_rt(self): @@ -494,7 +511,9 @@ def test_properties(self, retention_time_feature): def test_initialization_parameters(self): """Test initialization with custom parameters.""" - feature = RetentionTimeFeature(hidden_dim=10, train_fraction=0.8) + feature = RetentionTimeFeature( + hidden_dim=10, train_fraction=0.8, invalid_prosit_tokens=["U", "O", "X"] + ) assert feature.hidden_dim == 10 assert feature.train_fraction == 0.8 assert feature.prosit_irt_model_name == "Prosit_2019_irt" @@ -779,7 +798,7 @@ class TestPrositFeatures: @pytest.fixture() def prosit_features(self): """Create a PrositFeatures instance for testing.""" - return PrositFeatures(mz_tolerance=0.02) + return PrositFeatures(mz_tolerance=0.02, invalid_prosit_tokens=["U", "O", "X"]) @pytest.fixture() def sample_dataset_with_spectra(self): @@ -818,7 +837,9 @@ def test_properties(self, prosit_features): def test_initialization_with_tolerance(self): """Test initialization with custom tolerance.""" - feature = PrositFeatures(mz_tolerance=0.01) + feature = PrositFeatures( + mz_tolerance=0.01, invalid_prosit_tokens=["U", "O", "X"] + ) assert feature.mz_tolerance == 0.01 assert feature.prosit_intensity_model_name == "Prosit_2020_intensity_HCD" @@ -1152,7 +1173,9 @@ class TestChimericFeatures: @pytest.fixture() def chimeric_features(self): """Create a ChimericFeatures instance for testing.""" - return ChimericFeatures(mz_tolerance=0.02) + return ChimericFeatures( + mz_tolerance=0.02, invalid_prosit_tokens=["U", "O", "X"] + ) @pytest.fixture() def sample_dataset_with_beam_predictions(self): @@ -1205,7 +1228,9 @@ def test_properties(self, chimeric_features): def test_initialization_with_tolerance(self): """Test initialization with custom tolerance.""" - feature = ChimericFeatures(mz_tolerance=0.01) + feature = ChimericFeatures( + mz_tolerance=0.01, invalid_prosit_tokens=["U", "O", "X"] + ) assert feature.mz_tolerance == 0.01 def test_prepare_does_nothing( diff --git a/tests/fdr/test_database_grounded.py b/tests/fdr/test_database_grounded.py index b5d9a2b..8e5f051 100644 --- a/tests/fdr/test_database_grounded.py +++ b/tests/fdr/test_database_grounded.py @@ -1,7 +1,6 @@ """Unit tests for winnow DatabaseGroundedFDRControl.""" import pytest -from unittest.mock import patch, Mock import pandas as pd from winnow.fdr.database_grounded import DatabaseGroundedFDRControl @@ -12,7 +11,24 @@ class TestDatabaseGroundedFDRControl: @pytest.fixture() def db_fdr_control(self): """Create a DatabaseGroundedFDRControl instance for testing.""" - return DatabaseGroundedFDRControl(confidence_feature="confidence") + residue_masses = { + "G": 57.021464, + "A": 71.037114, + "P": 97.052764, + "E": 129.042593, + "T": 101.047670, + "I": 113.084064, + "D": 115.026943, + "R": 156.101111, + "O": 237.147727, + "N": 114.042927, + "S": 87.032028, + "M": 131.040485, + "L": 113.084064, + } + return DatabaseGroundedFDRControl( + confidence_feature="confidence", residue_masses=residue_masses + ) @pytest.fixture() def sample_dataset_df(self): @@ -32,59 +48,40 @@ def test_initialization(self, db_fdr_control): assert db_fdr_control._fdr_values is None assert db_fdr_control._confidence_scores is None - @patch("winnow.fdr.database_grounded.Metrics") - def test_fit_basic(self, mock_metrics, db_fdr_control, sample_dataset_df): + def test_fit_basic(self, db_fdr_control, sample_dataset_df): """Test basic fitting functionality.""" - # Mock the Metrics class and its methods - mock_metrics_instance = Mock() - mock_metrics.return_value = mock_metrics_instance - mock_metrics_instance._split_peptide = lambda x: list(x) - - residue_masses = { - "P": 100.0, - "E": 110.0, - "T": 120.0, - "I": 130.0, - "D": 140.0, - "R": 150.0, - "O": 160.0, - "N": 170.0, - "S": 180.0, - "A": 190.0, - "M": 200.0, - "L": 210.0, - } + # Convert sequences to list format as expected by the implementation + sample_dataset_df = sample_dataset_df.copy() + sample_dataset_df["prediction"] = sample_dataset_df["prediction"].apply(list) # Should not raise an exception - db_fdr_control.fit(sample_dataset_df, residue_masses) + db_fdr_control.fit(sample_dataset_df) - # Check that metrics was called - mock_metrics.assert_called_once() + # Check that fit created the required attributes + assert hasattr(db_fdr_control, "preds") + assert hasattr(db_fdr_control, "_fdr_values") + assert hasattr(db_fdr_control, "_confidence_scores") + assert db_fdr_control._fdr_values is not None + assert db_fdr_control._confidence_scores is not None def test_fit_with_parameters(self, db_fdr_control): """Test fit with custom parameters.""" sample_df = pd.DataFrame( - {"sequence": ["TEST"], "prediction": ["TEST"], "confidence": [0.9]} + {"sequence": ["TEST"], "prediction": [list("TEST")], "confidence": [0.9]} ) - residue_masses = {"T": 100.0, "E": 110.0, "S": 120.0} - - with patch("winnow.fdr.database_grounded.Metrics") as mock_metrics: - mock_metrics_instance = Mock() - mock_metrics.return_value = mock_metrics_instance - mock_metrics_instance._split_peptide = lambda x: list(x) - db_fdr_control.fit( - sample_df, residue_masses, isotope_error_range=(0, 2), drop=5 - ) + db_fdr_control.fit(sample_df) - # Check that Metrics was initialized with correct parameters - mock_metrics.assert_called_once() + # Check that fit created the required attributes + assert hasattr(db_fdr_control, "preds") + assert len(db_fdr_control.preds) == 1 + assert db_fdr_control.preds.iloc[0]["confidence"] == 0.9 def test_fit_with_empty_data(self, db_fdr_control): """Test that fit method handles empty data.""" empty_data = pd.DataFrame() with pytest.raises(AssertionError, match="Fit method requires non-empty data"): - db_fdr_control.fit(empty_data, residue_masses={"A": 71.03}) + db_fdr_control.fit(empty_data) def test_get_confidence_cutoff_requires_fitting(self, db_fdr_control): """Test that get_confidence_cutoff requires fitting first.""" diff --git a/tests/scripts/test_config_paths.py b/tests/scripts/test_config_paths.py new file mode 100644 index 0000000..ac3ccf1 --- /dev/null +++ b/tests/scripts/test_config_paths.py @@ -0,0 +1,199 @@ +"""Tests for config path resolution utilities.""" + +from __future__ import annotations + +import pytest +from pathlib import Path +from unittest.mock import patch, MagicMock + +from winnow.scripts.config_path_utils import ( + get_config_dir, + get_config_search_path, + get_primary_config_dir, +) + + +class TestGetConfigDir: + """Tests for get_config_dir() function.""" + + def test_package_mode(self, tmp_path): + """Test config dir resolution in package mode.""" + # Mock importlib.resources to simulate installed package + package_configs = tmp_path / "package_configs" + package_configs.mkdir() + + # Create a mock that simulates files("winnow").joinpath("configs") + mock_winnow_files = MagicMock() + mock_configs = MagicMock() + mock_configs.is_dir.return_value = True + # Make str() return the path + type(mock_configs).__str__ = lambda self: str(package_configs) + mock_winnow_files.joinpath.return_value = mock_configs + + # Mock files("winnow") to return mock_winnow_files + mock_files = MagicMock(return_value=mock_winnow_files) + + # Patch importlib.resources.files (where it's imported from) + with patch("importlib.resources.files", mock_files): + config_dir = get_config_dir() + assert config_dir == package_configs + + def test_dev_mode(self, tmp_path): + """Test config dir resolution in dev mode.""" + # Mock importlib.resources to fail (simulating dev mode) + with patch("importlib.resources.files", side_effect=ModuleNotFoundError()): + # Create a mock repo structure + repo_root = tmp_path / "repo" + winnow_dir = repo_root / "winnow" + configs_dir = winnow_dir / "configs" + configs_dir.mkdir(parents=True) + + # Mock __file__ to point to winnow/scripts/config_paths.py + with patch( + "winnow.scripts.config_path_utils.__file__", + str(winnow_dir / "scripts" / "config_paths.py"), + ): + config_dir = get_config_dir() + assert config_dir == configs_dir + + def test_dev_mode_alt_location(self, tmp_path): + """Test config dir resolution in dev mode with configs at repo root.""" + # Mock importlib.resources to fail + with patch("importlib.resources.files", side_effect=ModuleNotFoundError()): + # Create a mock repo structure with configs at root + repo_root = tmp_path / "repo" + configs_dir = repo_root / "configs" + configs_dir.mkdir(parents=True) + + # Mock __file__ to point to winnow/scripts/config_paths.py + with patch( + "winnow.scripts.config_path_utils.__file__", + str(repo_root / "winnow" / "scripts" / "config_paths.py"), + ): + config_dir = get_config_dir() + assert config_dir == configs_dir + + def test_not_found(self): + """Test error when config dir cannot be found.""" + with patch("importlib.resources.files", side_effect=ModuleNotFoundError()): + with patch( + "winnow.scripts.config_path_utils.__file__", "/nonexistent/path" + ): + with pytest.raises(FileNotFoundError): + get_config_dir() + + +class TestGetConfigSearchPath: + """Tests for get_config_search_path() function.""" + + def test_custom_dir_only(self, tmp_path): + """Test search path with custom directory.""" + custom_dir = tmp_path / "custom_configs" + custom_dir.mkdir() + + with patch( + "winnow.scripts.config_path_utils.get_config_dir", + return_value=Path("/package/configs"), + ): + search_path = get_config_search_path(str(custom_dir)) + assert len(search_path) == 2 + assert search_path[0] == custom_dir.resolve() + assert search_path[1] == Path("/package/configs").resolve() + + def test_no_custom_dir(self, tmp_path): + """Test search path without custom directory.""" + package_dir = tmp_path / "package_configs" + package_dir.mkdir() + + with patch( + "winnow.scripts.config_path_utils.get_config_dir", return_value=package_dir + ): + search_path = get_config_search_path() + assert len(search_path) == 1 + assert search_path[0] == package_dir.resolve() + + def test_custom_dir_not_exists(self): + """Test error when custom directory doesn't exist.""" + with pytest.raises(FileNotFoundError, match="does not exist"): + get_config_search_path("/nonexistent/path") + + def test_custom_dir_not_directory(self, tmp_path): + """Test error when custom path is not a directory.""" + file_path = tmp_path / "not_a_dir" + file_path.touch() + + with pytest.raises(ValueError, match="not a directory"): + get_config_search_path(str(file_path)) + + +class TestGetPrimaryConfigDir: + """Tests for get_primary_config_dir() function.""" + + def test_no_custom_dir(self, tmp_path): + """Test primary config dir without custom directory.""" + package_dir = tmp_path / "package_configs" + package_dir.mkdir() + + with patch( + "winnow.scripts.config_path_utils.get_config_dir", return_value=package_dir + ): + primary_dir = get_primary_config_dir() + assert primary_dir == package_dir.resolve() + + def test_with_custom_dir(self, tmp_path): + """Test primary config dir with custom directory (merged).""" + custom_dir = tmp_path / "custom_configs" + custom_dir.mkdir() + (custom_dir / "residues.yaml").write_text("custom: true") + + package_dir = tmp_path / "package_configs" + package_dir.mkdir() + (package_dir / "train.yaml").write_text("package: true") + (package_dir / "residues.yaml").write_text("package: true") + + with patch( + "winnow.scripts.config_path_utils.get_config_dir", return_value=package_dir + ): + primary_dir = get_primary_config_dir(str(custom_dir)) + + # Should be a temporary merged directory + assert primary_dir.exists() + assert primary_dir.is_dir() + + # Custom config should override package config + residues_content = (primary_dir / "residues.yaml").read_text() + assert "custom: true" in residues_content + + # Package config should be available for files not in custom dir + assert (primary_dir / "train.yaml").exists() + train_content = (primary_dir / "train.yaml").read_text() + assert "package: true" in train_content + + def test_partial_configs(self, tmp_path): + """Test that partial configs work (only custom residues.yaml).""" + custom_dir = tmp_path / "custom_configs" + custom_dir.mkdir() + (custom_dir / "residues.yaml").write_text("custom_residues: true") + + package_dir = tmp_path / "package_configs" + package_dir.mkdir() + (package_dir / "train.yaml").write_text("train_config: true") + (package_dir / "residues.yaml").write_text("package_residues: true") + (package_dir / "calibrator.yaml").write_text("calibrator_config: true") + + with patch( + "winnow.scripts.config_path_utils.get_config_dir", return_value=package_dir + ): + primary_dir = get_primary_config_dir(str(custom_dir)) + + # Custom residues should override + residues_content = (primary_dir / "residues.yaml").read_text() + assert "custom_residues" in residues_content + assert "package_residues" not in residues_content + + # Package files not in custom dir should be available + assert (primary_dir / "train.yaml").exists() + assert (primary_dir / "calibrator.yaml").exists() + + train_content = (primary_dir / "train.yaml").read_text() + assert "train_config" in train_content diff --git a/uv.lock b/uv.lock index 7b2a52c..abd63aa 100644 --- a/uv.lock +++ b/uv.lock @@ -5209,6 +5209,7 @@ version = "1.0.3" source = { editable = "." } dependencies = [ { name = "huggingface-hub" }, + { name = "hydra-core" }, { name = "instanovo" }, { name = "koinapy" }, { name = "tomli" }, @@ -5239,6 +5240,7 @@ notebook = [ [package.metadata] requires-dist = [ { name = "huggingface-hub", specifier = ">=0.35.3" }, + { name = "hydra-core", specifier = ">=1.3.2" }, { name = "instanovo", specifier = ">=1.1.4" }, { name = "koinapy", specifier = ">=0.0.10" }, { name = "tomli", specifier = ">=2.2.1" }, diff --git a/winnow/calibration/calibration_features.py b/winnow/calibration/calibration_features.py index b5a820e..3fd17df 100644 --- a/winnow/calibration/calibration_features.py +++ b/winnow/calibration/calibration_features.py @@ -13,7 +13,6 @@ import koinapy from winnow.datasets.calibration_dataset import CalibrationDataset -from winnow.constants import INVALID_PROSIT_TOKENS def map_modification(peptide: List[str]) -> List[str]: @@ -197,18 +196,28 @@ def compute_ion_identifications( class PrositFeatures(CalibrationFeatures): """A class for extracting features related to Prosit: a machine learning-based intensity prediction tool for peptide fragmentation.""" - def __init__(self, mz_tolerance: float, learn_from_missing: bool = True) -> None: + def __init__( + self, + mz_tolerance: float, + invalid_prosit_tokens: List[str], + learn_from_missing: bool = True, + prosit_intensity_model_name: str = "Prosit_2020_intensity_HCD", + ) -> None: """Initialize PrositFeatures. Args: mz_tolerance (float): The mass-to-charge tolerance for ion matching. + invalid_prosit_tokens (List[str]): The tokens to consider as invalid for Prosit intensity prediction. learn_from_missing (bool): Whether to learn from missing data by including a missingness indicator column. If False, an error will be raised when invalid spectra are encountered. Defaults to True. + prosit_intensity_model_name (str): The name of the Prosit intensity model to use. + Defaults to "Prosit_2020_intensity_HCD". """ self.mz_tolerance = mz_tolerance + self.invalid_prosit_tokens = invalid_prosit_tokens self.learn_from_missing = learn_from_missing - self.prosit_intensity_model_name = "Prosit_2020_intensity_HCD" + self.prosit_intensity_model_name = prosit_intensity_model_name @property def dependencies(self) -> List[FeatureDependency]: @@ -266,7 +275,7 @@ def check_valid_prosit_prediction(self, dataset: CalibrationDataset) -> pd.Serie metadata_predicate=lambda row: ( any( token in row["prediction_untokenised"] - for token in INVALID_PROSIT_TOKENS + for token in self.invalid_prosit_tokens ) ) ) @@ -322,7 +331,7 @@ def compute(self, dataset: CalibrationDataset) -> None: f"Please filter your dataset to remove:\n" f" - Peptides longer than 30 amino acids\n" f" - Precursor charges greater than 6\n" - f" - Peptides with unsupported modifications (e.g., {', '.join(INVALID_PROSIT_TOKENS[:3])}...)\n" + f" - Peptides with unsupported modifications (e.g., {', '.join(self.invalid_prosit_tokens[:3])}...)\n" f"Or set learn_from_missing=True to handle missing data automatically." ) @@ -413,18 +422,28 @@ class ChimericFeatures(CalibrationFeatures): are stored in the dataset metadata. """ - def __init__(self, mz_tolerance: float, learn_from_missing: bool = True) -> None: + def __init__( + self, + mz_tolerance: float, + invalid_prosit_tokens: List[str], + learn_from_missing: bool = True, + prosit_intensity_model_name: str = "Prosit_2020_intensity_HCD", + ) -> None: """Initialize ChimericFeatures. Args: mz_tolerance (float): The mass-to-charge tolerance for ion matching. + invalid_prosit_tokens (List[str]): The tokens to consider as invalid for Prosit intensity prediction. learn_from_missing (bool): Whether to learn from missing data by including a missingness indicator column. If False, an error will be raised when invalid spectra are encountered. Defaults to True. + prosit_intensity_model_name (str): The name of the Prosit intensity model to use. + Defaults to "Prosit_2020_intensity_HCD". """ self.mz_tolerance = mz_tolerance self.learn_from_missing = learn_from_missing - self.prosit_intensity_model_name = "Prosit_2020_intensity_HCD" + self.invalid_prosit_tokens = invalid_prosit_tokens + self.prosit_intensity_model_name = prosit_intensity_model_name @property def dependencies(self) -> List[FeatureDependency]: @@ -488,7 +507,7 @@ def check_valid_chimeric_prosit_prediction( len(beam) > 1 and any( token in "".join(beam[1].sequence) - for token in INVALID_PROSIT_TOKENS + for token in self.invalid_prosit_tokens ) ) ) @@ -552,7 +571,7 @@ def compute(self, dataset: CalibrationDataset) -> None: f" - Spectra without runner-up sequences (beam search required)\n" f" - Runner-up peptides longer than 30 amino acids\n" f" - Runner-up peptides with precursor charges greater than 6\n" - f" - Runner-up peptides with unsupported modifications (e.g., {', '.join(INVALID_PROSIT_TOKENS[:3])}...)\n" + f" - Runner-up peptides with unsupported modifications (e.g., {', '.join(self.invalid_prosit_tokens[:3])}...)\n" f"Or set learn_from_missing=True to handle missing data automatically." ) @@ -834,23 +853,50 @@ class RetentionTimeFeature(CalibrationFeatures): irt_predictor: MLPRegressor def __init__( - self, hidden_dim: int, train_fraction: float, learn_from_missing: bool = True + self, + hidden_dim: int, + train_fraction: float, + invalid_prosit_tokens: List[str], + learn_from_missing: bool = True, + seed: int = 42, + learning_rate_init: float = 0.001, + alpha: float = 0.0001, + max_iter: int = 200, + early_stopping: bool = False, + validation_fraction: float = 0.1, + prosit_irt_model_name: str = "Prosit_2019_irt", ) -> None: """Initialize RetentionTimeFeature. Args: hidden_dim (int): Hidden dimension size for the MLP regressor. train_fraction (float): Fraction of data to use for training the iRT calibrator. + invalid_prosit_tokens (List[str]): The tokens to consider as invalid for Prosit iRT prediction. learn_from_missing (bool): Whether to learn from missing data by including a missingness indicator column. If False, an error will be raised when invalid spectra are encountered. Defaults to True. + seed (int): Random seed for the regressor. Defaults to 42. + learning_rate_init (float): The initial learning rate. Defaults to 0.001. + alpha (float): L2 regularisation parameter. Defaults to 0.0001. + max_iter (int): Maximum number of training iterations. Defaults to 200. + early_stopping (bool): Whether to use early stopping to terminate training. Defaults to False. + validation_fraction (float): Proportion of training data to use for early stopping validation. Defaults to 0.1. + prosit_irt_model_name (str): The name of the Prosit iRT model to use. + Defaults to "Prosit_2019_irt". """ self.train_fraction = train_fraction self.hidden_dim = hidden_dim self.learn_from_missing = learn_from_missing - self.prosit_irt_model_name = "Prosit_2019_irt" + self.invalid_prosit_tokens = invalid_prosit_tokens + self.prosit_irt_model_name = prosit_irt_model_name self.irt_predictor = MLPRegressor( - hidden_layer_sizes=[hidden_dim], random_state=42 + hidden_layer_sizes=[hidden_dim], + random_state=seed, + learning_rate_init=learning_rate_init, + alpha=alpha, + max_iter=max_iter, + early_stopping=early_stopping, + validation_fraction=validation_fraction, ) @property @@ -906,7 +952,7 @@ def check_valid_irt_prediction(self, dataset: CalibrationDataset) -> pd.Series: metadata_predicate=lambda row: ( any( token in row["prediction_untokenised"] - for token in INVALID_PROSIT_TOKENS + for token in self.invalid_prosit_tokens ) ) ) @@ -1003,7 +1049,7 @@ def compute(self, dataset: CalibrationDataset) -> None: f" - Spectra without retention time data\n" f" - Peptides longer than 30 amino acids\n" f" - Precursor charges greater than 6\n" - f" - Peptides with unsupported modifications (e.g., {', '.join(INVALID_PROSIT_TOKENS[:3])}...)\n" + f" - Peptides with unsupported modifications (e.g., {', '.join(self.invalid_prosit_tokens[:3])}...)\n" f"Or set learn_from_missing=True to handle missing data automatically." ) diff --git a/winnow/calibration/calibrator.py b/winnow/calibration/calibrator.py index 1ea2d6a..b3b927d 100644 --- a/winnow/calibration/calibrator.py +++ b/winnow/calibration/calibrator.py @@ -8,6 +8,8 @@ from sklearn.preprocessing import StandardScaler from numpy.typing import NDArray from huggingface_hub import snapshot_download +from omegaconf import DictConfig + from winnow.calibration.calibration_features import ( CalibrationFeatures, FeatureDependency, @@ -21,21 +23,54 @@ class ProbabilityCalibrator: This class provides functionality to recalibrate predicted probabilities by fitting an MLP classifier using various features computed from a calibration dataset. """ - def __init__(self, seed: int = 42) -> None: + def __init__( + self, + seed: int = 42, + features: Optional[ + Union[List[CalibrationFeatures], Dict[str, CalibrationFeatures], DictConfig] + ] = None, + hidden_layer_sizes: Tuple[int, ...] = (50, 50), + learning_rate_init: float = 0.001, + alpha: float = 0.0001, + max_iter: int = 1000, + early_stopping: bool = True, + validation_fraction: float = 0.1, + ) -> None: + """Initialise the probability calibrator. + + Args: + seed (int): Random seed for the classifier. Defaults to 42. + features (Optional[Union[List[CalibrationFeatures], Dict[str, CalibrationFeatures], DictConfig]]): + Features to add to the calibrator. Can be a list or dict of CalibrationFeatures objects. + If None, no features are added. Defaults to None. + hidden_layer_sizes (Tuple[int, ...]): The number of neurons in each hidden layer. Defaults to (50, 50). + learning_rate_init (float): The initial learning rate. Defaults to 0.001. + alpha (float): L2 regularisation parameter. Defaults to 0.0001. + max_iter (int): Maximum number of training iterations. Defaults to 1000. + early_stopping (bool): Whether to use early stopping to terminate training. Defaults to True. + validation_fraction (float): Proportion of training data to use for early stopping validation. Defaults to 0.1. + """ self.feature_dict: Dict[str, CalibrationFeatures] = {} self.dependencies: Dict[str, FeatureDependency] = {} self.dependency_reference_counter: Dict[str, int] = {} self.classifier = MLPClassifier( random_state=seed, - hidden_layer_sizes=(50, 50), - learning_rate_init=0.001, - alpha=0.0001, - max_iter=1000, - early_stopping=True, - validation_fraction=0.1, + hidden_layer_sizes=hidden_layer_sizes, + learning_rate_init=learning_rate_init, + alpha=alpha, + max_iter=max_iter, + early_stopping=early_stopping, + validation_fraction=validation_fraction, ) self.scaler = StandardScaler() + # Add features if provided + if features is not None: + if isinstance(features, (dict, DictConfig)): + self.add_features(list(features.values())) + else: + self.add_features(list(features)) + @property def columns(self) -> List[str]: """Returns the list of column names corresponding to the features added to the calibrator. @@ -59,14 +94,18 @@ def features(self) -> List[str]: return list(self.feature_dict.keys()) @classmethod - def save(cls, calibrator: "ProbabilityCalibrator", dir_path: Path) -> None: + def save( + cls, calibrator: "ProbabilityCalibrator", dir_path: Union[Path, str] + ) -> None: """Save the calibrator to a file. Args: calibrator (ProbabilityCalibrator): The calibrator to save. dir_path (Path): The path to the directory where the calibrator checkpoint will be saved. """ - dir_path.mkdir(parents=True) + if isinstance(dir_path, str): + dir_path = Path(dir_path) + dir_path.mkdir(parents=True, exist_ok=True) pickle.dump(calibrator, open(dir_path / "calibrator.pkl", "wb")) @classmethod diff --git a/winnow/configs/calibrator.yaml b/winnow/configs/calibrator.yaml new file mode 100644 index 0000000..0816814 --- /dev/null +++ b/winnow/configs/calibrator.yaml @@ -0,0 +1,48 @@ +# --- Calibrator configuration --- + +calibrator: + _target_: winnow.calibration.calibrator.ProbabilityCalibrator + + seed: 42 + hidden_layer_sizes: [50, 50] # The number of neurons in each hidden layer of the MLP classifier. + learning_rate_init: 0.001 # The initial learning rate for the MLP classifier. + alpha: 0.0001 # L2 regularisation parameter for the MLP classifier. + max_iter: 1000 # Maximum number of training iterations for the MLP classifier. + early_stopping: true # Whether to use early stopping to terminate training. + validation_fraction: 0.1 # Proportion of training data to use for early stopping validation. + + features: + mass_error: + _target_: winnow.calibration.calibration_features.MassErrorFeature + residue_masses: ${residue_masses} # The residue masses to use for the mass error feature. + + prosit_features: + _target_: winnow.calibration.calibration_features.PrositFeatures + mz_tolerance: 0.02 + learn_from_missing: true # Whether to learn from missing Prosit features. If False, errors will be raised when invalid spectra are encountered. + invalid_prosit_tokens: ${invalid_prosit_tokens} # The tokens to consider as invalid for Prosit features. + prosit_intensity_model_name: Prosit_2020_intensity_HCD # The name of the Prosit intensity model to use. + + retention_time_feature: + _target_: winnow.calibration.calibration_features.RetentionTimeFeature + hidden_dim: 10 # The hidden dimension size for the MLP regressor used to predict iRT from observed retention times. + train_fraction: 0.1 # The fraction of the data to use for training the iRT predictor. + learn_from_missing: true # Whether to learn from missing retention time features. If False, errors will be raised when invalid spectra are encountered. + seed: 42 # Random seed for the MLP regressor. + learning_rate_init: 0.001 # The initial learning rate for the MLP regressor. + alpha: 0.0001 # L2 regularisation parameter for the MLP regressor. + max_iter: 200 # Maximum number of training iterations for the MLP regressor. + early_stopping: false # Whether to use early stopping for the MLP regressor. + validation_fraction: 0.1 # Proportion of training data to use for early stopping validation. + invalid_prosit_tokens: ${invalid_prosit_tokens} # The tokens to consider as invalid for Prosit iRT features. + prosit_irt_model_name: Prosit_2019_irt # The name of the Prosit iRT model to use. + + chimeric_features: + _target_: winnow.calibration.calibration_features.ChimericFeatures + mz_tolerance: 0.02 + learn_from_missing: true # Whether to learn from missing chimeric features. If False, errors will be raised when invalid spectra are encountered. + invalid_prosit_tokens: ${invalid_prosit_tokens} # The tokens to consider as invalid for Prosit chimeric intensity features. + prosit_intensity_model_name: Prosit_2020_intensity_HCD # The name of the Prosit intensity model to use. + + beam_features: + _target_: winnow.calibration.calibration_features.BeamFeatures diff --git a/winnow/configs/data_loader/instanovo.yaml b/winnow/configs/data_loader/instanovo.yaml new file mode 100644 index 0000000..caf142e --- /dev/null +++ b/winnow/configs/data_loader/instanovo.yaml @@ -0,0 +1,23 @@ +# --- InstaNovo data loading configuration --- + +_target_: winnow.datasets.data_loaders.InstaNovoDatasetLoader + +residue_masses: ${residue_masses} +residue_remapping: # Used to map InstaNovo legacy notations to UNIMOD tokens. + "M(ox)": "M[UNIMOD:35]" # Oxidation + "M(+15.99)": "M[UNIMOD:35]" # Oxidation + "S(p)": "S[UNIMOD:21]" # Phosphorylation + "T(p)": "T[UNIMOD:21]" # Phosphorylation + "Y(p)": "Y[UNIMOD:21]" # Phosphorylation + "S(+79.97)": "S[UNIMOD:21]" # Phosphorylation + "T(+79.97)": "T[UNIMOD:21]" # Phosphorylation + "Y(+79.97)": "Y[UNIMOD:21]" # Phosphorylation + "Q(+0.98)": "Q[UNIMOD:7]" # Deamidation + "N(+0.98)": "N[UNIMOD:7]" # Deamidation + "Q(+.98)": "Q[UNIMOD:7]" # Deamidation + "N(+.98)": "N[UNIMOD:7]" # Deamidation + "C(+57.02)": "C[UNIMOD:4]" # Carbamidomethylation + # N-terminal modifications. + "(+42.01)": "[UNIMOD:1]" # Acetylation + "(+43.01)": "[UNIMOD:5]" # Carbamylation + "(-17.03)": "[UNIMOD:385]" # Ammonia loss diff --git a/winnow/configs/data_loader/mztab.yaml b/winnow/configs/data_loader/mztab.yaml new file mode 100644 index 0000000..dc184ea --- /dev/null +++ b/winnow/configs/data_loader/mztab.yaml @@ -0,0 +1,20 @@ +# --- MZTab data loading configuration --- +_target_: winnow.datasets.data_loaders.MZTabDatasetLoader + +residue_masses: ${residue_masses} +residue_remapping: # Used to map Casanovo-specific notations to UNIMOD tokens. + "M+15.995": "M[UNIMOD:35]" # Oxidation + "Q+0.984": "Q[UNIMOD:7]" # Deamidation + "N+0.984": "N[UNIMOD:7]" # Deamidation + "+42.011": "[UNIMOD:1]" # Acetylation + "+43.006": "[UNIMOD:5]" # Carbamylation + "-17.027": "[UNIMOD:385]" # Ammonia loss + "C+57.021": "C[UNIMOD:4]" # Carbamidomethylation + "C[Carbamidomethyl]": "C[UNIMOD:4]" # Carbamidomethylation + "M[Oxidation]": "M[UNIMOD:35]" # Oxidation + "N[Deamidated]": "N[UNIMOD:7]" # Deamidation + "Q[Deamidated]": "Q[UNIMOD:7]" # Deamidation + # N-terminal modifications. + "[Acetyl]-": "[UNIMOD:1]" # Acetylation + "[Carbamyl]-": "[UNIMOD:5]" # Carbamylation + "[Ammonia-loss]-": "[UNIMOD:385]" # Ammonia loss diff --git a/winnow/configs/data_loader/pointnovo.yaml b/winnow/configs/data_loader/pointnovo.yaml new file mode 100644 index 0000000..022691a --- /dev/null +++ b/winnow/configs/data_loader/pointnovo.yaml @@ -0,0 +1,5 @@ +# --- PointNovo data loading configuration --- + +_target_: winnow.datasets.data_loaders.PointNovoDatasetLoader + +residue_masses: ${residue_masses} diff --git a/winnow/configs/data_loader/winnow.yaml b/winnow/configs/data_loader/winnow.yaml new file mode 100644 index 0000000..dbfb632 --- /dev/null +++ b/winnow/configs/data_loader/winnow.yaml @@ -0,0 +1,7 @@ +# --- Winnow data loading configuration --- + +_target_: winnow.datasets.data_loaders.WinnowDatasetLoader + +residue_masses: ${residue_masses} +# The internal Winnow dataset loader does not need a residue remapping +# since it uses the UNIMOD tokens directly. diff --git a/winnow/configs/fdr_method/database_grounded.yaml b/winnow/configs/fdr_method/database_grounded.yaml new file mode 100644 index 0000000..41a6cc3 --- /dev/null +++ b/winnow/configs/fdr_method/database_grounded.yaml @@ -0,0 +1,8 @@ +# --- Database-grounded FDR control configuration --- + +_target_: winnow.fdr.database_grounded.DatabaseGroundedFDRControl + +confidence_feature: ${fdr_control.confidence_column} # Name of the column with confidence scores to use for FDR estimation. +residue_masses: ${residue_masses} # The residue masses from global `residues` config +isotope_error_range: [0, 1] # The isotope error range for matching peptides +drop: 10 # The number of top predictions to drop for stability diff --git a/winnow/configs/fdr_method/nonparametric.yaml b/winnow/configs/fdr_method/nonparametric.yaml new file mode 100644 index 0000000..2d8c5a3 --- /dev/null +++ b/winnow/configs/fdr_method/nonparametric.yaml @@ -0,0 +1,3 @@ +# --- Non-parametric FDR control configuration --- + +_target_: winnow.fdr.nonparametric.NonParametricFDRControl diff --git a/winnow/configs/predict.yaml b/winnow/configs/predict.yaml new file mode 100644 index 0000000..fa53a6a --- /dev/null +++ b/winnow/configs/predict.yaml @@ -0,0 +1,38 @@ +# --- Predicting scores and applying FDR control --- +defaults: + - _self_ + - residues + - data_loader: instanovo # Options: instanovo, mztab, pointnovo, winnow + - fdr_method: nonparametric # Options: nonparametric, database_grounded + +# --- Pipeline Execution Configuration --- + +dataset: + # Dataset paths: + # Path to the spectrum data file or to folder containing saved internal Winnow dataset. + spectrum_path_or_directory: examples/example_data/spectra.ipc + # Path to the beam predictions file. + # Leave as `null` if data source is `winnow`, or loading will fail. + predictions_path: examples/example_data/predictions.csv + # NOTE: Make sure that the data loader type matches the data source type in this dataset section. + +calibrator: + # Model loading: + # Path to the local calibrator directory or the HuggingFace model identifier. + # If the path is a local directory path, it will be used directly. If it is a HuggingFace repository identifier, it will be downloaded from HuggingFace. + pretrained_model_name_or_path: InstaDeepAI/winnow-general-model + # Directory to cache the HuggingFace model. + cache_dir: null # can be set to `null` if using local model or for the default cache directory from HuggingFace. + +fdr_control: + # FDR settings: + # Target FDR threshold (e.g. 0.01 for 1%, 0.05 for 5% etc.). + fdr_threshold: 0.05 + # Name of the column with confidence scores to use for FDR estimation. + confidence_column: calibrated_confidence + +# Folder path to write the outputs to. +# This will create two CSV files in the output folder: +# - metadata.csv: Contains all metadata and feature columns from the input dataset. +# - preds_and_fdr_metrics.csv: Contains predictions and FDR metrics. +output_folder: results/predictions diff --git a/winnow/configs/residues.yaml b/winnow/configs/residues.yaml new file mode 100644 index 0000000..d76d4a3 --- /dev/null +++ b/winnow/configs/residues.yaml @@ -0,0 +1,64 @@ +# --- Residues configuration --- + +# This is Winnow's internal residue representation. +# We use this to calculate the mass error feature and during database-grounded FDR control. +# We also use this to initialise the residue set for the Metrics class. +residue_masses: + "G": 57.021464 + "A": 71.037114 + "S": 87.032028 + "P": 97.052764 + "V": 99.068414 + "T": 101.047670 + "C": 103.009185 + "L": 113.084064 + "I": 113.084064 + "N": 114.042927 + "D": 115.026943 + "Q": 128.058578 + "K": 128.094963 + "E": 129.042593 + "M": 131.040485 + "H": 137.058912 + "F": 147.068414 + "R": 156.101111 + "Y": 163.063329 + "W": 186.079313 + # Modifications + "M[UNIMOD:35]": 147.035400 # Oxidation + "C[UNIMOD:4]": 160.030649 # Carboxyamidomethylation + "N[UNIMOD:7]": 115.026943 # Deamidation + "Q[UNIMOD:7]": 129.042594 # Deamidation + "R[UNIMOD:7]": 157.085127 # Arginine citrullination + "P[UNIMOD:35]": 113.047679 # Proline hydroxylation + "S[UNIMOD:21]": 166.998028 # Phosphorylation + 79.966 + "T[UNIMOD:21]": 181.01367 # Phosphorylation + 79.966 + "Y[UNIMOD:21]": 243.029329 # Phosphorylation + 79.966 + "C[UNIMOD:312]": 222.013284 # Cysteinylation + "E[UNIMOD:27]": 111.032028 # Glu -> pyro-Glu + "Q[UNIMOD:28]": 111.032029 # Gln -> pyro-Gln + # Terminal modifications + "[UNIMOD:1]": 42.010565 # Acetylation + "[UNIMOD:5]": 43.005814 # Carbamylation + "[UNIMOD:385]": -17.026549 # NH3 loss + "(+25.98)": 25.980265 # Carbamylation & NH3 loss (legacy notation) + +# The tokens to consider as invalid for Prosit features. +# We also filter out non-carboxyamidomethylated Cysteine in a separate step. +invalid_prosit_tokens: + # InstaNovo + - "[UNIMOD:7]" + - "[UNIMOD:21]" + - "[UNIMOD:1]" + - "[UNIMOD:5]" + - "[UNIMOD:385]" + - "(+25.98)" # (legacy notation) + # Casanovo + - "+0.984" + - "+42.011" + - "+43.006" + - "-17.027" + - "[Ammonia-loss]-" + - "[Carbamyl]-" + - "[Acetyl]-" + - "[Deamidated]" diff --git a/winnow/configs/train.yaml b/winnow/configs/train.yaml new file mode 100644 index 0000000..839d3f8 --- /dev/null +++ b/winnow/configs/train.yaml @@ -0,0 +1,21 @@ +# --- Training a calibrator --- +defaults: + - _self_ + - residues + - calibrator + - data_loader: instanovo # Options: instanovo, mztab, pointnovo, winnow + +# --- Pipeline Execution Configuration --- + +dataset: + # Dataset paths: + # Path to the spectrum data file or to folder containing saved internal Winnow dataset. + spectrum_path_or_directory: examples/example_data/spectra.ipc + # Path to the beam predictions file. + # Leave as `null` if data source is `winnow`, or loading will fail. + predictions_path: examples/example_data/predictions.csv + # NOTE: Make sure that the data loader type matches the data source type in this dataset section. + +# Output paths: +model_output_dir: models/new_model +dataset_output_path: results/calibrated_dataset.csv diff --git a/winnow/constants.py b/winnow/constants.py deleted file mode 100644 index a4eeb31..0000000 --- a/winnow/constants.py +++ /dev/null @@ -1,103 +0,0 @@ -from instanovo.utils.residues import ResidueSet -from instanovo.utils.metrics import Metrics - -RESIDUE_MASSES: dict[str, float] = { - "G": 57.021464, - "A": 71.037114, - "S": 87.032028, - "P": 97.052764, - "V": 99.068414, - "T": 101.047670, - "C": 103.009185, - "L": 113.084064, - "I": 113.084064, - "N": 114.042927, - "D": 115.026943, - "Q": 128.058578, - "K": 128.094963, - "E": 129.042593, - "M": 131.040485, - "H": 137.058912, - "F": 147.068414, - "R": 156.101111, - "Y": 163.063329, - "W": 186.079313, - # Modifications - "M[UNIMOD:35]": 147.035400, # Oxidation - "N[UNIMOD:7]": 115.026943, # Deamidation - "Q[UNIMOD:7]": 129.042594, # Deamidation - "C[UNIMOD:4]": 160.030649, # Carboxyamidomethylation - "S[UNIMOD:21]": 166.998028, # Phosphorylation - "T[UNIMOD:21]": 181.01367, # Phosphorylation - "Y[UNIMOD:21]": 243.029329, # Phosphorylation - "[UNIMOD:385]": -17.026549, # Ammonia Loss - "[UNIMOD:5]": 43.005814, # Carbamylation - "[UNIMOD:1]": 42.010565, # Acetylation - "C[UNIMOD:312]": 222.013284, # Cysteinylation - "E[UNIMOD:27]": 111.032028, # Glu -> pyro-Glu - "Q[UNIMOD:28]": 111.032029, # Gln -> pyro-Gln - "(+25.98)": 25.980265, # Carbamylation & NH3 loss -} - -RESIDUE_REMAPPING: dict[str, str] = { - "M(ox)": "M[UNIMOD:35]", # Oxidation - "M(+15.99)": "M[UNIMOD:35]", - "S(p)": "S[UNIMOD:21]", # Phosphorylation - "T(p)": "T[UNIMOD:21]", - "Y(p)": "Y[UNIMOD:21]", - "S(+79.97)": "S[UNIMOD:21]", - "T(+79.97)": "T[UNIMOD:21]", - "Y(+79.97)": "Y[UNIMOD:21]", - "Q(+0.98)": "Q[UNIMOD:7]", # Deamidation - "N(+0.98)": "N[UNIMOD:7]", - "Q(+.98)": "Q[UNIMOD:7]", - "N(+.98)": "N[UNIMOD:7]", - "C(+57.02)": "C[UNIMOD:4]", # Carbamidomethylation - "(+42.01)": "[UNIMOD:1]", # Acetylation - "(+43.01)": "[UNIMOD:5]", # Carbamylation - "(-17.03)": "[UNIMOD:385]", # Loss of ammonia -} - -CASANOVO_RESIDUE_REMAPPING: dict[str, str] = { - "M+15.995": "M[UNIMOD:35]", # Oxidation - "Q+0.984": "Q[UNIMOD:7]", # Deamidation - "N+0.984": "N[UNIMOD:7]", # Deamidation - "+42.011": "[UNIMOD:1]", # Acetylation - "+43.006": "[UNIMOD:5]", # Carbamylation - "-17.027": "[UNIMOD:385]", # Loss of ammonia - "C+57.021": "C[UNIMOD:4]", # Carbamidomethylation - # "+43.006-17.027": "[UNIMOD:5][UNIMOD:385]", # Carbamylation and Loss of ammonia - "C[Carbamidomethyl]": "C[UNIMOD:4]", # Carbamidomethylation - "M[Oxidation]": "M[UNIMOD:35]", # Met oxidation: 131.040485 + 15.994915 - "N[Deamidated]": "N[UNIMOD:7]", # Asn deamidation: 114.042927 + 0.984016 - "Q[Deamidated]": "Q[UNIMOD:7]", # Gln deamidation: 128.058578 + 0.984016 - # N-terminal modifications. - "[Acetyl]-": "[UNIMOD:1]", # Acetylation - "[Carbamyl]-": "[UNIMOD:5]", # Carbamylation - "[Ammonia-loss]-": "[UNIMOD:385]", # Ammonia loss - # "[+25.980265]-": 25.980265 # Carbamylation and ammonia loss -} - -# Each C is also treated as Cysteine with carbamidomethylation in Prosit. -INVALID_PROSIT_TOKENS: list = [ - "(+25.98)", - "[UNIMOD:7]", - "[UNIMOD:21]", - "[UNIMOD:1]", - "[UNIMOD:5]", - "[UNIMOD:385]", - "+0.984", - "+42.011", - "+43.006", - "-17.027", - "[Ammonia-loss]-", - "[Carbamyl]-", - "[Acetyl]-", - "[Deamidated]", -] - - -residue_set = ResidueSet( - residue_masses=RESIDUE_MASSES, residue_remapping=RESIDUE_REMAPPING -) -metrics = Metrics(residue_set=residue_set, isotope_error_range=[0, 1]) diff --git a/winnow/datasets/calibration_dataset.py b/winnow/datasets/calibration_dataset.py index 118387a..50df89b 100644 --- a/winnow/datasets/calibration_dataset.py +++ b/winnow/datasets/calibration_dataset.py @@ -9,7 +9,7 @@ from dataclasses import dataclass from pathlib import Path -from typing import Any, Callable, List, Optional, Tuple +from typing import Any, Callable, List, Optional, Tuple, Union import pickle import numpy as np @@ -48,7 +48,7 @@ def save(self, data_dir: Path) -> None: data_dir (Path): Directory to save the dataset. This will contain `metadata.csv` and optionally, `predictions.pkl` for serialized beam search results. """ - data_dir.mkdir(parents=True) + data_dir.mkdir(parents=True, exist_ok=True) with (data_dir / "metadata.csv").open(mode="w") as metadata_file: output_metadata = self.metadata.copy(deep=True) if "sequence" in output_metadata.columns: @@ -127,20 +127,26 @@ def filter_entries( return CalibrationDataset(predictions=predictions, metadata=metadata) - def to_csv(self, path: Path) -> None: + def to_csv(self, path: Union[Path, str]) -> None: """Saves the dataset metadata to a CSV file. Args: path (str): Path to the output CSV file. """ + if isinstance(path, str): + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) self.metadata.to_csv(path) - def to_parquet(self, path: str) -> None: + def to_parquet(self, path: Union[Path, str]) -> None: """Saves the dataset metadata to a parquet file. Args: path (str): Path to the output parquet file. """ + if isinstance(path, str): + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) self.metadata.to_parquet(path) def _create_predicate_error_message( diff --git a/winnow/datasets/data_loaders.py b/winnow/datasets/data_loaders.py index ae2acca..37d3436 100644 --- a/winnow/datasets/data_loaders.py +++ b/winnow/datasets/data_loaders.py @@ -9,36 +9,61 @@ import re from pathlib import Path from typing import Any, List, Optional, Tuple - import numpy as np import pandas as pd import polars as pl import polars.selectors as cs from pyteomics import mztab +from instanovo.utils.residues import ResidueSet +from instanovo.utils.metrics import Metrics from winnow.datasets.interfaces import DatasetLoader from winnow.datasets.calibration_dataset import ( CalibrationDataset, ScoredSequence, ) -from winnow.constants import metrics, CASANOVO_RESIDUE_REMAPPING class InstaNovoDatasetLoader(DatasetLoader): """Loader for InstaNovo predictions in CSV format.""" + def __init__( + self, + residue_masses: dict[str, float], + residue_remapping: dict[str, str], + isotope_error_range: Tuple[int, int] = (0, 1), + ) -> None: + """Initialise the InstaNovoDatasetLoader. + + Args: + residue_masses: The mapping of residue masses to UNIMOD tokens. + residue_remapping: The mapping of residue notations to UNIMOD tokens. + isotope_error_range: The range of isotope errors to consider when matching peptides. + """ + self.metrics = Metrics( + residue_set=ResidueSet( + residue_masses=residue_masses, residue_remapping=residue_remapping + ), + isotope_error_range=isotope_error_range, + ) + @staticmethod def _load_beam_preds( - predictions_path: Path, + predictions_path: Path | str, ) -> Tuple[pl.DataFrame, pl.DataFrame]: """Loads a dataset from a CSV file and optionally filters it. Args: - predictions_path (Path): The path to the CSV file containing the predictions. + predictions_path (Path | str): The path to the CSV file containing the predictions. Returns: Tuple[pl.DataFrame, pl.DataFrame]: A tuple containing the predictions and beams dataframes. """ + predictions_path = Path(predictions_path) + if predictions_path.suffix != ".csv": + raise ValueError( + f"Unsupported file format for InstaNovo beam predictions: {predictions_path.suffix}. Supported format is .csv." + ) df = pl.read_csv(predictions_path) # Use polars column selectors to split dataframe beam_df = df.select( @@ -54,7 +79,103 @@ def _load_beam_preds( return preds_df, beam_df @staticmethod - def _process_beams(beam_df: pl.DataFrame) -> List[Optional[List[ScoredSequence]]]: + def _load_spectrum_data(spectrum_path: Path | str) -> Tuple[pl.DataFrame, bool]: + """Loads spectrum data from either a Parquet or IPC file. + + Args: + spectrum_path (Path | str): The path to the spectrum data file. + + Returns: + Tuple[pl.DataFrame, bool]: A tuple containing the spectrum data and a boolean indicating whether the dataset has ground truth labels. + """ + spectrum_path = Path(spectrum_path) + + if spectrum_path.suffix == ".parquet": + df = pl.read_parquet(spectrum_path) + elif spectrum_path.suffix == ".ipc": + df = pl.read_ipc(spectrum_path) + else: + raise ValueError( + f"Unsupported file format for spectrum data: {spectrum_path.suffix}. Supported formats are .parquet and .ipc." + ) + + if "sequence" in df.columns: + has_labels = True + else: + has_labels = False + + return df, has_labels + + @staticmethod + def _merge_spectrum_data( + beam_dataset: pd.DataFrame, spectrum_dataset: pd.DataFrame + ) -> pd.DataFrame: + """Merge the input and output data from the de novo sequencing model. + + Args: + beam_dataset (pd.DataFrame): The dataframe containing the beam predictions. + spectrum_dataset (pd.DataFrame): The dataframe containing the spectrum data. + + Returns: + pd.DataFrame: The merged dataframe. + """ + merged_df = pd.merge( + beam_dataset, + spectrum_dataset, + on=["spectrum_id"], + suffixes=("_from_beams", ""), + ) + merged_df = merged_df.drop( + columns=[ + col + "_from_beams" + for col in beam_dataset.columns + if col in spectrum_dataset.columns and col != "spectrum_id" + ], + axis=1, + ) + + if len(merged_df) != len(beam_dataset): + raise ValueError( + f"Merge conflict: Expected {len(beam_dataset)} rows, but got {len(merged_df)}." + ) + + return merged_df + + def load( + self, *, data_path: Path, predictions_path: Optional[Path] = None, **kwargs: Any + ) -> CalibrationDataset: + """Load a CalibrationDataset from InstaNovo CSV predictions. + + Args: + data_path: Path to the spectrum data file + predictions_path: Path to the IPC or parquet beam predictions file + **kwargs: Not used + + Returns: + CalibrationDataset: An instance of the CalibrationDataset class containing metadata and predictions. + + Raises: + ValueError: If predictions_path is None + """ + if predictions_path is None: + raise ValueError("predictions_path is required for InstaNovoDatasetLoader") + + beam_predictions_path = predictions_path + inputs, has_labels = self._load_spectrum_data(data_path) + inputs = self._process_spectrum_data(inputs, has_labels) + + predictions, beams = self._load_beam_preds(beam_predictions_path) + beams = self._process_beams(beams) + predictions = self._process_predictions(predictions.to_pandas(), has_labels) + + predictions = self._merge_spectrum_data(predictions, inputs) + predictions = self._evaluate_predictions(predictions, has_labels) + + return CalibrationDataset(metadata=predictions, predictions=beams) + + def _process_beams( + self, beam_df: pl.DataFrame + ) -> List[Optional[List[ScoredSequence]]]: """Processes beam predictions into scored sequences. Args: @@ -85,7 +206,7 @@ def convert_row_to_scored_sequences( if sequence and log_prob > float("-inf"): scored_sequences.append( ScoredSequence( - sequence=metrics._split_peptide(sequence), + sequence=self.metrics._split_peptide(sequence), mass_error=None, sequence_log_probability=log_prob, token_log_probabilities=token_log_prob, @@ -110,8 +231,35 @@ def convert_row_to_scored_sequences( for row in beam_df.iter_rows(named=True) ] - @staticmethod - def _process_predictions(dataset: pd.DataFrame, has_labels: bool) -> pd.DataFrame: + def _process_spectrum_data( + self, df: pl.DataFrame, has_labels: bool + ) -> pd.DataFrame: + """Processes the input data from the de novo sequencing model. + + Args: + df (pl.DataFrame): The dataframe containing the spectrum data. + has_labels (bool): Whether the dataset has ground truth labels. + + Returns: + pd.DataFrame: The processed dataframe. + """ + # Convert to pandas for downstream compatibility + df = df.to_pandas() + if has_labels: + df["sequence"] = ( + df["sequence"] + .apply( + lambda peptide: peptide.replace("L", "I") + if isinstance(peptide, str) + else peptide + ) + .apply(self.metrics._split_peptide) + ) + return df + + def _process_predictions( + self, dataset: pd.DataFrame, has_labels: bool + ) -> pd.DataFrame: """Processes the predictions obtained from saved beams. Args: @@ -144,7 +292,7 @@ def _process_predictions(dataset: pd.DataFrame, has_labels: bool) -> pd.DataFram else peptide ) dataset["sequence"] = dataset["sequence_untokenised"].apply( - metrics._split_peptide + self.metrics._split_peptide ) dataset["prediction"] = dataset["prediction"].apply( lambda peptide: [ @@ -161,96 +309,9 @@ def _process_predictions(dataset: pd.DataFrame, has_labels: bool) -> pd.DataFram return dataset - @staticmethod - def _load_spectrum_data(spectrum_path: Path | str) -> Tuple[pl.DataFrame, bool]: - """Loads spectrum data from either a Parquet or IPC file. - - Args: - spectrum_path (Path | str): The path to the spectrum data file. - - Returns: - Tuple[pl.DataFrame, bool]: A tuple containing the spectrum data and a boolean indicating whether the dataset has ground truth labels. - """ - spectrum_path = Path(spectrum_path) - - if spectrum_path.suffix == ".parquet": - df = pl.read_parquet(spectrum_path) - elif spectrum_path.suffix == ".ipc": - df = pl.read_ipc(spectrum_path) - else: - raise ValueError( - f"Unsupported file format: {spectrum_path.suffix}. Supported formats are .parquet and .ipc." - ) - - if "sequence" in df.columns: - has_labels = True - else: - has_labels = False - - return df, has_labels - - @staticmethod - def _process_spectrum_data(df: pl.DataFrame, has_labels: bool) -> pd.DataFrame: - """Processes the input data from the de novo sequencing model. - - Args: - df (pl.DataFrame): The dataframe containing the spectrum data. - has_labels (bool): Whether the dataset has ground truth labels. - - Returns: - pd.DataFrame: The processed dataframe. - """ - # Convert to pandas for downstream compatibility - df = df.to_pandas() - if has_labels: - df["sequence"] = ( - df["sequence"] - .apply( - lambda peptide: peptide.replace("L", "I") - if isinstance(peptide, str) - else peptide - ) - .apply(metrics._split_peptide) - ) - return df - - @staticmethod - def _merge_spectrum_data( - beam_dataset: pd.DataFrame, spectrum_dataset: pd.DataFrame + def _evaluate_predictions( + self, dataset: pd.DataFrame, has_labels: bool ) -> pd.DataFrame: - """Merge the input and output data from the de novo sequencing model. - - Args: - beam_dataset (pd.DataFrame): The dataframe containing the beam predictions. - spectrum_dataset (pd.DataFrame): The dataframe containing the spectrum data. - - Returns: - pd.DataFrame: The merged dataframe. - """ - merged_df = pd.merge( - beam_dataset, - spectrum_dataset, - on=["spectrum_id"], - suffixes=("_from_beams", ""), - ) - merged_df = merged_df.drop( - columns=[ - col + "_from_beams" - for col in beam_dataset.columns - if col in spectrum_dataset.columns and col != "spectrum_id" - ], - axis=1, - ) - - if len(merged_df) != len(beam_dataset): - raise ValueError( - f"Merge conflict: Expected {len(beam_dataset)} rows, but got {len(merged_df)}." - ) - - return merged_df - - @staticmethod - def _evaluate_predictions(dataset: pd.DataFrame, has_labels: bool) -> pd.DataFrame: """Evaluates predictions in a dataset by checking validity and accuracy. Args: @@ -269,7 +330,9 @@ def _evaluate_predictions(dataset: pd.DataFrame, has_labels: bool) -> pd.DataFra ) if has_labels: dataset["num_matches"] = dataset.apply( - lambda row: metrics._novor_match(row["sequence"], row["prediction"]) + lambda row: self.metrics._novor_match( + row["sequence"], row["prediction"] + ) if isinstance(row["sequence"], list) and isinstance(row["prediction"], list) else 0, @@ -286,38 +349,6 @@ def _evaluate_predictions(dataset: pd.DataFrame, has_labels: bool) -> pd.DataFra ) return dataset - def load( - self, *, data_path: Path, predictions_path: Optional[Path] = None, **kwargs: Any - ) -> CalibrationDataset: - """Load a CalibrationDataset from InstaNovo CSV predictions. - - Args: - data_path: Path to the spectrum data file - predictions_path: Path to the IPC or parquet beam predictions file - **kwargs: Not used - - Returns: - CalibrationDataset: An instance of the CalibrationDataset class containing metadata and predictions. - - Raises: - ValueError: If predictions_path is None - """ - if predictions_path is None: - raise ValueError("predictions_path is required for InstaNovoDatasetLoader") - - beam_predictions_path = predictions_path - inputs, has_labels = self._load_spectrum_data(data_path) - inputs = self._process_spectrum_data(inputs, has_labels) - - predictions, beams = self._load_beam_preds(beam_predictions_path) - beams = self._process_beams(beams) - predictions = self._process_predictions(predictions.to_pandas(), has_labels) - - predictions = self._merge_spectrum_data(predictions, inputs) - predictions = self._evaluate_predictions(predictions, has_labels) - - return CalibrationDataset(metadata=predictions, predictions=beams) - class MZTabDatasetLoader(DatasetLoader): """Loader for MZTab predictions from both traditional search engines and Casanovo outputs. @@ -344,36 +375,25 @@ class MZTabDatasetLoader(DatasetLoader): """ def __init__( - self, residue_remapping: dict[str, str] | None = None, *args: Any, **kwargs: Any + self, + residue_masses: dict[str, float], + residue_remapping: dict[str, str], + isotope_error_range: Tuple[int, int] = (0, 1), ) -> None: """Initialise the MZTabDatasetLoader. Args: - residue_remapping: Optional dictionary mapping modification strings to UNIMOD format. - If None, uses the default CASANOVO_RESIDUE_REMAPPING. - *args: Additional positional arguments for parent class - **kwargs: Additional keyword arguments for parent class + residue_masses: The mapping of residue masses to UNIMOD tokens. + residue_remapping: The mapping of residue notations to UNIMOD tokens. + isotope_error_range: The range of isotope errors to consider when matching peptides. """ - super().__init__(*args, **kwargs) - self.residue_remapping = ( - residue_remapping - if residue_remapping is not None - else CASANOVO_RESIDUE_REMAPPING + self.metrics = Metrics( + residue_set=ResidueSet( + residue_masses=residue_masses, residue_remapping=residue_remapping + ), + isotope_error_range=isotope_error_range, ) - @staticmethod - def _load_dataset(predictions_path: Path) -> pl.DataFrame: - """Load predictions from mzTab file. - - Args: - predictions_path: Path to mzTab file containing predictions - - Returns: - DataFrame containing predictions - """ - predictions = mztab.MzTab(str(predictions_path)).spectrum_match_table - return pl.DataFrame(predictions) - @staticmethod def _load_spectrum_data(spectrum_path: Path | str) -> Tuple[pl.DataFrame, bool]: """Load spectrum data from either a Parquet or IPC file. @@ -393,7 +413,7 @@ def _load_spectrum_data(spectrum_path: Path | str) -> Tuple[pl.DataFrame, bool]: df = pl.read_ipc(spectrum_path) else: raise ValueError( - f"Unsupported file format: {spectrum_path.suffix}. Supported formats are .parquet and .ipc." + f"Unsupported file format for spectrum data: {spectrum_path.suffix}. Supported formats are .parquet and .ipc." ) if "sequence" in df.columns: @@ -401,6 +421,24 @@ def _load_spectrum_data(spectrum_path: Path | str) -> Tuple[pl.DataFrame, bool]: return df, has_labels + @staticmethod + def _load_dataset(predictions_path: Path | str) -> pl.DataFrame: + """Load predictions from mzTab file. + + Args: + predictions_path: Path to mzTab file containing predictions + + Returns: + DataFrame containing predictions + """ + predictions_path = Path(predictions_path) + if predictions_path.suffix != ".mztab": + raise ValueError( + f"Unsupported file format for MZTab predictions: {predictions_path.suffix}. Supported format is .mztab." + ) + predictions = mztab.MzTab(str(predictions_path)).spectrum_match_table + return pl.DataFrame(predictions) + def load( self, *, data_path: Path, predictions_path: Optional[Path] = None, **kwargs: Any ) -> CalibrationDataset: @@ -528,7 +566,7 @@ def _tokenize( ).with_columns( # Split sequence string into list of amino acid tokens pl.col(tokenised_column) - .map_elements(metrics._split_peptide, return_dtype=pl.List(pl.Utf8)) + .map_elements(self.metrics._split_peptide, return_dtype=pl.List(pl.Utf8)) .alias(tokenised_column) ) @@ -569,7 +607,7 @@ def _create_beam_predictions( def _map_modifications(self, sequence: str) -> str: """Map modifications to UNIMOD.""" - for mod, unimod in self.residue_remapping.items(): + for mod, unimod in self.metrics.residue_remapping.items(): sequence = sequence.replace(mod, unimod) return sequence @@ -667,7 +705,9 @@ def _evaluate_predictions( # Count matching amino acids between prediction and ground truth pl.struct(["sequence", "prediction"]) .map_elements( - lambda row: metrics._novor_match(row["sequence"], row["prediction"]) + lambda row: self.metrics._novor_match( + row["sequence"], row["prediction"] + ) if isinstance(row["sequence"], list) and isinstance(row["prediction"], list) else 0, @@ -724,6 +764,26 @@ def load( class WinnowDatasetLoader(DatasetLoader): """Loader for previously saved CalibrationDataset instances.""" + def __init__( + self, + residue_masses: dict[str, float], + residue_remapping: dict[str, str], + isotope_error_range: Tuple[int, int] = (0, 1), + ) -> None: + """Initialise the WinnowDatasetLoader. + + Args: + residue_masses: The mapping of residue masses to UNIMOD tokens. + residue_remapping: The mapping of residue notations to UNIMOD tokens. + isotope_error_range: The range of isotope errors to consider when matching peptides. + """ + self.metrics = Metrics( + residue_set=ResidueSet( + residue_masses=residue_masses, residue_remapping=residue_remapping + ), + isotope_error_range=isotope_error_range, + ) + def load( self, *, data_path: Path, predictions_path: Optional[Path] = None, **kwargs: Any ) -> CalibrationDataset: @@ -740,34 +800,52 @@ def load( if predictions_path is not None: raise ValueError("predictions_path is not used for WinnowDatasetLoader") - with (data_path / "metadata.csv").open(mode="r") as metadata_file: - metadata = pd.read_csv(metadata_file) - if "sequence" in metadata.columns: - metadata["sequence"] = metadata["sequence"].apply( - metrics._split_peptide - ) - metadata["prediction"] = metadata["prediction"].apply( - metrics._split_peptide + metadata_csv_path = data_path / "metadata.csv" + if not metadata_csv_path.exists(): + raise FileNotFoundError( + f"Winnow dataset loader expects a CSV file containing metadata at {metadata_csv_path}. " + f"The specified directory {data_path} should contain a 'metadata.csv' file " + f"with PSM metadata from a previously saved Winnow dataset." ) - metadata["mz_array"] = metadata["mz_array"].apply( - lambda s: ast.literal_eval(s) - if "," in s - else ast.literal_eval( - re.sub(r"(\n?)(\s+)", ", ", re.sub(r"\[\s+", "[", s)) - ) - ) - metadata["intensity_array"] = metadata["intensity_array"].apply( - lambda s: ast.literal_eval(s) - if "," in s - else ast.literal_eval( - re.sub(r"(\n?)(\s+)", ", ", re.sub(r"\[\s+", "[", s)) - ) + + try: + with metadata_csv_path.open(mode="r") as metadata_file: + metadata = pd.read_csv(metadata_file) + except Exception as e: + raise ValueError( + f"Failed to read metadata.csv from Winnow dataset directory {data_path}. " + f"The file should be a valid CSV containing PSM metadata. Error: {e}" + ) from e + + if "sequence" in metadata.columns: + metadata["sequence"] = metadata["sequence"].apply( + self.metrics._split_peptide ) + metadata["prediction"] = metadata["prediction"].apply( + self.metrics._split_peptide + ) + metadata["mz_array"] = metadata["mz_array"].apply( + lambda s: ast.literal_eval(s) + if "," in s + else ast.literal_eval(re.sub(r"(\n?)(\s+)", ", ", re.sub(r"\[\s+", "[", s))) + ) + metadata["intensity_array"] = metadata["intensity_array"].apply( + lambda s: ast.literal_eval(s) + if "," in s + else ast.literal_eval(re.sub(r"(\n?)(\s+)", ", ", re.sub(r"\[\s+", "[", s))) + ) - predictions_path = data_path / "predictions.pkl" - if predictions_path.exists(): - with predictions_path.open(mode="rb") as predictions_file: - predictions = pickle.load(predictions_file) + predictions_pkl_path = data_path / "predictions.pkl" + if predictions_pkl_path.exists(): + try: + with predictions_pkl_path.open(mode="rb") as predictions_file: + predictions = pickle.load(predictions_file) + except Exception as e: + raise ValueError( + f"Failed to load predictions.pkl from Winnow dataset directory {data_path}. " + f"The file should be a pickled beam predictions object from a previously saved Winnow dataset. " + f"Error: {e}" + ) from e else: predictions = None return CalibrationDataset(metadata=metadata, predictions=predictions) diff --git a/winnow/datasets/interfaces.py b/winnow/datasets/interfaces.py index 91260be..7abfd7a 100644 --- a/winnow/datasets/interfaces.py +++ b/winnow/datasets/interfaces.py @@ -3,7 +3,7 @@ This module provides abstract interfaces that define the contract for dataset loaders. """ -from typing import Protocol, Optional +from typing import Protocol, Optional, Tuple from pathlib import Path from winnow.datasets.calibration_dataset import CalibrationDataset @@ -14,6 +14,21 @@ class DatasetLoader(Protocol): Any class implementing this protocol must provide a load method that returns a CalibrationDataset. """ + def __init__( + self, + residue_masses: dict[str, float], + residue_remapping: dict[str, str] | None = None, + isotope_error_range: Tuple[int, int] = (0, 1), + ) -> None: + """Initialise the DatasetLoader. + + Args: + residue_masses: The mapping of residue masses to UNIMOD tokens. + residue_remapping: Optional mapping of residue notations to UNIMOD tokens. Defaults to None. + isotope_error_range: The range of isotope errors to consider when matching peptides. Defaults to (0, 1). + """ + ... + def load( self, *, data_path: Path, predictions_path: Optional[Path] = None, **kwargs ) -> CalibrationDataset: diff --git a/winnow/fdr/database_grounded.py b/winnow/fdr/database_grounded.py index 141405c..90ba323 100644 --- a/winnow/fdr/database_grounded.py +++ b/winnow/fdr/database_grounded.py @@ -2,9 +2,9 @@ import pandas as pd import numpy as np from instanovo.utils.metrics import Metrics +from instanovo.utils.residues import ResidueSet from winnow.fdr.base import FDRControl -from winnow.constants import residue_set class DatabaseGroundedFDRControl(FDRControl): @@ -13,16 +13,27 @@ class DatabaseGroundedFDRControl(FDRControl): This method estimates FDR thresholds by comparing model-predicted peptides to ground-truth peptides from a database. """ - def __init__(self, confidence_feature: str) -> None: + def __init__( + self, + confidence_feature: str, + residue_masses: dict[str, float], + isotope_error_range: Tuple[int, int] = (0, 1), + drop: int = 10, + ) -> None: super().__init__() self.confidence_feature = confidence_feature + self.residue_masses = residue_masses + self.isotope_error_range = isotope_error_range + self.drop = drop + + self.metrics = Metrics( + residue_set=ResidueSet(residue_masses=residue_masses), + isotope_error_range=isotope_error_range, + ) def fit( # type: ignore self, dataset: pd.DataFrame, - residue_masses: dict[str, float], - isotope_error_range: Tuple[int, int] = (0, 1), - drop: int = 10, ) -> None: """Computes the precision-recall curve by comparing model predictions to database-grounded peptide sequences. @@ -32,25 +43,14 @@ def fit( # type: ignore - 'peptide': Ground-truth peptide sequences. - 'prediction': Model-predicted peptide sequences. - 'confidence': Confidence scores associated with predictions. - - residue_masses (dict[str, float]): A dictionary mapping amino acid residues to their respective masses. - - isotope_error_range (Tuple[int, int], optional): Range of isotope errors to consider when matching peptides. Defaults to (0, 1). - - drop (int): Number of top-scoring predictions to exclude when computing FDR thresholds. Defaults to 10. """ assert len(dataset) > 0, "Fit method requires non-empty data" - metrics = Metrics( - residue_set=residue_set, isotope_error_range=isotope_error_range - ) - - dataset["sequence"] = dataset["sequence"].apply(metrics._split_peptide) - # dataset["prediction"] = dataset["prediction"].apply(metrics._split_peptide) + dataset["sequence"] = dataset["sequence"].apply(self.metrics._split_peptide) dataset["num_matches"] = dataset.apply( lambda row: ( - metrics._novor_match(row["sequence"], row["prediction"]) + self.metrics._novor_match(row["sequence"], row["prediction"]) if isinstance(row["prediction"], list) else 0 ), @@ -70,5 +70,5 @@ def fit( # type: ignore precision = np.cumsum(dataset["correct"]) / np.arange(1, len(dataset) + 1) confidence = np.array(dataset[self.confidence_feature]) - self._fdr_values = np.array(1 - precision[drop:]) - self._confidence_scores = confidence[drop:] + self._fdr_values = np.array(1 - precision[self.drop :]) + self._confidence_scores = confidence[self.drop :] diff --git a/winnow/scripts/config_formatter.py b/winnow/scripts/config_formatter.py new file mode 100644 index 0000000..1bb88e4 --- /dev/null +++ b/winnow/scripts/config_formatter.py @@ -0,0 +1,160 @@ +"""Configuration output formatter with hierarchical colour-coding.""" + +from rich.console import Console +from rich.text import Text +from omegaconf import DictConfig, OmegaConf + + +class ConfigFormatter: + """Format Hydra configuration with hierarchical colour-coding based on nesting depth. + + Keys are coloured according to their indentation level to help visualise + the configuration structure. + """ + + # Colour palette for different indentation levels (similar to Typer's style) + INDENT_COLOURS = [ + "bright_cyan", # Level 0 (root keys) + "bright_green", # Level 1 + "bright_yellow", # Level 2 + "bright_magenta", # Level 3 + "bright_blue", # Level 4 + "cyan", # Level 5 + "green", # Level 6 + "yellow", # Level 7+ + ] + + def __init__(self): + """Initialise the formatter.""" + self.console = Console() + + def print_config(self, cfg: DictConfig) -> None: + """Print configuration with hierarchical colour-coding. + + Args: + cfg: OmegaConf configuration object to format and print + """ + yaml_str = OmegaConf.to_yaml(cfg) + output = Text() + + for line in yaml_str.split("\n"): + formatted_line = self._format_line(line) + output.append(formatted_line) + + self.console.print(output, end="") + + def _format_line(self, line: str) -> Text: + """Format a single line of YAML with appropriate colouring. + + Args: + line: A single line from the YAML output + + Returns: + Rich Text object with formatted content + """ + output = Text() + + # Handle empty lines + if not line.strip(): + output.append("\n") + return output + + indent_level = self._get_indent_level(line) + colour = self._get_colour_for_level(indent_level) + + # Handle list items specially (they contain '- ' prefix) + if self._is_list_item(line): + output.append(line) + output.append("\n") + return output + + # Handle key-value pairs + separator_idx = self._find_key_value_separator(line) + if separator_idx != -1: + self._append_key_value_pair(output, line, separator_idx, colour) + else: + # Lines without key-value separator + output.append(line) + output.append("\n") + + return output + + def _get_indent_level(self, line: str) -> int: + """Calculate the indentation level of a line. + + Args: + line: Line to analyse + + Returns: + Indentation level (0 for root, 1 for first nested level, etc.) + """ + return (len(line) - len(line.lstrip())) // 2 + + def _get_colour_for_level(self, indent_level: int) -> str: + """Get the colour for a given indentation level. + + Args: + indent_level: The indentation level + + Returns: + Colour name for Rich + """ + return self.INDENT_COLOURS[min(indent_level, len(self.INDENT_COLOURS) - 1)] + + def _is_list_item(self, line: str) -> bool: + """Check if a line is a YAML list item. + + Args: + line: Line to check + + Returns: + True if line is a list item (starts with '- ') + """ + return line.lstrip().startswith("- ") + + def _find_key_value_separator(self, line: str) -> int: + """Find the position of the YAML key-value separator. + + This finds colons that are followed by a space or end of line, + avoiding colons inside keys like M[UNIMOD:35]. + + Args: + line: Line to search + + Returns: + Index of the separator colon, or -1 if not found + """ + for i, char in enumerate(line): + if char == ":": + # Check if this is followed by space, end of line, or is the last char + if i + 1 >= len(line) or line[i + 1] == " ": + return i + return -1 + + def _append_key_value_pair( + self, output: Text, line: str, separator_idx: int, colour: str + ) -> None: + """Append a formatted key-value pair to the output. + + Args: + output: Text object to append to + line: Original line + separator_idx: Index of the separator colon + colour: Colour to use for the key + """ + key_part = line[:separator_idx] + value_part = line[separator_idx + 1 :] + indent = " " * (len(line) - len(line.lstrip())) + + # Add indentation + output.append(indent) + + # Add coloured key + output.append(key_part.lstrip(), style=f"bold {colour}") + output.append(":") + + # Add value without formatting (plain text) + if value_part: + output.append(value_part) + + output.append("\n") diff --git a/winnow/scripts/config_path_utils.py b/winnow/scripts/config_path_utils.py new file mode 100644 index 0000000..18b50b9 --- /dev/null +++ b/winnow/scripts/config_path_utils.py @@ -0,0 +1,190 @@ +"""Configuration path resolution utilities. + +This module provides robust path resolution for config directories that works +in both development (cloned repo) and package (installed) modes. +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Optional, List +import logging +import shutil +import tempfile +import atexit + +logger = logging.getLogger(__name__) + +# Track temporary directories for cleanup +_temp_dirs: List[Path] = [] + + +def _cleanup_temp_dirs() -> None: + """Clean up temporary directories on exit.""" + for temp_dir in _temp_dirs: + if temp_dir.exists(): + shutil.rmtree(temp_dir, ignore_errors=True) + + +atexit.register(_cleanup_temp_dirs) + + +def get_config_dir() -> Path: + """Get the primary config directory (package or dev mode). + + Returns: + Path to the config directory. In package mode, returns the package + config directory. In dev mode, returns the repo root config directory. + + Raises: + FileNotFoundError: If config directory cannot be found in either mode. + """ + # Try package mode first (when installed) + try: + from importlib.resources import files + + config_path = files("winnow").joinpath("configs") + if config_path.is_dir(): + return Path(str(config_path)) + except (ModuleNotFoundError, TypeError, AttributeError): + pass + + # Fallback to dev mode (cloned repo) + # This file is in winnow/scripts/, so go up to repo root + script_dir = Path(__file__).parent + repo_root = script_dir.parent.parent + dev_configs = repo_root / "winnow" / "configs" + + if dev_configs.exists() and dev_configs.is_dir(): + return dev_configs + + # If neither works, try alternative dev location (configs at repo root) + alt_dev_configs = repo_root / "configs" + if alt_dev_configs.exists() and alt_dev_configs.is_dir(): + return alt_dev_configs + + raise FileNotFoundError( + f"Could not locate configs directory. Tried:\n" + f" - Package configs: winnow.configs\n" + f" - Dev configs: {dev_configs}\n" + f" - Alt dev configs: {alt_dev_configs}" + ) + + +def get_config_search_path(custom_config_dir: Optional[str] = None) -> List[Path]: + """Get ordered list of config directories for Hydra search path. + + The search path is ordered by priority (first directory has highest priority): + 1. Custom config directory (if provided) + 2. Package configs (when installed) + 3. Development configs (when running from cloned repo) + + Args: + custom_config_dir: Optional path to custom config directory. + If provided, this takes highest priority. + + Returns: + List of config directory paths in priority order (highest first). + All paths are absolute. + + Raises: + FileNotFoundError: If custom_config_dir is provided but doesn't exist. + ValueError: If custom_config_dir is provided but is not a directory. + """ + search_path: List[Path] = [] + + # 1. Custom config directory (highest priority) + if custom_config_dir: + custom_path = Path(custom_config_dir).resolve() + if not custom_path.exists(): + raise FileNotFoundError( + f"Custom config directory does not exist: {custom_config_dir}" + ) + if not custom_path.is_dir(): + raise ValueError( + f"Custom config path is not a directory: {custom_config_dir}" + ) + search_path.append(custom_path) + logger.info(f"Using custom config directory: {custom_path}") + + # 2. Package configs (fallback for files not in custom dir) + try: + package_config_dir = get_config_dir() + # Only add if it's different from custom dir (avoid duplicates) + if not search_path or package_config_dir.resolve() != search_path[0].resolve(): + search_path.append(package_config_dir.resolve()) + logger.debug(f"Added package config directory: {package_config_dir}") + except FileNotFoundError: + logger.warning("Package config directory not found, skipping") + + return search_path + + +def _merge_config_dirs(custom_dir: Path, package_dir: Path) -> Path: + """Create a merged config directory with custom configs overriding package configs. + + Creates a temporary directory containing: + - All files from custom_dir (highest priority) + - Files from package_dir that don't exist in custom_dir (fallback) + + This allows partial configs to work with Hydra's single-directory search. + + Args: + custom_dir: Custom config directory (highest priority). + package_dir: Package config directory (fallback). + + Returns: + Path to temporary merged config directory. + """ + temp_dir = Path(tempfile.mkdtemp(prefix="winnow_configs_")) + _temp_dirs.append(temp_dir) + + # First, copy all package configs (this provides fallback for missing files) + if package_dir.exists(): + for item in package_dir.rglob("*"): + if item.is_file(): + rel_path = item.relative_to(package_dir) + dest_path = temp_dir / rel_path + dest_path.parent.mkdir(parents=True, exist_ok=True) + shutil.copy2(item, dest_path) + + # Then, copy/override with custom configs (this takes precedence) + if custom_dir.exists(): + for item in custom_dir.rglob("*"): + if item.is_file(): + rel_path = item.relative_to(custom_dir) + dest_path = temp_dir / rel_path + dest_path.parent.mkdir(parents=True, exist_ok=True) + shutil.copy2(item, dest_path) + logger.debug(f"Merged custom config: {rel_path}") + + return temp_dir + + +def get_primary_config_dir(custom_config_dir: Optional[str] = None) -> Path: + """Get the primary config directory to use with Hydra. + + If custom_config_dir is provided, creates a merged directory containing + both custom and package configs (custom takes precedence). This allows + partial configs to work - users only need to include files they want to override. + + Otherwise returns package/dev config directory. + + Args: + custom_config_dir: Optional path to custom config directory. + + Returns: + Path to primary config directory (absolute). May be a temporary directory + if custom_config_dir is provided. + """ + if custom_config_dir: + custom_path = Path(custom_config_dir).resolve() + package_path = get_config_dir().resolve() + # Merge custom and package configs so partial configs work + merged_dir = _merge_config_dirs(custom_path, package_path) + logger.info( + f"Using merged config directory (custom: {custom_path}, " + f"package: {package_path}) -> {merged_dir}" + ) + return merged_dir + return get_config_dir().resolve() diff --git a/winnow/scripts/main.py b/winnow/scripts/main.py index 5ba18ed..ec57e01 100644 --- a/winnow/scripts/main.py +++ b/winnow/scripts/main.py @@ -1,90 +1,26 @@ -# -- Import -from winnow.calibration.calibration_features import ( - PrositFeatures, - MassErrorFeature, - RetentionTimeFeature, - ChimericFeatures, - BeamFeatures, -) -from winnow.calibration.calibrator import ProbabilityCalibrator -from winnow.datasets.calibration_dataset import CalibrationDataset -from winnow.datasets.data_loaders import ( - InstaNovoDatasetLoader, - MZTabDatasetLoader, - PointNovoDatasetLoader, - WinnowDatasetLoader, -) -from winnow.fdr.nonparametric import NonParametricFDRControl -from winnow.fdr.database_grounded import DatabaseGroundedFDRControl -from winnow.constants import RESIDUE_MASSES +"""CLI entry point for winnow. + +Note: This module uses lazy imports to minimise CLI startup time. +Heavy dependencies (PyTorch, InstaNovo, etc.) are imported only when +needed, significantly reducing --help and config command times. +""" -from dataclasses import dataclass -from enum import Enum +from __future__ import annotations + +from typing import Union, Tuple, Optional, List, TYPE_CHECKING, Annotated import typer -from typing_extensions import Annotated -from typing import Union, Optional import logging from rich.logging import RichHandler from pathlib import Path -import yaml -import pandas as pd - - -# --- Configuration --- -SEED = 42 -MZ_TOLERANCE = 0.02 -HIDDEN_DIM = 10 -TRAIN_FRACTION = 0.1 - - -class DataSource(Enum): - """Source of a dataset to be used for calibration.""" - - winnow = "winnow" - instanovo = "instanovo" - pointnovo = "pointnovo" - mztab = "mztab" - - -@dataclass -class WinnowDatasetConfig: - """Config for calibration datasets saved through `winnow`.""" - - data_dir: Path - - -@dataclass -class InstaNovoDatasetConfig: - """Config for calibration datasets generated by InstaNovo.""" - - beam_predictions_path: Path - spectrum_path: Path - - -@dataclass -class MZTabDatasetConfig: - """Config for calibration datasets saved in MZTab format.""" - - spectrum_path: Path - predictions_path: Path - - -@dataclass -class PointNovoDatasetConfig: - """Config for calibration datasets generated by PointNovo.""" - - mgf_path: Path - predictions_path: Path - -class FDRMethod(Enum): - """FDR estimation method.""" +# Lazy imports for heavy dependencies - only imported when actually needed +if TYPE_CHECKING: + import pandas as pd + from winnow.datasets.calibration_dataset import CalibrationDataset + from winnow.fdr.nonparametric import NonParametricFDRControl + from winnow.fdr.database_grounded import DatabaseGroundedFDRControl - database = "database-ground" - winnow = "winnow" - - -# --- Logging Setup --- +# Logging setup logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) # Prevent duplicate messages by disabling propagation and using only RichHandler @@ -92,68 +28,33 @@ class FDRMethod(Enum): if not logger.handlers: logger.addHandler(RichHandler()) + +# Typer CLI setup app = typer.Typer( name="winnow", - help=""" - Confidence calibration and FDR estimation for de novo peptide sequencing. - """, + help="""Confidence calibration and FDR estimation for de novo peptide sequencing.""", + rich_markup_mode="rich", ) +# Config command group +config_app = typer.Typer( + name="config", + help="Configuration utilities for inspecting resolved settings.", + rich_markup_mode="rich", +) +app.add_typer(config_app) -def load_dataset( - data_source: DataSource, dataset_config_path: Path -) -> CalibrationDataset: - """Load PSM dataset into a `CalibrationDataset` object. - - Args: - data_source (DataSource): The source of the dataset - dataset_config_path (Path): Path to a `.yaml` file containing arguments - for the load method for the data source. - Raises: - TypeError: If `data_source` is not one of the supported data sources +def print_config(cfg) -> None: + """Print configuration with hierarchical colour-coding based on nesting depth. - Returns: - CalibrationDataset: A calibration dataset + Args: + cfg: OmegaConf configuration object to print """ - logger.info(f"Loading dataset from {data_source}.") - with open(dataset_config_path) as dataset_config_file: - if data_source is DataSource.winnow: - winnow_dataset_config = WinnowDatasetConfig( - **yaml.safe_load(dataset_config_file) - ) - dataset = WinnowDatasetLoader().load( - data_path=Path(winnow_dataset_config.data_dir) - ) - elif data_source is DataSource.instanovo: - instanovo_dataset_config = InstaNovoDatasetConfig( - **yaml.safe_load(dataset_config_file) - ) - dataset = InstaNovoDatasetLoader().load( - data_path=Path(instanovo_dataset_config.spectrum_path), - predictions_path=Path(instanovo_dataset_config.beam_predictions_path), - ) - elif data_source is DataSource.mztab: - mztab_dataset_config = MZTabDatasetConfig( - **yaml.safe_load(dataset_config_file) - ) - dataset = MZTabDatasetLoader().load( - data_path=Path(mztab_dataset_config.spectrum_path), - predictions_path=Path(mztab_dataset_config.predictions_path), - ) - elif data_source is DataSource.pointnovo: - pointnovo_dataset_config = PointNovoDatasetConfig( - **yaml.safe_load(dataset_config_file) - ) - dataset = PointNovoDatasetLoader().load( - data_path=Path(pointnovo_dataset_config.mgf_path), - predictions_path=Path(pointnovo_dataset_config.predictions_path), - ) - else: - raise TypeError( - f"Data source was {data_source}. Only 'instanovo', 'mztab' and 'pointnovo' are supported." - ) - return dataset + from winnow.scripts.config_formatter import ConfigFormatter + + formatter = ConfigFormatter() + formatter.print_config(cfg) def filter_dataset(dataset: CalibrationDataset) -> CalibrationDataset: @@ -165,7 +66,6 @@ def filter_dataset(dataset: CalibrationDataset) -> CalibrationDataset: Returns: CalibrationDataset: The filtered dataset """ - logger.info("Filtering dataset.") filtered_dataset = ( dataset.filter_entries( # Filter out non-list predictions @@ -177,47 +77,6 @@ def filter_dataset(dataset: CalibrationDataset) -> CalibrationDataset: return filtered_dataset -def initialise_calibrator( - learn_prosit_missing: bool = True, - learn_chimeric_missing: bool = True, - learn_retention_missing: bool = True, -) -> ProbabilityCalibrator: - """Set up the probability calibrator with features. - - Args: - learn_prosit_missing: Whether to learn from missing Prosit features. If False, - errors will be raised when invalid spectra are encountered. - learn_chimeric_missing: Whether to learn from missing chimeric features. If False, - errors will be raised when invalid spectra are encountered. - learn_retention_missing: Whether to learn from missing retention time features. If False, - errors will be raised when invalid spectra are encountered. - - Returns: - ProbabilityCalibrator: Configured calibrator with specified features. - """ - calibrator = ProbabilityCalibrator(SEED) - calibrator.add_feature(MassErrorFeature(residue_masses=RESIDUE_MASSES)) - calibrator.add_feature( - PrositFeatures( - mz_tolerance=MZ_TOLERANCE, learn_from_missing=learn_prosit_missing - ) - ) - calibrator.add_feature( - RetentionTimeFeature( - hidden_dim=HIDDEN_DIM, - train_fraction=TRAIN_FRACTION, - learn_from_missing=learn_retention_missing, - ) - ) - calibrator.add_feature( - ChimericFeatures( - mz_tolerance=MZ_TOLERANCE, learn_from_missing=learn_chimeric_missing - ) - ) - calibrator.add_feature(BeamFeatures()) - return calibrator - - def apply_fdr_control( fdr_control: Union[NonParametricFDRControl, DatabaseGroundedFDRControl], dataset: CalibrationDataset, @@ -225,14 +84,14 @@ def apply_fdr_control( confidence_column: str, ) -> pd.DataFrame: """Apply FDR control to a dataset.""" + from winnow.fdr.nonparametric import NonParametricFDRControl + if isinstance(fdr_control, NonParametricFDRControl): fdr_control.fit(dataset=dataset.metadata[confidence_column]) dataset.metadata = fdr_control.add_psm_pep(dataset.metadata, confidence_column) else: - fdr_control.fit( - dataset=dataset.metadata[confidence_column], - residue_masses=RESIDUE_MASSES, - ) + fdr_control.fit(dataset=dataset.metadata[confidence_column]) + dataset.metadata = fdr_control.add_psm_fdr(dataset.metadata, confidence_column) dataset.metadata = fdr_control.add_psm_q_value(dataset.metadata, confidence_column) confidence_cutoff = fdr_control.get_confidence_cutoff(threshold=fdr_threshold) @@ -249,197 +108,352 @@ def check_if_labelled(dataset: CalibrationDataset) -> None: ) -@app.command(name="train", help="Fit a calibration model.") -def train( - data_source: Annotated[ - DataSource, typer.Option(help="The type of PSM dataset to be calibrated.") - ], - dataset_config_path: Annotated[ - Path, - typer.Option( - help="The path to the config with the specification of the calibration dataset." - ), - ], - model_output_dir: Annotated[ - Path, - typer.Option( - help="The path to the directory where the fitted model checkpoint will be saved." - ), - ], - dataset_output_path: Annotated[ - Path, typer.Option(help="The path to write the output to.") - ], - learn_prosit_missing: Annotated[ - bool, - typer.Option( - help="Whether to learn from missing Prosit features. If False, training will fail if any spectra have invalid Prosit predictions." - ), - ] = True, - learn_chimeric_missing: Annotated[ - bool, - typer.Option( - help="Whether to learn from missing chimeric features. If False, training will fail if any spectra have invalid predictions for chimeric feature computation." - ), - ] = True, - learn_retention_missing: Annotated[ - bool, - typer.Option( - help="Whether to learn from missing retention time features. If False, training will fail if any spectra have invalid retention time predictions." - ), - ] = True, -): - """Fit the calibration model. +def separate_metadata_and_predictions( + dataset_metadata: pd.DataFrame, + fdr_control: Union[NonParametricFDRControl, DatabaseGroundedFDRControl], + confidence_column: str, +) -> Tuple[pd.DataFrame, pd.DataFrame]: + """Separate out metadata from prediction and FDR metrics. Args: - data_source (Annotated[ DataSource, typer.Option, optional): The type of PSM dataset to be calibrated. - dataset_config_path (Annotated[ Path, typer.Option, optional): The path to the config with the specification of the calibration dataset. - model_output_dir (Annotated[Path, typer.Option, optional]): The path to the directory where the fitted model checkpoint will be saved. - dataset_output_path (Annotated[Path, typer.Option, optional): The path to write the output to. + dataset_metadata: The metadata dataframe to separate out prediction and FDR metrics from metadata and computed features. + fdr_control: The FDR control object used (to determine which columns were added). + confidence_column: The name of the confidence column. + + Returns: + Tuple[pd.DataFrame, pd.DataFrame]: A tuple containing the metadata dataframe and the prediction and FDR metrics dataframe. """ - # -- Load dataset - logger.info("Loading datasets.") - annotated_dataset = load_dataset( - data_source=data_source, - dataset_config_path=dataset_config_path, - ) + from winnow.fdr.nonparametric import NonParametricFDRControl + + # Separate out metadata from prediction and FDR metrics + preds_and_fdr_metrics_cols = [ + confidence_column, + "prediction", + "psm_fdr", + "psm_q_value", + ] + if "sequence" in dataset_metadata.columns: + preds_and_fdr_metrics_cols.append("sequence") + # NonParametricFDRControl adds psm_pep column + if isinstance(fdr_control, NonParametricFDRControl): + preds_and_fdr_metrics_cols.append("psm_pep") + dataset_preds_and_fdr_metrics = dataset_metadata[ + preds_and_fdr_metrics_cols + ["spectrum_id"] + ] + dataset_metadata = dataset_metadata.drop(columns=preds_and_fdr_metrics_cols) + return dataset_metadata, dataset_preds_and_fdr_metrics + + +def train_entry_point( + overrides: Optional[List[str]] = None, + execute: bool = True, + config_dir: Optional[str] = None, +) -> None: + """The main training pipeline entry point. + + Args: + overrides: Optional list of config overrides. + execute: If False, only print the configuration and return without executing the pipeline. + config_dir: Optional path to custom config directory. If provided, configs in this + directory take precedence over package configs. Files not in custom dir will use package defaults (file-by-file resolution). + """ + from hydra import initialize_config_dir, compose + from hydra.utils import instantiate + from winnow.scripts.config_path_utils import get_primary_config_dir + + # Get primary config directory (custom if provided, otherwise package/dev) + primary_config_dir = get_primary_config_dir(config_dir) + # Initialise Hydra with primary config directory + with initialize_config_dir( + config_dir=str(primary_config_dir), + version_base="1.3", + job_name="winnow_train", + ): + cfg = compose(config_name="train", overrides=overrides) + + if not execute: + print_config(cfg) + return + + from winnow.calibration.calibrator import ProbabilityCalibrator + + logger.info("Starting training pipeline.") + logger.info(f"Training configuration: {cfg}") + + # Load dataset - Hydra creates the DatasetLoader object + logger.info("Loading dataset.") + data_loader = instantiate(cfg.data_loader) + + # Extract dataset loading parameters and convert to dict for flexible kwargs + dataset_params = dict(cfg.dataset) + # Rename config keys to match the Protocol interface + dataset_params["data_path"] = dataset_params.pop("spectrum_path_or_directory") + dataset_params["predictions_path"] = dataset_params.pop("predictions_path", None) + + annotated_dataset = data_loader.load(**dataset_params) + + logger.info("Filtering dataset.") annotated_dataset = filter_dataset(annotated_dataset) - # Train - logger.info("Training calibrator.") - calibrator = initialise_calibrator( - learn_prosit_missing=learn_prosit_missing, - learn_chimeric_missing=learn_chimeric_missing, - learn_retention_missing=learn_retention_missing, - ) + # Instantiate the calibrator from the config + logger.info("Instantiating calibrator from config.") + calibrator = instantiate(cfg.calibrator) + + # Fit the calibrator to the dataset + logger.info("Fitting calibrator to dataset.") calibrator.fit(annotated_dataset) - # -- Write model checkpoints - logger.info(f"Saving model to {model_output_dir}") - ProbabilityCalibrator.save(calibrator, model_output_dir) + # Save the model + logger.info(f"Saving model to {cfg.model_output_dir}") + ProbabilityCalibrator.save(calibrator, cfg.model_output_dir) - # -- Write output - logger.info("Writing output.") - annotated_dataset.to_csv(dataset_output_path) - logger.info(f"Training dataset results saved: {dataset_output_path}") + # Save the training dataset results + logger.info(f"Saving training dataset results to {cfg.dataset_output_path}") + annotated_dataset.to_csv(cfg.dataset_output_path) + logger.info("Training pipeline completed successfully.") -@app.command( - name="predict", - help="Calibrate scores and optionally filter results to a target FDR.", -) -def predict( - data_source: Annotated[ - DataSource, typer.Option(help="The type of PSM dataset to be calibrated.") - ], - dataset_config_path: Annotated[ - Path, - typer.Option( - help="The path to the config with the specification of the calibration dataset." - ), - ], - method: Annotated[ - FDRMethod, typer.Option(help="Method to use for FDR estimation.") - ], - fdr_threshold: Annotated[ - float, - typer.Option( - help="The target FDR threshold (e.g. 0.01 for 1%, 0.05 for 5% etc.)" - ), - ], - confidence_column: Annotated[ - str, typer.Option(help="Name of the column with confidence scores.") - ], - output_folder: Annotated[ - Path, typer.Option(help="The folder path to write the outputs to.") - ], - huggingface_model_name: Annotated[ - str, - typer.Option( - help="HuggingFace model identifier. If neither this nor `--local-model-folder` are provided, loads default model from HuggingFace.", - ), - ] = "InstaDeepAI/winnow-general-model", - local_model_folder: Annotated[ - Optional[Path], - typer.Option( - help="Path to local calibrator directory. If neither this nor `--huggingface-model-name` are provided, loads default pretrained model from HuggingFace.", - ), - ] = None, -): - """Calibrate model scores, estimate FDR and filter for a threshold. + +def predict_entry_point( + overrides: Optional[List[str]] = None, + execute: bool = True, + config_dir: Optional[str] = None, +) -> None: + """The main prediction pipeline entry point. Args: - data_source (Annotated[ DataSource, typer.Option, optional): The type of PSM dataset to be calibrated. - dataset_config_path (Annotated[ Path, typer.Option, optional): The path to the config with the specification of the dataset. - method (Annotated[ FDRMethod, typer.Option, optional): Method to use for FDR estimation. - fdr_threshold (Annotated[ float, typer.Option, optional): The target FDR threshold (e.g. 0.01 for 1%, 0.05 for 5% etc.). - confidence_column (Annotated[ str, typer.Option, optional): Name of the column with confidence scores. - output_folder (Annotated[ Path, typer.Option, optional): The folder path to write the outputs to: `metadata.csv` and `preds_and_fdr_metrics.csv`. - huggingface_model_name (Annotated[str, typer.Option, optional): HuggingFace model identifier. - local_model_folder (Annotated[Path, typer.Option, optional): Path to local calibrator directory (e.g., Path("./my-model-directory")). - - Note that either `local_model_folder` or `huggingface-model-name` may be overwritten, but not both. - If neither `local_model_folder` nor `huggingface-model-name` are provided, the general model from HuggingFace will be loaded by default (i.e., `InstaDeepAI/winnow-general-model`). + overrides: Optional list of config overrides. + execute: If False, only print the configuration and return without executing the pipeline. + config_dir: Optional path to custom config directory. If provided, configs in this + directory take precedence over package configs. Files not in custom dir will use + package defaults (file-by-file resolution). """ - # -- Load dataset - logger.info("Loading datasets.") - dataset = load_dataset( - data_source=data_source, - dataset_config_path=dataset_config_path, - ) + from hydra import initialize_config_dir, compose + from hydra.utils import instantiate + from winnow.scripts.config_path_utils import get_primary_config_dir + + # Get primary config directory (custom if provided, otherwise package/dev) + primary_config_dir = get_primary_config_dir(config_dir) + + # Initialize Hydra with primary config directory + with initialize_config_dir( + config_dir=str(primary_config_dir), + version_base="1.3", + job_name="winnow_predict", + ): + cfg = compose(config_name="predict", overrides=overrides) + + if not execute: + print_config(cfg) + return + + from winnow.calibration.calibrator import ProbabilityCalibrator + from winnow.fdr.database_grounded import DatabaseGroundedFDRControl + + logger.info("Starting prediction pipeline.") + logger.info(f"Prediction configuration: {cfg}") + + # Load dataset - Hydra creates the DatasetLoader object + logger.info("Loading dataset.") + data_loader = instantiate(cfg.data_loader) + + # Extract dataset loading parameters and convert to dict for flexible kwargs + dataset_params = dict(cfg.dataset) + # Rename config keys to match the Protocol interface + dataset_params["data_path"] = dataset_params.pop("spectrum_path_or_directory") + dataset_params["predictions_path"] = dataset_params.pop("predictions_path", None) + + dataset = data_loader.load(**dataset_params) + logger.info("Filtering dataset.") dataset = filter_dataset(dataset) - # Predict - # If local_model_folder is an empty string, load the HuggingFace model - if local_model_folder is None: - logger.info(f"Loading HuggingFace model: {huggingface_model_name}") - calibrator = ProbabilityCalibrator.load(huggingface_model_name) - # Otherwise, load the model from the local folder path - else: - logger.info(f"Loading local model from: {local_model_folder}") - calibrator = ProbabilityCalibrator.load(local_model_folder) + # Load trained calibrator + logger.info("Loading trained calibrator.") + calibrator = ProbabilityCalibrator.load( + pretrained_model_name_or_path=cfg.calibrator.pretrained_model_name_or_path, + cache_dir=cfg.calibrator.cache_dir, + ) + # Calibrate scores logger.info("Calibrating scores.") calibrator.predict(dataset) - if method is FDRMethod.winnow: - logger.info("Applying FDR control.") - dataset_metadata = apply_fdr_control( - NonParametricFDRControl(), dataset, fdr_threshold, confidence_column - ) - elif method is FDRMethod.database: - logger.info("Applying FDR control.") + # Instantiate FDR control from config - Hydra handles which FDR method to use + logger.info("Instantiating FDR control from config.") + fdr_control = instantiate(cfg.fdr_method) + + # Check if dataset is labelled for database-grounded FDR + if isinstance(fdr_control, DatabaseGroundedFDRControl): check_if_labelled(dataset) - dataset_metadata = apply_fdr_control( - DatabaseGroundedFDRControl(confidence_feature=confidence_column), - dataset, - fdr_threshold, - confidence_column, - ) - # -- Write output - logger.info("Writing output.") - # Separate out metadata from prediction and FDR metrics - preds_and_fdr_metrics_cols = [ - confidence_column, - "prediction", - "psm_fdr", - "psm_q_value", - ] - if "sequence" in dataset_metadata.columns: - preds_and_fdr_metrics_cols.append("sequence") - if method is FDRMethod.winnow: - preds_and_fdr_metrics_cols.append("psm_pep") - dataset_preds_and_fdr_metrics = dataset_metadata[ - preds_and_fdr_metrics_cols + ["spectrum_id"] - ] - dataset_metadata = dataset_metadata.drop(columns=preds_and_fdr_metrics_cols) - # Write outputs - output_folder.mkdir(parents=True) - dataset_metadata.to_csv(output_folder / "metadata.csv") - dataset_preds_and_fdr_metrics.to_csv(output_folder / "preds_and_fdr_metrics.csv") - logger.info(f"Outputs saved: {output_folder}") + # Apply FDR control + logger.info(f"Applying {fdr_control.__class__.__name__} FDR control.") + dataset_metadata = apply_fdr_control( + fdr_control, + dataset, + cfg.fdr_control.fdr_threshold, + cfg.fdr_control.confidence_column, + ) + + # Write output + logger.info(f"Writing output to {cfg.output_folder}") + dataset_metadata, dataset_preds_and_fdr_metrics = separate_metadata_and_predictions( + dataset_metadata, fdr_control, cfg.fdr_control.confidence_column + ) + output_folder = Path(cfg.output_folder) + output_folder.mkdir(parents=True, exist_ok=True) + dataset_metadata.to_csv(output_folder.joinpath("metadata.csv")) + dataset_preds_and_fdr_metrics.to_csv( + output_folder.joinpath("preds_and_fdr_metrics.csv") + ) + + logger.info("Prediction pipeline completed successfully.") + + +@app.command( + name="train", + help=( + "Train a probability calibration model on annotated peptide sequencing data.\n\n" + "This command loads your dataset, trains calibration features, and saves the trained model.\n\n" + "[bold cyan]Quick start:[/bold cyan]\n" + " [dim]winnow train[/dim] # Uses default config from config/train.yaml\n\n" + "[bold cyan]Override parameters:[/bold cyan]\n" + " [dim]winnow train data_loader=mztab[/dim] # Use MZTab format instead of InstaNovo\n" + " [dim]winnow train model_output_dir=models/my_model[/dim] # Custom output location\n" + " [dim]winnow train calibrator.seed=42[/dim] # Set random seed\n\n" + "[bold cyan]Custom config directory:[/bold cyan]\n" + " [dim]winnow train --config-dir /path/to/configs[/dim] # Use custom config directory\n" + " [dim]winnow train -cp ./my_configs[/dim] # Short form (relative or absolute path)\n" + " See docs for advanced usage.\n\n" + "[bold cyan]Configuration files to customise:[/bold cyan]\n" + " • config/train.yaml - Main config (data paths, output locations)\n" + " • config/calibrator.yaml - Model architecture and features\n" + " • config/data_loader/ - Dataset format loaders\n" + " • config/residues.yaml - Amino acid masses and modifications" + ), + context_settings={"allow_extra_args": True, "ignore_unknown_options": True}, +) +def train( + ctx: typer.Context, + config_dir: Annotated[ + Optional[str], + typer.Option( + "--config-dir", + "-cp", + help="Path to custom config directory (relative or absolute). See documentation for advanced usage.", + ), + ] = None, +) -> None: + """Passes control directly to the Hydra training pipeline.""" + # Capture extra arguments as Hydra overrides (--config-dir already parsed out by Typer) + overrides = ctx.args if ctx.args else None + train_entry_point(overrides, config_dir=config_dir) + + +@app.command( + name="predict", + help=( + "Calibrate confidence scores and filter peptide predictions by false discovery rate (FDR).\n\n" + "This command loads your dataset, applies a trained calibrator to improve confidence scores, " + "estimates FDR using your chosen method, and outputs filtered predictions at your target FDR threshold.\n\n" + "[bold cyan]Quick start:[/bold cyan]\n" + " [dim]winnow predict[/dim] # Uses default config from config/predict.yaml\n\n" + "[bold cyan]Override parameters:[/bold cyan]\n" + " [dim]winnow predict data_loader=mztab[/dim] # Use MZTab format instead of InstaNovo\n" + " [dim]winnow predict fdr_method=database_grounded[/dim] # Use database-grounded FDR\n" + " [dim]winnow predict fdr_threshold=0.01[/dim] # Target 1% FDR instead of 5%\n" + " [dim]winnow predict output_folder=results/my_run[/dim] # Custom output location\n\n" + "[bold cyan]Custom config directory:[/bold cyan]\n" + " [dim]winnow predict --config-dir /path/to/configs[/dim] # Use custom config directory\n" + " [dim]winnow predict -cp ./my_configs[/dim] # Short form (relative or absolute path)\n" + " See docs for advanced usage.\n\n" + "[bold cyan]Configuration files to customise:[/bold cyan]\n" + " • config/predict.yaml - Main config (data paths, FDR settings, output)\n" + " • config/fdr_method/ - FDR methods (nonparametric, database_grounded)\n" + " • config/data_loader/ - Dataset format loaders\n" + " • config/residues.yaml - Amino acid masses and modifications" + ), + context_settings={"allow_extra_args": True, "ignore_unknown_options": True}, +) +def predict( + ctx: typer.Context, + config_dir: Annotated[ + Optional[str], + typer.Option( + "--config-dir", + "-cp", + help="Path to custom config directory (relative or absolute). See documentation for advanced usage.", + ), + ] = None, +) -> None: + """Passes control directly to the Hydra predict pipeline.""" + # Capture extra arguments as Hydra overrides (--config-dir already parsed out by Typer) + overrides = ctx.args if ctx.args else None + predict_entry_point(overrides, config_dir=config_dir) + + +@config_app.command( + name="train", + help=( + "Display the resolved training configuration without running the pipeline.\n\n" + "This is useful for inspecting the final configuration after all defaults " + "and overrides have been applied.\n\n" + "[bold cyan]Usage:[/bold cyan]\n" + " [dim]winnow config train[/dim] # Show default config\n" + " [dim]winnow config train data_loader=mztab[/dim] # Show config with overrides\n" + " [dim]winnow config train calibrator.seed=42[/dim] # Check override application\n" + " [dim]winnow config train --config-dir /path/to/configs[/dim] # Show config with custom directory\n" + " [dim]winnow config train -cp ./my_configs[/dim] # Short form (relative or absolute path)" + ), + context_settings={"allow_extra_args": True, "ignore_unknown_options": True}, +) +def config_train( + ctx: typer.Context, + config_dir: Annotated[ + Optional[str], + typer.Option( + "--config-dir", + "-cp", + help="Path to custom config directory (relative or absolute). See documentation for advanced usage.", + ), + ] = None, +) -> None: + """Display the resolved training configuration.""" + overrides = ctx.args if ctx.args else None + train_entry_point(overrides, execute=False, config_dir=config_dir) + + +@config_app.command( + name="predict", + help=( + "Display the resolved prediction configuration without running the pipeline.\n\n" + "This is useful for inspecting the final configuration after all defaults " + "and overrides have been applied.\n\n" + "[bold cyan]Usage:[/bold cyan]\n" + " [dim]winnow config predict[/dim] # Show default config\n" + " [dim]winnow config predict fdr_method=database_grounded[/dim] # Show config with overrides\n" + " [dim]winnow config predict fdr_control.fdr_threshold=0.01[/dim] # Check override application\n" + " [dim]winnow config predict --config-dir /path/to/configs[/dim] # Show config with custom directory\n" + " [dim]winnow config predict -cp ./my_configs[/dim] # Short form (relative or absolute path)" + ), + context_settings={"allow_extra_args": True, "ignore_unknown_options": True}, +) +def config_predict( + ctx: typer.Context, + config_dir: Annotated[ + Optional[str], + typer.Option( + "--config-dir", + "-cp", + help="Path to custom config directory (relative or absolute). See documentation for advanced usage.", + ), + ] = None, +) -> None: + """Display the resolved prediction configuration.""" + overrides = ctx.args if ctx.args else None + predict_entry_point(overrides, execute=False, config_dir=config_dir) if __name__ == "__main__":