Source code for atomai.trainers.etrainer

"""
etrainer.py
===========

Module for deeep ensemble training of neural networks

Created by Maxim Ziatdinov (email: maxim.ziatdinov@ai4microscopy.com)
"""

from copy import deepcopy as dc
from typing import Callable, Dict, Optional, Tuple, Type, Union
import warnings

import numpy as np
import torch
from sklearn.model_selection import train_test_split

from ..losses_metrics import IoU
from ..nets import init_fcnn_model, init_imspec_model
from ..utils import (average_weights, check_image_dims, check_signal_dims,
                     num_classes_from_labels, sample_weights)
from .trainer import BaseTrainer

augfn_type = Callable[[torch.Tensor, torch.Tensor, int], Tuple[torch.Tensor, torch.Tensor]]
compile_kwargs_type = Union[Type[torch.optim.Optimizer], str, int, bool]
ensemble_type = Dict[int, Dict[str, torch.Tensor]]


class BaseEnsembleTrainer(BaseTrainer):
    """
    Base class for deep ensemble training
    """
    def __init__(self,
                 model: Type[torch.nn.Module] = None,
                 nb_classes=None
                 ) -> None:
        """
        Initialize base ensemble trainer
        """
        super(BaseEnsembleTrainer, self).__init__()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        if model is not None:
            self.set_model(model, nb_classes)
        self.ensemble_state_dict = {}

    def compile_ensemble_trainer(self,
                                 **kwargs: compile_kwargs_type
                                 ) -> None:
        """
        Compile ensemble trainer.

        Args:
            kwargs:
                Keyword arguments to be passed to BaseTrainer.compile_trainer
                (loss, optimizer, compute_accuracy, full_epoch, swa,
                perturb_weights, batch_size, training_cycles, accuracy_metrics,
                filename, print_loss, plot_training_history)
        """
        self.kdict = kwargs

    def train_baseline(self,
                       X_train: np.ndarray,
                       y_train: np.ndarray,
                       X_test: Optional[np.ndarray] = None,
                       y_test: Optional[np.ndarray] = None,
                       seed: int = 1,
                       augment_fn: augfn_type = None
                       ) -> Type[torch.nn.Module]:
        """
        Trains baseline weights

        Args:
            X_train:
                Training features
            y_train:
                Training targets
            X_test:
                Test features
            y_test:
                Test targets
            seed:
                seed to be used for pytorch and numpy random numbers generator
            augment_fn:
                function that takes two torch tensors (features and targets),
                peforms some transforms, and returns the transformed tensors.
                The dimensions of the transformed tensors must be the same as
                the dimensions of the original ones.

        Returns:
            Trained baseline model
        """
        if self.net is None:
            raise AssertionError("You need to set a model first")
        self._reset_rng(seed)
        self._reset_weights()
        self._reset_training_history()
        self._delete_optimizer()
        (X_train, y_train,
         X_test, y_test) = self.preprocess_train_data(
            X_train, y_train, X_test, y_test)

        self.compile_trainer(
            (X_train, y_train, X_test, y_test), **self.kdict)
        self.data_augmentation(augment_fn)
        self.fit()
        return self.net

    def train_ensemble_from_scratch(self,
                                    X_train: np.ndarray,
                                    y_train: np.ndarray,
                                    X_test: Optional[np.ndarray] = None,
                                    y_test: Optional[np.ndarray] = None,
                                    n_models: int = 10,
                                    augment_fn: augfn_type = None,
                                    **kwargs
                                    ) -> Tuple[Type[torch.nn.Module], ensemble_type]:
        """
        Trains ensemble of models starting every time from scratch with
        different initialization

        Args:
            X_train: Training features
            y_train: Training targets
            X_test: Test features
            y_test: Test targets
            n_models: number of models to be trained
            augment_fn:
                function that takes two torch tensors (features and targets),
                peforms some transforms, and returns the transformed tensors.
                The dimensions of the transformed tensors must be the same as
                the dimensions of the original ones.
            **kwargs:
                Updates kwargs from initial compilation, which can be useful
                for iterative training.

        Returns:
            The last trained model and dictionary with ensemble weights
        """

        self.update_training_parameters(kwargs)

        print("Training ensemble models (strategy = 'from_scratch')")
        for i in range(n_models):
            print("\nEnsemble model {}".format(i + 1))
            self.kdict["batch_seed"] = i
            model_i = self.train_baseline(
                X_train, y_train, X_test, y_test, i, augment_fn)
            self.ensemble_state_dict[i] = dc(model_i.state_dict())
            self.save_ensemble_metadict()
        return self.net, self.ensemble_state_dict

    def train_ensemble_from_baseline(self,
                                     X_train: np.ndarray,
                                     y_train: np.ndarray,
                                     X_test: Optional[np.ndarray] = None,
                                     y_test: Optional[np.ndarray] = None,
                                     basemodel: Type[torch.nn.Module] = None,
                                     n_models: int = 10,
                                     training_cycles_base: int = 1000,
                                     training_cycles_ensemble: int = 100,
                                     augment_fn: augfn_type = None,
                                     **kwargs
                                     ) -> Tuple[Type[torch.nn.Module], ensemble_type]:
        """
        Trains ensemble of models starting each time from baseline model.
        Each ensemble model is trained each with different random shuffling
        of batches (and different seed for data augmentation if any).
        If a baseline model is not provided, the baseline weights are trained
        for *N* epochs and then used as a baseline to train multiple ensemble
        models for *n* epochs (*n* << *N*),

        Args:
            X_train: Training features
            y_train: Training targets
            X_test: Test features
            y_test: Test targets
            basemodel: Provide a baseline model (Optional)
            n_models: number of models in ensemble
            training_cycles_base:
                Number of training iterations for baseline model
            training_cycles_ensemble:
                Number of training iterations for every ensemble model
            augment_fn:
                function that takes two torch tensors (features and targets),
                peforms some transforms, and returns the transformed tensors.
                The dimensions of the transformed tensors must be the same as
                the dimensions of the original ones.
            **kwargs: Updates kwargs from initial compilation
                (can be useful for iterative training)

        Returns:
            Model with averaged weights and dictionary with ensemble weights
        """

        self.update_training_parameters(kwargs)

        if basemodel is None:
            self.kdict["training_cycles"] = training_cycles_base
            print("Training baseline model...")
            basemodel = self.train_baseline(
                X_train, y_train, X_test, y_test, 1, augment_fn)
        else:  # this is the only time when we do not use train_from_baseline
            (X_train, y_train,
             X_test, y_test) = self.preprocess_train_data(
                X_train, y_train, X_test, y_test)

        self.set_model(basemodel)
        basemodel_state_dict = dc(self.net.state_dict())

        self.kdict["training_cycles"] = training_cycles_ensemble
        if not self.full_epoch:
            if "print_loss" not in self.kdict.keys():
                self.kdict["print_loss"] = 10

        print("\nTraining ensemble models (strategy = 'from_baseline')")
        for i in range(n_models):
            print("\nEnsemble model {}".format(i + 1))
            if i > 0:
                self.net.load_state_dict(basemodel_state_dict)
            self._reset_rng(i+2)
            self._reset_training_history()
            self._delete_optimizer()
            self.compile_trainer(  # Note that here we reinitialize optimizer
                (X_train, y_train, X_test, y_test),
                batch_seed=i+2, **self.kdict)
            model_i = self.run()
            self.ensemble_state_dict[i] = dc(model_i.state_dict())
            self.save_ensemble_metadict()
            averaged_weights = average_weights(self.ensemble_state_dict)
            model_i.load_state_dict(averaged_weights)
        return model_i, self.ensemble_state_dict

    def train_swag(self,
                   X_train: np.ndarray,
                   y_train: np.ndarray,
                   X_test: Optional[np.ndarray] = None,
                   y_test: Optional[np.ndarray] = None,
                   n_models: int = 10,
                   augment_fn: augfn_type = None,
                   **kwargs: compile_kwargs_type
                   ) -> Tuple[Type[torch.nn.Module], ensemble_type]:
        """
        Performs SWAG-like weights sampling at the end of single model training

        Args:
            X_train: Training features
            y_train: Training targets
            X_test: Test features
            y_test: Test targets
            n_models: number fo samples to be drawn
            augment_fn:
                function that takes two torch tensors (features and targets),
                peforms some transforms, and returns the transformed tensors.
                The dimensions of the transformed tensors must be the same as
                the dimensions of the original ones.
            **kwargs: Updates kwargs from initial compilation
                (can be useful for iterative training)

        Returns:
            Baseline model and dictionary with sampled weights
        """
        self.update_training_parameters(kwargs)
        self.kdict["swa"] = True
        basemodel = self.train_baseline(
                X_train, y_train, X_test, y_test, 1, augment_fn)
        self.ensemble_state_dict = sample_weights(
            self.running_weights, n_models)
        self.save_ensemble_metadict()

        return basemodel, self.ensemble_state_dict

    def update_training_parameters(self, kwargs):
        warn_msg = "Overwriting the initial value '{}' of parameter '{}' with new value '{}'"
        if len(kwargs) != 0:
            for k, v in kwargs.items():
                if k in self.kdict.keys():
                    warnings.warn(
                        warn_msg.format(self.kdict[k], k, kwargs[k]),
                        UserWarning)
                self.kdict[k] = v

    def preprocess_train_data(self,
                              train_data: Tuple[np.ndarray]
                              ) -> Tuple[np.ndarray]:
        X, y, X_, y_ = train_data
        tor = lambda x: torch.from_numpy(x)
        return tor(X), tor(y), tor(X_), tor(y_)

    def save_ensemble_metadict(self, filename: str = None) -> None:
        """
        Saves meta dictionary with ensemble weights and key information about
        model's structure (needed to load it back) to disk
        """
        fname = self.filename if filename is None else filename
        ensemble_metadict = dc(self.meta_state_dict)
        ensemble_metadict["weights"] = self.ensemble_state_dict
        torch.save(ensemble_metadict, fname + "_ensemble_metadict.tar")


