From aa139be73c36ec69a8fefa2756d43265ffb4848a Mon Sep 17 00:00:00 2001 From: Martin Peck Date: Wed, 2 Oct 2024 10:39:18 +0000 Subject: [PATCH] WIP commit --- .../src/aoai_api_simulator/models.py | 19 +++-- .../record_replay/handler.py | 31 ++++---- .../record_replay/openai.py | 71 +++++++++++++++---- 3 files changed, 90 insertions(+), 31 deletions(-) diff --git a/src/aoai-api-simulator/src/aoai_api_simulator/models.py b/src/aoai-api-simulator/src/aoai_api_simulator/models.py index 8e6726f..89cb255 100644 --- a/src/aoai-api-simulator/src/aoai_api_simulator/models.py +++ b/src/aoai-api-simulator/src/aoai_api_simulator/models.py @@ -1,15 +1,25 @@ -from dataclasses import dataclass import random +from dataclasses import dataclass from typing import Annotated, Awaitable, Callable +import nanoid + # from aoai_api_simulator.pipeline import RequestContext from fastapi import Request, Response from pydantic import Field +from pydantic.functional_validators import AfterValidator from pydantic_settings import BaseSettings, SettingsConfigDict from requests import Response as requests_Response -from starlette.routing import Route, Match +from starlette.routing import Match, Route -import nanoid + +def validate_endpoint_format(url: str) -> str: + if url.endswith("/"): + url = url[:-1] + return url + + +ValidatedUrl = Annotated[str, AfterValidator(validate_endpoint_format)] class RequestContext: @@ -67,7 +77,8 @@ class RecordingConfig(BaseSettings): dir: str = Field(default=".recording", alias="RECORDING_DIR") autosave: bool = Field(default=True, alias="RECORDING_AUTOSAVE") aoai_api_key: str | None = Field(default=None, alias="AZURE_OPENAI_KEY") - aoai_api_endpoint: str | None = Field(default=None, alias="AZURE_OPENAI_ENDPOINT") + aoai_api_endpoint: ValidatedUrl | None = Field(default=None, alias="AZURE_OPENAI_ENDPOINT") + forwarders: ( list[ Callable[ diff --git a/src/aoai-api-simulator/src/aoai_api_simulator/record_replay/handler.py b/src/aoai-api-simulator/src/aoai_api_simulator/record_replay/handler.py index 10a4b00..3adf156 100644 --- a/src/aoai-api-simulator/src/aoai_api_simulator/record_replay/handler.py +++ b/src/aoai-api-simulator/src/aoai_api_simulator/record_replay/handler.py @@ -5,11 +5,10 @@ import fastapi import requests - from aoai_api_simulator import constants from aoai_api_simulator.models import RequestContext -from aoai_api_simulator.record_replay.openai import forward_to_azure_openai from aoai_api_simulator.record_replay.models import RecordedResponse, get_request_hash, hash_request_parts +from aoai_api_simulator.record_replay.openai import default_openai_forwarder, openai_image_gen_forwarder from aoai_api_simulator.record_replay.persistence import YamlRecordingPersister logger = logging.getLogger(__name__) @@ -17,18 +16,20 @@ text_content_types = ["application/json", "application/text"] -def get_default_forwarders() -> list[ - Callable[ - [RequestContext], - fastapi.Response - | Awaitable[fastapi.Response] - | requests.Response - | Awaitable[requests.Response] - | dict - | Awaitable[dict] - | None, +def get_default_forwarders() -> ( + list[ + Callable[ + [RequestContext], + fastapi.Response + | Awaitable[fastapi.Response] + | requests.Response + | Awaitable[requests.Response] + | dict + | Awaitable[dict] + | None, + ] ] -]: +): # Return a list of functions to call when recording and no matching saved request is found # # If the function returns a Response object (from FastAPI or requests package) @@ -39,7 +40,8 @@ def get_default_forwarders() -> list[ # # If the function returns None, the next function in the list will be called return [ - forward_to_azure_openai, + openai_image_gen_forwarder, + default_openai_forwarder, ] @@ -58,7 +60,6 @@ def persist_response(self) -> bool: class RecordReplayHandler: - _recordings: dict[str, dict[int, RecordedResponse]] _forwarders: list[ Callable[ diff --git a/src/aoai-api-simulator/src/aoai_api_simulator/record_replay/openai.py b/src/aoai-api-simulator/src/aoai_api_simulator/record_replay/openai.py index 5ff6d68..c40bcb7 100644 --- a/src/aoai-api-simulator/src/aoai_api_simulator/record_replay/openai.py +++ b/src/aoai-api-simulator/src/aoai_api_simulator/record_replay/openai.py @@ -1,15 +1,15 @@ import json import logging -import requests -from aoai_api_simulator.models import RequestContext +import requests from aoai_api_simulator.constants import ( SIMULATOR_KEY_DEPLOYMENT_NAME, - SIMULATOR_KEY_OPENAI_PROMPT_TOKENS, + SIMULATOR_KEY_LIMITER, SIMULATOR_KEY_OPENAI_COMPLETION_TOKENS, + SIMULATOR_KEY_OPENAI_PROMPT_TOKENS, SIMULATOR_KEY_OPENAI_TOTAL_TOKENS, - SIMULATOR_KEY_LIMITER, ) +from aoai_api_simulator.models import RequestContext # This file contains a default openai forwarder # You can configure your own forwarders by creating a forwarder_config.py file and setting the @@ -85,7 +85,7 @@ def _get_token_usage_from_response(body: str) -> int | None: return None -async def forward_to_azure_openai(context: RequestContext) -> dict: +async def default_openai_forwarder(context: RequestContext) -> dict: request = context.request if not request.url.path.startswith("/openai/"): # assume not an OpenAI request @@ -95,22 +95,69 @@ async def forward_to_azure_openai(context: RequestContext) -> dict: # Only initialize once, and only if we need to _validate_endpoint_config(context) - aoai_api_endpoint = context.config.recording.aoai_api_endpoint - aoai_api_key = context.config.recording.aoai_api_key + url = context.config.recording.aoai_api_endpoint + if url.endswith("/"): + url = url[:-1] + url += request.url.path + "?" + request.url.query + + # Copy most headers, but override auth + fwd_headers = { + k: v for k, v in request.headers.items() if k.lower() not in ["content-length", "host", "authorization"] + } + fwd_headers["api-key"] = context.config.recording.aoai_api_key + + body = await request.body() + + response = requests.request( + request.method, + url, + headers=fwd_headers, + data=body, + timeout=30, + ) + + for header in aoai_response_headers_to_remove: + if response.headers.get(header): + del response.headers[header] + + if response.status_code >= 300: + # Likely an error or rate-limit + # no further processing - indicate not to persist this response + return {"response": response, "persist_response": False} - if aoai_api_key is None or aoai_api_endpoint is None: + # store values in the context for use by the rate-limiter etc + deployment_name = _get_deployment_name_from_url(request.url.path) + prompt_tokens, completion_tokens, total_tokens = _get_token_usage_from_response(response.text) + context.values[SIMULATOR_KEY_LIMITER] = "openai" + context.values[SIMULATOR_KEY_DEPLOYMENT_NAME] = deployment_name + context.values[SIMULATOR_KEY_OPENAI_PROMPT_TOKENS] = prompt_tokens + context.values[SIMULATOR_KEY_OPENAI_COMPLETION_TOKENS] = completion_tokens + context.values[SIMULATOR_KEY_OPENAI_TOTAL_TOKENS] = total_tokens + + return {"response": response, "persist_response": True} + + +async def openai_image_gen_forwarder(context: RequestContext) -> dict: + request = context.request + + if request.url.path.find("images/generations") == -1: + # this is not an image generation request return None - url = aoai_api_endpoint - if url.endswith("/"): - url = url[:-1] + if not config_validated: + # Only initialize once, and only if we need to + _validate_endpoint_config(context) + + raise NotImplementedError("Image Forwarder Not Complete.") + + url = context.config.recording.aoai_api_endpoint url += request.url.path + "?" + request.url.query # Copy most headers, but override auth fwd_headers = { k: v for k, v in request.headers.items() if k.lower() not in ["content-length", "host", "authorization"] } - fwd_headers["api-key"] = aoai_api_key + fwd_headers["api-key"] = context.config.recording.aoai_api_key body = await request.body()