From cd6c876f1e8d8ed9e8606b8e49b5034440136df1 Mon Sep 17 00:00:00 2001 From: Mihai Mitrea Date: Wed, 1 Apr 2026 09:09:34 +0000 Subject: [PATCH 1/2] Add best-effort --force-refresh support for databricks-cli auth When the SDK's cached CLI token is stale, try `databricks auth token --force-refresh` to get a freshly minted token from the IdP. If the installed CLI is too old to recognise the flag, fall back to regular `auth token` and remember the capability for future refreshes. Centralise unknown-flag detection in CliTokenSource._exec_cli_command() via UnsupportedCliFlagError so the same classifier is reused by both the legacy --profile fallback and the new --force-refresh downgrade path in DatabricksCliTokenSource. See: https://github.com/databricks/cli/pull/4767 --- NEXT_CHANGELOG.md | 1 + databricks/sdk/credentials_provider.py | 27 +++-- tests/test_credentials_provider.py | 141 +++++++++++++++++++++++++ 3 files changed, 160 insertions(+), 9 deletions(-) diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index 7b7310b02..3f27a8790 100755 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -5,6 +5,7 @@ ### New Features and Improvements * Add support for unified hosts. A single configuration profile can now be used for both account-level and workspace-level operations when the host supports it and both `account_id` and `workspace_id` are available. The `experimental_is_unified_host` flag has been removed; unified host detection is now automatic. * Accept `DATABRICKS_OIDC_TOKEN_FILEPATH` environment variable for consistency with other Databricks SDKs (Go, CLI, Terraform). The previous `DATABRICKS_OIDC_TOKEN_FILE` is still supported as an alias. +* Pass `--force-refresh` to the Databricks CLI `auth token` command so the SDK always receives a freshly minted token instead of a potentially stale cached one. ### Security diff --git a/databricks/sdk/credentials_provider.py b/databricks/sdk/credentials_provider.py index e49f0a7fc..c48b79617 100644 --- a/databricks/sdk/credentials_provider.py +++ b/databricks/sdk/credentials_provider.py @@ -650,6 +650,7 @@ def refreshed_headers() -> Dict[str, str]: class CliTokenSource(oauth.Refreshable): + def __init__( self, cmd: List[str], @@ -899,11 +900,10 @@ def __init__(self, cfg: "Config"): cli_path = self.__class__._find_executable(cli_path) fallback_cmd = None + self._force_cmd = None if cfg.profile: - # When profile is set, use --profile as the primary command. - # The profile contains the full config (host, account_id, etc.). args = ["auth", "token", "--profile", cfg.profile] - # Build a --host fallback for older CLIs that don't support --profile. + self._force_cmd = [cli_path, *args, "--force-refresh"] if cfg.host: fallback_cmd = [cli_path, *self.__class__._build_host_args(cfg)] else: @@ -926,12 +926,21 @@ def __init__(self, cfg: "Config"): ) def refresh(self) -> oauth.Token: - # The scope validation lives in refresh() because this is the only method that - # produces new tokens (see Refreshable._token assignments). By overriding here, - # every token is validated — both at initial auth and on every subsequent refresh - # when the cached token expires. This catches cases where a user re-authenticates - # mid-session with different scopes. - token = super().refresh() + if self._force_cmd is None: + token = super().refresh() + else: + try: + token = self._exec_cli_command(self._force_cmd) + except IOError as e: + err_msg = str(e) + if "unknown flag: --force-refresh" in err_msg or "unknown flag: --profile" in err_msg: + logger.warning( + "Databricks CLI does not support --force-refresh. " + "Please upgrade your CLI to the latest version." + ) + token = super().refresh() + else: + raise if self._requested_scopes: self._validate_token_scopes(token) return token diff --git a/tests/test_credentials_provider.py b/tests/test_credentials_provider.py index a6eae3018..a185e1a5b 100644 --- a/tests/test_credentials_provider.py +++ b/tests/test_credentials_provider.py @@ -471,6 +471,147 @@ def test_no_fallback_when_fallback_cmd_not_set(self, mocker): assert mock_run.call_count == 1 +class TestDatabricksCliForceRefresh: + """Tests for --force-refresh support in DatabricksCliTokenSource.""" + + @staticmethod + def _make_process_error(stderr: str, stdout: str = ""): + import subprocess + + err = subprocess.CalledProcessError(1, ["databricks"]) + err.stdout = stdout.encode() + err.stderr = stderr.encode() + return err + + @staticmethod + def _make_token_source( + *, + profile=None, + host="https://workspace.databricks.com", + cli_path="/path/to/databricks", + ): + """Build a DatabricksCliTokenSource by mocking only the executable check.""" + mock_cfg = Mock() + mock_cfg.profile = profile + mock_cfg.host = host + mock_cfg.databricks_cli_path = cli_path + mock_cfg.disable_async_token_refresh = True + mock_cfg.scopes = None + mock_cfg.get_scopes = Mock(return_value=["all-apis"]) + mock_cfg.client_type = ClientType.WORKSPACE + mock_cfg.account_id = None + return credentials_provider.DatabricksCliTokenSource(mock_cfg) + + def _valid_response_json(self, access_token="fresh-token"): + import json + + expiry = (datetime.now() + timedelta(hours=1)).strftime("%Y-%m-%dT%H:%M:%S") + return json.dumps({"access_token": access_token, "token_type": "Bearer", "expiry": expiry}) + + def test_force_refresh_tried_first_with_profile(self, mocker): + """When profile is configured, refresh() tries --force-refresh first.""" + ts = self._make_token_source(profile="my-profile") + + mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess") + mock_run.return_value = Mock(stdout=self._valid_response_json("refreshed").encode()) + + token = ts.refresh() + assert token.access_token == "refreshed" + assert mock_run.call_count == 1 + + cmd = mock_run.call_args[0][0] + assert "--force-refresh" in cmd + assert "--profile" in cmd + + def test_host_only_skips_force_refresh(self, mocker): + """When only host is configured, --force-refresh is not used.""" + ts = self._make_token_source() + + mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess") + mock_run.return_value = Mock(stdout=self._valid_response_json("token").encode()) + + token = ts.refresh() + assert token.access_token == "token" + assert mock_run.call_count == 1 + + cmd = mock_run.call_args[0][0] + assert "--force-refresh" not in cmd + assert "--host" in cmd + + def test_force_refresh_fallback_when_unsupported(self, mocker): + """Old CLI without --force-refresh: falls back to cmd without --force-refresh.""" + ts = self._make_token_source(profile="my-profile") + + mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess") + mock_run.side_effect = [ + self._make_process_error("Error: unknown flag: --force-refresh"), + Mock(stdout=self._valid_response_json("fallback").encode()), + ] + + token = ts.refresh() + assert token.access_token == "fallback" + assert mock_run.call_count == 2 + + first_cmd = mock_run.call_args_list[0][0][0] + second_cmd = mock_run.call_args_list[1][0][0] + assert "--force-refresh" in first_cmd + assert "--force-refresh" not in second_cmd + + def test_profile_fallback_when_unsupported(self, mocker): + """Old CLI without --profile: force_cmd fails, fallback retries with --host.""" + ts = self._make_token_source(profile="my-profile") + + mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess") + mock_run.side_effect = [ + # force_cmd: --profile + --force-refresh → unknown --profile + self._make_process_error("Error: unknown flag: --profile"), + # _refresh_without_force cmd: --profile → unknown --profile + self._make_process_error("Error: unknown flag: --profile"), + # _refresh_without_force fallback_cmd: --host → success + Mock(stdout=self._valid_response_json("host-token").encode()), + ] + + token = ts.refresh() + assert token.access_token == "host-token" + assert mock_run.call_count == 3 + assert "--host" in mock_run.call_args_list[2][0][0] + + def test_two_step_downgrade_both_flags_unsupported(self, mocker): + """CLI supports neither --force-refresh nor --profile: force_cmd fails, then full fallback.""" + ts = self._make_token_source(profile="my-profile") + + mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess") + mock_run.side_effect = [ + # 1st: force_cmd (--profile + --force-refresh) → unknown --force-refresh + self._make_process_error("Error: unknown flag: --force-refresh"), + # 2nd: _refresh_without_force cmd (--profile) → unknown --profile + self._make_process_error("Error: unknown flag: --profile"), + # 3rd: _refresh_without_force fallback_cmd (--host) → success + Mock(stdout=self._valid_response_json("plain").encode()), + ] + + token = ts.refresh() + assert token.access_token == "plain" + assert mock_run.call_count == 3 + + cmds = [call[0][0] for call in mock_run.call_args_list] + assert "--force-refresh" in cmds[0] and "--profile" in cmds[0] + assert "--force-refresh" not in cmds[1] and "--profile" in cmds[1] + assert "--host" in cmds[2] + + def test_real_auth_error_does_not_trigger_fallback(self, mocker): + """Real auth failures (not unknown-flag) surface immediately.""" + ts = self._make_token_source() + + mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess") + mock_run.side_effect = self._make_process_error("cache: databricks OAuth is not configured for this host") + + with pytest.raises(IOError) as exc_info: + ts.refresh() + assert "databricks OAuth is not configured" in str(exc_info.value) + assert mock_run.call_count == 1 + + # Tests for cloud-agnostic hosts and removed cloud checks class TestCloudAgnosticHosts: """Tests that credential providers work with cloud-agnostic hosts after removing is_azure/is_gcp checks.""" From 1432f4c5348be9f69876c60578593d15e6510333 Mon Sep 17 00:00:00 2001 From: Mihai Mitrea Date: Wed, 1 Apr 2026 09:24:11 +0000 Subject: [PATCH 2/2] Generalize CLI token source into progressive command list Replace the explicit force_cmd/fallback_cmd fields with a CliCommand dataclass and an optional commands list on CliTokenSource. When commands is provided, refresh() delegates to _refresh_progressive() which walks the list from activeCommandIndex, falling back on unsupported-flag errors. When commands is None, refresh() delegates to _refresh_single() which preserves the original fallback behavior with zero changes for AzureCliTokenSource. DatabricksCliTokenSource._build_commands() produces the progressive list: --profile + --force-refresh first, plain --profile second, and --host as a terminal fallback. --force-refresh is only paired with --profile, never with --host. Adding future flags (e.g. --scopes) requires only adding entries to _build_commands(). --- NEXT_CHANGELOG.md | 1 + databricks/sdk/credentials_provider.py | 119 ++++++++++++---- tests/test_credentials_provider.py | 179 ++++++++++--------------- 3 files changed, 162 insertions(+), 137 deletions(-) diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index 3f27a8790..0091117cc 100755 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -19,6 +19,7 @@ ### Internal Changes * Replace the async-disabling mechanism on token refresh failure with a 1-minute retry backoff. Previously, a single failed async refresh would disable proactive token renewal until the token expired. Now, the SDK waits a short cooldown period and retries, improving resilience to transient errors. * Extract `_resolve_profile` to simplify config file loading and improve `__settings__` error messages. +* Generalize CLI token source into a progressive command list for forward-compatible flag support. ### API Changes * Add `create_catalog()`, `create_synced_table()`, `delete_catalog()`, `delete_synced_table()`, `get_catalog()` and `get_synced_table()` methods for [w.postgres](https://databricks-sdk-py.readthedocs.io/en/latest/workspace/postgres/postgres.html) workspace-level service. diff --git a/databricks/sdk/credentials_provider.py b/databricks/sdk/credentials_provider.py index c48b79617..c69b8b582 100644 --- a/databricks/sdk/credentials_provider.py +++ b/databricks/sdk/credentials_provider.py @@ -1,5 +1,6 @@ import abc import base64 +import dataclasses import functools import io import json @@ -649,6 +650,15 @@ def refreshed_headers() -> Dict[str, str]: return OAuthCredentialsProvider(refreshed_headers, token) +@dataclasses.dataclass +class CliCommand: + """A single CLI command variant with metadata for progressive fallback.""" + + args: List[str] + flags: List[str] + warning: str + + class CliTokenSource(oauth.Refreshable): def __init__( @@ -659,6 +669,7 @@ def __init__( expiry_field: str, disable_async: bool = True, fallback_cmd: Optional[List[str]] = None, + commands: Optional[List[CliCommand]] = None, ): super().__init__(disable_async=disable_async) self._cmd = cmd @@ -670,6 +681,8 @@ def __init__( self._token_type_field = token_type_field self._access_token_field = access_token_field self._expiry_field = expiry_field + self._commands = commands + self._active_command_index = -1 @staticmethod def _parse_expiry(expiry: str) -> datetime: @@ -700,7 +713,18 @@ def _exec_cli_command(self, cmd: List[str]) -> oauth.Token: message = "\n".join(filter(None, [stdout, stderr])) raise IOError(f"cannot get access token: {message}") from e + @staticmethod + def _is_unknown_flag_error(error: IOError, flags: List[str]) -> bool: + """Check if the error indicates the CLI rejected one of the given flags.""" + msg = str(error) + return any(f"unknown flag: {flag}" in msg for flag in flags) + def refresh(self) -> oauth.Token: + if self._commands is not None: + return self._refresh_progressive() + return self._refresh_single() + + def _refresh_single(self) -> oauth.Token: try: return self._exec_cli_command(self._cmd) except IOError as e: @@ -712,6 +736,30 @@ def refresh(self) -> oauth.Token: return self._exec_cli_command(self._fallback_cmd) raise + def _refresh_progressive(self) -> oauth.Token: + idx = self._active_command_index + if idx >= 0: + return self._exec_cli_command(self._commands[idx].args) + return self._probe_and_exec() + + def _probe_and_exec(self) -> oauth.Token: + """Walk the command list to find a CLI command that succeeds. + + When a command fails with "unknown flag" for one of its flags, log a + warning and try the next. On success, store _active_command_index so + future calls skip probing. + """ + for i, cmd in enumerate(self._commands): + try: + token = self._exec_cli_command(cmd.args) + self._active_command_index = i + return token + except IOError as e: + is_last = i == len(self._commands) - 1 + if is_last or not self._is_unknown_flag_error(e, cmd.flags): + raise + logger.warning(cmd.warning) + def _run_subprocess( popenargs, @@ -899,15 +947,7 @@ def __init__(self, cfg: "Config"): elif cli_path.count("/") == 0: cli_path = self.__class__._find_executable(cli_path) - fallback_cmd = None - self._force_cmd = None - if cfg.profile: - args = ["auth", "token", "--profile", cfg.profile] - self._force_cmd = [cli_path, *args, "--force-refresh"] - if cfg.host: - fallback_cmd = [cli_path, *self.__class__._build_host_args(cfg)] - else: - args = self.__class__._build_host_args(cfg) + commands = self.__class__._build_commands(cli_path, cfg) # get_scopes() defaults to ["all-apis"] when nothing is configured, which would # cause false-positive mismatches against every token that wasn't issued with @@ -917,30 +957,57 @@ def __init__(self, cfg: "Config"): self._host = cfg.host super().__init__( - cmd=[cli_path, *args], + cmd=commands[-1].args, token_type_field="token_type", access_token_field="access_token", expiry_field="expiry", disable_async=cfg.disable_async_token_refresh, - fallback_cmd=fallback_cmd, + commands=commands, ) - def refresh(self) -> oauth.Token: - if self._force_cmd is None: - token = super().refresh() - else: - try: - token = self._exec_cli_command(self._force_cmd) - except IOError as e: - err_msg = str(e) - if "unknown flag: --force-refresh" in err_msg or "unknown flag: --profile" in err_msg: - logger.warning( - "Databricks CLI does not support --force-refresh. " - "Please upgrade your CLI to the latest version." + @staticmethod + def _build_commands(cli_path: str, cfg: "Config") -> List[CliCommand]: + commands: List[CliCommand] = [] + if cfg.profile: + profile_args = [cli_path, "auth", "token", "--profile", cfg.profile] + commands.append( + CliCommand( + args=profile_args + ["--force-refresh"], + flags=["--force-refresh", "--profile"], + warning="Databricks CLI does not support --force-refresh. " + "Please upgrade your CLI to the latest version.", + ) + ) + commands.append( + CliCommand( + args=profile_args, + flags=["--profile"], + warning="Databricks CLI does not support --profile flag. " + "Falling back to --host. " + "Please upgrade your CLI to the latest version.", + ) + ) + if cfg.host: + commands.append( + CliCommand( + args=[cli_path, *DatabricksCliTokenSource._build_host_args(cfg)], + flags=[], + warning="", ) - token = super().refresh() - else: - raise + ) + else: + host_args = [cli_path, *DatabricksCliTokenSource._build_host_args(cfg)] + commands.append( + CliCommand( + args=host_args, + flags=[], + warning="", + ) + ) + return commands + + def refresh(self) -> oauth.Token: + token = super().refresh() if self._requested_scopes: self._validate_token_scopes(token) return token diff --git a/tests/test_credentials_provider.py b/tests/test_credentials_provider.py index a185e1a5b..55116a95b 100644 --- a/tests/test_credentials_provider.py +++ b/tests/test_credentials_provider.py @@ -337,8 +337,8 @@ def test_account_client_passes_account_id(self, mocker): assert "test-account-id" in cmd assert "--workspace-id" not in cmd - def test_profile_uses_profile_flag_with_host_fallback(self, mocker): - """When profile is set, --profile is used as primary and --host as fallback.""" + def test_profile_with_host_builds_three_commands(self, mocker): + """With profile + host: force-refresh, profile, host fallback.""" mock_init = mocker.patch.object( credentials_provider.CliTokenSource, "__init__", @@ -348,24 +348,21 @@ def test_profile_uses_profile_flag_with_host_fallback(self, mocker): mock_cfg = Mock() mock_cfg.profile = "my-profile" mock_cfg.host = "https://workspace.databricks.com" - + mock_cfg.client_type = ClientType.WORKSPACE + mock_cfg.account_id = None mock_cfg.databricks_cli_path = "/path/to/databricks" mock_cfg.disable_async_token_refresh = False credentials_provider.DatabricksCliTokenSource(mock_cfg) - call_kwargs = mock_init.call_args - cmd = call_kwargs.kwargs["cmd"] - host_cmd = call_kwargs.kwargs["fallback_cmd"] - - assert cmd == ["/path/to/databricks", "auth", "token", "--profile", "my-profile"] - assert host_cmd is not None - assert "--host" in host_cmd - assert "https://workspace.databricks.com" in host_cmd - assert "--profile" not in host_cmd + commands = mock_init.call_args.kwargs["commands"] + assert len(commands) == 3 + assert "--force-refresh" in commands[0].args and "--profile" in commands[0].args + assert "--profile" in commands[1].args and "--force-refresh" not in commands[1].args + assert "--host" in commands[2].args and "--force-refresh" not in commands[2].args - def test_profile_without_host_no_fallback(self, mocker): - """When profile is set but host is absent, no fallback is built.""" + def test_profile_without_host_builds_two_commands(self, mocker): + """With profile only: force-refresh and plain profile.""" mock_init = mocker.patch.object( credentials_provider.CliTokenSource, "__init__", @@ -380,95 +377,33 @@ def test_profile_without_host_no_fallback(self, mocker): credentials_provider.DatabricksCliTokenSource(mock_cfg) - call_kwargs = mock_init.call_args - cmd = call_kwargs.kwargs["cmd"] - host_cmd = call_kwargs.kwargs["fallback_cmd"] - - assert cmd == ["/path/to/databricks", "auth", "token", "--profile", "my-profile"] - assert host_cmd is None - - -# Tests for CliTokenSource fallback on unknown --profile flag -class TestCliTokenSourceFallback: - """Tests that CliTokenSource falls back to --host when CLI doesn't support --profile.""" - - def _make_token_source(self, fallback_cmd=None): - ts = credentials_provider.CliTokenSource.__new__(credentials_provider.CliTokenSource) - ts._cmd = ["databricks", "auth", "token", "--profile", "my-profile"] - ts._fallback_cmd = fallback_cmd - ts._token_type_field = "token_type" - ts._access_token_field = "access_token" - ts._expiry_field = "expiry" - return ts - - def _make_process_error(self, stderr: str, stdout: str = ""): - import subprocess - - err = subprocess.CalledProcessError(1, ["databricks"]) - err.stdout = stdout.encode() - err.stderr = stderr.encode() - return err - - def test_fallback_on_unknown_profile_flag(self, mocker): - """When --profile fails with 'unknown flag: --profile', falls back to --host command.""" - import json - - expiry = (datetime.now() + timedelta(hours=1)).strftime("%Y-%m-%dT%H:%M:%S") - valid_response = json.dumps({"access_token": "fallback-token", "token_type": "Bearer", "expiry": expiry}) - - mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess") - mock_run.side_effect = [ - self._make_process_error("Error: unknown flag: --profile"), - Mock(stdout=valid_response.encode()), - ] - - fallback_cmd = ["databricks", "auth", "token", "--host", "https://workspace.databricks.com"] - ts = self._make_token_source(fallback_cmd=fallback_cmd) - token = ts.refresh() - assert token.access_token == "fallback-token" - assert mock_run.call_count == 2 - assert mock_run.call_args_list[1][0][0] == fallback_cmd - - def test_fallback_triggered_when_unknown_flag_in_stderr_only(self, mocker): - """Fallback triggers even when CLI also writes usage text to stdout.""" - import json - - expiry = (datetime.now() + timedelta(hours=1)).strftime("%Y-%m-%dT%H:%M:%S") - valid_response = json.dumps({"access_token": "fallback-token", "token_type": "Bearer", "expiry": expiry}) - - mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess") - mock_run.side_effect = [ - self._make_process_error(stderr="Error: unknown flag: --profile", stdout="Usage: databricks auth token"), - Mock(stdout=valid_response.encode()), - ] - - fallback_cmd = ["databricks", "auth", "token", "--host", "https://workspace.databricks.com"] - ts = self._make_token_source(fallback_cmd=fallback_cmd) - token = ts.refresh() - assert token.access_token == "fallback-token" + commands = mock_init.call_args.kwargs["commands"] + assert len(commands) == 2 + assert "--force-refresh" in commands[0].args + assert "--force-refresh" not in commands[1].args - def test_no_fallback_on_real_auth_error(self, mocker): - """When --profile fails with a real error (not unknown flag), no fallback is attempted.""" - mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess") - mock_run.side_effect = self._make_process_error("cache: databricks OAuth is not configured for this host") + def test_host_only_builds_one_command(self, mocker): + """With host only: single plain host command.""" + mock_init = mocker.patch.object( + credentials_provider.CliTokenSource, + "__init__", + return_value=None, + ) - fallback_cmd = ["databricks", "auth", "token", "--host", "https://workspace.databricks.com"] - ts = self._make_token_source(fallback_cmd=fallback_cmd) - with pytest.raises(IOError) as exc_info: - ts.refresh() - assert "databricks OAuth is not configured" in str(exc_info.value) - assert mock_run.call_count == 1 + mock_cfg = Mock() + mock_cfg.profile = None + mock_cfg.host = "https://workspace.databricks.com" + mock_cfg.client_type = ClientType.WORKSPACE + mock_cfg.account_id = None + mock_cfg.databricks_cli_path = "/path/to/databricks" + mock_cfg.disable_async_token_refresh = False - def test_no_fallback_when_fallback_cmd_not_set(self, mocker): - """When fallback_cmd is None and --profile fails, the original error is raised.""" - mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess") - mock_run.side_effect = self._make_process_error("Error: unknown flag: --profile") + credentials_provider.DatabricksCliTokenSource(mock_cfg) - ts = self._make_token_source(fallback_cmd=None) - with pytest.raises(IOError) as exc_info: - ts.refresh() - assert "unknown flag: --profile" in str(exc_info.value) - assert mock_run.call_count == 1 + commands = mock_init.call_args.kwargs["commands"] + assert len(commands) == 1 + assert "--host" in commands[0].args + assert "--force-refresh" not in commands[0].args class TestDatabricksCliForceRefresh: @@ -539,7 +474,7 @@ def test_host_only_skips_force_refresh(self, mocker): assert "--host" in cmd def test_force_refresh_fallback_when_unsupported(self, mocker): - """Old CLI without --force-refresh: falls back to cmd without --force-refresh.""" + """Old CLI without --force-refresh: falls back to plain --profile command.""" ts = self._make_token_source(profile="my-profile") mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess") @@ -556,18 +491,16 @@ def test_force_refresh_fallback_when_unsupported(self, mocker): second_cmd = mock_run.call_args_list[1][0][0] assert "--force-refresh" in first_cmd assert "--force-refresh" not in second_cmd + assert "--profile" in second_cmd - def test_profile_fallback_when_unsupported(self, mocker): - """Old CLI without --profile: force_cmd fails, fallback retries with --host.""" + def test_profile_fallback_to_host(self, mocker): + """Old CLI without --profile: progressive loop walks to --host terminal command.""" ts = self._make_token_source(profile="my-profile") mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess") mock_run.side_effect = [ - # force_cmd: --profile + --force-refresh → unknown --profile self._make_process_error("Error: unknown flag: --profile"), - # _refresh_without_force cmd: --profile → unknown --profile self._make_process_error("Error: unknown flag: --profile"), - # _refresh_without_force fallback_cmd: --host → success Mock(stdout=self._valid_response_json("host-token").encode()), ] @@ -576,17 +509,14 @@ def test_profile_fallback_when_unsupported(self, mocker): assert mock_run.call_count == 3 assert "--host" in mock_run.call_args_list[2][0][0] - def test_two_step_downgrade_both_flags_unsupported(self, mocker): - """CLI supports neither --force-refresh nor --profile: force_cmd fails, then full fallback.""" + def test_full_fallback_chain(self, mocker): + """CLI supports neither --force-refresh nor --profile: walks the full chain.""" ts = self._make_token_source(profile="my-profile") mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess") mock_run.side_effect = [ - # 1st: force_cmd (--profile + --force-refresh) → unknown --force-refresh self._make_process_error("Error: unknown flag: --force-refresh"), - # 2nd: _refresh_without_force cmd (--profile) → unknown --profile self._make_process_error("Error: unknown flag: --profile"), - # 3rd: _refresh_without_force fallback_cmd (--host) → success Mock(stdout=self._valid_response_json("plain").encode()), ] @@ -599,9 +529,28 @@ def test_two_step_downgrade_both_flags_unsupported(self, mocker): assert "--force-refresh" not in cmds[1] and "--profile" in cmds[1] assert "--host" in cmds[2] + def test_active_command_index_caching(self, mocker): + """After fallback, subsequent refreshes start at the cached command index.""" + ts = self._make_token_source(profile="my-profile") + + mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess") + mock_run.side_effect = [ + self._make_process_error("Error: unknown flag: --force-refresh"), + Mock(stdout=self._valid_response_json("first").encode()), + Mock(stdout=self._valid_response_json("second").encode()), + ] + + token1 = ts.refresh() + assert token1.access_token == "first" + assert mock_run.call_count == 2 + + token2 = ts.refresh() + assert token2.access_token == "second" + assert mock_run.call_count == 3 + def test_real_auth_error_does_not_trigger_fallback(self, mocker): """Real auth failures (not unknown-flag) surface immediately.""" - ts = self._make_token_source() + ts = self._make_token_source(profile="my-profile") mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess") mock_run.side_effect = self._make_process_error("cache: databricks OAuth is not configured for this host") @@ -611,6 +560,14 @@ def test_real_auth_error_does_not_trigger_fallback(self, mocker): assert "databricks OAuth is not configured" in str(exc_info.value) assert mock_run.call_count == 1 + def test_is_unknown_flag_error(self): + """_is_unknown_flag_error matches against specific flag list.""" + check = credentials_provider.CliTokenSource._is_unknown_flag_error + assert check(IOError("Error: unknown flag: --force-refresh"), ["--force-refresh", "--profile"]) + assert check(IOError("Error: unknown flag: --profile"), ["--profile"]) + assert not check(IOError("Error: unknown flag: --force-refresh"), ["--profile"]) + assert not check(IOError("some other error"), ["--force-refresh"]) + # Tests for cloud-agnostic hosts and removed cloud checks class TestCloudAgnosticHosts: