Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,18 @@ module github.com/HGV/x
go 1.22

require (
github.com/coreos/go-oidc/v3 v3.11.0
github.com/jackc/pgx/v5 v5.7.2
github.com/stretchr/testify v1.10.0
)

require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/go-jose/go-jose/v4 v4.0.2 // indirect
github.com/kr/text v0.2.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rogpeppe/go-internal v1.13.1 // indirect
golang.org/x/crypto v0.31.0 // indirect
golang.org/x/oauth2 v0.21.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
8 changes: 8 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
github.com/coreos/go-oidc/v3 v3.11.0 h1:Ia3MxdwpSw702YW0xgfmP1GVCMA9aEFWu12XUZ3/OtI=
github.com/coreos/go-oidc/v3 v3.11.0/go.mod h1:gE3LgjOgFoHi9a4ce4/tJczr0Ai2/BoDhf0r5lltWI0=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-jose/go-jose/v4 v4.0.2 h1:R3l3kkBds16bO7ZFAEEcofK0MkrAJt3jlJznWZG0nvk=
github.com/go-jose/go-jose/v4 v4.0.2/go.mod h1:WVf9LFMHh/QVrmqrOfqun0C45tMe3RoiKJMPvgWwLfY=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
Expand All @@ -21,6 +27,8 @@ github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOf
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/oauth2 v0.21.0 h1:tsimM75w1tF/uws5rbeHzIWxEqElMehnc+iW793zsZs=
golang.org/x/oauth2 v0.21.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
Expand Down
150 changes: 150 additions & 0 deletions oidcx/middleware.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
package oidcx

import (
"context"
"errors"
"fmt"
"net/http"
"strings"

"github.com/coreos/go-oidc/v3/oidc"
)

type Middleware struct {
o *middlewareOptions
v *oidc.IDTokenVerifier
}

type middlewareOptions struct {
ClientID string
SkipClientIDCheck bool
Email string
SkipEmailCheck bool
InsecureSkipSignatureCheck bool
AuthFailedHandler func(error) http.HandlerFunc
}

type MiddlewareOption func(*middlewareOptions)

type idTokenContextKey struct{}

func NewMiddleware(ctx context.Context, issuer string, opts ...MiddlewareOption) *Middleware {
provider, err := oidc.NewProvider(ctx, issuer)
if err != nil {
panic(err)
}

o := &middlewareOptions{
AuthFailedHandler: defaultAuthFailedHandler,
}

for _, opt := range opts {
opt(o)
}

return &Middleware{
o: o,
v: provider.VerifierContext(ctx, &oidc.Config{
ClientID: o.ClientID,
SkipClientIDCheck: o.SkipClientIDCheck,
InsecureSkipSignatureCheck: o.InsecureSkipSignatureCheck,
}),
}
}

func (mw *Middleware) Handler(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()

bearerToken, ok := validateAuthHeader(r.Header.Get("Authorization"), "Bearer ")
if !ok {
mw.o.AuthFailedHandler(errors.New("bearer token is missing or invalid")).ServeHTTP(w, r)
return
}

idToken, err := mw.v.Verify(ctx, bearerToken)
if err != nil {
mw.o.AuthFailedHandler(err).ServeHTTP(w, r)
return
}

if !mw.o.SkipEmailCheck {
if mw.o.Email == "" {
mw.o.AuthFailedHandler(errors.New("invalid configuration, Email must be provided or SkipEmailCheck must be set")).ServeHTTP(w, r)
return
}

var claims struct {
Email string `json:"email"`
}
if err = idToken.Claims(&claims); err != nil {
mw.o.AuthFailedHandler(err).ServeHTTP(w, r)
return
}
if !strings.EqualFold(mw.o.Email, claims.Email) {
mw.o.AuthFailedHandler(fmt.Errorf("expected email %q got %q", mw.o.Email, claims.Email)).ServeHTTP(w, r)
return
}
}

ctx = context.WithValue(ctx, idTokenContextKey{}, idToken)
next.ServeHTTP(w, r.WithContext(ctx))
}
return http.HandlerFunc(fn)
}

func IDTokenFromContext(ctx context.Context) (*oidc.IDToken, bool) {
octx, ok := ctx.Value(idTokenContextKey{}).(*oidc.IDToken)
return octx, ok
}

func WithAuthFailedHandler(h func(error) http.HandlerFunc) MiddlewareOption {
return func(o *middlewareOptions) {
if h != nil {
o.AuthFailedHandler = h
}
}
}

func WithClientID(clientID string) MiddlewareOption {
return func(o *middlewareOptions) {
o.ClientID = clientID
}
}

func WithSkipClientIDCheck() MiddlewareOption {
return func(o *middlewareOptions) {
o.SkipClientIDCheck = true
}
}

func WithEmail(email string) MiddlewareOption {
return func(o *middlewareOptions) {
o.Email = email
}
}

func WithSkipEmailCheck() MiddlewareOption {
return func(o *middlewareOptions) {
o.SkipEmailCheck = true
}
}

func withInsecureSkipSignatureCheck() MiddlewareOption {
return func(o *middlewareOptions) {
o.InsecureSkipSignatureCheck = true
}
}

func validateAuthHeader(s, scheme string) (string, bool) {
if len(s) >= len(scheme) && strings.EqualFold(s[0:len(scheme)], scheme) {
return s[len(scheme):], true
}
return s, false
}

