From 0994425e5836be3033bcc12005e48246e105ee4b Mon Sep 17 00:00:00 2001
From: "s.islam" <s.islam@fz-juelich.de>
Date: Fri, 29 Apr 2022 14:20:00 +0200
Subject: [PATCH] Using 2 channels(NTransmittance & Retardation) for more
 regional information.

---
 code/data.py  | 43 +++++++++++++++++++++++++++++++++----------
 code/model.py |  6 +++---
 2 files changed, 36 insertions(+), 13 deletions(-)

diff --git a/code/data.py b/code/data.py
index 4b98b44..868f218 100644
--- a/code/data.py
+++ b/code/data.py
@@ -121,6 +121,7 @@ class TestDataModule(pl.LightningDataModule):
             # TODO: Load the PLI and Cytp train data here as lists of numpy arrays: List[np.ndarray]
             # Load the pyramid/00 per file
 
+            '''
             #For single channel
             #For JSC Training.
             pli_path = '/p/fastdata/pli/Private/oberstrass1/datasets/vervet1818/vervet1818-stained/data/aligned/pli/NTransmittance'
@@ -147,14 +148,26 @@ class TestDataModule(pl.LightningDataModule):
                 pli_train_file = pli_train_file - 0.5
                 pli_train_file = np.clip(pli_train_file, 0, 1)
                 self.pli_train.append(pli_train_file)
-                
+            '''
+
+            # For 2/3 Channel.
+
+            #Path for JSC.
+
+            pli_NTransmittance_path = '/p/fastdata/pli/Private/oberstrass1/datasets/vervet1818/vervet1818-stained/data/aligned/pli/NTransmittance'
+            pli_Retardation_path = '/p/fastdata/pli/Private/oberstrass1/datasets/vervet1818/vervet1818-stained/data/aligned/pli/Retardation'
+            pli_Direction_path = '/p/fastdata/pli/Private/oberstrass1/datasets/vervet1818/vervet1818-stained/data/aligned/pli/Direction'
+            cyto_path = '/p/fastdata/pli/Private/oberstrass1/datasets/vervet1818/vervet1818-stained/data/aligned/stained'
+
+
+            #Path for local machine.
 
-            # For 3 Channel.
             '''
             pli_NTransmittance_path = '/media/tushar/A2246889246861F1/Master Thesis MAIA/example-data/pli/NTransmittance'
             pli_Retardation_path = '/media/tushar/A2246889246861F1/Master Thesis MAIA/example-data/pli/Retardation'
             pli_Direction_path = '/media/tushar/A2246889246861F1/Master Thesis MAIA/example-data/pli/Direction'
             cyto_path = '/media/tushar/A2246889246861F1/Master Thesis MAIA/example-data/stained'
+            '''
 
             pli_NTransmittance_files_list = [file for file in os.listdir(pli_NTransmittance_path) if
                               file.endswith(('.h5', '.hdf', '.h4', '.hdf4', '.he2', '.hdf5', '.he5'))]
@@ -186,19 +199,23 @@ class TestDataModule(pl.LightningDataModule):
                 pli_train_file = np.asarray(pli_train_file).astype(np.float32)
                 pli_train_file = pli_train_file - 0.5
                 pli_Retardation_train.append(pli_train_file)
+
+            #For Stacking Direction data. As it is not required, don't touch it!
+            '''
             for i in range(0, 4):
                 pli_train_file = h5py.File(os.path.join(pli_Direction_path, pli_Direction_files_list[i]), 'r')
                 pli_train_file = pli_train_file['pyramid/00']
                 pli_train_file = np.asarray(pli_train_file).astype(np.float32)
                 pli_train_file = (pli_train_file/255) - 0.5
                 pli_Direction_train.append(pli_train_file)
+            '''
 
             self.pli_train = []
             self.cyto_train = []
 
             for i in range(0,4):
