diff --git a/cmd/gocached/gocached.go b/cmd/gocached/gocached.go index 3aea693..eddbae0 100644 --- a/cmd/gocached/gocached.go +++ b/cmd/gocached/gocached.go @@ -71,10 +71,11 @@ func main() { maps.Copy(globalClaims, jwtClaims) maps.Copy(globalClaims, globalJWTClaims) - opts = append(opts, - gocached.WithJWTAuth(*jwtIssuer, jwtClaims), - gocached.WithGlobalNamespaceJWTClaims(globalClaims), - ) + opts = append(opts, gocached.WithJWTAuth(gocached.JWTIssuerConfig{ + Issuer: *jwtIssuer, + RequiredClaims: jwtClaims, + GlobalWriteClaims: globalClaims, + })) } srv, err := gocached.NewServer(opts...) diff --git a/gocached/gocached.go b/gocached/gocached.go index 564a088..9f97931 100644 --- a/gocached/gocached.go +++ b/gocached/gocached.go @@ -61,6 +61,7 @@ import ( "time" ijwt "github.com/bradfitz/go-tool-cache/gocached/internal/jwt" + "github.com/bradfitz/go-tool-cache/gocached/logger" "github.com/pierrec/lz4/v4" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/collectors" @@ -274,13 +275,19 @@ func (srv *Server) start() error { srv.logf("gocached: cleaned %v", res) } - if srv.jwtIssuer != "" { - srv.jwtValidator = ijwt.NewJWTValidator(srv.logf, srv.jwtIssuer, gocachedAudience) + if len(srv.jwtIssuers) > 0 { + issuerURLs := make([]string, 0, len(srv.jwtIssuers)) + for iss := range srv.jwtIssuers { + issuerURLs = append(issuerURLs, iss) + } + srv.jwtValidator = ijwt.NewJWTValidator(srv.logf, gocachedAudience, issuerURLs) if err := srv.jwtValidator.RunUpdateJWKSLoop(srv.shutdownCtx); err != nil { return fmt.Errorf("failed to fetch JWKS for JWT validator: %w", err) } - srv.logf("gocached: using JWT issuer %q with claims %v, global claims %v", srv.jwtIssuer, srv.jwtClaims, srv.globalJWTClaims) + for iss, entry := range srv.jwtIssuers { + srv.logf("gocached: using JWT issuer %q with required claims %v, global write claims %v", iss, entry.requiredClaims, entry.globalWriteClaims) + } go srv.runCleanSessionsLoop() } @@ -351,11 +358,9 @@ func WithVerbose(verbose bool) ServerOption { } } -type logf func(format string, args ...any) - // WithLogf sets a custom logging function for the server. Defaults to // [log.Printf]. -func WithLogf(logf logf) ServerOption { +func WithLogf(logf logger.Logf) ServerOption { return func(srv *Server) { srv.logf = logf } @@ -378,24 +383,40 @@ func WithMaxAge(maxAge time.Duration) ServerOption { } } -// WithJWTAuth enables JWT-based authentication for the server. The issuer must -// be a reachable HTTP(S) server that serves its JWKS via a URL discoverable at -// /.well-known/openid-configuration, and any JWT presented to the server must -// exactly match the provided claims to start a session. No requests are allowed -// without authentication if JWT auth is enabled. -func WithJWTAuth(issuer string, claims map[string]string) ServerOption { - return func(srv *Server) { - srv.jwtIssuer = issuer - srv.jwtClaims = claims - } +// JWTIssuerConfig configures a single OIDC issuer for JWT-based authentication. +type JWTIssuerConfig struct { + // Issuer is the OIDC issuer URL. It must be a reachable HTTP(S) server + // that serves its JWKS via a URL discoverable at + // /.well-known/openid-configuration. + Issuer string + + // RequiredClaims are claims that any JWT from this issuer must have to + // start a session. All key-value pairs must match exactly. + RequiredClaims map[string]string + + // GlobalWriteClaims are claims that a JWT from this issuer must have to + // write to the cache's global namespace. It should be a superset of + // RequiredClaims. + GlobalWriteClaims map[string]string } -// WithGlobalNamespaceJWTClaims sets additional claims that a JWT must have to -// write to the cache's global namespace. It should be a superset of the claims -// provided to [WithJWTAuth]. -func WithGlobalNamespaceJWTClaims(claims map[string]string) ServerOption { +// WithJWTAuth enables JWT-based authentication for the server. Each issuer must +// be a reachable HTTP(S) server that serves its JWKS via a URL discoverable at +// /.well-known/openid-configuration, and any JWT presented to the server must +// exactly match the issuer's required claims to start a session. No requests are +// allowed without authentication if JWT auth is enabled. It can be called multiple +// times; configs accumulate. +func WithJWTAuth(issuers ...JWTIssuerConfig) ServerOption { return func(srv *Server) { - srv.globalJWTClaims = claims + if srv.jwtIssuers == nil { + srv.jwtIssuers = make(map[string]*jwtIssuerConfig) + } + for _, ic := range issuers { + srv.jwtIssuers[ic.Issuer] = &jwtIssuerConfig{ + requiredClaims: ic.RequiredClaims, + globalWriteClaims: ic.GlobalWriteClaims, + } + } } } @@ -445,7 +466,7 @@ type Server struct { db *sql.DB dir string // for SQLite DB + large blobs verbose bool - logf logf + logf logger.Logf clock func() time.Time // if non-nil, alternate time.Now for testing metricsHandler http.Handler maxSize int64 // maximum size of the cache in bytes; 0 means no limit @@ -453,10 +474,8 @@ type Server struct { shutdownCtx context.Context shutdownCancel context.CancelFunc - jwtValidator *ijwt.Validator // nil unless jwtIssuer is set - jwtIssuer string // issuer URL for JWTs - jwtClaims map[string]string // claims required for any JWT to start a session - globalJWTClaims map[string]string // additional claims required to write to global namespace + jwtValidator *ijwt.Validator // nil unless jwtIssuers is non-empty + jwtIssuers map[string]*jwtIssuerConfig // keyed by issuer URL mu sync.RWMutex // guards following fields in this block sessions map[string]*sessionData // maps access token -> session data. @@ -501,6 +520,12 @@ type Server struct { } } +// jwtIssuerConfig holds per-issuer claim requirements for JWT auth. +type jwtIssuerConfig struct { + requiredClaims map[string]string + globalWriteClaims map[string]string +} + // sessionData corresponds to a specific access token, and is only used if JWT // auth is enabled. type sessionData struct { @@ -1149,11 +1174,17 @@ func (srv *Server) handleTokenExchange(w http.ResponseWriter, r *http.Request) { } func (srv *Server) evaluateClaims(claims map[string]any) (globalNSWrite bool, _ error) { - if missing := findMissingClaims(srv.jwtClaims, claims); len(missing) > 0 { + iss, _ := claims["iss"].(string) + cfg, ok := srv.jwtIssuers[iss] + if !ok { + return false, fmt.Errorf("got claims %v; unknown issuer %q", claims, iss) + } + + if missing := findMissingClaims(cfg.requiredClaims, claims); len(missing) > 0 { return false, fmt.Errorf("got claims %v; missing required claims: %v", claims, missing) } - if missing := findMissingClaims(srv.globalJWTClaims, claims); len(missing) == 0 { + if missing := findMissingClaims(cfg.globalWriteClaims, claims); len(missing) == 0 { return true, nil } else if srv.verbose { srv.logf("token exchange: missing global namespace write claims: %v", missing) @@ -1702,9 +1733,11 @@ func (srv *Server) serveSessions(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/html; charset=utf-8") fmt.Fprintf(w, "

gocached sessions

\n") - fmt.Fprintf(w, "

JWT issuer: %s

\n", srv.jwtIssuer) - fmt.Fprintf(w, "

JWT claims required: %v

\n", srv.jwtClaims) - fmt.Fprintf(w, "

JWT global write claims required: %v

\n", srv.globalJWTClaims) + for iss, cfg := range srv.jwtIssuers { + fmt.Fprintf(w, "

JWT issuer: %s

\n", iss) + fmt.Fprintf(w, "

JWT claims required: %v

\n", cfg.requiredClaims) + fmt.Fprintf(w, "

JWT global write claims required: %v

\n", cfg.globalWriteClaims) + } fmt.Fprintf(w, "

Number of sessions: %d

\n", len(sessions)) fmt.Fprintf(w, "\n") diff --git a/gocached/gocached_test.go b/gocached/gocached_test.go index a119d05..a46d92e 100644 --- a/gocached/gocached_test.go +++ b/gocached/gocached_test.go @@ -722,8 +722,11 @@ func TestExchangeToken(t *testing.T) { t.Run(name, func(t *testing.T) { issuer, createJWT := startOIDCServer(t, privateKey.Public()) st := newServerTester(t, - WithJWTAuth(issuer, wantClaims), - WithGlobalNamespaceJWTClaims(wantGlobalClaims), + WithJWTAuth(JWTIssuerConfig{ + Issuer: issuer, + RequiredClaims: wantClaims, + GlobalWriteClaims: wantGlobalClaims, + }), ) // Generate JWT. @@ -837,6 +840,175 @@ func TestExchangeToken(t *testing.T) { } } +func TestMultiIssuerAuth(t *testing.T) { + // Generate separate keys for each issuer. + keyA, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("error generating key A: %v", err) + } + keyB, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("error generating key B: %v", err) + } + keyC, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("error generating key C: %v", err) + } + + issuerA, createJWTA := startOIDCServer(t, keyA.Public()) + issuerB, createJWTB := startOIDCServer(t, keyB.Public()) + issuerC, createJWTC := startOIDCServer(t, keyC.Public()) + + st := newServerTester(t, + WithJWTAuth( + JWTIssuerConfig{ + Issuer: issuerA, + RequiredClaims: map[string]string{"sub": "userA"}, + GlobalWriteClaims: map[string]string{ + "sub": "userA", + "ref": "refs/heads/main", + }, + }, + JWTIssuerConfig{ + Issuer: issuerB, + RequiredClaims: map[string]string{"sub": "userB"}, + GlobalWriteClaims: map[string]string{ + "sub": "userB", + "ref": "refs/heads/main", + }, + }, + ), + ) + + makeJWTBody := func(jwtString string) []byte { + body, err := json.Marshal(map[string]any{"jwt": jwtString}) + if err != nil { + t.Fatalf("error marshaling request body: %v", err) + } + return body + } + + exchangeToken := func(jwtBody []byte) (*http.Response, string) { + t.Helper() + req, err := http.NewRequest("POST", st.hs.URL+"/auth/exchange-token", bytes.NewReader(jwtBody)) + if err != nil { + t.Fatalf("error creating request: %v", err) + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("error making request: %v", err) + } + body, err := io.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + t.Fatalf("error reading response body: %v", err) + } + if resp.StatusCode == http.StatusOK { + var d struct { + AccessToken string `json:"access_token"` + } + if err := json.Unmarshal(body, &d); err != nil { + t.Fatalf("error decoding response body: %v", err) + } + return resp, d.AccessToken + } + return resp, "" + } + + baseClaims := func(iss string) jwt.MapClaims { + return jwt.MapClaims{ + "iss": iss, + "aud": gocachedAudience, + "nbf": jwt.NewNumericDate(time.Now().Add(-time.Minute)), + "exp": jwt.NewNumericDate(time.Now().Add(time.Hour)), + } + } + + // Issuer A: valid read-only token. + t.Run("issuerA_read", func(t *testing.T) { + claims := baseClaims(issuerA) + claims["sub"] = "userA" + resp, accessToken := exchangeToken(makeJWTBody(createJWTA(claims, keyA))) + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status code: want %d, got %d", http.StatusOK, resp.StatusCode) + } + if accessToken == "" { + t.Fatal("expected access token") + } + + cl := st.mkClient() + cl.AccessToken = accessToken + st.wantGetMiss(cl, "aabb01") + }) + + // Issuer A: valid write token. + t.Run("issuerA_write", func(t *testing.T) { + claims := baseClaims(issuerA) + claims["sub"] = "userA" + claims["ref"] = "refs/heads/main" + resp, accessToken := exchangeToken(makeJWTBody(createJWTA(claims, keyA))) + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status code: want %d, got %d", http.StatusOK, resp.StatusCode) + } + cl := st.mkClient() + cl.AccessToken = accessToken + st.wantPut(cl, "aabb02", "ccdd02", "hello-from-A") + st.wantGet(cl, "aabb02", "ccdd02", "hello-from-A") + }) + + // Issuer B: valid read-only token. + t.Run("issuerB_read", func(t *testing.T) { + claims := baseClaims(issuerB) + claims["sub"] = "userB" + resp, accessToken := exchangeToken(makeJWTBody(createJWTB(claims, keyB))) + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status code: want %d, got %d", http.StatusOK, resp.StatusCode) + } + if accessToken == "" { + t.Fatal("expected access token") + } + cl := st.mkClient() + cl.AccessToken = accessToken + // Can read data written by issuer A. + st.wantGet(cl, "aabb02", "ccdd02", "hello-from-A") + }) + + // Issuer B: valid write token. + t.Run("issuerB_write", func(t *testing.T) { + claims := baseClaims(issuerB) + claims["sub"] = "userB" + claims["ref"] = "refs/heads/main" + resp, accessToken := exchangeToken(makeJWTBody(createJWTB(claims, keyB))) + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status code: want %d, got %d", http.StatusOK, resp.StatusCode) + } + cl := st.mkClient() + cl.AccessToken = accessToken + st.wantPut(cl, "aabb03", "ccdd03", "hello-from-B") + st.wantGet(cl, "aabb03", "ccdd03", "hello-from-B") + }) + + // Issuer B: wrong required claims (sub doesn't match). + t.Run("issuerB_wrong_sub", func(t *testing.T) { + claims := baseClaims(issuerB) + claims["sub"] = "userA" // issuer B requires sub=userB + resp, _ := exchangeToken(makeJWTBody(createJWTB(claims, keyB))) + if resp.StatusCode != http.StatusUnauthorized { + t.Fatalf("unexpected status code: want %d, got %d", http.StatusUnauthorized, resp.StatusCode) + } + }) + + // Issuer C: not configured, should be rejected. + t.Run("issuerC_rejected", func(t *testing.T) { + claims := baseClaims(issuerC) + claims["sub"] = "userC" + resp, _ := exchangeToken(makeJWTBody(createJWTC(claims, keyC))) + if resp.StatusCode != http.StatusUnauthorized { + t.Fatalf("unexpected status code: want %d, got %d", http.StatusUnauthorized, resp.StatusCode) + } + }) +} + func BenchmarkFlushAccessTimes(b *testing.B) { st := newServerTester(b, WithVerbose(false)) s := st.srv diff --git a/gocached/internal/jwt/jwt.go b/gocached/internal/jwt/jwt.go index 31f13a0..d572b81 100644 --- a/gocached/internal/jwt/jwt.go +++ b/gocached/internal/jwt/jwt.go @@ -6,6 +6,7 @@ package jwt import ( "context" "encoding/json" + "errors" "fmt" "io" "net/http" @@ -15,6 +16,7 @@ import ( "sync/atomic" "time" + "github.com/bradfitz/go-tool-cache/gocached/logger" "github.com/go-jose/go-jose/v4" "github.com/golang-jwt/jwt/v5" ) @@ -26,30 +28,64 @@ var ( supportedAlgorithms = []string{"HS256", "RS256", "ES256"} ) +// issuer holds the per-issuer state: its own JWT parser and signing keys. +type issuer struct { + iss string + parser *jwt.Parser + signingKeys atomic.Value // []jose.JSONWebKey +} + +// keyFunc is how github.com/golang-jwt/jwt gets the public key it needs to +// verify a JWT signature. Each issuer entry only checks its own keys. +func (ie *issuer) keyFunc(t *jwt.Token) (any, error) { + var kid string + if v, ok := t.Header["kid"]; ok { + kid, _ = v.(string) + } + if kid == "" { + return nil, fmt.Errorf("no kid found in token header") + } + + signingKeys := ie.signingKeys.Load().([]jose.JSONWebKey) + for _, k := range signingKeys { + if k.KeyID == kid { + return k.Key, nil + } + } + + return nil, fmt.Errorf("unknown key ID: %s", kid) +} + // NewJWTValidator constructs a [Validator] for validating JWTs. Must call // [RunUpdateJWKSLoop] before validating any JWTs. Every JWT must exactly match -// the provided issuer and audience values in its "iss" and "aud" claims -// respectively. The issuer must be a reachable HTTP server that serves the JWT -// public signing keys via the path defined by [oidcConfigWellKnownPath], and -// the audience should be a value specific to the trust boundary that gocached -// resides within. -func NewJWTValidator(logf func(format string, args ...any), issuer, audience string) *Validator { +// one of the provided issuers and the audience value in its "iss" and "aud" +// claims respectively. Each issuer must be a reachable HTTP server that serves +// the JWT public signing keys via the path defined by [oidcConfigWellKnownPath], +// and the audience should be a value specific to the trust boundary that +// gocached resides within. +func NewJWTValidator(logf logger.Logf, audience string, issuerURLs []string) *Validator { + var issuers []*issuer + for _, iss := range issuerURLs { + issuers = append(issuers, &issuer{ + iss: iss, + parser: jwt.NewParser( + jwt.WithValidMethods(supportedAlgorithms), + jwt.WithIssuer(iss), + jwt.WithAudience(audience), + jwt.WithLeeway(10*time.Second), + jwt.WithIssuedAt(), + ), + }) + } return &Validator{ - logf: logf, - issuer: issuer, - parser: jwt.NewParser( - jwt.WithValidMethods(supportedAlgorithms), - jwt.WithIssuer(issuer), - jwt.WithAudience(audience), - jwt.WithLeeway(10*time.Second), - jwt.WithIssuedAt(), - ), + logf: logf, + issuers: issuers, } } // RunUpdateJWKSLoop fetches the JWKS synchronously once to surface any config // errors early, and then starts a background goroutine that periodically fetches -// the JWKS from the issuer to keep the signing keys up to date. Must be called +// the JWKS from all issuers to keep the signing keys up to date. Must be called // before validating any JWTs. func (v *Validator) RunUpdateJWKSLoop(ctx context.Context) error { // Initial fetch to error early on misconfiguration. @@ -65,56 +101,43 @@ func (v *Validator) RunUpdateJWKSLoop(ctx context.Context) error { // Validator provides methods for validating JWTs. Use [NewJWTValidator] to // construct a working Validator. type Validator struct { - logf func(format string, args ...any) - issuer string - parser *jwt.Parser - - signingKeys atomic.Value // []jose.JSONWebKey + logf logger.Logf + issuers []*issuer // TODO(tomhjp): metrics } // Validate returns an error if the provided JWT fails validation for an invalid -// signature or standard claim (iss, aud, iat, nbf, exp). It returns the token's -// verified claims if validation succeeds. The caller should then make policy -// decisions based on other claims such as "sub" or other custom claims. +// signature or standard claim (iss, aud, iat, nbf, exp). It tries each +// configured issuer and returns the verified claims from the first successful +// parse. If all issuers fail, it returns the last error. func (v *Validator) Validate(ctx context.Context, jwtString string) (map[string]any, error) { - tk, err := v.parser.Parse(jwtString, v.keyFunc) - if err != nil { - return nil, fmt.Errorf("failed to parse token: %w", err) - } - - if !tk.Valid { - return nil, fmt.Errorf("invalid token") - } + var lastErr error + for _, ie := range v.issuers { + tk, err := ie.parser.Parse(jwtString, ie.keyFunc) + if err != nil { + lastErr = err + continue + } - gotClaims, ok := tk.Claims.(jwt.MapClaims) - if !ok { - return nil, fmt.Errorf("unexpected claims type: %T", tk.Claims) - } + if !tk.Valid { + lastErr = fmt.Errorf("invalid token") + continue + } - return gotClaims, nil -} + gotClaims, ok := tk.Claims.(jwt.MapClaims) + if !ok { + lastErr = fmt.Errorf("unexpected claims type: %T", tk.Claims) + continue + } -// keyFunc is how github.com/golang-jwt/jwt gets the public key it needs to -// verify a JWT signature. -func (v *Validator) keyFunc(t *jwt.Token) (any, error) { - var kid string - if v, ok := t.Header["kid"]; ok { - kid, _ = v.(string) - } - if kid == "" { - return nil, fmt.Errorf("no kid found in token header") + return gotClaims, nil } - signingKeys := v.signingKeys.Load().([]jose.JSONWebKey) - for _, k := range signingKeys { - if k.KeyID == kid { - return k.Key, nil - } + if lastErr != nil { + return nil, fmt.Errorf("failed to parse token: %w", lastErr) } - - return nil, fmt.Errorf("unknown key ID: %s", kid) + return nil, fmt.Errorf("no issuers configured") } func (v *Validator) runUpdateJWKSLoop(ctx context.Context) { @@ -135,10 +158,20 @@ func (v *Validator) runUpdateJWKSLoop(ctx context.Context) { } func (v *Validator) updateJWKS(ctx context.Context) error { - v.logf("jwt: fetching JWKS from issuer %q", v.issuer) - u, err := url.Parse(v.issuer) + var errs []error + for _, ie := range v.issuers { + if err := ie.updateJWKS(ctx, v.logf); err != nil { + errs = append(errs, fmt.Errorf("issuer %q: %w", ie.iss, err)) + } + } + return errors.Join(errs...) +} + +func (ie *issuer) updateJWKS(ctx context.Context, logf logger.Logf) error { + logf("jwt: fetching JWKS from issuer %q", ie.iss) + u, err := url.Parse(ie.iss) if err != nil { - return fmt.Errorf("failed to parse issuer URL %q: %w", v.issuer, err) + return fmt.Errorf("failed to parse issuer URL %q: %w", ie.iss, err) } u.Path = path.Join(u.Path, oidcConfigWellKnownPath) @@ -174,7 +207,7 @@ func (v *Validator) updateJWKS(ctx context.Context) error { signingKeys = append(signingKeys, k) } - v.signingKeys.Store(signingKeys) + ie.signingKeys.Store(signingKeys) return nil } diff --git a/gocached/logger/logger.go b/gocached/logger/logger.go new file mode 100644 index 0000000..69b955e --- /dev/null +++ b/gocached/logger/logger.go @@ -0,0 +1,7 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package logger + +// Logf is a logging function type. It is implemented by log.Printf. +type Logf func(format string, args ...any)