图像胜过 16x16 个词:大规模图像识别的 Transformer
Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov*, Dirk Weissenborn*, Xiaohua Zhai*, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby*,†** *等同技术贡献,†等同指导 Google Research, Brain Team {adosovitskiy, neilhoulsby}@google.com
摘要
尽管 Transformer 架构已成为自然语言处理任务的事实标准,但其在计算机视觉领域的应用仍然有限。在视觉领域,注意力机制要么与卷积网络结合使用,要么在保持卷积网络整体结构不变的情况下替换其某些组件。我们证明,这种对 CNN 的依赖并非必要,直接应用于图像块序列的纯 Transformer 在图像分类任务上可以表现得非常好。当在海量数据上进行预训练并迁移到多个中小型图像识别基准(ImageNet、CIFAR-100、VTAB 等)时,Vision Transformer (ViT) 与最先进的卷积网络相比取得了优异的结果,同时训练所需的计算资源显著减少。
1 引言
基于自注意力(self-attention)的架构,特别是 Transformer (Vaswani et al., 2017),已成为自然语言处理 (NLP) 中的首选模型。主流方法是在大型文本语料库上进行预训练,然后在较小的特定任务数据集上进行微调 (Devlin et al., 2019)。得益于 Transformer 的计算效率和可扩展性,训练参数超过 100B 的超大规模模型已成为可能 (Brown et al., 2020; Lepikhin et al., 2020)。随着模型和数据集的增长,性能尚未出现饱和的迹象。
然而,在计算机视觉领域,卷积架构仍然占据主导地位 (LeCun et al., 1989; Krizhevsky et al., 2012; He et al., 2016)。受 NLP 成功的启发,多项研究尝试将类 CNN 架构与自注意力相结合 (Wang et al., 2018; Carion et al., 2020),有些研究则完全取代了卷积 (Ramachandran et al., 2019; Wang et al., 2020a)。后者的模型虽然在理论上是高效的,但由于使用了专门的注意力模式,尚未在现代硬件加速器上实现有效扩展。因此,在大规模图像识别中,经典的类 ResNet 架构仍然是最先进的 (Mahajan et al., 2018; Xie et al., 2020; Kolesnikov et al., 2020)。
受 Transformer 在 NLP 中扩展成功的启发,我们尝试将标准 Transformer 直接应用于图像,并尽可能减少修改。为此,我们将图像分割成块(patches),并将这些块的线性嵌入序列作为 Transformer 的输入。图像块的处理方式与 NLP 应用中的标记(词)相同。我们以监督方式对模型进行图像分类训练。
当在 ImageNet 等中等规模数据集上进行训练且没有强正则化时,这些模型产生的精度比同等规模的 ResNet 低几个百分点。这种看似令人沮丧的结果是可以预料的:Transformer 缺乏 CNN 所固有的某些归纳偏置(inductive biases),例如平移等变性(translation equivariance)和局部性(locality),因此在训练数据不足时无法很好地泛化。
然而,如果模型在更大的数据集(14M-300M 图像)上进行训练,情况就会发生变化。我们发现大规模训练胜过归纳偏置。我们的 Vision Transformer (ViT) 在经过足够规模的预训练并迁移到数据点较少的任务时,取得了优异的结果。当在公开的 ImageNet-21k 数据集或内部的 JFT-300M 数据集上进行预训练时,ViT 在多个图像识别基准上接近或超越了最先进水平。特别是,最佳模型在 ImageNet 上达到了 88.55% 的准确率,在 ImageNet-ReaL 上达到 90.72%,在 CIFAR-100 上达到 94.55%,在 19 个任务的 VTAB 套件上达到 77.63%。
2 相关工作
Transformer 由 Vaswani et al. (2017) 提出用于机器翻译,并已成为许多 NLP 任务中最先进的方法。大型基于 Transformer 的模型通常在大型语料库上进行预训练,然后针对手头的任务进行微调:BERT (Devlin et al., 2019) 使用去噪自监督预训练任务,而 GPT 系列工作使用语言建模作为其预训练任务 (Radford et al., 2018; 2019; Brown et al., 2020)。
将自注意力直接应用于图像需要每个像素关注其他所有像素。由于像素数量的二次方成本,这无法扩展到实际的输入尺寸。因此,为了在图像处理的背景下应用 Transformer,过去已经尝试了几种近似方法。Parmar et al. (2018) 仅对每个查询像素在局部邻域内应用自注意力,而不是全局应用。这种局部多头点积自注意力块可以完全取代卷积 (Hu et al., 2019; Ramachandran et al., 2019; Zhao et al., 2020)。在另一项工作中,Sparse Transformers (Child et al., 2019) 采用可扩展的全局自注意力近似方法,以便应用于图像。扩展注意力的另一种方法是将其应用于不同大小的块 (Weissenborn et al., 2019),在极端情况下仅沿单个轴应用 (Ho et al., 2019; Wang et al., 2020a)。许多这些专门的注意力架构在计算机视觉任务中展示了有希望的结果,但需要复杂的工程才能在硬件加速器上高效实现。
与我们最相关的是 Cordonnier et al. (2020) 的模型,它从输入图像中提取 大小的块,并在其上应用完整的自注意力。该模型与 ViT 非常相似,但我们的工作更进一步,证明了大规模预训练使得普通 Transformer 能够与最先进的 CNN 竞争(甚至更好)。此外,Cordonnier et al. (2020) 使用了 像素的小块大小,这使得该模型仅适用于小分辨率图像,而我们也能处理中等分辨率图像。
将卷积神经网络 (CNN) 与各种形式的自注意力相结合也引起了极大的兴趣,例如通过增强特征图进行图像分类 (Bello et al., 2019),或者通过使用自注意力进一步处理 CNN 的输出,例如用于目标检测 (Hu et al., 2018; Carion et al., 2020)、视频处理 (Wang et al., 2018; Sun et al., 2019)、图像分类 (Wu et al., 2020)、无监督目标发现 (Locatello et al., 2020) 或统一的文本-视觉任务 (Chen et al., 2020c; Lu et al., 2019; Li et al., 2019)。
另一个最近的相关模型是 image GPT (iGPT) (Chen et al., 2020a),它在降低图像分辨率和色彩空间后将 Transformer 应用于图像像素。该模型以无监督方式作为生成模型进行训练,所得表示随后可以进行微调或线性探测以获得分类性能,在 ImageNet 上实现了 72% 的最高准确率。
我们的工作增加了越来越多的论文集合,这些论文探索了比标准 ImageNet 数据集更大规模的图像识别。使用额外的数据源可以实现标准基准上的最先进结果 (Mahajan et al., 2018; Touvron et al., 2019; Xie et al., 2020)。此外,Sun et al. (2017) 研究了 CNN 性能如何随数据集大小扩展,Kolesnikov et al. (2020); Djolonga et al. (2020) 对从 ImageNet-21k 和 JFT-300M 等大规模数据集进行的 CNN 迁移学习进行了实证探索。我们也关注这两个数据集,但训练的是 Transformer,而不是先前工作中使用的基于 ResNet 的模型。

