Skip to content

Attention Mechanism

Attention Mechanism

Self Attention

KV-Cache

During inference time, the model is essentially performing next token prediction task, i.e., it is trying to generate a next token conditioned on all previously generated tokens. Take a closer look at attention, one could notice that it only needs QTQ_{T}, KK and VV in order to perform necessary calculation for the next token.

Without KV-cache, during each forward pass, the self attention block goes through calculating Q,K,VQ, K, V matrices for all previous tokens. However, there exist a huge amount of repeated calculations during the process. With KV-cache, we only use query for the last token, and append the corresponding key and value to KV-cache.

Memory usage for KV-cache is 2×num_layers×sequence_length×dk2 \times \text{num\_layers}\times\text{sequence\_length}\times d_{k}.


read more

Multi Head Attention

Multi Query Attention

In vanilla multi-head attention setting, there is a separate set of query, key and value vectors for each token. In this case, the memory usage goes up quickly for long sequence and incremental inference is often slow. Multi-query attention(MQA) instead shares keys and values across all different attention heads.

1
import torch
2
3
B = 4 # batch
4
T = 128 # sequence length
5
D = 512 # embeddings dimension
6
H = 8 # number of headscc
7
8
D_single = D // H # single head dimension
9
10
torch.manual_seed(47)
11
X = torch.randn(B, T, D)
12
13
Wq = torch.nn.Linear(D, D)
14
Wk = torch.nn.Linear(D, D_single)
15
Wv = torch.nn.Linear(D, D_single)
16
17
Q, K, V = Wq(X), Wk(X), Wv(X)
18
print("Q: ", Q.shape)
19
print("K: ", K.shape)
20
print("V: ", V.shape)
>>> torch.Size([4, 128, 512])
>>> torch.Size([4, 128, 64])
>>> torch.Size([4, 128, 64])

During the calculation for block ii, the KK and VV vectors are broadcast to multiple heads by Tensor.expand before doing multi-head attention. Once the output is computed, memory for the KV-cache is free before proceeding to the next block.

1
Q_ = Q.view(B, T, H, D_single).transpose(1, 2)
2
K_ = K.unsqueeze(1).expand(B, H, T, D_single).transpose(2, 3)
3
V_ = V.unsqueeze(1).expand(B, H, T, D_single)
4
5
print("Q: ", Q_.shape)
6
print("K: ", K_.shape)
7
print("V: ", V_.shape)

The output is

>>> Q: torch.Size([4, 8, 128, 64])
>>> K: torch.Size([4, 8, 64, 128])
>>> V: torch.Size([4, 8, 128, 64])

Grouped Query Attention

Grouped query attention is somewhat an interpolation of MHA and MQA with the goal of balancing the pros and cons. Tokens are first grouped together according to positions and share key and value vectors within groups. GQA is a balance between MHA from vanilla transformers and MQA, leveraging the advantages from both sides.

Grouped Query Attention

The following table is a direct comparison between dimensions of query, key, value and output for MHA and MQA. BB is batch size, TT is sequence length, dd is embedding dimension, n_heads\text{n\_heads} is number of heads, dk=d/n_headsd_{k}=d/\text{n\_heads} is query/key/value vector dimension, n_kv_heads\text{n\_kv\_heads} is the number of key/value heads and G=n_heads/n_kv_headsG=\text{n\_heads/n\_kv\_heads} is the number of groups.

Other variants of Attention

Linear Attention


read more