一文通透想颠覆Transformer的Mamba:从SSM、S4到mamba、线性transformer(含RWKV解析)

前言

不知读者发现没有,本文标题的信息含量很大,比如

  1. 出来了一个新的序列模型:Mamba,其基于SSM或S4(Structured State Space for Sequence Modeling,连起来4个S,故简称S4)发展为S6(S4 models with a selection mechanism and computed with a scan),其对应的论文为《Mamba: Linear-Time Sequence Modeling with Selective State Spaces
  2. 该Mamba模型的提出者为Albert Gu、Tri Dao,前者现在是CMU助理教授,多年来一直推动SSM架构发展,曾在DeepMind 工作,后者为FlashAttention的一作
    换言之,除了论文中展示的效果确实不错之外,由于提出者的背景不一般,所以关注的人比较多
  3. Transformer统治各大领域近7年了,7年来,挑战Transformer的模型其实不少 (比如linear attention, gated convolution and recurrent models, and SSMs),该模型能否真正颠覆Transformer的霸权呢?对此,我们可以细究其原理细节,看看其创新到底是否靠谱、力度是否大

加之有一大模型项目开发营的朋友问道,可否在论文100课上解读下Mamba这篇论文,于此,便有了此文,且具备4个特点

  1. 清晰易懂:也为「不需要天天看paper的朋友」而写
    在ChatGPT诞生后的一年来,以大模型为代表的技术发展特别快,经常一个月会出来很多新的技术、模型
    而不一定非得是每天在实验室扎根于科研的人 才有资格去追踪前沿技术发展,还有一大帮可能是出于对前沿技术的了解、兴趣、热爱、应用而想追踪,可这帮朋友平时或因工作或事太多而不一定对每个新技术、新模型都去看一遍论文,即不可能天天看paper

    那咋办呢?他们可能通过一些比如公众号之类的文章去了解,但有的公号文章写的不错,有的则写的不够清晰易懂甚至漏洞百出,会因此让读到这种文章的朋友对新技术、新模型产生畏难心理甚至被误导

    故,我和我司来了,为帮助更多朋友更好、更快、更细致的了解大模型相关技术及其实践,我个人算是笔耕不辍(我自今年年初以来也史无前例的写了近30篇,详见:大模型/AIGC/ChatGPT系列:原理、论文、代码、实战)、团队和我算讲课不停
  2. 中英对比:部分关键的阐述中英文对照学习
    考虑到这些新技术、新模型刚推出的时候,论文还是相对最严谨的参考,所以本文会延续前几篇文章的风格:对于一些关键的阐述会把原英文的表述用斜体且淡色的黑体表示,毕竟有的描述对其翻译相比,用原英文阐述更精准
  3. 足够细致:从HiPPO、SSM、S4起步,逐步推导到Mamba
    目前介绍mamba模型的文章,少部分写得很不错,大部分不是这个细节没深入,便是那个细节没深入,考虑到如果很多关键细节没有介绍的话,那没法彻底理解mamba模型
    因此,本文会尽可能兼顾所有必须写清楚的细节(比如如果不理解SSM和S4则无法理解mamba模型,故本文会从HiPPO、SSM、S4起步,逐步推导到mamba),尽可能一文通透mamba模型
  4. 足够全面:Mamba之外,更精讲「线性Transformer」相关
    友人钟博士曾评论,不带点积注意力机制的开源模型中,有希望超越带注意力机制的Llama架构的,一个是mamba,一个便是TransnormerLLM
    mamba之外,为帮助大家更深入、更全面的理解不带点积注意力机制的线性transformer,本文第三部分将精讲国内的两个相关工作:TransnormerLLM和RWKV

第一部分 背景知识:Transformer时间复杂度、HiPPO、SSM、S4

1.1 Transformer的二次复杂度

通过之前本博客内的另一篇文章《通透理解FlashAttention与FlashAttention2:让大模型上下文长度突破32K的技术之一》,可知

简单理解的话,计算复杂度和序列长度的平方N^2成正比,可以看一个小例子,比如两个相乘的矩阵大小分别为(N \times d) 和(d \times N),矩阵乘法的一种计算方式是使用第一个矩阵的每一行与第二个矩阵的每一列做​点乘​

因为我们需要拿第一个矩阵的每一行去与第二个矩阵的每一列做点乘,所以总共就需要 N^2 次点乘。而每次点乘又需要 d 次乘法,所以总复杂度就为 \mathrm O(N^2d)


精确理解的话,当输入批次大小为 b​ ,序列长度为 N​ 时,
l​ 层transformer模型的计算量为 l *\left(24 b N d^{2}+4 b N^{2} d\right)​,d​则代表词向量的维度或者隐藏层的维度(隐藏层维度通常等于词向量维度)

但这个结果是怎么一步一步计算得到的呢?请看原文

正因为现有的ChatGPT等大模型处理长文本算力消耗巨大,背后原因是Transformer架构中注意力机制的二次复杂度

  1. 一方面,有了针对注意力机制的各种所谓魔改,甚至也有S4、FlashAttention及其二代等
  2. 二方面,S4、FlashAttention等作者提出了新的序列模型:Mamba,在很多语言任务上击败/匹配Transformer性能,具有线性复杂度和5倍推理吞吐量,下文详述

1.2 状态空间模型SSM(State Space Model)

SSM可以视为从输入信号到输出信号的参数化映射

  1. SSMs可以当做是RNN与CNN的结合「These models can be interpreted as acombination of recurrent neural networks (RNNs) and convolutional neural networks (CNNs)
  2. 这类模型可以非常高效地计算为递归或卷积,在序列长度上具有线性或近线性缩放(This class of models can be computed very efficiently as either arecurrence or convolution, with linear or near-linear scaling in sequence length)

然而,他们在对文本等离散且信息密集的数据进行建模时效果较差(they have been less effective at modeling discrete and information-dense data such as text)

从而有了本文要介绍的一类新的选择性状态空间模型(下文详述),它改进了先前的工作,以实现模型在“序列长度线性缩放(scaling linearly in sequence length)”情况下的建模能力

1.3 S4的前身:HiPPO

1.3.1 改进transformer不擅长处理超长的序列的问题:输入u到状态x

