# 2023 Ainslie et al.

GQA: Training Generalized Multi-Query Attention

GQA:从多头检查点训练广义多查询 Transformer 模型 Joshua Ainslie , James Lee-Thorp , Michiel de Jong † Yury Zemlyanskiy, Federico Lebrón, Sumit Sanghai Google Research 摘要 多查询注意力(Multi-query a...

精粹译文

GQA:从多头检查点训练广义多查询 Transformer 模型

Joshua Ainslie*, James Lee-Thorp*, Michiel de Jong* † Yury Zemlyanskiy, Federico Lebrón, Sumit Sanghai

Google Research

摘要

多查询注意力(Multi-query attention, MQA)仅使用单个键值头,极大地加快了解码器推理速度。然而,MQA 可能导致质量下降,此外,仅仅为了更快的推理而训练一个单独的模型可能并不理想。我们(1)提出了一种方案,利用原始预训练计算量的 5% 将现有的多头语言模型检查点“上训练”(uptrain)为具有 MQA 的模型;(2)引入了分组查询注意力(Grouped-query attention, GQA),这是多查询注意力的一种泛化形式,它使用中间数量(大于 1,小于查询头数量)的键值头。我们证明,经过上训练的 GQA 在保持接近 MQA 的速度的同时,达到了接近多头注意力的质量。


1 引言

自回归解码器推理是 Transformer 模型的一个严重瓶颈,这是由于在每个解码步骤中加载解码器权重以及所有注意力键和值所带来的内存带宽开销(Shazeer, 2019; Pope et al., 2022; de Jong et al., 2022)。通过多查询注意力(Shazeer, 2019),可以大幅减少加载键和值的内存带宽,该方法使用多个查询头,但仅使用单个键和值头。

然而,多查询注意力(MQA)可能导致质量下降和训练不稳定,并且针对质量和推理分别优化模型可能并不可行。此外,虽然一些语言模型已经使用了多查询注意力,例如 PaLM(Chowdhery et al., 2022),但许多模型并没有使用,包括公开可用的语言模型,如 T5(Raffel et al., 2020)和 LLaMA(Touvron et al., 2023)。

本工作包含两项旨在实现大型语言模型更快推理的贡献。首先,我们证明具有多头注意力(MHA)的语言模型检查点可以通过上训练(Komatsuzaki et al., 2022)来使用 MQA,且仅需原始训练计算量的一小部分。这提供了一种经济高效的方法来获得快速的多查询检查点以及高质量的 MHA 检查点。

其次,我们提出了分组查询注意力(GQA),这是一种介于多头和多查询注意力之间的插值方法,每个查询头子组对应单个键和值头。我们证明,经过上训练的 GQA 在达到接近多头注意力质量的同时,速度几乎与多查询注意力一样快。

2 方法

2.1 上训练(Uptraining)

从多头模型生成多查询模型分为两个步骤:首先是转换检查点,其次是进行额外的预训练以使模型适应其新结构。图 1 展示了将多头检查点转换为多查询检查点的过程。键和值头的投影矩阵被均值池化(mean pooled)为单个投影矩阵,我们发现这种方法比选择单个键值头或从头开始随机初始化新的键值头效果更好。

多头到多查询注意力转换概览。所有头的键和值投影矩阵被均值池化为一个单一的头。 图 1: 多头到多查询注意力转换概览。所有头的键和值投影矩阵被均值池化为一个单一的头。

转换后的检查点随后在相同的预训练方案下,以原始训练步数的一小部分 α\alpha 进行预训练。

2.2 分组查询注意力(Grouped-query attention)

分组查询注意力将查询头分为 GG 个组,每一组共享一个键头和值头。GQA-GG 指的是具有 GG 个组的分组查询。GQA-1 具有单个组,因此具有单个键和值头,等同于 MQA;而 GQA-HHHH 为头数)等同于 MHA。图 2 展示了分组查询注意力与多头/多查询注意力的比较。当将多头检查点转换为 GQA 检查点时,我们通过对该组内的所有原始头进行均值池化来构建每个组的键和值头。

