-
Notifications
You must be signed in to change notification settings - Fork 39
Description
Describe the bug
ChatLearn fails to utilize all instances when num_replica of next model is less than the number of output batch_size. For example, when the batch size of policy outputs is smaller than the number of code reward instances, Chatlearn currently does not further divide the batch size to fully utilize the number of code reward instances.
To Reproduce
A simple UT to reproduce the bug:
import os
import time
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import chatlearn
from chatlearn import RLHFEngine
from chatlearn import TorchModule
from chatlearn.utils import future
class CustomDataset(Dataset):
def __init__(self, data):
self.data = data
self.collate_fn = None
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return {"query": self.data[idx]}
chatlearn.init()
class PolicyModel(TorchModule):
counter = 1
def _get_rank(self):
return int(os.environ["RANK"])
@property
def data_parallel_size(self):
return 2
@property
def data_parallel_rank(self):
if self._get_rank() < 4:
return 0
return 1
def forward_step(self, data, iteration):
print(f"policy forward {self.counter}=========", flush=True)
query = data["query"]
bs = query.size(0)
data["policy_out"] = torch.ones([bs, 1024]).cuda()
self.counter += 1
return data
def build_dataset(self, prompts, is_eval=False):
dataset = CustomDataset(prompts)
return dataset
class ReferenceModel(TorchModule):
counter = 1
def _get_rank(self):
return int(os.environ["RANK"])
@property
def data_parallel_size(self):
return 8
@property
def data_parallel_rank(self):
return self._get_rank()
def forward_step(self, data, iteration):
print(f"reference forward {self.counter} on {self.data_parallel_rank}=========", flush=True)
query = data["policy_out"].cuda()
data["ref_out"] = query * 2
self.counter += 1
return data
class RewardModel(TorchModule):
counter = 1
def _get_rank(self):
return int(os.environ["RANK"])
@property
def data_parallel_size(self):
return 8
@property
def data_parallel_rank(self):
return self._get_rank()
def forward_step(self, data, iteration):
print(f"reward forward {self.counter}=========", flush=True)
data["reward_out"] = data["ref_out"].cuda() + data["policy_out"].cuda()
self.counter += 1
return data
class ValueModel(TorchModule):
counter = 1
def _get_rank(self):
return int(os.environ["RANK"])
@property
def data_parallel_size(self):
return 8
@property
def data_parallel_rank(self):
return self._get_rank()
def forward_step(self, data, iteration):
print(f"value forward {self.counter}=========", flush=True)
data["value_out"] = data["policy_out"].cuda() * 3
self.counter += 1
return data
class PPOPolicy(TorchModule):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.data = []
self.counter = 1
def _get_rank(self):
return int(os.environ["RANK"])
@property
def data_parallel_size(self):
return 2
@property
def data_parallel_rank(self):
if self._get_rank() < 4:
return 0
return 1
def train_step(self, data, iteration):
print(f"ppo policy train_step {self.counter}========= {self.data_parallel_rank}", flush=True)
self.data.append(data)
num_mb = len(data)
self.counter += 1
return num_mb
def get_data(self):
return self.data
class PPOValue(TorchModule):
counter = 1
def _get_rank(self):
return int(os.environ["RANK"])
@property
def data_parallel_size(self):
return 2
@property
def data_parallel_rank(self):
if self._get_rank() < 4:
return 0
return 1
def train_step(self, data, iteration):
print(f"ppo value train_step {self.counter}=========", flush=True)
num_mb = len(data)
self.counter += 1
return num_mb
for _, model_config in chatlearn.get_args().models.items():
model_config.num_gpu = 8
chatlearn.get_args().models['policy'].expert_model_parallel_size = 1
chatlearn.get_args().models['reference'].expert_model_parallel_size = 1
chatlearn.get_args().models['reward'].expert_model_parallel_size = 1
chatlearn.get_args().models['value'].expert_model_parallel_size = 1
chatlearn.get_args().models['policy'].tensor_model_parallel_size = 4
chatlearn.get_args().models['reference'].tensor_model_parallel_size = 1
chatlearn.get_args().models['reward'].tensor_model_parallel_size = 1
chatlearn.get_args().models['value'].tensor_model_parallel_size = 1
chatlearn.get_args().models['ppo_policy'].expert_model_parallel_size = 1
chatlearn.get_args().models['ppo_value'].expert_model_parallel_size = 1
chatlearn.get_args().models['ppo_policy'].tensor_model_parallel_size = 4
chatlearn.get_args().models['ppo_value'].tensor_model_parallel_size = 4
chatlearn.get_args().models['policy'].generation_batch_size = 4
chatlearn.get_args().models['reference'].generation_batch_size = 2
chatlearn.get_args().models['reward'].generation_batch_size = 2
chatlearn.get_args().models['value'].generation_batch_size = 2
chatlearn.get_args().runtime_args.colocation = [["policy", "reference", "reward", "value", "ppo_policy", "ppo_value"]]
chatlearn.get_args().runtime_args.train_micro_batch_size = 4
chatlearn.get_args().runtime_args.train_global_batch_size = 8
chatlearn.get_args().runtime_args.max_relay_episode = 1
chatlearn.get_args().runtime_args.sample_per_episode = 16
policy = PolicyModel("policy")
reference = ReferenceModel("reference")
reward = RewardModel("reward")
value = ValueModel("value")
ppo_policy = PPOPolicy("ppo_policy")
ppo_value = PPOValue("ppo_value")
engine = RLHFEngine(policy, reference, reward, value, ppo_policy, ppo_value)
def relay_sample_fn(episode_relay_buffers):
buffer = episode_relay_buffers[-1].buffer
episode_id = episode_relay_buffers[-1]._episode_id
assert len(buffer) == 16
for i in range(len(buffer)):
assert int(buffer[i]['query'][0].item()) == i + episode_id * 16
return buffer
engine.set_relay_sample_fn(relay_sample_fn)
# for inference models, they have 2 dp replicas
assert policy.num_replica == 2
assert reference.num_replica == 8
assert reward.num_replica == 8
assert value.num_replica == 8
# for training models, ep is combined into dp, leading to only 1 replica
assert ppo_policy.num_replica == 1
assert ppo_value.num_replica == 1
data = [torch.ones([1024]) * i for i in range(2048)]
engine.set_dataset(data)
engine.learn()
assert future.get(engine.named_models['reference'].replicas[-1].get_runtime_args()[0]).consumed_samples == 32and add
def get_runtime_args(self):
return self.runtime_argsto base_module.py.
The program will encounter assertion fail in assert future.get(engine.named_models['reference'].replicas[-1].get_runtime_args()[0]).consumed_samples == 32.
Expected behavior
Pass the UT with future.get(engine.named_models['reference'].replicas[-1].get_runtime_args()[0]).consumed_samples == 32
Additional context
The issue might not be of top priority because users could adjust the parallel size of different parallel strategies to change the num_replica.