Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ jobs:
NEW_VERSION=$(grep '^version = ' pyproject.toml | cut -d'"' -f2)
echo "new_version=$NEW_VERSION" >> $GITHUB_OUTPUT

- name: Update lock file
run: |
uv lock

- name: Get commits since last tag
id: get_commits
run: |
Expand Down Expand Up @@ -116,10 +120,10 @@ jobs:
run: |
git config user.name "github-actions[bot]"
git config user.email "github-actions[bot]@users.noreply.github.com"
git add pyproject.toml CHANGELOG.md
git add pyproject.toml uv.lock CHANGELOG.md
git commit -m "Bump version to ${{ steps.bump_version.outputs.new_version }}

Update CHANGELOG.md with commits since last release"
Update CHANGELOG.md with commits since last release and sync lock file"

- name: Create release branch
id: create_branch
Expand Down
21 changes: 21 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,27 @@ clean: ## Remove build artifacts and cache
show-version: ## Show current version from pyproject.toml
@uv version

bump-patch: ## Bump patch version (0.1.0 -> 0.1.1)
uv version --bump patch
uv lock
@echo "✓ Version bumped to $$(uv version)"
@echo "Remember to commit: git add pyproject.toml uv.lock && git commit -m 'Bump version'"

bump-minor: ## Bump minor version (0.1.0 -> 0.2.0)
uv version --bump minor
uv lock
@echo "✓ Version bumped to $$(uv version)"
@echo "Remember to commit: git add pyproject.toml uv.lock && git commit -m 'Bump version'"

bump-major: ## Bump major version (0.1.0 -> 1.0.0)
uv version --bump major
uv lock
@echo "✓ Version bumped to $$(uv version)"
@echo "Remember to commit: git add pyproject.toml uv.lock && git commit -m 'Bump version'"

lock: ## Update uv.lock file
uv lock

# Docker commands
docker-build: ## Build Docker image
docker build -t $(DOCKER_IMAGE) .
Expand Down
42 changes: 34 additions & 8 deletions src/pgslice/dumper/sql_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,28 @@ def _get_column_types(self, schema: str, table: str) -> dict[str, tuple[str, str
}
return self._column_type_cache[key]

def _quote_identifier(self, identifier: str) -> str:
"""
Quote a SQL identifier safely.

Always uses double quotes to handle reserved words and special characters.
Escapes embedded double quotes.

Args:
identifier: SQL identifier (table, column, schema name)

Returns:
Quoted identifier

Example:
"users" -> '"users"'
"references" -> '"references"'
'col"name' -> '"col""name"' (escaped quote)
"""
# Escape embedded double quotes by doubling them
escaped = identifier.replace('"', '""')
return f'"{escaped}"'

def _is_array_type(self, data_type: str) -> bool:
"""
Check if a PostgreSQL data type is an array type.
Expand Down Expand Up @@ -972,11 +994,11 @@ def _generate_insert_with_fk_remapping(
SELECT
map0.new_id::integer,
map1.new_id::integer,
data.last_update
data."last_update"
FROM (VALUES
('20', '1', '2006-02-15T10:05:03'),
...
) AS data(old_actor_id, old_film_id, last_update)
) AS data("old_actor_id", "old_film_id", "last_update")
JOIN _pgslice_id_map map0 ...
JOIN _pgslice_id_map map1 ...
"""
Expand Down Expand Up @@ -1057,13 +1079,15 @@ def _generate_insert_with_fk_remapping(
values_clause = ",\n".join(values_rows)

# Create column aliases for the VALUES clause
# Example: data(old_actor_id, old_film_id, description, last_update)
# Example: data("old_actor_id", "old_film_id", "description", "last_update")
data_column_aliases = []
for col in columns:
if col in fk_to_remap:
data_column_aliases.append(f"old_{col}")
# Quote the prefixed alias for remapped FK columns
data_column_aliases.append(self._quote_identifier(f"old_{col}"))
else:
data_column_aliases.append(col)
# Quote regular column names to handle reserved keywords
data_column_aliases.append(self._quote_identifier(col))

# Get table metadata for column types
table_meta = self.introspector.get_table_metadata(schema, table)
Expand Down Expand Up @@ -1096,7 +1120,7 @@ def _generate_insert_with_fk_remapping(
join_clauses.append(
f" JOIN _pgslice_id_map {alias}\n"
f" ON {alias}.table_name = '{target_full}'\n"
f" AND {alias}.old_id = data.old_{col}"
f" AND {alias}.old_id = data.{self._quote_identifier(f'old_{col}')}"
)
join_index += 1
else:
Expand All @@ -1116,10 +1140,12 @@ def _generate_insert_with_fk_remapping(
element_type = self._get_array_element_type(col_meta.udt_name)
pg_type = f"{element_type}[]"

select_parts.append(f"data.{col}::{pg_type}")
select_parts.append(
f"data.{self._quote_identifier(col)}::{pg_type}"
)
else:
# Fallback if column metadata not found
select_parts.append(f"data.{col}")
select_parts.append(f"data.{self._quote_identifier(col)}")

select_clause = ",\n ".join(select_parts)
join_clause = "\n".join(join_clauses)
Expand Down
62 changes: 46 additions & 16 deletions src/pgslice/graph/traverser.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ def __init__(
self.timeframe_filters = {f.table_name: f for f in (timeframe_filters or [])}
self.wide_mode = wide_mode
self.fetch_batch_size = fetch_batch_size
self.starting_table: str | None = (
None # Track starting table for timeframe filtering
)

def traverse(
self,
Expand Down Expand Up @@ -87,6 +90,9 @@ def traverse(
Raises:
RecordNotFoundError: If starting record doesn't exist
"""
# Track starting table for timeframe filtering (only apply to starting table)
self.starting_table = table_name

start_id = self._create_record_identifier(schema, table_name, (pk_value,))
queue: deque[tuple[RecordIdentifier, int, bool]] = deque([(start_id, 0, True)])
results: set[RecordData] = set()
Expand Down Expand Up @@ -257,6 +263,9 @@ def traverse_multiple(
Returns:
Set of all discovered RecordData objects
"""
# Track starting table for timeframe filtering (only apply to starting table)
self.starting_table = table_name

# Edge case: empty pk_values
if not pk_values:
logger.info("No primary keys provided for traversal")
Expand Down Expand Up @@ -511,9 +520,12 @@ def _fetch_record(self, record_id: RecordIdentifier) -> RecordData:
where_parts.append(f'"{pk_col}" = %s')
params.append(pk_val)

# Apply timeframe filter if applicable
# Apply timeframe filter only to starting table
timeframe_clause = ""
if record_id.table_name in self.timeframe_filters:
if (
record_id.table_name in self.timeframe_filters
and record_id.table_name == self.starting_table
):
filter_config = self.timeframe_filters[record_id.table_name]
timeframe_clause = f' AND "{filter_config.column_name}" BETWEEN %s AND %s'
params.extend([filter_config.start_date, filter_config.end_date])
Expand Down Expand Up @@ -568,16 +580,33 @@ def _fetch_records_batch(
logger.warning(f"Table {schema}.{table} has no primary key, skipping")
continue

# Build WHERE clause: WHERE id IN (1, 2, 3)
pk_col = table_metadata.primary_keys[0]
pk_values = [rid.pk_values[0] for rid in table_record_ids]
placeholders = ", ".join(["%s"] * len(pk_values))
pk_cols = table_metadata.primary_keys

# Apply timeframe filter if applicable
# Apply timeframe filter only to starting table
timeframe_clause = ""
params: list[Any] = pk_values.copy()
params: list[Any] = []

# Build WHERE clause for composite or single PK
if len(pk_cols) == 1:
# Single column PK: WHERE id IN (1, 2, 3)
pk_col = pk_cols[0]
pk_values = [rid.pk_values[0] for rid in table_record_ids]
placeholders = ", ".join(["%s"] * len(pk_values))
where_clause = f'"{pk_col}" IN ({placeholders})'
params.extend(pk_values)
else:
# Composite PK: WHERE (col1, col2) IN ((1, 2), (3, 4))
pk_columns = ", ".join([f'"{col}"' for col in pk_cols])
pk_tuples = [rid.pk_values for rid in table_record_ids]
tuple_placeholders = ", ".join(
["(" + ", ".join(["%s"] * len(pk_cols)) + ")"] * len(pk_tuples)
)
where_clause = f"({pk_columns}) IN ({tuple_placeholders})"
# Flatten the tuples into a single list of params
for pk_tuple in pk_tuples:
params.extend(pk_tuple)

if table in self.timeframe_filters:
if table in self.timeframe_filters and table == self.starting_table:
filter_config = self.timeframe_filters[table]
timeframe_clause = (
f' AND "{filter_config.column_name}" BETWEEN %s AND %s'
Expand All @@ -586,7 +615,7 @@ def _fetch_records_batch(

query = f"""
SELECT * FROM "{schema}"."{table}"
WHERE "{pk_col}" IN ({placeholders}){timeframe_clause}
WHERE {where_clause}{timeframe_clause}
"""

with self.conn.cursor() as cur:
Expand All @@ -596,9 +625,10 @@ def _fetch_records_batch(

for row in rows:
data = dict(zip(columns, row, strict=False))
pk_value = data[pk_col]
# Extract all PK values for composite keys
record_pk_values = tuple(data[col] for col in pk_cols)
record_id = self._create_record_identifier(
schema, table, (pk_value,)
schema, table, record_pk_values
)
results[record_id] = RecordData(identifier=record_id, data=data)

Expand Down Expand Up @@ -664,11 +694,11 @@ def _find_referencing_records(
# Build query
pk_columns = ", ".join(f'"{pk}"' for pk in source_table.primary_keys)

# Apply timeframe filter if applicable
# Apply timeframe filter only to starting table
timeframe_clause = ""
params: list[Any] = [target_pk_value]

if table in self.timeframe_filters:
if table in self.timeframe_filters and table == self.starting_table:
filter_config = self.timeframe_filters[table]
timeframe_clause = f' AND "{filter_config.column_name}" BETWEEN %s AND %s'
params.extend([filter_config.start_date, filter_config.end_date])
Expand Down Expand Up @@ -733,11 +763,11 @@ def _find_referencing_records_batch(
pk_columns = ", ".join(f'"{pk}"' for pk in source_table.primary_keys)
placeholders = ", ".join(["%s"] * len(target_pk_values))

# Apply timeframe filter if applicable
# Apply timeframe filter only to starting table
timeframe_clause = ""
params: list[Any] = target_pk_values.copy()

if table in self.timeframe_filters:
if table in self.timeframe_filters and table == self.starting_table:
filter_config = self.timeframe_filters[table]
timeframe_clause = f' AND "{filter_config.column_name}" BETWEEN %s AND %s'
params.extend([filter_config.start_date, filter_config.end_date])
Expand Down
Loading