Visualizing PyTorch Neural Network Architecture and Training Process with HiddenLayer, torchviz, TensorBoardX, and Visdom
This tutorial explains how to visualize a PyTorch convolutional neural network's architecture and training dynamics using tools such as HiddenLayer, torchviz, TensorBoardX, and Visdom, providing step‑by‑step code examples and screenshots for each method.
This article demonstrates how to visualize both the structure and training process of a simple convolutional neural network built with PyTorch.
1. Network structure visualization
We first define a basic ConvNet class consisting of two convolutional blocks, a fully‑connected block, and an output layer, then print the model to show its architecture.
<code>import torch
import torch.nn as nn
class ConvNet(nn.Module):
def __init__(self):
super(ConvNet, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(1, 16, 3, 1, 1),
nn.ReLU(),
nn.AvgPool2d(2, 2)
)
self.conv2 = nn.Sequential(
nn.Conv2d(16, 32, 3, 1, 1),
nn.ReLU(),
nn.MaxPool2d(2, 2)
)
self.fc = nn.Sequential(
nn.Linear(32 * 7 * 7, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU()
)
self.out = nn.Linear(64, 10)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
output = self.out(x)
return output</code>Printing the model yields a detailed hierarchical description of each layer.
<code>MyConvNet = ConvNet()
print(MyConvNet)</code>1.1 Visualizing with HiddenLayer
Install the library and generate a graph object for the model.
<code>pip install hiddenlayer</code> <code>import hiddenlayer as h
vis_graph = h.build_graph(MyConvNet, torch.zeros([1, 1, 28, 28]))
vis_graph.theme = h.graph.THEMES["blue"].copy()
vis_graph.save("./demo1.png")</code>The resulting PNG shows the network topology.
1.2 Visualizing with torchviz
Install torchviz and use make_dot to create a Graphviz representation.
<code>pip install torchviz</code> <code>from torchviz import make_dot
x = torch.randn(1, 1, 28, 28).requires_grad_(True)
y = MyConvNet(x)
MyConvNetVis = make_dot(y, params=dict(list(MyConvNet.named_parameters()) + [('x', x)]))
MyConvNetVis.format = "png"
MyConvNetVis.directory = "data"
MyConvNetVis.view()</code>This generates a .gv script and a rendered .png image.
2. Training process visualization
Monitoring loss and accuracy during training helps assess model performance. The tutorial shows how to log these metrics with tensorboardX and HiddenLayer .
2.1 Using tensorboardX
Install the required packages and add the log directory to the system PATH.
<code>pip install tensorboardX
pip install tensorboard</code>Typical training loop with logging:
<code>from tensorboardX import SummaryWriter
logger = SummaryWriter(log_dir="data/log")
optimizer = torch.optim.Adam(MyConvNet.parameters(), lr=3e-4)
loss_func = nn.CrossEntropyLoss()
log_step_interval = 100
for epoch in range(5):
for step, (x, y) in enumerate(train_loader):
predict = MyConvNet(x)
loss = loss_func(predict, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
global_iter_num = epoch * len(train_loader) + step + 1
if global_iter_num % log_step_interval == 0:
print(f"global_step:{global_iter_num}, loss:{loss.item():.2}")
logger.add_scalar("train loss", loss.item(), global_step=global_iter_num)
test_predict = MyConvNet(test_data_x)
_, predict_idx = torch.max(test_predict, 1)
acc = accuracy_score(test_data_y, predict_idx)
logger.add_scalar("test accuracy", acc.item(), global_step=global_iter_num)
img = vutils.make_grid(x, nrow=12)
logger.add_image("train image sample", img, global_step=global_iter_num)
for name, param in MyConvNet.named_parameters():
logger.add_histogram(name, param.data.numpy(), global_step=global_iter_num)</code>Run tensorboard --logdir="./data/log" to launch the visual interface.
2.2 Visualizing training with HiddenLayer
HiddenLayer can dynamically plot loss, accuracy, and weight matrices during training.
<code>import hiddenlayer as hl
import time
history = hl.History()
canvas = hl.Canvas()
optimizer = torch.optim.Adam(MyConvNet.parameters(), lr=3e-4)
loss_func = nn.CrossEntropyLoss()
log_step_interval = 100
for epoch in range(5):
for step, (x, y) in enumerate(train_loader):
predict = MyConvNet(x)
loss = loss_func(predict, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
global_iter_num = epoch * len(train_loader) + step + 1
if global_iter_num % log_step_interval == 0:
test_predict = MyConvNet(test_data_x)
_, predict_idx = torch.max(test_predict, 1)
acc = accuracy_score(test_data_y, predict_idx)
history.log((epoch, step), train_loss=loss, test_acc=acc, hidden_weight=MyConvNet.fc[2].weight)
with canvas:
canvas.draw_plot(history["train_loss"])
canvas.draw_plot(history["test_acc"])
canvas.draw_image(history["hidden_weight"])</code>3. Using Visdom for visualization
Visdom, a Facebook tool for PyTorch, provides a flexible web‑based interface similar to Matplotlib.
<code>pip install visdom</code>Basic usage example:
<code>from visdom import Visdom
from sklearn.datasets import load_iris
import torch, numpy as np
from PIL import Image
vis = Visdom()
# line plot
x = torch.linspace(-6, 6, 100).view([-1, 1])
sigmoid_y = torch.nn.Sigmoid()(x)
tanh_y = torch.nn.Tanh()(x)
relu_y = torch.nn.ReLU()(x)
plot_x = torch.cat([x, x, x], dim=1)
plot_y = torch.cat([sigmoid_y, tanh_y, relu_y], dim=1)
vis.line(X=plot_x, Y=plot_y, win="line plot", env="main", opts={"legend": ["Sigmoid", "Tanh", "ReLU"]})
# scatter plot
iris_x, iris_y = load_iris(return_X_y=True)
vis.scatter(iris_x[:, :2], Y=iris_y+1, win="scatter2d", env="main")
# stem plot
x = torch.linspace(-6, 6, 100).view([-1, 1])
y1 = torch.sin(x)
y2 = torch.cos(x)
plot_x = torch.cat([x, x], dim=1)
plot_y = torch.cat([y1, y2], dim=1)
vis.stem(X=plot_x, Y=plot_y, win="stem plot", env="main", opts={"legend": ["sin", "cos"], "title": "Stem Plot"})
# heatmap
iris_corr = torch.from_numpy(np.corrcoef(iris_x, rowvar=False))
vis.heatmap(iris_corr, win="heatmap", env="main", opts={"title": "Correlation Heatmap"})
# image
img = Image.open("./example.jpg").convert("L")
img_tensor = torch.from_numpy(np.array(img, dtype=np.float32))
vis.image(img_tensor, win="one image", env="MyPlotEnv", opts={"title": "Sample Image"})
# text
vis.text("hello world", win="text plot", env="MyPlotEnv", opts={"title": "Text Visualization"})</code>Start the server with python -m visdom.server and open the provided URL in a browser to explore the visualizations.
Additional notes cover saving and reloading Visdom environments, handling environment names, and retrieving window data via the Visdom API.
Python Programming Learning Circle
A global community of Chinese Python developers offering technical articles, columns, original video tutorials, and problem sets. Topics include web full‑stack development, web scraping, data analysis, natural language processing, image processing, machine learning, automated testing, DevOps automation, and big data.
How this landed with the community
Was this worth your time?
0 Comments
Thoughtful readers leave field notes, pushback, and hard-won operational detail here.