Skip to content

Commit 88356e1

Browse files
authored
feat: use jti to revoke token (#30)
2 parents 3f4e62c + 4723296 commit 88356e1

7 files changed

Lines changed: 109 additions & 21 deletions

File tree

fob_api/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from .config import Config
2-
from .database import init_engine
2+
from .database import init_engine, get_session
33
from .lib.headscale import HeadScale
44
from .vpn import headscale_driver
55
from . import mail

fob_api/auth/__init__.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from fob_api.config import Config
1414
from fob_api.models.database import User
15+
from fob_api.models.database import Token as TokenDB
1516
from fob_api import engine
1617

1718
password_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
@@ -20,7 +21,7 @@
2021

2122
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/token")
2223
jwt_secret = Config().jwt_secret_key
23-
jwt_expire_days = 1
24+
jwt_expire_days = 15
2425

2526
if not jwt_secret:
2627
raise ValueError("JWT secret not set")
@@ -33,29 +34,30 @@ def hash_password(password: str) -> str:
3334
"""
3435
return password_context.hash(password)
3536

36-
37-
def encode_token(username) -> str:
38-
return jwt.encode({
37+
def make_token_data(username: str) -> dict:
38+
return {
3939
"exp": datetime.now() + timedelta(days=jwt_expire_days),
4040
"iat": datetime.now(),
4141
"jti": str(uuid4()),
4242
"nbf": datetime.now(),
4343
"sub": str(username)
44-
}, jwt_secret, algorithm="HS256")
44+
}
45+
46+
def encode_token(token_data) -> str:
47+
return jwt.encode(token_data, jwt_secret, algorithm="HS256")
4548

4649

4750
def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]) -> User:
4851
try:
4952
payload = jwt.decode(token, jwt_secret, algorithms=["HS256"])
5053
with Session(engine) as session:
51-
return session.exec(select(User).where(User.username == payload["sub"])).first()
52-
raise HTTPException(status_code=401, detail="No user matching the token")
53-
except JWTClaimsError as e:
54-
raise HTTPException(status_code=401, detail=f"JWTClaimsError: {e}")
55-
except ExpiredSignatureError as e:
56-
raise HTTPException(status_code=401, detail=f"ExpiredSignatureError: {e}")
57-
except JWTError as e:
58-
raise HTTPException(status_code=401, detail=f"JWTError: {e}")
54+
user = session.exec(select(User).where(User.username == payload["sub"])).first()
55+
token = session.exec(select(TokenDB).where(TokenDB.token_id == payload["jti"])).first()
56+
if user and token:
57+
return user
58+
raise HTTPException(status_code=401, detail="Invalid token")
59+
except (JWTClaimsError, ExpiredSignatureError, JWTError) as e:
60+
raise HTTPException(status_code=401, detail=f"JWT Error: {e}")
5961

6062

6163
def basic_auth_validator(username: str, password: str) -> User:

fob_api/database.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from fastapi import Depends, FastAPI, HTTPException, Query
2+
from sqlmodel import Field, Session, SQLModel, create_engine, select
13
from sqlalchemy import Engine
24
from sqlmodel import create_engine, SQLModel
35
from fob_api import Config
@@ -9,3 +11,9 @@ def init_engine() -> Engine:
911
"""
1012
print("Initializing database engine")
1113
return create_engine(Config().database_url, echo=False, pool_recycle=1800, pool_pre_ping=True)
14+
15+
engine = init_engine()
16+
17+
def get_session():
18+
with Session(engine) as session:
19+
yield session

fob_api/models/database/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from .user import (
22
User,
3-
UserPasswordReset
3+
UserPasswordReset,
4+
Token
45
)
56
from .headscale import (
67
HeadScalePolicyACL,

fob_api/models/database/user.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,13 @@ class UserPasswordReset(SQLModel, table=True):
2424
source_ip: str
2525
created_at: datetime = Field(default=datetime.now())
2626
expires_at: datetime
27+
28+
class Token(SQLModel, table=True):
29+
"""
30+
This class represents the Token
31+
"""
32+
id: int = Field(primary_key=True)
33+
expires_at: datetime
34+
created_at: datetime
35+
token_id: str
36+
user_id: int = Field(foreign_key="user.id")

fob_api/routes/token.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,27 +2,55 @@
22

33
from fastapi import APIRouter, Depends, HTTPException
44
from fastapi.security import OAuth2PasswordRequestForm
5+
from sqlmodel import Session, select
56

6-
from fob_api import auth
7+
from fob_api import auth, get_session
78
from fob_api.models.database import User
9+
from fob_api.models.database import Token as TokenDB
810
from fob_api.models.api import Token, TokenValidate
911

1012
router = APIRouter()
1113

1214
@router.post("/token", response_model=Token, tags=["token"])
13-
def get_token(form_data: Annotated[OAuth2PasswordRequestForm, Depends()]) -> str:
15+
def get_token(form_data: Annotated[OAuth2PasswordRequestForm, Depends()], session: Session = Depends(get_session)) -> Token:
1416
user = auth.basic_auth_validator(form_data.username, form_data.password)
1517
if not user:
1618
raise HTTPException(status_code=401, detail="Invalid credentials")
17-
token = auth.encode_token(user.username)
19+
token_data = auth.make_token_data(user.username)
20+
token_db: TokenDB = TokenDB(
21+
expires_at=token_data["exp"],
22+
created_at=token_data["iat"],
23+
token_id=token_data["jti"],
24+
user_id=user.id,
25+
)
26+
session.add(token_db)
27+
session.commit()
28+
token = auth.encode_token(token_data)
1829
return Token(access_token=token, token_type="bearer")
1930

2031

2132
@router.get("/token/refreshtoken", response_model=Token, tags=["token"])
22-
def refresh_token(user: Annotated[User, Depends(auth.get_current_user)]) -> str:
23-
token = auth.encode_token(user.username)
33+
def refresh_token(user: Annotated[User, Depends(auth.get_current_user)], session: Session = Depends(get_session)) -> Token:
34+
token_data = auth.make_token_data(user.username)
35+
token_db: TokenDB = TokenDB(
36+
expires_at=token_data["exp"],
37+
created_at=token_data["iat"],
38+
token_id=token_data["jti"],
39+
user_id=user.id,
40+
)
41+
session.add(token_db)
42+
session.commit()
43+
token = auth.encode_token(token_data)
2444
return Token(access_token=token, token_type="bearer")
2545

46+
@router.delete("/token/{jti}", tags=["token"])
47+
def revoke_token(jti: str, user: Annotated[User, Depends(auth.get_current_user)], session: Session = Depends(get_session)) -> None:
48+
token = session.exec(select(TokenDB).where(TokenDB.token_id == jti)).first()
49+
if not token or token.user_id != user.id:
50+
raise HTTPException(status_code=404, detail="Cant revoke token")
51+
session.delete(token)
52+
session.commit()
53+
2654
@router.get("/token/verify", response_model=TokenValidate, tags=["token"])
27-
def verify_token(user: Annotated[User, Depends(auth.get_current_user)]) -> str:
55+
def verify_token(user: Annotated[User, Depends(auth.get_current_user)]) -> TokenValidate:
2856
return TokenValidate(valid=True)
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
"""add token register
2+
3+
Revision ID: 22f19d8927c2
4+
Revises: aa1757ace187
5+
Create Date: 2025-02-05 19:31:37.252993
6+
7+
"""
8+
from typing import Sequence, Union
9+
10+
from alembic import op
11+
import sqlalchemy as sa
12+
import sqlmodel
13+
14+
15+
# revision identifiers, used by Alembic.
16+
revision: str = '22f19d8927c2'
17+
down_revision: Union[str, None] = 'aa1757ace187'
18+
branch_labels: Union[str, Sequence[str], None] = None
19+
depends_on: Union[str, Sequence[str], None] = None
20+
21+
22+
def upgrade() -> None:
23+
# ### commands auto generated by Alembic - please adjust! ###
24+
op.create_table('token',
25+
sa.Column('id', sa.Integer(), nullable=False),
26+
sa.Column('expires_at', sa.DateTime(), nullable=False),
27+
sa.Column('created_at', sa.DateTime(), nullable=False),
28+
sa.Column('token_id', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
29+
sa.Column('user_id', sa.Integer(), nullable=False),
30+
sa.ForeignKeyConstraint(['user_id'], ['user.id'], ),
31+
sa.PrimaryKeyConstraint('id')
32+
)
33+
# ### end Alembic commands ###
34+
35+
36+
def downgrade() -> None:
37+
# ### commands auto generated by Alembic - please adjust! ###
38+
op.drop_table('token')
39+
# ### end Alembic commands ###

0 commit comments

Comments
 (0)