Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions src/unipoll_api/actions/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,12 @@ async def refresh_token(authorization: str,


async def refresh_token_with_clientID(authorization: str,
body: str,
refresh_token: str,
token_db=Depends(get_access_token_db),
strategy=Depends(get_database_strategy)):
# Make sure the Authorization header is valid and extract the access token
try:
client_id = re.match(r'^Basic (\S+)$', authorization).group(1) # type: ignore
refresh_token = re.match(r'^refresh_token=(\S+)&grant_type=refresh_token$', body).group(1) # type: ignore
except Exception as e:
Debug.print_error(str(e))
raise AuthExceptions.InvalidAuthorizationHeader()
Expand All @@ -68,7 +67,7 @@ async def refresh_token_with_clientID(authorization: str,

# Make sure the access token exists in the database
if token_data is None:
raise AuthExceptions.InvalidAccessToken()
raise AuthExceptions.InvalidAccessToken
Copy link

Copilot AI Jul 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exception class is not being instantiated. Should be raise AuthExceptions.InvalidAccessToken() to properly raise the exception.

Suggested change
raise AuthExceptions.InvalidAccessToken
raise AuthExceptions.InvalidAccessToken()

Copilot uses AI. Check for mistakes.

# Get the user from the database using the user ID in the token data
user = await Account.get(token_data.user_id)
Expand All @@ -78,14 +77,14 @@ async def refresh_token_with_clientID(authorization: str,
# Decode the client ID and make sure it matches account ID
client_id = base64.b64decode(client_id)
if PydanticObjectId(str(client_id, "utf-8")[:-1]) != user.id:
raise AuthExceptions.InvalidClientID()
raise AuthExceptions.InvalidClientID
Copy link

Copilot AI Jul 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exception class is not being instantiated. Should be raise AuthExceptions.InvalidClientID() to properly raise the exception.

Suggested change
raise AuthExceptions.InvalidClientID
raise AuthExceptions.InvalidClientID()

Copilot uses AI. Check for mistakes.

# Check if the refresh token is the most recent one
all_tokens = await token_db.get_token_family_by_user_id(user.id)
if (await all_tokens.to_list())[0].refresh_token != refresh_token:
# If not, delete all tokens associated with the user and return an error
await strategy.destroy_token_family(user)
raise AuthExceptions.refreshTokenExpired()
raise AuthExceptions.refreshTokenExpired
Copy link

Copilot AI Jul 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exception class is not being instantiated. Should be raise AuthExceptions.refreshTokenExpired() to properly raise the exception.

Suggested change
raise AuthExceptions.refreshTokenExpired
raise AuthExceptions.refreshTokenExpired()

Copilot uses AI. Check for mistakes.

# Login the user using the supplied strategy
# Generate new pair of access and refresh tokens
Expand Down
85 changes: 55 additions & 30 deletions src/unipoll_api/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,14 @@
from typing import Literal
from bson import DBRef
from beanie import Document as BeanieDocument
from beanie import BackLink, WriteRules, after_event, Insert, Link, PydanticObjectId # BackLink
from beanie import (
BackLink,
WriteRules,
after_event,
Insert,
Link,
PydanticObjectId,
) # BackLink
from fastapi_users_db_beanie import BeanieBaseUser
from pydantic import Field
from unipoll_api.utils import colored_dbg as Debug
Expand All @@ -18,8 +25,9 @@ def get_document_type(cls) -> str:

# Create a link to the Document model
async def create_link(document: Document) -> Link:
ref = DBRef(collection=document._document_settings.name, # type: ignore
id=document.id)
ref = DBRef(
collection=document._document_settings.name, id=document.id # type: ignore
)
link = Link(ref, type(document))
return link

Expand All @@ -41,19 +49,24 @@ class AccessToken(BeanieBaseAccessToken, Document): # type: ignore
class Resource(Document):
id: ResourceID = Field(default_factory=ResourceID, alias="_id")
name: str = Field(
title="Name", description="Name of the resource", min_length=3, max_length=50)
title="Name", description="Name of the resource", min_length=3, max_length=50
)
description: str = Field(default="", title="Description", max_length=1000)
policies: list[Link["Policy"]] = []

@after_event(Insert)
def create_group(self) -> None:
Debug.info(f'New {self.get_document_type()} "{self.id}" has been created')

async def add_policy(self, policy_holder: "Group | Member", permissions, save: bool = True) -> "Policy":
new_policy = Policy(policy_holder_type=policy_holder.get_document_type(), # type: ignore
policy_holder=(await create_link(policy_holder)),
permissions=permissions,
parent_resource=(await create_link(self))) # type: ignore
async def add_policy(
self, policy_holder: "Group | Member", permissions, save: bool = True
) -> "Policy":
new_policy = Policy(
policy_holder_type=policy_holder.get_document_type(), # type: ignore
policy_holder=(await create_link(policy_holder)),
permissions=permissions,
parent_resource=(await create_link(self)),
) # type: ignore

# Add the policy to the group
self.policies.append(new_policy) # type: ignore
Expand All @@ -62,13 +75,15 @@ async def add_policy(self, policy_holder: "Group | Member", permissions, save: b
return new_policy

async def remove_policy(self, policy: "Policy", save: bool = True) -> None:
for i, p in enumerate(self.policies):
for p in self.policies:
if policy.id == p.ref.id:
self.policies.remove(p)
if save:
await self.save(link_rule=WriteRules.WRITE) # type: ignore

async def remove_policy_by_holder(self, policy_holder: "Group | Member", save: bool = True) -> None:
async def remove_policy_by_holder(
self, policy_holder: "Group | Member", save: bool = True
) -> None:
for policy in self.policies:
if policy.policy_holder.ref.id == policy_holder.id: # type: ignore
self.policies.remove(policy)
Expand All @@ -79,15 +94,11 @@ async def remove_policy_by_holder(self, policy_holder: "Group | Member", save: b
class Account(BeanieBaseUser, Document): # type: ignore
id: ResourceID = Field(default_factory=ResourceID, alias="_id")
first_name: str = Field(
default_factory=str,
max_length=20,
min_length=2,
pattern="^[A-Z][a-z]*$")
default_factory=str, max_length=20, min_length=2, pattern="^[A-Z][a-z]*$"
)
last_name: str = Field(
default_factory=str,
max_length=20,
min_length=2,
pattern="^[A-Z][a-z]*$")
default_factory=str, max_length=20, min_length=2, pattern="^[A-Z][a-z]*$"
)


class Workspace(Resource):
Expand All @@ -106,7 +117,9 @@ async def add_member(self, account: "Account", permissions, save: bool = True) -
await self.save(link_rule=WriteRules.WRITE) # type: ignore
return new_member

async def remove_member(self, member_to_delete: "Member", save: bool = True) -> bool:
async def remove_member(
self, member_to_delete: "Member", save: bool = True
) -> bool:
# Remove the account from the workspace
for member in self.members:
if member.id == member_to_delete.id: # type: ignore
Expand Down Expand Up @@ -136,11 +149,15 @@ class Group(Resource):
members: list[Link["Member"]] = []
groups: list[Link["Group"]] = []

async def add_member(self, member: "Member", permissions, save: bool = True) -> "Member":
async def add_member(
self, member: "Member", permissions, save: bool = True
) -> "Member":
if member.workspace.id != self.workspace.id: # type: ignore
from unipoll_api.exceptions import WorkspaceExceptions

raise WorkspaceExceptions.UserNotMember(
self.workspace, member) # type: ignore
self.workspace, member
) # type: ignore

# Add the member to the group's list of members
self.members.append(member) # type: ignore
Expand All @@ -157,7 +174,8 @@ async def remove_member(self, member: "Member", save: bool = True) -> bool:
self.members.remove(_member)
# type: ignore
Debug.info(
f"Removed member {member.id} from {self.get_document_type()} {self.id}") # type: ignore
f"Removed member {member.id} from {self.get_document_type()} {self.id}"
) # type: ignore
break

# Remove the policy from the group
Expand All @@ -182,21 +200,28 @@ class Policy(Document):
policy_holder: Link["Group"] | Link["Member"]
permissions: int

async def get_parent_resource(self, fetch_links: bool = False) -> Workspace | Group | Poll:
async def get_parent_resource(
self, fetch_links: bool = False
) -> Workspace | Group | Poll:
from unipoll_api.exceptions.resource import ResourceNotFound

collection = eval(self.parent_resource.ref.collection)
parent: Workspace | Group | Poll = await collection.get(self.parent_resource.ref.id,
fetch_links=fetch_links)
parent: Workspace | Group | Poll = await collection.get(
self.parent_resource.ref.id, fetch_links=fetch_links
)
if not parent:
ResourceNotFound(self.parent_resource.ref.collection,
self.parent_resource.ref.id)
ResourceNotFound(
Copy link

Copilot AI Jul 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exception class is not being instantiated properly. Should be raise ResourceNotFound(...) to properly raise the exception.

Copilot uses AI. Check for mistakes.
self.parent_resource.ref.collection, self.parent_resource.ref.id
)
return parent

async def get_policy_holder(self, fetch_links: bool = False) -> "Group | Member":
from unipoll_api.exceptions.policy import PolicyHolderNotFound

collection = eval(self.policy_holder.ref.collection)
policy_holder: Group | Member = await collection.get(self.policy_holder.ref.id,
fetch_links=fetch_links)
policy_holder: Group | Member = await collection.get(
self.policy_holder.ref.id, fetch_links=fetch_links
)
if not policy_holder:
PolicyHolderNotFound(self.policy_holder.ref.id)
return policy_holder
Expand Down
17 changes: 7 additions & 10 deletions src/unipoll_api/routes/v1/authentication.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Annotated
from fastapi import APIRouter, Body, Depends, HTTPException, Header, status
from typing import Annotated, Literal
from fastapi import APIRouter, Body, Depends, Form, HTTPException, Header, status
from fastapi.security import OAuth2PasswordRequestForm
from fastapi_users import BaseUserManager, models
from fastapi_users.openapi import OpenAPIResponseType
Expand All @@ -10,7 +10,7 @@

