Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 38 additions & 15 deletions app/services/providers/openai_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,16 @@ def get_model_id(self, payload: dict[str, Any]) -> str:
logger.error(f"Model ID not found in payload for {self.provider_name}")
raise BaseInvalidRequestException(
provider_name=self.provider_name,
error=ValueError("Model ID not found in payload")
error=ValueError("Model ID not found in payload"),
)

def _ensure_list(
self, value: str | list[str] | list[int] | list[list[int]]
) -> list[str] | list[int] | list[list[int]]:
if not isinstance(value, list):
return [value]
return value

async def list_models(
self,
api_key: str,
Expand Down Expand Up @@ -70,19 +77,22 @@ async def list_models(
):
if response.status != HTTPStatus.OK:
error_text = await response.text()
logger.error(f"List Models API error for {self.provider_name}: {error_text}")
logger.error(
f"List Models API error for {self.provider_name}: {error_text}"
)
raise ProviderAPIException(
provider_name=self.provider_name,
error_code=response.status,
error_message=error_text
error_message=error_text,
)
resp = await response.json()

# Better compatibility with Forge
models_list = resp["data"] if isinstance(resp, dict) else resp

self.OPENAI_MODEL_MAPPING = {
d.get("name", self.get_model_id(d)): self.get_model_id(d) for d in models_list
d.get("name", self.get_model_id(d)): self.get_model_id(d)
for d in models_list
}
models = [self.get_model_id(d) for d in models_list]

Expand Down Expand Up @@ -122,11 +132,13 @@ async def stream_response() -> AsyncGenerator[bytes, None]:
):
if response.status != HTTPStatus.OK:
error_text = await response.text()
logger.error(f"Completion Streaming API error for {self.provider_name}: {error_text}")
logger.error(
f"Completion Streaming API error for {self.provider_name}: {error_text}"
)
raise ProviderAPIException(
provider_name=self.provider_name,
error_code=response.status,
error_message=error_text
error_message=error_text,
)

# Stream the response back
Expand All @@ -148,11 +160,13 @@ async def stream_response() -> AsyncGenerator[bytes, None]:
):
if response.status != HTTPStatus.OK:
error_text = await response.text()
logger.error(f"Completion API error for {self.provider_name}: {error_text}")
logger.error(
f"Completion API error for {self.provider_name}: {error_text}"
)
raise ProviderAPIException(
provider_name=self.provider_name,
error_code=response.status,
error_message=error_text
error_message=error_text,
)

return await response.json()
Expand All @@ -177,11 +191,13 @@ async def process_image_generation(
):
if response.status != HTTPStatus.OK:
error_text = await response.text()
logger.error(f"Image Generation API error for {self.provider_name}: {error_text}")
logger.error(
f"Image Generation API error for {self.provider_name}: {error_text}"
)
raise ProviderAPIException(
provider_name=self.provider_name,
error_code=response.status,
error_message=error_text
error_message=error_text,
)

return await response.json()
Expand Down Expand Up @@ -210,11 +226,11 @@ async def process_image_edits(
raise ProviderAPIException(
provider_name=self.provider_name,
error_code=response.status,
error_message=error_text
error_message=error_text,
)

return await response.json()

async def process_embeddings(
self,
endpoint: str,
Expand All @@ -230,6 +246,9 @@ async def process_embeddings(
"Content-Type": "application/json",
}

# process single and batch jobs
payload["input"] = self._ensure_list(payload["input"])

# inpput_type is for cohere embeddings only
if "input_type" in payload:
del payload["input_type"]
Expand All @@ -239,15 +258,19 @@ async def process_embeddings(

async with (
aiohttp.ClientSession() as session,
session.post(url, headers=headers, json=payload, params=query_params) as response,
session.post(
url, headers=headers, json=payload, params=query_params
) as response,
):
if response.status != HTTPStatus.OK:
error_text = await response.text()
logger.error(f"Embeddings API error for {self.provider_name}: {error_text}")
logger.error(
f"Embeddings API error for {self.provider_name}: {error_text}"
)
raise ProviderAPIException(
provider_name=self.provider_name,
error_code=response.status,
error_message=error_text
error_message=error_text,
)

return await response.json()
2 changes: 0 additions & 2 deletions app/services/providers/tensorblock_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@
from .azure_adapter import AzureAdapter

TENSORBLOCK_MODELS = [
"gpt-4.1",
"gpt-4.1-mini",
"gpt-4.1-nano",
"gpt-4o",
"gpt-4o-mini",
"o3-mini",
"text-embedding-3-large",
Expand Down
20 changes: 20 additions & 0 deletions tests/unit_tests/assets/openai/embeddings_response.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
{
"object": "list",
"data": [
{
"object": "embedding",
"index": 0,
"embedding": [
-0.006929283495992422,
-0.005336422007530928,
-4.547132266452536e-05,
-0.024047503247857094
]
}
],
"model": "text-embedding-ada-002",
"usage": {
"prompt_tokens": 8,
"total_tokens": 8
}
}
22 changes: 22 additions & 0 deletions tests/unit_tests/test_openai_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@
) as f:
MOCK_CHAT_COMPLETION_STREAMING_RESPONSE_DATA = json.load(f)

with open(
os.path.join(CURRENT_DIR, "docs", "openai", "embeddings_response.json"), "r"
) as f:
MOCK_EMBEDDINGS_RESPONSE_DATA = json.load(f)


class TestOpenAIProvider(TestCase):
def setUp(self):
Expand Down Expand Up @@ -108,3 +113,20 @@ async def test_chat_completion_streaming(self):
expected_model="gpt-4o-mini-2024-07-18",
expected_message=OPENAAI_STANDARD_CHAT_COMPLETION_RESPONSE,
)

async def test_process_embeddings(self):
payload = {
"model": "text-embedding-ada-002",
"input": ["hello", "world"],
}
with patch("aiohttp.ClientSession", ClientSessionMock()) as mock_session:
mock_session.responses = [(MOCK_EMBEDDINGS_RESPONSE_DATA, 200)]

# Call the method
result = await self.adapter.process_embeddings(
api_key=self.api_key, payload=payload, endpoint="embeddings"
)
self.assertEqual(result, MOCK_EMBEDDINGS_RESPONSE_DATA)

# verify that the payload sent to openai has a list as input
self.assertIsInstance(mock_session.posted_json[0]["input"], list)
Loading