During inference time, the model is essentially performing next token prediction task,
i.e., it is trying to generate a next token conditioned on all previously generated tokens.
Take a closer look at attention,
one could notice that it only needs QT, K and V in order to perform necessary calculation for the next token.
Without KV-cache, during each forward pass, the self attention block goes through calculating Q,K,V matrices for all previous tokens.
However, there exist a huge amount of repeated calculations during the process.
With KV-cache, we only use query for the last token, and append the corresponding key and value to KV-cache.
Memory usage for KV-cache is 2×num_layers×sequence_length×dk.
In vanilla multi-head attention setting, there is a separate set of query, key and value vectors for each token.
In this case, the memory usage goes up quickly for long sequence and incremental inference is often slow. Multi-query attention(MQA)
instead shares keys and values across all different attention heads.
During the calculation for block i, the K and V vectors are broadcast to multiple heads by
Tensor.expand before doing multi-head attention.
Once the output is computed, memory for the KV-cache is free before proceeding to the next block.
Grouped query attention is somewhat an interpolation of MHA and MQA with the goal of balancing the pros and cons.
Tokens are first grouped together according to positions and share key and value vectors within groups.
GQA is a balance between MHA from vanilla transformers and MQA, leveraging the advantages from both sides.
The following table is a direct comparison between dimensions of query, key, value and output for MHA and MQA. B is batch size, T is sequence length, d is embedding dimension, n_heads is number of heads, dk=d/n_heads is query/key/value vector dimension, n_kv_heads is the number of key/value heads and G=n_heads/n_kv_heads is the number of groups.