from typing import Optional, Tuple, List
import numpy
from numpy import ndarray
from aydin.features.base import FeatureGeneratorBase
from aydin.features.groups.base import FeatureGroupBase
from aydin.util.log.log import lprint, lsection
[docs]class ExtensibleFeatureGenerator(FeatureGeneratorBase):
"""
Extensible Feature Generator class
"""
def __init__(self):
"""
Constructs an extensible feature generator
"""
# Calls super constructor:
super().__init__()
# This list holds all the information for computing each 'group' of features:
self.features_group_list = []
def _load_internals(self, path: str):
pass
[docs] def add_feature_group(self, feature_group: FeatureGroupBase, *args, **kwargs):
"""
Adds a feature to this feature generator.
Parameters
----------
feature_group : FeatureGroupBase
feature group
args
additional arguments for function
kwargs
additional keyword arguments for function
"""
self.features_group_list.append(feature_group)
[docs] def clear_features(self):
"""
Clears the features group list
"""
self.features_group_list = []
[docs] def get_num_features(self, ndim: int) -> int:
"""
Returns the number of features when considering translations
Parameters
----------
ndim : int
number of dimensions
Returns
-------
nb_features : int
"""
nb_features = 0
for feature_group in self.features_group_list:
nb_features += feature_group.num_features(ndim)
return nb_features
[docs] def get_receptive_field_radius(self) -> int:
"""
Returns the receptive field radius in pixels
Returns
-------
result : int
receptive field radius in pixels
"""
receptive_field_radius = 0
for feature_group in self.features_group_list:
receptive_field_radius = max(
receptive_field_radius, feature_group.receptive_field_radius
)
return receptive_field_radius
[docs] def compute(
self,
image,
exclude_center_feature: bool = False,
exclude_center_value: bool = False,
features: ndarray = None,
feature_last_dim: bool = True,
passthrough_channels: Optional[Tuple[bool]] = None,
num_reserved_features: int = 0,
excluded_voxels: Optional[List[Tuple[int]]] = None,
spatial_feature_offset: Optional[Tuple[float, ...]] = None,
spatial_feature_scale: Optional[Tuple[float, ...]] = None,
):
"""
Computes the features given an image. If the input image is of shape (d,h,w),
resulting features are of shape (n,d,h,w) where n is the number of features.
Parameters
----------
image : numpy.ndarray
image for which features are computed
exclude_center_feature : bool
If true, features that use the image
patch's center pixel are entirely excluded from teh set of computed
features.
exclude_center_value : bool
If true, the center pixel is never used
to compute any feature, different feature generation algorithms can
take different approaches to achieve that.
features : ndarray
If None the feature array is allocated internally,
if not None the provided array is used to store the features.
feature_last_dim : bool
If True the last dimension of the feature
array is the feature dimension, if False then it is the first
dimension.
passthrough_channels : Optional[Tuple[bool]]
Optional tuple of booleans that specify which channels are
'pass-through' channels, i.e. channels that are not featurised
and directly used as features.
num_reserved_features : int
Number of features to be left as blank,
useful when adding features separately.
excluded_voxels : Optional[List[Tuple[int]]]
List of pixel coordinates -- expressed as tuple of ints relative to
the central pixel -- that will be excluded from any computed features.
This is used for implementing 'extended blind-spot' N2S denoising
approaches.
spatial_feature_offset: Optional[Tuple[float, ...]]
Offset vector to be applied (added) to the spatial features
(if used).
spatial_feature_scale: Optional[Tuple[float, ...]]
Scale vector to be applied (multiplied) to the spatial features
(if used).
Returns
-------
feature array : numpy.ndarray
"""
with lsection('Computing features'):
# some important numbers:
num_dims = len(image.shape)
num_spatiotemp_dim = num_dims - 2
num_batches = image.shape[0]
num_channels = image.shape[1]
num_features = self.get_num_features(num_spatiotemp_dim)
# exclude_center_value can be a tuple, in that case each entry corresponds to a channel:
if type(exclude_center_value) is not tuple:
exclude_center_value = (exclude_center_value,) * num_channels
# Fills in the default values for passthrough channels: False ==> by default channels are not passthrough.
if passthrough_channels is None:
passthrough_channels = (False,) * num_channels
# Computes the number of features, taking into account: passthrough_channels, channels, and reserved_features:
num_passthrough_channels = sum(1 if p else 0 for p in passthrough_channels)
num_normal_features = num_features * (
num_channels - num_passthrough_channels
)
num_total_features = (
num_normal_features + num_passthrough_channels + num_reserved_features
)
# Creates feature array that will hold the final result:
features = self.create_feature_array(image, num_total_features)
# We iterate over batches:
for batch_index in range(num_batches):
with lsection(
f'Computing features for batch: {batch_index + 1}/{num_batches}'
):
feature_pointer = 0
# We iterate over channels:
for channel_index in range(num_channels):
with lsection(
f'Computing features for channel: {channel_index + 1}/{num_channels}'
):
# We collect the 'exclude_center_value' for the current channel:
exclude_center_value_for_channel = exclude_center_value[
channel_index
]
lprint(
f'Excluding center value for channel: {exclude_center_value_for_channel}'
)
# Image batch slice:
image_slice = (
batch_index,
channel_index,
*(slice(None),) * num_spatiotemp_dim,
)
lprint(f'Image slice: {image_slice}')
if passthrough_channels[channel_index]:
# A passthrough channel is simply fed directly as a feature:
lprint(
f'Adding passthrough channel feature for channel index: {channel_index}'
)
batch_feature_slice = (
slice(feature_pointer, feature_pointer + 1, 1),
batch_index,
*(slice(None),) * num_spatiotemp_dim,
)
features[batch_feature_slice] = image[
image_slice
].astype(self.dtype, copy=False)
feature_pointer += 1
else:
single_image = image[image_slice]
# Usefull code snippet for debugging features:
# with napari.gui_qt():
# viewer = Viewer()
# viewer.add_image(image_batch_gpu.get(), name='image')
# viewer.add_image(rescale_intensity(image_integral_gpu.get(), in_range='image', out_range=(0, 1)), name='integral')
# Feature batch slice:
batch_feature_slice = (
slice(None),
batch_index,
*(slice(None),) * num_spatiotemp_dim,
)
lprint(
f'Feature slice for batch and channel: {batch_feature_slice}'
)
batch_channel_features = features[batch_feature_slice]
for feature_group in self.features_group_list:
lprint(
f'Computing feature {feature_group}, ' # , and kwargs={kwargs}'
)
# number of features in group:
num_group_features = feature_group.num_features(
num_spatiotemp_dim
)
# Excluded voxels:
if excluded_voxels is None:
excluded_voxels = []
excluded_voxels_for_feature_group = []
if exclude_center_value_for_channel:
excluded_voxels_for_feature_group.extend(
excluded_voxels
)
if (
not (0,) * num_spatiotemp_dim
in excluded_voxels_for_feature_group
):
excluded_voxels_for_feature_group.append(
(0,) * num_spatiotemp_dim
)
# We prepare feature generation for the group by setting the image,
# any computation that can be factored should happen now:
feature_group.prepare(
single_image,
excluded_voxels=excluded_voxels_for_feature_group,
offset=spatial_feature_offset,
scale=spatial_feature_scale,
)
for index in range(num_group_features):
# feature array
feature = batch_channel_features[
feature_pointer
]
# Computing the feature:
feature_group.compute_feature(index, feature)
# Increments feature index:
feature_pointer += 1
feature_group.finish()
# # Usefull code snippet for debugging features:
# import napari
#
# with napari.gui_qt():
# from napari import Viewer
#
# viewer = Viewer()
# viewer.add_image(features, name='features')
# 'collect_features_nD' puts the feature vector in axis 0.
# The following line creates a view of the array
# in which the features are indexed by the last dimension instead:
if feature_last_dim:
lprint('Move feature axis to the last axis ...')
features = numpy.moveaxis(features, 0, -1)
return features