3 方法
在模型设计中,我们尽可能遵循原始的 Transformer (Vaswani et al., 2017)。这种刻意简单的设置的一个优势是,可扩展的 NLP Transformer 架构及其高效实现几乎可以直接使用。
3.1 Vision Transformer (ViT)
模型概览如图 1 所示。标准 Transformer 接收 1D 标记嵌入序列作为输入。为了处理 2D 图像,我们将图像 重塑为展平的 2D 块序列 ,其中 是原始图像的分辨率, 是通道数, 是每个图像块的分辨率,而 是生成的块数量,它也作为 Transformer 的有效输入序列长度。Transformer 在其所有层中都使用恒定的潜在向量大小 ,因此我们展平块并通过可训练的线性投影映射到 维度(公式 1)。我们将此投影的输出称为块嵌入。
类似于 BERT 的 [class] 标记,我们在嵌入块序列前添加一个可学习的嵌入(),其在 Transformer 编码器输出端的状态()作为图像表示 (公式 4)。在预训练和微调期间,分类头都附加到 上。分类头在预训练时由具有一个隐藏层的 MLP 实现,在微调时由单个线性层实现。
位置嵌入被添加到块嵌入中以保留位置信息。我们使用标准的 1D 可学习位置嵌入,因为我们没有观察到使用更高级的 2D 感知位置嵌入带来的显著性能提升(附录 D.4)。生成的嵌入向量序列作为编码器的输入。
Transformer 编码器 (Vaswani et al., 2017) 由多头自注意力(MSA,见附录 A)和 MLP 块交替层组成(公式 2、3)。层归一化 (LN) 应用于每个块之前,残差连接应用于每个块之后 (Wang et al., 2019; Baevski & Auli, 2019)。
MLP 包含两层,具有 GELU 非线性。
归纳偏置。我们注意到 Vision Transformer 比 CNN 具有更少的图像特定归纳偏置。在 CNN 中,局部性、二维邻域结构和平移等变性被烘焙到整个模型中的每一层。在 ViT 中,只有 MLP 层是局部的和平移等变的,而自注意力层是全局的。二维邻域结构的使用非常节制:在模型开始时通过将图像切割成块,以及在微调时用于调整不同分辨率图像的位置嵌入(如下所述)。除此之外,初始化时的位置嵌入不携带关于块的 2D 位置的信息,块之间的所有空间关系都必须从头开始学习。
混合架构。作为原始图像块的替代方案,输入序列可以由 CNN 的特征图形成 (LeCun et al., 1989)。在这种混合模型中,块嵌入投影 (公式 1)应用于从 CNN 特征图提取的块。作为一种特殊情况,块可以具有 的空间大小,这意味着输入序列是通过简单地展平特征图的空间维度并投影到 Transformer 维度而获得的。分类输入嵌入和位置嵌入按上述方式添加。
3.2 微调和更高分辨率
通常,我们在大型数据集上预训练 ViT,并微调到(较小的)下游任务。为此,我们移除预训练的预测头,并附加一个零初始化的 前馈层,其中 是下游类别的数量。以比预训练更高的分辨率进行微调通常是有益的 (Touvron et al., 2019; Kolesnikov et al., 2020)。当输入更高分辨率的图像时,我们保持块大小不变,这会导致更大的有效序列长度。Vision Transformer 可以处理任意序列长度(受内存限制),但是,预训练的位置嵌入可能不再有意义。因此,我们根据预训练位置嵌入在原始图像中的位置对其进行 2D 插值。请注意,这种分辨率调整和块提取是手动将关于图像 2D 结构的归纳偏置注入 Vision Transformer 的唯一几点。
4 实验
我们评估了 ResNet、Vision Transformer (ViT) 和混合架构的表示学习能力。为了了解每个模型的数据需求,我们在不同规模的数据集上进行预训练,并评估许多基准任务。在考虑预训练模型的计算成本时,ViT 表现非常优异,以较低的预训练成本在大多数识别基准上达到了最先进水平。最后,我们使用自监督进行了一个小实验,并表明自监督 ViT 对未来具有前景。
4.1 设置
数据集。为了探索模型的可扩展性,我们使用具有 1k 类和 1.3M 图像的 ILSVRC-2012 ImageNet 数据集(我们在下文中将其称为 ImageNet)、其超集 ImageNet-21k(具有 21k 类和 14M 图像)(Deng et al., 2009),以及 JFT(具有 18k 类和 303M 高分辨率图像)(Sun et al., 2017)。我们按照 Kolesnikov et al. (2020) 的方法,根据下游任务的测试集对预训练数据集进行了去重。我们将这些数据集上训练的模型迁移到几个基准任务:ImageNet 的原始验证标签和清理后的 ReaL 标签 (Beyer et al., 2020)、CIFAR-10/100 (Krizhevsky, 2009)、Oxford-IIIT Pets (Parkhi et al., 2012) 和 Oxford Flowers-102 (Nilsback & Zisserman, 2008)。对于这些数据集,预处理遵循 Kolesnikov et al. (2020)。
| 模型 | 层数 | 隐藏大小 | MLP 大小 | 头数 | 参数量 |
|---|---|---|---|---|---|
| ViT-Base | 12 | 768 | 3072 | 12 | 86M |
| ViT-Large | 24 | 1024 | 4096 | 16 | 307M |
| ViT-Huge | 32 | 1280 | 5120 | 16 | 632M |
表 1:Vision Transformer 模型变体的详细信息。
我们还在 19 个任务的 VTAB 分类套件上进行了评估 (Zhai et al., 2019b)。VTAB 评估对不同任务的低数据迁移,每个任务使用 1,000 个训练示例。任务分为三组:Natural(如上述任务、Pets、CIFAR 等)、Specialized(医学和卫星图像)和 Structured(需要几何理解的任务,如定位)。
模型变体。我们将 ViT 配置基于 BERT 中使用的配置(Devlin et al., 2019),如表 1 所示。“Base”和“Large”模型直接采用自 BERT,我们添加了更大的“Huge”模型。在下文中,我们使用简写符号来表示模型大小和输入块大小:例如,ViT-L/16 表示具有 输入块大小的“Large”变体。请注意,Transformer 的序列长度与块大小的平方成反比,因此块大小较小的模型计算成本更高。
对于基线 CNN,我们使用 ResNet (He et al., 2016),但将批归一化层 (Ioffe & Szegedy, 2015) 替换为组归一化 (Wu & He, 2018),并使用了标准化卷积 (Qiao et al., 2019)。这些修改改进了迁移 (Kolesnikov et al., 2020),我们将修改后的模型称为“ResNet (BiT)”。对于混合模型,我们将中间特征图输入到块大小为“一个像素”的 ViT 中。为了试验不同的序列长度,我们要么 (i) 采用常规 ResNet50 的第 4 阶段输出,要么 (ii) 移除第 4 阶段,将相同数量的层放置在第 3 阶段(保持总层数不变),并采用这个扩展的第 3 阶段的输出。选项 (ii) 导致序列长度增加了 4 倍,并且是一个更昂贵的 ViT 模型。
训练与微调。我们使用 Adam (Kingma & Ba, 2015) 训练所有模型(包括 ResNet),参数为 ,批大小为 4096,并应用 0.1 的高权重衰减,我们发现这对所有模型的迁移都很有用(附录 D.1 显示,与通常的做法相反,在我们的设置中,Adam 对 ResNet 的效果略好于 SGD)。我们使用线性学习率预热和衰减,详情见附录 B.1。对于微调,我们对所有模型使用带有动量的 SGD,批大小为 512,见附录 B.1.1。对于表 2 中的 ImageNet 结果,我们在更高分辨率下进行了微调:ViT-L/16 为 512,ViT-H/14 为 518,并且还使用了 Polyak & Juditsky (1992) 的平均法,因子为 0.9999 (Ramachandran et al., 2019; Wang et al., 2020b)。
指标。我们通过少样本(few-shot)或微调准确率报告下游数据集的结果。微调准确率捕获了模型在相应数据集上微调后的性能。少样本准确率是通过解决一个正则化最小二乘回归问题获得的,该问题将训练图像子集的(冻结)表示映射到 目标向量。这种公式允许我们以闭式形式恢复精确解。虽然我们主要关注微调性能,但有时我们使用线性少样本准确率进行快速的即时评估,因为微调的成本太高。
4.2 与最先进水平的比较
我们首先将我们最大的模型——ViT-H/14 和 ViT-L/16——与文献中最先进的 CNN 进行比较。第一个比较点是 Big Transfer (BiT) (Kolesnikov et al., 2020),它使用大型 ResNet 执行监督迁移学习。第二个是 Noisy Student (Xie et al., 2020),这是一个在 ImageNet 和 JFT-300M 上使用半监督学习训练的大型 EfficientNet,标签已被移除。目前,Noisy Student 是 ImageNet 上最先进的,BiT-L 是此处报告的其他数据集上最先进的。所有模型都在 TPUv3 硬件上进行了训练,我们报告了预训练每个模型所花费的 TPUv3 核心天数,即用于训练的 TPU v3 核心数量(每个芯片 2 个)乘以训练时间(以天为单位)。
表 2 显示了结果。在 JFT-300M 上预训练的较小 ViT-L/16 模型在所有任务上都优于 BiT-L(它是在相同数据集上预训练的),同时训练所需的计算资源显著减少。较大的模型 ViT-H/14 进一步提高了性能,特别是在更具挑战性的数据集上——ImageNet、CIFAR-100 和 VTAB 套件。有趣的是,该模型预训练所需的计算量仍然比之前最先进的水平少得多。然而,我们注意到预训练效率可能不仅受架构选择的影响,还受其他参数的影响,例如训练计划、优化器、权重衰减等。我们在第 4.4 节中提供了不同架构的性能与计算量的对照研究。最后,在公开的 ImageNet-21k 数据集上预训练的 ViT-L/16 模型在大多数数据集上也表现良好,同时预训练所需的资源更少:它可以使用带有 8 个核心的标准云 TPUv3 在大约 30 天内完成训练。
图 2 将 VTAB 任务分解为各自的组,并与该基准上的先前 SOTA 方法进行了比较:BiT、VIVI(在 ImageNet 和 Youtube 上共同训练的 ResNet,Tschannen et al., 2020)和 S4L(在 ImageNet 上进行监督加半监督学习,Zhai et al., 2019a)。ViT-H/14 在 Natural 和 Structured 任务上优于 BiT-R152x4 和其他方法。在 Specialized 任务上,前两个模型的性能相似。
| Ours-JFT (ViT-H/14) | Ours-JFT (ViT-L/16) | Ours-I21k (ViT-L/16) | BiT-L (ResNet152x4) | Noisy Student (EfficientNet-L2) | |
|---|---|---|---|---|---|
| ImageNet | 88.55 ± 0.04 | 87.76 ± 0.03 | 85.30 ± 0.02 | 87.54 ± 0.02 | 88.4/88.5* |
| ImageNet ReaL | 90.72 ± 0.05 | 90.54 ± 0.03 | 88.62 ± 0.05 | 90.54 | 90.55 |
| CIFAR-10 | 99.50 ± 0.06 | 99.42 ± 0.03 | 99.15 ± 0.03 | 99.37 ± 0.06 | - |
| CIFAR-100 | 94.55 ± 0.04 | 93.90 ± 0.05 | 93.25 ± 0.05 | 93.51 ± 0.08 | - |
| Oxford-IIIT Pets | 97.56 ± 0.03 | 97.32 ± 0.11 | 94.67 ± 0.15 | 96.62 ± 0.23 | - |
| Oxford Flowers-102 | 99.68 ± 0.02 | 99.74 ± 0.00 | 99.61 ± 0.02 | 99.63 ± 0.03 | - |
| VTAB (19 tasks) | 77.63 ± 0.23 | 76.28 ± 0.46 | 72.72 ± 0.21 | 76.29 ± 1.70 | - |
| TPUv3-core-days | 2.5k | 0.68k | 0.23k | 9.9k | 12.3k |
表 2:与流行图像分类基准上最先进水平的比较。我们报告了三次微调运行的平均准确率和标准差。在 JFT-300M 数据集上预训练的 Vision Transformer 模型在所有数据集上都优于基于 ResNet 的基线,同时预训练所需的计算资源显著减少。在较小的公开 ImageNet-21k 数据集上预训练的 ViT 也表现良好。*Touvron et al. (2020) 报告了略微改进的 88.5% 结果。

