diff --git a/.github/workflows/package.yml b/.github/workflows/package.yml index 519d5b8..bb7fe69 100644 --- a/.github/workflows/package.yml +++ b/.github/workflows/package.yml @@ -24,7 +24,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install flake8 mypy pytest fastapi>=0.60.0 python-jose>=3.2.0 pydantic-settings httpx requests types-requests + python -m pip install flake8 mypy pytest fastapi>=0.60.0 PyJWT>=2.8.0 pydantic-settings httpx requests types-requests if [ -f requirements.txt ]; then pip install -r requirements.txt; fi - name: Lint with flake8 run: | diff --git a/.github/workflows/pull_request.yml b/.github/workflows/pull_request.yml index 7c6e49f..453fd53 100644 --- a/.github/workflows/pull_request.yml +++ b/.github/workflows/pull_request.yml @@ -19,7 +19,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install flake8 mypy pytest fastapi>=0.60.0 python-jose>=3.2.0 pydantic-settings httpx requests types-requests + python -m pip install flake8 mypy pytest fastapi>=0.60.0 PyJWT>=2.8.0 pydantic-settings httpx requests types-requests if [ -f requirements.txt ]; then pip install -r requirements.txt; fi - name: Lint with flake8 run: | diff --git a/setup.py b/setup.py index b7346e0..dcb7aa7 100644 --- a/setup.py +++ b/setup.py @@ -17,5 +17,5 @@ package_dir={'': 'src'}, package_data={'': ['py.typed']}, python_requires='>=3.7', - install_requires=['fastapi>=0.60.0', 'python-jose>=3.2.0'] + install_requires=['fastapi>=0.60.0', 'PyJWT>=2.8.0'] ) diff --git a/src/fastapi_auth0/auth.py b/src/fastapi_auth0/auth.py index 165ecaa..54a119f 100644 --- a/src/fastapi_auth0/auth.py +++ b/src/fastapi_auth0/auth.py @@ -1,17 +1,15 @@ -import json import logging import os -from typing import Optional, Dict, List, Type +from typing import Optional, Dict, List, Type, Any import urllib.parse import urllib.request -from jose import jwt # type: ignore +import jwt from fastapi import HTTPException, Depends, Request from fastapi.security import SecurityScopes, HTTPBearer, HTTPAuthorizationCredentials from fastapi.security import OAuth2, OAuth2PasswordBearer, OAuth2AuthorizationCodeBearer, OpenIdConnect from fastapi.openapi.models import OAuthFlows, OAuthFlowImplicit from pydantic import BaseModel, Field, ValidationError -from typing_extensions import TypedDict logger = logging.getLogger('fastapi_auth0') @@ -62,22 +60,10 @@ async def __call__(self, request: Request) -> Optional[str]: return None -class JwksKeyDict(TypedDict): - kid: str - kty: str - use: str - n: str - e: str - -class JwksDict(TypedDict): - keys: List[JwksKeyDict] - - - class Auth0: def __init__(self, domain: str, api_audience: str, scopes: Dict[str, str]={}, auto_error: bool=True, scope_auto_error: bool=True, email_auto_error: bool=False, - auth0user_model: Type[Auth0User]=Auth0User): + auth0user_model: Type[Auth0User]=Auth0User, options: Dict[str, Any] | None = None): self.domain = domain self.audience = api_audience @@ -88,8 +74,6 @@ def __init__(self, domain: str, api_audience: str, scopes: Dict[str, str]={}, self.auth0_user_model = auth0user_model self.algorithms = ['RS256'] - r = urllib.request.urlopen(f'https://{domain}/.well-known/jwks.json') - self.jwks: JwksDict = json.loads(r.read()) authorization_url_qs = urllib.parse.urlencode({'audience': api_audience}) authorization_url = f'https://{domain}/authorize?{authorization_url_qs}' @@ -103,6 +87,14 @@ def __init__(self, domain: str, api_audience: str, scopes: Dict[str, str]={}, tokenUrl=f'https://{domain}/oauth/token', scopes=scopes) self.oidc_scheme = OpenIdConnect(openIdConnectUrl=f'https://{domain}/.well-known/openid-configuration') + self.options = options or dict() + self.options.setdefault("verify_signature", True) + self.options.setdefault("verify_aud", True) + self.options.setdefault("verify_iss", True) + self.options.setdefault("verify_exp", True) + self.options.setdefault("verify_iat", True) + self.options.setdefault("require", ["iss", "sub", "aud", "iat", "exp"]) + self.jwks_client = jwt.PyJWKClient(f"https://{self.domain}/.well-known/jwks.json") async def get_user(self, @@ -139,27 +131,21 @@ async def get_user(self, logger.warning(msg) return None - rsa_key = {} - for key in self.jwks['keys']: - if key['kid'] == unverified_header['kid']: - rsa_key = { - 'kty': key['kty'], - 'kid': key['kid'], - 'use': key['use'], - 'n': key['n'], - 'e': key['e'] - } - break - if rsa_key: + try: + signing_key = self.jwks_client.get_signing_key_from_jwt(token) + options = self.options.copy() + leeway = options.pop("leeway", 0) payload = jwt.decode( token, - rsa_key, + signing_key.key, algorithms=self.algorithms, audience=self.audience, - issuer=f'https://{self.domain}/' + issuer=f"https://{self.domain}/", + leeway=leeway, + options=options, ) - else: - msg = 'Invalid kid header (wrong tenant or rotated public key)' + except jwt.PyJWKClientError as e: + msg = str(e) if self.auto_error: raise Auth0UnauthenticatedException(detail=msg) else: @@ -174,16 +160,16 @@ async def get_user(self, logger.warning(msg) return None - except jwt.JWTClaimsError: - msg = 'Invalid token claims (wrong issuer or audience)' + except (jwt.InvalidAudienceError, jwt.InvalidIssuerError): + msg = "Invalid token claims (wrong issuer or audience)" if self.auto_error: raise Auth0UnauthenticatedException(detail=msg) else: logger.warning(msg) return None - except jwt.JWTError: - msg = 'Malformed token' + except jwt.PyJWTError as e: + msg = f"Malformed token: {e}" if self.auto_error: raise Auth0UnauthenticatedException(detail=msg) else: