-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathengine_merging.py
More file actions
73 lines (60 loc) · 2.56 KB
/
engine_merging.py
File metadata and controls
73 lines (60 loc) · 2.56 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import os
import argparse
import torch
from typing import Dict
from task_vectors import TaskVector, merge_max_abs
import utils.misc as misc
def create_combined_encoder(
model_combination: Dict[str, str],
train_model: str,
pretrained_model_path: str,
which_merging_technique: str,
output_dir: str,
suffix: str,
device: torch.device,
scaling_coef: float,
args: argparse.Namespace
) -> torch.nn.Module:
"""
Create a combined encoder using task vectors.
Args:
model_combination: Dictionary mapping model names to checkpoint epochs
pretrained_model_path: Path to pre-trained model
which_merging_technique: Merging technique ('task_vectors' or 'magmax')
output_dir: Base output directory
suffix: Suffix to add to the combined model name
device: Device to create model on
scaling_coef: Scaling coefficient for task vector application
Returns:
Combined encoder model
"""
# Create task vectors
task_vectors = []
for instrument, checkpoint in model_combination.items():
checkpoint_path = os.path.join(
output_dir, "pretraining", train_model, f"{instrument}_True",
f"checkpoint_{instrument}_True-{checkpoint}.pth"
)
task_vectors.append(TaskVector(pretrained_model_path, checkpoint_path))
# Merge task vectors
if which_merging_technique == "task_vectors":
task_vector_sum = sum(task_vectors)
elif which_merging_technique == "magmax":
task_vector_sum = merge_max_abs(task_vectors)
else:
raise ValueError(f"Unknown merging technique: {which_merging_technique}. "
f"Available choices: task_vectors, magmax")
# Apply task vector to pre-trained model
combined_encoder = task_vector_sum.apply_to(
pretrained_model_path, train_model, device, args, scaling_coef=scaling_coef
).to(device)
# Save combined model
combined_model_name = "_".join([f"{instrument}" for instrument, checkpoint in model_combination.items()])
save_dir = os.path.join(output_dir, "pretraining", train_model, f"model_merging_{which_merging_technique}")
os.makedirs(save_dir, exist_ok=True)
print(f"Saving combined encoder to {os.path.join(save_dir, f'{combined_model_name}_{suffix}.pth')}")
if "vit" in train_model:
misc.save_model(args, save_dir, f"{combined_model_name}_{suffix}", combined_encoder)
else:
torch.save(combined_encoder.state_dict(), os.path.join(save_dir, f"{combined_model_name}_{suffix}.pth"))
return combined_encoder