diff --git a/code/data.py b/code/data.py index 3dc5ced42c2e13150ffdcedd74572eac2eb7ea89..da228d72c859870a9b2b9aa3d4aa7e6c350ff37f 100644 --- a/code/data.py +++ b/code/data.py @@ -294,10 +294,10 @@ class TestDataModule(pl.LightningDataModule): # Augmentations for training self._train_transforms = A.Compose( [ - A.Affine(rotate=(-180, 180), translate_percent=(0.1, 0.1), shear=(-30, 30), cval=0, p=0.9), + A.Affine(rotate=(-180, 180), translate_percent=(0.1, 0.1), shear=(-30, 30), cval=0, p=1), A.CenterCrop(p=1, height=self.patch_size, width=self.patch_size), - A.HorizontalFlip(p=0.9), - A.VerticalFlip(p=0.9), + A.HorizontalFlip(p=0.5), + A.VerticalFlip(p=0.5), ToTensor(), ], additional_targets={'pli_image': 'image', 'cyto_image': 'image'}