Skip to content

Contact sales

By filling out this form and clicking submit, you acknowledge our privacy policy.

Artistic Neural Style Transfer with PyTorch

Jun 9, 2020 • 9 Minute Read

Overview

In this guide, you will implement the algorithm on Neural Network for Artistic Style Transfer (NST) in PyTorch. This algorithm will allow you to get a Picasso-style image. It does so by creating a new image that mixes the style (painting) of one image and the content (input image) of the other.

Thanks to Leon A. Gatys, et al. for their contribution in their paper, A Neural Algorithm of Artistic Style.

To make the model faster and more accurate, a pre-trained VGG-Net-19 (Visual Geometry Group) is used. The model is trained on ImageNet images and can be downloaded from Pytorch API.

The theory behind artistic neural transfer has been covered in previous guides that were based on the TensorFlow framework. It is highly recommended that you go through these guides first. Part 1 talks about theoretical aspects and VGG-Net, and Part 2 talks about losses involved in creating AI digital art.

The guide will be a code walkthrough of the PyTorch implementation. Below is an outline of the process.

Importing Libraries

To work with PyTorch, import the torch library. The PIL image library will manipulate the image. From the function torchvision, you will import model class and call for vgg19 model.

      %matplotlib inline

from PIL import Image
from io import BytesIO
import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.optim as optim
import requests
from torchvision import transforms, models
    

Build the Model

Load the VGG-Net-19 model and keep pretrained=True. The number 19 denotes the number of layers involved in the network. VGG has its use in the classification problem (face detection) as well. But in NST, you are only dealing with features. The param.requires_grad_() will freeze all VGG parameters since you're only optimizing the target image.

      vgg = models.vgg19(pretrained=True).features

for param in vgg.parameters():
    param.requires_grad_(False)
    

But there's a catch in PyTorch! Here you have to check if GPU is available in your system. If it is, then move the vgg model to GPU. If it's not, then the model will run on CPU. And if you don't have a GPU? No worries--you can use Google Colab.

      device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

vgg.to(device)
    

The smaller the image size, the faster the processing speed will be. The transform class of torchvision plays an important role while pre-processing and normalizing the image.

      def load_image(img_path, max_size=400, shape=None):
    
    if max(image.size) > max_size:
        size = max_size
    else:
        size = max(image.size)
    
    if shape is not None:
        size = shape
        
    in_transform = transforms.Compose([
                        transforms.Resize(size),
                        transforms.ToTensor(),
                        transforms.Normalize((0.485, 0.456, 0.406), 
                                             (0.229, 0.224, 0.225))])
    image = in_transform(image)[:3,:,:].unsqueeze(0)
    
    return image
    
      style = load_image("Picasso.jpg").to(device)

content  = load_image("houses.png").to(device)
    

Display the tensor as an image.

      def im_convert(tensor):    
    image = tensor.to("cpu").clone().detach()
    image = image.numpy().squeeze()
    image = image.transpose(1,2,0)
    image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
    image = image.clip(0, 1)

    return image
    

Display the images side-by-side.

      fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10))
ax1.imshow(im_convert(content))
ax1.set_title("Content-Image",fontsize = 20)
ax2.imshow(im_convert(style))
ax2.set_title("Style-Image", fontsize = 20)
plt.show()
    
      print(vgg)
    

Intermediate Layers for Style and Content

Deeper layers of VGG-19 will extract the best and most complex features. Hence, conv4_2 is assigned to extract content components. From each block, the first convolution layers (shallow layers) i.e. from conv1_1 to conv5_1 detects multiple features like lines or edges. Refer to the image below.

      def get_features(image, model, layers=None):
    
    if layers is None:
        layers = {'0': 'conv1_1',
                  '5': 'conv2_1', 
                  '10': 'conv3_1', 
                  '19': 'conv4_1',
                  '30': 'conv5_2', #content
                  '28': 'conv5_1'}
        
    features = {}
    x = image
    for name, layer in model._modules.items():
        x = layer(x)
        if name in layers:
            features[layers[name]] = x
            
    return features
    

Loss Functions

Calculate the gram matrices for each layer. You will need the houses and the lake in the target image. Start by cloning the content image and then iteratively change its style.

      def gram_matrix(tensor):
    
    _, d, h, w = tensor.size()
    
    tensor = tensor.view(d, h * w)
    
    gram = torch.mm(tensor, tensor.t())
    
    return gram
    
      content_features = get_features(content, vgg)
style_features = get_features(style, vgg)

style_grams = {layer: gram_matrix(style_features[layer]) for layer in style_features}

target = content.clone().requires_grad_(True).to(device)
    

Assigning Weights

Weights are assigned on each style layer. Weight the earlier layers with a higher number to get the larger style artifacts.

      style_weights = {'conv1_1': 1.5,
                 'conv2_1': 0.80,
                 'conv3_1': 0.25,
                 'conv4_1': 0.25,
                 'conv5_1': 0.25}

content_weight = 1e-2  
style_weight = 1e9
    

Run the Model

These weights are used in the optimizer (Adam) to reduce the loss of the model. Define steps to update the image. Putting everything together: call the features from the VGG-Net and calculate the content loss. Get the style representation to calculate the style loss. It will weight the layer appropriately before adding it to other layers. Finally, calculate the total loss!

In a gradient, descent for NN weights are adjusted, but in NST, they are kept fixed. Instead, image pixels are adjusted. The gradients concerning the distance measure will be backpropagated to the inputs, thus transforming the inputs.

      show = 400

optimizer = optim.Adam([target], lr=0.01)
steps = 7000  

for ii in range(1, steps+1):
    
    target_features = get_features(target, vgg)
    
    content_loss = torch.mean((target_features['conv4_2'] - content_features['conv4_2'])**2)
    
    style_loss = 0
    for layer in style_weights:
        target_feature = target_features[layer]
        target_gram = gram_matrix(target_feature)
        _, d, h, w = target_feature.shape
        style_gram = style_grams[layer]
        layer_style_loss = style_weights[layer] * torch.mean((target_gram - style_gram)**2)
        style_loss += layer_style_loss / (d * h * w)
        
    total_loss = content_weight * content_loss + style_weight * style_loss
    
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()
    
    if  i % show == 0:
        print('Total loss: ', total_loss.item())
           plt.imshow(im_convert(target))
        plt.show()
    

Conclusion

Great work! The loss is reducing after every epoch. Play with your creation by adjusting weights and learning rate. You may go for an LBFGS optimizer.

This guide gave you a general idea of how to code NST using PyTorch. I would recommend going through the NST using TensorFlow for a better understanding of the terms involved (losses, VGG-Net, cost function, etc.).

I hope you learned something new today. If you need any help with machine learning implementations, feel free to contact me.

Gaurav Singhal

Gaurav S.

Guarav is a Data Scientist with a strong background in computer science and mathematics. He has extensive research experience in data structures, statistical data analysis, and mathematical modeling. With a solid background in Web development he works with Python, JAVA, Django, HTML, Struts, Hibernate, Vaadin, Web Scrapping, Angular, and React. His data science skills include Python, Matplotlib, Tensorflows, Pandas, Numpy, Keras, CNN, ANN, NLP, Recommenders, Predictive analysis. He has built systems that have used both basic machine learning algorithms and complex deep neural network. He has worked in many data science projects, some of them are product recommendation, user sentiments, twitter bots, information retrieval, predictive analysis, data mining, image segmentation, SVMs, RandomForest etc.

More about this author