diff --git a/tenant/tenantmiddleware.go b/tenant/tenantmiddleware.go index 6ff65b8..f72e1f2 100644 --- a/tenant/tenantmiddleware.go +++ b/tenant/tenantmiddleware.go @@ -37,7 +37,7 @@ import ( "crypto/sha256" "encoding/base64" "errors" - "log" + "fmt" "net/http" "strings" ) @@ -61,7 +61,7 @@ const ( // Adds systemBaseUri and tenantId to request context. // If the headers are not present the given defaultSystemBaseUri and tenant "0" are used. // The signatureSecretKey is specific for each App and is provided by the registration process for d.velop cloud. -func AddToCtx(defaultSystemBaseUri string, signatureSecretKey []byte) func(http.Handler) http.Handler { +func AddToCtx(defaultSystemBaseUri string, signatureSecretKey []byte, logger func(ctx context.Context, message string)) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { ctx := req.Context() @@ -71,19 +71,19 @@ func AddToCtx(defaultSystemBaseUri string, signatureSecretKey []byte) func(http. if systemBaseUri != "" || tenantId != "" { if signatureSecretKey == nil { - log.Printf("error validating signature for headers '%v' and '%v' because secret signature key has not been configured", systemBaseUriHeader, tenantIdHeader) + logger(req.Context(), fmt.Sprintf("validating signature for headers '%v' and '%v' because secret signature key has not been configured", systemBaseUriHeader, tenantIdHeader)) http.Error(rw, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } base64Signature := req.Header.Get("x-dv-sig-1") signature, err := base64.StdEncoding.DecodeString(base64Signature) if err != nil { - log.Printf("error decoding signature '%v' as base 64 data because: %v", base64Signature, err) + logger(req.Context(), fmt.Sprintf("decoding signature '%v' as base 64 data because: %v", base64Signature, err)) http.Error(rw, http.StatusText(http.StatusForbidden), http.StatusForbidden) return } if !signatureIsValid([]byte(systemBaseUri+tenantId), []byte(signature), signatureSecretKey) { - log.Printf("error signature '%v' is not valid for SystemBaseUri '%v' and TenantId '%v'", signature, systemBaseUri, tenantId) + logger(req.Context(), fmt.Sprintf("signature '%v' is not valid for SystemBaseUri '%v' and TenantId '%v'", signature, systemBaseUri, tenantId)) http.Error(rw, http.StatusText(http.StatusForbidden), http.StatusForbidden) return } diff --git a/tenant/tenantmiddleware_test.go b/tenant/tenantmiddleware_test.go index b9f467f..90f0474 100644 --- a/tenant/tenantmiddleware_test.go +++ b/tenant/tenantmiddleware_test.go @@ -8,6 +8,7 @@ import ( "fmt" "net/http" "net/http/httptest" + "strings" "testing" "github.com/d-velop/dvelop-sdk-go/tenant" @@ -34,7 +35,8 @@ func TestBaseUriHeaderAndEmptyDefaultBaseUri_UsesHeader(t *testing.T) { handlerSpy := handlerSpy{} responseSpy := responseSpy{httptest.NewRecorder()} - tenant.AddToCtx("", signatureKey)(&handlerSpy).ServeHTTP(responseSpy, req) + logSpy := loggerSpy{} + tenant.AddToCtx("", signatureKey, logSpy.logError)(&handlerSpy).ServeHTTP(responseSpy, req) if err := responseSpy.assertStatusCodeIs(http.StatusOK); err != nil { t.Error(err) @@ -51,8 +53,9 @@ func TestNoBaseUriHeaderAndDefaultBaseUri_UsesDefaultBaseUri(t *testing.T) { } handlerSpy := handlerSpy{} responseSpy := responseSpy{httptest.NewRecorder()} + logSpy := loggerSpy{} - tenant.AddToCtx(defaultSystemBaseUri, signatureKey)(&handlerSpy).ServeHTTP(responseSpy, req) + tenant.AddToCtx(defaultSystemBaseUri, signatureKey, logSpy.logError)(&handlerSpy).ServeHTTP(responseSpy, req) if err := responseSpy.assertStatusCodeIs(http.StatusOK); err != nil { t.Error(err) @@ -72,8 +75,9 @@ func TestBaseUriHeaderAndDefaultBaseUri_UsesHeader(t *testing.T) { req.Header.Set(signatureHeader, base64Signature(systemBaseUriFromHeader, signatureKey)) handlerSpy := handlerSpy{} responseSpy := responseSpy{httptest.NewRecorder()} + logSpy := loggerSpy{} - tenant.AddToCtx(defaultSystemBaseUri, signatureKey)(&handlerSpy).ServeHTTP(responseSpy, req) + tenant.AddToCtx(defaultSystemBaseUri, signatureKey, logSpy.logError)(&handlerSpy).ServeHTTP(responseSpy, req) if err := responseSpy.assertStatusCodeIs(http.StatusOK); err != nil { t.Error(err) @@ -89,8 +93,9 @@ func TestNoBaseUriHeaderAndEmptyDefaultBaseUri_DoesntAddBaseUriToContext(t *test t.Fatal(err) } handlerSpy := handlerSpy{} + logSpy := loggerSpy{} - tenant.AddToCtx("", signatureKey)(&handlerSpy).ServeHTTP(httptest.NewRecorder(), req) + tenant.AddToCtx("", signatureKey, logSpy.logError)(&handlerSpy).ServeHTTP(httptest.NewRecorder(), req) if err := handlerSpy.assertErrorReadingSystemBaseUri(); err != nil { t.Error(err) @@ -107,8 +112,9 @@ func TestTenantIdHeader_UsesHeader(t *testing.T) { req.Header.Set(signatureHeader, base64Signature(tenantIdFromHeader, signatureKey)) handlerSpy := handlerSpy{} responseSpy := responseSpy{httptest.NewRecorder()} + logSpy := loggerSpy{} - tenant.AddToCtx("", signatureKey)(&handlerSpy).ServeHTTP(responseSpy, req) + tenant.AddToCtx("", signatureKey, logSpy.logError)(&handlerSpy).ServeHTTP(responseSpy, req) if err := responseSpy.assertStatusCodeIs(http.StatusOK); err != nil { t.Error(err) @@ -125,8 +131,9 @@ func TestNoTenantIdHeader_UsesTenantIdZero(t *testing.T) { } handlerSpy := handlerSpy{} responseSpy := responseSpy{httptest.NewRecorder()} + logSpy := loggerSpy{} - tenant.AddToCtx("", signatureKey)(&handlerSpy).ServeHTTP(responseSpy, req) + tenant.AddToCtx("", signatureKey, logSpy.logError)(&handlerSpy).ServeHTTP(responseSpy, req) if err := responseSpy.assertStatusCodeIs(http.StatusOK); err != nil { t.Error(err) @@ -147,8 +154,9 @@ func TestInitiatorSystemBaseUriHeader_UsesForwardedHeader(t *testing.T) { req.Header.Set(signatureHeader, base64Signature(forwardedHeaderValue, signatureKey)) handlerSpy := handlerSpy{} responseSpy := responseSpy{httptest.NewRecorder()} + logSpy := loggerSpy{} - tenant.AddToCtx("", signatureKey)(&handlerSpy).ServeHTTP(responseSpy, req) + tenant.AddToCtx("", signatureKey, logSpy.logError)(&handlerSpy).ServeHTTP(responseSpy, req) if err := responseSpy.assertStatusCodeIs(http.StatusOK); err != nil { t.Error(err) @@ -169,8 +177,9 @@ func TestInitiatorSystemBaseUriHeader_UsesForwardedHeaderMultipleHosts(t *testin req.Header.Set(signatureHeader, base64Signature(forwardedHeaderValue, signatureKey)) handlerSpy := handlerSpy{} responseSpy := responseSpy{httptest.NewRecorder()} + logSpy := loggerSpy{} - tenant.AddToCtx("", signatureKey)(&handlerSpy).ServeHTTP(responseSpy, req) + tenant.AddToCtx("", signatureKey, logSpy.logError)(&handlerSpy).ServeHTTP(responseSpy, req) if err := responseSpy.assertStatusCodeIs(http.StatusOK); err != nil { t.Error(err) @@ -190,8 +199,9 @@ func TestInitiatorSystemBaseUriHeader_UsesXForwardedHeader(t *testing.T) { req.Header.Set(signatureHeader, base64Signature(xForwardedHostValue, signatureKey)) handlerSpy := handlerSpy{} responseSpy := responseSpy{httptest.NewRecorder()} + logSpy := loggerSpy{} - tenant.AddToCtx("", signatureKey)(&handlerSpy).ServeHTTP(responseSpy, req) + tenant.AddToCtx("", signatureKey, logSpy.logError)(&handlerSpy).ServeHTTP(responseSpy, req) if err := responseSpy.assertStatusCodeIs(http.StatusOK); err != nil { t.Error(err) @@ -212,8 +222,9 @@ func TestInitiatorSystemBaseUriHeader_UsesXForwardedHeaderMultipleHosts(t *testi req.Header.Set(signatureHeader, base64Signature(xForwardedHostMultiValue, signatureKey)) handlerSpy := handlerSpy{} responseSpy := responseSpy{httptest.NewRecorder()} + logSpy := loggerSpy{} - tenant.AddToCtx("", signatureKey)(&handlerSpy).ServeHTTP(responseSpy, req) + tenant.AddToCtx("", signatureKey, logSpy.logError)(&handlerSpy).ServeHTTP(responseSpy, req) if err := responseSpy.assertStatusCodeIs(http.StatusOK); err != nil { t.Error(err) @@ -232,8 +243,9 @@ func TestInitiatorSystemBaseUriHeader_EmptyForwardedHeadersNoSystemBaseUri(t *te req.Header.Set(signatureHeader, base64Signature("", signatureKey)) handlerSpy := handlerSpy{} responseSpy := responseSpy{httptest.NewRecorder()} + logSpy := loggerSpy{} - tenant.AddToCtx("", signatureKey)(&handlerSpy).ServeHTTP(responseSpy, req) + tenant.AddToCtx("", signatureKey, logSpy.logError)(&handlerSpy).ServeHTTP(responseSpy, req) if err := responseSpy.assertStatusCodeIs(http.StatusOK); err != nil { t.Error(err) @@ -253,8 +265,9 @@ func TestInitiatorSystemBaseUriHeader_EmptyForwardedHeadersWithSystemBaseUri(t * req.Header.Set(signatureHeader, base64Signature(systemBaseUri, signatureKey)) handlerSpy := handlerSpy{} responseSpy := responseSpy{httptest.NewRecorder()} + logSpy := loggerSpy{} - tenant.AddToCtx("", signatureKey)(&handlerSpy).ServeHTTP(responseSpy, req) + tenant.AddToCtx("", signatureKey, logSpy.logError)(&handlerSpy).ServeHTTP(responseSpy, req) if err := responseSpy.assertStatusCodeIs(http.StatusOK); err != nil { t.Error(err) @@ -273,8 +286,9 @@ func TestInitiatorSystemBaseUriHeader_EmptyForwardedHeadersWithDefaultSystemBase handlerSpy := handlerSpy{} responseSpy := responseSpy{httptest.NewRecorder()} + logSpy := loggerSpy{} - tenant.AddToCtx(defaultSystemBaseUri, signatureKey)(&handlerSpy).ServeHTTP(responseSpy, req) + tenant.AddToCtx(defaultSystemBaseUri, signatureKey, logSpy.logError)(&handlerSpy).ServeHTTP(responseSpy, req) if err := responseSpy.assertStatusCodeIs(http.StatusOK); err != nil { t.Error(err) @@ -296,8 +310,9 @@ func TestTenantIdHeaderAndBaseUriHeader_UsesHeaders(t *testing.T) { req.Header.Set(signatureHeader, base64Signature(systemBaseUriFromHeader+tenantIdFromHeader, signatureKey)) handlerSpy := handlerSpy{} responseSpy := responseSpy{httptest.NewRecorder()} + logSpy := loggerSpy{} - tenant.AddToCtx(defaultSystemBaseUri, signatureKey)(&handlerSpy).ServeHTTP(responseSpy, req) + tenant.AddToCtx(defaultSystemBaseUri, signatureKey, logSpy.logError)(&handlerSpy).ServeHTTP(responseSpy, req) if err := responseSpy.assertStatusCodeIs(http.StatusOK); err != nil { t.Error(err) @@ -323,8 +338,9 @@ func TestTenantIdHeaderAndNoBaseUriHeader_UsesTenantIdHeaderAndDefaultSystemBase req.Header.Set(signatureHeader, base64Signature(tenantIdFromHeader, signatureKey)) handlerSpy := handlerSpy{} responseSpy := responseSpy{httptest.NewRecorder()} + logSpy := loggerSpy{} - tenant.AddToCtx(defaultSystemBaseUri, signatureKey)(&handlerSpy).ServeHTTP(responseSpy, req) + tenant.AddToCtx(defaultSystemBaseUri, signatureKey, logSpy.logError)(&handlerSpy).ServeHTTP(responseSpy, req) if err := responseSpy.assertStatusCodeIs(http.StatusOK); err != nil { t.Error(err) @@ -347,8 +363,9 @@ func TestNoHeadersButDefaultSystemBaseUri_UsesDefaultBaseUriAndTenantIdZero(t *t } handlerSpy := handlerSpy{} responseSpy := responseSpy{httptest.NewRecorder()} + logSpy := loggerSpy{} - tenant.AddToCtx(defaultSystemBaseUri, signatureKey)(&handlerSpy).ServeHTTP(responseSpy, req) + tenant.AddToCtx(defaultSystemBaseUri, signatureKey, logSpy.logError)(&handlerSpy).ServeHTTP(responseSpy, req) if err := responseSpy.assertStatusCodeIs(http.StatusOK); err != nil { t.Error(err) @@ -371,8 +388,9 @@ func TestNoHeadersButDefaultSystemBaseUriAndNoSignatureSecretKey_UsesDefaultBase } handlerSpy := handlerSpy{} responseSpy := responseSpy{httptest.NewRecorder()} + logSpy := loggerSpy{} - tenant.AddToCtx(defaultSystemBaseUri, nil)(&handlerSpy).ServeHTTP(responseSpy, req) + tenant.AddToCtx(defaultSystemBaseUri, nil, logSpy.logError)(&handlerSpy).ServeHTTP(responseSpy, req) if err := responseSpy.assertStatusCodeIs(http.StatusOK); err != nil { t.Error(err) @@ -400,8 +418,9 @@ func TestWrongDataSignedWithValidSignatureKey_Returns403(t *testing.T) { req.Header.Set(signatureHeader, base64Signature("wrong data", signatureKey)) handlerSpy := handlerSpy{} responseSpy := responseSpy{httptest.NewRecorder()} + logSpy := loggerSpy{} - tenant.AddToCtx("", signatureKey)(&handlerSpy).ServeHTTP(responseSpy, req) + tenant.AddToCtx("", signatureKey, logSpy.logError)(&handlerSpy).ServeHTTP(responseSpy, req) if err := responseSpy.assertStatusCodeIs(http.StatusForbidden); err != nil { t.Error(err) @@ -409,6 +428,10 @@ func TestWrongDataSignedWithValidSignatureKey_Returns403(t *testing.T) { if handlerSpy.hasBeenCalled { t.Error("inner handler should not have been called") } + + if err := logSpy.assertLogContains("signature"); err != nil { + t.Error(err) + } } func TestNoneBase64Signature_Returns403(t *testing.T) { @@ -423,8 +446,9 @@ func TestNoneBase64Signature_Returns403(t *testing.T) { req.Header.Set(signatureHeader, "abc+(9-!") handlerSpy := handlerSpy{} responseSpy := responseSpy{httptest.NewRecorder()} + logSpy := loggerSpy{} - tenant.AddToCtx("", signatureKey)(&handlerSpy).ServeHTTP(responseSpy, req) + tenant.AddToCtx("", signatureKey, logSpy.logError)(&handlerSpy).ServeHTTP(responseSpy, req) if err := responseSpy.assertStatusCodeIs(http.StatusForbidden); err != nil { t.Error(err) @@ -432,6 +456,10 @@ func TestNoneBase64Signature_Returns403(t *testing.T) { if handlerSpy.hasBeenCalled { t.Error("inner handler should not have been called") } + + if err := logSpy.assertLogContains("illegal base64"); err != nil { + t.Error(err) + } } func TestWrongSignatureKey_Returns403(t *testing.T) { @@ -447,8 +475,9 @@ func TestWrongSignatureKey_Returns403(t *testing.T) { req.Header.Set(signatureHeader, base64Signature(systemBaseUriFromHeader+tenantIdFromHeader, wrongSignatureKey)) handlerSpy := handlerSpy{} responseSpy := responseSpy{httptest.NewRecorder()} + logSpy := loggerSpy{} - tenant.AddToCtx("", signatureKey)(&handlerSpy).ServeHTTP(responseSpy, req) + tenant.AddToCtx("", signatureKey, logSpy.logError)(&handlerSpy).ServeHTTP(responseSpy, req) if err := responseSpy.assertStatusCodeIs(http.StatusForbidden); err != nil { t.Error(err) @@ -456,6 +485,10 @@ func TestWrongSignatureKey_Returns403(t *testing.T) { if handlerSpy.hasBeenCalled { t.Error("inner handler should not have been called") } + + if err := logSpy.assertLogContains("signature"); err != nil { + t.Error(err) + } } func TestHeadersWithoutSignature_Returns403(t *testing.T) { @@ -469,8 +502,9 @@ func TestHeadersWithoutSignature_Returns403(t *testing.T) { req.Header.Set(tenantIdHeader, tenantIdFromHeader) handlerSpy := handlerSpy{} responseSpy := responseSpy{httptest.NewRecorder()} + logSpy := loggerSpy{} - tenant.AddToCtx("", signatureKey)(&handlerSpy).ServeHTTP(responseSpy, req) + tenant.AddToCtx("", signatureKey, logSpy.logError)(&handlerSpy).ServeHTTP(responseSpy, req) if err := responseSpy.assertStatusCodeIs(http.StatusForbidden); err != nil { t.Error(err) @@ -478,6 +512,10 @@ func TestHeadersWithoutSignature_Returns403(t *testing.T) { if handlerSpy.hasBeenCalled { t.Error("inner handler should not have been called") } + + if err := logSpy.assertLogContains("signature"); err != nil { + t.Error(err) + } } func TestHeadersAndNoSignatureSecretKey_Returns500(t *testing.T) { @@ -492,8 +530,9 @@ func TestHeadersAndNoSignatureSecretKey_Returns500(t *testing.T) { req.Header.Set(signatureHeader, base64Signature(systemBaseUriFromHeader+tenantIdFromHeader, signatureKey)) handlerSpy := handlerSpy{} responseSpy := responseSpy{httptest.NewRecorder()} + logSpy := loggerSpy{} - tenant.AddToCtx("", nil)(&handlerSpy).ServeHTTP(responseSpy, req) + tenant.AddToCtx("", nil, logSpy.logError)(&handlerSpy).ServeHTTP(responseSpy, req) if err := responseSpy.assertStatusCodeIs(http.StatusInternalServerError); err != nil { t.Error(err) @@ -501,6 +540,10 @@ func TestHeadersAndNoSignatureSecretKey_Returns500(t *testing.T) { if handlerSpy.hasBeenCalled { t.Error("inner handler should not have been called") } + + if err := logSpy.assertLogContains("secret"); err != nil { + t.Error(err) + } } func TestNoIdOnContext_SetId_ReturnsContextWithId(t *testing.T) { @@ -625,3 +668,24 @@ func (spy *responseSpy) assertStatusCodeIs(expectedStatusCode int) error { } return nil } + +type loggerSpy struct { + hasBeenCalled bool + lastMessage string +} + +func (spy *loggerSpy) logError(ctx context.Context, message string) { + spy.hasBeenCalled = true + spy.lastMessage = message + fmt.Println(message) +} + +func (spy *loggerSpy) assertLogContains(term string) error { + if !spy.hasBeenCalled { + return fmt.Errorf("log should have been written") + } + if !strings.Contains(spy.lastMessage, term) { + return fmt.Errorf("expected log to contain the term '%v'", term) + } + return nil +} \ No newline at end of file