Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions sqlit/domains/connections/providers/oracle/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,18 @@ def connect(self, config: ConnectionConfig) -> Any:
if endpoint is None:
raise ValueError("Oracle connections require a TCP-style endpoint.")
port = int(endpoint.port or get_default_port("oracle"))
# Use Easy Connect string format: host:port/service_name
dsn = f"{endpoint.host}:{port}/{endpoint.database}"

# Determine connection type: service_name (default) or sid
connection_type = config.get_option("oracle_connection_type", "service_name")

if connection_type == "sid":
# SID format: host:port:sid (uses colon separator)
# SID is stored in oracle_sid field, fall back to database for backward compat
sid = config.get_option("oracle_sid") or endpoint.database
dsn = f"{endpoint.host}:{port}:{sid}"
else:
# Service Name format: host:port/service_name (uses slash separator)
dsn = f"{endpoint.host}:{port}/{endpoint.database}"

# Determine connection mode based on oracle_role
oracle_role = config.get_option("oracle_role", "normal")
Expand Down
30 changes: 30 additions & 0 deletions sqlit/domains/connections/providers/oracle/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,21 @@ def _get_oracle_role_options() -> tuple[SelectOption, ...]:
)


def _get_oracle_connection_type_options() -> tuple[SelectOption, ...]:
return (
SelectOption("service_name", "Service Name"),
SelectOption("sid", "SID"),
)


def _oracle_connection_type_is_service_name(values: dict) -> bool:
return values.get("oracle_connection_type", "service_name") != "sid"


def _oracle_connection_type_is_sid(values: dict) -> bool:
return values.get("oracle_connection_type") == "sid"


SCHEMA = ConnectionSchema(
db_type="oracle",
display_name="Oracle",
Expand All @@ -32,11 +47,26 @@ def _get_oracle_role_options() -> tuple[SelectOption, ...]:
group="server_port",
),
_port_field("1521"),
SchemaField(
name="oracle_connection_type",
label="Connection Type",
field_type=FieldType.DROPDOWN,
options=_get_oracle_connection_type_options(),
default="service_name",
),
SchemaField(
name="database",
label="Service Name",
placeholder="ORCL or XEPDB1",
required=True,
visible_when=_oracle_connection_type_is_service_name,
),
SchemaField(
name="oracle_sid",
label="SID",
placeholder="ORCL",
required=True,
visible_when=_oracle_connection_type_is_sid,
),
_username_field(),
_password_field(),
Expand Down
30 changes: 30 additions & 0 deletions sqlit/domains/connections/providers/oracle_legacy/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,21 @@ def _get_oracle_role_options() -> tuple[SelectOption, ...]:
)


def _get_oracle_connection_type_options() -> tuple[SelectOption, ...]:
return (
SelectOption("service_name", "Service Name"),
SelectOption("sid", "SID"),
)


def _oracle_connection_type_is_service_name(values: dict) -> bool:
return values.get("oracle_connection_type", "service_name") != "sid"


def _oracle_connection_type_is_sid(values: dict) -> bool:
return values.get("oracle_connection_type") == "sid"


def _get_oracle_client_mode_options() -> tuple[SelectOption, ...]:
return (
SelectOption("thick", "Thick (Instant Client)"),
Expand All @@ -43,11 +58,26 @@ def _oracle_thick_mode_enabled(values: dict) -> bool:
group="server_port",
),
_port_field("1521"),
SchemaField(
name="oracle_connection_type",
label="Connection Type",
field_type=FieldType.DROPDOWN,
options=_get_oracle_connection_type_options(),
default="service_name",
),
SchemaField(
name="database",
label="Service Name",
placeholder="ORCL or XEPDB1",
required=True,
visible_when=_oracle_connection_type_is_service_name,
),
SchemaField(
name="oracle_sid",
label="SID",
placeholder="ORCL",
required=True,
visible_when=_oracle_connection_type_is_sid,
),
_username_field(),
_password_field(),
Expand Down
2 changes: 2 additions & 0 deletions sqlit/domains/explorer/ui/mixins/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,8 @@ def action_select_table(self: TreeMixinHost) -> None:
"name": data.name,
"columns": [],
}
# Stash per-result metadata so results can resolve PKs without relying on globals.
self._pending_result_table_info = self._last_query_table
self._prime_last_query_table_columns(data.database, data.schema, data.name)

