Motivation

Transformers are slow and memory-hungry on long sequences due to quadratic time and memory complexity of self-attention in sequence length. The flash attention paper propose an IO-aware and memory efficient algorithm that implements exact attention.

Hardware prerequisites

  • GPU SRAM(Static Random-Access Memory): 19TB/s(20MB)
  • GPU HBM(High Bandwidth Memory): 1.5TB/s(40GB)
  • CPU DRAM: 12.8GB/s(>1TB)

Compute/memory bandwidths in attention calculation

finish the note

Consider A100-40GB SXM, which has a compute bandwidth of 312TFLOPS and memory bandwidth of 1555GB/s and mixed precision training, the operational intensity for the particular hardware is

Suppose we are to calculate attention for ,

For more details refer to this wonderful blog. The bottleneck may vary for different values of and .

operations/bytesbottleneck
25612864memory-bound
1024128102memory-bound
4096128120memory-bound
25625685memory-bound
1024256171memory-bound
4096256228compute-bound
256512102memory-bound
1024512256compute-bound
4096512410compute-bound

Online softmax

Safe softmax

Vanilla softmax could encounter numerical instability, hence a safe version of softmax is used in flash attention where the maximum of the vector is subtracted before exponentiation.

3-pass softmax

For ,

For ,

For ,

2-pass softmax

Reference