TPO:把"该学什么"和"怎么学"拆开,梯度自己知道什么时候停

核心摘要

策略梯度方法有个老毛病:reward信号和参数更新死死绑在一起,学习率大了overshoot,小了undershoot,稀疏奖励下更是直接趴窝。TPO的思路很清爽——先算出"理想的概率分布长什么样",再用交叉熵让策略去拟合这个分布。梯度是 \(p^\theta - q\),策略跟目标对齐了,梯度自己就归零了,不用靠裁剪硬拦。在简单任务上跟PG/PPO/GRPO/DG差不多,稀疏奖励下直接拉开差距——H=10的token反转任务,TPO错误率7.4%,GRPO 50.4%,PPO直接挂掉。代码才几行,JAX版核心逻辑就一行softmax。


论文信息

  • 标题:Target Policy Optimization
  • 作者:Jean Kaddour(University College London)
  • 日期:2026年4月7日
  • 链接:https://arxiv.org/abs/2604.06159
  • 代码:https://github.com/JeanKaddour/tpo

从一个让人头疼的问题说起

你有没有碰到过这种情况:用GRPO训练模型做数学推理,reward曲线蹭蹭涨,结果生成质量肉眼可见在变差?或者更邪门——任务本身只有"做对了给1分,做错0分"这种稀疏奖励,模型学了半天跟随机猜一样?

这不是你的调参水平问题,是策略梯度方法的结构性缺陷。

传统的策略梯度(PG)家族——从REINFORCE到PPO到GRPO——都把两个问题搅在一起回答:"哪些补全应该涨概率"和"参数往哪挪"。你拿到一组采样、一组分数,乘起来当梯度用。但这个梯度的方向和大小,全靠学习率、裁剪阈值这些超参来控制,调好了能work,调不好就overshoot或undershoot。

DeepSeek搞出GRPO,用组内相对优势替代value function,省了一个critic网络,算是工程上的一大步。但你仔细看GRPO的梯度——它没有不动点。策略已经收敛了,梯度还在那里push,停不下来。多epoch复用数据的时候更明显,2个epoch可能比4个epoch还差,这种非单调性让人调参调到怀疑人生。

TPO换了个思路:别把"该学什么"和"怎么学"绑在一起。


TPO的核心思路:先定目标,再拟合

一句话讲清楚

给采样的一组补全打分,算出一个"目标分布"\(q\),然后用交叉熵让策略\(\pi_\theta\)去拟合\(q\)。梯度是 \(p^\theta - q\),匹配了就停。

具体怎么算

设上下文 \(x\),采样 \(K\) 个候选 \(y_1, \dots, y_K \sim \pi_{\text{old}}(\cdot|x)\),用评分器 \(S\) 打分 \(s_i = S(x, y_i)\),在组内标准化为 \(u_i\)

策略在组上的分布:

\[p_i^\theta = \frac{\exp(\ell_i^\theta)}{\sum_{j=1}^K \exp(\ell_j^\theta)}\]

其中 \(\ell_i^\theta = \log \pi_\theta(y_i|x)\)

目标分布:

\[q_i = \frac{p_i^{\text{old}} \exp(u_i / \eta)}{\sum_{j=1}^K p_j^{\text{old}} \exp(u_j / \eta)}\]

\(\eta > 0\) 是温度(默认1),\(p_i^{\text{old}}\) 是行为策略的冻结概率。注意这里 \(p^{\text{old}}\) 起到了"锚点"的作用——没有它,目标分布就退化成纯按分数排,容易走极端。

交叉熵损失:

\[\mathcal{L}_{\text{TPO}}(\theta) = -\sum_{i=1}^K q_i \log p_i^\theta\]

关键梯度性质:

\[\frac{\partial \mathcal{L}}{\partial \ell_i^\theta} = p_i^\theta - q_i\]

你想想看,梯度方向就是"当前概率减去目标概率"。高了就往下压,低了就往上拉,刚好对齐就为零。这比PPO的clipping优雅多了——clipping是硬截断,TPO是梯度天然消失。

KL正则化的等价解释

目标分布 \(q\) 其实是下面这个优化问题的唯一解:

\[q = \arg\max_{r \in \Delta^{K-1}} \left\{ \sum_{i=1}^K r_i u_i - \eta \cdot \text{KL}(r \| p^{\text{old}}) \right\}\]

在"最大化奖励"和"别离旧策略太远"之间做了KL正则化的平衡。说到这个,DeepSeek之前用GRPO的时候也加了反向KL惩罚来防mode collapse,但那是外部加的约束。TPO把这个平衡直接内嵌到目标分布的构造里了,\(\eta\) 就是控制"要不要激进"的唯一旋钮。

代码才几行

JAX版核心逻辑:

