Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions saml/parser.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package saml

import (
"net/http"

"otc-auth/common"
)

type CredentialParser interface {
Parse(resp *http.Response) (*common.TokenResponse, error)
}

type defaultCredentialParser struct{}

func NewDefaultCredentialParser() CredentialParser {
return &defaultCredentialParser{}
}

func (p *defaultCredentialParser) Parse(resp *http.Response) (*common.TokenResponse, error) {
return common.GetCloudCredentialsFromResponse(resp)
}
51 changes: 36 additions & 15 deletions saml/saml.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,35 @@ import (
"github.com/go-http-utils/headers"
)

func AuthenticateAndGetUnscopedToken(ctx context.Context, authInfo common.AuthInfo) (*common.TokenResponse, error) {
httpClient := common.NewHTTPClient(authInfo.SkipTLS)
spInitiatedRequest, err := getServiceProviderInitiatedRequest(ctx, authInfo, httpClient)
type Authenticator struct {
client common.HTTPClient
parser CredentialParser
}

func newAuthenticator(client common.HTTPClient, parser CredentialParser) *Authenticator {
return &Authenticator{
client: client,
parser: parser,
}
}

func AuthenticateAndGetUnscopedToken(ctx context.Context,
authInfo common.AuthInfo,
) (*common.TokenResponse, error) {
client := common.NewHTTPClient(authInfo.SkipTLS)
parser := NewDefaultCredentialParser()
service := newAuthenticator(client, parser)
return service.Authenticate(ctx, authInfo)
}

func (a *Authenticator) Authenticate(ctx context.Context, authInfo common.AuthInfo) (*common.TokenResponse, error) {
spInitiatedRequest, err := a.getServiceProviderInitiatedRequest(ctx, authInfo)
if err != nil {
return nil, fmt.Errorf("error getting sp request\ntrace: %w", err)
}
defer spInitiatedRequest.Body.Close()

bodyBytes, err := authenticateWithIdp(ctx, authInfo, spInitiatedRequest, httpClient)
bodyBytes, err := a.authenticateWithIdp(ctx, authInfo, spInitiatedRequest)
if err != nil {
return nil, fmt.Errorf("couldn't auth with idp: %w", err)
}
Expand All @@ -35,22 +55,22 @@ func AuthenticateAndGetUnscopedToken(ctx context.Context, authInfo common.AuthIn
return nil, fmt.Errorf("fatal: error deserializing xml.\ntrace: %w", err)
}

response, err := validateAuthenticationWithServiceProvider(ctx, assertionResult, bodyBytes, httpClient)
response, err := a.validateAuthenticationWithServiceProvider(ctx, assertionResult, bodyBytes)
if err != nil {
return nil, fmt.Errorf("couldn't validate auth with service provider: %w", err)
}
defer response.Body.Close()

tokenResponse, err := common.GetCloudCredentialsFromResponse(response)
tokenResponse, err := a.parser.Parse(response)
if err != nil {
return nil, fmt.Errorf("couldn't get cloud creds from response: %w", err)
}

return tokenResponse, nil
}

func getServiceProviderInitiatedRequest(ctx context.Context,
params common.AuthInfo, client common.HTTPClient,
func (a *Authenticator) getServiceProviderInitiatedRequest(ctx context.Context,
params common.AuthInfo,
) (*http.Response, error) {
request, err := common.NewRequest(ctx, http.MethodGet,
endpoints.IdentityProviders(params.IdpName, string(params.AuthProtocol), params.Region), nil)
Expand All @@ -60,11 +80,11 @@ func getServiceProviderInitiatedRequest(ctx context.Context,
request.Header.Add(headers.Accept, headervalues.ApplicationPaos)
request.Header.Add(header.Paos, headervalues.Paos)

return client.MakeRequest(request)
return a.client.MakeRequest(request)
}

func authenticateWithIdp(ctx context.Context, params common.AuthInfo,
samlResponse *http.Response, client common.HTTPClient,
func (a *Authenticator) authenticateWithIdp(ctx context.Context, params common.AuthInfo,
samlResponse *http.Response,
) ([]byte, error) {
request, err := common.NewRequest(ctx, http.MethodPost, params.IdpURL, samlResponse.Body)
if err != nil {
Expand All @@ -75,16 +95,17 @@ func authenticateWithIdp(ctx context.Context, params common.AuthInfo,
request.Header.Add(headers.ContentType, headervalues.TextXML)
request.SetBasicAuth(params.Username, params.Password)

response, err := client.MakeRequest(request)
response, err := a.client.MakeRequest(request)
if err != nil {
return nil, err
}
defer response.Body.Close()
return common.GetBodyBytesFromResponse(response)
}

func validateAuthenticationWithServiceProvider(ctx context.Context, assertionResult common.SamlAssertionResponse,
responseBodyBytes []byte, client common.HTTPClient,
func (a *Authenticator) validateAuthenticationWithServiceProvider(ctx context.Context,
assertionResult common.SamlAssertionResponse,
responseBodyBytes []byte,
) (*http.Response, error) {
request, err := common.NewRequest(ctx, http.MethodPost, assertionResult.Header.Response.AssertionConsumerServiceURL,
bytes.NewReader(responseBodyBytes))
Expand All @@ -93,5 +114,5 @@ func validateAuthenticationWithServiceProvider(ctx context.Context, assertionRes
}
request.Header.Add(headers.ContentType, headervalues.ApplicationPaos)

return client.MakeRequest(request)
return a.client.MakeRequest(request)
}
Loading
Loading