Source code for atomai.nets.fcnn

"""
fcnn.py
=========

Fully convolutional neural networks

Created by Maxim Ziatdinov (email: maxim.ziatdinov@ai4microscopy.com)
"""

from typing import List, Union, Type

import torch
import torch.nn as nn
import torch.nn.functional as F
from .blocks import ConvBlock, ResModule, DilatedBlock, UpsampleBlock


[docs]class Unet(nn.Module): """ Builds a fully convolutional Unet-like neural network model Args: nb_classes: Number of classes in the ground truth nb_filters: Number of filters in 1st convolutional block (gets multiplied by 2 in each next block) dropout: Use dropouts to the 3 inner layers (Default: False) batch_norm: Use batch normalization after each convolutional layer (Default: True) upsampling_mode: Select between "bilinear" or "nearest" upsampling method. Bilinear is usually more accurate,but adds additional (small) randomness. For full reproducibility, consider using 'nearest' (this assumes that all other sources of randomness are fixed) with_dilation: Use dilated convolutions instead of regular ones in the bottleneck layers (Default: False) **layers (list): List with a number of layers in each block. The first 4 elements in the list are used to determine the number of layers in each block of the encoder (incluidng bottleneck layers), and the number of layers in the decoder is chosen accordingly (to maintain symmetry between encoder and decoder) """ def __init__(self, nb_classes: int = 1, nb_filters: int = 16, dropout: bool = False, batch_norm: bool = True, upsampling_mode: str = "bilinear", with_dilation: bool = False, **kwargs: List[int]) -> None: """ Initializes model parameters """ super(Unet, self).__init__() nbl = kwargs.get("layers", [1, 2, 2, 3]) dilation_values = torch.arange(2, 2*nbl[-1]+1, 2).tolist() padding_values = dilation_values.copy() dropout_vals = [.1, .2, .1] if dropout else [0, 0, 0] self.c1 = ConvBlock( 2, nbl[0], 1, nb_filters, batch_norm=batch_norm ) self.c2 = ConvBlock( 2, nbl[1], nb_filters, nb_filters*2, batch_norm=batch_norm ) self.c3 = ConvBlock( 2, nbl[2], nb_filters*2, nb_filters*4, batch_norm=batch_norm, dropout_=dropout_vals[0] ) if with_dilation: self.bn = DilatedBlock( 2, nb_filters*4, nb_filters*8, dilation_values=dilation_values, padding_values=padding_values, batch_norm=batch_norm, dropout_=dropout_vals[1] ) else: self.bn = ConvBlock( 2, nbl[3], nb_filters*4, nb_filters*8, batch_norm=batch_norm, dropout_=dropout_vals[1] ) self.upsample_block1 = UpsampleBlock( 2, nb_filters*8, nb_filters*4, mode=upsampling_mode) self.c4 = ConvBlock( 2, nbl[2], nb_filters*8, nb_filters*4, batch_norm=batch_norm, dropout_=dropout_vals[2] ) self.upsample_block2 = UpsampleBlock( 2, nb_filters*4, nb_filters*2, mode=upsampling_mode) self.c5 = ConvBlock( 2, nbl[1], nb_filters*4, nb_filters*2, batch_norm=batch_norm ) self.upsample_block3 = UpsampleBlock( 2, nb_filters*2, nb_filters, mode=upsampling_mode) self.c6 = ConvBlock( 2, nbl[0], nb_filters*2, nb_filters, batch_norm=batch_norm ) self.px = nn.Conv2d(nb_filters, nb_classes, 1, 1, 0) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Defines a forward pass """ # Contracting path c1 = self.c1(x) d1 = F.max_pool2d(c1, kernel_size=2, stride=2) c2 = self.c2(d1) d2 = F.max_pool2d(c2, kernel_size=2, stride=2) c3 = self.c3(d2) d3 = F.max_pool2d(c3, kernel_size=2, stride=2) # Bottleneck layer bn = self.bn(d3) # Expanding path u3 = self.upsample_block1(bn) u3 = torch.cat([c3, u3], dim=1) u3 = self.c4(u3) u2 = self.upsample_block2(u3) u2 = torch.cat([c2, u2], dim=1) u2 = self.c5(u2) u1 = self.upsample_block3(u2) u1 = torch.cat([c1, u1], dim=1) u1 = self.c6(u1) # Final layer used for pixel-wise convolution px = self.px(u1) return px
[docs]class dilnet(nn.Module): """ Builds a fully convolutional neural network model by utilizing a combination of regular and dilated convolutions Args: nb_classes: Number of classes in the ground truth nb_filters: Number of filters in first and last convolutional blocks (gets multiplied by 2 for the bottleneck layer) dropout: Add dropouts to the bottleneck layers (Default: False) batch_norm: Add batch normalization for each convolutional layer (Default: True) upsampling_mode: Select between "bilinear" or "nearest" upsampling method. Bilinear is usually more accurate,but adds additional (small) randomness. For full reproducibility, consider using 'nearest' (this assumes that all other sources of randomness are fixed) **layers (list): List with a number of layers for each block (Default: [3, 3, 3, 3]) """ def __init__(self, nb_classes: int = 1, nb_filters: int = 25, dropout: bool = False, batch_norm: bool = True, upsampling_mode: str = "bilinear", **kwargs: List[int]) -> None: """ Initializes model parameters """ super(dilnet, self).__init__() nbl = kwargs.get("layers", [3, 3, 3, 3]) dilation_values_1 = torch.arange(2, 2*nbl[1]+1, 2).tolist() padding_values_1 = dilation_values_1.copy() dilation_values_2 = torch.arange(2, 2*nbl[2]+1, 2).tolist() padding_values_2 = dilation_values_2.copy() dropout_vals = [.3, .3] if dropout else [0, 0] self.c1 = ConvBlock( 2, nbl[0], 1, nb_filters, batch_norm=batch_norm ) self.at1 = DilatedBlock( 2, nb_filters, nb_filters*2, dilation_values=dilation_values_1, padding_values=padding_values_1, batch_norm=batch_norm, dropout_=dropout_vals[0] ) self.at2 = DilatedBlock( 2, nb_filters*2, nb_filters*2, dilation_values=dilation_values_2, padding_values=padding_values_2, batch_norm=batch_norm, dropout_=dropout_vals[1] ) self.up1 = UpsampleBlock( 2, nb_filters*2, nb_filters, mode=upsampling_mode ) self.c2 = ConvBlock( 2, nbl[3], nb_filters*2, nb_filters, batch_norm=batch_norm ) self.px = nn.Conv2d(nb_filters, nb_classes, 1, 1, 0) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Defines a forward pass """ c1 = self.c1(x) d1 = F.max_pool2d(c1, kernel_size=2, stride=2) at1 = self.at1(d1) at2 = self.at2(at1) u1 = self.up1(at2) u1 = torch.cat([c1, u1], dim=1) u1 = self.c2(u1) px = self.px(u1) return px
[docs]class ResHedNet(nn.Module): """ Holistically nested edge detector with residual connections in each block Args: nb_classes: Number of classes in the ground truth nb_filters: Number of filters in 1st residual block (gets multiplied by 2 in each next block) upsampling_mode: Select between "bilinear" or "nearest" upsampling method. Bilinear is usually more accurate,but adds additional (small) randomness. For full reproducibility, consider using 'nearest' (this assumes that all other sources of randomness are fixed) **layers (list): 3-element list with a number of residual blocks in each segment (Default: [3, 4, 5]) """ def __init__(self, nb_classes: int = 1, nb_filters: int = 64, upsampling_mode: str = "bilinear", **kwargs: List[int]) -> None: """ Initializes model's parameters """ super(ResHedNet, self).__init__() nbl = kwargs.get("layers", [3, 4, 5]) self.upsample = upsampling_mode self.net1 = ResModule(2, nbl[0], 1, nb_filters, True) self.net2 = nn.Sequential( nn.MaxPool2d(2, 2), ResModule(2, nbl[1], nb_filters, 2*nb_filters, True) ) self.net3 = nn.Sequential( nn.MaxPool2d(2, 2), ResModule(2, nbl[2], 2*nb_filters, 4*nb_filters, True) ) self.net1score = nn.Sequential( nn.Conv2d(nb_filters, nb_classes, 1, 1, 0), nn.BatchNorm2d(nb_classes) ) self.net2score = nn.Sequential( nn.Conv2d(2*nb_filters, nb_classes, 1, 1, 0), nn.BatchNorm2d(nb_classes) ) self.net3score = nn.Sequential( nn.Conv2d(4*nb_filters, nb_classes, 1, 1, 0), nn.BatchNorm2d(nb_classes) ) self.out = torch.nn.Conv2d(3*nb_classes, nb_classes, 1, 1, 0) def forward(self, x: torch.Tensor) -> torch.Tensor: h, w = x.shape[2:4] net1out = self.net1(x) net2out = self.net2(net1out) net3out = self.net3(net2out) score1 = self.net1score(net1out) score2 = self.net2score(net2out) score3 = self.net3score(net3out) score2 = F.interpolate(score2, size=(h, w), mode=self.upsample) score3 = F.interpolate(score3, size=(h, w), mode=self.upsample) return self.out(torch.cat([score1, score2, score3], 1))
[docs]class SegResNet(nn.Module): ''' Builds a fully convolutional neural network based on SegNet architecture with residual blocks for semantic segmentation Args: nb_classes: Number of classes in the ground truth nb_filters: Number of filters in 1st residual block (gets multiplied by 2 in each next block) batch_norm: Use batch normalization after each convolutional layer (Default: True) upsampling_mode: Select between "bilinear" or "nearest" upsampling method. Bilinear is usually more accurate,but adds additional (small) randomness. For full reproducibility, consider using 'nearest' (this assumes that all other sources of randomness are fixed) **layers (list): 3-element list with a number of residual blocks in each residual segment (Default: [2, 2]) ''' def __init__(self, nb_classes: int = 1, nb_filters: int = 32, batch_norm: bool = True, upsampling_mode: str = "bilinear", **kwargs: List[int] ) -> None: ''' Initializes module parameters ''' super(SegResNet, self).__init__() nbl = kwargs.get("layers", [2, 2, 2]) self.c1 = ConvBlock( 2, 1, 1, nb_filters, batch_norm=batch_norm ) self.c2 = ResModule( 2, nbl[0], nb_filters, nb_filters*2, batch_norm=batch_norm ) self.bn = ResModule( 2, nbl[1], nb_filters*2, nb_filters*4, batch_norm=batch_norm ) self.upsample_block1 = UpsampleBlock( 2, nb_filters*4, nb_filters*2, 2, upsampling_mode ) self.c3 = ResModule( 2, nbl[2], nb_filters*4, nb_filters*2, batch_norm=batch_norm ) self.upsample_block2 = UpsampleBlock( 2, nb_filters*2, nb_filters, 2, upsampling_mode ) self.c4 = ConvBlock( 2, 1, nb_filters*2, nb_filters, batch_norm=batch_norm ) self.px = nn.Conv2d(nb_filters, nb_classes, 1, 1, 0) def forward(self, x: torch.Tensor) -> torch.Tensor: '''Defines a forward pass''' # Contracting path c1 = self.c1(x) d1 = F.max_pool2d(c1, kernel_size=2, stride=2) c2 = self.c2(d1) d2 = F.max_pool2d(c2, kernel_size=2, stride=2) # Bottleneck bn = self.bn(d2) # Expanding path u2 = self.upsample_block1(bn) u2 = torch.cat([c2, u2], dim=1) u2 = self.c3(u2) u1 = self.upsample_block2(u2) u1 = torch.cat([c1, u1], dim=1) u1 = self.c4(u1) # pixel-wise classification px = self.px(u1) return px
def init_fcnn_model(model: Union[Type[nn.Module], str], nb_classes: int, **kwargs: [bool, int, List] ) -> Type[nn.Module]: """ Initializes a fully convolutional neural network """ if not isinstance(model, str) and hasattr(model, "state_dict"): meta_state_dict = { 'model_type': 'Seg', model: 'custom', 'nb_classes': nb_classes} return model, meta_state_dict batch_norm = kwargs.get('batch_norm', True) dropout = kwargs.get('dropout', False) upsampling = kwargs.get('upsampling', "bilinear") meta_state_dict = { 'model_type': 'seg', 'model': model, 'nb_classes': nb_classes, 'batch_norm': batch_norm, 'dropout': dropout, 'upsampling': upsampling, } if isinstance(model, str) and model == 'Unet': with_dilation = kwargs.get('with_dilation', False) nb_filters = kwargs.get('nb_filters', 16) layers = kwargs.get("layers", [1, 2, 2, 3]) net = Unet( nb_classes, nb_filters, dropout, batch_norm, upsampling, with_dilation, layers=layers ) meta_state_dict["with_dilation"] = with_dilation elif isinstance(model, str) and model == 'dilnet': nb_filters = kwargs.get('nb_filters', 25) layers = kwargs.get("layers", [1, 3, 3, 1]) net = dilnet( nb_classes, nb_filters, dropout, batch_norm, upsampling, layers=layers ) elif isinstance(model, str) and model == 'SegResNet': nb_filters = kwargs.get('nb_filters', 32) layers = kwargs.get("layers", [2, 2, 2]) net = SegResNet( nb_classes, nb_filters, batch_norm, upsampling, layers=layers ) elif isinstance(model, str) and model == 'ResHedNet': nb_filters = kwargs.get('nb_filters', 64) layers = kwargs.get("layers", [3, 4, 5]) net = ResHedNet( nb_classes, nb_filters, upsampling, layers=layers ) else: raise NotImplementedError( "Currently implemented models are 'Unet', 'dilnet', SegResNet', and 'ResHedNet'" ) if model in ["ResHedNet", "SegResNet"]: meta_state_dict["dropout"] = None if model == ['ResHedNet']: meta_state_dict["batch_norm"] = True meta_state_dict["nb_filters"] = nb_filters meta_state_dict["layers"] = layers return net, meta_state_dict