Skip to content
Snippets Groups Projects
model_GAN.py 8.52 KiB
Newer Older
import torch.nn as nn
from torch.nn import functional as F
import torch
import pytorch_lightning as pl
from rmi import RMILoss
from torchvision.utils import make_grid
import numpy as np

import segmentation_models_pytorch as smp

from pl_bolts.models.gans.dcgan.components import DCGANDiscriminator

class Generator(nn.Module):

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = parent_parser.add_argument_group("Generator")
        parser.add_argument('--learning_rate', type=float, default=1e-3)
        parser.add_argument('--name', type=str, default='Generator')
        return parent_parser

    @classmethod
    def from_argparse_args(cls, args):
        dict_args = vars(args)
        return cls(**dict_args)

    def __init__(self, learning_rate=1e-3, name='Generator', depth=3, **kwargs):
        super().__init__()
        #self.save_hyperparameters()
        self.learning_rate = learning_rate
        self.name = name

        # Define the model
        self.g_model = smp.Unet(
            encoder_name="resnet18",  # Also consider using smaller or larger encoders
            encoder_weights= "imagenet",  # Do the pretrained weights help? Try with or without
s.islam's avatar
s.islam committed
            in_channels=1,  # We use 1 chanel transmittance as input
            classes=1,  # classes == output channels. We use one output channel for cyto data
            activation="sigmoid"
        )
        self.loss_f = RMILoss(with_logits=True) #torch.nn.MSELoss()

    def forward(self, x):
        x = self.g_model(x)
        return x

    def training_step(self, batch, batch_idx):
        # Structure of the batch: {'pli_image': pli_image, 'cyto_image': stained_image}
        cyto_imag_generated = self.forward(batch['pli_image'])
        loss = self.loss_f(cyto_imag_generated, batch['cyto_image'])
        #self.log('train_loss', loss)
        return loss, cyto_imag_generated
'''
    def validation_step(self, batch, batch_idx):
        # Do it the same as in training
        cyto_imag_generated = self.forward(batch['pli_image'])
        loss = self.loss_f(cyto_imag_generated, batch['cyto_image'])
        self.log("val_loss", loss)
        batch['pli_image'] = batch['pli_image']
        batch['cyto_image'] = batch['cyto_image']
        cyto_imag_generated = cyto_imag_generated
        if batch_idx == 0:
            grid = make_grid([batch['pli_image'][0, :1], batch['pli_image'][0, 1:], batch['cyto_image'][0], cyto_imag_generated[0]])
            self.logger.experiment.add_image('Grid_images', grid, self.current_epoch, dataformats="CHW")
'''

'''
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer
'''


class Discriminator_2(nn.Module):
    def __init__(self, img_shape):
        super().__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)

        return validity

s.islam's avatar
s.islam committed
#discriminator
class Discriminator(nn.Module):

    @staticmethod
    def add_model_specific_args(parent_parser):
s.islam's avatar
s.islam committed
        parser = parent_parser.add_argument_group("discriminator")
        parser.add_argument('--learning_rate', type=float, default=1e-3)
s.islam's avatar
s.islam committed
        parser.add_argument('--name', type=str, default='discriminator')
        return parent_parser

    @classmethod
    def from_argparse_args(cls, args):
        dict_args = vars(args)
        return cls(**dict_args)

s.islam's avatar
s.islam committed
    def __init__(self, learning_rate=1e-3, name='discriminator', **kwargs):
        super().__init__()
        #self.save_hyperparameters()
        self.learning_rate = learning_rate
        self.name = name
        self.d_model = DCGANDiscriminator(
        image_channels = 1
        )

    def forward(self, x):
        #x = self.d_model(x)
        #self.disc(x)
        x = self.d_model.disc(x)
        #out = F.adaptive_avg_pool2d(x, (1, 1)).squeeze()
        out = F.adaptive_avg_pool2d(x, (1, 1)).view((x.size(0), -1))

        return out

