广义知识蒸馏 (Generalized Knowledge Distillation, GKD):是一种让学生模型在自己生成的 On-Policy 序列上,利用教师模型给出的 Token-Level 分布反馈进行蒸馏,从而缓解自回归模型训练与推理分布不一致问题的方法。

原始论文:On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes

1 自回归模型蒸馏方法

1.1 有监督微调 (Supervised FT)

若无法访问“教师”策略(策略可以理解为参数),但是可以获得“教师”策略生成得到的一组固定 Ground-Truth 数据集。那么一种简单的方法是,在“学生”策略下最小化这些序列的负对数似然值:

$$ L_{\mathrm{SFT}}(\theta) := \mathbb{E}_{(x,y)\sim (X,Y)} \left[-\log p_S^{\theta}(y \mid x)\right]. $$

这种方法在蒸馏领域中也可以叫做 Seq-KD,该方法应该是当前最简单和常用的蒸馏方法了。并且由于它不需要访问“教师”策略,所以可以轻松对优秀的闭源模型进行蒸馏,同时得到的离线数据集可以用在任何其他大模型进行训练,通用性强。

1.2 有监督蒸馏 (Supervised KD)

大模型在自回归生成时,在选取下一个生成 Token 时实际上是从一个分布中采用得到的。蒸馏得到纯文本的数据集,只保留了采样的结果,而丢弃了包含丰富信息的概率分布。有监督蒸馏方法从数据集中抽取数据后,通过 KL 散度对齐“学生”和“教师”模型在数据集条目上的生成概率分布,即训练学生模型模仿教师模型的 Token 级概率分布:

$$ L_{\mathrm{SD}}(\theta) := \mathbb{E}_{(x,y)\sim (X,Y)} \left[ \mathcal{D}_{\mathrm{KL}} \left(p_T \,\|\, p_S^{\theta}\right)(y \mid x) \right]. $$

但由于该方法需要访问到模型生成 Token 的概率分布,而闭源模型基本上只会提供 Token 生成的 API,该方法无法蒸馏闭源模型。对于开源模型,使用 vLLM 框架部署后可以轻松获得生成 Token 的概率分布,因此该方法适合能够访问“教师”策略的场景。

1.3 以上方法的缺点

上面所述的两种方法让“学生”模型在固定的数据集(来自标注或“教师”模型)上进行训练,但在实际推理阶段,“学生”模型是基于自己前一步的输出自回归地生成下一个词的,这导致了训练与推理时的分布不匹配。这种差异导致学生在推理时可能会遇到它在训练中从未见过的序列状态,并且由于大语言模型的自回归生成原理,模型一旦在早期生成中犯错,就会产生连锁反应,导致后续生成质量直线下降。

2 广义知识蒸馏 GKD

广义知识蒸馏的(不完整)流程为:

  1. 给定输入序列 $x\in X$
  2. 学生模型采样生成输出序列 $y$
  3. 教师模型在序列 $y$ 上,对每个 token 前缀状态给出概率分布
  4. 学生基于 KL 散度拟合教师模型的概率分布

从流程就可以看出 GKD 就是针对其他蒸馏方案的缺点设计的。首先模型训练的数据不再是固定的 $(x,y)\in(X,Y)$,而是固定的 $x\in X$ 后由学生模型动态采样生成得到的 $y$。其次,学生模型不再去拟合他人(来自标注或教师模型)生成的序列,学习如何在自己的序列前缀上让概率分布与教师模型相似。这样在训练和推理时,学生模型见到的序列是一致的,均是自己生成的序列。

用公式来表示的话,损失表达为:

$$ L_{OD}(\theta) := \mathbb{E}_{x \sim X} \left[ \mathbb{E}_{y \sim p_S(\cdot \mid x)} \left[ D_{KL}\left(p_T \,\|\, p_S^\theta\right)(y \mid x) \right] \right]. $$

但是根据作者的设计,上述内容只是 GKD 的一部分,称为在线蒸馏 (On-policy KD, OD)。完整的 GKD 将 SFT 和 OD 通过权重 $\lambda$ 结合在一起:

$$ L_{\mathrm{GKD}}(\theta) := (1-\lambda) \mathbb{E}_{(x,y)\sim (X,Y)} \left[ \mathcal{D}\left(p_T \,\|\, p_S^\theta\right)(y \mid x) \right] + \lambda \mathbb{E}_{x\sim X} \left[ \mathbb{E}_{y\sim p_S(\cdot \mid x)} \left[ \mathcal{D}\left(p_T \,\|\, p_S^\theta\right)(y \mid x) \right] \right]. $$

