diff --git a/sqlit/domains/connections/providers/mssql/adapter.py b/sqlit/domains/connections/providers/mssql/adapter.py index 5ddb16c..0ae3574 100644 --- a/sqlit/domains/connections/providers/mssql/adapter.py +++ b/sqlit/domains/connections/providers/mssql/adapter.py @@ -180,7 +180,10 @@ def connect(self, config: ConnectionConfig) -> Any: ) conn_str = self._build_connection_string(config) - return mssql_python.connect(conn_str) + conn = mssql_python.connect(conn_str) + # Enable autocommit to allow DDL statements like CREATE DATABASE + conn.autocommit = True + return conn def get_databases(self, conn: Any) -> list[str]: """Get list of databases from SQL Server.""" diff --git a/sqlit/domains/connections/providers/oracle/adapter.py b/sqlit/domains/connections/providers/oracle/adapter.py index c6ad39e..99dd7a9 100644 --- a/sqlit/domains/connections/providers/oracle/adapter.py +++ b/sqlit/domains/connections/providers/oracle/adapter.py @@ -72,8 +72,18 @@ def connect(self, config: ConnectionConfig) -> Any: if endpoint is None: raise ValueError("Oracle connections require a TCP-style endpoint.") port = int(endpoint.port or get_default_port("oracle")) - # Use Easy Connect string format: host:port/service_name - dsn = f"{endpoint.host}:{port}/{endpoint.database}" + + # Determine connection type: service_name (default) or sid + connection_type = config.get_option("oracle_connection_type", "service_name") + + if connection_type == "sid": + # SID format: host:port:sid (uses colon separator) + # SID is stored in oracle_sid field, fall back to database for backward compat + sid = config.get_option("oracle_sid") or endpoint.database + dsn = f"{endpoint.host}:{port}:{sid}" + else: + # Service Name format: host:port/service_name (uses slash separator) + dsn = f"{endpoint.host}:{port}/{endpoint.database}" # Determine connection mode based on oracle_role oracle_role = config.get_option("oracle_role", "normal") diff --git a/sqlit/domains/connections/providers/oracle/schema.py b/sqlit/domains/connections/providers/oracle/schema.py index 77f4976..f004a6d 100644 --- a/sqlit/domains/connections/providers/oracle/schema.py +++ b/sqlit/domains/connections/providers/oracle/schema.py @@ -20,6 +20,21 @@ def _get_oracle_role_options() -> tuple[SelectOption, ...]: ) +def _get_oracle_connection_type_options() -> tuple[SelectOption, ...]: + return ( + SelectOption("service_name", "Service Name"), + SelectOption("sid", "SID"), + ) + + +def _oracle_connection_type_is_service_name(values: dict) -> bool: + return values.get("oracle_connection_type", "service_name") != "sid" + + +def _oracle_connection_type_is_sid(values: dict) -> bool: + return values.get("oracle_connection_type") == "sid" + + SCHEMA = ConnectionSchema( db_type="oracle", display_name="Oracle", @@ -32,11 +47,26 @@ def _get_oracle_role_options() -> tuple[SelectOption, ...]: group="server_port", ), _port_field("1521"), + SchemaField( + name="oracle_connection_type", + label="Connection Type", + field_type=FieldType.DROPDOWN, + options=_get_oracle_connection_type_options(), + default="service_name", + ), SchemaField( name="database", label="Service Name", placeholder="ORCL or XEPDB1", required=True, + visible_when=_oracle_connection_type_is_service_name, + ), + SchemaField( + name="oracle_sid", + label="SID", + placeholder="ORCL", + required=True, + visible_when=_oracle_connection_type_is_sid, ), _username_field(), _password_field(), diff --git a/sqlit/domains/connections/providers/oracle_legacy/schema.py b/sqlit/domains/connections/providers/oracle_legacy/schema.py index 5da8d0a..4252f72 100644 --- a/sqlit/domains/connections/providers/oracle_legacy/schema.py +++ b/sqlit/domains/connections/providers/oracle_legacy/schema.py @@ -20,6 +20,21 @@ def _get_oracle_role_options() -> tuple[SelectOption, ...]: ) +def _get_oracle_connection_type_options() -> tuple[SelectOption, ...]: + return ( + SelectOption("service_name", "Service Name"), + SelectOption("sid", "SID"), + ) + + +def _oracle_connection_type_is_service_name(values: dict) -> bool: + return values.get("oracle_connection_type", "service_name") != "sid" + + +def _oracle_connection_type_is_sid(values: dict) -> bool: + return values.get("oracle_connection_type") == "sid" + + def _get_oracle_client_mode_options() -> tuple[SelectOption, ...]: return ( SelectOption("thick", "Thick (Instant Client)"), @@ -43,11 +58,26 @@ def _oracle_thick_mode_enabled(values: dict) -> bool: group="server_port", ), _port_field("1521"), + SchemaField( + name="oracle_connection_type", + label="Connection Type", + field_type=FieldType.DROPDOWN, + options=_get_oracle_connection_type_options(), + default="service_name", + ), SchemaField( name="database", label="Service Name", + placeholder="ORCL or XEPDB1", + required=True, + visible_when=_oracle_connection_type_is_service_name, + ), + SchemaField( + name="oracle_sid", + label="SID", placeholder="ORCL", required=True, + visible_when=_oracle_connection_type_is_sid, ), _username_field(), _password_field(), diff --git a/sqlit/domains/explorer/ui/mixins/tree.py b/sqlit/domains/explorer/ui/mixins/tree.py index 9491461..1a4f219 100644 --- a/sqlit/domains/explorer/ui/mixins/tree.py +++ b/sqlit/domains/explorer/ui/mixins/tree.py @@ -295,6 +295,8 @@ def action_select_table(self: TreeMixinHost) -> None: "name": data.name, "columns": [], } + # Stash per-result metadata so results can resolve PKs without relying on globals. + self._pending_result_table_info = self._last_query_table self._prime_last_query_table_columns(data.database, data.schema, data.name) self.query_input.text = self.current_provider.dialect.build_select_query( diff --git a/sqlit/domains/query/ui/mixins/query_execution.py b/sqlit/domains/query/ui/mixins/query_execution.py index 6b1c4e2..417a633 100644 --- a/sqlit/domains/query/ui/mixins/query_execution.py +++ b/sqlit/domains/query/ui/mixins/query_execution.py @@ -417,6 +417,17 @@ async def _run_query_async(self: QueryMixinHost, query: str, keep_insert_mode: b executable_statements = [s for s in statements if not is_comment_only_statement(s)] is_multi_statement = len(executable_statements) > 1 + if is_multi_statement: + self._pending_result_table_info = None + elif executable_statements: + if getattr(self, "_pending_result_table_info", None) is None: + table_info = self._infer_result_table_info(executable_statements[0]) + if table_info is not None: + self._pending_result_table_info = table_info + prime = getattr(self, "_prime_result_table_columns", None) + if callable(prime): + prime(table_info) + try: start_time = time.perf_counter() max_rows = self.services.runtime.max_rows or MAX_FETCH_ROWS diff --git a/sqlit/domains/query/ui/mixins/query_results.py b/sqlit/domains/query/ui/mixins/query_results.py index c40223f..6be6e8b 100644 --- a/sqlit/domains/query/ui/mixins/query_results.py +++ b/sqlit/domains/query/ui/mixins/query_results.py @@ -224,6 +224,7 @@ def _render_results_table_incremental( escape: bool, row_limit: int, render_token: int, + table_info: dict[str, Any] | None = None, ) -> None: initial_count = min(RESULTS_RENDER_INITIAL_ROWS, row_limit) initial_rows = rows[:initial_count] if initial_count > 0 else [] @@ -269,9 +270,16 @@ def _render_results_table_incremental( pass if render_token == getattr(self, "_results_render_token", 0): self._replace_results_table_with_data(columns, rows, escape=escape) + if table_info is not None: + try: + self.results_table.result_table_info = table_info + except Exception: + pass return if render_token != getattr(self, "_results_render_token", 0): return + if table_info is not None: + table.result_table_info = table_info self._replace_results_table_with_table(table) self._schedule_results_render( table, @@ -291,6 +299,7 @@ async def _display_query_results( self._last_result_columns = columns self._last_result_rows = rows self._last_result_row_count = row_count + table_info = getattr(self, "_pending_result_table_info", None) # Switch to single result mode (in case we were showing stacked results) self._show_single_result_mode() @@ -304,12 +313,15 @@ async def _display_query_results( escape=True, row_limit=row_limit, render_token=render_token, + table_info=table_info, ) else: render_rows = rows[:row_limit] if row_limit else [] table = self._build_results_table(columns, render_rows, escape=True) if render_token != getattr(self, "_results_render_token", 0): return + if table_info is not None: + table.result_table_info = table_info self._replace_results_table_with_table(table) time_str = format_duration_ms(elapsed_ms) @@ -320,9 +332,15 @@ async def _display_query_results( ) else: self.notify(f"Query returned {row_count} rows in {time_str}") + if table_info is not None: + prime = getattr(self, "_prime_result_table_columns", None) + if callable(prime): + prime(table_info) + self._pending_result_table_info = None def _display_non_query_result(self: QueryMixinHost, affected: int, elapsed_ms: float) -> None: """Display non-query result (called on main thread).""" + self._pending_result_table_info = None self._last_result_columns = ["Result"] self._last_result_rows = [(f"{affected} row(s) affected",)] self._last_result_row_count = 1 @@ -337,6 +355,7 @@ def _display_non_query_result(self: QueryMixinHost, affected: int, elapsed_ms: f def _display_query_error(self: QueryMixinHost, error_message: str) -> None: """Display query error (called on main thread).""" self._cancel_results_render() + self._pending_result_table_info = None # notify(severity="error") handles displaying the error in results via _show_error_in_results self.notify(f"Query error: {error_message}", severity="error") @@ -360,7 +379,17 @@ def _display_multi_statement_results( # Add each result section for i, stmt_result in enumerate(multi_result.results): - container.add_result_section(stmt_result, i, auto_collapse=auto_collapse) + table_info = self._infer_result_table_info(stmt_result.statement) + if table_info is not None: + prime = getattr(self, "_prime_result_table_columns", None) + if callable(prime): + prime(table_info) + container.add_result_section( + stmt_result, + i, + auto_collapse=auto_collapse, + table_info=table_info, + ) # Show the stacked results container, hide single result table self._show_stacked_results_mode() @@ -378,6 +407,7 @@ def _display_multi_statement_results( ) else: self.notify(f"Executed {total} statements in {time_str}") + self._pending_result_table_info = None def _get_stacked_results_container(self: QueryMixinHost) -> Any: """Get the stacked results container.""" @@ -410,3 +440,30 @@ def _show_single_result_mode(self: QueryMixinHost) -> None: stacked.remove_class("active") except Exception: pass + + def _infer_result_table_info(self: QueryMixinHost, sql: str) -> dict[str, Any] | None: + """Best-effort inference of a single source table for query results.""" + from sqlit.domains.query.completion import extract_table_refs + + refs = extract_table_refs(sql) + if len(refs) != 1: + return None + ref = refs[0] + schema = ref.schema + name = ref.name + database = None + table_metadata = getattr(self, "_table_metadata", {}) or {} + key_candidates = [name.lower()] + if schema: + key_candidates.insert(0, f"{schema}.{name}".lower()) + for key in key_candidates: + metadata = table_metadata.get(key) + if metadata: + schema, name, database = metadata + break + return { + "database": database, + "schema": schema, + "name": name, + "columns": [], + } diff --git a/sqlit/domains/results/ui/mixins/results.py b/sqlit/domains/results/ui/mixins/results.py index 2f3a434..d875029 100644 --- a/sqlit/domains/results/ui/mixins/results.py +++ b/sqlit/domains/results/ui/mixins/results.py @@ -8,6 +8,8 @@ from sqlit.shared.ui.protocols import ResultsMixinHost from sqlit.shared.ui.widgets import SqlitDataTable +MIN_TIMER_DELAY_S = 0.001 + class ResultsMixin: """Mixin providing results handling functionality.""" @@ -19,6 +21,122 @@ class ResultsMixin: _tooltip_showing: bool = False _tooltip_timer: Any | None = None + def _schedule_results_timer(self: ResultsMixinHost, delay_s: float, callback: Any) -> Any | None: + set_timer = getattr(self, "set_timer", None) + if callable(set_timer): + return set_timer(delay_s, callback) + call_later = getattr(self, "call_later", None) + if callable(call_later): + try: + call_later(callback) + return None + except Exception: + pass + try: + callback() + except Exception: + pass + return None + + def _apply_result_table_columns( + self: ResultsMixinHost, + table_info: dict[str, Any], + token: int, + columns: list[Any], + ) -> None: + if table_info.get("_columns_token") != token: + return + table_info["columns"] = columns + + def _prime_result_table_columns(self: ResultsMixinHost, table_info: dict[str, Any] | None) -> None: + if not table_info: + return + if table_info.get("columns"): + return + name = table_info.get("name") + if not name: + return + database = table_info.get("database") + schema = table_info.get("schema") + token = int(table_info.get("_columns_token", 0)) + 1 + table_info["_columns_token"] = token + + async def work_async() -> None: + import asyncio + + columns: list[Any] = [] + try: + runtime = getattr(self.services, "runtime", None) + use_worker = bool(getattr(runtime, "process_worker", False)) and not bool( + getattr(getattr(runtime, "mock", None), "enabled", False) + ) + client = None + if use_worker and hasattr(self, "_get_process_worker_client_async"): + client = await self._get_process_worker_client_async() # type: ignore[attr-defined] + + if client is not None and hasattr(client, "list_columns") and self.current_config is not None: + outcome = await asyncio.to_thread( + client.list_columns, + config=self.current_config, + database=database, + schema=schema, + name=name, + ) + if getattr(outcome, "cancelled", False): + return + error = getattr(outcome, "error", None) + if error: + raise RuntimeError(error) + columns = outcome.columns or [] + else: + schema_service = getattr(self, "_get_schema_service", None) + if callable(schema_service): + service = self._get_schema_service() + if service: + columns = await asyncio.to_thread( + service.list_columns, + database, + schema, + name, + ) + except Exception: + columns = [] + + self._schedule_results_timer( + MIN_TIMER_DELAY_S, + lambda: self._apply_result_table_columns(table_info, token, columns), + ) + + self.run_worker(work_async(), name=f"prime-result-columns-{name}", exclusive=False) + + def _normalize_column_name(self: ResultsMixinHost, name: str) -> str: + trimmed = name.strip() + if len(trimmed) >= 2: + if trimmed[0] == trimmed[-1] and trimmed[0] in ("\"", "`"): + trimmed = trimmed[1:-1] + elif trimmed[0] == "[" and trimmed[-1] == "]": + trimmed = trimmed[1:-1] + if "." in trimmed and not any(q in trimmed for q in ("\"", "`", "[")): + trimmed = trimmed.split(".")[-1] + return trimmed.lower() + + def _get_active_results_table_info( + self: ResultsMixinHost, + table: SqlitDataTable | None, + stacked: bool, + ) -> dict[str, Any] | None: + if not table: + return None + if stacked: + section = self._find_results_section(table) + table_info = getattr(section, "result_table_info", None) + if table_info: + return table_info + table_info = getattr(table, "result_table_info", None) + if table_info: + return table_info + return getattr(self, "_last_query_table", None) + def _copy_text(self: ResultsMixinHost, text: str) -> bool: """Copy text to clipboard if possible, otherwise store internally.""" self._internal_clipboard = text @@ -610,21 +728,20 @@ def sql_value(v: object) -> str: # Get table name and primary key columns table_name = "