Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 17 additions & 14 deletions db/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
BiocommonsGroup,
BiocommonsUser,
GroupMembership,
GroupMembershipHistory,
GroupMembershipAuditLog,
Platform,
PlatformMembership,
PlatformMembershipHistory,
PlatformMembershipAuditLog,
)
from db.setup import get_engine

Expand Down Expand Up @@ -115,20 +115,21 @@ class GroupMembershipAdmin(ModelView, model=GroupMembership):
]


class GroupMembershipHistoryAdmin(ModelView, model=GroupMembershipHistory):
class GroupMembershipAuditLogAdmin(ModelView, model=GroupMembershipAuditLog):
can_edit = False
can_create = False
can_delete = False
column_list = [
"name",
"group_id",
"user_email",
"user_id",
"approval_status",
"updated_at",
"updated_by_email",
"action",
"action_time",
"updated_by_id",
"revocation_reason",
]
column_default_sort = ("updated_at", True)
column_default_sort = ("action_time", True)


class PlatformAdmin(ModelView, model=Platform):
Expand Down Expand Up @@ -157,19 +158,21 @@ class PlatformMembershipAdmin(ModelView, model=PlatformMembership):
column_default_sort = ("updated_at", True)


class PlatformMembershipHistoryAdmin(ModelView, model=PlatformMembershipHistory):
class PlatformMembershipAuditLogAdmin(ModelView, model=PlatformMembershipAuditLog):
can_edit = False
can_create = False
can_delete = True
can_delete = False
column_list = [
"id",
"platform_id",
"user_id",
"approval_status",
"updated_at",
"updated_by"
"action",
"action_time",
"updated_by_id",
"revocation_reason",
]
column_default_sort = ("updated_at", True)
column_default_sort = ("action_time", True)


class DatabaseAdmin:
Expand All @@ -182,10 +185,10 @@ class DatabaseAdmin:
GroupAdmin,
Auth0RoleAdmin,
GroupMembershipAdmin,
GroupMembershipHistoryAdmin,
GroupMembershipAuditLogAdmin,
PlatformAdmin,
PlatformMembershipAdmin,
PlatformMembershipHistoryAdmin,
PlatformMembershipAuditLogAdmin,
)

def __init__(self, app: FastAPI, secret_key: str):
Expand Down
158 changes: 156 additions & 2 deletions db/core.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
from __future__ import annotations

import uuid
from datetime import datetime, timezone
from typing import Any, ClassVar

import sqlalchemy as sa
from sqlalchemy import MetaData, event, select
from sqlalchemy import inspect as sa_inspect
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session as SASession
from sqlalchemy.orm import with_loader_criteria
from sqlalchemy.sql import expression
from sqlmodel import Field, Session, SQLModel
from sqlmodel import DateTime, Field, Session, SQLModel
from sqlmodel import Enum as DbEnum

from db.types import AuditActionEnum

naming_convention = {
"ix": "ix_%(column_0_label)s",
Expand Down Expand Up @@ -117,6 +123,106 @@ def _coerce_primary_key_map(cls, identity: Any) -> dict[str, Any]:
raise ValueError("Identity must be scalar, tuple/list, or dict matching the primary key.")


class AuditLogModel(BaseModel):
"""
Base for tables that store audit log entries.
"""
__abstract__ = True

id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
action: AuditActionEnum = Field(
sa_type=DbEnum(AuditActionEnum, name="audit_action_enum"),
description="Type of change that produced this audit record.",
)
action_time: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
sa_type=DateTime(timezone=True),
description="Timestamp when the audit record was produced.",
)


class AuditedModel(BaseModel):
"""
Base for ORM models that should write audit log entries on create/update/delete.
"""
__abstract__ = True

# Concrete subclasses must set this to the associated audit log SQLModel.
__audit_model__: ClassVar[type[AuditLogModel] | None] = None
# Map of attribute name on the subject model to the attribute name on the audit log.
__audit_field_map__: ClassVar[dict[str, str]] = {}
# Attributes to skip when copying state into the audit log.
__audit_exclude_columns__: ClassVar[set[str]] = set()

