Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions examples/gguf_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
]

for prompt in prompts:
result = llm(prompt, max_tokens=128, temperature=0.7)
text = result["choices"][0]["text"].strip()
print(f"Q: {prompt}\nA: {text}\n")
stream = llm(prompt, max_tokens=128, temperature=0.7, stream=True)
print(f"Q: {prompt}\nA: ", end="", flush=True)
for chunk in stream:
token = chunk["choices"][0].get("text", "")
print(token, end="", flush=True)
print("\n")
10 changes: 8 additions & 2 deletions examples/openai_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,19 @@
]

for prompt in prompts:
response = openai_client.chat.completions.create(
stream = openai_client.chat.completions.create(
model="gpt-4o",
messages=[{"role": "user", "content": prompt}],
temperature=0.7,
max_tokens=256,
stream=True,
stream_options={"include_usage": True},
)
print(f"Q: {prompt}\nA: {response.choices[0].message.content}\n")
print(f"Q: {prompt}\nA: ", end="", flush=True)
for chunk in stream:
if chunk.choices and chunk.choices[0].delta.content:
print(chunk.choices[0].delta.content, end="", flush=True)
print("\n")

client.flush()
print("Done. Events flushed to WildEdge.")
151 changes: 141 additions & 10 deletions tests/test_integrations_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pytest

import wildedge.integrations.openai as openai_mod
from wildedge.integrations.common import AsyncStreamWrapper, SyncStreamWrapper
from wildedge.integrations.openai import (
OpenAIExtractor,
build_api_meta,
Expand Down Expand Up @@ -83,6 +84,56 @@ async def create(self, *args, **kwargs):
return self._response


def make_stream_chunk(content=None, finish_reason=None, usage=None):
chunk = SimpleNamespace(
choices=[
SimpleNamespace(
delta=SimpleNamespace(content=content),
finish_reason=finish_reason,
)
],
usage=usage,
model="gpt-4o",
system_fingerprint=None,
service_tier=None,
)
return chunk


class FakeStreamingCompletions:
def __init__(self, chunks):
self._chunks = chunks

def create(self, *args, **kwargs):
if kwargs.get("stream"):
return iter(self._chunks)
return FakeResponse()


class FakeAsyncStreamingCompletions:
def __init__(self, chunks):
self._chunks = chunks

async def create(self, *args, **kwargs):
if kwargs.get("stream"):
return FakeAsyncIterator(self._chunks)
return FakeResponse()


class FakeAsyncIterator:
def __init__(self, items):
self._iter = iter(items)

def __aiter__(self):
return self

async def __anext__(self):
try:
return next(self._iter)
except StopIteration:
raise StopAsyncIteration


# Named "OpenAI" / "AsyncOpenAI" so can_handle sees the right type name.
class OpenAI:
def __init__(self, base_url="https://api.openai.com/v1", api_key=None):
Expand Down Expand Up @@ -371,11 +422,65 @@ def create(self, *args, **kwargs):
client.handles["gpt-4o"].track_error.assert_called_once()
client.handles["gpt-4o"].track_inference.assert_not_called()

def test_streaming_skips_tracking(self):
completions, client = self.setup()
completions.create(model="gpt-4o", messages=[], stream=True)
if "gpt-4o" in client.handles:
client.handles["gpt-4o"].track_inference.assert_not_called()
def test_streaming_returns_sync_stream_wrapper(self):
chunks = [make_stream_chunk("hi", None), make_stream_chunk(None, "stop")]
completions = FakeStreamingCompletions(chunks)
client = make_fake_client()
wrap_sync_completions(completions, "openai", lambda: client)
result = completions.create(model="gpt-4o", messages=[], stream=True)
assert isinstance(result, SyncStreamWrapper)

def test_streaming_records_inference_on_exhaustion(self):
chunks = [
make_stream_chunk("Hello", None),
make_stream_chunk(" world", "stop"),
]
completions = FakeStreamingCompletions(chunks)
client = make_fake_client()
wrap_sync_completions(completions, "openai", lambda: client)
stream = completions.create(
model="gpt-4o", messages=[{"role": "user", "content": "hi"}], stream=True
)
list(stream)
handle = client.handles["gpt-4o"]
handle.track_inference.assert_called_once()
kwargs = handle.track_inference.call_args.kwargs
assert kwargs["output_meta"].time_to_first_token_ms is not None
assert kwargs["output_meta"].stop_reason == "stop"
assert kwargs["input_modality"] == "text"
assert kwargs["success"] is True

def test_streaming_captures_usage_from_chunks(self):
usage_chunk = SimpleNamespace(prompt_tokens=8, completion_tokens=15)
chunks = [
make_stream_chunk("hi", None),
make_stream_chunk(None, "stop", usage=usage_chunk),
]
completions = FakeStreamingCompletions(chunks)
client = make_fake_client()
wrap_sync_completions(completions, "openai", lambda: client)
list(completions.create(model="gpt-4o", messages=[], stream=True))
out = client.handles["gpt-4o"].track_inference.call_args.kwargs["output_meta"]
assert out.tokens_in == 8
assert out.tokens_out == 15

def test_streaming_error_during_iteration_tracks_error(self):
def bad_iter():
yield make_stream_chunk("hi", None)
raise RuntimeError("stream error")

class ErrorStreamCompletions:
def create(self, *args, **kwargs):
return bad_iter()

client = make_fake_client()
completions = ErrorStreamCompletions()
wrap_sync_completions(completions, "openai", lambda: client)
stream = completions.create(model="gpt-4o", messages=[], stream=True)
with pytest.raises(RuntimeError, match="stream error"):
list(stream)
client.handles["gpt-4o"].track_error.assert_called_once()
client.handles["gpt-4o"].track_inference.assert_not_called()

def test_closed_client_passes_through(self):
completions, client = self.setup(closed=True)
Expand Down Expand Up @@ -438,11 +543,37 @@ async def create(self, *args, **kwargs):

client.handles["gpt-4o"].track_error.assert_called_once()

async def test_streaming_skips_tracking(self):
completions, client = self.setup()
await completions.create(model="qwen/qwen3-235b", messages=[], stream=True)
if "qwen/qwen3-235b" in client.handles:
client.handles["qwen/qwen3-235b"].track_inference.assert_not_called()
async def test_streaming_returns_async_stream_wrapper(self):
chunks = [make_stream_chunk("hi", None), make_stream_chunk(None, "stop")]
completions = FakeAsyncStreamingCompletions(chunks)
client = make_fake_client()
wrap_async_completions(completions, "openrouter", lambda: client)
result = await completions.create(
model="qwen/qwen3-235b", messages=[], stream=True
)
assert isinstance(result, AsyncStreamWrapper)

async def test_streaming_records_inference_on_exhaustion(self):
chunks = [
make_stream_chunk("Hello", None),
make_stream_chunk(" world", "stop"),
]
completions = FakeAsyncStreamingCompletions(chunks)
client = make_fake_client()
wrap_async_completions(completions, "openrouter", lambda: client)
stream = await completions.create(
model="qwen/qwen3-235b",
messages=[{"role": "user", "content": "hi"}],
stream=True,
)
async for _ in stream:
pass
handle = client.handles["qwen/qwen3-235b"]
handle.track_inference.assert_called_once()
kwargs = handle.track_inference.call_args.kwargs
assert kwargs["output_meta"].time_to_first_token_ms is not None
assert kwargs["output_meta"].stop_reason == "stop"
assert kwargs["success"] is True


# ---------------------------------------------------------------------------
Expand Down
123 changes: 122 additions & 1 deletion wildedge/integrations/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,15 @@

from __future__ import annotations

from typing import Any
from collections.abc import Callable
from typing import TYPE_CHECKING, Any

from wildedge import constants
from wildedge.logging import logger
from wildedge.timing import elapsed_ms

if TYPE_CHECKING:
from wildedge.model import ModelHandle


def debug_failure(framework: str, context: str, exc: BaseException) -> None:
Expand Down Expand Up @@ -110,3 +116,118 @@ def num_classes_from_output_shape(shape: tuple) -> int:
if len(shape) >= 2 and isinstance(shape[-1], int) and shape[-1] > 1:
return int(shape[-1])
return 0


# ---------------------------------------------------------------------------
# Generic streaming wrappers
# ---------------------------------------------------------------------------
# Each integration provides:
# on_chunk(chunk) -> None : update mutable state from a single chunk
# on_done(duration_ms, ttft_ms) : record inference once the stream is exhausted
#
# The wrappers handle TTFT capture, error tracking, context-manager delegation,
# and attribute proxying so callers get a drop-in replacement for the raw stream.


class SyncStreamWrapper:
"""Wraps a sync iterable stream to capture TTFT and record inference on exhaustion."""

def __init__(
self,
original: object,
handle: ModelHandle,
t0: float,
on_chunk: Callable[[object], None] | None,
on_done: Callable[[int, int | None], None],
) -> None:
self._original = original
self._handle = handle
self._t0 = t0
self._on_chunk = on_chunk
self._on_done = on_done

def __iter__(self):
return self._track()

def _track(self):
ttft_ms: int | None = None
try:
for chunk in self._original: # type: ignore[union-attr]
if ttft_ms is None:
ttft_ms = elapsed_ms(self._t0)
if self._on_chunk is not None:
self._on_chunk(chunk)
yield chunk
except Exception as exc:
self._handle.track_error(
error_code="UNKNOWN",
error_message=str(exc)[: constants.ERROR_MSG_MAX_LEN],
)
raise
else:
self._on_done(elapsed_ms(self._t0), ttft_ms)

def __enter__(self) -> SyncStreamWrapper:
if hasattr(self._original, "__enter__"):
self._original.__enter__() # type: ignore[union-attr]
return self

def __exit__(self, *args: object) -> object:
if hasattr(self._original, "__exit__"):
return self._original.__exit__(*args) # type: ignore[union-attr]
return None

def __getattr__(self, name: str) -> object:
return getattr(self._original, name)


class AsyncStreamWrapper:
"""Wraps an async iterable stream to capture TTFT and record inference on exhaustion."""

def __init__(
self,
original: object,
handle: ModelHandle,
t0: float,
on_chunk: Callable[[object], None] | None,
on_done: Callable[[int, int | None], None],
) -> None:
self._original = original
self._handle = handle
self._t0 = t0
self._on_chunk = on_chunk
self._on_done = on_done

def __aiter__(self):
return self._track()

async def _track(self):
ttft_ms: int | None = None
try:
async for chunk in self._original: # type: ignore[union-attr]
if ttft_ms is None:
ttft_ms = elapsed_ms(self._t0)
if self._on_chunk is not None:
self._on_chunk(chunk)
yield chunk
except Exception as exc:
self._handle.track_error(
error_code="UNKNOWN",
error_message=str(exc)[: constants.ERROR_MSG_MAX_LEN],
)
raise
else:
self._on_done(elapsed_ms(self._t0), ttft_ms)

async def __aenter__(self) -> AsyncStreamWrapper:
if hasattr(self._original, "__aenter__"):
await self._original.__aenter__() # type: ignore[union-attr]
return self

async def __aexit__(self, *args: object) -> object:
if hasattr(self._original, "__aexit__"):
return await self._original.__aexit__(*args) # type: ignore[union-attr]
return None

def __getattr__(self, name: str) -> object:
return getattr(self._original, name)
Loading
Loading