Source code for aydin.it.classic_denoisers.spectral

import math
from functools import partial
from typing import Optional, Union, Tuple, Sequence, List

import numpy
from numba import jit, prange
from numpy.fft import fftshift, ifftshift
from numpy.typing import ArrayLike
from scipy.fft import fftn, ifftn, dctn, idctn, dstn, idstn

from aydin.it.classic_denoisers import _defaults
from aydin.util.array.outer import outer_sum
from aydin.util.crop.rep_crop import representative_crop
from aydin.util.j_invariance.j_invariance import calibrate_denoiser
from aydin.util.patch_size.patch_size import default_patch_size
from aydin.util.patch_transform.patch_transform import (
    extract_patches_nd,
    reconstruct_from_nd_patches,
)


[docs]def calibrate_denoise_spectral( image: ArrayLike, axes: Optional[Tuple[int, ...]] = None, patch_size: Optional[Union[int, Tuple[int], str]] = None, try_dct: bool = True, try_fft: bool = False, try_dst: bool = False, max_order: float = 6.0, crop_size_in_voxels: Optional[int] = _defaults.default_crop_size_normal.value, optimiser: str = _defaults.default_optimiser.value, max_num_evaluations: int = _defaults.default_max_evals_low.value, blind_spots: Optional[List[Tuple[int]]] = _defaults.default_blind_spots.value, jinv_interpolation_mode: str = _defaults.default_jinv_interpolation_mode.value, multi_core: bool = True, display_images: bool = False, display_crop: bool = False, **other_fixed_parameters, ): """ Calibrates the Spectral denoiser for the given image and returns the optimal parameters obtained using the N2S loss. Parameters ---------- image: ArrayLike Image to calibrate spectral denoiser for. axes: Optional[Tuple[int,...]] Axes over which to apply the spectral transform (dct, dst, fft) for denoising each patch. patch_size: int Patch size for the 'image-to-patch' transform. Can be: 'full' for a single patch covering the whole image, 'half', 'quarter', or an int s that corresponds to isotropic patches of shape: (s,)*image.ndim, or a tuple of ints. By default (None) the patch size is chosen automatically to give the best results. (advanced) try_dct: bool Tries DCT transform during optimisation. try_fft: bool Tries FFT transform during optimisation. try_dst: bool Tries DST ransform during optimisation. max_order: float Maximal order for the Butterworth filter. (advanced) crop_size_in_voxels: int or None for default Number of voxels for crop used to calibrate denoiser. Increase this number by factors of two if denoising quality is unsatisfactory -- this can be important for very noisy images. Values to try are: 65000, 128000, 256000, 320000. We do not recommend values higher than 512000. optimiser: str Optimiser to use for finding the best denoising parameters. Can be: 'smart' (default), or 'fast' for a mix of SHGO followed by L-BFGS-B. (advanced) max_num_evaluations: int Maximum number of evaluations for finding the optimal parameters. Increase this number by factors of two if denoising quality is unsatisfactory. blind_spots: bool List of voxel coordinates (relative to receptive field center) to be included in the blind-spot. For example, you can give a list of 3 tuples: [(0,0,0), (0,1,0), (0,-1,0)] to extend the blind spot to cover voxels of relative coordinates: (0,0,0),(0,1,0), and (0,-1,0) (advanced) (hidden) jinv_interpolation_mode: str J-invariance interpolation mode for masking. Can be: 'median' or 'gaussian'. (advanced) multi_core: bool Use all CPU cores during calibration. (advanced) display_images: bool When True the denoised images encountered during optimisation are shown. (advanced) (hidden) display_crop: bool Displays crop, for debugging purposes... (advanced) (hidden) other_fixed_parameters: dict Any other fixed parameters Returns ------- Denoising function, dictionary containing optimal parameters, and free memory needed in bytes for computation. """ # Convert image to float if needed: image = image.astype(dtype=numpy.float32, copy=False) # obtain representative crop, to speed things up... crop = representative_crop( image, crop_size=crop_size_in_voxels, display_crop=display_crop ) # Normalise patch size: patch_size = default_patch_size(image, patch_size, odd=True) # Ranges: threshold_range = (0.0, 1.0) # np.arange(0, 0.5, 0.02) ** 2 freq_bias_range = (0.0, 2.0) # np.arange(0, 2, 0.2) freq_cutoff_range = (0.01, 1.0) order_range = (0.5, max_order) # prepare modes list modes = [] if try_dct: modes.append("dct") if try_fft: modes.append("fft") if try_dst: modes.append("dst") # Parameters to test when calibrating the denoising algorithm parameter_ranges = { 'threshold': threshold_range, 'freq_bias_stength': freq_bias_range, # 'reconstruction_gamma': [0.0001, 0.1, 1.0], 'freq_cutoff': freq_cutoff_range, 'order': order_range, 'mode': modes, } # 'fft',, 'fft' # Combine fixed parameters: other_fixed_parameters = other_fixed_parameters | { 'patch_size': patch_size, 'axes': axes, } # Partial function: _denoise_spectral = partial( denoise_spectral, **(other_fixed_parameters | {'multi_core': multi_core}) ) # Calibrate denoiser best_parameters = ( calibrate_denoiser( crop, _denoise_spectral, mode=optimiser, denoise_parameters=parameter_ranges, interpolation_mode=jinv_interpolation_mode, max_num_evaluations=max_num_evaluations, blind_spots=blind_spots, display_images=display_images, ) | other_fixed_parameters ) # Memory needed: memory_needed = 2 * image.nbytes + 8 * image.nbytes * math.prod(patch_size) return denoise_spectral, best_parameters, memory_needed
[docs]def denoise_spectral( image: ArrayLike, axes: Optional[Tuple[int, ...]] = None, patch_size: Optional[Union[int, Tuple[int], str]] = None, mode: str = 'dct', threshold: float = 0.5, freq_bias_stength: float = 1, freq_cutoff: Union[float, Sequence[float]] = 0.5, order: float = 1, reconstruction_gamma: float = 0, multi_core: bool = True, ): """Denoises the given image by first applying the patch transform, and then zeroing Fourier/DCT/DST coefficients below a given threshold. In addition, we apply Butterworth filter to suppress frequencies above the band-pass and a configurable frequency bias before applying the thresholding to favour suppressing high versus low frequencies. \n\n Note: This seems like a lot of parameters, but thanks to our auto-tuning approach these parameters are all automatically determined 😊. Parameters ---------- image: ArrayLike Image to denoise axes: Optional[Tuple[int,...]] Axes over which to apply the spetcral transform (dct, dst, fft) for denoising each patch. patch_size: int Patch size for the 'image-to-patch' transform. Can be: 'full' for a single patch covering the whole image, 'half', 'quarter', or an int s that corresponds to isotropic patches of shape: (s,)*image.ndim, or a tuple of ints. By default (None) the patch size is chosen automatically to give the best results. mode: str Possible modes are: 'dct'(works best!), 'dst', and 'fft'. threshold: float Threshold between 0 and 1 freq_bias_stength: float Frequency bias: closer to zero: no bias against high frequencies, closer to one and above: stronger bias towards high-frequencies. freq_cutoff: float Cutoff frequency, must be within [0, 1]. In addition order: float Filter order, typically an integer above 1. reconstruction_gamma: float Patch reconstruction parameter multi_core: bool By default we use as many cores as possible, in some cases, for small (test) images, it might be faster to run on a single core instead of starting the whole parallelization machinery. Returns ------- Denoised image """ # Convert image to float if needed: image = image.astype(dtype=numpy.float32, copy=False) # 'full' patch size: if patch_size == 'full': patch_size = image.shape elif patch_size == 'half': patch_size = tuple(max(3, 2 * (s // 4)) for s in image.shape) elif patch_size == 'quarter': patch_size = tuple(max(3, 4 * (s // 8)) for s in image.shape) # Normalise patch size: patch_size = default_patch_size(image, patch_size, odd=True) # Default axes: if axes is None: axes = tuple(range(image.ndim)) # Selected axes: selected_axes = tuple((a in axes) for a in range(image.ndim)) workers = -1 if multi_core else 1 axes = tuple(a for a in range(1, image.ndim + 1) if (a - 1) in axes) if mode == 'fft': transform = lambda x: fftshift( # noqa: E731 fftn(x, workers=workers, axes=axes), axes=axes ) i_transform = lambda x: ifftn( # noqa: E731 ifftshift(x, axes=axes), workers=workers, axes=axes ) elif mode == 'dct': transform = partial(dctn, workers=workers, axes=axes) i_transform = partial(idctn, workers=workers, axes=axes) elif mode == 'dst': transform = partial(dstn, workers=workers, axes=axes) i_transform = partial(idstn, workers=workers, axes=axes) else: raise ValueError(f"Unsupported mode: {mode}") # Normalise freq_cutoff argument to tuple: if type(freq_cutoff) is not tuple: freq_cutoff = tuple((freq_cutoff,) * image.ndim) # First we apply the patch transform: patches = extract_patches_nd(image, patch_size=patch_size) # ### PART 1: apply Butterworth filter to patches: # Then we apply the sparsifying transform: patches = transform(patches) # Compute adequate squared distance image and chose filter implementation: if mode == 'fft': f = _compute_distance_image_for_fft(freq_cutoff, patch_size, selected_axes) elif mode == 'dct' or mode == 'dst': f = _compute_distance_image_for_dxt(freq_cutoff, patch_size, selected_axes) else: raise ValueError(f"Unsupported mode: {mode}") # Configure filter function: filter_wrapped = jit(nopython=True, parallel=multi_core)(_filter) # Apply filter: patches = filter_wrapped(patches, f, order) # ### PART 2: apply thresholding: # Window for frequency bias: freq_bias = _freq_bias_window(patch_size, freq_bias_stength) # We use this value to estimate power per coefficient: power = numpy.absolute(patches) power *= freq_bias # import napari # with napari.gui_qt(): # viewer = napari.Viewer() # viewer.add_image(image, name='image') # viewer.add_image(patches, name='patches') # viewer.add_image(f_patches, name='f_patches') # viewer.add_image(power, name='power') # viewer.add_image(freq_bias, name='freq_bias') # What is the max coefficient in the transforms: max_value = numpy.max(power) # We derive from that the threshold: threshold *= max_value # Here are the entries that are below the threshold: below = power < threshold # Thresholding: patches[below] = 0 # Transform back to real space: patches = i_transform(patches) # convert to real: if numpy.iscomplexobj(patches): patches = numpy.real(patches) # Transform back from patches to image: denoised_image = reconstruct_from_nd_patches( patches, image.shape, gamma=reconstruction_gamma ) # Cast back to float32 if needed: denoised_image = denoised_image.astype(numpy.float32, copy=False) return denoised_image
# @jit(nopython=True, parallel=True) def _freq_bias_window(shape: Tuple[int], alpha: float = 1): window_tuple = tuple(numpy.linspace(0, 1, s) ** 2 for s in shape) window_nd = numpy.sqrt(outer_sum(*window_tuple)) + 1e-6 window_nd = 1.0 / (1.0 + window_nd) window_nd **= alpha window_nd /= window_nd.max() window_nd = window_nd.astype(numpy.float32) return window_nd # @jit(nopython=True, parallel=True) def _compute_distance_image_for_dxt(freq_cutoff, shape, selected_axes): # Normalise selected axes: if selected_axes is None: selected_axes = (a for a in range(len(shape))) f = numpy.zeros(shape=shape, dtype=numpy.float32) axis_grid = tuple( (numpy.linspace(0, 1, s) if sa else numpy.zeros((s,))) for sa, s in zip(selected_axes, shape) ) for fc, x in zip(freq_cutoff, numpy.meshgrid(*axis_grid, indexing='ij')): f += (x / fc) ** 2 return f # @jit(nopython=True, parallel=True) def _compute_distance_image_for_fft(freq_cutoff, shape, selected_axes): f = numpy.zeros(shape=shape, dtype=numpy.float32) axis_grid = tuple( (numpy.linspace(-1, 1, s) if sa else numpy.zeros((s,))) for sa, s in zip(selected_axes, shape) ) for fc, x in zip(freq_cutoff, numpy.meshgrid(*axis_grid, indexing='ij')): f += (x / fc) ** 2 return f def _filter(image_f, f, order): factor = 1 / numpy.sqrt(1.0 + f**order) factor = factor.astype(numpy.float32) n = image_f.shape[0] for i in prange(n): image_f[i] *= factor return image_f