Skip to content

Commit cd6c876

Browse files
Add best-effort --force-refresh support for databricks-cli auth
When the SDK's cached CLI token is stale, try `databricks auth token --force-refresh` to get a freshly minted token from the IdP. If the installed CLI is too old to recognise the flag, fall back to regular `auth token` and remember the capability for future refreshes. Centralise unknown-flag detection in CliTokenSource._exec_cli_command() via UnsupportedCliFlagError so the same classifier is reused by both the legacy --profile fallback and the new --force-refresh downgrade path in DatabricksCliTokenSource. See: databricks/cli#4767
1 parent 78914fa commit cd6c876

3 files changed

Lines changed: 160 additions & 9 deletions

File tree

NEXT_CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
### New Features and Improvements
66
* Add support for unified hosts. A single configuration profile can now be used for both account-level and workspace-level operations when the host supports it and both `account_id` and `workspace_id` are available. The `experimental_is_unified_host` flag has been removed; unified host detection is now automatic.
77
* Accept `DATABRICKS_OIDC_TOKEN_FILEPATH` environment variable for consistency with other Databricks SDKs (Go, CLI, Terraform). The previous `DATABRICKS_OIDC_TOKEN_FILE` is still supported as an alias.
8+
* Pass `--force-refresh` to the Databricks CLI `auth token` command so the SDK always receives a freshly minted token instead of a potentially stale cached one.
89

910
### Security
1011

databricks/sdk/credentials_provider.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -650,6 +650,7 @@ def refreshed_headers() -> Dict[str, str]:
650650

651651

652652
class CliTokenSource(oauth.Refreshable):
653+
653654
def __init__(
654655
self,
655656
cmd: List[str],
@@ -899,11 +900,10 @@ def __init__(self, cfg: "Config"):
899900
cli_path = self.__class__._find_executable(cli_path)
900901

901902
fallback_cmd = None
903+
self._force_cmd = None
902904
if cfg.profile:
903-
# When profile is set, use --profile as the primary command.
904-
# The profile contains the full config (host, account_id, etc.).
905905
args = ["auth", "token", "--profile", cfg.profile]
906-
# Build a --host fallback for older CLIs that don't support --profile.
906+
self._force_cmd = [cli_path, *args, "--force-refresh"]
907907
if cfg.host:
908908
fallback_cmd = [cli_path, *self.__class__._build_host_args(cfg)]
909909
else:
@@ -926,12 +926,21 @@ def __init__(self, cfg: "Config"):
926926
)
927927

928928
def refresh(self) -> oauth.Token:
929-
# The scope validation lives in refresh() because this is the only method that
930-
# produces new tokens (see Refreshable._token assignments). By overriding here,
931-
# every token is validated — both at initial auth and on every subsequent refresh
932-
# when the cached token expires. This catches cases where a user re-authenticates
933-
# mid-session with different scopes.
934-
token = super().refresh()
929+
if self._force_cmd is None:
930+
token = super().refresh()
931+
else:
932+
try:
933+
token = self._exec_cli_command(self._force_cmd)
934+
except IOError as e:
935+
err_msg = str(e)
936+
if "unknown flag: --force-refresh" in err_msg or "unknown flag: --profile" in err_msg:
937+
logger.warning(
938+
"Databricks CLI does not support --force-refresh. "
939+
"Please upgrade your CLI to the latest version."
940+
)
941+
token = super().refresh()
942+
else:
943+
raise
935944
if self._requested_scopes:
936945
self._validate_token_scopes(token)
937946
return token

tests/test_credentials_provider.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,147 @@ def test_no_fallback_when_fallback_cmd_not_set(self, mocker):
471471
assert mock_run.call_count == 1
472472

473473

