Source code for atomai.predictors.predictor


Module for making predictions with pre-trained neural networks,
including semantic segmentation models and im2spec models.

Created by Maxim Ziatdinov (

import time
from typing import Dict, List, Tuple, Type, Union

import numpy as np
import torch
import torch.nn.functional as F
from atomai.utils import (cv_thresh, find_com, get_downsample_factor,
                          get_nb_classes, img_pad, img_resize, peak_refinement,
                          set_train_rng, torch_format_image,

[docs]class BasePredictor: """ Base predictor class """ def __init__(self, model: Type[torch.nn.Module] = None, use_gpu: bool = False, **kwargs: Union[bool, str]) -> None: """ Initialize predictor Args: model: trained pytorch model use_gpu: Use GPU accelerator (Default: False) **device: CUDA device, e.g. 'cuda:0' """ self.model = model self.device = "cpu" if use_gpu and torch.cuda.is_available(): if kwargs.get("device") is None: self.device = "cuda" else: self.device = kwargs.get("device") if self.model is not None: self.verbose = kwargs.get("verbose", False)
[docs] def preprocess(self, data: Union[torch.Tensor, np.ndarray] ) -> None: """ Preprocess input data """ if isinstance(data, np.ndarray): data = torch.from_numpy(data).float() return data
def _model2device(self, device: str = None) -> None: if device is None: device = self.device def _data2device(self, data: torch.Tensor, device: str = None) -> torch.Tensor: if device is None: device = self.device data = return data
[docs] def forward_(self, xnew: torch.Tensor) -> torch.Tensor: """ Pass data through a trained neural network """ self.model.eval() with torch.no_grad(): out = self.model( return out
[docs] def batch_predict(self, data: torch.Tensor, out_shape: Tuple[int], num_batches: int) -> torch.Tensor: """ Make a prediction batch-by-batch (for larger datasets) """ batch_size = len(data) // num_batches if batch_size < 1: num_batches = batch_size = 1 # prediction_all = np.zeros(shape=out_shape) prediction_all = torch.zeros(out_shape) for i in range(num_batches): if self.verbose: print("\rBatch {}/{}".format(i+1, num_batches), end="") data_i = data[i*batch_size:(i+1)*batch_size] prediction_i = self.forward_(data_i) # We put predictions on cpu since the major point of batch-by-batch # prediction is to not run out of the GPU memory prediction_all[i*batch_size:(i+1)*batch_size] = prediction_i.cpu() data_i = data[(i+1)*batch_size:] if len(data_i) > 0: prediction_i = self.forward_(data_i) prediction_all[(i+1)*batch_size:] = prediction_i.cpu() return prediction_all
[docs] def predict(self, data: torch.Tensor, out_shape: Tuple[int] = None, num_batches: int = 1) -> torch.Tensor: """ Make a prediction on the new data with a trained model """ if out_shape is None: out_shape = data.shape else: out_shape = (data.shape[0], *out_shape) data = self.preprocess(data) prediction = self.batch_predict(data, out_shape, num_batches) return prediction
[docs]class SegPredictor(BasePredictor): """ Prediction with a trained fully convolutional neural network Args: trained_model: Trained pytorch model (skeleton+weights) refine: Atomic positions refinement with 2d Gaussian peak fitting resize: Target dimensions for optional image(s) resizing use_gpu: Use gpu device for inference logits: Indicates that the image data is passed through a softmax/sigmoid layer when set to False (logits=True for AtomAI models) **thresh (float): value between 0 and 1 for thresholding the NN output (Default: 0.5) **d (int): half-side of a square around each atomic position used for refinement with 2d Gaussian peak fitting. Defaults to 1/4 of average nearest neighbor atomic distance **nb_classes (int): Number of classes in the model **downsampling (int or float): Downsampling factor (equal to :math:`2^n` where *n* is a number of pooling operations) Example: >>> # Here we load new experimental data (as 2D or 3D numpy array) >>> expdata = np.load('expdata-test.npy') >>> # Get prediction from a trained model >>> pseg = atomnet.SegPredictor(trained_model) >>> nn_output, coords = """ def __init__(self, trained_model: Type[torch.nn.Module], refine: bool = False, resize: Union[Tuple, List] = None, use_gpu: bool = False, logits: bool = True, **kwargs: Union[int, float, bool]) -> None: """ Initializes predictive object """ super(SegPredictor, self).__init__(trained_model, use_gpu) set_train_rng(1) self.nb_classes = kwargs.get('nb_classes', None) if self.nb_classes is None: self.nb_classes = get_nb_classes(trained_model) self.downsampling = kwargs.get('downsampling', None) if self.downsampling is None: self.downsampling = get_downsample_factor(trained_model) self.resize = resize self.logits = logits self.refine = refine self.d = kwargs.get("d", None) self.thresh = kwargs.get("thresh", .5) self.use_gpu = use_gpu self.verbose = kwargs.get("verbose", True)
[docs] def preprocess(self, image_data: np.ndarray, norm: bool = True) -> torch.Tensor: """ Prepares an input for a neural network """ if image_data.ndim == 2: image_data = image_data[np.newaxis, ...] elif image_data.ndim == 4: if image_data.shape[-1] == 1: image_data = image_data[..., 0] elif image_data.shape[1] == 1: image_data = image_data[:, 0, ...] if self.resize is not None: image_data = img_resize(image_data, self.resize) image_data = img_pad(image_data, self.downsampling) image_data = torch_format_image(image_data, norm) return image_data
[docs] def forward_(self, images: torch.Tensor) -> np.ndarray: """ Returns 'probability' of each pixel in image(s) belonging to an atom/defect """ images = self.model.eval() with torch.no_grad(): prob = self.model(images) if self.logits: if self.nb_classes > 1: prob = F.softmax(prob, dim=1) else: prob = torch.sigmoid(prob) else: if self.nb_classes > 1: prob = torch.exp(prob) else: pass prob = prob.permute(0, 2, 3, 1) # reshape to have channel as a last dim images = images.cpu() prob = prob.cpu() return prob
[docs] def predict(self, image_data: np.ndarray, return_image: bool = False, **kwargs: int) -> Tuple[np.ndarray]: """ Make prediction Args: image_data: 3D image stack or a single 2D image (all greyscale) return_image: Returns images used as input into NN **num_batches (int): number of batches **norm (bool): Normalize data to (0, 1) during pre-processing """ image_data = self.preprocess( image_data, kwargs.get("norm", True)) n, _, w, h = image_data.shape num_batches = kwargs.get("num_batches") if num_batches is None: if w >= 256 or h >= 256: num_batches = len(image_data) else: num_batches = 10 segmented_imgs = self.batch_predict( image_data, (n, w, h, self.nb_classes), num_batches) if return_image: image_data = image_data.permute(0, 2, 3, 1).numpy() return image_data, segmented_imgs.numpy() return segmented_imgs.numpy()
[docs] def run(self, image_data: np.ndarray, compute_coords=True, **kwargs: int) -> Tuple[np.ndarray, Dict[int, np.ndarray]]: """ Make prediction with a trained model and calculate coordinates Args: image_data: Image stack or a single image (all greyscale) compute_coords: Computes centers of the mass of individual blobs in the segmented images (Default: True) **num_batches (int): number of batches for batch-by-batch prediction which ensures that one doesn't run out of memory (Default: 10) **norm (bool): Normalize data to (0, 1) during pre-processing """ start_time = time.time() if not compute_coords: decoded_imgs = self.predict(image_data, **kwargs) return decoded_imgs images, decoded_imgs = self.predict( image_data, return_image=True, **kwargs) loc = Locator(self.thresh, refine=self.refine, d=self.d) coordinates =, images) if self.verbose: n_images_str = " image was " if decoded_imgs.shape[0] == 1 else " images were " print("\n" + str(decoded_imgs.shape[0]) + n_images_str + "decoded in approximately " + str(np.around(time.time() - start_time, decimals=4)) + ' seconds') return decoded_imgs, coordinates
[docs]class ImSpecPredictor(BasePredictor): """ Prediction with a trained im2spec or spec2im model Args: trained_model: Pre-trained neural network output_dim: Output dimensions. For im2spec, the output_dim is (length,). For spec2im, the output_dim is (height, width) use_gpu: Use GPU accelration for prediction verbose: Verbosity Example: >>> # Predict spectra from images with pretrained im2spec model >>> out_dim = (16,) # spectra length >>> prediction = ImSpecPredictor(trained_model, out_dim).run(data) """ def __init__(self, trained_model: Type[torch.nn.Module], output_dim: Tuple[int], use_gpu: bool = False, **kwargs: str) -> None: """ Initialize predictor """ super(ImSpecPredictor, self).__init__(trained_model, use_gpu) if isinstance(output_dim, int): output_dim = (output_dim,) if len(output_dim) not in [1, 2]: raise ValueError("output_dim must be a two-value tuple for images" + " and a single-value tuple for spectra") set_train_rng(1) self.output_dim = output_dim self.verbose = kwargs.get("verbose", True)
[docs] def preprocess(self, signal: np.ndarray, norm: bool = True) -> torch.Tensor: """ Preprocess input signal (images or spectra) """ if len(self.output_dim) == 1: if signal.ndim == 2: signal = signal[np.newaxis, ...] signal = torch_format_image(signal, norm) elif len(self.output_dim) == 2: if signal.ndim == 1: signal = signal[np.newaxis, ...] signal = torch_format_spectra(signal, norm) return signal
[docs] def predict(self, signal: np.ndarray, **kwargs: int) -> np.ndarray: """ Predict spectra from images or vice versa Args: signal: Input image/spectrum or batch of images/spectra **num_batches (int): number of batches (Default: 10) **norm (bool): Normalize data to (0, 1) during pre-processing """ signal = self.preprocess(signal, kwargs.get("norm", True)) num_batches = kwargs.get("num_batches", 10) output = self.batch_predict( signal, (len(signal), 1, *self.output_dim), num_batches) return output[:, 0].numpy()
[docs] def run(self, signal: np.ndarray, **kwargs: int) -> np.ndarray: """ Make prediction with a trained model Args: signal: Input image/spectrum or batch of images/spectra **num_batches (int): number of batches (Default: 10) **norm (bool): Normalize data to (0, 1) during pre-processing """ start_time = time.time() prediction = self.predict(signal, **kwargs) if self.verbose: if len(self.output_dim) == 1: str_ = " image was " if prediction.shape[0] == 1 else " images were " else: str_ = " spectrum was " if prediction.shape[0] == 1 else " spectra were " print("\n" + str(prediction.shape[0]) + str_ + "decoded in approximately " + str(np.around(time.time() - start_time, decimals=4)) + ' seconds') return prediction
class RegPredictor(BasePredictor): """ Prediction with a trained regression model Args: trained_model: Pre-trained neural network output_dim: Output dimensions (e.g., for single-output regression, output_dim=1) use_gpu: Use GPU accelration for prediction verbose: Verbosity Example: >>> # Make predictions with trained regression model >>> out_dim = 1 # single-output regression >>> prediction = RegPredictor(trained_model, out_dim).run(data) """ def __init__(self, trained_model: Type[torch.nn.Module], output_dim: int, use_gpu: bool = False, **kwargs: str) -> None: """ Initialize predictor """ super(RegPredictor, self).__init__(trained_model, use_gpu) set_train_rng(1) self.output_dim = output_dim self.verbose = kwargs.get("verbose", True) def preprocess(self, image_data: np.ndarray, norm: bool = True) -> torch.Tensor: """ Preprocess input image(s) """ if image_data.ndim == 2: image_data = image_data[np.newaxis, ...] image_data = torch_format_image(image_data, norm) return image_data def predict(self, image_data: np.ndarray, **kwargs: int) -> np.ndarray: """ Predict target value(s) from image(s) Args: image_data: Input image or batch of images **num_batches (int): number of batches (Default: 10) **norm (bool): Normalize data to (0, 1) during pre-processing """ num_batches = kwargs.get("num_batches", 10) image_data = self.preprocess(image_data, kwargs.get("norm", True)) output = self.batch_predict( image_data, (len(image_data), self.output_dim), num_batches) return output.squeeze().numpy() def run(self, image_data: np.ndarray, **kwargs: int) -> np.ndarray: """ Make prediction with a trained regression model Args: image_data: Input image or batch of images **num_batches (int): number of batches (Default: 10) **norm (bool): Normalize data to (0, 1) during pre-processing """ start_time = time.time() prediction = self.predict(image_data, **kwargs) if self.verbose: n_images = 1 if prediction.ndim == 0 else prediction.shape[0] n_images_str = " image was " if n_images == 1 else " images were " print("\n" + str(n_images) + n_images_str + "decoded in approximately " + str(np.around(time.time() - start_time, decimals=4)) + ' seconds') return prediction class clsPredictor(RegPredictor): """ Prediction with a trained classifier Args: trained_model: Pre-trained neural network nb_classes: number of classes in a classification scheme use_gpu: Use GPU accelration for prediction verbose: Verbosity Example: >>> # Make predictions with trained regression model >>> nb_classes = 10 >>> prediction = clsPredictor(trained_model, nb_classes).run(data) """ def __init__(self, trained_model: Type[torch.nn.Module], nb_classes: int, use_gpu: bool = False, **kwargs: str) -> None: """ Initialize predictor """ super(clsPredictor, self).__init__(trained_model, nb_classes, use_gpu, **kwargs) def predict(self, image_data: np.ndarray, **kwargs: int) -> np.ndarray: """ Categorizes an input image or a batch of input images Args: image_data (numpy array): Input image or batch of images **num_batches (int): number of batches (Default: 10) **norm (bool): Normalize data to (0, 1) during pre-processing """ num_batches = kwargs.get("num_batches", 10) image_data = self.preprocess(image_data, kwargs.get("norm", True)) output = self.batch_predict( image_data, (len(image_data), self.output_dim), num_batches) output = torch.argmax(output, 1) return output.squeeze().numpy() class Locator: """ Transforms pixel data from NN output into coordinate data Args: decoded_imgs: Output of a neural network threshold: Value at which the neural network output is thresholded dist_edge: Distance within image boundaries not to consider dim_order: 'channel_last' or 'channel_first' (Default: 'channel last') Example: >>> # Transform output of atomnet.predictor to atomic classes and coordinates >>> coordinates = locator(dist_edge=10, refine=False).run(nn_output) """ def __init__(self, threshold: float = 0.5, dist_edge: int = 5, dim_order: str = 'channel_last', **kwargs: Union[bool, float]) -> None: """ Initialize locator parameters """ self.dim_order = dim_order self.threshold = threshold self.dist_edge = dist_edge self.refine = kwargs.get("refine") self.d = kwargs.get("d") def preprocess(self, nn_output: np.ndarray) -> np.ndarray: """ Prepares data for coordinates extraction """ if nn_output.shape[-1] == 1: # Add background class for 1-channel data nn_output_b = 1 - nn_output nn_output = np.concatenate( (nn_output, nn_output_b), axis=3) if self.dim_order == 'channel_first': # make channel dim the last dim nn_output = np.transpose(nn_output, (0, 2, 3, 1)) elif self.dim_order == 'channel_last': pass else: raise NotImplementedError( 'For dim_order, use "channel_first"', 'or "channel_last" (e.g. tensorflow)') return nn_output def run(self, nn_output: np.ndarray, *args: np.ndarray) -> Dict[int, np.ndarray]: """ Extract all atomic coordinates in image via CoM method & store data as a dictionary (key: frame number) Args: nn_output (4D numpy array): Output (prediction) of a neural network *args: 4D input into a neural network (experimental data) """ nn_output = self.preprocess(nn_output) d_coord = {} for i, decoded_img in enumerate(nn_output): coordinates = np.empty((0, 2)) category = np.empty((0, 1)) # we assume that class 'background' is always the last one for ch in range(decoded_img.shape[2]-1): decoded_img_c = cv_thresh( decoded_img[:, :, ch], self.threshold) coord = find_com(decoded_img_c) coord_ch = self.rem_edge_coord(coord, *nn_output.shape[1:3]) category_ch = np.zeros((coord_ch.shape[0], 1)) + ch coordinates = np.append(coordinates, coord_ch, axis=0) category = np.append(category, category_ch, axis=0) d_coord[i] = np.concatenate((coordinates, category), axis=1) if self.refine: if len(args) > 0: imgdata = args[0] else: raise AssertionError("Pass input image(s) for coordinates refinement") print('\n\rRefining atomic positions... ', end="") d_coord_r = {} for i, (img, coord) in enumerate(zip(imgdata, d_coord.values())): d_coord_r[i] = peak_refinement(img[..., 0], coord, self.d) print("Done") return d_coord_r return d_coord def rem_edge_coord(self, coordinates: np.ndarray, h: int, w: int) -> np.ndarray: """ Removes coordinates at the image edges """ def coord_edges(coordinates, h, w): return [coordinates[0] > h - self.dist_edge, coordinates[0] < self.dist_edge, coordinates[1] > w - self.dist_edge, coordinates[1] < self.dist_edge] coord_to_rem = [ idx for idx, c in enumerate(coordinates) if any(coord_edges(c, h, w)) ] coord_to_rem = np.array(coord_to_rem, dtype=int) coordinates = np.delete(coordinates, coord_to_rem, axis=0) return coordinates