diff --git a/sqlit/domains/connections/ui/connection_error_handlers.py b/sqlit/domains/connections/ui/connection_error_handlers.py index 2268a66..12e9262 100644 --- a/sqlit/domains/connections/ui/connection_error_handlers.py +++ b/sqlit/domains/connections/ui/connection_error_handlers.py @@ -36,9 +36,19 @@ def can_handle(self, error: Exception) -> bool: def handle(self, app: ConnectionErrorApp, error: Exception, config: ConnectionConfig) -> None: from sqlit.domains.connections.providers.exceptions import MissingDriverError + from sqlit.shared.core.debug_events import emit_debug_event + from .restart_cache import write_pending_connection_cache from .screens import PackageSetupScreen + # Save pending connection for auto-reconnect after driver install restart + if config.name: + write_pending_connection_cache(config.name) + emit_debug_event( + "driver_install.pending_connection_saved", + connection_name=config.name, + ) + # No on_success callback - uses default "Restart to apply" behavior app.push_screen(PackageSetupScreen(cast(MissingDriverError, error))) diff --git a/sqlit/domains/connections/ui/restart_cache.py b/sqlit/domains/connections/ui/restart_cache.py index 7f76c4b..e2daa69 100644 --- a/sqlit/domains/connections/ui/restart_cache.py +++ b/sqlit/domains/connections/ui/restart_cache.py @@ -28,3 +28,18 @@ def clear_restart_cache() -> None: get_restart_cache_path().unlink(missing_ok=True) except Exception: pass + + +def write_pending_connection_cache(connection_name: str) -> None: + """Cache a pending connection name for auto-reconnect after driver install restart. + + This is used when a user tries to connect to a server but the driver is missing. + After the driver is installed and the app restarts, it can auto-connect to this + connection. + """ + payload = { + "version": 2, + "type": "pending_connection", + "connection_name": connection_name, + } + write_restart_cache(payload) diff --git a/sqlit/domains/explorer/ui/mixins/tree.py b/sqlit/domains/explorer/ui/mixins/tree.py index 9491461..9e4a0f4 100644 --- a/sqlit/domains/explorer/ui/mixins/tree.py +++ b/sqlit/domains/explorer/ui/mixins/tree.py @@ -216,6 +216,18 @@ def action_refresh_tree(self: TreeMixinHost) -> None: if hasattr(self, "_loading_nodes"): self._loading_nodes.clear() self._schema_service = None + + # Reload saved connections from disk (in case added via CLI) + try: + services = getattr(self, "services", None) + if services: + store = getattr(services, "connection_store", None) + if store: + reloaded = store.load_all(load_credentials=False) + self.connections = reloaded + except Exception: + pass # Keep existing connections if reload fails + self.refresh_tree() loader = getattr(self, "_load_schema_cache", None) if callable(loader): @@ -295,6 +307,8 @@ def action_select_table(self: TreeMixinHost) -> None: "name": data.name, "columns": [], } + # Stash per-result metadata so results can resolve PKs without relying on globals. + self._pending_result_table_info = self._last_query_table self._prime_last_query_table_columns(data.database, data.schema, data.name) self.query_input.text = self.current_provider.dialect.build_select_query( diff --git a/sqlit/domains/query/ui/mixins/query_execution.py b/sqlit/domains/query/ui/mixins/query_execution.py index 6b1c4e2..417a633 100644 --- a/sqlit/domains/query/ui/mixins/query_execution.py +++ b/sqlit/domains/query/ui/mixins/query_execution.py @@ -417,6 +417,17 @@ async def _run_query_async(self: QueryMixinHost, query: str, keep_insert_mode: b executable_statements = [s for s in statements if not is_comment_only_statement(s)] is_multi_statement = len(executable_statements) > 1 + if is_multi_statement: + self._pending_result_table_info = None + elif executable_statements: + if getattr(self, "_pending_result_table_info", None) is None: + table_info = self._infer_result_table_info(executable_statements[0]) + if table_info is not None: + self._pending_result_table_info = table_info + prime = getattr(self, "_prime_result_table_columns", None) + if callable(prime): + prime(table_info) + try: start_time = time.perf_counter() max_rows = self.services.runtime.max_rows or MAX_FETCH_ROWS diff --git a/sqlit/domains/query/ui/mixins/query_results.py b/sqlit/domains/query/ui/mixins/query_results.py index c40223f..6be6e8b 100644 --- a/sqlit/domains/query/ui/mixins/query_results.py +++ b/sqlit/domains/query/ui/mixins/query_results.py @@ -224,6 +224,7 @@ def _render_results_table_incremental( escape: bool, row_limit: int, render_token: int, + table_info: dict[str, Any] | None = None, ) -> None: initial_count = min(RESULTS_RENDER_INITIAL_ROWS, row_limit) initial_rows = rows[:initial_count] if initial_count > 0 else [] @@ -269,9 +270,16 @@ def _render_results_table_incremental( pass if render_token == getattr(self, "_results_render_token", 0): self._replace_results_table_with_data(columns, rows, escape=escape) + if table_info is not None: + try: + self.results_table.result_table_info = table_info + except Exception: + pass return if render_token != getattr(self, "_results_render_token", 0): return + if table_info is not None: + table.result_table_info = table_info self._replace_results_table_with_table(table) self._schedule_results_render( table, @@ -291,6 +299,7 @@ async def _display_query_results( self._last_result_columns = columns self._last_result_rows = rows self._last_result_row_count = row_count + table_info = getattr(self, "_pending_result_table_info", None) # Switch to single result mode (in case we were showing stacked results) self._show_single_result_mode() @@ -304,12 +313,15 @@ async def _display_query_results( escape=True, row_limit=row_limit, render_token=render_token, + table_info=table_info, ) else: render_rows = rows[:row_limit] if row_limit else [] table = self._build_results_table(columns, render_rows, escape=True) if render_token != getattr(self, "_results_render_token", 0): return + if table_info is not None: + table.result_table_info = table_info self._replace_results_table_with_table(table) time_str = format_duration_ms(elapsed_ms) @@ -320,9 +332,15 @@ async def _display_query_results( ) else: self.notify(f"Query returned {row_count} rows in {time_str}") + if table_info is not None: + prime = getattr(self, "_prime_result_table_columns", None) + if callable(prime): + prime(table_info) + self._pending_result_table_info = None def _display_non_query_result(self: QueryMixinHost, affected: int, elapsed_ms: float) -> None: """Display non-query result (called on main thread).""" + self._pending_result_table_info = None self._last_result_columns = ["Result"] self._last_result_rows = [(f"{affected} row(s) affected",)] self._last_result_row_count = 1 @@ -337,6 +355,7 @@ def _display_non_query_result(self: QueryMixinHost, affected: int, elapsed_ms: f def _display_query_error(self: QueryMixinHost, error_message: str) -> None: """Display query error (called on main thread).""" self._cancel_results_render() + self._pending_result_table_info = None # notify(severity="error") handles displaying the error in results via _show_error_in_results self.notify(f"Query error: {error_message}", severity="error") @@ -360,7 +379,17 @@ def _display_multi_statement_results( # Add each result section for i, stmt_result in enumerate(multi_result.results): - container.add_result_section(stmt_result, i, auto_collapse=auto_collapse) + table_info = self._infer_result_table_info(stmt_result.statement) + if table_info is not None: + prime = getattr(self, "_prime_result_table_columns", None) + if callable(prime): + prime(table_info) + container.add_result_section( + stmt_result, + i, + auto_collapse=auto_collapse, + table_info=table_info, + ) # Show the stacked results container, hide single result table self._show_stacked_results_mode() @@ -378,6 +407,7 @@ def _display_multi_statement_results( ) else: self.notify(f"Executed {total} statements in {time_str}") + self._pending_result_table_info = None def _get_stacked_results_container(self: QueryMixinHost) -> Any: """Get the stacked results container.""" @@ -410,3 +440,30 @@ def _show_single_result_mode(self: QueryMixinHost) -> None: stacked.remove_class("active") except Exception: pass + + def _infer_result_table_info(self: QueryMixinHost, sql: str) -> dict[str, Any] | None: + """Best-effort inference of a single source table for query results.""" + from sqlit.domains.query.completion import extract_table_refs + + refs = extract_table_refs(sql) + if len(refs) != 1: + return None + ref = refs[0] + schema = ref.schema + name = ref.name + database = None + table_metadata = getattr(self, "_table_metadata", {}) or {} + key_candidates = [name.lower()] + if schema: + key_candidates.insert(0, f"{schema}.{name}".lower()) + for key in key_candidates: + metadata = table_metadata.get(key) + if metadata: + schema, name, database = metadata + break + return { + "database": database, + "schema": schema, + "name": name, + "columns": [], + } diff --git a/sqlit/domains/results/ui/mixins/results.py b/sqlit/domains/results/ui/mixins/results.py index 2f3a434..d875029 100644 --- a/sqlit/domains/results/ui/mixins/results.py +++ b/sqlit/domains/results/ui/mixins/results.py @@ -8,6 +8,8 @@ from sqlit.shared.ui.protocols import ResultsMixinHost from sqlit.shared.ui.widgets import SqlitDataTable +MIN_TIMER_DELAY_S = 0.001 + class ResultsMixin: """Mixin providing results handling functionality.""" @@ -19,6 +21,122 @@ class ResultsMixin: _tooltip_showing: bool = False _tooltip_timer: Any | None = None + def _schedule_results_timer(self: ResultsMixinHost, delay_s: float, callback: Any) -> Any | None: + set_timer = getattr(self, "set_timer", None) + if callable(set_timer): + return set_timer(delay_s, callback) + call_later = getattr(self, "call_later", None) + if callable(call_later): + try: + call_later(callback) + return None + except Exception: + pass + try: + callback() + except Exception: + pass + return None + + def _apply_result_table_columns( + self: ResultsMixinHost, + table_info: dict[str, Any], + token: int, + columns: list[Any], + ) -> None: + if table_info.get("_columns_token") != token: + return + table_info["columns"] = columns + + def _prime_result_table_columns(self: ResultsMixinHost, table_info: dict[str, Any] | None) -> None: + if not table_info: + return + if table_info.get("columns"): + return + name = table_info.get("name") + if not name: + return + database = table_info.get("database") + schema = table_info.get("schema") + token = int(table_info.get("_columns_token", 0)) + 1 + table_info["_columns_token"] = token + + async def work_async() -> None: + import asyncio + + columns: list[Any] = [] + try: + runtime = getattr(self.services, "runtime", None) + use_worker = bool(getattr(runtime, "process_worker", False)) and not bool( + getattr(getattr(runtime, "mock", None), "enabled", False) + ) + client = None + if use_worker and hasattr(self, "_get_process_worker_client_async"): + client = await self._get_process_worker_client_async() # type: ignore[attr-defined] + + if client is not None and hasattr(client, "list_columns") and self.current_config is not None: + outcome = await asyncio.to_thread( + client.list_columns, + config=self.current_config, + database=database, + schema=schema, + name=name, + ) + if getattr(outcome, "cancelled", False): + return + error = getattr(outcome, "error", None) + if error: + raise RuntimeError(error) + columns = outcome.columns or [] + else: + schema_service = getattr(self, "_get_schema_service", None) + if callable(schema_service): + service = self._get_schema_service() + if service: + columns = await asyncio.to_thread( + service.list_columns, + database, + schema, + name, + ) + except Exception: + columns = [] + + self._schedule_results_timer( + MIN_TIMER_DELAY_S, + lambda: self._apply_result_table_columns(table_info, token, columns), + ) + + self.run_worker(work_async(), name=f"prime-result-columns-{name}", exclusive=False) + + def _normalize_column_name(self: ResultsMixinHost, name: str) -> str: + trimmed = name.strip() + if len(trimmed) >= 2: + if trimmed[0] == trimmed[-1] and trimmed[0] in ("\"", "`"): + trimmed = trimmed[1:-1] + elif trimmed[0] == "[" and trimmed[-1] == "]": + trimmed = trimmed[1:-1] + if "." in trimmed and not any(q in trimmed for q in ("\"", "`", "[")): + trimmed = trimmed.split(".")[-1] + return trimmed.lower() + + def _get_active_results_table_info( + self: ResultsMixinHost, + table: SqlitDataTable | None, + stacked: bool, + ) -> dict[str, Any] | None: + if not table: + return None + if stacked: + section = self._find_results_section(table) + table_info = getattr(section, "result_table_info", None) + if table_info: + return table_info + table_info = getattr(table, "result_table_info", None) + if table_info: + return table_info + return getattr(self, "_last_query_table", None) + def _copy_text(self: ResultsMixinHost, text: str) -> bool: """Copy text to clipboard if possible, otherwise store internally.""" self._internal_clipboard = text @@ -610,21 +728,20 @@ def sql_value(v: object) -> str: # Get table name and primary key columns table_name = "" pk_column_names: set[str] = set() - - if hasattr(self, "_last_query_table") and self._last_query_table: - table_info = self._last_query_table - table_name = table_info["name"] + table_info = self._get_active_results_table_info(table, _stacked) + if table_info: + table_name = table_info.get("name", table_name) # Get PK columns from column info for col in table_info.get("columns", []): if col.is_primary_key: - pk_column_names.add(col.name) + pk_column_names.add(self._normalize_column_name(col.name)) # Build WHERE clause - prefer PK columns, fall back to all columns where_parts = [] for i, col in enumerate(columns): if i < len(row_values): # If we have PK info, only use PK columns; otherwise use all columns - if pk_column_names and col not in pk_column_names: + if pk_column_names and self._normalize_column_name(col) not in pk_column_names: continue val = row_values[i] if val is None: @@ -685,9 +802,10 @@ def action_edit_cell(self: ResultsMixinHost) -> None: column_name = columns[cursor_col] # Check if this column is a primary key - don't allow editing PKs - if hasattr(self, "_last_query_table") and self._last_query_table: - for col in self._last_query_table.get("columns", []): - if col.name == column_name and col.is_primary_key: + table_info = self._get_active_results_table_info(table, _stacked) + if table_info: + for col in table_info.get("columns", []): + if col.is_primary_key and self._normalize_column_name(col.name) == self._normalize_column_name(column_name): self.notify("Cannot edit primary key column", severity="warning") return @@ -705,21 +823,19 @@ def sql_value(v: object) -> str: # Get table name and primary key columns table_name = "
" pk_column_names: set[str] = set() - - if hasattr(self, "_last_query_table") and self._last_query_table: - table_info = self._last_query_table - table_name = table_info["name"] + if table_info: + table_name = table_info.get("name", table_name) # Get PK columns from column info for col in table_info.get("columns", []): if col.is_primary_key: - pk_column_names.add(col.name) + pk_column_names.add(self._normalize_column_name(col.name)) # Build WHERE clause - prefer PK columns, fall back to all columns where_parts = [] for i, col in enumerate(columns): if i < len(row_values): # If we have PK info, only use PK columns; otherwise use all columns - if pk_column_names and col not in pk_column_names: + if pk_column_names and self._normalize_column_name(col) not in pk_column_names: continue val = row_values[i] if val is None: diff --git a/sqlit/domains/shell/app/commands/debug.py b/sqlit/domains/shell/app/commands/debug.py index 7134c76..1cecf7d 100644 --- a/sqlit/domains/shell/app/commands/debug.py +++ b/sqlit/domains/shell/app/commands/debug.py @@ -41,6 +41,16 @@ def _set_debug_enabled(app: Any, enabled: bool) -> None: else: app._debug_events_enabled = bool(enabled) + # Persist the setting across sessions + try: + services = getattr(app, "services", None) + if services: + store = getattr(services, "settings_store", None) + if store: + store.set("debug_events_enabled", enabled) + except Exception: + pass + path = getattr(app, "_debug_event_log_path", None) suffix = f" (log: {path})" if path else "" state = "enabled" if enabled else "disabled" diff --git a/sqlit/domains/shell/app/main.py b/sqlit/domains/shell/app/main.py index 5f9fbec..66edcde 100644 --- a/sqlit/domains/shell/app/main.py +++ b/sqlit/domains/shell/app/main.py @@ -197,6 +197,7 @@ def __init__( self._columns_loading: set[str] = set() self._state_machine = UIStateMachine() self._last_query_table: dict[str, Any] | None = None + self._pending_result_table_info: dict[str, Any] | None = None self._query_target_database: str | None = None # Target DB for auto-generated queries self._restart_requested: bool = False # Idle scheduler for background work diff --git a/sqlit/domains/shell/app/startup_flow.py b/sqlit/domains/shell/app/startup_flow.py index 3d3bd62..84c482a 100644 --- a/sqlit/domains/shell/app/startup_flow.py +++ b/sqlit/domains/shell/app/startup_flow.py @@ -36,6 +36,10 @@ def run_on_mount(app: AppProtocol) -> None: app._startup_stamp("settings_loaded") app._expanded_paths = set(settings.get("expanded_nodes", [])) + if settings.get("debug_events_enabled"): + setter = getattr(app, "_set_debug_events_enabled", None) + if callable(setter): + setter(True) if "process_worker" in settings: app.services.runtime.process_worker = bool(settings.get("process_worker")) if "process_worker_warm_on_idle" in settings: @@ -83,6 +87,9 @@ def run_on_mount(app: AppProtocol) -> None: app.object_tree.cursor_line = 0 app._update_section_labels() maybe_restore_connection_screen(app) + # Auto-connect to pending connection after driver install (if not already connecting) + if app._startup_connect_config is None: + maybe_auto_connect_pending(app) app._startup_stamp("restore_checked") if app._debug_mode: app.call_after_refresh(app._record_launch_ms) @@ -224,6 +231,83 @@ def _get_restart_cache_path() -> Path: return Path(tempfile.gettempdir()) / "sqlit-driver-install-restore.json" +def maybe_auto_connect_pending(app: AppProtocol) -> bool: + """Auto-connect to a pending connection after driver install restart. + + Returns True if a connection was initiated, False otherwise. + """ + from sqlit.shared.core.debug_events import emit_debug_event + + from sqlit.domains.connections.ui.restart_cache import ( + clear_restart_cache, + get_restart_cache_path, + ) + + cache_path = get_restart_cache_path() + emit_debug_event( + "startup.pending_connection_check", + cache_path=str(cache_path), + exists=cache_path.exists(), + ) + if not cache_path.exists(): + return False + + emit_debug_event( + "startup.pending_connection_found", + contents=cache_path.read_text(), + ) + + try: + payload = json.loads(cache_path.read_text(encoding="utf-8")) + except Exception as e: + emit_debug_event("startup.pending_connection_parse_error", error=str(e)) + clear_restart_cache() + return False + + # Always clear cache after reading + clear_restart_cache() + + # Check for version 2 pending_connection type + if not isinstance(payload, dict): + emit_debug_event("startup.pending_connection_invalid", reason="not a dict") + return False + if payload.get("version") != 2: + emit_debug_event("startup.pending_connection_invalid", reason="wrong version", version=payload.get("version")) + return False + if payload.get("type") != "pending_connection": + emit_debug_event("startup.pending_connection_invalid", reason="wrong type", type=payload.get("type")) + return False + + connection_name = payload.get("connection_name") + if not connection_name: + emit_debug_event("startup.pending_connection_invalid", reason="no connection_name") + return False + + emit_debug_event( + "startup.pending_connection_lookup", + connection_name=connection_name, + available_connections=[getattr(c, "name", None) for c in app.connections], + ) + + # Find the connection by name + config = next( + (c for c in app.connections if getattr(c, "name", None) == connection_name), + None, + ) + if config is None: + emit_debug_event("startup.pending_connection_not_found", connection_name=connection_name) + return False + + emit_debug_event("startup.pending_connection_connecting", connection_name=connection_name) + + # Auto-connect after refresh (same pattern as startup_connect_config) + def _connect_pending() -> None: + app.connect_to_server(config) + + app.call_after_refresh(_connect_pending) + return True + + def maybe_restore_connection_screen(app: AppProtocol) -> None: """Restore an in-progress connection form after a driver-install restart.""" cache_path = _get_restart_cache_path() @@ -239,14 +323,16 @@ def maybe_restore_connection_screen(app: AppProtocol) -> None: pass return + # Only handle version 1 (connection form restore), leave version 2 for maybe_auto_connect_pending + if not isinstance(payload, dict) or payload.get("version") != 1: + return + + # Clear cache only for version 1 try: cache_path.unlink(missing_ok=True) except Exception: pass - if not isinstance(payload, dict) or payload.get("version") != 1: - return - values = payload.get("values") if not isinstance(values, dict): return diff --git a/sqlit/shared/ui/protocols/results.py b/sqlit/shared/ui/protocols/results.py index d6e3f5e..2124b1a 100644 --- a/sqlit/shared/ui/protocols/results.py +++ b/sqlit/shared/ui/protocols/results.py @@ -16,6 +16,7 @@ class ResultsStateProtocol(Protocol): _last_result_row_count: int _internal_clipboard: str _last_query_table: dict[str, Any] | None + _pending_result_table_info: dict[str, Any] | None _results_table_counter: int _results_filter_visible: bool _results_filter_text: str diff --git a/sqlit/shared/ui/widgets_stacked_results.py b/sqlit/shared/ui/widgets_stacked_results.py index e61c819..015cc45 100644 --- a/sqlit/shared/ui/widgets_stacked_results.py +++ b/sqlit/shared/ui/widgets_stacked_results.py @@ -104,6 +104,7 @@ def __init__( self.is_error = is_error self.result_columns: list[str] = [] self.result_rows: list[tuple] = [] + self.result_table_info: dict[str, Any] | None = None self._content = content if is_error: self.add_class("error") @@ -156,6 +157,7 @@ def add_result_section( index: int, *, auto_collapse: bool = False, + table_info: dict[str, Any] | None = None, ) -> None: """Add a result section for a statement result.""" from sqlit.domains.query.app.query_service import QueryResult @@ -169,6 +171,8 @@ def add_result_section( # SELECT result - build a DataTable result_columns, result_rows = self._get_result_table_data(stmt_result.result) content = self._build_result_table_from_rows(result_columns, result_rows, index) + if table_info is not None: + content.result_table_info = table_info else: # Non-query result (INSERT/UPDATE/DELETE) content = NonQueryDisplay(stmt_result.result.rows_affected) @@ -186,6 +190,7 @@ def add_result_section( ) section.result_columns = result_columns section.result_rows = result_rows + section.result_table_info = table_info self.mount(section) self._section_count += 1 diff --git a/tests/unit/test_auto_reconnect_after_driver_install.py b/tests/unit/test_auto_reconnect_after_driver_install.py new file mode 100644 index 0000000..7e4ad96 --- /dev/null +++ b/tests/unit/test_auto_reconnect_after_driver_install.py @@ -0,0 +1,100 @@ +"""Test auto-reconnect after driver installation restart.""" + +from __future__ import annotations + +import json +from unittest.mock import MagicMock + +from sqlit.domains.connections.domain.config import ConnectionConfig + + +class TestAutoReconnectAfterDriverInstall: + """Test that app auto-connects after driver install restart.""" + + def test_pending_connection_cache_written_on_missing_driver(self): + """ + When user tries to connect but driver is missing, + the connection name should be cached for auto-reconnect after restart. + """ + from sqlit.domains.connections.ui.restart_cache import ( + get_restart_cache_path, + write_pending_connection_cache, + ) + + config = ConnectionConfig(name="my-mssql-server", db_type="mssql") + + # Write the pending connection cache + write_pending_connection_cache(config.name) + + # Verify cache was written + cache_path = get_restart_cache_path() + assert cache_path.exists() + + payload = json.loads(cache_path.read_text()) + assert payload["version"] == 2 + assert payload["type"] == "pending_connection" + assert payload["connection_name"] == "my-mssql-server" + + # Cleanup + cache_path.unlink(missing_ok=True) + + def test_startup_reads_pending_connection_and_connects(self): + """ + On startup, if pending_connection cache exists, + app should auto-connect to that connection. + """ + from sqlit.domains.connections.ui.restart_cache import ( + get_restart_cache_path, + write_pending_connection_cache, + ) + from sqlit.domains.shell.app.startup_flow import maybe_auto_connect_pending + + # Setup: Write pending connection cache + write_pending_connection_cache("my-mssql-server") + + # Mock app with the saved connection + mock_app = MagicMock() + saved_config = ConnectionConfig(name="my-mssql-server", db_type="mssql") + mock_app.connections = [saved_config] + mock_app.connect_to_server = MagicMock() + mock_app.call_after_refresh = MagicMock() + + # Call the startup function + result = maybe_auto_connect_pending(mock_app) + + # Should have scheduled a connection via call_after_refresh + assert result is True + mock_app.call_after_refresh.assert_called_once() + + # Execute the callback to verify it calls connect_to_server + callback = mock_app.call_after_refresh.call_args[0][0] + callback() + mock_app.connect_to_server.assert_called_once_with(saved_config) + + # Cache should be cleared + assert not get_restart_cache_path().exists() + + def test_startup_ignores_missing_connection(self): + """ + If the cached connection no longer exists, don't crash. + """ + from sqlit.domains.connections.ui.restart_cache import ( + get_restart_cache_path, + write_pending_connection_cache, + ) + from sqlit.domains.shell.app.startup_flow import maybe_auto_connect_pending + + write_pending_connection_cache("deleted-connection") + + mock_app = MagicMock() + mock_app.connections = [] # No connections + mock_app.connect_to_server = MagicMock() + + result = maybe_auto_connect_pending(mock_app) + + # Should return False (no connection made) + assert result is False + mock_app.connect_to_server.assert_not_called() + + # Cache should still be cleared + assert not get_restart_cache_path().exists() diff --git a/tests/unit/test_connection_picker_refresh.py b/tests/unit/test_connection_picker_refresh.py new file mode 100644 index 0000000..03ba272 --- /dev/null +++ b/tests/unit/test_connection_picker_refresh.py @@ -0,0 +1,117 @@ +"""Test that tree refresh reloads connections from disk.""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + + +class TestTreeRefresh: + """Test that pressing 'f' in explorer reloads saved connections.""" + + def test_action_refresh_tree_reloads_connections(self): + """ + Bug: action_refresh_tree didn't reload saved connections from store. + Fix: Now it calls connection_store.load_all() and updates self.connections. + """ + from sqlit.domains.connections.domain.config import ConnectionConfig + + # Create a mock host with required attributes + mock_host = MagicMock() + + # Set up initial state + initial_conn = ConnectionConfig(name="existing", db_type="sqlite") + mock_host.connections = [initial_conn] + + # Mock the services and store + mock_store = MagicMock() + new_conn = ConnectionConfig(name="new-cli-conn", db_type="postgresql") + mock_store.load_all.return_value = [initial_conn, new_conn] + + mock_services = MagicMock() + mock_services.connection_store = mock_store + mock_host.services = mock_services + + # Mock other methods + mock_host._get_object_cache.return_value = MagicMock() + mock_host._schema_cache = {"columns": {}} + mock_host._loading_nodes = set() + mock_host._schema_service = None + mock_host.refresh_tree = MagicMock() + + # Import and call the mixin method directly + from sqlit.domains.explorer.ui.mixins.tree import TreeMixin + + # Before: 1 connection + assert len(mock_host.connections) == 1 + + # Call refresh using the mixin method bound to our mock + TreeMixin.action_refresh_tree(mock_host) + + # Verify store.load_all was called + mock_store.load_all.assert_called_once_with(load_credentials=False) + + # After: 2 connections (reloaded from store) + assert len(mock_host.connections) == 2 + assert mock_host.connections[1].name == "new-cli-conn" + + # Verify tree was rebuilt + mock_host.refresh_tree.assert_called_once() + + def test_action_refresh_tree_handles_store_error(self): + """Test that refresh handles store errors gracefully.""" + from sqlit.domains.connections.domain.config import ConnectionConfig + + mock_host = MagicMock() + mock_host.connections = [ConnectionConfig(name="existing", db_type="sqlite")] + + mock_store = MagicMock() + mock_store.load_all.side_effect = Exception("File not found") + + mock_services = MagicMock() + mock_services.connection_store = mock_store + mock_host.services = mock_services + + mock_host._get_object_cache.return_value = MagicMock() + mock_host._schema_cache = {"columns": {}} + mock_host._loading_nodes = set() + mock_host._schema_service = None + mock_host.refresh_tree = MagicMock() + + from sqlit.domains.explorer.ui.mixins.tree import TreeMixin + + # Should not raise, should keep existing connections + TreeMixin.action_refresh_tree(mock_host) + + # Connections should be unchanged + assert len(mock_host.connections) == 1 + assert mock_host.connections[0].name == "existing" + + # Tree should still be refreshed + mock_host.refresh_tree.assert_called_once() + + def test_action_refresh_tree_handles_missing_services(self): + """Test that refresh handles missing services gracefully.""" + from sqlit.domains.connections.domain.config import ConnectionConfig + + mock_host = MagicMock() + mock_host.connections = [ConnectionConfig(name="existing", db_type="sqlite")] + + # No services attribute + del mock_host.services + + mock_host._get_object_cache.return_value = MagicMock() + mock_host._schema_cache = {"columns": {}} + mock_host._loading_nodes = set() + mock_host._schema_service = None + mock_host.refresh_tree = MagicMock() + + from sqlit.domains.explorer.ui.mixins.tree import TreeMixin + + # Should not raise + TreeMixin.action_refresh_tree(mock_host) + + # Connections should be unchanged + assert len(mock_host.connections) == 1 + + # Tree should still be refreshed + mock_host.refresh_tree.assert_called_once()