Skip to content

rosikand/minirl

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

minirl

minirl banner

Minimal RL training for LLMs. Generate with vLLM. Train with PyTorch FSDP. No magic.

Two backends, one training loop:

Co-located Async
How Same GPUs switch between vLLM ↔ FSDP Separate GPU pools, connected by queue
Scale 4–64 GPUs 100+ GPUs
Policy Fully on-policy Slightly off-policy
Module minirl.colocated minirl.async_rl

The training loop is identical in both — only the data source differs.

minirl architecture

Install

pip install minirl
pip install minirl[async]  # adds ZMQ for async backend

Requires Python ≥ 3.10, PyTorch ≥ 2.3, vLLM ≥ 0.8

Co-located: Same GPUs

from minirl import WeightManager, clipped_policy_loss, get_logprobs, parse_vllm_outputs
from vllm import SamplingParams

wm = WeightManager("Qwen/Qwen3-4B", lr=1e-6)

for step in range(num_steps):
    llm = wm.as_vllm()                          # → vllm.LLM
    outputs = llm.generate(prompts, SamplingParams(n=8, temperature=1.0, logprobs=1))
    batches = parse_vllm_outputs(outputs, group_size=8, reward_fn=my_reward_fn)
    for b in batches:
        b.compute_advantages()

    model, optimizer = wm.as_trainable()         # → (nn.Module, Optimizer)
    for epoch in range(4):
        for rollout in all_rollouts:
            out = model(rollout.token_ids, output_hidden_states=True)
            loss = clipped_policy_loss(get_logprobs(out.logits, rollout.token_ids),
                                       rollout.old_logprobs, rollout.advantage)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

Async: Separate GPU Pools

# === Script A: generation machines ===
from minirl import RolloutServer
server = RolloutServer("Qwen/Qwen3-4B", queue_address="tcp://*:5555")
server.run(prompts)

# === Script B: training machines ===
from minirl import TrainWorker, clipped_policy_loss, get_logprobs
worker = TrainWorker("Qwen/Qwen3-4B", queue_address="tcp://gen-server:5555")
for batches in worker.pull_rollouts():
    for b in batches:
        b.compute_advantages()
    for rollout in all_rollouts:
        out = worker.model(rollout.token_ids)
        loss = clipped_policy_loss(get_logprobs(out.logits, rollout.token_ids),
                                   rollout.old_logprobs, rollout.advantage)
        worker.optimizer.zero_grad()
        loss.backward()
        worker.optimizer.step()
    worker.sync_weights()   # push updated weights to generation server

API

Core — minirl.core

Function Description
compute_advantages(rewards, group_size) GRPO: A_i = r_i - mean(group)
clipped_policy_loss(new_lp, old_lp, adv) Clipped surrogate loss
get_logprobs(logits, token_ids) Extract per-token log-probs

Data — minirl.data

Class Description
Rollout Single completion with token IDs, log-probs, reward
RolloutBatch Group of G rollouts for one prompt
parse_vllm_outputs(...) Convert vLLM outputs → RolloutBatch list

Co-located — minirl.colocated

Returns Description
WeightManager(path) Initialize with model
.as_vllm() vllm.LLM Switch to generation mode
.as_trainable() (nn.Module, Optimizer) Switch to training mode

Async — minirl.async_rl

Class Description
RolloutServer(path, addr) vLLM generation, pushes rollouts to queue
TrainWorker(path, addr) FSDP training, pulls from queue, .sync_weights()

Design Principles

  • You own the loop. minirl manages weights. You write the training logic.
  • Raw objects. .as_vllm() returns a real vllm.LLM. .as_trainable() returns a real nn.Module. No wrappers.
  • Composable. Custom losses, auxiliary heads, hidden state access — all work because you have the raw model.

License

MIT

About

Minimal, flexible RL training engine for LLMs, supports async and co-location backends.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages