Fast and Memory-Efficient Exact Attention with IO-Awareness

Flash attention guarantees exact attention, meaning the results obtained are strictly identical to the standard/vanilla attention calculations. This stands in contrast to some other algorithms that enhance computation speed by compromising the precision of attention. Lastly, by IO-awareness it indicates the improvements come from increasing IO efficiency.

Motivation

First let’s recap the standard procedure of computing attention in the Transformer architecture.

It starts with the input tokens’ vectors, generating the matrices Q, K, V.

Then, by multiplying Q with the transpose of K, we obtain S.

Next, S undergoes row-wise softmax operation to get the attention matrix P.

Finally, P is multiplied by V to get the attention output O.

For the sake of simplification, unless specified otherwise, the following discussion does not consider scaled dot product, multi-head attention, or dropout.

Let’s see how attention is computed on a physical GPU using PyTorch code.

Q, K, and V matrices are stored in HBM (High Bandwidth Memory) with shapes N \times d,

where N is sequence length and d is feature dimension. The process is as follows:


  1. Load Q, K matrices from HBM to SRAM.

  2. Compute S as Q \cdot K^T.

  3. Write S back to HBM.

  4. Load S from HBM to SRAM.

  5. Compute P as Softmax of S.

  6. Write P back to HBM.

  7. Load P and V from HBM to SRAM.

  8. Compute O as P \cdot V.

  9. Write O back to HBM.

  10. Return O.


There are many temporary variable reads/writes, like matrices S and P, whose sizes grow quadratically with sequence length.

These intermediate results are necessary for gradient computation during backpropagation.

compute/memory bandwidths in attention calculation

Training speed constraints include Compute Bound and Memory Bound scenarios.

Compute Bound constraints are due to operations like large matrix multiplications and multi-channel convolutions, which require minimal data but are computationally intensive.

Memory Bound constraints occur when HBM data read speeds cannot keep up with computation speeds, leading to idle computational resources.

Main operations include element-wise operations like ReLU and Dropout and reduction operations like SUM and Softmax.

Attention calculations are mostly Memory Bound.

Optimizations for Memory Bound scenarios involve fusing multiple operations (Fusion) to reduce HBM access time, allowing multiple operations to access HBM only once. However, intermediate results needed for backpropagation are recomputed rather than stored to save HBM access time. Those interested can refer to my earlier videos on gradient checkpointing.

Memory in GPUs is hierarchical, with fast-access on-chip cache and slower-access off-chip HBM. To optimize IO speed, computations should access on-chip cache as much as possible while reducing off-chip HBM access. Refer to my previous video on GPU architecture for more information.

Earlier Attention improvements focused on reducing computation. FlashAttention focuses on reducing IO access and accelerating IO speed via on-chip cache. Its goal is to avoid Attention operations’ HBM read/writes by:

  1. Matrix partitioning and fusing all Attention operations without caching intermediate results in HBM.
  2. Recomputing intermediate results during backpropagation to mitigate the computational cost of not caching them.

These improvements have enhanced training speed 2-4 times and reduced memory usage from quadratic to linear growth with sequence length. At a sequence length of 4096, it saves 20 times the memory compared to PyTorch.

Consider A100-40GB SXM, which has a compute bandwidth of 312TFLOPS and memory bandwidth of 1555GB/s and mixed precision training, the operational intensity for the particular hardware is

𝜋𝛽=3121𝑒1215551𝑒9=201FLOPS/Bytes

Suppose we are to calculate attention for S=QK^{T},

𝜋𝑡𝛽𝑡=2𝑁2𝑑2𝑁𝑑+2𝑁2=𝑁2𝑑2𝑁𝑑+𝑁2

For more details refer to this blog. The bottleneck may vary for different values of N and d.

Ndoperations/bytesbottleneck
25612864memory bound
1024128102memory bound
4096128120memory bound
25625685memory bound
1024256171memory bound
4096256228compute bound
256512102memory bound
1024512256compute bound
4096512410compute bound

Approach

  • Through block computation and fusion, reduce the cache of intermediate results onto HBM.
  • Recompute the required intermediate results during backward pass.

hardware prerequisites

  • GPU SRAM(Static Random-Access Memory): 19TB/s(20MB)
  • GPU HBM(High Bandwidth Memory): 1.5TB/s(40GB)
  • CPU DRAM: 12.8GB/s(>1TB)

