Skip to content

Commit f979a2c

Browse files
Generalize CLI token source into progressive command list
Replace the three explicit command fields (force_cmd, cmd, fallback_cmd) and manual fallback methods with a CliCommand dataclass and an optional commands list on CliTokenSource. When commands is provided, refresh() delegates to _refresh_progressive() which walks the list from activeCommandIndex, falling back on unsupported-flag errors. When commands is None, refresh() delegates to _refresh_single() which preserves the original single-command behavior with zero changes for AzureCliTokenSource. DatabricksCliTokenSource._build_commands() produces the progressive list: --profile + --force-refresh first, plain --profile second, and --host as a terminal fallback. --force-refresh is only paired with --profile, never with --host. Adding future flags (e.g. --scopes) requires only adding entries to _build_commands().
1 parent 9e82d18 commit f979a2c

3 files changed

Lines changed: 181 additions & 163 deletions

File tree

NEXT_CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
### Documentation
1313

1414
### Internal Changes
15+
* Generalize CLI token source into a progressive command list for forward-compatible flag support.
1516

1617
### API Changes
1718
* Add `disable_gov_tag_creation` field for `databricks.sdk.service.settings.RestrictWorkspaceAdminsMessage`.

databricks/sdk/credentials_provider.py

Lines changed: 83 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import abc
22
import base64
3+
import dataclasses
34
import functools
45
import io
56
import json
@@ -650,6 +651,15 @@ def refreshed_headers() -> Dict[str, str]:
650651
return OAuthCredentialsProvider(refreshed_headers, token)
651652

652653

654+
@dataclasses.dataclass
655+
class CliCommand:
656+
"""A single CLI command variant with metadata for progressive fallback."""
657+
658+
args: List[str]
659+
flags: List[str]
660+
warning: str
661+
662+
653663
class CliTokenSource(oauth.Refreshable):
654664
_UNKNOWN_FLAG_RE = re.compile(r"unknown flag: (--[a-z-]+)")
655665

@@ -660,18 +670,15 @@ def __init__(
660670
access_token_field: str,
661671
expiry_field: str,
662672
disable_async: bool = True,
663-
fallback_cmd: Optional[List[str]] = None,
673+
commands: Optional[List[CliCommand]] = None,
664674
):
665675
super().__init__(disable_async=disable_async)
666676
self._cmd = cmd
667-
# fallback_cmd is tried when the primary command fails with "unknown flag: --profile",
668-
# indicating the CLI is too old to support --profile. Can be removed once support
669-
# for CLI versions predating --profile is dropped.
670-
# See: https://github.com/databricks/databricks-sdk-go/pull/1497
671-
self._fallback_cmd = fallback_cmd
672677
self._token_type_field = token_type_field
673678
self._access_token_field = access_token_field
674679
self._expiry_field = expiry_field
680+
self._commands = commands
681+
self._active_command_index = 0
675682

676683
@staticmethod
677684
def _parse_expiry(expiry: str) -> datetime:
@@ -703,22 +710,34 @@ def _exec_cli_command(self, cmd: List[str]) -> oauth.Token:
703710
raise IOError(f"cannot get access token: {message}") from e
704711

705712
@staticmethod
706-
def _get_unsupported_flag(error: IOError) -> Optional[str]:
707-
"""Extract the flag name if the error is an 'unknown flag' CLI rejection."""
708-
match = CliTokenSource._UNKNOWN_FLAG_RE.search(str(error))
709-
return match.group(1) if match else None
713+
def _is_unknown_flag_error(error: IOError, flags: List[str]) -> bool:
714+
"""Check if the error indicates the CLI rejected one of the given flags."""
715+
msg = str(error)
716+
return any(f"unknown flag: {flag}" in msg for flag in flags)
710717

