Source code for aydin.it.transforms.salt_pepper

import numpy

from numpy.typing import ArrayLike
from numpy import sort
from scipy.ndimage import uniform_filter

from aydin.it.classic_denoisers.lipschitz import denoise_lipschitz
from aydin.it.transforms.base import ImageTransformBase
from aydin.util.log.log import lsection, lprint


[docs]class SaltPepperTransform(ImageTransformBase): """Salt And Pepper Correction Detectors such as cameras have 'broken' pixels that blink, are very dim, or very bright. Other phenomena cause voxels to have very different values from their neighbors, this is often called 'impulse' or 'salt-and-pepper' noise. While self-supervised denoising can solve many of these issues, there is no reason to not to alleviate the task, especially when there are simple and fast approaches that can tackle this kind of noise. This preprocessing replaces voxels with the median of its neighbors if the voxel value is too different from its neighbors. This difference is proportional to the local second-derivative of the image. Increase the threshold parameter to tolerate more variation, decrease it to be more aggressive in removing salt & pepper noise. The algorithm is iterative, starting with the most offending pixels, until no pixels are corrected. You can set the max proportion of pixels that are allowed to be corrected if you can give a good estimate for that. """ preprocess_description = ( "Salt and pepper pixels correction" + ImageTransformBase.preprocess_description ) postprocess_description = "Not supported (why would anyone want to do that? ☺)" postprocess_supported = False postprocess_recommended = False def __init__( self, fix_repeated: bool = True, max_repeated: int = 4, fix_lipschitz: bool = True, lipschitz: float = 0.1, percentile: float = 0.01, num_iterations: int = 64, priority: float = 0.08, **kwargs, ): """ Constructs a Salt And Pepper Transform Parameters ---------- fix_repeated: bool Removes Salt & pepper by finding highly repeated values. These values are then considered as erroneous and are fixed by interpolation. max_repeated: int Max number of repeated values to fix. fix_lipschitz: bool Removes Salt & pepper by enforcing Lipschitz continuity. lipschitz : float Lipschitz threshold. Increase to tolerate more variation, decrease to be more aggressive in removing impulse/salt&pepper noise. percentile : float Percentile value used to determine the threshold for choosing the worst offending voxels per iteration according to the Lipschitz threshold. num_iterations : int Number of iterations for enforcing Lipschitz continuity. priority : float The priority is a value within [0,1] used to determine the order in which to apply the pre- and post-processing transforms. Transforms are sorted and applied in ascending order during preprocesing and in the reverse, descending, order during post-processing. """ super().__init__(priority=priority, **kwargs) self.fix_lipschitz = fix_lipschitz self.num_iterations = num_iterations self.correction_percentile = percentile self.lipschitz = lipschitz self.fix_repeated = fix_repeated self.max_repeated = max_repeated self._original_dtype = None lprint(f"Instanciating: {self}") # We exclude certain fields from saving: def __getstate__(self): state = self.__dict__.copy() del state['_original_dtype'] return state def __str__(self): return ( f'{type(self).__name__} (fix_lipschitz={self.fix_lipschitz},' f' num_iterations={self.num_iterations},' f' correction_percentile={self.correction_percentile},' f' lipschitz={self.lipschitz},' f' fix_repeated={self.fix_repeated},' f' max_repeated={self.max_repeated} )' ) def __repr__(self): return self.__str__() def preprocess(self, array: ArrayLike): with lsection( f"Broken Pixels Correction for array of shape: {array.shape} and dtype: {array.dtype}:" ): # We save the original dtype: self._original_dtype = array.dtype # If needed, we convert to float32: array = array.astype(dtype=numpy.float32, copy=True) # First we look at over represented voxel values -- a sign of problematic voxels, # and try to fix them: if self.fix_repeated: array = self._repeated_value_method(array) # Then we enforce Lipschitz continuity: if self.fix_lipschitz: array = self._lipschitz_method(array) return array def postprocess(self, array: ArrayLike): # undoing this transform is unpractical and unlikely to be usefull array = array.astype(self._original_dtype, copy=False) return array def _repeated_value_method(self, array: ArrayLike): with lsection( "Correcting for wrong pixels values using the 'repeated-value' approach:" ): unique, counts = numpy.unique(array, return_counts=True) # How many unique values in image? num_unique_values = unique.size lprint(f"Number of unique values in image: {num_unique_values}.") # Most occuring value most_occuring_value = unique[numpy.argmax(counts)] highest_count = numpy.max(counts) lprint( f"Most occurring value in array: {most_occuring_value}, {highest_count} times." ) # Assuming a uniform distribution we would expect each value to be used at most: average_count = array.size // num_unique_values lprint( f"Average number of occurences of a value assuming uniform distribution: {average_count}" ) # We fix at most n over-represented values: selected_counts = sort(counts.flatten()) # First we ignore counts below a certain thresholds: selected_counts = selected_counts[selected_counts > average_count] # use Otsu split to clean up remaining values: mask = _otsu_split(selected_counts) selected_counts = selected_counts[mask] # Maximum number of repeated values to remove: n = self.max_repeated n = min(n, len(selected_counts)) max_tolerated_count = selected_counts[-n] lprint(f"Maximum tolerated count per value: {max_tolerated_count}.") # If a voxel value appears over more than 0.1% of voxels, then it is a problematic value: problematic_counts_mask = counts > max_tolerated_count problematic_counts = counts[problematic_counts_mask] problematic_values = unique[problematic_counts_mask] lprint(f"Problematic values: {list(problematic_values)}.") lprint(f"Problematic counts: {list(problematic_counts)}.") # We construct the mask of good values: good_values_mask = numpy.ones_like(array, dtype=numpy.bool_) for problematic_value in problematic_values: good_values_mask &= array != problematic_value with lsection(f"Correcting voxels with values: {problematic_values}."): # We save the good values (copy!): good_values = array[good_values_mask].copy() # We compute the number of iterations: num_bad_values = array.size - len(good_values) num_iterations = 16 * int( (array.size / num_bad_values) ** (1.0 / array.ndim) ) # We solve the harmonic equation: for i in range(num_iterations): lprint(f"Iteration {i}") # We compute the median: array = uniform_filter(array, size=3) # We use the median to correct pixels: array[good_values_mask] = good_values # count number of corrections for this round: num_corrections = numpy.sum(mask) lprint(f"Number of corrections: {num_corrections}.") return array def _lipschitz_method(self, array): # Iterations: with lsection( "Correcting for wrong pixels values using the Lipschitz approach:" ): array = denoise_lipschitz( array, lipschitz=self.lipschitz, percentile=self.correction_percentile, max_num_iterations=self.num_iterations, ) return array # OLD METHOD KEEP! # for i in range(self.num_iterations): # lprint(f"Iteration {i}") # # # Compute median: # median = median_filter(array, size=3) # # # We scale the lipschitz threshold to the image std at '3 sigma' : # lipschitz = self.lipschitz * 3 * median.std() # # # We compute the 'error': # median, error = self._compute_error( # array, median=median, lipschitz=lipschitz # ) # # # We compute the threshold on the basis of the errors, # # we first tackle the most offending voxels: # threshold = numpy.percentile( # error, q=100 * (1 - self.correction_percentile) # ) # # # We compute the mask: # mask = error > threshold # # # count number of corrections for this round: # num_corrections = numpy.sum(mask) # lprint(f"Number of corrections: {num_corrections}") # # # if no corrections made we stop iterating: # if num_corrections == 0: # break # # # We keep track of the proportion of voxels corrected: # proportion = ( # num_corrections + total_number_of_corrections # ) / array.size # lprint( # f"Proportion of corrected pixels: {int(proportion * 100)}% (up to now), versus maximum: {int(self.max_proportion_corrected * 100)}%) " # ) # # # If too many voxels have been corrected we stop: # if proportion > self.max_proportion_corrected: # break # # # We use the median to correct pixels: # array[mask] = median[mask] # # # increment total number of corrections: # total_number_of_corrections += num_corrections def _compute_error(self, array, median, lipschitz): # we compute the error map: error = median.copy() error -= array numpy.abs(error, out=error) numpy.maximum(error, lipschitz, out=error) error -= lipschitz return median, error
def _otsu_split(array: ArrayLike): # Flatten array: shape = array.shape array = array.reshape(-1) mean_weigth = 1.0 / array.size his, bins = numpy.histogram(array, bins='auto', density=True) final_thresh = -1 final_value = -1 for i in range(1, len(bins) - 1): Wb = numpy.sum(his[:i]) * mean_weigth Wf = numpy.sum(his[i:]) * mean_weigth mub = numpy.mean(his[:i]) muf = numpy.mean(his[i:]) value = Wb * Wf * (mub - muf) ** 2 # print("Wb", Wb, "Wf", Wf) # print("t", i, "value", value) if value > final_value: final_thresh = 0.5 * (bins[i] + bins[i + 1]) final_value = value mask = array > final_thresh mask = mask.reshape(shape) return mask