diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index 7b7310b02..0091117cc 100755 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -5,6 +5,7 @@ ### New Features and Improvements * Add support for unified hosts. A single configuration profile can now be used for both account-level and workspace-level operations when the host supports it and both `account_id` and `workspace_id` are available. The `experimental_is_unified_host` flag has been removed; unified host detection is now automatic. * Accept `DATABRICKS_OIDC_TOKEN_FILEPATH` environment variable for consistency with other Databricks SDKs (Go, CLI, Terraform). The previous `DATABRICKS_OIDC_TOKEN_FILE` is still supported as an alias. +* Pass `--force-refresh` to the Databricks CLI `auth token` command so the SDK always receives a freshly minted token instead of a potentially stale cached one. ### Security @@ -18,6 +19,7 @@ ### Internal Changes * Replace the async-disabling mechanism on token refresh failure with a 1-minute retry backoff. Previously, a single failed async refresh would disable proactive token renewal until the token expired. Now, the SDK waits a short cooldown period and retries, improving resilience to transient errors. * Extract `_resolve_profile` to simplify config file loading and improve `__settings__` error messages. +* Generalize CLI token source into a progressive command list for forward-compatible flag support. ### API Changes * Add `create_catalog()`, `create_synced_table()`, `delete_catalog()`, `delete_synced_table()`, `get_catalog()` and `get_synced_table()` methods for [w.postgres](https://databricks-sdk-py.readthedocs.io/en/latest/workspace/postgres/postgres.html) workspace-level service. diff --git a/databricks/sdk/credentials_provider.py b/databricks/sdk/credentials_provider.py index e49f0a7fc..c69b8b582 100644 --- a/databricks/sdk/credentials_provider.py +++ b/databricks/sdk/credentials_provider.py @@ -1,5 +1,6 @@ import abc import base64 +import dataclasses import functools import io import json @@ -649,7 +650,17 @@ def refreshed_headers() -> Dict[str, str]: return OAuthCredentialsProvider(refreshed_headers, token) +@dataclasses.dataclass +class CliCommand: + """A single CLI command variant with metadata for progressive fallback.""" + + args: List[str] + flags: List[str] + warning: str + + class CliTokenSource(oauth.Refreshable): + def __init__( self, cmd: List[str], @@ -658,6 +669,7 @@ def __init__( expiry_field: str, disable_async: bool = True, fallback_cmd: Optional[List[str]] = None, + commands: Optional[List[CliCommand]] = None, ): super().__init__(disable_async=disable_async) self._cmd = cmd @@ -669,6 +681,8 @@ def __init__( self._token_type_field = token_type_field self._access_token_field = access_token_field self._expiry_field = expiry_field + self._commands = commands + self._active_command_index = -1 @staticmethod def _parse_expiry(expiry: str) -> datetime: @@ -699,7 +713,18 @@ def _exec_cli_command(self, cmd: List[str]) -> oauth.Token: message = "\n".join(filter(None, [stdout, stderr])) raise IOError(f"cannot get access token: {message}") from e + @staticmethod + def _is_unknown_flag_error(error: IOError, flags: List[str]) -> bool: + """Check if the error indicates the CLI rejected one of the given flags.""" + msg = str(error) + return any(f"unknown flag: {flag}" in msg for flag in flags) + def refresh(self) -> oauth.Token: + if self._commands is not None: + return self._refresh_progressive() + return self._refresh_single() + + def _refresh_single(self) -> oauth.Token: try: return self._exec_cli_command(self._cmd) except IOError as e: @@ -711,6 +736,30 @@ def refresh(self) -> oauth.Token: return self._exec_cli_command(self._fallback_cmd) raise + def _refresh_progressive(self) -> oauth.Token: + idx = self._active_command_index + if idx >= 0: + return self._exec_cli_command(self._commands[idx].args) + return self._probe_and_exec() + + def _probe_and_exec(self) -> oauth.Token: + """Walk the command list to find a CLI command that succeeds. + + When a command fails with "unknown flag" for one of its flags, log a + warning and try the next. On success, store _active_command_index so + future calls skip probing. + """ + for i, cmd in enumerate(self._commands): + try: + token = self._exec_cli_command(cmd.args) + self._active_command_index = i + return token + except IOError as e: + is_last = i == len(self._commands) - 1 + if is_last or not self._is_unknown_flag_error(e, cmd.flags): + raise + logger.warning(cmd.warning) + def _run_subprocess( popenargs, @@ -898,16 +947,7 @@ def __init__(self, cfg: "Config"): elif cli_path.count("/") == 0: cli_path = self.__class__._find_executable(cli_path) - fallback_cmd = None - if cfg.profile: - # When profile is set, use --profile as the primary command. - # The profile contains the full config (host, account_id, etc.). - args = ["auth", "token", "--profile", cfg.profile] - # Build a --host fallback for older CLIs that don't support --profile. - if cfg.host: - fallback_cmd = [cli_path, *self.__class__._build_host_args(cfg)] - else: - args = self.__class__._build_host_args(cfg) + commands = self.__class__._build_commands(cli_path, cfg) # get_scopes() defaults to ["all-apis"] when nothing is configured, which would # cause false-positive mismatches against every token that wasn't issued with @@ -917,20 +957,56 @@ def __init__(self, cfg: "Config"): self._host = cfg.host super().__init__( - cmd=[cli_path, *args], + cmd=commands[-1].args, token_type_field="token_type", access_token_field="access_token", expiry_field="expiry", disable_async=cfg.disable_async_token_refresh, - fallback_cmd=fallback_cmd, + commands=commands, ) + @staticmethod + def _build_commands(cli_path: str, cfg: "Config") -> List[CliCommand]: + commands: List[CliCommand] = [] + if cfg.profile: + profile_args = [cli_path, "auth", "token", "--profile", cfg.profile] + commands.append( + CliCommand( + args=profile_args + ["--force-refresh"], + flags=["--force-refresh", "--profile"], + warning="Databricks CLI does not support --force-refresh. " + "Please upgrade your CLI to the latest version.", + ) + ) + commands.append( + CliCommand( + args=profile_args, + flags=["--profile"], + warning="Databricks CLI does not support --profile flag. " + "Falling back to --host. " + "Please upgrade your CLI to the latest version.", + ) + ) + if cfg.host: + commands.append( + CliCommand( + args=[cli_path, *DatabricksCliTokenSource._build_host_args(cfg)], + flags=[], + warning="", + ) + ) + else: + host_args = [cli_path, *DatabricksCliTokenSource._build_host_args(cfg)] + commands.append( + CliCommand( + args=host_args, + flags=[], + warning="", + ) + ) + return commands + def refresh(self) -> oauth.Token: - # The scope validation lives in refresh() because this is the only method that - # produces new tokens (see Refreshable._token assignments). By overriding here, - # every token is validated — both at initial auth and on every subsequent refresh - # when the cached token expires. This catches cases where a user re-authenticates - # mid-session with different scopes. token = super().refresh() if self._requested_scopes: self._validate_token_scopes(token) diff --git a/tests/test_credentials_provider.py b/tests/test_credentials_provider.py index a6eae3018..55116a95b 100644 --- a/tests/test_credentials_provider.py +++ b/tests/test_credentials_provider.py @@ -337,8 +337,8 @@ def test_account_client_passes_account_id(self, mocker): assert "test-account-id" in cmd assert "--workspace-id" not in cmd - def test_profile_uses_profile_flag_with_host_fallback(self, mocker): - """When profile is set, --profile is used as primary and --host as fallback.""" + def test_profile_with_host_builds_three_commands(self, mocker): + """With profile + host: force-refresh, profile, host fallback.""" mock_init = mocker.patch.object( credentials_provider.CliTokenSource, "__init__", @@ -348,24 +348,21 @@ def test_profile_uses_profile_flag_with_host_fallback(self, mocker): mock_cfg = Mock() mock_cfg.profile = "my-profile" mock_cfg.host = "https://workspace.databricks.com" - + mock_cfg.client_type = ClientType.WORKSPACE + mock_cfg.account_id = None mock_cfg.databricks_cli_path = "/path/to/databricks" mock_cfg.disable_async_token_refresh = False credentials_provider.DatabricksCliTokenSource(mock_cfg) - call_kwargs = mock_init.call_args - cmd = call_kwargs.kwargs["cmd"] - host_cmd = call_kwargs.kwargs["fallback_cmd"] - - assert cmd == ["/path/to/databricks", "auth", "token", "--profile", "my-profile"] - assert host_cmd is not None - assert "--host" in host_cmd - assert "https://workspace.databricks.com" in host_cmd - assert "--profile" not in host_cmd + commands = mock_init.call_args.kwargs["commands"] + assert len(commands) == 3 + assert "--force-refresh" in commands[0].args and "--profile" in commands[0].args + assert "--profile" in commands[1].args and "--force-refresh" not in commands[1].args + assert "--host" in commands[2].args and "--force-refresh" not in commands[2].args - def test_profile_without_host_no_fallback(self, mocker): - """When profile is set but host is absent, no fallback is built.""" + def test_profile_without_host_builds_two_commands(self, mocker): + """With profile only: force-refresh and plain profile.""" mock_init = mocker.patch.object( credentials_provider.CliTokenSource, "__init__", @@ -380,28 +377,40 @@ def test_profile_without_host_no_fallback(self, mocker): credentials_provider.DatabricksCliTokenSource(mock_cfg) - call_kwargs = mock_init.call_args - cmd = call_kwargs.kwargs["cmd"] - host_cmd = call_kwargs.kwargs["fallback_cmd"] + commands = mock_init.call_args.kwargs["commands"] + assert len(commands) == 2 + assert "--force-refresh" in commands[0].args + assert "--force-refresh" not in commands[1].args + + def test_host_only_builds_one_command(self, mocker): + """With host only: single plain host command.""" + mock_init = mocker.patch.object( + credentials_provider.CliTokenSource, + "__init__", + return_value=None, + ) + + mock_cfg = Mock() + mock_cfg.profile = None + mock_cfg.host = "https://workspace.databricks.com" + mock_cfg.client_type = ClientType.WORKSPACE + mock_cfg.account_id = None + mock_cfg.databricks_cli_path = "/path/to/databricks" + mock_cfg.disable_async_token_refresh = False - assert cmd == ["/path/to/databricks", "auth", "token", "--profile", "my-profile"] - assert host_cmd is None + credentials_provider.DatabricksCliTokenSource(mock_cfg) + commands = mock_init.call_args.kwargs["commands"] + assert len(commands) == 1 + assert "--host" in commands[0].args + assert "--force-refresh" not in commands[0].args -# Tests for CliTokenSource fallback on unknown --profile flag -class TestCliTokenSourceFallback: - """Tests that CliTokenSource falls back to --host when CLI doesn't support --profile.""" - def _make_token_source(self, fallback_cmd=None): - ts = credentials_provider.CliTokenSource.__new__(credentials_provider.CliTokenSource) - ts._cmd = ["databricks", "auth", "token", "--profile", "my-profile"] - ts._fallback_cmd = fallback_cmd - ts._token_type_field = "token_type" - ts._access_token_field = "access_token" - ts._expiry_field = "expiry" - return ts +class TestDatabricksCliForceRefresh: + """Tests for --force-refresh support in DatabricksCliTokenSource.""" - def _make_process_error(self, stderr: str, stdout: str = ""): + @staticmethod + def _make_process_error(stderr: str, stdout: str = ""): import subprocess err = subprocess.CalledProcessError(1, ["databricks"]) @@ -409,67 +418,156 @@ def _make_process_error(self, stderr: str, stdout: str = ""): err.stderr = stderr.encode() return err - def test_fallback_on_unknown_profile_flag(self, mocker): - """When --profile fails with 'unknown flag: --profile', falls back to --host command.""" + @staticmethod + def _make_token_source( + *, + profile=None, + host="https://workspace.databricks.com", + cli_path="/path/to/databricks", + ): + """Build a DatabricksCliTokenSource by mocking only the executable check.""" + mock_cfg = Mock() + mock_cfg.profile = profile + mock_cfg.host = host + mock_cfg.databricks_cli_path = cli_path + mock_cfg.disable_async_token_refresh = True + mock_cfg.scopes = None + mock_cfg.get_scopes = Mock(return_value=["all-apis"]) + mock_cfg.client_type = ClientType.WORKSPACE + mock_cfg.account_id = None + return credentials_provider.DatabricksCliTokenSource(mock_cfg) + + def _valid_response_json(self, access_token="fresh-token"): import json expiry = (datetime.now() + timedelta(hours=1)).strftime("%Y-%m-%dT%H:%M:%S") - valid_response = json.dumps({"access_token": "fallback-token", "token_type": "Bearer", "expiry": expiry}) + return json.dumps({"access_token": access_token, "token_type": "Bearer", "expiry": expiry}) + + def test_force_refresh_tried_first_with_profile(self, mocker): + """When profile is configured, refresh() tries --force-refresh first.""" + ts = self._make_token_source(profile="my-profile") + + mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess") + mock_run.return_value = Mock(stdout=self._valid_response_json("refreshed").encode()) + + token = ts.refresh() + assert token.access_token == "refreshed" + assert mock_run.call_count == 1 + + cmd = mock_run.call_args[0][0] + assert "--force-refresh" in cmd + assert "--profile" in cmd + + def test_host_only_skips_force_refresh(self, mocker): + """When only host is configured, --force-refresh is not used.""" + ts = self._make_token_source() + + mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess") + mock_run.return_value = Mock(stdout=self._valid_response_json("token").encode()) + + token = ts.refresh() + assert token.access_token == "token" + assert mock_run.call_count == 1 + + cmd = mock_run.call_args[0][0] + assert "--force-refresh" not in cmd + assert "--host" in cmd + + def test_force_refresh_fallback_when_unsupported(self, mocker): + """Old CLI without --force-refresh: falls back to plain --profile command.""" + ts = self._make_token_source(profile="my-profile") mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess") mock_run.side_effect = [ - self._make_process_error("Error: unknown flag: --profile"), - Mock(stdout=valid_response.encode()), + self._make_process_error("Error: unknown flag: --force-refresh"), + Mock(stdout=self._valid_response_json("fallback").encode()), ] - fallback_cmd = ["databricks", "auth", "token", "--host", "https://workspace.databricks.com"] - ts = self._make_token_source(fallback_cmd=fallback_cmd) token = ts.refresh() - assert token.access_token == "fallback-token" + assert token.access_token == "fallback" assert mock_run.call_count == 2 - assert mock_run.call_args_list[1][0][0] == fallback_cmd - def test_fallback_triggered_when_unknown_flag_in_stderr_only(self, mocker): - """Fallback triggers even when CLI also writes usage text to stdout.""" - import json + first_cmd = mock_run.call_args_list[0][0][0] + second_cmd = mock_run.call_args_list[1][0][0] + assert "--force-refresh" in first_cmd + assert "--force-refresh" not in second_cmd + assert "--profile" in second_cmd - expiry = (datetime.now() + timedelta(hours=1)).strftime("%Y-%m-%dT%H:%M:%S") - valid_response = json.dumps({"access_token": "fallback-token", "token_type": "Bearer", "expiry": expiry}) + def test_profile_fallback_to_host(self, mocker): + """Old CLI without --profile: progressive loop walks to --host terminal command.""" + ts = self._make_token_source(profile="my-profile") mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess") mock_run.side_effect = [ - self._make_process_error(stderr="Error: unknown flag: --profile", stdout="Usage: databricks auth token"), - Mock(stdout=valid_response.encode()), + self._make_process_error("Error: unknown flag: --profile"), + self._make_process_error("Error: unknown flag: --profile"), + Mock(stdout=self._valid_response_json("host-token").encode()), ] - fallback_cmd = ["databricks", "auth", "token", "--host", "https://workspace.databricks.com"] - ts = self._make_token_source(fallback_cmd=fallback_cmd) token = ts.refresh() - assert token.access_token == "fallback-token" + assert token.access_token == "host-token" + assert mock_run.call_count == 3 + assert "--host" in mock_run.call_args_list[2][0][0] + + def test_full_fallback_chain(self, mocker): + """CLI supports neither --force-refresh nor --profile: walks the full chain.""" + ts = self._make_token_source(profile="my-profile") - def test_no_fallback_on_real_auth_error(self, mocker): - """When --profile fails with a real error (not unknown flag), no fallback is attempted.""" mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess") - mock_run.side_effect = self._make_process_error("cache: databricks OAuth is not configured for this host") + mock_run.side_effect = [ + self._make_process_error("Error: unknown flag: --force-refresh"), + self._make_process_error("Error: unknown flag: --profile"), + Mock(stdout=self._valid_response_json("plain").encode()), + ] - fallback_cmd = ["databricks", "auth", "token", "--host", "https://workspace.databricks.com"] - ts = self._make_token_source(fallback_cmd=fallback_cmd) - with pytest.raises(IOError) as exc_info: - ts.refresh() - assert "databricks OAuth is not configured" in str(exc_info.value) - assert mock_run.call_count == 1 + token = ts.refresh() + assert token.access_token == "plain" + assert mock_run.call_count == 3 + + cmds = [call[0][0] for call in mock_run.call_args_list] + assert "--force-refresh" in cmds[0] and "--profile" in cmds[0] + assert "--force-refresh" not in cmds[1] and "--profile" in cmds[1] + assert "--host" in cmds[2] + + def test_active_command_index_caching(self, mocker): + """After fallback, subsequent refreshes start at the cached command index.""" + ts = self._make_token_source(profile="my-profile") - def test_no_fallback_when_fallback_cmd_not_set(self, mocker): - """When fallback_cmd is None and --profile fails, the original error is raised.""" mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess") - mock_run.side_effect = self._make_process_error("Error: unknown flag: --profile") + mock_run.side_effect = [ + self._make_process_error("Error: unknown flag: --force-refresh"), + Mock(stdout=self._valid_response_json("first").encode()), + Mock(stdout=self._valid_response_json("second").encode()), + ] + + token1 = ts.refresh() + assert token1.access_token == "first" + assert mock_run.call_count == 2 + + token2 = ts.refresh() + assert token2.access_token == "second" + assert mock_run.call_count == 3 + + def test_real_auth_error_does_not_trigger_fallback(self, mocker): + """Real auth failures (not unknown-flag) surface immediately.""" + ts = self._make_token_source(profile="my-profile") + + mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess") + mock_run.side_effect = self._make_process_error("cache: databricks OAuth is not configured for this host") - ts = self._make_token_source(fallback_cmd=None) with pytest.raises(IOError) as exc_info: ts.refresh() - assert "unknown flag: --profile" in str(exc_info.value) + assert "databricks OAuth is not configured" in str(exc_info.value) assert mock_run.call_count == 1 + def test_is_unknown_flag_error(self): + """_is_unknown_flag_error matches against specific flag list.""" + check = credentials_provider.CliTokenSource._is_unknown_flag_error + assert check(IOError("Error: unknown flag: --force-refresh"), ["--force-refresh", "--profile"]) + assert check(IOError("Error: unknown flag: --profile"), ["--profile"]) + assert not check(IOError("Error: unknown flag: --force-refresh"), ["--profile"]) + assert not check(IOError("some other error"), ["--force-refresh"]) + # Tests for cloud-agnostic hosts and removed cloud checks class TestCloudAgnosticHosts: