diff --git a/packages/mistralai_gcp/src/mistralai_gcp/sdk.py b/packages/mistralai_gcp/src/mistralai_gcp/sdk.py index dd93cc7f..ae2b4483 100644 --- a/packages/mistralai_gcp/src/mistralai_gcp/sdk.py +++ b/packages/mistralai_gcp/src/mistralai_gcp/sdk.py @@ -226,6 +226,7 @@ def before_request( headers=headers, content=new_content, stream=None, + extensions=request.extensions, ) return next_request diff --git a/src/mistralai/extra/tests/test_gcp_timeout_extensions.py b/src/mistralai/extra/tests/test_gcp_timeout_extensions.py new file mode 100644 index 00000000..1ff39b1d --- /dev/null +++ b/src/mistralai/extra/tests/test_gcp_timeout_extensions.py @@ -0,0 +1,25 @@ +import sys +from pathlib import Path +import httpx + +# Add the mistralai_gcp package to the path for imports +gcp_package_path = Path(__file__).parent.parent.parent.parent.parent / "packages" / "mistralai_gcp" / "src" +sys.path.insert(0, str(gcp_package_path)) + +from mistralai_gcp.sdk import GoogleCloudBeforeRequestHook + +def test_gcp_before_request_preserves_timeout_extension(): + hook = GoogleCloudBeforeRequestHook(region="europe-west4", project_id="proj") + + req = httpx.Request( + "POST", + "https://europe-west4-aiplatform.googleapis.com/v1/dummy", + content=b'{"model":"mistral-large-2407"}', + headers={"content-type": "application/json"}, + ) + timeout = httpx.Timeout(30.0) + req.extensions["timeout"] = timeout + + out = hook.before_request(None, req) + assert isinstance(out, httpx.Request) + assert out.extensions.get("timeout") == timeout