如本文开头所说,mamba论文的一作Albert Gu多年来一直在推动SSM的发展

简单来讲,序列数据一般都是离散的数据 比如文本、图、DNA

  1. 但现实生活中还有很多连续的数据,比如音频、视频,对于音视频这种信号而言,其一个重要特点就是有极长的context window
  2. 而在transformer长context上往往会失败,或者注意力机制在有着超长上下文长度的任务上并不擅长(所以你才看到各种对注意力机制的改进,比如flashattention等等,即便如此一般也就32K的上下文长度,在面对100w的序列长度则无能为力),而S4擅长这类任务

为了方便大家更好的理解,Albert Gu举了一个金融领域的例子

  1. 即根据输入,计算其EMA(如下图所示,黑色的一直在跳跃着的曲线是输入x输出y是蓝色的线)

    由于EMA(Exponential Decaying Measure)有着unbounded context(无限长度),Transformers和Convolution因为都只有着有限的上下文窗口而不好计算
  2. Albert Gu发现EMA其实是整个signal的一个summary,相当于是过往所有信号历史的加权平均值,其权重呈指数衰减之势(下图中绿色的线即相当于投影到的指数衰减)

  3. 如果用u表示inputx表示对应的summary(可能你看到这里 觉得表示有点乱,毕竟上面还是输入x 输出y,不过 不要急,很快你会看到:输入u、状态x、输出y)
    那么该summary可以在常数时间内快速计算得到(即summary of entire context update in constant time):

    这个summary作为对之前信息的一个总结,也可以认为是对“当前事物所处在一个什么样的状态”的建模,而随着新信息的不断输入,那么当前事物所处的状态也会不断更新

1.3.2 HiPPO的推导:state compresses the history of input

假设 t_0时刻我们看到了信号 u(t) 的之前部分:

  1. 我们希望在一个memory budget来压缩前面这一段的input来学习特征,一个很容易想到的方法是用多项式去近似这段input

  2. 在我们接收到更多signal的时候,我们希望仍然在这个memory budget内对整段signal进行压缩,自然,你得更新你的多项式的各项系数,如下图底部所示

  3. 以上,会涌现出两个问题:
    1. 如何找到这些最优的近似?
    2. 如何快速地更新多项式的参数?
    为了解决这两个问题,我们需要一个measure去定义一个近似的好坏程度。例如,可以使用EDM

  4. 这就引出了HiPPO(High-order Polynomial Projection Operator)的正式定义,其为两个信号和两个矩阵的组合:

    这个矩阵A就是HiPPO矩阵,比如可以是这样:

  5. HiPPO相当于将函数映射到函数,这里给个通俗的例子解释一下:

    和上面一样,这里的u是原信号,x是压缩后的信号。给定一个持续增长的u,HiPPO允许online update压缩的x。如果使用一个64unit的polynomial压缩器(完全表示需要10000unit,所以是非常高度的压缩),可以发现EDM很不错,保留了大量之前的信息:

    其中红色的线相当于对输入的重建(可以看出来,离当下最近时刻的 其刻画最准确,至于离当下最远的时刻 则其刻画的不那么准确 )
    这里要注意,HiPPO只需要看到这个时刻的多项式(polynomial)参数和在此之前的signal u,不需要看到之前的多项式参数..
  6. 上面都是用EDM这个measure的,但是我们在学习过程中用的往往不只一个measure(例如一个time-varying measure can change over time),这个时候如何去建模?
    最终,作者得到了一个结论:HiPPO可以在各种measure上面成立

1.4 S4的推出:Structured State Space Models

1.4.1 HiPPO的高阶化(输入u到状态x最后输出y)

发现HiPPO在低阶信号上work后,我们希望将它扩展到高阶信号上。阶数越高,与LLM越相似,工作的价值就越大

  1. 但是我们不能直接堆叠HiPPO算子,因为不断增加维度会引起维数爆炸:

  2. 作者想到了非常精妙的一个方法:不考虑input u 到state x,而是直接从state x 到output y
    如下图所示,通过蓝色x(t)的线性组合得到最终的输出红色y(t),这里的 Cx 就是state x 的线性组合,而 D 就是skip connection,是绕开state x ,直接从input u 输出 y 的一个连接:

  3. 这样,我们通过两个方程定义S4
    \rightarrow  一个是之前定义的 x'(下一时刻的 x) 来将input u 记忆成state,如下图左侧所示
    \rightarrow  现在又定义了 y 来将state x 线性组合成一个输出,如下图右侧所示

    相当于输入到状态、状态到输出,至此,也算终于写清楚了S4
  4. 有意思的是,推出来的这些公式组成了一个1960年在ASME会议上提出的State Space Machine! SSM由Kalman提出,原文在这:A New Approach to Linear Filtering and Prediction Problems

1.4.2 Structured SSM

我们正式定义下S4

  1. 首先,有一个state space model,简称为SSM
  2. 其次,在下图所示的两个方程中插入特定的矩阵值

  3. 接着,学习对应的参数

1.5 S4的性质:连续的表示、用Recurrent快速infer、用Convolutional快速训练

接下来,我们来看下如下图所示的S4的三个性质

1.5.1 连续的表示

第一个性质是连续的表示,且就算SSM在离散数据上训练,它仍能学习到底层蕴含的连续信息,因为在SSM眼里,sequence不过是连续信号signal的采样(离散形式),或者说连续的信号模型是离散的序列模型的概括

1.5.2 用Recurrent表示进行快速的infer

第二个性质是有效的online计算,这点之前在HiPPO提到了,就是计算下一时刻的state x' 只需要这一时刻的state x 和全局输入 u

\rightarrow  虽然需要全局输入,但是这个全局的计算是常数时间的,这与RNN相同,而与Transformer/CNN不同
\rightarrow  之所以是常数时间,也与RNN相同,因为有state(中间这条蓝线)这导致下一个state的计算只需要上一个state + 全局的输入

1.5.3 用Convolutional表示进行快速的训练

SSM的一个问题是,当知道未来的signal的时候,训练是低效的。有没有办法并行化SSM?作者提出了使用一个卷积核 K ,绕过状态 x ,直接从输入 u 到输出 y(而非先输入到状态、状态再到输出)

输入u怎么到输出y呢?相当于通过特定的卷积滤波器K对输入进行卷积(即you can involve the input by an exponentially decaying convolution kernel),该滤波器在上图中用绿色线表示

问题好像解决了,但SSM还是存在两个问题

  1. 一个是计算复杂度的问题,最终通过给SSM做结构化(比如使用HiPPO矩阵,相当于变成了S4),即structured state space can be computed faster

  2. 另一个是,作者意识到这个S4某种意义上就是一个很fancy的CNN(包括可以以不同的方式参数化卷积内核),但是context window有时是无限长的
    而刚好convolutional kernel可以无限长(至于单纯的CNN则是有限长的窗口),那其如何设计以适应有时无限长的context window呢?如下图所示

// 待更

第二部分 Mamba的组成结构与原理解析

Mamba在语言、音频、DNA序列模态上都实现SOTA,在最受关注的语言任务上,Mamba-3B超越同等规模的Transformer,与两倍大的Transformer匹敌,并且相关代码、预训练模型checkpoint都已开源

简言之,Mamba是一种状态空间模型(SSM),建立在更现代的适用于深度学习的结构化SSM (简称S6)基础上,与经典架构RNN有相似之处

2.1 Mamba = 有选择处理信息 + 硬件感知算法 + 更简单的SSM架构

与先前的研究相比,Mamba主要有三点创新:

  1. 对输入信息有选择性处理(Selection Mechanism)
    具体而言,设计了一个简单的选择机制,通过“参数化SSM的输入”,以便关注或忽略特定的输入。这样一来,模型能够过滤掉与问题无关的信息,并且可以长期记住与问题相关的信息
    focus on or ignore particular inputs. we design a simple selection mechanism by parameterizing the SSM parameters based on the input. This allows the model to filter out irrelevant information and remember relevant information indefinitely
  2. 硬件感知的算法(Hardware-aware Algorithm)
    该算法采用“扫描”而非“卷积”来进行模型的循环计算,但为了避免GPU内存层次结构中不同级别之间的IO访问,它没有具体化扩展状态
    algorithm that computes the model recurrently with a scan instead of convolution, but does not materialize the expanded state in order to avoid IO access between different levels of theGPU memory hierarchy

    当然,这点也是受到了S5(Simplified State Space Layers for Sequence Modeling)的启发
  3. 更简单的架构
    将SSM架构的设计与transformer的MLP块合并为一个块(combining the design of prior SSM architectures with the MLP block of Transformers into a single block),来简化过去的深度序列模型架构,从而得到一个包含selective state space的架构设计

2.1.1 选择性状态空间模型:从S4到S6

作者认为,序列建模的一个基础问题是把上下文压缩成更小的状态(We argue that a fundamental problem of sequence modeling is compressing context into a smaller state)

  • 从这个角度来看,注意力机制虽然有效果但效率不算很高,毕竟其需要显式地存储整个上下文(也就是KV缓存),直接导致训练和推理消耗算力大
    For example, attention is both effective and inefficient because it explicitly does not compress context at all. This can be seen from the fact that auto regressive inference requires explicitly storing the entire context (i.e. the KV cache), which directly causes the slow linear-time inference and quadratic-time training of Transformers.

    好比,Transformer就像人类每写一个字之前,都把前面的所有字+输入都复习一遍,所以写的慢
  • RNN的推理和训练效率高,但性能容易受到对上下文压缩程度的限制
    On the other hand, recurrent models are efficient because they have a finite state, implying constant-time inference and linear-time training. However, their effectiveness is limited by how well this state has compressed the context.

    好比,RNN每次只参考前面固定的字数,写的快,但容易忘掉更前面的内容
  • Mamba的解决办法是,让模型对信息有选择性处理,可以关注或忽略特定的内容,即使状态大小固定也能压缩上下文
    好比,Mamba每次参考前面所有内容的一个概括,越往后写对前面内容概括得越狠,丢掉细节、保留大意

总之,序列模型的效率与效果的权衡点在于它们对状态的压缩程度:高效的模型必须有一个小的状态,而有效的模型必须有一个包含来自上下文的所有必要信息的状态,而mamba为了兼顾效率和效果,选择性的关注必须关注的、过滤掉可以忽略的

在其前身结构化状态空间模型S4中,其有4个参数(∆, A, B, C)

且它们都是固定的,不随输入变化(即与输入无关),这些参数控制了以下两个阶段

  • 第一阶段(1a 1b),通常采用固定公式A = 𝑓𝐴(∆, A)和B = 𝑓𝐵(∆, A, B),将“连续参数”(∆,A,B)转化为“离散参数”(A,B),其中(𝑓𝐴, 𝑓𝐵) 称为离散化规则,且可以使用多种规则来实现这一转换,例如下述方程中定义的零阶保持(ZOH)
    The first stage transforms the “continuous parameters” (∆, A, B) to “discrete parameters” (A, B) through fixed formulas A = 𝑓𝐴(∆, A) and B = 𝑓𝐵(∆, A, B), where the pair (𝑓𝐴, 𝑓𝐵) is called a discretization rule. Various rules can be used such as the zero-order hold (ZOH) defined in equation (4).
    \overline{\boldsymbol{A}}=\exp (\Delta \boldsymbol{A}) \quad \overline{\boldsymbol{B}}=(\Delta \boldsymbol{A})^{-1}(\exp (\Delta \boldsymbol{A})-\boldsymbol{I}) \cdot \Delta \boldsymbol{B}
  • 第二阶段(2a 2b,和3a 3b),在参数由(∆,A, B, C)变换为(A, B, C)后,模型可以用两种方式计算,即线性递归(2)或全局卷积(3)
    After the parameters have been transformed from (∆, A, B, C) ↦ (A, B, C), the model can be computed in two ways, either as a linear recurrence (2) or a global convolution (3)

    通常,如第一部分最后所讲的,模型使用卷积模式(3)可以进行高效的并行化训练(其中整个输入序列提前看到),并切换到循环模式(2)以高效的自回归推理(其中输入每次只看到一个时间步)
    the model uses the convolutional mode (3) for efficient parallelizable training (where the whole input sequence is seen ahead of time), and switched into recurrent mode (2) for efficient autoregressive inference (wheret he inputs are seen one timestep at a time)

    为何可以做高效的并行化呢,因为该模式能够绕过状态计算,并实现仅包含(B, L, D)的卷积核(3a)
    Thus the more efficient convolution mode wasintroduced which could bypass the state computation and materializes a convolution kernel (3a) of only (𝙱, 𝙻, 𝙳)

