Newer
Older
import torch.nn as nn
from torch.nn import functional as F
import torch
import pytorch_lightning as pl
from rmi import RMILoss
from torchvision.utils import make_grid
import numpy as np
import segmentation_models_pytorch as smp
from pl_bolts.models.gans.dcgan.components import DCGANDiscriminator
class Generator(nn.Module):
@staticmethod
def add_model_specific_args(parent_parser):
parser = parent_parser.add_argument_group("Generator")
parser.add_argument('--learning_rate', type=float, default=1e-3)
parser.add_argument('--name', type=str, default='Generator')
return parent_parser
@classmethod
def from_argparse_args(cls, args):
dict_args = vars(args)
return cls(**dict_args)
def __init__(self, learning_rate=1e-3, name='Generator', depth=3, **kwargs):
super().__init__()
#self.save_hyperparameters()
self.learning_rate = learning_rate
self.name = name
# Define the model
self.g_model = smp.Unet(
encoder_name="resnet18", # Also consider using smaller or larger encoders
encoder_weights= "imagenet", # Do the pretrained weights help? Try with or without
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
classes=1, # classes == output channels. We use one output channel for cyto data
activation="sigmoid"
)
self.loss_f = RMILoss(with_logits=True) #torch.nn.MSELoss()
def forward(self, x):
x = self.g_model(x)
return x
def training_step(self, batch, batch_idx):
# Structure of the batch: {'pli_image': pli_image, 'cyto_image': stained_image}
cyto_imag_generated = self.forward(batch['pli_image'])
loss = self.loss_f(cyto_imag_generated, batch['cyto_image'])
#self.log('train_loss', loss)
return loss, cyto_imag_generated
'''
def validation_step(self, batch, batch_idx):
# Do it the same as in training
cyto_imag_generated = self.forward(batch['pli_image'])
loss = self.loss_f(cyto_imag_generated, batch['cyto_image'])
self.log("val_loss", loss)
batch['pli_image'] = batch['pli_image']
batch['cyto_image'] = batch['cyto_image']
cyto_imag_generated = cyto_imag_generated
if batch_idx == 0:
grid = make_grid([batch['pli_image'][0, :1], batch['pli_image'][0, 1:], batch['cyto_image'][0], cyto_imag_generated[0]])
self.logger.experiment.add_image('Grid_images', grid, self.current_epoch, dataformats="CHW")
'''
'''
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
return optimizer
'''
class Discriminator_2(nn.Module):
def __init__(self, img_shape):
super().__init__()
self.model = nn.Sequential(
nn.Linear(int(np.prod(img_shape)), 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 1),
nn.Sigmoid(),
)
def forward(self, img):
img_flat = img.view(img.size(0), -1)
validity = self.model(img_flat)
return validity
@staticmethod
def add_model_specific_args(parent_parser):
parser.add_argument('--learning_rate', type=float, default=1e-3)
parser.add_argument('--name', type=str, default='discriminator')
return parent_parser
@classmethod
def from_argparse_args(cls, args):
dict_args = vars(args)
return cls(**dict_args)
def __init__(self, learning_rate=1e-3, name='discriminator', **kwargs):
super().__init__()
#self.save_hyperparameters()
self.learning_rate = learning_rate
self.name = name
self.d_model = DCGANDiscriminator(
image_channels = 1
)
def forward(self, x):
#x = self.d_model(x)
#self.disc(x)
x = self.d_model.disc(x)
#out = F.adaptive_avg_pool2d(x, (1, 1)).squeeze()
out = F.adaptive_avg_pool2d(x, (1, 1)).view((x.size(0), -1))
return out
#GAN
class GAN(pl.LightningModule):
@staticmethod
def add_model_specific_args(parent_parser):
parser = parent_parser.add_argument_group("TestModule")
parser.add_argument('--learning_rate', type=float, default=1e-3)
parser.add_argument('--name', type=str, default='GAN')
return parent_parser
@classmethod
def from_argparse_args(cls, args):
dict_args = vars(args)
return cls(**dict_args)
def __init__(self, learning_rate=1e-3, name='model', depth=3, lamda_G = 0.1, **kwargs):
super().__init__()
self.save_hyperparameters()
self.learning_rate = learning_rate
self.name = name
self.generator = Generator()
self.discriminator = Discriminator((256,256))
self.cyto_imag_generated = None
self.lamda_G = lamda_G
def forward(self, x):
result = self.generator(x)
return result
def adversarial_loss(self, y_hat, y):
return F.binary_cross_entropy(y_hat, y)
def training_step(self, batch, batch_idx, optimizer_idx):
#Training the Generator.
if optimizer_idx == 0:
self.cyto_imag_generated = self(batch['pli_image'])
loss_RMI = self.generator.loss_f(self.cyto_imag_generated, batch['cyto_image'])
# ground truth result (ie: all fake)
# put on GPU because we created this tensor inside training_loop
valid = torch.ones(batch['pli_image'].size(0), 1)
valid = valid.type_as(batch['pli_image'])
g_loss = self.adversarial_loss(self.discriminator(self.cyto_imag_generated), valid)
self.log('train_g_loss', g_loss)
self.log('train_g_RMI_loss', loss_RMI)
self.log('train_g_Total_loss', g_loss + self.lamda_G * loss_RMI)
return g_loss + self.lamda_G * loss_RMI
if optimizer_idx == 1:
# how well can it label as real?
valid = torch.ones(batch['pli_image'].size(0), 1)
valid = valid.type_as(batch['pli_image'])
real_loss = self.adversarial_loss(self.discriminator(batch['cyto_image']), valid)
# how well can it label as fake?
fake = torch.zeros(batch['pli_image'].size(0), 1)
fake = fake.type_as(batch['pli_image'])
fake_loss = self.adversarial_loss(self.discriminator(batch['pli_image']).detach(), fake)
d_loss = (real_loss + fake_loss) / 2
self.log('train_d_loss', d_loss)
return d_loss
def validation_step(self, batch, batch_idx):
# Validating the Generator.
loss_RMI = self.generator.loss_f(self.cyto_imag_generated, batch['cyto_image'])
# ground truth result (ie: all fake)
# put on GPU because we created this tensor inside training_loop
valid = torch.ones(batch['pli_image'].size(0), 1)
valid = valid.type_as(batch['pli_image'])
g_loss = self.adversarial_loss(self.discriminator(self.cyto_imag_generated), valid)
#self.log('val_g_loss', g_loss)
#self.log('val_g_RMI_loss', loss_RMI)
self.log('val_g_Total_loss', g_loss + self.lamda_G * loss_RMI)
grid = make_grid([batch['pli_image'][0], batch['cyto_image'][0], self.cyto_imag_generated[0]])
self.logger.experiment.add_image('Grid_images', grid, self.current_epoch, dataformats="CHW")
def configure_optimizers(self):
optimizer_g = torch.optim.Adam(self.generator.parameters(), lr=self.learning_rate)
optimizer_d = torch.optim.Adam(self.discriminator.parameters(), lr=self.learning_rate)
return [optimizer_g, optimizer_d], []
'''
def validation_step(self, batch, batch_idx, optimizer_idx):
# Training the Generator.
if optimizer_idx == 0:
self.cyto_imag_generated = self.forward(batch['pli_image'])
g_loss = self.adversarial_loss(self.discriminator(self.cyto_imag_generated), batch['cyto_image'])
self.log('val_g_loss', g_loss)
return g_loss