# 2023 Gu & Dao

Mamba: Linear-Time Sequence Modeling with Selective State Spaces

Mamba:具有选择性状态空间的线性时间序列建模 Albert Gu$^ $ 和 Tri Dao$^ $ $^1$ 卡内基梅隆大学,机器学习系 $^2$ 普林斯顿大学,计算机科学系 agu@cs.cmu.edu, tri@tridao.me 摘要 基础模型(Foundation models)目前为深度学习中大多数令人兴奋的应用提供了动力,它们几乎普遍基于 Transformer...

精粹译文

Mamba:具有选择性状态空间的线性时间序列建模

Albert Gu^* 和 Tri Dao^* 1^1 卡内基梅隆大学,机器学习系 2^2 普林斯顿大学,计算机科学系 agu@cs.cmu.edu, tri@tridao.me

摘要

基础模型(Foundation models)目前为深度学习中大多数令人兴奋的应用提供了动力,它们几乎普遍基于 Transformer 架构及其核心注意力模块。为了解决 Transformer 在长序列上的计算效率低下问题,人们开发了许多次二次时间(subquadratic-time)架构,例如线性注意力、门控卷积和循环模型,以及结构化状态空间模型(SSMs),但它们在语言等重要模态上的表现不如注意力机制。我们发现这些模型的一个关键弱点是无法执行基于内容的推理,并对此进行了几项改进。首先,简单地让 SSM 参数成为输入的函数,解决了它们在离散模态上的弱点,允许模型根据当前标记(token)沿序列长度维度选择性地传播或遗忘信息。其次,尽管这种变化阻碍了高效卷积的使用,但我们设计了一种循环模式下的硬件感知并行算法。我们将这些选择性 SSM 集成到一个简化的端到端神经网络架构中,该架构没有注意力机制,甚至没有 MLP 块(Mamba)。Mamba 享受快速推理(吞吐量比 Transformer 高 5 倍)和序列长度的线性扩展,其在真实数据上的性能可提升至百万长度的序列。作为一种通用的序列模型主干,Mamba 在语言、音频和基因组学等多种模态上实现了最先进的性能。在语言建模方面,我们的 Mamba-3B 模型在预训练和下游评估中,性能均优于相同规模的 Transformer,并与两倍于其规模的 Transformer 相媲美。


1 引言

基础模型(FMs),即在海量数据上预训练然后适应下游任务的大型模型,已成为现代机器学习中一种有效的范式。这些 FM 的主干通常是序列模型,对来自语言、图像、语音、音频、时间序列和基因组学等各种领域的任意输入序列进行操作(Brown 等人 2020;Dosovitskiy 等人 2020;Ismail Fawaz 等人 2019;Oord 等人 2016;Poli 等人 2023;Sutskever, Vinyals, 和 Quoc V Le 2014)。虽然这个概念与特定的模型架构选择无关,但现代 FM 主要基于单一类型的序列模型:Transformer(Vaswani 等人 2017)及其核心注意力层(Bahdanau, Cho, 和 Bengio 2015)。自注意力机制的功效归功于其在上下文窗口内密集路由信息的能力,从而允许它对复杂数据进行建模。然而,这一特性带来了根本性的缺陷:无法对有限窗口之外的任何内容进行建模,以及相对于窗口长度的二次方扩展。为了克服这些缺陷,出现了大量关于更高效注意力变体的研究(Tay, Dehghani, Bahri 等人 2022),但往往以牺牲使其有效的特性为代价。到目前为止,还没有这些变体被证明在跨领域的规模上具有经验有效性。

最近,结构化状态空间序列模型(SSMs)(Gu, Goel, 和 Ré 2022;Gu, Johnson, Goel 等人 2021)已成为一种有前途的序列建模架构类别。这些模型可以被解释为循环神经网络(RNNs)和卷积神经网络(CNNs)的结合,灵感来自经典的状态空间模型(Kalman 1960)。这类模型可以作为循环或卷积非常高效地计算,在序列长度上具有线性或近线性扩展。此外,它们在某些数据模态中具有建模长程依赖的原则性机制(Gu, Dao 等人 2020),并主导了诸如 Long Range Arena(Tay, Dehghani, Abnar 等人 2021)等基准测试。许多类型的 SSM(Gu, Goel, 和 Ré 2022;Gu, Gupta 等人 2022;Gupta, Gu, 和 Berant 2022;Y. Li 等人 2023;Ma 等人 2023;Orvieto 等人 2023;Smith, Warrington, 和 Linderman 2023)在涉及音频和视觉等连续信号数据的领域取得了成功(Goel 等人 2022;Nguyen, Goel 等人 2022;Saon, Gupta, 和 Cui 2023)。然而,它们在建模文本等离散且信息密集的数据方面效果较差。

我们提出了一类新的选择性状态空间模型,它在多个轴上改进了先前的工作,以在序列长度上线性扩展的同时实现 Transformer 的建模能力。

选择性机制。首先,我们确定了先前模型的一个关键限制:以输入依赖方式高效选择数据的能力(即关注或忽略特定输入)。基于重要合成任务(如选择性复制和归纳头)的直觉,我们通过根据输入参数化 SSM 参数来设计一种简单的选择机制。这允许模型过滤掉不相关的信息并无限期地记住相关信息。

硬件感知算法。这种简单的变化对模型的计算提出了技术挑战;事实上,所有先前的 SSM 模型必须是时间不变和输入不变的,才能在计算上高效。我们通过一种硬件感知算法克服了这一点,该算法通过扫描(scan)而不是卷积以循环方式计算模型,但不会具体化扩展状态,以避免 GPU 内存层次结构不同级别之间的 IO 访问。由此产生的实现不仅在理论上比以前的方法更快(序列长度线性扩展,而所有基于卷积的 SSM 为伪线性),而且在现代硬件上也更快(在 A100 GPU 上快达 3 倍)。

架构。我们通过将先前 SSM 架构的设计(Dao, Fu, Saab 等人 2023)与 Transformer 的 MLP 块结合成一个单一块,简化了先前的深度序列模型架构,从而产生了一种简单的同质架构设计(Mamba),其中结合了选择性状态空间。

选择性 SSM,以及由此扩展的 Mamba 架构,是完全循环的模型,具有使其适合作为在序列上运行的通用基础模型主干的关键属性。(i)高质量:选择性在语言和基因组学等密集模态上带来了强大的性能。(ii)快速训练和推理:计算和内存随序列长度线性扩展,并且在推理过程中自回归展开模型每步仅需要恒定时间,因为它不需要先前元素的缓存。(iii)长上下文:质量和效率共同在真实数据上实现了高达 1M 序列长度的性能提升。

我们通过预训练质量和特定领域的任务性能,在多种模态和设置下,实证验证了 Mamba 作为通用序列 FM 主干的潜力:

  • 合成任务。在被认为是大型语言模型关键的重要合成任务(如复制和归纳头)上,Mamba 不仅能轻松解决它们,而且可以无限期地外推解决方案(>1M 个标记)。
  • 音频和基因组学。Mamba 在建模音频波形和 DNA 序列方面优于 SaShiMi、Hyena 和 Transformer 等先前最先进的模型,无论是在预训练质量还是下游指标上(例如,在具有挑战性的语音生成数据集上将 FID 降低了一半以上)。在这两种设置中,其性能都随着更长的上下文而提高,最高可达百万长度的序列。
  • 语言建模。Mamba 是第一个真正实现 Transformer 质量性能的线性时间序列模型,无论是在预训练困惑度还是下游评估中。通过高达 1B 参数的缩放定律,我们表明 Mamba 的性能超过了大量的基线,包括基于 LLaMa 的非常强大的现代 Transformer 训练配方(Touvron 等人 2023)。我们的 Mamba 语言模型具有比同等规模的 Transformer 高 5 倍的生成吞吐量,并且 Mamba-3B 的质量与两倍于其规模的 Transformer 相匹配(例如,与 Pythia-3B 相比,常识推理平均高出 4 分,甚至超过了 Pythia-7B)。

模型代码和预训练检查点已在 https://github.com/state-spaces/mamba 开源。


