HOME/Articles/

pil example style transfer (snippet)

Article Outline

Python pil example 'style transfer'

Functions in program:

  • def main():
  • def transfer_style(model, content_image_path, style_image_path, gen_img_folder, img_size,
  • def total_variation_loss(tensor):
  • def calculate_style_loss(style_image_features, gen_image_features, style_layer_weights):
  • def gram_matrix(tensor):
  • def calculate_content_loss(content_image_features, gen_image_features, content_layer):
  • def get_features(model, image, vgg_layers):
  • def load_image(img_path, img_size, transform):

Modules used in program:

  • import os
  • import time
  • import torch.nn.functional as F
  • import torchvision.transforms.functional as VF
  • import torch.optim as optim
  • import torch.nn as nn
  • import torch
  • import argparse

python style transfer

Python pil example: style transfer


import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
import torchvision.transforms.functional as VF
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from PIL import Image
import time
import os


parser = argparse.ArgumentParser()
parser.add_argument("--content_weight", type=float, default=1e5)
parser.add_argument("--style_weight", type=float, default=1e2)
parser.add_argument("--tv_weight", type=float, default=1e1)
parser.add_argument("--vgg_model", default="vgg19")
parser.add_argument("--lambda_noise", default=1,  help="G_init = lambda_noise*C+(1-lambda)*N")
parser.add_argument("--content_image_path", required=True)
parser.add_argument("--style_image_path", required=True)
parser.add_argument("--gen_img_folder", default="generated/")
parser.add_argument("--img_size", default=512)
parser.add_argument("--num_iterations", default=50000)
parser.add_argument("--learning_rate", default=1e-3)
# parser.add_argument("--log_dir", default="logs/{}/".format(int(time.time())))
parser.add_argument("--log_dir", default="logs/")


def load_image(img_path, img_size, transform):
    image = Image.open(img_path).convert('RGB')
    image = image.resize((img_size, img_size), Image.BICUBIC)
    image = transform(image)
    return image.unsqueeze(0)


def get_features(model, image, vgg_layers):
    x = image
    img_features = {}
    for num, layer in enumerate(model.features):
        x = layer(x)
        if num in vgg_layers.keys():
            img_features[vgg_layers[num]] = x
    return img_features


def calculate_content_loss(content_image_features, gen_image_features, content_layer):
    feat_content = content_image_features[content_layer]
    feat_gen = gen_image_features[content_layer]
    return F.mse_loss(feat_content, feat_gen)


def gram_matrix(tensor):
    _, c, h, w = tensor.shape
    tensor = tensor.view(c, h*w)
    return torch.mm(tensor, tensor.T)


def calculate_style_loss(style_image_features, gen_image_features, style_layer_weights):
    style_loss = 0
    for layer, weight in style_layer_weights.items():
        feat_style = style_image_features[layer]
        feat_gen = gen_image_features[layer]
        GM_S = gram_matrix(feat_style)
        GM_G = gram_matrix(feat_gen)
        c, c = GM_S.shape
        layer_loss = F.mse_loss(GM_S, GM_G)
        style_loss += weight*layer_loss
    return style_loss


def total_variation_loss(tensor):
    x1 = tensor[:,:,1:,:]
    x2 = tensor[:,:,:-1,:]
    y1 = tensor[:,:,:,1:]
    y2 = tensor[:,:,:,:-1]
    return torch.sum(torch.abs(x1-x2)) + torch.sum(torch.abs(y1-y2))


def transfer_style(model, content_image_path, style_image_path, gen_img_folder, img_size,
        lambda_noise, content_weight, style_weight, tv_weight, content_layer, style_layer_weights, vgg_layers, device, learning_rate, num_iterations, writer):
    # transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
    # untransform = transforms.Compose([transforms.Normalize((-0.485, -0.456, -0.406), (0.229, 0.224, 0.225))])
    transform = transforms.Compose([transforms.ToTensor()])
    content_image = load_image(content_image_path, img_size, transform)
    style_image = load_image(style_image_path, img_size, transform)
    noise_image = torch.rand(img_size, img_size)
    gen_image = lambda_noise*content_image + (1-lambda_noise)*noise_image

    model = model.to(device)
    content_image = content_image.to(device)
    style_image = style_image.to(device)
    gen_image = gen_image.to(device)

    gen_img_path = "{}/{}.jpg".format(gen_img_folder, "000")
    img_to_save = gen_image[0].cpu()
    # VF.to_pil_image(untransform(img_to_save)).save(gen_img_path)
    VF.to_pil_image(img_to_save).save(gen_img_path)

    # replace all max pool with avg pool
    for i, layer in enumerate(model.features):
        if isinstance(layer, torch.nn.MaxPool2d):
            model.features[i] = torch.nn.AvgPool2d(kernel_size=2, stride=2, padding=0)

    for param in model.parameters():
        param.requires_grad=False

    gen_image.requires_grad = True
    optimizer = optim.Adam(params=[gen_image], lr=learning_rate)

    for itr in range(num_iterations):
        optimizer.zero_grad()
        content_image_features = get_features(model, content_image, vgg_layers)
        style_image_features = get_features(model, style_image, vgg_layers)
        gen_image_features = get_features(model, gen_image, vgg_layers)

        content_loss = calculate_content_loss(content_image_features, gen_image_features, content_layer)
        style_loss = calculate_style_loss(style_image_features, gen_image_features, style_layer_weights)
        # tv_loss = total_variation_loss(gen_image)
        # total_loss = content_weight*content_loss + style_weight*style_loss + tv_weight*tv_loss
        total_loss = content_weight*content_loss + style_weight*style_loss

        # print("Iteration = {}: Content Loss = {}, Style Loss = {}, TV Loss = {}, Total Loss = {}".format(itr, content_loss, style_loss, tv_loss, total_loss))
        print("Iteration = {}: Content Loss = {}, Style Loss = {}, Total Loss = {}".format(itr, content_loss, style_loss, total_loss))

        writer.add_scalar("Content Loss", content_loss, itr)
        writer.add_scalar("Style Loss", style_loss, itr)
        # writer.add_scalar("TV Loss", tv_loss, itr)
        writer.add_scalar("Total Loss", total_loss, itr)

        total_loss.backward(retain_graph=True)
        optimizer.step()

        if itr%10==0:
            gen_img_path = "{}/{}.jpg".format(gen_img_folder, int(time.time()))
            img_to_save = gen_image[0].cpu()
            # VF.to_pil_image(untransform(img_to_save)).save(gen_img_path)
            VF.to_pil_image(img_to_save).save(gen_img_path)


def main():
    args = parser.parse_args()
    content_weight = args.content_weight
    style_weight = args.style_weight
    tv_weight = args.tv_weight
    lambda_noise  = args.lambda_noise
    vgg_model = args.vgg_model
    gen_img_folder = args.gen_img_folder
    content_image_path = args.content_image_path
    style_image_path = args.style_image_path
    img_size = args.img_size
    log_dir = args.log_dir
    learning_rate = args.learning_rate
    num_iterations = args.num_iterations

    model_options = {"vgg16": models.vgg16(pretrained=True),
                     "vgg19": models.vgg19(pretrained=True)
                    }
    model = model_options[vgg_model]
    os.makedirs(gen_img_folder, exist_ok=True)
    writer = SummaryWriter(log_dir=log_dir)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Device: {}".format(device))

    vgg_layers = {0: 'conv1_1', 5: 'conv2_1', 10: 'conv3_1', 19: 'conv4_1', 21: 'conv4_2', \
                    28: 'conv5_1', 30: 'conv5_2'} # 30 is content layer

    content_layer = "conv5_2"
    # style_layer_weights = {'conv1_1': 0.75, 'conv2_1': 0.5, 'conv3_1': 0.2, 'conv4_1': 0.2, 'conv5_1': 0.2}
    style_layer_weights = {'conv1_1': 0.2, 'conv2_1': 0.2, 'conv3_1': 0.2, 'conv4_1': 0.2, 'conv5_1': 0.2}

    transfer_style(model, content_image_path, style_image_path, gen_img_folder, img_size, \
        lambda_noise, content_weight, style_weight, tv_weight, content_layer, style_layer_weights, vgg_layers, device, learning_rate, num_iterations, writer)


if __name__=="__main__":
    main()