Fast and Memory-Efficient Exact Attention with IO-Awareness

Flash attention guarantees exact attention, meaning the results obtained are strictly identical to the standard/vanilla attention calculations. This stands in contrast to some other algorithms that enhance computation speed by compromising the precision of attention. Lastly, by IO-awareness it indicates the improvements come from increasing IO efficiency.

Motivation

First let’s recap the standard procedure of computing attention in the Transformer architecture.

It starts with the input tokens’ vectors, generating the matrices , , .

Then, by multiplying with the transpose of , we obtain .

Next, undergoes row-wise softmax operation to get the attention matrix .

Finally, is multiplied by to get the attention output .

For the sake of simplification, unless specified otherwise, the following discussion does not consider scaled dot product, multi-head attention, or dropout.

Let’s see how attention is computed on a physical GPU using PyTorch code.

, , and matrices are stored in HBM (High Bandwidth Memory) with shapes ,

where is sequence length and is feature dimension. The process is as follows:


  1. Load , matrices from HBM to SRAM.

  2. Compute as .

  3. Write back to HBM.

  4. Load from HBM to SRAM.

  5. Compute as Softmax of .

  6. Write back to HBM.

  7. Load and from HBM to SRAM.

  8. Compute as .

  9. Write back to HBM.

  10. Return .


There are many temporary variable reads/writes, like matrices and , whose sizes grow quadratically with sequence length.

These intermediate results are necessary for gradient computation during backpropagation.

compute/memory bandwidths in attention calculation

Training speed constraints include Compute Bound and Memory Bound scenarios.

Compute Bound constraints are due to operations like large matrix multiplications and multi-channel convolutions, which require minimal data but are computationally intensive.

Memory Bound constraints occur when HBM data read speeds cannot keep up with computation speeds, leading to idle computational resources.

Main operations include element-wise operations like ReLU and Dropout and reduction operations like SUM and Softmax.

Attention calculations are mostly Memory Bound.

Optimizations for Memory Bound scenarios involve fusing multiple operations (Fusion) to reduce HBM access time, allowing multiple operations to access HBM only once. However, intermediate results needed for backpropagation are recomputed rather than stored to save HBM access time. Those interested can refer to my earlier videos on gradient checkpointing.

Memory in GPUs is hierarchical, with fast-access on-chip cache and slower-access off-chip HBM. To optimize IO speed, computations should access on-chip cache as much as possible while reducing off-chip HBM access. Refer to my previous video on GPU architecture for more information.

Earlier Attention improvements focused on reducing computation. FlashAttention focuses on reducing IO access and accelerating IO speed via on-chip cache. Its goal is to avoid Attention operations’ HBM read/writes by:

  1. Matrix partitioning and fusing all Attention operations without caching intermediate results in HBM.
  2. Recomputing intermediate results during backpropagation to mitigate the computational cost of not caching them.

These improvements have enhanced training speed 2-4 times and reduced memory usage from quadratic to linear growth with sequence length. At a sequence length of 4096, it saves 20 times the memory compared to PyTorch.

Consider A100-40GB SXM, which has a compute bandwidth of 312TFLOPS and memory bandwidth of 1555GB/s and mixed precision training, the operational intensity for the particular hardware is

Suppose we are to calculate attention for ,

For more details refer to this blog. The bottleneck may vary for different values of and .

Approach

  • Through block computation and fusion, reduce the cache of intermediate results onto HBM.
  • Recompute the required intermediate results during backward pass.

hardware prerequisites

  • GPU SRAM(Static Random-Access Memory): 19TB/s(20MB)
  • GPU HBM(High Bandwidth Memory): 1.5TB/s(40GB)
  • CPU DRAM: 12.8GB/s(>1TB)

Algorithm FlashAttention
  1. Set block sizes , .
  2. Initialize , , in HBM.
  3. Divide into blocks of size each, and divide in to blocks and , of size each.
  4. Divide into blocks of size each, divide into blocks of size each, divide into blocks of size each.
  5. for do
  6. Load from HBM to on-chip SRAM.
  7. for do
  8. Load from HBM to on-chip SRAM.
  9. On chip, compute .
  10. On chip, compute , (pointwise), .
  11. On chip, compute , .
  12. Write .
  13. Write to HBM.
  14. end for
  15. end for
  16. Return .

Online Softmax

safe softmax

Vanilla softmax could encounter numerical instability, hence a safe version of softmax is used in flash attention where the maximum of the vector is subtracted before exponentiation.

3-pass softmax

For ,

For ,

For ,

2-pass softmax

readme

We’ll now explore how matrix partitioning and computing fusion reduces HBM access. We bypass Softmax initially for its complexity in partitioned computation, discussed later. Assume results directly as .

The process starts by reading the first two rows of Q, three columns of , and three columns of V from HBM to SRAM for computation. yields , which isn’t stored in HBM, used immediately with V partitions, resulting not in final ‘s two rows but intermediate results updated later.

Keep and V partitions in SRAM, load ‘s middle two rows from HBM, compute similarly for ‘s middle rows, repeat for final rows. Load last columns of , last rows of V, previous Q rows, update using earlier intermediate results. Maintain , V partitions, load next Q chunk, compute, update .

Through partitioning and fusion, we avoid storing intermediate in HBM, significantly reducing IO time. Except for Softmax, row-wise computation requiring complete data for summation, crucial for fusion requires solving Softmax’s partition computation.

Softmax computation in mixed-precision training (FP16) risks overflow for large exponents, fixed by Safe Softmax. It adjusts all terms by the largest value to prevent overflow. Simplifies to expressions where exponents are non-positive, manageable in FP16.

Safe Softmax involves finding maximum , transforming , and summing for normalization. For partitioned values, . Adjust with coefficients, ensuring correct global maximum adjustments. Sum adjusted partitions for correct Softmax.

Though requiring extra variables per row, storage is minor. The balancing of computation with IO reduction is favorable.

Pseudo-code overview: QKV stored in HBM, SRAM size . Column block size . 4 for QKVO partitions. Row block size to control Q block size. Initialize in HBM, partition QKV, .

Outer loop KV, inner Q matches our animations. Load K, V partition to SRAM, loop Q, O, l, partitions, compute Q block, , , update max values, using ‘s inverse matrix, update in HBM.

Backpropagation saves Softmax’s intermediate for quick recomputation during partitioned backpropagation, akin to gradient checkpointing.

Tip

FlashAttention increases computation slightly but drastically reduces HBM IO, markedly decreasing training time. FlashAttention2 follows similar principles, with further optimizations, reducing non-matrix computation, reversing Q/KV loops, enhancing parallelism, exploiting partition advantages, skipping computations for masked upper-triangle, further reducing computation.