diff --git a/cmd/gocached/gocached.go b/cmd/gocached/gocached.go
index 3aea693..eddbae0 100644
--- a/cmd/gocached/gocached.go
+++ b/cmd/gocached/gocached.go
@@ -71,10 +71,11 @@ func main() {
maps.Copy(globalClaims, jwtClaims)
maps.Copy(globalClaims, globalJWTClaims)
- opts = append(opts,
- gocached.WithJWTAuth(*jwtIssuer, jwtClaims),
- gocached.WithGlobalNamespaceJWTClaims(globalClaims),
- )
+ opts = append(opts, gocached.WithJWTAuth(gocached.JWTIssuerConfig{
+ Issuer: *jwtIssuer,
+ RequiredClaims: jwtClaims,
+ GlobalWriteClaims: globalClaims,
+ }))
}
srv, err := gocached.NewServer(opts...)
diff --git a/gocached/gocached.go b/gocached/gocached.go
index 564a088..9f97931 100644
--- a/gocached/gocached.go
+++ b/gocached/gocached.go
@@ -61,6 +61,7 @@ import (
"time"
ijwt "github.com/bradfitz/go-tool-cache/gocached/internal/jwt"
+ "github.com/bradfitz/go-tool-cache/gocached/logger"
"github.com/pierrec/lz4/v4"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/collectors"
@@ -274,13 +275,19 @@ func (srv *Server) start() error {
srv.logf("gocached: cleaned %v", res)
}
- if srv.jwtIssuer != "" {
- srv.jwtValidator = ijwt.NewJWTValidator(srv.logf, srv.jwtIssuer, gocachedAudience)
+ if len(srv.jwtIssuers) > 0 {
+ issuerURLs := make([]string, 0, len(srv.jwtIssuers))
+ for iss := range srv.jwtIssuers {
+ issuerURLs = append(issuerURLs, iss)
+ }
+ srv.jwtValidator = ijwt.NewJWTValidator(srv.logf, gocachedAudience, issuerURLs)
if err := srv.jwtValidator.RunUpdateJWKSLoop(srv.shutdownCtx); err != nil {
return fmt.Errorf("failed to fetch JWKS for JWT validator: %w", err)
}
- srv.logf("gocached: using JWT issuer %q with claims %v, global claims %v", srv.jwtIssuer, srv.jwtClaims, srv.globalJWTClaims)
+ for iss, entry := range srv.jwtIssuers {
+ srv.logf("gocached: using JWT issuer %q with required claims %v, global write claims %v", iss, entry.requiredClaims, entry.globalWriteClaims)
+ }
go srv.runCleanSessionsLoop()
}
@@ -351,11 +358,9 @@ func WithVerbose(verbose bool) ServerOption {
}
}
-type logf func(format string, args ...any)
-
// WithLogf sets a custom logging function for the server. Defaults to
// [log.Printf].
-func WithLogf(logf logf) ServerOption {
+func WithLogf(logf logger.Logf) ServerOption {
return func(srv *Server) {
srv.logf = logf
}
@@ -378,24 +383,40 @@ func WithMaxAge(maxAge time.Duration) ServerOption {
}
}
-// WithJWTAuth enables JWT-based authentication for the server. The issuer must
-// be a reachable HTTP(S) server that serves its JWKS via a URL discoverable at
-// /.well-known/openid-configuration, and any JWT presented to the server must
-// exactly match the provided claims to start a session. No requests are allowed
-// without authentication if JWT auth is enabled.
-func WithJWTAuth(issuer string, claims map[string]string) ServerOption {
- return func(srv *Server) {
- srv.jwtIssuer = issuer
- srv.jwtClaims = claims
- }
+// JWTIssuerConfig configures a single OIDC issuer for JWT-based authentication.
+type JWTIssuerConfig struct {
+ // Issuer is the OIDC issuer URL. It must be a reachable HTTP(S) server
+ // that serves its JWKS via a URL discoverable at
+ // /.well-known/openid-configuration.
+ Issuer string
+
+ // RequiredClaims are claims that any JWT from this issuer must have to
+ // start a session. All key-value pairs must match exactly.
+ RequiredClaims map[string]string
+
+ // GlobalWriteClaims are claims that a JWT from this issuer must have to
+ // write to the cache's global namespace. It should be a superset of
+ // RequiredClaims.
+ GlobalWriteClaims map[string]string
}
-// WithGlobalNamespaceJWTClaims sets additional claims that a JWT must have to
-// write to the cache's global namespace. It should be a superset of the claims
-// provided to [WithJWTAuth].
-func WithGlobalNamespaceJWTClaims(claims map[string]string) ServerOption {
+// WithJWTAuth enables JWT-based authentication for the server. Each issuer must
+// be a reachable HTTP(S) server that serves its JWKS via a URL discoverable at
+// /.well-known/openid-configuration, and any JWT presented to the server must
+// exactly match the issuer's required claims to start a session. No requests are
+// allowed without authentication if JWT auth is enabled. It can be called multiple
+// times; configs accumulate.
+func WithJWTAuth(issuers ...JWTIssuerConfig) ServerOption {
return func(srv *Server) {
- srv.globalJWTClaims = claims
+ if srv.jwtIssuers == nil {
+ srv.jwtIssuers = make(map[string]*jwtIssuerConfig)
+ }
+ for _, ic := range issuers {
+ srv.jwtIssuers[ic.Issuer] = &jwtIssuerConfig{
+ requiredClaims: ic.RequiredClaims,
+ globalWriteClaims: ic.GlobalWriteClaims,
+ }
+ }
}
}
@@ -445,7 +466,7 @@ type Server struct {
db *sql.DB
dir string // for SQLite DB + large blobs
verbose bool
- logf logf
+ logf logger.Logf
clock func() time.Time // if non-nil, alternate time.Now for testing
metricsHandler http.Handler
maxSize int64 // maximum size of the cache in bytes; 0 means no limit
@@ -453,10 +474,8 @@ type Server struct {
shutdownCtx context.Context
shutdownCancel context.CancelFunc
- jwtValidator *ijwt.Validator // nil unless jwtIssuer is set
- jwtIssuer string // issuer URL for JWTs
- jwtClaims map[string]string // claims required for any JWT to start a session
- globalJWTClaims map[string]string // additional claims required to write to global namespace
+ jwtValidator *ijwt.Validator // nil unless jwtIssuers is non-empty
+ jwtIssuers map[string]*jwtIssuerConfig // keyed by issuer URL
mu sync.RWMutex // guards following fields in this block
sessions map[string]*sessionData // maps access token -> session data.
@@ -501,6 +520,12 @@ type Server struct {
}
}
+// jwtIssuerConfig holds per-issuer claim requirements for JWT auth.
+type jwtIssuerConfig struct {
+ requiredClaims map[string]string
+ globalWriteClaims map[string]string
+}
+
// sessionData corresponds to a specific access token, and is only used if JWT
// auth is enabled.
type sessionData struct {
@@ -1149,11 +1174,17 @@ func (srv *Server) handleTokenExchange(w http.ResponseWriter, r *http.Request) {
}
func (srv *Server) evaluateClaims(claims map[string]any) (globalNSWrite bool, _ error) {
- if missing := findMissingClaims(srv.jwtClaims, claims); len(missing) > 0 {
+ iss, _ := claims["iss"].(string)
+ cfg, ok := srv.jwtIssuers[iss]
+ if !ok {
+ return false, fmt.Errorf("got claims %v; unknown issuer %q", claims, iss)
+ }
+
+ if missing := findMissingClaims(cfg.requiredClaims, claims); len(missing) > 0 {
return false, fmt.Errorf("got claims %v; missing required claims: %v", claims, missing)
}
- if missing := findMissingClaims(srv.globalJWTClaims, claims); len(missing) == 0 {
+ if missing := findMissingClaims(cfg.globalWriteClaims, claims); len(missing) == 0 {
return true, nil
} else if srv.verbose {
srv.logf("token exchange: missing global namespace write claims: %v", missing)
@@ -1702,9 +1733,11 @@ func (srv *Server) serveSessions(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html; charset=utf-8")
fmt.Fprintf(w, "
gocached sessions
\n")
- fmt.Fprintf(w, "JWT issuer: %s
\n", srv.jwtIssuer)
- fmt.Fprintf(w, "JWT claims required: %v
\n", srv.jwtClaims)
- fmt.Fprintf(w, "JWT global write claims required: %v
\n", srv.globalJWTClaims)
+ for iss, cfg := range srv.jwtIssuers {
+ fmt.Fprintf(w, "JWT issuer: %s
\n", iss)
+ fmt.Fprintf(w, "JWT claims required: %v
\n", cfg.requiredClaims)
+ fmt.Fprintf(w, "JWT global write claims required: %v
\n", cfg.globalWriteClaims)
+ }
fmt.Fprintf(w, "Number of sessions: %d
\n", len(sessions))
fmt.Fprintf(w, "\n")
diff --git a/gocached/gocached_test.go b/gocached/gocached_test.go
index a119d05..a46d92e 100644
--- a/gocached/gocached_test.go
+++ b/gocached/gocached_test.go
@@ -722,8 +722,11 @@ func TestExchangeToken(t *testing.T) {
t.Run(name, func(t *testing.T) {
issuer, createJWT := startOIDCServer(t, privateKey.Public())
st := newServerTester(t,
- WithJWTAuth(issuer, wantClaims),
- WithGlobalNamespaceJWTClaims(wantGlobalClaims),
+ WithJWTAuth(JWTIssuerConfig{
+ Issuer: issuer,
+ RequiredClaims: wantClaims,
+ GlobalWriteClaims: wantGlobalClaims,
+ }),
)
// Generate JWT.
@@ -837,6 +840,175 @@ func TestExchangeToken(t *testing.T) {
}
}
+func TestMultiIssuerAuth(t *testing.T) {
+ // Generate separate keys for each issuer.
+ keyA, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+ if err != nil {
+ t.Fatalf("error generating key A: %v", err)
+ }
+ keyB, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+ if err != nil {
+ t.Fatalf("error generating key B: %v", err)
+ }
+ keyC, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+ if err != nil {
+ t.Fatalf("error generating key C: %v", err)
+ }
+
+ issuerA, createJWTA := startOIDCServer(t, keyA.Public())
+ issuerB, createJWTB := startOIDCServer(t, keyB.Public())
+ issuerC, createJWTC := startOIDCServer(t, keyC.Public())
+
+ st := newServerTester(t,
+ WithJWTAuth(
+ JWTIssuerConfig{
+ Issuer: issuerA,
+ RequiredClaims: map[string]string{"sub": "userA"},
+ GlobalWriteClaims: map[string]string{
+ "sub": "userA",
+ "ref": "refs/heads/main",
+ },
+ },
+ JWTIssuerConfig{
+ Issuer: issuerB,
+ RequiredClaims: map[string]string{"sub": "userB"},
+ GlobalWriteClaims: map[string]string{
+ "sub": "userB",
+ "ref": "refs/heads/main",
+ },
+ },
+ ),
+ )
+
+ makeJWTBody := func(jwtString string) []byte {
+ body, err := json.Marshal(map[string]any{"jwt": jwtString})
+ if err != nil {
+ t.Fatalf("error marshaling request body: %v", err)
+ }
+ return body
+ }
+
+ exchangeToken := func(jwtBody []byte) (*http.Response, string) {
+ t.Helper()
+ req, err := http.NewRequest("POST", st.hs.URL+"/auth/exchange-token", bytes.NewReader(jwtBody))
+ if err != nil {
+ t.Fatalf("error creating request: %v", err)
+ }
+ resp, err := http.DefaultClient.Do(req)
+ if err != nil {
+ t.Fatalf("error making request: %v", err)
+ }
+ body, err := io.ReadAll(resp.Body)
+ resp.Body.Close()
+ if err != nil {
+ t.Fatalf("error reading response body: %v", err)
+ }
+ if resp.StatusCode == http.StatusOK {
+ var d struct {
+ AccessToken string `json:"access_token"`
+ }
+ if err := json.Unmarshal(body, &d); err != nil {
+ t.Fatalf("error decoding response body: %v", err)
+ }
+ return resp, d.AccessToken
+ }
+ return resp, ""
+ }
+
+ baseClaims := func(iss string) jwt.MapClaims {
+ return jwt.MapClaims{
+ "iss": iss,
+ "aud": gocachedAudience,
+ "nbf": jwt.NewNumericDate(time.Now().Add(-time.Minute)),
+ "exp": jwt.NewNumericDate(time.Now().Add(time.Hour)),
+ }
+ }
+
+ // Issuer A: valid read-only token.
+ t.Run("issuerA_read", func(t *testing.T) {
+ claims := baseClaims(issuerA)
+ claims["sub"] = "userA"
+ resp, accessToken := exchangeToken(makeJWTBody(createJWTA(claims, keyA)))
+ if resp.StatusCode != http.StatusOK {
+ t.Fatalf("unexpected status code: want %d, got %d", http.StatusOK, resp.StatusCode)
+ }
+ if accessToken == "" {
+ t.Fatal("expected access token")
+ }
+
+ cl := st.mkClient()
+ cl.AccessToken = accessToken
+ st.wantGetMiss(cl, "aabb01")
+ })
+
+ // Issuer A: valid write token.
+ t.Run("issuerA_write", func(t *testing.T) {
+ claims := baseClaims(issuerA)
+ claims["sub"] = "userA"
+ claims["ref"] = "refs/heads/main"
+ resp, accessToken := exchangeToken(makeJWTBody(createJWTA(claims, keyA)))
+ if resp.StatusCode != http.StatusOK {
+ t.Fatalf("unexpected status code: want %d, got %d", http.StatusOK, resp.StatusCode)
+ }
+ cl := st.mkClient()
+ cl.AccessToken = accessToken
+ st.wantPut(cl, "aabb02", "ccdd02", "hello-from-A")
+ st.wantGet(cl, "aabb02", "ccdd02", "hello-from-A")
+ })
+
+ // Issuer B: valid read-only token.
+ t.Run("issuerB_read", func(t *testing.T) {
+ claims := baseClaims(issuerB)
+ claims["sub"] = "userB"
+ resp, accessToken := exchangeToken(makeJWTBody(createJWTB(claims, keyB)))
+ if resp.StatusCode != http.StatusOK {
+ t.Fatalf("unexpected status code: want %d, got %d", http.StatusOK, resp.StatusCode)
+ }
+ if accessToken == "" {
+ t.Fatal("expected access token")
+ }
+ cl := st.mkClient()
+ cl.AccessToken = accessToken
+ // Can read data written by issuer A.
+ st.wantGet(cl, "aabb02", "ccdd02", "hello-from-A")
+ })
+
+ // Issuer B: valid write token.
+ t.Run("issuerB_write", func(t *testing.T) {
+ claims := baseClaims(issuerB)
+ claims["sub"] = "userB"
+ claims["ref"] = "refs/heads/main"
+ resp, accessToken := exchangeToken(makeJWTBody(createJWTB(claims, keyB)))
+ if resp.StatusCode != http.StatusOK {
+ t.Fatalf("unexpected status code: want %d, got %d", http.StatusOK, resp.StatusCode)
+ }
+ cl := st.mkClient()
+ cl.AccessToken = accessToken
+ st.wantPut(cl, "aabb03", "ccdd03", "hello-from-B")
+ st.wantGet(cl, "aabb03", "ccdd03", "hello-from-B")
+ })
+
+ // Issuer B: wrong required claims (sub doesn't match).
+ t.Run("issuerB_wrong_sub", func(t *testing.T) {
+ claims := baseClaims(issuerB)
+ claims["sub"] = "userA" // issuer B requires sub=userB
+ resp, _ := exchangeToken(makeJWTBody(createJWTB(claims, keyB)))
+ if resp.StatusCode != http.StatusUnauthorized {
+ t.Fatalf("unexpected status code: want %d, got %d", http.StatusUnauthorized, resp.StatusCode)
+ }
+ })
+
+ // Issuer C: not configured, should be rejected.
+ t.Run("issuerC_rejected", func(t *testing.T) {
+ claims := baseClaims(issuerC)
+ claims["sub"] = "userC"
+ resp, _ := exchangeToken(makeJWTBody(createJWTC(claims, keyC)))
+ if resp.StatusCode != http.StatusUnauthorized {
+ t.Fatalf("unexpected status code: want %d, got %d", http.StatusUnauthorized, resp.StatusCode)
+ }
+ })
+}
+
func BenchmarkFlushAccessTimes(b *testing.B) {
st := newServerTester(b, WithVerbose(false))
s := st.srv
diff --git a/gocached/internal/jwt/jwt.go b/gocached/internal/jwt/jwt.go
index 31f13a0..d572b81 100644
--- a/gocached/internal/jwt/jwt.go
+++ b/gocached/internal/jwt/jwt.go
@@ -6,6 +6,7 @@ package jwt
import (
"context"
"encoding/json"
+ "errors"
"fmt"
"io"
"net/http"
@@ -15,6 +16,7 @@ import (
"sync/atomic"
"time"
+ "github.com/bradfitz/go-tool-cache/gocached/logger"
"github.com/go-jose/go-jose/v4"
"github.com/golang-jwt/jwt/v5"
)
@@ -26,30 +28,64 @@ var (
supportedAlgorithms = []string{"HS256", "RS256", "ES256"}
)
+// issuer holds the per-issuer state: its own JWT parser and signing keys.
+type issuer struct {
+ iss string
+ parser *jwt.Parser
+ signingKeys atomic.Value // []jose.JSONWebKey
+}
+
+// keyFunc is how github.com/golang-jwt/jwt gets the public key it needs to
+// verify a JWT signature. Each issuer entry only checks its own keys.
+func (ie *issuer) keyFunc(t *jwt.Token) (any, error) {
+ var kid string
+ if v, ok := t.Header["kid"]; ok {
+ kid, _ = v.(string)
+ }
+ if kid == "" {
+ return nil, fmt.Errorf("no kid found in token header")
+ }
+
+ signingKeys := ie.signingKeys.Load().([]jose.JSONWebKey)
+ for _, k := range signingKeys {
+ if k.KeyID == kid {
+ return k.Key, nil
+ }
+ }
+
+ return nil, fmt.Errorf("unknown key ID: %s", kid)
+}
+
// NewJWTValidator constructs a [Validator] for validating JWTs. Must call
// [RunUpdateJWKSLoop] before validating any JWTs. Every JWT must exactly match
-// the provided issuer and audience values in its "iss" and "aud" claims
-// respectively. The issuer must be a reachable HTTP server that serves the JWT
-// public signing keys via the path defined by [oidcConfigWellKnownPath], and
-// the audience should be a value specific to the trust boundary that gocached
-// resides within.
-func NewJWTValidator(logf func(format string, args ...any), issuer, audience string) *Validator {
+// one of the provided issuers and the audience value in its "iss" and "aud"
+// claims respectively. Each issuer must be a reachable HTTP server that serves
+// the JWT public signing keys via the path defined by [oidcConfigWellKnownPath],
+// and the audience should be a value specific to the trust boundary that
+// gocached resides within.
+func NewJWTValidator(logf logger.Logf, audience string, issuerURLs []string) *Validator {
+ var issuers []*issuer
+ for _, iss := range issuerURLs {
+ issuers = append(issuers, &issuer{
+ iss: iss,
+ parser: jwt.NewParser(
+ jwt.WithValidMethods(supportedAlgorithms),
+ jwt.WithIssuer(iss),
+ jwt.WithAudience(audience),
+ jwt.WithLeeway(10*time.Second),
+ jwt.WithIssuedAt(),
+ ),
+ })
+ }
return &Validator{
- logf: logf,
- issuer: issuer,
- parser: jwt.NewParser(
- jwt.WithValidMethods(supportedAlgorithms),
- jwt.WithIssuer(issuer),
- jwt.WithAudience(audience),
- jwt.WithLeeway(10*time.Second),
- jwt.WithIssuedAt(),
- ),
+ logf: logf,
+ issuers: issuers,
}
}
// RunUpdateJWKSLoop fetches the JWKS synchronously once to surface any config
// errors early, and then starts a background goroutine that periodically fetches
-// the JWKS from the issuer to keep the signing keys up to date. Must be called
+// the JWKS from all issuers to keep the signing keys up to date. Must be called
// before validating any JWTs.
func (v *Validator) RunUpdateJWKSLoop(ctx context.Context) error {
// Initial fetch to error early on misconfiguration.
@@ -65,56 +101,43 @@ func (v *Validator) RunUpdateJWKSLoop(ctx context.Context) error {
// Validator provides methods for validating JWTs. Use [NewJWTValidator] to
// construct a working Validator.
type Validator struct {
- logf func(format string, args ...any)
- issuer string
- parser *jwt.Parser
-
- signingKeys atomic.Value // []jose.JSONWebKey
+ logf logger.Logf
+ issuers []*issuer
// TODO(tomhjp): metrics
}
// Validate returns an error if the provided JWT fails validation for an invalid
-// signature or standard claim (iss, aud, iat, nbf, exp). It returns the token's
-// verified claims if validation succeeds. The caller should then make policy
-// decisions based on other claims such as "sub" or other custom claims.
+// signature or standard claim (iss, aud, iat, nbf, exp). It tries each
+// configured issuer and returns the verified claims from the first successful
+// parse. If all issuers fail, it returns the last error.
func (v *Validator) Validate(ctx context.Context, jwtString string) (map[string]any, error) {
- tk, err := v.parser.Parse(jwtString, v.keyFunc)
- if err != nil {
- return nil, fmt.Errorf("failed to parse token: %w", err)
- }
-
- if !tk.Valid {
- return nil, fmt.Errorf("invalid token")
- }
+ var lastErr error
+ for _, ie := range v.issuers {
+ tk, err := ie.parser.Parse(jwtString, ie.keyFunc)
+ if err != nil {
+ lastErr = err
+ continue
+ }
- gotClaims, ok := tk.Claims.(jwt.MapClaims)
- if !ok {
- return nil, fmt.Errorf("unexpected claims type: %T", tk.Claims)
- }
+ if !tk.Valid {
+ lastErr = fmt.Errorf("invalid token")
+ continue
+ }
- return gotClaims, nil
-}
+ gotClaims, ok := tk.Claims.(jwt.MapClaims)
+ if !ok {
+ lastErr = fmt.Errorf("unexpected claims type: %T", tk.Claims)
+ continue
+ }
-// keyFunc is how github.com/golang-jwt/jwt gets the public key it needs to
-// verify a JWT signature.
-func (v *Validator) keyFunc(t *jwt.Token) (any, error) {
- var kid string
- if v, ok := t.Header["kid"]; ok {
- kid, _ = v.(string)
- }
- if kid == "" {
- return nil, fmt.Errorf("no kid found in token header")
+ return gotClaims, nil
}
- signingKeys := v.signingKeys.Load().([]jose.JSONWebKey)
- for _, k := range signingKeys {
- if k.KeyID == kid {
- return k.Key, nil
- }
+ if lastErr != nil {
+ return nil, fmt.Errorf("failed to parse token: %w", lastErr)
}
-
- return nil, fmt.Errorf("unknown key ID: %s", kid)
+ return nil, fmt.Errorf("no issuers configured")
}
func (v *Validator) runUpdateJWKSLoop(ctx context.Context) {
@@ -135,10 +158,20 @@ func (v *Validator) runUpdateJWKSLoop(ctx context.Context) {
}
func (v *Validator) updateJWKS(ctx context.Context) error {
- v.logf("jwt: fetching JWKS from issuer %q", v.issuer)
- u, err := url.Parse(v.issuer)
+ var errs []error
+ for _, ie := range v.issuers {
+ if err := ie.updateJWKS(ctx, v.logf); err != nil {
+ errs = append(errs, fmt.Errorf("issuer %q: %w", ie.iss, err))
+ }
+ }
+ return errors.Join(errs...)
+}
+
+func (ie *issuer) updateJWKS(ctx context.Context, logf logger.Logf) error {
+ logf("jwt: fetching JWKS from issuer %q", ie.iss)
+ u, err := url.Parse(ie.iss)
if err != nil {
- return fmt.Errorf("failed to parse issuer URL %q: %w", v.issuer, err)
+ return fmt.Errorf("failed to parse issuer URL %q: %w", ie.iss, err)
}
u.Path = path.Join(u.Path, oidcConfigWellKnownPath)
@@ -174,7 +207,7 @@ func (v *Validator) updateJWKS(ctx context.Context) error {
signingKeys = append(signingKeys, k)
}
- v.signingKeys.Store(signingKeys)
+ ie.signingKeys.Store(signingKeys)
return nil
}
diff --git a/gocached/logger/logger.go b/gocached/logger/logger.go
new file mode 100644
index 0000000..69b955e
--- /dev/null
+++ b/gocached/logger/logger.go
@@ -0,0 +1,7 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package logger
+
+// Logf is a logging function type. It is implemented by log.Printf.
+type Logf func(format string, args ...any)