From 5685979d72325a3f2c8e08f9db5650bf4f647806 Mon Sep 17 00:00:00 2001 From: Peter Adams <18162810+Maxteabag@users.noreply.github.com> Date: Sat, 31 Jan 2026 15:05:19 +0100 Subject: [PATCH 1/6] Auto-refresh explorer after schema-changing queries --- sqlit/domains/explorer/ui/mixins/tree.py | 14 ++- .../query/ui/mixins/query_execution.py | 61 +++++++++++++ tests/unit/test_autorefresh_schema_changes.py | 85 +++++++++++++++++++ 3 files changed, 157 insertions(+), 3 deletions(-) create mode 100644 tests/unit/test_autorefresh_schema_changes.py diff --git a/sqlit/domains/explorer/ui/mixins/tree.py b/sqlit/domains/explorer/ui/mixins/tree.py index 9491461..f06091a 100644 --- a/sqlit/domains/explorer/ui/mixins/tree.py +++ b/sqlit/domains/explorer/ui/mixins/tree.py @@ -208,8 +208,7 @@ def on_tree_node_highlighted(self: TreeMixinHost, event: Tree.NodeHighlighted) - """Update footer when tree selection changes.""" self._update_footer_bindings() - def action_refresh_tree(self: TreeMixinHost) -> None: - """Refresh the explorer.""" + def _refresh_tree_common(self: TreeMixinHost, *, notify: bool) -> None: self._get_object_cache().clear() if hasattr(self, "_schema_cache") and "columns" in self._schema_cache: self._schema_cache["columns"] = {} @@ -241,7 +240,16 @@ def run_loader() -> None: ) else: self._schedule_timer(MIN_TIMER_DELAY_S, run_loader) - self.notify("Refreshed") + if notify: + self.notify("Refreshed") + + def _refresh_tree_after_query(self: TreeMixinHost) -> None: + """Refresh the explorer without user-facing notifications.""" + self._refresh_tree_common(notify=False) + + def action_refresh_tree(self: TreeMixinHost) -> None: + """Refresh the explorer.""" + self._refresh_tree_common(notify=True) def refresh_tree(self: TreeMixinHost) -> None: tree_builder.refresh_tree_chunked(self) diff --git a/sqlit/domains/query/ui/mixins/query_execution.py b/sqlit/domains/query/ui/mixins/query_execution.py index 6b1c4e2..c8a7229 100644 --- a/sqlit/domains/query/ui/mixins/query_execution.py +++ b/sqlit/domains/query/ui/mixins/query_execution.py @@ -2,6 +2,7 @@ from __future__ import annotations +import re from typing import TYPE_CHECKING, Any, Callable from sqlit.domains.explorer.ui.tree import db_switching as tree_db_switching @@ -21,6 +22,43 @@ from sqlit.domains.query.app.transaction import TransactionExecutor +_SCHEMA_CHANGE_KEYWORDS = ( + "CREATE", + "ALTER", + "DROP", + "TRUNCATE", + "RENAME", + "COMMENT", +) +_SCHEMA_CHANGE_RE = re.compile(r"\b(?:%s)\b" % "|".join(_SCHEMA_CHANGE_KEYWORDS), re.IGNORECASE) +_SINGLE_QUOTE_RE = re.compile(r"'[^']*'") +_DOUBLE_QUOTE_RE = re.compile(r'"[^"]*"') +_BACKTICK_RE = re.compile(r"`[^`]*`") +_BRACKET_RE = re.compile(r"\[[^\]]*\]") + + +def _strip_literals(sql: str) -> str: + cleaned = _SINGLE_QUOTE_RE.sub("''", sql) + cleaned = _DOUBLE_QUOTE_RE.sub('""', cleaned) + cleaned = _BACKTICK_RE.sub("``", cleaned) + cleaned = _BRACKET_RE.sub("[]", cleaned) + return cleaned + + +def _query_changes_schema(sql: str) -> bool: + if not sql or not sql.strip(): + return False + from sqlit.domains.query.app.multi_statement import split_statements + from sqlit.domains.query.editing.comments import strip_all_comments + + for statement in split_statements(sql): + cleaned = strip_all_comments(statement) + cleaned = _strip_literals(cleaned) + if _SCHEMA_CHANGE_RE.search(cleaned): + return True + return False + + class QueryExecutionMixin(ProcessWorkerLifecycleMixin): """Mixin providing query execution actions.""" @@ -243,6 +281,25 @@ def _stop_query_spinner(self: QueryMixinHost) -> None: except Exception: pass + def _maybe_refresh_explorer_after_query(self: QueryMixinHost, query: str) -> None: + if not _query_changes_schema(query): + return + + def run_refresh() -> None: + refresher = getattr(self, "_refresh_tree_after_query", None) + if callable(refresher): + refresher() + return + action_refresh = getattr(self, "action_refresh_tree", None) + if callable(action_refresh): + action_refresh() + + call_after_refresh = getattr(self, "call_after_refresh", None) + if callable(call_after_refresh): + call_after_refresh(run_refresh) + else: + run_refresh() + def _get_history_store(self: QueryMixinHost) -> Any: store = getattr(self, "_history_store", None) if store is not None: @@ -470,6 +527,7 @@ async def _run_query_async(self: QueryMixinHost, query: str, keep_insert_mode: b ) else: self._display_non_query_result(result.rows_affected, elapsed_ms) + self._maybe_refresh_explorer_after_query(query) if keep_insert_mode: self._restore_insert_mode() return @@ -489,6 +547,7 @@ async def _run_query_async(self: QueryMixinHost, query: str, keep_insert_mode: b except Exception: pass self._display_multi_statement_results(multi_result, elapsed_ms) + self._maybe_refresh_explorer_after_query(query) else: # Single statement - existing behavior result = await asyncio.to_thread( @@ -509,6 +568,7 @@ async def _run_query_async(self: QueryMixinHost, query: str, keep_insert_mode: b ) else: self._display_non_query_result(result.rows_affected, elapsed_ms) + self._maybe_refresh_explorer_after_query(query) if keep_insert_mode: self._restore_insert_mode() @@ -581,6 +641,7 @@ async def _run_query_atomic_async(self: QueryMixinHost, query: str) -> None: else: self._display_non_query_result(result.rows_affected, elapsed_ms) self.notify("Query executed atomically (committed)", severity="information") + self._maybe_refresh_explorer_after_query(query) except Exception as e: self._display_query_error(f"Transaction rolled back: {e}") diff --git a/tests/unit/test_autorefresh_schema_changes.py b/tests/unit/test_autorefresh_schema_changes.py new file mode 100644 index 0000000..8927bf1 --- /dev/null +++ b/tests/unit/test_autorefresh_schema_changes.py @@ -0,0 +1,85 @@ +"""Tests for auto-refreshing explorer after schema-changing queries.""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest + +from sqlit.domains.query.app.query_service import NonQueryResult, QueryResult +from sqlit.domains.query.ui.mixins.query_execution import QueryExecutionMixin + + +class MockExecutor: + def __init__(self, result) -> None: + self._result = result + + def execute(self, query: str, max_rows: int | None = None): + _ = query, max_rows + return self._result + + +class MockHost(QueryExecutionMixin): + def __init__(self) -> None: + self.current_connection = object() + self.current_provider = MagicMock() + self.current_provider.apply_database_override = lambda config, db: config + self.current_provider.metadata = MagicMock(db_type="mock") + self.current_config = MagicMock() + self.current_config.tcp_endpoint = MagicMock(database="") + self.query_input = MagicMock() + self.services = MagicMock() + self.services.runtime.max_rows = 100 + self.services.runtime.query_alert_mode = 0 + self._query_worker = None + self._query_spinner = None + self._query_target_database = None + self._executor = MockExecutor(NonQueryResult(rows_affected=0)) + self._refresh_called = False + + def _use_process_worker(self, provider) -> bool: + _ = provider + return False + + def _get_transaction_executor(self, config, provider): + _ = config, provider + return self._executor + + async def _display_query_results(self, columns, rows, row_count, truncated, elapsed_ms) -> None: + _ = columns, rows, row_count, truncated, elapsed_ms + + def _display_non_query_result(self, affected, elapsed_ms) -> None: + _ = affected, elapsed_ms + + def _display_multi_statement_results(self, multi_result, elapsed_ms) -> None: + _ = multi_result, elapsed_ms + + def _display_query_error(self, error_message: str) -> None: + _ = error_message + + def _stop_query_spinner(self) -> None: + pass + + def _get_effective_database(self): + return None + + def call_after_refresh(self, callback) -> None: + callback() + + def _refresh_tree_after_query(self) -> None: + self._refresh_called = True + + +@pytest.mark.asyncio +async def test_schema_change_query_triggers_refresh() -> None: + host = MockHost() + await host._run_query_async("CREATE TABLE test_users(id INT)", keep_insert_mode=False) + assert host._refresh_called is True + + +@pytest.mark.asyncio +async def test_select_query_does_not_trigger_refresh() -> None: + host = MockHost() + host._executor = MockExecutor(QueryResult(columns=[], rows=[], row_count=0, truncated=False)) + await host._run_query_async("SELECT * FROM test_users", keep_insert_mode=False) + assert host._refresh_called is False From f87896eae248f75d20d775c9672be9e2b9fe17d5 Mon Sep 17 00:00:00 2001 From: Peter Adams <18162810+Maxteabag@users.noreply.github.com> Date: Sat, 31 Jan 2026 15:10:36 +0100 Subject: [PATCH 2/6] Auto-refresh explorer after schema changes --- sqlit/domains/explorer/ui/tree/builder.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/sqlit/domains/explorer/ui/tree/builder.py b/sqlit/domains/explorer/ui/tree/builder.py index 76d9272..f47dadf 100644 --- a/sqlit/domains/explorer/ui/tree/builder.py +++ b/sqlit/domains/explorer/ui/tree/builder.py @@ -549,6 +549,12 @@ def update_connection_state( # Update new connected node and populate it if new_config is not None and host.current_connection is not None: populate_connected_tree(host) + try: + from . import loaders as tree_loaders + except Exception: + tree_loaders = None + if tree_loaders is not None: + tree_loaders.ensure_expanded_nodes_loaded(host, host.object_tree.root) def remove_connection_nodes(host: TreeMixinHost, names: set[str]) -> None: From 29f4f2aa2e226420ee6ff984356ddef935c86cc5 Mon Sep 17 00:00:00 2001 From: Peter Adams <18162810+Maxteabag@users.noreply.github.com> Date: Sat, 31 Jan 2026 15:12:52 +0100 Subject: [PATCH 3/6] Revert "Auto-refresh explorer after schema changes" This reverts commit f87896eae248f75d20d775c9672be9e2b9fe17d5. --- sqlit/domains/explorer/ui/tree/builder.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/sqlit/domains/explorer/ui/tree/builder.py b/sqlit/domains/explorer/ui/tree/builder.py index f47dadf..76d9272 100644 --- a/sqlit/domains/explorer/ui/tree/builder.py +++ b/sqlit/domains/explorer/ui/tree/builder.py @@ -549,12 +549,6 @@ def update_connection_state( # Update new connected node and populate it if new_config is not None and host.current_connection is not None: populate_connected_tree(host) - try: - from . import loaders as tree_loaders - except Exception: - tree_loaders = None - if tree_loaders is not None: - tree_loaders.ensure_expanded_nodes_loaded(host, host.object_tree.root) def remove_connection_nodes(host: TreeMixinHost, names: set[str]) -> None: From e2edf5a9df83a92ebedb2c9fb6f12693013ba7fe Mon Sep 17 00:00:00 2001 From: Peter Adams <18162810+Maxteabag@users.noreply.github.com> Date: Sat, 31 Jan 2026 15:12:52 +0100 Subject: [PATCH 4/6] Revert "Auto-refresh explorer after schema-changing queries" This reverts commit 5685979d72325a3f2c8e08f9db5650bf4f647806. --- sqlit/domains/explorer/ui/mixins/tree.py | 14 +-- .../query/ui/mixins/query_execution.py | 61 ------------- tests/unit/test_autorefresh_schema_changes.py | 85 ------------------- 3 files changed, 3 insertions(+), 157 deletions(-) delete mode 100644 tests/unit/test_autorefresh_schema_changes.py diff --git a/sqlit/domains/explorer/ui/mixins/tree.py b/sqlit/domains/explorer/ui/mixins/tree.py index f06091a..9491461 100644 --- a/sqlit/domains/explorer/ui/mixins/tree.py +++ b/sqlit/domains/explorer/ui/mixins/tree.py @@ -208,7 +208,8 @@ def on_tree_node_highlighted(self: TreeMixinHost, event: Tree.NodeHighlighted) - """Update footer when tree selection changes.""" self._update_footer_bindings() - def _refresh_tree_common(self: TreeMixinHost, *, notify: bool) -> None: + def action_refresh_tree(self: TreeMixinHost) -> None: + """Refresh the explorer.""" self._get_object_cache().clear() if hasattr(self, "_schema_cache") and "columns" in self._schema_cache: self._schema_cache["columns"] = {} @@ -240,16 +241,7 @@ def run_loader() -> None: ) else: self._schedule_timer(MIN_TIMER_DELAY_S, run_loader) - if notify: - self.notify("Refreshed") - - def _refresh_tree_after_query(self: TreeMixinHost) -> None: - """Refresh the explorer without user-facing notifications.""" - self._refresh_tree_common(notify=False) - - def action_refresh_tree(self: TreeMixinHost) -> None: - """Refresh the explorer.""" - self._refresh_tree_common(notify=True) + self.notify("Refreshed") def refresh_tree(self: TreeMixinHost) -> None: tree_builder.refresh_tree_chunked(self) diff --git a/sqlit/domains/query/ui/mixins/query_execution.py b/sqlit/domains/query/ui/mixins/query_execution.py index c8a7229..6b1c4e2 100644 --- a/sqlit/domains/query/ui/mixins/query_execution.py +++ b/sqlit/domains/query/ui/mixins/query_execution.py @@ -2,7 +2,6 @@ from __future__ import annotations -import re from typing import TYPE_CHECKING, Any, Callable from sqlit.domains.explorer.ui.tree import db_switching as tree_db_switching @@ -22,43 +21,6 @@ from sqlit.domains.query.app.transaction import TransactionExecutor -_SCHEMA_CHANGE_KEYWORDS = ( - "CREATE", - "ALTER", - "DROP", - "TRUNCATE", - "RENAME", - "COMMENT", -) -_SCHEMA_CHANGE_RE = re.compile(r"\b(?:%s)\b" % "|".join(_SCHEMA_CHANGE_KEYWORDS), re.IGNORECASE) -_SINGLE_QUOTE_RE = re.compile(r"'[^']*'") -_DOUBLE_QUOTE_RE = re.compile(r'"[^"]*"') -_BACKTICK_RE = re.compile(r"`[^`]*`") -_BRACKET_RE = re.compile(r"\[[^\]]*\]") - - -def _strip_literals(sql: str) -> str: - cleaned = _SINGLE_QUOTE_RE.sub("''", sql) - cleaned = _DOUBLE_QUOTE_RE.sub('""', cleaned) - cleaned = _BACKTICK_RE.sub("``", cleaned) - cleaned = _BRACKET_RE.sub("[]", cleaned) - return cleaned - - -def _query_changes_schema(sql: str) -> bool: - if not sql or not sql.strip(): - return False - from sqlit.domains.query.app.multi_statement import split_statements - from sqlit.domains.query.editing.comments import strip_all_comments - - for statement in split_statements(sql): - cleaned = strip_all_comments(statement) - cleaned = _strip_literals(cleaned) - if _SCHEMA_CHANGE_RE.search(cleaned): - return True - return False - - class QueryExecutionMixin(ProcessWorkerLifecycleMixin): """Mixin providing query execution actions.""" @@ -281,25 +243,6 @@ def _stop_query_spinner(self: QueryMixinHost) -> None: except Exception: pass - def _maybe_refresh_explorer_after_query(self: QueryMixinHost, query: str) -> None: - if not _query_changes_schema(query): - return - - def run_refresh() -> None: - refresher = getattr(self, "_refresh_tree_after_query", None) - if callable(refresher): - refresher() - return - action_refresh = getattr(self, "action_refresh_tree", None) - if callable(action_refresh): - action_refresh() - - call_after_refresh = getattr(self, "call_after_refresh", None) - if callable(call_after_refresh): - call_after_refresh(run_refresh) - else: - run_refresh() - def _get_history_store(self: QueryMixinHost) -> Any: store = getattr(self, "_history_store", None) if store is not None: @@ -527,7 +470,6 @@ async def _run_query_async(self: QueryMixinHost, query: str, keep_insert_mode: b ) else: self._display_non_query_result(result.rows_affected, elapsed_ms) - self._maybe_refresh_explorer_after_query(query) if keep_insert_mode: self._restore_insert_mode() return @@ -547,7 +489,6 @@ async def _run_query_async(self: QueryMixinHost, query: str, keep_insert_mode: b except Exception: pass self._display_multi_statement_results(multi_result, elapsed_ms) - self._maybe_refresh_explorer_after_query(query) else: # Single statement - existing behavior result = await asyncio.to_thread( @@ -568,7 +509,6 @@ async def _run_query_async(self: QueryMixinHost, query: str, keep_insert_mode: b ) else: self._display_non_query_result(result.rows_affected, elapsed_ms) - self._maybe_refresh_explorer_after_query(query) if keep_insert_mode: self._restore_insert_mode() @@ -641,7 +581,6 @@ async def _run_query_atomic_async(self: QueryMixinHost, query: str) -> None: else: self._display_non_query_result(result.rows_affected, elapsed_ms) self.notify("Query executed atomically (committed)", severity="information") - self._maybe_refresh_explorer_after_query(query) except Exception as e: self._display_query_error(f"Transaction rolled back: {e}") diff --git a/tests/unit/test_autorefresh_schema_changes.py b/tests/unit/test_autorefresh_schema_changes.py deleted file mode 100644 index 8927bf1..0000000 --- a/tests/unit/test_autorefresh_schema_changes.py +++ /dev/null @@ -1,85 +0,0 @@ -"""Tests for auto-refreshing explorer after schema-changing queries.""" - -from __future__ import annotations - -from unittest.mock import MagicMock - -import pytest - -from sqlit.domains.query.app.query_service import NonQueryResult, QueryResult -from sqlit.domains.query.ui.mixins.query_execution import QueryExecutionMixin - - -class MockExecutor: - def __init__(self, result) -> None: - self._result = result - - def execute(self, query: str, max_rows: int | None = None): - _ = query, max_rows - return self._result - - -class MockHost(QueryExecutionMixin): - def __init__(self) -> None: - self.current_connection = object() - self.current_provider = MagicMock() - self.current_provider.apply_database_override = lambda config, db: config - self.current_provider.metadata = MagicMock(db_type="mock") - self.current_config = MagicMock() - self.current_config.tcp_endpoint = MagicMock(database="") - self.query_input = MagicMock() - self.services = MagicMock() - self.services.runtime.max_rows = 100 - self.services.runtime.query_alert_mode = 0 - self._query_worker = None - self._query_spinner = None - self._query_target_database = None - self._executor = MockExecutor(NonQueryResult(rows_affected=0)) - self._refresh_called = False - - def _use_process_worker(self, provider) -> bool: - _ = provider - return False - - def _get_transaction_executor(self, config, provider): - _ = config, provider - return self._executor - - async def _display_query_results(self, columns, rows, row_count, truncated, elapsed_ms) -> None: - _ = columns, rows, row_count, truncated, elapsed_ms - - def _display_non_query_result(self, affected, elapsed_ms) -> None: - _ = affected, elapsed_ms - - def _display_multi_statement_results(self, multi_result, elapsed_ms) -> None: - _ = multi_result, elapsed_ms - - def _display_query_error(self, error_message: str) -> None: - _ = error_message - - def _stop_query_spinner(self) -> None: - pass - - def _get_effective_database(self): - return None - - def call_after_refresh(self, callback) -> None: - callback() - - def _refresh_tree_after_query(self) -> None: - self._refresh_called = True - - -@pytest.mark.asyncio -async def test_schema_change_query_triggers_refresh() -> None: - host = MockHost() - await host._run_query_async("CREATE TABLE test_users(id INT)", keep_insert_mode=False) - assert host._refresh_called is True - - -@pytest.mark.asyncio -async def test_select_query_does_not_trigger_refresh() -> None: - host = MockHost() - host._executor = MockExecutor(QueryResult(columns=[], rows=[], row_count=0, truncated=False)) - await host._run_query_async("SELECT * FROM test_users", keep_insert_mode=False) - assert host._refresh_called is False From 568c9ba7edb073646dfa4475e42b312f5c60b81a Mon Sep 17 00:00:00 2001 From: Peter Adams <18162810+Maxteabag@users.noreply.github.com> Date: Sat, 31 Jan 2026 16:15:20 +0100 Subject: [PATCH 5/6] Fix DuckDB explorer table load and add refresh tests --- sqlit/domains/explorer/app/schema_service.py | 8 +- .../test_explorer_refresh_duckdb_cursor.py | 338 ++++++++++++++++++ 2 files changed, 344 insertions(+), 2 deletions(-) create mode 100644 tests/integration/test_explorer_refresh_duckdb_cursor.py diff --git a/sqlit/domains/explorer/app/schema_service.py b/sqlit/domains/explorer/app/schema_service.py index 3c8ab7b..1252b77 100644 --- a/sqlit/domains/explorer/app/schema_service.py +++ b/sqlit/domains/explorer/app/schema_service.py @@ -94,9 +94,11 @@ def list_folder_items(self, folder_type: str, database: str | None) -> list[Any] cache_key = database or "__default__" obj_cache = self.object_cache - def cached(key: str, loader: Callable[[], Any]) -> Any: + def cached(key: str, loader: Callable[[], Any], *, allow_empty: bool = True) -> Any: if cache_key in obj_cache and key in obj_cache[cache_key]: - return obj_cache[cache_key][key] + data = obj_cache[cache_key][key] + if allow_empty or data: + return data data = loader() if cache_key not in obj_cache: obj_cache[cache_key] = {} @@ -110,6 +112,7 @@ def cached(key: str, loader: Callable[[], Any]) -> Any: lambda: inspector.get_tables(self.session.connection, db_arg), database, ), + allow_empty=self.session.provider.metadata.db_type != "duckdb", ) return [("table", schema, name) for schema, name in raw_data] if folder_type == "views": @@ -119,6 +122,7 @@ def cached(key: str, loader: Callable[[], Any]) -> Any: lambda: inspector.get_views(self.session.connection, db_arg), database, ), + allow_empty=self.session.provider.metadata.db_type != "duckdb", ) return [("view", schema, name) for schema, name in raw_data] if folder_type == "databases": diff --git a/tests/integration/test_explorer_refresh_duckdb_cursor.py b/tests/integration/test_explorer_refresh_duckdb_cursor.py new file mode 100644 index 0000000..2dbedd8 --- /dev/null +++ b/tests/integration/test_explorer_refresh_duckdb_cursor.py @@ -0,0 +1,338 @@ +"""Integration tests for explorer refresh cursor behavior with DuckDB.""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +import pytest + +from sqlit.domains.explorer.domain.tree_nodes import ColumnNode, TableNode +from sqlit.domains.shell.app.main import SSMSTUI +from tests.helpers import ConnectionConfig +from tests.integration.browsing_base import ( + find_connection_node, + find_folder_node, + find_table_node, + has_loading_children, + wait_for_condition, +) + + +def _build_duckdb_db(path: Path) -> None: + try: + import duckdb # type: ignore + except ImportError: + pytest.skip("duckdb is not installed") + + conn = duckdb.connect(str(path)) + conn.execute("CREATE TABLE users (id INTEGER PRIMARY KEY, name VARCHAR)") + conn.close() + + +def _find_column_node(parent: Any, column_name: str) -> Any | None: + for child in parent.children: + data = getattr(child, "data", None) + if isinstance(data, ColumnNode) and data.name == column_name: + return child + return None + + +async def _wait_for_folder_loaded(pilot: Any, node: Any, description: str) -> None: + await wait_for_condition( + pilot, + lambda: not has_loading_children(node) and len(node.children) > 0, + timeout_seconds=10.0, + description=description, + ) + + +async def _wait_for_columns_loaded(pilot: Any, node: Any) -> None: + await wait_for_condition( + pilot, + lambda: not has_loading_children(node) and _find_column_node(node, "id") is not None, + timeout_seconds=10.0, + description="columns to load", + ) + + +async def _refresh_tree(pilot: Any, app: SSMSTUI) -> None: + before_token = getattr(app, "_tree_refresh_token", None) + await pilot.press("f") + await wait_for_condition( + pilot, + lambda: getattr(app, "_tree_refresh_token", None) is not before_token, + timeout_seconds=5.0, + description="tree refresh to start", + ) + + +async def _wait_for_folder_loaded_or_refresh( + pilot: Any, + app: SSMSTUI, + node: Any, + description: str, + *, + allow_refresh: bool, +) -> None: + try: + await _wait_for_folder_loaded(pilot, node, description) + return + except AssertionError: + if not allow_refresh: + raise + await _refresh_tree(pilot, app) + await pilot.pause(0.3) + await _wait_for_folder_loaded(pilot, node, description) + + +def _set_auto_expanded_paths(app: SSMSTUI, config_name: str) -> None: + app._expanded_paths = { + f"conn:{config_name}", + f"conn:{config_name}/folder:tables", + f"conn:{config_name}/folder:tables/table:main.users", + f"conn:{config_name}/folder:tables/table:main.users/column:main.users.id", + } + + +async def _connect_and_expand( + pilot: Any, + app: SSMSTUI, + config: ConnectionConfig, + *, + auto_expand: bool, + allow_refresh_on_load: bool, +) -> tuple[Any, Any, Any]: + app.connections = [config] + app.refresh_tree() + await pilot.pause(0.1) + + await wait_for_condition( + pilot, + lambda: len(app.object_tree.root.children) > 0, + timeout_seconds=5.0, + description="tree to be populated with connections", + ) + + app.connect_to_server(config) + await pilot.pause(0.5) + + await wait_for_condition( + pilot, + lambda: app.current_connection is not None, + timeout_seconds=15.0, + description="connection to be established", + ) + + connected_node = find_connection_node(app.object_tree.root, config.name) + assert connected_node is not None + + tables_folder = find_folder_node(connected_node, "tables") + assert tables_folder is not None + + if auto_expand: + _set_auto_expanded_paths(app, config.name) + app.refresh_tree() + await pilot.pause(0.3) + connected_node = find_connection_node(app.object_tree.root, config.name) + assert connected_node is not None + tables_folder = find_folder_node(connected_node, "tables") + assert tables_folder is not None + else: + tables_folder.expand() + await pilot.pause(0.2) + await _wait_for_folder_loaded_or_refresh( + pilot, + app, + tables_folder, + "tables to load", + allow_refresh=allow_refresh_on_load, + ) + + table_node = find_table_node(tables_folder, "users") + assert table_node is not None + + if auto_expand: + assert table_node.is_expanded + else: + table_node.expand() + await pilot.pause(0.2) + await _wait_for_columns_loaded(pilot, table_node) + + column_node = _find_column_node(table_node, "id") + assert column_node is not None + + return tables_folder, table_node, column_node + + +@pytest.mark.integration +@pytest.mark.asyncio +@pytest.mark.parametrize("auto_expand", [False, True]) +async def test_duckdb_tables_load_without_manual_refresh(tmp_path: Path, auto_expand: bool) -> None: + db_path = tmp_path / "duckdb_initial_tables.db" + _build_duckdb_db(db_path) + + config = ConnectionConfig( + name="duckdb-initial-tables", + db_type="duckdb", + file_path=str(db_path), + ) + + app = SSMSTUI() + + async with app.run_test(size=(120, 40)) as pilot: + await pilot.pause(0.1) + + app.connections = [config] + app.refresh_tree() + await pilot.pause(0.1) + + await wait_for_condition( + pilot, + lambda: len(app.object_tree.root.children) > 0, + timeout_seconds=5.0, + description="tree to be populated with connections", + ) + + app.connect_to_server(config) + await pilot.pause(0.5) + + await wait_for_condition( + pilot, + lambda: app.current_connection is not None, + timeout_seconds=15.0, + description="connection to be established", + ) + + connected_node = find_connection_node(app.object_tree.root, config.name) + assert connected_node is not None + + tables_folder = find_folder_node(connected_node, "tables") + assert tables_folder is not None + + if auto_expand: + _set_auto_expanded_paths(app, config.name) + app.refresh_tree() + await pilot.pause(0.3) + connected_node = find_connection_node(app.object_tree.root, config.name) + assert connected_node is not None + tables_folder = find_folder_node(connected_node, "tables") + assert tables_folder is not None + assert tables_folder.is_expanded + else: + tables_folder.expand() + await pilot.pause(0.2) + + await _wait_for_folder_loaded( + pilot, + tables_folder, + "tables to load without manual refresh", + ) + +@pytest.mark.integration +@pytest.mark.asyncio +@pytest.mark.parametrize("auto_expand", [False, True]) +async def test_duckdb_refresh_keeps_cursor_on_table(tmp_path: Path, auto_expand: bool) -> None: + db_path = tmp_path / "duckdb_refresh_table.db" + _build_duckdb_db(db_path) + + config = ConnectionConfig( + name="duckdb-refresh-table", + db_type="duckdb", + file_path=str(db_path), + ) + + app = SSMSTUI() + + async with app.run_test(size=(120, 40)) as pilot: + await pilot.pause(0.1) + + _, table_node, _ = await _connect_and_expand( + pilot, + app, + config, + auto_expand=auto_expand, + allow_refresh_on_load=True, + ) + + app.action_focus_explorer() + await pilot.pause(0.05) + app.object_tree.move_cursor(table_node) + await pilot.pause(0.05) + assert app.object_tree.cursor_node == table_node + + await _refresh_tree(pilot, app) + await pilot.pause(0.5) + + refreshed_connection = find_connection_node(app.object_tree.root, config.name) + assert refreshed_connection is not None + refreshed_tables = find_folder_node(refreshed_connection, "tables") + assert refreshed_tables is not None + assert refreshed_tables.is_expanded + + await _wait_for_folder_loaded(pilot, refreshed_tables, "tables to reload") + refreshed_table = find_table_node(refreshed_tables, "users") + assert refreshed_table is not None + + cursor = app.object_tree.cursor_node + assert cursor is not None + assert isinstance(cursor.data, TableNode) + assert cursor.data.name == "users" + + +@pytest.mark.integration +@pytest.mark.asyncio +@pytest.mark.parametrize("auto_expand", [False, True]) +async def test_duckdb_refresh_keeps_cursor_on_column(tmp_path: Path, auto_expand: bool) -> None: + db_path = tmp_path / "duckdb_refresh_column.db" + _build_duckdb_db(db_path) + + config = ConnectionConfig( + name="duckdb-refresh-column", + db_type="duckdb", + file_path=str(db_path), + ) + + app = SSMSTUI() + + async with app.run_test(size=(120, 40)) as pilot: + await pilot.pause(0.1) + + _, table_node, column_node = await _connect_and_expand( + pilot, + app, + config, + auto_expand=auto_expand, + allow_refresh_on_load=True, + ) + + app.action_focus_explorer() + await pilot.pause(0.05) + app.object_tree.move_cursor(column_node) + await pilot.pause(0.05) + assert app.object_tree.cursor_node == column_node + + await _refresh_tree(pilot, app) + await pilot.pause(0.7) + + refreshed_connection = find_connection_node(app.object_tree.root, config.name) + assert refreshed_connection is not None + refreshed_tables = find_folder_node(refreshed_connection, "tables") + assert refreshed_tables is not None + assert refreshed_tables.is_expanded + + await _wait_for_folder_loaded(pilot, refreshed_tables, "tables to reload") + refreshed_table = find_table_node(refreshed_tables, "users") + assert refreshed_table is not None + if not refreshed_table.is_expanded: + refreshed_table.expand() + await pilot.pause(0.2) + + await _wait_for_columns_loaded(pilot, refreshed_table) + refreshed_column = _find_column_node(refreshed_table, "id") + assert refreshed_column is not None + + cursor = app.object_tree.cursor_node + assert cursor is not None + assert isinstance(cursor.data, ColumnNode) + assert cursor.data.name == "id" From c11358cdb4568278d2392d5614932fd1636c6510 Mon Sep 17 00:00:00 2001 From: Peter Adams <18162810+Maxteabag@users.noreply.github.com> Date: Sat, 31 Jan 2026 16:18:07 +0100 Subject: [PATCH 6/6] Auto-refresh explorer after DDL queries --- sqlit/domains/explorer/ui/mixins/tree.py | 17 ++++- .../query/ui/mixins/query_execution.py | 36 +++++++++++ .../test_explorer_refresh_duckdb_cursor.py | 64 +++++++++++++++++++ 3 files changed, 115 insertions(+), 2 deletions(-) diff --git a/sqlit/domains/explorer/ui/mixins/tree.py b/sqlit/domains/explorer/ui/mixins/tree.py index 9491461..2591146 100644 --- a/sqlit/domains/explorer/ui/mixins/tree.py +++ b/sqlit/domains/explorer/ui/mixins/tree.py @@ -210,9 +210,21 @@ def on_tree_node_highlighted(self: TreeMixinHost, event: Tree.NodeHighlighted) - def action_refresh_tree(self: TreeMixinHost) -> None: """Refresh the explorer.""" + self._refresh_tree_common(notify=True) + + def _refresh_tree_after_schema_change(self: TreeMixinHost) -> None: + """Refresh tree after DDL without showing a notification.""" + self._refresh_tree_common(notify=False) + + def _refresh_tree_common(self: TreeMixinHost, *, notify: bool) -> None: self._get_object_cache().clear() - if hasattr(self, "_schema_cache") and "columns" in self._schema_cache: + if hasattr(self, "_schema_cache") and isinstance(self._schema_cache, dict): self._schema_cache["columns"] = {} + self._schema_cache["tables"] = [] + self._schema_cache["views"] = [] + self._schema_cache["procedures"] = [] + if hasattr(self, "_db_object_cache"): + self._db_object_cache = {} if hasattr(self, "_loading_nodes"): self._loading_nodes.clear() self._schema_service = None @@ -241,7 +253,8 @@ def run_loader() -> None: ) else: self._schedule_timer(MIN_TIMER_DELAY_S, run_loader) - self.notify("Refreshed") + if notify: + self.notify("Refreshed") def refresh_tree(self: TreeMixinHost) -> None: tree_builder.refresh_tree_chunked(self) diff --git a/sqlit/domains/query/ui/mixins/query_execution.py b/sqlit/domains/query/ui/mixins/query_execution.py index 6b1c4e2..9873203 100644 --- a/sqlit/domains/query/ui/mixins/query_execution.py +++ b/sqlit/domains/query/ui/mixins/query_execution.py @@ -2,6 +2,8 @@ from __future__ import annotations +import re + from typing import TYPE_CHECKING, Any, Callable from sqlit.domains.explorer.ui.tree import db_switching as tree_db_switching @@ -21,6 +23,19 @@ from sqlit.domains.query.app.transaction import TransactionExecutor +_SCHEMA_CHANGE_RE = re.compile( + r"\b(create|alter|drop|truncate|rename|comment|grant|revoke)\b", + re.IGNORECASE, +) +_SQL_COMMENT_RE = re.compile(r"(--[^\n]*|/\*.*?\*/)", re.DOTALL) +_SQL_LITERAL_RE = re.compile(r"('([^']|'')*'|\"([^\"]|\"\")*\"|`[^`]*`|\[[^\]]*\])", re.DOTALL) + + +def _strip_sql_comments_and_literals(sql: str) -> str: + sql = _SQL_COMMENT_RE.sub(" ", sql) + return _SQL_LITERAL_RE.sub(" ", sql) + + class QueryExecutionMixin(ProcessWorkerLifecycleMixin): """Mixin providing query execution actions.""" @@ -216,6 +231,21 @@ def _on_result(confirmed: bool | None) -> None: _on_result, ) + def _query_changes_schema(self: QueryMixinHost, query: str) -> bool: + cleaned = _strip_sql_comments_and_literals(query) + return bool(_SCHEMA_CHANGE_RE.search(cleaned)) + + def _maybe_refresh_explorer_after_query(self: QueryMixinHost, query: str) -> None: + if not self._query_changes_schema(query): + return + refresh = getattr(self, "_refresh_tree_after_schema_change", None) + if callable(refresh): + refresh() + return + action = getattr(self, "action_refresh_tree", None) + if callable(action): + action() + def _start_query_spinner(self: QueryMixinHost) -> None: """Start the query execution spinner animation.""" import time @@ -470,6 +500,7 @@ async def _run_query_async(self: QueryMixinHost, query: str, keep_insert_mode: b ) else: self._display_non_query_result(result.rows_affected, elapsed_ms) + self._maybe_refresh_explorer_after_query(query) if keep_insert_mode: self._restore_insert_mode() return @@ -489,6 +520,7 @@ async def _run_query_async(self: QueryMixinHost, query: str, keep_insert_mode: b except Exception: pass self._display_multi_statement_results(multi_result, elapsed_ms) + self._maybe_refresh_explorer_after_query(query) else: # Single statement - existing behavior result = await asyncio.to_thread( @@ -509,6 +541,7 @@ async def _run_query_async(self: QueryMixinHost, query: str, keep_insert_mode: b ) else: self._display_non_query_result(result.rows_affected, elapsed_ms) + self._maybe_refresh_explorer_after_query(query) if keep_insert_mode: self._restore_insert_mode() @@ -573,14 +606,17 @@ async def _run_query_atomic_async(self: QueryMixinHost, query: str) -> None: self.notify("Transaction rolled back (error in statement)", severity="error") else: self.notify("Query executed atomically (committed)", severity="information") + self._maybe_refresh_explorer_after_query(query) elif isinstance(result, QueryResult): await self._display_query_results( result.columns, result.rows, result.row_count, result.truncated, elapsed_ms ) self.notify("Query executed atomically (committed)", severity="information") + self._maybe_refresh_explorer_after_query(query) else: self._display_non_query_result(result.rows_affected, elapsed_ms) self.notify("Query executed atomically (committed)", severity="information") + self._maybe_refresh_explorer_after_query(query) except Exception as e: self._display_query_error(f"Transaction rolled back: {e}") diff --git a/tests/integration/test_explorer_refresh_duckdb_cursor.py b/tests/integration/test_explorer_refresh_duckdb_cursor.py index 2dbedd8..e183ab8 100644 --- a/tests/integration/test_explorer_refresh_duckdb_cursor.py +++ b/tests/integration/test_explorer_refresh_duckdb_cursor.py @@ -95,6 +95,16 @@ def _set_auto_expanded_paths(app: SSMSTUI, config_name: str) -> None: } +def _find_table_in_tree(app: SSMSTUI, config_name: str, table_name: str) -> Any | None: + connected_node = find_connection_node(app.object_tree.root, config_name) + if connected_node is None: + return None + tables_folder = find_folder_node(connected_node, "tables") + if tables_folder is None: + return None + return find_table_node(tables_folder, table_name) + + async def _connect_and_expand( pilot: Any, app: SSMSTUI, @@ -336,3 +346,57 @@ async def test_duckdb_refresh_keeps_cursor_on_column(tmp_path: Path, auto_expand assert cursor is not None assert isinstance(cursor.data, ColumnNode) assert cursor.data.name == "id" + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_duckdb_auto_refresh_after_create_table(tmp_path: Path) -> None: + db_path = tmp_path / "duckdb_auto_refresh.db" + _build_duckdb_db(db_path) + + config = ConnectionConfig( + name="duckdb-auto-refresh", + db_type="duckdb", + file_path=str(db_path), + ) + + app = SSMSTUI() + + async with app.run_test(size=(120, 40)) as pilot: + await pilot.pause(0.1) + + tables_folder, _, _ = await _connect_and_expand( + pilot, + app, + config, + auto_expand=False, + allow_refresh_on_load=True, + ) + + assert find_table_node(tables_folder, "users") is not None + assert _find_table_in_tree(app, config.name, "users3") is None + + before_token = getattr(app, "_tree_refresh_token", None) + app.query_input.text = "CREATE TABLE users3 (id INTEGER)" + app.action_execute_query() + + await wait_for_condition( + pilot, + lambda: not getattr(app, "query_executing", False), + timeout_seconds=15.0, + description="query execution to finish", + ) + + await wait_for_condition( + pilot, + lambda: getattr(app, "_tree_refresh_token", None) is not before_token, + timeout_seconds=10.0, + description="tree refresh after DDL", + ) + + await wait_for_condition( + pilot, + lambda: _find_table_in_tree(app, config.name, "users3") is not None, + timeout_seconds=10.0, + description="new table to appear after auto refresh", + )