From 96d7cead1a15a4cdf361159a2f0a95b93ce4dde8 Mon Sep 17 00:00:00 2001 From: Tom Proctor Date: Tue, 3 Mar 2026 16:19:07 +0000 Subject: [PATCH] gocached: allow multiple JWT issuers To allow a single gocached server to be shared between clients in different contexts, support multiple JWT issuers. For example, it could be configured to support both GitHub identity tokens and AWS IAM outbound identity federation, with distinct claims in each case. As written, this commit breaks the gocached package API, but we're not releasing proper semantic versions of the library, so I haven't made efforts not to. I'm happy to receive feedback on that if it is going to cause issues for anyone though. Updates tailscale/corp#37839 Signed-off-by: Tom Proctor --- cmd/gocached/gocached.go | 9 +- gocached/gocached.go | 95 +++++++++++++------ gocached/gocached_test.go | 176 ++++++++++++++++++++++++++++++++++- gocached/internal/jwt/jwt.go | 149 +++++++++++++++++------------ gocached/logger/logger.go | 7 ++ 5 files changed, 341 insertions(+), 95 deletions(-) create mode 100644 gocached/logger/logger.go diff --git a/cmd/gocached/gocached.go b/cmd/gocached/gocached.go index 3aea693..eddbae0 100644 --- a/cmd/gocached/gocached.go +++ b/cmd/gocached/gocached.go @@ -71,10 +71,11 @@ func main() { maps.Copy(globalClaims, jwtClaims) maps.Copy(globalClaims, globalJWTClaims) - opts = append(opts, - gocached.WithJWTAuth(*jwtIssuer, jwtClaims), - gocached.WithGlobalNamespaceJWTClaims(globalClaims), - ) + opts = append(opts, gocached.WithJWTAuth(gocached.JWTIssuerConfig{ + Issuer: *jwtIssuer, + RequiredClaims: jwtClaims, + GlobalWriteClaims: globalClaims, + })) } srv, err := gocached.NewServer(opts...) diff --git a/gocached/gocached.go b/gocached/gocached.go index 564a088..9f97931 100644 --- a/gocached/gocached.go +++ b/gocached/gocached.go @@ -61,6 +61,7 @@ import ( "time" ijwt "github.com/bradfitz/go-tool-cache/gocached/internal/jwt" + "github.com/bradfitz/go-tool-cache/gocached/logger" "github.com/pierrec/lz4/v4" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/collectors" @@ -274,13 +275,19 @@ func (srv *Server) start() error { srv.logf("gocached: cleaned %v", res) } - if srv.jwtIssuer != "" { - srv.jwtValidator = ijwt.NewJWTValidator(srv.logf, srv.jwtIssuer, gocachedAudience) + if len(srv.jwtIssuers) > 0 { + issuerURLs := make([]string, 0, len(srv.jwtIssuers)) + for iss := range srv.jwtIssuers { + issuerURLs = append(issuerURLs, iss) + } + srv.jwtValidator = ijwt.NewJWTValidator(srv.logf, gocachedAudience, issuerURLs) if err := srv.jwtValidator.RunUpdateJWKSLoop(srv.shutdownCtx); err != nil { return fmt.Errorf("failed to fetch JWKS for JWT validator: %w", err) } - srv.logf("gocached: using JWT issuer %q with claims %v, global claims %v", srv.jwtIssuer, srv.jwtClaims, srv.globalJWTClaims) + for iss, entry := range srv.jwtIssuers { + srv.logf("gocached: using JWT issuer %q with required claims %v, global write claims %v", iss, entry.requiredClaims, entry.globalWriteClaims) + } go srv.runCleanSessionsLoop() } @@ -351,11 +358,9 @@ func WithVerbose(verbose bool) ServerOption { } } -type logf func(format string, args ...any) - // WithLogf sets a custom logging function for the server. Defaults to // [log.Printf]. -func WithLogf(logf logf) ServerOption { +func WithLogf(logf logger.Logf) ServerOption { return func(srv *Server) { srv.logf = logf } @@ -378,24 +383,40 @@ func WithMaxAge(maxAge time.Duration) ServerOption { } } -// WithJWTAuth enables JWT-based authentication for the server. The issuer must -// be a reachable HTTP(S) server that serves its JWKS via a URL discoverable at -// /.well-known/openid-configuration, and any JWT presented to the server must -// exactly match the provided claims to start a session. No requests are allowed -// without authentication if JWT auth is enabled. -func WithJWTAuth(issuer string, claims map[string]string) ServerOption { - return func(srv *Server) { - srv.jwtIssuer = issuer - srv.jwtClaims = claims - } +// JWTIssuerConfig configures a single OIDC issuer for JWT-based authentication. +type JWTIssuerConfig struct { + // Issuer is the OIDC issuer URL. It must be a reachable HTTP(S) server + // that serves its JWKS via a URL discoverable at + // /.well-known/openid-configuration. + Issuer string + + // RequiredClaims are claims that any JWT from this issuer must have to + // start a session. All key-value pairs must match exactly. + RequiredClaims map[string]string + + // GlobalWriteClaims are claims that a JWT from this issuer must have to + // write to the cache's global namespace. It should be a superset of + // RequiredClaims. + GlobalWriteClaims map[string]string } -// WithGlobalNamespaceJWTClaims sets additional claims that a JWT must have to -// write to the cache's global namespace. It should be a superset of the claims -// provided to [WithJWTAuth]. -func WithGlobalNamespaceJWTClaims(claims map[string]string) ServerOption { +// WithJWTAuth enables JWT-based authentication for the server. Each issuer must +// be a reachable HTTP(S) server that serves its JWKS via a URL discoverable at +// /.well-known/openid-configuration, and any JWT presented to the server must +// exactly match the issuer's required claims to start a session. No requests are +// allowed without authentication if JWT auth is enabled. It can be called multiple +// times; configs accumulate. +func WithJWTAuth(issuers ...JWTIssuerConfig) ServerOption { return func(srv *Server) { - srv.globalJWTClaims = claims + if srv.jwtIssuers == nil { + srv.jwtIssuers = make(map[string]*jwtIssuerConfig) + } + for _, ic := range issuers { + srv.jwtIssuers[ic.Issuer] = &jwtIssuerConfig{ + requiredClaims: ic.RequiredClaims, + globalWriteClaims: ic.GlobalWriteClaims, + } + } } } @@ -445,7 +466,7 @@ type Server struct { db *sql.DB dir string // for SQLite DB + large blobs verbose bool - logf logf + logf logger.Logf clock func() time.Time // if non-nil, alternate time.Now for testing metricsHandler http.Handler maxSize int64 // maximum size of the cache in bytes; 0 means no limit @@ -453,10 +474,8 @@ type Server struct { shutdownCtx context.Context shutdownCancel context.CancelFunc - jwtValidator *ijwt.Validator // nil unless jwtIssuer is set - jwtIssuer string // issuer URL for JWTs - jwtClaims map[string]string // claims required for any JWT to start a session - globalJWTClaims map[string]string // additional claims required to write to global namespace + jwtValidator *ijwt.Validator // nil unless jwtIssuers is non-empty + jwtIssuers map[string]*jwtIssuerConfig // keyed by issuer URL mu sync.RWMutex // guards following fields in this block sessions map[string]*sessionData // maps access token -> session data. @@ -501,6 +520,12 @@ type Server struct { } } +// jwtIssuerConfig holds per-issuer claim requirements for JWT auth. +type jwtIssuerConfig struct { + requiredClaims map[string]string + globalWriteClaims map[string]string +} + // sessionData corresponds to a specific access token, and is only used if JWT // auth is enabled. type sessionData struct { @@ -1149,11 +1174,17 @@ func (srv *Server) handleTokenExchange(w http.ResponseWriter, r *http.Request) { } func (srv *Server) evaluateClaims(claims map[string]any) (globalNSWrite bool, _ error) { - if missing := findMissingClaims(srv.jwtClaims, claims); len(missing) > 0 { + iss, _ := claims["iss"].(string) + cfg, ok := srv.jwtIssuers[iss] + if !ok { + return false, fmt.Errorf("got claims %v; unknown issuer %q", claims, iss) + } + + if missing := findMissingClaims(cfg.requiredClaims, claims); len(missing) > 0 { return false, fmt.Errorf("got claims %v; missing required claims: %v", claims, missing) } - if missing := findMissingClaims(srv.globalJWTClaims, claims); len(missing) == 0 { + if missing := findMissingClaims(cfg.globalWriteClaims, claims); len(missing) == 0 { return true, nil } else if srv.verbose { srv.logf("token exchange: missing global namespace write claims: %v", missing) @@ -1702,9 +1733,11 @@ func (srv *Server) serveSessions(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/html; charset=utf-8") fmt.Fprintf(w, "

