diff --git a/bigearthnet/configs/config.yaml b/bigearthnet/configs/config.yaml index 6e9423a..0de3e08 100644 --- a/bigearthnet/configs/config.yaml +++ b/bigearthnet/configs/config.yaml @@ -67,3 +67,4 @@ defaults: - _self_ - model: baseline.yaml # uses the default baseline model when training - transforms: norm.yaml # performs normalization as default transform + - override hydra/sweeper: orion diff --git a/bigearthnet/configs/hydra/sweeper/orion.yaml b/bigearthnet/configs/hydra/sweeper/orion.yaml new file mode 100644 index 0000000..8936d56 --- /dev/null +++ b/bigearthnet/configs/hydra/sweeper/orion.yaml @@ -0,0 +1,28 @@ +orion: + name: 'experiment' + version: '1' + +algorithm: + type: random + config: + seed: 1 + +worker: + n_workers: -1 + max_broken: 3 + max_trials: 3 + +storage: + type: legacy + + database: + type: pickleddb + host: 'database.pkl' + +# parametrization of the hyperparameter space +parametrization: + datamodule: + batch_size: "choices([64, 128, 256])" + optimizer: + lr: "uniform(0, 1)" + name: "choices(['adam', 'sgd'])" diff --git a/bigearthnet/train.py b/bigearthnet/train.py index d60cd94..485d302 100644 --- a/bigearthnet/train.py +++ b/bigearthnet/train.py @@ -29,6 +29,8 @@ def main(cfg: DictConfig): trainer.fit(model, datamodule=datamodule) log.info("Training Done.") + return model.val_best_metric + if __name__ == "__main__": main() diff --git a/setup.py b/setup.py index 1deec00..c3a77d7 100644 --- a/setup.py +++ b/setup.py @@ -14,6 +14,7 @@ "jupyter", "matplotlib", "numpy>=1.23", + "orion", "pyyaml>=5.3", "pytest", "pytorch_lightning==1.6.4",