2 状态空间模型

结构化状态空间序列模型(S4)是最近一类用于深度学习的序列模型,与 RNN、CNN 和经典状态空间模型广泛相关。它们的灵感来自一个特定的连续系统 (1),该系统将 1 维函数或序列 x(t)Ry(t)Rx(t) \in \mathbb{R} \mapsto y(t) \in \mathbb{R} 通过隐式潜在状态 h(t)RNh(t) \in \mathbb{R}^N 进行映射。

具体而言,S4 模型由四个参数 (Δ,A,B,C)(\Delta, A, B, C) 定义,它们在两个阶段定义了序列到序列的转换。

h(t)=Ah(t)+Bx(t)(1a)h'(t) = Ah(t) + Bx(t) \quad (1a) y(t)=Ch(t)(1b)y(t) = Ch(t) \quad (1b)

ht=Aˉht1+Bˉxt(2a)h_t = \bar{A}h_{t-1} + \bar{B}x_t \quad (2a) yt=Cˉht(2b)y_t = \bar{C}h_t \quad (2b)

Kˉ=(CˉBˉ,CˉAˉBˉ,,CˉAˉkBˉ,)(3a)\bar{K} = (\bar{C}\bar{B}, \bar{C}\bar{A}\bar{B}, \dots, \bar{C}\bar{A}^k\bar{B}, \dots) \quad (3a) y=xKˉ(3b)y = x * \bar{K} \quad (3b)

离散化。第一阶段通过固定公式 Aˉ=fΔ(Δ,A)\bar{A} = f_\Delta(\Delta, A)Bˉ=fΔ(Δ,A,B)\bar{B} = f_\Delta(\Delta, A, B) 将“连续参数” (Δ,A,B)(\Delta, A, B) 转换为“离散参数” (Aˉ,Bˉ)(\bar{A}, \bar{B}),其中对 (fΔ,fB)(f_\Delta, f_B) 称为离散化规则。可以使用各种规则,例如方程 (4) 中定义的零阶保持(ZOH)。

Aˉ=exp(ΔA)Bˉ=(ΔA)1(exp(ΔA)I)ΔB(4)\bar{A} = \exp(\Delta A) \quad \bar{B} = (\Delta A)^{-1}(\exp(\Delta A) - I) \cdot \Delta B \quad (4)

离散化与连续时间系统有着深厚的联系,这可以赋予它们额外的属性,例如分辨率不变性(Nguyen, Goel 等人 2022),并自动确保模型被正确归一化(Gu, Johnson, Timalsina 等人 2023;Orvieto 等人 2023)。它还与 RNN 的门控机制有关(Gu, Gulcehre 等人 2020;Tallec 和 Ollivier 2018),我们将在第 3.5 节中重新讨论。然而,从机械的角度来看,离散化可以简单地视为 SSM 前向传递中计算图的第一步。其他类型的 SSM 可以绕过离散化步骤并直接参数化 (Aˉ,Bˉ)(\bar{A}, \bar{B})(Zhang 等人 2023),这可能更容易推理。

计算。在参数从 (Δ,A,B,C)(Aˉ,Bˉ,Cˉ)(\Delta, A, B, C) \mapsto (\bar{A}, \bar{B}, \bar{C}) 转换后,模型可以通过两种方式计算,即线性循环 (2) 或全局卷积 (3)。

通常,模型使用卷积模式 (3) 进行高效的并行训练(其中提前看到整个输入序列),并切换到循环模式 (2) 进行高效的自回归推理(其中一次看到一个时间步的输入)。

线性时间不变性(LTI)。方程 (1) 到 (3) 的一个重要属性是模型的动力学随时间保持不变。换句话说,(Δ,A,B,C)(\Delta, A, B, C) 以及随之而来的 (Aˉ,Bˉ)(\bar{A}, \bar{B}) 对于所有时间步都是固定的。这个属性被称为线性时间不变性(LTI),它与循环和卷积有着深刻的联系。非正式地,我们将 LTI SSM 视为等同于任何线性循环 (2a) 或卷积 (3b),并使用 LTI 作为这些模型类别的统称。

到目前为止,所有结构化 SSM 都是 LTI 的(例如作为卷积计算),因为存在第 3.3 节中讨论的基本效率约束。然而,这项工作的核心见解是 LTI 模型在建模某些类型的数据时存在根本性的局限性,我们的技术贡献包括在克服效率瓶颈的同时消除 LTI 约束。

结构和维度。最后,我们注意到结构化 SSM 之所以这样命名,是因为高效计算它们还需要对 AA 矩阵施加结构。最流行的结构形式是对角线(Gu, Gupta 等人 2022;Gupta, Gu, 和 Berant 2022;Smith, Warrington, 和 Linderman 2023),我们也使用这种结构。

在这种情况下,ARN×N,BRN×1,CR1×NA \in \mathbb{R}^{N \times N}, B \in \mathbb{R}^{N \times 1}, C \in \mathbb{R}^{1 \times N} 矩阵都可以用 NN 个数字表示。为了对具有 DD 个通道、批大小 BB 和长度 LL 的输入序列 xx 进行操作,SSM 被独立应用于每个通道。注意,在这种情况下,每个输入的总隐藏状态维度为 DNDN,并且在序列长度上计算它需要 O(BLDN)O(BLDN) 的时间和内存;这是第 3.3 节中解决的基本效率瓶颈的根源。

通用状态空间模型。我们注意到“状态空间模型”一词具有非常广泛的含义,它简单地代表了任何具有潜在状态的循环过程的概念。它已被用于指代不同学科中的许多不同概念,包括马尔可夫决策过程(MDP)(强化学习(Hafner 等人 2020))、动态因果建模(DCM)(计算神经科学(Friston, Harrison, 和 Penny 2003))、卡尔曼滤波器(控制(Kalman 1960))、隐马尔可夫模型(HMM)和线性动力系统(LDS)(机器学习),以及大型循环(有时是卷积)模型(深度学习)。

在整篇论文中,我们使用术语“SSM”专门指代结构化 SSM 或 S4 模型(Gu, Goel, 和 Ré 2022;Gu, Gupta 等人 2022;Gupta, Gu, 和 Berant 2022;Hasani 等人 2023;Ma 等人 2023;Smith, Warrington, 和 Linderman 2023),并交替使用这些术语。为了方便起见,我们也可以包括此类模型的衍生产品,例如那些专注于线性循环或全局卷积观点的模型(Y. Li 等人 2023;Orvieto 等人 2023;Poli 等人 2023),并在必要时澄清细微差别。

SSM 架构。SSM 是可以合并到端到端神经网络架构中的独立序列转换。(我们也有时称 SSM 架构为 SSNN,它们之于 SSM 层就像 CNN 之于线性卷积层。)我们讨论了一些最著名的 SSM 架构,其中许多也将作为我们的主要基线。

  • 线性注意力(Katharopoulos 等人 2020)是自注意力的一种近似,涉及一种可以被视为退化线性 SSM 的循环。
  • H3(Dao, Fu, Saab 等人 2023)将这种循环推广到使用 S4;它可以被视为一种架构,其中 SSM 被两个门控连接夹在中间(图 3)。H3 还在主 SSM 层之前插入了一个标准的局部卷积,他们将其称为 shift-SSM。
  • Hyena(Poli 等人 2023)使用与 H3 相同的架构,但用 MLP 参数化的全局卷积(Romero 等人 2021)替换了 S4 层。
  • RetNet(Y. Sun 等人 2023)在架构中增加了一个额外的门,并使用了一个更简单的 SSM,允许一种替代的并行计算路径,使用多头注意力(MHA)的变体而不是卷积。
  • RWKV(B. Peng 等人 2023)是一种最近设计的用于语言建模的 RNN,基于另一种线性注意力近似,即无注意力 Transformer(S. Zhai 等人 2021)。其主要的“WKV”机制涉及 LTI 循环,可以被视为两个 SSM 的比率。

