Analysis of LLaMA Model Architecture in the Transformers Library
This article walks through the core LLaMA implementation in HuggingFace’s Transformers library, detailing the inheritance hierarchy, configuration defaults, model initialization, embedding and stacked decoder layers, the RMSNorm‑based attention and MLP modules, and the forward pass that produces normalized hidden states.
This document provides a detailed walkthrough of the LLaMA model implementation in the transformers repository, focusing on the core model structure while omitting tensor‑parallel and memory‑saving code.
Model inheritance : The main model class LlamaModel inherits from LlamaPreTrainedModel , which in turn inherits from PreTrainedModel , the base class for all HuggingFace models.
LlamaModel -> LlamaPreTrainedModel -> PreTrainedModelLlamaConfig : The configuration class defines hyper‑parameters such as vocab_size , hidden_size , num_hidden_layers , and num_attention_heads . All parameters have sensible defaults, allowing a config object to be instantiated directly.
config = LlamaConfig()LlamaModel initialization : The constructor sets the padding index and vocabulary size, creates the token embedding layer, a list of decoder layers, RMSNorm layers, and a flag for gradient checkpointing. It also calls post_init() to initialize weights.
def __init__(self, config: LlamaConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList([LlamaDecoderLayer(config, i) for i in range(config.num_hidden_layers)])
self._use_sdpa = config._attn_implementation == "sdpa"
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
self.post_init()The post_init() method (defined in PreTrainedModel ) performs weight initialization and handles backward‑compatibility for gradient checkpointing.
def post_init(self):
"""Execute code after model initialization, e.g., weight init."""
self.init_weights()
self._backward_compatibility_gradient_checkpointing()LlamaModel forward pass : Input token IDs are embedded, then passed through each decoder layer. Hidden states are optionally collected, and the final hidden states are normalized before being returned as a BaseModelOutputWithPast object.
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
hidden_states = self.norm(hidden_states)LlamaDecoderLayer : Each layer contains a self‑attention module, an RMSNorm before and after attention, and an MLP. The forward method saves a residual, applies layer‑norm, runs attention, adds the residual, applies a second layer‑norm, runs the MLP, and adds another residual.
class LlamaDecoderLayer(nn.Module):
def __init__(self, config: LlamaConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
self.mlp = LlamaMLP(config)
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(self, hidden_states, attention_mask=None, position_ids=None,
past_key_value=None, output_attentions=False, use_cache=False, **kwargs):
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
**kwargs,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputsAttention module (LlamaAttention) : Implements multi‑head attention with optional rotary position embeddings. It projects the hidden states into query, key, and value tensors, reshapes them, applies rotary embeddings, computes scaled dot‑product attention, adds the attention mask, applies softmax and dropout, multiplies by the value tensor, reshapes, and finally projects back with o_proj .
class LlamaAttention(nn.Module):
def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
super().__init__()
self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
self._init_rope()
def forward(self, hidden_states, attention_mask=None, position_ids=None,
past_key_value=None, output_attentions=False, use_cache=False, **kwargs):
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights, past_key_valueLlamaMLP : A simple feed‑forward network consisting of a gate projection, an up projection, and a down projection. The gate output is passed through the activation function before being multiplied with the up‑projected tensor.
class LlamaMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)
return down_projLlamaRMSNorm : Implements Root Mean Square Layer Normalization, which normalizes inputs by their RMS value and scales them with a learnable weight.
class LlamaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)The walkthrough concludes that the core components of the LLaMA model—embedding, stacked decoder layers, multi‑head attention, RMSNorm, and MLP—are now fully explained, providing a solid foundation for further experimentation or extension.
Sohu Tech Products
A knowledge-sharing platform for Sohu's technology products. As a leading Chinese internet brand with media, video, search, and gaming services and over 700 million users, Sohu continuously drives tech innovation and practice. We’ll share practical insights and tech news here.
How this landed with the community
Was this worth your time?
0 Comments
Thoughtful readers leave field notes, pushback, and hard-won operational detail here.