Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
234 changes: 223 additions & 11 deletions src/pgslice/graph/traverser.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,9 +243,10 @@ def traverse_multiple(
max_depth: int | None = None,
) -> set[RecordData]:
"""
Traverse from multiple starting records.
Traverse from multiple starting records using unified BFS.

Efficiently handles shared relationships via visited tracking.
Optimizes multi-record traversal by batch-fetching all starting
records and running a single BFS from all starting points.

Args:
table_name: Starting table name
Expand All @@ -254,21 +255,232 @@ def traverse_multiple(
max_depth: Optional maximum traversal depth

Returns:
Set of all discovered RecordData objects (union of all traversals)
Set of all discovered RecordData objects
"""
all_records: set[RecordData] = set()
# Edge case: empty pk_values
if not pk_values:
logger.info("No primary keys provided for traversal")
return set()

# Single PK: delegate to traverse() for simplicity
if len(pk_values) == 1:
return self.traverse(table_name, pk_values[0], schema, max_depth)

logger.info(
f"Starting batch traversal from {schema}.{table_name} "
f"with {len(pk_values)} starting records"
)

results: set[RecordData] = set()

# Step 1: Create RecordIdentifiers for all starting records
start_ids: list[RecordIdentifier] = [
self._create_record_identifier(schema, table_name, (pk,))
for pk in pk_values
]

# Step 2: Filter out already-visited starting records
unvisited_start_ids: list[RecordIdentifier] = [
rid for rid in start_ids if not self.visited.is_visited(rid)
]

if not unvisited_start_ids:
logger.info("All starting records already visited")
return results

# Step 3: Mark all starting records as visited BEFORE fetching
for rid in unvisited_start_ids:
self.visited.mark_visited(rid)

# Step 4: Batch-fetch ALL starting records in one query
try:
fetched_starts = self._fetch_records_batch(unvisited_start_ids)
except Exception as e:
logger.error(f"Error batch-fetching starting records: {e}")
# Fallback to individual fetches
fetched_starts = {}
for rid in unvisited_start_ids:
try:
fetched_starts[rid] = self._fetch_record(rid)
except RecordNotFoundError:
logger.warning(f"Starting record not found: {rid}")

if not fetched_starts:
logger.warning("No starting records found")
return results

logger.debug(
f"Fetched {len(fetched_starts)}/{len(unvisited_start_ids)} starting records"
)

# Step 5: Initialize unified BFS queue with all starting records at depth 0
queue: deque[tuple[RecordIdentifier, int, bool]] = deque()

for record_id, record_data in fetched_starts.items():
results.add(record_data)

# Get table metadata and process outgoing FKs
table = self._get_table_metadata(
record_id.schema_name, record_id.table_name
)

for fk in table.foreign_keys_outgoing:
target_id = self._resolve_foreign_key_target(record_data, fk)
if target_id:
record_data.dependencies.add(target_id)
if not self.visited.is_visited(target_id):
follow_incoming = self.wide_mode
queue.append((target_id, 1, follow_incoming))

# Step 6: Batch-process incoming FKs for all starting records
self._batch_process_incoming_fks_for_records(
fetched_starts, queue, current_depth=0
)

# Step 7: Run unified BFS (reuse logic from traverse())
while queue:
current_depth = queue[0][1] if queue else 0
batch: list[tuple[RecordIdentifier, bool]] = []

# Collect batch at current depth
while queue and len(batch) < self.fetch_batch_size:
record_id, depth, follow_incoming_fks = queue.popleft()

logger.info(f"Starting multi-record traversal from {schema}.{table_name}")
logger.info(f"Primary keys: {pk_values}")
if depth != current_depth:
queue.appendleft((record_id, depth, follow_incoming_fks))
break

for pk_value in pk_values:
records = self.traverse(table_name, pk_value, schema, max_depth)
all_records.update(records)
if max_depth is not None and depth > max_depth:
logger.debug(
f"Skipping {record_id}: depth {depth} > max {max_depth}"
)
continue

if self.visited.is_visited(record_id):
logger.debug(f"Skipping {record_id}: already visited")
continue

self.visited.mark_visited(record_id)
batch.append((record_id, follow_incoming_fks))

if not batch:
continue

# Batch fetch records
record_ids = [rid for rid, _ in batch]
try:
fetched_records = self._fetch_records_batch(record_ids)
except Exception as e:
logger.error(f"Error batch fetching records: {e}")
fetched_records = {}
for record_id in record_ids:
try:
fetched_records[record_id] = self._fetch_record(record_id)
except RecordNotFoundError:
logger.warning(f"Record not found: {record_id}")

# Process fetched records
for record_id, _ in batch:
if record_id not in fetched_records:
continue

record_data = fetched_records[record_id]
results.add(record_data)

table = self._get_table_metadata(
record_id.schema_name, record_id.table_name
)

for fk in table.foreign_keys_outgoing:
target_id = self._resolve_foreign_key_target(record_data, fk)
if target_id:
record_data.dependencies.add(target_id)
if not self.visited.is_visited(target_id):
follow_incoming = self.wide_mode
queue.append(
(target_id, current_depth + 1, follow_incoming)
)

