diff --git a/examples/gguf_example.py b/examples/gguf_example.py index cf58f88..bd50f98 100644 --- a/examples/gguf_example.py +++ b/examples/gguf_example.py @@ -31,6 +31,9 @@ ] for prompt in prompts: - result = llm(prompt, max_tokens=128, temperature=0.7) - text = result["choices"][0]["text"].strip() - print(f"Q: {prompt}\nA: {text}\n") + stream = llm(prompt, max_tokens=128, temperature=0.7, stream=True) + print(f"Q: {prompt}\nA: ", end="", flush=True) + for chunk in stream: + token = chunk["choices"][0].get("text", "") + print(token, end="", flush=True) + print("\n") diff --git a/examples/openai_example.py b/examples/openai_example.py index d9410e4..ad2d3a9 100644 --- a/examples/openai_example.py +++ b/examples/openai_example.py @@ -32,13 +32,19 @@ ] for prompt in prompts: - response = openai_client.chat.completions.create( + stream = openai_client.chat.completions.create( model="gpt-4o", messages=[{"role": "user", "content": prompt}], temperature=0.7, max_tokens=256, + stream=True, + stream_options={"include_usage": True}, ) - print(f"Q: {prompt}\nA: {response.choices[0].message.content}\n") + print(f"Q: {prompt}\nA: ", end="", flush=True) + for chunk in stream: + if chunk.choices and chunk.choices[0].delta.content: + print(chunk.choices[0].delta.content, end="", flush=True) + print("\n") client.flush() print("Done. Events flushed to WildEdge.") diff --git a/tests/test_integrations_openai.py b/tests/test_integrations_openai.py index 9fea3ce..dbba8d0 100644 --- a/tests/test_integrations_openai.py +++ b/tests/test_integrations_openai.py @@ -9,6 +9,7 @@ import pytest import wildedge.integrations.openai as openai_mod +from wildedge.integrations.common import AsyncStreamWrapper, SyncStreamWrapper from wildedge.integrations.openai import ( OpenAIExtractor, build_api_meta, @@ -83,6 +84,56 @@ async def create(self, *args, **kwargs): return self._response +def make_stream_chunk(content=None, finish_reason=None, usage=None): + chunk = SimpleNamespace( + choices=[ + SimpleNamespace( + delta=SimpleNamespace(content=content), + finish_reason=finish_reason, + ) + ], + usage=usage, + model="gpt-4o", + system_fingerprint=None, + service_tier=None, + ) + return chunk + + +class FakeStreamingCompletions: + def __init__(self, chunks): + self._chunks = chunks + + def create(self, *args, **kwargs): + if kwargs.get("stream"): + return iter(self._chunks) + return FakeResponse() + + +class FakeAsyncStreamingCompletions: + def __init__(self, chunks): + self._chunks = chunks + + async def create(self, *args, **kwargs): + if kwargs.get("stream"): + return FakeAsyncIterator(self._chunks) + return FakeResponse() + + +class FakeAsyncIterator: + def __init__(self, items): + self._iter = iter(items) + + def __aiter__(self): + return self + + async def __anext__(self): + try: + return next(self._iter) + except StopIteration: + raise StopAsyncIteration + + # Named "OpenAI" / "AsyncOpenAI" so can_handle sees the right type name. class OpenAI: def __init__(self, base_url="https://api.openai.com/v1", api_key=None): @@ -371,11 +422,65 @@ def create(self, *args, **kwargs): client.handles["gpt-4o"].track_error.assert_called_once() client.handles["gpt-4o"].track_inference.assert_not_called() - def test_streaming_skips_tracking(self): - completions, client = self.setup() - completions.create(model="gpt-4o", messages=[], stream=True) - if "gpt-4o" in client.handles: - client.handles["gpt-4o"].track_inference.assert_not_called() + def test_streaming_returns_sync_stream_wrapper(self): + chunks = [make_stream_chunk("hi", None), make_stream_chunk(None, "stop")] + completions = FakeStreamingCompletions(chunks) + client = make_fake_client() + wrap_sync_completions(completions, "openai", lambda: client) + result = completions.create(model="gpt-4o", messages=[], stream=True) + assert isinstance(result, SyncStreamWrapper) + + def test_streaming_records_inference_on_exhaustion(self): + chunks = [ + make_stream_chunk("Hello", None), + make_stream_chunk(" world", "stop"), + ] + completions = FakeStreamingCompletions(chunks) + client = make_fake_client() + wrap_sync_completions(completions, "openai", lambda: client) + stream = completions.create( + model="gpt-4o", messages=[{"role": "user", "content": "hi"}], stream=True + ) + list(stream) + handle = client.handles["gpt-4o"] + handle.track_inference.assert_called_once() + kwargs = handle.track_inference.call_args.kwargs + assert kwargs["output_meta"].time_to_first_token_ms is not None + assert kwargs["output_meta"].stop_reason == "stop" + assert kwargs["input_modality"] == "text" + assert kwargs["success"] is True + + def test_streaming_captures_usage_from_chunks(self): + usage_chunk = SimpleNamespace(prompt_tokens=8, completion_tokens=15) + chunks = [ + make_stream_chunk("hi", None), + make_stream_chunk(None, "stop", usage=usage_chunk), + ] + completions = FakeStreamingCompletions(chunks) + client = make_fake_client() + wrap_sync_completions(completions, "openai", lambda: client) + list(completions.create(model="gpt-4o", messages=[], stream=True)) + out = client.handles["gpt-4o"].track_inference.call_args.kwargs["output_meta"] + assert out.tokens_in == 8 + assert out.tokens_out == 15 + + def test_streaming_error_during_iteration_tracks_error(self): + def bad_iter(): + yield make_stream_chunk("hi", None) + raise RuntimeError("stream error") + + class ErrorStreamCompletions: + def create(self, *args, **kwargs): + return bad_iter() + + client = make_fake_client() + completions = ErrorStreamCompletions() + wrap_sync_completions(completions, "openai", lambda: client) + stream = completions.create(model="gpt-4o", messages=[], stream=True) + with pytest.raises(RuntimeError, match="stream error"): + list(stream) + client.handles["gpt-4o"].track_error.assert_called_once() + client.handles["gpt-4o"].track_inference.assert_not_called() def test_closed_client_passes_through(self): completions, client = self.setup(closed=True) @@ -438,11 +543,37 @@ async def create(self, *args, **kwargs): client.handles["gpt-4o"].track_error.assert_called_once() - async def test_streaming_skips_tracking(self): - completions, client = self.setup() - await completions.create(model="qwen/qwen3-235b", messages=[], stream=True) - if "qwen/qwen3-235b" in client.handles: - client.handles["qwen/qwen3-235b"].track_inference.assert_not_called() + async def test_streaming_returns_async_stream_wrapper(self): + chunks = [make_stream_chunk("hi", None), make_stream_chunk(None, "stop")] + completions = FakeAsyncStreamingCompletions(chunks) + client = make_fake_client() + wrap_async_completions(completions, "openrouter", lambda: client) + result = await completions.create( + model="qwen/qwen3-235b", messages=[], stream=True + ) + assert isinstance(result, AsyncStreamWrapper) + + async def test_streaming_records_inference_on_exhaustion(self): + chunks = [ + make_stream_chunk("Hello", None), + make_stream_chunk(" world", "stop"), + ] + completions = FakeAsyncStreamingCompletions(chunks) + client = make_fake_client() + wrap_async_completions(completions, "openrouter", lambda: client) + stream = await completions.create( + model="qwen/qwen3-235b", + messages=[{"role": "user", "content": "hi"}], + stream=True, + ) + async for _ in stream: + pass + handle = client.handles["qwen/qwen3-235b"] + handle.track_inference.assert_called_once() + kwargs = handle.track_inference.call_args.kwargs + assert kwargs["output_meta"].time_to_first_token_ms is not None + assert kwargs["output_meta"].stop_reason == "stop" + assert kwargs["success"] is True # --------------------------------------------------------------------------- diff --git a/wildedge/integrations/common.py b/wildedge/integrations/common.py index 22a431a..1506b01 100644 --- a/wildedge/integrations/common.py +++ b/wildedge/integrations/common.py @@ -2,9 +2,15 @@ from __future__ import annotations -from typing import Any +from collections.abc import Callable +from typing import TYPE_CHECKING, Any +from wildedge import constants from wildedge.logging import logger +from wildedge.timing import elapsed_ms + +if TYPE_CHECKING: + from wildedge.model import ModelHandle def debug_failure(framework: str, context: str, exc: BaseException) -> None: @@ -110,3 +116,118 @@ def num_classes_from_output_shape(shape: tuple) -> int: if len(shape) >= 2 and isinstance(shape[-1], int) and shape[-1] > 1: return int(shape[-1]) return 0 + + +# --------------------------------------------------------------------------- +# Generic streaming wrappers +# --------------------------------------------------------------------------- +# Each integration provides: +# on_chunk(chunk) -> None : update mutable state from a single chunk +# on_done(duration_ms, ttft_ms) : record inference once the stream is exhausted +# +# The wrappers handle TTFT capture, error tracking, context-manager delegation, +# and attribute proxying so callers get a drop-in replacement for the raw stream. + + +class SyncStreamWrapper: + """Wraps a sync iterable stream to capture TTFT and record inference on exhaustion.""" + + def __init__( + self, + original: object, + handle: ModelHandle, + t0: float, + on_chunk: Callable[[object], None] | None, + on_done: Callable[[int, int | None], None], + ) -> None: + self._original = original + self._handle = handle + self._t0 = t0 + self._on_chunk = on_chunk + self._on_done = on_done + + def __iter__(self): + return self._track() + + def _track(self): + ttft_ms: int | None = None + try: + for chunk in self._original: # type: ignore[union-attr] + if ttft_ms is None: + ttft_ms = elapsed_ms(self._t0) + if self._on_chunk is not None: + self._on_chunk(chunk) + yield chunk + except Exception as exc: + self._handle.track_error( + error_code="UNKNOWN", + error_message=str(exc)[: constants.ERROR_MSG_MAX_LEN], + ) + raise + else: + self._on_done(elapsed_ms(self._t0), ttft_ms) + + def __enter__(self) -> SyncStreamWrapper: + if hasattr(self._original, "__enter__"): + self._original.__enter__() # type: ignore[union-attr] + return self + + def __exit__(self, *args: object) -> object: + if hasattr(self._original, "__exit__"): + return self._original.__exit__(*args) # type: ignore[union-attr] + return None + + def __getattr__(self, name: str) -> object: + return getattr(self._original, name) + + +class AsyncStreamWrapper: + """Wraps an async iterable stream to capture TTFT and record inference on exhaustion.""" + + def __init__( + self, + original: object, + handle: ModelHandle, + t0: float, + on_chunk: Callable[[object], None] | None, + on_done: Callable[[int, int | None], None], + ) -> None: + self._original = original + self._handle = handle + self._t0 = t0 + self._on_chunk = on_chunk + self._on_done = on_done + + def __aiter__(self): + return self._track() + + async def _track(self): + ttft_ms: int | None = None + try: + async for chunk in self._original: # type: ignore[union-attr] + if ttft_ms is None: + ttft_ms = elapsed_ms(self._t0) + if self._on_chunk is not None: + self._on_chunk(chunk) + yield chunk + except Exception as exc: + self._handle.track_error( + error_code="UNKNOWN", + error_message=str(exc)[: constants.ERROR_MSG_MAX_LEN], + ) + raise + else: + self._on_done(elapsed_ms(self._t0), ttft_ms) + + async def __aenter__(self) -> AsyncStreamWrapper: + if hasattr(self._original, "__aenter__"): + await self._original.__aenter__() # type: ignore[union-attr] + return self + + async def __aexit__(self, *args: object) -> object: + if hasattr(self._original, "__aexit__"): + return await self._original.__aexit__(*args) # type: ignore[union-attr] + return None + + def __getattr__(self, name: str) -> object: + return getattr(self._original, name) diff --git a/wildedge/integrations/gguf.py b/wildedge/integrations/gguf.py index 9937cdd..dc44360 100644 --- a/wildedge/integrations/gguf.py +++ b/wildedge/integrations/gguf.py @@ -12,7 +12,7 @@ from wildedge import constants from wildedge.events.inference import GenerationOutputMeta, TextInputMeta from wildedge.integrations.base import BaseExtractor, patch_instance_call_once -from wildedge.integrations.common import debug_failure +from wildedge.integrations.common import SyncStreamWrapper, debug_failure from wildedge.logging import logger from wildedge.model import ModelInfo from wildedge.platforms import CURRENT_PLATFORM @@ -69,6 +69,76 @@ def parse_quantization(filename: str) -> str | None: return None +def make_gguf_input_meta(prompt: object, tokens_in: int | None) -> TextInputMeta | None: + if not isinstance(prompt, str) or not prompt: + return None + return TextInputMeta( + char_count=len(prompt), + word_count=len(prompt.split()), + token_count=tokens_in, + ) + + +def make_gguf_output_meta( + tokens_in: int | None, + tokens_out: int | None, + stop_reason: str | None, + ttft_ms: int | None, + duration_ms: int, +) -> GenerationOutputMeta | None: + if tokens_out is None and ttft_ms is None: + return None + tps = ( + round(tokens_out / duration_ms * 1000, 1) + if duration_ms > 0 and tokens_out + else None + ) + return GenerationOutputMeta( + task="generation", + tokens_in=tokens_in, + tokens_out=tokens_out, + time_to_first_token_ms=ttft_ms, + tokens_per_second=tps, + stop_reason=stop_reason, + ) + + +def make_gguf_stream_callbacks(handle: object, prompt: object) -> tuple: + """Return (on_chunk, on_done) callbacks for a llama-cpp-python streaming response. + + Chunks are dicts; usage appears in the final chunk when available. + """ + tokens_in: list[int | None] = [None] + tokens_out: list[int | None] = [None] + stop_reason: list[str | None] = [None] + + def on_chunk(chunk: object) -> None: + if not isinstance(chunk, dict): + return + usage = chunk.get("usage") + if usage: + tokens_in[0] = usage.get("prompt_tokens") + tokens_out[0] = usage.get("completion_tokens") + choices = chunk.get("choices") or [] + if choices: + reason = choices[0].get("finish_reason") + if reason: + stop_reason[0] = reason + + def on_done(duration_ms: int, ttft_ms: int | None) -> None: + ti, to, sr = tokens_in[0], tokens_out[0], stop_reason[0] + handle.track_inference( # type: ignore[union-attr] + duration_ms=duration_ms, + input_modality="text", + output_modality="generation", + input_meta=make_gguf_input_meta(prompt, ti), + success=True, + output_meta=make_gguf_output_meta(ti, to, sr, ttft_ms, duration_ms), + ) + + return on_chunk, on_done + + def build_patched_call(original_call): def patched_call(self_inner, *args, **kwargs): handle = getattr(self_inner, GGUF_HANDLE_ATTR, None) @@ -76,9 +146,13 @@ def patched_call(self_inner, *args, **kwargs): return original_call(self_inner, *args, **kwargs) prompt = args[0] if args else kwargs.get("prompt", "") + is_streaming: bool = bool(kwargs.get("stream", False)) t0 = time.perf_counter() try: result = original_call(self_inner, *args, **kwargs) + if is_streaming: + on_chunk, on_done = make_gguf_stream_callbacks(handle, prompt) + return SyncStreamWrapper(result, handle, t0, on_chunk, on_done) duration_ms = elapsed_ms(t0) tokens_in = None tokens_out = None @@ -89,36 +163,15 @@ def patched_call(self_inner, *args, **kwargs): tokens_out = usage.get("completion_tokens") except Exception as exc: debug_gguf_failure("usage extraction", exc) - - input_meta = None - if isinstance(prompt, str) and prompt: - input_meta = TextInputMeta( - char_count=len(prompt), - word_count=len(prompt.split()), - token_count=tokens_in, - ) - - output_meta = None - if tokens_out is not None: - tps = ( - round(tokens_out / duration_ms * 1000, 1) - if duration_ms > 0 - else None - ) - output_meta = GenerationOutputMeta( - task="generation", - tokens_in=tokens_in, - tokens_out=tokens_out, - tokens_per_second=tps, - ) - handle.track_inference( duration_ms=duration_ms, input_modality="text", output_modality="generation", - input_meta=input_meta, + input_meta=make_gguf_input_meta(prompt, tokens_in), success=True, - output_meta=output_meta, + output_meta=make_gguf_output_meta( + tokens_in, tokens_out, None, None, duration_ms + ), ) return result except Exception as exc: diff --git a/wildedge/integrations/openai.py b/wildedge/integrations/openai.py index e5d1525..31281e9 100644 --- a/wildedge/integrations/openai.py +++ b/wildedge/integrations/openai.py @@ -11,7 +11,11 @@ from wildedge import constants from wildedge.events.inference import ApiMeta, GenerationOutputMeta, TextInputMeta from wildedge.integrations.base import BaseExtractor -from wildedge.integrations.common import debug_failure +from wildedge.integrations.common import ( + AsyncStreamWrapper, + SyncStreamWrapper, + debug_failure, +) from wildedge.model import ModelInfo from wildedge.timing import elapsed_ms @@ -58,6 +62,28 @@ def build_input_meta(messages: list, tokens_in: int | None) -> TextInputMeta | N ) +def build_streaming_output_meta( + ttft_ms: int | None, + tokens_in: int | None, + tokens_out: int | None, + stop_reason: str | None, + duration_ms: int, +) -> GenerationOutputMeta: + tps = ( + round(tokens_out / duration_ms * 1000, 1) + if duration_ms > 0 and tokens_out + else None + ) + return GenerationOutputMeta( + task="generation", + tokens_in=tokens_in, + tokens_out=tokens_out, + time_to_first_token_ms=ttft_ms, + tokens_per_second=tps, + stop_reason=stop_reason, + ) + + def build_output_meta( response: object, duration_ms: int ) -> GenerationOutputMeta | None: @@ -145,6 +171,50 @@ def record_inference( ) +def make_openai_stream_callbacks( + handle: ModelHandle, + messages: list, +) -> tuple: + """Return (on_chunk, on_done) callbacks for an OpenAI streaming response. + + on_chunk updates mutable state from each ChatCompletionChunk. + on_done is called with (duration_ms, ttft_ms) when the stream is exhausted. + """ + tokens_in: list[int | None] = [None] + tokens_out: list[int | None] = [None] + stop_reason: list[str | None] = [None] + last_chunk: list[object] = [None] + + def on_chunk(chunk: object) -> None: + last_chunk[0] = chunk + chunk_usage = getattr(chunk, "usage", None) + if chunk_usage is not None: + tokens_in[0] = getattr(chunk_usage, "prompt_tokens", None) + tokens_out[0] = getattr(chunk_usage, "completion_tokens", None) + choices = getattr(chunk, "choices", None) or [] + if choices: + reason = getattr(choices[0], "finish_reason", None) + if reason: + stop_reason[0] = reason + + def on_done(duration_ms: int, ttft_ms: int | None) -> None: + ti, to, sr = tokens_in[0], tokens_out[0], stop_reason[0] + output_meta = build_streaming_output_meta(ttft_ms, ti, to, sr, duration_ms) + handle.track_inference( + duration_ms=duration_ms, + input_modality="text", + output_modality="generation", + success=True, + input_meta=build_input_meta(messages, ti), + output_meta=output_meta, + api_meta=build_api_meta(last_chunk[0]) + if last_chunk[0] is not None + else None, + ) + + return on_chunk, on_done + + def wrap_sync_completions(completions: object, source: str, client_ref: object) -> None: original_create = completions.create # type: ignore[attr-defined] model_handles: dict[str, ModelHandle] = {} @@ -160,8 +230,12 @@ def patched_create(*args, **kwargs): t0 = time.perf_counter() try: result = original_create(*args, **kwargs) - if not is_streaming and handle is not None: - record_inference(handle, result, messages, elapsed_ms(t0)) + if handle is not None: + if is_streaming: + on_chunk, on_done = make_openai_stream_callbacks(handle, messages) + return SyncStreamWrapper(result, handle, t0, on_chunk, on_done) + else: + record_inference(handle, result, messages, elapsed_ms(t0)) return result except Exception as exc: if handle is not None: @@ -191,8 +265,12 @@ async def patched_create(*args, **kwargs): t0 = time.perf_counter() try: result = await original_create(*args, **kwargs) - if not is_streaming and handle is not None: - record_inference(handle, result, messages, elapsed_ms(t0)) + if handle is not None: + if is_streaming: + on_chunk, on_done = make_openai_stream_callbacks(handle, messages) + return AsyncStreamWrapper(result, handle, t0, on_chunk, on_done) + else: + record_inference(handle, result, messages, elapsed_ms(t0)) return result except Exception as exc: if handle is not None: