diff --git a/.gitignore b/.gitignore index 64d49ae..468c518 100644 --- a/.gitignore +++ b/.gitignore @@ -213,4 +213,5 @@ marimo/_lsp/ __marimo__/ # Streamlit -.streamlit/secrets.toml \ No newline at end of file +.streamlit/secrets.toml +internal/ diff --git a/README.md b/README.md index cbe9c81..1777cd2 100644 --- a/README.md +++ b/README.md @@ -59,7 +59,6 @@ client = wildedge.init( If no DSN is configured, the client becomes a no-op and logs a warning. `init(...)` is a convenience wrapper for `WildEdge(...)` + `instrument(...)`. - ## Supported integrations **On-device** @@ -105,6 +104,15 @@ For unsupported frameworks, see [Manual tracking](https://github.com/wild-edge/w For advanced options (batching, queue tuning, dead-letter storage), see [Configuration](https://github.com/wild-edge/wildedge-python/blob/main/docs/configuration.md). +## Projects using this SDK + +| Name | Link | +|---|---| +| agntr | [github.com/pmaciolek/agntr](https://github.com/pmaciolek/agntr) | +| *(your project here)* | - | + +Using WildEdge in your project? Open a PR to add it to the list. + ## Privacy Report security & priact issues to: wildedge@googlegroups.com. diff --git a/docs/manual-tracking.md b/docs/manual-tracking.md index f90f58a..0fa227e 100644 --- a/docs/manual-tracking.md +++ b/docs/manual-tracking.md @@ -215,3 +215,58 @@ handle.feedback(FeedbackType.THUMBS_DOWN) ``` `FeedbackType` values: `THUMBS_UP`, `THUMBS_DOWN`. + +## Track spans for agentic workflows + +Use span events to track non-inference steps like planning, tool calls, retrieval, or memory updates. + +```python +from wildedge.timing import Timer + +with Timer() as t: + tool_result = call_tool() + +client.track_span( + kind="tool", + name="call_tool", + duration_ms=t.elapsed_ms, + status="ok", + attributes={"tool": "search"}, +) +``` + +You can also attach optional correlation fields (`trace_id`, `span_id`, +`parent_span_id`, `run_id`, `agent_id`, `step_index`, `conversation_id`) to any +event by passing them into `track_inference`, `track_error`, `track_feedback`, +or `track_span`. Use `context=` for correlation attributes shared across events. + +### Trace context helpers + +Use `client.trace()` and `client.span()` to auto-populate correlation fields for +all events emitted inside the block. `client.span()` times the block and emits a +span event on exit: + +```python +import wildedge +from wildedge.timing import Timer + +client = wildedge.init() +handle = client.register_model(my_model, model_id="my-org/my-model") + +with client.trace(run_id="run-123", agent_id="agent-1"): + with client.span(kind="agent_step", name="plan", step_index=1): + with Timer() as t: + result = my_model(prompt) + handle.track_inference(duration_ms=t.elapsed_ms, input_modality="text", output_modality="generation") +``` + +If you need to set correlation fields without emitting a span event, use the +lower-level `span_context()` directly: + +```python +with client.trace(run_id="run-123", agent_id="agent-1"): + with wildedge.span_context(step_index=1): + with Timer() as t: + result = my_model(prompt) + handle.track_inference(duration_ms=t.elapsed_ms, input_modality="text", output_modality="generation") +``` diff --git a/examples/agentic_example.py b/examples/agentic_example.py new file mode 100644 index 0000000..446c812 --- /dev/null +++ b/examples/agentic_example.py @@ -0,0 +1,181 @@ +# /// script +# requires-python = ">=3.10" +# dependencies = ["wildedge-sdk", "openai"] +# +# [tool.uv.sources] +# wildedge-sdk = { path = "..", editable = true } +# /// +"""Agentic workflow example with tool use. + +Demonstrates WildEdge tracing for a simple agent that: + - Runs within a trace (one per agent session) + - Wraps each reasoning step in an agent_step span + - Wraps each tool call in a tool span + - Tracks LLM inference automatically via the OpenAI integration + +Run with: uv run agentic_example.py +Requires: OPENROUTER_API_KEY environment variable. Set WILDEDGE_DSN to send events. +""" + +import json +import os +import time +import uuid + +from openai import OpenAI + +import wildedge + +we = wildedge.init( + app_version="1.0.0", + integrations="openai", +) + +openai_client = OpenAI( + base_url="https://openrouter.ai/api/v1", + api_key=os.getenv("OPENROUTER_API_KEY"), +) + +# --- Tools ------------------------------------------------------------------- + +TOOLS = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Return current weather for a city.", + "parameters": { + "type": "object", + "properties": { + "city": {"type": "string"}, + }, + "required": ["city"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "calculator", + "description": "Evaluate a simple arithmetic expression.", + "parameters": { + "type": "object", + "properties": { + "expression": {"type": "string"}, + }, + "required": ["expression"], + }, + }, + }, +] + + +def get_weather(city: str) -> str: + # ~150ms to simulate a real weather API call. + time.sleep(0.15) + return json.dumps({"city": city, "temperature_c": 18, "condition": "partly cloudy"}) + + +def calculator(expression: str) -> str: + # ~60ms to simulate a remote computation call. + time.sleep(0.06) + try: + result = eval(expression, {"__builtins__": {}}) # noqa: S307 + return json.dumps({"expression": expression, "result": result}) + except Exception as e: + return json.dumps({"error": str(e)}) + + +TOOL_HANDLERS = { + "get_weather": get_weather, + "calculator": calculator, +} + + +# --- Agent loop -------------------------------------------------------------- + + +def call_tool(name: str, arguments: dict) -> str: + with we.span( + kind="tool", + name=name, + input_summary=json.dumps(arguments)[:200], + ) as span: + result = TOOL_HANDLERS[name](**arguments) + span.output_summary = result[:200] + return result + + +def retrieve_context(query: str) -> str: + """Fetch relevant context from the vector store (~120ms).""" + with we.span( + kind="retrieval", + name="vector_search", + input_summary=query[:200], + ) as span: + time.sleep(0.12) + result = f"[context: background knowledge relevant to '{query[:40]}']" + span.output_summary = result + return result + + +def run_agent(task: str, step_index: int, messages: list) -> str: + # Fetch context before the first reasoning step, include it in the user turn. + context = retrieve_context(task) + messages.append({"role": "user", "content": f"{task}\n\nContext: {context}"}) + + while True: + with we.span( + kind="agent_step", + name="reason", + step_index=step_index, + input_summary=task[:200], + ) as span: + response = openai_client.chat.completions.create( + model="qwen/qwen3.5-flash-02-23", + messages=messages, + tools=TOOLS, + tool_choice="auto", + max_tokens=512, + ) + choice = response.choices[0] + span.output_summary = choice.finish_reason + + messages.append(choice.message.model_dump(exclude_none=True)) + + if choice.finish_reason == "tool_calls": + step_index += 1 + for tool_call in choice.message.tool_calls: + arguments = json.loads(tool_call.function.arguments) + result = call_tool(tool_call.function.name, arguments) + messages.append( + { + "role": "tool", + "tool_call_id": tool_call.id, + "content": result, + } + ) + # Not instrumented: context window update between tool calls (~80ms). + # Shows up as a gap stripe in the trace view. + time.sleep(0.08) + else: + return choice.message.content or "" + + +# --- Main -------------------------------------------------------------------- + +TASKS = [ + "What's the weather like in Tokyo, and what is 42 * 18?", + "Is it warmer in Paris or Berlin right now?", +] + +system_prompt = "You are a helpful assistant. Use tools when needed." +messages = [{"role": "system", "content": system_prompt}] + +with we.trace(agent_id="demo-agent", run_id=str(uuid.uuid4())): + for i, task in enumerate(TASKS, start=1): + print(f"\nTask {i}: {task}") + reply = run_agent(task, step_index=i, messages=messages) + print(f"Reply: {reply}") + +we.flush() diff --git a/examples/gguf_example.py b/examples/gguf_example.py index cf58f88..bd50f98 100644 --- a/examples/gguf_example.py +++ b/examples/gguf_example.py @@ -31,6 +31,9 @@ ] for prompt in prompts: - result = llm(prompt, max_tokens=128, temperature=0.7) - text = result["choices"][0]["text"].strip() - print(f"Q: {prompt}\nA: {text}\n") + stream = llm(prompt, max_tokens=128, temperature=0.7, stream=True) + print(f"Q: {prompt}\nA: ", end="", flush=True) + for chunk in stream: + token = chunk["choices"][0].get("text", "") + print(token, end="", flush=True) + print("\n") diff --git a/examples/openai_example.py b/examples/openai_example.py index d9410e4..ad2d3a9 100644 --- a/examples/openai_example.py +++ b/examples/openai_example.py @@ -32,13 +32,19 @@ ] for prompt in prompts: - response = openai_client.chat.completions.create( + stream = openai_client.chat.completions.create( model="gpt-4o", messages=[{"role": "user", "content": prompt}], temperature=0.7, max_tokens=256, + stream=True, + stream_options={"include_usage": True}, ) - print(f"Q: {prompt}\nA: {response.choices[0].message.content}\n") + print(f"Q: {prompt}\nA: ", end="", flush=True) + for chunk in stream: + if chunk.choices and chunk.choices[0].delta.content: + print(chunk.choices[0].delta.content, end="", flush=True) + print("\n") client.flush() print("Done. Events flushed to WildEdge.") diff --git a/pyproject.toml b/pyproject.toml index 3a7e63e..fa9edea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "wildedge-sdk" -version = "0.1.2" +version = "0.1.3" description = "On-device ML inference monitoring for Python" readme = "README.md" requires-python = ">=3.10" @@ -42,6 +42,7 @@ build-backend = "hatchling.build" [tool.hatch.build] exclude = [ "/scripts", + "/examples", ] [tool.hatch.build.targets.wheel] diff --git a/tests/test_event_serialization.py b/tests/test_event_serialization.py index 12028c7..587940b 100644 --- a/tests/test_event_serialization.py +++ b/tests/test_event_serialization.py @@ -4,6 +4,7 @@ from wildedge.events.inference import InferenceEvent, TextInputMeta from wildedge.events.model_download import AdapterDownload, ModelDownloadEvent from wildedge.events.model_load import AdapterLoad, ModelLoadEvent +from wildedge.events.span import SpanEvent def test_inference_event_to_dict_omits_none_fields(): @@ -72,3 +73,44 @@ def test_feedback_event_enum_and_string_forms(): ) assert enum_event.to_dict()["feedback"]["feedback_type"] == "accept" assert string_event.to_dict()["feedback"]["feedback_type"] == "reject" + + +def test_span_event_to_dict_includes_required_fields(): + event = SpanEvent( + kind="tool", + name="search", + duration_ms=250, + status="ok", + attributes={"provider": "custom"}, + ) + data = event.to_dict() + assert data["event_type"] == "span" + assert data["span"]["kind"] == "tool" + assert data["span"]["attributes"]["provider"] == "custom" + + +def test_span_event_context_serializes_under_context_key(): + event = SpanEvent( + kind="agent_step", + name="plan", + duration_ms=10, + status="ok", + context={"user_id": "u1"}, + ) + data = event.to_dict() + assert data["context"] == {"user_id": "u1"} + assert "attributes" not in data + + +def test_span_event_attributes_and_context_are_independent(): + event = SpanEvent( + kind="tool", + name="search", + duration_ms=50, + status="ok", + attributes={"provider": "custom"}, + context={"user_id": "u1"}, + ) + data = event.to_dict() + assert data["span"]["attributes"] == {"provider": "custom"} + assert data["context"] == {"user_id": "u1"} diff --git a/tests/test_integrations_openai.py b/tests/test_integrations_openai.py index 9fea3ce..dbba8d0 100644 --- a/tests/test_integrations_openai.py +++ b/tests/test_integrations_openai.py @@ -9,6 +9,7 @@ import pytest import wildedge.integrations.openai as openai_mod +from wildedge.integrations.common import AsyncStreamWrapper, SyncStreamWrapper from wildedge.integrations.openai import ( OpenAIExtractor, build_api_meta, @@ -83,6 +84,56 @@ async def create(self, *args, **kwargs): return self._response +def make_stream_chunk(content=None, finish_reason=None, usage=None): + chunk = SimpleNamespace( + choices=[ + SimpleNamespace( + delta=SimpleNamespace(content=content), + finish_reason=finish_reason, + ) + ], + usage=usage, + model="gpt-4o", + system_fingerprint=None, + service_tier=None, + ) + return chunk + + +class FakeStreamingCompletions: + def __init__(self, chunks): + self._chunks = chunks + + def create(self, *args, **kwargs): + if kwargs.get("stream"): + return iter(self._chunks) + return FakeResponse() + + +class FakeAsyncStreamingCompletions: + def __init__(self, chunks): + self._chunks = chunks + + async def create(self, *args, **kwargs): + if kwargs.get("stream"): + return FakeAsyncIterator(self._chunks) + return FakeResponse() + + +class FakeAsyncIterator: + def __init__(self, items): + self._iter = iter(items) + + def __aiter__(self): + return self + + async def __anext__(self): + try: + return next(self._iter) + except StopIteration: + raise StopAsyncIteration + + # Named "OpenAI" / "AsyncOpenAI" so can_handle sees the right type name. class OpenAI: def __init__(self, base_url="https://api.openai.com/v1", api_key=None): @@ -371,11 +422,65 @@ def create(self, *args, **kwargs): client.handles["gpt-4o"].track_error.assert_called_once() client.handles["gpt-4o"].track_inference.assert_not_called() - def test_streaming_skips_tracking(self): - completions, client = self.setup() - completions.create(model="gpt-4o", messages=[], stream=True) - if "gpt-4o" in client.handles: - client.handles["gpt-4o"].track_inference.assert_not_called() + def test_streaming_returns_sync_stream_wrapper(self): + chunks = [make_stream_chunk("hi", None), make_stream_chunk(None, "stop")] + completions = FakeStreamingCompletions(chunks) + client = make_fake_client() + wrap_sync_completions(completions, "openai", lambda: client) + result = completions.create(model="gpt-4o", messages=[], stream=True) + assert isinstance(result, SyncStreamWrapper) + + def test_streaming_records_inference_on_exhaustion(self): + chunks = [ + make_stream_chunk("Hello", None), + make_stream_chunk(" world", "stop"), + ] + completions = FakeStreamingCompletions(chunks) + client = make_fake_client() + wrap_sync_completions(completions, "openai", lambda: client) + stream = completions.create( + model="gpt-4o", messages=[{"role": "user", "content": "hi"}], stream=True + ) + list(stream) + handle = client.handles["gpt-4o"] + handle.track_inference.assert_called_once() + kwargs = handle.track_inference.call_args.kwargs + assert kwargs["output_meta"].time_to_first_token_ms is not None + assert kwargs["output_meta"].stop_reason == "stop" + assert kwargs["input_modality"] == "text" + assert kwargs["success"] is True + + def test_streaming_captures_usage_from_chunks(self): + usage_chunk = SimpleNamespace(prompt_tokens=8, completion_tokens=15) + chunks = [ + make_stream_chunk("hi", None), + make_stream_chunk(None, "stop", usage=usage_chunk), + ] + completions = FakeStreamingCompletions(chunks) + client = make_fake_client() + wrap_sync_completions(completions, "openai", lambda: client) + list(completions.create(model="gpt-4o", messages=[], stream=True)) + out = client.handles["gpt-4o"].track_inference.call_args.kwargs["output_meta"] + assert out.tokens_in == 8 + assert out.tokens_out == 15 + + def test_streaming_error_during_iteration_tracks_error(self): + def bad_iter(): + yield make_stream_chunk("hi", None) + raise RuntimeError("stream error") + + class ErrorStreamCompletions: + def create(self, *args, **kwargs): + return bad_iter() + + client = make_fake_client() + completions = ErrorStreamCompletions() + wrap_sync_completions(completions, "openai", lambda: client) + stream = completions.create(model="gpt-4o", messages=[], stream=True) + with pytest.raises(RuntimeError, match="stream error"): + list(stream) + client.handles["gpt-4o"].track_error.assert_called_once() + client.handles["gpt-4o"].track_inference.assert_not_called() def test_closed_client_passes_through(self): completions, client = self.setup(closed=True) @@ -438,11 +543,37 @@ async def create(self, *args, **kwargs): client.handles["gpt-4o"].track_error.assert_called_once() - async def test_streaming_skips_tracking(self): - completions, client = self.setup() - await completions.create(model="qwen/qwen3-235b", messages=[], stream=True) - if "qwen/qwen3-235b" in client.handles: - client.handles["qwen/qwen3-235b"].track_inference.assert_not_called() + async def test_streaming_returns_async_stream_wrapper(self): + chunks = [make_stream_chunk("hi", None), make_stream_chunk(None, "stop")] + completions = FakeAsyncStreamingCompletions(chunks) + client = make_fake_client() + wrap_async_completions(completions, "openrouter", lambda: client) + result = await completions.create( + model="qwen/qwen3-235b", messages=[], stream=True + ) + assert isinstance(result, AsyncStreamWrapper) + + async def test_streaming_records_inference_on_exhaustion(self): + chunks = [ + make_stream_chunk("Hello", None), + make_stream_chunk(" world", "stop"), + ] + completions = FakeAsyncStreamingCompletions(chunks) + client = make_fake_client() + wrap_async_completions(completions, "openrouter", lambda: client) + stream = await completions.create( + model="qwen/qwen3-235b", + messages=[{"role": "user", "content": "hi"}], + stream=True, + ) + async for _ in stream: + pass + handle = client.handles["qwen/qwen3-235b"] + handle.track_inference.assert_called_once() + kwargs = handle.track_inference.call_args.kwargs + assert kwargs["output_meta"].time_to_first_token_ms is not None + assert kwargs["output_meta"].stop_reason == "stop" + assert kwargs["success"] is True # --------------------------------------------------------------------------- diff --git a/tests/test_tracing.py b/tests/test_tracing.py new file mode 100644 index 0000000..944f1e0 --- /dev/null +++ b/tests/test_tracing.py @@ -0,0 +1,93 @@ +from __future__ import annotations + +from wildedge.client import SpanContextManager +from wildedge.model import ModelHandle, ModelInfo +from wildedge.tracing import get_span_context, span_context, trace_context + + +def test_track_inference_uses_trace_context(): + events: list[dict] = [] + + def publish(event: dict) -> None: + events.append(event) + + handle = ModelHandle( + "model-1", + ModelInfo( + model_name="test", + model_version="1.0", + model_source="local", + model_format="onnx", + ), + publish, + ) + + with trace_context( + trace_id="trace-123", + run_id="run-1", + agent_id="agent-1", + attributes={"trace_key": "trace_val"}, + ): + with span_context(span_id="span-abc", step_index=2, attributes={"span_key": 2}): + handle.track_inference(duration_ms=5) + + assert events[0]["trace_id"] == "trace-123" + assert events[0]["parent_span_id"] == "span-abc" + assert events[0]["run_id"] == "run-1" + assert events[0]["agent_id"] == "agent-1" + assert events[0]["step_index"] == 2 + assert events[0]["attributes"] == {"trace_key": "trace_val", "span_key": 2} + + +class _FakeClient: + def __init__(self, events: list[dict]) -> None: + self._events = events + + def track_span(self, **kwargs) -> str: + self._events.append(kwargs) + return kwargs.get("span_id", "") + + +def test_span_root_has_no_parent(): + """A root span must not reference itself as its own parent.""" + events: list[dict] = [] + client = _FakeClient(events) + + with SpanContextManager(client, kind="agent_step", name="root"): + pass + + assert len(events) == 1 + assert events[0]["parent_span_id"] is None + + +def test_span_context_restored_after_exit(): + """The active span context must revert to the parent after a span exits.""" + events: list[dict] = [] + client = _FakeClient(events) + + with span_context(span_id="parent-span"): + with SpanContextManager(client, kind="agent_step", name="child"): + inner_id = get_span_context().span_id + + assert get_span_context().span_id == "parent-span" + + assert inner_id != "parent-span" + assert events[0]["parent_span_id"] == "parent-span" + assert events[0]["span_id"] != "parent-span" + + +def test_nested_spans_correct_parent_chain(): + """Nested spans must each point to their direct parent, not themselves.""" + events: list[dict] = [] + client = _FakeClient(events) + + with SpanContextManager(client, kind="agent_step", name="outer") as outer: + with SpanContextManager(client, kind="tool", name="inner") as inner: + pass + + assert len(events) == 2 + inner_ev, outer_ev = events[0], events[1] + assert inner_ev["span_id"] == inner.span_id + assert inner_ev["parent_span_id"] == outer.span_id + assert outer_ev["span_id"] == outer.span_id + assert outer_ev["parent_span_id"] is None diff --git a/wildedge/__init__.py b/wildedge/__init__.py index f1f449b..2b7993f 100644 --- a/wildedge/__init__.py +++ b/wildedge/__init__.py @@ -1,6 +1,6 @@ """WildEdge Python SDK.""" -from wildedge.client import WildEdge +from wildedge.client import SpanContextManager, WildEdge from wildedge.convenience import init from wildedge.decorators import track from wildedge.events import ( @@ -15,13 +15,22 @@ GenerationConfig, GenerationOutputMeta, ImageInputMeta, + SpanEvent, TextInputMeta, ) +from wildedge.events.span import SpanKind, SpanStatus from wildedge.platforms import capture_hardware from wildedge.platforms.device_info import DeviceInfo from wildedge.platforms.hardware import HardwareContext, ThermalContext from wildedge.queue import QueuePolicy from wildedge.timing import Timer +from wildedge.tracing import ( + SpanContext, + TraceContext, + get_span_context, + get_trace_context, + span_context, +) __all__ = [ "WildEdge", @@ -42,7 +51,16 @@ "GenerationConfig", "AdapterLoad", "AdapterDownload", + "SpanEvent", "FeedbackType", "ErrorCode", "Timer", + "span_context", + "TraceContext", + "SpanContext", + "get_trace_context", + "get_span_context", + "SpanKind", + "SpanStatus", + "SpanContextManager", ] diff --git a/wildedge/client.py b/wildedge/client.py index da5e052..627cbc1 100644 --- a/wildedge/client.py +++ b/wildedge/client.py @@ -12,6 +12,8 @@ from wildedge import constants from wildedge.consumer import Consumer from wildedge.dead_letters import DeadLetterStore +from wildedge.events import SpanEvent +from wildedge.events.span import SpanKind, SpanStatus from wildedge.hubs.base import BaseHubTracker from wildedge.hubs.huggingface import HuggingFaceHubTracker from wildedge.hubs.registry import supported_hubs @@ -39,6 +41,14 @@ from wildedge.queue import EventQueue, QueuePolicy from wildedge.settings import read_client_env, resolve_app_identity from wildedge.timing import Timer, elapsed_ms +from wildedge.tracing import ( + SpanContext, + _merge_correlation_fields, + _reset_span_context, + _set_span_context, + get_span_context, + trace_context, +) from wildedge.transmitter import Transmitter DSN_FORMAT = "'https://@ingest.wildedge.dev/'" @@ -103,6 +113,101 @@ def parse_dsn(dsn: str) -> tuple[str, str, str]: ] +class SpanContextManager: + def __init__( + self, + client: WildEdge, + *, + kind: SpanKind, + name: str, + status: SpanStatus = "ok", + model_id: str | None = None, + input_summary: str | None = None, + output_summary: str | None = None, + attributes: dict[str, Any] | None = None, + trace_id: str | None = None, + span_id: str | None = None, + parent_span_id: str | None = None, + run_id: str | None = None, + agent_id: str | None = None, + step_index: int | None = None, + conversation_id: str | None = None, + context: dict[str, Any] | None = None, + ): + self._client = client + self.kind = kind + self.name = name + self.status = status + self.model_id = model_id + self.input_summary = input_summary + self.output_summary = output_summary + self.attributes = attributes + self.trace_id = trace_id + self.span_id = span_id + self.parent_span_id = parent_span_id + self.run_id = run_id + self.agent_id = agent_id + self.step_index = step_index + self.conversation_id = conversation_id + self.context = context + self._t0: float | None = None + self._span_token = None + + def __enter__(self): + self._t0 = time.perf_counter() + if self.span_id is None: + self.span_id = str(uuid.uuid4()) + if self.parent_span_id is None: + current = get_span_context() + self.parent_span_id = current.span_id if current else None + self._span_token = _set_span_context( + SpanContext( + span_id=self.span_id, + parent_span_id=self.parent_span_id, + step_index=self.step_index, + attributes=self.context, + ) + ) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self._t0 is None: + return False + duration_ms = elapsed_ms(self._t0) + status = "error" if exc_type else self.status + # Restore parent span context before emitting, so _merge_correlation_fields + # sees the parent context rather than this span (which would make the span + # appear as its own parent). + if self._span_token is not None: + _reset_span_context(self._span_token) + self._span_token = None + self._client.track_span( + kind=self.kind, + name=self.name, + duration_ms=duration_ms, + status=status, + model_id=self.model_id, + input_summary=self.input_summary, + output_summary=self.output_summary, + attributes=self.attributes, + trace_id=self.trace_id, + span_id=self.span_id, + parent_span_id=self.parent_span_id, + run_id=self.run_id, + agent_id=self.agent_id, + step_index=self.step_index, + conversation_id=self.conversation_id, + context=self.context, + ) + return False + + async def __aenter__(self): + return self.__enter__() + + async def __aexit__(self, exc_type, exc_val, exc_tb): + return self.__exit__(exc_type, exc_val, exc_tb) + + class WildEdge: """ WildEdge on-device ML monitoring client. @@ -381,6 +486,110 @@ def register_model( return handle + def trace( + self, + *, + trace_id: str | None = None, + run_id: str | None = None, + agent_id: str | None = None, + conversation_id: str | None = None, + attributes: dict[str, Any] | None = None, + ): + """Context manager that sets trace correlation fields.""" + return trace_context( + trace_id=trace_id, + run_id=run_id, + agent_id=agent_id, + conversation_id=conversation_id, + attributes=attributes, + ) + + def span( + self, + *, + kind: SpanKind, + name: str, + status: SpanStatus = "ok", + model_id: str | None = None, + input_summary: str | None = None, + output_summary: str | None = None, + attributes: dict[str, Any] | None = None, + trace_id: str | None = None, + span_id: str | None = None, + parent_span_id: str | None = None, + run_id: str | None = None, + agent_id: str | None = None, + step_index: int | None = None, + conversation_id: str | None = None, + context: dict[str, Any] | None = None, + ) -> SpanContextManager: + """Context manager that times and emits a span event.""" + return SpanContextManager( + self, + kind=kind, + name=name, + status=status, + model_id=model_id, + input_summary=input_summary, + output_summary=output_summary, + attributes=attributes, + trace_id=trace_id, + span_id=span_id, + parent_span_id=parent_span_id, + run_id=run_id, + agent_id=agent_id, + step_index=step_index, + conversation_id=conversation_id, + context=context, + ) + + def track_span( + self, + *, + kind: SpanKind, + name: str, + duration_ms: int, + status: SpanStatus = "ok", + model_id: str | None = None, + input_summary: str | None = None, + output_summary: str | None = None, + attributes: dict[str, Any] | None = None, + trace_id: str | None = None, + span_id: str | None = None, + parent_span_id: str | None = None, + run_id: str | None = None, + agent_id: str | None = None, + step_index: int | None = None, + conversation_id: str | None = None, + context: dict[str, Any] | None = None, + ) -> str: + """Emit a span event for agentic workflows and tooling.""" + correlation = _merge_correlation_fields( + trace_id=trace_id, + span_id=span_id, + parent_span_id=parent_span_id, + run_id=run_id, + agent_id=agent_id, + step_index=step_index, + conversation_id=conversation_id, + context=context, + ) + if correlation["span_id"] is None: + correlation["span_id"] = str(uuid.uuid4()) + event = SpanEvent( + kind=kind, + name=name, + duration_ms=duration_ms, + status=status, + model_id=model_id, + input_summary=input_summary, + output_summary=output_summary, + attributes=attributes, + **correlation, + ) + self.publish(event.to_dict()) + return correlation["span_id"] + def _find_extractor(self, model_obj: object) -> BaseExtractor | None: for candidate in DEFAULT_EXTRACTORS: if candidate.can_handle(model_obj): diff --git a/wildedge/decorators.py b/wildedge/decorators.py index fe2e3ff..f953159 100644 --- a/wildedge/decorators.py +++ b/wildedge/decorators.py @@ -38,6 +38,14 @@ def __init__( input_meta: Any = None, output_meta: Any = None, generation_config: Any = None, + trace_id: str | None = None, + span_id: str | None = None, + parent_span_id: str | None = None, + run_id: str | None = None, + agent_id: str | None = None, + step_index: int | None = None, + conversation_id: str | None = None, + context: dict[str, Any] | None = None, ): self.handle = handle self.input_type = input_type @@ -46,6 +54,16 @@ def __init__( self.input_meta = input_meta self.output_meta = output_meta self.generation_config = generation_config + self._correlation = dict( + trace_id=trace_id, + span_id=span_id, + parent_span_id=parent_span_id, + run_id=run_id, + agent_id=agent_id, + step_index=step_index, + conversation_id=conversation_id, + context=context, + ) self.start_time: float | None = None def __call__(self, func): @@ -62,6 +80,7 @@ def wrapper(*args, **kwargs): input_meta=self.input_meta, output_meta=self.output_meta, generation_config=self.generation_config, + **self._correlation, ) return result except Exception as exc: @@ -69,6 +88,7 @@ def wrapper(*args, **kwargs): self.handle.track_error( error_code="UNKNOWN", error_message=str(exc)[: constants.ERROR_MSG_MAX_LEN], + **self._correlation, ) raise @@ -89,6 +109,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): error_message=str(exc_val)[: constants.ERROR_MSG_MAX_LEN] if exc_val else None, + **self._correlation, ) else: self.handle.track_inference( @@ -99,5 +120,6 @@ def __exit__(self, exc_type, exc_val, exc_tb): input_meta=self.input_meta, output_meta=self.output_meta, generation_config=self.generation_config, + **self._correlation, ) return False diff --git a/wildedge/events/__init__.py b/wildedge/events/__init__.py index b08bc78..6648c40 100644 --- a/wildedge/events/__init__.py +++ b/wildedge/events/__init__.py @@ -19,6 +19,7 @@ from wildedge.events.model_download import AdapterDownload, ModelDownloadEvent from wildedge.events.model_load import AdapterLoad, ModelLoadEvent from wildedge.events.model_unload import ModelUnloadEvent +from wildedge.events.span import SpanEvent __all__ = [ "ApiMeta", @@ -40,6 +41,7 @@ "ModelDownloadEvent", "ModelLoadEvent", "ModelUnloadEvent", + "SpanEvent", "TextInputMeta", "TopKPrediction", ] diff --git a/wildedge/events/common.py b/wildedge/events/common.py new file mode 100644 index 0000000..399b55b --- /dev/null +++ b/wildedge/events/common.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +from typing import Any + + +def add_optional_fields(event: dict, fields: dict[str, Any]) -> dict: + """Add non-None fields to an event payload.""" + for key, value in fields.items(): + if value is not None: + event[key] = value + return event diff --git a/wildedge/events/error.py b/wildedge/events/error.py index 13d29f1..e814120 100644 --- a/wildedge/events/error.py +++ b/wildedge/events/error.py @@ -23,6 +23,14 @@ class ErrorEvent: error_message: str | None = None stack_trace_hash: str | None = None related_event_id: str | None = None + trace_id: str | None = None + span_id: str | None = None + parent_span_id: str | None = None + run_id: str | None = None + agent_id: str | None = None + step_index: int | None = None + conversation_id: str | None = None + context: dict[str, Any] | None = None event_id: str = field(default_factory=lambda: str(uuid.uuid4())) timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) @@ -40,10 +48,26 @@ def to_dict(self) -> dict: if self.related_event_id is not None: error_data["related_event_id"] = self.related_event_id - return { + event = { "event_id": self.event_id, "event_type": "error", "timestamp": self.timestamp.isoformat(), "model_id": self.model_id, "error": error_data, } + from wildedge.events.common import add_optional_fields + + add_optional_fields( + event, + { + "trace_id": self.trace_id, + "span_id": self.span_id, + "parent_span_id": self.parent_span_id, + "run_id": self.run_id, + "agent_id": self.agent_id, + "step_index": self.step_index, + "conversation_id": self.conversation_id, + "attributes": self.context, + }, + ) + return event diff --git a/wildedge/events/feedback.py b/wildedge/events/feedback.py index 650904b..dfb6efa 100644 --- a/wildedge/events/feedback.py +++ b/wildedge/events/feedback.py @@ -24,6 +24,14 @@ class FeedbackEvent: feedback_type: str | FeedbackType delay_ms: int | None = None edit_distance: int | None = None + trace_id: str | None = None + span_id: str | None = None + parent_span_id: str | None = None + run_id: str | None = None + agent_id: str | None = None + step_index: int | None = None + conversation_id: str | None = None + context: dict[str, Any] | None = None event_id: str = field(default_factory=lambda: str(uuid.uuid4())) timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) @@ -42,10 +50,26 @@ def to_dict(self) -> dict: if self.edit_distance is not None: feedback_data["edit_distance"] = self.edit_distance - return { + event = { "event_id": self.event_id, "event_type": "feedback", "timestamp": self.timestamp.isoformat(), "model_id": self.model_id, "feedback": feedback_data, } + from wildedge.events.common import add_optional_fields + + add_optional_fields( + event, + { + "trace_id": self.trace_id, + "span_id": self.span_id, + "parent_span_id": self.parent_span_id, + "run_id": self.run_id, + "agent_id": self.agent_id, + "step_index": self.step_index, + "conversation_id": self.conversation_id, + "attributes": self.context, + }, + ) + return event diff --git a/wildedge/events/inference.py b/wildedge/events/inference.py index 0c02996..ef76ae5 100644 --- a/wildedge/events/inference.py +++ b/wildedge/events/inference.py @@ -304,6 +304,14 @@ class InferenceEvent: generation_config: GenerationConfig | None = None hardware: HardwareContext | None = None api_meta: ApiMeta | None = None + trace_id: str | None = None + span_id: str | None = None + parent_span_id: str | None = None + run_id: str | None = None + agent_id: str | None = None + step_index: int | None = None + conversation_id: str | None = None + context: dict[str, Any] | None = None event_id: str = field(default_factory=lambda: str(uuid.uuid4())) timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) inference_id: str = field(default_factory=lambda: str(uuid.uuid4())) @@ -333,10 +341,26 @@ def to_dict(self) -> dict: if self.api_meta is not None: inference_data["api_meta"] = self.api_meta.to_dict() - return { + event = { "event_id": self.event_id, "event_type": "inference", "timestamp": self.timestamp.isoformat(), "model_id": self.model_id, "inference": inference_data, } + from wildedge.events.common import add_optional_fields + + add_optional_fields( + event, + { + "trace_id": self.trace_id, + "span_id": self.span_id, + "parent_span_id": self.parent_span_id, + "run_id": self.run_id, + "agent_id": self.agent_id, + "step_index": self.step_index, + "conversation_id": self.conversation_id, + "attributes": self.context, + }, + ) + return event diff --git a/wildedge/events/model_download.py b/wildedge/events/model_download.py index 784c62f..2e68a3a 100644 --- a/wildedge/events/model_download.py +++ b/wildedge/events/model_download.py @@ -49,6 +49,14 @@ class ModelDownloadEvent: cdn_edge: str | None = None error_code: str | None = None adapter: AdapterDownload | None = None + trace_id: str | None = None + span_id: str | None = None + parent_span_id: str | None = None + run_id: str | None = None + agent_id: str | None = None + step_index: int | None = None + conversation_id: str | None = None + context: dict[str, Any] | None = None event_id: str = field(default_factory=lambda: str(uuid.uuid4())) timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) @@ -83,10 +91,26 @@ def to_dict(self) -> dict: if self.adapter is not None: download_data["adapter"] = self.adapter.to_dict() - return { + event = { "event_id": self.event_id, "event_type": "model_download", "timestamp": self.timestamp.isoformat(), "model_id": self.model_id, "download": download_data, } + from wildedge.events.common import add_optional_fields + + add_optional_fields( + event, + { + "trace_id": self.trace_id, + "span_id": self.span_id, + "parent_span_id": self.parent_span_id, + "run_id": self.run_id, + "agent_id": self.agent_id, + "step_index": self.step_index, + "conversation_id": self.conversation_id, + "attributes": self.context, + }, + ) + return event diff --git a/wildedge/events/model_load.py b/wildedge/events/model_load.py index 9fae742..e058092 100644 --- a/wildedge/events/model_load.py +++ b/wildedge/events/model_load.py @@ -55,6 +55,14 @@ class ModelLoadEvent: cold_start: bool | None = None compile_time_ms: int | None = None adapter: AdapterLoad | None = None + trace_id: str | None = None + span_id: str | None = None + parent_span_id: str | None = None + run_id: str | None = None + agent_id: str | None = None + step_index: int | None = None + conversation_id: str | None = None + context: dict[str, Any] | None = None event_id: str = field(default_factory=lambda: str(uuid.uuid4())) timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) @@ -84,10 +92,26 @@ def to_dict(self) -> dict: if self.adapter is not None: load_data["adapter"] = self.adapter.to_dict() - return { + event = { "event_id": self.event_id, "event_type": "model_load", "timestamp": self.timestamp.isoformat(), "model_id": self.model_id, "load": load_data, } + from wildedge.events.common import add_optional_fields + + add_optional_fields( + event, + { + "trace_id": self.trace_id, + "span_id": self.span_id, + "parent_span_id": self.parent_span_id, + "run_id": self.run_id, + "agent_id": self.agent_id, + "step_index": self.step_index, + "conversation_id": self.conversation_id, + "attributes": self.context, + }, + ) + return event diff --git a/wildedge/events/model_unload.py b/wildedge/events/model_unload.py index 9dab481..16def90 100644 --- a/wildedge/events/model_unload.py +++ b/wildedge/events/model_unload.py @@ -14,6 +14,14 @@ class ModelUnloadEvent: memory_freed_bytes: int | None = None peak_memory_bytes: int | None = None uptime_ms: int | None = None + trace_id: str | None = None + span_id: str | None = None + parent_span_id: str | None = None + run_id: str | None = None + agent_id: str | None = None + step_index: int | None = None + conversation_id: str | None = None + context: dict[str, Any] | None = None event_id: str = field(default_factory=lambda: str(uuid.uuid4())) timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) @@ -30,10 +38,26 @@ def to_dict(self) -> dict: if v is not None: unload_data[k] = v - return { + event = { "event_id": self.event_id, "event_type": "model_unload", "timestamp": self.timestamp.isoformat(), "model_id": self.model_id, "unload": unload_data, } + from wildedge.events.common import add_optional_fields + + add_optional_fields( + event, + { + "trace_id": self.trace_id, + "span_id": self.span_id, + "parent_span_id": self.parent_span_id, + "run_id": self.run_id, + "agent_id": self.agent_id, + "step_index": self.step_index, + "conversation_id": self.conversation_id, + "attributes": self.context, + }, + ) + return event diff --git a/wildedge/events/span.py b/wildedge/events/span.py new file mode 100644 index 0000000..9d5be3c --- /dev/null +++ b/wildedge/events/span.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +import uuid +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import Any, Literal + +from wildedge.events.common import add_optional_fields + +SpanKind = Literal[ + "agent_step", + "tool", + "retrieval", + "memory", + "router", + "guardrail", + "cache", + "eval", + "custom", +] +SpanStatus = Literal["ok", "error"] + + +@dataclass +class SpanEvent: + kind: SpanKind + name: str + duration_ms: int + status: SpanStatus + model_id: str | None = None + input_summary: str | None = None + output_summary: str | None = None + attributes: dict[str, Any] | None = None + trace_id: str | None = None + span_id: str | None = None + parent_span_id: str | None = None + run_id: str | None = None + agent_id: str | None = None + step_index: int | None = None + conversation_id: str | None = None + context: dict[str, Any] | None = None + event_id: str = field(default_factory=lambda: str(uuid.uuid4())) + timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + + def to_dict(self) -> dict: + span_data: dict[str, Any] = { + "kind": self.kind, + "name": self.name, + "duration_ms": self.duration_ms, + "status": self.status, + } + if self.input_summary is not None: + span_data["input_summary"] = self.input_summary + if self.output_summary is not None: + span_data["output_summary"] = self.output_summary + if self.attributes is not None: + span_data["attributes"] = self.attributes + + event = { + "event_id": self.event_id, + "event_type": "span", + "timestamp": self.timestamp.isoformat(), + "span": span_data, + } + add_optional_fields( + event, + { + "model_id": self.model_id, + "trace_id": self.trace_id, + "span_id": self.span_id, + "parent_span_id": self.parent_span_id, + "run_id": self.run_id, + "agent_id": self.agent_id, + "step_index": self.step_index, + "conversation_id": self.conversation_id, + "context": self.context, + }, + ) + return event diff --git a/wildedge/integrations/common.py b/wildedge/integrations/common.py index 22a431a..1506b01 100644 --- a/wildedge/integrations/common.py +++ b/wildedge/integrations/common.py @@ -2,9 +2,15 @@ from __future__ import annotations -from typing import Any +from collections.abc import Callable +from typing import TYPE_CHECKING, Any +from wildedge import constants from wildedge.logging import logger +from wildedge.timing import elapsed_ms + +if TYPE_CHECKING: + from wildedge.model import ModelHandle def debug_failure(framework: str, context: str, exc: BaseException) -> None: @@ -110,3 +116,118 @@ def num_classes_from_output_shape(shape: tuple) -> int: if len(shape) >= 2 and isinstance(shape[-1], int) and shape[-1] > 1: return int(shape[-1]) return 0 + + +# --------------------------------------------------------------------------- +# Generic streaming wrappers +# --------------------------------------------------------------------------- +# Each integration provides: +# on_chunk(chunk) -> None : update mutable state from a single chunk +# on_done(duration_ms, ttft_ms) : record inference once the stream is exhausted +# +# The wrappers handle TTFT capture, error tracking, context-manager delegation, +# and attribute proxying so callers get a drop-in replacement for the raw stream. + + +class SyncStreamWrapper: + """Wraps a sync iterable stream to capture TTFT and record inference on exhaustion.""" + + def __init__( + self, + original: object, + handle: ModelHandle, + t0: float, + on_chunk: Callable[[object], None] | None, + on_done: Callable[[int, int | None], None], + ) -> None: + self._original = original + self._handle = handle + self._t0 = t0 + self._on_chunk = on_chunk + self._on_done = on_done + + def __iter__(self): + return self._track() + + def _track(self): + ttft_ms: int | None = None + try: + for chunk in self._original: # type: ignore[union-attr] + if ttft_ms is None: + ttft_ms = elapsed_ms(self._t0) + if self._on_chunk is not None: + self._on_chunk(chunk) + yield chunk + except Exception as exc: + self._handle.track_error( + error_code="UNKNOWN", + error_message=str(exc)[: constants.ERROR_MSG_MAX_LEN], + ) + raise + else: + self._on_done(elapsed_ms(self._t0), ttft_ms) + + def __enter__(self) -> SyncStreamWrapper: + if hasattr(self._original, "__enter__"): + self._original.__enter__() # type: ignore[union-attr] + return self + + def __exit__(self, *args: object) -> object: + if hasattr(self._original, "__exit__"): + return self._original.__exit__(*args) # type: ignore[union-attr] + return None + + def __getattr__(self, name: str) -> object: + return getattr(self._original, name) + + +class AsyncStreamWrapper: + """Wraps an async iterable stream to capture TTFT and record inference on exhaustion.""" + + def __init__( + self, + original: object, + handle: ModelHandle, + t0: float, + on_chunk: Callable[[object], None] | None, + on_done: Callable[[int, int | None], None], + ) -> None: + self._original = original + self._handle = handle + self._t0 = t0 + self._on_chunk = on_chunk + self._on_done = on_done + + def __aiter__(self): + return self._track() + + async def _track(self): + ttft_ms: int | None = None + try: + async for chunk in self._original: # type: ignore[union-attr] + if ttft_ms is None: + ttft_ms = elapsed_ms(self._t0) + if self._on_chunk is not None: + self._on_chunk(chunk) + yield chunk + except Exception as exc: + self._handle.track_error( + error_code="UNKNOWN", + error_message=str(exc)[: constants.ERROR_MSG_MAX_LEN], + ) + raise + else: + self._on_done(elapsed_ms(self._t0), ttft_ms) + + async def __aenter__(self) -> AsyncStreamWrapper: + if hasattr(self._original, "__aenter__"): + await self._original.__aenter__() # type: ignore[union-attr] + return self + + async def __aexit__(self, *args: object) -> object: + if hasattr(self._original, "__aexit__"): + return await self._original.__aexit__(*args) # type: ignore[union-attr] + return None + + def __getattr__(self, name: str) -> object: + return getattr(self._original, name) diff --git a/wildedge/integrations/gguf.py b/wildedge/integrations/gguf.py index 9937cdd..dc44360 100644 --- a/wildedge/integrations/gguf.py +++ b/wildedge/integrations/gguf.py @@ -12,7 +12,7 @@ from wildedge import constants from wildedge.events.inference import GenerationOutputMeta, TextInputMeta from wildedge.integrations.base import BaseExtractor, patch_instance_call_once -from wildedge.integrations.common import debug_failure +from wildedge.integrations.common import SyncStreamWrapper, debug_failure from wildedge.logging import logger from wildedge.model import ModelInfo from wildedge.platforms import CURRENT_PLATFORM @@ -69,6 +69,76 @@ def parse_quantization(filename: str) -> str | None: return None +def make_gguf_input_meta(prompt: object, tokens_in: int | None) -> TextInputMeta | None: + if not isinstance(prompt, str) or not prompt: + return None + return TextInputMeta( + char_count=len(prompt), + word_count=len(prompt.split()), + token_count=tokens_in, + ) + + +def make_gguf_output_meta( + tokens_in: int | None, + tokens_out: int | None, + stop_reason: str | None, + ttft_ms: int | None, + duration_ms: int, +) -> GenerationOutputMeta | None: + if tokens_out is None and ttft_ms is None: + return None + tps = ( + round(tokens_out / duration_ms * 1000, 1) + if duration_ms > 0 and tokens_out + else None + ) + return GenerationOutputMeta( + task="generation", + tokens_in=tokens_in, + tokens_out=tokens_out, + time_to_first_token_ms=ttft_ms, + tokens_per_second=tps, + stop_reason=stop_reason, + ) + + +def make_gguf_stream_callbacks(handle: object, prompt: object) -> tuple: + """Return (on_chunk, on_done) callbacks for a llama-cpp-python streaming response. + + Chunks are dicts; usage appears in the final chunk when available. + """ + tokens_in: list[int | None] = [None] + tokens_out: list[int | None] = [None] + stop_reason: list[str | None] = [None] + + def on_chunk(chunk: object) -> None: + if not isinstance(chunk, dict): + return + usage = chunk.get("usage") + if usage: + tokens_in[0] = usage.get("prompt_tokens") + tokens_out[0] = usage.get("completion_tokens") + choices = chunk.get("choices") or [] + if choices: + reason = choices[0].get("finish_reason") + if reason: + stop_reason[0] = reason + + def on_done(duration_ms: int, ttft_ms: int | None) -> None: + ti, to, sr = tokens_in[0], tokens_out[0], stop_reason[0] + handle.track_inference( # type: ignore[union-attr] + duration_ms=duration_ms, + input_modality="text", + output_modality="generation", + input_meta=make_gguf_input_meta(prompt, ti), + success=True, + output_meta=make_gguf_output_meta(ti, to, sr, ttft_ms, duration_ms), + ) + + return on_chunk, on_done + + def build_patched_call(original_call): def patched_call(self_inner, *args, **kwargs): handle = getattr(self_inner, GGUF_HANDLE_ATTR, None) @@ -76,9 +146,13 @@ def patched_call(self_inner, *args, **kwargs): return original_call(self_inner, *args, **kwargs) prompt = args[0] if args else kwargs.get("prompt", "") + is_streaming: bool = bool(kwargs.get("stream", False)) t0 = time.perf_counter() try: result = original_call(self_inner, *args, **kwargs) + if is_streaming: + on_chunk, on_done = make_gguf_stream_callbacks(handle, prompt) + return SyncStreamWrapper(result, handle, t0, on_chunk, on_done) duration_ms = elapsed_ms(t0) tokens_in = None tokens_out = None @@ -89,36 +163,15 @@ def patched_call(self_inner, *args, **kwargs): tokens_out = usage.get("completion_tokens") except Exception as exc: debug_gguf_failure("usage extraction", exc) - - input_meta = None - if isinstance(prompt, str) and prompt: - input_meta = TextInputMeta( - char_count=len(prompt), - word_count=len(prompt.split()), - token_count=tokens_in, - ) - - output_meta = None - if tokens_out is not None: - tps = ( - round(tokens_out / duration_ms * 1000, 1) - if duration_ms > 0 - else None - ) - output_meta = GenerationOutputMeta( - task="generation", - tokens_in=tokens_in, - tokens_out=tokens_out, - tokens_per_second=tps, - ) - handle.track_inference( duration_ms=duration_ms, input_modality="text", output_modality="generation", - input_meta=input_meta, + input_meta=make_gguf_input_meta(prompt, tokens_in), success=True, - output_meta=output_meta, + output_meta=make_gguf_output_meta( + tokens_in, tokens_out, None, None, duration_ms + ), ) return result except Exception as exc: diff --git a/wildedge/integrations/openai.py b/wildedge/integrations/openai.py index e5d1525..700aace 100644 --- a/wildedge/integrations/openai.py +++ b/wildedge/integrations/openai.py @@ -11,7 +11,11 @@ from wildedge import constants from wildedge.events.inference import ApiMeta, GenerationOutputMeta, TextInputMeta from wildedge.integrations.base import BaseExtractor -from wildedge.integrations.common import debug_failure +from wildedge.integrations.common import ( + AsyncStreamWrapper, + SyncStreamWrapper, + debug_failure, +) from wildedge.model import ModelInfo from wildedge.timing import elapsed_ms @@ -41,13 +45,21 @@ def source_from_base_url(base_url: str | None) -> str: return SOURCE_BY_HOSTNAME.get(hostname or "", hostname or "openai") +def _msg_role(m) -> str | None: + return m.get("role") if isinstance(m, dict) else getattr(m, "role", None) + + +def _msg_content(m) -> str | None: + return m.get("content") if isinstance(m, dict) else getattr(m, "content", None) + + def build_input_meta(messages: list, tokens_in: int | None) -> TextInputMeta | None: if not messages: return None - last_user = next((m for m in reversed(messages) if m.get("role") == "user"), None) + last_user = next((m for m in reversed(messages) if _msg_role(m) == "user"), None) if not last_user: return None - content = last_user.get("content", "") + content = _msg_content(last_user) or "" if not isinstance(content, str) or not content: return None return TextInputMeta( @@ -58,6 +70,28 @@ def build_input_meta(messages: list, tokens_in: int | None) -> TextInputMeta | N ) +def build_streaming_output_meta( + ttft_ms: int | None, + tokens_in: int | None, + tokens_out: int | None, + stop_reason: str | None, + duration_ms: int, +) -> GenerationOutputMeta: + tps = ( + round(tokens_out / duration_ms * 1000, 1) + if duration_ms > 0 and tokens_out + else None + ) + return GenerationOutputMeta( + task="generation", + tokens_in=tokens_in, + tokens_out=tokens_out, + time_to_first_token_ms=ttft_ms, + tokens_per_second=tps, + stop_reason=stop_reason, + ) + + def build_output_meta( response: object, duration_ms: int ) -> GenerationOutputMeta | None: @@ -145,6 +179,50 @@ def record_inference( ) +def make_openai_stream_callbacks( + handle: ModelHandle, + messages: list, +) -> tuple: + """Return (on_chunk, on_done) callbacks for an OpenAI streaming response. + + on_chunk updates mutable state from each ChatCompletionChunk. + on_done is called with (duration_ms, ttft_ms) when the stream is exhausted. + """ + tokens_in: list[int | None] = [None] + tokens_out: list[int | None] = [None] + stop_reason: list[str | None] = [None] + last_chunk: list[object] = [None] + + def on_chunk(chunk: object) -> None: + last_chunk[0] = chunk + chunk_usage = getattr(chunk, "usage", None) + if chunk_usage is not None: + tokens_in[0] = getattr(chunk_usage, "prompt_tokens", None) + tokens_out[0] = getattr(chunk_usage, "completion_tokens", None) + choices = getattr(chunk, "choices", None) or [] + if choices: + reason = getattr(choices[0], "finish_reason", None) + if reason: + stop_reason[0] = reason + + def on_done(duration_ms: int, ttft_ms: int | None) -> None: + ti, to, sr = tokens_in[0], tokens_out[0], stop_reason[0] + output_meta = build_streaming_output_meta(ttft_ms, ti, to, sr, duration_ms) + handle.track_inference( + duration_ms=duration_ms, + input_modality="text", + output_modality="generation", + success=True, + input_meta=build_input_meta(messages, ti), + output_meta=output_meta, + api_meta=build_api_meta(last_chunk[0]) + if last_chunk[0] is not None + else None, + ) + + return on_chunk, on_done + + def wrap_sync_completions(completions: object, source: str, client_ref: object) -> None: original_create = completions.create # type: ignore[attr-defined] model_handles: dict[str, ModelHandle] = {} @@ -160,8 +238,12 @@ def patched_create(*args, **kwargs): t0 = time.perf_counter() try: result = original_create(*args, **kwargs) - if not is_streaming and handle is not None: - record_inference(handle, result, messages, elapsed_ms(t0)) + if handle is not None: + if is_streaming: + on_chunk, on_done = make_openai_stream_callbacks(handle, messages) + return SyncStreamWrapper(result, handle, t0, on_chunk, on_done) + else: + record_inference(handle, result, messages, elapsed_ms(t0)) return result except Exception as exc: if handle is not None: @@ -191,8 +273,12 @@ async def patched_create(*args, **kwargs): t0 = time.perf_counter() try: result = await original_create(*args, **kwargs) - if not is_streaming and handle is not None: - record_inference(handle, result, messages, elapsed_ms(t0)) + if handle is not None: + if is_streaming: + on_chunk, on_done = make_openai_stream_callbacks(handle, messages) + return AsyncStreamWrapper(result, handle, t0, on_chunk, on_done) + else: + record_inference(handle, result, messages, elapsed_ms(t0)) return result except Exception as exc: if handle is not None: diff --git a/wildedge/model.py b/wildedge/model.py index de7e7b1..2502b14 100644 --- a/wildedge/model.py +++ b/wildedge/model.py @@ -28,6 +28,7 @@ from wildedge.logging import logger from wildedge.platforms import capture_hardware, is_sampling from wildedge.platforms.hardware import HardwareContext +from wildedge.tracing import _merge_correlation_fields @dataclass @@ -68,8 +69,26 @@ def track_load( accelerator: str | None = None, success: bool = True, error_code: str | None = None, + trace_id: str | None = None, + span_id: str | None = None, + parent_span_id: str | None = None, + run_id: str | None = None, + agent_id: str | None = None, + step_index: int | None = None, + conversation_id: str | None = None, + context: dict[str, Any] | None = None, **kwargs: Any, ) -> None: + correlation = _merge_correlation_fields( + trace_id=trace_id, + span_id=span_id, + parent_span_id=parent_span_id, + run_id=run_id, + agent_id=agent_id, + step_index=step_index, + conversation_id=conversation_id, + context=context, + ) event = ModelLoadEvent( model_id=self.model_id, duration_ms=duration_ms, @@ -77,6 +96,7 @@ def track_load( accelerator=accelerator or self.detected_accelerator, success=success, error_code=error_code, + **correlation, **kwargs, ) self.publish(event.to_dict()) @@ -89,7 +109,25 @@ def track_unload( memory_freed_bytes: int | None = None, peak_memory_bytes: int | None = None, uptime_ms: int | None = None, + trace_id: str | None = None, + span_id: str | None = None, + parent_span_id: str | None = None, + run_id: str | None = None, + agent_id: str | None = None, + step_index: int | None = None, + conversation_id: str | None = None, + context: dict[str, Any] | None = None, ) -> None: + correlation = _merge_correlation_fields( + trace_id=trace_id, + span_id=span_id, + parent_span_id=parent_span_id, + run_id=run_id, + agent_id=agent_id, + step_index=step_index, + conversation_id=conversation_id, + context=context, + ) event = ModelUnloadEvent( model_id=self.model_id, duration_ms=duration_ms, @@ -97,6 +135,7 @@ def track_unload( memory_freed_bytes=memory_freed_bytes, peak_memory_bytes=peak_memory_bytes, uptime_ms=uptime_ms, + **correlation, ) self.publish(event.to_dict()) @@ -111,8 +150,26 @@ def track_download( resumed: bool, cache_hit: bool, success: bool, + trace_id: str | None = None, + span_id: str | None = None, + parent_span_id: str | None = None, + run_id: str | None = None, + agent_id: str | None = None, + step_index: int | None = None, + conversation_id: str | None = None, + context: dict[str, Any] | None = None, **kwargs: Any, ) -> None: + correlation = _merge_correlation_fields( + trace_id=trace_id, + span_id=span_id, + parent_span_id=parent_span_id, + run_id=run_id, + agent_id=agent_id, + step_index=step_index, + conversation_id=conversation_id, + context=context, + ) event = ModelDownloadEvent( model_id=self.model_id, source_url=source_url, @@ -124,6 +181,7 @@ def track_download( resumed=resumed, cache_hit=cache_hit, success=success, + **correlation, **kwargs, ) self.publish(event.to_dict()) @@ -146,9 +204,27 @@ def track_inference( generation_config: GenerationConfig | None = None, hardware: HardwareContext | None = None, api_meta: ApiMeta | None = None, + trace_id: str | None = None, + span_id: str | None = None, + parent_span_id: str | None = None, + run_id: str | None = None, + agent_id: str | None = None, + step_index: int | None = None, + conversation_id: str | None = None, + context: dict[str, Any] | None = None, ) -> str: if hardware is None and is_sampling(): hardware = capture_hardware() + correlation = _merge_correlation_fields( + trace_id=trace_id, + span_id=span_id, + parent_span_id=parent_span_id, + run_id=run_id, + agent_id=agent_id, + step_index=step_index, + conversation_id=conversation_id, + context=context, + ) event = InferenceEvent( model_id=self.model_id, duration_ms=duration_ms, @@ -162,6 +238,7 @@ def track_inference( generation_config=generation_config, hardware=hardware, api_meta=api_meta, + **correlation, ) self.last_inference_id = event.inference_id self.publish(event.to_dict()) @@ -174,13 +251,32 @@ def track_feedback( *, delay_ms: int | None = None, edit_distance: int | None = None, + trace_id: str | None = None, + span_id: str | None = None, + parent_span_id: str | None = None, + run_id: str | None = None, + agent_id: str | None = None, + step_index: int | None = None, + conversation_id: str | None = None, + context: dict[str, Any] | None = None, ) -> None: + correlation = _merge_correlation_fields( + trace_id=trace_id, + span_id=span_id, + parent_span_id=parent_span_id, + run_id=run_id, + agent_id=agent_id, + step_index=step_index, + conversation_id=conversation_id, + context=context, + ) event = FeedbackEvent( model_id=self.model_id, related_inference_id=related_inference_id, feedback_type=feedback_type, delay_ms=delay_ms, edit_distance=edit_distance, + **correlation, ) self.publish(event.to_dict()) @@ -205,13 +301,32 @@ def track_error( error_message: str | None = None, stack_trace_hash: str | None = None, related_event_id: str | None = None, + trace_id: str | None = None, + span_id: str | None = None, + parent_span_id: str | None = None, + run_id: str | None = None, + agent_id: str | None = None, + step_index: int | None = None, + conversation_id: str | None = None, + context: dict[str, Any] | None = None, ) -> None: + correlation = _merge_correlation_fields( + trace_id=trace_id, + span_id=span_id, + parent_span_id=parent_span_id, + run_id=run_id, + agent_id=agent_id, + step_index=step_index, + conversation_id=conversation_id, + context=context, + ) event = ErrorEvent( model_id=self.model_id, error_code=error_code, error_message=error_message, stack_trace_hash=stack_trace_hash, related_event_id=related_event_id, + **correlation, ) self.publish(event.to_dict()) diff --git a/wildedge/tracing.py b/wildedge/tracing.py new file mode 100644 index 0000000..a205543 --- /dev/null +++ b/wildedge/tracing.py @@ -0,0 +1,167 @@ +from __future__ import annotations + +import contextlib +import contextvars +import uuid +from dataclasses import dataclass +from typing import Any + + +@dataclass(frozen=True) +class TraceContext: + trace_id: str + run_id: str | None = None + agent_id: str | None = None + conversation_id: str | None = None + attributes: dict[str, Any] | None = None + + +@dataclass(frozen=True) +class SpanContext: + span_id: str + parent_span_id: str | None = None + step_index: int | None = None + attributes: dict[str, Any] | None = None + + +_TRACE_CTX: contextvars.ContextVar[TraceContext | None] = contextvars.ContextVar( + "wildedge_trace_ctx", default=None +) +_SPAN_CTX: contextvars.ContextVar[SpanContext | None] = contextvars.ContextVar( + "wildedge_span_ctx", default=None +) + + +def get_trace_context() -> TraceContext | None: + return _TRACE_CTX.get() + + +def get_span_context() -> SpanContext | None: + return _SPAN_CTX.get() + + +def _set_trace_context(ctx: TraceContext) -> contextvars.Token: + return _TRACE_CTX.set(ctx) + + +def _reset_trace_context(token: contextvars.Token) -> None: + _TRACE_CTX.reset(token) + + +def _set_span_context(ctx: SpanContext) -> contextvars.Token: + return _SPAN_CTX.set(ctx) + + +def _reset_span_context(token: contextvars.Token) -> None: + _SPAN_CTX.reset(token) + + +@contextlib.contextmanager +def trace_context( + *, + trace_id: str | None = None, + run_id: str | None = None, + agent_id: str | None = None, + conversation_id: str | None = None, + attributes: dict[str, Any] | None = None, +): + if trace_id is None: + trace_id = str(uuid.uuid4()) + token = _set_trace_context( + TraceContext( + trace_id=trace_id, + run_id=run_id, + agent_id=agent_id, + conversation_id=conversation_id, + attributes=attributes, + ) + ) + try: + yield get_trace_context() + finally: + _reset_trace_context(token) + + +@contextlib.contextmanager +def span_context( + *, + span_id: str | None = None, + parent_span_id: str | None = None, + step_index: int | None = None, + attributes: dict[str, Any] | None = None, +): + """Low-level context manager that sets span correlation fields without emitting a span event. + + Prefer client.span() for most use cases. Use this only when you need correlation + fields attached to auto-instrumented events (e.g. an OpenAI call) without emitting + a redundant span wrapper. + """ + if span_id is None: + span_id = str(uuid.uuid4()) + if parent_span_id is None: + current = get_span_context() + parent_span_id = current.span_id if current else None + token = _set_span_context( + SpanContext( + span_id=span_id, + parent_span_id=parent_span_id, + step_index=step_index, + attributes=attributes, + ) + ) + try: + yield get_span_context() + finally: + _reset_span_context(token) + + +def _merge_context(*candidates: dict[str, Any] | None) -> dict[str, Any] | None: + merged: dict[str, Any] = {} + for attrs in candidates: + if not attrs: + continue + merged.update(attrs) + return merged or None + + +def _merge_correlation_fields( + *, + trace_id: str | None = None, + span_id: str | None = None, + parent_span_id: str | None = None, + run_id: str | None = None, + agent_id: str | None = None, + step_index: int | None = None, + conversation_id: str | None = None, + context: dict[str, Any] | None = None, +) -> dict[str, Any]: + trace = get_trace_context() + span = get_span_context() + + resolved_trace_id = trace_id or (trace.trace_id if trace else None) + resolved_span_id = span_id + resolved_parent_span_id = parent_span_id or (span.span_id if span else None) + resolved_run_id = run_id or (trace.run_id if trace else None) + resolved_agent_id = agent_id or (trace.agent_id if trace else None) + resolved_step_index = ( + step_index if step_index is not None else (span.step_index if span else None) + ) + resolved_conversation_id = conversation_id or ( + trace.conversation_id if trace else None + ) + resolved_context = _merge_context( + trace.attributes if trace else None, + span.attributes if span else None, + context, + ) + + return { + "trace_id": resolved_trace_id, + "span_id": resolved_span_id, + "parent_span_id": resolved_parent_span_id, + "run_id": resolved_run_id, + "agent_id": resolved_agent_id, + "step_index": resolved_step_index, + "conversation_id": resolved_conversation_id, + "context": resolved_context, + }