From 20dfd85ae1b44de925e86c338b558e4c50b76b15 Mon Sep 17 00:00:00 2001 From: sbldevnet Date: Tue, 3 Feb 2026 15:06:31 +0100 Subject: [PATCH 01/12] refactor: add rootmanager API and move packages under internal --- cmd/audit.go | 11 +++---- cmd/check.go | 8 ++--- cmd/delete.go | 11 +++---- cmd/enable.go | 8 ++--- cmd/recovery.go | 11 +++---- {pkg => internal/cli}/output/csv.go | 0 {pkg => internal/cli}/output/json.go | 0 {pkg => internal/cli}/output/output.go | 2 +- {pkg => internal/cli}/output/table.go | 0 .../cli/ui/account_selector.go | 24 +++++++-------- {pkg => internal/cli}/ui/selector.go | 0 {pkg => internal/infra}/aws/config.go | 0 {pkg => internal/infra}/aws/iam.go | 2 +- {pkg => internal/infra}/aws/organizations.go | 2 +- {pkg => internal/infra}/aws/sts.go | 2 +- {pkg => internal}/logger/logger.go | 0 {pkg => internal}/service/audit.go | 4 +-- {pkg => internal}/service/configuration.go | 4 +-- {pkg => internal}/service/credentials.go | 4 +-- rootmanager/api.go | 29 +++++++++++++++++++ rootmanager/errors.go | 17 +++++++++++ rootmanager/types.go | 24 +++++++++++++++ 22 files changed, 118 insertions(+), 45 deletions(-) rename {pkg => internal/cli}/output/csv.go (100%) rename {pkg => internal/cli}/output/json.go (100%) rename {pkg => internal/cli}/output/output.go (94%) rename {pkg => internal/cli}/output/table.go (100%) rename pkg/service/accounts.go => internal/cli/ui/account_selector.go (71%) rename {pkg => internal/cli}/ui/selector.go (100%) rename {pkg => internal/infra}/aws/config.go (100%) rename {pkg => internal/infra}/aws/iam.go (99%) rename {pkg => internal/infra}/aws/organizations.go (98%) rename {pkg => internal/infra}/aws/sts.go (97%) rename {pkg => internal}/logger/logger.go (100%) rename {pkg => internal}/service/audit.go (95%) rename {pkg => internal}/service/configuration.go (94%) rename {pkg => internal}/service/credentials.go (97%) create mode 100644 rootmanager/api.go create mode 100644 rootmanager/errors.go create mode 100644 rootmanager/types.go diff --git a/cmd/audit.go b/cmd/audit.go index b9b0f56..646ee03 100644 --- a/cmd/audit.go +++ b/cmd/audit.go @@ -5,10 +5,11 @@ import ( "fmt" "strings" - "github.com/unicrons/aws-root-manager/pkg/aws" - "github.com/unicrons/aws-root-manager/pkg/logger" - "github.com/unicrons/aws-root-manager/pkg/output" - "github.com/unicrons/aws-root-manager/pkg/service" + "github.com/unicrons/aws-root-manager/internal/cli/output" + "github.com/unicrons/aws-root-manager/internal/cli/ui" + "github.com/unicrons/aws-root-manager/internal/infra/aws" + "github.com/unicrons/aws-root-manager/internal/logger" + "github.com/unicrons/aws-root-manager/internal/service" "github.com/spf13/cobra" ) @@ -28,7 +29,7 @@ var auditCmd = &cobra.Command{ return err } - auditAccounts, err := service.GetTargetAccounts(ctx, accountsFlags) + auditAccounts, err := ui.SelectTargetAccounts(ctx, accountsFlags) if err != nil { logger.Error("cmd.audit", err, "failed to get accounts to audit") return err diff --git a/cmd/check.go b/cmd/check.go index 3594aaa..f39c99c 100644 --- a/cmd/check.go +++ b/cmd/check.go @@ -4,10 +4,10 @@ import ( "context" "strconv" - "github.com/unicrons/aws-root-manager/pkg/aws" - "github.com/unicrons/aws-root-manager/pkg/logger" - "github.com/unicrons/aws-root-manager/pkg/output" - "github.com/unicrons/aws-root-manager/pkg/service" + "github.com/unicrons/aws-root-manager/internal/infra/aws" + "github.com/unicrons/aws-root-manager/internal/logger" + "github.com/unicrons/aws-root-manager/internal/cli/output" + "github.com/unicrons/aws-root-manager/internal/service" "github.com/spf13/cobra" ) diff --git a/cmd/delete.go b/cmd/delete.go index 592fc23..41ac146 100644 --- a/cmd/delete.go +++ b/cmd/delete.go @@ -5,10 +5,11 @@ import ( "fmt" "strings" - "github.com/unicrons/aws-root-manager/pkg/aws" - "github.com/unicrons/aws-root-manager/pkg/logger" - "github.com/unicrons/aws-root-manager/pkg/output" - "github.com/unicrons/aws-root-manager/pkg/service" + "github.com/unicrons/aws-root-manager/internal/cli/output" + "github.com/unicrons/aws-root-manager/internal/cli/ui" + "github.com/unicrons/aws-root-manager/internal/infra/aws" + "github.com/unicrons/aws-root-manager/internal/logger" + "github.com/unicrons/aws-root-manager/internal/service" "github.com/spf13/cobra" ) @@ -116,7 +117,7 @@ func delete(accountsFlags []string, credentialType string) error { return fmt.Errorf("failed to load aws config: %w", err) } - auditAccounts, err := service.GetTargetAccounts(ctx, accountsFlags) + auditAccounts, err := ui.SelectTargetAccounts(ctx, accountsFlags) if err != nil { return fmt.Errorf("failed to get accounts to audit: %w", err) } diff --git a/cmd/enable.go b/cmd/enable.go index 81e8ae2..200a0d0 100644 --- a/cmd/enable.go +++ b/cmd/enable.go @@ -4,10 +4,10 @@ import ( "context" "strconv" - "github.com/unicrons/aws-root-manager/pkg/aws" - "github.com/unicrons/aws-root-manager/pkg/logger" - "github.com/unicrons/aws-root-manager/pkg/output" - "github.com/unicrons/aws-root-manager/pkg/service" + "github.com/unicrons/aws-root-manager/internal/infra/aws" + "github.com/unicrons/aws-root-manager/internal/logger" + "github.com/unicrons/aws-root-manager/internal/cli/output" + "github.com/unicrons/aws-root-manager/internal/service" "github.com/spf13/cobra" ) diff --git a/cmd/recovery.go b/cmd/recovery.go index 89abe33..ad82b21 100644 --- a/cmd/recovery.go +++ b/cmd/recovery.go @@ -4,10 +4,11 @@ import ( "context" "strings" - "github.com/unicrons/aws-root-manager/pkg/aws" - "github.com/unicrons/aws-root-manager/pkg/logger" - "github.com/unicrons/aws-root-manager/pkg/output" - "github.com/unicrons/aws-root-manager/pkg/service" + "github.com/unicrons/aws-root-manager/internal/cli/output" + "github.com/unicrons/aws-root-manager/internal/cli/ui" + "github.com/unicrons/aws-root-manager/internal/infra/aws" + "github.com/unicrons/aws-root-manager/internal/logger" + "github.com/unicrons/aws-root-manager/internal/service" "github.com/spf13/cobra" ) @@ -26,7 +27,7 @@ var recoveryCmd = &cobra.Command{ return } - targetAccounts, err := service.GetTargetAccounts(ctx, accountsFlags) + targetAccounts, err := ui.SelectTargetAccounts(ctx, accountsFlags) if err != nil { logger.Error("cmd.recovery", err, "failed to get target accounts") } diff --git a/pkg/output/csv.go b/internal/cli/output/csv.go similarity index 100% rename from pkg/output/csv.go rename to internal/cli/output/csv.go diff --git a/pkg/output/json.go b/internal/cli/output/json.go similarity index 100% rename from pkg/output/json.go rename to internal/cli/output/json.go diff --git a/pkg/output/output.go b/internal/cli/output/output.go similarity index 94% rename from pkg/output/output.go rename to internal/cli/output/output.go index dc55923..610a937 100644 --- a/pkg/output/output.go +++ b/internal/cli/output/output.go @@ -3,7 +3,7 @@ package output import ( "fmt" - "github.com/unicrons/aws-root-manager/pkg/logger" + "github.com/unicrons/aws-root-manager/internal/logger" ) // HandleOutput handles the output based on the specified format diff --git a/pkg/output/table.go b/internal/cli/output/table.go similarity index 100% rename from pkg/output/table.go rename to internal/cli/output/table.go diff --git a/pkg/service/accounts.go b/internal/cli/ui/account_selector.go similarity index 71% rename from pkg/service/accounts.go rename to internal/cli/ui/account_selector.go index f59fca8..f253d92 100644 --- a/pkg/service/accounts.go +++ b/internal/cli/ui/account_selector.go @@ -1,12 +1,11 @@ -package service +package ui import ( "context" "fmt" - "github.com/unicrons/aws-root-manager/pkg/aws" - "github.com/unicrons/aws-root-manager/pkg/logger" - "github.com/unicrons/aws-root-manager/pkg/ui" + "github.com/unicrons/aws-root-manager/internal/infra/aws" + "github.com/unicrons/aws-root-manager/internal/logger" ) const ( @@ -14,13 +13,14 @@ const ( AllAccountsSelectorText = "all non management accounts" ) -// Get target AWS accounts based on input flags or user interaction -func GetTargetAccounts(ctx context.Context, accounts []string) ([]string, error) { - logger.Trace("service.GetTargetAccounts", "processing target accounts: %s", accounts) +// SelectTargetAccounts handles interactive account selection or returns accounts based on flags. +// Returns account IDs based on flags or TUI prompt. +func SelectTargetAccounts(ctx context.Context, accountsFlag []string) ([]string, error) { + logger.Trace("ui.SelectTargetAccounts", "processing target accounts: %s", accountsFlag) // if accounts are provided and "all" is not specified, return them - if len(accounts) > 0 && accounts[0] != AllAccountsOption { - return accounts, nil + if len(accountsFlag) > 0 && accountsFlag[0] != AllAccountsOption { + return accountsFlag, nil } // fetch all non-management accounts @@ -30,7 +30,7 @@ func GetTargetAccounts(ctx context.Context, accounts []string) ([]string, error) } // if "all" is specified, return all account IDs - if len(accounts) > 0 && accounts[0] == AllAccountsOption { + if len(accountsFlag) > 0 && accountsFlag[0] == AllAccountsOption { return convertAccountsToIDs(orgAccounts), nil } @@ -40,7 +40,7 @@ func GetTargetAccounts(ctx context.Context, accounts []string) ([]string, error) for _, account := range orgAccounts { selectorChoices = append(selectorChoices, fmt.Sprintf("%s - %s", account.AccountID, account.Name)) } - selectedIndexes, err := ui.Prompt("Please select the AWS accounts to audit", selectorChoices) + selectedIndexes, err := Prompt("Please select the AWS accounts to audit", selectorChoices) if err != nil { return nil, err } @@ -50,7 +50,7 @@ func GetTargetAccounts(ctx context.Context, accounts []string) ([]string, error) // Resolve selected accounts if allSelected(selectedIndexes) { - logger.Debug("service.GetTargetAccounts", "all accounts selected") + logger.Debug("ui.SelectTargetAccounts", "all accounts selected") return convertAccountsToIDs(orgAccounts), nil } diff --git a/pkg/ui/selector.go b/internal/cli/ui/selector.go similarity index 100% rename from pkg/ui/selector.go rename to internal/cli/ui/selector.go diff --git a/pkg/aws/config.go b/internal/infra/aws/config.go similarity index 100% rename from pkg/aws/config.go rename to internal/infra/aws/config.go diff --git a/pkg/aws/iam.go b/internal/infra/aws/iam.go similarity index 99% rename from pkg/aws/iam.go rename to internal/infra/aws/iam.go index 7bda6b3..653129a 100644 --- a/pkg/aws/iam.go +++ b/internal/infra/aws/iam.go @@ -6,7 +6,7 @@ import ( "fmt" "slices" - "github.com/unicrons/aws-root-manager/pkg/logger" + "github.com/unicrons/aws-root-manager/internal/logger" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/iam" diff --git a/pkg/aws/organizations.go b/internal/infra/aws/organizations.go similarity index 98% rename from pkg/aws/organizations.go rename to internal/infra/aws/organizations.go index 37bfcd3..0372ace 100644 --- a/pkg/aws/organizations.go +++ b/internal/infra/aws/organizations.go @@ -4,7 +4,7 @@ import ( "context" "fmt" - "github.com/unicrons/aws-root-manager/pkg/logger" + "github.com/unicrons/aws-root-manager/internal/logger" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/organizations" diff --git a/pkg/aws/sts.go b/internal/infra/aws/sts.go similarity index 97% rename from pkg/aws/sts.go rename to internal/infra/aws/sts.go index 46c9294..16796e3 100644 --- a/pkg/aws/sts.go +++ b/internal/infra/aws/sts.go @@ -4,7 +4,7 @@ import ( "context" "fmt" - "github.com/unicrons/aws-root-manager/pkg/logger" + "github.com/unicrons/aws-root-manager/internal/logger" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/sts" diff --git a/pkg/logger/logger.go b/internal/logger/logger.go similarity index 100% rename from pkg/logger/logger.go rename to internal/logger/logger.go diff --git a/pkg/service/audit.go b/internal/service/audit.go similarity index 95% rename from pkg/service/audit.go rename to internal/service/audit.go index 53741d5..957a416 100644 --- a/pkg/service/audit.go +++ b/internal/service/audit.go @@ -4,8 +4,8 @@ import ( "context" "sync" - "github.com/unicrons/aws-root-manager/pkg/aws" - "github.com/unicrons/aws-root-manager/pkg/logger" + "github.com/unicrons/aws-root-manager/internal/infra/aws" + "github.com/unicrons/aws-root-manager/internal/logger" ) // Get root credentials for a list of AWS accounts. diff --git a/pkg/service/configuration.go b/internal/service/configuration.go similarity index 94% rename from pkg/service/configuration.go rename to internal/service/configuration.go index 52b1862..9b85cd4 100644 --- a/pkg/service/configuration.go +++ b/internal/service/configuration.go @@ -3,8 +3,8 @@ package service import ( "context" - "github.com/unicrons/aws-root-manager/pkg/aws" - "github.com/unicrons/aws-root-manager/pkg/logger" + "github.com/unicrons/aws-root-manager/internal/infra/aws" + "github.com/unicrons/aws-root-manager/internal/logger" ) func CheckRootAccess(ctx context.Context, iam *aws.IamClient) (aws.RootAccessStatus, error) { diff --git a/pkg/service/credentials.go b/internal/service/credentials.go similarity index 97% rename from pkg/service/credentials.go rename to internal/service/credentials.go index 4708798..24d23db 100644 --- a/pkg/service/credentials.go +++ b/internal/service/credentials.go @@ -4,8 +4,8 @@ import ( "context" "sync" - "github.com/unicrons/aws-root-manager/pkg/aws" - "github.com/unicrons/aws-root-manager/pkg/logger" + "github.com/unicrons/aws-root-manager/internal/infra/aws" + "github.com/unicrons/aws-root-manager/internal/logger" ) // Delete root credentials for a list of AWS accounts diff --git a/rootmanager/api.go b/rootmanager/api.go new file mode 100644 index 0000000..3281eba --- /dev/null +++ b/rootmanager/api.go @@ -0,0 +1,29 @@ +package rootmanager + +import "context" + +// RootManager provides operations for managing AWS root credentials for an AWS organization. +type RootManager interface { + // AuditAccounts audits root credentials across the specified AWS accounts. + // It checks for login profiles, access keys, MFA devices, and signing certificates. + AuditAccounts(ctx context.Context, accountIds []string) ([]RootCredentials, error) + + // CheckRootAccess checks the status of centralized root access features in the organization. + // It verifies whether trusted access, root credentials management, and root sessions are enabled. + CheckRootAccess(ctx context.Context) (RootAccessStatus, error) + + // EnableRootAccess enables centralized root access features in the organization. + // The enableSessions parameter controls whether to enable root sessions (AssumeRoot). + // Returns the initial status, final status after enabling, and any error encountered. + EnableRootAccess(ctx context.Context, enableSessions bool) (RootAccessStatus, RootAccessStatus, error) + + // DeleteCredentials deletes root credentials for the specified accounts. + // The creds parameter should contain audit results identifying what credentials exist. + // The credentialType parameter specifies what to delete: "all", "login", "keys", "mfa", or "certificate". + DeleteCredentials(ctx context.Context, creds []RootCredentials, credentialType string) error + + // RecoverRootPassword initiates root password recovery for the specified accounts. + // This triggers AWS to send password reset emails to the account's root email address. + // Returns a map of account ID to success status, and any error encountered. + RecoverRootPassword(ctx context.Context, accountIds []string) (map[string]bool, error) +} diff --git a/rootmanager/errors.go b/rootmanager/errors.go new file mode 100644 index 0000000..e275859 --- /dev/null +++ b/rootmanager/errors.go @@ -0,0 +1,17 @@ +package rootmanager + +import "errors" + +var ( + // ErrTrustedAccessNotEnabled indicates AWS IAM does not have trusted access to the organization. + ErrTrustedAccessNotEnabled = errors.New("AWS IAM trusted access is not enabled for the organization") + + // ErrRootCredentialsManagementNotEnabled indicates centralized root credentials management is not enabled. + ErrRootCredentialsManagementNotEnabled = errors.New("centralized root credentials management is not enabled") + + // ErrRootSessionsNotEnabled indicates root sessions (AssumeRoot) are not enabled. + ErrRootSessionsNotEnabled = errors.New("root sessions are not enabled") + + // ErrEntityAlreadyExists indicates the requested entity already exists. + ErrEntityAlreadyExists = errors.New("entity already exists") +) diff --git a/rootmanager/types.go b/rootmanager/types.go new file mode 100644 index 0000000..94ba336 --- /dev/null +++ b/rootmanager/types.go @@ -0,0 +1,24 @@ +package rootmanager + +// RootAccessStatus represents the status of centralized root access features in an AWS Organization. +type RootAccessStatus struct { + TrustedAccess bool // Whether AWS IAM has trusted access to the organization + RootCredentialsManagement bool // Whether centralized root credentials management is enabled + RootSessions bool // Whether root sessions (assume root) are enabled +} + +// RootCredentials represents the root user credentials for an AWS account. +type RootCredentials struct { + AccountId string // AWS account ID + LoginProfile bool // Whether a root password exists + AccessKeys []string // List of root access key IDs + MfaDevices []string // List of root MFA device serial numbers + SigningCertificates []string // List of root signing certificate IDs + Error string // Error message if audit failed for this account +} + +// RecoveryResult represents the result of a root password recovery operation for an account. +type RecoveryResult struct { + AccountId string // AWS account ID + Success bool // Whether recovery email was successfully sent +} From d47695f2b1cea13fd62e137b0992724b4ae19759 Mon Sep 17 00:00:00 2001 From: sbldevnet Date: Tue, 3 Feb 2026 16:13:41 +0100 Subject: [PATCH 02/12] refactor: consolidate types and errors --- cmd/audit.go | 5 +- cmd/check.go | 10 +-- cmd/delete.go | 8 +- cmd/enable.go | 6 +- cmd/recovery.go | 6 +- internal/infra/aws/iam.go | 16 ++-- internal/service/configuration.go | 8 +- internal/service/credentials.go | 4 +- internal/service/rootmanager.go | 121 ++++++++++++++++++++++++++++++ 9 files changed, 148 insertions(+), 36 deletions(-) create mode 100644 internal/service/rootmanager.go diff --git a/cmd/audit.go b/cmd/audit.go index 646ee03..fd49549 100644 --- a/cmd/audit.go +++ b/cmd/audit.go @@ -40,9 +40,8 @@ var auditCmd = &cobra.Command{ } logger.Debug("cmd.audit", "selected accounts: %s", strings.Join(auditAccounts, ", ")) - iam := aws.NewIamClient(awscfg) - sts := aws.NewStsClient(awscfg) - audit, err := service.AuditAccounts(ctx, iam, sts, auditAccounts) + rm := service.NewRootManager(aws.NewIamClient(awscfg), aws.NewStsClient(awscfg), aws.NewOrganizationsClient(awscfg)) + audit, err := rm.AuditAccounts(ctx, auditAccounts) if err != nil { logger.Error("cmd.audit", err, "failed to audit accounts") return err diff --git a/cmd/check.go b/cmd/check.go index f39c99c..39d1980 100644 --- a/cmd/check.go +++ b/cmd/check.go @@ -26,8 +26,8 @@ var checkCmd = &cobra.Command{ return } - iam := aws.NewIamClient(awscfg) - test, err := service.CheckRootAccess(ctx, iam) + rm := service.NewRootManager(aws.NewIamClient(awscfg), nil, nil) + status, err := rm.CheckRootAccess(ctx) if err != nil { logger.Error("cmd.check", err, "failed to check root access configuration") return @@ -35,9 +35,9 @@ var checkCmd = &cobra.Command{ headers := []string{"Name", "Status"} data := [][]any{ - {"TrustedAccess", strconv.FormatBool(test.TrustedAccess)}, - {"RootCredentialsManagement", strconv.FormatBool(test.RootCredentialsManagement)}, - {"RootSessions", strconv.FormatBool(test.RootSessions)}, + {"TrustedAccess", strconv.FormatBool(status.TrustedAccess)}, + {"RootCredentialsManagement", strconv.FormatBool(status.RootCredentialsManagement)}, + {"RootSessions", strconv.FormatBool(status.RootSessions)}, } output.HandleOutput(outputFlag, headers, data) }, diff --git a/cmd/delete.go b/cmd/delete.go index 41ac146..afdfd1d 100644 --- a/cmd/delete.go +++ b/cmd/delete.go @@ -127,15 +127,13 @@ func delete(accountsFlags []string, credentialType string) error { } logger.Debug("cmd.audit", "selected accounts: %s", strings.Join(auditAccounts, ", ")) - iam := aws.NewIamClient(awscfg) - sts := aws.NewStsClient(awscfg) - - audit, err := service.AuditAccounts(ctx, iam, sts, auditAccounts) + rm := service.NewRootManager(aws.NewIamClient(awscfg), aws.NewStsClient(awscfg), aws.NewOrganizationsClient(awscfg)) + audit, err := rm.AuditAccounts(ctx, auditAccounts) if err != nil { return err } - if err = service.DeleteAccountsCredentials(ctx, iam, sts, audit, credentialType); err != nil { + if err = rm.DeleteCredentials(ctx, audit, credentialType); err != nil { return err } diff --git a/cmd/enable.go b/cmd/enable.go index 200a0d0..e1e0330 100644 --- a/cmd/enable.go +++ b/cmd/enable.go @@ -28,10 +28,8 @@ var enableCmd = &cobra.Command{ return } - iam := aws.NewIamClient(awscfg) - org := aws.NewOrganizationsClient(awscfg) - - initStatus, status, err := service.EnableRootAccess(ctx, iam, org, enableRootSessions) + rm := service.NewRootManager(aws.NewIamClient(awscfg), nil, aws.NewOrganizationsClient(awscfg)) + initStatus, status, err := rm.EnableRootAccess(ctx, enableRootSessions) if err != nil { logger.Error("cmd.enable", err, "failed to enable root access") return diff --git a/cmd/recovery.go b/cmd/recovery.go index ad82b21..4351f6f 100644 --- a/cmd/recovery.go +++ b/cmd/recovery.go @@ -37,10 +37,8 @@ var recoveryCmd = &cobra.Command{ } logger.Debug("cmd.recovery", "selected accounts: %s", strings.Join(targetAccounts, ", ")) - iam := aws.NewIamClient(awscfg) - sts := aws.NewStsClient(awscfg) - - resultMap, err := service.RecoverAccountsRootPassword(ctx, iam, sts, targetAccounts) + rm := service.NewRootManager(aws.NewIamClient(awscfg), aws.NewStsClient(awscfg), aws.NewOrganizationsClient(awscfg)) + resultMap, err := rm.RecoverRootPassword(ctx, targetAccounts) if err != nil { logger.Error("cmd.recovery", err, "failed to recover root password") return diff --git a/internal/infra/aws/iam.go b/internal/infra/aws/iam.go index 653129a..42084cf 100644 --- a/internal/infra/aws/iam.go +++ b/internal/infra/aws/iam.go @@ -7,19 +7,13 @@ import ( "slices" "github.com/unicrons/aws-root-manager/internal/logger" + "github.com/unicrons/aws-root-manager/rootmanager" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/iam" "github.com/aws/aws-sdk-go-v2/service/iam/types" ) -var ( - ErrTrustedAccessNotEnabled = errors.New("trustedAccessNotEnabled") - ErrRootCredentialsManagementNotEnabled = errors.New("rootCredentialsManagementNotEnabled") - ErrRootSessionsNotEnabled = errors.New("rootSessionsNotEnabled") - ErrEntityAlreadyExists = errors.New("entityAlreadyExists") -) - type RootAccessStatus struct { TrustedAccess bool RootCredentialsManagement bool @@ -52,14 +46,14 @@ func (c *IamClient) CheckOrganizationRootAccess(ctx context.Context, rootSession if err != nil { var serviceAccessNotEnabledErr *types.ServiceAccessNotEnabledException if errors.As(err, &serviceAccessNotEnabledErr) { - return ErrTrustedAccessNotEnabled + return rootmanager.ErrTrustedAccessNotEnabled } return fmt.Errorf("aws.CheckOrganizationRootAccess: failed to list organization features: %w", err) } rootCredentialsManagement := slices.Contains(features.EnabledFeatures, "RootCredentialsManagement") if !rootCredentialsManagement { - return ErrRootCredentialsManagementNotEnabled + return rootmanager.ErrRootCredentialsManagementNotEnabled } if !rootSessionsRequired { @@ -68,7 +62,7 @@ func (c *IamClient) CheckOrganizationRootAccess(ctx context.Context, rootSession rootSessions := slices.Contains(features.EnabledFeatures, "RootSessions") if !rootSessions { - return ErrRootSessionsNotEnabled + return rootmanager.ErrRootSessionsNotEnabled } return nil @@ -244,7 +238,7 @@ func (c *IamClient) CreateLoginProfile(ctx context.Context) error { var entityAlreadyExistsErr *types.EntityAlreadyExistsException if errors.As(err, &entityAlreadyExistsErr) { logger.Debug("aws.createLoginProfile", "login profile already exists") - return ErrEntityAlreadyExists + return rootmanager.ErrEntityAlreadyExists } return fmt.Errorf("error creating login profile: %w", err) } diff --git a/internal/service/configuration.go b/internal/service/configuration.go index 9b85cd4..239be4d 100644 --- a/internal/service/configuration.go +++ b/internal/service/configuration.go @@ -2,9 +2,11 @@ package service import ( "context" + "errors" "github.com/unicrons/aws-root-manager/internal/infra/aws" "github.com/unicrons/aws-root-manager/internal/logger" + "github.com/unicrons/aws-root-manager/rootmanager" ) func CheckRootAccess(ctx context.Context, iam *aws.IamClient) (aws.RootAccessStatus, error) { @@ -16,17 +18,17 @@ func CheckRootAccess(ctx context.Context, iam *aws.IamClient) (aws.RootAccessSta err := iam.CheckOrganizationRootAccess(ctx, true) if err != nil { - if err == aws.ErrTrustedAccessNotEnabled { + if errors.Is(err, rootmanager.ErrTrustedAccessNotEnabled) { return status, nil } status.TrustedAccess = true - if err == aws.ErrRootCredentialsManagementNotEnabled { + if errors.Is(err, rootmanager.ErrRootCredentialsManagementNotEnabled) { return status, nil } status.RootCredentialsManagement = true - if err == aws.ErrRootSessionsNotEnabled { + if errors.Is(err, rootmanager.ErrRootSessionsNotEnabled) { return status, nil } diff --git a/internal/service/credentials.go b/internal/service/credentials.go index 24d23db..38a5c35 100644 --- a/internal/service/credentials.go +++ b/internal/service/credentials.go @@ -2,10 +2,12 @@ package service import ( "context" + "errors" "sync" "github.com/unicrons/aws-root-manager/internal/infra/aws" "github.com/unicrons/aws-root-manager/internal/logger" + "github.com/unicrons/aws-root-manager/rootmanager" ) // Delete root credentials for a list of AWS accounts @@ -156,7 +158,7 @@ func recoverAccountRootPassowrd(ctx context.Context, sts *aws.StsClient, account err = iamRecoverRoot.CreateLoginProfile(ctx) if err != nil { - if err == aws.ErrEntityAlreadyExists { + if errors.Is(err, rootmanager.ErrEntityAlreadyExists) { return false, nil } return false, err diff --git a/internal/service/rootmanager.go b/internal/service/rootmanager.go new file mode 100644 index 0000000..7329306 --- /dev/null +++ b/internal/service/rootmanager.go @@ -0,0 +1,121 @@ +package service + +import ( + "context" + "errors" + + "github.com/unicrons/aws-root-manager/internal/infra/aws" + "github.com/unicrons/aws-root-manager/rootmanager" +) + +// rootManagerImpl implements rootmanager.RootManager and converts between +// rootmanager types (public API) and aws types (infra) at the boundary. +type rootManagerImpl struct { + iam *aws.IamClient + sts *aws.StsClient + org *aws.OrganizationsClient +} + +// NewRootManager returns a RootManager that uses the given AWS clients. +// sts and org may be nil for callers that only use CheckRootAccess. +func NewRootManager(iam *aws.IamClient, sts *aws.StsClient, org *aws.OrganizationsClient) rootmanager.RootManager { + return &rootManagerImpl{iam: iam, sts: sts, org: org} +} + +func (r *rootManagerImpl) AuditAccounts(ctx context.Context, accountIds []string) ([]rootmanager.RootCredentials, error) { + if r.sts == nil { + return nil, errors.New("STS client required for audit") + } + creds, err := AuditAccounts(ctx, r.iam, r.sts, accountIds) + if err != nil { + return nil, err + } + return toRootCredentialsSlice(creds), nil +} + +func (r *rootManagerImpl) CheckRootAccess(ctx context.Context) (rootmanager.RootAccessStatus, error) { + status, err := CheckRootAccess(ctx, r.iam) + if err != nil { + return rootmanager.RootAccessStatus{}, err + } + return toRootAccessStatus(status), nil +} + +func (r *rootManagerImpl) EnableRootAccess(ctx context.Context, enableSessions bool) (rootmanager.RootAccessStatus, rootmanager.RootAccessStatus, error) { + if r.org == nil { + return rootmanager.RootAccessStatus{}, rootmanager.RootAccessStatus{}, errors.New("Organizations client required for enable") + } + initStatus, status, err := EnableRootAccess(ctx, r.iam, r.org, enableSessions) + if err != nil { + return toRootAccessStatus(initStatus), toRootAccessStatus(status), err + } + return toRootAccessStatus(initStatus), toRootAccessStatus(status), nil +} + +func (r *rootManagerImpl) DeleteCredentials(ctx context.Context, creds []rootmanager.RootCredentials, credentialType string) error { + if r.sts == nil { + return errors.New("STS client required for delete") + } + awsCreds := fromRootCredentialsSlice(creds) + if err := DeleteAccountsCredentials(ctx, r.iam, r.sts, awsCreds, credentialType); err != nil { + return err + } + return nil +} + +func (r *rootManagerImpl) RecoverRootPassword(ctx context.Context, accountIds []string) (map[string]bool, error) { + if r.sts == nil { + return nil, errors.New("STS client required for recovery") + } + resultMap, err := RecoverAccountsRootPassword(ctx, r.iam, r.sts, accountIds) + if err != nil { + return nil, err + } + return resultMap, nil +} + +func toRootCredentials(c aws.RootCredentials) rootmanager.RootCredentials { + return rootmanager.RootCredentials{ + AccountId: c.AccountId, + LoginProfile: c.LoginProfile, + AccessKeys: append([]string(nil), c.AccessKeys...), + MfaDevices: append([]string(nil), c.MfaDevices...), + SigningCertificates: append([]string(nil), c.SigningCertificates...), + Error: c.Error, + } +} + +func toRootCredentialsSlice(creds []aws.RootCredentials) []rootmanager.RootCredentials { + out := make([]rootmanager.RootCredentials, len(creds)) + for i := range creds { + out[i] = toRootCredentials(creds[i]) + } + return out +} + +func fromRootCredentials(c rootmanager.RootCredentials) aws.RootCredentials { + return aws.RootCredentials{ + AccountId: c.AccountId, + LoginProfile: c.LoginProfile, + AccessKeys: append([]string(nil), c.AccessKeys...), + MfaDevices: append([]string(nil), c.MfaDevices...), + SigningCertificates: append([]string(nil), c.SigningCertificates...), + Error: c.Error, + } +} + +func fromRootCredentialsSlice(creds []rootmanager.RootCredentials) []aws.RootCredentials { + out := make([]aws.RootCredentials, len(creds)) + for i := range creds { + out[i] = fromRootCredentials(creds[i]) + } + return out +} + +func toRootAccessStatus(s aws.RootAccessStatus) rootmanager.RootAccessStatus { + return rootmanager.RootAccessStatus{ + TrustedAccess: s.TrustedAccess, + RootCredentialsManagement: s.RootCredentialsManagement, + RootSessions: s.RootSessions, + } +} From 1a12929f1dd197140b267f3d5b78c97d42348239 Mon Sep 17 00:00:00 2001 From: sbldevnet Date: Tue, 3 Feb 2026 16:19:07 +0100 Subject: [PATCH 03/12] refactor: use single rootmanager type definition --- internal/infra/aws/iam.go | 15 ------- internal/service/audit.go | 15 ++++--- internal/service/configuration.go | 12 ++--- internal/service/credentials.go | 10 ++--- internal/service/rootmanager.go | 73 +++---------------------------- 5 files changed, 24 insertions(+), 101 deletions(-) diff --git a/internal/infra/aws/iam.go b/internal/infra/aws/iam.go index 42084cf..3f2b6d7 100644 --- a/internal/infra/aws/iam.go +++ b/internal/infra/aws/iam.go @@ -14,21 +14,6 @@ import ( "github.com/aws/aws-sdk-go-v2/service/iam/types" ) -type RootAccessStatus struct { - TrustedAccess bool - RootCredentialsManagement bool - RootSessions bool -} - -type RootCredentials struct { - AccountId string - LoginProfile bool - AccessKeys []string - MfaDevices []string - SigningCertificates []string - Error string -} - type IamClient struct { client *iam.Client } diff --git a/internal/service/audit.go b/internal/service/audit.go index 957a416..95e5e48 100644 --- a/internal/service/audit.go +++ b/internal/service/audit.go @@ -6,13 +6,14 @@ import ( "github.com/unicrons/aws-root-manager/internal/infra/aws" "github.com/unicrons/aws-root-manager/internal/logger" + "github.com/unicrons/aws-root-manager/rootmanager" ) // Get root credentials for a list of AWS accounts. -func AuditAccounts(ctx context.Context, iam *aws.IamClient, sts *aws.StsClient, accounts []string) ([]aws.RootCredentials, error) { +func AuditAccounts(ctx context.Context, iam *aws.IamClient, sts *aws.StsClient, accounts []string) ([]rootmanager.RootCredentials, error) { logger.Trace("service.AuditAccounts", "auditing accounts %s", accounts) - rootCredentials := make([]aws.RootCredentials, len(accounts)) + rootCredentials := make([]rootmanager.RootCredentials, len(accounts)) var wgAccounts sync.WaitGroup if err := iam.CheckOrganizationRootAccess(ctx, false); err != nil { @@ -25,7 +26,7 @@ func AuditAccounts(ctx context.Context, iam *aws.IamClient, sts *aws.StsClient, defer wgAccounts.Done() if accStatus, err := auditAccount(ctx, sts, accountId); err != nil { logger.Error("service.AuditAccounts", err, "account %s: audit skipped", accountId) - rootCredentials[idx] = aws.RootCredentials{AccountId: accountId, Error: err.Error()} + rootCredentials[idx] = rootmanager.RootCredentials{AccountId: accountId, Error: err.Error()} } else { rootCredentials[idx] = accStatus } @@ -38,16 +39,16 @@ func AuditAccounts(ctx context.Context, iam *aws.IamClient, sts *aws.StsClient, } // Get root credentials for a specific account -func auditAccount(ctx context.Context, sts *aws.StsClient, accountId string) (aws.RootCredentials, error) { +func auditAccount(ctx context.Context, sts *aws.StsClient, accountId string) (rootmanager.RootCredentials, error) { logger.Trace("service.auditAccount", "auditing account %s", accountId) awscfgRoot, err := sts.GetAssumeRootConfig(ctx, accountId, "IAMAuditRootUserCredentials") if err != nil { - return aws.RootCredentials{}, err + return rootmanager.RootCredentials{}, err } iamRoot := aws.NewIamClient(awscfgRoot) - var accountRootCredentials aws.RootCredentials + var accountRootCredentials rootmanager.RootCredentials loginProfile, err := iamRoot.GetLoginProfile(ctx, accountId) if err != nil { @@ -73,7 +74,7 @@ func auditAccount(ctx context.Context, sts *aws.StsClient, accountId string) (aw } logger.Debug("service.AuditAccounts", "account %s - signing_certificates: %s", accountId, certificates) - accountRootCredentials = aws.RootCredentials{ + accountRootCredentials = rootmanager.RootCredentials{ AccountId: accountId, LoginProfile: loginProfile, AccessKeys: accessKeys, diff --git a/internal/service/configuration.go b/internal/service/configuration.go index 239be4d..4c2adbf 100644 --- a/internal/service/configuration.go +++ b/internal/service/configuration.go @@ -9,8 +9,8 @@ import ( "github.com/unicrons/aws-root-manager/rootmanager" ) -func CheckRootAccess(ctx context.Context, iam *aws.IamClient) (aws.RootAccessStatus, error) { - var status = aws.RootAccessStatus{ +func CheckRootAccess(ctx context.Context, iam *aws.IamClient) (rootmanager.RootAccessStatus, error) { + var status = rootmanager.RootAccessStatus{ TrustedAccess: false, RootCredentialsManagement: false, RootSessions: false, @@ -32,10 +32,10 @@ func CheckRootAccess(ctx context.Context, iam *aws.IamClient) (aws.RootAccessSta return status, nil } - return aws.RootAccessStatus{}, err + return rootmanager.RootAccessStatus{}, err } - status = aws.RootAccessStatus{ + status = rootmanager.RootAccessStatus{ TrustedAccess: true, RootCredentialsManagement: true, RootSessions: true, @@ -44,8 +44,8 @@ func CheckRootAccess(ctx context.Context, iam *aws.IamClient) (aws.RootAccessSta return status, nil } -func EnableRootAccess(ctx context.Context, iam *aws.IamClient, org *aws.OrganizationsClient, enableSessions bool) (aws.RootAccessStatus, aws.RootAccessStatus, error) { - var initStatus, status aws.RootAccessStatus +func EnableRootAccess(ctx context.Context, iam *aws.IamClient, org *aws.OrganizationsClient, enableSessions bool) (rootmanager.RootAccessStatus, rootmanager.RootAccessStatus, error) { + var initStatus, status rootmanager.RootAccessStatus initStatus, err := CheckRootAccess(ctx, iam) if err != nil { diff --git a/internal/service/credentials.go b/internal/service/credentials.go index 38a5c35..a0d378f 100644 --- a/internal/service/credentials.go +++ b/internal/service/credentials.go @@ -11,7 +11,7 @@ import ( ) // Delete root credentials for a list of AWS accounts -func DeleteAccountsCredentials(ctx context.Context, iam *aws.IamClient, sts *aws.StsClient, creds []aws.RootCredentials, credentialType string) error { +func DeleteAccountsCredentials(ctx context.Context, iam *aws.IamClient, sts *aws.StsClient, creds []rootmanager.RootCredentials, credentialType string) error { var ( wgAccounts sync.WaitGroup errChan = make(chan error, len(creds)) @@ -23,9 +23,9 @@ func DeleteAccountsCredentials(ctx context.Context, iam *aws.IamClient, sts *aws for _, accountCredentials := range creds { wgAccounts.Add(1) - go func(accountId aws.RootCredentials) { + go func(accountCreds rootmanager.RootCredentials) { defer wgAccounts.Done() - if err := deleteAccountCrendentials(ctx, sts, accountCredentials, credentialType); err != nil { + if err := deleteAccountCrendentials(ctx, sts, accountCreds, credentialType); err != nil { errChan <- err } }(accountCredentials) @@ -42,7 +42,7 @@ func DeleteAccountsCredentials(ctx context.Context, iam *aws.IamClient, sts *aws } // Delete root credentials for a specific account -func deleteAccountCrendentials(ctx context.Context, sts *aws.StsClient, creds aws.RootCredentials, credentialType string) error { +func deleteAccountCrendentials(ctx context.Context, sts *aws.StsClient, creds rootmanager.RootCredentials, credentialType string) error { logger.Trace("service.deleteAccountCrendentials", "checking if account %s has %s credentials to delete", credentialType, credentialType) // Check if there are credentials to delete before assuming root @@ -89,7 +89,7 @@ func deleteAccountCrendentials(ctx context.Context, sts *aws.StsClient, creds aw } // Check if the account has credentials to delete based on the credential type -func hasCredentialsToDelete(creds aws.RootCredentials, credentialType string) bool { +func hasCredentialsToDelete(creds rootmanager.RootCredentials, credentialType string) bool { switch credentialType { case "all": return creds.LoginProfile || len(creds.AccessKeys) > 0 || len(creds.MfaDevices) > 0 || len(creds.SigningCertificates) > 0 diff --git a/internal/service/rootmanager.go b/internal/service/rootmanager.go index 7329306..e7623b5 100644 --- a/internal/service/rootmanager.go +++ b/internal/service/rootmanager.go @@ -8,8 +8,7 @@ import ( "github.com/unicrons/aws-root-manager/rootmanager" ) -// rootManagerImpl implements rootmanager.RootManager and converts between -// rootmanager types (public API) and aws types (infra) at the boundary. +// rootManagerImpl implements rootmanager.RootManager using AWS clients. type rootManagerImpl struct { iam *aws.IamClient sts *aws.StsClient @@ -26,41 +25,25 @@ func (r *rootManagerImpl) AuditAccounts(ctx context.Context, accountIds []string if r.sts == nil { return nil, errors.New("STS client required for audit") } - creds, err := AuditAccounts(ctx, r.iam, r.sts, accountIds) - if err != nil { - return nil, err - } - return toRootCredentialsSlice(creds), nil + return AuditAccounts(ctx, r.iam, r.sts, accountIds) } func (r *rootManagerImpl) CheckRootAccess(ctx context.Context) (rootmanager.RootAccessStatus, error) { - status, err := CheckRootAccess(ctx, r.iam) - if err != nil { - return rootmanager.RootAccessStatus{}, err - } - return toRootAccessStatus(status), nil + return CheckRootAccess(ctx, r.iam) } func (r *rootManagerImpl) EnableRootAccess(ctx context.Context, enableSessions bool) (rootmanager.RootAccessStatus, rootmanager.RootAccessStatus, error) { if r.org == nil { return rootmanager.RootAccessStatus{}, rootmanager.RootAccessStatus{}, errors.New("Organizations client required for enable") } - initStatus, status, err := EnableRootAccess(ctx, r.iam, r.org, enableSessions) - if err != nil { - return toRootAccessStatus(initStatus), toRootAccessStatus(status), err - } - return toRootAccessStatus(initStatus), toRootAccessStatus(status), nil + return EnableRootAccess(ctx, r.iam, r.org, enableSessions) } func (r *rootManagerImpl) DeleteCredentials(ctx context.Context, creds []rootmanager.RootCredentials, credentialType string) error { if r.sts == nil { return errors.New("STS client required for delete") } - awsCreds := fromRootCredentialsSlice(creds) - if err := DeleteAccountsCredentials(ctx, r.iam, r.sts, awsCreds, credentialType); err != nil { - return err - } - return nil + return DeleteAccountsCredentials(ctx, r.iam, r.sts, creds, credentialType) } func (r *rootManagerImpl) RecoverRootPassword(ctx context.Context, accountIds []string) (map[string]bool, error) { @@ -73,49 +56,3 @@ func (r *rootManagerImpl) RecoverRootPassword(ctx context.Context, accountIds [] } return resultMap, nil } - -func toRootCredentials(c aws.RootCredentials) rootmanager.RootCredentials { - return rootmanager.RootCredentials{ - AccountId: c.AccountId, - LoginProfile: c.LoginProfile, - AccessKeys: append([]string(nil), c.AccessKeys...), - MfaDevices: append([]string(nil), c.MfaDevices...), - SigningCertificates: append([]string(nil), c.SigningCertificates...), - Error: c.Error, - } -} - -func toRootCredentialsSlice(creds []aws.RootCredentials) []rootmanager.RootCredentials { - out := make([]rootmanager.RootCredentials, len(creds)) - for i := range creds { - out[i] = toRootCredentials(creds[i]) - } - return out -} - -func fromRootCredentials(c rootmanager.RootCredentials) aws.RootCredentials { - return aws.RootCredentials{ - AccountId: c.AccountId, - LoginProfile: c.LoginProfile, - AccessKeys: append([]string(nil), c.AccessKeys...), - MfaDevices: append([]string(nil), c.MfaDevices...), - SigningCertificates: append([]string(nil), c.SigningCertificates...), - Error: c.Error, - } -} - -func fromRootCredentialsSlice(creds []rootmanager.RootCredentials) []aws.RootCredentials { - out := make([]aws.RootCredentials, len(creds)) - for i := range creds { - out[i] = fromRootCredentials(creds[i]) - } - return out -} - -func toRootAccessStatus(s aws.RootAccessStatus) rootmanager.RootAccessStatus { - return rootmanager.RootAccessStatus{ - TrustedAccess: s.TrustedAccess, - RootCredentialsManagement: s.RootCredentialsManagement, - RootSessions: s.RootSessions, - } -} From ce7829ea952fd125b47a14447a39ff6094e2b187 Mon Sep 17 00:00:00 2001 From: sbldevnet Date: Tue, 3 Feb 2026 16:21:45 +0100 Subject: [PATCH 04/12] refactor(service): unexport internal operation functions --- internal/service/audit.go | 16 ++++++++-------- internal/service/configuration.go | 8 ++++---- internal/service/credentials.go | 18 +++++++++--------- internal/service/rootmanager.go | 10 +++++----- 4 files changed, 26 insertions(+), 26 deletions(-) diff --git a/internal/service/audit.go b/internal/service/audit.go index 95e5e48..d234f18 100644 --- a/internal/service/audit.go +++ b/internal/service/audit.go @@ -9,9 +9,9 @@ import ( "github.com/unicrons/aws-root-manager/rootmanager" ) -// Get root credentials for a list of AWS accounts. -func AuditAccounts(ctx context.Context, iam *aws.IamClient, sts *aws.StsClient, accounts []string) ([]rootmanager.RootCredentials, error) { - logger.Trace("service.AuditAccounts", "auditing accounts %s", accounts) +// auditAccounts returns root credentials for a list of AWS accounts. +func auditAccounts(ctx context.Context, iam *aws.IamClient, sts *aws.StsClient, accounts []string) ([]rootmanager.RootCredentials, error) { + logger.Trace("service.auditAccounts", "auditing accounts %s", accounts) rootCredentials := make([]rootmanager.RootCredentials, len(accounts)) var wgAccounts sync.WaitGroup @@ -25,7 +25,7 @@ func AuditAccounts(ctx context.Context, iam *aws.IamClient, sts *aws.StsClient, go func(idx int, accountId string) { defer wgAccounts.Done() if accStatus, err := auditAccount(ctx, sts, accountId); err != nil { - logger.Error("service.AuditAccounts", err, "account %s: audit skipped", accountId) + logger.Error("service.auditAccounts", err, "account %s: audit skipped", accountId) rootCredentials[idx] = rootmanager.RootCredentials{AccountId: accountId, Error: err.Error()} } else { rootCredentials[idx] = accStatus @@ -54,25 +54,25 @@ func auditAccount(ctx context.Context, sts *aws.StsClient, accountId string) (ro if err != nil { return accountRootCredentials, err } - logger.Debug("service.AuditAccounts", "account %s - login_profile: %t", accountId, loginProfile) + logger.Debug("service.auditAccounts", "account %s - login_profile: %t", accountId, loginProfile) accessKeys, err := iamRoot.ListAccessKeys(ctx, accountId) if err != nil { return accountRootCredentials, err } - logger.Debug("service.AuditAccounts", "account %s - access_keys: %s", accountId, accessKeys) + logger.Debug("service.auditAccounts", "account %s - access_keys: %s", accountId, accessKeys) mfaDevices, err := iamRoot.ListMFADevices(ctx, accountId) if err != nil { return accountRootCredentials, err } - logger.Debug("service.AuditAccounts", "account %s - mfa_devices: %s", accountId, mfaDevices) + logger.Debug("service.auditAccounts", "account %s - mfa_devices: %s", accountId, mfaDevices) certificates, err := iamRoot.ListSigningCertificates(ctx, accountId) if err != nil { return accountRootCredentials, err } - logger.Debug("service.AuditAccounts", "account %s - signing_certificates: %s", accountId, certificates) + logger.Debug("service.auditAccounts", "account %s - signing_certificates: %s", accountId, certificates) accountRootCredentials = rootmanager.RootCredentials{ AccountId: accountId, diff --git a/internal/service/configuration.go b/internal/service/configuration.go index 4c2adbf..5ed71b6 100644 --- a/internal/service/configuration.go +++ b/internal/service/configuration.go @@ -9,7 +9,7 @@ import ( "github.com/unicrons/aws-root-manager/rootmanager" ) -func CheckRootAccess(ctx context.Context, iam *aws.IamClient) (rootmanager.RootAccessStatus, error) { +func checkRootAccess(ctx context.Context, iam *aws.IamClient) (rootmanager.RootAccessStatus, error) { var status = rootmanager.RootAccessStatus{ TrustedAccess: false, RootCredentialsManagement: false, @@ -44,10 +44,10 @@ func CheckRootAccess(ctx context.Context, iam *aws.IamClient) (rootmanager.RootA return status, nil } -func EnableRootAccess(ctx context.Context, iam *aws.IamClient, org *aws.OrganizationsClient, enableSessions bool) (rootmanager.RootAccessStatus, rootmanager.RootAccessStatus, error) { +func enableRootAccess(ctx context.Context, iam *aws.IamClient, org *aws.OrganizationsClient, enableSessions bool) (rootmanager.RootAccessStatus, rootmanager.RootAccessStatus, error) { var initStatus, status rootmanager.RootAccessStatus - initStatus, err := CheckRootAccess(ctx, iam) + initStatus, err := checkRootAccess(ctx, iam) if err != nil { return initStatus, status, err } @@ -77,7 +77,7 @@ func EnableRootAccess(ctx context.Context, iam *aws.IamClient, org *aws.Organiza } } - status, err = CheckRootAccess(ctx, iam) + status, err = checkRootAccess(ctx, iam) if err != nil { return initStatus, status, err } diff --git a/internal/service/credentials.go b/internal/service/credentials.go index a0d378f..f838694 100644 --- a/internal/service/credentials.go +++ b/internal/service/credentials.go @@ -10,8 +10,8 @@ import ( "github.com/unicrons/aws-root-manager/rootmanager" ) -// Delete root credentials for a list of AWS accounts -func DeleteAccountsCredentials(ctx context.Context, iam *aws.IamClient, sts *aws.StsClient, creds []rootmanager.RootCredentials, credentialType string) error { +// deleteAccountsCredentials deletes root credentials for a list of AWS accounts. +func deleteAccountsCredentials(ctx context.Context, iam *aws.IamClient, sts *aws.StsClient, creds []rootmanager.RootCredentials, credentialType string) error { var ( wgAccounts sync.WaitGroup errChan = make(chan error, len(creds)) @@ -25,7 +25,7 @@ func DeleteAccountsCredentials(ctx context.Context, iam *aws.IamClient, sts *aws wgAccounts.Add(1) go func(accountCreds rootmanager.RootCredentials) { defer wgAccounts.Done() - if err := deleteAccountCrendentials(ctx, sts, accountCreds, credentialType); err != nil { + if err := deleteAccountCredentials(ctx, sts, accountCreds, credentialType); err != nil { errChan <- err } }(accountCredentials) @@ -41,13 +41,13 @@ func DeleteAccountsCredentials(ctx context.Context, iam *aws.IamClient, sts *aws return nil } -// Delete root credentials for a specific account -func deleteAccountCrendentials(ctx context.Context, sts *aws.StsClient, creds rootmanager.RootCredentials, credentialType string) error { - logger.Trace("service.deleteAccountCrendentials", "checking if account %s has %s credentials to delete", credentialType, credentialType) +// deleteAccountCredentials deletes root credentials for a specific account. +func deleteAccountCredentials(ctx context.Context, sts *aws.StsClient, creds rootmanager.RootCredentials, credentialType string) error { + logger.Trace("service.deleteAccountCredentials", "checking if account %s has %s credentials to delete", credentialType, credentialType) // Check if there are credentials to delete before assuming root if !hasCredentialsToDelete(creds, credentialType) { - logger.Info("service.deleteAccountCrendentials", "no %s credentials found for account %s", credentialType, creds.AccountId) + logger.Info("service.deleteAccountCredentials", "no %s credentials found for account %s", credentialType, creds.AccountId) return nil } @@ -106,8 +106,8 @@ func hasCredentialsToDelete(creds rootmanager.RootCredentials, credentialType st } } -// Enable the recovery process for root passwords for a list of AWS accounts -func RecoverAccountsRootPassword(ctx context.Context, iam *aws.IamClient, sts *aws.StsClient, accountIds []string) (map[string]bool, error) { +// recoverAccountsRootPassword initiates root password recovery for a list of AWS accounts. +func recoverAccountsRootPassword(ctx context.Context, iam *aws.IamClient, sts *aws.StsClient, accountIds []string) (map[string]bool, error) { var ( wgAccounts sync.WaitGroup results = sync.Map{} diff --git a/internal/service/rootmanager.go b/internal/service/rootmanager.go index e7623b5..311baa1 100644 --- a/internal/service/rootmanager.go +++ b/internal/service/rootmanager.go @@ -25,32 +25,32 @@ func (r *rootManagerImpl) AuditAccounts(ctx context.Context, accountIds []string if r.sts == nil { return nil, errors.New("STS client required for audit") } - return AuditAccounts(ctx, r.iam, r.sts, accountIds) + return auditAccounts(ctx, r.iam, r.sts, accountIds) } func (r *rootManagerImpl) CheckRootAccess(ctx context.Context) (rootmanager.RootAccessStatus, error) { - return CheckRootAccess(ctx, r.iam) + return checkRootAccess(ctx, r.iam) } func (r *rootManagerImpl) EnableRootAccess(ctx context.Context, enableSessions bool) (rootmanager.RootAccessStatus, rootmanager.RootAccessStatus, error) { if r.org == nil { return rootmanager.RootAccessStatus{}, rootmanager.RootAccessStatus{}, errors.New("Organizations client required for enable") } - return EnableRootAccess(ctx, r.iam, r.org, enableSessions) + return enableRootAccess(ctx, r.iam, r.org, enableSessions) } func (r *rootManagerImpl) DeleteCredentials(ctx context.Context, creds []rootmanager.RootCredentials, credentialType string) error { if r.sts == nil { return errors.New("STS client required for delete") } - return DeleteAccountsCredentials(ctx, r.iam, r.sts, creds, credentialType) + return deleteAccountsCredentials(ctx, r.iam, r.sts, creds, credentialType) } func (r *rootManagerImpl) RecoverRootPassword(ctx context.Context, accountIds []string) (map[string]bool, error) { if r.sts == nil { return nil, errors.New("STS client required for recovery") } - resultMap, err := RecoverAccountsRootPassword(ctx, r.iam, r.sts, accountIds) + resultMap, err := recoverAccountsRootPassword(ctx, r.iam, r.sts, accountIds) if err != nil { return nil, err } From 1e4dfdf8a33c1792ab7facfb42d3169d6cedf75e Mon Sep 17 00:00:00 2001 From: sbldevnet Date: Tue, 3 Feb 2026 16:24:53 +0100 Subject: [PATCH 05/12] refactor(service): rename rootManagerImpl to manager --- internal/service/rootmanager.go | 34 ++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/internal/service/rootmanager.go b/internal/service/rootmanager.go index 311baa1..3832f15 100644 --- a/internal/service/rootmanager.go +++ b/internal/service/rootmanager.go @@ -8,8 +8,8 @@ import ( "github.com/unicrons/aws-root-manager/rootmanager" ) -// rootManagerImpl implements rootmanager.RootManager using AWS clients. -type rootManagerImpl struct { +// manager implements rootmanager.RootManager using AWS clients. +type manager struct { iam *aws.IamClient sts *aws.StsClient org *aws.OrganizationsClient @@ -18,39 +18,39 @@ type rootManagerImpl struct { // NewRootManager returns a RootManager that uses the given AWS clients. // sts and org may be nil for callers that only use CheckRootAccess. func NewRootManager(iam *aws.IamClient, sts *aws.StsClient, org *aws.OrganizationsClient) rootmanager.RootManager { - return &rootManagerImpl{iam: iam, sts: sts, org: org} + return &manager{iam: iam, sts: sts, org: org} } -func (r *rootManagerImpl) AuditAccounts(ctx context.Context, accountIds []string) ([]rootmanager.RootCredentials, error) { - if r.sts == nil { +func (m *manager) AuditAccounts(ctx context.Context, accountIds []string) ([]rootmanager.RootCredentials, error) { + if m.sts == nil { return nil, errors.New("STS client required for audit") } - return auditAccounts(ctx, r.iam, r.sts, accountIds) + return auditAccounts(ctx, m.iam, m.sts, accountIds) } -func (r *rootManagerImpl) CheckRootAccess(ctx context.Context) (rootmanager.RootAccessStatus, error) { - return checkRootAccess(ctx, r.iam) +func (m *manager) CheckRootAccess(ctx context.Context) (rootmanager.RootAccessStatus, error) { + return checkRootAccess(ctx, m.iam) } -func (r *rootManagerImpl) EnableRootAccess(ctx context.Context, enableSessions bool) (rootmanager.RootAccessStatus, rootmanager.RootAccessStatus, error) { - if r.org == nil { +func (m *manager) EnableRootAccess(ctx context.Context, enableSessions bool) (rootmanager.RootAccessStatus, rootmanager.RootAccessStatus, error) { + if m.org == nil { return rootmanager.RootAccessStatus{}, rootmanager.RootAccessStatus{}, errors.New("Organizations client required for enable") } - return enableRootAccess(ctx, r.iam, r.org, enableSessions) + return enableRootAccess(ctx, m.iam, m.org, enableSessions) } -func (r *rootManagerImpl) DeleteCredentials(ctx context.Context, creds []rootmanager.RootCredentials, credentialType string) error { - if r.sts == nil { +func (m *manager) DeleteCredentials(ctx context.Context, creds []rootmanager.RootCredentials, credentialType string) error { + if m.sts == nil { return errors.New("STS client required for delete") } - return deleteAccountsCredentials(ctx, r.iam, r.sts, creds, credentialType) + return deleteAccountsCredentials(ctx, m.iam, m.sts, creds, credentialType) } -func (r *rootManagerImpl) RecoverRootPassword(ctx context.Context, accountIds []string) (map[string]bool, error) { - if r.sts == nil { +func (m *manager) RecoverRootPassword(ctx context.Context, accountIds []string) (map[string]bool, error) { + if m.sts == nil { return nil, errors.New("STS client required for recovery") } - resultMap, err := recoverAccountsRootPassword(ctx, r.iam, r.sts, accountIds) + resultMap, err := recoverAccountsRootPassword(ctx, m.iam, m.sts, accountIds) if err != nil { return nil, err } From 3f834785b5b81e5b104b345b331d55f7c14b4975 Mon Sep 17 00:00:00 2001 From: sbldevnet Date: Tue, 3 Feb 2026 17:54:43 +0100 Subject: [PATCH 06/12] refactor(cmd): extract command logic to functions and use functions everywhere --- cmd/audit.go | 109 +++++++++++++++++++++--------------------- cmd/check.go | 67 +++++++++++++------------- cmd/delete.go | 123 ++++++++++++------------------------------------ cmd/enable.go | 74 ++++++++++++++--------------- cmd/recovery.go | 85 ++++++++++++++++----------------- cmd/root.go | 6 +++ cmd/version.go | 19 ++++---- 7 files changed, 211 insertions(+), 272 deletions(-) diff --git a/cmd/audit.go b/cmd/audit.go index fd49549..ff4df0d 100644 --- a/cmd/audit.go +++ b/cmd/audit.go @@ -14,65 +14,64 @@ import ( "github.com/spf13/cobra" ) -var auditCmd = &cobra.Command{ - Use: "audit", - Short: "Retrieve root user credentials", - Long: `Retrieve available root user credentials for all member accounts within an AWS Organization.`, - SilenceUsage: true, - RunE: func(cmd *cobra.Command, args []string) error { - logger.Trace("cmd.audit", "audit called") +func Audit() *cobra.Command { + cmd := &cobra.Command{ + Use: "audit", + Short: "Retrieve root user credentials", + Long: `Retrieve available root user credentials for all member accounts within an AWS Organization.`, + SilenceUsage: true, + RunE: func(cmd *cobra.Command, args []string) error { + logger.Trace("cmd.audit", "audit called") - ctx := context.Background() - awscfg, err := aws.LoadAWSConfig(ctx) - if err != nil { - logger.Error("cmd.audit", err, "failed to load aws config") - return err - } - - auditAccounts, err := ui.SelectTargetAccounts(ctx, accountsFlags) - if err != nil { - logger.Error("cmd.audit", err, "failed to get accounts to audit") - return err - } - if len(auditAccounts) == 0 { - logger.Info("cmd.audit", "no accounts selected") - return nil - } - logger.Debug("cmd.audit", "selected accounts: %s", strings.Join(auditAccounts, ", ")) + ctx := context.Background() + awscfg, err := aws.LoadAWSConfig(ctx) + if err != nil { + logger.Error("cmd.audit", err, "failed to load aws config") + return err + } - rm := service.NewRootManager(aws.NewIamClient(awscfg), aws.NewStsClient(awscfg), aws.NewOrganizationsClient(awscfg)) - audit, err := rm.AuditAccounts(ctx, auditAccounts) - if err != nil { - logger.Error("cmd.audit", err, "failed to audit accounts") - return err - } + auditAccounts, err := ui.SelectTargetAccounts(ctx, accountsFlags) + if err != nil { + logger.Error("cmd.audit", err, "failed to get accounts to audit") + return err + } + if len(auditAccounts) == 0 { + logger.Info("cmd.audit", "no accounts selected") + return nil + } + logger.Debug("cmd.audit", "selected accounts: %s", strings.Join(auditAccounts, ", ")) - var skipped int - headers := []string{"Account", "LoginProfile", "AccessKeys", "MFA Devices", "Signing Certificates"} - var data [][]any - for i, acc := range audit { - if acc.Error != "" { - skipped++ - continue + rm := service.NewRootManager(aws.NewIamClient(awscfg), aws.NewStsClient(awscfg), aws.NewOrganizationsClient(awscfg)) + audit, err := rm.AuditAccounts(ctx, auditAccounts) + if err != nil { + logger.Error("cmd.audit", err, "failed to audit accounts") + return err } - data = append(data, []any{ - auditAccounts[i], - acc.LoginProfile, - acc.AccessKeys, - acc.MfaDevices, - acc.SigningCertificates, - }) - } - output.HandleOutput(outputFlag, headers, data) - if skipped > 0 { - return fmt.Errorf("audit skipped for %d account(s)", skipped) - } - return nil - }, -} + var skipped int + headers := []string{"Account", "LoginProfile", "AccessKeys", "MFA Devices", "Signing Certificates"} + var data [][]any + for i, acc := range audit { + if acc.Error != "" { + skipped++ + continue + } + data = append(data, []any{ + auditAccounts[i], + acc.LoginProfile, + acc.AccessKeys, + acc.MfaDevices, + acc.SigningCertificates, + }) + } + output.HandleOutput(outputFlag, headers, data) -func init() { - rootCmd.AddCommand(auditCmd) - auditCmd.PersistentFlags().StringSliceVarP(&accountsFlags, "accounts", "a", []string{}, "List of AWS account IDs to audit (comma-separated). Use \"all\" to audit all accounts.") + if skipped > 0 { + return fmt.Errorf("audit skipped for %d account(s)", skipped) + } + return nil + }, + } + cmd.PersistentFlags().StringSliceVarP(&accountsFlags, "accounts", "a", []string{}, "List of AWS account IDs to audit (comma-separated). Use \"all\" to audit all accounts.") + return cmd } diff --git a/cmd/check.go b/cmd/check.go index 39d1980..0662480 100644 --- a/cmd/check.go +++ b/cmd/check.go @@ -4,45 +4,44 @@ import ( "context" "strconv" + "github.com/unicrons/aws-root-manager/internal/cli/output" "github.com/unicrons/aws-root-manager/internal/infra/aws" "github.com/unicrons/aws-root-manager/internal/logger" - "github.com/unicrons/aws-root-manager/internal/cli/output" "github.com/unicrons/aws-root-manager/internal/service" "github.com/spf13/cobra" ) -var checkCmd = &cobra.Command{ - Use: "check", - Short: "Check if centralized root access is enabled.", - Long: `Retrieve the status of centralized root access settings for an AWS Organization.`, - Run: func(cmd *cobra.Command, args []string) { - logger.Trace("cmd.check", "check called") - - ctx := context.Background() - awscfg, err := aws.LoadAWSConfig(ctx) - if err != nil { - logger.Error("cmd.check", err, "failed to load aws config") - return - } - - rm := service.NewRootManager(aws.NewIamClient(awscfg), nil, nil) - status, err := rm.CheckRootAccess(ctx) - if err != nil { - logger.Error("cmd.check", err, "failed to check root access configuration") - return - } - - headers := []string{"Name", "Status"} - data := [][]any{ - {"TrustedAccess", strconv.FormatBool(status.TrustedAccess)}, - {"RootCredentialsManagement", strconv.FormatBool(status.RootCredentialsManagement)}, - {"RootSessions", strconv.FormatBool(status.RootSessions)}, - } - output.HandleOutput(outputFlag, headers, data) - }, -} - -func init() { - rootCmd.AddCommand(checkCmd) +func Check() *cobra.Command { + return &cobra.Command{ + Use: "check", + Short: "Check if centralized root access is enabled.", + Long: `Retrieve the status of centralized root access settings for an AWS Organization.`, + RunE: func(cmd *cobra.Command, args []string) error { + logger.Trace("cmd.check", "check called") + + ctx := context.Background() + awscfg, err := aws.LoadAWSConfig(ctx) + if err != nil { + logger.Error("cmd.check", err, "failed to load aws config") + return err + } + + rm := service.NewRootManager(aws.NewIamClient(awscfg), nil, nil) + status, err := rm.CheckRootAccess(ctx) + if err != nil { + logger.Error("cmd.check", err, "failed to check root access configuration") + return err + } + + headers := []string{"Name", "Status"} + data := [][]any{ + {"TrustedAccess", strconv.FormatBool(status.TrustedAccess)}, + {"RootCredentialsManagement", strconv.FormatBool(status.RootCredentialsManagement)}, + {"RootSessions", strconv.FormatBool(status.RootSessions)}, + } + output.HandleOutput(outputFlag, headers, data) + return nil + }, + } } diff --git a/cmd/delete.go b/cmd/delete.go index afdfd1d..17bfdf7 100644 --- a/cmd/delete.go +++ b/cmd/delete.go @@ -14,103 +14,38 @@ import ( "github.com/spf13/cobra" ) -var deleteCmd = &cobra.Command{ - Use: "delete", - Short: "Delete root user credentials", - Long: `Delete root user credentials for specific AWS Organization member accounts.`, -} - -var deleteAllCmd = &cobra.Command{ - Use: "all", - Short: "Delete all existing root user credentials", - Long: `Delete all existing root user credentials for specific AWS Organization member accounts.`, - SilenceUsage: true, - RunE: func(cmd *cobra.Command, args []string) error { - logger.Trace("cmd.deleteAll", "delete all called") - - if err := delete(accountsFlags, "all"); err != nil { - return err - } - - return nil - }, -} - -var deleteLoginCmd = &cobra.Command{ - Use: "login", - Short: "Delete root user Login Profile", - Long: `Delete existing root user Login Profile for specific AWS Organization member accounts.`, - SilenceUsage: true, - RunE: func(cmd *cobra.Command, args []string) error { - logger.Trace("cmd.deleteLogin", "delete login called") - - if err := delete(accountsFlags, "login"); err != nil { - return err - } - - return nil - }, -} - -var deleteKeysCmd = &cobra.Command{ - Use: "keys", - Short: "Delete root user Access Keys", - Long: `Delete existing root user Access Keys for specific AWS Organization member accounts.`, - SilenceUsage: true, - RunE: func(cmd *cobra.Command, args []string) error { - logger.Trace("cmd.deleteKeys", "delete keys called") - - if err := delete(accountsFlags, "keys"); err != nil { - return err - } - - return nil - }, -} - -var deleteMfaCmd = &cobra.Command{ - Use: "mfa", - Short: "Deactivate root user MFA Devices", - Long: `Deactivate existing root user MFA Devices for specific AWS Organization member accounts.`, - SilenceUsage: true, - RunE: func(cmd *cobra.Command, args []string) error { - logger.Trace("cmd.deleteMfa", "delete mfa called") - - if err := delete(accountsFlags, "mfa"); err != nil { - return err - } - - return nil - }, -} - -var deleteCertificatesCmd = &cobra.Command{ - Use: "certificates", - Short: "Delete root user Signin Certificates", - Long: `Delete existing root user Signing Certificates for specific AWS Organization member accounts.`, - SilenceUsage: true, - RunE: func(cmd *cobra.Command, args []string) error { - logger.Trace("cmd.deleteCertificates", "delete certificates called") - - if err := delete(accountsFlags, "certificate"); err != nil { - return err - } - - return nil - }, +func Delete() *cobra.Command { + cmd := &cobra.Command{ + Use: "delete", + Short: "Delete root user credentials", + Long: `Delete root user credentials for specific AWS Organization member accounts.`, + } + cmd.PersistentFlags().StringSliceVarP(&accountsFlags, "accounts", "a", []string{}, "List of AWS account IDs to audit (comma-separated). Use \"all\" to audit all accounts.") + cmd.AddCommand(deleteSubcommand("all", "Delete all existing root user credentials", "Delete all existing root user credentials for specific AWS Organization member accounts.")) + cmd.AddCommand(deleteSubcommand("login", "Delete root user Login Profile", "Delete existing root user Login Profile for specific AWS Organization member accounts.")) + cmd.AddCommand(deleteSubcommand("keys", "Delete root user Access Keys", "Delete existing root user Access Keys for specific AWS Organization member accounts.")) + cmd.AddCommand(deleteSubcommand("mfa", "Deactivate root user MFA Devices", "Deactivate existing root user MFA Devices for specific AWS Organization member accounts.")) + cmd.AddCommand(deleteSubcommand("certificates", "Delete root user Signin Certificates", "Delete existing root user Signing Certificates for specific AWS Organization member accounts.")) + return cmd } -func init() { - rootCmd.AddCommand(deleteCmd) - deleteCmd.AddCommand(deleteAllCmd) - deleteCmd.AddCommand(deleteLoginCmd) - deleteCmd.AddCommand(deleteKeysCmd) - deleteCmd.AddCommand(deleteMfaCmd) - deleteCmd.AddCommand(deleteCertificatesCmd) - deleteCmd.PersistentFlags().StringSliceVarP(&accountsFlags, "accounts", "a", []string{}, "List of AWS account IDs to audit (comma-separated). Use \"all\" to audit all accounts.") +func deleteSubcommand(use, short, long string) *cobra.Command { + credentialType := use + if use == "certificates" { + credentialType = "certificate" + } + return &cobra.Command{ + Use: use, + Short: short, + Long: long, + SilenceUsage: true, + RunE: func(cmd *cobra.Command, args []string) error { + return runDelete(accountsFlags, credentialType) + }, + } } -func delete(accountsFlags []string, credentialType string) error { +func runDelete(accountsFlags []string, credentialType string) error { ctx := context.Background() awscfg, err := aws.LoadAWSConfig(ctx) if err != nil { @@ -143,7 +78,7 @@ func delete(accountsFlags []string, credentialType string) error { data = append(data, []any{ account, credentialType, - fmt.Sprintf("deleted"), // TODO: this is not real + "deleted", // TODO: this is not real }) } output.HandleOutput(outputFlag, headers, data) diff --git a/cmd/enable.go b/cmd/enable.go index e1e0330..28b5d21 100644 --- a/cmd/enable.go +++ b/cmd/enable.go @@ -4,48 +4,48 @@ import ( "context" "strconv" + "github.com/unicrons/aws-root-manager/internal/cli/output" "github.com/unicrons/aws-root-manager/internal/infra/aws" "github.com/unicrons/aws-root-manager/internal/logger" - "github.com/unicrons/aws-root-manager/internal/cli/output" "github.com/unicrons/aws-root-manager/internal/service" "github.com/spf13/cobra" ) -var enableCmd = &cobra.Command{ - Use: "enable", - Short: "Enable centralized root access", - Long: `Enable centralized root access management in an AWS Organization.`, - Run: func(cmd *cobra.Command, args []string) { - logger.Trace("cmd.enable", "enable called") - - enableRootSessions, _ := cmd.Flags().GetBool("enableRootSessions") - - ctx := context.Background() - awscfg, err := aws.LoadAWSConfig(ctx) - if err != nil { - logger.Error("cmd.enable", err, "failed to load aws config") - return - } - - rm := service.NewRootManager(aws.NewIamClient(awscfg), nil, aws.NewOrganizationsClient(awscfg)) - initStatus, status, err := rm.EnableRootAccess(ctx, enableRootSessions) - if err != nil { - logger.Error("cmd.enable", err, "failed to enable root access") - return - } - - headers := []string{"Name", "InitialStatus", "CurrentStatus"} - data := [][]any{ - {"TrustedAccess", strconv.FormatBool(initStatus.TrustedAccess), strconv.FormatBool(status.TrustedAccess)}, - {"RootCredentialsManagement", strconv.FormatBool(initStatus.RootCredentialsManagement), strconv.FormatBool(status.RootCredentialsManagement)}, - {"RootSessions", strconv.FormatBool(initStatus.RootSessions), strconv.FormatBool(status.RootSessions)}, - } - output.HandleOutput(outputFlag, headers, data) - }, -} - -func init() { - rootCmd.AddCommand(enableCmd) - enableCmd.PersistentFlags().Bool("enableRootSessions", false, "Enable Root Sessions, required only when working with resource policies.") +func Enable() *cobra.Command { + cmd := &cobra.Command{ + Use: "enable", + Short: "Enable centralized root access", + Long: `Enable centralized root access management in an AWS Organization.`, + RunE: func(cmd *cobra.Command, args []string) error { + logger.Trace("cmd.enable", "enable called") + + enableRootSessions, _ := cmd.Flags().GetBool("enableRootSessions") + + ctx := context.Background() + awscfg, err := aws.LoadAWSConfig(ctx) + if err != nil { + logger.Error("cmd.enable", err, "failed to load aws config") + return err + } + + rm := service.NewRootManager(aws.NewIamClient(awscfg), nil, aws.NewOrganizationsClient(awscfg)) + initStatus, status, err := rm.EnableRootAccess(ctx, enableRootSessions) + if err != nil { + logger.Error("cmd.enable", err, "failed to enable root access") + return err + } + + headers := []string{"Name", "InitialStatus", "CurrentStatus"} + data := [][]any{ + {"TrustedAccess", strconv.FormatBool(initStatus.TrustedAccess), strconv.FormatBool(status.TrustedAccess)}, + {"RootCredentialsManagement", strconv.FormatBool(initStatus.RootCredentialsManagement), strconv.FormatBool(status.RootCredentialsManagement)}, + {"RootSessions", strconv.FormatBool(initStatus.RootSessions), strconv.FormatBool(status.RootSessions)}, + } + output.HandleOutput(outputFlag, headers, data) + return nil + }, + } + cmd.PersistentFlags().Bool("enableRootSessions", false, "Enable Root Sessions, required only when working with resource policies.") + return cmd } diff --git a/cmd/recovery.go b/cmd/recovery.go index 4351f6f..60a2e15 100644 --- a/cmd/recovery.go +++ b/cmd/recovery.go @@ -13,52 +13,53 @@ import ( "github.com/spf13/cobra" ) -var recoveryCmd = &cobra.Command{ - Use: "recovery", - Short: "Allow root password recovery", - Long: `Retrieve the status of centralized root access settings for an AWS Organization.`, - Run: func(cmd *cobra.Command, args []string) { - logger.Trace("cmd.recovery", "recovery called") +func Recovery() *cobra.Command { + cmd := &cobra.Command{ + Use: "recovery", + Short: "Allow root password recovery", + Long: `Retrieve the status of centralized root access settings for an AWS Organization.`, + RunE: func(cmd *cobra.Command, args []string) error { + logger.Trace("cmd.recovery", "recovery called") - ctx := context.Background() - awscfg, err := aws.LoadAWSConfig(ctx) - if err != nil { - logger.Error("cmd.recovery", err, "failed to load aws config") - return - } - - targetAccounts, err := ui.SelectTargetAccounts(ctx, accountsFlags) - if err != nil { - logger.Error("cmd.recovery", err, "failed to get target accounts") - } - if len(targetAccounts) == 0 { - logger.Info("cmd.recovery", "no accounts selected") - return - } - logger.Debug("cmd.recovery", "selected accounts: %s", strings.Join(targetAccounts, ", ")) + ctx := context.Background() + awscfg, err := aws.LoadAWSConfig(ctx) + if err != nil { + logger.Error("cmd.recovery", err, "failed to load aws config") + return err + } - rm := service.NewRootManager(aws.NewIamClient(awscfg), aws.NewStsClient(awscfg), aws.NewOrganizationsClient(awscfg)) - resultMap, err := rm.RecoverRootPassword(ctx, targetAccounts) - if err != nil { - logger.Error("cmd.recovery", err, "failed to recover root password") - return - } + targetAccounts, err := ui.SelectTargetAccounts(ctx, accountsFlags) + if err != nil { + logger.Error("cmd.recovery", err, "failed to get target accounts") + return err + } + if len(targetAccounts) == 0 { + logger.Info("cmd.recovery", "no accounts selected") + return nil + } + logger.Debug("cmd.recovery", "selected accounts: %s", strings.Join(targetAccounts, ", ")) - headers := []string{"Account", "Login Profile"} - var data [][]any - for acc, success := range resultMap { - status := "recovered" - if !success { - status = "already exists" + rm := service.NewRootManager(aws.NewIamClient(awscfg), aws.NewStsClient(awscfg), aws.NewOrganizationsClient(awscfg)) + resultMap, err := rm.RecoverRootPassword(ctx, targetAccounts) + if err != nil { + logger.Error("cmd.recovery", err, "failed to recover root password") + return err } - data = append(data, []any{acc, status}) - } - output.HandleOutput(outputFlag, headers, data) - }, -} + headers := []string{"Account", "Login Profile"} + var data [][]any + for acc, success := range resultMap { + status := "recovered" + if !success { + status = "already exists" + } + data = append(data, []any{acc, status}) + } -func init() { - rootCmd.AddCommand(recoveryCmd) - recoveryCmd.PersistentFlags().StringSliceVarP(&accountsFlags, "accounts", "a", []string{}, "List of tarjet AWS account IDs (comma-separated). Use \"all\" to select all accounts.") + output.HandleOutput(outputFlag, headers, data) + return nil + }, + } + cmd.PersistentFlags().StringSliceVarP(&accountsFlags, "accounts", "a", []string{}, "List of tarjet AWS account IDs (comma-separated). Use \"all\" to select all accounts.") + return cmd } diff --git a/cmd/root.go b/cmd/root.go index 6e93e84..2112ea2 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -37,4 +37,10 @@ func Execute() { func init() { rootCmd.PersistentFlags().StringVarP(&outputFlag, "output", "o", "table", "Set the output format (table, json, csv)") + rootCmd.AddCommand(Audit()) + rootCmd.AddCommand(Check()) + rootCmd.AddCommand(Enable()) + rootCmd.AddCommand(Delete()) + rootCmd.AddCommand(Recovery()) + rootCmd.AddCommand(Version()) } diff --git a/cmd/version.go b/cmd/version.go index d6e3f15..e59b8a6 100644 --- a/cmd/version.go +++ b/cmd/version.go @@ -8,14 +8,13 @@ import ( var version = "dev" -var versionCmd = &cobra.Command{ - Use: "version", - Short: "Print the version", - Run: func(cmd *cobra.Command, args []string) { - fmt.Println("Version:", version) - }, -} - -func init() { - rootCmd.AddCommand(versionCmd) +func Version() *cobra.Command { + return &cobra.Command{ + Use: "version", + Short: "Print the version", + RunE: func(cmd *cobra.Command, args []string) error { + fmt.Println("Version:", version) + return nil + }, + } } From 812555aed58beaabcbc9b51867f7d217b231f621 Mon Sep 17 00:00:00 2001 From: sbldevnet Date: Thu, 12 Mar 2026 13:57:19 +0100 Subject: [PATCH 07/12] refactor: introduce interfaces and dependency injection for AWS clients Extract IamClient, StsClient, and OrganizationsClient interfaces from concrete structs to enable mocking and testability. Add IamClientFactory for goroutines that create IAM clients dynamically after AssumeRoot. Add Logger interface with backward-compatible global functions. Commands now use NewRootManagerFromConfig convenience constructor instead of manually wiring AWS config and clients. DeleteCredentials and RecoverRootPassword return typed result slices (DeletionResult, RecoveryResult) instead of raw error channels and sync.Map. --- cmd/audit.go | 6 +- cmd/check.go | 6 +- cmd/delete.go | 36 +++++++---- cmd/enable.go | 6 +- cmd/recovery.go | 32 +++++++--- internal/infra/aws/factory.go | 21 +++++++ internal/infra/aws/iam.go | 30 ++++----- internal/infra/aws/interfaces.go | 61 ++++++++++++++++++ internal/infra/aws/organizations.go | 18 +++--- internal/infra/aws/sts.go | 10 +-- internal/logger/interface.go | 21 +++++++ internal/logger/logger.go | 67 +++++++++++++++++--- internal/service/audit.go | 8 +-- internal/service/configuration.go | 4 +- internal/service/credentials.go | 98 +++++++++++++++-------------- internal/service/rootmanager.go | 43 ++++++++----- rootmanager/api.go | 7 ++- rootmanager/types.go | 9 +++ 18 files changed, 339 insertions(+), 144 deletions(-) create mode 100644 internal/infra/aws/factory.go create mode 100644 internal/infra/aws/interfaces.go create mode 100644 internal/logger/interface.go diff --git a/cmd/audit.go b/cmd/audit.go index ff4df0d..76657c3 100644 --- a/cmd/audit.go +++ b/cmd/audit.go @@ -7,7 +7,6 @@ import ( "github.com/unicrons/aws-root-manager/internal/cli/output" "github.com/unicrons/aws-root-manager/internal/cli/ui" - "github.com/unicrons/aws-root-manager/internal/infra/aws" "github.com/unicrons/aws-root-manager/internal/logger" "github.com/unicrons/aws-root-manager/internal/service" @@ -24,9 +23,9 @@ func Audit() *cobra.Command { logger.Trace("cmd.audit", "audit called") ctx := context.Background() - awscfg, err := aws.LoadAWSConfig(ctx) + rm, err := service.NewRootManagerFromConfig(ctx) if err != nil { - logger.Error("cmd.audit", err, "failed to load aws config") + logger.Error("cmd.audit", err, "failed to initialize root manager") return err } @@ -41,7 +40,6 @@ func Audit() *cobra.Command { } logger.Debug("cmd.audit", "selected accounts: %s", strings.Join(auditAccounts, ", ")) - rm := service.NewRootManager(aws.NewIamClient(awscfg), aws.NewStsClient(awscfg), aws.NewOrganizationsClient(awscfg)) audit, err := rm.AuditAccounts(ctx, auditAccounts) if err != nil { logger.Error("cmd.audit", err, "failed to audit accounts") diff --git a/cmd/check.go b/cmd/check.go index 0662480..091c76b 100644 --- a/cmd/check.go +++ b/cmd/check.go @@ -5,7 +5,6 @@ import ( "strconv" "github.com/unicrons/aws-root-manager/internal/cli/output" - "github.com/unicrons/aws-root-manager/internal/infra/aws" "github.com/unicrons/aws-root-manager/internal/logger" "github.com/unicrons/aws-root-manager/internal/service" @@ -21,13 +20,12 @@ func Check() *cobra.Command { logger.Trace("cmd.check", "check called") ctx := context.Background() - awscfg, err := aws.LoadAWSConfig(ctx) + rm, err := service.NewRootManagerFromConfig(ctx) if err != nil { - logger.Error("cmd.check", err, "failed to load aws config") + logger.Error("cmd.check", err, "failed to initialize root manager") return err } - rm := service.NewRootManager(aws.NewIamClient(awscfg), nil, nil) status, err := rm.CheckRootAccess(ctx) if err != nil { logger.Error("cmd.check", err, "failed to check root access configuration") diff --git a/cmd/delete.go b/cmd/delete.go index 17bfdf7..15ef39b 100644 --- a/cmd/delete.go +++ b/cmd/delete.go @@ -7,7 +7,6 @@ import ( "github.com/unicrons/aws-root-manager/internal/cli/output" "github.com/unicrons/aws-root-manager/internal/cli/ui" - "github.com/unicrons/aws-root-manager/internal/infra/aws" "github.com/unicrons/aws-root-manager/internal/logger" "github.com/unicrons/aws-root-manager/internal/service" @@ -47,9 +46,9 @@ func deleteSubcommand(use, short, long string) *cobra.Command { func runDelete(accountsFlags []string, credentialType string) error { ctx := context.Background() - awscfg, err := aws.LoadAWSConfig(ctx) + rm, err := service.NewRootManagerFromConfig(ctx) if err != nil { - return fmt.Errorf("failed to load aws config: %w", err) + return fmt.Errorf("failed to initialize root manager: %w", err) } auditAccounts, err := ui.SelectTargetAccounts(ctx, accountsFlags) @@ -57,31 +56,44 @@ func runDelete(accountsFlags []string, credentialType string) error { return fmt.Errorf("failed to get accounts to audit: %w", err) } if len(auditAccounts) == 0 { - logger.Info("cmd.audit", "no accounts selected") + logger.Info("cmd.delete", "no accounts selected") return nil } - logger.Debug("cmd.audit", "selected accounts: %s", strings.Join(auditAccounts, ", ")) + logger.Debug("cmd.delete", "selected accounts: %s", strings.Join(auditAccounts, ", ")) - rm := service.NewRootManager(aws.NewIamClient(awscfg), aws.NewStsClient(awscfg), aws.NewOrganizationsClient(awscfg)) audit, err := rm.AuditAccounts(ctx, auditAccounts) if err != nil { return err } - if err = rm.DeleteCredentials(ctx, audit, credentialType); err != nil { + results, err := rm.DeleteCredentials(ctx, audit, credentialType) + if err != nil { return err } - headers := []string{"Account", "CredentialType", "Status"} + headers := []string{"Account", "CredentialType", "Status", "Error"} var data [][]any - for _, account := range auditAccounts { + var failureCount int + for _, result := range results { + status := "deleted" + errorMsg := "" + if !result.Success { + status = "failed" + errorMsg = result.Error + failureCount++ + } data = append(data, []any{ - account, - credentialType, - "deleted", // TODO: this is not real + result.AccountId, + result.CredentialType, + status, + errorMsg, }) } output.HandleOutput(outputFlag, headers, data) + if failureCount > 0 { + return fmt.Errorf("deletion failed for %d account(s)", failureCount) + } + return nil } diff --git a/cmd/enable.go b/cmd/enable.go index 28b5d21..6753404 100644 --- a/cmd/enable.go +++ b/cmd/enable.go @@ -5,7 +5,6 @@ import ( "strconv" "github.com/unicrons/aws-root-manager/internal/cli/output" - "github.com/unicrons/aws-root-manager/internal/infra/aws" "github.com/unicrons/aws-root-manager/internal/logger" "github.com/unicrons/aws-root-manager/internal/service" @@ -23,13 +22,12 @@ func Enable() *cobra.Command { enableRootSessions, _ := cmd.Flags().GetBool("enableRootSessions") ctx := context.Background() - awscfg, err := aws.LoadAWSConfig(ctx) + rm, err := service.NewRootManagerFromConfig(ctx) if err != nil { - logger.Error("cmd.enable", err, "failed to load aws config") + logger.Error("cmd.enable", err, "failed to initialize root manager") return err } - rm := service.NewRootManager(aws.NewIamClient(awscfg), nil, aws.NewOrganizationsClient(awscfg)) initStatus, status, err := rm.EnableRootAccess(ctx, enableRootSessions) if err != nil { logger.Error("cmd.enable", err, "failed to enable root access") diff --git a/cmd/recovery.go b/cmd/recovery.go index 60a2e15..30a1b1d 100644 --- a/cmd/recovery.go +++ b/cmd/recovery.go @@ -2,11 +2,11 @@ package cmd import ( "context" + "fmt" "strings" "github.com/unicrons/aws-root-manager/internal/cli/output" "github.com/unicrons/aws-root-manager/internal/cli/ui" - "github.com/unicrons/aws-root-manager/internal/infra/aws" "github.com/unicrons/aws-root-manager/internal/logger" "github.com/unicrons/aws-root-manager/internal/service" @@ -22,9 +22,9 @@ func Recovery() *cobra.Command { logger.Trace("cmd.recovery", "recovery called") ctx := context.Background() - awscfg, err := aws.LoadAWSConfig(ctx) + rm, err := service.NewRootManagerFromConfig(ctx) if err != nil { - logger.Error("cmd.recovery", err, "failed to load aws config") + logger.Error("cmd.recovery", err, "failed to initialize root manager") return err } @@ -39,24 +39,36 @@ func Recovery() *cobra.Command { } logger.Debug("cmd.recovery", "selected accounts: %s", strings.Join(targetAccounts, ", ")) - rm := service.NewRootManager(aws.NewIamClient(awscfg), aws.NewStsClient(awscfg), aws.NewOrganizationsClient(awscfg)) - resultMap, err := rm.RecoverRootPassword(ctx, targetAccounts) + results, err := rm.RecoverRootPassword(ctx, targetAccounts) if err != nil { logger.Error("cmd.recovery", err, "failed to recover root password") return err } - headers := []string{"Account", "Login Profile"} + headers := []string{"Account", "Login Profile", "Error"} var data [][]any - for acc, success := range resultMap { + var failureCount int + for _, result := range results { status := "recovered" - if !success { - status = "already exists" + errorMsg := "" + if !result.Success { + if result.Error != "" { + status = "failed" + errorMsg = result.Error + failureCount++ + } else { + status = "already exists" + } } - data = append(data, []any{acc, status}) + data = append(data, []any{result.AccountId, status, errorMsg}) } output.HandleOutput(outputFlag, headers, data) + + if failureCount > 0 { + return fmt.Errorf("recovery failed for %d account(s)", failureCount) + } + return nil }, } diff --git a/internal/infra/aws/factory.go b/internal/infra/aws/factory.go new file mode 100644 index 0000000..1c40507 --- /dev/null +++ b/internal/infra/aws/factory.go @@ -0,0 +1,21 @@ +package aws + +import ( + awssdk "github.com/aws/aws-sdk-go-v2/aws" +) + +// IamClientFactory creates IAM clients with a given AWS config. +// This abstraction enables dependency injection of client creation logic, +// which is especially important for goroutines that create clients dynamically +// after AssumeRoot. +type IamClientFactory interface { + NewIamClient(cfg awssdk.Config) IamClient +} + +// DefaultIamClientFactory is the production implementation of IamClientFactory. +type DefaultIamClientFactory struct{} + +// NewIamClient creates a new IAM client using the concrete implementation. +func (f *DefaultIamClientFactory) NewIamClient(cfg awssdk.Config) IamClient { + return NewIamClient(cfg) +} diff --git a/internal/infra/aws/iam.go b/internal/infra/aws/iam.go index 3f2b6d7..a8e1e60 100644 --- a/internal/infra/aws/iam.go +++ b/internal/infra/aws/iam.go @@ -14,17 +14,17 @@ import ( "github.com/aws/aws-sdk-go-v2/service/iam/types" ) -type IamClient struct { +type iamClient struct { client *iam.Client } -func NewIamClient(awscfg aws.Config) *IamClient { +func NewIamClient(awscfg aws.Config) IamClient { client := iam.NewFromConfig(awscfg) - return &IamClient{client: client} + return &iamClient{client: client} } // Verifies if AWS centralized root access is enabled -func (c *IamClient) CheckOrganizationRootAccess(ctx context.Context, rootSessionsRequired bool) error { +func (c *iamClient) CheckOrganizationRootAccess(ctx context.Context, rootSessionsRequired bool) error { logger.Trace("aws.CheckOrganizationRootAccess", "checking if organization root access is enabled") features, err := c.client.ListOrganizationsFeatures(ctx, &iam.ListOrganizationsFeaturesInput{}) @@ -54,7 +54,7 @@ func (c *IamClient) CheckOrganizationRootAccess(ctx context.Context, rootSession } // Check if an account has root login profile enabled -func (c *IamClient) GetLoginProfile(ctx context.Context, accountId string) (bool, error) { +func (c *iamClient) GetLoginProfile(ctx context.Context, accountId string) (bool, error) { logger.Debug("aws.GetLoginProfile", "getting login profile for account %s", accountId) _, err := c.client.GetLoginProfile(ctx, &iam.GetLoginProfileInput{}) @@ -71,7 +71,7 @@ func (c *IamClient) GetLoginProfile(ctx context.Context, accountId string) (bool } // Delete root login profile for a specific account -func (c *IamClient) DeleteLoginProfile(ctx context.Context, accountId string) error { +func (c *iamClient) DeleteLoginProfile(ctx context.Context, accountId string) error { logger.Debug("aws.DeleteLoginProfile", "deleting login profile for account %s", accountId) _, err := c.client.DeleteLoginProfile(ctx, &iam.DeleteLoginProfileInput{}) @@ -85,7 +85,7 @@ func (c *IamClient) DeleteLoginProfile(ctx context.Context, accountId string) er } // Get a list of root access keys for a specific account -func (c *IamClient) ListAccessKeys(ctx context.Context, accountId string) ([]string, error) { +func (c *iamClient) ListAccessKeys(ctx context.Context, accountId string) ([]string, error) { logger.Debug("aws.ListAccessKeys", "listing access keys for account %s", accountId) accessKeys, err := c.client.ListAccessKeys(ctx, &iam.ListAccessKeysInput{}) @@ -102,7 +102,7 @@ func (c *IamClient) ListAccessKeys(ctx context.Context, accountId string) ([]str } // Delete a list of root access for a specific account -func (c *IamClient) DeleteAccessKeys(ctx context.Context, accountId string, accessKeyIds []string) error { +func (c *iamClient) DeleteAccessKeys(ctx context.Context, accountId string, accessKeyIds []string) error { logger.Debug("aws.DeleteAccessKeys", "deleting root access key %s for account %s", accessKeyIds, accountId) for _, accessKeyId := range accessKeyIds { @@ -120,7 +120,7 @@ func (c *IamClient) DeleteAccessKeys(ctx context.Context, accountId string, acce } // Get a list of root MFA devices for a specific account -func (c *IamClient) ListMFADevices(ctx context.Context, accountId string) ([]string, error) { +func (c *iamClient) ListMFADevices(ctx context.Context, accountId string) ([]string, error) { mfaDevices, err := c.client.ListMFADevices(ctx, &iam.ListMFADevicesInput{}) if err != nil { return nil, fmt.Errorf("error listing root mfa devices for account %s: %w", accountId, err) @@ -135,7 +135,7 @@ func (c *IamClient) ListMFADevices(ctx context.Context, accountId string) ([]str } // Deactivate a list of root MFA devices for a specific account -func (c *IamClient) DeactivateMFADevices(ctx context.Context, accountId string, mfaSerialNumbers []string) error { +func (c *iamClient) DeactivateMFADevices(ctx context.Context, accountId string, mfaSerialNumbers []string) error { logger.Debug("aws.DeactivateMFADevices", "deleting root mfa device %s for account %s", mfaSerialNumbers, accountId) for _, mfaSerialNumber := range mfaSerialNumbers { @@ -153,7 +153,7 @@ func (c *IamClient) DeactivateMFADevices(ctx context.Context, accountId string, } // Get a list of root signing certificates devices for a specific account -func (c *IamClient) ListSigningCertificates(ctx context.Context, accountId string) ([]string, error) { +func (c *iamClient) ListSigningCertificates(ctx context.Context, accountId string) ([]string, error) { certificates, err := c.client.ListSigningCertificates(ctx, &iam.ListSigningCertificatesInput{}) if err != nil { return nil, fmt.Errorf("error listing signing certificates for account %s: %w", accountId, err) @@ -168,7 +168,7 @@ func (c *IamClient) ListSigningCertificates(ctx context.Context, accountId strin } // Delete a list of root signing certificates for a specific account -func (c *IamClient) DeleteSigningCertificates(ctx context.Context, accountId string, certificates []string) error { +func (c *iamClient) DeleteSigningCertificates(ctx context.Context, accountId string, certificates []string) error { logger.Debug("aws.DeleteSigningCertificates", "deleting singin certificates %s for account %s", certificates, accountId) for _, certificate := range certificates { @@ -187,7 +187,7 @@ func (c *IamClient) DeleteSigningCertificates(ctx context.Context, accountId str } // Enable centralized root credentials management -func (c *IamClient) EnableOrganizationsRootCredentialsManagement(ctx context.Context) error { +func (c *iamClient) EnableOrganizationsRootCredentialsManagement(ctx context.Context) error { logger.Debug("aws.EnableOrganizationsRootCredentialsManagement", "enabling organization root credentials management") _, err := c.client.EnableOrganizationsRootCredentialsManagement(ctx, &iam.EnableOrganizationsRootCredentialsManagementInput{}) @@ -201,7 +201,7 @@ func (c *IamClient) EnableOrganizationsRootCredentialsManagement(ctx context.Con } // Enable centralized root sessions -func (c *IamClient) EnableOrganizationsRootSessions(ctx context.Context) error { +func (c *iamClient) EnableOrganizationsRootSessions(ctx context.Context) error { logger.Debug("aws.EnableOrganizationsRootSessions", "enabling organization root sessions") _, err := c.client.EnableOrganizationsRootSessions(ctx, &iam.EnableOrganizationsRootSessionsInput{}) @@ -215,7 +215,7 @@ func (c *IamClient) EnableOrganizationsRootSessions(ctx context.Context) error { } // Allow root password recovery -func (c *IamClient) CreateLoginProfile(ctx context.Context) error { +func (c *iamClient) CreateLoginProfile(ctx context.Context) error { logger.Debug("aws.createLoginProfile", "creating loggin profile") _, err := c.client.CreateLoginProfile(ctx, &iam.CreateLoginProfileInput{}) diff --git a/internal/infra/aws/interfaces.go b/internal/infra/aws/interfaces.go new file mode 100644 index 0000000..32f4fbd --- /dev/null +++ b/internal/infra/aws/interfaces.go @@ -0,0 +1,61 @@ +package aws + +import ( + "context" + + awssdk "github.com/aws/aws-sdk-go-v2/aws" +) + +// IamClient defines the interface for IAM operations. +// This interface enables mocking and dependency injection for testing. +type IamClient interface { + // CheckOrganizationRootAccess verifies if AWS centralized root access is enabled + CheckOrganizationRootAccess(ctx context.Context, rootSessionsRequired bool) error + + // GetLoginProfile checks if an account has root login profile enabled + GetLoginProfile(ctx context.Context, accountId string) (bool, error) + + // DeleteLoginProfile deletes root login profile for a specific account + DeleteLoginProfile(ctx context.Context, accountId string) error + + // ListAccessKeys gets a list of root access keys for a specific account + ListAccessKeys(ctx context.Context, accountId string) ([]string, error) + + // DeleteAccessKeys deletes a list of root access keys for a specific account + DeleteAccessKeys(ctx context.Context, accountId string, accessKeyIds []string) error + + // ListMFADevices gets a list of root MFA devices for a specific account + ListMFADevices(ctx context.Context, accountId string) ([]string, error) + + // DeactivateMFADevices deactivates a list of root MFA devices for a specific account + DeactivateMFADevices(ctx context.Context, accountId string, mfaSerialNumbers []string) error + + // ListSigningCertificates gets a list of root signing certificates for a specific account + ListSigningCertificates(ctx context.Context, accountId string) ([]string, error) + + // DeleteSigningCertificates deletes a list of root signing certificates for a specific account + DeleteSigningCertificates(ctx context.Context, accountId string, certificates []string) error + + // EnableOrganizationsRootCredentialsManagement enables centralized root credentials management + EnableOrganizationsRootCredentialsManagement(ctx context.Context) error + + // EnableOrganizationsRootSessions enables centralized root sessions + EnableOrganizationsRootSessions(ctx context.Context) error + + // CreateLoginProfile allows root password recovery + CreateLoginProfile(ctx context.Context) error +} + +// StsClient defines the interface for STS operations. +// This interface enables mocking and dependency injection for testing. +type StsClient interface { + // GetAssumeRootConfig gets AWS config with assumed root credentials for a specific account and task + GetAssumeRootConfig(ctx context.Context, accountId, taskPolicyName string) (awssdk.Config, error) +} + +// OrganizationsClient defines the interface for AWS Organizations operations. +// This interface enables mocking and dependency injection for testing. +type OrganizationsClient interface { + // EnableAWSServiceAccess enables AWS service access for the organization + EnableAWSServiceAccess(ctx context.Context, service string) error +} diff --git a/internal/infra/aws/organizations.go b/internal/infra/aws/organizations.go index 0372ace..a56fd76 100644 --- a/internal/infra/aws/organizations.go +++ b/internal/infra/aws/organizations.go @@ -11,13 +11,13 @@ import ( "github.com/aws/aws-sdk-go-v2/service/organizations/types" ) -type OrganizationsClient struct { +type organizationsClient struct { client *organizations.Client } -func NewOrganizationsClient(awscfg aws.Config) *OrganizationsClient { +func NewOrganizationsClient(awscfg aws.Config) OrganizationsClient { client := organizations.NewFromConfig(awscfg) - return &OrganizationsClient{client: client} + return &organizationsClient{client: client} } type OrganizationAccount struct { @@ -34,14 +34,14 @@ func GetNonManagementOrganizationAccounts(ctx context.Context) ([]OrganizationAc return nil, fmt.Errorf("failed to load aws config: %w", err) } - organizations := NewOrganizationsClient(awscfg) + orgs := NewOrganizationsClient(awscfg) - mgmAccount, err := organizations.describeOrganization(ctx) + mgmAccount, err := orgs.(*organizationsClient).describeOrganization(ctx) if err != nil { return nil, err } - orgAccounts, err := organizations.listOrganizationAccounts() + orgAccounts, err := orgs.(*organizationsClient).listOrganizationAccounts() if err != nil { return nil, err } @@ -60,7 +60,7 @@ func GetNonManagementOrganizationAccounts(ctx context.Context) ([]OrganizationAc return nonManagementOrgAccounts, nil } -func (c *OrganizationsClient) listOrganizationAccounts() ([]types.Account, error) { +func (c *organizationsClient) listOrganizationAccounts() ([]types.Account, error) { logger.Trace("aws.listOrganizationAccounts", "listing organization accounts") params := &organizations.ListAccountsInput{} @@ -79,7 +79,7 @@ func (c *OrganizationsClient) listOrganizationAccounts() ([]types.Account, error return allAccounts, nil } -func (c *OrganizationsClient) describeOrganization(ctx context.Context) (string, error) { +func (c *organizationsClient) describeOrganization(ctx context.Context) (string, error) { logger.Trace("aws.describeOrganization", "describing organization") organization, err := c.client.DescribeOrganization(ctx, &organizations.DescribeOrganizationInput{}) @@ -90,7 +90,7 @@ func (c *OrganizationsClient) describeOrganization(ctx context.Context) (string, return *organization.Organization.MasterAccountId, nil } -func (c *OrganizationsClient) EnableAWSServiceAccess(ctx context.Context, service string) error { +func (c *organizationsClient) EnableAWSServiceAccess(ctx context.Context, service string) error { logger.Trace("aws.EnableAWSServiceAccess", "enabling %s service access", service) _, err := c.client.EnableAWSServiceAccess(ctx, &organizations.EnableAWSServiceAccessInput{ diff --git a/internal/infra/aws/sts.go b/internal/infra/aws/sts.go index 16796e3..d596644 100644 --- a/internal/infra/aws/sts.go +++ b/internal/infra/aws/sts.go @@ -13,16 +13,16 @@ import ( const rootPolicyPrefix = "arn:aws:iam::aws:policy/root-task/" -type StsClient struct { +type stsClient struct { client *sts.Client } -func NewStsClient(awscfg aws.Config) *StsClient { +func NewStsClient(awscfg aws.Config) StsClient { client := sts.NewFromConfig(awscfg) - return &StsClient{client: client} + return &stsClient{client: client} } -func (c *StsClient) GetAssumeRootConfig(ctx context.Context, accountId, taskPolicyName string) (aws.Config, error) { +func (c *stsClient) GetAssumeRootConfig(ctx context.Context, accountId, taskPolicyName string) (aws.Config, error) { logger.Trace("aws.GetAssumeRootConfig", "getting root aws.config account %s and task %s", accountId, taskPolicyName) stsCreds, err := c.assumeRoot(ctx, accountId, taskPolicyName) @@ -48,7 +48,7 @@ func (c *StsClient) GetAssumeRootConfig(ctx context.Context, accountId, taskPoli return awsrootcfg, nil } -func (c *StsClient) assumeRoot(ctx context.Context, accountId, taskPolicyName string) (types.Credentials, error) { +func (c *stsClient) assumeRoot(ctx context.Context, accountId, taskPolicyName string) (types.Credentials, error) { logger.Trace("aws.assumeRoot", "assuming root for account %s and task %s", accountId, taskPolicyName) params := &sts.AssumeRootInput{ diff --git a/internal/logger/interface.go b/internal/logger/interface.go new file mode 100644 index 0000000..1c3c8dc --- /dev/null +++ b/internal/logger/interface.go @@ -0,0 +1,21 @@ +package logger + +// Logger defines the interface for logging operations. +// This interface enables dependency injection and allows for alternative implementations +// such as no-op loggers for testing or custom loggers for different environments. +type Logger interface { + // Trace logs a trace-level message with function name and formatted string + Trace(funcName, format string, args ...any) + + // Debug logs a debug-level message with function name and formatted string + Debug(funcName, format string, args ...any) + + // Info logs an info-level message with function name and formatted string + Info(funcName, format string, args ...any) + + // Warn logs a warning-level message with function name and formatted string + Warn(funcName, format string, args ...any) + + // Error logs an error-level message with function name, error, and formatted string + Error(funcName string, err error, format string, args ...any) +} diff --git a/internal/logger/logger.go b/internal/logger/logger.go index f57d097..08f0d93 100644 --- a/internal/logger/logger.go +++ b/internal/logger/logger.go @@ -6,7 +6,15 @@ import ( log "github.com/sirupsen/logrus" ) -func init() { +// logrusLogger implements the Logger interface using logrus. +type logrusLogger struct { + logger *log.Logger +} + +// New creates a new Logger instance configured from environment variables. +func New() Logger { + logger := log.New() + lvl, ok := os.LookupEnv("LOG_LEVEL") if !ok { lvl = "error" // default @@ -16,15 +24,56 @@ func init() { if err != nil { ll = log.DebugLevel } - log.SetLevel(ll) + logger.SetLevel(ll) format, ok := os.LookupEnv("LOG_FORMAT") if ok { - SetLoggerFormat(format) + setLoggerFormat(logger, format) } + + return &logrusLogger{logger: logger} +} + +func setLoggerFormat(logger *log.Logger, logFormat string) { + switch logFormat { + case "json": + logger.SetFormatter(&log.JSONFormatter{}) + default: + logger.SetFormatter(&log.TextFormatter{}) + } +} + +func (l *logrusLogger) Trace(funcName, format string, args ...any) { + l.logger.WithField("function", funcName).Tracef(format, args...) +} + +func (l *logrusLogger) Debug(funcName, format string, args ...any) { + l.logger.WithField("function", funcName).Debugf(format, args...) +} + +func (l *logrusLogger) Info(funcName, format string, args ...any) { + l.logger.WithField("function", funcName).Infof(format, args...) +} + +func (l *logrusLogger) Warn(funcName, format string, args ...any) { + l.logger.WithField("function", funcName).Warnf(format, args...) +} + +func (l *logrusLogger) Error(funcName string, err error, format string, args ...any) { + l.logger.WithField("function", funcName).WithError(err).Errorf(format, args...) +} + +// Global logger functions for backward compatibility +// These will be deprecated once all code uses dependency injection + +var defaultLogger Logger + +func init() { + defaultLogger = New() } func SetLoggerFormat(logFormat string) { + // For backward compatibility with existing code switch logFormat { case "json": log.SetFormatter(&log.JSONFormatter{}) @@ -33,23 +82,23 @@ func SetLoggerFormat(logFormat string) { } } -// Wrap logrus with function name +// Wrap logrus with function name - global functions for backward compatibility func Trace(funcName, format string, args ...any) { - log.WithField("function", funcName).Tracef(format, args...) + defaultLogger.Trace(funcName, format, args...) } func Debug(funcName, format string, args ...any) { - log.WithField("function", funcName).Debugf(format, args...) + defaultLogger.Debug(funcName, format, args...) } func Info(funcName, format string, args ...any) { - log.WithField("function", funcName).Infof(format, args...) + defaultLogger.Info(funcName, format, args...) } func Warn(funcName, format string, args ...any) { - log.WithField("function", funcName).Warnf(format, args...) + defaultLogger.Warn(funcName, format, args...) } func Error(funcName string, err error, format string, args ...any) { - log.WithField("function", funcName).WithError(err).Errorf(format, args...) + defaultLogger.Error(funcName, err, format, args...) } diff --git a/internal/service/audit.go b/internal/service/audit.go index d234f18..633a584 100644 --- a/internal/service/audit.go +++ b/internal/service/audit.go @@ -10,7 +10,7 @@ import ( ) // auditAccounts returns root credentials for a list of AWS accounts. -func auditAccounts(ctx context.Context, iam *aws.IamClient, sts *aws.StsClient, accounts []string) ([]rootmanager.RootCredentials, error) { +func auditAccounts(ctx context.Context, iam aws.IamClient, sts aws.StsClient, factory aws.IamClientFactory, accounts []string) ([]rootmanager.RootCredentials, error) { logger.Trace("service.auditAccounts", "auditing accounts %s", accounts) rootCredentials := make([]rootmanager.RootCredentials, len(accounts)) @@ -24,7 +24,7 @@ func auditAccounts(ctx context.Context, iam *aws.IamClient, sts *aws.StsClient, wgAccounts.Add(1) go func(idx int, accountId string) { defer wgAccounts.Done() - if accStatus, err := auditAccount(ctx, sts, accountId); err != nil { + if accStatus, err := auditAccount(ctx, sts, factory, accountId); err != nil { logger.Error("service.auditAccounts", err, "account %s: audit skipped", accountId) rootCredentials[idx] = rootmanager.RootCredentials{AccountId: accountId, Error: err.Error()} } else { @@ -39,7 +39,7 @@ func auditAccounts(ctx context.Context, iam *aws.IamClient, sts *aws.StsClient, } // Get root credentials for a specific account -func auditAccount(ctx context.Context, sts *aws.StsClient, accountId string) (rootmanager.RootCredentials, error) { +func auditAccount(ctx context.Context, sts aws.StsClient, factory aws.IamClientFactory, accountId string) (rootmanager.RootCredentials, error) { logger.Trace("service.auditAccount", "auditing account %s", accountId) awscfgRoot, err := sts.GetAssumeRootConfig(ctx, accountId, "IAMAuditRootUserCredentials") @@ -47,7 +47,7 @@ func auditAccount(ctx context.Context, sts *aws.StsClient, accountId string) (ro return rootmanager.RootCredentials{}, err } - iamRoot := aws.NewIamClient(awscfgRoot) + iamRoot := factory.NewIamClient(awscfgRoot) var accountRootCredentials rootmanager.RootCredentials loginProfile, err := iamRoot.GetLoginProfile(ctx, accountId) diff --git a/internal/service/configuration.go b/internal/service/configuration.go index 5ed71b6..3c13daf 100644 --- a/internal/service/configuration.go +++ b/internal/service/configuration.go @@ -9,7 +9,7 @@ import ( "github.com/unicrons/aws-root-manager/rootmanager" ) -func checkRootAccess(ctx context.Context, iam *aws.IamClient) (rootmanager.RootAccessStatus, error) { +func checkRootAccess(ctx context.Context, iam aws.IamClient) (rootmanager.RootAccessStatus, error) { var status = rootmanager.RootAccessStatus{ TrustedAccess: false, RootCredentialsManagement: false, @@ -44,7 +44,7 @@ func checkRootAccess(ctx context.Context, iam *aws.IamClient) (rootmanager.RootA return status, nil } -func enableRootAccess(ctx context.Context, iam *aws.IamClient, org *aws.OrganizationsClient, enableSessions bool) (rootmanager.RootAccessStatus, rootmanager.RootAccessStatus, error) { +func enableRootAccess(ctx context.Context, iam aws.IamClient, org aws.OrganizationsClient, enableSessions bool) (rootmanager.RootAccessStatus, rootmanager.RootAccessStatus, error) { var initStatus, status rootmanager.RootAccessStatus initStatus, err := checkRootAccess(ctx, iam) diff --git a/internal/service/credentials.go b/internal/service/credentials.go index f838694..1ee516f 100644 --- a/internal/service/credentials.go +++ b/internal/service/credentials.go @@ -11,38 +11,45 @@ import ( ) // deleteAccountsCredentials deletes root credentials for a list of AWS accounts. -func deleteAccountsCredentials(ctx context.Context, iam *aws.IamClient, sts *aws.StsClient, creds []rootmanager.RootCredentials, credentialType string) error { - var ( - wgAccounts sync.WaitGroup - errChan = make(chan error, len(creds)) - ) - +// Returns a slice of DeletionResult containing the outcome for each account. +func deleteAccountsCredentials(ctx context.Context, iam aws.IamClient, sts aws.StsClient, factory aws.IamClientFactory, creds []rootmanager.RootCredentials, credentialType string) ([]rootmanager.DeletionResult, error) { if err := iam.CheckOrganizationRootAccess(ctx, false); err != nil { - return err + return nil, err } - for _, accountCredentials := range creds { + results := make([]rootmanager.DeletionResult, len(creds)) + var wgAccounts sync.WaitGroup + + for i, accountCredentials := range creds { wgAccounts.Add(1) - go func(accountCreds rootmanager.RootCredentials) { + go func(idx int, accountCreds rootmanager.RootCredentials) { defer wgAccounts.Done() - if err := deleteAccountCredentials(ctx, sts, accountCreds, credentialType); err != nil { - errChan <- err + if err := deleteAccountCredentials(ctx, sts, factory, accountCreds, credentialType); err != nil { + logger.Error("service.deleteAccountsCredentials", err, "account %s: deletion failed", accountCreds.AccountId) + results[idx] = rootmanager.DeletionResult{ + AccountId: accountCreds.AccountId, + CredentialType: credentialType, + Success: false, + Error: err.Error(), + } + } else { + results[idx] = rootmanager.DeletionResult{ + AccountId: accountCreds.AccountId, + CredentialType: credentialType, + Success: true, + Error: "", + } } - }(accountCredentials) + }(i, accountCredentials) } wgAccounts.Wait() - close(errChan) - - if len(errChan) > 0 { - return <-errChan - } - return nil + return results, nil } // deleteAccountCredentials deletes root credentials for a specific account. -func deleteAccountCredentials(ctx context.Context, sts *aws.StsClient, creds rootmanager.RootCredentials, credentialType string) error { +func deleteAccountCredentials(ctx context.Context, sts aws.StsClient, factory aws.IamClientFactory, creds rootmanager.RootCredentials, credentialType string) error { logger.Trace("service.deleteAccountCredentials", "checking if account %s has %s credentials to delete", credentialType, credentialType) // Check if there are credentials to delete before assuming root @@ -55,7 +62,7 @@ func deleteAccountCredentials(ctx context.Context, sts *aws.StsClient, creds roo if err != nil { return err } - iamDeleteRoot := aws.NewIamClient(awscfgDeleteRoot) + iamDeleteRoot := factory.NewIamClient(awscfgDeleteRoot) if creds.LoginProfile && (credentialType == "all" || credentialType == "login") { err = iamDeleteRoot.DeleteLoginProfile(ctx, creds.AccountId) @@ -107,54 +114,51 @@ func hasCredentialsToDelete(creds rootmanager.RootCredentials, credentialType st } // recoverAccountsRootPassword initiates root password recovery for a list of AWS accounts. -func recoverAccountsRootPassword(ctx context.Context, iam *aws.IamClient, sts *aws.StsClient, accountIds []string) (map[string]bool, error) { - var ( - wgAccounts sync.WaitGroup - results = sync.Map{} - errChan = make(chan error, len(accountIds)) - ) - +// Returns a slice of RecoveryResult containing the outcome for each account. +func recoverAccountsRootPassword(ctx context.Context, iam aws.IamClient, sts aws.StsClient, factory aws.IamClientFactory, accountIds []string) ([]rootmanager.RecoveryResult, error) { if err := iam.CheckOrganizationRootAccess(ctx, false); err != nil { return nil, err } - for _, acc := range accountIds { + results := make([]rootmanager.RecoveryResult, len(accountIds)) + var wgAccounts sync.WaitGroup + + for i, accountId := range accountIds { wgAccounts.Add(1) - go func(accountId string) { + go func(idx int, accId string) { defer wgAccounts.Done() - success, err := recoverAccountRootPassowrd(ctx, sts, acc) - results.Store(accountId, success) + success, err := recoverAccountRootPassowrd(ctx, sts, factory, accId) if err != nil { - errChan <- err + logger.Error("service.recoverAccountsRootPassword", err, "account %s: recovery failed", accId) + results[idx] = rootmanager.RecoveryResult{ + AccountId: accId, + Success: false, + Error: err.Error(), + } + } else { + results[idx] = rootmanager.RecoveryResult{ + AccountId: accId, + Success: success, + Error: "", + } } - }(acc) + }(i, accountId) } wgAccounts.Wait() - close(errChan) - - resultMap := make(map[string]bool) - results.Range(func(key, value any) bool { - resultMap[key.(string)] = value.(bool) - return true - }) - - if len(errChan) > 0 { - return resultMap, <-errChan - } - return resultMap, nil + return results, nil } // Enable the recovery process for root passwords for a specific account -func recoverAccountRootPassowrd(ctx context.Context, sts *aws.StsClient, accountId string) (bool, error) { +func recoverAccountRootPassowrd(ctx context.Context, sts aws.StsClient, factory aws.IamClientFactory, accountId string) (bool, error) { logger.Trace("service.recoverAccountRootPassowrd", "trying to recover root password for account %s ", accountId) awscfgRecoverRoot, err := sts.GetAssumeRootConfig(ctx, accountId, "IAMCreateRootUserPassword") if err != nil { return false, err } - iamRecoverRoot := aws.NewIamClient(awscfgRecoverRoot) + iamRecoverRoot := factory.NewIamClient(awscfgRecoverRoot) err = iamRecoverRoot.CreateLoginProfile(ctx) if err != nil { diff --git a/internal/service/rootmanager.go b/internal/service/rootmanager.go index 3832f15..054a859 100644 --- a/internal/service/rootmanager.go +++ b/internal/service/rootmanager.go @@ -10,22 +10,37 @@ import ( // manager implements rootmanager.RootManager using AWS clients. type manager struct { - iam *aws.IamClient - sts *aws.StsClient - org *aws.OrganizationsClient + iam aws.IamClient + sts aws.StsClient + org aws.OrganizationsClient + factory aws.IamClientFactory } -// NewRootManager returns a RootManager that uses the given AWS clients. +// NewRootManager returns a RootManager that uses the given AWS clients and factory. // sts and org may be nil for callers that only use CheckRootAccess. -func NewRootManager(iam *aws.IamClient, sts *aws.StsClient, org *aws.OrganizationsClient) rootmanager.RootManager { - return &manager{iam: iam, sts: sts, org: org} +func NewRootManager(iam aws.IamClient, sts aws.StsClient, org aws.OrganizationsClient, factory aws.IamClientFactory) rootmanager.RootManager { + return &manager{iam: iam, sts: sts, org: org, factory: factory} +} + +// NewRootManagerFromConfig loads the default AWS config and returns a ready-to-use RootManager. +func NewRootManagerFromConfig(ctx context.Context) (rootmanager.RootManager, error) { + cfg, err := aws.LoadAWSConfig(ctx) + if err != nil { + return nil, err + } + return NewRootManager( + aws.NewIamClient(cfg), + aws.NewStsClient(cfg), + aws.NewOrganizationsClient(cfg), + &aws.DefaultIamClientFactory{}, + ), nil } func (m *manager) AuditAccounts(ctx context.Context, accountIds []string) ([]rootmanager.RootCredentials, error) { if m.sts == nil { return nil, errors.New("STS client required for audit") } - return auditAccounts(ctx, m.iam, m.sts, accountIds) + return auditAccounts(ctx, m.iam, m.sts, m.factory, accountIds) } func (m *manager) CheckRootAccess(ctx context.Context) (rootmanager.RootAccessStatus, error) { @@ -39,20 +54,16 @@ func (m *manager) EnableRootAccess(ctx context.Context, enableSessions bool) (ro return enableRootAccess(ctx, m.iam, m.org, enableSessions) } -func (m *manager) DeleteCredentials(ctx context.Context, creds []rootmanager.RootCredentials, credentialType string) error { +func (m *manager) DeleteCredentials(ctx context.Context, creds []rootmanager.RootCredentials, credentialType string) ([]rootmanager.DeletionResult, error) { if m.sts == nil { - return errors.New("STS client required for delete") + return nil, errors.New("STS client required for delete") } - return deleteAccountsCredentials(ctx, m.iam, m.sts, creds, credentialType) + return deleteAccountsCredentials(ctx, m.iam, m.sts, m.factory, creds, credentialType) } -func (m *manager) RecoverRootPassword(ctx context.Context, accountIds []string) (map[string]bool, error) { +func (m *manager) RecoverRootPassword(ctx context.Context, accountIds []string) ([]rootmanager.RecoveryResult, error) { if m.sts == nil { return nil, errors.New("STS client required for recovery") } - resultMap, err := recoverAccountsRootPassword(ctx, m.iam, m.sts, accountIds) - if err != nil { - return nil, err - } - return resultMap, nil + return recoverAccountsRootPassword(ctx, m.iam, m.sts, m.factory, accountIds) } diff --git a/rootmanager/api.go b/rootmanager/api.go index 3281eba..a97584a 100644 --- a/rootmanager/api.go +++ b/rootmanager/api.go @@ -20,10 +20,11 @@ type RootManager interface { // DeleteCredentials deletes root credentials for the specified accounts. // The creds parameter should contain audit results identifying what credentials exist. // The credentialType parameter specifies what to delete: "all", "login", "keys", "mfa", or "certificate". - DeleteCredentials(ctx context.Context, creds []RootCredentials, credentialType string) error + // Returns a slice of DeletionResult showing the outcome for each account. + DeleteCredentials(ctx context.Context, creds []RootCredentials, credentialType string) ([]DeletionResult, error) // RecoverRootPassword initiates root password recovery for the specified accounts. // This triggers AWS to send password reset emails to the account's root email address. - // Returns a map of account ID to success status, and any error encountered. - RecoverRootPassword(ctx context.Context, accountIds []string) (map[string]bool, error) + // Returns a slice of RecoveryResult showing the outcome for each account. + RecoverRootPassword(ctx context.Context, accountIds []string) ([]RecoveryResult, error) } diff --git a/rootmanager/types.go b/rootmanager/types.go index 94ba336..59f5fb1 100644 --- a/rootmanager/types.go +++ b/rootmanager/types.go @@ -21,4 +21,13 @@ type RootCredentials struct { type RecoveryResult struct { AccountId string // AWS account ID Success bool // Whether recovery email was successfully sent + Error string // Error message if recovery failed (empty if Success=true) +} + +// DeletionResult represents the result of a credential deletion operation for an account. +type DeletionResult struct { + AccountId string // AWS account ID + CredentialType string // Type of credential deleted (login, keys, mfa, certificate, all) + Success bool // Whether deletion was successful + Error string // Error message if deletion failed (empty if Success=true) } From e522de7a7628f5d6757472f3e4b44ebb808e6e0a Mon Sep 17 00:00:00 2001 From: sbldevnet Date: Thu, 12 Mar 2026 14:14:42 +0100 Subject: [PATCH 08/12] revert: remove logger interface and restore simple global logger --- internal/logger/interface.go | 21 ----------- internal/logger/logger.go | 67 +++++------------------------------- 2 files changed, 9 insertions(+), 79 deletions(-) delete mode 100644 internal/logger/interface.go diff --git a/internal/logger/interface.go b/internal/logger/interface.go deleted file mode 100644 index 1c3c8dc..0000000 --- a/internal/logger/interface.go +++ /dev/null @@ -1,21 +0,0 @@ -package logger - -// Logger defines the interface for logging operations. -// This interface enables dependency injection and allows for alternative implementations -// such as no-op loggers for testing or custom loggers for different environments. -type Logger interface { - // Trace logs a trace-level message with function name and formatted string - Trace(funcName, format string, args ...any) - - // Debug logs a debug-level message with function name and formatted string - Debug(funcName, format string, args ...any) - - // Info logs an info-level message with function name and formatted string - Info(funcName, format string, args ...any) - - // Warn logs a warning-level message with function name and formatted string - Warn(funcName, format string, args ...any) - - // Error logs an error-level message with function name, error, and formatted string - Error(funcName string, err error, format string, args ...any) -} diff --git a/internal/logger/logger.go b/internal/logger/logger.go index 08f0d93..f57d097 100644 --- a/internal/logger/logger.go +++ b/internal/logger/logger.go @@ -6,15 +6,7 @@ import ( log "github.com/sirupsen/logrus" ) -// logrusLogger implements the Logger interface using logrus. -type logrusLogger struct { - logger *log.Logger -} - -// New creates a new Logger instance configured from environment variables. -func New() Logger { - logger := log.New() - +func init() { lvl, ok := os.LookupEnv("LOG_LEVEL") if !ok { lvl = "error" // default @@ -24,56 +16,15 @@ func New() Logger { if err != nil { ll = log.DebugLevel } - logger.SetLevel(ll) + log.SetLevel(ll) format, ok := os.LookupEnv("LOG_FORMAT") if ok { - setLoggerFormat(logger, format) + SetLoggerFormat(format) } - - return &logrusLogger{logger: logger} -} - -func setLoggerFormat(logger *log.Logger, logFormat string) { - switch logFormat { - case "json": - logger.SetFormatter(&log.JSONFormatter{}) - default: - logger.SetFormatter(&log.TextFormatter{}) - } -} - -func (l *logrusLogger) Trace(funcName, format string, args ...any) { - l.logger.WithField("function", funcName).Tracef(format, args...) -} - -func (l *logrusLogger) Debug(funcName, format string, args ...any) { - l.logger.WithField("function", funcName).Debugf(format, args...) -} - -func (l *logrusLogger) Info(funcName, format string, args ...any) { - l.logger.WithField("function", funcName).Infof(format, args...) -} - -func (l *logrusLogger) Warn(funcName, format string, args ...any) { - l.logger.WithField("function", funcName).Warnf(format, args...) -} - -func (l *logrusLogger) Error(funcName string, err error, format string, args ...any) { - l.logger.WithField("function", funcName).WithError(err).Errorf(format, args...) -} - -// Global logger functions for backward compatibility -// These will be deprecated once all code uses dependency injection - -var defaultLogger Logger - -func init() { - defaultLogger = New() } func SetLoggerFormat(logFormat string) { - // For backward compatibility with existing code switch logFormat { case "json": log.SetFormatter(&log.JSONFormatter{}) @@ -82,23 +33,23 @@ func SetLoggerFormat(logFormat string) { } } -// Wrap logrus with function name - global functions for backward compatibility +// Wrap logrus with function name func Trace(funcName, format string, args ...any) { - defaultLogger.Trace(funcName, format, args...) + log.WithField("function", funcName).Tracef(format, args...) } func Debug(funcName, format string, args ...any) { - defaultLogger.Debug(funcName, format, args...) + log.WithField("function", funcName).Debugf(format, args...) } func Info(funcName, format string, args ...any) { - defaultLogger.Info(funcName, format, args...) + log.WithField("function", funcName).Infof(format, args...) } func Warn(funcName, format string, args ...any) { - defaultLogger.Warn(funcName, format, args...) + log.WithField("function", funcName).Warnf(format, args...) } func Error(funcName string, err error, format string, args ...any) { - defaultLogger.Error(funcName, err, format, args...) + log.WithField("function", funcName).WithError(err).Errorf(format, args...) } From 674a3466911d35e4b20aa8de5d44cc0d38280c1d Mon Sep 17 00:00:00 2001 From: sbldevnet Date: Thu, 12 Mar 2026 14:31:44 +0100 Subject: [PATCH 09/12] refactor(logger): replace logrus with log/slog --- cmd/audit.go | 14 +++--- cmd/check.go | 8 ++-- cmd/delete.go | 6 +-- cmd/enable.go | 8 ++-- cmd/recovery.go | 14 +++--- go.mod | 1 - go.sum | 10 ----- internal/cli/output/output.go | 9 ++-- internal/cli/ui/account_selector.go | 6 +-- internal/infra/aws/iam.go | 42 +++++++++--------- internal/infra/aws/organizations.go | 11 +++-- internal/infra/aws/sts.go | 9 ++-- internal/logger/logger.go | 66 ++++++++++++----------------- internal/service/audit.go | 16 +++---- internal/service/configuration.go | 8 ++-- internal/service/credentials.go | 12 +++--- 16 files changed, 107 insertions(+), 133 deletions(-) diff --git a/cmd/audit.go b/cmd/audit.go index 76657c3..286e510 100644 --- a/cmd/audit.go +++ b/cmd/audit.go @@ -3,11 +3,11 @@ package cmd import ( "context" "fmt" + "log/slog" "strings" "github.com/unicrons/aws-root-manager/internal/cli/output" "github.com/unicrons/aws-root-manager/internal/cli/ui" - "github.com/unicrons/aws-root-manager/internal/logger" "github.com/unicrons/aws-root-manager/internal/service" "github.com/spf13/cobra" @@ -20,29 +20,29 @@ func Audit() *cobra.Command { Long: `Retrieve available root user credentials for all member accounts within an AWS Organization.`, SilenceUsage: true, RunE: func(cmd *cobra.Command, args []string) error { - logger.Trace("cmd.audit", "audit called") + slog.Debug("audit called") ctx := context.Background() rm, err := service.NewRootManagerFromConfig(ctx) if err != nil { - logger.Error("cmd.audit", err, "failed to initialize root manager") + slog.Error("failed to initialize root manager", "error", err) return err } auditAccounts, err := ui.SelectTargetAccounts(ctx, accountsFlags) if err != nil { - logger.Error("cmd.audit", err, "failed to get accounts to audit") + slog.Error("failed to get accounts to audit", "error", err) return err } if len(auditAccounts) == 0 { - logger.Info("cmd.audit", "no accounts selected") + slog.Info("no accounts selected") return nil } - logger.Debug("cmd.audit", "selected accounts: %s", strings.Join(auditAccounts, ", ")) + slog.Debug("selected accounts", "accounts", strings.Join(auditAccounts, ", ")) audit, err := rm.AuditAccounts(ctx, auditAccounts) if err != nil { - logger.Error("cmd.audit", err, "failed to audit accounts") + slog.Error("failed to audit accounts", "error", err) return err } diff --git a/cmd/check.go b/cmd/check.go index 091c76b..fe7132b 100644 --- a/cmd/check.go +++ b/cmd/check.go @@ -2,10 +2,10 @@ package cmd import ( "context" + "log/slog" "strconv" "github.com/unicrons/aws-root-manager/internal/cli/output" - "github.com/unicrons/aws-root-manager/internal/logger" "github.com/unicrons/aws-root-manager/internal/service" "github.com/spf13/cobra" @@ -17,18 +17,18 @@ func Check() *cobra.Command { Short: "Check if centralized root access is enabled.", Long: `Retrieve the status of centralized root access settings for an AWS Organization.`, RunE: func(cmd *cobra.Command, args []string) error { - logger.Trace("cmd.check", "check called") + slog.Debug("check called") ctx := context.Background() rm, err := service.NewRootManagerFromConfig(ctx) if err != nil { - logger.Error("cmd.check", err, "failed to initialize root manager") + slog.Error("failed to initialize root manager", "error", err) return err } status, err := rm.CheckRootAccess(ctx) if err != nil { - logger.Error("cmd.check", err, "failed to check root access configuration") + slog.Error("failed to check root access configuration", "error", err) return err } diff --git a/cmd/delete.go b/cmd/delete.go index 15ef39b..fc2a368 100644 --- a/cmd/delete.go +++ b/cmd/delete.go @@ -3,11 +3,11 @@ package cmd import ( "context" "fmt" + "log/slog" "strings" "github.com/unicrons/aws-root-manager/internal/cli/output" "github.com/unicrons/aws-root-manager/internal/cli/ui" - "github.com/unicrons/aws-root-manager/internal/logger" "github.com/unicrons/aws-root-manager/internal/service" "github.com/spf13/cobra" @@ -56,10 +56,10 @@ func runDelete(accountsFlags []string, credentialType string) error { return fmt.Errorf("failed to get accounts to audit: %w", err) } if len(auditAccounts) == 0 { - logger.Info("cmd.delete", "no accounts selected") + slog.Info("no accounts selected") return nil } - logger.Debug("cmd.delete", "selected accounts: %s", strings.Join(auditAccounts, ", ")) + slog.Debug("selected accounts", "accounts", strings.Join(auditAccounts, ", ")) audit, err := rm.AuditAccounts(ctx, auditAccounts) if err != nil { diff --git a/cmd/enable.go b/cmd/enable.go index 6753404..81937aa 100644 --- a/cmd/enable.go +++ b/cmd/enable.go @@ -2,10 +2,10 @@ package cmd import ( "context" + "log/slog" "strconv" "github.com/unicrons/aws-root-manager/internal/cli/output" - "github.com/unicrons/aws-root-manager/internal/logger" "github.com/unicrons/aws-root-manager/internal/service" "github.com/spf13/cobra" @@ -17,20 +17,20 @@ func Enable() *cobra.Command { Short: "Enable centralized root access", Long: `Enable centralized root access management in an AWS Organization.`, RunE: func(cmd *cobra.Command, args []string) error { - logger.Trace("cmd.enable", "enable called") + slog.Debug("enable called") enableRootSessions, _ := cmd.Flags().GetBool("enableRootSessions") ctx := context.Background() rm, err := service.NewRootManagerFromConfig(ctx) if err != nil { - logger.Error("cmd.enable", err, "failed to initialize root manager") + slog.Error("failed to initialize root manager", "error", err) return err } initStatus, status, err := rm.EnableRootAccess(ctx, enableRootSessions) if err != nil { - logger.Error("cmd.enable", err, "failed to enable root access") + slog.Error("failed to enable root access", "error", err) return err } diff --git a/cmd/recovery.go b/cmd/recovery.go index 30a1b1d..f5f902c 100644 --- a/cmd/recovery.go +++ b/cmd/recovery.go @@ -3,11 +3,11 @@ package cmd import ( "context" "fmt" + "log/slog" "strings" "github.com/unicrons/aws-root-manager/internal/cli/output" "github.com/unicrons/aws-root-manager/internal/cli/ui" - "github.com/unicrons/aws-root-manager/internal/logger" "github.com/unicrons/aws-root-manager/internal/service" "github.com/spf13/cobra" @@ -19,29 +19,29 @@ func Recovery() *cobra.Command { Short: "Allow root password recovery", Long: `Retrieve the status of centralized root access settings for an AWS Organization.`, RunE: func(cmd *cobra.Command, args []string) error { - logger.Trace("cmd.recovery", "recovery called") + slog.Debug("recovery called") ctx := context.Background() rm, err := service.NewRootManagerFromConfig(ctx) if err != nil { - logger.Error("cmd.recovery", err, "failed to initialize root manager") + slog.Error("failed to initialize root manager", "error", err) return err } targetAccounts, err := ui.SelectTargetAccounts(ctx, accountsFlags) if err != nil { - logger.Error("cmd.recovery", err, "failed to get target accounts") + slog.Error("failed to get target accounts", "error", err) return err } if len(targetAccounts) == 0 { - logger.Info("cmd.recovery", "no accounts selected") + slog.Info("no accounts selected") return nil } - logger.Debug("cmd.recovery", "selected accounts: %s", strings.Join(targetAccounts, ", ")) + slog.Debug("selected accounts", "accounts", strings.Join(targetAccounts, ", ")) results, err := rm.RecoverRootPassword(ctx, targetAccounts) if err != nil { - logger.Error("cmd.recovery", err, "failed to recover root password") + slog.Error("failed to recover root password", "error", err) return err } diff --git a/go.mod b/go.mod index a39d02c..af206eb 100644 --- a/go.mod +++ b/go.mod @@ -12,7 +12,6 @@ require ( github.com/charmbracelet/bubbletea v1.3.10 github.com/charmbracelet/lipgloss v1.1.0 github.com/charmbracelet/x/term v0.2.2 - github.com/sirupsen/logrus v1.9.4 github.com/spf13/cobra v1.10.2 ) diff --git a/go.sum b/go.sum index cbcbb32..0fe338c 100644 --- a/go.sum +++ b/go.sum @@ -49,8 +49,6 @@ github.com/charmbracelet/x/exp/golden v0.0.0-20240815200342-61de596daa2b/go.mod github.com/charmbracelet/x/term v0.2.2 h1:xVRT/S2ZcKdhhOuSP4t5cLi5o+JxklsoEObBSgfgZRk= github.com/charmbracelet/x/term v0.2.2/go.mod h1:kF8CY5RddLWrsgVwpw4kAa6TESp6EB5y3uxGLeCqzAI= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4= github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= @@ -69,20 +67,14 @@ github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELU github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo= github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc= github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= -github.com/sirupsen/logrus v1.9.4 h1:TsZE7l11zFCLZnZ+teH4Umoq5BhEIfIzfRDZ1Uzql2w= -github.com/sirupsen/logrus v1.9.4/go.mod h1:ftWc9WdOfJ0a92nsE2jF5u5ZwH8Bv2zdeOC42RjbV2g= github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU= github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4= github.com/spf13/pflag v1.0.9 h1:9exaQaMOCwffKiiiYk6/BndUBv+iRViNW+4lEMi0PvY= github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= -github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= -github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= @@ -95,5 +87,3 @@ golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= -gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/cli/output/output.go b/internal/cli/output/output.go index 610a937..f6a8954 100644 --- a/internal/cli/output/output.go +++ b/internal/cli/output/output.go @@ -2,8 +2,7 @@ package output import ( "fmt" - - "github.com/unicrons/aws-root-manager/internal/logger" + "log/slog" ) // HandleOutput handles the output based on the specified format @@ -11,16 +10,16 @@ func HandleOutput(format string, headers []string, rawData [][]any) { switch format { case "json": if err := PrintJSON(headers, rawData); err != nil { - logger.Error("output.HandleOutput", err, "error printing json") + slog.Error("error printing json", "error", err) } case "csv": if err := printCSV(headers, dataToString(rawData)); err != nil { - logger.Error("output.HandleOutput", err, "error printing csv") + slog.Error("error printing csv", "error", err) } case "table": printTable(headers, rawData) default: - logger.Error("output.HandleOutput", nil, "unsupported output format: %v", format) + slog.Error("unsupported output format", "format", format) } } diff --git a/internal/cli/ui/account_selector.go b/internal/cli/ui/account_selector.go index f253d92..0faaa83 100644 --- a/internal/cli/ui/account_selector.go +++ b/internal/cli/ui/account_selector.go @@ -3,9 +3,9 @@ package ui import ( "context" "fmt" + "log/slog" "github.com/unicrons/aws-root-manager/internal/infra/aws" - "github.com/unicrons/aws-root-manager/internal/logger" ) const ( @@ -16,7 +16,7 @@ const ( // SelectTargetAccounts handles interactive account selection or returns accounts based on flags. // Returns account IDs based on flags or TUI prompt. func SelectTargetAccounts(ctx context.Context, accountsFlag []string) ([]string, error) { - logger.Trace("ui.SelectTargetAccounts", "processing target accounts: %s", accountsFlag) + slog.Debug("processing target accounts", "accounts_flag", accountsFlag) // if accounts are provided and "all" is not specified, return them if len(accountsFlag) > 0 && accountsFlag[0] != AllAccountsOption { @@ -50,7 +50,7 @@ func SelectTargetAccounts(ctx context.Context, accountsFlag []string) ([]string, // Resolve selected accounts if allSelected(selectedIndexes) { - logger.Debug("ui.SelectTargetAccounts", "all accounts selected") + slog.Debug("all accounts selected") return convertAccountsToIDs(orgAccounts), nil } diff --git a/internal/infra/aws/iam.go b/internal/infra/aws/iam.go index a8e1e60..6b1f989 100644 --- a/internal/infra/aws/iam.go +++ b/internal/infra/aws/iam.go @@ -4,9 +4,9 @@ import ( "context" "errors" "fmt" + "log/slog" "slices" - "github.com/unicrons/aws-root-manager/internal/logger" "github.com/unicrons/aws-root-manager/rootmanager" "github.com/aws/aws-sdk-go-v2/aws" @@ -25,7 +25,7 @@ func NewIamClient(awscfg aws.Config) IamClient { // Verifies if AWS centralized root access is enabled func (c *iamClient) CheckOrganizationRootAccess(ctx context.Context, rootSessionsRequired bool) error { - logger.Trace("aws.CheckOrganizationRootAccess", "checking if organization root access is enabled") + slog.Debug("checking if organization root access is enabled") features, err := c.client.ListOrganizationsFeatures(ctx, &iam.ListOrganizationsFeaturesInput{}) if err != nil { @@ -55,13 +55,13 @@ func (c *iamClient) CheckOrganizationRootAccess(ctx context.Context, rootSession // Check if an account has root login profile enabled func (c *iamClient) GetLoginProfile(ctx context.Context, accountId string) (bool, error) { - logger.Debug("aws.GetLoginProfile", "getting login profile for account %s", accountId) + slog.Debug("getting login profile", "account_id", accountId) _, err := c.client.GetLoginProfile(ctx, &iam.GetLoginProfileInput{}) if err != nil { var notFoundErr *types.NoSuchEntityException if errors.As(err, ¬FoundErr) { - logger.Debug("aws.GetLoginProfile", "account %s does not have a root login profile", accountId) + slog.Debug("account does not have a root login profile", "account_id", accountId) return false, nil } return true, fmt.Errorf("error getting root login profile for account %s: %w", accountId, err) @@ -72,21 +72,21 @@ func (c *iamClient) GetLoginProfile(ctx context.Context, accountId string) (bool // Delete root login profile for a specific account func (c *iamClient) DeleteLoginProfile(ctx context.Context, accountId string) error { - logger.Debug("aws.DeleteLoginProfile", "deleting login profile for account %s", accountId) + slog.Debug("deleting login profile", "account_id", accountId) _, err := c.client.DeleteLoginProfile(ctx, &iam.DeleteLoginProfileInput{}) if err != nil { return fmt.Errorf("error deleting root login profile for account %s: %w", accountId, err) } - logger.Info("aws.DeleteLoginProfile", "successfully deleted login profile for account %s", accountId) + slog.Info("successfully deleted login profile", "account_id", accountId) return nil } // Get a list of root access keys for a specific account func (c *iamClient) ListAccessKeys(ctx context.Context, accountId string) ([]string, error) { - logger.Debug("aws.ListAccessKeys", "listing access keys for account %s", accountId) + slog.Debug("listing access keys", "account_id", accountId) accessKeys, err := c.client.ListAccessKeys(ctx, &iam.ListAccessKeysInput{}) if err != nil { @@ -103,7 +103,7 @@ func (c *iamClient) ListAccessKeys(ctx context.Context, accountId string) ([]str // Delete a list of root access for a specific account func (c *iamClient) DeleteAccessKeys(ctx context.Context, accountId string, accessKeyIds []string) error { - logger.Debug("aws.DeleteAccessKeys", "deleting root access key %s for account %s", accessKeyIds, accountId) + slog.Debug("deleting root access keys", "account_id", accountId, "access_key_ids", accessKeyIds) for _, accessKeyId := range accessKeyIds { _, err := c.client.DeleteAccessKey(ctx, &iam.DeleteAccessKeyInput{ @@ -114,7 +114,7 @@ func (c *iamClient) DeleteAccessKeys(ctx context.Context, accountId string, acce } } - logger.Info("aws.DeleteAccessKeys", "successfully deleted access keys for account %s", accountId) + slog.Info("successfully deleted access keys", "account_id", accountId) return nil } @@ -136,7 +136,7 @@ func (c *iamClient) ListMFADevices(ctx context.Context, accountId string) ([]str // Deactivate a list of root MFA devices for a specific account func (c *iamClient) DeactivateMFADevices(ctx context.Context, accountId string, mfaSerialNumbers []string) error { - logger.Debug("aws.DeactivateMFADevices", "deleting root mfa device %s for account %s", mfaSerialNumbers, accountId) + slog.Debug("deactivating root mfa devices", "account_id", accountId, "mfa_serial_numbers", mfaSerialNumbers) for _, mfaSerialNumber := range mfaSerialNumbers { _, err := c.client.DeactivateMFADevice(ctx, &iam.DeactivateMFADeviceInput{ @@ -147,7 +147,7 @@ func (c *iamClient) DeactivateMFADevices(ctx context.Context, accountId string, } } - logger.Info("aws.DeactivateMFADevices", "successfully deactivated mfa devices for account %s", accountId) + slog.Info("successfully deactivated mfa devices", "account_id", accountId) return nil } @@ -169,11 +169,9 @@ func (c *iamClient) ListSigningCertificates(ctx context.Context, accountId strin // Delete a list of root signing certificates for a specific account func (c *iamClient) DeleteSigningCertificates(ctx context.Context, accountId string, certificates []string) error { - logger.Debug("aws.DeleteSigningCertificates", "deleting singin certificates %s for account %s", certificates, accountId) + slog.Debug("deleting signing certificates", "account_id", accountId, "certificates", certificates) for _, certificate := range certificates { - logger.Debug("aws.DeleteSigningCertificates", "deleting root Signing Certificate %s", certificate) - if _, err := c.client.DeleteSigningCertificate(ctx, &iam.DeleteSigningCertificateInput{ CertificateId: aws.String(certificate), }); err != nil { @@ -181,54 +179,54 @@ func (c *iamClient) DeleteSigningCertificates(ctx context.Context, accountId str } } - logger.Info("aws.DeleteSigningCertificates", "successfully deleted signing certificates for account %s", accountId) + slog.Info("successfully deleted signing certificates", "account_id", accountId) return nil } // Enable centralized root credentials management func (c *iamClient) EnableOrganizationsRootCredentialsManagement(ctx context.Context) error { - logger.Debug("aws.EnableOrganizationsRootCredentialsManagement", "enabling organization root credentials management") + slog.Debug("enabling organization root credentials management") _, err := c.client.EnableOrganizationsRootCredentialsManagement(ctx, &iam.EnableOrganizationsRootCredentialsManagementInput{}) if err != nil { return fmt.Errorf("error enabling organization root credentials management: %w", err) } - logger.Info("aws.EnableOrganizationsRootCredentialsManagement", "successfully enabled organization root credentials management") + slog.Info("successfully enabled organization root credentials management") return nil } // Enable centralized root sessions func (c *iamClient) EnableOrganizationsRootSessions(ctx context.Context) error { - logger.Debug("aws.EnableOrganizationsRootSessions", "enabling organization root sessions") + slog.Debug("enabling organization root sessions") _, err := c.client.EnableOrganizationsRootSessions(ctx, &iam.EnableOrganizationsRootSessionsInput{}) if err != nil { return fmt.Errorf("error enabling organization root sessions: %w", err) } - logger.Info("aws.EnableOrganizationsRootSessions", "successfully enabled organization root sessions management") + slog.Info("successfully enabled organization root sessions management") return nil } // Allow root password recovery func (c *iamClient) CreateLoginProfile(ctx context.Context) error { - logger.Debug("aws.createLoginProfile", "creating loggin profile") + slog.Debug("creating login profile") _, err := c.client.CreateLoginProfile(ctx, &iam.CreateLoginProfileInput{}) if err != nil { var entityAlreadyExistsErr *types.EntityAlreadyExistsException if errors.As(err, &entityAlreadyExistsErr) { - logger.Debug("aws.createLoginProfile", "login profile already exists") + slog.Debug("login profile already exists") return rootmanager.ErrEntityAlreadyExists } return fmt.Errorf("error creating login profile: %w", err) } - logger.Info("aws.createLoginProfile", "successfully created login profile") + slog.Info("successfully created login profile") return nil } diff --git a/internal/infra/aws/organizations.go b/internal/infra/aws/organizations.go index a56fd76..80dc414 100644 --- a/internal/infra/aws/organizations.go +++ b/internal/infra/aws/organizations.go @@ -3,8 +3,7 @@ package aws import ( "context" "fmt" - - "github.com/unicrons/aws-root-manager/internal/logger" + "log/slog" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/organizations" @@ -27,7 +26,7 @@ type OrganizationAccount struct { // Fetches AWS Organization accounts, excluding the management account func GetNonManagementOrganizationAccounts(ctx context.Context) ([]OrganizationAccount, error) { - logger.Trace("aws.GetNonManagementOrganizationAccounts", "getting organization accounts") + slog.Debug("getting organization accounts") awscfg, err := LoadAWSConfig(ctx) if err != nil { @@ -61,7 +60,7 @@ func GetNonManagementOrganizationAccounts(ctx context.Context) ([]OrganizationAc } func (c *organizationsClient) listOrganizationAccounts() ([]types.Account, error) { - logger.Trace("aws.listOrganizationAccounts", "listing organization accounts") + slog.Debug("listing organization accounts") params := &organizations.ListAccountsInput{} paginator := organizations.NewListAccountsPaginator(c.client, params) @@ -80,7 +79,7 @@ func (c *organizationsClient) listOrganizationAccounts() ([]types.Account, error } func (c *organizationsClient) describeOrganization(ctx context.Context) (string, error) { - logger.Trace("aws.describeOrganization", "describing organization") + slog.Debug("describing organization") organization, err := c.client.DescribeOrganization(ctx, &organizations.DescribeOrganizationInput{}) if err != nil { @@ -91,7 +90,7 @@ func (c *organizationsClient) describeOrganization(ctx context.Context) (string, } func (c *organizationsClient) EnableAWSServiceAccess(ctx context.Context, service string) error { - logger.Trace("aws.EnableAWSServiceAccess", "enabling %s service access", service) + slog.Debug("enabling service access", "service", service) _, err := c.client.EnableAWSServiceAccess(ctx, &organizations.EnableAWSServiceAccessInput{ ServicePrincipal: aws.String(service), diff --git a/internal/infra/aws/sts.go b/internal/infra/aws/sts.go index d596644..00bf9ce 100644 --- a/internal/infra/aws/sts.go +++ b/internal/infra/aws/sts.go @@ -3,8 +3,7 @@ package aws import ( "context" "fmt" - - "github.com/unicrons/aws-root-manager/internal/logger" + "log/slog" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/sts" @@ -23,7 +22,7 @@ func NewStsClient(awscfg aws.Config) StsClient { } func (c *stsClient) GetAssumeRootConfig(ctx context.Context, accountId, taskPolicyName string) (aws.Config, error) { - logger.Trace("aws.GetAssumeRootConfig", "getting root aws.config account %s and task %s", accountId, taskPolicyName) + slog.Debug("getting root aws config", "account_id", accountId, "task", taskPolicyName) stsCreds, err := c.assumeRoot(ctx, accountId, taskPolicyName) if err != nil { @@ -43,13 +42,13 @@ func (c *stsClient) GetAssumeRootConfig(ctx context.Context, accountId, taskPoli return aws.Config{}, fmt.Errorf("error loading aws root config: %s", err) } - logger.Debug("aws.GetAssumeRootConfig", "successfully generated assume root credentials for account %s and task %s", accountId, taskPolicyName) + slog.Debug("successfully generated assume root credentials", "account_id", accountId, "task", taskPolicyName) return awsrootcfg, nil } func (c *stsClient) assumeRoot(ctx context.Context, accountId, taskPolicyName string) (types.Credentials, error) { - logger.Trace("aws.assumeRoot", "assuming root for account %s and task %s", accountId, taskPolicyName) + slog.Debug("assuming root", "account_id", accountId, "task", taskPolicyName) params := &sts.AssumeRootInput{ TargetPrincipal: aws.String(accountId), diff --git a/internal/logger/logger.go b/internal/logger/logger.go index f57d097..c9cca62 100644 --- a/internal/logger/logger.go +++ b/internal/logger/logger.go @@ -1,55 +1,45 @@ package logger import ( + "log/slog" "os" - - log "github.com/sirupsen/logrus" ) func init() { - lvl, ok := os.LookupEnv("LOG_LEVEL") - if !ok { - lvl = "error" // default - } + lvl := os.Getenv("LOG_LEVEL") + format := os.Getenv("LOG_FORMAT") + Configure(lvl, format) +} - ll, err := log.ParseLevel(lvl) - if err != nil { - ll = log.DebugLevel - } - log.SetLevel(ll) +// Configure sets up the global slog logger based on the given level and format. +// This is used by the CLI at startup; external consumers control slog via slog.SetDefault. +func Configure(level, format string) { + slogLevel := parseLevel(level) - format, ok := os.LookupEnv("LOG_FORMAT") - if ok { - SetLoggerFormat(format) - } -} + opts := &slog.HandlerOptions{Level: slogLevel, AddSource: true} -func SetLoggerFormat(logFormat string) { - switch logFormat { + var handler slog.Handler + switch format { case "json": - log.SetFormatter(&log.JSONFormatter{}) + handler = slog.NewJSONHandler(os.Stderr, opts) default: - log.SetFormatter(&log.TextFormatter{}) + handler = slog.NewTextHandler(os.Stderr, opts) } -} -// Wrap logrus with function name -func Trace(funcName, format string, args ...any) { - log.WithField("function", funcName).Tracef(format, args...) + slog.SetDefault(slog.New(handler)) } -func Debug(funcName, format string, args ...any) { - log.WithField("function", funcName).Debugf(format, args...) -} - -func Info(funcName, format string, args ...any) { - log.WithField("function", funcName).Infof(format, args...) -} - -func Warn(funcName, format string, args ...any) { - log.WithField("function", funcName).Warnf(format, args...) -} - -func Error(funcName string, err error, format string, args ...any) { - log.WithField("function", funcName).WithError(err).Errorf(format, args...) +func parseLevel(level string) slog.Level { + switch level { + case "debug": + return slog.LevelDebug + case "info": + return slog.LevelInfo + case "warn": + return slog.LevelWarn + case "error": + return slog.LevelError + default: + return slog.LevelError + } } diff --git a/internal/service/audit.go b/internal/service/audit.go index 633a584..f6ab265 100644 --- a/internal/service/audit.go +++ b/internal/service/audit.go @@ -2,16 +2,16 @@ package service import ( "context" + "log/slog" "sync" "github.com/unicrons/aws-root-manager/internal/infra/aws" - "github.com/unicrons/aws-root-manager/internal/logger" "github.com/unicrons/aws-root-manager/rootmanager" ) // auditAccounts returns root credentials for a list of AWS accounts. func auditAccounts(ctx context.Context, iam aws.IamClient, sts aws.StsClient, factory aws.IamClientFactory, accounts []string) ([]rootmanager.RootCredentials, error) { - logger.Trace("service.auditAccounts", "auditing accounts %s", accounts) + slog.Debug("auditing accounts", "accounts", accounts) rootCredentials := make([]rootmanager.RootCredentials, len(accounts)) var wgAccounts sync.WaitGroup @@ -25,7 +25,7 @@ func auditAccounts(ctx context.Context, iam aws.IamClient, sts aws.StsClient, fa go func(idx int, accountId string) { defer wgAccounts.Done() if accStatus, err := auditAccount(ctx, sts, factory, accountId); err != nil { - logger.Error("service.auditAccounts", err, "account %s: audit skipped", accountId) + slog.Error("audit skipped", "account_id", accountId, "error", err) rootCredentials[idx] = rootmanager.RootCredentials{AccountId: accountId, Error: err.Error()} } else { rootCredentials[idx] = accStatus @@ -40,7 +40,7 @@ func auditAccounts(ctx context.Context, iam aws.IamClient, sts aws.StsClient, fa // Get root credentials for a specific account func auditAccount(ctx context.Context, sts aws.StsClient, factory aws.IamClientFactory, accountId string) (rootmanager.RootCredentials, error) { - logger.Trace("service.auditAccount", "auditing account %s", accountId) + slog.Debug("auditing account", "account_id", accountId) awscfgRoot, err := sts.GetAssumeRootConfig(ctx, accountId, "IAMAuditRootUserCredentials") if err != nil { @@ -54,25 +54,25 @@ func auditAccount(ctx context.Context, sts aws.StsClient, factory aws.IamClientF if err != nil { return accountRootCredentials, err } - logger.Debug("service.auditAccounts", "account %s - login_profile: %t", accountId, loginProfile) + slog.Debug("audit result", "account_id", accountId, "login_profile", loginProfile) accessKeys, err := iamRoot.ListAccessKeys(ctx, accountId) if err != nil { return accountRootCredentials, err } - logger.Debug("service.auditAccounts", "account %s - access_keys: %s", accountId, accessKeys) + slog.Debug("audit result", "account_id", accountId, "access_keys", accessKeys) mfaDevices, err := iamRoot.ListMFADevices(ctx, accountId) if err != nil { return accountRootCredentials, err } - logger.Debug("service.auditAccounts", "account %s - mfa_devices: %s", accountId, mfaDevices) + slog.Debug("audit result", "account_id", accountId, "mfa_devices", mfaDevices) certificates, err := iamRoot.ListSigningCertificates(ctx, accountId) if err != nil { return accountRootCredentials, err } - logger.Debug("service.auditAccounts", "account %s - signing_certificates: %s", accountId, certificates) + slog.Debug("audit result", "account_id", accountId, "signing_certificates", certificates) accountRootCredentials = rootmanager.RootCredentials{ AccountId: accountId, diff --git a/internal/service/configuration.go b/internal/service/configuration.go index 3c13daf..c5fea72 100644 --- a/internal/service/configuration.go +++ b/internal/service/configuration.go @@ -3,9 +3,9 @@ package service import ( "context" "errors" + "log/slog" "github.com/unicrons/aws-root-manager/internal/infra/aws" - "github.com/unicrons/aws-root-manager/internal/logger" "github.com/unicrons/aws-root-manager/rootmanager" ) @@ -53,7 +53,7 @@ func enableRootAccess(ctx context.Context, iam aws.IamClient, org aws.Organizati } if !initStatus.TrustedAccess { - logger.Debug("service.EnableRootAccess", "trusted access is disabled") + slog.Debug("trusted access is disabled") err := org.EnableAWSServiceAccess(ctx, "iam.amazonaws.com") if err != nil { return initStatus, status, err @@ -61,7 +61,7 @@ func enableRootAccess(ctx context.Context, iam aws.IamClient, org aws.Organizati } if !initStatus.RootCredentialsManagement { - logger.Debug("service.EnableRootAccess", "root credentials management is disabled") + slog.Debug("root credentials management is disabled") err = iam.EnableOrganizationsRootCredentialsManagement(ctx) if err != nil { return initStatus, status, err @@ -69,7 +69,7 @@ func enableRootAccess(ctx context.Context, iam aws.IamClient, org aws.Organizati } if !initStatus.RootSessions && enableSessions { - logger.Debug("service.EnableRootAccess", "root sessions is disabled") + slog.Debug("root sessions is disabled") err = iam.EnableOrganizationsRootSessions(ctx) if err != nil { diff --git a/internal/service/credentials.go b/internal/service/credentials.go index 1ee516f..e82cf2b 100644 --- a/internal/service/credentials.go +++ b/internal/service/credentials.go @@ -3,10 +3,10 @@ package service import ( "context" "errors" + "log/slog" "sync" "github.com/unicrons/aws-root-manager/internal/infra/aws" - "github.com/unicrons/aws-root-manager/internal/logger" "github.com/unicrons/aws-root-manager/rootmanager" ) @@ -25,7 +25,7 @@ func deleteAccountsCredentials(ctx context.Context, iam aws.IamClient, sts aws.S go func(idx int, accountCreds rootmanager.RootCredentials) { defer wgAccounts.Done() if err := deleteAccountCredentials(ctx, sts, factory, accountCreds, credentialType); err != nil { - logger.Error("service.deleteAccountsCredentials", err, "account %s: deletion failed", accountCreds.AccountId) + slog.Error("deletion failed", "account_id", accountCreds.AccountId, "error", err) results[idx] = rootmanager.DeletionResult{ AccountId: accountCreds.AccountId, CredentialType: credentialType, @@ -50,11 +50,11 @@ func deleteAccountsCredentials(ctx context.Context, iam aws.IamClient, sts aws.S // deleteAccountCredentials deletes root credentials for a specific account. func deleteAccountCredentials(ctx context.Context, sts aws.StsClient, factory aws.IamClientFactory, creds rootmanager.RootCredentials, credentialType string) error { - logger.Trace("service.deleteAccountCredentials", "checking if account %s has %s credentials to delete", credentialType, credentialType) + slog.Debug("checking credentials to delete", "account_id", creds.AccountId, "credential_type", credentialType) // Check if there are credentials to delete before assuming root if !hasCredentialsToDelete(creds, credentialType) { - logger.Info("service.deleteAccountCredentials", "no %s credentials found for account %s", credentialType, creds.AccountId) + slog.Info("no credentials found to delete", "account_id", creds.AccountId, "credential_type", credentialType) return nil } @@ -129,7 +129,7 @@ func recoverAccountsRootPassword(ctx context.Context, iam aws.IamClient, sts aws defer wgAccounts.Done() success, err := recoverAccountRootPassowrd(ctx, sts, factory, accId) if err != nil { - logger.Error("service.recoverAccountsRootPassword", err, "account %s: recovery failed", accId) + slog.Error("recovery failed", "account_id", accId, "error", err) results[idx] = rootmanager.RecoveryResult{ AccountId: accId, Success: false, @@ -152,7 +152,7 @@ func recoverAccountsRootPassword(ctx context.Context, iam aws.IamClient, sts aws // Enable the recovery process for root passwords for a specific account func recoverAccountRootPassowrd(ctx context.Context, sts aws.StsClient, factory aws.IamClientFactory, accountId string) (bool, error) { - logger.Trace("service.recoverAccountRootPassowrd", "trying to recover root password for account %s ", accountId) + slog.Debug("trying to recover root password", "account_id", accountId) awscfgRecoverRoot, err := sts.GetAssumeRootConfig(ctx, accountId, "IAMCreateRootUserPassword") if err != nil { From 79ea81f373786c5d259589d3ef2c1dcb52241b4a Mon Sep 17 00:00:00 2001 From: sbldevnet Date: Thu, 12 Mar 2026 14:39:35 +0100 Subject: [PATCH 10/12] refactor(logger): remove Info and Error logs from internal packages --- internal/infra/aws/iam.go | 14 -------------- internal/service/audit.go | 1 - internal/service/credentials.go | 3 --- 3 files changed, 18 deletions(-) diff --git a/internal/infra/aws/iam.go b/internal/infra/aws/iam.go index 6b1f989..da153d5 100644 --- a/internal/infra/aws/iam.go +++ b/internal/infra/aws/iam.go @@ -79,8 +79,6 @@ func (c *iamClient) DeleteLoginProfile(ctx context.Context, accountId string) er return fmt.Errorf("error deleting root login profile for account %s: %w", accountId, err) } - slog.Info("successfully deleted login profile", "account_id", accountId) - return nil } @@ -114,8 +112,6 @@ func (c *iamClient) DeleteAccessKeys(ctx context.Context, accountId string, acce } } - slog.Info("successfully deleted access keys", "account_id", accountId) - return nil } @@ -147,8 +143,6 @@ func (c *iamClient) DeactivateMFADevices(ctx context.Context, accountId string, } } - slog.Info("successfully deactivated mfa devices", "account_id", accountId) - return nil } @@ -179,8 +173,6 @@ func (c *iamClient) DeleteSigningCertificates(ctx context.Context, accountId str } } - slog.Info("successfully deleted signing certificates", "account_id", accountId) - return nil } @@ -193,8 +185,6 @@ func (c *iamClient) EnableOrganizationsRootCredentialsManagement(ctx context.Con return fmt.Errorf("error enabling organization root credentials management: %w", err) } - slog.Info("successfully enabled organization root credentials management") - return nil } @@ -207,8 +197,6 @@ func (c *iamClient) EnableOrganizationsRootSessions(ctx context.Context) error { return fmt.Errorf("error enabling organization root sessions: %w", err) } - slog.Info("successfully enabled organization root sessions management") - return nil } @@ -226,7 +214,5 @@ func (c *iamClient) CreateLoginProfile(ctx context.Context) error { return fmt.Errorf("error creating login profile: %w", err) } - slog.Info("successfully created login profile") - return nil } diff --git a/internal/service/audit.go b/internal/service/audit.go index f6ab265..5d7095e 100644 --- a/internal/service/audit.go +++ b/internal/service/audit.go @@ -25,7 +25,6 @@ func auditAccounts(ctx context.Context, iam aws.IamClient, sts aws.StsClient, fa go func(idx int, accountId string) { defer wgAccounts.Done() if accStatus, err := auditAccount(ctx, sts, factory, accountId); err != nil { - slog.Error("audit skipped", "account_id", accountId, "error", err) rootCredentials[idx] = rootmanager.RootCredentials{AccountId: accountId, Error: err.Error()} } else { rootCredentials[idx] = accStatus diff --git a/internal/service/credentials.go b/internal/service/credentials.go index e82cf2b..f311f71 100644 --- a/internal/service/credentials.go +++ b/internal/service/credentials.go @@ -25,7 +25,6 @@ func deleteAccountsCredentials(ctx context.Context, iam aws.IamClient, sts aws.S go func(idx int, accountCreds rootmanager.RootCredentials) { defer wgAccounts.Done() if err := deleteAccountCredentials(ctx, sts, factory, accountCreds, credentialType); err != nil { - slog.Error("deletion failed", "account_id", accountCreds.AccountId, "error", err) results[idx] = rootmanager.DeletionResult{ AccountId: accountCreds.AccountId, CredentialType: credentialType, @@ -54,7 +53,6 @@ func deleteAccountCredentials(ctx context.Context, sts aws.StsClient, factory aw // Check if there are credentials to delete before assuming root if !hasCredentialsToDelete(creds, credentialType) { - slog.Info("no credentials found to delete", "account_id", creds.AccountId, "credential_type", credentialType) return nil } @@ -129,7 +127,6 @@ func recoverAccountsRootPassword(ctx context.Context, iam aws.IamClient, sts aws defer wgAccounts.Done() success, err := recoverAccountRootPassowrd(ctx, sts, factory, accId) if err != nil { - slog.Error("recovery failed", "account_id", accId, "error", err) results[idx] = rootmanager.RecoveryResult{ AccountId: accId, Success: false, From fd53710cf596f9220094e9a41b4c883a07eb9d7d Mon Sep 17 00:00:00 2001 From: sbldevnet Date: Thu, 12 Mar 2026 14:55:18 +0100 Subject: [PATCH 11/12] refactor(logger): remove init() to avoid overriding external configuration --- cmd/root.go | 4 ++++ internal/logger/logger.go | 7 ------- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/cmd/root.go b/cmd/root.go index 2112ea2..4c05189 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -3,6 +3,8 @@ package cmd import ( "os" + "github.com/unicrons/aws-root-manager/internal/logger" + "github.com/spf13/cobra" ) @@ -36,6 +38,8 @@ func Execute() { } func init() { + logger.Configure(os.Getenv("LOG_LEVEL"), os.Getenv("LOG_FORMAT")) + rootCmd.PersistentFlags().StringVarP(&outputFlag, "output", "o", "table", "Set the output format (table, json, csv)") rootCmd.AddCommand(Audit()) rootCmd.AddCommand(Check()) diff --git a/internal/logger/logger.go b/internal/logger/logger.go index c9cca62..7e20e6a 100644 --- a/internal/logger/logger.go +++ b/internal/logger/logger.go @@ -5,14 +5,7 @@ import ( "os" ) -func init() { - lvl := os.Getenv("LOG_LEVEL") - format := os.Getenv("LOG_FORMAT") - Configure(lvl, format) -} - // Configure sets up the global slog logger based on the given level and format. -// This is used by the CLI at startup; external consumers control slog via slog.SetDefault. func Configure(level, format string) { slogLevel := parseLevel(level) From f2391ea7a0654e93b6ba1e3ee9790c46c127cf2d Mon Sep 17 00:00:00 2001 From: sbldevnet Date: Fri, 13 Mar 2026 10:26:27 +0100 Subject: [PATCH 12/12] refactor(aws): add DescribeOrganization and ListAccounts to OrganizationsClient interface --- internal/cli/ui/account_selector.go | 9 +++- internal/infra/aws/interfaces.go | 6 +++ internal/infra/aws/organizations.go | 68 ++++++++++++++--------------- 3 files changed, 46 insertions(+), 37 deletions(-) diff --git a/internal/cli/ui/account_selector.go b/internal/cli/ui/account_selector.go index 0faaa83..79c9599 100644 --- a/internal/cli/ui/account_selector.go +++ b/internal/cli/ui/account_selector.go @@ -23,8 +23,15 @@ func SelectTargetAccounts(ctx context.Context, accountsFlag []string) ([]string, return accountsFlag, nil } + // create organizations client to fetch accounts + awscfg, err := aws.LoadAWSConfig(ctx) + if err != nil { + return nil, fmt.Errorf("failed to load aws config: %w", err) + } + org := aws.NewOrganizationsClient(awscfg) + // fetch all non-management accounts - orgAccounts, err := aws.GetNonManagementOrganizationAccounts(ctx) + orgAccounts, err := aws.GetNonManagementOrganizationAccounts(ctx, org) if err != nil { return nil, fmt.Errorf("error fetching organization accounts: %w", err) } diff --git a/internal/infra/aws/interfaces.go b/internal/infra/aws/interfaces.go index 32f4fbd..53930e5 100644 --- a/internal/infra/aws/interfaces.go +++ b/internal/infra/aws/interfaces.go @@ -56,6 +56,12 @@ type StsClient interface { // OrganizationsClient defines the interface for AWS Organizations operations. // This interface enables mocking and dependency injection for testing. type OrganizationsClient interface { + // DescribeOrganization returns the management account ID of the organization + DescribeOrganization(ctx context.Context) (string, error) + + // ListAccounts returns all accounts in the organization + ListAccounts(ctx context.Context) ([]OrganizationAccount, error) + // EnableAWSServiceAccess enables AWS service access for the organization EnableAWSServiceAccess(ctx context.Context, service string) error } diff --git a/internal/infra/aws/organizations.go b/internal/infra/aws/organizations.go index 80dc414..98cc45c 100644 --- a/internal/infra/aws/organizations.go +++ b/internal/infra/aws/organizations.go @@ -24,69 +24,65 @@ type OrganizationAccount struct { AccountID string } -// Fetches AWS Organization accounts, excluding the management account -func GetNonManagementOrganizationAccounts(ctx context.Context) ([]OrganizationAccount, error) { +// GetNonManagementOrganizationAccounts fetches active organization accounts, excluding the management account. +func GetNonManagementOrganizationAccounts(ctx context.Context, org OrganizationsClient) ([]OrganizationAccount, error) { slog.Debug("getting organization accounts") - awscfg, err := LoadAWSConfig(ctx) - if err != nil { - return nil, fmt.Errorf("failed to load aws config: %w", err) - } - - orgs := NewOrganizationsClient(awscfg) - - mgmAccount, err := orgs.(*organizationsClient).describeOrganization(ctx) + mgmAccountId, err := org.DescribeOrganization(ctx) if err != nil { return nil, err } - orgAccounts, err := orgs.(*organizationsClient).listOrganizationAccounts() + allAccounts, err := org.ListAccounts(ctx) if err != nil { return nil, err } - var nonManagementOrgAccounts []OrganizationAccount - for _, acc := range orgAccounts { - if string(acc.State) == "ACTIVE" && *acc.Id != mgmAccount { - account := OrganizationAccount{ - Name: *acc.Name, - AccountID: *acc.Id, - } - nonManagementOrgAccounts = append(nonManagementOrgAccounts, account) + var nonManagementAccounts []OrganizationAccount + for _, acc := range allAccounts { + if acc.AccountID != mgmAccountId { + nonManagementAccounts = append(nonManagementAccounts, acc) } } - return nonManagementOrgAccounts, nil + return nonManagementAccounts, nil } -func (c *organizationsClient) listOrganizationAccounts() ([]types.Account, error) { +func (c *organizationsClient) DescribeOrganization(ctx context.Context) (string, error) { + slog.Debug("describing organization") + + organization, err := c.client.DescribeOrganization(ctx, &organizations.DescribeOrganizationInput{}) + if err != nil { + return "", fmt.Errorf("failed to describe organization: %w", err) + } + + return *organization.Organization.MasterAccountId, nil +} + +func (c *organizationsClient) ListAccounts(ctx context.Context) ([]OrganizationAccount, error) { slog.Debug("listing organization accounts") params := &organizations.ListAccountsInput{} paginator := organizations.NewListAccountsPaginator(c.client, params) - var allAccounts []types.Account + var accounts []OrganizationAccount for paginator.HasMorePages() { - page, err := paginator.NextPage(context.Background()) + page, err := paginator.NextPage(ctx) if err != nil { return nil, fmt.Errorf("failed to list organization accounts: %v", err) } - allAccounts = append(allAccounts, page.Accounts...) - } - - return allAccounts, nil -} - -func (c *organizationsClient) describeOrganization(ctx context.Context) (string, error) { - slog.Debug("describing organization") - - organization, err := c.client.DescribeOrganization(ctx, &organizations.DescribeOrganizationInput{}) - if err != nil { - return "", fmt.Errorf("failed to describe organization: %w", err) + for _, acc := range page.Accounts { + if acc.Status == types.AccountStatusActive { + accounts = append(accounts, OrganizationAccount{ + Name: aws.ToString(acc.Name), + AccountID: aws.ToString(acc.Id), + }) + } + } } - return *organization.Organization.MasterAccountId, nil + return accounts, nil } func (c *organizationsClient) EnableAWSServiceAccess(ctx context.Context, service string) error {