From 627e1629135af8d677b1ca4a0f85499a9cf2f8bd Mon Sep 17 00:00:00 2001 From: Mweya Date: Wed, 15 Oct 2025 13:27:44 +0200 Subject: [PATCH] test: SAML --- saml/parser.go | 21 ++ saml/saml.go | 51 +++-- saml/saml_test.go | 559 ++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 616 insertions(+), 15 deletions(-) create mode 100644 saml/parser.go create mode 100644 saml/saml_test.go diff --git a/saml/parser.go b/saml/parser.go new file mode 100644 index 0000000..fb3f36b --- /dev/null +++ b/saml/parser.go @@ -0,0 +1,21 @@ +package saml + +import ( + "net/http" + + "otc-auth/common" +) + +type CredentialParser interface { + Parse(resp *http.Response) (*common.TokenResponse, error) +} + +type defaultCredentialParser struct{} + +func NewDefaultCredentialParser() CredentialParser { + return &defaultCredentialParser{} +} + +func (p *defaultCredentialParser) Parse(resp *http.Response) (*common.TokenResponse, error) { + return common.GetCloudCredentialsFromResponse(resp) +} diff --git a/saml/saml.go b/saml/saml.go index 229118e..3caccdc 100644 --- a/saml/saml.go +++ b/saml/saml.go @@ -15,15 +15,35 @@ import ( "github.com/go-http-utils/headers" ) -func AuthenticateAndGetUnscopedToken(ctx context.Context, authInfo common.AuthInfo) (*common.TokenResponse, error) { - httpClient := common.NewHTTPClient(authInfo.SkipTLS) - spInitiatedRequest, err := getServiceProviderInitiatedRequest(ctx, authInfo, httpClient) +type Authenticator struct { + client common.HTTPClient + parser CredentialParser +} + +func newAuthenticator(client common.HTTPClient, parser CredentialParser) *Authenticator { + return &Authenticator{ + client: client, + parser: parser, + } +} + +func AuthenticateAndGetUnscopedToken(ctx context.Context, + authInfo common.AuthInfo, +) (*common.TokenResponse, error) { + client := common.NewHTTPClient(authInfo.SkipTLS) + parser := NewDefaultCredentialParser() + service := newAuthenticator(client, parser) + return service.Authenticate(ctx, authInfo) +} + +func (a *Authenticator) Authenticate(ctx context.Context, authInfo common.AuthInfo) (*common.TokenResponse, error) { + spInitiatedRequest, err := a.getServiceProviderInitiatedRequest(ctx, authInfo) if err != nil { return nil, fmt.Errorf("error getting sp request\ntrace: %w", err) } defer spInitiatedRequest.Body.Close() - bodyBytes, err := authenticateWithIdp(ctx, authInfo, spInitiatedRequest, httpClient) + bodyBytes, err := a.authenticateWithIdp(ctx, authInfo, spInitiatedRequest) if err != nil { return nil, fmt.Errorf("couldn't auth with idp: %w", err) } @@ -35,13 +55,13 @@ func AuthenticateAndGetUnscopedToken(ctx context.Context, authInfo common.AuthIn return nil, fmt.Errorf("fatal: error deserializing xml.\ntrace: %w", err) } - response, err := validateAuthenticationWithServiceProvider(ctx, assertionResult, bodyBytes, httpClient) + response, err := a.validateAuthenticationWithServiceProvider(ctx, assertionResult, bodyBytes) if err != nil { return nil, fmt.Errorf("couldn't validate auth with service provider: %w", err) } defer response.Body.Close() - tokenResponse, err := common.GetCloudCredentialsFromResponse(response) + tokenResponse, err := a.parser.Parse(response) if err != nil { return nil, fmt.Errorf("couldn't get cloud creds from response: %w", err) } @@ -49,8 +69,8 @@ func AuthenticateAndGetUnscopedToken(ctx context.Context, authInfo common.AuthIn return tokenResponse, nil } -func getServiceProviderInitiatedRequest(ctx context.Context, - params common.AuthInfo, client common.HTTPClient, +func (a *Authenticator) getServiceProviderInitiatedRequest(ctx context.Context, + params common.AuthInfo, ) (*http.Response, error) { request, err := common.NewRequest(ctx, http.MethodGet, endpoints.IdentityProviders(params.IdpName, string(params.AuthProtocol), params.Region), nil) @@ -60,11 +80,11 @@ func getServiceProviderInitiatedRequest(ctx context.Context, request.Header.Add(headers.Accept, headervalues.ApplicationPaos) request.Header.Add(header.Paos, headervalues.Paos) - return client.MakeRequest(request) + return a.client.MakeRequest(request) } -func authenticateWithIdp(ctx context.Context, params common.AuthInfo, - samlResponse *http.Response, client common.HTTPClient, +func (a *Authenticator) authenticateWithIdp(ctx context.Context, params common.AuthInfo, + samlResponse *http.Response, ) ([]byte, error) { request, err := common.NewRequest(ctx, http.MethodPost, params.IdpURL, samlResponse.Body) if err != nil { @@ -75,7 +95,7 @@ func authenticateWithIdp(ctx context.Context, params common.AuthInfo, request.Header.Add(headers.ContentType, headervalues.TextXML) request.SetBasicAuth(params.Username, params.Password) - response, err := client.MakeRequest(request) + response, err := a.client.MakeRequest(request) if err != nil { return nil, err } @@ -83,8 +103,9 @@ func authenticateWithIdp(ctx context.Context, params common.AuthInfo, return common.GetBodyBytesFromResponse(response) } -func validateAuthenticationWithServiceProvider(ctx context.Context, assertionResult common.SamlAssertionResponse, - responseBodyBytes []byte, client common.HTTPClient, +func (a *Authenticator) validateAuthenticationWithServiceProvider(ctx context.Context, + assertionResult common.SamlAssertionResponse, + responseBodyBytes []byte, ) (*http.Response, error) { request, err := common.NewRequest(ctx, http.MethodPost, assertionResult.Header.Response.AssertionConsumerServiceURL, bytes.NewReader(responseBodyBytes)) @@ -93,5 +114,5 @@ func validateAuthenticationWithServiceProvider(ctx context.Context, assertionRes } request.Header.Add(headers.ContentType, headervalues.ApplicationPaos) - return client.MakeRequest(request) + return a.client.MakeRequest(request) } diff --git a/saml/saml_test.go b/saml/saml_test.go new file mode 100644 index 0000000..383528f --- /dev/null +++ b/saml/saml_test.go @@ -0,0 +1,559 @@ +//nolint:testpackage // whitebox testing +package saml + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/xml" + "fmt" + "io" + "net/http" + "reflect" + "strings" + "testing" + + "otc-auth/common" + "otc-auth/common/endpoints" + "otc-auth/common/headervalues" + header "otc-auth/common/xheaders" + + "github.com/go-http-utils/headers" +) + +type mockHTTPClient struct { + T *testing.T + + ResponseToReturn *http.Response + ErrorToReturn error + + ExpectedURL string + ExpectedMethod string + ExpectedHeader http.Header + ExpectedBody []byte +} + +type mockSequencedHTTPClient struct { + T *testing.T + Responses []*http.Response // A queue of responses to return for each call. + Errors []error // A parallel queue of errors to return. + callIndex int // Tracks which call we are on. +} + +type mockCredentialParser struct { + TokenToReturn *common.TokenResponse + ErrorToReturn error +} + +func (m *mockCredentialParser) Parse(resp *http.Response) (*common.TokenResponse, error) { + return m.TokenToReturn, m.ErrorToReturn +} + +func (m *mockSequencedHTTPClient) MakeRequest(req *http.Request) (*http.Response, error) { + if m.callIndex >= len(m.Responses) || m.callIndex >= len(m.Errors) { + m.T.Fatalf("MakeRequest called more times than expected. Got call #%d", m.callIndex+1) + return nil, fmt.Errorf("unexpected call") + } + + response := m.Responses[m.callIndex] + err := m.Errors[m.callIndex] + m.callIndex++ + return response, err +} + +func (m *mockHTTPClient) MakeRequest(req *http.Request) (*http.Response, error) { + if m.ExpectedMethod != "" && req.Method != m.ExpectedMethod { + m.T.Errorf("MakeRequest() received method %q, want %q", req.Method, m.ExpectedMethod) + } + if m.ExpectedURL != "" && req.URL.String() != m.ExpectedURL { + m.T.Errorf("MakeRequest() received URL %q, want %q", req.URL.String(), m.ExpectedURL) + } + if m.ExpectedHeader != nil { + headerKey := headers.ContentType + if got := req.Header.Get(headerKey); got != m.ExpectedHeader.Get(headerKey) { + m.T.Errorf("MakeRequest() header %q = %q, want %q", headerKey, got, m.ExpectedHeader.Get(headerKey)) + } + } + if m.ExpectedBody != nil { + bodyBytes, _ := io.ReadAll(req.Body) + if !bytes.Equal(bodyBytes, m.ExpectedBody) { + m.T.Errorf("MakeRequest() body = %q, want %q", string(bodyBytes), string(m.ExpectedBody)) + } + } + + return m.ResponseToReturn, m.ErrorToReturn +} + +func TestAuthenticator_validateAuthenticationWithServiceProvider(t *testing.T) { + type fields struct { + client common.HTTPClient + } + type args struct { + ctx context.Context + assertionResult common.SamlAssertionResponse + responseBodyBytes []byte + } + ctx := context.Background() + expectedURL := "https://example.com/assertion-consumer" + requestBodyBytes := []byte("assertion") + + validAssertionResult := common.SamlAssertionResponse{ + Name: xml.Name{}, + Header: struct { + Response struct { + AssertionConsumerServiceURL string `xml:"AssertionConsumerServiceURL,attr"` + } `xml:"Response"` + }{ + Response: struct { + AssertionConsumerServiceURL string `xml:"AssertionConsumerServiceURL,attr"` + }{ + AssertionConsumerServiceURL: expectedURL, + }, + }, + } + + successResponse := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("success")), + } + + tests := []struct { + name string + fields fields + args args + want *http.Response + wantErr bool + }{ + { + name: "Success - Happy Path", + fields: fields{ + client: &mockHTTPClient{ + T: t, + ResponseToReturn: successResponse, + ErrorToReturn: nil, + ExpectedURL: expectedURL, + ExpectedMethod: http.MethodPost, + ExpectedHeader: http.Header{headers.ContentType: []string{headervalues.ApplicationPaos}}, + ExpectedBody: requestBodyBytes, + }, + }, + args: args{ + ctx: ctx, + assertionResult: validAssertionResult, + responseBodyBytes: requestBodyBytes, + }, + want: successResponse, + wantErr: false, + }, + { + name: "Failure - Client returns an error", + fields: fields{ + client: &mockHTTPClient{ + T: t, + ErrorToReturn: fmt.Errorf("simulated network error"), + }, + }, + args: args{ + ctx: ctx, + assertionResult: validAssertionResult, + responseBodyBytes: requestBodyBytes, + }, + want: nil, + wantErr: true, + }, + { + name: "Failure - NewRequest fails due to bad URL", + fields: fields{ + client: &mockHTTPClient{T: t}, + }, + args: args{ + ctx: ctx, + assertionResult: common.SamlAssertionResponse{ + Header: struct { + Response struct { + AssertionConsumerServiceURL string `xml:"AssertionConsumerServiceURL,attr"` + } `xml:"Response"` + }{ + Response: struct { + AssertionConsumerServiceURL string `xml:"AssertionConsumerServiceURL,attr"` + }{ + AssertionConsumerServiceURL: "::not a valid URL", + }, + }, + }, + responseBodyBytes: requestBodyBytes, + }, + want: nil, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := &Authenticator{ + client: tt.fields.client, + } + got, err := a.validateAuthenticationWithServiceProvider(tt.args.ctx, + tt.args.assertionResult, tt.args.responseBodyBytes) + if (err != nil) != tt.wantErr { + t.Errorf("validateAuthenticationWithServiceProvider() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if got != tt.want { + t.Errorf("validateAuthenticationWithServiceProvider() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestAuthenticator_authenticateWithIdp(t *testing.T) { + // --- Common test data --- + ctx := context.Background() + authParams := common.AuthInfo{ + IdpURL: "https://idp.example.com/login", + Username: "testuser", + Password: "testpassword", + } + + // Create the expected Basic Auth header value. + auth := authParams.Username + ":" + authParams.Password + expectedAuthHeader := "Basic " + base64.StdEncoding.EncodeToString([]byte(auth)) + + // This is the body of the *incoming* response, which will be the body of the *outgoing* request. + samlRequestBody := "request" + samlResponse := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(samlRequestBody)), + } + + // This is the body of the response we expect to get back from the IdP. + expectedResponseBody := []byte("response") + idpSuccessResponse := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewReader(expectedResponseBody)), + } + + type fields struct { + client common.HTTPClient + } + type args struct { + ctx context.Context + params common.AuthInfo + samlResponse *http.Response + } + tests := []struct { + name string + fields fields + args args + want []byte + wantErr bool + }{ + { + name: "Success - Happy Path", + fields: fields{ + client: &mockHTTPClient{ + T: t, + ResponseToReturn: idpSuccessResponse, + ExpectedURL: authParams.IdpURL, + ExpectedMethod: http.MethodPost, + ExpectedHeader: http.Header{ + headers.ContentType: []string{headervalues.TextXML}, + headers.Authorization: []string{expectedAuthHeader}, + }, + }, + }, + args: args{ + ctx: ctx, + params: authParams, + samlResponse: samlResponse, + }, + want: expectedResponseBody, + wantErr: false, + }, + { + name: "Failure - Client returns an error", + fields: fields{ + client: &mockHTTPClient{ + T: t, + ErrorToReturn: fmt.Errorf("simulated network error"), + }, + }, + args: args{ + ctx: ctx, + params: authParams, + samlResponse: samlResponse, + }, + want: nil, + wantErr: true, + }, + { + name: "Failure - NewRequest fails due to bad URL", + fields: fields{ + client: &mockHTTPClient{T: t}, // Client won't be called. + }, + args: args{ + ctx: ctx, + params: common.AuthInfo{ + IdpURL: "::not a valid url", // Invalid URL to make NewRequest fail. + }, + samlResponse: samlResponse, + }, + want: nil, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := &Authenticator{ + client: tt.fields.client, + } + got, err := a.authenticateWithIdp(tt.args.ctx, tt.args.params, tt.args.samlResponse) + if (err != nil) != tt.wantErr { + t.Errorf("authenticateWithIdp() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("authenticateWithIdp() got = %s, want %s", string(got), string(tt.want)) + } + }) + } +} + +func TestAuthenticator_getServiceProviderInitiatedRequest(t *testing.T) { + ctx := context.Background() + authParams := common.AuthInfo{ + IdpName: "my-idp", + AuthProtocol: "saml", + Region: "eu-de", + } + + expectedURL := endpoints.IdentityProviders(authParams.IdpName, string(authParams.AuthProtocol), authParams.Region) + + successResponse := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("success")), + } + + type fields struct { + client common.HTTPClient + } + type args struct { + ctx context.Context + params common.AuthInfo + } + tests := []struct { + name string + fields fields + args args + want *http.Response + wantErr bool + }{ + { + name: "Success - Happy Path", + fields: fields{ + client: &mockHTTPClient{ + T: t, + ResponseToReturn: successResponse, + ExpectedURL: expectedURL, + ExpectedMethod: http.MethodGet, + ExpectedHeader: http.Header{ + headers.Accept: []string{headervalues.ApplicationPaos}, + header.Paos: []string{headervalues.Paos}, + }, + }, + }, + args: args{ + ctx: ctx, + params: authParams, + }, + want: successResponse, + wantErr: false, + }, + { + name: "Failure - Client returns an error", + fields: fields{ + client: &mockHTTPClient{ + T: t, + ErrorToReturn: fmt.Errorf("simulated network error"), + }, + }, + args: args{ + ctx: ctx, + params: authParams, + }, + want: nil, + wantErr: true, + }, + { + name: "Failure - NewRequest fails (e.g., bad region)", + fields: fields{ + client: &mockHTTPClient{T: t}, + }, + args: args{ + ctx: ctx, + params: common.AuthInfo{ + Region: " ", + }, + }, + want: nil, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := &Authenticator{ + client: tt.fields.client, + } + got, err := a.getServiceProviderInitiatedRequest(tt.args.ctx, tt.args.params) + if (err != nil) != tt.wantErr { + t.Errorf("getServiceProviderInitiatedRequest() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("getServiceProviderInitiatedRequest() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestAuthenticator_Authenticate(t *testing.T) { + ctx := context.Background() + authInfo := common.AuthInfo{ + Region: "eu-de", + } + + spSuccessResponse := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("sp-response-body")), + } + + samlXMLBody := ` +
+ +
+ ` + idpSuccessResponse := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(samlXMLBody)), + } + + finalTokenBody := `{"token": {"id": "final-token"}}` + finalSuccessResponse := &http.Response{ + StatusCode: http.StatusCreated, + Header: http.Header{"X-Subject-Token": []string{"final-token-id"}}, + Body: io.NopCloser(strings.NewReader(finalTokenBody)), + } + + expectedTokenResponse := &common.TokenResponse{} + + type fields struct { + client common.HTTPClient + parser CredentialParser + } + type args struct { + ctx context.Context + authInfo common.AuthInfo + } + tests := []struct { + name string + fields fields + args args + setup func() // Optional setup, e.g., for mocking package-level funcs + want *common.TokenResponse + wantErr bool + }{ + { + name: "Success - Happy Path", + fields: fields{ + client: &mockSequencedHTTPClient{ + T: t, + Responses: []*http.Response{spSuccessResponse, idpSuccessResponse, finalSuccessResponse}, + Errors: []error{nil, nil, nil}, + }, + parser: &mockCredentialParser{TokenToReturn: expectedTokenResponse}, + }, + args: args{ctx: ctx, authInfo: authInfo}, + want: expectedTokenResponse, + wantErr: false, + }, + { + name: "Failure - Final credential parsing fails", + fields: fields{ + client: &mockSequencedHTTPClient{ + T: t, + Responses: []*http.Response{spSuccessResponse, idpSuccessResponse, finalSuccessResponse}, + Errors: []error{nil, nil, nil}, + }, + // Configure the mock parser to return an error. + parser: &mockCredentialParser{ErrorToReturn: fmt.Errorf("could not parse final token")}, + }, + args: args{ctx: ctx, authInfo: authInfo}, + want: nil, + wantErr: true, + }, + { + name: "Failure - First step (getServiceProviderInitiatedRequest) fails", + fields: fields{ + client: &mockSequencedHTTPClient{ + T: t, + Responses: []*http.Response{nil}, + Errors: []error{fmt.Errorf("network error on step 1")}, + }, + }, + args: args{ctx: ctx, authInfo: authInfo}, + setup: func() {}, + want: nil, + wantErr: true, + }, + { + name: "Failure - Second step (authenticateWithIdp) fails", + fields: fields{ + client: &mockSequencedHTTPClient{ + T: t, + Responses: []*http.Response{spSuccessResponse, nil}, + Errors: []error{nil, fmt.Errorf("network error on step 2")}, + }, + }, + args: args{ctx: ctx, authInfo: authInfo}, + setup: func() {}, + want: nil, + wantErr: true, + }, + { + name: "Failure - XML Unmarshal fails", + fields: fields{ + client: &mockSequencedHTTPClient{ + T: t, + Responses: []*http.Response{ + spSuccessResponse, + {StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader("not valid xml"))}, + }, + Errors: []error{nil, nil}, + }, + }, + args: args{ctx: ctx, authInfo: authInfo}, + setup: func() {}, + want: nil, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.setup != nil { + tt.setup() + } + a := &Authenticator{ + client: tt.fields.client, + parser: tt.fields.parser, + } + got, err := a.Authenticate(tt.args.ctx, tt.args.authInfo) + if (err != nil) != tt.wantErr { + t.Errorf("AuthenticateAndGetUnscopedToken() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("AuthenticateAndGetUnscopedToken() got = %v, want %v", got, tt.want) + } + }) + } +}