1.2: The KV Cache — How It Eliminates Redundancy

Community Article Published January 26, 2026

The Core Idea

In Article 1, we established that K and V values for position i are mathematically guaranteed to be identical across generation steps (due to frozen weights + causal masking). The KV cache is the natural solution: store K and V after computing them once, and read from storage instead of recomputing.

This is conceptually simple, but the details matter. Let's trace through exactly what gets stored, how the forward pass changes, and what trade-offs emerge.


What Exactly Gets Stored?

For each attention layer in the model, we store two matrices:

K_cache: all Key vectors for positions 0, 1, 2, ..., n-1
V_cache: all Value vectors for positions 0, 1, 2, ..., n-1

Important: We store K and V, but not Q. We'll explain why shortly.

The Shape of the Cache

For a single attention head in a single layer:

K_cache shape: [current_sequence_length, head_dim]
V_cache shape: [current_sequence_length, head_dim]

For the full model:

Full KV cache shape: [num_layers, 2, current_sequence_length, num_heads, head_dim]
                      ↑          ↑   ↑                        ↑          ↑
                      │          │   │                        │          └─ dimension per head (e.g., 128)
                      │          │   │                        └─ attention heads (e.g., 32)
                      │          │   └─ grows during generation
                      │          └─ one for K, one for V
                      └─ each layer has its own cache

Concrete Example: LLaMA-7B

Let's put real numbers to this. LLaMA-7B has:

  • 32 layers
  • 32 attention heads
  • head_dim = 128
  • Using float16 (2 bytes per value)

For a sequence of length seq_len, the KV cache size is:

Cache size = num_layers × 2 × seq_len × num_heads × head_dim × bytes_per_value
           = 32 × 2 × seq_len × 32 × 128 × 2 bytes
           = 524,288 × seq_len bytes
           = 0.5 MB × seq_len

So for a 2048-token sequence: ~1 GB of KV cache for a single request.

We'll return to these memory implications later. First, let's see how the cache actually gets used.


How the Forward Pass Changes

Let me show you the forward pass without and with KV cache, so you can see exactly what changes.

Without KV Cache (Naive)

At each generation step, with current sequence [t_0, t_1, ..., t_{n-1}]:

def forward_naive(all_tokens):
    hidden_states = embed(all_tokens)  # Shape: [n, hidden_dim]
    
    for layer in transformer_layers:
        # Compute Q, K, V for ALL tokens
        Q = hidden_states @ W_Q  # Shape: [n, head_dim]
        K = hidden_states @ W_K  # Shape: [n, head_dim]
        V = hidden_states @ W_V  # Shape: [n, head_dim]
        
        # Full attention computation
        attention_scores = Q @ K.T / sqrt(d)  # Shape: [n, n]
        attention_scores = apply_causal_mask(attention_scores)
        attention_weights = softmax(attention_scores)
        attention_output = attention_weights @ V  # Shape: [n, head_dim]
        
        hidden_states = layer.ffn(attention_output)
    
    # Only need last position's logits for next token prediction
    return hidden_states[-1] @ W_output

The waste: We compute K and V for all n tokens, but K₀ through K_{n-2} and V₀ through V_{n-2} are identical to what we computed in the previous step.

With KV Cache

def forward_with_cache(new_token, kv_cache):
    """
    new_token: the single token we just generated (or last token of prompt)
    kv_cache: dictionary storing K and V for all previous positions, per layer
    """
    hidden_states = embed(new_token)  # Shape: [1, hidden_dim] — just ONE token!
    
    for layer_idx, layer in enumerate(transformer_layers):
        # Compute Q, K, V for ONLY the new token
        Q_new = hidden_states @ W_Q  # Shape: [1, head_dim]
        K_new = hidden_states @ W_K  # Shape: [1, head_dim]
        V_new = hidden_states @ W_V  # Shape: [1, head_dim]
        
        # Retrieve cached K and V from all previous positions
        K_cached = kv_cache[layer_idx]['K']  # Shape: [n-1, head_dim]
        V_cached = kv_cache[layer_idx]['V']  # Shape: [n-1, head_dim]
        
        # Concatenate: full K and V including new token
        K_full = concat([K_cached, K_new])  # Shape: [n, head_dim]
        V_full = concat([V_cached, V_new])  # Shape: [n, head_dim]
        
        # Update cache for next generation step
        kv_cache[layer_idx]['K'] = K_full
        kv_cache[layer_idx]['V'] = V_full
        
        # Attention: Q_new attends to ALL positions (full K and V)
        attention_scores = Q_new @ K_full.T / sqrt(d)  # Shape: [1, n]
        # No causal mask needed — Q_new is the last position, 
        # and it's allowed to attend to all previous positions
        attention_weights = softmax(attention_scores)  # Shape: [1, n]
        attention_output = attention_weights @ V_full  # Shape: [1, head_dim]
        
        hidden_states = layer.ffn(attention_output)
    
    return hidden_states @ W_output  # Logits for next token