func defaultAuthFailedHandler(err error) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
}
}
119 changes: 119 additions & 0 deletions oidcx/middleware_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
package oidcx

import (
"context"
"io"
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"
)

func TestValidateAuthHeader(t *testing.T) {
tests := []struct {
authHeader string
scheme string
expectedToken string
expectedOK bool
}{
{
authHeader: "", scheme: "bearer",
expectedToken: "", expectedOK: false,
},
{
authHeader: "bearer token", scheme: "bearer ",
expectedToken: "token", expectedOK: true,
},
{
authHeader: "BEARER token", scheme: "bearer ",
expectedToken: "token", expectedOK: true,
},
}

for _, tt := range tests {
t.Run("", func(t *testing.T) {
token, ok := validateAuthHeader(tt.authHeader, tt.scheme)
assert.Equal(t, tt.expectedOK, ok)
assert.Equal(t, tt.expectedToken, token)
})
}
}

func TestHandler(t *testing.T) {
issuer := "https://api.accounts.hgv.it"

next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
idToken, ok := IDTokenFromContext(r.Context())
assert.True(t, ok)
assert.NotNil(t, idToken)
w.WriteHeader(http.StatusTeapot)
})

makeRequest := func(h http.Handler, token string) *httptest.ResponseRecorder {
r := httptest.NewRequest(http.MethodGet, "/", nil)
r.Header.Add("Authorization", "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJ0ZXN0LXVzZXIiLCJhdWQiOlsidGVzdC1jbGllbnQiXSwiaXNzIjoiaHR0cHM6Ly9hcGkuYWNjb3VudHMuaGd2Lml0IiwiaWF0IjoxNjAwMDAwMDAwLCJleHAiOjIwMDAwMDAwMDB9.hJREizNgcJpnEEyZ5lE5VC9tPY45JIFJoxm9ZlIPgTI")
w := httptest.NewRecorder()
h.ServeHTTP(w, r)
return w
}

t.Run("unauthorized with no token", func(t *testing.T) {
h := NewMiddleware(context.Background(), issuer).Handler(next)
w := makeRequest(h, "")
assert.Equal(t, http.StatusUnauthorized, w.Code)
})

t.Run("custom error handler returns 403", func(t *testing.T) {
h := NewMiddleware(context.Background(), issuer,
WithAuthFailedHandler(func(err error) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusForbidden)
}
}),
).Handler(next)
w := makeRequest(h, "")
assert.Equal(t, http.StatusForbidden, w.Code)
})

t.Run("email config required but missing", func(t *testing.T) {
h := NewMiddleware(context.Background(), issuer,
WithAuthFailedHandler(func(err error) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(err.Error()))
}
}),
WithSkipClientIDCheck(),
withInsecureSkipSignatureCheck(),
).Handler(next)
w := makeRequest(h, "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJ0ZXN0LXVzZXIiLCJhdWQiOlsidGVzdC1jbGllbnQiXSwiaXNzIjoiaHR0cHM6Ly9hcGkuYWNjb3VudHMuaGd2Lml0IiwiaWF0IjoxNjAwMDAwMDAwLCJleHAiOjIwMDAwMDAwMDB9.hJREizNgcJpnEEyZ5lE5VC9tPY45JIFJoxm9ZlIPgTI")
b, _ := io.ReadAll(w.Body)
assert.Equal(t, "invalid configuration, Email must be provided or SkipEmailCheck must be set", string(b))
})

t.Run("email mismatch", func(t *testing.T) {
h := NewMiddleware(context.Background(), issuer,
WithAuthFailedHandler(func(err error) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(err.Error()))
}
}),
WithSkipClientIDCheck(),
WithEmail("test@hgv.it"),
withInsecureSkipSignatureCheck(),
).Handler(next)
w := makeRequest(h, "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJ0ZXN0LXVzZXIiLCJhdWQiOlsidGVzdC1jbGllbnQiXSwiaXNzIjoiaHR0cHM6Ly9hcGkuYWNjb3VudHMuaGd2Lml0IiwiaWF0IjoxNjAwMDAwMDAwLCJleHAiOjIwMDAwMDAwMDB9.hJREizNgcJpnEEyZ5lE5VC9tPY45JIFJoxm9ZlIPgTI")
b, _ := io.ReadAll(w.Body)
assert.Equal(t, "expected email \"test@hgv.it\" got \"\"", string(b))
})

t.Run("valid expired token without email check", func(t *testing.T) {
h := NewMiddleware(context.Background(), issuer,
WithSkipClientIDCheck(),
WithSkipEmailCheck(),
withInsecureSkipSignatureCheck(),
).Handler(next)
w := makeRequest(h, "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJ0ZXN0LXVzZXIiLCJhdWQiOlsidGVzdC1jbGllbnQiXSwiaXNzIjoiaHR0cHM6Ly9hcGkuYWNjb3VudHMuaGd2Lml0IiwiaWF0IjoxNjAwMDAwMDAwLCJleHAiOjIwMDAwMDAwMDB9.hJREizNgcJpnEEyZ5lE5VC9tPY45JIFJoxm9ZlIPgTI")
assert.Equal(t, http.StatusTeapot, w.Code)
})
}