From 4371112827ccf4e3f981ce5a2bbe946b7da5f3c1 Mon Sep 17 00:00:00 2001
From: Peter Adams <18162810+Maxteabag@users.noreply.github.com>
Date: Sat, 31 Jan 2026 15:34:31 +0100
Subject: [PATCH 1/2] Attach result table metadata for PK-safe updates
---
sqlit/domains/explorer/ui/mixins/tree.py | 2 +
.../query/ui/mixins/query_execution.py | 11 ++
.../domains/query/ui/mixins/query_results.py | 59 ++++++-
sqlit/domains/results/ui/mixins/results.py | 146 ++++++++++++++++--
sqlit/domains/shell/app/main.py | 1 +
sqlit/shared/ui/protocols/results.py | 1 +
sqlit/shared/ui/widgets_stacked_results.py | 5 +
7 files changed, 209 insertions(+), 16 deletions(-)
diff --git a/sqlit/domains/explorer/ui/mixins/tree.py b/sqlit/domains/explorer/ui/mixins/tree.py
index 94914618..1a4f2192 100644
--- a/sqlit/domains/explorer/ui/mixins/tree.py
+++ b/sqlit/domains/explorer/ui/mixins/tree.py
@@ -295,6 +295,8 @@ def action_select_table(self: TreeMixinHost) -> None:
"name": data.name,
"columns": [],
}
+ # Stash per-result metadata so results can resolve PKs without relying on globals.
+ self._pending_result_table_info = self._last_query_table
self._prime_last_query_table_columns(data.database, data.schema, data.name)
self.query_input.text = self.current_provider.dialect.build_select_query(
diff --git a/sqlit/domains/query/ui/mixins/query_execution.py b/sqlit/domains/query/ui/mixins/query_execution.py
index 6b1c4e27..417a633a 100644
--- a/sqlit/domains/query/ui/mixins/query_execution.py
+++ b/sqlit/domains/query/ui/mixins/query_execution.py
@@ -417,6 +417,17 @@ async def _run_query_async(self: QueryMixinHost, query: str, keep_insert_mode: b
executable_statements = [s for s in statements if not is_comment_only_statement(s)]
is_multi_statement = len(executable_statements) > 1
+ if is_multi_statement:
+ self._pending_result_table_info = None
+ elif executable_statements:
+ if getattr(self, "_pending_result_table_info", None) is None:
+ table_info = self._infer_result_table_info(executable_statements[0])
+ if table_info is not None:
+ self._pending_result_table_info = table_info
+ prime = getattr(self, "_prime_result_table_columns", None)
+ if callable(prime):
+ prime(table_info)
+
try:
start_time = time.perf_counter()
max_rows = self.services.runtime.max_rows or MAX_FETCH_ROWS
diff --git a/sqlit/domains/query/ui/mixins/query_results.py b/sqlit/domains/query/ui/mixins/query_results.py
index c40223fb..6be6e8b3 100644
--- a/sqlit/domains/query/ui/mixins/query_results.py
+++ b/sqlit/domains/query/ui/mixins/query_results.py
@@ -224,6 +224,7 @@ def _render_results_table_incremental(
escape: bool,
row_limit: int,
render_token: int,
+ table_info: dict[str, Any] | None = None,
) -> None:
initial_count = min(RESULTS_RENDER_INITIAL_ROWS, row_limit)
initial_rows = rows[:initial_count] if initial_count > 0 else []
@@ -269,9 +270,16 @@ def _render_results_table_incremental(
pass
if render_token == getattr(self, "_results_render_token", 0):
self._replace_results_table_with_data(columns, rows, escape=escape)
+ if table_info is not None:
+ try:
+ self.results_table.result_table_info = table_info
+ except Exception:
+ pass
return
if render_token != getattr(self, "_results_render_token", 0):
return
+ if table_info is not None:
+ table.result_table_info = table_info
self._replace_results_table_with_table(table)
self._schedule_results_render(
table,
@@ -291,6 +299,7 @@ async def _display_query_results(
self._last_result_columns = columns
self._last_result_rows = rows
self._last_result_row_count = row_count
+ table_info = getattr(self, "_pending_result_table_info", None)
# Switch to single result mode (in case we were showing stacked results)
self._show_single_result_mode()
@@ -304,12 +313,15 @@ async def _display_query_results(
escape=True,
row_limit=row_limit,
render_token=render_token,
+ table_info=table_info,
)
else:
render_rows = rows[:row_limit] if row_limit else []
table = self._build_results_table(columns, render_rows, escape=True)
if render_token != getattr(self, "_results_render_token", 0):
return
+ if table_info is not None:
+ table.result_table_info = table_info
self._replace_results_table_with_table(table)
time_str = format_duration_ms(elapsed_ms)
@@ -320,9 +332,15 @@ async def _display_query_results(
)
else:
self.notify(f"Query returned {row_count} rows in {time_str}")
+ if table_info is not None:
+ prime = getattr(self, "_prime_result_table_columns", None)
+ if callable(prime):
+ prime(table_info)
+ self._pending_result_table_info = None
def _display_non_query_result(self: QueryMixinHost, affected: int, elapsed_ms: float) -> None:
"""Display non-query result (called on main thread)."""
+ self._pending_result_table_info = None
self._last_result_columns = ["Result"]
self._last_result_rows = [(f"{affected} row(s) affected",)]
self._last_result_row_count = 1
@@ -337,6 +355,7 @@ def _display_non_query_result(self: QueryMixinHost, affected: int, elapsed_ms: f
def _display_query_error(self: QueryMixinHost, error_message: str) -> None:
"""Display query error (called on main thread)."""
self._cancel_results_render()
+ self._pending_result_table_info = None
# notify(severity="error") handles displaying the error in results via _show_error_in_results
self.notify(f"Query error: {error_message}", severity="error")
@@ -360,7 +379,17 @@ def _display_multi_statement_results(
# Add each result section
for i, stmt_result in enumerate(multi_result.results):
- container.add_result_section(stmt_result, i, auto_collapse=auto_collapse)
+ table_info = self._infer_result_table_info(stmt_result.statement)
+ if table_info is not None:
+ prime = getattr(self, "_prime_result_table_columns", None)
+ if callable(prime):
+ prime(table_info)
+ container.add_result_section(
+ stmt_result,
+ i,
+ auto_collapse=auto_collapse,
+ table_info=table_info,
+ )
# Show the stacked results container, hide single result table
self._show_stacked_results_mode()
@@ -378,6 +407,7 @@ def _display_multi_statement_results(
)
else:
self.notify(f"Executed {total} statements in {time_str}")
+ self._pending_result_table_info = None
def _get_stacked_results_container(self: QueryMixinHost) -> Any:
"""Get the stacked results container."""
@@ -410,3 +440,30 @@ def _show_single_result_mode(self: QueryMixinHost) -> None:
stacked.remove_class("active")
except Exception:
pass
+
+ def _infer_result_table_info(self: QueryMixinHost, sql: str) -> dict[str, Any] | None:
+ """Best-effort inference of a single source table for query results."""
+ from sqlit.domains.query.completion import extract_table_refs
+
+ refs = extract_table_refs(sql)
+ if len(refs) != 1:
+ return None
+ ref = refs[0]
+ schema = ref.schema
+ name = ref.name
+ database = None
+ table_metadata = getattr(self, "_table_metadata", {}) or {}
+ key_candidates = [name.lower()]
+ if schema:
+ key_candidates.insert(0, f"{schema}.{name}".lower())
+ for key in key_candidates:
+ metadata = table_metadata.get(key)
+ if metadata:
+ schema, name, database = metadata
+ break
+ return {
+ "database": database,
+ "schema": schema,
+ "name": name,
+ "columns": [],
+ }
diff --git a/sqlit/domains/results/ui/mixins/results.py b/sqlit/domains/results/ui/mixins/results.py
index 2f3a4340..d8750299 100644
--- a/sqlit/domains/results/ui/mixins/results.py
+++ b/sqlit/domains/results/ui/mixins/results.py
@@ -8,6 +8,8 @@
from sqlit.shared.ui.protocols import ResultsMixinHost
from sqlit.shared.ui.widgets import SqlitDataTable
+MIN_TIMER_DELAY_S = 0.001
+
class ResultsMixin:
"""Mixin providing results handling functionality."""
@@ -19,6 +21,122 @@ class ResultsMixin:
_tooltip_showing: bool = False
_tooltip_timer: Any | None = None
+ def _schedule_results_timer(self: ResultsMixinHost, delay_s: float, callback: Any) -> Any | None:
+ set_timer = getattr(self, "set_timer", None)
+ if callable(set_timer):
+ return set_timer(delay_s, callback)
+ call_later = getattr(self, "call_later", None)
+ if callable(call_later):
+ try:
+ call_later(callback)
+ return None
+ except Exception:
+ pass
+ try:
+ callback()
+ except Exception:
+ pass
+ return None
+
+ def _apply_result_table_columns(
+ self: ResultsMixinHost,
+ table_info: dict[str, Any],
+ token: int,
+ columns: list[Any],
+ ) -> None:
+ if table_info.get("_columns_token") != token:
+ return
+ table_info["columns"] = columns
+
+ def _prime_result_table_columns(self: ResultsMixinHost, table_info: dict[str, Any] | None) -> None:
+ if not table_info:
+ return
+ if table_info.get("columns"):
+ return
+ name = table_info.get("name")
+ if not name:
+ return
+ database = table_info.get("database")
+ schema = table_info.get("schema")
+ token = int(table_info.get("_columns_token", 0)) + 1
+ table_info["_columns_token"] = token
+
+ async def work_async() -> None:
+ import asyncio
+
+ columns: list[Any] = []
+ try:
+ runtime = getattr(self.services, "runtime", None)
+ use_worker = bool(getattr(runtime, "process_worker", False)) and not bool(
+ getattr(getattr(runtime, "mock", None), "enabled", False)
+ )
+ client = None
+ if use_worker and hasattr(self, "_get_process_worker_client_async"):
+ client = await self._get_process_worker_client_async() # type: ignore[attr-defined]
+
+ if client is not None and hasattr(client, "list_columns") and self.current_config is not None:
+ outcome = await asyncio.to_thread(
+ client.list_columns,
+ config=self.current_config,
+ database=database,
+ schema=schema,
+ name=name,
+ )
+ if getattr(outcome, "cancelled", False):
+ return
+ error = getattr(outcome, "error", None)
+ if error:
+ raise RuntimeError(error)
+ columns = outcome.columns or []
+ else:
+ schema_service = getattr(self, "_get_schema_service", None)
+ if callable(schema_service):
+ service = self._get_schema_service()
+ if service:
+ columns = await asyncio.to_thread(
+ service.list_columns,
+ database,
+ schema,
+ name,
+ )
+ except Exception:
+ columns = []
+
+ self._schedule_results_timer(
+ MIN_TIMER_DELAY_S,
+ lambda: self._apply_result_table_columns(table_info, token, columns),
+ )
+
+ self.run_worker(work_async(), name=f"prime-result-columns-{name}", exclusive=False)
+
+ def _normalize_column_name(self: ResultsMixinHost, name: str) -> str:
+ trimmed = name.strip()
+ if len(trimmed) >= 2:
+ if trimmed[0] == trimmed[-1] and trimmed[0] in ("\"", "`"):
+ trimmed = trimmed[1:-1]
+ elif trimmed[0] == "[" and trimmed[-1] == "]":
+ trimmed = trimmed[1:-1]
+ if "." in trimmed and not any(q in trimmed for q in ("\"", "`", "[")):
+ trimmed = trimmed.split(".")[-1]
+ return trimmed.lower()
+
+ def _get_active_results_table_info(
+ self: ResultsMixinHost,
+ table: SqlitDataTable | None,
+ stacked: bool,
+ ) -> dict[str, Any] | None:
+ if not table:
+ return None
+ if stacked:
+ section = self._find_results_section(table)
+ table_info = getattr(section, "result_table_info", None)
+ if table_info:
+ return table_info
+ table_info = getattr(table, "result_table_info", None)
+ if table_info:
+ return table_info
+ return getattr(self, "_last_query_table", None)
+
def _copy_text(self: ResultsMixinHost, text: str) -> bool:
"""Copy text to clipboard if possible, otherwise store internally."""
self._internal_clipboard = text
@@ -610,21 +728,20 @@ def sql_value(v: object) -> str:
# Get table name and primary key columns
table_name = "
"
pk_column_names: set[str] = set()
-
- if hasattr(self, "_last_query_table") and self._last_query_table:
- table_info = self._last_query_table
- table_name = table_info["name"]
+ table_info = self._get_active_results_table_info(table, _stacked)
+ if table_info:
+ table_name = table_info.get("name", table_name)
# Get PK columns from column info
for col in table_info.get("columns", []):
if col.is_primary_key:
- pk_column_names.add(col.name)
+ pk_column_names.add(self._normalize_column_name(col.name))
# Build WHERE clause - prefer PK columns, fall back to all columns
where_parts = []
for i, col in enumerate(columns):
if i < len(row_values):
# If we have PK info, only use PK columns; otherwise use all columns
- if pk_column_names and col not in pk_column_names:
+ if pk_column_names and self._normalize_column_name(col) not in pk_column_names:
continue
val = row_values[i]
if val is None:
@@ -685,9 +802,10 @@ def action_edit_cell(self: ResultsMixinHost) -> None:
column_name = columns[cursor_col]
# Check if this column is a primary key - don't allow editing PKs
- if hasattr(self, "_last_query_table") and self._last_query_table:
- for col in self._last_query_table.get("columns", []):
- if col.name == column_name and col.is_primary_key:
+ table_info = self._get_active_results_table_info(table, _stacked)
+ if table_info:
+ for col in table_info.get("columns", []):
+ if col.is_primary_key and self._normalize_column_name(col.name) == self._normalize_column_name(column_name):
self.notify("Cannot edit primary key column", severity="warning")
return
@@ -705,21 +823,19 @@ def sql_value(v: object) -> str:
# Get table name and primary key columns
table_name = ""
pk_column_names: set[str] = set()
-
- if hasattr(self, "_last_query_table") and self._last_query_table:
- table_info = self._last_query_table
- table_name = table_info["name"]
+ if table_info:
+ table_name = table_info.get("name", table_name)
# Get PK columns from column info
for col in table_info.get("columns", []):
if col.is_primary_key:
- pk_column_names.add(col.name)
+ pk_column_names.add(self._normalize_column_name(col.name))
# Build WHERE clause - prefer PK columns, fall back to all columns
where_parts = []
for i, col in enumerate(columns):
if i < len(row_values):
# If we have PK info, only use PK columns; otherwise use all columns
- if pk_column_names and col not in pk_column_names:
+ if pk_column_names and self._normalize_column_name(col) not in pk_column_names:
continue
val = row_values[i]
if val is None:
diff --git a/sqlit/domains/shell/app/main.py b/sqlit/domains/shell/app/main.py
index 5f9fbec4..66edcdec 100644
--- a/sqlit/domains/shell/app/main.py
+++ b/sqlit/domains/shell/app/main.py
@@ -197,6 +197,7 @@ def __init__(
self._columns_loading: set[str] = set()
self._state_machine = UIStateMachine()
self._last_query_table: dict[str, Any] | None = None
+ self._pending_result_table_info: dict[str, Any] | None = None
self._query_target_database: str | None = None # Target DB for auto-generated queries
self._restart_requested: bool = False
# Idle scheduler for background work
diff --git a/sqlit/shared/ui/protocols/results.py b/sqlit/shared/ui/protocols/results.py
index d6e3f5eb..2124b1aa 100644
--- a/sqlit/shared/ui/protocols/results.py
+++ b/sqlit/shared/ui/protocols/results.py
@@ -16,6 +16,7 @@ class ResultsStateProtocol(Protocol):
_last_result_row_count: int
_internal_clipboard: str
_last_query_table: dict[str, Any] | None
+ _pending_result_table_info: dict[str, Any] | None
_results_table_counter: int
_results_filter_visible: bool
_results_filter_text: str
diff --git a/sqlit/shared/ui/widgets_stacked_results.py b/sqlit/shared/ui/widgets_stacked_results.py
index e61c819b..015cc450 100644
--- a/sqlit/shared/ui/widgets_stacked_results.py
+++ b/sqlit/shared/ui/widgets_stacked_results.py
@@ -104,6 +104,7 @@ def __init__(
self.is_error = is_error
self.result_columns: list[str] = []
self.result_rows: list[tuple] = []
+ self.result_table_info: dict[str, Any] | None = None
self._content = content
if is_error:
self.add_class("error")
@@ -156,6 +157,7 @@ def add_result_section(
index: int,
*,
auto_collapse: bool = False,
+ table_info: dict[str, Any] | None = None,
) -> None:
"""Add a result section for a statement result."""
from sqlit.domains.query.app.query_service import QueryResult
@@ -169,6 +171,8 @@ def add_result_section(
# SELECT result - build a DataTable
result_columns, result_rows = self._get_result_table_data(stmt_result.result)
content = self._build_result_table_from_rows(result_columns, result_rows, index)
+ if table_info is not None:
+ content.result_table_info = table_info
else:
# Non-query result (INSERT/UPDATE/DELETE)
content = NonQueryDisplay(stmt_result.result.rows_affected)
@@ -186,6 +190,7 @@ def add_result_section(
)
section.result_columns = result_columns
section.result_rows = result_rows
+ section.result_table_info = table_info
self.mount(section)
self._section_count += 1
From dee3140f66a261ba0131ea54b8d2c237f2dc6290 Mon Sep 17 00:00:00 2001
From: Peter Adams <18162810+Maxteabag@users.noreply.github.com>
Date: Sat, 31 Jan 2026 15:45:43 +0100
Subject: [PATCH 2/2] feat: add Oracle SID connection support
Add support for connecting to Oracle databases using SID in addition
to the existing Service Name format. The UI shows a Connection Type
dropdown that conditionally displays either a Service Name or SID
input field based on the selection.
Closes #106
---
.../connections/providers/oracle/adapter.py | 14 +-
.../connections/providers/oracle/schema.py | 30 ++++
.../providers/oracle_legacy/schema.py | 30 ++++
tests/unit/test_oracle_adapter.py | 142 ++++++++++++++++++
4 files changed, 214 insertions(+), 2 deletions(-)
diff --git a/sqlit/domains/connections/providers/oracle/adapter.py b/sqlit/domains/connections/providers/oracle/adapter.py
index c6ad39e4..99dd7a93 100644
--- a/sqlit/domains/connections/providers/oracle/adapter.py
+++ b/sqlit/domains/connections/providers/oracle/adapter.py
@@ -72,8 +72,18 @@ def connect(self, config: ConnectionConfig) -> Any:
if endpoint is None:
raise ValueError("Oracle connections require a TCP-style endpoint.")
port = int(endpoint.port or get_default_port("oracle"))
- # Use Easy Connect string format: host:port/service_name
- dsn = f"{endpoint.host}:{port}/{endpoint.database}"
+
+ # Determine connection type: service_name (default) or sid
+ connection_type = config.get_option("oracle_connection_type", "service_name")
+
+ if connection_type == "sid":
+ # SID format: host:port:sid (uses colon separator)
+ # SID is stored in oracle_sid field, fall back to database for backward compat
+ sid = config.get_option("oracle_sid") or endpoint.database
+ dsn = f"{endpoint.host}:{port}:{sid}"
+ else:
+ # Service Name format: host:port/service_name (uses slash separator)
+ dsn = f"{endpoint.host}:{port}/{endpoint.database}"
# Determine connection mode based on oracle_role
oracle_role = config.get_option("oracle_role", "normal")
diff --git a/sqlit/domains/connections/providers/oracle/schema.py b/sqlit/domains/connections/providers/oracle/schema.py
index 77f49764..f004a6d9 100644
--- a/sqlit/domains/connections/providers/oracle/schema.py
+++ b/sqlit/domains/connections/providers/oracle/schema.py
@@ -20,6 +20,21 @@ def _get_oracle_role_options() -> tuple[SelectOption, ...]:
)
+def _get_oracle_connection_type_options() -> tuple[SelectOption, ...]:
+ return (
+ SelectOption("service_name", "Service Name"),
+ SelectOption("sid", "SID"),
+ )
+
+
+def _oracle_connection_type_is_service_name(values: dict) -> bool:
+ return values.get("oracle_connection_type", "service_name") != "sid"
+
+
+def _oracle_connection_type_is_sid(values: dict) -> bool:
+ return values.get("oracle_connection_type") == "sid"
+
+
SCHEMA = ConnectionSchema(
db_type="oracle",
display_name="Oracle",
@@ -32,11 +47,26 @@ def _get_oracle_role_options() -> tuple[SelectOption, ...]:
group="server_port",
),
_port_field("1521"),
+ SchemaField(
+ name="oracle_connection_type",
+ label="Connection Type",
+ field_type=FieldType.DROPDOWN,
+ options=_get_oracle_connection_type_options(),
+ default="service_name",
+ ),
SchemaField(
name="database",
label="Service Name",
placeholder="ORCL or XEPDB1",
required=True,
+ visible_when=_oracle_connection_type_is_service_name,
+ ),
+ SchemaField(
+ name="oracle_sid",
+ label="SID",
+ placeholder="ORCL",
+ required=True,
+ visible_when=_oracle_connection_type_is_sid,
),
_username_field(),
_password_field(),
diff --git a/sqlit/domains/connections/providers/oracle_legacy/schema.py b/sqlit/domains/connections/providers/oracle_legacy/schema.py
index 5da8d0a1..4252f721 100644
--- a/sqlit/domains/connections/providers/oracle_legacy/schema.py
+++ b/sqlit/domains/connections/providers/oracle_legacy/schema.py
@@ -20,6 +20,21 @@ def _get_oracle_role_options() -> tuple[SelectOption, ...]:
)
+def _get_oracle_connection_type_options() -> tuple[SelectOption, ...]:
+ return (
+ SelectOption("service_name", "Service Name"),
+ SelectOption("sid", "SID"),
+ )
+
+
+def _oracle_connection_type_is_service_name(values: dict) -> bool:
+ return values.get("oracle_connection_type", "service_name") != "sid"
+
+
+def _oracle_connection_type_is_sid(values: dict) -> bool:
+ return values.get("oracle_connection_type") == "sid"
+
+
def _get_oracle_client_mode_options() -> tuple[SelectOption, ...]:
return (
SelectOption("thick", "Thick (Instant Client)"),
@@ -43,11 +58,26 @@ def _oracle_thick_mode_enabled(values: dict) -> bool:
group="server_port",
),
_port_field("1521"),
+ SchemaField(
+ name="oracle_connection_type",
+ label="Connection Type",
+ field_type=FieldType.DROPDOWN,
+ options=_get_oracle_connection_type_options(),
+ default="service_name",
+ ),
SchemaField(
name="database",
label="Service Name",
+ placeholder="ORCL or XEPDB1",
+ required=True,
+ visible_when=_oracle_connection_type_is_service_name,
+ ),
+ SchemaField(
+ name="oracle_sid",
+ label="SID",
placeholder="ORCL",
required=True,
+ visible_when=_oracle_connection_type_is_sid,
),
_username_field(),
_password_field(),
diff --git a/tests/unit/test_oracle_adapter.py b/tests/unit/test_oracle_adapter.py
index f743c78d..b1fca942 100644
--- a/tests/unit/test_oracle_adapter.py
+++ b/tests/unit/test_oracle_adapter.py
@@ -126,3 +126,145 @@ def test_connect_default_role_when_not_set(self):
mock_oracledb.connect.assert_called_once()
call_kwargs = mock_oracledb.connect.call_args.kwargs
assert "mode" not in call_kwargs
+
+
+class TestOracleAdapterConnectionType:
+ """Test Oracle adapter handles connection type (Service Name vs SID) correctly."""
+
+ def test_connect_service_name_format(self):
+ """Test that service_name connection type uses slash separator."""
+ mock_oracledb = MagicMock()
+ mock_oracledb.AUTH_MODE_SYSDBA = 2
+ mock_oracledb.AUTH_MODE_SYSOPER = 4
+
+ with patch.dict("sys.modules", {"oracledb": mock_oracledb}):
+ from sqlit.domains.connections.providers.oracle.adapter import OracleAdapter
+
+ adapter = OracleAdapter()
+ config = ConnectionConfig(
+ name="test",
+ db_type="oracle",
+ server="localhost",
+ port="1521",
+ database="XEPDB1",
+ username="testuser",
+ password="testpass",
+ options={"oracle_connection_type": "service_name"},
+ )
+
+ adapter.connect(config)
+
+ mock_oracledb.connect.assert_called_once()
+ call_kwargs = mock_oracledb.connect.call_args.kwargs
+ # Service name uses slash separator: host:port/service_name
+ assert call_kwargs["dsn"] == "localhost:1521/XEPDB1"
+
+ def test_connect_sid_format(self):
+ """Test that sid connection type uses colon separator with oracle_sid field."""
+ mock_oracledb = MagicMock()
+ mock_oracledb.AUTH_MODE_SYSDBA = 2
+ mock_oracledb.AUTH_MODE_SYSOPER = 4
+
+ with patch.dict("sys.modules", {"oracledb": mock_oracledb}):
+ from sqlit.domains.connections.providers.oracle.adapter import OracleAdapter
+
+ adapter = OracleAdapter()
+ config = ConnectionConfig(
+ name="test",
+ db_type="oracle",
+ server="localhost",
+ port="1521",
+ username="testuser",
+ password="testpass",
+ options={"oracle_connection_type": "sid", "oracle_sid": "ORCL"},
+ )
+
+ adapter.connect(config)
+
+ mock_oracledb.connect.assert_called_once()
+ call_kwargs = mock_oracledb.connect.call_args.kwargs
+ # SID uses colon separator: host:port:sid
+ assert call_kwargs["dsn"] == "localhost:1521:ORCL"
+
+ def test_connect_sid_backward_compat_uses_database_field(self):
+ """Test that SID falls back to database field for backward compatibility."""
+ mock_oracledb = MagicMock()
+ mock_oracledb.AUTH_MODE_SYSDBA = 2
+ mock_oracledb.AUTH_MODE_SYSOPER = 4
+
+ with patch.dict("sys.modules", {"oracledb": mock_oracledb}):
+ from sqlit.domains.connections.providers.oracle.adapter import OracleAdapter
+
+ adapter = OracleAdapter()
+ # Old config style: oracle_sid not set, database used instead
+ config = ConnectionConfig(
+ name="test",
+ db_type="oracle",
+ server="localhost",
+ port="1521",
+ database="LEGACY_SID",
+ username="testuser",
+ password="testpass",
+ options={"oracle_connection_type": "sid"},
+ )
+
+ adapter.connect(config)
+
+ mock_oracledb.connect.assert_called_once()
+ call_kwargs = mock_oracledb.connect.call_args.kwargs
+ # Should use database field as fallback
+ assert call_kwargs["dsn"] == "localhost:1521:LEGACY_SID"
+
+ def test_connect_default_connection_type_is_service_name(self):
+ """Test that missing oracle_connection_type defaults to service_name format."""
+ mock_oracledb = MagicMock()
+ mock_oracledb.AUTH_MODE_SYSDBA = 2
+ mock_oracledb.AUTH_MODE_SYSOPER = 4
+
+ with patch.dict("sys.modules", {"oracledb": mock_oracledb}):
+ from sqlit.domains.connections.providers.oracle.adapter import OracleAdapter
+
+ adapter = OracleAdapter()
+ # Create config without oracle_connection_type
+ config = ConnectionConfig(
+ name="test",
+ db_type="oracle",
+ server="localhost",
+ port="1521",
+ database="ORCL",
+ username="testuser",
+ password="testpass",
+ )
+
+ adapter.connect(config)
+
+ mock_oracledb.connect.assert_called_once()
+ call_kwargs = mock_oracledb.connect.call_args.kwargs
+ # Should default to service name format with slash
+ assert call_kwargs["dsn"] == "localhost:1521/ORCL"
+
+ def test_connect_sid_with_custom_port(self):
+ """Test SID format works with non-default port."""
+ mock_oracledb = MagicMock()
+ mock_oracledb.AUTH_MODE_SYSDBA = 2
+ mock_oracledb.AUTH_MODE_SYSOPER = 4
+
+ with patch.dict("sys.modules", {"oracledb": mock_oracledb}):
+ from sqlit.domains.connections.providers.oracle.adapter import OracleAdapter
+
+ adapter = OracleAdapter()
+ config = ConnectionConfig(
+ name="test",
+ db_type="oracle",
+ server="db.example.com",
+ port="1522",
+ username="testuser",
+ password="testpass",
+ options={"oracle_connection_type": "sid", "oracle_sid": "PROD"},
+ )
+
+ adapter.connect(config)
+
+ mock_oracledb.connect.assert_called_once()
+ call_kwargs = mock_oracledb.connect.call_args.kwargs
+ assert call_kwargs["dsn"] == "db.example.com:1522:PROD"