纯的 OD 虽然能缓解训练推理的不一致,但是学生冷启动时可能导致采样出来的序列太差,训练不稳定,因此作者将 GKD 设计为 SFT 和 OD 的混合。在代码中,混合方式也是非常简单:生成一个随机数 $u$,如果 $u\leq\lambda$ 则走 OD 训练;如果 $u>\lambda$ 则走 SFT 训练。

另外,该方法的“广义”就在它通过控制 $\lambda$ 就可以包含多个方法:

  • $\lambda=0$:退化为 Supervised KD;
  • $\lambda=1$:退化为 On-policy KD;
  • $\lambda=0.5$:固定数据和学生生成数据混合,真正的 GKD 方法。

关于原文 ...where we do not backpropagate through the student’s sampling distribution pS(·|x)... 的解释

原始论文特别强调了,采样动作不进行反向传播,而只是用于收集训练数据。如果采样动作也进行反向传播,那么该方法就变成了 RL 风格的训练,方差更大,也更不稳定。同时,模型可能会开始尝试改变自己采样哪些序列,而倾向于生成那些能让学生教师分布 KL 更小的序列,而不是去学习如何在当前序列下让学生教师分布的 KL 更小。因此为了训练的稳定性,GKD 的采样部分不记录梯度,不进行反向传播。

3 分布相似性

论文中还提到了如何度量两个分布的相似性,KL 散度是最常见的度量方式:

$$ D_{\mathrm{KL}}(P \,\|\, Q) = \sum_{x \in \mathcal{X}} P(x) \log \frac{P(x)}{Q(x)} $$

需要注意的是,KL 散度不对称,即 $D_{\mathrm{KL}}(P \,\|\, Q)\neq D_{\mathrm{KL}}(Q \,\|\, P)$,因此使用两种不同的 KL 计算方式,会让训练行为截然不同。

首先来看正向的 KL 散度

$$ D_{\mathrm{KL}}(p_T \,\|\, p_S^\theta) = \sum_{v \in \mathcal{V}} p_T(v) \log \frac{p_T(v)}{p_S^\theta(v)} $$

注意到学生模型在分母,系数项是教师模型。这种情况下,如果教师在某个 token 上概率很高,但学生概率很低,惩罚会很大。因此这种方式倾向于让学生尽量覆盖教师分布里的所有可能模式。

然后来看反向的 KL 散度

$$ D_{\mathrm{KL}}(p_S^\theta \,\|\, p_T) = \sum_{v \in \mathcal{V}} p_S^\theta(v) \log \frac{p_S^\theta(v)}{p_T(v)} $$

注意到教师模型在分母,系数项是学生模型。这种情况下,如果学生给某个 token 很高概率,但教师给它很低概率,惩罚会很大。因此这种方式倾向于集中到教师最确信的高概率区域,而不是覆盖教师的所有可能输出。

广义 Jensen-Shannon 散度 (Generalized JSD) 融合了上面两种计算方式,提供了更均衡的计算方式。它不直接比较教师分布和学生分布,而是先构造一个中间分布,然后分别比较教师/学生和这个中间分布的距离。

例如对于分布 $P,Q$,JSD 首先构造 $M=\beta P+(1-\beta)Q$,然后结果便是:

$$ D_{\mathrm{JSD(\beta)}}(P \,\|\, Q)=\beta D_{\mathrm{KL}}(P \,\|\, M)+(1-\beta)D_{\mathrm{KL}}(Q \,\|\, M) $$

$\beta$ 是一个 $0\sim1$ 的超参数,用于控制 JSD 的倾向性:$\beta$ 越小更接近正向 KL,偏覆盖教师分布;$\beta$ 越大更接近反向 KL,更倾向教师高概率 token.

4 应用于训练

GKD 方法已经被主流大模型训练框架例如 ms-swift 支持了,我们看到 ms-swift 文档的 GKD 页面,在学习完原论文后,框架里的参数含义就非常明了了:

  • --beta:就是第 3 节里提到的控制 JSD 的 $\beta$
  • --lmbda:就是第 2 节里提到的控制 GKD 的 $\lambda$
  • --seq_kd:若为 False,则 GKD 不走 OD 分支时使用数据集直接 SFT;若为 True,则 GKD 不走 OD 分支时采样一个教师生成来 SFT

具体的示例,ms-swift 官方也提供了,这里就不再重复了。

文章目录