diff --git a/docs/explanation/data-flow.md b/docs/explanation/data-flow.md index c8053b2..afbd158 100644 --- a/docs/explanation/data-flow.md +++ b/docs/explanation/data-flow.md @@ -26,14 +26,14 @@ Let's say the lambda runs a few tasks, which either pass (status `OK`) or fail ( ## Phase 2: What Gets Published to SNS -We publish one SNS message per status type, so in this case two messages: one for `OK` and one for `ERROR`. Each message includes the relevant portion of the `_perform_task` return value (JSON-encoded) as the `Message` payload. The `result_type` is included as a `MessageAttribute` to allow for filtering by the SQS subscriptions. +We publish one SNS message per status type, so in this case two messages: one for `OK` and one for `ERROR`. Each message includes the relevant portion of the `_perform_task` return value, with a top-level `result_type` added before publication. The same `result_type` is also included as a `MessageAttribute` to allow for filtering by the SQS subscriptions. For `OK`, the publish call payload shape is: ```json { "TopicArn": "arn:aws:sns:us-east-1:123456789012:lambdacron-results.fifo", - "Message": "{\"tasks\": [{\"taskid\": 1, \"name\": \"Foo\"}, {\"taskid\": 2, \"name\": \"Bar\"}]}", + "Message": "{\"tasks\": [{\"taskid\": 1, \"name\": \"Foo\"}, {\"taskid\": 2, \"name\": \"Bar\"}], \"result_type\": \"OK\"}", "Subject": "Notification for OK", "MessageAttributes": { "result_type": { @@ -47,7 +47,7 @@ For `OK`, the publish call payload shape is: For `ERROR`, the shape is identical except: -- `Message` is `"{\"tasks\": [{\"taskid\": 3, \"name\": \"Baz\"}]}"` +- `Message` is `"{\"tasks\": [{\"taskid\": 3, \"name\": \"Baz\"}], \"result_type\": \"ERROR\"}"` - `Subject` is `"Notification for ERROR"` - `MessageAttributes.result_type.StringValue` is `"ERROR"` @@ -63,7 +63,7 @@ The SQS event record (as seen by notifier Lambda) looks like: { "messageId": "msg-ok-1", "eventSource": "aws:sqs", - "body": "{\"tasks\": [{\"taskid\": 1, \"name\": \"Foo\"}, {\"taskid\": 2, \"name\": \"Bar\"}]}", + "body": "{\"tasks\": [{\"taskid\": 1, \"name\": \"Foo\"}, {\"taskid\": 2, \"name\": \"Bar\"}], \"result_type\": \"OK\"}", "messageAttributes": { "result_type": { "stringValue": "OK", @@ -77,8 +77,7 @@ The SQS event record (as seen by notifier Lambda) looks like: ## Phase 4: Notifier Parse Behavior with This Shape -The notifier parser converts the JSON string back into an object, and, if `result_type` is missing from the payload, injects it from the SQS `messageAttributes`. In this case, the payload doesn't include `result_type`, so it is injected. - +The notifier parser converts the JSON string back into an object and validates that the payload `result_type` matches the SQS `messageAttributes.result_type` value when that attribute is present. The result, which is fed to the notifier's templates, is: ```json diff --git a/docs/how-to/set-up-ses.md b/docs/how-to/set-up-ses.md index 64ed67d..503a385 100644 --- a/docs/how-to/set-up-ses.md +++ b/docs/how-to/set-up-ses.md @@ -12,8 +12,10 @@ Set up Amazon SES once per AWS account and region where your notifier Lambda run * Choose the AWS account/region where email will be sent (SES setup is regional). * Choose a sender identity: + * Email identity for one sender address. * Domain identity if you want to send from multiple addresses in one domain. + * Make sure you can edit DNS records if you choose a domain identity. SES requires verified identities for senders and (in sandbox) recipients. That means you must verify the email address or domain you want to send from, and if in sandbox, also verify any recipient addresses. Sandbox mode limits how much you can send, and you'll probably want to request production access if you want to send to more than one recipient. diff --git a/docs/how-to/write-lambda-and-templates.md b/docs/how-to/write-lambda-and-templates.md index 9445888..6a62deb 100644 --- a/docs/how-to/write-lambda-and-templates.md +++ b/docs/how-to/write-lambda-and-templates.md @@ -33,7 +33,7 @@ Based on `examples/basic/lambda/lambda_module.py`: * Key (`example`) is the `result_type`. * Value (`{"message": "Hello World"}`) is the payload rendered by templates. -* The payload automatically has the key `result_type` injected by the notification handler, so you can access it in templates as `{{ result_type }}`. +* LambdaCron adds `result_type` to the published message body when it sends the payload to SNS, so you can access it in templates as `{{ result_type }}`. ## 2. Create templates that use the payload fields @@ -90,7 +90,8 @@ At runtime: Important detail: -* `result_type` is injected into the render payload by the notification handler when available in message attributes. +* `result_type` is injected into the message body by the publisher before it reaches the notifier. +* The SNS message attribute still carries the same `result_type`. It is used for filter policies and validation. * That is why templates like `{{ result_type }}` work even when your `_perform_task` payload does not explicitly include a `result_type` field. ## 4. Checklist Before Deploying diff --git a/src/lambdacron/lambda_task.py b/src/lambdacron/lambda_task.py index 45f469e..91764fe 100644 --- a/src/lambdacron/lambda_task.py +++ b/src/lambdacron/lambda_task.py @@ -158,6 +158,46 @@ def load_sns_message_group_id(env_var: str = "SNS_MESSAGE_GROUP_ID") -> str: return os.environ.get(env_var, "lambdacron") +def build_result_message_payload(*, result_type: str, message: Any) -> dict[str, Any]: + """ + Build the JSON object published to SNS for a single result type. + + Parameters + ---------- + result_type : str + Result type key from the task output mapping. + message : Any + JSON-serializable payload associated with the result type. + + Returns + ------- + dict[str, Any] + Payload object with a top-level ``result_type`` field. + + Raises + ------ + ValueError + If the payload is not a JSON object or if it contains a conflicting + ``result_type`` value. + """ + if not isinstance(message, Mapping): + raise ValueError( + f"Result payload for type '{result_type}' must be a JSON object" + ) + + payload = dict(message) + existing_result_type = payload.get("result_type") + if existing_result_type is None: + payload["result_type"] = result_type + else: + if existing_result_type != result_type: + raise ValueError( + f"Result payload for type '{result_type}' has conflicting " + f"result_type '{existing_result_type}'" + ) + return payload + + def dispatch_sns_messages( *, result: Mapping[str, Any], @@ -180,9 +220,10 @@ def dispatch_sns_messages( Logger used to emit structured publish logs. """ for result_type, message in result.items(): + payload = build_result_message_payload(result_type=result_type, message=message) sns_client.publish( TopicArn=sns_topic_arn, - Message=json.dumps(message), + Message=json.dumps(payload), Subject=f"Notification for {result_type}", MessageAttributes={ "result_type": { diff --git a/src/lambdacron/notifications/base.py b/src/lambdacron/notifications/base.py index eed5874..644e0b3 100644 --- a/src/lambdacron/notifications/base.py +++ b/src/lambdacron/notifications/base.py @@ -94,8 +94,6 @@ class RenderedTemplateNotificationHandler(ABC): Providers keyed by template name for rendering. expected_queue_arn : str, optional Queue ARN to validate incoming SQS records. - include_result_type : bool, optional - Whether to include the SNS message attribute ``result_type`` in the payload. logger : logging.Logger, optional Logger used for structured logging. jinja_env : jinja2.Environment, optional @@ -107,13 +105,11 @@ def __init__( template_providers: Mapping[str, TemplateProvider], *, expected_queue_arn: Optional[str] = None, - include_result_type: bool = True, logger: Optional[logging.Logger] = None, jinja_env: Optional[Environment] = None, ) -> None: self.template_providers = dict(template_providers) self.expected_queue_arn = expected_queue_arn - self.include_result_type = include_result_type self.logger = logger or logging.getLogger(self.__class__.__name__) self.jinja_env = jinja_env or Environment(undefined=StrictUndefined) @@ -213,10 +209,18 @@ def _parse_result(self, record: Mapping[str, Any]) -> Mapping[str, Any]: raise ValueError("SNS message must be valid JSON") from exc if not isinstance(payload, dict): raise ValueError("Result payload must be a JSON object") - if self.include_result_type: - result_type = self._extract_result_type(record) - if result_type and "result_type" not in payload: - payload["result_type"] = result_type + payload_result_type = payload.get("result_type") + if not isinstance(payload_result_type, str) or not payload_result_type: + raise ValueError( + "Result payload must include a non-empty string result_type" + ) + + attribute_result_type = self._extract_result_type(record) + if attribute_result_type and attribute_result_type != payload_result_type: + raise ValueError( + "Result type mismatch between payload and message attributes " + f"(payload {payload_result_type}, attribute {attribute_result_type})" + ) return payload def _render_template(self, template: str, result: Mapping[str, Any]) -> str: diff --git a/src/lambdacron/notifications/print_handler.py b/src/lambdacron/notifications/print_handler.py index d8ff95c..c07fc85 100644 --- a/src/lambdacron/notifications/print_handler.py +++ b/src/lambdacron/notifications/print_handler.py @@ -25,13 +25,11 @@ def __init__( *, template_provider: TemplateProvider, expected_queue_arn: str | None = None, - include_result_type: bool = True, logger: Any | None = None, ) -> None: super().__init__( template_providers={"body": template_provider}, expected_queue_arn=expected_queue_arn, - include_result_type=include_result_type, logger=logger, ) diff --git a/src/lambdacron/render.py b/src/lambdacron/render.py index 6b6860a..fb8b1f8 100644 --- a/src/lambdacron/render.py +++ b/src/lambdacron/render.py @@ -6,6 +6,7 @@ from jinja2 import TemplateError +from lambdacron.lambda_task import build_result_message_payload from lambdacron.notifications.base import ( FileTemplateProvider, RenderedTemplateNotificationHandler, @@ -46,8 +47,7 @@ def build_parser() -> argparse.ArgumentParser: class RenderNotificationHandler(RenderedTemplateNotificationHandler): def __init__(self, *, template_path: Path, stream: TextIO | None = None) -> None: super().__init__( - template_providers={"body": FileTemplateProvider(template_path)}, - include_result_type=True, + template_providers={"body": FileTemplateProvider(template_path)} ) self.stream = stream or sys.stdout @@ -95,11 +95,15 @@ def extract_result_payload(payload_json: str, *, result_type: str) -> str: if not isinstance(payload, dict): raise ValueError("Task output must be a JSON object keyed by result type") selected = payload.get(result_type) - if not isinstance(selected, dict): + if not isinstance(selected, Mapping): raise ValueError( - f"Result payload for type '{result_type}' must be a JSON object" + f"Result payload for type '{result_type}' must be a JSON object, " + f"got {type(selected).__name__}" ) - return json.dumps(selected) + payload_for_publish = build_result_message_payload( + result_type=result_type, message=selected + ) + return json.dumps(payload_for_publish) def main(argv: list[str] | None = None) -> int: diff --git a/tests/notifications/test_base.py b/tests/notifications/test_base.py index 3ef0e89..55cda37 100644 --- a/tests/notifications/test_base.py +++ b/tests/notifications/test_base.py @@ -71,13 +71,13 @@ def test_file_template_provider_raises_for_missing_file(tmp_path): def test_notification_handler_parses_sqs_json_body(monkeypatch): monkeypatch.setenv("TEMPLATE", "Status {{ status }}") handler = CapturingHandler(template_providers={"body": EnvVarTemplateProvider()}) - event = build_sqs_event(json.dumps({"status": "ok"})) + event = build_sqs_event(json.dumps({"status": "ok", "result_type": "success"})) response = handler.lambda_handler(event, context=None) assert handler.calls == [ { - "result": {"status": "ok"}, + "result": {"status": "ok", "result_type": "success"}, "rendered": {"body": "Status ok"}, "record": event["Records"][0], } @@ -88,7 +88,9 @@ def test_notification_handler_parses_sqs_json_body(monkeypatch): def test_notification_handler_parses_sns_envelope(monkeypatch): monkeypatch.setenv("TEMPLATE", "Result {{ status }}") handler = CapturingHandler(template_providers={"body": EnvVarTemplateProvider()}) - sns_body = json.dumps({"Message": json.dumps({"status": "good"})}) + sns_body = json.dumps( + {"Message": json.dumps({"status": "good", "result_type": "success"})} + ) event = build_sqs_event(sns_body) response = handler.lambda_handler(event, context=None) @@ -141,8 +143,14 @@ def test_notification_handler_logs_invocation(monkeypatch, caplog): ) event = { "Records": [ - {"body": json.dumps({"name": "Ada"}), "eventSource": "aws:sqs"}, - {"body": json.dumps({"name": "Grace"}), "eventSource": "aws:sqs"}, + { + "body": json.dumps({"name": "Ada", "result_type": "success"}), + "eventSource": "aws:sqs", + }, + { + "body": json.dumps({"name": "Grace", "result_type": "success"}), + "eventSource": "aws:sqs", + }, ] } @@ -204,30 +212,41 @@ def test_parse_result_rejects_non_object_payload(monkeypatch): assert response == {"batchItemFailures": [{"itemIdentifier": "msg-123"}]} -@pytest.mark.parametrize("include_result_type", [True, False]) -@pytest.mark.parametrize("payload_has_result_type", [True, False]) -def test_notification_handler_result_type_injection( - monkeypatch, include_result_type, payload_has_result_type -): +def test_notification_handler_payload_result_type_passes_through(monkeypatch): + monkeypatch.setenv("TEMPLATE", "Result {{ result_type }}") + handler = CapturingHandler(template_providers={"body": EnvVarTemplateProvider()}) + payload = {"status": "ok", "result_type": "payload"} + event = build_sqs_event( + json.dumps(payload), + message_attributes={"result_type": {"stringValue": "payload"}}, + ) + + handler.lambda_handler(event, context=None) + + assert handler.calls[0]["result"] == payload + + +def test_notification_handler_requires_payload_result_type(monkeypatch): monkeypatch.setenv("TEMPLATE", "Result {{ result_type | default('none') }}") - handler = CapturingHandler( - template_providers={"body": EnvVarTemplateProvider()}, - include_result_type=include_result_type, + handler = CapturingHandler(template_providers={"body": EnvVarTemplateProvider()}) + event = build_sqs_event( + json.dumps({"status": "ok"}), + message_attributes={"result_type": {"stringValue": "attribute"}}, ) - payload = {"status": "ok"} - if payload_has_result_type: - payload["result_type"] = "payload" + + response = handler.lambda_handler(event, context=None) + + assert response == {"batchItemFailures": [{"itemIdentifier": "msg-123"}]} + + +def test_notification_handler_rejects_result_type_mismatch(monkeypatch): + monkeypatch.setenv("TEMPLATE", "Result {{ result_type }}") + handler = CapturingHandler(template_providers={"body": EnvVarTemplateProvider()}) event = build_sqs_event( - json.dumps(payload), + json.dumps({"status": "ok", "result_type": "payload"}), message_attributes={"result_type": {"stringValue": "attribute"}}, ) - handler.lambda_handler(event, context=None) + response = handler.lambda_handler(event, context=None) - result = handler.calls[0]["result"] - if payload_has_result_type: - assert result["result_type"] == "payload" - elif include_result_type: - assert result["result_type"] == "attribute" - else: - assert "result_type" not in result + assert response == {"batchItemFailures": [{"itemIdentifier": "msg-123"}]} diff --git a/tests/notifications/test_email_handler.py b/tests/notifications/test_email_handler.py index 3f9e427..cd53b64 100644 --- a/tests/notifications/test_email_handler.py +++ b/tests/notifications/test_email_handler.py @@ -40,7 +40,7 @@ def test_email_handler_sends_rendered_templates(monkeypatch): recipients=["alice@example.com", "bob@example.com"], ses_client=ses_client, ) - event = build_sqs_event({"name": "Ada"}) + event = build_sqs_event({"name": "Ada", "result_type": "success"}) handler.lambda_handler(event, context=None) @@ -71,7 +71,7 @@ def test_email_handler_includes_optional_fields(monkeypatch): config_set="alerts", reply_to=["reply@example.com"], ) - event = build_sqs_event({"name": "Grace"}) + event = build_sqs_event({"name": "Grace", "result_type": "success"}) handler.lambda_handler(event, context=None) @@ -102,7 +102,9 @@ def send_email(self, **kwargs): recipients=["ops@example.com"], ses_client=ErrorSesClient(), ) - event = build_sqs_event({"name": "Ada"}, message_id="msg-err") + event = build_sqs_event( + {"name": "Ada", "result_type": "success"}, message_id="msg-err" + ) response = handler.lambda_handler(event, context=None) diff --git a/tests/notifications/test_print_handler.py b/tests/notifications/test_print_handler.py index 525e387..3a4c41b 100644 --- a/tests/notifications/test_print_handler.py +++ b/tests/notifications/test_print_handler.py @@ -7,7 +7,12 @@ def test_print_handler_prints_rendered_template(monkeypatch, capsys): monkeypatch.setenv("TEMPLATE", "Hello {{ name }}") handler = PrintNotificationHandler(template_provider=EnvVarTemplateProvider()) event = { - "Records": [{"body": json.dumps({"name": "Ada"}), "eventSource": "aws:sqs"}] + "Records": [ + { + "body": json.dumps({"name": "Ada", "result_type": "success"}), + "eventSource": "aws:sqs", + } + ] } handler.lambda_handler(event, context=None) diff --git a/tests/test_lambda_task.py b/tests/test_lambda_task.py index 3c9a9ec..3c2d15a 100644 --- a/tests/test_lambda_task.py +++ b/tests/test_lambda_task.py @@ -7,6 +7,7 @@ from lambdacron.lambda_task import ( CronLambdaTask, + build_result_message_payload, dispatch_sns_messages, extract_context_metadata, load_sns_message_group_id, @@ -57,7 +58,7 @@ def test_dispatch_sns_messages_publishes(caplog): sns_client.publish.assert_any_call( TopicArn="arn:one", - Message=json.dumps({"ok": True}), + Message=json.dumps({"ok": True, "result_type": "success"}), Subject="Notification for success", MessageAttributes={ "result_type": {"DataType": "String", "StringValue": "success"} @@ -66,7 +67,7 @@ def test_dispatch_sns_messages_publishes(caplog): ) sns_client.publish.assert_any_call( TopicArn="arn:one", - Message=json.dumps({"ok": False}), + Message=json.dumps({"ok": False, "result_type": "failure"}), Subject="Notification for failure", MessageAttributes={ "result_type": {"DataType": "String", "StringValue": "failure"} @@ -95,7 +96,7 @@ def _perform_task(self, event, context): sns_client.publish.assert_called_once_with( TopicArn="arn:one", - Message=json.dumps({"ok": True}), + Message=json.dumps({"ok": True, "result_type": "success"}), Subject="Notification for success", MessageAttributes={ "result_type": {"DataType": "String", "StringValue": "success"} @@ -104,6 +105,68 @@ def _perform_task(self, event, context): ) +def test_build_result_message_payload_rejects_conflicting_result_type(): + with pytest.raises( + ValueError, + match="Result payload for type 'success' has conflicting result_type", + ): + build_result_message_payload( + result_type="success", + message={"result_type": "failure", "ok": False}, + ) + + +def test_dispatch_sns_messages_publishes_distinct_bodies_for_identical_payloads(): + sns_client = Mock() + logger = logging.getLogger("test_dispatch_identical") + result = { + "failure": {"status": "bad"}, + "failure_or_warning": {"status": "bad"}, + } + + dispatch_sns_messages( + result=result, + sns_topic_arn="arn:one", + sns_client=sns_client, + logger=logger, + ) + + published_messages = [ + json.loads(call.kwargs["Message"]) for call in sns_client.publish.call_args_list + ] + + assert published_messages == [ + {"status": "bad", "result_type": "failure"}, + {"status": "bad", "result_type": "failure_or_warning"}, + ] + + +def test_dispatch_sns_messages_does_not_mutate_shared_payload(): + sns_client = Mock() + logger = logging.getLogger("test_dispatch_shared") + shared_payload = {"status": "bad"} + + dispatch_sns_messages( + result={ + "failure": shared_payload, + "failure_or_warning": shared_payload, + }, + sns_topic_arn="arn:one", + sns_client=sns_client, + logger=logger, + ) + + published_messages = [ + json.loads(call.kwargs["Message"]) for call in sns_client.publish.call_args_list + ] + + assert published_messages == [ + {"status": "bad", "result_type": "failure"}, + {"status": "bad", "result_type": "failure_or_warning"}, + ] + assert shared_payload == {"status": "bad"} + + def test_cron_lambda_task_init_loads_sns_topic_arn(monkeypatch): monkeypatch.setenv("SNS_TOPIC_ARN", "arn:one") diff --git a/tests/test_render.py b/tests/test_render.py index 36cba0b..9dcb468 100644 --- a/tests/test_render.py +++ b/tests/test_render.py @@ -30,7 +30,9 @@ def test_render_main_renders_with_long_flags(tmp_path, capsys): assert captured.err == "" -def test_render_main_short_flags_preserve_payload_result_type(tmp_path, capsys): +def test_render_main_short_flags_reject_mismatched_payload_result_type( + tmp_path, capsys +): template_path = tmp_path / "template.jinja2" output_path = tmp_path / "output.json" template_path.write_text("{{ result_type }}", encoding="utf-8") @@ -42,9 +44,11 @@ def test_render_main_short_flags_preserve_payload_result_type(tmp_path, capsys): code = render.main(["-t", str(template_path), "-r", "attribute", str(output_path)]) captured = capsys.readouterr() - assert code == 0 - assert captured.out == "payload\n" - assert captured.err == "" + assert code == 1 + assert ( + "Result payload for type 'attribute' has conflicting result_type 'payload'" + in captured.err + ) def test_render_main_rejects_non_object_json(tmp_path, capsys): @@ -117,6 +121,23 @@ def test_render_main_reads_from_stdin_with_dash(tmp_path, capsys, monkeypatch): assert captured.err == "" +def test_render_main_preserves_matching_payload_result_type(tmp_path, capsys): + template_path = tmp_path / "template.jinja2" + output_path = tmp_path / "output.json" + template_path.write_text("{{ result_type }}", encoding="utf-8") + output_path.write_text( + json.dumps({"success": {"status": "ok", "result_type": "success"}}), + encoding="utf-8", + ) + + code = render.main(["-t", str(template_path), "-r", "success", str(output_path)]) + + captured = capsys.readouterr() + assert code == 0 + assert captured.out == "success\n" + assert captured.err == "" + + def test_render_main_extracts_payload_from_result_map_stdin( tmp_path, capsys, monkeypatch ): @@ -149,7 +170,10 @@ def test_render_main_rejects_task_output_missing_result_type(tmp_path, capsys): captured = capsys.readouterr() assert code == 1 - assert "Result payload for type 'success' must be a JSON object" in captured.err + assert ( + "Result payload for type 'success' must be a JSON object, got NoneType" + in captured.err + ) def test_render_main_rejects_non_object_result_payload(tmp_path, capsys): @@ -162,4 +186,7 @@ def test_render_main_rejects_non_object_result_payload(tmp_path, capsys): captured = capsys.readouterr() assert code == 1 - assert "Result payload for type 'success' must be a JSON object" in captured.err + assert ( + "Result payload for type 'success' must be a JSON object, got str" + in captured.err + )