From 0d04ee723439e7cf131c81624fbd540fdf81097b Mon Sep 17 00:00:00 2001 From: Azzedde Date: Mon, 12 May 2025 17:39:50 +0200 Subject: [PATCH 1/2] enhancement: Add Structured Output support --- skllm/llm/base.py | 9 +++- skllm/llm/gpt/clients/openai/completion.py | 53 +++++++++++++++++++++- tests/conftest.py | 43 ++++++++++++++++++ tests/test_structured_outputs.py | 38 ++++++++++++++++ 4 files changed, 141 insertions(+), 2 deletions(-) create mode 100644 tests/conftest.py create mode 100644 tests/test_structured_outputs.py diff --git a/skllm/llm/base.py b/skllm/llm/base.py index 18b7edf..3a29cf5 100644 --- a/skllm/llm/base.py +++ b/skllm/llm/base.py @@ -1,6 +1,8 @@ from abc import ABC, abstractmethod -from typing import Any +from typing import Any, TypeVar, Type +from pydantic import BaseModel +T = TypeVar('T', bound=BaseModel) class BaseTextCompletionMixin(ABC): @abstractmethod @@ -13,6 +15,11 @@ def _convert_completion_to_str(self, completion: Any): """Converts a completion object to a string""" pass + @abstractmethod + def _get_parsed_completion(self, output_model: Type[T], **kwargs) -> T: + """Gets a chat completion parsed into the specified Pydantic model""" + pass + class BaseClassifierMixin(BaseTextCompletionMixin): @abstractmethod diff --git a/skllm/llm/gpt/clients/openai/completion.py b/skllm/llm/gpt/clients/openai/completion.py index 779c0f2..d9c80c9 100644 --- a/skllm/llm/gpt/clients/openai/completion.py +++ b/skllm/llm/gpt/clients/openai/completion.py @@ -1,4 +1,5 @@ -import openai +from typing import TypeVar, Type +from pydantic import BaseModel from openai import OpenAI from skllm.llm.gpt.clients.openai.credentials import ( set_azure_credentials, @@ -6,6 +7,8 @@ ) from skllm.utils import retry +T = TypeVar('T', bound=BaseModel) + @retry(max_retries=3) def get_chat_completion( @@ -50,3 +53,51 @@ def get_chat_completion( temperature=0.0, messages=messages, **model_dict ) return completion + + +@retry(max_retries=3) +def get_parsed_completion( + messages: dict, + output_model: Type[T], + key: str, + org: str, + model: str = "gpt-3.5-turbo", + api="openai", +) -> T: + """Gets a chat completion parsed into the specified Pydantic model. + + Parameters + ---------- + messages : dict + input messages to use. + output_model : Type[T] + Pydantic model class to parse the response into. + key : str + The OPEN AI key to use. + org : str + The OPEN AI organization ID to use. + model : str, optional + The OPEN AI model to use. Defaults to "gpt-3.5-turbo". + api : str + The API to use. Must be one of "openai" or "azure". Defaults to "openai". + + Returns + ------- + parsed_model : T + Instance of the specified Pydantic model + """ + if api in ("openai", "custom_url"): + client = set_credentials(key, org) + elif api == "azure": + client = set_azure_credentials(key, org) + else: + raise ValueError("Invalid API") + + completion = client.beta.chat.completions.parse( + model=model, + messages=messages, + response_format=output_model, + temperature=0.0 + ) + + return completion.choices[0].message.parsed diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..b915165 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,43 @@ +# tests/conftest.py +import pytest +from types import SimpleNamespace +import skllm.llm.gpt.clients.openai.completion as completion_mod +from test_structured_outputs import TestEvent + +class DummyCompletions: + def __init__(self, model_cls): + self._model_cls = model_cls + + def parse(self, *, model, messages, response_format, temperature): + # response_format is the Pydantic model class (TestEvent) + fake = self._model_cls( + event_name="science fair", + date="Friday", + attendees=["Alice", "Bob"], + ) + return SimpleNamespace( + choices=[SimpleNamespace(message=SimpleNamespace(parsed=fake))] + ) + + def create(self, *, temperature, messages, **kwargs): + # if you ever test get_chat_completion + return {"id": "dummy", "choices": []} + +class DummyClient: + def __init__(self, model_cls): + self.chat = SimpleNamespace(completions=DummyCompletions(model_cls)) + self.beta = SimpleNamespace(chat=SimpleNamespace(completions=DummyCompletions(model_cls))) + + +@pytest.fixture(autouse=True) +def patch_openai(monkeypatch): + monkeypatch.setattr( + completion_mod, + "set_credentials", + lambda key, org: DummyClient(TestEvent), + ) + monkeypatch.setattr( + completion_mod, + "set_azure_credentials", + lambda key, org: DummyClient(TestEvent), + ) diff --git a/tests/test_structured_outputs.py b/tests/test_structured_outputs.py new file mode 100644 index 0000000..0ae6810 --- /dev/null +++ b/tests/test_structured_outputs.py @@ -0,0 +1,38 @@ + +from pydantic import BaseModel + +from skllm.llm.gpt.clients.openai.completion import get_parsed_completion + +# Add __test__ = False to prevent pytest collection +class TestEvent(BaseModel): + __test__ = False + event_name: str + date: str + attendees: list[str] + +def test_openai_structured_output(): + """Test that structured outputs are properly parsed into Pydantic models.""" + messages = [ + {"role": "system", "content": "Extract event information in JSON format"}, + {"role": "user", "content": "Alice and Bob are attending the science fair on Friday"} + ] + + # Test successful parsing + result = get_parsed_completion( + messages=messages, + output_model=TestEvent, + key="dummy_value", # Replace with actual key + org="dummy_value", # Replace with actual org + model="gpt-4o-mini" + ) + + # Validate the result structure + assert isinstance(result, TestEvent) + assert isinstance(result.event_name, str) + assert len(result.event_name) > 0 + assert isinstance(result.date, str) + assert len(result.date) > 0 + assert isinstance(result.attendees, list) + assert len(result.attendees) >= 2 # Should have at least Alice and Bob + assert all(isinstance(name, str) for name in result.attendees) + From 952cc2b8f3668f0afb06cf03f80f0c418edfd069 Mon Sep 17 00:00:00 2001 From: Azzedde Date: Mon, 19 May 2025 11:30:58 +0200 Subject: [PATCH 2/2] enhance: turning pytest tests to unittest --- requirements-dev.txt | 1 + tests/conftest.py | 43 ------------- tests/test_structured_outputs.py | 105 ++++++++++++++++++++++--------- 3 files changed, 78 insertions(+), 71 deletions(-) delete mode 100644 tests/conftest.py diff --git a/requirements-dev.txt b/requirements-dev.txt index c2de4e7..0a0a1cf 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -6,3 +6,4 @@ docformatter interrogate numpy pandas +pydantic diff --git a/tests/conftest.py b/tests/conftest.py deleted file mode 100644 index b915165..0000000 --- a/tests/conftest.py +++ /dev/null @@ -1,43 +0,0 @@ -# tests/conftest.py -import pytest -from types import SimpleNamespace -import skllm.llm.gpt.clients.openai.completion as completion_mod -from test_structured_outputs import TestEvent - -class DummyCompletions: - def __init__(self, model_cls): - self._model_cls = model_cls - - def parse(self, *, model, messages, response_format, temperature): - # response_format is the Pydantic model class (TestEvent) - fake = self._model_cls( - event_name="science fair", - date="Friday", - attendees=["Alice", "Bob"], - ) - return SimpleNamespace( - choices=[SimpleNamespace(message=SimpleNamespace(parsed=fake))] - ) - - def create(self, *, temperature, messages, **kwargs): - # if you ever test get_chat_completion - return {"id": "dummy", "choices": []} - -class DummyClient: - def __init__(self, model_cls): - self.chat = SimpleNamespace(completions=DummyCompletions(model_cls)) - self.beta = SimpleNamespace(chat=SimpleNamespace(completions=DummyCompletions(model_cls))) - - -@pytest.fixture(autouse=True) -def patch_openai(monkeypatch): - monkeypatch.setattr( - completion_mod, - "set_credentials", - lambda key, org: DummyClient(TestEvent), - ) - monkeypatch.setattr( - completion_mod, - "set_azure_credentials", - lambda key, org: DummyClient(TestEvent), - ) diff --git a/tests/test_structured_outputs.py b/tests/test_structured_outputs.py index 0ae6810..f78d3dc 100644 --- a/tests/test_structured_outputs.py +++ b/tests/test_structured_outputs.py @@ -1,38 +1,87 @@ +import unittest from pydantic import BaseModel - from skllm.llm.gpt.clients.openai.completion import get_parsed_completion +import unittest +from unittest.mock import patch +from types import SimpleNamespace +import skllm.llm.gpt.clients.openai.completion as completion_mod + +class DummyCompletions: + def __init__(self, model_cls): + self._model_cls = model_cls + + def parse(self, *, model, messages, response_format, temperature): + # response_format is the Pydantic model class (TestEvent) + fake = self._model_cls( + event_name="science fair", + date="Friday", + attendees=["Alice", "Bob"], + ) + return SimpleNamespace( + choices=[SimpleNamespace(message=SimpleNamespace(parsed=fake))] + ) + + def create(self, *, temperature, messages, **kwargs): + # if you ever test get_chat_completion + return {"id": "dummy", "choices": []} + +class DummyClient: + def __init__(self, model_cls): + self.chat = SimpleNamespace(completions=DummyCompletions(model_cls)) + self.beta = SimpleNamespace(chat=SimpleNamespace(completions=DummyCompletions(model_cls))) + +class OpenAITestCase(unittest.TestCase): + def setUp(self): + self.patcher1 = patch.object( + completion_mod, + "set_credentials", + lambda key, org: DummyClient(TestEvent) + ) + self.patcher2 = patch.object( + completion_mod, + "set_azure_credentials", + lambda key, org: DummyClient(TestEvent) + ) + self.patcher1.start() + self.patcher2.start() + + def tearDown(self): + self.patcher1.stop() + self.patcher2.stop() + -# Add __test__ = False to prevent pytest collection class TestEvent(BaseModel): - __test__ = False event_name: str date: str attendees: list[str] -def test_openai_structured_output(): - """Test that structured outputs are properly parsed into Pydantic models.""" - messages = [ - {"role": "system", "content": "Extract event information in JSON format"}, - {"role": "user", "content": "Alice and Bob are attending the science fair on Friday"} - ] - - # Test successful parsing - result = get_parsed_completion( - messages=messages, - output_model=TestEvent, - key="dummy_value", # Replace with actual key - org="dummy_value", # Replace with actual org - model="gpt-4o-mini" - ) - - # Validate the result structure - assert isinstance(result, TestEvent) - assert isinstance(result.event_name, str) - assert len(result.event_name) > 0 - assert isinstance(result.date, str) - assert len(result.date) > 0 - assert isinstance(result.attendees, list) - assert len(result.attendees) >= 2 # Should have at least Alice and Bob - assert all(isinstance(name, str) for name in result.attendees) +class TestOpenAIStructuredOutput(OpenAITestCase): + def test_openai_structured_output(self): + """Test that structured outputs are properly parsed into Pydantic models.""" + messages = [ + {"role": "system", "content": "Extract event information in JSON format"}, + {"role": "user", "content": "Alice and Bob are attending the science fair on Friday"} + ] + + # Test successful parsing + result = get_parsed_completion( + messages=messages, + output_model=TestEvent, + key="dummy_value", # Replace with actual key + org="dummy_value", # Replace with actual org + model="gpt-4o-mini" + ) + + # Validate the result structure + self.assertIsInstance(result, TestEvent) + self.assertIsInstance(result.event_name, str) + self.assertGreater(len(result.event_name), 0) + self.assertIsInstance(result.date, str) + self.assertGreater(len(result.date), 0) + self.assertIsInstance(result.attendees, list) + self.assertGreaterEqual(len(result.attendees), 2) # Should have at least Alice and Bob + self.assertTrue(all(isinstance(name, str) for name in result.attendees)) +if __name__ == '__main__': + unittest.main()