def tpo_target(log_scores, u, eta=1.0):
    return jax.nn.softmax(
        jax.nn.log_softmax(log_scores, -1) + u / eta, -1)

q = jax.lax.stop_gradient(tpo_target(log_scores, u))
log_p = jax.nn.log_softmax(log_scores, -1)
loss = -(q * log_p).sum(-1).mean()

log_softmax(log_scores) 就是 \(\log p^{\text{old}}\),加上 \(u/\eta\) 再过一遍softmax就得到 \(q\)。然后 \(q\) detach掉,用交叉熵拟合。说实话,看到这个实现的时候我愣了一下——就这么简单?

图1:TPO与基线方法对比

图1:TPO在简单任务上匹配基线,在稀疏奖励下大幅超越。左边是MNIST上下文赌博机(密集奖励),TPO略快于GRPO和DG;右边是token反转任务(终端奖励),GRPO和DG在随机水平停滞,TPO直接学会

图2:代码实现示意

图2:TPO的on-policy实现只需几行代码。log_scores是策略的log概率,u是标准化分数,eta是温度。同一个log_scores既用来构造q也用来算log_p,q在更新前detach


方法对比:TPO到底新在哪

先放一张全景对比表,看清楚TPO在策略优化家族里的位置:

方法 更新规则 组内比较 需要Critic 需要固定参考
REINFORCE PG + baseline
PPO Clipped PG surrogate
GRPO Clipped PG + group adv.
RLOO PG + leave-one-out baseline
REBEL Sq. loss on reward diffs
MPO/V-MPO q∝π_old·exp(signal/η); fit π→q
TPO q∝p_old·exp(u); CE to q

最该注意的是最后一行和MPO那行。MPO是DeepMind 2018年提出的,思路一样——构造目标分布再拟合。但MPO需要value function来估计优势,而TPO利用了LLM RLVR场景的特殊性:同一prompt采K个补全,组内分数天然就能当相对优势用,不需要critic。

这是TPO真正聪明的地方:不是发明了新算法,是把MPO的思路适配到了"K个候选同prompt"这个特定场景,砍掉了critic这个大包袱。


实验:简单任务差不多,稀疏奖励拉开差距

密集奖励:大家都能学会

MNIST上下文赌博机,dense reward,TPO收敛稍快(1,600步达5%误差 vs DG的2,200步),最终大家都到差不多水平——TPO 2.9%,GRPO约3-4%,DG约4%。这个结果不意外——密集奖励下信用分配不难,每个动作都有反馈,所有方法都能工作。

Token级别密集奖励的小规模Transformer实验更有意思。作者测了4种目标逻辑(copy, flip, reverse copy, reverse flip)× 2种奖励结构(bag-of-tokens, sequential)= 8个变体:

表2:达到1%误差的步数(K=8 token候选,H=10, V=2)

奖励 目标 TPO_token GRPO_token DG PPO
Bag of tokens Copy 81 338 219 170
Bag of tokens Flip 56 104 201 146
Bag of tokens Rev. copy 55 352 202
Bag of tokens Rev. flip 59 209 200 143
Sequential Copy 295 439
Sequential Flip 321 349
Sequential Rev. copy 159 515
Sequential Rev. flip 276 309

"−"表示在预算内没达到1%误差。注意sequential奖励那几行——GRPO和PPO直接没学会,TPO虽然慢一点但稳定收敛。这说明TPO不只是"稀疏奖励下更强",而是对奖励结构的适应性更广。bag-of-tokens奖励下GRPO勉强能学,但一换到sequential就歇菜了。

还有个有趣的发现:词汇表大小V的影响。V=2时TPO_token只需58步达1%误差,GRPO_token需要904步——差了15倍。V变大后差距缩小,但TPO始终最快。

稀疏奖励:TPO的主场

这才是重点。当奖励只在序列结束时给一次"做对了没":

表1:精确匹配错误率%,终端奖励,token反转任务

方法 H=7 H=8 H=9 H=10
TPO 6.9 8.6 6.1 7.4
GRPO 14.5 27.6 30.0 50.4
GRPO (no KL) 66.6 92.5
PPO 12.0 26.3 90.6
DG 33.8 58.8

H=10时TPO 7.4% vs GRPO 50.4%,差了近7倍。没有KL惩罚的GRPO直接崩到92.5%,说明GRPO在这个场景下极度依赖外部KL约束来保命。PPO在H=9就挂了。DG更惨,H=8就58.8%。

图3:稀疏奖励下的学习曲线

图3:终端奖励下各方法的学习曲线对比,序列长度从H=7到H=10。TPO在所有长度上稳定收敛,其他方法随序列增长急剧退化

说实话,看到GRPO (no KL)的数字时我确实有点惊讶——66.6%到92.5%,这已经不是"效果差"了,是模型在主动学坏。这也印证了GRPO的核心问题:没有不动点的梯度一直在推,推到mode collapse。

