Skip to content

Commit 4684f06

Browse files
Implement CLICredentials to read tokens from the local cache (#4570)
## Summary Introduces a CLI-owned credentials chain and implements `CLICredentials`, a credentials strategy that reads OAuth tokens directly from the local token cache via `u2m.PersistentAuth`, instead of shelling out to `databricks auth token` as a subprocess. ## Why The SDK authenticates by iterating through an ordered list of credential strategies. One of these — `u2mCredentials` (auth type `"databricks-cli"`) — works by spawning `databricks auth token --host <HOST>` as a child process. When the CLI itself is the running process, this is a circular dependency: the CLI shells out to a copy of itself just to read a cached token. This PR does two things to address this: 1. **Introduces a CLI-owned credentials chain.** An `init()` function in `libs/auth/credentials.go` sets `config.DefaultCredentialStrategyProvider` to a custom chain that the CLI controls. This runs on every CLI invocation — regardless of the command — because `libs/auth` is transitively imported by the rest of the CLI. Owning the chain guarantees the CLI remains stable despite the evolution of the SDK, and allows customizing individual strategies. 2. **Implements `CLICredentials`** as the replacement for the SDK's `u2mCredentials` in that chain. Instead of shelling out, it reads the token cache in-process via `u2m.PersistentAuth`, eliminating the subprocess round-trip. **Out of scope:** the token retrieval logic is now duplicated between `CLICredentials` and the `databricks auth token` command (`cmd/auth/token.go`). Sharing this via a common abstraction is not the goal of this PR and will be done as a follow-up. ## What changed ### Interface changes - **`CLICredentials.PersistentAuthOptions []u2m.PersistentAuthOption`** — new field that allows injecting test dependencies (token cache, endpoint supplier, HTTP client) into the underlying `u2m.PersistentAuth`. ### Behavioral changes Commands that previously fell through to the SDK's subprocess-based `u2mCredentials` strategy will now authenticate in-process via `CLICredentials`. The authentication result is identical — same token cache, same refresh logic — but without spawning a child process. ### Internal changes - **`init()` in `libs/auth/credentials.go`** — registers a CLI-owned credentials chain via `config.DefaultCredentialStrategyProvider`, replacing the SDK's default chain. The chain preserves the same strategy order as the SDK, with `CLICredentials` substituted for the SDK's `u2mCredentials`. - **`CLICredentials.Configure()`** — converts the SDK `config.Config` to `AuthArguments`, creates a `u2m.PersistentAuth` to access the token cache (with token refresh), wraps it in a `CachedTokenSource` with async refresh controlled by `cfg.DisableOAuthRefreshToken`, and returns an `OAuthCredentialsProvider`. - **`authArgumentsFromConfig()`** — new helper that bridges `config.Config` fields (`Host`, `AccountID`, `WorkspaceID`, `Experimental_IsUnifiedHost`) to the CLI's `AuthArguments` type. - **SDK bump** — `databricks-sdk-go` updated to `v0.110.1-0.20260221140112-be1d4d821dd1` which exposes the `NewCachedTokenSource`, `WithAsyncRefresh`, and `NewOAuthCredentialsProviderFromTokenSource` APIs needed by this implementation. ## How is this tested? Unit tests in `libs/auth/credentials_test.go`: - `TestCLICredentialsName` — asserts `Name()` returns `"databricks-cli"`. - `TestCLICredentialsConfigure` — table-driven tests with injected token cache and mock endpoint supplier covering: empty host (error), workspace host with valid token, account host with valid token, no cached token, expired token with successful refresh, expired token with failed refresh. Verified with a local build that the `databricks-cli` flow properly goes through the new credentials strategy and is able to make requests as expected. --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 1899134 commit 4684f06

File tree

2 files changed

+325
-0
lines changed

2 files changed

+325
-0
lines changed

libs/auth/credentials.go

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
package auth
2+
3+
import (
4+
"context"
5+
"errors"
6+
7+
"github.com/databricks/databricks-sdk-go/config"
8+
"github.com/databricks/databricks-sdk-go/config/credentials"
9+
"github.com/databricks/databricks-sdk-go/config/experimental/auth"
10+
"github.com/databricks/databricks-sdk-go/config/experimental/auth/authconv"
11+
"github.com/databricks/databricks-sdk-go/credentials/u2m"
12+
)
13+
14+
// The credentials chain used by the CLI. It is a custom implementation
15+
// that differs from the SDK's default credentials chain. This guarantees
16+
// that the CLI remain stable despite the evolution of the SDK while
17+
// allowing the customization of some strategies such as "databricks-cli"
18+
// which has a different behavior than the SDK.
19+
//
20+
// Modifying this order could break authentication for users whose
21+
// environments are compatible with multiple strategies and who rely
22+
// on the current priority for tie-breaking.
23+
var credentialChain = []config.CredentialsStrategy{
24+
config.PatCredentials{},
25+
config.BasicCredentials{},
26+
config.M2mCredentials{},
27+
CLICredentials{}, // custom
28+
config.MetadataServiceCredentials{},
29+
// OIDC Strategies.
30+
config.GitHubOIDCCredentials{},
31+
config.AzureDevOpsOIDCCredentials{},
32+
config.EnvOIDCCredentials{},
33+
config.FileOIDCCredentials{},
34+
// Azure strategies.
35+
config.AzureGithubOIDCCredentials{},
36+
config.AzureMsiCredentials{},
37+
config.AzureClientSecretCredentials{},
38+
config.AzureCliCredentials{},
39+
// Google strategies.
40+
config.GoogleCredentials{},
41+
config.GoogleDefaultCredentials{},
42+
}
43+
44+
func init() {
45+
// Sets the credentials chain for the CLI.
46+
config.DefaultCredentialStrategyProvider = func() config.CredentialsStrategy {
47+
return &defaultCredentials{chain: config.NewCredentialsChain(credentialChain...)}
48+
}
49+
}
50+
51+
// defaultCredentials wraps the CLI credential chain and provides "default"
52+
// as the fallback name, matching the SDK's DefaultCredentials behavior.
53+
type defaultCredentials struct {
54+
chain config.CredentialsStrategy
55+
}
56+
57+
func (d *defaultCredentials) Name() string {
58+
if name := d.chain.Name(); name != "" {
59+
return name
60+
}
61+
return "default"
62+
}
63+
64+
func (d *defaultCredentials) Configure(ctx context.Context, cfg *config.Config) (credentials.CredentialsProvider, error) {
65+
return d.chain.Configure(ctx, cfg)
66+
}
67+
68+
// CLICredentials is a credentials strategy that reads OAuth tokens directly
69+
// from the local token store. It replaces the SDK's default "databricks-cli"
70+
// strategy, which shells out to `databricks auth token` as a subprocess.
71+
type CLICredentials struct {
72+
// persistentAuth is a function to override the default implementation
73+
// of the persistent auth client. It exists for testing purposes only.
74+
persistentAuthFn func(ctx context.Context, opts ...u2m.PersistentAuthOption) (auth.TokenSource, error)
75+
}
76+
77+
// Name implements [config.CredentialsStrategy].
78+
func (c CLICredentials) Name() string {
79+
return "databricks-cli"
80+
}
81+
82+
var errNoHost = errors.New("no host provided")
83+
84+
// Configure implements [config.CredentialsStrategy].
85+
//
86+
// IMPORTANT: This credentials strategy ignores the scopes specified in the
87+
// config and purely relies on the scopes from the loaded CLI token. This can
88+
// lead to mismatches if the token was obtained with different scopes than the
89+
// ones configured in the current profile. This is a temporary limitation that
90+
// will be addressed in a future release by adding support for dynamic token
91+
// downscoping.
92+
func (c CLICredentials) Configure(ctx context.Context, cfg *config.Config) (credentials.CredentialsProvider, error) {
93+
if cfg.Host == "" {
94+
return nil, errNoHost
95+
}
96+
oauthArg, err := authArgumentsFromConfig(cfg).ToOAuthArgument()
97+
if err != nil {
98+
return nil, err
99+
}
100+
ts, err := c.persistentAuth(ctx, u2m.WithOAuthArgument(oauthArg))
101+
if err != nil {
102+
return nil, err
103+
}
104+
cp := credentials.NewOAuthCredentialsProviderFromTokenSource(
105+
auth.NewCachedTokenSource(ts, auth.WithAsyncRefresh(!cfg.DisableOAuthRefreshToken)),
106+
)
107+
return cp, nil
108+
}
109+
110+
// persistentAuth returns a token source. It is a convenience function that
111+
// overrides the default implementation of the persistent auth client if
112+
// an alternative implementation is provided for testing.
113+
func (c CLICredentials) persistentAuth(ctx context.Context, opts ...u2m.PersistentAuthOption) (auth.TokenSource, error) {
114+
if c.persistentAuthFn != nil {
115+
return c.persistentAuthFn(ctx, opts...)
116+
}
117+
ts, err := u2m.NewPersistentAuth(ctx, opts...)
118+
if err != nil {
119+
return nil, err
120+
}
121+
return authconv.AuthTokenSource(ts), nil
122+
}
123+
124+
// authArgumentsFromConfig converts an SDK config to AuthArguments.
125+
func authArgumentsFromConfig(cfg *config.Config) AuthArguments {
126+
return AuthArguments{
127+
Host: cfg.Host,
128+
AccountID: cfg.AccountID,
129+
WorkspaceID: cfg.WorkspaceID,
130+
IsUnifiedHost: cfg.Experimental_IsUnifiedHost,
131+
Profile: cfg.Profile,
132+
}
133+
}

libs/auth/credentials_test.go

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
package auth
2+
3+
import (
4+
"context"
5+
"errors"
6+
"net/http"
7+
"slices"
8+
"testing"
9+
10+
"github.com/databricks/databricks-sdk-go/config"
11+
"github.com/databricks/databricks-sdk-go/config/experimental/auth"
12+
"github.com/databricks/databricks-sdk-go/credentials/u2m"
13+
"golang.org/x/oauth2"
14+
)
15+
16+
// TestCredentialChainOrder purely exists as an extra measure to catch
17+
// accidental change in the ordering.
18+
func TestCredentialChainOrder(t *testing.T) {
19+
names := make([]string, len(credentialChain))
20+
for i, s := range credentialChain {
21+
names[i] = s.Name()
22+
}
23+
want := []string{
24+
"pat",
25+
"basic",
26+
"oauth-m2m",
27+
"databricks-cli",
28+
"metadata-service",
29+
"github-oidc",
30+
"azure-devops-oidc",
31+
"env-oidc",
32+
"file-oidc",
33+
"github-oidc-azure",
34+
"azure-msi",
35+
"azure-client-secret",
36+
"azure-cli",
37+
"google-credentials",
38+
"google-id",
39+
}
40+
if !slices.Equal(names, want) {
41+
t.Errorf("credential chain order: want %v, got %v", want, names)
42+
}
43+
}
44+
45+
func TestCLICredentialsName(t *testing.T) {
46+
c := CLICredentials{}
47+
if got := c.Name(); got != "databricks-cli" {
48+
t.Errorf("Name(): want %q, got %q", "databricks-cli", got)
49+
}
50+
}
51+
52+
func TestAuthArgumentsFromConfig(t *testing.T) {
53+
tests := []struct {
54+
name string
55+
cfg *config.Config
56+
want AuthArguments
57+
}{
58+
{
59+
name: "empty config",
60+
cfg: &config.Config{},
61+
want: AuthArguments{},
62+
},
63+
{
64+
name: "workspace host only",
65+
cfg: &config.Config{
66+
Host: "https://myworkspace.cloud.databricks.com",
67+
},
68+
want: AuthArguments{
69+
Host: "https://myworkspace.cloud.databricks.com",
70+
},
71+
},
72+
{
73+
name: "account host with account ID",
74+
cfg: &config.Config{
75+
Host: "https://accounts.cloud.databricks.com",
76+
AccountID: "test-account-id",
77+
},
78+
want: AuthArguments{
79+
Host: "https://accounts.cloud.databricks.com",
80+
AccountID: "test-account-id",
81+
},
82+
},
83+
{
84+
name: "all fields",
85+
cfg: &config.Config{
86+
Host: "https://myhost.com",
87+
AccountID: "acc-123",
88+
WorkspaceID: "ws-456",
89+
Profile: "my-profile",
90+
Experimental_IsUnifiedHost: true,
91+
},
92+
want: AuthArguments{
93+
Host: "https://myhost.com",
94+
AccountID: "acc-123",
95+
WorkspaceID: "ws-456",
96+
Profile: "my-profile",
97+
IsUnifiedHost: true,
98+
},
99+
},
100+
}
101+
102+
for _, tt := range tests {
103+
t.Run(tt.name, func(t *testing.T) {
104+
got := authArgumentsFromConfig(tt.cfg)
105+
if got != tt.want {
106+
t.Errorf("want %v, got %v", tt.want, got)
107+
}
108+
})
109+
}
110+
}
111+
112+
func TestCLICredentialsConfigure(t *testing.T) {
113+
testErr := errors.New("test error")
114+
115+
tests := []struct {
116+
name string
117+
cfg *config.Config
118+
persistentAuthFn func(ctx context.Context, opts ...u2m.PersistentAuthOption) (auth.TokenSource, error)
119+
wantErr error
120+
wantToken string
121+
}{
122+
{
123+
name: "empty host returns error",
124+
cfg: &config.Config{},
125+
wantErr: errNoHost,
126+
},
127+
{
128+
name: "persistentAuthFn error is propagated",
129+
cfg: &config.Config{
130+
Host: "https://myworkspace.cloud.databricks.com",
131+
},
132+
persistentAuthFn: func(_ context.Context, _ ...u2m.PersistentAuthOption) (auth.TokenSource, error) {
133+
return nil, testErr
134+
},
135+
wantErr: testErr,
136+
},
137+
{
138+
name: "workspace host",
139+
cfg: &config.Config{
140+
Host: "https://myworkspace.cloud.databricks.com",
141+
},
142+
persistentAuthFn: func(_ context.Context, _ ...u2m.PersistentAuthOption) (auth.TokenSource, error) {
143+
return auth.TokenSourceFn(func(_ context.Context) (*oauth2.Token, error) {
144+
return &oauth2.Token{AccessToken: "workspace-token"}, nil
145+
}), nil
146+
},
147+
wantToken: "workspace-token",
148+
},
149+
{
150+
name: "account host",
151+
cfg: &config.Config{
152+
Host: "https://accounts.cloud.databricks.com",
153+
AccountID: "test-account-id",
154+
},
155+
persistentAuthFn: func(_ context.Context, _ ...u2m.PersistentAuthOption) (auth.TokenSource, error) {
156+
return auth.TokenSourceFn(func(_ context.Context) (*oauth2.Token, error) {
157+
return &oauth2.Token{AccessToken: "account-token"}, nil
158+
}), nil
159+
},
160+
wantToken: "account-token",
161+
},
162+
}
163+
164+
for _, tt := range tests {
165+
t.Run(tt.name, func(t *testing.T) {
166+
ctx := context.Background()
167+
c := CLICredentials{persistentAuthFn: tt.persistentAuthFn}
168+
169+
got, err := c.Configure(ctx, tt.cfg)
170+
171+
if !errors.Is(err, tt.wantErr) {
172+
t.Fatalf("want error %v, got %v", tt.wantErr, err)
173+
}
174+
if tt.wantErr != nil {
175+
return
176+
}
177+
178+
// Verify the credentials provider sets the correct Bearer token.
179+
req, err := http.NewRequest("GET", tt.cfg.Host, nil)
180+
if err != nil {
181+
t.Fatalf("creating request: %v", err)
182+
}
183+
if err := got.SetHeaders(req); err != nil {
184+
t.Fatalf("SetHeaders: want no error, got %v", err)
185+
}
186+
want := "Bearer " + tt.wantToken
187+
if gotHeader := req.Header.Get("Authorization"); gotHeader != want {
188+
t.Errorf("Authorization header: want %q, got %q", want, gotHeader)
189+
}
190+
})
191+
}
192+
}

0 commit comments

Comments
 (0)