Artificial Intelligence 13 min read

Training MNIST with Burn on wgpu: From PyTorch to Rust Backend

This tutorial demonstrates how to train a MNIST digit‑recognition model using the Rust‑based Burn framework on top of the cross‑platform wgpu API, covering model export from PyTorch to ONNX, code generation, data loading, training loops, and performance comparison across CPU, GPU, and other backends.

Rare Earth Juejin Tech Community
Rare Earth Juejin Tech Community
Rare Earth Juejin Tech Community
Training MNIST with Burn on wgpu: From PyTorch to Rust Backend

When first encountering PyTorch on a modest laptop, the author struggled with GPU requirements for training CNNs, prompting a search for a cross‑platform, driver‑agnostic API framework.

1. wgpu

wgpu is a pure‑Rust, cross‑platform graphics API that implements the WebGPU standard and runs on Vulkan, Metal, D3D12, OpenGL, and WebGL2/WebGPU in browsers. Its broad driver support makes it a solid foundation for GPU‑accelerated workloads, as used by Firefox and Deno.

2. burn

Burn is a Rust deep‑learning framework emphasizing flexibility, computational efficiency, and portability. It serves as the mediation layer for training arbitrary models on the wgpu backend, offering strong device compatibility and easy model import.

3. Code Walkthrough

The example starts with a standard PyTorch MNIST model:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 8, 3)
        self.conv2 = nn.Conv2d(8, 16, 3)
        self.conv3 = nn.Conv2d(16, 24, 3)
        self.norm1 = nn.BatchNorm2d(24)
        self.dropout1 = nn.Dropout(0.3)
        self.fc1 = nn.Linear(24 * 22 * 22, 32)
        self.fc2 = nn.Linear(32, 10)
        self.norm2 = nn.BatchNorm1d(10)
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = self.norm1(x)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.dropout1(x)
        x = self.norm2(self.fc2(x))
        return F.log_softmax(x, dim=1)

Using torch.save(model.state_dict(), "mnist.pt") , the parameters are saved, then Burn re‑creates an equivalent model in Rust:

#[derive(Module, Debug)]
pub struct Model
{
    conv1: Conv2d
,
    conv2: Conv2d
,
    conv3: Conv2d
,
    norm1: BatchNorm
,
    fc1: Linear
,
    fc2: Linear
,
    norm2: BatchNorm
,
    phantom: core::marker::PhantomData
,
}

The model is exported to ONNX with torch.onnx.export(..., "mnist.onnx") , then converted to Rust source via burn-import :

ModelGen::new()
    .input("./model/mnist.onnx")
    .out_dir("./model/")
    .run_from_script();

After generation, the model can be embedded directly into the binary or loaded from the generated mnist.rs , mnist.bin , and mnist.mpk files.

4. Training Pipeline

Data loading mirrors the PyTorch workflow using Burn’s MNISTDataset and a custom ClassificationBatcher . The training loop employs CrossEntropyLoss and the AdaGrad optimizer, with TrainStep and ValidStep traits implemented for the model.

pub fn train
(config: TrainingConfig, device: B::Device) {
    // create dataloaders, learner, and run training
    let model_trained = learner.fit(dataloader_train, dataloader_test);
    model_trained.save_file("{ARTIFACT_DIR}/model", &CompactRecorder::new())
        .expect("Trained model should be saved successfully");
}

Configuration examples show how to run on CPU, wgpu (GPU), or other backends such as LibTorch with CUDA or Apple Silicon.

5. Results and Observations

Running on the laptop’s integrated GPU dramatically reduces training time (≈37 minutes) and improves accuracy compared to CPU, while monitoring tools confirm GPU utilization and modest CPU load.

Conclusion

Burn, combined with wgpu, provides a viable path to train deep‑learning models in pure Rust across diverse hardware, though some advanced loss functions and optimizers remain unsupported.

Deep LearningRustGPUMNISTBurnONNXwgpu
Rare Earth Juejin Tech Community
Written by

Rare Earth Juejin Tech Community

Juejin, a tech community that helps developers grow.

0 followers
Reader feedback

How this landed with the community

login Sign in to like

Rate this article

Was this worth your time?

Sign in to rate
Discussion

0 Comments

Thoughtful readers leave field notes, pushback, and hard-won operational detail here.