Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions ffcv/libffcv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from numba import njit
import numpy as np
import platform
from ctypes import CDLL, c_int64, c_uint8, c_uint64, POINTER, c_void_p, c_uint32, c_bool, cdll
from ctypes import CDLL, c_int64, c_uint8, c_uint64, c_float, POINTER, c_void_p, c_uint32, c_bool, cdll
import ffcv._libffcv

lib = CDLL(ffcv._libffcv.__file__)
Expand All @@ -22,6 +22,22 @@ def read(fileno:int, destination:np.ndarray, offset:int):
ctypes_resize = lib.resize
ctypes_resize.argtypes = 11 * [c_int64]

ctypes_rotate = lib.rotate
ctypes_rotate.argtypes = [c_float, c_int64, c_int64, c_int64, c_int64]

ctypes_shear = lib.shear
ctypes_shear.argtypes = [c_float, c_float, c_int64, c_int64, c_int64, c_int64]

ctypes_add_weighted = lib.add_weighted
ctypes_add_weighted.argtypes = [c_int64, c_float, c_int64, c_float, c_int64, c_int64, c_int64]

ctypes_equalize = lib.equalize
ctypes_equalize.argtypes = 4 * [c_int64]

ctypes_unsharp_mask = lib.unsharp_mask
ctypes_unsharp_mask.argtypes = 4 * [c_int64]


