11import abc
22import base64
3+ import dataclasses
34import functools
45import io
56import json
@@ -650,6 +651,15 @@ def refreshed_headers() -> Dict[str, str]:
650651 return OAuthCredentialsProvider (refreshed_headers , token )
651652
652653
654+ @dataclasses .dataclass
655+ class CliCommand :
656+ """A single CLI command variant with metadata for progressive fallback."""
657+
658+ args : List [str ]
659+ flags : List [str ]
660+ warning : str
661+
662+
653663class CliTokenSource (oauth .Refreshable ):
654664 _UNKNOWN_FLAG_RE = re .compile (r"unknown flag: (--[a-z-]+)" )
655665
@@ -660,18 +670,15 @@ def __init__(
660670 access_token_field : str ,
661671 expiry_field : str ,
662672 disable_async : bool = True ,
663- fallback_cmd : Optional [List [str ]] = None ,
673+ commands : Optional [List [CliCommand ]] = None ,
664674 ):
665675 super ().__init__ (disable_async = disable_async )
666676 self ._cmd = cmd
667- # fallback_cmd is tried when the primary command fails with "unknown flag: --profile",
668- # indicating the CLI is too old to support --profile. Can be removed once support
669- # for CLI versions predating --profile is dropped.
670- # See: https://github.com/databricks/databricks-sdk-go/pull/1497
671- self ._fallback_cmd = fallback_cmd
672677 self ._token_type_field = token_type_field
673678 self ._access_token_field = access_token_field
674679 self ._expiry_field = expiry_field
680+ self ._commands = commands
681+ self ._active_command_index = 0
675682
676683 @staticmethod
677684 def _parse_expiry (expiry : str ) -> datetime :
@@ -703,22 +710,34 @@ def _exec_cli_command(self, cmd: List[str]) -> oauth.Token:
703710 raise IOError (f"cannot get access token: { message } " ) from e
704711
705712 @staticmethod
706- def _get_unsupported_flag (error : IOError ) -> Optional [str ]:
707- """Extract the flag name if the error is an 'unknown flag' CLI rejection ."""
708- match = CliTokenSource . _UNKNOWN_FLAG_RE . search ( str (error ) )
709- return match . group ( 1 ) if match else None
713+ def _is_unknown_flag_error (error : IOError , flags : List [str ]) -> bool :
714+ """Check if the error indicates the CLI rejected one of the given flags ."""
715+ msg = str (error )
716+ return any ( f"unknown flag: { flag } " in msg for flag in flags )
710717
711718 def refresh (self ) -> oauth .Token :
712- try :
713- return self ._exec_cli_command (self ._cmd )
714- except IOError as e :
715- if self ._fallback_cmd is not None and "unknown flag: --profile" in str (e ):
716- logger .warning (
717- "Databricks CLI does not support --profile flag. Falling back to --host. "
718- "Please upgrade your CLI to the latest version."
719- )
720- return self ._exec_cli_command (self ._fallback_cmd )
721- raise
719+ if self ._commands is not None :
720+ return self ._refresh_progressive ()
721+ return self ._refresh_single ()
722+
723+ def _refresh_single (self ) -> oauth .Token :
724+ return self ._exec_cli_command (self ._cmd )
725+
726+ def _refresh_progressive (self ) -> oauth .Token :
727+ last_err : Optional [IOError ] = None
728+ for i in range (self ._active_command_index , len (self ._commands )):
729+ cmd = self ._commands [i ]
730+ try :
731+ token = self ._exec_cli_command (cmd .args )
732+ self ._active_command_index = i
733+ return token
734+ except IOError as e :
735+ is_last = i == len (self ._commands ) - 1
736+ if is_last or not self ._is_unknown_flag_error (e , cmd .flags ):
737+ raise
738+ logger .warning (cmd .warning )
739+ last_err = e
740+ raise last_err
722741
723742
724743def _run_subprocess (
@@ -907,15 +926,7 @@ def __init__(self, cfg: "Config"):
907926 elif cli_path .count ("/" ) == 0 :
908927 cli_path = self .__class__ ._find_executable (cli_path )
909928
910- fallback_cmd = None
911- if cfg .profile :
912- args = ["auth" , "token" , "--profile" , cfg .profile ]
913- if cfg .host :
914- fallback_cmd = [cli_path , * self .__class__ ._build_host_args (cfg )]
915- else :
916- args = self .__class__ ._build_host_args (cfg )
917-
918- self ._force_cmd = [cli_path , * args , "--force-refresh" ]
929+ commands = self .__class__ ._build_commands (cli_path , cfg )
919930
920931 # get_scopes() defaults to ["all-apis"] when nothing is configured, which would
921932 # cause false-positive mismatches against every token that wasn't issued with
@@ -925,29 +936,57 @@ def __init__(self, cfg: "Config"):
925936 self ._host = cfg .host
926937
927938 super ().__init__ (
928- cmd = [ cli_path , * args ] ,
939+ cmd = commands [ - 1 ]. args ,
929940 token_type_field = "token_type" ,
930941 access_token_field = "access_token" ,
931942 expiry_field = "expiry" ,
932943 disable_async = cfg .disable_async_token_refresh ,
933- fallback_cmd = fallback_cmd ,
944+ commands = commands ,
934945 )
935946
936- _KNOWN_CLI_FLAGS = {"--force-refresh" , "--profile" }
947+ @staticmethod
948+ def _build_commands (cli_path : str , cfg : "Config" ) -> List [CliCommand ]:
949+ commands : List [CliCommand ] = []
950+ if cfg .profile :
951+ profile_args = [cli_path , "auth" , "token" , "--profile" , cfg .profile ]
952+ commands .append (
953+ CliCommand (
954+ args = profile_args + ["--force-refresh" ],
955+ flags = ["--force-refresh" , "--profile" ],
956+ warning = "Databricks CLI does not support --force-refresh. "
957+ "Please upgrade your CLI to the latest version." ,
958+ )
959+ )
960+ commands .append (
961+ CliCommand (
962+ args = profile_args ,
963+ flags = ["--profile" ],
964+ warning = "Databricks CLI does not support --profile flag. "
965+ "Falling back to --host. "
966+ "Please upgrade your CLI to the latest version." ,
967+ )
968+ )
969+ if cfg .host :
970+ commands .append (
971+ CliCommand (
972+ args = [cli_path , * DatabricksCliTokenSource ._build_host_args (cfg )],
973+ flags = [],
974+ warning = "" ,
975+ )
976+ )
977+ else :
978+ host_args = [cli_path , * DatabricksCliTokenSource ._build_host_args (cfg )]
979+ commands .append (
980+ CliCommand (
981+ args = host_args ,
982+ flags = [],
983+ warning = "" ,
984+ )
985+ )
986+ return commands
937987
938988 def refresh (self ) -> oauth .Token :
939- try :
940- token = self ._exec_cli_command (self ._force_cmd )
941- except IOError as e :
942- flag = self ._get_unsupported_flag (e )
943- if flag in self ._KNOWN_CLI_FLAGS :
944- logger .warning (
945- "Databricks CLI does not support %s. " "Please upgrade your CLI to the latest version." ,
946- flag ,
947- )
948- token = super ().refresh ()
949- else :
950- raise
989+ token = super ().refresh ()
951990 if self ._requested_scopes :
952991 self ._validate_token_scopes (token )
953992 return token
0 commit comments