Source code for atomai.predictors.epredictor

"""
epredictor.py
===========

Module for predicting with ensembles of pre-trained neural networks

Created by Maxim Ziatdinov (email: maxim.ziatdinov@ai4microscopy.com)
"""

from typing import Dict, Tuple, Type, Union

import numpy as np
import torch
from torch.nn.functional import softmax

from ..utils import (get_downsample_factor, torch_format_image,
                     torch_format_spectra, cluster_coord)
from .predictor import BasePredictor, Locator


[docs]class EnsemblePredictor(BasePredictor): """ Prediction with ensemble of models Args: skeleton: Model skeleton (cam be with randomly initialized weights) ensemble: Ensemble of trained weights data_type: Input data type (image or spectra) output_type: Output data type (image or spectra) nb_classes: Number of classes (e.g. for semantic segmentation) in_dim: Input data size (for models with fully-connected layers) out_dim: Output data size (for models with fully-connected layers) **output_shape: Optionally one may specify the exact output shape **verbose: verbosity Example: >>> p = aoi.predictors.EnsemblePredictor(skeleton, ensemble, nb_classes=3) >>> nn_out_mean, nn_out_var = p.predict(expdata) """ def __init__(self, skeleton: Type[torch.nn.Module], ensemble: Dict[int, Dict[str, torch.Tensor]], data_type: str = "image", output_type: str = "image", nb_classes: int = None, in_dim: Tuple[int] = None, out_dim: Tuple[int] = None, **kwargs: Union[str, Tuple[int]]) -> None: """ Initialize ensemble predictor """ super(EnsemblePredictor, self).__init__() if output_type not in ["image", "spectra"]: raise TypeError("Supported output types are 'image' and 'spectra'") inout = [data_type, output_type] inout_d = not all([in_dim, out_dim]) if inout in (["image", "spectra"], ["spectra", "image"]) and inout_d: raise TypeError( "Specify input (in_dim) & output (out_dim) dimensions") self.device = "cpu" if kwargs.get("use_gpu", True) and torch.cuda.is_available(): if kwargs.get("device") is None: self.device = "cuda" else: self.device = kwargs.get("device") self.model = skeleton self.ensemble = ensemble self.data_type = data_type self.output_type = output_type self.nb_classes = nb_classes self.in_dim, self.out_dim = in_dim, out_dim self.downsample_factor = None self.logits = kwargs.get("logits", True) self.output_shape = kwargs.get("output_shape") verbose = kwargs.get("verbose", 1) if verbose: self.everbose = True self.verbose = True if verbose > 1 else False def _set_output_shape(self, data: np.ndarray) -> None: """ Sets output shape """ if self.data_type == self.output_type == "image": if self.nb_classes: # semantic segmentation out_shape = (len(data), self.nb_classes, *data.shape[2:]) else: # image cleaning out_shape = (len(data), 1, *data.shape[2:]) elif self.data_type == "spectra" and self.output_type == "image": if self.nb_classes: out_shape = (len(data), self.nb_classes, *self.out_dim) else: out_shape = (len(data), 1, *self.out_dim) elif self.data_type == "image" and self.output_type == "spectra": out_shape = (len(data), 1, *self.out_dim) elif self.data_type == self.output_type == "spectra": out_shape = (len(data), 1, *data.shape[2:]) else: raise TypeError("Data not understood") self.output_shape = out_shape
[docs] def preprocess(self, data: np.ndarray, norm: bool = True ) -> torch.Tensor: """ Preprocesses data depending on type (image or spectra) """ if self.data_type == "image": if data.ndim == 2: data = data[np.newaxis, ...] data = torch_format_image(data, norm) elif self.data_type == "spectra": if data.ndim == 1: data = data[np.newaxis, ...] data = torch_format_spectra(data, norm) return data
[docs] def ensemble_forward_(self, data: torch.Tensor, out_shape: Tuple[int] ) -> Tuple[np.ndarray]: """ Computes mean and variance of prediction with ensemble models """ eprediction = self.ensemble_forward(data, out_shape) return np.mean(eprediction, axis=0), np.var(eprediction, axis=0)
[docs] def ensemble_forward(self, data: torch.Tensor, out_shape: Tuple[int], num_batches: int = 1) -> np.ndarray: """ Computes prediction with ensemble models. Returns ALL calculated predictions (n_models * n_samples). """ eprediction = np.zeros( (len(self.ensemble), *out_shape)) for i, m in enumerate(self.ensemble.values()): self.model.load_state_dict(m) self._model2device() if num_batches > 1: prob = self.batch_predict( data, out_shape, num_batches) else: prob = self.forward_(data) nclasses = 0 if not self.nb_classes else self.nb_classes if self.logits: if nclasses > 1: prob = softmax(prob, dim=1) elif self.nb_classes == 1: prob = torch.sigmoid(prob) else: if nclasses > 1: prob = torch.exp(prob) eprediction[i] = prob.cpu().numpy() return eprediction
[docs] def ensemble_batch_predict(self, data: np.ndarray, num_batches: int = 10 ) -> Tuple[np.ndarray]: """ Batch-by-batch prediction with ensemble models """ batch_size = len(data) // num_batches if batch_size < 1: num_batches = batch_size = 1 prediction_mean = np.zeros(shape=self.output_shape) prediction_var = np.zeros(shape=self.output_shape) for i in range(num_batches): if self.everbose: print("\rBatch {}/{}".format(i+1, num_batches), end="") data_i = data[i*batch_size:(i+1)*batch_size] pred_mean, pred_var = self.ensemble_forward_( data_i, (batch_size, *self.output_shape[1:])) prediction_mean[i*batch_size:(i+1)*batch_size] = pred_mean prediction_var[i*batch_size:(i+1)*batch_size] = pred_var data_i = data[(i+1)*batch_size:] if len(data_i) > 0: pred_mean, pred_var = self.ensemble_forward_( data_i, (len(data_i), *self.output_shape[1:])) prediction_mean[(i+1)*batch_size:] = pred_mean prediction_var[(i+1)*batch_size:] = pred_var return prediction_mean, prediction_var
[docs] def predict(self, data: np.ndarray, num_batches: int = 10, format_out: str = "channel_last", norm: bool = True ) -> Tuple[np.ndarray]: """ Predicts mean and variance for all the data points with ensemble of models Args: data: input data num_batches: number of batches for batch-by-batch prediction (Default: 10) format_out: 'channel_last' of 'channel_first' dimension order in output norm: Normalize input data to (0, 1) Returns: Tuple of numpy arrays with predicted mean and variance """ if format_out not in ["channel_first", "channel_last"]: raise ValueError( "Specify channel_last or channel_first output format") data = self.preprocess(data, norm) if not self.output_shape: self._set_output_shape(data) if (self.data_type == self.output_type == "image" and self.downsample_factor is None): self.downsample_factor = get_downsample_factor(self.model) prediction = self.ensemble_batch_predict(data, num_batches) prediction_mean, prediction_var = prediction # channel transpose if format_out == "channel_last": size_dim = np.arange(prediction_mean.ndim - 2) + 2 c_tr = (0, *size_dim, 1) elif format_out == "channel_first": c_tr = np.arange(prediction_mean.ndim) return prediction_mean.transpose(c_tr), prediction_var.transpose(c_tr)
[docs]def ensemble_locate(nn_output_ensemble: np.ndarray, **kwargs: Dict) -> Tuple[np.ndarray, np.ndarray]: """ Finds coordinates for each ensemble predictions Args: nn_output_ensembles (numpy array): 5D numpy array with ensemble predictions **eps (float): DBSCAN epsilon for clustering coordinates **threshold (float): threshold value for atomnet.locator Returns: Mean and variance for every detected atom/defect/particle coordinate """ eps = kwargs.get("eps", 0.5) thresh = kwargs.get("threshold", 0.5) coord_mean_all = {} coord_var_all = {} for i in range(nn_output_ensemble.shape[1]): coordinates = {} nn_output = nn_output_ensemble[:, i] for i2, img in enumerate(nn_output): coord = Locator(thresh).run(img[None, ...]) coordinates[i2] = coord[0] _, coord_mean, coord_var = cluster_coord(coordinates, eps) coord_mean_all[i] = coord_mean coord_var_all[i] = coord_var return coord_mean_all, coord_var_all