From 66a96ca49865e77e5f200fd07f7ef08494f21e24 Mon Sep 17 00:00:00 2001
From: "s.islam" <s.islam@fz-juelich.de>
Date: Wed, 11 May 2022 05:17:03 +0200
Subject: [PATCH] First GAN model

---
 code/data.py      | 65 +++++++++++++++++++++++------------------------
 code/model_GAN.py | 57 ++++++++++++++++++++++++++++++-----------
 main.py           |  6 ++---
 3 files changed, 77 insertions(+), 51 deletions(-)

diff --git a/code/data.py b/code/data.py
index 580a9d9..704a755 100644
--- a/code/data.py
+++ b/code/data.py
@@ -119,15 +119,14 @@ class TestDataModule(pl.LightningDataModule):
             # TODO: Load the PLI and Cytp train data here as lists of numpy arrays: List[np.ndarray]
             # Load the pyramid/00 per file
 
-            '''
-            #For single channel
-            #For JSC Training.
-            #pli_path = '/p/fastdata/pli/Private/oberstrass1/datasets/vervet1818/vervet1818-stained/data/aligned/pli/NTransmittance'
-            #cyto_path = '/p/fastdata/pli/Private/oberstrass1/datasets/vervet1818/vervet1818-stained/data/aligned/stained'
+            # For single channel
+            # For JSC Training.
+            pli_path = '/p/fastdata/pli/Private/oberstrass1/datasets/vervet1818/vervet1818-stained/data/aligned/pli/NTransmittance'
+            cyto_path = '/p/fastdata/pli/Private/oberstrass1/datasets/vervet1818/vervet1818-stained/data/aligned/stained'
 
-            #For Local Machine Training.
-            pli_path = '/media/tushar/A2246889246861F1/Master Thesis MAIA/example-data/pli/NTransmittance'
-            cyto_path = '/media/tushar/A2246889246861F1/Master Thesis MAIA/example-data/stained'
+            # For Local Machine Training.
+            #pli_path = '/media/tushar/A2246889246861F1/Master Thesis MAIA/example-data/pli/NTransmittance'
+            #cyto_path = '/media/tushar/A2246889246861F1/Master Thesis MAIA/example-data/stained'
 
             pli_files_list = [file for file in os.listdir(pli_path) if
                               file.endswith(('.h5', '.hdf', '.h4', '.hdf4', '.he2', '.hdf5', '.he5'))]
@@ -139,7 +138,7 @@ class TestDataModule(pl.LightningDataModule):
             self.pli_train = []
             self.cyto_train = []
 
-            for i in range(0,4):
+            for i in range(0, 4):
                 pli_train_file = h5py.File(os.path.join(pli_path, pli_files_list[i]), 'r')
                 pli_train_file = pli_train_file['pyramid/00']
                 pli_train_file = np.asarray(pli_train_file).astype(np.float32)
@@ -148,33 +147,35 @@ class TestDataModule(pl.LightningDataModule):
                 self.pli_train.append(pli_train_file)
 
             '''
-
             # For 2/3 Channel.
 
-            # Path for JSC.
+            #Path for JSC.
+
 
             pli_NTransmittance_path = '/p/fastdata/pli/Private/oberstrass1/datasets/vervet1818/vervet1818-stained/data/aligned/pli/NTransmittance'
             pli_Retardation_path = '/p/fastdata/pli/Private/oberstrass1/datasets/vervet1818/vervet1818-stained/data/aligned/pli/Retardation'
             pli_Direction_path = '/p/fastdata/pli/Private/oberstrass1/datasets/vervet1818/vervet1818-stained/data/aligned/pli/Direction'
             cyto_path = '/p/fastdata/pli/Private/oberstrass1/datasets/vervet1818/vervet1818-stained/data/aligned/stained'
 
-            # Path for local machine.
 
