Wasserstein GAN (WGAN): Theory and Hands‑On Implementation
This article explains why traditional GANs suffer from training instability, introduces the Wasserstein (Earth‑Mover) distance as a smoother alternative, derives the WGAN objective, discusses Lipschitz constraints, and provides practical PyTorch code modifications to convert a vanilla GAN into a stable WGAN.
Wasserstein GAN (WGAN): Theory and Hands‑On Implementation
Introduction
The author, "Bald‑Headed Xiao‑Su", continues a series on Generative Adversarial Networks (GANs) and presents a detailed walkthrough of Wasserstein GAN (WGAN), which aims to solve the notorious training difficulties of classic GANs.
Why GANs Are Hard to Train
Standard GANs minimize a loss that is equivalent to the Jensen‑Shannon (JS) divergence between the real data distribution p_r and the generator distribution p_g . When the two distributions do not overlap, the JS divergence stays at log 2 , causing the generator gradient to vanish.
KL and JS Divergence
The article briefly defines KL divergence and shows how JS divergence can be expressed as a symmetrized KL. It demonstrates that when the supports of p_r and p_g are disjoint, both KL terms become zero, leaving the constant log 2 .
Earth‑Mover (Wasserstein) Distance
To obtain a smoother metric, WGAN replaces the JS divergence with the Wasserstein (Earth‑Mover) distance: W(p_r, p_g) = \inf_{\gamma \in \Pi(p_r, p_g)} \mathbb{E}_{(x,y)\sim\gamma}[|x-y|] This distance measures the minimal “cost” of transporting mass from one distribution to the other, and it varies smoothly even when the supports are disjoint. Lipschitz Constraint Using the Kantorovich‑Rubinstein duality, the Wasserstein distance can be written as: W(p_r, p_g) = \frac{1}{K}\sup_{\|f\|_L \le K}\big(\mathbb{E}_{x\sim p_r}[f(x)]-\mathbb{E}_{x\sim p_g}[f(x)]\big) Here f must be K‑Lipschitz, i.e., |f(x_1)-f(x_2)| \le K|x_1-x_2| . In practice, the discriminator (called the “critic”) is trained to approximate such a Lipschitz function. Practical WGAN Implementation (PyTorch) The author lists four modifications required to turn a vanilla GAN into a WGAN: Remove the sigmoid activation from the critic’s output. Do not apply a logarithm to the loss; use raw expectations. Clip the critic’s weights after each update (e.g., to [-0.01, 0.01] ). Prefer RMSProp over momentum‑based optimizers such as Adam. Key code snippets: # Define loss (original GAN) criterion = nn.BCEWithLogitsLoss(reduction='mean') # For WGAN, replace with raw means # Critic loss d_loss = -(torch.mean(d_out_real.view(-1))) - torch.mean(d_out_fake.view(-1)) # Generator loss g_loss = -torch.mean(d_out_fake.view(-1)) # Weight clipping for p in D.parameters(): p.data.clamp_(-0.01, 0.01) # Optimizers optimizerG = torch.optim.RMSprop(G.parameters(), lr=5e-5) optimizerD = torch.optim.RMSprop(D.parameters(), lr=5e-5) The article also includes several illustrative figures (e.g., EM‑distance examples, KL/JS diagrams, and the WGAN training flowchart) to help readers visualise the concepts. Conclusion Understanding the theoretical foundation—why JS divergence leads to gradient vanishing and how the Wasserstein distance provides meaningful gradients—allows practitioners to implement a stable WGAN with only a few code changes. The author encourages readers to experiment with the four modifications and explore the referenced papers for deeper insight. References Links to seminal WGAN papers, GAN‑JS divergence analysis, WGAN‑GP, Spectral Normalization, and related lecture notes.
Rare Earth Juejin Tech Community
Juejin, a tech community that helps developers grow.
How this landed with the community
Was this worth your time?
0 Comments
Thoughtful readers leave field notes, pushback, and hard-won operational detail here.