Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,6 @@ You will need:
- An Athena host URL.
- An OAuth client ID and secret with access to the Athena environment.
- An affiliate with Athena enabled.
- `imagemagick` installed on your system and on your path at `magick`.


#### Preparing your environment
Expand Down
5 changes: 5 additions & 0 deletions common_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Utility package for testing and examples.

This package contains helper functions that are not core to the client library,
but are shared across the examples and the tests.
"""
Original file line number Diff line number Diff line change
@@ -1,31 +1,31 @@
"""Ultra-fast random image creation utilities for maximum throughput."""
"""Ultra-fast random image creation utilities for maximum throughput.

This file is intended to be used for generating benign test images for the
purposes of integration testing the client, as is provided as a convenience
for API consumers.
"""

import asyncio
import io
import random
import time
from collections.abc import AsyncIterator

from PIL import Image, ImageDraw
import cv2 as cv
import numpy as np

from resolver_athena_client.client.models import ImageData

# Global cache for reusable objects and constants
_image_cache: dict[
tuple[int, int], tuple[Image.Image, ImageDraw.ImageDraw]
] = {}
_image_cache: dict[tuple[int, int], np.ndarray] = {}
_rng = random.Random() # noqa: S311 - Not used for cryptographic purposes


def _get_cached_image(
width: int, height: int
) -> tuple[Image.Image, ImageDraw.ImageDraw]:
"""Get cached image and draw objects, creating if needed."""
def _get_cached_image(width: int, height: int) -> np.ndarray:
"""Get cached image array, creating if needed."""
key = (width, height)
if key not in _image_cache:
img = Image.new("RGB", (width, height), (0, 0, 0))
draw = ImageDraw.Draw(img)
_image_cache[key] = (img, draw)
img = np.zeros((height, width, 3), dtype=np.uint8)
_image_cache[key] = img
return _image_cache[key]


Expand All @@ -45,8 +45,9 @@ def create_random_image(
PNG image bytes

"""
# Get cached image and draw objects
image, draw = _get_cached_image(width, height)
# Get cached image array
image = _get_cached_image(width, height)
img = image.copy()

