Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions NEXT_CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
### New Features and Improvements
* Add support for unified hosts. A single configuration profile can now be used for both account-level and workspace-level operations when the host supports it and both `account_id` and `workspace_id` are available. The `experimental_is_unified_host` flag has been removed; unified host detection is now automatic.
* Accept `DATABRICKS_OIDC_TOKEN_FILEPATH` environment variable for consistency with other Databricks SDKs (Go, CLI, Terraform). The previous `DATABRICKS_OIDC_TOKEN_FILE` is still supported as an alias.
* Pass `--force-refresh` to the Databricks CLI `auth token` command so the SDK always receives a freshly minted token instead of a potentially stale cached one.

### Security

Expand All @@ -18,6 +19,7 @@
### Internal Changes
* Replace the async-disabling mechanism on token refresh failure with a 1-minute retry backoff. Previously, a single failed async refresh would disable proactive token renewal until the token expired. Now, the SDK waits a short cooldown period and retries, improving resilience to transient errors.
* Extract `_resolve_profile` to simplify config file loading and improve `__settings__` error messages.
* Generalize CLI token source into a progressive command list for forward-compatible flag support.

### API Changes
* Add `create_catalog()`, `create_synced_table()`, `delete_catalog()`, `delete_synced_table()`, `get_catalog()` and `get_synced_table()` methods for [w.postgres](https://databricks-sdk-py.readthedocs.io/en/latest/workspace/postgres/postgres.html) workspace-level service.
Expand Down
110 changes: 93 additions & 17 deletions databricks/sdk/credentials_provider.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import abc
import base64
import dataclasses
import functools
import io
import json
Expand Down Expand Up @@ -649,7 +650,17 @@ def refreshed_headers() -> Dict[str, str]:
return OAuthCredentialsProvider(refreshed_headers, token)


@dataclasses.dataclass
class CliCommand:
"""A single CLI command variant with metadata for progressive fallback."""

args: List[str]
flags: List[str]
warning: str


class CliTokenSource(oauth.Refreshable):

def __init__(
self,
cmd: List[str],
Expand All @@ -658,6 +669,7 @@ def __init__(
expiry_field: str,
disable_async: bool = True,
fallback_cmd: Optional[List[str]] = None,
commands: Optional[List[CliCommand]] = None,
):
super().__init__(disable_async=disable_async)
self._cmd = cmd
Expand All @@ -669,6 +681,8 @@ def __init__(
self._token_type_field = token_type_field
self._access_token_field = access_token_field
self._expiry_field = expiry_field
self._commands = commands
self._active_command_index = -1

@staticmethod
def _parse_expiry(expiry: str) -> datetime:
Expand Down Expand Up @@ -699,7 +713,18 @@ def _exec_cli_command(self, cmd: List[str]) -> oauth.Token:
message = "\n".join(filter(None, [stdout, stderr]))
raise IOError(f"cannot get access token: {message}") from e

@staticmethod
def _is_unknown_flag_error(error: IOError, flags: List[str]) -> bool:
"""Check if the error indicates the CLI rejected one of the given flags."""
msg = str(error)
return any(f"unknown flag: {flag}" in msg for flag in flags)

def refresh(self) -> oauth.Token:
if self._commands is not None:
return self._refresh_progressive()
return self._refresh_single()

def _refresh_single(self) -> oauth.Token:
try:
return self._exec_cli_command(self._cmd)
except IOError as e:
Expand All @@ -711,6 +736,30 @@ def refresh(self) -> oauth.Token:
return self._exec_cli_command(self._fallback_cmd)
raise

def _refresh_progressive(self) -> oauth.Token:
idx = self._active_command_index
if idx >= 0:
return self._exec_cli_command(self._commands[idx].args)
return self._probe_and_exec()

def _probe_and_exec(self) -> oauth.Token:
"""Walk the command list to find a CLI command that succeeds.

When a command fails with "unknown flag" for one of its flags, log a
warning and try the next. On success, store _active_command_index so
future calls skip probing.
"""
for i, cmd in enumerate(self._commands):
try:
token = self._exec_cli_command(cmd.args)
self._active_command_index = i
return token
except IOError as e:
is_last = i == len(self._commands) - 1
if is_last or not self._is_unknown_flag_error(e, cmd.flags):
raise
logger.warning(cmd.warning)


def _run_subprocess(
popenargs,
Expand Down Expand Up @@ -898,16 +947,7 @@ def __init__(self, cfg: "Config"):
elif cli_path.count("/") == 0:
cli_path = self.__class__._find_executable(cli_path)

fallback_cmd = None
if cfg.profile:
# When profile is set, use --profile as the primary command.
# The profile contains the full config (host, account_id, etc.).
args = ["auth", "token", "--profile", cfg.profile]
# Build a --host fallback for older CLIs that don't support --profile.
if cfg.host:
fallback_cmd = [cli_path, *self.__class__._build_host_args(cfg)]
else:
args = self.__class__._build_host_args(cfg)
commands = self.__class__._build_commands(cli_path, cfg)

# get_scopes() defaults to ["all-apis"] when nothing is configured, which would
# cause false-positive mismatches against every token that wasn't issued with
Expand All @@ -917,20 +957,56 @@ def __init__(self, cfg: "Config"):
self._host = cfg.host

super().__init__(
cmd=[cli_path, *args],
cmd=commands[-1].args,
token_type_field="token_type",
access_token_field="access_token",
expiry_field="expiry",
disable_async=cfg.disable_async_token_refresh,
fallback_cmd=fallback_cmd,
commands=commands,
)

@staticmethod
def _build_commands(cli_path: str, cfg: "Config") -> List[CliCommand]:
commands: List[CliCommand] = []
if cfg.profile:
profile_args = [cli_path, "auth", "token", "--profile", cfg.profile]
commands.append(
CliCommand(
args=profile_args + ["--force-refresh"],
flags=["--force-refresh", "--profile"],
warning="Databricks CLI does not support --force-refresh. "
"Please upgrade your CLI to the latest version.",
)
)
commands.append(
CliCommand(
args=profile_args,
flags=["--profile"],
warning="Databricks CLI does not support --profile flag. "
"Falling back to --host. "
"Please upgrade your CLI to the latest version.",
)
)
if cfg.host:
commands.append(
CliCommand(
args=[cli_path, *DatabricksCliTokenSource._build_host_args(cfg)],
flags=[],
warning="",
)
)
else:
host_args = [cli_path, *DatabricksCliTokenSource._build_host_args(cfg)]
commands.append(
CliCommand(
args=host_args,
flags=[],
warning="",
)
)
return commands

def refresh(self) -> oauth.Token:
# The scope validation lives in refresh() because this is the only method that
# produces new tokens (see Refreshable._token assignments). By overriding here,
# every token is validated β€” both at initial auth and on every subsequent refresh
# when the cached token expires. This catches cases where a user re-authenticates
# mid-session with different scopes.
token = super().refresh()
if self._requested_scopes:
self._validate_token_scopes(token)
Expand Down
Loading
Loading