The Key Differences

Aspect Without Cache With Cache
Tokens processed All n tokens Only 1 new token
K, V computed n vectors each 1 vector each
K, V read from memory None n-1 vectors each (from cache)
Attention matrix size [n, n] [1, n]
Q vectors computed n vectors 1 vector

The transformation: We've traded recomputation for memory storage and retrieval.


Visualizing Cache Growth

Let's trace through generation step by step, showing how the cache evolves.

Initial State: User prompt is ["The", "cat", "sat"]

┌─────────────────────────────────────────────────────────────────────────┐
│  STEP 0: Process prompt (this is the "prefill" phase — more later)     │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│  Input: ["The", "cat", "sat"]  (3 tokens processed in parallel)        │
│                                                                         │
│  Compute:  K₀, K₁, K₂  and  V₀, V₁, V₂                                 │
│                                                                         │
│  KV Cache after this step:                                              │
│  ┌─────────────────────────┐                                           │
│  │ K_cache: [K₀, K₁, K₂]   │  (3 key vectors)                          │
│  │ V_cache: [V₀, V₁, V₂]   │  (3 value vectors)                        │
│  └─────────────────────────┘                                           │
│                                                                         │
│  Output: sample from logits → "on"                                      │
└─────────────────────────────────────────────────────────────────────────┘
                                    │
                                    ▼
┌─────────────────────────────────────────────────────────────────────────┐
│  STEP 1: Generate "on"  (this is the "decode" phase)                   │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│  Input: ["on"]  (only the NEW token)                                   │
│                                                                         │
│  Compute:  K₃ and V₃  (only for new token)                             │
│  Read:     K₀, K₁, K₂ and V₀, V₁, V₂  (from cache)                     │
│                                                                         │
│  Attention: Q₃ attends to [K₀, K₁, K₂, K₃]                             │
│                                                                         │
│  KV Cache after this step:                                              │
│  ┌──────────────────────────────┐                                      │
│  │ K_cache: [K₀, K₁, K₂, K₃]   │  (4 key vectors)                      │
│  │ V_cache: [V₀, V₁, V₂, V₃]   │  (4 value vectors)                    │
│  └──────────────────────────────┘                                      │
│                                                                         │
│  Output: sample from logits → "the"                                     │
└─────────────────────────────────────────────────────────────────────────┘
                                    │
                                    ▼
┌─────────────────────────────────────────────────────────────────────────┐
│  STEP 2: Generate "the"                                                │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│  Input: ["the"]  (only the NEW token)                                  │
│                                                                         │
│  Compute:  K₄ and V₄  (only for new token)                             │
│  Read:     K₀, K₁, K₂, K₃ and V₀, V₁, V₂, V₃  (from cache)             │
│                                                                         │
│  Attention: Q₄ attends to [K₀, K₁, K₂, K₃, K₄]                         │
│                                                                         │
│  KV Cache after this step:                                              │
│  ┌───────────────────────────────────┐                                 │
│  │ K_cache: [K₀, K₁, K₂, K₃, K₄]    │  (5 key vectors)                 │
│  │ V_cache: [V₀, V₁, V₂, V₃, V₄]    │  (5 value vectors)               │
│  └───────────────────────────────────┘                                 │
│                                                                         │
│  Output: sample from logits → "mat"                                     │
└─────────────────────────────────────────────────────────────────────────┘
                                    │
                                    ▼
                               ... and so on

The pattern: Each decode step computes K and V for exactly one token, reads the full cache, and appends to the cache. The cache grows by one token's worth of storage per generation step.


Why Don't We Cache Q?

You might wonder: if we cache K and V, why not cache Q as well?

The answer lies in how attention works during generation.

During the decode phase, we only compute attention for the new token. This new token needs to attend to all previous tokens (using their K and V), but no previous token needs to attend to the new token (causal masking prohibits this).

Let's be precise:

K and V are used by future tokens:

  • K₀ and V₀ are used when token 1 attends to token 0
  • K₀ and V₀ are used again when token 2 attends to token 0
  • K₀ and V₀ are used again when token 3 attends to token 0
  • ... K₀ and V₀ are used at every future generation step

Q is only used by its own token:

  • Q₀ is used for token 0's attention computation, then never again
  • Q₁ is used for token 1's attention computation, then never again
  • When we generate token 3, we need Q₃ to compute attention scores against K₀, K₁, K₂, K₃

