Loss Functions and Metrics

Losses for imbalanced classes

class atomai.losses_metrics.focal_loss(alpha=0.5, gamma=2, with_logits=True)[source]

Loss function for classification tasks with large data imbalance. Focal loss (FL) is define as: FL(p_t) = -alpha*((1-p_t)^gamma))*log(p_t), where p_t is a cross-entropy loss for binary classification. For more details, see https://arxiv.org/abs/1708.02002.

Parameters
  • alpha (float) – “balance” coefficient,

  • gamma (float) – “focusing” parameter (>=0),

  • with_logits (bool) – indicates if the sigmoid operation was applied at the end of a neural network’s forward path.

class atomai.losses_metrics.dice_loss(eps=1e-07)[source]

Computes the Sørensen–Dice loss. Adapted with changes from https://github.com/kevinzakka/pytorch-goodies

Metics for imbalanced classes

class atomai.losses_metrics.IoU(true, pred, activation=True, thresh=0.5)[source]

Computes mean of the Intersection over Union. Adapted with changes from https://github.com/kevinzakka/pytorch-goodies

Parameters
  • true (Tensor) – labels (aka ground truth)

  • pred (Tensor) – model predictions

  • activation (bool) – applies softmax/sigmoid to predictions

  • thresh (float) – image binary threshold level for predictions

classmethod threshold_(xarr, thresh)[source]

Thresholds image data

compute_hist(true, pred)[source]

Computes histogram for a single true-pred pair

evaluate()[source]

Computes mean IoU score for a batch