Artificial Intelligence 9 min read

Implementing RNN, LSTM, and GRU with PyTorch

This article introduces the basic architectures of recurrent neural networks (RNN), LSTM, and GRU, explains PyTorch APIs such as nn.RNN, nn.LSTM, nn.GRU, details their parameters, demonstrates code examples for building and testing these models, and provides practical insights for deep learning practitioners.

DataFunTalk
DataFunTalk
DataFunTalk
Implementing RNN, LSTM, and GRU with PyTorch

PyTorch provides two families of recurrent neural network APIs for the basic RNN architecture and its variants (LSTM, GRU): the cell versions (nn.RNNCell, nn.LSTMCell, nn.GRUCell) accept a single time‑step input, while the module versions (nn.RNN, nn.LSTM, nn.GRU) accept an entire sequence.

The typical usage of the encapsulated RNN module is shown by the signature torch.nn.RNN(args, **kwargs) . The underlying state‑update equations are illustrated in the accompanying formula image.

The nn.RNN constructor accepts several important arguments:

input_size: number of features in the input x .

hidden_size: number of features in the hidden state.

num_layers: number of stacked RNN layers.

nonlinearity: activation function (tanh or ReLU, default tanh).

bias: whether to include bias terms (default True).

batch_first: if True, input shape is (batch, seq, feature); otherwise (seq, batch, feature).

dropout: dropout probability applied to intermediate layers (0‑1, default 0).

bidirectional: if True, creates a bidirectional RNN (default False).

To illustrate a concrete RNN, we build a two‑layer, unidirectional network with input dimension 10 and hidden dimension 20. The expected tensor shapes are:

Input x_t : (seq_len, batch, input_size).

Hidden state h_0 : (num_layers * num_directions, batch, hidden_size).

Output output : (seq_len, batch, hidden_size).

New hidden state h_n : same shape as h_0 .

Generating synthetic data for this RNN is done with the following code:

#生成输入数据
input=torch.randn(100,32,10)
h_0=torch.randn(2,32,20)

Running the RNN with the generated tensors yields an output of shape (100, 32, 20) and a hidden state of shape (2, 32, 20), matching the theoretical expectations.

For the cell‑based API, nn.RNNCell expects inputs of shape (batch, input_size) and returns a hidden state of shape (batch, hidden_size), without a sequence dimension.

---

**LSTM** adds a cell state c and three gating mechanisms, enabling long‑term memory. Consequently, an LSTM layer has four times more parameters than a plain RNN. The hidden state consists of both h and c , each shaped (num_layers * num_directions, batch, hidden_size).

The PyTorch LSTM module is instantiated similarly to the RNN module; an example architecture diagram is shown below.

Inspecting the learned weights (e.g., weight_ih_l0 , weight_hh_l0 ) reveals that each LSTM weight matrix is four times larger than its RNN counterpart, confirming the parameter increase.

If no initial hidden state is supplied, PyTorch automatically initializes h_0 and c_0 to zeros.

---

**GRU** shares a similar structure with LSTM but uses only two gates and a single hidden state, resulting in three times the parameters of a standard RNN. The architecture diagram is shown below.

Implementation of a GRU cell follows the same pattern as the LSTM example, with code analogous to the snippets shown for the RNN and LSTM sections.

---

All the code examples and visualizations are extracted from the textbook Python Deep Learning: Based on PyTorch (2nd Edition) (ISBN 978‑7‑111‑71880‑2), which provides a comprehensive, hands‑on guide to deep‑learning techniques using PyTorch.

deep learningneural networksPyTorchGRULSTMRNN
DataFunTalk
Written by

DataFunTalk

Dedicated to sharing and discussing big data and AI technology applications, aiming to empower a million data scientists. Regularly hosts live tech talks and curates articles on big data, recommendation/search algorithms, advertising algorithms, NLP, intelligent risk control, autonomous driving, and machine learning/deep learning.

0 followers
Reader feedback

How this landed with the community

login Sign in to like

Rate this article

Was this worth your time?

Sign in to rate
Discussion

0 Comments

Thoughtful readers leave field notes, pushback, and hard-won operational detail here.