Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NEXT_CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
### New Features and Improvements
* Add support for unified hosts. A single configuration profile can now be used for both account-level and workspace-level operations when the host supports it and both `account_id` and `workspace_id` are available. The `experimental_is_unified_host` flag has been removed; unified host detection is now automatic.
* Accept `DATABRICKS_OIDC_TOKEN_FILEPATH` environment variable for consistency with other Databricks SDKs (Go, CLI, Terraform). The previous `DATABRICKS_OIDC_TOKEN_FILE` is still supported as an alias.
* Pass `--force-refresh` to the Databricks CLI `auth token` command so the SDK always receives a freshly minted token instead of a potentially stale cached one.

### Security

Expand Down
27 changes: 18 additions & 9 deletions databricks/sdk/credentials_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,7 @@ def refreshed_headers() -> Dict[str, str]:


class CliTokenSource(oauth.Refreshable):

def __init__(
self,
cmd: List[str],
Expand Down Expand Up @@ -899,11 +900,10 @@ def __init__(self, cfg: "Config"):
cli_path = self.__class__._find_executable(cli_path)

fallback_cmd = None
self._force_cmd = None
if cfg.profile:
# When profile is set, use --profile as the primary command.
# The profile contains the full config (host, account_id, etc.).
args = ["auth", "token", "--profile", cfg.profile]
# Build a --host fallback for older CLIs that don't support --profile.
self._force_cmd = [cli_path, *args, "--force-refresh"]
if cfg.host:
fallback_cmd = [cli_path, *self.__class__._build_host_args(cfg)]
else:
Expand All @@ -926,12 +926,21 @@ def __init__(self, cfg: "Config"):
)

def refresh(self) -> oauth.Token:
# The scope validation lives in refresh() because this is the only method that
# produces new tokens (see Refreshable._token assignments). By overriding here,
# every token is validated β€” both at initial auth and on every subsequent refresh
# when the cached token expires. This catches cases where a user re-authenticates
# mid-session with different scopes.
token = super().refresh()
if self._force_cmd is None:
token = super().refresh()
else:
try:
token = self._exec_cli_command(self._force_cmd)
except IOError as e:
err_msg = str(e)
if "unknown flag: --force-refresh" in err_msg or "unknown flag: --profile" in err_msg:
logger.warning(
"Databricks CLI does not support --force-refresh. "
"Please upgrade your CLI to the latest version."
)
token = super().refresh()
else:
raise
if self._requested_scopes:
self._validate_token_scopes(token)
return token
Expand Down
141 changes: 141 additions & 0 deletions tests/test_credentials_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,147 @@ def test_no_fallback_when_fallback_cmd_not_set(self, mocker):
assert mock_run.call_count == 1


class TestDatabricksCliForceRefresh:
"""Tests for --force-refresh support in DatabricksCliTokenSource."""

@staticmethod
def _make_process_error(stderr: str, stdout: str = ""):
import subprocess

err = subprocess.CalledProcessError(1, ["databricks"])
err.stdout = stdout.encode()
err.stderr = stderr.encode()
return err

@staticmethod
def _make_token_source(
*,
profile=None,
host="https://workspace.databricks.com",
cli_path="/path/to/databricks",
):
"""Build a DatabricksCliTokenSource by mocking only the executable check."""
mock_cfg = Mock()
mock_cfg.profile = profile
mock_cfg.host = host
mock_cfg.databricks_cli_path = cli_path
mock_cfg.disable_async_token_refresh = True
mock_cfg.scopes = None
mock_cfg.get_scopes = Mock(return_value=["all-apis"])
mock_cfg.client_type = ClientType.WORKSPACE
mock_cfg.account_id = None
return credentials_provider.DatabricksCliTokenSource(mock_cfg)

def _valid_response_json(self, access_token="fresh-token"):
import json

expiry = (datetime.now() + timedelta(hours=1)).strftime("%Y-%m-%dT%H:%M:%S")
return json.dumps({"access_token": access_token, "token_type": "Bearer", "expiry": expiry})

