Source code for atomai.transforms.imaug
"""
imaug.py
========
Module for image transformations relevant to data augmentation
Created by Maxim Ziatdinov (maxim.ziatdinov@ai4microscopy.com)
"""
from typing import Optional, Callable, Union, List, Tuple
import numpy as np
import torch
import cv2
from scipy import stats, ndimage
from skimage import exposure
from skimage.util import random_noise
[docs]class datatransform:
"""
Applies a sequence of pre-defined operations for data augmentation.
Args:
n_channels (int):
Number of classes (channels) in the ground truth
dim_order_in (str):
Channel first or channel last ordering in the input masks
dim_order_out (str):
Channel first or channel last ordering in the output masks
seed (int):
Determenism
**custom_transform (Callable):
Python function that takes two ndarrays (images and masks) as
input, applies a set of transformation to them, and returns the two
transformed arrays
**rotation (bool):
Rotating image by +- 90 deg (if image is square)
and horizontal/vertical flipping.
**zoom (bool or int):
Zooming-in by a specified zoom factor (Default: 2)
Note that a zoom window is always square
**gauss_noise (bool or list ot tuple):
Gaussian noise. You can pass min and max values as a list/tuple
(Default [min, max] range: [0, 50])
**poisson_noise (bool or list ot tuple):
Poisson noise. You can pass min and max values as a list/tuple
(Default [min, max] range: [30, 40])
**salt_and_pepper (bool or list ot tuple):
Salt and pepper noise. You can pass min and max values as a list/tuple
(Default [min, max] range: [0, 50])
**blur (bool or list ot tuple):
Gaussian blurring. You can pass min and max values as a list/tuple
(Default [min, max] range: [1, 50])
**contrast (bool or list ot tuple):
Contrast level. You can pass min and max values as a list/tuple
(Default [min, max] range: [5, 20])
**background (bool):
Adds/substracts asymmetric 2D gaussian of random width and intensity
from the image
**resize (tuple):
Values for image resizing
[downscale factor (default: 2), upscale factor (default:1.5)]
"""
def __init__(self,
n_channels: int = None,
dim_order_in: str = 'channel_last',
dim_order_out: str = 'channel_first',
squeeze_channels: bool = False,
seed: Optional[int] = None,
**kwargs: Union[bool, Callable, List, Tuple]) -> None:
"""
Initializes image transformation parameters
"""
self.ch = n_channels
self.dim_order_in = dim_order_in
self.dim_order_out = dim_order_out
self.squeeze = squeeze_channels
self.custom_transform = kwargs.get('custom_transform')
self.rotation = kwargs.get('rotation')
self.background = kwargs.get('background')
self.gauss = kwargs.get('gauss_noise')
if self.gauss is True:
self.gauss = [0, 50]
self.jitter = kwargs.get('jitter')
if self.jitter is True:
self.jitter = [0, 50]
self.poisson = kwargs.get('poisson_noise')
if self.poisson is True:
self.poisson = [30, 40]
self.salt_and_pepper = kwargs.get('salt_and_pepper')
if self.salt_and_pepper is True:
self.salt_and_pepper = [0, 50]
self.blur = kwargs.get('blur')
if self.blur is True:
self.blur = [1, 50]
self.contrast = kwargs.get('contrast')
if self.contrast is True:
self.contrast = [5, 20]
self.zoom = kwargs.get('zoom')
if self.zoom is True:
self.zoom = 2
self.resize = kwargs.get('resize')
if self.resize is True:
self.resize = [2, 1.5]
if seed is not None:
np.random.seed(seed)
[docs] def apply_gauss(self,
X_batch: np.ndarray,
y_batch: np.ndarray) -> Tuple[np.ndarray]:
"""
Random application of gaussian noise to each training inage in a stack
"""
n, h, w = X_batch.shape[0:3]
X_batch_noisy = np.zeros((n, h, w))
for i, img in enumerate(X_batch):
gauss_var = np.random.randint(self.gauss[0], self.gauss[1])
X_batch_noisy[i] = random_noise(
img, mode='gaussian', var=1e-4*gauss_var)
return X_batch_noisy, y_batch
[docs] def apply_jitter(self,
X_batch: np.ndarray,
y_batch: np.ndarray) -> Tuple[np.ndarray]:
"""
Random application of jitter noise to each training image in a stack
"""
n, h, w = X_batch.shape[0:3]
X_batch_noisy = np.zeros((n, h, w))
for i, img in enumerate(X_batch):
jitter_amount = np.random.randint(self.jitter[0], self.jitter[1]) / 10
shift_arr = stats.poisson.rvs(jitter_amount, loc=0, size=h)
X_batch_noisy[i] = np.array([np.roll(row, z) for row, z in zip(img, shift_arr)])
return X_batch_noisy, y_batch
[docs] def apply_poisson(self,
X_batch: np.ndarray,
y_batch: np.ndarray) -> Tuple[np.ndarray]:
"""
Random application of poisson noise to each training inage in a stack
"""
def make_pnoise(image, l):
vals = len(np.unique(image))
vals = (50/l) ** np.ceil(np.log2(vals))
image_n_filt = np.random.poisson(image * vals) / float(vals)
return image_n_filt
n, h, w = X_batch.shape[0:3]
X_batch_noisy = np.zeros((n, h, w))
for i, img in enumerate(X_batch):
poisson_l = np.random.randint(self.poisson[0], self.poisson[1])
X_batch_noisy[i] = make_pnoise(img, poisson_l)
return X_batch_noisy, y_batch
[docs] def apply_sp(self,
X_batch: np.ndarray,
y_batch: np.ndarray) -> Tuple[np.ndarray]:
"""
Random application of salt & pepper noise to each training inage in a stack
"""
n, h, w = X_batch.shape[0:3]
X_batch_noisy = np.zeros((n, h, w))
for i, img in enumerate(X_batch):
sp_amount = np.random.randint(
self.salt_and_pepper[0], self.salt_and_pepper[1])
X_batch_noisy[i] = random_noise(img, mode='s&p', amount=sp_amount*1e-3)
return X_batch_noisy, y_batch
[docs] def apply_blur(self,
X_batch: np.ndarray,
y_batch: np.ndarray) -> Tuple[np.ndarray]:
"""
Random blurring of each training image in a stack
"""
n, h, w = X_batch.shape[0:3]
X_batch_noisy = np.zeros((n, h, w))
for i, img in enumerate(X_batch):
blur_amount = np.random.randint(self.blur[0], self.blur[1])
X_batch_noisy[i] = ndimage.filters.gaussian_filter(img, blur_amount*5e-2)
return X_batch_noisy, y_batch
[docs] def apply_contrast(self,
X_batch: np.ndarray,
y_batch: np.ndarray) -> Tuple[np.ndarray]:
"""
Randomly change level of contrast of each training image on a stack
"""
n, h, w = X_batch.shape[0:3]
X_batch_noisy = np.zeros((n, h, w))
for i, img in enumerate(X_batch):
clevel = np.random.randint(self.contrast[0], self.contrast[1])
X_batch_noisy[i] = exposure.adjust_gamma(img, clevel/10)
return X_batch_noisy, y_batch
[docs] def apply_zoom(self,
X_batch: np.ndarray,
y_batch: np.ndarray) -> Tuple[np.ndarray]:
"""
Zoom-in achieved by cropping image and then resizing
to the original size. The zooming window is a square.
"""
n, h, w = X_batch.shape[0:3]
shortdim = min([w, h])
zoom_values = np.arange(int(shortdim // self.zoom), shortdim + 8, 8)
zoom_values = zoom_values[zoom_values <= shortdim]
X_batch_z = np.zeros((n, shortdim, shortdim))
y_batch_z = np.zeros((n, shortdim, shortdim, self.ch))
for i, (img, gt) in enumerate(zip(X_batch, y_batch)):
zv = np.random.choice(zoom_values)
img = img[
(h // 2) - (zv // 2): (h // 2) + (zv // 2),
(w // 2) - (zv // 2): (w // 2) + (zv // 2)]
gt = gt[
(h // 2) - (zv // 2): (h // 2) + (zv // 2),
(w // 2) - (zv // 2): (w // 2) + (zv // 2)]
img = cv2.resize(
img, (shortdim, shortdim), interpolation=cv2.INTER_CUBIC)
gt = cv2.resize(
gt, (shortdim, shortdim), interpolation=cv2.INTER_CUBIC)
img = np.clip(img, 0, 1)
gt = np.around(gt)
if len(gt.shape) != 3:
gt = np.expand_dims(gt, axis=2)
X_batch_z[i] = img
y_batch_z[i] = gt
return X_batch_z, y_batch_z
[docs] def apply_background(self,
X_batch: np.ndarray,
y_batch: np.ndarray) -> Tuple[np.ndarray]:
"""
Emulates thickness variation in STEM or height variation in STM
"""
def gauss2d(xy, x0, y0, a, b, fwhm):
return np.exp(-np.log(2)*(a*(xy[0]-x0)**2 + b*(xy[1]-y0)**2) / fwhm**2)
n, h, w = X_batch.shape[0:3]
X_batch_b = np.zeros((n, h, w))
x, y = np.meshgrid(
np.linspace(0, h, h), np.linspace(0, w, w), indexing='ij')
for i, img in enumerate(X_batch):
x0 = np.random.randint(0, h - h // 4)
y0 = np.random.randint(0, w - w // 4)
a, b = np.random.randint(10, 20, 2) / 10
fwhm = np.random.randint(min([h, w]) // 4, min([h, w]) - min([h, w]) // 2)
Z = gauss2d([x, y], x0, y0, a, b, fwhm)
img = img + 0.05 * np.random.randint(-10, 10) * Z
X_batch_b[i] = img
return X_batch_b, y_batch
[docs] def apply_rotation(self,
X_batch: np.ndarray,
y_batch: np.ndarray) -> Tuple[np.ndarray]:
"""
Flips and rotates training images and correponding ground truth images
"""
n, h, w = X_batch.shape[0:3]
X_batch_r = np.zeros((n, h, w))
y_batch_r = np.zeros((n, h, w, self.ch))
for i, (img, gt) in enumerate(zip(X_batch, y_batch)):
flip_type = np.random.randint(-1, 3)
if flip_type == 3 and h == w:
img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
gt = cv2.rotate(gt, cv2.ROTATE_90_CLOCKWISE)
elif flip_type == 2 and h == w:
img = cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE)
gt = cv2.rotate(gt, cv2.ROTATE_90_COUNTERCLOCKWISE)
else:
img = cv2.flip(img, flip_type)
gt = cv2.flip(gt, flip_type)
if len(gt.shape) != 3:
gt = np.expand_dims(gt, axis=2)
X_batch_r[i] = img
y_batch_r[i] = gt
return X_batch_r, y_batch_r
[docs] def apply_imresize(self,
X_batch: np.ndarray,
y_batch: np.ndarray) -> Tuple[np.ndarray]:
"""
Resizes training images and corresponding ground truth images
"""
rs_factor_d = 1 / self.resize[0]
rs_factor_u = self.resize[1]
n, h, w = X_batch.shape[0:3]
s, p = 0.03, 8
while (np.round((h * s), 7) % p != 0
and np.round((w * s), 7) % p != 0):
s += 1e-5
rs_h = (np.arange(rs_factor_d, rs_factor_u, s) * h).astype(np.int64)
rs_w = (np.arange(rs_factor_d, rs_factor_u, s) * w).astype(np.int64)
rs_idx = np.random.randint(len(rs_h))
if X_batch.shape[1:3] == (rs_h[rs_idx], rs_w[rs_idx]):
return X_batch, y_batch
X_batch_r = np.zeros((n, rs_h[rs_idx], rs_w[rs_idx]))
y_batch_r = np.zeros((n, rs_h[rs_idx], rs_w[rs_idx], self.ch))
for i, (img, gt) in enumerate(zip(X_batch, y_batch)):
rs_method = cv2.INTER_AREA if rs_h[rs_idx] < h else cv2.INTER_CUBIC
img = cv2.resize(img, (rs_w[rs_idx], rs_h[rs_idx]), rs_method)
gt = cv2.resize(gt, (rs_w[rs_idx], rs_h[rs_idx]), rs_method)
gt = np.around(gt)
if len(gt.shape) < 3:
gt = np.expand_dims(gt, axis=-1)
X_batch_r[i] = img
y_batch_r[i] = gt
return X_batch_r, y_batch_r
[docs] def run(self, images: np.ndarray, targets: np.ndarray) -> Tuple[np.ndarray]:
"""
Applies a sequence of augmentation procedures to images
and (except for noise) targets. Starts with user defined
custom_transform if available. Then proceeds with
rotation->zoom->resize->gauss->jitter->poisson->sp->blur->contrast->background.
The operations that are not specified in kwargs are skipped.
"""
same_dim = images.ndim + 1 == targets.ndim == 4 and self.ch is not None
if self.dim_order_in == 'channel_first' and same_dim:
targets = np.transpose(targets, [0, 2, 3, 1])
elif self.dim_order_in == 'channel_last':
pass
else:
raise NotImplementedError("Use 'channel_first' or 'channel_last'")
images = (images - images.min()) / images.ptp()
if self.custom_transform is not None:
images, targets = self.custom_transform(images, targets)
if self.rotation and same_dim:
images, targets = self.apply_rotation(images, targets)
if self.zoom and same_dim:
images, targets = self.apply_zoom(images, targets)
if isinstance(self.resize, list) or isinstance(self.resize, tuple):
if same_dim:
images, targets = self.apply_imresize(images, targets)
if isinstance(self.gauss, list) or isinstance(self.gauss, tuple):
images, targets = self.apply_gauss(images, targets)
if isinstance(self.jitter, list) or isinstance(self.jitter, tuple):
images, targets = self.apply_jitter(images, targets)
if isinstance(self.poisson, list) or isinstance(self.poisson, tuple):
images, targets = self.apply_poisson(images, targets)
if isinstance(self.salt_and_pepper, list) or isinstance(self.salt_and_pepper, tuple):
images, targets = self.apply_sp(images, targets)
if isinstance(self.blur, list) or isinstance(self.blur, tuple):
images, targets = self.apply_blur(images, targets)
if isinstance(self.contrast, list) or isinstance(self.contrast, tuple):
images, targets = self.apply_contrast(images, targets)
if self.background:
images, targets = self.apply_background(images, targets)
if self.squeeze and same_dim:
images, targets = squeeze_channels(images, targets)
if self.dim_order_out == 'channel_first':
images = np.expand_dims(images, axis=1)
if same_dim:
if self.squeeze is None or self.ch == 1:
targets = np.transpose(targets, (0, 3, 1, 2))
elif self.dim_order_out == 'channel_last':
images = np.expand_dims(images, axis=3)
else:
raise NotImplementedError("Use 'channel_first' or 'channel_last'")
images = (images - images.min()) / images.ptp()
return images, targets
def squeeze_channels(images: np.ndarray,
labels: np.ndarray,
clip: bool = False) -> Tuple[np.ndarray]:
"""
Squeezes channels in each training image and
filters out image-label pairs where some pixels have multiple values.
As a result the number of image-label-pairs returned may be different
from the number of image-label pairs in the original data.
"""
def squeeze_channels_(label):
"""
Squeezes multiple channel into a single channel for a single label
"""
label_ = np.zeros((1, label.shape[0], label.shape[1]))
for c in range(label.shape[-1]):
label_ += label[:, :, c] * c
return label_
if labels.shape[-1] == 1:
return images, labels
images_valid, labels_valid = [], []
for label, image in zip(labels, images):
label = squeeze_channels_(label)
if clip:
label[label > labels.shape[-1] - 1] = 0
labels_valid.append(label)
images_valid.append(image[None, ...])
else:
if len(np.unique(label)) == labels.shape[-1]:
labels_valid.append(label)
images_valid.append(image[None, ...])
return np.concatenate(images_valid), np.concatenate(labels_valid)
def unsqueeze_channels(labels: np.ndarray, n_channels: int) -> np.ndarray:
"""
Separates pixels with different values into different channels
"""
if n_channels == 1:
return labels
labels_ = np.eye(n_channels)[labels.astype(int)]
return np.transpose(labels_, [0, 3, 1, 2])
def seg_augmentor(nb_classes: int,
**kwargs
) -> Callable[[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
auglist = ["custom_transform", "zoom", "gauss_noise", "jitter",
"poisson_noise", "contrast", "salt_and_pepper", "blur",
"resize", "rotation", "background"]
augdict = {k: kwargs[k] for k in auglist if k in kwargs.keys()}
if len(augdict) == 0:
return
def augmentor(images, labels, seed):
images = images.cpu().numpy().astype(np.float64)
labels = labels.cpu().numpy().astype(np.float64)
dt = datatransform(
nb_classes, "channel_first", 'channel_first',
True, seed, **augdict)
images, labels = dt.run(
images[:, 0, ...], unsqueeze_channels(labels, nb_classes))
images = torch.from_numpy(images).float()
if nb_classes == 1:
labels = torch.from_numpy(labels).float()
else:
labels = torch.from_numpy(labels).long()
return images, labels
return augmentor
def imspec_augmentor(in_dim: Tuple[int],
out_dim: Tuple[int],
**kwargs
) -> Callable[[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
auglist = ["custom_transform", "gauss_noise", "jitter",
"poisson_noise", "contrast", "salt_and_pepper", "blur",
"background"]
augdict = {k: kwargs[k] for k in auglist if k in kwargs.keys()}
if len(augdict) == 0:
return
if len(in_dim) < len(out_dim):
raise NotImplementedError("The built-in data augmentor works only" +
" for img->spec models (i.e. input is image)")
def augmentor(features, targets, seed):
features = features.cpu().numpy().astype(np.float64)
targets = targets.cpu().numpy().astype(np.float64)
dt = datatransform(seed, **augdict)
features, targets = dt.run(features[:, 0, ...], targets)
features = torch.from_numpy(features).float()
targets = torch.from_numpy(targets).float()
return features, targets
return augmentor
def reg_augmentor(**kwargs) -> Callable[[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
auglist = ["custom_transform", "gauss_noise", "jitter",
"poisson_noise", "contrast", "salt_and_pepper", "blur",
"background"]
augdict = {k: kwargs[k] for k in auglist if k in kwargs.keys()}
if len(augdict) == 0:
return
def augmentor(features, targets, seed):
features = features.cpu().numpy().astype(np.float64)
targets = targets.cpu().numpy().astype(np.float64)
dt = datatransform(seed, **augdict)
features, targets = dt.run(features[:, 0, ...], targets)
features = torch.from_numpy(features).float()
targets = torch.from_numpy(targets).float()
return features, targets
return augmentor