为了方便大家更好的理解,再解释下


通过第一部分的讲解,可知\boldsymbol{A} \in \mathbb{R}^{N \times N}, \boldsymbol{B} \in \mathbb{R}^{N \times 1}, \boldsymbol{C} \in \mathbb{R}^{1 \times N}矩阵都可以由N个数字表示,为了对批量大小为B、长度为L、具有D个通道的输入序列x进行操作,SSM被独立地应用于每个通道(the A ∈ ℝ𝑁×𝑁, B ∈ ℝ𝑁×1 , C ∈ ℝ1×𝑁 matrices can all be represented by 𝑁 numbers. To operate over an input sequence 𝑥 of batch size 𝐵 and length 𝐿 with 𝐷 channels, the SSM is applied independently to each channel.)

请注意,在这种情况下,每个输入的总隐状态具有DN维,在序列长度上计算它需要O(BLDN)时间和内存( the total hidden state has dimension 𝐷𝑁 per input, and computing it over the sequence length requires 𝑂(𝐵𝐿𝐷𝑁) time and memory)

下面,再分析下各个变量的含义

  • \Delta,一个标量,类似遗忘门
    如sonta所说,这个量跟RNN里的gating有着深刻的联系(∆ in SSMs can be seen to play a generalized role of the RNN gating mechanism),data dependent的 Δ 跟RNN的forget gate的功能类似(step size Δ that represents the resolution of the input discretization of SSMs is the principled foundation of heuristic gating mechanisms.)

  • B,起到的作用类似于:进RNN的memory
  • C,起到的作用类似于:取RNN的memory
    所以有人说,data dependent的B/C的功能跟RNN的input/output gate类似
  • A,意味着对应这个维度的SSM来说,A在每个hidden state维度上的作用可以不相同,起到multi-scale/fine-grained gating的作用,这也是LSTM网络里面用element-wise product的原因

而在Mamaba中,作者让这些参数BC\Delta成为输入的函数,让模型能够根据输入内容自适应地调整其行为

  1. 从S4到S6的过程中,可以看出BC的大小从原来的(D,N)变成了(B,L,N)\Delta的大小由原来的D变成了(B,L,D)
    进一步,咱们通过
    s_{B}(x)=\operatorname{Linear}_{N}(x)
    s_{C}(x)=\operatorname{Linear}_{N}(x)
    s_{\Delta}(x)=\operatorname{Linear}_{D}(x)
    \tau_{\Delta}=\text { softplus }
    来逐一将B, C, \Delta数据依赖(data dependent)化
    至于上面的所谓\text { Linear }_{d}(x)代表把D维的输入向量x经过一个线性层map到d维
    N即SSM的隐藏层维度(hidden dimension),当然 一般设的比较小

    且每个位置的BC\Delta都不相同(S4时是所有位置共享)
  2. 虽然A没有变成data dependent,但是通过state space model的离散化操作之后,(\bar{A}, \bar{B})会经过outer product变成(B, L, N, D)的data dependent张量,以一种parameter efficient的方式来达到data dependent的目的

    当然,到底效果变好的最大原因是哪一块,可以参考这篇做下相关的实验:Gated Linear Attention Transformers with Hardware-Efficient Training

2.1.2 硬件感知的状态扩展:借鉴Flash Attention

为了让传统的SSM在现代GPU上也能高效计算,Mamba中也使用了Flash Attention技术

核心思想是利用内存的不同层级结构处理SSM的状态,减少高带宽但慢速的HBM内存反复读写这个瓶颈,即

  1. 在更高速的SRAM内存中执行离散化和递归操作,再将输出写回HBM
    具体来说,我们不是在GPU HBM(高带宽内存)中将大小(𝙱,𝙻,𝙳,𝙽)的扫描输入(A, B),而是
    \rightarrow  首先,直接将SSM参数(∆,A,B,C)从慢速HBM加载到快速SRAM
    \rightarrow  然后,在SRAM中进行离散化和递归计算
    \rightarrow  最后,将大小为(B,L,D)的最终输出写回HBM
    Concretely, instead of preparing the scan input (A, B) of size (𝙱, 𝙻, 𝙳, 𝙽) in GPU HBM (high-bandwidth memory), we load the SSM parameters (∆, A, B, C) directly from slow HBM to fast SRAM, perform the discretization and recurrence in SRAM, and then write the final outputs of size (𝙱, 𝙻, 𝙳) back to HBM.
  2. 通过并行扫描算法实现并行化
  3. 当输入从HBM加载到SRAM时,中间状态不被保存,而是在反向传播中重新计算
    the intermediate states are not stored but recomputed in the backward pass when the inputs are loaded from HBM to SRAM

如下图所示

  1. 结构化SSM通过一个更高维的潜在状态ℎ(例如 𝑁= 4),独立地将输入𝑥的每个通道(例如 𝐷= 5)映射到输出𝑦
    Structured SSMs independently map each channel (e.g. 𝐷= 5) of an input 𝑥 to output 𝑦t hrough a higher dimensional latent state ℎ(e.g. 𝑁= 4).
  2. 先前的SSM通过巧妙的替代计算路径,避免实现这个大的有效状态「𝐷𝑁,乘以批量大小𝐵和序列长度𝐿」,且要求:(∆, A, B, C)参数随时间保持不变
    Prior SSMs avoid materializing this large effective state 「𝐷𝑁, times batch size 𝐵and sequence length 𝐿」 through clever alternate computation paths requiring time-invariance: the (∆, A, B, C) parameters are constant across time.
  3. 我们的选择机制重新添加了依赖于输入的动态特性,这也需要一个精心设计的、考虑硬件的算法,以便只在GPU内存层次结构中更有效的级别上实现这些扩展状态
    Our selection mechanism adds back input-dependent dynamics, which also requires a careful hardware-aware algorithm toonly materialize the expanded states in more efficient levels of the GPU memory hierarchy.