[docs]class EnsembleTrainer(BaseEnsembleTrainer): """ Deep ensemble trainer Args: model: Built-in AtomAI model (passed as string) or initialized custom PyTorch model nb_classes: Number of classes (if any) in the model's output **kwargs: Number of input, output, and latent dimensions for imspec models (in_dim, out_dim, latent_dim) Example: >>> # Train an ensemble of Unet-s >>> etrainer = aoi.trainers.EnsembleTrainer( >>> "Unet", batch_norm=True, nb_classes=3, with_dilation=False) >>> etrainer.compile_ensemble_trainer(training_cycles=500) >>> # Train 10 different models from scratch >>> smodel, ensemble = etrainer.train_ensemble_from_scratch( >>> images, labels, images_test, labels_test, n_models=10) """ def __init__(self, model: Union[str, Type[torch.nn.Module]] = None, nb_classes: int = 1, **kwargs) -> None: super(EnsembleTrainer, self).__init__() """ Initializes ensemble trainer """ self.nb_classes = nb_classes if isinstance(model, str): if model in ["Unet", "dilnet", "SegResNet", "ResHedNet"]: self.net, self.meta_state_dict = init_fcnn_model( model, self.nb_classes, **kwargs) self.accuracy_fn = accuracy_fn_seg(nb_classes) elif model == "imspec": keys_check = [] for k in ["in_dim", "out_dim", "latent_dim"]: if k not in kwargs.keys(): keys_check.append(k) if len(keys_check) > 0: raise AssertionError( "Specify input, output, and latent dimensions " + "(Missing dimensions: {})".format(str(keys_check)[1:-1])) self.in_dim = kwargs.pop("in_dim") self.out_dim = kwargs.pop("out_dim") latent_dim = kwargs.pop("latent_dim") self.net, self.meta_state_dict = init_imspec_model( self.in_dim, self.out_dim, latent_dim, **kwargs) self.net.to(self.device) else: self.set_model(model, nb_classes) self.meta_state_dict["weights"] = self.net.state_dict() self.meta_state_dict["optimizer"] = self.optimizer
[docs] def compile_ensemble_trainer(self, **kwargs: compile_kwargs_type ) -> None: """ Compile ensemble trainer. Args: kwargs: Keyword arguments to be passed to BaseTrainer.compile_trainer (loss, optimizer, compute_accuracy, full_epoch, swa, perturb_weights, batch_size, training_cycles, accuracy_metrics, filename, print_loss, plot_training_history) """ self.kdict = kwargs self.full_epoch = self.kdict.get("full_epoch", False) self.batch_size = self.kdict.get("batch_size", 32) self.kdict["overwrite_train_data"] = False
[docs] def train_baseline(self, X_train: np.ndarray, y_train: np.ndarray, X_test: Optional[np.ndarray] = None, y_test: Optional[np.ndarray] = None, seed: int = 1, augment_fn: augfn_type = None ) -> Type[torch.nn.Module]: """ Trains baseline weights Args: X_train: Training features y_train: Training targets X_test: Test features y_test: Test targets seed: seed to be used for pytorch and numpy random numbers generator augment_fn: function that takes two torch tensors (features and targets), peforms some transforms, and returns the transformed tensors. The dimensions of the transformed tensors must be the same as the dimensions of the original ones. Returns: Trained baseline weights """ if self.net is None: raise AssertionError("You need to set a model first") train_data = self.preprocess_train_data( X_train, y_train, X_test, y_test) self.set_data(*train_data, **self.kdict) self._reset_rng(seed) self._reset_weights() self._reset_training_history() self._delete_optimizer() self.compile_trainer( (X_train, y_train, X_test, y_test), **self.kdict) self.data_augmentation(augment_fn) self.fit() return self.net
[docs] def preprocess_train_data(self, *args: np.ndarray ) -> Tuple[torch.Tensor]: """ Training and test data preprocessing """ if self.meta_state_dict.get("model_type") == "seg": train_data = set_data_seg(*args, self.nb_classes) elif self.meta_state_dict.get("model_type") == "imspec": train_data = set_data_imspec(*args, (self.in_dim, self.out_dim)) return train_data
def set_data_seg(X_train: np.ndarray, y_train: np.ndarray, X_test: Optional[np.ndarray] = None, y_test: Optional[np.ndarray] = None, nb_classes_set: int = 1, **kwargs: Union[float, int] ) -> Tuple[np.ndarray]: """ Sets training and test data for semantic segmentation """ nb_classes = num_classes_from_labels(y_train) if nb_classes != nb_classes_set: raise AssertionError("Number of specified classes" + " is different from the number of classes" + " contained in training data") if X_test is None or y_test is None: X_train, X_test, y_train, y_test = train_test_split( X_train, y_train, test_size=kwargs.get("test_size", .15), shuffle=True, random_state=kwargs.get("seed", 1)) X_train, y_train, X_test, y_test = check_image_dims( X_train, y_train, X_test, y_test, nb_classes) f32, i64 = lambda x: x.astype(np.float32), lambda x: x.astype(np.int64) X_train, X_test = f32(X_train), f32(X_test) if nb_classes > 1: y_train, y_test = i64(y_train), i64(y_test) else: y_train, y_test = f32(y_train), f32(y_test) return X_train, y_train, X_test, y_test def set_data_imspec(X_train: np.ndarray, y_train: np.ndarray, X_test: Optional[np.ndarray] = None, y_test: Optional[np.ndarray] = None, dims: Tuple[int] = None, **kwargs: Union[float, int] ) -> Tuple[np.ndarray]: """ Sets training and test data for im2spec and spec2im models """ if X_test is None or y_test is None: X_train, X_test, y_train, y_test = train_test_split( X_train, y_train, test_size=kwargs.get("test_size", .15), shuffle=True, random_state=kwargs.get("seed", 1)) X_train, y_train, X_test, y_test = check_signal_dims( X_train, y_train, X_test, y_test) in_dim, out_dim = X_train.shape[2:], y_train.shape[2:] if dims[0] != in_dim or dims[1] != out_dim: raise AssertionError( "The input/output dimensions of the model must match" + " the height, width and length (for spectra) of training") f32 = lambda x: x.astype(np.float32) X_train, X_test = f32(X_train), f32(X_test) y_train, y_test = f32(y_train), f32(y_test) return X_train, y_train, X_test, y_test def accuracy_fn_seg(nb_classes: int ) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]: """ Returns function that computes IoU score """ def accuracy(y, y_prob, *args): iou_score = IoU( y, y_prob, nb_classes).evaluate() return iou_score return accuracy