import torch.nn as nn from torch.nn import functional as F import torch import pytorch_lightning as pl from rmi import RMILoss from torchvision.utils import make_grid import numpy as np import segmentation_models_pytorch as smp from pl_bolts.models.gans.dcgan.components import DCGANDiscriminator class Generator(nn.Module): @staticmethod def add_model_specific_args(parent_parser): parser = parent_parser.add_argument_group("Generator") parser.add_argument('--learning_rate', type=float, default=1e-3) parser.add_argument('--name', type=str, default='Generator') return parent_parser @classmethod def from_argparse_args(cls, args): dict_args = vars(args) return cls(**dict_args) def __init__(self, learning_rate=1e-3, name='Generator', depth=3, **kwargs): super().__init__() #self.save_hyperparameters() self.learning_rate = learning_rate self.name = name # Define the model self.g_model = smp.Unet( encoder_name="resnet18", # Also consider using smaller or larger encoders encoder_weights= "imagenet", # Do the pretrained weights help? Try with or without in_channels=2, # We use 1 chanel transmittance as input classes=1, # classes == output channels. We use one output channel for cyto data activation="sigmoid" ) self.loss_f = RMILoss(with_logits=True) #torch.nn.MSELoss() def forward(self, x): x = self.g_model(x) return x def training_step(self, batch, batch_idx): # Structure of the batch: {'pli_image': pli_image, 'cyto_image': stained_image} cyto_imag_generated = self.forward(batch['pli_image']) loss = self.loss_f(cyto_imag_generated, batch['cyto_image']) #self.log('train_loss', loss) return loss, cyto_imag_generated ''' def validation_step(self, batch, batch_idx): # Do it the same as in training cyto_imag_generated = self.forward(batch['pli_image']) loss = self.loss_f(cyto_imag_generated, batch['cyto_image']) self.log("val_loss", loss) batch['pli_image'] = batch['pli_image'] batch['cyto_image'] = batch['cyto_image'] cyto_imag_generated = cyto_imag_generated if batch_idx == 0: grid = make_grid([batch['pli_image'][0, :1], batch['pli_image'][0, 1:], batch['cyto_image'][0], cyto_imag_generated[0]]) self.logger.experiment.add_image('Grid_images', grid, self.current_epoch, dataformats="CHW") ''' ''' def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) return optimizer ''' ''' class Discriminator_2(nn.Module): def __init__(self, img_shape): super().__init__() self.model = nn.Sequential( nn.Linear(int(np.prod(img_shape)), 512), nn.LeakyReLU(0.2, inplace=True), nn.Linear(512, 256), nn.LeakyReLU(0.2, inplace=True), nn.Linear(256, 1), nn.Sigmoid(), ) def forward(self, img): img_flat = img.view(img.size(0), -1) validity = self.model(img_flat) return validity ''' #discriminator class Discriminator(nn.Module): @staticmethod def add_model_specific_args(parent_parser): parser = parent_parser.add_argument_group("discriminator") parser.add_argument('--learning_rate', type=float, default=1e-3) parser.add_argument('--name', type=str, default='discriminator') return parent_parser @classmethod def from_argparse_args(cls, args): dict_args = vars(args) return cls(**dict_args) def __init__(self, learning_rate=1e-3, name='discriminator', **kwargs): super().__init__() #self.save_hyperparameters() self.learning_rate = learning_rate self.name = name self.d_model = DCGANDiscriminator( feature_maps = 16, image_channels = 1 ) def forward(self, x): #x = self.d_model(x) #self.disc(x) x = self.d_model.disc(x) #out = F.adaptive_avg_pool2d(x, (1, 1)).squeeze() out = F.adaptive_avg_pool2d(x, (1, 1)).view((x.size(0), -1)) return out #GAN class GAN(pl.LightningModule): @staticmethod def add_model_specific_args(parent_parser): parser = parent_parser.add_argument_group("TestModule") parser.add_argument('--learning_rate', type=float, default=1e-3) parser.add_argument('--name', type=str, default='GAN') return parent_parser @classmethod def from_argparse_args(cls, args): dict_args = vars(args) return cls(**dict_args) def __init__(self, learning_rate=1e-3, name='model', depth=3, lamda_G = 0.2, **kwargs): super().__init__() self.save_hyperparameters() self.learning_rate = learning_rate self.name = name self.generator = Generator() self.discriminator = Discriminator((256,256)) self.cyto_imag_generated = None self.lamda_G = lamda_G def forward(self, x): result = self.generator(x) return result def adversarial_loss(self, y_hat, y): return F.binary_cross_entropy(y_hat, y) def training_step(self, batch, batch_idx, optimizer_idx): #Training the Generator. if optimizer_idx == 0: self.cyto_imag_generated = self(batch['pli_image']) loss_RMI = self.generator.loss_f(self.cyto_imag_generated, batch['cyto_image']) # ground truth result (ie: all fake) # put on GPU because we created this tensor inside training_loop valid = torch.ones(batch['pli_image'].size(0), 1) valid = valid.type_as(batch['pli_image']) g_loss = self.adversarial_loss(self.discriminator(self.cyto_imag_generated), valid) self.log('train_g_loss', g_loss) self.log('train_g_RMI_loss', loss_RMI) self.log('train_g_Total_loss', g_loss + self.lamda_G * loss_RMI) return g_loss + self.lamda_G * loss_RMI #Training discriminator. if optimizer_idx == 1: # how well can it label as real? valid = torch.ones(batch['pli_image'].size(0), 1) valid = valid.type_as(batch['pli_image']) real_loss = self.adversarial_loss(self.discriminator(batch['cyto_image']), valid) # how well can it label as fake? fake = torch.zeros(batch['pli_image'].size(0), 1) fake = fake.type_as(batch['pli_image']) fake_loss = self.adversarial_loss(self.discriminator(self.generator(batch['pli_image']).detach()), fake) d_loss = (real_loss + fake_loss) / 2 self.log('train_d_loss', d_loss) return d_loss def validation_step(self, batch, batch_idx): # Validating the Generator. self.cyto_imag_generated = self.generator(batch['pli_image']) loss_RMI = self.generator.loss_f(self.cyto_imag_generated, batch['cyto_image']) # ground truth result (ie: all fake) # put on GPU because we created this tensor inside training_loop valid = torch.ones(batch['pli_image'].size(0), 1) valid = valid.type_as(batch['pli_image']) g_loss = self.adversarial_loss(self.discriminator(self.cyto_imag_generated), valid) self.log('val_g_loss', g_loss) self.log('val_g_RMI_loss', loss_RMI) self.log('val_g_Total_loss', g_loss + self.lamda_G * loss_RMI) if batch_idx == 0: #grid = make_grid([batch['pli_image'][0], batch['cyto_image'][0], self.cyto_imag_generated[0]]) grid = make_grid(list(batch['pli_image'][:,:1]) + list(batch['pli_image'][:,1:]) + list(batch['cyto_image']) + list(self.cyto_imag_generated)) self.logger.experiment.add_image('Grid_images', grid, self.current_epoch, dataformats="CHW") def configure_optimizers(self): optimizer_g = torch.optim.Adam(self.generator.parameters(), lr=self.learning_rate) optimizer_d = torch.optim.Adam(self.discriminator.parameters(), lr=self.learning_rate) return [optimizer_g, optimizer_d], [] ''' def validation_step(self, batch, batch_idx, optimizer_idx): # Training the Generator. if optimizer_idx == 0: self.cyto_imag_generated = self.forward(batch['pli_image']) g_loss = self.adversarial_loss(self.discriminator(self.cyto_imag_generated), batch['cyto_image']) self.log('val_g_loss', g_loss) return g_loss # Training discriminator. if optimizer_idx == 1: d_loss = self.adversarial_loss(self.cyto_imag_generated, batch['cyto_image']) self.log('val_d_loss', d_loss) return d_loss ''' ''''''