A Comprehensive Introduction to BERT: Architecture, Pre‑training, and Implementation
This article provides an in‑depth overview of BERT, covering its NLP background, GLUE benchmark achievements, Transformer‑based architecture, pre‑training strategies (MLM and NSP), downstream fine‑tuning methods, and includes detailed PyTorch code implementations of its core components.
The article begins with a brief introduction to Natural Language Processing (NLP) and the significance of BERT as a breakthrough model in 2018, highlighting its record‑breaking performance on the GLUE benchmark across 11 tasks.
It explains the four major NLP task categories—sequence labeling, classification, sentence‑pair relationship, and generation—and shows how BERT excels in each, especially through its bidirectional Transformer encoder.
The core language‑model concepts are reviewed, including n‑gram models, their limitations, and the shift to neural language models and pre‑training techniques such as Word2Vec, ELMo, and GPT, leading up to BERT’s design.
BERT’s architecture is described in detail: a stack of identical Transformer encoder layers, each containing multi‑head self‑attention and position‑wise feed‑forward sub‑layers, with residual connections and layer normalization.
The input representation combines token, segment, and positional embeddings, using special tokens [CLS] and [SEP] for single‑sentence and sentence‑pair inputs.
Pre‑training strategies are explained: Masked Language Modeling (MLM) randomly masks 15% of tokens and predicts them, while Next Sentence Prediction (NSP) learns sentence‑pair relationships. The article notes the trade‑offs of MLM and the modifications introduced by BERT.
Four downstream fine‑tuning approaches are outlined for sentence‑pair classification, single‑sentence classification, question answering (e.g., SQuAD), and sequence labeling (e.g., NER), each adding a simple classification head on top of the [CLS] token.
Implementation details are provided using PyTorch. The following code snippets illustrate the core components:
class MaskedLanguageModel(nn.Module):
def __init__(self, hidden, vocab_size):
super(MaskedLanguageModel, self).__init__()
self.linear = nn.Linear(hidden, vocab_size)
self.softmax = nn.LogSoftmax(dim=-1)
def forward(self, x):
return self.softmax(self.linear(x)) class NextSentencePrediction(nn.Module):
def __init__(self, hidden):
super(NextSentencePrediction, self).__init__()
self.linear = nn.Linear(hidden, 2)
self.softmax = nn.LogSoftmax(dim=-1)
def forward(self, x):
return self.softmax(self.linear(x[:, 0])) class BERT(PreTrainedBERTModel):
def __init__(self, vocab_size, hidden=768, n_layers=12, attn_heads=12, dropout=0.1):
config = BertConfig(vocab_size, hidden_size=hidden, num_hidden_layers=n_layers,
num_attention_heads=attn_heads, hidden_dropout_prob=dropout)
super(BERT, self).__init__(config)
self.hidden = hidden
self.n_layers = n_layers
self.attn_heads = attn_heads
self.feed_forward_hidden = hidden * 4
self.embedding = BERTEmbedding(vocab_size=vocab_size, embed_size=hidden)
self.transformer_blocks = nn.ModuleList([TransformerBlock(hidden, attn_heads, hidden * 4, dropout) for _ in range(n_layers)])
def forward(self, x, segment_info):
mask = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1)
x = self.embedding(x, segment_info)
for transformer in self.transformer_blocks:
x = transformer.forward(x, mask)
return x class TransformerBlock(nn.Module):
def __init__(self, hidden, attn_heads, feed_forward_hidden, dropout):
super(TransformerBlock, self).__init__()
self.attention = MultiHeadedAttention(h=attn_heads, d_model=hidden)
self.feed_forward = PositionwiseFeedForward(d_model=hidden, d_ff=feed_forward_hidden, dropout=dropout)
self.input_sublayer = SublayerConnection(size=hidden, dropout=dropout)
self.output_sublayer = SublayerConnection(size=hidden, dropout=dropout)
self.dropout = nn.Dropout(p=dropout)
def forward(self, x, mask):
x = self.input_sublayer(x, lambda _x: self.attention.forward(_x, _x, _x, mask=mask))
x = self.output_sublayer(x, self.feed_forward)
return self.dropout(x) class MultiHeadedAttention(nn.Module):
def __init__(self, h, d_model, dropout=0.1):
super().__init__()
assert d_model % h == 0
self.d_k = d_model // h
self.h = h
self.linear_layers = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(3)])
self.output_linear = nn.Linear(d_model, d_model)
self.attention = Attention()
self.dropout = nn.Dropout(p=dropout)
def forward(self, query, key, value, mask=None):
batch_size = query.size(0)
query, key, value = [l(x).view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
for l, x in zip(self.linear_layers, (query, key, value))]
x, attn = self.attention(query, key, value, mask=mask, dropout=self.dropout)
x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.h * self.d_k)
return self.output_linear(x) class Attention(nn.Module):
def forward(self, query, key, value, mask=None, dropout=None):
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(query.size(-1))
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
p_attn = F.softmax(scores, dim=-1)
if dropout is not None:
p_attn = dropout(p_attn)
return torch.matmul(p_attn, value), p_attn class SublayerConnection(nn.Module):
def __init__(self, size, dropout):
super(SublayerConnection, self).__init__()
self.norm = LayerNorm(size)
self.dropout = nn.Dropout(dropout)
def forward(self, x, sublayer):
return x + self.dropout(sublayer(self.norm(x)))Finally, the article lists key references (Devlin et al., 2018; Vaswani et al., 2017; etc.) and provides author information, emphasizing the relevance of BERT for future NLP research and industrial applications.
DataFunTalk
Dedicated to sharing and discussing big data and AI technology applications, aiming to empower a million data scientists. Regularly hosts live tech talks and curates articles on big data, recommendation/search algorithms, advertising algorithms, NLP, intelligent risk control, autonomous driving, and machine learning/deep learning.
How this landed with the community
Was this worth your time?
0 Comments
Thoughtful readers leave field notes, pushback, and hard-won operational detail here.