# Copyright 2024 Konrad Heidler
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import abstractmethod
from typing import Union, List, Tuple
import math
import jax
import jax.numpy as jnp
from einops import rearrange
import warnings
from .base import Transformation, InputType, same_type, PyTree, RNGKey
class ImageLevelTransformation(Transformation):
pass
[docs]
class GridShuffle(ImageLevelTransformation):
"""Divides the image into grid cells and shuffles them randomly.
Args:
grid_size (int, int): Tuple of `(gridcells_x, gridcells_y)` that specifies into how many
cells the image is to be divided along each axis.
If only a single number is given, that value will be used along both axes.
Currently requires that each image dimension is a multiple of the corresponding value.
p (float): Probability of applying the transformation
"""
def __init__(self, grid_size: Union[Tuple[int, int], int] = (4, 4), p: float = 0.5, input_types=InputType.IMAGE):
super().__init__(input_types)
if hasattr(grid_size, '__iter__'):
self.grid_size = tuple(grid_size)
else:
self.grid_size = (self.grid_size, self.grid_size)
self.grid_size = grid_size
self.probability = p
def apply(self, rng: RNGKey, inputs: PyTree, input_types: PyTree, invert=False) -> List[jnp.ndarray]:
key1, key2 = jax.random.split(rng)
do_apply = jax.random.bernoulli(key1, self.probability)
def transform_single(input, input_type):
if same_type(input_type, InputType.IMAGE) or same_type(input_type, InputType.MASK) or same_type(input_type, InputType.DENSE):
raw_image = input
H, W, *_ = raw_image.shape
gx, gy = self.grid_size
if H % self.grid_size[0] != 0:
raise ValueError(f"Image height ({H}) needs to be a multiple of gridcells_y ({gy})")
if W % self.grid_size[1] != 0:
raise ValueError(f"Image width ({W}) needs to be a multiple of gridcells_x ({gx})")
image = rearrange(raw_image, '(gy h) (gx w) c -> (gy gx) h w c', gx=gx, gy=gy)
if invert:
inv_permutation = jnp.argsort(jax.random.permutation(key2, image.shape[0]))
image = image[inv_permutation]
else:
image = jax.random.permutation(key2, image)
image = rearrange(image, '(gy gx) h w c -> (gy h) (gx w) c', gx=gx, gy=gy)
return jnp.where(do_apply, image, raw_image)
elif same_type(input_type, InputType.METADATA):
return input
else:
raise NotImplementedError(f"GridShuffle for {input_type} not yet implemented")
return jax.tree_util.tree_map(transform_single, inputs, input_types)
class _ConvolutionalBlur(ImageLevelTransformation):
@abstractmethod
def __init__(self, p: float = 0.5, input_types=[InputType.IMAGE]):
super().__init__(input_types)
self.probability = p
self.kernel = None
self.kernelsize = -1
def apply(self, rng: RNGKey, inputs: PyTree, input_types: PyTree, invert=False) -> PyTree:
if input_types is None:
input_types = self.input_types
do_apply = jax.random.bernoulli(rng, self.probability)
p0 = self.kernelsize // 2
p1 = self.kernelsize - p0 - 1
def transform_single(input, input_type):
if same_type(input_type, InputType.IMAGE):
if invert:
warnings.warn("Trying to invert a Blur Filter, which is not invertible.")
return input
else:
image_padded = jnp.pad(input, [(p0, p1), (p0, p1), (0, 0)], mode='edge')
image_padded = rearrange(image_padded, 'h w (c c2) -> c c2 h w', c2=1)
convolved = jax.lax.conv(image_padded, self.kernel, [1, 1], 'valid')
convolved = rearrange(convolved, 'c c2 h w -> h w (c c2)', c2=1)
return jnp.where(do_apply, convolved, input)
else:
return input
return jax.tree_util.tree_map(transform_single, inputs, input_types)
[docs]
class Blur(_ConvolutionalBlur):
def __init__(self, size: int = 5, p: float = 0.5):
super().__init__(p)
self.kernel = jnp.ones([1, 1, size, size])
self.kernel = self.kernel / self.kernel.sum()
self.kernelsize = size
[docs]
class GaussianBlur(_ConvolutionalBlur):
def __init__(self, sigma: int = 3, p: float = 0.5):
super().__init__(p)
N = int(math.ceil(2 * sigma))
rng = jnp.linspace(-2.0, 2.0, N)
x = rng.reshape(1, -1)
y = rng.reshape(-1, 1)
self.kernel = jnp.exp((-0.5/sigma) * (x*x + y*y))
self.kernel = self.kernel / self.kernel.sum()
self.kernel = self.kernel.reshape(1, 1, N, N)
self.kernelsize = N