|
10 | 10 |
|
11 | 11 | import httpx |
12 | 12 | import requests |
| 13 | +from requests.adapters import HTTPAdapter |
13 | 14 |
|
14 | 15 | from connect.client.constants import CONNECT_ENDPOINT_URL, CONNECT_SPECS_URL |
15 | 16 | from connect.client.help_formatter import DefaultFormatter |
|
24 | 25 | from connect.client.utils import get_headers |
25 | 26 |
|
26 | 27 |
|
27 | | -class _ConnectClientBase(threading.local): |
| 28 | +_SYNC_TRANSPORTS = {} |
| 29 | +_ASYNC_TRANSPORTS = {} |
| 30 | + |
| 31 | + |
| 32 | +class _ConnectClientBase: |
28 | 33 | def __init__( |
29 | 34 | self, |
30 | 35 | api_key, |
@@ -53,25 +58,11 @@ def __init__( |
53 | 58 | self.specs = None |
54 | 59 | if self._use_specs: |
55 | 60 | self.specs = OpenAPISpecs(self.specs_location) |
56 | | - self._response = None |
57 | 61 | self.logger = logger |
58 | 62 | self._help_formatter = DefaultFormatter(self.specs) |
59 | 63 | self.timeout = timeout |
60 | 64 | self.resourceset_append = resourceset_append |
61 | 65 |
|
62 | | - @property |
63 | | - def response(self) -> requests.Response: |
64 | | - """ |
65 | | - Returns the raw |
66 | | - [`requests`](https://requests.readthedocs.io/en/latest/api/#requests.Response) |
67 | | - response. |
68 | | - """ |
69 | | - return self._response |
70 | | - |
71 | | - @response.setter |
72 | | - def response(self, value: requests.Response): |
73 | | - self._response = value |
74 | | - |
75 | 66 | def __getattr__(self, name): |
76 | 67 | if '_' in name: |
77 | 68 | name = name.replace('_', '-') |
@@ -173,7 +164,7 @@ def _get_api_error_details(self): |
173 | 164 | pass |
174 | 165 |
|
175 | 166 |
|
176 | | -class ConnectClient(_ConnectClientBase, threading.local, SyncClientMixin): |
| 167 | +class ConnectClient(_ConnectClientBase, SyncClientMixin): |
177 | 168 | """ |
178 | 169 | Create a new instance of the ConnectClient. |
179 | 170 |
|
@@ -203,7 +194,33 @@ class ConnectClient(_ConnectClientBase, threading.local, SyncClientMixin): |
203 | 194 |
|
204 | 195 | def __init__(self, *args, **kwargs): |
205 | 196 | super().__init__(*args, **kwargs) |
206 | | - self._session = requests.Session() |
| 197 | + self._thread_locals = threading.local() |
| 198 | + self._thread_locals.response = None |
| 199 | + self._thread_locals.session = requests.Session() |
| 200 | + self._thread_locals.session.mount( |
| 201 | + self.endpoint, |
| 202 | + _SYNC_TRANSPORTS.setdefault( |
| 203 | + self.endpoint, |
| 204 | + HTTPAdapter(), |
| 205 | + ), |
| 206 | + ) |
| 207 | + |
| 208 | + @property |
| 209 | + def session(self): |
| 210 | + return self._thread_locals.session |
| 211 | + |
| 212 | + @property |
| 213 | + def response(self) -> requests.Response: |
| 214 | + """ |
| 215 | + Returns the raw |
| 216 | + [`requests`](https://requests.readthedocs.io/en/latest/api/#requests.Response) |
| 217 | + response. |
| 218 | + """ |
| 219 | + return self._thread_locals.response |
| 220 | + |
| 221 | + @response.setter |
| 222 | + def response(self, value: requests.Response): |
| 223 | + self._thread_locals.response = value |
207 | 224 |
|
208 | 225 | def _get_collection_class(self): |
209 | 226 | return Collection |
@@ -246,7 +263,19 @@ class AsyncConnectClient(_ConnectClientBase, AsyncClientMixin): |
246 | 263 | def __init__(self, *args, **kwargs): |
247 | 264 | super().__init__(*args, **kwargs) |
248 | 265 | self._response = contextvars.ContextVar('response', default=None) |
249 | | - self._session = httpx.AsyncClient(verify=_SSL_CONTEXT) |
| 266 | + self._session = contextvars.ContextVar( |
| 267 | + 'session', |
| 268 | + default=httpx.AsyncClient( |
| 269 | + transport=_ASYNC_TRANSPORTS.setdefault( |
| 270 | + self.endpoint, |
| 271 | + httpx.AsyncHTTPTransport(verify=_SSL_CONTEXT), |
| 272 | + ), |
| 273 | + ), |
| 274 | + ) |
| 275 | + |
| 276 | + @property |
| 277 | + def session(self): |
| 278 | + return self._session.get() |
250 | 279 |
|
251 | 280 | @property |
252 | 281 | def response(self): |
|
0 commit comments