From 19f4039b992561eb93542ca43dd841218d206b56 Mon Sep 17 00:00:00 2001 From: Asher Trockman Date: Tue, 15 Feb 2022 20:20:28 -0500 Subject: [PATCH 01/20] Init RandAug work --- ffcv/libffcv.py | 16 ++- ffcv/transforms/utils/fast_crop.py | 37 ++++++- libffcv/libffcv.cpp | 44 +++++++- tests/.gitignore | 1 + tests/test_rand_aug.py | 162 +++++++++++++++++++++++++++++ 5 files changed, 253 insertions(+), 7 deletions(-) create mode 100644 tests/.gitignore create mode 100644 tests/test_rand_aug.py diff --git a/ffcv/libffcv.py b/ffcv/libffcv.py index 52219f3c..ae923f52 100644 --- a/ffcv/libffcv.py +++ b/ffcv/libffcv.py @@ -1,7 +1,7 @@ import ctypes from numba import njit import numpy as np -from ctypes import CDLL, c_int64, c_uint8, c_uint64, POINTER, c_void_p, c_uint32, c_bool, cdll +from ctypes import CDLL, c_int64, c_uint8, c_uint64, c_float, POINTER, c_void_p, c_uint32, c_bool, cdll import ffcv._libffcv lib = CDLL(ffcv._libffcv.__file__) @@ -13,10 +13,19 @@ def read(fileno:int, destination:np.ndarray, offset:int): return read_c(fileno, destination.ctypes.data, destination.size, offset) - ctypes_resize = lib.resize ctypes_resize.argtypes = 11 * [c_int64] +ctypes_rotate = lib.rotate +ctypes_rotate.argtypes = [c_float, c_int64, c_int64, c_int64, c_int64] + +ctypes_shear = lib.shear +ctypes_shear.argtypes = [c_float, c_float, c_int64, c_int64, c_int64, c_int64] + +ctypes_add_weighted = lib.add_weighted +ctypes_add_weighted.argtypes = [c_int64, c_float, c_int64, c_float, c_int64, c_int64, c_int64] + + def resize_crop(source, start_row, end_row, start_col, end_col, destination): ctypes_resize(0, source.ctypes.data, @@ -47,5 +56,4 @@ def imdecode(source: np.ndarray, dst: np.ndarray, ctypes_memcopy.argtypes = [c_void_p, c_void_p, c_uint64] def memcpy(source: np.ndarray, dest: np.ndarray): - return ctypes_memcopy(source.ctypes.data, dest.ctypes.data, source.size) - + return ctypes_memcopy(source.ctypes.data, dest.ctypes.data, source.size) \ No newline at end of file diff --git a/ffcv/transforms/utils/fast_crop.py b/ffcv/transforms/utils/fast_crop.py index 3b3f2af3..f0b3b78b 100644 --- a/ffcv/transforms/utils/fast_crop.py +++ b/ffcv/transforms/utils/fast_crop.py @@ -1,7 +1,42 @@ import ctypes from numba import njit import numpy as np -from ...libffcv import ctypes_resize +from ...libffcv import ctypes_resize, ctypes_rotate, ctypes_shear, ctypes_add_weighted + + +@njit(inline='always') +def blend(source1, source2, ratio, destination): + ctypes_add_weighted(source1.ctypes.data, ratio, + source2.ctypes.data, 1 - ratio, + destination.ctypes.data, + source1.shape[0], source1.shape[1]) + + +@njit(parallel=False, fastmath=True, inline='always') +def adjust_contrast(source, scratch, factor, destination): + # TODO assuming 3 channels + scratch[:,:,:] = np.mean(0.299 * source[..., 0] + + 0.587 * source[..., 1] + + 0.114 * source[..., 2]) + + blend(source, scratch, factor, destination) + + +@njit(inline='always') +def rotate(source, destination, angle): + ctypes_rotate(angle, + source.ctypes.data, + destination.ctypes.data, + source.shape[0], source.shape[1]) + + +@njit(inline='always') +def shear(source, destination, shear_x, shear_y): + ctypes_shear(shear_x, shear_y, + source.ctypes.data, + destination.ctypes.data, + source.shape[0], source.shape[1]) + @njit(inline='always') def resize_crop(source, start_row, end_row, start_col, end_col, destination): diff --git a/libffcv/libffcv.cpp b/libffcv/libffcv.cpp index 7bae23ba..475a7e22 100644 --- a/libffcv/libffcv.cpp +++ b/libffcv/libffcv.cpp @@ -27,17 +27,57 @@ extern "C" { int64_t start_row, int64_t end_row, int64_t start_col, int64_t end_col, int64_t dest_p, int64_t tx, int64_t ty) { // TODO use proper arguments type - cv::Mat source_matrix(sx, sy, CV_8UC3, (uint8_t*) source_p); cv::Mat dest_matrix(tx, ty, CV_8UC3, (uint8_t*) dest_p); cv::resize(source_matrix.colRange(start_col, end_col).rowRange(start_row, end_row), dest_matrix, dest_matrix.size(), 0, 0, cv::INTER_AREA); } + void rotate(float angle, int64_t source_p, int64_t dest_p, int64_t sx, int64_t sy) { + cv::Mat source_matrix(sx, sy, CV_8UC3, (uint8_t*) source_p); + cv::Mat dest_matrix(sx, sy, CV_8UC3, (uint8_t*) dest_p); + // TODO unsure if this should be sx, sy + cv::Point2f center((sy-1) / 2.0, (sx-1) / 2.0); + cv::Mat rotation = cv::getRotationMatrix2D(center, angle, 1.0); + cv::warpAffine(source_matrix.colRange(0, sy).rowRange(0, sx), + dest_matrix, rotation, dest_matrix.size(), cv::INTER_NEAREST); + } + + void shear(float shear_x, float shear_y, int64_t source_p, int64_t dest_p, int64_t sx, int64_t sy) { + cv::Mat source_matrix(sx, sy, CV_8UC3, (uint8_t*) source_p); + cv::Mat dest_matrix(sx, sy, CV_8UC3, (uint8_t*) dest_p); + + float _shear[6] = { 1, shear_x, 0, shear_y, 1, 0 }; + + float cx = (sx - 1) / 2.0; + float cy = (sy - 1) / 2.0; + + _shear[2] += _shear[0] * -cx + _shear[1] * -cy; + _shear[5] += _shear[3] * -cx + _shear[4] * -cy; + + _shear[2] += cx; + _shear[5] += cy; + + cv::Mat shear = cv::Mat(2, 3, CV_32F, _shear); + cv::warpAffine(source_matrix.colRange(0, sy).rowRange(0, sx), + dest_matrix, shear, dest_matrix.size(), cv::INTER_NEAREST); + } + + void add_weighted(int64_t img1_p, float a, int64_t img2_p, float b, int64_t dest_p, int64_t sx, int64_t sy) { + cv::Mat img1(sx, sy, CV_8UC3, (uint8_t*) img1_p); + cv::Mat img2(sx, sy, CV_8UC3, (uint8_t*) img2_p); + cv::Mat dest_matrix(sx, sy, CV_8UC3, (uint8_t*) dest_p); + + // TODO doubt we need colRange/rowRange stuff + cv::addWeighted(img1.colRange(0, sy).rowRange(0, sx), a, + img2.colRange(0, sy).rowRange(0, sx), b, + 0, dest_matrix); + } + void my_memcpy(void *source, void* dst, uint64_t size) { memcpy(dst, source, size); } - + void my_fread(int64_t fp, int64_t offset, void *destination, int64_t size) { fseek((FILE *) fp, offset, SEEK_SET); fread(destination, 1, size, (FILE *) fp); diff --git a/tests/.gitignore b/tests/.gitignore new file mode 100644 index 00000000..667770a1 --- /dev/null +++ b/tests/.gitignore @@ -0,0 +1 @@ +example_imgs/* diff --git a/tests/test_rand_aug.py b/tests/test_rand_aug.py new file mode 100644 index 00000000..b16ff5f8 --- /dev/null +++ b/tests/test_rand_aug.py @@ -0,0 +1,162 @@ +import time +import numpy as np +import torch +import matplotlib.pyplot as plt +from ffcv.fields import IntField, RGBImageField +from ffcv.fields.decoders import SimpleRGBImageDecoder +from ffcv.loader import Loader, OrderOption +from ffcv.pipeline.compiler import Compiler +from ffcv.pipeline.operation import Operation, AllocationQuery +from ffcv.transforms import ToTensor, ToTorchImage +from ffcv.writer import DatasetWriter +from dataclasses import replace +from typing import Callable, Optional, Tuple +from ffcv.pipeline.state import State +from ffcv.transforms.utils.fast_crop import rotate, shear, blend, adjust_contrast +import torchvision.transforms as tv +import cv2 +import pytest +import math + +class RandAugment(Operation): + def __init__(self, size: int): + super().__init__() + self.size = size + + def generate_code(self) -> Callable: + my_range = Compiler.get_iterator() + def randaug(im, mem): + dst, scratch = mem + for i in my_range(im.shape[0]): + + ## TODO actual randaug logic + + ## rotate + deg = np.random.random() * 45.0 + rotate(im[i], dst[i], deg) + + ## brighten + blend(im[i], scratch[i][0], 0.5, dst[i]) + + ## adjust contrast + adjust_contrast(im[i], scratch[i][0], 0.5, dst[i]) + + return dst + + randaug.is_parallel = True + return randaug + + def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Optional[AllocationQuery]]: + assert previous_state.jit_mode + return replace(previous_state, shape=(self.size, self.size, 3)), [ + AllocationQuery((self.size, self.size, 3), dtype=np.dtype('uint8')), + AllocationQuery((1, self.size, self.size, 3), dtype=np.dtype('uint8')) + ] + + +@pytest.mark.parametrize('angle', [45]) +def test_rotate(angle): + Xnp = np.random.uniform(0, 255, size=(32, 32, 3)).astype(np.uint8) + Ynp = np.zeros(Xnp.shape, dtype=np.uint8) + Xch = torch.tensor(Xnp.astype(np.float32)).permute(2, 0, 1) + Ych = tv.functional.rotate(Xch, angle).permute(1, 2, 0).numpy().astype(np.uint8) + rotate(Xnp, Ynp, angle) + + plt.subplot(1, 2, 1) + plt.imshow(Ynp) + plt.subplot(1, 2, 2) + plt.imshow(Ych) + plt.savefig('example_imgs/rotate-%d.png' % angle) + + assert np.linalg.norm(Ynp.astype(np.float32) - Ych.astype(np.float32)) < 100 + #print(Ynp.min(), Ynp.max(), Ych.min(), Ych.max()) + + +@pytest.mark.parametrize('amt', [0.31]) +def test_shear(amt): + Xnp = np.random.uniform(0, 255, size=(32, 32, 3)).astype(np.uint8) + Ynp = np.zeros(Xnp.shape, dtype=np.uint8) + Xch = torch.tensor(Xnp.astype(np.float32)).permute(2, 0, 1) + Ych = tv.functional.affine(Xch, + angle=0.0, + translate=[0, 0], + scale=1.0, + shear=[0, math.degrees(math.atan(0.31))], + interpolation=tv.functional.InterpolationMode.NEAREST, + fill=0, + #center=[0, 0], + ).permute(1, 2, 0).numpy().astype(np.uint8) + shear(Xnp, Ynp, 0, -0.31) + + plt.subplot(1, 2, 1) + plt.imshow(Ynp) + plt.subplot(1, 2, 2) + plt.imshow(Ych) + plt.savefig('example_imgs/shear-%f.png' % amt) + + assert np.linalg.norm(Ynp.astype(np.float32) - Ych.astype(np.float32)) < 100 + #print(Ynp.min(), Ynp.max(), Ych.min(), Ych.max()) + + +@pytest.mark.parametrize('amt', [0.5]) +def test_brightness(amt): + Xnp = np.random.uniform(0, 256, size=(32, 32, 3)).astype(np.uint8) + Ynp = np.zeros(Xnp.shape, dtype=np.uint8) + Snp = np.zeros(Xnp.shape, dtype=np.uint8) + Xch = torch.tensor(Xnp.astype(np.float32)/255.).permute(2, 0, 1) + Ych = (255*tv.functional.adjust_brightness(Xch, amt).permute(1, 2, 0).numpy()).astype(np.uint8) + blend(Xnp, Snp, amt, Ynp) + + plt.subplot(1, 2, 1) + plt.imshow(Ynp) + plt.subplot(1, 2, 2) + plt.imshow(Ych) + plt.savefig('example_imgs/brightness-%.2f.png' % amt) + + assert np.linalg.norm(Ynp.astype(np.float32) - Ych.astype(np.float32)) < 100 + #print(Ynp.min(), Ynp.max(), Ych.min(), Ych.max()) + + +@pytest.mark.parametrize('amt', [0.5]) +def test_adjust_contrast(amt): + Xnp = np.random.uniform(0, 256, size=(32, 32, 3)).astype(np.uint8) + Ynp = np.zeros(Xnp.shape, dtype=np.uint8) + Snp = np.zeros(Xnp.shape, dtype=np.uint8) + Xch = torch.tensor(Xnp.astype(np.float32)/255.).permute(2, 0, 1) + Ych = (255*tv.functional.adjust_contrast(Xch, 0.5).permute(1, 2, 0).numpy()).astype(np.uint8) + adjust_contrast(Xnp, Snp, 0.5, Ynp) + + plt.subplot(1, 2, 1) + plt.imshow(Ynp) + plt.subplot(1, 2, 2) + plt.imshow(Ych) + plt.savefig('example_imgs/adjust_contrast-%.2f.png' % amt) + + assert np.linalg.norm(Ynp.astype(np.float32) - Ych.astype(np.float32)) < 100 + #print(Ynp.min(), Ynp.max(), Ych.min(), Ych.max()) + + + +if __name__ == '__main__': + test_rotate(45) + test_shear(0.31) + test_brightness(0.5) + test_adjust_contrast(0.5) + + BATCH_SIZE = 512 + image_pipelines = { + 'with': [SimpleRGBImageDecoder(), RandAugment(32), ToTensor()], + 'without': [SimpleRGBImageDecoder(), ToTensor()], + 'torchvision': [SimpleRGBImageDecoder(), ToTensor(), ToTorchImage(), tv.RandAugment(num_ops=2, magnitude=10)] + } + + for name, pipeline in image_pipelines.items(): + loader = Loader('/home/ashert/iclr-followup/ffcv/ffcv/examples/cifar/betons/cifar_train.beton', batch_size=BATCH_SIZE, + num_workers=2, order=OrderOption.RANDOM, + drop_last=True, pipelines={'image': pipeline}) + + for ims, labs in loader: pass + start_time = time.time() + for _ in range(5): #(100): + for ims, labs in loader: pass + print(f'Method: {name} | Shape: {ims.shape} | Time per epoch: {(time.time() - start_time) / 100:.5f}s') From 21e25124e490dd4e9838142fcd516f3369f4472b Mon Sep 17 00:00:00 2001 From: Asher Trockman Date: Tue, 15 Feb 2022 21:50:00 -0500 Subject: [PATCH 02/20] Add posterize --- ffcv/transforms/utils/fast_crop.py | 6 ++++++ tests/test_rand_aug.py | 21 +++++++++++++++++++-- 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/ffcv/transforms/utils/fast_crop.py b/ffcv/transforms/utils/fast_crop.py index f0b3b78b..d6eab515 100644 --- a/ffcv/transforms/utils/fast_crop.py +++ b/ffcv/transforms/utils/fast_crop.py @@ -4,6 +4,12 @@ from ...libffcv import ctypes_resize, ctypes_rotate, ctypes_shear, ctypes_add_weighted +@njit(parallel=False, fastmath=True, inline='always') +def posterize(source, bits, destination): + mask = ~(2 ** (8 - bits) - 1) + destination[:] = source & mask + + @njit(inline='always') def blend(source1, source2, ratio, destination): ctypes_add_weighted(source1.ctypes.data, ratio, diff --git a/tests/test_rand_aug.py b/tests/test_rand_aug.py index b16ff5f8..e986a8a6 100644 --- a/tests/test_rand_aug.py +++ b/tests/test_rand_aug.py @@ -12,7 +12,8 @@ from dataclasses import replace from typing import Callable, Optional, Tuple from ffcv.pipeline.state import State -from ffcv.transforms.utils.fast_crop import rotate, shear, blend, adjust_contrast +from ffcv.transforms.utils.fast_crop import rotate, shear, blend, \ + adjust_contrast, posterize import torchvision.transforms as tv import cv2 import pytest @@ -135,6 +136,22 @@ def test_adjust_contrast(amt): assert np.linalg.norm(Ynp.astype(np.float32) - Ych.astype(np.float32)) < 100 #print(Ynp.min(), Ynp.max(), Ych.min(), Ych.max()) +@pytest.mark.parametrize('bits', [2]) +def test_posterize(bits): + Xnp = np.random.uniform(0, 256, size=(32, 32, 3)).astype(np.uint8) + Ynp = np.zeros(Xnp.shape, dtype=np.uint8) + Xch = torch.tensor(Xnp).permute(2, 0, 1) + Ych = tv.functional.posterize(Xch, bits).permute(1, 2, 0).numpy() + posterize(Xnp, bits, Ynp) + + plt.subplot(1, 2, 1) + plt.imshow(Ynp) + plt.subplot(1, 2, 2) + plt.imshow(Ych) + plt.savefig('example_imgs/posterize-%d.png' % bits) + + print(Ynp.min(), Ynp.max(), Ych.min(), Ych.max()) + assert np.linalg.norm(Ynp.astype(np.float32) - Ych.astype(np.float32)) < 100 if __name__ == '__main__': @@ -142,7 +159,7 @@ def test_adjust_contrast(amt): test_shear(0.31) test_brightness(0.5) test_adjust_contrast(0.5) - + test_posterize(2) BATCH_SIZE = 512 image_pipelines = { 'with': [SimpleRGBImageDecoder(), RandAugment(32), ToTensor()], From cc3a99b11c74b6fa50dec3a3c47aa3ceabea2e32 Mon Sep 17 00:00:00 2001 From: Asher Trockman Date: Tue, 15 Feb 2022 22:01:39 -0500 Subject: [PATCH 03/20] Add invert --- ffcv/transforms/utils/fast_crop.py | 5 +++++ tests/test_rand_aug.py | 23 +++++++++++++++++++++-- 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/ffcv/transforms/utils/fast_crop.py b/ffcv/transforms/utils/fast_crop.py index d6eab515..4c0e62c3 100644 --- a/ffcv/transforms/utils/fast_crop.py +++ b/ffcv/transforms/utils/fast_crop.py @@ -4,6 +4,11 @@ from ...libffcv import ctypes_resize, ctypes_rotate, ctypes_shear, ctypes_add_weighted +@njit(parallel=False, fastmath=True, inline='always') +def invert(source, destination): + destination[:] = 255 - source + + @njit(parallel=False, fastmath=True, inline='always') def posterize(source, bits, destination): mask = ~(2 ** (8 - bits) - 1) diff --git a/tests/test_rand_aug.py b/tests/test_rand_aug.py index e986a8a6..54654ddb 100644 --- a/tests/test_rand_aug.py +++ b/tests/test_rand_aug.py @@ -13,7 +13,7 @@ from typing import Callable, Optional, Tuple from ffcv.pipeline.state import State from ffcv.transforms.utils.fast_crop import rotate, shear, blend, \ - adjust_contrast, posterize + adjust_contrast, posterize, invert import torchvision.transforms as tv import cv2 import pytest @@ -136,6 +136,7 @@ def test_adjust_contrast(amt): assert np.linalg.norm(Ynp.astype(np.float32) - Ych.astype(np.float32)) < 100 #print(Ynp.min(), Ynp.max(), Ych.min(), Ych.max()) + @pytest.mark.parametrize('bits', [2]) def test_posterize(bits): Xnp = np.random.uniform(0, 256, size=(32, 32, 3)).astype(np.uint8) @@ -150,7 +151,23 @@ def test_posterize(bits): plt.imshow(Ych) plt.savefig('example_imgs/posterize-%d.png' % bits) - print(Ynp.min(), Ynp.max(), Ych.min(), Ych.max()) + assert np.linalg.norm(Ynp.astype(np.float32) - Ych.astype(np.float32)) < 100 + + +def test_invert(): + Xnp = np.random.uniform(0, 256, size=(32, 32, 3)).astype(np.uint8) + Xnp[5:9,5:9,:] = 0 + Ynp = np.zeros(Xnp.shape, dtype=np.uint8) + Xch = torch.tensor(Xnp).permute(2, 0, 1) + Ych = tv.functional.invert(Xch).permute(1, 2, 0).numpy() + invert(Xnp, Ynp) + + plt.subplot(1, 2, 1) + plt.imshow(Ynp) + plt.subplot(1, 2, 2) + plt.imshow(Ych) + plt.savefig('example_imgs/invert.png') + assert np.linalg.norm(Ynp.astype(np.float32) - Ych.astype(np.float32)) < 100 @@ -160,6 +177,8 @@ def test_posterize(bits): test_brightness(0.5) test_adjust_contrast(0.5) test_posterize(2) + test_invert() + BATCH_SIZE = 512 image_pipelines = { 'with': [SimpleRGBImageDecoder(), RandAugment(32), ToTensor()], From 08b96b0b9fcdcc078aa2a0b082cb9dbfe7c946d4 Mon Sep 17 00:00:00 2001 From: Asher Trockman Date: Tue, 15 Feb 2022 22:13:40 -0500 Subject: [PATCH 04/20] Add solarize --- ffcv/transforms/utils/fast_crop.py | 6 ++++++ tests/test_rand_aug.py | 23 ++++++++++++++++++++++- 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/ffcv/transforms/utils/fast_crop.py b/ffcv/transforms/utils/fast_crop.py index 4c0e62c3..b0ea9b7a 100644 --- a/ffcv/transforms/utils/fast_crop.py +++ b/ffcv/transforms/utils/fast_crop.py @@ -9,6 +9,12 @@ def invert(source, destination): destination[:] = 255 - source +@njit(parallel=False, fastmath=True, inline='always') +def solarize(source, threshold, destination): + invert(source, destination) + destination[:] = np.where(source >= threshold, destination, source) + + @njit(parallel=False, fastmath=True, inline='always') def posterize(source, bits, destination): mask = ~(2 ** (8 - bits) - 1) diff --git a/tests/test_rand_aug.py b/tests/test_rand_aug.py index 54654ddb..e52c9fe4 100644 --- a/tests/test_rand_aug.py +++ b/tests/test_rand_aug.py @@ -13,7 +13,7 @@ from typing import Callable, Optional, Tuple from ffcv.pipeline.state import State from ffcv.transforms.utils.fast_crop import rotate, shear, blend, \ - adjust_contrast, posterize, invert + adjust_contrast, posterize, invert, solarize import torchvision.transforms as tv import cv2 import pytest @@ -171,6 +171,26 @@ def test_invert(): assert np.linalg.norm(Ynp.astype(np.float32) - Ych.astype(np.float32)) < 100 +@pytest.mark.parametrize('threshold', [9]) +def test_solarize(threshold): + Xnp = np.random.uniform(0, 256, size=(32, 32, 3)).astype(np.uint8) + Xnp[5:9,5:9,:] = 0 + Xnp[10:15,10:15,:] = 8 + Xnp[27:31,27:31,:] = 9 + Ynp = np.zeros(Xnp.shape, dtype=np.uint8) + Xch = torch.tensor(Xnp).permute(2, 0, 1) + Ych = tv.functional.solarize(Xch, threshold).permute(1, 2, 0).numpy() + solarize(Xnp, threshold, Ynp) + + plt.subplot(1, 2, 1) + plt.imshow(Ynp) + plt.subplot(1, 2, 2) + plt.imshow(Ych) + plt.savefig('example_imgs/solarize-%d.png' % threshold) + + assert np.linalg.norm(Ynp.astype(np.float32) - Ych.astype(np.float32)) < 100 + + if __name__ == '__main__': test_rotate(45) test_shear(0.31) @@ -178,6 +198,7 @@ def test_invert(): test_adjust_contrast(0.5) test_posterize(2) test_invert() + test_solarize(9) BATCH_SIZE = 512 image_pipelines = { From 5b0359a85d23aef419a490bf3e31b474673e4093 Mon Sep 17 00:00:00 2001 From: Asher Trockman Date: Wed, 16 Feb 2022 15:02:57 -0500 Subject: [PATCH 05/20] Add equalize --- ffcv/libffcv.py | 3 ++ ffcv/transforms/utils/fast_crop.py | 55 ++++++++++++++++++++++++++++-- libffcv/libffcv.cpp | 7 ++++ tests/test_rand_aug.py | 49 ++++++++++++++++++++------ 4 files changed, 102 insertions(+), 12 deletions(-) diff --git a/ffcv/libffcv.py b/ffcv/libffcv.py index ae923f52..a044ac66 100644 --- a/ffcv/libffcv.py +++ b/ffcv/libffcv.py @@ -25,6 +25,9 @@ def read(fileno:int, destination:np.ndarray, offset:int): ctypes_add_weighted = lib.add_weighted ctypes_add_weighted.argtypes = [c_int64, c_float, c_int64, c_float, c_int64, c_int64, c_int64] +ctypes_equalize = lib.equalize +ctypes_equalize.argtypes = [c_int64, c_int64, c_int64, c_int64] + def resize_crop(source, start_row, end_row, start_col, end_col, destination): ctypes_resize(0, diff --git a/ffcv/transforms/utils/fast_crop.py b/ffcv/transforms/utils/fast_crop.py index b0ea9b7a..340c4547 100644 --- a/ffcv/transforms/utils/fast_crop.py +++ b/ffcv/transforms/utils/fast_crop.py @@ -1,7 +1,58 @@ import ctypes -from numba import njit +from numba import njit, prange import numpy as np -from ...libffcv import ctypes_resize, ctypes_rotate, ctypes_shear, ctypes_add_weighted +from ...libffcv import ctypes_resize, ctypes_rotate, ctypes_shear, ctypes_add_weighted, ctypes_equalize + + +""" +Custom equalize -- equivalent to torchvision.transforms.functional.equalize, +but probably slow -- scratch is a (channels, 256) uint16 array. +""" +@njit(parallel=True, fastmath=True, inline='always') +def equalize(source, scratch, destination): + for i in prange(source.shape[-1]): + scratch[i] = np.bincount(source[..., i].flatten(), minlength=256) + nonzero_hist = scratch[i][scratch[i] != 0] + step = nonzero_hist[:-1].sum() // 255 + + if step == 0: + continue + + scratch[i][1:] = scratch[i].cumsum()[:-1] + scratch[i] = (scratch[i] + step // 2) // step + scratch[i][0] = 0 + np.clip(scratch[i], 0, 255, out=scratch[i]) + + # numba doesn't like 2d advanced indexing + for row in prange(source.shape[0]): + destination[row, :, i] = scratch[i][source[row, :, i]] + +""" +Equalize using OpenCV -- not equivalent to +torchvision.transforms.functional.equalize for so-far-unknown reasons. +""" +@njit(parallel=False, fastmath=True, inline='always') +def fast_equalize(source, chw_scratch, destination): + # this seems kind of hacky + # also, assuming ctypes_equalize allocates a minimal amount of memory + # which may be incorrect -- so maybe we should do this from scratch. + # TODO may be a better way to do this in pure OpenCV + c, h, w = chw_scratch.shape + chw_scratch[0] = source[..., 0] + ctypes_equalize(chw_scratch.ctypes.data, + chw_scratch.ctypes.data, + h, w) + chw_scratch[1] = source[..., 1] + ctypes_equalize(chw_scratch.ctypes.data + h*w, + chw_scratch.ctypes.data + h*w, + h, w) + chw_scratch[2] = source[..., 2] + ctypes_equalize(chw_scratch.ctypes.data + 2*h*w, + chw_scratch.ctypes.data + 2*h*w, + h, w) + destination[..., 0] = chw_scratch[0] + destination[..., 1] = chw_scratch[1] + destination[..., 2] = chw_scratch[2] @njit(parallel=False, fastmath=True, inline='always') diff --git a/libffcv/libffcv.cpp b/libffcv/libffcv.cpp index 475a7e22..e867f282 100644 --- a/libffcv/libffcv.cpp +++ b/libffcv/libffcv.cpp @@ -74,6 +74,13 @@ extern "C" { 0, dest_matrix); } + void equalize(int64_t source_p, int64_t dest_p, int64_t sx, int64_t sy) { + cv::Mat source_matrix(sx, sy, CV_8U, (uint8_t*) source_p); + cv::Mat dest_matrix(sx, sy, CV_8U, (uint8_t*) dest_p); + cv::equalizeHist(source_matrix.colRange(0, sy).rowRange(0, sx), + dest_matrix); + } + void my_memcpy(void *source, void* dst, uint64_t size) { memcpy(dst, source, size); } diff --git a/tests/test_rand_aug.py b/tests/test_rand_aug.py index e52c9fe4..70b76343 100644 --- a/tests/test_rand_aug.py +++ b/tests/test_rand_aug.py @@ -13,7 +13,7 @@ from typing import Callable, Optional, Tuple from ffcv.pipeline.state import State from ffcv.transforms.utils.fast_crop import rotate, shear, blend, \ - adjust_contrast, posterize, invert, solarize + adjust_contrast, posterize, invert, solarize, equalize, fast_equalize import torchvision.transforms as tv import cv2 import pytest @@ -27,7 +27,7 @@ def __init__(self, size: int): def generate_code(self) -> Callable: my_range = Compiler.get_iterator() def randaug(im, mem): - dst, scratch = mem + dst, scratch, lut = mem for i in my_range(im.shape[0]): ## TODO actual randaug logic @@ -42,6 +42,8 @@ def randaug(im, mem): ## adjust contrast adjust_contrast(im[i], scratch[i][0], 0.5, dst[i]) + ## equalize + equalize(im[i], lut[i], dst[i]) return dst randaug.is_parallel = True @@ -51,7 +53,8 @@ def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Option assert previous_state.jit_mode return replace(previous_state, shape=(self.size, self.size, 3)), [ AllocationQuery((self.size, self.size, 3), dtype=np.dtype('uint8')), - AllocationQuery((1, self.size, self.size, 3), dtype=np.dtype('uint8')) + AllocationQuery((1, self.size, self.size, 3), dtype=np.dtype('uint8')), + AllocationQuery((3, 256), dtype=np.dtype('int16')) ] @@ -191,14 +194,40 @@ def test_solarize(threshold): assert np.linalg.norm(Ynp.astype(np.float32) - Ych.astype(np.float32)) < 100 +def test_equalize(): + Xnp = np.random.uniform(0, 256, size=(32, 32, 3)).astype(np.uint8) + #Xnp = cv2.imread('example_imgs/0249.png') + Xnp[5:9,5:9,:] = 0 + Ynp = np.zeros(Xnp.shape, dtype=np.uint8) + #Snp_chw = np.zeros((3, 32, 32), dtype=np.uint8) + Snp = np.zeros((3, 256), dtype=np.int16) + Xch = torch.tensor(Xnp).permute(2, 0, 1) + Ych = tv.functional.equalize(Xch).permute(1, 2, 0).numpy() + #fast_equalize(Xnp, Snp_chw, Ynp) + equalize(Xnp, Snp, Ynp) + + plt.subplot(2, 2, 1) + plt.imshow(Xnp) + plt.subplot(2, 2, 2) + plt.imshow(Ynp) + plt.subplot(2, 2, 3) + plt.imshow(Xch.permute(1, 2, 0).numpy()) + plt.subplot(2, 2, 4) + plt.imshow(Ych) + plt.savefig('example_imgs/equalize.png') + + assert np.linalg.norm(Ynp.astype(np.float32) - Ych.astype(np.float32)) < 100 + + if __name__ == '__main__': - test_rotate(45) - test_shear(0.31) - test_brightness(0.5) - test_adjust_contrast(0.5) - test_posterize(2) - test_invert() - test_solarize(9) +# test_rotate(45) +# test_shear(0.31) +# test_brightness(0.5) +# test_adjust_contrast(0.5) +# test_posterize(2) +# test_invert() +# test_solarize(9) +# test_equalize() BATCH_SIZE = 512 image_pipelines = { From 202501359c6fd786af6370c996d6b52850e35a72 Mon Sep 17 00:00:00 2001 From: Asher Trockman Date: Wed, 16 Feb 2022 15:57:58 -0500 Subject: [PATCH 06/20] Add autocontrast --- ffcv/transforms/utils/fast_crop.py | 24 +++++++++++++++++ tests/test_rand_aug.py | 43 ++++++++++++++++++++++++++---- 2 files changed, 62 insertions(+), 5 deletions(-) diff --git a/ffcv/transforms/utils/fast_crop.py b/ffcv/transforms/utils/fast_crop.py index 340c4547..1d1027a8 100644 --- a/ffcv/transforms/utils/fast_crop.py +++ b/ffcv/transforms/utils/fast_crop.py @@ -3,6 +3,29 @@ import numpy as np from ...libffcv import ctypes_resize, ctypes_rotate, ctypes_shear, ctypes_add_weighted, ctypes_equalize +""" +Requires a float32 scratch array +""" +@njit(parallel=True, fastmath=True, inline='always') +def autocontrast(source, scratchf, destination): + # numba: no kwargs in min? as a consequence, I might as well have written + # this in C++ + # TODO assuming 3 channels + minimum = [source[..., 0].min(), source[..., 1].min(), source[..., 2].min()] + maximum = [source[..., 0].max(), source[..., 1].max(), source[..., 2].max()] + scale = [0.0, 0.0, 0.0] + for i in prange(source.shape[-1]): + if minimum[i] == maximum[i]: + scale[i] = 1 + minimum[i] = 0 + else: + scale[i] = 255. / (maximum[i] - minimum[i]) + for i in prange(source.shape[-1]): + scratchf[..., i] = source[..., i] - minimum[i] + scratchf[..., i] = scratchf[..., i] * scale[i] + np.clip(scratchf, 0, 255, out=scratchf) + destination[:] = scratchf + """ Custom equalize -- equivalent to torchvision.transforms.functional.equalize, @@ -11,6 +34,7 @@ @njit(parallel=True, fastmath=True, inline='always') def equalize(source, scratch, destination): for i in prange(source.shape[-1]): + # TODO memory less than ideal for bincount() and hist() scratch[i] = np.bincount(source[..., i].flatten(), minlength=256) nonzero_hist = scratch[i][scratch[i] != 0] step = nonzero_hist[:-1].sum() // 255 diff --git a/tests/test_rand_aug.py b/tests/test_rand_aug.py index 70b76343..92eee588 100644 --- a/tests/test_rand_aug.py +++ b/tests/test_rand_aug.py @@ -13,7 +13,7 @@ from typing import Callable, Optional, Tuple from ffcv.pipeline.state import State from ffcv.transforms.utils.fast_crop import rotate, shear, blend, \ - adjust_contrast, posterize, invert, solarize, equalize, fast_equalize + adjust_contrast, posterize, invert, solarize, equalize, fast_equalize, autocontrast import torchvision.transforms as tv import cv2 import pytest @@ -27,7 +27,7 @@ def __init__(self, size: int): def generate_code(self) -> Callable: my_range = Compiler.get_iterator() def randaug(im, mem): - dst, scratch, lut = mem + dst, scratch, lut, scratchf = mem for i in my_range(im.shape[0]): ## TODO actual randaug logic @@ -42,8 +42,16 @@ def randaug(im, mem): ## adjust contrast adjust_contrast(im[i], scratch[i][0], 0.5, dst[i]) - ## equalize - equalize(im[i], lut[i], dst[i]) + if deg < 10: + ## equalize + equalize(im[i], lut[i], dst[i]) + + if 10 < deg < 20: + ## autocontrast -- things are getting slower now. + autocontrast(im[i], scratchf[i][0], dst[i]) + # --^ this is a good candidate for moving entirely to OpenCV + # it would involve less casting/scratch memory I think + return dst randaug.is_parallel = True @@ -54,7 +62,8 @@ def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Option return replace(previous_state, shape=(self.size, self.size, 3)), [ AllocationQuery((self.size, self.size, 3), dtype=np.dtype('uint8')), AllocationQuery((1, self.size, self.size, 3), dtype=np.dtype('uint8')), - AllocationQuery((3, 256), dtype=np.dtype('int16')) + AllocationQuery((3, 256), dtype=np.dtype('int16')), + AllocationQuery((1, self.size, self.size, 3), dtype=np.dtype('float32')), ] @@ -218,6 +227,29 @@ def test_equalize(): assert np.linalg.norm(Ynp.astype(np.float32) - Ych.astype(np.float32)) < 100 + +def test_autocontrast(): + Xnp = np.random.uniform(0, 256, size=(32, 32, 3)).astype(np.uint8) + #Xnp = cv2.imread('example_imgs/0249.png') + Xnp[5:9,5:9,:] = 0 + Ynp = np.zeros(Xnp.shape, dtype=np.uint8) + Snp = np.zeros((32, 32, 3), dtype=np.float32) + Xch = torch.tensor(Xnp).permute(2, 0, 1) + Ych = tv.functional.autocontrast(Xch).permute(1, 2, 0).numpy() + autocontrast(Xnp, Snp, Ynp) + + plt.subplot(2, 2, 1) + plt.imshow(Xnp) + plt.subplot(2, 2, 2) + plt.imshow(Ynp) + plt.subplot(2, 2, 3) + plt.imshow(Xch.permute(1, 2, 0).numpy()) + plt.subplot(2, 2, 4) + plt.imshow(Ych) + plt.savefig('example_imgs/autocontrast.png') + + assert np.linalg.norm(Ynp.astype(np.float32) - Ych.astype(np.float32)) < 100 + if __name__ == '__main__': # test_rotate(45) @@ -228,6 +260,7 @@ def test_equalize(): # test_invert() # test_solarize(9) # test_equalize() +# test_autocontrast() BATCH_SIZE = 512 image_pipelines = { From 6b6432cb10d81b20bede3209f10c055eb1d25854 Mon Sep 17 00:00:00 2001 From: Asher Trockman Date: Wed, 16 Feb 2022 20:59:26 -0500 Subject: [PATCH 07/20] Add sharpness --- ffcv/libffcv.py | 6 +++++- ffcv/transforms/utils/fast_crop.py | 19 ++++++++++++++++++- libffcv/libffcv.cpp | 17 +++++++++++++++++ tests/test_rand_aug.py | 22 +++++++++++++++++++++- 4 files changed, 61 insertions(+), 3 deletions(-) diff --git a/ffcv/libffcv.py b/ffcv/libffcv.py index a044ac66..815d23e0 100644 --- a/ffcv/libffcv.py +++ b/ffcv/libffcv.py @@ -13,6 +13,7 @@ def read(fileno:int, destination:np.ndarray, offset:int): return read_c(fileno, destination.ctypes.data, destination.size, offset) + ctypes_resize = lib.resize ctypes_resize.argtypes = 11 * [c_int64] @@ -26,7 +27,10 @@ def read(fileno:int, destination:np.ndarray, offset:int): ctypes_add_weighted.argtypes = [c_int64, c_float, c_int64, c_float, c_int64, c_int64, c_int64] ctypes_equalize = lib.equalize -ctypes_equalize.argtypes = [c_int64, c_int64, c_int64, c_int64] +ctypes_equalize.argtypes = 4 * [c_int64] + +ctypes_unsharp_mask = lib.unsharp_mask +ctypes_unsharp_mask.argtypes = 4 * [c_int64] def resize_crop(source, start_row, end_row, start_col, end_col, destination): diff --git a/ffcv/transforms/utils/fast_crop.py b/ffcv/transforms/utils/fast_crop.py index 1d1027a8..a04370c9 100644 --- a/ffcv/transforms/utils/fast_crop.py +++ b/ffcv/transforms/utils/fast_crop.py @@ -1,7 +1,8 @@ import ctypes from numba import njit, prange import numpy as np -from ...libffcv import ctypes_resize, ctypes_rotate, ctypes_shear, ctypes_add_weighted, ctypes_equalize +from ...libffcv import ctypes_resize, ctypes_rotate, ctypes_shear, \ + ctypes_add_weighted, ctypes_equalize, ctypes_unsharp_mask """ Requires a float32 scratch array @@ -114,6 +115,22 @@ def adjust_contrast(source, scratch, factor, destination): blend(source, scratch, factor, destination) +@njit(fastmath=True, inline='always') +def sharpen(source, destination, amount): + ctypes_unsharp_mask(source.ctypes.data, + destination.ctypes.data, + source.shape[0], source.shape[1]) + + # in PyTorch's implementation, + # the border is unaffected + destination[0,:] = source[0,:] + destination[1:,0] = source[1:,0] + destination[-1,:] = source[-1,:] + destination[1:-1,-1] = source[1:-1,-1] + + blend(source, destination, amount, destination) + + @njit(inline='always') def rotate(source, destination, angle): ctypes_rotate(angle, diff --git a/libffcv/libffcv.cpp b/libffcv/libffcv.cpp index e867f282..e0277dc9 100644 --- a/libffcv/libffcv.cpp +++ b/libffcv/libffcv.cpp @@ -81,6 +81,23 @@ extern "C" { dest_matrix); } + void unsharp_mask(int64_t source_p, int64_t dest_p, int64_t sx, int64_t sy) { + cv::Mat source_matrix(sx, sy, CV_8UC3, (uint8_t*) source_p); + cv::Mat dest_matrix(sx, sy, CV_8UC3, (uint8_t*) dest_p); + + cv::Point anchor(-1, -1); + + // 3x3 kernel, all 1s with 5 in center / sum of kernel + float _kernel[9] = { 0.0769, 0.0769, 0.0769, 0.0769, 0.3846, + 0.0769, 0.0769, 0.0769, 0.0769 }; + cv::Mat kernel = cv::Mat(3, 3, CV_32F, _kernel); + + cv::filter2D(source_matrix.colRange(0, sy).rowRange(0, sx), + dest_matrix, -1, kernel, anchor, 0, cv::BORDER_ISOLATED); + + //add_weighted(source_p, amount, dest_p, 1 - amount, dest_p, sx, sy); + } + void my_memcpy(void *source, void* dst, uint64_t size) { memcpy(dst, source, size); } diff --git a/tests/test_rand_aug.py b/tests/test_rand_aug.py index 92eee588..97d8f9a3 100644 --- a/tests/test_rand_aug.py +++ b/tests/test_rand_aug.py @@ -13,7 +13,7 @@ from typing import Callable, Optional, Tuple from ffcv.pipeline.state import State from ffcv.transforms.utils.fast_crop import rotate, shear, blend, \ - adjust_contrast, posterize, invert, solarize, equalize, fast_equalize, autocontrast + adjust_contrast, posterize, invert, solarize, equalize, fast_equalize, autocontrast, sharpen import torchvision.transforms as tv import cv2 import pytest @@ -251,6 +251,25 @@ def test_autocontrast(): assert np.linalg.norm(Ynp.astype(np.float32) - Ych.astype(np.float32)) < 100 +@pytest.mark.parametrize('amt', [2.0]) +def test_sharpen(amt): + Xnp = np.random.uniform(0, 256, size=(32, 32, 3)).astype(np.uint8) + #Xnp = cv2.imread('example_imgs/0249.png') + Ynp = np.zeros(Xnp.shape, dtype=np.uint8) + Snp = np.zeros(Xnp.shape, dtype=np.uint8) + Xch = torch.tensor(Xnp).permute(2, 0, 1) + Ych = tv.functional.adjust_sharpness(Xch, amt).permute(1, 2, 0).numpy() + sharpen(Xnp, Ynp, amt) + + plt.subplot(1, 2, 1) + plt.imshow(Ynp) + plt.subplot(1, 2, 2) + plt.imshow(Ych) + plt.savefig('example_imgs/sharpen-%.2f.png' % amt) + + assert np.linalg.norm(Ynp.astype(np.float32) - Ych.astype(np.float32)) < 100 + + if __name__ == '__main__': # test_rotate(45) # test_shear(0.31) @@ -261,6 +280,7 @@ def test_autocontrast(): # test_solarize(9) # test_equalize() # test_autocontrast() +# test_sharpen(2.0) BATCH_SIZE = 512 image_pipelines = { From fc2bdf6581b550bc486a6eb2f3d6cf7162d96f10 Mon Sep 17 00:00:00 2001 From: Asher Trockman Date: Thu, 17 Feb 2022 11:10:38 -0500 Subject: [PATCH 08/20] Add color (adjust_saturation), fix test params --- ffcv/transforms/utils/fast_crop.py | 14 ++++++++++++ tests/test_rand_aug.py | 36 +++++++++++++++++++++++++----- 2 files changed, 44 insertions(+), 6 deletions(-) diff --git a/ffcv/transforms/utils/fast_crop.py b/ffcv/transforms/utils/fast_crop.py index a04370c9..7241dcd6 100644 --- a/ffcv/transforms/utils/fast_crop.py +++ b/ffcv/transforms/utils/fast_crop.py @@ -104,6 +104,20 @@ def blend(source1, source2, ratio, destination): destination.ctypes.data, source1.shape[0], source1.shape[1]) + +@njit(parallel=False, fastmath=True, inline='always') +def adjust_saturation(source, scratch, factor, destination): + # TODO numpy autocasting probably allocates memory here, + # should be more careful. + # TODO do we really need scratch for this? could use destination + scratch[...,0] = 0.299 * source[..., 0] + \ + 0.587 * source[..., 1] + \ + 0.114 * source[..., 2] + scratch[...,1] = scratch[...,0] + scratch[...,2] = scratch[...,1] + + blend(source, scratch, factor, destination) + @njit(parallel=False, fastmath=True, inline='always') def adjust_contrast(source, scratch, factor, destination): diff --git a/tests/test_rand_aug.py b/tests/test_rand_aug.py index 97d8f9a3..aa9cb408 100644 --- a/tests/test_rand_aug.py +++ b/tests/test_rand_aug.py @@ -13,7 +13,8 @@ from typing import Callable, Optional, Tuple from ffcv.pipeline.state import State from ffcv.transforms.utils.fast_crop import rotate, shear, blend, \ - adjust_contrast, posterize, invert, solarize, equalize, fast_equalize, autocontrast, sharpen + adjust_contrast, posterize, invert, solarize, equalize, fast_equalize, \ + autocontrast, sharpen, adjust_saturation import torchvision.transforms as tv import cv2 import pytest @@ -94,12 +95,12 @@ def test_shear(amt): angle=0.0, translate=[0, 0], scale=1.0, - shear=[0, math.degrees(math.atan(0.31))], + shear=[0, math.degrees(math.atan(amt))], interpolation=tv.functional.InterpolationMode.NEAREST, fill=0, #center=[0, 0], ).permute(1, 2, 0).numpy().astype(np.uint8) - shear(Xnp, Ynp, 0, -0.31) + shear(Xnp, Ynp, 0, -amt) plt.subplot(1, 2, 1) plt.imshow(Ynp) @@ -136,8 +137,8 @@ def test_adjust_contrast(amt): Ynp = np.zeros(Xnp.shape, dtype=np.uint8) Snp = np.zeros(Xnp.shape, dtype=np.uint8) Xch = torch.tensor(Xnp.astype(np.float32)/255.).permute(2, 0, 1) - Ych = (255*tv.functional.adjust_contrast(Xch, 0.5).permute(1, 2, 0).numpy()).astype(np.uint8) - adjust_contrast(Xnp, Snp, 0.5, Ynp) + Ych = (255*tv.functional.adjust_contrast(Xch, amt).permute(1, 2, 0).numpy()).astype(np.uint8) + adjust_contrast(Xnp, Snp, amt, Ynp) plt.subplot(1, 2, 1) plt.imshow(Ynp) @@ -269,7 +270,29 @@ def test_sharpen(amt): assert np.linalg.norm(Ynp.astype(np.float32) - Ych.astype(np.float32)) < 100 - +@pytest.mark.parametrize('amt', [0.5, 1.5]) +def test_adjust_saturation(amt): + Xnp = np.random.uniform(0, 256, size=(32, 32, 3)).astype(np.uint8) + #Xnp = cv2.imread('example_imgs/0249.png') + Ynp = np.zeros(Xnp.shape, dtype=np.uint8) + Snp = np.zeros(Xnp.shape, dtype=np.uint8) + Xch = torch.tensor(Xnp.astype(np.float32)/255.).permute(2, 0, 1) + Ych = (255*tv.functional.adjust_saturation(Xch, amt).permute(1, 2, 0).numpy()).astype(np.uint8) + adjust_saturation(Xnp, Snp, amt, Ynp) + + plt.subplot(2, 2, 1) + plt.imshow(Xnp) + plt.subplot(2, 2, 2) + plt.imshow(Ynp) + plt.subplot(2, 2, 3) + plt.imshow(Xch.permute(1, 2, 0).numpy()) + plt.subplot(2, 2, 4) + plt.imshow(Ych) + plt.savefig('example_imgs/adjust_saturation-%.2f.png' % amt) + + assert np.linalg.norm(Ynp.astype(np.float32) - Ych.astype(np.float32)) < 100 + #print(Ynp.min(), Ynp.max(), Ych.min(), Ych.max()) + if __name__ == '__main__': # test_rotate(45) # test_shear(0.31) @@ -281,6 +304,7 @@ def test_sharpen(amt): # test_equalize() # test_autocontrast() # test_sharpen(2.0) +# test_adjust_saturation(0.5) BATCH_SIZE = 512 image_pipelines = { From 56e5cc39a8046c7a106962715ef9edef2ca87395 Mon Sep 17 00:00:00 2001 From: Asher Trockman Date: Thu, 17 Feb 2022 11:41:24 -0500 Subject: [PATCH 09/20] Add translate --- ffcv/transforms/utils/fast_crop.py | 15 +++++++++++++++ tests/test_rand_aug.py | 30 +++++++++++++++++++++++++++++- 2 files changed, 44 insertions(+), 1 deletion(-) diff --git a/ffcv/transforms/utils/fast_crop.py b/ffcv/transforms/utils/fast_crop.py index 7241dcd6..90a4a4cf 100644 --- a/ffcv/transforms/utils/fast_crop.py +++ b/ffcv/transforms/utils/fast_crop.py @@ -145,6 +145,21 @@ def sharpen(source, destination, amount): blend(source, destination, amount, destination) +""" +Translation, x and y +Assuming this is faster than warpAffine; +also assuming tx and ty are ints +""" +@njit(inline='always') +def translate(source, destination, tx, ty): + if tx > 0: + destination[:, tx:] = source[:, :-tx] + destination[:, :tx] = 0 + if ty > 0: + destination[ty:, :] = source[:-ty, :] + destination[:ty, :] = 0 + + @njit(inline='always') def rotate(source, destination, angle): ctypes_rotate(angle, diff --git a/tests/test_rand_aug.py b/tests/test_rand_aug.py index aa9cb408..35bcca1a 100644 --- a/tests/test_rand_aug.py +++ b/tests/test_rand_aug.py @@ -14,7 +14,7 @@ from ffcv.pipeline.state import State from ffcv.transforms.utils.fast_crop import rotate, shear, blend, \ adjust_contrast, posterize, invert, solarize, equalize, fast_equalize, \ - autocontrast, sharpen, adjust_saturation + autocontrast, sharpen, adjust_saturation, translate import torchvision.transforms as tv import cv2 import pytest @@ -270,6 +270,7 @@ def test_sharpen(amt): assert np.linalg.norm(Ynp.astype(np.float32) - Ych.astype(np.float32)) < 100 + @pytest.mark.parametrize('amt', [0.5, 1.5]) def test_adjust_saturation(amt): Xnp = np.random.uniform(0, 256, size=(32, 32, 3)).astype(np.uint8) @@ -293,6 +294,32 @@ def test_adjust_saturation(amt): assert np.linalg.norm(Ynp.astype(np.float32) - Ych.astype(np.float32)) < 100 #print(Ynp.min(), Ynp.max(), Ych.min(), Ych.max()) + +@pytest.mark.parametrize('amt', [(4, 0), (0, 4)]) +def test_translate(amt): + Xnp = np.random.uniform(0, 255, size=(32, 32, 3)).astype(np.uint8) + Ynp = np.zeros(Xnp.shape, dtype=np.uint8) + Xch = torch.tensor(Xnp.astype(np.float32)).permute(2, 0, 1) + Ych = tv.functional.affine(Xch, + angle=0.0, + translate=[amt[0], amt[1]], + scale=1.0, + shear=[0, 0], + interpolation=tv.functional.InterpolationMode.NEAREST, + fill=0, + #center=[0, 0], + ).permute(1, 2, 0).numpy().astype(np.uint8) + translate(Xnp, Ynp, amt[0], amt[1]) + + plt.subplot(1, 2, 1) + plt.imshow(Ynp) + plt.subplot(1, 2, 2) + plt.imshow(Ych) + plt.savefig('example_imgs/translate-%d-%d.png' % amt) + + assert np.linalg.norm(Ynp.astype(np.float32) - Ych.astype(np.float32)) < 100 + + if __name__ == '__main__': # test_rotate(45) # test_shear(0.31) @@ -305,6 +332,7 @@ def test_adjust_saturation(amt): # test_autocontrast() # test_sharpen(2.0) # test_adjust_saturation(0.5) +# test_translate((4, 0)) BATCH_SIZE = 512 image_pipelines = { From 784f3d081a9dc7759df02118e52c9dbedf83a901 Mon Sep 17 00:00:00 2001 From: Asher Trockman Date: Thu, 17 Feb 2022 12:12:05 -0500 Subject: [PATCH 10/20] Fix merge problems/typos --- ffcv/libffcv.py | 1 + libffcv/libffcv.cpp | 6 ------ 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/ffcv/libffcv.py b/ffcv/libffcv.py index 53ffef15..b9985d1c 100644 --- a/ffcv/libffcv.py +++ b/ffcv/libffcv.py @@ -1,6 +1,7 @@ import ctypes from numba import njit import numpy as np +import platform from ctypes import CDLL, c_int64, c_uint8, c_uint64, c_float, POINTER, c_void_p, c_uint32, c_bool, cdll import ffcv._libffcv diff --git a/libffcv/libffcv.cpp b/libffcv/libffcv.cpp index 41ad994f..06fcff14 100644 --- a/libffcv/libffcv.cpp +++ b/libffcv/libffcv.cpp @@ -105,12 +105,6 @@ extern "C" { //add_weighted(source_p, amount, dest_p, 1 - amount, dest_p, sx, sy); } - - void my_memcpy(void *source, void* dst, uint64_t size) { - memcpy(dst, source, size); - } - - void my_fread(int64_t fp, int64_t offset, void *destination, int64_t size) { EXPORT void my_memcpy(void *source, void* dst, uint64_t size) { memcpy(dst, source, size); From e2675380228c093a80630a80cce1bd56dd4ca1c1 Mon Sep 17 00:00:00 2001 From: Asher Trockman Date: Thu, 17 Feb 2022 12:12:24 -0500 Subject: [PATCH 11/20] Remove unused imports --- ffcv/fields/ndarray.py | 1 - ffcv/pipeline/graph.py | 1 - 2 files changed, 2 deletions(-) diff --git a/ffcv/fields/ndarray.py b/ffcv/fields/ndarray.py index d10687f9..df347d43 100644 --- a/ffcv/fields/ndarray.py +++ b/ffcv/fields/ndarray.py @@ -2,7 +2,6 @@ import warnings import json from dataclasses import replace -from kornia import warnings import numpy as np import torch as ch diff --git a/ffcv/pipeline/graph.py b/ffcv/pipeline/graph.py index ccc657bc..b7f6b9d5 100644 --- a/ffcv/pipeline/graph.py +++ b/ffcv/pipeline/graph.py @@ -2,7 +2,6 @@ import warnings import ast -import astor from collections import defaultdict from typing import Callable, Dict, List, Optional, Sequence, Set from abc import ABC, abstractmethod From 6fc2362dbaa9c122b035ab96a5effd89580ba3b9 Mon Sep 17 00:00:00 2001 From: Asher Trockman Date: Thu, 17 Feb 2022 12:28:59 -0500 Subject: [PATCH 12/20] Support translate by negative amount --- ffcv/transforms/utils/fast_crop.py | 6 ++++++ tests/test_rand_aug.py | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/ffcv/transforms/utils/fast_crop.py b/ffcv/transforms/utils/fast_crop.py index 90a4a4cf..194a5f18 100644 --- a/ffcv/transforms/utils/fast_crop.py +++ b/ffcv/transforms/utils/fast_crop.py @@ -155,9 +155,15 @@ def translate(source, destination, tx, ty): if tx > 0: destination[:, tx:] = source[:, :-tx] destination[:, :tx] = 0 + if tx < 0: + destination[:, :tx] = source[:, -tx:] + destination[:, tx:] = 0 if ty > 0: destination[ty:, :] = source[:-ty, :] destination[:ty, :] = 0 + if ty < 0: + destination[:ty, :] = source[-ty:, :] + destination[ty:, :] = 0 @njit(inline='always') diff --git a/tests/test_rand_aug.py b/tests/test_rand_aug.py index 35bcca1a..9abf9ade 100644 --- a/tests/test_rand_aug.py +++ b/tests/test_rand_aug.py @@ -295,7 +295,7 @@ def test_adjust_saturation(amt): #print(Ynp.min(), Ynp.max(), Ych.min(), Ych.max()) -@pytest.mark.parametrize('amt', [(4, 0), (0, 4)]) +@pytest.mark.parametrize('amt', [(4, 0), (0, 4), (-4, 0), (0, -4)]) def test_translate(amt): Xnp = np.random.uniform(0, 255, size=(32, 32, 3)).astype(np.uint8) Ynp = np.zeros(Xnp.shape, dtype=np.uint8) From 4d2e1ed0a033906305e0832638b9e3fd82464dad Mon Sep 17 00:00:00 2001 From: Asher Trockman Date: Thu, 17 Feb 2022 12:35:49 -0500 Subject: [PATCH 13/20] Ensure rotate, shear amounts can be negative --- tests/test_rand_aug.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_rand_aug.py b/tests/test_rand_aug.py index 9abf9ade..3857de6a 100644 --- a/tests/test_rand_aug.py +++ b/tests/test_rand_aug.py @@ -68,7 +68,7 @@ def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Option ] -@pytest.mark.parametrize('angle', [45]) +@pytest.mark.parametrize('angle', [45, -30]) def test_rotate(angle): Xnp = np.random.uniform(0, 255, size=(32, 32, 3)).astype(np.uint8) Ynp = np.zeros(Xnp.shape, dtype=np.uint8) @@ -86,7 +86,7 @@ def test_rotate(angle): #print(Ynp.min(), Ynp.max(), Ych.min(), Ych.max()) -@pytest.mark.parametrize('amt', [0.31]) +@pytest.mark.parametrize('amt', [0.31, -0.31]) def test_shear(amt): Xnp = np.random.uniform(0, 255, size=(32, 32, 3)).astype(np.uint8) Ynp = np.zeros(Xnp.shape, dtype=np.uint8) From 9e3a2ee5769c7210a2f0bb747998054bbbbaf913 Mon Sep 17 00:00:00 2001 From: Asher Trockman Date: Thu, 17 Feb 2022 15:00:13 -0500 Subject: [PATCH 14/20] Implement RandAug (WIP) --- tests/test_rand_aug.py | 102 +++++++++++++++++++++++++++++++---------- 1 file changed, 79 insertions(+), 23 deletions(-) diff --git a/tests/test_rand_aug.py b/tests/test_rand_aug.py index 3857de6a..0ba3adfb 100644 --- a/tests/test_rand_aug.py +++ b/tests/test_rand_aug.py @@ -21,37 +21,92 @@ import math class RandAugment(Operation): - def __init__(self, size: int): + def __init__(self, + size: int = 32, + num_ops: int = 2, + magnitude: int = 9, + num_magnitude_bins: int = 31): super().__init__() self.size = size + self.num_ops = num_ops + self.magnitude = magnitude + num_bins = num_magnitude_bins + # index, name (for readability); bins, sign multiplier + # those with a -1 can have negative magnitude with probability 0.5 + self.op_table = [ + (0, "Identity", np.array(0.0), 1), + (1, "ShearX", np.linspace(0.0, 0.3, num_bins), -1), + (2, "ShearY", np.linspace(0.0, 0.3, num_bins), -1), + (3, "TranslateX", np.linspace(0.0, 150.0 / 331.0 * size, num_bins), -1), + (4, "TranslateY", np.linspace(0.0, 150.0 / 331.0 * size, num_bins), -1), + (5, "Rotate", np.linspace(0.0, 30.0, num_bins), -1), + (6, "Brightness", np.linspace(0.0, 0.9, num_bins), -1), + (7, "Color", np.linspace(0.0, 0.9, num_bins), -1), + (8, "Contrast", np.linspace(0.0, 0.9, num_bins), -1), + (9, "Sharpness", np.linspace(0.0, 0.9, num_bins), -1), + (10, "Posterize", 8 - (np.arange(num_bins) / ((num_bins - 1) / 4)).round().astype('uint8'), 1), + (11, "Solarize", np.linspace(255.0, 0.0, num_bins), 1), + (12, "AutoContrast", np.array(0.0), 1), + (13, "Equalize", np.array(0.0), 1), + ] def generate_code(self) -> Callable: my_range = Compiler.get_iterator() + op_table = self.op_table + magnitudes = np.array([(op[2][self.magnitude] if op[2].ndim > 0 else 0) for op in self.op_table]) + is_signed = np.array([op[3] for op in self.op_table]) + num_ops = self.num_ops +# for i in range(len(magnitudes)): +# print(i, op_table[i][1], '%.3f'%magnitudes[i]) def randaug(im, mem): dst, scratch, lut, scratchf = mem for i in my_range(im.shape[0]): - - ## TODO actual randaug logic - - ## rotate - deg = np.random.random() * 45.0 - rotate(im[i], dst[i], deg) - - ## brighten - blend(im[i], scratch[i][0], 0.5, dst[i]) - - ## adjust contrast - adjust_contrast(im[i], scratch[i][0], 0.5, dst[i]) - - if deg < 10: - ## equalize - equalize(im[i], lut[i], dst[i]) - - if 10 < deg < 20: - ## autocontrast -- things are getting slower now. - autocontrast(im[i], scratchf[i][0], dst[i]) - # --^ this is a good candidate for moving entirely to OpenCV - # it would involve less casting/scratch memory I think + for _ in range(num_ops): + idx = np.random.randint(low=0, high=13+1) + mag = magnitudes[idx] + if np.random.random() < 0.5: + mag = mag * is_signed[idx] + + # Not worth fighting numba at the moment. + # TODO + if idx == 1: # ShearX (0.004) + shear(im[i], dst[i], mag, 0) + + if idx == 2: # ShearY + shear(im[i], dst[i], 0, mag) + + if idx == 3: # TranslateX + translate(im[i], dst[i], int(mag), 0) + + if idx == 4: # TranslateY + translate(im[i], dst[i], 0, int(mag)) + + if idx == 5: # Rotate + rotate(im[i], dst[i], mag) + + if idx == 6: # Brightness + blend(im[i], scratch[i][0], 1.0 + mag, dst[i]) + + if idx == 7: # Color + adjust_saturation(im[i], scratch[i][0], 1.0 + mag, dst[i]) + + if idx == 8: # Contrast + adjust_contrast(im[i], scratch[i][0], 1.0 + mag, dst[i]) + + if idx == 9: # Sharpness + sharpen(im[i], dst[i], 1.0 + mag) + + if idx == 10: # Posterize + posterize(im[i], int(mag), dst[i]) + + if idx == 11: # Solarize + solarize(im[i], mag, dst[i]) + + if idx == 12: # AutoContrast (TODO: takes 0.04s -> 0.052s) (+0.01s) + autocontrast(im[i], scratchf[i][0], dst[i]) + + if idx == 13: # Equalize (TODO: +0.008s) + equalize(im[i], lut[i], dst[i]) return dst @@ -346,6 +401,7 @@ def test_translate(amt): num_workers=2, order=OrderOption.RANDOM, drop_last=True, pipelines={'image': pipeline}) + import matplotlib.pyplot as plt for ims, labs in loader: pass start_time = time.time() for _ in range(5): #(100): From f6ba47925124af8256b5a5cbe6211d0fc9dbe115 Mon Sep 17 00:00:00 2001 From: Asher Trockman Date: Thu, 17 Feb 2022 18:22:56 -0500 Subject: [PATCH 15/20] Allow more than one op per image, fixes --- ffcv/transforms/utils/fast_crop.py | 6 ++++ tests/test_rand_aug.py | 49 ++++++++++++++++++++---------- 2 files changed, 39 insertions(+), 16 deletions(-) diff --git a/ffcv/transforms/utils/fast_crop.py b/ffcv/transforms/utils/fast_crop.py index 194a5f18..1569c9f7 100644 --- a/ffcv/transforms/utils/fast_crop.py +++ b/ffcv/transforms/utils/fast_crop.py @@ -105,6 +105,12 @@ def blend(source1, source2, ratio, destination): source1.shape[0], source1.shape[1]) +@njit(inline='always') +def adjust_brightness(source, scratch, factor, destination): + scratch[:] = 0 + blend(source, scratch, factor, destination) + + @njit(parallel=False, fastmath=True, inline='always') def adjust_saturation(source, scratch, factor, destination): # TODO numpy autocasting probably allocates memory here, diff --git a/tests/test_rand_aug.py b/tests/test_rand_aug.py index 0ba3adfb..619c045e 100644 --- a/tests/test_rand_aug.py +++ b/tests/test_rand_aug.py @@ -14,7 +14,7 @@ from ffcv.pipeline.state import State from ffcv.transforms.utils.fast_crop import rotate, shear, blend, \ adjust_contrast, posterize, invert, solarize, equalize, fast_equalize, \ - autocontrast, sharpen, adjust_saturation, translate + autocontrast, sharpen, adjust_saturation, translate, adjust_brightness import torchvision.transforms as tv import cv2 import pytest @@ -61,7 +61,12 @@ def generate_code(self) -> Callable: def randaug(im, mem): dst, scratch, lut, scratchf = mem for i in my_range(im.shape[0]): - for _ in range(num_ops): + for n in range(num_ops): + if n == 0: + src = im + else: + src = dst + idx = np.random.randint(low=0, high=13+1) mag = magnitudes[idx] if np.random.random() < 0.5: @@ -69,44 +74,47 @@ def randaug(im, mem): # Not worth fighting numba at the moment. # TODO + if idx == 0: + dst[i][:] = src[i] + if idx == 1: # ShearX (0.004) - shear(im[i], dst[i], mag, 0) + shear(src[i], dst[i], mag, 0) if idx == 2: # ShearY - shear(im[i], dst[i], 0, mag) + shear(src[i], dst[i], 0, mag) if idx == 3: # TranslateX - translate(im[i], dst[i], int(mag), 0) + translate(src[i], dst[i], int(mag), 0) if idx == 4: # TranslateY - translate(im[i], dst[i], 0, int(mag)) + translate(src[i], dst[i], 0, int(mag)) if idx == 5: # Rotate - rotate(im[i], dst[i], mag) + rotate(src[i], dst[i], mag) if idx == 6: # Brightness - blend(im[i], scratch[i][0], 1.0 + mag, dst[i]) + adjust_brightness(src[i], scratch[i][0], 1.0 + mag, dst[i]) if idx == 7: # Color - adjust_saturation(im[i], scratch[i][0], 1.0 + mag, dst[i]) + adjust_saturation(src[i], scratch[i][0], 1.0 + mag, dst[i]) if idx == 8: # Contrast - adjust_contrast(im[i], scratch[i][0], 1.0 + mag, dst[i]) + adjust_contrast(src[i], scratch[i][0], 1.0 + mag, dst[i]) if idx == 9: # Sharpness - sharpen(im[i], dst[i], 1.0 + mag) + sharpen(src[i], dst[i], 1.0 + mag) if idx == 10: # Posterize - posterize(im[i], int(mag), dst[i]) + posterize(src[i], int(mag), dst[i]) if idx == 11: # Solarize - solarize(im[i], mag, dst[i]) + solarize(src[i], mag, dst[i]) if idx == 12: # AutoContrast (TODO: takes 0.04s -> 0.052s) (+0.01s) - autocontrast(im[i], scratchf[i][0], dst[i]) + autocontrast(src[i], scratchf[i][0], dst[i]) if idx == 13: # Equalize (TODO: +0.008s) - equalize(im[i], lut[i], dst[i]) + equalize(src[i], lut[i], dst[i]) return dst @@ -174,7 +182,7 @@ def test_brightness(amt): Snp = np.zeros(Xnp.shape, dtype=np.uint8) Xch = torch.tensor(Xnp.astype(np.float32)/255.).permute(2, 0, 1) Ych = (255*tv.functional.adjust_brightness(Xch, amt).permute(1, 2, 0).numpy()).astype(np.uint8) - blend(Xnp, Snp, amt, Ynp) + adjust_brightness(Xnp, Snp, amt, Ynp) plt.subplot(1, 2, 1) plt.imshow(Ynp) @@ -402,7 +410,16 @@ def test_translate(amt): drop_last=True, pipelines={'image': pipeline}) import matplotlib.pyplot as plt + idx = 0 for ims, labs in loader: pass +# print('a') +# if name == 'with': +# for i in range(5): +# for j in range(5): +# plt.subplot(5, 5, i * 5 + j + 1) +# plt.imshow(ims[i * 5 + j].numpy()) +# plt.savefig('scratch/%d.png'%idx) +# idx+=1 start_time = time.time() for _ in range(5): #(100): for ims, labs in loader: pass From d6074ba7a3f35103e481b4d9d68363378d0ebc8f Mon Sep 17 00:00:00 2001 From: Asher Trockman Date: Thu, 17 Feb 2022 19:25:15 -0500 Subject: [PATCH 16/20] Move RandAugment --- ffcv/transforms/__init__.py | 1 + ffcv/transforms/randaugment.py | 119 +++++++++++++++++++++++++++++++++ tests/test_rand_aug.py | 115 +------------------------------ 3 files changed, 121 insertions(+), 114 deletions(-) create mode 100644 ffcv/transforms/randaugment.py diff --git a/ffcv/transforms/__init__.py b/ffcv/transforms/__init__.py index bc8fa321..fa95e1d1 100644 --- a/ffcv/transforms/__init__.py +++ b/ffcv/transforms/__init__.py @@ -9,6 +9,7 @@ from .translate import RandomTranslate from .mixup import ImageMixup, LabelMixup, MixupToOneHot from .module import ModuleWrapper +from .randaugment import RandAugment __all__ = ['ToTensor', 'ToDevice', 'ToTorchImage', 'NormalizeImage', diff --git a/ffcv/transforms/randaugment.py b/ffcv/transforms/randaugment.py new file mode 100644 index 00000000..40042fcf --- /dev/null +++ b/ffcv/transforms/randaugment.py @@ -0,0 +1,119 @@ +import numpy as np +from ffcv.pipeline.compiler import Compiler +from ffcv.pipeline.operation import Operation, AllocationQuery +from dataclasses import replace +from typing import Callable, Optional, Tuple +from ffcv.pipeline.state import State +from ffcv.transforms.utils.fast_crop import rotate, shear, blend, \ + adjust_contrast, posterize, invert, solarize, equalize, fast_equalize, \ + autocontrast, sharpen, adjust_saturation, translate, adjust_brightness + +class RandAugment(Operation): + def __init__(self, + size: int = 32, + num_ops: int = 2, + magnitude: int = 9, + num_magnitude_bins: int = 31): + super().__init__() + self.size = size + self.num_ops = num_ops + self.magnitude = magnitude + num_bins = num_magnitude_bins + # index, name (for readability); bins, sign multiplier + # those with a -1 can have negative magnitude with probability 0.5 + self.op_table = [ + (0, "Identity", np.array(0.0), 1), + (1, "ShearX", np.linspace(0.0, 0.3, num_bins), -1), + (2, "ShearY", np.linspace(0.0, 0.3, num_bins), -1), + (3, "TranslateX", np.linspace(0.0, 150.0 / 331.0 * size, num_bins), -1), + (4, "TranslateY", np.linspace(0.0, 150.0 / 331.0 * size, num_bins), -1), + (5, "Rotate", np.linspace(0.0, 30.0, num_bins), -1), + (6, "Brightness", np.linspace(0.0, 0.9, num_bins), -1), + (7, "Color", np.linspace(0.0, 0.9, num_bins), -1), + (8, "Contrast", np.linspace(0.0, 0.9, num_bins), -1), + (9, "Sharpness", np.linspace(0.0, 0.9, num_bins), -1), + (10, "Posterize", 8 - (np.arange(num_bins) / ((num_bins - 1) / 4)).round().astype('uint8'), 1), + (11, "Solarize", np.linspace(255.0, 0.0, num_bins), 1), + (12, "AutoContrast", np.array(0.0), 1), + (13, "Equalize", np.array(0.0), 1), + ] + + def generate_code(self) -> Callable: + my_range = Compiler.get_iterator() + op_table = self.op_table + magnitudes = np.array([(op[2][self.magnitude] if op[2].ndim > 0 else 0) for op in self.op_table]) + is_signed = np.array([op[3] for op in self.op_table]) + num_ops = self.num_ops +# for i in range(len(magnitudes)): +# print(i, op_table[i][1], '%.3f'%magnitudes[i]) + def randaug(im, mem): + dst, scratch, lut, scratchf = mem + for i in my_range(im.shape[0]): + for n in range(num_ops): + if n == 0: + src = im + else: + src = dst + + idx = np.random.randint(low=0, high=13+1) + mag = magnitudes[idx] + if np.random.random() < 0.5: + mag = mag * is_signed[idx] + + # Not worth fighting numba at the moment. + # TODO + if idx == 0: + dst[i][:] = src[i] + + if idx == 1: # ShearX (0.004) + shear(src[i], dst[i], mag, 0) + + if idx == 2: # ShearY + shear(src[i], dst[i], 0, mag) + + if idx == 3: # TranslateX + translate(src[i], dst[i], int(mag), 0) + + if idx == 4: # TranslateY + translate(src[i], dst[i], 0, int(mag)) + + if idx == 5: # Rotate + rotate(src[i], dst[i], mag) + + if idx == 6: # Brightness + adjust_brightness(src[i], scratch[i][0], 1.0 + mag, dst[i]) + + if idx == 7: # Color + adjust_saturation(src[i], scratch[i][0], 1.0 + mag, dst[i]) + + if idx == 8: # Contrast + adjust_contrast(src[i], scratch[i][0], 1.0 + mag, dst[i]) + + if idx == 9: # Sharpness + sharpen(src[i], dst[i], 1.0 + mag) + + if idx == 10: # Posterize + posterize(src[i], int(mag), dst[i]) + + if idx == 11: # Solarize + solarize(src[i], mag, dst[i]) + + if idx == 12: # AutoContrast (TODO: takes 0.04s -> 0.052s) (+0.01s) + autocontrast(src[i], scratchf[i][0], dst[i]) + + if idx == 13: # Equalize (TODO: +0.008s) + equalize(src[i], lut[i], dst[i]) + + return dst + + randaug.is_parallel = True + return randaug + + def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Optional[AllocationQuery]]: + assert previous_state.jit_mode + return replace(previous_state, shape=(self.size, self.size, 3)), [ + AllocationQuery((self.size, self.size, 3), dtype=np.dtype('uint8')), + AllocationQuery((1, self.size, self.size, 3), dtype=np.dtype('uint8')), + AllocationQuery((3, 256), dtype=np.dtype('int16')), + AllocationQuery((1, self.size, self.size, 3), dtype=np.dtype('float32')), + ] \ No newline at end of file diff --git a/tests/test_rand_aug.py b/tests/test_rand_aug.py index 619c045e..5de7ee14 100644 --- a/tests/test_rand_aug.py +++ b/tests/test_rand_aug.py @@ -5,9 +5,7 @@ from ffcv.fields import IntField, RGBImageField from ffcv.fields.decoders import SimpleRGBImageDecoder from ffcv.loader import Loader, OrderOption -from ffcv.pipeline.compiler import Compiler -from ffcv.pipeline.operation import Operation, AllocationQuery -from ffcv.transforms import ToTensor, ToTorchImage +from ffcv.transforms import ToTensor, ToTorchImage, RandAugment from ffcv.writer import DatasetWriter from dataclasses import replace from typing import Callable, Optional, Tuple @@ -19,117 +17,6 @@ import cv2 import pytest import math - -class RandAugment(Operation): - def __init__(self, - size: int = 32, - num_ops: int = 2, - magnitude: int = 9, - num_magnitude_bins: int = 31): - super().__init__() - self.size = size - self.num_ops = num_ops - self.magnitude = magnitude - num_bins = num_magnitude_bins - # index, name (for readability); bins, sign multiplier - # those with a -1 can have negative magnitude with probability 0.5 - self.op_table = [ - (0, "Identity", np.array(0.0), 1), - (1, "ShearX", np.linspace(0.0, 0.3, num_bins), -1), - (2, "ShearY", np.linspace(0.0, 0.3, num_bins), -1), - (3, "TranslateX", np.linspace(0.0, 150.0 / 331.0 * size, num_bins), -1), - (4, "TranslateY", np.linspace(0.0, 150.0 / 331.0 * size, num_bins), -1), - (5, "Rotate", np.linspace(0.0, 30.0, num_bins), -1), - (6, "Brightness", np.linspace(0.0, 0.9, num_bins), -1), - (7, "Color", np.linspace(0.0, 0.9, num_bins), -1), - (8, "Contrast", np.linspace(0.0, 0.9, num_bins), -1), - (9, "Sharpness", np.linspace(0.0, 0.9, num_bins), -1), - (10, "Posterize", 8 - (np.arange(num_bins) / ((num_bins - 1) / 4)).round().astype('uint8'), 1), - (11, "Solarize", np.linspace(255.0, 0.0, num_bins), 1), - (12, "AutoContrast", np.array(0.0), 1), - (13, "Equalize", np.array(0.0), 1), - ] - - def generate_code(self) -> Callable: - my_range = Compiler.get_iterator() - op_table = self.op_table - magnitudes = np.array([(op[2][self.magnitude] if op[2].ndim > 0 else 0) for op in self.op_table]) - is_signed = np.array([op[3] for op in self.op_table]) - num_ops = self.num_ops -# for i in range(len(magnitudes)): -# print(i, op_table[i][1], '%.3f'%magnitudes[i]) - def randaug(im, mem): - dst, scratch, lut, scratchf = mem - for i in my_range(im.shape[0]): - for n in range(num_ops): - if n == 0: - src = im - else: - src = dst - - idx = np.random.randint(low=0, high=13+1) - mag = magnitudes[idx] - if np.random.random() < 0.5: - mag = mag * is_signed[idx] - - # Not worth fighting numba at the moment. - # TODO - if idx == 0: - dst[i][:] = src[i] - - if idx == 1: # ShearX (0.004) - shear(src[i], dst[i], mag, 0) - - if idx == 2: # ShearY - shear(src[i], dst[i], 0, mag) - - if idx == 3: # TranslateX - translate(src[i], dst[i], int(mag), 0) - - if idx == 4: # TranslateY - translate(src[i], dst[i], 0, int(mag)) - - if idx == 5: # Rotate - rotate(src[i], dst[i], mag) - - if idx == 6: # Brightness - adjust_brightness(src[i], scratch[i][0], 1.0 + mag, dst[i]) - - if idx == 7: # Color - adjust_saturation(src[i], scratch[i][0], 1.0 + mag, dst[i]) - - if idx == 8: # Contrast - adjust_contrast(src[i], scratch[i][0], 1.0 + mag, dst[i]) - - if idx == 9: # Sharpness - sharpen(src[i], dst[i], 1.0 + mag) - - if idx == 10: # Posterize - posterize(src[i], int(mag), dst[i]) - - if idx == 11: # Solarize - solarize(src[i], mag, dst[i]) - - if idx == 12: # AutoContrast (TODO: takes 0.04s -> 0.052s) (+0.01s) - autocontrast(src[i], scratchf[i][0], dst[i]) - - if idx == 13: # Equalize (TODO: +0.008s) - equalize(src[i], lut[i], dst[i]) - - return dst - - randaug.is_parallel = True - return randaug - - def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Optional[AllocationQuery]]: - assert previous_state.jit_mode - return replace(previous_state, shape=(self.size, self.size, 3)), [ - AllocationQuery((self.size, self.size, 3), dtype=np.dtype('uint8')), - AllocationQuery((1, self.size, self.size, 3), dtype=np.dtype('uint8')), - AllocationQuery((3, 256), dtype=np.dtype('int16')), - AllocationQuery((1, self.size, self.size, 3), dtype=np.dtype('float32')), - ] - @pytest.mark.parametrize('angle', [45, -30]) def test_rotate(angle): From 8a2a9bef5d005c266a5c684b604c9af5db6cece7 Mon Sep 17 00:00:00 2001 From: Asher Trockman Date: Thu, 17 Feb 2022 21:41:28 -0500 Subject: [PATCH 17/20] Fix bug (all 0 image) when translation amount is 0 --- ffcv/transforms/utils/fast_crop.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/ffcv/transforms/utils/fast_crop.py b/ffcv/transforms/utils/fast_crop.py index 1569c9f7..02687fde 100644 --- a/ffcv/transforms/utils/fast_crop.py +++ b/ffcv/transforms/utils/fast_crop.py @@ -158,6 +158,9 @@ def sharpen(source, destination, amount): """ @njit(inline='always') def translate(source, destination, tx, ty): + if tx == 0 and ty == 0: + destination[:] = source + return if tx > 0: destination[:, tx:] = source[:, :-tx] destination[:, :tx] = 0 From 69adc182139fa2b4864adad20867d6b56b6c6f5e Mon Sep 17 00:00:00 2001 From: Asher Trockman Date: Fri, 18 Feb 2022 13:28:01 -0500 Subject: [PATCH 18/20] Solarize and sharpen need scratch memory --- ffcv/transforms/randaugment.py | 4 ++-- ffcv/transforms/utils/fast_crop.py | 20 ++++++++++---------- tests/test_rand_aug.py | 9 +++++---- 3 files changed, 17 insertions(+), 16 deletions(-) diff --git a/ffcv/transforms/randaugment.py b/ffcv/transforms/randaugment.py index 40042fcf..21723660 100644 --- a/ffcv/transforms/randaugment.py +++ b/ffcv/transforms/randaugment.py @@ -90,13 +90,13 @@ def randaug(im, mem): adjust_contrast(src[i], scratch[i][0], 1.0 + mag, dst[i]) if idx == 9: # Sharpness - sharpen(src[i], dst[i], 1.0 + mag) + sharpen(src[i], scratch[i][0], 1.0 + mag, dst[i]) if idx == 10: # Posterize posterize(src[i], int(mag), dst[i]) if idx == 11: # Solarize - solarize(src[i], mag, dst[i]) + solarize(src[i], scratch[i][0], mag, dst[i]) if idx == 12: # AutoContrast (TODO: takes 0.04s -> 0.052s) (+0.01s) autocontrast(src[i], scratchf[i][0], dst[i]) diff --git a/ffcv/transforms/utils/fast_crop.py b/ffcv/transforms/utils/fast_crop.py index 02687fde..9417d834 100644 --- a/ffcv/transforms/utils/fast_crop.py +++ b/ffcv/transforms/utils/fast_crop.py @@ -86,9 +86,9 @@ def invert(source, destination): @njit(parallel=False, fastmath=True, inline='always') -def solarize(source, threshold, destination): - invert(source, destination) - destination[:] = np.where(source >= threshold, destination, source) +def solarize(source, scratch, threshold, destination): + invert(source, scratch) + destination[:] = np.where(source >= threshold, scratch, source) @njit(parallel=False, fastmath=True, inline='always') @@ -136,19 +136,19 @@ def adjust_contrast(source, scratch, factor, destination): @njit(fastmath=True, inline='always') -def sharpen(source, destination, amount): +def sharpen(source, scratch, amount, destination): ctypes_unsharp_mask(source.ctypes.data, - destination.ctypes.data, + scratch.ctypes.data, source.shape[0], source.shape[1]) # in PyTorch's implementation, # the border is unaffected - destination[0,:] = source[0,:] - destination[1:,0] = source[1:,0] - destination[-1,:] = source[-1,:] - destination[1:-1,-1] = source[1:-1,-1] + scratch[0,:] = source[0,:] + scratch[1:,0] = source[1:,0] + scratch[-1,:] = source[-1,:] + scratch[1:-1,-1] = source[1:-1,-1] - blend(source, destination, amount, destination) + blend(source, scratch, amount, destination) """ diff --git a/tests/test_rand_aug.py b/tests/test_rand_aug.py index 5de7ee14..826a75b1 100644 --- a/tests/test_rand_aug.py +++ b/tests/test_rand_aug.py @@ -100,7 +100,7 @@ def test_adjust_contrast(amt): #print(Ynp.min(), Ynp.max(), Ych.min(), Ych.max()) -@pytest.mark.parametrize('bits', [2]) +@pytest.mark.parametrize('bits', [2, 3, 4, 5, 6, 7]) def test_posterize(bits): Xnp = np.random.uniform(0, 256, size=(32, 32, 3)).astype(np.uint8) Ynp = np.zeros(Xnp.shape, dtype=np.uint8) @@ -141,9 +141,10 @@ def test_solarize(threshold): Xnp[10:15,10:15,:] = 8 Xnp[27:31,27:31,:] = 9 Ynp = np.zeros(Xnp.shape, dtype=np.uint8) + Snp = np.zeros(Xnp.shape, dtype=np.uint8) Xch = torch.tensor(Xnp).permute(2, 0, 1) Ych = tv.functional.solarize(Xch, threshold).permute(1, 2, 0).numpy() - solarize(Xnp, threshold, Ynp) + solarize(Xnp, Snp, threshold, Ynp) plt.subplot(1, 2, 1) plt.imshow(Ynp) @@ -202,7 +203,7 @@ def test_autocontrast(): assert np.linalg.norm(Ynp.astype(np.float32) - Ych.astype(np.float32)) < 100 -@pytest.mark.parametrize('amt', [2.0]) +@pytest.mark.parametrize('amt', [0.5, 0.75, 1.0, 2.0]) def test_sharpen(amt): Xnp = np.random.uniform(0, 256, size=(32, 32, 3)).astype(np.uint8) #Xnp = cv2.imread('example_imgs/0249.png') @@ -210,7 +211,7 @@ def test_sharpen(amt): Snp = np.zeros(Xnp.shape, dtype=np.uint8) Xch = torch.tensor(Xnp).permute(2, 0, 1) Ych = tv.functional.adjust_sharpness(Xch, amt).permute(1, 2, 0).numpy() - sharpen(Xnp, Ynp, amt) + sharpen(Xnp, Snp, amt, Ynp) plt.subplot(1, 2, 1) plt.imshow(Ynp) From 02792e3dd92ca22f28d48139c38e24186c9f2483 Mon Sep 17 00:00:00 2001 From: Asher Trockman Date: Fri, 18 Feb 2022 19:23:01 -0500 Subject: [PATCH 19/20] Remove size argument --- ffcv/transforms/__init__.py | 1 + ffcv/transforms/randaugment.py | 23 +++++++++++------------ 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/ffcv/transforms/__init__.py b/ffcv/transforms/__init__.py index fa95e1d1..55e235cf 100644 --- a/ffcv/transforms/__init__.py +++ b/ffcv/transforms/__init__.py @@ -15,6 +15,7 @@ 'ToTorchImage', 'NormalizeImage', 'Convert', 'Squeeze', 'View', 'RandomResizedCrop', 'RandomHorizontalFlip', 'RandomTranslate', + 'RandAugment', 'Cutout', 'ImageMixup', 'LabelMixup', 'MixupToOneHot', 'Poison', 'ReplaceLabel', 'ModuleWrapper'] \ No newline at end of file diff --git a/ffcv/transforms/randaugment.py b/ffcv/transforms/randaugment.py index 21723660..5f9b0733 100644 --- a/ffcv/transforms/randaugment.py +++ b/ffcv/transforms/randaugment.py @@ -10,12 +10,10 @@ class RandAugment(Operation): def __init__(self, - size: int = 32, num_ops: int = 2, magnitude: int = 9, num_magnitude_bins: int = 31): super().__init__() - self.size = size self.num_ops = num_ops self.magnitude = magnitude num_bins = num_magnitude_bins @@ -25,14 +23,14 @@ def __init__(self, (0, "Identity", np.array(0.0), 1), (1, "ShearX", np.linspace(0.0, 0.3, num_bins), -1), (2, "ShearY", np.linspace(0.0, 0.3, num_bins), -1), - (3, "TranslateX", np.linspace(0.0, 150.0 / 331.0 * size, num_bins), -1), - (4, "TranslateY", np.linspace(0.0, 150.0 / 331.0 * size, num_bins), -1), + (3, "TranslateX", np.linspace(0.0, 150.0 / 331.0, num_bins), -1), + (4, "TranslateY", np.linspace(0.0, 150.0 / 331.0, num_bins), -1), (5, "Rotate", np.linspace(0.0, 30.0, num_bins), -1), (6, "Brightness", np.linspace(0.0, 0.9, num_bins), -1), (7, "Color", np.linspace(0.0, 0.9, num_bins), -1), (8, "Contrast", np.linspace(0.0, 0.9, num_bins), -1), (9, "Sharpness", np.linspace(0.0, 0.9, num_bins), -1), - (10, "Posterize", 8 - (np.arange(num_bins) / ((num_bins - 1) / 4)).round().astype('uint8'), 1), + (10, "Posterize", 8 - (np.arange(num_bins) / ((num_bins - 1) / 4)).round(), 1), (11, "Solarize", np.linspace(255.0, 0.0, num_bins), 1), (12, "AutoContrast", np.array(0.0), 1), (13, "Equalize", np.array(0.0), 1), @@ -72,10 +70,10 @@ def randaug(im, mem): shear(src[i], dst[i], 0, mag) if idx == 3: # TranslateX - translate(src[i], dst[i], int(mag), 0) + translate(src[i], dst[i], int(src[i].shape[1] * mag), 0) if idx == 4: # TranslateY - translate(src[i], dst[i], 0, int(mag)) + translate(src[i], dst[i], 0, int(src[i].shape[2] * mag)) if idx == 5: # Rotate rotate(src[i], dst[i], mag) @@ -111,9 +109,10 @@ def randaug(im, mem): def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Optional[AllocationQuery]]: assert previous_state.jit_mode - return replace(previous_state, shape=(self.size, self.size, 3)), [ - AllocationQuery((self.size, self.size, 3), dtype=np.dtype('uint8')), - AllocationQuery((1, self.size, self.size, 3), dtype=np.dtype('uint8')), - AllocationQuery((3, 256), dtype=np.dtype('int16')), - AllocationQuery((1, self.size, self.size, 3), dtype=np.dtype('float32')), + h, w, c = previous_state.shape + return replace(previous_state, shape=previous_state.shape), [ + AllocationQuery(previous_state.shape, dtype=np.dtype('uint8')), + AllocationQuery((1, h, w, c), dtype=np.dtype('uint8')), + AllocationQuery((c, 256), dtype=np.dtype('int16')), + AllocationQuery((1, h, w, c), dtype=np.dtype('float32')), ] \ No newline at end of file From 0fa3d7c8a041fec7348025afafe6e4f6fff60f33 Mon Sep 17 00:00:00 2001 From: Asher Trockman Date: Thu, 16 Feb 2023 20:24:12 -0500 Subject: [PATCH 20/20] Update graph.py --- ffcv/pipeline/graph.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ffcv/pipeline/graph.py b/ffcv/pipeline/graph.py index b7f6b9d5..584a45b2 100644 --- a/ffcv/pipeline/graph.py +++ b/ffcv/pipeline/graph.py @@ -324,7 +324,7 @@ def collect_requirements(self, state=INITIAL_STATE, next_state, allocation = operation.declare_state_and_memory(state) state_allocation = operation.declare_shared_memory(state) - if next_state.device.type != 'cuda' and isinstance(operation, + if next_state.device != 'cuda' and isinstance(operation, ModuleWrapper): msg = ("Using a pytorch transform on the CPU is extremely" "detrimental to the performance, consider moving the augmentation" @@ -479,4 +479,4 @@ def codegen_all(self, code): code_stages.append(self.codegen_stage(stage, s_ix, op_to_node, code, already_defined)) final_output = [x.id for x in self.leaf_nodes.values()] - return code_stages, final_output \ No newline at end of file + return code_stages, final_output