# Process incoming FKs for batch
batch_records = {
rid: fetched_records[rid]
for rid, follow in batch
if rid in fetched_records and follow
}
self._batch_process_incoming_fks_for_records(
batch_records, queue, current_depth
)

logger.info(
f"Multi-traversal complete: {len(all_records)} unique records found"
f"Batch traversal complete: {len(results)} unique records found "
f"from {len(pk_values)} starting points"
)
return all_records
return results

def _batch_process_incoming_fks_for_records(
self,
records: dict[RecordIdentifier, RecordData],
queue: deque[tuple[RecordIdentifier, int, bool]],
current_depth: int,
) -> None:
"""
Process incoming FKs for multiple records in batch.

Groups records by FK relationship for efficient batch lookups.

Args:
records: Map of record IDs to their data
queue: BFS queue to append discovered records
current_depth: Current traversal depth
"""
if not records:
return

# Group by incoming FK for batch processing
incoming_fk_lookups: dict[tuple[str, str], list[RecordIdentifier]] = {}

for record_id in records:
table = self._get_table_metadata(
record_id.schema_name, record_id.table_name
)

for fk in table.foreign_keys_incoming:
# Skip self-referencing FKs in strict mode
if not self.wide_mode:
source_schema, source_table = self._parse_table_name(
fk.source_table
)
if (
source_schema == record_id.schema_name
and source_table == record_id.table_name
):
continue

fk_key = (fk.source_table, fk.source_column)
if fk_key not in incoming_fk_lookups:
incoming_fk_lookups[fk_key] = []
incoming_fk_lookups[fk_key].append(record_id)

# Execute batch lookups
for (source_table, source_column), target_ids in incoming_fk_lookups.items():
fk_obj = type(
"FK",
(),
{"source_table": source_table, "source_column": source_column},
)()

try:
referencing_map = self._find_referencing_records_batch(
target_ids, fk_obj
)

for target_id in target_ids:
source_records = referencing_map.get(target_id, [])
for source_id in source_records:
if not self.visited.is_visited(source_id):
queue.append((source_id, current_depth + 1, True))
except Exception as e:
logger.error(f"Error in batch FK lookup for {source_table}: {e}")

def _fetch_record(self, record_id: RecordIdentifier) -> RecordData:
"""
Expand Down
74 changes: 70 additions & 4 deletions tests/unit/graph/test_traverser.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,14 +269,36 @@ def test_does_not_revisit_records(
class TestTraverseMultiple(TestRelationshipTraverser):
"""Tests for traverse_multiple method."""

def test_empty_pk_values_returns_empty_set(
self,
traverser: RelationshipTraverser,
) -> None:
"""Should return empty set for empty pk_values."""
results = traverser.traverse_multiple("users", [])
assert len(results) == 0

def test_single_pk_delegates_to_traverse(
self,
traverser: RelationshipTraverser,
mock_cursor: MagicMock,
) -> None:
"""Should delegate to traverse() for single PK."""
mock_cursor.fetchone.return_value = (1, "User 1")

results = traverser.traverse_multiple("users", [1])

assert len(results) == 1
record = list(results)[0]
assert record.identifier.table_name == "users"

def test_traverses_all_starting_records(
self,
traverser: RelationshipTraverser,
mock_cursor: MagicMock,
) -> None:
"""Should traverse from all starting records."""
# Different return values for each call
mock_cursor.fetchone.side_effect = [
"""Should traverse from all starting records using batch fetch."""
# Batch fetch returns all records at once
mock_cursor.fetchall.return_value = [
(1, "User 1"),
(2, "User 2"),
(3, "User 3"),
Expand All @@ -292,7 +314,7 @@ def test_combines_results(
mock_cursor: MagicMock,
) -> None:
"""Should combine results from all traversals."""
mock_cursor.fetchone.side_effect = [
mock_cursor.fetchall.return_value = [
(1, "User 1"),
(2, "User 2"),
]
Expand All @@ -303,6 +325,50 @@ def test_combines_results(
assert ("1",) in identifiers
assert ("2",) in identifiers

def test_skips_already_visited_starting_records(
self,
traverser: RelationshipTraverser,
mock_cursor: MagicMock,
visited_tracker: VisitedTracker,
) -> None:
"""Should skip starting records that are already visited."""
# Pre-mark record 1 as visited
visited_tracker.mark_visited(
RecordIdentifier(
schema_name="public",
table_name="users",
pk_values=(1,),
)
)

# Only record 2 should be fetched
mock_cursor.fetchall.return_value = [
(2, "User 2"),
]

results = traverser.traverse_multiple("users", [1, 2])

# Should only have record 2
assert len(results) == 1
record = list(results)[0]
assert record.identifier.pk_values == ("2",)

def test_respects_max_depth(
self,
traverser: RelationshipTraverser,
mock_cursor: MagicMock,
) -> None:
"""Should respect max_depth parameter."""
mock_cursor.fetchall.return_value = [
(1, "User 1"),
(2, "User 2"),
]

results = traverser.traverse_multiple("users", [1, 2], max_depth=0)

# max_depth=0 means only the starting records
assert len(results) == 2


class TestFetchRecord(TestRelationshipTraverser):
"""Tests for _fetch_record method."""
Expand Down