From cd6c876f1e8d8ed9e8606b8e49b5034440136df1 Mon Sep 17 00:00:00 2001 From: Mihai Mitrea Date: Wed, 1 Apr 2026 09:09:34 +0000 Subject: [PATCH] Add best-effort --force-refresh support for databricks-cli auth When the SDK's cached CLI token is stale, try `databricks auth token --force-refresh` to get a freshly minted token from the IdP. If the installed CLI is too old to recognise the flag, fall back to regular `auth token` and remember the capability for future refreshes. Centralise unknown-flag detection in CliTokenSource._exec_cli_command() via UnsupportedCliFlagError so the same classifier is reused by both the legacy --profile fallback and the new --force-refresh downgrade path in DatabricksCliTokenSource. See: https://github.com/databricks/cli/pull/4767 --- NEXT_CHANGELOG.md | 1 + databricks/sdk/credentials_provider.py | 27 +++-- tests/test_credentials_provider.py | 141 +++++++++++++++++++++++++ 3 files changed, 160 insertions(+), 9 deletions(-) diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index 7b7310b02..3f27a8790 100755 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -5,6 +5,7 @@ ### New Features and Improvements * Add support for unified hosts. A single configuration profile can now be used for both account-level and workspace-level operations when the host supports it and both `account_id` and `workspace_id` are available. The `experimental_is_unified_host` flag has been removed; unified host detection is now automatic. * Accept `DATABRICKS_OIDC_TOKEN_FILEPATH` environment variable for consistency with other Databricks SDKs (Go, CLI, Terraform). The previous `DATABRICKS_OIDC_TOKEN_FILE` is still supported as an alias. +* Pass `--force-refresh` to the Databricks CLI `auth token` command so the SDK always receives a freshly minted token instead of a potentially stale cached one. ### Security diff --git a/databricks/sdk/credentials_provider.py b/databricks/sdk/credentials_provider.py index e49f0a7fc..c48b79617 100644 --- a/databricks/sdk/credentials_provider.py +++ b/databricks/sdk/credentials_provider.py @@ -650,6 +650,7 @@ def refreshed_headers() -> Dict[str, str]: class CliTokenSource(oauth.Refreshable): + def __init__( self, cmd: List[str], @@ -899,11 +900,10 @@ def __init__(self, cfg: "Config"): cli_path = self.__class__._find_executable(cli_path) fallback_cmd = None + self._force_cmd = None if cfg.profile: - # When profile is set, use --profile as the primary command. - # The profile contains the full config (host, account_id, etc.). args = ["auth", "token", "--profile", cfg.profile] - # Build a --host fallback for older CLIs that don't support --profile. + self._force_cmd = [cli_path, *args, "--force-refresh"] if cfg.host: fallback_cmd = [cli_path, *self.__class__._build_host_args(cfg)] else: @@ -926,12 +926,21 @@ def __init__(self, cfg: "Config"): ) def refresh(self) -> oauth.Token: - # The scope validation lives in refresh() because this is the only method that - # produces new tokens (see Refreshable._token assignments). By overriding here, - # every token is validated — both at initial auth and on every subsequent refresh - # when the cached token expires. This catches cases where a user re-authenticates - # mid-session with different scopes. - token = super().refresh() + if self._force_cmd is None: + token = super().refresh() + else: + try: + token = self._exec_cli_command(self._force_cmd) + except IOError as e: + err_msg = str(e) + if "unknown flag: --force-refresh" in err_msg or "unknown flag: --profile" in err_msg: + logger.warning( + "Databricks CLI does not support --force-refresh. " + "Please upgrade your CLI to the latest version." + ) + token = super().refresh() + else: + raise if self._requested_scopes: self._validate_token_scopes(token) return token diff --git a/tests/test_credentials_provider.py b/tests/test_credentials_provider.py index a6eae3018..a185e1a5b 100644 --- a/tests/test_credentials_provider.py +++ b/tests/test_credentials_provider.py @@ -471,6 +471,147 @@ def test_no_fallback_when_fallback_cmd_not_set(self, mocker): assert mock_run.call_count == 1 +class TestDatabricksCliForceRefresh: + """Tests for --force-refresh support in DatabricksCliTokenSource.""" + + @staticmethod + def _make_process_error(stderr: str, stdout: str = ""): + import subprocess + + err = subprocess.CalledProcessError(1, ["databricks"]) + err.stdout = stdout.encode() + err.stderr = stderr.encode() + return err + + @staticmethod + def _make_token_source( + *, + profile=None, + host="https://workspace.databricks.com", + cli_path="/path/to/databricks", + ): + """Build a DatabricksCliTokenSource by mocking only the executable check.""" + mock_cfg = Mock() + mock_cfg.profile = profile + mock_cfg.host = host + mock_cfg.databricks_cli_path = cli_path + mock_cfg.disable_async_token_refresh = True + mock_cfg.scopes = None + mock_cfg.get_scopes = Mock(return_value=["all-apis"]) + mock_cfg.client_type = ClientType.WORKSPACE + mock_cfg.account_id = None + return credentials_provider.DatabricksCliTokenSource(mock_cfg) + + def _valid_response_json(self, access_token="fresh-token"): + import json + + expiry = (datetime.now() + timedelta(hours=1)).strftime("%Y-%m-%dT%H:%M:%S") + return json.dumps({"access_token": access_token, "token_type": "Bearer", "expiry": expiry}) + + def test_force_refresh_tried_first_with_profile(self, mocker): + """When profile is configured, refresh() tries --force-refresh first.""" + ts = self._make_token_source(profile="my-profile") + + mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess") + mock_run.return_value = Mock(stdout=self._valid_response_json("refreshed").encode()) + + token = ts.refresh() + assert token.access_token == "refreshed" + assert mock_run.call_count == 1 + + cmd = mock_run.call_args[0][0] + assert "--force-refresh" in cmd + assert "--profile" in cmd + + def test_host_only_skips_force_refresh(self, mocker): + """When only host is configured, --force-refresh is not used.""" + ts = self._make_token_source() + + mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess") + mock_run.return_value = Mock(stdout=self._valid_response_json("token").encode()) + + token = ts.refresh() + assert token.access_token == "token" + assert mock_run.call_count == 1 + + cmd = mock_run.call_args[0][0] + assert "--force-refresh" not in cmd + assert "--host" in cmd + + def test_force_refresh_fallback_when_unsupported(self, mocker): + """Old CLI without --force-refresh: falls back to cmd without --force-refresh.""" + ts = self._make_token_source(profile="my-profile") + + mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess") + mock_run.side_effect = [ + self._make_process_error("Error: unknown flag: --force-refresh"), + Mock(stdout=self._valid_response_json("fallback").encode()), + ] + + token = ts.refresh() + assert token.access_token == "fallback" + assert mock_run.call_count == 2 + + first_cmd = mock_run.call_args_list[0][0][0] + second_cmd = mock_run.call_args_list[1][0][0] + assert "--force-refresh" in first_cmd + assert "--force-refresh" not in second_cmd + + def test_profile_fallback_when_unsupported(self, mocker): + """Old CLI without --profile: force_cmd fails, fallback retries with --host.""" + ts = self._make_token_source(profile="my-profile") + + mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess") + mock_run.side_effect = [ + # force_cmd: --profile + --force-refresh → unknown --profile + self._make_process_error("Error: unknown flag: --profile"), + # _refresh_without_force cmd: --profile → unknown --profile + self._make_process_error("Error: unknown flag: --profile"), + # _refresh_without_force fallback_cmd: --host → success + Mock(stdout=self._valid_response_json("host-token").encode()), + ] + + token = ts.refresh() + assert token.access_token == "host-token" + assert mock_run.call_count == 3 + assert "--host" in mock_run.call_args_list[2][0][0] + + def test_two_step_downgrade_both_flags_unsupported(self, mocker): + """CLI supports neither --force-refresh nor --profile: force_cmd fails, then full fallback.""" + ts = self._make_token_source(profile="my-profile") + + mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess") + mock_run.side_effect = [ + # 1st: force_cmd (--profile + --force-refresh) → unknown --force-refresh + self._make_process_error("Error: unknown flag: --force-refresh"), + # 2nd: _refresh_without_force cmd (--profile) → unknown --profile + self._make_process_error("Error: unknown flag: --profile"), + # 3rd: _refresh_without_force fallback_cmd (--host) → success + Mock(stdout=self._valid_response_json("plain").encode()), + ] + + token = ts.refresh() + assert token.access_token == "plain" + assert mock_run.call_count == 3 + + cmds = [call[0][0] for call in mock_run.call_args_list] + assert "--force-refresh" in cmds[0] and "--profile" in cmds[0] + assert "--force-refresh" not in cmds[1] and "--profile" in cmds[1] + assert "--host" in cmds[2] + + def test_real_auth_error_does_not_trigger_fallback(self, mocker): + """Real auth failures (not unknown-flag) surface immediately.""" + ts = self._make_token_source() + + mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess") + mock_run.side_effect = self._make_process_error("cache: databricks OAuth is not configured for this host") + + with pytest.raises(IOError) as exc_info: + ts.refresh() + assert "databricks OAuth is not configured" in str(exc_info.value) + assert mock_run.call_count == 1 + + # Tests for cloud-agnostic hosts and removed cloud checks class TestCloudAgnosticHosts: """Tests that credential providers work with cloud-agnostic hosts after removing is_azure/is_gcp checks."""