DISTFLASHATTN: Distributed Memory-efficient Attention for Long-context LLMs Training

背景

大型语言模型(LLMs)在处理长文本数据时,例如生成完整代码库或与长篇文档进行对话,需要保持对长上下文信息的敏感性。这要求模型能够处理和记忆更长的序列信息。使用长序列训练LLMs时,会显著增加激活(activation)的内存占用,这对模型训练提出了新的挑战。传统的注意力机制在处理长序列时内存需求呈二次方增长,这限制了模型规模和序列长度。现有的内存高效注意力机制虽然能够在单个设备上减少峰值内存使用,但缺乏分布式扩展,无法充分利用多设备计算资源。此外,现有的分布式训练方法在处理长序列时效率不高,存在通信开销大和扩展性差的问题。为了克服这些挑战,需要一种能够在分布式环境中有效扩展,同时保持高GPU利用率和低通信开销的注意力机制。本文通过提出新的分布式内存高效注意力机制,旨在优化计算资源的使用,特别是减少不必要的内存占用和计算开销,从而提高训练长上下文LLMs的效率。通过解决长上下文LLMs训练中的瓶颈问题,可以推动开发更大、更复杂的模型,这些模型能够更好地理解和生成语言,从而在各种自然语言处理任务中取得更好的性能。

方法设计

序列并行方法

DistFlashAttn的目标是两方面的:(1) 将单个序列分发到多个workers(例如GPU),以便它们共同利用内存来支持长序列训练;(2) 保持内存高效注意力的IO感知优势,使得训练速度快且内存占用较小。为了分发长序列,DISTFLASHATTN将tokens组成的输入序列均匀地跨P个workers(例如GPU)沿序列维度进行分割。每个worker仅计算并存储一个子序列的激活。

每个worker可以收集与其他子序列相关的所有K和V,然后通过调用现有的单机FlashAttention在本地计算。然而,这种收集通过需要存储完整的键和值列表引入了内存压力。DistFlashAttn通过迭代的方法,每次迭代只从其中一个远程worker获取哪个worker上保存的KV进行计算。

负载平衡

在因果注意力中,每个token只关注它之前的tokens,因此worker之间的负载并不均衡(图1(a))。如图1(b)所示,本文让完成了所有
Attention计算的worker 为工作负载较重的worker执行注意力计算,从而完成负载均衡。

图1:每个worker的workload

DistFlashAttn依赖于点对点(P2P)通信,在计算Attention之前从远程workers获KV,这些通信是可以喝计算重叠的。如图2所示,DiststFlashAttn实现了两种并行化的步骤:提前获取(Fetch)和计算(Compute)。在现代加速器中,这可以通过将Attention计算kernel放在主GPU流中,将P2P通信kernel放在另一个流中来实现。这种优化通过将通信时间隐藏在计算时间内部,可以有效地减少了通信开销。

图2:Overlap的例子

重计算梯度检查点算法

梯度检查点(Gradient checkpointing)经常被用于训练长上下文Transformer模型。系统通常使用启发式方法在每个Transformer层的边界处插入梯度检查点。然而,由于FlashAttention的存在,本文发现以前的梯度检查点策略会导致FlashAttention前向kernel存在冗余重计算。具体来说,在计算MLP层的梯度时,会重新计算包括FlashAttention在内的整个Transformer层的前向过程。在这个过程中,FlashAttention backward kernel会再次按块状重新计算softmax以减少内存使用。这个问题的本质是不论外部系统如何冲计算,FlashAttention在前向过程中都不保存激活值,而是在后向过程中进行重计算。

为了解决这个问题,如图3所示,本文FlashAttention kernel的输出处插入检查点,而不是在Transformer层边界处。因此,DiststFlashAttn只需要计算一次FlashAttention的前向,有效地避免了所有FlashAttention的重计算。

图3:HuggingFace和本文的提督检查点对比

如图4所示,在长序列中前向过程中Attention占主导地位,这表明本文方法在单台机器上训练64K序列的Llama-7b示例时节省了约0.23×32(即约7)秒。此外还可以节省DISTFLASHATTN前向在分布式训练场景中带来的通信开销。

图4:Attention和其他模块相比的前向计算时间

实验验证

本文在至多16张NVIDIA A100上进行了实验验证。与Ring Self-Attention相比,DISTFLASHATTN支持长达8倍的序列长度,速度提升了4.45到5.64倍。与Megatron-LM结合FlashAttention相比,DISTFLASHATTN在序列长度延长2到8倍的情况下,速度提升了1.24到2.01倍。与Ring Attention和DeepSpeed-Ulysses相比,DISTFLASHATTN分别取得了1.67倍和1.26到1.88倍的速度提升。

DISTFLASHATTN能够支持更长的序列长度。如表1所示,Ring Self-Attention(RSA)能够支持的最大序列长度为32K,而DISTFLASHATTN能够支持超过256K的序列长度。在两个节点上,RSA能够支持的最大序列长度为64K,而DISTFLASHATTN能够支持超过512K的序列长度。

表1:与Ring Self-Attention相比,支持的最长训练长度每个iteration的平均时间

总结

这项工作介绍了DistFlashAttn,这是一种基于序列并行性的分布式内存高效注意力原型,用于长上下文Transformer训练。DistFlashAttn展示了包括因果语言建模的负载平衡、分布式注意力计算中通信与计算的重叠,以及重计算感知的检查点策略在内的新颖系统优化。实验评估了多种Transformer模型以及在不同类型的集群上,并且与四个强大的分布式系统基线进行了比较。特别是,与流行的系统Megatron-LM和FlashAttention相比,DistFlashAttn展示了高达2.01倍的速度提升,并能够扩展到8倍更长的序列。