|
9 | 9 | import pytest |
10 | 10 |
|
11 | 11 | import wildedge.integrations.openai as openai_mod |
| 12 | +from wildedge.integrations.common import AsyncStreamWrapper, SyncStreamWrapper |
12 | 13 | from wildedge.integrations.openai import ( |
13 | 14 | OpenAIExtractor, |
14 | 15 | build_api_meta, |
@@ -83,6 +84,56 @@ async def create(self, *args, **kwargs): |
83 | 84 | return self._response |
84 | 85 |
|
85 | 86 |
|
| 87 | +def make_stream_chunk(content=None, finish_reason=None, usage=None): |
| 88 | + chunk = SimpleNamespace( |
| 89 | + choices=[ |
| 90 | + SimpleNamespace( |
| 91 | + delta=SimpleNamespace(content=content), |
| 92 | + finish_reason=finish_reason, |
| 93 | + ) |
| 94 | + ], |
| 95 | + usage=usage, |
| 96 | + model="gpt-4o", |
| 97 | + system_fingerprint=None, |
| 98 | + service_tier=None, |
| 99 | + ) |
| 100 | + return chunk |
| 101 | + |
| 102 | + |
| 103 | +class FakeStreamingCompletions: |
| 104 | + def __init__(self, chunks): |
| 105 | + self._chunks = chunks |
| 106 | + |
| 107 | + def create(self, *args, **kwargs): |
| 108 | + if kwargs.get("stream"): |
| 109 | + return iter(self._chunks) |
| 110 | + return FakeResponse() |
| 111 | + |
| 112 | + |
| 113 | +class FakeAsyncStreamingCompletions: |
| 114 | + def __init__(self, chunks): |
| 115 | + self._chunks = chunks |
| 116 | + |
| 117 | + async def create(self, *args, **kwargs): |
| 118 | + if kwargs.get("stream"): |
| 119 | + return FakeAsyncIterator(self._chunks) |
| 120 | + return FakeResponse() |
| 121 | + |
| 122 | + |
| 123 | +class FakeAsyncIterator: |
| 124 | + def __init__(self, items): |
| 125 | + self._iter = iter(items) |
| 126 | + |
| 127 | + def __aiter__(self): |
| 128 | + return self |
| 129 | + |
| 130 | + async def __anext__(self): |
| 131 | + try: |
| 132 | + return next(self._iter) |
| 133 | + except StopIteration: |
| 134 | + raise StopAsyncIteration |
| 135 | + |
| 136 | + |
86 | 137 | # Named "OpenAI" / "AsyncOpenAI" so can_handle sees the right type name. |
87 | 138 | class OpenAI: |
88 | 139 | def __init__(self, base_url="https://api.openai.com/v1", api_key=None): |
@@ -371,11 +422,65 @@ def create(self, *args, **kwargs): |
371 | 422 | client.handles["gpt-4o"].track_error.assert_called_once() |
372 | 423 | client.handles["gpt-4o"].track_inference.assert_not_called() |
373 | 424 |
|
374 | | - def test_streaming_skips_tracking(self): |
375 | | - completions, client = self.setup() |
376 | | - completions.create(model="gpt-4o", messages=[], stream=True) |
377 | | - if "gpt-4o" in client.handles: |
378 | | - client.handles["gpt-4o"].track_inference.assert_not_called() |
| 425 | + def test_streaming_returns_sync_stream_wrapper(self): |
| 426 | + chunks = [make_stream_chunk("hi", None), make_stream_chunk(None, "stop")] |
| 427 | + completions = FakeStreamingCompletions(chunks) |
| 428 | + client = make_fake_client() |
| 429 | + wrap_sync_completions(completions, "openai", lambda: client) |
| 430 | + result = completions.create(model="gpt-4o", messages=[], stream=True) |
| 431 | + assert isinstance(result, SyncStreamWrapper) |
| 432 | + |
| 433 | + def test_streaming_records_inference_on_exhaustion(self): |
| 434 | + chunks = [ |
| 435 | + make_stream_chunk("Hello", None), |
| 436 | + make_stream_chunk(" world", "stop"), |
| 437 | + ] |
| 438 | + completions = FakeStreamingCompletions(chunks) |
| 439 | + client = make_fake_client() |
| 440 | + wrap_sync_completions(completions, "openai", lambda: client) |
| 441 | + stream = completions.create( |
| 442 | + model="gpt-4o", messages=[{"role": "user", "content": "hi"}], stream=True |
| 443 | + ) |
| 444 | + list(stream) |
| 445 | + handle = client.handles["gpt-4o"] |
| 446 | + handle.track_inference.assert_called_once() |
| 447 | + kwargs = handle.track_inference.call_args.kwargs |
| 448 | + assert kwargs["output_meta"].time_to_first_token_ms is not None |
| 449 | + assert kwargs["output_meta"].stop_reason == "stop" |
| 450 | + assert kwargs["input_modality"] == "text" |
| 451 | + assert kwargs["success"] is True |
| 452 | + |
| 453 | + def test_streaming_captures_usage_from_chunks(self): |
| 454 | + usage_chunk = SimpleNamespace(prompt_tokens=8, completion_tokens=15) |
| 455 | + chunks = [ |
| 456 | + make_stream_chunk("hi", None), |
| 457 | + make_stream_chunk(None, "stop", usage=usage_chunk), |
| 458 | + ] |
| 459 | + completions = FakeStreamingCompletions(chunks) |
| 460 | + client = make_fake_client() |
| 461 | + wrap_sync_completions(completions, "openai", lambda: client) |
| 462 | + list(completions.create(model="gpt-4o", messages=[], stream=True)) |
| 463 | + out = client.handles["gpt-4o"].track_inference.call_args.kwargs["output_meta"] |
| 464 | + assert out.tokens_in == 8 |
| 465 | + assert out.tokens_out == 15 |
| 466 | + |
| 467 | + def test_streaming_error_during_iteration_tracks_error(self): |
| 468 | + def bad_iter(): |
| 469 | + yield make_stream_chunk("hi", None) |
| 470 | + raise RuntimeError("stream error") |
| 471 | + |
| 472 | + class ErrorStreamCompletions: |
| 473 | + def create(self, *args, **kwargs): |
| 474 | + return bad_iter() |
| 475 | + |
| 476 | + client = make_fake_client() |
| 477 | + completions = ErrorStreamCompletions() |
| 478 | + wrap_sync_completions(completions, "openai", lambda: client) |
| 479 | + stream = completions.create(model="gpt-4o", messages=[], stream=True) |
| 480 | + with pytest.raises(RuntimeError, match="stream error"): |
| 481 | + list(stream) |
| 482 | + client.handles["gpt-4o"].track_error.assert_called_once() |
| 483 | + client.handles["gpt-4o"].track_inference.assert_not_called() |
379 | 484 |
|
380 | 485 | def test_closed_client_passes_through(self): |
381 | 486 | completions, client = self.setup(closed=True) |
@@ -438,11 +543,37 @@ async def create(self, *args, **kwargs): |
438 | 543 |
|
439 | 544 | client.handles["gpt-4o"].track_error.assert_called_once() |
440 | 545 |
|
441 | | - async def test_streaming_skips_tracking(self): |
442 | | - completions, client = self.setup() |
443 | | - await completions.create(model="qwen/qwen3-235b", messages=[], stream=True) |
444 | | - if "qwen/qwen3-235b" in client.handles: |
445 | | - client.handles["qwen/qwen3-235b"].track_inference.assert_not_called() |
| 546 | + async def test_streaming_returns_async_stream_wrapper(self): |
| 547 | + chunks = [make_stream_chunk("hi", None), make_stream_chunk(None, "stop")] |
| 548 | + completions = FakeAsyncStreamingCompletions(chunks) |
| 549 | + client = make_fake_client() |
| 550 | + wrap_async_completions(completions, "openrouter", lambda: client) |
| 551 | + result = await completions.create( |
| 552 | + model="qwen/qwen3-235b", messages=[], stream=True |
| 553 | + ) |
| 554 | + assert isinstance(result, AsyncStreamWrapper) |
| 555 | + |
| 556 | + async def test_streaming_records_inference_on_exhaustion(self): |
| 557 | + chunks = [ |
| 558 | + make_stream_chunk("Hello", None), |
| 559 | + make_stream_chunk(" world", "stop"), |
| 560 | + ] |
| 561 | + completions = FakeAsyncStreamingCompletions(chunks) |
| 562 | + client = make_fake_client() |
| 563 | + wrap_async_completions(completions, "openrouter", lambda: client) |
| 564 | + stream = await completions.create( |
| 565 | + model="qwen/qwen3-235b", |
| 566 | + messages=[{"role": "user", "content": "hi"}], |
| 567 | + stream=True, |
| 568 | + ) |
| 569 | + async for _ in stream: |
| 570 | + pass |
| 571 | + handle = client.handles["qwen/qwen3-235b"] |
| 572 | + handle.track_inference.assert_called_once() |
| 573 | + kwargs = handle.track_inference.call_args.kwargs |
| 574 | + assert kwargs["output_meta"].time_to_first_token_ms is not None |
| 575 | + assert kwargs["output_meta"].stop_reason == "stop" |
| 576 | + assert kwargs["success"] is True |
446 | 577 |
|
447 | 578 |
|
448 | 579 | # --------------------------------------------------------------------------- |
|
0 commit comments