-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy pathengine_rl.py
More file actions
137 lines (114 loc) · 3.93 KB
/
engine_rl.py
File metadata and controls
137 lines (114 loc) · 3.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import os, sys, time, math, json, importlib
import torch
import datetime
from collections import defaultdict, OrderedDict
import numpy as np
from tqdm import tqdm
from models.rl.grpo_trainer import GRPOTrainer
from utils.io import save_checkpoint
from utils.misc import SmoothedValue
from utils.dist import (
init_distributed,
is_distributed,
is_primary,
get_rank,
barrier,
all_reduce_average,
all_gather_dict
)
class Logger:
def __init__(self, args):
exp_name = os.path.split(args.checkpoint_dir)[-1]
self.logger = open(os.path.join(args.checkpoint_dir, f'{exp_name}-rl-logger.log'), 'a')
def __call__(self, info_str):
self.logger.write(info_str + "\n")
self.logger.flush()
print(info_str)
def do_rl_train(
args,
model,
ref_model,
tokenizer,
dataset_config,
dataloaders,
best_val_metrics=dict()
):
"""
RL training using GRPO
"""
logout = Logger(args)
if is_primary():
logout(f"Starting RL training with args: {args}")
logout(f"Model: {model}")
device = next(model.parameters()).device
# Initialize GRPO trainer
grpo_config = {
'beta': args.rl_beta,
'lr': args.rl_lr,
'batch_size': args.batchsize_per_gpu,
'num_epochs': args.rl_num_epochs,
'max_grad_norm': args.rl_max_grad_norm,
}
trainer = GRPOTrainer(
model=model,
ref_model=ref_model,
tokenizer=tokenizer,
device=device,
grpo_config=grpo_config
)
# Training loop
curr_epoch = args.start_epoch
max_epochs = args.max_epoch
time_delta = SmoothedValue(window_size=10)
loss_avg = SmoothedValue(window_size=10)
model.train()
barrier()
for epoch in range(curr_epoch, max_epochs):
if is_distributed():
dataloaders["train_sampler"].set_epoch(epoch)
epoch_start_time = time.time()
# Train one epoch
epoch_stats = trainer.train_epoch(dataloaders['train'])
epoch_time = time.time() - epoch_start_time
# Log statistics
if is_primary():
logout(f"Epoch {epoch}/{max_epochs} completed in {epoch_time:.2f}s")
logout(f"Policy Loss: {epoch_stats['policy_loss']:.4f}")
logout(f"KL Divergence: {epoch_stats['kl_div']:.4f}")
logout(f"Average Reward: {epoch_stats['avg_reward']:.4f}")
logout(f"Format Reward: {epoch_stats['format_reward']:.4f}")
logout(f"Perception Reward: {epoch_stats['perception_reward']:.4f}")
logout(f"Semantic Reward: {epoch_stats['semantic_reward']:.4f}")
# Save checkpoint
if is_primary() and (epoch + 1) % args.save_every == 0:
checkpoint_path = os.path.join(
args.checkpoint_dir,
f"rl_checkpoint_epoch_{epoch}.pth"
)
trainer.save_checkpoint(checkpoint_path, epoch, epoch_stats)
logout(f"Saved checkpoint to {checkpoint_path}")
# Evaluation
if (epoch + 1) % args.eval_every_iteration == 0:
model.eval()
for test_loader in dataloaders['test']:
test_loader.dataset.eval_func(
args,
epoch,
model,
dataset_config,
test_loader
)
model.train()
barrier()
logout("RL training completed!")
def load_sft_model(args, model):
"""
Load SFT model weights for RL training
"""
if args.sft_checkpoint is not None:
checkpoint = torch.load(args.sft_checkpoint, map_location="cpu")
model.load_state_dict(checkpoint['model'], strict=False)
print(f"Loaded SFT checkpoint from {args.sft_checkpoint}")
else:
print("Warning: No SFT checkpoint provided for RL training")
return model