其他密切相关的 SSM 和架构在扩展的相关工作(附录 B)中进行了进一步讨论。我们特别强调 S5(Smith, Warrington, 和 Linderman 2023)、QRNN(Bradbury 等人 2016)和 SRU(Lei 等人 2017),我们认为这些是与我们核心选择性 SSM 最相关的方法。


选择性状态空间模型与硬件感知状态扩展

图 1:(概述。)结构化 SSM 独立地将输入 xx 的每个通道(例如 D=5D=5)通过更高维的潜在状态 hh(例如 N=4N=4)映射到输出 yy。先前的 SSM 通过巧妙的替代计算路径避免了具体化这个巨大的有效状态(DNDN,乘以批大小 BB 和序列长度 LL),这些路径需要时间不变性:(Δ,A,B,C)(\Delta, A, B, C) 参数在时间上是恒定的。我们的选择机制添加了输入依赖的动力学,这也需要仔细的硬件感知算法,仅在 GPU 内存层次结构中更有效的级别具体化扩展状态。


3 选择性状态空间模型

我们使用来自合成任务的直觉(第 3.1 节)来激发我们的选择机制,然后解释如何将此机制合并到状态空间模型中(第 3.2 节)。由此产生的时变 SSM 不能使用卷积,这提出了如何高效计算它们的技术挑战。我们通过一种利用现代硬件上内存层次结构的硬件感知算法克服了这一点(第 3.3 节)。然后,我们描述了一种没有注意力机制甚至没有 MLP 块的简单 SSM 架构(第 3.4 节)。最后,我们讨论了选择机制的一些额外属性(第 3.5 节)。

3.1 动机:选择作为压缩的一种手段

我们认为序列建模的一个基本问题是将上下文压缩成更小的状态。事实上,我们可以从这个角度看待流行序列模型的权衡。例如,注意力机制既有效又低效,因为它明确地根本不压缩上下文。这可以从自回归推理需要明确存储整个上下文(即 KV 缓存)这一事实中看出,这直接导致了 Transformer 缓慢的线性时间推理和二次方时间训练。另一方面,循环模型是高效的,因为它们具有有限的状态,这意味着恒定时间的推理和线性时间的训练。然而,它们的有效性受到该状态压缩上下文程度的限制。

为了理解这一原则,我们关注两个合成任务的运行示例(图 2)。

  • 选择性复制任务通过改变要记忆的标记的位置,修改了流行的复制任务(Arjovsky, Shah, 和 Bengio 2016)。它需要内容感知推理,以便能够记忆相关的标记(彩色)并过滤掉不相关的标记(白色)。
  • 归纳头任务是一个众所周知的机制,被假设用于解释 LLM 的大部分上下文学习能力(Olsson 等人 2022)。它需要上下文感知推理,以知道何时在适当的上下文中产生正确的输出(黑色)。

这些任务揭示了 LTI 模型的失效模式。从循环的角度来看,它们的恒定动力学(例如 (2) 中的 (Aˉ,Bˉ)(\bar{A}, \bar{B}) 转换)不能让它们从上下文中选择正确的信息,或者以输入依赖的方式影响沿序列传递的隐藏状态。从卷积的角度来看,众所周知,全局卷积可以解决普通的复制任务(Romero 等人 2021),因为它只需要时间感知,但它们在选择性复制任务上存在困难,因为缺乏内容感知(图 2)。更具体地说,输入到输出之间的间距是变化的,不能用静态卷积核建模。

总之,序列模型的效率与有效性权衡的特征在于它们压缩状态的程度:高效模型必须具有较小的状态,而有效模型必须具有包含来自上下文的所有必要信息的状态。反过来,我们提出构建序列模型的一个基本原则是选择性:即关注或过滤输入到顺序状态中的上下文感知能力。特别是,选择机制控制信息如何沿序列维度传播或交互(参见第 3.5 节以获取更多讨论)。

3.2 用选择改进 SSM

将选择机制合并到模型中的一种方法是让它们影响沿序列交互的参数(例如 RNN 的循环动力学或 CNN 的卷积核)成为输入依赖的。

算法 1 和 2 说明了我们使用的主要选择机制。主要区别仅仅是使几个参数 Δ,B,C\Delta, B, C 成为输入的函数,以及随之而来的整个张量形状的变化。特别是,我们强调这些参数现在具有长度维度 LL,这意味着模型已从时间不变变为时变。(注意,形状注释在第 2 节中描述。)这失去了与卷积 (3) 的等价性,并对其效率产生了影响,将在下面讨论。

我们特别选择 sB(x)=LinearN(x),sC(x)=LinearN(x),sΔ(x)=BroadcastD(Linear1(x))s_B(x) = \text{Linear}_N(x), s_C(x) = \text{Linear}_N(x), s_\Delta(x) = \text{Broadcast}_D(\text{Linear}_1(x)), 以及 τΔ=softplus\tau_\Delta = \text{softplus},其中 LinearN\text{Linear}_N 是到维度 NN 的参数化投影。sΔs_\DeltaτΔ\tau_\Delta 的选择是由于与第 3.5 节中解释的 RNN 门控机制的联系。


图 2

图 2:(左)复制任务的标准版本涉及输入和输出元素之间的恒定间距,并且很容易被线性循环和全局卷积等时间不变模型解决。(右上)选择性复制任务在输入之间具有随机间距,需要能够根据内容选择性地记住或忽略输入的时变模型。(右下)归纳头任务是联想回忆的一个例子,它需要根据上下文检索答案,这是 LLM 的一项关键能力。


算法 1 SSM (S4)算法 2 SSM + 选择 (S6)
输入x:(B,L,D)x: (B, L, D)输入x:(B,L,D)x: (B, L, D)
输出y:(B,L,D)y: (B, L, D)输出y:(B,L,D)y: (B, L, D)
1: A:(D,N)ParameterA: (D, N) \leftarrow \text{Parameter}1: A:(D,N)ParameterA: (D, N) \leftarrow \text{Parameter}
\quad \triangleright 表示结构化 N×NN \times N 矩阵\quad \triangleright 表示结构化 N×NN \times N 矩阵
2: B:(D,N)ParameterB: (D, N) \leftarrow \text{Parameter}2: B:(B,L,N)sB(x)B: (B, L, N) \leftarrow s_B(x)
3: C:(D,N)ParameterC: (D, N) \leftarrow \text{Parameter}3: C:(B,L,N)sC(x)C: (B, L, N) \leftarrow s_C(x)
4: Δ:(D)τΔ(Parameter)\Delta: (D) \leftarrow \tau_\Delta(\text{Parameter})4: Δ:(B,L,D)τΔ(Parameter+sΔ(x))\Delta: (B, L, D) \leftarrow \tau_\Delta(\text{Parameter} + s_\Delta(x))
5: Aˉ,Bˉ:(D,N)discretize(Δ,A,B)\bar{A}, \bar{B}: (D, N) \leftarrow \text{discretize}(\Delta, A, B)5: Aˉ,Bˉ:(B,L,D,N)discretize(Δ,A,B)\bar{A}, \bar{B}: (B, L, D, N) \leftarrow \text{discretize}(\Delta, A, B)
6: ySSM(Aˉ,Bˉ,Cˉ)(x)y \leftarrow \text{SSM}(\bar{A}, \bar{B}, \bar{C})(x)6: ySSM(Aˉ,Bˉ,Cˉ)(x)y \leftarrow \text{SSM}(\bar{A}, \bar{B}, \bar{C})(x)
\quad \triangleright 时间不变:循环或卷积\quad \triangleright 时变:仅循环 (scan)
7: return yy7: return yy

3.3 选择性 SSM 的高效实现

卷积(Krizhevsky, Sutskever, 和 Hinton 2012)和注意力(Bahdanau, Cho, 和 Bengio 2015;Vaswani 等人 2017)等硬件友好原语享有广泛的应用。在这里,我们的目标是使选择性 SSM 在现代硬件(GPU)上也高效。选择机制非常自然,早期的工作尝试合并选择的特殊情况,例如在循环 SSM 中让 Δ\Delta 随时间变化(Gu, Dao 等人 2020)。然而,如前所述,SSM 使用的一个核心限制是它们的计算效率,这就是为什么 S4 及其所有衍生产品使用 LTI(非选择性)模型,最常见的是以全局卷积的形式。

3.3.1 先前模型的动机

我们首先回顾这一动机并概述我们克服先前方法局限性的方法。

  • 在高层面上,循环模型(如 SSM)总是平衡表达能力和速度之间的权衡:如第 3.1 节所述,具有更大隐藏状态维度的模型应该更有效但更慢。因此,我们希望在不支付速度和内存成本的情况下最大化隐藏状态维度。
  • 注意,循环模式比卷积模式更灵活,因为后者 (3) 是从前者 (2) 扩展而来的(Gu, Goel, 和 Ré 2022;Gu, Johnson, Goel 等人 2021)。然而,这将需要计算和具体化形状为 (B,L,D,N)(B, L, D, N) 的潜在状态 hh,这比形状为 (B,L,D)(B, L, D) 的输入 xx 和输出 yy 大得多(大 NN 倍,即 SSM 状态维度)。因此,引入了更高效的卷积模式,它可以绕过状态计算并具体化仅大小为 (B,L,D)(B, L, D) 的卷积核 (3a)。
  • 先前的 LTI 状态空间模型利用双重循环-卷积形式将有效状态维度增加了 NN 倍(10100\approx 10-100),远大于传统 RNN,且没有效率损失。

3.3.2 选择性扫描概述:硬件感知状态扩展

选择机制旨在克服 LTI 模型的局限性;同时,我们因此需要重新审视 SSM 的计算问题。我们通过三种经典技术来解决这个问题:内核融合、并行扫描和重计算。我们有两个主要观察结果:

  • 朴素的循环计算使用 O(BLDN)O(BLDN) FLOPs,而卷积计算使用 O(BLDlogL)O(BLD \log L) FLOPs,并且前者具有更低的常数因子。因此,对于长序列和不太大的状态维度 NN,循环模式实际上可以使用更少的 FLOPs。
  • 两个挑战是循环的顺序性质和巨大的内存使用。为了解决后者,就像卷积模式一样,我们可以尝试不实际具体化完整状态 hh

主要思想是利用现代加速器(GPU)的属性,仅在内存层次结构中更有效的级别具体化状态 hh。特别是,大多数操作(矩阵乘法除外)都受内存带宽限制(Dao, Fu, Ermon 等人 2022;Ivanov 等人 2021;Williams, Waterman, 和 Patterson 2009)。这包括我们的扫描操作,我们使用内核融合来减少内存 IO 的数量,从而导致与标准实现相比的显著加速。

具体而言,我们不是在 GPU HBM(高带宽内存)中准备大小为 (B,L,D,N)(B, L, D, N) 的扫描输入 (Aˉ,Bˉ)(\bar{A}, \bar{B}),而是直接从慢速 HBM 将 SSM 参数 (Δ,A,B,C)(\Delta, A, B, C) 加载到快速 SRAM,在 SRAM 中执行离散化和循环,然后将大小为 (B,L,D)(B, L, D) 的最终输出写回 HBM。

为了避免顺序循环,我们观察到尽管它不是线性的,但它仍然可以通过工作高效的并行扫描算法进行并行化(Blelloch 1990;Martin 和 Cundy 2018;Smith, Warrington, 和 Linderman 2023)。

最后,我们还必须避免保存反向传播所需的中间状态。我们仔细应用重计算的经典技术来减少内存需求:中间状态不被存储,而是在输入从 HBM 加载到 SRAM 时在反向传递中重新计算。结果,融合的选择性扫描层具有与使用 FlashAttention 的优化 Transformer 实现相同的内存需求。

融合内核和重计算的详细信息在附录 D 中。完整的选择性 SSM 层和算法如图 1 所示。

3.4 简化 SSM 架构

与结构化 SSM 一样,选择性 SSM 是可以灵活合并到神经网络中的独立序列转换。H3 架构是大多数著名 SSM 架构的基础(第 2 节),这些架构通常由受线性注意力启发的块与 MLP(多层感知器)块交错组成。我们通过将这两个组件组合成一个来简化此架构,该组件被同质堆叠(图 3)。这受到门控注意力单元(GAU)(Hua 等人 2022)的启发,该单元对注意力做了类似的事情。

该架构涉及通过可控的扩展因子 EE 扩展模型维度 DD。对于每个块,大多数参数 (3D23D^2) 都在线性投影中(2D22D^2 用于输入投影,D2D^2 用于输出投影),而内部 SSM 的贡献较小。相比之下,SSM 参数的数量(用于 Δ,B,C\Delta, B, C 和矩阵 AA 的投影)要小得多。我们重复此块,与标准归一化和残差连接交错,以形成 Mamba 架构。我们在实验中始终固定 E=2E=2,并使用两个块堆栈来匹配 Transformer 交错 MHA(多头注意力)和 MLP 块的 12D212D^2 参数。我们使用 SiLU / Swish 激活函数(Hendrycks 和 Gimpel 2016;Ramachandran, Zoph, 和 Quoc V Le 2017),其动机是使门控 MLP 成为流行的“SwiGLU”变体(Chowdhery 等人 2023;Dauphin 等人 2017;Shazeer 2020;Touvron 等人 2023)。最后,我们额外使用了一个可选的归一化层(我们选择 LayerNorm(J. L. Ba, Kiros, 和 Hinton 2016)),其动机是 RetNet 在类似位置使用了归一化层(Y. Sun 等人 2023)。

3.5 选择机制的属性

选择机制是一个更广泛的概念,可以以不同的方式应用,例如应用于更传统的 RNN 或 CNN,应用于不同的参数(例如算法 2 中的 BB),或使用不同的转换 s(x)s(x)


图 3

图 3:(架构。)我们简化的块设计将 H3 块(大多数 SSM 架构的基础)与现代神经网络中无处不在的 MLP 块相结合。我们没有交错这两个块,而是简单地同质重复 Mamba 块。与 H3 块相比,Mamba 用激活函数替换了第一个乘法门。与 MLP 块相比,Mamba 在主分支中添加了一个 SSM。对于 σ\sigma,我们使用 SiLU / Swish 激活(Hendrycks 和 Gimpel 2016;Ramachandran, Zoph, 和 Quoc V Le 2017)。


3.5.1 与门控机制的联系

我们强调最重要的联系:RNN 的经典门控机制是我们 SSM 选择机制的一个实例。我们注意到 RNN 门控与连续时间系统离散化之间的联系已经确立(Funahashi 和 Nakamura 1993;Tallec 和 Ollivier 2018)。事实上,定理 1 是 Gu, Johnson, Goel 等人(2021,引理 3.1)的改进,推广到 ZOH 离散化和输入依赖门(证明在附录 C 中)。更广泛地说,SSM 中的 Δ\Delta 可以被视为发挥了 RNN 门控机制的通用作用。与先前的工作一致,我们采取的观点是,SSM 的离散化是启发式门控机制的原则性基础。

定理 1。当 N=1,A=1,B=1,sΔ=Linear(x)N=1, A=-1, B=1, s_\Delta = \text{Linear}(x), 且 τΔ=softplus\tau_\Delta = \text{softplus} 时,选择性 SSM 循环(算法 2)采用以下形式

gt=σ(Linear(xt))g_t = \sigma(\text{Linear}(x_t)) ht=(1gt)ht1+gtxt(5)h_t = (1 - g_t)h_{t-1} + g_t x_t \quad (5)

如第 3.2 节所述,我们对 sΔ,τΔs_\Delta, \tau_\Delta 的具体选择来自这种联系。特别是,注意如果给定的输入 xtx_t 应该被完全忽略(如合成任务中必要的那样),所有 DD 个通道都应该忽略它,因此我们在用 Δ\Delta 重复/广播之前将输入投影到 1 维。

3.5.2 选择机制的解释

我们阐述选择的三个特定机械效应。

