From 66481163f9a008b472aa04c4c7aacf92b257c249 Mon Sep 17 00:00:00 2001 From: Noam Date: Sat, 18 Mar 2023 10:16:42 +0200 Subject: [PATCH 1/2] Added basic pre-commit config. --- .pre-commit-config.yaml | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) create mode 100644 .pre-commit-config.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..e5db6a5b --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,18 @@ +repos: + - repo: https://github.com/ambv/black + rev: 23.1.0 + hooks: + - id: black + language_version: python3 + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.4.0 + hooks: + - id: end-of-file-fixer + - id: check-merge-conflict + - id: mixed-line-ending + - repo: https://github.com/pycqa/flake8 + rev: 6.0.0 + hooks: + - id: flake8 + language_version: python3 + args: ['--select=E9,F63,F7,F82'] From 1a7bc7a065fa203f6d8beabbf3c23d71fbbc8df6 Mon Sep 17 00:00:00 2001 From: Noam Date: Sat, 18 Mar 2023 12:52:36 +0200 Subject: [PATCH 2/2] Added djangorestframework-stubs and fixes errors. --- .github/workflows/django-package.yml | 2 +- README.rst | 2 +- docs/conf.py | 4 +- setup.cfg | 6 +++ testproject/settings.py | 2 +- testproject/tests/test_add_mfa.py | 12 +++--- testproject/tests/test_exceptions.py | 15 ++++--- testproject/tests/test_utils.py | 6 ++- testproject/tests/utils.py | 9 ++-- trench/backends/application.py | 6 ++- trench/backends/aws.py | 2 +- trench/backends/basic_mail.py | 4 +- trench/command/authenticate_second_factor.py | 12 ++---- trench/command/authenticate_user.py | 10 ++--- trench/models.py | 4 +- trench/responses.py | 27 ++++++------ trench/serializers.py | 11 +++-- trench/settings.py | 36 +++++++++++++--- trench/utils.py | 29 +++++++------ trench/views/authtoken.py | 5 ++- trench/views/base.py | 45 +++++++++++--------- trench/views/jwt.py | 5 ++- 22 files changed, 151 insertions(+), 103 deletions(-) diff --git a/.github/workflows/django-package.yml b/.github/workflows/django-package.yml index 03ff3a5c..922f086e 100644 --- a/.github/workflows/django-package.yml +++ b/.github/workflows/django-package.yml @@ -23,7 +23,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install mypy flake8 pytest pytest-xdist flaky + python -m pip install mypy djangorestframework-stubs flake8 pytest pytest-xdist flaky if [ -f testproject/requirements.txt ]; then pip install -r testproject/requirements.txt; fi ln -s $(pwd)/trench/ $(pwd)/testproject/trench - name: Lint trench package with flake8 diff --git a/README.rst b/README.rst index d69400b9..bcfcccc7 100644 --- a/README.rst +++ b/README.rst @@ -114,7 +114,7 @@ Local development .. code-block:: shell - pip install black mypy + pip install black mypy djangorestframework-stubs pip install -r testproject/requirements.txt 5. Set environment variables: diff --git a/docs/conf.py b/docs/conf.py index ea1d1446..7d20d73c 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -15,6 +15,8 @@ # import os # import sys # sys.path.insert(0, os.path.abspath('.')) +from typing import Dict + import sphinx_rtd_theme @@ -114,7 +116,7 @@ # -- Options for LaTeX output ------------------------------------------------ -latex_elements = { +latex_elements: Dict[str, str] = { # The paper size ('letterpaper' or 'a4paper'). # # 'papersize': 'letterpaper', diff --git a/setup.cfg b/setup.cfg index b8c38b48..75f7584b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,11 @@ [mypy] ignore_missing_imports = True +plugins = + mypy_django_plugin.main, + mypy_drf_plugin.main + +[mypy.plugins.django-stubs] +django_settings_module = "testproject.settings" [flake8] inline-quotes = " diff --git a/testproject/settings.py b/testproject/settings.py index fe2e1e40..fae05aac 100644 --- a/testproject/settings.py +++ b/testproject/settings.py @@ -16,7 +16,7 @@ ALLOWED_HOSTS = env.list("ALLOWED_HOSTS", default=["*"]) CORS_ORIGIN_ALLOW_ALL = env.bool("CORS_ORIGIN_ALLOW_ALL", default=False) DEFAULT_AUTO_FIELD = "django.db.models.BigAutoField" -STATIC_ROOT= os.path.join(BASE_DIR, 'static/') +STATIC_ROOT = os.path.join(BASE_DIR, "static/") INSTALLED_APPS = [ "django.contrib.admin", diff --git a/testproject/tests/test_add_mfa.py b/testproject/tests/test_add_mfa.py index d8ecb574..2672f246 100644 --- a/testproject/tests/test_add_mfa.py +++ b/testproject/tests/test_add_mfa.py @@ -1,7 +1,6 @@ import pytest from django.conf import settings -from django.contrib.auth import get_user_model from django.contrib.auth.models import AbstractUser from flaky import flaky @@ -12,9 +11,6 @@ from trench.command.create_secret import create_secret_command -User: AbstractUser = get_user_model() - - @pytest.mark.django_db def test_add_user_mfa(active_user): client = TrenchAPIClient() @@ -33,11 +29,13 @@ def test_add_user_mfa(active_user): @pytest.mark.django_db -def test_should_fail_on_add_user_mfa_with_invalid_source_field(active_user: User): +def test_should_fail_on_add_user_mfa_with_invalid_source_field( + active_user: AbstractUser, +): client = TrenchAPIClient() client.authenticate(user=active_user) secret = create_secret_command() - settings.TRENCH_AUTH["MFA_METHODS"]["email"]["SOURCE_FIELD"] = "email_test" + settings.TRENCH_AUTH["MFA_METHODS"]["email"]["SOURCE_FIELD"] = "email_test" # type: ignore[index] response = client.post( path="/auth/email/activate/", @@ -53,7 +51,7 @@ def test_should_fail_on_add_user_mfa_with_invalid_source_field(active_user: User response.data.get("error") == "Field name `email_test` is not valid for model `User`." ) - settings.TRENCH_AUTH["MFA_METHODS"]["email"]["SOURCE_FIELD"] = "email" + settings.TRENCH_AUTH["MFA_METHODS"]["email"]["SOURCE_FIELD"] = "email" # type: ignore[index] @flaky diff --git a/testproject/tests/test_exceptions.py b/testproject/tests/test_exceptions.py index 26f7d07d..dcfc3bef 100644 --- a/testproject/tests/test_exceptions.py +++ b/testproject/tests/test_exceptions.py @@ -30,8 +30,9 @@ def test_method_handler_missing_error(): assert settings.MFA_METHODS["method_without_handler"] is None -def test_code_missing_error(): - validator = ProtectedActionValidator(mfa_method_name="yubi", user=None) +@pytest.mark.django_db +def test_code_missing_error(active_user): + validator = ProtectedActionValidator(mfa_method_name="yubi", user=active_user) with pytest.raises(OTPCodeMissingError): validator.validate_code(value="") @@ -44,15 +45,17 @@ def test_request_body_validator(): validator.update(instance=MFAMethod(), validated_data=OrderedDict()) -def test_protected_action_validator(): - validator = ProtectedActionValidator(mfa_method_name="yubi", user=None) +@pytest.mark.django_db +def test_protected_action_validator(active_user): + validator = ProtectedActionValidator(mfa_method_name="yubi", user=active_user) with pytest.raises(NotImplementedError): validator._validate_mfa_method(mfa=MFAMethod()) -def test_mfa_method_activation_validator(): +@pytest.mark.django_db +def test_mfa_method_activation_validator(active_user): validator = MFAMethodActivationConfirmationValidator( - mfa_method_name="yubi", user=None + mfa_method_name="yubi", user=active_user ) with pytest.raises(MFAMethodAlreadyActiveError): validator._validate_mfa_method(mfa=MFAMethod(is_active=True)) diff --git a/testproject/tests/test_utils.py b/testproject/tests/test_utils.py index f5807a5d..4d03ac77 100644 --- a/testproject/tests/test_utils.py +++ b/testproject/tests/test_utils.py @@ -1,3 +1,5 @@ +from typing import cast + import pytest from trench.backends.application import ApplicationMessageDispatcher @@ -23,7 +25,9 @@ def test_invalid_token(): def test_create_qr_link(active_user_with_many_otp_methods): user, _ = active_user_with_many_otp_methods mfa_method: MFAMethod = user.mfa_methods.filter(name="app").first() - handler: ApplicationMessageDispatcher = get_mfa_handler(mfa_method) + handler: ApplicationMessageDispatcher = cast( + ApplicationMessageDispatcher, get_mfa_handler(mfa_method) + ) qr_link = handler._create_qr_link(user=user) assert type(qr_link) == str assert user.username in qr_link diff --git a/testproject/tests/utils.py b/testproject/tests/utils.py index 2f4275cb..357fccc3 100644 --- a/testproject/tests/utils.py +++ b/testproject/tests/utils.py @@ -65,13 +65,16 @@ def _second_factor_request( code: Optional[str] = None, path: str = PATH_AUTH_JWT_LOGIN_CODE, ) -> Response: - if handler is None and code is None: - raise ValueError("handler and code can't be None simultaneously") + if code is None: + if handler is None: + raise ValueError("handler and code can't be None simultaneously") + else: + code = handler.create_code() return self.post( path=path, data={ "ephemeral_token": ephemeral_token, - "code": handler.create_code() if code is None else code, # type: ignore + "code": code, }, format="json", ) diff --git a/trench/backends/application.py b/trench/backends/application.py index 879270a3..6f1d6dd3 100644 --- a/trench/backends/application.py +++ b/trench/backends/application.py @@ -1,3 +1,5 @@ +from typing import Type + from django.contrib.auth import get_user_model from django.contrib.auth.models import AbstractUser @@ -12,7 +14,7 @@ from trench.settings import trench_settings -User: AbstractUser = get_user_model() +User: Type[AbstractUser] = get_user_model() class ApplicationMessageDispatcher(AbstractMessageDispatcher): @@ -24,7 +26,7 @@ def dispatch_message(self) -> DispatchResponse: logging.error(cause, exc_info=True) # pragma: nocover return FailedDispatchResponse(details=str(cause)) # pragma: nocover - def _create_qr_link(self, user: User) -> str: + def _create_qr_link(self, user: AbstractUser) -> str: return self._get_otp().provisioning_uri( getattr(user, User.USERNAME_FIELD), trench_settings.APPLICATION_ISSUER_NAME, diff --git a/trench/backends/aws.py b/trench/backends/aws.py index 7d8cfaed..eebfc297 100644 --- a/trench/backends/aws.py +++ b/trench/backends/aws.py @@ -2,7 +2,6 @@ import logging import boto3 -import botocore.exceptions from trench.backends.base import AbstractMessageDispatcher from trench.responses import ( @@ -13,6 +12,7 @@ from trench.settings import AWS_ACCESS_KEY, AWS_SECRET_KEY, AWS_REGION from botocore.exceptions import ClientError, EndpointConnectionError + class AWSMessageDispatcher(AbstractMessageDispatcher): _SMS_BODY = _("Your verification code is: ") _SUCCESS_DETAILS = _("SMS message with MFA code has been sent.") diff --git a/trench/backends/basic_mail.py b/trench/backends/basic_mail.py index 1f02e3eb..86315fdc 100644 --- a/trench/backends/basic_mail.py +++ b/trench/backends/basic_mail.py @@ -25,11 +25,11 @@ def dispatch_message(self) -> DispatchResponse: email_html_template = self._config[EMAIL_HTML_TEMPLATE] try: send_mail( - subject=self._config.get(EMAIL_SUBJECT), + subject=str(self._config.get(EMAIL_SUBJECT)), message=get_template(email_plain_template).render(context), html_message=get_template(email_html_template).render(context), from_email=settings.DEFAULT_FROM_EMAIL, - recipient_list=(self._to,), + recipient_list=(self._to,) if self._to else (), fail_silently=False, ) return SuccessfulDispatchResponse(details=self._SUCCESS_DETAILS) diff --git a/trench/command/authenticate_second_factor.py b/trench/command/authenticate_second_factor.py index e7540a31..a19f74be 100644 --- a/trench/command/authenticate_second_factor.py +++ b/trench/command/authenticate_second_factor.py @@ -1,8 +1,7 @@ -from django.contrib.auth import get_user_model -from django.contrib.auth.models import AbstractUser - from typing import Type +from django.contrib.auth.base_user import AbstractBaseUser + from trench.backends.provider import get_mfa_handler from trench.command.remove_backup_code import remove_backup_code_command from trench.command.validate_backup_code import validate_backup_code_command @@ -11,18 +10,15 @@ from trench.utils import get_mfa_model, user_token_generator -User: AbstractUser = get_user_model() - - class AuthenticateSecondFactorCommand: def __init__(self, mfa_model: Type[MFAMethod]) -> None: self._mfa_model = mfa_model - def execute(self, code: str, ephemeral_token: str) -> User: + def execute(self, code: str, ephemeral_token: str) -> AbstractBaseUser: user = user_token_generator.check_token(user=None, token=ephemeral_token) if user is None: raise InvalidTokenError() - self.is_authenticated(user_id=user.id, code=code) + self.is_authenticated(user_id=user.pk, code=code) return user def is_authenticated(self, user_id: int, code: str) -> None: diff --git a/trench/command/authenticate_user.py b/trench/command/authenticate_user.py index 37d23f7b..e67f10ee 100644 --- a/trench/command/authenticate_user.py +++ b/trench/command/authenticate_user.py @@ -1,17 +1,13 @@ -from django.contrib.auth import authenticate, get_user_model -from django.contrib.auth.models import AbstractUser - +from django.contrib.auth import authenticate +from django.contrib.auth.base_user import AbstractBaseUser from rest_framework.request import Request from trench.exceptions import UnauthenticatedError -User: AbstractUser = get_user_model() - - class AuthenticateUserCommand: @staticmethod - def execute(request: Request, username: str, password: str) -> User: + def execute(request: Request, username: str, password: str) -> AbstractBaseUser: user = authenticate( request=request, username=username, diff --git a/trench/models.py b/trench/models.py index 6b1e19ca..e12c5bf8 100644 --- a/trench/models.py +++ b/trench/models.py @@ -19,7 +19,7 @@ from trench.exceptions import MFAMethodDoesNotExistError -class MFAUserMethodManager(Manager): +class MFAUserMethodManager(Manager["MFAMethod"]): def get_by_name(self, user_id: Any, name: str) -> "MFAMethod": try: return self.get(user_id=user_id, name=name) @@ -96,7 +96,7 @@ class Meta: objects = MFAUserMethodManager() def __str__(self) -> str: - return f"{self.name} (User id: {self.user_id})" + return f"{self.name} (User id: {self.user_id})" # type: ignore[attr-defined] @property def backup_codes(self) -> Iterable[str]: diff --git a/trench/responses.py b/trench/responses.py index b7679f61..8190abac 100644 --- a/trench/responses.py +++ b/trench/responses.py @@ -1,3 +1,6 @@ +from typing import Union + +from django_stubs_ext import StrOrPromise from rest_framework.response import Response from rest_framework.status import ( HTTP_200_OK, @@ -14,20 +17,20 @@ class DispatchResponse(Response): class SuccessfulDispatchResponse(DispatchResponse): def __init__( - self, details: str, status: str = HTTP_200_OK, *args, **kwargs + self, details: StrOrPromise, status: int = HTTP_200_OK, *args, **kwargs ) -> None: - super().__init__( - data={self._FIELD_DETAILS: details}, status=status, *args, **kwargs - ) + super().__init__({self._FIELD_DETAILS: details}, status, *args, **kwargs) class FailedDispatchResponse(DispatchResponse): def __init__( - self, details: str, status: str = HTTP_422_UNPROCESSABLE_ENTITY, *args, **kwargs + self, + details: StrOrPromise, + status: int = HTTP_422_UNPROCESSABLE_ENTITY, + *args, + **kwargs ) -> None: - super().__init__( - data={self._FIELD_DETAILS: details}, status=status, *args, **kwargs - ) + super().__init__({self._FIELD_DETAILS: details}, status, *args, **kwargs) class ErrorResponse(Response): @@ -35,11 +38,9 @@ class ErrorResponse(Response): def __init__( self, - error: MFAValidationError, - status: str = HTTP_400_BAD_REQUEST, + error: Union[StrOrPromise, MFAValidationError], + status: int = HTTP_400_BAD_REQUEST, *args, **kwargs ) -> None: - super().__init__( - data={self._FIELD_ERROR: str(error)}, status=status, *args, **kwargs - ) + super().__init__({self._FIELD_ERROR: str(error)}, status, *args, **kwargs) diff --git a/trench/serializers.py b/trench/serializers.py index efb803a1..c4a9405e 100644 --- a/trench/serializers.py +++ b/trench/serializers.py @@ -1,4 +1,5 @@ from django.contrib.auth import get_user_model +from django.contrib.auth.base_user import AbstractBaseUser from django.contrib.auth.models import AbstractUser from django.db.models import Model @@ -6,7 +7,7 @@ from rest_framework.authtoken.models import Token from rest_framework.fields import CharField, ChoiceField from rest_framework.serializers import ModelSerializer, Serializer -from typing import Any, OrderedDict +from typing import Any, OrderedDict, Type from trench.backends.provider import get_mfa_handler from trench.command.remove_backup_code import remove_backup_code_command @@ -23,7 +24,7 @@ from trench.utils import available_method_choices, get_mfa_model -User: AbstractUser = get_user_model() +User: Type[AbstractUser] = get_user_model() class RequestBodyValidator(Serializer): @@ -46,7 +47,9 @@ def _get_validation_method_name() -> str: def _validate_mfa_method(mfa: MFAMethod) -> None: raise NotImplementedError - def __init__(self, mfa_method_name: str, user: User, *args, **kwargs) -> None: + def __init__( + self, mfa_method_name: str, user: AbstractBaseUser, *args, **kwargs + ) -> None: super().__init__(*args, **kwargs) self._user = user self._mfa_method_name = mfa_method_name @@ -56,7 +59,7 @@ def validate_code(self, value: str) -> str: raise OTPCodeMissingError() mfa_model = get_mfa_model() mfa = mfa_model.objects.get_by_name( - user_id=self._user.id, name=self._mfa_method_name + user_id=self._user.pk, name=self._mfa_method_name ) self._validate_mfa_method(mfa) diff --git a/trench/settings.py b/trench/settings.py index 885fad5a..dee8e349 100644 --- a/trench/settings.py +++ b/trench/settings.py @@ -3,17 +3,39 @@ import string from rest_framework.settings import APISettings, perform_import -from typing import Any, Dict +from typing import Any, Dict, TYPE_CHECKING +from typing_extensions import Literal from trench.exceptions import MethodHandlerMissingError +if TYPE_CHECKING: + from rest_framework.settings import DefaultsSettings + + class TrenchDefaultsSettings(DefaultsSettings): + USER_MFA_MODEL: str + USER_ACTIVE_FIELD: str + BACKUP_CODES_QUANTITY: int + BACKUP_CODES_LENGTH: int + BACKUP_CODES_CHARACTERS: str + SECRET_KEY_LENGTH: int + DEFAULT_VALIDITY_PERIOD: int + CONFIRM_DISABLE_WITH_CODE: bool + CONFIRM_BACKUP_CODES_REGENERATION_WITH_CODE: bool + ALLOW_BACKUP_CODES_REGENERATION: bool + ENCRYPT_BACKUP_CODES: bool + APPLICATION_ISSUER_NAME: str + MFA_METHODS: dict + class TrenchAPISettings(APISettings): - _FIELD_USER_SETTINGS = "_user_settings" - _FIELD_TRENCH_AUTH = "TRENCH_AUTH" - _FIELD_BACKUP_CODES_CHARACTERS = "BACKUP_CODES_CHARACTERS" - _FIELD_MFA_METHODS = "MFA_METHODS" - _FIELD_HANDLER = "HANDLER" + defaults: "TrenchDefaultsSettings" + _FIELD_USER_SETTINGS: Literal["_user_settings"] = "_user_settings" + _FIELD_TRENCH_AUTH: Literal["TRENCH_AUTH"] = "TRENCH_AUTH" + _FIELD_BACKUP_CODES_CHARACTERS: Literal[ + "BACKUP_CODES_CHARACTERS" + ] = "BACKUP_CODES_CHARACTERS" + _FIELD_MFA_METHODS: Literal["MFA_METHODS"] = "MFA_METHODS" + _FIELD_HANDLER: Literal["HANDLER"] = "HANDLER" @property def user_settings(self) -> Dict[str, Any]: @@ -56,7 +78,7 @@ def __getitem__(self, attr: str) -> Any: AWS_SECRET_KEY = "AWS_SECRET_KEY" AWS_REGION = "AWS_REGION" -DEFAULTS = { +DEFAULTS: "TrenchDefaultsSettings" = { "USER_MFA_MODEL": "trench.MFAMethod", "USER_ACTIVE_FIELD": "is_active", "BACKUP_CODES_QUANTITY": 5, diff --git a/trench/utils.py b/trench/utils.py index fe22ea8f..1c37e01b 100644 --- a/trench/utils.py +++ b/trench/utils.py @@ -1,22 +1,19 @@ +from datetime import datetime +from typing import List, Optional, Tuple, Type + from django.apps import apps from django.conf import settings from django.contrib.auth import get_user_model -from django.contrib.auth.models import AbstractUser +from django.contrib.auth.base_user import AbstractBaseUser from django.contrib.auth.tokens import PasswordResetTokenGenerator from django.utils.crypto import constant_time_compare, salted_hmac from django.utils.http import base36_to_int, int_to_base36 from django.utils.translation import gettext_lazy as _ -from datetime import datetime -from typing import List, Optional, Tuple, Type - from trench.models import MFAMethod from trench.settings import VERBOSE_NAME, trench_settings -User: AbstractUser = get_user_model() - - class UserTokenGenerator(PasswordResetTokenGenerator): """ Custom token generator: @@ -29,10 +26,12 @@ class UserTokenGenerator(PasswordResetTokenGenerator): SECRET = settings.SECRET_KEY EXPIRY_TIME = 60 * 15 - def make_token(self, user: User) -> str: + def make_token(self, user: AbstractBaseUser) -> str: return self._make_token_with_timestamp(user, int(datetime.now().timestamp())) - def check_token(self, user: User, token: str) -> Optional[User]: + def check_token( # type: ignore[override] # fixing return type would be a breaking change + self, user: Optional[AbstractBaseUser], token: Optional[str] + ) -> Optional[AbstractBaseUser]: user_model = get_user_model() if not token: return None @@ -40,19 +39,23 @@ def check_token(self, user: User, token: str) -> Optional[User]: token = str(token) user_pk, ts_b36, token_hash = token.rsplit("-", 2) ts = base36_to_int(ts_b36) - user = user_model._default_manager.get(pk=user_pk) + token_user = user_model._default_manager.get(pk=user_pk) except (ValueError, TypeError, user_model.DoesNotExist): return None if (datetime.now().timestamp() - ts) > self.EXPIRY_TIME: return None # pragma: no cover - if not constant_time_compare(self._make_token_with_timestamp(user, ts), token): + if not constant_time_compare( + self._make_token_with_timestamp(token_user, ts), token + ): return None # pragma: no cover - return user + return token_user - def _make_token_with_timestamp(self, user: User, timestamp: int, **kwargs) -> str: + def _make_token_with_timestamp( # type: ignore[override] + self, user: AbstractBaseUser, timestamp: int, **kwargs + ) -> str: ts_b36 = int_to_base36(timestamp) token_hash = salted_hmac( self.KEY_SALT, diff --git a/trench/views/authtoken.py b/trench/views/authtoken.py index 5258a559..426c2925 100644 --- a/trench/views/authtoken.py +++ b/trench/views/authtoken.py @@ -1,4 +1,5 @@ from django.contrib.auth import user_logged_in, user_logged_out +from django.contrib.auth.base_user import AbstractBaseUser from rest_framework import status from rest_framework.authtoken.models import Token @@ -8,11 +9,11 @@ from rest_framework.views import APIView from trench.serializers import TokenSerializer -from trench.views import MFAFirstStepMixin, MFASecondStepMixin, MFAStepMixin, User +from trench.views import MFAFirstStepMixin, MFASecondStepMixin, MFAStepMixin class MFAAuthTokenView(MFAStepMixin): - def _successful_authentication_response(self, user: User) -> Response: + def _successful_authentication_response(self, user: AbstractBaseUser) -> Response: token, _ = Token.objects.get_or_create(user=user) user_logged_in.send(sender=user.__class__, request=self.request, user=user) return Response(data=TokenSerializer(token).data) diff --git a/trench/views/base.py b/trench/views/base.py index cd896413..9b50ca71 100644 --- a/trench/views/base.py +++ b/trench/views/base.py @@ -1,4 +1,7 @@ +from typing import Type + from django.contrib.auth import get_user_model +from django.contrib.auth.base_user import AbstractBaseUser from django.contrib.auth.models import AbstractUser from django.db.models import QuerySet from django.utils.translation import gettext_lazy as _ @@ -46,15 +49,18 @@ from trench.settings import SOURCE_FIELD, trench_settings from trench.utils import available_method_choices, get_mfa_model, user_token_generator +User: Type[AbstractUser] = get_user_model() + -User: AbstractUser = get_user_model() +class AuthenticatedRequest(Request): + user: AbstractBaseUser class MFAStepMixin(APIView, ABC): permission_classes = (AllowAny,) @abstractmethod - def _successful_authentication_response(self, user: User) -> Response: + def _successful_authentication_response(self, user: AbstractBaseUser) -> Response: raise NotImplementedError @@ -72,7 +78,7 @@ def post(self, request: Request) -> Response: return ErrorResponse(error=cause) try: mfa_model = get_mfa_model() - mfa_method = mfa_model.objects.get_primary_active(user_id=user.id) + mfa_method = mfa_model.objects.get_primary_active(user_id=user.pk) get_mfa_handler(mfa_method=mfa_method).dispatch_message() return Response( data={ @@ -111,12 +117,11 @@ def post(request: Request, method: str) -> Response: try: if source_field is not None and not hasattr(user, source_field): raise MFASourceFieldDoesNotExistError( - source_field, - user.__class__.__name__ + source_field, user.__class__.__name__ ) mfa = create_mfa_method_command( - user_id=user.id, + user_id=user.pk, name=method, ) except MFAValidationError as cause: @@ -128,7 +133,7 @@ class MFAMethodConfirmActivationView(APIView): permission_classes = (IsAuthenticated,) @staticmethod - def post(request: Request, method: str) -> Response: + def post(request: AuthenticatedRequest, method: str) -> Response: serializer = MFAMethodActivationConfirmationValidator( mfa_method_name=method, user=request.user, data=request.data ) @@ -136,7 +141,7 @@ def post(request: Request, method: str) -> Response: return Response(status=HTTP_400_BAD_REQUEST, data=serializer.errors) try: backup_codes = activate_mfa_method_command( - user_id=request.user.id, + user_id=request.user.pk, name=method, code=serializer.validated_data["code"], ) @@ -149,7 +154,7 @@ class MFAMethodDeactivationView(APIView): permission_classes = (IsAuthenticated,) @staticmethod - def post(request: Request, method: str) -> Response: + def post(request: AuthenticatedRequest, method: str) -> Response: serializer = MFAMethodDeactivationValidator( mfa_method_name=method, user=request.user, data=request.data ) @@ -157,7 +162,7 @@ def post(request: Request, method: str) -> Response: return Response(status=HTTP_400_BAD_REQUEST, data=serializer.errors) try: deactivate_mfa_method_command( - mfa_method_name=method, user_id=request.user.id + mfa_method_name=method, user_id=request.user.pk ) return Response(status=HTTP_204_NO_CONTENT) except MFAValidationError as cause: @@ -168,7 +173,7 @@ class MFAMethodBackupCodesRegenerationView(APIView): permission_classes = (IsAuthenticated,) @staticmethod - def post(request: Request, method: str) -> Response: + def post(request: AuthenticatedRequest, method: str) -> Response: if not trench_settings.ALLOW_BACKUP_CODES_REGENERATION: return ErrorResponse(error=_("Backup codes regeneration is not allowed.")) serializer = MFAMethodBackupCodesGenerationValidator( @@ -178,7 +183,7 @@ def post(request: Request, method: str) -> Response: return Response(status=HTTP_400_BAD_REQUEST, data=serializer.errors) try: backup_codes = regenerate_backup_codes_for_mfa_method_command( - user_id=request.user.id, + user_id=request.user.pk, name=method, ) return Response({"backup_codes": backup_codes}) @@ -198,8 +203,10 @@ def get(request: Request) -> Response: for method_name, method_verbose_name in available_method_choices() ], "confirm_disable_with_code": trench_settings.CONFIRM_DISABLE_WITH_CODE, # noqa - "confirm_regeneration_with_code": trench_settings.CONFIRM_BACKUP_CODES_REGENERATION_WITH_CODE, # noqa - "allow_backup_codes_regeneration": trench_settings.ALLOW_BACKUP_CODES_REGENERATION, # noqa + "confirm_regeneration_with_code": trench_settings.CONFIRM_BACKUP_CODES_REGENERATION_WITH_CODE, + # noqa + "allow_backup_codes_regeneration": trench_settings.ALLOW_BACKUP_CODES_REGENERATION, + # noqa }, ) @@ -210,7 +217,7 @@ class MFAListActiveUserMethodsView(ListAPIView): def get_queryset(self) -> QuerySet: mfa_model = get_mfa_model() - return mfa_model.objects.list_active(user_id=self.request.user.id) + return mfa_model.objects.list_active(user_id=self.request.user.pk) class MFAMethodRequestCodeView(APIView): @@ -225,9 +232,9 @@ def post(request: Request) -> Response: mfa_model = get_mfa_model() if method is None: method = mfa_model.objects.get_primary_active_name( - user_id=request.user.id + user_id=request.user.pk ) - mfa = mfa_model.objects.get_by_name(user_id=request.user.id, name=method) + mfa = mfa_model.objects.get_by_name(user_id=request.user.pk, name=method) return get_mfa_handler(mfa_method=mfa).dispatch_message() except MFAValidationError as cause: return ErrorResponse(error=cause) @@ -237,7 +244,7 @@ class MFAPrimaryMethodChangeView(APIView): permission_classes = (IsAuthenticated,) @staticmethod - def post(request: Request) -> Response: + def post(request: AuthenticatedRequest) -> Response: method_serializer = ChangePrimaryMethodValidator(data=request.data) method_serializer.is_valid(raise_exception=True) @@ -249,7 +256,7 @@ def post(request: Request) -> Response: code_serializer.is_valid(raise_exception=True) try: set_primary_mfa_method_command( - user_id=request.user.id, name=method_serializer.validated_data["method"] + user_id=request.user.pk, name=method_serializer.validated_data["method"] ) return Response(status=HTTP_204_NO_CONTENT) except MFAValidationError as cause: diff --git a/trench/views/jwt.py b/trench/views/jwt.py index f826a751..8242177e 100644 --- a/trench/views/jwt.py +++ b/trench/views/jwt.py @@ -1,11 +1,12 @@ +from django.contrib.auth.base_user import AbstractBaseUser from rest_framework.response import Response from rest_framework_simplejwt.tokens import RefreshToken -from trench.views import MFAFirstStepMixin, MFASecondStepMixin, MFAStepMixin, User +from trench.views import MFAFirstStepMixin, MFASecondStepMixin, MFAStepMixin class MFAJWTView(MFAStepMixin): - def _successful_authentication_response(self, user: User) -> Response: + def _successful_authentication_response(self, user: AbstractBaseUser) -> Response: token = RefreshToken.for_user(user=user) return Response(data={"refresh": str(token), "access": str(token.access_token)})