Skip to content

Commit ca3c651

Browse files
committed
wip
1 parent caa4cb1 commit ca3c651

9 files changed

Lines changed: 618 additions & 27 deletions

File tree

examples/cli/cli_wrapper_example.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@
1010

1111
model = timm.create_model("resnet18", pretrained=False).eval()
1212
batch = torch.randn(1, 3, 224, 224)
13+
iterations = 500
1314

14-
with torch.inference_mode():
15-
output = model(batch)
15+
for _ in range(iterations):
16+
with torch.inference_mode():
17+
output = model(batch)
1618

17-
print("output shape:", tuple(output.shape))
19+
# print("output shape:", tuple(output.shape))

tests/test_consumer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def set(self) -> None:
261261

262262
stop_control = StopControl()
263263
consumer.stop_event = stop_control
264-
consumer.drain_once = lambda: False
264+
consumer.drain_once = lambda flush_reservoir=False: False
265265

266266
called = {"count": 0}
267267

tests/test_offline_replay.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def test_offline_replay_restores_model_registry_for_pending_events(tmp_path):
5454
quantization="fp32",
5555
)
5656
client_a.publish(
57-
{"event_id": "e1", "event_type": "inference", "model_id": "ResNet"}
57+
{"event_id": "e1", "event_type": "model_load", "model_id": "ResNet"}
5858
)
5959
client_a.close()
6060

tests/test_reservoir.py

Lines changed: 320 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,320 @@
1+
"""Tests for InferenceReservoir and ReservoirRegistry."""
2+
3+
from __future__ import annotations
4+
5+
import threading
6+
7+
from wildedge.reservoir import InferenceReservoir, ReservoirRegistry, ReservoirStats
8+
9+
# ---------------------------------------------------------------------------
10+
# Helpers
11+
# ---------------------------------------------------------------------------
12+
13+
14+
def make_inference_event(
15+
model_id: str = "m1",
16+
success: bool = True,
17+
avg_confidence: float | None = None,
18+
avg_token_entropy: float | None = None,
19+
) -> dict:
20+
output_meta: dict = {}
21+
if avg_confidence is not None:
22+
output_meta["avg_confidence"] = avg_confidence
23+
if avg_token_entropy is not None:
24+
output_meta["avg_token_entropy"] = avg_token_entropy
25+
26+
inference: dict = {"success": success}
27+
if output_meta:
28+
inference["output_meta"] = output_meta
29+
30+
return {
31+
"event_type": "inference",
32+
"model_id": model_id,
33+
"inference": inference,
34+
}
35+
36+
37+
# ---------------------------------------------------------------------------
38+
# Stratum A guarantee
39+
# ---------------------------------------------------------------------------
40+
41+
42+
def test_stratum_a_success_false_always_retained():
43+
r = InferenceReservoir(reservoir_size=5)
44+
for _ in range(20):
45+
r.add(make_inference_event(success=False))
46+
events, stats = r.snapshot()
47+
assert stats.priority_seen == 20
48+
assert stats.priority_sent == 20
49+
assert stats.total_inference_events_sent == 20
50+
assert all("sample_rate" not in e for e in events)
51+
52+
53+
def test_stratum_a_low_confidence_always_retained():
54+
r = InferenceReservoir(reservoir_size=5, low_confidence_threshold=0.5)
55+
for _ in range(20):
56+
r.add(make_inference_event(avg_confidence=0.1))
57+
events, stats = r.snapshot()
58+
assert stats.priority_seen == 20
59+
assert stats.priority_sent == 20
60+
assert all("sample_rate" not in e for e in events)
61+
62+
63+
def test_stratum_a_high_entropy_always_retained():
64+
r = InferenceReservoir(reservoir_size=5, high_entropy_threshold=2.0)
65+
for _ in range(20):
66+
r.add(make_inference_event(avg_token_entropy=3.5))
67+
events, stats = r.snapshot()
68+
assert stats.priority_seen == 20
69+
assert stats.priority_sent == 20
70+
assert all("sample_rate" not in e for e in events)
71+
72+
73+
def test_stratum_a_threshold_boundary():
74+
r = InferenceReservoir(
75+
reservoir_size=10,
76+
low_confidence_threshold=0.5,
77+
high_entropy_threshold=2.0,
78+
)
79+
# Exactly at threshold: not priority
80+
r.add(make_inference_event(avg_confidence=0.5))
81+
r.add(make_inference_event(avg_token_entropy=2.0))
82+
# Just inside threshold: priority
83+
r.add(make_inference_event(avg_confidence=0.49))
84+
r.add(make_inference_event(avg_token_entropy=2.01))
85+
86+
_, stats = r.snapshot()
87+
assert stats.priority_seen == 2
88+
assert stats.priority_sent == 2
89+
90+
91+
# ---------------------------------------------------------------------------
92+
# priority_fn override
93+
# ---------------------------------------------------------------------------
94+
95+
96+
def test_priority_fn_replaces_builtin_signals():
97+
# Built-in signals would make these priority, but priority_fn always returns False
98+
r = InferenceReservoir(
99+
reservoir_size=3,
100+
priority_fn=lambda e: False,
101+
)
102+
for _ in range(10):
103+
r.add(make_inference_event(success=False))
104+
_, stats = r.snapshot()
105+
assert stats.priority_seen == 0
106+
107+
108+
def test_priority_fn_can_promote_to_stratum_a():
109+
def my_fn(event: dict) -> bool:
110+
return event.get("inference", {}).get("success", True) is True
111+
112+
r = InferenceReservoir(reservoir_size=3, priority_fn=my_fn)
113+
for _ in range(20):
114+
r.add(make_inference_event(success=True))
115+
_, stats = r.snapshot()
116+
assert stats.priority_seen == 20
117+
assert stats.priority_sent == 20
118+
119+
120+
# ---------------------------------------------------------------------------
121+
# Warm-up: no sample_rate while seen < capacity
122+
# ---------------------------------------------------------------------------
123+
124+
125+
def test_warmup_no_sample_rate():
126+
r = InferenceReservoir(reservoir_size=10, low_confidence_slots_pct=0.2)
127+
# background_capacity = 10 - int(10 * 0.2) = 8
128+
for _ in range(8): # exactly fills stratum B during warm-up
129+
r.add(make_inference_event(avg_confidence=0.9))
130+
events, stats = r.snapshot()
131+
assert all("sample_rate" not in e for e in events)
132+
assert stats.total_inference_events_sent == 8
133+
134+
135+
# ---------------------------------------------------------------------------
136+
# Post-warm-up: sample_rate set on Stratum B events only
137+
# ---------------------------------------------------------------------------
138+
139+
140+
def test_sample_rate_set_post_warmup():
141+
b_cap = 4
142+
r = InferenceReservoir(reservoir_size=5, low_confidence_slots_pct=0.2)
143+
# background_capacity = 5 - int(5 * 0.2) = 4
144+
total_bg = 20
145+
for _ in range(total_bg):
146+
r.add(make_inference_event(avg_confidence=0.9))
147+
events, stats = r.snapshot()
148+
149+
assert stats.total_inference_events_seen == total_bg
150+
assert stats.total_inference_events_sent == b_cap
151+
assert stats.priority_seen == 0
152+
153+
expected_rate = b_cap / total_bg
154+
for e in events:
155+
assert "sample_rate" in e
156+
assert abs(e["sample_rate"] - expected_rate) < 1e-9
157+
158+
159+
def test_no_sample_rate_on_stratum_a_events_post_warmup():
160+
r = InferenceReservoir(reservoir_size=5, low_confidence_slots_pct=0.2)
161+
for _ in range(20):
162+
r.add(make_inference_event(avg_confidence=0.1)) # priority
163+
for _ in range(20):
164+
r.add(make_inference_event(avg_confidence=0.9)) # background
165+
events, stats = r.snapshot()
166+
167+
priority_events = [e for e in events if "sample_rate" not in e]
168+
background_events = [e for e in events if "sample_rate" in e]
169+
170+
assert len(priority_events) == 20 # all priority retained
171+
assert stats.priority_sent == 20
172+
assert len(background_events) <= 4 # background capacity
173+
for e in background_events:
174+
assert e["sample_rate"] < 1.0
175+
176+
177+
# ---------------------------------------------------------------------------
178+
# snapshot() atomically resets
179+
# ---------------------------------------------------------------------------
180+
181+
182+
def test_snapshot_resets_counters():
183+
r = InferenceReservoir(reservoir_size=10)
184+
for _ in range(5):
185+
r.add(make_inference_event())
186+
r.snapshot()
187+
188+
assert r.size() == 0
189+
_, stats = r.snapshot()
190+
assert stats.total_inference_events_seen == 0
191+
assert stats.total_inference_events_sent == 0
192+
assert stats.priority_seen == 0
193+
assert stats.priority_sent == 0
194+
195+
196+
def test_snapshot_empty_reservoir():
197+
r = InferenceReservoir()
198+
events, stats = r.snapshot()
199+
assert events == []
200+
assert stats == ReservoirStats()
201+
202+
203+
# ---------------------------------------------------------------------------
204+
# Algorithm R statistical distribution
205+
# ---------------------------------------------------------------------------
206+
207+
208+
def test_algorithm_r_uniform_distribution():
209+
"""Each slot should be selected with roughly equal probability."""
210+
n_slots = 4
211+
n_total = 400
212+
n_trials = 200
213+
slot_counts: dict[int, int] = {i: 0 for i in range(n_slots)}
214+
215+
# background_capacity = 5 - 1 = 4 with low_confidence_slots_pct=0.2, size=5
216+
for _ in range(n_trials):
217+
r = InferenceReservoir(reservoir_size=5, low_confidence_slots_pct=0.2)
218+
for i in range(n_total):
219+
ev = make_inference_event(avg_confidence=0.9)
220+
ev["idx"] = i
221+
r.add(ev)
222+
events, _ = r.snapshot()
223+
for e in events:
224+
bucket = e["idx"] % n_slots
225+
slot_counts[bucket] += 1
226+
227+
total = sum(slot_counts.values())
228+
expected = total / n_slots
229+
for count in slot_counts.values():
230+
# Allow 30% deviation from expected (loose but deterministic)
231+
assert abs(count - expected) / expected < 0.3
232+
233+
234+
# ---------------------------------------------------------------------------
235+
# Thread safety
236+
# ---------------------------------------------------------------------------
237+
238+
239+
def test_concurrent_add_and_snapshot():
240+
r = InferenceReservoir(reservoir_size=50)
241+
errors: list[Exception] = []
242+
243+
def producer():
244+
try:
245+
for _ in range(500):
246+
r.add(make_inference_event(avg_confidence=0.9))
247+
except Exception as exc:
248+
errors.append(exc)
249+
250+
def snapshotter():
251+
try:
252+
for _ in range(10):
253+
r.snapshot()
254+
except Exception as exc:
255+
errors.append(exc)
256+
257+
threads = [threading.Thread(target=producer) for _ in range(4)]
258+
threads += [threading.Thread(target=snapshotter) for _ in range(2)]
259+
for t in threads:
260+
t.start()
261+
for t in threads:
262+
t.join()
263+
264+
assert errors == []
265+
266+
267+
# ---------------------------------------------------------------------------
268+
# ReservoirRegistry
269+
# ---------------------------------------------------------------------------
270+
271+
272+
def test_registry_get_or_create_is_idempotent():
273+
reg = ReservoirRegistry()
274+
r1 = reg.get_or_create("model-a")
275+
r2 = reg.get_or_create("model-a")
276+
assert r1 is r2
277+
278+
279+
def test_registry_snapshot_all_sampling_envelope():
280+
reg = ReservoirRegistry(
281+
reservoir_size=10,
282+
low_confidence_threshold=0.4,
283+
high_entropy_threshold=1.5,
284+
)
285+
for _ in range(3):
286+
reg.get_or_create("m1").add(make_inference_event("m1", avg_confidence=0.1))
287+
for _ in range(2):
288+
reg.get_or_create("m2").add(make_inference_event("m2", avg_confidence=0.9))
289+
290+
events, sampling = reg.snapshot_all()
291+
292+
assert len(events) == 5
293+
assert sampling["priority_thresholds"]["low_confidence"] == 0.4
294+
assert sampling["priority_thresholds"]["high_entropy"] == 1.5
295+
296+
assert sampling["m1"]["total_inference_events_seen"] == 3
297+
assert sampling["m1"]["priority_seen"] == 3
298+
assert sampling["m1"]["priority_sent"] == 3
299+
300+
assert sampling["m2"]["total_inference_events_seen"] == 2
301+
assert sampling["m2"]["priority_seen"] == 0
302+
303+
304+
def test_registry_snapshot_all_excludes_empty_models():
305+
reg = ReservoirRegistry()
306+
reg.get_or_create("m1") # never receives events
307+
reg.get_or_create("m2").add(make_inference_event("m2"))
308+
309+
_, sampling = reg.snapshot_all()
310+
assert "m1" not in sampling
311+
assert "m2" in sampling
312+
313+
314+
def test_registry_has_events():
315+
reg = ReservoirRegistry()
316+
assert not reg.has_events()
317+
reg.get_or_create("m1").add(make_inference_event("m1"))
318+
assert reg.has_events()
319+
reg.snapshot_all()
320+
assert not reg.has_events()

wildedge/batch.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,19 @@ def build_batch(
1717
events: list[dict],
1818
session_id: str,
1919
created_at: datetime,
20+
sampling: dict | None = None,
2021
) -> dict:
2122
"""Build a protocol-compliant batch envelope."""
22-
return {
23+
batch: dict = {
2324
"protocol_version": constants.PROTOCOL_VERSION,
2425
"device": device.to_dict(),
2526
"models": models,
2627
"session_id": session_id,
2728
"batch_id": str(uuid.uuid4()),
2829
"created_at": created_at.isoformat(),
2930
"sent_at": datetime.now(timezone.utc).isoformat(),
30-
"events": [_sanitize_event(event) for event in events],
3131
}
32+
if sampling is not None:
33+
batch["sampling"] = sampling
34+
batch["events"] = [_sanitize_event(event) for event in events]
35+
return batch

0 commit comments

Comments
 (0)