In vanilla multi-head attention setting, there is a separate set of query, key and value vectors for each token. In this case, the memory usage goes up quickly for long sequence and incremental inference is often slow. Multi-query attention(MQA) instead shares keys and values across all different attention heads.

import torch
 
B = 4  # batch
T = 128  # sequence length
D = 512  # embeddings dimension
H = 8  # number of heads
 
D_single = D // H  # single head dimension
 
torch.manual_seed(47)
X = torch.randn(B, T, D)
 
Wq = torch.nn.Linear(D, D)
Wk = torch.nn.Linear(D, D_single)
Wv = torch.nn.Linear(D, D_single)
 
Q, K, V = Wq(X), Wk(X), Wv(X)
print("Q: ", Q.shape)
print("K: ", K.shape)
print("V: ", V.shape)
>>> torch.Size([4, 128, 512])
>>> torch.Size([4, 128, 64])
>>> torch.Size([4, 128, 64])

During the calculation for block , the and vectors are broadcast to multiple heads by Tensor.expand before doing multi-head attention. Once the output is computed, memory for the KV-cache is free before proceeding to the next block.

Q_ = Q.view(B, T, H, D_single).transpose(1, 2)
K_ = K.unsqueeze(1).expand(B, H, T, D_single).transpose(2, 3)
V_ = V.unsqueeze(1).expand(B, H, T, D_single)
 
print("Q: ", Q_.shape)
print("K: ", K_.shape)
print("V: ", V_.shape)

The output is

>>> Q: torch.Size([4, 8, 128, 64])
>>> K: torch.Size([4, 8, 64, 128])
>>> V: torch.Size([4, 8, 128, 64])

In terms of memory usage comparison between MQA and MHA, there may not be a noticeable decrease in memory usage for generating relatively short sequences, one can refer to this discussion or conduct more experiments themselves.