import importlib
import inspect
import os
import platform
import shutil
from aydin.it import classic_denoisers
from aydin.it.base import ImageTranslatorBase
from aydin.it.classic import ImageDenoiserClassic
from aydin.it.transforms.padding import PaddingTransform
from aydin.it.transforms.range import RangeTransform
from aydin.it.transforms.variance_stabilisation import VarianceStabilisationTransform
from aydin.restoration.denoise.base import DenoiseRestorationBase
from aydin.util.log.log import lsection
if os.getenv("BUNDLED_AYDIN") == "1":
from aydin.it.classic_denoisers.bilateral import ( # noqa: F401
calibrate_denoise_bilateral,
denoise_bilateral,
)
from aydin.it.classic_denoisers.bmnd import ( # noqa: F401
calibrate_denoise_bmnd,
denoise_bmnd,
)
from aydin.it.classic_denoisers.butterworth import ( # noqa: F401
calibrate_denoise_butterworth,
denoise_butterworth,
)
from aydin.it.classic_denoisers.dictionary_fixed import ( # noqa: F401
calibrate_denoise_dictionary_fixed,
denoise_dictionary_fixed,
)
from aydin.it.classic_denoisers.dictionary_learned import ( # noqa: F401
calibrate_denoise_dictionary_learned,
denoise_dictionary_learned,
)
from aydin.it.classic_denoisers.gaussian import ( # noqa: F401
calibrate_denoise_gaussian,
denoise_gaussian,
)
from aydin.it.classic_denoisers.gm import ( # noqa: F401
calibrate_denoise_gm,
denoise_gm,
)
from aydin.it.classic_denoisers.harmonic import ( # noqa: F401
calibrate_denoise_harmonic,
denoise_harmonic,
)
from aydin.it.classic_denoisers.lipschitz import ( # noqa: F401
calibrate_denoise_lipschitz,
denoise_lipschitz,
)
from aydin.it.classic_denoisers.nlm import ( # noqa: F401
calibrate_denoise_nlm,
denoise_nlm,
)
from aydin.it.classic_denoisers.pca import ( # noqa: F401
calibrate_denoise_pca,
denoise_pca,
)
from aydin.it.classic_denoisers.spectral import ( # noqa: F401
calibrate_denoise_spectral,
denoise_spectral,
)
from aydin.it.classic_denoisers.tv import ( # noqa: F401
calibrate_denoise_tv,
denoise_tv,
)
from aydin.it.classic_denoisers.wavelet import ( # noqa: F401
calibrate_denoise_wavelet,
denoise_wavelet,
)
[docs]class Classic(DenoiseRestorationBase):
"""Classic Image Denoising"""
disabled_modules = ["bilateral", "bmnd", "_defaults"]
def __init__(
self,
*,
variant: str = None,
use_model=None,
input_model_path=None,
lower_level_args=None,
it_transforms=None,
):
"""
Parameters
----------
variant : str
Variant of the Classic denoiser to be used. Variant
would supersede the denoiser option passed in lower_level_args.
`implementations` property would return a complete list
of variants (with a prefix of 'Classic-`) that can be used
on a given installation. Example variants: `butterworth`,
`gaussian`, `lipschitz`, `nlm`, ...
use_model : bool
Flag to choose to train a new model or infer from a
previously trained model. By default it is None.
input_model_path : string
Path to model that is desired to be used for inference.
By default it is None.
lower_level_args
it_transforms
"""
super().__init__(variant=variant)
self.lower_level_args = lower_level_args
self.input_model_path = input_model_path
self.use_model_flag = use_model
self.it_transforms = (
[
{"class": RangeTransform, "kwargs": {}},
{"class": PaddingTransform, "kwargs": {}},
{"class": VarianceStabilisationTransform, "kwargs": {}},
]
if it_transforms is None
else it_transforms
)
@property
def configurable_arguments(self):
"""Returns the configurable arguments that will be exposed
on GUI and CLI.
"""
arguments = {}
# IT Classic
it = ImageDenoiserClassic
fullargspec3 = inspect.getfullargspec(ImageDenoiserClassic.__init__)
it_args = {
"arguments": fullargspec3.args[3:],
"defaults": fullargspec3.defaults[2:],
"annotations": fullargspec3.annotations,
"reference_class": it,
}
# Methods
method_modules = self.get_implementations_in_a_module(classic_denoisers)
# Remove the disabled modules
method_modules = [
module
for module in method_modules
if module.name not in self.disabled_modules
]
if platform.system() == "Darwin":
for module in method_modules:
if module.name in ["spectral", "dictionary_learned"]:
method_modules.remove(module)
for module in method_modules:
calibration_args = self.get_function_implementation_kwargs(
classic_denoisers, module, "calibrate_denoise_" + module.name
)
for idx, arg_name in enumerate(calibration_args["arguments"]):
if arg_name == "display_images":
calibration_args["arguments"].remove(arg_name)
calibration_args["defaults"] = tuple(
x
for id, x in enumerate(calibration_args["defaults"])
if id != idx
)
del calibration_args["annotations"][arg_name]
calibration_args["backend"] = module.name
arguments["Classic-" + module.name] = {
"calibration": calibration_args,
"it": it_args,
}
return arguments
@property
def implementations(self):
"""Returns the list of discovered implementations for given method."""
method_modules = self.get_implementations_in_a_module(classic_denoisers)
# Remove the disabled modules
method_modules = [
module
for module in method_modules
if module.name not in self.disabled_modules
]
if platform.system() == "Darwin":
for module in method_modules:
if module.name in ["spectral", "dictionary_learned"]:
method_modules.remove(module)
return ["Classic-" + x.name for x in method_modules]
@property
def implementations_description(self):
it_classic_description = ImageDenoiserClassic.__doc__.strip()
descriptions = []
method_modules = self.get_implementations_in_a_module(classic_denoisers)
# Remove the disabled modules
method_modules = [
module
for module in method_modules
if module.name not in self.disabled_modules
]
if platform.system() == "Darwin":
for module in method_modules:
if module.name in ["spectral", "dictionary_learned"]:
method_modules.remove(module)
for module in method_modules:
response = importlib.import_module(
classic_denoisers.__name__ + '.' + module.name
)
elem = response.__getattribute__("denoise_" + module.name)
descriptions.append(
it_classic_description
+ ": "
+ module.name
+ "\n\n"
+ elem.__doc__[: elem.__doc__.find("Parameters")].replace(
"\n\n", "<br><br>"
)
)
return descriptions
[docs] def stop_running(self):
"""Method to stop running N2S instance"""
self.it.stop_training()
[docs] def get_translator(self):
"""Returns the corresponding translator instance for given selections.
Parameters
----------
feature_generator : FeatureGeneratorBase
regressor : RegressorBase
Returns
-------
it : ImageTranslatorBase
"""
if self.variant:
return ImageDenoiserClassic(method=self.variant)
# Use a pre-saved model or train a new one from scratch and save it
if self.use_model_flag:
# Unarchive the model file and load its ImageTranslator object into self.it
shutil.unpack_archive(
self.input_model_path, os.path.dirname(self.input_model_path), "zip"
)
it = ImageTranslatorBase.load(self.input_model_path[:-4])
else:
if (
self.lower_level_args is not None
and self.lower_level_args["variant"] is not None
):
method = self.lower_level_args["variant"].split("-")[1]
it = ImageDenoiserClassic(
method=method,
calibration_kwargs=self.lower_level_args["calibration"]["kwargs"],
**self.lower_level_args["it"]["kwargs"],
)
else:
it = ImageDenoiserClassic()
return it
def add_transforms(self):
if self.it_transforms is not None:
for transform in self.it_transforms:
transform_class = transform["class"]
transform_kwargs = transform["kwargs"]
self.it.add_transform(transform_class(**transform_kwargs))
[docs] def train(self, noisy_image, *, batch_axes=None, chan_axes=None, **kwargs):
"""Method to run training for Noise2Self FGR.
Parameters
----------
noisy_image : numpy.ndarray
batch_axes : array_like, optional
Indices of batch axes.
chan_axes : array_like, optional
Indices of channel axes.
Returns
-------
response : numpy.ndarray
"""
with lsection("Noise2Self train is starting..."):
self.it = self.get_translator()
self.add_transforms()
# Train a new model
self.it.train(
noisy_image,
noisy_image,
batch_axes=batch_axes,
channel_axes=chan_axes,
train_valid_ratio=kwargs['train_valid_ratio']
if 'train_valid_ratio' in kwargs
else 0.1,
callback_period=kwargs['callback_period']
if 'callback_period' in kwargs
else 3,
jinv=kwargs['jinv'] if 'jinv' in kwargs else None,
)
[docs] def denoise(self, noisy_image, *, batch_axes=None, chan_axes=None, **kwargs):
"""Method to denoise an image with trained Noise2Self FGR.
Parameters
----------
batch_axes : array_like, optional
Indices of batch axes.
chan_axes : array_like, optional
Indices of channel axes.
noisy_image : numpy.ndarray
Returns
-------
response : numpy.ndarray
"""
with lsection("Noise2Self denoise is starting..."):
# Predict the resulting image
response = self.it.translate(
noisy_image,
batch_axes=batch_axes,
channel_axes=chan_axes,
tile_size=kwargs['tile_size'] if 'tile_size' in kwargs else None,
)
response = response.astype(noisy_image.dtype, copy=False)
return response
[docs]def classic_denoise(noisy, *, batch_axes=None, chan_axes=None, variant=None):
"""Method to denoise an image with Classic denoising restoration module.
Parameters
----------
noisy : numpy.ndarray
Image to denoise
batch_axes : array_like, optional
Indices of batch axes.
chan_axes : array_like, optional
Indices of channel axes.
variant : str
Algorithm variant.
Returns
-------
Denoised image : numpy.ndarray
"""
# Run and save the result
classic = Classic(variant=variant)
# Train
classic.train(noisy, batch_axes=batch_axes, chan_axes=chan_axes)
# Denoise
denoised = classic.denoise(noisy, batch_axes=batch_axes, chan_axes=chan_axes)
return denoised