diff --git a/acceptance/bin/discovery_browser.py b/acceptance/bin/discovery_browser.py new file mode 100755 index 0000000000..42099fa06d --- /dev/null +++ b/acceptance/bin/discovery_browser.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python3 +""" +Simulates the login.databricks.com discovery flow for acceptance tests. + +When the CLI opens this "browser" with the login.databricks.com URL, +the script extracts the OAuth parameters from the destination_url, +constructs a callback to localhost with an iss parameter pointing +at the testserver, and fetches it. + +Usage: discovery_browser.py +""" + +import os +import sys +import urllib.parse +import urllib.request + +if len(sys.argv) < 2: + sys.stderr.write("Usage: discovery_browser.py \n") + sys.exit(1) + +url = sys.argv[1] +parsed = urllib.parse.urlparse(url) +top_params = urllib.parse.parse_qs(parsed.query) + +destination_url = top_params.get("destination_url", [None])[0] +if not destination_url: + sys.stderr.write(f"No destination_url found in: {url}\n") + sys.exit(1) + +dest_parsed = urllib.parse.urlparse(destination_url) +dest_params = urllib.parse.parse_qs(dest_parsed.query) + +redirect_uri = dest_params.get("redirect_uri", [None])[0] +state = dest_params.get("state", [None])[0] + +if not redirect_uri or not state: + sys.stderr.write(f"Missing redirect_uri or state in destination_url: {destination_url}\n") + sys.exit(1) + +# The testserver's host acts as the workspace issuer. +testserver_host = os.environ.get("DATABRICKS_HOST", "") +if not testserver_host: + sys.stderr.write("DATABRICKS_HOST not set\n") + sys.exit(1) + +issuer = testserver_host.rstrip("/") + "/oidc" + +# Build the callback URL with code, state, and iss (the workspace issuer). +callback_params = urllib.parse.urlencode( + { + "code": "oauth-code", + "state": state, + "iss": issuer, + } +) +callback_url = f"{redirect_uri}?{callback_params}" + +try: + response = urllib.request.urlopen(callback_url) + if response.status != 200: + sys.stderr.write(f"Callback failed: {callback_url} (status {response.status})\n") + sys.exit(1) +except Exception as e: + sys.stderr.write(f"Callback failed: {callback_url} ({e})\n") + sys.exit(1) + +sys.exit(0) diff --git a/acceptance/cmd/auth/login/discovery/out.databrickscfg b/acceptance/cmd/auth/login/discovery/out.databrickscfg new file mode 100644 index 0000000000..d6e17b7595 --- /dev/null +++ b/acceptance/cmd/auth/login/discovery/out.databrickscfg @@ -0,0 +1,10 @@ +; The profile defined in the DEFAULT section is to be used as a fallback when no profile is explicitly specified. +[DEFAULT] + +[discovery-test] +host = [DATABRICKS_URL] +workspace_id = 12345 +auth_type = databricks-cli + +[__settings__] +default_profile = discovery-test diff --git a/acceptance/cmd/auth/login/discovery/out.test.toml b/acceptance/cmd/auth/login/discovery/out.test.toml new file mode 100644 index 0000000000..d560f1de04 --- /dev/null +++ b/acceptance/cmd/auth/login/discovery/out.test.toml @@ -0,0 +1,5 @@ +Local = true +Cloud = false + +[EnvMatrix] + DATABRICKS_BUNDLE_ENGINE = ["terraform", "direct"] diff --git a/acceptance/cmd/auth/login/discovery/output.txt b/acceptance/cmd/auth/login/discovery/output.txt new file mode 100644 index 0000000000..c687b07fd5 --- /dev/null +++ b/acceptance/cmd/auth/login/discovery/output.txt @@ -0,0 +1,14 @@ + +>>> [CLI] auth login --profile discovery-test +Opening login.databricks.com in your browser... +Profile discovery-test was successfully saved + +>>> [CLI] auth profiles +Name Host Valid +discovery-test (Default) [DATABRICKS_URL] YES + +>>> print_requests.py --get //tokens/introspect +{ + "method": "GET", + "path": "/api/2.0/tokens/introspect" +} diff --git a/acceptance/cmd/auth/login/discovery/script b/acceptance/cmd/auth/login/discovery/script new file mode 100644 index 0000000000..4ae4c682d9 --- /dev/null +++ b/acceptance/cmd/auth/login/discovery/script @@ -0,0 +1,14 @@ +sethome "./home" + +# Use the discovery browser script that simulates login.databricks.com +export BROWSER="discovery_browser.py" + +trace $CLI auth login --profile discovery-test + +trace $CLI auth profiles + +# Verify the introspection endpoint was called (workspace_id in profile confirms this too). +trace print_requests.py --get //tokens/introspect + +# Track the .databrickscfg file that was created to surface changes. +mv "./home/.databrickscfg" "./out.databrickscfg" diff --git a/acceptance/cmd/auth/login/discovery/test.toml b/acceptance/cmd/auth/login/discovery/test.toml new file mode 100644 index 0000000000..d430a86e03 --- /dev/null +++ b/acceptance/cmd/auth/login/discovery/test.toml @@ -0,0 +1,18 @@ +Ignore = [ + "home" +] +RecordRequests = true + +# Override the introspection endpoint so we can verify it gets called. +[[Server]] +Pattern = "GET /api/2.0/tokens/introspect" +Response.Body = ''' +{ + "principal_context": { + "authentication_scope": { + "account_id": "test-account-123", + "workspace_id": 12345 + } + } +} +''' diff --git a/cmd/auth/login.go b/cmd/auth/login.go index 1a1240d84c..379ce1a4c2 100644 --- a/cmd/auth/login.go +++ b/cmd/auth/login.go @@ -16,12 +16,14 @@ import ( "github.com/databricks/cli/libs/databrickscfg/profile" "github.com/databricks/cli/libs/env" "github.com/databricks/cli/libs/exec" + "github.com/databricks/cli/libs/log" "github.com/databricks/databricks-sdk-go" "github.com/databricks/databricks-sdk-go/config" "github.com/databricks/databricks-sdk-go/config/experimental/auth/authconv" "github.com/databricks/databricks-sdk-go/credentials/u2m" browserpkg "github.com/pkg/browser" "github.com/spf13/cobra" + "golang.org/x/oauth2" ) func promptForProfile(ctx context.Context, defaultValue string) (string, error) { @@ -45,8 +47,46 @@ const ( minimalDbConnectVersion = "13.1" defaultTimeout = 1 * time.Hour authTypeDatabricksCLI = "databricks-cli" + discoveryFallbackTip = "\n\nTip: you can specify a workspace directly with: databricks auth login --host " ) +// discoveryErr wraps an error (or creates a new one) and appends the +// discovery fallback tip so users know they can bypass login.databricks.com. +func discoveryErr(msg string, err error) error { + if err != nil { + return fmt.Errorf("%s: %w%s", msg, err, discoveryFallbackTip) + } + return fmt.Errorf("%s%s", msg, discoveryFallbackTip) +} + +type discoveryPersistentAuth interface { + Challenge() error + Token() (*oauth2.Token, error) + Close() error +} + +// discoveryClient abstracts the external dependencies of discoveryLogin so +// they can be replaced in tests without package-level variable mutation. +type discoveryClient interface { + NewOAuthArgument(profileName string) (*u2m.BasicDiscoveryOAuthArgument, error) + NewPersistentAuth(ctx context.Context, opts ...u2m.PersistentAuthOption) (discoveryPersistentAuth, error) + IntrospectToken(ctx context.Context, host, accessToken string) (*auth.IntrospectionResult, error) +} + +type defaultDiscoveryClient struct{} + +func (d *defaultDiscoveryClient) NewOAuthArgument(profileName string) (*u2m.BasicDiscoveryOAuthArgument, error) { + return u2m.NewBasicDiscoveryOAuthArgument(profileName) +} + +func (d *defaultDiscoveryClient) NewPersistentAuth(ctx context.Context, opts ...u2m.PersistentAuthOption) (discoveryPersistentAuth, error) { + return u2m.NewPersistentAuth(ctx, opts...) +} + +func (d *defaultDiscoveryClient) IntrospectToken(ctx context.Context, host, accessToken string) (*auth.IntrospectionResult, error) { + return auth.IntrospectToken(ctx, host, accessToken, nil) +} + func newLoginCommand(authArguments *auth.AuthArguments) *cobra.Command { defaultConfigPath := "~/.databrickscfg" if runtime.GOOS == "windows" { @@ -69,9 +109,11 @@ you can refer to the documentation linked below. GCP: https://docs.gcp.databricks.com/dev-tools/auth/index.html -This command requires a Databricks Host URL (using --host or as a positional argument -or implicitly inferred from the specified profile name) -and a profile name (using --profile) to be specified. If you don't specify these +If no host is provided (via --host, as a positional argument, or from an existing +profile), the CLI will open login.databricks.com where you can authenticate and +select a workspace. The workspace URL will be discovered automatically. + +A profile name (using --profile) can be specified. If you don't specify these values, you'll be prompted for values at runtime. While this command always logs you into the specified host, the runtime behaviour @@ -138,6 +180,15 @@ depends on the existing profiles you have set in your configuration file return err } + // If no host is available from any source, use the discovery flow + // via login.databricks.com. + if shouldUseDiscovery(authArguments.Host, args, existingProfile) { + if err := validateDiscoveryFlagCompatibility(cmd); err != nil { + return err + } + return discoveryLogin(ctx, &defaultDiscoveryClient{}, profileName, loginTimeout, scopes, existingProfile, getBrowserFunc(cmd)) + } + // Load unified host flags from the profile if not explicitly set via CLI flag if !cmd.Flag("experimental-is-unified-host").Changed && existingProfile != nil { authArguments.IsUnifiedHost = existingProfile.IsUnifiedHost @@ -157,15 +208,11 @@ depends on the existing profiles you have set in your configuration file switch { case scopes != "": // Explicit --scopes flag takes precedence. - for _, s := range strings.Split(scopes, ",") { - scopesList = append(scopesList, strings.TrimSpace(s)) - } + scopesList = splitScopes(scopes) case existingProfile != nil && existingProfile.Scopes != "": // Preserve scopes from the existing profile so re-login // uses the same scopes the user previously configured. - for _, s := range strings.Split(existingProfile.Scopes, ",") { - scopesList = append(scopesList, strings.TrimSpace(s)) - } + scopesList = splitScopes(existingProfile.Scopes) } oauthArgument, err := authArguments.ToOAuthArgument() @@ -400,6 +447,43 @@ func loadProfileByName(ctx context.Context, profileName string, profiler profile return nil, nil } +// shouldUseDiscovery returns true if the discovery flow should be used +// (no host available from any source). +func shouldUseDiscovery(hostFlag string, args []string, existingProfile *profile.Profile) bool { + if hostFlag != "" { + return false + } + if len(args) > 0 { + return false + } + if existingProfile != nil && existingProfile.Host != "" { + return false + } + return true +} + +// discoveryIncompatibleFlags lists flags that require --host and are incompatible +// with the discovery login flow via login.databricks.com. +var discoveryIncompatibleFlags = []string{ + "account-id", + "workspace-id", + "experimental-is-unified-host", + "configure-cluster", + "configure-serverless", +} + +// validateDiscoveryFlagCompatibility returns an error if any flags that require +// --host were explicitly set. These flags are meaningless in discovery mode +// and could lead to incorrect profile configuration. +func validateDiscoveryFlagCompatibility(cmd *cobra.Command) error { + for _, name := range discoveryIncompatibleFlags { + if cmd.Flag(name).Changed { + return fmt.Errorf("--%s requires --host to be specified", name) + } + } + return nil +} + // openURLSuppressingStderr opens a URL in the browser while suppressing stderr output. // This prevents xdg-open error messages from being displayed to the user. func openURLSuppressingStderr(url string) error { @@ -416,6 +500,123 @@ func openURLSuppressingStderr(url string) error { return browserpkg.OpenURL(url) } +// discoveryLogin runs the login.databricks.com discovery flow. The user +// authenticates in the browser, selects a workspace, and the CLI receives +// the workspace host from the OAuth callback's iss parameter. +func discoveryLogin(ctx context.Context, dc discoveryClient, profileName string, timeout time.Duration, scopes string, existingProfile *profile.Profile, browserFunc func(string) error) error { + arg, err := dc.NewOAuthArgument(profileName) + if err != nil { + return discoveryErr("setting up login.databricks.com", err) + } + + scopesList := splitScopes(scopes) + if len(scopesList) == 0 && existingProfile != nil && existingProfile.Scopes != "" { + scopesList = splitScopes(existingProfile.Scopes) + } + + opts := []u2m.PersistentAuthOption{ + u2m.WithOAuthArgument(arg), + u2m.WithBrowser(browserFunc), + u2m.WithDiscoveryLogin(), + } + if len(scopesList) > 0 { + opts = append(opts, u2m.WithScopes(scopesList)) + } + + // Apply timeout before creating PersistentAuth so Challenge() respects it. + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + persistentAuth, err := dc.NewPersistentAuth(ctx, opts...) + if err != nil { + return discoveryErr("setting up login.databricks.com", err) + } + defer persistentAuth.Close() + + cmdio.LogString(ctx, "Opening login.databricks.com in your browser...") + if err := persistentAuth.Challenge(); err != nil { + return discoveryErr("login via login.databricks.com failed", err) + } + + discoveredHost := arg.GetDiscoveredHost() + if discoveredHost == "" { + return discoveryErr("login succeeded but no workspace host was discovered", nil) + } + + // Get the token for introspection + tok, err := persistentAuth.Token() + if err != nil { + return fmt.Errorf("retrieving token after login: %w", err) + } + + // Best-effort introspection for metadata. + var workspaceID string + introspection, err := dc.IntrospectToken(ctx, discoveredHost, tok.AccessToken) + if err != nil { + log.Debugf(ctx, "token introspection failed (non-fatal): %v", err) + } else { + // TODO: Save introspection.AccountID once the SDKs are ready to use + // account_id as part of the profile/cache key. Adding it now would break + // existing auth flows that don't expect account_id on workspace profiles. + workspaceID = introspection.WorkspaceID + + // Warn if the detected account_id differs from what's already saved in the profile. + if existingProfile != nil && existingProfile.AccountID != "" && introspection.AccountID != "" && + existingProfile.AccountID != introspection.AccountID { + log.Warnf(ctx, "detected account ID %q differs from existing profile account ID %q", + introspection.AccountID, existingProfile.AccountID) + } + } + + configFile := env.Get(ctx, "DATABRICKS_CONFIG_FILE") + clearKeys := oauthLoginClearKeys() + // Discovery login always produces a workspace-level profile pointing at the + // discovered host. Any previous routing metadata (account_id, workspace_id, + // is_unified_host, cluster_id, serverless_compute_id) from a prior login to + // a different host type must be cleared so they don't leak into the new + // profile. workspace_id is re-added only when introspection succeeds. + clearKeys = append(clearKeys, + "account_id", + "workspace_id", + "experimental_is_unified_host", + "cluster_id", + "serverless_compute_id", + ) + err = databrickscfg.SaveToProfile(ctx, &config.Config{ + Profile: profileName, + Host: discoveredHost, + AuthType: authTypeDatabricksCLI, + WorkspaceID: workspaceID, + Scopes: scopesList, + ConfigFile: configFile, + }, clearKeys...) + if err != nil { + if configFile != "" { + return fmt.Errorf("saving profile %q to %s: %w", profileName, configFile, err) + } + return fmt.Errorf("saving profile %q: %w", profileName, err) + } + + cmdio.LogString(ctx, fmt.Sprintf("Profile %s was successfully saved", profileName)) + return nil +} + +// splitScopes splits a comma-separated scopes string into a trimmed slice. +func splitScopes(scopes string) []string { + var result []string + for _, s := range strings.Split(scopes, ",") { + scope := strings.TrimSpace(s) + if scope == "" { + continue + } + result = append(result, scope) + } + if len(result) == 0 { + return nil + } + return result +} + // oauthLoginClearKeys returns profile keys that should be explicitly removed // when performing an OAuth login. Derives auth credential fields dynamically // from the SDK's ConfigAttributes to stay in sync as new auth methods are added. diff --git a/cmd/auth/login_test.go b/cmd/auth/login_test.go index bd135bc730..670ff7c211 100644 --- a/cmd/auth/login_test.go +++ b/cmd/auth/login_test.go @@ -1,17 +1,46 @@ package auth import ( + "bytes" "context" + "errors" + "log/slog" + "os" + "path/filepath" + "sync" "testing" + "time" "github.com/databricks/cli/libs/auth" "github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/databrickscfg/profile" "github.com/databricks/cli/libs/env" + "github.com/databricks/cli/libs/log" + "github.com/databricks/databricks-sdk-go/credentials/u2m" + "github.com/spf13/cobra" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/oauth2" ) +// logBuffer is a thread-safe bytes.Buffer for capturing log output in tests. +type logBuffer struct { + mu sync.Mutex + buf bytes.Buffer +} + +func (lb *logBuffer) Write(p []byte) (int, error) { + lb.mu.Lock() + defer lb.mu.Unlock() + return lb.buf.Write(p) +} + +func (lb *logBuffer) String() string { + lb.mu.Lock() + defer lb.mu.Unlock() + return lb.buf.String() +} + func loadTestProfile(t *testing.T, ctx context.Context, profileName string) *profile.Profile { profile, err := loadProfileByName(ctx, profileName, profile.DefaultProfiler) require.NoError(t, err) @@ -19,6 +48,62 @@ func loadTestProfile(t *testing.T, ctx context.Context, profileName string) *pro return profile } +type fakeDiscoveryPersistentAuth struct { + token *oauth2.Token + challengeErr error + tokenErr error +} + +func (f *fakeDiscoveryPersistentAuth) Challenge() error { + return f.challengeErr +} + +func (f *fakeDiscoveryPersistentAuth) Token() (*oauth2.Token, error) { + if f.tokenErr != nil { + return nil, f.tokenErr + } + return f.token, nil +} + +func (f *fakeDiscoveryPersistentAuth) Close() error { + return nil +} + +type fakeDiscoveryClient struct { + oauthArg *u2m.BasicDiscoveryOAuthArgument + oauthArgErr error + persistentAuth discoveryPersistentAuth + persistentAuthErr error + introspection *auth.IntrospectionResult + introspectionErr error + // For assertions + introspectHost string + introspectToken string +} + +func (f *fakeDiscoveryClient) NewOAuthArgument(profileName string) (*u2m.BasicDiscoveryOAuthArgument, error) { + if f.oauthArgErr != nil { + return nil, f.oauthArgErr + } + return f.oauthArg, nil +} + +func (f *fakeDiscoveryClient) NewPersistentAuth(ctx context.Context, opts ...u2m.PersistentAuthOption) (discoveryPersistentAuth, error) { + if f.persistentAuthErr != nil { + return nil, f.persistentAuthErr + } + return f.persistentAuth, nil +} + +func (f *fakeDiscoveryClient) IntrospectToken(ctx context.Context, host, accessToken string) (*auth.IntrospectionResult, error) { + f.introspectHost = host + f.introspectToken = accessToken + if f.introspectionErr != nil { + return nil, f.introspectionErr + } + return f.introspection, nil +} + func TestSetHostDoesNotFailWithNoDatabrickscfg(t *testing.T) { ctx := t.Context() ctx = env.Set(ctx, "DATABRICKS_CONFIG_FILE", "./imaginary-file/databrickscfg") @@ -267,3 +352,464 @@ func TestLoadProfileByNameAndClusterID(t *testing.T) { }) } } + +func TestShouldUseDiscovery(t *testing.T) { + tests := []struct { + name string + hostFlag string + args []string + existingProfile *profile.Profile + want bool + }{ + { + name: "no host from any source", + want: true, + }, + { + name: "host from flag", + hostFlag: "https://example.com", + want: false, + }, + { + name: "host from positional arg", + args: []string{"https://example.com"}, + want: false, + }, + { + name: "host from existing profile", + existingProfile: &profile.Profile{Host: "https://example.com"}, + want: false, + }, + { + name: "existing profile without host", + existingProfile: &profile.Profile{Name: "test"}, + want: true, + }, + { + name: "nil profile", + existingProfile: nil, + want: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := shouldUseDiscovery(tt.hostFlag, tt.args, tt.existingProfile) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestSplitScopes(t *testing.T) { + tests := []struct { + name string + input string + output []string + }{ + { + name: "empty input", + input: "", + output: nil, + }, + { + name: "single scope", + input: "all-apis", + output: []string{"all-apis"}, + }, + { + name: "trims whitespace", + input: " all-apis , sql ", + output: []string{"all-apis", "sql"}, + }, + { + name: "drops empty entries", + input: "all-apis, ,sql,,", + output: []string{"all-apis", "sql"}, + }, + { + name: "only empty entries", + input: " , , ", + output: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.output, splitScopes(tt.input)) + }) + } +} + +func TestValidateDiscoveryFlagCompatibility(t *testing.T) { + tests := []struct { + name string + setFlag string + flagVal string + wantErr string + }{ + { + name: "account-id is incompatible", + setFlag: "account-id", + flagVal: "abc123", + wantErr: "--account-id requires --host to be specified", + }, + { + name: "workspace-id is incompatible", + setFlag: "workspace-id", + flagVal: "12345", + wantErr: "--workspace-id requires --host to be specified", + }, + { + name: "experimental-is-unified-host is incompatible", + setFlag: "experimental-is-unified-host", + flagVal: "true", + wantErr: "--experimental-is-unified-host requires --host to be specified", + }, + { + name: "configure-cluster is incompatible", + setFlag: "configure-cluster", + flagVal: "true", + wantErr: "--configure-cluster requires --host to be specified", + }, + { + name: "configure-serverless is incompatible", + setFlag: "configure-serverless", + flagVal: "true", + wantErr: "--configure-serverless requires --host to be specified", + }, + { + name: "no flags set is ok", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cmd := &cobra.Command{} + cmd.Flags().String("account-id", "", "") + cmd.Flags().String("workspace-id", "", "") + cmd.Flags().Bool("experimental-is-unified-host", false, "") + cmd.Flags().Bool("configure-cluster", false, "") + cmd.Flags().Bool("configure-serverless", false, "") + + if tt.setFlag != "" { + require.NoError(t, cmd.Flags().Set(tt.setFlag, tt.flagVal)) + } + + err := validateDiscoveryFlagCompatibility(cmd) + if tt.wantErr != "" { + assert.EqualError(t, err, tt.wantErr) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestDiscoveryLogin_IntrospectionFailureStillSavesProfile(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, ".databrickscfg") + err := os.WriteFile(configPath, []byte(""), 0o600) + require.NoError(t, err) + t.Setenv("DATABRICKS_CONFIG_FILE", configPath) + + oauthArg, err := u2m.NewBasicDiscoveryOAuthArgument("DISCOVERY") + require.NoError(t, err) + oauthArg.SetDiscoveredHost("https://workspace.example.com") + + dc := &fakeDiscoveryClient{ + oauthArg: oauthArg, + persistentAuth: &fakeDiscoveryPersistentAuth{ + token: &oauth2.Token{AccessToken: "test-token"}, + }, + introspectionErr: errors.New("introspection failed"), + } + + ctx, _ := cmdio.NewTestContextWithStdout(t.Context()) + err = discoveryLogin(ctx, dc, "DISCOVERY", time.Second, "all-apis, ,sql,", nil, func(string) error { return nil }) + require.NoError(t, err) + + assert.Equal(t, "https://workspace.example.com", dc.introspectHost) + assert.Equal(t, "test-token", dc.introspectToken) + + savedProfile, err := loadProfileByName(ctx, "DISCOVERY", profile.DefaultProfiler) + require.NoError(t, err) + require.NotNil(t, savedProfile) + assert.Equal(t, "https://workspace.example.com", savedProfile.Host) + assert.Equal(t, "all-apis,sql", savedProfile.Scopes) + assert.Empty(t, savedProfile.AccountID) + assert.Empty(t, savedProfile.WorkspaceID) +} + +func TestDiscoveryLogin_AccountIDMismatchWarning(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, ".databrickscfg") + err := os.WriteFile(configPath, []byte(""), 0o600) + require.NoError(t, err) + t.Setenv("DATABRICKS_CONFIG_FILE", configPath) + + oauthArg, err := u2m.NewBasicDiscoveryOAuthArgument("DISCOVERY") + require.NoError(t, err) + oauthArg.SetDiscoveredHost("https://workspace.example.com") + + dc := &fakeDiscoveryClient{ + oauthArg: oauthArg, + persistentAuth: &fakeDiscoveryPersistentAuth{ + token: &oauth2.Token{AccessToken: "test-token"}, + }, + introspection: &auth.IntrospectionResult{ + AccountID: "new-account-id", + WorkspaceID: "12345", + }, + } + + // Set up a logger that captures log records to verify the warning. + var logBuf logBuffer + logger := slog.New(slog.NewTextHandler(&logBuf, &slog.HandlerOptions{Level: slog.LevelWarn})) + ctx, _ := cmdio.NewTestContextWithStdout(t.Context()) + ctx = log.NewContext(ctx, logger) + + existingProfile := &profile.Profile{ + Name: "DISCOVERY", + AccountID: "old-account-id", + } + + err = discoveryLogin(ctx, dc, "DISCOVERY", time.Second, "", existingProfile, func(string) error { return nil }) + require.NoError(t, err) + + // Verify warning about mismatched account IDs was logged. + assert.Contains(t, logBuf.String(), "new-account-id") + assert.Contains(t, logBuf.String(), "old-account-id") + + // Verify the profile was saved without account_id (not overwritten). + savedProfile, err := loadProfileByName(ctx, "DISCOVERY", profile.DefaultProfiler) + require.NoError(t, err) + require.NotNil(t, savedProfile) + assert.Equal(t, "https://workspace.example.com", savedProfile.Host) + assert.Equal(t, "12345", savedProfile.WorkspaceID) +} + +func TestDiscoveryLogin_NoWarningWhenAccountIDsMatch(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, ".databrickscfg") + err := os.WriteFile(configPath, []byte(""), 0o600) + require.NoError(t, err) + t.Setenv("DATABRICKS_CONFIG_FILE", configPath) + + oauthArg, err := u2m.NewBasicDiscoveryOAuthArgument("DISCOVERY") + require.NoError(t, err) + oauthArg.SetDiscoveredHost("https://workspace.example.com") + + dc := &fakeDiscoveryClient{ + oauthArg: oauthArg, + persistentAuth: &fakeDiscoveryPersistentAuth{ + token: &oauth2.Token{AccessToken: "test-token"}, + }, + introspection: &auth.IntrospectionResult{ + AccountID: "same-account-id", + WorkspaceID: "12345", + }, + } + + var logBuf logBuffer + logger := slog.New(slog.NewTextHandler(&logBuf, &slog.HandlerOptions{Level: slog.LevelWarn})) + ctx, _ := cmdio.NewTestContextWithStdout(t.Context()) + ctx = log.NewContext(ctx, logger) + + existingProfile := &profile.Profile{ + Name: "DISCOVERY", + AccountID: "same-account-id", + } + + err = discoveryLogin(ctx, dc, "DISCOVERY", time.Second, "", existingProfile, func(string) error { return nil }) + require.NoError(t, err) + + // No warning should be logged when account IDs match. + assert.Empty(t, logBuf.String()) +} + +func TestDiscoveryLogin_EmptyDiscoveredHostReturnsError(t *testing.T) { + // Return arg without calling SetDiscoveredHost, so GetDiscoveredHost returns "". + oauthArg, err := u2m.NewBasicDiscoveryOAuthArgument("DISCOVERY") + require.NoError(t, err) + + dc := &fakeDiscoveryClient{ + oauthArg: oauthArg, + persistentAuth: &fakeDiscoveryPersistentAuth{ + token: &oauth2.Token{AccessToken: "test-token"}, + }, + } + + ctx, _ := cmdio.NewTestContextWithStdout(t.Context()) + err = discoveryLogin(ctx, dc, "DISCOVERY", time.Second, "", nil, func(string) error { return nil }) + require.Error(t, err) + assert.Contains(t, err.Error(), "no workspace host was discovered") +} + +func TestDiscoveryLogin_ReloginPreservesExistingProfileScopes(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, ".databrickscfg") + err := os.WriteFile(configPath, []byte(""), 0o600) + require.NoError(t, err) + t.Setenv("DATABRICKS_CONFIG_FILE", configPath) + + oauthArg, err := u2m.NewBasicDiscoveryOAuthArgument("DISCOVERY") + require.NoError(t, err) + oauthArg.SetDiscoveredHost("https://workspace.example.com") + + dc := &fakeDiscoveryClient{ + oauthArg: oauthArg, + persistentAuth: &fakeDiscoveryPersistentAuth{ + token: &oauth2.Token{AccessToken: "test-token"}, + }, + introspectionErr: errors.New("introspection failed"), + } + + existingProfile := &profile.Profile{ + Name: "DISCOVERY", + Host: "https://old-workspace.example.com", + Scopes: "sql,clusters", + } + + // No --scopes flag (empty string), should fall back to existing profile scopes. + ctx, _ := cmdio.NewTestContextWithStdout(t.Context()) + err = discoveryLogin(ctx, dc, "DISCOVERY", time.Second, "", existingProfile, func(string) error { return nil }) + require.NoError(t, err) + + savedProfile, err := loadProfileByName(ctx, "DISCOVERY", profile.DefaultProfiler) + require.NoError(t, err) + require.NotNil(t, savedProfile) + assert.Equal(t, "https://workspace.example.com", savedProfile.Host) + assert.Equal(t, "sql,clusters", savedProfile.Scopes) +} + +func TestDiscoveryLogin_ExplicitScopesOverrideExistingProfile(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, ".databrickscfg") + err := os.WriteFile(configPath, []byte(""), 0o600) + require.NoError(t, err) + t.Setenv("DATABRICKS_CONFIG_FILE", configPath) + + oauthArg, err := u2m.NewBasicDiscoveryOAuthArgument("DISCOVERY") + require.NoError(t, err) + oauthArg.SetDiscoveredHost("https://workspace.example.com") + + dc := &fakeDiscoveryClient{ + oauthArg: oauthArg, + persistentAuth: &fakeDiscoveryPersistentAuth{ + token: &oauth2.Token{AccessToken: "test-token"}, + }, + introspectionErr: errors.New("introspection failed"), + } + + existingProfile := &profile.Profile{ + Name: "DISCOVERY", + Host: "https://old-workspace.example.com", + Scopes: "sql,clusters", + } + + // Explicit --scopes flag should override existing profile scopes. + ctx, _ := cmdio.NewTestContextWithStdout(t.Context()) + err = discoveryLogin(ctx, dc, "DISCOVERY", time.Second, "all-apis", existingProfile, func(string) error { return nil }) + require.NoError(t, err) + + savedProfile, err := loadProfileByName(ctx, "DISCOVERY", profile.DefaultProfiler) + require.NoError(t, err) + require.NotNil(t, savedProfile) + assert.Equal(t, "all-apis", savedProfile.Scopes) +} + +func TestDiscoveryLogin_ClearsStaleRoutingFieldsFromUnifiedProfile(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, ".databrickscfg") + + // Pre-populate a profile that looks like an older hostless/unified login. + initialConfig := `[DISCOVERY] +host = https://old-unified.databricks.com +account_id = old-account +workspace_id = 999999 +experimental_is_unified_host = true +auth_type = databricks-cli +` + err := os.WriteFile(configPath, []byte(initialConfig), 0o600) + require.NoError(t, err) + t.Setenv("DATABRICKS_CONFIG_FILE", configPath) + + oauthArg, err := u2m.NewBasicDiscoveryOAuthArgument("DISCOVERY") + require.NoError(t, err) + oauthArg.SetDiscoveredHost("https://new-workspace.example.com") + + // Introspection fails, so workspace_id should be cleared (not left stale). + dc := &fakeDiscoveryClient{ + oauthArg: oauthArg, + persistentAuth: &fakeDiscoveryPersistentAuth{ + token: &oauth2.Token{AccessToken: "test-token"}, + }, + introspectionErr: errors.New("introspection unavailable"), + } + + existingProfile := &profile.Profile{ + Name: "DISCOVERY", + Host: "https://old-unified.databricks.com", + AccountID: "old-account", + WorkspaceID: "999999", + IsUnifiedHost: true, + } + + ctx, _ := cmdio.NewTestContextWithStdout(t.Context()) + err = discoveryLogin(ctx, dc, "DISCOVERY", time.Second, "", existingProfile, func(string) error { return nil }) + require.NoError(t, err) + + savedProfile, err := loadProfileByName(ctx, "DISCOVERY", profile.DefaultProfiler) + require.NoError(t, err) + require.NotNil(t, savedProfile) + assert.Equal(t, "https://new-workspace.example.com", savedProfile.Host) + // Stale routing fields must be cleared. + assert.Empty(t, savedProfile.AccountID, "stale account_id should be cleared") + assert.Empty(t, savedProfile.WorkspaceID, "stale workspace_id should be cleared on introspection failure") + assert.False(t, savedProfile.IsUnifiedHost, "stale experimental_is_unified_host should be cleared") +} + +func TestDiscoveryLogin_IntrospectionWritesFreshWorkspaceID(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, ".databrickscfg") + + // Pre-populate with stale workspace_id. + initialConfig := `[DISCOVERY] +host = https://old.example.com +workspace_id = 111111 +auth_type = databricks-cli +` + err := os.WriteFile(configPath, []byte(initialConfig), 0o600) + require.NoError(t, err) + t.Setenv("DATABRICKS_CONFIG_FILE", configPath) + + oauthArg, err := u2m.NewBasicDiscoveryOAuthArgument("DISCOVERY") + require.NoError(t, err) + oauthArg.SetDiscoveredHost("https://new-workspace.example.com") + + // Introspection succeeds with a fresh workspace_id. + dc := &fakeDiscoveryClient{ + oauthArg: oauthArg, + persistentAuth: &fakeDiscoveryPersistentAuth{ + token: &oauth2.Token{AccessToken: "test-token"}, + }, + introspection: &auth.IntrospectionResult{ + AccountID: "fresh-account", + WorkspaceID: "222222", + }, + } + + existingProfile := &profile.Profile{ + Name: "DISCOVERY", + Host: "https://old.example.com", + WorkspaceID: "111111", + } + + ctx, _ := cmdio.NewTestContextWithStdout(t.Context()) + err = discoveryLogin(ctx, dc, "DISCOVERY", time.Second, "", existingProfile, func(string) error { return nil }) + require.NoError(t, err) + + savedProfile, err := loadProfileByName(ctx, "DISCOVERY", profile.DefaultProfiler) + require.NoError(t, err) + require.NotNil(t, savedProfile) + assert.Equal(t, "https://new-workspace.example.com", savedProfile.Host) + assert.Equal(t, "222222", savedProfile.WorkspaceID, "workspace_id should be updated to fresh introspection value") +} diff --git a/cmd/auth/token.go b/cmd/auth/token.go index 5f695ce6bc..79f99726be 100644 --- a/cmd/auth/token.go +++ b/cmd/auth/token.go @@ -421,9 +421,7 @@ func runInlineLogin(ctx context.Context, profiler profile.Profiler) (string, *pr // uses the same scopes the user previously configured. var scopesList []string if existingProfile != nil && existingProfile.Scopes != "" { - for _, s := range strings.Split(existingProfile.Scopes, ",") { - scopesList = append(scopesList, strings.TrimSpace(s)) - } + scopesList = splitScopes(existingProfile.Scopes) } oauthArgument, err := loginArgs.ToOAuthArgument() diff --git a/libs/auth/introspect.go b/libs/auth/introspect.go new file mode 100644 index 0000000000..6558801014 --- /dev/null +++ b/libs/auth/introspect.go @@ -0,0 +1,69 @@ +package auth + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strconv" + "strings" +) + +// IntrospectionResponse represents the response from the Databricks token +// introspection endpoint at /api/2.0/tokens/introspect. +type IntrospectionResponse struct { + PrincipalContext struct { + AuthenticationScope struct { + AccountID string `json:"account_id"` + WorkspaceID int64 `json:"workspace_id"` + } `json:"authentication_scope"` + } `json:"principal_context"` +} + +// IntrospectionResult contains the extracted metadata from token introspection. +type IntrospectionResult struct { + AccountID string + WorkspaceID string +} + +// IntrospectToken calls the workspace token introspection endpoint to extract +// account_id and workspace_id for the given access token. Returns an error +// if the request fails or the response cannot be parsed. Callers should treat +// errors as non-fatal (best-effort metadata enrichment). +func IntrospectToken(ctx context.Context, host, accessToken string, httpClient *http.Client) (*IntrospectionResult, error) { + if httpClient == nil { + httpClient = http.DefaultClient + } + endpoint := strings.TrimSuffix(host, "/") + "/api/2.0/tokens/introspect" + req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) + if err != nil { + return nil, fmt.Errorf("creating introspection request: %w", err) + } + req.Header.Set("Authorization", "Bearer "+accessToken) + + resp, err := httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("calling introspection endpoint: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + // Drain the body so the underlying TCP connection can be reused. + _, _ = io.Copy(io.Discard, resp.Body) + return nil, fmt.Errorf("introspection endpoint returned status %d", resp.StatusCode) + } + + var introspection IntrospectionResponse + if err := json.NewDecoder(resp.Body).Decode(&introspection); err != nil { + return nil, fmt.Errorf("decoding introspection response: %w", err) + } + + result := &IntrospectionResult{ + AccountID: introspection.PrincipalContext.AuthenticationScope.AccountID, + } + if introspection.PrincipalContext.AuthenticationScope.WorkspaceID != 0 { + result.WorkspaceID = strconv.FormatInt(introspection.PrincipalContext.AuthenticationScope.WorkspaceID, 10) + } + return result, nil +} diff --git a/libs/auth/introspect_test.go b/libs/auth/introspect_test.go new file mode 100644 index 0000000000..3ea5f10717 --- /dev/null +++ b/libs/auth/introspect_test.go @@ -0,0 +1,112 @@ +package auth + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestIntrospectToken_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{ + "principal_context": { + "authentication_scope": { + "account_id": "a1b1c234-5678-90ab-cdef-1234567890ab", + "workspace_id": 2548836972759138 + } + } + }`)) + })) + defer server.Close() + + result, err := IntrospectToken(t.Context(), server.URL, "test-token", nil) + require.NoError(t, err) + assert.Equal(t, "a1b1c234-5678-90ab-cdef-1234567890ab", result.AccountID) + assert.Equal(t, "2548836972759138", result.WorkspaceID) +} + +func TestIntrospectToken_ZeroWorkspaceID(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{ + "principal_context": { + "authentication_scope": { + "account_id": "abc-123", + "workspace_id": 0 + } + } + }`)) + })) + defer server.Close() + + result, err := IntrospectToken(t.Context(), server.URL, "test-token", nil) + require.NoError(t, err) + assert.Equal(t, "abc-123", result.AccountID) + assert.Empty(t, result.WorkspaceID) +} + +func TestIntrospectToken_HTTPError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusForbidden) + })) + defer server.Close() + + _, err := IntrospectToken(t.Context(), server.URL, "test-token", nil) + assert.ErrorContains(t, err, "status 403") +} + +func TestIntrospectToken_MalformedJSON(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`not json`)) + })) + defer server.Close() + + _, err := IntrospectToken(t.Context(), server.URL, "test-token", nil) + assert.ErrorContains(t, err, "decoding introspection response") +} + +func TestIntrospectToken_VerifyRequestDetails(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/api/2.0/tokens/introspect", r.URL.Path) + assert.Equal(t, "Bearer my-secret-token", r.Header.Get("Authorization")) + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"principal_context":{"authentication_scope":{"account_id":"x","workspace_id":1}}}`)) + })) + defer server.Close() + + _, err := IntrospectToken(t.Context(), server.URL, "my-secret-token", nil) + require.NoError(t, err) +} + +// roundTripFunc adapts a function into an http.RoundTripper. +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} + +func TestIntrospectToken_UsesSuppliedHTTPClient(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"principal_context":{"authentication_scope":{"account_id":"x","workspace_id":1}}}`)) + })) + defer server.Close() + + var transportUsed bool + customClient := &http.Client{ + Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + transportUsed = true + return http.DefaultTransport.RoundTrip(req) + }), + } + + result, err := IntrospectToken(t.Context(), server.URL, "test-token", customClient) + require.NoError(t, err) + assert.True(t, transportUsed, "expected IntrospectToken to use the supplied HTTP client") + assert.Equal(t, "x", result.AccountID) +}