def resize_crop(source, start_row, end_row, start_col, end_col, destination):
ctypes_resize(0,
source.ctypes.data,
Expand Down Expand Up @@ -52,5 +68,4 @@ def imdecode(source: np.ndarray, dst: np.ndarray,
ctypes_memcopy.argtypes = [c_void_p, c_void_p, c_uint64]

def memcpy(source: np.ndarray, dest: np.ndarray):
return ctypes_memcopy(source.ctypes.data, dest.ctypes.data, source.size*source.itemsize)

return ctypes_memcopy(source.ctypes.data, dest.ctypes.data, source.size*source.itemsize)
5 changes: 3 additions & 2 deletions ffcv/pipeline/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import warnings
import ast


try:
# Useful for debugging
import astor
Expand Down Expand Up @@ -330,7 +331,7 @@ def collect_requirements(self, state=INITIAL_STATE,
next_state, allocation = operation.declare_state_and_memory(state)
state_allocation = operation.declare_shared_memory(state)

if next_state.device.type != 'cuda' and isinstance(operation,
if next_state.device != 'cuda' and isinstance(operation,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think as of v1.0 the device will be a torch.device in which case we would want next_state.device.type?

ModuleWrapper):
msg = ("Using a pytorch transform on the CPU is extremely"
"detrimental to the performance, consider moving the augmentation"
Expand Down Expand Up @@ -485,4 +486,4 @@ def codegen_all(self, code):
code_stages.append(self.codegen_stage(stage, s_ix, op_to_node, code, already_defined))

final_output = [x.id for x in self.leaf_nodes.values()]
return code_stages, final_output
return code_stages, final_output
2 changes: 2 additions & 0 deletions ffcv/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .translate import RandomTranslate
from .mixup import ImageMixup, LabelMixup, MixupToOneHot
from .module import ModuleWrapper
from .randaugment import RandAugment
from .solarization import Solarization
from .color_jitter import RandomBrightness, RandomContrast, RandomSaturation
from .erasing import RandomErasing
Expand All @@ -19,6 +20,7 @@
'RandomResizedCrop', 'RandomHorizontalFlip', 'RandomTranslate',
'Cutout', 'RandomCutout', 'RandomErasing',
'ImageMixup', 'LabelMixup', 'MixupToOneHot',
'RandAugment',
'Poison', 'ReplaceLabel',
'ModuleWrapper',
'Solarization',
Expand Down
118 changes: 118 additions & 0 deletions ffcv/transforms/randaugment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import numpy as np
from ffcv.pipeline.compiler import Compiler
from ffcv.pipeline.operation import Operation, AllocationQuery
from dataclasses import replace
from typing import Callable, Optional, Tuple
from ffcv.pipeline.state import State
from ffcv.transforms.utils.fast_crop import rotate, shear, blend, \
adjust_contrast, posterize, invert, solarize, equalize, fast_equalize, \
autocontrast, sharpen, adjust_saturation, translate, adjust_brightness

class RandAugment(Operation):
def __init__(self,
num_ops: int = 2,
magnitude: int = 9,
num_magnitude_bins: int = 31):
super().__init__()
self.num_ops = num_ops
self.magnitude = magnitude
num_bins = num_magnitude_bins
# index, name (for readability); bins, sign multiplier
# those with a -1 can have negative magnitude with probability 0.5
self.op_table = [
(0, "Identity", np.array(0.0), 1),
(1, "ShearX", np.linspace(0.0, 0.3, num_bins), -1),
(2, "ShearY", np.linspace(0.0, 0.3, num_bins), -1),
(3, "TranslateX", np.linspace(0.0, 150.0 / 331.0, num_bins), -1),
(4, "TranslateY", np.linspace(0.0, 150.0 / 331.0, num_bins), -1),
(5, "Rotate", np.linspace(0.0, 30.0, num_bins), -1),
(6, "Brightness", np.linspace(0.0, 0.9, num_bins), -1),
(7, "Color", np.linspace(0.0, 0.9, num_bins), -1),
(8, "Contrast", np.linspace(0.0, 0.9, num_bins), -1),
(9, "Sharpness", np.linspace(0.0, 0.9, num_bins), -1),
(10, "Posterize", 8 - (np.arange(num_bins) / ((num_bins - 1) / 4)).round(), 1),
(11, "Solarize", np.linspace(255.0, 0.0, num_bins), 1),
(12, "AutoContrast", np.array(0.0), 1),
(13, "Equalize", np.array(0.0), 1),
]

def generate_code(self) -> Callable:
my_range = Compiler.get_iterator()
op_table = self.op_table
magnitudes = np.array([(op[2][self.magnitude] if op[2].ndim > 0 else 0) for op in self.op_table])
is_signed = np.array([op[3] for op in self.op_table])
num_ops = self.num_ops
# for i in range(len(magnitudes)):
# print(i, op_table[i][1], '%.3f'%magnitudes[i])
def randaug(im, mem):
dst, scratch, lut, scratchf = mem
for i in my_range(im.shape[0]):
for n in range(num_ops):
if n == 0:
src = im
else:
src = dst

idx = np.random.randint(low=0, high=13+1)
mag = magnitudes[idx]
if np.random.random() < 0.5:
mag = mag * is_signed[idx]

# Not worth fighting numba at the moment.
# TODO
if idx == 0:
dst[i][:] = src[i]

if idx == 1: # ShearX (0.004)
shear(src[i], dst[i], mag, 0)

if idx == 2: # ShearY
shear(src[i], dst[i], 0, mag)

if idx == 3: # TranslateX
translate(src[i], dst[i], int(src[i].shape[1] * mag), 0)

if idx == 4: # TranslateY
translate(src[i], dst[i], 0, int(src[i].shape[2] * mag))

if idx == 5: # Rotate
rotate(src[i], dst[i], mag)

if idx == 6: # Brightness
adjust_brightness(src[i], scratch[i][0], 1.0 + mag, dst[i])

if idx == 7: # Color
adjust_saturation(src[i], scratch[i][0], 1.0 + mag, dst[i])

if idx == 8: # Contrast
adjust_contrast(src[i], scratch[i][0], 1.0 + mag, dst[i])

if idx == 9: # Sharpness
sharpen(src[i], scratch[i][0], 1.0 + mag, dst[i])

if idx == 10: # Posterize
posterize(src[i], int(mag), dst[i])

if idx == 11: # Solarize
solarize(src[i], scratch[i][0], mag, dst[i])

if idx == 12: # AutoContrast (TODO: takes 0.04s -> 0.052s) (+0.01s)
autocontrast(src[i], scratchf[i][0], dst[i])

if idx == 13: # Equalize (TODO: +0.008s)
equalize(src[i], lut[i], dst[i])

return dst

randaug.is_parallel = True
return randaug

def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Optional[AllocationQuery]]:
assert previous_state.jit_mode
h, w, c = previous_state.shape
return replace(previous_state, shape=previous_state.shape), [
AllocationQuery(previous_state.shape, dtype=np.dtype('uint8')),
AllocationQuery((1, h, w, c), dtype=np.dtype('uint8')),
AllocationQuery((c, 256), dtype=np.dtype('int16')),
AllocationQuery((1, h, w, c), dtype=np.dtype('float32')),
]
192 changes: 190 additions & 2 deletions ffcv/transforms/utils/fast_crop.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,195 @@
import ctypes
from numba import njit
from numba import njit, prange
import numpy as np
from ...libffcv import ctypes_resize
from ...libffcv import ctypes_resize, ctypes_rotate, ctypes_shear, \
ctypes_add_weighted, ctypes_equalize, ctypes_unsharp_mask

"""
Requires a float32 scratch array
"""
@njit(parallel=True, fastmath=True, inline='always')
def autocontrast(source, scratchf, destination):
# numba: no kwargs in min? as a consequence, I might as well have written
# this in C++
# TODO assuming 3 channels
minimum = [source[..., 0].min(), source[..., 1].min(), source[..., 2].min()]
maximum = [source[..., 0].max(), source[..., 1].max(), source[..., 2].max()]
scale = [0.0, 0.0, 0.0]
for i in prange(source.shape[-1]):
if minimum[i] == maximum[i]:
scale[i] = 1
minimum[i] = 0
else:
scale[i] = 255. / (maximum[i] - minimum[i])
for i in prange(source.shape[-1]):
scratchf[..., i] = source[..., i] - minimum[i]
scratchf[..., i] = scratchf[..., i] * scale[i]
np.clip(scratchf, 0, 255, out=scratchf)
destination[:] = scratchf


"""
Custom equalize -- equivalent to torchvision.transforms.functional.equalize,
but probably slow -- scratch is a (channels, 256) uint16 array.
"""
@njit(parallel=True, fastmath=True, inline='always')
def equalize(source, scratch, destination):
for i in prange(source.shape[-1]):
# TODO memory less than ideal for bincount() and hist()
scratch[i] = np.bincount(source[..., i].flatten(), minlength=256)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunate that np.bincount doesn't have an out argument...

Copy link
Collaborator

@GuillaumeLeclerc GuillaumeLeclerc Feb 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A numba version should be pretty fast and relatively easy to implement no ? (and might even be faster since it would skip the first pass of bincount that checks the min and max values)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, good idea. I'll try to add that in the near future.

nonzero_hist = scratch[i][scratch[i] != 0]
step = nonzero_hist[:-1].sum() // 255

if step == 0:
continue

scratch[i][1:] = scratch[i].cumsum()[:-1]
scratch[i] = (scratch[i] + step // 2) // step
scratch[i][0] = 0
np.clip(scratch[i], 0, 255, out=scratch[i])

# numba doesn't like 2d advanced indexing
for row in prange(source.shape[0]):
destination[row, :, i] = scratch[i][source[row, :, i]]

"""
Equalize using OpenCV -- not equivalent to
torchvision.transforms.functional.equalize for so-far-unknown reasons.
"""
@njit(parallel=False, fastmath=True, inline='always')
def fast_equalize(source, chw_scratch, destination):
# this seems kind of hacky
# also, assuming ctypes_equalize allocates a minimal amount of memory
# which may be incorrect -- so maybe we should do this from scratch.
# TODO may be a better way to do this in pure OpenCV
c, h, w = chw_scratch.shape
chw_scratch[0] = source[..., 0]
ctypes_equalize(chw_scratch.ctypes.data,
chw_scratch.ctypes.data,
h, w)
chw_scratch[1] = source[..., 1]
ctypes_equalize(chw_scratch.ctypes.data + h*w,
chw_scratch.ctypes.data + h*w,
h, w)
chw_scratch[2] = source[..., 2]
ctypes_equalize(chw_scratch.ctypes.data + 2*h*w,
chw_scratch.ctypes.data + 2*h*w,
h, w)
destination[..., 0] = chw_scratch[0]
destination[..., 1] = chw_scratch[1]
destination[..., 2] = chw_scratch[2]


@njit(parallel=False, fastmath=True, inline='always')
def invert(source, destination):
destination[:] = 255 - source


@njit(parallel=False, fastmath=True, inline='always')
def solarize(source, scratch, threshold, destination):
invert(source, scratch)
destination[:] = np.where(source >= threshold, scratch, source)


@njit(parallel=False, fastmath=True, inline='always')
def posterize(source, bits, destination):
mask = ~(2 ** (8 - bits) - 1)
destination[:] = source & mask


@njit(inline='always')
def blend(source1, source2, ratio, destination):
ctypes_add_weighted(source1.ctypes.data, ratio,
source2.ctypes.data, 1 - ratio,
destination.ctypes.data,
source1.shape[0], source1.shape[1])


@njit(inline='always')
def adjust_brightness(source, scratch, factor, destination):
scratch[:] = 0
blend(source, scratch, factor, destination)


@njit(parallel=False, fastmath=True, inline='always')
def adjust_saturation(source, scratch, factor, destination):
# TODO numpy autocasting probably allocates memory here,
# should be more careful.
# TODO do we really need scratch for this? could use destination
scratch[...,0] = 0.299 * source[..., 0] + \
0.587 * source[..., 1] + \
0.114 * source[..., 2]
scratch[...,1] = scratch[...,0]
scratch[...,2] = scratch[...,1]

blend(source, scratch, factor, destination)


@njit(parallel=False, fastmath=True, inline='always')
def adjust_contrast(source, scratch, factor, destination):
# TODO assuming 3 channels
scratch[:,:,:] = np.mean(0.299 * source[..., 0] +
0.587 * source[..., 1] +
0.114 * source[..., 2])

blend(source, scratch, factor, destination)


@njit(fastmath=True, inline='always')
def sharpen(source, scratch, amount, destination):
ctypes_unsharp_mask(source.ctypes.data,
scratch.ctypes.data,
source.shape[0], source.shape[1])

# in PyTorch's implementation,
# the border is unaffected
scratch[0,:] = source[0,:]
scratch[1:,0] = source[1:,0]
scratch[-1,:] = source[-1,:]
scratch[1:-1,-1] = source[1:-1,-1]

blend(source, scratch, amount, destination)


"""
Translation, x and y
Assuming this is faster than warpAffine;
also assuming tx and ty are ints
"""
@njit(inline='always')
def translate(source, destination, tx, ty):
if tx == 0 and ty == 0:
destination[:] = source
return
if tx > 0:
destination[:, tx:] = source[:, :-tx]
destination[:, :tx] = 0
if tx < 0:
destination[:, :tx] = source[:, -tx:]
destination[:, tx:] = 0
if ty > 0:
destination[ty:, :] = source[:-ty, :]
destination[:ty, :] = 0
if ty < 0:
destination[:ty, :] = source[-ty:, :]
destination[ty:, :] = 0


@njit(inline='always')
def rotate(source, destination, angle):
ctypes_rotate(angle,
source.ctypes.data,
destination.ctypes.data,
source.shape[0], source.shape[1])


@njit(inline='always')
def shear(source, destination, shear_x, shear_y):
ctypes_shear(shear_x, shear_y,
source.ctypes.data,
destination.ctypes.data,
source.shape[0], source.shape[1])


@njit(inline='always')
def resize_crop(source, start_row, end_row, start_col, end_col, destination):
Expand Down
Loading