Skip to content
Snippets Groups Projects
model_GAN.py 8.65 KiB
Newer Older
  • Learn to ignore specific revisions
  • import torch.nn as nn
    from torch.nn import functional as F
    import torch
    import pytorch_lightning as pl
    from rmi import RMILoss
    from torchvision.utils import make_grid
    import numpy as np
    
    import segmentation_models_pytorch as smp
    
    from pl_bolts.models.gans.dcgan.components import DCGANDiscriminator
    
    class Generator(nn.Module):
    
        @staticmethod
        def add_model_specific_args(parent_parser):
            parser = parent_parser.add_argument_group("Generator")
            parser.add_argument('--learning_rate', type=float, default=1e-3)
            parser.add_argument('--name', type=str, default='Generator')
            return parent_parser
    
        @classmethod
        def from_argparse_args(cls, args):
            dict_args = vars(args)
            return cls(**dict_args)
    
        def __init__(self, learning_rate=1e-3, name='Generator', depth=3, **kwargs):
            super().__init__()
            #self.save_hyperparameters()
            self.learning_rate = learning_rate
            self.name = name
    
            # Define the model
            self.g_model = smp.Unet(
                encoder_name="resnet18",  # Also consider using smaller or larger encoders
                encoder_weights= "imagenet",  # Do the pretrained weights help? Try with or without
    
    s.islam's avatar
    s.islam committed
                in_channels=1,  # We use 1 chanel transmittance as input
    
                classes=1,  # classes == output channels. We use one output channel for cyto data
                activation="sigmoid"
            )
            self.loss_f = RMILoss(with_logits=True) #torch.nn.MSELoss()
    
        def forward(self, x):
            x = self.g_model(x)
            return x
    
        def training_step(self, batch, batch_idx):
            # Structure of the batch: {'pli_image': pli_image, 'cyto_image': stained_image}
            cyto_imag_generated = self.forward(batch['pli_image'])
            loss = self.loss_f(cyto_imag_generated, batch['cyto_image'])
            #self.log('train_loss', loss)
            return loss, cyto_imag_generated
    '''
        def validation_step(self, batch, batch_idx):
            # Do it the same as in training
            cyto_imag_generated = self.forward(batch['pli_image'])
            loss = self.loss_f(cyto_imag_generated, batch['cyto_image'])
            self.log("val_loss", loss)
            batch['pli_image'] = batch['pli_image']
            batch['cyto_image'] = batch['cyto_image']
            cyto_imag_generated = cyto_imag_generated
            if batch_idx == 0:
                grid = make_grid([batch['pli_image'][0, :1], batch['pli_image'][0, 1:], batch['cyto_image'][0], cyto_imag_generated[0]])
                self.logger.experiment.add_image('Grid_images', grid, self.current_epoch, dataformats="CHW")
    '''
    
    '''
        def configure_optimizers(self):
            optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
            return optimizer
    '''
    
    
    class Discriminator_2(nn.Module):
        def __init__(self, img_shape):
            super().__init__()
    
            self.model = nn.Sequential(
                nn.Linear(int(np.prod(img_shape)), 512),
                nn.LeakyReLU(0.2, inplace=True),
                nn.Linear(512, 256),
                nn.LeakyReLU(0.2, inplace=True),
                nn.Linear(256, 1),
                nn.Sigmoid(),
            )
    
        def forward(self, img):
            img_flat = img.view(img.size(0), -1)
            validity = self.model(img_flat)
    
            return validity
    
    
    s.islam's avatar
    s.islam committed
    #discriminator
    class Discriminator(nn.Module):
    
    
        @staticmethod
        def add_model_specific_args(parent_parser):
    
    s.islam's avatar
    s.islam committed
            parser = parent_parser.add_argument_group("discriminator")
    
            parser.add_argument('--learning_rate', type=float, default=1e-3)
    
    s.islam's avatar
    s.islam committed
            parser.add_argument('--name', type=str, default='discriminator')
    
            return parent_parser
    
        @classmethod
        def from_argparse_args(cls, args):
            dict_args = vars(args)
            return cls(**dict_args)
    
    
    s.islam's avatar
    s.islam committed
        def __init__(self, learning_rate=1e-3, name='discriminator', **kwargs):
    
            super().__init__()
            #self.save_hyperparameters()
            self.learning_rate = learning_rate
            self.name = name
            self.d_model = DCGANDiscriminator(
    
            image_channels = 1
            )
    
        def forward(self, x):
    
            #x = self.d_model(x)
            #self.disc(x)
            x = self.d_model.disc(x)
            #out = F.adaptive_avg_pool2d(x, (1, 1)).squeeze()
            out = F.adaptive_avg_pool2d(x, (1, 1)).view((x.size(0), -1))
    
            return out
    
    
    #GAN
    class GAN(pl.LightningModule):
    
        @staticmethod
        def add_model_specific_args(parent_parser):
            parser = parent_parser.add_argument_group("TestModule")
            parser.add_argument('--learning_rate', type=float, default=1e-3)
            parser.add_argument('--name', type=str, default='GAN')
            return parent_parser
    
        @classmethod
        def from_argparse_args(cls, args):
            dict_args = vars(args)
            return cls(**dict_args)
    
    
        def __init__(self, learning_rate=1e-3, name='model', depth=3, lamda_G = 0, **kwargs):
    
            super().__init__()
            self.save_hyperparameters()
            self.learning_rate = learning_rate
            self.name = name
            self.generator = Generator()
    
            self.discriminator = Discriminator((256,256))
    
            self.cyto_imag_generated = None
            self.lamda_G = lamda_G
    
        def forward(self, x):
            result = self.generator(x)
            return result
    
        def adversarial_loss(self, y_hat, y):
            return F.binary_cross_entropy(y_hat, y)
    
    
        def training_step(self, batch, batch_idx, optimizer_idx):
    
            #Training the Generator.
            if optimizer_idx == 0:
                self.cyto_imag_generated = self(batch['pli_image'])
    
                loss_RMI = self.generator.loss_f(self.cyto_imag_generated, batch['cyto_image'])
    
                # ground truth result (ie: all fake)
                # put on GPU because we created this tensor inside training_loop
                valid = torch.ones(batch['pli_image'].size(0), 1)
                valid = valid.type_as(batch['pli_image'])
    
    
    s.islam's avatar
    s.islam committed
                g_loss = self.adversarial_loss(self.discriminator(self.cyto_imag_generated), valid)
    
                self.log('train_g_loss', g_loss)
                self.log('train_g_RMI_loss', loss_RMI)
                self.log('train_g_Total_loss', g_loss + self.lamda_G * loss_RMI)
                return g_loss + self.lamda_G * loss_RMI
    
    
    s.islam's avatar
    s.islam committed
            #Training discriminator.
    
            if optimizer_idx == 1:
                # how well can it label as real?
                valid = torch.ones(batch['pli_image'].size(0), 1)
                valid = valid.type_as(batch['pli_image'])
    
    
    s.islam's avatar
    s.islam committed
                real_loss = self.adversarial_loss(self.discriminator(batch['cyto_image']), valid)
    
    
                # how well can it label as fake?
                fake = torch.zeros(batch['pli_image'].size(0), 1)
                fake = fake.type_as(batch['pli_image'])
    
    
                fake_loss = self.adversarial_loss(self.discriminator(self.generator(batch['pli_image']).detach()), fake)
    
    
                d_loss = (real_loss + fake_loss) / 2
                self.log('train_d_loss', d_loss)
                return d_loss
    
    
    s.islam's avatar
    s.islam committed
    
    
    
    
        def validation_step(self, batch, batch_idx):
            # Validating the Generator.
    
    s.islam's avatar
    s.islam committed
            self.cyto_imag_generated = self.generator(batch['pli_image'])
    
    s.islam's avatar
    s.islam committed
            loss_RMI = self.generator.loss_f(self.cyto_imag_generated, batch['cyto_image'])
    
            # ground truth result (ie: all fake)
            # put on GPU because we created this tensor inside training_loop
    
    s.islam's avatar
    s.islam committed
            valid = torch.ones(batch['pli_image'].size(0), 1)
            valid = valid.type_as(batch['pli_image'])
    
    s.islam's avatar
    s.islam committed
    
    
    s.islam's avatar
    s.islam committed
            g_loss = self.adversarial_loss(self.discriminator(self.cyto_imag_generated), valid)
    
    s.islam's avatar
    s.islam committed
            #self.log('val_g_loss', g_loss)
            #self.log('val_g_RMI_loss', loss_RMI)
    
    s.islam's avatar
    s.islam committed
            self.log('val_g_Total_loss', g_loss + self.lamda_G * loss_RMI)
    
    s.islam's avatar
    s.islam committed
            if batch_idx == 0:
    
                #grid = make_grid([batch['pli_image'][0], batch['cyto_image'][0], self.cyto_imag_generated[0]])
                grid = make_grid(list(batch['pli_image']) + list(batch['cyto_image']) + list(self.cyto_imag_generated))
    
    s.islam's avatar
    s.islam committed
                self.logger.experiment.add_image('Grid_images', grid, self.current_epoch, dataformats="CHW")
    
    
    
    
    
    
    
        def configure_optimizers(self):
            optimizer_g = torch.optim.Adam(self.generator.parameters(), lr=self.learning_rate)
    
    s.islam's avatar
    s.islam committed
            optimizer_d = torch.optim.Adam(self.discriminator.parameters(), lr=self.learning_rate)
    
            return [optimizer_g, optimizer_d], []
    '''
        def validation_step(self, batch, batch_idx, optimizer_idx):
            # Training the Generator.
            if optimizer_idx == 0:
                self.cyto_imag_generated = self.forward(batch['pli_image'])
    
    s.islam's avatar
    s.islam committed
                g_loss = self.adversarial_loss(self.discriminator(self.cyto_imag_generated), batch['cyto_image'])
    
                self.log('val_g_loss', g_loss)
                return g_loss
    
    
    s.islam's avatar
    s.islam committed
            # Training discriminator.
    
            if optimizer_idx == 1:
                d_loss = self.adversarial_loss(self.cyto_imag_generated, batch['cyto_image'])
                self.log('val_d_loss', d_loss)
                return d_loss
    '''
    
    
    
    
    
    
    ''''''