Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions chatstream/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@ class chat_server:
----------
model
OpenAI model to use. Can be a string or a function that returns a string.
azure_deployment_id
Azure deployment ID to use (optional). Azure supports the OpenAI API, but with
some slight changes. If you are using Azure, you must set this to the your
deployment ID. Can be a string or a function that return a string.
api_key
OpenAI API key to use (optional). Can be a string or a function that returns a
string, or `None`. If `None`, then it will use the `OPENAI_API_KEY` environment
Expand All @@ -124,7 +128,7 @@ class chat_server:
temperature
Temperature to use. Can be a float or a function that returns a float.
text_input_placeholder
Placeholder teext to use for the text input. Can be a string or a function that
Placeholder text to use for the text input. Can be a string or a function that
returns a string, or `None` for no placeholder.
throttle
Throttle interval to use for incoming streaming messages. Can be a float or a
Expand Down Expand Up @@ -164,6 +168,7 @@ def __init__(
session: Session,
*,
model: OpenAiModel | Callable[[], OpenAiModel] = DEFAULT_MODEL,
azure_deployment_id: str | Callable[[], str] | None = None,
api_key: str | Callable[[], str] | None = None,
url: str | Callable[[], str] | None = None,
system_prompt: str | Callable[[], str] = DEFAULT_SYSTEM_PROMPT,
Expand All @@ -189,6 +194,12 @@ def __init__(
Callable[[], OpenAiModel],
wrap_function_nonreactive(model),
)

if azure_deployment_id is None:
self.azure_deployment_id = None
else:
self.azure_deployment_id = wrap_function_nonreactive(azure_deployment_id)

if api_key is None:
self.api_key = get_env_var_api_key
else:
Expand Down Expand Up @@ -338,13 +349,19 @@ async def perform_query():
if self.url() is not None:
extra_kwargs["url"] = self.url()

if self.azure_deployment_id is not None:
# Azure-OpenAI uses deployment_id instead of model.
extra_kwargs["deployment_id"] = self.azure_deployment_id()
else:
# OpenAI just uses model.
extra_kwargs["model"] = self.model()

# Launch a Task that updates the chat string asynchronously. We run this in
# a separate task so that the data can come in without need to await it in
# this Task (which would block other computation to happen, like running
# reactive stuff).
messages: StreamResult[ChatCompletionStreaming] = stream_to_reactive(
openai.ChatCompletion.acreate( # pyright: ignore[reportUnknownMemberType, reportGeneralTypeIssues]
model=self.model(),
api_key=self.api_key(),
messages=outgoing_messages_normalized,
stream=True,
Expand Down
1 change: 1 addition & 0 deletions chatstream/openai_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"gpt-4-32k-0314",
]


openai_model_context_limits: dict[OpenAiModel, int] = {
"gpt-3.5-turbo": 4096,
"gpt-3.5-turbo-16k": 16384,
Expand Down