Source code for aydin.it.classic_denoisers.dictionary_learned

import math
from typing import Optional, Tuple, List

import numpy
from numpy.typing import ArrayLike

from aydin.it.classic_denoisers import _defaults
from aydin.it.classic_denoisers.dictionary_fixed import denoise_dictionary_fixed
from aydin.util.crop.rep_crop import representative_crop
from aydin.util.dictionary.dictionary import learn_dictionary
from aydin.util.j_invariance.j_invariance import calibrate_denoiser
from aydin.util.patch_size.patch_size import default_patch_size


[docs]def calibrate_denoise_dictionary_learned( image: ArrayLike, patch_size: int = None, try_omp: bool = True, try_lasso_lars: bool = False, try_lasso_cd: bool = False, try_lars: bool = False, try_threshold: bool = False, max_patches: Optional[int] = int(1e6), dictionary_size: Optional[int] = None, over_completeness: float = 3, max_dictionary_size: int = 256, try_kmeans: bool = True, try_pca: bool = True, try_ica: bool = False, try_sdl: bool = False, num_sparsity_values_to_try: int = 6, optimiser: str = _defaults.default_optimiser.value, num_iterations: int = 1024, batch_size: int = 3, alpha: int = 1, do_cleanup_dictionary: bool = True, do_denoise_dictionary: bool = False, crop_size_in_voxels: Optional[int] = _defaults.default_crop_size_normal.value, max_num_evaluations: int = _defaults.default_max_evals_low.value, blind_spots: Optional[List[Tuple[int]]] = _defaults.default_blind_spots.value, jinv_interpolation_mode: str = _defaults.default_jinv_interpolation_mode.value, display_dictionary: bool = False, display_images: bool = False, display_crop: bool = False, **other_fixed_parameters, ): """ Calibrates the dictionary-based denoiser for the given image by finding the best learned dictionary of patches and returns the optimal parameters. Parameters ---------- image : ArrayLike Image to calibrate denoiser for. patch_size : int Patch size. If None it is automatically adjusted to teh number of dimensions of the image to ensure a reasonable computational effort. (advanced) try_omp: bool Whether OMP should be tried as a sparse coding algorithm during calibration. try_lasso_lars: bool Whether LASSO-LARS should be tried as a sparse coding algorithm during calibration. try_lasso_cd: bool Whether LASSO-CD should be tried as a sparse coding algorithm during calibration. try_lars: bool Whether LARS should be tried as a sparse coding algorithm during calibration. try_threshold: bool Whether 'threshold'' should be tried as a sparse coding algorithm during calibration. max_patches: Optional[int] Max number of patches to extract for dictionary learning. If None there is no limit. (advanced) dictionary_size: int Dictionary size in 'atoms'. If None the dictionary size is inferred from the over_completeness parameter type. (advanced) over_completeness: float Given a certain patch size p and image dimension n, a complete basis has p^n elements, the over completeness factor determines the size of the dictionary relative to that by the formula: ox*p^n. (advanced) max_dictionary_size: int Independently of any other parameter, we limit the size of the dictionary to this provided number. (advanced) try_kmeans: bool Whether to compute a kmeans based dictionary and used it as one of possible dictionaries during calibration. try_pca: bool Whether to compute a PCA based dictionary and used it as one of possible dictionaries during calibration. try_ica: bool Whether to compute an ICA (Independent Component Analysis) based dictionary and used it as one of possible dictionaries during calibration. try_sdl: bool Whether to compute a SDL (Sparse Dictionary Learning) based dictionary and used it as one of possible dictionaries during calibration. num_sparsity_values_to_try: int Maximum number of sparsity values to try during calibration (advanced) num_iterations: int Number of iterations for learning dictionary. (advanced) batch_size: int Size of batches during batched dictionary learning. (advanced) alpha: int Sparsity prior strength. (advanced) do_cleanup_dictionary: bool Removes dictionary entries that are likely pure noise or have impulses or very high-frequencies or checkerboard patterns that are unlikely needed to reconstruct the true signal. (advanced) crop_size_in_voxels: int or None for default Number of voxels for crop used to calibrate denoiser. Increase this number by factors of two if denoising quality is unsatisfactory -- this can be important for very noisy images. Values to try are: 65000, 128000, 256000, 320000. We do not recommend values higher than 512000. optimiser: str Optimiser to use for finding the best denoising parameters. Can be: 'smart' (default), or 'fast' for a mix of SHGO followed by L-BFGS-B. (advanced) max_num_evaluations: int Maximum number of evaluations for finding the optimal parameters. Increase this number by factors of two if denoising quality is unsatisfactory. blind_spots: bool List of voxel coordinates (relative to receptive field center) to be included in the blind-spot. For example, you can give a list of 3 tuples: [(0,0,0), (0,1,0), (0,-1,0)] to extend the blind spot to cover voxels of relative coordinates: (0,0,0),(0,1,0), and (0,-1,0) (advanced) (hidden) jinv_interpolation_mode: str J-invariance interpolation mode for masking. Can be: 'median' or 'gaussian'. (advanced) display_dictionary: bool If True displays dictionary with napari -- for debug purposes. (advanced) (hidden) display_images: bool When True the denoised images encountered during optimisation are shown. (advanced) (hidden) display_crop: bool Displays crop, for debugging purposes... (advanced) (hidden) other_fixed_parameters: dict Any other fixed parameters Returns ------- Denoising function, dictionary containing optimal parameters, and free memory needed in bytes for computation. """ # Convert image to float if needed: image = image.astype(dtype=numpy.float32, copy=False) # Normalise patch size: patch_size = default_patch_size(image, patch_size, odd=True) # obtain representative crop, to speed things up... crop = representative_crop( image, crop_size=crop_size_in_voxels, display_crop=display_crop ) # algorithms to try for generation of dictionaries: algorithms = [] if try_kmeans: algorithms.append('kmeans') if try_pca: algorithms.append('pca') if try_ica: algorithms.append('ica') if try_sdl: algorithms.append('sdl') # dictionaries to try: dictionaries = {} for algorithm in algorithms: # learn dictionary: dictionary = learn_dictionary( image, patch_size=patch_size, max_patches=max_patches, dictionary_size=dictionary_size, over_completeness=over_completeness, max_dictionary_size=max_dictionary_size, algorithm=algorithm, num_iterations=num_iterations, batch_size=batch_size, alpha=alpha, cleanup_dictionary=do_cleanup_dictionary, denoise_dictionary=do_denoise_dictionary, display_dictionary=display_dictionary, **other_fixed_parameters, ) dictionaries[algorithm] = dictionary # coding modes to try: coding_modes = [] if try_omp: coding_modes.append('omp') if try_lasso_lars: coding_modes.append('lasso_lars') if try_lasso_cd: coding_modes.append('lasso_cd') if try_lars: coding_modes.append('lars') if try_threshold: coding_modes.append('threshold') # Parameters to test when calibrating the denoising algorithm parameter_ranges = { 'sparsity': [1, 2, 3, 4, 8, 16][:num_sparsity_values_to_try], 'gamma': [0.001], 'coding_mode': coding_modes, 'dictlearn_algorithm': algorithms, } # Partial function: def _denoise_dictionary(image, dictlearn_algorithm, *args, **kwargs): return denoise_dictionary_learned( image, *args, dictionary=dictionaries[dictlearn_algorithm], **other_fixed_parameters, **kwargs, ) # Calibrate denoiser: best_parameters = ( calibrate_denoiser( crop, _denoise_dictionary, mode=optimiser, denoise_parameters=parameter_ranges, interpolation_mode=jinv_interpolation_mode, max_num_evaluations=max_num_evaluations, blind_spots=blind_spots, display_images=display_images, ) | other_fixed_parameters ) # Memory needed: memory_needed = 2 * image.nbytes + 6 * image.nbytes * math.prod(patch_size) return denoise_dictionary_learned, best_parameters, memory_needed
[docs]def denoise_dictionary_learned(*args, **kwargs): """ Denoises the given image using sparse-coding over a fixed dictionary of nD image patches. The dictionary learning and patch sparse coding uses scikit-learn's Batch-OMP implementation. Parameters ---------- args kwargs Returns ------- denoised image """ return denoise_dictionary_fixed(*args, **kwargs)