Skip to content

Commit 96fcf7f

Browse files
Generalize CLI token source into progressive command list
Replace the explicit force_cmd/fallback_cmd fields with a CliCommand dataclass and an optional commands list on CliTokenSource. When commands is provided, refresh() delegates to _refresh_progressive() which walks the list from activeCommandIndex, falling back on unsupported-flag errors. When commands is None, refresh() delegates to _refresh_single() which preserves the original fallback behavior with zero changes for AzureCliTokenSource. DatabricksCliTokenSource._build_commands() produces the progressive list: --profile + --force-refresh first, plain --profile second, and --host as a terminal fallback. --force-refresh is only paired with --profile, never with --host. Adding future flags (e.g. --scopes) requires only adding entries to _build_commands().
1 parent 469cb44 commit 96fcf7f

File tree

3 files changed

+179
-138
lines changed

3 files changed

+179
-138
lines changed

NEXT_CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
### Internal Changes
2020
* Replace the async-disabling mechanism on token refresh failure with a 1-minute retry backoff. Previously, a single failed async refresh would disable proactive token renewal until the token expired. Now, the SDK waits a short cooldown period and retries, improving resilience to transient errors.
2121
* Extract `_resolve_profile` to simplify config file loading and improve `__settings__` error messages.
22+
* Generalize CLI token source into a progressive command list for forward-compatible flag support.
2223

