机器学习 | LLM 并行方法(一)
本部分文章将涉及以下 LLM 并行方法:
- 数据并行 (Data Parallel, DP):将数据切分给不同 GPU,不同 GPU 并行处理不同输入数据。
- 张量并行 (Tensor Parallel, TP):将单个参数矩阵切分到多卡,通过协作完成层内计算。
本文图片来源 The Ultra-Scale Playbook: Training LLMs on GPU Clusters,使用 Apache-2.0 协议,兼容 CC BY-SA 4.0 协议。
本文需要前置知识并行计算集合通信,建议先确保了解七种集合通信方式再阅读本文。
1 数据并行 Data Parallel
1.1 概念
在单卡训练时,通过指定 batch_size,我们可以充分利用 GPU 的算力,实现多条不同数据并行处理,得到梯度均值后进行优化。数据并行(DP)便将这个流程扩展到了多 GPU 上,每个 GPU 同时处理不同的数据,计算完成后将所有 GPU 的梯度求均值,对参数进行优化。
DP 的示意图如下:

P.s. 图中看起来似乎每个 GPU 处理一条数据,实际上得理解为每个 GPU 处理一批数据,也就是若batch_size_per_gpu=16,那么 3 卡 DP 实际的batch_size=48(忽略梯度累积).
1.2 方法
1.2.1 朴素 DP/DDP
要实现 DP,最朴素的流程便是:
- 将不同的数据给到所有 GPU 进行前向传播,得到每个 GPU 的损失值;
- 每个 GPU 对损失值进行反向传播,得到每个 GPU 的梯度;
- 通过 All-Reduce 操作,将每个 GPU 的梯度汇总平均,同步到所有 GPU;
- 每个 GPU 使用同步后的梯度进行优化,更新参数。
可以看到,朴素 DP 和单卡训练的唯一区别就是第三步的 All-Reduce 操作,基于 Ring 算法的 All-Reduce 操作的单节点通信量是 $\approx2\Psi$,那么朴素 DP 的通信开销就可以记为 $2\Psi$。
什么是 DDP (Distributed Data Parallel)?
DP 和 DDP 实际上区别在于 All-Reduce 操作的工程实现方式:DP 使用主从方式进行 All-Reduce,主 GPU 的通信压力极大;DDP 使用 Ring 算法进行 All-Reduce,通信负载均匀分布在所有节点。它们的思想实际上是完全一致的,在理论讨论时无需区分。
由于每个 GPU 做的事情和单卡时几乎一样,因此朴素的 DP 方法不能节省显存,只能加快训练——若单卡训练需要 64GB 显存,四卡训练每张卡同样需要 64GB 显存,而训练速度可以加速约四倍。
1.2.2 零冗余优化器 (ZeRO)
观察朴素的 DP 方法,朴素 DP 的前向传播和反向传播过程是真正并行的,但从 All-Reduce 操作之后所有的 GPU 实际在做完全一致的事情,储存了大量镜像冗余的参数。零冗余优化器 (Zero Redundancy Optimizer, ZeRO) 是一种 DP 方法,它消除了这些冗余的参数,让 DP 在加快训练的同时还能节省显存。
ZeRO 方法的核心原理是:每个模型参数、该参数对应的梯度、该参数对应的优化器状态可以看作一个“小组”,这些“小组”之间是完全独立的,在每个“小组”内就可以直接完成对应部分 $\mathrm{Gradient}\to\mathrm{Optimizer}$ 的计算以及 $\mathrm{Optimizer}\to\Delta\mathrm{Parameter}$ 的计算。
根据以上性质,我们可以将这些小组均分到 $N_d$ 张 GPU 上,它们都能各自独立运行。仅当运行到需要完整参数才可计算的过程(如模型正向传播、反向传播)时,才收集每个 GPU 上的分片组合为完整参数进行计算。
$$ \overbrace{\mathrm{Parameter}\to\mathrm{Loss}\to}^{需完整参数}\underbrace{\mathrm{Gradient}\to\mathrm{Optimizer}\to\mathrm{Parameter}'}_{可分片独立运行}\overbrace{\to\mathrm{Loss}'\to\cdots}^{需完整参数} $$
ZeRO 按冗余参数的去除程度分为三个阶段 (Stage),可以根据具体需求选择不同的方法:
- ZeRO-1:优化器状态分块;
- ZeRO-2:优化器状态、梯度分块;
- ZeRO-3:优化器状态、梯度、参数分块。

