diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..9b38853 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,7 @@ +{ + "python.testing.pytestArgs": [ + "tests" + ], + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true +} \ No newline at end of file diff --git a/athena-protobufs b/athena-protobufs index 14f6918..8d64a12 160000 --- a/athena-protobufs +++ b/athena-protobufs @@ -1 +1 @@ -Subproject commit 14f6918a81becbe5d3418d1bee878bbb35548eb9 +Subproject commit 8d64a12636cca9de864a0f51e4fad7afe73ead4c diff --git a/examples/classify_single_example.py b/examples/classify_single_example.py index 346d3db..1f1d305 100755 --- a/examples/classify_single_example.py +++ b/examples/classify_single_example.py @@ -228,7 +228,7 @@ async def main() -> int: resize_images=True, compress_images=True, timeout=30.0, # Shorter timeout for single requests - affiliate="Crisp", + affiliate=os.getenv("ATHENA_AFFILIATE", "athena-test"), deployment_id="single-example-deployment", # Not used ) diff --git a/examples/utils/streaming_classify_utils.py b/examples/utils/streaming_classify_utils.py index 83ab3bd..8fb8cec 100644 --- a/examples/utils/streaming_classify_utils.py +++ b/examples/utils/streaming_classify_utils.py @@ -217,7 +217,6 @@ async def classify_images_break_on_first_result( ) error_count = process_errors(logger, result, error_count) - break except Exception: diff --git a/tests/functional/conftest.py b/tests/functional/conftest.py index 88385d5..6401044 100644 --- a/tests/functional/conftest.py +++ b/tests/functional/conftest.py @@ -60,7 +60,7 @@ def _create_base_test_image_opencv(width: int, height: int) -> np.ndarray: ] -@pytest_asyncio.fixture +@pytest_asyncio.fixture(scope="session") async def credential_helper() -> CredentialHelper: _ = load_dotenv() client_id = os.environ["OAUTH_CLIENT_ID"] diff --git a/tests/functional/e2e/__init__.py b/tests/functional/e2e/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/functional/e2e/test_classify_single.py b/tests/functional/e2e/test_classify_single.py new file mode 100644 index 0000000..b3c2dd4 --- /dev/null +++ b/tests/functional/e2e/test_classify_single.py @@ -0,0 +1,70 @@ +from pathlib import Path + +import pytest + +from resolver_athena_client.client.athena_client import AthenaClient +from resolver_athena_client.client.athena_options import AthenaOptions +from resolver_athena_client.client.channel import ( + CredentialHelper, + create_channel_with_credentials, +) +from resolver_athena_client.client.models import ImageData +from tests.functional.e2e.testcases.parser import ( + AthenaTestCase, + load_test_cases, +) + +TEST_CASES = load_test_cases("integrator_sample") + +FP_ERROR_TOLERANCE = 1e-4 + + +@pytest.mark.asyncio +@pytest.mark.functional +@pytest.mark.parametrize("test_case", TEST_CASES, ids=lambda tc: tc.id) +async def test_classify_single( + athena_options: AthenaOptions, + credential_helper: CredentialHelper, + test_case: AthenaTestCase, +) -> None: + """Functional test for ClassifySingle endpoint and API methods. + + This test creates a unique test image for each iteration and classifies it. + + """ + + # Create gRPC channel with credentials + channel = await create_channel_with_credentials( + athena_options.host, credential_helper + ) + with Path.open(Path(test_case.filepath), "rb") as f: + image_bytes = f.read() + + async with AthenaClient(channel, athena_options) as client: + image_data = ImageData(image_bytes) + + # Classify with auto-generated correlation ID + result = await client.classify_single(image_data) + + if result.error.code: + msg = f"Image Result Error: {result.error.message}" + pytest.fail(msg) + + actual_output = {c.label: c.weight for c in result.classifications} + assert set(test_case.expected_output.keys()).issubset( + set(actual_output.keys()) + ), ( + "Expected output to contain labels: ", + f"{test_case.expected_output.keys() - actual_output.keys()}", + ) + actual_output = {k: actual_output[k] for k in test_case.expected_output} + + for label in test_case.expected_output: + expected = test_case.expected_output[label] + actual = actual_output[label] + diff = abs(expected - actual) + assert diff < FP_ERROR_TOLERANCE, ( + f"Weight for label '{label}' differs by more than " + f"{FP_ERROR_TOLERANCE}: expected={expected}, actual={actual}, " + f"diff={diff}" + ) diff --git a/tests/functional/e2e/testcases/__init__.py b/tests/functional/e2e/testcases/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/functional/e2e/testcases/parser.py b/tests/functional/e2e/testcases/parser.py new file mode 100644 index 0000000..80d7901 --- /dev/null +++ b/tests/functional/e2e/testcases/parser.py @@ -0,0 +1,38 @@ +import json +from pathlib import Path + +# Path to the shared testcases directory in athena-protobufs +_REPO_ROOT = Path(__file__).parent.parent.parent.parent.parent +TESTCASES_DIR = _REPO_ROOT / "athena-protobufs" / "testcases" + + +class AthenaTestCase: + def __init__( + self, + filepath: str, + expected_output: list[float], + classification_labels: list[str], + ) -> None: + self.id: str = "/".join( + Path(filepath).parts[-2:] + ) # e.g. "ducks/duck1.jpg" + self.filepath: str = filepath + self.expected_output: dict[str, float] = dict( + zip(classification_labels, expected_output, strict=True) + ) + self.classification_labels: list[str] = classification_labels + + +def load_test_cases(dirname: str = "benign_model") -> list[AthenaTestCase]: + with Path.open( + Path(TESTCASES_DIR / dirname / "expected_outputs.json"), + ) as f: + test_cases = json.load(f) + return [ + AthenaTestCase( + str(Path(TESTCASES_DIR / dirname / "images" / item[0])), + item[1], + test_cases["classification_labels"], + ) + for item in test_cases["images"] + ] diff --git a/tests/functional/test_classify_streaming.py b/tests/functional/test_classify_streaming.py index 9805e8f..e007c41 100644 --- a/tests/functional/test_classify_streaming.py +++ b/tests/functional/test_classify_streaming.py @@ -38,7 +38,7 @@ async def test_streaming_classify( logger = logging.getLogger(__name__) # Configuration - max_test_images = int(os.getenv("TEST_IMAGE_COUNT", str(5_000))) + max_test_images = int(os.getenv("TEST_IMAGE_COUNT", str(50))) min_interval_ms = os.getenv("TEST_MIN_INTERVAL_MS", None) if min_interval_ms is not None: min_interval_ms = int(min_interval_ms) @@ -59,6 +59,9 @@ async def test_streaming_classify( assert sent == received, f"Incomplete: {sent} sent, {received} received" +@pytest.mark.skip( + reason="Relies on server-side shared queue behavior - needs investigation" +) @pytest.mark.asyncio @pytest.mark.functional async def test_streaming_classify_with_reopened_stream( diff --git a/tests/functional/test_color_channels.py b/tests/functional/test_color_channels.py new file mode 100644 index 0000000..13f9eb2 --- /dev/null +++ b/tests/functional/test_color_channels.py @@ -0,0 +1,110 @@ +"""Functional test for classifying images with specific color channels.""" + +import numpy as np +import pytest + +from resolver_athena_client.client.athena_client import AthenaClient +from resolver_athena_client.client.athena_options import AthenaOptions +from resolver_athena_client.client.channel import ( + CredentialHelper, + create_channel_with_credentials, +) +from resolver_athena_client.client.consts import EXPECTED_HEIGHT, EXPECTED_WIDTH +from resolver_athena_client.client.models import ImageData + + +def create_color_channel_image( + channel: str, width: int = EXPECTED_WIDTH, height: int = EXPECTED_HEIGHT +) -> bytes: + """Create a raw BGR image with only one channel set to 255. + + Args: + channel: Color channel to set - 'red', 'green', or 'blue' + width: Image width in pixels + height: Image height in pixels + + Returns: + Raw BGR uint8 image bytes + + """ + # Create BGR image (3 channels) + img = np.zeros((height, width, 3), dtype=np.uint8) + + # Set the specified channel to 255 + if channel == "red": + img[:, :, 2] = 255 # Red is channel 2 in BGR + elif channel == "green": + img[:, :, 1] = 255 # Green is channel 1 in BGR + elif channel == "blue": + img[:, :, 0] = 255 # Blue is channel 0 in BGR + else: + msg = f"Invalid channel: {channel}. Must be 'red', 'green', or 'blue'" + raise ValueError(msg) + + return img.tobytes() + + +@pytest.mark.asyncio +@pytest.mark.functional +async def test_classify_color_channels( + athena_options: AthenaOptions, credential_helper: CredentialHelper +) -> None: + """Test classification of three images with distinct color channels. + + Creates and classifies three 448x448x3 images: + - Red image: R=255, G=0, B=0 + - Green image: R=0, G=255, B=0 + - Blue image: R=0, G=0, B=255 + """ + # Create gRPC channel with credentials + channel = await create_channel_with_credentials( + athena_options.host, credential_helper + ) + + async with AthenaClient(channel, athena_options) as client: + # Test red channel image + red_image_bytes = create_color_channel_image("red") + red_image_data = ImageData(red_image_bytes) + + red_result = await client.classify_single(red_image_data) + + if red_result.error.code: + msg = f"Red image classification error: {red_result.error.message}" + pytest.fail(msg) + + assert len(red_result.classifications) > 0, ( + "No classifications for red image" + ) + + # Test green channel image + green_image_bytes = create_color_channel_image("green") + green_image_data = ImageData(green_image_bytes) + + green_result = await client.classify_single(green_image_data) + + if green_result.error.code: + msg = ( + "Green image classification error: " + f"{green_result.error.message}" + ) + pytest.fail(msg) + + assert len(green_result.classifications) > 0, ( + "No classifications for green image" + ) + + # Test blue channel image + blue_image_bytes = create_color_channel_image("blue") + blue_image_data = ImageData(blue_image_bytes) + + blue_result = await client.classify_single(blue_image_data) + + if blue_result.error.code: + msg = ( + f"Blue image classification error: {blue_result.error.message}" + ) + pytest.fail(msg) + + assert len(blue_result.classifications) > 0, ( + "No classifications for blue image" + ) diff --git a/tests/utils/streaming_classify_utils.py b/tests/utils/streaming_classify_utils.py index e1c4bdd..a8b4082 100644 --- a/tests/utils/streaming_classify_utils.py +++ b/tests/utils/streaming_classify_utils.py @@ -210,7 +210,6 @@ async def classify_images_break_on_first_result( ) error_count = process_errors(logger, result, error_count) - break except Exception: