Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions tenant/tenantmiddleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import (
"crypto/sha256"
"encoding/base64"
"errors"
"log"
"fmt"
"net/http"
"strings"
)
Expand All @@ -61,7 +61,7 @@ const (
// Adds systemBaseUri and tenantId to request context.
// If the headers are not present the given defaultSystemBaseUri and tenant "0" are used.
// The signatureSecretKey is specific for each App and is provided by the registration process for d.velop cloud.
func AddToCtx(defaultSystemBaseUri string, signatureSecretKey []byte) func(http.Handler) http.Handler {
func AddToCtx(defaultSystemBaseUri string, signatureSecretKey []byte, logger func(ctx context.Context, message string)) func(http.Handler) http.Handler {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use logError instead of logger for the name of the log func param to use a naming similar to the idp middleware.

return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
ctx := req.Context()
Expand All @@ -71,19 +71,19 @@ func AddToCtx(defaultSystemBaseUri string, signatureSecretKey []byte) func(http.

if systemBaseUri != "" || tenantId != "" {
if signatureSecretKey == nil {
log.Printf("error validating signature for headers '%v' and '%v' because secret signature key has not been configured", systemBaseUriHeader, tenantIdHeader)
logger(req.Context(), fmt.Sprintf("validating signature for headers '%v' and '%v' because secret signature key has not been configured", systemBaseUriHeader, tenantIdHeader))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider using the local ctx var because it is shorter.

http.Error(rw, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
base64Signature := req.Header.Get("x-dv-sig-1")
signature, err := base64.StdEncoding.DecodeString(base64Signature)
if err != nil {
log.Printf("error decoding signature '%v' as base 64 data because: %v", base64Signature, err)
logger(req.Context(), fmt.Sprintf("decoding signature '%v' as base 64 data because: %v", base64Signature, err))
http.Error(rw, http.StatusText(http.StatusForbidden), http.StatusForbidden)
return
}
if !signatureIsValid([]byte(systemBaseUri+tenantId), []byte(signature), signatureSecretKey) {
log.Printf("error signature '%v' is not valid for SystemBaseUri '%v' and TenantId '%v'", signature, systemBaseUri, tenantId)
logger(req.Context(), fmt.Sprintf("signature '%v' is not valid for SystemBaseUri '%v' and TenantId '%v'", signature, systemBaseUri, tenantId))
http.Error(rw, http.StatusText(http.StatusForbidden), http.StatusForbidden)
return
}
Expand Down
108 changes: 86 additions & 22 deletions tenant/tenantmiddleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"

"github.com/d-velop/dvelop-sdk-go/tenant"
Expand All @@ -34,7 +35,8 @@ func TestBaseUriHeaderAndEmptyDefaultBaseUri_UsesHeader(t *testing.T) {
handlerSpy := handlerSpy{}
responseSpy := responseSpy{httptest.NewRecorder()}

tenant.AddToCtx("", signatureKey)(&handlerSpy).ServeHTTP(responseSpy, req)
logSpy := loggerSpy{}
tenant.AddToCtx("", signatureKey, logSpy.logError)(&handlerSpy).ServeHTTP(responseSpy, req)

if err := responseSpy.assertStatusCodeIs(http.StatusOK); err != nil {
t.Error(err)
Expand All @@ -51,8 +53,9 @@ func TestNoBaseUriHeaderAndDefaultBaseUri_UsesDefaultBaseUri(t *testing.T) {
}
handlerSpy := handlerSpy{}
responseSpy := responseSpy{httptest.NewRecorder()}
logSpy := loggerSpy{}

tenant.AddToCtx(defaultSystemBaseUri, signatureKey)(&handlerSpy).ServeHTTP(responseSpy, req)
tenant.AddToCtx(defaultSystemBaseUri, signatureKey, logSpy.logError)(&handlerSpy).ServeHTTP(responseSpy, req)

if err := responseSpy.assertStatusCodeIs(http.StatusOK); err != nil {
t.Error(err)
Expand All @@ -72,8 +75,9 @@ func TestBaseUriHeaderAndDefaultBaseUri_UsesHeader(t *testing.T) {
req.Header.Set(signatureHeader, base64Signature(systemBaseUriFromHeader, signatureKey))
handlerSpy := handlerSpy{}
responseSpy := responseSpy{httptest.NewRecorder()}
logSpy := loggerSpy{}

tenant.AddToCtx(defaultSystemBaseUri, signatureKey)(&handlerSpy).ServeHTTP(responseSpy, req)
tenant.AddToCtx(defaultSystemBaseUri, signatureKey, logSpy.logError)(&handlerSpy).ServeHTTP(responseSpy, req)

if err := responseSpy.assertStatusCodeIs(http.StatusOK); err != nil {
t.Error(err)
Expand All @@ -89,8 +93,9 @@ func TestNoBaseUriHeaderAndEmptyDefaultBaseUri_DoesntAddBaseUriToContext(t *test
t.Fatal(err)
}
handlerSpy := handlerSpy{}
logSpy := loggerSpy{}

tenant.AddToCtx("", signatureKey)(&handlerSpy).ServeHTTP(httptest.NewRecorder(), req)
tenant.AddToCtx("", signatureKey, logSpy.logError)(&handlerSpy).ServeHTTP(httptest.NewRecorder(), req)

if err := handlerSpy.assertErrorReadingSystemBaseUri(); err != nil {
t.Error(err)
Expand All @@ -107,8 +112,9 @@ func TestTenantIdHeader_UsesHeader(t *testing.T) {
req.Header.Set(signatureHeader, base64Signature(tenantIdFromHeader, signatureKey))
handlerSpy := handlerSpy{}
responseSpy := responseSpy{httptest.NewRecorder()}
logSpy := loggerSpy{}

tenant.AddToCtx("", signatureKey)(&handlerSpy).ServeHTTP(responseSpy, req)
tenant.AddToCtx("", signatureKey, logSpy.logError)(&handlerSpy).ServeHTTP(responseSpy, req)

if err := responseSpy.assertStatusCodeIs(http.StatusOK); err != nil {
t.Error(err)
Expand All @@ -125,8 +131,9 @@ func TestNoTenantIdHeader_UsesTenantIdZero(t *testing.T) {
}
handlerSpy := handlerSpy{}
responseSpy := responseSpy{httptest.NewRecorder()}
logSpy := loggerSpy{}

tenant.AddToCtx("", signatureKey)(&handlerSpy).ServeHTTP(responseSpy, req)
tenant.AddToCtx("", signatureKey, logSpy.logError)(&handlerSpy).ServeHTTP(responseSpy, req)

if err := responseSpy.assertStatusCodeIs(http.StatusOK); err != nil {
t.Error(err)
Expand All @@ -147,8 +154,9 @@ func TestInitiatorSystemBaseUriHeader_UsesForwardedHeader(t *testing.T) {
req.Header.Set(signatureHeader, base64Signature(forwardedHeaderValue, signatureKey))
handlerSpy := handlerSpy{}
responseSpy := responseSpy{httptest.NewRecorder()}
logSpy := loggerSpy{}

tenant.AddToCtx("", signatureKey)(&handlerSpy).ServeHTTP(responseSpy, req)
tenant.AddToCtx("", signatureKey, logSpy.logError)(&handlerSpy).ServeHTTP(responseSpy, req)

if err := responseSpy.assertStatusCodeIs(http.StatusOK); err != nil {
t.Error(err)
Expand All @@ -169,8 +177,9 @@ func TestInitiatorSystemBaseUriHeader_UsesForwardedHeaderMultipleHosts(t *testin
req.Header.Set(signatureHeader, base64Signature(forwardedHeaderValue, signatureKey))
handlerSpy := handlerSpy{}
responseSpy := responseSpy{httptest.NewRecorder()}
logSpy := loggerSpy{}

tenant.AddToCtx("", signatureKey)(&handlerSpy).ServeHTTP(responseSpy, req)
tenant.AddToCtx("", signatureKey, logSpy.logError)(&handlerSpy).ServeHTTP(responseSpy, req)

if err := responseSpy.assertStatusCodeIs(http.StatusOK); err != nil {
t.Error(err)
Expand All @@ -190,8 +199,9 @@ func TestInitiatorSystemBaseUriHeader_UsesXForwardedHeader(t *testing.T) {
req.Header.Set(signatureHeader, base64Signature(xForwardedHostValue, signatureKey))
handlerSpy := handlerSpy{}
responseSpy := responseSpy{httptest.NewRecorder()}
logSpy := loggerSpy{}

tenant.AddToCtx("", signatureKey)(&handlerSpy).ServeHTTP(responseSpy, req)
tenant.AddToCtx("", signatureKey, logSpy.logError)(&handlerSpy).ServeHTTP(responseSpy, req)

if err := responseSpy.assertStatusCodeIs(http.StatusOK); err != nil {
t.Error(err)
Expand All @@ -212,8 +222,9 @@ func TestInitiatorSystemBaseUriHeader_UsesXForwardedHeaderMultipleHosts(t *testi
req.Header.Set(signatureHeader, base64Signature(xForwardedHostMultiValue, signatureKey))
handlerSpy := handlerSpy{}
responseSpy := responseSpy{httptest.NewRecorder()}
logSpy := loggerSpy{}

tenant.AddToCtx("", signatureKey)(&handlerSpy).ServeHTTP(responseSpy, req)
tenant.AddToCtx("", signatureKey, logSpy.logError)(&handlerSpy).ServeHTTP(responseSpy, req)

if err := responseSpy.assertStatusCodeIs(http.StatusOK); err != nil {
t.Error(err)
Expand All @@ -232,8 +243,9 @@ func TestInitiatorSystemBaseUriHeader_EmptyForwardedHeadersNoSystemBaseUri(t *te
req.Header.Set(signatureHeader, base64Signature("", signatureKey))
handlerSpy := handlerSpy{}
responseSpy := responseSpy{httptest.NewRecorder()}
logSpy := loggerSpy{}

tenant.AddToCtx("", signatureKey)(&handlerSpy).ServeHTTP(responseSpy, req)
tenant.AddToCtx("", signatureKey, logSpy.logError)(&handlerSpy).ServeHTTP(responseSpy, req)

if err := responseSpy.assertStatusCodeIs(http.StatusOK); err != nil {
t.Error(err)
Expand All @@ -253,8 +265,9 @@ func TestInitiatorSystemBaseUriHeader_EmptyForwardedHeadersWithSystemBaseUri(t *
req.Header.Set(signatureHeader, base64Signature(systemBaseUri, signatureKey))
handlerSpy := handlerSpy{}
responseSpy := responseSpy{httptest.NewRecorder()}
logSpy := loggerSpy{}

tenant.AddToCtx("", signatureKey)(&handlerSpy).ServeHTTP(responseSpy, req)
tenant.AddToCtx("", signatureKey, logSpy.logError)(&handlerSpy).ServeHTTP(responseSpy, req)

if err := responseSpy.assertStatusCodeIs(http.StatusOK); err != nil {
t.Error(err)
Expand All @@ -273,8 +286,9 @@ func TestInitiatorSystemBaseUriHeader_EmptyForwardedHeadersWithDefaultSystemBase

handlerSpy := handlerSpy{}
responseSpy := responseSpy{httptest.NewRecorder()}
logSpy := loggerSpy{}

tenant.AddToCtx(defaultSystemBaseUri, signatureKey)(&handlerSpy).ServeHTTP(responseSpy, req)
tenant.AddToCtx(defaultSystemBaseUri, signatureKey, logSpy.logError)(&handlerSpy).ServeHTTP(responseSpy, req)

if err := responseSpy.assertStatusCodeIs(http.StatusOK); err != nil {
t.Error(err)
Expand All @@ -296,8 +310,9 @@ func TestTenantIdHeaderAndBaseUriHeader_UsesHeaders(t *testing.T) {
req.Header.Set(signatureHeader, base64Signature(systemBaseUriFromHeader+tenantIdFromHeader, signatureKey))
handlerSpy := handlerSpy{}
responseSpy := responseSpy{httptest.NewRecorder()}
logSpy := loggerSpy{}

tenant.AddToCtx(defaultSystemBaseUri, signatureKey)(&handlerSpy).ServeHTTP(responseSpy, req)
tenant.AddToCtx(defaultSystemBaseUri, signatureKey, logSpy.logError)(&handlerSpy).ServeHTTP(responseSpy, req)

if err := responseSpy.assertStatusCodeIs(http.StatusOK); err != nil {
t.Error(err)
Expand All @@ -323,8 +338,9 @@ func TestTenantIdHeaderAndNoBaseUriHeader_UsesTenantIdHeaderAndDefaultSystemBase
req.Header.Set(signatureHeader, base64Signature(tenantIdFromHeader, signatureKey))
handlerSpy := handlerSpy{}
responseSpy := responseSpy{httptest.NewRecorder()}
logSpy := loggerSpy{}

tenant.AddToCtx(defaultSystemBaseUri, signatureKey)(&handlerSpy).ServeHTTP(responseSpy, req)
tenant.AddToCtx(defaultSystemBaseUri, signatureKey, logSpy.logError)(&handlerSpy).ServeHTTP(responseSpy, req)

if err := responseSpy.assertStatusCodeIs(http.StatusOK); err != nil {
t.Error(err)
Expand All @@ -347,8 +363,9 @@ func TestNoHeadersButDefaultSystemBaseUri_UsesDefaultBaseUriAndTenantIdZero(t *t
}
handlerSpy := handlerSpy{}
responseSpy := responseSpy{httptest.NewRecorder()}
logSpy := loggerSpy{}

tenant.AddToCtx(defaultSystemBaseUri, signatureKey)(&handlerSpy).ServeHTTP(responseSpy, req)
tenant.AddToCtx(defaultSystemBaseUri, signatureKey, logSpy.logError)(&handlerSpy).ServeHTTP(responseSpy, req)

if err := responseSpy.assertStatusCodeIs(http.StatusOK); err != nil {
t.Error(err)
Expand All @@ -371,8 +388,9 @@ func TestNoHeadersButDefaultSystemBaseUriAndNoSignatureSecretKey_UsesDefaultBase
}
handlerSpy := handlerSpy{}
responseSpy := responseSpy{httptest.NewRecorder()}
logSpy := loggerSpy{}

tenant.AddToCtx(defaultSystemBaseUri, nil)(&handlerSpy).ServeHTTP(responseSpy, req)
tenant.AddToCtx(defaultSystemBaseUri, nil, logSpy.logError)(&handlerSpy).ServeHTTP(responseSpy, req)

if err := responseSpy.assertStatusCodeIs(http.StatusOK); err != nil {
t.Error(err)
Expand Down Expand Up @@ -400,15 +418,20 @@ func TestWrongDataSignedWithValidSignatureKey_Returns403(t *testing.T) {
req.Header.Set(signatureHeader, base64Signature("wrong data", signatureKey))
handlerSpy := handlerSpy{}
responseSpy := responseSpy{httptest.NewRecorder()}
logSpy := loggerSpy{}

tenant.AddToCtx("", signatureKey)(&handlerSpy).ServeHTTP(responseSpy, req)
tenant.AddToCtx("", signatureKey, logSpy.logError)(&handlerSpy).ServeHTTP(responseSpy, req)

if err := responseSpy.assertStatusCodeIs(http.StatusForbidden); err != nil {
t.Error(err)
}
if handlerSpy.hasBeenCalled {
t.Error("inner handler should not have been called")
}

if err := logSpy.assertLogContains("signature"); err != nil {
t.Error(err)
}
}

func TestNoneBase64Signature_Returns403(t *testing.T) {
Expand All @@ -423,15 +446,20 @@ func TestNoneBase64Signature_Returns403(t *testing.T) {
req.Header.Set(signatureHeader, "abc+(9-!")
handlerSpy := handlerSpy{}
responseSpy := responseSpy{httptest.NewRecorder()}
logSpy := loggerSpy{}

tenant.AddToCtx("", signatureKey)(&handlerSpy).ServeHTTP(responseSpy, req)
tenant.AddToCtx("", signatureKey, logSpy.logError)(&handlerSpy).ServeHTTP(responseSpy, req)

if err := responseSpy.assertStatusCodeIs(http.StatusForbidden); err != nil {
t.Error(err)
}
if handlerSpy.hasBeenCalled {
t.Error("inner handler should not have been called")
}

if err := logSpy.assertLogContains("illegal base64"); err != nil {
t.Error(err)
}
}

func TestWrongSignatureKey_Returns403(t *testing.T) {
Expand All @@ -447,15 +475,20 @@ func TestWrongSignatureKey_Returns403(t *testing.T) {
req.Header.Set(signatureHeader, base64Signature(systemBaseUriFromHeader+tenantIdFromHeader, wrongSignatureKey))
handlerSpy := handlerSpy{}
responseSpy := responseSpy{httptest.NewRecorder()}
logSpy := loggerSpy{}

tenant.AddToCtx("", signatureKey)(&handlerSpy).ServeHTTP(responseSpy, req)
tenant.AddToCtx("", signatureKey, logSpy.logError)(&handlerSpy).ServeHTTP(responseSpy, req)

if err := responseSpy.assertStatusCodeIs(http.StatusForbidden); err != nil {
t.Error(err)
}
if handlerSpy.hasBeenCalled {
t.Error("inner handler should not have been called")
}

if err := logSpy.assertLogContains("signature"); err != nil {
t.Error(err)
}
}

func TestHeadersWithoutSignature_Returns403(t *testing.T) {
Expand All @@ -469,15 +502,20 @@ func TestHeadersWithoutSignature_Returns403(t *testing.T) {
req.Header.Set(tenantIdHeader, tenantIdFromHeader)
handlerSpy := handlerSpy{}
responseSpy := responseSpy{httptest.NewRecorder()}
logSpy := loggerSpy{}

tenant.AddToCtx("", signatureKey)(&handlerSpy).ServeHTTP(responseSpy, req)
tenant.AddToCtx("", signatureKey, logSpy.logError)(&handlerSpy).ServeHTTP(responseSpy, req)

if err := responseSpy.assertStatusCodeIs(http.StatusForbidden); err != nil {
t.Error(err)
}
if handlerSpy.hasBeenCalled {
t.Error("inner handler should not have been called")
}

if err := logSpy.assertLogContains("signature"); err != nil {
t.Error(err)
}
}

func TestHeadersAndNoSignatureSecretKey_Returns500(t *testing.T) {
Expand All @@ -492,15 +530,20 @@ func TestHeadersAndNoSignatureSecretKey_Returns500(t *testing.T) {
req.Header.Set(signatureHeader, base64Signature(systemBaseUriFromHeader+tenantIdFromHeader, signatureKey))
handlerSpy := handlerSpy{}
responseSpy := responseSpy{httptest.NewRecorder()}
logSpy := loggerSpy{}

tenant.AddToCtx("", nil)(&handlerSpy).ServeHTTP(responseSpy, req)
tenant.AddToCtx("", nil, logSpy.logError)(&handlerSpy).ServeHTTP(responseSpy, req)

if err := responseSpy.assertStatusCodeIs(http.StatusInternalServerError); err != nil {
t.Error(err)
}
if handlerSpy.hasBeenCalled {
t.Error("inner handler should not have been called")
}

if err := logSpy.assertLogContains("secret"); err != nil {
t.Error(err)
}
}

func TestNoIdOnContext_SetId_ReturnsContextWithId(t *testing.T) {
Expand Down Expand Up @@ -625,3 +668,24 @@ func (spy *responseSpy) assertStatusCodeIs(expectedStatusCode int) error {
}
return nil
}

type loggerSpy struct {
hasBeenCalled bool
lastMessage string
}

func (spy *loggerSpy) logError(ctx context.Context, message string) {
spy.hasBeenCalled = true
spy.lastMessage = message
fmt.Println(message)
}

func (spy *loggerSpy) assertLogContains(term string) error {
if !spy.hasBeenCalled {
return fmt.Errorf("log should have been written")
}
if !strings.Contains(spy.lastMessage, term) {
return fmt.Errorf("expected log to contain the term '%v'", term)
}
return nil
}