Source code for atomai.nets.blocks

"""
blocks.py
=========

Customized NN blocks

Created by Maxim Ziatdinov (email: maxim.ziatdinov@ai4microscopy.com)
"""
from typing import List, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import mobilenet_v2, resnet50, vgg16


[docs]class ConvBlock(nn.Module): """ Creates block of layers each consisting of convolution operation, leaky relu and (optionally) dropout and batch normalization Args: ndim: Data dimensionality (1D or 2D) nb_layers: Number of layers in the block input_channels: Number of input channels for the block output_channels: Number of the output channels for the block kernel_size: Size of convolutional filter (in pixels) stride: Stride of convolutional filter padding: Value for edge padding batch_norm: Add batch normalization to each layer in the block lrelu_a: Value of alpha parameter in leaky ReLU activation for each layer in the block dropout_: Dropout value for each layer in the block """ def __init__(self, ndim: int, nb_layers: int, input_channels: int, output_channels: int, kernel_size: Union[Tuple[int], int] = 3, stride: Union[Tuple[int], int] = 1, padding: Union[Tuple[int], int] = 1, batch_norm: bool = False, lrelu_a: float = 0.01, dropout_: float = 0) -> None: """ Initializes module parameters """ super(ConvBlock, self).__init__() if not 0 < ndim < 3: raise AssertionError("ndim must be equal to 1 or 2") conv = nn.Conv2d if ndim == 2 else nn.Conv1d block = [] for idx in range(nb_layers): input_channels = output_channels if idx > 0 else input_channels block.append(conv(input_channels, output_channels, kernel_size=kernel_size, stride=stride, padding=padding)) if dropout_ > 0: block.append(nn.Dropout(dropout_)) block.append(nn.LeakyReLU(negative_slope=lrelu_a)) if batch_norm: if ndim == 2: block.append(nn.BatchNorm2d(output_channels)) else: block.append(nn.BatchNorm1d(output_channels)) self.block = nn.Sequential(*block) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Defines a forward pass """ output = self.block(x) return output
[docs]class UpsampleBlock(nn.Module): """ Defines upsampling block performed using bilinear or nearest-neigbor interpolation followed by 1-by-1 convolution (the latter can be used to reduce a number of feature channels) Args: ndim: Data dimensionality (1D or 2D) input_channels: Number of input channels for the block output_channels: Number of the output channels for the block scale_factor: Scale factor for upsampling mode: Upsampling mode. Select between "bilinear" and "nearest" """ def __init__(self, ndim: int, input_channels: int, output_channels: int, scale_factor: int = 2, mode: str = "bilinear") -> None: """ Initializes module parameters """ super(UpsampleBlock, self).__init__() if not any([mode == 'bilinear', mode == 'nearest']): raise NotImplementedError( "use 'bilinear' or 'nearest' for upsampling mode") if not 0 < ndim < 3: raise AssertionError("ndim must be equal to 1 or 2") conv = nn.Conv2d if ndim == 2 else nn.Conv1d self.scale_factor = scale_factor self.mode = mode if ndim == 2 else "nearest" self.conv = conv( input_channels, output_channels, kernel_size=1, stride=1, padding=0) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Defines a forward pass """ x = F.interpolate( x, scale_factor=self.scale_factor, mode=self.mode) return self.conv(x)
[docs]class ResBlock(nn.Module): """ Builds a residual block Args: ndim: Data dimensionality (1D or 2D) nb_layers: Number of layers in the block input_channels: Number of input channels for the block output_channels: Number of the output channels for the block kernel_size: Size of convolutional filter (in pixels) stride: Stride of convolutional filter padding: Value for edge padding batch_norm: Add batch normalization to each layer in the block lrelu_a: Value of alpha parameter in leaky ReLU activation for each layer in the block """ def __init__(self, ndim: int, input_channels: int, output_channels: int, kernel_size: Union[Tuple[int], int] = 3, stride: Union[Tuple[int], int] = 1, padding: Union[Tuple[int], int] = 1, batch_norm: bool = True, lrelu_a: float = 0.01 ) -> None: """ Initializes block's parameters """ super(ResBlock, self).__init__() if not 0 < ndim < 3: raise AssertionError("ndim must be equal to 1 or 2") conv = nn.Conv2d if ndim == 2 else nn.Conv1d self.lrelu_a = lrelu_a self.batch_norm = batch_norm self.c0 = conv(input_channels, output_channels, kernel_size=1, stride=1, padding=0) self.c1 = conv(output_channels, output_channels, kernel_size=3, stride=1, padding=1) self.c2 = conv(output_channels, output_channels, kernel_size=3, stride=1, padding=1) if batch_norm: bn = nn.BatchNorm2d if ndim == 2 else nn.BatchNorm1d self.bn1 = bn(output_channels) self.bn2 = bn(output_channels) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Defines forward pass """ x = self.c0(x) residual = x out = self.c1(x) if self.batch_norm: out = self.bn1(out) out = F.leaky_relu(out, negative_slope=self.lrelu_a) out = self.c2(out) if self.batch_norm: out = self.bn2(out) out += residual out = F.leaky_relu(out, negative_slope=self.lrelu_a) return out
[docs]class ResModule(nn.Module): """ Stitches multiple convolutional blocks with residual connections together Args: ndim: Data dimensionality (1D or 2D) res_depth: Number of residual blocks in a residual module input_channels: Number of filters in the input layer output_channels: Number of channels in the output layer batch_norm: Batch normalization for non-unity layers in the block lrelu_a: value of negative slope for LeakyReLU activation """ def __init__(self, ndim: int, res_depth: int, input_channels: int, output_channels: int, batch_norm: bool = True, lrelu_a: float = 0.01 ) -> None: """ Initializes module parameters """ super(ResModule, self).__init__() res_module = [] for i in range(res_depth): input_channels = output_channels if i > 0 else input_channels res_module.append( ResBlock(ndim, input_channels, output_channels, lrelu_a=lrelu_a, batch_norm=batch_norm)) self.res_module = nn.Sequential(*res_module) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Defines a forward pass """ x = self.res_module(x) return x
[docs]class DilatedBlock(nn.Module): """ Creates a "cascade" with dilated convolutional layers (aka atrous convolutions) Args: ndim: Data dimensionality (1D or 2D) input_channels: Number of input channels for the block output_channels: Number of the output channels for the block dilation_values: List of dilation rates for each convolution layer in the block (for example, dilation_values = [2, 4, 6] means that the dilated block will 3 layers with dilation values of 2, 4, and 6). padding_values: Edge padding for each dilated layer. The number of elements in this list should be equal to that in the dilated_values list and typically they can have the same values. kernel_size: Size of convolutional filter (in pixels) stride: Stride of convolutional filter batch_norm: Add batch normalization to each layer in the block lrelu_a: Value of alpha parameter in leaky ReLU activation for each layer in the block dropout_: Dropout value for each layer in the block """ def __init__(self, ndim: int, input_channels: int, output_channels: int, dilation_values: List[int], padding_values: List[int], kernel_size: Union[Tuple[int], int] = 3, stride: Union[Tuple[int], int] = 1, lrelu_a: float = 0.01, batch_norm: bool = False, dropout_: float = 0) -> None: """ Initializes module parameters """ super(DilatedBlock, self).__init__() if not 0 < ndim < 3: raise AssertionError("ndim must be equal to 1 or 2") conv = nn.Conv2d if ndim == 2 else nn.Conv1d atrous_module = [] for idx, (dil, pad) in enumerate(zip(dilation_values, padding_values)): input_channels = output_channels if idx > 0 else input_channels atrous_module.append(conv(input_channels, output_channels, kernel_size=kernel_size, stride=stride, padding=pad, dilation=dil, bias=True)) if dropout_ > 0: atrous_module.append(nn.Dropout(dropout_)) atrous_module.append(nn.LeakyReLU(negative_slope=lrelu_a)) if batch_norm: if ndim == 2: atrous_module.append(nn.BatchNorm2d(output_channels)) else: atrous_module.append(nn.BatchNorm1d(output_channels)) self.atrous_module = nn.Sequential(*atrous_module) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Defines a forward pass """ atrous_layers = [] for conv_layer in self.atrous_module: x = conv_layer(x) atrous_layers.append(x.unsqueeze(-1)) return torch.sum(torch.cat(atrous_layers, dim=-1), dim=-1)
class CustomBackbone(nn.Module): """ Custom backbone class to support ResNet50, VGG16, and MobileNetV2 architectures. Args: input_channels (int): The number of input channels. backbone_type (str, optional): The type of backbone architecture. Choose from "resnet", "vgg", or "mobilenet". Default is "mobilenet". """ def __init__(self, input_channels: int, backbone_type: str = "mobilenet"): super(CustomBackbone, self).__init__() if backbone_type == "resnet": # Load the pre-trained ResNet50 model backbone = resnet50(weights=None) # Modify the first convolutional layer to accept the given number of input channels backbone.conv1 = nn.Conv2d(input_channels, 64, kernel_size=7, stride=2, padding=3, bias=False) # Set the number of in_features for the fully connected layer self.in_features = backbone.fc.in_features # Remove the last fully connected layer (classification layer) self.backbone_layers = nn.Sequential(*list(backbone.children())[:-2]) elif backbone_type == "vgg": # Load the pre-trained VGG16 model backbone = vgg16(weights=None) # Modify the first convolutional layer to accept the given number of input channels backbone.features[0] = nn.Conv2d(input_channels, 64, kernel_size=3, padding=1) # Set the number of in_features for the fully connected layer self.in_features = backbone.features[-3].out_channels # Set the backbone layers self.backbone_layers = nn.Sequential(*list(backbone.features.children())[:-1]) elif backbone_type == "mobilenet": # Load the pre-trained MobileNetV2 model backbone = mobilenet_v2(weights=None) # Modify the first convolutional layer to accept the given number of input channels backbone.features[0][0] = nn.Conv2d(input_channels, 32, kernel_size=3, stride=2, padding=1, bias=False) # Set the number of in_features for the fully connected layer self.in_features = backbone.classifier[1].in_features # Set the backbone layers self.backbone_layers = nn.Sequential(*list(backbone.features.children())) else: raise ValueError("Unsupported backbone_type. Choose either 'resnet' or 'vgg'.") # Add adaptive average pooling self.adaptive_pool = nn.AdaptiveAvgPool2d((1, 1)) def forward(self, x: torch.Tensor): """ Forward pass Args: x (torch.Tensor): Input tensor with shape (batch_size, input_channels, height, width). Returns: torch.Tensor: Output tensor with shape (batch_size, in_features, 1, 1). """ x = self.adaptive_pool(self.backbone_layers(x)) return x