diff --git a/internal/cmd/auth_credentials.go b/internal/cmd/auth_credentials.go index 61d52356..97bc3ecd 100644 --- a/internal/cmd/auth_credentials.go +++ b/internal/cmd/auth_credentials.go @@ -15,8 +15,9 @@ import ( ) type AuthCredentialsCmd struct { - Set AuthCredentialsSetCmd `cmd:"" default:"withargs" help:"Store OAuth client credentials"` - List AuthCredentialsListCmd `cmd:"" name:"list" help:"List stored OAuth client credentials"` + Set AuthCredentialsSetCmd `cmd:"" default:"withargs" help:"Store OAuth client credentials"` + List AuthCredentialsListCmd `cmd:"" name:"list" help:"List stored OAuth client credentials"` + Remove AuthCredentialsRemoveCmd `cmd:"" name:"remove" help:"Remove stored OAuth client credentials"` } type AuthCredentialsSetCmd struct { @@ -160,3 +161,149 @@ func (c *AuthCredentialsListCmd) Run(ctx context.Context, _ *RootFlags) error { } return nil } + +type AuthCredentialsRemoveCmd struct { + Client string `arg:"" optional:"" name:"client" help:"Client name to remove (omit for default, or 'all' to remove every client)"` +} + +func (c *AuthCredentialsRemoveCmd) Run(ctx context.Context, flags *RootFlags) error { + u := ui.FromContext(ctx) + + // Determine target client(s): explicit arg > --client flag > default. + target := strings.TrimSpace(c.Client) + if target == "" { + t, err := normalizeClientForFlag(authclient.ClientOverrideFromContext(ctx)) + if err != nil { + return err + } + target = t + } + + if strings.EqualFold(target, "all") { + return c.removeAll(ctx, flags, u) + } + + client, err := config.NormalizeClientNameOrDefault(target) + if err != nil { + return err + } + + accounts := findAccountsForClient(client) + + action := fmt.Sprintf("remove OAuth credentials for client %q", client) + if len(accounts) > 0 { + action += fmt.Sprintf(" and %d associated token(s) (%s)", len(accounts), strings.Join(accounts, ", ")) + } + if err := confirmDestructive(ctx, flags, action); err != nil { + return err + } + + if err := config.DeleteClientCredentialsFor(client); err != nil { + return err + } + + tokensRemoved := removeTokensForClient(client, accounts) + domainsRemoved := removeDomainMappings(client) + + return writeResult(ctx, u, + kv("removed", true), + kv("client", client), + kv("tokens_removed", tokensRemoved), + kv("domains_removed", domainsRemoved), + ) +} + +func (c *AuthCredentialsRemoveCmd) removeAll(ctx context.Context, flags *RootFlags, u *ui.UI) error { + creds, err := config.ListClientCredentials() + if err != nil { + return err + } + if len(creds) == 0 { + return writeResult(ctx, u, kv("removed", 0)) + } + + names := make([]string, 0, len(creds)) + for _, info := range creds { + names = append(names, info.Client) + } + if err := confirmDestructive(ctx, flags, fmt.Sprintf("remove all OAuth credentials (%s)", strings.Join(names, ", "))); err != nil { + return err + } + + var allTokens []string + for _, info := range creds { + accounts := findAccountsForClient(info.Client) + if err := config.DeleteClientCredentialsFor(info.Client); err != nil { + return err + } + allTokens = append(allTokens, removeTokensForClient(info.Client, accounts)...) + removeDomainMappings(info.Client) + } + + return writeResult(ctx, u, + kv("removed", len(creds)), + kv("clients", names), + kv("tokens_removed", allTokens), + ) +} + +// findAccountsForClient returns emails that have tokens stored under the given client. +func findAccountsForClient(client string) []string { + store, err := openSecretsStore() + if err != nil { + return nil + } + tokens, err := store.ListTokens() + if err != nil { + return nil + } + var emails []string + for _, tok := range tokens { + tokClient, _ := config.NormalizeClientNameOrDefault(tok.Client) + if tokClient == client { + emails = append(emails, tok.Email) + } + } + return emails +} + +// removeTokensForClient deletes tokens for the given accounts under the specified client. +func removeTokensForClient(client string, emails []string) []string { + if len(emails) == 0 { + return nil + } + store, err := openSecretsStore() + if err != nil { + return nil + } + var removed []string + for _, email := range emails { + if err := store.DeleteToken(client, email); err == nil { + removed = append(removed, email) + } + } + return removed +} + +// removeDomainMappings deletes config domain entries that point to the given client. +func removeDomainMappings(client string) []string { + cfg, err := config.ReadConfig() + if err != nil { + return nil + } + var removed []string + for domain, mapped := range cfg.ClientDomains { + normalized, nerr := config.NormalizeClientNameOrDefault(mapped) + if nerr != nil { + continue + } + if normalized == client { + removed = append(removed, domain) + delete(cfg.ClientDomains, domain) + } + } + if len(removed) > 0 { + _ = config.WriteConfig(cfg) + } + return removed +} diff --git a/internal/config/credentials.go b/internal/config/credentials.go index 529ac8c5..67283f4d 100644 --- a/internal/config/credentials.go +++ b/internal/config/credentials.go @@ -114,6 +114,23 @@ func ReadClientCredentialsFor(client string) (ClientCredentials, error) { return c, nil } +func DeleteClientCredentialsFor(client string) error { + path, err := ClientCredentialsPathFor(client) + if err != nil { + return fmt.Errorf("resolve credentials path: %w", err) + } + + if err := os.Remove(path); err != nil { + if os.IsNotExist(err) { + return &CredentialsMissingError{Path: path, Cause: err} + } + + return fmt.Errorf("delete credentials: %w", err) + } + + return nil +} + func ClientCredentialsExists(client string) (bool, error) { path, err := ClientCredentialsPathFor(client) if err != nil {