4.3 预训练数据需求
Vision Transformer 在大型 JFT-300M 数据集上预训练时表现良好。与 ResNet 相比,视觉归纳偏置更少,数据集大小有多关键?我们进行了两系列实验。
首先,我们在规模不断增加的数据集上预训练 ViT 模型:ImageNet、ImageNet-21k 和 JFT-300M。为了提高在较小数据集上的性能,我们优化了三个基本正则化参数——权重衰减、dropout 和标签平滑。图 3 显示了微调到 ImageNet 后的结果(其他数据集的结果显示在表 5 中)。当在最小的数据集 ImageNet 上预训练时,尽管有(适度的)正则化,ViT-Large 模型的表现不如 ViT-Base 模型。通过 ImageNet-21k 预训练,它们的性能相似。只有在 JFT-300M 下,我们才能看到更大模型的全部好处。图 3 还显示了不同规模 BiT 模型所覆盖的区域。BiT CNN 在 ImageNet 上优于 ViT,但在更大的数据集上,ViT 超越了它们。
其次,我们在 9M、30M 和 90M 的随机子集以及完整的 JFT-300M 数据集上训练我们的模型。我们不对较小的子集进行额外的正则化,并对所有设置使用相同的超参数。通过这种方式,我们评估了内在的模型属性,而不是正则化的效果。然而,我们确实使用了提前停止(early-stopping),并报告了训练期间达到的最佳验证准确率。为了节省计算量,我们报告了少样本线性准确率,而不是完整的微调准确率。图 4 包含了结果。在较小的数据集上,Vision Transformer 比具有相当计算成本的 ResNet 过拟合更严重。例如,ViT-B/32 比 ResNet50 稍快;它在 9M 子集上表现差得多,但在 90M+ 子集上表现更好。ResNet152x2 和 ViT-L/16 也是如此。这一结果强化了这样一种直觉:卷积归纳偏置对于较小的数据集很有用,但对于较大的数据集,直接从数据中学习相关模式就足够了,甚至是有益的。
总体而言,ImageNet 上的少样本结果(图 4)以及 VTAB 上的低数据结果(表 2)对于极低数据迁移似乎很有希望。对 ViT 少样本属性的进一步分析是未来工作的一个令人兴奋的方向。

4.4 扩展研究
我们通过评估从 JFT-300M 的迁移性能,对不同模型进行了受控的扩展研究。在这种设置下,数据大小不会成为模型性能的瓶颈,我们评估了每个模型的性能与预训练成本。模型集包括:7 个 ResNet,R50x1、R50x2、R101x1、R152x1、R152x2,预训练 7 个 epoch,加上预训练 14 个 epoch 的 R152x2 和 R200x3;6 个 Vision Transformer,ViT-B/32、B/16、L/32、L/16,预训练 7 个 epoch,加上预训练 14 个 epoch 的 L/16 和 H/14;以及 5 个混合架构,R50+ViT-B/32、B/16、L/32、L/16,预训练 7 个 epoch,加上预训练 14 个 epoch 的 R50+ViT-L/16(对于混合架构,模型名称末尾的数字代表的不是块大小,而是 ResNet 主干中的总下采样率)。
图 5 包含了迁移性能与总预训练计算量的关系(计算成本的详细信息见附录 D.5)。每个模型的详细结果在附录的表 6 中提供。可以观察到一些模式。首先,Vision Transformer 在性能/计算权衡上主导了 ResNet。ViT 使用大约 更少的计算量来达到相同的性能(5 个数据集的平均值)。其次,混合架构在较小的计算预算下略微优于 ViT,但差距在较大的模型上消失了。这个结果有点令人惊讶,因为人们可能期望卷积局部特征处理在任何规模上都能辅助 ViT。第三,Vision Transformer 在尝试的范围内似乎没有饱和,这激发了未来的扩展努力。
4.5 检查 Vision Transformer