可变间距。选择性允许过滤掉可能出现在感兴趣输入之间的不相关噪声标记。这由选择性复制任务证明,但在常见数据模态中无处不在,特别是对于离散数据——例如存在诸如“um”之类的语言填充词。此属性的产生是因为模型可以在机械上过滤掉任何特定的输入 xtx_t,例如在门控 RNN 情况(定理 1)中当 gt0g_t \to 0 时。

过滤上下文。已经实证观察到,许多序列模型不会随着更长的上下文而提高(F. Shi 等人 2023),尽管原则上更多的上下文应该导致严格更好的性能。一种解释是,许多序列模型在必要时无法有效地忽略不相关的上下文;一个直观的例子是全局卷积(和通用 LTI 模型)。另一方面,选择性模型可以在任何时候简单地重置其状态以删除无关的历史记录,因此它们的性能原则上随着上下文长度单调提高(例如第 4.3.2 节)。

边界重置。在将多个独立序列拼接在一起的设置中,Transformer 可以通过实例化特定的注意力掩码来保持它们分开,而 LTI 模型会在序列之间渗漏信息。选择性 SSM 也可以在边界处重置其状态(例如 Δt\Delta_t \to \infty,或定理 1 中当 gt1g_t \to 1 时)。这些设置可能人为地发生(例如将文档打包在一起以提高硬件利用率)或自然地发生(例如强化学习中的情节边界(Lu 等人 2023))。

此外,我们阐述每个选择性参数的影响。

Δ\Delta 的解释。通常,Δ\Delta 控制在多大程度上关注或忽略当前输入 xtx_t 之间的平衡。它推广了 RNN 门(例如定理 1 中的 gtg_t):机械地,大的 Δ\Delta 重置状态 hh 并关注当前输入 xtx_t,而小的 Δ\Delta 保持状态并忽略当前输入。SSM (1)-(2) 可以被解释为由时间步 Δ\Delta 离散化的连续系统,在这种背景下,直觉是大的 Δ\Delta \to \infty 代表系统更长时间地关注当前输入(从而“选择”它并忘记其当前状态),而小的 Δ0\Delta \to 0 代表被忽略的瞬态输入。

AA 的解释。我们注意到,虽然 AA 参数也可以是选择性的,但它最终仅通过其与 Δ\Delta 的交互(通过 Aˉ=exp(ΔA)\bar{A} = \exp(\Delta A),即离散化 (4))影响模型。因此,Δ\Delta 中的选择性足以确保 (Aˉ,Bˉ)(\bar{A}, \bar{B}) 中的选择性,并且是改进的主要来源。我们假设使 AA 除了 Δ\Delta 之外(或代替 Δ\Delta)具有选择性将具有相似的性能,为了简单起见,我们将其省略。

BBCC 的解释。如第 3.1 节所述,选择性的最重要属性是过滤掉不相关的信息,以便序列模型的上下文可以被压缩成高效的状态。在 SSM 中,将 BBCC 修改为选择性的,允许对是否让输入 xtx_t 进入状态 hth_t,或将状态进入输出 yty_t 进行更细粒度的控制。这些可以被解释为允许模型分别基于内容(输入)和上下文(隐藏状态)来调节循环动力学。

3.6 额外模型细节

实数 vs. 复数。大多数先前的 SSM 在其状态 hh 中使用复数,这对于在感知模态的许多任务上获得强大性能是必要的(Gu, Goel, 和 Ré 2022)。然而,已经实证观察到,完全实值的 SSM 在某些设置中似乎工作得很好,甚至可能更好(Ma 等人 2023)。我们使用实值作为默认值,这对于除我们的一项任务之外的所有任务都非常有效;我们假设复数-实数权衡与数据模态中的连续-离散谱有关,其中复数对连续模态(例如音频、视频)有帮助,但对离散模态(例如文本、DNA)没有帮助。

初始化。大多数先前的 SSM 也建议特殊的初始化,特别是在复值情况下,这可以在低数据状态等几种设置中提供帮助。我们复值情况下的默认初始化是 S4D-Lin,实值情况下的默认初始化是 S4D-Real(Gu, Gupta 等人 2022),它们基于 HIPPO 理论(Gu, Dao 等人 2020)。这些分别将 AA 的第 nn 个元素定义为 1/2+ni-1/2 + ni(n+1)-(n+1)。然而,我们预计许多初始化都能很好地工作,特别是在大数据和实值 SSM 状态下;第 4.6 节考虑了一些消融实验。

Δ\Delta 的参数化。我们将 Δ\Delta 的选择性调整定义为 sΔ(x)=BroadcastD(Linear1(x))s_\Delta(x) = \text{Broadcast}_D(\text{Linear}_1(x)),这是受 Δ\Delta 机制(第 3.5 节)的启发。我们观察到它可以从维度 1 推广到更大的维度 RR。我们将其设置为 DD 的一小部分,与块中的主要线性投影相比,它使用的参数数量可以忽略不计。我们还注意到,广播操作可以被视为另一个线性投影,初始化为 1 和 0 的特定模式;如果此投影是可训练的,则会导致替代的 sΔ(x)=LinearR(Linear1(x))s_\Delta(x) = \text{Linear}_R(\text{Linear}_1(x)),这可以被视为低秩投影。

在我们的实验中,Δ\Delta 参数(可以被视为偏置项)被初始化为 τΔ1(Uniform([0.001,0.1]))\tau_\Delta^{-1}(\text{Uniform}([0.001, 0.1])),遵循先前关于 SSM 的工作(Gu, Johnson, Timalsina 等人 2023)。

备注 3.1。为了简洁起见,在我们的实验结果中,我们有时将选择性 SSM 简称为 S6 模型,因为它们是具有选择机制并用扫描计算的 S4 模型。


4 经验评估

在第 4.1 节中,我们测试了 Mamba 解决第 3.1 节中激发的两个合成任务的能力。然后,我们在三个领域进行评估,每个领域都评估了自回归预训练以及下游任务。

  • 第 4.2 节:语言模型预训练(缩放定律)和零样本下游评估。
  • 第 4.3 节:DNA 序列预训练,以及在长序列分类任务上的微调。
  • 第 4.4 节:音频波形预训练,以及自回归生成的语音片段的质量。

最后,第 4.5 节展示了 Mamba 在训练和推理时间上的计算效率,第 4.6 节消融了架构和选择性 SSM 的各个组件。

4.1 合成任务

这些任务的完整实验细节,包括任务细节和训练协议,都在附录 E.1 中。

4.1.1 选择性复制

复制任务是序列建模中最受研究的合成任务之一,最初旨在测试循环模型的记忆能力。如第 3.1 节所述,LTI SSM(线性循环和全局卷积)可以通过仅跟踪时间而不是对数据进行推理来轻松解决此任务;例如,通过构建长度完全正确的卷积核(图 2)。这在先前关于全局卷积的工作中得到了明确验证(Romero 等人 2021)。选择性复制任务通过随机化标记之间的间距来防止这种捷径。注意,此任务之前已被引入为去噪任务(Jing 等人 2019)。

注意,许多先前的工作认为添加架构门控(乘法交互)可以赋予模型“数据依赖性”并解决相关任务(Dao, Fu, Saab 等人 2023;Poli 等人 2023)。然而,我们发现这种解释在直觉上是不够的,因为这种门控不会沿序列轴交互,也不能影响标记之间的间距。特别是架构门控不是选择机制的一个实例(附录 A)。

表 1 证实,H3 和 Mamba 等门控架构仅部分提高了性能,而选择机制(将 S4 修改为 S6)可以轻松解决此任务,特别是当与这些更强大的架构结合时。

4.1.2 归纳头

归纳头(Olsson 等人 2022)是来自机械可解释性视角(Elhage 等人 2021)的一个简单任务,它出人意料地预测了 LLM 的上下文学习能力。它要求模型执行联想回忆和复制:例如,如果模型在序列中看到了诸如“Harry Potter”之类的双字母组,那么当“Harry”下一次出现在同一序列中时,模型应该能够通过从历史记录中复制来预测“Potter”。

数据集。我们在序列长度 256、词汇量 16 的归纳头任务上训练了一个 2 层模型,这与该任务的先前工作相当(Dao, Fu, Saab 等人 2023),但序列更长。我们还通过在测试时评估从 26=642^6 = 64220=10485762^{20} = 1048576 的一系列序列长度,研究了泛化和外推能力。

