diff --git a/backend/alembic/versions/0068_save_sync.py b/backend/alembic/versions/0068_save_sync.py new file mode 100644 index 000000000..863fbc107 --- /dev/null +++ b/backend/alembic/versions/0068_save_sync.py @@ -0,0 +1,102 @@ +"""Add device-based save synchronization + +Revision ID: 0068_save_sync +Revises: 0067_romfile_category_enum_cheat +Create Date: 2026-01-17 + +""" + +import sqlalchemy as sa +from alembic import op + +revision = "0068_save_sync" +down_revision = "0067_romfile_category_enum_cheat" +branch_labels = None +depends_on = None + + +def upgrade(): + op.create_table( + "devices", + sa.Column("id", sa.String(255), primary_key=True), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.Column("name", sa.String(255), nullable=True), + sa.Column("platform", sa.String(50), nullable=True), + sa.Column("client", sa.String(50), nullable=True), + sa.Column("client_version", sa.String(50), nullable=True), + sa.Column("ip_address", sa.String(45), nullable=True), + sa.Column("mac_address", sa.String(17), nullable=True), + sa.Column("hostname", sa.String(255), nullable=True), + sa.Column( + "sync_mode", + sa.Enum("API", "FILE_TRANSFER", "PUSH_PULL", name="syncmode"), + nullable=False, + server_default="API", + ), + sa.Column("sync_enabled", sa.Boolean(), nullable=False, server_default="1"), + sa.Column("last_seen", sa.TIMESTAMP(timezone=True), nullable=True), + sa.Column( + "created_at", + sa.TIMESTAMP(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.Column( + "updated_at", + sa.TIMESTAMP(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"), + ) + + op.create_table( + "device_save_sync", + sa.Column("device_id", sa.String(255), nullable=False), + sa.Column("save_id", sa.Integer(), nullable=False), + sa.Column("last_synced_at", sa.TIMESTAMP(timezone=True), nullable=False), + sa.Column("is_untracked", sa.Boolean(), nullable=False, server_default="0"), + sa.Column( + "created_at", + sa.TIMESTAMP(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.Column( + "updated_at", + sa.TIMESTAMP(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.ForeignKeyConstraint(["device_id"], ["devices.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint(["save_id"], ["saves.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("device_id", "save_id"), + ) + + with op.batch_alter_table("saves", schema=None) as batch_op: + batch_op.add_column(sa.Column("slot", sa.String(255), nullable=True)) + batch_op.add_column(sa.Column("content_hash", sa.String(32), nullable=True)) + + op.create_index("ix_devices_user_id", "devices", ["user_id"]) + op.create_index("ix_devices_last_seen", "devices", ["last_seen"]) + op.create_index("ix_device_save_sync_save_id", "device_save_sync", ["save_id"]) + op.create_index("ix_saves_slot", "saves", ["slot"]) + op.create_index( + "ix_saves_rom_user_hash", "saves", ["rom_id", "user_id", "content_hash"] + ) + + +def downgrade(): + op.drop_index("ix_saves_rom_user_hash", "saves") + op.drop_index("ix_saves_slot", "saves") + op.drop_index("ix_device_save_sync_save_id", "device_save_sync") + op.drop_index("ix_devices_last_seen", "devices") + op.drop_index("ix_devices_user_id", "devices") + + with op.batch_alter_table("saves", schema=None) as batch_op: + batch_op.drop_column("content_hash") + batch_op.drop_column("slot") + + op.drop_table("device_save_sync") + op.drop_table("devices") + op.execute("DROP TYPE IF EXISTS syncmode") diff --git a/backend/endpoints/device.py b/backend/endpoints/device.py new file mode 100644 index 000000000..8911dfbf4 --- /dev/null +++ b/backend/endpoints/device.py @@ -0,0 +1,179 @@ +import uuid +from datetime import datetime, timezone + +from fastapi import HTTPException, Request, Response, status +from pydantic import BaseModel, model_validator + +from decorators.auth import protected_route +from endpoints.responses.device import DeviceCreateResponse, DeviceSchema +from handler.auth.constants import Scope +from handler.database import db_device_handler, db_device_save_sync_handler +from logger.logger import log +from models.device import Device +from utils.router import APIRouter + +router = APIRouter( + prefix="/devices", + tags=["devices"], +) + + +class DeviceCreatePayload(BaseModel): + name: str | None = None + platform: str | None = None + client: str | None = None + client_version: str | None = None + ip_address: str | None = None + mac_address: str | None = None + hostname: str | None = None + allow_existing: bool = True + allow_duplicate: bool = False + reset_syncs: bool = False + + @model_validator(mode="after") + def _duplicate_disables_existing(self) -> "DeviceCreatePayload": + if self.allow_duplicate: + self.allow_existing = False + return self + + +class DeviceUpdatePayload(BaseModel): + name: str | None = None + platform: str | None = None + client: str | None = None + client_version: str | None = None + ip_address: str | None = None + mac_address: str | None = None + hostname: str | None = None + sync_enabled: bool | None = None + + +@protected_route(router.post, "", [Scope.DEVICES_WRITE]) +def register_device( + request: Request, + response: Response, + payload: DeviceCreatePayload, +) -> DeviceCreateResponse: + existing_device = None + if not payload.allow_duplicate: + existing_device = db_device_handler.get_device_by_fingerprint( + user_id=request.user.id, + mac_address=payload.mac_address, + hostname=payload.hostname, + platform=payload.platform, + ) + + if existing_device: + if not payload.allow_existing: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail={ + "error": "device_exists", + "message": "A device with this fingerprint already exists", + "device_id": existing_device.id, + }, + ) + + if payload.reset_syncs: + db_device_save_sync_handler.delete_syncs_for_device( + device_id=existing_device.id + ) + + db_device_handler.update_last_seen( + device_id=existing_device.id, user_id=request.user.id + ) + log.info( + f"Returned existing device {existing_device.id} for user {request.user.username}" + ) + + response.status_code = status.HTTP_200_OK + return DeviceCreateResponse( + device_id=existing_device.id, + name=existing_device.name, + created_at=existing_device.created_at, + ) + + response.status_code = status.HTTP_201_CREATED + device_id = str(uuid.uuid4()) + now = datetime.now(timezone.utc) + + device = Device( + id=device_id, + user_id=request.user.id, + name=payload.name, + platform=payload.platform, + client=payload.client, + client_version=payload.client_version, + ip_address=payload.ip_address, + mac_address=payload.mac_address, + hostname=payload.hostname, + last_seen=now, + ) + + db_device = db_device_handler.add_device(device) + log.info(f"Registered device {device_id} for user {request.user.username}") + + return DeviceCreateResponse( + device_id=db_device.id, + name=db_device.name, + created_at=db_device.created_at, + ) + + +@protected_route(router.get, "", [Scope.DEVICES_READ]) +def get_devices(request: Request) -> list[DeviceSchema]: + devices = db_device_handler.get_devices(user_id=request.user.id) + return [DeviceSchema.model_validate(device) for device in devices] + + +@protected_route(router.get, "/{device_id}", [Scope.DEVICES_READ]) +def get_device(request: Request, device_id: str) -> DeviceSchema: + device = db_device_handler.get_device(device_id=device_id, user_id=request.user.id) + if not device: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Device with ID {device_id} not found", + ) + return DeviceSchema.model_validate(device) + + +@protected_route(router.put, "/{device_id}", [Scope.DEVICES_WRITE]) +def update_device( + request: Request, + device_id: str, + payload: DeviceUpdatePayload, +) -> DeviceSchema: + device = db_device_handler.get_device(device_id=device_id, user_id=request.user.id) + if not device: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Device with ID {device_id} not found", + ) + + update_data = payload.model_dump(exclude_unset=True) + if update_data: + device = db_device_handler.update_device( + device_id=device_id, + user_id=request.user.id, + data=update_data, + ) + + return DeviceSchema.model_validate(device) + + +@protected_route( + router.delete, + "/{device_id}", + [Scope.DEVICES_WRITE], + status_code=status.HTTP_204_NO_CONTENT, +) +def delete_device(request: Request, device_id: str) -> None: + device = db_device_handler.get_device(device_id=device_id, user_id=request.user.id) + if not device: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Device with ID {device_id} not found", + ) + + db_device_handler.delete_device(device_id=device_id, user_id=request.user.id) + log.info(f"Deleted device {device_id} for user {request.user.username}") diff --git a/backend/endpoints/responses/assets.py b/backend/endpoints/responses/assets.py index 2256eb80f..04ad2803f 100644 --- a/backend/endpoints/responses/assets.py +++ b/backend/endpoints/responses/assets.py @@ -1,6 +1,12 @@ from datetime import datetime +from typing import Any + +from pydantic import model_validator +from sqlalchemy import inspect +from sqlalchemy.exc import InvalidRequestError from .base import BaseModel +from .device import DeviceSyncSchema class BaseAsset(BaseModel): @@ -31,7 +37,40 @@ class ScreenshotSchema(BaseAsset): class SaveSchema(BaseAsset): emulator: str | None + slot: str | None = None + content_hash: str | None = None screenshot: ScreenshotSchema | None + device_syncs: list[DeviceSyncSchema] = [] + + @model_validator(mode="before") + @classmethod + def handle_lazy_relationships(cls, data: Any) -> Any: + if isinstance(data, dict): + return data + try: + state = inspect(data) + except Exception: + return data + result = {} + for field_name in cls.model_fields: + if field_name in state.unloaded: + continue + try: + result[field_name] = getattr(data, field_name) + except InvalidRequestError: + continue + return result + + +class SlotSummarySchema(BaseModel): + slot: str | None + count: int + latest: SaveSchema + + +class SaveSummarySchema(BaseModel): + total_count: int + slots: list[SlotSummarySchema] class StateSchema(BaseAsset): diff --git a/backend/endpoints/responses/device.py b/backend/endpoints/responses/device.py new file mode 100644 index 000000000..cfed1f0ba --- /dev/null +++ b/backend/endpoints/responses/device.py @@ -0,0 +1,42 @@ +from datetime import datetime + +from models.device import SyncMode + +from .base import BaseModel + + +class DeviceSyncSchema(BaseModel): + device_id: str + device_name: str | None + last_synced_at: datetime + is_untracked: bool + is_current: bool + + class Config: + from_attributes = True + + +class DeviceSchema(BaseModel): + id: str + user_id: int + name: str | None + platform: str | None + client: str | None + client_version: str | None + ip_address: str | None + mac_address: str | None + hostname: str | None + sync_mode: SyncMode + sync_enabled: bool + last_seen: datetime | None + created_at: datetime + updated_at: datetime + + class Config: + from_attributes = True + + +class DeviceCreateResponse(BaseModel): + device_id: str + name: str | None + created_at: datetime diff --git a/backend/endpoints/saves.py b/backend/endpoints/saves.py index b45b2db5c..5f7e14c9c 100644 --- a/backend/endpoints/saves.py +++ b/backend/endpoints/saves.py @@ -1,21 +1,99 @@ +import os +import re from datetime import datetime, timezone from typing import Annotated from fastapi import Body, HTTPException, Request, UploadFile, status +from fastapi.responses import FileResponse from decorators.auth import protected_route -from endpoints.responses.assets import SaveSchema +from endpoints.responses.assets import SaveSchema, SaveSummarySchema, SlotSummarySchema +from endpoints.responses.device import DeviceSyncSchema from exceptions.endpoint_exceptions import RomNotFoundInDatabaseException from handler.auth.constants import Scope -from handler.database import db_rom_handler, db_save_handler, db_screenshot_handler +from handler.database import ( + db_device_handler, + db_device_save_sync_handler, + db_rom_handler, + db_save_handler, + db_screenshot_handler, +) from handler.filesystem import fs_asset_handler from handler.scan_handler import scan_save, scan_screenshot from logger.formatter import BLUE from logger.formatter import highlight as hl from logger.logger import log from models.assets import Save +from models.device import Device +from models.device_save_sync import DeviceSaveSync +from utils.datetime import to_utc from utils.router import APIRouter + +def _build_save_schema( + save: Save, + device: Device | None = None, + sync: DeviceSaveSync | None = None, +) -> SaveSchema: + save_schema = SaveSchema.model_validate(save) + + if device: + if sync: + is_current = to_utc(sync.last_synced_at) >= to_utc(save.updated_at) + last_synced = sync.last_synced_at + is_untracked = sync.is_untracked + else: + is_current = False + last_synced = save.updated_at + is_untracked = False + + save_schema.device_syncs = [ + DeviceSyncSchema( + device_id=device.id, + device_name=device.name, + last_synced_at=last_synced, + is_untracked=is_untracked, + is_current=is_current, + ) + ] + + return save_schema + + +DATETIME_TAG_PATTERN = re.compile(r" \[\d{4}-\d{2}-\d{2}_\d{2}-\d{2}-\d{2}\]") + + +def _apply_datetime_tag(filename: str) -> str: + name, ext = os.path.splitext(filename) + timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%d_%H-%M-%S") + + if DATETIME_TAG_PATTERN.search(name): + name = DATETIME_TAG_PATTERN.sub("", name) + + return f"{name} [{timestamp}]{ext}" + + +def _resolve_device( + device_id: str | None, + user_id: int, + scopes: set[str] | None = None, + required_scope: Scope | None = None, +) -> Device | None: + if not device_id: + return None + + if required_scope and scopes and required_scope not in scopes: + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Forbidden") + + device = db_device_handler.get_device(device_id=device_id, user_id=user_id) + if not device: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Device with ID {device_id} not found", + ) + return device + + router = APIRouter( prefix="/saves", tags=["saves"], @@ -27,22 +105,23 @@ async def add_save( request: Request, rom_id: int, emulator: str | None = None, + slot: str | None = None, + device_id: str | None = None, + overwrite: bool = False, + autocleanup: bool = False, + autocleanup_limit: int = 10, ) -> SaveSchema: + """Upload a save file for a ROM.""" + device = _resolve_device( + device_id, request.user.id, request.auth.scopes, Scope.DEVICES_WRITE + ) + data = await request.form() rom = db_rom_handler.get_rom(rom_id) if not rom: raise RomNotFoundInDatabaseException(rom_id) - log.info(f"Uploading save of {rom.name}") - - saves_path = fs_asset_handler.build_saves_file_path( - user=request.user, - platform_fs_slug=rom.platform.fs_slug, - rom_id=rom_id, - emulator=emulator, - ) - if "saveFile" not in data: log.error("No save file provided") raise HTTPException( @@ -57,12 +136,45 @@ async def add_save( status_code=status.HTTP_400_BAD_REQUEST, detail="Save file has no filename" ) - rom = db_rom_handler.get_rom(rom_id) - if not rom: - raise RomNotFoundInDatabaseException(rom_id) + actual_filename = saveFile.filename + if slot: + actual_filename = _apply_datetime_tag(saveFile.filename) + + db_save = db_save_handler.get_save_by_filename( + user_id=request.user.id, rom_id=rom.id, file_name=actual_filename + ) + + if device and slot and not overwrite: + slot_saves = db_save_handler.get_saves( + user_id=request.user.id, + rom_id=rom.id, + slot=slot, + order_by="updated_at", + ) + if slot_saves: + latest_in_slot = slot_saves[0] + sync = db_device_save_sync_handler.get_sync( + device_id=device.id, save_id=latest_in_slot.id + ) + if not sync or to_utc(sync.last_synced_at) < to_utc( + latest_in_slot.updated_at + ): + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail="Slot has a newer save since your last sync", + ) + elif device and db_save and not overwrite: + sync = db_device_save_sync_handler.get_sync( + device_id=device.id, save_id=db_save.id + ) + if sync and to_utc(sync.last_synced_at) < to_utc(db_save.updated_at): + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail="Save has been updated since your last sync", + ) log.info( - f"Uploading save {hl(saveFile.filename)} for {hl(str(rom.name), color=BLUE)}" + f"Uploading save {hl(actual_filename)} for {hl(str(rom.name), color=BLUE)}" ) saves_path = fs_asset_handler.build_saves_file_path( @@ -72,29 +184,72 @@ async def add_save( emulator=emulator, ) - await fs_asset_handler.write_file(file=saveFile, path=saves_path) + await fs_asset_handler.write_file( + file=saveFile, path=saves_path, filename=actual_filename + ) - # Scan or update save scanned_save = await scan_save( - file_name=saveFile.filename, + file_name=actual_filename, user=request.user, platform_fs_slug=rom.platform.fs_slug, rom_id=rom_id, emulator=emulator, ) - db_save = db_save_handler.get_save_by_filename( - user_id=request.user.id, rom_id=rom.id, file_name=saveFile.filename - ) - if db_save: - db_save = db_save_handler.update_save( - db_save.id, {"file_size_bytes": scanned_save.file_size_bytes} + + if slot and scanned_save.content_hash and not overwrite: + existing_by_hash = db_save_handler.get_save_by_content_hash( + user_id=request.user.id, + rom_id=rom.id, + content_hash=scanned_save.content_hash, ) + if existing_by_hash: + try: + await fs_asset_handler.remove_file(f"{saves_path}/{actual_filename}") + except FileNotFoundError: + pass + sync = None + if device: + sync = db_device_save_sync_handler.get_sync( + device_id=device.id, save_id=existing_by_hash.id + ) + return _build_save_schema(existing_by_hash, device, sync) + + if db_save: + update_data: dict = { + "file_size_bytes": scanned_save.file_size_bytes, + "content_hash": scanned_save.content_hash, + } + if slot is not None: + update_data["slot"] = slot + db_save = db_save_handler.update_save(db_save.id, update_data) else: scanned_save.rom_id = rom.id scanned_save.user_id = request.user.id scanned_save.emulator = emulator + scanned_save.slot = slot db_save = db_save_handler.add_save(save=scanned_save) + if device: + db_device_save_sync_handler.upsert_sync( + device_id=device.id, save_id=db_save.id, synced_at=db_save.updated_at + ) + db_device_handler.update_last_seen(device_id=device.id, user_id=request.user.id) + + if slot and autocleanup: + slot_saves = db_save_handler.get_saves( + user_id=request.user.id, + rom_id=rom.id, + slot=slot, + order_by="updated_at", + ) + if len(slot_saves) > autocleanup_limit: + for old_save in slot_saves[autocleanup_limit:]: + db_save_handler.delete_save(old_save.id) + try: + await fs_asset_handler.remove_file(old_save.full_path) + except FileNotFoundError: + log.warning(f"Could not delete old save file: {old_save.full_path}") + screenshotFile: UploadFile | None = data.get("screenshotFile", None) # type: ignore if screenshotFile and screenshotFile.filename: screenshots_path = fs_asset_handler.build_screenshots_file_path( @@ -103,7 +258,6 @@ async def add_save( await fs_asset_handler.write_file(file=screenshotFile, path=screenshots_path) - # Scan or update screenshot scanned_screenshot = await scan_screenshot( file_name=screenshotFile.filename, user=request.user, @@ -125,7 +279,6 @@ async def add_save( scanned_screenshot.user_id = request.user.id db_screenshot_handler.add_screenshot(screenshot=scanned_screenshot) - # Set the last played time for the current user rom_user = db_rom_handler.get_rom_user(rom_id=rom.id, user_id=request.user.id) if not rom_user: rom_user = db_rom_handler.add_rom_user(rom_id=rom.id, user_id=request.user.id) @@ -133,37 +286,47 @@ async def add_save( rom_user.id, {"last_played": datetime.now(timezone.utc)} ) - # Refetch the rom to get updated saves - rom = db_rom_handler.get_rom(rom_id) - if not rom: - raise RomNotFoundInDatabaseException(rom_id) - - return SaveSchema.model_validate(db_save) + sync = None + if device: + sync = db_device_save_sync_handler.get_sync( + device_id=device.id, save_id=db_save.id + ) + return _build_save_schema(db_save, device, sync) @protected_route(router.get, "", [Scope.ASSETS_READ]) def get_saves( - request: Request, rom_id: int | None = None, platform_id: int | None = None + request: Request, + rom_id: int | None = None, + platform_id: int | None = None, + device_id: str | None = None, + slot: str | None = None, ) -> list[SaveSchema]: + """Retrieve saves for the current user.""" + device = _resolve_device( + device_id, request.user.id, request.auth.scopes, Scope.DEVICES_READ + ) + saves = db_save_handler.get_saves( - user_id=request.user.id, rom_id=rom_id, platform_id=platform_id + user_id=request.user.id, rom_id=rom_id, platform_id=platform_id, slot=slot ) - return [SaveSchema.model_validate(save) for save in saves] + if not device: + return [_build_save_schema(save) for save in saves] + syncs = db_device_save_sync_handler.get_syncs_for_device_and_saves( + device_id=device.id, save_ids=[s.id for s in saves] + ) + sync_by_save_id = {s.save_id: s for s in syncs} -@protected_route(router.get, "/identifiers", [Scope.ASSETS_READ]) -def get_save_identifiers( - request: Request, -) -> list[int]: - """Get save identifiers endpoint + return [ + _build_save_schema(save, device, sync_by_save_id.get(save.id)) for save in saves + ] - Args: - request (Request): Fastapi Request object - Returns: - list[int]: List of save IDs - """ +@protected_route(router.get, "/identifiers", [Scope.ASSETS_READ]) +def get_save_identifiers(request: Request) -> list[int]: + """Retrieve save identifiers.""" saves = db_save_handler.get_saves( user_id=request.user.id, only_fields=[Save.id], @@ -172,20 +335,121 @@ def get_save_identifiers( return [save.id for save in saves] +@protected_route(router.get, "/summary", [Scope.ASSETS_READ]) +def get_saves_summary(request: Request, rom_id: int) -> SaveSummarySchema: + """Retrieve saves summary grouped by slot.""" + summary_data = db_save_handler.get_saves_summary( + user_id=request.user.id, rom_id=rom_id + ) + + slots = [ + SlotSummarySchema( + slot=slot_data["slot"], + count=slot_data["count"], + latest=_build_save_schema(slot_data["latest"]), + ) + for slot_data in summary_data["slots"] + ] + + return SaveSummarySchema(total_count=summary_data["total_count"], slots=slots) + + @protected_route(router.get, "/{id}", [Scope.ASSETS_READ]) -def get_save(request: Request, id: int) -> SaveSchema: +def get_save(request: Request, id: int, device_id: str | None = None) -> SaveSchema: + """Retrieve a save by ID.""" + device = _resolve_device( + device_id, request.user.id, request.auth.scopes, Scope.DEVICES_READ + ) + save = db_save_handler.get_save(user_id=request.user.id, id=id) + if not save: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Save with ID {id} not found", + ) + + sync = None + if device: + sync = db_device_save_sync_handler.get_sync( + device_id=device.id, save_id=save.id + ) + return _build_save_schema(save, device, sync) + + +@protected_route(router.get, "/{id}/content", [Scope.ASSETS_READ]) +def download_save( + request: Request, + id: int, + device_id: str | None = None, + optimistic: bool = True, +) -> FileResponse: + """Download a save file.""" + device = _resolve_device( + device_id, request.user.id, request.auth.scopes, Scope.DEVICES_READ + ) + save = db_save_handler.get_save(user_id=request.user.id, id=id) if not save: - error = f"Save with ID {id} not found" - log.error(error) - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=error) + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Save with ID {id} not found", + ) + + try: + file_path = fs_asset_handler.validate_path(save.full_path) + except ValueError: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Save file not found", + ) from None + + if not file_path.exists() or not file_path.is_file(): + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Save file not found on disk", + ) + + if device and optimistic: + db_device_save_sync_handler.upsert_sync( + device_id=device.id, + save_id=save.id, + synced_at=save.updated_at, + ) + db_device_handler.update_last_seen(device_id=device.id, user_id=request.user.id) - return SaveSchema.model_validate(save) + return FileResponse(path=str(file_path), filename=save.file_name) + + +@protected_route(router.post, "/{id}/downloaded", [Scope.DEVICES_WRITE]) +def confirm_download( + request: Request, + id: int, + device_id: str = Body(..., embed=True), +) -> SaveSchema: + """Confirm a save was downloaded successfully.""" + save = db_save_handler.get_save(user_id=request.user.id, id=id) + if not save: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Save with ID {id} not found", + ) + + device = _resolve_device(device_id, request.user.id) + assert device is not None + + sync = db_device_save_sync_handler.upsert_sync( + device_id=device_id, + save_id=save.id, + synced_at=save.updated_at, + ) + db_device_handler.update_last_seen(device_id=device_id, user_id=request.user.id) + + return _build_save_schema(save, device, sync) @protected_route(router.put, "/{id}", [Scope.ASSETS_WRITE]) async def update_save(request: Request, id: int) -> SaveSchema: + """Update a save file.""" data = await request.form() db_save = db_save_handler.get_save(user_id=request.user.id, id=id) @@ -300,3 +564,51 @@ async def delete_saves( log.error(error) return saves + + +@protected_route(router.post, "/{id}/track", [Scope.DEVICES_WRITE]) +def track_save( + request: Request, + id: int, + device_id: str = Body(..., embed=True), +) -> SaveSchema: + """Re-enable sync tracking for a save on a device.""" + save = db_save_handler.get_save(user_id=request.user.id, id=id) + if not save: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Save with ID {id} not found", + ) + + device = _resolve_device(device_id, request.user.id) + assert device is not None + + sync = db_device_save_sync_handler.set_untracked( + device_id=device_id, save_id=id, untracked=False + ) + + return _build_save_schema(save, device, sync) + + +@protected_route(router.post, "/{id}/untrack", [Scope.DEVICES_WRITE]) +def untrack_save( + request: Request, + id: int, + device_id: str = Body(..., embed=True), +) -> SaveSchema: + """Disable sync tracking for a save on a device.""" + save = db_save_handler.get_save(user_id=request.user.id, id=id) + if not save: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Save with ID {id} not found", + ) + + device = _resolve_device(device_id, request.user.id) + assert device is not None + + sync = db_device_save_sync_handler.set_untracked( + device_id=device_id, save_id=id, untracked=True + ) + + return _build_save_schema(save, device, sync) diff --git a/backend/handler/auth/constants.py b/backend/handler/auth/constants.py index a22681017..e63df5a4e 100644 --- a/backend/handler/auth/constants.py +++ b/backend/handler/auth/constants.py @@ -17,6 +17,8 @@ class Scope(enum.StrEnum): PLATFORMS_WRITE = "platforms.write" ASSETS_READ = "assets.read" ASSETS_WRITE = "assets.write" + DEVICES_READ = "devices.read" + DEVICES_WRITE = "devices.write" FIRMWARE_READ = "firmware.read" FIRMWARE_WRITE = "firmware.write" COLLECTIONS_READ = "collections.read" @@ -31,6 +33,7 @@ class Scope(enum.StrEnum): Scope.ROMS_READ: "View ROMs", Scope.PLATFORMS_READ: "View platforms", Scope.ASSETS_READ: "View assets", + Scope.DEVICES_READ: "View devices", Scope.FIRMWARE_READ: "View firmware", Scope.ROMS_USER_READ: "View user-rom properties", Scope.COLLECTIONS_READ: "View collections", @@ -39,6 +42,7 @@ class Scope(enum.StrEnum): WRITE_SCOPES_MAP: Final = { Scope.ME_WRITE: "Modify your profile", Scope.ASSETS_WRITE: "Modify assets", + Scope.DEVICES_WRITE: "Modify devices", Scope.ROMS_USER_WRITE: "Modify user-rom properties", Scope.COLLECTIONS_WRITE: "Modify collections", } diff --git a/backend/handler/database/__init__.py b/backend/handler/database/__init__.py index 65816f359..a7f1a22d1 100644 --- a/backend/handler/database/__init__.py +++ b/backend/handler/database/__init__.py @@ -1,4 +1,6 @@ from .collections_handler import DBCollectionsHandler +from .device_save_sync_handler import DBDeviceSaveSyncHandler +from .devices_handler import DBDevicesHandler from .firmware_handler import DBFirmwareHandler from .platforms_handler import DBPlatformsHandler from .roms_handler import DBRomsHandler @@ -8,6 +10,9 @@ from .stats_handler import DBStatsHandler from .users_handler import DBUsersHandler +db_collection_handler = DBCollectionsHandler() +db_device_handler = DBDevicesHandler() +db_device_save_sync_handler = DBDeviceSaveSyncHandler() db_firmware_handler = DBFirmwareHandler() db_platform_handler = DBPlatformsHandler() db_rom_handler = DBRomsHandler() @@ -16,4 +21,3 @@ db_state_handler = DBStatesHandler() db_stats_handler = DBStatsHandler() db_user_handler = DBUsersHandler() -db_collection_handler = DBCollectionsHandler() diff --git a/backend/handler/database/device_save_sync_handler.py b/backend/handler/database/device_save_sync_handler.py new file mode 100644 index 000000000..576acdedc --- /dev/null +++ b/backend/handler/database/device_save_sync_handler.py @@ -0,0 +1,129 @@ +from collections.abc import Sequence +from datetime import datetime, timezone + +from sqlalchemy import delete, select, update +from sqlalchemy.orm import Session + +from decorators.database import begin_session +from models.device_save_sync import DeviceSaveSync + +from .base_handler import DBBaseHandler + + +class DBDeviceSaveSyncHandler(DBBaseHandler): + @begin_session + def get_sync( + self, + device_id: str, + save_id: int, + session: Session = None, # type: ignore + ) -> DeviceSaveSync | None: + return session.scalar( + select(DeviceSaveSync) + .filter_by(device_id=device_id, save_id=save_id) + .limit(1) + ) + + @begin_session + def get_syncs_for_device_and_saves( + self, + device_id: str, + save_ids: list[int], + session: Session = None, # type: ignore + ) -> Sequence[DeviceSaveSync]: + if not save_ids: + return [] + return session.scalars( + select(DeviceSaveSync).filter( + DeviceSaveSync.device_id == device_id, + DeviceSaveSync.save_id.in_(save_ids), + ) + ).all() + + @begin_session + def upsert_sync( + self, + device_id: str, + save_id: int, + synced_at: datetime | None = None, + session: Session = None, # type: ignore + ) -> DeviceSaveSync: + now = synced_at or datetime.now(timezone.utc) + existing = session.scalar( + select(DeviceSaveSync) + .filter_by(device_id=device_id, save_id=save_id) + .limit(1) + ) + if existing: + session.execute( + update(DeviceSaveSync) + .where( + DeviceSaveSync.device_id == device_id, + DeviceSaveSync.save_id == save_id, + ) + .values(last_synced_at=now, is_untracked=False) + .execution_options(synchronize_session="evaluate") + ) + existing.last_synced_at = now + existing.is_untracked = False + return existing + else: + sync = DeviceSaveSync( + device_id=device_id, + save_id=save_id, + last_synced_at=now, + is_untracked=False, + ) + session.add(sync) + session.flush() + return sync + + @begin_session + def set_untracked( + self, + device_id: str, + save_id: int, + untracked: bool, + session: Session = None, # type: ignore + ) -> DeviceSaveSync | None: + existing = session.scalar( + select(DeviceSaveSync) + .filter_by(device_id=device_id, save_id=save_id) + .limit(1) + ) + if existing: + session.execute( + update(DeviceSaveSync) + .where( + DeviceSaveSync.device_id == device_id, + DeviceSaveSync.save_id == save_id, + ) + .values(is_untracked=untracked) + .execution_options(synchronize_session="evaluate") + ) + existing.is_untracked = untracked + return existing + elif untracked: + now = datetime.now(timezone.utc) + sync = DeviceSaveSync( + device_id=device_id, + save_id=save_id, + last_synced_at=now, + is_untracked=True, + ) + session.add(sync) + session.flush() + return sync + return None + + @begin_session + def delete_syncs_for_device( + self, + device_id: str, + session: Session = None, # type: ignore + ) -> None: + session.execute( + delete(DeviceSaveSync) + .where(DeviceSaveSync.device_id == device_id) + .execution_options(synchronize_session="evaluate") + ) diff --git a/backend/handler/database/devices_handler.py b/backend/handler/database/devices_handler.py new file mode 100644 index 000000000..81318aa9a --- /dev/null +++ b/backend/handler/database/devices_handler.py @@ -0,0 +1,111 @@ +from collections.abc import Sequence +from datetime import datetime, timezone + +from sqlalchemy import delete, select, update +from sqlalchemy.orm import Session + +from decorators.database import begin_session +from models.device import Device + +from .base_handler import DBBaseHandler + + +class DBDevicesHandler(DBBaseHandler): + @begin_session + def add_device( + self, + device: Device, + session: Session = None, # type: ignore + ) -> Device: + return session.merge(device) + + @begin_session + def get_device( + self, + device_id: str, + user_id: int, + session: Session = None, # type: ignore + ) -> Device | None: + return session.scalar( + select(Device).filter_by(id=device_id, user_id=user_id).limit(1) + ) + + @begin_session + def get_device_by_fingerprint( + self, + user_id: int, + mac_address: str | None = None, + hostname: str | None = None, + platform: str | None = None, + session: Session = None, # type: ignore + ) -> Device | None: + if mac_address: + device = session.scalar( + select(Device) + .filter_by(user_id=user_id, mac_address=mac_address) + .limit(1) + ) + if device: + return device + + if hostname and platform: + return session.scalar( + select(Device) + .filter_by(user_id=user_id, hostname=hostname, platform=platform) + .limit(1) + ) + + return None + + @begin_session + def get_devices( + self, + user_id: int, + session: Session = None, # type: ignore + ) -> Sequence[Device]: + return session.scalars(select(Device).filter_by(user_id=user_id)).all() + + @begin_session + def update_device( + self, + device_id: str, + user_id: int, + data: dict, + session: Session = None, # type: ignore + ) -> Device | None: + session.execute( + update(Device) + .where(Device.id == device_id, Device.user_id == user_id) + .values(**data) + .execution_options(synchronize_session="evaluate") + ) + return session.scalar( + select(Device).filter_by(id=device_id, user_id=user_id).limit(1) + ) + + @begin_session + def update_last_seen( + self, + device_id: str, + user_id: int, + session: Session = None, # type: ignore + ) -> None: + session.execute( + update(Device) + .where(Device.id == device_id, Device.user_id == user_id) + .values(last_seen=datetime.now(timezone.utc)) + .execution_options(synchronize_session="evaluate") + ) + + @begin_session + def delete_device( + self, + device_id: str, + user_id: int, + session: Session = None, # type: ignore + ) -> None: + session.execute( + delete(Device) + .where(Device.id == device_id, Device.user_id == user_id) + .execution_options(synchronize_session="evaluate") + ) diff --git a/backend/handler/database/saves_handler.py b/backend/handler/database/saves_handler.py index 4a06ffeb9..66b2a6eba 100644 --- a/backend/handler/database/saves_handler.py +++ b/backend/handler/database/saves_handler.py @@ -1,6 +1,7 @@ from collections.abc import Sequence +from typing import Literal -from sqlalchemy import and_, delete, select, update +from sqlalchemy import and_, asc, delete, desc, select, update from sqlalchemy.orm import QueryableAttribute, Session, load_only from decorators.database import begin_session @@ -42,12 +43,29 @@ def get_save_by_filename( .limit(1) ).first() + @begin_session + def get_save_by_content_hash( + self, + user_id: int, + rom_id: int, + content_hash: str, + session: Session = None, # type: ignore + ) -> Save | None: + return session.scalar( + select(Save) + .filter_by(rom_id=rom_id, user_id=user_id, content_hash=content_hash) + .limit(1) + ) + @begin_session def get_saves( self, user_id: int, rom_id: int | None = None, platform_id: int | None = None, + slot: str | None = None, + order_by: Literal["updated_at", "created_at"] | None = None, + order_dir: Literal["asc", "desc"] = "desc", only_fields: Sequence[QueryableAttribute] | None = None, session: Session = None, # type: ignore ) -> Sequence[Save]: @@ -61,6 +79,14 @@ def get_saves( Rom.platform_id == platform_id ) + if slot is not None: + query = query.filter(Save.slot == slot) + + if order_by: + order_col = getattr(Save, order_by) + order_fn = asc if order_dir == "asc" else desc + query = query.order_by(order_fn(order_col)) + if only_fields: query = query.options(load_only(*only_fields)) @@ -125,3 +151,28 @@ def mark_missing_saves( ) return missing_saves + + @begin_session + def get_saves_summary( + self, + user_id: int, + rom_id: int, + session: Session = None, # type: ignore + ) -> dict: + saves = session.scalars( + select(Save) + .filter_by(user_id=user_id, rom_id=rom_id) + .order_by(desc(Save.updated_at)) + ).all() + + slots_data: dict[str | None, dict] = {} + for save in saves: + slot_key = save.slot + if slot_key not in slots_data: + slots_data[slot_key] = {"slot": slot_key, "count": 0, "latest": save} + slots_data[slot_key]["count"] += 1 + + return { + "total_count": len(saves), + "slots": list(slots_data.values()), + } diff --git a/backend/handler/filesystem/assets_handler.py b/backend/handler/filesystem/assets_handler.py index 8d8466e92..fa1114f60 100644 --- a/backend/handler/filesystem/assets_handler.py +++ b/backend/handler/filesystem/assets_handler.py @@ -1,11 +1,44 @@ +import hashlib import os +import zipfile from config import ASSETS_BASE_PATH +from logger.logger import log from models.user import User from .base_handler import FSHandler +def compute_file_hash(file_path: str) -> str: + hash_obj = hashlib.md5(usedforsecurity=False) + with open(file_path, "rb") as f: + for chunk in iter(lambda: f.read(8192), b""): + hash_obj.update(chunk) + return hash_obj.hexdigest() + + +def compute_zip_hash(zip_path: str) -> str: + with zipfile.ZipFile(zip_path, "r") as zf: + file_hashes = [] + for name in sorted(zf.namelist()): + if not name.endswith("/"): + content = zf.read(name) + file_hash = hashlib.md5(content, usedforsecurity=False).hexdigest() + file_hashes.append(f"{name}:{file_hash}") + combined = "\n".join(file_hashes) + return hashlib.md5(combined.encode(), usedforsecurity=False).hexdigest() + + +def compute_content_hash(file_path: str) -> str | None: + try: + if zipfile.is_zipfile(file_path): + return compute_zip_hash(file_path) + return compute_file_hash(file_path) + except Exception as e: + log.debug(f"Failed to compute content hash for {file_path}: {e}") + return None + + class FSAssetsHandler(FSHandler): def __init__(self) -> None: super().__init__(base_path=ASSETS_BASE_PATH) diff --git a/backend/handler/scan_handler.py b/backend/handler/scan_handler.py index 6e1c7b26b..02deb304f 100644 --- a/backend/handler/scan_handler.py +++ b/backend/handler/scan_handler.py @@ -4,10 +4,12 @@ import socketio # type: ignore +from config import ASSETS_BASE_PATH from config.config_manager import config_manager as cm from endpoints.responses.rom import SimpleRomSchema from handler.database import db_platform_handler, db_rom_handler from handler.filesystem import fs_asset_handler, fs_firmware_handler +from handler.filesystem.assets_handler import compute_content_hash from handler.filesystem.roms_handler import FSRom from handler.metadata import ( meta_flashpoint_handler, @@ -817,11 +819,11 @@ async def fetch_sgdb_details() -> SGDBRom: return Rom(**rom_attrs) -async def _scan_asset(file_name: str, asset_path: str): +async def _scan_asset(file_name: str, asset_path: str, should_hash: bool = False): file_path = f"{asset_path}/{file_name}" file_size = await fs_asset_handler.get_file_size(file_path) - return { + result = { "file_path": asset_path, "file_name": file_name, "file_name_no_tags": fs_asset_handler.get_file_name_with_no_tags(file_name), @@ -830,6 +832,12 @@ async def _scan_asset(file_name: str, asset_path: str): "file_size_bytes": file_size, } + if should_hash: + absolute_path = f"{ASSETS_BASE_PATH}/{file_path}" + result["content_hash"] = compute_content_hash(absolute_path) + + return result + async def scan_save( file_name: str, @@ -841,7 +849,7 @@ async def scan_save( saves_path = fs_asset_handler.build_saves_file_path( user=user, platform_fs_slug=platform_fs_slug, rom_id=rom_id, emulator=emulator ) - scanned_asset = await _scan_asset(file_name, saves_path) + scanned_asset = await _scan_asset(file_name, saves_path, should_hash=True) return Save(**scanned_asset) diff --git a/backend/main.py b/backend/main.py index 1ea55e139..8dfb20b51 100644 --- a/backend/main.py +++ b/backend/main.py @@ -28,6 +28,7 @@ auth, collections, configs, + device, feeds, firmware, gamelist, @@ -122,6 +123,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None]: app.include_router(heartbeat.router, prefix="/api") app.include_router(auth.router, prefix="/api") app.include_router(user.router, prefix="/api") +app.include_router(device.router, prefix="/api") app.include_router(platform.router, prefix="/api") app.include_router(rom.router, prefix="/api") app.include_router(search.router, prefix="/api") diff --git a/backend/models/assets.py b/backend/models/assets.py index 06eb8d0a1..4311c069b 100644 --- a/backend/models/assets.py +++ b/backend/models/assets.py @@ -14,6 +14,7 @@ ) if TYPE_CHECKING: + from models.device_save_sync import DeviceSaveSync from models.rom import Rom from models.user import User @@ -54,9 +55,16 @@ class Save(RomAsset): __table_args__ = {"extend_existing": True} emulator: Mapped[str | None] = mapped_column(String(length=50)) + slot: Mapped[str | None] = mapped_column(String(length=255)) + content_hash: Mapped[str | None] = mapped_column(String(length=32)) rom: Mapped[Rom] = relationship(lazy="joined", back_populates="saves") user: Mapped[User] = relationship(lazy="joined", back_populates="saves") + device_syncs: Mapped[list[DeviceSaveSync]] = relationship( + back_populates="save", + cascade="all, delete-orphan", + lazy="raise", + ) @cached_property def screenshot(self) -> Screenshot | None: diff --git a/backend/models/device.py b/backend/models/device.py new file mode 100644 index 000000000..15d24febc --- /dev/null +++ b/backend/models/device.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +import enum +from datetime import datetime +from typing import TYPE_CHECKING + +from sqlalchemy import TIMESTAMP, Boolean, Enum, ForeignKey, String +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from models.base import BaseModel + +if TYPE_CHECKING: + from models.device_save_sync import DeviceSaveSync + from models.user import User + + +class SyncMode(enum.StrEnum): + API = "api" + FILE_TRANSFER = "file_transfer" + PUSH_PULL = "push_pull" + + +class Device(BaseModel): + __tablename__ = "devices" + __table_args__ = {"extend_existing": True} + + id: Mapped[str] = mapped_column(String(255), primary_key=True) + user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE")) + + name: Mapped[str | None] = mapped_column(String(255)) + platform: Mapped[str | None] = mapped_column(String(50)) + client: Mapped[str | None] = mapped_column(String(50)) + client_version: Mapped[str | None] = mapped_column(String(50)) + + ip_address: Mapped[str | None] = mapped_column(String(45)) + mac_address: Mapped[str | None] = mapped_column(String(17)) + hostname: Mapped[str | None] = mapped_column(String(255)) + + sync_mode: Mapped[SyncMode] = mapped_column(Enum(SyncMode), default=SyncMode.API) + sync_enabled: Mapped[bool] = mapped_column(Boolean, default=True) + + last_seen: Mapped[datetime | None] = mapped_column(TIMESTAMP(timezone=True)) + + user: Mapped[User] = relationship(lazy="joined") + save_syncs: Mapped[list[DeviceSaveSync]] = relationship( + back_populates="device", + cascade="all, delete-orphan", + lazy="raise", + ) diff --git a/backend/models/device_save_sync.py b/backend/models/device_save_sync.py new file mode 100644 index 000000000..7df0b3fa5 --- /dev/null +++ b/backend/models/device_save_sync.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +from datetime import datetime +from typing import TYPE_CHECKING + +from sqlalchemy import TIMESTAMP, Boolean, ForeignKey, String +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from models.base import BaseModel + +if TYPE_CHECKING: + from models.assets import Save + from models.device import Device + + +class DeviceSaveSync(BaseModel): + __tablename__ = "device_save_sync" + __table_args__ = {"extend_existing": True} + + device_id: Mapped[str] = mapped_column( + String(255), + ForeignKey("devices.id", ondelete="CASCADE"), + primary_key=True, + ) + save_id: Mapped[int] = mapped_column( + ForeignKey("saves.id", ondelete="CASCADE"), + primary_key=True, + ) + + last_synced_at: Mapped[datetime] = mapped_column(TIMESTAMP(timezone=True)) + is_untracked: Mapped[bool] = mapped_column(Boolean, default=False) + + device: Mapped[Device] = relationship(back_populates="save_syncs", lazy="raise") + save: Mapped[Save] = relationship(back_populates="device_syncs", lazy="raise") diff --git a/backend/models/user.py b/backend/models/user.py index e980ec4d4..2643f1656 100644 --- a/backend/models/user.py +++ b/backend/models/user.py @@ -22,6 +22,7 @@ if TYPE_CHECKING: from models.assets import Save, Screenshot, State from models.collection import Collection, SmartCollection + from models.device import Device from models.rom import RomNote, RomUser @@ -79,6 +80,9 @@ class User(BaseModel, SimpleUser): smart_collections: Mapped[list["SmartCollection"]] = relationship( lazy="raise", back_populates="user" ) + devices: Mapped[list["Device"]] = relationship( + lazy="raise", back_populates="user", cascade="all, delete-orphan" + ) @classmethod def kiosk_mode_user(cls) -> User: diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index b2762387a..0f7f19f4b 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -14,6 +14,8 @@ db_user_handler, ) from models.assets import Save, Screenshot, State +from models.device import Device +from models.device_save_sync import DeviceSaveSync from models.platform import Platform from models.rom import Rom from models.user import Role, User @@ -30,6 +32,8 @@ def setup_database(): @pytest.fixture(autouse=True) def clear_database(): with session.begin() as s: + s.query(DeviceSaveSync).delete(synchronize_session="evaluate") + s.query(Device).delete(synchronize_session="evaluate") s.query(Save).delete(synchronize_session="evaluate") s.query(State).delete(synchronize_session="evaluate") s.query(Screenshot).delete(synchronize_session="evaluate") diff --git a/backend/tests/endpoints/test_device.py b/backend/tests/endpoints/test_device.py new file mode 100644 index 000000000..9cc8de583 --- /dev/null +++ b/backend/tests/endpoints/test_device.py @@ -0,0 +1,509 @@ +from datetime import timedelta + +import pytest +from fastapi import status +from fastapi.testclient import TestClient +from main import app + +from endpoints.auth import ACCESS_TOKEN_EXPIRE_MINUTES +from handler.auth import oauth_handler +from handler.database import db_device_handler +from handler.redis_handler import sync_cache +from models.device import Device +from models.user import User + + +@pytest.fixture +def client(): + with TestClient(app) as client: + yield client + + +@pytest.fixture(autouse=True) +def clear_cache(): + yield + sync_cache.flushall() + + +@pytest.fixture +def editor_access_token(editor_user: User): + return oauth_handler.create_oauth_token( + data={ + "sub": editor_user.username, + "iss": "romm:oauth", + "scopes": " ".join(editor_user.oauth_scopes), + "type": "access", + }, + expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES), + ) + + +class TestDeviceEndpoints: + def test_register_device(self, client, access_token: str): + response = client.post( + "/api/devices", + json={ + "name": "Test Device", + "platform": "android", + "client": "argosy", + "client_version": "0.16.0", + }, + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_201_CREATED + data = response.json() + assert data["name"] == "Test Device" + assert "device_id" in data + assert "created_at" in data + + def test_register_device_minimal(self, client, access_token: str): + response = client.post( + "/api/devices", + json={}, + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_201_CREATED + data = response.json() + assert data["name"] is None + assert "device_id" in data + + def test_list_devices(self, client, access_token: str, admin_user: User): + + db_device_handler.add_device( + Device( + id="test-device-1", + user_id=admin_user.id, + name="Device 1", + ) + ) + db_device_handler.add_device( + Device( + id="test-device-2", + user_id=admin_user.id, + name="Device 2", + ) + ) + + response = client.get( + "/api/devices", + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert len(data) == 2 + names = [d["name"] for d in data] + assert "Device 1" in names + assert "Device 2" in names + + def test_get_device(self, client, access_token: str, admin_user: User): + + device = db_device_handler.add_device( + Device( + id="test-device-get", + user_id=admin_user.id, + name="Get Test Device", + platform="linux", + ) + ) + + response = client.get( + f"/api/devices/{device.id}", + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["id"] == "test-device-get" + assert data["name"] == "Get Test Device" + assert data["platform"] == "linux" + + def test_get_device_not_found(self, client, access_token: str): + response = client.get( + "/api/devices/nonexistent-device", + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_404_NOT_FOUND + + def test_update_device(self, client, access_token: str, admin_user: User): + device = db_device_handler.add_device( + Device( + id="test-device-update", + user_id=admin_user.id, + name="Original Name", + ) + ) + + response = client.put( + f"/api/devices/{device.id}", + json={ + "name": "Updated Name", + "platform": "android", + "client": "daijishou", + "client_version": "4.0.0", + "ip_address": "192.168.1.100", + "mac_address": "AA:BB:CC:DD:EE:FF", + "hostname": "my-odin3", + "sync_enabled": False, + }, + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["name"] == "Updated Name" + assert data["platform"] == "android" + assert data["client"] == "daijishou" + assert data["client_version"] == "4.0.0" + assert data["ip_address"] == "192.168.1.100" + assert data["mac_address"] == "AA:BB:CC:DD:EE:FF" + assert data["hostname"] == "my-odin3" + assert data["sync_enabled"] is False + + def test_delete_device(self, client, access_token: str, admin_user: User): + + device = db_device_handler.add_device( + Device( + id="test-device-delete", + user_id=admin_user.id, + name="To Delete", + ) + ) + + response = client.delete( + f"/api/devices/{device.id}", + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_204_NO_CONTENT + + get_response = client.get( + f"/api/devices/{device.id}", + headers={"Authorization": f"Bearer {access_token}"}, + ) + assert get_response.status_code == status.HTTP_404_NOT_FOUND + + +class TestDeviceUserIsolation: + def test_list_devices_only_returns_own_devices( + self, + client, + access_token: str, + editor_access_token: str, + admin_user: User, + editor_user: User, + ): + db_device_handler.add_device( + Device(id="admin-device", user_id=admin_user.id, name="Admin Device") + ) + db_device_handler.add_device( + Device(id="editor-device", user_id=editor_user.id, name="Editor Device") + ) + + admin_response = client.get( + "/api/devices", + headers={"Authorization": f"Bearer {access_token}"}, + ) + assert admin_response.status_code == status.HTTP_200_OK + admin_devices = admin_response.json() + assert len(admin_devices) == 1 + assert admin_devices[0]["name"] == "Admin Device" + + editor_response = client.get( + "/api/devices", + headers={"Authorization": f"Bearer {editor_access_token}"}, + ) + assert editor_response.status_code == status.HTTP_200_OK + editor_devices = editor_response.json() + assert len(editor_devices) == 1 + assert editor_devices[0]["name"] == "Editor Device" + + def test_cannot_get_other_users_device( + self, + client, + editor_access_token: str, + admin_user: User, + ): + device = db_device_handler.add_device( + Device(id="admin-only-device", user_id=admin_user.id, name="Admin Only") + ) + + response = client.get( + f"/api/devices/{device.id}", + headers={"Authorization": f"Bearer {editor_access_token}"}, + ) + assert response.status_code == status.HTTP_404_NOT_FOUND + + def test_cannot_update_other_users_device( + self, + client, + editor_access_token: str, + admin_user: User, + ): + device = db_device_handler.add_device( + Device(id="admin-protected-device", user_id=admin_user.id, name="Protected") + ) + + response = client.put( + f"/api/devices/{device.id}", + json={"name": "Hacked Name"}, + headers={"Authorization": f"Bearer {editor_access_token}"}, + ) + assert response.status_code == status.HTTP_404_NOT_FOUND + + original = db_device_handler.get_device( + device_id=device.id, user_id=admin_user.id + ) + assert original.name == "Protected" + + def test_cannot_delete_other_users_device( + self, + client, + editor_access_token: str, + admin_user: User, + ): + device = db_device_handler.add_device( + Device(id="admin-nodelete-device", user_id=admin_user.id, name="No Delete") + ) + + response = client.delete( + f"/api/devices/{device.id}", + headers={"Authorization": f"Bearer {editor_access_token}"}, + ) + assert response.status_code == status.HTTP_404_NOT_FOUND + + still_exists = db_device_handler.get_device( + device_id=device.id, user_id=admin_user.id + ) + assert still_exists is not None + + +class TestDeviceDuplicateHandling: + def test_duplicate_mac_address_returns_existing( + self, client, access_token: str, admin_user: User + ): + db_device_handler.add_device( + Device( + id="existing-mac-device", + user_id=admin_user.id, + name="Existing Device", + mac_address="AA:BB:CC:DD:EE:FF", + ) + ) + + response = client.post( + "/api/devices", + json={ + "name": "New Device", + "mac_address": "AA:BB:CC:DD:EE:FF", + }, + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["device_id"] == "existing-mac-device" + assert data["name"] == "Existing Device" + + def test_duplicate_hostname_platform_returns_existing( + self, client, access_token: str, admin_user: User + ): + db_device_handler.add_device( + Device( + id="existing-hostname-device", + user_id=admin_user.id, + name="Existing Device", + hostname="my-device", + platform="android", + ) + ) + + response = client.post( + "/api/devices", + json={ + "name": "New Device", + "hostname": "my-device", + "platform": "android", + }, + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["device_id"] == "existing-hostname-device" + assert data["name"] == "Existing Device" + + def test_duplicate_with_allow_existing_false_returns_409( + self, client, access_token: str, admin_user: User + ): + db_device_handler.add_device( + Device( + id="reject-duplicate-device", + user_id=admin_user.id, + name="Existing Device", + mac_address="FF:EE:DD:CC:BB:AA", + ) + ) + + response = client.post( + "/api/devices", + json={ + "name": "New Device", + "mac_address": "FF:EE:DD:CC:BB:AA", + "allow_existing": False, + }, + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_409_CONFLICT + data = response.json()["detail"] + assert data["error"] == "device_exists" + assert data["device_id"] == "reject-duplicate-device" + + def test_allow_existing_returns_existing_device( + self, client, access_token: str, admin_user: User + ): + existing = db_device_handler.add_device( + Device( + id="allow-existing-device", + user_id=admin_user.id, + name="Existing Device", + mac_address="11:22:33:44:55:66", + ) + ) + + response = client.post( + "/api/devices", + json={ + "name": "New Device Name", + "mac_address": "11:22:33:44:55:66", + "allow_existing": True, + }, + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["device_id"] == existing.id + assert data["name"] == "Existing Device" + + def test_allow_existing_with_reset_syncs( + self, client, access_token: str, admin_user: User, rom + ): + from handler.database import db_device_save_sync_handler, db_save_handler + from models.assets import Save + + existing = db_device_handler.add_device( + Device( + id="reset-syncs-device", + user_id=admin_user.id, + name="Device With Syncs", + mac_address="77:88:99:AA:BB:CC", + ) + ) + + save = db_save_handler.add_save( + Save( + file_name="test.sav", + file_name_no_tags="test", + file_name_no_ext="test", + file_extension="sav", + file_path="/saves", + file_size_bytes=100, + rom_id=rom.id, + user_id=admin_user.id, + ) + ) + db_device_save_sync_handler.upsert_sync(device_id=existing.id, save_id=save.id) + + sync_before = db_device_save_sync_handler.get_sync( + device_id=existing.id, save_id=save.id + ) + assert sync_before is not None + + response = client.post( + "/api/devices", + json={ + "mac_address": "77:88:99:AA:BB:CC", + "allow_existing": True, + "reset_syncs": True, + }, + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_200_OK + assert response.json()["device_id"] == existing.id + + sync_after = db_device_save_sync_handler.get_sync( + device_id=existing.id, save_id=save.id + ) + assert sync_after is None + + def test_allow_duplicate_creates_new_device( + self, client, access_token: str, admin_user: User + ): + existing = db_device_handler.add_device( + Device( + id="original-device", + user_id=admin_user.id, + name="Original Device", + mac_address="DD:EE:FF:00:11:22", + ) + ) + + response = client.post( + "/api/devices", + json={ + "name": "Duplicate Install", + "mac_address": "DD:EE:FF:00:11:22", + "allow_duplicate": True, + }, + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_201_CREATED + data = response.json() + assert data["device_id"] != existing.id + assert data["name"] == "Duplicate Install" + + def test_no_conflict_without_fingerprint(self, client, access_token: str): + response1 = client.post( + "/api/devices", + json={"name": "Device 1"}, + headers={"Authorization": f"Bearer {access_token}"}, + ) + assert response1.status_code == status.HTTP_201_CREATED + + response2 = client.post( + "/api/devices", + json={"name": "Device 2"}, + headers={"Authorization": f"Bearer {access_token}"}, + ) + assert response2.status_code == status.HTTP_201_CREATED + assert response1.json()["device_id"] != response2.json()["device_id"] + + def test_hostname_only_no_conflict_without_platform( + self, client, access_token: str, admin_user: User + ): + db_device_handler.add_device( + Device( + id="hostname-only-device", + user_id=admin_user.id, + name="Existing", + hostname="my-device", + ) + ) + + response = client.post( + "/api/devices", + json={ + "name": "New Device", + "hostname": "my-device", + }, + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_201_CREATED diff --git a/backend/tests/endpoints/test_saves.py b/backend/tests/endpoints/test_saves.py new file mode 100644 index 000000000..dcf6998d1 --- /dev/null +++ b/backend/tests/endpoints/test_saves.py @@ -0,0 +1,2137 @@ +from datetime import timedelta +from io import BytesIO +from unittest import mock + +import pytest +from fastapi import status +from fastapi.testclient import TestClient +from main import app + +from endpoints.auth import ACCESS_TOKEN_EXPIRE_MINUTES +from handler.auth import oauth_handler +from handler.auth.constants import Scope +from handler.database import db_device_handler, db_device_save_sync_handler +from handler.redis_handler import sync_cache +from models.assets import Save +from models.device import Device +from models.platform import Platform +from models.rom import Rom +from models.user import User + + +@pytest.fixture +def client(): + with TestClient(app) as client: + yield client + + +@pytest.fixture(autouse=True) +def clear_cache(): + yield + sync_cache.flushall() + + +@pytest.fixture +def device(admin_user: User): + return db_device_handler.add_device( + Device( + id="test-sync-device", + user_id=admin_user.id, + name="Sync Test Device", + ) + ) + + +@pytest.fixture +def token_without_device_scopes(admin_user: User): + scopes = [ + s + for s in admin_user.oauth_scopes + if s not in (Scope.DEVICES_READ, Scope.DEVICES_WRITE) + ] + return oauth_handler.create_oauth_token( + data={ + "sub": admin_user.username, + "iss": "romm:oauth", + "scopes": " ".join(scopes), + "type": "access", + }, + expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES), + ) + + +class TestSaveSyncEndpoints: + def test_get_saves_without_device_id(self, client, access_token: str, save: Save): + response = client.get( + "/api/saves", + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert len(data) == 1 + assert data[0]["id"] == save.id + assert data[0]["device_syncs"] == [] + + def test_get_saves_with_device_id_no_sync( + self, client, access_token: str, save: Save, device: Device + ): + response = client.get( + f"/api/saves?device_id={device.id}", + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert len(data) == 1 + assert len(data[0]["device_syncs"]) == 1 + assert data[0]["device_syncs"][0]["device_id"] == device.id + assert data[0]["device_syncs"][0]["is_untracked"] is False + assert data[0]["device_syncs"][0]["is_current"] is False + + def test_get_saves_with_device_id_synced( + self, client, access_token: str, save: Save, device: Device + ): + db_device_save_sync_handler.upsert_sync(device_id=device.id, save_id=save.id) + + response = client.get( + f"/api/saves?device_id={device.id}", + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert len(data[0]["device_syncs"]) == 1 + assert data[0]["device_syncs"][0]["is_untracked"] is False + assert data[0]["device_syncs"][0]["is_current"] is True + + def test_get_single_save_with_device_id( + self, client, access_token: str, save: Save, device: Device + ): + response = client.get( + f"/api/saves/{save.id}?device_id={device.id}", + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["id"] == save.id + assert len(data["device_syncs"]) == 1 + assert data["device_syncs"][0]["is_untracked"] is False + + def test_track_save(self, client, access_token: str, save: Save, device: Device): + db_device_save_sync_handler.set_untracked( + device_id=device.id, save_id=save.id, untracked=True + ) + + response = client.post( + f"/api/saves/{save.id}/track", + json={"device_id": device.id}, + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert len(data["device_syncs"]) == 1 + assert data["device_syncs"][0]["is_untracked"] is False + + def test_untrack_save(self, client, access_token: str, save: Save, device: Device): + response = client.post( + f"/api/saves/{save.id}/untrack", + json={"device_id": device.id}, + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert len(data["device_syncs"]) == 1 + assert data["device_syncs"][0]["is_untracked"] is True + + def test_track_save_not_found(self, client, access_token: str, device: Device): + response = client.post( + "/api/saves/99999/track", + json={"device_id": device.id}, + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_404_NOT_FOUND + + def test_track_save_device_not_found(self, client, access_token: str, save: Save): + response = client.post( + f"/api/saves/{save.id}/track", + json={"device_id": "nonexistent-device"}, + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_404_NOT_FOUND + + def test_get_saves_with_invalid_device_id_returns_404( + self, client, access_token: str, save: Save + ): + response = client.get( + "/api/saves?device_id=nonexistent-device", + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_404_NOT_FOUND + assert "nonexistent-device" in response.json()["detail"] + + def test_get_saves_with_device_id_no_saves( + self, client, access_token: str, device: Device + ): + """Test empty save_ids path in get_syncs_for_device_and_saves.""" + response = client.get( + f"/api/saves?device_id={device.id}&rom_id=99999", + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_200_OK + assert response.json() == [] + + def test_untrack_save_never_synced_creates_untracked_record( + self, client, access_token: str, save: Save, device: Device + ): + """Untracking a save that was never synced creates a new untracked record.""" + sync = db_device_save_sync_handler.get_sync( + device_id=device.id, save_id=save.id + ) + assert sync is None + + response = client.post( + f"/api/saves/{save.id}/untrack", + json={"device_id": device.id}, + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert len(data["device_syncs"]) == 1 + assert data["device_syncs"][0]["is_untracked"] is True + + sync = db_device_save_sync_handler.get_sync( + device_id=device.id, save_id=save.id + ) + assert sync is not None + assert sync.is_untracked is True + + def test_track_save_never_synced_is_noop( + self, client, access_token: str, save: Save, device: Device + ): + """Tracking a save that was never synced doesn't create a DB record. + + The response still includes a synthetic sync entry (is_untracked=False) + but no actual record is created in the database. + """ + sync = db_device_save_sync_handler.get_sync( + device_id=device.id, save_id=save.id + ) + assert sync is None + + response = client.post( + f"/api/saves/{save.id}/track", + json={"device_id": device.id}, + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert len(data["device_syncs"]) == 1 + assert data["device_syncs"][0]["is_untracked"] is False + + sync = db_device_save_sync_handler.get_sync( + device_id=device.id, save_id=save.id + ) + assert sync is None + + def test_get_single_save_with_invalid_device_id_returns_404( + self, client, access_token: str, save: Save + ): + response = client.get( + f"/api/saves/{save.id}?device_id=nonexistent-device", + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_404_NOT_FOUND + assert "nonexistent-device" in response.json()["detail"] + + +class TestSaveUploadWithSync: + @mock.patch( + "endpoints.saves.fs_asset_handler.write_file", new_callable=mock.AsyncMock + ) + @mock.patch("endpoints.saves.scan_save", new_callable=mock.AsyncMock) + def test_upload_save_without_device_id( + self, + mock_scan, + _mock_write, + client, + access_token: str, + rom: Rom, + platform: Platform, + admin_user: User, + ): + mock_save = Save( + file_name="test.sav", + file_name_no_tags="test", + file_name_no_ext="test", + file_extension="sav", + file_path=f"{platform.slug}/saves", + file_size_bytes=100, + rom_id=rom.id, + user_id=admin_user.id, + ) + mock_scan.return_value = mock_save + + file_content = BytesIO(b"test save data") + response = client.post( + f"/api/saves?rom_id={rom.id}", + files={"saveFile": ("test.sav", file_content, "application/octet-stream")}, + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["device_syncs"] == [] + + @mock.patch( + "endpoints.saves.fs_asset_handler.write_file", new_callable=mock.AsyncMock + ) + @mock.patch("endpoints.saves.scan_save", new_callable=mock.AsyncMock) + def test_upload_save_with_device_id( + self, + mock_scan, + _mock_write, + client, + access_token: str, + rom: Rom, + platform: Platform, + admin_user: User, + device: Device, + ): + mock_save = Save( + file_name="test.sav", + file_name_no_tags="test", + file_name_no_ext="test", + file_extension="sav", + file_path=f"{platform.slug}/saves", + file_size_bytes=100, + rom_id=rom.id, + user_id=admin_user.id, + ) + mock_scan.return_value = mock_save + + file_content = BytesIO(b"test save data") + response = client.post( + f"/api/saves?rom_id={rom.id}&device_id={device.id}", + files={"saveFile": ("test.sav", file_content, "application/octet-stream")}, + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert len(data["device_syncs"]) == 1 + assert data["device_syncs"][0]["device_id"] == device.id + assert data["device_syncs"][0]["is_untracked"] is False + + def test_upload_save_with_invalid_device_id_returns_404( + self, + client, + access_token: str, + rom: Rom, + ): + file_content = BytesIO(b"test save data") + response = client.post( + f"/api/saves?rom_id={rom.id}&device_id=nonexistent-device", + files={"saveFile": ("test.sav", file_content, "application/octet-stream")}, + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_404_NOT_FOUND + assert "nonexistent-device" in response.json()["detail"] + + @mock.patch( + "endpoints.saves.fs_asset_handler.write_file", new_callable=mock.AsyncMock + ) + @mock.patch("endpoints.saves.scan_save", new_callable=mock.AsyncMock) + def test_upload_save_with_slot( + self, + mock_scan, + _mock_write, + client, + access_token: str, + rom: Rom, + platform: Platform, + admin_user: User, + ): + mock_save = Save( + file_name="slot1.sav", + file_name_no_tags="slot1", + file_name_no_ext="slot1", + file_extension="sav", + file_path=f"{platform.slug}/saves", + file_size_bytes=100, + rom_id=rom.id, + user_id=admin_user.id, + slot="Slot 1", + ) + mock_scan.return_value = mock_save + + file_content = BytesIO(b"test save data") + response = client.post( + f"/api/saves?rom_id={rom.id}&slot=Slot%201", + files={"saveFile": ("slot1.sav", file_content, "application/octet-stream")}, + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["slot"] == "Slot 1" + + +class TestSaveConflictDetection: + @pytest.fixture + def device_b(self, admin_user: User): + return db_device_handler.add_device( + Device( + id="test-sync-device-b", + user_id=admin_user.id, + name="Device B", + ) + ) + + @mock.patch( + "endpoints.saves.fs_asset_handler.write_file", new_callable=mock.AsyncMock + ) + @mock.patch("endpoints.saves.scan_save", new_callable=mock.AsyncMock) + def test_first_upload_from_device_no_sync_exists( + self, + mock_scan, + _mock_write, + client, + access_token: str, + rom: Rom, + platform: Platform, + admin_user: User, + save: Save, + device: Device, + ): + """Scenario 1: First upload from device (no sync record exists) should succeed.""" + mock_scan.return_value = save + + file_content = BytesIO(b"save data from device") + response = client.post( + f"/api/saves?rom_id={rom.id}&device_id={device.id}", + files={ + "saveFile": (save.file_name, file_content, "application/octet-stream") + }, + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert len(data["device_syncs"]) == 1 + assert data["device_syncs"][0]["device_id"] == device.id + + @mock.patch( + "endpoints.saves.fs_asset_handler.write_file", new_callable=mock.AsyncMock + ) + @mock.patch("endpoints.saves.scan_save", new_callable=mock.AsyncMock) + def test_sync_equals_updated_at_no_conflict( + self, + mock_scan, + _mock_write, + client, + access_token: str, + rom: Rom, + platform: Platform, + admin_user: User, + save: Save, + device: Device, + ): + """Scenario 2: Device sync timestamp equals save.updated_at should succeed.""" + db_device_save_sync_handler.upsert_sync( + device_id=device.id, save_id=save.id, synced_at=save.updated_at + ) + + mock_scan.return_value = save + + file_content = BytesIO(b"updated save data") + response = client.post( + f"/api/saves?rom_id={rom.id}&device_id={device.id}", + files={ + "saveFile": (save.file_name, file_content, "application/octet-stream") + }, + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_200_OK + + @mock.patch( + "endpoints.saves.fs_asset_handler.write_file", new_callable=mock.AsyncMock + ) + @mock.patch("endpoints.saves.scan_save", new_callable=mock.AsyncMock) + def test_upload_without_device_id_always_succeeds( + self, + mock_scan, + _mock_write, + client, + access_token: str, + rom: Rom, + platform: Platform, + admin_user: User, + save: Save, + ): + """Scenario 3: Upload without device_id bypasses conflict detection.""" + mock_scan.return_value = save + + file_content = BytesIO(b"updated from web ui") + response = client.post( + f"/api/saves?rom_id={rom.id}", + files={ + "saveFile": (save.file_name, file_content, "application/octet-stream") + }, + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["device_syncs"] == [] + + @mock.patch( + "endpoints.saves.fs_asset_handler.write_file", new_callable=mock.AsyncMock + ) + @mock.patch("endpoints.saves.scan_save", new_callable=mock.AsyncMock) + def test_new_save_with_device_id_succeeds( + self, + mock_scan, + _mock_write, + client, + access_token: str, + rom: Rom, + platform: Platform, + admin_user: User, + device: Device, + ): + """Scenario 4: Creating a new save with device_id always succeeds.""" + new_save = Save( + file_name="brand_new_save.sav", + file_name_no_tags="brand_new_save", + file_name_no_ext="brand_new_save", + file_extension="sav", + file_path=f"{platform.slug}/saves", + file_size_bytes=100, + rom_id=rom.id, + user_id=admin_user.id, + ) + mock_scan.return_value = new_save + + file_content = BytesIO(b"brand new save data") + response = client.post( + f"/api/saves?rom_id={rom.id}&device_id={device.id}", + files={ + "saveFile": ( + "brand_new_save.sav", + file_content, + "application/octet-stream", + ) + }, + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert len(data["device_syncs"]) == 1 + + @mock.patch( + "endpoints.saves.fs_asset_handler.write_file", new_callable=mock.AsyncMock + ) + @mock.patch("endpoints.saves.scan_save", new_callable=mock.AsyncMock) + def test_device_b_downloads_then_uploads_no_conflict( + self, + mock_scan, + _mock_write, + client, + access_token: str, + rom: Rom, + platform: Platform, + admin_user: User, + save: Save, + device: Device, + device_b: Device, + ): + """Scenario 5: Device A uploads, Device B downloads (syncs), Device B uploads. + + Device B should succeed because it has the latest sync timestamp. + """ + + db_device_save_sync_handler.upsert_sync( + device_id=device.id, save_id=save.id, synced_at=save.updated_at + ) + + db_device_save_sync_handler.upsert_sync( + device_id=device_b.id, save_id=save.id, synced_at=save.updated_at + ) + + mock_scan.return_value = save + + file_content = BytesIO(b"save from device b after download") + response = client.post( + f"/api/saves?rom_id={rom.id}&device_id={device_b.id}", + files={ + "saveFile": (save.file_name, file_content, "application/octet-stream") + }, + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["device_syncs"][0]["device_id"] == device_b.id + + @mock.patch( + "endpoints.saves.fs_asset_handler.write_file", new_callable=mock.AsyncMock + ) + @mock.patch("endpoints.saves.scan_save", new_callable=mock.AsyncMock) + def test_device_b_uploads_without_download_conflict( + self, + mock_scan, + _mock_write, + client, + access_token: str, + rom: Rom, + platform: Platform, + admin_user: User, + save: Save, + device: Device, + device_b: Device, + ): + """Scenario 6: Device A uploads, Device B uploads without downloading first. + + Device B has an old sync from before Device A's upload, so conflict. + """ + from datetime import datetime, timedelta, timezone + + old_sync_time = datetime.now(timezone.utc) - timedelta(hours=2) + db_device_save_sync_handler.upsert_sync( + device_id=device_b.id, save_id=save.id, synced_at=old_sync_time + ) + + db_device_save_sync_handler.upsert_sync( + device_id=device.id, save_id=save.id, synced_at=save.updated_at + ) + + mock_scan.return_value = save + + file_content = BytesIO(b"stale save from device b") + response = client.post( + f"/api/saves?rom_id={rom.id}&device_id={device_b.id}", + files={ + "saveFile": (save.file_name, file_content, "application/octet-stream") + }, + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_409_CONFLICT + data = response.json() + assert "since your last sync" in data["detail"] + + @mock.patch( + "endpoints.saves.fs_asset_handler.write_file", new_callable=mock.AsyncMock + ) + @mock.patch("endpoints.saves.scan_save", new_callable=mock.AsyncMock) + def test_web_ui_uploads_then_device_with_old_sync_conflict( + self, + mock_scan, + _mock_write, + client, + access_token: str, + rom: Rom, + platform: Platform, + admin_user: User, + save: Save, + device: Device, + ): + """Scenario 7: Web UI uploads (no device_id), device with old sync uploads. + + Device A synced the save, then web UI uploaded a new version (without device_id). + Device A tries to upload without re-downloading - should conflict. + """ + from datetime import datetime, timedelta, timezone + + old_sync_time = datetime.now(timezone.utc) - timedelta(hours=1) + db_device_save_sync_handler.upsert_sync( + device_id=device.id, save_id=save.id, synced_at=old_sync_time + ) + + mock_scan.return_value = save + + file_content = BytesIO(b"stale save from device after web update") + response = client.post( + f"/api/saves?rom_id={rom.id}&device_id={device.id}", + files={ + "saveFile": (save.file_name, file_content, "application/octet-stream") + }, + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_409_CONFLICT + data = response.json() + assert "since your last sync" in data["detail"] + + @mock.patch( + "endpoints.saves.fs_asset_handler.write_file", new_callable=mock.AsyncMock + ) + @mock.patch("endpoints.saves.scan_save", new_callable=mock.AsyncMock) + def test_upload_conflict_bypassed_with_overwrite( + self, + mock_scan, + _mock_write, + client, + access_token: str, + rom: Rom, + platform: Platform, + admin_user: User, + save: Save, + device: Device, + ): + """Verify overwrite=true bypasses conflict detection.""" + from datetime import datetime, timedelta, timezone + + old_sync_time = datetime.now(timezone.utc) - timedelta(hours=1) + db_device_save_sync_handler.upsert_sync( + device_id=device.id, save_id=save.id, synced_at=old_sync_time + ) + + mock_scan.return_value = save + + file_content = BytesIO(b"forced overwrite") + response = client.post( + f"/api/saves?rom_id={rom.id}&device_id={device.id}&overwrite=true", + files={ + "saveFile": (save.file_name, file_content, "application/octet-stream") + }, + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_200_OK + + @mock.patch( + "endpoints.saves.fs_asset_handler.write_file", new_callable=mock.AsyncMock + ) + @mock.patch("endpoints.saves.scan_save", new_callable=mock.AsyncMock) + def test_conflict_response_contains_details( + self, + mock_scan, + _mock_write, + client, + access_token: str, + rom: Rom, + platform: Platform, + admin_user: User, + save: Save, + device: Device, + ): + """Verify conflict response contains all necessary details for client handling.""" + from datetime import datetime, timedelta, timezone + + old_sync_time = datetime.now(timezone.utc) - timedelta(hours=1) + db_device_save_sync_handler.upsert_sync( + device_id=device.id, save_id=save.id, synced_at=old_sync_time + ) + + mock_scan.return_value = save + + file_content = BytesIO(b"conflicting save") + response = client.post( + f"/api/saves?rom_id={rom.id}&device_id={device.id}", + files={ + "saveFile": (save.file_name, file_content, "application/octet-stream") + }, + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_409_CONFLICT + data = response.json() + assert "since your last sync" in data["detail"] + + @mock.patch( + "endpoints.saves.fs_asset_handler.write_file", new_callable=mock.AsyncMock + ) + @mock.patch("endpoints.saves.scan_save", new_callable=mock.AsyncMock) + def test_out_of_sync_response_with_slot( + self, + mock_scan, + _mock_write, + client, + access_token: str, + rom: Rom, + platform: Platform, + admin_user: User, + device: Device, + ): + """Verify out_of_sync response when uploading with slot (non-destructive). + + Slot conflict detection checks if device has synced the latest save in the slot, + not by exact filename (since datetime tags make each upload unique). + """ + from datetime import datetime, timedelta, timezone + + from handler.database import db_save_handler + + existing_slot_save = Save( + file_name="existing_slot_save.sav", + file_name_no_tags="existing_slot_save", + file_name_no_ext="existing_slot_save", + file_extension="sav", + file_path=f"{platform.slug}/saves", + file_size_bytes=100, + rom_id=rom.id, + user_id=admin_user.id, + slot="TestSlot", + ) + db_slot_save = db_save_handler.add_save(existing_slot_save) + + old_sync_time = datetime.now(timezone.utc) - timedelta(hours=1) + db_device_save_sync_handler.upsert_sync( + device_id=device.id, save_id=db_slot_save.id, synced_at=old_sync_time + ) + + mock_scan.return_value = Save( + file_name="new_upload.sav", + file_name_no_tags="new_upload", + file_name_no_ext="new_upload", + file_extension="sav", + file_path=f"{platform.slug}/saves", + file_size_bytes=100, + rom_id=rom.id, + user_id=admin_user.id, + slot="TestSlot", + ) + + file_content = BytesIO(b"out of sync save") + response = client.post( + f"/api/saves?rom_id={rom.id}&device_id={device.id}&slot=TestSlot", + files={ + "saveFile": ("new_upload.sav", file_content, "application/octet-stream") + }, + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_409_CONFLICT + data = response.json() + assert "newer save since your last sync" in data["detail"] + + @mock.patch( + "endpoints.saves.fs_asset_handler.write_file", new_callable=mock.AsyncMock + ) + @mock.patch("endpoints.saves.scan_save", new_callable=mock.AsyncMock) + def test_first_upload_to_slot_succeeds( + self, + mock_scan, + _mock_write, + client, + access_token: str, + rom: Rom, + platform: Platform, + admin_user: User, + device: Device, + ): + """First upload to a slot (no existing saves) should succeed.""" + mock_scan.return_value = Save( + file_name="first_in_slot.sav", + file_name_no_tags="first_in_slot", + file_name_no_ext="first_in_slot", + file_extension="sav", + file_path=f"{platform.slug}/saves", + file_size_bytes=100, + rom_id=rom.id, + user_id=admin_user.id, + slot="BrandNewSlot", + ) + + file_content = BytesIO(b"first save in slot") + response = client.post( + f"/api/saves?rom_id={rom.id}&device_id={device.id}&slot=BrandNewSlot", + files={ + "saveFile": ( + "first_in_slot.sav", + file_content, + "application/octet-stream", + ) + }, + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["slot"] == "BrandNewSlot" + + @mock.patch( + "endpoints.saves.fs_asset_handler.write_file", new_callable=mock.AsyncMock + ) + @mock.patch("endpoints.saves.scan_save", new_callable=mock.AsyncMock) + def test_upload_to_slot_with_current_sync_succeeds( + self, + mock_scan, + _mock_write, + client, + access_token: str, + rom: Rom, + platform: Platform, + admin_user: User, + device: Device, + ): + """Upload to slot succeeds when device has synced the latest save.""" + from handler.database import db_save_handler + + existing_slot_save = Save( + file_name="synced_save.sav", + file_name_no_tags="synced_save", + file_name_no_ext="synced_save", + file_extension="sav", + file_path=f"{platform.slug}/saves", + file_size_bytes=100, + rom_id=rom.id, + user_id=admin_user.id, + slot="SyncedSlot", + ) + db_slot_save = db_save_handler.add_save(existing_slot_save) + + db_device_save_sync_handler.upsert_sync( + device_id=device.id, + save_id=db_slot_save.id, + synced_at=db_slot_save.updated_at, + ) + + mock_scan.return_value = Save( + file_name="next_upload.sav", + file_name_no_tags="next_upload", + file_name_no_ext="next_upload", + file_extension="sav", + file_path=f"{platform.slug}/saves", + file_size_bytes=100, + rom_id=rom.id, + user_id=admin_user.id, + slot="SyncedSlot", + ) + + file_content = BytesIO(b"next save in slot") + response = client.post( + f"/api/saves?rom_id={rom.id}&device_id={device.id}&slot=SyncedSlot", + files={ + "saveFile": ( + "next_upload.sav", + file_content, + "application/octet-stream", + ) + }, + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_200_OK + + @mock.patch( + "endpoints.saves.fs_asset_handler.write_file", new_callable=mock.AsyncMock + ) + @mock.patch("endpoints.saves.scan_save", new_callable=mock.AsyncMock) + def test_out_of_sync_with_no_prior_device_sync( + self, + mock_scan, + _mock_write, + client, + access_token: str, + rom: Rom, + platform: Platform, + admin_user: User, + device: Device, + ): + """Device that never synced any save in slot should get out_of_sync.""" + from handler.database import db_save_handler + + existing_slot_save = Save( + file_name="never_synced.sav", + file_name_no_tags="never_synced", + file_name_no_ext="never_synced", + file_extension="sav", + file_path=f"{platform.slug}/saves", + file_size_bytes=100, + rom_id=rom.id, + user_id=admin_user.id, + slot="NeverSyncedSlot", + ) + db_save_handler.add_save(existing_slot_save) + + mock_scan.return_value = Save( + file_name="upload_attempt.sav", + file_name_no_tags="upload_attempt", + file_name_no_ext="upload_attempt", + file_extension="sav", + file_path=f"{platform.slug}/saves", + file_size_bytes=100, + rom_id=rom.id, + user_id=admin_user.id, + slot="NeverSyncedSlot", + ) + + file_content = BytesIO(b"upload without prior sync") + response = client.post( + f"/api/saves?rom_id={rom.id}&device_id={device.id}&slot=NeverSyncedSlot", + files={ + "saveFile": ( + "upload_attempt.sav", + file_content, + "application/octet-stream", + ) + }, + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_409_CONFLICT + data = response.json() + assert "newer save since your last sync" in data["detail"] + + +class TestDeviceScopeEnforcement: + def test_get_saves_with_device_id_requires_scope( + self, client, token_without_device_scopes: str, save: Save, device: Device + ): + response = client.get( + f"/api/saves?device_id={device.id}", + headers={"Authorization": f"Bearer {token_without_device_scopes}"}, + ) + assert response.status_code == status.HTTP_403_FORBIDDEN + + def test_get_single_save_with_device_id_requires_scope( + self, client, token_without_device_scopes: str, save: Save, device: Device + ): + response = client.get( + f"/api/saves/{save.id}?device_id={device.id}", + headers={"Authorization": f"Bearer {token_without_device_scopes}"}, + ) + assert response.status_code == status.HTTP_403_FORBIDDEN + + @mock.patch( + "endpoints.saves.fs_asset_handler.write_file", new_callable=mock.AsyncMock + ) + @mock.patch("endpoints.saves.scan_save", new_callable=mock.AsyncMock) + def test_upload_save_with_device_id_requires_scope( + self, + mock_scan, + _mock_write, + client, + token_without_device_scopes: str, + rom: Rom, + platform: Platform, + admin_user: User, + device: Device, + ): + mock_save = Save( + file_name="test.sav", + file_name_no_tags="test", + file_name_no_ext="test", + file_extension="sav", + file_path=f"{platform.slug}/saves", + file_size_bytes=100, + rom_id=rom.id, + user_id=admin_user.id, + ) + mock_scan.return_value = mock_save + + file_content = BytesIO(b"test save data") + response = client.post( + f"/api/saves?rom_id={rom.id}&device_id={device.id}", + files={"saveFile": ("test.sav", file_content, "application/octet-stream")}, + headers={"Authorization": f"Bearer {token_without_device_scopes}"}, + ) + assert response.status_code == status.HTTP_403_FORBIDDEN + + def test_track_save_requires_scope( + self, client, token_without_device_scopes: str, save: Save, device: Device + ): + response = client.post( + f"/api/saves/{save.id}/track", + json={"device_id": device.id}, + headers={"Authorization": f"Bearer {token_without_device_scopes}"}, + ) + assert response.status_code == status.HTTP_403_FORBIDDEN + + def test_untrack_save_requires_scope( + self, client, token_without_device_scopes: str, save: Save, device: Device + ): + response = client.post( + f"/api/saves/{save.id}/untrack", + json={"device_id": device.id}, + headers={"Authorization": f"Bearer {token_without_device_scopes}"}, + ) + assert response.status_code == status.HTTP_403_FORBIDDEN + + +class TestSlotFiltering: + @pytest.fixture + def saves_with_slots( + self, admin_user: User, rom: Rom, platform: Platform + ) -> list[Save]: + from handler.database import db_save_handler + + saves = [] + for i, slot in enumerate([None, "Slot 1", "Slot 1", "Slot 2"]): + save = Save( + file_name=f"save_{i}.sav", + file_name_no_tags=f"save_{i}", + file_name_no_ext=f"save_{i}", + file_extension="sav", + file_path=f"{platform.slug}/saves", + file_size_bytes=100 + i, + rom_id=rom.id, + user_id=admin_user.id, + slot=slot, + ) + saves.append(db_save_handler.add_save(save)) + return saves + + def test_get_saves_without_slot_filter( + self, client, access_token: str, saves_with_slots: list[Save] + ): + response = client.get( + "/api/saves", + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert len(data) >= 4 + for item in data: + assert "slot" in item + assert "id" in item + assert "rom_id" in item + + def test_get_saves_with_slot_filter( + self, client, access_token: str, rom: Rom, saves_with_slots: list[Save] + ): + response = client.get( + f"/api/saves?rom_id={rom.id}&slot=Slot%201", + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert len(data) == 2 + for item in data: + assert item["slot"] == "Slot 1" + + def test_get_saves_with_nonexistent_slot( + self, client, access_token: str, rom: Rom, saves_with_slots: list[Save] + ): + response = client.get( + f"/api/saves?rom_id={rom.id}&slot=NonexistentSlot", + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert len(data) == 0 + + +class TestDatetimeTagging: + @mock.patch( + "endpoints.saves.fs_asset_handler.write_file", new_callable=mock.AsyncMock + ) + @mock.patch("endpoints.saves.scan_save", new_callable=mock.AsyncMock) + def test_upload_with_slot_applies_datetime_tag( + self, + mock_scan, + mock_write, + client, + access_token: str, + rom: Rom, + platform: Platform, + admin_user: User, + ): + import re + + mock_save = Save( + file_name="test [2026-01-31_12-00-00].sav", + file_name_no_tags="test", + file_name_no_ext="test [2026-01-31_12-00-00]", + file_extension="sav", + file_path=f"{platform.slug}/saves", + file_size_bytes=100, + rom_id=rom.id, + user_id=admin_user.id, + slot="main", + ) + mock_scan.return_value = mock_save + + file_content = BytesIO(b"test save data") + response = client.post( + f"/api/saves?rom_id={rom.id}&slot=main", + files={"saveFile": ("test.sav", file_content, "application/octet-stream")}, + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_200_OK + mock_write.assert_called_once() + call_args = mock_write.call_args + written_filename = call_args[1].get("filename") or call_args[0][2] + assert re.search(r" \[\d{4}-\d{2}-\d{2}_\d{2}-\d{2}-\d{2}\]", written_filename) + + @mock.patch( + "endpoints.saves.fs_asset_handler.write_file", new_callable=mock.AsyncMock + ) + @mock.patch("endpoints.saves.scan_save", new_callable=mock.AsyncMock) + def test_upload_without_slot_no_datetime_tag( + self, + mock_scan, + mock_write, + client, + access_token: str, + rom: Rom, + platform: Platform, + admin_user: User, + ): + mock_save = Save( + file_name="test.sav", + file_name_no_tags="test", + file_name_no_ext="test", + file_extension="sav", + file_path=f"{platform.slug}/saves", + file_size_bytes=100, + rom_id=rom.id, + user_id=admin_user.id, + ) + mock_scan.return_value = mock_save + + file_content = BytesIO(b"test save data") + response = client.post( + f"/api/saves?rom_id={rom.id}", + files={"saveFile": ("test.sav", file_content, "application/octet-stream")}, + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_200_OK + mock_write.assert_called_once() + call_args = mock_write.call_args + written_filename = call_args[1].get("filename") or call_args[0][2] + assert written_filename == "test.sav" + + @mock.patch( + "endpoints.saves.fs_asset_handler.write_file", new_callable=mock.AsyncMock + ) + @mock.patch("endpoints.saves.scan_save", new_callable=mock.AsyncMock) + def test_upload_with_existing_datetime_tag_replaces_it( + self, + mock_scan, + mock_write, + client, + access_token: str, + rom: Rom, + platform: Platform, + admin_user: User, + ): + import re + + mock_save = Save( + file_name="test [2026-01-31_12-00-00].sav", + file_name_no_tags="test", + file_name_no_ext="test [2026-01-31_12-00-00]", + file_extension="sav", + file_path=f"{platform.slug}/saves", + file_size_bytes=100, + rom_id=rom.id, + user_id=admin_user.id, + slot="main", + ) + mock_scan.return_value = mock_save + + file_content = BytesIO(b"test save data") + response = client.post( + f"/api/saves?rom_id={rom.id}&slot=main", + files={ + "saveFile": ( + "test [2020-01-01_00-00-00].sav", + file_content, + "application/octet-stream", + ) + }, + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_200_OK + mock_write.assert_called_once() + call_args = mock_write.call_args + written_filename = call_args[1].get("filename") or call_args[0][2] + datetime_matches = re.findall( + r"\[\d{4}-\d{2}-\d{2}_\d{2}-\d{2}-\d{2}\]", written_filename + ) + assert len(datetime_matches) == 1 + assert "2020-01-01" not in written_filename + + +class TestAutocleanup: + @pytest.fixture + def slot_saves(self, admin_user: User, rom: Rom, platform: Platform) -> list[Save]: + from datetime import datetime, timedelta, timezone + + from handler.database import db_save_handler + + saves = [] + base_time = datetime.now(timezone.utc) - timedelta(hours=20) + for i in range(15): + save = Save( + file_name=f"autosave_{i}.sav", + file_name_no_tags=f"autosave_{i}", + file_name_no_ext=f"autosave_{i}", + file_extension="sav", + file_path=f"{platform.slug}/saves", + file_size_bytes=100 + i, + rom_id=rom.id, + user_id=admin_user.id, + slot="autosave", + ) + created = db_save_handler.add_save(save) + db_save_handler.update_save( + created.id, {"updated_at": base_time + timedelta(hours=i)} + ) + saves.append(created) + return saves + + @mock.patch( + "endpoints.saves.fs_asset_handler.write_file", new_callable=mock.AsyncMock + ) + @mock.patch( + "endpoints.saves.fs_asset_handler.remove_file", new_callable=mock.AsyncMock + ) + @mock.patch("endpoints.saves.scan_save", new_callable=mock.AsyncMock) + def test_autocleanup_deletes_old_saves( + self, + mock_scan, + mock_remove, + mock_write, + client, + access_token: str, + rom: Rom, + platform: Platform, + admin_user: User, + slot_saves: list[Save], + ): + from handler.database import db_save_handler + + initial_saves = db_save_handler.get_saves( + user_id=admin_user.id, rom_id=rom.id, slot="autosave" + ) + assert len(initial_saves) == 15 + + mock_save = Save( + file_name="new_autosave.sav", + file_name_no_tags="new_autosave", + file_name_no_ext="new_autosave", + file_extension="sav", + file_path=f"{platform.slug}/saves", + file_size_bytes=100, + rom_id=rom.id, + user_id=admin_user.id, + slot="autosave", + ) + mock_scan.return_value = mock_save + + file_content = BytesIO(b"new save") + response = client.post( + f"/api/saves?rom_id={rom.id}&slot=autosave&autocleanup=true&autocleanup_limit=10", + files={ + "saveFile": ( + "new_autosave.sav", + file_content, + "application/octet-stream", + ) + }, + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_200_OK + assert mock_remove.call_count == 6 + + @mock.patch( + "endpoints.saves.fs_asset_handler.write_file", new_callable=mock.AsyncMock + ) + @mock.patch( + "endpoints.saves.fs_asset_handler.remove_file", new_callable=mock.AsyncMock + ) + @mock.patch("endpoints.saves.scan_save", new_callable=mock.AsyncMock) + def test_autocleanup_disabled_by_default( + self, + mock_scan, + mock_remove, + mock_write, + client, + access_token: str, + rom: Rom, + platform: Platform, + admin_user: User, + slot_saves: list[Save], + ): + mock_save = Save( + file_name="new_save.sav", + file_name_no_tags="new_save", + file_name_no_ext="new_save", + file_extension="sav", + file_path=f"{platform.slug}/saves", + file_size_bytes=100, + rom_id=rom.id, + user_id=admin_user.id, + slot="autosave", + ) + mock_scan.return_value = mock_save + + file_content = BytesIO(b"new save") + response = client.post( + f"/api/saves?rom_id={rom.id}&slot=autosave", + files={ + "saveFile": ("new_save.sav", file_content, "application/octet-stream") + }, + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_200_OK + mock_remove.assert_not_called() + + @mock.patch( + "endpoints.saves.fs_asset_handler.write_file", new_callable=mock.AsyncMock + ) + @mock.patch( + "endpoints.saves.fs_asset_handler.remove_file", new_callable=mock.AsyncMock + ) + @mock.patch("endpoints.saves.scan_save", new_callable=mock.AsyncMock) + def test_autocleanup_without_slot_does_nothing( + self, + mock_scan, + mock_remove, + mock_write, + client, + access_token: str, + rom: Rom, + platform: Platform, + admin_user: User, + ): + mock_save = Save( + file_name="noslotsave.sav", + file_name_no_tags="noslotsave", + file_name_no_ext="noslotsave", + file_extension="sav", + file_path=f"{platform.slug}/saves", + file_size_bytes=100, + rom_id=rom.id, + user_id=admin_user.id, + ) + mock_scan.return_value = mock_save + + file_content = BytesIO(b"no slot save") + response = client.post( + f"/api/saves?rom_id={rom.id}&autocleanup=true&autocleanup_limit=5", + files={ + "saveFile": ("noslotsave.sav", file_content, "application/octet-stream") + }, + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_200_OK + mock_remove.assert_not_called() + + +class TestSavesSummaryEndpoint: + @pytest.fixture + def summary_saves( + self, admin_user: User, rom: Rom, platform: Platform + ) -> list[Save]: + from datetime import datetime, timedelta, timezone + + from handler.database import db_save_handler + + saves = [] + base_time = datetime.now(timezone.utc) - timedelta(hours=10) + + configs = [ + (None, 0), + (None, 1), + (None, 2), + ("Slot A", 3), + ("Slot A", 4), + ("Slot B", 5), + ] + + for slot, offset in configs: + save = Save( + file_name=f"summary_save_{offset}.sav", + file_name_no_tags=f"summary_save_{offset}", + file_name_no_ext=f"summary_save_{offset}", + file_extension="sav", + file_path=f"{platform.slug}/saves", + file_size_bytes=100 + offset, + rom_id=rom.id, + user_id=admin_user.id, + slot=slot, + ) + created = db_save_handler.add_save(save) + db_save_handler.update_save( + created.id, {"updated_at": base_time + timedelta(hours=offset)} + ) + saves.append(created) + return saves + + def test_get_saves_summary( + self, client, access_token: str, rom: Rom, summary_saves: list[Save] + ): + response = client.get( + f"/api/saves/summary?rom_id={rom.id}", + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + + assert "total_count" in data + assert "slots" in data + assert data["total_count"] == 6 + assert isinstance(data["slots"], list) + assert len(data["slots"]) == 3 + + slot_map = {s["slot"]: s for s in data["slots"]} + assert None in slot_map or "null" in str(slot_map.keys()) + assert "Slot A" in slot_map + assert "Slot B" in slot_map + + def test_get_saves_summary_validates_response_schema( + self, client, access_token: str, rom: Rom, summary_saves: list[Save] + ): + response = client.get( + f"/api/saves/summary?rom_id={rom.id}", + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + + assert isinstance(data["total_count"], int) + assert isinstance(data["slots"], list) + + for slot_info in data["slots"]: + assert "slot" in slot_info + assert "count" in slot_info + assert "latest" in slot_info + + assert isinstance(slot_info["count"], int) + assert slot_info["count"] > 0 + + latest = slot_info["latest"] + assert "id" in latest + assert "rom_id" in latest + assert "user_id" in latest + assert "file_name" in latest + assert "created_at" in latest + assert "updated_at" in latest + + def test_get_saves_summary_latest_is_most_recent( + self, client, access_token: str, rom: Rom, summary_saves: list[Save] + ): + response = client.get( + f"/api/saves/summary?rom_id={rom.id}", + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + + slot_a_info = next((s for s in data["slots"] if s["slot"] == "Slot A"), None) + assert slot_a_info is not None + assert slot_a_info["count"] == 2 + assert "summary_save_4" in slot_a_info["latest"]["file_name"] + + def test_get_saves_summary_requires_rom_id(self, client, access_token: str): + response = client.get( + "/api/saves/summary", + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + data = response.json() + assert "detail" in data + assert any("rom_id" in str(err).lower() for err in data["detail"]) + + def test_get_saves_summary_empty_rom(self, client, access_token: str): + response = client.get( + "/api/saves/summary?rom_id=999999", + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["total_count"] == 0 + assert data["slots"] == [] + + def test_get_saves_summary_requires_auth(self, client, rom: Rom): + response = client.get(f"/api/saves/summary?rom_id={rom.id}") + + assert response.status_code == status.HTTP_403_FORBIDDEN + + +class TestSaveDownload: + @mock.patch("endpoints.saves.fs_asset_handler.validate_path") + def test_download_save_without_device_returns_file( + self, + mock_validate_path, + client, + access_token: str, + save: Save, + tmp_path, + ): + test_file = tmp_path / "test.sav" + test_file.write_bytes(b"save file content") + mock_validate_path.return_value = test_file + + response = client.get( + f"/api/saves/{save.id}/content", + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_200_OK + assert response.content == b"save file content" + + sync = db_device_save_sync_handler.get_sync(device_id="any", save_id=save.id) + assert sync is None + + @mock.patch("endpoints.saves.fs_asset_handler.validate_path") + def test_download_save_with_device_returns_file( + self, + mock_validate_path, + client, + access_token: str, + save: Save, + device: Device, + tmp_path, + ): + test_file = tmp_path / "test.sav" + test_file.write_bytes(b"save file content") + mock_validate_path.return_value = test_file + + response = client.get( + f"/api/saves/{save.id}/content?device_id={device.id}", + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_200_OK + assert response.content == b"save file content" + + def test_download_save_not_found(self, client, access_token: str): + response = client.get( + "/api/saves/99999/content", + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_404_NOT_FOUND + assert "99999" in response.json()["detail"] + + @mock.patch("endpoints.saves.fs_asset_handler.validate_path") + def test_download_save_file_missing_on_disk( + self, + mock_validate_path, + client, + access_token: str, + save: Save, + tmp_path, + ): + missing_file = tmp_path / "nonexistent.sav" + mock_validate_path.return_value = missing_file + + response = client.get( + f"/api/saves/{save.id}/content", + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_404_NOT_FOUND + assert "not found on disk" in response.json()["detail"] + + @mock.patch("endpoints.saves.fs_asset_handler.validate_path") + def test_download_save_validate_path_raises( + self, + mock_validate_path, + client, + access_token: str, + save: Save, + ): + mock_validate_path.side_effect = ValueError("Invalid path") + + response = client.get( + f"/api/saves/{save.id}/content", + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_404_NOT_FOUND + assert "not found" in response.json()["detail"].lower() + + @mock.patch("endpoints.saves.fs_asset_handler.validate_path") + def test_download_with_device_id_optimistic_true_updates_sync( + self, + mock_validate_path, + client, + access_token: str, + save: Save, + device: Device, + tmp_path, + ): + test_file = tmp_path / "test.sav" + test_file.write_bytes(b"save content") + mock_validate_path.return_value = test_file + + sync_before = db_device_save_sync_handler.get_sync( + device_id=device.id, save_id=save.id + ) + assert sync_before is None + + response = client.get( + f"/api/saves/{save.id}/content?device_id={device.id}", + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_200_OK + + sync_after = db_device_save_sync_handler.get_sync( + device_id=device.id, save_id=save.id + ) + assert sync_after is not None + assert sync_after.last_synced_at.replace( + microsecond=0, tzinfo=None + ) == save.updated_at.replace(microsecond=0, tzinfo=None) + + @mock.patch("endpoints.saves.fs_asset_handler.validate_path") + def test_download_with_device_id_optimistic_false_no_sync_update( + self, + mock_validate_path, + client, + access_token: str, + save: Save, + device: Device, + tmp_path, + ): + test_file = tmp_path / "test.sav" + test_file.write_bytes(b"save content") + mock_validate_path.return_value = test_file + + response = client.get( + f"/api/saves/{save.id}/content?device_id={device.id}&optimistic=false", + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_200_OK + + sync = db_device_save_sync_handler.get_sync( + device_id=device.id, save_id=save.id + ) + assert sync is None + + @mock.patch("endpoints.saves.fs_asset_handler.validate_path") + def test_download_with_invalid_device_id_returns_404( + self, + mock_validate_path, + client, + access_token: str, + save: Save, + tmp_path, + ): + test_file = tmp_path / "test.sav" + test_file.write_bytes(b"save content") + mock_validate_path.return_value = test_file + + response = client.get( + f"/api/saves/{save.id}/content?device_id=nonexistent-device", + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_404_NOT_FOUND + assert "nonexistent-device" in response.json()["detail"] + + @mock.patch("endpoints.saves.fs_asset_handler.validate_path") + def test_download_without_device_scope_forbidden( + self, + mock_validate_path, + client, + token_without_device_scopes: str, + save: Save, + device: Device, + tmp_path, + ): + test_file = tmp_path / "test.sav" + test_file.write_bytes(b"save content") + mock_validate_path.return_value = test_file + + response = client.get( + f"/api/saves/{save.id}/content?device_id={device.id}", + headers={"Authorization": f"Bearer {token_without_device_scopes}"}, + ) + + assert response.status_code == status.HTTP_403_FORBIDDEN + + +class TestConfirmDownload: + def test_confirm_download_creates_sync_record( + self, + client, + access_token: str, + save: Save, + device: Device, + ): + sync_before = db_device_save_sync_handler.get_sync( + device_id=device.id, save_id=save.id + ) + assert sync_before is None + + response = client.post( + f"/api/saves/{save.id}/downloaded", + json={"device_id": device.id}, + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert len(data["device_syncs"]) == 1 + assert data["device_syncs"][0]["device_id"] == device.id + + sync_after = db_device_save_sync_handler.get_sync( + device_id=device.id, save_id=save.id + ) + assert sync_after is not None + assert sync_after.last_synced_at.replace( + microsecond=0, tzinfo=None + ) == save.updated_at.replace(microsecond=0, tzinfo=None) + + def test_confirm_download_updates_existing_sync( + self, + client, + access_token: str, + save: Save, + device: Device, + ): + from datetime import datetime, timedelta, timezone + + old_sync_time = datetime.now(timezone.utc) - timedelta(hours=5) + db_device_save_sync_handler.upsert_sync( + device_id=device.id, save_id=save.id, synced_at=old_sync_time + ) + + response = client.post( + f"/api/saves/{save.id}/downloaded", + json={"device_id": device.id}, + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_200_OK + + sync = db_device_save_sync_handler.get_sync( + device_id=device.id, save_id=save.id + ) + assert sync.last_synced_at.replace( + microsecond=0, tzinfo=None + ) == save.updated_at.replace(microsecond=0, tzinfo=None) + assert sync.last_synced_at.replace( + microsecond=0, tzinfo=None + ) != old_sync_time.replace(microsecond=0, tzinfo=None) + + def test_confirm_download_updates_device_last_seen( + self, + client, + access_token: str, + save: Save, + device: Device, + ): + original_last_seen = device.last_seen + + response = client.post( + f"/api/saves/{save.id}/downloaded", + json={"device_id": device.id}, + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_200_OK + + updated_device = db_device_handler.get_device( + device_id=device.id, user_id=device.user_id + ) + if original_last_seen: + assert updated_device.last_seen > original_last_seen + else: + assert updated_device.last_seen is not None + + def test_confirm_download_save_not_found( + self, + client, + access_token: str, + device: Device, + ): + response = client.post( + "/api/saves/99999/downloaded", + json={"device_id": device.id}, + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_404_NOT_FOUND + assert "99999" in response.json()["detail"] + + def test_confirm_download_device_not_found( + self, + client, + access_token: str, + save: Save, + ): + response = client.post( + f"/api/saves/{save.id}/downloaded", + json={"device_id": "nonexistent-device"}, + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_404_NOT_FOUND + assert "nonexistent-device" in response.json()["detail"] + + +class TestContentHashDeduplication: + @mock.patch( + "endpoints.saves.fs_asset_handler.write_file", new_callable=mock.AsyncMock + ) + @mock.patch("endpoints.saves.scan_save") + def test_slot_upload_includes_content_hash( + self, + mock_scan_save, + mock_write_file, + client, + access_token: str, + rom: Rom, + ): + from models.assets import Save as SaveModel + + mock_save = SaveModel( + id=999, + file_name="test [2026-01-31_12-00-00].sav", + file_name_no_tags="test.sav", + file_name_no_ext="test [2026-01-31_12-00-00]", + file_extension="sav", + file_path="/saves/path", + file_size_bytes=1024, + content_hash="abc123def456789012345678901234ab", + rom_id=rom.id, + user_id=1, + ) + mock_scan_save.return_value = mock_save + + response = client.post( + "/api/saves", + params={"rom_id": rom.id, "slot": "Slot1"}, + files={ + "saveFile": ( + "test.sav", + BytesIO(b"save content"), + "application/octet-stream", + ) + }, + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert "content_hash" in data + assert data["content_hash"] == "abc123def456789012345678901234ab" + + @mock.patch( + "endpoints.saves.fs_asset_handler.write_file", new_callable=mock.AsyncMock + ) + @mock.patch( + "endpoints.saves.fs_asset_handler.remove_file", new_callable=mock.AsyncMock + ) + @mock.patch("endpoints.saves.scan_save") + def test_duplicate_hash_returns_existing_save( + self, + mock_scan_save, + mock_remove_file, + mock_write_file, + client, + access_token: str, + rom: Rom, + save: Save, + ): + from handler.database import db_save_handler + + db_save_handler.update_save( + save.id, {"content_hash": "duplicate_hash_12345678901234"} + ) + + from models.assets import Save as SaveModel + + mock_save = SaveModel( + id=None, + file_name="new [2026-01-31_12-00-00].sav", + file_name_no_tags="new.sav", + file_name_no_ext="new [2026-01-31_12-00-00]", + file_extension="sav", + file_path="/saves/path", + file_size_bytes=1024, + content_hash="duplicate_hash_12345678901234", + rom_id=rom.id, + user_id=1, + ) + mock_scan_save.return_value = mock_save + + response = client.post( + "/api/saves", + params={"rom_id": rom.id, "slot": "Slot1"}, + files={ + "saveFile": ( + "new.sav", + BytesIO(b"save content"), + "application/octet-stream", + ) + }, + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["id"] == save.id + assert data["content_hash"] == "duplicate_hash_12345678901234" + mock_remove_file.assert_called_once() + + @mock.patch( + "endpoints.saves.fs_asset_handler.write_file", new_callable=mock.AsyncMock + ) + @mock.patch("endpoints.saves.scan_save") + def test_duplicate_hash_with_overwrite_succeeds( + self, + mock_scan_save, + mock_write_file, + client, + access_token: str, + rom: Rom, + save: Save, + ): + from handler.database import db_save_handler + + db_save_handler.update_save( + save.id, {"content_hash": "duplicate_hash_12345678901234"} + ) + + from models.assets import Save as SaveModel + + mock_save = SaveModel( + id=None, + file_name="new [2026-01-31_12-00-00].sav", + file_name_no_tags="new.sav", + file_name_no_ext="new [2026-01-31_12-00-00]", + file_extension="sav", + file_path="/saves/path", + file_size_bytes=1024, + content_hash="duplicate_hash_12345678901234", + rom_id=rom.id, + user_id=1, + ) + mock_scan_save.return_value = mock_save + + response = client.post( + "/api/saves", + params={"rom_id": rom.id, "slot": "Slot1", "overwrite": True}, + files={ + "saveFile": ( + "new.sav", + BytesIO(b"save content"), + "application/octet-stream", + ) + }, + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_200_OK + + @mock.patch( + "endpoints.saves.fs_asset_handler.write_file", new_callable=mock.AsyncMock + ) + @mock.patch("endpoints.saves.scan_save") + def test_non_slot_upload_no_dedup_block( + self, + mock_scan_save, + mock_write_file, + client, + access_token: str, + rom: Rom, + save: Save, + ): + from handler.database import db_save_handler + + db_save_handler.update_save( + save.id, {"content_hash": "duplicate_hash_12345678901234"} + ) + + from models.assets import Save as SaveModel + + mock_save = SaveModel( + id=None, + file_name="new.sav", + file_name_no_tags="new.sav", + file_name_no_ext="new", + file_extension="sav", + file_path="/saves/path", + file_size_bytes=1024, + content_hash="duplicate_hash_12345678901234", + rom_id=rom.id, + user_id=1, + ) + mock_scan_save.return_value = mock_save + + response = client.post( + "/api/saves", + params={"rom_id": rom.id}, + files={ + "saveFile": ( + "new.sav", + BytesIO(b"save content"), + "application/octet-stream", + ) + }, + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_200_OK + + +class TestContentHashComputation: + def test_compute_file_hash(self, tmp_path): + from handler.filesystem.assets_handler import compute_file_hash + + test_file = tmp_path / "test.sav" + test_file.write_bytes(b"test content for hashing") + + hash_result = compute_file_hash(str(test_file)) + + assert hash_result is not None + assert len(hash_result) == 32 + + hash_result2 = compute_file_hash(str(test_file)) + assert hash_result == hash_result2 + + def test_same_content_produces_same_hash(self, tmp_path): + from handler.filesystem.assets_handler import compute_file_hash + + file1 = tmp_path / "save1.sav" + file2 = tmp_path / "save2.sav" + file1.write_bytes(b"identical content") + file2.write_bytes(b"identical content") + + hash1 = compute_file_hash(str(file1)) + hash2 = compute_file_hash(str(file2)) + + assert hash1 == hash2 + + def test_different_content_produces_different_hash(self, tmp_path): + from handler.filesystem.assets_handler import compute_file_hash + + file1 = tmp_path / "save1.sav" + file2 = tmp_path / "save2.sav" + file1.write_bytes(b"content A") + file2.write_bytes(b"content B") + + hash1 = compute_file_hash(str(file1)) + hash2 = compute_file_hash(str(file2)) + + assert hash1 != hash2 diff --git a/backend/tests/handler/database/test_saves_handler.py b/backend/tests/handler/database/test_saves_handler.py index 066cb35d6..c20ee35c9 100644 --- a/backend/tests/handler/database/test_saves_handler.py +++ b/backend/tests/handler/database/test_saves_handler.py @@ -166,3 +166,280 @@ def test_get_save_by_id_with_platform_context( # Verify the save is associated with the correct platform through ROM assert retrieved_save.rom.platform_id == platform.id + + +class TestDBSavesHandlerSlotFiltering: + def test_get_saves_with_slot_filter(self, admin_user: User, rom: Rom): + save1 = Save( + rom_id=rom.id, + user_id=admin_user.id, + file_name="slot_test_1.sav", + file_name_no_tags="slot_test_1", + file_name_no_ext="slot_test_1", + file_extension="sav", + emulator="test_emu", + file_path=f"{rom.platform_slug}/saves", + file_size_bytes=100, + slot="Slot A", + ) + save2 = Save( + rom_id=rom.id, + user_id=admin_user.id, + file_name="slot_test_2.sav", + file_name_no_tags="slot_test_2", + file_name_no_ext="slot_test_2", + file_extension="sav", + emulator="test_emu", + file_path=f"{rom.platform_slug}/saves", + file_size_bytes=100, + slot="Slot A", + ) + save3 = Save( + rom_id=rom.id, + user_id=admin_user.id, + file_name="slot_test_3.sav", + file_name_no_tags="slot_test_3", + file_name_no_ext="slot_test_3", + file_extension="sav", + emulator="test_emu", + file_path=f"{rom.platform_slug}/saves", + file_size_bytes=100, + slot="Slot B", + ) + + db_save_handler.add_save(save1) + db_save_handler.add_save(save2) + db_save_handler.add_save(save3) + + slot_a_saves = db_save_handler.get_saves( + user_id=admin_user.id, rom_id=rom.id, slot="Slot A" + ) + assert len(slot_a_saves) == 2 + assert all(s.slot == "Slot A" for s in slot_a_saves) + + slot_b_saves = db_save_handler.get_saves( + user_id=admin_user.id, rom_id=rom.id, slot="Slot B" + ) + assert len(slot_b_saves) == 1 + assert slot_b_saves[0].slot == "Slot B" + + def test_get_saves_with_null_slot_filter(self, admin_user: User, rom: Rom): + save_with_slot = Save( + rom_id=rom.id, + user_id=admin_user.id, + file_name="with_slot.sav", + file_name_no_tags="with_slot", + file_name_no_ext="with_slot", + file_extension="sav", + emulator="test_emu", + file_path=f"{rom.platform_slug}/saves", + file_size_bytes=100, + slot="Main", + ) + save_without_slot = Save( + rom_id=rom.id, + user_id=admin_user.id, + file_name="without_slot.sav", + file_name_no_tags="without_slot", + file_name_no_ext="without_slot", + file_extension="sav", + emulator="test_emu", + file_path=f"{rom.platform_slug}/saves", + file_size_bytes=100, + slot=None, + ) + + db_save_handler.add_save(save_with_slot) + db_save_handler.add_save(save_without_slot) + + all_saves = db_save_handler.get_saves(user_id=admin_user.id, rom_id=rom.id) + assert len(all_saves) >= 2 + + def test_get_saves_order_by(self, admin_user: User, rom: Rom): + from datetime import datetime, timedelta, timezone + + base_time = datetime.now(timezone.utc) + + save1 = Save( + rom_id=rom.id, + user_id=admin_user.id, + file_name="order_test_1.sav", + file_name_no_tags="order_test_1", + file_name_no_ext="order_test_1", + file_extension="sav", + emulator="test_emu", + file_path=f"{rom.platform_slug}/saves", + file_size_bytes=100, + slot="order_test", + ) + save2 = Save( + rom_id=rom.id, + user_id=admin_user.id, + file_name="order_test_2.sav", + file_name_no_tags="order_test_2", + file_name_no_ext="order_test_2", + file_extension="sav", + emulator="test_emu", + file_path=f"{rom.platform_slug}/saves", + file_size_bytes=100, + slot="order_test", + ) + + created1 = db_save_handler.add_save(save1) + created2 = db_save_handler.add_save(save2) + + db_save_handler.update_save( + created1.id, {"updated_at": base_time - timedelta(hours=2)} + ) + db_save_handler.update_save( + created2.id, {"updated_at": base_time - timedelta(hours=1)} + ) + + ordered_saves_desc = db_save_handler.get_saves( + user_id=admin_user.id, + rom_id=rom.id, + slot="order_test", + order_by="updated_at", + ) + + assert len(ordered_saves_desc) == 2 + assert ordered_saves_desc[0].id == created2.id + assert ordered_saves_desc[1].id == created1.id + + ordered_saves_asc = db_save_handler.get_saves( + user_id=admin_user.id, + rom_id=rom.id, + slot="order_test", + order_by="updated_at", + order_dir="asc", + ) + + assert len(ordered_saves_asc) == 2 + assert ordered_saves_asc[0].id == created1.id + assert ordered_saves_asc[1].id == created2.id + + +class TestDBSavesHandlerSummary: + def test_get_saves_summary_basic(self, admin_user: User, rom: Rom): + from datetime import datetime, timedelta, timezone + + base_time = datetime.now(timezone.utc) + + configs = [ + ("summary_a_1.sav", "Slot A", -3), + ("summary_a_2.sav", "Slot A", -1), + ("summary_b_1.sav", "Slot B", -2), + ("summary_none_1.sav", None, -4), + ] + + for filename, slot, hours_offset in configs: + save = Save( + rom_id=rom.id, + user_id=admin_user.id, + file_name=filename, + file_name_no_tags=filename.replace(".sav", ""), + file_name_no_ext=filename.replace(".sav", ""), + file_extension="sav", + emulator="test_emu", + file_path=f"{rom.platform_slug}/saves", + file_size_bytes=100, + slot=slot, + ) + created = db_save_handler.add_save(save) + db_save_handler.update_save( + created.id, {"updated_at": base_time + timedelta(hours=hours_offset)} + ) + + summary = db_save_handler.get_saves_summary( + user_id=admin_user.id, rom_id=rom.id + ) + + assert "total_count" in summary + assert "slots" in summary + assert summary["total_count"] == 4 + assert len(summary["slots"]) == 3 + + def test_get_saves_summary_latest_per_slot(self, admin_user: User, rom: Rom): + from datetime import datetime, timedelta, timezone + + base_time = datetime.now(timezone.utc) + + old_save = Save( + rom_id=rom.id, + user_id=admin_user.id, + file_name="latest_test_old.sav", + file_name_no_tags="latest_test_old", + file_name_no_ext="latest_test_old", + file_extension="sav", + emulator="test_emu", + file_path=f"{rom.platform_slug}/saves", + file_size_bytes=100, + slot="latest_test", + ) + new_save = Save( + rom_id=rom.id, + user_id=admin_user.id, + file_name="latest_test_new.sav", + file_name_no_tags="latest_test_new", + file_name_no_ext="latest_test_new", + file_extension="sav", + emulator="test_emu", + file_path=f"{rom.platform_slug}/saves", + file_size_bytes=100, + slot="latest_test", + ) + + old_created = db_save_handler.add_save(old_save) + new_created = db_save_handler.add_save(new_save) + + db_save_handler.update_save( + old_created.id, {"updated_at": base_time - timedelta(hours=5)} + ) + db_save_handler.update_save( + new_created.id, {"updated_at": base_time - timedelta(hours=1)} + ) + + summary = db_save_handler.get_saves_summary( + user_id=admin_user.id, rom_id=rom.id + ) + + latest_slot = next( + (s for s in summary["slots"] if s["slot"] == "latest_test"), None + ) + assert latest_slot is not None + assert latest_slot["count"] == 2 + assert latest_slot["latest"].file_name == "latest_test_new.sav" + + def test_get_saves_summary_empty_rom(self, admin_user: User): + summary = db_save_handler.get_saves_summary( + user_id=admin_user.id, rom_id=999999 + ) + + assert summary["total_count"] == 0 + assert summary["slots"] == [] + + def test_get_saves_summary_count_accuracy(self, admin_user: User, rom: Rom): + for i in range(5): + save = Save( + rom_id=rom.id, + user_id=admin_user.id, + file_name=f"count_test_{i}.sav", + file_name_no_tags=f"count_test_{i}", + file_name_no_ext=f"count_test_{i}", + file_extension="sav", + emulator="test_emu", + file_path=f"{rom.platform_slug}/saves", + file_size_bytes=100, + slot="count_test", + ) + db_save_handler.add_save(save) + + summary = db_save_handler.get_saves_summary( + user_id=admin_user.id, rom_id=rom.id + ) + + count_slot = next( + (s for s in summary["slots"] if s["slot"] == "count_test"), None + ) + assert count_slot is not None + assert count_slot["count"] == 5 diff --git a/backend/utils/datetime.py b/backend/utils/datetime.py new file mode 100644 index 000000000..ed2cf705a --- /dev/null +++ b/backend/utils/datetime.py @@ -0,0 +1,7 @@ +from datetime import datetime, timezone + + +def to_utc(dt: datetime) -> datetime: + if dt.tzinfo is None: + return dt.replace(tzinfo=timezone.utc) + return dt.astimezone(timezone.utc)