class Generator(nn.Module):
def __init__(self, dropout_p=0.4):
super(Generator, self).__init__()
self.dropout_p = dropout_p
# Load Unet with Resnet34 Embedding from SMP library. Pre-trained on Imagenet
self.unet = smp.Unet(encoder_name="resnet34", encoder_weights="imagenet",
in_channels=1, classes=3, activation=None)
# Adding two layers of Dropout as the original Unet doesn't have any
# This will be used to feed noise into the networking during both training and evaluation
# These extra layers will be added on decoder part where 2D transposed convolution is occured
for idx in range(1, 3):
self.unet.decoder.blocks[idx].conv1.add_module('3', nn.Dropout2d(p=self.dropout_p))
# Disabling in-place ReLU as to avoid in-place operations as it will
# cause issues for double backpropagation on the same graph
for module in self.unet.modules():
if isinstance(module, nn.ReLU):
module.inplace = False
def forward(self, x):
x = self.unet(x)
x = F.relu(x)
return x
class Pix2Pix(pl.LightningModule):
def __init__(self, generator_dropout_p=0.4, discriminator_dropout_p=0.4, generator_lr=1e-3, discriminator_lr=1e-6,
weight_decay=1e-5, lr_scheduler_T_0=1e3, lr_scheduler_T_mult=2):
super(Pix2Pix, self).__init__()
self.save_hyperparameters()
# Important to disable automatic optimization as it
# will be done manually as there are two optimizators
self.automatic_optimization = False
self.generator_lr = generator_lr # Generator learning rate
self.discriminator_lr = discriminator_lr # Discriminator learning rate
self.weight_decay = weight_decay # Weight decay e.g. L2 regularization
self.lr_scheduler_T_0 = lr_scheduler_T_0 # Optimizer initial restart step number
self.lr_scheduler_T_mult = lr_scheduler_T_mult # Optimizer restart step number factor
# Models
self.generator = Generator(dropout_p=generator_dropout_p)
self.discriminator = Discriminator(dropout_p=discriminator_dropout_p)
def forward(self, x):
return self.generator(x)
def generator_loss(self, prediction_image, target_image, prediction_label, target_label):
"""
Generator loss (a combination of):
1 - Binary Cross-Entropy
Between predicted labels (generated by the discriminator) and target labels which is all 1s
2 - L1 / Mean Absolute Error (weighted by lambda)
Between generated image and target image
3 - L2 / Mean Squared Error (weighted by lambda)
Between generated image and target image
"""
bce_loss = F.binary_cross_entropy(prediction_label, target_label)
l1_loss = F.l1_loss(prediction_image, target_image)
mse_loss = F.mse_loss(prediction_image, target_image)
return bce_loss, l1_loss, mse_loss
def discriminator_loss(self, prediction_label, target_label):
"""
Discriminator loss:
1 - Binary Cross-Entropy
Between predicted labels (generated by the discriminator) and target labels
The target would be all 0s if the input of the discriminator is the generated image (generator)
The target would be all 1s if the input of the discriminator is the target image (dataloader)
"""
bce_loss = F.binary_cross_entropy(prediction_label, target_label)
return bce_loss
def configure_optimizers(self):
"""
Using Adam optimizer for both generator and discriminator including L2 regularization
Both would have different initial learning rates
Stochastic Gradient Descent with Warm Restarts is also added as learning scheduler (https://arxiv.org/abs/1608.03983)
"""
# Optimizers
generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr=self.generator_lr, weight_decay=self.weight_decay)
discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=self.discriminator_lr, weight_decay=self.weight_decay)
# Learning Scheduler
genertator_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(generator_optimizer, T_0=self.lr_scheduler_T_0, T_mult=self.lr_scheduler_T_mult)
discriminator_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(discriminator_optimizer, T_0=self.lr_scheduler_T_0, T_mult=self.lr_scheduler_T_mult)
return [generator_optimizer, discriminator_optimizer], [genertator_lr_scheduler, discriminator_lr_scheduler]
def training_step(self, batch, batch_idx):
# Optimizers
generator_optimizer, discriminator_optimizer = self.optimizers()
generator_lr_scheduler, discriminator_lr_scheduler = self.lr_schedulers()
image, target = batch
image_i, image_j = torch.split(image, TRAIN_BATCH_SIZE // 2)
target_i, target_j = torch.split(target, TRAIN_BATCH_SIZE // 2)
######################################
# Discriminator Loss and Optimizer #
######################################
# Generator Feed-Forward
generator_prediction = self.forward(image_i)
generator_prediction = torch.clip(generator_prediction, 0, 1)
# Discriminator Feed-Forward
discriminator_prediction_real = self.discriminator(torch.cat((image_i, target_i), dim=1))
discriminator_prediction_fake = self.discriminator(torch.cat((image_i, generator_prediction), dim=1))
# Discriminator Loss
discriminator_label_real = self.discriminator_loss(discriminator_prediction_real,
torch.ones_like(discriminator_prediction_real))
discriminator_label_fake = self.discriminator_loss(discriminator_prediction_fake,
torch.zeros_like(discriminator_prediction_fake))
discriminator_loss = discriminator_label_real + discriminator_label_fake
# Discriminator Optimizer
discriminator_optimizer.zero_grad()
discriminator_loss.backward()
discriminator_optimizer.step()
discriminator_lr_scheduler.step()
##################################
# Generator Loss and Optimizer #
##################################
# Generator Feed-Forward
generator_prediction = self.forward(image_j)
generator_prediction = torch.clip(generator_prediction, 0, 1)
# Discriminator Feed-Forward
discriminator_prediction_fake = self.discriminator(torch.cat((image_j, generator_prediction), dim=1))
# Generator loss
generator_bce_loss, generator_l1_loss, generator_mse_loss = self.generator_loss(generator_prediction, target_j,
discriminator_prediction_fake,
torch.ones_like(discriminator_prediction_fake))
generator_loss = generator_bce_loss + (generator_l1_loss * LAMBDA) + (generator_mse_loss * LAMBDA)
# Generator Optimizer
generator_optimizer.zero_grad()
generator_loss.backward()
generator_optimizer.step()
generator_lr_scheduler.step()
# Progressbar and Logging
loss = OrderedDict({'train_g_bce_loss': generator_bce_loss, 'train_g_l1_loss': generator_l1_loss, 'train_g_mse_loss': generator_mse_loss,
'train_g_loss': generator_loss, 'train_d_loss': discriminator_loss,
'train_g_lr': generator_lr_scheduler.get_last_lr()[0], 'train_d_lr': discriminator_lr_scheduler.get_last_lr()[0]})
self.log_dict(loss, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
image, target = batch
# Generator Feed-Forward
generator_prediction = self.forward(image)
generator_prediction = torch.clip(generator_prediction, 0, 1)
# Generator Metrics
generator_psnr = psnr(generator_prediction, target)
generator_ssim = ssim(generator_prediction, target)
discriminator_prediction_fake = self.discriminator(torch.cat((image, generator_prediction), dim=1))
generator_accuracy = accuracy(discriminator_prediction_fake, torch.ones_like(discriminator_prediction_fake, dtype=torch.int32))
# Discriminator Feed-Forward
discriminator_prediction_real = self.discriminator(torch.cat((image, target), dim=1))
discriminator_prediction_fake = self.discriminator(torch.cat((image, generator_prediction), dim=1))
# Discriminator Metrics
discriminator_accuracy = accuracy(discriminator_prediction_real, torch.ones_like(discriminator_prediction_real, dtype=torch.int32)) * 0.5 + \
accuracy(discriminator_prediction_fake, torch.zeros_like(discriminator_prediction_fake, dtype=torch.int32)) * 0.5
# Progressbar and Logging
metrics = OrderedDict({'val_g_psnr': generator_psnr, 'val_g_ssim': generator_ssim,
'val_g_accuracy': generator_accuracy, 'val_d_accuracy': discriminator_accuracy})
self.log_dict(metrics, prog_bar=True)
return metrics
Callbacks
class EpochInference(pl.callbacks.base.Callback):
"""
Callback on each end of training epoch
The callback will do inference on test dataloader based on corresponding checkpoints
The results will be saved as an image with 4-rows:
1 - Input image e.g. grayscale edged input
2 - Ground-truth
3 - Single inference
4 - Mean of hundred accumulated inference
Note that the inference have a noise factor that will generate different output on each execution
"""
def __init__(self, dataloader, *args, **kwargs):
super(EpochInference, self).__init__(*args, **kwargs)
self.dataloader = dataloader
def on_train_epoch_end(self, trainer, pl_module):
super(EpochInference, self).on_train_epoch_end(trainer, pl_module)
data = next(iter(self.dataloader))
image, target = data
image = image.cuda()
target = target.cuda()
with torch.no_grad():
# Take average of multiple inference as there is a random noise
# Single
reconstruction_init = pl_module.forward(image)
reconstruction_init = torch.clip(reconstruction_init, 0, 1)
# Mean
reconstruction_mean = torch.stack([pl_module.forward(image) for _ in range(100)])
reconstruction_mean = torch.clip(reconstruction_mean, 0, 1)
reconstruction_mean = torch.mean(reconstruction_mean, dim=0)
# Grayscale 1-D to 3-D
image = torch.stack([image for _ in range(3)], dim=1)
image = torch.squeeze(image)
grid_image = torchvision.utils.make_grid(torch.cat([image, target, reconstruction_init, reconstruction_mean], dim=0), nrow=20)
torchvision.utils.save_image(grid_image, fp=f'{trainer.log_dir}/epoch-{trainer.current_epoch:04}.png')