diff --git a/code/model.py b/code/model.py
index 1c010fdc7daf46305dec0fc29b3f8724722b1fc9..32674a304503aedc907561dc0138f87a66afcceb 100644
--- a/code/model.py
+++ b/code/model.py
@@ -2,6 +2,7 @@ import torch.nn as nn
 from torch.nn import functional as F
 import torch
 import pytorch_lightning as pl
+from torchvision.utils import make_grid
 
 import segmentation_models_pytorch as smp
 
@@ -51,6 +52,9 @@ class TestModule(pl.LightningModule):
         cyto_imag_generated = self.forward(batch['pli_image'])
         loss = self.loss_f(cyto_imag_generated, batch['cyto_image'])
         self.log("val_loss", loss)
+        if batch_idx == 0:
+            grid = make_grid([batch['pli_image'][0, 0], batch['cyto_image'][0, 0], cyto_imag_generated[0, 0]])
+            self.logger.experiment.add_image('Grid_images', grid, self.current_epoch, dataformats="HW")
 
     def configure_optimizers(self):
         optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)