Source code for aydin.it.transforms.deskew
import numpy
from numpy.typing import ArrayLike
from aydin.it.transforms.base import ImageTransformBase
from aydin.util.log.log import lsection, lprint
[docs]class DeskewTransform(ImageTransformBase):
    """(Integral) Stack Deskewer
    Denoising is more effective if voxels carrying correlated signal are close to each other. When a stack is skewed
    -- as resulting in some imaging modalities -- correlated voxels that should be close in space are far from each
    other. Thus, deskewing the image before denoising is highly recommended. Importantly, the deskewing must be
    'integral', meaning that it must not interpolate voxel values, which is a unadvised lossy operation. Integral
    stack deskewing consists in applying an integral shear transformation to a stack. Two axes need to be specified:
    the 'z'-axis and the 'skew'-axis along which shifting happens. The delta parameter controls the amount of shift
    per plane - must be an integer. We automatically snap the delta value to the closest integer. Padding is supported.
    Note: this only works for images with at least 3 dimensions. Does nothing
    on images with less than 3 dimensions.(advanced)
    """
    preprocess_description = "Deskew image" + ImageTransformBase.preprocess_description
    postprocess_description = (
        "Reskew image" + ImageTransformBase.postprocess_description
    )
    postprocess_supported = True
    postprocess_recommended = True
    def __init__(
        self,
        delta: float = 0,
        z_axis: int = 0,
        skew_axis: int = 1,
        pad: bool = True,
        priority: float = 0.4,
        **kwargs,
    ):
        """
        Constructs a stack deskewer
        Parameters
        ----------
        delta : float
            How much shifting from one plane to the next
        z_axis : int
            Axis for which the amount of shift depends upon.
        skew_axis : int
            Axis over which the image is shifted.
        pad : bool
            True for padding before rolling, this is useful
            because normal padding is rarely enough.
        priority : float
            The priority is a value within [0,1] used to determine the order in
            which to apply the pre- and post-processing transforms. Transforms
            are sorted and applied in ascending order during preprocesing and in
            the reverse, descending, order during post-processing.
        """
        super().__init__(priority=priority, **kwargs)
        self.delta = int(round(delta))
        self.z_axis = z_axis
        self.skew_axis = skew_axis
        self.pad = pad
        lprint(f"Instanciating: {self}")
    # We exclude certain fields from saving:
    def __getstate__(self):
        state = self.__dict__.copy()
        # nothing to exclude
        return state
    def __str__(self):
        return (
            f'{type(self).__name__} (delta={self.delta},'
            f' z_axis={self.z_axis},'
            f' skew_axis={self.skew_axis},'
            f' pad={self.pad} )'
        )
    def __repr__(self):
        return self.__str__()
    def preprocess(self, array: ArrayLike):
        with lsection(
            f"Deskewing (delta={self.delta}, z_axis={self.z_axis}, skew_axis={self.skew_axis}, pad={self.pad}) array of shape: {array.shape} and dtype: {array.dtype}:"
        ):
            if array.ndim >= 3:
                return self.deskew(array)
            else:
                return array
    def postprocess(self, array: ArrayLike):
        if not self.do_postprocess:
            return array
        with lsection(
            f"Undoing deskew for array of shape: {array.shape} and dtype: {array.dtype}:"
        ):
            if array.ndim >= 3:
                return self.reskew(array)
            else:
                return array
    def deskew(self, array: ArrayLike, pad_mode='wrap'):
        array = self._permutate(array)
        array = self._skew_transform(
            array, self.delta, pad=True, crop=False, pad_mode=pad_mode
        )
        array = self._depermutate(array)
        return array
    def reskew(self, array: ArrayLike):
        array = self._permutate(array)
        array = self._skew_transform(
            array, -self.delta, pad=False, crop=True, pad_mode=''
        )
        array = self._depermutate(array)
        return array
    def _permutate(self, array: ArrayLike):
        permutation = self._get_permutation(array)
        array = numpy.transpose(array, axes=permutation)
        return array
    def _depermutate(self, array: ArrayLike):
        permutation = self._get_permutation(array, inverse=True)
        array = numpy.transpose(array, axes=permutation)
        return array
    def _get_permutation(self, array: ArrayLike, inverse=False):
        permutation = (self.z_axis, self.skew_axis) + tuple(
            axis
            for axis in range(array.ndim)
            if axis not in [self.z_axis, self.skew_axis]
        )
        if inverse:
            permutation = numpy.argsort(permutation)
        return permutation
    @staticmethod
    def _skew_transform(array: ArrayLike, delta, pad, crop, pad_mode='wrap'):
        """
        This method assumes that the first dimension (index=0) is the z dimension,
        and the second dimension (index=1) is the 'skewed' dimension.
        The array can have arbitrary dimensions after that...
        We also assume that the array has been properly padded so that we can 'roll' the
        skewed dimension without fear or regret.
        """
        num_z_planes = array.shape[0]
        pad_length = abs(delta * num_z_planes)
        if pad:
            padding = (pad_length, 0) if delta < 0 else (0, pad_length)
            array = numpy.pad(
                array,
                pad_width=((0, 0), padding) + ((0, 0),) * (array.ndim - 2),
                mode=pad_mode,
            )
        else:
            array = array.copy()
        for zi in range(num_z_planes):
            array[zi, ...] = numpy.roll(array[zi, ...], shift=delta * zi, axis=0)
        if crop:
            cropping = (
                slice(pad_length, None, 1) if delta > 0 else slice(0, -pad_length, 1)
            )
            crop_slice = (slice(None), cropping) + (slice(None),) * (array.ndim - 2)
            array = array[crop_slice]
        return array