Back to writing

Transformers Architecture: A Deep Dive

3 min read

The Transformer architecture, introduced in the seminal paper "Attention Is All You Need" (Vaswani et al., 2017), fundamentally changed how we approach sequence-to-sequence tasks. Let's break down why this matters.

The Problem with RNNs

Traditional recurrent neural networks had three critical limitations:

  1. Sequential processing - Can't parallelize training
  2. Vanishing gradients - Struggles with long sequences
  3. Limited context - Hidden state bottleneck

LSTMs and GRUs helped, but didn't solve the parallelization problem.

Enter Self-Attention

The breakthrough insight: what if we could attend to all positions in the input simultaneously?

def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    Compute scaled dot-product attention.
    
    Args:
        Q: Queries (batch, heads, seq_len, d_k)
        K: Keys (batch, heads, seq_len, d_k)
        V: Values (batch, heads, seq_len, d_v)
        mask: Optional mask
    """
    d_k = Q.size(-1)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
    
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    
    attn_weights = F.softmax(scores, dim=-1)
    output = torch.matmul(attn_weights, V)
    
    return output, attn_weights

Why It Works

The attention mechanism computes three things for each token:

This creates a dynamic, context-aware representation.

Multi-Head Attention

Instead of one attention mechanism, we use multiple "heads" in parallel:

| Head | Purpose | Example Focus | |------|---------|---------------| | 1 | Syntax | Subject-verb agreement | | 2 | Semantics | Related concepts | | 3 | Discourse | Coreference resolution | | 4 | Position | Local context |

Each head learns different patterns, then we concatenate and project:

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0
        
        self.d_k = d_model // num_heads
        self.num_heads = num_heads
        
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
    
    def forward(self, Q, K, V, mask=None):
        batch_size = Q.size(0)
        
        # Linear projections in batch from d_model => h x d_k
        Q = self.W_q(Q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(K).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(V).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        
        # Apply attention
        x, attn = scaled_dot_product_attention(Q, K, V, mask)
        
        # Concatenate heads and apply final linear
        x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.d_k)
        return self.W_o(x)

Positional Encoding

Since transformers process all tokens in parallel, we need to inject positional information:

def positional_encoding(seq_len, d_model):
    pos = torch.arange(seq_len).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
    
    pe = torch.zeros(seq_len, d_model)
    pe[:, 0::2] = torch.sin(pos * div_term)
    pe[:, 1::2] = torch.cos(pos * div_term)
    
    return pe

Key insight: Sinusoidal encodings allow the model to extrapolate to sequence lengths not seen during training.

The Complete Architecture

The full transformer consists of:

  1. Encoder Stack (N=6 layers)

    • Multi-head self-attention
    • Feed-forward network
    • Layer normalization + residual connections
  2. Decoder Stack (N=6 layers)

    • Masked multi-head self-attention
    • Encoder-decoder attention
    • Feed-forward network
    • Layer normalization + residual connections

Why This Matters Today

The transformer architecture enabled:


The beauty of transformers lies in their simplicity and scalability. By replacing recurrence with attention, we unlocked unprecedented parallelization and the ability to model long-range dependencies effectively.

Next up: We'll explore how attention patterns emerge during training and what they reveal about language understanding.

Enjoying this article?

Get deep technical guides like this delivered weekly.

Get AI growth insights weekly

Join engineers and product leaders building with AI. No spam, unsubscribe anytime.

Keep reading