模型。遵循关于归纳头的既定工作,我们使用 2 层模型,这允许注意力机制在机械上解决归纳头任务(Olsson 等人 2022)。我们测试了多头注意力(8 个头,具有各种位置编码)和 SSM 变体。我们为 Mamba 使用了 64 的模型维度 DD,为其他模型使用了 128。

结果。表 2 显示 Mamba——更准确地说是其选择性 SSM 层——有能力完美地解决该任务,因为它能够选择性地记住相关标记,同时忽略中间的所有其他内容。它完美地泛化到百万长度的序列,或者比它在训练期间看到的长度长 4000 倍,而没有其他方法超过 2 倍。


模型架构准确率
S4无门控S418.3
-无门控S697.0
H3H3S457.0
HyenaH3Hyena30.1
-H3S699.7
-MambaS456.4
-MambaHyena28.4
MambaMambaS699.8

表 1:(选择性复制。)架构和内部序列层组合的准确率。

图 4

表 2:(归纳头。)模型在序列长度 28=2562^8 = 256 上进行训练,并在 26=642^6 = 64220=10485762^{20} = 1048576 的增加序列长度上进行测试。完整数字见表 11。

图 4:(缩放定律。)在 Pile 上训练的 125M\approx 125M1.3B\approx 1.3B 参数的模型。Mamba 的缩放效果优于所有其他无注意力模型,并且是第一个匹配非常强大的“Transformer++”配方性能的模型,该配方现已成为标准,特别是在序列长度增加时。


在注意力模型的位置编码变体中,xPos(专为长度外推而设计)略好于其他变体;还要注意,由于内存限制,所有注意力模型仅在高达序列长度 214=163842^{14} = 16384 的情况下进行了测试。在其他 SSM 中,H3 和 Hyena 是相似的,这与 Poli 等人(2023)的发现相反。

4.2 语言建模

我们在标准自回归语言建模上评估 Mamba 架构与其他架构的对比,包括预训练指标(困惑度)和零样本评估。我们将模型大小(深度和宽度)设置为镜像 GPT3 规范。我们使用 Pile 数据集(L. Gao, Biderman 等人 2020),并遵循 Brown 等人(2020)中描述的训练配方。所有训练细节都在附录 E.2 中。

4.2.1 缩放定律

对于基线,我们与标准 Transformer 架构(GPT3 架构)以及我们所知的最强 Transformer 配方(此处称为 Transformer++)进行比较,该配方基于 PaLM 和 LLaMa 架构(例如旋转嵌入、SwiGLU MLP、RMSNorm 而不是 LayerNorm、无线性偏置和更高的学习率)。我们还与其他最近的次二次架构进行了比较(图 4)。所有模型细节都在附录 E.2 中。

图 4 显示了在标准 Chinchilla(Hoffmann 等人 2022)协议下,从 125M\approx 125M1.3B\approx 1.3B 参数的模型上的缩放定律。Mamba 是第一个匹配非常强大的 Transformer 配方(Transformer++)性能的无注意力模型,该配方现已成为标准,特别是在序列长度增加时。(我们注意到 RWKV 和 RetNet 基线缺少上下文长度 8k 的完整结果,这些先前强大的循环模型也可以被解释为 SSM,因为缺乏高效的实现导致内存不足或不切实际的计算需求。)


4.2.2 下游评估

表 3 显示了 Mamba 在一系列流行的下游零样本评估任务上的性能。我们与这些规模下最著名的开源模型进行了比较,最重要的是 Pythia(Biderman 等人 2023)和 RWKV(B. Peng 等人 2023),它们使用与我们模型相同的分词器、数据集和训练长度(300B 个标记)进行了训练。(注意,Mamba 和 Pythia 是以 2048 的上下文长度训练的,而 RWKV 是以 1024 的上下文长度训练的。)

表 3:(零样本评估。)每种规模的最佳结果以粗体显示。我们与使用各种分词器、训练长达 300B 个标记的开源 LM 进行了比较。Pile 指的是验证集,仅与在相同数据集和分词器(GPT-NeoX-20B)上训练的模型进行比较。对于每种模型规模,Mamba 在每一项评估结果上都是同类最佳,并且通常与两倍于其规模的基线相匹配。

模型分词器Pile PPL ↓LAMBADA PPL ↓LAMBADA ACC ↑HellaSwag ACC ↑PIQA ACC ↑Arc-E ACC ↑Arc-C ACC ↑WinoGrande ACC ↑平均 ACC ↑
Hybrid H3-130MGPT289.4825.7731.764.244.424.250.640.1
Pythia-160MNeoX29.6438.1033.030.261.443.224.151.940.6
Mamba-130MNeoX10.5616.0744.335.364.548.024.351.944.7
Hybrid H3-360MGPT212.5848.041.568.151.424.754.148.0
Pythia-410MNeoX9.9510.8451.440.666.952.124.653.848.2
Mamba-370MNeoX8.288.1455.646.569.555.128.055.350.0
Pythia-1BNeoX7.827.9256.147.270.757.027.153.551.9
Mamba-790MNeoX7.336.0262.755.172.161.229.556.157.1
GPT-Neo 1.3BGPT27.5057.248.971.156.225.954.952.4
Hybrid H3-1.3BGPT211.2549.652.671.359.228.156.953.0
OPT-1.3BOPT6.6458.053.772.456.729.659.555.0
Pythia-1.4BNeoX7.516.0861.752.171.060.528.557.255.2
RWKV-1.5BNeoX7.707.0456.452.572.460.529.454.654.3
Mamba-1.4BNeoX6.805.0464.959.174.265.532.861.559.7
GPT-Neo 2.7BGPT25.6362.255.872.161.130.257.656.5
Hybrid H3-2.7BGPT27.9255.759.773.365.632.361.458.0
OPT-2.7BOPT5.1263.660.674.860.831.361.058.7
Pythia-2.8BNeoX6.735.0464.759.374.064.132.959.759.1
RWKV-3BNeoX7.005.2463.959.673.767.833.159.659.6
Mamba-2.8BNeoX6.224.2369.266.175.269.736.363.563.3
GPT-J-6BGPT24.1068.366.375.467.036.664.163.0
OPT-6.7BOPT4.2567.767.276.365.634.965.562.9
Pythia-6.9BNeoX6.514.4567.164.075.267.335.561.361.7
RWKV-7.4BNeoX6.314.3867.265.576.167.837.561.062.5

4.3 DNA 建模

受大型语言模型成功的激励,最近出现了将基础模型范式用于基因组学的探索。DNA 被比作语言,因为它由具有有限词汇的离散标记序列组成。它也以需要长程依赖来建模而闻名(Avsec 等人 2021)。我们研究了 Mamba 作为 FM 主干在与最近关于 DNA 长序列模型的工作相同的设置下进行预训练和微调(Nguyen, Poli 等人 2023)。特别是,我们专注于两个关于模型规模和序列长度的缩放定律探索(图 5),以及一个需要长上下文的困难下游合成分类任务(图 6)。

对于预训练,我们在很大程度上遵循标准的因果语言建模(下一个标记预测)设置来进行训练和模型细节(另请参阅附录 E.2)。对于数据集,我们很大程度上遵循 HyenaDNA 的设置(Nguyen, Poli 等人 2023),它使用 HG38 数据集进行预训练,该数据集由单个包含约 45 亿个标记(DNA 碱基对)的人类基因组组成,位于训练拆分中。


图 5

图 5:(DNA 缩放定律。)在 HG38(人类基因组)数据集上进行预训练。(左)固定短上下文长度 210=10242^{10} = 1024 并将规模从 200K\approx 200K 增加到 40M\approx 40M 参数,Mamba 的缩放效果优于基线。(右)固定模型规模并增加序列长度,同时保持标记/批次和总训练标记固定。与基线不同,Mamba 的选择机制促进了随着上下文长度增加而更好的性能。

4.3.1 缩放:模型规模

