import random
from os.path import join
from typing import Optional, Union, List, Tuple
from deprecated import deprecated
import keras.models
import numpy
from tensorflow.python.eager.context import device
from aydin.io.folders import get_temp_folder
from aydin.it.base import ImageTranslatorBase
from aydin.nn.tf.models.jinet import JINetModel
from aydin.nn.tf.models.unet import UNetModel
from aydin.nn.tf.models.utils.image_tile import tile_target_images, tile_input_images
from aydin.nn.tf.models.utils.unet_patch_size import (
get_ideal_patch_size,
post_tiling_patch_size_validation,
)
from aydin.nn.tf.util.callbacks import (
EarlyStopping,
ReduceLROnPlateau,
StopCenterGradient3D,
StopCenterGradient2D,
ModelCheckpoint,
)
from aydin.nn.tf.util.random_sample_patches import random_sample_patches
from aydin.util.log.log import lsection, lprint
from aydin.util.tf.device import get_best_device_name
[docs]@deprecated(
"All the Tensorflow related code and dependencies are deprecated and will be removed by v0.1.16"
)
class ImageTranslatorCNN(ImageTranslatorBase):
"""
Convolutional Neural Network (CNN) based Image Translator<br>
"""
verbose = 0
def __init__(
self,
training_architecture: str = 'random',
model_architecture: str = 'jinet',
batch_size: int = 32,
nb_unet_levels: int = 3,
batch_norm: str = "instance",
activation: str = 'ReLU',
patch_size: Optional[Union[int, List[int]]] = None,
total_num_patches: int = None,
adoption_rate: float = 0.5,
mask_size: int = 5,
random_mask_ratio: float = 0.1,
max_epochs: int = 30,
patience: int = 4,
learn_rate: float = 0.01,
blind_spots: Optional[Union[str, List[Tuple[int]]]] = None,
tile_min_margin: int = 8,
tile_max_margin: Optional[int] = None,
max_memory_usage_ratio: float = 0.9,
max_tiling_overhead: float = 0.1,
):
"""
Parameters
----------
training_architecture : str
'shiftconv' or 'checkerbox' or 'random' or 'checkran' architecture
(advanced)
model_architecture : str
'unet' or 'jinet'
batch_size : int
Batch size for training
nb_unet_levels : int
Number of layers
(advanced)
batch_norm
Type of batch normalization (e.g. batch, instance)
(advanced)
activation :
(advanced)
patch_size : int
Size for patch sample e.g. 64 for (64, 64) or (64, 64, 64)
(advanced)
total_num_patches
Total number of patches for training
(advanced)
adoption_rate
Percentage of random patches will be used for training, the rest will be discarded
mask_size
Mask shape for masking architecture; int of the same size as the spatial dimension
(advanced)
random_mask_ratio
Probability of masked pixels in random masking approach
(advanced)
max_epochs : int
Maximum number of epochs allowed
patience : int
Patience for EarlyStop or ReducedLR to be triggered
learn_rate : float
Initial learn rate
blind_spots : Optional[Union[str,List[Tuple[int]]]]
List of voxel coordinates (relative to receptive field center) to
be included in the blind-spot. For example, you can enter:
'<axis>#<radius>' to extend the blindspot along a given axis by a
certain radius. For example, for an image of dimension 3, 'x#1'
extends the blind spot to cover voxels of relative coordinates:
(0,0,0),(0,1,0), and (0,-1,0). If you want to extend both in x and y,
enter: 'x#1,y#1' by comma separating between axis. To specify the
axis you can use integer indices, or 'x', 'y', 'z', and 't'
(dimension order is tzyx with x being always the last dimension).
If None is passed then the blindspots are automatically discovered
from the image content. If 'center' is passed then no additional
blindspots to the center pixel are considered. If 'center' is passed
then only the default single center voxel blind-spot is used.
tile_min_margin : int
Minimal width of tile margin in voxels.
(advanced)
tile_max_margin : Optional[int]
Maximal width of tile margin in voxels.
(advanced)
max_memory_usage_ratio : float
Maximum allowed memory load, value must be within [0, 1]. Default is 90%.
(advanced)
max_tiling_overhead : float
Maximum allowed margin overhead during tiling. Default is 10%.
(advanced)
"""
super().__init__(
blind_spots=blind_spots,
tile_min_margin=tile_min_margin,
tile_max_margin=tile_max_margin,
max_memory_usage_ratio=max_memory_usage_ratio,
max_tiling_overhead=max_tiling_overhead,
)
self.model_architecture = model_architecture # both
self.batch_size = batch_size # both
self.batch_norm = batch_norm # both
self.activation_fun = activation # both
self.patch_size = patch_size # both
self.total_num_patches = total_num_patches # both
self.adoption_rate = adoption_rate # both
self.max_epochs = max_epochs # both
self.patience = patience # both
self.learn_rate = learn_rate # both
self.model = None # a CNN model # both
self.infmodel = None # inference model # both
self.EStop_patience = self.patience * 2 # both
self.ReduceLR_patience = self.patience # both
self.checkpoint = None # both
self.input_dim = None # both
self.stop_fitting = False # both
self.validation_images = None # both
self.validation_markers = None # both
self._create_patches_for_validation = (
False # if false use pixels for validation # both
)
self.mask_size = mask_size # unet
self.random_mask_ratio = random_mask_ratio # unet
self.nb_unet_levels = nb_unet_levels # unet
self.training_architecture = training_architecture # unet
with lsection("CNN image translator"):
lprint("training architecture: ", self.training_architecture)
lprint("number of layers: ", self.nb_unet_levels)
lprint("batch norm: ", self.batch_norm)
lprint("mask size: ", self.mask_size)
lprint("max_epochs", self.max_epochs)
lprint("verbose: ", self.verbose)
@property
def model_class(self):
if self.model_architecture == "jinet":
return JINetModel
elif self.model_architecture == "unet":
return UNetModel
else:
raise ValueError("Unknown model architecture")
[docs] def save(self, path: str):
"""
Saves a 'all-batteries-included' image translation model at a given path (folder).
Parameters
----------
path : str
path to save to
Returns
-------
"""
with lsection(f"Saving 'CNN' image translator to {path}"):
frozen = super().save(path)
self.save_cnn(path)
return frozen
def save_cnn(self, path: str):
if self.model is not None:
# serialize model to JSON:
self.model.save(join(path, "tf_model"))
else:
lprint("There is no model to save yet.")
if self.infmodel is not None:
self.infmodel.save(join(path, "tf_inf_model"))
else:
lprint("self.infmodel is None, no inference model will be saved.")
def __getstate__(self):
state = self.__dict__.copy()
# exclude fields below that should/cannot be saved properly:
del state['early_stopping']
del state['reduce_learning_rate']
del state['checkpoint']
del state['model']
del state['loss_history']
del state['infmodel']
del state['validation_images']
return state
def _load_internals(self, path: str):
with lsection(f"Loading 'cnn' image translator from {path}"):
# load JSON and create model:
self.model = keras.models.load_model(join(path, "tf_model"))
self.infmodel = keras.models.load_model(join(path, "tf_inf_model"))
[docs] def stop_training(self):
"""Stops currently running training within the instance by turning the flag
true for early stop callback.
"""
self.stop_fitting = True
def _train(
self,
input_image,
target_image,
train_valid_ratio=0.1,
callback_period=3,
jinv=False,
):
with device(get_best_device_name()):
# Reshape the input image
input_image = numpy.moveaxis(input_image, 1, input_image.ndim - 1)
if not self.self_supervised:
target_image = numpy.moveaxis(target_image, 1, target_image.ndim - 1)
self.spacetime_ndim = input_image.ndim - 2
if self.spacetime_ndim not in [2, 3]:
raise ValueError(
"Number of spacetime dimensions have to be either 2 or 3."
)
self.input_dim = input_image.shape[1:]
# batch_size check conditionals for unet and jinet
if (
self.model_architecture == "unet"
and 'shiftconv' in self.training_architecture
):
self.batch_size = 1
lprint(
'When patch_size is assigned under shiftconv architecture, batch_size is automatically set to 1.'
)
if self.model_architecture == "jinet" and self.spacetime_ndim == 3:
self.batch_size = 1
lprint(f"Batch size for training: {self.batch_size}")
# Compute patch size from batch size
if self.patch_size is None:
self.patch_size = get_ideal_patch_size(
self.nb_unet_levels, self.training_architecture
)
else:
# Check patch_size for unet models with passed patch_size values
if 'unet' in self.model_architecture:
patch_size = numpy.array(self.patch_size)
if (patch_size.max() / (2**self.nb_unet_levels) <= 0).any():
raise ValueError(
f'Tile size is too small. The largest dimension of tile size has to be >= {2 ** self.nb_unet_levels}.'
)
if (patch_size[-2:] % 2**self.nb_unet_levels != 0).any():
raise ValueError(
f'Tile sizes on XY plane have to be multiple of 2^{self.nb_unet_levels}'
)
# Adjust patch_size for given input shape
if isinstance(self.patch_size, int):
self.patch_size = [self.patch_size] * self.spacetime_ndim
# Check if the smallest dimension of input data >= patch_size
if min(self.patch_size) > min(self.input_dim[:-1]):
smallest_dim = min(self.input_dim[:-1])
self.patch_size[numpy.argsort(self.input_dim[:-1])[0]] = (
smallest_dim // 2 * 2
)
# Determine total number of patches
if self.total_num_patches is None:
self.total_num_patches = min(
input_image.size / numpy.prod(self.patch_size), 10240
) # upper limit of num of patches
self.total_num_patches = (
self.total_num_patches
- (self.total_num_patches % self.batch_size)
+ self.batch_size
)
else:
if self.total_num_patches < self.batch_size:
raise ValueError(
'total_num_patches has to be larger than batch_size.'
)
self.total_num_patches = (
self.total_num_patches
- (self.total_num_patches % self.batch_size)
+ self.batch_size
)
# Decide whether to use validation pixels or patches
self._create_patches_for_validation = 1024 <= input_image.size / numpy.prod(
self.patch_size
)
# Tile input and target image
with lsection('Random patch sampling...'):
input_patch_idx = random_sample_patches(
input_image,
self.patch_size[0],
self.total_num_patches,
self.adoption_rate,
)
self.total_num_patches = len(input_patch_idx)
lprint(f'Total number of patches: {self.total_num_patches}')
with lsection('Input image...'):
(
img_train,
self.validation_images,
self.validation_markers,
) = tile_input_images(
input_image,
self._create_patches_for_validation,
input_patch_idx,
train_valid_ratio,
)
with lsection('Target image...'):
target_image = tile_target_images(
img_train, target_image, input_patch_idx, self.self_supervised
)
post_tiling_patch_size_validation(
img_train,
self.nb_unet_levels,
self.training_architecture,
self.self_supervised,
)
unet_only_model_constructor_kwargs = (
{
"mini_batch_size": self.batch_size,
"nb_unet_levels": self.nb_unet_levels,
"normalization": self.batch_norm,
"activation": self.activation_fun,
"supervised": not self.self_supervised,
"training_architecture": self.training_architecture,
}
if self.model_architecture == "unet"
else {}
)
self.model = self.model_class(
img_train.shape[1:],
spacetime_ndim=self.spacetime_ndim,
learning_rate=self.learn_rate,
**unet_only_model_constructor_kwargs,
)
with lsection('CNN model summary:'):
lprint(f'Model architecture: {self.model_architecture}')
if self.model_architecture == 'unet':
lprint(f'Train scheme: {self.training_architecture}')
lprint(f'Number of layers: {self.nb_unet_levels}')
lprint(
f'Number of parameters in the model: {self.model.count_params()}'
)
lprint(f'Batch normalization: {self.batch_norm}')
lprint(f'Training input size: {img_train.shape[1:]}')
# End of train function and beginning of _train from legacy implementation
input_image = img_train
with lsection(
f"Training image translator from image of shape {input_image.shape} to image of shape {target_image.shape}:"
):
if 'jinet' in self.model_architecture:
self.EStop_patience = self.EStop_patience + 10
self.ReduceLR_patience = self.ReduceLR_patience + 20
# Early stopping patience:
lprint(f"Early stopping patience: {self.EStop_patience}")
# Effective LR patience:
lprint(f"Effective LR patience: {self.ReduceLR_patience}")
lprint(f'Batch size: {self.batch_size}')
# Here is the list of callbacks:
callbacks = []
# Early stopping callback:
self.early_stopping = EarlyStopping(
self, patience=self.EStop_patience, restore_best_weights=True
)
# Reduce LR on plateau:
self.reduce_learning_rate = ReduceLROnPlateau(
verbose=1, patience=self.ReduceLR_patience, min_lr=1e-8, min_delta=0
)
self.reduce_learning_rate1 = ReduceLROnPlateau(
verbose=1,
patience=self.ReduceLR_patience,
min_lr=self.learn_rate * 0.01,
min_delta=0,
)
if self.checkpoint is None:
self.model_file_path = join(
get_temp_folder(),
f"aydin_cnn_keras_model_file_{random.randint(0, 1e16)}.hdf5",
)
lprint(f"Model will be saved at: {self.model_file_path}")
self.checkpoint = ModelCheckpoint(
self.model_file_path, verbose=1, save_best_only=True
)
# Add callbacks to the list:
callbacks.append(self.checkpoint)
callbacks.append(self.early_stopping)
callbacks.append(self.reduce_learning_rate)
if 'checkran' in self.training_architecture:
callbacks.append(self.reduce_learning_rate1)
if 'center' in self.blind_spots:
self.blind_spots = []
center_pixel_coord = (0,) * self.spacetime_ndim
if center_pixel_coord in self.blind_spots:
self.blind_spots.remove((0,) * self.spacetime_ndim)
if self.spacetime_ndim == 2:
stop_center_gradient = StopCenterGradient2D(self.blind_spots)
elif self.spacetime_ndim == 3:
stop_center_gradient = StopCenterGradient3D(self.blind_spots)
callbacks = (
callbacks + [stop_center_gradient]
if 'jinet' in self.model_architecture and self.self_supervised
else callbacks
)
lprint("Training now...")
if 'jinet' in self.model_architecture:
self.loss_history = self.model.fit(
input_image=input_image,
target_image=target_image,
max_epochs=self.max_epochs,
callbacks=callbacks,
verbose=self.verbose,
batch_size=self.batch_size,
total_num_patches=self.total_num_patches,
img_val=self.validation_images,
create_patches_for_validation=self._create_patches_for_validation,
train_valid_ratio=train_valid_ratio,
)
else:
self.loss_history = self.model.fit(
input_image=input_image,
target_image=target_image,
max_epochs=self.max_epochs,
callbacks=callbacks,
verbose=self.verbose,
batch_size=self.batch_size,
total_num_patches=self.total_num_patches,
img_val=self.validation_images,
val_marker=self.validation_markers,
create_patches_for_validation=self._create_patches_for_validation,
train_valid_ratio=train_valid_ratio,
random_mask_ratio=self.random_mask_ratio,
patch_size=self.patch_size,
mask_size=self.mask_size,
ReduceLR_patience=self.ReduceLR_patience,
)
def _translate(self, input_image, image_slice=None, whole_image_shape=None):
with device(get_best_device_name()):
# Change dimensions to (B, space, C)
input_image = numpy.moveaxis(input_image, 1, input_image.ndim - 1)
# Check if padding is needed to have dim size of multiple of 2 in all dimension
reshaped_for_cube = False
reshaped_for_model = False
spatial_shape = numpy.array(input_image.shape[1:-1])
if abs(numpy.diff(spatial_shape)).min() != 0:
reshaped_for_cube = True
input_shape_max = numpy.ones(spatial_shape.shape) * spatial_shape.max()
pad_square = (input_shape_max - spatial_shape) / 2
pad_width1 = (
[[0, 0]]
+ [
[
numpy.ceil(pad_square[i]).astype(int),
numpy.floor(pad_square[i]).astype(int),
]
for i in range(len(pad_square))
]
+ [[0, 0]]
)
input_image = numpy.pad(input_image, pad_width1, 'edge')
spatial_shape = numpy.array(input_image.shape[1:-1])
if not (spatial_shape % 2**self.nb_unet_levels == 0).all():
reshaped_for_model = True
pad_width0 = (
2**self.nb_unet_levels
- (spatial_shape % 2**self.nb_unet_levels)
# + pad_square
) / 2
pad_width2 = (
[[0, 0]]
+ [
[
numpy.ceil(pad_width0[i]).astype(int),
numpy.floor(pad_width0[i]).astype(int),
]
for i in range(len(pad_width0))
]
+ [[0, 0]]
)
input_image = numpy.pad(input_image, pad_width2, 'edge')
# Change the batch_size in split layer or input dimensions accordingly
kwargs_for_infmodel = {
'spacetime_ndim': self.spacetime_ndim,
'mini_batch_size': 1,
'nb_unet_levels': self.nb_unet_levels,
'normalization': self.batch_norm,
'activation': self.activation_fun,
'training_architecture': self.training_architecture,
}
if len(input_image.shape[1:-1]) == 2:
kwargs_for_infmodel['input_layer_size'] = [
None,
None,
input_image.shape[-1],
]
elif len(input_image.shape[1:-1]) == 3:
kwargs_for_infmodel['input_layer_size'] = [
None,
None,
None,
input_image.shape[-1],
]
if self.model_architecture == "unet":
kwargs_for_infmodel['original_zdim'] = self.patch_size[0]
if (
'random' in self.training_architecture
or 'check' in self.training_architecture
):
kwargs_for_infmodel['supervised'] = True
else:
kwargs_for_infmodel['supervised'] = not self.self_supervised
if self.infmodel is None:
self.infmodel = self.model_class(**kwargs_for_infmodel)
self.infmodel.set_weights(self.model.get_weights())
try:
output_image = self.infmodel.predict(
input_image, batch_size=self.batch_size, verbose=self.verbose
)
except Exception:
output_image = self.infmodel.predict(
input_image, batch_size=1, verbose=self.verbose
)
# TODO: AhmetCan refactor
if reshaped_for_model:
if len(spatial_shape) == 2:
output_image = output_image[
:,
pad_width2[1][0] : -pad_width2[1][1] or None,
pad_width2[2][0] : -pad_width2[2][1] or None,
:,
]
else:
output_image = output_image[
:,
pad_width2[1][0] : -pad_width2[1][1] or None,
pad_width2[2][0] : -pad_width2[2][1] or None,
pad_width2[3][0] : -pad_width2[3][1] or None,
:,
]
if reshaped_for_cube:
if len(spatial_shape) == 2:
output_image = output_image[
:,
pad_width1[1][0] : -pad_width1[1][1] or None,
pad_width1[2][0] : -pad_width1[2][1] or None,
:,
]
else:
output_image = output_image[
:,
pad_width1[1][0] : -pad_width1[1][1] or None,
pad_width1[2][0] : -pad_width1[2][1] or None,
pad_width1[3][0] : -pad_width1[3][1] or None,
:,
]
output_image = numpy.moveaxis(output_image, output_image.ndim - 1, 1)
return output_image