Skip to content

aryansid/B-SWiRL

Repository files navigation

B-SWiRL: Branched Step-Wise Reinforcement Learning for LLM Reasoning

Motivation

Modern LLMs often solve reasoning problems using chains of thought. Most RL setups, however, reward only the final answer, ignoring whether intermediate steps are useful. Recent work like SWiRL have shown that giving per-step rewards improves reasoning quality, but they still train on only one sampled continuation per state. That's a significant limitation: at each intermediate state, there are usually many plausible next steps

B-SWiRL addresses this by branching: at each state in a trajectory, we sample multiple candidate next steps and assign them independent step-level rewards. This converts each state into a small empirical action distribution, letting the policy learn comparatively which continuations are better or worse.

High Level Pipeline

Base model: Instruction-tuned Mistral 7B.
Goal: Improve final answer accuracy and step quality via offline step-wise RL with branching.

The pipeline has three main stages:

  1. Trajectory Generation (Backbone + Branches)

    • Use Mistral 7B (via TogetherAI) to generate full step-by-step solutions (Step 1: ..., Step 2: ..., Answer: ...) for GSM8K.
    • Parse each solution into a trajectory of (state, action) pairs:
      • state: problem + all previous steps
      • action: current step text
    • For each intermediate state, generate multiple alternative next steps using a separate next-action prompt.
    • Store:
      • Backbone actions (branch_id = 0, is_backbone = True)
      • Branched actions (branch_id > 0, is_backbone = False)
        All share the same (question_id, step, state) but differ in action.
  2. Step-Level Reward Assignment

    • Use a frozen LLM judge (Qwen 2.5 72B) to assign binary rewards to each (state, action) pair:
      • Intermediate steps: judged for local correctness, consistency with history, and progress toward a solution.
      • Final answer step: bypass the judge and compare directly to the GSM8K gold answer.
    • The judge prompt is designed to:
      • Prefer steps that actually compute something over pure planning fluff.
      • Score steps locally, so a good step can still be rewarded even if earlier steps made a mistake.
    • Output: a reward-annotated JSONL file with potentially multiple actions + rewards per state.
  3. Offline RL Training (SWiRL vs. B-SWiRL)

    • Model: 4-bit quantized Mistral 7B with LoRA adapters; only adapter weights are trained.

    • SWiRL (single action per state):

      • Dataset: one (state, action, reward) per state.

      • Loss:

        $$ L_{\text{single}} = -\mathbb{E}\big[R \cdot \log \pi_\theta(a \mid s)\big]

        • \beta , \mathrm{KL}\big(\pi_\theta(\cdot \mid s),|,\pi_{\text{base}}(\cdot \mid s)\big) $$
    • B-SWiRL (branched):

      • Group data by (question_id, step) → one group = one state with multiple actions.

      • For each state group:

        • Compute per-state baseline ( b_i = \frac{1}{B_i} \sum_j R_i^{(j)} )
        • Advantage: ( A_i^{(j)} = R_i^{(j)} - b_i )
      • Loss:

        $$ L_{\text{branch}} = -\mathbb{E}\big[ A \cdot \log \pi_\theta(a \mid s) \big]

        • \beta , \mathrm{KL}\big(\pi_\theta(\cdot \mid s),|,\pi_{\text{base}}(\cdot \mid s)\big) $$
      • This encourages increasing probability of above-average branches and decreasing probability of below-average ones at the same state.

We also track how training changes the model’s log-probabilities on a held-out set of steps, separately for positive and negative rewards, to verify that the policy is actually following the reward signal.

About

Branched Step-Wise Reinforcement Learning for LLM Reasoning

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages