Artificial Intelligence 7 min read

How to Visualize Neural Networks with Python’s NetworkX: Step‑by‑Step Guide

This tutorial explains how to use the Python library NetworkX to create and visualize simple feed‑forward neural network graphs, covering initialization, node and edge addition, layout customization, and rendering with Matplotlib, plus sample code and example images.

Model Perspective
Model Perspective
Model Perspective
How to Visualize Neural Networks with Python’s NetworkX: Step‑by‑Step Guide

NetworkX Overview

NetworkX is a Python library for graph theory and complex network modeling, providing built‑in algorithms for analysis, simulation, and visualization. It supports simple undirected graphs, directed graphs, and multigraphs, with flexible node data and edge attributes.

Key features include:

Strong flexibility: create various graph types (undirected, directed, multigraph) with custom nodes and edges.

High extensibility: many built‑in graph‑theoretic and network algorithms for analysis and manipulation.

Good visualization: can render graphs as images for easy observation and presentation.

Active community: open‑source project with active developer and user communities, offering abundant examples and documentation.

Steps to Draw a Neural Network with NetworkX

Drawing a neural network with NetworkX requires considering the network’s structure and layout. The main steps are:

Initialize the graph: use nx.DiGraph() to create a directed graph, reflecting the forward flow from input to output layers.

Add nodes: create a node for each neuron in every layer, optionally labeling them uniquely (e.g., "Layer_1_Neuron_2").

Determine node positions: store positions in a dictionary to arrange neurons orderly, typically placing layers along the horizontal axis and spacing neurons evenly vertically.

Add edges: connect each neuron in a layer to every neuron in the subsequent layer.

Draw the graph: use nx.draw() with parameters to customize node color, size, shape, edge color, width, and labels.

Tips:

Adjust layout: NetworkX offers layouts like spring_layout and circular_layout , but custom layouts often work best for neural networks to maintain ordered layers.

Beautify the graph: modify node_color , node_size , edge_color , and enable with_labels=True for clear labeling.

Style edges: use edge_color and width to reflect weights or other attributes, optionally adding edge labels or varied line styles.

Add titles and annotations: use plt.title() and other Matplotlib functions for richer labeling.

Scalability: for larger networks, organize code into functions or classes for readability and reuse. For very complex networks, consider complementary tools such as PyTorch or TensorFlow visualizers.

<code>import matplotlib.pyplot as plt
import networkx as nx

def plot_neural_net(layers):
    """Plots a simple feed‑forward neural network graph using NetworkX.
    Args:
        layers (list of ints): each item is the number of neurons in that layer.
        e.g., [2, 3, 1] means 2 input neurons, a hidden layer with 3 neurons, and 1 output neuron.
    """
    G = nx.DiGraph()
    pos = {}
    # Add nodes and their positions for each layer
    for i, layer_size in enumerate(layers):
        for j in range(layer_size):
            node_name = f"Layer_{i}_Neuron_{j}"
            G.add_node(node_name)
            pos[node_name] = (i, j - layer_size / 2)
    # Connect nodes between layers
    for i in range(len(layers) - 1):
        for j in range(layers[i]):
            for k in range(layers[i + 1]):
                G.add_edge(f"Layer_{i}_Neuron_{j}", f"Layer_{i+1}_Neuron_{k}")
    # Draw the graph
    nx.draw(G, pos, with_labels=True, node_size=2000, node_color="skyblue", font_size=10,
            font_weight='bold', width=2, edge_color="gray")
    plt.title("3-layer Neural Network")
    plt.show()

# Define the number of neurons in each layer for a 3‑layer network
layers = [3, 4, 2]
plot_neural_net(layers)
</code>

Below are several alternative visualizations; you can obtain them by tweaking the drawing parameters.

PythonTutorialNetworkXgraph theoryNeural Network Visualization
Model Perspective
Written by

Model Perspective

Insights, knowledge, and enjoyment from a mathematical modeling researcher and educator. Hosted by Haihua Wang, a modeling instructor and author of "Clever Use of Chat for Mathematical Modeling", "Modeling: The Mathematics of Thinking", "Mathematical Modeling Practice: A Hands‑On Guide to Competitions", and co‑author of "Mathematical Modeling: Teaching Design and Cases".

0 followers
Reader feedback

How this landed with the community

login Sign in to like

Rate this article

Was this worth your time?

Sign in to rate
Discussion

0 Comments

Thoughtful readers leave field notes, pushback, and hard-won operational detail here.