Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ repos:
- id: detect-wallet-private-key
types: [file]
exclude: .json
- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort
files: "\\.(py)$"
args: [--settings-path=pyproject.toml]
# - repo: https://github.com/pycqa/isort
# rev: 5.12.0
# hooks:
# - id: isort
# files: "\\.(py)$"
# args: [--settings-path=pyproject.toml]
- repo: https://github.com/psf/black
rev: 23.3.0
hooks:
Expand Down
13 changes: 13 additions & 0 deletions core/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from rest_framework.exceptions import APIException


class LoginValidationFailed(APIException):
status_code = 400
default_code = "login_validation_failed"
default_detail = "Cannot Login or SignUp"


class InvalidSignature(APIException):
status_code = 400
default_code = "invalid_signature"
default_detail = "The signature is invalid"
70 changes: 69 additions & 1 deletion core/serializers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from rest_framework import exceptions as rest_exceptions
from rest_framework import serializers

from . import models
from . import exceptions, models
from .service import CoreService


class AcceptedNFTSerializer(serializers.ModelSerializer):
Expand Down Expand Up @@ -49,3 +51,69 @@ def get_listings_count(self, obj: models.AcceptedNFT) -> int:
return models.Listing.objects.filter(
token_contract_address=obj.contract_address
).count()


class UserSerializer(serializers.ModelSerializer):
"""
Convert the user model class to dict-like data for json serialization.
"""

class Meta:
model = models.User
fields = ["id", "public_key", "email"]


class SignInSerializer(serializers.Serializer):
"""
Serializer class for the sign in request call.
It validates the signstures to ensure it is the right user that is
making the signin request.
This is for signing in with the wallet (account).

Data:
signatures: A list of strings representing the signatures of
the signed login message
public_key: The public key of the signer
"""

signatures = serializers.ListField(child=serializers.CharField())
public_key = serializers.CharField()

def validate(self, attrs: dict) -> dict:
"""
Validate the request data
"""
# The signature list length must be greater than or equals to 5
# according to the starknet signature format.
if len(attrs["signatures"]) < 5:
raise exceptions.InvalidSignature

# validate the signature
check = CoreService.validate_login_request(
attrs["signatures"], attrs["public_key"]
)
if not check:
raise exceptions.LoginValidationFailed

# prevent login if account is not active
if models.User.objects.filter(
public_key=attrs["public_key"], is_active=False
).exists():
raise rest_exceptions.AuthenticationFailed

return attrs

def save(self) -> dict:
"""
Create or get the user and generate an auth token data
for requests authentications

Returns:
dict: Data of the user
"""
public_key = self.validated_data["public_key"]
user, is_new = CoreService.login_or_register_user(public_key)
data = UserSerializer(user).data
data["is_new"] = is_new
token_info = CoreService.generate_auth_token_data(user)
return {**data, **token_info}
57 changes: 57 additions & 0 deletions core/service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from rest_framework_simplejwt.tokens import RefreshToken
from starknet_py.utils.typed_data import TypedData

from core import models

from .utils import SignatureUtils


class CoreService:
@classmethod
def generate_auth_token_data(user: models.User) -> dict:
"""
Create the login token for validating protected requests.
Args:
user(models.User): the user model
"""
token_data_obj = RefreshToken.for_user(user)
expiry = token_data_obj.access_token["exp"]
token_data = {
"access": str(token_data_obj.access_token),
"refresh": str(token_data_obj),
"expiry": expiry,
}
return token_data

@classmethod
def validate_login_request(
cls, signatures: list[str], public_key: list[str]
) -> bool:
"""
Validate the Signed data that verifies the login of the user.
The data is signed by the user's wallet and it's components
are sent for verifiction
Args:
signatures(list[str]): the signatures of the message
public_key: The public key of the signer.

Returns:
bool: a bool representing whether the signature is valid or not.
"""
typed_data_dict = SignatureUtils.login_typed_data_format()
typed_data = TypedData(**typed_data_dict)
return SignatureUtils.verify_signatures(typed_data, signatures, public_key)

@classmethod
def login_or_register_user(cls, public_key: str) -> tuple[models.User, bool]:
"""
Create or get the existing user model
Args:
public_key: The public key of the user

Returns:
tuple[User, bool]: The User model and a bool indicating
whether is a new model or not
"""
user, created = models.User.objects.get_or_create(public_key=public_key)
return user, created
5 changes: 5 additions & 0 deletions core/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,9 @@
views.AcceptedTokenListAPIView.as_view(),
name="accepted-tokens-list-view",
),
path(
"sigin/",
views.SignInAPIView.as_view(),
name="signin",
),
]
86 changes: 85 additions & 1 deletion core/utils.py
Original file line number Diff line number Diff line change
@@ -1 +1,85 @@
# file of all the utility functions and variables
# file of all the utility functions, variables and classes
from django.conf import settings
from starknet_py.hash.utils import verify_message_signature
from starknet_py.utils.typed_data import Domain, Parameter, TypedData

