Step‑by‑Step Implementation of Transformer Blocks, Attention, Normalization, Feed‑Forward, Encoder and Decoder in PyTorch
This article provides a comprehensive tutorial on building the core components of a Transformer model—including multi‑head attention, layer normalization, feed‑forward networks, encoder and decoder layers—and assembles them into a complete PyTorch implementation, supplemented with explanatory diagrams and runnable code.
The article introduces a series of linked articles for learning about ChatGPT and Transformer fundamentals, then focuses on the concrete implementation of a generic Transformer block.
Transformer Block Overview – A diagram (not reproduced) shows the encoder side consisting of multi‑head self‑attention, two normalization layers, a feed‑forward network, and an additional weight‑normalization layer.
Attention Implementation – The self‑attention mechanism is explained conceptually (query, key, value) and the following PyTorch function is provided:
def attention(query: Tensor,
key: Tensor,
value: Tensor,
mask: Optional[Tensor] = None,
dropout: float = 0.1):
"""Calculate attention scores and output.
Args:
query: shape (batch_size, num_heads, seq_len, k_dim)
key: shape (batch_size, num_heads, seq_len, k_dim)
value: shape (batch_size, num_heads, seq_len, v_dim)
mask: shape (batch_size, num_heads, seq_len, seq_len)
Returns:
out: shape (batch_size, v_dim)
attention_score: shape (seq_len, seq_len)
"""
k_dim = query.size(-1)
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(k_dim)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e10)
attention_score = F.softmax(scores, dim=-1)
if dropout is not None:
attention_score = dropout(attention_score)
out = torch.matmul(attention_score, value)
return out, attention_scoreThe article then shows how this function is used inside a MultiHeadedAttention module:
class MultiHeadedAttention(nn.Module):
def __init__(self, num_heads: int, d_model: int, dropout: float = 0.1):
super(MultiHeadedAttention, self).__init__()
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.k_dim = d_model // num_heads
self.num_heads = num_heads
self.proj_weights = clones(nn.Linear(d_model, d_model), 4) # W^Q, W^K, W^V, W^O
self.attention_score = None
self.dropout = nn.Dropout(p=dropout)
def forward(self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None):
"""Multi‑head attention forward pass.
Returns:
out: shape (batch_size, seq_len, d_model)
"""
if mask is not None:
mask = mask.unsqueeze(1)
batch_size = query.size(0)
# 1) Linear projections
query, key, value = [proj_weight(x).view(batch_size, -1, self.num_heads, self.k_dim)
.transpose(1, 2)
for proj_weight, x in zip(self.proj_weights, [query, key, value])]
# 2) Attention
out, self.attention_score = attention(query, key, value, mask=mask, dropout=self.dropout)
# 3) Concatenate heads
out = out.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.k_dim)
# 4) Final linear
out = self.proj_weights[-1](out)
return outNormalization Layer – A simple layer‑norm implementation is given:
class NormLayer(nn.Module):
def __init__(self, features, eps=1e-6):
super(NormLayer, self).__init__()
self.a_2 = nn.Parameter(torch.ones(features))
self.b_2 = nn.Parameter(torch.zeros(features))
self.eps = eps
def forward(self, x):
mean = x.mean(-1, keepdim=True)
std = x.std(-1, keepdim=True)
return self.a_2 * (x - mean) / (std + self.eps) + self.b_2Feed‑Forward Network – The standard two‑layer feed‑forward module:
class FeedForward(nn.Module):
def __init__(self, d_model, d_ff=2048, dropout=0.1):
super(FeedForward, self).__init__()
self.linear_1 = nn.Linear(d_model, d_ff)
self.dropout = nn.Dropout(dropout)
self.linear_2 = nn.Linear(d_ff, d_model)
def forward(self, x):
x = self.dropout(F.relu(self.linear_1(x)))
x = self.linear_2(x)
return xEncoder Layer – Combines the above components into a single encoder block and stacks them:
class EncoderLayer(nn.Module):
def __init__(self, d_model, heads, dropout=0.1):
super(EncoderLayer, self).__init__()
self.norm_1 = Norm(d_model)
self.norm_2 = Norm(d_model)
self.attn = MultiHeadAttention(heads, d_model, dropout=dropout)
self.ff = FeedForward(d_model, dropout=dropout)
self.dropout_1 = nn.Dropout(dropout)
self.dropout_2 = nn.Dropout(dropout)
def forward(self, x, mask):
x2 = self.norm_1(x)
x = x + self.dropout_1(self.attn(x2, x2, x2, mask))
x2 = self.norm_2(x)
x = x + self.dropout_2(self.ff(x2))
return x
class Encoder(nn.Module):
def __init__(self, vocab_size, d_model, N, heads, dropout):
super(Encoder, self).__init__()
self.N = N
self.embed = Embedder(d_model, vocab_size)
self.pe = PositionalEncoder(d_model, dropout=dropout)
self.layers = get_clones(EncoderLayer(d_model, heads, dropout), N)
self.norm = Norm(d_model)
def forward(self, src, mask):
x = self.embed(src)
x = self.pe(x)
for i in range(self.N):
x = self.layers[i](x, mask)
return self.norm(x)Decoder Layer – Mirrors the encoder but adds a second attention over encoder outputs:
class DecoderLayer(nn.Module):
def __init__(self, d_model, heads, dropout=0.1):
super(DecoderLayer, self).__init__()
self.norm_1 = Norm(d_model)
self.norm_2 = Norm(d_model)
self.norm_3 = Norm(d_model)
self.dropout_1 = nn.Dropout(dropout)
self.dropout_2 = nn.Dropout(dropout)
self.dropout_3 = nn.Dropout(dropout)
self.attn_1 = MultiHeadAttention(heads, d_model, dropout=dropout)
self.attn_2 = MultiHeadAttention(heads, d_model, dropout=dropout)
self.ff = FeedForward(d_model, dropout=dropout)
def forward(self, x, e_outputs, src_mask, trg_mask):
x2 = self.norm_1(x)
x = x + self.dropout_1(self.attn_1(x2, x2, x2, trg_mask))
x2 = self.norm_2(x)
x = x + self.dropout_2(self.attn_2(x2, e_outputs, e_outputs, src_mask))
x2 = self.norm_3(x)
x = x + self.dropout_3(self.ff(x2))
return x
class Decoder(nn.Module):
def __init__(self, vocab_size, d_model, N, heads, dropout):
super(Decoder, self).__init__()
self.N = N
self.embed = Embedder(vocab_size, d_model)
self.pe = PositionalEncoder(d_model, dropout=dropout)
self.layers = get_clones(DecoderLayer(d_model, heads, dropout), N)
self.norm = Norm(d_model)
def forward(self, trg, e_outputs, src_mask, trg_mask):
x = self.embed(trg)
x = self.pe(x)
for i in range(self.N):
x = self.layers[i](x, e_outputs, src_mask, trg_mask)
return self.norm(x)Full Transformer Model – The encoder and decoder are combined into a minimal viable product:
class Transformer(nn.Module):
def __init__(self, src_vocab, trg_vocab, d_model, N, heads, dropout):
super(Transformer, self).__init__()
self.encoder = Encoder(src_vocab, d_model, N, heads, dropout)
self.decoder = Decoder(trg_vocab, d_model, N, heads, dropout)
self.out = nn.Linear(d_model, trg_vocab)
def forward(self, src, trg, src_mask, trg_mask):
e_outputs = self.encoder(src, src_mask)
d_output = self.decoder(trg, e_outputs, src_mask, trg_mask)
output = self.out(d_output)
return outputThe article concludes that the core Transformer code is now complete and points readers to the GitHub repository black‑transformer for the full project.
Nightwalker Tech
[Nightwalker Tech] is the tech sharing channel of "Nightwalker", focusing on AI and large model technologies, internet architecture design, high‑performance networking, and server‑side development (Golang, Python, Rust, PHP, C/C++).
How this landed with the community
Was this worth your time?
0 Comments
Thoughtful readers leave field notes, pushback, and hard-won operational detail here.