diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 22481a2..5e3b193 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -58,6 +58,10 @@ jobs: NEW_VERSION=$(grep '^version = ' pyproject.toml | cut -d'"' -f2) echo "new_version=$NEW_VERSION" >> $GITHUB_OUTPUT + - name: Update lock file + run: | + uv lock + - name: Get commits since last tag id: get_commits run: | @@ -116,10 +120,10 @@ jobs: run: | git config user.name "github-actions[bot]" git config user.email "github-actions[bot]@users.noreply.github.com" - git add pyproject.toml CHANGELOG.md + git add pyproject.toml uv.lock CHANGELOG.md git commit -m "Bump version to ${{ steps.bump_version.outputs.new_version }} - Update CHANGELOG.md with commits since last release" + Update CHANGELOG.md with commits since last release and sync lock file" - name: Create release branch id: create_branch diff --git a/Makefile b/Makefile index 6e062e5..f1b5c91 100644 --- a/Makefile +++ b/Makefile @@ -54,6 +54,27 @@ clean: ## Remove build artifacts and cache show-version: ## Show current version from pyproject.toml @uv version +bump-patch: ## Bump patch version (0.1.0 -> 0.1.1) + uv version --bump patch + uv lock + @echo "✓ Version bumped to $$(uv version)" + @echo "Remember to commit: git add pyproject.toml uv.lock && git commit -m 'Bump version'" + +bump-minor: ## Bump minor version (0.1.0 -> 0.2.0) + uv version --bump minor + uv lock + @echo "✓ Version bumped to $$(uv version)" + @echo "Remember to commit: git add pyproject.toml uv.lock && git commit -m 'Bump version'" + +bump-major: ## Bump major version (0.1.0 -> 1.0.0) + uv version --bump major + uv lock + @echo "✓ Version bumped to $$(uv version)" + @echo "Remember to commit: git add pyproject.toml uv.lock && git commit -m 'Bump version'" + +lock: ## Update uv.lock file + uv lock + # Docker commands docker-build: ## Build Docker image docker build -t $(DOCKER_IMAGE) . diff --git a/src/pgslice/dumper/sql_generator.py b/src/pgslice/dumper/sql_generator.py index 9a27b07..e222d60 100644 --- a/src/pgslice/dumper/sql_generator.py +++ b/src/pgslice/dumper/sql_generator.py @@ -248,6 +248,28 @@ def _get_column_types(self, schema: str, table: str) -> dict[str, tuple[str, str } return self._column_type_cache[key] + def _quote_identifier(self, identifier: str) -> str: + """ + Quote a SQL identifier safely. + + Always uses double quotes to handle reserved words and special characters. + Escapes embedded double quotes. + + Args: + identifier: SQL identifier (table, column, schema name) + + Returns: + Quoted identifier + + Example: + "users" -> '"users"' + "references" -> '"references"' + 'col"name' -> '"col""name"' (escaped quote) + """ + # Escape embedded double quotes by doubling them + escaped = identifier.replace('"', '""') + return f'"{escaped}"' + def _is_array_type(self, data_type: str) -> bool: """ Check if a PostgreSQL data type is an array type. @@ -972,11 +994,11 @@ def _generate_insert_with_fk_remapping( SELECT map0.new_id::integer, map1.new_id::integer, - data.last_update + data."last_update" FROM (VALUES ('20', '1', '2006-02-15T10:05:03'), ... - ) AS data(old_actor_id, old_film_id, last_update) + ) AS data("old_actor_id", "old_film_id", "last_update") JOIN _pgslice_id_map map0 ... JOIN _pgslice_id_map map1 ... """ @@ -1057,13 +1079,15 @@ def _generate_insert_with_fk_remapping( values_clause = ",\n".join(values_rows) # Create column aliases for the VALUES clause - # Example: data(old_actor_id, old_film_id, description, last_update) + # Example: data("old_actor_id", "old_film_id", "description", "last_update") data_column_aliases = [] for col in columns: if col in fk_to_remap: - data_column_aliases.append(f"old_{col}") + # Quote the prefixed alias for remapped FK columns + data_column_aliases.append(self._quote_identifier(f"old_{col}")) else: - data_column_aliases.append(col) + # Quote regular column names to handle reserved keywords + data_column_aliases.append(self._quote_identifier(col)) # Get table metadata for column types table_meta = self.introspector.get_table_metadata(schema, table) @@ -1096,7 +1120,7 @@ def _generate_insert_with_fk_remapping( join_clauses.append( f" JOIN _pgslice_id_map {alias}\n" f" ON {alias}.table_name = '{target_full}'\n" - f" AND {alias}.old_id = data.old_{col}" + f" AND {alias}.old_id = data.{self._quote_identifier(f'old_{col}')}" ) join_index += 1 else: @@ -1116,10 +1140,12 @@ def _generate_insert_with_fk_remapping( element_type = self._get_array_element_type(col_meta.udt_name) pg_type = f"{element_type}[]" - select_parts.append(f"data.{col}::{pg_type}") + select_parts.append( + f"data.{self._quote_identifier(col)}::{pg_type}" + ) else: # Fallback if column metadata not found - select_parts.append(f"data.{col}") + select_parts.append(f"data.{self._quote_identifier(col)}") select_clause = ",\n ".join(select_parts) join_clause = "\n".join(join_clauses) diff --git a/src/pgslice/graph/traverser.py b/src/pgslice/graph/traverser.py index 8e26bdc..d89faa9 100644 --- a/src/pgslice/graph/traverser.py +++ b/src/pgslice/graph/traverser.py @@ -54,6 +54,9 @@ def __init__( self.timeframe_filters = {f.table_name: f for f in (timeframe_filters or [])} self.wide_mode = wide_mode self.fetch_batch_size = fetch_batch_size + self.starting_table: str | None = ( + None # Track starting table for timeframe filtering + ) def traverse( self, @@ -87,6 +90,9 @@ def traverse( Raises: RecordNotFoundError: If starting record doesn't exist """ + # Track starting table for timeframe filtering (only apply to starting table) + self.starting_table = table_name + start_id = self._create_record_identifier(schema, table_name, (pk_value,)) queue: deque[tuple[RecordIdentifier, int, bool]] = deque([(start_id, 0, True)]) results: set[RecordData] = set() @@ -257,6 +263,9 @@ def traverse_multiple( Returns: Set of all discovered RecordData objects """ + # Track starting table for timeframe filtering (only apply to starting table) + self.starting_table = table_name + # Edge case: empty pk_values if not pk_values: logger.info("No primary keys provided for traversal") @@ -511,9 +520,12 @@ def _fetch_record(self, record_id: RecordIdentifier) -> RecordData: where_parts.append(f'"{pk_col}" = %s') params.append(pk_val) - # Apply timeframe filter if applicable + # Apply timeframe filter only to starting table timeframe_clause = "" - if record_id.table_name in self.timeframe_filters: + if ( + record_id.table_name in self.timeframe_filters + and record_id.table_name == self.starting_table + ): filter_config = self.timeframe_filters[record_id.table_name] timeframe_clause = f' AND "{filter_config.column_name}" BETWEEN %s AND %s' params.extend([filter_config.start_date, filter_config.end_date]) @@ -568,16 +580,33 @@ def _fetch_records_batch( logger.warning(f"Table {schema}.{table} has no primary key, skipping") continue - # Build WHERE clause: WHERE id IN (1, 2, 3) - pk_col = table_metadata.primary_keys[0] - pk_values = [rid.pk_values[0] for rid in table_record_ids] - placeholders = ", ".join(["%s"] * len(pk_values)) + pk_cols = table_metadata.primary_keys - # Apply timeframe filter if applicable + # Apply timeframe filter only to starting table timeframe_clause = "" - params: list[Any] = pk_values.copy() + params: list[Any] = [] + + # Build WHERE clause for composite or single PK + if len(pk_cols) == 1: + # Single column PK: WHERE id IN (1, 2, 3) + pk_col = pk_cols[0] + pk_values = [rid.pk_values[0] for rid in table_record_ids] + placeholders = ", ".join(["%s"] * len(pk_values)) + where_clause = f'"{pk_col}" IN ({placeholders})' + params.extend(pk_values) + else: + # Composite PK: WHERE (col1, col2) IN ((1, 2), (3, 4)) + pk_columns = ", ".join([f'"{col}"' for col in pk_cols]) + pk_tuples = [rid.pk_values for rid in table_record_ids] + tuple_placeholders = ", ".join( + ["(" + ", ".join(["%s"] * len(pk_cols)) + ")"] * len(pk_tuples) + ) + where_clause = f"({pk_columns}) IN ({tuple_placeholders})" + # Flatten the tuples into a single list of params + for pk_tuple in pk_tuples: + params.extend(pk_tuple) - if table in self.timeframe_filters: + if table in self.timeframe_filters and table == self.starting_table: filter_config = self.timeframe_filters[table] timeframe_clause = ( f' AND "{filter_config.column_name}" BETWEEN %s AND %s' @@ -586,7 +615,7 @@ def _fetch_records_batch( query = f""" SELECT * FROM "{schema}"."{table}" - WHERE "{pk_col}" IN ({placeholders}){timeframe_clause} + WHERE {where_clause}{timeframe_clause} """ with self.conn.cursor() as cur: @@ -596,9 +625,10 @@ def _fetch_records_batch( for row in rows: data = dict(zip(columns, row, strict=False)) - pk_value = data[pk_col] + # Extract all PK values for composite keys + record_pk_values = tuple(data[col] for col in pk_cols) record_id = self._create_record_identifier( - schema, table, (pk_value,) + schema, table, record_pk_values ) results[record_id] = RecordData(identifier=record_id, data=data) @@ -664,11 +694,11 @@ def _find_referencing_records( # Build query pk_columns = ", ".join(f'"{pk}"' for pk in source_table.primary_keys) - # Apply timeframe filter if applicable + # Apply timeframe filter only to starting table timeframe_clause = "" params: list[Any] = [target_pk_value] - if table in self.timeframe_filters: + if table in self.timeframe_filters and table == self.starting_table: filter_config = self.timeframe_filters[table] timeframe_clause = f' AND "{filter_config.column_name}" BETWEEN %s AND %s' params.extend([filter_config.start_date, filter_config.end_date]) @@ -733,11 +763,11 @@ def _find_referencing_records_batch( pk_columns = ", ".join(f'"{pk}"' for pk in source_table.primary_keys) placeholders = ", ".join(["%s"] * len(target_pk_values)) - # Apply timeframe filter if applicable + # Apply timeframe filter only to starting table timeframe_clause = "" params: list[Any] = target_pk_values.copy() - if table in self.timeframe_filters: + if table in self.timeframe_filters and table == self.starting_table: filter_config = self.timeframe_filters[table] timeframe_clause = f' AND "{filter_config.column_name}" BETWEEN %s AND %s' params.extend([filter_config.start_date, filter_config.end_date]) diff --git a/tests/unit/dumper/test_sql_generator.py b/tests/unit/dumper/test_sql_generator.py index a96ddc7..323930b 100644 --- a/tests/unit/dumper/test_sql_generator.py +++ b/tests/unit/dumper/test_sql_generator.py @@ -1906,3 +1906,226 @@ def test_create_schema_requires_database_name( assert "CREATE DATABASE" not in sql assert "CREATE SCHEMA" not in sql assert "CREATE TABLE" not in sql + + +class TestReservedKeywordColumns(TestSQLGenerator): + """Test that PostgreSQL reserved keywords are properly quoted in generated SQL.""" + + @pytest.fixture + def table_with_references_column(self) -> Table: + """Create a table with 'references' column (reserved keyword).""" + return Table( + schema_name="public", + table_name="shipments_shipment", + columns=[ + Column( + name="id", + data_type="integer", + udt_name="int4", + nullable=False, + is_primary_key=True, + ), + Column( + name="reference_id", + data_type="character varying", + udt_name="varchar", + nullable=True, + ), + Column( + name="references", # Reserved keyword! + data_type="jsonb", + udt_name="jsonb", + nullable=True, + ), + Column( + name="state_id", + data_type="character varying", + udt_name="varchar", + nullable=True, + ), + ], + primary_keys=["id"], + foreign_keys_outgoing=[], + foreign_keys_incoming=[], + ) + + def test_quote_identifier_simple(self, generator: SQLGenerator) -> None: + """_quote_identifier should properly quote identifiers.""" + assert generator._quote_identifier("column_name") == '"column_name"' + assert generator._quote_identifier("references") == '"references"' + assert generator._quote_identifier("user") == '"user"' + assert generator._quote_identifier("group") == '"group"' + + def test_quote_identifier_with_embedded_quotes( + self, generator: SQLGenerator + ) -> None: + """_quote_identifier should escape embedded double quotes.""" + assert generator._quote_identifier('col"name') == '"col""name"' + assert generator._quote_identifier('my"table"name') == '"my""table""name"' + + def test_insert_with_references_column( + self, mock_introspector: MagicMock, table_with_references_column: Table + ) -> None: + """Should properly quote 'references' column in INSERT statement.""" + mock_introspector.get_table_metadata.return_value = table_with_references_column + generator = SQLGenerator(mock_introspector) + + record = RecordData( + identifier=RecordIdentifier( + table_name="shipments_shipment", schema_name="public", pk_values=("1",) + ), + data={ + "id": 1, + "reference_id": "WNZK22", + "references": "[]", + "state_id": "quoting", + }, + dependencies=[], + ) + + sql = generator.generate_batch([record], keep_pks=True) + + # Verify the column is quoted in the INSERT statement + assert '"references"' in sql + # Verify it's in the column list + assert '("id", "reference_id", "references", "state_id")' in sql + + def test_fk_remapping_with_reserved_keyword_column( + self, mock_introspector: MagicMock + ) -> None: + """Should properly quote reserved keywords in FK remapping INSERT-SELECT.""" + # Create a table with 'user' as a FK column (reserved keyword) + table_with_user_fk = Table( + schema_name="public", + table_name="posts", + columns=[ + Column( + name="id", + data_type="integer", + udt_name="int4", + nullable=False, + is_primary_key=True, + is_auto_generated=True, + ), + Column( + name="user", # Reserved keyword as FK! + data_type="integer", + udt_name="int4", + nullable=False, + ), + Column( + name="content", + data_type="text", + udt_name="text", + nullable=False, + ), + ], + primary_keys=["id"], + foreign_keys_outgoing=[ + ForeignKey( + constraint_name="fk_posts_user", + source_table="public.posts", + source_column="user", + target_table="public.users", + target_column="id", + on_delete="CASCADE", + ) + ], + foreign_keys_incoming=[], + ) + + mock_introspector.get_table_metadata.return_value = table_with_user_fk + generator = SQLGenerator(mock_introspector) + + record = RecordData( + identifier=RecordIdentifier( + table_name="posts", schema_name="public", pk_values=("1",) + ), + data={"id": 1, "user": 42, "content": "Hello World"}, + dependencies=[ + RecordIdentifier( + table_name="users", schema_name="public", pk_values=("42",) + ) + ], + ) + + # Call the private method directly with correct parameters + tables_with_remapped_ids = {("public", "users")} + sql = generator._generate_insert_with_fk_remapping( + "public", "posts", [record], tables_with_remapped_ids + ) + + # Verify reserved keyword is quoted in the AS data(...) clause + assert '"old_user"' in sql + # Verify it's quoted in JOIN condition + assert 'data."old_user"' in sql + + def test_multiple_reserved_keywords(self, mock_introspector: MagicMock) -> None: + """Should properly quote multiple reserved keyword columns.""" + table_with_multiple_keywords = Table( + schema_name="public", + table_name="test_reserved", + columns=[ + Column( + name="id", + data_type="integer", + udt_name="int4", + nullable=False, + is_primary_key=True, + ), + Column( + name="user", # Reserved + data_type="character varying", + udt_name="varchar", + nullable=True, + ), + Column( + name="group", # Reserved + data_type="character varying", + udt_name="varchar", + nullable=True, + ), + Column( + name="references", # Reserved + data_type="jsonb", + udt_name="jsonb", + nullable=True, + ), + Column( + name="order", # Reserved + data_type="integer", + udt_name="int4", + nullable=True, + ), + ], + primary_keys=["id"], + foreign_keys_outgoing=[], + foreign_keys_incoming=[], + ) + + mock_introspector.get_table_metadata.return_value = table_with_multiple_keywords + generator = SQLGenerator(mock_introspector) + + record = RecordData( + identifier=RecordIdentifier( + table_name="test_reserved", schema_name="public", pk_values=("1",) + ), + data={ + "id": 1, + "user": "john", + "group": "admin", + "references": "{}", + "order": 5, + }, + dependencies=[], + ) + + sql = generator.generate_batch([record], keep_pks=True) + + # Verify all reserved keywords are quoted + assert '"user"' in sql + assert '"group"' in sql + assert '"references"' in sql + assert '"order"' in sql + # Verify they're in the column list + assert '("group", "id", "order", "references", "user")' in sql diff --git a/tests/unit/graph/test_traverser.py b/tests/unit/graph/test_traverser.py index d1fb1ea..a6ef7f7 100644 --- a/tests/unit/graph/test_traverser.py +++ b/tests/unit/graph/test_traverser.py @@ -955,6 +955,10 @@ def test_find_referencing_records_with_timeframe_filter( wide_mode=True, ) + # Set starting_table to "orders" so that timeframe filter applies + # (timeframe filters only apply to the starting table) + traverser.starting_table = "orders" + mock_cursor = MagicMock() mock_cursor.fetchall.return_value = [(100,)] mock_connection.cursor.return_value.__enter__.return_value = mock_cursor @@ -966,7 +970,7 @@ def test_find_referencing_records_with_timeframe_filter( traverser._find_referencing_records(target_id, fk) - # Should include timeframe filter in query + # Should include timeframe filter in query since orders is the starting table query_args = mock_cursor.execute.call_args[0] assert "BETWEEN %s AND %s" in query_args[0] # Should pass timeframe parameters @@ -1204,3 +1208,185 @@ def test_adds_outgoing_fk_as_dependency(self) -> None: fk = orders_table.foreign_keys_outgoing[0] assert fk.source_column == "user_id" assert fk.target_table == "public.users" + + +class TestTimeframeFilterOnlyAppliedToStartingTable: + """Test that timeframe filters are only applied to the starting table, not FK-related tables.""" + + @pytest.fixture + def mock_cursor(self) -> MagicMock: + """Create a mock cursor.""" + cursor = MagicMock() + cursor.execute = MagicMock() + cursor.description = [("id",), ("name",)] + return cursor + + @pytest.fixture + def mock_connection(self, mock_cursor: MagicMock) -> MagicMock: + """Create a mock connection.""" + conn = MagicMock() + cursor_cm = MagicMock() + cursor_cm.__enter__ = MagicMock(return_value=mock_cursor) + cursor_cm.__exit__ = MagicMock(return_value=False) + conn.cursor.return_value = cursor_cm + return conn + + @pytest.fixture + def mock_introspector(self) -> MagicMock: + """Create a mock SchemaIntrospector.""" + introspector = MagicMock() + return introspector + + @pytest.fixture + def visited_tracker(self) -> VisitedTracker: + """Create a VisitedTracker.""" + return VisitedTracker() + + def test_timeframe_filter_only_applies_to_starting_table( + self, + mock_connection: MagicMock, + mock_introspector: MagicMock, + mock_cursor: MagicMock, + visited_tracker: VisitedTracker, + ) -> None: + """Timeframe filters should ONLY apply to starting table, not FK-related tables.""" + # Setup film table (starting table) + film_table = Table( + schema_name="public", + table_name="film", + columns=[ + Column( + name="film_id", + data_type="integer", + udt_name="int4", + nullable=False, + is_primary_key=True, + ), + Column( + name="last_update", + data_type="timestamp", + udt_name="timestamp", + nullable=False, + ), + ], + primary_keys=["film_id"], + foreign_keys_outgoing=[], + foreign_keys_incoming=[ + ForeignKey( + constraint_name="fk_film_actor_film", + source_table="public.film_actor", + source_column="film_id", + target_table="public.film", + target_column="film_id", + on_delete="CASCADE", + ) + ], + ) + + # Setup film_actor table (related via incoming FK) + film_actor_table = Table( + schema_name="public", + table_name="film_actor", + columns=[ + Column( + name="film_id", + data_type="integer", + udt_name="int4", + nullable=False, + is_primary_key=True, + ), + Column( + name="actor_id", + data_type="integer", + udt_name="int4", + nullable=False, + is_primary_key=True, + ), + Column( + name="last_update", + data_type="timestamp", + udt_name="timestamp", + nullable=False, + ), + ], + primary_keys=["film_id", "actor_id"], + foreign_keys_outgoing=[], + foreign_keys_incoming=[], + ) + + # Mock introspector to return our tables + def get_table_metadata(schema: str, table: str) -> Table: + if table == "film": + return film_table + elif table == "film_actor": + return film_actor_table + raise ValueError(f"Unknown table: {table}") + + mock_introspector.get_table_metadata = MagicMock(side_effect=get_table_metadata) + + # Setup cursor mock responses + # First call: fetch film record + # Second call: fetch film_actor records (should NOT have timeframe filter) + mock_cursor.fetchone.return_value = (1, datetime(2024, 6, 1)) + mock_cursor.fetchall.side_effect = [ + [(1, datetime(2024, 6, 1))], # film fetch + [(1, 1), (1, 2), (1, 3)], # film_actor fetch (3 records) + ] + + # Create traverser with timeframe filter on film_actor + # This filter should be IGNORED because film_actor is not the starting table + timeframe_filters = [ + TimeframeFilter( + table_name="film_actor", + column_name="last_update", + start_date=datetime( + 2099, 1, 1 + ), # Future date - would exclude all records + end_date=datetime(2099, 12, 31), + ) + ] + + traverser = RelationshipTraverser( + connection=mock_connection, + schema_introspector=mock_introspector, + visited_tracker=visited_tracker, + timeframe_filters=timeframe_filters, + wide_mode=False, # Strict mode + ) + + # Traverse from film (starting table = "film") + results = traverser.traverse("film", 1, "public") + + # Verify results + record_tables = {r.identifier.table_name for r in results} + + # Film should be included + assert "film" in record_tables, "Film should be included" + + # film_actor should be included even though it has a timeframe filter + # because the timeframe should ONLY apply to the starting table ("film") + assert "film_actor" in record_tables, ( + "film_actor should be included (timeframe ignored for non-starting tables)" + ) + + # Verify the SQL query for film_actor did NOT include the timeframe filter + executed_queries = [call[0][0] for call in mock_cursor.execute.call_args_list] + + # Find the film_actor query + film_actor_query = next( + (q for q in executed_queries if "film_actor" in q.lower()), None + ) + + assert film_actor_query is not None, ( + "film_actor query should have been executed" + ) + + # Verify the query does NOT contain the timeframe filter + # (The timeframe filter would add: AND "last_update" BETWEEN %s AND %s) + assert "2099" not in film_actor_query, ( + "film_actor query should not contain future timeframe dates" + ) + assert ( + "last_update" not in film_actor_query.lower() + or "between" not in film_actor_query.lower() + ), "film_actor query should not have timeframe filter" diff --git a/uv.lock b/uv.lock index 6f848f0..a2066b3 100644 --- a/uv.lock +++ b/uv.lock @@ -371,7 +371,7 @@ wheels = [ [[package]] name = "pgslice" -version = "0.2.1" +version = "0.2.2" source = { editable = "." } dependencies = [ { name = "printy" },