Source code for atomai.losses_metrics.losses


Custom Pytorch loss functions
from typing import List, Type, Optional

import torch
import torch.nn.functional as F

[docs]class focal_loss(torch.nn.Module): """ Loss function for classification tasks with large data imbalance. Focal loss (FL) is define as: FL(p_t) = -alpha*((1-p_t)^gamma))*log(p_t), where p_t is a cross-entropy loss for binary classification. For more details, see Args: alpha (float): "balance" coefficient, gamma (float): "focusing" parameter (>=0), with_logits (bool): indicates if the sigmoid operation was applied at the end of a neural network's forward path. """ def __init__(self, alpha: int = 0.5, gamma: int = 2, with_logits: bool = True) -> None: """ Parameter initialization """ super(focal_loss, self).__init__() self.alpha = alpha self.gamma = gamma self.logits = with_logits def forward(self, prediction: torch.Tensor, labels: torch.Tensor): """ Calculates loss """ if self.logits: CE_loss = F.binary_cross_entropy_with_logits(prediction, labels) else: CE_loss = F.binary_cross_entropy(prediction, labels) pt = torch.exp(-CE_loss) F_loss = self.alpha * (1-pt)**self.gamma * CE_loss return F_loss
[docs]class dice_loss(torch.nn.Module): """ Computes the Sørensen–Dice loss. Adapted with changes from """ def __init__(self, eps: float = 1e-7): """ Parameter initialization """ super(dice_loss, self).__init__() self.eps = eps def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: """ Calculate loss """ num_classes = logits.shape[1] if num_classes == 1: true_1_hot = torch.eye(num_classes + 1)[labels.squeeze(1).long()] true_1_hot = true_1_hot.permute(0, 3, 1, 2).contiguous().float() true_1_hot_f = true_1_hot[:, 0:1, :, :] true_1_hot_s = true_1_hot[:, 1:2, :, :] true_1_hot =[true_1_hot_s, true_1_hot_f], dim=1) pos_prob = torch.sigmoid(logits) neg_prob = 1 - pos_prob probas =[pos_prob, neg_prob], dim=1) else: true_1_hot = torch.eye(num_classes)[labels.long()] true_1_hot = true_1_hot.permute(0, 3, 1, 2).contiguous().float() probas = F.softmax(logits, dim=1) true_1_hot = true_1_hot.type(logits.type()) dims = (0,) + tuple(range(2, labels.ndimension())) intersection = torch.sum(probas * true_1_hot, dims) cardinality = torch.sum(probas + true_1_hot, dims) dice_loss = (2. * intersection / (cardinality + self.eps)).mean() return (1 - dice_loss)
class MultiTaskLoss(torch.nn.Module): """ Multi-task loss for handling loss computation in multi-task learning. Args: num_tasks (int): The number of tasks. loss_fn (Type[nn.Module], optional): The loss function class to use for each task. Default is nn.NLLLoss. weights (List[float], optional): The weights for each task's loss. Default is None, which sets equal weights for all tasks. """ def __init__(self, num_tasks: int, loss_fn: Type[torch.nn.Module] = torch.nn.NLLLoss, weights: Optional[List[float]] = None): super(MultiTaskLoss, self).__init__() # Create a list of loss functions for each task self.loss_functions = [loss_fn() for _ in range(num_tasks)] # Set the weights for each task if weights is not None: assert len(weights) == num_tasks, "The length of weights must match num_tasks" self.weights = weights else: self.weights = [1.0] * num_tasks def forward(self, outputs: List[torch.Tensor], labels: List[torch.Tensor]) -> torch.Tensor: """ Compute the combined loss for multiple tasks. Args: outputs (List[torch.Tensor]): A list of output tensors from the multi-task model, one for each task. labels (List[torch.Tensor]): A list of ground truth label tensors, one for each task. Returns: torch.Tensor: The total combined loss. """ # Calculate the individual losses for each task individual_losses = [ weight * loss_fn(output, label) for weight, loss_fn, output, label in zip(self.weights, self.loss_functions, outputs, labels) ] # Combine the individual losses to get the total loss total_loss = sum(individual_losses) return total_loss def select_loss(loss: str, nb_classes: int = None, **kwargs): """ Selects loss for DCNN model training """ if loss in ['ce', 'multitask'] and nb_classes is None: raise ValueError("For cross-entropy loss function, you must" + " specify the number of classes") if loss == 'multitask' and not isinstance(nb_classes, list): raise ValueError("Provide number of classes for each task as a list") if loss == 'dice': criterion = dice_loss() elif loss == 'focal': criterion = focal_loss() elif loss == 'ce' and nb_classes == 1: criterion = torch.nn.BCEWithLogitsLoss() elif loss == 'ce' and nb_classes > 2: criterion = torch.nn.CrossEntropyLoss() elif loss == 'nll': criterion = torch.nn.NLLLoss() elif loss == 'multitask_nll': criterion == MultiTaskLoss(len(nb_classes), **kwargs) elif loss == 'multitask_ce': criterion == MultiTaskLoss( len(nb_classes), torch.nn.CrossEntropyLoss, **kwargs) elif loss == 'mse': criterion = torch.nn.MSELoss() elif hasattr(loss, "__call__"): criterion = loss else: raise NotImplementedError( "Select Dice loss ('dice'), focal loss ('focal') " " cross-entropy loss ('ce'), means-squared error ('mse')," " multitask loss (multitask_nll and multitask_ce)" " or pass your custom loss function" ) return criterion