在此实验中,我们研究了具有各种模型主干的基因组学基础模型的缩放属性(图 5 左)。

训练。为了使基线具有优势,我们在 1024 的短序列长度上进行训练;如第 4.3.2 节所示,我们预计结果在更长的序列长度上会更偏向 Mamba。我们固定全局批大小为 1024,每批总共 2201M2^{20} \approx 1M 个标记。模型训练了 10K10K 个梯度步,总共 10B10B 个标记。

结果。图 5(左)显示 Mamba 的预训练困惑度随着模型规模平滑提高,并且 Mamba 的缩放效果优于 HyenaDNA 和 Transformer++。例如,在 40M\approx 40M 参数的最大模型规模下,曲线显示 Mamba 可以以大约 3 倍到 4 倍更少的参数匹配 Transformer++ 和 HyenaDNA 模型。

4.3.2 缩放:上下文长度

在下一个 DNA 实验中,我们研究了模型相对于序列长度的缩放属性。我们仅比较 HyenaDNA 和 Mamba 模型,因为二次方注意力在更长的序列长度上变得极其昂贵。我们预训练了序列长度为 210=1024,212=4096,214=16384,216=65536,218=262144,220=10485762^{10} = 1024, 2^{12} = 4096, 2^{14} = 16384, 2^{16} = 65536, 2^{18} = 262144, 2^{20} = 1048576 的模型。我们固定模型规模为 6 层,宽度为 128(约 1.3M-1.4M 参数)。模型训练了 20K20K 个梯度步,总共 330B\approx 330B 个标记。更长的序列长度使用了类似于(Nguyen, Poli 等人 2023)的序列长度预热。

结果。图 5(右)显示 Mamba 能够利用更长的上下文,甚至高达 1M 长度的极长序列,并且其预训练困惑度随着上下文的增加而提高。另一方面,HyenaDNA 模型随着序列长度的增加而变差。这从第 3.5 节关于选择机制属性的讨论中是直观的。特别是,LTI 模型不能选择性地忽略信息;从卷积的角度来看,非常长的卷积核正在聚合跨长序列的所有信息,这可能非常嘈杂。注意,虽然 HyenaDNA 声称随着更长的上下文而提高,但他们的结果没有控制计算时间。

4.3.3 合成物种分类

我们评估了模型在通过随机采样其 DNA 的连续片段来对 5 种不同物种进行分类的下游任务上的表现。此任务改编自 HyenaDNA,它使用了物种 {人类, 狐猴, 老鼠, 猪, 河马}。我们将任务修改为更具挑战性,通过对五种大猿物种 {人类, 黑猩猩, 大猩猩, 红毛猩猩, 倭黑猩猩} 进行分类,已知它们共享 99% 的 DNA。


图 6

图 6:(大猿 DNA 分类。)在序列长度 210=10242^{10} = 1024220=10485762^{20} = 1048576 上微调后的准确率,使用相同上下文长度的预训练模型。数值结果在表 13 中。

图 7

图 7:(音频预训练。)Mamba 在自回归音频建模方面提高了性能,优于先前最先进的模型(Sashimi),同时改进了长达一分钟的上下文或百万长度的序列(控制计算)。


4.4 音频建模与生成

对于音频波形模态,我们主要与 SaShiMi 架构和训练协议(Goel 等人 2022)进行比较。该模型包括:

  1. 一个具有两个池化阶段的 U-Net 主干,池化因子为 pp,每阶段使模型维度 DD 加倍,
  2. 每个阶段交替使用 S4 和 MLP 块。

我们考虑用 Mamba 块替换 S4+MLP 块。实验细节在附录 E.4 中。

4.4.1 长上下文自回归预训练

我们评估了 YouTubeMix(DeepSound 2017)上的预训练质量(自回归下一个样本预测),这是一个先前工作使用的标准钢琴音乐数据集,由 4 小时的独奏钢琴音乐组成,采样率为 16000 Hz。预训练细节很大程度上遵循标准的语言建模设置(第 4.2 节)。图 7 评估了将训练序列长度从 213=81922^{13} = 8192 增加到 2201062^{20} \approx 10^6 的影响,同时保持计算固定。(数据整理方式存在一些轻微的边缘情况,这可能导致缩放曲线出现扭结。例如,只有分钟长的片段可用,因此最大序列长度实际上受限于 60s16000Hz=96000060s \cdot 16000Hz = 960000。)

Mamba 和 SaShiMi(S4+MLP)基线都随着更长的上下文长度而持续提高;Mamba 在整个过程中表现更好,并且差距在更长的长度上扩大。主要指标是每字节位数(BPB),它是预训练其他模态的标准负对数似然(NLL)损失的常数因子 log(2)\log(2)

我们注意一个重要的细节:这是本文中唯一一个我们从实参数化切换到复参数化的实验(第 3.6 节)。我们在附录 E.4 中展示了额外的消融实验。

4.4.2 自回归语音生成

SC09 是一个基准语音生成数据集(Donahue, McAuley, 和 Puckette 2019;Warden 2018),由采样率为 16000 Hz 的 1 秒片段组成,包含数字“零”到“九”,具有高度可变的特征。我们很大程度上遵循 Goel 等人(2022)的自回归训练设置和生成协议。

表 4 显示了 Mamba-UNet 模型与 Goel 等人(2022)的各种基线相比的自动化指标:WaveNet(Oord 等人 2016)、SampleRNN(Mehri 等人 2017)、WaveGAN(Donahue, McAuley, 和 Puckette 2019)、DiffWave(Z. Kong 等人 2021)和 SaShiMi。一个小型的 Mamba 模型优于最先进的(且大得多的)GAN 和基于扩散的模型。一个参数与基线匹配的更大模型进一步显著提高了保真度指标。

表 5 采用了小型 Mamba 模型,并研究了外层阶段和中心阶段不同架构的组合。它表明 Mamba 在外层块中始终优于 S4+MLP,并且在中心块中 Mamba > S4+MLP > MHA+MLP。


表 4:(SC09)在具有挑战性的固定长度语音片段数据集上进行无条件生成的自动化指标。(从上到下)自回归基线、非自回归基线、Mamba 和数据集指标。

模型参数NLL ↓FID ↓IS ↑mIS ↑AM ↓
SampleRNN35.0M2.0428.961.713.021.76
WaveNet4.2M1.9255.082.275.801.47
SaShiMi5.8M1.8731.995.1342.570.74
WaveGAN19.1M-2.034.9036.100.80
DiffWave24.1M-1.925.2651.210.68
+ SaShiMi23.0M-1.425.9469.170.59
Mamba6.1M1.8520.946.2688.540.52
Mamba24.3M1.8600.677.33144.90.36
训练--0.008.56292.50.16
测试--0.028.33257.60.19

表 5:(SC09 模型消融)具有 6M 参数的模型。在 SaShiMi 的 U-Net 主干中,有 8 个中心块在序列长度 1000 上运行,两侧各被 8 个在序列长度 4000 上运行的外部块夹在中间,再被 8 个在序列长度 16000 上运行的外部块夹在中间(总共 40 个块)。8 个中心块的架构独立于其余部分进行消融。注意,由于效率限制,Transformer(MHA+MLP)没有在更重要的外部块中进行测试。

外层中心NLL ↓FID ↓IS ↑mIS ↑AM ↓
S4+MLPMHA+MLP1.8591.455.0647.030.70
S4+MLPS4+MLP1.8671.435.4253.540.65
S4+MLPMamba1.8591.425.7156.510.64
MambaMHA+MLP1.8501.375.6358.230.62
MambaS4+MLP1.8531.076.0573.340.55
MambaMamba1.8520.946.2688.540.52

4.5 速度和内存基准

我们在图 8 中对 SSM 扫描操作(状态扩展 N=16N=16)的速度以及 Mamba 的端到端推理吞吐量进行了基准测试。我们的高效 SSM 扫描在序列长度 2K 以上时比我们所知的最佳注意力实现(FlashAttention-2(Dao 2024))更快,并且比 PyTorch 中的标准扫描实现快 20-40 倍。Mamba 实现了比同等规模的 Transformer 高 4-5 倍的推理吞吐量,因为没有 KV 缓存,它可以利用更高的批大小。例如,Mamba-6.9B(未训练)的推理吞吐量将高于 5 倍更小的 Transformer-1.3B。详细信息在附录 E.5 中,其中还包括内存消耗的基准测试。

