Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions sqlit/domains/connections/ui/connection_error_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,19 @@ def can_handle(self, error: Exception) -> bool:

def handle(self, app: ConnectionErrorApp, error: Exception, config: ConnectionConfig) -> None:
from sqlit.domains.connections.providers.exceptions import MissingDriverError
from sqlit.shared.core.debug_events import emit_debug_event

from .restart_cache import write_pending_connection_cache
from .screens import PackageSetupScreen

# Save pending connection for auto-reconnect after driver install restart
if config.name:
write_pending_connection_cache(config.name)
emit_debug_event(
"driver_install.pending_connection_saved",
connection_name=config.name,
)

# No on_success callback - uses default "Restart to apply" behavior
app.push_screen(PackageSetupScreen(cast(MissingDriverError, error)))

Expand Down
15 changes: 15 additions & 0 deletions sqlit/domains/connections/ui/restart_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,18 @@ def clear_restart_cache() -> None:
get_restart_cache_path().unlink(missing_ok=True)
except Exception:
pass


def write_pending_connection_cache(connection_name: str) -> None:
"""Cache a pending connection name for auto-reconnect after driver install restart.

This is used when a user tries to connect to a server but the driver is missing.
After the driver is installed and the app restarts, it can auto-connect to this
connection.
"""
payload = {
"version": 2,
"type": "pending_connection",
"connection_name": connection_name,
}
write_restart_cache(payload)
14 changes: 14 additions & 0 deletions sqlit/domains/explorer/ui/mixins/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,18 @@ def action_refresh_tree(self: TreeMixinHost) -> None:
if hasattr(self, "_loading_nodes"):
self._loading_nodes.clear()
self._schema_service = None

# Reload saved connections from disk (in case added via CLI)
try:
services = getattr(self, "services", None)
if services:
store = getattr(services, "connection_store", None)
if store:
reloaded = store.load_all(load_credentials=False)
self.connections = reloaded
except Exception:
pass # Keep existing connections if reload fails

self.refresh_tree()
loader = getattr(self, "_load_schema_cache", None)
if callable(loader):
Expand Down Expand Up @@ -295,6 +307,8 @@ def action_select_table(self: TreeMixinHost) -> None:
"name": data.name,
"columns": [],
}
# Stash per-result metadata so results can resolve PKs without relying on globals.
self._pending_result_table_info = self._last_query_table
self._prime_last_query_table_columns(data.database, data.schema, data.name)