为了开始理解 Vision Transformer 如何处理图像数据,我们分析了其内部表示。Vision Transformer 的第一层将展平的块线性投影到低维空间(公式 1)。图 7(左)显示了学习到的嵌入滤波器的前几个主成分。这些成分类似于每个块内精细结构的低维表示的合理基函数。
投影后,将学习到的位置嵌入添加到块表示中。图 7(中心)显示模型学习在位置嵌入的相似性中编码图像内的距离,即更接近的块倾向于具有更相似的位置嵌入。此外,行-列结构出现;同一行/列中的块具有相似的嵌入。最后,对于较大的网格,有时会出现正弦结构(附录 D)。位置嵌入学习表示 2D 图像拓扑这一事实解释了为什么手工制作的 2D 感知嵌入变体没有带来改进(附录 D.4)。
自注意力允许 ViT 即使在最低层也能整合整个图像的信息。我们调查了网络在多大程度上利用了这种能力。具体来说,我们根据注意力权重计算了信息整合的图像空间平均距离(图 7,右)。这种“注意力距离”类似于 CNN 中的感受野大小。
我们发现一些头在最低层就已经关注了大部分图像,这表明模型确实使用了全局整合信息的能力。其他注意力头在低层具有持续较小的注意力距离。这种高度局部的注意力在应用 ResNet 之后再应用 Transformer 的混合模型中不太明显(图 7,右),这表明它可能发挥了与 CNN 中早期卷积层类似的功能。此外,注意力距离随着网络深度而增加。总体而言,我们发现模型关注对分类具有语义相关性的图像区域(图 6)。
4.6 自监督
Transformer 在 NLP 任务上表现出令人印象深刻的性能。然而,它们的大部分成功不仅源于其出色的可扩展性,还源于大规模自监督预训练 (Devlin et al., 2019; Radford et al., 2018)。我们也对用于自监督的*掩码块预测(masked patch prediction)*进行了初步探索,模仿了 BERT 中使用的掩码语言建模任务。通过自监督预训练,我们较小的 ViT-B/16 模型在 ImageNet 上达到了 79.9% 的准确率,比从头开始训练提高了 2%,但仍比监督预训练落后 4%。附录 B.1.2 包含更多详细信息。我们将对比预训练(Chen et al., 2020b; He et al., 2020; Bachman et al., 2019; Hénaff et al., 2020)的探索留给未来的工作。

