diff --git a/agave/tools/__init__.py b/agave/tools/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/agave/tools/asyncio/__init__.py b/agave/tools/asyncio/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/agave/tools/asyncio/sqs_celery_client.py b/agave/tools/asyncio/sqs_celery_client.py new file mode 100644 index 00000000..ed649251 --- /dev/null +++ b/agave/tools/asyncio/sqs_celery_client.py @@ -0,0 +1,27 @@ +import asyncio +from dataclasses import dataclass +from typing import Iterable, Optional + +from ..celery import build_celery_message +from .sqs_client import SqsClient + + +@dataclass +class SqsCeleryClient(SqsClient): + async def send_task( + self, + name: str, + args: Optional[Iterable] = None, + kwargs: Optional[dict] = None, + ) -> None: + celery_message = build_celery_message(name, args or (), kwargs or {}) + await super().send_message(celery_message) + + def send_background_task( + self, + name: str, + args: Optional[Iterable] = None, + kwargs: Optional[dict] = None, + ) -> asyncio.Task: + celery_message = build_celery_message(name, args or (), kwargs or {}) + return super().send_message_async(celery_message) diff --git a/agave/tasks/sqs_client.py b/agave/tools/asyncio/sqs_client.py similarity index 82% rename from agave/tasks/sqs_client.py rename to agave/tools/asyncio/sqs_client.py index cb20d2fe..31e05683 100644 --- a/agave/tasks/sqs_client.py +++ b/agave/tools/asyncio/sqs_client.py @@ -9,8 +9,8 @@ from types_aiobotocore_sqs import SQSClient except ImportError: raise ImportError( - "You must install agave with [fastapi, tasks] option.\n" - "You can install it with: pip install agave[fastapi, tasks]" + "You must install agave with [asyncio_aws_tools] option.\n" + "You can install it with: pip install agave[asyncio_aws_tools]" ) @@ -25,11 +25,16 @@ class SqsClient: def background_tasks(self) -> set: return self._background_tasks - async def __aenter__(self): + async def __aenter__(self) -> "SqsClient": await self.start() return self - async def __aexit__(self, exc_type, exc_val, exc_tb): + async def __aexit__( + self, + exc_type: Optional[type], + exc_val: Optional[Exception], + exc_tb: Optional[object], + ) -> None: await self.close() async def start(self): diff --git a/agave/tasks/sqs_celery_client.py b/agave/tools/celery.py similarity index 62% rename from agave/tasks/sqs_celery_client.py rename to agave/tools/celery.py index 33f02bb5..647b926b 100644 --- a/agave/tasks/sqs_celery_client.py +++ b/agave/tools/celery.py @@ -1,14 +1,15 @@ -import asyncio import json from base64 import b64encode -from dataclasses import dataclass -from typing import Iterable, Optional +from typing import Iterable from uuid import uuid4 -from agave.tasks.sqs_client import SqsClient +def _b64_encode(value: str) -> str: + encoded = b64encode(bytes(value, 'utf-8')) + return encoded.decode('utf-8') -def _build_celery_message( + +def build_celery_message( task_name: str, args_: Iterable, kwargs_: dict ) -> str: task_id = str(uuid4()) @@ -47,29 +48,3 @@ def _build_celery_message( encoded = _b64_encode(json.dumps(message)) return encoded - - -def _b64_encode(value: str) -> str: - encoded = b64encode(bytes(value, 'utf-8')) - return encoded.decode('utf-8') - - -@dataclass -class SqsCeleryClient(SqsClient): - async def send_task( - self, - name: str, - args: Optional[Iterable] = None, - kwargs: Optional[dict] = None, - ) -> None: - celery_message = _build_celery_message(name, args or (), kwargs or {}) - await super().send_message(celery_message) - - def send_background_task( - self, - name: str, - args: Optional[Iterable] = None, - kwargs: Optional[dict] = None, - ) -> asyncio.Task: - celery_message = _build_celery_message(name, args or (), kwargs or {}) - return super().send_message_async(celery_message) diff --git a/agave/tools/sync/__init__.py b/agave/tools/sync/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/agave/tools/sync/sqs_celery_client.py b/agave/tools/sync/sqs_celery_client.py new file mode 100644 index 00000000..06cdf337 --- /dev/null +++ b/agave/tools/sync/sqs_celery_client.py @@ -0,0 +1,17 @@ +from dataclasses import dataclass +from typing import Iterable, Optional + +from ..celery import build_celery_message +from .sqs_client import SqsClient + + +@dataclass +class SqsCeleryClient(SqsClient): + def send_task( + self, + name: str, + args: Optional[Iterable] = None, + kwargs: Optional[dict] = None, + ) -> None: + celery_message = build_celery_message(name, args or (), kwargs or {}) + self.send_message(celery_message) diff --git a/agave/tools/sync/sqs_client.py b/agave/tools/sync/sqs_client.py new file mode 100644 index 00000000..d86fa1a6 --- /dev/null +++ b/agave/tools/sync/sqs_client.py @@ -0,0 +1,34 @@ +import json +from dataclasses import dataclass, field +from typing import Optional, Union +from uuid import uuid4 + +try: + import boto3 + from types_boto3_sqs import SQSClient as Boto3SQSClient +except ImportError: + raise ImportError( + "You must install agave with [sync_aws_tools] option.\n" + "You can install it with: pip install agave[sync_aws_tools]" + ) + + +@dataclass +class SqsClient: + queue_url: str + region_name: str + _sqs: Boto3SQSClient = field(init=False) + + def __post_init__(self) -> None: + self._sqs = boto3.client('sqs', region_name=self.region_name) + + def send_message( + self, + data: Union[str, dict], + message_group_id: Optional[str] = None, + ) -> None: + self._sqs.send_message( + QueueUrl=self.queue_url, + MessageBody=data if isinstance(data, str) else json.dumps(data), + MessageGroupId=message_group_id or str(uuid4()), + ) diff --git a/agave/version.py b/agave/version.py index 1a72d32e..58d478ab 100644 --- a/agave/version.py +++ b/agave/version.py @@ -1 +1 @@ -__version__ = '1.1.0' +__version__ = '1.2.0' diff --git a/requirements-test.txt b/requirements-test.txt index 751866ed..4e0126d5 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -7,10 +7,10 @@ mypy==1.14.1 mongomock==4.3.0 mock==5.1.0 pytest-freezegun==0.4.2 -pytest-chalice==0.0.5 click==8.1.8 moto[server]==5.0.26 pytest-vcr==1.0.2 pytest-asyncio==0.18.* requests==2.32.3 httpx==0.28.1 +typing_extensions==4.12.2 diff --git a/requirements.txt b/requirements.txt index 7f65f8b4..d1938327 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,6 @@ cuenca-validations==2.1.0 +boto3==1.35.74 +types-boto3[sqs]==1.35.74 chalice==1.31.3 mongoengine==0.29.1 fastapi==0.115.6 diff --git a/setup.py b/setup.py index e9ab8e19..4ffc4593 100644 --- a/setup.py +++ b/setup.py @@ -39,6 +39,14 @@ 'aiobotocore>=2.0.0,<3.0.0', 'types-aiobotocore-sqs>=2.1.0,<3.0.0', ], + 'sync_aws_tools': [ + 'boto3>=1.34.106,<2.0.0', + 'types-boto3[sqs]>=1.34.106,<2.0.0', + ], + 'asyncio_aws_tools': [ + 'aiobotocore>=2.0.0,<3.0.0', + 'types-aiobotocore-sqs>=2.1.0,<3.0.0', + ], }, classifiers=[ 'Programming Language :: Python :: 3.9', diff --git a/tests/conftest.py b/tests/conftest.py index 5e204b34..1a78af30 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,13 +1,21 @@ import datetime as dt import functools import json +import os +from functools import partial from typing import Callable, Generator +import aiobotocore +import boto3 import pytest +from _pytest.monkeypatch import MonkeyPatch +from aiobotocore.session import AioSession from chalice.test import Client as OriginalChaliceClient from fastapi.testclient import TestClient as FastAPIClient from mongoengine import Document +from typing_extensions import deprecated +from agave.tasks import sqs_tasks from examples.config import ( TEST_DEFAULT_PLATFORM_ID, TEST_DEFAULT_USER_ID, @@ -216,3 +224,98 @@ def chalice_client() -> Generator[ChaliceClient, None, None]: client = ChaliceClient(app) yield client + + +@deprecated('Use fixtures from cuenca-test-fixtures') +@pytest.fixture(scope='session') +def aws_credentials() -> None: + """Mocked AWS Credentials for moto.""" + os.environ['AWS_ACCESS_KEY_ID'] = 'testing' + os.environ['AWS_SECRET_ACCESS_KEY'] = 'testing' + os.environ['AWS_SECURITY_TOKEN'] = 'testing' + os.environ['AWS_DEFAULT_REGION'] = 'us-east-1' + boto3.setup_default_session() + + +@deprecated('Use fixtures from cuenca-test-fixtures') +@pytest.fixture(scope='session') +def aws_endpoint_urls( + aws_credentials, +) -> Generator[dict[str, str], None, None]: + from moto.server import ThreadedMotoServer + + server = ThreadedMotoServer(port=4000) + server.start() + + endpoints = dict( + sqs='http://127.0.0.1:4000/', + ) + yield endpoints + + server.stop() + + +@pytest.fixture(autouse=True) +def patch_tasks_count(monkeypatch: MonkeyPatch) -> None: + def one_loop(*_, **__): + # Para pruebas solo unos cuantos ciclos + for i in range(5): + yield i + + monkeypatch.setattr(sqs_tasks, 'count', one_loop) + + +@deprecated('Use fixtures from cuenca-test-fixtures') +@pytest.fixture(autouse=True) +def patch_aiobotocore_create_client( + aws_endpoint_urls, monkeypatch: MonkeyPatch +) -> None: + create_client = AioSession.create_client + + def mock_create_client(*args, **kwargs): + service_name = next(a for a in args if type(a) is str) + kwargs['endpoint_url'] = aws_endpoint_urls[service_name] + + return create_client(*args, **kwargs) + + monkeypatch.setattr(AioSession, 'create_client', mock_create_client) + + +@deprecated('Use fixtures from cuenca-test-fixtures') +@pytest.fixture(autouse=True) +def patch_boto3_create_client( + aws_endpoint_urls, monkeypatch: MonkeyPatch +) -> None: + create_client = boto3.Session.client + + def mock_client(*args, **kwargs): + service_name = next(a for a in args if type(a) is str) + if service_name in aws_endpoint_urls: + kwargs['endpoint_url'] = aws_endpoint_urls[service_name] + return create_client(*args, **kwargs) + + monkeypatch.setattr(boto3.Session, 'client', mock_client) + + +@deprecated('Use fixtures from cuenca-test-fixtures') +@pytest.fixture +async def sqs_client(): + session = aiobotocore.session.get_session() + async with session.create_client('sqs', 'us-east-1') as sqs: + await sqs.create_queue( + QueueName='core.fifo', + Attributes={ + 'FifoQueue': 'true', + 'ContentBasedDeduplication': 'true', + }, + ) + resp = await sqs.get_queue_url(QueueName='core.fifo') + sqs.send_message = partial(sqs.send_message, QueueUrl=resp['QueueUrl']) + sqs.receive_message = partial( + sqs.receive_message, + QueueUrl=resp['QueueUrl'], + AttributeNames=['ApproximateReceiveCount'], + ) + sqs.queue_url = resp['QueueUrl'] + yield sqs + await sqs.purge_queue(QueueUrl=resp['QueueUrl']) diff --git a/tests/tasks/conftest.py b/tests/tasks/conftest.py deleted file mode 100644 index 28bec091..00000000 --- a/tests/tasks/conftest.py +++ /dev/null @@ -1,84 +0,0 @@ -import os -from functools import partial -from typing import Generator - -import aiobotocore -import boto3 -import pytest -from _pytest.monkeypatch import MonkeyPatch -from aiobotocore.session import AioSession - -from agave.tasks import sqs_tasks - - -@pytest.fixture(scope='session') -def aws_credentials() -> None: - """Mocked AWS Credentials for moto.""" - os.environ['AWS_ACCESS_KEY_ID'] = 'testing' - os.environ['AWS_SECRET_ACCESS_KEY'] = 'testing' - os.environ['AWS_SECURITY_TOKEN'] = 'testing' - os.environ['AWS_DEFAULT_REGION'] = 'us-east-1' - boto3.setup_default_session() - - -@pytest.fixture(scope='session') -def aws_endpoint_urls( - aws_credentials, -) -> Generator[dict[str, str], None, None]: - from moto.server import ThreadedMotoServer - - server = ThreadedMotoServer(port=4000) - server.start() - - endpoints = dict( - sqs='http://127.0.0.1:4000/', - ) - yield endpoints - - server.stop() - - -@pytest.fixture(autouse=True) -def patch_tasks_count(monkeypatch: MonkeyPatch) -> None: - def one_loop(*_, **__): - # Para pruebas solo unos cuantos ciclos - for i in range(5): - yield i - - monkeypatch.setattr(sqs_tasks, 'count', one_loop) - - -@pytest.fixture(autouse=True) -def patch_create_client(aws_endpoint_urls, monkeypatch: MonkeyPatch) -> None: - create_client = AioSession.create_client - - def mock_create_client(*args, **kwargs): - service_name = next(a for a in args if type(a) is str) - kwargs['endpoint_url'] = aws_endpoint_urls[service_name] - - return create_client(*args, **kwargs) - - monkeypatch.setattr(AioSession, 'create_client', mock_create_client) - - -@pytest.fixture -async def sqs_client(): - session = aiobotocore.session.get_session() - async with session.create_client('sqs', 'us-east-1') as sqs: - await sqs.create_queue( - QueueName='core.fifo', - Attributes={ - 'FifoQueue': 'true', - 'ContentBasedDeduplication': 'true', - }, - ) - resp = await sqs.get_queue_url(QueueName='core.fifo') - sqs.send_message = partial(sqs.send_message, QueueUrl=resp['QueueUrl']) - sqs.receive_message = partial( - sqs.receive_message, - QueueUrl=resp['QueueUrl'], - AttributeNames=['ApproximateReceiveCount'], - ) - sqs.queue_url = resp['QueueUrl'] - yield sqs - await sqs.purge_queue(QueueUrl=resp['QueueUrl']) diff --git a/tests/tasks/test_imports.py b/tests/tasks/test_imports.py index ac59dbb2..25b02b84 100644 --- a/tests/tasks/test_imports.py +++ b/tests/tasks/test_imports.py @@ -4,17 +4,38 @@ import pytest -def test_tasks_import_error(monkeypatch): - for module in ['types_aiobotocore_sqs', 'agave.tasks.sqs_client']: +@pytest.mark.parametrize( + 'module_path,required_types,required_option', + [ + ( + 'agave.tools.asyncio.sqs_client', + 'types_aiobotocore_sqs', + 'asyncio_aws_tools', + ), + ( + 'agave.tools.sync.sqs_client', + 'types_boto3_sqs', + 'sync_aws_tools', + ), + ], +) +def test_tasks_import_error( + monkeypatch: pytest.MonkeyPatch, + module_path: str, + required_types: str, + required_option: str, +) -> None: + # Clear modules from sys.modules + for module in [required_types, module_path]: if module in sys.modules: del sys.modules[module] - monkeypatch.setitem(sys.modules, 'types_aiobotocore_sqs', None) + monkeypatch.setitem(sys.modules, required_types, None) with pytest.raises(ImportError) as exc_info: - importlib.import_module('agave.tasks.sqs_client') + importlib.import_module(module_path) - assert "You must install agave with [fastapi, tasks] option" in str( + assert f"You must install agave with [{required_option}] option" in str( exc_info.value ) - assert "pip install agave[fastapi, tasks]" in str(exc_info.value) + assert f"pip install agave[{required_option}]" in str(exc_info.value) diff --git a/tests/tools/__init__.py b/tests/tools/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/tools/asyncio/__init__.py b/tests/tools/asyncio/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/tasks/test_sqs_celery_client.py b/tests/tools/asyncio/test_sqs_celery_client.py similarity index 96% rename from tests/tasks/test_sqs_celery_client.py rename to tests/tools/asyncio/test_sqs_celery_client.py index bb0df9a0..4c7f000f 100644 --- a/tests/tasks/test_sqs_celery_client.py +++ b/tests/tools/asyncio/test_sqs_celery_client.py @@ -1,7 +1,7 @@ import base64 import json -from agave.tasks.sqs_celery_client import SqsCeleryClient +from agave.tools.asyncio.sqs_celery_client import SqsCeleryClient CORE_QUEUE_REGION = 'us-east-1' diff --git a/tests/tasks/test_sqs_client.py b/tests/tools/asyncio/test_sqs_client.py similarity index 94% rename from tests/tasks/test_sqs_client.py rename to tests/tools/asyncio/test_sqs_client.py index f5849dac..ac9b54bf 100644 --- a/tests/tasks/test_sqs_client.py +++ b/tests/tools/asyncio/test_sqs_client.py @@ -1,6 +1,6 @@ import json -from agave.tasks.sqs_client import SqsClient +from agave.tools.asyncio.sqs_client import SqsClient CORE_QUEUE_REGION = 'us-east-1' diff --git a/tests/tools/sync/__init__.py b/tests/tools/sync/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/tools/sync/test_sqs_celery_client.py b/tests/tools/sync/test_sqs_celery_client.py new file mode 100644 index 00000000..f9d02f2f --- /dev/null +++ b/tests/tools/sync/test_sqs_celery_client.py @@ -0,0 +1,27 @@ +import base64 +import json + +from agave.tools.sync.sqs_celery_client import SqsCeleryClient + +CORE_QUEUE_REGION = 'us-east-1' + + +async def test_send_task(sqs_client) -> None: + args = [10, 'foo'] + kwargs = dict(hola='mundo') + client = SqsCeleryClient(sqs_client.queue_url, CORE_QUEUE_REGION) + + client.send_task('some.task', args=args, kwargs=kwargs) + sqs_message = await sqs_client.receive_message() + encoded_body = sqs_message['Messages'][0]['Body'] + message = json.loads( + base64.b64decode(encoded_body.encode('utf-8')).decode() + ) + body_json = json.loads( + base64.b64decode(message['body'].encode('utf-8')).decode() + ) + + assert body_json[0] == args + assert body_json[1] == kwargs + assert message['headers']['lang'] == 'py' + assert message['headers']['task'] == 'some.task' diff --git a/tests/tools/sync/test_sqs_client.py b/tests/tools/sync/test_sqs_client.py new file mode 100644 index 00000000..15c3eac0 --- /dev/null +++ b/tests/tools/sync/test_sqs_client.py @@ -0,0 +1,22 @@ +import json + +from agave.tools.sync.sqs_client import SqsClient + +CORE_QUEUE_REGION = 'us-east-1' + + +async def test_send_message(sqs_client) -> None: + data1 = dict(hola='mundo') + data2 = dict(foo='bar') + + client = SqsClient(sqs_client.queue_url, CORE_QUEUE_REGION) + client.send_message(data1) + client.send_message(data2, message_group_id='12345') + + sqs_message = await sqs_client.receive_message() + message_body = json.loads(sqs_message['Messages'][0]['Body']) + assert message_body == data1 + + sqs_message = await sqs_client.receive_message() + message = json.loads(sqs_message['Messages'][0]['Body']) + assert message == data2