Algorithm FlashAttention
  1. Set block sizes B_c = \left\lfloor \frac{M}{4d} \right\rfloor, B_r = \min \left(\left\lfloor \frac{M}{4d} \right\rfloor, d \right).
  2. Initialize \mathbf{O} = (0)_{N \times d} \in \mathbb{R}^{N \times d}, \ell = (0)_{N} \in \mathbb{R}^{N}, m = (-\infty)_{N} \in \mathbb{R}^{N} in HBM.
  3. Divide \mathbf{Q} into T_r = \left\lceil \frac{N}{B_r} \right\rceil blocks \mathbf{Q}_1, \ldots, \mathbf{Q}_{T_r} of size B_r \times d each, and divide \mathbf{K}, \mathbf{V} in to T_c = \left\lceil \frac{N}{B_c} \right\rceil blocks \mathbf{K}_1, \ldots, \mathbf{K}_{T_c} and \mathbf{V}_1, \ldots, \mathbf{V}_{T_c}, of size B_c \times d each.
  4. Divide \mathbf{O} into T_r blocks \mathbf{O}_1, \ldots, \mathbf{O}_{T_r} of size B_r \times d each, divide \ell into T_r blocks \ell_i, \ldots, \ell_{T_r} of size B_r each, divide m into T_r blocks m_1, \ldots, m_{T_r} of size B_r each.
  5. for 1 \leq j \leq T_c do
  6. \quadLoad \mathbf{K}_j, \mathbf{V}_j from HBM to on-chip SRAM.
  7. \quadfor 1 \leq i \leq T_r do
  8. \quad\quadLoad \mathbf{Q}_i, \mathbf{O}_i, \ell_i, m_i from HBM to on-chip SRAM.
  9. \quad\quadOn chip, compute S_{ij} = \mathbf{Q}_i \mathbf{K}_j^{\top} \in \mathbb{R}^{B_r \times B_c}.
  10. \quad\quadOn chip, compute \tilde{m}_{ij} = \text{rowmax}(S_{ij}) \in \mathbb{R}^{B_r}, \tilde{P}_{ij} = \exp(S_{ij} - \tilde{m}_{ij}) \in \mathbb{R}^{B_r \times B_c} (pointwise), \tilde{\ell}_{ij} = \text{rowsum}(\tilde{P}_{ij}) \in \mathbb{R}^{B_r}.
  11. \quad\quadOn chip, compute m_i^{\text{new}} = \max(m_i, \tilde{m}_{ij}) \in \mathbb{R}^{B_r}, \ell_i^{\text{new}} = e^{m_i - m_i^{\text{new}}} \ell_i + e^{\tilde{m}_{ij} - m_i^{\text{new}}} \tilde{\ell}_{ij} \in \mathbb{R}^{B_r}.
  12. \quad\quadWrite \mathbf{O}_i \leftarrow \text{diag}(\ell_i^{\text{new}})^{-1} (\text{diag}(\ell_i) e^{m_i - m_i^{\text{new}}} \mathbf{O}_i + e^{\tilde{m}_{ij} - m_i^{\text{new}}} \tilde{P}_{ij} \mathbf{V}_j).
  13. \quad\quadWrite \ell_i \leftarrow \ell_i^{\text{new}}, m_i \leftarrow m_i^{\text{new}} to HBM.
  14. \quadend for
  15. end for
  16. Return \mathbf{O}.

Online Softmax

safe softmax

Vanilla softmax could encounter numerical instability, hence a safe version of softmax is used in flash attention where the maximum of the vector is subtracted before exponentiation.

softmax(𝑥𝑖)=exp{𝑥𝑖max(𝑥)}𝑛𝑗=1exp{𝑥𝑗max(𝑥)}

3-pass softmax

For i=1, ..., N,

𝑚𝑖max(𝑚𝑖1,𝑥𝑖)

For i=1, ..., N,

𝑑𝑖𝑑𝑖1+exp{𝑥𝑖𝑚𝑁}

For i=1, ..., N,

𝑎𝑖exp{𝑥𝑖𝑚𝑁}𝑑𝑁

2-pass softmax

readme

We’ll now explore how matrix partitioning and computing fusion reduces HBM access. We bypass Softmax initially for its complexity in partitioned computation, discussed later. Assume results directly as 𝑆×𝑉.

The process starts by reading the first two rows of Q, three columns of 𝐾𝑇, and three columns of V from HBM to SRAM for computation. 𝑄×𝐾𝑇 yields 𝑆, which isn’t stored in HBM, used immediately with V partitions, resulting not in final 𝑂‘s two rows but intermediate results updated later.

Keep 𝐾 and V partitions in SRAM, load 𝑄‘s middle two rows from HBM, compute similarly for 𝑂‘s middle rows, repeat for final rows. Load last columns of 𝐾𝑇, last rows of V, previous Q rows, update 𝑂 using earlier intermediate results. Maintain 𝐾, V partitions, load next Q chunk, compute, update 𝑂.

Through partitioning and fusion, we avoid storing intermediate 𝑆 in HBM, significantly reducing IO time. Except for Softmax, row-wise computation requiring complete data for summation, crucial for fusion requires solving Softmax’s partition computation.

Softmax computation in mixed-precision training (FP16) risks overflow for large exponents, fixed by Safe Softmax. It adjusts all terms by the largest value 𝑚 to prevent overflow. Simplifies to expressions where exponents are non-positive, manageable in FP16.

Safe Softmax involves finding maximum 𝑚, transforming 𝑥𝑖𝑒𝑥𝑖𝑚, and summing for normalization. For partitioned 𝑛 values, 𝑚𝑥=max(𝑚(𝑥1),𝑚(𝑥2)). Adjust 𝑝(𝑥1),𝑝(𝑥2) with coefficients, ensuring correct global maximum adjustments. Sum adjusted partitions for correct Softmax.

Though requiring extra variables per row, storage is minor. The balancing of computation with IO reduction is favorable.

Pseudo-code overview: QKV stored in 𝑛×𝑑 HBM, SRAM size 𝑚. Column block size 𝑚4𝑑. 4 for QKVO partitions. Row block size min(𝑚4𝑑,𝑑) to control Q block size. Initialize 𝑂,𝐿,𝑀 in HBM, partition QKV, 𝑂𝐿𝑀.

Outer loop KV, inner Q matches our animations. Load K, V partition to SRAM, loop Q, O, l, 𝑚 partitions, compute Q block, 𝑄𝐾𝑇, 𝑀𝑃𝐿, update max values, 𝑂 using diag(𝑙)‘s inverse matrix, update 𝑂,𝐿,𝑀 in HBM.

Backpropagation saves Softmax’s intermediate 𝑀,𝐿 for quick recomputation during partitioned backpropagation, akin to gradient checkpointing.

Tip

FlashAttention increases computation slightly but drastically reduces HBM IO, markedly decreasing training time. FlashAttention2 follows similar principles, with further optimizations, reducing non-matrix computation, reversing Q/KV loops, enhancing parallelism, exploiting partition advantages, skipping computations for masked upper-triangle, further reducing computation.