2324
### API Changes
2425
* Add `create_catalog()`, `create_synced_table()`, `delete_catalog()`, `delete_synced_table()`, `get_catalog()` and `get_synced_table()` methods for [w.postgres](https://databricks-sdk-py.readthedocs.io/en/latest/workspace/postgres/postgres.html) workspace-level service.

databricks/sdk/credentials_provider.py

Lines changed: 85 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import abc
22
import base64
3+
import dataclasses
34
import functools
45
import io
56
import json
@@ -649,6 +650,15 @@ def refreshed_headers() -> Dict[str, str]:
649650
return OAuthCredentialsProvider(refreshed_headers, token)
650651

651652

653+
@dataclasses.dataclass
654+
class CliCommand:
655+
"""A single CLI command variant with metadata for progressive fallback."""
656+
657+
args: List[str]
658+
flags: List[str]
659+
warning: str
660+
661+
652662
class CliTokenSource(oauth.Refreshable):
653663

654664
def __init__(
@@ -659,6 +669,7 @@ def __init__(
659669
expiry_field: str,
660670
disable_async: bool = True,
661671
fallback_cmd: Optional[List[str]] = None,
672+
commands: Optional[List[CliCommand]] = None,
662673
):
663674
super().__init__(disable_async=disable_async)
664675
self._cmd = cmd
@@ -670,6 +681,8 @@ def __init__(
670681
self._token_type_field = token_type_field
671682
self._access_token_field = access_token_field
672683
self._expiry_field = expiry_field
684+
self._commands = commands
685+
self._active_command_index = 0
673686

674687
@staticmethod
675688
def _parse_expiry(expiry: str) -> datetime:
@@ -700,7 +713,18 @@ def _exec_cli_command(self, cmd: List[str]) -> oauth.Token:
700713
message = "\n".join(filter(None, [stdout, stderr]))
701714
raise IOError(f"cannot get access token: {message}") from e
702715

716+
@staticmethod
717+
def _is_unknown_flag_error(error: IOError, flags: List[str]) -> bool:
718+
"""Check if the error indicates the CLI rejected one of the given flags."""
719+
msg = str(error)
720+
return any(f"unknown flag: {flag}" in msg for flag in flags)
721+
703722
def refresh(self) -> oauth.Token:
723+
if self._commands is not None:
724+
return self._refresh_progressive()
725+
return self._refresh_single()
726+
727+
def _refresh_single(self) -> oauth.Token:
704728
try:
705729
return self._exec_cli_command(self._cmd)
706730
except IOError as e:
@@ -712,6 +736,22 @@ def refresh(self) -> oauth.Token:
712736
return self._exec_cli_command(self._fallback_cmd)
713737
raise
714738

739+
def _refresh_progressive(self) -> oauth.Token:
740+
last_err: Optional[IOError] = None
741+
for i in range(self._active_command_index, len(self._commands)):
742+
cmd = self._commands[i]
743+
try:
744+
token = self._exec_cli_command(cmd.args)
745+
self._active_command_index = i
746+
return token
747+
except IOError as e:
748+
is_last = i == len(self._commands) - 1
749+
if is_last or not self._is_unknown_flag_error(e, cmd.flags):
750+
raise
751+
logger.warning(cmd.warning)
752+
last_err = e
753+
raise last_err
754+
715755

716756
def _run_subprocess(
717757
popenargs,
@@ -899,15 +939,7 @@ def __init__(self, cfg: "Config"):
899939
elif cli_path.count("/") == 0:
900940
cli_path = self.__class__._find_executable(cli_path)
901941

902-
fallback_cmd = None
903-
if cfg.profile:
904-
args = ["auth", "token", "--profile", cfg.profile]
905-
if cfg.host:
906-
fallback_cmd = [cli_path, *self.__class__._build_host_args(cfg)]
907-
else:
908-
args = self.__class__._build_host_args(cfg)
909-
910-
self._force_cmd = [cli_path, *args, "--force-refresh"]
942+
commands = self.__class__._build_commands(cli_path, cfg)
911943

912944
# get_scopes() defaults to ["all-apis"] when nothing is configured, which would
913945
# cause false-positive mismatches against every token that wasn't issued with
@@ -917,27 +949,57 @@ def __init__(self, cfg: "Config"):
917949
self._host = cfg.host
918950

919951
super().__init__(
920-
cmd=[cli_path, *args],
952+
cmd=commands[-1].args,
921953
token_type_field="token_type",
922954
access_token_field="access_token",
923955
expiry_field="expiry",
924956
disable_async=cfg.disable_async_token_refresh,
925-
fallback_cmd=fallback_cmd,
957+
commands=commands,
926958
)
927959

928-
def refresh(self) -> oauth.Token:
929-
try:
930-
token = self._exec_cli_command(self._force_cmd)
931-
except IOError as e:
932-
err_msg = str(e)
933-
if "unknown flag: --force-refresh" in err_msg or "unknown flag: --profile" in err_msg:
934-
logger.warning(
935-
"Databricks CLI does not support --force-refresh. "
936-
"Please upgrade your CLI to the latest version."
960+
@staticmethod
961+
def _build_commands(cli_path: str, cfg: "Config") -> List[CliCommand]:
962+
commands: List[CliCommand] = []
963+
if cfg.profile:
964+
profile_args = [cli_path, "auth", "token", "--profile", cfg.profile]
965+
commands.append(
966+
CliCommand(
967+
args=profile_args + ["--force-refresh"],
968+
flags=["--force-refresh", "--profile"],
969+
warning="Databricks CLI does not support --force-refresh. "
970+
"Please upgrade your CLI to the latest version.",
937971
)
938-
token = super().refresh()
939-
else:
940-
raise
972+
)
973+
commands.append(
974+
CliCommand(
975+
args=profile_args,
976+
flags=["--profile"],
977+
warning="Databricks CLI does not support --profile flag. "
978+
"Falling back to --host. "
979+
"Please upgrade your CLI to the latest version.",
980+
)
981+
)
982+
if cfg.host:
983+
commands.append(
984+
CliCommand(
985+
args=[cli_path, *DatabricksCliTokenSource._build_host_args(cfg)],
986+
flags=[],
987+
warning="",
988+
)
989+
)
990+
else:
991+
host_args = [cli_path, *DatabricksCliTokenSource._build_host_args(cfg)]
992+
commands.append(
993+
CliCommand(
994+
args=host_args,
995+
flags=[],
996+
warning="",
997+
)
998+
)
999+
return commands
1000+
1001+
def refresh(self) -> oauth.Token:
1002+
token = super().refresh()
9411003
if self._requested_scopes:
9421004
self._validate_token_scopes(token)
9431005
return token

0 commit comments

Comments
 (0)