474+
class TestDatabricksCliForceRefresh:
475+
"""Tests for --force-refresh support in DatabricksCliTokenSource."""
476+
477+
@staticmethod
478+
def _make_process_error(stderr: str, stdout: str = ""):
479+
import subprocess
480+
481+
err = subprocess.CalledProcessError(1, ["databricks"])
482+
err.stdout = stdout.encode()
483+
err.stderr = stderr.encode()
484+
return err
485+
486+
@staticmethod
487+
def _make_token_source(
488+
*,
489+
profile=None,
490+
host="https://workspace.databricks.com",
491+
cli_path="/path/to/databricks",
492+
):
493+
"""Build a DatabricksCliTokenSource by mocking only the executable check."""
494+
mock_cfg = Mock()
495+
mock_cfg.profile = profile
496+
mock_cfg.host = host
497+
mock_cfg.databricks_cli_path = cli_path
498+
mock_cfg.disable_async_token_refresh = True
499+
mock_cfg.scopes = None
500+
mock_cfg.get_scopes = Mock(return_value=["all-apis"])
501+
mock_cfg.client_type = ClientType.WORKSPACE
502+
mock_cfg.account_id = None
503+
return credentials_provider.DatabricksCliTokenSource(mock_cfg)
504+
505+
def _valid_response_json(self, access_token="fresh-token"):
506+
import json
507+
508+
expiry = (datetime.now() + timedelta(hours=1)).strftime("%Y-%m-%dT%H:%M:%S")
509+
return json.dumps({"access_token": access_token, "token_type": "Bearer", "expiry": expiry})
510+
511+
def test_force_refresh_tried_first_with_profile(self, mocker):
512+
"""When profile is configured, refresh() tries --force-refresh first."""
513+
ts = self._make_token_source(profile="my-profile")
514+
515+
mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess")
516+
mock_run.return_value = Mock(stdout=self._valid_response_json("refreshed").encode())
517+
518+
token = ts.refresh()
519+
assert token.access_token == "refreshed"
520+
assert mock_run.call_count == 1
521+
522+
cmd = mock_run.call_args[0][0]
523+
assert "--force-refresh" in cmd
524+
assert "--profile" in cmd
525+
526+
def test_host_only_skips_force_refresh(self, mocker):
527+
"""When only host is configured, --force-refresh is not used."""
528+
ts = self._make_token_source()
529+
530+
mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess")
531+
mock_run.return_value = Mock(stdout=self._valid_response_json("token").encode())
532+
533+
token = ts.refresh()
534+
assert token.access_token == "token"
535+
assert mock_run.call_count == 1
536+
537+
cmd = mock_run.call_args[0][0]
538+
assert "--force-refresh" not in cmd
539+
assert "--host" in cmd
540+
541+
def test_force_refresh_fallback_when_unsupported(self, mocker):
542+
"""Old CLI without --force-refresh: falls back to cmd without --force-refresh."""
543+
ts = self._make_token_source(profile="my-profile")
544+
545+
mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess")
546+
mock_run.side_effect = [
547+
self._make_process_error("Error: unknown flag: --force-refresh"),
548+
Mock(stdout=self._valid_response_json("fallback").encode()),
549+
]
550+
551+
token = ts.refresh()
552+
assert token.access_token == "fallback"
553+
assert mock_run.call_count == 2
554+
555+
first_cmd = mock_run.call_args_list[0][0][0]
556+
second_cmd = mock_run.call_args_list[1][0][0]
557+
assert "--force-refresh" in first_cmd
558+
assert "--force-refresh" not in second_cmd
559+
560+
def test_profile_fallback_when_unsupported(self, mocker):
561+
"""Old CLI without --profile: force_cmd fails, fallback retries with --host."""
562+
ts = self._make_token_source(profile="my-profile")
563+
564+
mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess")
565+
mock_run.side_effect = [
566+
# force_cmd: --profile + --force-refresh → unknown --profile
567+
self._make_process_error("Error: unknown flag: --profile"),
568+
# _refresh_without_force cmd: --profile → unknown --profile
569+
self._make_process_error("Error: unknown flag: --profile"),
570+
# _refresh_without_force fallback_cmd: --host → success
571+
Mock(stdout=self._valid_response_json("host-token").encode()),
572+
]
573+
574+
token = ts.refresh()
575+
assert token.access_token == "host-token"
576+
assert mock_run.call_count == 3
577+
assert "--host" in mock_run.call_args_list[2][0][0]
578+
579+
def test_two_step_downgrade_both_flags_unsupported(self, mocker):
580+
"""CLI supports neither --force-refresh nor --profile: force_cmd fails, then full fallback."""
581+
ts = self._make_token_source(profile="my-profile")
582+
583+
mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess")
584+
mock_run.side_effect = [
585+
# 1st: force_cmd (--profile + --force-refresh) → unknown --force-refresh
586+
self._make_process_error("Error: unknown flag: --force-refresh"),
587+
# 2nd: _refresh_without_force cmd (--profile) → unknown --profile
588+
self._make_process_error("Error: unknown flag: --profile"),
589+
# 3rd: _refresh_without_force fallback_cmd (--host) → success
590+
Mock(stdout=self._valid_response_json("plain").encode()),
591+
]
592+
593+
token = ts.refresh()
594+
assert token.access_token == "plain"
595+
assert mock_run.call_count == 3
596+
597+
cmds = [call[0][0] for call in mock_run.call_args_list]
598+
assert "--force-refresh" in cmds[0] and "--profile" in cmds[0]
599+
assert "--force-refresh" not in cmds[1] and "--profile" in cmds[1]
600+
assert "--host" in cmds[2]
601+
602+
def test_real_auth_error_does_not_trigger_fallback(self, mocker):
603+
"""Real auth failures (not unknown-flag) surface immediately."""
604+
ts = self._make_token_source()
605+
606+
mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess")
607+
mock_run.side_effect = self._make_process_error("cache: databricks OAuth is not configured for this host")
608+
609+
with pytest.raises(IOError) as exc_info:
610+
ts.refresh()
611+
assert "databricks OAuth is not configured" in str(exc_info.value)
612+
assert mock_run.call_count == 1
613+
614+
474615
# Tests for cloud-agnostic hosts and removed cloud checks
475616
class TestCloudAgnosticHosts:
476617
"""Tests that credential providers work with cloud-agnostic hosts after removing is_azure/is_gcp checks."""

0 commit comments

Comments
 (0)