From 658b5009f9d69ac6a881d31abb5301f81fffa338 Mon Sep 17 00:00:00 2001
From: "s.islam" <s.islam@fz-juelich.de>
Date: Mon, 25 Apr 2022 01:33:16 +0200
Subject: [PATCH] RMI Loss added and 3 channel code(commented/not functional)
 added.

---
 code/data.py  | 95 ++++++++++++++++++++++++++++++++++++++++++++++++---
 code/model.py |  5 +--
 2 files changed, 93 insertions(+), 7 deletions(-)

diff --git a/code/data.py b/code/data.py
index 235cae1..40c4d76 100644
--- a/code/data.py
+++ b/code/data.py
@@ -121,14 +121,15 @@ class TestDataModule(pl.LightningDataModule):
             # TODO: Load the PLI and Cytp train data here as lists of numpy arrays: List[np.ndarray]
             # Load the pyramid/00 per file
 
+            #For single channel
             #For JSC Training.
-            pli_path = '/p/fastdata/pli/Private/oberstrass1/datasets/vervet1818/vervet1818-stained/data/aligned/pli/NTransmittance'
-            cyto_path = '/p/fastdata/pli/Private/oberstrass1/datasets/vervet1818/vervet1818-stained/data/aligned/stained'
+            #pli_path = '/p/fastdata/pli/Private/oberstrass1/datasets/vervet1818/vervet1818-stained/data/aligned/pli/NTransmittance'
+            #cyto_path = '/p/fastdata/pli/Private/oberstrass1/datasets/vervet1818/vervet1818-stained/data/aligned/stained'
 
             #For Local Machine Training.
-            #pli_path = '/media/tushar/A2246889246861F1/Master Thesis MAIA/example-data/pli/NTransmittance'
-            #cyto_path = '/media/tushar/A2246889246861F1/Master Thesis MAIA/example-data/stained'
-
+            pli_path = '/media/tushar/A2246889246861F1/Master Thesis MAIA/example-data/pli/NTransmittance'
+            cyto_path = '/media/tushar/A2246889246861F1/Master Thesis MAIA/example-data/stained'
+            
             pli_files_list = [file for file in os.listdir(pli_path) if
                               file.endswith(('.h5', '.hdf', '.h4', '.hdf4', '.he2', '.hdf5', '.he5'))]
             pli_files_list.sort()
@@ -145,6 +146,58 @@ class TestDataModule(pl.LightningDataModule):
                 pli_train_file = np.asarray(pli_train_file).astype(np.float32)
                 pli_train_file = pli_train_file - 0.5
                 self.pli_train.append(pli_train_file)
+                
+
+            # For 3 Channel.
+            '''
+            pli_NTransmittance_path = '/media/tushar/A2246889246861F1/Master Thesis MAIA/example-data/pli/NTransmittance'
+            pli_Retardation_path = '/media/tushar/A2246889246861F1/Master Thesis MAIA/example-data/pli/Retardation'
+            pli_Direction_path = '/media/tushar/A2246889246861F1/Master Thesis MAIA/example-data/pli/Direction'
+            cyto_path = '/media/tushar/A2246889246861F1/Master Thesis MAIA/example-data/stained'
+
+            pli_NTransmittance_files_list = [file for file in os.listdir(pli_NTransmittance_path) if
+                              file.endswith(('.h5', '.hdf', '.h4', '.hdf4', '.he2', '.hdf5', '.he5'))]
+            pli_NTransmittance_files_list.sort()
+
+            pli_Retardation_files_list = [file for file in os.listdir(pli_Retardation_path) if
+                                             file.endswith(('.h5', '.hdf', '.h4', '.hdf4', '.he2', '.hdf5', '.he5'))]
+            pli_Retardation_files_list.sort()
+
+            pli_Direction_files_list = [file for file in os.listdir(pli_Direction_path) if
+                                          file.endswith(('.h5', '.hdf', '.h4', '.hdf4', '.he2', '.hdf5', '.he5'))]
+            pli_Direction_files_list.sort()
+            cyto_files_list = [file for file in os.listdir(cyto_path) if
+                               file.endswith(('.h5', '.hdf', '.h4', '.hdf4', '.he2', '.hdf5', '.he5'))]
+
+            pli_NTransmittance_train = []
+            pli_Retardation_train = []
+            pli_Direction_train = []
+
+            for i in range(0, 4):
+                pli_train_file = h5py.File(os.path.join(pli_NTransmittance_path, pli_NTransmittance_files_list[i]), 'r')
+                pli_train_file = pli_train_file['pyramid/00']
+                pli_train_file = np.asarray(pli_train_file).astype(np.float32)
+                pli_train_file = pli_train_file - 0.5
+                pli_NTransmittance_train.append(pli_train_file)
+            for i in range(0, 4):
+                pli_train_file = h5py.File(os.path.join(pli_Retardation_path, pli_Retardation_files_list[i]), 'r')
+                pli_train_file = pli_train_file['pyramid/00']
+                pli_train_file = np.asarray(pli_train_file).astype(np.float32)
+                pli_train_file = pli_train_file - 0.5
+                pli_Retardation_train.append(pli_train_file)
+            for i in range(0, 4):
+                pli_train_file = h5py.File(os.path.join(pli_Direction_path, pli_Direction_files_list[i]), 'r')
+                pli_train_file = pli_train_file['pyramid/00']
+                pli_train_file = np.asarray(pli_train_file).astype(np.float32)
+                pli_train_file = (pli_train_file/255) - 0.5
+                pli_Direction_train.append(pli_train_file)
+
+            self.pli_train = []
+            self.cyto_train = []
+
+            for i in range(0,4):
+                self.pli_train.append(np.stack((pli_NTransmittance_train[i], pli_Retardation_train[i], pli_Direction_train[i]), axis=-1))
+            '''
 
             for i in range(0,4):
                 cyto_train_file = h5py.File(os.path.join(cyto_path, cyto_files_list[i]), 'r')
