Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions mu-qa.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
project-org = 'Level 12'
image-name = 'critic'

[tool.mu.event-rules.run-due-checks]
action='run_due_checks'
cron = '* * * * *' # Every minute on the minute
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ dependencies = [
"moto[all]>=5.1.14",
"pydantic>=2.7",
"httpx>=0.27",
"polyfactory>=3.2.0",
]


Expand All @@ -41,6 +42,7 @@ dev = [
]
# Used by nox
pytest = [
"freezegun>=1.5.5",
'pytest',
'pytest-cov',
'respx>=0.21',
Expand Down
8 changes: 8 additions & 0 deletions src/critic/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from flask import Flask
import mu

from critic.tasks import run_due_checks


log = logging.getLogger()

Expand Down Expand Up @@ -32,6 +34,12 @@ def error():
class ActionHandler(mu.ActionHandler):
wsgi_app = app

@staticmethod
def run_due_checks(event, context):
"""Triggered by EventBridge rule, invokes `run_due_checks` task."""
log.info('Invoking run_due_checks')
run_due_checks.invoke()


# The entry point for AWS lambda has to be a function
lambda_handler = ActionHandler.on_event
92 changes: 66 additions & 26 deletions src/critic/libs/ddb.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,80 @@
from datetime import datetime
from decimal import Decimal
import os

from boto3 import client
from boto3.dynamodb.types import TypeDeserializer, TypeSerializer
from pydantic import BaseModel
from pydantic import AwareDatetime, BaseModel, TypeAdapter

from critic.libs.dt import to_utc

client = client('dynamodb')
serializer = TypeSerializer()
deserializer = TypeDeserializer()

# https://www.reddit.com/r/aws/comments/cwams9/dynamodb_i_need_to_sort_whole_table_by_range_how/
CONSTANT_GSI_PK = 'bogus'

def serialize(data: dict) -> dict:
_ddb_client = None


def get_client():
"""
Get a boto3 DynamoDB client without recreating it if it already exists.
"""
global _ddb_client
if _ddb_client is None:
_ddb_client = client('dynamodb')
return _ddb_client


class Serializer:
"""Serialize standard JSON to DynamoDB format."""
return {k: serializer.serialize(v) for k, v in data.items()}

_serializer = TypeSerializer()
_aware_dt_adapter = TypeAdapter(AwareDatetime)

def deserialize(data: dict) -> dict:
"""Deserialize DynamoDB format to standard JSON."""
return {k: deserializer.deserialize(v) for k, v in data.items()}
@staticmethod
def dt_to_str(value):
if isinstance(value, datetime):
# Convert datetime to string in the same way Pydantic does to ensure consistency
return Serializer._aware_dt_adapter.dump_python(to_utc(value), mode='json')
return value

@staticmethod
def float_to_decimal(value):
if isinstance(value, float):
return Decimal(str(value))

def namespace_table(table_name: str) -> str:
return f'{table_name}-{os.environ["CRITIC_NAMESPACE"]}'
if isinstance(value, list):
return [Serializer.float_to_decimal(v) for v in value]

if isinstance(value, dict):
return {k: Serializer.float_to_decimal(v) for k, v in value.items()}

def floats_to_decimals(value):
if isinstance(value, float):
return Decimal(str(value))
return value

if isinstance(value, list):
return [floats_to_decimals(v) for v in value]
def serialize(self, value):
value = self.dt_to_str(value)
value = self.float_to_decimal(value)
return self._serializer.serialize(value)

if isinstance(value, dict):
return {k: floats_to_decimals(v) for k, v in value.items()}
def __call__(self, data: dict) -> dict:
return {k: self.serialize(v) for k, v in data.items()}

return value

class Deserializer:
"""Deserialize DynamoDB format to standard JSON."""

_deserializer = TypeDeserializer()

def __call__(self, data: dict) -> dict:
return {k: self._deserializer.deserialize(v) for k, v in data.items()}


serialize = Serializer()
deserialize = Deserializer()


class Table:
name: str
base_name: str
model: type[BaseModel]
partition_key: str
sort_key: str | None = None
Expand All @@ -48,17 +83,22 @@ class Table:
def model_to_ddb(inst: BaseModel) -> dict:
"""Convert a Pydantic model instance to a DynamoDB-compatible dict."""
plain = inst.model_dump(mode='json', exclude_none=True)
return serialize(floats_to_decimals(plain))
return serialize(plain)

@staticmethod
def namespace(table_name: str) -> str:
return f'{table_name}-{os.environ["CRITIC_NAMESPACE"]}'

@classmethod
def table_name(cls):
return namespace_table(cls.name)
def name(cls) -> str:
return cls.namespace(cls.base_name)

@classmethod
def put(cls, data: dict | BaseModel):
if isinstance(data, dict):
data = cls.model(**data)
client.put_item(TableName=cls.table_name(), Item=cls.model_to_ddb(data))
client = get_client()
client.put_item(TableName=cls.name(), Item=cls.model_to_ddb(data))

@classmethod
def get(cls, partition_value: str | int, sort_value: str | int | None = None):
Expand All @@ -70,8 +110,8 @@ def get(cls, partition_value: str | int, sort_value: str | int | None = None):
key[cls.sort_key] = sort_value

# Get item
item = client.get_item(
TableName=cls.table_name(),
item = get_client().get_item(
TableName=cls.name(),
Key=serialize(key),
)['Item']
return cls.model(**deserialize(item))
11 changes: 11 additions & 0 deletions src/critic/libs/dt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from datetime import UTC, datetime


def is_aware(dt: datetime) -> bool:
return dt.tzinfo is not None and dt.tzinfo.utcoffset(dt) is not None


def to_utc(dt: datetime) -> datetime:
if not is_aware(dt):
raise ValueError(f'datetime must be timezone aware, got {dt}')
return dt.astimezone(UTC)
31 changes: 26 additions & 5 deletions src/critic/libs/testing.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import boto3
from polyfactory.factories.pydantic_factory import ModelFactory
from pydantic import BaseModel

from critic.libs.ddb import client, namespace_table
from critic.libs.ddb import Table, get_client
from critic.models import UptimeMonitorModel
from critic.tables import UptimeMonitorTable


def create_tables():
client = get_client()
client.create_table(
TableName=namespace_table('Project'),
TableName=Table.namespace('Project'),
AttributeDefinitions=[
{'AttributeName': 'id', 'AttributeType': 'S'},
],
Expand All @@ -16,7 +21,7 @@ def create_tables():
)

client.create_table(
TableName=namespace_table('UptimeMonitor'),
TableName=Table.namespace('UptimeMonitor'),
AttributeDefinitions=[
# Key attributes
{'AttributeName': 'project_id', 'AttributeType': 'S'},
Expand All @@ -43,7 +48,7 @@ def create_tables():
)

client.create_table(
TableName=namespace_table('UptimeLog'),
TableName=Table.namespace('UptimeLog'),
AttributeDefinitions=[
{'AttributeName': 'monitor_id', 'AttributeType': 'S'},
{'AttributeName': 'timestamp', 'AttributeType': 'S'},
Expand Down Expand Up @@ -93,5 +98,21 @@ def _clear_table(table_name: str):


def clear_tables():
for table_name in [namespace_table(t) for t in ('Project', 'UptimeMonitor', 'UptimeLog')]:
for table_name in [Table.namespace(t) for t in ('Project', 'UptimeMonitor', 'UptimeLog')]:
_clear_table(table_name)


class PutMixin:
__table__: type[Table]

@classmethod
def put(cls, **kwargs) -> BaseModel:
item = cls.build(**kwargs)
cls.__table__.put(item)
return item


class UptimeMonitorFactory(PutMixin, ModelFactory):
__model__ = UptimeMonitorModel
__table__ = UptimeMonitorTable
__use_defaults__ = True
15 changes: 12 additions & 3 deletions src/critic/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
from typing import Any
from uuid import UUID

from pydantic import BaseModel, Field, HttpUrl
from pydantic import AwareDatetime, BaseModel, Field, HttpUrl, field_validator

from critic.libs.ddb import CONSTANT_GSI_PK
from critic.libs.dt import to_utc


class MonitorState(str, Enum):
Expand All @@ -21,15 +24,21 @@ class UptimeMonitorModel(BaseModel):
state: MonitorState = MonitorState.new
url: HttpUrl
frequency_mins: int = Field(ge=1)
next_due_at: datetime
next_due_at: AwareDatetime
timeout_secs: float = Field(ge=0)
# TODO: assertions should probably become its own model
assertions: dict[str, Any] | None = None
failures_before_alerting: int
alert_slack_channels: list[str] = Field(default_factory=list)
alert_emails: list[str] = Field(default_factory=list)
realert_interval_mins: int = Field(ge=0)
GSI_PK: str = Field(default='all monitors')
GSI_PK: str = Field(default=CONSTANT_GSI_PK)

@field_validator('next_due_at')
@classmethod
def validate_next_due_at(cls, v: datetime) -> datetime:
"""Normalize to UTC"""
return to_utc(v)


class ProjectMonitors(BaseModel):
Expand Down
21 changes: 19 additions & 2 deletions src/critic/tables.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,27 @@
from critic.libs.ddb import Table
from datetime import datetime

from critic.libs.ddb import CONSTANT_GSI_PK, Table, deserialize, get_client, serialize

from .models import UptimeMonitorModel


class UptimeMonitorTable(Table):
name = 'UptimeMonitor'
base_name = 'UptimeMonitor'
model = UptimeMonitorModel
partition_key = 'project_id'
sort_key = 'slug'

@classmethod
def get_due_since(cls, timestamp: datetime) -> list[UptimeMonitorModel]:
response = get_client().query(
TableName=cls.name(),
IndexName='NextDueIndex',
KeyConditionExpression='GSI_PK = :pk AND next_due_at <= :timestamp',
ExpressionAttributeValues=serialize(
{
':pk': CONSTANT_GSI_PK,
':timestamp': timestamp,
}
),
)
return [cls.model(**deserialize(item)) for item in response['Items']]
36 changes: 36 additions & 0 deletions src/critic/tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from datetime import UTC, datetime, timedelta
import logging

import mu

from critic.tables import UptimeMonitorTable


log = logging.getLogger(__name__)


@mu.task
def run_check(project_id: str, slug: str):
pass


@mu.task
def run_due_checks():
"""
This task is invoked by an EventBridge rule once a minute. It queries for all monitors that are
due and invokes `run_check` for each one.
"""
now = datetime.now(UTC)
log.info(f'Triggering due checks at {now.isoformat()}')

# Round `now` to the nearest minute in case there is a slight inaccuracy in scheduling
rounded_now = now.replace(second=0, microsecond=0)
if now.second >= 30:
rounded_now = rounded_now + timedelta(minutes=1)

# Trigger `run_check` for each due monitor.
due_monitors = UptimeMonitorTable.get_due_since(rounded_now)
for monitor in due_monitors:
run_check.invoke(str(monitor.project_id), monitor.slug)

log.info(f'Due checks triggered for {len(due_monitors)} monitors in {datetime.now(UTC) - now}')
5 changes: 5 additions & 0 deletions tests/critic_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from moto import mock_aws
import pytest

import critic.libs.ddb as ddb_module
from critic.libs.testing import clear_tables, create_tables


Expand Down Expand Up @@ -42,6 +43,10 @@ def moto_for_unit_tests(request):
with mock_aws():
create_tables()
yield
# The DDB module is designed to cache the client. When we're testing unit tests and
# integration tests, this cache needs to be reset so the integration test doesn't get
# the mocked client and vice versa.
ddb_module._ddb_client = None


def pytest_configure(config):
Expand Down
6 changes: 6 additions & 0 deletions tests/critic_tests/test_libs/test_ddb.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from datetime import datetime

import pytest

from critic.models import UptimeMonitorModel
Expand Down Expand Up @@ -65,3 +67,7 @@ def test_missing_sort_key(self):
# error.
with pytest.raises(ValueError):
UptimeMonitorTable.get('6033aa47-a9f7-4d7f-b7ff-a11ba9b34474')

def test_serialize_unaware_dt(self):
with pytest.raises(ValueError, match='must be timezone aware'):
UptimeMonitorTable.get_due_since(datetime.now())
Loading