data.py 9.03 KiB
from typing import Any, List
from numpy import random
import h5py
import os
import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import pytorch_lightning as pl
from torchvision import transforms
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2 as ToTensor
# Distributed
from atlasmpi import MPI
comm = MPI.COMM_WORLD
# pli example data directory: '/p/fastdata/pli/Private/oberstrass1/datasets/vervet1818/vervet1818-stained/data/aligned/pli'
# NTransmittance example data directory(pli_path): '/p/fastdata/pli/Private/oberstrass1/datasets/vervet1818/vervet1818-stained/data/aligned/pli/NTransmittance'
# stained example data directory(cyto_path): '/p/fastdata/pli/Private/oberstrass1/datasets/vervet1818/vervet1818-stained/data/aligned/stained'
#pli_path = '/media/tushar/A2246889246861F1/Master Thesis MAIA/example-data/pli/NTransmittance'
#cyto_path = '/media/tushar/A2246889246861F1/Master Thesis MAIA/example-data/stained'
# print(len(pli_files_list))
# print(pli_files_list)
# print(cyto_files_list)
#num_images = len(pli_files_list)
class TestSampler(Dataset):
# Gives you a random crop and a random image at each request
def __init__(self, pli_files_list, cyto_files_list, transforms, crop_size, dataset_size):
# crop_size is the size before the rotation and center crop. So the patch_size * sqrt(2)
# dataset_size defines the number of drawn patches per epoch. As we are drawing (arbitrary many) random patches we have to set is manually
super().__init__()
# list of pli has to be in the same order as list of cyto. So index i in pli should correspond to the same index in cyto
self.list_of_pli = pli_files_list
self.list_of_cyto = cyto_files_list
self.n_images = len(self.list_of_pli)
self.transforms = transforms
self.crop_size = crop_size
self.dataset_size = dataset_size
def __getitem__(self, ix):
# Get a random image
i = random.randint(self.n_images)
pli_image = self.list_of_pli[i]
cyto_image = self.list_of_cyto[i]
# Generate a random patch location from the image
x = random.randint(pli_image.shape[1] - self.crop_size)
y = random.randint(pli_image.shape[0] - self.crop_size)
# Get crops at the x, y location with size crop_size x crop_size
random_crop_pli = pli_image[y:y + self.crop_size, x:x + self.crop_size]
random_crop_cyto = cyto_image[y:y + self.crop_size, x:x + self.crop_size]
# Apply transforms on pli and cyto simultaniously
sample = self.transforms(image=random_crop_pli, cyto_image=random_crop_cyto)
sample["pli_image"] = sample.pop("image")
return sample
def __len__(self):
return self.dataset_size
class TestDataModule(pl.LightningDataModule):
def __init__(
self,
train_size: int = 2 ** 10,
val_size: int = 2 ** 6,
batch_size: int = 8,
crop_size: int = 362, # approx 256 * sqrt(2)
patch_size: int = 256,
num_workers: int = 4,
*args: Any,
**kwargs: Any
):
super().__init__(*args, **kwargs)
self.train_size = train_size
self.val_size = val_size
self.batch_size = batch_size
self.crop_size = crop_size
self.patch_size = patch_size
self.num_workers = num_workers
self.pli_train = None
self.cyto_train = None
self.pli_val = None
self.cyto_val = None
# Runs only on main processs
# Downloading data, doing some general preperation
def prepare_data(self):
pass
# Runs on evers single instance (on every process)
def setup(self, stage=None):
print("Setup Data Module")
# This is for multi GPU processing
rank = comm.Get_rank()
size = comm.size
# Load data from disk
if self.pli_train is None or self.cyto_train is None:
print(f"Rank {rank}/{size} preparing data")
# TODO: Load the PLI and Cytp train data here as lists of numpy arrays: List[np.ndarray]
# Load the pyramid/00 per file
#For JSC Training.
#pli_path = '/p/fastdata/pli/Private/oberstrass1/datasets/vervet1818/vervet1818-stained/data/aligned/pli/NTransmittance'
#cyto_path = '/p/fastdata/pli/Private/oberstrass1/datasets/vervet1818/vervet1818-stained/data/aligned/stained'
#For Local Machine Training.
pli_path = '/media/tushar/A2246889246861F1/Master Thesis MAIA/example-data/pli/NTransmittance'
cyto_path = '/media/tushar/A2246889246861F1/Master Thesis MAIA/example-data/stained'
pli_files_list = [file for file in os.listdir(pli_path) if
file.endswith(('.h5', '.hdf', '.h4', '.hdf4', '.he2', '.hdf5', '.he5'))]
pli_files_list.sort()
cyto_files_list = [file for file in os.listdir(cyto_path) if
file.endswith(('.h5', '.hdf', '.h4', '.hdf4', '.he2', '.hdf5', '.he5'))]
cyto_files_list.sort()
self.pli_train = []
self.cyto_train = []
for i in range(0,4):
pli_train_file = h5py.File(os.path.join(pli_path, pli_files_list[i]), 'r')
pli_train_file = pli_train_file['pyramid/00']
pli_train_file = np.asarray(pli_train_file).astype(np.float32)
pli_train_file = pli_train_file - 0.5
self.pli_train.append(pli_train_file)
for i in range(0,4):
cyto_train_file = h5py.File(os.path.join(cyto_path, cyto_files_list[i]), 'r')
cyto_train_file = cyto_train_file['pyramid/00']
cyto_train_file = np.asarray(cyto_train_file).astype(np.float32)
cyto_train_file = (cyto_train_file/255) - 0.5
self.cyto_train.append(cyto_train_file)
else:
print(f"Train data for rank {rank}/{size} already prepared")
if self.pli_val is None or self.cyto_val is None:
print(f"Rank {rank}/{size} preparing data")
# TODO: Load the PLI and Cytp val data here as lists of numpy arrays: List[np.ndarray]
# This should contain only unseen images
# Load the pyramid/00 per file
self.pli_val = []
self.cyto_val = []
pli_val_file = h5py.File(os.path.join(pli_path, pli_files_list[4]), 'r')
pli_val_file = pli_val_file['pyramid/00']
pli_val_file = np.asarray(pli_val_file).astype(np.float32)
pli_val_file = pli_val_file - 0.5
self.pli_val.append(pli_val_file)
cyto_val_file = h5py.File(os.path.join(cyto_path, cyto_files_list[4]), 'r')
cyto_val_file = cyto_val_file['pyramid/00']
cyto_val_file = np.asarray(cyto_val_file).astype(np.float32)
cyto_val_file = (cyto_val_file/255) - 0.5
self.cyto_val.append(cyto_val_file)
else:
print(f"Validation data for rank {rank}/{size} already prepared")
# Augmentations for training
self._train_transforms = A.Compose(
[
A.Affine(rotate=(-180, 180)),
A.CenterCrop(p=1, height=self.patch_size, width=self.patch_size),
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.5),
ToTensor(),
],
additional_targets={'pli_image': 'image', 'cyto_image': 'image'}
)
# Augmentation for validation and testing
self._test_transforms = A.Compose(
[
A.CenterCrop(p=1, height=self.patch_size, width=self.patch_size),
ToTensor(),
],
additional_targets={'pli_image': 'image', 'cyto_image': 'image'}
)
# Define the datasets for training and validation
self.train_sampler = TestSampler(
self.pli_train,
self.cyto_train,
self._train_transforms,
self.crop_size,
self.train_size
)
self.val_sampler = TestSampler(
self.pli_val,
self.cyto_val,
self._test_transforms,
self.crop_size,
self.val_size,
)
def train_dataloader(self):
def wif(id):
process_seed = torch.initial_seed()
# Back out the base_seed so we can use all the bits.
base_seed = process_seed - id
ss = np.random.SeedSequence([id, base_seed])
# More than 128 bits (4 32-bit words) would be overkill.
np.random.seed(ss.generate_state(4))
dl = DataLoader(
self.train_sampler,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers,
worker_init_fn=wif
)
return dl
def val_dataloader(self):
dl = DataLoader(
self.val_sampler,
batch_size=self.batch_size,
num_workers=self.num_workers,
)
return dl
''''''