2.1.3 简化的SSM架构

将大多数SSM架构比如H3的基础块,与现代神经网络比如transformer中普遍存在的门控MLP相结合,组成新的Mamba块,重复这个块,与归一化和残差连接结合,便构成了Mamba架构

顺带提一嘴,transformer quality in linear time以及mega moving average equipped gated attention的这两个工作,也用了类似的结构:即删除transformer的ffn/glu结构

2.2 对Improving SSMs with Selection的进一步阐述

2.2.1 三个任务的对比:copying、selective copying、induction heads

如下图所示,有三个任务

  1. (左)复制任务的标准版本涉及输入和输出元素之间的固定间距,可以通过线性递归和全局卷积等时不变模型轻松解决
    (Left) The standard version of the Copying task involves constant spacing between input and output elements and is easily solved by time-invariant models such as linear recurrences and global convolutions.
  2. (右上)选择性复制任务在输入之间具有随机间距,需要使用时变模型,在内容上能够灵活地选择记忆或忽略输入
    (Right Top) The Selective Copying task has random spacing in between inputs and requires time-varying models that can selectively remember or ignore inputs depending on their content.

    相当于选择性复制任务通过改变“要记忆的tokens的位置”来改进纯粹的复制任务(Arjovsky, Shah和Bengio 2016)。它需要内容感知推理,以便能够记住相关的标记(有色),并过滤掉不相关的标记(白色)
    The Selective Copying task modifies the popular Copying task (Arjovsky, Shah, and Bengio 2016) by varying the position of the tokens to memorize. It requires content-aware reasoning to be able to memorize the relevant
  3. (右下)归纳头部任务是联想回忆的一个例子,需要根据上下文检索答案,这是LLM关键的能力
    (Right Bottom) The Induction Heads task is an example of associative recall that requires retrieving an answer
    based on context, a key ability for LLMs.

    其实,归纳头部任务是一种众所周知的机制,据推测可以解释LLMs的大部分上下文学习能力(Olsson et al. 2022)。它需要上下文感知的推理,以便知道何时在适当的上下文中产生正确的输出(黑色)
    The Induction Heads task is a well-known mechanism hypothesized to explain the majority of in-context learning abilities of LLMs (Olsson et al. 2022). It requires context-aware reasoning to know when to produce the correct output in the appropriate context (black).

2.2.2

// 待更

2.3 实验结果

Mamba在Chinchilla缩放定律下预训练时,语言任务优于同类开源模型

下游任务上,每个规模尺寸的Mamba都是同类最佳,并且通常与两倍规模的基线性能匹配,特别是当序列长度增加到512k时,相比使用FlashAttention-2的Transformer快几个数量级,而且不会内存不足

最后,有的新闻稿会说Mamba是第一个实现匹配Transformer性能的线性时间序列模型,其实第一个是TransNormerLLM

第三部分 Mamba近似工作之线性Transformer:从TransnormerLLM到RWKV

3.1 从线性Transformer谈到TransnormerLLM

3.1.1 什么是线性transformer:cosformer

我们已知,Dot-product attention与softmax归一化是transformer捕捉长程依赖关系的基石。然而,其关于序列长度的二次空间和时间复杂性使其计算开销令人望而却步,特别是对于长输入。为了解决这个问题,最近提出了许多方法,如稀疏注意力矩阵(sparse attention matrix),低秩表示(lowrank representations)或基于核的方法(kernel-based methods)等,让这些方法皆有其各自的局限性

以上之外,另一个重要的方法便是线性Transformer(Linear Transformer),其将transformer的复杂度从O(N^2)降低为O(N),这对加快Transformer整体的加速非常重要

关于线性Transformer,可以看下这两篇论文:《Fast Autoregressive Transformers with Linear Attention》、以及友人钟博士团队的《COSFORMER : RETHINKING SOFTMAX IN ATTENTION


线性Transformer的核⼼思想是通过Kernel trick的⽅式,如下图右侧所示,将QKV的左乘变成右乘,从⽽将理论计算复杂度降为线性

我们已知

  • Transformer中self-attention的典型计算如下:
    O =\operatorname{softmax}\left(\frac{Q K^{T^{\prime}}}{\sqrt{d}}\right) V
    其中矩阵Q、K、V是由输入 x 经线性变化得到的query、key、value
  • 如果暂不考虑缩放因子,则自注意力的计算可以分解为向量运算
    \operatorname{Attn}(Q, K, V)_{t}=\frac{\sum_{i=1}^{T} e^{q_{t}^{\top} k_{i}} \odot v_{i}}{\sum_{i=1}^{T} e^{q_{t}^{\top} k_{i}}}
    其中,上式的分母是一个归一化因子,确保所有的注意力得分加起来等于1

    这一步怎么做到的呢,援引HeptaAI的一个说明图如下

接下来,便有以下一系列推导

  1. 如果用下标i来表示矩阵的第i行(如 Q_i表示矩阵 Q 的第i行),那么可以将上述公式中的计算用如下形式抽象出来:
    O_{i} =\frac{\sum_{j=1}^{N} \operatorname{sim}\left(Q_{i}, K_{j}\right) }{\sum_{j=1}^{N} \operatorname{sim}\left(Q_{i}, K_{j}\right)} V_{j}
    其中sim()为抽象出的计算Query和Key相似度的函数
  2. Linear Transformer采用了kernel来定义sim():
    \operatorname{sim}\left(Q_{i}, K_{j}\right)=\phi\left(Q_{i}\right) \phi\left(K_{j}\right)^{T}
    其中 \phi 是一个特征映射函数,可根据情况自行设计

    考虑到矩阵乘法有结合律,softmax只能左乘,linear可以右乘,而右乘更快,正因为矩阵乘积的这个属性可以实现注意力操作的线性复杂度:
    \left(\phi(Q) \phi(K)^{T}\right) V=\phi(Q)\left(\phi(K)^{T} V\right)

    相当于不是显式地计算注意力矩阵A=Q K^{T} \in \mathbb{R}^{N \times N},而是先计算\phi(K)^{T} V \in \mathbb{R}^{d \times d},然后乘以\phi(Q) \in \mathbb{R}^{N \times d},从而最终的时间复杂度为O\left(N d^{2}\right)
    考虑到,在一般的NLP任务中,一个头d的特征维度总是比输入序列长度N (d \ll N)小得多,因此可以忽略d,实现O(N)的计算复杂度
  3. 因此,self-attention可以从
    O_{i}=\frac{\sum_{j=1}^{N}\left(\phi\left(Q_{i}\right) \phi\left(K_{j}\right)^{T}\right) V_{j}}{\sum_{j=1}^{N}\left(\phi\left(Q_{i}\right) \phi\left(K_{j}\right)^{T}\right)}
    转化为:
    O_{i}^{\prime}=\frac{\phi\left(Q_{i}\right) \sum_{j=1}^{N} \phi\left(K_{j}\right)^{T} V_{j}}{\phi\left(Q_{i}\right) \sum_{j-1}^{N} \phi\left(K_{j}\right)^{T}}

    原始Transformer的计算复杂度之所以随序列长N呈二次方增长,这是因为attention的计算包含两层for循环
    \rightarrow  外层是对于每一个Query,我们需要计算它对应token的新表征
    \rightarrow  内层for循环是为了计算每一个Query对应的新表征,需要让该Query与每一个Key进行计算
    所以外层是 for q in Queries,内层是 for k in Keys,Queries数量和Keys数量都是N,从而复杂度是 O(N^2)

    好比军训时,甲乙丙丁4个人列成一队,计算注意力机制的过程相当于
    首先把甲站到队伍的前面,算“其”与“自己在内所有人”的相似度,即计算这些的内积值:
    甲q甲k、甲q乙k、甲q丙k、甲q丁k
    接着,再乙站到队伍的前面,算“其”与“自己在内所有人”的相似度,即计算这些的内积值:
    乙q甲k、乙q乙k、乙q丙k、乙q丁k

    丙、丁以此类推,即分别计算这两批内积值:
    丙q甲k、丙q乙k、丙q丁k、丙q丙k
    丁q甲k、丁q乙k、丁q丙k、丁q丁k


    而Linear Transformer,它只有外层for q in Queries这个循环了,因为求和项的计算与i 无关,所以所有的 Q_i可以共享求和项的值。换言之,求和项的值可以只计算一次,然后存在内存中供所有 Q_i 去使用,所以Linear Transformer的计算复杂度是O(N)
  4. 引入以下两个新符号:
    \begin{array}{c} S_{i}=\sum_{j=1}^{i} \phi\left(K_{j}\right)^{T} V_{j} \\ Z_{i}=\sum_{j=1}^{i} \phi\left(K_{j}\right)^{T} \end{array}

    稍作变换,可以将Si 和Zi 写作递归形式:
    \begin{array}{l} S_{i}=\sum_{j=1}^{i} \phi\left(K_{j}\right)^{T} V_{j}=\phi\left(K_{i}\right)^{T} V_{i}+\sum_{j=1}^{i-1} \phi\left(K_{j}\right)^{T} V_{j}=\phi\left(K_{i}\right)^{T} V_{i}+S_{i-1} \\ Z_{i}=\sum_{j=1}^{i} \phi\left(K_{j}\right)^{T}=\phi\left(K_{i}\right)^{T}+\sum_{j=1}^{i-1} \phi\left(K_{j}\right)^{T}=\phi\left(K_{i}\right)^{T}+Z_{i-1} \end{array}
    因此,在inference阶段,当需要计算第i时刻的输出时,Linear Transformer可以复用之前的状态 Si−1 和 Zi−1 ,再额外加上一个与当前时刻相关的计算量即可。而Transformer在计算第i时刻的输出时,它在第i-1个时刻的所有计算都无法被i时刻所复用。因此,Linear Transformer更加高效

总结一下:

  • Linear Transformer的计算复杂度为 O(N) (不考虑embedding的维度的情况下)
  • 因为Si可由Si−1计算得到(Zi同理),所以它可实现Sequential Decoding(先算S1,由S1算S2,以此类推)。能Sequential Decoding是让这类Transformer看起来像RNN的核心原因

3.1.2 TransnormerLLM

如qinzhen所说,transnomerLLM相比cosformer,最本质的区别是其位置编码的不同,剩下就是结构细微的优化以及工程

// 待更

3.2 从AFT到RWKV

3.2.1 AFT(Attention Free Transformer)

Attention Free Transformer (AFT) 是Apple公司提出的一种新型的神经网络模型,它在传统的 Transformer 模型的基础上,通过使用像Residual Connection之类的技术来消除注意力机制,从而减少计算量和提升性能

AFT在不同的资料中有不同的表达形式

  • 比如有的资料会写成
    O_{i}^{\prime}=\sigma\left(Q_{i}\right) \odot \frac{\sum_{j-1}^{i} \exp \left(K_{j}+w_{i, j}\right) \odot V_{j}}{\sum_{j=1}^{j} \exp \left(K_{j}+w_{i, j}\right)}
    其中\sigma是sigmoid函数;⊙是逐元素相乘(element-wise product), w_{i,j}是待训练的参数

    AFT采用的形式和上面的Linear Transformer不一样
    \rightarrow  首先是attention score,Linear Transformer仍然是同Transformer一样,为每一个Value赋予一个weight,而AFT会为每个dimension赋予weight
    换言之,在Linear Transformer中,同一个Value中不同dimension的weight是一致的,而AFT同一Value中不同dimension的weight不同
    \rightarrow  此外,attention score的计算也变得格外简单,用K去加一个可训练的bias。Q的用法很像一个gate

    可以很容易把AFT也写成递归形式,这样容易看出,AFT也可以像Linear Transformer,在inference阶段复用前面时刻的计算结果,表现如RNN形式,从而相比于Transformer变得更加高效
  • 还有的资料比如RWKV论文会写成(和上式一个意思)
    \operatorname{Attn}^{+}(W, K, V)_{t}=\frac{\sum_{i=1}^{t} e^{w_{t, i}+k_{i}} v_{i}}{\sum_{i=1}^{t} e^{w_{t, i}+k_{i}}}
    其中,其中 \left\{w_{t, i}\right\} \in R^{T \times T} 是学习的pair-wise位置偏差,每个 w_{t, i} 是一个标量

    下图是对该式的解释说明

    其实从式子上看,AFT无非是将矩阵乘改成了矩阵加,加上模型只能看到前面的token。注意这里的 w 是一个二维矩阵,和attention中的positional encoding作用相似,都是为了给模型输入位置信息

