-
Notifications
You must be signed in to change notification settings - Fork 182
[WIP] Add RandAugment #154
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: v1.1.0
Are you sure you want to change the base?
Changes from all commits
19f4039
21e2512
cc3a99b
08b96b0
5b0359a
2025013
6b6432c
fc2bdf6
56e5cc3
5da59e8
784f3d0
e267538
6fc2362
4d2e1ed
9e3a2ee
f6ba479
d6074ba
8a2a9be
69adc18
02792e3
0fa3d7c
2019acc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,118 @@ | ||
| import numpy as np | ||
| from ffcv.pipeline.compiler import Compiler | ||
| from ffcv.pipeline.operation import Operation, AllocationQuery | ||
| from dataclasses import replace | ||
| from typing import Callable, Optional, Tuple | ||
| from ffcv.pipeline.state import State | ||
| from ffcv.transforms.utils.fast_crop import rotate, shear, blend, \ | ||
| adjust_contrast, posterize, invert, solarize, equalize, fast_equalize, \ | ||
| autocontrast, sharpen, adjust_saturation, translate, adjust_brightness | ||
|
|
||
| class RandAugment(Operation): | ||
| def __init__(self, | ||
| num_ops: int = 2, | ||
| magnitude: int = 9, | ||
| num_magnitude_bins: int = 31): | ||
| super().__init__() | ||
| self.num_ops = num_ops | ||
| self.magnitude = magnitude | ||
| num_bins = num_magnitude_bins | ||
| # index, name (for readability); bins, sign multiplier | ||
| # those with a -1 can have negative magnitude with probability 0.5 | ||
| self.op_table = [ | ||
| (0, "Identity", np.array(0.0), 1), | ||
| (1, "ShearX", np.linspace(0.0, 0.3, num_bins), -1), | ||
| (2, "ShearY", np.linspace(0.0, 0.3, num_bins), -1), | ||
| (3, "TranslateX", np.linspace(0.0, 150.0 / 331.0, num_bins), -1), | ||
| (4, "TranslateY", np.linspace(0.0, 150.0 / 331.0, num_bins), -1), | ||
| (5, "Rotate", np.linspace(0.0, 30.0, num_bins), -1), | ||
| (6, "Brightness", np.linspace(0.0, 0.9, num_bins), -1), | ||
| (7, "Color", np.linspace(0.0, 0.9, num_bins), -1), | ||
| (8, "Contrast", np.linspace(0.0, 0.9, num_bins), -1), | ||
| (9, "Sharpness", np.linspace(0.0, 0.9, num_bins), -1), | ||
| (10, "Posterize", 8 - (np.arange(num_bins) / ((num_bins - 1) / 4)).round(), 1), | ||
| (11, "Solarize", np.linspace(255.0, 0.0, num_bins), 1), | ||
| (12, "AutoContrast", np.array(0.0), 1), | ||
| (13, "Equalize", np.array(0.0), 1), | ||
| ] | ||
|
|
||
| def generate_code(self) -> Callable: | ||
| my_range = Compiler.get_iterator() | ||
| op_table = self.op_table | ||
| magnitudes = np.array([(op[2][self.magnitude] if op[2].ndim > 0 else 0) for op in self.op_table]) | ||
| is_signed = np.array([op[3] for op in self.op_table]) | ||
| num_ops = self.num_ops | ||
| # for i in range(len(magnitudes)): | ||
| # print(i, op_table[i][1], '%.3f'%magnitudes[i]) | ||
| def randaug(im, mem): | ||
| dst, scratch, lut, scratchf = mem | ||
| for i in my_range(im.shape[0]): | ||
| for n in range(num_ops): | ||
| if n == 0: | ||
| src = im | ||
| else: | ||
| src = dst | ||
|
|
||
| idx = np.random.randint(low=0, high=13+1) | ||
| mag = magnitudes[idx] | ||
| if np.random.random() < 0.5: | ||
| mag = mag * is_signed[idx] | ||
|
|
||
| # Not worth fighting numba at the moment. | ||
| # TODO | ||
| if idx == 0: | ||
| dst[i][:] = src[i] | ||
|
|
||
| if idx == 1: # ShearX (0.004) | ||
| shear(src[i], dst[i], mag, 0) | ||
|
|
||
| if idx == 2: # ShearY | ||
| shear(src[i], dst[i], 0, mag) | ||
|
|
||
| if idx == 3: # TranslateX | ||
| translate(src[i], dst[i], int(src[i].shape[1] * mag), 0) | ||
|
|
||
| if idx == 4: # TranslateY | ||
| translate(src[i], dst[i], 0, int(src[i].shape[2] * mag)) | ||
|
|
||
| if idx == 5: # Rotate | ||
| rotate(src[i], dst[i], mag) | ||
|
|
||
| if idx == 6: # Brightness | ||
| adjust_brightness(src[i], scratch[i][0], 1.0 + mag, dst[i]) | ||
|
|
||
| if idx == 7: # Color | ||
| adjust_saturation(src[i], scratch[i][0], 1.0 + mag, dst[i]) | ||
|
|
||
| if idx == 8: # Contrast | ||
| adjust_contrast(src[i], scratch[i][0], 1.0 + mag, dst[i]) | ||
|
|
||
| if idx == 9: # Sharpness | ||
| sharpen(src[i], scratch[i][0], 1.0 + mag, dst[i]) | ||
|
|
||
| if idx == 10: # Posterize | ||
| posterize(src[i], int(mag), dst[i]) | ||
|
|
||
| if idx == 11: # Solarize | ||
| solarize(src[i], scratch[i][0], mag, dst[i]) | ||
|
|
||
| if idx == 12: # AutoContrast (TODO: takes 0.04s -> 0.052s) (+0.01s) | ||
| autocontrast(src[i], scratchf[i][0], dst[i]) | ||
|
|
||
| if idx == 13: # Equalize (TODO: +0.008s) | ||
| equalize(src[i], lut[i], dst[i]) | ||
|
|
||
| return dst | ||
|
|
||
| randaug.is_parallel = True | ||
| return randaug | ||
|
|
||
| def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Optional[AllocationQuery]]: | ||
| assert previous_state.jit_mode | ||
| h, w, c = previous_state.shape | ||
| return replace(previous_state, shape=previous_state.shape), [ | ||
| AllocationQuery(previous_state.shape, dtype=np.dtype('uint8')), | ||
| AllocationQuery((1, h, w, c), dtype=np.dtype('uint8')), | ||
| AllocationQuery((c, 256), dtype=np.dtype('int16')), | ||
| AllocationQuery((1, h, w, c), dtype=np.dtype('float32')), | ||
| ] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,7 +1,195 @@ | ||
| import ctypes | ||
| from numba import njit | ||
| from numba import njit, prange | ||
| import numpy as np | ||
| from ...libffcv import ctypes_resize | ||
| from ...libffcv import ctypes_resize, ctypes_rotate, ctypes_shear, \ | ||
| ctypes_add_weighted, ctypes_equalize, ctypes_unsharp_mask | ||
|
|
||
| """ | ||
| Requires a float32 scratch array | ||
| """ | ||
| @njit(parallel=True, fastmath=True, inline='always') | ||
| def autocontrast(source, scratchf, destination): | ||
| # numba: no kwargs in min? as a consequence, I might as well have written | ||
| # this in C++ | ||
| # TODO assuming 3 channels | ||
| minimum = [source[..., 0].min(), source[..., 1].min(), source[..., 2].min()] | ||
| maximum = [source[..., 0].max(), source[..., 1].max(), source[..., 2].max()] | ||
| scale = [0.0, 0.0, 0.0] | ||
| for i in prange(source.shape[-1]): | ||
| if minimum[i] == maximum[i]: | ||
| scale[i] = 1 | ||
| minimum[i] = 0 | ||
| else: | ||
| scale[i] = 255. / (maximum[i] - minimum[i]) | ||
| for i in prange(source.shape[-1]): | ||
| scratchf[..., i] = source[..., i] - minimum[i] | ||
| scratchf[..., i] = scratchf[..., i] * scale[i] | ||
| np.clip(scratchf, 0, 255, out=scratchf) | ||
| destination[:] = scratchf | ||
|
|
||
|
|
||
| """ | ||
| Custom equalize -- equivalent to torchvision.transforms.functional.equalize, | ||
| but probably slow -- scratch is a (channels, 256) uint16 array. | ||
| """ | ||
| @njit(parallel=True, fastmath=True, inline='always') | ||
| def equalize(source, scratch, destination): | ||
| for i in prange(source.shape[-1]): | ||
| # TODO memory less than ideal for bincount() and hist() | ||
| scratch[i] = np.bincount(source[..., i].flatten(), minlength=256) | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unfortunate that
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A numba version should be pretty fast and relatively easy to implement no ? (and might even be faster since it would skip the first pass of bincount that checks the min and max values)
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, good idea. I'll try to add that in the near future. |
||
| nonzero_hist = scratch[i][scratch[i] != 0] | ||
| step = nonzero_hist[:-1].sum() // 255 | ||
|
|
||
| if step == 0: | ||
| continue | ||
|
|
||
| scratch[i][1:] = scratch[i].cumsum()[:-1] | ||
| scratch[i] = (scratch[i] + step // 2) // step | ||
| scratch[i][0] = 0 | ||
| np.clip(scratch[i], 0, 255, out=scratch[i]) | ||
|
|
||
| # numba doesn't like 2d advanced indexing | ||
| for row in prange(source.shape[0]): | ||
| destination[row, :, i] = scratch[i][source[row, :, i]] | ||
|
|
||
| """ | ||
| Equalize using OpenCV -- not equivalent to | ||
| torchvision.transforms.functional.equalize for so-far-unknown reasons. | ||
| """ | ||
| @njit(parallel=False, fastmath=True, inline='always') | ||
| def fast_equalize(source, chw_scratch, destination): | ||
| # this seems kind of hacky | ||
| # also, assuming ctypes_equalize allocates a minimal amount of memory | ||
| # which may be incorrect -- so maybe we should do this from scratch. | ||
| # TODO may be a better way to do this in pure OpenCV | ||
| c, h, w = chw_scratch.shape | ||
| chw_scratch[0] = source[..., 0] | ||
| ctypes_equalize(chw_scratch.ctypes.data, | ||
| chw_scratch.ctypes.data, | ||
| h, w) | ||
| chw_scratch[1] = source[..., 1] | ||
| ctypes_equalize(chw_scratch.ctypes.data + h*w, | ||
| chw_scratch.ctypes.data + h*w, | ||
| h, w) | ||
| chw_scratch[2] = source[..., 2] | ||
| ctypes_equalize(chw_scratch.ctypes.data + 2*h*w, | ||
| chw_scratch.ctypes.data + 2*h*w, | ||
| h, w) | ||
| destination[..., 0] = chw_scratch[0] | ||
| destination[..., 1] = chw_scratch[1] | ||
| destination[..., 2] = chw_scratch[2] | ||
|
|
||
|
|
||
| @njit(parallel=False, fastmath=True, inline='always') | ||
| def invert(source, destination): | ||
| destination[:] = 255 - source | ||
|
|
||
|
|
||
| @njit(parallel=False, fastmath=True, inline='always') | ||
| def solarize(source, scratch, threshold, destination): | ||
| invert(source, scratch) | ||
| destination[:] = np.where(source >= threshold, scratch, source) | ||
|
|
||
|
|
||
| @njit(parallel=False, fastmath=True, inline='always') | ||
| def posterize(source, bits, destination): | ||
| mask = ~(2 ** (8 - bits) - 1) | ||
| destination[:] = source & mask | ||
|
|
||
|
|
||
| @njit(inline='always') | ||
| def blend(source1, source2, ratio, destination): | ||
| ctypes_add_weighted(source1.ctypes.data, ratio, | ||
| source2.ctypes.data, 1 - ratio, | ||
| destination.ctypes.data, | ||
| source1.shape[0], source1.shape[1]) | ||
|
|
||
|
|
||
| @njit(inline='always') | ||
| def adjust_brightness(source, scratch, factor, destination): | ||
| scratch[:] = 0 | ||
| blend(source, scratch, factor, destination) | ||
|
|
||
|
|
||
| @njit(parallel=False, fastmath=True, inline='always') | ||
| def adjust_saturation(source, scratch, factor, destination): | ||
| # TODO numpy autocasting probably allocates memory here, | ||
| # should be more careful. | ||
| # TODO do we really need scratch for this? could use destination | ||
| scratch[...,0] = 0.299 * source[..., 0] + \ | ||
| 0.587 * source[..., 1] + \ | ||
| 0.114 * source[..., 2] | ||
| scratch[...,1] = scratch[...,0] | ||
| scratch[...,2] = scratch[...,1] | ||
|
|
||
| blend(source, scratch, factor, destination) | ||
|
|
||
|
|
||
| @njit(parallel=False, fastmath=True, inline='always') | ||
| def adjust_contrast(source, scratch, factor, destination): | ||
| # TODO assuming 3 channels | ||
| scratch[:,:,:] = np.mean(0.299 * source[..., 0] + | ||
| 0.587 * source[..., 1] + | ||
| 0.114 * source[..., 2]) | ||
|
|
||
| blend(source, scratch, factor, destination) | ||
|
|
||
|
|
||
| @njit(fastmath=True, inline='always') | ||
| def sharpen(source, scratch, amount, destination): | ||
| ctypes_unsharp_mask(source.ctypes.data, | ||
| scratch.ctypes.data, | ||
| source.shape[0], source.shape[1]) | ||
|
|
||
| # in PyTorch's implementation, | ||
| # the border is unaffected | ||
| scratch[0,:] = source[0,:] | ||
| scratch[1:,0] = source[1:,0] | ||
| scratch[-1,:] = source[-1,:] | ||
| scratch[1:-1,-1] = source[1:-1,-1] | ||
|
|
||
| blend(source, scratch, amount, destination) | ||
|
|
||
|
|
||
| """ | ||
| Translation, x and y | ||
| Assuming this is faster than warpAffine; | ||
| also assuming tx and ty are ints | ||
| """ | ||
| @njit(inline='always') | ||
| def translate(source, destination, tx, ty): | ||
| if tx == 0 and ty == 0: | ||
| destination[:] = source | ||
| return | ||
| if tx > 0: | ||
| destination[:, tx:] = source[:, :-tx] | ||
| destination[:, :tx] = 0 | ||
| if tx < 0: | ||
| destination[:, :tx] = source[:, -tx:] | ||
| destination[:, tx:] = 0 | ||
| if ty > 0: | ||
| destination[ty:, :] = source[:-ty, :] | ||
| destination[:ty, :] = 0 | ||
| if ty < 0: | ||
| destination[:ty, :] = source[-ty:, :] | ||
| destination[ty:, :] = 0 | ||
|
|
||
|
|
||
| @njit(inline='always') | ||
| def rotate(source, destination, angle): | ||
| ctypes_rotate(angle, | ||
| source.ctypes.data, | ||
| destination.ctypes.data, | ||
| source.shape[0], source.shape[1]) | ||
|
|
||
|
|
||
| @njit(inline='always') | ||
| def shear(source, destination, shear_x, shear_y): | ||
| ctypes_shear(shear_x, shear_y, | ||
| source.ctypes.data, | ||
| destination.ctypes.data, | ||
| source.shape[0], source.shape[1]) | ||
|
|
||
|
|
||
| @njit(inline='always') | ||
| def resize_crop(source, start_row, end_row, start_col, end_col, destination): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think as of v1.0 the device will be a
torch.devicein which case we would wantnext_state.device.type?