Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion src/aoai-api-simulator/src/aoai_api_simulator/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,21 @@
# from aoai_api_simulator.pipeline import RequestContext
from fastapi import Request, Response
from pydantic import Field, field_validator
from pydantic.functional_validators import AfterValidator
from pydantic_settings import BaseSettings, SettingsConfigDict
from requests import Response as requests_Response
from starlette.routing import Match, Route


def validate_endpoint_format(url: str) -> str:
if url.endswith("/"):
url = url[:-1]
return url


ValidatedUrl = Annotated[str, AfterValidator(validate_endpoint_format)]


class RequestContext:
_config: "Config"
_request: Request
Expand Down Expand Up @@ -67,7 +77,8 @@ class RecordingConfig(BaseSettings):
dir: str = Field(default=".recording", alias="RECORDING_DIR")
autosave: bool = Field(default=True, alias="RECORDING_AUTOSAVE")
aoai_api_key: str | None = Field(default=None, alias="AZURE_OPENAI_KEY")
aoai_api_endpoint: str | None = Field(default=None, alias="AZURE_OPENAI_ENDPOINT")
aoai_api_endpoint: ValidatedUrl | None = Field(default=None, alias="AZURE_OPENAI_ENDPOINT")

forwarders: (
list[
Callable[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,31 @@

import fastapi
import requests

from aoai_api_simulator import constants
from aoai_api_simulator.models import RequestContext
from aoai_api_simulator.record_replay.openai import forward_to_azure_openai
from aoai_api_simulator.record_replay.models import RecordedResponse, get_request_hash, hash_request_parts
from aoai_api_simulator.record_replay.openai import default_openai_forwarder, openai_image_gen_forwarder
from aoai_api_simulator.record_replay.persistence import YamlRecordingPersister

logger = logging.getLogger(__name__)

text_content_types = ["application/json", "application/text"]


def get_default_forwarders() -> list[
Callable[
[RequestContext],
fastapi.Response
| Awaitable[fastapi.Response]
| requests.Response
| Awaitable[requests.Response]
| dict
| Awaitable[dict]
| None,
def get_default_forwarders() -> (
list[
Callable[
[RequestContext],
fastapi.Response
| Awaitable[fastapi.Response]
| requests.Response
| Awaitable[requests.Response]
| dict
| Awaitable[dict]
| None,
]
]
]:
):
# Return a list of functions to call when recording and no matching saved request is found
#
# If the function returns a Response object (from FastAPI or requests package)
Expand All @@ -39,7 +40,8 @@ def get_default_forwarders() -> list[
#
# If the function returns None, the next function in the list will be called
return [
forward_to_azure_openai,
openai_image_gen_forwarder,
default_openai_forwarder,
]


Expand All @@ -58,7 +60,6 @@ def persist_response(self) -> bool:


class RecordReplayHandler:

_recordings: dict[str, dict[int, RecordedResponse]]
_forwarders: list[
Callable[
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import json
import logging
import requests

from aoai_api_simulator.models import RequestContext
import requests
from aoai_api_simulator.constants import (
SIMULATOR_KEY_DEPLOYMENT_NAME,
SIMULATOR_KEY_OPENAI_PROMPT_TOKENS,
SIMULATOR_KEY_LIMITER,
SIMULATOR_KEY_OPENAI_COMPLETION_TOKENS,
SIMULATOR_KEY_OPENAI_PROMPT_TOKENS,
SIMULATOR_KEY_OPENAI_TOTAL_TOKENS,
SIMULATOR_KEY_LIMITER,
)
from aoai_api_simulator.models import RequestContext

# This file contains a default openai forwarder
# You can configure your own forwarders by creating a forwarder_config.py file and setting the
Expand Down Expand Up @@ -85,7 +85,7 @@ def _get_token_usage_from_response(body: str) -> int | None:
return None


async def forward_to_azure_openai(context: RequestContext) -> dict:
async def default_openai_forwarder(context: RequestContext) -> dict:
request = context.request
if not request.url.path.startswith("/openai/"):
# assume not an OpenAI request
Expand All @@ -95,22 +95,69 @@ async def forward_to_azure_openai(context: RequestContext) -> dict:
# Only initialize once, and only if we need to
_validate_endpoint_config(context)

aoai_api_endpoint = context.config.recording.aoai_api_endpoint
aoai_api_key = context.config.recording.aoai_api_key
url = context.config.recording.aoai_api_endpoint
if url.endswith("/"):
url = url[:-1]
url += request.url.path + "?" + request.url.query

# Copy most headers, but override auth
fwd_headers = {
k: v for k, v in request.headers.items() if k.lower() not in ["content-length", "host", "authorization"]
}
fwd_headers["api-key"] = context.config.recording.aoai_api_key

body = await request.body()

response = requests.request(
request.method,
url,
headers=fwd_headers,
data=body,
timeout=30,
)

for header in aoai_response_headers_to_remove:
if response.headers.get(header):
del response.headers[header]

if response.status_code >= 300:
# Likely an error or rate-limit
# no further processing - indicate not to persist this response
return {"response": response, "persist_response": False}

if aoai_api_key is None or aoai_api_endpoint is None:
# store values in the context for use by the rate-limiter etc
deployment_name = _get_deployment_name_from_url(request.url.path)
prompt_tokens, completion_tokens, total_tokens = _get_token_usage_from_response(response.text)
context.values[SIMULATOR_KEY_LIMITER] = "openai"
context.values[SIMULATOR_KEY_DEPLOYMENT_NAME] = deployment_name
context.values[SIMULATOR_KEY_OPENAI_PROMPT_TOKENS] = prompt_tokens
context.values[SIMULATOR_KEY_OPENAI_COMPLETION_TOKENS] = completion_tokens
context.values[SIMULATOR_KEY_OPENAI_TOTAL_TOKENS] = total_tokens

return {"response": response, "persist_response": True}


async def openai_image_gen_forwarder(context: RequestContext) -> dict:
request = context.request

if request.url.path.find("images/generations") == -1:
# this is not an image generation request
return None

url = aoai_api_endpoint
if url.endswith("/"):
url = url[:-1]
if not config_validated:
# Only initialize once, and only if we need to
_validate_endpoint_config(context)

raise NotImplementedError("Image Forwarder Not Complete.")

url = context.config.recording.aoai_api_endpoint
url += request.url.path + "?" + request.url.query

# Copy most headers, but override auth
fwd_headers = {
k: v for k, v in request.headers.items() if k.lower() not in ["content-length", "host", "authorization"]
}
fwd_headers["api-key"] = aoai_api_key
fwd_headers["api-key"] = context.config.recording.aoai_api_key

body = await request.body()

Expand Down