import math
from typing import Optional, Tuple, List

import numpy
from numpy.typing import ArrayLike
from sklearn.decomposition import SparseCoder

from import _defaults
from aydin.util.crop.rep_crop import representative_crop
from aydin.util.dictionary.dictionary import (
from aydin.util.j_invariance.j_invariance import calibrate_denoiser
from aydin.util.log.log import lsection, lprint
from aydin.util.patch_size.patch_size import default_patch_size
from aydin.util.patch_transform.patch_transform import reconstruct_from_nd_patches

[docs]def calibrate_denoise_dictionary_fixed( image: ArrayLike, patch_size: int = None, try_omp: bool = True, try_lasso_lars: bool = False, try_lasso_cd: bool = False, try_lars: bool = False, try_threshold: bool = False, num_sparsity_values_to_try: int = 6, dictionaries: str = 'dct', crop_size_in_voxels: Optional[int] = _defaults.default_crop_size_normal.value, optimiser: str = _defaults.default_optimiser.value, max_num_evaluations: int = _defaults.default_max_evals_low.value, blind_spots: Optional[List[Tuple[int]]] = _defaults.default_blind_spots.value, jinv_interpolation_mode: str = _defaults.default_jinv_interpolation_mode.value, display_dictionary: bool = False, display_images: bool = False, display_crop: bool = False, **other_fixed_parameters, ): """ Calibrates the dictionary-based denoiser for the given image and returns the optimal parameters obtained using the N2S loss. Parameters ---------- image : ArrayLike Image to calibrate denoiser for. patch_size : int Patch size. Common parameter to both 'learned', or 'fixed' dictionary types. (advanced) try_omp: bool Whether OMP should be tried as a sparse coding algorithm during calibration. try_lasso_lars: bool Whether LASSO-LARS should be tried as a sparse coding algorithm during calibration. try_lasso_cd: bool Whether LASSO-CD should be tried as a sparse coding algorithm during calibration. try_lars: bool Whether LARS should be tried as a sparse coding algorithm during calibration. try_threshold: bool Whether 'threshold'' should be tried as a sparse coding algorithm during calibration. num_sparsity_values_to_try: int Maximum number of sparsity values to try during calibration (advanced) dictionaries: str Fixed dictionaries to be included. Can be: 'dct', 'dst'. crop_size_in_voxels: int or None for default Number of voxels for crop used to calibrate denoiser. Increase this number by factors of two if denoising quality is unsatisfactory -- this can be important for very noisy images. Values to try are: 65000, 128000, 256000, 320000. We do not recommend values higher than 512000. optimiser: str Optimiser to use for finding the best denoising parameters. Can be: 'smart' (default), or 'fast' for a mix of SHGO followed by L-BFGS-B. (advanced) max_num_evaluations: int Maximum number of evaluations for finding the optimal parameters. Increase this number by factors of two if denoising quality is unsatisfactory. blind_spots: bool List of voxel coordinates (relative to receptive field center) to be included in the blind-spot. For example, you can give a list of 3 tuples: [(0,0,0), (0,1,0), (0,-1,0)] to extend the blind spot to cover voxels of relative coordinates: (0,0,0),(0,1,0), and (0,-1,0) (advanced) (hidden) jinv_interpolation_mode: str J-invariance interpolation mode for masking. Can be: 'median' or 'gaussian'. (advanced) display_dictionary: bool If True displays dictionary with napari -- for debug purposes. (advanced) (hidden) display_images: bool When True the denoised images encountered during optimisation are shown. (advanced) (hidden) display_crop: bool Displays crop, for debugging purposes... (advanced) (hidden) other_fixed_parameters: dict Any other fixed parameters Returns ------- Denoising function, dictionary containing optimal parameters, and free memory needed in bytes for computation. """ # Convert image to float if needed: image = image.astype(dtype=numpy.float32, copy=False) # Normalise patch size: patch_size = default_patch_size(image, patch_size, odd=True) # obtain representative crop, to speed things up... crop = representative_crop( image, crop_size=crop_size_in_voxels, display_crop=display_crop ) # Partial function: def _denoise_dictionary( image, max_freq: float = 0.5, coding_mode: str = 'omp', **parameters ): dictionary = fixed_dictionary( image, patch_size=patch_size, dictionaries=dictionaries, max_freq=max_freq ) denoised_image = denoise_dictionary_fixed( image, dictionary=dictionary, coding_mode=coding_mode, **parameters ) return denoised_image # coding modes to try: coding_modes = [] if try_omp: coding_modes.append('omp') if try_lasso_lars: coding_modes.append('lasso_lars') if try_lasso_cd: coding_modes.append('lasso_cd') if try_lars: coding_modes.append('lars') if try_threshold: coding_modes.append('threshold') # Parameters to test when calibrating the denoising algorithm parameter_ranges = {'max_freq': (0.01, 1.3), 'coding_mode': coding_modes} # Calibrate denoiser: best_parameters = calibrate_denoiser( crop, _denoise_dictionary, mode=optimiser, denoise_parameters=parameter_ranges, interpolation_mode=jinv_interpolation_mode, max_num_evaluations=max_num_evaluations, blind_spots=blind_spots, ) lprint(f"Best parameters: {best_parameters}") # Parameters to test when calibrating the denoising algorithm parameter_ranges = {'sparsity': [1, 2, 3, 4, 8, 16][:num_sparsity_values_to_try]} # Calibrate denoiser: best_parameters = ( calibrate_denoiser( crop, _denoise_dictionary, denoise_parameters=parameter_ranges, interpolation_mode=jinv_interpolation_mode, other_fixed_parameters=best_parameters | other_fixed_parameters, max_num_evaluations=max_num_evaluations, display_images=display_images, blind_spots=blind_spots, ) | best_parameters | other_fixed_parameters ) # Cleaning up a bit: best_parameters.pop('other_fixed_parameters') lprint(f"Final best parameters: {best_parameters}") # we need to replace the max freq argument with the actual dictionary # because that's what our client facing denoise function expects: max_freq = best_parameters.pop('max_freq') # Dictionary to use based on fixed and best parameters: dictionary = fixed_dictionary( image, patch_size=patch_size, dictionaries=dictionaries, max_freq=max_freq ) best_parameters = best_parameters | {'dictionary': dictionary} if display_dictionary: import napari with napari.gui_qt(): viewer = napari.Viewer() viewer.add_image( dictionary.reshape(len(dictionary), *patch_size), name='dictionary' ) # Memory needed: memory_needed = 2 * image.nbytes + 6 * image.nbytes * return denoise_dictionary_fixed, best_parameters, memory_needed
[docs]def denoise_dictionary_fixed( image: ArrayLike, dictionary=None, coding_mode: str = 'omp', sparsity: int = 1, gamma: float = 0.001, multi_core: bool = True, **kwargs, ): """ Denoises the given image using sparse-coding over a fixed dictionary of nD image patches. The dictionary learning and patch sparse coding uses scikit-learn's Batch-OMP implementation. Parameters ---------- image: ArrayLike nD image to be denoised dictionary: ArrayLike Dictionary to use for denosing image via sparse coding. By default (None) a fixed dictionary is used. coding_mode: str Type of sparse coding, can be: 'lasso_lars', 'lasso_cd', 'lars', 'omp', or 'threshold' sparsity: int How many atoms are used to represent each patch after denoising. gamma: float How much the periphery of teh patches contributes to the final denoised image. Larger gamma means that we keep more of the central pixels of the patches, smaller values lead to a more uniform contribution. A value of 1 corresponds to the default blackman window. multi_core: bool By default we use as many cores as possible, in some cases, for small (test) images, it might be faster to run on a single core instead of starting the whole parallelization machinery. Returns ------- Denoised image """ # Convert image to float if needed: image = image.astype(dtype=numpy.float32, copy=False) if dictionary is None: # learn dictionary with all defaults: dictionary = fixed_dictionary(image) # we can infer patch shape from dictionary: patch_size = dictionary.shape[1:] with lsection(f"Denoise image of shape {image.shape} and dtype {image.dtype}"): # vectorise dictionary: vectorised_dictionary = dictionary.reshape(len(dictionary), -1) # setup sparse coder: coder = SparseCoder( vectorised_dictionary, transform_algorithm=coding_mode, transform_n_nonzero_coefs=sparsity, n_jobs=-1 if multi_core else 1, ) # First we extract _all_ patches from the image, without any normalisation: with lsection("Extract all patches from image..."): patches, patch_means, _ = extract_normalised_vectorised_patches( image, patch_size=patch_size, max_patches=None, normalise_means=True, normalise_stds=False, output_norm_values=True, ) with lsection("Obtain sparse codes for each patch..."): code = coder.transform(patches) with lsection("Reconstruct patches from codes..."): denoised_patches =, vectorised_dictionary) # Add back means: denoised_patches += patch_means with lsection("Reshape to patches..."): denoised_patches = denoised_patches.reshape(len(patches), *patch_size) with lsection("Reconstructing image from patches..."): # Reconstructs image from denoised patches: denoised_image = reconstruct_from_nd_patches( patches=denoised_patches, image_shape=image.shape, gamma=gamma ) return denoised_image