Source code for augmax.colorspace

# Copyright 2024 Konrad Heidler
# 
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# 
#     http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import abstractmethod
from typing import List, Tuple
from functools import partial

import jax
import jax.numpy as jnp
import numpy as np
import warnings

from .base import Transformation, BaseChain, InputType, same_type, PyTree, RNGKey
from .utils import log_uniform, rgb_to_hsv, hsv_to_rgb
from .functional import colorspace as F


class ColorspaceTransformation(Transformation):
    @abstractmethod
    def pixelwise(self, rng: RNGKey, pixel: jnp.ndarray, invert=False) -> jnp.ndarray:
        return pixel

    def apply(self, rng: RNGKey, inputs: PyTree, input_types: PyTree, invert=False) -> PyTree:
        op = partial(self.pixelwise, invert=invert)
        full_op = jax.jit(jax.vmap(jax.vmap(op, [None, 0], 0), [None, 1], 1))

        def transform_single(input, input_type):
            if same_type(input_type, InputType.IMAGE):
                return full_op(rng, input)
            else:
                return input

        return jax.tree_util.tree_map(transform_single, inputs, input_types)


class ColorspaceChain(ColorspaceTransformation, BaseChain):
    def __init__(self, *transforms: ColorspaceTransformation, input_types=None):
        super().__init__(input_types)
        self.transforms = transforms

    def pixelwise(self, rng: jnp.ndarray, pixel: jnp.ndarray, invert=False) -> jnp.ndarray:
        N = len(self.transforms)
        subkeys = [None]*N if rng is None else jax.random.split(rng, N)

        transforms = self.transforms
        if invert:
            transforms = reversed(transforms)
            subkeys = reversed(subkeys)

        for transform, subkey in zip(transforms, subkeys):
            pixel = transform.pixelwise(subkey, pixel, invert=invert)
        return pixel


[docs] class ByteToFloat(ColorspaceTransformation): """Transforms images from uint8 representation (values 0-255) to normalized float representation (values 0.0-1.0) """ def pixelwise(self, rng: jnp.ndarray, pixel: jnp.ndarray, invert=False) -> jnp.ndarray: if invert: return jnp.clip(255.0 * pixel, 0, 255).astype(jnp.uint8) else: return pixel.astype(jnp.float32) / 255.0
[docs] class Normalize(ColorspaceTransformation): """Normalizes images using given coefficients using the mapping .. math:: p_k \\longmapsto \\frac{p_k - \\mathtt{mean}_k}{\\mathtt{std}_k} Args: mean (jnp.ndarray): Mean values for each channel std (jnp.ndarray): Standard deviation for each channel """ def __init__(self, mean: jnp.ndarray = np.array([0.485, 0.456, 0.406]), std: jnp.ndarray = np.array([0.229, 0.224, 0.225]), input_types=None ): super().__init__(input_types) self.mean = jnp.asarray(mean) self.std = jnp.asarray(std) def pixelwise(self, rng: jnp.ndarray, pixel: jnp.ndarray, invert=False) -> jnp.ndarray: if not invert: return (pixel - self.mean) / self.std else: return (pixel * self.std) + self.mean
[docs] class ChannelShuffle(ColorspaceTransformation): """Randomly shuffles an images channels. Args: p (float): Probability of applying the transformation """ def __init__(self, p: float = 0.5, input_types=None ): super().__init__(input_types) self.probability = p def pixelwise(self, rng: jnp.ndarray, pixel: jnp.ndarray, invert=False) -> jnp.ndarray: k1, k2 = jax.random.split(rng) do_apply = jax.random.bernoulli(k2, self.probability) if not invert: return jnp.where(do_apply, jax.random.permutation(k1, pixel), pixel ) else: inv_permutation = jnp.argsort(jax.random.permutation(k1, pixel.shape[0])) return jnp.where(do_apply, pixel[inv_permutation], pixel )
[docs] class RandomGamma(ColorspaceTransformation): """Randomly adjusts the image gamma. Args: range (float, float): p (float): Probability of applying the transformation """ def __init__(self, range: Tuple[float, float]=(0.25, 4.0), p: float = 0.5, input_types=None ): super().__init__(input_types) self.range = range self.probability = p def pixelwise(self, rng: jnp.ndarray, pixel: jnp.ndarray, invert=False) -> jnp.ndarray: if pixel.dtype != jnp.float32: raise ValueError(f"RandomGamma can only be applied to float images, but the input is {pixel.dtype}. " "Please call ByteToFloat first.") k1, k2 = jax.random.split(rng) random_gamma = log_uniform(k1, minval=self.range[0], maxval=self.range[1]) gamma = jnp.where(jax.random.bernoulli(k2, self.probability), random_gamma, 1.0) if not invert: return jnp.power(pixel, gamma) else: return jnp.power(pixel, 1/gamma)
[docs] class RandomBrightness(ColorspaceTransformation): """Randomly adjusts the image brightness. Args: range (float, float): p (float): Probability of applying the transformation """ def __init__(self, range: Tuple[float, float] = (-1.0, 1.0), p: float = 0.5, input_types=None ): super().__init__(input_types) self.minval = range[0] / 2.0 self.maxval = range[1] / 2.0 self.probability = p def pixelwise(self, rng: jnp.ndarray, pixel: jnp.ndarray, invert=False) -> jnp.ndarray: if pixel.dtype != jnp.float32: raise ValueError(f"RandomContrast can only be applied to float images, but the input is {pixel.dtype}. " "Please call ByteToFloat first.") k1, k2 = jax.random.split(rng) random_brightness = jax.random.uniform(k1, minval=self.minval, maxval=self.maxval) brightness = jnp.where(jax.random.bernoulli(k2, self.probability), random_brightness, 0.0) # cf. https://gitlab.gnome.org/GNOME/gimp/-/blob/master/app/operations/gimpoperationbrightnesscontrast.c return F.adjust_brightness(pixel, brightness, invert=invert)
[docs] class RandomContrast(ColorspaceTransformation): """Randomly adjusts the image contrast. Args: range (float, float): p (float): Probability of applying the transformation """ def __init__(self, range: Tuple[float, float] = (-1.0, 1.0), p: float = 0.5, input_types=None ): super().__init__(input_types) self.minval = range[0] / 2.0 self.maxval = range[1] / 2.0 self.probability = p def pixelwise(self, rng: jnp.ndarray, pixel: jnp.ndarray, invert=False) -> jnp.ndarray: if pixel.dtype != jnp.float32: raise ValueError(f"RandomContrast can only be applied to float images, but the input is {pixel.dtype}. " "Please call ByteToFloat first.") k1, k2 = jax.random.split(rng) random_contrast = jax.random.uniform(k1, minval=self.minval, maxval=self.maxval) contrast = jnp.where(jax.random.bernoulli(k2, self.probability), random_contrast, 0.0) return F.adjust_contrast(pixel, contrast, invert=invert)
[docs] class ColorJitter(ColorspaceTransformation): """Randomly jitter the image colors. Args: range (float, float): p (float): Probability of applying the transformation """ def __init__(self, brightness: float = 0.1, contrast: float = 0.1, saturation: float = 0.1, hue: float = 0.1, p: float=0.5, input_types=None ): super().__init__(input_types) self.brightness = brightness self.contrast = contrast self.saturation = saturation self.hue = hue self.probability = p self.keys_needed = sum(1 if val > 0 else 0 for val in [brightness, contrast, saturation, hue]) if p < 1: self.keys_needed += 1 def pixelwise(self, rng: jnp.ndarray, pixel: jnp.ndarray, invert=False) -> jnp.ndarray: if pixel.shape != (3, ): raise ValueError(f"ColorJitter only supports RGB imagery for now, got {pixel.shape}") if pixel.dtype != jnp.float32: raise ValueError(f"ColorJitter can only be applied to float images, but the input is {pixel.dtype}. " "Please call ByteToFloat first.") keys = jax.random.split(rng, self.keys_needed) hue, saturation, value = rgb_to_hsv(pixel) ops = ['brightness', 'contrast', 'hue', 'saturation'] if invert: ops = reversed(ops) keys = reversed(keys) for op, key in zip(ops, keys): strength = getattr(self, op) if strength <= 0: continue amount = jax.random.uniform(key, minval=-strength, maxval=strength) if op == 'brightness': value = F.adjust_brightness(value, amount, invert=invert) elif op == 'contrast': value = F.adjust_contrast(value, amount, invert=invert) elif op == 'hue': if invert: amount = -amount hue = hue + amount elif op == 'saturation': F.adjust_brightness(saturation, amount, invert=invert) transformed = hsv_to_rgb(hue, saturation, value) if self.probability < 1: do_apply = jax.random.bernoulli(rng, self.probability) transformed = jnp.where(do_apply, transformed, pixel) return transformed
[docs] class RandomGrayscale(ColorspaceTransformation): """Randomly converts the image to grayscale. Args: p (float): Probability of applying the transformation """ def __init__(self, p: float = 0.5, input_types=None ): super().__init__(input_types) self.probability = p def pixelwise(self, rng: jnp.ndarray, pixel: jnp.ndarray, invert=False) -> jnp.ndarray: if pixel.dtype != jnp.float32: raise ValueError(f"RandomGrayscale can only be applied to float images, but the input is {pixel.dtype}. " "Please call ByteToFloat first.") if invert: warnings.warn("Trying to invert a Grayscale Filter, which is not invertible.") return pixel do_apply = jax.random.bernoulli(rng, self.probability) return jnp.where(do_apply, F.to_grayscale(pixel), pixel )
[docs] class RandomChannelGamma(ColorspaceTransformation): """Randomly adjusts each channel's gamma. Args: range (float, float): p (float): Probability of applying the transformation """ def __init__(self, range: Tuple[float, float]=(0.25, 4.0), p: float = 0.5, input_types=None ): super().__init__(input_types) self.range = range self.probability = p def pixelwise(self, rng: jnp.ndarray, pixel: jnp.ndarray, invert=False) -> jnp.ndarray: if pixel.dtype != jnp.float32: raise ValueError(f"RandomGamma can only be applied to float images, but the input is {pixel.dtype}. " "Please call ByteToFloat first.") k1, k2 = jax.random.split(rng) random_gamma = log_uniform(k1, shape=pixel.shape, minval=self.range[0], maxval=self.range[1]) gamma = jnp.where(jax.random.bernoulli(k2, self.probability), random_gamma, 1.0) if not invert: return jnp.power(pixel, gamma) else: return jnp.power(pixel, 1/gamma)
[docs] class Solarization(ColorspaceTransformation): """Randomly solarizes the image. Args: range (float, float): p (float): Probability of applying the transformation """ def __init__(self, threshold: float = 0.5, p: float = 0.5, input_types=None ): super().__init__(input_types) self.range = range self.threshold = threshold self.probability = p def pixelwise(self, rng: jnp.ndarray, pixel: jnp.ndarray, invert=False) -> jnp.ndarray: if pixel.dtype != jnp.float32: raise ValueError(f"Solarization can only be applied to float images, but the input is {pixel.dtype}. " "Please call ByteToFloat first.") if invert: warnings.warn("Trying to invert a Solarization Filter, which is not invertible.") return pixel do_apply = jax.random.bernoulli(rng, self.probability) solarized = jnp.where((pixel > self.threshold) & do_apply, 1.0 - pixel, pixel ) return solarized
[docs] class ChannelDrop(ColorspaceTransformation): """Randomly drops a channelf from the image Args: p (float): Probability of applying the transformation """ def __init__(self, p: float = 0.5, input_types=None ): super().__init__(input_types) self.probability = p def pixelwise(self, rng: jnp.ndarray, pixel: jnp.ndarray, invert=False) -> jnp.ndarray: k1, k2 = jax.random.split(rng) C, = pixel.shape do_apply = jax.random.bernoulli(k1, self.probability) apply_channel = jax.random.randint(k1, [], minval=0, maxval=C) return jnp.where(do_apply & (jnp.arange(C) == apply_channel), 0.0, pixel)