-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
114 lines (83 loc) · 3.78 KB
/
main.py
File metadata and controls
114 lines (83 loc) · 3.78 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
"""CMVR: CLI entry point for SSL pretraining and downstream tasks."""
import argparse
import multiprocessing
import yaml
multiprocessing.set_start_method("fork", force=True)
def load_config(config_path: str, overrides: dict) -> dict:
"""Load YAML config and apply CLI overrides."""
with open(config_path) as f:
config = yaml.safe_load(f)
# Apply flat CLI overrides like --training.lr=1e-3
for key, value in overrides.items():
parts = key.split(".")
d = config
for part in parts[:-1]:
d = d[part]
# Try to cast to the original type
orig = d.get(parts[-1])
if orig is not None:
target_type = type(orig)
if target_type is bool:
value = value.lower() in ("true", "1", "yes")
else:
value = target_type(value)
d[parts[-1]] = value
return config
def parse_overrides(args: list[str]) -> dict:
"""Parse --key=value style overrides from remaining CLI args."""
overrides = {}
for arg in args:
if arg.startswith("--") and "=" in arg:
key, value = arg[2:].split("=", 1)
overrides[key] = value
return overrides
def cmd_pretrain_moco(args: argparse.Namespace, remaining: list[str]) -> None:
"""Run MoCo v2 pretraining."""
from ssl_methods.moco import train_moco
overrides = parse_overrides(remaining)
config = load_config(args.config, overrides)
train_moco(config)
def cmd_pretrain_dino(args: argparse.Namespace, remaining: list[str]) -> None:
"""Run DINO pretraining."""
from ssl_methods.dino import train_dino
overrides = parse_overrides(remaining)
config = load_config(args.config, overrides)
train_dino(config)
def cmd_pretrain_spark(args: argparse.Namespace, remaining: list[str]) -> None:
"""Run SparK pretraining."""
from ssl_methods.spark import train_spark
overrides = parse_overrides(remaining)
config = load_config(args.config, overrides)
train_spark(config)
def cmd_pretrain_barlow(args: argparse.Namespace, remaining: list[str]) -> None:
"""Run BarlowTwins pretraining."""
from ssl_methods.barlow import train_barlow
overrides = parse_overrides(remaining)
config = load_config(args.config, overrides)
train_barlow(config)
def main() -> None:
parser = argparse.ArgumentParser(prog="cmvr", description="CMVR: Chest X-ray SSL Pretraining")
subparsers = parser.add_subparsers(dest="command", required=True)
# pretrain-moco
p_moco = subparsers.add_parser("pretrain-moco", help="Run MoCo v2 self-supervised pretraining")
p_moco.add_argument("--config", type=str, default="configs/moco.yaml", help="Path to config YAML")
# pretrain-dino
p_dino = subparsers.add_parser("pretrain-dino", help="Run DINO self-supervised pretraining")
p_dino.add_argument("--config", type=str, default="configs/dino.yaml", help="Path to config YAML")
# pretrain-spark
p_spark = subparsers.add_parser("pretrain-spark", help="Run SparK self-supervised pretraining")
p_spark.add_argument("--config", type=str, default="configs/spark.yaml", help="Path to config YAML")
# pretrain-barlow
p_barlow = subparsers.add_parser("pretrain-barlow", help="Run BarlowTwins self-supervised pretraining")
p_barlow.add_argument("--config", type=str, default="configs/barlow.yaml", help="Path to config YAML")
args, remaining = parser.parse_known_args()
if args.command == "pretrain-moco":
cmd_pretrain_moco(args, remaining)
elif args.command == "pretrain-dino":
cmd_pretrain_dino(args, remaining)
elif args.command == "pretrain-spark":
cmd_pretrain_spark(args, remaining)
elif args.command == "pretrain-barlow":
cmd_pretrain_barlow(args, remaining)
if __name__ == "__main__":
main()