diff --git a/chatstream/__init__.py b/chatstream/__init__.py index d6b6806..d19f97c 100644 --- a/chatstream/__init__.py +++ b/chatstream/__init__.py @@ -111,6 +111,10 @@ class chat_server: ---------- model OpenAI model to use. Can be a string or a function that returns a string. + azure_deployment_id + Azure deployment ID to use (optional). Azure supports the OpenAI API, but with + some slight changes. If you are using Azure, you must set this to the your + deployment ID. Can be a string or a function that return a string. api_key OpenAI API key to use (optional). Can be a string or a function that returns a string, or `None`. If `None`, then it will use the `OPENAI_API_KEY` environment @@ -124,7 +128,7 @@ class chat_server: temperature Temperature to use. Can be a float or a function that returns a float. text_input_placeholder - Placeholder teext to use for the text input. Can be a string or a function that + Placeholder text to use for the text input. Can be a string or a function that returns a string, or `None` for no placeholder. throttle Throttle interval to use for incoming streaming messages. Can be a float or a @@ -164,6 +168,7 @@ def __init__( session: Session, *, model: OpenAiModel | Callable[[], OpenAiModel] = DEFAULT_MODEL, + azure_deployment_id: str | Callable[[], str] | None = None, api_key: str | Callable[[], str] | None = None, url: str | Callable[[], str] | None = None, system_prompt: str | Callable[[], str] = DEFAULT_SYSTEM_PROMPT, @@ -189,6 +194,12 @@ def __init__( Callable[[], OpenAiModel], wrap_function_nonreactive(model), ) + + if azure_deployment_id is None: + self.azure_deployment_id = None + else: + self.azure_deployment_id = wrap_function_nonreactive(azure_deployment_id) + if api_key is None: self.api_key = get_env_var_api_key else: @@ -338,13 +349,19 @@ async def perform_query(): if self.url() is not None: extra_kwargs["url"] = self.url() + if self.azure_deployment_id is not None: + # Azure-OpenAI uses deployment_id instead of model. + extra_kwargs["deployment_id"] = self.azure_deployment_id() + else: + # OpenAI just uses model. + extra_kwargs["model"] = self.model() + # Launch a Task that updates the chat string asynchronously. We run this in # a separate task so that the data can come in without need to await it in # this Task (which would block other computation to happen, like running # reactive stuff). messages: StreamResult[ChatCompletionStreaming] = stream_to_reactive( openai.ChatCompletion.acreate( # pyright: ignore[reportUnknownMemberType, reportGeneralTypeIssues] - model=self.model(), api_key=self.api_key(), messages=outgoing_messages_normalized, stream=True, diff --git a/chatstream/openai_types.py b/chatstream/openai_types.py index f790cf8..bca06fc 100644 --- a/chatstream/openai_types.py +++ b/chatstream/openai_types.py @@ -15,6 +15,7 @@ "gpt-4-32k-0314", ] + openai_model_context_limits: dict[OpenAiModel, int] = { "gpt-3.5-turbo": 4096, "gpt-3.5-turbo-16k": 16384,