#GAN
class GAN(pl.LightningModule):

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = parent_parser.add_argument_group("TestModule")
        parser.add_argument('--learning_rate', type=float, default=1e-3)
        parser.add_argument('--name', type=str, default='GAN')
        return parent_parser

    @classmethod
    def from_argparse_args(cls, args):
        dict_args = vars(args)
        return cls(**dict_args)

    def __init__(self, learning_rate=1e-3, name='model', depth=3, lamda_G = 0.1, **kwargs):
        super().__init__()
        self.save_hyperparameters()
        self.learning_rate = learning_rate
        self.name = name
        self.generator = Generator()
        self.discriminator = Discriminator((256,256))
        self.cyto_imag_generated = None
        self.lamda_G = lamda_G

    def forward(self, x):
        result = self.generator(x)
        return result

    def adversarial_loss(self, y_hat, y):
        return F.binary_cross_entropy(y_hat, y)


    def training_step(self, batch, batch_idx, optimizer_idx):

        #Training the Generator.
        if optimizer_idx == 0:
            self.cyto_imag_generated = self(batch['pli_image'])

            loss_RMI = self.generator.loss_f(self.cyto_imag_generated, batch['cyto_image'])

            # ground truth result (ie: all fake)
            # put on GPU because we created this tensor inside training_loop
            valid = torch.ones(batch['pli_image'].size(0), 1)
            valid = valid.type_as(batch['pli_image'])

s.islam's avatar
s.islam committed
            g_loss = self.adversarial_loss(self.discriminator(self.cyto_imag_generated), valid)
            self.log('train_g_loss', g_loss)
            self.log('train_g_RMI_loss', loss_RMI)
            self.log('train_g_Total_loss', g_loss + self.lamda_G * loss_RMI)
            return g_loss + self.lamda_G * loss_RMI

s.islam's avatar
s.islam committed
        #Training discriminator.
        if optimizer_idx == 1:
            # how well can it label as real?
            valid = torch.ones(batch['pli_image'].size(0), 1)
            valid = valid.type_as(batch['pli_image'])

s.islam's avatar
s.islam committed
            real_loss = self.adversarial_loss(self.discriminator(batch['cyto_image']), valid)

            # how well can it label as fake?
            fake = torch.zeros(batch['pli_image'].size(0), 1)
            fake = fake.type_as(batch['pli_image'])

s.islam's avatar
s.islam committed
            fake_loss = self.adversarial_loss(self.discriminator(batch['pli_image']).detach(), fake)

            d_loss = (real_loss + fake_loss) / 2
            self.log('train_d_loss', d_loss)
            return d_loss

s.islam's avatar
s.islam committed




    def validation_step(self, batch, batch_idx):
        # Validating the Generator.
s.islam's avatar
s.islam committed
        self.cyto_imag_generated = self.generator(batch['pli_image'])
s.islam's avatar
s.islam committed
        loss_RMI = self.generator.loss_f(self.cyto_imag_generated, batch['cyto_image'])

        # ground truth result (ie: all fake)
        # put on GPU because we created this tensor inside training_loop
s.islam's avatar
s.islam committed
        valid = torch.ones(batch['pli_image'].size(0), 1)
        valid = valid.type_as(batch['pli_image'])
s.islam's avatar
s.islam committed

s.islam's avatar
s.islam committed
        g_loss = self.adversarial_loss(self.discriminator(self.cyto_imag_generated), valid)
s.islam's avatar
s.islam committed
        #self.log('val_g_loss', g_loss)
        #self.log('val_g_RMI_loss', loss_RMI)
s.islam's avatar
s.islam committed
        self.log('val_g_Total_loss', g_loss + self.lamda_G * loss_RMI)
s.islam's avatar
s.islam committed
        if batch_idx == 0:
s.islam's avatar
s.islam committed
            grid = make_grid([batch['pli_image'][0], batch['cyto_image'][0], self.cyto_imag_generated[0]])
s.islam's avatar
s.islam committed
            self.logger.experiment.add_image('Grid_images', grid, self.current_epoch, dataformats="CHW")






    def configure_optimizers(self):
        optimizer_g = torch.optim.Adam(self.generator.parameters(), lr=self.learning_rate)
s.islam's avatar
s.islam committed
        optimizer_d = torch.optim.Adam(self.discriminator.parameters(), lr=self.learning_rate)
        return [optimizer_g, optimizer_d], []
'''
    def validation_step(self, batch, batch_idx, optimizer_idx):
        # Training the Generator.
        if optimizer_idx == 0:
            self.cyto_imag_generated = self.forward(batch['pli_image'])
s.islam's avatar
s.islam committed
            g_loss = self.adversarial_loss(self.discriminator(self.cyto_imag_generated), batch['cyto_image'])
            self.log('val_g_loss', g_loss)
            return g_loss

s.islam's avatar
s.islam committed
        # Training discriminator.
        if optimizer_idx == 1:
            d_loss = self.adversarial_loss(self.cyto_imag_generated, batch['cyto_image'])
            self.log('val_d_loss', d_loss)
            return d_loss
'''






''''''