self.query_input.text = self.current_provider.dialect.build_select_query(
Expand Down
11 changes: 11 additions & 0 deletions sqlit/domains/query/ui/mixins/query_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,17 @@ async def _run_query_async(self: QueryMixinHost, query: str, keep_insert_mode: b
executable_statements = [s for s in statements if not is_comment_only_statement(s)]
is_multi_statement = len(executable_statements) > 1

if is_multi_statement:
self._pending_result_table_info = None
elif executable_statements:
if getattr(self, "_pending_result_table_info", None) is None:
table_info = self._infer_result_table_info(executable_statements[0])
if table_info is not None:
self._pending_result_table_info = table_info
prime = getattr(self, "_prime_result_table_columns", None)
if callable(prime):
prime(table_info)

try:
start_time = time.perf_counter()
max_rows = self.services.runtime.max_rows or MAX_FETCH_ROWS
Expand Down
59 changes: 58 additions & 1 deletion sqlit/domains/query/ui/mixins/query_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ def _render_results_table_incremental(
escape: bool,
row_limit: int,
render_token: int,
table_info: dict[str, Any] | None = None,
) -> None:
initial_count = min(RESULTS_RENDER_INITIAL_ROWS, row_limit)
initial_rows = rows[:initial_count] if initial_count > 0 else []
Expand Down Expand Up @@ -269,9 +270,16 @@ def _render_results_table_incremental(
pass
if render_token == getattr(self, "_results_render_token", 0):
self._replace_results_table_with_data(columns, rows, escape=escape)
if table_info is not None:
try:
self.results_table.result_table_info = table_info
except Exception:
pass
return
if render_token != getattr(self, "_results_render_token", 0):
return
if table_info is not None:
table.result_table_info = table_info
self._replace_results_table_with_table(table)
self._schedule_results_render(
table,
Expand All @@ -291,6 +299,7 @@ async def _display_query_results(
self._last_result_columns = columns
self._last_result_rows = rows
self._last_result_row_count = row_count
table_info = getattr(self, "_pending_result_table_info", None)

# Switch to single result mode (in case we were showing stacked results)
self._show_single_result_mode()
Expand All @@ -304,12 +313,15 @@ async def _display_query_results(
escape=True,
row_limit=row_limit,
render_token=render_token,
table_info=table_info,
)
else:
render_rows = rows[:row_limit] if row_limit else []
table = self._build_results_table(columns, render_rows, escape=True)
if render_token != getattr(self, "_results_render_token", 0):
return
if table_info is not None:
table.result_table_info = table_info
self._replace_results_table_with_table(table)

time_str = format_duration_ms(elapsed_ms)
Expand All @@ -320,9 +332,15 @@ async def _display_query_results(
)
else:
self.notify(f"Query returned {row_count} rows in {time_str}")
if table_info is not None:
prime = getattr(self, "_prime_result_table_columns", None)
if callable(prime):
prime(table_info)
self._pending_result_table_info = None

def _display_non_query_result(self: QueryMixinHost, affected: int, elapsed_ms: float) -> None:
"""Display non-query result (called on main thread)."""
self._pending_result_table_info = None
self._last_result_columns = ["Result"]
self._last_result_rows = [(f"{affected} row(s) affected",)]
self._last_result_row_count = 1
Expand All @@ -337,6 +355,7 @@ def _display_non_query_result(self: QueryMixinHost, affected: int, elapsed_ms: f
def _display_query_error(self: QueryMixinHost, error_message: str) -> None:
"""Display query error (called on main thread)."""
self._cancel_results_render()
self._pending_result_table_info = None
# notify(severity="error") handles displaying the error in results via _show_error_in_results
self.notify(f"Query error: {error_message}", severity="error")

Expand All @@ -360,7 +379,17 @@ def _display_multi_statement_results(

# Add each result section
for i, stmt_result in enumerate(multi_result.results):
container.add_result_section(stmt_result, i, auto_collapse=auto_collapse)
table_info = self._infer_result_table_info(stmt_result.statement)
if table_info is not None:
prime = getattr(self, "_prime_result_table_columns", None)
if callable(prime):
prime(table_info)
container.add_result_section(
stmt_result,
i,
auto_collapse=auto_collapse,
table_info=table_info,
)

# Show the stacked results container, hide single result table
self._show_stacked_results_mode()
Expand All @@ -378,6 +407,7 @@ def _display_multi_statement_results(
)
else:
self.notify(f"Executed {total} statements in {time_str}")
self._pending_result_table_info = None

def _get_stacked_results_container(self: QueryMixinHost) -> Any:
"""Get the stacked results container."""
Expand Down Expand Up @@ -410,3 +440,30 @@ def _show_single_result_mode(self: QueryMixinHost) -> None:
stacked.remove_class("active")
except Exception:
pass

def _infer_result_table_info(self: QueryMixinHost, sql: str) -> dict[str, Any] | None:
"""Best-effort inference of a single source table for query results."""
from sqlit.domains.query.completion import extract_table_refs

refs = extract_table_refs(sql)
if len(refs) != 1:
return None
ref = refs[0]
schema = ref.schema
name = ref.name
database = None
table_metadata = getattr(self, "_table_metadata", {}) or {}
key_candidates = [name.lower()]
if schema:
key_candidates.insert(0, f"{schema}.{name}".lower())
for key in key_candidates:
metadata = table_metadata.get(key)
if metadata:
schema, name, database = metadata
break
return {
"database": database,
"schema": schema,
"name": name,
"columns": [],
}
Loading
Loading