Source code for aydin.restoration.denoise.noise2selfcnn

import importlib
import inspect
import os
import shutil
from typing import Optional

from aydin.it.base import ImageTranslatorBase

# from aydin.it.cnn_torch import ImageTranslatorCNNTorch
from aydin.it.cnn import ImageTranslatorCNN
from aydin.it.transforms.padding import PaddingTransform
from aydin.it.transforms.range import RangeTransform
from aydin.it.transforms.variance_stabilisation import VarianceStabilisationTransform
from aydin.nn.tf import models
from aydin.restoration.denoise.base import DenoiseRestorationBase
from aydin.util.log.log import lsection


[docs]class Noise2SelfCNN(DenoiseRestorationBase): """ Noise2Self image denoising using the "Convolutional Neural Networks" ( CNN) approach. Follows from the theory exposed in the <a href="https://arxiv.org/abs/1901.11365">Noise2Self paper</a>. """ def __init__( self, *, variant: Optional[str] = None, use_model=None, input_model_path=None, lower_level_args=None, it_transforms=None, ): """ Noise2Self image denoising using "Convolutional Neural Networks" (CNN). Parameters ---------- variant : str, optional Variant of CNN denoiser to be used. Variant would supersede the denoiser option passed in lower_level_args. Currently, we support only two variants: `unet` and `jinet`. use_model : bool Flag to choose to train a new model or infer from a previously trained model. By default it is None. input_model_path : string Path to model that is desired to be used for inference. By default it is None. """ super().__init__(variant=variant) self.use_model_flag = use_model self.input_model_path = input_model_path self.lower_level_args = lower_level_args self.it_transforms = ( [ {"class": RangeTransform, "kwargs": {}}, {"class": PaddingTransform, "kwargs": {}}, {"class": VarianceStabilisationTransform, "kwargs": {}}, ] if it_transforms is None else it_transforms ) @property def configurable_arguments(self): """Returns the configurable arguments that will be exposed on GUI and CLI. """ # IT CNN it = ImageTranslatorCNN fullargspec3 = inspect.getfullargspec(ImageTranslatorCNN.__init__) it_args = { "arguments": fullargspec3.args[1:], "defaults": fullargspec3.defaults, "annotations": fullargspec3.annotations, "reference_class": it, } # Model model_modules = DenoiseRestorationBase.get_implementations_in_a_module(models) arguments = {} for module in model_modules: model_args = self.get_class_implementation_kwargs( models, module, module.name + "Model" ) arguments["Noise2SelfCNN-" + module.name] = { "model": model_args, "it": it_args, } return arguments @property def implementations(self): """Returns the list of discovered implementations for given method.""" return [ "Noise2SelfCNN-" + x.name for x in self.get_implementations_in_a_module(models) ] @property def implementations_description(self): cnn_description = Noise2SelfCNN.__doc__.strip() descriptions = [] for module in self.get_implementations_in_a_module(models): response = importlib.import_module(models.__name__ + '.' + module.name) elem = [ x for x in dir(response) if module.name.replace("_", "") in x.lower() ][ 0 ] # class name elem_class = response.__getattribute__(elem) # model_name = elem_class.__name__ model_description = elem_class.__doc__.replace("\n\n", "<br><br>") descriptions.append(cnn_description + f"<br><br>{model_description}") # elem_class = response.__getattribute__(elem) return descriptions
[docs] def stop_running(self): """Method to stop running N2S instance""" self.it.stop_training()
[docs] def get_translator(self): """Returns the corresponding translator instance for given selections. Returns ------- it : ImageTranslatorBase """ if self.variant: return ImageTranslatorCNN(model_architecture=self.variant) # Use a pre-saved model or train a new one from scratch and save it if self.use_model_flag: # Unarchive the model file and load its ImageTranslator object into self.it shutil.unpack_archive( self.input_model_path, os.path.dirname(self.input_model_path), "zip" ) it = ImageTranslatorBase.load(self.input_model_path[:-4]) else: it = ImageTranslatorCNN( **self.lower_level_args["it"]["kwargs"] if self.lower_level_args is not None else {} ) return it
def add_transforms(self): if self.it_transforms is not None: for transform in self.it_transforms: transform_class = transform["class"] transform_kwargs = transform["kwargs"] self.it.add_transform(transform_class(**transform_kwargs))
[docs] def train(self, noisy_image, *, batch_axes=None, chan_axes=None, **kwargs): """Method to run Noise2Self CNN training. Parameters ---------- noisy_image : numpy.ArrayLike batch_axes : array_like, optional Indices of batch axes. chan_axes : array_like, optional Indices of channel axes. Returns ------- response : numpy.ArrayLike """ with lsection("Noise2Self train is starting..."): if sum(chan_axes): return self.it = self.get_translator() self.add_transforms() # Train a new model self.it.train( noisy_image, noisy_image, batch_axes=batch_axes, channel_axes=chan_axes, train_valid_ratio=kwargs['train_valid_ratio'] if 'train_valid_ratio' in kwargs else 0.1, callback_period=kwargs['callback_period'] if 'callback_period' in kwargs else 3, jinv=kwargs['jinv'] if 'jinv' in kwargs else None, )
[docs] def denoise(self, noisy_image, *, batch_axes=None, chan_axes=None, **kwargs): """Method to denoise an image with trained Noise2Self. Parameters ---------- batch_axes : array_like, optional Indices of batch axes. chan_axes : array_like, optional Indices of channel axes. noisy_image : numpy.ndarray Returns ------- response : numpy.ndarray """ with lsection("Noise2Self denoise is starting..."): # Predict the resulting image response = self.it.translate( noisy_image, batch_axes=batch_axes, channel_axes=chan_axes, tile_size=kwargs['tile_size'] if 'tile_size' in kwargs else None, ) response = response.astype(noisy_image.dtype, copy=False) return response
[docs]def noise2self_cnn(image, *, batch_axes=None, chan_axes=None, variant=None): """Method to denoise an image with Noise2Self CNN. Parameters ---------- image : numpy.ndarray Image to denoise batch_axes : array_like, optional Indices of batch axes. chan_axes : array_like, optional Indices of channel axes. variant : str Algorithm variant. Returns ------- Denoised image : numpy.ndarray """ # Run N2S and save the result n2s = Noise2SelfCNN(variant=variant) # Train n2s.train(image, batch_axes=batch_axes, chan_axes=chan_axes) # Denoise denoised = n2s.denoise(image, batch_axes=batch_axes, chan_axes=chan_axes) return denoised