Skip to content

[BUG] fails to utilize all instances when num_replica of next model is less than the number of output batch_size #209

@haolin-nju

Description

@haolin-nju

Describe the bug
ChatLearn fails to utilize all instances when num_replica of next model is less than the number of output batch_size. For example, when the batch size of policy outputs is smaller than the number of code reward instances, Chatlearn currently does not further divide the batch size to fully utilize the number of code reward instances.

To Reproduce
A simple UT to reproduce the bug:

import os
import time

import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

import chatlearn
from chatlearn import RLHFEngine
from chatlearn import TorchModule
from chatlearn.utils import future


class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data
        self.collate_fn = None

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return {"query": self.data[idx]}


chatlearn.init()

class PolicyModel(TorchModule):
    counter = 1

    def _get_rank(self):
        return int(os.environ["RANK"])

    @property
    def data_parallel_size(self):
        return 2

    @property
    def data_parallel_rank(self):
        if self._get_rank() < 4:
            return 0
        return 1

    def forward_step(self, data, iteration):
        print(f"policy forward {self.counter}=========", flush=True)
        query = data["query"]
        bs = query.size(0)
        data["policy_out"] = torch.ones([bs, 1024]).cuda()
        self.counter += 1
        return data

    def build_dataset(self, prompts, is_eval=False):
        dataset = CustomDataset(prompts)
        return dataset



class ReferenceModel(TorchModule):
    counter = 1

    def _get_rank(self):
        return int(os.environ["RANK"])

    @property
    def data_parallel_size(self):
        return 8

    @property
    def data_parallel_rank(self):
        return self._get_rank()

    def forward_step(self, data, iteration):
        print(f"reference forward {self.counter} on {self.data_parallel_rank}=========", flush=True)
        query = data["policy_out"].cuda()
        data["ref_out"] = query * 2
        self.counter += 1
        return data


class RewardModel(TorchModule):
    counter = 1

    def _get_rank(self):
        return int(os.environ["RANK"])

    @property
    def data_parallel_size(self):
        return 8

    @property
    def data_parallel_rank(self):
        return self._get_rank()

    def forward_step(self, data, iteration):
        print(f"reward forward {self.counter}=========", flush=True)
        data["reward_out"] = data["ref_out"].cuda() + data["policy_out"].cuda()
        self.counter += 1
        return data

class ValueModel(TorchModule):
    counter = 1

    def _get_rank(self):
        return int(os.environ["RANK"])

    @property
    def data_parallel_size(self):
        return 8

    @property
    def data_parallel_rank(self):
        return self._get_rank()

    def forward_step(self, data, iteration):
        print(f"value forward {self.counter}=========", flush=True)
        data["value_out"] = data["policy_out"].cuda() * 3
        self.counter += 1
        return data


class PPOPolicy(TorchModule):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.data = []
        self.counter = 1

    def _get_rank(self):
        return int(os.environ["RANK"])

    @property
    def data_parallel_size(self):
        return 2

    @property
    def data_parallel_rank(self):
        if self._get_rank() < 4:
            return 0
        return 1

    def train_step(self, data, iteration):
        print(f"ppo policy train_step {self.counter}========= {self.data_parallel_rank}", flush=True)
        self.data.append(data)
        num_mb = len(data)
        self.counter += 1
        return num_mb

    def get_data(self):
        return self.data

class PPOValue(TorchModule):
    counter = 1

    def _get_rank(self):
        return int(os.environ["RANK"])

    @property
    def data_parallel_size(self):
        return 2

    @property
    def data_parallel_rank(self):
        if self._get_rank() < 4:
            return 0
        return 1

    def train_step(self, data, iteration):
        print(f"ppo value train_step {self.counter}=========", flush=True)
        num_mb = len(data)
        self.counter += 1
        return num_mb

for _, model_config in chatlearn.get_args().models.items():
    model_config.num_gpu = 8

chatlearn.get_args().models['policy'].expert_model_parallel_size = 1
chatlearn.get_args().models['reference'].expert_model_parallel_size = 1
chatlearn.get_args().models['reward'].expert_model_parallel_size = 1
chatlearn.get_args().models['value'].expert_model_parallel_size = 1

chatlearn.get_args().models['policy'].tensor_model_parallel_size = 4
chatlearn.get_args().models['reference'].tensor_model_parallel_size = 1
chatlearn.get_args().models['reward'].tensor_model_parallel_size = 1
chatlearn.get_args().models['value'].tensor_model_parallel_size = 1

chatlearn.get_args().models['ppo_policy'].expert_model_parallel_size = 1
chatlearn.get_args().models['ppo_value'].expert_model_parallel_size = 1

chatlearn.get_args().models['ppo_policy'].tensor_model_parallel_size = 4
chatlearn.get_args().models['ppo_value'].tensor_model_parallel_size = 4

chatlearn.get_args().models['policy'].generation_batch_size = 4
chatlearn.get_args().models['reference'].generation_batch_size = 2
chatlearn.get_args().models['reward'].generation_batch_size = 2
chatlearn.get_args().models['value'].generation_batch_size = 2

chatlearn.get_args().runtime_args.colocation = [["policy", "reference", "reward", "value", "ppo_policy", "ppo_value"]]
chatlearn.get_args().runtime_args.train_micro_batch_size = 4
chatlearn.get_args().runtime_args.train_global_batch_size = 8
chatlearn.get_args().runtime_args.max_relay_episode = 1
chatlearn.get_args().runtime_args.sample_per_episode = 16
policy = PolicyModel("policy")
reference = ReferenceModel("reference")
reward = RewardModel("reward")
value = ValueModel("value")
ppo_policy = PPOPolicy("ppo_policy")
ppo_value = PPOValue("ppo_value")

engine = RLHFEngine(policy, reference, reward, value, ppo_policy, ppo_value)

def relay_sample_fn(episode_relay_buffers):
    buffer = episode_relay_buffers[-1].buffer
    episode_id = episode_relay_buffers[-1]._episode_id
    assert len(buffer) == 16
    for i in range(len(buffer)):
        assert int(buffer[i]['query'][0].item()) == i + episode_id * 16
    return buffer

engine.set_relay_sample_fn(relay_sample_fn)
# for inference models, they have 2 dp replicas
assert policy.num_replica == 2
assert reference.num_replica == 8
assert reward.num_replica == 8
assert value.num_replica == 8
# for training models, ep is combined into dp, leading to only 1 replica
assert ppo_policy.num_replica == 1
assert ppo_value.num_replica == 1
data = [torch.ones([1024]) * i for i in range(2048)]
engine.set_dataset(data)
engine.learn()
assert future.get(engine.named_models['reference'].replicas[-1].get_runtime_args()[0]).consumed_samples == 32

and add

def get_runtime_args(self):
    return self.runtime_args

to base_module.py.

The program will encounter assertion fail in assert future.get(engine.named_models['reference'].replicas[-1].get_runtime_args()[0]).consumed_samples == 32.

Expected behavior
Pass the UT with future.get(engine.named_models['reference'].replicas[-1].get_runtime_args()[0]).consumed_samples == 32

Additional context
The issue might not be of top priority because users could adjust the parallel size of different parallel strategies to change the num_replica.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions