Introduction to Generative Adversarial NetworksΒΆ

Classification and regression models are used for predictive tasks, they map diverse inputs to fixed outputs, these class of models are called discriminative models. Generative Models do the opposite, they generate diverse outputs from fixed inputs. An example generative model is a model that can generate new pictures of cars simply from a text description. Different generative models exist, the most successful are Generative Adversarial Networks by Gooodfellow et al,2014 These models consist of a generator model which is responsible for generating new outputs, and a discriminator model that attempts to tell if the generated outputs are real or fake. During training, the discriminator is presented with both real and generated images. The discriminator is trained is trained to correctly tell the real images apart from generated images, while the generator is trained to generate images that are so real that the discriminaor will classify them as real. Hence, the two networks are competing with each other and the generator is trying to fool the discriminator. While the logic of GANs can be slighly complicated, TorchFusion makes using them a breeze and provides a highly sophiscated framework for doing research with custom GAN logic.

Below is are two pictures generated by a GAN. Karras et al,2017

../_images/gan.png

UNCONDITIONAL GAN EXAMPLE

Earlier on, we learnt to correctly classify grayscale fashion images, now we shall attempt to generate them instead.

Step 1: Imports!

from torchfusion.gan.learners import *
from torchfusion.gan.applications import StandardGenerator,StandardProjectionDiscriminator
from torch.optim import Adam
from torchfusion.datasets import fashionmnist_loader
import torch.cuda as cuda
import torch.nn as nn

Define Generator and Discriminator

G = StandardGenerator(output_size=(1,32,32),latent_size=128)
D = StandardProjectionDiscriminator(input_size=(1,32,32),apply_sigmoid=False)

if cuda.is_available():
    G = nn.DataParallel(G.cuda())
    D = nn.DataParallel(D.cuda())

Here, we use predefined Generator and Discriminator in torchfusion, we set the size of the generated images to be 1,32,32 and the latent_size as 128. The images will be generated from the latent_code which will be of the size 128.

Setup optimizers

g_optim = Adam(G.parameters(),lr=0.0002,betas=(0.5,0.999))
d_optim = Adam(D.parameters(),lr=0.0002,betas=(0.5,0.999))

Since our generator and discriminator are separately trained, we need to specify different optimizers for them, try to stick to the hyper-parameters here as GANs can be very sensitive to this values.

load dataset

dataset = fashionmnist_loader(size=32,batch_size=64)

The image size here is set to be the same as the size of the images to be generated.

Define the learner

learner = RStandardGanLearner(G,D)

The Learner does all the heavy-lifting

Train the Models

if __name__ == "__main__":
    learner.train(dataset,gen_optimizer=g_optim,disc_optimizer=d_optim,save_outputs_interval=500,model_dir="./fashion-gan",latent_size=128,num_epochs=50,batch_log=False)

By specifying the save_outputs_interval as 500, every 500 batch iterations it will print sample generated immages. Note that this is different from number of epochs.

Putting it all Together

from torchfusion.gan.learners import *
from torchfusion.gan.applications import StandardGenerator,StandardProjectionDiscriminator
from torch.optim import Adam
from torchfusion.datasets import fashionmnist_loader
import torch.cuda as cuda
import torch.nn as nn

G = StandardGenerator(output_size=(1,32,32),latent_size=128)
D = StandardProjectionDiscriminator(input_size=(1,32,32),apply_sigmoid=False)

if cuda.is_available():
    G = nn.DataParallel(G.cuda())
    D = nn.DataParallel(D.cuda())

g_optim = Adam(G.parameters(),lr=0.0002,betas=(0.5,0.999))
d_optim = Adam(D.parameters(),lr=0.0002,betas=(0.5,0.999))

dataset = fashionmnist_loader(size=32,batch_size=64)

learner = RStandardGanLearner(G,D)

if __name__ == "__main__":
    learner.train(dataset,gen_optimizer=g_optim,disc_optimizer=d_optim,save_outputs_interval=500,model_dir="./fashion-gan",latent_size=128,num_epochs=50,batch_log=False)