分组查询方法概览。多头注意力具有 H 个查询、键和值头。多查询注意力在所有查询头之间共享单个键和值头。分组查询注意力则为每组查询头共享单个键和值头,在多头和多查询注意力之间进行插值。 图 2: 分组查询方法概览。多头注意力具有 HH 个查询、键和值头。多查询注意力在所有查询头之间共享单个键和值头。分组查询注意力则为每组查询头共享单个键和值头,在多头和多查询注意力之间进行插值。

中间数量的组会导致一个插值模型,其质量高于 MQA 但速度快于 MHA,并且正如我们将展示的那样,这代表了一种有利的权衡。从 MHA 到 MQA 将 HH 个键和值头减少为单个键和值头,减少了键值缓存的大小,从而将需要加载的数据量减少了 HH 倍。然而,大型模型通常会扩展头数,因此多查询注意力在内存带宽和容量方面代表了更激进的削减。GQA 让我们在模型规模增加时保持带宽和容量的相同比例下降。

此外,大型模型受注意力内存带宽开销的影响相对较小,因为 KV 缓存随模型维度缩放,而模型 FLOPs 和参数随模型维度的平方缩放。最后,大型模型的标准分片(sharding)通过模型分区数量复制单个键和值头(Pope et al., 2022);GQA 消除了这种分区带来的浪费。因此,我们预计 GQA 对于大型模型将呈现出特别好的权衡。

我们注意到 GQA 不应用于编码器自注意力层;编码器表示是并行计算的,因此内存带宽通常不是主要瓶颈。

3 实验

3.1 实验设置

配置 所有模型均基于 T5.1.1 架构(Raffel et al., 2020),使用 JAX (Bradbury et al., 2018)、Flax (Heek et al., 2020) 和 Flaxformer 实现。对于我们的主要实验,我们考虑了具有多头注意力的 T5 Large 和 XXL,以及具有多查询和分组查询注意力的上训练 T5 XXL 版本。我们使用与 T5 (Raffel et al., 2020) 相同的超参数和学习率调度来使用 Adafactor 优化器。我们将 MQA 和 GQA 应用于解码器自注意力和交叉注意力,但不应用于编码器自注意力。

上训练 上训练模型从公开的 T5.1.1 检查点初始化。键和值头被均值池化为适当的 MQA 或 GQA 结构,然后使用 (Raffel et al., 2020) 的原始预训练设置和数据集,以原始预训练步数的 α\alpha 比例进行进一步预训练。对于 α=0.05\alpha = 0.05,训练大约耗时 600 个 TPUv3 芯片日。

数据 我们在摘要数据集 CNN/Daily Mail (Nallapati et al., 2016)、arXiv 和 PubMed (Cohan et al., 2018)、MediaSum (Zhu et al., 2021) 和 Multi-News (Fabbri et al., 2019) 上进行评估;翻译数据集 WMT 2014 英德翻译;以及问答数据集 TriviaQA (Joshi et al., 2017)。我们不对诸如 GLUE (Wang et al., 2019) 之类的流行分类基准进行评估,因为自回归推理对这些任务的适用性较低。

微调 对于微调,我们对所有任务使用 0.001 的恒定学习率、128 的批大小和 0.1 的 dropout 率。CNN/Daily Mail 和 WMT 使用 512 的输入长度和 256 的输出长度。其他摘要数据集使用 2048 的输入长度和 512 的输出长度。最后,TriviaQA 使用 2048 的输入长度和 32 的输出长度。我们训练直到收敛,并选择开发集性能最高的检查点。我们使用贪婪解码进行推理。

