-
Notifications
You must be signed in to change notification settings - Fork 9
Expand file tree
/
Copy pathparallel_training.py
More file actions
109 lines (92 loc) · 2.59 KB
/
parallel_training.py
File metadata and controls
109 lines (92 loc) · 2.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
#!/usr/bin/env python3
import subprocess
import time
import os
# Configuration for 4x 3090s
# For our current implementation, relative training speed is:
# standard / p_anneal > top_k > batch_top_k > jump_relu > gated
# So, we have jump_relu and gated on their own GPUs
MODEL_NAME = "EleutherAI/pythia-160m-deduped"
MODEL_NAME = "google/gemma-2-2b"
# MODEL_NAME = "EleutherAI/pythia-70m-deduped"
# So we are forced to set MODEL_NAME before running this script
MODEL_NAME = None
if "gemma" in MODEL_NAME:
layer = 12
elif "pythia-70m" in MODEL_NAME:
layer = 3
elif "pythia-160m" in MODEL_NAME:
layer = 8
else:
raise ValueError("Unknown model name")
configurations = [
{
"arch": "jump_relu",
"layers": layer,
"device": "cuda:0",
"save_checkpoints": False
},
{
"arch": "top_k p_anneal",
"layers": layer,
"device": "cuda:1",
"save_checkpoints": False
},
{
"arch": "batch_top_k standard_new",
"layers": layer,
"device": "cuda:2",
"save_checkpoints": False
},
{
"arch": "gated",
"layers": layer,
"device": "cuda:3",
"save_checkpoints": False
},
]
# config for 2x 3090s
# configurations = [
# {
# "arch": "gated top_k p_anneal",
# "layers": layer,
# "device": "cuda:0"
# },
# {
# "arch": "jump_relu batch_top_k standard_new",
# "layers": layer,
# "device": "cuda:1"
# },
# ]
SAVE_DIR = "trained_saes/"
# Create logs directory if it doesn't exist
os.makedirs("logs", exist_ok=True)
# Launch jobs
for i, config in enumerate(configurations):
log_file = f"logs/{(config['arch'].replace(' ', '_'))}_l{config['layers']}_{config['device'].replace(':', '_')}.out"
if config["save_checkpoints"]:
save_command = "--save_checkpoints"
else:
save_command = ""
cmd = [
"python", "demo.py",
"--save_dir", SAVE_DIR,
"--model_name", MODEL_NAME,
"--architectures", config["arch"],
"--layers", str(config["layers"]),
"--device", config["device"],
save_command,
# "--use_wandb"
]
print(" ".join(cmd))
# Launch with nohup
with open(log_file, "w") as f:
subprocess.Popen(
f"nohup {' '.join(cmd)} > {log_file} 2>&1",
shell=True,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL
)
print(f"Started job {i + 1}/{len(configurations)}: {config['arch']} with layers: {config['layers']}")
time.sleep(2)
print("All jobs submitted!")