DOMAIN_NAME = settings.SIG_DOMAIN_NAME
CHAIN_ID = settings.SIG_CHAIN_ID
VERSION = settings.SIG_VERSION


class SignatureUtils:
@classmethod
def login_typed_data_format(cls) -> dict:
"""
This represents the signature request of the login operation.
Read on starknet signatures to understand more

Returns:
dict: The signature request structure of the login functionality.
"""
data = {
"domain": Domain(
**{
"name": DOMAIN_NAME,
"chain_id": CHAIN_ID,
"version": VERSION,
}
),
"types": {
"StarknetDomain": [
Parameter(**{"name": "name", "type": "felt"}),
Parameter(**{"name": "chainId", "type": "felt"}),
Parameter(**{"name": "version", "type": "felt"}),
],
"Message": [
Parameter(**{"name": "name", "type": "felt"}),
Parameter(**{"name": "age", "type": "felt"}),
Parameter(**{"name": "address", "type": "felt"}),
],
},
"primary_type": "Message",
"message": {},
}
return data.copy()

@classmethod
def generate_signature_typed_data(cls, data: dict, type_format: dict) -> TypedData:
"""
Integrates the data from a signing format into a signature request.
This data is based on the request structure and the signature message type.
It generates a TypedData from the typed data format.

Args:
data(dict): the data that contains the essential details of the signature
that is integrated into the typed_data format (the signature request).
type_format(dict): The typed data format that contains the meta data
of the signature request.
Returns:
TypedData: returns that typed data that is used for generating a
message hash for signature verification.
"""
login_signature_request_format = type_format
# add the data into the message section of the dict
login_signature_request_format["message"] = data
return TypedData(**login_signature_request_format)

@classmethod
def verify_signatures(
cls, typed_data: TypedData, signatures: list[str], public_key: str
) -> bool:
"""
Verify the signature with the typed data, signature list and the public key
Args:
typed_data(TypedData): This is used for generating a message hash
for signature verification.
signatures(list[str]): This is a list of the signatures that represent
the message that is signed.
public_key(str): The public key of the signer.
"""
int_signatures = list(map(lambda x: int(x), signatures))
int_public_key = int(public_key, 16)
message_hash = typed_data.message_hash(int_public_key)
return verify_message_signature(
message_hash, [int_signatures[3], int_signatures[4]], public_key
)
14 changes: 13 additions & 1 deletion core/views.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Create your views here.
from rest_framework.generics import ListAPIView
from rest_framework.generics import GenericAPIView, ListAPIView
from rest_framework.response import Response

from . import models, serializers

Expand All @@ -12,3 +13,14 @@ class AcceptedNFTListAPIView(ListAPIView):
class AcceptedTokenListAPIView(ListAPIView):
queryset = models.AcceptedToken.objects.all().order_by("name")
serializer_class = serializers.AcceptedTokenSerializer


class SignInAPIView(GenericAPIView):
serializer_class = serializers.SignInSerializer

def post(self, request):
serializer = self.serializer_class(data=request.data)
serializer.is_valid(raise_exception=True)
data = serializer.save()

return Response(data)
43 changes: 42 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ dj-database-url = "^2.3.0"
django-cors-headers = "^4.7.0"
psycopg2-binary = "^2.9.10"
factory-boy = "3.3.0"
djangorestframework-simplejwt = "^5.5.0"

[tool.poetry.group.dev.dependencies]
black = "^23.9.1"
Expand Down
15 changes: 15 additions & 0 deletions trajectfi/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
"""

import os
from datetime import timedelta
from pathlib import Path

import dj_database_url
Expand Down Expand Up @@ -159,5 +160,19 @@
"TIME_FORMAT": "%H:%M:%S",
}

# JWT settings
SIMPLE_JWT = {
"AUTH_HEADER_TYPES": ("Bearer",),
"USER_ID_CLAIM": "user_id",
"ACCESS_TOKEN_LIFETIME": timedelta(days=3),
"REFRESH_TOKEN_LIFETIME": timedelta(days=30),
}

# settings for generating signature request format
SIG_DOMAIN_NAME = "TRAJECTFI"
SIG_CHAIN_ID = "SN_SEPOLIA"
SIG_VERSION = "0.1.0"


# Custom settings
AUTH_USER_MODEL = "core.User"