diff --git a/project/dependencies.py b/project/dependencies.py index cb62df7..08e9a56 100644 --- a/project/dependencies.py +++ b/project/dependencies.py @@ -114,10 +114,24 @@ def get_client_id( raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="JWT is malformed") +@lru_cache +def get_ssl_context( + settings: Annotated[Settings, Depends(get_settings)], +): + # see https://www.python-httpx.org/advanced/ssl/#configuring-client-instances + ctx = truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + if settings.extra_ca_certs is not None: + ctx.load_verify_locations(cafile=settings.extra_ca_certs) + return ctx + + ProxyMount = dict[str, httpx.HTTPTransport] | None -def get_proxy_mounts(settings: Annotated[Settings, Depends(get_settings)]): +def get_proxy_mounts( + settings: Annotated[Settings, Depends(get_settings)], + ssl_context: Annotated[ssl.SSLContext, Depends(get_ssl_context)], +): proxy = settings.proxy proxy_mounts = {} @@ -127,7 +141,7 @@ def get_proxy_mounts(settings: Annotated[Settings, Depends(get_settings)]): if http_proxy_set and https_proxy_set: # if two urls are provided, set them for each mode of transport individually proxy_mounts["http://"] = httpx.HTTPTransport(proxy=str(proxy.http_url)) - proxy_mounts["https://"] = httpx.HTTPTransport(proxy=str(proxy.https_url)) + proxy_mounts["https://"] = httpx.HTTPTransport(proxy=str(proxy.https_url), verify=ssl_context) elif not http_proxy_set and not https_proxy_set: # if no urls are provided, do nothing pass @@ -136,7 +150,7 @@ def get_proxy_mounts(settings: Annotated[Settings, Depends(get_settings)]): proxy_url = str(proxy.http_url) if http_proxy_set else str(proxy.https_url) proxy_mounts["http://"] = httpx.HTTPTransport(proxy=proxy_url) - proxy_mounts["https://"] = httpx.HTTPTransport(proxy=proxy_url) + proxy_mounts["https://"] = httpx.HTTPTransport(proxy=proxy_url, verify=ssl_context) if len(proxy_mounts) == 0: return None @@ -144,17 +158,6 @@ def get_proxy_mounts(settings: Annotated[Settings, Depends(get_settings)]): return proxy_mounts -@lru_cache -def get_ssl_context( - settings: Annotated[Settings, Depends(get_settings)], -): - # see https://www.python-httpx.org/advanced/ssl/#configuring-client-instances - ctx = truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - if settings.extra_ca_certs is not None: - ctx.load_verify_locations(cafile=settings.extra_ca_certs) - return ctx - - def get_flame_hub_auth_flow( settings: Annotated[Settings, Depends(get_settings)], ssl_context: Annotated[ssl.SSLContext, Depends(get_ssl_context)],