From 6980e73721de818993bf002306d41009053801e5 Mon Sep 17 00:00:00 2001 From: Piotr Duda Date: Tue, 10 Mar 2026 18:54:23 +0100 Subject: [PATCH 1/3] wip --- examples/cli/cli_wrapper_example.py | 8 +- tests/test_consumer.py | 2 +- tests/test_offline_replay.py | 2 +- tests/test_reservoir.py | 320 ++++++++++++++++++++++++++++ wildedge/batch.py | 8 +- wildedge/client.py | 25 ++- wildedge/constants.py | 5 + wildedge/consumer.py | 77 +++++-- wildedge/reservoir.py | 198 +++++++++++++++++ 9 files changed, 618 insertions(+), 27 deletions(-) create mode 100644 tests/test_reservoir.py create mode 100644 wildedge/reservoir.py diff --git a/examples/cli/cli_wrapper_example.py b/examples/cli/cli_wrapper_example.py index fff54bc..9b8a17f 100644 --- a/examples/cli/cli_wrapper_example.py +++ b/examples/cli/cli_wrapper_example.py @@ -10,8 +10,10 @@ model = timm.create_model("resnet18", pretrained=False).eval() batch = torch.randn(1, 3, 224, 224) +iterations = 500 -with torch.inference_mode(): - output = model(batch) +for _ in range(iterations): + with torch.inference_mode(): + output = model(batch) -print("output shape:", tuple(output.shape)) +# print("output shape:", tuple(output.shape)) diff --git a/tests/test_consumer.py b/tests/test_consumer.py index 0084c3d..d982121 100644 --- a/tests/test_consumer.py +++ b/tests/test_consumer.py @@ -261,7 +261,7 @@ def set(self) -> None: stop_control = StopControl() consumer.stop_event = stop_control - consumer.drain_once = lambda: False + consumer.drain_once = lambda flush_reservoir=False: False called = {"count": 0} diff --git a/tests/test_offline_replay.py b/tests/test_offline_replay.py index c3970f7..6d5a057 100644 --- a/tests/test_offline_replay.py +++ b/tests/test_offline_replay.py @@ -54,7 +54,7 @@ def test_offline_replay_restores_model_registry_for_pending_events(tmp_path): quantization="fp32", ) client_a.publish( - {"event_id": "e1", "event_type": "inference", "model_id": "ResNet"} + {"event_id": "e1", "event_type": "model_load", "model_id": "ResNet"} ) client_a.close() diff --git a/tests/test_reservoir.py b/tests/test_reservoir.py new file mode 100644 index 0000000..e6f497a --- /dev/null +++ b/tests/test_reservoir.py @@ -0,0 +1,320 @@ +"""Tests for InferenceReservoir and ReservoirRegistry.""" + +from __future__ import annotations + +import threading + +from wildedge.reservoir import InferenceReservoir, ReservoirRegistry, ReservoirStats + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def make_inference_event( + model_id: str = "m1", + success: bool = True, + avg_confidence: float | None = None, + avg_token_entropy: float | None = None, +) -> dict: + output_meta: dict = {} + if avg_confidence is not None: + output_meta["avg_confidence"] = avg_confidence + if avg_token_entropy is not None: + output_meta["avg_token_entropy"] = avg_token_entropy + + inference: dict = {"success": success} + if output_meta: + inference["output_meta"] = output_meta + + return { + "event_type": "inference", + "model_id": model_id, + "inference": inference, + } + + +# --------------------------------------------------------------------------- +# Stratum A guarantee +# --------------------------------------------------------------------------- + + +def test_stratum_a_success_false_always_retained(): + r = InferenceReservoir(reservoir_size=5) + for _ in range(20): + r.add(make_inference_event(success=False)) + events, stats = r.snapshot() + assert stats.priority_seen == 20 + assert stats.priority_sent == 20 + assert stats.total_inference_events_sent == 20 + assert all("sample_rate" not in e for e in events) + + +def test_stratum_a_low_confidence_always_retained(): + r = InferenceReservoir(reservoir_size=5, low_confidence_threshold=0.5) + for _ in range(20): + r.add(make_inference_event(avg_confidence=0.1)) + events, stats = r.snapshot() + assert stats.priority_seen == 20 + assert stats.priority_sent == 20 + assert all("sample_rate" not in e for e in events) + + +def test_stratum_a_high_entropy_always_retained(): + r = InferenceReservoir(reservoir_size=5, high_entropy_threshold=2.0) + for _ in range(20): + r.add(make_inference_event(avg_token_entropy=3.5)) + events, stats = r.snapshot() + assert stats.priority_seen == 20 + assert stats.priority_sent == 20 + assert all("sample_rate" not in e for e in events) + + +def test_stratum_a_threshold_boundary(): + r = InferenceReservoir( + reservoir_size=10, + low_confidence_threshold=0.5, + high_entropy_threshold=2.0, + ) + # Exactly at threshold: not priority + r.add(make_inference_event(avg_confidence=0.5)) + r.add(make_inference_event(avg_token_entropy=2.0)) + # Just inside threshold: priority + r.add(make_inference_event(avg_confidence=0.49)) + r.add(make_inference_event(avg_token_entropy=2.01)) + + _, stats = r.snapshot() + assert stats.priority_seen == 2 + assert stats.priority_sent == 2 + + +# --------------------------------------------------------------------------- +# priority_fn override +# --------------------------------------------------------------------------- + + +def test_priority_fn_replaces_builtin_signals(): + # Built-in signals would make these priority, but priority_fn always returns False + r = InferenceReservoir( + reservoir_size=3, + priority_fn=lambda e: False, + ) + for _ in range(10): + r.add(make_inference_event(success=False)) + _, stats = r.snapshot() + assert stats.priority_seen == 0 + + +def test_priority_fn_can_promote_to_stratum_a(): + def my_fn(event: dict) -> bool: + return event.get("inference", {}).get("success", True) is True + + r = InferenceReservoir(reservoir_size=3, priority_fn=my_fn) + for _ in range(20): + r.add(make_inference_event(success=True)) + _, stats = r.snapshot() + assert stats.priority_seen == 20 + assert stats.priority_sent == 20 + + +# --------------------------------------------------------------------------- +# Warm-up: no sample_rate while seen < capacity +# --------------------------------------------------------------------------- + + +def test_warmup_no_sample_rate(): + r = InferenceReservoir(reservoir_size=10, low_confidence_slots_pct=0.2) + # background_capacity = 10 - int(10 * 0.2) = 8 + for _ in range(8): # exactly fills stratum B during warm-up + r.add(make_inference_event(avg_confidence=0.9)) + events, stats = r.snapshot() + assert all("sample_rate" not in e for e in events) + assert stats.total_inference_events_sent == 8 + + +# --------------------------------------------------------------------------- +# Post-warm-up: sample_rate set on Stratum B events only +# --------------------------------------------------------------------------- + + +def test_sample_rate_set_post_warmup(): + b_cap = 4 + r = InferenceReservoir(reservoir_size=5, low_confidence_slots_pct=0.2) + # background_capacity = 5 - int(5 * 0.2) = 4 + total_bg = 20 + for _ in range(total_bg): + r.add(make_inference_event(avg_confidence=0.9)) + events, stats = r.snapshot() + + assert stats.total_inference_events_seen == total_bg + assert stats.total_inference_events_sent == b_cap + assert stats.priority_seen == 0 + + expected_rate = b_cap / total_bg + for e in events: + assert "sample_rate" in e + assert abs(e["sample_rate"] - expected_rate) < 1e-9 + + +def test_no_sample_rate_on_stratum_a_events_post_warmup(): + r = InferenceReservoir(reservoir_size=5, low_confidence_slots_pct=0.2) + for _ in range(20): + r.add(make_inference_event(avg_confidence=0.1)) # priority + for _ in range(20): + r.add(make_inference_event(avg_confidence=0.9)) # background + events, stats = r.snapshot() + + priority_events = [e for e in events if "sample_rate" not in e] + background_events = [e for e in events if "sample_rate" in e] + + assert len(priority_events) == 20 # all priority retained + assert stats.priority_sent == 20 + assert len(background_events) <= 4 # background capacity + for e in background_events: + assert e["sample_rate"] < 1.0 + + +# --------------------------------------------------------------------------- +# snapshot() atomically resets +# --------------------------------------------------------------------------- + + +def test_snapshot_resets_counters(): + r = InferenceReservoir(reservoir_size=10) + for _ in range(5): + r.add(make_inference_event()) + r.snapshot() + + assert r.size() == 0 + _, stats = r.snapshot() + assert stats.total_inference_events_seen == 0 + assert stats.total_inference_events_sent == 0 + assert stats.priority_seen == 0 + assert stats.priority_sent == 0 + + +def test_snapshot_empty_reservoir(): + r = InferenceReservoir() + events, stats = r.snapshot() + assert events == [] + assert stats == ReservoirStats() + + +# --------------------------------------------------------------------------- +# Algorithm R statistical distribution +# --------------------------------------------------------------------------- + + +def test_algorithm_r_uniform_distribution(): + """Each slot should be selected with roughly equal probability.""" + n_slots = 4 + n_total = 400 + n_trials = 200 + slot_counts: dict[int, int] = {i: 0 for i in range(n_slots)} + + # background_capacity = 5 - 1 = 4 with low_confidence_slots_pct=0.2, size=5 + for _ in range(n_trials): + r = InferenceReservoir(reservoir_size=5, low_confidence_slots_pct=0.2) + for i in range(n_total): + ev = make_inference_event(avg_confidence=0.9) + ev["idx"] = i + r.add(ev) + events, _ = r.snapshot() + for e in events: + bucket = e["idx"] % n_slots + slot_counts[bucket] += 1 + + total = sum(slot_counts.values()) + expected = total / n_slots + for count in slot_counts.values(): + # Allow 30% deviation from expected (loose but deterministic) + assert abs(count - expected) / expected < 0.3 + + +# --------------------------------------------------------------------------- +# Thread safety +# --------------------------------------------------------------------------- + + +def test_concurrent_add_and_snapshot(): + r = InferenceReservoir(reservoir_size=50) + errors: list[Exception] = [] + + def producer(): + try: + for _ in range(500): + r.add(make_inference_event(avg_confidence=0.9)) + except Exception as exc: + errors.append(exc) + + def snapshotter(): + try: + for _ in range(10): + r.snapshot() + except Exception as exc: + errors.append(exc) + + threads = [threading.Thread(target=producer) for _ in range(4)] + threads += [threading.Thread(target=snapshotter) for _ in range(2)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert errors == [] + + +# --------------------------------------------------------------------------- +# ReservoirRegistry +# --------------------------------------------------------------------------- + + +def test_registry_get_or_create_is_idempotent(): + reg = ReservoirRegistry() + r1 = reg.get_or_create("model-a") + r2 = reg.get_or_create("model-a") + assert r1 is r2 + + +def test_registry_snapshot_all_sampling_envelope(): + reg = ReservoirRegistry( + reservoir_size=10, + low_confidence_threshold=0.4, + high_entropy_threshold=1.5, + ) + for _ in range(3): + reg.get_or_create("m1").add(make_inference_event("m1", avg_confidence=0.1)) + for _ in range(2): + reg.get_or_create("m2").add(make_inference_event("m2", avg_confidence=0.9)) + + events, sampling = reg.snapshot_all() + + assert len(events) == 5 + assert sampling["priority_thresholds"]["low_confidence"] == 0.4 + assert sampling["priority_thresholds"]["high_entropy"] == 1.5 + + assert sampling["m1"]["total_inference_events_seen"] == 3 + assert sampling["m1"]["priority_seen"] == 3 + assert sampling["m1"]["priority_sent"] == 3 + + assert sampling["m2"]["total_inference_events_seen"] == 2 + assert sampling["m2"]["priority_seen"] == 0 + + +def test_registry_snapshot_all_excludes_empty_models(): + reg = ReservoirRegistry() + reg.get_or_create("m1") # never receives events + reg.get_or_create("m2").add(make_inference_event("m2")) + + _, sampling = reg.snapshot_all() + assert "m1" not in sampling + assert "m2" in sampling + + +def test_registry_has_events(): + reg = ReservoirRegistry() + assert not reg.has_events() + reg.get_or_create("m1").add(make_inference_event("m1")) + assert reg.has_events() + reg.snapshot_all() + assert not reg.has_events() diff --git a/wildedge/batch.py b/wildedge/batch.py index 9afa3df..c2ef151 100644 --- a/wildedge/batch.py +++ b/wildedge/batch.py @@ -17,9 +17,10 @@ def build_batch( events: list[dict], session_id: str, created_at: datetime, + sampling: dict | None = None, ) -> dict: """Build a protocol-compliant batch envelope.""" - return { + batch: dict = { "protocol_version": constants.PROTOCOL_VERSION, "device": device.to_dict(), "models": models, @@ -27,5 +28,8 @@ def build_batch( "batch_id": str(uuid.uuid4()), "created_at": created_at.isoformat(), "sent_at": datetime.now(timezone.utc).isoformat(), - "events": [_sanitize_event(event) for event in events], } + if sampling is not None: + batch["sampling"] = sampling + batch["events"] = [_sanitize_event(event) for event in events] + return batch diff --git a/wildedge/client.py b/wildedge/client.py index ac2dec2..8b87aae 100644 --- a/wildedge/client.py +++ b/wildedge/client.py @@ -32,6 +32,7 @@ default_pending_queue_dir, ) from wildedge.queue import EventQueue, QueuePolicy +from wildedge.reservoir import ReservoirRegistry from wildedge.settings import read_client_env, resolve_app_identity from wildedge.timing import Timer, elapsed_ms from wildedge.transmitter import Transmitter @@ -136,6 +137,11 @@ def __init__( dead_letter_dir: str | None = None, max_dead_letter_batches: int = constants.DEFAULT_MAX_DEAD_LETTER_BATCHES, on_delivery_failure: Callable[[str, int, int], None] | None = None, + reservoir_size: int = constants.DEFAULT_RESERVOIR_SIZE, + low_confidence_threshold: float = constants.DEFAULT_LOW_CONFIDENCE_THRESHOLD, + high_entropy_threshold: float = constants.DEFAULT_HIGH_ENTROPY_THRESHOLD, + low_confidence_slots_pct: float = constants.DEFAULT_LOW_CONFIDENCE_SLOTS_PCT, + priority_fn: Callable[[dict], bool] | None = None, ): env = read_client_env(dsn=dsn, debug=debug, app_identity=app_identity) dsn = env.dsn @@ -208,6 +214,13 @@ def __init__( directory=resolved_dead_letter_dir, max_batches=max_dead_letter_batches, ) + self.reservoir_registry = ReservoirRegistry( + reservoir_size=reservoir_size, + low_confidence_threshold=low_confidence_threshold, + high_entropy_threshold=high_entropy_threshold, + low_confidence_slots_pct=low_confidence_slots_pct, + priority_fn=priority_fn, + ) self.consumer = Consumer( queue=self.queue, transmitter=self.transmitter, @@ -220,6 +233,7 @@ def __init__( max_event_age_sec=max_event_age_sec, dead_letter_store=self.dead_letter_store, on_delivery_failure=on_delivery_failure, + reservoir_registry=self.reservoir_registry, ) self.auto_loaded: set[str] = set() @@ -247,9 +261,14 @@ def publish(self, event_dict: dict) -> None: event_dict.get("event_type"), event_dict.get("model_id"), ) - event_dict.setdefault("__we_first_queued_at", time.time()) - event_dict.setdefault("__we_attempts", 0) - self.queue.add(event_dict) + + if event_dict.get("event_type") == "inference": + model_id = event_dict.get("model_id", "") + self.reservoir_registry.get_or_create(model_id).add(event_dict) + else: + event_dict.setdefault("__we_first_queued_at", time.time()) + event_dict.setdefault("__we_attempts", 0) + self.queue.add(event_dict) def register_model( self, diff --git a/wildedge/constants.py b/wildedge/constants.py index 07cd45b..2fdae58 100644 --- a/wildedge/constants.py +++ b/wildedge/constants.py @@ -44,6 +44,11 @@ "WILDEDGE_AUTOLOAD" # set to "1" by `wildedge run` to activate sitecustomize ) WILDEDGE_AUTOLOAD_ACTIVE = "WILDEDGE_AUTOLOAD_ACTIVE" # guard against double-init +# Reservoir sampling +DEFAULT_RESERVOIR_SIZE = 200 +DEFAULT_LOW_CONFIDENCE_THRESHOLD = 0.5 +DEFAULT_HIGH_ENTROPY_THRESHOLD = 2.0 +DEFAULT_LOW_CONFIDENCE_SLOTS_PCT = 0.20 # Runtime validation limits BATCH_SIZE_MIN = 1 diff --git a/wildedge/consumer.py b/wildedge/consumer.py index 068a793..604ae39 100644 --- a/wildedge/consumer.py +++ b/wildedge/consumer.py @@ -17,6 +17,7 @@ if TYPE_CHECKING: from wildedge.device import DeviceInfo + from wildedge.reservoir import ReservoirRegistry class Consumer: @@ -35,6 +36,7 @@ def __init__( max_event_age_sec: float = constants.DEFAULT_MAX_EVENT_AGE_SEC, dead_letter_store: DeadLetterStore | None = None, on_delivery_failure: Callable[[str, int, int], None] | None = None, + reservoir_registry: ReservoirRegistry | None = None, ): self.queue = queue self.transmitter = transmitter @@ -47,6 +49,11 @@ def __init__( self.max_event_age_sec = max_event_age_sec self.dead_letter_store = dead_letter_store self.on_delivery_failure = on_delivery_failure + self.reservoir_registry = reservoir_registry + + # Holds a failed reservoir snapshot across retry attempts. + # Cleared on successful transmission or permanent error. + self._held_snapshot: tuple[list[dict[str, Any]], dict[str, Any]] | None = None self.stop_event = threading.Event() self.stopped = False @@ -59,6 +66,18 @@ def __init__( self.thread.start() atexit.register(self.flush, constants.DEFAULT_SHUTDOWN_FLUSH_TIMEOUT_SEC) + def has_pending(self) -> bool: + """True if the FIFO queue or a held snapshot needs draining.""" + return self.queue.length() > 0 or self._held_snapshot is not None + + def has_any_pending(self) -> bool: + """True if anything is pending, including the reservoir.""" + if self.has_pending(): + return True + if self.reservoir_registry is not None: + return self.reservoir_registry.has_events() + return False + def run(self) -> None: last_flush = time.monotonic() while not self.stop_event.is_set(): @@ -66,8 +85,8 @@ def run(self) -> None: time_since_flush = now - last_flush force_flush = time_since_flush >= self.flush_interval_sec - if self.queue.length() > 0 or force_flush: - sent = self.drain_once() + if self.has_pending() or force_flush: + sent = self.drain_once(flush_reservoir=force_flush) if sent: last_flush = time.monotonic() self.backoff = constants.BACKOFF_MIN @@ -133,21 +152,37 @@ def dead_letter_and_drop( self.queue.remove_first_n(len(events)) self.notify_delivery_failure(reason, len(events)) - def drain_once(self) -> bool: - events = self.queue.peek_many(self.batch_size) - if not events: + def drain_once(self, flush_reservoir: bool = False) -> bool: + # --- Reservoir snapshot (inference events) --- + if self.reservoir_registry is not None: + if self._held_snapshot is not None: + # Always retry a held snapshot regardless of flush_reservoir + inference_events, sampling = self._held_snapshot + elif flush_reservoir: + inference_events, sampling = self.reservoir_registry.snapshot_all() + else: + inference_events, sampling = [], {} + else: + inference_events, sampling = [], {} + + # --- Non-inference FIFO --- + fifo_events = self.queue.peek_many(self.batch_size) + + if not inference_events and not fifo_events: return False + # Age-check FIFO events (inference events in reservoir are not age-checked; + # they represent the current flush window and are always transmitted) now_unix = time.time() expired_count = 0 - for event in events: + for event in fifo_events: first_seen = float(event.get("__we_first_queued_at", now_unix)) if (now_unix - first_seen) > self.max_event_age_sec: expired_count += 1 else: break if expired_count > 0: - expired = events[:expired_count] + expired = fifo_events[:expired_count] self.dead_letter_and_drop( reason="event_age_exceeded", events=expired, @@ -160,21 +195,25 @@ def drain_once(self) -> bool: ) return True - for event in events: + for event in fifo_events: event["__we_attempts"] = int(event.get("__we_attempts", 0)) + 1 + all_events = inference_events + fifo_events batch = build_batch( device=self.device, models=self.get_models(), - events=events, + events=all_events, session_id=self.session_id, created_at=self.created_at, + sampling=sampling if sampling else None, ) if self.debug: logger.debug( - "wildedge: transmitting %d events (batch_id=%s)", - len(events), + "wildedge: transmitting %d events (%d inference, %d fifo, batch_id=%s)", + len(all_events), + len(inference_events), + len(fifo_events), batch["batch_id"], ) @@ -182,10 +221,13 @@ def drain_once(self) -> bool: response = self.transmitter.send(batch) except TransmitError as exc: logger.warning("wildedge: transmit failed, will retry: %s", exc) + if self.reservoir_registry is not None and self._held_snapshot is None: + self._held_snapshot = (inference_events, sampling) return False if response.status in ("accepted", "partial"): - self.queue.remove_first_n(len(events)) + self.queue.remove_first_n(len(fifo_events)) + self._held_snapshot = None if self.debug: logger.debug( "wildedge: accepted=%d rejected=%d", @@ -199,7 +241,7 @@ def drain_once(self) -> bool: if response.status in ("rejected", "unauthorized", "error"): self.dead_letter_and_drop( reason=f"permanent_{response.status}", - events=events, + events=fifo_events, batch_id=batch["batch_id"], details={ "response_status": response.status, @@ -207,19 +249,20 @@ def drain_once(self) -> bool: "events_rejected": response.events_rejected, }, ) + self._held_snapshot = None return True return False def flush(self, timeout: float = 5.0) -> None: - """Block until the queue drains or timeout expires.""" + """Block until the queue and reservoir drain or timeout expires.""" if self.stopped: return deadline = time.monotonic() + timeout backoff = constants.BACKOFF_MIN - while self.queue.length() > 0 and time.monotonic() < deadline: - progressed = self.drain_once() - if self.queue.length() == 0: + while self.has_any_pending() and time.monotonic() < deadline: + progressed = self.drain_once(flush_reservoir=True) + if not self.has_any_pending(): break remaining = deadline - time.monotonic() if remaining <= 0: diff --git a/wildedge/reservoir.py b/wildedge/reservoir.py new file mode 100644 index 0000000..8ae2ef5 --- /dev/null +++ b/wildedge/reservoir.py @@ -0,0 +1,198 @@ +from __future__ import annotations + +import random +import threading +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any + +from wildedge import constants + + +@dataclass +class ReservoirStats: + total_inference_events_seen: int = 0 + total_inference_events_sent: int = 0 + priority_seen: int = 0 + priority_sent: int = 0 + + +class InferenceReservoir: + """ + Two-stratum per-model inference reservoir sampler (Algorithm R). + + Stratum A — priority (always retain, sample_rate omitted): + 1. success=False + 2. output_meta.avg_confidence < low_confidence_threshold + 3. output_meta.avg_token_entropy > high_entropy_threshold + 4. priority_fn(event) returns True (replaces built-ins when set) + + Stratum B — background (Algorithm R, sample_rate set at flush time). + """ + + def __init__( + self, + reservoir_size: int = constants.DEFAULT_RESERVOIR_SIZE, + low_confidence_threshold: float = constants.DEFAULT_LOW_CONFIDENCE_THRESHOLD, + high_entropy_threshold: float = constants.DEFAULT_HIGH_ENTROPY_THRESHOLD, + low_confidence_slots_pct: float = constants.DEFAULT_LOW_CONFIDENCE_SLOTS_PCT, + priority_fn: Callable[[dict], bool] | None = None, + ) -> None: + self._reservoir_size = reservoir_size + self._low_confidence_threshold = low_confidence_threshold + self._high_entropy_threshold = high_entropy_threshold + self._background_capacity = max( + 1, reservoir_size - int(reservoir_size * low_confidence_slots_pct) + ) + self._priority_fn = priority_fn + + # Stratum A: unbounded — all priority events are always retained + self._stratum_a: list[dict] = [] + self._stratum_a_seen: int = 0 + + # Stratum B: Algorithm R reservoir + self._stratum_b: list[dict] = [] + self._stratum_b_seen: int = 0 + + self._lock = threading.Lock() + + def add(self, event: dict) -> None: + with self._lock: + if self._is_priority(event): + self._stratum_a.append(event) + self._stratum_a_seen += 1 + else: + self._stratum_b_seen += 1 + if len(self._stratum_b) < self._background_capacity: + self._stratum_b.append(event) + else: + j = random.randint(0, self._stratum_b_seen - 1) + if j < self._background_capacity: + self._stratum_b[j] = event + + def snapshot(self) -> tuple[list[dict[str, Any]], ReservoirStats]: + """Atomically drain and reset. Annotates Stratum B events with sample_rate.""" + with self._lock: + stratum_a = self._stratum_a + stratum_b = self._stratum_b + a_seen = self._stratum_a_seen + b_seen = self._stratum_b_seen + + self._stratum_a = [] + self._stratum_b = [] + self._stratum_a_seen = 0 + self._stratum_b_seen = 0 + + # Annotate Stratum B events with sample_rate (omit when 1.0 i.e. warm-up) + if b_seen > 0: + b_rate = len(stratum_b) / b_seen + if b_rate < 1.0: + for ev in stratum_b: + ev["sample_rate"] = b_rate + + stats = ReservoirStats( + total_inference_events_seen=a_seen + b_seen, + total_inference_events_sent=len(stratum_a) + len(stratum_b), + priority_seen=a_seen, + priority_sent=len(stratum_a), + ) + + return stratum_a + stratum_b, stats + + def size(self) -> int: + with self._lock: + return len(self._stratum_a) + len(self._stratum_b) + + def _is_priority(self, event: dict) -> bool: + if self._priority_fn is not None: + return self._priority_fn(event) + + inference = event.get("inference", {}) + + if not inference.get("success", True): + return True + + output_meta = inference.get("output_meta", {}) + if isinstance(output_meta, dict): + avg_confidence = output_meta.get("avg_confidence") + if ( + avg_confidence is not None + and avg_confidence < self._low_confidence_threshold + ): + return True + + avg_token_entropy = output_meta.get("avg_token_entropy") + if ( + avg_token_entropy is not None + and avg_token_entropy > self._high_entropy_threshold + ): + return True + + return False + + +class ReservoirRegistry: + """Thread-safe registry of per-model InferenceReservoirs.""" + + def __init__( + self, + reservoir_size: int = constants.DEFAULT_RESERVOIR_SIZE, + low_confidence_threshold: float = constants.DEFAULT_LOW_CONFIDENCE_THRESHOLD, + high_entropy_threshold: float = constants.DEFAULT_HIGH_ENTROPY_THRESHOLD, + low_confidence_slots_pct: float = constants.DEFAULT_LOW_CONFIDENCE_SLOTS_PCT, + priority_fn: Callable[[dict], bool] | None = None, + ) -> None: + self._reservoir_kwargs: dict[str, Any] = { + "reservoir_size": reservoir_size, + "low_confidence_threshold": low_confidence_threshold, + "high_entropy_threshold": high_entropy_threshold, + "low_confidence_slots_pct": low_confidence_slots_pct, + "priority_fn": priority_fn, + } + self._reservoirs: dict[str, InferenceReservoir] = {} + self._lock = threading.Lock() + + def get_or_create(self, model_id: str) -> InferenceReservoir: + with self._lock: + if model_id not in self._reservoirs: + self._reservoirs[model_id] = InferenceReservoir( + **self._reservoir_kwargs + ) + return self._reservoirs[model_id] + + def has_events(self) -> bool: + """Return True if any reservoir holds at least one event.""" + with self._lock: + return any(r.size() > 0 for r in self._reservoirs.values()) + + def snapshot_all(self) -> tuple[list[dict[str, Any]], dict[str, Any]]: + """ + Snapshot every reservoir atomically. Returns: + - combined list of all inference events (with sample_rate where < 1.0) + - sampling envelope dict for the batch (priority_thresholds + per-model stats) + """ + with self._lock: + model_ids = list(self._reservoirs.keys()) + + all_events: list[dict[str, Any]] = [] + sampling: dict[str, Any] = { + "priority_thresholds": { + "low_confidence": self._reservoir_kwargs["low_confidence_threshold"], + "high_entropy": self._reservoir_kwargs["high_entropy_threshold"], + } + } + + for model_id in model_ids: + reservoir = self._reservoirs[model_id] + events, stats = reservoir.snapshot() + if stats.total_inference_events_seen == 0: + continue + all_events.extend(events) + sampling[model_id] = { + "total_inference_events_seen": stats.total_inference_events_seen, + "total_inference_events_sent": stats.total_inference_events_sent, + "priority_seen": stats.priority_seen, + "priority_sent": stats.priority_sent, + } + + return all_events, sampling From 0275a897d3bb3a643ce82a4f2ed5b74943cdbd67 Mon Sep 17 00:00:00 2001 From: Piotr Duda Date: Wed, 11 Mar 2026 11:39:38 +0100 Subject: [PATCH 2/3] fixes, tests, cleanup --- examples/cli/cli_wrapper_example.py | 4 +-- examples/feedback_example.py | 2 ++ examples/gguf_example.py | 2 ++ examples/gguf_gemma_manual_example.py | 2 ++ examples/keras_example.py | 2 ++ examples/onnx_example.py | 2 ++ examples/pytorch_example.py | 2 ++ examples/timm_example.py | 2 ++ tests/test_cli.py | 46 +++++++++++++++++++++++++-- wildedge/cli.py | 2 +- wildedge/constants.py | 1 + wildedge/settings.py | 2 +- 12 files changed, 62 insertions(+), 7 deletions(-) diff --git a/examples/cli/cli_wrapper_example.py b/examples/cli/cli_wrapper_example.py index 9b8a17f..abb9cba 100644 --- a/examples/cli/cli_wrapper_example.py +++ b/examples/cli/cli_wrapper_example.py @@ -13,7 +13,7 @@ iterations = 500 for _ in range(iterations): - with torch.inference_mode(): - output = model(batch) + with torch.inference_mode(): + output = model(batch) # print("output shape:", tuple(output.shape)) diff --git a/examples/feedback_example.py b/examples/feedback_example.py index 6dd57c1..91eca35 100644 --- a/examples/feedback_example.py +++ b/examples/feedback_example.py @@ -48,3 +48,5 @@ print( f"run {i + 1}: confidence={confidence:.3f} → {handle.last_inference_id[:8]}..." ) + +client.close() diff --git a/examples/gguf_example.py b/examples/gguf_example.py index cc3543b..f8467b0 100644 --- a/examples/gguf_example.py +++ b/examples/gguf_example.py @@ -33,3 +33,5 @@ result = llm(prompt, max_tokens=128, temperature=0.7) text = result["choices"][0]["text"].strip() print(f"Q: {prompt}\nA: {text}\n") + +client.close() diff --git a/examples/gguf_gemma_manual_example.py b/examples/gguf_gemma_manual_example.py index 2418187..97699b9 100644 --- a/examples/gguf_gemma_manual_example.py +++ b/examples/gguf_gemma_manual_example.py @@ -114,3 +114,5 @@ except Exception as exc: handle.track_error(error_code="UNKNOWN", error_message=str(exc)[:200]) raise + +client.close() diff --git a/examples/keras_example.py b/examples/keras_example.py index c854c89..8068a9b 100644 --- a/examples/keras_example.py +++ b/examples/keras_example.py @@ -42,3 +42,5 @@ for _ in range(3): output = model.predict(batch, verbose=0) print("output shape:", output.shape) + +client.close() diff --git a/examples/onnx_example.py b/examples/onnx_example.py index 73dc77c..0ac23a6 100644 --- a/examples/onnx_example.py +++ b/examples/onnx_example.py @@ -25,3 +25,5 @@ for _ in range(3): outputs = session.run(None, {"pixel_values": batch}) print("output shape:", outputs[0].shape) + +client.close() diff --git a/examples/pytorch_example.py b/examples/pytorch_example.py index b28760e..32e0b7c 100644 --- a/examples/pytorch_example.py +++ b/examples/pytorch_example.py @@ -40,3 +40,5 @@ def forward(self, x): for _ in range(3): output = model(batch) print("output shape:", output.shape) + +client.close() diff --git a/examples/timm_example.py b/examples/timm_example.py index 47ba3a7..a335d49 100644 --- a/examples/timm_example.py +++ b/examples/timm_example.py @@ -33,3 +33,5 @@ for _ in range(3): output = model(batch) print("output shape:", output.shape) + +client.close() diff --git a/tests/test_cli.py b/tests/test_cli.py index 4db4ea9..8eb6008 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -9,6 +9,7 @@ from wildedge.integrations.registry import IntegrationSpec from wildedge.runtime import bootstrap from wildedge.runtime import runner as runtime_runner +from wildedge.settings import read_runtime_env def _fake_execle(captured: dict): @@ -55,7 +56,7 @@ def test_cli_run_execs_command_with_env(monkeypatch): assert captured["env"][constants.ENV_STRICT_INTEGRATIONS] == "0" assert captured["env"][constants.ENV_PRINT_STARTUP_REPORT] == "0" assert captured["env"][constants.ENV_FLUSH_TIMEOUT] == str( - constants.DEFAULT_SHUTDOWN_FLUSH_TIMEOUT_SEC + constants.DEFAULT_RUNTIME_FLUSH_TIMEOUT_SEC ) @@ -148,7 +149,7 @@ def test_install_runtime_requires_dsn(monkeypatch): bootstrap.install_runtime() -def test_install_runtime_default_flush_timeout_is_shutdown_budget(monkeypatch): +def test_install_runtime_default_flush_timeout_is_runtime_budget(monkeypatch): class FakeWildEdge: SUPPORTED_INTEGRATIONS = {"onnx"} @@ -171,7 +172,7 @@ def close(self): # type: ignore[no-untyped-def] context = bootstrap.install_runtime() try: - assert context.flush_timeout == constants.DEFAULT_SHUTDOWN_FLUSH_TIMEOUT_SEC + assert context.flush_timeout == constants.DEFAULT_RUNTIME_FLUSH_TIMEOUT_SEC finally: context.shutdown() @@ -471,3 +472,42 @@ def test_parse_run_args_only_double_dash_raises(): """['--'] with nothing after raises ValueError.""" with pytest.raises(ValueError, match="missing command"): cli.parse_run_args(["--"]) + + +def test_cli_run_default_flush_timeout_is_nonzero(monkeypatch): + """CLI must pass a non-zero flush timeout so reservoir events are sent at shutdown.""" + captured: dict = {} + monkeypatch.setattr(cli.os, "execle", _fake_execle(captured)) + monkeypatch.setattr(cli.shutil, "which", lambda cmd: f"/usr/bin/{cmd}") + + cli.main(["run", "--", "gunicorn", "myapp.wsgi:app"]) + + flush_timeout = float(captured["env"][constants.ENV_FLUSH_TIMEOUT]) + assert flush_timeout > 0, ( + "flush timeout must be > 0 so reservoir inference events are flushed at shutdown" + ) + assert flush_timeout == constants.DEFAULT_RUNTIME_FLUSH_TIMEOUT_SEC + + +def test_cli_run_flush_timeout_override(monkeypatch): + """--flush-timeout arg is forwarded to child process.""" + captured: dict = {} + monkeypatch.setattr(cli.os, "execle", _fake_execle(captured)) + monkeypatch.setattr(cli.shutil, "which", lambda cmd: f"/usr/bin/{cmd}") + + cli.main(["run", "--flush-timeout", "10.0", "--", "gunicorn", "myapp.wsgi:app"]) + + assert captured["env"][constants.ENV_FLUSH_TIMEOUT] == "10.0" + + +def test_read_runtime_env_default_flush_timeout_is_nonzero(): + """read_runtime_env must default to a non-zero flush timeout when env var is absent.""" + env = {constants.ENV_DSN: "https://secret@ingest.wildedge.dev/key"} + result = read_runtime_env(all_integrations=[], all_hubs=[], environ=env) + assert result.flush_timeout > 0 + assert result.flush_timeout == constants.DEFAULT_RUNTIME_FLUSH_TIMEOUT_SEC + + +def test_runtime_flush_timeout_constant_is_nonzero(): + """Guard: DEFAULT_RUNTIME_FLUSH_TIMEOUT_SEC must stay > 0.""" + assert constants.DEFAULT_RUNTIME_FLUSH_TIMEOUT_SEC > 0 diff --git a/wildedge/cli.py b/wildedge/cli.py index 06a7470..7cef388 100644 --- a/wildedge/cli.py +++ b/wildedge/cli.py @@ -49,7 +49,7 @@ def build_parser() -> argparse.ArgumentParser: run.add_argument( "--flush-timeout", type=float, - default=constants.DEFAULT_SHUTDOWN_FLUSH_TIMEOUT_SEC, + default=constants.DEFAULT_RUNTIME_FLUSH_TIMEOUT_SEC, help="Flush timeout (seconds) for shutdown.", ) run.add_argument( diff --git a/wildedge/constants.py b/wildedge/constants.py index 2fdae58..d495b17 100644 --- a/wildedge/constants.py +++ b/wildedge/constants.py @@ -24,6 +24,7 @@ DEFAULT_ENABLE_DEAD_LETTER_PERSISTENCE = False DEFAULT_MAX_DEAD_LETTER_BATCHES = 10 DEFAULT_SHUTDOWN_FLUSH_TIMEOUT_SEC = 0.0 +DEFAULT_RUNTIME_FLUSH_TIMEOUT_SEC = 5.0 # Consumer backoff BACKOFF_MIN = 1.0 diff --git a/wildedge/settings.py b/wildedge/settings.py index fc32606..de0c360 100644 --- a/wildedge/settings.py +++ b/wildedge/settings.py @@ -95,7 +95,7 @@ def read_runtime_env( flush_timeout = float( env.get( constants.ENV_FLUSH_TIMEOUT, - str(constants.DEFAULT_SHUTDOWN_FLUSH_TIMEOUT_SEC), + str(constants.DEFAULT_RUNTIME_FLUSH_TIMEOUT_SEC), ) ) return RuntimeEnv( From 0db3079e141bb316431a6ae2f9980db0fc417836 Mon Sep 17 00:00:00 2001 From: Piotr Duda Date: Wed, 11 Mar 2026 11:53:03 +0100 Subject: [PATCH 3/3] fixes, tests, cleanup --- examples/chatgpt_example.py | 4 +--- examples/cli/cli_wrapper_example.py | 2 +- tests/test_client_flows.py | 9 +++++++++ wildedge/client.py | 4 +--- 4 files changed, 12 insertions(+), 7 deletions(-) diff --git a/examples/chatgpt_example.py b/examples/chatgpt_example.py index 709942d..dc00b39 100644 --- a/examples/chatgpt_example.py +++ b/examples/chatgpt_example.py @@ -100,9 +100,7 @@ # Simulate feedback: short completions get a thumbs down. feedback_type = ( - FeedbackType.THUMBS_UP - if len(completion) > 40 - else FeedbackType.THUMBS_DOWN + FeedbackType.THUMBS_UP if len(completion) > 40 else FeedbackType.THUMBS_DOWN ) handle.track_feedback(inference_id, feedback_type) diff --git a/examples/cli/cli_wrapper_example.py b/examples/cli/cli_wrapper_example.py index abb9cba..254d9f5 100644 --- a/examples/cli/cli_wrapper_example.py +++ b/examples/cli/cli_wrapper_example.py @@ -16,4 +16,4 @@ with torch.inference_mode(): output = model(batch) -# print("output shape:", tuple(output.shape)) +print("Done!") \ No newline at end of file diff --git a/tests/test_client_flows.py b/tests/test_client_flows.py index e0ab7ae..75b881d 100644 --- a/tests/test_client_flows.py +++ b/tests/test_client_flows.py @@ -18,6 +18,15 @@ def test_register_model_fallback_requires_id_when_no_extractor( client.register_model(object()) +def test_register_model_fallback_uses_model_id_as_name( + client_with_stubbed_runtime, +): + client = client_with_stubbed_runtime + with patch.object(client, "_find_extractor", return_value=None): + client.register_model(object(), model_id="openai/gpt-4o") + assert client.registry.models["openai/gpt-4o"].model_name == "openai/gpt-4o" + + def test_on_model_auto_loaded_uses_hub_records_when_downloads_missing( client_with_stubbed_runtime, dummy_handle ): diff --git a/wildedge/client.py b/wildedge/client.py index 8b87aae..852427c 100644 --- a/wildedge/client.py +++ b/wildedge/client.py @@ -306,9 +306,7 @@ def register_model( else: # No extractor matched - require explicit id model_id = overrides.pop("id", None) - model_name = overrides.pop("model_name", None) or ( - str(type(model_obj).__name__) - ) + model_name = overrides.pop("model_name", None) or model_id info = ModelInfo( model_name=model_name, model_version=overrides.pop("version", "unknown"),