Source code for aydin.it.transforms.motion

import math
from typing import Union, Optional, Sequence

import numpy
import scipy

from numpy.typing import ArrayLike
from scipy.ndimage import gaussian_filter
from scipy.optimize import curve_fit

from aydin.it.transforms.base import ImageTransformBase
from aydin.util.log.log import lprint, lsection


[docs]class MotionStabilisationTransform(ImageTransformBase): """Motion Stabilisation Denoising is more effective if signal-correlated voxels are close to each other. When a 2D+t or 3D+t timelapse has shifts between time points, pixels that should be close to each other from one time point to the next are now further away from each other. Worse, this relative placement often varies over time. This complicates denoising and typically leads to more blurry denoised images. Thus, stabilizing a timelapse before denoising is recommended to improve denoising performance. Currently, we assume that all frames can be registered to a common reference frame, and thus that all images have a common background that can be used for registration. For completeness, multiple axis can be specified and the correction is applied along each in sequence.(advanced) """ preprocess_description = ( "Stabilise motion" + ImageTransformBase.preprocess_description ) postprocess_description = ( "Reapply motion" + ImageTransformBase.postprocess_description ) postprocess_supported = True postprocess_recommended = False def __init__( self, axes: Union[None, int, Sequence[int]] = 0, sigma: float = 3, pad: bool = False, crop: bool = False, pad_mode: str = 'min_constant', max_pixel_shift: Optional[int] = None, reference_index: Optional[int] = None, priority: float = 0.45, **kwargs, ): """ Constructs a Motion Corrector Parameters ---------- axes : Union[None, int, Sequence[int]] Index of time axis. If None the axes are determined automatically. sigma : float Sigma for Gaussian filtering used to facilitate registration. pad : bool Pads image before applying stabilisation, default value is True. crop : bool Crops image after undoing stabilisation, default value is True pad_mode : str Padding mode. Can be: 'mean_constant', 'min_constant', 'max_constant', 'wrap'. Default is 'wrap'. max_pixel_shift : Optional[int] Maximum correctable motion in pixels. If None the maximum shift is automatically determined. reference_index : Optional[int] Index of image used as reference to register all others. If None the reference image is automatically determined. priority : float The priority is a value within [0,1] used to determine the order in which to apply the pre- and post-processing transforms. Transforms are sorted and applied in ascending order during preprocesing and in the reverse, descending, order during post-processing. """ super().__init__(priority=priority, **kwargs) self.sigma = sigma self.axes = ( None if axes is None else ((axes,) if isinstance(axes, int) else axes) ) self.pad_mode = pad_mode self.pad = pad self.crop = crop self.center = False self.max_pixel_shift = max_pixel_shift self.reference_index = reference_index self._shifts = {} self._original_dtype = None lprint(f"Instanciating: {self}") # We exclude certain fields from saving: def __getstate__(self): state = self.__dict__.copy() del state['_shifts'] del state['_original_dtype'] return state def __str__(self): return ( f'{type(self).__name__}' f' (pad_mode={self.pad_mode},' f' pad={self.pad},' f' crop={self.crop},' f' center={self.center},' f' max_pixel_shift={self.max_pixel_shift},' f' reference_index={self.reference_index})' ) def __repr__(self): return self.__str__() def preprocess(self, array: ArrayLike): with lsection(f"Motion-correcting array of shape: {array.shape}:"): self._original_dtype = array.dtype # We need a copy because the shift-transfrom is in-place array = array.astype(numpy.float32, copy=True) axes = range(array.ndim) if self.axes is None else self.axes self._shifts = {} for axis in axes: lprint(f"Correcting along axis: {axis}") array = self._permutate(array, axis=axis) shifts, mean_shift = _measure_shifts( array, reference_index=self.reference_index, center=self.center, max_pixel_shift=self.max_pixel_shift, mode='com', sigma=self.sigma, ) lprint(f"Mean shift: {mean_shift}") array = _shift_transform( array, -shifts, pad=self.pad, crop=False, pad_mode=self.pad_mode ) self._shifts[axis] = shifts array = self._depermutate(array, axis=axis) return array def postprocess(self, array: ArrayLike): if not self.do_postprocess: return array with lsection(f"Undoing motion-correction for array of shape: {array.shape}:"): # We need a copy because the shift-transfrom is in-place array = array.astype(numpy.float32, copy=True) axes = range(array.ndim) if self.axes is None else self.axes for axis in reversed(axes): lprint(f"Correcting along axis: {axis}") array = self._permutate(array, axis=axis) shifts = self._shifts[axis] array = _shift_transform( array, shifts, pad=False, crop=self.crop, pad_mode='' ) array = self._depermutate(array, axis=axis) # cast back to original dtype: array = array.astype(self._original_dtype, copy=False) return array def _permutate(self, array: ArrayLike, axis: int): permutation = self._get_permutation(array, axis=axis) array = numpy.transpose(array, axes=permutation) return array def _depermutate(self, array: ArrayLike, axis: int): permutation = self._get_permutation(array, axis=axis, inverse=True) array = numpy.transpose(array, axes=permutation) return array def _get_permutation(self, array: ArrayLike, axis: int, inverse=False): permutation = (axis,) + tuple(a for a in range(array.ndim) if a != axis) if inverse: permutation = numpy.argsort(permutation) return permutation
def _shift_transform(array: ArrayLike, shifts, pad, crop, pad_mode='wrap'): """ """ min_shift = abs(numpy.min(shifts, axis=0)) max_shift = abs(numpy.max(shifts, axis=0)) lprint(f"min_shift: {min_shift}, max_shift: {max_shift}") if pad: padding = tuple((mi, ma) for mi, ma in zip(min_shift, max_shift)) lprint(f"Padding: {padding}") value = 0 value = array.mean() if pad_mode == 'mean_constant' else value value = array.min() if pad_mode == 'min_constant' else value value = array.max() if pad_mode == 'max_constant' else value kwargs = {'constant_values': value} if 'constant' in pad_mode else {} pad_mode = 'constant' if 'constant' in pad_mode else pad_mode array = numpy.pad(array, pad_width=((0, 0),) + padding, mode=pad_mode, **kwargs) for ti, shift in enumerate(shifts): lprint(f"Motion correcting {ti} by {shift}") array[ti, ...] = numpy.roll( array[ti, ...], shift=shift, axis=tuple(range(0, array.ndim - 1)) ) if crop: crop_slice = (slice(None),) + tuple( slice(ma, s - mi) for mi, ma, s in zip(min_shift, max_shift, array.shape[1:]) ) array = array[crop_slice] return array def _measure_shifts( array: ArrayLike, reference_index: Optional[int] = None, center: bool = False, max_pixel_shift: Optional[int] = None, mode: str = 'com', sigma: float = 7, ): shifts = [] correlations = [] reference_index = len(array) // 2 if reference_index is None else reference_index max_pixel_shift = ( max(array.shape[1:]) // 3 if max_pixel_shift is None else max_pixel_shift ) for i in range(0, len(array)): image = array[i] if reference_index >= 0: reference_image = array[reference_index] elif reference_index < 0: reference_image = array[max(0, i + reference_index)] shift, correlation = _find_shift( image, reference_image, max_pixel_shift=max_pixel_shift, mode=mode, sigma=sigma, ) shifts.append(shift) correlations.append(correlation) lprint(f"Measured shift of {shift} for image {i}.") shifts = numpy.array(shifts) if reference_index < 0: shifts = numpy.cumsum(shifts, axis=0) # correlations = numpy.stack(correlations) # import napari # with napari.gui_qt(): # viewer = napari.Viewer() # viewer.add_image(array, name='array') # viewer.add_image(correlations, name='correlations') # Center shifts: if center: mean_shift = numpy.round(numpy.mean(shifts, axis=0)) mean_shift = mean_shift.astype(numpy.int) shifts = shifts - mean_shift else: mean_shift = numpy.array((0,) * (array.ndim - 1)) # Convert to int: shifts = numpy.round(shifts).astype(numpy.int) return shifts, mean_shift def _find_shift(a, b, max_pixel_shift: int = 64, mode: str = 'com', sigma: float = 7): # Basic idea: We just need to low-pass filter the heck of it, and it works. lprint(f"max_pixel_shift: {max_pixel_shift}, mode: {mode}, sigma: {sigma}") # First we blur the input images: a = _fast_denoise(a, sigma=sigma) b = _fast_denoise(b, sigma=sigma) # We compute the phas correlation: raw_correlation = _phase_correlation(a, b) # We denoise the correlogram itself again: correlation = _fast_denoise(raw_correlation, sigma=sigma) # correlation = raw_correlation # We estimate the noise floor of the correlation: empty_region = correlation.copy() empty_region_slice = tuple( slice(min(s // 2 - 1, max_pixel_shift), -min(s // 2 - 1, max_pixel_shift)) for s in correlation.shape ) empty_region = empty_region[empty_region_slice] noise_floor_level = numpy.percentile(empty_region, q=99.9) # we use that floor to clip anything below: correlation = correlation.clip(noise_floor_level, math.inf) - noise_floor_level # We roll the array and crop it to restrict ourself to the search region: correlation = numpy.roll( correlation, shift=max_pixel_shift, axis=tuple(range(a.ndim)) ) correlation = correlation[(slice(0, 2 * max_pixel_shift),) * a.ndim] if mode == 'gfit': # This looks fancy and shit but is just bad. def gaussian(x, sx=0, sy=0, b=1, a=1): return a * numpy.exp(-b * ((x[0] - sx) ** 2 + (x[1] - sy) ** 2)) x = numpy.arange(0, 2 * max_pixel_shift, 1) y = numpy.arange(0, 2 * max_pixel_shift, 1) xx, yy = numpy.meshgrid(x, y) xdata = numpy.stack([xx, yy]).reshape(2, -1) ydata = correlation[xx, yy].reshape(-1) popt, pcov = curve_fit( gaussian, xdata, ydata, method='trf', p0=(0, 0, 1, 1), bounds=( [-max_pixel_shift, -max_pixel_shift, 0, 0], [+max_pixel_shift, +max_pixel_shift, numpy.inf, numpy.inf], ), ) shift = popt[:2] elif mode == 'com': # This is simple, and works brilliantly, even with tons of noise. # We use the max as quickly computed proxy for the real center: rough_shift = numpy.unravel_index( numpy.argmax(correlation, axis=None), correlation.shape ) # We crop further to facilitate center-of-mass estimation: fine_window_radius = 4 * sigma cropped_correlation_slice = tuple( slice( max(0, int(rs - fine_window_radius)), min(s - 1, int(rs + fine_window_radius)), ) for rs, s in zip(rough_shift, correlation.shape) ) lprint(f"Cropped correlation: {cropped_correlation_slice}") cropped_correlation = correlation[cropped_correlation_slice] # We compute the signed rough shift signed_rough_shift = numpy.array(rough_shift) - max_pixel_shift if numpy.all(cropped_correlation == 0): # No mass: signed_com_shift = numpy.zeros_like(signed_rough_shift) else: # We compute the center of mass: # We take the square to squash small values far from the maximum that are likely noisy... signed_com_shift = ( numpy.array(scipy.ndimage.center_of_mass(cropped_correlation**2)) - fine_window_radius ) # The final shift is the sum of the rough sight plus the fine center of mass shift: shift = signed_rough_shift + signed_com_shift # shift = numpy.nan_to_num(shift) return shift, correlation def _fast_denoise(array: ArrayLike, sigma): denoised = gaussian_filter(array, sigma=sigma, mode='wrap') return denoised def _phase_correlation(image, reference_image): G_a = scipy.fft.fftn(image, workers=-1) G_b = scipy.fft.fftn(reference_image, workers=-1) conj_b = numpy.ma.conjugate(G_b) R = G_a * conj_b R /= numpy.absolute(R) r = scipy.fft.ifftn(R, workers=-1).real return r