ZeRO-1 Stage
在朴素 DP 中,从 All-Reduce 开始每个 GPU 就在干一样的活,它将 $1$ 份优化器状态镜像储存了 $N_d$ 份。ZeRO-1 方法打破了这个局面,它将 $1$ 份优化器状态分片到 $N_d$ 个 GPU 上,每个 GPU 只负责储存 $1/N_d$ 份优化器状态,该份优化器也只负责对应部分的梯度和参数。
ZeRO-1 的流程如下:
- 模型正向传播得到损失值;
- 损失值反向传播得到完整梯度;
- 使用 Reduce-Scatter 收集各自分片优化器对应的分片梯度;
- 基于对应的分片梯度更新分片的优化器状态;
- 基于分片的优化器状态得到分片的更新后模型参数;
- 通过 All-Gather 把分片的更新后模型参数拼接回来,得到完整新参数;
- 进行后续操作。
ZeRO-1 方法中的 Reduce-Scatter 通信开销 $\Psi$,All-Gather 通信开销 $\Psi$,因此 ZeRO-1 方法的总通信开销为 $2\Psi$.
用图片表示如下:

ZeRO-2 Stage
在小节开头便提到了,每个“参数、梯度、优化器”小组之间是完全独立的,在每个小组内就可以独立完成 $\mathrm{Gradient}\to\mathrm{Optimizer}$ 的过程。因此,ZeRO-1 方法在 Reduce-Scatter 之后,只有对应优化器状态分片的梯度是有用的,其他梯度完全没用。ZeRO-2 方法的唯一区别就在丢弃释放不需要的梯度,因此不再赘述。
和 ZeRO-1 一样,ZeRO-2 方法中的 Reduce-Scatter 通信开销 $\Psi$,All-Gather 通信开销 $\Psi$,因此 ZeRO-2 方法的总通信开销为 $2\Psi$.
用图片表示如下:

