Buiding Custom Trainers!ΒΆ
Torchfusion provides a wide variety of GAN Learners, you will find them in the torchfusion.gan.learners package However, lots of research is ongoing into improved techniques for GANs, hence, we provide multiple levels of abstractions to faciliate research.
Custom Loss
#Extend the StandardBaseGanLearner
class CustomGanLearner(StandardBaseGanLearner):
#Override the __update_discriminator_loss__
def __update_discriminator_loss__(self, real_images, gen_images, real_preds, gen_preds):
pred_loss = -torch.mean(real_preds - gen_preds)
return pred_loss
#Override the __update_generator_loss__
def __update_generator_loss__(self,real_images,gen_images,real_preds,gen_preds):
pred_loss = -torch.mean(gen_preds - real_preds)
return pred_loss
Custom Training Logic
#Extend BaseGanCore
class CustomGanLearner(BaseGanCore):
#Extend train
def train(self,train_loader, gen_optimizer,disc_optimizer,latent_size,loss_fn=nn.BCELoss(),**kwargs):
self.latent_size = latent_size
self.loss_fn = loss_fn
super().__train_loop__(train_loader,gen_optimizer,disc_optimizer,**kwargs)
#Extend __disc_train_func__
def __disc_train_func__(self, data):
super().__disc_train_func__(data)
self.disc_optimizer.zero_grad()
if isinstance(data, list) or isinstance(data, tuple):
x = data[0]
else:
x = data
batch_size = x.size(0)
source = self.dist.sample((batch_size,self.latent_size))
real_labels = torch.ones(batch_size,1)
fake_labels = torch.zeros(batch_size,1)
if self.cuda:
x = x.cuda()
source = source.cuda()
real_labels = real_labels.cuda()
fake_labels = fake_labels.cuda()
x = Variable(x)
source = Variable(source)
outputs = self.disc_model(x)
generated = self.gen_model(source)
gen_outputs = self.disc_model(generated.detach())
gen_loss = self.loss_fn(gen_outputs,fake_labels)
real_loss = self.loss_fn(outputs,real_labels)
loss = gen_loss + real_loss
loss.backward()
self.disc_optimizer.step()
self.disc_running_loss.add_(loss.cpu() * batch_size)
#Extend __gen_train_func__
def __gen_train_func__(self, data):
super().__gen_train_func__(data)
self.gen_optimizer.zero_grad()
if isinstance(data, list) or isinstance(data, tuple):
x = data[0]
else:
x = data
batch_size = x.size(0)
source = self.dist.sample((batch_size,self.latent_size))
real_labels = torch.ones(batch_size,1)
if self.cuda:
source = source.cuda()
real_labels = real_labels.cuda()
source = Variable(source)
fake_images = self.gen_model(source)
outputs = self.disc_model(fake_images)
loss = self.loss_fn(outputs,real_labels)
loss.backward()
self.gen_optimizer.step()
self.gen_running_loss.add_(loss.cpu() * batch_size)