Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install flake8 mypy pytest fastapi>=0.60.0 python-jose>=3.2.0 pydantic-settings httpx requests types-requests
python -m pip install flake8 mypy pytest fastapi>=0.60.0 PyJWT>=2.8.0 pydantic-settings httpx requests types-requests
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Lint with flake8
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/pull_request.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install flake8 mypy pytest fastapi>=0.60.0 python-jose>=3.2.0 pydantic-settings httpx requests types-requests
python -m pip install flake8 mypy pytest fastapi>=0.60.0 PyJWT>=2.8.0 pydantic-settings httpx requests types-requests
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Lint with flake8
run: |
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,5 @@
package_dir={'': 'src'},
package_data={'': ['py.typed']},
python_requires='>=3.7',
install_requires=['fastapi>=0.60.0', 'python-jose>=3.2.0']
install_requires=['fastapi>=0.60.0', 'PyJWT>=2.8.0']
)
64 changes: 25 additions & 39 deletions src/fastapi_auth0/auth.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
import json
import logging
import os
from typing import Optional, Dict, List, Type
from typing import Optional, Dict, List, Type, Any
import urllib.parse
import urllib.request

from jose import jwt # type: ignore
import jwt
from fastapi import HTTPException, Depends, Request
from fastapi.security import SecurityScopes, HTTPBearer, HTTPAuthorizationCredentials
from fastapi.security import OAuth2, OAuth2PasswordBearer, OAuth2AuthorizationCodeBearer, OpenIdConnect
from fastapi.openapi.models import OAuthFlows, OAuthFlowImplicit
from pydantic import BaseModel, Field, ValidationError
from typing_extensions import TypedDict


logger = logging.getLogger('fastapi_auth0')
Expand Down Expand Up @@ -62,22 +60,10 @@ async def __call__(self, request: Request) -> Optional[str]:
return None


class JwksKeyDict(TypedDict):
kid: str
kty: str
use: str
n: str
e: str

class JwksDict(TypedDict):
keys: List[JwksKeyDict]



class Auth0:
def __init__(self, domain: str, api_audience: str, scopes: Dict[str, str]={},
auto_error: bool=True, scope_auto_error: bool=True, email_auto_error: bool=False,
auth0user_model: Type[Auth0User]=Auth0User):
auth0user_model: Type[Auth0User]=Auth0User, options: Dict[str, Any] | None = None):
self.domain = domain
self.audience = api_audience

Expand All @@ -88,8 +74,6 @@ def __init__(self, domain: str, api_audience: str, scopes: Dict[str, str]={},
self.auth0_user_model = auth0user_model

self.algorithms = ['RS256']
r = urllib.request.urlopen(f'https://{domain}/.well-known/jwks.json')
self.jwks: JwksDict = json.loads(r.read())

authorization_url_qs = urllib.parse.urlencode({'audience': api_audience})
authorization_url = f'https://{domain}/authorize?{authorization_url_qs}'
Expand All @@ -103,6 +87,14 @@ def __init__(self, domain: str, api_audience: str, scopes: Dict[str, str]={},
tokenUrl=f'https://{domain}/oauth/token',
scopes=scopes)
self.oidc_scheme = OpenIdConnect(openIdConnectUrl=f'https://{domain}/.well-known/openid-configuration')
self.options = options or dict()
self.options.setdefault("verify_signature", True)
self.options.setdefault("verify_aud", True)
self.options.setdefault("verify_iss", True)
self.options.setdefault("verify_exp", True)
self.options.setdefault("verify_iat", True)
self.options.setdefault("require", ["iss", "sub", "aud", "iat", "exp"])
self.jwks_client = jwt.PyJWKClient(f"https://{self.domain}/.well-known/jwks.json")


async def get_user(self,
Expand Down Expand Up @@ -139,27 +131,21 @@ async def get_user(self,
logger.warning(msg)
return None

rsa_key = {}
for key in self.jwks['keys']:
if key['kid'] == unverified_header['kid']:
rsa_key = {
'kty': key['kty'],
'kid': key['kid'],
'use': key['use'],
'n': key['n'],
'e': key['e']
}
break
if rsa_key:
try:
signing_key = self.jwks_client.get_signing_key_from_jwt(token)
options = self.options.copy()
leeway = options.pop("leeway", 0)
payload = jwt.decode(
token,
rsa_key,
signing_key.key,
algorithms=self.algorithms,
audience=self.audience,
issuer=f'https://{self.domain}/'
issuer=f"https://{self.domain}/",
leeway=leeway,
options=options,
)
else:
msg = 'Invalid kid header (wrong tenant or rotated public key)'
except jwt.PyJWKClientError as e:
msg = str(e)
if self.auto_error:
raise Auth0UnauthenticatedException(detail=msg)
else:
Expand All @@ -174,16 +160,16 @@ async def get_user(self,
logger.warning(msg)
return None

except jwt.JWTClaimsError:
msg = 'Invalid token claims (wrong issuer or audience)'
except (jwt.InvalidAudienceError, jwt.InvalidIssuerError):
msg = "Invalid token claims (wrong issuer or audience)"
if self.auto_error:
raise Auth0UnauthenticatedException(detail=msg)
else:
logger.warning(msg)
return None

except jwt.JWTError:
msg = 'Malformed token'
except jwt.PyJWTError as e:
msg = f"Malformed token: {e}"
if self.auto_error:
raise Auth0UnauthenticatedException(detail=msg)
else:
Expand Down