gocached sessions

\n") - fmt.Fprintf(w, "

JWT issuer: %s

\n", srv.jwtIssuer) - fmt.Fprintf(w, "

JWT claims required: %v

\n", srv.jwtClaims) - fmt.Fprintf(w, "

JWT global write claims required: %v

\n", srv.globalJWTClaims) + for iss, cfg := range srv.jwtIssuers { + fmt.Fprintf(w, "

JWT issuer: %s

\n", iss) + fmt.Fprintf(w, "

JWT claims required: %v

\n", cfg.requiredClaims) + fmt.Fprintf(w, "

JWT global write claims required: %v

\n", cfg.globalWriteClaims) + } fmt.Fprintf(w, "

Number of sessions: %d

\n", len(sessions)) fmt.Fprintf(w, "\n") diff --git a/gocached/gocached_test.go b/gocached/gocached_test.go index a119d05..a46d92e 100644 --- a/gocached/gocached_test.go +++ b/gocached/gocached_test.go @@ -722,8 +722,11 @@ func TestExchangeToken(t *testing.T) { t.Run(name, func(t *testing.T) { issuer, createJWT := startOIDCServer(t, privateKey.Public()) st := newServerTester(t, - WithJWTAuth(issuer, wantClaims), - WithGlobalNamespaceJWTClaims(wantGlobalClaims), + WithJWTAuth(JWTIssuerConfig{ + Issuer: issuer, + RequiredClaims: wantClaims, + GlobalWriteClaims: wantGlobalClaims, + }), ) // Generate JWT. @@ -837,6 +840,175 @@ func TestExchangeToken(t *testing.T) { } } +func TestMultiIssuerAuth(t *testing.T) { + // Generate separate keys for each issuer. + keyA, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("error generating key A: %v", err) + } + keyB, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("error generating key B: %v", err) + } + keyC, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("error generating key C: %v", err) + } + + issuerA, createJWTA := startOIDCServer(t, keyA.Public()) + issuerB, createJWTB := startOIDCServer(t, keyB.Public()) + issuerC, createJWTC := startOIDCServer(t, keyC.Public()) + + st := newServerTester(t, + WithJWTAuth( + JWTIssuerConfig{ + Issuer: issuerA, + RequiredClaims: map[string]string{"sub": "userA"}, + GlobalWriteClaims: map[string]string{ + "sub": "userA", + "ref": "refs/heads/main", + }, + }, + JWTIssuerConfig{ + Issuer: issuerB, + RequiredClaims: map[string]string{"sub": "userB"}, + GlobalWriteClaims: map[string]string{ + "sub": "userB", + "ref": "refs/heads/main", + }, + }, + ), + ) + + makeJWTBody := func(jwtString string) []byte { + body, err := json.Marshal(map[string]any{"jwt": jwtString}) + if err != nil { + t.Fatalf("error marshaling request body: %v", err) + } + return body + } + + exchangeToken := func(jwtBody []byte) (*http.Response, string) { + t.Helper() + req, err := http.NewRequest("POST", st.hs.URL+"/auth/exchange-token", bytes.NewReader(jwtBody)) + if err != nil { + t.Fatalf("error creating request: %v", err) + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("error making request: %v", err) + } + body, err := io.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + t.Fatalf("error reading response body: %v", err) + } + if resp.StatusCode == http.StatusOK { + var d struct { + AccessToken string `json:"access_token"` + } + if err := json.Unmarshal(body, &d); err != nil { + t.Fatalf("error decoding response body: %v", err) + } + return resp, d.AccessToken + } + return resp, "" + } + + baseClaims := func(iss string) jwt.MapClaims { + return jwt.MapClaims{ + "iss": iss, + "aud": gocachedAudience, + "nbf": jwt.NewNumericDate(time.Now().Add(-time.Minute)), + "exp": jwt.NewNumericDate(time.Now().Add(time.Hour)), + } + } + + // Issuer A: valid read-only token. + t.Run("issuerA_read", func(t *testing.T) { + claims := baseClaims(issuerA) + claims["sub"] = "userA" + resp, accessToken := exchangeToken(makeJWTBody(createJWTA(claims, keyA))) + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status code: want %d, got %d", http.StatusOK, resp.StatusCode) + } + if accessToken == "" { + t.Fatal("expected access token") + } + + cl := st.mkClient() + cl.AccessToken = accessToken + st.wantGetMiss(cl, "aabb01") + }) + + // Issuer A: valid write token. + t.Run("issuerA_write", func(t *testing.T) { + claims := baseClaims(issuerA) + claims["sub"] = "userA" + claims["ref"] = "refs/heads/main" + resp, accessToken := exchangeToken(makeJWTBody(createJWTA(claims, keyA))) + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status code: want %d, got %d", http.StatusOK, resp.StatusCode) + } + cl := st.mkClient() + cl.AccessToken = accessToken + st.wantPut(cl, "aabb02", "ccdd02", "hello-from-A") + st.wantGet(cl, "aabb02", "ccdd02", "hello-from-A") + }) + + // Issuer B: valid read-only token. + t.Run("issuerB_read", func(t *testing.T) { + claims := baseClaims(issuerB) + claims["sub"] = "userB" + resp, accessToken := exchangeToken(makeJWTBody(createJWTB(claims, keyB))) + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status code: want %d, got %d", http.StatusOK, resp.StatusCode) + } + if accessToken == "" { + t.Fatal("expected access token") + } + cl := st.mkClient() + cl.AccessToken = accessToken + // Can read data written by issuer A. + st.wantGet(cl, "aabb02", "ccdd02", "hello-from-A") + }) + + // Issuer B: valid write token. + t.Run("issuerB_write", func(t *testing.T) { + claims := baseClaims(issuerB) + claims["sub"] = "userB" + claims["ref"] = "refs/heads/main" + resp, accessToken := exchangeToken(makeJWTBody(createJWTB(claims, keyB))) + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status code: want %d, got %d", http.StatusOK, resp.StatusCode) + } + cl := st.mkClient() + cl.AccessToken = accessToken + st.wantPut(cl, "aabb03", "ccdd03", "hello-from-B") + st.wantGet(cl, "aabb03", "ccdd03", "hello-from-B") + }) + + // Issuer B: wrong required claims (sub doesn't match). + t.Run("issuerB_wrong_sub", func(t *testing.T) { + claims := baseClaims(issuerB) + claims["sub"] = "userA" // issuer B requires sub=userB + resp, _ := exchangeToken(makeJWTBody(createJWTB(claims, keyB))) + if resp.StatusCode != http.StatusUnauthorized { + t.Fatalf("unexpected status code: want %d, got %d", http.StatusUnauthorized, resp.StatusCode) + } + }) + + // Issuer C: not configured, should be rejected. + t.Run("issuerC_rejected", func(t *testing.T) { + claims := baseClaims(issuerC) + claims["sub"] = "userC" + resp, _ := exchangeToken(makeJWTBody(createJWTC(claims, keyC))) + if resp.StatusCode != http.StatusUnauthorized { + t.Fatalf("unexpected status code: want %d, got %d", http.StatusUnauthorized, resp.StatusCode) + } + }) +} + func BenchmarkFlushAccessTimes(b *testing.B) { st := newServerTester(b, WithVerbose(false)) s := st.srv diff --git a/gocached/internal/jwt/jwt.go b/gocached/internal/jwt/jwt.go index 31f13a0..d572b81 100644 --- a/gocached/internal/jwt/jwt.go +++ b/gocached/internal/jwt/jwt.go @@ -6,6 +6,7 @@ package jwt import ( "context" "encoding/json" + "errors" "fmt" "io" "net/http" @@ -15,6 +16,7 @@ import ( "sync/atomic" "time" + "github.com/bradfitz/go-tool-cache/gocached/logger" "github.com/go-jose/go-jose/v4" "github.com/golang-jwt/jwt/v5" ) @@ -26,30 +28,64 @@ var ( supportedAlgorithms = []string{"HS256", "RS256", "ES256"} ) +// issuer holds the per-issuer state: its own JWT parser and signing keys. +type issuer struct { + iss string + parser *jwt.Parser + signingKeys atomic.Value // []jose.JSONWebKey +} + +// keyFunc is how github.com/golang-jwt/jwt gets the public key it needs to +// verify a JWT signature. Each issuer entry only checks its own keys. +func (ie *issuer) keyFunc(t *jwt.Token) (any, error) { + var kid string + if v, ok := t.Header["kid"]; ok { + kid, _ = v.(string) + } + if kid == "" { + return nil, fmt.Errorf("no kid found in token header") + } + + signingKeys := ie.signingKeys.Load().([]jose.JSONWebKey) + for _, k := range signingKeys { + if k.KeyID == kid { + return k.Key, nil + } + } + + return nil, fmt.Errorf("unknown key ID: %s", kid) +} + // NewJWTValidator constructs a [Validator] for validating JWTs. Must call // [RunUpdateJWKSLoop] before validating any JWTs. Every JWT must exactly match -// the provided issuer and audience values in its "iss" and "aud" claims -// respectively. The issuer must be a reachable HTTP server that serves the JWT -// public signing keys via the path defined by [oidcConfigWellKnownPath], and -// the audience should be a value specific to the trust boundary that gocached -// resides within. -func NewJWTValidator(logf func(format string, args ...any), issuer, audience string) *Validator { +// one of the provided issuers and the audience value in its "iss" and "aud" +// claims respectively. Each issuer must be a reachable HTTP server that serves +// the JWT public signing keys via the path defined by [oidcConfigWellKnownPath], +// and the audience should be a value specific to the trust boundary that +// gocached resides within. +func NewJWTValidator(logf logger.Logf, audience string, issuerURLs []string) *Validator { + var issuers []*issuer + for _, iss := range issuerURLs { + issuers = append(issuers, &issuer{ + iss: iss, + parser: jwt.NewParser( + jwt.WithValidMethods(supportedAlgorithms), + jwt.WithIssuer(iss), + jwt.WithAudience(audience), + jwt.WithLeeway(10*time.Second), + jwt.WithIssuedAt(), + ), + }) + } return &Validator{ - logf: logf, - issuer: issuer, - parser: jwt.NewParser( - jwt.WithValidMethods(supportedAlgorithms), - jwt.WithIssuer(issuer), - jwt.WithAudience(audience), - jwt.WithLeeway(10*time.Second), - jwt.WithIssuedAt(), - ), + logf: logf, + issuers: issuers, } } // RunUpdateJWKSLoop fetches the JWKS synchronously once to surface any config // errors early, and then starts a background goroutine that periodically fetches -// the JWKS from the issuer to keep the signing keys up to date. Must be called +// the JWKS from all issuers to keep the signing keys up to date. Must be called // before validating any JWTs. func (v *Validator) RunUpdateJWKSLoop(ctx context.Context) error { // Initial fetch to error early on misconfiguration. @@ -65,56 +101,43 @@ func (v *Validator) RunUpdateJWKSLoop(ctx context.Context) error { // Validator provides methods for validating JWTs. Use [NewJWTValidator] to // construct a working Validator. type Validator struct { - logf func(format string, args ...any) - issuer string - parser *jwt.Parser - - signingKeys atomic.Value // []jose.JSONWebKey + logf logger.Logf + issuers []*issuer // TODO(tomhjp): metrics } // Validate returns an error if the provided JWT fails validation for an invalid -// signature or standard claim (iss, aud, iat, nbf, exp). It returns the token's -// verified claims if validation succeeds. The caller should then make policy -// decisions based on other claims such as "sub" or other custom claims. +// signature or standard claim (iss, aud, iat, nbf, exp). It tries each +// configured issuer and returns the verified claims from the first successful +// parse. If all issuers fail, it returns the last error. func (v *Validator) Validate(ctx context.Context, jwtString string) (map[string]any, error) { - tk, err := v.parser.Parse(jwtString, v.keyFunc) - if err != nil { - return nil, fmt.Errorf("failed to parse token: %w", err) - } - - if !tk.Valid { - return nil, fmt.Errorf("invalid token") - } + var lastErr error + for _, ie := range v.issuers { + tk, err := ie.parser.Parse(jwtString, ie.keyFunc) + if err != nil { + lastErr = err + continue + } - gotClaims, ok := tk.Claims.(jwt.MapClaims) - if !ok { - return nil, fmt.Errorf("unexpected claims type: %T", tk.Claims) - } + if !tk.Valid { + lastErr = fmt.Errorf("invalid token") + continue + } - return gotClaims, nil -} + gotClaims, ok := tk.Claims.(jwt.MapClaims) + if !ok { + lastErr = fmt.Errorf("unexpected claims type: %T", tk.Claims) + continue + } -// keyFunc is how github.com/golang-jwt/jwt gets the public key it needs to -// verify a JWT signature. -func (v *Validator) keyFunc(t *jwt.Token) (any, error) { - var kid string - if v, ok := t.Header["kid"]; ok { - kid, _ = v.(string) - } - if kid == "" { - return nil, fmt.Errorf("no kid found in token header") + return gotClaims, nil } - signingKeys := v.signingKeys.Load().([]jose.JSONWebKey) - for _, k := range signingKeys { - if k.KeyID == kid { - return k.Key, nil - } + if lastErr != nil { + return nil, fmt.Errorf("failed to parse token: %w", lastErr) } - - return nil, fmt.Errorf("unknown key ID: %s", kid) + return nil, fmt.Errorf("no issuers configured") } func (v *Validator) runUpdateJWKSLoop(ctx context.Context) { @@ -135,10 +158,20 @@ func (v *Validator) runUpdateJWKSLoop(ctx context.Context) { } func (v *Validator) updateJWKS(ctx context.Context) error { - v.logf("jwt: fetching JWKS from issuer %q", v.issuer) - u, err := url.Parse(v.issuer) + var errs []error + for _, ie := range v.issuers { + if err := ie.updateJWKS(ctx, v.logf); err != nil { + errs = append(errs, fmt.Errorf("issuer %q: %w", ie.iss, err)) + } + } + return errors.Join(errs...) +} + +func (ie *issuer) updateJWKS(ctx context.Context, logf logger.Logf) error { + logf("jwt: fetching JWKS from issuer %q", ie.iss) + u, err := url.Parse(ie.iss) if err != nil { - return fmt.Errorf("failed to parse issuer URL %q: %w", v.issuer, err) + return fmt.Errorf("failed to parse issuer URL %q: %w", ie.iss, err) } u.Path = path.Join(u.Path, oidcConfigWellKnownPath) @@ -174,7 +207,7 @@ func (v *Validator) updateJWKS(ctx context.Context) error { signingKeys = append(signingKeys, k) } - v.signingKeys.Store(signingKeys) + ie.signingKeys.Store(signingKeys) return nil } diff --git a/gocached/logger/logger.go b/gocached/logger/logger.go new file mode 100644 index 0000000..69b955e --- /dev/null +++ b/gocached/logger/logger.go @@ -0,0 +1,7 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package logger + +// Logf is a logging function type. It is implemented by log.Printf. +type Logf func(format string, args ...any)