Source code for augmax.geometric

# Copyright 2024 Konrad Heidler
# 
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# 
#     http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union, Tuple
from abc import abstractmethod
import math
import warnings

import jax
import jax.numpy as jnp
import numpy as np
from einops import rearrange

from .base import Transformation, BaseChain, InputType, same_type, PyTree, RNGKey
from . import utils


class LazyCoordinates:
    _current_transform: jnp.ndarray = np.eye(3)
    _offsets: Union[jnp.ndarray, None] = None
    input_shape: Tuple[int, int]
    current_shape: Tuple[int, int]
    final_shape: Tuple[int, int]

    def __init__(self, shape: Tuple[int, int]):
        self.input_shape = shape
        self.current_shape = shape
        self.final_shape = shape

    def get_coordinate_grid(self) -> jnp.ndarray:
        H, W = self.final_shape
        coordinates = jnp.mgrid[0:H, 0:W] - jnp.array([H/2-0.5, W/2-0.5]).reshape(2, 1, 1)
        coordinates = utils.apply_perspective(coordinates, self._current_transform)

        if self._offsets is not None:
            coordinates = coordinates + self._offsets

        H, W = self.input_shape
        return coordinates + jnp.array([H/2-0.5, W/2-0.5]).reshape(2, 1, 1)

    def apply_to_points(self, points) -> jnp.ndarray:
        M_inv = jnp.linalg.inv(self._current_transform)

        H_in, W_in = self.input_shape
        H_out, W_out = self.final_shape
        c_x = jnp.array([H_in/2 - 0.5, W_in/2 - 0.5]).reshape(2, 1)
        c_y = jnp.array([H_out/2 - 0.5, W_out/2 - 0.5]).reshape(2, 1)
        points = points.T

        transformed_points = utils.apply_perspective(points - c_x, M_inv) + c_y
        if self._offsets is not None:
            # Need to do fix-point iteration
            points_iter = transformed_points
            offset_grid = rearrange(self._offsets, 'c h w -> h w c')
            for _ in range(2):
                # fix-point iteration
                offsets = utils.resample_image(offset_grid, points_iter, order=1).T
                points_iter = utils.apply_perspective(points - offsets - c_x, M_inv) + c_y
            transformed_points = points_iter

        return transformed_points.T

    def push_transform(self, M: jnp.ndarray):
        assert M.shape == (3, 3)
        self._current_transform = M @ self._current_transform
        self._dirty = True

    def apply_pixelwise_offsets(self, offsets: jnp.ndarray):
        assert offsets.shape[1:] == self.final_shape
        if self._offsets == None:
            self._offsets = offsets
        else:
            self._offsets = self._offsets + offsets


class GeometricTransformation(Transformation):
    @abstractmethod
    def transform_coordinates(self, rng: RNGKey, coordinates: LazyCoordinates, invert=False) -> LazyCoordinates:
        return coordinates

    def apply(self, rng: RNGKey, inputs: PyTree, input_types: PyTree, invert=False) -> PyTree:
        # TODO: How do we get the canonical image shape when there are multiple inputs?
        input_shapes = set()
        def extract_shape_if_imagelike(inp, typ):
          if same_type(typ, InputType.IMAGE) \
              or same_type(typ, InputType.MASK) \
              or same_type(typ, InputType.DENSE):
            return np.array(inp.shape[:2])
          else:
            return np.array([])
        shapes = jax.tree_util.tree_map(extract_shape_if_imagelike, inputs, input_types)
        for shape in jax.tree_util.tree_flatten(shapes)[0]:
          if len(shape) == 2:
            input_shapes.add(tuple(shape))
        if len(input_shapes) != 1:
          raise ValueError(f'Ambiguous input shape for geometric Transformations, got {shapes}')
        input_shape, = input_shapes
        output_shape = self.output_shape(input_shape)
        if invert:
            if not self.size_changing():
                output_shape = input_shape
            elif hasattr(self, 'shape_full'):
                output_shape = self.shape_full
            else:
                raise ValueError("Can't invert a size-changing transformation without running it forward once.")
        else:
            self.shape_full = input_shape

        coordinates = LazyCoordinates(input_shape)
        coordinates.final_shape = output_shape

        if invert:
            coordinates.current_shape = output_shape

        self.transform_coordinates(rng, coordinates, invert)
        sampling_coords = coordinates.get_coordinate_grid()

        def transform_single(input, input_type):
            if same_type(input_type, InputType.IMAGE) or same_type(input_type, InputType.DENSE):
                # Linear Interpolation for Images
                return utils.resample_image(input, sampling_coords, order=1, mode='constant')
            elif same_type(input_type, InputType.MASK):
                # Nearest Interpolation for Masks
                return utils.resample_image(input, sampling_coords, order=0, mode='constant')
            elif same_type(input_type, InputType.KEYPOINTS):
                return coordinates.apply_to_points(input)
            elif same_type(input_type, InputType.CONTOUR):
                current = coordinates.apply_to_points(input)
                return jnp.where(jnp.linalg.det(coordinates._current_transform) < 0,
                    current[::-1],
                    current
                )
            elif same_type(input_type, InputType.METADATA):
                return input
            else:
                raise NotImplementedError(f"Cannot transform input of type {input_type} with {self.__class__.__name__}")

        return jax.tree_util.tree_map(transform_single, inputs, input_types)

    def output_shape(self, input_shape: Tuple[int, int]) -> Tuple[int, int]:
        return input_shape

    def size_changing(self):
        return False
        if invert:
            if hasattr(self, 'shape_full'):
                output_shape = self.shape_full
            elif self.size_changing():
                raise ValueError("Can't invert a size-changing transformation without running it forward once.")


class SizeChangingGeometricTransformation(GeometricTransformation):
    def size_changing(self):
        return True


class GeometricChain(GeometricTransformation, BaseChain):
    def __init__(self, *transforms: GeometricTransformation):
        super().__init__()
        for transform in transforms:
            assert isinstance(transform, GeometricTransformation), f"{transform} is not a GeometricTransformation!"
        self.transforms = transforms

    def transform_coordinates(self, rng: jnp.ndarray, coordinates: LazyCoordinates, invert=False):
        shape_chain = [coordinates.input_shape]

        for transform in self.transforms:
            shape_chain.append(transform.output_shape(shape_chain[-1]))

        N = len(self.transforms)
        subkeys = [None]*N if rng is None else jax.random.split(rng, N)

        transforms = self.transforms
        if not invert:
            # Reverse the transformations iff not inverting!
            transforms = reversed(transforms)
            subkeys = reversed(subkeys)
            shape_chain = reversed(shape_chain[:-1])

        for transform, current_shape, subkey in zip(transforms, shape_chain, subkeys):
            coordinates.current_shape = current_shape
            transform.transform_coordinates(subkey, coordinates, invert=invert)

        return coordinates

    def output_shape(self, input_shape: Tuple[int, int]) -> Tuple[int, int]:
        shape = input_shape
        for transform in self.transforms:
            shape = transform.output_shape(shape)
        return shape

    def size_changing(self):
        return any(t.size_changing() for t in self.transforms)