LLM RLVR:十亿参数模型

用verl框架,Qwen3-1.7B和DeepSeek-R1-Distill-Qwen-1.5B,K=16:

  • GSM8K:TPO学得更快(比GRPO早约10步达50%准确率),但最终两者收敛相近(约85-87%)。密集奖励嘛,大家都行。
  • Reasoning Gym图着色:GRPO在Qwen3-1.7B上完全失败,300步近零分;TPO达到约0.96。这才是TPO的真正价值——图着色是典型的稀疏奖励任务,做对就是做对,没有半对。
  • Knights & Knaves:同样的模式,TPO优势更明显。

图4:LLM RLVR实验结果

图4:LLM RLVR实验。上排Qwen3-1.7B,下排DeepSeek-R1-Distill-Qwen-1.5B。三列分别是GSM8K、图着色、Knights & Knaves。图着色任务上GRPO在Qwen3上完全失败,TPO稳定收敛到约0.96


为什么TPO在稀疏奖励下这么强?

论文给了三个解释,每一个都有实验支撑,不是空谈。

1. 梯度自消失

TPO的梯度在 \(p^\theta = q\) 时为零,这是个不动点。GRPO没有这个性质——策略已经收敛了,梯度还在。

图5:梯度范数对比

图5:TPO的梯度自消失而GRPO不会。左图是训练过程中梯度L2范数,TPO随收敛快速衰减至近零,GRPO全程维持较高梯度范数。右图是成功/失败候选上的权重分配

这有多重要?在稀疏奖励下,大部分采样组都是全失败的(零方差→\(u=0\)\(q=p^{\text{old}}\)→损失贡献为零)。GRPO对这些全失败组还在持续输出梯度信号,推着模型瞎动。TPO直接忽略它们,把更新集中在那少数几个"有信息"的组上。

2. 信号集中在有信息组

当K=32且序列成功率约0.4%时,约90%的组全部失败。TPO对这些组天然免疫(\(q=p^{\text{old}}\),损失为零),自动把计算资源集中在有成功的组上。

这个机制其实很直觉:一组采样里要是没有一个人做对,标准化后的 \(u\) 就是零向量(因为方差为零),目标分布 \(q\) 就等于 \(p^{\text{old}}\),交叉熵损失对这组的梯度贡献为零。GRPO就不一样了——即使全失败,GRPO的组内优势计算照样会给负梯度,推着模型远离这些失败样本。听起来也对?问题是,"远离失败"和"靠近成功"不是一回事。远离失败可能推到完全随机的方向,尤其是在稀疏奖励下失败的原因千差万别。

组大小消融:

K TPO错误率 GRPO错误率
4 8.9% 19.4%
8 5.2% 19.8%
16 5.1% 9.2%
32 2.6% 4.4%
64 0.36% 5.6%

K=64时TPO 0.36% vs GRPO 5.6%。K越大,稀疏奖励下"组里至少有一个成功"的概率越高,TPO就越能从这些稀有成功信号中受益。而GRPO在K=64反而比K=32差了——更多全失败组带来了更多噪声梯度。

有个细节让我印象深刻:论文试了显式掩码零方差组来"帮"GRPO,结果反而有害(错误率从6.3%升至29.7%)。原因是在多epoch训练中,这些全失败组虽然当前没有新信息,但提供了"锚点回拉"的效果——防止策略在已经学会的组上走太远。TPO不需要这种外部帮助,它天然就能处理。

3. 多epoch稳定复用

Epoch数量消融:

Epoch数 TPO最终错误率 GRPO最终错误率
1 0.02% 4.3%
2 37.6%
4 0.05% 6.3%
8 \lt2.3% 3.3%
16 \lt2.3% 1.1%

TPO在1到16个epoch范围内全稳,均低于2.3%。GRPO在2个epoch时崩到37.6%,4个epoch才恢复到6.3%——这种非单调性简直让人崩溃。你在调参的时候怎么知道该用几个epoch?只能一个一个试。

TPO为什么稳定?因为冻结的 \(q\) 提供了一个固定的吸引子。不管你跑几个epoch,目标不变,策略朝着同一个方向走,不会来回震荡。GRPO的梯度方向在epoch间会变——上一步推过去的,下一步可能推回来。

这让我想到一个类比:TPO就像GPS导航,目的地(目标分布)设定好了,你只管往那个方向开;GRPO像是指南针,告诉你哪个方向"好一点",但具体走到哪算好,没有明确定义,走过了也不会告诉你停。


消融实验:锚点和目标匹配都不可少

论文做了三个消融:

  • TPO:完整方法
  • TPO-no-anchor:去掉 \(p^{\text{old}}\) 锚点,目标变成 \(q_i \propto \exp(u_i)\)
  • Group PG:保留候选和标准化分数,但用标量加权PG替代目标匹配

H=10时,完整TPO达7.4%,每个消融都超过99%。移除锚点始终有害;目标匹配本身也关键(Group PG表现最差)。

这个消融结果挺有说服力的。锚点 \(p^{\text{old}}\) 的作用不是装饰——没有它,目标分布就变成纯exp分数,极端分数下会把所有概率集中到一个候选上,策略一步到位然后锁死。目标匹配(用交叉熵而非标量加权)同样关键:它给了梯度消失的性质,而标量加权的PG永远在推。

图6:消融实验对比

图6:锚点和目标匹配的消融实验。完整TPO在每个序列长度上都优于所有消融变体,H=10时差距尤其明显


温度鲁棒性

\(\eta\) 从0.25到2.0都能快速收敛,只有4.0时才明显退化。默认值1.0就是个好选择,不需要怎么调。这一点对实际工程来说很重要——你不想每换一个任务就要重新调温度。

η 最终错误率% 达到1%步数
0.25 1.0 72
0.50 0.0 67
1.00 0.7 96
2.00 1.0 141
4.00 0.8 260

我的判断

亮点

  1. 思路干净:把"该学什么"和"怎么学"分开,这个拆解本身就有价值。不是靠工程trick堆出来的,是一个有数学保证的设计。
  2. 梯度消失性质\(p^\theta = q\) 时梯度为零,这不是近似,是精确的。比clipping那种硬截断优雅太多。
  3. 代码极简:JAX核心就一行softmax,PyTorch也差不多。从GRPO迁移到TPO的代码改动极小。
  4. 稀疏奖励下的优势是实打实的:不是在某个小数据集上刷了0.1个点,是7倍以上的差距。

问题和局限

  1. LLM实验规模偏小:1.5-1.7B参数的模型离实际部署差距不小。7B以上的模型、更难的benchmark(MATH, AIME)还没测。论文自己也承认了这一点。说实话,在小模型上的优势能不能scale up,还需要验证。
  2. K个采样的成本没变:TPO跟GRPO一样需要K个rollout/prompt。它解决的是"给定K个采样怎么更新更好",不是"怎么减少采样"。
  3. 跟DAPO等更近期的工程方法缺对比:DAPO是字节2025年提出的方法,在AIME 2024上刷了50分。TPO只跟了GRPO、PPO、DG对比,没跟DAPO比。可能是因为DAPO更偏工程层面的改进(动态采样、解耦裁剪),但读者肯定想知道TPO+DAPO能不能叠加。
  4. 标准化分数的放大效应:当组内方差极小时,标准化会放大小数值差异,跟GRPO的difficulty bias是同一个问题。论文提到了但没给解决方案。

跟MPO的关系

TPO跟DeepMind 2018年的MPO/V-MPO思路几乎一样——构造目标分布再拟合。作者在论文里也明确说了这一点。区别在于MPO需要value function,TPO利用了K候选同prompt的结构省掉了critic。这不算从零到一的突破,但确实是一个聪明且实用的适配。

再说一个背景:MPO在连续控制领域已经被验证过了,DeepMind的Alpha系列工作大量使用MPO族方法。TPO的贡献不是发明了"目标分布+交叉熵拟合"这个范式,而是证明了这个范式在LLM RLVR场景下——特别是稀疏奖励下——比当前主流的GRPO/PPO更合适。这其实是一个挺好的信号:有时候最有效的改进不是发明新东西,而是把已知的好方法搬到对的地方。

工程启发

如果你正在做LLM的RLVR训练,特别是在稀疏奖励场景(数学推理、代码生成、逻辑推理),TPO值得一试。迁移成本极低——把GRPO的loss换成TPO的loss,就几行代码的事。\(\eta=1\) 作为默认值大概率够用,不需要额外调参。多epoch复用数据也更稳定,不用纠结"到底用几个epoch"这种玄学问题。

但如果你做的是密集奖励场景(比如已经有过程奖励模型PRM),TPO的优势可能没那么明显。GSM8K上TPO和GRPO最终收敛就差不多。


总结

TPO用一个干净的设计选择替代了标量加权策略梯度:在评分候选集上构建目标分布,然后通过交叉熵拟合策略。三个关键性质——梯度自消失、信号集中在有信息组、多epoch稳定复用——解释了它在稀疏奖励下的强势表现。

这不是一个推翻一切的范式转移,更像是一个"就该这么做"的修正。策略梯度方法把评分和更新绑在一起是历史包袱,TPO把它拆开了,效果立竿见影。代码开源,几行就能用,对实际工程的迁移成本几乎为零。

唯一的悬念是大模型上的验证。1.5B参数的实验说服力有限,7B+和更难benchmark的结果才是真正的试金石。

觉得有启发的话,欢迎点赞、在看、转发。跟进最新AI前沿,关注我