Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 39 additions & 9 deletions credentials/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import (
"github.com/aws/aws-sdk-go-v2/credentials/processcreds"
"github.com/aws/aws-sdk-go-v2/service/sts"
"github.com/aws/aws-sdk-go-v2/service/sts/types"
"github.com/aws/smithy-go/middleware"
smithyhttp "github.com/aws/smithy-go/transport/http"
)

// FromConfig retrieves credentials from the AWS cli config files, typically
Expand Down Expand Up @@ -105,20 +107,25 @@ func FromReader(reader io.Reader) (*aws.Credentials, error) {
// FederateUser will federate the given user credentials by calling STS
// GetFederationToken. If the given credentials are not for a user (like
// credentials for a role) then they are returned unmodified.
func FederateUser(creds *aws.Credentials, region, name, policy string, duration time.Duration, _ string) (*aws.Credentials, error) {
func FederateUser(creds *aws.Credentials, region, name, policy string, duration time.Duration, userAgent string) (*aws.Credentials, error) {
// Only federate if user credentials were given.
if creds.SessionToken != "" {
return creds, nil
}

client := sts.NewFromConfig(aws.Config{
Credentials: credentials.NewStaticCredentialsProvider(
creds.AccessKeyID,
creds.SecretAccessKey,
creds.SessionToken,
),
Region: region,
})
client := sts.NewFromConfig(
aws.Config{
Credentials: credentials.NewStaticCredentialsProvider(
creds.AccessKeyID,
creds.SecretAccessKey,
creds.SessionToken,
),
Region: region,
},
func(options *sts.Options) {
options.APIOptions = append(options.APIOptions, setUserAgent(userAgent))
},
)

input := sts.GetFederationTokenInput{
Name: aws.String(name),
Expand Down Expand Up @@ -150,3 +157,26 @@ func FederateUser(creds *aws.Credentials, region, name, policy string, duration
SessionToken: aws.ToString(result.Credentials.SessionToken),
}, nil
}

func setUserAgent(useragent string) func(stack *middleware.Stack) error {
return func(stack *middleware.Stack) error {
bm := userAgentMiddleware(useragent)
stack.Build.Remove(bm.ID()) //nolint

return stack.Build.Add(&bm, middleware.After)
}
}

type userAgentMiddleware string

func (userAgentMiddleware) ID() string {
return "UserAgent"
}

func (u userAgentMiddleware) HandleBuild(ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler) (middleware.BuildOutput, middleware.Metadata, error) {
if req, ok := in.Request.(*smithyhttp.Request); ok {
req.Header.Set("User-Agent", string(u))
}

return next.HandleBuild(ctx, in)
}
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ require (
github.com/aws/aws-sdk-go-v2/config v1.31.16
github.com/aws/aws-sdk-go-v2/credentials v1.18.20
github.com/aws/aws-sdk-go-v2/service/sts v1.39.0
github.com/aws/smithy-go v1.23.1
github.com/joshdk/buildversion v0.1.0
github.com/mattn/go-isatty v0.0.20
github.com/mattn/go-sixel v0.0.5
Expand All @@ -25,7 +26,6 @@ require (
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.12 // indirect
github.com/aws/aws-sdk-go-v2/service/sso v1.30.0 // indirect
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.4 // indirect
github.com/aws/smithy-go v1.23.1 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/soniakeys/quant v1.0.0 // indirect
github.com/spf13/pflag v1.0.10 // indirect
Expand Down