diff --git a/backend/src/librarysync/jobs/aiostreams_import.py b/backend/src/librarysync/jobs/aiostreams_import.py index 742c66b..dd4dbe1 100644 --- a/backend/src/librarysync/jobs/aiostreams_import.py +++ b/backend/src/librarysync/jobs/aiostreams_import.py @@ -25,6 +25,7 @@ EpisodeItem, Integration, MediaItem, + WatchedItem, ) from librarysync.jobs.import_base import ImportContext, ImportResult, ImportStrategy from librarysync.jobs.import_pipeline import ( @@ -446,7 +447,7 @@ async def _lookup_metadata_candidate( exc_info=True, ) continue - candidate = _select_candidate_for_entry(entry, candidates) + candidate = await _select_candidate_for_entry(db, user_id, entry, candidates) if not candidate: continue if _candidate_has_useful_id(candidate): @@ -471,8 +472,107 @@ def _build_lookup_queries(title: str, year: int | None) -> list[str]: return [f"{cleaned} {year}", cleaned] -def _select_candidate_for_entry( - entry: ParsedEntry, candidates: list[MediaCandidate] +async def _check_series_continuity( + db: AsyncSession, + user_id: str, + entry: ParsedEntry, + candidates: list[MediaCandidate], +) -> MediaCandidate | None: + """ + Check if the user has previously watched an earlier episode of a show matching one of the candidates. + This helps disambiguate when there are multiple shows with the same or similar titles. + + For example, if watching "Fallout S02E08" and user previously watched "Fallout S02E07", + prefer the same show's metadata rather than selecting a different "Fallout" (e.g., an anime). + """ + if not entry.title or entry.season_number is None or entry.episode_number is None: + return None + + # Build a normalized title key for matching + title_key = _normalize_title_key(entry.title) + if not title_key: + return None + + # Query for shows the user has watched episodes of, matching the title + # We look for MediaItems that: + # 1. Are TV shows + # 2. Have a matching normalized title + # 3. Have episode watches by this user + from sqlalchemy import and_, func + + # First, find all TV shows the user has watched + # Use a subquery to get the max season and episode per show + # Note: This may return max_season from one episode and max_episode from another, + # but it's acceptable for our continuity check as we're looking for a general pattern + result = await db.execute( + select( + MediaItem, + func.max(EpisodeItem.season_number).label("max_season"), + func.max(EpisodeItem.episode_number).label("max_episode"), + ) + .join(EpisodeItem, EpisodeItem.show_media_item_id == MediaItem.id) + .join(WatchedItem, and_( + WatchedItem.episode_item_id == EpisodeItem.id, + WatchedItem.user_id == user_id + )) + .where(MediaItem.media_type == "tv") + .group_by(MediaItem.id) + ) + + watched_shows = result.all() + + # For each watched show, check if it matches one of our candidates + for show_item, max_season, max_episode in watched_shows: + # Check if the show title matches the entry title (fuzzy match) + show_title_key = _normalize_title_key(show_item.title or "") + if not show_title_key or show_title_key != title_key: + continue + + # Skip if we don't have valid season/episode data + if max_season is None or max_episode is None: + continue + + # Check if this is a continuation (same season and later episode, or later season) + if entry.season_number is not None: + is_continuation = ( + (entry.season_number == max_season and entry.episode_number > max_episode) or + (entry.season_number > max_season) + ) + + # Allow same episode or earlier if not too far back (could be rewatching) + is_nearby = ( + entry.season_number == max_season and + abs(entry.episode_number - max_episode) <= 3 + ) + + if not is_continuation and not is_nearby: + continue + + # Now check if this show matches one of our candidates + for candidate in candidates: + # Match by IMDb ID (most reliable) + if show_item.imdb_id and candidate.imdb_id: + if show_item.imdb_id.lower() == candidate.imdb_id.lower(): + return candidate + + # Match by TMDB ID + if show_item.tmdb_id and candidate.provider == "tmdb" and candidate.provider_id: + if str(show_item.tmdb_id) == str(candidate.provider_id): + return candidate + + # Match by TVDB ID + if show_item.tvdb_id and candidate.provider == "tvdb" and candidate.provider_id: + if str(show_item.tvdb_id) == str(candidate.provider_id): + return candidate + + return None + + +async def _select_candidate_for_entry( + db: AsyncSession, + user_id: str, + entry: ParsedEntry, + candidates: list[MediaCandidate], ) -> MediaCandidate | None: if not candidates: return None @@ -483,6 +583,23 @@ def _select_candidate_for_entry( ] if not scoped: scoped = candidates + + # For TV shows, check if user has watched a previous episode from one of the candidates + if entry.media_type == "tv" and entry.season_number is not None and entry.episode_number is not None: + continuity_candidate = await _check_series_continuity( + db, user_id, entry, scoped + ) + if continuity_candidate: + logger.debug( + "Using series continuity: selected %s (imdb=%s) for %s S%02dE%02d", + continuity_candidate.title, + continuity_candidate.imdb_id, + entry.title, + entry.season_number, + entry.episode_number, + ) + return continuity_candidate + title_key = _normalize_title_key(entry.title or "") title_matches: list[MediaCandidate] = [] if title_key: diff --git a/backend/tests/test_aiostreams_import.py b/backend/tests/test_aiostreams_import.py index 00760ad..f4115ce 100644 --- a/backend/tests/test_aiostreams_import.py +++ b/backend/tests/test_aiostreams_import.py @@ -1,7 +1,9 @@ +import asyncio import sys import unittest from datetime import datetime, timezone from pathlib import Path +from unittest.mock import AsyncMock, MagicMock PROJECT_ROOT = Path(__file__).resolve().parents[1] sys.path.append(str(PROJECT_ROOT / "src")) @@ -11,19 +13,26 @@ class TestAIOStreamsLookup(unittest.TestCase): - def _build_entry(self, title: str, year: int | None) -> aiostreams_import.ParsedEntry: + def _build_entry( + self, + title: str, + year: int | None, + media_type: str = "movie", + season_number: int | None = None, + episode_number: int | None = None, + ) -> aiostreams_import.ParsedEntry: now = datetime.now(timezone.utc) return aiostreams_import.ParsedEntry( raw={}, watched_at=now, last_seen=now, duration_seconds=3600, - media_type="movie", + media_type=media_type, imdb_id=None, tmdb_id=None, tvdb_id=None, - season_number=None, - episode_number=None, + season_number=season_number, + episode_number=episode_number, title=title, year=year, filename=None, @@ -56,7 +65,13 @@ def test_select_candidate_prefers_title_and_year(self) -> None: raw={}, ), ] - selected = aiostreams_import._select_candidate_for_entry(entry, candidates) + # Create a mock db session + db = AsyncMock() + db.execute = AsyncMock(return_value=MagicMock(all=MagicMock(return_value=[]))) + + selected = asyncio.run( + aiostreams_import._select_candidate_for_entry(db, "test_user", entry, candidates) + ) self.assertIsNotNone(selected) self.assertEqual(selected.provider_id, "200") @@ -84,10 +99,83 @@ def test_select_candidate_uses_title_match(self) -> None: raw={}, ), ] - selected = aiostreams_import._select_candidate_for_entry(entry, candidates) + # Create a mock db session + db = AsyncMock() + db.execute = AsyncMock(return_value=MagicMock(all=MagicMock(return_value=[]))) + + selected = asyncio.run( + aiostreams_import._select_candidate_for_entry(db, "test_user", entry, candidates) + ) self.assertIsNotNone(selected) self.assertEqual(selected.provider_id, "tt0000002") + def test_series_continuity_prefers_previously_watched_show(self) -> None: + """Test that if user watched Fallout S02E07, watching S02E08 prefers the same show.""" + entry = self._build_entry( + title="Fallout", + year=None, + media_type="tv", + season_number=2, + episode_number=8, + ) + + # Two candidates: one is the correct Fallout TV show, another is an anime + fallout_tv_candidate = MediaCandidate( + provider="tmdb", + provider_id="12345", + media_type="tv", + title="Fallout", + year=2024, + poster_url=None, + imdb_id="tt12345678", + raw={}, + ) + fallout_anime_candidate = MediaCandidate( + provider="tmdb", + provider_id="99999", + media_type="tv", + title="Fallout", + year=2008, + poster_url=None, + imdb_id="tt99999999", + raw={}, + ) + + candidates = [ + fallout_anime_candidate, # Wrong one first in list + fallout_tv_candidate, # Correct one second + ] + + # Mock database to return a show with matching imdb_id that user has watched before + from sqlalchemy.engine import Result + mock_result = MagicMock() + + # Create mock MediaItem for the TV show the user has watched + mock_media_item = MagicMock() + mock_media_item.id = "show_123" + mock_media_item.media_type = "tv" + mock_media_item.title = "Fallout" + mock_media_item.imdb_id = "tt12345678" + mock_media_item.tmdb_id = None + mock_media_item.tvdb_id = None + + # User has watched S02E07 (episode 7 before current episode 8) + mock_result.all = MagicMock(return_value=[ + (mock_media_item, 2, 7) # (MediaItem, max_season, max_episode) + ]) + + db = AsyncMock() + db.execute = AsyncMock(return_value=mock_result) + + selected = asyncio.run( + aiostreams_import._select_candidate_for_entry(db, "test_user", entry, candidates) + ) + + # Should select the TV show, not the anime, because user watched previous episode + self.assertIsNotNone(selected) + self.assertEqual(selected.imdb_id, "tt12345678") + self.assertEqual(selected.year, 2024) + if __name__ == "__main__": unittest.main()