"""
trainer.py
==========
Module for training fully convolutional neural networs
for atom/defect/particle finding and encoder-decoder neural networks
for prediction of spectra/images from images/spectra. It can also be
used for training custom PyTorch neural networks
Created by Maxim Ziatdinov (email: maxim.ziatdinov@ai4microscopy.com)
"""
import copy
import warnings
from collections import OrderedDict
from typing import Callable, List, Optional, Tuple, Type, Union
import numpy as np
import torch
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
from atomai import losses_metrics
from atomai.nets import (init_cls_model, init_fcnn_model, init_imspec_model,
init_reg_model)
from atomai.utils import (array2list, average_weights, gpu_usage_map,
init_cls_dataloaders, init_dataloaders,
init_fcnn_dataloaders, init_imspec_dataloaders,
init_reg_dataloaders, plot_losses,
preprocess_training_cls_data,
preprocess_training_image_data,
preprocess_training_imspec_data,
preprocess_training_reg_data, reset_bnorm,
set_train_rng, weights_init)
warnings.filterwarnings("ignore", module="torch.nn.functional")
augfn_type = Callable[[torch.Tensor, torch.Tensor, int], Tuple[torch.Tensor, torch.Tensor]]
[docs]class BaseTrainer:
"""
Base trainer class for training semantic segmentation
and image-to-spectrum/spectrum-to-image deep learning models
as well as custom PyTorch neural networks
Example:
>>> # Load 4 numpy arrays with training and test data
>>> dataset = np.load('training_data.npz')
>>> images = dataset['X_train']
>>> labels = dataset['y_train']
>>> images_test = dataset['X_test']
>>> labels_test = dataset['y_test']
>>> # Initialize a trainer
>>> t = BaseTrainer()
>>> # Set a model
>>> t.set_model(atomai.nets.Unet(), nb_classes=1)
>>> # Compile trainer
>>> t.compile_trainer(
>>> (images, labels, images_test_1, labels_test_1),
>>> loss="ce", full_epoch=True, training_cycles=25, swa=True)
>>> # Train and save model's weights
>>> t.fit()
>>> t.save_model("my_model")
"""
def __init__(self):
set_train_rng(1)
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
self.device = torch.device('mps') # backend for Apple silicon GPUs
elif torch.cuda.is_available():
self.device = 'cuda'
else:
self.device = 'cpu'
self.net = None
self.criterion = None
self.optimizer = None
self.compute_accuracy = False
self.full_epoch = True
self.swa = False
self.perturb_weights = False
self.running_weights = {}
self.training_cycles = 0
self.batch_idx_train, self.batch_idx_test = [], []
self.batch_size = 1
self.nb_classes = None
self.X_train, self.y_train = None, None
self.X_test, self.y_test = None, None
self.train_loader = torch.utils.data.TensorDataset()
self.test_loader = torch.utils.data.TensorDataset()
self.data_is_set = False
self.augdict = {}
self.augment_fn = None
self.filename = "model"
self.print_loss = 1
self.meta_state_dict = dict()
self.loss_acc = {"train_loss": [], "test_loss": [],
"train_accuracy": [], "test_accuracy": []}
def _reset_rng(self, seed: int) -> None:
"""
(re)sets seeds for pytorch and numpy random number generators
"""
set_train_rng(seed)
def _reset_weights(self) -> None:
"""
Resets weights of convolutional and linear NN layers
using Xavier initialization
"""
self.net.apply(weights_init)
self.net.apply(reset_bnorm)
def _reset_training_history(self) -> None:
"""
Empties training/test losses and accuracies
(can be useful for ensemble training)
"""
self.loss_acc = {"train_loss": [], "test_loss": [],
"train_accuracy": [], "test_accuracy": []}
def _delete_optimizer(self) -> None:
"""
Sets optimizer to None.
"""
self.optimizer = None
[docs] def set_data(self,
X_train: Union[torch.Tensor, np.ndarray],
y_train: Union[torch.Tensor, np.ndarray],
X_test: Union[torch.Tensor, np.ndarray],
y_test: Union[torch.Tensor, np.ndarray],
**kwargs: float) -> None:
"""
Sets training and test data by initializing PyTorch dataloaders
or creating a list of PyTorch tensors from which it will randomly
choose an element at each training iteration.
Args:
X_train: Training data
y_train: Training data labels/ground-truth
X_test: Test data
y_test: Test data labels/ground-truth
memory_alloc: threshold (in GB) for holding all training data on GPU
"""
memory_alloc = kwargs.get("memory_alloc", 4)
tor = lambda x: torch.from_numpy(x) if isinstance(x, np.ndarray) else x
X_train, y_train = tor(X_train), tor(y_train)
X_test, y_test = tor(X_test), tor(y_test)
if self.full_epoch:
self.train_loader, self.test_loader = init_dataloaders(
X_train, y_train, X_test, y_test,
self.batch_size, memory_alloc)
else:
(self.X_train, self.y_train,
self.X_test, self.y_test) = array2list(
X_train, y_train, X_test, y_test,
self.batch_size, memory_alloc)
self.data_is_set = True
[docs] def set_model(self,
model: Type[torch.nn.Module],
nb_classes: int = None) -> None:
"""
Sets a neural network model and a number of classes (if any)
Args:
model: initialized PyTorch model
nb_classes: number of classes in classification scheme (if any)
"""
self.net = model
self.net.to(self.device)
if self.nb_classes is None and nb_classes:
self.nb_classes = nb_classes
[docs] def get_loss_fn(self,
loss: Union[str, Callable] = 'mse',
nb_classes: int = None) -> None:
"""
Returns a loss function. Available loss functions are: 'mse' (MSE),
'ce' (cross-entropy), 'focal' (focal loss; single class only),
and 'dice' (dice loss; for semantic segmentation problems)
"""
return losses_metrics.select_loss(loss, nb_classes)
[docs] def train_step(self,
feat: torch.Tensor,
tar: torch.Tensor) -> Tuple[float]:
"""
Propagates image(s) through a network to get model's prediction
and compares predicted value with ground truth; then performs
backpropagation to compute gradients and optimizes weights.
Args:
feat: input features
tar: targets
"""
self.net.train()
self.optimizer.zero_grad()
feat, tar = feat.to(self.device), tar.to(self.device)
prob = self.net(feat)
loss = self.criterion(prob, tar)
loss.backward()
self.optimizer.step()
if self.compute_accuracy:
acc_score = self.accuracy_fn(tar, prob)
return (loss.item(), acc_score)
return (loss.item(),)
[docs] def test_step(self,
feat: torch.Tensor,
tar: torch.Tensor) -> float:
"""
Forward pass for test data with deactivated autograd engine
Args:
feat: input features
tar: targets
"""
feat, tar = feat.to(self.device), tar.to(self.device)
self.net.eval()
with torch.no_grad():
prob = self.net(feat)
loss = self.criterion(prob, tar)
if self.compute_accuracy:
acc_score = self.accuracy_fn(tar, prob)
return (loss.item(), acc_score)
return (loss.item(),)
[docs] def step(self, e: int) -> None:
"""
Single train-test step which passes a single
mini-batch (for both training and testing), i.e.
1 "epoch" = 1 mini-batch
"""
features, targets = self.dataloader(
self.batch_idx_train[e], mode='train')
# Training step
loss = self.train_step(features, targets)
self.loss_acc["train_loss"].append(loss[0])
features_, targets_ = self.dataloader(
self.batch_idx_test[e], mode='test')
# Test step
loss_ = self.test_step(features_, targets_)
self.loss_acc["test_loss"].append(loss_[0])
if self.compute_accuracy:
self.loss_acc["train_accuracy"].append(loss[1])
self.loss_acc["test_accuracy"].append(loss_[1])
[docs] def step_full(self) -> None:
"""
A standard PyTorch training loop where
all available mini-batches are passed at
a single step/epoch
"""
c, c_test = 0, 0
losses, losses_test = 0, 0
if self.compute_accuracy:
acc, acc_test = 0, 0
# Training step
for features, targets in self.train_loader:
if self.augment_fn is not None:
features, targets = self.augment_fn(
features, targets, seed=c)
loss = self.train_step(features, targets)
losses += loss[0]
if self.compute_accuracy:
acc += loss[1]
c += 1
else: # Test step
for features_, targets_ in self.test_loader:
if self.augment_fn is not None:
features_, targets_ = self.augment_fn(
features_, targets_, seed=c_test)
loss_ = self.test_step(features_, targets_)
losses_test += loss_[0]
if self.compute_accuracy:
acc_test += loss_[1]
c_test += 1
self.loss_acc["train_loss"].append(losses / c)
self.loss_acc["test_loss"].append(losses_test / c_test)
if self.compute_accuracy:
self.loss_acc["train_accuracy"].append(acc / c)
self.loss_acc["test_accuracy"].append(acc_test / c_test)
[docs] def eval_model(self) -> None:
"""
Evaluates model on the entire dataset
"""
self.net.eval()
running_loss_test, c = 0, 0
if self.compute_accuracy:
running_acc_test = 0
if self.full_epoch:
for features_, targets_ in self.test_loader:
if self.augment_fn is not None:
features, targets = self.augment_fn(
features_, targets_, seed=c)
loss_ = self.test_step(features_, targets_)
running_loss_test += loss_[0]
if self.compute_accuracy:
running_acc_test += loss_[1]
c += 1
print('Model (final state) evaluation loss:',
np.around(running_loss_test / c, 4))
if self.compute_accuracy:
print('Model (final state) accuracy:',
np.around(running_acc_test / c, 4))
else:
running_loss_test, running_acc_test = 0, 0
for idx in range(len(self.X_test)):
features_, targets_ = self.dataloader(idx, mode='test')
loss_ = self.test_step(features_, targets_)
running_loss_test += loss_[0]
if self.compute_accuracy:
running_acc_test += loss_[1]
print('Model (final state) evaluation loss:',
np.around(running_loss_test / len(self.X_test), 4))
if self.compute_accuracy:
print('Model (final state) accuracy:',
np.around(running_acc_test / len(self.X_test), 4))
[docs] def dataloader(self,
batch_num: int,
mode: str = 'train') -> Tuple[torch.Tensor]:
"""
Generates input training data with images/spectra
and the associated labels (spectra/images)
"""
if mode == 'test':
features = self.X_test[batch_num][:self.batch_size]
targets = self.y_test[batch_num][:self.batch_size]
else:
features = self.X_train[batch_num][:self.batch_size]
targets = self.y_train[batch_num][:self.batch_size]
if self.augment_fn is not None:
features, targets = self.augment_fn(
features, targets, seed=len(self.loss_acc["train_loss"]))
return features, targets
[docs] def save_model(self, *args: str) -> None:
"""
Saves trained weights, optimizer and key information about model's
architecture (the latter works only for built-in AtomAI models)
"""
try:
filename = args[0]
except IndexError:
filename = self.filename
self.meta_state_dict["weights"] = self.meta_state_dict.get(
"weights", self.net.state_dict())
self.meta_state_dict["optimizer"] = self.meta_state_dict.get(
"optimizer", self.optimizer)
torch.save(self.meta_state_dict,
filename + '.tar')
[docs] def print_statistics(self, e: int, **kwargs) -> None:
"""
Print loss and (optionally) accuracy score on train
and test data, as well as GPU memory usage.
"""
accuracy_metrics = self.accuracy_metrics
if accuracy_metrics is None:
accuracy_metrics = "Accuracy"
if torch.cuda.is_available():
gpu_usage = gpu_usage_map(torch.cuda.current_device())
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
gpu_usage = ['N/A ', ' N/A']
else:
gpu_usage = ['N/A ', ' N/A']
if self.compute_accuracy:
print('Epoch {}/{} ...'.format(e+1, self.training_cycles),
'Training loss: {} ...'.format(
np.around(self.loss_acc["train_loss"][-1], 4)),
'Test loss: {} ...'.format(
np.around(self.loss_acc["test_loss"][-1], 4)),
'Train {}: {} ...'.format(
accuracy_metrics,
np.around(self.loss_acc["train_accuracy"][-1], 4)),
'Test {}: {} ...'.format(
accuracy_metrics,
np.around(self.loss_acc["test_accuracy"][-1], 4)),
'GPU memory usage: {}/{}'.format(
gpu_usage[0], gpu_usage[1]))
else:
print('Epoch {}/{} ...'.format(e+1, self.training_cycles),
'Training loss: {} ...'.format(
np.around(self.loss_acc["train_loss"][-1], 4)),
'Test loss: {} ...'.format(
np.around(self.loss_acc["test_loss"][-1], 4)),
'GPU memory usage: {}/{}'.format(
gpu_usage[0], gpu_usage[1]))
[docs] def accuracy_fn(self, *args) -> None:
"""
Computes accuracy score
"""
raise NotImplementedError
[docs] def weight_perturbation(self, e: int) -> None:
"""
Time-dependent weights perturbation
(role of time is played by "epoch" number)
"""
a = self.perturb_weights["a"]
gamma = self.perturb_weights["gamma"]
e_p = self.perturb_weights["e_p"]
if self.perturb_weights and (e + 1) % e_p == 0:
var = torch.tensor(a / (1 + e)**gamma)
for k, v in self.net.state_dict().items():
v_prime = v + v.new(v.shape).normal_(0, torch.sqrt(var))
self.net.state_dict()[k].copy_(v_prime)
return
[docs] def save_running_weights(self, e: int) -> None:
"""
Saves running weights (for stochastic weights averaging)
"""
swa_epochs = 5 if self.full_epoch else 30
if self.training_cycles - e <= swa_epochs:
i_ = swa_epochs - (self.training_cycles - e)
state_dict_ = OrderedDict()
for k, v in self.net.state_dict().items():
state_dict_[k] = copy.deepcopy(v).cpu()
self.running_weights[i_] = state_dict_
return
[docs] def data_augmentation(self,
augment_fn: augfn_type) -> None:
"""
Set up data augmentation. To use it, pass a function that takes
two torch tensors (features and targets), peforms some transforms,
and returns the transformed tensors. The dimensions of the transformed
tensors must be the same as the dimensions of the original ones.
"""
self.augment_fn = augment_fn
[docs] def compile_trainer(self,
train_data: Union[Tuple[torch.Tensor], Tuple[np.ndarray]] = None,
loss: Union[str, Callable] = 'ce',
optimizer: Optional[Type[torch.optim.Optimizer]] = None,
training_cycles: int = 1000,
batch_size: int = 32,
compute_accuracy: bool = False,
full_epoch: bool = False,
swa: bool = False,
perturb_weights: bool = False,
**kwargs):
"""
Compile a trainer
Args:
train_data:
4-element tuple of ndarrays or torch tensors
(train_data, train_labels, test_data, test_labels)
loss:
loss function. Available loss functions are: 'mse' (MSE),
'ce' (cross-entropy), 'focal' (focal loss; single class only),
and 'dice' (dice loss; for semantic segmentation problems).
One can also pass a custom loss function that takes prediction
and ground truth and computes a loss score.
optimizer:
weights optimizer (defaults to Adam optimizer with lr=1e-3)
training_cycles:
Number of training 'epochs'.
If full_epoch argument is set to False, 1 epoch == 1 batch.
Otherwise, each cycle corresponds to all mini-batches of data
passing through a NN.
batch_size:
Size of training and test batches
compute_accuracy:
Computes accuracy function at each training cycle
full_epoch:
If True, passes all mini-batches of training/test data
at each training cycle and computes the average loss. If False,
passes a single (randomly chosen) mini-batch at each cycle.
swa:
Saves the recent stochastic weights and averages
them at the end of training
perturb_weights:
Time-dependent weight perturbation, :math:`w\\leftarrow w + a / (1 + e)^\\gamma`,
where parameters *a* and *gamma* can be passed as a dictionary,
together with parameter *e_p* determining every n-th epoch at
which a perturbation is applied
**lr_scheduler (list of floats):
List with learning rates for each training iteration/epoch.
If the length of list is smaller than the number of training iterations,
the last values in the list is used for the remaining iterations.
**batch_seed (int):
Random state for generating sequence of training and test batches
**overwrite_train_data (bool):
Overwrites the exising training data using self.set_data()
(Default: True)
**memory_alloc (float):
threshold (in GB) for holding all training data on GPU
**print_loss (int):
Prints loss every *n*-th epoch
**accuracy_metrics (str):
Accuracy metrics (used only for printing training statistics)
**filename (str):
Filename for model weights
(appended with "_test_weights_best.pt" and "_weights_final.pt")
**plot_training_history (bool):
Plots training and test curves vs epochs at the end of training
"""
self.full_epoch = full_epoch
self.training_cycles = training_cycles
self.batch_size = batch_size
self.compute_accuracy = compute_accuracy
self.swa = swa
self.lr_scheduler = kwargs.get("lr_scheduler")
alloc = kwargs.get("memory_alloc", 4)
if self.data_is_set:
if kwargs.get("overwrite_train_data", True):
self.set_data(*train_data, memory_alloc=alloc)
else:
pass
else:
self.set_data(*train_data, memory_alloc=alloc)
self.perturb_weights = perturb_weights
if self.perturb_weights:
if self.meta_state_dict["batchnorm"]:
raise AssertionError(
"To use time-dependent weights perturbation, " +
"turn off the batch normalization layes")
if isinstance(self.perturb_weights, bool):
e_p = 1 if self.full_epoch else 50
self.perturb_weights = {"a": .01, "gamma": 1.5, "e_p": e_p}
params = self.net.parameters()
if self.optimizer is None:
if optimizer is None:
# will be overwitten by lr_scheduler (if activated)
self.optimizer = torch.optim.Adam(params, lr=1e-3)
else:
self.optimizer = optimizer(params)
if self.criterion is None:
self.criterion = self.get_loss_fn(loss, self.nb_classes)
if not self.full_epoch:
r = self.training_cycles // len(self.X_train)
batch_idx_train = np.arange(
len(self.X_train)).repeat(r+1)[:self.training_cycles]
r_ = self.training_cycles // len(self.X_test)
batch_idx_test = np.arange(
len(self.X_test)).repeat(r_+1)[:self.training_cycles]
self.batch_idx_train = shuffle(
batch_idx_train, random_state=kwargs.get("batch_seed", 1))
self.batch_idx_test = shuffle(
batch_idx_test, random_state=kwargs.get("batch_seed", 1))
self.print_loss = kwargs.get("print_loss")
if self.print_loss is None:
if not self.full_epoch:
self.print_loss = 100
else:
self.print_loss = 1
self.accuracy_metrics = kwargs.get("accuracy_metrics")
self.filename = kwargs.get("filename", "./model")
self.plot_training_history = kwargs.get("plot_training_history", True)
[docs] def select_lr(self, e: int) -> None:
lr_i = (self.lr_scheduler[e] if e < len(self.lr_scheduler)
else self.lr_scheduler[-1])
for g in self.optimizer.param_groups:
g['lr'] = lr_i
[docs] def run(self) -> Type[torch.nn.Module]:
"""
Trains a neural network, prints the statistics,
saves the final model weights. One can also pass
kwargs for utils.datatransform class to perform
the data augmentation "on-the-fly"
"""
for e in range(self.training_cycles):
if self.lr_scheduler is not None:
self.select_lr(e)
if self.full_epoch:
self.step_full()
else:
self.step(e)
if self.swa:
self.save_running_weights(e)
if self.perturb_weights:
self.weight_perturbation(e)
if any([e == 0, (e+1) % self.print_loss == 0,
e == self.training_cycles-1]):
self.print_statistics(e)
if not self.full_epoch:
self.eval_model()
if self.swa:
print("Performing stochastic weight averaging...")
self.net.load_state_dict(average_weights(self.running_weights))
self.eval_model()
self.save_model(self.filename + "_metadict_final")
if self.plot_training_history:
plot_losses(self.loss_acc["train_loss"],
self.loss_acc["test_loss"])
return self.net
[docs] def fit(self) -> None:
_ = self.run()
[docs]class SegTrainer(BaseTrainer):
"""
Class for training a fully convolutional neural network
for semantic segmentation of noisy experimental data
Args:
model:
Type of model to train: 'Unet', 'Uplusnet' or 'dilnet' (Default: 'Unet').
See atomai.nets for more details. One can also pass a custom fully
convolutional neural network model.
nb_classes:
Number of classes in the classification scheme adopted
(must match the number of classes in training data)
**seed (int):
Deterministic mode for model training (Default: 1)
**batch_seed (int):
Separate seed for generating a sequence of batches
for training/testing. Equal to 'seed' if set to None (default)
**batch_norm (bool):
Apply batch normalization after each convolutional layer
(Default: True)
**dropout (bool):
Apply dropouts in the three inner blocks in the middle of a network
(Default: False)
**upsampling (str):
"bilinear" or "nearest" upsampling method (Default: "bilinear")
**nb_filters (int):
Number of convolutional filters in the first convolutional block
(this number doubles in the consequtive block(s),
see definition of Unet and dilnet models for details)
**with_dilation (bool):
Use dilated convolutions in the bottleneck of Unet
(Default: False)
**layers (list):
List with a number of layers in each block.
For U-Net the first 4 elements in the list
are used to determine the number of layers
in each block of the encoder (including bottleneck layer),
and the number of layers in the decoder is chosen accordingly
(to maintain symmetry between encoder and decoder)
"""
def __init__(self,
model: Union[Type[torch.nn.Module], str] = "Unet",
nb_classes: int = 1,
**kwargs: Union[int, List, str, bool]) -> None:
"""
Initialize a single FCNN model trainer
"""
super(SegTrainer, self).__init__()
seed = kwargs.get("seed", 1)
kwargs["batch_seed"] = kwargs.get("batch_seed", seed)
set_train_rng(seed)
self.nb_classes = nb_classes
self.net, self.meta_state_dict = init_fcnn_model(
model, self.nb_classes, **kwargs)
self.net.to(self.device)
if self.device == 'cpu':
warnings.warn(
"No GPU found. The training can be EXTREMELY slow",
UserWarning)
self.meta_state_dict["weights"] = self.net.state_dict()
#self.meta_state_dict["optimizer"] = self.optimizer
[docs] def set_data(self,
X_train: Tuple[np.ndarray, torch.Tensor],
y_train: Tuple[np.ndarray, torch.Tensor],
X_test: Optional[Tuple[np.ndarray, torch.Tensor]] = None,
y_test: Optional[Tuple[np.ndarray, torch.Tensor]] = None,
**kwargs: Union[float, int]) -> None:
"""
Sets training and test data.
Args:
X_train:
4D numpy array or pytorch tensor of training images
(n_samples, 1, height, width). One can also pass a regular
3D image stack without a channel dimension of 1 which will
be added automatically
y_train:
4D (binary) / 3D (multiclass) numpy array or pytorch tensor
of training masks (aka ground truth) stacked along
the first dimension. The reason why in the multiclass case
the X_train is 4-dimensional and the y_train is 3-dimensional
is because of how the cross-entropy loss is calculated in PyTorch
(see https://pytorch.org/docs/stable/nn.html#nllloss).
X_test:
4D numpy array or pytorch tensor of test images
(stacked along the first dimension)
y_test:
4D (binary) / 3D (multiclass) numpy array or pytorch tensor
of training masks (aka ground truth) stacked along
the first dimension.
kwargs:
Parameters for train_test_split ('test_size' and 'seed') when
separate test set is not provided and 'memory_alloc', which
sets a threshold (in GBs) for holding entire training data on GPU
"""
if X_test is None or y_test is None:
X_train, X_test, y_train, y_test = train_test_split(
X_train, y_train, test_size=kwargs.get("test_size", .15),
shuffle=True, random_state=kwargs.get("seed", 1))
if self.full_epoch:
loaders = init_fcnn_dataloaders(
X_train, y_train, X_test, y_test,
self.batch_size, memory_alloc=kwargs.get("memory_alloc", 4))
self.train_loader, self.test_loader, nb_classes = loaders
else:
(self.X_train, self.y_train,
self.X_test, self.y_test,
nb_classes) = preprocess_training_image_data(
X_train, y_train, X_test, y_test,
self.batch_size,
kwargs.get("memory_alloc", 4))
if self.nb_classes != nb_classes:
raise AssertionError("Number of classes in initialized model" +
" is different from the number of classes" +
" contained in training data")
[docs] def accuracy_fn(self,
y: torch.Tensor,
y_prob: torch.Tensor,
*args):
iou_score = losses_metrics.IoU(
y, y_prob, self.nb_classes).evaluate()
return iou_score
[docs]class ImSpecTrainer(BaseTrainer):
"""
Trainer of neural network for image-to-spectrum
and spectrum-to-image transformations
Args:
in_dim:
Input data dimensions.
(height, width) for images or (length,) for spectra
out_dim:
output dimensions.
(length,) for spectra or (height, width) for images
latent_dim:
dimensionality of the latent space
(number of neurons in a fully connected bottleneck layer)
**seed (int):
Deterministic mode for model training (Default: 1)
**batch_seed (int):
Separate seed for generating a sequence of batches
for training/testing. Equal to 'seed' if set to None (default)
**nblayers_encoder (int):
number of convolutional layers in the encoder
**nblayers_decoder (int):
number of convolutional layers in the decoder
**nbfilters_encoder (int):
number of convolutional filters in each layer of the encoder
**nbfilters_decoder (int):
number of convolutional filters in each layer of the decoder
**batch_norm (bool):
Apply batch normalization after each convolutional layer
(Default: True)
**encoder_downsampling (int):
downsamples input data by this factor before passing
to convolutional layers (Default: no downsampling)
**decoder_upsampling (bool):
performs upsampling+convolution operation twice on the reshaped latent
vector (starting from image/spectra dims 4x smaller than the target dims)
before passing to the decoder
"""
def __init__(self,
in_dim: Tuple[int],
out_dim: Tuple[int],
latent_dim: int = 2,
**kwargs: Union[int, bool, str]) -> None:
super(ImSpecTrainer, self).__init__()
"""
Initialize trainer's parameters
"""
seed = kwargs.get("seed", 1)
kwargs["batch_seed"] = kwargs.get("batch_seed", seed)
set_train_rng(seed)
self.in_dim, self.out_dim = in_dim, out_dim
(self.net,
self.meta_state_dict) = init_imspec_model(in_dim, out_dim, latent_dim,
**kwargs)
self.net.to(self.device)
self.meta_state_dict["weights"] = self.net.state_dict()
#self.meta_state_dict["optimizer"] = self.optimizer
[docs] def set_data(self,
X_train: Union[np.ndarray, torch.Tensor],
y_train: Union[np.ndarray, torch.Tensor],
X_test: Optional[Union[np.ndarray, torch.Tensor]] = None,
y_test: Optional[Union[np.ndarray, torch.Tensor]] = None,
**kwargs: Union[float, int]) -> None:
"""
Sets training and test data.
Args:
X_train:
4D numpy array or torch tensor with image data
(n_samples x 1 x height x width) or 3D array/tensor
with spectral data (n_samples x 1 x signal_length).
It is also possible to pass 3D and 2D arrays by ignoring
the channel dim of 1, which will be added automatically.
The X_train is typically referred to as 'features'
y_train:
3D numpy array or torch tensor with spectral data
(n_samples x 1 x signal_length) or 4D array/tensor with
image data (n_samples x 1 x height x width).
It is also possible to pass 2D and 3D arrays by ignoring
the channel dim of 1, which will be added automatically.
Note that if your X_train data are images,
then your y_train must be spectra and vice versa.
The y_train is typicaly referred to as "targets"
X_test:
Test data (features) of the same dimesnionality
(except for the number of samples) as X_train
y_test:
Test data (targets) of the same dimesnionality
(except for the number of samples) as y_train
kwargs:
Parameters for train_test_split ('test_size' and 'seed') when
separate test set is not provided and 'memory_alloc', which
sets a threshold (in GBs) for holding entire training data on GPU
"""
if X_test is None or y_test is None:
X_train, X_test, y_train, y_test = train_test_split(
X_train, y_train, test_size=kwargs.get("test_size", .15),
shuffle=True, random_state=kwargs.get("seed", 1))
if self.full_epoch:
self.train_loader, self.test_loader, dims = init_imspec_dataloaders(
X_train, y_train, X_test, y_test,
self.batch_size, kwargs.get("memory_alloc", 4))
else:
(self.X_train, self.y_train,
self.X_test, self.y_test, dims) = preprocess_training_imspec_data(
X_train, y_train, X_test, y_test,
self.batch_size, kwargs.get("memory_alloc", 4))
if dims[0] != self.in_dim or dims[1] != self.out_dim:
raise AssertionError(
"The input/output dimensions of the model must match" +
" the height, width and length (for spectra) of training")
class RegTrainer(BaseTrainer):
"""
Trainer for image-vector regression tasks
Args:
out_dim:
Output size. Equals to 1 for single-output regression tasks
backbone:
Type of backbone NN: choose between 'mobilenet', 'vgg', and 'resnet'
**input_channels:
Number of input channels. Assumes 1 as we mostly work with grayscale images
**seed:
random number generator seed
"""
def __init__(self,
out_dim: int = 1,
backbone: str = "mobilenet",
**kwargs) -> None:
super(RegTrainer, self).__init__()
"""
Initialize trainer's parameters
"""
seed = kwargs.get("seed", 1)
kwargs["batch_seed"] = kwargs.get("batch_seed", seed)
set_train_rng(seed)
self.output_size = out_dim
self.criterion = self.get_loss_fn('mse')
(self.net,
self.meta_state_dict) = init_reg_model(out_dim, backbone, **kwargs)
self.net.to(self.device)
self.meta_state_dict["weights"] = self.net.state_dict()
def set_data(self,
X_train: Union[np.ndarray, torch.Tensor],
y_train: Union[np.ndarray, torch.Tensor],
X_test: Optional[Union[np.ndarray, torch.Tensor]] = None,
y_test: Optional[Union[np.ndarray, torch.Tensor]] = None,
**kwargs: Union[float, int]) -> None:
"""
Sets training and test data(loaders)
Args:
X_train (numpy array):
4D numpy array with image data (n_samples x 1 x height x width).
It is also possible to pass 3D by ignoring the channel dim,
which will be added automatically.
y_train (numpy array):
2D numpy array with target values (n_samples x out_dim).
For single-outut regression tasks, one can simply pass an (n_samples,) array
X_test (numpy array):
4D numpy array with image data (n_samples x 1 x height x width).
It is also possible to pass 3D by ignoring the channel dim,
which will be added automatically.
y_test (numpy array):
2D numpy array with target values (n_samples x out_dim).
For single-outut regression tasks, one can simply pass an (n_samples,) array
kwargs:
Parameters for train_test_split ('test_size' and 'seed') when
separate test set is not provided and 'memory_alloc', which
sets a threshold (in GBs) for holding entire training data on GPU
"""
if X_test is None or y_test is None:
X_train, X_test, y_train, y_test = train_test_split(
X_train, y_train, test_size=kwargs.get("test_size", .15),
shuffle=True, random_state=kwargs.get("seed", 1))
if self.full_epoch:
self.train_loader, self.test_loader = init_reg_dataloaders(
X_train, y_train, X_test, y_test,
self.batch_size, kwargs.get("memory_alloc", 4))
train_out_dim = self.train_loader.dataset.tensors[1].shape[-1]
test_out_dim = self.test_loader.dataset.tensors[1].shape[-1]
if not train_out_dim == test_out_dim == self.output_size:
raise AssertionError(
"The output dimensions for the training and test data must be" +
" equal to the declared output dimensions")
else:
(self.X_train, self.y_train,
self.X_test, self.y_test) = preprocess_training_reg_data(
X_train, y_train, X_test, y_test,
self.batch_size, kwargs.get("memory_alloc", 4))
if not self.y_train[0].shape[-1] == self.y_test[0].shape[-1] == self.output_size:
raise AssertionError(
"The output dimensions for the training and test data must be" +
" equal to the declared output dimensions")
class clsTrainer(BaseTrainer):
"""
Class for training a classification models
Args:
nb_classes:
Number of classes in the classification scheme adopted
(must match the number of classes in training data)
backbone:
Type of backbone NN: choose between 'mobilenet', 'vgg', and 'resnet'
**seed (int):
Deterministic mode for model training (Default: 1)
**batch_seed (int):
Separate seed for generating a sequence of batches
for training/testing. Equal to 'seed' if set to None (default)
"""
def __init__(self,
nb_classes,
backbone: str = "mobilenet",
**kwargs) -> None:
"""
Initialize a trainer for classification tasks
"""
super(clsTrainer, self).__init__()
seed = kwargs.get("seed", 1)
kwargs["batch_seed"] = kwargs.get("batch_seed", seed)
set_train_rng(seed)
self.nb_classes = nb_classes
self.criterion = self.get_loss_fn('nll')
self.net, self.meta_state_dict = init_cls_model(
nb_classes, backbone, **kwargs)
self.net.to(self.device)
if self.device == 'cpu':
warnings.warn(
"No GPU found. The training can be EXTREMELY slow",
UserWarning)
self.meta_state_dict["weights"] = self.net.state_dict()
#self.meta_state_dict["optimizer"] = self.optimizer
def set_data(self,
X_train: Tuple[np.ndarray, torch.Tensor],
y_train: Tuple[np.ndarray, torch.Tensor],
X_test: Optional[Tuple[np.ndarray, torch.Tensor]] = None,
y_test: Optional[Tuple[np.ndarray, torch.Tensor]] = None,
**kwargs: Union[float, int]) -> None:
"""
Sets training and test data.
Args:
X_train (numpy array):
4D numpy array with image data (n_samples x 1 x height x width).
It is also possible to pass 3D by ignoring the channel dim,
which will be added automatically.
y_train (numpy array):
1D numpy array with target classes.
X_test (numpy array):
4D numpy array with image data (n_samples x 1 x height x width).
It is also possible to pass 3D by ignoring the channel dim,
which will be added automatically.
y_test (numpy array):
1D numpy array with target classes.
kwargs:
Parameters for train_test_split ('test_size' and 'seed') when
separate test set is not provided and 'memory_alloc', which
sets a threshold (in GBs) for holding entire training data on GPU
"""
if X_test is None or y_test is None:
X_train, X_test, y_train, y_test = train_test_split(
X_train, y_train, test_size=kwargs.get("test_size", .15),
shuffle=True, random_state=kwargs.get("seed", 1))
if self.full_epoch:
loaders = init_cls_dataloaders(
X_train, y_train, X_test, y_test,
self.batch_size, memory_alloc=kwargs.get("memory_alloc", 4))
self.train_loader, self.test_loader = loaders
else:
(self.X_train, self.y_train,
self.X_test, self.y_test) = preprocess_training_cls_data(
X_train, y_train, X_test, y_test,
self.batch_size,
kwargs.get("memory_alloc", 4))
def accuracy_fn(self,
y: torch.Tensor,
y_prob: torch.Tensor,
*args):
"""Computes a ratio of correct predictions"""
_, predicted = torch.max(y_prob.data, 1)
total = y.size(0)
correct = (predicted == y).sum().item()
return correct / total