# Random background color
bg_r, bg_g, bg_b = (
Expand All @@ -56,21 +57,24 @@ def create_random_image(
)

# Fill with background color
draw.rectangle([0, 0, width, height], fill=(bg_r, bg_g, bg_b))
img[:, :] = (bg_b, bg_g, bg_r) # OpenCV uses BGR

# Add single accent rectangle for visual variation
accent_color = (255 - bg_r, 255 - bg_g, 255 - bg_b)
accent_color = (255 - bg_b, 255 - bg_g, 255 - bg_r) # BGR
x1, y1 = width // 4, height // 4
x2, y2 = (width * 3) // 4, (height * 3) // 4
draw.rectangle([x1, y1, x2, y2], fill=accent_color)
img = cv.rectangle(img, (x1, y1), (x2, y2), accent_color, thickness=-1)

if img_format.upper() == "RAW_UINT8":
return image.tobytes()
return img.tobytes()

# Convert to PNG bytes
buffer = io.BytesIO()
image.save(buffer, format=img_format)
return buffer.getvalue()
# Convert to PNG/JPEG bytes
ext = f".{img_format.lower()}"
success, buf = cv.imencode(ext, img)
if not success:
err = f"Failed to encode image as {img_format}"
raise RuntimeError(err)
return buf.tobytes()


def create_batch_images(
Expand All @@ -90,29 +94,32 @@ def create_batch_images(

"""
images: list[bytes] = []
image, draw = _get_cached_image(width, height)
image = _get_cached_image(width, height)

# Pre-calculate accent rectangle coordinates
x1, y1 = width // 4, height // 4
x2, y2 = (width * 3) // 4, (height * 3) // 4

for _ in range(count):
img = image.copy()
# Random background
bg_r, bg_g, bg_b = (
_rng.randint(0, 255),
_rng.randint(0, 255),
_rng.randint(0, 255),
)
draw.rectangle([0, 0, width, height], fill=(bg_r, bg_g, bg_b))
img[:, :] = (bg_b, bg_g, bg_r) # OpenCV uses BGR

# Complement accent color
accent_color = (255 - bg_r, 255 - bg_g, 255 - bg_b)
draw.rectangle([x1, y1, x2, y2], fill=accent_color)
accent_color = (255 - bg_b, 255 - bg_g, 255 - bg_r) # BGR
img = cv.rectangle(img, (x1, y1), (x2, y2), accent_color, thickness=-1)

# Convert to PNG bytes
buffer = io.BytesIO()
image.save(buffer, format="PNG")
images.append(buffer.getvalue())
success, buf = cv.imencode(".png", img)
if not success:
msg = "Failed to encode image as PNG"
raise RuntimeError(msg)
images.append(buf.tobytes())

return images

Expand Down Expand Up @@ -196,18 +203,7 @@ async def rate_limited_image_iter(
def create_random_image_generator(
max_images: int, rate_limit_min_interval_ms: int | None = None
) -> AsyncIterator[ImageData]:
"""Generate a stream of random test images.

Args:
----
max_images: Maximum number of images to generate
rate_limit_min_interval_ms: Minimum interval in ms between images

Yields:
------
ImageData objects containing random image bytes

"""
"""Create an async generator for images with optional rate limiting."""
if rate_limit_min_interval_ms is not None:
return rate_limited_image_iter(rate_limit_min_interval_ms, max_images)

Expand Down
2 changes: 1 addition & 1 deletion docs/api/transformers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ Transformers accept configuration through their constructors:
**ImageResizer Configuration:**

* ``target_size``: Tuple of (width, height) for output dimensions
* ``resampling``: PIL resampling algorithm (default: ``Image.LANCZOS``)
* ``resampling``: OpenCV resampling algorithm (default: ``cv.INTER_LINEAR``)
* ``maintain_aspect_ratio``: Whether to preserve aspect ratio (default: ``True``)

**BrotliCompressor Configuration:**
Expand Down
2 changes: 1 addition & 1 deletion examples/classify_single_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from dotenv import load_dotenv

from examples.utils.image_generation import create_test_image
from common_utils.image_generation import create_test_image
from resolver_athena_client.client.athena_client import AthenaClient
from resolver_athena_client.client.athena_options import AthenaOptions
from resolver_athena_client.client.channel import (
Expand Down
2 changes: 1 addition & 1 deletion examples/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from dotenv import load_dotenv

from examples.utils.image_generation import iter_images
from common_utils.image_generation import iter_images
from examples.utils.streaming_classify_utils import count_and_yield
from resolver_athena_client.client.athena_client import AthenaClient
from resolver_athena_client.client.athena_options import AthenaOptions
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ dependencies = [
"grpcio-tools>=1.74.0",
"httpx>=0.25.0",
"numpy>=2.2.6",
"pillow>=11.3.0",
"opencv-python-headless>=4.13.0.92"
]

[project.optional-dependencies]
Expand Down
9 changes: 6 additions & 3 deletions src/resolver_athena_client/client/athena_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

from dataclasses import dataclass

from PIL.Image import Resampling

from resolver_athena_client.client.correlation import (
CorrelationProvider,
HashCorrelationProvider,
)
from resolver_athena_client.client.transformers.core import (
OpenCVResamplingAlgorithm,
)


@dataclass
Expand Down Expand Up @@ -69,4 +70,6 @@ class AthenaOptions:
timeout: float | None = 120.0
keepalive_interval: float | None = None
compression_quality: int = 11 # Brotli quality level (0-11)
resampling_algorithm: Resampling = Resampling.LANCZOS
resampling_algorithm: OpenCVResamplingAlgorithm = (
OpenCVResamplingAlgorithm.BILINEAR
)
64 changes: 36 additions & 28 deletions src/resolver_athena_client/client/transformers/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
"""

import asyncio
from io import BytesIO
import enum

import brotli
from PIL import Image
import cv2 as cv
import numpy as np

from resolver_athena_client.client.consts import EXPECTED_HEIGHT, EXPECTED_WIDTH
from resolver_athena_client.client.models import ImageData
Expand All @@ -20,14 +21,29 @@
_expected_raw_size = EXPECTED_WIDTH * EXPECTED_HEIGHT * 3


class OpenCVResamplingAlgorithm(enum.Enum):
"""Open CV Resampling Configuration.

Enum for ease of configuration and type-safety when selecting OpenCV
resampling algorithms.
"""

NEAREST = cv.INTER_NEAREST
BOX = cv.INTER_AREA
BILINEAR = cv.INTER_LINEAR
LANCZOS = cv.INTER_LANCZOS4


def _is_raw_bgr_expected_size(data: bytes) -> bool:
"""Detect if data is already a raw BGR array of expected size."""
return len(data) == _expected_raw_size


async def resize_image(
image_data: ImageData,
sampling_algorithm: Image.Resampling = Image.Resampling.LANCZOS,
sampling_algorithm: OpenCVResamplingAlgorithm = (
OpenCVResamplingAlgorithm.BILINEAR
),
) -> ImageData:
"""Resize an image to expected dimensions.

Expand All @@ -49,31 +65,23 @@ def process_image() -> tuple[bytes, bool]:
return image_data.data, False # No transformation needed

# Try to load the image data directly
input_buffer = BytesIO(image_data.data)

with Image.open(input_buffer) as image:
# Convert to RGB if needed
rgb_image = image.convert("RGB") if image.mode != "RGB" else image

# Resize if needed
if rgb_image.size != _target_size:
resized_image = rgb_image.resize(
_target_size, sampling_algorithm
)
else:
resized_image = rgb_image

rgb_bytes = resized_image.tobytes()

# Convert RGB to BGR by swapping channels
bgr_bytes = bytearray(len(rgb_bytes))

for i in range(0, len(rgb_bytes), 3):
bgr_bytes[i] = rgb_bytes[i + 2]
bgr_bytes[i + 1] = rgb_bytes[i + 1]
bgr_bytes[i + 2] = rgb_bytes[i]

return bytes(bgr_bytes), True # Data was transformed
img_data_buf = np.frombuffer(image_data.data, dtype=np.uint8)
img = cv.imdecode(img_data_buf, cv.IMREAD_COLOR)

if img is None:
err = "Failed to decode image data for resizing"
raise ValueError(err)

if img.shape[0] == EXPECTED_HEIGHT and img.shape[1] == EXPECTED_WIDTH:
resized_img = img
else:
resized_img = cv.resize(
img, _target_size, interpolation=sampling_algorithm.value
)

# OpenCV loads in BGR format by default, so we can directly convert to
# bytes
return resized_img.tobytes(), True # Data was transformed

# Use thread pool for CPU-intensive processing
resized_bytes, was_transformed = await asyncio.to_thread(process_image)
Expand Down
28 changes: 20 additions & 8 deletions tests/client/transformers/test_core.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
"""Test core transformation functions."""

from io import BytesIO

import cv2 as cv
import numpy as np
import pytest
from PIL import Image

from resolver_athena_client.client.consts import (
EXPECTED_HEIGHT,
Expand All @@ -19,11 +18,24 @@
def create_test_image(
width: int = 100, height: int = 100, mode: str = "RGB"
) -> bytes:
"""Create a test image with specified dimensions."""
img = Image.new(mode, (width, height), color="red")
img_bytes = BytesIO()
img.save(img_bytes, format="PNG")
return img_bytes.getvalue()
"""Create a test image with specified dimensions using OpenCV."""
# Map mode to OpenCV color shape
if mode == "RGB":
color = (255, 0, 0) # Red in RGB
img = np.full((height, width, 3), color, dtype=np.uint8)
elif mode == "L":
color = 76 # Red in grayscale
img = np.full((height, width), color, dtype=np.uint8)
else:
err = f"Unsupported mode: {mode}"
raise ValueError(err)

success, buf = cv.imencode(".png", img)
if not success:
err = "Failed to encode image to PNG"
raise RuntimeError(err)

return buf.tobytes()


@pytest.mark.asyncio
Expand Down
19 changes: 13 additions & 6 deletions tests/client/transformers/test_hash_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""Tests for hash list behavior throughout the transformation pipeline."""

import hashlib
from io import BytesIO

import cv2 as cv
import numpy as np
import pytest
from PIL import Image

from resolver_athena_client.client.consts import EXPECTED_HEIGHT, EXPECTED_WIDTH
from resolver_athena_client.client.models import ImageData
Expand All @@ -21,10 +21,17 @@

def create_test_png_image(width: int = 200, height: int = 200) -> bytes:
"""Create a test PNG image with specified dimensions."""
img = Image.new("RGB", (width, height), color=(255, 0, 0))
buffer = BytesIO()
img.save(buffer, format="PNG")
return buffer.getvalue()

# Create a red RGB image using numpy
img = np.zeros((height, width, 3), dtype=np.uint8)
img[:] = (255, 0, 0) # Red color

# Encode image as PNG to memory
success, buffer = cv.imencode(".png", img)
if not success:
err = "Failed to encode image as PNG"
raise RuntimeError(err)
return buffer.tobytes()


@pytest.mark.asyncio
Expand Down
Loading