Skip to content
Snippets Groups Projects
data.py 9.03 KiB
from typing import Any, List
from numpy import random

import h5py
import os
import numpy as np

import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import pytorch_lightning as pl

from torchvision import transforms
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2 as ToTensor

# Distributed
from atlasmpi import MPI

comm = MPI.COMM_WORLD

# pli example data directory: '/p/fastdata/pli/Private/oberstrass1/datasets/vervet1818/vervet1818-stained/data/aligned/pli'
# NTransmittance example data directory(pli_path): '/p/fastdata/pli/Private/oberstrass1/datasets/vervet1818/vervet1818-stained/data/aligned/pli/NTransmittance'
# stained example data directory(cyto_path): '/p/fastdata/pli/Private/oberstrass1/datasets/vervet1818/vervet1818-stained/data/aligned/stained'


#pli_path = '/media/tushar/A2246889246861F1/Master Thesis MAIA/example-data/pli/NTransmittance'


#cyto_path = '/media/tushar/A2246889246861F1/Master Thesis MAIA/example-data/stained'





# print(len(pli_files_list))
# print(pli_files_list)
# print(cyto_files_list)
#num_images = len(pli_files_list)

class TestSampler(Dataset):
    # Gives you a random crop and a random image at each request

    def __init__(self, pli_files_list, cyto_files_list, transforms, crop_size, dataset_size):
        # crop_size is the size before the rotation and center crop. So the patch_size * sqrt(2)
        # dataset_size defines the number of drawn patches per epoch. As we are drawing (arbitrary many) random patches we have to set is manually
        super().__init__()
        # list of pli has to be in the same order as list of cyto. So index i in pli should correspond to the same index in cyto
        self.list_of_pli = pli_files_list
        self.list_of_cyto = cyto_files_list
        self.n_images = len(self.list_of_pli)
        self.transforms = transforms
        self.crop_size = crop_size
        self.dataset_size = dataset_size

    def __getitem__(self, ix):
        # Get a random image
        i = random.randint(self.n_images)
        pli_image = self.list_of_pli[i]
        cyto_image = self.list_of_cyto[i]

        # Generate a random patch location from the image
        x = random.randint(pli_image.shape[1] - self.crop_size)
        y = random.randint(pli_image.shape[0] - self.crop_size)

        # Get crops at the x, y location with size crop_size x crop_size
        random_crop_pli = pli_image[y:y + self.crop_size, x:x + self.crop_size]
        random_crop_cyto = cyto_image[y:y + self.crop_size, x:x + self.crop_size]

        # Apply transforms on pli and cyto simultaniously
        sample = self.transforms(image=random_crop_pli, cyto_image=random_crop_cyto)
        sample["pli_image"] = sample.pop("image")
        return sample

    def __len__(self):
        return self.dataset_size