[docs] class HorizontalFlip(GeometricTransformation): """Randomly flips an image horizontally. Args: p (float): Probability of applying the transformation """ def __init__(self, p: float = 0.5): super().__init__() self.probability = p def transform_coordinates(self, rng: jnp.ndarray, coordinates: LazyCoordinates, invert=False): f = 1. - 2. * jax.random.bernoulli(rng, self.probability) transform = jnp.array([ [1, 0, 0], [0, f, 0], [0, 0, 1] ]) coordinates.push_transform(transform)
[docs] class VerticalFlip(GeometricTransformation): """Randomly flips an image vertically. Args: p (float): Probability of applying the transformation """ def __init__(self, p: float = 0.5): super().__init__() self.probability = p def transform_coordinates(self, rng: jnp.ndarray, coordinates: LazyCoordinates, invert=False): f = 1. - 2. * jax.random.bernoulli(rng, self.probability) transform = jnp.array([ [f, 0, 0], [0, 1, 0], [0, 0, 1] ]) coordinates.push_transform(transform)
[docs] class Rotate90(GeometricTransformation): """Randomly rotates the image by a multiple of 90 degrees. """ def __init__(self): super().__init__() def transform_coordinates(self, rng: jnp.ndarray, coordinates: LazyCoordinates, invert=False): params = jax.random.bernoulli(rng, 0.5, [2]) flip = 1. - 2. * params[0] rot = params[1] if invert: flip = (2. * rot - 1.) * flip transform = jnp.array([ [flip * rot, flip * (1.-rot), 0], [flip * (-1.+rot), flip * rot, 0], [0, 0, 1] ]) coordinates.push_transform(transform)
[docs] class Rotate(GeometricTransformation): """Rotates the image by a random arbitrary angle. Args: angle_range (float, float): Tuple of `(min_angle, max_angle)` to sample from. If only a single number is given, angles will be sampled from `(-angle_range, angle_range)`. p (float): Probability of applying the transformation """ def __init__(self, angle_range: Union[Tuple[float, float], float]=(-30, 30), p: float = 1.0): super().__init__() if not hasattr(angle_range, '__iter__'): angle_range = (-angle_range, angle_range) self.theta_min, self.theta_max = map(math.radians, angle_range) self.probability = p def transform_coordinates(self, rng: jnp.ndarray, coordinates: LazyCoordinates, invert=False): do_apply = jax.random.bernoulli(rng, self.probability) theta = do_apply * jax.random.uniform(rng, minval=self.theta_min, maxval=self.theta_max) if invert: theta = -theta transform = jnp.array([ [ jnp.cos(theta), jnp.sin(theta), 0], [-jnp.sin(theta), jnp.cos(theta), 0], [0, 0, 1] ]) coordinates.push_transform(transform)
class Translate(GeometricTransformation): def __init__(self, dx, dy): super().__init__() self.dx = dx self.dy = dy def transform_coordinates(self, rng: jnp.ndarray, coordinates: LazyCoordinates, invert=False): dy = self.dy dx = self.dx if invert: dy = -dy dx = -dx transform = jnp.array([ [1, 0, -dy], [0, 1, -dx], [0, 0, 1] ]) coordinates.push_transform(transform)
[docs] class Crop(SizeChangingGeometricTransformation): """Crop the image at the specified x0 and y0 with given width and height Args: x0 (float): x-coordinate of the crop's top-left corner y0 (float): y-coordinate of the crop's top-left corner w (float): width of the crop h (float): height of the crop """ def __init__(self, x0, y0, w, h): super().__init__() self.x0 = x0 self.y0 = y0 self.width = w self.height = h def transform_coordinates(self, rng: jnp.ndarray, coordinates: LazyCoordinates, invert=False): H, W = coordinates.current_shape center_x = self.x0 + self.width / 2 - W / 2 center_y = self.y0 + self.height / 2 - H / 2 # self.dx/dy is in (0,0) -- (H,W) reference frame # => push it to (-H/2, -W/2) -- (H/2, W/2) reference frame # Forward transform: Translate by (dx, dy) if invert: center_y = -center_y center_x = -center_x transform = jnp.array([ [1, 0, center_y], [0, 1, center_x], [0, 0, 1] ]) coordinates.push_transform(transform) def output_shape(self, input_shape: Tuple[int, int]) -> Tuple[int, int]: return (self.height, self.width)
[docs] class Resize(SizeChangingGeometricTransformation): def __init__(self, width: int, height: int = None): super().__init__() self.width = width self.height = width if height is None else height def output_shape(self, input_shape: Tuple[int, int]) -> Tuple[int, int]: return (self.height, self.width) def __repr__(self): return f'Resize({self.width}, {self.height})' def transform_coordinates(self, rng: jnp.ndarray, coordinates: LazyCoordinates, invert=False): H, W = coordinates.current_shape H_, W_ = self.height, self.width sy = H / H_ sx = W / W_ if invert: sy = 1 / sy sx = 1 / sx transform = jnp.array([ [sy, 0, 0], [ 0, sx, 0], [ 0, 0, 1], ]) coordinates.push_transform(transform)
[docs] class CenterCrop(SizeChangingGeometricTransformation): """Extracts a central crop from the image with given width and height. Args: w (float): width of the crop h (float): height of the crop """ width: int height: int def __init__(self, width: int, height: int = None): super().__init__() self.width = width self.height = width if height is None else height def transform_coordinates(self, rng: jnp.ndarray, coordinates: LazyCoordinates, invert=False): # Cropping is done implicitly via output_shape pass def output_shape(self, input_shape: Tuple[int, int]) -> Tuple[int, int]: return (self.height, self.width) def __repr__(self): return f'CenterCrop({self.width}, {self.height})'
[docs] class RandomCrop(SizeChangingGeometricTransformation): """Extracts a random crop from the image with given width and height. Args: w (float): width of the crop h (float): height of the crop """ width: int height: int def __init__(self, width: int, height: int = None): super().__init__() self.width = width self.height = width if height is None else height def transform_coordinates(self, rng: jnp.ndarray, coordinates: LazyCoordinates, invert=False): H, W = coordinates.current_shape limit_y = (H - self.height) / 2 limit_x = (W - self.width) / 2 center_y, center_x = jax.random.uniform(rng, [2], minval=jnp.array([-limit_y, -limit_x]), maxval=jnp.array([limit_y, limit_x])) if invert: center_y = -center_y center_x = -center_x transform = jnp.array([ [1, 0, center_y], [0, 1, center_x], [0, 0, 1] ]) coordinates.push_transform(transform) def output_shape(self, input_shape: Tuple[int, int]) -> Tuple[int, int]: return (self.height, self.width)
[docs] class RandomSizedCrop(SizeChangingGeometricTransformation): """Extracts a randomly sized crop from the image and rescales it to the given width and height. Args: w (float): width of the crop h (float): height of the crop zoom_range (float, float): minimum and maximum zoom level for the transformation prevent_underzoom (bool): whether to prevent zooming beyond the image size """ width: int height: int min_zoom: float max_zoom: float def __init__(self, width: int, height: int = None, zoom_range: Tuple[float, float] = (0.5, 2.0), prevent_underzoom: bool = True): super().__init__() self.width = width self.height = width if height is None else height self.min_zoom = zoom_range[0] self.max_zoom = zoom_range[1] self.prevent_underzoom = prevent_underzoom def transform_coordinates(self, rng: jnp.ndarray, coordinates: LazyCoordinates, invert=False): H, W = coordinates.current_shape key1, key2 = jax.random.split(rng) if self.prevent_underzoom: min_zoom = max(self.min_zoom, math.log(self.height / H), math.log(self.width / W)) max_zoom = max(self.max_zoom, min_zoom) else: min_zoom = self.min_zoom max_zoom = self.max_zoom zoom = utils.log_uniform(key1, minval=min_zoom, maxval=max_zoom) limit_y = ((H*zoom) - self.height) / 2 limit_x = ((W*zoom) - self.width) / 2 center = jax.random.uniform(key2, [2], minval=jnp.array([-limit_y, -limit_x]), maxval=jnp.array([limit_y, limit_x])) # Out matrix: # [ 1/zoom 0 1/c_y ] # [ 0 1/zoom 1/c_x ] # [ 0 0 1 ] if not invert: transform = jnp.concatenate([ jnp.concatenate([jnp.eye(2), center.reshape(2, 1)], axis=1) / zoom, jnp.array([[0, 0, 1]]) ], axis=0) else: transform = jnp.concatenate([ jnp.concatenate([jnp.eye(2) * zoom, -center.reshape(2, 1)], axis=1), jnp.array([[0, 0, 1]]) ], axis=0) coordinates.push_transform(transform) def output_shape(self, input_shape: Tuple[int, int]) -> Tuple[int, int]: return (self.height, self.width)
[docs] class Warp(GeometricTransformation): """ Warp an image (similar to ElasticTransform). Args: strength (float): How strong the transformation is, corresponds to the standard deviation of deformation values. coarseness (float): Size of the initial deformation grid cells. Lower values lead to a more noisy deformation. """ def __init__(self, strength: int=5, coarseness: int=32): super().__init__() self.strength = strength self.coarseness = coarseness def transform_coordinates(self, rng: jnp.ndarray, coordinates: LazyCoordinates, invert=False): if invert: warnings.warn("Inverting a Warp transform not yet implemented. Returning warped image as is.") return H, W = coordinates.final_shape H_, W_ = H // self.coarseness, W // self.coarseness coordshift_coarse = self.strength * jax.random.normal(rng, [2, H_, W_]) # Note: This is not 100% correct as it ignores possible perspective conmponents of # the current transform. Also, interchanging resize and transform application # is a speed hack, but this shouldn't diminish the quality. coordshift = jnp.tensordot(coordinates._current_transform[:2, :2], coordshift_coarse, axes=1) coordshift = jax.image.resize(coordshift, (2, H, W), method='bicubic') coordinates.apply_pixelwise_offsets(coordshift)