Neural Style Transfer with PyTorch: Theory and Implementation
This article introduces neural style transfer, explains its underlying principles using VGG19 feature extraction, content and style loss definitions, and provides a complete PyTorch implementation with code for loading images, extracting features, computing Gram matrices, and optimizing the output image.
1. Introduction
Previously, AI was considered unrelated to art, but with the emergence of GANs, style transfer, and diffusion models, AI can now generate images with artistic styles. Style transfer is an algorithm that transfers the visual style of one image (A) onto another image (B), such as applying Van Gogh's style to any picture.
2. Style Transfer
2.1 What is Style Transfer?
Style transfer uses two input images: a content image and a style image. The goal is to produce an output image that preserves the content of the first while adopting the style of the second. The notion of "style" is abstract and often recognized visually (e.g., Van Gogh's "Starry Night" or Monet's "Impression, Sunrise").
Later we will use feature maps to explain the concrete meaning of style.
In practice, we transform a non‑Van‑Gogh image to look like a Van‑Gogh painting, or a non‑Monet image to look like a Monet painting.
2.2 Implementation Principles
2.2.1 Feature Maps
If we denote the style image as A and the content image as B, we want the output image C to have a style similar to A and content similar to B. Simple pixel‑wise comparison is insufficient; instead we use feature maps extracted by a convolutional neural network, which are more robust and capture multi‑scale information.
VGG networks are commonly used for this purpose. Below is the VGG‑19 architecture, which contains five groups of convolutional layers. Early layers capture low‑level textures and edges, while deeper layers capture more abstract representations.
2.2.2 Content Loss
Content loss measures the difference between the content features of the output image and those of the content image. We use the activation of layer conv4_1 (512 feature maps) as the content representation and compute the mean‑squared error between the two feature tensors.
2.2.3 Style Loss
Style loss captures texture and overall style by comparing Gram matrices of feature maps across several layers. The Gram matrix for a layer is computed as the inner product of vectorised feature maps, and the style loss for a layer is the MSE between the Gram matrices of the style image and the output image. The total style loss is a weighted sum over all selected layers.
3. Code Implementation
The following sections provide a complete PyTorch implementation.
3.1 Loading Images
We need three tensors: the content image, the style image, and the generated image (initialized from the content image or random noise). The code loads the images, resizes them to 512×512, converts them to tensors, and moves them to the GPU.
import cv2
import torch
import torchvision.models as models
import torch.nn.functional as F
import torch.nn as nn
from PIL import Image
from torchvision.transforms import Compose, ToTensor, Resize
transform = Compose([
Resize((512, 512)),
ToTensor(),
])
def load_images(content_path, style_path):
content_img = Image.open(content_path)
image_size = content_img.size
content_img = transform(content_img).unsqueeze(0).cuda()
style_img = Image.open(style_path)
style_img = transform(style_img).unsqueeze(0).cuda()
var_img = content_img.clone()
var_img.requires_grad = True
return content_img, style_img, var_img, image_size
content_img, style_img, var_img, image_size = load_images('content.jpeg', 'style.png')3.2 Feature Extraction
We use a pre‑trained VGG‑19 model (with gradients disabled) to extract both content and style features. A custom FeatureExtractor class selects the appropriate layers for content (layer 22) and style (layers 1, 6, 11, 20, 29) and replaces max‑pooling with average pooling.
# Load pre‑trained VGG‑19 and freeze parameters
model = models.vgg19(pretrained=True).cuda()
for params in model.parameters():
params.requires_grad = False
model.eval()
# Normalization utilities
mu = torch.Tensor([0.485, 0.456, 0.406]).unsqueeze(-1).unsqueeze(-1).cuda()
std = torch.Tensor([0.229, 0.224, 0.225]).unsqueeze(-1).unsqueeze(-1).cuda()
unnormalize = lambda x: x * std + mu
normalize = lambda x: (x - mu) / std
class FeatureExtractor(nn.Module):
def __init__(self, model):
super().__init__()
self.module = model.features.cuda().eval()
self.con_layers = [22]
self.sty_layers = [1, 6, 11, 20, 29]
for name, layer in self.module.named_children():
if isinstance(layer, nn.MaxPool2d):
self.module[int(name)] = nn.AvgPool2d(kernel_size=2, stride=2)
def forward(self, tensor: torch.Tensor) -> dict:
sty_feat_maps = []
con_feat_maps = []
x = normalize(tensor)
for name, layer in self.module.named_children():
x = layer(x)
if int(name) in self.con_layers:
con_feat_maps.append(x)
if int(name) in self.sty_layers:
sty_feat_maps.append(x)
return {"content_features": con_feat_maps, "style_features": sty_feat_maps}
model = FeatureExtractor(model)
style_target = model(style_img)["style_features"]
content_target = model(content_img)["content_features"]3.3 Gram Matrix Computation
For each selected style layer we compute its Gram matrix, which will be used in the style loss.
gram_target = []
for i in range(len(style_target)):
b, c, h, w = style_target[i].size()
tensor_ = style_target[i].view(b * c, h * w)
gram_i = torch.mm(tensor_, tensor_.t()).div(b * c * h * w)
gram_target.append(gram_i)3.4 Optimizing the Generated Image
We treat the generated image as a learnable parameter and optimise it with Adam. The total loss is a weighted sum of content loss, style loss, and a total‑variation regularisation term.
optimizer = torch.optim.Adam([var_img], lr=0.01, betas=(0.9, 0.999), eps=1e-8)
lam1 = 1e-3
lam2 = 1e7
lam3 = 5e-3
for itera in range(20001):
optimizer.zero_grad()
output = model(var_img)
sty_output = output["style_features"]
con_output = output["content_features"]
con_loss = torch.tensor([0]).cuda().float()
for i in range(len(con_output)):
con_loss = con_loss + F.mse_loss(con_output[i], con_target[i])
sty_loss = torch.tensor([0]).cuda().float()
for i in range(len(sty_output)):
b, c, h, w = sty_output[i].size()
tensor_ = sty_output[i].view(b * c, h * w)
gram_i = torch.mm(tensor_, tensor_.t()).div(b * c * h * w)
sty_loss = sty_loss + F.mse_loss(gram_i, gram_target[i])
b, c, h, w = style_img.size()
TV_loss = (torch.sum(torch.abs(style_img[:, :, :, :-1] - style_img[:, :, :, 1:])) +
torch.sum(torch.abs(style_img[:, :, :-1, :] - style_img[:, :, 1:, :]))) / (b * c * h * w)
loss = con_loss * lam1 + sty_loss * lam2 + TV_loss * lam3
loss.backward()
var_img.data.clamp_(0, 1)
optimizer.step()
if itera % 100 == 0:
print('itera: %d, con_loss: %.4f, sty_loss: %.4f, TV_loss: %.4f' % (itera,
con_loss.item() * lam1,
sty_loss.item() * lam2,
TV_loss.item() * lam3), '\n\t total loss:',
loss.item())
print('var_img mean:%.4f, std:%.4f' % (var_img.mean().item(), var_img.std().item()))
if itera % 1000 == 0:
save_img = var_img.clone()
save_img = torch.clamp(save_img, 0, 1)
save_img = save_img[0].permute(1, 2, 0).data.cpu().numpy() * 255
save_img = save_img[..., ::-1].astype('uint8')
save_img = cv2.resize(save_img, image_size)
cv2.imwrite('outputs/transfer%d.jpg' % itera, save_img)Running the script produces intermediate results; the image quality improves with more iterations, as shown in the example outputs.
From left to right: style image, content image, result after 1,000 iterations, result after 3,000 iterations. The final images demonstrate successful style transfer.
Code reference: https://blog.csdn.net/Brikie/article/details/115602714
Rare Earth Juejin Tech Community
Juejin, a tech community that helps developers grow.
How this landed with the community
Was this worth your time?
0 Comments
Thoughtful readers leave field notes, pushback, and hard-won operational detail here.