711718
def refresh(self) -> oauth.Token:
712-
try:
713-
return self._exec_cli_command(self._cmd)
714-
except IOError as e:
715-
if self._fallback_cmd is not None and "unknown flag: --profile" in str(e):
716-
logger.warning(
717-
"Databricks CLI does not support --profile flag. Falling back to --host. "
718-
"Please upgrade your CLI to the latest version."
719-
)
720-
return self._exec_cli_command(self._fallback_cmd)
721-
raise
719+
if self._commands is not None:
720+
return self._refresh_progressive()
721+
return self._refresh_single()
722+
723+
def _refresh_single(self) -> oauth.Token:
724+
return self._exec_cli_command(self._cmd)
725+
726+
def _refresh_progressive(self) -> oauth.Token:
727+
last_err: Optional[IOError] = None
728+
for i in range(self._active_command_index, len(self._commands)):
729+
cmd = self._commands[i]
730+
try:
731+
token = self._exec_cli_command(cmd.args)
732+
self._active_command_index = i
733+
return token
734+
except IOError as e:
735+
is_last = i == len(self._commands) - 1
736+
if is_last or not self._is_unknown_flag_error(e, cmd.flags):
737+
raise
738+
logger.warning(cmd.warning)
739+
last_err = e
740+
raise last_err
722741

723742

724743
def _run_subprocess(
@@ -907,15 +926,7 @@ def __init__(self, cfg: "Config"):
907926
elif cli_path.count("/") == 0:
908927
cli_path = self.__class__._find_executable(cli_path)
909928

910-
fallback_cmd = None
911-
if cfg.profile:
912-
args = ["auth", "token", "--profile", cfg.profile]
913-
if cfg.host:
914-
fallback_cmd = [cli_path, *self.__class__._build_host_args(cfg)]
915-
else:
916-
args = self.__class__._build_host_args(cfg)
917-
918-
self._force_cmd = [cli_path, *args, "--force-refresh"]
929+
commands = self.__class__._build_commands(cli_path, cfg)
919930

920931
# get_scopes() defaults to ["all-apis"] when nothing is configured, which would
921932
# cause false-positive mismatches against every token that wasn't issued with
@@ -925,29 +936,57 @@ def __init__(self, cfg: "Config"):
925936
self._host = cfg.host
926937

927938
super().__init__(
928-
cmd=[cli_path, *args],
939+
cmd=commands[-1].args,
929940
token_type_field="token_type",
930941
access_token_field="access_token",
931942
expiry_field="expiry",
932943
disable_async=cfg.disable_async_token_refresh,
933-
fallback_cmd=fallback_cmd,
944+
commands=commands,
934945
)
935946

936-
_KNOWN_CLI_FLAGS = {"--force-refresh", "--profile"}
947+
@staticmethod
948+
def _build_commands(cli_path: str, cfg: "Config") -> List[CliCommand]:
949+
commands: List[CliCommand] = []
950+
if cfg.profile:
951+
profile_args = [cli_path, "auth", "token", "--profile", cfg.profile]
952+
commands.append(
953+
CliCommand(
954+
args=profile_args + ["--force-refresh"],
955+
flags=["--force-refresh", "--profile"],
956+
warning="Databricks CLI does not support --force-refresh. "
957+
"Please upgrade your CLI to the latest version.",
958+
)
959+
)
960+
commands.append(
961+
CliCommand(
962+
args=profile_args,
963+
flags=["--profile"],
964+
warning="Databricks CLI does not support --profile flag. "
965+
"Falling back to --host. "
966+
"Please upgrade your CLI to the latest version.",
967+
)
968+
)
969+
if cfg.host:
970+
commands.append(
971+
CliCommand(
972+
args=[cli_path, *DatabricksCliTokenSource._build_host_args(cfg)],
973+
flags=[],
974+
warning="",
975+
)
976+
)
977+
else:
978+
host_args = [cli_path, *DatabricksCliTokenSource._build_host_args(cfg)]
979+
commands.append(
980+
CliCommand(
981+
args=host_args,
982+
flags=[],
983+
warning="",
984+
)
985+
)
986+
return commands
937987

938988
def refresh(self) -> oauth.Token:
939-
try:
940-
token = self._exec_cli_command(self._force_cmd)
941-
except IOError as e:
942-
flag = self._get_unsupported_flag(e)
943-
if flag in self._KNOWN_CLI_FLAGS:
944-
logger.warning(
945-
"Databricks CLI does not support %s. " "Please upgrade your CLI to the latest version.",
946-
flag,
947-
)
948-
token = super().refresh()
949-
else:
950-
raise
989+
token = super().refresh()
951990
if self._requested_scopes:
952991
self._validate_token_scopes(token)
953992
return token

0 commit comments

Comments
 (0)