Source code for aydin.nn.models.jinet

import torch
from torch import nn
from aydin.nn.layers.dilated_conv import DilatedConv


[docs]class JINetModel(nn.Module): def __init__( self, spacetime_ndim, nb_in_channels: int = 1, nb_out_channels: int = 1, kernel_sizes=None, num_features=None, nb_dense_layers: int = 3, nb_channels: int = None, final_relu: bool = False, degressive_residuals: bool = False, # TODO: check what happens when this is True ): super(JINetModel, self).__init__() self.spacetime_ndim = spacetime_ndim self.nb_in_channels = nb_in_channels self.nb_out_channels = nb_out_channels self._kernel_sizes = kernel_sizes self._num_features = num_features self.nb_dense_layers = nb_dense_layers self.nb_channels = nb_channels self.final_relu = final_relu self.degressive_residuals = degressive_residuals if len(self.kernel_sizes) != len(self.num_features): raise ValueError("Number of kernel sizes and features does not match.") self.dilated_conv_functions = nn.ModuleList() current_receptive_field_radius = 0 for scale_index in range(len(self.kernel_sizes)): # Get kernel size and number of features: kernel_size = self.kernel_sizes[scale_index] # radius and dilation: radius = (kernel_size - 1) // 2 dilation = 1 + current_receptive_field_radius self.dilated_conv_functions.append( DilatedConv( self.nb_in_channels if scale_index == 0 else self.num_features[scale_index - 1], self.num_features[scale_index], self.spacetime_ndim, padding=dilation * radius, kernel_size=kernel_size, dilation=dilation, activation="lrel", ) ) # update receptive field radius current_receptive_field_radius += dilation * radius if spacetime_ndim == 2: self.conv = nn.Conv2d elif spacetime_ndim == 3: self.conv = nn.Conv3d else: raise ValueError("spacetime_ndim can not be anything other than 2 or 3...") if self.nb_channels is None: self.nb_channels = sum(self.num_features) # * 2 nb_out = self.nb_channels self.kernel_one_conv_functions = nn.ModuleList() for index in range(self.nb_dense_layers): nb_in = nb_out nb_out = ( self.nb_out_channels if index == (self.nb_dense_layers - 1) else self.nb_channels ) self.kernel_one_conv_functions.append( self.conv( in_channels=nb_in, out_channels=nb_out, kernel_size=(1,) * spacetime_ndim, ) ) self.final_kernel_one_conv = self.conv( in_channels=self.nb_channels, out_channels=1, kernel_size=(1,) * spacetime_ndim, ) self.relu = nn.ReLU() self.lrelu = nn.LeakyReLU(negative_slope=0.01) @property def kernel_sizes(self): if self._kernel_sizes is None: if self.spacetime_ndim == 2: self._kernel_sizes = [7, 5, 3, 3, 3, 3, 3, 3] elif self.spacetime_ndim == 3: self._kernel_sizes = [7, 5, 3, 3] return self._kernel_sizes @property def num_features(self): if self._num_features is None: if self.spacetime_ndim == 2: self._num_features = [64, 32, 16, 8, 4, 2, 1, 1] elif self.spacetime_ndim == 3: self._num_features = [10, 8, 4, 2] return self._num_features
[docs] def forward(self, x): dilated_conv_list = [] # Calculate dilated convolutions for index in range(len(self.kernel_sizes)): x = self.dilated_conv_functions[index](x) dilated_conv_list.append(x) # print(x.shape) # Concat the results x = torch.cat(dilated_conv_list, dim=1) # print(f"after cat: {x.shape}") # First kernel size one conv x = self.kernel_one_conv_functions[0](x) # print(f"after first kernel one conv: {x.shape}") x = self.lrelu(x) y = x f = 1 # Rest of the kernel size one convolutions for index in range(1, self.nb_dense_layers): x = self.kernel_one_conv_functions[index](x) x = self.lrelu(x) y = y + f * x if self.degressive_residuals: f = f * 0.5 # Final kernel size one convolution y = self.final_kernel_one_conv(y) # Final ReLU if self.final_relu: y = self.relu(y) return y