Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions cmd/gocached/gocached.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,11 @@ func main() {
maps.Copy(globalClaims, jwtClaims)
maps.Copy(globalClaims, globalJWTClaims)

opts = append(opts,
gocached.WithJWTAuth(*jwtIssuer, jwtClaims),
gocached.WithGlobalNamespaceJWTClaims(globalClaims),
)
opts = append(opts, gocached.WithJWTAuth(gocached.JWTIssuerConfig{
Issuer: *jwtIssuer,
RequiredClaims: jwtClaims,
GlobalWriteClaims: globalClaims,
}))
}

srv, err := gocached.NewServer(opts...)
Expand Down
95 changes: 64 additions & 31 deletions gocached/gocached.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ import (
"time"

ijwt "github.com/bradfitz/go-tool-cache/gocached/internal/jwt"
"github.com/bradfitz/go-tool-cache/gocached/logger"
"github.com/pierrec/lz4/v4"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/collectors"
Expand Down Expand Up @@ -274,13 +275,19 @@ func (srv *Server) start() error {
srv.logf("gocached: cleaned %v", res)
}

if srv.jwtIssuer != "" {
srv.jwtValidator = ijwt.NewJWTValidator(srv.logf, srv.jwtIssuer, gocachedAudience)
if len(srv.jwtIssuers) > 0 {
issuerURLs := make([]string, 0, len(srv.jwtIssuers))
for iss := range srv.jwtIssuers {
issuerURLs = append(issuerURLs, iss)
}
srv.jwtValidator = ijwt.NewJWTValidator(srv.logf, gocachedAudience, issuerURLs)
if err := srv.jwtValidator.RunUpdateJWKSLoop(srv.shutdownCtx); err != nil {
return fmt.Errorf("failed to fetch JWKS for JWT validator: %w", err)
}

srv.logf("gocached: using JWT issuer %q with claims %v, global claims %v", srv.jwtIssuer, srv.jwtClaims, srv.globalJWTClaims)
for iss, entry := range srv.jwtIssuers {
srv.logf("gocached: using JWT issuer %q with required claims %v, global write claims %v", iss, entry.requiredClaims, entry.globalWriteClaims)
}

go srv.runCleanSessionsLoop()
}
Expand Down Expand Up @@ -351,11 +358,9 @@ func WithVerbose(verbose bool) ServerOption {
}
}

type logf func(format string, args ...any)

// WithLogf sets a custom logging function for the server. Defaults to
// [log.Printf].
func WithLogf(logf logf) ServerOption {
func WithLogf(logf logger.Logf) ServerOption {
return func(srv *Server) {
srv.logf = logf
}
Expand All @@ -378,24 +383,40 @@ func WithMaxAge(maxAge time.Duration) ServerOption {
}
}

// WithJWTAuth enables JWT-based authentication for the server. The issuer must
// be a reachable HTTP(S) server that serves its JWKS via a URL discoverable at
// /.well-known/openid-configuration, and any JWT presented to the server must
// exactly match the provided claims to start a session. No requests are allowed
// without authentication if JWT auth is enabled.
func WithJWTAuth(issuer string, claims map[string]string) ServerOption {
return func(srv *Server) {
srv.jwtIssuer = issuer
srv.jwtClaims = claims
}
// JWTIssuerConfig configures a single OIDC issuer for JWT-based authentication.
type JWTIssuerConfig struct {
// Issuer is the OIDC issuer URL. It must be a reachable HTTP(S) server
// that serves its JWKS via a URL discoverable at
// /.well-known/openid-configuration.
Issuer string

// RequiredClaims are claims that any JWT from this issuer must have to
// start a session. All key-value pairs must match exactly.
RequiredClaims map[string]string

// GlobalWriteClaims are claims that a JWT from this issuer must have to
// write to the cache's global namespace. It should be a superset of
// RequiredClaims.
GlobalWriteClaims map[string]string
}

// WithGlobalNamespaceJWTClaims sets additional claims that a JWT must have to
// write to the cache's global namespace. It should be a superset of the claims
// provided to [WithJWTAuth].
func WithGlobalNamespaceJWTClaims(claims map[string]string) ServerOption {
// WithJWTAuth enables JWT-based authentication for the server. Each issuer must
// be a reachable HTTP(S) server that serves its JWKS via a URL discoverable at
// /.well-known/openid-configuration, and any JWT presented to the server must
// exactly match the issuer's required claims to start a session. No requests are
// allowed without authentication if JWT auth is enabled. It can be called multiple
// times; configs accumulate.
func WithJWTAuth(issuers ...JWTIssuerConfig) ServerOption {
return func(srv *Server) {
srv.globalJWTClaims = claims
if srv.jwtIssuers == nil {
srv.jwtIssuers = make(map[string]*jwtIssuerConfig)
}
for _, ic := range issuers {
srv.jwtIssuers[ic.Issuer] = &jwtIssuerConfig{
requiredClaims: ic.RequiredClaims,
globalWriteClaims: ic.GlobalWriteClaims,
}
}
}
}

Expand Down Expand Up @@ -445,18 +466,16 @@ type Server struct {
db *sql.DB
dir string // for SQLite DB + large blobs
verbose bool
logf logf
logf logger.Logf
clock func() time.Time // if non-nil, alternate time.Now for testing
metricsHandler http.Handler
maxSize int64 // maximum size of the cache in bytes; 0 means no limit
maxAge time.Duration // maximum age of objects; 0 means no limit
shutdownCtx context.Context
shutdownCancel context.CancelFunc

jwtValidator *ijwt.Validator // nil unless jwtIssuer is set
jwtIssuer string // issuer URL for JWTs
jwtClaims map[string]string // claims required for any JWT to start a session
globalJWTClaims map[string]string // additional claims required to write to global namespace
jwtValidator *ijwt.Validator // nil unless jwtIssuers is non-empty
jwtIssuers map[string]*jwtIssuerConfig // keyed by issuer URL

mu sync.RWMutex // guards following fields in this block
sessions map[string]*sessionData // maps access token -> session data.
Expand Down Expand Up @@ -501,6 +520,12 @@ type Server struct {
}
}

// jwtIssuerConfig holds per-issuer claim requirements for JWT auth.
type jwtIssuerConfig struct {
requiredClaims map[string]string
globalWriteClaims map[string]string
}

// sessionData corresponds to a specific access token, and is only used if JWT
// auth is enabled.
type sessionData struct {
Expand Down Expand Up @@ -1149,11 +1174,17 @@ func (srv *Server) handleTokenExchange(w http.ResponseWriter, r *http.Request) {
}

func (srv *Server) evaluateClaims(claims map[string]any) (globalNSWrite bool, _ error) {
if missing := findMissingClaims(srv.jwtClaims, claims); len(missing) > 0 {
iss, _ := claims["iss"].(string)
cfg, ok := srv.jwtIssuers[iss]
if !ok {
return false, fmt.Errorf("got claims %v; unknown issuer %q", claims, iss)
}

if missing := findMissingClaims(cfg.requiredClaims, claims); len(missing) > 0 {
return false, fmt.Errorf("got claims %v; missing required claims: %v", claims, missing)
}

if missing := findMissingClaims(srv.globalJWTClaims, claims); len(missing) == 0 {
if missing := findMissingClaims(cfg.globalWriteClaims, claims); len(missing) == 0 {
return true, nil
} else if srv.verbose {
srv.logf("token exchange: missing global namespace write claims: %v", missing)
Expand Down Expand Up @@ -1702,9 +1733,11 @@ func (srv *Server) serveSessions(w http.ResponseWriter, r *http.Request) {

w.Header().Set("Content-Type", "text/html; charset=utf-8")
fmt.Fprintf(w, "<html><body><h1>gocached sessions</h1>\n")
fmt.Fprintf(w, "<p>JWT issuer: %s</p>\n", srv.jwtIssuer)
fmt.Fprintf(w, "<p>JWT claims required: %v</p>\n", srv.jwtClaims)
fmt.Fprintf(w, "<p>JWT global write claims required: %v</p>\n", srv.globalJWTClaims)
for iss, cfg := range srv.jwtIssuers {
fmt.Fprintf(w, "<p>JWT issuer: %s</p>\n", iss)
fmt.Fprintf(w, "<p>JWT claims required: %v</p>\n", cfg.requiredClaims)
fmt.Fprintf(w, "<p>JWT global write claims required: %v</p>\n", cfg.globalWriteClaims)
}
fmt.Fprintf(w, "<p>Number of sessions: %d</p>\n", len(sessions))

fmt.Fprintf(w, "<table border='1' cellpadding=5>\n")
Expand Down
176 changes: 174 additions & 2 deletions gocached/gocached_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -722,8 +722,11 @@ func TestExchangeToken(t *testing.T) {
t.Run(name, func(t *testing.T) {
issuer, createJWT := startOIDCServer(t, privateKey.Public())
st := newServerTester(t,
WithJWTAuth(issuer, wantClaims),
WithGlobalNamespaceJWTClaims(wantGlobalClaims),
WithJWTAuth(JWTIssuerConfig{
Issuer: issuer,
RequiredClaims: wantClaims,
GlobalWriteClaims: wantGlobalClaims,
}),
)

// Generate JWT.
Expand Down Expand Up @@ -837,6 +840,175 @@ func TestExchangeToken(t *testing.T) {
}
}

func TestMultiIssuerAuth(t *testing.T) {
// Generate separate keys for each issuer.
keyA, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatalf("error generating key A: %v", err)
}
keyB, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatalf("error generating key B: %v", err)
}
keyC, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatalf("error generating key C: %v", err)
}

issuerA, createJWTA := startOIDCServer(t, keyA.Public())
issuerB, createJWTB := startOIDCServer(t, keyB.Public())
issuerC, createJWTC := startOIDCServer(t, keyC.Public())

st := newServerTester(t,
WithJWTAuth(
JWTIssuerConfig{
Issuer: issuerA,
RequiredClaims: map[string]string{"sub": "userA"},
GlobalWriteClaims: map[string]string{
"sub": "userA",
"ref": "refs/heads/main",
},
},
JWTIssuerConfig{
Issuer: issuerB,
RequiredClaims: map[string]string{"sub": "userB"},
GlobalWriteClaims: map[string]string{
"sub": "userB",
"ref": "refs/heads/main",
},
},
),
)

makeJWTBody := func(jwtString string) []byte {
body, err := json.Marshal(map[string]any{"jwt": jwtString})
if err != nil {
t.Fatalf("error marshaling request body: %v", err)
}
return body
}

exchangeToken := func(jwtBody []byte) (*http.Response, string) {
t.Helper()
req, err := http.NewRequest("POST", st.hs.URL+"/auth/exchange-token", bytes.NewReader(jwtBody))
if err != nil {
t.Fatalf("error creating request: %v", err)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("error making request: %v", err)
}
body, err := io.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
t.Fatalf("error reading response body: %v", err)
}
if resp.StatusCode == http.StatusOK {
var d struct {
AccessToken string `json:"access_token"`
}
if err := json.Unmarshal(body, &d); err != nil {
t.Fatalf("error decoding response body: %v", err)
}
return resp, d.AccessToken
}
return resp, ""
}

baseClaims := func(iss string) jwt.MapClaims {
return jwt.MapClaims{
"iss": iss,
"aud": gocachedAudience,
"nbf": jwt.NewNumericDate(time.Now().Add(-time.Minute)),
"exp": jwt.NewNumericDate(time.Now().Add(time.Hour)),
}
}

// Issuer A: valid read-only token.
t.Run("issuerA_read", func(t *testing.T) {
claims := baseClaims(issuerA)
claims["sub"] = "userA"
resp, accessToken := exchangeToken(makeJWTBody(createJWTA(claims, keyA)))
if resp.StatusCode != http.StatusOK {
t.Fatalf("unexpected status code: want %d, got %d", http.StatusOK, resp.StatusCode)
}
if accessToken == "" {
t.Fatal("expected access token")
}

cl := st.mkClient()
cl.AccessToken = accessToken
st.wantGetMiss(cl, "aabb01")
})

// Issuer A: valid write token.
t.Run("issuerA_write", func(t *testing.T) {
claims := baseClaims(issuerA)
claims["sub"] = "userA"
claims["ref"] = "refs/heads/main"
resp, accessToken := exchangeToken(makeJWTBody(createJWTA(claims, keyA)))
if resp.StatusCode != http.StatusOK {
t.Fatalf("unexpected status code: want %d, got %d", http.StatusOK, resp.StatusCode)
}
cl := st.mkClient()
cl.AccessToken = accessToken
st.wantPut(cl, "aabb02", "ccdd02", "hello-from-A")
st.wantGet(cl, "aabb02", "ccdd02", "hello-from-A")
})

// Issuer B: valid read-only token.
t.Run("issuerB_read", func(t *testing.T) {
claims := baseClaims(issuerB)
claims["sub"] = "userB"
resp, accessToken := exchangeToken(makeJWTBody(createJWTB(claims, keyB)))
if resp.StatusCode != http.StatusOK {
t.Fatalf("unexpected status code: want %d, got %d", http.StatusOK, resp.StatusCode)
}
if accessToken == "" {
t.Fatal("expected access token")
}
cl := st.mkClient()
cl.AccessToken = accessToken
// Can read data written by issuer A.
st.wantGet(cl, "aabb02", "ccdd02", "hello-from-A")
})

// Issuer B: valid write token.
t.Run("issuerB_write", func(t *testing.T) {
claims := baseClaims(issuerB)
claims["sub"] = "userB"
claims["ref"] = "refs/heads/main"
resp, accessToken := exchangeToken(makeJWTBody(createJWTB(claims, keyB)))
if resp.StatusCode != http.StatusOK {
t.Fatalf("unexpected status code: want %d, got %d", http.StatusOK, resp.StatusCode)
}
cl := st.mkClient()
cl.AccessToken = accessToken
st.wantPut(cl, "aabb03", "ccdd03", "hello-from-B")
st.wantGet(cl, "aabb03", "ccdd03", "hello-from-B")
})

// Issuer B: wrong required claims (sub doesn't match).
t.Run("issuerB_wrong_sub", func(t *testing.T) {
claims := baseClaims(issuerB)
claims["sub"] = "userA" // issuer B requires sub=userB
resp, _ := exchangeToken(makeJWTBody(createJWTB(claims, keyB)))
if resp.StatusCode != http.StatusUnauthorized {
t.Fatalf("unexpected status code: want %d, got %d", http.StatusUnauthorized, resp.StatusCode)
}
})

// Issuer C: not configured, should be rejected.
t.Run("issuerC_rejected", func(t *testing.T) {
claims := baseClaims(issuerC)
claims["sub"] = "userC"
resp, _ := exchangeToken(makeJWTBody(createJWTC(claims, keyC)))
if resp.StatusCode != http.StatusUnauthorized {
t.Fatalf("unexpected status code: want %d, got %d", http.StatusUnauthorized, resp.StatusCode)
}
})
}

func BenchmarkFlushAccessTimes(b *testing.B) {
st := newServerTester(b, WithVerbose(false))
s := st.srv
Expand Down
Loading