Source code for atomai.models.dgm.vae

"""
vae.py
=======

Module for analysis of system "building blocks" with variational autoencoders

Created by Maxim Ziatdinov (email: maxim.ziatdinov@ai4microscopy.com)
"""

import os
from copy import deepcopy as dc
from typing import Dict, List, Optional, Tuple, Union

import matplotlib.pyplot as plt
import numpy as np
import torch
from scipy.stats import norm
from torchvision.utils import make_grid

from ...losses_metrics import vae_loss
from ...nets import init_VAE_nets
from ...trainers import viBaseTrainer
from ...utils import (crop_borders, extract_subimages, get_coord_grid,
                      imcoordgrid, set_train_rng, subimg_trajectories,
                      to_onehot)


class BaseVAE(viBaseTrainer):
    """
    General class for VAE models

    Args:
        in_dim:
            (height, width) or (height, width, channel) of input images
        latent_dim:
            Number of latent dimensions
        nb_classes:
            Number of classes (for class-conditional VAEs)
        seed:
            seed for torch and numpy (pseudo-)random numbers generators
        **conv_encoder (bool):
            use convolutional layers in encoder
        **conv_decoder (bool):
            use convolutional layers in decoder (doesn't apply to  rVAE)
        **numlayers_encoder (int):
            number of layers in encoder (Default: 2)
        **numlayers_decoder (int):
            number of layers in decoder (Default: 2)
        **numhidden_encoder (int):
            number of hidden units OR conv filters in encoder (Default: 128)
        **numhidden_decoder (int):
            number of hidden units OR conv filters in decoder (Default: 128)
    """
    def __init__(self,
                 in_dim: Tuple[int],
                 latent_dim: int,
                 nb_classes: int = 0,
                 coord: int = 0,
                 discrete_dim: Optional[List] = None,
                 seed: int = 0,
                 **kwargs: Union[int, bool]) -> None:
        super(BaseVAE, self).__init__()
        """
        Initializes encoder-decoder object
        """

        in_dim_error_msg = (
            "You must specify the input dimensions and pass them as a tuple. "
            "For images, specify (height, width) or (height, width, channels)" +
            " if multiple channels. For spectra, specify (length,)")

        if in_dim is None or not isinstance(in_dim, (tuple, list)):
            raise AssertionError(in_dim_error_msg)
        if isinstance(in_dim, tuple) and not isinstance(in_dim[0], int):
            raise AssertionError(in_dim_error_msg)

        if torch.cuda.is_available:
            torch.cuda.empty_cache()
        set_train_rng(seed)
        np.random.seed(seed)

        self.in_dim = in_dim
        self.z_dim = latent_dim
        if isinstance(discrete_dim, list):
            self.z_dim = self.z_dim + sum(discrete_dim)
        self.discrete_dim = discrete_dim

        if coord:
            if len(in_dim) not in (2, 3):
                raise NotImplementedError(
                    "VAE with rotation and translational invariance are " +
                    "available only for 2D image data")
            self.z_dim = self.z_dim + coord
            self.x_coord = imcoordgrid(in_dim).to(self.device)
        self.nb_classes = nb_classes

        (encoder_net, decoder_net,
         self.metadict) = init_VAE_nets(
            in_dim, latent_dim, coord, discrete_dim,
            nb_classes, **kwargs)
        self.set_model(encoder_net, decoder_net)
        self.sigmoid_out = self.metadict["sigmoid_out"]
        self.coord = coord

    def encode_(self,
                x_new: Union[np.ndarray, torch.Tensor],
                **kwargs: int) -> Tuple[np.ndarray]:
        """
        Encodes input image data using a trained VAE's encoder

        Args:
            x_new:
                image array to encode
            **num_batches (int):
                number of batches (Default: 10)

        Returns:
            Concatenated array of encoded vectors
        """
        def inference() -> np.ndarray:
            with torch.no_grad():
                encoded = self.encoder_net(x_i)
            encoded = torch.cat(encoded, -1).cpu().numpy()
            return encoded

        if isinstance(x_new, np.ndarray):
            x_new = torch.from_numpy(x_new).float()
        if (x_new.ndim == len(self.in_dim) == 2 or
           x_new.ndim == len(self.in_dim) == 3):
            x_new = x_new.unsqueeze(0)
        x_new = x_new.to(self.device)
        num_batches = kwargs.get("num_batches", 10)
        batch_size = len(x_new) // num_batches
        z_encoded = []
        for i in range(num_batches):
            x_i = x_new[i*batch_size:(i+1)*batch_size]
            z_encoded_i = inference()
            z_encoded.append(z_encoded_i)
        x_i = x_new[(i+1)*batch_size:]
        if len(x_i) > 0:
            z_encoded_i = inference()
            z_encoded.append(z_encoded_i)
        return np.concatenate(z_encoded)

    def encode(self,
               x_new: Union[np.ndarray, torch.Tensor],
               **kwargs: int) -> Tuple[np.ndarray]:
        """
        Encodes input image data using a trained VAE's encoder

        Args:
            x_new:
                image array to encode
            **num_batches (int):
                number of batches (Default: 10)

        Returns:
            Mean and SD of the encoded continuous distribution,
            and alphas ("class probabilities") for the encoded
            discrete distribution(s) (if any). For rVAE, the output is (z_mean, z_sd).
            For jVAE and jrVAE, the output is (z_mean, z_sd, alphas).
            In all the cases z_mean consists of the encoded angle as 1st dimension,
            encoded x- and y-shift as 2nd and 3rd dimensions (if translation is set to True),
            and standard VAE latent variables as 4th, 5th, ..., n-th dimensions
            (if translation is set to True; otherwise, 2nd, 3rd, ... n-th dimensions)
        """
        z = self.encode_(x_new, **kwargs)
        if not self.discrete_dim:
            z_mean = z[:, :self.z_dim]
            z_logsd = z[:, self.z_dim:]
            return z_mean, z_logsd
        cont_dim = self.z_dim - sum(self.discrete_dim)
        z_mean = z[:, :cont_dim]
        z_logsd = z[:, cont_dim:cont_dim+cont_dim]
        alphas = z[:, cont_dim+cont_dim:]
        return z_mean, z_logsd, alphas

    def decode(self, z_sample: Union[np.ndarray, torch.Tensor],
               y: Optional[Union[int, np.ndarray, torch.Tensor]] = None
               ) -> np.ndarray:
        """
        Takes a point in latent space and maps it to data space
        via the learned generative model

        Args:
            z_sample: coordinates in latent space
            y: label (optional)

        Returns:
            Generated ("decoded") image(s)
        """

        if isinstance(z_sample, np.ndarray):
            z_sample = torch.from_numpy(z_sample).float()
        if len(z_sample.size()) == 1:
            z_sample = z_sample[None, ...]
        if self.coord:
            x_coord = self.x_coord.expand(z_sample.size(0), *self.x_coord.size())
        z_sample = z_sample.cuda() if torch.cuda.is_available() else z_sample
        if y is not None:
            if isinstance(y, int):
                y = torch.tensor(y)
            elif isinstance(y, np.ndarray):
                y = torch.from_numpy(y)
            if y.dim() == 0:
                y = y.unsqueeze(0)
            y = y.cuda() if torch.cuda.is_available() else y
            targets = to_onehot(y, self.nb_classes)
            z_sample = torch.cat((z_sample, targets), dim=-1)
        if torch.cuda.is_available():
            self.decoder_net.cuda()
        self.decoder_net.eval()
        with torch.no_grad():
            if self.coord:
                x_decoded = self.decoder_net(x_coord, z_sample)
            else:
                x_decoded = self.decoder_net(z_sample)
        if self.sigmoid_out:
            x_decoded = torch.sigmoid(x_decoded)
        imdec = x_decoded.cpu().numpy()
        return imdec

    def reconstruct(self,
                    x_new: Union[np.ndarray, torch.Tensor],
                    **kwargs: int) -> np.ndarray:
        """
        Forward prediction with uncertainty quantification by sampling from
        the encoded mean and std. Works only for regular VAE (and not for rVAE)

        Args:
            x_new:
                image array to encode
            **label (int):
                class to be reconstructed (for cVAE, crVAE, jVAE, and jrVAE)
            **num_samples (int):
                number of samples to generate from normal distribution

        Returns:
            Ensemble of "decoded" images
        """
        num_samples = kwargs.get("num_samples", 32)
        label = kwargs.get("label")
        if isinstance(x_new, np.ndarray):
            x_new = torch.from_numpy(x_new).float()
        if torch.cuda.is_available():
            x_new = x_new.cuda()
            self.encoder_net.cuda()
        with torch.no_grad():
            encoded = self.encoder_net(x_new)
        if len(encoded) == 2:
            z_mean, z_logsd = encoded
        else:
            z_mean, z_logsd, alphas = encoded
        z_mean = z_mean[:, self.coord:]
        z_logsd = z_logsd[:, self.coord:]
        if label is not None:
            n = self.nb_classes if self.discrete_dim is None else self.discrete_dim  # probably makes sense to use nb_classes for j(r)VAE's discrete_dim
            alphas = to_onehot(
                torch.tensor(label).unsqueeze(0).to(self.device),
                torch.tensor(n).to(self.device))
        z_sd = torch.exp(z_logsd)
        ndist = torch.distributions.Normal(z_mean, z_sd)
        decoded_all = []
        for i in range(num_samples):
            z_sample = ndist.rsample()
            z_sample = z_sample.view(1, -1)
            if len(encoded) > 2 or label is not None:
                z_sample = torch.cat([z_sample, alphas], dim=1)
            decoded_all.append(self.decode(z_sample))
        decoded_all = np.concatenate(decoded_all, axis=0)
        return decoded_all

    def encode_images(self,
                      imgdata: np.ndarray,
                      **kwargs: int) -> Tuple[np.ndarray, np.ndarray]:
        """
        Encodes every pixel of every image in image stack

        Args:
            imgdata:
                3D numpy array of images. Can also be a single 2D image
            **num_batches (int):
                number of batches for for encoding pixels of a single image

        Returns:
            Cropped original image stack and encoded array (cropping is due to finite window size)
        """

        if (imgdata.ndim == len(self.in_dim) == 2 or
           imgdata.ndim == len(self.in_dim) == 3):
            imgdata = np.expand_dims(imgdata, axis=0)
        imgdata_encoded, imgdata_ = [], []
        for i, img in enumerate(imgdata):
            print("\rImage {}/{}".format(i+1, imgdata.shape[0]), end="")
            img_, img_encoded = self.encode_image_(img, **kwargs)
            imgdata_encoded.append(img_encoded)
            imgdata_.append(img_)
        return np.array(imgdata_), np.array(imgdata_encoded)

    def encode_image_(self,
                      img: np.ndarray,
                      **kwargs: int) -> Tuple[np.ndarray, np.ndarray]:
        """
        Crops and encodes a subimage around each pixel in the input image.
        The size of subimage is determined by size of images in VAE training data.

        Args:
            img:
                2D numpy array
            **num_batches (int):
                number of batches for encoding subimages

        Returns:
            Cropped original image and encoded array (cropping is due to finite window size)
        """

        num_batches = kwargs.get("num_batches", 10)
        inf = int(1e5)
        img_to_encode = img.copy()
        coordinates = get_coord_grid(img_to_encode, 1, return_dict=False)
        batch_size = coordinates.shape[0] // num_batches
        encoded_img = -inf * np.ones((*img_to_encode.shape, self.z_dim))
        for i in range(num_batches):
            coord_i = coordinates[i*batch_size:(i+1)*batch_size]
            subimgs_i, com_i, _ = extract_subimages(
                                    img_to_encode, coord_i, self.in_dim[0])
            if len(subimgs_i) > 0:
                z_mean, _ = self.encode(subimgs_i, num_batches=10)
                for k, (l, m) in enumerate(com_i):
                    encoded_img[int(l), int(m)] = z_mean[k]
        coord_i = coordinates[(i+1)*batch_size:]
        if len(coord_i) > 0:
            subimgs_i, com_i, _ = extract_subimages(
                                    img_to_encode, coord_i, self.in_dim[0])
            if len(subimgs_i) > 0:
                z_mean, _ = self.encode(subimgs_i, num_batches=10)
                for k, (l, m) in enumerate(com_i):
                    encoded_img[int(l), int(m)] = z_mean[k]

        img_to_encode[encoded_img[..., 0] == -inf] = 0
        img_to_encode = crop_borders(img_to_encode[..., None], 0)
        encoded_img = crop_borders(encoded_img, -inf)

        return img_to_encode[..., 0], encoded_img

    def encode_trajectories(self,
                            imgdata: np.ndarray,
                            coord_class_dict: Dict[int, np.ndarray],
                            window_size: int,
                            min_length: int,
                            rmax: int,
                            **kwargs: int
                            ) -> Tuple[List[np.ndarray], List[np.ndarray]]:
        """
        Calculates trajectories and latent variable value
        for each point in a trajectory.

        Args:
            imgdata:
                NN output (preferable) or raw data
            coord_class_dict:
                atomic/defect/particle coordinates
            window_size:
                size of subimages to crop
            min_length:
                minimum length of trajectory to be included
            rmax:
                maximum allowed distance (projected on xy plane) between defect
                in one frame and the position of its nearest neigbor in the next one
            **num_batches (int):
                number of batches for self.encode (Default: 10)

        Returns:
            List of encoded trajectories and corresponding movie frame numbers
        """
        t = subimg_trajectories(
                imgdata, coord_class_dict, window_size, min_length, rmax)
        trajectories, frames, subimgs_all = t.get_all_trajectories()
        trajectories_enc_all = []
        for traj, subimgs in zip(trajectories, subimgs_all):
            z_mean, _ = self.encode(
                subimgs, num_batches=kwargs.get("num_batches", 10))
            traj_enc = np.concatenate((traj[:, :2], z_mean), axis=-1)
            trajectories_enc_all.append(traj_enc)
        return trajectories_enc_all, frames, subimgs_all

    def manifold2d(self, **kwargs: Union[int, List, str, bool]) -> None:  # use torchvision's grid here
        """
        Performs mapping from latent space to data space allowing the learned
        manifold to be visualized. This works only for 2d latent variable
        (not counting angle & translation dimensions)

        Args:
            **d (int): grid size
            **l1 (list): range of 1st latent variable
            **l2 (list): range of 2nd latent variable
            **label (int): label in class-conditioned (r)VAE
            **disc_idx (int): discrete "class"
            **cmap (str): color map (Default: gnuplot)
            **draw_grid (bool): plot semi-transparent grid
            **origin (str): plot origin (e.g. 'lower')
        """
        y = kwargs.get("label")
        if y is None and self.nb_classes != 0:
            y = 0
        elif y and self.nb_classes == 0:
            y = None
        l1, l2 = kwargs.get("l1"), kwargs.get("l2")
        d = kwargs.get("d", 9)
        cmap = kwargs.get("cmap", "gnuplot")
        if len(self.in_dim) == 2:
            figure = np.zeros((self.in_dim[0] * d, self.in_dim[1] * d))
        elif len(self.in_dim) == 3:
            figure = np.zeros((self.in_dim[0] * d, self.in_dim[1] * d, self.in_dim[-1]))
        if l1 and l2:
            grid_x = np.linspace(l1[1], l1[0], d)
            grid_y = np.linspace(l2[0], l2[1], d)
        else:
            grid_x = norm.ppf(np.linspace(0.95, 0.05, d))
            grid_y = norm.ppf(np.linspace(0.05, 0.95, d))

        if self.discrete_dim:
            z_disc = np.zeros((sum(self.discrete_dim)))[None]
            z_disc[:, kwargs.get("disc_idx", 0)] = 1
        for i, xi in enumerate(grid_x):
            for j, yi in enumerate(grid_y):
                z_sample = np.array([[xi, yi]])
                if self.discrete_dim:
                    z_sample = np.concatenate((z_sample, z_disc), -1)
                if y is not None:
                    imdec = self.decode(z_sample, y)
                else:
                    imdec = self.decode(z_sample)
                figure[i * self.in_dim[0]: (i + 1) * self.in_dim[0],
                       j * self.in_dim[1]: (j + 1) * self.in_dim[1]] = imdec
        if figure.min() < 0:
            figure = (figure - figure.min()) / figure.ptp()

        fig, ax = plt.subplots(figsize=(10, 10))
        ax.imshow(figure, cmap=cmap, origin=kwargs.get("origin", "lower"),
                  extent=[grid_x.min(), grid_x.max(), grid_y.min(), grid_y.max()])
        ax.set_xlabel("$z_1$")
        ax.set_ylabel("$z_2$")
        draw_grid = kwargs.get("draw_grid")
        if draw_grid:
            major_ticks_x = np.arange(0, d * self.in_dim[0], self.in_dim[0])
            major_ticks_y = np.arange(0, d * self.in_dim[1], self.in_dim[1])
            ax.set_xticks(major_ticks_x)
            ax.set_yticks(major_ticks_y)
            ax.grid(which='major', alpha=0.6)
        for item in ([ax.xaxis.label, ax.yaxis.label] +
                     ax.get_xticklabels() + ax.get_yticklabels()):
            item.set_fontsize(18)
        if not kwargs.get("savefig"):
            plt.show()
        else:
            savedir = kwargs.get("savedir", './vae_learning/')
            fname = kwargs.get("filename", "manifold_2d")
            if not os.path.exists(savedir):
                os.makedirs(savedir)
            fig.savefig(os.path.join(savedir, '{}.png'.format(fname)))
            plt.close(fig)
        return figure

    def manifold_traversal(self, cont_idx: int,
                           d: int = 10,
                           cont_idx_fixed: int = 0,
                           plot: bool = True,
                           **kwargs: Union[str, float]
                           ) -> np.ndarray:
        """
        Latent space traversals for joint continuous and discrete
        latent representations
        """
        if self.discrete_dim is None:
            raise TypeError(
                "Traversal of latent space is implemented only for joint",
                " continuous and discrete latent distributions")
        if len(self.in_dim) == 3 and self.in_dim[-1] != 3:
            raise ValueError(
                "This works only for a single channel and 3-channel images")
        num_samples = d**2
        cont_dim = self.z_dim - sum(self.discrete_dim) - self.coord
        # Get continuous latent coordinates
        samples_cont = np.zeros(shape=(num_samples, cont_dim)) + cont_idx_fixed
        cdf_traversal = np.linspace(0.05, 0.95, d)
        cont_traversal = norm.ppf(cdf_traversal)
        for i in range(d):
            for j in range(d):
                samples_cont[i * d + j, cont_idx] = cont_traversal[j]
        # Get discrete latent coordinates
        disc_dim = self.discrete_dim[0]
        n = np.arange(0, disc_dim)
        n = np.tile(n, d // disc_dim + 1)[:d]
        samples_disc = []
        for i in range(d):
            samples_disc_i = np.zeros((d, disc_dim))
            samples_disc_i[:, n[i]] = 1
            samples_disc.append(samples_disc_i)
        samples_disc = np.concatenate(samples_disc)
        # Put them together and pass through a decoder
        samples = np.concatenate((samples_cont, samples_disc), -1)
        decoded = self.decode(samples)
        # Use a built-in torchvision utility to construct a nice grid
        decoded = decoded.transpose(0, 3, 1, 2) if decoded.ndim == 4 else decoded[:, None]
        grid = make_grid(torch.from_numpy(decoded),
                         nrow=d, padding=kwargs.get("pad", 2)).numpy()
        grid = grid.transpose(1, 2, 0) if len(self.in_dim) == 3 else grid[0]
        grid = (grid - grid.min()) / grid.ptp()
        if not kwargs.get("keep_square", False) and disc_dim != d:
            grid = grid[:(self.in_dim[0]+kwargs.get("pad", 2)) * disc_dim]
        if plot:
            plt.figure(figsize=(12, 12))
            plt.imshow(grid, cmap='gnuplot',
                       origin=kwargs.get("origin", "lower"))
            plt.xlabel("$z_{cont}$", fontsize=18)
            plt.ylabel("$z_{disc}$", fontsize=18)
            plt.xticks([])
            plt.yticks([])
            plt.show()
        return grid

    @classmethod
    def visualize_manifold_learning(cls,
                                    frames_dir: str,
                                    **kwargs: Union[str, int]) -> None:
        """
        Creates and stores a video showing evolution of
        learned 2D manifold during rVAE's training

        Args:
            frames_dir:
                directory with snapshots of manifold as .png files
                (the files should be named as "1.png", "2.png", etc.)
            **moviename (str): name of the movie
            **frame_duration (int): duration of each movie frame
        """
        from atomai.utils import animation_from_png
        movie_name = kwargs.get("moviename", "manifold_learning")
        duration = kwargs.get("frame_duration", 1)
        animation_from_png(frames_dir, movie_name, duration, remove_dir=False)

    def _check_inputs(self,
                      X_train: np.ndarray,
                      y_train: Optional[np.ndarray] = None,
                      X_test: Optional[np.ndarray] = None,
                      y_test: Optional[np.ndarray] = None
                      ) -> None:
        """
        Asserts that dimensionality and number classes contained in
        training and test data matches those specified at initialization
        """
        if self.in_dim != X_train.shape[1:]:
            raise RuntimeError(
                "The values of input dimensions you specified do not match " +
                "the training data dimensions. " +
                "Expected {} but got {}".format(self.in_dim, X_train.shape[1:]))
        if X_test is not None and self.in_dim != X_test.shape[1:]:
            raise RuntimeError(
                "The values of input dimensions you specified do not match " +
                "the test data dimensions. " +
                "Expected {} but got {}".format(self.in_dim, X_test.shape[1:]))
        if y_train is not None and self.nb_classes == 0:
            raise RuntimeError(
                "You must have forgotten to specify number of classes " +
                "during the initialization. Example of correct usage: " +
                "vae = VAE(in_dim=(28, 28), nb_classes=10)); " +
                "vae.fit(train_data, train_labels).")
        lbl_match = True
        if y_train is not None and y_test is None:
            lbl_match = self.nb_classes == len(np.unique(y_train))
        elif y_train is not None and y_test is not None:
            lbl_match = (self.nb_classes == len(np.unique(y_train))
                         == len(np.unique(y_test)))
        if not lbl_match:
            raise RuntimeError(
                "The number of classes specified at initialization must be " +
                "equal the the number of classes in train and test labels")

    def _2torch(self,
                X: Union[np.ndarray, torch.Tensor],
                y: Union[np.ndarray, torch.Tensor] = None
                ) -> torch.Tensor:
        """
        Rules for conversion of numpy arrays to torch tensors
        """
        if isinstance(X, np.ndarray):
            X = torch.from_numpy(X).float()
        if isinstance(y, np.ndarray):
            y = torch.from_numpy(y).long()
        return X, y


[docs]class VAE(BaseVAE): """ Implements a standard Variational Autoencoder (VAE) Args: in_dim: Input dimensions for image data passed as (heigth, width) for grayscale data or (height, width, channels) for multichannel data latent_dim: Number of VAE latent dimensions nb_classes: Number of classes for class-conditional rVAE seed: seed for torch and numpy (pseudo-)random numbers generators **conv_encoder (bool): use convolutional layers in encoder **conv_decoder (bool): use convolutional layers in decoder **numlayers_encoder (int): number of layers in encoder (Default: 2) **numlayers_decoder (int): number of layers in decoder (Default: 2) **numhidden_encoder (int): number of hidden units OR conv filters in encoder (Default: 128) **numhidden_decoder (int): number of hidden units OR conv filters in decoder (Default: 128) Example: >>> input_dim = (28, 28) # Input data dimensions (without n_samples) >>> # Intitialize model >>> vae = aoi.models.VAE(input_dim) >>> # Train >>> vae.fit(imstack_train, training_cycles=100, batch_size=100) >>> # Visualize learned manifold (for 2 latent dimesnions) >>> vae.manifold2d(origin="upper", cmap="gnuplot2) One can also pass labels to train a class-conditioned VAE >>> # Intitialize model >>> vae = aoi.models.VAE(input_dim, nb_classes=10) >>> # Train >>> vae.fit(imstack_train, labels_train, training_cycles=100, batch_size=100) >>> # Visualize learned manifold for class 1 >>> vae.manifold2d(label=1, origin="upper", cmap="gnuplot2") """ def __init__(self, in_dim: int = None, latent_dim: int = 2, nb_classes: int = 0, seed: int = 0, **kwargs: Union[int, bool, str] ) -> None: super(VAE, self).__init__(in_dim, latent_dim, nb_classes, 0, **kwargs) set_train_rng(seed) self.kdict_ = dc(kwargs) self.kdict_["num_iter"] = 0
[docs] def elbo_fn(self, x: torch.Tensor, x_reconstr: torch.Tensor, *args: torch.Tensor, **kwargs) -> torch.Tensor: """ Calculates ELBO """ return vae_loss(self.loss, self.in_dim, x, x_reconstr, *args, **kwargs)
[docs] def forward_compute_elbo(self, x: torch.Tensor, y: Optional[torch.Tensor] = None, mode: str = "train" ) -> torch.Tensor: """ VAE's forward pass with training/test loss computation """ x = x.to(self.device) if mode == "eval": with torch.no_grad(): z_mean, z_logsd = self.encoder_net(x) else: z_mean, z_logsd = self.encoder_net(x) self.kdict_["num_iter"] += 1 z_sd = torch.exp(z_logsd) z = self.reparameterize(z_mean, z_sd) if y is not None: targets = to_onehot(y, self.nb_classes) z = torch.cat((z, targets), -1) if mode == "eval": with torch.no_grad(): x_reconstr = self.decoder_net(z) else: x_reconstr = self.decoder_net(z) return self.elbo_fn(x, x_reconstr, z_mean, z_logsd, **self.kdict_)
[docs] def fit(self, X_train: Union[np.ndarray, torch.Tensor], y_train: Optional[Union[np.ndarray, torch.Tensor]] = None, X_test: Optional[Union[np.ndarray, torch.Tensor]] = None, y_test: Optional[Union[np.ndarray, torch.Tensor]] = None, loss: str = "mse", **kwargs) -> None: """ Trains VAE model Args: X_train: For images, 3D or 4D stack of training images with dimensions (n_images, height, width) for grayscale data or or (n_images, height, width, channels) for multi-channel data. For spectra, 2D stack of spectra with dimensions (length,) y_train: Vector with labels of dimension (n_images,), where n_images is a number of training images/spectra X_test: 3D or 4D stack of test images or 2D stack of spectra with the same dimensions as for the X_train (Default: None) y_test: Vector with labels of dimension (n_images,), where n_images is a number of test images/spectra loss: reconstruction loss function, "ce" or "mse" (Default: "mse") **capacity (list): List containing (max_capacity, num_iters, gamma) parameters to control the capacity of the latent channel. Based on https://arxiv.org/pdf/1804.03599.pdf **filename (str): file path for saving model aftereach training cycle ("epoch") """ self._check_inputs(X_train, y_train, X_test, y_test) for k, v in kwargs.items(): if k in ["capacity"]: self.kdict_[k] = v self.compile_trainer( (X_train, y_train), (X_test, y_test), **kwargs) self.loss = loss # this part needs to be handled better if self.loss == "ce": self.sigmoid_out = True # Use sigmoid layer for "prediction" stage self.metadict["sigmoid_out"] = True for e in range(self.training_cycles): self.current_epoch = e elbo_epoch = self.train_epoch() self.loss_history["train_loss"].append(elbo_epoch) if self.test_iterator is not None: elbo_epoch_test = self.evaluate_model() self.loss_history["test_loss"].append(elbo_epoch_test) self.print_statistics(e) self.update_metadict() self.save_model(self.filename) return
[docs] def update_metadict(self): self.metadict["num_epochs"] = self.current_epoch self.metadict["num_iter"] = self.kdict_["num_iter"]