反向传播先得到完整的梯度,再丢掉不需要的梯度,那显存占用峰值不是没有变化吗?
ZeRO-2 的梯度计算过程是逐层的,每计算完模型一层,就规约一次,然后丢弃不需要的梯度,并不是说把模型所有层全计算完了再统一丢弃,因此不会产生非常大的显存峰值。
ZeRO-3 Stage
ZeRO-3 阶段将模型参数也进行了分块,相当于所有的“参数、梯度、优化器”小组被分为了独立的 $N_d$ 份,彻底消除了所有的冗余参数,实现了“Zero Redundancy”.
但这不是没有代价的,当模型参数也被分块后,模型的正向传播和反向传播就不能直接完成了:
$$ \overbrace{\mathrm{Parameter}\to\mathrm{Loss}\to}^{需完整参数}\underbrace{\mathrm{Gradient}\to\mathrm{Optimizer}\to\mathrm{Parameter}'}_{可分片独立运行}\overbrace{\to\mathrm{Loss}'\to\cdots}^{需完整参数} $$
于是 ZeRO-3 的流程如下:
- 对于模型每一层,先用 All-Gather 收集分片模型参数得到完整模型参数,正向传播得到损失值;
- 对于模型每一层,先用 All-Gather 收集分片模型参数得到完整模型参数,反向传播得到完整梯度值;
- 在计算完一层的梯度后,使用 Reduce-Scatter 收集各自分片优化器对应的分片梯度,并丢弃不需要的梯度部分;
- 基于对应的分片梯度更新分片的优化器状态;
- 基于分片的优化器状态得到分片的更新后模型参数;
ZeRO-3 方法中的 All-Gather 通信开销 $\Psi$,两次一共 $2\Psi$,Reduce-Scatter 通信开销 $\Psi$,因此 ZeRO-3 方法的总通信开销为 $3\Psi$.
用图片表示如下:
| 正向传播 | 反向传播 |
|---|---|
![]() | ![]() |
为什么 ZeRO-3 方法比 ZeRO-1/2 多了两次 All-Gather,最终的开销只变大了 $\Psi$?
由于 ZeRO-3 方法每个 GPU 只储存分片的参数,因此不需要 ZeRO-1/2 的“通过 All-Gather 把分片模型参数拼接得到完整新参数”的过程,因此总共 ZeRO-3 相对于 ZeRO-1/2 多了两次正反传播的 All-Gather,少了一次拼接新参数的 All-Gather,总共只相当于多了一次 All-Gather,开销只变大了 $\Psi$.
1.3 代价
总结上面的四种 DP 方法,表格如下:
| 方法 | 显存节省 | 每 GPU 储存内容 | 通信开销 |
|---|---|---|---|
| DDP | 无 | 完整模型参数 + 完整梯度 + 完整优化器状态 | All-Reduce ($2\Psi$) |
| ZeRO-1 | 低 | 完整模型参数 + 完整梯度 + 1/N 优化器状态 | Reduce-Scatter+All-Gather ($2\Psi$) |
| ZeRO-2 | 中 | 完整模型参数 + 1/N 梯度 + 1/N 优化器状态 | Reduce-Scatter+All-Gather ($2\Psi$) |
| ZeRO-3 | 高 | 1/N 参数 + 1/N 梯度 + 1/N 优化器状态 | Reduce-Scatter+$2\times$All-Gather ($3\Psi$) |
从通信开销来看,选择 ZeRO-2 看似是没有任何代价的,它用和 DDP 一样的通信开销,节省了大量的显存,似乎可以无脑选择。但实际上,通信开销只是衡量代价的一个方面,除了通信开销之外,还有通信延迟的增加、显存管理开销等代价。
但总的来说,在单机情况下(基于 NVLink 的机内带宽充足),ZeRO-2 方法确实接近“免费的午餐”。当模型很小或显存充足时,若要追求极限的训练速度,DDP 依然是最稳健、延迟最低的选择。
而对于 ZeRO-3,它引入了 50% 的额外通信开销,只有当想在有限的硬件上跑“远超单卡容量”的巨型模型训练时才会考虑它。
2 张量并行 Tensor Parallel
2.1 概念
张量并行的概念非常简单,在线性代数中,矩阵乘法是可以进行分块的:
$$ Y = X[W_0, W_1] = [XW_0, XW_1] = [Y_0, Y_1] $$
$$ Y = [X_0, X_1] \begin{bmatrix} W_0 \\ W_1 \end{bmatrix} = X_0 W_0 + X_1 W_1 = Y_0 + Y_1 $$
基于该数学性质,可以非常自然地实现多 GPU 的并行处理:
| 按列分块 | 按行分块 |
|---|---|
![]() | ![]() |
2.2 应用
2.2.1 在 LLM 中的应用方式
要在 LLM 中应用张量并行,主要应用的部分就是多头注意力机制和前馈网络部分:
多头注意力机制
- Q/K/V 矩阵:按列拆分,每个 GPU 负责计算一部分注意力头。例如,一个有 32 个头的模型在 4 张 GPU 上运行时,每张显卡只需处理 8 个头。
- 线性输出层 ($W_O$):按行拆分,各头计算完自己的注意力结果后,最后进行一次 All-Reduce 汇总。
P.s. 这也是为什么,在使用 vLLM 进行 TP 并行时,当 GPU 数量不能整除注意力头数时会报错(例如 5 卡并行 32 注意力头时).
TP 并行图示如下:

前馈网络
- 第一层 ($Gelu(X \cdot W_1)$):按列拆分。
- 第二层 ($Y \cdot W_2$):按行拆分。
在这种设计下,MLP 整个过程只需要在最后一步进行一次同步通信(All-Reduce),效率非常高。
TP 并行图示如下:

