Skip to content

Commit f174183

Browse files
committed
Add experiment matrix in config and add better example files
1 parent 9e4861a commit f174183

16 files changed

Lines changed: 524 additions & 365 deletions

File tree

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ tests/.hypothesis/
2424
config/*
2525
!config/.gitkeep
2626
!config/config_default.yml
27+
!config/config_example_smri.yml
28+
!config/config_example_experiment.yml
2729

2830
# Local data artifacts (never track user datasets)
2931
data/*

README.md

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,8 @@ docker run --rm \
111111
### Arguments
112112

113113
The `main.py` file accepts the following required arguments:
114-
- `--config`: Path to the configuration file. Default: `config/config_default.yml`.
115-
- `--mode`: Mode to run the project in. Options: `train`, `inference`, `validate`, `tune`. Default: `train`.
114+
- `--config`: Configuration filename located under `./config` (for example `config_default.yml` or `config.yml`).
115+
- `--mode`: Mode to run the project in. Options: `train`, `inference`, `validate`, `tune`, `experiment`.
116116

117117
The `main.py` file accepts the following optional arguments:
118118
- `--checkpoint`: Path to a checkpoint to load.
@@ -151,23 +151,49 @@ To skip the preprocessing pipeline, use the `--skip-preprocessing` flag.
151151

152152
### Configuration
153153

154-
The project configuration is defined in the `config` directory. The configuration file is in `YAML` format. The default configuration file is `config/config_default.yml`.
154+
The project configuration is defined in the `config` directory and uses `YAML`.
155155

156-
**Important:** Copy this file to create a custom configuration. The default configuration will be **overwritten** each time the application is run.
156+
- `config/config_default.yml`: generic runtime defaults intended for new projects and new datasets.
157+
- `config/config_example_smri.yml`: a non-runnable example template showing an sMRI-oriented setup with placeholder paths.
158+
- `config/config_example_experiment.yml`: an experiment-sweep template showing how to define custom `experiment` mode runs.
159+
160+
At startup, the app ensures `config/config_default.yml` exists. If it already exists, it is kept as-is (not overwritten).
161+
For your own run configs, create `config/config.yml` (or another filename in `./config`) and pass it with `--config`.
157162

158163
The configuration file contains the following sections:
159164

160165
- `dataset`: Configuration for the dataset.
161166
- `model`: Configuration for the model.
162167
- `train`: Configuration for the training process.
163-
- `inference`: Configuration for the inference process.
164168
- `validation`: Configuration for the validation process.
165169
- `meta`: Meta information for the project.
166170
- `general`: General configuration for the project.
167171
- `system`: System configuration for the project.
168172

169173
The configuration is validated against the schema before a task is run. Missing values are filled with default values from the default configuration file.
170174

175+
`experiment` mode is fully config-driven via `model.experiment_matrix`:
176+
177+
```yaml
178+
model:
179+
components:
180+
vae:
181+
covariate_embedding: no_embedding
182+
latent_dim: 32
183+
experiment_matrix:
184+
embedding_methods: [no_embedding, fair_embedding]
185+
latent_dims: [8, 16, 32]
186+
dataset_files: [dataset_a.csv, dataset_b.csv]
187+
repetitions: 3
188+
```
189+
190+
For a concrete copy-paste example similar to the previous site-based setup, see `config/config_example_experiment.yml`.
191+
That example uses explicit dataset filenames (`site_dataset_<rep>_site_<site>.rds`) with `repetitions: 1`.
192+
193+
To reproduce the old "harmonized" branch behavior, run a second experiment config where:
194+
- `model.experiment_matrix.embedding_methods: [no_embedding]`
195+
- `model.experiment_matrix.dataset_files` contains only `harmonized_site_dataset_<rep>_site_<site>.rds` files.
196+
171197
### Output Files
172198

173199
The application generates different types of artifacts during the execution of the project.

config/config_default.yml

Lines changed: 10 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ general:
1313
verbose: false
1414
meta:
1515
config_version: 2
16-
description: Variational Autoencoder design experiment setup
16+
description: Multivariate normative modeling configuration
1717
name: vae_basic
1818
version: 1
1919
dataset:
@@ -23,7 +23,7 @@ dataset:
2323
covariates: []
2424
skipped_covariates: []
2525
targets: []
26-
input_data: generated_data.rds
26+
input_data: your_dataset.csv
2727
data_type: tabular
2828
image_type: grayscale
2929
internal_file_format: hdf
@@ -42,40 +42,11 @@ dataset:
4242
- name: EncodingTransform
4343
type: preprocessing
4444
params:
45-
default: z-score
46-
one_hot_encoding:
47-
- site
48-
- sex
49-
z-score:
50-
- age
45+
default: min-max
46+
one_hot_encoding: []
47+
z-score: []
5148
min-max: []
5249
raw: []
53-
- name: SiteFilterTransform
54-
type: preprocessing
55-
params:
56-
selected_site: -1
57-
col_name: site
58-
- name: WaveFilterTransform
59-
type: preprocessing
60-
params:
61-
selected_wave: -1
62-
col_name: wave
63-
- name: AgeFilterTransform
64-
type: preprocessing
65-
params:
66-
age_lowerbound: 0.0
67-
age_upperbound: 100.0
68-
col_name: age
69-
- name: SexFilterTransform
70-
type: preprocessing
71-
params:
72-
sex: -1
73-
col_name: sex
74-
- name: SampleLimitTransform
75-
type: preprocessing
76-
params:
77-
max_samples: 1000
78-
shuffle: true
7950
data_analysis:
8051
features:
8152
reconstruction_mse: true
@@ -388,6 +359,11 @@ model:
388359
decoder: mlp
389360
latent_dim: 32
390361
covariate_embedding: no_embedding
362+
experiment_matrix:
363+
embedding_methods: []
364+
latent_dims: []
365+
dataset_files: []
366+
repetitions: 1
391367
hidden_layers:
392368
- 1024
393369
- 512
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Example experiment configuration (customizable template).
2+
# This file is intended as a reference and uses placeholder dataset names.
3+
4+
meta:
5+
name: experiment_site_sweep_template
6+
description: Example of a custom experiment sweep over embeddings, latent dims, and site-specific datasets
7+
8+
dataset:
9+
input_data: site_dataset_0_site_0.rds
10+
covariates:
11+
- age
12+
- sex
13+
- site
14+
enable_transforms: true
15+
transforms:
16+
- name: DataCleaningTransform
17+
type: preprocessing
18+
params:
19+
drop_na: true
20+
remove_duplicates: true
21+
- name: EncodingTransform
22+
type: preprocessing
23+
params:
24+
default: z-score
25+
one_hot_encoding:
26+
- sex
27+
- site
28+
z-score:
29+
- age
30+
min-max: []
31+
raw: []
32+
33+
model:
34+
architecture: vae
35+
components:
36+
vae:
37+
encoder: mlp
38+
decoder: mlp
39+
covariate_embedding: no_embedding
40+
latent_dim: 8
41+
# Custom experiment sweep keys used by `--mode experiment`.
42+
# This setup is similar to the previous thesis sweep, but fully config-driven.
43+
experiment_matrix:
44+
embedding_methods:
45+
- no_embedding
46+
- encoderdecoder_embedding
47+
- fair_embedding
48+
latent_dims:
49+
- 1
50+
- 2
51+
- 3
52+
- 4
53+
- 5
54+
- 8
55+
- 12
56+
- 16
57+
# Previous setup equivalent: explicit per-file list (site x repetition),
58+
# and `repetitions: 1`.
59+
dataset_files:
60+
- site_dataset_0_site_0.rds
61+
- site_dataset_0_site_1.rds
62+
- site_dataset_0_site_2.rds
63+
- site_dataset_1_site_0.rds
64+
- site_dataset_1_site_1.rds
65+
- site_dataset_1_site_2.rds
66+
- site_dataset_2_site_0.rds
67+
- site_dataset_2_site_1.rds
68+
- site_dataset_2_site_2.rds
69+
repetitions: 1
70+
71+
train:
72+
loss_function: mse_vae
73+
batch_size: 64
74+
epochs: 100
75+
76+
validation:
77+
model: your_model_best.safetensors

config/config_example_smri.yml

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Example-only sMRI configuration template.
2+
# This file is intentionally NOT runnable as-is because it references placeholder
3+
# data/model artifacts that are not part of this repository.
4+
5+
meta:
6+
name: smri_example_template
7+
description: Example-only sMRI template (non-runnable placeholder values)
8+
9+
dataset:
10+
input_data: smri_example_dataset_not_in_repo.rds
11+
unique_identifier_column: subject_id
12+
row_data_leakage_columns:
13+
- subject_id
14+
skipped_columns:
15+
- subject_id
16+
- diagnosis
17+
covariates:
18+
- age
19+
- sex
20+
- site
21+
- wave
22+
enable_transforms: true
23+
transforms:
24+
- name: DataCleaningTransform
25+
type: preprocessing
26+
params:
27+
drop_na: true
28+
remove_duplicates: true
29+
- name: EncodingTransform
30+
type: preprocessing
31+
params:
32+
default: z-score
33+
one_hot_encoding:
34+
- sex
35+
- site
36+
- wave
37+
z-score:
38+
- age
39+
min-max: []
40+
raw: []
41+
- name: SiteFilterTransform
42+
type: preprocessing
43+
params:
44+
selected_site: -1
45+
col_name: site
46+
- name: WaveFilterTransform
47+
type: preprocessing
48+
params:
49+
selected_wave: -1
50+
col_name: wave
51+
- name: AgeFilterTransform
52+
type: preprocessing
53+
params:
54+
age_lowerbound: 6.0
55+
age_upperbound: 25.0
56+
col_name: age
57+
- name: SexFilterTransform
58+
type: preprocessing
59+
params:
60+
sex: -1
61+
col_name: sex
62+
63+
model:
64+
components:
65+
vae:
66+
encoder: mlp
67+
decoder: mlp
68+
latent_dim: 16
69+
covariate_embedding: fair_embedding
70+
71+
train:
72+
loss_function: mse_vae
73+
batch_size: 64
74+
epochs: 100
75+
76+
validation:
77+
model: smri_example_model_best.safetensors

src/config/config_schema.py

Lines changed: 24 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ class DatasetConfig(BaseModel):
102102
covariates: list[str] = []
103103
skipped_covariates: list[str] = []
104104
targets: list[str] = []
105-
input_data: str = "generated_data.rds"
105+
input_data: str = "your_dataset.csv"
106106
data_type: str = "tabular"
107107
image_type: str = "grayscale"
108108
internal_file_format: str = "hdf"
@@ -122,38 +122,13 @@ class DatasetConfig(BaseModel):
122122
name="EncodingTransform",
123123
type="preprocessing",
124124
params={
125-
"default": "z-score",
126-
"one_hot_encoding": ["site", "sex"],
127-
"z-score": ["age"],
125+
"default": "min-max",
126+
"one_hot_encoding": [],
127+
"z-score": [],
128128
"min-max": [],
129129
"raw": [],
130130
},
131131
),
132-
TransformConfig(
133-
name="SiteFilterTransform",
134-
type="preprocessing",
135-
params={"selected_site": -1, "col_name": "site"},
136-
),
137-
TransformConfig(
138-
name="WaveFilterTransform",
139-
type="preprocessing",
140-
params={"selected_wave": -1, "col_name": "wave"},
141-
),
142-
TransformConfig(
143-
name="AgeFilterTransform",
144-
type="preprocessing",
145-
params={"age_lowerbound": 0.0, "age_upperbound": 100.0, "col_name": "age"},
146-
),
147-
TransformConfig(
148-
name="SexFilterTransform",
149-
type="preprocessing",
150-
params={"sex": -1, "col_name": "sex"},
151-
),
152-
TransformConfig(
153-
name="SampleLimitTransform",
154-
type="preprocessing",
155-
params={"max_samples": 1000, "shuffle": True},
156-
),
157132
]
158133

159134
@model_validator(mode="after")
@@ -178,11 +153,27 @@ class MetaConfig(BaseModel):
178153
"""Metadata configuration."""
179154

180155
config_version: int = 2
181-
description: str = "Variational Autoencoder design experiment setup"
156+
description: str = "Multivariate normative modeling configuration"
182157
name: str = "vae_basic"
183158
version: int = 1
184159

185160

161+
class ExperimentMatrixConfig(BaseModel):
162+
"""Experiment sweep configuration used by experiment mode."""
163+
164+
embedding_methods: list[str] = []
165+
latent_dims: list[int] = []
166+
dataset_files: list[str] = []
167+
repetitions: int = Field(default=1, ge=1)
168+
169+
@model_validator(mode="after")
170+
def validate_latent_dims(self) -> "ExperimentMatrixConfig":
171+
"""Ensure configured latent dimensions are strictly positive."""
172+
if any(latent_dim <= 0 for latent_dim in self.latent_dims):
173+
raise ValueError("All experiment_matrix.latent_dims must be > 0.")
174+
return self
175+
176+
186177
class ModelConfig(BaseModel):
187178
"""Model configuration."""
188179

@@ -222,6 +213,9 @@ class ModelConfig(BaseModel):
222213
# },
223214
}
224215
)
216+
experiment_matrix: ExperimentMatrixConfig = Field(
217+
default_factory=ExperimentMatrixConfig
218+
)
225219

226220
# Model components
227221
hidden_layers: list[int] = [1024, 512, 256]

0 commit comments

Comments
 (0)