Source code for aydin.it.transforms.deskew
import numpy
from numpy.typing import ArrayLike
from aydin.it.transforms.base import ImageTransformBase
from aydin.util.log.log import lsection, lprint
[docs]class DeskewTransform(ImageTransformBase):
"""(Integral) Stack Deskewer
Denoising is more effective if voxels carrying correlated signal are close to each other. When a stack is skewed
-- as resulting in some imaging modalities -- correlated voxels that should be close in space are far from each
other. Thus, deskewing the image before denoising is highly recommended. Importantly, the deskewing must be
'integral', meaning that it must not interpolate voxel values, which is a unadvised lossy operation. Integral
stack deskewing consists in applying an integral shear transformation to a stack. Two axes need to be specified:
the 'z'-axis and the 'skew'-axis along which shifting happens. The delta parameter controls the amount of shift
per plane - must be an integer. We automatically snap the delta value to the closest integer. Padding is supported.
Note: this only works for images with at least 3 dimensions. Does nothing
on images with less than 3 dimensions.(advanced)
"""
preprocess_description = "Deskew image" + ImageTransformBase.preprocess_description
postprocess_description = (
"Reskew image" + ImageTransformBase.postprocess_description
)
postprocess_supported = True
postprocess_recommended = True
def __init__(
self,
delta: float = 0,
z_axis: int = 0,
skew_axis: int = 1,
pad: bool = True,
priority: float = 0.4,
**kwargs,
):
"""
Constructs a stack deskewer
Parameters
----------
delta : float
How much shifting from one plane to the next
z_axis : int
Axis for which the amount of shift depends upon.
skew_axis : int
Axis over which the image is shifted.
pad : bool
True for padding before rolling, this is useful
because normal padding is rarely enough.
priority : float
The priority is a value within [0,1] used to determine the order in
which to apply the pre- and post-processing transforms. Transforms
are sorted and applied in ascending order during preprocesing and in
the reverse, descending, order during post-processing.
"""
super().__init__(priority=priority, **kwargs)
self.delta = int(round(delta))
self.z_axis = z_axis
self.skew_axis = skew_axis
self.pad = pad
lprint(f"Instanciating: {self}")
# We exclude certain fields from saving:
def __getstate__(self):
state = self.__dict__.copy()
# nothing to exclude
return state
def __str__(self):
return (
f'{type(self).__name__} (delta={self.delta},'
f' z_axis={self.z_axis},'
f' skew_axis={self.skew_axis},'
f' pad={self.pad} )'
)
def __repr__(self):
return self.__str__()
def preprocess(self, array: ArrayLike):
with lsection(
f"Deskewing (delta={self.delta}, z_axis={self.z_axis}, skew_axis={self.skew_axis}, pad={self.pad}) array of shape: {array.shape} and dtype: {array.dtype}:"
):
if array.ndim >= 3:
return self.deskew(array)
else:
return array
def postprocess(self, array: ArrayLike):
if not self.do_postprocess:
return array
with lsection(
f"Undoing deskew for array of shape: {array.shape} and dtype: {array.dtype}:"
):
if array.ndim >= 3:
return self.reskew(array)
else:
return array
def deskew(self, array: ArrayLike, pad_mode='wrap'):
array = self._permutate(array)
array = self._skew_transform(
array, self.delta, pad=True, crop=False, pad_mode=pad_mode
)
array = self._depermutate(array)
return array
def reskew(self, array: ArrayLike):
array = self._permutate(array)
array = self._skew_transform(
array, -self.delta, pad=False, crop=True, pad_mode=''
)
array = self._depermutate(array)
return array
def _permutate(self, array: ArrayLike):
permutation = self._get_permutation(array)
array = numpy.transpose(array, axes=permutation)
return array
def _depermutate(self, array: ArrayLike):
permutation = self._get_permutation(array, inverse=True)
array = numpy.transpose(array, axes=permutation)
return array
def _get_permutation(self, array: ArrayLike, inverse=False):
permutation = (self.z_axis, self.skew_axis) + tuple(
axis
for axis in range(array.ndim)
if axis not in [self.z_axis, self.skew_axis]
)
if inverse:
permutation = numpy.argsort(permutation)
return permutation
@staticmethod
def _skew_transform(array: ArrayLike, delta, pad, crop, pad_mode='wrap'):
"""
This method assumes that the first dimension (index=0) is the z dimension,
and the second dimension (index=1) is the 'skewed' dimension.
The array can have arbitrary dimensions after that...
We also assume that the array has been properly padded so that we can 'roll' the
skewed dimension without fear or regret.
"""
num_z_planes = array.shape[0]
pad_length = abs(delta * num_z_planes)
if pad:
padding = (pad_length, 0) if delta < 0 else (0, pad_length)
array = numpy.pad(
array,
pad_width=((0, 0), padding) + ((0, 0),) * (array.ndim - 2),
mode=pad_mode,
)
else:
array = array.copy()
for zi in range(num_z_planes):
array[zi, ...] = numpy.roll(array[zi, ...], shift=delta * zi, axis=0)
if crop:
cropping = (
slice(pad_length, None, 1) if delta > 0 else slice(0, -pad_length, 1)
)
crop_slice = (slice(None), cropping) + (slice(None),) * (array.ndim - 2)
array = array[crop_slice]
return array