-                self.pli_train.append(np.stack((pli_NTransmittance_train[i], pli_Retardation_train[i], pli_Direction_train[i]), axis=-1))
-            '''
+                self.pli_train.append(np.stack((pli_NTransmittance_train[i], pli_Retardation_train[i]), axis=-1))
+
 
             for i in range(0,4):
                 cyto_train_file = h5py.File(os.path.join(cyto_path, cyto_files_list[i]), 'r')
@@ -207,6 +224,7 @@ class TestDataModule(pl.LightningDataModule):
                 cyto_train_file = (cyto_train_file/255) - 0.5
                 self.cyto_train.append(cyto_train_file)
 
+
         else:
             print(f"Train data for rank {rank}/{size} already prepared")
 
@@ -220,6 +238,7 @@ class TestDataModule(pl.LightningDataModule):
             self.cyto_val = []
 
 
+            '''
             #Single Channel
             pli_val_file = h5py.File(os.path.join(pli_path, pli_files_list[4]), 'r')
             pli_val_file = pli_val_file['pyramid/00']
@@ -227,9 +246,10 @@ class TestDataModule(pl.LightningDataModule):
             pli_val_file = np.clip(pli_val_file, 0, 1)
             pli_val_file = pli_val_file - 0.5
             self.pli_val.append(pli_val_file)
-
-            #3 Channels
             '''
+
+            #2/3 Channels
+
             pli_NTransmittance_val = []
             pli_Retardation_val = []
             pli_Direction_val = []
@@ -247,15 +267,18 @@ class TestDataModule(pl.LightningDataModule):
             pli_Retardation_val_file = pli_Retardation_val_file - 0.5
             pli_Retardation_val.append(pli_Retardation_val_file)
 
+            #For Stacking Direction data. As it is not required, don't touch it!
+            '''
             pli_Direction_val_file = h5py.File(os.path.join(pli_Direction_path, pli_Direction_files_list[4]), 'r')
             pli_Direction_val_file = pli_Direction_val_file['pyramid/00']
             pli_Direction_val_file = np.asarray(pli_Direction_val_file).astype(np.float32)
             pli_Direction_val_file = (pli_Direction_val_file/255) - 0.5
             pli_Direction_val.append(pli_Direction_val_file)
-
-            self.pli_val = np.stack((pli_NTransmittance_val, pli_Retardation_val, pli_Direction_val), axis=-1)
             '''
 
+            self.pli_val = np.stack((pli_NTransmittance_val, pli_Retardation_val), axis=-1)
+
+
 
 
             cyto_val_file = h5py.File(os.path.join(cyto_path, cyto_files_list[4]), 'r')
@@ -291,7 +314,7 @@ class TestDataModule(pl.LightningDataModule):
 
         # Define the datasets for training and validation
         self.train_sampler = TestSampler(
-            self.pli_train,
+            list(self.pli_train),
             self.cyto_train,
             self._train_transforms,
             self.crop_size,
@@ -299,7 +322,7 @@ class TestDataModule(pl.LightningDataModule):
         )
 
         self.val_sampler = TestSampler(
-            self.pli_val,
+            list(self.pli_val),
             self.cyto_val,
             self._test_transforms,
             self.crop_size,
diff --git a/code/model.py b/code/model.py
index 4ebd957..22a408f 100644
--- a/code/model.py
+++ b/code/model.py
@@ -30,9 +30,9 @@ class TestModule(pl.LightningModule):
 
         # Define the model
         self.model = smp.Unet(
-            encoder_name="resnet34",  # Also consider using smaller or larger encoders
+            encoder_name="resnet18",  # Also consider using smaller or larger encoders
             encoder_weights= "imagenet",  # Do the pretrained weights help? Try with or without
-            in_channels=1,  # We use 1 chanel transmittance as input
+            in_channels=2,  # We use 1 chanel transmittance as input
             classes=1,  # classes == output channels. We use one output channel for cyto data
         )
         self.loss_f = RMILoss(with_logits=True) #torch.nn.MSELoss()
@@ -57,7 +57,7 @@ class TestModule(pl.LightningModule):
         batch['cyto_image'] = batch['cyto_image'] + 0.5
         cyto_imag_generated = cyto_imag_generated +0.5
         if batch_idx == 0:
-            grid = make_grid([batch['pli_image'][0], batch['cyto_image'][0], cyto_imag_generated[0]])
+            grid = make_grid([batch['pli_image'][0, :1], batch['pli_image'][:1, 0], batch['cyto_image'][0], cyto_imag_generated[0]])
             self.logger.experiment.add_image('Grid_images', grid, self.current_epoch, dataformats="CHW")
 
     def configure_optimizers(self):
-- 
GitLab