diff --git a/code/model_Unet.py b/code/model_Unet.py index 517183ee0d0c119c7fd3badc69c9d32c94153c14..e4dca0bfdd4e6ce4b2d2b052b7f61960f0806de6 100644 --- a/code/model_Unet.py +++ b/code/model_Unet.py @@ -37,7 +37,7 @@ class TestModule(pl.LightningModule): classes=1, # classes == output channels. We use one output channel for cyto data activation="sigmoid" ) - self.loss_f = torch.nn.MSELoss() #torch.nn.L1Loss() #RMILoss(with_logits=True) #torch.nn.MSELoss() + self.loss_f = RMILoss(with_logits=True) #torch.nn.L1Loss() #RMILoss(with_logits=True) #torch.nn.MSELoss() def forward(self, x): x = self.model(x)