The Query vector for a position is used exactly once (when that position computes its attention output), then discarded. There's no benefit to caching it.

To make this concrete with the attention formula:

For position n (the token we're generating):

attention_output_n = softmax(Q_n @ [K₀, K₁, ..., K_n]ᵀ / √d) @ [V₀, V₁, ..., V_n]
                            ↑      ↑                           ↑
                            │      └── Need all previous K     └── Need all previous V
                            │          (read from cache)           (read from cache)
                            │
                            └── Only need Q for current position
                                (computed fresh, used once, discarded)

The Compute-Memory Tradeoff

The KV cache eliminates the O(g²) redundant computation problem, but introduces a memory cost. Let's quantify the tradeoff.

What We Saved (Compute)

Recall from Article 1, for prompt length p and generated tokens g:

Approach K,V Computations
Naive g×p + g²/2
With KV Cache p + g

For p=500, g=200: we reduced from ~120,000 computations to 700 — a 171× reduction.

What We Pay (Memory)

The KV cache for a single request:

Cache size = 2 × num_layers × seq_len × num_heads × head_dim × bytes_per_value

For common models (assuming float16):

Model Layers Heads head_dim Cache per token Cache for 2K tokens
LLaMA-7B 32 32 128 0.5 MB 1.0 GB
LLaMA-13B 40 40 128 0.8 MB 1.6 GB
LLaMA-70B 80 64 128 2.5 MB 5.0 GB
GPT-3 175B 96 96 128 4.5 MB 9.0 GB

This is per request. If you're serving 100 concurrent users with 2K context each, you need 100× this memory just for KV cache.

The Nature of the Tradeoff

┌─────────────────────────────────────────────────────────────────┐
│                     THE FUNDAMENTAL TRADEOFF                    │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│    Without KV Cache              With KV Cache                  │
│    ─────────────────             ─────────────────              │
│    • Compute: O(g²)              • Compute: O(g)                │
│    • Memory:  O(1)               • Memory:  O(seq_len)          │
│      (no cache needed)             (cache grows linearly)       │
│                                                                 │
│    Catastrophically slow         Fast generation, but           │
│    for long sequences            memory-hungry                  │
│                                                                 │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│    This tradeoff is ALWAYS worth it in practice.                │
│    The compute savings dwarf the memory cost.                   │
│                                                                 │
│    But the memory cost creates new challenges:                  │
│    • Limits batch size (concurrent requests)                    │
│    • Limits context length                                      │
│    • Creates memory bandwidth bottleneck         │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

Summary: The KV Cache in One Picture

                        GENERATION WITH KV CACHE
                        
    ┌──────────────────────────────────────────────────────────────┐
    │                                                              │
    │   PREFILL (Process prompt)          DECODE (Generate tokens) │
    │   ════════════════════════          ════════════════════════ │
    │                                                              │
    │   Input: [t₀, t₁, t₂, ...]          Input: [t_new]           │
    │          (all prompt tokens)               (one token)       │
    │                                                              │
    │   Compute: K,V for ALL tokens       Compute: K,V for 1 token │
    │                                                              │
    │   Store: K,V → cache                Read: K,V from cache     │
    │                                     Append: new K,V to cache │
    │                                                              │
    │   Output: first generated token     Output: next token       │
    │                                                              │
    │   ┌─────────────────────┐          Repeat until done         │
    │   │ Cache initialized   │────────────────────────────────────│
    │   │ with prompt K,V     │                                    │
    │   └─────────────────────┘                                    │
    │                                                              │
    └──────────────────────────────────────────────────────────────┘

This diagram introduces the terms prefill and decode — the two phases of inference. The KV cache is what connects them: prefill builds the cache, decode uses and extends it.


What's Next

You now understand:

  1. What the KV cache stores (K and V matrices for all positions, per layer)
  2. How it eliminates redundant computation (read instead of recompute)
  3. How it grows during generation (one token's worth per step)
  4. The compute-memory tradeoff it creates

In Artifact 3, we'll formally define the two phases (prefill and decode) and understand their fundamentally different computational characteristics. The KV cache is why these phases exist and why they behave so differently.


Check Your Understanding

Before moving on:

  1. If you're generating with a 1000-token prompt and want to produce 500 tokens, how many K vectors get computed in total? How many are computed during prefill vs. decode?

  2. Why do we cache K and V but not Q? What's different about how they're used?

  3. For LLaMA-7B (0.5 MB cache per token), what's the KV cache size after processing a 1000-token prompt and generating 500 tokens?

Community

Sign up or log in to comment