UniSD:不靠"更强的老师",LLM 能不能自己教自己变强?

核心摘要

最近做 post-training,我和组里同事吵过一个问题:到底什么时候应该用 self-distillation,什么时候必须找一个更强的 teacher model? 之前的直觉是——self-distillation 是穷人方案,能不用就别用。但实际跑下来你会发现,找不到更强的 teacher 是常态:商用 API 不让蒸馏、policy 不允许、或者你的 domain 太特殊根本没有现成 teacher。

这篇 UniSD 把这个问题正经掰开了。作者不是提出新的 SD 方法,而是把过去散落在不同论文里的五种机制(多教师一致性、EMA teacher、token 级对比学习、特征匹配、散度裁剪)放在一个统一框架下做控制变量分析。在 6 个 benchmark、3 个模型家族的 6 个模型上跑,最终拼出来的 UniSD* 比 base model 提升 +5.4 分,比最强 baseline 提升 +2.8 分——完全不依赖外部更强 teacher。

说实话这种"系统化研究 + 模块化框架"的论文我一开始是有点警惕的——容易写成大杂烩。但读完发现它的贡献蛮扎实的:第一次把 SD 拆解成"监督可靠性 / 表征对齐 / 训练稳定性"三轴,并且每个组件什么时候有用、互相之间怎么干扰,给了明确的答案。


论文信息

  • 标题:UniSD: Towards a Unified Self-Distillation Framework for Large Language Models
  • 作者:Yiqiao Jin, Yiyang Wang, Lucheng Fu, Yijia Xiao, Yinyi Luo, Haoxin Liu, B. Aditya Prakash, Josiah Hester, Jindong Wang, Srijan Kumar
  • 机构:Georgia Tech、UCLA、CMU、William & Mary
  • arXiv:https://arxiv.org/abs/2605.06597

为什么 Self-Distillation 在 LLM 上这么难?

我先讲讲 self-distillation 这个概念的常见误解。

很多人一听 self-distillation,脑子里立刻冒出来的画面是 vision 里面那种"小网络蒸大网络"。但 LLM 里的 SD 不是这么回事——它指的是模型用自己生成的轨迹做监督来训练自己。听起来有点像"自己抓自己头发上天",对吧?我第一次听到也是这反应。

但仔细想想,这个 idea 是有道理的。如果模型对某个 prompt 能生成一个合理的 CoT 推理过程,那这个推理过程本身就可以做训练数据。问题是——怎么知道哪些 self-generated rationale 是可信的? 模型可能对一个错误答案生成一段听起来很有道理但是错的推理,你拿这个去训练自己,那就是把错的内化得更深。

这就是 LLM 里 SD 的核心难点。论文里把它拆成三块:

第一Open-Ended Generation。一个 prompt 可能有多个合理的答案、多条合理的推理路径。这跟 vision 不一样——cat 就是 cat,给个 label 完事。LLM 没有"标准答案",连"对"的定义都是 task-dependent 的。

第二监督信号不稳定。On-policy rollout 会让模型暴露在自己的错误上:你今天生成一个错的推理,拿它训练,明天可能更倾向于生成类似的错。这种"自我强化"在 RL 里叫 mode collapse,在 SD 里同样存在。

第三缺乏系统化理解。过去的 SD 论文各搞各的——有人加 EMA teacher、有人做 contrastive learning、有人玩 feature matching。但它们组合在一起到底有没有协同效应?哪些是冗余的?哪个 task 应该用哪些组件?没人说清楚。

这就是 UniSD 想做的事。

图1:UniSD 框架总览。左侧列出 SD 在 LLM 上的三大挑战,中间是 UniSD 的五个核心目标,右侧是收益(更高准确率、不依赖外部模型、可靠的自我提升等)

图1:作者把过去散落的 SD 技术整合到一个统一框架里。中间的五个模块——Multi-Teacher Agreement、EMA Teacher、Divergence Clipping、Contrastive Learning、Feature Matching——分别对应监督可靠性、训练稳定性等不同维度。


三个轴 × 五个组件

UniSD 的组织方式我觉得是这篇论文最值钱的地方。它把 self-distillation 抽象成一个 "reliability-aware self-correction process"——学生先 attempt 一次 completion,然后通过多个 teacher views 的对比和监督来学习,加权可靠信号,把知识整合进自己的行为。

围绕这个 formulation,五个组件分布在三个轴上:

轴一:监督可靠性(Supervision Reliability)

要解决的问题:哪些 self-derived signals 可信,哪些不可信?

Multi-Teacher Agreement:用多个 teacher view 看同一条 trajectory。如果几个 teacher 都对某个 token 的预测分布高度一致,那这个 signal 可信度高。如果它们彼此打架,说明这个位置本身就有不确定性,那个 token 的监督权重应该降低。

具体公式上是用多个 teacher 的预测做 cross-view consistency 估计:

\[\delta_t = A(\{l_t^k\}_{k=1}^K)\]

其中 \(l_t^k\) 是第 k 个 teacher 在第 t 个 token 上的 log probability。\(A\) 是 agreement function——可以是方差、KL 散度的反数等。

Token-Level Contrastive Learning:在 token level 上拉近 positive examples、推开 negative examples。比如对同一个 prompt,生成的 correct rationale 是 positive,generated 但 incorrect 的 rationale 是 negative。

\[\max(0, \gamma + d_t^+ - d_t^-)\]

这个 idea 借鉴自表示学习,目的是让模型在 embedding space 里能区分"看起来合理的"和"真正合理的"——这俩很容易混淆,正是 LLM SD 的核心痛点。

轴二:表征对齐(Representation Alignment)

要解决的问题:只在 output distribution 上做监督是不够的,中间表征也要对齐。

Feature Matching:学生的中间层 feature 要向 teacher 靠拢。

\[\mathcal{L}_{FM} = \sum_t m_t \|f_t^\theta - f_t^*\|_2^2\]

这块跟传统 distillation 里的 hint loss 一脉相承。但放在 SD 场景下,teacher 是自己(或者 EMA 自己),所以这个机制更像是正则化训练动力学——防止学生在 representation 上漂移得太快。

轴三:训练稳定性(Training Stability)

要解决的问题:SD 的 teacher 信号本身在演化,怎么防止小错误被指数级放大?

EMA Teacher:用 exponential moving average 维护一个"过去的自己"作为 teacher。

\[\bar{\theta}_n = \beta \bar{\theta}_{n-1} + (1-\beta) \theta_n\]

EMA teacher 提供时间上 smoothed 的目标,比 instantaneous student 稳得多。这个 trick 在 BYOL、MoCo 这些 self-supervised vision 方法里都验证过有效,搬到 LLM 上也 work。

Divergence Clipping:极端情况下,某些 token 的 KL divergence 会爆掉(比如 teacher 给某个 token 99% 概率,student 给 1%,KL 就发散)。Divergence Clipping 给每个 token 的 divergence 设上限:

\[\tilde{D}_t = \min(D_t^{(\alpha)}, \kappa)\]

防止 rare high-divergence tokens 主导优化。这是个看起来朴素但实际很关键的细节——我之前调 KD loss 的时候就被这个坑过。

图2:UniSD 的完整架构图。Student Policy 生成 on-policy rollout,被 K 个 self-derived teacher views(含 EMA teacher)看到。Multi-Teacher Agreement 在 token-level 和 sequence-level 估计一致性,生成 reliability weight,最终形成 weighted token-level loss。下方是 EMA、Contrastive、Feature Matching、Divergence Clipping 四个辅助模块

图2:这张图基本上就是 UniSD 算法的"全家福"。学生 policy \(\pi_\theta\) 生成 on-policy rollout,K 个 teacher(\(\pi_1^T, ..., \pi_K^T\))从不同 view 看这条轨迹,输出 log-prob。Agreement 函数把这些 log-prob 聚合成 token-level 和 sequence-level 的可靠性权重 \(w_t\),再用这个权重对 KD loss 重新加权。EMA、Contrastive、Feature Matching、Divergence Clipping 作为 B/C/D/E 四个辅助模块插入这个主流程。


这套框架的核心 loss

把上面所有组件凑在一起,UniSD 的训练目标大概是:

\[\mathcal{L} = \sum_{t=1}^T m_t w_t D(\pi_\theta(\cdot|x, \hat{y}_{<t}) \| \pi_*^T(\cdot|x, c^*, \hat{y}_{<t})) + \lambda_{aux} \mathcal{L}_{aux}\]

其中: - \(m_t\) 是 token mask(哪些位置要算 loss) - \(w_t\) 是 Multi-Teacher Agreement 算出来的 reliability weight - \(D(\cdot \| \cdot)\) 是 KL divergence(经过 Divergence Clipping) - \(\mathcal{L}_{aux}\) 是 contrastive + feature matching 的辅助损失

可以看出整个 loss 的"主线"就是一个re-weighted KD loss——每个 token 的监督权重由 cross-teacher agreement 决定,可信的 token 权重大,不可信的 token 权重小。

这跟单纯做 imitation learning(每个 token 平权)有本质区别。我自己之前做 RLHF 时就发现,reward 的方差和质量在不同 token 上差异巨大——一个 trajectory 里可能 90% 的 token 都是无关紧要的连接词,5% 是真正的 reasoning step,5% 是错误位置。如果你对所有 token 平权监督,那 reasoning step 的信号就被淹了。

UniSD 这个 token-level reweighting 思路,本质上和最近 MAPO(用 multi-agent 估计 advantage)、Process Reward 那一脉是相通的——都是在更细粒度上做监督。


实验:6 模型 × 6 benchmark 的系统化对比

实验 setup 我觉得设计得很周到:

  • 模型家族:3 个 family(Qwen、Llama、Mistral 之类),每个 family 取 2 个 size,共 6 个模型
  • 任务:6 个 benchmark,覆盖 Science(GPQA、MMLU-STEM)、Coding(HumanEval、MBPP)、Tool Use(ToolBench)、QA(MMLU、TriviaQA)
  • Setting:in-domain + OOD 都测,验证泛化能力