5 结论
我们探索了 Transformer 在图像识别中的直接应用。与之前在计算机视觉中使用自注意力的工作不同,除了初始块提取步骤外,我们没有在架构中引入图像特定的归纳偏置。相反,我们将图像解释为块序列,并由 NLP 中使用的标准 Transformer 编码器进行处理。这种简单但可扩展的策略在与大型数据集上的预训练相结合时效果出奇地好。因此,Vision Transformer 在许多图像分类数据集上匹配或超过了最先进水平,同时预训练相对便宜。
虽然这些初步结果令人鼓舞,但许多挑战仍然存在。一个是将 ViT 应用于其他计算机视觉任务,例如检测和分割。我们的结果,加上 Carion et al. (2020) 中的结果,表明了这种方法的前景。另一个挑战是继续探索自监督预训练方法。我们的初步实验显示了自监督预训练带来的改进,但自监督预训练与大规模监督预训练之间仍然存在巨大差距。最后,ViT 的进一步扩展可能会带来性能的提升。
致谢
这项工作是在柏林、苏黎世和阿姆斯特丹完成的。我们感谢 Google 的许多同事提供的帮助,特别是 Andreas Steiner 在基础设施和代码开源方面的关键帮助;Joan Puigcerver 和 Maxim Neumann 在大规模训练基础设施方面的帮助;Dmitry Lepikhin、Aravindh Mahendran、Daniel Keysers、Mario Lučić、Noam Shazeer、Ashish Vaswani 和 Colin Raffel 的有益讨论。
参考文献
(此处省略参考文献列表,以保持翻译的完整性与格式)
附录
A#### A 多头自注意力
标准 qkv 自注意力(SA,Vaswani et al. (2017))是神经网络架构中一种流行的构建块。对于输入序列 中的每个元素,我们计算序列中所有值 的加权和。注意力权重 基于序列中两个元素及其各自的查询 和键 表示之间的成对相似度。
多头自注意力(MSA)是 SA 的一种扩展,我们在其中并行运行 个自注意力操作(称为“头”),并投影它们的拼接输出。为了在改变 时保持计算量和参数数量不变,(公式 5)通常设置为 。
B 实验细节
B.1 训练
表 3 总结了我们不同模型的训练设置。我们发现,在 ImageNet 上从头开始训练模型时,强正则化是关键。Dropout(如果使用)应用于除 qkv 投影之外的每个密集层,并直接应用于位置嵌入到块嵌入的相加之后。混合模型使用与其 ViT 对应模型完全相同的设置进行训练。最后,所有训练均在分辨率 224 下完成。
B.1.1 微调
我们使用动量为 0.9 的 SGD 微调所有 ViT 模型。我们对学习率进行了小范围的网格搜索,参见表 4 中的学习率范围。为此,我们使用训练集的小子集(Pets 和 Flowers 为 10%,CIFAR 为 2%,ImageNet 为 1%)作为开发集,并在剩余数据上进行训练。对于最终结果,我们在整个训练集上进行训练,并在相应的测试数据上进行评估。对于 ResNet 和混合模型的微调,我们使用完全相同的设置,唯一的例外是 ImageNet,我们在学习率扫描中添加了另一个值 0.06。此外,对于 ResNet,我们还运行了 Kolesnikov et al. (2020) 的设置,并从该运行和我们的扫描中选择了最佳结果。最后,除非另有说明,所有微调实验均在 384 分辨率下运行(在与训练不同的分辨率下运行微调是常见的做法 (Kolesnikov et al., 2020))。
当将 ViT 模型迁移到另一个数据集时,我们移除整个头(两个线性层),并将其替换为输出目标数据集所需类别数量的单个零初始化线性层。我们发现这比简单地重新初始化最后一层更稳健。
对于 VTAB,我们遵循 Kolesnikov et al. (2020) 中的协议,并对所有任务使用相同的超参数设置。我们使用 0.01 的学习率并训练 2500 步(表 4)。我们通过对两个学习率和两个计划进行小范围扫描,并选择在 200 个示例验证集上具有最高 VTAB 分数的设置来选择此设置。我们遵循 Kolesnikov et al. (2020) 中使用的预处理,只是我们不使用特定于任务的输入分辨率。相反,我们发现 Vision Transformer 在所有任务的高分辨率()下受益最大。
B.1.2 自监督
我们采用掩码块预测(masked patch prediction)目标进行初步自监督实验。为此,我们通过以下方式破坏 50% 的块嵌入:用可学习的 [mask] 嵌入替换(80%)、替换为随机的其他块嵌入(10%)或保持原样(10%)。此设置与 Devlin et al. (2019) 用于语言的设置非常相似。最后,我们使用各自的块表示预测每个被破坏块的 3 位平均颜色(即总共 512 种颜色)。
我们在 JFT 上以 4096 的批大小训练了我们的自监督模型 1M 步(约 14 个 epoch)。我们使用 Adam,基础学习率为 ,预热 10k 步,并使用余弦学习率衰减。作为预训练的预测目标,我们尝试了以下设置:1) 仅预测平均 3 位颜色(即 512 种颜色的 1 次预测),2) 并行预测 块的 缩小版本,使用 3 位颜色(即 512 种颜色的 16 次预测),3) 使用 L2 对完整块进行回归(即对 3 个 RGB 通道进行 256 次回归)。令人惊讶的是,我们发现所有方法都运行得相当好,尽管 L2 稍差。我们仅报告选项 1) 的最终结果,因为它显示了最佳的少样本性能。我们还尝试了 Devlin et al. (2019) 使用的 15% 破坏率,但结果在我们的少样本指标上也稍差。
最后,我们想指出,我们的掩码块预测实例化不需要如此大量的预训练,也不需要像 JFT 这样的大型数据集来在 ImageNet 分类上获得类似的性能提升。也就是说,我们观察到在 100k 预训练步后下游性能的回报递减,并且在 ImageNet 上预训练时看到了类似的收益。
C 附加结果
我们报告了与论文中呈现的图表相对应的详细结果。表 5 对应于论文中的图 3,显示了在规模不断增加的数据集(ImageNet、ImageNet-21k 和 JFT-300M)上预训练的不同 ViT 模型的迁移性能。表 6 对应于论文中的图 5,显示了 ViT、ResNet 和不同规模混合模型的迁移性能,以及它们预训练的估计计算成本。
| 模型 | 数据集 | Epochs | 基础 LR | LR 衰减 | 权重衰减 | Dropout |
|---|---|---|---|---|---|---|
| ViT-B/{16,32} | JFT-300M | 7 | linear | 0.1 | 0.0 | |
| ViT-L/32 | JFT-300M | 7 | linear | 0.1 | 0.0 | |
| ViT-L/16 | JFT-300M | 7/14 | linear | 0.1 | 0.0 | |
| ViT-H/14 | JFT-300M | 14 | linear | 0.1 | 0.0 | |
| R50x{1,2} | JFT-300M | 7 | linear | 0.1 | 0.0 | |
| R101x1 | JFT-300M | 7 | linear | 0.1 | 0.0 | |
| R152x{1,2} | JFT-300M | 7 | linear | 0.1 | 0.0 | |
| R50+ViT-B/{16,32} | JFT-300M | 7 | linear | 0.1 | 0.0 | |
| R50+ViT-L/32 | JFT-300M | 7 | linear | 0.1 | 0.0 | |
| R50+ViT-L/16 | JFT-300M | 7/14 | linear | 0.1 | 0.0 | |
| ViT-B/{16,32} | ImageNet-21k | 90 | linear | 0.03 | 0.1 | |
| ViT-L/{16,32} | ImageNet-21k | 30/90 | linear | 0.03 | 0.1 | |
| ViT-* | ImageNet | 300 | cosine | 0.3 | 0.1 |
表 3:训练超参数。所有模型均以 4096 的批大小和 10k 步的学习率预热进行训练。对于 ImageNet,我们发现额外应用全局范数为 1 的梯度裁剪是有益的。训练分辨率为 224。
| 数据集 | 步数 | 基础 LR |
|---|---|---|
| ImageNet | 20 000 | {0.003, 0.01, 0.03, 0.06} |
| CIFAR100 | 10 000 | {0.001, 0.003, 0.01, 0.03} |
| CIFAR10 | 10 000 | {0.001, 0.003, 0.01, 0.03} |
| Oxford-IIIT Pets | 500 | {0.001, 0.003, 0.01, 0.03} |
| Oxford Flowers-102 | 500 | {0.001, 0.003, 0.01, 0.03} |
| VTAB (19 tasks) | 2 500 | 0.01 |
表 4:微调超参数。所有模型均使用余弦学习率衰减、512 的批大小、无权重衰减和全局范数为 1 的梯度裁剪进行微调。除非另有说明,微调分辨率为 384。
D 附加分析
D.1 SGD 与 Adam 用于 ResNet
ResNet 通常使用 SGD 训练,我们使用 Adam 作为优化器是非常规的。在这里,我们展示了促使这一选择的实验。即,我们比较了在 JFT 上使用 SGD 和 Adam 预训练的两个 ResNet(50x1 和 152x2)的微调性能。对于 SGD,我们使用 Kolesnikov et al. (2020) 推荐的超参数。结果如表 7 所示。Adam 预训练在大多数数据集上以及平均水平上都优于 SGD 预训练。这证明了选择 Adam 作为在 JFT 上预训练 ResNet 的优化器是合理的。请注意,绝对数字低于 Kolesnikov et al. (2020) 报告的数字,因为我们仅预训练了 7 个 epoch,而不是 30 个。
| ResNet50 | ResNet152x2 | |||
|---|---|---|---|---|
| 数据集 | Adam | SGD | Adam | SGD |
| ImageNet | 77.54 | 78.24 | 84.97 | 84.37 |
| CIFAR10 | 97.67 | 97.46 | 99.06 | 99.07 |
| CIFAR100 | 86.07 | 85.17 | 92.05 | 91.06 |
| Oxford-IIIT Pets | 91.11 | 91.00 | 95.37 | 94.79 |
| Oxford Flowers-102 | 94.26 | 92.06 | 98.62 | 99.32 |
| 平均 | 89.33 | 88.79 | 94.01 | 93.72 |
表 7:使用 Adam 和 SGD 预训练的 ResNet 模型的微调。
D.2 Transformer 形状
我们对 Transformer 架构的不同维度进行了缩放消融实验,以找出最适合扩展到超大规模模型的维度。图 8 显示了 ImageNet 上不同配置的 5-shot 性能。所有配置均基于具有 8 层、、 和块大小为 32 的 ViT 模型,这是所有线的交点。我们可以看到,缩放深度带来的改进最大,在 64 层之前清晰可见。然而,在 16 层之后已经可以看到收益递减。有趣的是,缩放网络的宽度似乎导致的变化最小。减小块大小从而增加有效序列长度显示出令人惊讶的稳健改进,而无需引入参数。这些发现表明,计算量可能是比参数数量更好的性能预测指标,并且缩放应该优先考虑深度而不是宽度(如果有的话)。总体而言,我们发现按比例缩放所有维度会带来稳健的改进。
D.3 头类型和 [class] 标记
为了尽可能接近原始 Transformer 模型,我们使用了额外的 [class] 标记,将其作为图像表示。然后,该标记的输出通过具有单隐藏层中 tanh 非线性的多层感知机(MLP)转换为类预测。
这种设计继承自文本的 Transformer 模型,我们在整篇论文中都使用了它。最初尝试仅使用图像块嵌入,对其进行全局平均池化(GAP),然后进行线性分类器——就像 ResNet 的最终特征图一样——表现非常差。然而,我们发现这既不是因为额外的标记,也不是因为 GAP 操作。相反,性能差异完全由对不同学习率的要求来解释,参见图 9。
D.4 位置嵌入
我们对使用位置嵌入编码空间信息的不同方式进行了消融实验。我们尝试了以下情况:
- 不提供位置信息:将输入视为块的集合(bag of patches)。
- 1 维位置嵌入:将输入视为光栅顺序的块序列(本文所有其他实验的默认设置)。
- 2 维位置嵌入:将输入视为二维块网格。在这种情况下,学习两组嵌入,每组对应一个轴,X-嵌入和 Y-嵌入,每组大小为 。然后,根据输入中路径上的坐标,我们拼接 X 和 Y 嵌入以获得该块的最终位置嵌入。
- 相对位置嵌入:考虑块之间的相对距离来编码空间信息,而不是它们的绝对位置。为此,我们使用 1 维相对注意力,其中我们定义所有可能的块对的相对距离。因此,对于每个给定的对(一个作为查询,另一个作为注意力机制中的键/值),我们有一个偏移量 ,其中每个偏移量都与一个嵌入相关联。然后,我们简单地运行额外的注意力,其中我们使用原始查询(查询的内容),但使用相对位置嵌入作为键。然后,我们在应用 softmax 之前,将来自相对注意力的 logits 作为偏置项添加到主注意力(基于内容的注意力)的 logits 中。
除了编码空间信息的不同方式外,我们还尝试了将此信息纳入我们模型的不同方式。对于 1 维和 2 维位置嵌入,我们尝试了三种不同的情况:(1) 在模型主干之后、输入 Transformer 编码器之前将位置嵌入添加到输入中(本文所有其他实验的默认设置);(2) 在每一层的开始学习并添加位置嵌入到输入中;(3) 在每一层的开始将学习到的位置嵌入添加到输入中(层之间共享)。
表 8 总结了 ViT-B/16 模型上这项消融研究的结果。正如我们所见,虽然没有位置嵌入的模型与有位置嵌入的模型之间存在巨大差距,但编码位置信息的不同方式之间几乎没有区别。我们推测,由于我们的 Transformer 编码器在块级输入上运行,而不是像素级,因此如何编码空间信息的差异不太重要。更准确地说,在块级输入中,空间维度比原始像素级输入小得多,例如 而不是 ,并且学习以这种分辨率表示空间关系对于这些不同的位置编码策略同样容易。即便如此,网络学习到的位置嵌入相似性的特定模式取决于训练超参数(图 10)。
| Pos. Emb. | Default/Stem | Every Layer | Every Layer-Shared |
|---|---|---|---|
| No Pos. Emb. | 0.61382 | N/A | N/A |
| 1-D Pos. Emb. | 0.64206 | 0.63964 | 0.64292 |
| 2-D Pos. Emb. | 0.64001 | 0.64046 | 0.64022 |
| Rel. Pos. Emb. | 0.64032 | N/A | N/A |
表 8:在 ImageNet 5-shot 线性评估上使用 ViT-B/16 模型的位置嵌入消融研究结果。
D.5 经验计算成本
我们也对架构在我们的硬件上的实际速度感兴趣,由于通道宽度和缓存大小等细节,这并不总是能通过理论 FLOPs 很好地预测。为此,我们在 TPUv3 加速器上对我们感兴趣的主要模型进行了推理速度计时;推理和反向传播速度之间的差异是一个与模型无关的常数因子。
图 12(左)显示了在各种输入尺寸下,一个核心每秒可以处理多少张图像。每一个点都指在广泛的批大小范围内测得的峰值性能。可以看出,ViT 随图像尺寸的理论双二次缩放仅在最大分辨率下的最大模型上才刚刚开始发生。
另一个感兴趣的数量是每个模型可以容纳在核心上的最大批大小,越大越有利于扩展到大型数据集。图 12(右)显示了同一组模型的此数量。这表明大型 ViT 模型在内存效率方面比 ResNet 模型具有明显的优势。
D.6 轴向注意力
轴向注意力(Huang et al., 2020; Ho et al., 2019)是一种简单而有效的技术,用于对组织为多维张量的海量输入运行自注意力。轴向注意力的总体思想是执行多个注意力操作,每个操作沿着输入张量的单个轴,而不是对输入的展平版本应用 1 维注意力。在轴向注意力中,每个注意力沿着特定轴混合信息,同时保持沿其他轴的信息独立。沿着这条线,Wang et al. (2020b) 提出了 AxialResNet 模型,其中 ResNet50 中所有内核大小为 的卷积都被轴向自注意力取代,即行和列注意力,并辅以相对位置编码。我们已经实现了 AxialResNet 作为基线模型。
此外,我们修改了 ViT 以处理 2 维形状的输入,而不是 1 维块序列,并结合了轴向 Transformer 块,其中我们不是自注意力后跟 MLP,而是有一个行自注意力加 MLP,后跟列自注意力加 MLP。
图 13 展示了 Axial ResNet、Axial-ViT-B/32 和 Axial-ViT-B/16 在 ImageNet 5-shot 线性上的性能(在 JFT 数据集上预训练),与预训练计算量(以 FLOPs 数量和推理时间(每秒示例数)衡量)的关系。正如我们所见,Axial-ViT-B/32 和 Axial-ViT-B/16 在性能方面都优于它们的 ViT-B 对应模型,但这是以更多计算为代价的。这是因为在 Axial-ViT 模型中,每个具有全局自注意力的 Transformer 块都被两个轴向 Transformer 块取代,一个具有行自注意力,一个具有列自注意力,尽管自注意力在轴向情况下运行的序列长度较小,但每个轴向 ViT 块都有一个额外的 MLP。对于 AxialResNet,虽然它在性能/计算权衡方面看起来合理(图 13,左),但朴素的实现方式在 TPU 上非常慢(图 13,右)。
D.7 注意力距离
为了理解 ViT 如何使用自注意力来整合图像信息,我们分析了不同层注意力权重所跨越的平均距离(图 11)。这种“注意力距离”类似于 CNN 中的感受野大小。平均注意力距离在低层的头之间变化很大,一些头关注大部分图像,而另一些头关注查询位置处或附近的小区域。随着深度的增加,所有头的注意力距离都会增加。在网络的后半部分,大多数头在标记之间广泛关注。
D.8 注意力图
为了计算从输出标记到输入空间的注意力图(图 6 和 14),我们使用了 Attention Rollout (Abnar & Zuidema, 2020)。简而言之,我们对 ViT-L/16 的所有头的注意力权重进行了平均,然后递归地乘以所有层的权重矩阵。这解释了注意力通过所有层在标记之间的混合。
D.9 ObjectNet 结果
我们还按照 Kolesnikov et al. (2020) 中的评估设置在 ObjectNet 基准上评估了我们的旗舰 ViT-H/14 模型,结果为 82.1% 的 top-5 准确率和 61.7% 的 top-1 准确率。
D.10 VTAB 分解
表 9 显示了在每个 VTAB-1k 任务上获得的分数。