Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/).
- DID document service entry updated from `#atproto_appview` / `AtprotoAppView` to `#atdata_appview` / `AtdataAppView`

### Added
- Adversarial review: sendInteractions feature and surrounding code (round 3) (#39)
- Add real-time change stream subscribeChanges endpoint (#50)
- Add sendInteractions XRPC procedure for usage telemetry (#35)

- Dual-hostname DID document support — serve different `did:web` documents for `api.atdata.app` (appview identity) and `atdata.app` (atproto account identity) based on the `Host` header ([#19](https://github.com/forecast-bio/atdata-app/issues/19))
Expand Down
152 changes: 152 additions & 0 deletions src/atdata_app/changestream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
"""In-memory broadcast channel for real-time change events.

Provides a pub/sub mechanism that the ingestion processor publishes to
and WebSocket subscribers consume from. Maintains a bounded buffer of
recent events for cursor-based replay.
"""

from __future__ import annotations

import asyncio
import logging
from collections import deque
from dataclasses import dataclass, field
from typing import Any

logger = logging.getLogger(__name__)

DEFAULT_BUFFER_SIZE = 1000
DEFAULT_SUBSCRIBER_QUEUE_SIZE = 256


@dataclass
class ChangeEvent:
"""A single change event in the stream."""

seq: int
type: str # "create", "update", or "delete"
collection: str
did: str
rkey: str
timestamp: str
record: dict[str, Any] | None = None
cid: str | None = None

def to_dict(self) -> dict[str, Any]:
d: dict[str, Any] = {
"seq": self.seq,
"type": self.type,
"collection": self.collection,
"did": self.did,
"rkey": self.rkey,
"timestamp": self.timestamp,
}
if self.record is not None:
d["record"] = self.record
if self.cid is not None:
d["cid"] = self.cid
return d


@dataclass
class ChangeStream:
"""Broadcast channel with bounded replay buffer.

Thread-safe for asyncio: all mutations happen in the event loop.
"""

buffer_size: int = DEFAULT_BUFFER_SIZE
subscriber_queue_size: int = DEFAULT_SUBSCRIBER_QUEUE_SIZE
_seq: int = field(default=0, init=False)
_buffer: deque[ChangeEvent] = field(init=False)
_subscribers: dict[int, asyncio.Queue[ChangeEvent]] = field(
default_factory=dict, init=False
)
_next_sub_id: int = field(default=0, init=False)

def __post_init__(self) -> None:
self._buffer = deque(maxlen=self.buffer_size)

def publish(self, event: ChangeEvent) -> None:
"""Publish an event to all subscribers and the replay buffer.

Non-blocking. If a subscriber's queue is full, the event is dropped
for that subscriber (backpressure via disconnect is handled by the
WebSocket handler).
"""
self._seq += 1
event.seq = self._seq
self._buffer.append(event)

for sub_id, queue in list(self._subscribers.items()):
try:
queue.put_nowait(event)
except asyncio.QueueFull:
logger.warning(
"Subscriber %d queue full, dropping event seq=%d",
sub_id,
event.seq,
)

def subscribe(self) -> tuple[int, asyncio.Queue[ChangeEvent]]:
"""Create a new subscriber. Returns (subscriber_id, queue)."""
sub_id = self._next_sub_id
self._next_sub_id += 1
queue: asyncio.Queue[ChangeEvent] = asyncio.Queue(
maxsize=self.subscriber_queue_size
)
self._subscribers[sub_id] = queue
logger.debug("Subscriber %d connected (total: %d)", sub_id, len(self._subscribers))
return sub_id, queue

def unsubscribe(self, sub_id: int) -> None:
"""Remove a subscriber."""
self._subscribers.pop(sub_id, None)
logger.debug("Subscriber %d disconnected (total: %d)", sub_id, len(self._subscribers))

def replay_from(self, cursor: int) -> list[ChangeEvent]:
"""Return buffered events with seq > cursor.

Returns an empty list if the cursor is outside the buffer window.
"""
if not self._buffer:
return []

oldest_seq = self._buffer[0].seq
if cursor < oldest_seq - 1:
# Cursor is too old — events between cursor and buffer start were lost
return []

return [ev for ev in self._buffer if ev.seq > cursor]

@property
def current_seq(self) -> int:
return self._seq

@property
def subscriber_count(self) -> int:
return len(self._subscribers)


def make_change_event(
*,
event_type: str,
collection: str,
did: str,
rkey: str,
record: dict[str, Any] | None = None,
cid: str | None = None,
) -> ChangeEvent:
"""Factory for creating change events with current timestamp."""
from datetime import datetime, timezone

return ChangeEvent(
seq=0, # Assigned by ChangeStream.publish()
type=event_type,
collection=collection,
did=did,
rkey=rkey,
timestamp=datetime.now(timezone.utc).isoformat(),
record=record,
cid=cid,
)
4 changes: 3 additions & 1 deletion src/atdata_app/ingestion/jetstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ async def jetstream_consumer(app: FastAPI) -> None:
if event.get("kind") != "commit":
continue

await process_commit(pool, event)
await process_commit(
pool, event, getattr(app.state, "change_stream", None)
)

last_time_us = event.get("time_us")
msg_count += 1
Expand Down
27 changes: 26 additions & 1 deletion src/atdata_app/ingestion/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,16 @@
import asyncpg

from atdata_app import database as db
from atdata_app.changestream import ChangeStream, make_change_event

logger = logging.getLogger(__name__)


async def process_commit(pool: asyncpg.Pool, event: dict[str, Any]) -> None:
async def process_commit(
pool: asyncpg.Pool,
event: dict[str, Any],
change_stream: ChangeStream | None = None,
) -> None:
"""Process a Jetstream commit event.

Expected event format::
Expand Down Expand Up @@ -44,11 +49,31 @@ async def process_commit(pool: asyncpg.Pool, event: dict[str, Any]) -> None:
if operation == "delete":
await db.delete_record(pool, table, did, rkey)
logger.debug("Deleted %s %s/%s", collection, did, rkey)
if change_stream is not None:
change_stream.publish(
make_change_event(
event_type="delete",
collection=collection,
did=did,
rkey=rkey,
)
)
elif operation in ("create", "update"):
record = commit["record"]
cid = commit.get("cid")
try:
await db.UPSERT_FNS[table](pool, did, rkey, cid, record)
logger.debug("Upserted %s %s/%s", collection, did, rkey)
if change_stream is not None:
change_stream.publish(
make_change_event(
event_type=operation,
collection=collection,
did=did,
rkey=rkey,
record=record,
cid=cid,
)
)
except Exception:
logger.exception("Failed to upsert %s %s/%s", collection, did, rkey)
4 changes: 4 additions & 0 deletions src/atdata_app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from fastapi.responses import JSONResponse
from fastapi.staticfiles import StaticFiles

from atdata_app.changestream import ChangeStream
from atdata_app.config import AppConfig
from atdata_app.database import create_pool, run_migrations
from atdata_app.frontend import router as frontend_router
Expand All @@ -30,6 +31,9 @@ async def lifespan(app: FastAPI):
config: AppConfig = app.state.config
logger.info("Starting atdata-app (DID: %s)", config.service_did)

# Change stream (must be created before background tasks)
app.state.change_stream = ChangeStream()

# Database
pool = await create_pool(config.database_url)
app.state.db_pool = pool
Expand Down
4 changes: 3 additions & 1 deletion src/atdata_app/xrpc/router.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""Combined XRPC router mounting all query and procedure endpoints."""
"""Combined XRPC router mounting all query, procedure, and subscription endpoints."""

from fastapi import APIRouter

from atdata_app.xrpc.procedures import router as procedures_router
from atdata_app.xrpc.queries import router as queries_router
from atdata_app.xrpc.subscriptions import router as subscriptions_router

router = APIRouter(prefix="/xrpc")
router.include_router(queries_router)
router.include_router(procedures_router)
router.include_router(subscriptions_router)
53 changes: 53 additions & 0 deletions src/atdata_app/xrpc/subscriptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""WebSocket subscription endpoints for real-time change streaming."""

from __future__ import annotations

import json
import logging

from fastapi import APIRouter, WebSocket, WebSocketDisconnect

from atdata_app.changestream import ChangeStream

logger = logging.getLogger(__name__)

router = APIRouter()


@router.websocket("/science.alt.dataset.subscribeChanges")
async def subscribe_changes(websocket: WebSocket) -> None:
"""Stream real-time change events over WebSocket.

Query parameters:
cursor: Optional sequence number to replay from.
"""
change_stream: ChangeStream = websocket.app.state.change_stream

await websocket.accept()

cursor_param = websocket.query_params.get("cursor")
sub_id, queue = change_stream.subscribe()

try:
# Replay buffered events if cursor provided
if cursor_param is not None:
try:
cursor = int(cursor_param)
except (ValueError, TypeError):
await websocket.close(code=1008, reason="Invalid cursor value")
return
missed = change_stream.replay_from(cursor)
for event in missed:
await websocket.send_text(json.dumps(event.to_dict()))

# Stream live events
while True:
event = await queue.get()
await websocket.send_text(json.dumps(event.to_dict()))

except WebSocketDisconnect:
logger.debug("Subscriber %d disconnected", sub_id)
except Exception:
logger.exception("Error in subscriber %d", sub_id)
finally:
change_stream.unsubscribe(sub_id)
Loading