diff --git a/.gitignore b/.gitignore index 64d49ae..468c518 100644 --- a/.gitignore +++ b/.gitignore @@ -213,4 +213,5 @@ marimo/_lsp/ __marimo__/ # Streamlit -.streamlit/secrets.toml \ No newline at end of file +.streamlit/secrets.toml +internal/ diff --git a/README.md b/README.md index 1c24d0b..1777cd2 100644 --- a/README.md +++ b/README.md @@ -59,7 +59,6 @@ client = wildedge.init( If no DSN is configured, the client becomes a no-op and logs a warning. `init(...)` is a convenience wrapper for `WildEdge(...)` + `instrument(...)`. - ## Supported integrations **On-device** diff --git a/docs/manual-tracking.md b/docs/manual-tracking.md index f90f58a..0fa227e 100644 --- a/docs/manual-tracking.md +++ b/docs/manual-tracking.md @@ -215,3 +215,58 @@ handle.feedback(FeedbackType.THUMBS_DOWN) ``` `FeedbackType` values: `THUMBS_UP`, `THUMBS_DOWN`. + +## Track spans for agentic workflows + +Use span events to track non-inference steps like planning, tool calls, retrieval, or memory updates. + +```python +from wildedge.timing import Timer + +with Timer() as t: + tool_result = call_tool() + +client.track_span( + kind="tool", + name="call_tool", + duration_ms=t.elapsed_ms, + status="ok", + attributes={"tool": "search"}, +) +``` + +You can also attach optional correlation fields (`trace_id`, `span_id`, +`parent_span_id`, `run_id`, `agent_id`, `step_index`, `conversation_id`) to any +event by passing them into `track_inference`, `track_error`, `track_feedback`, +or `track_span`. Use `context=` for correlation attributes shared across events. + +### Trace context helpers + +Use `client.trace()` and `client.span()` to auto-populate correlation fields for +all events emitted inside the block. `client.span()` times the block and emits a +span event on exit: + +```python +import wildedge +from wildedge.timing import Timer + +client = wildedge.init() +handle = client.register_model(my_model, model_id="my-org/my-model") + +with client.trace(run_id="run-123", agent_id="agent-1"): + with client.span(kind="agent_step", name="plan", step_index=1): + with Timer() as t: + result = my_model(prompt) + handle.track_inference(duration_ms=t.elapsed_ms, input_modality="text", output_modality="generation") +``` + +If you need to set correlation fields without emitting a span event, use the +lower-level `span_context()` directly: + +```python +with client.trace(run_id="run-123", agent_id="agent-1"): + with wildedge.span_context(step_index=1): + with Timer() as t: + result = my_model(prompt) + handle.track_inference(duration_ms=t.elapsed_ms, input_modality="text", output_modality="generation") +``` diff --git a/examples/agentic_example.py b/examples/agentic_example.py new file mode 100644 index 0000000..446c812 --- /dev/null +++ b/examples/agentic_example.py @@ -0,0 +1,181 @@ +# /// script +# requires-python = ">=3.10" +# dependencies = ["wildedge-sdk", "openai"] +# +# [tool.uv.sources] +# wildedge-sdk = { path = "..", editable = true } +# /// +"""Agentic workflow example with tool use. + +Demonstrates WildEdge tracing for a simple agent that: + - Runs within a trace (one per agent session) + - Wraps each reasoning step in an agent_step span + - Wraps each tool call in a tool span + - Tracks LLM inference automatically via the OpenAI integration + +Run with: uv run agentic_example.py +Requires: OPENROUTER_API_KEY environment variable. Set WILDEDGE_DSN to send events. +""" + +import json +import os +import time +import uuid + +from openai import OpenAI + +import wildedge + +we = wildedge.init( + app_version="1.0.0", + integrations="openai", +) + +openai_client = OpenAI( + base_url="https://openrouter.ai/api/v1", + api_key=os.getenv("OPENROUTER_API_KEY"), +) + +# --- Tools ------------------------------------------------------------------- + +TOOLS = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Return current weather for a city.", + "parameters": { + "type": "object", + "properties": { + "city": {"type": "string"}, + }, + "required": ["city"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "calculator", + "description": "Evaluate a simple arithmetic expression.", + "parameters": { + "type": "object", + "properties": { + "expression": {"type": "string"}, + }, + "required": ["expression"], + }, + }, + }, +] + + +def get_weather(city: str) -> str: + # ~150ms to simulate a real weather API call. + time.sleep(0.15) + return json.dumps({"city": city, "temperature_c": 18, "condition": "partly cloudy"}) + + +def calculator(expression: str) -> str: + # ~60ms to simulate a remote computation call. + time.sleep(0.06) + try: + result = eval(expression, {"__builtins__": {}}) # noqa: S307 + return json.dumps({"expression": expression, "result": result}) + except Exception as e: + return json.dumps({"error": str(e)}) + + +TOOL_HANDLERS = { + "get_weather": get_weather, + "calculator": calculator, +} + + +# --- Agent loop -------------------------------------------------------------- + + +def call_tool(name: str, arguments: dict) -> str: + with we.span( + kind="tool", + name=name, + input_summary=json.dumps(arguments)[:200], + ) as span: + result = TOOL_HANDLERS[name](**arguments) + span.output_summary = result[:200] + return result + + +def retrieve_context(query: str) -> str: + """Fetch relevant context from the vector store (~120ms).""" + with we.span( + kind="retrieval", + name="vector_search", + input_summary=query[:200], + ) as span: + time.sleep(0.12) + result = f"[context: background knowledge relevant to '{query[:40]}']" + span.output_summary = result + return result + + +def run_agent(task: str, step_index: int, messages: list) -> str: + # Fetch context before the first reasoning step, include it in the user turn. + context = retrieve_context(task) + messages.append({"role": "user", "content": f"{task}\n\nContext: {context}"}) + + while True: + with we.span( + kind="agent_step", + name="reason", + step_index=step_index, + input_summary=task[:200], + ) as span: + response = openai_client.chat.completions.create( + model="qwen/qwen3.5-flash-02-23", + messages=messages, + tools=TOOLS, + tool_choice="auto", + max_tokens=512, + ) + choice = response.choices[0] + span.output_summary = choice.finish_reason + + messages.append(choice.message.model_dump(exclude_none=True)) + + if choice.finish_reason == "tool_calls": + step_index += 1 + for tool_call in choice.message.tool_calls: + arguments = json.loads(tool_call.function.arguments) + result = call_tool(tool_call.function.name, arguments) + messages.append( + { + "role": "tool", + "tool_call_id": tool_call.id, + "content": result, + } + ) + # Not instrumented: context window update between tool calls (~80ms). + # Shows up as a gap stripe in the trace view. + time.sleep(0.08) + else: + return choice.message.content or "" + + +# --- Main -------------------------------------------------------------------- + +TASKS = [ + "What's the weather like in Tokyo, and what is 42 * 18?", + "Is it warmer in Paris or Berlin right now?", +] + +system_prompt = "You are a helpful assistant. Use tools when needed." +messages = [{"role": "system", "content": system_prompt}] + +with we.trace(agent_id="demo-agent", run_id=str(uuid.uuid4())): + for i, task in enumerate(TASKS, start=1): + print(f"\nTask {i}: {task}") + reply = run_agent(task, step_index=i, messages=messages) + print(f"Reply: {reply}") + +we.flush() diff --git a/tests/test_event_serialization.py b/tests/test_event_serialization.py index 12028c7..587940b 100644 --- a/tests/test_event_serialization.py +++ b/tests/test_event_serialization.py @@ -4,6 +4,7 @@ from wildedge.events.inference import InferenceEvent, TextInputMeta from wildedge.events.model_download import AdapterDownload, ModelDownloadEvent from wildedge.events.model_load import AdapterLoad, ModelLoadEvent +from wildedge.events.span import SpanEvent def test_inference_event_to_dict_omits_none_fields(): @@ -72,3 +73,44 @@ def test_feedback_event_enum_and_string_forms(): ) assert enum_event.to_dict()["feedback"]["feedback_type"] == "accept" assert string_event.to_dict()["feedback"]["feedback_type"] == "reject" + + +def test_span_event_to_dict_includes_required_fields(): + event = SpanEvent( + kind="tool", + name="search", + duration_ms=250, + status="ok", + attributes={"provider": "custom"}, + ) + data = event.to_dict() + assert data["event_type"] == "span" + assert data["span"]["kind"] == "tool" + assert data["span"]["attributes"]["provider"] == "custom" + + +def test_span_event_context_serializes_under_context_key(): + event = SpanEvent( + kind="agent_step", + name="plan", + duration_ms=10, + status="ok", + context={"user_id": "u1"}, + ) + data = event.to_dict() + assert data["context"] == {"user_id": "u1"} + assert "attributes" not in data + + +def test_span_event_attributes_and_context_are_independent(): + event = SpanEvent( + kind="tool", + name="search", + duration_ms=50, + status="ok", + attributes={"provider": "custom"}, + context={"user_id": "u1"}, + ) + data = event.to_dict() + assert data["span"]["attributes"] == {"provider": "custom"} + assert data["context"] == {"user_id": "u1"} diff --git a/tests/test_tracing.py b/tests/test_tracing.py new file mode 100644 index 0000000..944f1e0 --- /dev/null +++ b/tests/test_tracing.py @@ -0,0 +1,93 @@ +from __future__ import annotations + +from wildedge.client import SpanContextManager +from wildedge.model import ModelHandle, ModelInfo +from wildedge.tracing import get_span_context, span_context, trace_context + + +def test_track_inference_uses_trace_context(): + events: list[dict] = [] + + def publish(event: dict) -> None: + events.append(event) + + handle = ModelHandle( + "model-1", + ModelInfo( + model_name="test", + model_version="1.0", + model_source="local", + model_format="onnx", + ), + publish, + ) + + with trace_context( + trace_id="trace-123", + run_id="run-1", + agent_id="agent-1", + attributes={"trace_key": "trace_val"}, + ): + with span_context(span_id="span-abc", step_index=2, attributes={"span_key": 2}): + handle.track_inference(duration_ms=5) + + assert events[0]["trace_id"] == "trace-123" + assert events[0]["parent_span_id"] == "span-abc" + assert events[0]["run_id"] == "run-1" + assert events[0]["agent_id"] == "agent-1" + assert events[0]["step_index"] == 2 + assert events[0]["attributes"] == {"trace_key": "trace_val", "span_key": 2} + + +class _FakeClient: + def __init__(self, events: list[dict]) -> None: + self._events = events + + def track_span(self, **kwargs) -> str: + self._events.append(kwargs) + return kwargs.get("span_id", "") + + +def test_span_root_has_no_parent(): + """A root span must not reference itself as its own parent.""" + events: list[dict] = [] + client = _FakeClient(events) + + with SpanContextManager(client, kind="agent_step", name="root"): + pass + + assert len(events) == 1 + assert events[0]["parent_span_id"] is None + + +def test_span_context_restored_after_exit(): + """The active span context must revert to the parent after a span exits.""" + events: list[dict] = [] + client = _FakeClient(events) + + with span_context(span_id="parent-span"): + with SpanContextManager(client, kind="agent_step", name="child"): + inner_id = get_span_context().span_id + + assert get_span_context().span_id == "parent-span" + + assert inner_id != "parent-span" + assert events[0]["parent_span_id"] == "parent-span" + assert events[0]["span_id"] != "parent-span" + + +def test_nested_spans_correct_parent_chain(): + """Nested spans must each point to their direct parent, not themselves.""" + events: list[dict] = [] + client = _FakeClient(events) + + with SpanContextManager(client, kind="agent_step", name="outer") as outer: + with SpanContextManager(client, kind="tool", name="inner") as inner: + pass + + assert len(events) == 2 + inner_ev, outer_ev = events[0], events[1] + assert inner_ev["span_id"] == inner.span_id + assert inner_ev["parent_span_id"] == outer.span_id + assert outer_ev["span_id"] == outer.span_id + assert outer_ev["parent_span_id"] is None diff --git a/wildedge/__init__.py b/wildedge/__init__.py index f1f449b..2b7993f 100644 --- a/wildedge/__init__.py +++ b/wildedge/__init__.py @@ -1,6 +1,6 @@ """WildEdge Python SDK.""" -from wildedge.client import WildEdge +from wildedge.client import SpanContextManager, WildEdge from wildedge.convenience import init from wildedge.decorators import track from wildedge.events import ( @@ -15,13 +15,22 @@ GenerationConfig, GenerationOutputMeta, ImageInputMeta, + SpanEvent, TextInputMeta, ) +from wildedge.events.span import SpanKind, SpanStatus from wildedge.platforms import capture_hardware from wildedge.platforms.device_info import DeviceInfo from wildedge.platforms.hardware import HardwareContext, ThermalContext from wildedge.queue import QueuePolicy from wildedge.timing import Timer +from wildedge.tracing import ( + SpanContext, + TraceContext, + get_span_context, + get_trace_context, + span_context, +) __all__ = [ "WildEdge", @@ -42,7 +51,16 @@ "GenerationConfig", "AdapterLoad", "AdapterDownload", + "SpanEvent", "FeedbackType", "ErrorCode", "Timer", + "span_context", + "TraceContext", + "SpanContext", + "get_trace_context", + "get_span_context", + "SpanKind", + "SpanStatus", + "SpanContextManager", ] diff --git a/wildedge/client.py b/wildedge/client.py index da5e052..627cbc1 100644 --- a/wildedge/client.py +++ b/wildedge/client.py @@ -12,6 +12,8 @@ from wildedge import constants from wildedge.consumer import Consumer from wildedge.dead_letters import DeadLetterStore +from wildedge.events import SpanEvent +from wildedge.events.span import SpanKind, SpanStatus from wildedge.hubs.base import BaseHubTracker from wildedge.hubs.huggingface import HuggingFaceHubTracker from wildedge.hubs.registry import supported_hubs @@ -39,6 +41,14 @@ from wildedge.queue import EventQueue, QueuePolicy from wildedge.settings import read_client_env, resolve_app_identity from wildedge.timing import Timer, elapsed_ms +from wildedge.tracing import ( + SpanContext, + _merge_correlation_fields, + _reset_span_context, + _set_span_context, + get_span_context, + trace_context, +) from wildedge.transmitter import Transmitter DSN_FORMAT = "'https://@ingest.wildedge.dev/'" @@ -103,6 +113,101 @@ def parse_dsn(dsn: str) -> tuple[str, str, str]: ] +class SpanContextManager: + def __init__( + self, + client: WildEdge, + *, + kind: SpanKind, + name: str, + status: SpanStatus = "ok", + model_id: str | None = None, + input_summary: str | None = None, + output_summary: str | None = None, + attributes: dict[str, Any] | None = None, + trace_id: str | None = None, + span_id: str | None = None, + parent_span_id: str | None = None, + run_id: str | None = None, + agent_id: str | None = None, + step_index: int | None = None, + conversation_id: str | None = None, + context: dict[str, Any] | None = None, + ): + self._client = client + self.kind = kind + self.name = name + self.status = status + self.model_id = model_id + self.input_summary = input_summary + self.output_summary = output_summary + self.attributes = attributes + self.trace_id = trace_id + self.span_id = span_id + self.parent_span_id = parent_span_id + self.run_id = run_id + self.agent_id = agent_id + self.step_index = step_index + self.conversation_id = conversation_id + self.context = context + self._t0: float | None = None + self._span_token = None + + def __enter__(self): + self._t0 = time.perf_counter() + if self.span_id is None: + self.span_id = str(uuid.uuid4()) + if self.parent_span_id is None: + current = get_span_context() + self.parent_span_id = current.span_id if current else None + self._span_token = _set_span_context( + SpanContext( + span_id=self.span_id, + parent_span_id=self.parent_span_id, + step_index=self.step_index, + attributes=self.context, + ) + ) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self._t0 is None: + return False + duration_ms = elapsed_ms(self._t0) + status = "error" if exc_type else self.status + # Restore parent span context before emitting, so _merge_correlation_fields + # sees the parent context rather than this span (which would make the span + # appear as its own parent). + if self._span_token is not None: + _reset_span_context(self._span_token) + self._span_token = None + self._client.track_span( + kind=self.kind, + name=self.name, + duration_ms=duration_ms, + status=status, + model_id=self.model_id, + input_summary=self.input_summary, + output_summary=self.output_summary, + attributes=self.attributes, + trace_id=self.trace_id, + span_id=self.span_id, + parent_span_id=self.parent_span_id, + run_id=self.run_id, + agent_id=self.agent_id, + step_index=self.step_index, + conversation_id=self.conversation_id, + context=self.context, + ) + return False + + async def __aenter__(self): + return self.__enter__() + + async def __aexit__(self, exc_type, exc_val, exc_tb): + return self.__exit__(exc_type, exc_val, exc_tb) + + class WildEdge: """ WildEdge on-device ML monitoring client. @@ -381,6 +486,110 @@ def register_model( return handle + def trace( + self, + *, + trace_id: str | None = None, + run_id: str | None = None, + agent_id: str | None = None, + conversation_id: str | None = None, + attributes: dict[str, Any] | None = None, + ): + """Context manager that sets trace correlation fields.""" + return trace_context( + trace_id=trace_id, + run_id=run_id, + agent_id=agent_id, + conversation_id=conversation_id, + attributes=attributes, + ) + + def span( + self, + *, + kind: SpanKind, + name: str, + status: SpanStatus = "ok", + model_id: str | None = None, + input_summary: str | None = None, + output_summary: str | None = None, + attributes: dict[str, Any] | None = None, + trace_id: str | None = None, + span_id: str | None = None, + parent_span_id: str | None = None, + run_id: str | None = None, + agent_id: str | None = None, + step_index: int | None = None, + conversation_id: str | None = None, + context: dict[str, Any] | None = None, + ) -> SpanContextManager: + """Context manager that times and emits a span event.""" + return SpanContextManager( + self, + kind=kind, + name=name, + status=status, + model_id=model_id, + input_summary=input_summary, + output_summary=output_summary, + attributes=attributes, + trace_id=trace_id, + span_id=span_id, + parent_span_id=parent_span_id, + run_id=run_id, + agent_id=agent_id, + step_index=step_index, + conversation_id=conversation_id, + context=context, + ) + + def track_span( + self, + *, + kind: SpanKind, + name: str, + duration_ms: int, + status: SpanStatus = "ok", + model_id: str | None = None, + input_summary: str | None = None, + output_summary: str | None = None, + attributes: dict[str, Any] | None = None, + trace_id: str | None = None, + span_id: str | None = None, + parent_span_id: str | None = None, + run_id: str | None = None, + agent_id: str | None = None, + step_index: int | None = None, + conversation_id: str | None = None, + context: dict[str, Any] | None = None, + ) -> str: + """Emit a span event for agentic workflows and tooling.""" + correlation = _merge_correlation_fields( + trace_id=trace_id, + span_id=span_id, + parent_span_id=parent_span_id, + run_id=run_id, + agent_id=agent_id, + step_index=step_index, + conversation_id=conversation_id, + context=context, + ) + if correlation["span_id"] is None: + correlation["span_id"] = str(uuid.uuid4()) + event = SpanEvent( + kind=kind, + name=name, + duration_ms=duration_ms, + status=status, + model_id=model_id, + input_summary=input_summary, + output_summary=output_summary, + attributes=attributes, + **correlation, + ) + self.publish(event.to_dict()) + return correlation["span_id"] + def _find_extractor(self, model_obj: object) -> BaseExtractor | None: for candidate in DEFAULT_EXTRACTORS: if candidate.can_handle(model_obj): diff --git a/wildedge/decorators.py b/wildedge/decorators.py index fe2e3ff..f953159 100644 --- a/wildedge/decorators.py +++ b/wildedge/decorators.py @@ -38,6 +38,14 @@ def __init__( input_meta: Any = None, output_meta: Any = None, generation_config: Any = None, + trace_id: str | None = None, + span_id: str | None = None, + parent_span_id: str | None = None, + run_id: str | None = None, + agent_id: str | None = None, + step_index: int | None = None, + conversation_id: str | None = None, + context: dict[str, Any] | None = None, ): self.handle = handle self.input_type = input_type @@ -46,6 +54,16 @@ def __init__( self.input_meta = input_meta self.output_meta = output_meta self.generation_config = generation_config + self._correlation = dict( + trace_id=trace_id, + span_id=span_id, + parent_span_id=parent_span_id, + run_id=run_id, + agent_id=agent_id, + step_index=step_index, + conversation_id=conversation_id, + context=context, + ) self.start_time: float | None = None def __call__(self, func): @@ -62,6 +80,7 @@ def wrapper(*args, **kwargs): input_meta=self.input_meta, output_meta=self.output_meta, generation_config=self.generation_config, + **self._correlation, ) return result except Exception as exc: @@ -69,6 +88,7 @@ def wrapper(*args, **kwargs): self.handle.track_error( error_code="UNKNOWN", error_message=str(exc)[: constants.ERROR_MSG_MAX_LEN], + **self._correlation, ) raise @@ -89,6 +109,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): error_message=str(exc_val)[: constants.ERROR_MSG_MAX_LEN] if exc_val else None, + **self._correlation, ) else: self.handle.track_inference( @@ -99,5 +120,6 @@ def __exit__(self, exc_type, exc_val, exc_tb): input_meta=self.input_meta, output_meta=self.output_meta, generation_config=self.generation_config, + **self._correlation, ) return False diff --git a/wildedge/events/__init__.py b/wildedge/events/__init__.py index b08bc78..6648c40 100644 --- a/wildedge/events/__init__.py +++ b/wildedge/events/__init__.py @@ -19,6 +19,7 @@ from wildedge.events.model_download import AdapterDownload, ModelDownloadEvent from wildedge.events.model_load import AdapterLoad, ModelLoadEvent from wildedge.events.model_unload import ModelUnloadEvent +from wildedge.events.span import SpanEvent __all__ = [ "ApiMeta", @@ -40,6 +41,7 @@ "ModelDownloadEvent", "ModelLoadEvent", "ModelUnloadEvent", + "SpanEvent", "TextInputMeta", "TopKPrediction", ] diff --git a/wildedge/events/common.py b/wildedge/events/common.py new file mode 100644 index 0000000..399b55b --- /dev/null +++ b/wildedge/events/common.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +from typing import Any + + +def add_optional_fields(event: dict, fields: dict[str, Any]) -> dict: + """Add non-None fields to an event payload.""" + for key, value in fields.items(): + if value is not None: + event[key] = value + return event diff --git a/wildedge/events/error.py b/wildedge/events/error.py index 13d29f1..e814120 100644 --- a/wildedge/events/error.py +++ b/wildedge/events/error.py @@ -23,6 +23,14 @@ class ErrorEvent: error_message: str | None = None stack_trace_hash: str | None = None related_event_id: str | None = None + trace_id: str | None = None + span_id: str | None = None + parent_span_id: str | None = None + run_id: str | None = None + agent_id: str | None = None + step_index: int | None = None + conversation_id: str | None = None + context: dict[str, Any] | None = None event_id: str = field(default_factory=lambda: str(uuid.uuid4())) timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) @@ -40,10 +48,26 @@ def to_dict(self) -> dict: if self.related_event_id is not None: error_data["related_event_id"] = self.related_event_id - return { + event = { "event_id": self.event_id, "event_type": "error", "timestamp": self.timestamp.isoformat(), "model_id": self.model_id, "error": error_data, } + from wildedge.events.common import add_optional_fields + + add_optional_fields( + event, + { + "trace_id": self.trace_id, + "span_id": self.span_id, + "parent_span_id": self.parent_span_id, + "run_id": self.run_id, + "agent_id": self.agent_id, + "step_index": self.step_index, + "conversation_id": self.conversation_id, + "attributes": self.context, + }, + ) + return event diff --git a/wildedge/events/feedback.py b/wildedge/events/feedback.py index 650904b..dfb6efa 100644 --- a/wildedge/events/feedback.py +++ b/wildedge/events/feedback.py @@ -24,6 +24,14 @@ class FeedbackEvent: feedback_type: str | FeedbackType delay_ms: int | None = None edit_distance: int | None = None + trace_id: str | None = None + span_id: str | None = None + parent_span_id: str | None = None + run_id: str | None = None + agent_id: str | None = None + step_index: int | None = None + conversation_id: str | None = None + context: dict[str, Any] | None = None event_id: str = field(default_factory=lambda: str(uuid.uuid4())) timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) @@ -42,10 +50,26 @@ def to_dict(self) -> dict: if self.edit_distance is not None: feedback_data["edit_distance"] = self.edit_distance - return { + event = { "event_id": self.event_id, "event_type": "feedback", "timestamp": self.timestamp.isoformat(), "model_id": self.model_id, "feedback": feedback_data, } + from wildedge.events.common import add_optional_fields + + add_optional_fields( + event, + { + "trace_id": self.trace_id, + "span_id": self.span_id, + "parent_span_id": self.parent_span_id, + "run_id": self.run_id, + "agent_id": self.agent_id, + "step_index": self.step_index, + "conversation_id": self.conversation_id, + "attributes": self.context, + }, + ) + return event diff --git a/wildedge/events/inference.py b/wildedge/events/inference.py index 0c02996..ef76ae5 100644 --- a/wildedge/events/inference.py +++ b/wildedge/events/inference.py @@ -304,6 +304,14 @@ class InferenceEvent: generation_config: GenerationConfig | None = None hardware: HardwareContext | None = None api_meta: ApiMeta | None = None + trace_id: str | None = None + span_id: str | None = None + parent_span_id: str | None = None + run_id: str | None = None + agent_id: str | None = None + step_index: int | None = None + conversation_id: str | None = None + context: dict[str, Any] | None = None event_id: str = field(default_factory=lambda: str(uuid.uuid4())) timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) inference_id: str = field(default_factory=lambda: str(uuid.uuid4())) @@ -333,10 +341,26 @@ def to_dict(self) -> dict: if self.api_meta is not None: inference_data["api_meta"] = self.api_meta.to_dict() - return { + event = { "event_id": self.event_id, "event_type": "inference", "timestamp": self.timestamp.isoformat(), "model_id": self.model_id, "inference": inference_data, } + from wildedge.events.common import add_optional_fields + + add_optional_fields( + event, + { + "trace_id": self.trace_id, + "span_id": self.span_id, + "parent_span_id": self.parent_span_id, + "run_id": self.run_id, + "agent_id": self.agent_id, + "step_index": self.step_index, + "conversation_id": self.conversation_id, + "attributes": self.context, + }, + ) + return event diff --git a/wildedge/events/model_download.py b/wildedge/events/model_download.py index 784c62f..2e68a3a 100644 --- a/wildedge/events/model_download.py +++ b/wildedge/events/model_download.py @@ -49,6 +49,14 @@ class ModelDownloadEvent: cdn_edge: str | None = None error_code: str | None = None adapter: AdapterDownload | None = None + trace_id: str | None = None + span_id: str | None = None + parent_span_id: str | None = None + run_id: str | None = None + agent_id: str | None = None + step_index: int | None = None + conversation_id: str | None = None + context: dict[str, Any] | None = None event_id: str = field(default_factory=lambda: str(uuid.uuid4())) timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) @@ -83,10 +91,26 @@ def to_dict(self) -> dict: if self.adapter is not None: download_data["adapter"] = self.adapter.to_dict() - return { + event = { "event_id": self.event_id, "event_type": "model_download", "timestamp": self.timestamp.isoformat(), "model_id": self.model_id, "download": download_data, } + from wildedge.events.common import add_optional_fields + + add_optional_fields( + event, + { + "trace_id": self.trace_id, + "span_id": self.span_id, + "parent_span_id": self.parent_span_id, + "run_id": self.run_id, + "agent_id": self.agent_id, + "step_index": self.step_index, + "conversation_id": self.conversation_id, + "attributes": self.context, + }, + ) + return event diff --git a/wildedge/events/model_load.py b/wildedge/events/model_load.py index 9fae742..e058092 100644 --- a/wildedge/events/model_load.py +++ b/wildedge/events/model_load.py @@ -55,6 +55,14 @@ class ModelLoadEvent: cold_start: bool | None = None compile_time_ms: int | None = None adapter: AdapterLoad | None = None + trace_id: str | None = None + span_id: str | None = None + parent_span_id: str | None = None + run_id: str | None = None + agent_id: str | None = None + step_index: int | None = None + conversation_id: str | None = None + context: dict[str, Any] | None = None event_id: str = field(default_factory=lambda: str(uuid.uuid4())) timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) @@ -84,10 +92,26 @@ def to_dict(self) -> dict: if self.adapter is not None: load_data["adapter"] = self.adapter.to_dict() - return { + event = { "event_id": self.event_id, "event_type": "model_load", "timestamp": self.timestamp.isoformat(), "model_id": self.model_id, "load": load_data, } + from wildedge.events.common import add_optional_fields + + add_optional_fields( + event, + { + "trace_id": self.trace_id, + "span_id": self.span_id, + "parent_span_id": self.parent_span_id, + "run_id": self.run_id, + "agent_id": self.agent_id, + "step_index": self.step_index, + "conversation_id": self.conversation_id, + "attributes": self.context, + }, + ) + return event diff --git a/wildedge/events/model_unload.py b/wildedge/events/model_unload.py index 9dab481..16def90 100644 --- a/wildedge/events/model_unload.py +++ b/wildedge/events/model_unload.py @@ -14,6 +14,14 @@ class ModelUnloadEvent: memory_freed_bytes: int | None = None peak_memory_bytes: int | None = None uptime_ms: int | None = None + trace_id: str | None = None + span_id: str | None = None + parent_span_id: str | None = None + run_id: str | None = None + agent_id: str | None = None + step_index: int | None = None + conversation_id: str | None = None + context: dict[str, Any] | None = None event_id: str = field(default_factory=lambda: str(uuid.uuid4())) timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) @@ -30,10 +38,26 @@ def to_dict(self) -> dict: if v is not None: unload_data[k] = v - return { + event = { "event_id": self.event_id, "event_type": "model_unload", "timestamp": self.timestamp.isoformat(), "model_id": self.model_id, "unload": unload_data, } + from wildedge.events.common import add_optional_fields + + add_optional_fields( + event, + { + "trace_id": self.trace_id, + "span_id": self.span_id, + "parent_span_id": self.parent_span_id, + "run_id": self.run_id, + "agent_id": self.agent_id, + "step_index": self.step_index, + "conversation_id": self.conversation_id, + "attributes": self.context, + }, + ) + return event diff --git a/wildedge/events/span.py b/wildedge/events/span.py new file mode 100644 index 0000000..9d5be3c --- /dev/null +++ b/wildedge/events/span.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +import uuid +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import Any, Literal + +from wildedge.events.common import add_optional_fields + +SpanKind = Literal[ + "agent_step", + "tool", + "retrieval", + "memory", + "router", + "guardrail", + "cache", + "eval", + "custom", +] +SpanStatus = Literal["ok", "error"] + + +@dataclass +class SpanEvent: + kind: SpanKind + name: str + duration_ms: int + status: SpanStatus + model_id: str | None = None + input_summary: str | None = None + output_summary: str | None = None + attributes: dict[str, Any] | None = None + trace_id: str | None = None + span_id: str | None = None + parent_span_id: str | None = None + run_id: str | None = None + agent_id: str | None = None + step_index: int | None = None + conversation_id: str | None = None + context: dict[str, Any] | None = None + event_id: str = field(default_factory=lambda: str(uuid.uuid4())) + timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + + def to_dict(self) -> dict: + span_data: dict[str, Any] = { + "kind": self.kind, + "name": self.name, + "duration_ms": self.duration_ms, + "status": self.status, + } + if self.input_summary is not None: + span_data["input_summary"] = self.input_summary + if self.output_summary is not None: + span_data["output_summary"] = self.output_summary + if self.attributes is not None: + span_data["attributes"] = self.attributes + + event = { + "event_id": self.event_id, + "event_type": "span", + "timestamp": self.timestamp.isoformat(), + "span": span_data, + } + add_optional_fields( + event, + { + "model_id": self.model_id, + "trace_id": self.trace_id, + "span_id": self.span_id, + "parent_span_id": self.parent_span_id, + "run_id": self.run_id, + "agent_id": self.agent_id, + "step_index": self.step_index, + "conversation_id": self.conversation_id, + "context": self.context, + }, + ) + return event diff --git a/wildedge/integrations/openai.py b/wildedge/integrations/openai.py index 31281e9..700aace 100644 --- a/wildedge/integrations/openai.py +++ b/wildedge/integrations/openai.py @@ -45,13 +45,21 @@ def source_from_base_url(base_url: str | None) -> str: return SOURCE_BY_HOSTNAME.get(hostname or "", hostname or "openai") +def _msg_role(m) -> str | None: + return m.get("role") if isinstance(m, dict) else getattr(m, "role", None) + + +def _msg_content(m) -> str | None: + return m.get("content") if isinstance(m, dict) else getattr(m, "content", None) + + def build_input_meta(messages: list, tokens_in: int | None) -> TextInputMeta | None: if not messages: return None - last_user = next((m for m in reversed(messages) if m.get("role") == "user"), None) + last_user = next((m for m in reversed(messages) if _msg_role(m) == "user"), None) if not last_user: return None - content = last_user.get("content", "") + content = _msg_content(last_user) or "" if not isinstance(content, str) or not content: return None return TextInputMeta( diff --git a/wildedge/model.py b/wildedge/model.py index de7e7b1..2502b14 100644 --- a/wildedge/model.py +++ b/wildedge/model.py @@ -28,6 +28,7 @@ from wildedge.logging import logger from wildedge.platforms import capture_hardware, is_sampling from wildedge.platforms.hardware import HardwareContext +from wildedge.tracing import _merge_correlation_fields @dataclass @@ -68,8 +69,26 @@ def track_load( accelerator: str | None = None, success: bool = True, error_code: str | None = None, + trace_id: str | None = None, + span_id: str | None = None, + parent_span_id: str | None = None, + run_id: str | None = None, + agent_id: str | None = None, + step_index: int | None = None, + conversation_id: str | None = None, + context: dict[str, Any] | None = None, **kwargs: Any, ) -> None: + correlation = _merge_correlation_fields( + trace_id=trace_id, + span_id=span_id, + parent_span_id=parent_span_id, + run_id=run_id, + agent_id=agent_id, + step_index=step_index, + conversation_id=conversation_id, + context=context, + ) event = ModelLoadEvent( model_id=self.model_id, duration_ms=duration_ms, @@ -77,6 +96,7 @@ def track_load( accelerator=accelerator or self.detected_accelerator, success=success, error_code=error_code, + **correlation, **kwargs, ) self.publish(event.to_dict()) @@ -89,7 +109,25 @@ def track_unload( memory_freed_bytes: int | None = None, peak_memory_bytes: int | None = None, uptime_ms: int | None = None, + trace_id: str | None = None, + span_id: str | None = None, + parent_span_id: str | None = None, + run_id: str | None = None, + agent_id: str | None = None, + step_index: int | None = None, + conversation_id: str | None = None, + context: dict[str, Any] | None = None, ) -> None: + correlation = _merge_correlation_fields( + trace_id=trace_id, + span_id=span_id, + parent_span_id=parent_span_id, + run_id=run_id, + agent_id=agent_id, + step_index=step_index, + conversation_id=conversation_id, + context=context, + ) event = ModelUnloadEvent( model_id=self.model_id, duration_ms=duration_ms, @@ -97,6 +135,7 @@ def track_unload( memory_freed_bytes=memory_freed_bytes, peak_memory_bytes=peak_memory_bytes, uptime_ms=uptime_ms, + **correlation, ) self.publish(event.to_dict()) @@ -111,8 +150,26 @@ def track_download( resumed: bool, cache_hit: bool, success: bool, + trace_id: str | None = None, + span_id: str | None = None, + parent_span_id: str | None = None, + run_id: str | None = None, + agent_id: str | None = None, + step_index: int | None = None, + conversation_id: str | None = None, + context: dict[str, Any] | None = None, **kwargs: Any, ) -> None: + correlation = _merge_correlation_fields( + trace_id=trace_id, + span_id=span_id, + parent_span_id=parent_span_id, + run_id=run_id, + agent_id=agent_id, + step_index=step_index, + conversation_id=conversation_id, + context=context, + ) event = ModelDownloadEvent( model_id=self.model_id, source_url=source_url, @@ -124,6 +181,7 @@ def track_download( resumed=resumed, cache_hit=cache_hit, success=success, + **correlation, **kwargs, ) self.publish(event.to_dict()) @@ -146,9 +204,27 @@ def track_inference( generation_config: GenerationConfig | None = None, hardware: HardwareContext | None = None, api_meta: ApiMeta | None = None, + trace_id: str | None = None, + span_id: str | None = None, + parent_span_id: str | None = None, + run_id: str | None = None, + agent_id: str | None = None, + step_index: int | None = None, + conversation_id: str | None = None, + context: dict[str, Any] | None = None, ) -> str: if hardware is None and is_sampling(): hardware = capture_hardware() + correlation = _merge_correlation_fields( + trace_id=trace_id, + span_id=span_id, + parent_span_id=parent_span_id, + run_id=run_id, + agent_id=agent_id, + step_index=step_index, + conversation_id=conversation_id, + context=context, + ) event = InferenceEvent( model_id=self.model_id, duration_ms=duration_ms, @@ -162,6 +238,7 @@ def track_inference( generation_config=generation_config, hardware=hardware, api_meta=api_meta, + **correlation, ) self.last_inference_id = event.inference_id self.publish(event.to_dict()) @@ -174,13 +251,32 @@ def track_feedback( *, delay_ms: int | None = None, edit_distance: int | None = None, + trace_id: str | None = None, + span_id: str | None = None, + parent_span_id: str | None = None, + run_id: str | None = None, + agent_id: str | None = None, + step_index: int | None = None, + conversation_id: str | None = None, + context: dict[str, Any] | None = None, ) -> None: + correlation = _merge_correlation_fields( + trace_id=trace_id, + span_id=span_id, + parent_span_id=parent_span_id, + run_id=run_id, + agent_id=agent_id, + step_index=step_index, + conversation_id=conversation_id, + context=context, + ) event = FeedbackEvent( model_id=self.model_id, related_inference_id=related_inference_id, feedback_type=feedback_type, delay_ms=delay_ms, edit_distance=edit_distance, + **correlation, ) self.publish(event.to_dict()) @@ -205,13 +301,32 @@ def track_error( error_message: str | None = None, stack_trace_hash: str | None = None, related_event_id: str | None = None, + trace_id: str | None = None, + span_id: str | None = None, + parent_span_id: str | None = None, + run_id: str | None = None, + agent_id: str | None = None, + step_index: int | None = None, + conversation_id: str | None = None, + context: dict[str, Any] | None = None, ) -> None: + correlation = _merge_correlation_fields( + trace_id=trace_id, + span_id=span_id, + parent_span_id=parent_span_id, + run_id=run_id, + agent_id=agent_id, + step_index=step_index, + conversation_id=conversation_id, + context=context, + ) event = ErrorEvent( model_id=self.model_id, error_code=error_code, error_message=error_message, stack_trace_hash=stack_trace_hash, related_event_id=related_event_id, + **correlation, ) self.publish(event.to_dict()) diff --git a/wildedge/tracing.py b/wildedge/tracing.py new file mode 100644 index 0000000..a205543 --- /dev/null +++ b/wildedge/tracing.py @@ -0,0 +1,167 @@ +from __future__ import annotations + +import contextlib +import contextvars +import uuid +from dataclasses import dataclass +from typing import Any + + +@dataclass(frozen=True) +class TraceContext: + trace_id: str + run_id: str | None = None + agent_id: str | None = None + conversation_id: str | None = None + attributes: dict[str, Any] | None = None + + +@dataclass(frozen=True) +class SpanContext: + span_id: str + parent_span_id: str | None = None + step_index: int | None = None + attributes: dict[str, Any] | None = None + + +_TRACE_CTX: contextvars.ContextVar[TraceContext | None] = contextvars.ContextVar( + "wildedge_trace_ctx", default=None +) +_SPAN_CTX: contextvars.ContextVar[SpanContext | None] = contextvars.ContextVar( + "wildedge_span_ctx", default=None +) + + +def get_trace_context() -> TraceContext | None: + return _TRACE_CTX.get() + + +def get_span_context() -> SpanContext | None: + return _SPAN_CTX.get() + + +def _set_trace_context(ctx: TraceContext) -> contextvars.Token: + return _TRACE_CTX.set(ctx) + + +def _reset_trace_context(token: contextvars.Token) -> None: + _TRACE_CTX.reset(token) + + +def _set_span_context(ctx: SpanContext) -> contextvars.Token: + return _SPAN_CTX.set(ctx) + + +def _reset_span_context(token: contextvars.Token) -> None: + _SPAN_CTX.reset(token) + + +@contextlib.contextmanager +def trace_context( + *, + trace_id: str | None = None, + run_id: str | None = None, + agent_id: str | None = None, + conversation_id: str | None = None, + attributes: dict[str, Any] | None = None, +): + if trace_id is None: + trace_id = str(uuid.uuid4()) + token = _set_trace_context( + TraceContext( + trace_id=trace_id, + run_id=run_id, + agent_id=agent_id, + conversation_id=conversation_id, + attributes=attributes, + ) + ) + try: + yield get_trace_context() + finally: + _reset_trace_context(token) + + +@contextlib.contextmanager +def span_context( + *, + span_id: str | None = None, + parent_span_id: str | None = None, + step_index: int | None = None, + attributes: dict[str, Any] | None = None, +): + """Low-level context manager that sets span correlation fields without emitting a span event. + + Prefer client.span() for most use cases. Use this only when you need correlation + fields attached to auto-instrumented events (e.g. an OpenAI call) without emitting + a redundant span wrapper. + """ + if span_id is None: + span_id = str(uuid.uuid4()) + if parent_span_id is None: + current = get_span_context() + parent_span_id = current.span_id if current else None + token = _set_span_context( + SpanContext( + span_id=span_id, + parent_span_id=parent_span_id, + step_index=step_index, + attributes=attributes, + ) + ) + try: + yield get_span_context() + finally: + _reset_span_context(token) + + +def _merge_context(*candidates: dict[str, Any] | None) -> dict[str, Any] | None: + merged: dict[str, Any] = {} + for attrs in candidates: + if not attrs: + continue + merged.update(attrs) + return merged or None + + +def _merge_correlation_fields( + *, + trace_id: str | None = None, + span_id: str | None = None, + parent_span_id: str | None = None, + run_id: str | None = None, + agent_id: str | None = None, + step_index: int | None = None, + conversation_id: str | None = None, + context: dict[str, Any] | None = None, +) -> dict[str, Any]: + trace = get_trace_context() + span = get_span_context() + + resolved_trace_id = trace_id or (trace.trace_id if trace else None) + resolved_span_id = span_id + resolved_parent_span_id = parent_span_id or (span.span_id if span else None) + resolved_run_id = run_id or (trace.run_id if trace else None) + resolved_agent_id = agent_id or (trace.agent_id if trace else None) + resolved_step_index = ( + step_index if step_index is not None else (span.step_index if span else None) + ) + resolved_conversation_id = conversation_id or ( + trace.conversation_id if trace else None + ) + resolved_context = _merge_context( + trace.attributes if trace else None, + span.attributes if span else None, + context, + ) + + return { + "trace_id": resolved_trace_id, + "span_id": resolved_span_id, + "parent_span_id": resolved_parent_span_id, + "run_id": resolved_run_id, + "agent_id": resolved_agent_id, + "step_index": resolved_step_index, + "conversation_id": resolved_conversation_id, + "context": resolved_context, + }