diff --git a/src/unipoll_api/actions/authentication.py b/src/unipoll_api/actions/authentication.py index bd399e5..8feb6c6 100644 --- a/src/unipoll_api/actions/authentication.py +++ b/src/unipoll_api/actions/authentication.py @@ -52,13 +52,12 @@ async def refresh_token(authorization: str, async def refresh_token_with_clientID(authorization: str, - body: str, + refresh_token: str, token_db=Depends(get_access_token_db), strategy=Depends(get_database_strategy)): # Make sure the Authorization header is valid and extract the access token try: client_id = re.match(r'^Basic (\S+)$', authorization).group(1) # type: ignore - refresh_token = re.match(r'^refresh_token=(\S+)&grant_type=refresh_token$', body).group(1) # type: ignore except Exception as e: Debug.print_error(str(e)) raise AuthExceptions.InvalidAuthorizationHeader() @@ -68,7 +67,7 @@ async def refresh_token_with_clientID(authorization: str, # Make sure the access token exists in the database if token_data is None: - raise AuthExceptions.InvalidAccessToken() + raise AuthExceptions.InvalidAccessToken # Get the user from the database using the user ID in the token data user = await Account.get(token_data.user_id) @@ -78,14 +77,14 @@ async def refresh_token_with_clientID(authorization: str, # Decode the client ID and make sure it matches account ID client_id = base64.b64decode(client_id) if PydanticObjectId(str(client_id, "utf-8")[:-1]) != user.id: - raise AuthExceptions.InvalidClientID() + raise AuthExceptions.InvalidClientID # Check if the refresh token is the most recent one all_tokens = await token_db.get_token_family_by_user_id(user.id) if (await all_tokens.to_list())[0].refresh_token != refresh_token: # If not, delete all tokens associated with the user and return an error await strategy.destroy_token_family(user) - raise AuthExceptions.refreshTokenExpired() + raise AuthExceptions.refreshTokenExpired # Login the user using the supplied strategy # Generate new pair of access and refresh tokens diff --git a/src/unipoll_api/documents.py b/src/unipoll_api/documents.py index 0bf5228..72e8c68 100644 --- a/src/unipoll_api/documents.py +++ b/src/unipoll_api/documents.py @@ -2,7 +2,14 @@ from typing import Literal from bson import DBRef from beanie import Document as BeanieDocument -from beanie import BackLink, WriteRules, after_event, Insert, Link, PydanticObjectId # BackLink +from beanie import ( + BackLink, + WriteRules, + after_event, + Insert, + Link, + PydanticObjectId, +) # BackLink from fastapi_users_db_beanie import BeanieBaseUser from pydantic import Field from unipoll_api.utils import colored_dbg as Debug @@ -18,8 +25,9 @@ def get_document_type(cls) -> str: # Create a link to the Document model async def create_link(document: Document) -> Link: - ref = DBRef(collection=document._document_settings.name, # type: ignore - id=document.id) + ref = DBRef( + collection=document._document_settings.name, id=document.id # type: ignore + ) link = Link(ref, type(document)) return link @@ -41,7 +49,8 @@ class AccessToken(BeanieBaseAccessToken, Document): # type: ignore class Resource(Document): id: ResourceID = Field(default_factory=ResourceID, alias="_id") name: str = Field( - title="Name", description="Name of the resource", min_length=3, max_length=50) + title="Name", description="Name of the resource", min_length=3, max_length=50 + ) description: str = Field(default="", title="Description", max_length=1000) policies: list[Link["Policy"]] = [] @@ -49,11 +58,15 @@ class Resource(Document): def create_group(self) -> None: Debug.info(f'New {self.get_document_type()} "{self.id}" has been created') - async def add_policy(self, policy_holder: "Group | Member", permissions, save: bool = True) -> "Policy": - new_policy = Policy(policy_holder_type=policy_holder.get_document_type(), # type: ignore - policy_holder=(await create_link(policy_holder)), - permissions=permissions, - parent_resource=(await create_link(self))) # type: ignore + async def add_policy( + self, policy_holder: "Group | Member", permissions, save: bool = True + ) -> "Policy": + new_policy = Policy( + policy_holder_type=policy_holder.get_document_type(), # type: ignore + policy_holder=(await create_link(policy_holder)), + permissions=permissions, + parent_resource=(await create_link(self)), + ) # type: ignore # Add the policy to the group self.policies.append(new_policy) # type: ignore @@ -62,13 +75,15 @@ async def add_policy(self, policy_holder: "Group | Member", permissions, save: b return new_policy async def remove_policy(self, policy: "Policy", save: bool = True) -> None: - for i, p in enumerate(self.policies): + for p in self.policies: if policy.id == p.ref.id: self.policies.remove(p) if save: await self.save(link_rule=WriteRules.WRITE) # type: ignore - async def remove_policy_by_holder(self, policy_holder: "Group | Member", save: bool = True) -> None: + async def remove_policy_by_holder( + self, policy_holder: "Group | Member", save: bool = True + ) -> None: for policy in self.policies: if policy.policy_holder.ref.id == policy_holder.id: # type: ignore self.policies.remove(policy) @@ -79,15 +94,11 @@ async def remove_policy_by_holder(self, policy_holder: "Group | Member", save: b class Account(BeanieBaseUser, Document): # type: ignore id: ResourceID = Field(default_factory=ResourceID, alias="_id") first_name: str = Field( - default_factory=str, - max_length=20, - min_length=2, - pattern="^[A-Z][a-z]*$") + default_factory=str, max_length=20, min_length=2, pattern="^[A-Z][a-z]*$" + ) last_name: str = Field( - default_factory=str, - max_length=20, - min_length=2, - pattern="^[A-Z][a-z]*$") + default_factory=str, max_length=20, min_length=2, pattern="^[A-Z][a-z]*$" + ) class Workspace(Resource): @@ -106,7 +117,9 @@ async def add_member(self, account: "Account", permissions, save: bool = True) - await self.save(link_rule=WriteRules.WRITE) # type: ignore return new_member - async def remove_member(self, member_to_delete: "Member", save: bool = True) -> bool: + async def remove_member( + self, member_to_delete: "Member", save: bool = True + ) -> bool: # Remove the account from the workspace for member in self.members: if member.id == member_to_delete.id: # type: ignore @@ -136,11 +149,15 @@ class Group(Resource): members: list[Link["Member"]] = [] groups: list[Link["Group"]] = [] - async def add_member(self, member: "Member", permissions, save: bool = True) -> "Member": + async def add_member( + self, member: "Member", permissions, save: bool = True + ) -> "Member": if member.workspace.id != self.workspace.id: # type: ignore from unipoll_api.exceptions import WorkspaceExceptions + raise WorkspaceExceptions.UserNotMember( - self.workspace, member) # type: ignore + self.workspace, member + ) # type: ignore # Add the member to the group's list of members self.members.append(member) # type: ignore @@ -157,7 +174,8 @@ async def remove_member(self, member: "Member", save: bool = True) -> bool: self.members.remove(_member) # type: ignore Debug.info( - f"Removed member {member.id} from {self.get_document_type()} {self.id}") # type: ignore + f"Removed member {member.id} from {self.get_document_type()} {self.id}" + ) # type: ignore break # Remove the policy from the group @@ -182,21 +200,28 @@ class Policy(Document): policy_holder: Link["Group"] | Link["Member"] permissions: int - async def get_parent_resource(self, fetch_links: bool = False) -> Workspace | Group | Poll: + async def get_parent_resource( + self, fetch_links: bool = False + ) -> Workspace | Group | Poll: from unipoll_api.exceptions.resource import ResourceNotFound + collection = eval(self.parent_resource.ref.collection) - parent: Workspace | Group | Poll = await collection.get(self.parent_resource.ref.id, - fetch_links=fetch_links) + parent: Workspace | Group | Poll = await collection.get( + self.parent_resource.ref.id, fetch_links=fetch_links + ) if not parent: - ResourceNotFound(self.parent_resource.ref.collection, - self.parent_resource.ref.id) + ResourceNotFound( + self.parent_resource.ref.collection, self.parent_resource.ref.id + ) return parent async def get_policy_holder(self, fetch_links: bool = False) -> "Group | Member": from unipoll_api.exceptions.policy import PolicyHolderNotFound + collection = eval(self.policy_holder.ref.collection) - policy_holder: Group | Member = await collection.get(self.policy_holder.ref.id, - fetch_links=fetch_links) + policy_holder: Group | Member = await collection.get( + self.policy_holder.ref.id, fetch_links=fetch_links + ) if not policy_holder: PolicyHolderNotFound(self.policy_holder.ref.id) return policy_holder diff --git a/src/unipoll_api/routes/v1/authentication.py b/src/unipoll_api/routes/v1/authentication.py index 24f979d..f6caa72 100644 --- a/src/unipoll_api/routes/v1/authentication.py +++ b/src/unipoll_api/routes/v1/authentication.py @@ -1,5 +1,5 @@ -from typing import Annotated -from fastapi import APIRouter, Body, Depends, HTTPException, Header, status +from typing import Annotated, Literal +from fastapi import APIRouter, Body, Depends, Form, HTTPException, Header, status from fastapi.security import OAuth2PasswordRequestForm from fastapi_users import BaseUserManager, models from fastapi_users.openapi import OpenAPIResponseType @@ -10,7 +10,7 @@ # import fastapi_users, get_user_manager, jwt_backend, get_database_strategy, get_access_token_db from unipoll_api.actions import authentication as AuthActions -# from unipoll_api.schemas import authentication as AuthSchemas +from unipoll_api.schemas import authentication as AuthSchemas from unipoll_api.schemas import account as AccountSchemas from unipoll_api.exceptions.resource import APIException from unipoll_api.utils.token_db import BeanieAccessTokenDatabase @@ -90,7 +90,8 @@ async def refresh_jwt(authorization: Annotated[str, Header(...)], @router.post("/jwt/postman_refresh", responses=login_responses, response_model_exclude_unset=True) async def refresh_jwt_with_client_ID(authorization: Annotated[str, Header(...)], - body: Annotated[str, Body(...)], + refresh_token: Annotated[str, Form(...)], + grant_type: Literal["refresh_token"] = Form(...), token_db: BeanieAccessTokenDatabase = Depends(AccountManager.get_access_token_db), strategy: Strategy = Depends(AccountManager.get_database_strategy)): """Refresh the access token using the refresh token. @@ -99,14 +100,10 @@ async def refresh_jwt_with_client_ID(authorization: Annotated[str, Header(...)], authorization: `Authorization` header with the access token Body: refresh_token: `Refresh-Token` header with the refresh token + grant_type: `grant_type` header with the grant type (refresh_token) """ try: - # import json - # print(body.decode('utf-8')) - # body = json.loads(body.decode('utf-8')) - # print(body) - # AuthSchemas.PostmanRefreshTokenRequest(**body) - return await AuthActions.refresh_token_with_clientID(authorization, body, token_db, strategy) + return await AuthActions.refresh_token_with_clientID(authorization, refresh_token, token_db, strategy) except APIException as e: raise HTTPException(status_code=e.code, detail=e.detail)