Motivation

One of the major downsides of transformer architecture is that memory footprint and calculation grow quadratically in terms of sequence length.

During inference time, the model is essentially performing next token prediction task, i.e., it is trying to generate a next token conditioned on all previously generated tokens. Take a closer look at self attention mechanism, one could notice that it only needs , and in order to perform necessary calculation for the next token.

Without KV-cache, each time before the model generates a new token, it goes through calculating matrices for all previous tokens. However, there exist a huge amount of repeated calculations during the process. With KV-cache, we only use query for the last token, and append the corresponding key and value to KV-cache.

Memory usage for KV-cache is 2*\text{num\_layers}*\text{sequence\_length}*d_{k}.

Reference