def test_force_refresh_tried_first_with_profile(self, mocker):
"""When profile is configured, refresh() tries --force-refresh first."""
ts = self._make_token_source(profile="my-profile")

mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess")
mock_run.return_value = Mock(stdout=self._valid_response_json("refreshed").encode())

token = ts.refresh()
assert token.access_token == "refreshed"
assert mock_run.call_count == 1

cmd = mock_run.call_args[0][0]
assert "--force-refresh" in cmd
assert "--profile" in cmd

def test_host_only_skips_force_refresh(self, mocker):
"""When only host is configured, --force-refresh is not used."""
ts = self._make_token_source()

mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess")
mock_run.return_value = Mock(stdout=self._valid_response_json("token").encode())

token = ts.refresh()
assert token.access_token == "token"
assert mock_run.call_count == 1

cmd = mock_run.call_args[0][0]
assert "--force-refresh" not in cmd
assert "--host" in cmd

def test_force_refresh_fallback_when_unsupported(self, mocker):
"""Old CLI without --force-refresh: falls back to cmd without --force-refresh."""
ts = self._make_token_source(profile="my-profile")

mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess")
mock_run.side_effect = [
self._make_process_error("Error: unknown flag: --force-refresh"),
Mock(stdout=self._valid_response_json("fallback").encode()),
]

token = ts.refresh()
assert token.access_token == "fallback"
assert mock_run.call_count == 2

first_cmd = mock_run.call_args_list[0][0][0]
second_cmd = mock_run.call_args_list[1][0][0]
assert "--force-refresh" in first_cmd
assert "--force-refresh" not in second_cmd

def test_profile_fallback_when_unsupported(self, mocker):
"""Old CLI without --profile: force_cmd fails, fallback retries with --host."""
ts = self._make_token_source(profile="my-profile")

mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess")
mock_run.side_effect = [
# force_cmd: --profile + --force-refresh β†’ unknown --profile
self._make_process_error("Error: unknown flag: --profile"),
# _refresh_without_force cmd: --profile β†’ unknown --profile
self._make_process_error("Error: unknown flag: --profile"),
# _refresh_without_force fallback_cmd: --host β†’ success
Mock(stdout=self._valid_response_json("host-token").encode()),
]

token = ts.refresh()
assert token.access_token == "host-token"
assert mock_run.call_count == 3
assert "--host" in mock_run.call_args_list[2][0][0]

def test_two_step_downgrade_both_flags_unsupported(self, mocker):
"""CLI supports neither --force-refresh nor --profile: force_cmd fails, then full fallback."""
ts = self._make_token_source(profile="my-profile")

mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess")
mock_run.side_effect = [
# 1st: force_cmd (--profile + --force-refresh) β†’ unknown --force-refresh
self._make_process_error("Error: unknown flag: --force-refresh"),
# 2nd: _refresh_without_force cmd (--profile) β†’ unknown --profile
self._make_process_error("Error: unknown flag: --profile"),
# 3rd: _refresh_without_force fallback_cmd (--host) β†’ success
Mock(stdout=self._valid_response_json("plain").encode()),
]

token = ts.refresh()
assert token.access_token == "plain"
assert mock_run.call_count == 3

cmds = [call[0][0] for call in mock_run.call_args_list]
assert "--force-refresh" in cmds[0] and "--profile" in cmds[0]
assert "--force-refresh" not in cmds[1] and "--profile" in cmds[1]
assert "--host" in cmds[2]

def test_real_auth_error_does_not_trigger_fallback(self, mocker):
"""Real auth failures (not unknown-flag) surface immediately."""
ts = self._make_token_source()

mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess")
mock_run.side_effect = self._make_process_error("cache: databricks OAuth is not configured for this host")

with pytest.raises(IOError) as exc_info:
ts.refresh()
assert "databricks OAuth is not configured" in str(exc_info.value)
assert mock_run.call_count == 1


# Tests for cloud-agnostic hosts and removed cloud checks
class TestCloudAgnosticHosts:
"""Tests that credential providers work with cloud-agnostic hosts after removing is_azure/is_gcp checks."""
Expand Down
Loading