diff --git a/code/model.py b/code/model.py index 1c010fdc7daf46305dec0fc29b3f8724722b1fc9..32674a304503aedc907561dc0138f87a66afcceb 100644 --- a/code/model.py +++ b/code/model.py @@ -2,6 +2,7 @@ import torch.nn as nn from torch.nn import functional as F import torch import pytorch_lightning as pl +from torchvision.utils import make_grid import segmentation_models_pytorch as smp @@ -51,6 +52,9 @@ class TestModule(pl.LightningModule): cyto_imag_generated = self.forward(batch['pli_image']) loss = self.loss_f(cyto_imag_generated, batch['cyto_image']) self.log("val_loss", loss) + if batch_idx == 0: + grid = make_grid([batch['pli_image'][0, 0], batch['cyto_image'][0, 0], cyto_imag_generated[0, 0]]) + self.logger.experiment.add_image('Grid_images', grid, self.current_epoch, dataformats="HW") def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)