Flash Attention
此内容尚不支持你的语言。
Motivation
大家好 今天我们来讨论一下FlashAttention 它现在是训练大模型默认采用的技术 从FlashAttention的论文题目就可以看出 它有两大优势 一,Fast 它可以加快模型训练的速度 二,Memory Efficient 显存高效的 它可以减少显存的占用 并且它保证Exact Attention 也就是它和标准的Attention计算得到的结果是完全一致的 并不像其他一些算法是以降低Attention的精度来提高训练速度的 最后,With IO Awareness 说明了它是通过对IO感知的方式来进行训练的 也就是整个算法是以改进IO效率来达到的 这个论文题目写的确实不错 严简易赅 首先我们看一下Transformer架构里一个Attention的标准计算过程 首先是输入的Token的向量 然后生成Q,K,V矩阵 然后通过Q乘以K的转制得到S 接着对S按行求Softmax 接着得到注意力矩阵P 最后P乘以V得到注意力输出O 这里我们为了讲解方便,对整个过程进行了简化 没有对S做Scale,也没有多头注意力机制,也没有Dropout 那我们用PyTorch写的代码 在实际显卡上的Attention是如何计算的呢? 首先,矩阵Q,K,V存储在HBM里 它们的形状都是N乘以D N是序列长度,D是特征维度 第一步,从HBM加载Q,K矩阵到SRAM 第二步,计算出S,它等于Q乘以K的转制 第三步,将S写到HBM 第四步,将S加载到SRAM 第五步,计算出P,它等于Softmax S 第六步,将P写出到HBM 第七步,从HBM加载P和V到SRAM 第八步,计算O,它等于P乘以V 第九步,把O写出到HBM 第十步,返回O 可以看到中间有很多临时变量的读写 比如S和P矩阵 它们的大小都是随着序列长度的平方增长的 比如,拉马森8B 支持的序列长度为8192 每个头的维度为128 中间的S,P,临时矩阵占用的显存非常大 保留中间结果,比如S,P,会占用显存 但是还是需要的 因为反向传播时需要它们来计算梯度
First let’s recap the standard procedure of computing attention in the Transformer architecture. It starts with the input tokens’ vectors, generating the matrices , , . Then, by multiplying with the transpose of , we obtain . Next, undergoes row-wise softmax operation to get the attention matrix . Finally, is multiplied by to get the attention output . For the sake of simplification, unless specified otherwise, the following discussion does not consider scaled dot product, multi-head attention, or dropout.
Let’s see how attention is computed on a physical GPU using PyTorch code. , , and matrices are stored in HBM (High Bandwidth Memory) with shapes , where is sequence length and is feature dimension. The process is as follows:
-
Load , matrices from HBM to SRAM.
-
Compute as .
-
Write back to HBM.
-
Load from HBM to SRAM.
-
Compute as Softmax of .
-
Write back to HBM.
-
Load and from HBM to SRAM.
-
Compute as .
-
Write back to HBM.
-
Return .
There are many temporary variable reads/writes, like matrices and , whose sizes grow quadratically with sequence length. These intermediate results are necessary for gradient computation during backpropagation.
compute/memory bandwidths in attention calculation
在模型训练时 制约训练速度有两种情况 一种情况是Compute Bound 训练速度的瓶颈在于运算 比如对于大的矩阵乘法 还有多Channel的卷积操作 这些操作都是需要的数据量不大 但是计算很复杂 还有另一种情况是Memory Bound 训练速度的瓶颈在于对 对于HBM数据的读取速度 从HBM读取数据的速度跟不上运算的速度 算力在等待数据 主要的操作包括两类 一类是安慰的操作 比如软路和Dropout 还有一类是规约操作 比如SUM, Softmax等 这些操作都是需要数据很多 但是计算相对简单 Attention计算操作 主要是Memory Bound的计算 可以看到 Compute Bound的操作 比如矩阵乘法 只占用了非常少的时间 但是Memory Bound的计算 占据了很多的时间 对于Memory Bound的优化 主要通过融合多个操作 也叫做Fusion 它节约了原来多个操作之间 要存取HBM的时间 让多个操作 只要存取一次HBM 但是中间结果 我们之前说过 它在反向传播时是有用的 但这里为了节省 访问HBM的时间 我们不保存中间结果 在反向传播时重新计算 这里不清楚的话 可以看我之前讲 梯度检查点的视频 显存里的存储是分级的 有芯片内的缓存 还有芯片外的HBM显存 芯片内缓存容量小 但是访问快 芯片外的HBM容量大 但是访问慢 所以对于优化IO速度 应该尽可能的让计算 访问芯片内的缓存 而尽可能减少 访问芯片外的HBM的显存 这里不清楚的同学 可以看我上一期讲的 GPU原理的视频 之前很多对Attention改进的算法 都着眼于减少计算量 但是FlashAttention 着眼于减少IO访问量 以及通过访问芯片内缓存 而加快IO的速度 所以FlashAttention的目标 就是尽量避免Attention操作 从HBM来读写 它是通过以下两点来实现的 一通过对矩阵分块 并且融合Attention内的所有操作 不缓存中间结果到HBM 来加快速度 二通过在反向传播时 重新计算中间结果 来解决不缓存中间结果 带来的7度计算问题 通过这两点改进 它达到了2到4倍的训练速度的提升 而且将原来训练时对显存的占用 随序列长度平方增长 减小为随序列长度线性增长 可以看到序列长度越长 它对显存节省越多 序列长度为4096时 它比PyTorch的实现 节省20倍的显存占用
Training speed constraints include Compute Bound and Memory Bound scenarios. Compute Bound constraints are due to operations like large matrix multiplications and multi-channel convolutions, which require minimal data but are computationally intensive. Memory Bound constraints occur when HBM data read speeds cannot keep up with computation speeds, leading to idle computational resources. Main operations include element-wise operations like ReLU and Dropout and reduction operations like SUM and Softmax. Attention calculations are mostly Memory Bound.
Optimizations for Memory Bound scenarios involve fusing multiple operations (Fusion) to reduce HBM access time, allowing multiple operations to access HBM only once. However, intermediate results needed for backpropagation are recomputed rather than stored to save HBM access time. Those interested can refer to my earlier videos on gradient checkpointing.
Memory in GPUs is hierarchical, with fast-access on-chip cache and slower-access off-chip HBM. To optimize IO speed, computations should access on-chip cache as much as possible while reducing off-chip HBM access. Refer to my previous video on GPU architecture for more information.
Earlier Attention improvements focused on reducing computation. FlashAttention focuses on reducing IO access and accelerating IO speed via on-chip cache. Its goal is to avoid Attention operations’ HBM read/writes by:
- Matrix partitioning and fusing all Attention operations without caching intermediate results in HBM.
- Recomputing intermediate results during backpropagation to mitigate the computational cost of not caching them.
These improvements have enhanced training speed 2-4 times and reduced memory usage from quadratic to linear growth with sequence length. At a sequence length of 4096, it saves 20 times the memory compared to PyTorch.
Consider A100-40GB SXM, which has a compute bandwidth of 312TFLOPS and memory bandwidth of 1555GB/s and mixed precision training, the operational intensity for the particular hardware is
Suppose we are to calculate attention for ,
For more details refer to this blog. The bottleneck may vary for different values of and .
Approach
- Through block computation and fusion, reduce the cache of intermediate results onto HBM.
- Recompute the required intermediate results during backward pass.
hardware prerequisites
- GPU SRAM(Static Random-Access Memory): 19TB/s(20MB)
- GPU HBM(High Bandwidth Memory): 1.5TB/s(40GB)
- CPU DRAM: 12.8GB/s(>1TB)
- Set block sizes , .
- Initialize , , in HBM.
- Divide into blocks of size each, and divide in to blocks and , of size each.
- Divide into blocks of size each, divide into blocks of size each, divide into blocks of size each.
- for do
- Load from HBM to on-chip SRAM.
- for do
- Load from HBM to on-chip SRAM.
- On chip, compute .
- On chip, compute , (pointwise), .
- On chip, compute , .
- Write .
- Write to HBM.
- end for
- end for
- Return .
Online Softmax
safe softmax
Vanilla softmax could encounter numerical instability, hence a safe version of softmax is used in flash attention where the maximum of the vector is subtracted before exponentiation.
3-pass softmax
For ,
For ,
For ,
2-pass softmax
read more
下面我们看一下 如何通过矩阵分块和融合多个计算 来减少对HBM的访问 这里我们先跳过Softmax的操作 因为它比较特殊 分块计算比较麻烦 后面来单独讨论 我们先认为结果直接就是S乘以V 我们看一下整个过程 开始时我们从HBM里 读取Q的前两行 K转制的前三列 V的前三列 然后传入到SRAM上 对它们进行计算 Q乘以K的转制得到S 并不存入HBM 然后直接和V的分块进行计算 这时我们得到的并不是最终的结果 O的前两行 因为我们知道 O是对所有的V的一个加权平均 目前只是对V的前三行进行加权平均 后边我们还要对O进行更新 这里我们用浅一点的颜色来表示 这个值还只是一个中间结果 接着K和V的分块还保留在SRAM里 从HBM里读取O的中间 Q的中间两行 然后经过同样的计算 得到O的第三和第四行的中间结果 然后继续保留K和V的分块在SRAM里 从HBM里读取Q的最后两行 经过同样的计算 得到O的最后两行的中间结果 接下来读取K转制的后三列 V的后三行 Q的前两行 进行计算 这里算出的O是对V的后三行值的加权平均 再从HBM里读取O之前保存的中间结果 也就是对V的前三行的加权平均值 进行加核 这就是O最终的结果了 同样保持K和V的分块不变 从HBM里读取下一个分块的Q 进行计算 从HBM里读取之前计算的中间结果 O 加核更新后存入HBM 最后继续保持SRAM里的K和V分块不变 从HBM里读取最后一个分块的Q 进行计算 从HBM里读取之前计算的中间结果 O 加核更新后存入HBM 这时就完成了一个Attention的计算 我们可以发现 通过将矩阵分块 以及将多部计算进行融合 中途没有将中间计算结果 S存入HBM 大大减少了IO的时间 这一切看起来都不错 除了Softmax Softmax是暗行进行的 只有一行所有的数据都计算完成后 才能进行这里的求核计算 所以我们想要让我们之前的矩阵分块 对Attention多部进行融合计算 得以进行的前提 是必须解决Softmax的分块计算问题 下面我们就来讨论Softmax的分块计算 现在我们训练都是混合精度 在FP16下进行 如果X等于12 则E的X次方 就大于FP16能表示的最大的数了 为了解决这个数值溢出问题 人们提出了一种叫做 Safe Softmax的算法 首先找出从X1到Xn里最大的值m 给Softmax计算公式的分子和分母 同时除以E的m次方 Softmax的结果不变 进行化简就变成了后面的表达式 E的指数部分 就都小于等于0了 这时用FP16表示 就不会有数值溢出的问题了 我们再看一下 Safe Softmax的计算过程 有一组X 通过mx求出X里的最大值 通过px将X变化成E的xi减去mx 通过Lx对所有的px求和 最终的Softmax值就是px除以Lx 接下来我们看一下 如果有一组X 它有In个 我们将它分成两个部分 第一个部分是从1到n 第二个部分是从n加1到2n 然后对两个块分别计算 它们的mx最大值 pxE的x次方减去mx Lxpx的求和 然后我们看是否能求出 对于原始In个X的 正确的Softmax的值 首先mx等于max mx1 mx2 我们需要两个临时变量 保存两个分块里的最大值 mx1 mx2 就可以保证 最后可以求出 全局最大值mx了 第二步 我们把px1和px2拼接起来 然后给它们都乘以一个系数 因为mx是mx1和mx2里 较大的值 所以它肯定等于mx1 或者mx2 我们假设mx等于mx2 那么后面E的指数项就等于0 E的0次方等于1 后边就是px2不变 这里很好理解 因为px2分块计算的时候 它的x减去的就是全局最大值 所以此时不需要再进行调整 那么对于px1的部分 它在分块计算时 它的x减去的是局部最大值 不是全局最大值 那么它和全局最大值比少减了多少 就是这里的mx1减去mx 现在给它补回来 这样乘以这个系数后 px1就被调整为 减去全局最大值的正确结果了 对于Lx的计算也是同样的道理 把Lx1和Lx2调整后加起来 得到全局的正确的加和值 最后用px除以Lx 就得到了2n个x 正确的softmax的值了 可以看到softmax 也可以通过分块来计算了 只是我们需要额外保存几个变量 mx1 mx2 Lx1 Lx2 不过他们对于softmax 每行都各占用一个数字 存储占用非常小 另外就是在分块进行合并时 需要额外的调整计算 增加了计算量 但是这些计算量 相对于减少的IO时间 都是非常划算的 我们看一下伪代码的实现 QKV存储在HBM上 它们的维度都是nxd GPU芯片上的sRAM大小为m 然后确定列分块大小为m除以4d 向上取整 为什么是4 因为要存放QKVO 4个分块矩阵 行分块大小为m除以4d 和d取小的值 这里之所以要和d相比取小的值 是为了控制Q矩阵分块 最大就是一个d乘以d的方阵 不要让Q分块矩阵的行太大 从而让在sRAM里生成的 中间矩阵计算结果太大 而超出sRAM的大小 在HBM里初始化OLM 然后对QKV进行分块 对OLM进行分块 然后进入循环 外循环是KV 内循环是Q 这里和我们之前动画演示的例子 是一致的 首先从HBM里读取K和V的分块 到sRAM 然后循环读取QOLM的分块 到sRAM 计算分块的Q乘以K的转制 然后计算分块的MPL 然后根据之前分块 已经计算出来的最大值 和当前分块的最大值进行比较 找出新的最大值 同时更新L的值 接着更新O的值 对角矩阵的逆矩阵 等于原始矩阵 对角元素的导数构成的矩阵 乘以后面的O值 相当于给每一行 除以每一行的求和值 这里给原来的OI 先乘以原来的DAG LI还原回去 加上新的O的值 再除以新的求和值 更新O值到HBM 更新L M到HBM 最后返回O 接着我们看反向传播 在前向传播时 会保存softmax计算得到的 M和L的中间值 这样在反向传播分块计算时 就可以快速重新计算激活值了 这可以看作是另一种形式的 梯度检查点 最后我们发现 FlashAttention的计算量 比标准实现增加了一些 但是对HBM的IO访问量大幅减少了 训练的时间 更是大幅缩减到原来的1⁶ FlashAttention2 大致思想和FlashAttention1类似 增加了一些工程上的优化 比如减少了非矩阵乘法的计算 将Q改为外循环 KV改为内循环 更进一步减少了对HBM的读写 增加了并行度 最后进一步利用分块计算的优势 如果判断一个分块 是原始矩阵的上三角部分 也就是它是被mask掉的部分 那么就不需要进行Attention的计算了 从而更进一步减少了计算量
We’ll now explore how matrix partitioning and computing fusion reduces HBM access. We bypass Softmax initially for its complexity in partitioned computation, discussed later. Assume results directly as .
The process starts by reading the first two rows of Q, three columns of , and three columns of V from HBM to SRAM for computation. yields , which isn’t stored in HBM, used immediately with V partitions, resulting not in final ‘s two rows but intermediate results updated later. Different colors denote intermediate results needing update.
Keep and V partitions in SRAM, load ‘s middle two rows from HBM, compute similarly for ‘s middle rows, repeat for final rows. Load last columns of , last rows of V, previous Q rows, update using earlier intermediate results. Maintain , V partitions, load next Q chunk, compute, update .
Through partitioning and fusion, we avoid storing intermediate in HBM, significantly reducing IO time. Except for Softmax, row-wise computation requiring complete data for summation, crucial for fusion requires solving Softmax’s partition computation.
Softmax computation in mixed-precision training (FP16) risks overflow for large exponents, fixed by Safe Softmax. It adjusts all terms by the largest value to prevent overflow. Simplifies to expressions where exponents are non-positive, manageable in FP16.
Safe Softmax involves finding maximum , transforming , and summing for normalization. For partitioned values, . Adjust with coefficients, ensuring correct global maximum adjustments. Sum adjusted partitions for correct Softmax.
Though requiring extra variables per row, storage is minor. The balancing of computation with IO reduction is favorable.
Pseudo-code overview: QKV stored in HBM, SRAM size . Column block size . 4 for QKVO partitions. Row block size to control Q block size. Initialize in HBM, partition QKV, .
Outer loop KV, inner Q matches our animations. Load K, V partition to SRAM, loop Q, OL partitions, compute Q block, , , update max values, using DAG’s inverse matrix, update in HBM.
Backpropagation saves Softmax’s intermediate for quick recomputation during partitioned backpropagation, akin to gradient checkpointing.
FlashAttention increases computation slightly but drastically reduces HBM IO, markedly decreasing training time. FlashAttention2 follows similar principles, with further optimizations, reducing non-matrix computation, reversing Q/KV loops, enhancing parallelism, exploiting partition advantages, skipping computations for masked upper-triangle, further reducing computation.