图 8

图 8:(效率基准。)(左)训练:我们的高效扫描比标准实现快 40 倍。(右)推理:作为循环模型,Mamba 可以实现比 Transformer 高 5 倍的吞吐量。

4.6 模型消融

我们对模型的组件进行了一系列详细的消融实验,重点关注 Chinchilla 标记计数下 350M\approx 350M 模型规模的语言建模设置(与图 4 相同的设置)。

4.6.1 架构

表 6 研究了架构(块)及其内部 SSM 层(图 3)的影响。我们发现

  • 在先前等同于全局卷积的非选择性(LTI)SSM 中,性能非常相似。
  • 用实值变体替换先前工作中的复值 S4 变体对性能影响不大,这表明(至少对于 LM 而言)在考虑硬件效率时,实值 SSM 可能是更好的选择。
  • 用选择性 SSM (S6) 替换其中任何一个都会显著提高性能,验证了第 3 节的动机。

表 6:(消融:架构和 SSM 层。)Mamba 块的表现类似于 H3,同时更简单。在内部层中,LTI 模型的不同参数化之间几乎没有差异,而选择性 SSM (S6) 提供了巨大的改进。更具体地说,S4(实数)变体是 S4D-Real,S4(复数)变体是 S4D-Lin。

模型架构SSM 层困惑度
HyenaH3Hyena10.24
H3H3S4 (复数)10.30
-H3S4 (实数)10.34
-H3S68.95
模型架构SSM 层困惑度
-MambaHyena10.75
-MambaS4 (复数)10.54
-MambaS4 (实数)10.56
MambaMambaS68.69

表 7:(消融:选择性参数。)Δ\Delta 是最重要的参数(定理 1),但将多个选择性参数结合在一起会产生协同效应。

选择性 Δ\Delta选择性 BB选择性 CC困惑度
10.93
10.15
9.99.98
9.81
8.71

表 8:(消融:AA 的参数化。)当 SSM 是选择性时,基于 S4D-Lin(Gu, Gupta 等人 2022)的更标准初始化比 S4D-Real 或随机初始化表现更差。

AnA_n 初始化困惑度
An=1/2+niA_n = -1/2 + ni复数9.16
An=1/2A_n = -1/2实数8.85
An=(n+1)A_n = -(n+1)实数8.71
Anexp(N(0,1))A_n \sim \exp(\mathcal{N}(0, 1))实数8.71
  • Mamba 架构的表现类似于 H3 架构(并且在使用选择性层时似乎略好)。

我们还在附录 E.2.2 中研究了将 Mamba 块与其他块(如 MLP(传统架构)或 MHA(混合注意力架构))交错的情况。

4.6.2 选择性 SSM

表 7 通过考虑选择性 Δ,B\Delta, BCC 参数的不同组合(算法 2)来消融选择性 SSM 层,表明由于其与 RNN 门控的联系(定理 1),Δ\Delta 是最重要的参数。

表 8 考虑了 SSM 的不同初始化,这些初始化已被证明在某些数据模态和设置中会产生巨大差异(Gu, Goel, 和 Ré 2022;Gu, Gupta 等人 2022)。在语言建模上,我们发现更简单的实值对角线初始化(S4D-Real,第 3 行)比更标准的复值参数化(S4D-Lin,第 1 行)表现更好。随机初始化也工作得很好,这与先前工作的发现一致(Mehta 等人 2023)。

表 9 和表 10 分别考虑了改变 Δ\Delta(B,C)(B, C) 投影的维度。将它们从静态变为选择性提供了最大的好处,而进一步增加维度通常会以参数数量的微小增加为代价,适度提高性能。

特别值得注意的是,当状态大小 NN 增加时,选择性 SSM 的性能得到了显著提升,仅以 1% 的额外参数成本就实现了超过 1.0 的困惑度改进。这验证了我们在第 3.1 节和第 3.3 节中的核心动机。


5 讨论

我们讨论相关工作、局限性以及一些未来方向。

相关工作。附录 A 讨论了选择机制如何与类似概念相关联。附录 B 提供了 SSM 和其他相关模型的扩展相关工作。


表 9:(消融:Δ\Delta 的表达能力。)Δ\Delta 的选择机制通过输入的投影来构建它。将其投影到维度 1 即可提供性能的大幅提升;进一步增加它可以在参数适度增加的成本下提供进一步的改进。状态大小固定为 N=16N=16

Δ\Delta 投影大小参数 (M)困惑度
-358.99.12
1359.18.97
2359.38.97
4359.78.91
8360.58.83
16362.18.84
32365.28.80
64371.58.71

表 10:(消融:SSM 状态维度。)(上)恒定 BBCC(下)选择性 BBCC。增加 SSM 状态维度 NN(可以被视为循环状态维度的扩展因子)可以在参数/FLOPs 成本可忽略不计的情况下显著提高性能,但前提是 BBCC 也是选择性的。Δ\Delta 投影大小固定为 64。

状态维度 NN参数 (M)困惑度
1367.19.88
2367.49.86
4368.09.82
8369.19.82
16371.59.81
1367.19.73
2367.49.40
4368.09.09
8369.18.84
16371.58.71

没有免费午餐:连续-离散谱。结构化 SSM 最初被定义为连续系统 (1) 的离散化,并且对感知信号(例如音频、视频)等连续时间数据模态具有很强的归纳偏置。如第 3.1 节和第 3.5 节所述,选择机制克服了它们在文本和 DNA 等离散模态上的弱点;但这反过来可能会阻碍它们在 LTI SSM 擅长的数据上的性能。我们对音频波形的消融实验更详细地检查了这种权衡。

下游可供性。基于 Transformer 的基础模型(特别是 LLM)拥有丰富的属性和与预训练模型交互的模式生态系统,例如微调、适应、提示、上下文学习、指令微调、RLHF、量化等。我们特别感兴趣的是,Transformer 的替代方案(如 SSM)是否具有类似的属性和可供性。

缩放。我们的经验评估仅限于小模型规模,低于大多数强开源 LLM(例如 Llama(Touvron 等人 2023))以及其他循环模型(如 RWKV(B. Peng 等人 2023)和 RetNet(Y. Sun 等人 2023))的阈值,这些模型已在 7B 参数规模及以上进行了评估。Mamba 在这些更大的规模下是否仍然表现良好,还有待评估。我们还注意到,缩放 SSM 可能涉及本文未讨论的进一步工程挑战和模型调整。


6 结论

我们为结构化状态空间模型引入了一种选择机制,使它们能够在序列长度上线性扩展的同时执行上下文相关推理。当合并到简单的无注意力架构中时,Mamba 在一组不同的领域上实现了最先进的结果,其性能匹配或超过了强大的 Transformer 模型。我们对利用选择性状态空间模型为不同领域构建基础模型感到兴奋,特别是在需要长上下文的新兴模态(如基因组学、音频和视频)中。我们的结果表明,Mamba 是作为通用序列模型主干的有力候选者。

致谢 我们感谢 Karan Goel、Arjun Desai 和 Kush Bhatia 对草稿提供的有益反馈。

参考文献 [1] Martin Arjovsky, Amar Shah, and Yoshua Bengio. “Unitary Evolution Recurrent Neural Networks”. In: The International Conference on Machine Learning (ICML). 2016, pp. 1120–1128.

硬核测试

正确率:0 / 5
1

根据论文,Transformer 架构在处理长序列时面临的主要计算挑战是什么?

2

Mamba 模型引入的“选择性机制”主要解决了传统 SSM 的什么问题?

3

Mamba 模型为了在循环模式下实现高效计算,采用了什么关键技术?

4

关于 Mamba 模型的性能表现,下列说法错误的是?

5

在状态空间模型(SSM)的背景下,“线性时间不变性(LTI)”指的是什么?