Skip to content

Commit dd3f82b

Browse files
authored
add SQSTraceExporter and SQSBatchSpanProcessor (#2)
1 parent 74a323e commit dd3f82b

11 files changed

Lines changed: 1021 additions & 10 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,6 @@
44
.ruff_cache/
55
.coverage
66
coverage.xml
7+
dist/
78
__pycache__/
89
*.py[cod]
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from aws_lambda_opentelemetry.trace.helpers import instrument_handler
2+
3+
__all__ = ["instrument_handler"]
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
import base64
2+
import enum
3+
import gzip
4+
import logging
5+
import os
6+
import threading
7+
import zlib
8+
from collections.abc import Sequence
9+
from io import BytesIO
10+
from typing import Any
11+
12+
from opentelemetry.exporter.otlp.proto.common.trace_encoder import encode_spans
13+
from opentelemetry.sdk.environment_variables import (
14+
OTEL_EXPORTER_OTLP_COMPRESSION,
15+
OTEL_EXPORTER_OTLP_TRACES_COMPRESSION,
16+
)
17+
from opentelemetry.sdk.trace import ReadableSpan
18+
from opentelemetry.sdk.trace.export import (
19+
BatchSpanProcessor,
20+
SpanExporter,
21+
SpanExportResult,
22+
)
23+
from uuid_utils import uuid7
24+
25+
logger = logging.getLogger(__name__)
26+
27+
28+
class Compression(enum.Enum):
29+
NoCompression = "none"
30+
Deflate = "deflate"
31+
Gzip = "gzip"
32+
33+
@classmethod
34+
def from_env(cls) -> "Compression":
35+
compression = (
36+
os.getenv(
37+
OTEL_EXPORTER_OTLP_TRACES_COMPRESSION,
38+
os.getenv(OTEL_EXPORTER_OTLP_COMPRESSION, "none"),
39+
)
40+
.lower()
41+
.strip()
42+
)
43+
return Compression(compression)
44+
45+
46+
class Base64SpanSerializer:
47+
def __init__(self, compression: Compression):
48+
self._compression = compression
49+
50+
def serialize(self, spans: Sequence[ReadableSpan]) -> str:
51+
encoded_spans = encode_spans(spans)
52+
data = encoded_spans.SerializeToString()
53+
54+
if self._compression == Compression.Gzip:
55+
gzip_data = BytesIO()
56+
with gzip.GzipFile(fileobj=gzip_data, mode="w") as gzip_stream:
57+
gzip_stream.write(data)
58+
data = gzip_data.getvalue()
59+
elif self._compression == Compression.Deflate:
60+
data = zlib.compress(data)
61+
62+
compressed_serialized_spans = base64.b64encode(data)
63+
return compressed_serialized_spans.decode("utf-8")
64+
65+
66+
class SQSTraceExporter(SpanExporter):
67+
"""
68+
Implements OpenTelemetry SpanExporter interface
69+
which can be used in combination with a SpanProcessor
70+
to publish traces to Amazon SQS.
71+
72+
```
73+
provider = TracerProvider()
74+
processor = SimpleSpanProcessor(SQSTraceExporter())
75+
provider.add_span_processor(processor)
76+
trace.set_tracer_provider(provider)
77+
```
78+
"""
79+
80+
def __init__(
81+
self,
82+
queue_url: str,
83+
sqs_client: Any,
84+
compression: Compression | None = None,
85+
) -> None:
86+
self._compression = compression or Compression.from_env()
87+
self._serializer = Base64SpanSerializer(self._compression)
88+
self._queue_url = queue_url
89+
self._sqs_client = sqs_client
90+
self._shutdown_in_progress = threading.Event()
91+
self._shutdown = False
92+
93+
def export(self, spans: Sequence[ReadableSpan]) -> SpanExportResult:
94+
"""
95+
Exports spans to SQS in batches when the batch size is reached.
96+
"""
97+
if self._shutdown:
98+
logger.warning("Exporter already shutdown, ignoring batch")
99+
return SpanExportResult.FAILURE
100+
101+
entries = []
102+
for span in spans:
103+
serialized_span = self._serializer.serialize([span])
104+
id_ = str(span.context.span_id) if span.context else uuid7().hex
105+
entries.append({"Id": id_, "MessageBody": serialized_span})
106+
107+
try:
108+
self._sqs_client.send_message_batch(
109+
QueueUrl=self._queue_url, Entries=entries
110+
)
111+
return SpanExportResult.SUCCESS
112+
except Exception as exc:
113+
logger.exception(f"Unexpected error exporting spans: {exc}")
114+
return SpanExportResult.FAILURE
115+
116+
def shutdown(self) -> None:
117+
"""Flush remaining spans before shutdown."""
118+
if self._shutdown:
119+
logger.warning("Exporter already shutdown, ignoring call")
120+
return
121+
122+
self._shutdown = True
123+
self._shutdown_in_progress.set()
124+
self._sqs_client.close()
125+
126+
def force_flush(self, timeout_millis: int = 30000) -> bool:
127+
"""Nothing is buffered in this exporter, so this method does nothing."""
128+
return True
129+
130+
131+
class SQSBatchSpanProcessor(BatchSpanProcessor):
132+
"""
133+
BatchSpanProcessor configured for SQS limits.
134+
135+
Automatically sets max_export_batch_size to 10 (SQS batch limit).
136+
137+
```
138+
provider = TracerProvider()
139+
exporter = SQSTraceExporter(queue_url="your-sqs-queue-url")
140+
processor = SQSBatchSpanProcessor(exporter)
141+
provider.add_span_processor(processor)
142+
trace.set_tracer_provider(provider)
143+
```
144+
"""
145+
146+
MAX_SQS_BATCH_SIZE = 10
147+
148+
def __init__(
149+
self,
150+
span_exporter: SpanExporter,
151+
max_export_batch_size: int = MAX_SQS_BATCH_SIZE,
152+
**kwargs,
153+
) -> None:
154+
assert max_export_batch_size <= self.MAX_SQS_BATCH_SIZE
155+
super().__init__(
156+
span_exporter=span_exporter,
157+
max_export_batch_size=max_export_batch_size,
158+
**kwargs,
159+
)
Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,22 @@
1313
from aws_lambda_opentelemetry.utils import set_lambda_handler_attributes
1414

1515

16-
def instrument_lambda_handler(**kwargs):
16+
def instrument_handler(**kwargs):
1717
"""
1818
Decorate a Lambda handler function to automatically create and manage
1919
an OpenTelemetry span for the function invocation.
2020
2121
Accepts all keyword arguments from Tracer.start_as_current_span():
2222
2323
:param name: Span name (defaults to function name if not provided)
24-
:param kind: SpanKind (defaults to SERVER if not provided)
2524
:param context: Parent span context
25+
:param kind: SpanKind (defaults to SERVER if not provided)
2626
:param attributes: Initial span attributes dict
2727
:param links: Span links
2828
:param start_time: Span start timestamp
2929
:param record_exception: Whether to record exceptions (default True)
3030
:param set_status_on_exception: Whether to set error status on exception (default True)
31+
:param end_on_exit: Whether to end the span on exit (default True)
3132
:return: The decorated handler function.
3233
"""
3334

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,15 @@ classifiers = [
3030
]
3131
dependencies = [
3232
"opentelemetry-api>=1.0.0",
33+
"opentelemetry-exporter-otlp-proto-common>=1.0.0",
3334
"opentelemetry-sdk>=1.0.0",
35+
"uuid-utils>=0.12.0",
3436
]
3537

3638
[dependency-groups]
3739
dev = [
40+
"boto3>=1.42.14",
41+
"moto>=5.1.18",
3842
"pytest>=9.0.2",
3943
"pytest-cov>=7.0.0",
4044
"ruff>=0.14.9",

tests/test_trace/__init__.py

Whitespace-only changes.

tests/test_trace/conftest.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import boto3
2+
import pytest
3+
from moto import mock_aws
4+
5+
6+
@pytest.fixture
7+
def mock_sqs_client():
8+
with mock_aws():
9+
sqs = boto3.client("sqs", region_name="us-east-1")
10+
sqs.create_queue(QueueName="test-queue")
11+
yield sqs

tests/test_trace/test_export.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
from unittest.mock import MagicMock
2+
3+
import pytest
4+
from opentelemetry import trace
5+
from opentelemetry.sdk.trace import TracerProvider
6+
from opentelemetry.sdk.trace.export import (
7+
SpanExportResult,
8+
)
9+
10+
from aws_lambda_opentelemetry.trace.export import (
11+
Base64SpanSerializer,
12+
Compression,
13+
SQSBatchSpanProcessor,
14+
SQSTraceExporter,
15+
)
16+
from tests.utils import generate_span
17+
18+
19+
class TestSpanSerializer:
20+
@pytest.mark.parametrize(
21+
"compression, expected_length",
22+
[
23+
(Compression.NoCompression, 276),
24+
(Compression.Gzip, 248),
25+
(Compression.Deflate, 232),
26+
],
27+
)
28+
def test_base64_span_serializer(self, compression, expected_length):
29+
serializer = Base64SpanSerializer(compression)
30+
spans = [generate_span()]
31+
result = serializer.serialize(spans)
32+
assert isinstance(result, str)
33+
assert len(result) == expected_length
34+
35+
@pytest.mark.parametrize(
36+
"compression_name, expected_compression",
37+
[
38+
("gzip", Compression.Gzip),
39+
("deflate", Compression.Deflate),
40+
("none", Compression.NoCompression),
41+
],
42+
)
43+
def test_compression_from_env_var(
44+
self, monkeypatch, compression_name, expected_compression
45+
):
46+
monkeypatch.setenv("OTEL_EXPORTER_OTLP_TRACES_COMPRESSION", compression_name)
47+
assert Compression.from_env() == expected_compression
48+
49+
50+
class TestSqsTraceExporter:
51+
QUEUE_URL = "https://sqs.us-east-1.amazonaws.com/123456789/test-queue"
52+
53+
def test_export_when_shutdown_is_called(self, mock_sqs_client):
54+
exporter = SQSTraceExporter(
55+
queue_url=self.QUEUE_URL,
56+
sqs_client=mock_sqs_client,
57+
)
58+
exporter.shutdown()
59+
60+
result = exporter.export([])
61+
assert result == SpanExportResult.FAILURE
62+
63+
def test_export_sends_messages_to_sqs(self, mock_sqs_client):
64+
exporter = SQSTraceExporter(
65+
queue_url=self.QUEUE_URL,
66+
sqs_client=mock_sqs_client,
67+
)
68+
69+
spans = [generate_span() for _ in range(2)]
70+
result = exporter.export(spans)
71+
assert result == SpanExportResult.SUCCESS
72+
73+
response = mock_sqs_client.receive_message(
74+
QueueUrl=self.QUEUE_URL,
75+
MaxNumberOfMessages=10,
76+
)
77+
messages = response.get("Messages", [])
78+
assert len(messages) == 2
79+
80+
def test_export_handles_sqs_client_exception(self):
81+
exporter = SQSTraceExporter(
82+
queue_url=self.QUEUE_URL,
83+
sqs_client=object(),
84+
)
85+
86+
spans = [generate_span() for _ in range(2)]
87+
result = exporter.export(spans)
88+
89+
assert result == SpanExportResult.FAILURE
90+
91+
def test_export_shutdown(self, mock_sqs_client):
92+
mock_sqs_client.close = MagicMock()
93+
exporter = SQSTraceExporter(
94+
queue_url=self.QUEUE_URL,
95+
sqs_client=mock_sqs_client,
96+
)
97+
98+
exporter.shutdown()
99+
100+
mock_sqs_client.close.assert_called_once()
101+
assert exporter._shutdown is True
102+
assert exporter._shutdown_in_progress.is_set()
103+
104+
def test_export_shutdown_successive_calls(self, mock_sqs_client):
105+
mock_sqs_client.close = MagicMock()
106+
exporter = SQSTraceExporter(
107+
queue_url=self.QUEUE_URL,
108+
sqs_client=mock_sqs_client,
109+
)
110+
111+
exporter.shutdown()
112+
exporter.shutdown()
113+
114+
mock_sqs_client.close.assert_called_once()
115+
assert exporter._shutdown is True
116+
assert exporter._shutdown_in_progress.is_set()
117+
118+
def test_export_force_flush(self, mock_sqs_client):
119+
exporter = SQSTraceExporter(
120+
queue_url=self.QUEUE_URL,
121+
sqs_client=mock_sqs_client,
122+
)
123+
124+
result = exporter.force_flush()
125+
assert result is True
126+
127+
128+
class TestSqsBatchSpanProcessor:
129+
QUEUE_URL = "https://sqs.us-east-1.amazonaws.com/123456789/test-queue"
130+
131+
def test_sqs_batch_span_processor_exports_in_batches(self, mock_sqs_client):
132+
exporter = SQSTraceExporter(
133+
queue_url=self.QUEUE_URL,
134+
sqs_client=mock_sqs_client,
135+
)
136+
processor = SQSBatchSpanProcessor(span_exporter=exporter)
137+
provider = TracerProvider()
138+
provider.add_span_processor(processor)
139+
trace.set_tracer_provider(provider)
140+
141+
tracer = trace.get_tracer("test-sqs-batch-span-processor")
142+
for i in range(15):
143+
with tracer.start_as_current_span(f"test-span-{i}"):
144+
...
145+
146+
response = mock_sqs_client.receive_message(
147+
QueueUrl=self.QUEUE_URL,
148+
MaxNumberOfMessages=10,
149+
WaitTimeSeconds=1,
150+
)
151+
assert len(response.get("Messages", [])) == 10
152+
153+
processor.shutdown()
154+
155+
response = mock_sqs_client.receive_message(
156+
QueueUrl=self.QUEUE_URL,
157+
MaxNumberOfMessages=10,
158+
WaitTimeSeconds=1,
159+
)
160+
assert len(response.get("Messages", [])) == 5

0 commit comments

Comments
 (0)