diff --git a/.gitignore b/.gitignore index 013fe9d..4e5dee5 100644 --- a/.gitignore +++ b/.gitignore @@ -145,6 +145,7 @@ Thumbs.db .pgslice_history schema_cache.db output/ +dumps/ # AI CLAUDE.md diff --git a/CHANGELOG.md b/CHANGELOG.md index c41c9d3..3e05084 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.2.1] - 2025-12-29 + +### Fixed +- Docker volume permission issues with dedicated entrypoint script + +### Changed +- Optimized graph traversal performance using batch queries for relationship lookups + ## [0.2.0] - 2025-12-28 ### Added diff --git a/DOCKER_USAGE.md b/DOCKER_USAGE.md index 984462a..47799da 100644 --- a/DOCKER_USAGE.md +++ b/DOCKER_USAGE.md @@ -39,6 +39,44 @@ docker run --rm -it \ edraobdu/pgslice:latest \ pgslice --host your.db.host --port 5432 --user your_user --database your_db ``` + +### Connecting to Localhost Database + +When your PostgreSQL database is running on your host machine (localhost), the container cannot access it using `localhost` or `127.0.0.1` because these refer to the container itself, not your host. + +**Solution 1: Use host networking (Linux, simplest)** +```bash +docker run --rm -it \ + --network host \ + -v $(pwd)/dumps:/home/pgslice/.pgslice/dumps \ + -e PGPASSWORD=your_password \ + edraobdu/pgslice:latest \ + pgslice --host localhost --port 5432 --user your_user --database your_db +``` + +**Solution 2: Use host.docker.internal (Mac/Windows)** +```bash +docker run --rm -it \ + -v $(pwd)/dumps:/home/pgslice/.pgslice/dumps \ + -e PGPASSWORD=your_password \ + edraobdu/pgslice:latest \ + pgslice --host host.docker.internal --port 5432 --user your_user --database your_db +``` + +**Solution 3: Use Docker bridge IP (Linux alternative)** +```bash +# Find your host's Docker bridge IP (usually 172.17.0.1) +docker run --rm -it \ + -v $(pwd)/dumps:/home/pgslice/.pgslice/dumps \ + -e PGPASSWORD=your_password \ + edraobdu/pgslice:latest \ + pgslice --host 172.17.0.1 --port 5432 --user your_user --database your_db +``` + +**Note:** Make sure your PostgreSQL is configured to accept connections from Docker containers: +- Edit `postgresql.conf`: Set `listen_addresses = '*'` or `listen_addresses = '0.0.0.0'` +- Edit `pg_hba.conf`: Add entry like `host all all 172.17.0.0/16 md5` (for Docker bridge network) + ### Using Environment File Create a `.env` file: @@ -86,7 +124,63 @@ Mount a local directory to persist SQL dumps: -v $(pwd)/dumps:/home/pgslice/.pgslice/dumps ``` -**Important:** The dumps directory is created inside the container with non-root user permissions (UID 1000). +#### Volume Permissions + +The container runs as non-root user `pgslice` (UID 1000) for security. When mounting local directories: + +**The entrypoint script automatically handles permissions** by: +1. Detecting mounted volumes +2. Fixing ownership to UID 1000 if needed +3. Providing helpful error messages if permissions can't be fixed + +**If you encounter permission errors:** + +**Option 1: Pre-fix permissions on host (recommended)** +```bash +# Create dumps directory and set ownership +mkdir -p dumps +sudo chown -R 1000:1000 dumps + +# Run container +docker run --rm -it \ + -v $(pwd)/dumps:/home/pgslice/.pgslice/dumps \ + edraobdu/pgslice:latest pgslice +``` + +**Option 2: Run as your user ID** +```bash +# Run container as your user (bypasses UID 1000) +docker run --rm -it \ + -v $(pwd)/dumps:/home/pgslice/.pgslice/dumps \ + --user $(id -u):$(id -g) \ + edraobdu/pgslice:latest pgslice +``` + +**Why UID 1000?** +- Common default UID for first user on Linux systems +- Matches most developer workstations +- If your user is different, use `--user $(id -u):$(id -g)` flag + +### Remote Server Workflow + +When running pgslice on a remote server, dumps are created as files with visible progress: + +```bash +# SSH into remote server and run dump +ssh user@remote-server "docker run --rm \ + -v /tmp/dumps:/home/pgslice/.pgslice/dumps \ + --env-file .env \ + edraobdu/pgslice:latest \ + pgslice --dump users --pks 42" + +# Copy the generated file to your local machine +scp user@remote-server:/tmp/dumps/public_users_42_*.sql ./local_dumps/ + +# Or use rsync for better performance with large files +rsync -avz user@remote-server:/tmp/dumps/public_users_42_*.sql ./local_dumps/ +``` + +Progress bars are visible during the dump, and the file is ready to transfer when complete. ## Links diff --git a/Dockerfile b/Dockerfile index 5c7c025..4c03ccf 100644 --- a/Dockerfile +++ b/Dockerfile @@ -5,7 +5,7 @@ FROM python:3.13-alpine COPY --from=ghcr.io/astral-sh/uv:latest /uv /usr/local/bin/uv # Install system dependencies -RUN apk add --no-cache postgresql-client +RUN apk add --no-cache postgresql-client su-exec # Install the project into `/app` WORKDIR /app @@ -31,6 +31,9 @@ COPY . /app RUN --mount=type=cache,target=/root/.cache/uv \ uv pip install --no-deps -e . +# Copy entrypoint script (must be done as root before USER directive) +COPY --chmod=755 docker-entrypoint.sh /usr/local/bin/docker-entrypoint.sh + # Set environment variables ENV PYTHONUNBUFFERED=1 @@ -39,14 +42,13 @@ RUN adduser -D -u 1000 pgslice && \ mkdir -p /home/pgslice/.cache/pgslice /home/pgslice/.pgslice/dumps && \ chown -R pgslice:pgslice /app /home/pgslice -# Switch to non-root user -USER pgslice - # Update cache directory to use pgslice's home ENV PGSLICE_CACHE_DIR=/home/pgslice/.cache/pgslice -# Reset the entrypoint, don't invoke `uv` -ENTRYPOINT [] +# Note: Container runs as root, entrypoint will drop to pgslice user after fixing permissions + +# Use custom entrypoint to fix permissions +ENTRYPOINT ["/usr/local/bin/docker-entrypoint.sh"] -# Default command +# Default command (passed to entrypoint) CMD ["pgslice", "--help"] diff --git a/README.md b/README.md index e386d72..bcbf5c1 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,7 @@ Extract only what you need while maintaining referential integrity. ## Features -- ✅ **CLI-first design**: Stream SQL to stdout for easy piping and scripting +- ✅ **CLI-first design**: Dumps always saved to files with visible progress (matches REPL behavior) - ✅ **Bidirectional FK traversal**: Follows relationships in both directions (forward and reverse) - ✅ **Circular relationship handling**: Prevents infinite loops with visited tracking - ✅ **Multiple records**: Extract multiple records in one operation @@ -79,6 +79,65 @@ docker pull edraobdu/pgslice:0.1.1 docker pull --platform linux/amd64 edraobdu/pgslice:latest ``` +#### Connecting to Localhost Database + +When your PostgreSQL database runs on your host machine, use `--network host` (Linux) or `host.docker.internal` (Mac/Windows): + +```bash +# Linux: Use host networking +docker run --rm -it \ + --network host \ + -v $(pwd)/dumps:/home/pgslice/.pgslice/dumps \ + -e PGPASSWORD=your_password \ + edraobdu/pgslice:latest \ + pgslice --host localhost --database your_db --dump users --pks 42 + +# Mac/Windows: Use special hostname +docker run --rm -it \ + -v $(pwd)/dumps:/home/pgslice/.pgslice/dumps \ + -e PGPASSWORD=your_password \ + edraobdu/pgslice:latest \ + pgslice --host host.docker.internal --database your_db --dump users --pks 42 +``` + +See [DOCKER_USAGE.md](DOCKER_USAGE.md#connecting-to-localhost-database) for more connection options. + +#### Docker Volume Permissions + +The pgslice container runs as user `pgslice` (UID 1000) for security. When mounting local directories as volumes, you may encounter permission issues. + +**The entrypoint script automatically fixes permissions** on mounted volumes. However, if you still encounter issues: + +```bash +# Fix permissions on host before mounting +sudo chown -R 1000:1000 ./dumps + +# Then run normally +docker run --rm -it \ + -v $(pwd)/dumps:/home/pgslice/.pgslice/dumps \ + edraobdu/pgslice:latest \ + pgslice --host your.db.host --database your_db --dump users --pks 42 +``` + +**Alternative:** Run container as your user: +```bash +docker run --rm -it \ + -v $(pwd)/dumps:/home/pgslice/.pgslice/dumps \ + --user $(id -u):$(id -g) \ + edraobdu/pgslice:latest \ + pgslice --host your.db.host --database your_db --dump users --pks 42 +``` + +**For remote servers:** +```bash +# Run dump on remote server +ssh user@remote-server "docker run --rm -v /tmp/dumps:/home/pgslice/.pgslice/dumps \ + edraobdu/pgslice:latest pgslice --dump users --pks 42" + +# Copy file locally +scp user@remote-server:/tmp/dumps/users_42_*.sql ./ +``` + ### From Source (Development) See [DEVELOPMENT.md](DEVELOPMENT.md) for detailed development setup instructions. @@ -87,40 +146,40 @@ See [DEVELOPMENT.md](DEVELOPMENT.md) for detailed development setup instructions ### CLI Mode -The CLI mode streams SQL to stdout by default, making it easy to pipe or redirect output: +Dumps are always saved to files with visible progress indicators (helpful for large datasets): ```bash -# Basic dump to stdout (pipe to file) -PGPASSWORD=xxx pgslice --host localhost --database mydb --table users --pks 42 > user_42.sql +# Basic dump (auto-generates filename like: public_users_42_TIMESTAMP.sql) +PGPASSWORD=xxx pgslice --host localhost --database mydb --dump users --pks 42 # Multiple records -PGPASSWORD=xxx pgslice --host localhost --database mydb --table users --pks 1,2,3 > users.sql +PGPASSWORD=xxx pgslice --host localhost --database mydb --dump users --pks 1,2,3 -# Output directly to file with --output flag -pgslice --host localhost --database mydb --table users --pks 42 --output user_42.sql +# Specify output file path +pgslice --host localhost --database mydb --dump users --pks 42 --output user_42.sql # Dump by timeframe (instead of PKs) - filters main table by date range -pgslice --host localhost --database mydb --table orders \ - --timeframe "created_at:2024-01-01:2024-12-31" > orders_2024.sql +pgslice --host localhost --database mydb --dump orders \ + --timeframe "created_at:2024-01-01:2024-12-31" --output orders_2024.sql # Wide mode: follow all relationships including self-referencing FKs # Be cautious - this can result in larger datasets -pgslice --host localhost --database mydb --table customer --pks 42 --wide > customer.sql +pgslice --host localhost --database mydb --dump customer --pks 42 --wide # Keep original primary keys (no remapping) -pgslice --host localhost --database mydb --table film --pks 1 --keep-pks > film.sql +pgslice --host localhost --database mydb --dump film --pks 1 --keep-pks # Generate self-contained SQL with DDL statements # Includes CREATE DATABASE/SCHEMA/TABLE statements -pgslice --host localhost --database mydb --table film --pks 1 --create-schema > film_complete.sql +pgslice --host localhost --database mydb --dump film --pks 1 --create-schema # Apply truncate filter to limit related tables by date range -pgslice --host localhost --database mydb --table customer --pks 42 \ - --truncate "rental:rental_date:2024-01-01:2024-12-31" > customer.sql +pgslice --host localhost --database mydb --dump customer --pks 42 \ + --truncate "rental:rental_date:2024-01-01:2024-12-31" # Enable debug logging (writes to stderr) -pgslice --host localhost --database mydb --table users --pks 42 \ - --log-level DEBUG 2>debug.log > output.sql +pgslice --host localhost --database mydb --dump users --pks 42 \ + --log-level DEBUG 2>debug.log ``` ### Schema Exploration @@ -140,12 +199,12 @@ Run pgslice on a remote server and capture output locally: ```bash # Execute on remote server, save output locally ssh remote.server.com "PGPASSWORD=xxx pgslice --host db.internal --database mydb \ - --table users --pks 1 --create-schema" > local_dump.sql + --dump users --pks 1 --create-schema" > local_dump.sql # With SSH tunnel for database access ssh -f -N -L 5433:db.internal:5432 bastion.example.com PGPASSWORD=xxx pgslice --host localhost --port 5433 --database mydb \ - --table users --pks 42 > user.sql + --dump users --pks 42 > user.sql ``` ### Interactive REPL @@ -163,43 +222,42 @@ pgslice> describe "film" Understanding the difference between CLI and REPL modes: -### CLI Mode (stdout by default) -The CLI streams SQL to **stdout** by default, perfect for piping and scripting: +### CLI Mode (files with progress) +The CLI writes to files and shows progress bars (helpful for large datasets): ```bash -# Streams to stdout - redirect with > -pgslice --table users --pks 42 > user_42.sql +# Writes to ~/.pgslice/dumps/public_users_42_TIMESTAMP.sql +pgslice --dump users --pks 42 -# Or use --output flag -pgslice --table users --pks 42 --output user_42.sql - -# Pipe to other commands -pgslice --table users --pks 42 | gzip > user_42.sql.gz +# Specify output file +pgslice --dump users --pks 42 --output user_42.sql ``` -### REPL Mode (files by default) -The REPL writes to **`~/.pgslice/dumps/`** by default when `--output` is not specified: +### REPL Mode (same behavior) +The REPL also writes to **`~/.pgslice/dumps/`** by default: ```bash -# In REPL: writes to ~/.pgslice/dumps/public_users_42.sql +# Writes to ~/.pgslice/dumps/public_users_42_TIMESTAMP.sql pgslice> dump "users" 42 # Specify custom output path pgslice> dump "users" 42 --output /path/to/user.sql ``` +Both modes now behave identically - always writing to files with visible progress. + ### Same Operations, Different Modes | Operation | CLI | REPL | |-----------|-----|------| | **List tables** | `pgslice --tables` | `pgslice> tables` | | **Describe table** | `pgslice --describe users` | `pgslice> describe "users"` | -| **Dump to stdout** | `pgslice --table users --pks 42` | N/A (REPL always writes to file) | -| **Dump to file** | `pgslice --table users --pks 42 --output user.sql` | `pgslice> dump "users" 42 --output user.sql` | -| **Dump (default)** | Stdout | `~/.pgslice/dumps/public_users_42.sql` | -| **Multiple PKs** | `pgslice --table users --pks 1,2,3` | `pgslice> dump "users" 1,2,3` | -| **Truncate filter** | `pgslice --table users --pks 42 --truncate "orders:2024-01-01:2024-12-31"` | `pgslice> dump "users" 42 --truncate "orders:2024-01-01:2024-12-31"` | -| **Wide mode** | `pgslice --table users --pks 42 --wide` | `pgslice> dump "users" 42 --wide` | +| **Dump (auto-named)** | `pgslice --dump users --pks 42` | `pgslice> dump "users" 42` | +| **Dump to file** | `pgslice --dump users --pks 42 --output user.sql` | `pgslice> dump "users" 42 --output user.sql` | +| **Dump (default path)** | `~/.pgslice/dumps/public_users_42_TIMESTAMP.sql` | `~/.pgslice/dumps/public_users_42_TIMESTAMP.sql` | +| **Multiple PKs** | `pgslice --dump users --pks 1,2,3` | `pgslice> dump "users" 1,2,3` | +| **Truncate filter** | `pgslice --dump users --pks 42 --truncate "orders:2024-01-01:2024-12-31"` | `pgslice> dump "users" 42 --truncate "orders:2024-01-01:2024-12-31"` | +| **Wide mode** | `pgslice --dump users --pks 42 --wide` | `pgslice> dump "users" 42 --wide` | ### When to Use Each Mode diff --git a/docker-entrypoint.sh b/docker-entrypoint.sh new file mode 100755 index 0000000..4e39970 --- /dev/null +++ b/docker-entrypoint.sh @@ -0,0 +1,46 @@ +#!/bin/sh +set -e + +# Docker entrypoint for pgslice +# Fixes permissions on mounted volumes to allow UID 1000 (pgslice user) to write + +DUMPS_DIR="/home/pgslice/.pgslice/dumps" +CACHE_DIR="/home/pgslice/.cache/pgslice" + +# Function to check if directory is writable +is_writable() { + su-exec pgslice test -w "$1" 2>/dev/null +} + +# Function to fix permissions if needed +fix_permissions() { + local dir="$1" + + # Only fix if directory exists and is not writable by pgslice user + if [ -d "$dir" ] && ! is_writable "$dir"; then + echo "Fixing permissions on $dir..." + # Change ownership to pgslice:pgslice (UID:GID 1000:1000) + chown -R pgslice:pgslice "$dir" 2>/dev/null || { + echo "Warning: Could not fix permissions on $dir. Volume may be read-only or owned by different user." + echo "To fix: Run 'sudo chown -R 1000:1000 ./dumps' on host before mounting." + } + fi +} + +# Fix permissions on dumps directory if mounted +if [ -d "$DUMPS_DIR" ]; then + fix_permissions "$DUMPS_DIR" +fi + +# Fix permissions on cache directory if mounted +if [ -d "$CACHE_DIR" ]; then + fix_permissions "$CACHE_DIR" +fi + +# If no command provided, show help +if [ $# -eq 0 ]; then + exec su-exec pgslice pgslice --help +fi + +# Execute command as pgslice user +exec su-exec pgslice "$@" diff --git a/src/pgslice/cli.py b/src/pgslice/cli.py index 2a48cbf..8a99841 100644 --- a/src/pgslice/cli.py +++ b/src/pgslice/cli.py @@ -4,6 +4,7 @@ import argparse import sys +import time from dataclasses import dataclass from datetime import datetime from importlib.metadata import version as get_version @@ -160,7 +161,7 @@ def run_cli_dump( return 1 pk_values = fetch_pks_by_timeframe( - conn_manager, args.table, args.schema, timeframe + conn_manager, args.dump, args.schema, timeframe ) if not pk_values: printy("[y]No records found matching the timeframe@") @@ -170,23 +171,26 @@ def run_cli_dump( sys.stderr.write("Error: --pks or --timeframe is required\n") return 1 - # Show progress only if stderr is a TTY (not piped) - show_progress = sys.stderr.isatty() + # Always show progress since we're writing to files (not stdout) + # Users want to see progress for large datasets + show_progress = True + + # Start timing + start_time = time.time() # Wide mode warning if args.wide and show_progress: - sys.stderr.write( - "\n⚠ Note: Wide mode follows ALL relationships including self-referencing FKs.\n" + printy( + "\n[gI]⚠ Note: Wide mode follows ALL relationships including self-referencing FKs.@" ) - sys.stderr.write(" This may take longer and fetch more data.\n\n") - sys.stderr.flush() + printy("[gI]This may take longer and fetch more data.@\n") # Create dump service service = DumpService(conn_manager, config, show_progress=show_progress) # Execute dump result = service.dump( - table=args.table, + table=args.dump, pk_values=pk_values, schema=args.schema, wide_mode=args.wide, @@ -196,12 +200,32 @@ def run_cli_dump( show_graph=args.graph, ) - # Output SQL + # Always write to file (never stdout) if args.output: - SQLWriter.write_to_file(result.sql_content, args.output) - printy(f"[g]Wrote {result.record_count} records to {args.output}@") + output_path = args.output else: - SQLWriter.write_to_stdout(result.sql_content) + # Generate default filename like REPL mode does + output_path = SQLWriter.get_default_output_path( + config.output_dir, + args.dump, # table name + pk_values[0] if pk_values else "multi", # first PK for filename + args.schema, + ) + + SQLWriter.write_to_file(result.sql_content, str(output_path)) + + # Calculate and format elapsed time + elapsed_time = time.time() - start_time + if elapsed_time >= 60: + time_str = f"{elapsed_time / 60:.1f}m" + elif elapsed_time >= 1: + time_str = f"{elapsed_time:.1f}s" + else: + time_str = f"{elapsed_time * 1000:.0f}ms" + + printy( + f"[g]✓ Wrote {result.record_count} records to {output_path} (took {time_str})@" + ) return 0 @@ -260,14 +284,14 @@ def main() -> int: formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: - # Dump to stdout - PGPASSWORD=xxx %(prog)s --host localhost --database mydb --table users --pks 42 + # Dump to auto-generated file (shows progress) + PGPASSWORD=xxx %(prog)s --host localhost --database mydb --dump users --pks 42 # Dump by timeframe (instead of PKs) - %(prog)s --host localhost --database mydb --table orders --timeframe "created_at:2024-01-01:2024-12-31" + %(prog)s --host localhost --database mydb --dump orders --timeframe "created_at:2024-01-01:2024-12-31" - # Dump to file with truncate filter for related tables - %(prog)s --table users --pks 1 --truncate "orders:created_at:2024-01-01:2024-12-31" --output user.sql + # Dump to specific file with truncate filter for related tables + %(prog)s --dump users --pks 1 --truncate "orders:created_at:2024-01-01:2024-12-31" --output user.sql # List all tables %(prog)s --host localhost --database mydb --tables @@ -340,8 +364,9 @@ def main() -> int: # Dump operation arguments (non-interactive CLI mode) dump_group = parser.add_argument_group("Dump Operation (CLI mode)") dump_group.add_argument( - "--table", - help="Table name to dump (enables non-interactive CLI mode)", + "--dump", + "-d", + help="Table name to dump (same as 'dump' command in REPL mode)", ) # --pks and --timeframe are mutually exclusive ways to select records @@ -432,9 +457,9 @@ def main() -> int: config.log_level = args.log_level # Validate CLI dump mode arguments - if args.table and not args.pks and not args.timeframe: + if args.dump and not args.pks and not args.timeframe: sys.stderr.write( - "Error: --pks or --timeframe is required when using --table\n" + "Error: --pks or --timeframe is required when using --dump\n" ) return 1 @@ -485,7 +510,7 @@ def main() -> int: if args.describe: return run_describe_table(conn_manager, args.schema, args.describe) - if args.table: + if args.dump: # Non-interactive CLI dump mode return run_cli_dump(args, config, conn_manager) else: diff --git a/src/pgslice/dumper/dump_service.py b/src/pgslice/dumper/dump_service.py index 9111590..33ad2da 100644 --- a/src/pgslice/dumper/dump_service.py +++ b/src/pgslice/dumper/dump_service.py @@ -14,7 +14,7 @@ from ..graph.traverser import RelationshipTraverser from ..graph.visited_tracker import VisitedTracker from ..utils.logging_config import get_logger -from ..utils.spinner import SpinnerAnimator +from ..utils.spinner import SpinnerAnimator, animated_spinner from .dependency_sorter import DependencySorter from .sql_generator import SQLGenerator @@ -88,63 +88,55 @@ def dump( file=sys.stderr, bar_format="{desc}: {percentage:3.0f}%|{bar}| {n_fmt}/{total_fmt}", ) as pbar: - # Step 1: Setup and traverse relationships # Create spinner animator (updates every 100ms for smooth animation) spinner = SpinnerAnimator(update_interval=0.1) - pbar.set_description(f"Traversing relationships {spinner.get_frame()}") - - # Define progress callback to update progress bar with animated spinner - def update_progress(count: int) -> None: - pbar.set_description( - f"Traversing relationships {spinner.get_frame()} {count} records found" + # Step 1: Setup and traverse relationships (using animated spinner) + with animated_spinner( + spinner, pbar.set_description, "Traversing relationships" + ): + conn = self.conn_manager.get_connection() + introspector = SchemaIntrospector(conn) + visited = VisitedTracker() + traverser = RelationshipTraverser( + conn, + introspector, + visited, + timeframe_filters, + wide_mode=wide_mode, + fetch_batch_size=self.config.sql_batch_size, ) - conn = self.conn_manager.get_connection() - introspector = SchemaIntrospector(conn) - visited = VisitedTracker() - traverser = RelationshipTraverser( - conn, - introspector, - visited, - timeframe_filters, - wide_mode=wide_mode, - progress_callback=update_progress, - ) - - if len(pk_values) == 1: - records = traverser.traverse( - table, pk_values[0], schema, self.config.max_depth - ) - else: - records = traverser.traverse_multiple( - table, pk_values, schema, self.config.max_depth - ) - pbar.set_description( - f"Traversing relationships ✓ {len(records)} records found" - ) + if len(pk_values) == 1: + records = traverser.traverse( + table, pk_values[0], schema, self.config.max_depth + ) + else: + records = traverser.traverse_multiple( + table, pk_values, schema, self.config.max_depth + ) pbar.update(1) - # Step 2: Sort by dependencies - pbar.set_description("Sorting dependencies ⠋") - sorter = DependencySorter() - sorted_records = sorter.sort(records) - pbar.set_description("Sorting dependencies ✓") + # Step 2: Sort by dependencies (using animated spinner) + with animated_spinner( + spinner, pbar.set_description, "Sorting dependencies" + ): + sorter = DependencySorter() + sorted_records = sorter.sort(records) pbar.update(1) - # Step 3: Generate SQL - pbar.set_description("Generating SQL ⠋") - generator = SQLGenerator( - introspector, batch_size=self.config.sql_batch_size - ) - sql = generator.generate_batch( - sorted_records, - keep_pks=keep_pks, - create_schema=create_schema, - database_name=self.config.db.database, - schema_name=schema, - ) - pbar.set_description("Generating SQL ✓") + # Step 3: Generate SQL (using animated spinner) + with animated_spinner(spinner, pbar.set_description, "Generating SQL"): + generator = SQLGenerator( + introspector, batch_size=self.config.sql_batch_size + ) + sql = generator.generate_batch( + sorted_records, + keep_pks=keep_pks, + create_schema=create_schema, + database_name=self.config.db.database, + schema_name=schema, + ) pbar.update(1) # Step 4: Complete diff --git a/src/pgslice/dumper/sql_generator.py b/src/pgslice/dumper/sql_generator.py index 8227181..9a27b07 100644 --- a/src/pgslice/dumper/sql_generator.py +++ b/src/pgslice/dumper/sql_generator.py @@ -4,6 +4,7 @@ import json from datetime import date, datetime, time +from decimal import Decimal from typing import Any from uuid import UUID @@ -413,6 +414,15 @@ def _format_value( elif isinstance(value, int): return str(value) + elif isinstance(value, Decimal): + # Handle Decimal (before float, since we want numeric output) + if value.is_nan(): + return "'NaN'" + if value.is_infinite(): + return "'Infinity'" if value > 0 else "'-Infinity'" + # Return as numeric literal (no quotes) + return str(value) + elif isinstance(value, float): # Handle special float values if value != value: # NaN diff --git a/src/pgslice/graph/traverser.py b/src/pgslice/graph/traverser.py index 757fe8b..09dd2d2 100644 --- a/src/pgslice/graph/traverser.py +++ b/src/pgslice/graph/traverser.py @@ -3,7 +3,6 @@ from __future__ import annotations from collections import deque -from collections.abc import Callable from typing import Any import psycopg @@ -32,7 +31,7 @@ def __init__( visited_tracker: VisitedTracker, timeframe_filters: list[TimeframeFilter] | None = None, wide_mode: bool = False, - progress_callback: Callable[[int], None] | None = None, + fetch_batch_size: int = 500, ) -> None: """ Initialize relationship traverser. @@ -45,7 +44,8 @@ def __init__( wide_mode: If True, follow incoming FKs from all records (wide/exploratory). If False (default), only follow incoming FKs from starting records and records reached via incoming FKs (strict mode, prevents fan-out). - progress_callback: Optional callback invoked with record count after each fetch + fetch_batch_size: Number of records to fetch in a single batch query (default: 500). + Higher values reduce database round-trips but increase memory usage. """ self.conn = connection self.introspector = schema_introspector @@ -53,7 +53,7 @@ def __init__( self.table_cache: dict[str, Table] = {} self.timeframe_filters = {f.table_name: f for f in (timeframe_filters or [])} self.wide_mode = wide_mode - self.progress_callback = progress_callback + self.fetch_batch_size = fetch_batch_size def traverse( self, @@ -63,18 +63,17 @@ def traverse( max_depth: int | None = None, ) -> set[RecordData]: """ - Traverse relationships from a starting record. + Traverse relationships from a starting record using batch fetching. Algorithm: 1. Start with initial record (table + PK) - 2. Use BFS with queue of (RecordIdentifier, depth) - 3. For each record: - - Skip if already visited - - Mark as visited - - Fetch record data - - Follow outgoing FKs (forward relationships) - - Follow incoming FKs (reverse relationships) - 4. Continue until queue empty + 2. Use BFS with queue of (RecordIdentifier, depth, follow_incoming_fks) + 3. Collect records at same depth into batches (up to fetch_batch_size) + 4. For each batch: + - Batch fetch all records + - Process outgoing FKs for all records in batch + - Batch process incoming FKs + 5. Continue until queue empty Args: table_name: Starting table name @@ -89,101 +88,149 @@ def traverse( RecordNotFoundError: If starting record doesn't exist """ start_id = self._create_record_identifier(schema, table_name, (pk_value,)) - # Queue now tracks: (record_id, depth, follow_incoming_fks) - # follow_incoming_fks=True for starting records and records reached via incoming FKs - # follow_incoming_fks=False for records reached via outgoing FKs (dependencies) queue: deque[tuple[RecordIdentifier, int, bool]] = deque([(start_id, 0, True)]) results: set[RecordData] = set() logger.info(f"Starting traversal from {start_id}") while queue: - record_id, depth, follow_incoming_fks = queue.popleft() + # Collect batch: all records at current depth (up to batch_size) + current_depth = queue[0][1] if queue else 0 + batch: list[tuple[RecordIdentifier, bool]] = [] - # Check depth limit - if max_depth is not None and depth > max_depth: - logger.debug(f"Skipping {record_id}: depth {depth} > max {max_depth}") - continue + while queue and len(batch) < self.fetch_batch_size: + record_id, depth, follow_incoming_fks = queue.popleft() - # Skip if already visited - if self.visited.is_visited(record_id): - logger.debug(f"Skipping {record_id}: already visited") - continue + # If depth changed, put it back and process current batch + if depth != current_depth: + queue.appendleft((record_id, depth, follow_incoming_fks)) + break - # Mark as visited BEFORE fetching to prevent re-queueing - self.visited.mark_visited(record_id) + # Check depth limit + if max_depth is not None and depth > max_depth: + logger.debug( + f"Skipping {record_id}: depth {depth} > max {max_depth}" + ) + continue - # Fetch record data - try: - record_data = self._fetch_record(record_id) - except RecordNotFoundError: - logger.warning(f"Record not found: {record_id}") + # Skip if already visited + if self.visited.is_visited(record_id): + logger.debug(f"Skipping {record_id}: already visited") + continue + + # Mark as visited BEFORE fetching + self.visited.mark_visited(record_id) + batch.append((record_id, follow_incoming_fks)) + + # Process batch + if not batch: continue - results.add(record_data) - logger.debug( - f"Fetched {record_id} at depth {depth} ({len(results)} total records)" - ) + # Batch fetch all records + record_ids = [rid for rid, _ in batch] + try: + fetched_records = self._fetch_records_batch(record_ids) + except Exception as e: + logger.error(f"Error batch fetching records: {e}") + # Fall back to individual fetches + fetched_records = {} + for record_id in record_ids: + try: + fetched_records[record_id] = self._fetch_record(record_id) + except RecordNotFoundError: + logger.warning(f"Record not found: {record_id}") + + # Process each fetched record + for record_id, _ in batch: + if record_id not in fetched_records: + continue + + record_data = fetched_records[record_id] + results.add(record_data) + logger.debug( + f"Fetched {record_id} at depth {current_depth} ({len(results)} total)" + ) + + # Get table metadata + table = self._get_table_metadata( + record_id.schema_name, record_id.table_name + ) + + # Traverse outgoing FKs (forward relationships) + for fk in table.foreign_keys_outgoing: + target_id = self._resolve_foreign_key_target(record_data, fk) + if target_id: + # ALWAYS add dependency + record_data.dependencies.add(target_id) + logger.debug( + f" -> Dependency: {record_data.identifier} depends on {target_id}" + ) - # Invoke progress callback with current record count - if self.progress_callback: - self.progress_callback(len(results)) + # Only traverse if not visited + if not self.visited.is_visited(target_id): + follow_incoming = self.wide_mode + queue.append( + (target_id, current_depth + 1, follow_incoming) + ) + logger.debug(f" -> Following outgoing FK to {target_id}") - # Get table metadata - table = self._get_table_metadata( - record_id.schema_name, record_id.table_name - ) + # Process incoming FKs for batch (using batch lookup) + # Group by FK to minimize queries + incoming_fk_lookups: dict[Any, list[tuple[RecordIdentifier, bool]]] = {} + for record_id, follow_incoming_fks in batch: + if record_id not in fetched_records or not follow_incoming_fks: + continue - # Traverse outgoing FKs (forward relationships) - for fk in table.foreign_keys_outgoing: - target_id = self._resolve_foreign_key_target(record_data, fk) - if target_id: - # ALWAYS add dependency (even if target already visited) - # This ensures correct SQL ordering when inserting records - record_data.dependencies.add(target_id) - logger.debug( - f" -> Dependency: {record_data.identifier} depends on {target_id}" - ) + table = self._get_table_metadata( + record_id.schema_name, record_id.table_name + ) - # Only traverse if not visited - # In strict mode: dependencies should NOT follow incoming FKs (prevents fan-out) - # In wide mode: all records can follow incoming FKs - if not self.visited.is_visited(target_id): - follow_incoming = self.wide_mode - queue.append((target_id, depth + 1, follow_incoming)) - logger.debug(f" -> Following outgoing FK to {target_id}") - - # Traverse incoming FKs (reverse relationships) - # Only follow incoming FKs if this record allows it - if follow_incoming_fks: for fk in table.foreign_keys_incoming: - logger.debug( - f" <- Processing incoming FK: {fk.source_table}, wide_mode={self.wide_mode}" - ) - # In strict mode, skip self-referencing FKs to prevent sibling expansion - # Self-referencing FKs like users.manager_id -> users.id would find peers/siblings + # Skip self-referencing FKs in strict mode if not self.wide_mode: source_schema, source_table = self._parse_table_name( fk.source_table ) - logger.debug( - f" <- Checking FK: {fk.source_table} (parsed: {source_schema}.{source_table}) vs current: {record_id.schema_name}.{record_id.table_name}" - ) if ( source_schema == record_id.schema_name and source_table == record_id.table_name ): logger.debug( - f" <- Skipping self-referencing FK from {source_schema}.{source_table} (strict mode)" + " <- Skipping self-referencing FK (strict mode)" ) continue - source_records = self._find_referencing_records(record_id, fk) - for source_id in source_records: - if not self.visited.is_visited(source_id): - # Records reached via incoming FKs CAN follow incoming FKs - queue.append((source_id, depth + 1, True)) - logger.debug(f" <- Following incoming FK from {source_id}") + # Group records by FK for batch lookup + fk_key = (fk.source_table, fk.source_column) + if fk_key not in incoming_fk_lookups: + incoming_fk_lookups[fk_key] = [] + incoming_fk_lookups[fk_key].append((record_id, follow_incoming_fks)) + + # Batch process incoming FKs + for (source_table, source_column), targets in incoming_fk_lookups.items(): + # Reconstruct FK object for batch lookup + fk_obj = type( + "FK", + (), + {"source_table": source_table, "source_column": source_column}, + )() + target_ids = [tid for tid, _ in targets] + + try: + referencing_map = self._find_referencing_records_batch( + target_ids, fk_obj + ) + + for target_id in target_ids: + source_records = referencing_map.get(target_id, []) + for source_id in source_records: + if not self.visited.is_visited(source_id): + queue.append((source_id, current_depth + 1, True)) + logger.debug( + f" <- Following incoming FK from {source_id}" + ) + except Exception as e: + logger.error(f"Error in batch FK lookup: {e}") logger.info(f"Traversal complete: {len(results)} records found") return results @@ -218,10 +265,6 @@ def traverse_multiple( records = self.traverse(table_name, pk_value, schema, max_depth) all_records.update(records) - # Final progress callback with total unique records - if self.progress_callback: - self.progress_callback(len(all_records)) - logger.info( f"Multi-traversal complete: {len(all_records)} unique records found" ) @@ -281,6 +324,74 @@ def _fetch_record(self, record_id: RecordIdentifier) -> RecordData: return RecordData(identifier=record_id, data=data) + def _fetch_records_batch( + self, record_ids: list[RecordIdentifier] + ) -> dict[RecordIdentifier, RecordData]: + """ + Fetch multiple records in a single query using IN clause. + Groups records by table for efficient batching. + + Args: + record_ids: List of record identifiers to fetch + + Returns: + Dictionary mapping RecordIdentifier to RecordData + """ + if not record_ids: + return {} + + results: dict[RecordIdentifier, RecordData] = {} + + # Group by (schema, table) + by_table: dict[tuple[str, str], list[RecordIdentifier]] = {} + for record_id in record_ids: + key = (record_id.schema_name, record_id.table_name) + by_table.setdefault(key, []).append(record_id) + + # Fetch each table's records in batch + for (schema, table), table_record_ids in by_table.items(): + table_metadata = self._get_table_metadata(schema, table) + + if not table_metadata.primary_keys: + logger.warning(f"Table {schema}.{table} has no primary key, skipping") + continue + + # Build WHERE clause: WHERE id IN (1, 2, 3) + pk_col = table_metadata.primary_keys[0] + pk_values = [rid.pk_values[0] for rid in table_record_ids] + placeholders = ", ".join(["%s"] * len(pk_values)) + + # Apply timeframe filter if applicable + timeframe_clause = "" + params: list[Any] = pk_values.copy() + + if table in self.timeframe_filters: + filter_config = self.timeframe_filters[table] + timeframe_clause = ( + f' AND "{filter_config.column_name}" BETWEEN %s AND %s' + ) + params.extend([filter_config.start_date, filter_config.end_date]) + + query = f""" + SELECT * FROM "{schema}"."{table}" + WHERE "{pk_col}" IN ({placeholders}){timeframe_clause} + """ + + with self.conn.cursor() as cur: + cur.execute(query, params) + rows = cur.fetchall() + columns = [desc[0] for desc in (cur.description or [])] + + for row in rows: + data = dict(zip(columns, row, strict=False)) + pk_value = data[pk_col] + record_id = self._create_record_identifier( + schema, table, (pk_value,) + ) + results[record_id] = RecordData(identifier=record_id, data=data) + + return results + def _resolve_foreign_key_target( self, record: RecordData, fk: Any ) -> RecordIdentifier | None: @@ -378,6 +489,81 @@ def _find_referencing_records( return results + def _find_referencing_records_batch( + self, target_ids: list[RecordIdentifier], fk: Any + ) -> dict[RecordIdentifier, list[RecordIdentifier]]: + """ + Find all records referencing multiple targets via single FK using IN clause. + + Args: + target_ids: List of target record identifiers being referenced + fk: ForeignKey object + + Returns: + Dictionary mapping each target_id to list of RecordIdentifiers referencing it + """ + if not target_ids: + return {} + + # Parse source table + schema, table = self._parse_table_name(fk.source_table) + + # Get primary keys for source table + source_table = self._get_table_metadata(schema, table) + if not source_table.primary_keys: + logger.warning(f"Table {schema}.{table} has no primary key, skipping") + return {} + + # Get target PK values + target_pk_values = [tid.pk_values[0] for tid in target_ids] + + # Build query with IN clause + pk_columns = ", ".join(f'"{pk}"' for pk in source_table.primary_keys) + placeholders = ", ".join(["%s"] * len(target_pk_values)) + + # Apply timeframe filter if applicable + timeframe_clause = "" + params: list[Any] = target_pk_values.copy() + + if table in self.timeframe_filters: + filter_config = self.timeframe_filters[table] + timeframe_clause = f' AND "{filter_config.column_name}" BETWEEN %s AND %s' + params.extend([filter_config.start_date, filter_config.end_date]) + + # Include FK column to map back to targets + query = f""" + SELECT {pk_columns}, "{fk.source_column}" + FROM "{schema}"."{table}" + WHERE "{fk.source_column}" IN ({placeholders}){timeframe_clause} + """ + + # Initialize results dict with empty lists for all targets + results: dict[RecordIdentifier, list[RecordIdentifier]] = { + tid: [] for tid in target_ids + } + + with self.conn.cursor() as cur: + cur.execute(query, params) + rows = cur.fetchall() + + for row in rows: + fk_value = row[-1] # Last column is FK value + pk_values = row[:-1] # Rest are PK values + + source_id = self._create_record_identifier( + schema, + table, + (pk_values[0],) if len(pk_values) == 1 else tuple(pk_values), + ) + + # Map to correct target + for target_id in target_ids: + if str(target_id.pk_values[0]) == str(fk_value): + results[target_id].append(source_id) + break + + return results + def _get_table_metadata(self, schema: str, table: str) -> Table: """ Get table metadata with caching. diff --git a/src/pgslice/repl.py b/src/pgslice/repl.py index cc04015..935b6d9 100644 --- a/src/pgslice/repl.py +++ b/src/pgslice/repl.py @@ -3,9 +3,10 @@ from __future__ import annotations import shlex +import time from pathlib import Path -from printy import printy, raw_format +from printy import printy, raw from prompt_toolkit import PromptSession from prompt_toolkit.completion import WordCompleter from prompt_toolkit.history import FileHistory @@ -190,9 +191,9 @@ def _cmd_dump(self, args: list[str]) -> None: # Wide mode warning if wide_mode: printy( - " [y]⚠ Note: Wide mode follows ALL relationships including self-referencing FKs.@" + "\n[gI]⚠ Note: Wide mode follows ALL relationships including self-referencing FKs.@" ) - printy(" [y] This may take longer and fetch more data.@\n") + printy("[gI]This may take longer and fetch more data.@\n") if timeframe_filters: printy(" [y]Truncate filters:@") @@ -201,6 +202,9 @@ def _cmd_dump(self, args: list[str]) -> None: printy("") # Empty line after filters try: + # Start timing + start_time = time.time() + # Use DumpService for the actual dump # REPL always writes to files, so progress bar is safe to show service = DumpService(self.conn_manager, self.config, show_progress=True) @@ -215,13 +219,22 @@ def _cmd_dump(self, args: list[str]) -> None: show_graph=show_graph, ) + # Calculate and format elapsed time + elapsed_time = time.time() - start_time + if elapsed_time >= 60: + time_str = f"{elapsed_time / 60:.1f}m" + elif elapsed_time >= 1: + time_str = f"{elapsed_time:.1f}s" + else: + time_str = f"{elapsed_time * 1000:.0f}ms" + printy(f"\n [g]✓ Found {result.record_count} related records@") # Output if output_file: SQLWriter.write_to_file(result.sql_content, output_file) printy( - f" [g]✓ Wrote {result.record_count} INSERT statements to {output_file}@\n" + f" [g]✓ Wrote {result.record_count} INSERT statements to {output_file} (took {time_str})@\n" ) else: # Use default output path @@ -233,7 +246,7 @@ def _cmd_dump(self, args: list[str]) -> None: ) SQLWriter.write_to_file(result.sql_content, str(default_path)) printy( - f" [g]✓ Wrote {result.record_count} INSERT statements to {default_path}@\n" + f" [g]✓ Wrote {result.record_count} INSERT statements to {default_path} (took {time_str})@\n" ) except DBReverseDumpError as e: @@ -260,8 +273,8 @@ def _cmd_help(self, args: list[str]) -> None: tabulate( help_data, headers=[ - raw_format("Command", flags="B"), - raw_format("Description", flags="B"), + raw("Command", flags="B"), + raw("Description", flags="B"), ], tablefmt="simple", ) diff --git a/src/pgslice/utils/spinner.py b/src/pgslice/utils/spinner.py index 0650e49..fd88978 100644 --- a/src/pgslice/utils/spinner.py +++ b/src/pgslice/utils/spinner.py @@ -2,7 +2,10 @@ from __future__ import annotations +import threading import time +from collections.abc import Callable, Iterator +from contextlib import contextmanager class SpinnerAnimator: @@ -46,3 +49,40 @@ def reset(self) -> None: """Reset spinner to initial state.""" self._current_idx = 0 self._last_update = time.time() + + +@contextmanager +def animated_spinner( + spinner: SpinnerAnimator, + update_fn: Callable[[str], None], + base_text: str, + interval: float = 0.1, +) -> Iterator[None]: + """ + Context manager that animates spinner in background thread. + + Args: + spinner: SpinnerAnimator instance + update_fn: Function to call with updated description (e.g., pbar.set_description) + base_text: Base description text (spinner appended) + interval: Update interval in seconds + + Usage: + with animated_spinner(spinner, pbar.set_description, "Sorting dependencies"): + sorter.sort(records) # Spinner animates during this + """ + stop_event = threading.Event() + + def animate() -> None: + while not stop_event.is_set(): + update_fn(f"{base_text} {spinner.get_frame()}") + stop_event.wait(interval) + + thread = threading.Thread(target=animate, daemon=True) + thread.start() + try: + yield + finally: + stop_event.set() + thread.join(timeout=0.5) + update_fn(f"{base_text} ✓") diff --git a/tests/unit/graph/test_traverser.py b/tests/unit/graph/test_traverser.py index 37ff3b0..fe36556 100644 --- a/tests/unit/graph/test_traverser.py +++ b/tests/unit/graph/test_traverser.py @@ -30,7 +30,12 @@ def mock_cursor(self) -> MagicMock: cursor = MagicMock() cursor.execute = MagicMock() cursor.fetchone = MagicMock() - cursor.fetchall = MagicMock(return_value=[]) + # Set up fetchall to return the same data as fetchone for batch compatibility + cursor.fetchall = MagicMock( + side_effect=lambda: [cursor.fetchone.return_value] + if cursor.fetchone.return_value + else [] + ) cursor.description = [("id",), ("name",)] return cursor @@ -537,7 +542,7 @@ def test_skips_incoming_fks_when_disabled( mock_introspector.get_table_metadata.return_value = users_table mock_cursor.fetchone.return_value = (1, "Alice") - results = traverser.traverse("users", "public", (1,)) + results = traverser.traverse("users", 1, "public") # Should only find the user, not referencing orders assert len(results) == 1 @@ -631,7 +636,7 @@ def test_strict_mode_skips_self_referencing_fk( mock_introspector.get_table_metadata.return_value = users_table mock_cursor.fetchone.return_value = (1, "Employee", 2) - results = traverser.traverse("users", "public", (1,)) + results = traverser.traverse("users", 1, "public") # In strict mode, should skip self-referencing FK assert len(results) == 1 diff --git a/tests/unit/graph/test_traverser_progress.py b/tests/unit/graph/test_traverser_progress.py deleted file mode 100644 index 5660ffe..0000000 --- a/tests/unit/graph/test_traverser_progress.py +++ /dev/null @@ -1,185 +0,0 @@ -"""Tests for traverser progress callback functionality.""" - -from __future__ import annotations - -from unittest.mock import MagicMock - -from pgslice.graph.models import RecordData -from pgslice.graph.traverser import RelationshipTraverser - - -class TestProgressCallback: - """Tests for progress callback parameter and invocation.""" - - def test_traverser_accepts_progress_callback(self) -> None: - """Should accept progress_callback parameter without errors.""" - mock_conn = MagicMock() - mock_introspector = MagicMock() - mock_visited = MagicMock() - mock_callback = MagicMock() - - traverser = RelationshipTraverser( - mock_conn, - mock_introspector, - mock_visited, - progress_callback=mock_callback, - ) - - assert traverser.progress_callback == mock_callback - - def test_progress_callback_none_does_not_error(self) -> None: - """Should work correctly when progress_callback is None.""" - mock_conn = MagicMock() - mock_introspector = MagicMock() - mock_visited = MagicMock() - - # Should not raise any errors - traverser = RelationshipTraverser( - mock_conn, - mock_introspector, - mock_visited, - progress_callback=None, - ) - - assert traverser.progress_callback is None - - def test_progress_callback_invoked_per_record(self, mocker: MagicMock) -> None: - """Should invoke callback after each record is fetched.""" - mock_conn = MagicMock() - mock_introspector = MagicMock() - mock_visited = MagicMock() - mock_callback = MagicMock() - - # Mock table metadata - mock_table = MagicMock() - mock_table.primary_keys = ["id"] - mock_table.foreign_keys_outgoing = [] - mock_table.foreign_keys_incoming = [] - mock_introspector.get_table_metadata.return_value = mock_table - - # Mock visited tracker - mock_visited.is_visited.return_value = False - mock_visited.mark_visited.return_value = None - - # Mock fetch_record to return a valid record - def mock_fetch(record_id): - return RecordData( - identifier=record_id, - data={"id": record_id.pk_values[0]}, - ) - - traverser = RelationshipTraverser( - mock_conn, - mock_introspector, - mock_visited, - progress_callback=mock_callback, - ) - - # Patch _fetch_record method - mocker.patch.object(traverser, "_fetch_record", side_effect=mock_fetch) - - # Traverse a single record - traverser.traverse("users", "1", "public", max_depth=0) - - # Callback should be invoked at least once - assert mock_callback.called - mock_callback.assert_called_with(1) - - def test_progress_callback_receives_correct_count(self, mocker: MagicMock) -> None: - """Should pass accurate record count to callback.""" - mock_conn = MagicMock() - mock_introspector = MagicMock() - mock_visited = MagicMock() - mock_callback = MagicMock() - - # Mock table metadata with no relationships - mock_table = MagicMock() - mock_table.primary_keys = ["id"] - mock_table.foreign_keys_outgoing = [] - mock_table.foreign_keys_incoming = [] - mock_introspector.get_table_metadata.return_value = mock_table - - # Track visited records - visited_records = set() - - def is_visited(record_id): - return record_id in visited_records - - def mark_visited(record_id): - visited_records.add(record_id) - - mock_visited.is_visited.side_effect = is_visited - mock_visited.mark_visited.side_effect = mark_visited - - # Mock fetch_record - def mock_fetch(record_id): - return RecordData( - identifier=record_id, - data={"id": record_id.pk_values[0]}, - ) - - traverser = RelationshipTraverser( - mock_conn, - mock_introspector, - mock_visited, - progress_callback=mock_callback, - ) - - mocker.patch.object(traverser, "_fetch_record", side_effect=mock_fetch) - - # Traverse - traverser.traverse("users", "1", "public", max_depth=0) - - # Should be called with count 1 (only the starting record) - mock_callback.assert_called_with(1) - - def test_traverse_multiple_updates_progress(self, mocker: MagicMock) -> None: - """Should invoke callback when traversing multiple starting records.""" - mock_conn = MagicMock() - mock_introspector = MagicMock() - mock_visited = MagicMock() - mock_callback = MagicMock() - - # Mock table metadata - mock_table = MagicMock() - mock_table.primary_keys = ["id"] - mock_table.foreign_keys_outgoing = [] - mock_table.foreign_keys_incoming = [] - mock_introspector.get_table_metadata.return_value = mock_table - - # Track visited records to avoid duplicates - visited_records = set() - - def is_visited(record_id): - return record_id in visited_records - - def mark_visited(record_id): - visited_records.add(record_id) - - mock_visited.is_visited.side_effect = is_visited - mock_visited.mark_visited.side_effect = mark_visited - - # Mock fetch_record - def mock_fetch(record_id): - return RecordData( - identifier=record_id, - data={"id": record_id.pk_values[0]}, - ) - - traverser = RelationshipTraverser( - mock_conn, - mock_introspector, - mock_visited, - progress_callback=mock_callback, - ) - - mocker.patch.object(traverser, "_fetch_record", side_effect=mock_fetch) - - # Traverse multiple records - traverser.traverse_multiple("users", ["1", "2", "3"], "public", max_depth=0) - - # Should be called multiple times (once per record + final callback) - assert mock_callback.call_count >= 3 - # Final call should have count of 3 unique records - final_call_arg = mock_callback.call_args[0][0] - assert final_call_arg == 3 diff --git a/tests/unit/test_cli.py b/tests/unit/test_cli.py index 155a484..41142e3 100644 --- a/tests/unit/test_cli.py +++ b/tests/unit/test_cli.py @@ -320,7 +320,7 @@ def test_table_without_pks_or_timeframe_fails( "test", "--database", "test", - "--table", + "--dump", "users", ], ), @@ -360,7 +360,7 @@ def test_cli_dump_mode_executes(self) -> None: "test", "--database", "test", - "--table", + "--dump", "users", "--pks", "1", @@ -381,6 +381,7 @@ def test_cli_dump_mode_executes(self) -> None: mock_config.cache.enabled = False mock_config.connection_ttl_minutes = 30 mock_config.create_schema = False + mock_config.output_dir = Path("/tmp") mock_load.return_value = mock_config mock_cm_instance = MagicMock() @@ -390,6 +391,9 @@ def test_cli_dump_mode_executes(self) -> None: mock_service_instance.dump.return_value = mock_result mock_dump_service.return_value = mock_service_instance + # Mock get_default_output_path for auto-generated filename + mock_writer.get_default_output_path.return_value = Path("/tmp/test.sql") + exit_code = main() assert exit_code == 0 @@ -399,8 +403,8 @@ def test_cli_dump_mode_executes(self) -> None: assert call_kwargs["table"] == "users" assert call_kwargs["pk_values"] == ["1"] - # SQL should be written to stdout (no --output flag) - mock_writer.write_to_stdout.assert_called_once_with(mock_result.sql_content) + # SQL should be written to file (always writes to files now) + mock_writer.write_to_file.assert_called_once() def test_cli_dump_with_output_file(self, tmp_path: Path) -> None: """Should write to file when --output is specified.""" @@ -425,7 +429,7 @@ def test_cli_dump_with_output_file(self, tmp_path: Path) -> None: "test", "--database", "test", - "--table", + "--dump", "users", "--pks", "1", @@ -448,6 +452,7 @@ def test_cli_dump_with_output_file(self, tmp_path: Path) -> None: mock_config.cache.enabled = False mock_config.connection_ttl_minutes = 30 mock_config.create_schema = False + mock_config.output_dir = tmp_path mock_load.return_value = mock_config mock_cm_instance = MagicMock() @@ -487,7 +492,7 @@ def test_cli_dump_with_flags(self) -> None: "test", "--database", "test", - "--table", + "--dump", "users", "--pks", "1,2,3", @@ -502,7 +507,7 @@ def test_cli_dump_with_flags(self) -> None: patch("pgslice.cli.SecureCredentials"), patch("pgslice.cli.ConnectionManager") as mock_cm, patch("pgslice.cli.DumpService") as mock_dump_service, - patch("pgslice.cli.SQLWriter"), + patch("pgslice.cli.SQLWriter") as mock_writer, ): mock_config = MagicMock() mock_config.db.host = "localhost" @@ -513,6 +518,7 @@ def test_cli_dump_with_flags(self) -> None: mock_config.cache.enabled = False mock_config.connection_ttl_minutes = 30 mock_config.create_schema = False + mock_config.output_dir = Path("/tmp") mock_load.return_value = mock_config mock_cm_instance = MagicMock() @@ -522,6 +528,9 @@ def test_cli_dump_with_flags(self) -> None: mock_service_instance.dump.return_value = mock_result mock_dump_service.return_value = mock_service_instance + # Mock get_default_output_path for auto-generated filename + mock_writer.get_default_output_path.return_value = Path("/tmp/test.sql") + exit_code = main() assert exit_code == 0 @@ -596,7 +605,7 @@ def test_mutual_exclusion_with_pks(self) -> None: "argv", [ "pgslice", - "--table", + "--dump", "users", "--pks", "1", @@ -635,7 +644,7 @@ def test_timeframe_mode_executes(self) -> None: "test", "--database", "test", - "--table", + "--dump", "users", "--timeframe", "created_at:2024-01-01:2024-12-31", @@ -646,7 +655,7 @@ def test_timeframe_mode_executes(self) -> None: patch("pgslice.cli.ConnectionManager") as mock_cm, patch("pgslice.cli.SchemaIntrospector") as mock_introspector, patch("pgslice.cli.DumpService") as mock_dump_service, - patch("pgslice.cli.SQLWriter"), + patch("pgslice.cli.SQLWriter") as mock_writer, patch("pgslice.cli.printy"), ): mock_config = MagicMock() @@ -658,6 +667,7 @@ def test_timeframe_mode_executes(self) -> None: mock_config.cache.enabled = False mock_config.connection_ttl_minutes = 30 mock_config.create_schema = False + mock_config.output_dir = Path("/tmp") mock_load.return_value = mock_config mock_cm_instance = MagicMock() @@ -676,6 +686,9 @@ def test_timeframe_mode_executes(self) -> None: mock_service_instance.dump.return_value = mock_result mock_dump_service.return_value = mock_service_instance + # Mock get_default_output_path for auto-generated filename + mock_writer.get_default_output_path.return_value = Path("/tmp/test.sql") + exit_code = main() assert exit_code == 0 @@ -701,7 +714,7 @@ def test_timeframe_mode_empty_result(self) -> None: "test", "--database", "test", - "--table", + "--dump", "users", "--timeframe", "created_at:2024-01-01:2024-12-31", @@ -722,6 +735,7 @@ def test_timeframe_mode_empty_result(self) -> None: mock_config.cache.enabled = False mock_config.connection_ttl_minutes = 30 mock_config.create_schema = False + mock_config.output_dir = Path("/tmp") mock_load.return_value = mock_config mock_cm_instance = MagicMock()