diff --git a/backend/src/librarysync/api/routes_admin.py b/backend/src/librarysync/api/routes_admin.py index 8d6dcd9..bf098aa 100644 --- a/backend/src/librarysync/api/routes_admin.py +++ b/backend/src/librarysync/api/routes_admin.py @@ -1,12 +1,26 @@ from datetime import datetime, timedelta, timezone +from typing import Any from fastapi import APIRouter, Depends, Query from fastapi.responses import JSONResponse -from sqlalchemy import delete, func, select, update +from pydantic import BaseModel, Field +from sqlalchemy import delete, func, or_, select, update from sqlalchemy.ext.asyncio import AsyncSession from librarysync.api.deps import get_admin_api_key, get_db -from librarysync.db.models import OutboxJob, ScheduledJob, User, WatchEvent +from librarysync.core.watch_pipeline import _merge_media_fields +from librarysync.db.models import ( + EpisodeItem, + MediaItem, + OutboxJob, + ScheduledJob, + User, + WatchedItem, + WatchEvent, + WatchlistItem, + WatchlistSourceItem, + WatchSync, +) from librarysync.jobs.merge_history import merge_history_for_user from librarysync.jobs.metadata_backfill import ( METADATA_BACKFILL_FORCE_JOB, @@ -27,6 +41,24 @@ } +class MediaIdUpdate(BaseModel): + id: str = Field(..., description="Media item ID to update") + imdb: str | None = Field(None, description="IMDb ID to set (or null to clear)") + tmdb: str | None = Field(None, description="TMDB ID to set (or null to clear)") + tvdb: str | None = Field(None, description="TVDB ID to set (or null to clear)") + tvmaze: str | None = Field(None, description="TVMaze ID to set (or null to clear)") + kitsu: str | None = Field(None, description="Kitsu ID to set (or null to clear)") + myanimelist: str | None = Field(None, description="MyAnimeList ID to set (or null to clear)") + anilist: str | None = Field(None, description="AniList ID to set (or null to clear)") + + +class MediaUpdateRequest(BaseModel): + updates: list[MediaIdUpdate] = Field( + max_length=100, description="List of media items to update" + ) + dry_run: bool = Field(False, description="If true, preview changes without committing") + + @router.post( "/reset-outbox-jobs", summary="Reset stuck outbox jobs", @@ -355,3 +387,378 @@ async def merge_history( "user_results": user_results, } ) + + +@router.post( + "/media/update-external-ids", + summary="Update external IDs for media items", + description=( + "Update external IDs for one or more media items. If updating an ID causes " + "a conflict (another media item already has that ID), the items will be merged: " + "all dependent objects will be migrated to the target item and the duplicate " + "will be deleted. Use dry_run=true to preview changes." + ), +) +async def update_media_external_ids( + request: MediaUpdateRequest, + db: AsyncSession = Depends(get_db), + _: str = Depends(get_admin_api_key), +) -> JSONResponse: + # Load all target media items upfront + ids = [u.id for u in request.updates] + result = await db.execute(select(MediaItem).where(MediaItem.id.in_(ids))) + target_items = result.scalars().all() + target_map = {item.id: item for item in target_items} + + results = [] + total_updated = 0 + total_merged = 0 + total_unchanged = 0 + total_errors = 0 + + for update_item in request.updates: + target_media = target_map.get(update_item.id) + if not target_media: + results.append( + { + "id": update_item.id, + "status": "error", + "message": "Media item not found", + } + ) + total_errors += 1 + continue + try: + result = await _process_media_id_update( + db, update_item, target_media, dry_run=request.dry_run + ) + results.append(result) + + if result["status"] == "updated": + total_updated += 1 + elif result["status"] == "merged": + total_merged += 1 + elif result["status"] == "unchanged": + total_unchanged += 1 + elif result["status"] == "error": + total_errors += 1 + + except Exception as e: + await db.rollback() + results.append( + { + "id": update_item.id, + "status": "error", + "message": str(e)[:500], + } + ) + total_errors += 1 + + return JSONResponse( + { + "results": results, + "total_updated": total_updated, + "total_merged": total_merged, + "total_unchanged": total_unchanged, + "total_errors": total_errors, + } + ) + + +async def _process_media_id_update( + db: AsyncSession, update_item: MediaIdUpdate, target_media: MediaItem, dry_run: bool +) -> dict[str, Any]: + if not any( + [ + update_item.imdb is not None, + update_item.tmdb is not None, + update_item.tvdb is not None, + update_item.tvmaze is not None, + update_item.kitsu is not None, + update_item.myanimelist is not None, + update_item.anilist is not None, + ] + ): + return { + "id": update_item.id, + "status": "unchanged", + "message": "No changes requested", + } + + changes_made = False + merge_info: dict[str, Any] | None = None + + updates = { + "imdb": update_item.imdb, + "tmdb": update_item.tmdb, + "tvdb": update_item.tvdb, + "tvmaze": update_item.tvmaze, + "kitsu": update_item.kitsu, + "myanimelist": update_item.myanimelist, + "anilist": update_item.anilist, + } + + for field, value in updates.items(): + if value is None: + continue + + if field == "imdb": + conflict = await _find_conflict_by_imdb(db, value, target_media.id) + else: + conflict = await _find_conflict_by_provider( + db, field, value, target_media.media_type, target_media.id + ) + + if conflict: + if conflict.id == target_media.id: + current_value = getattr(target_media, f"{field}_id") + if current_value == value.lower() if field == "imdb" else value: + continue + setattr(target_media, f"{field}_id", value.lower() if field == "imdb" else value) + changes_made = True + else: + merge_info = await _merge_media_items(db, target_media, conflict, dry_run) + target_media = conflict if merge_info["kept_id"] == conflict.id else target_media + if not dry_run: + await db.commit() + return { + "id": update_item.id, + "status": "merged", + "message": f"Merged with {conflict.id} due to {field} conflict", + **merge_info, + } + else: + current_value = getattr(target_media, f"{field}_id") + normalized_value = value.lower() if field == "imdb" else value + if current_value != normalized_value: + setattr(target_media, f"{field}_id", normalized_value) + changes_made = True + + if not changes_made: + return { + "id": update_item.id, + "status": "unchanged", + "message": "No changes made - all IDs already set to requested values", + } + + if not dry_run: + await db.commit() + + return { + "id": update_item.id, + "status": "updated", + "message": "External IDs updated successfully", + } + + +async def _find_conflict_by_imdb( + db: AsyncSession, imdb_id: str, exclude_id: str +) -> MediaItem | None: + result = await db.execute( + select(MediaItem).where( + MediaItem.imdb_id == imdb_id.lower(), + MediaItem.id != exclude_id, + ) + ) + return result.scalars().first() + + +async def _find_conflict_by_provider( + db: AsyncSession, provider: str, provider_id: str, media_type: str, exclude_id: str +) -> MediaItem | None: + result = await db.execute( + select(MediaItem).where( + getattr(MediaItem, f"{provider}_id") == provider_id, + MediaItem.media_type == media_type, + MediaItem.id != exclude_id, + ) + ) + return result.scalars().first() + + +async def _merge_media_items( + db: AsyncSession, target: MediaItem, duplicate: MediaItem, dry_run: bool +) -> dict[str, Any]: + keep = target + remove = duplicate + migrated: dict[str, int] = {} + + if dry_run: + watched_result = await db.execute( + select(func.count()) + .select_from(WatchedItem) + .where(WatchedItem.media_item_id == remove.id) + ) + migrated["watched_count"] = int(watched_result.scalar() or 0) + + # Deduplicate episodes: count non-conflicting remove episodes + keep_episode_keys_result = await db.execute( + select(EpisodeItem.season_number, EpisodeItem.episode_number).where( + EpisodeItem.show_media_item_id == keep.id + ) + ) + keep_episode_keys = { + (row.season_number, row.episode_number) for row in keep_episode_keys_result.all() + } + + remove_episode_keys_result = await db.execute( + select(EpisodeItem.season_number, EpisodeItem.episode_number).where( + EpisodeItem.show_media_item_id == remove.id + ) + ) + remove_episode_keys = [ + (row.season_number, row.episode_number) for row in remove_episode_keys_result.all() + ] + migrated["episode_count"] = len( + [k for k in remove_episode_keys if k not in keep_episode_keys] + ) + + # Deduplicate watchlist: count non-conflicting remove watchlist + keep_watchlist_users_result = await db.execute( + select(WatchlistItem.user_id).where(WatchlistItem.media_item_id == keep.id) + ) + keep_watchlist_users = {row.user_id for row in keep_watchlist_users_result.all()} + + remove_watchlist_users_result = await db.execute( + select(WatchlistItem.user_id).where(WatchlistItem.media_item_id == remove.id) + ) + remove_watchlist_users = [row.user_id for row in remove_watchlist_users_result.all()] + migrated["watchlist_count"] = len( + [u for u in remove_watchlist_users if u not in keep_watchlist_users] + ) + + watch_event_result = await db.execute( + select(func.count()) + .select_from(WatchEvent) + .where(WatchEvent.media_item_id == remove.id) + ) + migrated["watch_event_count"] = int(watch_event_result.scalar() or 0) + + watchlist_source_result = await db.execute( + select(func.count()) + .select_from(WatchlistSourceItem) + .where(WatchlistSourceItem.media_item_id == remove.id) + ) + migrated["watchlist_source_count"] = int(watchlist_source_result.scalar() or 0) + + return { + "kept_id": keep.id, + "merged_from": remove.id, + "migrated": migrated, + } + + watched_result = await db.execute( + select(WatchedItem).where(WatchedItem.media_item_id == remove.id) + ) + watched_items = watched_result.scalars().all() + + for watched in watched_items: + watched.media_item_id = keep.id + + migrated["watched_count"] = len(watched_items) + + # Deduplicate episodes: get existing episode keys for keep + keep_episode_keys_result = await db.execute( + select(EpisodeItem.season_number, EpisodeItem.episode_number).where( + EpisodeItem.show_media_item_id == keep.id + ) + ) + keep_episode_keys = { + (row.season_number, row.episode_number) for row in keep_episode_keys_result.all() + } + + episode_result = await db.execute( + select(EpisodeItem).where(EpisodeItem.show_media_item_id == remove.id) + ) + remove_episodes = episode_result.scalars().all() + + reassigned_episodes = [] + for episode in remove_episodes: + if (episode.season_number, episode.episode_number) in keep_episode_keys: + await db.delete(episode) + else: + episode.show_media_item_id = keep.id + reassigned_episodes.append(episode) + + migrated["episode_count"] = len(reassigned_episodes) + + # Deduplicate watchlist items: get existing user_ids for keep + keep_watchlist_users_result = await db.execute( + select(WatchlistItem.user_id).where(WatchlistItem.media_item_id == keep.id) + ) + keep_watchlist_users = {row.user_id for row in keep_watchlist_users_result.all()} + + watchlist_result = await db.execute( + select(WatchlistItem).where(WatchlistItem.media_item_id == remove.id) + ) + remove_watchlist_items = watchlist_result.scalars().all() + + reassigned_watchlist = [] + for wl_item in remove_watchlist_items: + if wl_item.user_id in keep_watchlist_users: + await db.delete(wl_item) + else: + wl_item.media_item_id = keep.id + reassigned_watchlist.append(wl_item) + + migrated["watchlist_count"] = len(reassigned_watchlist) + + # Reassign WatchEvents + watch_event_result = await db.execute( + select(WatchEvent).where(WatchEvent.media_item_id == remove.id) + ) + watch_events = watch_event_result.scalars().all() + + for event in watch_events: + event.media_item_id = keep.id + + migrated["watch_event_count"] = len(watch_events) + + # Reassign WatchlistSourceItems + watchlist_source_result = await db.execute( + select(WatchlistSourceItem).where(WatchlistSourceItem.media_item_id == remove.id) + ) + watchlist_source_items = watchlist_source_result.scalars().all() + + for source_item in watchlist_source_items: + source_item.media_item_id = keep.id + + migrated["watchlist_source_count"] = len(watchlist_source_items) + + watched_ids = [w.id for w in watched_items] + + if watched_ids: + sync_result = await db.execute( + select(WatchSync).where(WatchSync.watched_item_id.in_(watched_ids)) + ) + syncs = sync_result.scalars().all() + + outbox_result = await db.execute( + select(OutboxJob).where( + or_( + OutboxJob.payload["watched_item_id"].as_string().in_(watched_ids), + OutboxJob.payload["watch_sync_id"].in_([s.id for s in syncs]), + ) + ) + ) + jobs = outbox_result.scalars().all() + + for job in jobs: + payload = dict(job.payload or {}) + watched_id = payload.get("watched_item_id") + if watched_id in watched_ids: + payload["watched_item_id"] = watched_id + sync_id = payload.get("watch_sync_id") + if sync_id and any(s.id == sync_id for s in syncs): + payload["watch_sync_id"] = sync_id + job.payload = payload + + _merge_media_fields(keep, remove) + + await db.delete(remove) + + return { + "kept_id": keep.id, + "merged_from": remove.id, + "migrated": migrated, + } diff --git a/backend/src/librarysync/core/watch_pipeline.py b/backend/src/librarysync/core/watch_pipeline.py index b402136..973eb5d 100644 --- a/backend/src/librarysync/core/watch_pipeline.py +++ b/backend/src/librarysync/core/watch_pipeline.py @@ -1308,3 +1308,43 @@ def _coerce_int(value: object) -> int | None: if cleaned.isdigit(): return int(cleaned) return None + + +def _merge_media_fields(target: MediaItem, source: MediaItem) -> None: + if not target.imdb_id and source.imdb_id: + target.imdb_id = source.imdb_id + if not target.tmdb_id and source.tmdb_id: + target.tmdb_id = source.tmdb_id + if not target.tvdb_id and source.tvdb_id: + target.tvdb_id = source.tvdb_id + if not target.tvmaze_id and source.tvmaze_id: + target.tvmaze_id = source.tvmaze_id + if not target.kitsu_id and source.kitsu_id: + target.kitsu_id = source.kitsu_id + if not target.myanimelist_id and source.myanimelist_id: + target.myanimelist_id = source.myanimelist_id + if not target.anilist_id and source.anilist_id: + target.anilist_id = source.anilist_id + if target.year is None and source.year is not None: + target.year = source.year + if not target.poster_url and source.poster_url: + target.poster_url = source.poster_url + if not target.title and source.title: + target.title = source.title + if not target.overview and source.overview: + target.overview = source.overview + if not target.genres and source.genres: + target.genres = source.genres + if target.runtime_in_seconds is None and source.runtime_in_seconds is not None: + target.runtime_in_seconds = source.runtime_in_seconds + if not target.release_date and source.release_date: + target.release_date = source.release_date + if not target.first_air_date and source.first_air_date: + target.first_air_date = source.first_air_date + if not target.last_air_date and source.last_air_date: + target.last_air_date = source.last_air_date + if isinstance(target.raw, dict) and isinstance(source.raw, dict): + for key, value in source.raw.items(): + target.raw.setdefault(key, value) + elif target.raw is None and source.raw is not None: + target.raw = dict(source.raw) diff --git a/backend/src/librarysync/jobs/merge_history.py b/backend/src/librarysync/jobs/merge_history.py index 59b2364..ffc9948 100644 --- a/backend/src/librarysync/jobs/merge_history.py +++ b/backend/src/librarysync/jobs/merge_history.py @@ -30,6 +30,7 @@ extend_scheduled_job, release_scheduled_job, ) +from librarysync.core.watch_pipeline import _merge_media_fields from librarysync.db.models import ( EpisodeItem, Integration, @@ -298,23 +299,7 @@ def _row_sort_key(row: WatchedRow, priority_map: dict[str, int]) -> tuple[int, i def _merge_media(primary: MediaItem, others: list[MediaItem]) -> None: for other in others: - if not primary.imdb_id and other.imdb_id: - primary.imdb_id = other.imdb_id - if not primary.tmdb_id and other.tmdb_id: - primary.tmdb_id = other.tmdb_id - if not primary.tvdb_id and other.tvdb_id: - primary.tvdb_id = other.tvdb_id - if primary.year is None and other.year is not None: - primary.year = other.year - if not primary.poster_url and other.poster_url: - primary.poster_url = other.poster_url - if other.title and (not primary.title or len(other.title) > len(primary.title)): - primary.title = other.title - if isinstance(primary.raw, dict) and isinstance(other.raw, dict): - for key, value in other.raw.items(): - primary.raw.setdefault(key, value) - elif primary.raw is None and other.raw is not None: - primary.raw = dict(other.raw) + _merge_media_fields(primary, other) def _merge_watched(primary: WatchedItem, others: list[WatchedItem]) -> None: diff --git a/backend/tests/test_routes_admin.py b/backend/tests/test_routes_admin.py new file mode 100644 index 0000000..d247eab --- /dev/null +++ b/backend/tests/test_routes_admin.py @@ -0,0 +1,107 @@ +import sys +import unittest +from pathlib import Path + +PROJECT_ROOT = Path(__file__).resolve().parents[1] +sys.path.append(str(PROJECT_ROOT / "src")) + +from librarysync.api import routes_admin # noqa: E402 + + +class TestAdminMediaUpdate(unittest.TestCase): + def test_media_id_update_with_all_fields(self) -> None: + model = routes_admin.MediaIdUpdate( + id="test-id", + imdb="tt1234567", + tmdb="12345", + tvdb="67890", + tvmaze="123", + kitsu="456", + myanimelist="789", + anilist="101112", + ) + self.assertEqual(model.id, "test-id") + self.assertEqual(model.imdb, "tt1234567") + self.assertEqual(model.tmdb, "12345") + self.assertEqual(model.tvdb, "67890") + self.assertEqual(model.tvmaze, "123") + self.assertEqual(model.kitsu, "456") + self.assertEqual(model.myanimelist, "789") + self.assertEqual(model.anilist, "101112") + + def test_media_id_update_with_partial_fields(self) -> None: + model = routes_admin.MediaIdUpdate( + id="test-id", + imdb="tt1234567", + tmdb="12345", + ) + self.assertEqual(model.id, "test-id") + self.assertEqual(model.imdb, "tt1234567") + self.assertEqual(model.tmdb, "12345") + self.assertIsNone(model.tvdb) + self.assertIsNone(model.tvmaze) + + def test_media_id_update_with_null_fields(self) -> None: + model = routes_admin.MediaIdUpdate( + id="test-id", + imdb=None, + tmdb=None, + ) + self.assertEqual(model.id, "test-id") + self.assertIsNone(model.imdb) + self.assertIsNone(model.tmdb) + + def test_media_update_request_with_single_item(self) -> None: + request = routes_admin.MediaUpdateRequest( + updates=[ + routes_admin.MediaIdUpdate( + id="test-id", + tmdb="12345", + ) + ], + dry_run=False, + ) + self.assertEqual(len(request.updates), 1) + self.assertEqual(request.updates[0].id, "test-id") + self.assertEqual(request.updates[0].tmdb, "12345") + self.assertFalse(request.dry_run) + + def test_media_update_request_with_multiple_items(self) -> None: + request = routes_admin.MediaUpdateRequest( + updates=[ + routes_admin.MediaIdUpdate( + id="test-id-1", + tmdb="12345", + ), + routes_admin.MediaIdUpdate( + id="test-id-2", + imdb="tt67890", + ), + ], + dry_run=True, + ) + self.assertEqual(len(request.updates), 2) + self.assertEqual(request.updates[0].id, "test-id-1") + self.assertEqual(request.updates[1].id, "test-id-2") + self.assertTrue(request.dry_run) + + def test_media_update_request_default_dry_run(self) -> None: + request = routes_admin.MediaUpdateRequest( + updates=[ + routes_admin.MediaIdUpdate( + id="test-id", + tmdb="12345", + ) + ] + ) + self.assertFalse(request.dry_run) + + def test_media_id_update_requires_id(self) -> None: + with self.assertRaises(ValueError): + routes_admin.MediaIdUpdate( + tmdb="12345", + ) + + +if __name__ == "__main__": + unittest.main()