2.2.2 与序列并行(Sequence Parallel)结合
通过上面的图示可以发现,模型的参数、梯度、优化器状态确实是均匀分布到了多张 GPU 上,多卡并行计算 Attention 和 Linear,降低了单卡显存需求同时加快速度。但每次计算完成后得到的模型激活值,却通过 All-Reduce 完整镜像到了每张 GPU 上,每张 GPU 用完整的激活值镜像地计算 Dropout 和 LayerNorm,浪费显存同时无法加速。
为了解决这个问题,序列并行应运而生,它针对的就是张量并行无法并行的 Dropout 和 LayerNorm 部分,将激活值均分到每张 GPU 上,每张 GPU 并行计算一部分的 Dropout 和 LayerNorm 操作。接下来具体来看看如何分别对 Dropout 和 LayerNorm 并行化。
Dropout
对于 LLM 中使用的 Dropout(nn.Dropout),它天然就是元素级独立的,每个隐藏状态是否被丢弃互不干扰。因此,Dropout 操作可以从序列长度方向 $seq$ 进行切分,每个 GPU 只处理自己负责序列的 Dropout.
LayerNorm
对于一个 $d_\mathrm{model}$ 维度的隐藏状态 $h$,首先要计算它的均值和方差:
$$ \mu = \frac{1}{d_\mathrm{model}} \sum_{i=1}^{d_\mathrm{model}} h_i\\ \sigma^2 = \frac{1}{d_\mathrm{model}} \sum_{i=1}^{d_\mathrm{model}} (h_i - \mu)^2 $$
得到均值和方差之后,便可以进行归一化了($\epsilon$ 为防止除零的极小值):
$$ \hat{h}_i = \frac{h_i - \mu}{\sqrt{\sigma^2 + \epsilon}} $$
最后进行线性变化,$\gamma$ 和 $\beta$ 是可学习的参数,$\odot$ 是哈达玛积:
$$ h' = \gamma \odot \hat{h} + \beta $$
$h'$ 便是 LayerNorm 后的隐藏状态了。
注意上面的整个流程,LayerNorm 实际上是在对单独的每个隐藏状态 $h$ 做操作,它并不需要像 BatchNorm 一样需要完整 Batch 的数据才能计算。因此,LayerNorm 操作可以从序列长度方向 $seq$ 进行切分,每个 GPU 只处理一部分序列,全部处理完后拼接还原即可。
到这里,序列并行的方式应该就很清楚了,它和张量并行的区别就在于切分的方向不同:
- 张量并行:$(batch,seq,hidden)\to N_{\mathrm{gpu}}\times(batch,seq,\frac{hidden}{N_{\mathrm{gpu}}})$
- 序列并行:$(batch,seq,hidden)\to N_{\mathrm{gpu}}\times(batch,\frac{seq}{N_{\mathrm{gpu}}},hidden)$
它们二者结合使用,就可以实现模型参数、梯度、优化器、激活值全部均分到多张 GPU 上了。示意图如下:

2.3 代价
数据并行的通信仅发生在整个模型反向传播结束时,但张量并行发生在每一层内部计算过程中,因此张量并行的通信量极大。这就导致张量并行基本上只能在单机 NVLink 中完成,几乎不可能进行多级多卡并行。在多机多卡并行时,一般采用机内张量并行、机间数据并行或者其他并行方式。
接下来探讨序列并行的加入,对通信开销的影响。
看到上图,若仅使用张量并行,那么它所需集合通信有:
- 正向传播:$f$ 位置无通信,$f^*$ 位置 All-Reduce
- 反向传播:$f$ 位置 All-Reduce,$f^*$ 位置无通信
若使用张量并行+序列并行,那么它所需的集合通信有:
- 正向传播:$g$ 位置 All-Gather,$g^*$ 位置 Reduce-Scatter
- 反向传播:$g$ 位置 Reduce-Scatter,$g^*$ 位置 All-Gather
我们知道,All-Reduce 操作的开销是其他六种集合通信操作的两倍,因此单独使用张量并行的开销和使用张量并行+序列并行是相同的。可以理解为,要使用张量并行时,结合使用序列并行算是“免费的午餐”。
本文采用 CC BY-SA 4.0 许可,本文 Markdown 源码:Haotian-BiJi