class TestDataModule(pl.LightningDataModule):

    def __init__(
            self,
            train_size: int = 2 ** 10,
            val_size: int = 2 ** 6,
            batch_size: int = 8,
            crop_size: int = 362,  # approx 256 * sqrt(2)
            patch_size: int = 256,
            num_workers: int = 4,
            *args: Any,
            **kwargs: Any
    ):
        super().__init__(*args, **kwargs)
        self.train_size = train_size
        self.val_size = val_size
        self.batch_size = batch_size
        self.crop_size = crop_size
        self.patch_size = patch_size
        self.num_workers = num_workers
        self.pli_train = None
        self.cyto_train = None
        self.pli_val = None
        self.cyto_val = None

    # Runs only on main processs
    # Downloading data, doing some general preperation
    def prepare_data(self):
        pass

    # Runs on evers single instance (on every process)
    def setup(self, stage=None):
        print("Setup Data Module")

        # This is for multi GPU processing
        rank = comm.Get_rank()
        size = comm.size

        # Load data from disk
        if self.pli_train is None or self.cyto_train is None:
            print(f"Rank {rank}/{size} preparing data")

            # TODO: Load the PLI and Cytp train data here as lists of numpy arrays: List[np.ndarray]
            # Load the pyramid/00 per file

            #For JSC Training.
            #pli_path = '/p/fastdata/pli/Private/oberstrass1/datasets/vervet1818/vervet1818-stained/data/aligned/pli/NTransmittance'
            #cyto_path = '/p/fastdata/pli/Private/oberstrass1/datasets/vervet1818/vervet1818-stained/data/aligned/stained'

            #For Local Machine Training.
            pli_path = '/media/tushar/A2246889246861F1/Master Thesis MAIA/example-data/pli/NTransmittance'
            cyto_path = '/media/tushar/A2246889246861F1/Master Thesis MAIA/example-data/stained'

            pli_files_list = [file for file in os.listdir(pli_path) if
                              file.endswith(('.h5', '.hdf', '.h4', '.hdf4', '.he2', '.hdf5', '.he5'))]
            pli_files_list.sort()
            cyto_files_list = [file for file in os.listdir(cyto_path) if
                               file.endswith(('.h5', '.hdf', '.h4', '.hdf4', '.he2', '.hdf5', '.he5'))]
            cyto_files_list.sort()

            self.pli_train = []
            self.cyto_train = []

            for i in range(0,4):
                pli_train_file = h5py.File(os.path.join(pli_path, pli_files_list[i]), 'r')
                pli_train_file = pli_train_file['pyramid/00']
                pli_train_file = np.asarray(pli_train_file).astype(np.float32)
                pli_train_file = pli_train_file - 0.5
                self.pli_train.append(pli_train_file)

            for i in range(0,4):
                cyto_train_file = h5py.File(os.path.join(cyto_path, cyto_files_list[i]), 'r')
                cyto_train_file = cyto_train_file['pyramid/00']
                cyto_train_file = np.asarray(cyto_train_file).astype(np.float32)
                cyto_train_file = (cyto_train_file/255) - 0.5
                self.cyto_train.append(cyto_train_file)

        else:
            print(f"Train data for rank {rank}/{size} already prepared")

        if self.pli_val is None or self.cyto_val is None:
            print(f"Rank {rank}/{size} preparing data")

            # TODO: Load the PLI and Cytp val data here as lists of numpy arrays: List[np.ndarray]
            # This should contain only unseen images
            # Load the pyramid/00 per file
            self.pli_val = []
            self.cyto_val = []

            pli_val_file = h5py.File(os.path.join(pli_path, pli_files_list[4]), 'r')
            pli_val_file = pli_val_file['pyramid/00']
            pli_val_file = np.asarray(pli_val_file).astype(np.float32)
            pli_val_file = pli_val_file - 0.5
            self.pli_val.append(pli_val_file)

            cyto_val_file = h5py.File(os.path.join(cyto_path, cyto_files_list[4]), 'r')
            cyto_val_file = cyto_val_file['pyramid/00']
            cyto_val_file = np.asarray(cyto_val_file).astype(np.float32)
            cyto_val_file = (cyto_val_file/255) - 0.5
            self.cyto_val.append(cyto_val_file)

        else:
            print(f"Validation data for rank {rank}/{size} already prepared")

        # Augmentations for training
        self._train_transforms = A.Compose(
            [
                A.Affine(rotate=(-180, 180)),
                A.CenterCrop(p=1, height=self.patch_size, width=self.patch_size),
                A.HorizontalFlip(p=0.5),
                A.VerticalFlip(p=0.5),
                ToTensor(),
            ],
            additional_targets={'pli_image': 'image', 'cyto_image': 'image'}
        )

        # Augmentation for validation and testing
        self._test_transforms = A.Compose(
            [
                A.CenterCrop(p=1, height=self.patch_size, width=self.patch_size),
                ToTensor(),
            ],
            additional_targets={'pli_image': 'image', 'cyto_image': 'image'}
        )

        # Define the datasets for training and validation
        self.train_sampler = TestSampler(
            self.pli_train,
            self.cyto_train,
            self._train_transforms,
            self.crop_size,
            self.train_size
        )

        self.val_sampler = TestSampler(
            self.pli_val,
            self.cyto_val,
            self._test_transforms,
            self.crop_size,
            self.val_size,
        )

    def train_dataloader(self):

        def wif(id):
            process_seed = torch.initial_seed()
            # Back out the base_seed so we can use all the bits.
            base_seed = process_seed - id
            ss = np.random.SeedSequence([id, base_seed])
            # More than 128 bits (4 32-bit words) would be overkill.
            np.random.seed(ss.generate_state(4))

        dl = DataLoader(
            self.train_sampler,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            worker_init_fn=wif
        )

        return dl

    def val_dataloader(self):

        dl = DataLoader(
            self.val_sampler,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
        )

        return dl


''''''