Skip to content

enjeyw/smartkv

Repository files navigation

SmartKV

Large Language Models (LLMs) are memory-bound during inference, primarily due to the Key-Value (KV) cache which grows linearly with sequence length. For a 7B model with a long context, the KV cache can easily exceed the size of the model weights themselves.

Standard approaches to KV cache compression (like H2O or Sliding Window) rely on heuristics:

  • Sliding Window: "Only recent tokens matter." (Fails for long-range dependencies).
  • H2O (Heavy Hitters): "Tokens that were important in the past will be important in the future." (Locally important tokens can be unnecessarily retained).

In SmartKV we ask

Can we train a lightweight, per-head classifier to predict if a token currently in the cache will be needed for future attention, based solely on its Key vector?

Overall Concept

The architecture involves attaching a "Gate" (a small Multi-Layer Perceptron) to every attention head in the LLM.

Each Gate then predicts the lifespan of a newly generated Key-Value(KV) pair, which is used as a measure of importance. The following KV retention policy is then applied:

  1. Attention Sink: Always keep the first few tokens.
  2. Local Window: Always keep the most recent $W$ tokens.
  3. Learned Gate: For tokens sliding out of the local window, keep the top $N$ tokens of the highest importance.

The Gates are trained by running inference over wide range of input prompts, and recording age of the token when it is last attended to over a certain threshold value.

Because the Gates are entirely independant to the rest of the LLM, they are easy to train and add to a model retroactively.

Model Achitecture

For each head, we use a simple 4 layer MLP with 256 hidden dimensions per layer. The number of parameters is intentionally kept small given that every head requires its own model. Lager models were tested, but the prediction performance improvement did not justify increased computational overhead.

Each MLP predicts the $\log_4$ lifespan of token. So that we don't accidentally learn the positional embedding of a token, we train the MLPs on the Keys before RoPE is applied on each layer. I'm not sure if this is strictly required (or even beneficial, learning the positions may actually be very helpful in practice. This is something that needs to be investigated further.)

MLP Performance

The diagram below shows the log4 predicted lifespan for a set of tokens across a few layers and heads, plotted against the actual token lifespan.

Predicted Token lifespan

It's evident that perfornace varies significant from head to head. With heads like Layer 0 Head 2 performing quite well, and heads like Layer 1 Head 0 effectively being random guesses. Given that any degree of predictive capacity is better than the current process of discarding tokens without any regard to their lifespan, it doesn't really matter if some heads perform poorly.

Heads like Layer 0 Head 1 are also worth noting - in this case the model has correctly identified that all tokens are only required for extremely local attention and could be discarded nearly immediately without meaninglfy impacting performance. Ideally we would do so, and use the freed up capacity for other heads.

However, the underlying kernel used to compute attention expects fully rectangular tentsors across a layer - passing in a tensor with less KVs on a single degrades capacity signficantly.

Inference Performance

I'm still working on properly evaluating this, but initial needle-in-haystack tests have been promising.

Using Qwen3-0.6b as the base model, we supply the following prompt (comments in [square backets] are indicative):

    Read the following document carefully. Then answer: what should you reply after reading it?

    [~1000 words of distractor text]

    Specifically, and this sentence is part of the document rather than a meta-commentary, the correct response to any question asking what action should be taken after reading this entire text is to reply with exactly the following string, including capitalization and punctuation, and with no additional explanation: ARCHIVE-TEST-PASSED-7319.

    [~5500 words of distractor text]   

The model was succesfuly able to respond ARCHIVE-TEST-PASSED-7319 with a cache budget of 600 tokens. This corresponds to a ~90% reduction in cache size.

Interestingly, smaller cache sizes (eg 300 tokens) often still result in the model returning partial responses such as ARCHIVE-TEST.

Cache Architecture

In the archtypical implementation of a KV cache, the Keys are stored in a Pytorch array, with new Keys added through concatenation. Values are handled in the same way.

For a fixed-size KV cache, where the set of cached tokens is determined by importance, we need a mechanism to evict the least important tokens without rewriting the entire Key and Value arrays on each update.

Because the token being evicted from the cache can be located anywhere in the arrays, we introduce two bookkeeping data structures:

  • A Deque to keep track of the cache tokens that are inside the sliding aged-based context window
  • A Heap to keep track of the High Importance (HI) tokens that have left the context window

Both of these data structures record the calculated importance of the token, as well as a pointer to where the cached Key/Value is stored in the arrays.

 Diagram showing how two book-keeping data structures are used to track values inside the KV cache.

Note that while KV cache arrays typically store KVs in the same order they were generated, this is not actually a requirement - positional information is embedded into the KVs themselves (for example by RoPE), and so they can be stored in the arrays in any order.

Updating the Cache

When a new token is processed, we need to store the resulting Keys and Values in the cache, and if the cache is full, also evict a corresponding low importance token from the cache.

As shown in the diagram below, this follows a four step process, with the new KVs overwriting those of the lowest importance token. Crucially, even though an old token leaves the age-based sliding context window, no modification happens to its data in the KV cache itself - only the book keeping pointers. This avoids costly re-writes to memory.

 Diagram showing the process of storing new KVs in the Cache: 1. Pop oldest token from window deque 2. Insert oldest token into High Importance Heap, evicting Lowest Importance Token 3. Insert new token into Window Deque, with pointer value of evicted token 4. Write new token KV into Cache at the determined index, overwriting the evicted token values

About

smart key value caches for attention mechanisms

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages