diff --git a/sqlit/domains/connections/providers/oracle/adapter.py b/sqlit/domains/connections/providers/oracle/adapter.py index c6ad39e..99dd7a9 100644 --- a/sqlit/domains/connections/providers/oracle/adapter.py +++ b/sqlit/domains/connections/providers/oracle/adapter.py @@ -72,8 +72,18 @@ def connect(self, config: ConnectionConfig) -> Any: if endpoint is None: raise ValueError("Oracle connections require a TCP-style endpoint.") port = int(endpoint.port or get_default_port("oracle")) - # Use Easy Connect string format: host:port/service_name - dsn = f"{endpoint.host}:{port}/{endpoint.database}" + + # Determine connection type: service_name (default) or sid + connection_type = config.get_option("oracle_connection_type", "service_name") + + if connection_type == "sid": + # SID format: host:port:sid (uses colon separator) + # SID is stored in oracle_sid field, fall back to database for backward compat + sid = config.get_option("oracle_sid") or endpoint.database + dsn = f"{endpoint.host}:{port}:{sid}" + else: + # Service Name format: host:port/service_name (uses slash separator) + dsn = f"{endpoint.host}:{port}/{endpoint.database}" # Determine connection mode based on oracle_role oracle_role = config.get_option("oracle_role", "normal") diff --git a/sqlit/domains/connections/providers/oracle/schema.py b/sqlit/domains/connections/providers/oracle/schema.py index 77f4976..f004a6d 100644 --- a/sqlit/domains/connections/providers/oracle/schema.py +++ b/sqlit/domains/connections/providers/oracle/schema.py @@ -20,6 +20,21 @@ def _get_oracle_role_options() -> tuple[SelectOption, ...]: ) +def _get_oracle_connection_type_options() -> tuple[SelectOption, ...]: + return ( + SelectOption("service_name", "Service Name"), + SelectOption("sid", "SID"), + ) + + +def _oracle_connection_type_is_service_name(values: dict) -> bool: + return values.get("oracle_connection_type", "service_name") != "sid" + + +def _oracle_connection_type_is_sid(values: dict) -> bool: + return values.get("oracle_connection_type") == "sid" + + SCHEMA = ConnectionSchema( db_type="oracle", display_name="Oracle", @@ -32,11 +47,26 @@ def _get_oracle_role_options() -> tuple[SelectOption, ...]: group="server_port", ), _port_field("1521"), + SchemaField( + name="oracle_connection_type", + label="Connection Type", + field_type=FieldType.DROPDOWN, + options=_get_oracle_connection_type_options(), + default="service_name", + ), SchemaField( name="database", label="Service Name", placeholder="ORCL or XEPDB1", required=True, + visible_when=_oracle_connection_type_is_service_name, + ), + SchemaField( + name="oracle_sid", + label="SID", + placeholder="ORCL", + required=True, + visible_when=_oracle_connection_type_is_sid, ), _username_field(), _password_field(), diff --git a/sqlit/domains/connections/providers/oracle_legacy/schema.py b/sqlit/domains/connections/providers/oracle_legacy/schema.py index 5da8d0a..4252f72 100644 --- a/sqlit/domains/connections/providers/oracle_legacy/schema.py +++ b/sqlit/domains/connections/providers/oracle_legacy/schema.py @@ -20,6 +20,21 @@ def _get_oracle_role_options() -> tuple[SelectOption, ...]: ) +def _get_oracle_connection_type_options() -> tuple[SelectOption, ...]: + return ( + SelectOption("service_name", "Service Name"), + SelectOption("sid", "SID"), + ) + + +def _oracle_connection_type_is_service_name(values: dict) -> bool: + return values.get("oracle_connection_type", "service_name") != "sid" + + +def _oracle_connection_type_is_sid(values: dict) -> bool: + return values.get("oracle_connection_type") == "sid" + + def _get_oracle_client_mode_options() -> tuple[SelectOption, ...]: return ( SelectOption("thick", "Thick (Instant Client)"), @@ -43,11 +58,26 @@ def _oracle_thick_mode_enabled(values: dict) -> bool: group="server_port", ), _port_field("1521"), + SchemaField( + name="oracle_connection_type", + label="Connection Type", + field_type=FieldType.DROPDOWN, + options=_get_oracle_connection_type_options(), + default="service_name", + ), SchemaField( name="database", label="Service Name", + placeholder="ORCL or XEPDB1", + required=True, + visible_when=_oracle_connection_type_is_service_name, + ), + SchemaField( + name="oracle_sid", + label="SID", placeholder="ORCL", required=True, + visible_when=_oracle_connection_type_is_sid, ), _username_field(), _password_field(), diff --git a/sqlit/domains/explorer/ui/mixins/tree.py b/sqlit/domains/explorer/ui/mixins/tree.py index 9491461..1a4f219 100644 --- a/sqlit/domains/explorer/ui/mixins/tree.py +++ b/sqlit/domains/explorer/ui/mixins/tree.py @@ -295,6 +295,8 @@ def action_select_table(self: TreeMixinHost) -> None: "name": data.name, "columns": [], } + # Stash per-result metadata so results can resolve PKs without relying on globals. + self._pending_result_table_info = self._last_query_table self._prime_last_query_table_columns(data.database, data.schema, data.name) self.query_input.text = self.current_provider.dialect.build_select_query( diff --git a/sqlit/domains/query/ui/mixins/query_execution.py b/sqlit/domains/query/ui/mixins/query_execution.py index 6b1c4e2..417a633 100644 --- a/sqlit/domains/query/ui/mixins/query_execution.py +++ b/sqlit/domains/query/ui/mixins/query_execution.py @@ -417,6 +417,17 @@ async def _run_query_async(self: QueryMixinHost, query: str, keep_insert_mode: b executable_statements = [s for s in statements if not is_comment_only_statement(s)] is_multi_statement = len(executable_statements) > 1 + if is_multi_statement: + self._pending_result_table_info = None + elif executable_statements: + if getattr(self, "_pending_result_table_info", None) is None: + table_info = self._infer_result_table_info(executable_statements[0]) + if table_info is not None: + self._pending_result_table_info = table_info + prime = getattr(self, "_prime_result_table_columns", None) + if callable(prime): + prime(table_info) + try: start_time = time.perf_counter() max_rows = self.services.runtime.max_rows or MAX_FETCH_ROWS diff --git a/sqlit/domains/query/ui/mixins/query_results.py b/sqlit/domains/query/ui/mixins/query_results.py index c40223f..6be6e8b 100644 --- a/sqlit/domains/query/ui/mixins/query_results.py +++ b/sqlit/domains/query/ui/mixins/query_results.py @@ -224,6 +224,7 @@ def _render_results_table_incremental( escape: bool, row_limit: int, render_token: int, + table_info: dict[str, Any] | None = None, ) -> None: initial_count = min(RESULTS_RENDER_INITIAL_ROWS, row_limit) initial_rows = rows[:initial_count] if initial_count > 0 else [] @@ -269,9 +270,16 @@ def _render_results_table_incremental( pass if render_token == getattr(self, "_results_render_token", 0): self._replace_results_table_with_data(columns, rows, escape=escape) + if table_info is not None: + try: + self.results_table.result_table_info = table_info + except Exception: + pass return if render_token != getattr(self, "_results_render_token", 0): return + if table_info is not None: + table.result_table_info = table_info self._replace_results_table_with_table(table) self._schedule_results_render( table, @@ -291,6 +299,7 @@ async def _display_query_results( self._last_result_columns = columns self._last_result_rows = rows self._last_result_row_count = row_count + table_info = getattr(self, "_pending_result_table_info", None) # Switch to single result mode (in case we were showing stacked results) self._show_single_result_mode() @@ -304,12 +313,15 @@ async def _display_query_results( escape=True, row_limit=row_limit, render_token=render_token, + table_info=table_info, ) else: render_rows = rows[:row_limit] if row_limit else [] table = self._build_results_table(columns, render_rows, escape=True) if render_token != getattr(self, "_results_render_token", 0): return + if table_info is not None: + table.result_table_info = table_info self._replace_results_table_with_table(table) time_str = format_duration_ms(elapsed_ms) @@ -320,9 +332,15 @@ async def _display_query_results( ) else: self.notify(f"Query returned {row_count} rows in {time_str}") + if table_info is not None: + prime = getattr(self, "_prime_result_table_columns", None) + if callable(prime): + prime(table_info) + self._pending_result_table_info = None def _display_non_query_result(self: QueryMixinHost, affected: int, elapsed_ms: float) -> None: """Display non-query result (called on main thread).""" + self._pending_result_table_info = None self._last_result_columns = ["Result"] self._last_result_rows = [(f"{affected} row(s) affected",)] self._last_result_row_count = 1 @@ -337,6 +355,7 @@ def _display_non_query_result(self: QueryMixinHost, affected: int, elapsed_ms: f def _display_query_error(self: QueryMixinHost, error_message: str) -> None: """Display query error (called on main thread).""" self._cancel_results_render() + self._pending_result_table_info = None # notify(severity="error") handles displaying the error in results via _show_error_in_results self.notify(f"Query error: {error_message}", severity="error") @@ -360,7 +379,17 @@ def _display_multi_statement_results( # Add each result section for i, stmt_result in enumerate(multi_result.results): - container.add_result_section(stmt_result, i, auto_collapse=auto_collapse) + table_info = self._infer_result_table_info(stmt_result.statement) + if table_info is not None: + prime = getattr(self, "_prime_result_table_columns", None) + if callable(prime): + prime(table_info) + container.add_result_section( + stmt_result, + i, + auto_collapse=auto_collapse, + table_info=table_info, + ) # Show the stacked results container, hide single result table self._show_stacked_results_mode() @@ -378,6 +407,7 @@ def _display_multi_statement_results( ) else: self.notify(f"Executed {total} statements in {time_str}") + self._pending_result_table_info = None def _get_stacked_results_container(self: QueryMixinHost) -> Any: """Get the stacked results container.""" @@ -410,3 +440,30 @@ def _show_single_result_mode(self: QueryMixinHost) -> None: stacked.remove_class("active") except Exception: pass + + def _infer_result_table_info(self: QueryMixinHost, sql: str) -> dict[str, Any] | None: + """Best-effort inference of a single source table for query results.""" + from sqlit.domains.query.completion import extract_table_refs + + refs = extract_table_refs(sql) + if len(refs) != 1: + return None + ref = refs[0] + schema = ref.schema + name = ref.name + database = None + table_metadata = getattr(self, "_table_metadata", {}) or {} + key_candidates = [name.lower()] + if schema: + key_candidates.insert(0, f"{schema}.{name}".lower()) + for key in key_candidates: + metadata = table_metadata.get(key) + if metadata: + schema, name, database = metadata + break + return { + "database": database, + "schema": schema, + "name": name, + "columns": [], + } diff --git a/sqlit/domains/results/ui/mixins/results.py b/sqlit/domains/results/ui/mixins/results.py index 2f3a434..d875029 100644 --- a/sqlit/domains/results/ui/mixins/results.py +++ b/sqlit/domains/results/ui/mixins/results.py @@ -8,6 +8,8 @@ from sqlit.shared.ui.protocols import ResultsMixinHost from sqlit.shared.ui.widgets import SqlitDataTable +MIN_TIMER_DELAY_S = 0.001 + class ResultsMixin: """Mixin providing results handling functionality.""" @@ -19,6 +21,122 @@ class ResultsMixin: _tooltip_showing: bool = False _tooltip_timer: Any | None = None + def _schedule_results_timer(self: ResultsMixinHost, delay_s: float, callback: Any) -> Any | None: + set_timer = getattr(self, "set_timer", None) + if callable(set_timer): + return set_timer(delay_s, callback) + call_later = getattr(self, "call_later", None) + if callable(call_later): + try: + call_later(callback) + return None + except Exception: + pass + try: + callback() + except Exception: + pass + return None + + def _apply_result_table_columns( + self: ResultsMixinHost, + table_info: dict[str, Any], + token: int, + columns: list[Any], + ) -> None: + if table_info.get("_columns_token") != token: + return + table_info["columns"] = columns + + def _prime_result_table_columns(self: ResultsMixinHost, table_info: dict[str, Any] | None) -> None: + if not table_info: + return + if table_info.get("columns"): + return + name = table_info.get("name") + if not name: + return + database = table_info.get("database") + schema = table_info.get("schema") + token = int(table_info.get("_columns_token", 0)) + 1 + table_info["_columns_token"] = token + + async def work_async() -> None: + import asyncio + + columns: list[Any] = [] + try: + runtime = getattr(self.services, "runtime", None) + use_worker = bool(getattr(runtime, "process_worker", False)) and not bool( + getattr(getattr(runtime, "mock", None), "enabled", False) + ) + client = None + if use_worker and hasattr(self, "_get_process_worker_client_async"): + client = await self._get_process_worker_client_async() # type: ignore[attr-defined] + + if client is not None and hasattr(client, "list_columns") and self.current_config is not None: + outcome = await asyncio.to_thread( + client.list_columns, + config=self.current_config, + database=database, + schema=schema, + name=name, + ) + if getattr(outcome, "cancelled", False): + return + error = getattr(outcome, "error", None) + if error: + raise RuntimeError(error) + columns = outcome.columns or [] + else: + schema_service = getattr(self, "_get_schema_service", None) + if callable(schema_service): + service = self._get_schema_service() + if service: + columns = await asyncio.to_thread( + service.list_columns, + database, + schema, + name, + ) + except Exception: + columns = [] + + self._schedule_results_timer( + MIN_TIMER_DELAY_S, + lambda: self._apply_result_table_columns(table_info, token, columns), + ) + + self.run_worker(work_async(), name=f"prime-result-columns-{name}", exclusive=False) + + def _normalize_column_name(self: ResultsMixinHost, name: str) -> str: + trimmed = name.strip() + if len(trimmed) >= 2: + if trimmed[0] == trimmed[-1] and trimmed[0] in ("\"", "`"): + trimmed = trimmed[1:-1] + elif trimmed[0] == "[" and trimmed[-1] == "]": + trimmed = trimmed[1:-1] + if "." in trimmed and not any(q in trimmed for q in ("\"", "`", "[")): + trimmed = trimmed.split(".")[-1] + return trimmed.lower() + + def _get_active_results_table_info( + self: ResultsMixinHost, + table: SqlitDataTable | None, + stacked: bool, + ) -> dict[str, Any] | None: + if not table: + return None + if stacked: + section = self._find_results_section(table) + table_info = getattr(section, "result_table_info", None) + if table_info: + return table_info + table_info = getattr(table, "result_table_info", None) + if table_info: + return table_info + return getattr(self, "_last_query_table", None) + def _copy_text(self: ResultsMixinHost, text: str) -> bool: """Copy text to clipboard if possible, otherwise store internally.""" self._internal_clipboard = text @@ -610,21 +728,20 @@ def sql_value(v: object) -> str: # Get table name and primary key columns table_name = "" pk_column_names: set[str] = set() - - if hasattr(self, "_last_query_table") and self._last_query_table: - table_info = self._last_query_table - table_name = table_info["name"] + table_info = self._get_active_results_table_info(table, _stacked) + if table_info: + table_name = table_info.get("name", table_name) # Get PK columns from column info for col in table_info.get("columns", []): if col.is_primary_key: - pk_column_names.add(col.name) + pk_column_names.add(self._normalize_column_name(col.name)) # Build WHERE clause - prefer PK columns, fall back to all columns where_parts = [] for i, col in enumerate(columns): if i < len(row_values): # If we have PK info, only use PK columns; otherwise use all columns - if pk_column_names and col not in pk_column_names: + if pk_column_names and self._normalize_column_name(col) not in pk_column_names: continue val = row_values[i] if val is None: @@ -685,9 +802,10 @@ def action_edit_cell(self: ResultsMixinHost) -> None: column_name = columns[cursor_col] # Check if this column is a primary key - don't allow editing PKs - if hasattr(self, "_last_query_table") and self._last_query_table: - for col in self._last_query_table.get("columns", []): - if col.name == column_name and col.is_primary_key: + table_info = self._get_active_results_table_info(table, _stacked) + if table_info: + for col in table_info.get("columns", []): + if col.is_primary_key and self._normalize_column_name(col.name) == self._normalize_column_name(column_name): self.notify("Cannot edit primary key column", severity="warning") return @@ -705,21 +823,19 @@ def sql_value(v: object) -> str: # Get table name and primary key columns table_name = "
" pk_column_names: set[str] = set() - - if hasattr(self, "_last_query_table") and self._last_query_table: - table_info = self._last_query_table - table_name = table_info["name"] + if table_info: + table_name = table_info.get("name", table_name) # Get PK columns from column info for col in table_info.get("columns", []): if col.is_primary_key: - pk_column_names.add(col.name) + pk_column_names.add(self._normalize_column_name(col.name)) # Build WHERE clause - prefer PK columns, fall back to all columns where_parts = [] for i, col in enumerate(columns): if i < len(row_values): # If we have PK info, only use PK columns; otherwise use all columns - if pk_column_names and col not in pk_column_names: + if pk_column_names and self._normalize_column_name(col) not in pk_column_names: continue val = row_values[i] if val is None: diff --git a/sqlit/domains/shell/app/main.py b/sqlit/domains/shell/app/main.py index 5f9fbec..66edcde 100644 --- a/sqlit/domains/shell/app/main.py +++ b/sqlit/domains/shell/app/main.py @@ -197,6 +197,7 @@ def __init__( self._columns_loading: set[str] = set() self._state_machine = UIStateMachine() self._last_query_table: dict[str, Any] | None = None + self._pending_result_table_info: dict[str, Any] | None = None self._query_target_database: str | None = None # Target DB for auto-generated queries self._restart_requested: bool = False # Idle scheduler for background work diff --git a/sqlit/shared/ui/protocols/results.py b/sqlit/shared/ui/protocols/results.py index d6e3f5e..2124b1a 100644 --- a/sqlit/shared/ui/protocols/results.py +++ b/sqlit/shared/ui/protocols/results.py @@ -16,6 +16,7 @@ class ResultsStateProtocol(Protocol): _last_result_row_count: int _internal_clipboard: str _last_query_table: dict[str, Any] | None + _pending_result_table_info: dict[str, Any] | None _results_table_counter: int _results_filter_visible: bool _results_filter_text: str diff --git a/sqlit/shared/ui/widgets_stacked_results.py b/sqlit/shared/ui/widgets_stacked_results.py index e61c819..015cc45 100644 --- a/sqlit/shared/ui/widgets_stacked_results.py +++ b/sqlit/shared/ui/widgets_stacked_results.py @@ -104,6 +104,7 @@ def __init__( self.is_error = is_error self.result_columns: list[str] = [] self.result_rows: list[tuple] = [] + self.result_table_info: dict[str, Any] | None = None self._content = content if is_error: self.add_class("error") @@ -156,6 +157,7 @@ def add_result_section( index: int, *, auto_collapse: bool = False, + table_info: dict[str, Any] | None = None, ) -> None: """Add a result section for a statement result.""" from sqlit.domains.query.app.query_service import QueryResult @@ -169,6 +171,8 @@ def add_result_section( # SELECT result - build a DataTable result_columns, result_rows = self._get_result_table_data(stmt_result.result) content = self._build_result_table_from_rows(result_columns, result_rows, index) + if table_info is not None: + content.result_table_info = table_info else: # Non-query result (INSERT/UPDATE/DELETE) content = NonQueryDisplay(stmt_result.result.rows_affected) @@ -186,6 +190,7 @@ def add_result_section( ) section.result_columns = result_columns section.result_rows = result_rows + section.result_table_info = table_info self.mount(section) self._section_count += 1 diff --git a/tests/unit/test_oracle_adapter.py b/tests/unit/test_oracle_adapter.py index f743c78..b1fca94 100644 --- a/tests/unit/test_oracle_adapter.py +++ b/tests/unit/test_oracle_adapter.py @@ -126,3 +126,145 @@ def test_connect_default_role_when_not_set(self): mock_oracledb.connect.assert_called_once() call_kwargs = mock_oracledb.connect.call_args.kwargs assert "mode" not in call_kwargs + + +class TestOracleAdapterConnectionType: + """Test Oracle adapter handles connection type (Service Name vs SID) correctly.""" + + def test_connect_service_name_format(self): + """Test that service_name connection type uses slash separator.""" + mock_oracledb = MagicMock() + mock_oracledb.AUTH_MODE_SYSDBA = 2 + mock_oracledb.AUTH_MODE_SYSOPER = 4 + + with patch.dict("sys.modules", {"oracledb": mock_oracledb}): + from sqlit.domains.connections.providers.oracle.adapter import OracleAdapter + + adapter = OracleAdapter() + config = ConnectionConfig( + name="test", + db_type="oracle", + server="localhost", + port="1521", + database="XEPDB1", + username="testuser", + password="testpass", + options={"oracle_connection_type": "service_name"}, + ) + + adapter.connect(config) + + mock_oracledb.connect.assert_called_once() + call_kwargs = mock_oracledb.connect.call_args.kwargs + # Service name uses slash separator: host:port/service_name + assert call_kwargs["dsn"] == "localhost:1521/XEPDB1" + + def test_connect_sid_format(self): + """Test that sid connection type uses colon separator with oracle_sid field.""" + mock_oracledb = MagicMock() + mock_oracledb.AUTH_MODE_SYSDBA = 2 + mock_oracledb.AUTH_MODE_SYSOPER = 4 + + with patch.dict("sys.modules", {"oracledb": mock_oracledb}): + from sqlit.domains.connections.providers.oracle.adapter import OracleAdapter + + adapter = OracleAdapter() + config = ConnectionConfig( + name="test", + db_type="oracle", + server="localhost", + port="1521", + username="testuser", + password="testpass", + options={"oracle_connection_type": "sid", "oracle_sid": "ORCL"}, + ) + + adapter.connect(config) + + mock_oracledb.connect.assert_called_once() + call_kwargs = mock_oracledb.connect.call_args.kwargs + # SID uses colon separator: host:port:sid + assert call_kwargs["dsn"] == "localhost:1521:ORCL" + + def test_connect_sid_backward_compat_uses_database_field(self): + """Test that SID falls back to database field for backward compatibility.""" + mock_oracledb = MagicMock() + mock_oracledb.AUTH_MODE_SYSDBA = 2 + mock_oracledb.AUTH_MODE_SYSOPER = 4 + + with patch.dict("sys.modules", {"oracledb": mock_oracledb}): + from sqlit.domains.connections.providers.oracle.adapter import OracleAdapter + + adapter = OracleAdapter() + # Old config style: oracle_sid not set, database used instead + config = ConnectionConfig( + name="test", + db_type="oracle", + server="localhost", + port="1521", + database="LEGACY_SID", + username="testuser", + password="testpass", + options={"oracle_connection_type": "sid"}, + ) + + adapter.connect(config) + + mock_oracledb.connect.assert_called_once() + call_kwargs = mock_oracledb.connect.call_args.kwargs + # Should use database field as fallback + assert call_kwargs["dsn"] == "localhost:1521:LEGACY_SID" + + def test_connect_default_connection_type_is_service_name(self): + """Test that missing oracle_connection_type defaults to service_name format.""" + mock_oracledb = MagicMock() + mock_oracledb.AUTH_MODE_SYSDBA = 2 + mock_oracledb.AUTH_MODE_SYSOPER = 4 + + with patch.dict("sys.modules", {"oracledb": mock_oracledb}): + from sqlit.domains.connections.providers.oracle.adapter import OracleAdapter + + adapter = OracleAdapter() + # Create config without oracle_connection_type + config = ConnectionConfig( + name="test", + db_type="oracle", + server="localhost", + port="1521", + database="ORCL", + username="testuser", + password="testpass", + ) + + adapter.connect(config) + + mock_oracledb.connect.assert_called_once() + call_kwargs = mock_oracledb.connect.call_args.kwargs + # Should default to service name format with slash + assert call_kwargs["dsn"] == "localhost:1521/ORCL" + + def test_connect_sid_with_custom_port(self): + """Test SID format works with non-default port.""" + mock_oracledb = MagicMock() + mock_oracledb.AUTH_MODE_SYSDBA = 2 + mock_oracledb.AUTH_MODE_SYSOPER = 4 + + with patch.dict("sys.modules", {"oracledb": mock_oracledb}): + from sqlit.domains.connections.providers.oracle.adapter import OracleAdapter + + adapter = OracleAdapter() + config = ConnectionConfig( + name="test", + db_type="oracle", + server="db.example.com", + port="1522", + username="testuser", + password="testpass", + options={"oracle_connection_type": "sid", "oracle_sid": "PROD"}, + ) + + adapter.connect(config) + + mock_oracledb.connect.assert_called_once() + call_kwargs = mock_oracledb.connect.call_args.kwargs + assert call_kwargs["dsn"] == "db.example.com:1522:PROD"