Skip to content

Commit 93e681e

Browse files
committed
Address Copilot review feedback (round 3)
- Fix PostgreSQL FK handling: return empty array for FK statements (PostgreSQL doesn't have a simple way to disable FK checks globally, rely on CASCADE option instead) - Add FK re-enable on error in catch block for MySQL/SQLite - Add sqlite_sequence reset for true TRUNCATE semantics (reset auto-increment) - Add docstrings to helper functions for better code documentation - Fix nested transaction issue when both cell changes and table ops exist (wrap in single transaction only when multiple operation types)
1 parent 750cbcf commit 93e681e

1 file changed

Lines changed: 62 additions & 16 deletions

File tree

TablePro/Views/Main/MainContentCoordinator.swift

Lines changed: 62 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -564,20 +564,32 @@ final class MainContentCoordinator: ObservableObject {
564564

565565
var allStatements: [String] = []
566566

567+
// Wrap all operations in a single transaction when we have multiple operations
568+
let needsTransaction = hasEditedCells && hasPendingTableOps
569+
if needsTransaction {
570+
allStatements.append("BEGIN")
571+
}
572+
567573
if hasEditedCells {
568574
allStatements.append(contentsOf: changeManager.generateSQL())
569575
}
570576

571577
if hasPendingTableOps {
572578
// Generate table operation SQL with FK/cascade options
579+
// Don't wrap in transaction if we're already in one
573580
let tableOpStatements = generateTableOperationSQL(
574581
truncates: pendingTruncates,
575582
deletes: pendingDeletes,
576-
options: tableOperationOptions
583+
options: tableOperationOptions,
584+
wrapInTransaction: !needsTransaction
577585
)
578586
allStatements.append(contentsOf: tableOpStatements)
579587
}
580588

589+
if needsTransaction {
590+
allStatements.append("COMMIT")
591+
}
592+
581593
guard !allStatements.isEmpty else {
582594
if let index = tabManager.selectedTabIndex {
583595
tabManager.tabs[index].errorMessage = "Could not generate SQL for changes."
@@ -595,11 +607,18 @@ final class MainContentCoordinator: ObservableObject {
595607
)
596608
}
597609

598-
/// Generate SQL for table truncate/delete operations with FK/cascade options
610+
/// Generates SQL statements for table truncate/drop operations with FK handling.
611+
/// - Parameters:
612+
/// - truncates: Set of table names to truncate
613+
/// - deletes: Set of table names to drop
614+
/// - options: Per-table options for FK and cascade handling
615+
/// - wrapInTransaction: Whether to wrap statements in BEGIN/COMMIT
616+
/// - Returns: Array of SQL statements to execute
599617
private func generateTableOperationSQL(
600618
truncates: Set<String>,
601619
deletes: Set<String>,
602-
options: [String: TableOperationOptions]
620+
options: [String: TableOperationOptions],
621+
wrapInTransaction: Bool = true
603622
) -> [String] {
604623
var statements: [String] = []
605624
let dbType = connection.type
@@ -608,13 +627,13 @@ final class MainContentCoordinator: ObservableObject {
608627
let sortedTruncates = truncates.sorted()
609628
let sortedDeletes = deletes.sorted()
610629

611-
// Check if any operation needs FK disabled
612-
let needsDisableFK = truncates.union(deletes).contains { tableName in
630+
// Check if any operation needs FK disabled (not applicable to PostgreSQL)
631+
let needsDisableFK = dbType != .postgresql && truncates.union(deletes).contains { tableName in
613632
options[tableName]?.ignoreForeignKeys == true
614633
}
615634

616635
// Wrap in transaction for atomicity
617-
let needsTransaction = (sortedTruncates.count + sortedDeletes.count) > 1
636+
let needsTransaction = wrapInTransaction && (sortedTruncates.count + sortedDeletes.count) > 1
618637
if needsTransaction {
619638
statements.append("BEGIN")
620639
}
@@ -626,7 +645,7 @@ final class MainContentCoordinator: ObservableObject {
626645
for tableName in sortedTruncates {
627646
let quotedName = dbType.quoteIdentifier(tableName)
628647
let tableOptions = options[tableName] ?? TableOperationOptions()
629-
statements.append(truncateStatement(tableName: quotedName, options: tableOptions, dbType: dbType))
648+
statements.append(contentsOf: truncateStatements(tableName: tableName, quotedName: quotedName, options: tableOptions, dbType: dbType))
630649
}
631650

632651
for tableName in sortedDeletes {
@@ -646,39 +665,53 @@ final class MainContentCoordinator: ObservableObject {
646665
return statements
647666
}
648667

668+
/// Returns SQL statements to disable foreign key checks for the database type.
669+
/// - Note: PostgreSQL doesn't support globally disabling FK checks; use CASCADE instead.
649670
private func fkDisableStatements(for dbType: DatabaseType) -> [String] {
650671
switch dbType {
651672
case .mysql, .mariadb:
652673
return ["SET FOREIGN_KEY_CHECKS=0"]
653674
case .postgresql:
654-
// SET CONSTRAINTS works within transaction for deferrable constraints
655-
// For non-deferrable, CASCADE is the proper approach
656-
return ["SET CONSTRAINTS ALL DEFERRED"]
675+
// PostgreSQL doesn't support globally disabling non-deferrable FKs.
676+
// Use CASCADE option for reliable FK handling.
677+
return []
657678
case .sqlite:
658679
return ["PRAGMA foreign_keys = OFF"]
659680
}
660681
}
661682

683+
/// Returns SQL statements to re-enable foreign key checks for the database type.
662684
private func fkEnableStatements(for dbType: DatabaseType) -> [String] {
663685
switch dbType {
664686
case .mysql, .mariadb:
665687
return ["SET FOREIGN_KEY_CHECKS=1"]
666688
case .postgresql:
667-
// Constraints auto-check at COMMIT
668689
return []
669690
case .sqlite:
670691
return ["PRAGMA foreign_keys = ON"]
671692
}
672693
}
673694

674-
private func truncateStatement(tableName: String, options: TableOperationOptions, dbType: DatabaseType) -> String {
675-
return switch dbType {
676-
case .mysql, .mariadb: "TRUNCATE TABLE \(tableName)"
677-
case .postgresql: options.cascade ? "TRUNCATE TABLE \(tableName) CASCADE" : "TRUNCATE TABLE \(tableName)"
678-
case .sqlite: "DELETE FROM \(tableName)"
695+
/// Generates TRUNCATE/DELETE statements for a table.
696+
/// - Note: SQLite uses DELETE and resets auto-increment via sqlite_sequence.
697+
private func truncateStatements(tableName: String, quotedName: String, options: TableOperationOptions, dbType: DatabaseType) -> [String] {
698+
switch dbType {
699+
case .mysql, .mariadb:
700+
return ["TRUNCATE TABLE \(quotedName)"]
701+
case .postgresql:
702+
let cascade = options.cascade ? " CASCADE" : ""
703+
return ["TRUNCATE TABLE \(quotedName)\(cascade)"]
704+
case .sqlite:
705+
// DELETE FROM + reset auto-increment counter for true TRUNCATE semantics
706+
let escapedName = tableName.replacingOccurrences(of: "'", with: "''")
707+
return [
708+
"DELETE FROM \(quotedName)",
709+
"DELETE FROM sqlite_sequence WHERE name = '\(escapedName)'"
710+
]
679711
}
680712
}
681713

714+
/// Generates DROP TABLE statement with optional CASCADE.
682715
private func dropTableStatement(tableName: String, options: TableOperationOptions, dbType: DatabaseType) -> String {
683716
let cascade = options.cascade ? " CASCADE" : ""
684717
return switch dbType {
@@ -699,6 +732,12 @@ final class MainContentCoordinator: ObservableObject {
699732
let deletedTables = Set(pendingDeletes)
700733
let truncatedTables = Set(pendingTruncates)
701734
let conn = connection
735+
let dbType = connection.type
736+
737+
// Track if FK checks were disabled (need to re-enable on failure)
738+
let fkWasDisabled = dbType != .postgresql && deletedTables.union(truncatedTables).contains { tableName in
739+
tableOperationOptions[tableName]?.ignoreForeignKeys == true
740+
}
702741

703742
// Capture options before clearing (for potential restore on failure)
704743
var capturedOptions: [String: TableOperationOptions] = [:]
@@ -782,6 +821,13 @@ final class MainContentCoordinator: ObservableObject {
782821
} catch {
783822
let executionTime = Date().timeIntervalSince(overallStartTime)
784823

824+
// Try to re-enable FK checks if they were disabled
825+
if fkWasDisabled, let driver = DatabaseManager.shared.activeDriver {
826+
for statement in self.fkEnableStatements(for: dbType) {
827+
try? await driver.execute(query: statement)
828+
}
829+
}
830+
785831
await MainActor.run {
786832
QueryHistoryManager.shared.recordQuery(
787833
query: sql,

0 commit comments

Comments
 (0)