# import fastapi_users, get_user_manager, jwt_backend, get_database_strategy, get_access_token_db
from unipoll_api.actions import authentication as AuthActions
# from unipoll_api.schemas import authentication as AuthSchemas
from unipoll_api.schemas import authentication as AuthSchemas
from unipoll_api.schemas import account as AccountSchemas
from unipoll_api.exceptions.resource import APIException
from unipoll_api.utils.token_db import BeanieAccessTokenDatabase
Expand Down Expand Up @@ -90,7 +90,8 @@ async def refresh_jwt(authorization: Annotated[str, Header(...)],

@router.post("/jwt/postman_refresh", responses=login_responses, response_model_exclude_unset=True)
async def refresh_jwt_with_client_ID(authorization: Annotated[str, Header(...)],
body: Annotated[str, Body(...)],
refresh_token: Annotated[str, Form(...)],
grant_type: Literal["refresh_token"] = Form(...),
token_db: BeanieAccessTokenDatabase = Depends(AccountManager.get_access_token_db),
strategy: Strategy = Depends(AccountManager.get_database_strategy)):
"""Refresh the access token using the refresh token.
Expand All @@ -99,14 +100,10 @@ async def refresh_jwt_with_client_ID(authorization: Annotated[str, Header(...)],
authorization: `Authorization` header with the access token
Body:
refresh_token: `Refresh-Token` header with the refresh token
grant_type: `grant_type` header with the grant type (refresh_token)
"""
try:
# import json
# print(body.decode('utf-8'))
# body = json.loads(body.decode('utf-8'))
# print(body)
# AuthSchemas.PostmanRefreshTokenRequest(**body)
return await AuthActions.refresh_token_with_clientID(authorization, body, token_db, strategy)
return await AuthActions.refresh_token_with_clientID(authorization, refresh_token, token_db, strategy)
except APIException as e:
raise HTTPException(status_code=e.code, detail=e.detail)

Expand Down
Loading