Motivation
Transformers are slow and memory-hungry on long sequences due to quadratic time and memory complexity of self-attention in sequence length. The flash attention paper propose an IO-aware and memory efficient algorithm that implements exact attention.
Hardware prerequisites
- GPU SRAM(Static Random-Access Memory): 19TB/s(20MB)
- GPU HBM(High Bandwidth Memory): 1.5TB/s(40GB)
- CPU DRAM: 12.8GB/s(>1TB)
Compute/memory bandwidths in attention calculation
finish the note
Consider A100-40GB SXM, which has a compute bandwidth of 312TFLOPS and memory bandwidth of 1555GB/s and mixed precision training, the operational intensity for the particular hardware is
Suppose we are to calculate attention for ,
For more details refer to this wonderful blog. The bottleneck may vary for different values of and .
operations/bytes | bottleneck | ||
---|---|---|---|
256 | 128 | 64 | memory-bound |
1024 | 128 | 102 | memory-bound |
4096 | 128 | 120 | memory-bound |
256 | 256 | 85 | memory-bound |
1024 | 256 | 171 | memory-bound |
4096 | 256 | 228 | compute-bound |
256 | 512 | 102 | memory-bound |
1024 | 512 | 256 | compute-bound |
4096 | 512 | 410 | compute-bound |
Online softmax
Safe softmax
Vanilla softmax could encounter numerical instability, hence a safe version of softmax is used in flash attention where the maximum of the vector is subtracted before exponentiation.
3-pass softmax
For ,
For ,
For ,