Source code for atomai.models.denoiser

"""
denoiser.py
===========

Denoising autoencoder model for image cleaning

Created by Maxim Ziatdinov (email: maxim.ziatdinov@ai4microscopy.com)
Modified with conventional batch normalization approach
"""

from typing import Type, Union, Optional, Tuple
import torch
import numpy as np
from ..trainers import BaseTrainer
from ..predictors import BasePredictor
from ..nets import ConvBlock, UpsampleBlock
from ..utils import set_train_rng, preprocess_denoiser_data


[docs]class DenoisingAutoencoder(BaseTrainer): """ Denoising autoencoder model for image cleaning and noise reduction Args: encoder_filters: List of filter sizes for encoder layers (Default: [8, 16, 32, 64]) decoder_filters: List of filter sizes for decoder layers (Default: [64, 32, 16, 8]) encoder_layers: Number of convolutional layers per encoder block (Default: [1, 2, 2, 2]) decoder_layers: Number of convolutional layers per decoder block (Default: [2, 2, 2, 1]) use_batch_norm: Whether to use batch normalization in both encoder and decoder (Default: True) upsampling_mode: Upsampling method ('nearest' or 'bilinear') (Default: 'nearest') **seed: Random seed for reproducibility (Default: 1) Example: >>> # Initialize model >>> model = aoi.models.DenoisingAutoencoder() >>> # Train on noisy/clean image pairs >>> model.fit(noisy_images, clean_images, noisy_test, clean_test, >>> training_cycles=500, swa=True) >>> # Denoise new images >>> cleaned = model.predict(new_noisy_images) """ def __init__(self, encoder_filters: list = [8, 16, 32, 64], decoder_filters: list = [64, 32, 16, 8], encoder_layers: list = [1, 2, 2, 2], decoder_layers: list = [2, 2, 2, 1], use_batch_norm: bool = False, upsampling_mode: str = 'nearest', **kwargs) -> None: """ Initialize denoising autoencoder """ super(DenoisingAutoencoder, self).__init__() seed = kwargs.get("seed", 1) set_train_rng(seed) # Store architecture parameters self.encoder_filters = encoder_filters self.decoder_filters = decoder_filters self.encoder_layers = encoder_layers self.decoder_layers = decoder_layers self.use_batch_norm = use_batch_norm self.upsampling_mode = upsampling_mode # Build the autoencoder self.net = self._build_autoencoder() self.net.to(self.device) # Initialize meta state dict for saving/loading self.meta_state_dict = { "model_type": "denoising_autoencoder", "encoder_filters": encoder_filters, "decoder_filters": decoder_filters, "encoder_layers": encoder_layers, "decoder_layers": decoder_layers, "use_batch_norm": use_batch_norm, "upsampling_mode": upsampling_mode, "weights": self.net.state_dict() } def _build_autoencoder(self) -> torch.nn.Module: """ Build the encoder-decoder architecture with consistent batch norm placement """ # Build encoder encoder_modules = [] in_channels = 1 # Assuming grayscale images for i, (filters, layers) in enumerate(zip(self.encoder_filters, self.encoder_layers)): # Add convolutional block with consistent batch norm usage encoder_modules.append( ConvBlock(ndim=2, nb_layers=layers, input_channels=in_channels, output_channels=filters, batch_norm=self.use_batch_norm) ) # Add max pooling (except for the last layer) if i < len(self.encoder_filters) - 1: encoder_modules.append(torch.nn.MaxPool2d(2, 2)) in_channels = filters encoder = torch.nn.Sequential(*encoder_modules) # Build decoder decoder_modules = [] for i, (filters, layers) in enumerate(zip(self.decoder_filters, self.decoder_layers)): # Add upsampling (except for the first layer) if i > 0: decoder_modules.append( UpsampleBlock(ndim=2, input_channels=in_channels, output_channels=in_channels, mode=self.upsampling_mode) ) # Add convolutional block with same batch norm setting as encoder decoder_modules.append( ConvBlock(ndim=2, nb_layers=layers, input_channels=in_channels, output_channels=filters, batch_norm=self.use_batch_norm) ) in_channels = filters # Final output layer (no batch norm for final reconstruction) decoder_modules.append(torch.nn.Conv2d(in_channels, 1, 1)) decoder = torch.nn.Sequential(*decoder_modules) # Combine encoder and decoder autoencoder = torch.nn.Sequential(encoder, decoder) return autoencoder
[docs] def fit(self, X_train: Union[np.ndarray, torch.Tensor], y_train: Union[np.ndarray, torch.Tensor], X_test: Optional[Union[np.ndarray, torch.Tensor]] = None, y_test: Optional[Union[np.ndarray, torch.Tensor]] = None, loss: str = 'mse', optimizer: Optional[Type[torch.optim.Optimizer]] = None, training_cycles: int = 500, batch_size: int = 32, compute_accuracy: bool = False, full_epoch: bool = False, swa: bool = True, perturb_weights: bool = False, **kwargs): """ Train the denoising autoencoder Args: X_train: Noisy input images for training y_train: Clean target images for training X_test: Noisy input images for testing y_test: Clean target images for testing loss: Loss function (Default: 'mse') optimizer: Optimizer (Default: Adam with lr=1e-3) training_cycles: Number of training epochs batch_size: Batch size for training compute_accuracy: Whether to compute accuracy metrics full_epoch: Whether to use full epochs swa: Whether to use stochastic weight averaging perturb_weights: Whether to use weight perturbation **kwargs: Additional arguments for training """ if X_test is None or y_test is None: from sklearn.model_selection import train_test_split X_train, X_test, y_train, y_test = train_test_split( X_train, y_train, test_size=kwargs.get("test_size", .15), shuffle=True, random_state=kwargs.get("seed", 1)) # Preprocess data X_train, y_train, X_test, y_test = preprocess_denoiser_data( X_train, y_train, X_test, y_test) # Compile and run training self.compile_trainer( (X_train, y_train, X_test, y_test), loss=loss, optimizer=optimizer, training_cycles=training_cycles, batch_size=batch_size, compute_accuracy=compute_accuracy, full_epoch=full_epoch, swa=swa, perturb_weights=perturb_weights, **kwargs ) self.run() # Update meta state dict self.meta_state_dict["weights"] = self.net.state_dict()
[docs] def predict(self, data: Union[np.ndarray, torch.Tensor], **kwargs) -> np.ndarray: """ Denoise input images Args: data: Input noisy images **num_batches: Number of batches for prediction (Default: 10) Returns: Denoised images """ use_gpu = self.device == 'cuda' predictor = BasePredictor(self.net, use_gpu, **kwargs) # Ensure proper format for prediction if isinstance(data, np.ndarray): if data.ndim == 2: data = data[None, None, ...] # Add batch and channel dims elif data.ndim == 3: data = data[:, None, ...] # Add channel dim prediction = predictor.predict(data, **kwargs) return prediction.detach().cpu().numpy().squeeze()
[docs] def load_weights(self, filepath: str) -> None: """ Load saved model weights """ weight_dict = torch.load(filepath, map_location=self.device) if "weights" in weight_dict: self.net.load_state_dict(weight_dict["weights"]) else: self.net.load_state_dict(weight_dict)
def init_denoising_autoencoder(**kwargs) -> Tuple[Type[torch.nn.Module], dict]: """ Initialize a denoising autoencoder model Returns: Tuple of (model, meta_state_dict) """ model = DenoisingAutoencoder(**kwargs) return model.net, model.meta_state_dict # Convenience function for quick denoising def denoise_images(noisy_images: np.ndarray, clean_images: np.ndarray, test_noisy: Optional[np.ndarray] = None, test_clean: Optional[np.ndarray] = None, training_cycles: int = 500, **kwargs) -> Tuple[DenoisingAutoencoder, np.ndarray]: """ Convenience function for training a denoising autoencoder and making predictions Args: noisy_images: Training noisy images clean_images: Training clean images test_noisy: Test noisy images (optional) test_clean: Test clean images (optional) training_cycles: Number of training cycles **kwargs: Additional arguments for model and training Returns: Tuple of (trained_model, predictions_on_test_data) """ # Initialize model model = DenoisingAutoencoder(**kwargs) # Train model model.fit(noisy_images, clean_images, test_noisy, test_clean, training_cycles=training_cycles, **kwargs) # Make predictions if test data provided predictions = None if test_noisy is not None: predictions = model.predict(test_noisy) return model, predictions