3.2.2 RWKV:试图在Transformer时代重塑RNN

RWKV其实是我司论文审稿GPT第一版(详见此文的第三部分 七月论文审稿GPT第一版:基于论文审稿语料微调RWKV )选用的模型之一,虽然当时第一版用RWKV的效果没符合预期,但在有些任务上的表现还是不错的,加之因为写mamba模型而再次关注到有点类似的RWKV,故本文也顺带讲一下

据RWKV论文可知,RWKV 架构的名称源自timemixingchannel-mixing模块中使用的四个主要模型元素(defined by four fundamental elements that are intrinsic to the timemixing and channel-mixing blocks):

  • R:表示过去的信息,用的sigmoid激活函数
  • W:权重是位置权重衰减向量,是可训练的模型参数(后面还会再出来个U,是对当前位置信号的补偿)
  • K:Key    是类似于传统注意力中的K 的向量
  • V :value 是类似于传统注意力中的V 的向量

每个时间步的主要元素之间的相互作用都是乘法的,如下图所示

在RWKV的结构中,其中的递归被表述为当前输入和前一个时间步的输入之间的线性插值(我们将这种技术称为time-shift mixing或token shift,如下图中的对角线所示)

  • 可以表示为针对输入嵌入的每个线性投影(例如,timemixing中的 R、K、V,以及channel-mixing中的 R、K)进行独立调整,并作为 WKV 的时间相关更新
  • WKV 计算与 AFT 类似,但 W 现在是“通道向量”乘以“相对位置”(下文详述),而不是 AFT 中的pairwise position matrix。我们还引入了一个向量 U 来单独关注当前token,以补偿 W 的潜在退化

一看有点懵,没事,因为其中有不少细节,咱们来逐一阐述

3.2.2.1 RWKV的时间混合(time mix)模块与通道混合(channel mix)模块

如下图所示,假设输入sequence是My name is,目前 t = 2 ,则这里 x_{t-1}是上一个输入token(My), x_t是这个输入token(name)
\mu是遗忘因子,越大对上个token(My)就忘的越多,也就是对这个token(name)更专注,黄色(μ)表示token shift「至于红色(1)表示分母,蓝色(2)表示分子,粉色(3)表示16种分数计算,h代表了分子和分母的元组

可有以下五个公式

先解释前三个公式

  • 在传统Transformer中, q,k,v 本质上都是 x_t 的线性变换,可以用来动态调整表示的子空间维度且增大参数量
  • 在RWKV中, r,k,v 本质上都是 x_t,x_{t-1} 线性组合的变换,且作为计算RKV的输入的x:不再是当前token的embedding,而是当前token与上一个token embedding的加权和

\begin{aligned} r_{t} & =W_{r} \cdot\left(\mu_{r} x_{t}+\left(1-\mu_{r}\right) x_{t-1}\right) \\ k_{t} & =W_{k} \cdot\left(\mu_{k} x_{t}+\left(1-\mu_{k}\right) x_{t-1}\right) \\ v_{t} & =W_{v} \cdot\left(\mu_{v} x_{t}+\left(1-\mu_{v}\right) x_{t-1}\right) \end{aligned}

接下来 重点解释下其中最难的部分第4个公式w k v_{t}

  • 原始的attention是这样的:

\operatorname{Attn}(Q, K, V)_{t}=\frac{\sum_{i=1}^{T} e^{q_{t}^{\top} k_{i}} \odot v_{i}}{\sum_{i=1}^{T} e^{q_{t}^{\top} k_{i}}}

  • AFT的attention

\operatorname{Attn}^{+}(W, K, V)_{t}=\frac{\sum_{i=1}^{t} e^{w_{t, i}+k_{i}} v_{i}}{\sum_{i=1}^{t} e^{w_{t, i}+k_{i}}}

  • RWKV的attention
                                            w k v_{t}=\frac{\sum_{i=1}^{t-1} e^{-(t-1-i) w+k_{i}} \odot v_{i}+e^{u+k_{t}} \odot v_{t}}{\sum_{i=1}^{t-1} e^{-(t-1-i) w+k_{i}}+e^{u+k_{t}}}

    怎么理解这个RWKV attention的这个表达式呢?
    受 AFT 的启发,RWKV 中的每个 w_{t, i} 都代表一个「通道时间衰减向量」,该向量乘以相对位置,并且在衰减时从当前时间开始向后追踪(Each wt,i in RWKV is a channelwise time decay vector multiplied by the relative position and traced backward from current time as it decays):
                                            w_{t, i}=-(t-i) w

    其中 w \in\left(R_{\geq 0}\right)^{d} , d 是通道数,RWKV要求 w为非负数,以确保 e^{w_{t, i}} \leq 1 并且确保每个通道的权重在时间上向后衰减(ensure that e wt,i ≤ 1 and the per-channel weights decay backwards in time)
    这个操作与后面的 e^{u} 都是用来建模序列的time decay的

    以上可能解释的比较绕,不够通俗,其实说白了,相比AFT,原来的依靠绝对位置的偏置w_{t, i}没有了,改成了相对位置,并且只有一个参数w向量需要训练
    其次,对当前位置单独处理,增加了参数u

最后,再解释第5个公式

  • 其中 W K V计算, w k v_{t} , 在 Transformers 中扮演 \operatorname{Attn}(Q, K, V)的角色,而不会产生quadratic成本,因为计算的都是标量,这就是上面的第5个公式

                                              o_{t}=W_{o} \cdot\left(\sigma\left(r_{t}\right) \odot w k v_{t}\right)

  • 直观上,随着时间 t 的增加,向量 o_{t} 取决于较长的历史,由越来越多的项的总和表示。对于目标位置 t ,RWKV在 [1, t]的位置区间进行加权求和,然后乘以接受度 \sigma(r) 
    因此,交互作用在给定的时间步长内是乘法的,并在不同的时间步长上求和

