Skip to content

Commit 79de451

Browse files
authored
TTFT support for remote llm + GGUF integrations (#32)
1 parent 4c80592 commit 79de451

File tree

6 files changed

+439
-47
lines changed

6 files changed

+439
-47
lines changed

examples/gguf_example.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@
3131
]
3232

3333
for prompt in prompts:
34-
result = llm(prompt, max_tokens=128, temperature=0.7)
35-
text = result["choices"][0]["text"].strip()
36-
print(f"Q: {prompt}\nA: {text}\n")
34+
stream = llm(prompt, max_tokens=128, temperature=0.7, stream=True)
35+
print(f"Q: {prompt}\nA: ", end="", flush=True)
36+
for chunk in stream:
37+
token = chunk["choices"][0].get("text", "")
38+
print(token, end="", flush=True)
39+
print("\n")

examples/openai_example.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,19 @@
3232
]
3333

3434
for prompt in prompts:
35-
response = openai_client.chat.completions.create(
35+
stream = openai_client.chat.completions.create(
3636
model="gpt-4o",
3737
messages=[{"role": "user", "content": prompt}],
3838
temperature=0.7,
3939
max_tokens=256,
40+
stream=True,
41+
stream_options={"include_usage": True},
4042
)
41-
print(f"Q: {prompt}\nA: {response.choices[0].message.content}\n")
43+
print(f"Q: {prompt}\nA: ", end="", flush=True)
44+
for chunk in stream:
45+
if chunk.choices and chunk.choices[0].delta.content:
46+
print(chunk.choices[0].delta.content, end="", flush=True)
47+
print("\n")
4248

4349
client.flush()
4450
print("Done. Events flushed to WildEdge.")

tests/test_integrations_openai.py

Lines changed: 141 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import pytest
1010

1111
import wildedge.integrations.openai as openai_mod
12+
from wildedge.integrations.common import AsyncStreamWrapper, SyncStreamWrapper
1213
from wildedge.integrations.openai import (
1314
OpenAIExtractor,
1415
build_api_meta,
@@ -83,6 +84,56 @@ async def create(self, *args, **kwargs):
8384
return self._response
8485

8586

87+
def make_stream_chunk(content=None, finish_reason=None, usage=None):
88+
chunk = SimpleNamespace(
89+
choices=[
90+
SimpleNamespace(
91+
delta=SimpleNamespace(content=content),
92+
finish_reason=finish_reason,
93+
)
94+
],
95+
usage=usage,
96+
model="gpt-4o",
97+
system_fingerprint=None,
98+
service_tier=None,
99+
)
100+
return chunk
101+
102+
103+
class FakeStreamingCompletions:
104+
def __init__(self, chunks):
105+
self._chunks = chunks
106+
107+
def create(self, *args, **kwargs):
108+
if kwargs.get("stream"):
109+
return iter(self._chunks)
110+
return FakeResponse()
111+
112+
113+
class FakeAsyncStreamingCompletions:
114+
def __init__(self, chunks):
115+
self._chunks = chunks
116+
117+
async def create(self, *args, **kwargs):
118+
if kwargs.get("stream"):
119+
return FakeAsyncIterator(self._chunks)
120+
return FakeResponse()
121+
122+
123+
class FakeAsyncIterator:
124+
def __init__(self, items):
125+
self._iter = iter(items)
126+
127+
def __aiter__(self):
128+
return self
129+
130+
async def __anext__(self):
131+
try:
132+
return next(self._iter)
133+
except StopIteration:
134+
raise StopAsyncIteration
135+
136+
86137
# Named "OpenAI" / "AsyncOpenAI" so can_handle sees the right type name.
87138
class OpenAI:
88139
def __init__(self, base_url="https://api.openai.com/v1", api_key=None):
@@ -371,11 +422,65 @@ def create(self, *args, **kwargs):
371422
client.handles["gpt-4o"].track_error.assert_called_once()
372423
client.handles["gpt-4o"].track_inference.assert_not_called()
373424

374-
def test_streaming_skips_tracking(self):
375-
completions, client = self.setup()
376-
completions.create(model="gpt-4o", messages=[], stream=True)
377-
if "gpt-4o" in client.handles:
378-
client.handles["gpt-4o"].track_inference.assert_not_called()
425+
def test_streaming_returns_sync_stream_wrapper(self):
426+
chunks = [make_stream_chunk("hi", None), make_stream_chunk(None, "stop")]
427+
completions = FakeStreamingCompletions(chunks)
428+
client = make_fake_client()
429+
wrap_sync_completions(completions, "openai", lambda: client)
430+
result = completions.create(model="gpt-4o", messages=[], stream=True)
431+
assert isinstance(result, SyncStreamWrapper)
432+
433+
def test_streaming_records_inference_on_exhaustion(self):
434+
chunks = [
435+
make_stream_chunk("Hello", None),
436+
make_stream_chunk(" world", "stop"),
437+
]
438+
completions = FakeStreamingCompletions(chunks)
439+
client = make_fake_client()
440+
wrap_sync_completions(completions, "openai", lambda: client)
441+
stream = completions.create(
442+
model="gpt-4o", messages=[{"role": "user", "content": "hi"}], stream=True
443+
)
444+
list(stream)
445+
handle = client.handles["gpt-4o"]
446+
handle.track_inference.assert_called_once()
447+
kwargs = handle.track_inference.call_args.kwargs
448+
assert kwargs["output_meta"].time_to_first_token_ms is not None
449+
assert kwargs["output_meta"].stop_reason == "stop"
450+
assert kwargs["input_modality"] == "text"
451+
assert kwargs["success"] is True
452+
453+
def test_streaming_captures_usage_from_chunks(self):
454+
usage_chunk = SimpleNamespace(prompt_tokens=8, completion_tokens=15)
455+
chunks = [
456+
make_stream_chunk("hi", None),
457+
make_stream_chunk(None, "stop", usage=usage_chunk),
458+
]
459+
completions = FakeStreamingCompletions(chunks)
460+
client = make_fake_client()
461+
wrap_sync_completions(completions, "openai", lambda: client)
462+
list(completions.create(model="gpt-4o", messages=[], stream=True))
463+
out = client.handles["gpt-4o"].track_inference.call_args.kwargs["output_meta"]
464+
assert out.tokens_in == 8
465+
assert out.tokens_out == 15
466+
467+
def test_streaming_error_during_iteration_tracks_error(self):
468+
def bad_iter():
469+
yield make_stream_chunk("hi", None)
470+
raise RuntimeError("stream error")
471+
472+
class ErrorStreamCompletions:
473+
def create(self, *args, **kwargs):
474+
return bad_iter()
475+
476+
client = make_fake_client()
477+
completions = ErrorStreamCompletions()
478+
wrap_sync_completions(completions, "openai", lambda: client)
479+
stream = completions.create(model="gpt-4o", messages=[], stream=True)
480+
with pytest.raises(RuntimeError, match="stream error"):
481+
list(stream)
482+
client.handles["gpt-4o"].track_error.assert_called_once()
483+
client.handles["gpt-4o"].track_inference.assert_not_called()
379484

380485
def test_closed_client_passes_through(self):
381486
completions, client = self.setup(closed=True)
@@ -438,11 +543,37 @@ async def create(self, *args, **kwargs):
438543

439544
client.handles["gpt-4o"].track_error.assert_called_once()
440545

441-
async def test_streaming_skips_tracking(self):
442-
completions, client = self.setup()
443-
await completions.create(model="qwen/qwen3-235b", messages=[], stream=True)
444-
if "qwen/qwen3-235b" in client.handles:
445-
client.handles["qwen/qwen3-235b"].track_inference.assert_not_called()
546+
async def test_streaming_returns_async_stream_wrapper(self):
547+
chunks = [make_stream_chunk("hi", None), make_stream_chunk(None, "stop")]
548+
completions = FakeAsyncStreamingCompletions(chunks)
549+
client = make_fake_client()
550+
wrap_async_completions(completions, "openrouter", lambda: client)
551+
result = await completions.create(
552+
model="qwen/qwen3-235b", messages=[], stream=True
553+
)
554+
assert isinstance(result, AsyncStreamWrapper)
555+
556+
async def test_streaming_records_inference_on_exhaustion(self):
557+
chunks = [
558+
make_stream_chunk("Hello", None),
559+
make_stream_chunk(" world", "stop"),
560+
]
561+
completions = FakeAsyncStreamingCompletions(chunks)
562+
client = make_fake_client()
563+
wrap_async_completions(completions, "openrouter", lambda: client)
564+
stream = await completions.create(
565+
model="qwen/qwen3-235b",
566+
messages=[{"role": "user", "content": "hi"}],
567+
stream=True,
568+
)
569+
async for _ in stream:
570+
pass
571+
handle = client.handles["qwen/qwen3-235b"]
572+
handle.track_inference.assert_called_once()
573+
kwargs = handle.track_inference.call_args.kwargs
574+
assert kwargs["output_meta"].time_to_first_token_ms is not None
575+
assert kwargs["output_meta"].stop_reason == "stop"
576+
assert kwargs["success"] is True
446577

447578

448579
# ---------------------------------------------------------------------------

wildedge/integrations/common.py

Lines changed: 122 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,15 @@
22

33
from __future__ import annotations
44

5-
from typing import Any
5+
from collections.abc import Callable
6+
from typing import TYPE_CHECKING, Any
67

8+
from wildedge import constants
79
from wildedge.logging import logger
10+
from wildedge.timing import elapsed_ms
11+
12+
if TYPE_CHECKING:
13+
from wildedge.model import ModelHandle
814

915

1016
def debug_failure(framework: str, context: str, exc: BaseException) -> None:
@@ -110,3 +116,118 @@ def num_classes_from_output_shape(shape: tuple) -> int:
110116
if len(shape) >= 2 and isinstance(shape[-1], int) and shape[-1] > 1:
111117
return int(shape[-1])
112118
return 0
119+
120+
121+
# ---------------------------------------------------------------------------
122+
# Generic streaming wrappers
123+
# ---------------------------------------------------------------------------
124+
# Each integration provides:
125+
# on_chunk(chunk) -> None : update mutable state from a single chunk
126+
# on_done(duration_ms, ttft_ms) : record inference once the stream is exhausted
127+
#
128+
# The wrappers handle TTFT capture, error tracking, context-manager delegation,
129+
# and attribute proxying so callers get a drop-in replacement for the raw stream.
130+
131+
132+
class SyncStreamWrapper:
133+
"""Wraps a sync iterable stream to capture TTFT and record inference on exhaustion."""
134+
135+
def __init__(
136+
self,
137+
original: object,
138+
handle: ModelHandle,
139+
t0: float,
140+
on_chunk: Callable[[object], None] | None,
141+
on_done: Callable[[int, int | None], None],
142+
) -> None:
143+
self._original = original
144+
self._handle = handle
145+
self._t0 = t0
146+
self._on_chunk = on_chunk
147+
self._on_done = on_done
148+
149+
def __iter__(self):
150+
return self._track()
151+
152+
def _track(self):
153+
ttft_ms: int | None = None
154+
try:
155+
for chunk in self._original: # type: ignore[union-attr]
156+
if ttft_ms is None:
157+
ttft_ms = elapsed_ms(self._t0)
158+
if self._on_chunk is not None:
159+
self._on_chunk(chunk)
160+
yield chunk
161+
except Exception as exc:
162+
self._handle.track_error(
163+
error_code="UNKNOWN",
164+
error_message=str(exc)[: constants.ERROR_MSG_MAX_LEN],
165+
)
166+
raise
167+
else:
168+
self._on_done(elapsed_ms(self._t0), ttft_ms)
169+
170+
def __enter__(self) -> SyncStreamWrapper:
171+
if hasattr(self._original, "__enter__"):
172+
self._original.__enter__() # type: ignore[union-attr]
173+
return self
174+
175+
def __exit__(self, *args: object) -> object:
176+
if hasattr(self._original, "__exit__"):
177+
return self._original.__exit__(*args) # type: ignore[union-attr]
178+
return None
179+
180+
def __getattr__(self, name: str) -> object:
181+
return getattr(self._original, name)
182+
183+
184+
class AsyncStreamWrapper:
185+
"""Wraps an async iterable stream to capture TTFT and record inference on exhaustion."""
186+
187+
def __init__(
188+
self,
189+
original: object,
190+
handle: ModelHandle,
191+
t0: float,
192+
on_chunk: Callable[[object], None] | None,
193+
on_done: Callable[[int, int | None], None],
194+
) -> None:
195+
self._original = original
196+
self._handle = handle
197+
self._t0 = t0
198+
self._on_chunk = on_chunk
199+
self._on_done = on_done
200+
201+
def __aiter__(self):
202+
return self._track()
203+
204+
async def _track(self):
205+
ttft_ms: int | None = None
206+
try:
207+
async for chunk in self._original: # type: ignore[union-attr]
208+
if ttft_ms is None:
209+
ttft_ms = elapsed_ms(self._t0)
210+
if self._on_chunk is not None:
211+
self._on_chunk(chunk)
212+
yield chunk
213+
except Exception as exc:
214+
self._handle.track_error(
215+
error_code="UNKNOWN",
216+
error_message=str(exc)[: constants.ERROR_MSG_MAX_LEN],
217+
)
218+
raise
219+
else:
220+
self._on_done(elapsed_ms(self._t0), ttft_ms)
221+
222+
async def __aenter__(self) -> AsyncStreamWrapper:
223+
if hasattr(self._original, "__aenter__"):
224+
await self._original.__aenter__() # type: ignore[union-attr]
225+
return self
226+
227+
async def __aexit__(self, *args: object) -> object:
228+
if hasattr(self._original, "__aexit__"):
229+
return await self._original.__aexit__(*args) # type: ignore[union-attr]
230+
return None
231+
232+
def __getattr__(self, name: str) -> object:
233+
return getattr(self._original, name)

0 commit comments

Comments
 (0)