Skip to content

Commit bdc451f

Browse files
Generalize CLI token source into progressive command list
Replace the three explicit command fields (force_cmd, cmd, fallback_cmd) and manual fallback methods with a CliCommand dataclass and an optional commands list on CliTokenSource. When commands is provided, refresh() delegates to _refresh_progressive() which walks the list from activeCommandIndex, falling back on unsupported-flag errors. When commands is None, refresh() delegates to _refresh_single() which preserves the original single-command behavior with zero changes for AzureCliTokenSource. DatabricksCliTokenSource._build_commands() produces the progressive list: --profile + --force-refresh first, plain --profile second, and --host as a terminal fallback. --force-refresh is only paired with --profile, never with --host. Adding future flags (e.g. --scopes) requires only adding entries to _build_commands().
1 parent 4115f37 commit bdc451f

3 files changed

Lines changed: 187 additions & 153 deletions

File tree

NEXT_CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
### Internal Changes
2020
* Replace the async-disabling mechanism on token refresh failure with a 1-minute retry backoff. Previously, a single failed async refresh would disable proactive token renewal until the token expired. Now, the SDK waits a short cooldown period and retries, improving resilience to transient errors.
2121
* Extract `_resolve_profile` to simplify config file loading and improve `__settings__` error messages.
22+
* Generalize CLI token source into a progressive command list for forward-compatible flag support.
2223

2324
### API Changes
2425
* Add `create_catalog()`, `create_synced_table()`, `delete_catalog()`, `delete_synced_table()`, `get_catalog()` and `get_synced_table()` methods for [w.postgres](https://databricks-sdk-py.readthedocs.io/en/latest/workspace/postgres/postgres.html) workspace-level service.

databricks/sdk/credentials_provider.py

Lines changed: 88 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import abc
22
import base64
3+
import dataclasses
34
import functools
45
import io
56
import json
67
import logging
78
import os
89
import pathlib
910
import platform
11+
import re
1012
import subprocess
1113
import sys
1214
import threading
@@ -649,7 +651,17 @@ def refreshed_headers() -> Dict[str, str]:
649651
return OAuthCredentialsProvider(refreshed_headers, token)
650652

651653

654+
@dataclasses.dataclass
655+
class CliCommand:
656+
"""A single CLI command variant with metadata for progressive fallback."""
657+
658+
args: List[str]
659+
flags: List[str]
660+
warning: str
661+
662+
652663
class CliTokenSource(oauth.Refreshable):
664+
_UNKNOWN_FLAG_RE = re.compile(r"unknown flag: (--[a-z-]+)")
653665

654666
def __init__(
655667
self,
@@ -658,18 +670,15 @@ def __init__(
658670
access_token_field: str,
659671
expiry_field: str,
660672
disable_async: bool = True,
661-
fallback_cmd: Optional[List[str]] = None,
673+
commands: Optional[List[CliCommand]] = None,
662674
):
663675
super().__init__(disable_async=disable_async)
664676
self._cmd = cmd
665-
# fallback_cmd is tried when the primary command fails with "unknown flag: --profile",
666-
# indicating the CLI is too old to support --profile. Can be removed once support
667-
# for CLI versions predating --profile is dropped.
668-
# See: https://github.com/databricks/databricks-sdk-go/pull/1497
669-
self._fallback_cmd = fallback_cmd
670677
self._token_type_field = token_type_field
671678
self._access_token_field = access_token_field
672679
self._expiry_field = expiry_field
680+
self._commands = commands
681+
self._active_command_index = 0
673682

674683
@staticmethod
675684
def _parse_expiry(expiry: str) -> datetime:
@@ -700,17 +709,35 @@ def _exec_cli_command(self, cmd: List[str]) -> oauth.Token:
700709
message = "\n".join(filter(None, [stdout, stderr]))
701710
raise IOError(f"cannot get access token: {message}") from e
702711

712+
@staticmethod
713+
def _is_unknown_flag_error(error: IOError, flags: List[str]) -> bool:
714+
"""Check if the error indicates the CLI rejected one of the given flags."""
715+
msg = str(error)
716+
return any(f"unknown flag: {flag}" in msg for flag in flags)
717+
703718
def refresh(self) -> oauth.Token:
704-
try:
705-
return self._exec_cli_command(self._cmd)
706-
except IOError as e:
707-
if self._fallback_cmd is not None and "unknown flag: --profile" in str(e):
708-
logger.warning(
709-
"Databricks CLI does not support --profile flag. Falling back to --host. "
710-
"Please upgrade your CLI to the latest version."
711-
)
712-
return self._exec_cli_command(self._fallback_cmd)
713-
raise
719+
if self._commands is not None:
720+
return self._refresh_progressive()
721+
return self._refresh_single()
722+
723+
def _refresh_single(self) -> oauth.Token:
724+
return self._exec_cli_command(self._cmd)
725+
726+
def _refresh_progressive(self) -> oauth.Token:
727+
last_err: Optional[IOError] = None
728+
for i in range(self._active_command_index, len(self._commands)):
729+
cmd = self._commands[i]
730+
try:
731+
token = self._exec_cli_command(cmd.args)
732+
self._active_command_index = i
733+
return token
734+
except IOError as e:
735+
is_last = i == len(self._commands) - 1
736+
if is_last or not self._is_unknown_flag_error(e, cmd.flags):
737+
raise
738+
logger.warning(cmd.warning)
739+
last_err = e
740+
raise last_err
714741

715742

716743
def _run_subprocess(
@@ -899,15 +926,7 @@ def __init__(self, cfg: "Config"):
899926
elif cli_path.count("/") == 0:
900927
cli_path = self.__class__._find_executable(cli_path)
901928

902-
fallback_cmd = None
903-
if cfg.profile:
904-
args = ["auth", "token", "--profile", cfg.profile]
905-
if cfg.host:
906-
fallback_cmd = [cli_path, *self.__class__._build_host_args(cfg)]
907-
else:
908-
args = self.__class__._build_host_args(cfg)
909-
910-
self._force_cmd = [cli_path, *args, "--force-refresh"]
929+
commands = self.__class__._build_commands(cli_path, cfg)
911930

912931
# get_scopes() defaults to ["all-apis"] when nothing is configured, which would
913932
# cause false-positive mismatches against every token that wasn't issued with
@@ -917,27 +936,57 @@ def __init__(self, cfg: "Config"):
917936
self._host = cfg.host
918937

919938
super().__init__(
920-
cmd=[cli_path, *args],
939+
cmd=commands[-1].args,
921940
token_type_field="token_type",
922941
access_token_field="access_token",
923942
expiry_field="expiry",
924943
disable_async=cfg.disable_async_token_refresh,
925-
fallback_cmd=fallback_cmd,
944+
commands=commands,
926945
)
927946

928-
def refresh(self) -> oauth.Token:
929-
try:
930-
token = self._exec_cli_command(self._force_cmd)
931-
except IOError as e:
932-
err_msg = str(e)
933-
if "unknown flag: --force-refresh" in err_msg or "unknown flag: --profile" in err_msg:
934-
logger.warning(
935-
"Databricks CLI does not support --force-refresh. "
936-
"Please upgrade your CLI to the latest version."
947+
@staticmethod
948+
def _build_commands(cli_path: str, cfg: "Config") -> List[CliCommand]:
949+
commands: List[CliCommand] = []
950+
if cfg.profile:
951+
profile_args = [cli_path, "auth", "token", "--profile", cfg.profile]
952+
commands.append(
953+
CliCommand(
954+
args=profile_args + ["--force-refresh"],
955+
flags=["--force-refresh", "--profile"],
956+
warning="Databricks CLI does not support --force-refresh. "
957+
"Please upgrade your CLI to the latest version.",
937958
)
938-
token = super().refresh()
939-
else:
940-
raise
959+
)
960+
commands.append(
961+
CliCommand(
962+
args=profile_args,
963+
flags=["--profile"],
964+
warning="Databricks CLI does not support --profile flag. "
965+
"Falling back to --host. "
966+
"Please upgrade your CLI to the latest version.",
967+
)
968+
)
969+
if cfg.host:
970+
commands.append(
971+
CliCommand(
972+
args=[cli_path, *DatabricksCliTokenSource._build_host_args(cfg)],
973+
flags=[],
974+
warning="",
975+
)
976+
)
977+
else:
978+
host_args = [cli_path, *DatabricksCliTokenSource._build_host_args(cfg)]
979+
commands.append(
980+
CliCommand(
981+
args=host_args,
982+
flags=[],
983+
warning="",
984+
)
985+
)
986+
return commands
987+
988+
def refresh(self) -> oauth.Token:
989+
token = super().refresh()
941990
if self._requested_scopes:
942991
self._validate_token_scopes(token)
943992
return token

0 commit comments

Comments
 (0)