diff --git a/sqlit/domains/connections/ui/connection_error_handlers.py b/sqlit/domains/connections/ui/connection_error_handlers.py
index 2268a66..12e9262 100644
--- a/sqlit/domains/connections/ui/connection_error_handlers.py
+++ b/sqlit/domains/connections/ui/connection_error_handlers.py
@@ -36,9 +36,19 @@ def can_handle(self, error: Exception) -> bool:
def handle(self, app: ConnectionErrorApp, error: Exception, config: ConnectionConfig) -> None:
from sqlit.domains.connections.providers.exceptions import MissingDriverError
+ from sqlit.shared.core.debug_events import emit_debug_event
+ from .restart_cache import write_pending_connection_cache
from .screens import PackageSetupScreen
+ # Save pending connection for auto-reconnect after driver install restart
+ if config.name:
+ write_pending_connection_cache(config.name)
+ emit_debug_event(
+ "driver_install.pending_connection_saved",
+ connection_name=config.name,
+ )
+
# No on_success callback - uses default "Restart to apply" behavior
app.push_screen(PackageSetupScreen(cast(MissingDriverError, error)))
diff --git a/sqlit/domains/connections/ui/restart_cache.py b/sqlit/domains/connections/ui/restart_cache.py
index 7f76c4b..e2daa69 100644
--- a/sqlit/domains/connections/ui/restart_cache.py
+++ b/sqlit/domains/connections/ui/restart_cache.py
@@ -28,3 +28,18 @@ def clear_restart_cache() -> None:
get_restart_cache_path().unlink(missing_ok=True)
except Exception:
pass
+
+
+def write_pending_connection_cache(connection_name: str) -> None:
+ """Cache a pending connection name for auto-reconnect after driver install restart.
+
+ This is used when a user tries to connect to a server but the driver is missing.
+ After the driver is installed and the app restarts, it can auto-connect to this
+ connection.
+ """
+ payload = {
+ "version": 2,
+ "type": "pending_connection",
+ "connection_name": connection_name,
+ }
+ write_restart_cache(payload)
diff --git a/sqlit/domains/explorer/ui/mixins/tree.py b/sqlit/domains/explorer/ui/mixins/tree.py
index 9491461..9e4a0f4 100644
--- a/sqlit/domains/explorer/ui/mixins/tree.py
+++ b/sqlit/domains/explorer/ui/mixins/tree.py
@@ -216,6 +216,18 @@ def action_refresh_tree(self: TreeMixinHost) -> None:
if hasattr(self, "_loading_nodes"):
self._loading_nodes.clear()
self._schema_service = None
+
+ # Reload saved connections from disk (in case added via CLI)
+ try:
+ services = getattr(self, "services", None)
+ if services:
+ store = getattr(services, "connection_store", None)
+ if store:
+ reloaded = store.load_all(load_credentials=False)
+ self.connections = reloaded
+ except Exception:
+ pass # Keep existing connections if reload fails
+
self.refresh_tree()
loader = getattr(self, "_load_schema_cache", None)
if callable(loader):
@@ -295,6 +307,8 @@ def action_select_table(self: TreeMixinHost) -> None:
"name": data.name,
"columns": [],
}
+ # Stash per-result metadata so results can resolve PKs without relying on globals.
+ self._pending_result_table_info = self._last_query_table
self._prime_last_query_table_columns(data.database, data.schema, data.name)
self.query_input.text = self.current_provider.dialect.build_select_query(
diff --git a/sqlit/domains/query/ui/mixins/query_execution.py b/sqlit/domains/query/ui/mixins/query_execution.py
index 6b1c4e2..417a633 100644
--- a/sqlit/domains/query/ui/mixins/query_execution.py
+++ b/sqlit/domains/query/ui/mixins/query_execution.py
@@ -417,6 +417,17 @@ async def _run_query_async(self: QueryMixinHost, query: str, keep_insert_mode: b
executable_statements = [s for s in statements if not is_comment_only_statement(s)]
is_multi_statement = len(executable_statements) > 1
+ if is_multi_statement:
+ self._pending_result_table_info = None
+ elif executable_statements:
+ if getattr(self, "_pending_result_table_info", None) is None:
+ table_info = self._infer_result_table_info(executable_statements[0])
+ if table_info is not None:
+ self._pending_result_table_info = table_info
+ prime = getattr(self, "_prime_result_table_columns", None)
+ if callable(prime):
+ prime(table_info)
+
try:
start_time = time.perf_counter()
max_rows = self.services.runtime.max_rows or MAX_FETCH_ROWS
diff --git a/sqlit/domains/query/ui/mixins/query_results.py b/sqlit/domains/query/ui/mixins/query_results.py
index c40223f..6be6e8b 100644
--- a/sqlit/domains/query/ui/mixins/query_results.py
+++ b/sqlit/domains/query/ui/mixins/query_results.py
@@ -224,6 +224,7 @@ def _render_results_table_incremental(
escape: bool,
row_limit: int,
render_token: int,
+ table_info: dict[str, Any] | None = None,
) -> None:
initial_count = min(RESULTS_RENDER_INITIAL_ROWS, row_limit)
initial_rows = rows[:initial_count] if initial_count > 0 else []
@@ -269,9 +270,16 @@ def _render_results_table_incremental(
pass
if render_token == getattr(self, "_results_render_token", 0):
self._replace_results_table_with_data(columns, rows, escape=escape)
+ if table_info is not None:
+ try:
+ self.results_table.result_table_info = table_info
+ except Exception:
+ pass
return
if render_token != getattr(self, "_results_render_token", 0):
return
+ if table_info is not None:
+ table.result_table_info = table_info
self._replace_results_table_with_table(table)
self._schedule_results_render(
table,
@@ -291,6 +299,7 @@ async def _display_query_results(
self._last_result_columns = columns
self._last_result_rows = rows
self._last_result_row_count = row_count
+ table_info = getattr(self, "_pending_result_table_info", None)
# Switch to single result mode (in case we were showing stacked results)
self._show_single_result_mode()
@@ -304,12 +313,15 @@ async def _display_query_results(
escape=True,
row_limit=row_limit,
render_token=render_token,
+ table_info=table_info,
)
else:
render_rows = rows[:row_limit] if row_limit else []
table = self._build_results_table(columns, render_rows, escape=True)
if render_token != getattr(self, "_results_render_token", 0):
return
+ if table_info is not None:
+ table.result_table_info = table_info
self._replace_results_table_with_table(table)
time_str = format_duration_ms(elapsed_ms)
@@ -320,9 +332,15 @@ async def _display_query_results(
)
else:
self.notify(f"Query returned {row_count} rows in {time_str}")
+ if table_info is not None:
+ prime = getattr(self, "_prime_result_table_columns", None)
+ if callable(prime):
+ prime(table_info)
+ self._pending_result_table_info = None
def _display_non_query_result(self: QueryMixinHost, affected: int, elapsed_ms: float) -> None:
"""Display non-query result (called on main thread)."""
+ self._pending_result_table_info = None
self._last_result_columns = ["Result"]
self._last_result_rows = [(f"{affected} row(s) affected",)]
self._last_result_row_count = 1
@@ -337,6 +355,7 @@ def _display_non_query_result(self: QueryMixinHost, affected: int, elapsed_ms: f
def _display_query_error(self: QueryMixinHost, error_message: str) -> None:
"""Display query error (called on main thread)."""
self._cancel_results_render()
+ self._pending_result_table_info = None
# notify(severity="error") handles displaying the error in results via _show_error_in_results
self.notify(f"Query error: {error_message}", severity="error")
@@ -360,7 +379,17 @@ def _display_multi_statement_results(
# Add each result section
for i, stmt_result in enumerate(multi_result.results):
- container.add_result_section(stmt_result, i, auto_collapse=auto_collapse)
+ table_info = self._infer_result_table_info(stmt_result.statement)
+ if table_info is not None:
+ prime = getattr(self, "_prime_result_table_columns", None)
+ if callable(prime):
+ prime(table_info)
+ container.add_result_section(
+ stmt_result,
+ i,
+ auto_collapse=auto_collapse,
+ table_info=table_info,
+ )
# Show the stacked results container, hide single result table
self._show_stacked_results_mode()
@@ -378,6 +407,7 @@ def _display_multi_statement_results(
)
else:
self.notify(f"Executed {total} statements in {time_str}")
+ self._pending_result_table_info = None
def _get_stacked_results_container(self: QueryMixinHost) -> Any:
"""Get the stacked results container."""
@@ -410,3 +440,30 @@ def _show_single_result_mode(self: QueryMixinHost) -> None:
stacked.remove_class("active")
except Exception:
pass
+
+ def _infer_result_table_info(self: QueryMixinHost, sql: str) -> dict[str, Any] | None:
+ """Best-effort inference of a single source table for query results."""
+ from sqlit.domains.query.completion import extract_table_refs
+
+ refs = extract_table_refs(sql)
+ if len(refs) != 1:
+ return None
+ ref = refs[0]
+ schema = ref.schema
+ name = ref.name
+ database = None
+ table_metadata = getattr(self, "_table_metadata", {}) or {}
+ key_candidates = [name.lower()]
+ if schema:
+ key_candidates.insert(0, f"{schema}.{name}".lower())
+ for key in key_candidates:
+ metadata = table_metadata.get(key)
+ if metadata:
+ schema, name, database = metadata
+ break
+ return {
+ "database": database,
+ "schema": schema,
+ "name": name,
+ "columns": [],
+ }
diff --git a/sqlit/domains/results/ui/mixins/results.py b/sqlit/domains/results/ui/mixins/results.py
index 2f3a434..d875029 100644
--- a/sqlit/domains/results/ui/mixins/results.py
+++ b/sqlit/domains/results/ui/mixins/results.py
@@ -8,6 +8,8 @@
from sqlit.shared.ui.protocols import ResultsMixinHost
from sqlit.shared.ui.widgets import SqlitDataTable
+MIN_TIMER_DELAY_S = 0.001
+
class ResultsMixin:
"""Mixin providing results handling functionality."""
@@ -19,6 +21,122 @@ class ResultsMixin:
_tooltip_showing: bool = False
_tooltip_timer: Any | None = None
+ def _schedule_results_timer(self: ResultsMixinHost, delay_s: float, callback: Any) -> Any | None:
+ set_timer = getattr(self, "set_timer", None)
+ if callable(set_timer):
+ return set_timer(delay_s, callback)
+ call_later = getattr(self, "call_later", None)
+ if callable(call_later):
+ try:
+ call_later(callback)
+ return None
+ except Exception:
+ pass
+ try:
+ callback()
+ except Exception:
+ pass
+ return None
+
+ def _apply_result_table_columns(
+ self: ResultsMixinHost,
+ table_info: dict[str, Any],
+ token: int,
+ columns: list[Any],
+ ) -> None:
+ if table_info.get("_columns_token") != token:
+ return
+ table_info["columns"] = columns
+
+ def _prime_result_table_columns(self: ResultsMixinHost, table_info: dict[str, Any] | None) -> None:
+ if not table_info:
+ return
+ if table_info.get("columns"):
+ return
+ name = table_info.get("name")
+ if not name:
+ return
+ database = table_info.get("database")
+ schema = table_info.get("schema")
+ token = int(table_info.get("_columns_token", 0)) + 1
+ table_info["_columns_token"] = token
+
+ async def work_async() -> None:
+ import asyncio
+
+ columns: list[Any] = []
+ try:
+ runtime = getattr(self.services, "runtime", None)
+ use_worker = bool(getattr(runtime, "process_worker", False)) and not bool(
+ getattr(getattr(runtime, "mock", None), "enabled", False)
+ )
+ client = None
+ if use_worker and hasattr(self, "_get_process_worker_client_async"):
+ client = await self._get_process_worker_client_async() # type: ignore[attr-defined]
+
+ if client is not None and hasattr(client, "list_columns") and self.current_config is not None:
+ outcome = await asyncio.to_thread(
+ client.list_columns,
+ config=self.current_config,
+ database=database,
+ schema=schema,
+ name=name,
+ )
+ if getattr(outcome, "cancelled", False):
+ return
+ error = getattr(outcome, "error", None)
+ if error:
+ raise RuntimeError(error)
+ columns = outcome.columns or []
+ else:
+ schema_service = getattr(self, "_get_schema_service", None)
+ if callable(schema_service):
+ service = self._get_schema_service()
+ if service:
+ columns = await asyncio.to_thread(
+ service.list_columns,
+ database,
+ schema,
+ name,
+ )
+ except Exception:
+ columns = []
+
+ self._schedule_results_timer(
+ MIN_TIMER_DELAY_S,
+ lambda: self._apply_result_table_columns(table_info, token, columns),
+ )
+
+ self.run_worker(work_async(), name=f"prime-result-columns-{name}", exclusive=False)
+
+ def _normalize_column_name(self: ResultsMixinHost, name: str) -> str:
+ trimmed = name.strip()
+ if len(trimmed) >= 2:
+ if trimmed[0] == trimmed[-1] and trimmed[0] in ("\"", "`"):
+ trimmed = trimmed[1:-1]
+ elif trimmed[0] == "[" and trimmed[-1] == "]":
+ trimmed = trimmed[1:-1]
+ if "." in trimmed and not any(q in trimmed for q in ("\"", "`", "[")):
+ trimmed = trimmed.split(".")[-1]
+ return trimmed.lower()
+
+ def _get_active_results_table_info(
+ self: ResultsMixinHost,
+ table: SqlitDataTable | None,
+ stacked: bool,
+ ) -> dict[str, Any] | None:
+ if not table:
+ return None
+ if stacked:
+ section = self._find_results_section(table)
+ table_info = getattr(section, "result_table_info", None)
+ if table_info:
+ return table_info
+ table_info = getattr(table, "result_table_info", None)
+ if table_info:
+ return table_info
+ return getattr(self, "_last_query_table", None)
+
def _copy_text(self: ResultsMixinHost, text: str) -> bool:
"""Copy text to clipboard if possible, otherwise store internally."""
self._internal_clipboard = text
@@ -610,21 +728,20 @@ def sql_value(v: object) -> str:
# Get table name and primary key columns
table_name = "
"
pk_column_names: set[str] = set()
-
- if hasattr(self, "_last_query_table") and self._last_query_table:
- table_info = self._last_query_table
- table_name = table_info["name"]
+ table_info = self._get_active_results_table_info(table, _stacked)
+ if table_info:
+ table_name = table_info.get("name", table_name)
# Get PK columns from column info
for col in table_info.get("columns", []):
if col.is_primary_key:
- pk_column_names.add(col.name)
+ pk_column_names.add(self._normalize_column_name(col.name))
# Build WHERE clause - prefer PK columns, fall back to all columns
where_parts = []
for i, col in enumerate(columns):
if i < len(row_values):
# If we have PK info, only use PK columns; otherwise use all columns
- if pk_column_names and col not in pk_column_names:
+ if pk_column_names and self._normalize_column_name(col) not in pk_column_names:
continue
val = row_values[i]
if val is None:
@@ -685,9 +802,10 @@ def action_edit_cell(self: ResultsMixinHost) -> None:
column_name = columns[cursor_col]
# Check if this column is a primary key - don't allow editing PKs
- if hasattr(self, "_last_query_table") and self._last_query_table:
- for col in self._last_query_table.get("columns", []):
- if col.name == column_name and col.is_primary_key:
+ table_info = self._get_active_results_table_info(table, _stacked)
+ if table_info:
+ for col in table_info.get("columns", []):
+ if col.is_primary_key and self._normalize_column_name(col.name) == self._normalize_column_name(column_name):
self.notify("Cannot edit primary key column", severity="warning")
return
@@ -705,21 +823,19 @@ def sql_value(v: object) -> str:
# Get table name and primary key columns
table_name = ""
pk_column_names: set[str] = set()
-
- if hasattr(self, "_last_query_table") and self._last_query_table:
- table_info = self._last_query_table
- table_name = table_info["name"]
+ if table_info:
+ table_name = table_info.get("name", table_name)
# Get PK columns from column info
for col in table_info.get("columns", []):
if col.is_primary_key:
- pk_column_names.add(col.name)
+ pk_column_names.add(self._normalize_column_name(col.name))
# Build WHERE clause - prefer PK columns, fall back to all columns
where_parts = []
for i, col in enumerate(columns):
if i < len(row_values):
# If we have PK info, only use PK columns; otherwise use all columns
- if pk_column_names and col not in pk_column_names:
+ if pk_column_names and self._normalize_column_name(col) not in pk_column_names:
continue
val = row_values[i]
if val is None:
diff --git a/sqlit/domains/shell/app/commands/debug.py b/sqlit/domains/shell/app/commands/debug.py
index 7134c76..1cecf7d 100644
--- a/sqlit/domains/shell/app/commands/debug.py
+++ b/sqlit/domains/shell/app/commands/debug.py
@@ -41,6 +41,16 @@ def _set_debug_enabled(app: Any, enabled: bool) -> None:
else:
app._debug_events_enabled = bool(enabled)
+ # Persist the setting across sessions
+ try:
+ services = getattr(app, "services", None)
+ if services:
+ store = getattr(services, "settings_store", None)
+ if store:
+ store.set("debug_events_enabled", enabled)
+ except Exception:
+ pass
+
path = getattr(app, "_debug_event_log_path", None)
suffix = f" (log: {path})" if path else ""
state = "enabled" if enabled else "disabled"
diff --git a/sqlit/domains/shell/app/main.py b/sqlit/domains/shell/app/main.py
index 5f9fbec..66edcde 100644
--- a/sqlit/domains/shell/app/main.py
+++ b/sqlit/domains/shell/app/main.py
@@ -197,6 +197,7 @@ def __init__(
self._columns_loading: set[str] = set()
self._state_machine = UIStateMachine()
self._last_query_table: dict[str, Any] | None = None
+ self._pending_result_table_info: dict[str, Any] | None = None
self._query_target_database: str | None = None # Target DB for auto-generated queries
self._restart_requested: bool = False
# Idle scheduler for background work
diff --git a/sqlit/domains/shell/app/startup_flow.py b/sqlit/domains/shell/app/startup_flow.py
index 3d3bd62..84c482a 100644
--- a/sqlit/domains/shell/app/startup_flow.py
+++ b/sqlit/domains/shell/app/startup_flow.py
@@ -36,6 +36,10 @@ def run_on_mount(app: AppProtocol) -> None:
app._startup_stamp("settings_loaded")
app._expanded_paths = set(settings.get("expanded_nodes", []))
+ if settings.get("debug_events_enabled"):
+ setter = getattr(app, "_set_debug_events_enabled", None)
+ if callable(setter):
+ setter(True)
if "process_worker" in settings:
app.services.runtime.process_worker = bool(settings.get("process_worker"))
if "process_worker_warm_on_idle" in settings:
@@ -83,6 +87,9 @@ def run_on_mount(app: AppProtocol) -> None:
app.object_tree.cursor_line = 0
app._update_section_labels()
maybe_restore_connection_screen(app)
+ # Auto-connect to pending connection after driver install (if not already connecting)
+ if app._startup_connect_config is None:
+ maybe_auto_connect_pending(app)
app._startup_stamp("restore_checked")
if app._debug_mode:
app.call_after_refresh(app._record_launch_ms)
@@ -224,6 +231,83 @@ def _get_restart_cache_path() -> Path:
return Path(tempfile.gettempdir()) / "sqlit-driver-install-restore.json"
+def maybe_auto_connect_pending(app: AppProtocol) -> bool:
+ """Auto-connect to a pending connection after driver install restart.
+
+ Returns True if a connection was initiated, False otherwise.
+ """
+ from sqlit.shared.core.debug_events import emit_debug_event
+
+ from sqlit.domains.connections.ui.restart_cache import (
+ clear_restart_cache,
+ get_restart_cache_path,
+ )
+
+ cache_path = get_restart_cache_path()
+ emit_debug_event(
+ "startup.pending_connection_check",
+ cache_path=str(cache_path),
+ exists=cache_path.exists(),
+ )
+ if not cache_path.exists():
+ return False
+
+ emit_debug_event(
+ "startup.pending_connection_found",
+ contents=cache_path.read_text(),
+ )
+
+ try:
+ payload = json.loads(cache_path.read_text(encoding="utf-8"))
+ except Exception as e:
+ emit_debug_event("startup.pending_connection_parse_error", error=str(e))
+ clear_restart_cache()
+ return False
+
+ # Always clear cache after reading
+ clear_restart_cache()
+
+ # Check for version 2 pending_connection type
+ if not isinstance(payload, dict):
+ emit_debug_event("startup.pending_connection_invalid", reason="not a dict")
+ return False
+ if payload.get("version") != 2:
+ emit_debug_event("startup.pending_connection_invalid", reason="wrong version", version=payload.get("version"))
+ return False
+ if payload.get("type") != "pending_connection":
+ emit_debug_event("startup.pending_connection_invalid", reason="wrong type", type=payload.get("type"))
+ return False
+
+ connection_name = payload.get("connection_name")
+ if not connection_name:
+ emit_debug_event("startup.pending_connection_invalid", reason="no connection_name")
+ return False
+
+ emit_debug_event(
+ "startup.pending_connection_lookup",
+ connection_name=connection_name,
+ available_connections=[getattr(c, "name", None) for c in app.connections],
+ )
+
+ # Find the connection by name
+ config = next(
+ (c for c in app.connections if getattr(c, "name", None) == connection_name),
+ None,
+ )
+ if config is None:
+ emit_debug_event("startup.pending_connection_not_found", connection_name=connection_name)
+ return False
+
+ emit_debug_event("startup.pending_connection_connecting", connection_name=connection_name)
+
+ # Auto-connect after refresh (same pattern as startup_connect_config)
+ def _connect_pending() -> None:
+ app.connect_to_server(config)
+
+ app.call_after_refresh(_connect_pending)
+ return True
+
+
def maybe_restore_connection_screen(app: AppProtocol) -> None:
"""Restore an in-progress connection form after a driver-install restart."""
cache_path = _get_restart_cache_path()
@@ -239,14 +323,16 @@ def maybe_restore_connection_screen(app: AppProtocol) -> None:
pass
return
+ # Only handle version 1 (connection form restore), leave version 2 for maybe_auto_connect_pending
+ if not isinstance(payload, dict) or payload.get("version") != 1:
+ return
+
+ # Clear cache only for version 1
try:
cache_path.unlink(missing_ok=True)
except Exception:
pass
- if not isinstance(payload, dict) or payload.get("version") != 1:
- return
-
values = payload.get("values")
if not isinstance(values, dict):
return
diff --git a/sqlit/shared/ui/protocols/results.py b/sqlit/shared/ui/protocols/results.py
index d6e3f5e..2124b1a 100644
--- a/sqlit/shared/ui/protocols/results.py
+++ b/sqlit/shared/ui/protocols/results.py
@@ -16,6 +16,7 @@ class ResultsStateProtocol(Protocol):
_last_result_row_count: int
_internal_clipboard: str
_last_query_table: dict[str, Any] | None
+ _pending_result_table_info: dict[str, Any] | None
_results_table_counter: int
_results_filter_visible: bool
_results_filter_text: str
diff --git a/sqlit/shared/ui/widgets_stacked_results.py b/sqlit/shared/ui/widgets_stacked_results.py
index e61c819..015cc45 100644
--- a/sqlit/shared/ui/widgets_stacked_results.py
+++ b/sqlit/shared/ui/widgets_stacked_results.py
@@ -104,6 +104,7 @@ def __init__(
self.is_error = is_error
self.result_columns: list[str] = []
self.result_rows: list[tuple] = []
+ self.result_table_info: dict[str, Any] | None = None
self._content = content
if is_error:
self.add_class("error")
@@ -156,6 +157,7 @@ def add_result_section(
index: int,
*,
auto_collapse: bool = False,
+ table_info: dict[str, Any] | None = None,
) -> None:
"""Add a result section for a statement result."""
from sqlit.domains.query.app.query_service import QueryResult
@@ -169,6 +171,8 @@ def add_result_section(
# SELECT result - build a DataTable
result_columns, result_rows = self._get_result_table_data(stmt_result.result)
content = self._build_result_table_from_rows(result_columns, result_rows, index)
+ if table_info is not None:
+ content.result_table_info = table_info
else:
# Non-query result (INSERT/UPDATE/DELETE)
content = NonQueryDisplay(stmt_result.result.rows_affected)
@@ -186,6 +190,7 @@ def add_result_section(
)
section.result_columns = result_columns
section.result_rows = result_rows
+ section.result_table_info = table_info
self.mount(section)
self._section_count += 1
diff --git a/tests/unit/test_auto_reconnect_after_driver_install.py b/tests/unit/test_auto_reconnect_after_driver_install.py
new file mode 100644
index 0000000..7e4ad96
--- /dev/null
+++ b/tests/unit/test_auto_reconnect_after_driver_install.py
@@ -0,0 +1,100 @@
+"""Test auto-reconnect after driver installation restart."""
+
+from __future__ import annotations
+
+import json
+from unittest.mock import MagicMock
+
+from sqlit.domains.connections.domain.config import ConnectionConfig
+
+
+class TestAutoReconnectAfterDriverInstall:
+ """Test that app auto-connects after driver install restart."""
+
+ def test_pending_connection_cache_written_on_missing_driver(self):
+ """
+ When user tries to connect but driver is missing,
+ the connection name should be cached for auto-reconnect after restart.
+ """
+ from sqlit.domains.connections.ui.restart_cache import (
+ get_restart_cache_path,
+ write_pending_connection_cache,
+ )
+
+ config = ConnectionConfig(name="my-mssql-server", db_type="mssql")
+
+ # Write the pending connection cache
+ write_pending_connection_cache(config.name)
+
+ # Verify cache was written
+ cache_path = get_restart_cache_path()
+ assert cache_path.exists()
+
+ payload = json.loads(cache_path.read_text())
+ assert payload["version"] == 2
+ assert payload["type"] == "pending_connection"
+ assert payload["connection_name"] == "my-mssql-server"
+
+ # Cleanup
+ cache_path.unlink(missing_ok=True)
+
+ def test_startup_reads_pending_connection_and_connects(self):
+ """
+ On startup, if pending_connection cache exists,
+ app should auto-connect to that connection.
+ """
+ from sqlit.domains.connections.ui.restart_cache import (
+ get_restart_cache_path,
+ write_pending_connection_cache,
+ )
+ from sqlit.domains.shell.app.startup_flow import maybe_auto_connect_pending
+
+ # Setup: Write pending connection cache
+ write_pending_connection_cache("my-mssql-server")
+
+ # Mock app with the saved connection
+ mock_app = MagicMock()
+ saved_config = ConnectionConfig(name="my-mssql-server", db_type="mssql")
+ mock_app.connections = [saved_config]
+ mock_app.connect_to_server = MagicMock()
+ mock_app.call_after_refresh = MagicMock()
+
+ # Call the startup function
+ result = maybe_auto_connect_pending(mock_app)
+
+ # Should have scheduled a connection via call_after_refresh
+ assert result is True
+ mock_app.call_after_refresh.assert_called_once()
+
+ # Execute the callback to verify it calls connect_to_server
+ callback = mock_app.call_after_refresh.call_args[0][0]
+ callback()
+ mock_app.connect_to_server.assert_called_once_with(saved_config)
+
+ # Cache should be cleared
+ assert not get_restart_cache_path().exists()
+
+ def test_startup_ignores_missing_connection(self):
+ """
+ If the cached connection no longer exists, don't crash.
+ """
+ from sqlit.domains.connections.ui.restart_cache import (
+ get_restart_cache_path,
+ write_pending_connection_cache,
+ )
+ from sqlit.domains.shell.app.startup_flow import maybe_auto_connect_pending
+
+ write_pending_connection_cache("deleted-connection")
+
+ mock_app = MagicMock()
+ mock_app.connections = [] # No connections
+ mock_app.connect_to_server = MagicMock()
+
+ result = maybe_auto_connect_pending(mock_app)
+
+ # Should return False (no connection made)
+ assert result is False
+ mock_app.connect_to_server.assert_not_called()
+
+ # Cache should still be cleared
+ assert not get_restart_cache_path().exists()
diff --git a/tests/unit/test_connection_picker_refresh.py b/tests/unit/test_connection_picker_refresh.py
new file mode 100644
index 0000000..03ba272
--- /dev/null
+++ b/tests/unit/test_connection_picker_refresh.py
@@ -0,0 +1,117 @@
+"""Test that tree refresh reloads connections from disk."""
+
+from __future__ import annotations
+
+from unittest.mock import MagicMock, patch
+
+
+class TestTreeRefresh:
+ """Test that pressing 'f' in explorer reloads saved connections."""
+
+ def test_action_refresh_tree_reloads_connections(self):
+ """
+ Bug: action_refresh_tree didn't reload saved connections from store.
+ Fix: Now it calls connection_store.load_all() and updates self.connections.
+ """
+ from sqlit.domains.connections.domain.config import ConnectionConfig
+
+ # Create a mock host with required attributes
+ mock_host = MagicMock()
+
+ # Set up initial state
+ initial_conn = ConnectionConfig(name="existing", db_type="sqlite")
+ mock_host.connections = [initial_conn]
+
+ # Mock the services and store
+ mock_store = MagicMock()
+ new_conn = ConnectionConfig(name="new-cli-conn", db_type="postgresql")
+ mock_store.load_all.return_value = [initial_conn, new_conn]
+
+ mock_services = MagicMock()
+ mock_services.connection_store = mock_store
+ mock_host.services = mock_services
+
+ # Mock other methods
+ mock_host._get_object_cache.return_value = MagicMock()
+ mock_host._schema_cache = {"columns": {}}
+ mock_host._loading_nodes = set()
+ mock_host._schema_service = None
+ mock_host.refresh_tree = MagicMock()
+
+ # Import and call the mixin method directly
+ from sqlit.domains.explorer.ui.mixins.tree import TreeMixin
+
+ # Before: 1 connection
+ assert len(mock_host.connections) == 1
+
+ # Call refresh using the mixin method bound to our mock
+ TreeMixin.action_refresh_tree(mock_host)
+
+ # Verify store.load_all was called
+ mock_store.load_all.assert_called_once_with(load_credentials=False)
+
+ # After: 2 connections (reloaded from store)
+ assert len(mock_host.connections) == 2
+ assert mock_host.connections[1].name == "new-cli-conn"
+
+ # Verify tree was rebuilt
+ mock_host.refresh_tree.assert_called_once()
+
+ def test_action_refresh_tree_handles_store_error(self):
+ """Test that refresh handles store errors gracefully."""
+ from sqlit.domains.connections.domain.config import ConnectionConfig
+
+ mock_host = MagicMock()
+ mock_host.connections = [ConnectionConfig(name="existing", db_type="sqlite")]
+
+ mock_store = MagicMock()
+ mock_store.load_all.side_effect = Exception("File not found")
+
+ mock_services = MagicMock()
+ mock_services.connection_store = mock_store
+ mock_host.services = mock_services
+
+ mock_host._get_object_cache.return_value = MagicMock()
+ mock_host._schema_cache = {"columns": {}}
+ mock_host._loading_nodes = set()
+ mock_host._schema_service = None
+ mock_host.refresh_tree = MagicMock()
+
+ from sqlit.domains.explorer.ui.mixins.tree import TreeMixin
+
+ # Should not raise, should keep existing connections
+ TreeMixin.action_refresh_tree(mock_host)
+
+ # Connections should be unchanged
+ assert len(mock_host.connections) == 1
+ assert mock_host.connections[0].name == "existing"
+
+ # Tree should still be refreshed
+ mock_host.refresh_tree.assert_called_once()
+
+ def test_action_refresh_tree_handles_missing_services(self):
+ """Test that refresh handles missing services gracefully."""
+ from sqlit.domains.connections.domain.config import ConnectionConfig
+
+ mock_host = MagicMock()
+ mock_host.connections = [ConnectionConfig(name="existing", db_type="sqlite")]
+
+ # No services attribute
+ del mock_host.services
+
+ mock_host._get_object_cache.return_value = MagicMock()
+ mock_host._schema_cache = {"columns": {}}
+ mock_host._loading_nodes = set()
+ mock_host._schema_service = None
+ mock_host.refresh_tree = MagicMock()
+
+ from sqlit.domains.explorer.ui.mixins.tree import TreeMixin
+
+ # Should not raise
+ TreeMixin.action_refresh_tree(mock_host)
+
+ # Connections should be unchanged
+ assert len(mock_host.connections) == 1
+
+ # Tree should still be refreshed
+ mock_host.refresh_tree.assert_called_once()