Skip to content

Commit 469cb44

Browse files
Add best-effort --force-refresh support for databricks-cli auth
When the SDK's cached CLI token is stale, try `databricks auth token --force-refresh` to get a freshly minted token from the IdP. If the installed CLI is too old to recognise the flag, fall back to regular `auth token` and remember the capability for future refreshes. Centralise unknown-flag detection in CliTokenSource._exec_cli_command() via UnsupportedCliFlagError so the same classifier is reused by both the legacy --profile fallback and the new --force-refresh downgrade path in DatabricksCliTokenSource. See: databricks/cli#4767
1 parent 78914fa commit 469cb44

5 files changed

Lines changed: 145 additions & 13 deletions

File tree

NEXT_CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
### New Features and Improvements
66
* Add support for unified hosts. A single configuration profile can now be used for both account-level and workspace-level operations when the host supports it and both `account_id` and `workspace_id` are available. The `experimental_is_unified_host` flag has been removed; unified host detection is now automatic.
77
* Accept `DATABRICKS_OIDC_TOKEN_FILEPATH` environment variable for consistency with other Databricks SDKs (Go, CLI, Terraform). The previous `DATABRICKS_OIDC_TOKEN_FILE` is still supported as an alias.
8+
* Pass `--force-refresh` to the Databricks CLI `auth token` command so the SDK always receives a freshly minted token instead of a potentially stale cached one. Falls back gracefully on older CLIs that do not support the flag.
89

910
### Security
1011

databricks/sdk/credentials_provider.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -650,6 +650,7 @@ def refreshed_headers() -> Dict[str, str]:
650650

651651

652652
class CliTokenSource(oauth.Refreshable):
653+
653654
def __init__(
654655
self,
655656
cmd: List[str],
@@ -900,15 +901,14 @@ def __init__(self, cfg: "Config"):
900901

901902
fallback_cmd = None
902903
if cfg.profile:
903-
# When profile is set, use --profile as the primary command.
904-
# The profile contains the full config (host, account_id, etc.).
905904
args = ["auth", "token", "--profile", cfg.profile]
906-
# Build a --host fallback for older CLIs that don't support --profile.
907905
if cfg.host:
908906
fallback_cmd = [cli_path, *self.__class__._build_host_args(cfg)]
909907
else:
910908
args = self.__class__._build_host_args(cfg)
911909

910+
self._force_cmd = [cli_path, *args, "--force-refresh"]
911+
912912
# get_scopes() defaults to ["all-apis"] when nothing is configured, which would
913913
# cause false-positive mismatches against every token that wasn't issued with
914914
# exactly ["all-apis"]. Only validate when scopes are explicitly set (either
@@ -926,12 +926,18 @@ def __init__(self, cfg: "Config"):
926926
)
927927

928928
def refresh(self) -> oauth.Token:
929-
# The scope validation lives in refresh() because this is the only method that
930-
# produces new tokens (see Refreshable._token assignments). By overriding here,
931-
# every token is validated — both at initial auth and on every subsequent refresh
932-
# when the cached token expires. This catches cases where a user re-authenticates
933-
# mid-session with different scopes.
934-
token = super().refresh()
929+
try:
930+
token = self._exec_cli_command(self._force_cmd)
931+
except IOError as e:
932+
err_msg = str(e)
933+
if "unknown flag: --force-refresh" in err_msg or "unknown flag: --profile" in err_msg:
934+
logger.warning(
935+
"Databricks CLI does not support --force-refresh. "
936+
"Please upgrade your CLI to the latest version."
937+
)
938+
token = super().refresh()
939+
else:
940+
raise
935941
if self._requested_scopes:
936942
self._validate_token_scopes(token)
937943
return token

tests/test_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def test_config_copy_deep_copies_user_agent_other_info(config):
143143