计时 我们报告每个 TPUv4 芯片的每样本时间,由 xprof (Google, 2020) 测量。对于计时实验,我们使用 8 个 TPU,并使用每个 TPU 最多可容纳的最大批大小,且每个模型的并行化均单独优化。

模型TinferT_{infer}平均CNNarXivPubMedMediaSumMultiNewsWMTTriviaQA
sR1R_1R1R_1R1R_1R1R_1R1R_1BLEUF1
MHA-Large0.3746.042.944.646.235.546.627.778.2
MHA-XXL1.5147.243.845.647.536.446.928.481.9
MQA-XXL0.2446.643.045.046.936.146.528.581.3
GQA-8-XXL0.2847.143.545.447.736.347.228.481.6

表 1: T5 Large 和 XXL 模型(多头注意力)与 5% 上训练的 T5-XXL 模型(多查询和分组查询注意力)在摘要数据集 CNN/Daily Mail、arXiv、PubMed、MediaSum 和 MultiNews,翻译数据集 WMT,以及问答数据集 TriviaQA 上的推理时间和平均开发集性能比较。

3.2 主要结果

图 3 显示了所有数据集的平均性能随平均推理时间的变化,对比了 MHA T5-Large 和 T5-XXL,以及上训练比例 α=0.05\alpha = 0.05 的 MQA 和 GQA-8 XXL 模型。我们看到,更大的上训练 MQA 模型相对于 MHA 模型提供了有利的权衡,具有比 MHA-Large 更高的质量和更快的推理速度。此外,GQA 实现了显著的额外质量提升,在达到接近 MQA 速度的同时,性能接近 MHA-XXL。表 1 包含了所有数据集的完整结果。

上训练的 MQA 与 MHA 相比产生了有利的权衡,具有比 MHA-Large 更高的质量和更快的速度,而 GQA 在实现类似速度提升的同时,达到了更好的性能,且质量与 MHA-XXL 相当。图示为 T5-Large 和 T5-XXL(多头注意力)以及 5% 上训练的 T5-XXL(MQA 和 GQA-8 注意力)在所有任务上的平均性能随每样本平均推理时间的变化。 图 3: 上训练的 MQA 与 MHA 相比产生了有利的权衡,具有比 MHA-Large 更高的质量和更快的速度,而 GQA 在实现类似速度提升的同时,达到了更好的性能,且质量与 MHA-XXL 相当。

3.3 消融实验

本节介绍旨在研究不同建模选择影响的实验。我们评估了任务代表性子集上的性能:CNN/Daily Mail(短篇摘要)、MultiNews(长篇摘要)和 TriviaQA(问答)。

检查点转换 图 4 比较了不同检查点转换方法的性能。均值池化似乎效果最好,其次是选择单个头,然后是随机初始化。直观上,结果是按照从预训练模型中保留信息的程度来排序的。

上训练步数 图 5 显示了性能随 T5 XXL 模型 MQA 和 GQA 上训练比例的变化。首先,我们注意到 GQA 在转换后就已经达到了合理的性能,而 MQA 需要上训练才能变得有用。MQA 和 GQA 都从 5% 的上训练中获益,并在 10% 时收益递减。

T5-Large 上训练为 MQA(比例 \alpha = 0.05)的不同检查点转换方法的性能比较。‘Mean’ 表示对键和值头进行均值池化,‘First’ 选择第一个头,‘Random’ 从头开始初始化头。 图 4: T5-Large 上训练为 MQA(比例 α=0.05\alpha = 0.05)的不同检查点转换方法的性能比较。

T5 XXL 模型(MQA 和 GQA-8)性能随上训练比例的变化。 图 5: T5 XXL 模型(MQA 和 GQA-8)性能随上训练比例的变化。

组数 图 6 演示了 GQA 组数对推理速度的影响。对于更大的模型,来自 KV 缓存的内存带宽开销限制较小(Shazeer, 2019),而由于头数增加,键值大小的减少更为剧烈。因此,从 MQA 开始增加组数最初只会导致轻微的减速,随着我们向 MHA 靠近,成本会增加。我们选择了 8 个组作为有利的中间地带。

GQA-XXL 每样本时间随 GQA 组数的变化(输入长度 2048,输出长度 512)。从 1(MQA)增加到 8 个组会增加轻微的推理开销,增加更多组的成本会随之增加。 图 6: GQA-XXL 每样本时间随 GQA 组数的变化(输入长度 2048,输出长度 512)。

4 相关工作

本工作专注于通过减少加载键和值的内存带宽开销(Williams et al., 2009)来实现解码器质量和推理时间之间更好的权衡。Shazeer (2019) 首先提出了通过多查询注意力来减少这种开销。后续工作表明,多查询注意力对于长输入特别有帮助(Pope et al., 2022; de Jong et al., 2022)。Rabe (2023) 独立开发了具有公共实现的 GQA。其他工作探索了为了计算效率而对注意力头进行分组(Park et al., 2020; Luo et al., 2022; Ni et al., 2023),但没有专门关注决定内存带宽开销的键值头。

已经提出了许多其他方法来减少键和值的内存带宽开销以及参数量。Flash attention (Dao et al., 2022) 构建了注意力计算以避免实现二次注意力分数,从而减少内存并加快训练。量化(Dettmers et al., 2022; Frantar et al., 2022)通过降低精度来减小权重和激活(包括键和值)的大小。模型蒸馏(Hinton et al., 2015; Gou et al., 2021)则在给定精度下减小模型大小,使用从较大模型生成的数据来微调较小模型。层稀疏交叉注意力(de Jong et al., 2022)消除了大部分交叉注意力层,这些层是长输入的主要开销。推测采样(Chen et al., 2023; Leviathan et al., 2022)通过用较小模型提出多个标记,然后由较大模型并行评分,从而改善了内存带宽瓶颈。

最后,我们提出的上训练程序受到 Komatsuzaki et al. (2022) 的启发,该研究将标准 T5 检查点上训练为稀疏激活的混合专家模型。

5 结论

语言模型的推理成本高昂,主要是由于加载键和值的内存带宽开销。多查询注意力以牺牲模型容量和质量为代价减少了这种开销。我们提出将多头注意力模型转换为多查询模型,且仅需原始预训练计算量的一小部分。此外,我们引入了分组查询注意力,这是一种多查询和多头注意力的插值方法,在达到接近多查询注意力的速度的同时,实现了接近多头注意力的质量。

局限性 本文专注于改善加载键和值的内存带宽开销。这种开销在生成较长序列时最为重要,而长序列的质量本身就难以评估。对于摘要,我们采用了 Rouge 分数,我们知道这是一种有缺陷的评估方法,不能说明全部情况;因此,很难确定我们的权衡是否完全正确。由于计算量有限,我们也没有将我们的 XXL GQA 模型与从头训练的比较模型进行对比,因此我们不知道上训练与从头训练的相对性能。最后,我们仅在编码器-解码器模型上评估了上训练和 GQA 的影响。最近,仅解码器模型非常流行,由于这些模型没有单独的自注意力和交叉注意力,我们预计 GQA 相对于 MQA 将具有更强的优势。

致谢 我们感谢 Santiago Ontañón、Afroz Mohiuddin、William Cohen 以及 Google Research 的其他人提供的深刻建议和讨论。


(参考文献部分略,保持原格式)

硬核测试

正确率:0 / 5
1

根据论文,多查询注意力(MQA)的主要优势是什么?

2

在将多头注意力(MHA)检查点转换为多查询或分组查询注意力(GQA)时,论文推荐的最佳转换方法是什么?

3

关于分组查询注意力(GQA),下列描述正确的是:

4

论文中提到的“上训练”(Uptraining)过程,其计算量大约占原始预训练计算量的多少?

5

根据论文,为什么 GQA 不应用于编码器自注意力层?