GAN Inference¶
Now that we have learn’t how to generate images of specific classes, here we shall use the trained generator for inference
Putting it all Together
from torchfusion.gan.learners import *
from torchfusion.gan.applications import StandardGenerator
import torch.cuda as cuda
import torch.nn as nn
from torchvision.utils import save_image
import torch
from torch.distributions import Normal
G = StandardGenerator(output_size=(1,32,32),latent_size=128,num_classes=10)
if cuda.is_available():
G = nn.DataParallel(G.cuda())
learner = RStandardGanLearner(G,None)
learner.load_generator("path-to-trained-gen")
if __name__ == "__main__":
"Define an instance of the normal distribution"
dist = Normal(0,1)
#Get a sample latent vector from the distribution
latent_vector = dist.sample((1,128))
#Define the class of the image you want to generate
label = torch.LongTensor(1).fill_(5)
#Run inference
image = learner.predict([latent_vector,label])
#Save generated image
save_image(image, "image.jpg")