diff --git a/src/pgslice/graph/traverser.py b/src/pgslice/graph/traverser.py index 09dd2d2..8e26bdc 100644 --- a/src/pgslice/graph/traverser.py +++ b/src/pgslice/graph/traverser.py @@ -243,9 +243,10 @@ def traverse_multiple( max_depth: int | None = None, ) -> set[RecordData]: """ - Traverse from multiple starting records. + Traverse from multiple starting records using unified BFS. - Efficiently handles shared relationships via visited tracking. + Optimizes multi-record traversal by batch-fetching all starting + records and running a single BFS from all starting points. Args: table_name: Starting table name @@ -254,21 +255,232 @@ def traverse_multiple( max_depth: Optional maximum traversal depth Returns: - Set of all discovered RecordData objects (union of all traversals) + Set of all discovered RecordData objects """ - all_records: set[RecordData] = set() + # Edge case: empty pk_values + if not pk_values: + logger.info("No primary keys provided for traversal") + return set() + + # Single PK: delegate to traverse() for simplicity + if len(pk_values) == 1: + return self.traverse(table_name, pk_values[0], schema, max_depth) + + logger.info( + f"Starting batch traversal from {schema}.{table_name} " + f"with {len(pk_values)} starting records" + ) + + results: set[RecordData] = set() + + # Step 1: Create RecordIdentifiers for all starting records + start_ids: list[RecordIdentifier] = [ + self._create_record_identifier(schema, table_name, (pk,)) + for pk in pk_values + ] + + # Step 2: Filter out already-visited starting records + unvisited_start_ids: list[RecordIdentifier] = [ + rid for rid in start_ids if not self.visited.is_visited(rid) + ] + + if not unvisited_start_ids: + logger.info("All starting records already visited") + return results + + # Step 3: Mark all starting records as visited BEFORE fetching + for rid in unvisited_start_ids: + self.visited.mark_visited(rid) + + # Step 4: Batch-fetch ALL starting records in one query + try: + fetched_starts = self._fetch_records_batch(unvisited_start_ids) + except Exception as e: + logger.error(f"Error batch-fetching starting records: {e}") + # Fallback to individual fetches + fetched_starts = {} + for rid in unvisited_start_ids: + try: + fetched_starts[rid] = self._fetch_record(rid) + except RecordNotFoundError: + logger.warning(f"Starting record not found: {rid}") + + if not fetched_starts: + logger.warning("No starting records found") + return results + + logger.debug( + f"Fetched {len(fetched_starts)}/{len(unvisited_start_ids)} starting records" + ) + + # Step 5: Initialize unified BFS queue with all starting records at depth 0 + queue: deque[tuple[RecordIdentifier, int, bool]] = deque() + + for record_id, record_data in fetched_starts.items(): + results.add(record_data) + + # Get table metadata and process outgoing FKs + table = self._get_table_metadata( + record_id.schema_name, record_id.table_name + ) + + for fk in table.foreign_keys_outgoing: + target_id = self._resolve_foreign_key_target(record_data, fk) + if target_id: + record_data.dependencies.add(target_id) + if not self.visited.is_visited(target_id): + follow_incoming = self.wide_mode + queue.append((target_id, 1, follow_incoming)) + + # Step 6: Batch-process incoming FKs for all starting records + self._batch_process_incoming_fks_for_records( + fetched_starts, queue, current_depth=0 + ) + + # Step 7: Run unified BFS (reuse logic from traverse()) + while queue: + current_depth = queue[0][1] if queue else 0 + batch: list[tuple[RecordIdentifier, bool]] = [] + + # Collect batch at current depth + while queue and len(batch) < self.fetch_batch_size: + record_id, depth, follow_incoming_fks = queue.popleft() - logger.info(f"Starting multi-record traversal from {schema}.{table_name}") - logger.info(f"Primary keys: {pk_values}") + if depth != current_depth: + queue.appendleft((record_id, depth, follow_incoming_fks)) + break - for pk_value in pk_values: - records = self.traverse(table_name, pk_value, schema, max_depth) - all_records.update(records) + if max_depth is not None and depth > max_depth: + logger.debug( + f"Skipping {record_id}: depth {depth} > max {max_depth}" + ) + continue + + if self.visited.is_visited(record_id): + logger.debug(f"Skipping {record_id}: already visited") + continue + + self.visited.mark_visited(record_id) + batch.append((record_id, follow_incoming_fks)) + + if not batch: + continue + + # Batch fetch records + record_ids = [rid for rid, _ in batch] + try: + fetched_records = self._fetch_records_batch(record_ids) + except Exception as e: + logger.error(f"Error batch fetching records: {e}") + fetched_records = {} + for record_id in record_ids: + try: + fetched_records[record_id] = self._fetch_record(record_id) + except RecordNotFoundError: + logger.warning(f"Record not found: {record_id}") + + # Process fetched records + for record_id, _ in batch: + if record_id not in fetched_records: + continue + + record_data = fetched_records[record_id] + results.add(record_data) + + table = self._get_table_metadata( + record_id.schema_name, record_id.table_name + ) + + for fk in table.foreign_keys_outgoing: + target_id = self._resolve_foreign_key_target(record_data, fk) + if target_id: + record_data.dependencies.add(target_id) + if not self.visited.is_visited(target_id): + follow_incoming = self.wide_mode + queue.append( + (target_id, current_depth + 1, follow_incoming) + ) + + # Process incoming FKs for batch + batch_records = { + rid: fetched_records[rid] + for rid, follow in batch + if rid in fetched_records and follow + } + self._batch_process_incoming_fks_for_records( + batch_records, queue, current_depth + ) logger.info( - f"Multi-traversal complete: {len(all_records)} unique records found" + f"Batch traversal complete: {len(results)} unique records found " + f"from {len(pk_values)} starting points" ) - return all_records + return results + + def _batch_process_incoming_fks_for_records( + self, + records: dict[RecordIdentifier, RecordData], + queue: deque[tuple[RecordIdentifier, int, bool]], + current_depth: int, + ) -> None: + """ + Process incoming FKs for multiple records in batch. + + Groups records by FK relationship for efficient batch lookups. + + Args: + records: Map of record IDs to their data + queue: BFS queue to append discovered records + current_depth: Current traversal depth + """ + if not records: + return + + # Group by incoming FK for batch processing + incoming_fk_lookups: dict[tuple[str, str], list[RecordIdentifier]] = {} + + for record_id in records: + table = self._get_table_metadata( + record_id.schema_name, record_id.table_name + ) + + for fk in table.foreign_keys_incoming: + # Skip self-referencing FKs in strict mode + if not self.wide_mode: + source_schema, source_table = self._parse_table_name( + fk.source_table + ) + if ( + source_schema == record_id.schema_name + and source_table == record_id.table_name + ): + continue + + fk_key = (fk.source_table, fk.source_column) + if fk_key not in incoming_fk_lookups: + incoming_fk_lookups[fk_key] = [] + incoming_fk_lookups[fk_key].append(record_id) + + # Execute batch lookups + for (source_table, source_column), target_ids in incoming_fk_lookups.items(): + fk_obj = type( + "FK", + (), + {"source_table": source_table, "source_column": source_column}, + )() + + try: + referencing_map = self._find_referencing_records_batch( + target_ids, fk_obj + ) + + for target_id in target_ids: + source_records = referencing_map.get(target_id, []) + for source_id in source_records: + if not self.visited.is_visited(source_id): + queue.append((source_id, current_depth + 1, True)) + except Exception as e: + logger.error(f"Error in batch FK lookup for {source_table}: {e}") def _fetch_record(self, record_id: RecordIdentifier) -> RecordData: """ diff --git a/tests/unit/graph/test_traverser.py b/tests/unit/graph/test_traverser.py index fe36556..d1fb1ea 100644 --- a/tests/unit/graph/test_traverser.py +++ b/tests/unit/graph/test_traverser.py @@ -269,14 +269,36 @@ def test_does_not_revisit_records( class TestTraverseMultiple(TestRelationshipTraverser): """Tests for traverse_multiple method.""" + def test_empty_pk_values_returns_empty_set( + self, + traverser: RelationshipTraverser, + ) -> None: + """Should return empty set for empty pk_values.""" + results = traverser.traverse_multiple("users", []) + assert len(results) == 0 + + def test_single_pk_delegates_to_traverse( + self, + traverser: RelationshipTraverser, + mock_cursor: MagicMock, + ) -> None: + """Should delegate to traverse() for single PK.""" + mock_cursor.fetchone.return_value = (1, "User 1") + + results = traverser.traverse_multiple("users", [1]) + + assert len(results) == 1 + record = list(results)[0] + assert record.identifier.table_name == "users" + def test_traverses_all_starting_records( self, traverser: RelationshipTraverser, mock_cursor: MagicMock, ) -> None: - """Should traverse from all starting records.""" - # Different return values for each call - mock_cursor.fetchone.side_effect = [ + """Should traverse from all starting records using batch fetch.""" + # Batch fetch returns all records at once + mock_cursor.fetchall.return_value = [ (1, "User 1"), (2, "User 2"), (3, "User 3"), @@ -292,7 +314,7 @@ def test_combines_results( mock_cursor: MagicMock, ) -> None: """Should combine results from all traversals.""" - mock_cursor.fetchone.side_effect = [ + mock_cursor.fetchall.return_value = [ (1, "User 1"), (2, "User 2"), ] @@ -303,6 +325,50 @@ def test_combines_results( assert ("1",) in identifiers assert ("2",) in identifiers + def test_skips_already_visited_starting_records( + self, + traverser: RelationshipTraverser, + mock_cursor: MagicMock, + visited_tracker: VisitedTracker, + ) -> None: + """Should skip starting records that are already visited.""" + # Pre-mark record 1 as visited + visited_tracker.mark_visited( + RecordIdentifier( + schema_name="public", + table_name="users", + pk_values=(1,), + ) + ) + + # Only record 2 should be fetched + mock_cursor.fetchall.return_value = [ + (2, "User 2"), + ] + + results = traverser.traverse_multiple("users", [1, 2]) + + # Should only have record 2 + assert len(results) == 1 + record = list(results)[0] + assert record.identifier.pk_values == ("2",) + + def test_respects_max_depth( + self, + traverser: RelationshipTraverser, + mock_cursor: MagicMock, + ) -> None: + """Should respect max_depth parameter.""" + mock_cursor.fetchall.return_value = [ + (1, "User 1"), + (2, "User 2"), + ] + + results = traverser.traverse_multiple("users", [1, 2], max_depth=0) + + # max_depth=0 means only the starting records + assert len(results) == 2 + class TestFetchRecord(TestRelationshipTraverser): """Tests for _fetch_record method."""