self.query_input.text = self.current_provider.dialect.build_select_query(
Expand Down
11 changes: 11 additions & 0 deletions sqlit/domains/query/ui/mixins/query_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,17 @@ async def _run_query_async(self: QueryMixinHost, query: str, keep_insert_mode: b
executable_statements = [s for s in statements if not is_comment_only_statement(s)]
is_multi_statement = len(executable_statements) > 1

if is_multi_statement:
self._pending_result_table_info = None
elif executable_statements:
if getattr(self, "_pending_result_table_info", None) is None:
table_info = self._infer_result_table_info(executable_statements[0])
if table_info is not None:
self._pending_result_table_info = table_info
prime = getattr(self, "_prime_result_table_columns", None)
if callable(prime):
prime(table_info)

try:
start_time = time.perf_counter()
max_rows = self.services.runtime.max_rows or MAX_FETCH_ROWS
Expand Down
59 changes: 58 additions & 1 deletion sqlit/domains/query/ui/mixins/query_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ def _render_results_table_incremental(
escape: bool,
row_limit: int,
render_token: int,
table_info: dict[str, Any] | None = None,
) -> None:
initial_count = min(RESULTS_RENDER_INITIAL_ROWS, row_limit)
initial_rows = rows[:initial_count] if initial_count > 0 else []
Expand Down Expand Up @@ -269,9 +270,16 @@ def _render_results_table_incremental(
pass
if render_token == getattr(self, "_results_render_token", 0):
self._replace_results_table_with_data(columns, rows, escape=escape)
if table_info is not None:
try:
self.results_table.result_table_info = table_info
except Exception:
pass
return
if render_token != getattr(self, "_results_render_token", 0):
return
if table_info is not None:
table.result_table_info = table_info
self._replace_results_table_with_table(table)
self._schedule_results_render(
table,
Expand All @@ -291,6 +299,7 @@ async def _display_query_results(
self._last_result_columns = columns
self._last_result_rows = rows
self._last_result_row_count = row_count
table_info = getattr(self, "_pending_result_table_info", None)

# Switch to single result mode (in case we were showing stacked results)
self._show_single_result_mode()
Expand All @@ -304,12 +313,15 @@ async def _display_query_results(
escape=True,
row_limit=row_limit,
render_token=render_token,
table_info=table_info,
)
else:
render_rows = rows[:row_limit] if row_limit else []
table = self._build_results_table(columns, render_rows, escape=True)
if render_token != getattr(self, "_results_render_token", 0):
return
if table_info is not None:
table.result_table_info = table_info
self._replace_results_table_with_table(table)

time_str = format_duration_ms(elapsed_ms)
Expand All @@ -320,9 +332,15 @@ async def _display_query_results(
)
else:
self.notify(f"Query returned {row_count} rows in {time_str}")
if table_info is not None:
prime = getattr(self, "_prime_result_table_columns", None)
if callable(prime):
prime(table_info)
self._pending_result_table_info = None

def _display_non_query_result(self: QueryMixinHost, affected: int, elapsed_ms: float) -> None:
"""Display non-query result (called on main thread)."""
self._pending_result_table_info = None
self._last_result_columns = ["Result"]
self._last_result_rows = [(f"{affected} row(s) affected",)]
self._last_result_row_count = 1
Expand All @@ -337,6 +355,7 @@ def _display_non_query_result(self: QueryMixinHost, affected: int, elapsed_ms: f
def _display_query_error(self: QueryMixinHost, error_message: str) -> None:
"""Display query error (called on main thread)."""
self._cancel_results_render()
self._pending_result_table_info = None
# notify(severity="error") handles displaying the error in results via _show_error_in_results
self.notify(f"Query error: {error_message}", severity="error")

Expand All @@ -360,7 +379,17 @@ def _display_multi_statement_results(

# Add each result section
for i, stmt_result in enumerate(multi_result.results):
container.add_result_section(stmt_result, i, auto_collapse=auto_collapse)
table_info = self._infer_result_table_info(stmt_result.statement)
if table_info is not None:
prime = getattr(self, "_prime_result_table_columns", None)
if callable(prime):
prime(table_info)
container.add_result_section(
stmt_result,
i,
auto_collapse=auto_collapse,
table_info=table_info,
)

# Show the stacked results container, hide single result table
self._show_stacked_results_mode()
Expand All @@ -378,6 +407,7 @@ def _display_multi_statement_results(
)
else:
self.notify(f"Executed {total} statements in {time_str}")
self._pending_result_table_info = None

def _get_stacked_results_container(self: QueryMixinHost) -> Any:
"""Get the stacked results container."""
Expand Down Expand Up @@ -410,3 +440,30 @@ def _show_single_result_mode(self: QueryMixinHost) -> None:
stacked.remove_class("active")
except Exception:
pass

def _infer_result_table_info(self: QueryMixinHost, sql: str) -> dict[str, Any] | None:
"""Best-effort inference of a single source table for query results."""
from sqlit.domains.query.completion import extract_table_refs

refs = extract_table_refs(sql)
if len(refs) != 1:
return None
ref = refs[0]
schema = ref.schema
name = ref.name
database = None
table_metadata = getattr(self, "_table_metadata", {}) or {}
key_candidates = [name.lower()]
if schema:
key_candidates.insert(0, f"{schema}.{name}".lower())
for key in key_candidates:
metadata = table_metadata.get(key)
if metadata:
schema, name, database = metadata
break
return {
"database": database,
"schema": schema,
"name": name,
"columns": [],
}
Loading
Loading