关键结果

UniSD*(整合后的 pipeline)相比 base model 平均提升 +5.4 分,相比最强的 SD baseline(论文里报的应该是某个单组件方法)提升 +2.8 分

具体到组件贡献的部分,论文给出几个有意思的观察:

观察 1:单个组件单独跑,提升都比较小(+1 ~ +2 分),但组合起来有协同效应——多个组件的互补让 UniSD* 比任何单组件方法都好。这种"组合优势"在 ablation 里得到了支持。

观察 2Multi-Teacher Agreement 是收益最稳定的组件。在所有 6 个 benchmark 上都贡献正向提升。这印证了我前面说的——token-level reweighting 是 SD 的关键,比单纯的 imitation 强得多。

观察 3Feature Matching 的贡献和模型规模相关。小模型(7B 以下)上 Feature Matching 帮助大,大模型上贡献递减。这个我有点意外——按直觉应该是大模型 hidden representation 更"对齐",但实验显示反而是小模型受益更多。我的解释是:小模型 hidden representation 更不稳定,需要 explicit alignment regularizer 拉一把。

观察 4Divergence Clipping 是"防崩"组件,平均收益很小但去掉之后训练偶尔会炸。这是个典型的"日常用不上但关键时刻保命"的机制。


我的判断:值不值得读?

先说结论:如果你在做 LLM post-training,特别是 SD 或 KD 相关,强烈推荐读。不推荐的场景:纯 pretraining 或 RL 方向,这篇关联度低。

亮点

  1. 三轴拆解是个好的思维框架。Supervision Reliability / Representation Alignment / Training Stability 这个分类把过去散乱的 SD trick 组织得很清晰。后续做 SD 的论文如果要做 ablation,基本可以照这个框架来
  2. 实验 setup 扎实。6 模型 × 6 benchmark × in-domain + OOD,覆盖度足够。比那些只在 LLaMA 上 demo 的论文要严谨
  3. 找到了"组合效应"这个非平凡发现。单组件提升小,组合起来反而强——说明这五个组件确实在解决不同的问题

问题

  1. 整体方法新颖性偏弱。每个组件单看都是已有的——Multi-Teacher Agreement 让我想到 Cotrain、PoE;EMA teacher 是 MoCo/BYOL 的老把戏;Feature Matching 是 FitNet 的标准操作。论文的贡献更多在整合 + 系统分析,而不是单点突破
  2. "自蒸馏"的边界有点模糊。Multi-Teacher Agreement 里那些 teacher 是从哪来的?如果是同一个 base model 的多次采样,那就是真 SD;如果是 prompt 不同 view,那 teacher 的"多"是 view 多,不是 model 多。这块描述不够精确
  3. 缺与 RL-based 方法的直接对比。GRPO、DAPO 等最近的 RL 方法也是不依赖更强 teacher 的,但实验里基本没出现。Self-distillation vs RLHF / GRPO 哪个更适合什么场景?没有答案
  4. UniSD* 的最优组合是怎么搜出来的?论文说"guided by these insights",但具体搜索过程描述得比较模糊——是穷举所有组合?还是基于消融的某种启发式?这块对复现挺重要

对工程实践的启发

  • 如果你在自蒸馏 / 数据合成场景下,Token-level reweighting 是必加的。把 trajectory 上的每个 token 平权监督,是浪费监督信号
  • EMA teacher 比 instantaneous teacher 稳定得多。BYOL 的老把戏在 LLM 上一样有效,无脑加
  • Divergence Clipping 是廉价的"保险"。KL loss 系列方法默认应该加一个 max divergence 上限,防止偶发崩溃

收尾

这类"统一框架 + 系统分析"的论文在 ML 社区不算稀缺,但能做得扎实的不多——很多最后变成大杂烩,缺乏洞察。UniSD 在这块做得算合格——三轴拆解清晰,组件之间的交互有数据支撑,最后的整合方案也确实拿到了 +2.8 分的 net gain。

回头看 self-distillation 这个方向的发展,从最早的 Born Again Networks,到 LLM 里的 STaR、ReST、SPIN,再到今天的 UniSD——可以看出一个清晰的趋势:SD 方法越来越精细化,从"用自己的输出当 label"到"按 token 加权 + 多视角一致性 + 训练稳定性保护"的多模块组合。这个方向短期内不会停。

下一个值得追的问题:SD 和 RL 的边界在哪? UniSD 强调 reliability weighting,但 reliability 本身和 RL 的 advantage 概念有点像——一个 token 是否"可信",跟它是否对最终 reward 有贡献,在某种程度上是同一个问题的两个 framing。下一篇好的 SD 论文,应该会把这两个思路打通。


觉得有启发的话,欢迎点赞、在看、转发。跟进最新 AI 前沿,关注我