Source code for aydin.nn.models.unet

import torch
from torch import nn

from aydin.nn.layers.custom_conv import double_conv_block
from aydin.nn.layers.pooling_down import PoolingDown


[docs]class UNetModel(nn.Module): def __init__( self, spacetime_ndim, nb_unet_levels: int = 4, nb_filters: int = 8, learning_rate=0.01, pooling_mode: str = 'max', ): super(UNetModel, self).__init__() self.spacetime_ndim = spacetime_ndim self.nb_unet_levels = nb_unet_levels self.nb_filters = nb_filters self.learning_rate = learning_rate self.pooling_down = PoolingDown(spacetime_ndim, pooling_mode) self.upsampling = nn.Upsample(scale_factor=2, mode='nearest') self.double_conv_blocks_encoder = self._encoder_convolutions() self.unet_bottom_conv_out_channels = self.nb_filters * ( 2 ** (self.nb_unet_levels - 1) ) self.unet_bottom_conv_block = double_conv_block( self.unet_bottom_conv_out_channels, self.unet_bottom_conv_out_channels * 2, self.unet_bottom_conv_out_channels, spacetime_ndim, ) self.double_conv_blocks_decoder = self._decoder_convolutions() if spacetime_ndim == 2: self.final_conv = nn.Conv2d(self.nb_filters, 1, 1) else: self.final_conv = nn.Conv3d(self.nb_filters, 1, 1)
[docs] def forward(self, x): """ UNet forward method. Parameters ---------- x input_mask : numpy.ArrayLike A mask per image must be passed with self-supervised training. Returns ------- """ skip_layer = [] # Encoder for layer_index in range(self.nb_unet_levels): x = self.double_conv_blocks_encoder[layer_index](x) skip_layer.append(x) x = self.pooling_down(x) # Bottom x = self.unet_bottom_conv_block(x) # Decoder for layer_index in range(self.nb_unet_levels): x = self.upsampling(x) x = torch.cat([x, skip_layer.pop()], dim=1) x = self.double_conv_blocks_decoder[layer_index](x) # Final convolution x = self.final_conv(x) return x
def _encoder_convolutions(self): convolution = nn.ModuleList() for layer_index in range(self.nb_unet_levels): if layer_index == 0: nb_filters_in = 1 nb_filters_inner = self.nb_filters nb_filters_out = self.nb_filters else: nb_filters_in = self.nb_filters * (2 ** (layer_index - 1)) nb_filters_inner = self.nb_filters * (2**layer_index) nb_filters_out = self.nb_filters * (2**layer_index) convolution.append( double_conv_block( nb_filters_in, nb_filters_inner, nb_filters_out, self.spacetime_ndim, ) ) return convolution def _decoder_convolutions(self): convolutions = nn.ModuleList() for layer_index in range(self.nb_unet_levels): if layer_index == self.nb_unet_levels - 1: nb_filters_in = self.nb_filters * 2 nb_filters_inner = nb_filters_out = self.nb_filters else: nb_filters_in = self.nb_filters * ( 2 ** (self.nb_unet_levels - layer_index) ) nb_filters_inner = nb_filters_in // 2 nb_filters_out = nb_filters_in // 4 convolutions.append( double_conv_block( nb_filters_in, nb_filters_inner, nb_filters_out, spacetime_ndim=self.spacetime_ndim, normalizations=(None, "batch"), ) ) return convolutions