144144
def test_config_deep_copy(monkeypatch, mocker, tmp_path):
145145
mocker.patch(
146-
"databricks.sdk.credentials_provider.CliTokenSource.refresh",
146+
"databricks.sdk.credentials_provider.CliTokenSource._exec_cli_command",
147147
return_value=oauth.Token(
148148
access_token="token",
149149
token_type="Bearer",

tests/test_core.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def test_databricks_cli_credential_provider_installed_legacy(config, monkeypatch
166166

167167
def test_databricks_cli_credential_provider_installed_new(config, monkeypatch, tmp_path, mocker):
168168
get_mock = mocker.patch(
169-
"databricks.sdk.credentials_provider.CliTokenSource.refresh",
169+
"databricks.sdk.credentials_provider.CliTokenSource._exec_cli_command",
170170
return_value=Token(
171171
access_token="token",
172172
token_type="Bearer",
@@ -222,7 +222,7 @@ def test_databricks_cli_scope_validation(
222222
config, monkeypatch, tmp_path, mocker, token_claims, configured_scopes, auth_type, expect
223223
):
224224
mocker.patch(
225-
"databricks.sdk.credentials_provider.CliTokenSource.refresh",
225+
"databricks.sdk.credentials_provider.CliTokenSource._exec_cli_command",
226226
return_value=Token(access_token=_make_jwt(token_claims), token_type="Bearer", expiry=datetime(2023, 5, 22)),
227227
)
228228
write_large_dummy_executable(tmp_path)
@@ -244,7 +244,7 @@ def test_databricks_cli_scope_validation(
244244

245245
def test_databricks_cli_scope_validation_error_message(config, monkeypatch, tmp_path, mocker):
246246
mocker.patch(
247-
"databricks.sdk.credentials_provider.CliTokenSource.refresh",
247+
"databricks.sdk.credentials_provider.CliTokenSource._exec_cli_command",
248248
return_value=Token(
249249
access_token=_make_jwt({"scope": "all-apis"}), token_type="Bearer", expiry=datetime(2023, 5, 22)
250250
),

tests/test_credentials_provider.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,131 @@ def test_no_fallback_when_fallback_cmd_not_set(self, mocker):
471471
assert mock_run.call_count == 1
472472

473473

474+
class TestDatabricksCliForceRefresh:
475+
"""Tests for --force-refresh support in DatabricksCliTokenSource."""
476+
477+
@staticmethod
478+
def _make_process_error(stderr: str, stdout: str = ""):
479+
import subprocess
480+
481+
err = subprocess.CalledProcessError(1, ["databricks"])
482+
err.stdout = stdout.encode()
483+
err.stderr = stderr.encode()
484+
return err
485+
486+
@staticmethod
487+
def _make_token_source(
488+
*,
489+
profile=None,
490+
host="https://workspace.databricks.com",
491+
cli_path="/path/to/databricks",
492+
):
493+
"""Build a DatabricksCliTokenSource by mocking only the executable check."""
494+
mock_cfg = Mock()
495+
mock_cfg.profile = profile
496+
mock_cfg.host = host
497+
mock_cfg.databricks_cli_path = cli_path
498+
mock_cfg.disable_async_token_refresh = True
499+
mock_cfg.scopes = None
500+
mock_cfg.get_scopes = Mock(return_value=["all-apis"])
501+
mock_cfg.client_type = ClientType.WORKSPACE
502+
mock_cfg.account_id = None
503+
return credentials_provider.DatabricksCliTokenSource(mock_cfg)
504+
505+
def _valid_response_json(self, access_token="fresh-token"):
506+
import json
507+
508+
expiry = (datetime.now() + timedelta(hours=1)).strftime("%Y-%m-%dT%H:%M:%S")
509+
return json.dumps({"access_token": access_token, "token_type": "Bearer", "expiry": expiry})
510+
511+
def test_force_refresh_always_tried_first(self, mocker):
512+
"""refresh() always tries --force-refresh first."""
513+
ts = self._make_token_source()
514+
515+
mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess")
516+
mock_run.return_value = Mock(stdout=self._valid_response_json("refreshed").encode())
517+
518+
token = ts.refresh()
519+
assert token.access_token == "refreshed"
520+
assert mock_run.call_count == 1
521+
522+
cmd = mock_run.call_args[0][0]
523+
assert "--force-refresh" in cmd
524+
525+
def test_force_refresh_fallback_when_unsupported(self, mocker):
526+
"""Old CLI without --force-refresh: falls back to cmd without --force-refresh."""
527+
ts = self._make_token_source()
528+
529+
mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess")
530+
mock_run.side_effect = [
531+
self._make_process_error("Error: unknown flag: --force-refresh"),
532+
Mock(stdout=self._valid_response_json("fallback").encode()),
533+
]
534+
535+
token = ts.refresh()
536+
assert token.access_token == "fallback"
537+
assert mock_run.call_count == 2
538+
539+
first_cmd = mock_run.call_args_list[0][0][0]
540+
second_cmd = mock_run.call_args_list[1][0][0]
541+
assert "--force-refresh" in first_cmd
542+
assert "--force-refresh" not in second_cmd
543+
544+
def test_profile_fallback_when_unsupported(self, mocker):
545+
"""Old CLI without --profile: force_cmd fails, fallback retries with --host."""
546+
ts = self._make_token_source(profile="my-profile")
547+
548+
mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess")
549+
mock_run.side_effect = [
550+
# force_cmd: --profile + --force-refresh → unknown --profile
551+
self._make_process_error("Error: unknown flag: --profile"),
552+
# _refresh_without_force cmd: --profile → unknown --profile
553+
self._make_process_error("Error: unknown flag: --profile"),
554+
# _refresh_without_force fallback_cmd: --host → success
555+
Mock(stdout=self._valid_response_json("host-token").encode()),
556+
]
557+
558+
token = ts.refresh()
559+
assert token.access_token == "host-token"
560+
assert mock_run.call_count == 3
561+
assert "--host" in mock_run.call_args_list[2][0][0]
562+
563+
def test_two_step_downgrade_both_flags_unsupported(self, mocker):
564+
"""CLI supports neither --force-refresh nor --profile: force_cmd fails, then full fallback."""
565+
ts = self._make_token_source(profile="my-profile")
566+
567+
mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess")
568+
mock_run.side_effect = [
569+
# 1st: force_cmd (--profile + --force-refresh) → unknown --force-refresh
570+
self._make_process_error("Error: unknown flag: --force-refresh"),
571+
# 2nd: _refresh_without_force cmd (--profile) → unknown --profile
572+
self._make_process_error("Error: unknown flag: --profile"),
573+
# 3rd: _refresh_without_force fallback_cmd (--host) → success
574+
Mock(stdout=self._valid_response_json("plain").encode()),
575+
]
576+
577+
token = ts.refresh()
578+
assert token.access_token == "plain"
579+
assert mock_run.call_count == 3
580+
581+
cmds = [call[0][0] for call in mock_run.call_args_list]
582+
assert "--force-refresh" in cmds[0] and "--profile" in cmds[0]
583+
assert "--force-refresh" not in cmds[1] and "--profile" in cmds[1]
584+
assert "--host" in cmds[2]
585+
586+
def test_real_auth_error_does_not_trigger_fallback(self, mocker):
587+
"""Real auth failures (not unknown-flag) surface immediately."""
588+
ts = self._make_token_source()
589+
590+
mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess")
591+
mock_run.side_effect = self._make_process_error("cache: databricks OAuth is not configured for this host")
592+
593+
with pytest.raises(IOError) as exc_info:
594+
ts.refresh()
595+
assert "databricks OAuth is not configured" in str(exc_info.value)
596+
assert mock_run.call_count == 1
597+
598+
474599
# Tests for cloud-agnostic hosts and removed cloud checks
475600
class TestCloudAgnosticHosts:
476601
"""Tests that credential providers work with cloud-agnostic hosts after removing is_azure/is_gcp checks."""

0 commit comments

Comments
 (0)