-            #pli_NTransmittance_path = '/media/tushar/A2246889246861F1/Master Thesis MAIA/example-data/pli/NTransmittance'
-            #pli_Retardation_path = '/media/tushar/A2246889246861F1/Master Thesis MAIA/example-data/pli/Retardation'
-            #pli_Direction_path = '/media/tushar/A2246889246861F1/Master Thesis MAIA/example-data/pli/Direction'
-            #cyto_path = '/media/tushar/A2246889246861F1/Master Thesis MAIA/example-data/stained'
+            #Path for local machine.
+
+            pli_NTransmittance_path = '/media/tushar/A2246889246861F1/Master Thesis MAIA/example-data/pli/NTransmittance'
+            pli_Retardation_path = '/media/tushar/A2246889246861F1/Master Thesis MAIA/example-data/pli/Retardation'
+            pli_Direction_path = '/media/tushar/A2246889246861F1/Master Thesis MAIA/example-data/pli/Direction'
+            cyto_path = '/media/tushar/A2246889246861F1/Master Thesis MAIA/example-data/stained'
+
 
             pli_NTransmittance_files_list = [file for file in os.listdir(pli_NTransmittance_path) if
-                                             file.endswith(('.h5', '.hdf', '.h4', '.hdf4', '.he2', '.hdf5', '.he5'))]
+                              file.endswith(('.h5', '.hdf', '.h4', '.hdf4', '.he2', '.hdf5', '.he5'))]
             pli_NTransmittance_files_list.sort()
 
             pli_Retardation_files_list = [file for file in os.listdir(pli_Retardation_path) if
-                                          file.endswith(('.h5', '.hdf', '.h4', '.hdf4', '.he2', '.hdf5', '.he5'))]
+                                             file.endswith(('.h5', '.hdf', '.h4', '.hdf4', '.he2', '.hdf5', '.he5'))]
             pli_Retardation_files_list.sort()
 
             pli_Direction_files_list = [file for file in os.listdir(pli_Direction_path) if
-                                        file.endswith(('.h5', '.hdf', '.h4', '.hdf4', '.he2', '.hdf5', '.he5'))]
+                                          file.endswith(('.h5', '.hdf', '.h4', '.hdf4', '.he2', '.hdf5', '.he5'))]
             pli_Direction_files_list.sort()
             cyto_files_list = [file for file in os.listdir(cyto_path) if
                                file.endswith(('.h5', '.hdf', '.h4', '.hdf4', '.he2', '.hdf5', '.he5'))]
@@ -182,7 +183,7 @@ class TestDataModule(pl.LightningDataModule):
 
             pli_NTransmittance_train = []
             pli_Retardation_train = []
-            # pli_Direction_train = []
+            pli_Direction_train = []
 
             for i in range(0, 4):
                 pli_train_file = h5py.File(os.path.join(pli_NTransmittance_path, pli_NTransmittance_files_list[i]), 'r')
@@ -197,7 +198,6 @@ class TestDataModule(pl.LightningDataModule):
                 pli_train_file = pli_train_file
                 pli_Retardation_train.append(pli_train_file)
 
-            '''
             #For Stacking Direction data. As it is not required, don't touch it!
 
             for i in range(0, 4):
@@ -206,13 +206,14 @@ class TestDataModule(pl.LightningDataModule):
                 pli_train_file = np.asarray(pli_train_file).astype(np.float32)
                 pli_train_file = (pli_train_file/255) 
                 pli_Direction_train.append(pli_train_file)
-            '''
+
 
             self.pli_train = []
             self.cyto_train = []
 
-            for i in range(0, 4):
+            for i in range(0,4):
                 self.pli_train.append(np.stack((pli_NTransmittance_train[i], pli_Retardation_train[i]), axis=-1))
+            '''
 
             for i in range(0, 4):
                 cyto_train_file = h5py.File(os.path.join(cyto_path, cyto_files_list[i]), 'r')
@@ -234,24 +235,23 @@ class TestDataModule(pl.LightningDataModule):
             self.pli_val = []
             self.cyto_val = []
 
-            '''
-            #Single Channel
+            # Single Channel
             pli_val_file = h5py.File(os.path.join(pli_path, pli_files_list[4]), 'r')
             pli_val_file = pli_val_file['pyramid/00']
             pli_val_file = np.asarray(pli_val_file).astype(np.float32)
             pli_val_file = np.clip(pli_val_file, 0, 1)
             pli_val_file = pli_val_file
             self.pli_val.append(pli_val_file)
-            '''
 
-            # 2/3 Channels
+            '''
+            #2/3 Channels
 
             pli_NTransmittance_val = []
             pli_Retardation_val = []
-            # pli_Direction_val = []
+            pli_Direction_val = []
+
 
-            pli_NTransmittance_val_file = h5py.File(
-                os.path.join(pli_NTransmittance_path, pli_NTransmittance_files_list[4]), 'r')
+            pli_NTransmittance_val_file = h5py.File(os.path.join(pli_NTransmittance_path, pli_NTransmittance_files_list[4]), 'r')
             pli_NTransmittance_val_file = pli_NTransmittance_val_file['pyramid/00']
             pli_NTransmittance_val_file = np.asarray(pli_NTransmittance_val_file).astype(np.float32)
             pli_NTransmittance_val_file = pli_NTransmittance_val_file
@@ -263,7 +263,6 @@ class TestDataModule(pl.LightningDataModule):
             pli_Retardation_val_file = pli_Retardation_val_file
             pli_Retardation_val.append(pli_Retardation_val_file)
 
-            '''
             #For Stacking Direction data. As it is not required, don't touch it!
 
             pli_Direction_val_file = h5py.File(os.path.join(pli_Direction_path, pli_Direction_files_list[4]), 'r')
@@ -271,9 +270,9 @@ class TestDataModule(pl.LightningDataModule):
             pli_Direction_val_file = np.asarray(pli_Direction_val_file).astype(np.float32)
             pli_Direction_val_file = (pli_Direction_val_file/255)
             pli_Direction_val.append(pli_Direction_val_file)
-            '''
 
             self.pli_val = np.stack((pli_NTransmittance_val, pli_Retardation_val), axis=-1)
+            '''
 
             cyto_val_file = h5py.File(os.path.join(cyto_path, cyto_files_list[4]), 'r')
             cyto_val_file = cyto_val_file['pyramid/00']
@@ -354,4 +353,4 @@ class TestDataModule(pl.LightningDataModule):
         return dl
 
 
-''''''
+''''''
\ No newline at end of file
diff --git a/code/model_GAN.py b/code/model_GAN.py
index d2c9662..a643bd4 100644
--- a/code/model_GAN.py
+++ b/code/model_GAN.py
@@ -34,7 +34,7 @@ class Generator(nn.Module):
         self.g_model = smp.Unet(
             encoder_name="resnet18",  # Also consider using smaller or larger encoders
             encoder_weights= "imagenet",  # Do the pretrained weights help? Try with or without
-            in_channels=2,  # We use 1 chanel transmittance as input
+            in_channels=1,  # We use 1 chanel transmittance as input
             classes=1,  # classes == output channels. We use one output channel for cyto data
             activation="sigmoid"
         )
@@ -90,14 +90,14 @@ class Discriminator_2(nn.Module):
 
         return validity
 
-#Descriminator
-class Descriminator(nn.Module):
+#discriminator
+class Discriminator(nn.Module):
 
     @staticmethod
     def add_model_specific_args(parent_parser):
-        parser = parent_parser.add_argument_group("Descriminator")
+        parser = parent_parser.add_argument_group("discriminator")
         parser.add_argument('--learning_rate', type=float, default=1e-3)
-        parser.add_argument('--name', type=str, default='Descriminator')
+        parser.add_argument('--name', type=str, default='discriminator')
         return parent_parser
 
     @classmethod
@@ -105,13 +105,13 @@ class Descriminator(nn.Module):
         dict_args = vars(args)
         return cls(**dict_args)
 
-    def __init__(self, learning_rate=1e-3, name='Descriminator', **kwargs):
+    def __init__(self, learning_rate=1e-3, name='discriminator', **kwargs):
         super().__init__()
         #self.save_hyperparameters()
         self.learning_rate = learning_rate
         self.name = name
         self.d_model = DCGANDiscriminator(
-        feature_maps = 8,
+        feature_maps = 1,
         image_channels = 1
         )
 
@@ -140,7 +140,7 @@ class GAN(pl.LightningModule):
         self.learning_rate = learning_rate
         self.name = name
         self.generator = Generator()
-        self.descriminator = Discriminator_2((256,256))
+        self.discriminator = Discriminator_2((256,256))
         self.cyto_imag_generated = None
         self.lamda_G = lamda_G
 
@@ -165,44 +165,71 @@ class GAN(pl.LightningModule):
             valid = torch.ones(batch['pli_image'].size(0), 1)
             valid = valid.type_as(batch['pli_image'])
 
-            g_loss = self.adversarial_loss(self.descriminator(self.cyto_imag_generated), valid)
+            g_loss = self.adversarial_loss(self.discriminator(self.cyto_imag_generated), valid)
             self.log('train_g_loss', g_loss)
             self.log('train_g_RMI_loss', loss_RMI)
             self.log('train_g_Total_loss', g_loss + self.lamda_G * loss_RMI)
             return g_loss + self.lamda_G * loss_RMI
 
-        #Training Descriminator.
+        #Training discriminator.
         if optimizer_idx == 1:
             # how well can it label as real?
             valid = torch.ones(batch['pli_image'].size(0), 1)
             valid = valid.type_as(batch['pli_image'])
 
-            real_loss = self.adversarial_loss(self.descriminator(batch['cyto_image']), valid)
+            real_loss = self.adversarial_loss(self.discriminator(batch['cyto_image']), valid)
 
             # how well can it label as fake?
             fake = torch.zeros(batch['pli_image'].size(0), 1)
             fake = fake.type_as(batch['pli_image'])
 
-            fake_loss = self.adversarial_loss(self.descriminator(self(batch['pli_image']).detach()), fake)
+            fake_loss = self.adversarial_loss(self.discriminator(batch['pli_image']).detach(), fake)
 
             d_loss = (real_loss + fake_loss) / 2
             self.log('train_d_loss', d_loss)
             return d_loss
 
+
+
+
+
+    def validation_step(self, batch, batch_idx):
+        # Validating the Generator.
+        cyto_imag_generated = self.generator(batch['pli_image'])
+        loss_RMI = self.generator.loss_f(self.cyto_imag_generated, batch['cyto_image'])
+
+        # ground truth result (ie: all fake)
+        # put on GPU because we created this tensor inside training_loop
+        valid = torch.ones(batch['pli_image'].size(0), 1)
+        valid = valid.type_as(batch['pli_image'])
+
+        g_loss = self.adversarial_loss(self.discriminator(self.cyto_imag_generated), valid)
+        self.log('val_g_loss', g_loss)
+        self.log('val_g_RMI_loss', loss_RMI)
+        self.log('val_g_Total_loss', g_loss + self.lamda_G * loss_RMI)
+        if batch_idx == 0:
+            grid = make_grid([batch['pli_image'][0], batch['cyto_image'][0], cyto_imag_generated[0]])
+            self.logger.experiment.add_image('Grid_images', grid, self.current_epoch, dataformats="CHW")
+
+
+
+
+
+
     def configure_optimizers(self):
         optimizer_g = torch.optim.Adam(self.generator.parameters(), lr=self.learning_rate)
-        optimizer_d = torch.optim.Adam(self.descriminator.parameters(), lr=self.learning_rate)
+        optimizer_d = torch.optim.Adam(self.discriminator.parameters(), lr=self.learning_rate)
         return [optimizer_g, optimizer_d], []
 '''
     def validation_step(self, batch, batch_idx, optimizer_idx):
         # Training the Generator.
         if optimizer_idx == 0:
             self.cyto_imag_generated = self.forward(batch['pli_image'])
