diff --git a/ffcv/libffcv.py b/ffcv/libffcv.py index 693269f6..b9985d1c 100644 --- a/ffcv/libffcv.py +++ b/ffcv/libffcv.py @@ -2,7 +2,7 @@ from numba import njit import numpy as np import platform -from ctypes import CDLL, c_int64, c_uint8, c_uint64, POINTER, c_void_p, c_uint32, c_bool, cdll +from ctypes import CDLL, c_int64, c_uint8, c_uint64, c_float, POINTER, c_void_p, c_uint32, c_bool, cdll import ffcv._libffcv lib = CDLL(ffcv._libffcv.__file__) @@ -22,6 +22,22 @@ def read(fileno:int, destination:np.ndarray, offset:int): ctypes_resize = lib.resize ctypes_resize.argtypes = 11 * [c_int64] +ctypes_rotate = lib.rotate +ctypes_rotate.argtypes = [c_float, c_int64, c_int64, c_int64, c_int64] + +ctypes_shear = lib.shear +ctypes_shear.argtypes = [c_float, c_float, c_int64, c_int64, c_int64, c_int64] + +ctypes_add_weighted = lib.add_weighted +ctypes_add_weighted.argtypes = [c_int64, c_float, c_int64, c_float, c_int64, c_int64, c_int64] + +ctypes_equalize = lib.equalize +ctypes_equalize.argtypes = 4 * [c_int64] + +ctypes_unsharp_mask = lib.unsharp_mask +ctypes_unsharp_mask.argtypes = 4 * [c_int64] + + def resize_crop(source, start_row, end_row, start_col, end_col, destination): ctypes_resize(0, source.ctypes.data, @@ -52,5 +68,4 @@ def imdecode(source: np.ndarray, dst: np.ndarray, ctypes_memcopy.argtypes = [c_void_p, c_void_p, c_uint64] def memcpy(source: np.ndarray, dest: np.ndarray): - return ctypes_memcopy(source.ctypes.data, dest.ctypes.data, source.size*source.itemsize) - + return ctypes_memcopy(source.ctypes.data, dest.ctypes.data, source.size*source.itemsize) \ No newline at end of file diff --git a/ffcv/pipeline/graph.py b/ffcv/pipeline/graph.py index 05da7cee..220df1bc 100644 --- a/ffcv/pipeline/graph.py +++ b/ffcv/pipeline/graph.py @@ -2,6 +2,7 @@ import warnings import ast + try: # Useful for debugging import astor @@ -330,7 +331,7 @@ def collect_requirements(self, state=INITIAL_STATE, next_state, allocation = operation.declare_state_and_memory(state) state_allocation = operation.declare_shared_memory(state) - if next_state.device.type != 'cuda' and isinstance(operation, + if next_state.device != 'cuda' and isinstance(operation, ModuleWrapper): msg = ("Using a pytorch transform on the CPU is extremely" "detrimental to the performance, consider moving the augmentation" @@ -485,4 +486,4 @@ def codegen_all(self, code): code_stages.append(self.codegen_stage(stage, s_ix, op_to_node, code, already_defined)) final_output = [x.id for x in self.leaf_nodes.values()] - return code_stages, final_output \ No newline at end of file + return code_stages, final_output diff --git a/ffcv/transforms/__init__.py b/ffcv/transforms/__init__.py index dc58b55f..c2640078 100644 --- a/ffcv/transforms/__init__.py +++ b/ffcv/transforms/__init__.py @@ -9,6 +9,7 @@ from .translate import RandomTranslate from .mixup import ImageMixup, LabelMixup, MixupToOneHot from .module import ModuleWrapper +from .randaugment import RandAugment from .solarization import Solarization from .color_jitter import RandomBrightness, RandomContrast, RandomSaturation from .erasing import RandomErasing @@ -19,6 +20,7 @@ 'RandomResizedCrop', 'RandomHorizontalFlip', 'RandomTranslate', 'Cutout', 'RandomCutout', 'RandomErasing', 'ImageMixup', 'LabelMixup', 'MixupToOneHot', + 'RandAugment', 'Poison', 'ReplaceLabel', 'ModuleWrapper', 'Solarization', diff --git a/ffcv/transforms/randaugment.py b/ffcv/transforms/randaugment.py new file mode 100644 index 00000000..5f9b0733 --- /dev/null +++ b/ffcv/transforms/randaugment.py @@ -0,0 +1,118 @@ +import numpy as np +from ffcv.pipeline.compiler import Compiler +from ffcv.pipeline.operation import Operation, AllocationQuery +from dataclasses import replace +from typing import Callable, Optional, Tuple +from ffcv.pipeline.state import State +from ffcv.transforms.utils.fast_crop import rotate, shear, blend, \ + adjust_contrast, posterize, invert, solarize, equalize, fast_equalize, \ + autocontrast, sharpen, adjust_saturation, translate, adjust_brightness + +class RandAugment(Operation): + def __init__(self, + num_ops: int = 2, + magnitude: int = 9, + num_magnitude_bins: int = 31): + super().__init__() + self.num_ops = num_ops + self.magnitude = magnitude + num_bins = num_magnitude_bins + # index, name (for readability); bins, sign multiplier + # those with a -1 can have negative magnitude with probability 0.5 + self.op_table = [ + (0, "Identity", np.array(0.0), 1), + (1, "ShearX", np.linspace(0.0, 0.3, num_bins), -1), + (2, "ShearY", np.linspace(0.0, 0.3, num_bins), -1), + (3, "TranslateX", np.linspace(0.0, 150.0 / 331.0, num_bins), -1), + (4, "TranslateY", np.linspace(0.0, 150.0 / 331.0, num_bins), -1), + (5, "Rotate", np.linspace(0.0, 30.0, num_bins), -1), + (6, "Brightness", np.linspace(0.0, 0.9, num_bins), -1), + (7, "Color", np.linspace(0.0, 0.9, num_bins), -1), + (8, "Contrast", np.linspace(0.0, 0.9, num_bins), -1), + (9, "Sharpness", np.linspace(0.0, 0.9, num_bins), -1), + (10, "Posterize", 8 - (np.arange(num_bins) / ((num_bins - 1) / 4)).round(), 1), + (11, "Solarize", np.linspace(255.0, 0.0, num_bins), 1), + (12, "AutoContrast", np.array(0.0), 1), + (13, "Equalize", np.array(0.0), 1), + ] + + def generate_code(self) -> Callable: + my_range = Compiler.get_iterator() + op_table = self.op_table + magnitudes = np.array([(op[2][self.magnitude] if op[2].ndim > 0 else 0) for op in self.op_table]) + is_signed = np.array([op[3] for op in self.op_table]) + num_ops = self.num_ops +# for i in range(len(magnitudes)): +# print(i, op_table[i][1], '%.3f'%magnitudes[i]) + def randaug(im, mem): + dst, scratch, lut, scratchf = mem + for i in my_range(im.shape[0]): + for n in range(num_ops): + if n == 0: + src = im + else: + src = dst + + idx = np.random.randint(low=0, high=13+1) + mag = magnitudes[idx] + if np.random.random() < 0.5: + mag = mag * is_signed[idx] + + # Not worth fighting numba at the moment. + # TODO + if idx == 0: + dst[i][:] = src[i] + + if idx == 1: # ShearX (0.004) + shear(src[i], dst[i], mag, 0) + + if idx == 2: # ShearY + shear(src[i], dst[i], 0, mag) + + if idx == 3: # TranslateX + translate(src[i], dst[i], int(src[i].shape[1] * mag), 0) + + if idx == 4: # TranslateY + translate(src[i], dst[i], 0, int(src[i].shape[2] * mag)) + + if idx == 5: # Rotate + rotate(src[i], dst[i], mag) + + if idx == 6: # Brightness + adjust_brightness(src[i], scratch[i][0], 1.0 + mag, dst[i]) + + if idx == 7: # Color + adjust_saturation(src[i], scratch[i][0], 1.0 + mag, dst[i]) + + if idx == 8: # Contrast + adjust_contrast(src[i], scratch[i][0], 1.0 + mag, dst[i]) + + if idx == 9: # Sharpness + sharpen(src[i], scratch[i][0], 1.0 + mag, dst[i]) + + if idx == 10: # Posterize + posterize(src[i], int(mag), dst[i]) + + if idx == 11: # Solarize + solarize(src[i], scratch[i][0], mag, dst[i]) + + if idx == 12: # AutoContrast (TODO: takes 0.04s -> 0.052s) (+0.01s) + autocontrast(src[i], scratchf[i][0], dst[i]) + + if idx == 13: # Equalize (TODO: +0.008s) + equalize(src[i], lut[i], dst[i]) + + return dst + + randaug.is_parallel = True + return randaug + + def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Optional[AllocationQuery]]: + assert previous_state.jit_mode + h, w, c = previous_state.shape + return replace(previous_state, shape=previous_state.shape), [ + AllocationQuery(previous_state.shape, dtype=np.dtype('uint8')), + AllocationQuery((1, h, w, c), dtype=np.dtype('uint8')), + AllocationQuery((c, 256), dtype=np.dtype('int16')), + AllocationQuery((1, h, w, c), dtype=np.dtype('float32')), + ] \ No newline at end of file diff --git a/ffcv/transforms/utils/fast_crop.py b/ffcv/transforms/utils/fast_crop.py index 34cb7835..9dcba9c3 100644 --- a/ffcv/transforms/utils/fast_crop.py +++ b/ffcv/transforms/utils/fast_crop.py @@ -1,7 +1,195 @@ import ctypes -from numba import njit +from numba import njit, prange import numpy as np -from ...libffcv import ctypes_resize +from ...libffcv import ctypes_resize, ctypes_rotate, ctypes_shear, \ + ctypes_add_weighted, ctypes_equalize, ctypes_unsharp_mask + +""" +Requires a float32 scratch array +""" +@njit(parallel=True, fastmath=True, inline='always') +def autocontrast(source, scratchf, destination): + # numba: no kwargs in min? as a consequence, I might as well have written + # this in C++ + # TODO assuming 3 channels + minimum = [source[..., 0].min(), source[..., 1].min(), source[..., 2].min()] + maximum = [source[..., 0].max(), source[..., 1].max(), source[..., 2].max()] + scale = [0.0, 0.0, 0.0] + for i in prange(source.shape[-1]): + if minimum[i] == maximum[i]: + scale[i] = 1 + minimum[i] = 0 + else: + scale[i] = 255. / (maximum[i] - minimum[i]) + for i in prange(source.shape[-1]): + scratchf[..., i] = source[..., i] - minimum[i] + scratchf[..., i] = scratchf[..., i] * scale[i] + np.clip(scratchf, 0, 255, out=scratchf) + destination[:] = scratchf + + +""" +Custom equalize -- equivalent to torchvision.transforms.functional.equalize, +but probably slow -- scratch is a (channels, 256) uint16 array. +""" +@njit(parallel=True, fastmath=True, inline='always') +def equalize(source, scratch, destination): + for i in prange(source.shape[-1]): + # TODO memory less than ideal for bincount() and hist() + scratch[i] = np.bincount(source[..., i].flatten(), minlength=256) + nonzero_hist = scratch[i][scratch[i] != 0] + step = nonzero_hist[:-1].sum() // 255 + + if step == 0: + continue + + scratch[i][1:] = scratch[i].cumsum()[:-1] + scratch[i] = (scratch[i] + step // 2) // step + scratch[i][0] = 0 + np.clip(scratch[i], 0, 255, out=scratch[i]) + + # numba doesn't like 2d advanced indexing + for row in prange(source.shape[0]): + destination[row, :, i] = scratch[i][source[row, :, i]] + +""" +Equalize using OpenCV -- not equivalent to +torchvision.transforms.functional.equalize for so-far-unknown reasons. +""" +@njit(parallel=False, fastmath=True, inline='always') +def fast_equalize(source, chw_scratch, destination): + # this seems kind of hacky + # also, assuming ctypes_equalize allocates a minimal amount of memory + # which may be incorrect -- so maybe we should do this from scratch. + # TODO may be a better way to do this in pure OpenCV + c, h, w = chw_scratch.shape + chw_scratch[0] = source[..., 0] + ctypes_equalize(chw_scratch.ctypes.data, + chw_scratch.ctypes.data, + h, w) + chw_scratch[1] = source[..., 1] + ctypes_equalize(chw_scratch.ctypes.data + h*w, + chw_scratch.ctypes.data + h*w, + h, w) + chw_scratch[2] = source[..., 2] + ctypes_equalize(chw_scratch.ctypes.data + 2*h*w, + chw_scratch.ctypes.data + 2*h*w, + h, w) + destination[..., 0] = chw_scratch[0] + destination[..., 1] = chw_scratch[1] + destination[..., 2] = chw_scratch[2] + + +@njit(parallel=False, fastmath=True, inline='always') +def invert(source, destination): + destination[:] = 255 - source + + +@njit(parallel=False, fastmath=True, inline='always') +def solarize(source, scratch, threshold, destination): + invert(source, scratch) + destination[:] = np.where(source >= threshold, scratch, source) + + +@njit(parallel=False, fastmath=True, inline='always') +def posterize(source, bits, destination): + mask = ~(2 ** (8 - bits) - 1) + destination[:] = source & mask + + +@njit(inline='always') +def blend(source1, source2, ratio, destination): + ctypes_add_weighted(source1.ctypes.data, ratio, + source2.ctypes.data, 1 - ratio, + destination.ctypes.data, + source1.shape[0], source1.shape[1]) + + +@njit(inline='always') +def adjust_brightness(source, scratch, factor, destination): + scratch[:] = 0 + blend(source, scratch, factor, destination) + + +@njit(parallel=False, fastmath=True, inline='always') +def adjust_saturation(source, scratch, factor, destination): + # TODO numpy autocasting probably allocates memory here, + # should be more careful. + # TODO do we really need scratch for this? could use destination + scratch[...,0] = 0.299 * source[..., 0] + \ + 0.587 * source[..., 1] + \ + 0.114 * source[..., 2] + scratch[...,1] = scratch[...,0] + scratch[...,2] = scratch[...,1] + + blend(source, scratch, factor, destination) + + +@njit(parallel=False, fastmath=True, inline='always') +def adjust_contrast(source, scratch, factor, destination): + # TODO assuming 3 channels + scratch[:,:,:] = np.mean(0.299 * source[..., 0] + + 0.587 * source[..., 1] + + 0.114 * source[..., 2]) + + blend(source, scratch, factor, destination) + + +@njit(fastmath=True, inline='always') +def sharpen(source, scratch, amount, destination): + ctypes_unsharp_mask(source.ctypes.data, + scratch.ctypes.data, + source.shape[0], source.shape[1]) + + # in PyTorch's implementation, + # the border is unaffected + scratch[0,:] = source[0,:] + scratch[1:,0] = source[1:,0] + scratch[-1,:] = source[-1,:] + scratch[1:-1,-1] = source[1:-1,-1] + + blend(source, scratch, amount, destination) + + +""" +Translation, x and y +Assuming this is faster than warpAffine; +also assuming tx and ty are ints +""" +@njit(inline='always') +def translate(source, destination, tx, ty): + if tx == 0 and ty == 0: + destination[:] = source + return + if tx > 0: + destination[:, tx:] = source[:, :-tx] + destination[:, :tx] = 0 + if tx < 0: + destination[:, :tx] = source[:, -tx:] + destination[:, tx:] = 0 + if ty > 0: + destination[ty:, :] = source[:-ty, :] + destination[:ty, :] = 0 + if ty < 0: + destination[:ty, :] = source[-ty:, :] + destination[ty:, :] = 0 + + +@njit(inline='always') +def rotate(source, destination, angle): + ctypes_rotate(angle, + source.ctypes.data, + destination.ctypes.data, + source.shape[0], source.shape[1]) + + +@njit(inline='always') +def shear(source, destination, shear_x, shear_y): + ctypes_shear(shear_x, shear_y, + source.ctypes.data, + destination.ctypes.data, + source.shape[0], source.shape[1]) + @njit(inline='always') def resize_crop(source, start_row, end_row, start_col, end_col, destination): diff --git a/libffcv/libffcv.cpp b/libffcv/libffcv.cpp index db4798d1..06fcff14 100644 --- a/libffcv/libffcv.cpp +++ b/libffcv/libffcv.cpp @@ -34,13 +34,78 @@ extern "C" { int64_t start_row, int64_t end_row, int64_t start_col, int64_t end_col, int64_t dest_p, int64_t tx, int64_t ty) { // TODO use proper arguments type - cv::Mat source_matrix(sx, sy, CV_8UC3, (uint8_t*) source_p); cv::Mat dest_matrix(tx, ty, CV_8UC3, (uint8_t*) dest_p); cv::resize(source_matrix.colRange(start_col, end_col).rowRange(start_row, end_row), dest_matrix, dest_matrix.size(), 0, 0, cv::INTER_AREA); } + + EXPORT void rotate(float angle, int64_t source_p, int64_t dest_p, int64_t sx, int64_t sy) { + cv::Mat source_matrix(sx, sy, CV_8UC3, (uint8_t*) source_p); + cv::Mat dest_matrix(sx, sy, CV_8UC3, (uint8_t*) dest_p); + // TODO unsure if this should be sx, sy + cv::Point2f center((sy-1) / 2.0, (sx-1) / 2.0); + cv::Mat rotation = cv::getRotationMatrix2D(center, angle, 1.0); + cv::warpAffine(source_matrix.colRange(0, sy).rowRange(0, sx), + dest_matrix, rotation, dest_matrix.size(), cv::INTER_NEAREST); + } + + EXPORT void shear(float shear_x, float shear_y, int64_t source_p, int64_t dest_p, int64_t sx, int64_t sy) { + cv::Mat source_matrix(sx, sy, CV_8UC3, (uint8_t*) source_p); + cv::Mat dest_matrix(sx, sy, CV_8UC3, (uint8_t*) dest_p); + + float _shear[6] = { 1, shear_x, 0, shear_y, 1, 0 }; + + float cx = (sx - 1) / 2.0; + float cy = (sy - 1) / 2.0; + + _shear[2] += _shear[0] * -cx + _shear[1] * -cy; + _shear[5] += _shear[3] * -cx + _shear[4] * -cy; + + _shear[2] += cx; + _shear[5] += cy; + + cv::Mat shear = cv::Mat(2, 3, CV_32F, _shear); + cv::warpAffine(source_matrix.colRange(0, sy).rowRange(0, sx), + dest_matrix, shear, dest_matrix.size(), cv::INTER_NEAREST); + } + + EXPORT void add_weighted(int64_t img1_p, float a, int64_t img2_p, float b, int64_t dest_p, int64_t sx, int64_t sy) { + cv::Mat img1(sx, sy, CV_8UC3, (uint8_t*) img1_p); + cv::Mat img2(sx, sy, CV_8UC3, (uint8_t*) img2_p); + cv::Mat dest_matrix(sx, sy, CV_8UC3, (uint8_t*) dest_p); + + // TODO doubt we need colRange/rowRange stuff + cv::addWeighted(img1.colRange(0, sy).rowRange(0, sx), a, + img2.colRange(0, sy).rowRange(0, sx), b, + 0, dest_matrix); + } + + EXPORT void equalize(int64_t source_p, int64_t dest_p, int64_t sx, int64_t sy) { + cv::Mat source_matrix(sx, sy, CV_8U, (uint8_t*) source_p); + cv::Mat dest_matrix(sx, sy, CV_8U, (uint8_t*) dest_p); + cv::equalizeHist(source_matrix.colRange(0, sy).rowRange(0, sx), + dest_matrix); + } + + EXPORT void unsharp_mask(int64_t source_p, int64_t dest_p, int64_t sx, int64_t sy) { + cv::Mat source_matrix(sx, sy, CV_8UC3, (uint8_t*) source_p); + cv::Mat dest_matrix(sx, sy, CV_8UC3, (uint8_t*) dest_p); + + cv::Point anchor(-1, -1); + + // 3x3 kernel, all 1s with 5 in center / sum of kernel + float _kernel[9] = { 0.0769, 0.0769, 0.0769, 0.0769, 0.3846, + 0.0769, 0.0769, 0.0769, 0.0769 }; + cv::Mat kernel = cv::Mat(3, 3, CV_32F, _kernel); + + cv::filter2D(source_matrix.colRange(0, sy).rowRange(0, sx), + dest_matrix, -1, kernel, anchor, 0, cv::BORDER_ISOLATED); + + //add_weighted(source_p, amount, dest_p, 1 - amount, dest_p, sx, sy); + } + EXPORT void my_memcpy(void *source, void* dst, uint64_t size) { memcpy(dst, source, size); } diff --git a/tests/.gitignore b/tests/.gitignore new file mode 100644 index 00000000..667770a1 --- /dev/null +++ b/tests/.gitignore @@ -0,0 +1 @@ +example_imgs/* diff --git a/tests/test_rand_aug.py b/tests/test_rand_aug.py new file mode 100644 index 00000000..826a75b1 --- /dev/null +++ b/tests/test_rand_aug.py @@ -0,0 +1,314 @@ +import time +import numpy as np +import torch +import matplotlib.pyplot as plt +from ffcv.fields import IntField, RGBImageField +from ffcv.fields.decoders import SimpleRGBImageDecoder +from ffcv.loader import Loader, OrderOption +from ffcv.transforms import ToTensor, ToTorchImage, RandAugment +from ffcv.writer import DatasetWriter +from dataclasses import replace +from typing import Callable, Optional, Tuple +from ffcv.pipeline.state import State +from ffcv.transforms.utils.fast_crop import rotate, shear, blend, \ + adjust_contrast, posterize, invert, solarize, equalize, fast_equalize, \ + autocontrast, sharpen, adjust_saturation, translate, adjust_brightness +import torchvision.transforms as tv +import cv2 +import pytest +import math + +@pytest.mark.parametrize('angle', [45, -30]) +def test_rotate(angle): + Xnp = np.random.uniform(0, 255, size=(32, 32, 3)).astype(np.uint8) + Ynp = np.zeros(Xnp.shape, dtype=np.uint8) + Xch = torch.tensor(Xnp.astype(np.float32)).permute(2, 0, 1) + Ych = tv.functional.rotate(Xch, angle).permute(1, 2, 0).numpy().astype(np.uint8) + rotate(Xnp, Ynp, angle) + + plt.subplot(1, 2, 1) + plt.imshow(Ynp) + plt.subplot(1, 2, 2) + plt.imshow(Ych) + plt.savefig('example_imgs/rotate-%d.png' % angle) + + assert np.linalg.norm(Ynp.astype(np.float32) - Ych.astype(np.float32)) < 100 + #print(Ynp.min(), Ynp.max(), Ych.min(), Ych.max()) + + +@pytest.mark.parametrize('amt', [0.31, -0.31]) +def test_shear(amt): + Xnp = np.random.uniform(0, 255, size=(32, 32, 3)).astype(np.uint8) + Ynp = np.zeros(Xnp.shape, dtype=np.uint8) + Xch = torch.tensor(Xnp.astype(np.float32)).permute(2, 0, 1) + Ych = tv.functional.affine(Xch, + angle=0.0, + translate=[0, 0], + scale=1.0, + shear=[0, math.degrees(math.atan(amt))], + interpolation=tv.functional.InterpolationMode.NEAREST, + fill=0, + #center=[0, 0], + ).permute(1, 2, 0).numpy().astype(np.uint8) + shear(Xnp, Ynp, 0, -amt) + + plt.subplot(1, 2, 1) + plt.imshow(Ynp) + plt.subplot(1, 2, 2) + plt.imshow(Ych) + plt.savefig('example_imgs/shear-%f.png' % amt) + + assert np.linalg.norm(Ynp.astype(np.float32) - Ych.astype(np.float32)) < 100 + #print(Ynp.min(), Ynp.max(), Ych.min(), Ych.max()) + + +@pytest.mark.parametrize('amt', [0.5]) +def test_brightness(amt): + Xnp = np.random.uniform(0, 256, size=(32, 32, 3)).astype(np.uint8) + Ynp = np.zeros(Xnp.shape, dtype=np.uint8) + Snp = np.zeros(Xnp.shape, dtype=np.uint8) + Xch = torch.tensor(Xnp.astype(np.float32)/255.).permute(2, 0, 1) + Ych = (255*tv.functional.adjust_brightness(Xch, amt).permute(1, 2, 0).numpy()).astype(np.uint8) + adjust_brightness(Xnp, Snp, amt, Ynp) + + plt.subplot(1, 2, 1) + plt.imshow(Ynp) + plt.subplot(1, 2, 2) + plt.imshow(Ych) + plt.savefig('example_imgs/brightness-%.2f.png' % amt) + + assert np.linalg.norm(Ynp.astype(np.float32) - Ych.astype(np.float32)) < 100 + #print(Ynp.min(), Ynp.max(), Ych.min(), Ych.max()) + + +@pytest.mark.parametrize('amt', [0.5]) +def test_adjust_contrast(amt): + Xnp = np.random.uniform(0, 256, size=(32, 32, 3)).astype(np.uint8) + Ynp = np.zeros(Xnp.shape, dtype=np.uint8) + Snp = np.zeros(Xnp.shape, dtype=np.uint8) + Xch = torch.tensor(Xnp.astype(np.float32)/255.).permute(2, 0, 1) + Ych = (255*tv.functional.adjust_contrast(Xch, amt).permute(1, 2, 0).numpy()).astype(np.uint8) + adjust_contrast(Xnp, Snp, amt, Ynp) + + plt.subplot(1, 2, 1) + plt.imshow(Ynp) + plt.subplot(1, 2, 2) + plt.imshow(Ych) + plt.savefig('example_imgs/adjust_contrast-%.2f.png' % amt) + + assert np.linalg.norm(Ynp.astype(np.float32) - Ych.astype(np.float32)) < 100 + #print(Ynp.min(), Ynp.max(), Ych.min(), Ych.max()) + + +@pytest.mark.parametrize('bits', [2, 3, 4, 5, 6, 7]) +def test_posterize(bits): + Xnp = np.random.uniform(0, 256, size=(32, 32, 3)).astype(np.uint8) + Ynp = np.zeros(Xnp.shape, dtype=np.uint8) + Xch = torch.tensor(Xnp).permute(2, 0, 1) + Ych = tv.functional.posterize(Xch, bits).permute(1, 2, 0).numpy() + posterize(Xnp, bits, Ynp) + + plt.subplot(1, 2, 1) + plt.imshow(Ynp) + plt.subplot(1, 2, 2) + plt.imshow(Ych) + plt.savefig('example_imgs/posterize-%d.png' % bits) + + assert np.linalg.norm(Ynp.astype(np.float32) - Ych.astype(np.float32)) < 100 + + +def test_invert(): + Xnp = np.random.uniform(0, 256, size=(32, 32, 3)).astype(np.uint8) + Xnp[5:9,5:9,:] = 0 + Ynp = np.zeros(Xnp.shape, dtype=np.uint8) + Xch = torch.tensor(Xnp).permute(2, 0, 1) + Ych = tv.functional.invert(Xch).permute(1, 2, 0).numpy() + invert(Xnp, Ynp) + + plt.subplot(1, 2, 1) + plt.imshow(Ynp) + plt.subplot(1, 2, 2) + plt.imshow(Ych) + plt.savefig('example_imgs/invert.png') + + assert np.linalg.norm(Ynp.astype(np.float32) - Ych.astype(np.float32)) < 100 + + +@pytest.mark.parametrize('threshold', [9]) +def test_solarize(threshold): + Xnp = np.random.uniform(0, 256, size=(32, 32, 3)).astype(np.uint8) + Xnp[5:9,5:9,:] = 0 + Xnp[10:15,10:15,:] = 8 + Xnp[27:31,27:31,:] = 9 + Ynp = np.zeros(Xnp.shape, dtype=np.uint8) + Snp = np.zeros(Xnp.shape, dtype=np.uint8) + Xch = torch.tensor(Xnp).permute(2, 0, 1) + Ych = tv.functional.solarize(Xch, threshold).permute(1, 2, 0).numpy() + solarize(Xnp, Snp, threshold, Ynp) + + plt.subplot(1, 2, 1) + plt.imshow(Ynp) + plt.subplot(1, 2, 2) + plt.imshow(Ych) + plt.savefig('example_imgs/solarize-%d.png' % threshold) + + assert np.linalg.norm(Ynp.astype(np.float32) - Ych.astype(np.float32)) < 100 + + +def test_equalize(): + Xnp = np.random.uniform(0, 256, size=(32, 32, 3)).astype(np.uint8) + #Xnp = cv2.imread('example_imgs/0249.png') + Xnp[5:9,5:9,:] = 0 + Ynp = np.zeros(Xnp.shape, dtype=np.uint8) + #Snp_chw = np.zeros((3, 32, 32), dtype=np.uint8) + Snp = np.zeros((3, 256), dtype=np.int16) + Xch = torch.tensor(Xnp).permute(2, 0, 1) + Ych = tv.functional.equalize(Xch).permute(1, 2, 0).numpy() + #fast_equalize(Xnp, Snp_chw, Ynp) + equalize(Xnp, Snp, Ynp) + + plt.subplot(2, 2, 1) + plt.imshow(Xnp) + plt.subplot(2, 2, 2) + plt.imshow(Ynp) + plt.subplot(2, 2, 3) + plt.imshow(Xch.permute(1, 2, 0).numpy()) + plt.subplot(2, 2, 4) + plt.imshow(Ych) + plt.savefig('example_imgs/equalize.png') + + assert np.linalg.norm(Ynp.astype(np.float32) - Ych.astype(np.float32)) < 100 + + +def test_autocontrast(): + Xnp = np.random.uniform(0, 256, size=(32, 32, 3)).astype(np.uint8) + #Xnp = cv2.imread('example_imgs/0249.png') + Xnp[5:9,5:9,:] = 0 + Ynp = np.zeros(Xnp.shape, dtype=np.uint8) + Snp = np.zeros((32, 32, 3), dtype=np.float32) + Xch = torch.tensor(Xnp).permute(2, 0, 1) + Ych = tv.functional.autocontrast(Xch).permute(1, 2, 0).numpy() + autocontrast(Xnp, Snp, Ynp) + + plt.subplot(2, 2, 1) + plt.imshow(Xnp) + plt.subplot(2, 2, 2) + plt.imshow(Ynp) + plt.subplot(2, 2, 3) + plt.imshow(Xch.permute(1, 2, 0).numpy()) + plt.subplot(2, 2, 4) + plt.imshow(Ych) + plt.savefig('example_imgs/autocontrast.png') + + assert np.linalg.norm(Ynp.astype(np.float32) - Ych.astype(np.float32)) < 100 + + +@pytest.mark.parametrize('amt', [0.5, 0.75, 1.0, 2.0]) +def test_sharpen(amt): + Xnp = np.random.uniform(0, 256, size=(32, 32, 3)).astype(np.uint8) + #Xnp = cv2.imread('example_imgs/0249.png') + Ynp = np.zeros(Xnp.shape, dtype=np.uint8) + Snp = np.zeros(Xnp.shape, dtype=np.uint8) + Xch = torch.tensor(Xnp).permute(2, 0, 1) + Ych = tv.functional.adjust_sharpness(Xch, amt).permute(1, 2, 0).numpy() + sharpen(Xnp, Snp, amt, Ynp) + + plt.subplot(1, 2, 1) + plt.imshow(Ynp) + plt.subplot(1, 2, 2) + plt.imshow(Ych) + plt.savefig('example_imgs/sharpen-%.2f.png' % amt) + + assert np.linalg.norm(Ynp.astype(np.float32) - Ych.astype(np.float32)) < 100 + + +@pytest.mark.parametrize('amt', [0.5, 1.5]) +def test_adjust_saturation(amt): + Xnp = np.random.uniform(0, 256, size=(32, 32, 3)).astype(np.uint8) + #Xnp = cv2.imread('example_imgs/0249.png') + Ynp = np.zeros(Xnp.shape, dtype=np.uint8) + Snp = np.zeros(Xnp.shape, dtype=np.uint8) + Xch = torch.tensor(Xnp.astype(np.float32)/255.).permute(2, 0, 1) + Ych = (255*tv.functional.adjust_saturation(Xch, amt).permute(1, 2, 0).numpy()).astype(np.uint8) + adjust_saturation(Xnp, Snp, amt, Ynp) + + plt.subplot(2, 2, 1) + plt.imshow(Xnp) + plt.subplot(2, 2, 2) + plt.imshow(Ynp) + plt.subplot(2, 2, 3) + plt.imshow(Xch.permute(1, 2, 0).numpy()) + plt.subplot(2, 2, 4) + plt.imshow(Ych) + plt.savefig('example_imgs/adjust_saturation-%.2f.png' % amt) + + assert np.linalg.norm(Ynp.astype(np.float32) - Ych.astype(np.float32)) < 100 + #print(Ynp.min(), Ynp.max(), Ych.min(), Ych.max()) + + +@pytest.mark.parametrize('amt', [(4, 0), (0, 4), (-4, 0), (0, -4)]) +def test_translate(amt): + Xnp = np.random.uniform(0, 255, size=(32, 32, 3)).astype(np.uint8) + Ynp = np.zeros(Xnp.shape, dtype=np.uint8) + Xch = torch.tensor(Xnp.astype(np.float32)).permute(2, 0, 1) + Ych = tv.functional.affine(Xch, + angle=0.0, + translate=[amt[0], amt[1]], + scale=1.0, + shear=[0, 0], + interpolation=tv.functional.InterpolationMode.NEAREST, + fill=0, + #center=[0, 0], + ).permute(1, 2, 0).numpy().astype(np.uint8) + translate(Xnp, Ynp, amt[0], amt[1]) + + plt.subplot(1, 2, 1) + plt.imshow(Ynp) + plt.subplot(1, 2, 2) + plt.imshow(Ych) + plt.savefig('example_imgs/translate-%d-%d.png' % amt) + + assert np.linalg.norm(Ynp.astype(np.float32) - Ych.astype(np.float32)) < 100 + + +if __name__ == '__main__': +# test_rotate(45) +# test_shear(0.31) +# test_brightness(0.5) +# test_adjust_contrast(0.5) +# test_posterize(2) +# test_invert() +# test_solarize(9) +# test_equalize() +# test_autocontrast() +# test_sharpen(2.0) +# test_adjust_saturation(0.5) +# test_translate((4, 0)) + + BATCH_SIZE = 512 + image_pipelines = { + 'with': [SimpleRGBImageDecoder(), RandAugment(32), ToTensor()], + 'without': [SimpleRGBImageDecoder(), ToTensor()], + 'torchvision': [SimpleRGBImageDecoder(), ToTensor(), ToTorchImage(), tv.RandAugment(num_ops=2, magnitude=10)] + } + + for name, pipeline in image_pipelines.items(): + loader = Loader('/home/ashert/iclr-followup/ffcv/ffcv/examples/cifar/betons/cifar_train.beton', batch_size=BATCH_SIZE, + num_workers=2, order=OrderOption.RANDOM, + drop_last=True, pipelines={'image': pipeline}) + + import matplotlib.pyplot as plt + idx = 0 + for ims, labs in loader: pass +# print('a') +# if name == 'with': +# for i in range(5): +# for j in range(5): +# plt.subplot(5, 5, i * 5 + j + 1) +# plt.imshow(ims[i * 5 + j].numpy()) +# plt.savefig('scratch/%d.png'%idx) +# idx+=1 + start_time = time.time() + for _ in range(5): #(100): + for ims, labs in loader: pass + print(f'Method: {name} | Shape: {ims.shape} | Time per epoch: {(time.time() - start_time) / 100:.5f}s')