In vanilla multi-head attention setting, there is a separate set of query, key and value vectors for each token. In this case, the memory usage goes up quickly for long sequence and incremental inference is often slow. Multi-query attention(MQA) instead shares keys and values across all different attention heads.
During the calculation for block i, the K and V vectors are broadcast to multiple heads by Tensor.expand before doing multi-head attention. Once the output is computed, memory for the KV-cache is free before proceeding to the next block.
The output is
In terms of memory usage comparison between MQA and MHA, there may not be a noticeable decrease in memory usage for generating relatively short sequences, one can refer to this discussion or conduct more experiments themselves.