Skip to content

Latest commit

 

History

History
196 lines (115 loc) · 10 KB

File metadata and controls

196 lines (115 loc) · 10 KB

PPO(Proximal Policy Optimization)学习笔记

1. On-Policy 与 Off-Policy

1.1 On-Policy(朴素方法)

在朴素方法中,我们使用策略 $\pi_\theta$ 和环境交互若干次,得到一批回合数据,然后用这批数据计算出来的奖励值去更新 $\pi_\theta$

特点: 产出数据的策略和用这批数据做更新的策略是同一个

1.2 Off-Policy(PPO 的做法)

为了降低采样成本、提升训练效率,同时更加"谨慎"地更新模型,PPO 希望把收集到的一批数据重复使用 k 次

  1. 假设某次更新完毕后,我们得到策略 $\pi_{old}$
  2. $\pi_{old}$ 和环境交互,得到一批回合数据
  3. 将这一批回合数据重复使用 k 次:
    • 先喂给 $\pi_{old}$,更新得到 $\pi_{\theta_0}$
    • 再喂给 $\pi_{\theta_0}$,更新得到 $\pi_{\theta_1}$
    • 以此类推,做 k 次更新后得到 $\pi_\theta$
  4. 在这 k 次更新后,令 $\pi_{old} = \pi_\theta$,重复上面的过程,直到达到设定的停止条件

特点: 产出数据的策略($\pi_{old}$)和用这批数据做更新的策略($\pi_\theta$)不是同一个,这就是 off-policy。

核心问题: 数据是 $\pi_{old}$ 产出的,但我们要更新的是 $\pi_\theta$。随着更新,$\pi_\theta$ 和 $\pi_{old}$ 越来越不同,直接用这批数据就不准确了。怎么办?——重要性采样


2. 重要性采样(Importance Sampling)

2.1 问题定义

  • 目标: 计算 $E_{x \sim p(x)}[f(x)]$,即从分布 $p(x)$ 中采样,求函数 $f(x)$ 的期望
  • 困境: 无法从 $p(x)$ 中直接采样,只能从另一个分布 $q(x)$ 中采样

2.2 数学推导

第一步: 期望的定义展开

$$E_{x \sim p(x)}[f(x)] = \int p(x) f(x) , dx$$

这里 $p(x)$ 是概率密度函数,$f(x)$ 是我们关心的函数。对于连续分布,期望就是"函数值 × 概率密度"在整个定义域上的积分。

第二步: 分子分母同乘 $q(x)$(关键一步)

$$= \int \frac{p(x)}{q(x)} q(x) f(x) , dx$$

这一步纯粹是数学恒等变换:在被积函数中乘以 $\frac{q(x)}{q(x)} = 1$,不改变积分值,只是把形式重组了。

第三步: 识别出新的期望形式

$$= E_{x \sim q(x)}\left[\frac{p(x)}{q(x)} f(x)\right]$$

为什么 $\int \frac{p(x)}{q(x)} q(x) f(x) , dx$ 可以写成 $E_{x \sim q(x)}[\cdots]$

回顾期望的定义:对于任意分布 $q(x)$ 和任意函数 $g(x)$

$$E_{x \sim q(x)}[g(x)] = \int q(x) \cdot g(x) , dx$$

现在令 $g(x) = \frac{p(x)}{q(x)} f(x)$,代入:

$$E_{x \sim q(x)}\left[\frac{p(x)}{q(x)} f(x)\right] = \int q(x) \cdot \frac{p(x)}{q(x)} f(x) , dx$$

这和第二步的积分形式完全一致。所以第二步到第三步只是"认出"了积分的结构符合 $q(x)$ 下的期望定义。

2.3 直觉理解

重要性采样本质上是一个加权修正

  • $q(x)$ 中采了一个样本 $x$
  • 如果 $p(x)$ 在这一点的概率比 $q(x)$ (即 $q(x)$ 对这个点采样不够),权重 $\frac{p(x)}{q(x)} > 1$放大这个样本的贡献
  • 如果 $p(x)$ 在这一点的概率比 $q(x)$ (即 $q(x)$ 对这个点过度采样了),权重 $\frac{p(x)}{q(x)} < 1$缩小这个样本的贡献

