From 05ad44916d7be16e97516e47c1cb2c64971ee465 Mon Sep 17 00:00:00 2001 From: "C. Weaver" Date: Fri, 19 Dec 2025 16:40:14 -0500 Subject: [PATCH] Re-implement looking up multiple messages with more flexibility. Allow mixing results across multiple topics, with potentially differing access levels. Allow reversing the order of results. Use more efficient paging. --- archive/access_api.py | 62 ++++- archive/database_api.py | 303 ++++++++++++++++++--- pyproject.toml | 1 + tests/test_access_api.py | 65 ++++- tests/test_database_api.py | 520 ++++++++++++++++++++++++++++++++++--- tests/test_decision_api.py | 2 + tests/test_store_api.py | 8 +- 7 files changed, 867 insertions(+), 94 deletions(-) diff --git a/archive/access_api.py b/archive/access_api.py index 4b9ca90..8c89505 100644 --- a/archive/access_api.py +++ b/archive/access_api.py @@ -7,6 +7,7 @@ """ import bson import logging +from typing import Optional, List from . import database_api from . import decision_api from . import store_api @@ -48,10 +49,6 @@ async def get_metadata(self, uuid): "get metadata, if any associated with a message with the given ID" return await self.db.fetch(uuid) - async def get_metadata_for_time_range(self, topic: str, start_time: int, end_time: int, limit: int=10, offset: int=0): - "get metadata, if any associated with a message with the given ID" - return await self.db.get_message_records_for_time_range(topic, start_time, end_time, limit, offset) - async def get_object_lazily(self, key): """ get the raw object in the form of the S3 response which can be streamed @@ -96,4 +93,59 @@ async def store_message(self, payload, metadata, public: bool=True, direct_uploa return (True, {"archive_uuid": annotations["con_text_uuid"], "is_client_uuid": annotations["con_is_client_uuid"]}, - "") \ No newline at end of file + "") + + async def get_topics_with_public_messages(self): + """ + Get the names of all topics on which at least one public message is archived + """ + return await self.db.get_topics_with_public_messages() + + async def get_message_records(self, *args, **kwargs): + """ + Get the records for messages satisfying specified criteria, with results split/batched into + 'pages'. + Selecting messages by topic has some complexity: First, if neither topics_public nor + topics_full is specified, the default is to select public messages from any topic. If either + topic restriction argument is specified, no message is returned which is on a topic not + specified by one of the two arguments. Both arguments may be specified at the same time to + select a union of messages across multiple topics with different access levels. + + Args: + bookmark: If not None, this must be a 'bookmark' string returned by a previous call, to + select another page of results. + page_size: The maximum number of results to return fro this call; any further results + can be retrieved as subsequent 'pages'. + ascending: Whether the reuslts should be sorted in ascending timestamp order. + topics_public: If not None, only messages which are flagged as being public appearing on + these topics will be returned. Can be used at the same time as + topics_full. + topics_full: If not None, any message appearing on one of these topics is a cadidate to + be returned. + start_time: The beginning of the message timestamp range to select. + end_time: The end of the message timestamp range to select. The range is half-open, so + messages with this exact timestamp will be excluded. + Return: A tuple consisting of the results (a list of MessageRecords), a 'bookmark' which can + be used to fetch the next page of results or None if there are no subsequent + results, and a 'bookmark' for the previous page of results or None if there are no + prior results. + """ + return await self.db.get_message_records(*args, **kwargs) + + async def count_message_records(self, *args, **kwargs): + """ + Count the numberof messages satisfying specified criteria, as they would be returned by + get_message_records. + + Args: + topics_public: If not None, only messages which are flagged as being public appearing on + these topics will be returned. Can be used at the same time as + topics_full. + topics_full: If not None, any message appearing on one of these topics is a cadidate to + be returned. + start_time: The beginning of the message timestamp range to select. + end_time: The end of the message timestamp range to select. The range is half-open, so + messages with this exact timestamp will be excluded. + Return: An integer count of messages + """ + return await self.db.count_message_records(*args, **kwargs) \ No newline at end of file diff --git a/archive/database_api.py b/archive/database_api.py index db6c3d2..1a1a829 100644 --- a/archive/database_api.py +++ b/archive/database_api.py @@ -21,14 +21,17 @@ class can be configured to access the production or from botocore.exceptions import ClientError import sqlalchemy bindparam = sqlalchemy.bindparam -from sqlalchemy.ext.asyncio import create_async_engine, create_async_pool_from_url +from sqlalchemy.ext.asyncio import create_async_engine, create_async_pool_from_url, AsyncSession +from sqlakeyset.asyncio import select_page import psycopg from psycopg_pool import AsyncConnectionPool import aioboto3 import logging -from collections import namedtuple +from dataclasses import dataclass from . import utility_api import os +from typing import Optional, List +from uuid import UUID ################################## # "databases" @@ -88,10 +91,21 @@ def log(self): async def insert(self, metadata, annotations): raise NotImplementedError - MessageRecord = namedtuple("MessageRecord", - ["id", "topic", "timestamp", "uuid", "size", "key", - "bucket", "crc32", "is_client_uuid", "public", - "direct_upload", "message_crc32"]) + # the fields in this class must match the columns defined in SQL_db.connect + @dataclass + class MessageRecord: + id: int + topic: str + timestamp: int + uuid: UUID + size: int + key: str + bucket: str + crc32: int + is_client_uuid: bool + public: bool + direct_upload: bool + message_crc32: int async def fetch(self, uuid) -> MessageRecord: raise NotImplementedError @@ -150,6 +164,70 @@ async def get_message_locations(self, ids): """ raise NotImplementedError + async def get_topics_with_public_messages(self): + """ + Get the names of all topics on which at least one public message is archived + """ + raise NotImplementedError + + async def get_message_records(self, bookmark: Optional[str]=None, page_size: int=1024, + ascending: bool=True, + topics_public: Optional[List[str]]=None, + topics_full: Optional[List[str]]=None, + start_time: Optional[int]=None, + end_time: Optional[int]=None,): + """ + Get the records for messages satisfying specified criteria, with results split/batched into + 'pages'. + Selecting messages by topic has some complexity: First, if neither topics_public nor + topics_full is specified, the default is to select public messages from any topic. If either + topic restriction argument is specified, no message is returned which is on a topic not + specified by one of the two arguments. Both arguments may be specified at the same time to + select a union of messages across multiple topics with different access levels. + + Args: + bookmark: If not None, this must be a 'bookmark' string returned by a previous call, to + select another page of results. + page_size: The maximum number of results to return fro this call; any further results + can be retrieved as subsequent 'pages'. + ascending: Whether the reuslts should be sorted in ascending timestamp order. + topics_public: If not None, only messages which are flagged as being public appearing on + these topics will be returned. Can be used at the same time as + topics_full. + topics_full: If not None, any message appearing on one of these topics is a cadidate to + be returned. + start_time: The beginning of the message timestamp range to select. + end_time: The end of the message timestamp range to select. The range is half-open, so + messages with this exact timestamp will be excluded. + Return: A tuple consisting of the results (a list of MessageRecords), a 'bookmark' which can + be used to fetch the next page of results or None if there are no subsequent + results, and a 'bookmark' for the previous page of results or None if there are no + prior results. + """ + raise NotImplementedError + + async def count_message_records(self, + topics_public: Optional[List[str]]=None, + topics_full: Optional[List[str]]=None, + start_time: Optional[int]=None, + end_time: Optional[int]=None,): + """ + Count the numberof messages satisfying specified criteria, as they would be returned by + get_message_records. + + Args: + topics_public: If not None, only messages which are flagged as being public appearing on + these topics will be returned. Can be used at the same time as + topics_full. + topics_full: If not None, any message appearing on one of these topics is a cadidate to + be returned. + start_time: The beginning of the message timestamp range to select. + end_time: The end of the message timestamp range to select. The range is half-open, so + messages with this exact timestamp will be excluded. + Return: An integer count of messages + """ + raise NotImplementedError + class Mock_db(Base_db): """ @@ -243,20 +321,110 @@ async def get_message_locations(self, ids): # TODO do what if id is not known? return results - async def get_message_records_for_time_range(self, topic: str, start_time: int, end_time: int, limit: int=10, offset: int=0): + async def get_topics_with_public_messages(self): assert self.connected - # This is not at all efficient, but should not be used for serious amounts of data - results = [] - for record in sorted(self.data.values(), key=lambda r: r.timestamp): - if record.topic == topic and \ - record.timestamp >= start_time and record.timestamp < end_time: - if offset > 0: - offset -= 1 + topics = set() + for record in self.data.values(): + if record.public: + topics.add(record.topic) + return list(topics) + + def _generate_query_filter(self, + topics_public: Optional[List[str]]=None, + topics_full: Optional[List[str]]=None, + start_time: Optional[int]=None, + end_time: Optional[int]=None): + def filter(record): + if topics_public is not None or topics_full is not None: + if topics_public is None: + if record.topic not in topics_full: + return False + elif topics_full is None: + if not record.public or record.topic not in topics_public: + return False else: - results.append(record) - if limit!=0 and len(results) == limit: - return results - return results + if record.topic not in topics_full and (not record.public or record.topic not in topics_public): + return False + elif not record.public: + return False + if start_time is not None and record.timestamp < start_time: + return False + if end_time is not None and record.timestamp >= end_time: + return False + return True + return filter + + async def get_message_records(self, bookmark: Optional[str]=None, page_size: int=1024, + ascending: bool=True, + topics_public: Optional[List[str]]=None, + topics_full: Optional[List[str]]=None, + start_time: Optional[int]=None, + end_time: Optional[int]=None,): + assert self.connected + filter = self._generate_query_filter(topics_public, topics_full, start_time, end_time) + # This is not at all efficient, but should not be used for serious amounts of data + matching = [r for r in self.data.values() if filter(r)] + if len(matching) == 0: + return [], None, None + matching.sort(key=lambda r: (r.timestamp, r.id), reverse=not ascending) + # apply pagination + def make_bookmark(dir, ts, i): + # This is not quite the format used by sqlakeyset, but we don't need as much generality, + # nor is interoperability needed, and this is simpler to parse. + return f"{dir}{ts}~{i}" + if bookmark is not None: + assert len(bookmark) >= 4 + direction = bookmark[0] + assert direction in ('<', '>') + ts, i = tuple(int(s) for s in bookmark[1:].split("~")) + if direction == '>': + startIdx = 0 + while startIdx < len(matching) and matching[startIdx].timestamp < ts or \ + (matching[startIdx].timestamp == ts and matching[startIdx].id <= i): + startIdx += 1 + if startIdx >= len(matching): + return [], None, make_bookmark('<', matching[0].timestamp, matching[0].id+1) + endIdx = startIdx + page_size + if endIdx > len(matching): + endIdx = len(matching) + else: # '<' + endIdx = 0 + while endIdx < len(matching) and matching[endIdx].timestamp < ts or \ + (matching[endIdx].timestamp == ts and matching[endIdx].id < i): + endIdx += 1 + if endIdx == 0: + return [], make_bookmark('>', matching[0].timestamp, matching[0].id-1), None + startIdx = endIdx - page_size + if startIdx < 0: + startIdx = 0 + else: + startIdx = 0 + endIdx = startIdx + page_size + if endIdx > len(matching): + endIdx = len(matching) + if endIdx < len(matching): + next = make_bookmark('>', matching[endIdx-1].timestamp, matching[endIdx-1].id) + else: + next = None + if startIdx > 0: + prev = make_bookmark('<', matching[startIdx].timestamp, matching[startIdx].id) + else: + prev = None + return matching[startIdx:endIdx], next, prev + + async def count_message_records(self, + topics_public: Optional[List[str]]=None, + topics_full: Optional[List[str]]=None, + start_time: Optional[int]=None, + end_time: Optional[int]=None,): + assert self.connected + filter = self._generate_query_filter(topics_public, topics_full, start_time, end_time) + # This is not at all efficient, but should not be used for serious amounts of data + count = 0 + for record in self.data.values(): + if filter(record): + count += 1 + return count class SQL_db(Base_db): def __init__(self, config): @@ -282,6 +450,7 @@ async def connect(self): ) self.db_meta = sqlalchemy.MetaData() Column = sqlalchemy.Column + # the columns defined in here must match the fields in Base_db.MessageRecord self.table = sqlalchemy.Table( self.table_name, self.db_meta, @@ -304,7 +473,14 @@ async def close(self): async def make_schema(self): "Declare tables" - ts_idx = sqlalchemy.Index(f"{self.table_name}_timestamp_idx", self.table.c.timestamp) + # There's generally little reason to search for specific timestamps by value, range queries + # are more natural. In addition, to paginate the results, we need uniqueness, so we want to + # use the (internal, DB) ID as well. Also, filtering based on source topic and the public + # flag is frequently needed, and this is much more efficient if those columns are included + # in the index (but not used for ordering). + ts_idx = sqlalchemy.Index(f"{self.table_name}_timestamp_id_idx", + self.table.c.timestamp, self.table.c.id, + postgresql_include=["topic", "public"]) topic_idx = sqlalchemy.Index(f"{self.table_name}_topic_idx", self.table.c.topic) key_idx = sqlalchemy.Index(f"{self.table_name}_key_idx", self.table.c.key) uuid_idx = sqlalchemy.Index(f"{self.table_name}_uuid_idx", self.table.c.uuid) @@ -448,8 +624,8 @@ async def get_message_id(self, uuid): only the id for one of them. """ async with self.engine.connect() as conn: - result = await conn.execute(sqlalchemy.select(self.table.c.id)\ - .select_from(self.table)\ + result = await conn.execute(sqlalchemy.select(self.table.c.id) + .select_from(self.table) .where(self.table.c.uuid == bindparam("uuid")), {"uuid":uuid}) return result.scalar() @@ -463,25 +639,82 @@ async def get_message_locations(self, ids): be found in the data store. """ async with self.engine.connect() as conn: - result = await conn.execute(sqlalchemy.select(self.table.c.bucket, self.table.c.key)\ - .select_from(self.table)\ + result = await conn.execute(sqlalchemy.select(self.table.c.bucket, self.table.c.key) + .select_from(self.table) .where(self.table.c.id.in_(bindparam("ids"))), {"ids":ids}) return result.all() - async def get_message_records_for_time_range(self, topic: str, start_time: int, end_time: int, limit: int=10, offset: int=0): + async def get_topics_with_public_messages(self): async with self.engine.connect() as conn: - result = await conn.execute(self.table.select()\ - .where((self.table.c.topic == bindparam("topic")) & - (self.table.c.timestamp >= bindparam("start_time")) & - (self.table.c.timestamp < bindparam("end_time"))) - .limit(bindparam("limit")).offset(bindparam("offset")), - {"topic":topic, - "start_time":start_time, - "end_time":end_time, - "limit":limit, - "offset":offset}) - return [Base_db.MessageRecord(**record._mapping) for record in result.all()] + result = await conn.execute(sqlalchemy.select(sqlalchemy.distinct(self.table.c.topic)) + .where(self.table.c.public == sqlalchemy.true())) + return [r[0] for r in result.all()] + + def _generate_query_restrictions(self, q, + topics_public: Optional[List[str]]=None, + topics_full: Optional[List[str]]=None, + start_time: Optional[int]=None, + end_time: Optional[int]=None): + if topics_public is not None or topics_full is not None: + if topics_public is not None and len(topics_public) > 0: + if len(topics_public) == 1: + pub_clause = sqlalchemy.and_(self.table.c.public == sqlalchemy.true(), + self.table.c.topic == topics_public[0]) + else: + pub_clause = sqlalchemy.and_(self.table.c.public == sqlalchemy.true(), + self.table.c.topic.in_(topics_public)) + else: + pub_clause = None + if topics_full is not None and len(topics_full) > 0: + if len(topics_full) == 1: + full_clause = self.table.c.topic == topics_full[0] + else: + full_clause = self.table.c.topic.in_(topics_full) + else: + full_clause = None + if pub_clause is not None and full_clause is not None: + q = q.where(sqlalchemy.or_(pub_clause, full_clause)) + elif pub_clause is not None: + q = q.where(pub_clause) + elif full_clause is not None: + q = q.where(full_clause) + else: # if no topics specified, select public messages across all topics + q = q.where(self.table.c.public == sqlalchemy.true()) + if start_time is not None: + q = q.where(self.table.c.timestamp >= start_time) + if end_time is not None: + q = q.where(self.table.c.timestamp < end_time) + return q + + async def get_message_records(self, bookmark: Optional[str]=None, page_size: int=1024, + ascending: bool=True, + topics_public: Optional[List[str]]=None, + topics_full: Optional[List[str]]=None, + start_time: Optional[int]=None, + end_time: Optional[int]=None,): + q = self.table.select() + if ascending: + q = q.order_by(self.table.c.timestamp, self.table.c.id) + else: + q = q.order_by(sqlalchemy.desc(self.table.c.timestamp), + sqlalchemy.desc(self.table.c.id)) + q = self._generate_query_restrictions(q, topics_public, topics_full, start_time, end_time) + async with AsyncSession(self.engine) as session: + page = await select_page(session, q, per_page=page_size, page=bookmark) + return ([Base_db.MessageRecord(*r) for r in page], + page.paging.bookmark_next if page.paging.has_next else None, + page.paging.bookmark_previous if page.paging.has_previous else None,) + + async def count_message_records(self, topics_public: Optional[List[str]]=None, + topics_full: Optional[List[str]]=None, + start_time: Optional[int]=None, + end_time: Optional[int]=None,): + q = sqlalchemy.select(sqlalchemy.func.count()).select_from(self.table) + q = self._generate_query_restrictions(q, topics_public, topics_full, start_time, end_time) + async with AsyncSession(self.engine) as session: + count = (await session.execute(q)).scalar() + return count class AWS_db(SQL_db): diff --git a/pyproject.toml b/pyproject.toml index c7fb3dc..976e7e7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ dependencies = [ "psycopg[binary]", "psycopg_pool", "sqlalchemy", + "sqlakeyset", "toml", "uvicorn[standard]", ] diff --git a/tests/test_access_api.py b/tests/test_access_api.py index 10874de..e4bb7dc 100644 --- a/tests/test_access_api.py +++ b/tests/test_access_api.py @@ -30,7 +30,8 @@ def get_mock_config(): return {"db_type": "mock", "store_type": "mock", "store_primary_bucket": "archive", - "store_backup_bucket": "backup"} + "store_backup_bucket": "backup", + "store_region_name": "eu-north-3"} @pytest.mark.asyncio async def test_archive_access_startup_shutdown(): @@ -88,7 +89,7 @@ async def test_archive_access_store_message_no_uuid(): r = await aa.store_message(message, metadata) assert r[0], "Insertion should succeed" - dr = await aa.db.get_message_records_for_time_range("t1", start_time=355, end_time=357) + dr, _, _ = await aa.get_message_records(topics_full=["t1"]) assert len(dr) == 1, "Message should be recorded in database" assert dr[0].topic == metadata.topic assert dr[0].timestamp == metadata.timestamp @@ -143,7 +144,24 @@ async def test_archive_access_get_metadata(): assert r.uuid == str(u) @pytest.mark.asyncio -async def test_archive_access_get_metadata_for_time_range(): +async def test_archive_access_get_topics_with_public_messages(): + messages = [] + for i in range(0,10): + ms = b"datadatadata" + md = Metadata(topic=f"t{i}", partition=0, offset=i, timestamp=i, key="", headers=[("_id",uuid.uuid4().bytes)], _raw=None) + messages.append((ms,md)) + + aa = access_api.Archive_access(get_mock_config()) + await aa.connect() + + for m in messages: + await aa.store_message(m[0],m[1], public=(m[1].offset % 2 == 0)) + + pub_top = await aa.get_topics_with_public_messages() + assert len(pub_top) == 5 + +@pytest.mark.asyncio +async def test_archive_access_get_message_records(): messages = [] for i in range(0,10): ms = b"datadatadata" @@ -156,27 +174,58 @@ async def test_archive_access_get_metadata_for_time_range(): for m in messages: await aa.store_message(m[0],m[1]) - r = await aa.get_metadata_for_time_range("t1", start_time=4, end_time=7) + r, _, _ = await aa.get_message_records(topics_full=["t1"], start_time=4, end_time=7) assert len(r) == 3 assert r[0].timestamp == 4 assert r[1].timestamp == 5 assert r[2].timestamp == 6 - r = await aa.get_metadata_for_time_range("t1", start_time=3, end_time=7, limit=2) + r, n, p = await aa.get_message_records(topics_full=["t1"], start_time=3, end_time=7, page_size=2) assert len(r) == 2 assert r[0].timestamp == 3 assert r[1].timestamp == 4 + assert n is not None + assert p is None - r = await aa.get_metadata_for_time_range("t1", start_time=3, end_time=7, limit=2, offset=2) + r, n, p = await aa.get_message_records(topics_full=["t1"], start_time=3, end_time=7, page_size=2, bookmark=n) assert len(r) == 2 assert r[0].timestamp == 5 assert r[1].timestamp == 6 + assert n is None + assert p is not None - r = await aa.get_metadata_for_time_range("t1", start_time=12, end_time=14) + r, n, p = await aa.get_message_records(topics_full=["t1"], start_time=12, end_time=14) assert len(r) == 0 + assert n is None + assert p is None - r = await aa.get_metadata_for_time_range("t2", start_time=0, end_time=5) + r, n, p = await aa.get_message_records(topics_full=["t2"], start_time=0, end_time=5) assert len(r) == 0 + assert n is None + assert p is None + +@pytest.mark.asyncio +async def test_archive_access_count_message_records(): + messages = [] + for i in range(0,10): + ms = b"datadatadata" + md = Metadata(topic="t1", partition=0, offset=i, timestamp=i, key="", headers=[("_id",uuid.uuid4().bytes)], _raw=None) + messages.append((ms,md)) + + aa = access_api.Archive_access(get_mock_config()) + await aa.connect() + + for m in messages: + await aa.store_message(m[0],m[1]) + + c = await aa.count_message_records(topics_full=["t1"], start_time=4, end_time=7) + assert c == 3 + + c = await aa.count_message_records(topics_full=["t1"], start_time=12, end_time=14) + assert c == 0 + + c = await aa.count_message_records(topics_full=["t2"], start_time=0, end_time=5) + assert c == 0 @pytest.mark.asyncio async def test_archive_access_get_object_lazily(): diff --git a/tests/test_database_api.py b/tests/test_database_api.py index 003c85e..7b62096 100644 --- a/tests/test_database_api.py +++ b/tests/test_database_api.py @@ -82,10 +82,20 @@ def test_add_parser_options(tmpdir): async def get_mock_store(): st = store_api.StoreFactory({"store_type": "mock", "store_primary_bucket": "archive", - "store_backup_bucket": "backup"}) + "store_backup_bucket": "backup", + "store_region_name": "eu-north-3"}) await st.connect() return st +def generate_message(payload: bytes, topic: str, timestamp: int, public: bool=True, headers=[]): + metadata = Metadata(topic=topic, partition=0, offset=0, timestamp=timestamp, key="", headers=headers, _raw=None) + annotations = decision_api.get_annotations(payload, metadata.headers, public=public) + annotations['size'] = len(payload) + annotations['key'] = annotations["con_text_uuid"] + annotations['bucket'] = "bucket" + annotations['crc32'] = 0 + return payload, metadata, annotations + @pytest.mark.asyncio async def test_SQL_db_startup(tmpdir): with temp_postgres(tmpdir) as db_conf: @@ -378,52 +388,467 @@ async def test_SQL_db_get_message_locations(tmpdir): assert result[0][1] == annotations['key'], "Object key should be correct" @pytest.mark.asyncio -async def test_SQL_db_get_message_records_for_time_range(tmpdir): - messages = [] - st = await get_mock_store() - for i in range(0,10): - ms = b"datadatadata" - md = Metadata(topic="t1", partition=0, offset=i, timestamp=i, key="", headers=[("_id",uuid.uuid4().bytes)], _raw=None) - an = decision_api.get_annotations(ms, md.headers) - await st.store(ms, md, an) - messages.append((ms,md,an)) - +@pytest.mark.parametrize("db_class", (database_api.SQL_db, database_api.Mock_db)) +async def test_SQL_db_get_topics_with_public_messages(db_class, tmpdir): + with temp_postgres(tmpdir) as db_conf: + db = db_class(db_conf) + await db.connect() + await db.make_schema() + + p1, m1, a1 = generate_message(b"alert", topic="t1", timestamp=47, public=True) + await db.insert(m1, a1) + p2, m2, a2 = generate_message(b"another alert", topic="t1", timestamp=49, public=True) + await db.insert(m2, a2) + + p3, m3, a3 = generate_message(b"secret", topic="t2", timestamp=22, public=False) + await db.insert(m3, a3) + p4, m4, a4 = generate_message(b"shared", topic="t2", timestamp=81, public=True) + await db.insert(m4, a4) + + p5, m5, a5 = generate_message(b"private", topic="t3", timestamp=35, public=False) + await db.insert(m5, a5) + p4, m6, a6 = generate_message(b"proprietary", topic="t3", timestamp=48, public=False) + await db.insert(m6, a6) + + pub_tops = await db.get_topics_with_public_messages() + assert "t1" in pub_tops + assert "t2" in pub_tops + assert "t3" not in pub_tops + +@pytest.mark.asyncio +@pytest.mark.parametrize("db_class", (database_api.SQL_db, database_api.Mock_db)) +async def test_db_get_message_records_public(db_class, tmpdir): + with temp_postgres(tmpdir) as db_conf: + db = db_class(db_conf) + await db.connect() + await db.make_schema() + + p1, m1, a1 = generate_message(b"alert", topic="t1", timestamp=47, public=True) + await db.insert(m1, a1) + p2, m2, a2 = generate_message(b"another alert", topic="t1", timestamp=49, public=True) + await db.insert(m2, a2) + + p3, m3, a3 = generate_message(b"secret", topic="t2", timestamp=22, public=False) + await db.insert(m3, a3) + p4, m4, a4 = generate_message(b"shared", topic="t2", timestamp=81, public=True) + await db.insert(m4, a4) + + p5, m5, a5 = generate_message(b"private", topic="t3", timestamp=35, public=False) + await db.insert(m5, a5) + p4, m6, a6 = generate_message(b"proprietary", topic="t3", timestamp=48, public=False) + await db.insert(m6, a6) + + # with no topics explicitly selected, all public messages across all topics should be + # selected + results, n, p = await db.get_message_records(ascending=True) + assert len(results) == 3 + # exploit time ordering to check that we got the right messages by their timestamps + assert results[0].timestamp == 47 + assert results[1].timestamp == 49 + assert results[2].timestamp == 81 + assert all([r.public for r in results]) + + # repeat, selecting only messages on one topic + results, n, p = await db.get_message_records(topics_public=["t2"], ascending=True) + # only the public message on the selected topics should be returned + assert len(results) == 1 + assert results[0].timestamp == 81 + assert results[0].public + + # explicitly select multiple topics + results, n, p = await db.get_message_records(topics_public=["t1", "t2"], ascending=True) + assert len(results) == 3 + assert results[0].timestamp == 47 + assert results[1].timestamp == 49 + assert results[2].timestamp == 81 + assert all([r.public for r in results]) + + # select a topic with no matching messages + results, n, p = await db.get_message_records(topics_public=["t3"], ascending=True) + assert len(results) == 0 + +@pytest.mark.asyncio +@pytest.mark.parametrize("db_class", (database_api.SQL_db, database_api.Mock_db)) +async def test_SQL_db_get_message_records_full(db_class, tmpdir): + with temp_postgres(tmpdir) as db_conf: + db = db_class(db_conf) + await db.connect() + await db.make_schema() + + p1, m1, a1 = generate_message(b"alert", topic="t1", timestamp=47, public=True) + await db.insert(m1, a1) + p2, m2, a2 = generate_message(b"another alert", topic="t1", timestamp=49, public=True) + await db.insert(m2, a2) + + p3, m3, a3 = generate_message(b"secret", topic="t2", timestamp=22, public=False) + await db.insert(m3, a3) + p4, m4, a4 = generate_message(b"shared", topic="t2", timestamp=81, public=True) + await db.insert(m4, a4) + + p5, m5, a5 = generate_message(b"private", topic="t3", timestamp=35, public=False) + await db.insert(m5, a5) + p4, m6, a6 = generate_message(b"proprietary", topic="t3", timestamp=48, public=False) + await db.insert(m6, a6) + + # reading from a topic with full access should find both public and private messages + results, n, p = await db.get_message_records(topics_full=["t2"], ascending=True) + assert len(results) == 2 + # exploit time ordering to check that we got the right messages by their timestamps + assert results[0].timestamp == 22 + assert not results[0].public + assert results[1].timestamp == 81 + assert results[1].public + + # reading multiple topics should interleve messages in time order + results, n, p = await db.get_message_records(topics_full=["t2", "t3"], ascending=True) + assert len(results) == 4 + assert results[0].timestamp == 22 + assert not results[0].public + assert results[0].topic == "t2" + assert results[1].timestamp == 35 + assert not results[1].public + assert results[1].topic == "t3" + assert results[2].timestamp == 48 + assert not results[2].public + assert results[2].topic == "t3" + assert results[3].timestamp == 81 + assert results[3].public + assert results[3].topic == "t2" + + # should be able to mix public and full access to different topics + results, n, p = await db.get_message_records(topics_public=["t2"], topics_full=["t3"], ascending=True) + assert len(results) == 3 + assert results[0].timestamp == 35 + assert not results[0].public + assert results[0].topic == "t3" + assert results[1].timestamp == 48 + assert not results[1].public + assert results[1].topic == "t3" + assert results[2].timestamp == 81 + assert results[2].public + assert results[2].topic == "t2" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("db_class", (database_api.SQL_db, database_api.Mock_db)) +async def test_SQL_db_get_message_records_descending(db_class, tmpdir): + with temp_postgres(tmpdir) as db_conf: + db = db_class(db_conf) + await db.connect() + await db.make_schema() + + p1, m1, a1 = generate_message(b"alert", topic="t1", timestamp=47, public=True) + await db.insert(m1, a1) + p2, m2, a2 = generate_message(b"another alert", topic="t1", timestamp=49, public=True) + await db.insert(m2, a2) + + p3, m3, a3 = generate_message(b"secret", topic="t2", timestamp=22, public=False) + await db.insert(m3, a3) + p4, m4, a4 = generate_message(b"shared", topic="t2", timestamp=81, public=True) + await db.insert(m4, a4) + + p5, m5, a5 = generate_message(b"private", topic="t3", timestamp=35, public=False) + await db.insert(m5, a5) + p4, m6, a6 = generate_message(b"proprietary", topic="t3", timestamp=48, public=False) + await db.insert(m6, a6) + + # get all public messages, in descending order + results, n, p = await db.get_message_records(ascending=False) + assert len(results) == 3 + assert results[0].timestamp == 81 + assert results[1].timestamp == 49 + assert results[2].timestamp == 47 + assert all([r.public for r in results]) + + results, n, p = await db.get_message_records(topics_public=["t1"], ascending=False) + assert len(results) == 2 + assert results[0].timestamp == 49 + assert results[1].timestamp == 47 + assert all([r.public for r in results]) + + results, n, p = await db.get_message_records(topics_full=["t1", "t3"], ascending=False) + assert len(results) == 4 + assert results[0].timestamp == 49 + assert results[0].public + assert results[0].topic == "t1" + assert results[1].timestamp == 48 + assert not results[1].public + assert results[1].topic == "t3" + assert results[2].timestamp == 47 + assert results[2].public + assert results[2].topic == "t1" + assert results[3].timestamp == 35 + assert not results[3].public + assert results[3].topic == "t3" + +@pytest.mark.asyncio +@pytest.mark.parametrize("db_class", (database_api.SQL_db, database_api.Mock_db)) +async def test_SQL_db_get_message_records_paging(db_class, tmpdir): + with temp_postgres(tmpdir) as db_conf: + db = db_class(db_conf) + await db.connect() + await db.make_schema() + + for i in range(0,4): + p, m, a = generate_message(b"data", topic="t1", timestamp=i, public=True) + await db.insert(m, a) + for i in range(4,8): + p, m, a = generate_message(b"data", topic="t2", timestamp=i, public=False) + await db.insert(m, a) + for i in range(8,12): + p, m, a = generate_message(b"data", topic="t1", timestamp=i, public=True) + await db.insert(m, a) + + # get the first page of public messages + results, n, p = await db.get_message_records(ascending=True, page_size=4) + assert len(results) == 4 + assert n is not None + assert p is None + for r in results: + assert r.topic == "t1" + assert r.public + assert [r.timestamp for r in results] == [0, 1, 2, 3] + + #get the next page of public messages + results, n, p = await db.get_message_records(ascending=True, page_size=4, bookmark=n) + assert len(results) == 4 + assert n is None + assert p is not None + for r in results: + assert r.topic == "t1" + assert r.public + assert [r.timestamp for r in results] == [8, 9, 10, 11] + + # it should be possible to go back to the previous page + # test with an increased page size to check that nothing goes wrong trying to get data from + # before the beginning + results, n, p = await db.get_message_records(ascending=True, page_size=5, bookmark=p) + assert len(results) == 4 + assert n is not None + assert p is None + for r in results: + assert r.topic == "t1" + assert r.public + assert [r.timestamp for r in results] == [0, 1, 2, 3] + + # get the first page of all messages + results, n, p = await db.get_message_records(topics_full=["t1", "t2"], ascending=True, page_size=5) + assert len(results) == 5 + assert n is not None + assert p is None + assert [r.timestamp for r in results] == [0, 1, 2, 3, 4] + + # get the second page of all messages + results, n, p = await db.get_message_records(topics_full=["t1", "t2"], ascending=True, page_size=5, bookmark=n) + assert len(results) == 5 + assert n is not None + assert p is not None + assert [r.timestamp for r in results] == [5, 6, 7, 8, 9] + + # get the third page of all messages + results, n, p = await db.get_message_records(topics_full=["t1", "t2"], ascending=True, page_size=5, bookmark=n) + assert len(results) == 2 + assert n is None + assert p is not None + assert [r.timestamp for r in results] == [10, 11] + +@pytest.mark.asyncio +@pytest.mark.parametrize("db_class", (database_api.SQL_db, database_api.Mock_db)) +async def test_SQL_db_get_message_records_time_range(db_class, tmpdir): + with temp_postgres(tmpdir) as db_conf: + db = db_class(db_conf) + await db.connect() + await db.make_schema() + + for i in range(0,4): + p, m, a = generate_message(b"data", topic="t1", timestamp=i, public=True) + await db.insert(m, a) + for i in range(4,8): + p, m, a = generate_message(b"data", topic="t2", timestamp=i, public=False) + await db.insert(m, a) + for i in range(8,12): + p, m, a = generate_message(b"data", topic="t1", timestamp=i, public=True) + await db.insert(m, a) + + # all public messages + results, n, p = await db.get_message_records(ascending=True) + assert len(results) == 8 + assert [r.timestamp for r in results] == [0, 1, 2, 3, 8, 9, 10, 11] + + # all public messages after 2 + results, n, p = await db.get_message_records(ascending=True, start_time=2) + assert len(results) == 6 + assert [r.timestamp for r in results] == [2, 3, 8, 9, 10, 11] + + # all public messages before 10 + results, n, p = await db.get_message_records(ascending=True, end_time=10) + assert len(results) == 6 + assert [r.timestamp for r in results] == [0, 1, 2, 3, 8, 9] + + # all public messages between 2 and 10 + results, n, p = await db.get_message_records(ascending=True, start_time=2, end_time=10) + assert len(results) == 4 + assert [r.timestamp for r in results] == [2, 3, 8, 9] + + # all messages + results, n, p = await db.get_message_records(topics_full=["t1", "t2"], ascending=True) + assert len(results) == 12 + assert [r.timestamp for r in results] == list(range(0, 12)) + + # all messages after 2 + results, n, p = await db.get_message_records(topics_full=["t1", "t2"], ascending=True, start_time=2) + assert len(results) == 10 + assert [r.timestamp for r in results] == list(range(2, 12)) + + # all messages before 10 + results, n, p = await db.get_message_records(topics_full=["t1", "t2"], ascending=True, end_time=10) + assert len(results) == 10 + assert [r.timestamp for r in results] == list(range(0, 10)) + + # all messages between 2 and 10 + results, n, p = await db.get_message_records(topics_full=["t1", "t2"], ascending=True, start_time=2, end_time=10) + assert len(results) == 8 + assert [r.timestamp for r in results] == list(range(2, 10)) + + # make usre paging works sensibly with time limits + results, n, p = await db.get_message_records(ascending=True, start_time=2, end_time=10, page_size=2) + assert len(results) == 2 + assert [r.timestamp for r in results] == [2, 3] + assert n is not None + assert p is None + results, n, p = await db.get_message_records(ascending=True, start_time=2, end_time=10, page_size=2, bookmark=n) + assert len(results) == 2 + assert [r.timestamp for r in results] == [8, 9] + assert n is None + assert p is not None + +@pytest.mark.asyncio +@pytest.mark.parametrize("db_class", (database_api.SQL_db, database_api.Mock_db)) +async def test_SQL_db_count_message_records_public(db_class, tmpdir): + with temp_postgres(tmpdir) as db_conf: + db = db_class(db_conf) + await db.connect() + await db.make_schema() + + p1, m1, a1 = generate_message(b"alert", topic="t1", timestamp=47, public=True) + await db.insert(m1, a1) + p2, m2, a2 = generate_message(b"another alert", topic="t1", timestamp=49, public=True) + await db.insert(m2, a2) + + p3, m3, a3 = generate_message(b"secret", topic="t2", timestamp=22, public=False) + await db.insert(m3, a3) + p4, m4, a4 = generate_message(b"shared", topic="t2", timestamp=81, public=True) + await db.insert(m4, a4) + + p5, m5, a5 = generate_message(b"private", topic="t3", timestamp=35, public=False) + await db.insert(m5, a5) + p4, m6, a6 = generate_message(b"proprietary", topic="t3", timestamp=48, public=False) + await db.insert(m6, a6) + + # with no topics explicitly selected, all public messages across all topics should be + # selected + c = await db.count_message_records() + assert c == 3 + + # repeat, selecting only messages on one topic + c = await db.count_message_records(topics_public=["t2"]) + # only the public message on the selected topics should be counted + assert c == 1 + + # explicitly select multiple topics + c = await db.count_message_records(topics_public=["t1", "t2"]) + assert c == 3 + + # explicitly select a topic with no matching messages + c = await db.count_message_records(topics_public=["t3"]) + assert c == 0 + +@pytest.mark.asyncio +@pytest.mark.parametrize("db_class", (database_api.SQL_db, database_api.Mock_db)) +async def test_SQL_db_count_message_records_full(db_class, tmpdir): + with temp_postgres(tmpdir) as db_conf: + db = db_class(db_conf) + await db.connect() + await db.make_schema() + + p1, m1, a1 = generate_message(b"alert", topic="t1", timestamp=47, public=True) + await db.insert(m1, a1) + p2, m2, a2 = generate_message(b"another alert", topic="t1", timestamp=49, public=True) + await db.insert(m2, a2) + + p3, m3, a3 = generate_message(b"secret", topic="t2", timestamp=22, public=False) + await db.insert(m3, a3) + p4, m4, a4 = generate_message(b"shared", topic="t2", timestamp=81, public=True) + await db.insert(m4, a4) + + p5, m5, a5 = generate_message(b"private", topic="t3", timestamp=35, public=False) + await db.insert(m5, a5) + p4, m6, a6 = generate_message(b"proprietary", topic="t3", timestamp=48, public=False) + await db.insert(m6, a6) + + # reading from a topic with full access should find both public and private messages + c = await db.count_message_records(topics_full=["t2"]) + assert c == 2 + + # reading multiple topics should interleve messages in time order + c = await db.count_message_records(topics_full=["t2", "t3"]) + assert c == 4 + + # should be able to mix public and full access to different topics + c = await db.count_message_records(topics_public=["t2"], topics_full=["t3"]) + assert c == 3 + +@pytest.mark.asyncio +@pytest.mark.parametrize("db_class", (database_api.SQL_db, database_api.Mock_db)) +async def test_SQL_db_count_message_records_time_range(db_class, tmpdir): with temp_postgres(tmpdir) as db_conf: + db = db_class(db_conf) db = database_api.SQL_db(db_conf) await db.connect() await db.make_schema() - for m in messages: - await db.insert(m[1],m[2]) - - r = await db.get_message_records_for_time_range("t1", start_time=4, end_time=7) - assert len(r) == 3 - assert r[0].timestamp == 4 - assert r[0].uuid == uuid.UUID(messages[4][2]["con_text_uuid"]) - assert r[1].timestamp == 5 - assert r[1].uuid == uuid.UUID(messages[5][2]["con_text_uuid"]) - assert r[2].timestamp == 6 - assert r[2].uuid == uuid.UUID(messages[6][2]["con_text_uuid"]) - - r = await db.get_message_records_for_time_range("t1", start_time=3, end_time=7, limit=2) - assert len(r) == 2 - assert r[0].timestamp == 3 - assert r[0].uuid == uuid.UUID(messages[3][2]["con_text_uuid"]) - assert r[1].timestamp == 4 - assert r[1].uuid == uuid.UUID(messages[4][2]["con_text_uuid"]) - - r = await db.get_message_records_for_time_range("t1", start_time=3, end_time=7, limit=2, offset=2) - assert len(r) == 2 - assert r[0].timestamp == 5 - assert r[0].uuid == uuid.UUID(messages[5][2]["con_text_uuid"]) - assert r[1].timestamp == 6 - assert r[1].uuid == uuid.UUID(messages[6][2]["con_text_uuid"]) - - r = await db.get_message_records_for_time_range("t1", start_time=12, end_time=14) - assert len(r) == 0 - - r = await db.get_message_records_for_time_range("t2", start_time=0, end_time=5) - assert len(r) == 0 + for i in range(0,4): + p, m, a = generate_message(b"data", topic="t1", timestamp=i, public=True) + await db.insert(m, a) + for i in range(4,8): + p, m, a = generate_message(b"data", topic="t2", timestamp=i, public=False) + await db.insert(m, a) + for i in range(8,12): + p, m, a = generate_message(b"data", topic="t1", timestamp=i, public=True) + await db.insert(m, a) + + # all public messages + c = await db.count_message_records() + assert c == 8 + + # all public messages after 2 + c = await db.count_message_records(start_time=2) + assert c == 6 + + # all public messages before 10 + c = await db.count_message_records(end_time=10) + assert c == 6 + + # all public messages between 2 and 10 + c = await db.count_message_records(start_time=2, end_time=10) + assert c == 4 + + # all messages + c = await db.count_message_records(topics_full=["t1", "t2"]) + assert c == 12 + + # all messages after 2 + c = await db.count_message_records(topics_full=["t1", "t2"], start_time=2) + assert c == 10 + + # all messages before 10 + c = await db.count_message_records(topics_full=["t1", "t2"], end_time=10) + assert c == 10 + + # all messages between 2 and 10 + c = await db.count_message_records(topics_full=["t1", "t2"], start_time=2, end_time=10) + assert c == 8 + # These tests test test code, which is a bit pointless, but keeps it from cluttering up the coverage # reports as being un-covered @@ -458,6 +883,15 @@ async def test_Base_db_unimlemented(): with pytest.raises(NotImplementedError): await db.get_message_locations(None) + + with pytest.raises(NotImplementedError): + await db.get_topics_with_public_messages() + + with pytest.raises(NotImplementedError): + await db.get_message_records() + + with pytest.raises(NotImplementedError): + await db.count_message_records() @pytest.mark.asyncio async def test_Mock_db_get_message_id(): @@ -515,4 +949,4 @@ async def test_Mock_db_fetch(): assert result.bucket == annotations['bucket'] assert result.crc32 == annotations['crc32'] assert result.is_client_uuid == annotations['con_is_client_uuid'] - assert result.message_crc32 == annotations['con_message_crc32'] \ No newline at end of file + assert result.message_crc32 == annotations['con_message_crc32'] diff --git a/tests/test_decision_api.py b/tests/test_decision_api.py index 7119438..4f189c5 100644 --- a/tests/test_decision_api.py +++ b/tests/test_decision_api.py @@ -16,6 +16,7 @@ async def test_is_content_identical(): "store_type": "mock", "store_primary_bucket": "b1", "store_backup_bucket": "b2", + "store_region_name": "eu-north-3" } db = database_api.DbFactory(config) st = store_api.StoreFactory(config) @@ -105,6 +106,7 @@ async def test_is_deemed_duplicate(): "store_type": "mock", "store_primary_bucket": "b1", "store_backup_bucket": "b2", + "store_region_name": "r", } db = database_api.DbFactory(config) st = store_api.StoreFactory(config) diff --git a/tests/test_store_api.py b/tests/test_store_api.py index f09938e..dfab9ef 100644 --- a/tests/test_store_api.py +++ b/tests/test_store_api.py @@ -175,7 +175,8 @@ async def test_Mock_store_store_readonly(tmpdir): metadata = Metadata(topic="t1", partition=0, offset=2, timestamp=356, key="", headers=[("_id",u.bytes)], _raw=None) annotations = decision_api.get_annotations(message, metadata.headers) - st = store_api.Mock_store({"store_primary_bucket": "a", "store_backup_bucket": "b"}) + st = store_api.Mock_store({"store_primary_bucket": "a", "store_backup_bucket": "b", + "store_region_name": "r"}) await st.set_read_only() await st.connect() @@ -193,7 +194,8 @@ async def test_Mock_store_deep_delete(tmpdir): metadata = Metadata(topic="t1", partition=0, offset=2, timestamp=356, key="", headers=[("_id",u.bytes)], _raw=None) annotations = decision_api.get_annotations(message, metadata.headers) - st = store_api.Mock_store({"store_primary_bucket": "a", "store_backup_bucket": "b"}) + st = store_api.Mock_store({"store_primary_bucket": "a", "store_backup_bucket": "b", + "store_region_name": "r"}) await st.connect() await st.store(message, metadata, annotations) assert await st.get_object(annotations["key"]) is not None @@ -216,7 +218,7 @@ async def test_Mock_store_deep_delete(tmpdir): def test_Base_store_unimplemented(): bs = store_api.Base_store({"store_primary_bucket": "a", "store_backup_bucket": "b", - "store_region_name": "nowhere"}) + "store_region_name": "r"}) with pytest.raises(NotImplementedError): bs.get_object("akey") with pytest.raises(NotImplementedError):