def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if cls.__abstract__:
return
audit_model = getattr(cls, "__audit_model__", None)
if audit_model is None:
raise TypeError(f"{cls.__name__} must define __audit_model__ to enable auditing.")
try:
mapper = sa.orm.class_mapper(cls)
except sa.orm.exc.UnmappedClassError:
mapper = None
else:
_configure_audited_model(mapper, cls)
sa.orm.configure_mappers()

@classmethod
def _audit_after_insert(cls, mapper, connection, target) -> None:
cls._emit_audit(connection, target, AuditActionEnum.CREATED)

@classmethod
def _audit_after_update(cls, mapper, connection, target) -> None:
state = sa_inspect(target)
has_changes = any(
state.attrs[column.key].history.has_changes()
for column in mapper.column_attrs
if column.key in state.attrs
)
if has_changes:
cls._emit_audit(connection, target, AuditActionEnum.UPDATED)

@classmethod
def _audit_before_delete(cls, mapper, connection, target) -> None:
cls._emit_audit(connection, target, AuditActionEnum.DELETED)

@classmethod
def _emit_audit(cls, connection, target, action: AuditActionEnum) -> None:
audit_model = cls.__audit_model__
if audit_model is None:
return
payload = cls._collect_audit_payload(target, audit_model)
if payload is None:
return
payload["action"] = action
payload.setdefault("action_time", datetime.now(timezone.utc))
insert_stmt = audit_model.__table__.insert().values(**payload)
connection.execute(insert_stmt)

@classmethod
def _collect_audit_payload(
cls,
target: "AuditedModel",
audit_model: type[AuditLogModel],
) -> dict[str, Any]:
mapper = sa_inspect(target.__class__)
field_map = cls.__audit_field_map__
exclude = cls.__audit_exclude_columns__
audit_columns = set(audit_model.__table__.columns.keys())
payload: dict[str, Any] = {}
for column in mapper.columns:
key = column.key
if key in exclude:
continue
dest_key = field_map.get(key, key)
if dest_key is None or dest_key not in audit_columns:
continue
payload[dest_key] = getattr(target, key)
return payload or None


def _copy_column_state(source: SoftDeleteModel, target: SoftDeleteModel) -> None:
mapper = sa_inspect(source.__class__)
for attr in mapper.column_attrs:
Expand Down Expand Up @@ -184,4 +290,52 @@ def _filter_soft_deleted(execute_state) -> None:
)


__all__ = ["BaseModel", "SoftDeleteModel"]
__all__ = ["BaseModel", "SoftDeleteModel", "AuditLogModel", "AuditedModel"]


def _get_bind_connection(session: SASession, obj: AuditedModel):
return session.connection()


@event.listens_for(SASession, "before_flush")
def _audit_deleted(session: SASession, flush_context, instances) -> None:
for obj in list(session.deleted):
if not isinstance(obj, AuditedModel):
continue
connection = _get_bind_connection(session, obj)
obj.__class__._emit_audit(connection, obj, AuditActionEnum.DELETED)


@event.listens_for(SASession, "after_flush")
def _audit_new_and_dirty(session: SASession, flush_context) -> None:
new_objs = [obj for obj in session.new if isinstance(obj, AuditedModel)]
new_ids = {id(obj) for obj in new_objs}
dirty_objs = [
obj for obj in session.dirty
if isinstance(obj, AuditedModel) and id(obj) not in new_ids
]

for obj in new_objs:
connection = _get_bind_connection(session, obj)
obj.__class__._emit_audit(connection, obj, AuditActionEnum.CREATED)

for obj in dirty_objs:
state = sa_inspect(obj)
has_changes = any(
attr.history.has_changes() for attr in state.attrs
)
if not has_changes:
continue
connection = _get_bind_connection(session, obj)
obj.__class__._emit_audit(connection, obj, AuditActionEnum.UPDATED)


@event.listens_for(sa.orm.Mapper, "mapper_configured")
def _configure_audited_model(mapper, cls) -> None:
if not isinstance(cls, type) or not issubclass(cls, AuditedModel):
return
if cls.__abstract__:
return
audit_model = getattr(cls, "__audit_model__", None)
if audit_model is None:
return
Loading