From 658b5009f9d69ac6a881d31abb5301f81fffa338 Mon Sep 17 00:00:00 2001 From: "s.islam" <s.islam@fz-juelich.de> Date: Mon, 25 Apr 2022 01:33:16 +0200 Subject: [PATCH] RMI Loss added and 3 channel code(commented/not functional) added. --- code/data.py | 95 ++++++++++++++++++++++++++++++++++++++++++++++++--- code/model.py | 5 +-- 2 files changed, 93 insertions(+), 7 deletions(-) diff --git a/code/data.py b/code/data.py index 235cae1..40c4d76 100644 --- a/code/data.py +++ b/code/data.py @@ -121,14 +121,15 @@ class TestDataModule(pl.LightningDataModule): # TODO: Load the PLI and Cytp train data here as lists of numpy arrays: List[np.ndarray] # Load the pyramid/00 per file + #For single channel #For JSC Training. - pli_path = '/p/fastdata/pli/Private/oberstrass1/datasets/vervet1818/vervet1818-stained/data/aligned/pli/NTransmittance' - cyto_path = '/p/fastdata/pli/Private/oberstrass1/datasets/vervet1818/vervet1818-stained/data/aligned/stained' + #pli_path = '/p/fastdata/pli/Private/oberstrass1/datasets/vervet1818/vervet1818-stained/data/aligned/pli/NTransmittance' + #cyto_path = '/p/fastdata/pli/Private/oberstrass1/datasets/vervet1818/vervet1818-stained/data/aligned/stained' #For Local Machine Training. - #pli_path = '/media/tushar/A2246889246861F1/Master Thesis MAIA/example-data/pli/NTransmittance' - #cyto_path = '/media/tushar/A2246889246861F1/Master Thesis MAIA/example-data/stained' - + pli_path = '/media/tushar/A2246889246861F1/Master Thesis MAIA/example-data/pli/NTransmittance' + cyto_path = '/media/tushar/A2246889246861F1/Master Thesis MAIA/example-data/stained' + pli_files_list = [file for file in os.listdir(pli_path) if file.endswith(('.h5', '.hdf', '.h4', '.hdf4', '.he2', '.hdf5', '.he5'))] pli_files_list.sort() @@ -145,6 +146,58 @@ class TestDataModule(pl.LightningDataModule): pli_train_file = np.asarray(pli_train_file).astype(np.float32) pli_train_file = pli_train_file - 0.5 self.pli_train.append(pli_train_file) + + + # For 3 Channel. + ''' + pli_NTransmittance_path = '/media/tushar/A2246889246861F1/Master Thesis MAIA/example-data/pli/NTransmittance' + pli_Retardation_path = '/media/tushar/A2246889246861F1/Master Thesis MAIA/example-data/pli/Retardation' + pli_Direction_path = '/media/tushar/A2246889246861F1/Master Thesis MAIA/example-data/pli/Direction' + cyto_path = '/media/tushar/A2246889246861F1/Master Thesis MAIA/example-data/stained' + + pli_NTransmittance_files_list = [file for file in os.listdir(pli_NTransmittance_path) if + file.endswith(('.h5', '.hdf', '.h4', '.hdf4', '.he2', '.hdf5', '.he5'))] + pli_NTransmittance_files_list.sort() + + pli_Retardation_files_list = [file for file in os.listdir(pli_Retardation_path) if + file.endswith(('.h5', '.hdf', '.h4', '.hdf4', '.he2', '.hdf5', '.he5'))] + pli_Retardation_files_list.sort() + + pli_Direction_files_list = [file for file in os.listdir(pli_Direction_path) if + file.endswith(('.h5', '.hdf', '.h4', '.hdf4', '.he2', '.hdf5', '.he5'))] + pli_Direction_files_list.sort() + cyto_files_list = [file for file in os.listdir(cyto_path) if + file.endswith(('.h5', '.hdf', '.h4', '.hdf4', '.he2', '.hdf5', '.he5'))] + + pli_NTransmittance_train = [] + pli_Retardation_train = [] + pli_Direction_train = [] + + for i in range(0, 4): + pli_train_file = h5py.File(os.path.join(pli_NTransmittance_path, pli_NTransmittance_files_list[i]), 'r') + pli_train_file = pli_train_file['pyramid/00'] + pli_train_file = np.asarray(pli_train_file).astype(np.float32) + pli_train_file = pli_train_file - 0.5 + pli_NTransmittance_train.append(pli_train_file) + for i in range(0, 4): + pli_train_file = h5py.File(os.path.join(pli_Retardation_path, pli_Retardation_files_list[i]), 'r') + pli_train_file = pli_train_file['pyramid/00'] + pli_train_file = np.asarray(pli_train_file).astype(np.float32) + pli_train_file = pli_train_file - 0.5 + pli_Retardation_train.append(pli_train_file) + for i in range(0, 4): + pli_train_file = h5py.File(os.path.join(pli_Direction_path, pli_Direction_files_list[i]), 'r') + pli_train_file = pli_train_file['pyramid/00'] + pli_train_file = np.asarray(pli_train_file).astype(np.float32) + pli_train_file = (pli_train_file/255) - 0.5 + pli_Direction_train.append(pli_train_file) + + self.pli_train = [] + self.cyto_train = [] + + for i in range(0,4): + self.pli_train.append(np.stack((pli_NTransmittance_train[i], pli_Retardation_train[i], pli_Direction_train[i]), axis=-1)) + ''' for i in range(0,4): cyto_train_file = h5py.File(os.path.join(cyto_path, cyto_files_list[i]), 'r') @@ -165,12 +218,44 @@ class TestDataModule(pl.LightningDataModule): self.pli_val = [] self.cyto_val = [] + + #Single Channel pli_val_file = h5py.File(os.path.join(pli_path, pli_files_list[4]), 'r') pli_val_file = pli_val_file['pyramid/00'] pli_val_file = np.asarray(pli_val_file).astype(np.float32) pli_val_file = pli_val_file - 0.5 self.pli_val.append(pli_val_file) + #3 Channels + ''' + pli_NTransmittance_val = [] + pli_Retardation_val = [] + pli_Direction_val = [] + + + pli_NTransmittance_val_file = h5py.File(os.path.join(pli_NTransmittance_path, pli_NTransmittance_files_list[4]), 'r') + pli_NTransmittance_val_file = pli_NTransmittance_val_file['pyramid/00'] + pli_NTransmittance_val_file = np.asarray(pli_NTransmittance_val_file).astype(np.float32) + pli_NTransmittance_val_file = pli_NTransmittance_val_file - 0.5 + pli_NTransmittance_val.append(pli_NTransmittance_val_file) + + pli_Retardation_val_file = h5py.File(os.path.join(pli_Retardation_path, pli_Retardation_files_list[4]), 'r') + pli_Retardation_val_file = pli_Retardation_val_file['pyramid/00'] + pli_Retardation_val_file = np.asarray(pli_Retardation_val_file).astype(np.float32) + pli_Retardation_val_file = pli_Retardation_val_file - 0.5 + pli_Retardation_val.append(pli_Retardation_val_file) + + pli_Direction_val_file = h5py.File(os.path.join(pli_Direction_path, pli_Direction_files_list[4]), 'r') + pli_Direction_val_file = pli_Direction_val_file['pyramid/00'] + pli_Direction_val_file = np.asarray(pli_Direction_val_file).astype(np.float32) + pli_Direction_val_file = (pli_Direction_val_file/255) - 0.5 + pli_Direction_val.append(pli_Direction_val_file) + + self.pli_val = np.stack((pli_NTransmittance_val, pli_Retardation_val, pli_Direction_val), axis=-1) + ''' + + + cyto_val_file = h5py.File(os.path.join(cyto_path, cyto_files_list[4]), 'r') cyto_val_file = cyto_val_file['pyramid/00'] cyto_val_file = np.asarray(cyto_val_file).astype(np.float32) diff --git a/code/model.py b/code/model.py index 8109a25..4ebd957 100644 --- a/code/model.py +++ b/code/model.py @@ -2,6 +2,7 @@ import torch.nn as nn from torch.nn import functional as F import torch import pytorch_lightning as pl +from rmi import RMILoss from torchvision.utils import make_grid import segmentation_models_pytorch as smp @@ -29,12 +30,12 @@ class TestModule(pl.LightningModule): # Define the model self.model = smp.Unet( - encoder_name="densenet121", # Also consider using smaller or larger encoders + encoder_name="resnet34", # Also consider using smaller or larger encoders encoder_weights= "imagenet", # Do the pretrained weights help? Try with or without in_channels=1, # We use 1 chanel transmittance as input classes=1, # classes == output channels. We use one output channel for cyto data ) - self.loss_f = torch.nn.MSELoss() + self.loss_f = RMILoss(with_logits=True) #torch.nn.MSELoss() def forward(self, x): x = self.model(x) -- GitLab