diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..c41c9d3 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,71 @@ +# Changelog + +All notable changes to pgslice will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [Unreleased] + +## [0.2.0] - 2025-12-28 + +### Added +- **CLI-first design**: pgslice now works as a CLI tool that can dump records without entering REPL + - `--table TABLE` + `--pks PK_VALUES`: Dump specific records by primary key (comma-separated) + - `--timeframe COLUMN:START:END`: Filter main table by date range (alternative to `--pks`) + - `--truncate TABLE:COL:START:END`: Apply timeframe filters to related tables (repeatable) + - `--output FILE`: Write output to file (default: stdout for easy piping) + - `--wide`: Enable wide mode (follow self-referencing FKs) + - `--keep-pks`: Keep original primary key values instead of remapping + - `--graph`: Display table relationship graph after dump completes +- **Schema introspection commands**: + - `--tables`: List all tables in the schema with formatted output + - `--describe TABLE`: Show table structure and relationships +- **Schema DDL generation**: New `--create-schema` flag for dump command + - Generates `CREATE DATABASE IF NOT EXISTS` statements + - Generates `CREATE SCHEMA IF NOT EXISTS` for all schemas + - Generates `CREATE TABLE IF NOT EXISTS` with complete table definitions + - Includes columns, primary keys, unique constraints, and foreign keys + - Handles circular dependencies via ALTER TABLE statements + - Supports all PostgreSQL data types including arrays and user-defined types + - All DDL uses IF NOT EXISTS for idempotency (can run multiple times safely) + - Works with both `--keep-pks` and default PK remapping modes +- **Dependency graph visualization**: New `--graph` flag displays ASCII art graph of table relationships + - Shows record counts per table + - Displays FK relationships between tables + - Highlights root table(s) in the graph +- **Progress indicators**: Visual feedback for long-running operations + - Spinner animation during traversal operations + - Progress bar enabled in both CLI and REPL modes + - Automatically disabled when output is piped (not a TTY) +- **Centralized operations module**: New `pgslice.operations` package for shared CLI/REPL logic + - `operations/dump_ops.py`: Shared dump execution logic + - `operations/parsing.py`: Timeframe/truncate filter parsing utilities + - `operations/schema_ops.py`: List tables and describe table operations + +### Changed +- **REPL mode improvements**: + - Renamed `--timeframe` flag to `--truncate` for clarity (applies to related tables, not main table) + - Enabled progress bar in REPL mode for better user feedback + - Updated to use centralized operations from `operations/` module + - Improved help text and error messages +- **Logging behavior**: Log level now defaults to disabled unless `--log-level` is explicitly specified +- **Code organization**: + - Refactored CLI to support both interactive REPL and non-interactive CLI modes + - Eliminated code duplication between CLI and REPL by introducing shared operations + - `SQLGenerator.generate_batch()` now accepts optional DDL parameters: `create_schema`, `database_name`, `schema_name` + - `AppConfig` dataclass includes new `create_schema: bool = False` field + +### Fixed +- Removed `IF NOT EXISTS` from `CREATE DATABASE` statement (PostgreSQL doesn't support it) + +### Technical Details +- **New modules**: + - `pgslice.dumper.ddl_generator.DDLGenerator`: DDL generation for schema dumps + - `pgslice.dumper.dump_service.DumpService`: Centralized dump service + - `pgslice.operations`: Package with shared CLI/REPL operations + - `pgslice.utils.graph_visualizer`: Dependency graph visualization + - `pgslice.utils.spinner`: Spinner animation for progress indication +- **Test coverage**: 8 new test modules added, maintained >93% overall code coverage +- **Dependency management**: Uses Kahn's algorithm for table dependency ordering in DDL generation +- **Architecture**: Cleaner separation between CLI routing, REPL mode, and shared operations diff --git a/Makefile b/Makefile index adac641..6e062e5 100644 --- a/Makefile +++ b/Makefile @@ -54,41 +54,6 @@ clean: ## Remove build artifacts and cache show-version: ## Show current version from pyproject.toml @uv version -bump-patch: ## Bump patch version (0.1.1 -> 0.1.2) - @uv version --bump patch - -bump-minor: ## Bump minor version (0.1.1 -> 0.2.0) - @uv version --bump minor - -bump-major: ## Bump major version (0.1.1 -> 1.0.0) - @uv version --bump major - -# Python package building and publishing -build-dist: clean ## Build Python distribution packages (wheel + sdist) - @echo "Building distribution packages..." - uv build - @echo "Build complete! Packages in dist/" - @ls -lh dist/ - -install-local: build-dist ## Install package locally from built wheel - @echo "Installing from local build..." - uv pip install dist/*.whl --force-reinstall - @echo "Installation complete! Test with: pgslice --version" - -publish-test: build-dist ## Publish to TestPyPI for testing - @echo "Publishing to TestPyPI..." - uv publish --publish-url https://test.pypi.org/legacy/ - @echo "Published to TestPyPI! Install with:" - @echo " pip install --index-url https://test.pypi.org/simple/ pgslice" - -publish: all-checks build-dist ## Publish to production PyPI (requires confirmation) - @echo "WARNING: This will publish to production PyPI!" - @read -p "Version $$(grep '^version = ' pyproject.toml | cut -d'"' -f2) - Continue? [y/N] " confirm && \ - [ "$$confirm" = "y" ] || [ "$$confirm" = "Y" ] || (echo "Aborted." && exit 1) - @echo "Publishing to PyPI..." - uv publish - @echo "Published! Install with: pip install pgslice" - # Docker commands docker-build: ## Build Docker image docker build -t $(DOCKER_IMAGE) . @@ -123,20 +88,9 @@ uv-install: ## Install uv (one-time setup) sync: ## Sync dependencies with uv (local development) uv sync --all-extras -lock: ## Update uv.lock file - uv lock - -test-compat: ## Test compatibility across Python versions - @echo "Testing Python 3.10..." - @uv run --python 3.10 python --version || echo "Python 3.10 not available" - @echo "Testing Python 3.13..." - @uv run --python 3.13 python --version || echo "Python 3.13 not available" - @echo "Testing Python 3.14..." - @uv run --python 3.14 python --version || echo "Python 3.14 not available" - setup: ## One-time local development setup @echo "Copying env file..." - cp .env.template .env + cp .env.example .env @echo "Setting up local development environment..." @command -v uv >/dev/null 2>&1 || (echo "Installing uv..." && curl -LsSf https://astral.sh/uv/install.sh | sh) @echo "Installing Python 3.14..." diff --git a/README.md b/README.md index 2d5e5a9..e386d72 100644 --- a/README.md +++ b/README.md @@ -32,12 +32,14 @@ Extract only what you need while maintaining referential integrity. ## Features +- ✅ **CLI-first design**: Stream SQL to stdout for easy piping and scripting - ✅ **Bidirectional FK traversal**: Follows relationships in both directions (forward and reverse) - ✅ **Circular relationship handling**: Prevents infinite loops with visited tracking - ✅ **Multiple records**: Extract multiple records in one operation - ✅ **Timeframe filtering**: Filter specific tables by date ranges - ✅ **PK remapping**: Auto-remaps auto-generated primary keys for clean imports -- ✅ **Interactive REPL**: User-friendly command-line interface +- ✅ **DDL generation**: Optionally include CREATE DATABASE/SCHEMA/TABLE statements for self-contained dumps +- ✅ **Progress bar**: Visual progress indicator for dump operations - ✅ **Schema caching**: SQLite-based caching for improved performance - ✅ **Type-safe**: Full type hints with mypy strict mode - ✅ **Secure**: SQL injection prevention, secure password handling @@ -83,40 +85,136 @@ See [DEVELOPMENT.md](DEVELOPMENT.md) for detailed development setup instructions ## Quick Start +### CLI Mode + +The CLI mode streams SQL to stdout by default, making it easy to pipe or redirect output: + ```bash -# In REPL: -# This will dump all related records to the film with id 1 -# The generated SQL file will be placed, by default, in ~/.pgslice/dumps -# The name will be a formated string with table name, id, and timestamp -pgslice> dump "film" 1 +# Basic dump to stdout (pipe to file) +PGPASSWORD=xxx pgslice --host localhost --database mydb --table users --pks 42 > user_42.sql -# You can overwrite the output path with: -pgslice> dump "film" 1 --output film_1.sql +# Multiple records +PGPASSWORD=xxx pgslice --host localhost --database mydb --table users --pks 1,2,3 > users.sql -# Extract multiple records -pgslice> dump "actor" 1,2,3 --output multiple_actors.sql +# Output directly to file with --output flag +pgslice --host localhost --database mydb --table users --pks 42 --output user_42.sql -# Use wide mode to follow all relationships (including self-referencing FKs) -# Be cautions that this can result in larger datasets. So use with caution -pgslice> dump "customer" 42 --wide --output customer_42.sql +# Dump by timeframe (instead of PKs) - filters main table by date range +pgslice --host localhost --database mydb --table orders \ + --timeframe "created_at:2024-01-01:2024-12-31" > orders_2024.sql -# Apply timeframe filter -pgslice> dump "customer" 42 --timeframe "rental:rental_date:2024-01-01:2024-12-31" +# Wide mode: follow all relationships including self-referencing FKs +# Be cautious - this can result in larger datasets +pgslice --host localhost --database mydb --table customer --pks 42 --wide > customer.sql -# List all tables -pgslice> tables +# Keep original primary keys (no remapping) +pgslice --host localhost --database mydb --table film --pks 1 --keep-pks > film.sql + +# Generate self-contained SQL with DDL statements +# Includes CREATE DATABASE/SCHEMA/TABLE statements +pgslice --host localhost --database mydb --table film --pks 1 --create-schema > film_complete.sql + +# Apply truncate filter to limit related tables by date range +pgslice --host localhost --database mydb --table customer --pks 42 \ + --truncate "rental:rental_date:2024-01-01:2024-12-31" > customer.sql + +# Enable debug logging (writes to stderr) +pgslice --host localhost --database mydb --table users --pks 42 \ + --log-level DEBUG 2>debug.log > output.sql +``` + +### Schema Exploration + +```bash +# List all tables in the schema +pgslice --host localhost --database mydb --tables + +# Describe table structure and relationships +pgslice --host localhost --database mydb --describe users +``` + +### SSH Remote Execution + +Run pgslice on a remote server and capture output locally: + +```bash +# Execute on remote server, save output locally +ssh remote.server.com "PGPASSWORD=xxx pgslice --host db.internal --database mydb \ + --table users --pks 1 --create-schema" > local_dump.sql + +# With SSH tunnel for database access +ssh -f -N -L 5433:db.internal:5432 bastion.example.com +PGPASSWORD=xxx pgslice --host localhost --port 5433 --database mydb \ + --table users --pks 42 > user.sql +``` -# Show table structure and relationships +### Interactive REPL + +```bash +# Start interactive REPL +pgslice --host localhost --database mydb + +pgslice> dump "film" 1 --output film_1.sql +pgslice> tables pgslice> describe "film" +``` + +## CLI vs REPL: Output Behavior + +Understanding the difference between CLI and REPL modes: + +### CLI Mode (stdout by default) +The CLI streams SQL to **stdout** by default, perfect for piping and scripting: -# Keep original primary key values (no remapping) -# By default, we will dinamically assign ids to the new generated records -# and handle conflicts gracefully. Meaninh, you can run the same file multiple times -# and no conflicts will arise. -# If you want to keep the original id's run: -pgslice> dump "film" 1 --keep-pks --output film_1.sql +```bash +# Streams to stdout - redirect with > +pgslice --table users --pks 42 > user_42.sql + +# Or use --output flag +pgslice --table users --pks 42 --output user_42.sql + +# Pipe to other commands +pgslice --table users --pks 42 | gzip > user_42.sql.gz ``` +### REPL Mode (files by default) +The REPL writes to **`~/.pgslice/dumps/`** by default when `--output` is not specified: + +```bash +# In REPL: writes to ~/.pgslice/dumps/public_users_42.sql +pgslice> dump "users" 42 + +# Specify custom output path +pgslice> dump "users" 42 --output /path/to/user.sql +``` + +### Same Operations, Different Modes + +| Operation | CLI | REPL | +|-----------|-----|------| +| **List tables** | `pgslice --tables` | `pgslice> tables` | +| **Describe table** | `pgslice --describe users` | `pgslice> describe "users"` | +| **Dump to stdout** | `pgslice --table users --pks 42` | N/A (REPL always writes to file) | +| **Dump to file** | `pgslice --table users --pks 42 --output user.sql` | `pgslice> dump "users" 42 --output user.sql` | +| **Dump (default)** | Stdout | `~/.pgslice/dumps/public_users_42.sql` | +| **Multiple PKs** | `pgslice --table users --pks 1,2,3` | `pgslice> dump "users" 1,2,3` | +| **Truncate filter** | `pgslice --table users --pks 42 --truncate "orders:2024-01-01:2024-12-31"` | `pgslice> dump "users" 42 --truncate "orders:2024-01-01:2024-12-31"` | +| **Wide mode** | `pgslice --table users --pks 42 --wide` | `pgslice> dump "users" 42 --wide` | + +### When to Use Each Mode + +**Use CLI mode when:** +- Piping output to other commands +- Scripting and automation +- Remote execution via SSH +- One-off dumps + +**Use REPL mode when:** +- Exploring database schema interactively +- Running multiple dumps in a session +- You prefer persistent file output +- Testing different dump configurations + ## Configuration Key environment variables (see `.env.example` for full reference): @@ -131,7 +229,7 @@ Key environment variables (see `.env.example` for full reference): | `PGPASSWORD` | Database password (env var only) | - | | `CACHE_ENABLED` | Enable schema caching | `true` | | `CACHE_TTL_HOURS` | Cache time-to-live | `24` | -| `LOG_LEVEL` | Logging level | `INFO` | +| `LOG_LEVEL` | Logging level (disabled by default unless specified) | disabled | | `PGSLICE_OUTPUT_DIR` | Output directory | `~/.pgslice/dumps` | ## Security diff --git a/pyproject.toml b/pyproject.toml index 1e40410..3262b4b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ dependencies = [ "printy==3.0.0", "tabulate>=0.9.0", "python-dotenv>=1.0.0", + "tqdm>=4.66.0", ] [project.optional-dependencies] @@ -104,6 +105,10 @@ ignore_missing_imports = true module = "tabulate.*" ignore_missing_imports = true +[[tool.mypy.overrides]] +module = "tqdm.*" +ignore_missing_imports = true + [tool.pytest.ini_options] testpaths = ["tests"] python_files = "test_*.py" diff --git a/src/pgslice/cli.py b/src/pgslice/cli.py index a84b7fb..2a48cbf 100644 --- a/src/pgslice/cli.py +++ b/src/pgslice/cli.py @@ -4,18 +4,250 @@ import argparse import sys +from dataclasses import dataclass +from datetime import datetime from importlib.metadata import version as get_version -from .config import load_config +from printy import printy + +from .config import AppConfig, load_config from .db.connection import ConnectionManager +from .db.schema import SchemaIntrospector +from .dumper.dump_service import DumpService +from .dumper.writer import SQLWriter +from .operations import ( + describe_table, + list_tables, + parse_truncate_filters, + print_tables, +) from .repl import REPL -from .utils.exceptions import DBReverseDumpError +from .utils.exceptions import DBReverseDumpError, InvalidTimeframeError from .utils.logging_config import get_logger, setup_logging from .utils.security import SecureCredentials logger = get_logger(__name__) +@dataclass +class MainTableTimeframe: + """Timeframe filter for the main table.""" + + column_name: str + start_date: datetime + end_date: datetime + + +def parse_main_timeframe(spec: str) -> MainTableTimeframe: + """ + Parse main table timeframe specification. + + Format: column:start_date:end_date + Example: created_at:2024-01-01:2024-12-31 + + Args: + spec: Timeframe specification string + + Returns: + MainTableTimeframe object + + Raises: + InvalidTimeframeError: If specification is invalid + """ + parts = spec.split(":") + if len(parts) != 3: + raise InvalidTimeframeError( + f"Invalid timeframe format: {spec}. " + "Expected: column:start:end (e.g., created_at:2024-01-01:2024-12-31)" + ) + + column_name, start_str, end_str = parts + + try: + start_date = datetime.fromisoformat(start_str) + except ValueError as e: + raise InvalidTimeframeError(f"Invalid start date: {start_str}") from e + + try: + end_date = datetime.fromisoformat(end_str) + except ValueError as e: + raise InvalidTimeframeError(f"Invalid end date: {end_str}") from e + + return MainTableTimeframe( + column_name=column_name, + start_date=start_date, + end_date=end_date, + ) + + +def fetch_pks_by_timeframe( + conn_manager: ConnectionManager, + table: str, + schema: str, + timeframe: MainTableTimeframe, +) -> list[str]: + """ + Fetch primary key values matching the timeframe filter. + + Args: + conn_manager: Database connection manager + table: Table name + schema: Schema name + timeframe: Timeframe filter + + Returns: + List of primary key values as strings + """ + printy("[y]Warning: Fetching records by timeframe may be slow for large tables@") + + conn = conn_manager.get_connection() + introspector = SchemaIntrospector(conn) + table_meta = introspector.get_table_metadata(schema, table) + + if not table_meta.primary_keys: + raise DBReverseDumpError(f"Table {schema}.{table} has no primary key") + + # Use first primary key column for simplicity + pk_col = table_meta.primary_keys[0] + + # Build and execute query + query = f''' + SELECT "{pk_col}" + FROM "{schema}"."{table}" + WHERE "{timeframe.column_name}" BETWEEN %s AND %s + ''' + + with conn.cursor() as cur: + cur.execute(query, (timeframe.start_date, timeframe.end_date)) + rows = cur.fetchall() + + pk_values = [str(row[0]) for row in rows] + printy(f"[c]Found {len(pk_values)} records matching timeframe@") + return pk_values + + +def run_cli_dump( + args: argparse.Namespace, + config: AppConfig, + conn_manager: ConnectionManager, +) -> int: + """ + Execute dump in non-interactive CLI mode. + + Args: + args: Parsed command line arguments + config: Application configuration + conn_manager: Database connection manager + + Returns: + Exit code (0 for success, non-zero for error) + """ + # Parse truncate filters for related tables + try: + truncate_filters = parse_truncate_filters(args.truncate) + except InvalidTimeframeError as e: + sys.stderr.write(f"Error: {e}\n") + return 1 + + # Determine PK values - either from --pks or --timeframe + if args.pks: + pk_values = [v.strip() for v in args.pks.split(",")] + elif args.timeframe: + try: + timeframe = parse_main_timeframe(args.timeframe) + except InvalidTimeframeError as e: + sys.stderr.write(f"Error: {e}\n") + return 1 + + pk_values = fetch_pks_by_timeframe( + conn_manager, args.table, args.schema, timeframe + ) + if not pk_values: + printy("[y]No records found matching the timeframe@") + return 0 + else: + # Should not reach here due to earlier validation + sys.stderr.write("Error: --pks or --timeframe is required\n") + return 1 + + # Show progress only if stderr is a TTY (not piped) + show_progress = sys.stderr.isatty() + + # Wide mode warning + if args.wide and show_progress: + sys.stderr.write( + "\n⚠ Note: Wide mode follows ALL relationships including self-referencing FKs.\n" + ) + sys.stderr.write(" This may take longer and fetch more data.\n\n") + sys.stderr.flush() + + # Create dump service + service = DumpService(conn_manager, config, show_progress=show_progress) + + # Execute dump + result = service.dump( + table=args.table, + pk_values=pk_values, + schema=args.schema, + wide_mode=args.wide, + keep_pks=args.keep_pks, + create_schema=args.create_schema, + timeframe_filters=truncate_filters, + show_graph=args.graph, + ) + + # Output SQL + if args.output: + SQLWriter.write_to_file(result.sql_content, args.output) + printy(f"[g]Wrote {result.record_count} records to {args.output}@") + else: + SQLWriter.write_to_stdout(result.sql_content) + + return 0 + + +def run_list_tables(conn_manager: ConnectionManager, schema: str) -> int: + """ + List all tables in the specified schema. + + Args: + conn_manager: Database connection manager + schema: Schema name + + Returns: + Exit code (0 for success, non-zero for error) + """ + try: + tables = list_tables(conn_manager, schema) + print_tables(tables, schema) + return 0 + except Exception as e: + printy(f"[r]Error: {e}@") + return 1 + + +def run_describe_table( + conn_manager: ConnectionManager, schema: str, table_name: str +) -> int: + """ + Describe table structure and relationships. + + Args: + conn_manager: Database connection manager + schema: Schema name + table_name: Table name to describe + + Returns: + Exit code (0 for success, non-zero for error) + """ + try: + describe_table(conn_manager, schema, table_name) + return 0 + except Exception as e: + printy(f"[r]Error: {e}@") + return 1 + + def main() -> int: """ Main entry point for pgslice CLI. @@ -28,11 +260,23 @@ def main() -> int: formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: - # Start interactive REPL - %(prog)s --host localhost --port 5432 --user postgres --database mydb + # Dump to stdout + PGPASSWORD=xxx %(prog)s --host localhost --database mydb --table users --pks 42 + + # Dump by timeframe (instead of PKs) + %(prog)s --host localhost --database mydb --table orders --timeframe "created_at:2024-01-01:2024-12-31" + + # Dump to file with truncate filter for related tables + %(prog)s --table users --pks 1 --truncate "orders:created_at:2024-01-01:2024-12-31" --output user.sql + + # List all tables + %(prog)s --host localhost --database mydb --tables + + # Describe table structure + %(prog)s --host localhost --database mydb --describe users - # Require read-only connection - %(prog)s --host prod-db --require-read-only --database mydb + # Interactive REPL + %(prog)s --host localhost --database mydb # Clear cache and exit %(prog)s --clear-cache @@ -74,13 +318,77 @@ def main() -> int: action="store_true", help="Clear schema cache and exit", ) + parser.add_argument( + "--create-schema", + action="store_true", + help="Include DDL statements (CREATE DATABASE/SCHEMA/TABLE) in SQL dumps", + ) + + # Schema information arguments + info_group = parser.add_argument_group("Schema Information") + info_group.add_argument( + "--tables", + action="store_true", + help="List all tables in the schema", + ) + info_group.add_argument( + "--describe", + metavar="TABLE", + help="Show table structure and relationships", + ) + + # Dump operation arguments (non-interactive CLI mode) + dump_group = parser.add_argument_group("Dump Operation (CLI mode)") + dump_group.add_argument( + "--table", + help="Table name to dump (enables non-interactive CLI mode)", + ) + + # --pks and --timeframe are mutually exclusive ways to select records + pk_source_group = dump_group.add_mutually_exclusive_group() + pk_source_group.add_argument( + "--pks", + help="Primary key value(s), comma-separated (e.g., '42' or '1,2,3')", + ) + pk_source_group.add_argument( + "--timeframe", + metavar="COLUMN:START:END", + help="Filter main table by timeframe (e.g., 'created_at:2024-01-01:2024-12-31'). " + "Mutually exclusive with --pks.", + ) + + dump_group.add_argument( + "--wide", + action="store_true", + help="Wide mode: follow all relationships including self-referencing FKs", + ) + dump_group.add_argument( + "--keep-pks", + action="store_true", + help="Keep original primary key values (default: remap auto-generated PKs)", + ) + dump_group.add_argument( + "--graph", + action="store_true", + help="Display table relationship graph after dump completes", + ) + dump_group.add_argument( + "--truncate", + action="append", + help="Truncate filter for related tables (format: table:column:start:end). Can be repeated.", + ) + dump_group.add_argument( + "--output", + "-o", + help="Output file path (default: stdout)", + ) # Other arguments parser.add_argument( "--log-level", choices=["DEBUG", "INFO", "WARNING", "ERROR"], - default="INFO", - help="Log level (default: INFO)", + default=None, + help="Log level (default: disabled unless specified)", ) # Get version dynamically from package metadata try: @@ -117,8 +425,18 @@ def main() -> int: config.db.schema = args.schema if args.no_cache: config.cache.enabled = False + if args.create_schema: + config.create_schema = True - config.log_level = args.log_level + if args.log_level: + config.log_level = args.log_level + + # Validate CLI dump mode arguments + if args.table and not args.pks and not args.timeframe: + sys.stderr.write( + "Error: --pks or --timeframe is required when using --table\n" + ) + return 1 # Clear cache if requested if args.clear_cache: @@ -130,7 +448,7 @@ def main() -> int: config.cache.ttl_hours, ) # Clear all caches (we don't have specific db info) - logger.info("Cache cleared") + printy("[g]Cache cleared successfully@") else: pass return 0 @@ -151,24 +469,35 @@ def main() -> int: ) # Test connection - logger.info("Testing database connection...") try: conn_manager.get_connection() except Exception as e: logger.error(f"Connection failed: {e}") raise - # Start REPL + # Route: Schema info, CLI dump mode, or REPL mode try: - repl = REPL(conn_manager, config) - repl.start() + # Handle --tables + if args.tables: + return run_list_tables(conn_manager, args.schema) + + # Handle --describe + if args.describe: + return run_describe_table(conn_manager, args.schema, args.describe) + + if args.table: + # Non-interactive CLI dump mode + return run_cli_dump(args, config, conn_manager) + else: + # Interactive REPL mode + repl = REPL(conn_manager, config) + repl.start() + return 0 finally: # Clean up conn_manager.close() credentials.clear() - return 0 - except KeyboardInterrupt: logger.info("Interrupted by user") return 130 diff --git a/src/pgslice/config.py b/src/pgslice/config.py index d0f77fa..40a6092 100644 --- a/src/pgslice/config.py +++ b/src/pgslice/config.py @@ -45,6 +45,7 @@ class AppConfig: log_level: str = "INFO" sql_batch_size: int = 100 output_dir: Path = Path.home() / ".pgslice" / "dumps" + create_schema: bool = False def load_config() -> AppConfig: diff --git a/src/pgslice/dumper/ddl_generator.py b/src/pgslice/dumper/ddl_generator.py new file mode 100644 index 0000000..1f8026f --- /dev/null +++ b/src/pgslice/dumper/ddl_generator.py @@ -0,0 +1,541 @@ +"""DDL (Data Definition Language) generation for schema creation.""" + +from __future__ import annotations + +from collections import defaultdict, deque +from typing import TYPE_CHECKING + +from ..graph.models import Column, Table +from ..utils.logging_config import get_logger + +if TYPE_CHECKING: + from ..db.schema import SchemaIntrospector + +logger = get_logger(__name__) + + +class DDLGenerator: + """Generates CREATE DATABASE, CREATE SCHEMA, and CREATE TABLE statements.""" + + def __init__(self, schema_introspector: SchemaIntrospector) -> None: + """ + Initialize DDL generator. + + Args: + schema_introspector: Schema introspection utility for table metadata + """ + self.introspector = schema_introspector + + def generate_ddl( + self, + database_name: str, + schema_name: str, + tables: set[tuple[str, str]], + ) -> str: + """ + Generate complete DDL for database, schema, and tables. + + Args: + database_name: Name of the database to create + schema_name: Name of the schema to create + tables: Set of (schema, table) tuples to generate CREATE TABLE for + + Returns: + Complete DDL script with CREATE statements + + Example: + >>> generator = DDLGenerator(introspector) + >>> ddl = generator.generate_ddl("mydb", "public", {("public", "users")}) + >>> print(ddl) + -- CREATE DATABASE "mydb"; + \\c "mydb" + CREATE SCHEMA IF NOT EXISTS "public"; + ... + """ + if not tables: + return "" + + logger.info( + f"Generating DDL for database '{database_name}', " + f"schema '{schema_name}', {len(tables)} table(s)" + ) + + statements = [] + + # 1. CREATE DATABASE + quoted_db = self._quote_identifier(database_name) + statements.append( + f"-- NOTE: PostgreSQL does not support 'CREATE DATABASE IF NOT EXISTS'.\n" + f"-- If the database already exists, comment out the line below or it will fail.\n" + f"-- For a new database: Uncomment the following line\n" + f"-- CREATE DATABASE {quoted_db};" + ) + statements.append("") # Blank line + + # 2. CONNECT TO DATABASE + # Add connection command (psql-specific) + statements.append( + f"-- Connect to the database before creating schemas/tables.\n" + f"-- For psql: Use the \\c command below\n" + f"-- For other clients: Disconnect and reconnect to the database manually\n" + f"\\c {quoted_db}" + ) + statements.append("") # Blank line + + # 3. CREATE SCHEMA(S) + # Collect unique schemas from tables + unique_schemas = {schema for schema, _ in tables} + for schema in sorted(unique_schemas): + quoted_schema = self._quote_identifier(schema) + statements.append(f"CREATE SCHEMA IF NOT EXISTS {quoted_schema};") + statements.append("") # Blank line + + # 4. CREATE TABLES (sorted by dependencies) + sorted_tables = self._sort_tables_by_dependencies(tables) + + table_statements = [] + foreign_key_statements = [] + + for schema, table in sorted_tables: + # Generate CREATE TABLE + create_table_sql = self._generate_create_table(schema, table) + table_statements.append(create_table_sql) + + # Generate ALTER TABLE statements for foreign keys + fk_sql = self._generate_foreign_key_statements(schema, table) + if fk_sql: + foreign_key_statements.append(fk_sql) + + # Add all table creation statements + statements.extend(table_statements) + + # Add blank line before foreign keys + if foreign_key_statements: + statements.append("") + statements.append("-- Add foreign key constraints") + statements.extend(foreign_key_statements) + + ddl = "\n".join(statements) + logger.debug(f"Generated DDL: {len(statements)} statements") + return ddl + + def _generate_create_table(self, schema: str, table: str) -> str: + """ + Generate CREATE TABLE statement for a single table. + + Args: + schema: Schema name + table: Table name + + Returns: + CREATE TABLE IF NOT EXISTS statement with all columns and constraints + + Example: + CREATE TABLE IF NOT EXISTS "public"."users" ( + "id" SERIAL PRIMARY KEY, + "email" TEXT NOT NULL, + "created_at" TIMESTAMP DEFAULT NOW() + ); + """ + # Get table metadata + table_metadata = self.introspector.get_table_metadata(schema, table) + + quoted_schema = self._quote_identifier(schema) + quoted_table = self._quote_identifier(table) + full_table_name = f"{quoted_schema}.{quoted_table}" + + # Build column definitions + column_defs = [] + for col in table_metadata.columns: + col_def = self._format_column_definition(col) + column_defs.append(f" {col_def}") + + # Add table-level constraints + constraints = [] + + # Primary key constraint (if not already in column definition) + if table_metadata.primary_keys: + # Check if PK is already defined inline (single PK column) + has_inline_pk = len(table_metadata.primary_keys) == 1 and any( + col.is_primary_key and not col.is_auto_generated + for col in table_metadata.columns + if col.name == table_metadata.primary_keys[0] + ) + + if not has_inline_pk: + pk_constraint = self._format_primary_key_constraint(table_metadata) + if pk_constraint: + constraints.append(f" {pk_constraint}") + + # Unique constraints + for constraint_name, columns in table_metadata.unique_constraints.items(): + unique_constraint = self._format_unique_constraint(constraint_name, columns) + constraints.append(f" {unique_constraint}") + + # Combine columns and constraints + all_definitions = column_defs + constraints + definitions_sql = ",\n".join(all_definitions) + + # Build final CREATE TABLE statement + create_table = ( + f"CREATE TABLE IF NOT EXISTS {full_table_name} (\n{definitions_sql}\n);" + ) + + return create_table + + def _format_column_definition(self, col: Column) -> str: + """ + Format a single column definition. + + Args: + col: Column metadata + + Returns: + Column definition string + + Example: + "id" SERIAL PRIMARY KEY + "email" TEXT NOT NULL + "count" INTEGER DEFAULT 0 + "tags" TEXT[] + """ + quoted_name = self._quote_identifier(col.name) + + # Map data type + col_type = self._map_postgresql_type(col.data_type, col.udt_name) + + # Build definition parts + parts = [quoted_name, col_type] + + # Add NOT NULL if applicable + if not col.nullable: + parts.append("NOT NULL") + + # Add DEFAULT if specified + if col.default is not None: + # Clean up default value (remove PostgreSQL type casts if present) + default_value = col.default + # Handle nextval() for SERIAL columns - skip it as SERIAL includes it + if "nextval(" not in default_value.lower(): + parts.append(f"DEFAULT {default_value}") + + # Add PRIMARY KEY if single PK and not auto-generated + # (AUTO-generated columns like SERIAL handle PK differently) + if col.is_primary_key and not col.is_auto_generated: + parts.append("PRIMARY KEY") + + return " ".join(parts) + + def _format_primary_key_constraint(self, table: Table) -> str: + """ + Format PRIMARY KEY constraint. + + Args: + table: Table metadata + + Returns: + PRIMARY KEY constraint string or empty string + + Example: + PRIMARY KEY ("id") + PRIMARY KEY ("tenant_id", "user_id") + """ + if not table.primary_keys: + return "" + + # For SERIAL/auto-generated single PK, skip (already in column def) + if len(table.primary_keys) == 1: + pk_col_name = table.primary_keys[0] + pk_col = next( + (col for col in table.columns if col.name == pk_col_name), None + ) + if pk_col and pk_col.is_auto_generated: + return "" # SERIAL already includes PRIMARY KEY + + # Format composite or explicit PK + pk_columns = ", ".join(self._quote_identifier(pk) for pk in table.primary_keys) + return f"PRIMARY KEY ({pk_columns})" + + def _format_unique_constraint(self, name: str, columns: list[str]) -> str: + """ + Format UNIQUE constraint. + + Args: + name: Constraint name + columns: List of column names + + Returns: + UNIQUE constraint string + + Example: + CONSTRAINT "users_email_key" UNIQUE ("email") + """ + quoted_name = self._quote_identifier(name) + quoted_columns = ", ".join(self._quote_identifier(col) for col in columns) + return f"CONSTRAINT {quoted_name} UNIQUE ({quoted_columns})" + + def _generate_foreign_key_statements(self, schema: str, table: str) -> str: + """ + Generate ALTER TABLE statements for foreign keys. + + Foreign keys are added via ALTER TABLE to handle circular dependencies. + + Args: + schema: Schema name + table: Table name + + Returns: + ALTER TABLE statements for all foreign keys, or empty string + + Example: + ALTER TABLE "public"."orders" + ADD CONSTRAINT "orders_user_id_fkey" + FOREIGN KEY ("user_id") + REFERENCES "public"."users"("id"); + """ + # Get table metadata + table_metadata = self.introspector.get_table_metadata(schema, table) + + if not table_metadata.foreign_keys_outgoing: + return "" + + quoted_schema = self._quote_identifier(schema) + quoted_table = self._quote_identifier(table) + full_table_name = f"{quoted_schema}.{quoted_table}" + + fk_statements = [] + + for fk in table_metadata.foreign_keys_outgoing: + # Quote identifiers + constraint_name = self._quote_identifier(fk.constraint_name) + source_col = self._quote_identifier(fk.source_column) + + # Determine target schema (assume same schema if not specified) + # Foreign key target tables might be in different schemas + target_schema = schema # Default to same schema + target_table_quoted = self._quote_identifier(fk.target_table) + target_col_quoted = self._quote_identifier(fk.target_column) + + # Build ALTER TABLE statement + alter_stmt = ( + f"ALTER TABLE {full_table_name}\n" + f" ADD CONSTRAINT {constraint_name}\n" + f" FOREIGN KEY ({source_col})\n" + f' REFERENCES "{target_schema}".{target_table_quoted}({target_col_quoted})' + ) + + # Add ON DELETE clause if not default + if fk.on_delete and fk.on_delete != "NO ACTION": + alter_stmt += f"\n ON DELETE {fk.on_delete}" + + alter_stmt += ";" + fk_statements.append(alter_stmt) + + return "\n\n".join(fk_statements) + + def _map_postgresql_type(self, data_type: str, udt_name: str) -> str: + """ + Map PostgreSQL information_schema data type to CREATE TABLE syntax. + + Args: + data_type: Data type from information_schema.columns.data_type + udt_name: UDT name from information_schema.columns.udt_name + + Returns: + PostgreSQL type for CREATE TABLE statement + + Example: + ("ARRAY", "_text") -> "TEXT[]" + ("integer", "int4") -> "INTEGER" + ("character varying", "varchar") -> "TEXT" + ("USER-DEFINED", "my_enum") -> "my_enum" + """ + data_type_upper = data_type.upper() + + # Handle arrays + if data_type_upper == "ARRAY": + element_type = self._get_array_element_type(udt_name) + return f"{element_type}[]" + + # Handle user-defined types (ENUMs, domains, etc.) + if data_type_upper == "USER-DEFINED": + return udt_name + + # Map standard types + type_mapping = { + "INTEGER": "INTEGER", + "BIGINT": "BIGINT", + "SMALLINT": "SMALLINT", + "TEXT": "TEXT", + "CHARACTER VARYING": "TEXT", # Prefer TEXT over VARCHAR + "VARCHAR": "TEXT", + "CHARACTER": "CHAR", + "CHAR": "CHAR", + "BOOLEAN": "BOOLEAN", + "TIMESTAMP WITHOUT TIME ZONE": "TIMESTAMP", + "TIMESTAMP WITH TIME ZONE": "TIMESTAMPTZ", + "TIMESTAMP": "TIMESTAMP", + "TIMESTAMPTZ": "TIMESTAMPTZ", + "DATE": "DATE", + "TIME WITHOUT TIME ZONE": "TIME", + "TIME WITH TIME ZONE": "TIMETZ", + "TIME": "TIME", + "UUID": "UUID", + "JSON": "JSON", + "JSONB": "JSONB", + "NUMERIC": "NUMERIC", + "DECIMAL": "NUMERIC", + "REAL": "REAL", + "DOUBLE PRECISION": "DOUBLE PRECISION", + "BYTEA": "BYTEA", + "SERIAL": "SERIAL", + "BIGSERIAL": "BIGSERIAL", + "SMALLSERIAL": "SMALLSERIAL", + } + + mapped_type = type_mapping.get(data_type_upper) + if mapped_type: + return mapped_type + + # Fallback: use data_type as-is + logger.warning( + f"Unknown data type '{data_type}' (udt: '{udt_name}'), using as-is" + ) + return data_type.upper() + + def _get_array_element_type(self, udt_name: str) -> str: + """ + Extract element type from array udt_name. + + Args: + udt_name: PostgreSQL UDT name (e.g., "_text", "_int4") + + Returns: + Element type name + + Example: + "_text" -> "TEXT" + "_int4" -> "INTEGER" + "_varchar" -> "TEXT" + """ + # Remove leading underscore from array type names + element_udt = udt_name[1:] if udt_name.startswith("_") else udt_name + + # Map common UDT names to SQL types + udt_mapping = { + "text": "TEXT", + "varchar": "TEXT", + "char": "CHAR", + "int4": "INTEGER", + "int8": "BIGINT", + "int2": "SMALLINT", + "float4": "REAL", + "float8": "DOUBLE PRECISION", + "bool": "BOOLEAN", + "timestamp": "TIMESTAMP", + "timestamptz": "TIMESTAMPTZ", + "date": "DATE", + "time": "TIME", + "timetz": "TIMETZ", + "uuid": "UUID", + "json": "JSON", + "jsonb": "JSONB", + "numeric": "NUMERIC", + "bytea": "BYTEA", + } + + return udt_mapping.get(element_udt, element_udt.upper()) + + def _sort_tables_by_dependencies( + self, tables: set[tuple[str, str]] + ) -> list[tuple[str, str]]: + """ + Sort tables by FK dependencies using topological sort (Kahn's algorithm). + + Tables with no dependencies come first, followed by tables that depend on them. + This ensures CREATE TABLE statements are in valid dependency order. + + Args: + tables: Set of (schema, table) tuples + + Returns: + List of (schema, table) tuples in dependency order + + Example: + Input: {("public", "orders"), ("public", "users")} + Output: [("public", "users"), ("public", "orders")] + (users has no deps, orders depends on users) + """ + if not tables: + return [] + + # Build dependency graph + # in_degree: how many tables this table depends on + in_degree: dict[tuple[str, str], int] = dict.fromkeys(tables, 0) + # adjacency list: tables that depend on this table + dependents: dict[tuple[str, str], list[tuple[str, str]]] = defaultdict(list) + + for schema, table in tables: + table_metadata = self.introspector.get_table_metadata(schema, table) + + # Count dependencies (outgoing FKs to other tables in the set) + for fk in table_metadata.foreign_keys_outgoing: + # Assume target is in same schema if not specified + target_table_tuple = (schema, fk.target_table) + + # Only count if target is in our table set + if target_table_tuple in tables: + in_degree[(schema, table)] += 1 + dependents[target_table_tuple].append((schema, table)) + + # Kahn's algorithm: Process tables with no dependencies first + queue: deque[tuple[str, str]] = deque() + for table_tuple in tables: + if in_degree[table_tuple] == 0: + queue.append(table_tuple) + + sorted_tables = [] + + while queue: + current = queue.popleft() + sorted_tables.append(current) + + # Reduce in-degree for dependents + for dependent in dependents[current]: + in_degree[dependent] -= 1 + if in_degree[dependent] == 0: + queue.append(dependent) + + # If we couldn't sort all tables, there's a circular dependency + # In this case, just append remaining tables (FK will be added via ALTER) + if len(sorted_tables) < len(tables): + remaining = tables - set(sorted_tables) + sorted_tables.extend(sorted(remaining)) # Sort for consistency + logger.warning( + f"Circular dependencies detected among tables: {remaining}. " + f"Foreign keys will be added via ALTER TABLE." + ) + + return sorted_tables + + def _quote_identifier(self, identifier: str) -> str: + """ + Quote a SQL identifier safely. + + Always uses double quotes to handle reserved words and special characters. + Escapes embedded double quotes. + + Args: + identifier: SQL identifier (table, column, schema name) + + Returns: + Quoted identifier + + Example: + "users" -> '"users"' + "my-table" -> '"my-table"' + 'table"name' -> '"table""name"' (escaped quote) + """ + # Escape embedded double quotes by doubling them + escaped = identifier.replace('"', '""') + return f'"{escaped}"' diff --git a/src/pgslice/dumper/dump_service.py b/src/pgslice/dumper/dump_service.py new file mode 100644 index 0000000..9111590 --- /dev/null +++ b/src/pgslice/dumper/dump_service.py @@ -0,0 +1,178 @@ +"""Service for executing database dump operations.""" + +from __future__ import annotations + +import sys +from dataclasses import dataclass, field + +from tqdm import tqdm + +from ..config import AppConfig +from ..db.connection import ConnectionManager +from ..db.schema import SchemaIntrospector +from ..graph.models import TimeframeFilter +from ..graph.traverser import RelationshipTraverser +from ..graph.visited_tracker import VisitedTracker +from ..utils.logging_config import get_logger +from ..utils.spinner import SpinnerAnimator +from .dependency_sorter import DependencySorter +from .sql_generator import SQLGenerator + +logger = get_logger(__name__) + + +@dataclass +class DumpResult: + """Result of a dump operation.""" + + sql_content: str + record_count: int + tables_involved: set[str] = field(default_factory=set) + + +class DumpService: + """Service for executing database dump operations.""" + + def __init__( + self, + connection_manager: ConnectionManager, + config: AppConfig, + show_progress: bool = False, + ) -> None: + """ + Initialize dump service. + + Args: + connection_manager: Database connection manager + config: Application configuration + show_progress: Whether to show progress bar (writes to stderr) + """ + self.conn_manager = connection_manager + self.config = config + self.show_progress = show_progress + + def dump( + self, + table: str, + pk_values: list[str], + schema: str = "public", + wide_mode: bool = False, + keep_pks: bool = False, + create_schema: bool = False, + timeframe_filters: list[TimeframeFilter] | None = None, + show_graph: bool = False, + ) -> DumpResult: + """ + Execute dump operation and return result. + + Args: + table: Table name to dump + pk_values: List of primary key values + schema: Database schema name + wide_mode: Whether to follow all relationships including self-referencing FKs + keep_pks: Whether to keep original primary key values + create_schema: Whether to include DDL statements + timeframe_filters: Optional timeframe filters + show_graph: Whether to display relationship graph after dump + + Returns: + DumpResult with SQL content and metadata + """ + timeframe_filters = timeframe_filters or [] + + # Progress bar with 4 steps, writes to stderr + with tqdm( + total=4, + desc="Dumping", + disable=not self.show_progress, + file=sys.stderr, + bar_format="{desc}: {percentage:3.0f}%|{bar}| {n_fmt}/{total_fmt}", + ) as pbar: + # Step 1: Setup and traverse relationships + # Create spinner animator (updates every 100ms for smooth animation) + spinner = SpinnerAnimator(update_interval=0.1) + + pbar.set_description(f"Traversing relationships {spinner.get_frame()}") + + # Define progress callback to update progress bar with animated spinner + def update_progress(count: int) -> None: + pbar.set_description( + f"Traversing relationships {spinner.get_frame()} {count} records found" + ) + + conn = self.conn_manager.get_connection() + introspector = SchemaIntrospector(conn) + visited = VisitedTracker() + traverser = RelationshipTraverser( + conn, + introspector, + visited, + timeframe_filters, + wide_mode=wide_mode, + progress_callback=update_progress, + ) + + if len(pk_values) == 1: + records = traverser.traverse( + table, pk_values[0], schema, self.config.max_depth + ) + else: + records = traverser.traverse_multiple( + table, pk_values, schema, self.config.max_depth + ) + pbar.set_description( + f"Traversing relationships ✓ {len(records)} records found" + ) + pbar.update(1) + + # Step 2: Sort by dependencies + pbar.set_description("Sorting dependencies ⠋") + sorter = DependencySorter() + sorted_records = sorter.sort(records) + pbar.set_description("Sorting dependencies ✓") + pbar.update(1) + + # Step 3: Generate SQL + pbar.set_description("Generating SQL ⠋") + generator = SQLGenerator( + introspector, batch_size=self.config.sql_batch_size + ) + sql = generator.generate_batch( + sorted_records, + keep_pks=keep_pks, + create_schema=create_schema, + database_name=self.config.db.database, + schema_name=schema, + ) + pbar.set_description("Generating SQL ✓") + pbar.update(1) + + # Step 4: Complete + pbar.set_description("Complete ✓") + pbar.update(1) + + # Display graph AFTER progress bar completes + if show_graph and self.show_progress: + from ..utils.graph_visualizer import GraphBuilder, GraphRenderer + + builder = GraphBuilder() + graph = builder.build(records, table, schema) + + renderer = GraphRenderer() + graph_output = renderer.render(graph) + + # Print to stderr with header + sys.stderr.write("\n") + sys.stderr.write("=== Relationship Graph ===\n") + sys.stderr.write(graph_output) + sys.stderr.write("\n\n") + sys.stderr.flush() + + # Collect tables involved + tables_involved = {record.identifier.table_name for record in sorted_records} + + return DumpResult( + sql_content=sql, + record_count=len(sorted_records), + tables_involved=tables_involved, + ) diff --git a/src/pgslice/dumper/sql_generator.py b/src/pgslice/dumper/sql_generator.py index 1964ca9..8227181 100644 --- a/src/pgslice/dumper/sql_generator.py +++ b/src/pgslice/dumper/sql_generator.py @@ -10,6 +10,7 @@ from ..db.schema import SchemaIntrospector from ..graph.models import RecordData, RecordIdentifier, Table from ..utils.logging_config import get_logger +from .ddl_generator import DDLGenerator logger = get_logger(__name__) @@ -99,6 +100,9 @@ def generate_batch( records: list[RecordData], include_transaction: bool = True, keep_pks: bool = False, + create_schema: bool = False, + database_name: str | None = None, + schema_name: str = "public", ) -> str: """ Generate SQL for multiple records with proper ordering and bulk INSERTs. @@ -108,19 +112,29 @@ def generate_batch( include_transaction: Whether to wrap in BEGIN/COMMIT keep_pks: If True, keep original PK values (current behavior). If False, exclude auto-generated PKs and use PL/pgSQL remapping. + create_schema: If True, include DDL statements (CREATE DATABASE/SCHEMA/TABLE) + database_name: Database name for CREATE DATABASE statement (required if create_schema=True) + schema_name: Schema name for CREATE SCHEMA statement Returns: Complete SQL script """ if keep_pks: - return self._generate_batch_with_pks(records, include_transaction) + return self._generate_batch_with_pks( + records, include_transaction, create_schema, database_name, schema_name + ) else: return self._generate_batch_with_plpgsql_remapping( - records, include_transaction + records, include_transaction, create_schema, database_name, schema_name ) def _generate_batch_with_pks( - self, records: list[RecordData], include_transaction: bool + self, + records: list[RecordData], + include_transaction: bool, + create_schema: bool = False, + database_name: str | None = None, + schema_name: str = "public", ) -> str: """ Generate SQL with original PK values (current behavior). @@ -128,6 +142,9 @@ def _generate_batch_with_pks( Args: records: List of RecordData in dependency order include_transaction: Whether to wrap in BEGIN/COMMIT + create_schema: If True, include DDL statements + database_name: Database name for CREATE DATABASE (required if create_schema=True) + schema_name: Schema name for CREATE SCHEMA Returns: Complete SQL script with all bulk INSERT statements @@ -140,6 +157,20 @@ def _generate_batch_with_pks( sql_statements = [] + # Add DDL if requested + if create_schema and database_name: + ddl_generator = DDLGenerator(self.introspector) + + # Collect unique tables from records + unique_tables = { + (record.identifier.schema_name, record.identifier.table_name) + for record in records + } + + ddl = ddl_generator.generate_ddl(database_name, schema_name, unique_tables) + sql_statements.append(ddl) + sql_statements.append("") # Blank line separator + # Add header header = [ "-- Generated by pgslice", @@ -446,7 +477,12 @@ def _format_value( # ============================================================================ def _generate_batch_with_plpgsql_remapping( - self, records: list[RecordData], include_transaction: bool + self, + records: list[RecordData], + include_transaction: bool, + create_schema: bool = False, + database_name: str | None = None, + schema_name: str = "public", ) -> str: """ Generate PL/pgSQL script with ID remapping for auto-generated PKs. @@ -462,6 +498,13 @@ def _generate_batch_with_plpgsql_remapping( - Replace FK values with subqueries to lookup mapped IDs 6. Drop temp table + Args: + records: List of RecordData in dependency order + include_transaction: Whether to wrap in BEGIN/COMMIT (always True for PL/pgSQL) + create_schema: If True, include DDL statements + database_name: Database name for CREATE DATABASE (required if create_schema=True) + schema_name: Schema name for CREATE SCHEMA + Returns: PL/pgSQL script as string """ @@ -469,6 +512,22 @@ def _generate_batch_with_plpgsql_remapping( logger.info(f"Generating PL/pgSQL with ID remapping for {len(records)} records") + sql_statements = [] + + # Add DDL if requested + if create_schema and database_name: + ddl_generator = DDLGenerator(self.introspector) + + # Collect unique tables from records + unique_tables = { + (record.identifier.schema_name, record.identifier.table_name) + for record in records + } + + ddl = ddl_generator.generate_ddl(database_name, schema_name, unique_tables) + sql_statements.append(ddl) + sql_statements.append("") # Blank line separator + # 1. Deduplicate records (same as duplicate bug fix) seen_identifiers: set[RecordIdentifier] = set() unique_records: list[RecordData] = [] @@ -617,7 +676,14 @@ def _generate_batch_with_plpgsql_remapping( ] ) - result = "\n".join(sql_parts) + # Combine DDL (if any) with PL/pgSQL script + if sql_statements: + # sql_statements contains DDL, add PL/pgSQL parts after + sql_statements.extend(sql_parts) + result = "\n".join(sql_statements) + else: + result = "\n".join(sql_parts) + logger.info(f"Generated PL/pgSQL script ({len(result)} bytes)") return result diff --git a/src/pgslice/dumper/writer.py b/src/pgslice/dumper/writer.py index 4ac2a47..dd2f9f8 100644 --- a/src/pgslice/dumper/writer.py +++ b/src/pgslice/dumper/writer.py @@ -111,7 +111,13 @@ def write_to_stdout(sql_content: str) -> None: """ Write SQL content to stdout. + Uses sys.stdout directly to avoid print() buffering issues. + Ensures proper encoding for piping. + Args: sql_content: SQL script content """ - logger.debug("Writing SQL to stdout") + import sys + + sys.stdout.write(sql_content) + sys.stdout.flush() diff --git a/src/pgslice/graph/traverser.py b/src/pgslice/graph/traverser.py index 72c443b..757fe8b 100644 --- a/src/pgslice/graph/traverser.py +++ b/src/pgslice/graph/traverser.py @@ -3,6 +3,7 @@ from __future__ import annotations from collections import deque +from collections.abc import Callable from typing import Any import psycopg @@ -31,6 +32,7 @@ def __init__( visited_tracker: VisitedTracker, timeframe_filters: list[TimeframeFilter] | None = None, wide_mode: bool = False, + progress_callback: Callable[[int], None] | None = None, ) -> None: """ Initialize relationship traverser. @@ -43,6 +45,7 @@ def __init__( wide_mode: If True, follow incoming FKs from all records (wide/exploratory). If False (default), only follow incoming FKs from starting records and records reached via incoming FKs (strict mode, prevents fan-out). + progress_callback: Optional callback invoked with record count after each fetch """ self.conn = connection self.introspector = schema_introspector @@ -50,6 +53,7 @@ def __init__( self.table_cache: dict[str, Table] = {} self.timeframe_filters = {f.table_name: f for f in (timeframe_filters or [])} self.wide_mode = wide_mode + self.progress_callback = progress_callback def traverse( self, @@ -121,6 +125,10 @@ def traverse( f"Fetched {record_id} at depth {depth} ({len(results)} total records)" ) + # Invoke progress callback with current record count + if self.progress_callback: + self.progress_callback(len(results)) + # Get table metadata table = self._get_table_metadata( record_id.schema_name, record_id.table_name @@ -210,6 +218,10 @@ def traverse_multiple( records = self.traverse(table_name, pk_value, schema, max_depth) all_records.update(records) + # Final progress callback with total unique records + if self.progress_callback: + self.progress_callback(len(all_records)) + logger.info( f"Multi-traversal complete: {len(all_records)} unique records found" ) diff --git a/src/pgslice/operations/__init__.py b/src/pgslice/operations/__init__.py new file mode 100644 index 0000000..df24a5c --- /dev/null +++ b/src/pgslice/operations/__init__.py @@ -0,0 +1,17 @@ +"""Shared operations for CLI and REPL.""" + +from __future__ import annotations + +from .dump_ops import DumpOptions, execute_dump +from .parsing import parse_truncate_filter, parse_truncate_filters +from .schema_ops import describe_table, list_tables, print_tables + +__all__ = [ + "parse_truncate_filter", + "parse_truncate_filters", + "list_tables", + "print_tables", + "describe_table", + "execute_dump", + "DumpOptions", +] diff --git a/src/pgslice/operations/dump_ops.py b/src/pgslice/operations/dump_ops.py new file mode 100644 index 0000000..72c4fec --- /dev/null +++ b/src/pgslice/operations/dump_ops.py @@ -0,0 +1,52 @@ +"""Dump operations shared by CLI and REPL.""" + +from __future__ import annotations + +from dataclasses import dataclass, field + +from pgslice.config import AppConfig +from pgslice.db.connection import ConnectionManager +from pgslice.dumper.dump_service import DumpResult, DumpService +from pgslice.graph.models import TimeframeFilter + + +@dataclass +class DumpOptions: + """Options for dump operation.""" + + table: str + pk_values: list[str] + schema: str + wide_mode: bool = False + keep_pks: bool = False + create_schema: bool = False + timeframe_filters: list[TimeframeFilter] = field(default_factory=list) + show_progress: bool = False + + +def execute_dump( + conn_manager: ConnectionManager, + config: AppConfig, + options: DumpOptions, +) -> DumpResult: + """ + Execute dump operation. + + Args: + conn_manager: Database connection manager + config: Application configuration + options: Dump options + + Returns: + DumpResult with SQL content and metadata + """ + service = DumpService(conn_manager, config, show_progress=options.show_progress) + return service.dump( + table=options.table, + pk_values=options.pk_values, + schema=options.schema, + wide_mode=options.wide_mode, + keep_pks=options.keep_pks, + create_schema=options.create_schema, + timeframe_filters=options.timeframe_filters, + ) diff --git a/src/pgslice/operations/parsing.py b/src/pgslice/operations/parsing.py new file mode 100644 index 0000000..de3b69b --- /dev/null +++ b/src/pgslice/operations/parsing.py @@ -0,0 +1,75 @@ +"""Shared parsing utilities for CLI and REPL.""" + +from __future__ import annotations + +from datetime import datetime + +from pgslice.graph.models import TimeframeFilter +from pgslice.utils.exceptions import InvalidTimeframeError + + +def parse_truncate_filter(spec: str) -> TimeframeFilter: + """ + Parse truncate filter specification for related tables. + + Formats: + - table:column:start_date:end_date + - table:start_date:end_date (assumes 'created_at' column) + + Args: + spec: Truncate filter specification string + + Returns: + TimeframeFilter object + + Raises: + InvalidTimeframeError: If specification is invalid + """ + parts = spec.split(":") + + if len(parts) == 3: + # Format: table:start:end (assume created_at) + table_name, start_str, end_str = parts + column_name = "created_at" + elif len(parts) == 4: + # Format: table:column:start:end + table_name, column_name, start_str, end_str = parts + else: + raise InvalidTimeframeError( + f"Invalid truncate filter format: {spec}. " + "Expected: table:column:start:end or table:start:end" + ) + + # Parse dates + try: + start_date = datetime.fromisoformat(start_str) + except ValueError as e: + raise InvalidTimeframeError(f"Invalid start date: {start_str}") from e + + try: + end_date = datetime.fromisoformat(end_str) + except ValueError as e: + raise InvalidTimeframeError(f"Invalid end date: {end_str}") from e + + return TimeframeFilter( + table_name=table_name, + column_name=column_name, + start_date=start_date, + end_date=end_date, + ) + + +def parse_truncate_filters(specs: list[str] | None) -> list[TimeframeFilter]: + """ + Parse multiple truncate filter specifications for related tables. + + Args: + specs: List of truncate filter specification strings or None + + Returns: + List of TimeframeFilter objects + """ + if not specs: + return [] + + return [parse_truncate_filter(spec) for spec in specs] diff --git a/src/pgslice/operations/schema_ops.py b/src/pgslice/operations/schema_ops.py new file mode 100644 index 0000000..19b39a4 --- /dev/null +++ b/src/pgslice/operations/schema_ops.py @@ -0,0 +1,96 @@ +"""Schema introspection operations shared by CLI and REPL.""" + +from __future__ import annotations + +from printy import printy +from tabulate import tabulate + +from pgslice.db.connection import ConnectionManager +from pgslice.db.schema import SchemaIntrospector + + +def list_tables(conn_manager: ConnectionManager, schema: str) -> list[str]: + """ + List all tables in a schema. + + Args: + conn_manager: Database connection manager + schema: Schema name + + Returns: + List of table names + """ + conn = conn_manager.get_connection() + introspector = SchemaIntrospector(conn) + return introspector.get_all_tables(schema) + + +def print_tables(tables: list[str], schema: str) -> None: + """ + Print table list with formatting. + + Args: + tables: List of table names + schema: Schema name (for display) + """ + printy(f"\n[c]Tables in schema '{schema}':@\n") + for table in tables: + printy(f" {table}") + printy(f"\n[g]Total: {len(tables)} tables@\n") + + +def describe_table( + conn_manager: ConnectionManager, schema: str, table_name: str +) -> None: + """ + Describe table structure and relationships. + + Args: + conn_manager: Database connection manager + schema: Schema name + table_name: Table to describe + """ + conn = conn_manager.get_connection() + introspector = SchemaIntrospector(conn) + table = introspector.get_table_metadata(schema, table_name) + + printy(f"\n[c]Table: {table.full_name}@\n") + + # Columns + printy("\n[cB]Columns@") + col_data = [] + for col in table.columns: + pk_indicator = "✓" if col.is_primary_key else "" + col_data.append( + [ + col.name, + col.data_type, + "YES" if col.nullable else "NO", + col.default or "", + pk_indicator, + ] + ) + table_str = tabulate( + col_data, + headers=["Name", "Type", "Nullable", "Default", "PK"], + tablefmt="simple", + ) + printy(table_str) + + # Primary keys + if table.primary_keys: + printy(f"\n[g]Primary Keys:@ {', '.join(table.primary_keys)}") + + # Foreign keys outgoing + if table.foreign_keys_outgoing: + printy("\n[y]Foreign Keys (Outgoing):@") + for fk in table.foreign_keys_outgoing: + printy(f" {fk.source_column} → {fk.target_table}.{fk.target_column}") + + # Foreign keys incoming + if table.foreign_keys_incoming: + printy("\n[b]Referenced By (Incoming):@") + for fk in table.foreign_keys_incoming: + printy(f" {fk.source_table}.{fk.source_column} → {fk.target_column}") + + printy() diff --git a/src/pgslice/repl.py b/src/pgslice/repl.py index dc354bb..cc04015 100644 --- a/src/pgslice/repl.py +++ b/src/pgslice/repl.py @@ -3,7 +3,6 @@ from __future__ import annotations import shlex -from datetime import datetime from pathlib import Path from printy import printy, raw_format @@ -15,13 +14,10 @@ from .cache.schema_cache import SchemaCache from .config import AppConfig from .db.connection import ConnectionManager -from .db.schema import SchemaIntrospector -from .dumper.dependency_sorter import DependencySorter -from .dumper.sql_generator import SQLGenerator +from .dumper.dump_service import DumpService from .dumper.writer import SQLWriter from .graph.models import TimeframeFilter -from .graph.traverser import RelationshipTraverser -from .graph.visited_tracker import VisitedTracker +from .operations import describe_table, list_tables, parse_truncate_filter, print_tables from .utils.exceptions import DBReverseDumpError, InvalidTimeframeError from .utils.logging_config import get_logger @@ -63,7 +59,7 @@ def __init__( } def start(self) -> None: - """Start the REPL.""" + """Start the interactive REPL.""" # Create prompt session with history history_file = Path.home() / ".pgslice_history" self.session = PromptSession( @@ -71,7 +67,7 @@ def start(self) -> None: completer=WordCompleter(list(self.commands.keys()), ignore_case=True), ) - printy("\n[cB]pgslice REPL@") + printy("[cB]pgslice REPL@") printy("Type 'help' for commands, 'exit' to quit\n") while True: @@ -111,20 +107,26 @@ def _cmd_dump(self, args: list[str]) -> None: """ Execute dump command. - Format: dump "table_name" pk_value[,pk_value,...] [--output file.sql] [--schema schema_name] [--timeframe "table:col:start:end"] [--wide] + Format: dump "table_name" pk_value[,pk_value,...] [--output file.sql] [--schema schema_name] [--truncate "table:col:start:end"] [--wide] """ if len(args) < 2: printy('[y]Usage: dump "table_name" pk_value [options]@') printy("\nOptions:") printy(" --output FILE Output file path") printy(" --schema SCHEMA Schema name (default: public)") - printy(" --timeframe SPEC Timeframe filter (table:column:start:end)") + printy( + " --truncate SPEC Truncate filter for related tables (table:column:start:end)" + ) printy( " --wide Wide mode: follow all relationships (default: strict)" ) printy( " --keep-pks Keep original primary key values (default: remap auto-generated PKs)" ) + printy( + " --create-schema Include CREATE DATABASE/SCHEMA/TABLE statements" + ) + printy(" --graph Display relationship graph after dump") return table_name = args[0] @@ -139,6 +141,8 @@ def _cmd_dump(self, args: list[str]) -> None: timeframe_specs: list[str] = [] wide_mode = False keep_pks = False # Default: remap auto-generated PKs + create_schema_ddl = self.config.create_schema # Default from config + show_graph = False i = 2 while i < len(args): @@ -148,7 +152,7 @@ def _cmd_dump(self, args: list[str]) -> None: elif args[i] == "--schema" and i + 1 < len(args): schema = args[i + 1] i += 2 - elif args[i] == "--timeframe" and i + 1 < len(args): + elif args[i] == "--truncate" and i + 1 < len(args): timeframe_specs.append(args[i + 1]) i += 2 elif args[i] == "--wide": @@ -157,71 +161,67 @@ def _cmd_dump(self, args: list[str]) -> None: elif args[i] == "--keep-pks": keep_pks = True i += 1 + elif args[i] == "--create-schema": + create_schema_ddl = True + i += 1 + elif args[i] == "--graph": + show_graph = True + i += 1 else: i += 1 - # Parse timeframe filters + # Parse timeframe filters using shared function timeframe_filters: list[TimeframeFilter] = [] for spec in timeframe_specs: try: - tf = self._parse_timeframe(spec) + tf = parse_truncate_filter(spec) timeframe_filters.append(tf) except InvalidTimeframeError as e: - printy(f"[r]Invalid timeframe: {e}@") + printy(f"[r]Invalid truncate filter: {e}@") return # Execute dump pk_display = ", ".join(str(pk) for pk in pk_values) mode_display = "wide" if wide_mode else "strict" printy( - f"\n[c]Dumping {schema}.{table_name} with PK(s): {pk_display} ({mode_display} mode)@" + f"\n [c]Dumping {schema}.{table_name} with PK(s): {pk_display} ({mode_display} mode)@\n" ) + # Wide mode warning + if wide_mode: + printy( + " [y]⚠ Note: Wide mode follows ALL relationships including self-referencing FKs.@" + ) + printy(" [y] This may take longer and fetch more data.@\n") + if timeframe_filters: - printy("\n[y]Timeframe filters:@") + printy(" [y]Truncate filters:@") for tf in timeframe_filters: - printy(f" - {tf}") + printy(f" - {tf}") + printy("") # Empty line after filters try: - # Get connection - conn = self.conn_manager.get_connection() - - # Create introspector - introspector = SchemaIntrospector(conn) - - # Create traverser - visited = VisitedTracker() - traverser = RelationshipTraverser( - conn, introspector, visited, timeframe_filters, wide_mode=wide_mode + # Use DumpService for the actual dump + # REPL always writes to files, so progress bar is safe to show + service = DumpService(self.conn_manager, self.config, show_progress=True) + result = service.dump( + table=table_name, + pk_values=pk_values, + schema=schema, + wide_mode=wide_mode, + keep_pks=keep_pks, + create_schema=create_schema_ddl, + timeframe_filters=timeframe_filters, + show_graph=show_graph, ) - # Traverse relationships - if len(pk_values) == 1: - records = traverser.traverse( - table_name, pk_values[0], schema, self.config.max_depth - ) - else: - records = traverser.traverse_multiple( - table_name, pk_values, schema, self.config.max_depth - ) - - printy(f"\n[g]Found {len(records)} related records@") - - # Sort by dependencies - sorter = DependencySorter() - sorted_records = sorter.sort(records) - - # Generate SQL - generator = SQLGenerator( - introspector, batch_size=self.config.sql_batch_size - ) - sql = generator.generate_batch(sorted_records, keep_pks=keep_pks) + printy(f"\n [g]✓ Found {result.record_count} related records@") # Output if output_file: - SQLWriter.write_to_file(sql, output_file) + SQLWriter.write_to_file(result.sql_content, output_file) printy( - f"[g]Wrote {len(sorted_records)} INSERT statements to {output_file}@" + f" [g]✓ Wrote {result.record_count} INSERT statements to {output_file}@\n" ) else: # Use default output path @@ -231,13 +231,13 @@ def _cmd_dump(self, args: list[str]) -> None: pk_values[0], # Use first PK for filename schema, ) - SQLWriter.write_to_file(sql, str(default_path)) + SQLWriter.write_to_file(result.sql_content, str(default_path)) printy( - f"[g]Wrote {len(sorted_records)} INSERT statements to {default_path}@" + f" [g]✓ Wrote {result.record_count} INSERT statements to {default_path}@\n" ) except DBReverseDumpError as e: - printy(f"[r]Error: {e}@") + printy(f"\n [r]Error: {e}@\n") except Exception as e: logger.exception("Error during dump") printy(f"[r]Unexpected error: {e}@") @@ -248,7 +248,7 @@ def _cmd_help(self, args: list[str]) -> None: help_data = [ [ "dump TABLE PK [options]", - "Extract a record and all related records\nOptions: --output FILE, --schema SCHEMA, --timeframe SPEC", + "Extract a record and all related records\nOptions: --output FILE, --schema SCHEMA, --truncate SPEC", ], ["tables [--schema SCHEMA]", "List all tables in the database"], ["describe TABLE [--schema]", "Show table structure and relationships"], @@ -269,7 +269,7 @@ def _cmd_help(self, args: list[str]) -> None: printy("\n[y]Examples:@") print(' dump "users" 42 --output user_42.sql') print(' dump "users" 42,123,456 --output users.sql') - print(' dump "users" 42 --timeframe "orders:created_at:2024-01-01:2024-12-31"') + print(' dump "users" 42 --truncate "orders:created_at:2024-01-01:2024-12-31"') print(" tables") print(' describe "users"') print() @@ -288,15 +288,8 @@ def _cmd_list_tables(self, args: list[str]) -> None: schema = args[1] try: - conn = self.conn_manager.get_connection() - introspector = SchemaIntrospector(conn) - tables = introspector.get_all_tables(schema) - - printy(f"\n[c]Tables in schema '{schema}':@\n") - for table in tables: - printy(f" {table}") - printy(f"\n[g]Total: {len(tables)} tables@\n") - + tables = list_tables(self.conn_manager, schema) + print_tables(tables, schema) except Exception as e: printy(f"[r]Error: {e}@") @@ -314,55 +307,7 @@ def _cmd_describe_table(self, args: list[str]) -> None: schema = args[2] try: - conn = self.conn_manager.get_connection() - introspector = SchemaIntrospector(conn) - table = introspector.get_table_metadata(schema, table_name) - - printy(f"\n[c]Table: {table.full_name}@\n") - - # Columns - printy("\n[cB]Columns@") - col_data = [] - for col in table.columns: - pk_indicator = "✓" if col.is_primary_key else "" - col_data.append( - [ - col.name, - col.data_type, - "YES" if col.nullable else "NO", - col.default or "", - pk_indicator, - ] - ) - table_str = tabulate( - col_data, - headers=["Name", "Type", "Nullable", "Default", "PK"], - tablefmt="simple", - ) - printy(table_str) - - # Primary keys - if table.primary_keys: - printy(f"\n[g]Primary Keys:@ {', '.join(table.primary_keys)}") - - # Foreign keys outgoing - if table.foreign_keys_outgoing: - printy("\n[y]Foreign Keys (Outgoing):@") - for fk in table.foreign_keys_outgoing: - printy( - f" {fk.source_column} → {fk.target_table}.{fk.target_column}" - ) - - # Foreign keys incoming - if table.foreign_keys_incoming: - printy("\n[b]Referenced By (Incoming):@") - for fk in table.foreign_keys_incoming: - printy( - f" {fk.source_table}.{fk.source_column} → {fk.target_column}" - ) - - printy() - + describe_table(self.conn_manager, schema, table_name) except Exception as e: printy(f"[r]Error: {e}@") @@ -378,52 +323,3 @@ def _cmd_clear_cache(self, args: list[str]) -> None: printy("[g]Cache cleared successfully@") else: printy("[y]Cache not initialized@") - - def _parse_timeframe(self, spec: str) -> TimeframeFilter: - """ - Parse timeframe specification. - - Format: table:column:start_date:end_date - Or: table:start_date:end_date (assumes 'created_at' column) - - Args: - spec: Timeframe specification string - - Returns: - TimeframeFilter object - - Raises: - InvalidTimeframeError: If specification is invalid - """ - parts = spec.split(":") - - if len(parts) == 3: - # Format: table:start:end (assume created_at) - table_name, start_str, end_str = parts - column_name = "created_at" - elif len(parts) == 4: - # Format: table:column:start:end - table_name, column_name, start_str, end_str = parts - else: - raise InvalidTimeframeError( - f"Invalid timeframe format: {spec}. " - "Expected: table:column:start:end or table:start:end" - ) - - # Parse dates - try: - start_date = datetime.fromisoformat(start_str) - except ValueError as e: - raise InvalidTimeframeError(f"Invalid start date: {start_str}") from e - - try: - end_date = datetime.fromisoformat(end_str) - except ValueError as e: - raise InvalidTimeframeError(f"Invalid end date: {end_str}") from e - - return TimeframeFilter( - table_name=table_name, - column_name=column_name, - start_date=start_date, - end_date=end_date, - ) diff --git a/src/pgslice/utils/graph_visualizer.py b/src/pgslice/utils/graph_visualizer.py new file mode 100644 index 0000000..0867914 --- /dev/null +++ b/src/pgslice/utils/graph_visualizer.py @@ -0,0 +1,267 @@ +"""Graph visualization for relationship traversal results.""" + +from __future__ import annotations + +from collections import defaultdict +from dataclasses import dataclass + +from printy import raw + +from ..graph.models import RecordData + + +@dataclass +class TableNode: + """Represents a table in the relationship graph.""" + + table_name: str + schema_name: str + record_count: int + is_root: bool = False + + +@dataclass +class TableEdge: + """Represents a FK relationship between tables.""" + + source_table: str # Child table (has FK column) + target_table: str # Parent table (referenced by FK) + fk_column: str | None # FK column name (if available) + record_count: int # Number of records using this relationship + + +@dataclass +class TableGraph: + """Complete table-level graph structure.""" + + nodes: list[TableNode] + edges: list[TableEdge] + + +class GraphBuilder: + """Builds table-level graph from RecordData set.""" + + def build( + self, records: set[RecordData], root_table: str, root_schema: str + ) -> TableGraph: + """ + Build table-level graph from RecordData set. + + Args: + records: Set of all fetched records + root_table: Name of the starting table + root_schema: Schema of the starting table + + Returns: + TableGraph with nodes and edges + """ + # 1. Count records per table + table_counts: dict[tuple[str, str], int] = defaultdict(int) + for record in records: + key = (record.identifier.schema_name, record.identifier.table_name) + table_counts[key] += 1 + + # 2. Extract FK relationships from dependencies + edges: dict[tuple[str, str], dict[tuple[str, str], int]] = defaultdict( + lambda: defaultdict(int) + ) + for record in records: + source_key = ( + record.identifier.schema_name, + record.identifier.table_name, + ) + for dep in record.dependencies: + target_key = (dep.schema_name, dep.table_name) + # Count how many records use this FK relationship + edges[source_key][target_key] += 1 + + # 3. Create nodes (mark root) + nodes = [ + TableNode( + table_name=table, + schema_name=schema, + record_count=count, + is_root=(table == root_table and schema == root_schema), + ) + for (schema, table), count in table_counts.items() + ] + + # 4. Create edges + edge_list = [] + for (src_schema, src_table), targets in edges.items(): + for (tgt_schema, tgt_table), count in targets.items(): + edge_list.append( + TableEdge( + source_table=f"{src_schema}.{src_table}", + target_table=f"{tgt_schema}.{tgt_table}", + fk_column=None, # Can enhance later with FK column name + record_count=count, + ) + ) + + return TableGraph(nodes=nodes, edges=edge_list) + + +class GraphRenderer: + """Renders table graph as ASCII tree using Unicode box-drawing.""" + + # Unicode box-drawing characters + BRANCH = "├── " + PIPE = "│ " + LAST = "└── " + SPACE = " " + + def render(self, graph: TableGraph) -> str: + """ + Render graph as tree using Unicode box-drawing. + + Args: + graph: TableGraph to render + + Returns: + Formatted string with tree visualization + """ + # 1. Find root nodes + roots = [node for node in graph.nodes if node.is_root] + if not roots: + # Fallback: find nodes with no incoming edges + incoming = {edge.target_table for edge in graph.edges} + roots = [ + node + for node in graph.nodes + if f"{node.schema_name}.{node.table_name}" not in incoming + ] + + # Handle empty graph + if not roots: + if graph.nodes: + # No clear root, use first node + roots = [graph.nodes[0]] + else: + return "(No records found)" + + # 2. Build adjacency list (bidirectional to show full traversal) + children: dict[str, list[tuple[TableNode, TableEdge]]] = defaultdict(list) + for edge in graph.edges: + # Find both child and parent nodes + child_node = next( + ( + n + for n in graph.nodes + if f"{n.schema_name}.{n.table_name}" == edge.source_table + ), + None, + ) + parent_node = next( + ( + n + for n in graph.nodes + if f"{n.schema_name}.{n.table_name}" == edge.target_table + ), + None, + ) + + if child_node and parent_node: + # Add child to parent's list (reverse FK: parent <- child) + children[edge.target_table].append((child_node, edge)) + + # Add parent to child's list (forward FK: child -> parent) + # This allows showing full traversal tree + children[edge.source_table].append((parent_node, edge)) + + # 3. Render tree with DFS + lines: list[str] = [] + visited: set[str] = set() + + for root in roots: + self._render_node( + root, children, "", True, lines, visited, is_root=True, parent=None + ) + + # Handle single table with no relationships + if len(lines) == 1 and not children: + lines.append("(No related tables)") + + return "\n".join(lines) + + def _render_node( + self, + node: TableNode, + children: dict[str, list[tuple[TableNode, TableEdge]]], + prefix: str, + is_last: bool, + lines: list[str], + visited: set[str], + is_root: bool = False, + parent: TableNode | None = None, + ) -> None: + """ + Recursively render node and its children. + + Args: + node: Current node to render + children: Adjacency list of parent -> children + prefix: Current indentation prefix + is_last: Whether this is the last child of its parent + lines: Accumulator for output lines + visited: Set of already visited nodes (for cycle detection) + is_root: Whether this is a root node + parent: Parent node we came from (to avoid immediate back-references) + """ + full_name = f"{node.schema_name}.{node.table_name}" + + # Format current node with colors using printy + if is_root: + # Root nodes: cyan (bold) table name + yellow count + line = raw(f"[cB]{node.table_name}@ [y]({node.record_count} records)@") + else: + connector = self.LAST if is_last else self.BRANCH + # Tree structure: dark connectors + cyan table name + yellow count + line = ( + prefix + + raw(f"[n]{connector}@") + + raw(f"[c]{node.table_name}@ ") + + raw(f"[y]({node.record_count} records)@") + ) + + # Check if already shown (cycle detection) + if full_name in visited and not is_root: + line += raw(" [n][shown above]@") + lines.append(line) + return + + lines.append(line) + visited.add(full_name) + + # Render children (filter out immediate parent to avoid back-reference) + child_list = children.get(full_name, []) + + # Filter out the parent we just came from + if parent: + parent_name = f"{parent.schema_name}.{parent.table_name}" + child_list = [ + (child, edge) + for child, edge in child_list + if f"{child.schema_name}.{child.table_name}" != parent_name + ] + + for i, (child_node, _edge) in enumerate(child_list): + is_last_child = i == len(child_list) - 1 + + # Update prefix for child (with colored tree characters) + if is_root: + child_prefix = "" + elif is_last: + child_prefix = prefix + self.SPACE + else: + child_prefix = prefix + raw(f"[n]{self.PIPE}@") + + self._render_node( + child_node, + children, + child_prefix, + is_last_child, + lines, + visited, + parent=node, + ) diff --git a/src/pgslice/utils/logging_config.py b/src/pgslice/utils/logging_config.py index d692963..55af44e 100644 --- a/src/pgslice/utils/logging_config.py +++ b/src/pgslice/utils/logging_config.py @@ -1,18 +1,33 @@ """Logging configuration for pgslice.""" +from __future__ import annotations + import logging import sys -def setup_logging(log_level: str = "INFO") -> None: +def disable_logging() -> None: + """Disable all logging output.""" + logging.disable(logging.CRITICAL) + + +def setup_logging(log_level: str | None = None) -> None: """ Configure logging for the application. Args: - log_level: Logging level (DEBUG, INFO, WARNING, ERROR) + log_level: Logging level (DEBUG, INFO, WARNING, ERROR). + If None, logging is disabled entirely. """ + if log_level is None: + disable_logging() + return + level = getattr(logging, log_level.upper(), logging.INFO) + # Re-enable logging in case it was previously disabled + logging.disable(logging.NOTSET) + # Create formatter formatter = logging.Formatter( fmt="%(asctime)s - %(name)s - %(levelname)s - %(message)s", @@ -27,8 +42,8 @@ def setup_logging(log_level: str = "INFO") -> None: for handler in root_logger.handlers[:]: root_logger.removeHandler(handler) - # Add console handler - console_handler = logging.StreamHandler(sys.stdout) + # Add console handler to stderr (not stdout, to avoid mixing with SQL output) + console_handler = logging.StreamHandler(sys.stderr) console_handler.setLevel(level) console_handler.setFormatter(formatter) root_logger.addHandler(console_handler) diff --git a/src/pgslice/utils/spinner.py b/src/pgslice/utils/spinner.py new file mode 100644 index 0000000..0650e49 --- /dev/null +++ b/src/pgslice/utils/spinner.py @@ -0,0 +1,48 @@ +"""Animated spinner utility for progress indication.""" + +from __future__ import annotations + +import time + + +class SpinnerAnimator: + """ + Animated spinner using Braille patterns. + + Rotates through spinner frames at a fixed rate (time-based) + to provide smooth animation regardless of callback frequency. + """ + + # Braille spinner frames for smooth rotation + FRAMES = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"] + + def __init__(self, update_interval: float = 0.1) -> None: + """ + Initialize spinner animator. + + Args: + update_interval: Minimum time (in seconds) between frame updates. + Default: 0.1 seconds (100ms) for smooth animation. + """ + self.update_interval = update_interval + self._current_idx = 0 + self._last_update = time.time() + + def get_frame(self) -> str: + """ + Get current spinner frame and advance if enough time has passed. + + Returns: + Current spinner character + """ + current_time = time.time() + if current_time - self._last_update >= self.update_interval: + self._current_idx = (self._current_idx + 1) % len(self.FRAMES) + self._last_update = current_time + + return self.FRAMES[self._current_idx] + + def reset(self) -> None: + """Reset spinner to initial state.""" + self._current_idx = 0 + self._last_update = time.time() diff --git a/tests/unit/dumper/test_ddl_generator.py b/tests/unit/dumper/test_ddl_generator.py new file mode 100644 index 0000000..c90999e --- /dev/null +++ b/tests/unit/dumper/test_ddl_generator.py @@ -0,0 +1,678 @@ +"""Tests for DDL generation.""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest + +from pgslice.dumper.ddl_generator import DDLGenerator +from pgslice.graph.models import Column, ForeignKey, Table + + +class TestDDLGenerator: + """Tests for DDLGenerator class.""" + + @pytest.fixture + def mock_introspector(self) -> MagicMock: + """Provide a mocked schema introspector.""" + return MagicMock() + + @pytest.fixture + def generator(self, mock_introspector: MagicMock) -> DDLGenerator: + """Provide a DDLGenerator instance.""" + return DDLGenerator(mock_introspector) + + +class TestGenerateDDL(TestDDLGenerator): + """Tests for generate_ddl method.""" + + def test_empty_tables_set( + self, generator: DDLGenerator, mock_introspector: MagicMock + ) -> None: + """Should return empty string for empty table set.""" + result = generator.generate_ddl("testdb", "public", set()) + assert result == "" + + def test_generate_database_statement( + self, generator: DDLGenerator, mock_introspector: MagicMock + ) -> None: + """Should generate commented CREATE DATABASE statement.""" + # Mock table metadata + mock_introspector.get_table_metadata.return_value = Table( + schema_name="public", + table_name="users", + columns=[ + Column( + name="id", + data_type="integer", + udt_name="int4", + nullable=False, + is_auto_generated=True, + is_primary_key=True, + ) + ], + primary_keys=["id"], + foreign_keys_outgoing=[], + foreign_keys_incoming=[], + ) + + result = generator.generate_ddl("mydb", "public", {("public", "users")}) + # CREATE DATABASE should be commented out by default (PostgreSQL doesn't support IF NOT EXISTS) + assert '-- CREATE DATABASE "mydb";' in result + + def test_generate_schema_statement( + self, generator: DDLGenerator, mock_introspector: MagicMock + ) -> None: + """Should generate CREATE SCHEMA IF NOT EXISTS statement.""" + mock_introspector.get_table_metadata.return_value = Table( + schema_name="public", + table_name="users", + columns=[ + Column( + name="id", + data_type="integer", + udt_name="int4", + nullable=False, + is_auto_generated=True, + is_primary_key=True, + ) + ], + primary_keys=["id"], + foreign_keys_outgoing=[], + foreign_keys_incoming=[], + ) + + result = generator.generate_ddl("mydb", "public", {("public", "users")}) + assert 'CREATE SCHEMA IF NOT EXISTS "public";' in result + + def test_multiple_schemas( + self, generator: DDLGenerator, mock_introspector: MagicMock + ) -> None: + """Should generate CREATE SCHEMA for all unique schemas.""" + + def get_table_meta(schema: str, table: str) -> Table: + return Table( + schema_name=schema, + table_name=table, + columns=[ + Column( + name="id", + data_type="integer", + udt_name="int4", + nullable=False, + is_auto_generated=True, + is_primary_key=True, + ) + ], + primary_keys=["id"], + foreign_keys_outgoing=[], + foreign_keys_incoming=[], + ) + + mock_introspector.get_table_metadata.side_effect = get_table_meta + + result = generator.generate_ddl( + "mydb", "public", {("public", "users"), ("custom", "orders")} + ) + assert 'CREATE SCHEMA IF NOT EXISTS "public";' in result + assert 'CREATE SCHEMA IF NOT EXISTS "custom";' in result + + +class TestGenerateCreateTable(TestDDLGenerator): + """Tests for _generate_create_table method.""" + + def test_basic_table( + self, generator: DDLGenerator, mock_introspector: MagicMock + ) -> None: + """Should generate basic CREATE TABLE statement.""" + mock_introspector.get_table_metadata.return_value = Table( + schema_name="public", + table_name="users", + columns=[ + Column( + name="id", + data_type="integer", + udt_name="int4", + nullable=False, + is_auto_generated=True, + is_primary_key=True, + ), + Column( + name="email", + data_type="text", + udt_name="text", + nullable=False, + ), + ], + primary_keys=["id"], + foreign_keys_outgoing=[], + foreign_keys_incoming=[], + ) + + result = generator._generate_create_table("public", "users") + assert 'CREATE TABLE IF NOT EXISTS "public"."users"' in result + assert '"id" INTEGER' in result + assert '"email" TEXT NOT NULL' in result + + def test_serial_column( + self, generator: DDLGenerator, mock_introspector: MagicMock + ) -> None: + """Should handle SERIAL columns correctly.""" + mock_introspector.get_table_metadata.return_value = Table( + schema_name="public", + table_name="users", + columns=[ + Column( + name="id", + data_type="integer", + udt_name="int4", + nullable=False, + is_auto_generated=True, + is_primary_key=True, + default="nextval('users_id_seq'::regclass)", + ) + ], + primary_keys=["id"], + foreign_keys_outgoing=[], + foreign_keys_incoming=[], + ) + + result = generator._generate_create_table("public", "users") + # SERIAL columns should not include PRIMARY KEY inline or DEFAULT nextval + assert '"id" INTEGER NOT NULL' in result + + def test_composite_primary_key( + self, generator: DDLGenerator, mock_introspector: MagicMock + ) -> None: + """Should handle composite primary keys.""" + mock_introspector.get_table_metadata.return_value = Table( + schema_name="public", + table_name="user_roles", + columns=[ + Column( + name="user_id", + data_type="integer", + udt_name="int4", + nullable=False, + is_primary_key=True, + ), + Column( + name="role_id", + data_type="integer", + udt_name="int4", + nullable=False, + is_primary_key=True, + ), + ], + primary_keys=["user_id", "role_id"], + foreign_keys_outgoing=[], + foreign_keys_incoming=[], + ) + + result = generator._generate_create_table("public", "user_roles") + assert 'PRIMARY KEY ("user_id", "role_id")' in result + + def test_unique_constraints( + self, generator: DDLGenerator, mock_introspector: MagicMock + ) -> None: + """Should include unique constraints.""" + mock_introspector.get_table_metadata.return_value = Table( + schema_name="public", + table_name="users", + columns=[ + Column( + name="id", + data_type="integer", + udt_name="int4", + nullable=False, + is_auto_generated=True, + is_primary_key=True, + ), + Column( + name="email", + data_type="text", + udt_name="text", + nullable=False, + ), + ], + primary_keys=["id"], + foreign_keys_outgoing=[], + foreign_keys_incoming=[], + unique_constraints={"users_email_key": ["email"]}, + ) + + result = generator._generate_create_table("public", "users") + assert 'CONSTRAINT "users_email_key" UNIQUE ("email")' in result + + def test_default_values( + self, generator: DDLGenerator, mock_introspector: MagicMock + ) -> None: + """Should include DEFAULT values.""" + mock_introspector.get_table_metadata.return_value = Table( + schema_name="public", + table_name="posts", + columns=[ + Column( + name="id", + data_type="integer", + udt_name="int4", + nullable=False, + is_auto_generated=True, + is_primary_key=True, + ), + Column( + name="status", + data_type="text", + udt_name="text", + nullable=False, + default="'draft'::text", + ), + Column( + name="created_at", + data_type="timestamp", + udt_name="timestamp", + nullable=False, + default="now()", + ), + ], + primary_keys=["id"], + foreign_keys_outgoing=[], + foreign_keys_incoming=[], + ) + + result = generator._generate_create_table("public", "posts") + assert "DEFAULT 'draft'::text" in result + assert "DEFAULT now()" in result + + +class TestColumnFormatting(TestDDLGenerator): + """Tests for _format_column_definition method.""" + + def test_array_types( + self, generator: DDLGenerator, mock_introspector: MagicMock + ) -> None: + """Should format array types correctly.""" + col = Column( + name="tags", + data_type="ARRAY", + udt_name="_text", + nullable=True, + ) + + result = generator._format_column_definition(col) + assert '"tags" TEXT[]' in result + + def test_integer_array( + self, generator: DDLGenerator, mock_introspector: MagicMock + ) -> None: + """Should format integer array types.""" + col = Column( + name="scores", + data_type="ARRAY", + udt_name="_int4", + nullable=True, + ) + + result = generator._format_column_definition(col) + assert '"scores" INTEGER[]' in result + + def test_json_types( + self, generator: DDLGenerator, mock_introspector: MagicMock + ) -> None: + """Should handle JSON and JSONB types.""" + col_json = Column( + name="metadata", + data_type="json", + udt_name="json", + nullable=True, + ) + col_jsonb = Column( + name="data", + data_type="jsonb", + udt_name="jsonb", + nullable=True, + ) + + result_json = generator._format_column_definition(col_json) + result_jsonb = generator._format_column_definition(col_jsonb) + + assert '"metadata" JSON' in result_json + assert '"data" JSONB' in result_jsonb + + def test_numeric_types( + self, generator: DDLGenerator, mock_introspector: MagicMock + ) -> None: + """Should handle various numeric types.""" + cols = [ + Column( + name="small_num", data_type="smallint", udt_name="int2", nullable=True + ), + Column(name="big_num", data_type="bigint", udt_name="int8", nullable=True), + Column( + name="decimal_num", + data_type="numeric", + udt_name="numeric", + nullable=True, + ), + Column(name="real_num", data_type="real", udt_name="float4", nullable=True), + Column( + name="double_num", + data_type="double precision", + udt_name="float8", + nullable=True, + ), + ] + + for col in cols: + result = generator._format_column_definition(col) + assert col.name in result + + +class TestForeignKeys(TestDDLGenerator): + """Tests for foreign key generation.""" + + def test_generate_foreign_keys( + self, generator: DDLGenerator, mock_introspector: MagicMock + ) -> None: + """Should generate ALTER TABLE statements for foreign keys.""" + mock_introspector.get_table_metadata.return_value = Table( + schema_name="public", + table_name="orders", + columns=[ + Column( + name="id", + data_type="integer", + udt_name="int4", + nullable=False, + is_auto_generated=True, + is_primary_key=True, + ), + Column( + name="user_id", + data_type="integer", + udt_name="int4", + nullable=False, + ), + ], + primary_keys=["id"], + foreign_keys_outgoing=[ + ForeignKey( + constraint_name="orders_user_id_fkey", + source_table="orders", + source_column="user_id", + target_table="users", + target_column="id", + on_delete="CASCADE", + ) + ], + foreign_keys_incoming=[], + ) + + result = generator._generate_foreign_key_statements("public", "orders") + assert 'ALTER TABLE "public"."orders"' in result + assert 'ADD CONSTRAINT "orders_user_id_fkey"' in result + assert 'FOREIGN KEY ("user_id")' in result + assert 'REFERENCES "public"."users"("id")' in result + assert "ON DELETE CASCADE" in result + + def test_no_foreign_keys( + self, generator: DDLGenerator, mock_introspector: MagicMock + ) -> None: + """Should return empty string when no foreign keys.""" + mock_introspector.get_table_metadata.return_value = Table( + schema_name="public", + table_name="users", + columns=[ + Column( + name="id", + data_type="integer", + udt_name="int4", + nullable=False, + is_auto_generated=True, + is_primary_key=True, + ) + ], + primary_keys=["id"], + foreign_keys_outgoing=[], + foreign_keys_incoming=[], + ) + + result = generator._generate_foreign_key_statements("public", "users") + assert result == "" + + +class TestTableDependencySorting(TestDDLGenerator): + """Tests for _sort_tables_by_dependencies method.""" + + def test_sort_simple_dependency( + self, generator: DDLGenerator, mock_introspector: MagicMock + ) -> None: + """Should sort tables by dependencies - users before orders.""" + + def get_table_meta(schema: str, table: str) -> Table: + if table == "users": + return Table( + schema_name=schema, + table_name=table, + columns=[], + primary_keys=["id"], + foreign_keys_outgoing=[], + foreign_keys_incoming=[], + ) + else: # orders + return Table( + schema_name=schema, + table_name=table, + columns=[], + primary_keys=["id"], + foreign_keys_outgoing=[ + ForeignKey( + constraint_name="orders_user_id_fkey", + source_table="orders", + source_column="user_id", + target_table="users", + target_column="id", + ) + ], + foreign_keys_incoming=[], + ) + + mock_introspector.get_table_metadata.side_effect = get_table_meta + + result = generator._sort_tables_by_dependencies( + {("public", "orders"), ("public", "users")} + ) + + # Users should come before orders + users_idx = result.index(("public", "users")) + orders_idx = result.index(("public", "orders")) + assert users_idx < orders_idx + + def test_circular_dependencies( + self, generator: DDLGenerator, mock_introspector: MagicMock + ) -> None: + """Should handle circular dependencies gracefully.""" + + def get_table_meta(schema: str, table: str) -> Table: + if table == "authors": + return Table( + schema_name=schema, + table_name=table, + columns=[], + primary_keys=["id"], + foreign_keys_outgoing=[ + ForeignKey( + constraint_name="authors_favorite_book_id_fkey", + source_table="authors", + source_column="favorite_book_id", + target_table="books", + target_column="id", + ) + ], + foreign_keys_incoming=[], + ) + else: # books + return Table( + schema_name=schema, + table_name=table, + columns=[], + primary_keys=["id"], + foreign_keys_outgoing=[ + ForeignKey( + constraint_name="books_author_id_fkey", + source_table="books", + source_column="author_id", + target_table="authors", + target_column="id", + ) + ], + foreign_keys_incoming=[], + ) + + mock_introspector.get_table_metadata.side_effect = get_table_meta + + # Should not raise an error + result = generator._sort_tables_by_dependencies( + {("public", "authors"), ("public", "books")} + ) + + # Should return all tables (order may vary) + assert len(result) == 2 + assert set(result) == {("public", "authors"), ("public", "books")} + + +class TestIdentifierQuoting(TestDDLGenerator): + """Tests for _quote_identifier method.""" + + def test_simple_identifier(self, generator: DDLGenerator) -> None: + """Should quote simple identifiers.""" + assert generator._quote_identifier("users") == '"users"' + + def test_identifier_with_special_chars(self, generator: DDLGenerator) -> None: + """Should handle identifiers with special characters.""" + assert generator._quote_identifier("my-table") == '"my-table"' + assert generator._quote_identifier("table.name") == '"table.name"' + + def test_identifier_with_embedded_quotes(self, generator: DDLGenerator) -> None: + """Should escape embedded double quotes.""" + assert generator._quote_identifier('table"name') == '"table""name"' + + def test_reserved_words(self, generator: DDLGenerator) -> None: + """Should quote reserved words.""" + assert generator._quote_identifier("order") == '"order"' + assert generator._quote_identifier("select") == '"select"' + + +class TestTypeMapping(TestDDLGenerator): + """Tests for _map_postgresql_type method.""" + + def test_user_defined_types(self, generator: DDLGenerator) -> None: + """Should handle user-defined types (ENUMs).""" + result = generator._map_postgresql_type("USER-DEFINED", "status_enum") + assert result == "status_enum" + + def test_varchar_to_text(self, generator: DDLGenerator) -> None: + """Should map VARCHAR to TEXT.""" + result = generator._map_postgresql_type("character varying", "varchar") + assert result == "TEXT" + + def test_timestamp_types(self, generator: DDLGenerator) -> None: + """Should handle various timestamp types.""" + assert ( + generator._map_postgresql_type("timestamp without time zone", "timestamp") + == "TIMESTAMP" + ) + assert ( + generator._map_postgresql_type("timestamp with time zone", "timestamptz") + == "TIMESTAMPTZ" + ) + + def test_array_type_mapping(self, generator: DDLGenerator) -> None: + """Should map array types to type[].""" + result = generator._map_postgresql_type("ARRAY", "_text") + assert result == "TEXT[]" + + result = generator._map_postgresql_type("ARRAY", "_int4") + assert result == "INTEGER[]" + + +class TestIntegration(TestDDLGenerator): + """Integration tests for complete DDL generation.""" + + def test_full_ddl_generation( + self, generator: DDLGenerator, mock_introspector: MagicMock + ) -> None: + """Should generate complete DDL with database, schema, and tables.""" + + def get_table_meta(schema: str, table: str) -> Table: + if table == "users": + return Table( + schema_name=schema, + table_name=table, + columns=[ + Column( + name="id", + data_type="integer", + udt_name="int4", + nullable=False, + is_auto_generated=True, + is_primary_key=True, + ), + Column( + name="email", + data_type="text", + udt_name="text", + nullable=False, + ), + ], + primary_keys=["id"], + foreign_keys_outgoing=[], + foreign_keys_incoming=[], + ) + else: # orders + return Table( + schema_name=schema, + table_name=table, + columns=[ + Column( + name="id", + data_type="integer", + udt_name="int4", + nullable=False, + is_auto_generated=True, + is_primary_key=True, + ), + Column( + name="user_id", + data_type="integer", + udt_name="int4", + nullable=False, + ), + ], + primary_keys=["id"], + foreign_keys_outgoing=[ + ForeignKey( + constraint_name="orders_user_id_fkey", + source_table="orders", + source_column="user_id", + target_table="users", + target_column="id", + ) + ], + foreign_keys_incoming=[], + ) + + mock_introspector.get_table_metadata.side_effect = get_table_meta + + result = generator.generate_ddl( + "testdb", "public", {("public", "users"), ("public", "orders")} + ) + + # Should contain all parts + assert '-- CREATE DATABASE "testdb";' in result + assert 'CREATE SCHEMA IF NOT EXISTS "public"' in result + assert 'CREATE TABLE IF NOT EXISTS "public"."users"' in result + assert 'CREATE TABLE IF NOT EXISTS "public"."orders"' in result + assert 'ALTER TABLE "public"."orders"' in result + assert "FOREIGN KEY" in result diff --git a/tests/unit/dumper/test_sql_generator.py b/tests/unit/dumper/test_sql_generator.py index d17ecff..a96ddc7 100644 --- a/tests/unit/dumper/test_sql_generator.py +++ b/tests/unit/dumper/test_sql_generator.py @@ -1767,3 +1767,142 @@ def test_on_conflict_clause_with_fk_remapping( # Should include ON CONFLICT clause assert "ON CONFLICT" in sql assert "DO UPDATE SET" in sql + + +class TestCreateSchemaIntegration(TestSQLGenerator): + """Integration tests for create_schema flag.""" + + def test_generate_batch_with_create_schema_and_keep_pks( + self, mock_introspector: MagicMock + ) -> None: + """Should include DDL when create_schema=True with keep_pks=True.""" + generator = SQLGenerator(mock_introspector, batch_size=100) + + # Create test records + records = [ + RecordData( + identifier=RecordIdentifier( + table_name="users", schema_name="public", pk_values=("1",) + ), + data={ + "id": 1, + "name": "Alice", + "email": "alice@example.com", + "age": 30, + }, + dependencies=set(), + ), + ] + + # Generate with DDL + sql = generator.generate_batch( + records, + keep_pks=True, + create_schema=True, + database_name="testdb", + schema_name="public", + ) + + # Verify DDL statements are present + assert '-- CREATE DATABASE "testdb";' in sql + assert 'CREATE SCHEMA IF NOT EXISTS "public"' in sql + assert 'CREATE TABLE IF NOT EXISTS "public"."users"' in sql + + # Verify INSERT statements are also present + assert "INSERT INTO" in sql + assert "Alice" in sql + + def test_generate_batch_with_create_schema_and_plpgsql( + self, mock_introspector: MagicMock + ) -> None: + """Should include DDL when create_schema=True with PL/pgSQL remapping.""" + generator = SQLGenerator(mock_introspector, batch_size=100) + + # Create test records + records = [ + RecordData( + identifier=RecordIdentifier( + table_name="users", schema_name="public", pk_values=("1",) + ), + data={"name": "Bob", "email": "bob@example.com", "age": 25}, + dependencies=set(), + ), + ] + + # Generate with DDL and PL/pgSQL remapping + sql = generator.generate_batch( + records, + keep_pks=False, # Triggers PL/pgSQL mode + create_schema=True, + database_name="testdb", + schema_name="public", + ) + + # Verify DDL statements are present + assert '-- CREATE DATABASE "testdb";' in sql + assert 'CREATE SCHEMA IF NOT EXISTS "public"' in sql + assert 'CREATE TABLE IF NOT EXISTS "public"."users"' in sql + + # Verify PL/pgSQL block is also present + assert "DO $$" in sql or "CREATE TEMP TABLE" in sql + + def test_generate_batch_without_create_schema( + self, mock_introspector: MagicMock + ) -> None: + """Should NOT include DDL when create_schema=False (default).""" + generator = SQLGenerator(mock_introspector, batch_size=100) + + # Create test records + records = [ + RecordData( + identifier=RecordIdentifier( + table_name="users", schema_name="public", pk_values=("1",) + ), + data={ + "id": 1, + "name": "Charlie", + "email": "charlie@example.com", + "age": 35, + }, + dependencies=set(), + ), + ] + + # Generate without DDL + sql = generator.generate_batch(records, keep_pks=True, create_schema=False) + + # Verify DDL statements are NOT present + assert "CREATE DATABASE" not in sql + assert "CREATE SCHEMA" not in sql + assert "CREATE TABLE" not in sql + + # Verify INSERT statements are present + assert "INSERT INTO" in sql + assert "Charlie" in sql + + def test_create_schema_requires_database_name( + self, mock_introspector: MagicMock + ) -> None: + """Should not generate DDL if database_name is not provided.""" + generator = SQLGenerator(mock_introspector, batch_size=100) + + # Create test records + records = [ + RecordData( + identifier=RecordIdentifier( + table_name="users", schema_name="public", pk_values=("1",) + ), + data={"id": 1, "name": "Dave", "email": "dave@example.com", "age": 40}, + dependencies=set(), + ), + ] + + # Generate with create_schema=True but no database_name + sql = generator.generate_batch( + records, keep_pks=True, create_schema=True, database_name=None + ) + + # Verify DDL is not generated without database name + assert "CREATE DATABASE" not in sql + assert "CREATE SCHEMA" not in sql + assert "CREATE TABLE" not in sql diff --git a/tests/unit/graph/test_traverser_progress.py b/tests/unit/graph/test_traverser_progress.py new file mode 100644 index 0000000..5660ffe --- /dev/null +++ b/tests/unit/graph/test_traverser_progress.py @@ -0,0 +1,185 @@ +"""Tests for traverser progress callback functionality.""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +from pgslice.graph.models import RecordData +from pgslice.graph.traverser import RelationshipTraverser + + +class TestProgressCallback: + """Tests for progress callback parameter and invocation.""" + + def test_traverser_accepts_progress_callback(self) -> None: + """Should accept progress_callback parameter without errors.""" + mock_conn = MagicMock() + mock_introspector = MagicMock() + mock_visited = MagicMock() + mock_callback = MagicMock() + + traverser = RelationshipTraverser( + mock_conn, + mock_introspector, + mock_visited, + progress_callback=mock_callback, + ) + + assert traverser.progress_callback == mock_callback + + def test_progress_callback_none_does_not_error(self) -> None: + """Should work correctly when progress_callback is None.""" + mock_conn = MagicMock() + mock_introspector = MagicMock() + mock_visited = MagicMock() + + # Should not raise any errors + traverser = RelationshipTraverser( + mock_conn, + mock_introspector, + mock_visited, + progress_callback=None, + ) + + assert traverser.progress_callback is None + + def test_progress_callback_invoked_per_record(self, mocker: MagicMock) -> None: + """Should invoke callback after each record is fetched.""" + mock_conn = MagicMock() + mock_introspector = MagicMock() + mock_visited = MagicMock() + mock_callback = MagicMock() + + # Mock table metadata + mock_table = MagicMock() + mock_table.primary_keys = ["id"] + mock_table.foreign_keys_outgoing = [] + mock_table.foreign_keys_incoming = [] + mock_introspector.get_table_metadata.return_value = mock_table + + # Mock visited tracker + mock_visited.is_visited.return_value = False + mock_visited.mark_visited.return_value = None + + # Mock fetch_record to return a valid record + def mock_fetch(record_id): + return RecordData( + identifier=record_id, + data={"id": record_id.pk_values[0]}, + ) + + traverser = RelationshipTraverser( + mock_conn, + mock_introspector, + mock_visited, + progress_callback=mock_callback, + ) + + # Patch _fetch_record method + mocker.patch.object(traverser, "_fetch_record", side_effect=mock_fetch) + + # Traverse a single record + traverser.traverse("users", "1", "public", max_depth=0) + + # Callback should be invoked at least once + assert mock_callback.called + mock_callback.assert_called_with(1) + + def test_progress_callback_receives_correct_count(self, mocker: MagicMock) -> None: + """Should pass accurate record count to callback.""" + mock_conn = MagicMock() + mock_introspector = MagicMock() + mock_visited = MagicMock() + mock_callback = MagicMock() + + # Mock table metadata with no relationships + mock_table = MagicMock() + mock_table.primary_keys = ["id"] + mock_table.foreign_keys_outgoing = [] + mock_table.foreign_keys_incoming = [] + mock_introspector.get_table_metadata.return_value = mock_table + + # Track visited records + visited_records = set() + + def is_visited(record_id): + return record_id in visited_records + + def mark_visited(record_id): + visited_records.add(record_id) + + mock_visited.is_visited.side_effect = is_visited + mock_visited.mark_visited.side_effect = mark_visited + + # Mock fetch_record + def mock_fetch(record_id): + return RecordData( + identifier=record_id, + data={"id": record_id.pk_values[0]}, + ) + + traverser = RelationshipTraverser( + mock_conn, + mock_introspector, + mock_visited, + progress_callback=mock_callback, + ) + + mocker.patch.object(traverser, "_fetch_record", side_effect=mock_fetch) + + # Traverse + traverser.traverse("users", "1", "public", max_depth=0) + + # Should be called with count 1 (only the starting record) + mock_callback.assert_called_with(1) + + def test_traverse_multiple_updates_progress(self, mocker: MagicMock) -> None: + """Should invoke callback when traversing multiple starting records.""" + mock_conn = MagicMock() + mock_introspector = MagicMock() + mock_visited = MagicMock() + mock_callback = MagicMock() + + # Mock table metadata + mock_table = MagicMock() + mock_table.primary_keys = ["id"] + mock_table.foreign_keys_outgoing = [] + mock_table.foreign_keys_incoming = [] + mock_introspector.get_table_metadata.return_value = mock_table + + # Track visited records to avoid duplicates + visited_records = set() + + def is_visited(record_id): + return record_id in visited_records + + def mark_visited(record_id): + visited_records.add(record_id) + + mock_visited.is_visited.side_effect = is_visited + mock_visited.mark_visited.side_effect = mark_visited + + # Mock fetch_record + def mock_fetch(record_id): + return RecordData( + identifier=record_id, + data={"id": record_id.pk_values[0]}, + ) + + traverser = RelationshipTraverser( + mock_conn, + mock_introspector, + mock_visited, + progress_callback=mock_callback, + ) + + mocker.patch.object(traverser, "_fetch_record", side_effect=mock_fetch) + + # Traverse multiple records + traverser.traverse_multiple("users", ["1", "2", "3"], "public", max_depth=0) + + # Should be called multiple times (once per record + final callback) + assert mock_callback.call_count >= 3 + # Final call should have count of 3 unique records + final_call_arg = mock_callback.call_args[0][0] + assert final_call_arg == 3 diff --git a/tests/unit/operations/__init__.py b/tests/unit/operations/__init__.py new file mode 100644 index 0000000..8946b1c --- /dev/null +++ b/tests/unit/operations/__init__.py @@ -0,0 +1 @@ +"""Tests for shared operations module.""" diff --git a/tests/unit/operations/test_dump_ops.py b/tests/unit/operations/test_dump_ops.py new file mode 100644 index 0000000..2ff4441 --- /dev/null +++ b/tests/unit/operations/test_dump_ops.py @@ -0,0 +1,151 @@ +"""Tests for shared dump operations.""" + +from __future__ import annotations + +from datetime import datetime +from unittest.mock import MagicMock, patch + +from pgslice.graph.models import TimeframeFilter +from pgslice.operations.dump_ops import DumpOptions, execute_dump + + +class TestDumpOptions: + """Tests for DumpOptions dataclass.""" + + def test_creates_with_required_fields(self) -> None: + """Should create options with required fields.""" + options = DumpOptions( + table="users", + pk_values=["1", "2"], + schema="public", + ) + + assert options.table == "users" + assert options.pk_values == ["1", "2"] + assert options.schema == "public" + + def test_has_correct_defaults(self) -> None: + """Should have correct default values.""" + options = DumpOptions( + table="users", + pk_values=["1"], + schema="public", + ) + + assert options.wide_mode is False + assert options.keep_pks is False + assert options.create_schema is False + assert options.timeframe_filters == [] + assert options.show_progress is False + + def test_accepts_all_options(self) -> None: + """Should accept all optional fields.""" + tf = TimeframeFilter( + table_name="orders", + column_name="created_at", + start_date=datetime(2024, 1, 1), + end_date=datetime(2024, 12, 31), + ) + + options = DumpOptions( + table="users", + pk_values=["1"], + schema="custom", + wide_mode=True, + keep_pks=True, + create_schema=True, + timeframe_filters=[tf], + show_progress=True, + ) + + assert options.wide_mode is True + assert options.keep_pks is True + assert options.create_schema is True + assert len(options.timeframe_filters) == 1 + assert options.show_progress is True + + +class TestExecuteDump: + """Tests for execute_dump function.""" + + def test_creates_dump_service_and_executes(self) -> None: + """Should create DumpService and execute dump.""" + mock_conn_manager = MagicMock() + mock_config = MagicMock() + mock_result = MagicMock() + + with patch("pgslice.operations.dump_ops.DumpService") as mock_service_class: + mock_service = MagicMock() + mock_service.dump.return_value = mock_result + mock_service_class.return_value = mock_service + + options = DumpOptions( + table="users", + pk_values=["42"], + schema="public", + ) + + result = execute_dump(mock_conn_manager, mock_config, options) + + # Verify DumpService was created correctly + mock_service_class.assert_called_once_with( + mock_conn_manager, mock_config, show_progress=False + ) + + # Verify dump was called with correct args + mock_service.dump.assert_called_once_with( + table="users", + pk_values=["42"], + schema="public", + wide_mode=False, + keep_pks=False, + create_schema=False, + timeframe_filters=[], + ) + + assert result == mock_result + + def test_passes_all_options_to_dump_service(self) -> None: + """Should pass all options to DumpService.dump().""" + mock_conn_manager = MagicMock() + mock_config = MagicMock() + + tf = TimeframeFilter( + table_name="orders", + column_name="created_at", + start_date=datetime(2024, 1, 1), + end_date=datetime(2024, 12, 31), + ) + + with patch("pgslice.operations.dump_ops.DumpService") as mock_service_class: + mock_service = MagicMock() + mock_service_class.return_value = mock_service + + options = DumpOptions( + table="users", + pk_values=["1", "2", "3"], + schema="custom", + wide_mode=True, + keep_pks=True, + create_schema=True, + timeframe_filters=[tf], + show_progress=True, + ) + + execute_dump(mock_conn_manager, mock_config, options) + + # Verify show_progress is passed to constructor + mock_service_class.assert_called_once_with( + mock_conn_manager, mock_config, show_progress=True + ) + + # Verify all options passed to dump + mock_service.dump.assert_called_once_with( + table="users", + pk_values=["1", "2", "3"], + schema="custom", + wide_mode=True, + keep_pks=True, + create_schema=True, + timeframe_filters=[tf], + ) diff --git a/tests/unit/operations/test_parsing.py b/tests/unit/operations/test_parsing.py new file mode 100644 index 0000000..8039ddf --- /dev/null +++ b/tests/unit/operations/test_parsing.py @@ -0,0 +1,100 @@ +"""Tests for shared parsing operations.""" + +from __future__ import annotations + +from datetime import datetime + +import pytest + +from pgslice.operations.parsing import parse_truncate_filter, parse_truncate_filters +from pgslice.utils.exceptions import InvalidTimeframeError + + +class TestParseTruncateFilter: + """Tests for parse_truncate_filter function.""" + + def test_parses_four_part_format(self) -> None: + """Should parse table:column:start:end format.""" + result = parse_truncate_filter("orders:created_at:2024-01-01:2024-12-31") + + assert result.table_name == "orders" + assert result.column_name == "created_at" + assert result.start_date == datetime(2024, 1, 1) + assert result.end_date == datetime(2024, 12, 31) + + def test_parses_three_part_format(self) -> None: + """Should parse table:start:end format with default column.""" + result = parse_truncate_filter("orders:2024-01-01:2024-12-31") + + assert result.table_name == "orders" + assert result.column_name == "created_at" # Default + assert result.start_date == datetime(2024, 1, 1) + assert result.end_date == datetime(2024, 12, 31) + + def test_raises_for_invalid_format(self) -> None: + """Should raise for invalid format.""" + with pytest.raises(InvalidTimeframeError) as exc_info: + parse_truncate_filter("invalid") + + assert "Invalid truncate filter format" in str(exc_info.value) + + def test_raises_for_too_many_parts(self) -> None: + """Should raise for too many parts.""" + with pytest.raises(InvalidTimeframeError) as exc_info: + parse_truncate_filter("a:b:c:d:e") + + assert "Invalid truncate filter format" in str(exc_info.value) + + def test_raises_for_invalid_start_date(self) -> None: + """Should raise for invalid start date.""" + with pytest.raises(InvalidTimeframeError) as exc_info: + parse_truncate_filter("orders:not-a-date:2024-12-31") + + assert "Invalid start date" in str(exc_info.value) + + def test_raises_for_invalid_end_date(self) -> None: + """Should raise for invalid end date.""" + with pytest.raises(InvalidTimeframeError) as exc_info: + parse_truncate_filter("orders:2024-01-01:not-a-date") + + assert "Invalid end date" in str(exc_info.value) + + +class TestParseTruncateFilters: + """Tests for parse_truncate_filters function.""" + + def test_returns_empty_for_none(self) -> None: + """Should return empty list for None.""" + result = parse_truncate_filters(None) + assert result == [] + + def test_returns_empty_for_empty_list(self) -> None: + """Should return empty list for empty list.""" + result = parse_truncate_filters([]) + assert result == [] + + def test_parses_single_filter(self) -> None: + """Should parse single filter.""" + result = parse_truncate_filters(["orders:2024-01-01:2024-12-31"]) + + assert len(result) == 1 + assert result[0].table_name == "orders" + + def test_parses_multiple_filters(self) -> None: + """Should parse multiple filters.""" + result = parse_truncate_filters( + [ + "orders:2024-01-01:2024-12-31", + "payments:paid_at:2024-06-01:2024-06-30", + ] + ) + + assert len(result) == 2 + assert result[0].table_name == "orders" + assert result[1].table_name == "payments" + assert result[1].column_name == "paid_at" + + def test_raises_for_invalid_filter_in_list(self) -> None: + """Should raise if any filter is invalid.""" + with pytest.raises(InvalidTimeframeError): + parse_truncate_filters(["orders:2024-01-01:2024-12-31", "invalid"]) diff --git a/tests/unit/operations/test_schema_ops.py b/tests/unit/operations/test_schema_ops.py new file mode 100644 index 0000000..0add4ea --- /dev/null +++ b/tests/unit/operations/test_schema_ops.py @@ -0,0 +1,147 @@ +"""Tests for shared schema operations.""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + +from pgslice.operations.schema_ops import describe_table, list_tables, print_tables + + +class TestListTables: + """Tests for list_tables function.""" + + def test_returns_table_list(self) -> None: + """Should return list of tables from introspector.""" + mock_conn_manager = MagicMock() + mock_conn = MagicMock() + mock_conn_manager.get_connection.return_value = mock_conn + + with patch("pgslice.operations.schema_ops.SchemaIntrospector") as mock_intro: + mock_introspector = MagicMock() + mock_introspector.get_all_tables.return_value = ["users", "orders"] + mock_intro.return_value = mock_introspector + + result = list_tables(mock_conn_manager, "public") + + assert result == ["users", "orders"] + mock_introspector.get_all_tables.assert_called_once_with("public") + + def test_uses_provided_schema(self) -> None: + """Should use the provided schema.""" + mock_conn_manager = MagicMock() + mock_conn = MagicMock() + mock_conn_manager.get_connection.return_value = mock_conn + + with patch("pgslice.operations.schema_ops.SchemaIntrospector") as mock_intro: + mock_introspector = MagicMock() + mock_introspector.get_all_tables.return_value = [] + mock_intro.return_value = mock_introspector + + list_tables(mock_conn_manager, "custom_schema") + + mock_introspector.get_all_tables.assert_called_once_with("custom_schema") + + +class TestPrintTables: + """Tests for print_tables function.""" + + def test_prints_formatted_output(self) -> None: + """Should print tables with formatting.""" + with patch("pgslice.operations.schema_ops.printy") as mock_printy: + print_tables(["users", "orders"], "public") + + # Check header was printed + calls = [str(call) for call in mock_printy.call_args_list] + assert any("Tables in schema 'public'" in call for call in calls) + assert any("Total: 2 tables" in call for call in calls) + + def test_handles_empty_list(self) -> None: + """Should handle empty table list.""" + with patch("pgslice.operations.schema_ops.printy") as mock_printy: + print_tables([], "public") + + calls = [str(call) for call in mock_printy.call_args_list] + assert any("Total: 0 tables" in call for call in calls) + + +class TestDescribeTable: + """Tests for describe_table function.""" + + @pytest.fixture + def mock_table(self) -> MagicMock: + """Create mock table metadata.""" + table = MagicMock() + table.full_name = "public.users" + table.primary_keys = ["id"] + + # Mock column + col = MagicMock() + col.name = "id" + col.data_type = "integer" + col.nullable = False + col.default = "nextval('users_id_seq')" + col.is_primary_key = True + table.columns = [col] + + # Mock FKs + table.foreign_keys_outgoing = [] + table.foreign_keys_incoming = [] + + return table + + def test_displays_table_structure(self, mock_table: MagicMock) -> None: + """Should display table structure.""" + mock_conn_manager = MagicMock() + mock_conn = MagicMock() + mock_conn_manager.get_connection.return_value = mock_conn + + with ( + patch("pgslice.operations.schema_ops.SchemaIntrospector") as mock_intro, + patch("pgslice.operations.schema_ops.printy") as mock_printy, + ): + mock_introspector = MagicMock() + mock_introspector.get_table_metadata.return_value = mock_table + mock_intro.return_value = mock_introspector + + describe_table(mock_conn_manager, "public", "users") + + # Check table name was printed + calls = [str(call) for call in mock_printy.call_args_list] + assert any("public.users" in call for call in calls) + assert any("Columns" in call for call in calls) + + def test_displays_foreign_keys(self, mock_table: MagicMock) -> None: + """Should display foreign key relationships.""" + # Add outgoing FK + fk_out = MagicMock() + fk_out.source_column = "role_id" + fk_out.target_table = "roles" + fk_out.target_column = "id" + mock_table.foreign_keys_outgoing = [fk_out] + + # Add incoming FK + fk_in = MagicMock() + fk_in.source_table = "orders" + fk_in.source_column = "user_id" + fk_in.target_column = "id" + mock_table.foreign_keys_incoming = [fk_in] + + mock_conn_manager = MagicMock() + mock_conn = MagicMock() + mock_conn_manager.get_connection.return_value = mock_conn + + with ( + patch("pgslice.operations.schema_ops.SchemaIntrospector") as mock_intro, + patch("pgslice.operations.schema_ops.printy") as mock_printy, + ): + mock_introspector = MagicMock() + mock_introspector.get_table_metadata.return_value = mock_table + mock_intro.return_value = mock_introspector + + describe_table(mock_conn_manager, "public", "users") + + calls = [str(call) for call in mock_printy.call_args_list] + assert any("Foreign Keys (Outgoing)" in call for call in calls) + assert any("Referenced By (Incoming)" in call for call in calls) diff --git a/tests/unit/test_cli.py b/tests/unit/test_cli.py index 6a0a1b0..155a484 100644 --- a/tests/unit/test_cli.py +++ b/tests/unit/test_cli.py @@ -3,12 +3,20 @@ from __future__ import annotations import sys +from datetime import datetime from pathlib import Path from unittest.mock import MagicMock, patch import pytest -from pgslice.cli import main +from pgslice.cli import ( + MainTableTimeframe, + main, + parse_main_timeframe, + run_describe_table, + run_list_tables, +) +from pgslice.utils.exceptions import InvalidTimeframeError class TestMain: @@ -89,6 +97,26 @@ def test_no_cache_flag(self) -> None: # Cache should be disabled assert mock_config.cache.enabled is False + def test_create_schema_flag(self) -> None: + """Should enable create_schema when --create-schema is used.""" + with ( + patch.object( + sys, "argv", ["pgslice", "--create-schema", "--host", "localhost"] + ), + patch("pgslice.cli.load_config") as mock_load, + ): + mock_config = MagicMock() + mock_config.db.host = "" + mock_config.db.user = "" + mock_config.db.database = "" + mock_config.create_schema = False + mock_load.return_value = mock_config + + main() + + # create_schema should be enabled + assert mock_config.create_schema is True + def test_cli_args_override_config(self) -> None: """CLI arguments should override config values.""" with ( @@ -271,3 +299,608 @@ def test_successful_repl_start(self) -> None: mock_cm_instance.close.assert_called_once() mock_creds_instance.clear.assert_called_once() assert exit_code == 0 + + +class TestCLIDumpMode: + """Tests for CLI dump mode (non-interactive).""" + + def test_table_without_pks_or_timeframe_fails( + self, capsys: pytest.CaptureFixture[str] + ) -> None: + """Should fail when --table is provided without --pks or --timeframe.""" + with ( + patch.object( + sys, + "argv", + [ + "pgslice", + "--host", + "localhost", + "--user", + "test", + "--database", + "test", + "--table", + "users", + ], + ), + patch("pgslice.cli.load_config") as mock_load, + ): + mock_config = MagicMock() + mock_config.db.host = "localhost" + mock_config.db.user = "test" + mock_config.db.database = "test" + mock_load.return_value = mock_config + + exit_code = main() + assert exit_code == 1 + + captured = capsys.readouterr() + assert "--pks or --timeframe is required" in captured.err + + def test_cli_dump_mode_executes(self) -> None: + """Should execute dump in CLI mode when --table and --pks are provided.""" + from pgslice.dumper.dump_service import DumpResult + + mock_result = DumpResult( + sql_content="INSERT INTO users VALUES (1);", + record_count=1, + tables_involved={"users"}, + ) + + with ( + patch.object( + sys, + "argv", + [ + "pgslice", + "--host", + "localhost", + "--user", + "test", + "--database", + "test", + "--table", + "users", + "--pks", + "1", + ], + ), + patch("pgslice.cli.load_config") as mock_load, + patch("pgslice.cli.SecureCredentials"), + patch("pgslice.cli.ConnectionManager") as mock_cm, + patch("pgslice.cli.DumpService") as mock_dump_service, + patch("pgslice.cli.SQLWriter") as mock_writer, + ): + mock_config = MagicMock() + mock_config.db.host = "localhost" + mock_config.db.user = "test" + mock_config.db.database = "test" + mock_config.db.port = 5432 + mock_config.db.schema = "public" + mock_config.cache.enabled = False + mock_config.connection_ttl_minutes = 30 + mock_config.create_schema = False + mock_load.return_value = mock_config + + mock_cm_instance = MagicMock() + mock_cm.return_value = mock_cm_instance + + mock_service_instance = MagicMock() + mock_service_instance.dump.return_value = mock_result + mock_dump_service.return_value = mock_service_instance + + exit_code = main() + assert exit_code == 0 + + # DumpService.dump should have been called + mock_service_instance.dump.assert_called_once() + call_kwargs = mock_service_instance.dump.call_args[1] + assert call_kwargs["table"] == "users" + assert call_kwargs["pk_values"] == ["1"] + + # SQL should be written to stdout (no --output flag) + mock_writer.write_to_stdout.assert_called_once_with(mock_result.sql_content) + + def test_cli_dump_with_output_file(self, tmp_path: Path) -> None: + """Should write to file when --output is specified.""" + from pgslice.dumper.dump_service import DumpResult + + output_file = str(tmp_path / "output.sql") + mock_result = DumpResult( + sql_content="INSERT INTO users VALUES (1);", + record_count=1, + tables_involved={"users"}, + ) + + with ( + patch.object( + sys, + "argv", + [ + "pgslice", + "--host", + "localhost", + "--user", + "test", + "--database", + "test", + "--table", + "users", + "--pks", + "1", + "--output", + output_file, + ], + ), + patch("pgslice.cli.load_config") as mock_load, + patch("pgslice.cli.SecureCredentials"), + patch("pgslice.cli.ConnectionManager") as mock_cm, + patch("pgslice.cli.DumpService") as mock_dump_service, + patch("pgslice.cli.SQLWriter") as mock_writer, + ): + mock_config = MagicMock() + mock_config.db.host = "localhost" + mock_config.db.user = "test" + mock_config.db.database = "test" + mock_config.db.port = 5432 + mock_config.db.schema = "public" + mock_config.cache.enabled = False + mock_config.connection_ttl_minutes = 30 + mock_config.create_schema = False + mock_load.return_value = mock_config + + mock_cm_instance = MagicMock() + mock_cm.return_value = mock_cm_instance + + mock_service_instance = MagicMock() + mock_service_instance.dump.return_value = mock_result + mock_dump_service.return_value = mock_service_instance + + exit_code = main() + assert exit_code == 0 + + # SQL should be written to file + mock_writer.write_to_file.assert_called_once() + call_args = mock_writer.write_to_file.call_args[0] + assert call_args[1] == output_file + + def test_cli_dump_with_flags(self) -> None: + """Should pass flags to DumpService correctly.""" + from pgslice.dumper.dump_service import DumpResult + + mock_result = DumpResult( + sql_content="INSERT...", + record_count=1, + tables_involved={"users"}, + ) + + with ( + patch.object( + sys, + "argv", + [ + "pgslice", + "--host", + "localhost", + "--user", + "test", + "--database", + "test", + "--table", + "users", + "--pks", + "1,2,3", + "--wide", + "--keep-pks", + "--create-schema", + "--truncate", + "orders:2024-01-01:2024-12-31", + ], + ), + patch("pgslice.cli.load_config") as mock_load, + patch("pgslice.cli.SecureCredentials"), + patch("pgslice.cli.ConnectionManager") as mock_cm, + patch("pgslice.cli.DumpService") as mock_dump_service, + patch("pgslice.cli.SQLWriter"), + ): + mock_config = MagicMock() + mock_config.db.host = "localhost" + mock_config.db.user = "test" + mock_config.db.database = "test" + mock_config.db.port = 5432 + mock_config.db.schema = "public" + mock_config.cache.enabled = False + mock_config.connection_ttl_minutes = 30 + mock_config.create_schema = False + mock_load.return_value = mock_config + + mock_cm_instance = MagicMock() + mock_cm.return_value = mock_cm_instance + + mock_service_instance = MagicMock() + mock_service_instance.dump.return_value = mock_result + mock_dump_service.return_value = mock_service_instance + + exit_code = main() + assert exit_code == 0 + + # Check that flags were passed correctly + call_kwargs = mock_service_instance.dump.call_args[1] + assert call_kwargs["pk_values"] == ["1", "2", "3"] + assert call_kwargs["wide_mode"] is True + assert call_kwargs["keep_pks"] is True + assert call_kwargs["create_schema"] is True + assert len(call_kwargs["timeframe_filters"]) == 1 + + +class TestLoggingDefault: + """Tests for logging disabled by default.""" + + def test_logging_disabled_by_default(self) -> None: + """Should disable logging when --log-level is not specified.""" + with ( + patch.object(sys, "argv", ["pgslice", "--clear-cache"]), + patch("pgslice.cli.setup_logging") as mock_setup, + patch("pgslice.cli.load_config") as mock_load, + ): + mock_config = MagicMock() + mock_config.cache.enabled = False + mock_load.return_value = mock_config + + main() + + # setup_logging should be called with None (disabled) + mock_setup.assert_called_with(None) + + +class TestParseMainTimeframe: + """Tests for parse_main_timeframe function.""" + + def test_parses_valid_format(self) -> None: + """Should parse column:start:end format.""" + result = parse_main_timeframe("created_at:2024-01-01:2024-12-31") + + assert isinstance(result, MainTableTimeframe) + assert result.column_name == "created_at" + assert result.start_date == datetime(2024, 1, 1) + assert result.end_date == datetime(2024, 12, 31) + + def test_raises_for_invalid_format(self) -> None: + """Should raise for invalid format.""" + with pytest.raises(InvalidTimeframeError, match="Invalid timeframe format"): + parse_main_timeframe("just_column") + + with pytest.raises(InvalidTimeframeError, match="Invalid timeframe format"): + parse_main_timeframe("a:b:c:d") + + def test_raises_for_invalid_start_date(self) -> None: + """Should raise for invalid start date.""" + with pytest.raises(InvalidTimeframeError, match="Invalid start date"): + parse_main_timeframe("created_at:invalid:2024-12-31") + + def test_raises_for_invalid_end_date(self) -> None: + """Should raise for invalid end date.""" + with pytest.raises(InvalidTimeframeError, match="Invalid end date"): + parse_main_timeframe("created_at:2024-01-01:invalid") + + +class TestMainTableTimeframeCLI: + """Tests for main table timeframe CLI functionality.""" + + def test_mutual_exclusion_with_pks(self) -> None: + """Should not allow both --pks and --timeframe.""" + with ( + patch.object( + sys, + "argv", + [ + "pgslice", + "--table", + "users", + "--pks", + "1", + "--timeframe", + "created_at:2024-01-01:2024-12-31", + ], + ), + pytest.raises(SystemExit) as exc_info, + ): + main() + # argparse exits with 2 for argument errors + assert exc_info.value.code == 2 + + def test_timeframe_mode_executes(self) -> None: + """Should execute dump with timeframe mode.""" + from pgslice.dumper.dump_service import DumpResult + + mock_result = DumpResult( + sql_content="INSERT INTO users VALUES (1);", + record_count=1, + tables_involved={"users"}, + ) + + mock_table_meta = MagicMock() + mock_table_meta.primary_keys = ["id"] + + with ( + patch.object( + sys, + "argv", + [ + "pgslice", + "--host", + "localhost", + "--user", + "test", + "--database", + "test", + "--table", + "users", + "--timeframe", + "created_at:2024-01-01:2024-12-31", + ], + ), + patch("pgslice.cli.load_config") as mock_load, + patch("pgslice.cli.SecureCredentials"), + patch("pgslice.cli.ConnectionManager") as mock_cm, + patch("pgslice.cli.SchemaIntrospector") as mock_introspector, + patch("pgslice.cli.DumpService") as mock_dump_service, + patch("pgslice.cli.SQLWriter"), + patch("pgslice.cli.printy"), + ): + mock_config = MagicMock() + mock_config.db.host = "localhost" + mock_config.db.user = "test" + mock_config.db.database = "test" + mock_config.db.port = 5432 + mock_config.db.schema = "public" + mock_config.cache.enabled = False + mock_config.connection_ttl_minutes = 30 + mock_config.create_schema = False + mock_load.return_value = mock_config + + mock_cm_instance = MagicMock() + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_cursor.fetchall.return_value = [(1,), (2,), (3,)] + mock_conn.cursor.return_value.__enter__.return_value = mock_cursor + mock_cm_instance.get_connection.return_value = mock_conn + mock_cm.return_value = mock_cm_instance + + mock_introspector_instance = MagicMock() + mock_introspector_instance.get_table_metadata.return_value = mock_table_meta + mock_introspector.return_value = mock_introspector_instance + + mock_service_instance = MagicMock() + mock_service_instance.dump.return_value = mock_result + mock_dump_service.return_value = mock_service_instance + + exit_code = main() + assert exit_code == 0 + + # DumpService.dump should have been called with PKs from timeframe query + mock_service_instance.dump.assert_called_once() + call_kwargs = mock_service_instance.dump.call_args[1] + assert call_kwargs["pk_values"] == ["1", "2", "3"] + + def test_timeframe_mode_empty_result(self) -> None: + """Should handle empty result from timeframe query.""" + mock_table_meta = MagicMock() + mock_table_meta.primary_keys = ["id"] + + with ( + patch.object( + sys, + "argv", + [ + "pgslice", + "--host", + "localhost", + "--user", + "test", + "--database", + "test", + "--table", + "users", + "--timeframe", + "created_at:2024-01-01:2024-12-31", + ], + ), + patch("pgslice.cli.load_config") as mock_load, + patch("pgslice.cli.SecureCredentials"), + patch("pgslice.cli.ConnectionManager") as mock_cm, + patch("pgslice.cli.SchemaIntrospector") as mock_introspector, + patch("pgslice.cli.printy") as mock_printy, + ): + mock_config = MagicMock() + mock_config.db.host = "localhost" + mock_config.db.user = "test" + mock_config.db.database = "test" + mock_config.db.port = 5432 + mock_config.db.schema = "public" + mock_config.cache.enabled = False + mock_config.connection_ttl_minutes = 30 + mock_config.create_schema = False + mock_load.return_value = mock_config + + mock_cm_instance = MagicMock() + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_cursor.fetchall.return_value = [] # Empty result + mock_conn.cursor.return_value.__enter__.return_value = mock_cursor + mock_cm_instance.get_connection.return_value = mock_conn + mock_cm.return_value = mock_cm_instance + + mock_introspector_instance = MagicMock() + mock_introspector_instance.get_table_metadata.return_value = mock_table_meta + mock_introspector.return_value = mock_introspector_instance + + exit_code = main() + assert exit_code == 0 + + # Should print warning about no records found + mock_printy.assert_any_call("[y]No records found matching the timeframe@") + + +class TestSchemaInfoFlags: + """Tests for --tables and --describe CLI flags.""" + + def test_tables_flag_lists_tables(self) -> None: + """Should list tables when --tables is used.""" + with ( + patch.object( + sys, + "argv", + [ + "pgslice", + "--host", + "localhost", + "--user", + "test", + "--database", + "test", + "--tables", + ], + ), + patch("pgslice.cli.load_config") as mock_load, + patch("pgslice.cli.SecureCredentials"), + patch("pgslice.cli.ConnectionManager") as mock_cm, + patch( + "pgslice.operations.schema_ops.SchemaIntrospector" + ) as mock_introspector, + patch("pgslice.operations.schema_ops.printy") as mock_printy, + ): + mock_config = MagicMock() + mock_config.db.host = "localhost" + mock_config.db.user = "test" + mock_config.db.database = "test" + mock_config.db.port = 5432 + mock_config.db.schema = "public" + mock_config.cache.enabled = False + mock_config.connection_ttl_minutes = 30 + mock_load.return_value = mock_config + + mock_cm_instance = MagicMock() + mock_cm.return_value = mock_cm_instance + + mock_introspector_instance = MagicMock() + mock_introspector_instance.get_all_tables.return_value = [ + "users", + "orders", + "products", + ] + mock_introspector.return_value = mock_introspector_instance + + exit_code = main() + assert exit_code == 0 + + # Should print each table and total + mock_printy.assert_any_call(" users") + mock_printy.assert_any_call(" orders") + mock_printy.assert_any_call(" products") + + def test_describe_flag_shows_table_info(self) -> None: + """Should describe table when --describe is used.""" + mock_table = MagicMock() + mock_table.full_name = "public.users" + mock_table.columns = [] + mock_table.primary_keys = ["id"] + mock_table.foreign_keys_outgoing = [] + mock_table.foreign_keys_incoming = [] + + with ( + patch.object( + sys, + "argv", + [ + "pgslice", + "--host", + "localhost", + "--user", + "test", + "--database", + "test", + "--describe", + "users", + ], + ), + patch("pgslice.cli.load_config") as mock_load, + patch("pgslice.cli.SecureCredentials"), + patch("pgslice.cli.ConnectionManager") as mock_cm, + patch( + "pgslice.operations.schema_ops.SchemaIntrospector" + ) as mock_introspector, + patch("pgslice.operations.schema_ops.printy") as mock_printy, + patch( + "pgslice.operations.schema_ops.tabulate", return_value="COLUMNS TABLE" + ), + ): + mock_config = MagicMock() + mock_config.db.host = "localhost" + mock_config.db.user = "test" + mock_config.db.database = "test" + mock_config.db.port = 5432 + mock_config.db.schema = "public" + mock_config.cache.enabled = False + mock_config.connection_ttl_minutes = 30 + mock_load.return_value = mock_config + + mock_cm_instance = MagicMock() + mock_cm.return_value = mock_cm_instance + + mock_introspector_instance = MagicMock() + mock_introspector_instance.get_table_metadata.return_value = mock_table + mock_introspector.return_value = mock_introspector_instance + + exit_code = main() + assert exit_code == 0 + + # Should print table name + mock_printy.assert_any_call("\n[c]Table: public.users@\n") + + def test_run_list_tables_function(self) -> None: + """Should return tables from introspector.""" + mock_conn_manager = MagicMock() + mock_conn = MagicMock() + mock_conn_manager.get_connection.return_value = mock_conn + + with ( + patch( + "pgslice.operations.schema_ops.SchemaIntrospector" + ) as mock_introspector, + patch("pgslice.operations.schema_ops.printy"), + ): + mock_introspector_instance = MagicMock() + mock_introspector_instance.get_all_tables.return_value = ["table1"] + mock_introspector.return_value = mock_introspector_instance + + result = run_list_tables(mock_conn_manager, "public") + assert result == 0 + + def test_run_describe_table_function(self) -> None: + """Should return table metadata from introspector.""" + mock_conn_manager = MagicMock() + mock_conn = MagicMock() + mock_conn_manager.get_connection.return_value = mock_conn + + mock_table = MagicMock() + mock_table.full_name = "public.users" + mock_table.columns = [] + mock_table.primary_keys = [] + mock_table.foreign_keys_outgoing = [] + mock_table.foreign_keys_incoming = [] + + with ( + patch( + "pgslice.operations.schema_ops.SchemaIntrospector" + ) as mock_introspector, + patch("pgslice.operations.schema_ops.printy"), + patch("pgslice.operations.schema_ops.tabulate", return_value=""), + ): + mock_introspector_instance = MagicMock() + mock_introspector_instance.get_table_metadata.return_value = mock_table + mock_introspector.return_value = mock_introspector_instance + + result = run_describe_table(mock_conn_manager, "public", "users") + assert result == 0 diff --git a/tests/unit/test_repl.py b/tests/unit/test_repl.py index b94c119..ee6fffa 100644 --- a/tests/unit/test_repl.py +++ b/tests/unit/test_repl.py @@ -3,7 +3,6 @@ from __future__ import annotations from collections.abc import Generator -from datetime import datetime from pathlib import Path from unittest.mock import MagicMock, patch @@ -11,7 +10,6 @@ from pgslice.config import AppConfig, CacheConfig, DatabaseConfig from pgslice.repl import REPL -from pgslice.utils.exceptions import InvalidTimeframeError class TestREPL: @@ -154,12 +152,14 @@ def test_lists_tables_in_default_schema( self, repl: REPL, mock_connection_manager: MagicMock ) -> None: """Should list tables in default schema.""" - with patch("pgslice.repl.SchemaIntrospector") as mock_introspector: + with patch( + "pgslice.operations.schema_ops.SchemaIntrospector" + ) as mock_introspector: mock_instance = MagicMock() mock_instance.get_all_tables.return_value = ["users", "orders"] mock_introspector.return_value = mock_instance - with patch("pgslice.repl.printy"): + with patch("pgslice.operations.schema_ops.printy"): repl._cmd_list_tables([]) mock_instance.get_all_tables.assert_called_once_with("public") @@ -168,12 +168,14 @@ def test_lists_tables_with_custom_schema( self, repl: REPL, mock_connection_manager: MagicMock ) -> None: """Should list tables in custom schema.""" - with patch("pgslice.repl.SchemaIntrospector") as mock_introspector: + with patch( + "pgslice.operations.schema_ops.SchemaIntrospector" + ) as mock_introspector: mock_instance = MagicMock() mock_instance.get_all_tables.return_value = ["custom_table"] mock_introspector.return_value = mock_instance - with patch("pgslice.repl.printy"): + with patch("pgslice.operations.schema_ops.printy"): repl._cmd_list_tables(["--schema", "custom"]) mock_instance.get_all_tables.assert_called_once_with("custom") @@ -182,10 +184,15 @@ def test_handles_error( self, repl: REPL, mock_connection_manager: MagicMock ) -> None: """Should handle errors gracefully.""" - with patch("pgslice.repl.SchemaIntrospector") as mock_introspector: + with patch( + "pgslice.operations.schema_ops.SchemaIntrospector" + ) as mock_introspector: mock_introspector.side_effect = Exception("Connection error") - with patch("pgslice.repl.printy"): + with ( + patch("pgslice.repl.printy"), + patch("pgslice.operations.schema_ops.printy"), + ): # Should not raise repl._cmd_list_tables([]) @@ -246,14 +253,16 @@ def test_describes_table( ], ) - with patch("pgslice.repl.SchemaIntrospector") as mock_introspector: + with patch( + "pgslice.operations.schema_ops.SchemaIntrospector" + ) as mock_introspector: mock_instance = MagicMock() mock_instance.get_table_metadata.return_value = mock_table mock_introspector.return_value = mock_instance with ( - patch("pgslice.repl.printy"), - patch("pgslice.repl.tabulate", return_value=""), + patch("pgslice.operations.schema_ops.printy"), + patch("pgslice.operations.schema_ops.tabulate", return_value=""), ): repl._cmd_describe_table(["users"]) @@ -281,14 +290,16 @@ def test_describes_table_with_custom_schema( foreign_keys_incoming=[], ) - with patch("pgslice.repl.SchemaIntrospector") as mock_introspector: + with patch( + "pgslice.operations.schema_ops.SchemaIntrospector" + ) as mock_introspector: mock_instance = MagicMock() mock_instance.get_table_metadata.return_value = mock_table mock_introspector.return_value = mock_instance with ( - patch("pgslice.repl.printy"), - patch("pgslice.repl.tabulate", return_value=""), + patch("pgslice.operations.schema_ops.printy"), + patch("pgslice.operations.schema_ops.tabulate", return_value=""), ): repl._cmd_describe_table(["data", "--schema", "custom"]) @@ -327,251 +338,139 @@ def test_shows_usage_without_args(self, repl: REPL) -> None: def test_executes_dump( self, repl: REPL, mock_connection_manager: MagicMock, tmp_path: Path ) -> None: - """Should execute dump command.""" - from pgslice.graph.models import Column, RecordData, RecordIdentifier, Table - - mock_table = Table( - schema_name="public", - table_name="users", - columns=[ - Column( - name="id", - data_type="integer", - udt_name="int4", - nullable=False, - is_primary_key=True, - ), - ], - primary_keys=["id"], - foreign_keys_outgoing=[], - foreign_keys_incoming=[], - ) + """Should execute dump command using DumpService.""" + from pgslice.dumper.dump_service import DumpResult - mock_record = RecordData( - identifier=RecordIdentifier( - schema_name="public", - table_name="users", - pk_values=("42",), - ), - data={"id": 42}, + mock_result = DumpResult( + sql_content="INSERT INTO users (id) VALUES (42);", + record_count=1, + tables_involved={"users"}, ) - with patch("pgslice.repl.SchemaIntrospector") as mock_introspector: - mock_intro_instance = MagicMock() - mock_intro_instance.get_table_metadata.return_value = mock_table - mock_introspector.return_value = mock_intro_instance - - with patch("pgslice.repl.RelationshipTraverser") as mock_traverser: - mock_trav_instance = MagicMock() - mock_trav_instance.traverse.return_value = {mock_record} - mock_traverser.return_value = mock_trav_instance - - with patch("pgslice.repl.DependencySorter") as mock_sorter: - mock_sorter_instance = MagicMock() - mock_sorter_instance.sort.return_value = [mock_record] - mock_sorter.return_value = mock_sorter_instance - - with patch("pgslice.repl.SQLGenerator") as mock_generator: - mock_gen_instance = MagicMock() - mock_gen_instance.generate_batch.return_value = ( - "INSERT INTO users (id) VALUES (42);" - ) - mock_generator.return_value = mock_gen_instance + with ( + patch("pgslice.repl.DumpService") as mock_dump_service, + patch("pgslice.repl.SQLWriter") as mock_writer, + patch("pgslice.repl.printy"), + ): + mock_service_instance = MagicMock() + mock_service_instance.dump.return_value = mock_result + mock_dump_service.return_value = mock_service_instance - with patch("pgslice.repl.SQLWriter") as mock_writer: - mock_writer.get_default_output_path.return_value = ( - tmp_path / "users_42.sql" - ) + mock_writer.get_default_output_path.return_value = tmp_path / "users_42.sql" - with patch("pgslice.repl.printy"): - repl._cmd_dump(["users", "42"]) + repl._cmd_dump(["users", "42"]) - mock_trav_instance.traverse.assert_called_once() - mock_sorter_instance.sort.assert_called_once() - mock_gen_instance.generate_batch.assert_called_once() + mock_service_instance.dump.assert_called_once() + call_kwargs = mock_service_instance.dump.call_args[1] + assert call_kwargs["table"] == "users" + assert call_kwargs["pk_values"] == ["42"] def test_executes_dump_with_output_file( self, repl: REPL, mock_connection_manager: MagicMock, tmp_path: Path ) -> None: """Should execute dump with specified output file.""" - from pgslice.graph.models import Column, RecordData, RecordIdentifier, Table + from pgslice.dumper.dump_service import DumpResult - mock_table = Table( - schema_name="public", - table_name="users", - columns=[ - Column( - name="id", - data_type="integer", - udt_name="int4", - nullable=False, - is_primary_key=True, - ), - ], - primary_keys=["id"], - foreign_keys_outgoing=[], - foreign_keys_incoming=[], - ) - - mock_record = RecordData( - identifier=RecordIdentifier( - schema_name="public", - table_name="users", - pk_values=("42",), - ), - data={"id": 42}, + mock_result = DumpResult( + sql_content="INSERT INTO users (id) VALUES (42);", + record_count=1, + tables_involved={"users"}, ) output_file = str(tmp_path / "custom_output.sql") - with patch("pgslice.repl.SchemaIntrospector") as mock_introspector: - mock_intro_instance = MagicMock() - mock_intro_instance.get_table_metadata.return_value = mock_table - mock_introspector.return_value = mock_intro_instance - - with patch("pgslice.repl.RelationshipTraverser") as mock_traverser: - mock_trav_instance = MagicMock() - mock_trav_instance.traverse.return_value = {mock_record} - mock_traverser.return_value = mock_trav_instance - - with patch("pgslice.repl.DependencySorter") as mock_sorter: - mock_sorter_instance = MagicMock() - mock_sorter_instance.sort.return_value = [mock_record] - mock_sorter.return_value = mock_sorter_instance - - with patch("pgslice.repl.SQLGenerator") as mock_generator: - mock_gen_instance = MagicMock() - mock_gen_instance.generate_batch.return_value = ( - "INSERT INTO users (id) VALUES (42);" - ) - mock_generator.return_value = mock_gen_instance + with ( + patch("pgslice.repl.DumpService") as mock_dump_service, + patch("pgslice.repl.SQLWriter") as mock_writer, + patch("pgslice.repl.printy"), + ): + mock_service_instance = MagicMock() + mock_service_instance.dump.return_value = mock_result + mock_dump_service.return_value = mock_service_instance - with patch("pgslice.repl.SQLWriter") as mock_writer: - with patch("pgslice.repl.printy"): - repl._cmd_dump(["users", "42", "--output", output_file]) + repl._cmd_dump(["users", "42", "--output", output_file]) - mock_writer.write_to_file.assert_called_once() - call_args = mock_writer.write_to_file.call_args - assert call_args[0][1] == output_file + mock_writer.write_to_file.assert_called_once() + call_args = mock_writer.write_to_file.call_args + assert call_args[0][1] == output_file def test_executes_dump_with_multiple_pks( self, repl: REPL, mock_connection_manager: MagicMock, tmp_path: Path ) -> None: """Should execute dump with multiple PKs.""" - from pgslice.graph.models import RecordData, RecordIdentifier + from pgslice.dumper.dump_service import DumpResult - mock_record = RecordData( - identifier=RecordIdentifier( - schema_name="public", - table_name="users", - pk_values=("42",), - ), - data={"id": 42}, + mock_result = DumpResult( + sql_content="INSERT...", + record_count=3, + tables_involved={"users"}, ) with ( - patch("pgslice.repl.SchemaIntrospector"), - patch("pgslice.repl.RelationshipTraverser") as mock_traverser, - patch("pgslice.repl.DependencySorter") as mock_sorter, - patch("pgslice.repl.SQLGenerator") as mock_generator, + patch("pgslice.repl.DumpService") as mock_dump_service, patch("pgslice.repl.SQLWriter") as mock_writer, patch("pgslice.repl.printy"), ): - mock_trav_instance = MagicMock() - mock_trav_instance.traverse_multiple.return_value = {mock_record} - mock_traverser.return_value = mock_trav_instance - - mock_sorter_instance = MagicMock() - mock_sorter_instance.sort.return_value = [mock_record] - mock_sorter.return_value = mock_sorter_instance - - mock_gen_instance = MagicMock() - mock_gen_instance.generate_batch.return_value = "INSERT..." - mock_generator.return_value = mock_gen_instance + mock_service_instance = MagicMock() + mock_service_instance.dump.return_value = mock_result + mock_dump_service.return_value = mock_service_instance mock_writer.get_default_output_path.return_value = tmp_path / "out.sql" repl._cmd_dump(["users", "42,43,44"]) - mock_trav_instance.traverse_multiple.assert_called_once() + call_kwargs = mock_service_instance.dump.call_args[1] + assert call_kwargs["pk_values"] == ["42", "43", "44"] def test_handles_wide_mode_flag( self, repl: REPL, mock_connection_manager: MagicMock, tmp_path: Path ) -> None: """Should handle --wide flag.""" - from pgslice.graph.models import RecordData, RecordIdentifier + from pgslice.dumper.dump_service import DumpResult - mock_record = RecordData( - identifier=RecordIdentifier( - schema_name="public", - table_name="users", - pk_values=("42",), - ), - data={"id": 42}, + mock_result = DumpResult( + sql_content="INSERT...", + record_count=1, + tables_involved={"users"}, ) with ( - patch("pgslice.repl.SchemaIntrospector"), - patch("pgslice.repl.RelationshipTraverser") as mock_traverser, - patch("pgslice.repl.DependencySorter") as mock_sorter, - patch("pgslice.repl.SQLGenerator") as mock_generator, + patch("pgslice.repl.DumpService") as mock_dump_service, patch("pgslice.repl.SQLWriter") as mock_writer, patch("pgslice.repl.printy"), ): - mock_trav_instance = MagicMock() - mock_trav_instance.traverse.return_value = {mock_record} - mock_traverser.return_value = mock_trav_instance - - mock_sorter_instance = MagicMock() - mock_sorter_instance.sort.return_value = [mock_record] - mock_sorter.return_value = mock_sorter_instance - - mock_gen_instance = MagicMock() - mock_gen_instance.generate_batch.return_value = "INSERT..." - mock_generator.return_value = mock_gen_instance + mock_service_instance = MagicMock() + mock_service_instance.dump.return_value = mock_result + mock_dump_service.return_value = mock_service_instance mock_writer.get_default_output_path.return_value = tmp_path / "out.sql" repl._cmd_dump(["users", "42", "--wide"]) - # Check that wide_mode=True was passed to traverser - call_args = mock_traverser.call_args - assert call_args[1]["wide_mode"] is True + # Check that wide_mode=True was passed to DumpService.dump() + call_kwargs = mock_service_instance.dump.call_args[1] + assert call_kwargs["wide_mode"] is True - def test_handles_timeframe_flag( + def test_handles_truncate_flag( self, repl: REPL, mock_connection_manager: MagicMock, tmp_path: Path ) -> None: - """Should handle --timeframe flag.""" - from pgslice.graph.models import RecordData, RecordIdentifier - - mock_record = RecordData( - identifier=RecordIdentifier( - schema_name="public", - table_name="users", - pk_values=("42",), - ), - data={"id": 42}, + """Should handle --truncate flag.""" + from pgslice.dumper.dump_service import DumpResult + + mock_result = DumpResult( + sql_content="INSERT...", + record_count=1, + tables_involved={"users"}, ) with ( - patch("pgslice.repl.SchemaIntrospector"), - patch("pgslice.repl.RelationshipTraverser") as mock_traverser, - patch("pgslice.repl.DependencySorter") as mock_sorter, - patch("pgslice.repl.SQLGenerator") as mock_generator, + patch("pgslice.repl.DumpService") as mock_dump_service, patch("pgslice.repl.SQLWriter") as mock_writer, patch("pgslice.repl.printy"), ): - mock_trav_instance = MagicMock() - mock_trav_instance.traverse.return_value = {mock_record} - mock_traverser.return_value = mock_trav_instance - - mock_sorter_instance = MagicMock() - mock_sorter_instance.sort.return_value = [mock_record] - mock_sorter.return_value = mock_sorter_instance - - mock_gen_instance = MagicMock() - mock_gen_instance.generate_batch.return_value = "INSERT..." - mock_generator.return_value = mock_gen_instance + mock_service_instance = MagicMock() + mock_service_instance.dump.return_value = mock_result + mock_dump_service.return_value = mock_service_instance mock_writer.get_default_output_path.return_value = tmp_path / "out.sql" @@ -579,20 +478,20 @@ def test_handles_timeframe_flag( [ "users", "42", - "--timeframe", + "--truncate", "orders:created_at:2024-01-01:2024-12-31", ] ) - # Check that timeframe was passed - call_args = mock_traverser.call_args - assert len(call_args[0][3]) == 1 # timeframe_filters + # Check that timeframe_filters was passed to DumpService.dump() + call_kwargs = mock_service_instance.dump.call_args[1] + assert len(call_kwargs["timeframe_filters"]) == 1 - def test_handles_invalid_timeframe(self, repl: REPL) -> None: - """Should handle invalid timeframe.""" + def test_handles_invalid_truncate(self, repl: REPL) -> None: + """Should handle invalid truncate filter.""" with patch("pgslice.repl.printy"): - # Invalid format - repl._cmd_dump(["users", "42", "--timeframe", "invalid"]) + # Invalid format - should not raise, just print error + repl._cmd_dump(["users", "42", "--truncate", "invalid"]) def test_handles_dump_error( self, repl: REPL, mock_connection_manager: MagicMock @@ -600,52 +499,16 @@ def test_handles_dump_error( """Should handle errors during dump.""" from pgslice.utils.exceptions import RecordNotFoundError - with patch("pgslice.repl.SchemaIntrospector") as mock_introspector: - mock_introspector.side_effect = RecordNotFoundError("Not found") - - with patch("pgslice.repl.printy"): - # Should not raise - repl._cmd_dump(["users", "42"]) - - -class TestParseTimeframe(TestREPL): - """Tests for _parse_timeframe method.""" - - def test_parses_four_part_format(self, repl: REPL) -> None: - """Should parse table:column:start:end format.""" - result = repl._parse_timeframe("orders:created_at:2024-01-01:2024-12-31") - - assert result.table_name == "orders" - assert result.column_name == "created_at" - assert result.start_date == datetime(2024, 1, 1) - assert result.end_date == datetime(2024, 12, 31) - - def test_parses_three_part_format(self, repl: REPL) -> None: - """Should parse table:start:end format with default column.""" - result = repl._parse_timeframe("orders:2024-01-01:2024-12-31") - - assert result.table_name == "orders" - assert result.column_name == "created_at" - assert result.start_date == datetime(2024, 1, 1) - assert result.end_date == datetime(2024, 12, 31) - - def test_raises_for_invalid_format(self, repl: REPL) -> None: - """Should raise for invalid format.""" - with pytest.raises(InvalidTimeframeError, match="Invalid timeframe format"): - repl._parse_timeframe("orders") - - with pytest.raises(InvalidTimeframeError, match="Invalid timeframe format"): - repl._parse_timeframe("a:b:c:d:e") - - def test_raises_for_invalid_start_date(self, repl: REPL) -> None: - """Should raise for invalid start date.""" - with pytest.raises(InvalidTimeframeError, match="Invalid start date"): - repl._parse_timeframe("orders:invalid:2024-12-31") + with ( + patch("pgslice.repl.DumpService") as mock_dump_service, + patch("pgslice.repl.printy"), + ): + mock_service_instance = MagicMock() + mock_service_instance.dump.side_effect = RecordNotFoundError("Not found") + mock_dump_service.return_value = mock_service_instance - def test_raises_for_invalid_end_date(self, repl: REPL) -> None: - """Should raise for invalid end date.""" - with pytest.raises(InvalidTimeframeError, match="Invalid end date"): - repl._parse_timeframe("orders:2024-01-01:invalid") + # Should not raise + repl._cmd_dump(["users", "42"]) class TestStart(TestREPL): diff --git a/tests/unit/utils/test_graph_visualizer.py b/tests/unit/utils/test_graph_visualizer.py new file mode 100644 index 0000000..c36472b --- /dev/null +++ b/tests/unit/utils/test_graph_visualizer.py @@ -0,0 +1,308 @@ +"""Tests for graph visualization utility.""" + +from __future__ import annotations + +from pgslice.graph.models import RecordData, RecordIdentifier +from pgslice.utils.graph_visualizer import ( + GraphBuilder, + GraphRenderer, + TableEdge, + TableGraph, + TableNode, +) + + +class TestGraphBuilder: + """Tests for GraphBuilder class.""" + + def test_single_table_no_dependencies(self) -> None: + """Should create graph with single node and no edges.""" + # Create single record with no dependencies + record_id = RecordIdentifier("users", "public", ("1",)) + record = RecordData(identifier=record_id, data={"id": 1}) + + builder = GraphBuilder() + graph = builder.build({record}, "users", "public") + + assert len(graph.nodes) == 1 + assert graph.nodes[0].table_name == "users" + assert graph.nodes[0].schema_name == "public" + assert graph.nodes[0].record_count == 1 + assert graph.nodes[0].is_root is True + assert len(graph.edges) == 0 + + def test_simple_parent_child(self) -> None: + """Should create graph with parent-child relationship.""" + # Create user record (parent) + user_id = RecordIdentifier("users", "public", ("1",)) + user = RecordData(identifier=user_id, data={"id": 1}) + + # Create order records (children) that depend on user + order1_id = RecordIdentifier("orders", "public", ("101",)) + order1 = RecordData( + identifier=order1_id, data={"id": 101, "user_id": 1}, dependencies={user_id} + ) + + order2_id = RecordIdentifier("orders", "public", ("102",)) + order2 = RecordData( + identifier=order2_id, data={"id": 102, "user_id": 1}, dependencies={user_id} + ) + + order3_id = RecordIdentifier("orders", "public", ("103",)) + order3 = RecordData( + identifier=order3_id, data={"id": 103, "user_id": 1}, dependencies={user_id} + ) + + builder = GraphBuilder() + graph = builder.build({user, order1, order2, order3}, "users", "public") + + # Should have 2 nodes (users, orders) + assert len(graph.nodes) == 2 + + # Find nodes + users_node = next(n for n in graph.nodes if n.table_name == "users") + orders_node = next(n for n in graph.nodes if n.table_name == "orders") + + assert users_node.record_count == 1 + assert users_node.is_root is True + assert orders_node.record_count == 3 + assert orders_node.is_root is False + + # Should have 1 edge (orders -> users) + assert len(graph.edges) == 1 + assert graph.edges[0].source_table == "public.orders" + assert graph.edges[0].target_table == "public.users" + + def test_record_counting(self) -> None: + """Should correctly count multiple records from same table.""" + # Create 5 user records + records = set() + for i in range(1, 6): + record_id = RecordIdentifier("users", "public", (str(i),)) + record = RecordData(identifier=record_id, data={"id": i}) + records.add(record) + + builder = GraphBuilder() + graph = builder.build(records, "users", "public") + + assert len(graph.nodes) == 1 + assert graph.nodes[0].record_count == 5 + + def test_edge_counting(self) -> None: + """Should count how many records use same FK relationship.""" + # Create 1 user + user_id = RecordIdentifier("users", "public", ("1",)) + user = RecordData(identifier=user_id, data={"id": 1}) + + # Create 10 orders all referencing same user + records = {user} + for i in range(1, 11): + order_id = RecordIdentifier("orders", "public", (str(i),)) + order = RecordData( + identifier=order_id, + data={"id": i, "user_id": 1}, + dependencies={user_id}, + ) + records.add(order) + + builder = GraphBuilder() + graph = builder.build(records, "users", "public") + + assert len(graph.edges) == 1 + assert graph.edges[0].record_count == 10 + + def test_multiple_tables_with_dependencies(self) -> None: + """Should handle complex graph with multiple tables.""" + # Create: users -> orders -> order_items + user_id = RecordIdentifier("users", "public", ("1",)) + user = RecordData(identifier=user_id, data={"id": 1}) + + order_id = RecordIdentifier("orders", "public", ("101",)) + order = RecordData( + identifier=order_id, data={"id": 101}, dependencies={user_id} + ) + + item_id = RecordIdentifier("order_items", "public", ("1001",)) + item = RecordData( + identifier=item_id, data={"id": 1001}, dependencies={order_id} + ) + + builder = GraphBuilder() + graph = builder.build({user, order, item}, "users", "public") + + assert len(graph.nodes) == 3 + assert len(graph.edges) == 2 + + # Verify edges + edge_sources = {e.source_table for e in graph.edges} + edge_targets = {e.target_table for e in graph.edges} + + assert "public.orders" in edge_sources + assert "public.order_items" in edge_sources + assert "public.users" in edge_targets + assert "public.orders" in edge_targets + + def test_non_root_table_when_root_not_in_results(self) -> None: + """Should mark is_root=False if root table not in results.""" + # Create order record but specify users as root + order_id = RecordIdentifier("orders", "public", ("101",)) + order = RecordData(identifier=order_id, data={"id": 101}) + + builder = GraphBuilder() + graph = builder.build({order}, "users", "public") + + assert len(graph.nodes) == 1 + assert graph.nodes[0].is_root is False # Not the specified root + + +class TestGraphRenderer: + """Tests for GraphRenderer class.""" + + def test_single_root_no_children(self) -> None: + """Should render single node without tree structure.""" + node = TableNode("users", "public", 1, is_root=True) + graph = TableGraph(nodes=[node], edges=[]) + + renderer = GraphRenderer() + output = renderer.render(graph) + + # Check for content (output now has ANSI color codes) + assert "users" in output + assert "1 records" in output + assert "(No related tables)" in output + + def test_linear_chain(self) -> None: + """Should render linear dependency chain.""" + # A -> B -> C + node_a = TableNode("table_a", "public", 1, is_root=True) + node_b = TableNode("table_b", "public", 1, is_root=False) + node_c = TableNode("table_c", "public", 1, is_root=False) + + edge_ab = TableEdge("public.table_b", "public.table_a", None, 1) + edge_bc = TableEdge("public.table_c", "public.table_b", None, 1) + + graph = TableGraph(nodes=[node_a, node_b, node_c], edges=[edge_ab, edge_bc]) + + renderer = GraphRenderer() + output = renderer.render(graph) + + lines = output.split("\n") + assert len(lines) == 3 + # Check for table names (output has color codes) + assert "table_a" in lines[0] and "1 records" in lines[0] + assert "table_b" in lines[1] and "1 records" in lines[1] + assert "table_c" in lines[2] and "1 records" in lines[2] + + def test_multiple_children(self) -> None: + """Should render multiple children with correct connectors.""" + # Root with 3 children + root = TableNode("users", "public", 1, is_root=True) + child1 = TableNode("orders", "public", 2, is_root=False) + child2 = TableNode("addresses", "public", 3, is_root=False) + child3 = TableNode("reviews", "public", 4, is_root=False) + + edges = [ + TableEdge("public.orders", "public.users", None, 2), + TableEdge("public.addresses", "public.users", None, 3), + TableEdge("public.reviews", "public.users", None, 4), + ] + + graph = TableGraph(nodes=[root, child1, child2, child3], edges=edges) + + renderer = GraphRenderer() + output = renderer.render(graph) + + # Should have root + 3 children (check for content) + assert "users" in output and "1 records" in output + assert "orders" in output and "2 records" in output + assert "addresses" in output and "3 records" in output + assert "reviews" in output and "4 records" in output + + # Should use branch characters (├── for non-last, └── for last) + assert "├──" in output # First two children + assert "└──" in output # Last child + + def test_circular_dependency_detection(self) -> None: + """Should detect and mark circular dependencies.""" + # A -> B -> A (circular) + node_a = TableNode("table_a", "public", 1, is_root=True) + node_b = TableNode("table_b", "public", 1, is_root=False) + + # Create circular edges + edge_ab = TableEdge("public.table_b", "public.table_a", None, 1) + edge_ba = TableEdge("public.table_a", "public.table_b", None, 1) + + graph = TableGraph(nodes=[node_a, node_b], edges=[edge_ab, edge_ba]) + + renderer = GraphRenderer() + output = renderer.render(graph) + + # Should mark cycle + assert "[shown above]" in output + + def test_unicode_characters(self) -> None: + """Should use Unicode box-drawing characters.""" + root = TableNode("users", "public", 1, is_root=True) + child = TableNode("orders", "public", 2, is_root=False) + edge = TableEdge("public.orders", "public.users", None, 2) + + graph = TableGraph(nodes=[root, child], edges=[edge]) + + renderer = GraphRenderer() + output = renderer.render(graph) + + # Should contain box-drawing chars (├──, │, └──) + # At least one should be present + has_unicode = any(char in output for char in ["├", "└", "│", "─"]) + assert has_unicode + + def test_empty_graph(self) -> None: + """Should handle empty graph gracefully.""" + graph = TableGraph(nodes=[], edges=[]) + + renderer = GraphRenderer() + output = renderer.render(graph) + + assert output == "(No records found)" + + def test_nested_hierarchy(self) -> None: + """Should render nested hierarchy with proper indentation.""" + # users -> orders -> order_items + users = TableNode("users", "public", 1, is_root=True) + orders = TableNode("orders", "public", 2, is_root=False) + items = TableNode("order_items", "public", 5, is_root=False) + + edges = [ + TableEdge("public.orders", "public.users", None, 2), + TableEdge("public.order_items", "public.orders", None, 5), + ] + + graph = TableGraph(nodes=[users, orders, items], edges=edges) + + renderer = GraphRenderer() + output = renderer.render(graph) + + lines = output.split("\n") + assert len(lines) == 3 + + # Check that all tables appear in output + assert "users" in output + assert "orders" in output + assert "order_items" in output + + # Check for tree structure characters (output has ANSI color codes) + assert "└" in output or "├" in output # Has tree structure + + def test_multiple_roots(self) -> None: + """Should handle multiple root nodes.""" + root1 = TableNode("users", "public", 1, is_root=True) + root2 = TableNode("products", "public", 2, is_root=True) + + graph = TableGraph(nodes=[root1, root2], edges=[]) + + renderer = GraphRenderer() + output = renderer.render(graph) + + # Both roots should appear (output has color codes) + assert "users" in output and "1 records" in output + assert "products" in output and "2 records" in output diff --git a/tests/unit/utils/test_logging_config.py b/tests/unit/utils/test_logging_config.py index 0699018..68ab3b8 100644 --- a/tests/unit/utils/test_logging_config.py +++ b/tests/unit/utils/test_logging_config.py @@ -44,11 +44,12 @@ def test_sets_error_level(self) -> None: root = logging.getLogger() assert root.level == logging.ERROR - def test_default_level_is_info(self) -> None: - """Default log level should be INFO.""" - setup_logging() - root = logging.getLogger() - assert root.level == logging.INFO + def test_default_disables_logging(self) -> None: + """Default (None) should disable logging entirely.""" + setup_logging() # None = disabled + # When disabled, logging.disable(CRITICAL) is called + # Check that log messages are suppressed + assert logging.root.manager.disable >= logging.CRITICAL def test_level_case_insensitive(self) -> None: """Log level should be case insensitive.""" @@ -68,7 +69,7 @@ def test_invalid_level_defaults_to_info(self) -> None: assert root.level == logging.INFO def test_adds_console_handler(self) -> None: - """Should add a console handler to stdout.""" + """Should add a console handler to stderr.""" setup_logging("INFO") root = logging.getLogger() @@ -76,9 +77,9 @@ def test_adds_console_handler(self) -> None: assert len(root.handlers) == 1 handler = root.handlers[0] - # Should be a StreamHandler + # Should be a StreamHandler to stderr (not stdout, to avoid mixing with SQL output) assert isinstance(handler, logging.StreamHandler) - assert handler.stream == sys.stdout + assert handler.stream == sys.stderr def test_handler_has_correct_level(self) -> None: """Handler should have the same level as configured.""" @@ -206,9 +207,10 @@ def test_logger_respects_level(self, capsys: pytest.CaptureFixture[str]) -> None logger.warning("Warning message") captured = capsys.readouterr() - assert "Debug message" not in captured.out - assert "Info message" not in captured.out - assert "Warning message" in captured.out + # Logs go to stderr now (not stdout) to avoid mixing with SQL output + assert "Debug message" not in captured.err + assert "Info message" not in captured.err + assert "Warning message" in captured.err def test_log_format_includes_timestamp( self, caplog: pytest.LogCaptureFixture diff --git a/tests/unit/utils/test_spinner.py b/tests/unit/utils/test_spinner.py new file mode 100644 index 0000000..ded4509 --- /dev/null +++ b/tests/unit/utils/test_spinner.py @@ -0,0 +1,122 @@ +"""Tests for spinner animation utility.""" + +from __future__ import annotations + +import time + +from pgslice.utils.spinner import SpinnerAnimator + + +class TestSpinnerAnimator: + """Tests for SpinnerAnimator class.""" + + def test_init_default_interval(self) -> None: + """Should initialize with default update interval.""" + spinner = SpinnerAnimator() + assert spinner.update_interval == 0.1 + + def test_init_custom_interval(self) -> None: + """Should initialize with custom update interval.""" + spinner = SpinnerAnimator(update_interval=0.05) + assert spinner.update_interval == 0.05 + + def test_get_frame_returns_valid_character(self) -> None: + """Should return a valid Braille spinner character.""" + spinner = SpinnerAnimator() + frame = spinner.get_frame() + assert frame in SpinnerAnimator.FRAMES + + def test_get_frame_starts_at_first_frame(self) -> None: + """Should start at the first frame.""" + spinner = SpinnerAnimator() + assert spinner.get_frame() == SpinnerAnimator.FRAMES[0] + + def test_get_frame_advances_after_interval(self) -> None: + """Should advance to next frame after update interval.""" + spinner = SpinnerAnimator(update_interval=0.01) # 10ms for fast test + first_frame = spinner.get_frame() + + # Wait for interval to pass + time.sleep(0.02) # 20ms to ensure we've passed the interval + + second_frame = spinner.get_frame() + assert second_frame != first_frame + assert second_frame == SpinnerAnimator.FRAMES[1] + + def test_get_frame_does_not_advance_before_interval(self) -> None: + """Should not advance before update interval.""" + spinner = SpinnerAnimator(update_interval=1.0) # 1 second + first_frame = spinner.get_frame() + + # Call immediately without waiting + second_frame = spinner.get_frame() + + assert second_frame == first_frame + + def test_get_frame_cycles_through_all_frames(self) -> None: + """Should cycle through all frames and wrap around.""" + spinner = SpinnerAnimator(update_interval=0.01) + + # Get first frame + frames_seen = [spinner.get_frame()] + + # Advance through all frames + a few more to test wrapping + for _ in range(len(SpinnerAnimator.FRAMES) + 2): + time.sleep(0.015) # Wait for interval + frames_seen.append(spinner.get_frame()) + + # Should have cycled through all frames + assert SpinnerAnimator.FRAMES[0] in frames_seen + assert SpinnerAnimator.FRAMES[-1] in frames_seen + # Should have wrapped around (first frame appears at start and after full cycle) + assert frames_seen.count(SpinnerAnimator.FRAMES[0]) >= 2 + + def test_reset_returns_to_first_frame(self) -> None: + """Should reset to first frame.""" + spinner = SpinnerAnimator(update_interval=0.01) + + # Advance a few frames + for _ in range(3): + time.sleep(0.015) + spinner.get_frame() + + # Reset + spinner.reset() + + # Should be back at first frame + assert spinner.get_frame() == SpinnerAnimator.FRAMES[0] + + def test_reset_updates_last_update_time(self) -> None: + """Should update last update time on reset.""" + spinner = SpinnerAnimator(update_interval=0.01) + + # Advance one frame + time.sleep(0.015) + spinner.get_frame() + + # Reset and immediately call get_frame + spinner.reset() + frame_after_reset = spinner.get_frame() + + # Should still be at first frame (not advanced immediately) + assert frame_after_reset == SpinnerAnimator.FRAMES[0] + + def test_frames_constant_contains_braille_patterns(self) -> None: + """Should have 10 Braille pattern frames.""" + assert len(SpinnerAnimator.FRAMES) == 10 + expected_frames = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"] + assert expected_frames == SpinnerAnimator.FRAMES + + def test_high_frequency_calls_respect_interval(self) -> None: + """Should respect update interval even with high frequency calls.""" + spinner = SpinnerAnimator(update_interval=0.1) + + # Call get_frame many times rapidly + frames = [] + for _ in range(50): + frames.append(spinner.get_frame()) + time.sleep(0.001) # 1ms between calls + + # Should have stayed on first frame for most calls + # (50ms total, so only about half the interval) + assert frames.count(SpinnerAnimator.FRAMES[0]) > 40 diff --git a/uv.lock b/uv.lock index 8f32897..a393558 100644 --- a/uv.lock +++ b/uv.lock @@ -379,6 +379,7 @@ dependencies = [ { name = "psycopg", extra = ["binary"] }, { name = "python-dotenv" }, { name = "tabulate" }, + { name = "tqdm" }, ] [package.optional-dependencies] @@ -412,6 +413,7 @@ requires-dist = [ { name = "python-dotenv", specifier = ">=1.0.0" }, { name = "ruff", marker = "extra == 'dev'", specifier = ">=0.1.0" }, { name = "tabulate", specifier = ">=0.9.0" }, + { name = "tqdm", specifier = ">=4.66.0" }, ] provides-extras = ["dev"] @@ -808,6 +810,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/77/b8/0135fadc89e73be292b473cb820b4f5a08197779206b33191e801feeae40/tomli-2.3.0-py3-none-any.whl", hash = "sha256:e95b1af3c5b07d9e643909b5abbec77cd9f1217e6d0bca72b0234736b9fb1f1b", size = 14408, upload-time = "2025-10-08T22:01:46.04Z" }, ] +[[package]] +name = "tqdm" +version = "4.67.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a8/4b/29b4ef32e036bb34e4ab51796dd745cdba7ed47ad142a9f4a1eb8e0c744d/tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2", size = 169737, upload-time = "2024-11-24T20:12:22.481Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d0/30/dc54f88dd4a2b5dc8a0279bdd7270e735851848b762aeb1c1184ed1f6b14/tqdm-4.67.1-py3-none-any.whl", hash = "sha256:26445eca388f82e72884e0d580d5464cd801a3ea01e63e5601bdff9ba6a48de2", size = 78540, upload-time = "2024-11-24T20:12:19.698Z" }, +] + [[package]] name = "typing-extensions" version = "4.15.0"