标题

Mamba: Linear-Time Sequence Modeling with Selective State Spaces

具有选择状态空间的线性时间序列建模

背景

发表时间:2023年12月
单位:卡内基梅隆大学
作者:Albert Gu*, Tri Dao*
image

Mamba架构

Mamba是由许多层Mamba堆叠而成,作者提到Mamba架构是受到H3架构(Hungry Hungry Hippo)的启发。Mamba是由H3和门控MLP操作组合而成。

Mamba和Transformer是相似的,

image

状态空间模型

Mamba是一种状态空间模型,是基于S4(序列模型的结构化状态空间)的工作。

状态空间模型帮助我们使用一组定义的输入和输出对物理系统进行建模,输入和输出通过一阶微分方程相关联。
image

这是一个一节微分方程,变量ABCD称为状态变量。负责记住系统的状态并对其进行建模,在模型训练时被保留为可学习的参数。
通过对状态空间方程模型求导,得到状态空间方程的解:

image

理想的语言模型

Transformer在训练的过程中具有高度并行化的架构,不会梯度爆炸或下降。但Transformer的推理是一个迭代的过程,计算复杂度为O(N^2)。而RNN的计算复杂度只有O(1),因为下一个token仅取决于之前的隐藏状态和当前输入。并且,正常的卷积训练是可并行的。

所以,一个理想的语言模型应该是:

S4的两面性 - RNN和CNN

RNN 循环神经网络
在RNN中,每个RNN块都输入当前状态x_t和之前的隐藏状态h_t-1,以输出下一个隐藏状态h_t和可选的输出状态y_t。

image
CNN 卷积神经网络
通过展开递归方程,我们可以得到一个通式,这个通式可以表示为具有固定K大小内核的连续卷积。
image

image

因此,我们在训练的时候就可以切换到卷积模式,进行并行化训练。在推理过程中,就可以切换到循环模式,进行恒定时间推理。
因为卷积的内核是固定的,所以称这种模型为时序不变的SSM。
由于SSM和RNN非常相似,所以也会遇到梯度爆炸或者消失的问题。因此提出过HiPPO矩阵,
image

通过HiPPO表示A矩阵,每个隐藏状态中的A矩阵都会记住历史信息,只需要计算一次。
矩阵A通过跟踪勒让德多项式的系数表示最新历史信息。

红色信号视为我们想要记住或近似的目标信号,黑框表示每个状态的值,基于
该状态值,蓝线绘制勒让德级数。随着每个步骤的进行,HiPPO矩阵会更新
每个步骤。 最近的步骤越多,近似度越好,时间步长越长时,近似度越低。
image

S4的问题

基于序列的建模的根本问题是:将上下文压缩成更小的可学习状态,在效率和状态表示质量之间存在权衡。
比如:transformer效率低,因为它需要存储上下文的KV缓存。但存储和表示长文方面表现良好。而RNN效率高,因为具有有限形态,但是在上下文压缩质量方面效果较差。S4和所有其他状态空间模型在一些任务上,比如说选择性复制和上下文学习归纳表现效果不佳。

Mamba的选择机制

举个例子,我们想过滤一条评论,不过改推文主题,但是删除单词。基于Transformer的模型可以通学习这些单词来做到。但是当前的SSM模型无法做到这一点,因为他们是时序不变的,内核是固定的,可学习参数随着每个新的令牌传入而保持固定。

image

Mamba引入选择性扫描,不依赖卷积和递归的双重属性,仅依赖于循环,由于时间变化的参数化,矩阵A(HiPPO矩阵)保持不变,但是∆,B和C现在变为输入的函数。其中B代表批次大小、L代表序列长度、D代表输入数据的维度数,N代表隐藏状态的维度数。并且和S4不同的是,B和C不是参数,而是一些投影输出
image
image

Mamba的并行扫描操作

由于Mamba支持用RNN,不使用卷积,所以不能进行并行化。但是Mamba设计了一个并行扫描操作。
image

我们先讲一个前缀和问题,核心思想是计算一个数组中每个位置之前所有元素的综合。
image

它的计算方式和RNN类似,每个新状态都是当前输入x_t和先前状态h_t-1的总和。

由于求和运算具有结合性,(a + b) + c = a + (b + c),我们可以将数组拆分成多个部分,分别计算每个部分的前缀和,然后合并结果。这样,不同部分的计算可以并行进行。
image

Mamba的优化

除了并行扫描算法,mamba还使用了内核融合和激活重新计算技术。
对于Mamba来说,扫描操作识别内存绑定的操作,因此内核融合用于减少内存I/O量,即从HBM加载到SRAM。所以,首先将SSM参数 A、Δ、B 和 C 从 HBM 加载到 SRAM,融合进行离散化操作,即从 A、Δ、B 转换 A_bar 和 B_bar , C,递归运算在SRAM中完成,最终被发送回HBM。
在任何深度学习模型中,都有前向传播和反向传播、在前向传播中,我们计算每层的中间值和输出,在后向传播中,我们计算参数的梯度并更新它。

image

在后向传播过程中,我们需要跟踪所有的中间值,会消耗大量内存,所以在实际计算之前,必须将中间值从HBM加载到SRAM,但需要更多的内存和时间,因此Mamba在计算过程中,所有中间值都不会被存储,而是在输入从HBM加载到SRAM时的反向传播期间重新计算。

实验部分

image
image
image

架构图

image

Mamba是基于之前的工作S4(Structured state space model),在S4的基础上做了两个改进:

代码

model = Mamba(
    d_model=dim, # Model dimension d_model
    d_state=16,  # SSM state expansion factor
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
).to("cuda")