Skip to content
Snippets Groups Projects
Commit 658b5009 authored by s.islam's avatar s.islam
Browse files

RMI Loss added and 3 channel code(commented/not functional) added.

parent 91657cad
No related branches found
No related tags found
No related merge requests found
......@@ -121,14 +121,15 @@ class TestDataModule(pl.LightningDataModule):
# TODO: Load the PLI and Cytp train data here as lists of numpy arrays: List[np.ndarray]
# Load the pyramid/00 per file
#For single channel
#For JSC Training.
pli_path = '/p/fastdata/pli/Private/oberstrass1/datasets/vervet1818/vervet1818-stained/data/aligned/pli/NTransmittance'
cyto_path = '/p/fastdata/pli/Private/oberstrass1/datasets/vervet1818/vervet1818-stained/data/aligned/stained'
#pli_path = '/p/fastdata/pli/Private/oberstrass1/datasets/vervet1818/vervet1818-stained/data/aligned/pli/NTransmittance'
#cyto_path = '/p/fastdata/pli/Private/oberstrass1/datasets/vervet1818/vervet1818-stained/data/aligned/stained'
#For Local Machine Training.
#pli_path = '/media/tushar/A2246889246861F1/Master Thesis MAIA/example-data/pli/NTransmittance'
#cyto_path = '/media/tushar/A2246889246861F1/Master Thesis MAIA/example-data/stained'
pli_path = '/media/tushar/A2246889246861F1/Master Thesis MAIA/example-data/pli/NTransmittance'
cyto_path = '/media/tushar/A2246889246861F1/Master Thesis MAIA/example-data/stained'
pli_files_list = [file for file in os.listdir(pli_path) if
file.endswith(('.h5', '.hdf', '.h4', '.hdf4', '.he2', '.hdf5', '.he5'))]
pli_files_list.sort()
......@@ -145,6 +146,58 @@ class TestDataModule(pl.LightningDataModule):
pli_train_file = np.asarray(pli_train_file).astype(np.float32)
pli_train_file = pli_train_file - 0.5
self.pli_train.append(pli_train_file)
# For 3 Channel.
'''
pli_NTransmittance_path = '/media/tushar/A2246889246861F1/Master Thesis MAIA/example-data/pli/NTransmittance'
pli_Retardation_path = '/media/tushar/A2246889246861F1/Master Thesis MAIA/example-data/pli/Retardation'
pli_Direction_path = '/media/tushar/A2246889246861F1/Master Thesis MAIA/example-data/pli/Direction'
cyto_path = '/media/tushar/A2246889246861F1/Master Thesis MAIA/example-data/stained'
pli_NTransmittance_files_list = [file for file in os.listdir(pli_NTransmittance_path) if
file.endswith(('.h5', '.hdf', '.h4', '.hdf4', '.he2', '.hdf5', '.he5'))]
pli_NTransmittance_files_list.sort()
pli_Retardation_files_list = [file for file in os.listdir(pli_Retardation_path) if
file.endswith(('.h5', '.hdf', '.h4', '.hdf4', '.he2', '.hdf5', '.he5'))]
pli_Retardation_files_list.sort()
pli_Direction_files_list = [file for file in os.listdir(pli_Direction_path) if
file.endswith(('.h5', '.hdf', '.h4', '.hdf4', '.he2', '.hdf5', '.he5'))]
pli_Direction_files_list.sort()
cyto_files_list = [file for file in os.listdir(cyto_path) if
file.endswith(('.h5', '.hdf', '.h4', '.hdf4', '.he2', '.hdf5', '.he5'))]
pli_NTransmittance_train = []
pli_Retardation_train = []
pli_Direction_train = []
for i in range(0, 4):
pli_train_file = h5py.File(os.path.join(pli_NTransmittance_path, pli_NTransmittance_files_list[i]), 'r')
pli_train_file = pli_train_file['pyramid/00']
pli_train_file = np.asarray(pli_train_file).astype(np.float32)
pli_train_file = pli_train_file - 0.5
pli_NTransmittance_train.append(pli_train_file)
for i in range(0, 4):
pli_train_file = h5py.File(os.path.join(pli_Retardation_path, pli_Retardation_files_list[i]), 'r')
pli_train_file = pli_train_file['pyramid/00']
pli_train_file = np.asarray(pli_train_file).astype(np.float32)
pli_train_file = pli_train_file - 0.5
pli_Retardation_train.append(pli_train_file)
for i in range(0, 4):
pli_train_file = h5py.File(os.path.join(pli_Direction_path, pli_Direction_files_list[i]), 'r')
pli_train_file = pli_train_file['pyramid/00']
pli_train_file = np.asarray(pli_train_file).astype(np.float32)
pli_train_file = (pli_train_file/255) - 0.5
pli_Direction_train.append(pli_train_file)
self.pli_train = []
self.cyto_train = []
for i in range(0,4):
self.pli_train.append(np.stack((pli_NTransmittance_train[i], pli_Retardation_train[i], pli_Direction_train[i]), axis=-1))
'''
for i in range(0,4):
cyto_train_file = h5py.File(os.path.join(cyto_path, cyto_files_list[i]), 'r')
......@@ -165,12 +218,44 @@ class TestDataModule(pl.LightningDataModule):
self.pli_val = []
self.cyto_val = []
#Single Channel
pli_val_file = h5py.File(os.path.join(pli_path, pli_files_list[4]), 'r')
pli_val_file = pli_val_file['pyramid/00']
pli_val_file = np.asarray(pli_val_file).astype(np.float32)
pli_val_file = pli_val_file - 0.5
self.pli_val.append(pli_val_file)
#3 Channels
'''
pli_NTransmittance_val = []
pli_Retardation_val = []
pli_Direction_val = []
pli_NTransmittance_val_file = h5py.File(os.path.join(pli_NTransmittance_path, pli_NTransmittance_files_list[4]), 'r')
pli_NTransmittance_val_file = pli_NTransmittance_val_file['pyramid/00']
pli_NTransmittance_val_file = np.asarray(pli_NTransmittance_val_file).astype(np.float32)
pli_NTransmittance_val_file = pli_NTransmittance_val_file - 0.5
pli_NTransmittance_val.append(pli_NTransmittance_val_file)
pli_Retardation_val_file = h5py.File(os.path.join(pli_Retardation_path, pli_Retardation_files_list[4]), 'r')
pli_Retardation_val_file = pli_Retardation_val_file['pyramid/00']
pli_Retardation_val_file = np.asarray(pli_Retardation_val_file).astype(np.float32)
pli_Retardation_val_file = pli_Retardation_val_file - 0.5
pli_Retardation_val.append(pli_Retardation_val_file)
pli_Direction_val_file = h5py.File(os.path.join(pli_Direction_path, pli_Direction_files_list[4]), 'r')
pli_Direction_val_file = pli_Direction_val_file['pyramid/00']
pli_Direction_val_file = np.asarray(pli_Direction_val_file).astype(np.float32)
pli_Direction_val_file = (pli_Direction_val_file/255) - 0.5
pli_Direction_val.append(pli_Direction_val_file)
self.pli_val = np.stack((pli_NTransmittance_val, pli_Retardation_val, pli_Direction_val), axis=-1)
'''
cyto_val_file = h5py.File(os.path.join(cyto_path, cyto_files_list[4]), 'r')
cyto_val_file = cyto_val_file['pyramid/00']
cyto_val_file = np.asarray(cyto_val_file).astype(np.float32)
......
......@@ -2,6 +2,7 @@ import torch.nn as nn
from torch.nn import functional as F
import torch
import pytorch_lightning as pl
from rmi import RMILoss
from torchvision.utils import make_grid
import segmentation_models_pytorch as smp
......@@ -29,12 +30,12 @@ class TestModule(pl.LightningModule):
# Define the model
self.model = smp.Unet(
encoder_name="densenet121", # Also consider using smaller or larger encoders
encoder_name="resnet34", # Also consider using smaller or larger encoders
encoder_weights= "imagenet", # Do the pretrained weights help? Try with or without
in_channels=1, # We use 1 chanel transmittance as input
classes=1, # classes == output channels. We use one output channel for cyto data
)
self.loss_f = torch.nn.MSELoss()
self.loss_f = RMILoss(with_logits=True) #torch.nn.MSELoss()
def forward(self, x):
x = self.model(x)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment