Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions examples/chatgpt_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,7 @@

# Simulate feedback: short completions get a thumbs down.
feedback_type = (
FeedbackType.THUMBS_UP
if len(completion) > 40
else FeedbackType.THUMBS_DOWN
FeedbackType.THUMBS_UP if len(completion) > 40 else FeedbackType.THUMBS_DOWN
)
handle.track_feedback(inference_id, feedback_type)

Expand Down
8 changes: 5 additions & 3 deletions examples/cli/cli_wrapper_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@

model = timm.create_model("resnet18", pretrained=False).eval()
batch = torch.randn(1, 3, 224, 224)
iterations = 500

with torch.inference_mode():
output = model(batch)
for _ in range(iterations):
with torch.inference_mode():
output = model(batch)

print("output shape:", tuple(output.shape))
print("Done!")
2 changes: 2 additions & 0 deletions examples/feedback_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,5 @@
print(
f"run {i + 1}: confidence={confidence:.3f} → {handle.last_inference_id[:8]}..."
)

client.close()
2 changes: 2 additions & 0 deletions examples/gguf_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,5 @@
result = llm(prompt, max_tokens=128, temperature=0.7)
text = result["choices"][0]["text"].strip()
print(f"Q: {prompt}\nA: {text}\n")

client.close()
2 changes: 2 additions & 0 deletions examples/gguf_gemma_manual_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,5 @@
except Exception as exc:
handle.track_error(error_code="UNKNOWN", error_message=str(exc)[:200])
raise

client.close()
2 changes: 2 additions & 0 deletions examples/keras_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,5 @@
for _ in range(3):
output = model.predict(batch, verbose=0)
print("output shape:", output.shape)

client.close()
2 changes: 2 additions & 0 deletions examples/onnx_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,5 @@
for _ in range(3):
outputs = session.run(None, {"pixel_values": batch})
print("output shape:", outputs[0].shape)

client.close()
2 changes: 2 additions & 0 deletions examples/pytorch_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,5 @@ def forward(self, x):
for _ in range(3):
output = model(batch)
print("output shape:", output.shape)

client.close()
2 changes: 2 additions & 0 deletions examples/timm_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,5 @@
for _ in range(3):
output = model(batch)
print("output shape:", output.shape)

client.close()
46 changes: 43 additions & 3 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from wildedge.integrations.registry import IntegrationSpec
from wildedge.runtime import bootstrap
from wildedge.runtime import runner as runtime_runner
from wildedge.settings import read_runtime_env


def _fake_execle(captured: dict):
Expand Down Expand Up @@ -55,7 +56,7 @@ def test_cli_run_execs_command_with_env(monkeypatch):
assert captured["env"][constants.ENV_STRICT_INTEGRATIONS] == "0"
assert captured["env"][constants.ENV_PRINT_STARTUP_REPORT] == "0"
assert captured["env"][constants.ENV_FLUSH_TIMEOUT] == str(
constants.DEFAULT_SHUTDOWN_FLUSH_TIMEOUT_SEC
constants.DEFAULT_RUNTIME_FLUSH_TIMEOUT_SEC
)


Expand Down Expand Up @@ -148,7 +149,7 @@ def test_install_runtime_requires_dsn(monkeypatch):
bootstrap.install_runtime()


def test_install_runtime_default_flush_timeout_is_shutdown_budget(monkeypatch):
def test_install_runtime_default_flush_timeout_is_runtime_budget(monkeypatch):
class FakeWildEdge:
SUPPORTED_INTEGRATIONS = {"onnx"}

Expand All @@ -171,7 +172,7 @@ def close(self): # type: ignore[no-untyped-def]

context = bootstrap.install_runtime()
try:
assert context.flush_timeout == constants.DEFAULT_SHUTDOWN_FLUSH_TIMEOUT_SEC
assert context.flush_timeout == constants.DEFAULT_RUNTIME_FLUSH_TIMEOUT_SEC
finally:
context.shutdown()

Expand Down Expand Up @@ -471,3 +472,42 @@ def test_parse_run_args_only_double_dash_raises():
"""['--'] with nothing after raises ValueError."""
with pytest.raises(ValueError, match="missing command"):
cli.parse_run_args(["--"])


def test_cli_run_default_flush_timeout_is_nonzero(monkeypatch):
"""CLI must pass a non-zero flush timeout so reservoir events are sent at shutdown."""
captured: dict = {}
monkeypatch.setattr(cli.os, "execle", _fake_execle(captured))
monkeypatch.setattr(cli.shutil, "which", lambda cmd: f"/usr/bin/{cmd}")

cli.main(["run", "--", "gunicorn", "myapp.wsgi:app"])

flush_timeout = float(captured["env"][constants.ENV_FLUSH_TIMEOUT])
assert flush_timeout > 0, (
"flush timeout must be > 0 so reservoir inference events are flushed at shutdown"
)
assert flush_timeout == constants.DEFAULT_RUNTIME_FLUSH_TIMEOUT_SEC


def test_cli_run_flush_timeout_override(monkeypatch):
"""--flush-timeout arg is forwarded to child process."""
captured: dict = {}
monkeypatch.setattr(cli.os, "execle", _fake_execle(captured))
monkeypatch.setattr(cli.shutil, "which", lambda cmd: f"/usr/bin/{cmd}")

cli.main(["run", "--flush-timeout", "10.0", "--", "gunicorn", "myapp.wsgi:app"])

assert captured["env"][constants.ENV_FLUSH_TIMEOUT] == "10.0"


def test_read_runtime_env_default_flush_timeout_is_nonzero():
"""read_runtime_env must default to a non-zero flush timeout when env var is absent."""
env = {constants.ENV_DSN: "https://secret@ingest.wildedge.dev/key"}
result = read_runtime_env(all_integrations=[], all_hubs=[], environ=env)
assert result.flush_timeout > 0
assert result.flush_timeout == constants.DEFAULT_RUNTIME_FLUSH_TIMEOUT_SEC


def test_runtime_flush_timeout_constant_is_nonzero():
"""Guard: DEFAULT_RUNTIME_FLUSH_TIMEOUT_SEC must stay > 0."""
assert constants.DEFAULT_RUNTIME_FLUSH_TIMEOUT_SEC > 0
9 changes: 9 additions & 0 deletions tests/test_client_flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,15 @@ def test_register_model_fallback_requires_id_when_no_extractor(
client.register_model(object())


def test_register_model_fallback_uses_model_id_as_name(
client_with_stubbed_runtime,
):
client = client_with_stubbed_runtime
with patch.object(client, "_find_extractor", return_value=None):
client.register_model(object(), model_id="openai/gpt-4o")
assert client.registry.models["openai/gpt-4o"].model_name == "openai/gpt-4o"


def test_on_model_auto_loaded_uses_hub_records_when_downloads_missing(
client_with_stubbed_runtime, dummy_handle
):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def set(self) -> None:

stop_control = StopControl()
consumer.stop_event = stop_control
consumer.drain_once = lambda: False
consumer.drain_once = lambda flush_reservoir=False: False

called = {"count": 0}

Expand Down
2 changes: 1 addition & 1 deletion tests/test_offline_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_offline_replay_restores_model_registry_for_pending_events(tmp_path):
quantization="fp32",
)
client_a.publish(
{"event_id": "e1", "event_type": "inference", "model_id": "ResNet"}
{"event_id": "e1", "event_type": "model_load", "model_id": "ResNet"}
)
client_a.close()

Expand Down
Loading
Loading