-            g_loss = self.adversarial_loss(self.descriminator(self.cyto_imag_generated), batch['cyto_image'])
+            g_loss = self.adversarial_loss(self.discriminator(self.cyto_imag_generated), batch['cyto_image'])
             self.log('val_g_loss', g_loss)
             return g_loss
 
-        # Training Descriminator.
+        # Training discriminator.
         if optimizer_idx == 1:
             d_loss = self.adversarial_loss(self.cyto_imag_generated, batch['cyto_image'])
             self.log('val_d_loss', d_loss)
diff --git a/main.py b/main.py
index e99198b..92adcc2 100644
--- a/main.py
+++ b/main.py
@@ -36,7 +36,7 @@ sys.path.insert(0, "code/")
 
 import utils
 import model_GAN
-import model_Unet
+import model_GAN
 import data
 
 
@@ -44,7 +44,7 @@ def main():
     parser = ArgumentParser()
     parser = pl.Trainer.add_argparse_args(parser)
     parser = data.TestDataModule.add_argparse_args(parser)
-    parser = model_Unet.TestModule.add_model_specific_args(parser)
+    parser = model_GAN.GAN.add_model_specific_args(parser)
     parser.add_argument('--log_dir', type=str, default='doc/tensorboard/')
     parser.add_argument('--ckpt_dir', type=str, default='tmp/ckpt/')
     parser.add_argument('--save_every_n_epochs', type=int, default=None)
@@ -53,7 +53,7 @@ def main():
     dict_args = vars(args)
 
     print("Create model")
-    test_model = model_Unet.TestModule.from_argparse_args(args)
+    test_model = model_GAN.GAN.from_argparse_args(args)
     print(f"Model '{test_model.name}' loaded")
 
     print("Load train data and create Data Module")
-- 
GitLab