Skip to content
Open
3 changes: 2 additions & 1 deletion examples/cifar/train_cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def make_dataloaders(train_dataset=None, val_dataset=None, batch_size=None, num_
Convert(ch.float16),
torchvision.transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
])

ordering = OrderOption.RANDOM if name == 'train' else OrderOption.SEQUENTIAL

loaders[name] = Loader(paths[name], batch_size=batch_size, num_workers=num_workers,
Expand Down Expand Up @@ -145,6 +145,7 @@ def construct_model():
model = model.to(memory_format=ch.channels_last).cuda()
return model


@param('training.lr')
@param('training.epochs')
@param('training.momentum')
Expand Down
71 changes: 46 additions & 25 deletions ffcv/fields/rgb_image.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import ABCMeta, abstractmethod
from functools import partial
from dataclasses import replace
from typing import Optional, Callable, TYPE_CHECKING, Tuple, Type

Expand All @@ -23,8 +24,11 @@
IMAGE_MODES['raw'] = 1


def encode_jpeg(numpy_image, quality):
numpy_image = cv2.cvtColor(numpy_image, cv2.COLOR_RGB2BGR)
def encode_jpeg(numpy_image, quality, is_rgb):
if is_rgb:
numpy_image = cv2.cvtColor(numpy_image, cv2.COLOR_RGB2BGR)

# TODO this def assumes rgb lol
success, result = cv2.imencode('.jpg', numpy_image,
[int(cv2.IMWRITE_JPEG_QUALITY), quality])

Expand Down Expand Up @@ -86,7 +90,9 @@ class SimpleRGBImageDecoder(Operation):

It only supports dataset with constant image resolution and will simply read (potentially decompress) and pass the images as is.
"""
def __init__(self):
def __init__(self, is_rgb):
self.is_rgb = is_rgb
self.channels = 3 if is_rgb else 1
super().__init__()

def declare_state_and_memory(self, previous_state: State) -> Tuple[State, AllocationQuery]:
Expand All @@ -102,7 +108,7 @@ def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Alloca
instead."""
raise TypeError(msg)

biggest_shape = (max_height, max_width, 3)
biggest_shape = (max_height, max_width, self.channels)
my_dtype = np.dtype('<u1')

return (
Expand All @@ -119,6 +125,7 @@ def generate_code(self) -> Callable:
raw = IMAGE_MODES['raw']
my_range = Compiler.get_iterator()
my_memcpy = Compiler.compile(memcpy)
is_rgb = self.is_rgb

def decode(batch_indices, destination, metadata, storage_state):
for dst_ix in my_range(len(batch_indices)):
Expand All @@ -128,8 +135,9 @@ def decode(batch_indices, destination, metadata, storage_state):
height, width = field['height'], field['width']

if field['mode'] == jpg:
imdecode_c(image_data, destination[dst_ix],
height, width, height, width, 0, 0, 1, 1, False, False)
imdecode_c(image_data, destination[dst_ix], height, width,
height, width, 0, 0, 1, 1, False, False,
is_rgb)
else:
my_memcpy(image_data, destination[dst_ix])

Expand All @@ -138,14 +146,17 @@ def decode(batch_indices, destination, metadata, storage_state):
decode.is_parallel = True
return decode

class SimpleGrayscaleImageDecoder(SimpleRGBImageDecoder):
def __init__(self):
super().__init__(is_rgb=False)

class ResizedCropRGBImageDecoder(SimpleRGBImageDecoder, metaclass=ABCMeta):
"""Abstract decoder for :class:`~ffcv.fields.RGBImageField` that performs a crop and and a resize operation.

It supports both variable and constant resolution datasets.
"""
def __init__(self, output_size):
super().__init__()
def __init__(self, output_size, is_rgb):
super().__init__(is_rgb)
self.output_size = output_size

def declare_state_and_memory(self, previous_state: State) -> Tuple[State, AllocationQuery]:
Expand All @@ -154,19 +165,19 @@ def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Alloca
# We convert to uint64 to avoid overflows
self.max_width = np.uint64(widths.max())
self.max_height = np.uint64(heights.max())
output_shape = (self.output_size[0], self.output_size[1], 3)
output_shape = (self.output_size[0], self.output_size[1], self.channels)
my_dtype = np.dtype('<u1')

channels = np.uint64(self.channels)
return (
replace(previous_state, jit_mode=True,
shape=output_shape, dtype=my_dtype),
(AllocationQuery(output_shape, my_dtype),
AllocationQuery((self.max_height * self.max_width * np.uint64(3),), my_dtype),
AllocationQuery((self.max_height * self.max_width * channels,), my_dtype),
)
)

def generate_code(self) -> Callable:

jpg = IMAGE_MODES['jpg']

mem_read = self.memory_read
Expand All @@ -177,6 +188,8 @@ def generate_code(self) -> Callable:

scale = self.scale
ratio = self.ratio
is_rgb = self.is_rgb
channels = self.channels
if isinstance(scale, tuple):
scale = np.array(scale)
if isinstance(ratio, tuple):
Expand All @@ -193,21 +206,21 @@ def decode(batch_indices, my_storage, metadata, storage_state):

if field['mode'] == jpg:
temp_buffer = temp_storage[dst_ix]
imdecode_c(image_data, temp_buffer,
height, width, height, width, 0, 0, 1, 1, False, False)
selected_size = 3 * height * width
imdecode_c(image_data, temp_buffer, height, width, height,
width, 0, 0, 1, 1, False, False, is_rgb)
selected_size = channels * height * width
temp_buffer = temp_buffer.reshape(-1)[:selected_size]
temp_buffer = temp_buffer.reshape(height, width, 3)

temp_buffer = temp_buffer.reshape(height, width, channels)
else:
temp_buffer = image_data.reshape(height, width, 3)
temp_buffer = image_data.reshape(height, width, channels)

i, j, h, w = get_crop_c(height, width, scale, ratio)

resize_crop_c(temp_buffer, i, i + h, j, j + w,
destination[dst_ix])
destination[dst_ix], is_rgb)

return destination[:len(batch_indices)]

decode.is_parallel = True
return decode

Expand All @@ -231,8 +244,9 @@ class RandomResizedCropRGBImageDecoder(ResizedCropRGBImageDecoder):
ratio : Tuple[float]
The range of potential aspect ratios that can be randomly sampled
"""
def __init__(self, output_size, scale=(0.08, 1.0), ratio=(0.75, 4/3)):
super().__init__(output_size)
def __init__(self, output_size, scale=(0.08, 1.0), ratio=(0.75, 4/3),
is_rgb=True):
super().__init__(output_size, is_rgb=is_rgb)
self.scale = scale
self.ratio = ratio
self.output_size = output_size
Expand All @@ -255,8 +269,8 @@ class CenterCropRGBImageDecoder(ResizedCropRGBImageDecoder):
ratio of (crop size) / (min side length)
"""
# output size: resize crop size -> output size
def __init__(self, output_size, ratio):
super().__init__(output_size)
def __init__(self, output_size, ratio, is_rgb=True):
super().__init__(output_size, is_rgb=is_rgb)
self.scale = None
self.ratio = ratio

Expand Down Expand Up @@ -291,12 +305,13 @@ class RGBImageField(Field):
"""
def __init__(self, write_mode='raw', max_resolution: int = None,
smart_threshold: int = None, jpeg_quality: int = 90,
compress_probability: float = 0.5) -> None:
compress_probability: float = 0.5, is_rgb: bool = True) -> None:
self.write_mode = write_mode
self.smart_threshold = smart_threshold
self.max_resolution = max_resolution
self.jpeg_quality = int(jpeg_quality)
self.proportion = compress_probability
self.is_rgb = is_rgb

@property
def metadata_type(self) -> np.dtype:
Expand All @@ -308,7 +323,10 @@ def metadata_type(self) -> np.dtype:
])

def get_decoder_class(self) -> Type[Operation]:
return SimpleRGBImageDecoder
if self.is_rgb:
return SimpleRGBImageDecoder # TODO
else:
return SimpleGrayscaleImageDecoder # TODO

@staticmethod
def from_binary(binary: ARG_TYPE) -> Field:
Expand All @@ -327,7 +345,10 @@ def encode(self, destination, image, malloc):
if image.dtype != np.uint8:
raise ValueError("Image type has to be uint8")

if image.shape[2] != 3:
shape = image.shape
is_ok_grayscale = len(shape) == 2 and not self.is_rgb
is_ok_rgb = len(shape) == 3 and shape[2] == 3 and self.is_rgb
if not (is_ok_rgb or is_ok_grayscale):
raise ValueError(f"Invalid shape for rgb image: {image.shape}")

assert image.dtype == np.uint8
Expand Down
13 changes: 7 additions & 6 deletions ffcv/libffcv.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,32 +15,33 @@ def read(fileno:int, destination:np.ndarray, offset:int):


ctypes_resize = lib.resize
ctypes_resize.argtypes = 11 * [c_int64]
ctypes_resize.argtypes = (11 * [c_int64]) + [c_bool]

def resize_crop(source, start_row, end_row, start_col, end_col, destination):
def resize_crop(source, start_row, end_row, start_col, end_col, destination,
is_rgb):
ctypes_resize(0,
source.ctypes.data,
source.shape[0], source.shape[1],
start_row, end_row, start_col, end_col,
destination.ctypes.data,
destination.shape[0], destination.shape[1])
destination.shape[0], destination.shape[1], is_rgb)

# Extract and define the interface of imdeocde
ctypes_imdecode = lib.imdecode
ctypes_imdecode.argtypes = [
c_void_p, c_uint64, c_uint32, c_uint32, c_void_p, c_uint32, c_uint32,
c_uint32, c_uint32, c_uint32, c_uint32, c_bool, c_bool
c_uint32, c_uint32, c_uint32, c_uint32, c_bool, c_bool, c_bool
]

def imdecode(source: np.ndarray, dst: np.ndarray,
source_height: int, source_width: int,
crop_height=None, crop_width=None,
offset_x=0, offset_y=0, scale_factor_num=1, scale_factor_denom=1,
enable_crop=False, do_flip=False):
enable_crop=False, do_flip=False, is_rgb=True):
return ctypes_imdecode(source.ctypes.data, source.size,
source_height, source_width, dst.ctypes.data,
crop_height, crop_width, offset_x, offset_y, scale_factor_num, scale_factor_denom,
enable_crop, do_flip)
enable_crop, do_flip, is_rgb)


ctypes_memcopy = lib.my_memcpy
Expand Down
1 change: 1 addition & 0 deletions ffcv/loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ def __init__(self,
# We check if the user disabled this field
if operations is None:
continue

if not isinstance(operations[0], DecoderClass):
msg = "The first operation of the pipeline for "
msg += f"'{field_name}' has to be a subclass of "
Expand Down
1 change: 1 addition & 0 deletions ffcv/memory_managers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(self, reader:Reader):
self.ptrs = self.ptrs[order]
self.sizes = self.sizes[order]

print('initi memory manager', len(self.ptrs), len(self.sizes))
self.ptr_to_size = dict(zip(self.ptrs, self.sizes))

# We extract the page number by shifting the address corresponding
Expand Down
3 changes: 2 additions & 1 deletion ffcv/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
from .translate import RandomTranslate
from .mixup import ImageMixup, LabelMixup, MixupToOneHot
from .module import ModuleWrapper
from .colorjitter import ColorJitter

__all__ = ['ToTensor', 'ToDevice',
'ToTorchImage', 'NormalizeImage',
'Convert', 'Squeeze', 'View',
'RandomResizedCrop', 'RandomHorizontalFlip', 'RandomTranslate',
'Cutout', 'ImageMixup', 'LabelMixup', 'MixupToOneHot',
'Poison', 'ReplaceLabel',
'ModuleWrapper']
'ModuleWrapper', 'ColorJitter']
Loading