From 9df44512304949e6193e7ff33390342e26d065c6 Mon Sep 17 00:00:00 2001 From: "stainless-app[bot]" <142633134+stainless-app[bot]@users.noreply.github.com> Date: Fri, 10 Oct 2025 08:49:20 +0000 Subject: [PATCH 01/14] chore(internal): detect missing future annotations with ruff --- pyproject.toml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 1db28690..35fb701e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -228,6 +228,8 @@ select = [ "B", # remove unused imports "F401", + # check for missing future annotations + "FA102", # bare except statements "E722", # unused arguments @@ -250,6 +252,8 @@ unfixable = [ "T203", ] +extend-safe-fixes = ["FA102"] + [tool.ruff.lint.flake8-tidy-imports.banned-api] "functools.lru_cache".msg = "This function does not retain type information for the wrapped function's arguments; The `lru_cache` function from `_utils` should be used instead" From f2ef07dbe6ffd744bf58a6c7b5f3dac8b73a8805 Mon Sep 17 00:00:00 2001 From: "stainless-app[bot]" <142633134+stainless-app[bot]@users.noreply.github.com> Date: Fri, 17 Oct 2025 10:21:19 +0000 Subject: [PATCH 02/14] chore: bump `httpx-aiohttp` version to 0.1.9 --- pyproject.toml | 2 +- requirements-dev.lock | 2 +- requirements.lock | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 35fb701e..eb115f0e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,7 @@ Homepage = "https://github.com/writer/writer-python" Repository = "https://github.com/writer/writer-python" [project.optional-dependencies] -aiohttp = ["aiohttp", "httpx_aiohttp>=0.1.8"] +aiohttp = ["aiohttp", "httpx_aiohttp>=0.1.9"] [tool.rye] managed = true diff --git a/requirements-dev.lock b/requirements-dev.lock index ad44866d..d8160bd5 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -60,7 +60,7 @@ httpx==0.28.1 # via httpx-aiohttp # via respx # via writer-sdk -httpx-aiohttp==0.1.8 +httpx-aiohttp==0.1.9 # via writer-sdk idna==3.4 # via anyio diff --git a/requirements.lock b/requirements.lock index 466a8f8d..0b87d0fd 100644 --- a/requirements.lock +++ b/requirements.lock @@ -43,7 +43,7 @@ httpcore==1.0.9 httpx==0.28.1 # via httpx-aiohttp # via writer-sdk -httpx-aiohttp==0.1.8 +httpx-aiohttp==0.1.9 # via writer-sdk idna==3.4 # via anyio From 31e39034cab026c34c9509757a27d9e2221c0c5b Mon Sep 17 00:00:00 2001 From: "stainless-app[bot]" <142633134+stainless-app[bot]@users.noreply.github.com> Date: Wed, 29 Oct 2025 10:44:16 +0000 Subject: [PATCH 03/14] fix(client): close streams without requiring full consumption --- src/writerai/_streaming.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/writerai/_streaming.py b/src/writerai/_streaming.py index ca4711bc..34bb9d92 100644 --- a/src/writerai/_streaming.py +++ b/src/writerai/_streaming.py @@ -76,9 +76,8 @@ def __stream__(self) -> Iterator[_T]: response=self.response, ) - # Ensure the entire stream is consumed - for _sse in iterator: - ... + # As we might not fully consume the response stream, we need to close it explicitly + response.close() def __enter__(self) -> Self: return self @@ -159,9 +158,8 @@ async def __stream__(self) -> AsyncIterator[_T]: response=self.response, ) - # Ensure the entire stream is consumed - async for _sse in iterator: - ... + # As we might not fully consume the response stream, we need to close it explicitly + await response.aclose() async def __aenter__(self) -> Self: return self From 828ac4d2a57d4f623d4fe2aef25390c5f0051b96 Mon Sep 17 00:00:00 2001 From: "stainless-app[bot]" <142633134+stainless-app[bot]@users.noreply.github.com> Date: Thu, 30 Oct 2025 11:04:11 +0000 Subject: [PATCH 04/14] chore(internal/tests): avoid race condition with implicit client cleanup --- tests/test_client.py | 371 +++++++++++++++++++++++-------------------- 1 file changed, 202 insertions(+), 169 deletions(-) diff --git a/tests/test_client.py b/tests/test_client.py index c8a764ce..4066755f 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -60,51 +60,49 @@ def _get_open_connections(client: Writer | AsyncWriter) -> int: class TestWriter: - client = Writer(base_url=base_url, api_key=api_key, _strict_response_validation=True) - @pytest.mark.respx(base_url=base_url) - def test_raw_response(self, respx_mock: MockRouter) -> None: + def test_raw_response(self, respx_mock: MockRouter, client: Writer) -> None: respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) - response = self.client.post("/foo", cast_to=httpx.Response) + response = client.post("/foo", cast_to=httpx.Response) assert response.status_code == 200 assert isinstance(response, httpx.Response) assert response.json() == {"foo": "bar"} @pytest.mark.respx(base_url=base_url) - def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None: + def test_raw_response_for_binary(self, respx_mock: MockRouter, client: Writer) -> None: respx_mock.post("/foo").mock( return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}') ) - response = self.client.post("/foo", cast_to=httpx.Response) + response = client.post("/foo", cast_to=httpx.Response) assert response.status_code == 200 assert isinstance(response, httpx.Response) assert response.json() == {"foo": "bar"} - def test_copy(self) -> None: - copied = self.client.copy() - assert id(copied) != id(self.client) + def test_copy(self, client: Writer) -> None: + copied = client.copy() + assert id(copied) != id(client) - copied = self.client.copy(api_key="another My API Key") + copied = client.copy(api_key="another My API Key") assert copied.api_key == "another My API Key" - assert self.client.api_key == "My API Key" + assert client.api_key == "My API Key" - def test_copy_default_options(self) -> None: + def test_copy_default_options(self, client: Writer) -> None: # options that have a default are overridden correctly - copied = self.client.copy(max_retries=7) + copied = client.copy(max_retries=7) assert copied.max_retries == 7 - assert self.client.max_retries == 2 + assert client.max_retries == 2 copied2 = copied.copy(max_retries=6) assert copied2.max_retries == 6 assert copied.max_retries == 7 # timeout - assert isinstance(self.client.timeout, httpx.Timeout) - copied = self.client.copy(timeout=None) + assert isinstance(client.timeout, httpx.Timeout) + copied = client.copy(timeout=None) assert copied.timeout is None - assert isinstance(self.client.timeout, httpx.Timeout) + assert isinstance(client.timeout, httpx.Timeout) def test_copy_default_headers(self) -> None: client = Writer( @@ -139,6 +137,7 @@ def test_copy_default_headers(self) -> None: match="`default_headers` and `set_default_headers` arguments are mutually exclusive", ): client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"}) + client.close() def test_copy_default_query(self) -> None: client = Writer( @@ -176,13 +175,15 @@ def test_copy_default_query(self) -> None: ): client.copy(set_default_query={}, default_query={"foo": "Bar"}) - def test_copy_signature(self) -> None: + client.close() + + def test_copy_signature(self, client: Writer) -> None: # ensure the same parameters that can be passed to the client are defined in the `.copy()` method init_signature = inspect.signature( # mypy doesn't like that we access the `__init__` property. - self.client.__init__, # type: ignore[misc] + client.__init__, # type: ignore[misc] ) - copy_signature = inspect.signature(self.client.copy) + copy_signature = inspect.signature(client.copy) exclude_params = {"transport", "proxies", "_strict_response_validation"} for name in init_signature.parameters.keys(): @@ -193,12 +194,12 @@ def test_copy_signature(self) -> None: assert copy_param is not None, f"copy() signature is missing the {name} param" @pytest.mark.skipif(sys.version_info >= (3, 10), reason="fails because of a memory leak that started from 3.12") - def test_copy_build_request(self) -> None: + def test_copy_build_request(self, client: Writer) -> None: options = FinalRequestOptions(method="get", url="/foo") def build_request(options: FinalRequestOptions) -> None: - client = self.client.copy() - client._build_request(options) + client_copy = client.copy() + client_copy._build_request(options) # ensure that the machinery is warmed up before tracing starts. build_request(options) @@ -255,14 +256,12 @@ def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.Statistic print(frame) raise AssertionError() - def test_request_timeout(self) -> None: - request = self.client._build_request(FinalRequestOptions(method="get", url="/foo")) + def test_request_timeout(self, client: Writer) -> None: + request = client._build_request(FinalRequestOptions(method="get", url="/foo")) timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == DEFAULT_TIMEOUT - request = self.client._build_request( - FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0)) - ) + request = client._build_request(FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))) timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == httpx.Timeout(100.0) @@ -273,6 +272,8 @@ def test_client_timeout_option(self) -> None: timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == httpx.Timeout(0) + client.close() + def test_http_client_timeout_option(self) -> None: # custom timeout given to the httpx client should be used with httpx.Client(timeout=None) as http_client: @@ -284,6 +285,8 @@ def test_http_client_timeout_option(self) -> None: timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == httpx.Timeout(None) + client.close() + # no timeout given to the httpx client should not use the httpx default with httpx.Client() as http_client: client = Writer( @@ -294,6 +297,8 @@ def test_http_client_timeout_option(self) -> None: timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == DEFAULT_TIMEOUT + client.close() + # explicitly passing the default timeout currently results in it being ignored with httpx.Client(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client: client = Writer( @@ -304,6 +309,8 @@ def test_http_client_timeout_option(self) -> None: timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == DEFAULT_TIMEOUT # our default + client.close() + async def test_invalid_http_client(self) -> None: with pytest.raises(TypeError, match="Invalid `http_client` arg"): async with httpx.AsyncClient() as http_client: @@ -315,14 +322,14 @@ async def test_invalid_http_client(self) -> None: ) def test_default_headers_option(self) -> None: - client = Writer( + test_client = Writer( base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"} ) - request = client._build_request(FinalRequestOptions(method="get", url="/foo")) + request = test_client._build_request(FinalRequestOptions(method="get", url="/foo")) assert request.headers.get("x-foo") == "bar" assert request.headers.get("x-stainless-lang") == "python" - client2 = Writer( + test_client2 = Writer( base_url=base_url, api_key=api_key, _strict_response_validation=True, @@ -331,10 +338,13 @@ def test_default_headers_option(self) -> None: "X-Stainless-Lang": "my-overriding-header", }, ) - request = client2._build_request(FinalRequestOptions(method="get", url="/foo")) + request = test_client2._build_request(FinalRequestOptions(method="get", url="/foo")) assert request.headers.get("x-foo") == "stainless" assert request.headers.get("x-stainless-lang") == "my-overriding-header" + test_client.close() + test_client2.close() + def test_validate_headers(self) -> None: client = Writer(base_url=base_url, api_key=api_key, _strict_response_validation=True) request = client._build_request(FinalRequestOptions(method="get", url="/foo")) @@ -363,8 +373,10 @@ def test_default_query_option(self) -> None: url = httpx.URL(request.url) assert dict(url.params) == {"foo": "baz", "query_param": "overridden"} - def test_request_extra_json(self) -> None: - request = self.client._build_request( + client.close() + + def test_request_extra_json(self, client: Writer) -> None: + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -375,7 +387,7 @@ def test_request_extra_json(self) -> None: data = json.loads(request.content.decode("utf-8")) assert data == {"foo": "bar", "baz": False} - request = self.client._build_request( + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -386,7 +398,7 @@ def test_request_extra_json(self) -> None: assert data == {"baz": False} # `extra_json` takes priority over `json_data` when keys clash - request = self.client._build_request( + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -397,8 +409,8 @@ def test_request_extra_json(self) -> None: data = json.loads(request.content.decode("utf-8")) assert data == {"foo": "bar", "baz": None} - def test_request_extra_headers(self) -> None: - request = self.client._build_request( + def test_request_extra_headers(self, client: Writer) -> None: + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -408,7 +420,7 @@ def test_request_extra_headers(self) -> None: assert request.headers.get("X-Foo") == "Foo" # `extra_headers` takes priority over `default_headers` when keys clash - request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request( + request = client.with_options(default_headers={"X-Bar": "true"})._build_request( FinalRequestOptions( method="post", url="/foo", @@ -419,8 +431,8 @@ def test_request_extra_headers(self) -> None: ) assert request.headers.get("X-Bar") == "false" - def test_request_extra_query(self) -> None: - request = self.client._build_request( + def test_request_extra_query(self, client: Writer) -> None: + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -433,7 +445,7 @@ def test_request_extra_query(self) -> None: assert params == {"my_query_param": "Foo"} # if both `query` and `extra_query` are given, they are merged - request = self.client._build_request( + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -447,7 +459,7 @@ def test_request_extra_query(self) -> None: assert params == {"bar": "1", "foo": "2"} # `extra_query` takes priority over `query` when keys clash - request = self.client._build_request( + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -490,7 +502,7 @@ def test_multipart_repeating_array(self, client: Writer) -> None: ] @pytest.mark.respx(base_url=base_url) - def test_basic_union_response(self, respx_mock: MockRouter) -> None: + def test_basic_union_response(self, respx_mock: MockRouter, client: Writer) -> None: class Model1(BaseModel): name: str @@ -499,12 +511,12 @@ class Model2(BaseModel): respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) - response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) + response = client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) assert isinstance(response, Model2) assert response.foo == "bar" @pytest.mark.respx(base_url=base_url) - def test_union_response_different_types(self, respx_mock: MockRouter) -> None: + def test_union_response_different_types(self, respx_mock: MockRouter, client: Writer) -> None: """Union of objects with the same field name using a different type""" class Model1(BaseModel): @@ -515,18 +527,18 @@ class Model2(BaseModel): respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) - response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) + response = client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) assert isinstance(response, Model2) assert response.foo == "bar" respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1})) - response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) + response = client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) assert isinstance(response, Model1) assert response.foo == 1 @pytest.mark.respx(base_url=base_url) - def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None: + def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter, client: Writer) -> None: """ Response that sets Content-Type to something other than application/json but returns json data """ @@ -542,7 +554,7 @@ class Model(BaseModel): ) ) - response = self.client.get("/foo", cast_to=Model) + response = client.get("/foo", cast_to=Model) assert isinstance(response, Model) assert response.foo == 2 @@ -554,6 +566,8 @@ def test_base_url_setter(self) -> None: assert client.base_url == "https://example.com/from_setter/" + client.close() + def test_base_url_env(self) -> None: with update_env(WRITER_BASE_URL="http://localhost:5000/from/env"): client = Writer(api_key=api_key, _strict_response_validation=True) @@ -581,6 +595,7 @@ def test_base_url_trailing_slash(self, client: Writer) -> None: ), ) assert request.url == "http://localhost:5000/custom/path/foo" + client.close() @pytest.mark.parametrize( "client", @@ -604,6 +619,7 @@ def test_base_url_no_trailing_slash(self, client: Writer) -> None: ), ) assert request.url == "http://localhost:5000/custom/path/foo" + client.close() @pytest.mark.parametrize( "client", @@ -627,35 +643,36 @@ def test_absolute_request_url(self, client: Writer) -> None: ), ) assert request.url == "https://myapi.com/foo" + client.close() def test_copied_client_does_not_close_http(self) -> None: - client = Writer(base_url=base_url, api_key=api_key, _strict_response_validation=True) - assert not client.is_closed() + test_client = Writer(base_url=base_url, api_key=api_key, _strict_response_validation=True) + assert not test_client.is_closed() - copied = client.copy() - assert copied is not client + copied = test_client.copy() + assert copied is not test_client del copied - assert not client.is_closed() + assert not test_client.is_closed() def test_client_context_manager(self) -> None: - client = Writer(base_url=base_url, api_key=api_key, _strict_response_validation=True) - with client as c2: - assert c2 is client + test_client = Writer(base_url=base_url, api_key=api_key, _strict_response_validation=True) + with test_client as c2: + assert c2 is test_client assert not c2.is_closed() - assert not client.is_closed() - assert client.is_closed() + assert not test_client.is_closed() + assert test_client.is_closed() @pytest.mark.respx(base_url=base_url) - def test_client_response_validation_error(self, respx_mock: MockRouter) -> None: + def test_client_response_validation_error(self, respx_mock: MockRouter, client: Writer) -> None: class Model(BaseModel): foo: str respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}})) with pytest.raises(APIResponseValidationError) as exc: - self.client.get("/foo", cast_to=Model) + client.get("/foo", cast_to=Model) assert isinstance(exc.value.__cause__, ValidationError) @@ -664,13 +681,13 @@ def test_client_max_retries_validation(self) -> None: Writer(base_url=base_url, api_key=api_key, _strict_response_validation=True, max_retries=cast(Any, None)) @pytest.mark.respx(base_url=base_url) - def test_default_stream_cls(self, respx_mock: MockRouter) -> None: + def test_default_stream_cls(self, respx_mock: MockRouter, client: Writer) -> None: class Model(BaseModel): name: str respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) - stream = self.client.post("/foo", cast_to=Model, stream=True, stream_cls=Stream[Model]) + stream = client.post("/foo", cast_to=Model, stream=True, stream_cls=Stream[Model]) assert isinstance(stream, Stream) stream.response.close() @@ -686,11 +703,14 @@ class Model(BaseModel): with pytest.raises(APIResponseValidationError): strict_client.get("/foo", cast_to=Model) - client = Writer(base_url=base_url, api_key=api_key, _strict_response_validation=False) + non_strict_client = Writer(base_url=base_url, api_key=api_key, _strict_response_validation=False) - response = client.get("/foo", cast_to=Model) + response = non_strict_client.get("/foo", cast_to=Model) assert isinstance(response, str) # type: ignore[unreachable] + strict_client.close() + non_strict_client.close() + @pytest.mark.parametrize( "remaining_retries,retry_after,timeout", [ @@ -713,9 +733,9 @@ class Model(BaseModel): ], ) @mock.patch("time.time", mock.MagicMock(return_value=1696004797)) - def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None: - client = Writer(base_url=base_url, api_key=api_key, _strict_response_validation=True) - + def test_parse_retry_after_header( + self, remaining_retries: int, retry_after: str, timeout: float, client: Writer + ) -> None: headers = httpx.Headers({"retry-after": retry_after}) options = FinalRequestOptions(method="get", url="/foo", max_retries=3) calculated = client._calculate_retry_timeout(remaining_retries, options, headers) @@ -729,7 +749,7 @@ def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter, clien with pytest.raises(APITimeoutError): client.chat.with_streaming_response.chat(messages=[{"role": "user"}], model="model").__enter__() - assert _get_open_connections(self.client) == 0 + assert _get_open_connections(client) == 0 @mock.patch("writerai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) @@ -738,7 +758,7 @@ def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter, client with pytest.raises(APIStatusError): client.chat.with_streaming_response.chat(messages=[{"role": "user"}], model="model").__enter__() - assert _get_open_connections(self.client) == 0 + assert _get_open_connections(client) == 0 @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) @mock.patch("writerai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @@ -844,83 +864,77 @@ def test_default_client_creation(self) -> None: ) @pytest.mark.respx(base_url=base_url) - def test_follow_redirects(self, respx_mock: MockRouter) -> None: + def test_follow_redirects(self, respx_mock: MockRouter, client: Writer) -> None: # Test that the default follow_redirects=True allows following redirects respx_mock.post("/redirect").mock( return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"}) ) respx_mock.get("/redirected").mock(return_value=httpx.Response(200, json={"status": "ok"})) - response = self.client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response) + response = client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response) assert response.status_code == 200 assert response.json() == {"status": "ok"} @pytest.mark.respx(base_url=base_url) - def test_follow_redirects_disabled(self, respx_mock: MockRouter) -> None: + def test_follow_redirects_disabled(self, respx_mock: MockRouter, client: Writer) -> None: # Test that follow_redirects=False prevents following redirects respx_mock.post("/redirect").mock( return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"}) ) with pytest.raises(APIStatusError) as exc_info: - self.client.post( - "/redirect", body={"key": "value"}, options={"follow_redirects": False}, cast_to=httpx.Response - ) + client.post("/redirect", body={"key": "value"}, options={"follow_redirects": False}, cast_to=httpx.Response) assert exc_info.value.response.status_code == 302 assert exc_info.value.response.headers["Location"] == f"{base_url}/redirected" class TestAsyncWriter: - client = AsyncWriter(base_url=base_url, api_key=api_key, _strict_response_validation=True) - @pytest.mark.respx(base_url=base_url) - @pytest.mark.asyncio - async def test_raw_response(self, respx_mock: MockRouter) -> None: + async def test_raw_response(self, respx_mock: MockRouter, async_client: AsyncWriter) -> None: respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) - response = await self.client.post("/foo", cast_to=httpx.Response) + response = await async_client.post("/foo", cast_to=httpx.Response) assert response.status_code == 200 assert isinstance(response, httpx.Response) assert response.json() == {"foo": "bar"} @pytest.mark.respx(base_url=base_url) - @pytest.mark.asyncio - async def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None: + async def test_raw_response_for_binary(self, respx_mock: MockRouter, async_client: AsyncWriter) -> None: respx_mock.post("/foo").mock( return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}') ) - response = await self.client.post("/foo", cast_to=httpx.Response) + response = await async_client.post("/foo", cast_to=httpx.Response) assert response.status_code == 200 assert isinstance(response, httpx.Response) assert response.json() == {"foo": "bar"} - def test_copy(self) -> None: - copied = self.client.copy() - assert id(copied) != id(self.client) + def test_copy(self, async_client: AsyncWriter) -> None: + copied = async_client.copy() + assert id(copied) != id(async_client) - copied = self.client.copy(api_key="another My API Key") + copied = async_client.copy(api_key="another My API Key") assert copied.api_key == "another My API Key" - assert self.client.api_key == "My API Key" + assert async_client.api_key == "My API Key" - def test_copy_default_options(self) -> None: + def test_copy_default_options(self, async_client: AsyncWriter) -> None: # options that have a default are overridden correctly - copied = self.client.copy(max_retries=7) + copied = async_client.copy(max_retries=7) assert copied.max_retries == 7 - assert self.client.max_retries == 2 + assert async_client.max_retries == 2 copied2 = copied.copy(max_retries=6) assert copied2.max_retries == 6 assert copied.max_retries == 7 # timeout - assert isinstance(self.client.timeout, httpx.Timeout) - copied = self.client.copy(timeout=None) + assert isinstance(async_client.timeout, httpx.Timeout) + copied = async_client.copy(timeout=None) assert copied.timeout is None - assert isinstance(self.client.timeout, httpx.Timeout) + assert isinstance(async_client.timeout, httpx.Timeout) - def test_copy_default_headers(self) -> None: + async def test_copy_default_headers(self) -> None: client = AsyncWriter( base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"} ) @@ -953,8 +967,9 @@ def test_copy_default_headers(self) -> None: match="`default_headers` and `set_default_headers` arguments are mutually exclusive", ): client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"}) + await client.close() - def test_copy_default_query(self) -> None: + async def test_copy_default_query(self) -> None: client = AsyncWriter( base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"} ) @@ -990,13 +1005,15 @@ def test_copy_default_query(self) -> None: ): client.copy(set_default_query={}, default_query={"foo": "Bar"}) - def test_copy_signature(self) -> None: + await client.close() + + def test_copy_signature(self, async_client: AsyncWriter) -> None: # ensure the same parameters that can be passed to the client are defined in the `.copy()` method init_signature = inspect.signature( # mypy doesn't like that we access the `__init__` property. - self.client.__init__, # type: ignore[misc] + async_client.__init__, # type: ignore[misc] ) - copy_signature = inspect.signature(self.client.copy) + copy_signature = inspect.signature(async_client.copy) exclude_params = {"transport", "proxies", "_strict_response_validation"} for name in init_signature.parameters.keys(): @@ -1007,12 +1024,12 @@ def test_copy_signature(self) -> None: assert copy_param is not None, f"copy() signature is missing the {name} param" @pytest.mark.skipif(sys.version_info >= (3, 10), reason="fails because of a memory leak that started from 3.12") - def test_copy_build_request(self) -> None: + def test_copy_build_request(self, async_client: AsyncWriter) -> None: options = FinalRequestOptions(method="get", url="/foo") def build_request(options: FinalRequestOptions) -> None: - client = self.client.copy() - client._build_request(options) + client_copy = async_client.copy() + client_copy._build_request(options) # ensure that the machinery is warmed up before tracing starts. build_request(options) @@ -1069,12 +1086,12 @@ def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.Statistic print(frame) raise AssertionError() - async def test_request_timeout(self) -> None: - request = self.client._build_request(FinalRequestOptions(method="get", url="/foo")) + async def test_request_timeout(self, async_client: AsyncWriter) -> None: + request = async_client._build_request(FinalRequestOptions(method="get", url="/foo")) timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == DEFAULT_TIMEOUT - request = self.client._build_request( + request = async_client._build_request( FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0)) ) timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore @@ -1089,6 +1106,8 @@ async def test_client_timeout_option(self) -> None: timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == httpx.Timeout(0) + await client.close() + async def test_http_client_timeout_option(self) -> None: # custom timeout given to the httpx client should be used async with httpx.AsyncClient(timeout=None) as http_client: @@ -1100,6 +1119,8 @@ async def test_http_client_timeout_option(self) -> None: timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == httpx.Timeout(None) + await client.close() + # no timeout given to the httpx client should not use the httpx default async with httpx.AsyncClient() as http_client: client = AsyncWriter( @@ -1110,6 +1131,8 @@ async def test_http_client_timeout_option(self) -> None: timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == DEFAULT_TIMEOUT + await client.close() + # explicitly passing the default timeout currently results in it being ignored async with httpx.AsyncClient(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client: client = AsyncWriter( @@ -1120,6 +1143,8 @@ async def test_http_client_timeout_option(self) -> None: timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == DEFAULT_TIMEOUT # our default + await client.close() + def test_invalid_http_client(self) -> None: with pytest.raises(TypeError, match="Invalid `http_client` arg"): with httpx.Client() as http_client: @@ -1130,15 +1155,15 @@ def test_invalid_http_client(self) -> None: http_client=cast(Any, http_client), ) - def test_default_headers_option(self) -> None: - client = AsyncWriter( + async def test_default_headers_option(self) -> None: + test_client = AsyncWriter( base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"} ) - request = client._build_request(FinalRequestOptions(method="get", url="/foo")) + request = test_client._build_request(FinalRequestOptions(method="get", url="/foo")) assert request.headers.get("x-foo") == "bar" assert request.headers.get("x-stainless-lang") == "python" - client2 = AsyncWriter( + test_client2 = AsyncWriter( base_url=base_url, api_key=api_key, _strict_response_validation=True, @@ -1147,10 +1172,13 @@ def test_default_headers_option(self) -> None: "X-Stainless-Lang": "my-overriding-header", }, ) - request = client2._build_request(FinalRequestOptions(method="get", url="/foo")) + request = test_client2._build_request(FinalRequestOptions(method="get", url="/foo")) assert request.headers.get("x-foo") == "stainless" assert request.headers.get("x-stainless-lang") == "my-overriding-header" + await test_client.close() + await test_client2.close() + def test_validate_headers(self) -> None: client = AsyncWriter(base_url=base_url, api_key=api_key, _strict_response_validation=True) request = client._build_request(FinalRequestOptions(method="get", url="/foo")) @@ -1161,7 +1189,7 @@ def test_validate_headers(self) -> None: client2 = AsyncWriter(base_url=base_url, api_key=None, _strict_response_validation=True) _ = client2 - def test_default_query_option(self) -> None: + async def test_default_query_option(self) -> None: client = AsyncWriter( base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"} ) @@ -1179,8 +1207,10 @@ def test_default_query_option(self) -> None: url = httpx.URL(request.url) assert dict(url.params) == {"foo": "baz", "query_param": "overridden"} - def test_request_extra_json(self) -> None: - request = self.client._build_request( + await client.close() + + def test_request_extra_json(self, client: Writer) -> None: + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -1191,7 +1221,7 @@ def test_request_extra_json(self) -> None: data = json.loads(request.content.decode("utf-8")) assert data == {"foo": "bar", "baz": False} - request = self.client._build_request( + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -1202,7 +1232,7 @@ def test_request_extra_json(self) -> None: assert data == {"baz": False} # `extra_json` takes priority over `json_data` when keys clash - request = self.client._build_request( + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -1213,8 +1243,8 @@ def test_request_extra_json(self) -> None: data = json.loads(request.content.decode("utf-8")) assert data == {"foo": "bar", "baz": None} - def test_request_extra_headers(self) -> None: - request = self.client._build_request( + def test_request_extra_headers(self, client: Writer) -> None: + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -1224,7 +1254,7 @@ def test_request_extra_headers(self) -> None: assert request.headers.get("X-Foo") == "Foo" # `extra_headers` takes priority over `default_headers` when keys clash - request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request( + request = client.with_options(default_headers={"X-Bar": "true"})._build_request( FinalRequestOptions( method="post", url="/foo", @@ -1235,8 +1265,8 @@ def test_request_extra_headers(self) -> None: ) assert request.headers.get("X-Bar") == "false" - def test_request_extra_query(self) -> None: - request = self.client._build_request( + def test_request_extra_query(self, client: Writer) -> None: + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -1249,7 +1279,7 @@ def test_request_extra_query(self) -> None: assert params == {"my_query_param": "Foo"} # if both `query` and `extra_query` are given, they are merged - request = self.client._build_request( + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -1263,7 +1293,7 @@ def test_request_extra_query(self) -> None: assert params == {"bar": "1", "foo": "2"} # `extra_query` takes priority over `query` when keys clash - request = self.client._build_request( + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -1306,7 +1336,7 @@ def test_multipart_repeating_array(self, async_client: AsyncWriter) -> None: ] @pytest.mark.respx(base_url=base_url) - async def test_basic_union_response(self, respx_mock: MockRouter) -> None: + async def test_basic_union_response(self, respx_mock: MockRouter, async_client: AsyncWriter) -> None: class Model1(BaseModel): name: str @@ -1315,12 +1345,12 @@ class Model2(BaseModel): respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) - response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) + response = await async_client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) assert isinstance(response, Model2) assert response.foo == "bar" @pytest.mark.respx(base_url=base_url) - async def test_union_response_different_types(self, respx_mock: MockRouter) -> None: + async def test_union_response_different_types(self, respx_mock: MockRouter, async_client: AsyncWriter) -> None: """Union of objects with the same field name using a different type""" class Model1(BaseModel): @@ -1331,18 +1361,20 @@ class Model2(BaseModel): respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) - response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) + response = await async_client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) assert isinstance(response, Model2) assert response.foo == "bar" respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1})) - response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) + response = await async_client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) assert isinstance(response, Model1) assert response.foo == 1 @pytest.mark.respx(base_url=base_url) - async def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None: + async def test_non_application_json_content_type_for_json_data( + self, respx_mock: MockRouter, async_client: AsyncWriter + ) -> None: """ Response that sets Content-Type to something other than application/json but returns json data """ @@ -1358,11 +1390,11 @@ class Model(BaseModel): ) ) - response = await self.client.get("/foo", cast_to=Model) + response = await async_client.get("/foo", cast_to=Model) assert isinstance(response, Model) assert response.foo == 2 - def test_base_url_setter(self) -> None: + async def test_base_url_setter(self) -> None: client = AsyncWriter( base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True ) @@ -1372,7 +1404,9 @@ def test_base_url_setter(self) -> None: assert client.base_url == "https://example.com/from_setter/" - def test_base_url_env(self) -> None: + await client.close() + + async def test_base_url_env(self) -> None: with update_env(WRITER_BASE_URL="http://localhost:5000/from/env"): client = AsyncWriter(api_key=api_key, _strict_response_validation=True) assert client.base_url == "http://localhost:5000/from/env/" @@ -1392,7 +1426,7 @@ def test_base_url_env(self) -> None: ], ids=["standard", "custom http client"], ) - def test_base_url_trailing_slash(self, client: AsyncWriter) -> None: + async def test_base_url_trailing_slash(self, client: AsyncWriter) -> None: request = client._build_request( FinalRequestOptions( method="post", @@ -1401,6 +1435,7 @@ def test_base_url_trailing_slash(self, client: AsyncWriter) -> None: ), ) assert request.url == "http://localhost:5000/custom/path/foo" + await client.close() @pytest.mark.parametrize( "client", @@ -1417,7 +1452,7 @@ def test_base_url_trailing_slash(self, client: AsyncWriter) -> None: ], ids=["standard", "custom http client"], ) - def test_base_url_no_trailing_slash(self, client: AsyncWriter) -> None: + async def test_base_url_no_trailing_slash(self, client: AsyncWriter) -> None: request = client._build_request( FinalRequestOptions( method="post", @@ -1426,6 +1461,7 @@ def test_base_url_no_trailing_slash(self, client: AsyncWriter) -> None: ), ) assert request.url == "http://localhost:5000/custom/path/foo" + await client.close() @pytest.mark.parametrize( "client", @@ -1442,7 +1478,7 @@ def test_base_url_no_trailing_slash(self, client: AsyncWriter) -> None: ], ids=["standard", "custom http client"], ) - def test_absolute_request_url(self, client: AsyncWriter) -> None: + async def test_absolute_request_url(self, client: AsyncWriter) -> None: request = client._build_request( FinalRequestOptions( method="post", @@ -1451,37 +1487,37 @@ def test_absolute_request_url(self, client: AsyncWriter) -> None: ), ) assert request.url == "https://myapi.com/foo" + await client.close() async def test_copied_client_does_not_close_http(self) -> None: - client = AsyncWriter(base_url=base_url, api_key=api_key, _strict_response_validation=True) - assert not client.is_closed() + test_client = AsyncWriter(base_url=base_url, api_key=api_key, _strict_response_validation=True) + assert not test_client.is_closed() - copied = client.copy() - assert copied is not client + copied = test_client.copy() + assert copied is not test_client del copied await asyncio.sleep(0.2) - assert not client.is_closed() + assert not test_client.is_closed() async def test_client_context_manager(self) -> None: - client = AsyncWriter(base_url=base_url, api_key=api_key, _strict_response_validation=True) - async with client as c2: - assert c2 is client + test_client = AsyncWriter(base_url=base_url, api_key=api_key, _strict_response_validation=True) + async with test_client as c2: + assert c2 is test_client assert not c2.is_closed() - assert not client.is_closed() - assert client.is_closed() + assert not test_client.is_closed() + assert test_client.is_closed() @pytest.mark.respx(base_url=base_url) - @pytest.mark.asyncio - async def test_client_response_validation_error(self, respx_mock: MockRouter) -> None: + async def test_client_response_validation_error(self, respx_mock: MockRouter, async_client: AsyncWriter) -> None: class Model(BaseModel): foo: str respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}})) with pytest.raises(APIResponseValidationError) as exc: - await self.client.get("/foo", cast_to=Model) + await async_client.get("/foo", cast_to=Model) assert isinstance(exc.value.__cause__, ValidationError) @@ -1492,19 +1528,17 @@ async def test_client_max_retries_validation(self) -> None: ) @pytest.mark.respx(base_url=base_url) - @pytest.mark.asyncio - async def test_default_stream_cls(self, respx_mock: MockRouter) -> None: + async def test_default_stream_cls(self, respx_mock: MockRouter, async_client: AsyncWriter) -> None: class Model(BaseModel): name: str respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) - stream = await self.client.post("/foo", cast_to=Model, stream=True, stream_cls=AsyncStream[Model]) + stream = await async_client.post("/foo", cast_to=Model, stream=True, stream_cls=AsyncStream[Model]) assert isinstance(stream, AsyncStream) await stream.response.aclose() @pytest.mark.respx(base_url=base_url) - @pytest.mark.asyncio async def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None: class Model(BaseModel): name: str @@ -1516,11 +1550,14 @@ class Model(BaseModel): with pytest.raises(APIResponseValidationError): await strict_client.get("/foo", cast_to=Model) - client = AsyncWriter(base_url=base_url, api_key=api_key, _strict_response_validation=False) + non_strict_client = AsyncWriter(base_url=base_url, api_key=api_key, _strict_response_validation=False) - response = await client.get("/foo", cast_to=Model) + response = await non_strict_client.get("/foo", cast_to=Model) assert isinstance(response, str) # type: ignore[unreachable] + await strict_client.close() + await non_strict_client.close() + @pytest.mark.parametrize( "remaining_retries,retry_after,timeout", [ @@ -1543,13 +1580,12 @@ class Model(BaseModel): ], ) @mock.patch("time.time", mock.MagicMock(return_value=1696004797)) - @pytest.mark.asyncio - async def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None: - client = AsyncWriter(base_url=base_url, api_key=api_key, _strict_response_validation=True) - + async def test_parse_retry_after_header( + self, remaining_retries: int, retry_after: str, timeout: float, async_client: AsyncWriter + ) -> None: headers = httpx.Headers({"retry-after": retry_after}) options = FinalRequestOptions(method="get", url="/foo", max_retries=3) - calculated = client._calculate_retry_timeout(remaining_retries, options, headers) + calculated = async_client._calculate_retry_timeout(remaining_retries, options, headers) assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType] @mock.patch("writerai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @@ -1562,7 +1598,7 @@ async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter, messages=[{"role": "user"}], model="model" ).__aenter__() - assert _get_open_connections(self.client) == 0 + assert _get_open_connections(async_client) == 0 @mock.patch("writerai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) @@ -1573,12 +1609,11 @@ async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter, await async_client.chat.with_streaming_response.chat( messages=[{"role": "user"}], model="model" ).__aenter__() - assert _get_open_connections(self.client) == 0 + assert _get_open_connections(async_client) == 0 @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) @mock.patch("writerai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) - @pytest.mark.asyncio @pytest.mark.parametrize("failure_mode", ["status", "exception"]) async def test_retries_taken( self, @@ -1610,7 +1645,6 @@ def retry_handler(_request: httpx.Request) -> httpx.Response: @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) @mock.patch("writerai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) - @pytest.mark.asyncio async def test_omit_retry_count_header( self, async_client: AsyncWriter, failures_before_success: int, respx_mock: MockRouter ) -> None: @@ -1636,7 +1670,6 @@ def retry_handler(_request: httpx.Request) -> httpx.Response: @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) @mock.patch("writerai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) - @pytest.mark.asyncio async def test_overwrite_retry_count_header( self, async_client: AsyncWriter, failures_before_success: int, respx_mock: MockRouter ) -> None: @@ -1686,26 +1719,26 @@ async def test_default_client_creation(self) -> None: ) @pytest.mark.respx(base_url=base_url) - async def test_follow_redirects(self, respx_mock: MockRouter) -> None: + async def test_follow_redirects(self, respx_mock: MockRouter, async_client: AsyncWriter) -> None: # Test that the default follow_redirects=True allows following redirects respx_mock.post("/redirect").mock( return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"}) ) respx_mock.get("/redirected").mock(return_value=httpx.Response(200, json={"status": "ok"})) - response = await self.client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response) + response = await async_client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response) assert response.status_code == 200 assert response.json() == {"status": "ok"} @pytest.mark.respx(base_url=base_url) - async def test_follow_redirects_disabled(self, respx_mock: MockRouter) -> None: + async def test_follow_redirects_disabled(self, respx_mock: MockRouter, async_client: AsyncWriter) -> None: # Test that follow_redirects=False prevents following redirects respx_mock.post("/redirect").mock( return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"}) ) with pytest.raises(APIStatusError) as exc_info: - await self.client.post( + await async_client.post( "/redirect", body={"key": "value"}, options={"follow_redirects": False}, cast_to=httpx.Response ) From e8b11131528095f8acf847d126fda21cec0b66c6 Mon Sep 17 00:00:00 2001 From: "stainless-app[bot]" <142633134+stainless-app[bot]@users.noreply.github.com> Date: Mon, 3 Nov 2025 16:11:44 +0000 Subject: [PATCH 05/14] chore(internal): grammar fix (it's -> its) --- src/writerai/_utils/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/writerai/_utils/_utils.py b/src/writerai/_utils/_utils.py index 50d59269..eec7f4a1 100644 --- a/src/writerai/_utils/_utils.py +++ b/src/writerai/_utils/_utils.py @@ -133,7 +133,7 @@ def is_given(obj: _T | NotGiven | Omit) -> TypeGuard[_T]: # Type safe methods for narrowing types with TypeVars. # The default narrowing for isinstance(obj, dict) is dict[unknown, unknown], # however this cause Pyright to rightfully report errors. As we know we don't -# care about the contained types we can safely use `object` in it's place. +# care about the contained types we can safely use `object` in its place. # # There are two separate functions defined, `is_*` and `is_*_t` for different use cases. # `is_*` is for when you're dealing with an unknown input From 9b204ced5e50fa180e24b3d05ec271b8bbd7baff Mon Sep 17 00:00:00 2001 From: "stainless-app[bot]" <142633134+stainless-app[bot]@users.noreply.github.com> Date: Mon, 10 Nov 2025 11:37:39 +0000 Subject: [PATCH 06/14] chore(package): drop Python 3.8 support --- README.md | 2 +- pyproject.toml | 5 ++--- src/writerai/_utils/_sync.py | 34 +++------------------------------- 3 files changed, 6 insertions(+), 35 deletions(-) diff --git a/README.md b/README.md index b0ae7de7..b27220e5 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ [![PyPI version](https://img.shields.io/pypi/v/writer-sdk.svg?label=pypi%20(stable))](https://pypi.org/project/writer-sdk/) -The Writer Python library provides access to the Writer REST API from any Python 3.8+ +The Writer Python library provides access to the Writer REST API from any Python 3.9+ application. It includes a set of tools and utilities that make it easy to integrate the capabilities of Writer into your projects. diff --git a/pyproject.toml b/pyproject.toml index eb115f0e..e9762e35 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,11 +17,10 @@ dependencies = [ "cached-property; python_version < '3.8'", "jiter>=0.4.0, <1", ] -requires-python = ">= 3.8" +requires-python = ">= 3.9" classifiers = [ "Typing :: Typed", "Intended Audience :: Developers", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", @@ -145,7 +144,7 @@ filterwarnings = [ # there are a couple of flags that are still disabled by # default in strict mode as they are experimental and niche. typeCheckingMode = "strict" -pythonVersion = "3.8" +pythonVersion = "3.9" exclude = [ "_dev", diff --git a/src/writerai/_utils/_sync.py b/src/writerai/_utils/_sync.py index ad7ec71b..f6027c18 100644 --- a/src/writerai/_utils/_sync.py +++ b/src/writerai/_utils/_sync.py @@ -1,10 +1,8 @@ from __future__ import annotations -import sys import asyncio import functools -import contextvars -from typing import Any, TypeVar, Callable, Awaitable +from typing import TypeVar, Callable, Awaitable from typing_extensions import ParamSpec import anyio @@ -15,34 +13,11 @@ T_ParamSpec = ParamSpec("T_ParamSpec") -if sys.version_info >= (3, 9): - _asyncio_to_thread = asyncio.to_thread -else: - # backport of https://docs.python.org/3/library/asyncio-task.html#asyncio.to_thread - # for Python 3.8 support - async def _asyncio_to_thread( - func: Callable[T_ParamSpec, T_Retval], /, *args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs - ) -> Any: - """Asynchronously run function *func* in a separate thread. - - Any *args and **kwargs supplied for this function are directly passed - to *func*. Also, the current :class:`contextvars.Context` is propagated, - allowing context variables from the main thread to be accessed in the - separate thread. - - Returns a coroutine that can be awaited to get the eventual result of *func*. - """ - loop = asyncio.events.get_running_loop() - ctx = contextvars.copy_context() - func_call = functools.partial(ctx.run, func, *args, **kwargs) - return await loop.run_in_executor(None, func_call) - - async def to_thread( func: Callable[T_ParamSpec, T_Retval], /, *args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs ) -> T_Retval: if sniffio.current_async_library() == "asyncio": - return await _asyncio_to_thread(func, *args, **kwargs) + return await asyncio.to_thread(func, *args, **kwargs) return await anyio.to_thread.run_sync( functools.partial(func, *args, **kwargs), @@ -53,10 +28,7 @@ async def to_thread( def asyncify(function: Callable[T_ParamSpec, T_Retval]) -> Callable[T_ParamSpec, Awaitable[T_Retval]]: """ Take a blocking function and create an async one that receives the same - positional and keyword arguments. For python version 3.9 and above, it uses - asyncio.to_thread to run the function in a separate thread. For python version - 3.8, it uses locally defined copy of the asyncio.to_thread function which was - introduced in python 3.9. + positional and keyword arguments. Usage: From 56db2716054e1ba6a23071e172584e7c2433ba87 Mon Sep 17 00:00:00 2001 From: "stainless-app[bot]" <142633134+stainless-app[bot]@users.noreply.github.com> Date: Mon, 10 Nov 2025 13:37:33 +0000 Subject: [PATCH 07/14] fix: compat with Python 3.14 --- src/writerai/_models.py | 11 ++++++++--- tests/test_models.py | 8 ++++---- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/src/writerai/_models.py b/src/writerai/_models.py index 14fd47f0..ee5350b3 100644 --- a/src/writerai/_models.py +++ b/src/writerai/_models.py @@ -2,6 +2,7 @@ import os import inspect +import weakref from typing import TYPE_CHECKING, Any, Type, Union, Generic, TypeVar, Callable, Optional, cast from datetime import date, datetime from typing_extensions import ( @@ -574,6 +575,9 @@ class CachedDiscriminatorType(Protocol): __discriminator__: DiscriminatorDetails +DISCRIMINATOR_CACHE: weakref.WeakKeyDictionary[type, DiscriminatorDetails] = weakref.WeakKeyDictionary() + + class DiscriminatorDetails: field_name: str """The name of the discriminator field in the variant class, e.g. @@ -616,8 +620,9 @@ def __init__( def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, ...]) -> DiscriminatorDetails | None: - if isinstance(union, CachedDiscriminatorType): - return union.__discriminator__ + cached = DISCRIMINATOR_CACHE.get(union) + if cached is not None: + return cached discriminator_field_name: str | None = None @@ -670,7 +675,7 @@ def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, discriminator_field=discriminator_field_name, discriminator_alias=discriminator_alias, ) - cast(CachedDiscriminatorType, union).__discriminator__ = details + DISCRIMINATOR_CACHE.setdefault(union, details) return details diff --git a/tests/test_models.py b/tests/test_models.py index af9b6e48..d5169d03 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -9,7 +9,7 @@ from writerai._utils import PropertyInfo from writerai._compat import PYDANTIC_V1, parse_obj, model_dump, model_json -from writerai._models import BaseModel, construct_type +from writerai._models import DISCRIMINATOR_CACHE, BaseModel, construct_type class BasicModel(BaseModel): @@ -809,7 +809,7 @@ class B(BaseModel): UnionType = cast(Any, Union[A, B]) - assert not hasattr(UnionType, "__discriminator__") + assert not DISCRIMINATOR_CACHE.get(UnionType) m = construct_type( value={"type": "b", "data": "foo"}, type_=cast(Any, Annotated[UnionType, PropertyInfo(discriminator="type")]) @@ -818,7 +818,7 @@ class B(BaseModel): assert m.type == "b" assert m.data == "foo" # type: ignore[comparison-overlap] - discriminator = UnionType.__discriminator__ + discriminator = DISCRIMINATOR_CACHE.get(UnionType) assert discriminator is not None m = construct_type( @@ -830,7 +830,7 @@ class B(BaseModel): # if the discriminator details object stays the same between invocations then # we hit the cache - assert UnionType.__discriminator__ is discriminator + assert DISCRIMINATOR_CACHE.get(UnionType) is discriminator @pytest.mark.skipif(PYDANTIC_V1, reason="TypeAliasType is not supported in Pydantic v1") From 6161879c895c2b47a9ece796261f10f84e953100 Mon Sep 17 00:00:00 2001 From: "stainless-app[bot]" <142633134+stainless-app[bot]@users.noreply.github.com> Date: Mon, 10 Nov 2025 18:51:28 +0000 Subject: [PATCH 08/14] codegen metadata --- .stats.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.stats.yml b/.stats.yml index b3d54722..9ff96150 100644 --- a/.stats.yml +++ b/.stats.yml @@ -1,4 +1,4 @@ configured_endpoints: 33 openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/writerai%2Fwriter-4ec783072dd7f57c6e021a746df7650fb8d7a164d8ec25c7d5cab06c33bc114f.yml openapi_spec_hash: ceab065d515f3681b0c33137da308968 -config_hash: 089fd5502b9cf91247887b19117f1ca2 +config_hash: 886645f89dc98f04b8931eaf02854e5f From 1fb332284ab2c7ff87afeb686176df1efcf262db Mon Sep 17 00:00:00 2001 From: "stainless-app[bot]" <142633134+stainless-app[bot]@users.noreply.github.com> Date: Tue, 11 Nov 2025 14:37:29 +0000 Subject: [PATCH 09/14] fix(compat): update signatures of `model_dump` and `model_dump_json` for Pydantic v1 --- src/writerai/_models.py | 41 +++++++++++++++++++++++++++++------------ 1 file changed, 29 insertions(+), 12 deletions(-) diff --git a/src/writerai/_models.py b/src/writerai/_models.py index ee5350b3..94bf248d 100644 --- a/src/writerai/_models.py +++ b/src/writerai/_models.py @@ -258,15 +258,16 @@ def model_dump( mode: Literal["json", "python"] | str = "python", include: IncEx | None = None, exclude: IncEx | None = None, + context: Any | None = None, by_alias: bool | None = None, exclude_unset: bool = False, exclude_defaults: bool = False, exclude_none: bool = False, + exclude_computed_fields: bool = False, round_trip: bool = False, warnings: bool | Literal["none", "warn", "error"] = True, - context: dict[str, Any] | None = None, - serialize_as_any: bool = False, fallback: Callable[[Any], Any] | None = None, + serialize_as_any: bool = False, ) -> dict[str, Any]: """Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump @@ -274,16 +275,24 @@ def model_dump( Args: mode: The mode in which `to_python` should run. - If mode is 'json', the dictionary will only contain JSON serializable types. - If mode is 'python', the dictionary may contain any Python objects. - include: A list of fields to include in the output. - exclude: A list of fields to exclude from the output. + If mode is 'json', the output will only contain JSON serializable types. + If mode is 'python', the output may contain non-JSON-serializable Python objects. + include: A set of fields to include in the output. + exclude: A set of fields to exclude from the output. + context: Additional context to pass to the serializer. by_alias: Whether to use the field's alias in the dictionary key if defined. - exclude_unset: Whether to exclude fields that are unset or None from the output. - exclude_defaults: Whether to exclude fields that are set to their default value from the output. - exclude_none: Whether to exclude fields that have a value of `None` from the output. - round_trip: Whether to enable serialization and deserialization round-trip support. - warnings: Whether to log warnings when invalid fields are encountered. + exclude_unset: Whether to exclude fields that have not been explicitly set. + exclude_defaults: Whether to exclude fields that are set to their default value. + exclude_none: Whether to exclude fields that have a value of `None`. + exclude_computed_fields: Whether to exclude computed fields. + While this can be useful for round-tripping, it is usually recommended to use the dedicated + `round_trip` parameter instead. + round_trip: If True, dumped values should be valid as input for non-idempotent types such as Json[T]. + warnings: How to handle serialization errors. False/"none" ignores them, True/"warn" logs errors, + "error" raises a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError]. + fallback: A function to call when an unknown value is encountered. If not provided, + a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError] error is raised. + serialize_as_any: Whether to serialize fields with duck-typing serialization behavior. Returns: A dictionary representation of the model. @@ -300,6 +309,8 @@ def model_dump( raise ValueError("serialize_as_any is only supported in Pydantic v2") if fallback is not None: raise ValueError("fallback is only supported in Pydantic v2") + if exclude_computed_fields != False: + raise ValueError("exclude_computed_fields is only supported in Pydantic v2") dumped = super().dict( # pyright: ignore[reportDeprecated] include=include, exclude=exclude, @@ -316,15 +327,17 @@ def model_dump_json( self, *, indent: int | None = None, + ensure_ascii: bool = False, include: IncEx | None = None, exclude: IncEx | None = None, + context: Any | None = None, by_alias: bool | None = None, exclude_unset: bool = False, exclude_defaults: bool = False, exclude_none: bool = False, + exclude_computed_fields: bool = False, round_trip: bool = False, warnings: bool | Literal["none", "warn", "error"] = True, - context: dict[str, Any] | None = None, fallback: Callable[[Any], Any] | None = None, serialize_as_any: bool = False, ) -> str: @@ -356,6 +369,10 @@ def model_dump_json( raise ValueError("serialize_as_any is only supported in Pydantic v2") if fallback is not None: raise ValueError("fallback is only supported in Pydantic v2") + if ensure_ascii != False: + raise ValueError("ensure_ascii is only supported in Pydantic v2") + if exclude_computed_fields != False: + raise ValueError("exclude_computed_fields is only supported in Pydantic v2") return super().json( # type: ignore[reportDeprecated] indent=indent, include=include, From 3b5b4a69314e7c3853018233796b05a4035710fb Mon Sep 17 00:00:00 2001 From: "stainless-app[bot]" <142633134+stainless-app[bot]@users.noreply.github.com> Date: Fri, 21 Nov 2025 23:28:44 +0000 Subject: [PATCH 10/14] chore(internal): codegen related update --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index e9762e35..9d39fe60 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ classifiers = [ "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", "Operating System :: OS Independent", "Operating System :: POSIX", "Operating System :: MacOS", From 23c7971d69301956cef01d0041120a848818fa5a Mon Sep 17 00:00:00 2001 From: "stainless-app[bot]" <142633134+stainless-app[bot]@users.noreply.github.com> Date: Thu, 27 Nov 2025 16:52:14 +0000 Subject: [PATCH 11/14] fix: ensure streams are always closed --- src/writerai/_streaming.py | 98 +++++++++++++++++++------------------- 1 file changed, 50 insertions(+), 48 deletions(-) diff --git a/src/writerai/_streaming.py b/src/writerai/_streaming.py index 34bb9d92..a9e69241 100644 --- a/src/writerai/_streaming.py +++ b/src/writerai/_streaming.py @@ -54,30 +54,31 @@ def __stream__(self) -> Iterator[_T]: process_data = self._client._process_response_data iterator = self._iter_events() - for sse in iterator: - if sse.data.startswith("[DONE]"): - break - - if sse.event is None: - yield process_data(data=sse.json(), cast_to=cast_to, response=response) - - if sse.event == "error": - body = sse.data - - try: - body = sse.json() - err_msg = f"{body}" - except Exception: - err_msg = sse.data or f"Error code: {response.status_code}" - - raise self._client._make_status_error( - err_msg, - body=body, - response=self.response, - ) - - # As we might not fully consume the response stream, we need to close it explicitly - response.close() + try: + for sse in iterator: + if sse.data.startswith("[DONE]"): + break + + if sse.event is None: + yield process_data(data=sse.json(), cast_to=cast_to, response=response) + + if sse.event == "error": + body = sse.data + + try: + body = sse.json() + err_msg = f"{body}" + except Exception: + err_msg = sse.data or f"Error code: {response.status_code}" + + raise self._client._make_status_error( + err_msg, + body=body, + response=self.response, + ) + finally: + # Ensure the response is closed even if the consumer doesn't read all data + response.close() def __enter__(self) -> Self: return self @@ -136,30 +137,31 @@ async def __stream__(self) -> AsyncIterator[_T]: process_data = self._client._process_response_data iterator = self._iter_events() - async for sse in iterator: - if sse.data.startswith("[DONE]"): - break - - if sse.event is None: - yield process_data(data=sse.json(), cast_to=cast_to, response=response) - - if sse.event == "error": - body = sse.data - - try: - body = sse.json() - err_msg = f"{body}" - except Exception: - err_msg = sse.data or f"Error code: {response.status_code}" - - raise self._client._make_status_error( - err_msg, - body=body, - response=self.response, - ) - - # As we might not fully consume the response stream, we need to close it explicitly - await response.aclose() + try: + async for sse in iterator: + if sse.data.startswith("[DONE]"): + break + + if sse.event is None: + yield process_data(data=sse.json(), cast_to=cast_to, response=response) + + if sse.event == "error": + body = sse.data + + try: + body = sse.json() + err_msg = f"{body}" + except Exception: + err_msg = sse.data or f"Error code: {response.status_code}" + + raise self._client._make_status_error( + err_msg, + body=body, + response=self.response, + ) + finally: + # Ensure the response is closed even if the consumer doesn't read all data + await response.aclose() async def __aenter__(self) -> Self: return self From 74b479957daea7272bfd0a7533125b0bd42c17dd Mon Sep 17 00:00:00 2001 From: "stainless-app[bot]" <142633134+stainless-app[bot]@users.noreply.github.com> Date: Thu, 27 Nov 2025 18:49:14 +0000 Subject: [PATCH 12/14] chore(deps): mypy 1.18.1 has a regression, pin to 1.17 --- pyproject.toml | 2 +- requirements-dev.lock | 4 +++- requirements.lock | 8 ++++---- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 9d39fe60..402582f0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,7 +48,7 @@ managed = true # version pins are in requirements-dev.lock dev-dependencies = [ "pyright==1.1.399", - "mypy", + "mypy==1.17", "respx", "pytest", "pytest-asyncio", diff --git a/requirements-dev.lock b/requirements-dev.lock index d8160bd5..642272ac 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -79,7 +79,7 @@ mdurl==0.1.2 multidict==6.4.4 # via aiohttp # via yarl -mypy==1.14.1 +mypy==1.17.0 mypy-extensions==1.0.0 # via mypy nest-asyncio==1.6.0 @@ -89,6 +89,8 @@ nox==2023.4.22 packaging==23.2 # via nox # via pytest +pathspec==0.12.1 + # via mypy platformdirs==3.11.0 # via virtualenv pluggy==1.5.0 diff --git a/requirements.lock b/requirements.lock index 0b87d0fd..5d1ba1fa 100644 --- a/requirements.lock +++ b/requirements.lock @@ -57,21 +57,21 @@ multidict==6.4.4 propcache==0.3.1 # via aiohttp # via yarl -pydantic==2.11.9 +pydantic==2.12.5 # via writer-sdk -pydantic-core==2.33.2 +pydantic-core==2.41.5 # via pydantic sniffio==1.3.0 # via anyio # via writer-sdk -typing-extensions==4.12.2 +typing-extensions==4.15.0 # via anyio # via multidict # via pydantic # via pydantic-core # via typing-inspection # via writer-sdk -typing-inspection==0.4.1 +typing-inspection==0.4.2 # via pydantic yarl==1.20.0 # via aiohttp From 04fe0769dcba588b421d5d6fe3fd5b7cf10a726d Mon Sep 17 00:00:00 2001 From: "stainless-app[bot]" <142633134+stainless-app[bot]@users.noreply.github.com> Date: Sun, 30 Nov 2025 00:58:18 +0000 Subject: [PATCH 13/14] docs(api): updates to API spec --- .stats.yml | 4 +-- src/writerai/resources/files.py | 30 +++++++++++++++++-- src/writerai/resources/vision.py | 12 +++++--- src/writerai/types/file.py | 8 ++++- src/writerai/types/file_upload_params.py | 10 +++++++ src/writerai/types/shared/tool_param.py | 10 +++++-- .../types/shared_params/tool_param.py | 10 +++++-- src/writerai/types/vision_analyze_params.py | 4 +-- tests/api_resources/test_files.py | 22 ++++++++++++++ 9 files changed, 92 insertions(+), 18 deletions(-) diff --git a/.stats.yml b/.stats.yml index 9ff96150..33ea12a4 100644 --- a/.stats.yml +++ b/.stats.yml @@ -1,4 +1,4 @@ configured_endpoints: 33 -openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/writerai%2Fwriter-4ec783072dd7f57c6e021a746df7650fb8d7a164d8ec25c7d5cab06c33bc114f.yml -openapi_spec_hash: ceab065d515f3681b0c33137da308968 +openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/writerai%2Fwriter-ea6ec4b34f6b7fdecc564f59b2e31482eee05830bf8dc1f389461b158de1548e.yml +openapi_spec_hash: ea89c1faed473908be2740efe6da255f config_hash: 886645f89dc98f04b8931eaf02854e5f diff --git a/src/writerai/resources/files.py b/src/writerai/resources/files.py index f8547d33..94d8495d 100644 --- a/src/writerai/resources/files.py +++ b/src/writerai/resources/files.py @@ -6,7 +6,7 @@ import httpx -from ..types import file_list_params, file_retry_params +from ..types import file_list_params, file_retry_params, file_upload_params from .._types import Body, Omit, Query, Headers, NotGiven, FileTypes, SequenceNotStr, omit, not_given from .._utils import maybe_transform, async_maybe_transform from .._compat import cached_property @@ -274,6 +274,7 @@ def upload( content: FileTypes, content_disposition: str, content_type: str, + graph_id: str | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -287,6 +288,13 @@ def upload( DOC, DOCX, PPT, PPTX, JPG, PNG, EML, HTML, SRT, CSV, XLS, and XLSX. Args: + graph_id: The unique identifier of the Knowledge Graph to associate the uploaded file + with. + + Note: The response from the upload endpoint does not include the `graphId` + field, but the association will be visible when you retrieve the file using the + file retrieval endpoint. + extra_headers: Send extra headers extra_query: Add additional query parameters to the request @@ -303,7 +311,11 @@ def upload( return self._post( "/v1/files", options=make_request_options( - extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + query=maybe_transform({"graph_id": graph_id}, file_upload_params.FileUploadParams), ), binary_request=content, cast_to=File, @@ -550,6 +562,7 @@ async def upload( content: FileTypes, content_disposition: str, content_type: str, + graph_id: str | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -563,6 +576,13 @@ async def upload( DOC, DOCX, PPT, PPTX, JPG, PNG, EML, HTML, SRT, CSV, XLS, and XLSX. Args: + graph_id: The unique identifier of the Knowledge Graph to associate the uploaded file + with. + + Note: The response from the upload endpoint does not include the `graphId` + field, but the association will be visible when you retrieve the file using the + file retrieval endpoint. + extra_headers: Send extra headers extra_query: Add additional query parameters to the request @@ -579,7 +599,11 @@ async def upload( return await self._post( "/v1/files", options=make_request_options( - extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + query=await async_maybe_transform({"graph_id": graph_id}, file_upload_params.FileUploadParams), ), binary_request=content, cast_to=File, diff --git a/src/writerai/resources/vision.py b/src/writerai/resources/vision.py index d90bf3d3..446b31d2 100644 --- a/src/writerai/resources/vision.py +++ b/src/writerai/resources/vision.py @@ -57,8 +57,10 @@ def analyze( extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> VisionResponse: - """ - Submit images and a prompt to generate an analysis of the images. + """Submit images and documents with a prompt to generate an analysis. + + Supports JPG, + PNG, PDF, and TXT files up to 7MB each. Args: model: The model to use for image analysis. @@ -125,8 +127,10 @@ async def analyze( extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> VisionResponse: - """ - Submit images and a prompt to generate an analysis of the images. + """Submit images and documents with a prompt to generate an analysis. + + Supports JPG, + PNG, PDF, and TXT files up to 7MB each. Args: model: The model to use for image analysis. diff --git a/src/writerai/types/file.py b/src/writerai/types/file.py index f12ba530..8b129976 100644 --- a/src/writerai/types/file.py +++ b/src/writerai/types/file.py @@ -16,7 +16,13 @@ class File(BaseModel): """The timestamp when the file was uploaded.""" graph_ids: List[str] - """A list of Knowledge Graph IDs that the file is associated with.""" + """A list of Knowledge Graph IDs that the file is associated with. + + If you provided a `graphId` during upload, the file is associated with that + Knowledge Graph. However, the `graph_ids` field in the upload response is an + empty list. The association will be visible in the `graph_ids` list when you + retrieve the file using the file retrieval endpoint. + """ name: str """The name of the file.""" diff --git a/src/writerai/types/file_upload_params.py b/src/writerai/types/file_upload_params.py index 14077b97..caa50879 100644 --- a/src/writerai/types/file_upload_params.py +++ b/src/writerai/types/file_upload_params.py @@ -16,3 +16,13 @@ class FileUploadParams(TypedDict, total=False): content_disposition: Required[Annotated[str, PropertyInfo(alias="Content-Disposition")]] content_type: Required[Annotated[str, PropertyInfo(alias="Content-Type")]] + + graph_id: Annotated[str, PropertyInfo(alias="graphId")] + """ + The unique identifier of the Knowledge Graph to associate the uploaded file + with. + + Note: The response from the upload endpoint does not include the `graphId` + field, but the association will be visible when you retrieve the file using the + file retrieval endpoint. + """ diff --git a/src/writerai/types/shared/tool_param.py b/src/writerai/types/shared/tool_param.py index c88d8aec..934e6bf4 100644 --- a/src/writerai/types/shared/tool_param.py +++ b/src/writerai/types/shared/tool_param.py @@ -203,10 +203,11 @@ class TranslationTool(BaseModel): class VisionToolFunctionVariable(BaseModel): file_id: str - """The File ID of the image to analyze. + """The File ID of the file to analyze. The file must be uploaded to the Writer platform before you use it with the - Vision tool. The maximum allowed file size is 7MB. + Vision tool. Supported file types: JPG, PNG, PDF, TXT. The maximum allowed file + size is 7MB. """ name: str @@ -228,7 +229,10 @@ class VisionToolFunction(BaseModel): class VisionTool(BaseModel): function: VisionToolFunction - """A tool that uses Palmyra Vision to analyze images.""" + """A tool that uses Palmyra Vision to analyze images and documents. + + Supports JPG, PNG, PDF, and TXT files up to 7MB each. + """ type: Literal["vision"] """The type of tool.""" diff --git a/src/writerai/types/shared_params/tool_param.py b/src/writerai/types/shared_params/tool_param.py index c881bcb5..2ab4c094 100644 --- a/src/writerai/types/shared_params/tool_param.py +++ b/src/writerai/types/shared_params/tool_param.py @@ -204,10 +204,11 @@ class TranslationTool(TypedDict, total=False): class VisionToolFunctionVariable(TypedDict, total=False): file_id: Required[str] - """The File ID of the image to analyze. + """The File ID of the file to analyze. The file must be uploaded to the Writer platform before you use it with the - Vision tool. The maximum allowed file size is 7MB. + Vision tool. Supported file types: JPG, PNG, PDF, TXT. The maximum allowed file + size is 7MB. """ name: Required[str] @@ -229,7 +230,10 @@ class VisionToolFunction(TypedDict, total=False): class VisionTool(TypedDict, total=False): function: Required[VisionToolFunction] - """A tool that uses Palmyra Vision to analyze images.""" + """A tool that uses Palmyra Vision to analyze images and documents. + + Supports JPG, PNG, PDF, and TXT files up to 7MB each. + """ type: Required[Literal["vision"]] """The type of tool.""" diff --git a/src/writerai/types/vision_analyze_params.py b/src/writerai/types/vision_analyze_params.py index 9dac31ac..1bcd12a4 100644 --- a/src/writerai/types/vision_analyze_params.py +++ b/src/writerai/types/vision_analyze_params.py @@ -25,10 +25,10 @@ class VisionAnalyzeParams(TypedDict, total=False): class Variable(TypedDict, total=False): file_id: Required[str] - """The File ID of the image to analyze. + """The File ID of the file to analyze. The file must be uploaded to the Writer platform before it can be used in a - vision request. + vision request. Supported file types: JPG, PNG, PDF, TXT (max 7MB each). """ name: Required[str] diff --git a/tests/api_resources/test_files.py b/tests/api_resources/test_files.py index 46a3c6d1..b22112b3 100644 --- a/tests/api_resources/test_files.py +++ b/tests/api_resources/test_files.py @@ -239,6 +239,17 @@ def test_method_upload(self, client: Writer) -> None: ) assert_matches_type(File, file, path=["response"]) + @pytest.mark.skip(reason="requests with binary data not yet supported in test environment") + @parametrize + def test_method_upload_with_all_params(self, client: Writer) -> None: + file = client.files.upload( + content=b"raw file contents", + content_disposition="Content-Disposition", + content_type="Content-Type", + graph_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + ) + assert_matches_type(File, file, path=["response"]) + @pytest.mark.skip(reason="requests with binary data not yet supported in test environment") @parametrize def test_raw_response_upload(self, client: Writer) -> None: @@ -484,6 +495,17 @@ async def test_method_upload(self, async_client: AsyncWriter) -> None: ) assert_matches_type(File, file, path=["response"]) + @pytest.mark.skip(reason="requests with binary data not yet supported in test environment") + @parametrize + async def test_method_upload_with_all_params(self, async_client: AsyncWriter) -> None: + file = await async_client.files.upload( + content=b"raw file contents", + content_disposition="Content-Disposition", + content_type="Content-Type", + graph_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + ) + assert_matches_type(File, file, path=["response"]) + @pytest.mark.skip(reason="requests with binary data not yet supported in test environment") @parametrize async def test_raw_response_upload(self, async_client: AsyncWriter) -> None: From de7ffa23e7ed5f0fb3082dc1db79ea63f60ff9fa Mon Sep 17 00:00:00 2001 From: "stainless-app[bot]" <142633134+stainless-app[bot]@users.noreply.github.com> Date: Mon, 1 Dec 2025 21:16:29 +0000 Subject: [PATCH 14/14] release: 2.3.3-rc1 --- .release-please-manifest.json | 2 +- CHANGELOG.md | 27 +++++++++++++++++++++++++++ README.md | 4 ++-- pyproject.toml | 2 +- src/writerai/_version.py | 2 +- 5 files changed, 32 insertions(+), 5 deletions(-) diff --git a/.release-please-manifest.json b/.release-please-manifest.json index c5e4ca3d..7f554478 100644 --- a/.release-please-manifest.json +++ b/.release-please-manifest.json @@ -1,3 +1,3 @@ { - ".": "2.3.2" + ".": "2.3.3-rc1" } \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index 3998c187..9d557021 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,32 @@ # Changelog +## 2.3.3-rc1 (2025-12-01) + +Full Changelog: [v2.3.2...v2.3.3-rc1](https://github.com/writer/writer-python/compare/v2.3.2...v2.3.3-rc1) + +### Bug Fixes + +* **client:** close streams without requiring full consumption ([31e3903](https://github.com/writer/writer-python/commit/31e39034cab026c34c9509757a27d9e2221c0c5b)) +* compat with Python 3.14 ([56db271](https://github.com/writer/writer-python/commit/56db2716054e1ba6a23071e172584e7c2433ba87)) +* **compat:** update signatures of `model_dump` and `model_dump_json` for Pydantic v1 ([1fb3322](https://github.com/writer/writer-python/commit/1fb332284ab2c7ff87afeb686176df1efcf262db)) +* ensure streams are always closed ([23c7971](https://github.com/writer/writer-python/commit/23c7971d69301956cef01d0041120a848818fa5a)) + + +### Chores + +* bump `httpx-aiohttp` version to 0.1.9 ([f2ef07d](https://github.com/writer/writer-python/commit/f2ef07dbe6ffd744bf58a6c7b5f3dac8b73a8805)) +* **deps:** mypy 1.18.1 has a regression, pin to 1.17 ([74b4799](https://github.com/writer/writer-python/commit/74b479957daea7272bfd0a7533125b0bd42c17dd)) +* **internal/tests:** avoid race condition with implicit client cleanup ([828ac4d](https://github.com/writer/writer-python/commit/828ac4d2a57d4f623d4fe2aef25390c5f0051b96)) +* **internal:** codegen related update ([3b5b4a6](https://github.com/writer/writer-python/commit/3b5b4a69314e7c3853018233796b05a4035710fb)) +* **internal:** detect missing future annotations with ruff ([9df4451](https://github.com/writer/writer-python/commit/9df44512304949e6193e7ff33390342e26d065c6)) +* **internal:** grammar fix (it's -> its) ([e8b1113](https://github.com/writer/writer-python/commit/e8b11131528095f8acf847d126fda21cec0b66c6)) +* **package:** drop Python 3.8 support ([9b204ce](https://github.com/writer/writer-python/commit/9b204ced5e50fa180e24b3d05ec271b8bbd7baff)) + + +### Documentation + +* **api:** updates to API spec ([04fe076](https://github.com/writer/writer-python/commit/04fe0769dcba588b421d5d6fe3fd5b7cf10a726d)) + ## 2.3.2 (2025-10-03) Full Changelog: [v2.3.2-rc2...v2.3.2](https://github.com/writer/writer-python/compare/v2.3.2-rc2...v2.3.2) diff --git a/README.md b/README.md index b27220e5..a1a33706 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ To install the package from PyPI, use `pip`: ```sh # install from PyPI -pip install writer-sdk +pip install --pre writer-sdk ``` ## Prequisites @@ -116,7 +116,7 @@ You can enable this by installing `aiohttp`: ```sh # install from PyPI -pip install writer-sdk[aiohttp] +pip install --pre writer-sdk[aiohttp] ``` Then you can enable it by instantiating the client with `http_client=DefaultAioHttpClient()`: diff --git a/pyproject.toml b/pyproject.toml index 402582f0..dd7ef0f5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "writer-sdk" -version = "2.3.2" +version = "2.3.3-rc1" description = "The official Python library for the writer API" dynamic = ["readme"] license = "Apache-2.0" diff --git a/src/writerai/_version.py b/src/writerai/_version.py index 062565c2..6ff9a512 100644 --- a/src/writerai/_version.py +++ b/src/writerai/_version.py @@ -1,4 +1,4 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. __title__ = "writerai" -__version__ = "2.3.2" # x-release-please-version +__version__ = "2.3.3-rc1" # x-release-please-version