From 66a96ca49865e77e5f200fd07f7ef08494f21e24 Mon Sep 17 00:00:00 2001 From: "s.islam" <s.islam@fz-juelich.de> Date: Wed, 11 May 2022 05:17:03 +0200 Subject: [PATCH] First GAN model --- code/data.py | 65 +++++++++++++++++++++++------------------------ code/model_GAN.py | 57 ++++++++++++++++++++++++++++++----------- main.py | 6 ++--- 3 files changed, 77 insertions(+), 51 deletions(-) diff --git a/code/data.py b/code/data.py index 580a9d9..704a755 100644 --- a/code/data.py +++ b/code/data.py @@ -119,15 +119,14 @@ class TestDataModule(pl.LightningDataModule): # TODO: Load the PLI and Cytp train data here as lists of numpy arrays: List[np.ndarray] # Load the pyramid/00 per file - ''' - #For single channel - #For JSC Training. - #pli_path = '/p/fastdata/pli/Private/oberstrass1/datasets/vervet1818/vervet1818-stained/data/aligned/pli/NTransmittance' - #cyto_path = '/p/fastdata/pli/Private/oberstrass1/datasets/vervet1818/vervet1818-stained/data/aligned/stained' + # For single channel + # For JSC Training. + pli_path = '/p/fastdata/pli/Private/oberstrass1/datasets/vervet1818/vervet1818-stained/data/aligned/pli/NTransmittance' + cyto_path = '/p/fastdata/pli/Private/oberstrass1/datasets/vervet1818/vervet1818-stained/data/aligned/stained' - #For Local Machine Training. - pli_path = '/media/tushar/A2246889246861F1/Master Thesis MAIA/example-data/pli/NTransmittance' - cyto_path = '/media/tushar/A2246889246861F1/Master Thesis MAIA/example-data/stained' + # For Local Machine Training. + #pli_path = '/media/tushar/A2246889246861F1/Master Thesis MAIA/example-data/pli/NTransmittance' + #cyto_path = '/media/tushar/A2246889246861F1/Master Thesis MAIA/example-data/stained' pli_files_list = [file for file in os.listdir(pli_path) if file.endswith(('.h5', '.hdf', '.h4', '.hdf4', '.he2', '.hdf5', '.he5'))] @@ -139,7 +138,7 @@ class TestDataModule(pl.LightningDataModule): self.pli_train = [] self.cyto_train = [] - for i in range(0,4): + for i in range(0, 4): pli_train_file = h5py.File(os.path.join(pli_path, pli_files_list[i]), 'r') pli_train_file = pli_train_file['pyramid/00'] pli_train_file = np.asarray(pli_train_file).astype(np.float32) @@ -148,33 +147,35 @@ class TestDataModule(pl.LightningDataModule): self.pli_train.append(pli_train_file) ''' - # For 2/3 Channel. - # Path for JSC. + #Path for JSC. + pli_NTransmittance_path = '/p/fastdata/pli/Private/oberstrass1/datasets/vervet1818/vervet1818-stained/data/aligned/pli/NTransmittance' pli_Retardation_path = '/p/fastdata/pli/Private/oberstrass1/datasets/vervet1818/vervet1818-stained/data/aligned/pli/Retardation' pli_Direction_path = '/p/fastdata/pli/Private/oberstrass1/datasets/vervet1818/vervet1818-stained/data/aligned/pli/Direction' cyto_path = '/p/fastdata/pli/Private/oberstrass1/datasets/vervet1818/vervet1818-stained/data/aligned/stained' - # Path for local machine. - #pli_NTransmittance_path = '/media/tushar/A2246889246861F1/Master Thesis MAIA/example-data/pli/NTransmittance' - #pli_Retardation_path = '/media/tushar/A2246889246861F1/Master Thesis MAIA/example-data/pli/Retardation' - #pli_Direction_path = '/media/tushar/A2246889246861F1/Master Thesis MAIA/example-data/pli/Direction' - #cyto_path = '/media/tushar/A2246889246861F1/Master Thesis MAIA/example-data/stained' + #Path for local machine. + + pli_NTransmittance_path = '/media/tushar/A2246889246861F1/Master Thesis MAIA/example-data/pli/NTransmittance' + pli_Retardation_path = '/media/tushar/A2246889246861F1/Master Thesis MAIA/example-data/pli/Retardation' + pli_Direction_path = '/media/tushar/A2246889246861F1/Master Thesis MAIA/example-data/pli/Direction' + cyto_path = '/media/tushar/A2246889246861F1/Master Thesis MAIA/example-data/stained' + pli_NTransmittance_files_list = [file for file in os.listdir(pli_NTransmittance_path) if - file.endswith(('.h5', '.hdf', '.h4', '.hdf4', '.he2', '.hdf5', '.he5'))] + file.endswith(('.h5', '.hdf', '.h4', '.hdf4', '.he2', '.hdf5', '.he5'))] pli_NTransmittance_files_list.sort() pli_Retardation_files_list = [file for file in os.listdir(pli_Retardation_path) if - file.endswith(('.h5', '.hdf', '.h4', '.hdf4', '.he2', '.hdf5', '.he5'))] + file.endswith(('.h5', '.hdf', '.h4', '.hdf4', '.he2', '.hdf5', '.he5'))] pli_Retardation_files_list.sort() pli_Direction_files_list = [file for file in os.listdir(pli_Direction_path) if - file.endswith(('.h5', '.hdf', '.h4', '.hdf4', '.he2', '.hdf5', '.he5'))] + file.endswith(('.h5', '.hdf', '.h4', '.hdf4', '.he2', '.hdf5', '.he5'))] pli_Direction_files_list.sort() cyto_files_list = [file for file in os.listdir(cyto_path) if file.endswith(('.h5', '.hdf', '.h4', '.hdf4', '.he2', '.hdf5', '.he5'))] @@ -182,7 +183,7 @@ class TestDataModule(pl.LightningDataModule): pli_NTransmittance_train = [] pli_Retardation_train = [] - # pli_Direction_train = [] + pli_Direction_train = [] for i in range(0, 4): pli_train_file = h5py.File(os.path.join(pli_NTransmittance_path, pli_NTransmittance_files_list[i]), 'r') @@ -197,7 +198,6 @@ class TestDataModule(pl.LightningDataModule): pli_train_file = pli_train_file pli_Retardation_train.append(pli_train_file) - ''' #For Stacking Direction data. As it is not required, don't touch it! for i in range(0, 4): @@ -206,13 +206,14 @@ class TestDataModule(pl.LightningDataModule): pli_train_file = np.asarray(pli_train_file).astype(np.float32) pli_train_file = (pli_train_file/255) pli_Direction_train.append(pli_train_file) - ''' + self.pli_train = [] self.cyto_train = [] - for i in range(0, 4): + for i in range(0,4): self.pli_train.append(np.stack((pli_NTransmittance_train[i], pli_Retardation_train[i]), axis=-1)) + ''' for i in range(0, 4): cyto_train_file = h5py.File(os.path.join(cyto_path, cyto_files_list[i]), 'r') @@ -234,24 +235,23 @@ class TestDataModule(pl.LightningDataModule): self.pli_val = [] self.cyto_val = [] - ''' - #Single Channel + # Single Channel pli_val_file = h5py.File(os.path.join(pli_path, pli_files_list[4]), 'r') pli_val_file = pli_val_file['pyramid/00'] pli_val_file = np.asarray(pli_val_file).astype(np.float32) pli_val_file = np.clip(pli_val_file, 0, 1) pli_val_file = pli_val_file self.pli_val.append(pli_val_file) - ''' - # 2/3 Channels + ''' + #2/3 Channels pli_NTransmittance_val = [] pli_Retardation_val = [] - # pli_Direction_val = [] + pli_Direction_val = [] + - pli_NTransmittance_val_file = h5py.File( - os.path.join(pli_NTransmittance_path, pli_NTransmittance_files_list[4]), 'r') + pli_NTransmittance_val_file = h5py.File(os.path.join(pli_NTransmittance_path, pli_NTransmittance_files_list[4]), 'r') pli_NTransmittance_val_file = pli_NTransmittance_val_file['pyramid/00'] pli_NTransmittance_val_file = np.asarray(pli_NTransmittance_val_file).astype(np.float32) pli_NTransmittance_val_file = pli_NTransmittance_val_file @@ -263,7 +263,6 @@ class TestDataModule(pl.LightningDataModule): pli_Retardation_val_file = pli_Retardation_val_file pli_Retardation_val.append(pli_Retardation_val_file) - ''' #For Stacking Direction data. As it is not required, don't touch it! pli_Direction_val_file = h5py.File(os.path.join(pli_Direction_path, pli_Direction_files_list[4]), 'r') @@ -271,9 +270,9 @@ class TestDataModule(pl.LightningDataModule): pli_Direction_val_file = np.asarray(pli_Direction_val_file).astype(np.float32) pli_Direction_val_file = (pli_Direction_val_file/255) pli_Direction_val.append(pli_Direction_val_file) - ''' self.pli_val = np.stack((pli_NTransmittance_val, pli_Retardation_val), axis=-1) + ''' cyto_val_file = h5py.File(os.path.join(cyto_path, cyto_files_list[4]), 'r') cyto_val_file = cyto_val_file['pyramid/00'] @@ -354,4 +353,4 @@ class TestDataModule(pl.LightningDataModule): return dl -'''''' +'''''' \ No newline at end of file diff --git a/code/model_GAN.py b/code/model_GAN.py index d2c9662..a643bd4 100644 --- a/code/model_GAN.py +++ b/code/model_GAN.py @@ -34,7 +34,7 @@ class Generator(nn.Module): self.g_model = smp.Unet( encoder_name="resnet18", # Also consider using smaller or larger encoders encoder_weights= "imagenet", # Do the pretrained weights help? Try with or without - in_channels=2, # We use 1 chanel transmittance as input + in_channels=1, # We use 1 chanel transmittance as input classes=1, # classes == output channels. We use one output channel for cyto data activation="sigmoid" ) @@ -90,14 +90,14 @@ class Discriminator_2(nn.Module): return validity -#Descriminator -class Descriminator(nn.Module): +#discriminator +class Discriminator(nn.Module): @staticmethod def add_model_specific_args(parent_parser): - parser = parent_parser.add_argument_group("Descriminator") + parser = parent_parser.add_argument_group("discriminator") parser.add_argument('--learning_rate', type=float, default=1e-3) - parser.add_argument('--name', type=str, default='Descriminator') + parser.add_argument('--name', type=str, default='discriminator') return parent_parser @classmethod @@ -105,13 +105,13 @@ class Descriminator(nn.Module): dict_args = vars(args) return cls(**dict_args) - def __init__(self, learning_rate=1e-3, name='Descriminator', **kwargs): + def __init__(self, learning_rate=1e-3, name='discriminator', **kwargs): super().__init__() #self.save_hyperparameters() self.learning_rate = learning_rate self.name = name self.d_model = DCGANDiscriminator( - feature_maps = 8, + feature_maps = 1, image_channels = 1 ) @@ -140,7 +140,7 @@ class GAN(pl.LightningModule): self.learning_rate = learning_rate self.name = name self.generator = Generator() - self.descriminator = Discriminator_2((256,256)) + self.discriminator = Discriminator_2((256,256)) self.cyto_imag_generated = None self.lamda_G = lamda_G @@ -165,44 +165,71 @@ class GAN(pl.LightningModule): valid = torch.ones(batch['pli_image'].size(0), 1) valid = valid.type_as(batch['pli_image']) - g_loss = self.adversarial_loss(self.descriminator(self.cyto_imag_generated), valid) + g_loss = self.adversarial_loss(self.discriminator(self.cyto_imag_generated), valid) self.log('train_g_loss', g_loss) self.log('train_g_RMI_loss', loss_RMI) self.log('train_g_Total_loss', g_loss + self.lamda_G * loss_RMI) return g_loss + self.lamda_G * loss_RMI - #Training Descriminator. + #Training discriminator. if optimizer_idx == 1: # how well can it label as real? valid = torch.ones(batch['pli_image'].size(0), 1) valid = valid.type_as(batch['pli_image']) - real_loss = self.adversarial_loss(self.descriminator(batch['cyto_image']), valid) + real_loss = self.adversarial_loss(self.discriminator(batch['cyto_image']), valid) # how well can it label as fake? fake = torch.zeros(batch['pli_image'].size(0), 1) fake = fake.type_as(batch['pli_image']) - fake_loss = self.adversarial_loss(self.descriminator(self(batch['pli_image']).detach()), fake) + fake_loss = self.adversarial_loss(self.discriminator(batch['pli_image']).detach(), fake) d_loss = (real_loss + fake_loss) / 2 self.log('train_d_loss', d_loss) return d_loss + + + + + def validation_step(self, batch, batch_idx): + # Validating the Generator. + cyto_imag_generated = self.generator(batch['pli_image']) + loss_RMI = self.generator.loss_f(self.cyto_imag_generated, batch['cyto_image']) + + # ground truth result (ie: all fake) + # put on GPU because we created this tensor inside training_loop + valid = torch.ones(batch['pli_image'].size(0), 1) + valid = valid.type_as(batch['pli_image']) + + g_loss = self.adversarial_loss(self.discriminator(self.cyto_imag_generated), valid) + self.log('val_g_loss', g_loss) + self.log('val_g_RMI_loss', loss_RMI) + self.log('val_g_Total_loss', g_loss + self.lamda_G * loss_RMI) + if batch_idx == 0: + grid = make_grid([batch['pli_image'][0], batch['cyto_image'][0], cyto_imag_generated[0]]) + self.logger.experiment.add_image('Grid_images', grid, self.current_epoch, dataformats="CHW") + + + + + + def configure_optimizers(self): optimizer_g = torch.optim.Adam(self.generator.parameters(), lr=self.learning_rate) - optimizer_d = torch.optim.Adam(self.descriminator.parameters(), lr=self.learning_rate) + optimizer_d = torch.optim.Adam(self.discriminator.parameters(), lr=self.learning_rate) return [optimizer_g, optimizer_d], [] ''' def validation_step(self, batch, batch_idx, optimizer_idx): # Training the Generator. if optimizer_idx == 0: self.cyto_imag_generated = self.forward(batch['pli_image']) - g_loss = self.adversarial_loss(self.descriminator(self.cyto_imag_generated), batch['cyto_image']) + g_loss = self.adversarial_loss(self.discriminator(self.cyto_imag_generated), batch['cyto_image']) self.log('val_g_loss', g_loss) return g_loss - # Training Descriminator. + # Training discriminator. if optimizer_idx == 1: d_loss = self.adversarial_loss(self.cyto_imag_generated, batch['cyto_image']) self.log('val_d_loss', d_loss) diff --git a/main.py b/main.py index e99198b..92adcc2 100644 --- a/main.py +++ b/main.py @@ -36,7 +36,7 @@ sys.path.insert(0, "code/") import utils import model_GAN -import model_Unet +import model_GAN import data @@ -44,7 +44,7 @@ def main(): parser = ArgumentParser() parser = pl.Trainer.add_argparse_args(parser) parser = data.TestDataModule.add_argparse_args(parser) - parser = model_Unet.TestModule.add_model_specific_args(parser) + parser = model_GAN.GAN.add_model_specific_args(parser) parser.add_argument('--log_dir', type=str, default='doc/tensorboard/') parser.add_argument('--ckpt_dir', type=str, default='tmp/ckpt/') parser.add_argument('--save_every_n_epochs', type=int, default=None) @@ -53,7 +53,7 @@ def main(): dict_args = vars(args) print("Create model") - test_model = model_Unet.TestModule.from_argparse_args(args) + test_model = model_GAN.GAN.from_argparse_args(args) print(f"Model '{test_model.name}' loaded") print("Load train data and create Data Module") -- GitLab