最后,通道混合块(channel mix block)根据time-mixing block的输出,然后使用下述三个公式的前两个公式计算一组心的R、K,最后根据下面第三个公式计算最终输出

\begin{aligned} r_{t} & =W_{r} \cdot\left(\mu_{r} x_{t}+\left(1-\mu_{r}\right) x_{t-1}\right) \\ k_{t} & =W_{k} \cdot\left(\mu_{k} x_{t}+\left(1-\mu_{k}\right) x_{t-1}\right) \\ o_{t} & =\sigma\left(r_{t}\right) \odot\left(W_{v} \cdot \max \left(k_{t}, 0\right)^{2}\right) \end{aligned}

3.2.2.2 RWKV的训练阶段与推理阶段
训练阶段:时间并行模式

在训练复杂度上,我们对比下标准注意力与RWKV

  • 对于标准注意力而言,假设是T个最大token,因为RWKV只需要上一时刻的state vector和这一时刻的输入。因此,生成的每一个token只要考虑常数个变量,所以复杂度为\mathcal{O}(T)

    如果是d个通道,则每个 \operatorname{Attn}_{t} 需要进行 T 次求和,每次求和都涉及一维向量分别点乘,复杂度为\mathcal{O}(T d),因此对于整个序列的复杂度为\left(T^{2} d\right)
    \operatorname{Attn}(Q, K, V)_{t}=\frac{\sum_{i=1}^{T} e^{q_{t}^{\top} k_{i}} \odot v_{i}}{\sum_{i=1}^{T} e^{q_{t}^{\top} k_{i}}}
    当然,如果是B个序列,则复杂度为O\left(B T^{2} d\right)
  • 对于RWKV而言
    w k v_{t}=\frac{\sum_{i=1}^{t-1} e^{-(t-1-i) w+k_{i}} v_{i}+e^{u+k_{t}} v_{t}}{\sum_{i=1}^{t-1} e^{-(t-1-i) w+k_{i}}+e^{u+k_{t}}}

    \rightarrow  针对\sum_{i=1}^{t-1} e^{(-t-1-i) w+k_{i}} v_{i}, t不是向量下标,意味着对每个 t,我们知道 w,k_i 是复用的,因此,t → T 时复杂度为\mathcal{O}(T d)
    \rightarrow  针对 \sum_{i=1}^{t-1} e^{u+k_{t}} v_{t}, i 不是向量下标,意味着对每个 i ,我们知道 k_t,v_t 是复用的,因此,时间复杂度为 \mathcal{O}(T d)+\mathcal{O}(T d)=\mathcal{O}(T d)

    也就是说,在内层循环,算出的\sum_{i=1}^{t-1} e^{(-t-1-i) w+k_{i}} v_{i}可以直接存起来供外层循环使用。即,RWKV的内外层循环是解耦的
    当然,如果是B个序列,则复杂度为O(B T d)
推理阶段:时间顺序模式

在循环网络中,使用状态 t 的输出作为状态 t+1 的输入是很常见的。这在语言模型的自回归解码推理中尤其明显,要求每个标记在输入下一步之前进行计算,从而使得RWKV 利用其类似 RNN 的结构,称为时间顺序模式(time-sequence mode),如下图所示(来自小冬瓜AIGC)

  • 在这种情况下,可以方便地递归地制定 RWKV 以便在推理过程中进行解码,它利用了每个输出token仅依赖于最新状态的优点,该状态具有恒定的大小,而与序列长度无关
  • 然后,它充当 RNN 解码器,根据序列长度产生恒定的速度和内存占用,从而能够更有效地处理较长的序列。相比之下,自注意力通常需要 KV 缓存相对于序列长度线性增长,从而导致效率下降,并且随着序列变长而增加内存占用和时间

参考文献与推荐阅读

  1. Transformer挑战者出现!FlashAttention作者参与,模型代码都开源,公司已创办
  2. [线性RNN系列] Mamba: S4史诗级升级
  3. Structured State Spaces for Sequence Modeling (S4)
  4. S4: 使用结构化状态空间对长序列进行高效建模
  5. Efficiently Modeling Long Sequences with Structured State Spaces
    首次提出了结构化状态空间S4
  6. S4作者在YouTube上对S4论文的精彩解读
  7. Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention
  8. RWKV: Reinventing RNNs for the Transformer Era(下载地址2),这是其翻译,这是其解读之一
  9. 【手撕LLM-RWKV】重塑RNN 效率
  10. ..

创作、修订、完善记录

  1. 12.11,开写,且发现Google抓的也是真快(当天用Google搜:mamba模型,本文已排第一)
  2. 12.12,考虑到想理解好mamba,则需要先理解好SSM,故全力完善这几节的内容:“1.2 状态空间模型SSM”、“1.3 S4的前身:HiPPO”、“1.4 S4的推出:Structured State Space Models”
  3. 12.13,完善此节:“1.5 S4的性质:连续的表示、用Recurrent快速infer、用Convolutional快速训练”
  4. 12.14,结合mamba论文,开始精修“第二部分 Mamba的组成结构与原理解析”
    特别是以下这两节
    2.1.1 选择性状态空间模型:从S4到S6
    2.1.2 硬件感知的状态扩展:借鉴Flash Attention
  5. 12.15,开始写:“第三部分 Mamba近似工作之线性Transformer:从AFT、RWKV谈到TransnormerLLM”
    特别是此节:“3.2 RWKV:试图在Transformer时代重塑RNN”
  6. 12.17,修正1.4节中的一个笔误,已修正为:“作者想到了非常精妙的一个方法:不考虑input u 到state x,而是直接从state x 到output y ”
  7. 12.19,在TransNormer的提出者qinzhen的建议之下,补充关于线性transformer的一些解释说明,特别是关键的这一句
    “考虑到矩阵乘法有结合律,softmax只能左乘,linear可以右乘,而右乘更快,正因为矩阵乘积的这个属性可以实现注意力操作的线性复杂度”
  8. 12.23,根据友人钟博士的反馈,在文中强调:第一个实现匹配Transformer性能的线性时间序列模型是TransNormerLLM..