通过这个权重,从 $q(x)$ 采样得到的结果就能无偏地估计 $p(x)$ 下的期望。

2.4 $q(x)$ 是什么?

$q(x)$ 是一个任意的概率分布,唯一要求是它的支撑集(support)要覆盖 $p(x)$ 的支撑集(即 $p(x) > 0$ 的地方 $q(x)$ 也必须 $> 0$,否则 $\frac{p(x)}{q(x)}$ 无定义)。

关键在于:$q(x)$ 不是凭空出现的,它是你实际能采样的那个分布。在不同场景中 $q(x)$ 对应不同的东西,在 PPO 中,$q(x)$ 就是旧策略 $\pi_{old}$


3. 以 LLM RLHF 为例理解重要性采样

3.1 场景设定

假设我们有一个经过 SFT(Supervised Fine-Tuning)训练好的大语言模型,现在要用 PPO 进行 RLHF(Reinforcement Learning from Human Feedback)训练。

3.2 各符号的含义

在 LLM 的 PPO 训练中:

数学符号 LLM RLHF 中的含义
状态 $s$ 当前的 prompt + 已经生成的 token 序列
动作 $a$ 模型下一步要生成的那个 token
$\pi(a \mid s)$ 模型在给定上下文(prompt + 已生成 token)下,选择某个 token 的概率
$\pi_{old}$ 旧策略——上一轮更新后"冻结"的模型权重
$\pi_\theta$ 当前策略——正在被优化更新的模型权重
$f(x)$ 优势函数 $A(s,a)$——评价当前动作相对于平均水平有多好

3.3 具体流程

采样阶段:$\pi_{old}$(冻结的旧模型)对一批 prompt 进行推理,生成若干条回复。对每条回复用 Reward Model 打分。此时每个 token 的生成概率 $\pi_{old}(a_t \mid s_t)$ 被记录下来。

更新阶段: 用同一批 prompt 和生成的回复去更新 $\pi_\theta$(当前模型)。但 $\pi_\theta$ 的权重已经和 $\pi_{old}$ 不同了,所以同一个 token 在同样上下文下的概率变了。为了修正这个偏差,我们计算重要性采样比率:

$$r(\theta) = \frac{\pi_\theta(a_t \mid s_t)}{\pi_{old}(a_t \mid s_t)}$$

3.4 如何获取 $\pi_{old}$$\pi_\theta$ 的概率

对于 LLM,模型最后一层的输出经过 softmax 后得到词表上的概率分布。所以:

  1. $\pi_{old}(a_t \mid s_t)$:在采样阶段,旧模型推理时已经计算过每个 token 的 logits,对其做 softmax 后取对应 token 的概率值,记录保存下来
  2. $\pi_\theta(a_t \mid s_t)$:在更新阶段,把同样的输入喂给当前模型做一次前向传播,得到新的 logits → softmax → 取对应 token 的概率值

实际实现中通常保存和比较的是 log 概率(log-probabilities),因为:

  • 数值更稳定(避免极小概率值的精度问题)
  • 比率的计算变成减法:$\log r(\theta) = \log \pi_\theta(a_t \mid s_t) - \log \pi_{old}(a_t \mid s_t)$
  • 最后 $r(\theta) = \exp(\log r(\theta))$

4. PPO 的 Clip 机制

4.1 为什么需要 Clip

重要性采样虽然在数学上是无偏的,但有一个实际问题:方差可能很大

$\pi_\theta$$\pi_{old}$ 差距太大时,比率 $r(\theta) = \frac{\pi_\theta(a|s)}{\pi_{old}(a|s)}$ 会变得非常大或非常小,导致梯度不稳定,训练可能崩溃。

4.2 PPO 的解决方案:截断目标函数

PPO 定义了一个截断(clipped)的目标函数:

$$L^{CLIP}(\theta) = E\left[\min\left(r(\theta) A(s,a), ; \text{clip}(r(\theta), 1-\epsilon, 1+\epsilon) A(s,a)\right)\right]$$

其中 $\epsilon$ 是一个超参数,通常取 $0.1 \sim 0.2$

