Author avatar

Gaurav Singhal

Artistic Neural Style Transfer with PyTorch

Gaurav Singhal

  • Jun 9, 2020
  • 9 Min read
  • 370 Views
  • Jun 9, 2020
  • 9 Min read
  • 370 Views
Data
Data Analytics
Machine Learning
PyTorch

Overview

In this guide, you will implement the algorithm on Neural Network for Artistic Style Transfer (NST) in PyTorch. This algorithm will allow you to get a Picasso-style image. It does so by creating a new image that mixes the style (painting) of one image and the content (input image) of the other.

Thanks to Leon A. Gatys, et al. for their contribution in their paper, A Neural Algorithm of Artistic Style.

To make the model faster and more accurate, a pre-trained VGG-Net-19 (Visual Geometry Group) is used. The model is trained on ImageNet images and can be downloaded from Pytorch API.

Overall picture

The theory behind artistic neural transfer has been covered in previous guides that were based on the TensorFlow framework. It is highly recommended that you go through these guides first. Part 1 talks about theoretical aspects and VGG-Net, and Part 2 talks about losses involved in creating AI digital art.

The guide will be a code walkthrough of the PyTorch implementation. Below is an outline of the process.

Steps

Importing Libraries

To work with PyTorch, import the torch library. The PIL image library will manipulate the image. From the function torchvision, you will import model class and call for vgg19 model.

1
2
3
4
5
6
7
8
9
10
11
%matplotlib inline

from PIL import Image
from io import BytesIO
import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.optim as optim
import requests
from torchvision import transforms, models
python

Build the Model

Load the VGG-Net-19 model and keep pretrained=True. The number 19 denotes the number of layers involved in the network. VGG has its use in the classification problem (face detection) as well. But in NST, you are only dealing with features. The param.requires_grad_() will freeze all VGG parameters since you're only optimizing the target image.

1
2
3
4
vgg = models.vgg19(pretrained=True).features

for param in vgg.parameters():
    param.requires_grad_(False)
python

But there's a catch in PyTorch! Here you have to check if GPU is available in your system. If it is, then move the vgg model to GPU. If it's not, then the model will run on CPU. And if you don't have a GPU? No worries--you can use Google Colab.

1
2
3
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

vgg.to(device)
python

The smaller the image size, the faster the processing speed will be. The transform class of torchvision plays an important role while pre-processing and normalizing the image.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def load_image(img_path, max_size=400, shape=None):
    
    if max(image.size) > max_size:
        size = max_size
    else:
        size = max(image.size)
    
    if shape is not None:
        size = shape
        
    in_transform = transforms.Compose([
                        transforms.Resize(size),
                        transforms.ToTensor(),
                        transforms.Normalize((0.485, 0.456, 0.406), 
                                             (0.229, 0.224, 0.225))])
    image = in_transform(image)[:3,:,:].unsqueeze(0)
    
    return image
python
1
2
3
style = load_image("Picasso.jpg").to(device)

content  = load_image("houses.png").to(device)
python

Display the tensor as an image.

1
2
3
4
5
6
7
8
def im_convert(tensor):    
    image = tensor.to("cpu").clone().detach()
    image = image.numpy().squeeze()
    image = image.transpose(1,2,0)
    image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
    image = image.clip(0, 1)

    return image
python

Display the images side-by-side.

1
2
3
4
5
6
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10))
ax1.imshow(im_convert(content))
ax1.set_title("Content-Image",fontsize = 20)
ax2.imshow(im_convert(style))
ax2.set_title("Style-Image", fontsize = 20)
plt.show()
python

Content and Style image

1
print(vgg)
pythons

Layers

Intermediate Layers for Style and Content

Deeper layers of VGG-19 will extract the best and most complex features. Hence, conv4_2 is assigned to extract content components. From each block, the first convolution layers (shallow layers) i.e. from conv1_1 to conv5_1 detects multiple features like lines or edges. Refer to the image below.

Architecture

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def get_features(image, model, layers=None):
    
    if layers is None:
        layers = {'0': 'conv1_1',
                  '5': 'conv2_1', 
                  '10': 'conv3_1', 
                  '19': 'conv4_1',
                  '30': 'conv5_2', #content
                  '28': 'conv5_1'}
        
    features = {}
    x = image
    for name, layer in model._modules.items():
        x = layer(x)
        if name in layers:
            features[layers[name]] = x
            
    return features
python

Loss Functions

Calculate the gram matrices for each layer. You will need the houses and the lake in the target image. Start by cloning the content image and then iteratively change its style.

1
2
3
4
5
6
7
8
9
def gram_matrix(tensor):
    
    _, d, h, w = tensor.size()
    
    tensor = tensor.view(d, h * w)
    
    gram = torch.mm(tensor, tensor.t())
    
    return gram 
python
1
2
3
4
5
6
content_features = get_features(content, vgg)
style_features = get_features(style, vgg)

style_grams = {layer: gram_matrix(style_features[layer]) for layer in style_features}

target = content.clone().requires_grad_(True).to(device)
python

Assigning Weights

Weights are assigned on each style layer. Weight the earlier layers with a higher number to get the larger style artifacts.

1
2
3
4
5
6
7
8
style_weights = {'conv1_1': 1.5,
                 'conv2_1': 0.80,
                 'conv3_1': 0.25,
                 'conv4_1': 0.25,
                 'conv5_1': 0.25}

content_weight = 1e-2  
style_weight = 1e9  
python

Run the Model

These weights are used in the optimizer (Adam) to reduce the loss of the model. Define steps to update the image. Putting everything together: call the features from the VGG-Net and calculate the content loss. Get the style representation to calculate the style loss. It will weight the layer appropriately before adding it to other layers. Finally, calculate the total loss!

In a gradient, descent for NN weights are adjusted, but in NST, they are kept fixed. Instead, image pixels are adjusted. The gradients concerning the distance measure will be backpropagated to the inputs, thus transforming the inputs.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
show = 400

optimizer = optim.Adam([target], lr=0.01)
steps = 7000  

for ii in range(1, steps+1):
    
    target_features = get_features(target, vgg)
    
    content_loss = torch.mean((target_features['conv4_2'] - content_features['conv4_2'])**2)
    
    style_loss = 0
    for layer in style_weights:
        target_feature = target_features[layer]
        target_gram = gram_matrix(target_feature)
        _, d, h, w = target_feature.shape
        style_gram = style_grams[layer]
        layer_style_loss = style_weights[layer] * torch.mean((target_gram - style_gram)**2)
        style_loss += layer_style_loss / (d * h * w)
        
    total_loss = content_weight * content_loss + style_weight * style_loss
    
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()
    
    if  i % show == 0:
        print('Total loss: ', total_loss.item())
           plt.imshow(im_convert(target))
        plt.show()
python

Output-images-progress 1

Output-images-progress 2

Conclusion

Great work! The loss is reducing after every epoch. Play with your creation by adjusting weights and learning rate. You may go for an LBFGS optimizer.

This guide gave you a general idea of how to code NST using PyTorch. I would recommend going through the NST using TensorFlow for a better understanding of the terms involved (losses, VGG-Net, cost function, etc.).

I hope you learned something new today. If you need any help with machine learning implementations, feel free to contact me.

7