Attention Is All You Need to Implement
Part 1 of 4: Scaled Dot-Product & Multi-Head Attention
TL;DR: Attention is differentiable retrieval — every token computes a weighted combination of all other tokens, with weights learned from the data. This article derives scaled dot-product attention from first principles (including the variance proof for why we scale by ), builds multi-head attention with explicit shape annotations at every step, implements a causal mask for autoregressive decoding, and adds KV-cache for efficient inference. Full tested implementation at rlvr-from-scratch.
Prerequisites: Basic PyTorch (tensors, nn.Module, nn.Linear). Linear algebra (matrix multiplication, transpose).
The Problem With Sequences
If you’re building a sequence model and your sequences are long, you have a latency problem.
Recurrent networks process tokens one at a time. Information from the first token has to travel through every intermediate hidden state to reach the last token — sequential operations. Double the sequence length, double the latency. For a 4,096-token context, that’s 4,096 serial steps before the last token knows anything about the first.
This is not just slow. It’s a fundamental architectural bottleneck. The path length between any two tokens grows linearly with distance. Long-range dependencies have to survive compression through hundreds of hidden states, each one lossy.
Attention solves both problems at once:
| Property | RNN | Attention |
|---|---|---|
| Path length between any two tokens | ||
| Sequential operations | ||
| Computation per layer | ||
| Parallelizable | ❌ | ✅ |
The tradeoff is quadratic computation in sequence length () versus linear in the RNN case. But parallel operations on a GPU are faster than sequential operations. For modern hardware, attention wins.
That’s the engineering reason attention became the dominant sequence modeling primitive. Not elegance, not novelty — parallelism and direct information flow.
This is Part 1 of the Transformer Internals series, where I build a complete transformer from scratch in PyTorch, equation by equation, with tests at every layer. The complete implementation lives at rlvr-from-scratch.
Queries, Keys, and Values
The analogy everyone uses is a database. It’s imperfect, but useful as a starting point.
Imagine a key-value store. You have a query — the thing you’re searching for. Each entry has a key — a descriptor of what it contains. And each entry has a value — the actual content returned when matched.
In attention:
- Query (Q): “What am I looking for?” — derived from the current token
- Key (K): “What do I contain?” — derived from every token in the sequence
- Value (V): “What information do I return if selected?” — also derived from every token
Now break the analogy. A database lookup is hard — you match one key exactly and get one value back. Attention is soft — every key contributes to the output, weighted by how well it matches the query. There’s no binary match/no-match. You get a weighted combination of all values, where the weights reflect relevance.
This is the fundamental insight: attention is differentiable retrieval.
The Linear Projections
We start with an input — a batch of sequences, where each token is a -dimensional vector. We learn three separate projection matrices:
where and .
Why three separate projections? Because what makes a token a good search target (its key) is not the same as what information it should contribute when found (its value), and neither is the same as what the current token is searching for (its query). The model learns to decouple these three roles.
In practice, where is the number of attention heads. More on that later.
import torch
import torch.nn as nn
d_model = 512
d_k = 64 # query/key dimension (typically d_model / n_heads)
d_v = 64 # value dimension
W_Q = nn.Linear(d_model, d_k, bias=False)
W_K = nn.Linear(d_model, d_k, bias=False)
W_V = nn.Linear(d_model, d_v, bias=False)
# X: (B, T, d_model)
Q = W_Q(X) # (B, T, d_model) @ (d_model, d_k) -> (B, T, d_k)
K = W_K(X) # (B, T, d_model) @ (d_model, d_k) -> (B, T, d_k)
V = W_V(X) # (B, T, d_model) @ (d_model, d_v) -> (B, T, d_v)
Each token in the sequence now has three representations — one for each role in the retrieval process.
Key Insight: Q, K, V are not three different inputs. They are three learned views of the same input, each optimized for a different role. The model learns what to search for, what to advertise, and what to return — independently.
Scaled Dot-Product Attention
This is the core operation. Five steps, each with a clear mathematical purpose.
Let me break this apart.
Step 1: Score Every Pair of Tokens
Compute the dot product between every query and every key:
The result is a matrix where entry measures how much token ‘s query aligns with token ‘s key. High value means high relevance. For self-attention, , and you get a attention matrix — every token scored against every other token.
# Q: (B, T_q, d_k), K: (B, T_k, d_k)
scores = torch.matmul(Q, K.transpose(-2, -1)) # (B, T_q, T_k)
This is where the cost comes from. For a 4,096-token sequence, this matrix has ~16.7 million entries per batch element. That’s the price of letting every token see every other token directly.
Step 2: Scale by
This is not cosmetic. Let me derive why it’s necessary.
Assume and are random vectors in with entries independently drawn from . Their dot product is:
Each term is the product of two independent standard normals. The product of two independent variables has:
Since the terms are independent:
So has mean and standard deviation .
What this means in practice: When , dot products have standard deviation . Feed values this large into softmax and you get outputs that are essentially one-hot — one position gets weight , everything else .
Why is this a problem? The gradient of softmax at saturation is near zero. If , then for all . The model can’t learn which tokens to attend to because the gradients vanish.
Dividing by normalizes the variance back to :
Now softmax operates in a regime where it produces meaningful, non-degenerate distributions with healthy gradients.
import math
d_k = Q.size(-1)
scores = scores / math.sqrt(d_k) # (B, T_q, T_k)
Key Insight: The scaling is not a hyperparameter to tune — it’s derived directly from the variance of dot products. Without it, softmax saturates, gradients vanish, and the model cannot learn attention patterns. This is one of those cases where the math isn’t optional.
Step 3: Apply Mask (Optional)
For autoregressive models, we add a causal mask. For now, I’ll treat this as a simple addition — the next section covers masking in depth.
if mask is not None:
scores = scores + mask # additive: 0.0 = allowed, -inf = blocked
Step 4: Softmax
Each row is now a probability distribution over the key positions. Weights are non-negative and sum to 1 along the last dimension. Token ‘s row tells you exactly how much attention it pays to every other token.
weights = torch.softmax(scores, dim=-1) # (B, T_q, T_k) — each row sums to 1
Step 5: Weighted Sum of Values
Each token’s output is a weighted combination of all value vectors. If token attends strongly to token , then token ‘s value contributes heavily to token ‘s output.
output = torch.matmul(weights, V) # (B, T_q, d_v)
The Complete Function
def scaled_dot_product_attention(Q, K, V, mask=None):
"""
Scaled dot-product attention.
Args:
Q: Query tensor (B, H, T_q, d_k)
K: Key tensor (B, H, T_k, d_k)
V: Value tensor (B, H, T_k, d_v)
mask: Additive mask (B|1, 1|H, T_q, T_k)
0.0 = allowed, -inf = blocked
Returns:
output: (B, H, T_q, d_v)
weights: (B, H, T_q, T_k)
"""
d_k = Q.size(-1)
# =========================================
# 1. Score: how much does each query match each key?
# =========================================
# (B, H, T_q, d_k) @ (B, H, d_k, T_k) -> (B, H, T_q, T_k)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
# =========================================
# 2. Mask: block positions that shouldn't be attended to
# =========================================
if mask is not None:
scores = scores + mask
# =========================================
# 3. Normalize: convert scores to probabilities
# =========================================
# (B, H, T_q, T_k) — each row sums to 1
weights = torch.softmax(scores, dim=-1)
# =========================================
# 4. Aggregate: weighted sum of values
# =========================================
# (B, H, T_q, T_k) @ (B, H, T_k, d_v) -> (B, H, T_q, d_v)
output = torch.matmul(weights, V)
return output, weights
Note the H dimension — in practice, attention always runs inside multi-head attention, batched over both the batch dimension and heads.
Key Insight: Attention is three matrix multiplies and a softmax. That’s it. computes relevance, keeps gradients alive, softmax normalizes, and the result retrieves from . Everything else — masking, multiple heads, caching — is engineering on top of this core.
The Causal Mask
Why Masking Matters
In autoregressive (decoder) models, token must only attend to tokens . During generation, future tokens don’t exist yet — looking at them would be cheating.
Without masking, the model sees the answer while trying to predict it. Training would optimize for a trivial copy operation rather than learning to predict.
Additive Masking
Two conventions exist:
| Convention | Operation | Properties |
|---|---|---|
| Boolean | scores[mask] = -inf | In-place mutation, requires boolean tensor |
| Additive | scores = scores + mask | Pure addition, composable, broadcastable |
We use additive masking. The mask tensor contains 0.0 for allowed positions and -inf for blocked positions.
After softmax, — blocked positions get exactly zero attention weight.
position → 0 1 2 3
token 0 [ 0.0, -inf, -inf, -inf ] ← can only see itself
token 1 [ 0.0, 0.0, -inf, -inf ] ← sees token 0 and itself
token 2 [ 0.0, 0.0, 0.0, -inf ] ← sees 0, 1, and itself
token 3 [ 0.0, 0.0, 0.0, 0.0 ] ← sees everything up to itself
Why additive over boolean?
1. Pure operation — no in-place mutation, cleaner for autograd
2. Composable — multiple masks can be summed together (e.g., causal + padding)
3. Broadcastable — shape (1, 1, T, T) works across any batch and head count
def causal_mask(T: int, device=None):
"""
Create additive causal mask.
Returns:
mask: (1, 1, T, T) — 0.0 for allowed, -inf for blocked.
Broadcastable over batch and heads.
"""
# Upper triangle (above diagonal) = True = blocked
mask = torch.triu(torch.ones(T, T, device=device), diagonal=1).bool()
return mask.float().masked_fill(mask, float("-inf")).unsqueeze(0).unsqueeze(0)
The unsqueeze(0).unsqueeze(0) adds batch and head dimensions for broadcasting: (T, T) → (1, 1, T, T).
Key Insight: The causal mask is not a separate mechanism from attention — it’s just an additive bias on the score matrix. Future positions get , softmax converts that to , and those tokens contribute nothing. Masking and attention are the same computation.
Multi-Head Attention
Why Multiple Heads?
A single attention head computes one set of weights — one notion of “relevance” between tokens. But tokens relate to each other in multiple ways simultaneously.
Consider: “The cat sat on the mat because it was tired.”
- One head might learn coreference: “it” attends to “cat”
- Another might learn local context: “it” attends to nearby tokens
- Another might learn semantic roles: “tired” attends to “sat”
A single head can only learn one of these patterns per layer. Multiple heads learn them in parallel.
The Math
With heads and model dimension , each head operates on dimension :
where each head is:
Projection matrices:
- — projects into query space for head
- — projects into key space for head
- — projects into value space for head
- — projects concatenated heads back
The total parameter count is the same as single-head attention at dimension . You’re partitioning the same capacity into parallel subspaces.
Implementation: Reshape, Don’t Loop
The naive approach loops over heads. The efficient approach reshapes.
The key insight: the “split into heads” is just a reshape. No data is copied. Each head sees a different -dimensional slice of the projected representation.
Let me trace the shapes explicitly:
# Concrete example: B=2, T=10, d_model=512, H=8, d_k=64
query = torch.randn(2, 10, 512) # (B, T, d_model)
# =========================================
# 1. Project to full d_model dimension
# =========================================
Q = W_Q(query) # (2, 10, 512) — one Linear layer
# =========================================
# 2. Split into heads: view + transpose
# =========================================
Q = Q.view(2, 10, 8, 64) # (B, T, H, d_k)
Q = Q.transpose(1, 2) # (B, H, T, d_k) = (2, 8, 10, 64)
# =========================================
# 3. Attention (batched over B=2 and H=8)
# =========================================
attn_output, weights = scaled_dot_product_attention(Q, K, V, mask)
# attn_output: (2, 8, 10, 64)
# weights: (2, 8, 10, 10)
# =========================================
# 4. Merge heads: transpose + contiguous + view
# =========================================
attn_output = attn_output.transpose(1, 2) # (2, 10, 8, 64)
attn_output = attn_output.contiguous() # required for view()
attn_output = attn_output.view(2, 10, 512) # (2, 10, 512) = (B, T, d_model)
# =========================================
# 5. Output projection
# =========================================
output = W_O(attn_output) # (2, 10, 512)
Why .contiguous()? After transpose(), the tensor’s memory layout is non-contiguous — the strides don’t match what .view() expects. Without .contiguous(), you get a runtime error. This is the kind of thing that costs you 30 minutes of debugging exactly once.
The Full Module
class MultiHeadAttention(nn.Module):
"""
Multi-head attention with explicit projections.
No torch.nn.MultiheadAttention — every operation visible.
Args:
d_model: Model dimension.
n_heads: Number of attention heads. Must divide d_model.
bias: Whether to use bias in projections.
"""
def __init__(self, d_model: int, n_heads: int, bias: bool = False):
super().__init__()
assert d_model % n_heads == 0, (
f"d_model ({d_model}) must be divisible by n_heads ({n_heads})"
)
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
# =========================================
# Four learned projections
# =========================================
self.W_Q = nn.Linear(d_model, d_model, bias=bias)
self.W_K = nn.Linear(d_model, d_model, bias=bias)
self.W_V = nn.Linear(d_model, d_model, bias=bias)
self.W_O = nn.Linear(d_model, d_model, bias=bias)
def forward(self, query, key, value, mask=None, kv_cache=None):
"""
Args:
query: (B, T_q, d_model)
key: (B, T_k, d_model)
value: (B, T_k, d_model)
mask: Additive mask (B|1, 1|H, T_q, T_k)
kv_cache: Optional (K, V) from previous steps,
each (B, H, T_prev, d_k)
Returns:
output: (B, T_q, d_model)
weights: (B, H, T_q, T_k)
new_kv_cache: Updated (K, V) or None
"""
B, T_q, _ = query.shape
# =========================================
# 1. Project
# =========================================
Q = self.W_Q(query) # (B, T_q, d_model)
K = self.W_K(key) # (B, T_k, d_model)
V = self.W_V(value) # (B, T_k, d_model)
# =========================================
# 2. Split heads
# =========================================
Q = self._split_heads(Q) # (B, H, T_q, d_k)
K = self._split_heads(K) # (B, H, T_k, d_k)
V = self._split_heads(V) # (B, H, T_k, d_k)
# =========================================
# 3. KV-cache (for incremental decoding)
# =========================================
new_kv_cache = None
if kv_cache is not None:
K_prev, V_prev = kv_cache
K = torch.cat([K_prev, K], dim=2) # (B, H, T_prev+T_k, d_k)
V = torch.cat([V_prev, V], dim=2)
new_kv_cache = (K, V)
# =========================================
# 4. Attention
# =========================================
attn_output, weights = scaled_dot_product_attention(Q, K, V, mask)
# =========================================
# 5. Merge heads + output projection
# =========================================
attn_output = self._merge_heads(attn_output) # (B, T_q, d_model)
output = self.W_O(attn_output) # (B, T_q, d_model)
return output, weights, new_kv_cache
def _split_heads(self, x):
"""(B, T, d_model) -> (B, H, T, d_k)"""
B, T, _ = x.shape
return x.view(B, T, self.n_heads, self.d_k).transpose(1, 2)
def _merge_heads(self, x):
"""(B, H, T, d_k) -> (B, T, d_model)"""
B, _, T, _ = x.shape
return x.transpose(1, 2).contiguous().view(B, T, self.d_model)
Parameter Count
Multi-head attention has exactly four weight matrices:
| Parameter | Shape | Count |
|---|---|---|
| Total |
For : parameters. The number of heads doesn’t change this — you’re partitioning the same total dimension.
Key Insight: Multi-head attention doesn’t add parameters compared to single-head attention at the same dimension. It partitions the same capacity into parallel subspaces. The model learns to use each head for a different type of relationship — syntax, semantics, position — without any explicit supervision telling it to do so.
Self-Attention vs Cross-Attention
Same mechanism, two modes:
Self-attention: Q, K, V all come from the same sequence. Each token attends to every other token in the same input. Used in both encoders and decoders.
# Self-attention: same input for all three
output, weights, _ = mha(x, x, x, mask=causal_mask(T))
Cross-attention: Q comes from one sequence (decoder), K and V come from another (encoder output). The decoder queries the encoder’s representation. Used in encoder-decoder models for translation, summarization, and similar tasks.
# Cross-attention: decoder queries, encoder keys/values
output, weights, _ = mha(decoder_state, encoder_output, encoder_output)
Our MultiHeadAttention handles both — the query, key, value arguments are deliberately separate. For self-attention, pass the same tensor for all three. For cross-attention, pass different tensors.
KV-Cache: Making Generation Fast
The Problem
During training, we process the full sequence at once — one forward pass, all positions in parallel. During generation, we decode one token at a time.
At step , the new token needs to attend to all previous tokens plus itself. Without caching, this means recomputing K and V projections for all previous tokens at every step. For a sequence of length , total projection computation scales as:
That’s work just for the linear projections — before we even get to attention.
The Solution
Cache the K and V tensors. At each step:
- Compute K and V for only the new token —
- Concatenate with the cached K, V from all previous steps
- Compute attention using the full cache but only the new Q
# =========================================
# Step t: process one new token
# =========================================
new_Q = W_Q(new_token) # (B, H, 1, d_k) — one token
new_K = W_K(new_token) # (B, H, 1, d_k)
new_V = W_V(new_token) # (B, H, 1, d_k)
# Append to cache
K_cache = torch.cat([K_cache, new_K], dim=2) # (B, H, t, d_k)
V_cache = torch.cat([V_cache, new_V], dim=2) # (B, H, t, d_k)
# Attention: (B, H, 1, d_k) against (B, H, t, d_k)
output, _ = scaled_dot_product_attention(new_Q, K_cache, V_cache)
# output: (B, H, 1, d_k) — one token's representation
Per-step projection cost drops from to . The attention computation itself is still per step — you can’t avoid looking at all previous tokens.
The Correctness Invariant
This is the most important property of a KV-cache implementation:
Incremental decoding with KV-cache must produce the exact same output as a full forward pass with a causal mask.
If it doesn’t, your cache is wrong. Token ‘s output should be identical whether you compute it as part of a full batch or incrementally with cached K, V from steps through .
We test this explicitly:
def test_kv_cache_matches_full_pass(mha):
"""Cached incremental decoding must match full-sequence result."""
mha.eval()
seq = torch.randn(B, T, D_MODEL)
mask = causal_mask(T)
# Full pass (ground truth)
with torch.no_grad():
full_output, _, _ = mha(seq, seq, seq, mask=mask)
# Incremental pass: token by token with KV-cache
cache = (torch.empty(B, H, 0, D_K), torch.empty(B, H, 0, D_K))
incremental_outputs = []
with torch.no_grad():
for t in range(T):
token = seq[:, t:t+1, :] # (B, 1, d_model)
out, _, cache = mha(token, token, token, kv_cache=cache)
incremental_outputs.append(out)
incremental_output = torch.cat(incremental_outputs, dim=1)
# These must match
torch.testing.assert_close(full_output, incremental_output, atol=1e-5, rtol=1e-5)
Key Insight: KV-cache trades memory for time. You store all previous keys and values (memory grows linearly with ) but avoid recomputing them (projection cost per step drops from to ). For long sequences, this is the difference between practical and impractical generation speeds.
Full Implementation
The complete, tested implementation lives at src/rlvr_from_scratch/model/attention.py.
What’s in the module
| Component | What it does | Parameters |
|---|---|---|
scaled_dot_product_attention | Core: score, scale, mask, softmax, aggregate | None (pure function) |
causal_mask | Prevents attending to future tokens | None (deterministic) |
MultiHeadAttention | Projections, head splitting, attention, merging | |
| KV-cache support | Incremental decoding without recomputation | None (caches K, V tensors) |
Test Coverage
The test suite at tests/model/test_attention.py covers 24 tests across three categories:
Correctness:
- Output shapes for all configurations (H=1, 2, 4, 8)
- Attention weights sum to 1
- Causal mask blocks all future positions
- Causal mask allows all past positions and self
- Cross-attention with different Q/K lengths
- KV-cache matches full forward pass
Robustness:
- Numerical stability with large (1024)
- Batch independence (each element processed identically)
- Determinism (same input → same output)
Training:
- Gradient flow through Q, K, V
- Gradient flow through all MHA parameters
- Invalid configuration raises
ValueError
Key Takeaways
The Core Operation
Three matrix multiplies and a softmax. Everything else is engineering.
Design Choices
- Masking convention: Additive (0.0 / -inf) — composable, pure, broadcastable
- Head splitting: Reshape, not loop — same computation, GPU-friendly
- Bias in projections: Off by default — modern standard (GPT-2+, LLaMA)
- KV-cache: Concatenation-based — simple, correct, testable
What’s Next
Attention is permutation equivariant — shuffle the input tokens and you get the same output (modulo the shuffling). The model has no sense of order. Token 0 and token 99 are treated identically.
In Part 2: Positional Encoding, I build sinusoidal, learned, and rotary position embeddings from scratch, derive the rotation matrix formulation of RoPE, and show why RoPE won.
After that, Part 3 assembles the full transformer block (attention + FFN + normalization + residuals), and Part 4 builds the training loop with AdamW and cosine warmup from scratch.
Further Reading
Original Papers:
- Attention Is All You Need (Vaswani et al., 2017)
- An Image Is Worth 16x16 Words (Dosovitskiy et al., 2020) — attention beyond NLP
- FlashAttention: Fast and Memory-Efficient Exact Attention (Dao et al., 2022) — IO-aware implementation
Pedagogical Resources:
- The Illustrated Transformer (Alammar, 2018)
- The Annotated Transformer (Rush, 2018)
Implementation:
- rlvr-from-scratch — the tested implementation from this article
Cite this reference
Sousa, V. (2026). Attention Is All You Need to Implement. vitorsousa.com (Foundation Reference). https://www.vitorsousa.com/foundations//
@article{sousa2026,
title={Attention Is All You Need to Implement},
author={Sousa, Vitor},
year={2026},
note={Foundation Reference},
url={https://www.vitorsousa.com/foundations//}
} Enjoyed this? Get notified when I publish new references.
Subscribe via RSS
Discussion
Found something useful, spotted an error, or want to add context? Comments are powered by GitHub Discussions.