clip 函数的作用是把 $r(\theta)$ 限制在 $[1-\epsilon, 1+\epsilon]$ 范围内:

$$\text{clip}(r, 1-\epsilon, 1+\epsilon) = \begin{cases} 1-\epsilon & \text{if } r < 1-\epsilon \ r & \text{if } 1-\epsilon \le r \le 1+\epsilon \ 1+\epsilon & \text{if } r > 1+\epsilon \end{cases}$$

4.3 分情况理解 min 和 clip 的配合

$\min$ 取两项中较小的那个,其效果取决于优势函数 $A(s,a)$ 的正负:

情况一:$A(s,a) > 0$(这个动作是好的,我们想鼓励它)

梯度会推动 $r(\theta)$ 增大(让新策略更倾向于选择这个动作)。但如果 $r(\theta)$ 已经超过了 $1+\epsilon$,clip 那一项会被截断为 $(1+\epsilon) \cdot A$,比未截断的 $r(\theta) \cdot A$ 更小,min 会选择截断项,梯度为零,阻止进一步增大

$$\min\bigl(\underbrace{r(\theta) \cdot A}_{\text{想继续增大}}, ; \underbrace{(1+\epsilon) \cdot A}_{\text{封顶}}\bigr) = (1+\epsilon) \cdot A$$

效果: "这个动作虽然好,但你不能一步走太远。"

情况二:$A(s,a) < 0$(这个动作是坏的,我们想抑制它)

梯度会推动 $r(\theta)$ 减小(让新策略更不倾向于选择这个动作)。但如果 $r(\theta)$ 已经低于 $1-\epsilon$,clip 那一项为 $(1-\epsilon) \cdot A$。由于 $A &lt; 0$,$(1-\epsilon) \cdot A$ 比 $r(\theta) \cdot A$(更接近 0),min 会选择未截断项 $r(\theta) \cdot A$... 不对,让我们仔细算:

$r(\theta) &lt; 1-\epsilon$$A &lt; 0$

  • 未截断项:$r(\theta) \cdot A$($r$ 小,$A$ 负,绝对值大,这是个很负的数)
  • 截断项:$(1-\epsilon) \cdot A$($(1-\epsilon) > r(\theta)$,$A$ 负,绝对值较小,没那么负)
  • $\min$ 取更小的 = 未截断项?不对,应该取截断项。

实际上:$(1-\epsilon) > r(\theta)$,乘以负数 $A$ 后,$(1-\epsilon) A < r(\theta) A$,所以 $\min$ 选截断项,梯度同样被截断。

效果: "这个动作虽然坏,但也不要一步惩罚太多。"

4.4 总结 Clip 的作用

情况 $r(\theta)$ 范围 clip 的效果
$A &gt; 0$,$r > 1+\epsilon$ 新策略已经太倾向这个动作了 截断,阻止继续增大
$A &gt; 0$,$r \in [1-\epsilon, 1+\epsilon]$ 变化在合理范围 不截断,正常更新
$A &lt; 0$,$r < 1-\epsilon$ 新策略已经太不倾向这个动作了 截断,阻止继续减小
$A &lt; 0$,$r \in [1-\epsilon, 1+\epsilon]$ 变化在合理范围 不截断,正常更新

核心思想就一句话:允许策略改进,但每一步不能走太远,确保新策略始终在旧策略的"信任域"(Trust Region)附近。这也是 PPO 名字中 "Proximal"(近端)的含义。


5. PPO 完整目标函数

PPO 的完整目标函数通常还包含另外两项:

$$L(\theta) = L^{CLIP}(\theta) - c_1 \cdot L^{VF}(\theta) + c_2 \cdot H[\pi_\theta]$$

含义
$L^{CLIP}(\theta)$ 截断的策略梯度目标(主项)
$L^{VF}(\theta)$ 价值函数的均方误差损失(训练 Critic)
$H[\pi_\theta]$ 策略的熵(鼓励探索,防止策略过早收敛到某个动作)
$c_1, c_2$ 平衡系数

在 LLM RLHF 场景中,还常常额外加入 KL 散度惩罚项(参见 kl_divergence.md),防止模型偏离 SFT 基线太远。