@@ -165,12 +218,44 @@ class TestDataModule(pl.LightningDataModule):
             self.pli_val = []
             self.cyto_val = []
 
+
+            #Single Channel
             pli_val_file = h5py.File(os.path.join(pli_path, pli_files_list[4]), 'r')
             pli_val_file = pli_val_file['pyramid/00']
             pli_val_file = np.asarray(pli_val_file).astype(np.float32)
             pli_val_file = pli_val_file - 0.5
             self.pli_val.append(pli_val_file)
 
+            #3 Channels
+            '''
+            pli_NTransmittance_val = []
+            pli_Retardation_val = []
+            pli_Direction_val = []
+
+
+            pli_NTransmittance_val_file = h5py.File(os.path.join(pli_NTransmittance_path, pli_NTransmittance_files_list[4]), 'r')
+            pli_NTransmittance_val_file = pli_NTransmittance_val_file['pyramid/00']
+            pli_NTransmittance_val_file = np.asarray(pli_NTransmittance_val_file).astype(np.float32)
+            pli_NTransmittance_val_file = pli_NTransmittance_val_file - 0.5
+            pli_NTransmittance_val.append(pli_NTransmittance_val_file)
+
+            pli_Retardation_val_file = h5py.File(os.path.join(pli_Retardation_path, pli_Retardation_files_list[4]), 'r')
+            pli_Retardation_val_file = pli_Retardation_val_file['pyramid/00']
+            pli_Retardation_val_file = np.asarray(pli_Retardation_val_file).astype(np.float32)
+            pli_Retardation_val_file = pli_Retardation_val_file - 0.5
+            pli_Retardation_val.append(pli_Retardation_val_file)
+
+            pli_Direction_val_file = h5py.File(os.path.join(pli_Direction_path, pli_Direction_files_list[4]), 'r')
+            pli_Direction_val_file = pli_Direction_val_file['pyramid/00']
+            pli_Direction_val_file = np.asarray(pli_Direction_val_file).astype(np.float32)
+            pli_Direction_val_file = (pli_Direction_val_file/255) - 0.5
+            pli_Direction_val.append(pli_Direction_val_file)
+
+            self.pli_val = np.stack((pli_NTransmittance_val, pli_Retardation_val, pli_Direction_val), axis=-1)
+            '''
+
+
+
             cyto_val_file = h5py.File(os.path.join(cyto_path, cyto_files_list[4]), 'r')
             cyto_val_file = cyto_val_file['pyramid/00']
             cyto_val_file = np.asarray(cyto_val_file).astype(np.float32)
diff --git a/code/model.py b/code/model.py
index 8109a25..4ebd957 100644
--- a/code/model.py
+++ b/code/model.py
@@ -2,6 +2,7 @@ import torch.nn as nn
 from torch.nn import functional as F
 import torch
 import pytorch_lightning as pl
+from rmi import RMILoss
 from torchvision.utils import make_grid
 
 import segmentation_models_pytorch as smp
@@ -29,12 +30,12 @@ class TestModule(pl.LightningModule):
 
         # Define the model
         self.model = smp.Unet(
-            encoder_name="densenet121",  # Also consider using smaller or larger encoders
+            encoder_name="resnet34",  # Also consider using smaller or larger encoders
             encoder_weights= "imagenet",  # Do the pretrained weights help? Try with or without
             in_channels=1,  # We use 1 chanel transmittance as input
             classes=1,  # classes == output channels. We use one output channel for cyto data
         )
-        self.loss_f = torch.nn.MSELoss()
+        self.loss_f = RMILoss(with_logits=True) #torch.nn.MSELoss()
 
     def forward(self, x):
         x = self.model(x)
-- 
GitLab