机器学习 | 语言模型解码算法
语言模型解码算法:指在语言模型生成过程中,从模型输出的概率分布 (logits) 中,根据一定策略选择并生成下一个词或符号的过程,直至构成完整的句子或文本。
需要注意这里的解码不同于 Transformer 架构大语言模型的 Decoder 块。大语言模型的 Decoder 块是模型结构的一部分,而这里讲的解码算法不属于模型结构,只是利用模型来生成完整句子的一种算法。
1 理想情况
1.1 原理
语言模型的解码目标便是寻求最大化以下联合概率:
但需要注意的是,大语言模型是一个自回归模型,下一个 Token 的概率分布是受当前 Token 影响的,这个性质使得取得上述概率的全局最优解较为困难。
例如以下概率分布的例子,红线是全局最优解:
1.2 复杂度
如果我们要求得理想情况下的最优解,那么我们需要遍历所有情况来求取。假设每次生成的概率分布都有
它的复杂度为
2 贪心解码(Greedy Decoding)
2.1 原理
贪心解码利用了贪心思想,每次挑选下一个 Token 时,只挑概率最大的 Token。公式表述即为:
对于上面的示例,使用贪心解码就会选择绿色的路线(红色为全局最优解):
2.2 复杂度
对于贪心解码,假设每次生成的概率分布都有
可以发现贪心解码的复杂度是线性的,效率非常高。
3 束搜索(Beam Search)
3.1 原理
贪心解码只挑概率最大的 Token,而束搜索引入了参数 num_beams
,每次挑选 num_beams
个大值进行尝试。可以发现,当 num_beams=1
时,束搜索便退化为贪心搜索。
对于上面的示例,使用 num_beams=2
的束搜索就会选择绿色的路线(红色为全局最优解):
3.2 复杂度
对于束搜索,假设每次生成的概率分布都有 num_beams
记为
可以发现束搜索的复杂度也是线性的,效率也非常高,性价比比较高。
4 随机采样(Sampling)
4.1 原理
随机采样按照概率分布随机选择下一个词,虽然也不能保证采样到全局最优解,但是这样生成的句子可能更具多样性,有可能形成更好的结果。
随机采样也有很多种方式:
- Temperature 采样
- Top-k 采样
- Top-p 采样
4.2 Temperature 采样
Temperature 采样的公式可以用 softmax 函数表示。假设模型输出一个向量
其中,
- 当
越小,高概率的 Token 会变得更有可能被抽到,低概率的 Token 会变得更不可能被抽到,模型更加固定。 - 当
越大,高概率和低概率的 Token 之间的差距会变得更小,模型生成的文本会更加随机和多样。
4.3 Top-k 采样
Top-k 采样先找出
- LLM 提供下一个 Token 的概率 logits。
- 然后,模型会选择概率最高的
个词。 - 最后,模型会从这
个词中随机抽取一个。
4.4 Top-p 采样
Top-p 才行先找出累计概率超过
- LLM 提供下一个 Token 的概率 logits。
- 然后,模型会从概率高到低的 Token 开始选择,直到累计概率超过
。 - 最后,模型会从这些词中随机抽取一个。
Top-p 采样实际上是对 Top-k 采样的针对优化,因为 Top-k 采样固定了
4.5 复杂度
对于随机采样,假设每次生成的概率分布都有
复杂度和贪心解码是一样的,但是会较贪心解码有更好的多样性,属于综合效果较好的解码方式。
5 transformers 库
transformers 库将以上解码方法合并到了 generate 方法中,通过特定的参数组合来决定使用哪种解码方式:
- 贪心解码:
num_beams=1
且do_sample=False
- 随机采样:
num_beams=1
且do_sample=True
- 束搜索:
num_beams>1
且do_sample=False
- 随机束搜索:
num_beams>1
且do_sample=True
,相当于使用随机采样代替选择 Top-k
源代码:https://github.com/huggingface/transformers/blob/v4.20.1/src/transformers/generation_utils.py
最新版的 transformers 库对 generate 方法有调整,不过没有特别大的变化。
本文采用 CC BY-SA 4.0 许可,本文 Markdown 源码:Haotian-BiJi