Source code for atomai.nets.ed

"""
ed.py
=========

Encoder and decoder modules for VAE/VED and im2spec/spec2im models

Created by Maxim Ziatdinov (email: maxim.ziatdinov@ai4microscopy.com)

"""

from typing import Tuple, Type, Union, Dict, List, Optional

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from .blocks import ConvBlock, DilatedBlock


[docs]class SignalEncoder(nn.Module): """ Encodes 1D/2D signal into a latent vector Args: signal_dim: Size of input signal. For images, it is (height, width). For spectra, it is (length,) z_dim: Number of fully-connected neurons in a "bottleneck layer" (latent dimensions) nb_layers: Number of convolutional layers nb_filters: Number of convolutional filters (aka "kernels") in each layer **batch_norm (bool): Apply batch normalization after each convolutional layer (Default: True) **downsampling (int): Downsamples input data by this factor before passing to convolutional layers (Default: no downsampling) """ def __init__(self, signal_dim: Tuple[int], z_dim: int, nb_layers: int, nb_filters: int, **kwargs: int) -> None: """ Initialize module parameters """ super(SignalEncoder, self).__init__() if isinstance(signal_dim, int): signal_dim = (signal_dim,) if not 0 < len(signal_dim) < 3: raise AssertionError("signal dimensionality must be to 1D or 2D") ndim = 2 if len(signal_dim) == 2 else 1 self.downsample = kwargs.get("downsampling", 0) bn = kwargs.get('batch_norm', True) if self.downsample: signal_dim = [s // self.downsample for s in signal_dim] n = np.product(signal_dim) self.reshape_ = nb_filters * n self.conv = ConvBlock( ndim, nb_layers, 1, nb_filters, lrelu_a=0.1, batch_norm=bn) self.fc = nn.Linear(nb_filters * n, z_dim) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Embeddes the input signal into a latent vector """ if self.downsample: if x.ndim == 3: x = F.avg_pool1d( x, self.downsample, self.downsample) else: x = F.avg_pool2d( x, self.downsample, self.downsample) x = self.conv(x) x = x.reshape(-1, self.reshape_) return self.fc(x)
[docs]class SignalDecoder(nn.Module): """ Decodes a latent vector into 1D/2D signal Args: signal_dim: Size of input signal. For images, it is (height, width). For spectra, it is (length,) z_dim: Number of fully-connected neurons in a "bottleneck layer" (latent dimensions) nb_layers: Number of convolutional layers nb_filters: Number of convolutional filters (aka "kernels") in each layer **batch_norm (bool): Apply batch normalization after each convolutional layer (Default: True) **upsampling (bool): Performs upsampling+convolution operation twice on the reshaped latent vector (starting from image/spectra dims 4x smaller than the target dims) before passing to the decoder """ def __init__(self, signal_dim: Tuple[int], z_dim: int, nb_layers: int, nb_filters: int, **kwargs: bool) -> None: """ Initializes module parameters """ super(SignalDecoder, self).__init__() self.upsampling = kwargs.get("upsampling", False) bn = kwargs.get('batch_norm', True) if isinstance(signal_dim, int): signal_dim = (signal_dim,) if not 0 < len(signal_dim) < 3: raise AssertionError("signal dimensionality must be to 1D or 2D") ndim = 2 if len(signal_dim) == 2 else 1 if self.upsampling: signal_dim = [s // 4 for s in signal_dim] n = np.product(signal_dim) self.reshape_ = (nb_filters, *signal_dim) self.fc = nn.Linear(z_dim, nb_filters*n) if self.upsampling: self.deconv1 = ConvBlock( ndim, 1, nb_filters, nb_filters, lrelu_a=0.1, batch_norm=bn) self.deconv2 = ConvBlock( ndim, 1, nb_filters, nb_filters, lrelu_a=0.1, batch_norm=bn) self.dilblock = DilatedBlock( ndim, nb_filters, nb_filters, dilation_values=torch.arange(1, nb_layers + 1).tolist(), padding_values=torch.arange(1, nb_layers + 1).tolist(), lrelu_a=0.1, batch_norm=bn) self.conv = ConvBlock( ndim, 1, nb_filters, 1, lrelu_a=0.1, batch_norm=bn) if ndim == 2: self.out = nn.Conv2d(1, 1, 1) else: self.out = nn.Conv1d(1, 1, 1) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Generates a signal from embedded features (latent vector) """ x = self.fc(x) x = x.reshape(-1, *self.reshape_) if self.upsampling: x = self.deconv1(x) x = F.interpolate(x, scale_factor=2, mode="nearest") x = self.deconv2(x) x = F.interpolate(x, scale_factor=2, mode="nearest") x = self.dilblock(x) x = self.conv(x) return self.out(x)
[docs]class SignalED(nn.Module): """ Transforms image into spectra (im2spec) and vice versa (spec2im) Args: feature_dim: Input data dimensions. (height, width) for images or (length,) for spectra target_dim: Output dimensions. (length,) for spectra or (height, width) for images latent_dim: Dimensionality of the latent space (number of neurons in a fully connected "bottleneck" layer) nblayers_encoder: Number of convolutional layers in the encoder nblayers_decoder: Number of convolutional layers in the decoder nbfilters_encoder: number of convolutional filters in each layer of the encoder nbfilters_decoder: Number of convolutional filters in each layer of the decoder batch_norm: Apply batch normalization after each convolutional layer (Default: True) encoder_downsampling: Downsamples input data by this factor before passing to convolutional layers (Default: no downsampling) decoder_upsampling: Performs upsampling+convolution operation twice on the reshaped latent vector (starting from image/spectra dims 4x smaller than the target dims) before passing to the decoder """ def __init__(self, feature_dim: Tuple[int], target_dim: Tuple[int], latent_dim: int, nblayers_encoder: int = 2, nblayers_decoder: int = 2, nbfilters_encoder: int = 64, nbfilters_decoder: int = 2, batch_norm: bool = True, encoder_downsampling: int = 0, decoder_upsampling: bool = False) -> None: """ Initializes im2spec/spec2im parameters """ super(SignalED, self).__init__() self.encoder = SignalEncoder( feature_dim, latent_dim, nblayers_encoder, nbfilters_encoder, batch_norm=batch_norm, downsampling=encoder_downsampling) self.decoder = SignalDecoder( target_dim, latent_dim, nblayers_decoder, nbfilters_decoder, batch_norm=batch_norm, upsampling=decoder_upsampling)
[docs] def encode(self, features: torch.Tensor) -> torch.Tensor: """ Embeddes the input image into a latent vector """ return self.encoder(features)
[docs] def decode(self, latent: torch.Tensor) -> torch.Tensor: """ Generates signal from the embedded features """ return self.decoder(latent)
def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass """ x = self.encode(x) return self.decode(x)
[docs]class convEncoderNet(nn.Module): """ Convolutional encoder/inference network (for variational autoencoder) Args: in_dim: Input dimensions. For images, it is (height, width) or (height, width, channels). For spectra, it is (length,) latent_dim: number of latent dimensions (the first 3 latent dimensions are angle & translations by default) num_layers: number of NN layers hidden_dim: number of neurons in each fully connnected layer (for mlp=True) or number of filters in each convolutional layer (for mlp=False) **softplus_out (bool): Optionally applies a softplus activation to the output associated with standard deviation of the encoded distribution """ def __init__(self, in_dim: Tuple[int], latent_dim: int = 2, num_layers: int = 2, hidden_dim: int = 32, **kwargs: Union[float, bool] ) -> None: """ Initializes network parameters """ super(convEncoderNet, self).__init__() if len(in_dim) not in (1, 2, 3): raise ValueError( "The input dimensions must be (length,) for 1D data and " + "(height, width) or (height, width, channel) for 2D data") dim = 2 if len(in_dim) > 1 else 1 c = in_dim[-1] if len(in_dim) > 2 else 1 self.conv = ConvBlock( dim, num_layers, c, hidden_dim, lrelu_a=kwargs.get("lrelu_a", 0.1)) self.reshape_ = hidden_dim * np.product(in_dim[:2]) self.fc11 = nn.Linear(self.reshape_, latent_dim) self.fc12 = nn.Linear(self.reshape_, latent_dim) self._out = nn.Softplus() if kwargs.get("softplus_out") else lambda x: x def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]: """ Forward pass Args: x: Input tensor with channel (> 1) as the last dimension """ x = x.unsqueeze(1) if x.ndim in (2, 3) else x.permute(0, -1, 1, 2) x = self.conv(x) x = x.reshape(-1, self.reshape_) z_mu = self.fc11(x) z_logstd = self._out(self.fc12(x)) return z_mu, z_logstd
[docs]class fcEncoderNet(nn.Module): """ Encoder/inference network (for variational autoencoder) Args: in_dim: Input dimensions. For images, it is (height, width) or (height, width, channels). For spectra, it is (length,) latent_dim: number of latent dimensions (the first 3 latent dimensions are angle & translations by default) num_layers: number of NN layers hidden_dim: number of neurons in each fully connnected layer (for mlp=True) or number of filters in each convolutional layer (for mlp=False) **softplus_out (bool): Optionally applies a softplus activation to the output associated with standard deviation of the encoded distribution """ def __init__(self, in_dim: Tuple[int], latent_dim: int = 2, num_layers: int = 2, hidden_dim: int = 32, **kwargs: bool ) -> None: """ Initializes network parameters """ super(fcEncoderNet, self).__init__() dense = [] for i in range(num_layers): input_dim = np.product(in_dim) if i == 0 else hidden_dim dense.extend([nn.Linear(input_dim, hidden_dim), nn.Tanh()]) self.dense = nn.Sequential(*dense) self.reshape_ = hidden_dim self.fc11 = nn.Linear(self.reshape_, latent_dim) self.fc12 = nn.Linear(self.reshape_, latent_dim) self._out = nn.Softplus() if kwargs.get("softplus_out") else lambda x: x def forward(self, x: torch.Tensor): """ Forward pass """ x = x.reshape(-1, np.product(x.size()[1:])) x = self.dense(x) x = x.reshape(-1, self.reshape_) z_mu = self.fc11(x) z_logstd = self._out(self.fc12(x)) return z_mu, z_logstd
class jfcEncoderNet(nn.Module): """ Encoder/inference network (for variational autoencoder) Args: in_dim: Input dimensions. For images, it is (height, width) or (height, width, channels). For spectra, it is (length,) latent_dim: number of latent dimensions (the first 3 latent dimensions are angle & translations by default) num_layers: number of NN layers hidden_dim: number of neurons in each fully connnected layer (for mlp=True) or number of filters in each convolutional layer (for mlp=False) **softplus_out (bool): Optionally applies a softplus activation to the output associated with standard deviation of the encoded distribution """ def __init__(self, in_dim: Tuple[int], latent_dim: int = 2, discrete_dim: List = [1], num_layers: int = 2, hidden_dim: int = 32, **kwargs: bool ) -> None: """ Initializes network parameters """ super(jfcEncoderNet, self).__init__() dense = [] for i in range(num_layers): input_dim = np.product(in_dim) if i == 0 else hidden_dim dense.extend([nn.Linear(input_dim, hidden_dim), nn.Tanh()]) self.dense = nn.Sequential(*dense) self.reshape_ = hidden_dim self.fc11 = nn.Linear(self.reshape_, latent_dim) self.fc12 = nn.Linear(self.reshape_, latent_dim) fc13 = [] for disc in discrete_dim: fc13.append(nn.Linear(self.reshape_, disc)) self.fc13 = nn.ModuleList(fc13) self._out = nn.Softplus() if kwargs.get("softplus_out") else lambda x: x def forward(self, x: torch.Tensor): """ Forward pass """ x = x.reshape(-1, np.product(x.size()[1:])) x = self.dense(x) x = x.reshape(-1, self.reshape_) encoded = [self.fc11(x), self._out(self.fc12(x))] for fc_ in self.fc13: encoded.append(F.softmax(fc_(x), dim=1)) return encoded class jconvEncoderNet(nn.Module): """ Convolutional encoder/inference network for joint continuous and discrete distributions (for variational autoencoder) Args: in_dim: Input dimensions. For images, it is (height, width) or (height, width, channels). For spectra, it is (length,) latent_dim: number of latent dimensions (the first 3 latent dimensions are angle & translations by default) num_layers: number of NN layers hidden_dim: number of neurons in each fully connnected layer (for mlp=True) or number of filters in each convolutional layer (for mlp=False) **softplus_out (bool): Optionally applies a softplus activation to the output associated with standard deviation of the encoded distribution """ def __init__(self, in_dim: Tuple[int], latent_dim: int = 2, discrete_dim: List = [1], num_layers: int = 2, hidden_dim: int = 32, **kwargs: Union[float, bool] ) -> None: """ Initializes network parameters """ super(jconvEncoderNet, self).__init__() if len(in_dim) not in (1, 2, 3): raise ValueError( "The input dimensions must be (length,) for 1D data and " + "(height, width) or (height, width, channel) for 2D data") dim = 2 if len(in_dim) > 1 else 1 c = in_dim[-1] if len(in_dim) > 2 else 1 self.conv = ConvBlock( dim, num_layers, c, hidden_dim, lrelu_a=kwargs.get("lrelu_a", 0.1)) self.reshape_ = hidden_dim * np.product(in_dim[:2]) self.fc11 = nn.Linear(self.reshape_, latent_dim) self.fc12 = nn.Linear(self.reshape_, latent_dim) fc13 = [] for disc in discrete_dim: fc13.append(nn.Linear(self.reshape_, disc)) self.fc13 = nn.ModuleList(fc13) self._out = nn.Softplus() if kwargs.get("softplus_out") else lambda x: x def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: """ Forward pass """ x = x.unsqueeze(1) if x.ndim in (2, 3) else x.permute(0, -1, 1, 2) x = self.conv(x) x = x.reshape(-1, self.reshape_) encoded = [self.fc11(x), self._out(self.fc12(x))] for fc_ in self.fc13: encoded.append(F.softmax(fc_(x), dim=1)) return encoded
[docs]class convDecoderNet(nn.Module): """ Convolutional decoder network (for variational autoencoder) Args: out_dim: Output dimensions. For images, it is (height, width) or (height, width, channels). For spectra, it is (length,) latent_dim: number of latent dimensions associated with images content num_layers: number of fully connected layers hidden_dim: number of neurons in each fully connected layer """ def __init__(self, out_dim: Tuple[int], latent_dim: int, num_layers: int = 2, hidden_dim: int = 32, **kwargs: float) -> None: """ Initializes network parameters """ super(convDecoderNet, self).__init__() if len(out_dim) not in (1, 2, 3): raise ValueError( "The output dimensions must be (length,) for 1D data and " + "(height, width) or (height, width, channel) for 2D data") dim = 2 if len(out_dim) > 1 else 1 c = out_dim[-1] if len(out_dim) > 2 else 1 self.fc_linear = nn.Linear( latent_dim, hidden_dim * np.product(out_dim[:2]), bias=False) self.reshape_ = (hidden_dim, *out_dim[:2]) self.decoder = ConvBlock( dim, num_layers, hidden_dim, hidden_dim, lrelu_a=kwargs.get("lrelu_a", 0.1)) conv_1x1 = nn.Conv2d if dim == 2 else nn.Conv1d self.conv_1x1 = conv_1x1(hidden_dim, c, 1, 1, 0) self.out_dim = (c, *out_dim[:2]) def forward(self, z: torch.Tensor) -> torch.Tensor: """ Forward pass """ z = self.fc_linear(z) z = z.reshape(-1, *self.reshape_) h = self.decoder(z) h = self.conv_1x1(h) h = h.reshape(-1, *self.out_dim) if h.size(1) == 1: h = h.squeeze(1) else: h = h.permute(0, 2, 3, 1) return h
[docs]class fcDecoderNet(nn.Module): """ Decoder network (for variational autoencoder) Args: out_dim: Output dimensions. For images, it is (height, width) or (height, width, channels). For spectra, it is (length,) latent_dim: number of latent dimensions associated with images content num_layers: number of fully connected layers hidden_dim: number of neurons in each fully connected layer """ def __init__(self, out_dim: Tuple[int], latent_dim: int, num_layers: int = 2, hidden_dim: int = 32, ) -> None: """ Initializes network parameters """ super(fcDecoderNet, self).__init__() if len(out_dim) not in (1, 2, 3): raise ValueError( "The output dimensions must be (length,) for 1D data and " + "(height, width) or (height, width, channel) for 2D data") c = out_dim[-1] if len(out_dim) > 2 else 1 decoder = [] for i in range(num_layers): hidden_dim_ = latent_dim if i == 0 else hidden_dim decoder.extend([nn.Linear(hidden_dim_, hidden_dim), nn.Tanh()]) self.decoder = nn.Sequential(*decoder) self.out = nn.Linear(hidden_dim, np.product(out_dim)) self.out_dim = (c, *out_dim[:2]) def forward(self, z: torch.Tensor) -> torch.Tensor: """ Forward pass """ h = self.decoder(z) h = self.out(h) h = h.reshape(-1, *self.out_dim) if h.size(1) == 1: h = h.squeeze(1) else: h = h.permute(0, 2, 3, 1) return h
[docs]class rDecoderNet(nn.Module): """ Spatial decoder network with (optional) skip connections Args: out_dim: output dimensions: (height, width) or (height, width, channels) latent_dim: number of latent dimensions associated with images content num_layers: number of fully connected layers hidden_dim: number of neurons in each fully connected layer skip: Use skip connections to propagate latent variables through decoder network (Default: False) """ def __init__(self, out_dim: Tuple[int], latent_dim: int, num_layers: int, hidden_dim: int, skip: bool = False, ) -> None: """ Initializes network parameters """ super(rDecoderNet, self).__init__() if len(out_dim) == 2: c = 1 self.reshape_ = (out_dim[0], out_dim[1]) else: c = out_dim[-1] self.reshape_ = (out_dim[0], out_dim[1], c) self.skip = skip self.coord_latent = coord_latent( latent_dim, hidden_dim, not skip) fc_decoder = [] for i in range(num_layers): fc_decoder.extend([nn.Linear(hidden_dim, hidden_dim), nn.Tanh()]) self.fc_decoder = nn.Sequential(*fc_decoder) self.out = nn.Linear(hidden_dim, c) def forward(self, x_coord: torch.Tensor, z: torch.Tensor) -> torch.Tensor: """ Forward pass """ batch_dim = x_coord.size()[0] h = self.coord_latent(x_coord, z) if self.skip: residual = h for i, fc_block in enumerate(self.fc_decoder): h = fc_block(h) if (i + 1) % 2 == 0: h = h.add(residual) else: h = self.fc_decoder(h) h = self.out(h) h = h.reshape(batch_dim, *self.reshape_) return h
class coord_latent(nn.Module): """ The "spatial" part of the rVAE's decoder that allows for translational and rotational invariance (based on https://arxiv.org/abs/1909.11663) Args: latent_dim: number of latent dimensions associated with images content out_dim: number of output dimensions (usually equal to number of hidden units in the first layer of the corresponding VAE's decoder) activation: Applies tanh activation to the output (Default: False) """ def __init__(self, latent_dim: int, out_dim: int, activation: bool = False) -> None: """ Initiate parameters """ super(coord_latent, self).__init__() self.fc_coord = nn.Linear(2, out_dim) self.fc_latent = nn.Linear(latent_dim, out_dim, bias=False) self.activation = nn.Tanh() if activation else None def forward(self, x_coord: torch.Tensor, z: torch.Tensor) -> torch.Tensor: """ Forward pass """ batch_dim, n = x_coord.size()[:2] x_coord = x_coord.reshape(batch_dim * n, -1) h_x = self.fc_coord(x_coord) h_x = h_x.reshape(batch_dim, n, -1) h_z = self.fc_latent(z) h = h_x.add(h_z.unsqueeze(1)) h = h.reshape(batch_dim * n, -1) if self.activation is not None: h = self.activation(h) return h def init_imspec_model(in_dim: Tuple[int], out_dim: Tuple[int], latent_dim: int, **kwargs: Union[int, bool] ) -> Tuple[Type[nn.Module], Dict[str, Union[int, bool]]]: """ Initializes ImSpec model """ nblayers_encoder = kwargs.get("nblayers_encoder", 3) nblayers_decoder = kwargs.get("nblayers_decoder", 4) nbfilters_encoder = kwargs.get("nbfilters_encoder", 64) nbfilters_decoder = kwargs.get("nbfilters_decoder", 64) batch_norm = kwargs.get("batch_norm", True) encoder_downsampling = kwargs.get("encoder_downsampling", 0) decoder_upsampling = kwargs.get("decoder_upsampling", False) net = SignalED( in_dim, out_dim, latent_dim, nblayers_encoder, nblayers_decoder, nbfilters_encoder, nbfilters_decoder, batch_norm, encoder_downsampling, decoder_upsampling) meta_state_dict = { "model_type": "imspec", "in_dim": in_dim, "out_dim": out_dim, "latent_dim": latent_dim, "nblayers_encoder": nblayers_encoder, "nblayers_decoder": nblayers_decoder, "nbfilters_encoder": nbfilters_encoder, "nbfilters_decoder": nbfilters_decoder, "batchnorm": batch_norm, "encoder_downsampling": encoder_downsampling, "decoder_upsampling": decoder_upsampling } return net, meta_state_dict def init_VAE_nets(in_dim: Tuple[int], latent_dim: int, coord: int = 0, discrete_dim: Optional[List] = None, nb_classes: int = 0, **kwargs ) -> Tuple[Type[nn.Module], Type[nn.Module], Dict[str, Union[int, bool]]]: """ Initializes encoder and decoder for VAE """ conv_e = kwargs.get("conv_encoder", False) if not coord: conv_d = kwargs.get("conv_decoder", False) numlayers_e = kwargs.get("numlayers_encoder", 2) numlayers_d = kwargs.get("numlayers_decoder", 2) numhidden_e = kwargs.get("numhidden_encoder", 128) numhidden_d = kwargs.get("numhidden_decoder", 128) skip = kwargs.get("skip", False) sigmoid_out = kwargs.get("sigmoid_out", False) softplus_out = kwargs.get("softplus_out") discrete_dim_ = 0 if discrete_dim: discrete_dim_ = sum(discrete_dim) nb_classes_ = nb_classes if discrete_dim_ == 0 else 0 if not coord: dnet = convDecoderNet if conv_d else fcDecoderNet decoder_net = dnet( in_dim, latent_dim+discrete_dim_+nb_classes_, numlayers_d, numhidden_d) else: decoder_net = rDecoderNet( in_dim, latent_dim+discrete_dim_+nb_classes_, numlayers_d, numhidden_d, skip) if not discrete_dim: enet = convEncoderNet if conv_e else fcEncoderNet encoder_net = enet( in_dim, latent_dim + coord, numlayers_e, numhidden_e, softplus_out=softplus_out) else: enet = jconvEncoderNet if conv_e else jfcEncoderNet encoder_net = enet( in_dim, latent_dim + coord, discrete_dim, numlayers_e, numhidden_e, softplus_out=softplus_out) meta_state_dict = { "model_type": "vae", "in_dim": in_dim, "latent_dim": latent_dim, "coord": coord, "conv_encoder": conv_e, "numlayers_encoder": numlayers_e, "numlayers_decoder": numlayers_d, "numhidden_encoder": numhidden_e, "numhidden_decoder": numhidden_d, "skip": skip, "nb_classes": nb_classes, "discrete_dim": discrete_dim, "sigmoid_out": sigmoid_out, "softplus_out": softplus_out } if not coord: meta_state_dict["conv_decoder"] = conv_d return encoder_net, decoder_net, meta_state_dict