From 3fa349df8caa59f73b8d08679897caff572f47da Mon Sep 17 00:00:00 2001 From: Sam El-Borai Date: Sun, 8 Feb 2026 22:22:22 +0100 Subject: [PATCH 1/2] Add stainless-proxy: JWE-based credential proxy prototype MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Stateless proxy that decrypts JWE tokens and injects raw API credentials at the last hop before the target API. Sandboxes running agent-generated code never see raw credentials — only encrypted JWEs that are host-locked to prevent exfiltration. Uses ECDH-ES+A256KW with A256GCM on P-256 keys (go-jose/v4). Includes host locking, key rotation via JWKS, revocation deny-list, per-request tracing, runtime/secret.Do integration for memory zeroing, and manual plaintext zeroing as a cross-platform fallback. 14 integration test scenarios stress-testing credential isolation, exfiltration attempts, key rotation, concurrency, SSE streaming, token expiry, path traversal, and more. Benchmarks covering throughput, latency percentiles, payload sizes, and deny-list scaling. --- stainless-proxy/cmd/mint/main.go | 108 +++ stainless-proxy/cmd/stainless-proxy/main.go | 88 +++ stainless-proxy/config.example.json | 9 + stainless-proxy/go.mod | 14 + stainless-proxy/go.sum | 12 + stainless-proxy/internal/config/config.go | 112 +++ .../internal/hostmatch/hostmatch.go | 31 + .../internal/hostmatch/hostmatch_test.go | 34 + stainless-proxy/internal/jwe/jwe.go | 152 ++++ stainless-proxy/internal/jwe/jwe_test.go | 151 ++++ stainless-proxy/internal/keystore/keystore.go | 180 +++++ stainless-proxy/internal/proxy/bench_test.go | 579 +++++++++++++++ .../internal/proxy/integration_test.go | 692 ++++++++++++++++++ stainless-proxy/internal/proxy/proxy.go | 280 +++++++ stainless-proxy/internal/proxy/proxy_test.go | 219 ++++++ stainless-proxy/internal/proxy/trace.go | 44 ++ .../internal/revocation/revocation.go | 66 ++ .../internal/revocation/revocation_test.go | 35 + .../internal/secretutil/do_default.go | 5 + .../internal/secretutil/do_experiment.go | 7 + stainless-proxy/internal/server/handlers.go | 129 ++++ stainless-proxy/internal/server/middleware.go | 90 +++ stainless-proxy/internal/server/server.go | 90 +++ 23 files changed, 3127 insertions(+) create mode 100644 stainless-proxy/cmd/mint/main.go create mode 100644 stainless-proxy/cmd/stainless-proxy/main.go create mode 100644 stainless-proxy/config.example.json create mode 100644 stainless-proxy/go.mod create mode 100644 stainless-proxy/go.sum create mode 100644 stainless-proxy/internal/config/config.go create mode 100644 stainless-proxy/internal/hostmatch/hostmatch.go create mode 100644 stainless-proxy/internal/hostmatch/hostmatch_test.go create mode 100644 stainless-proxy/internal/jwe/jwe.go create mode 100644 stainless-proxy/internal/jwe/jwe_test.go create mode 100644 stainless-proxy/internal/keystore/keystore.go create mode 100644 stainless-proxy/internal/proxy/bench_test.go create mode 100644 stainless-proxy/internal/proxy/integration_test.go create mode 100644 stainless-proxy/internal/proxy/proxy.go create mode 100644 stainless-proxy/internal/proxy/proxy_test.go create mode 100644 stainless-proxy/internal/proxy/trace.go create mode 100644 stainless-proxy/internal/revocation/revocation.go create mode 100644 stainless-proxy/internal/revocation/revocation_test.go create mode 100644 stainless-proxy/internal/secretutil/do_default.go create mode 100644 stainless-proxy/internal/secretutil/do_experiment.go create mode 100644 stainless-proxy/internal/server/handlers.go create mode 100644 stainless-proxy/internal/server/middleware.go create mode 100644 stainless-proxy/internal/server/server.go diff --git a/stainless-proxy/cmd/mint/main.go b/stainless-proxy/cmd/mint/main.go new file mode 100644 index 0000000..026dcee --- /dev/null +++ b/stainless-proxy/cmd/mint/main.go @@ -0,0 +1,108 @@ +package main + +import ( + "encoding/json" + "flag" + "fmt" + "net/http" + "os" + "strings" + "time" + + "github.com/go-jose/go-jose/v4" + "github.com/stainless-api/stainless-proxy/internal/jwe" + + "crypto/ecdsa" +) + +type credFlag []string + +func (f *credFlag) String() string { return strings.Join(*f, ", ") } +func (f *credFlag) Set(v string) error { + *f = append(*f, v) + return nil +} + +func main() { + jwksURL := flag.String("jwks-url", "", "URL to fetch JWKS from") + expDuration := flag.String("exp", "1h", "expiration duration") + hosts := flag.String("hosts", "", "comma-separated allowed hosts") + var creds credFlag + flag.Var(&creds, "cred", "credential as Header=Value (repeatable)") + flag.Parse() + + if *jwksURL == "" { + fmt.Fprintln(os.Stderr, "error: -jwks-url is required") + os.Exit(1) + } + if *hosts == "" { + fmt.Fprintln(os.Stderr, "error: -hosts is required") + os.Exit(1) + } + if len(creds) == 0 { + fmt.Fprintln(os.Stderr, "error: at least one -cred is required") + os.Exit(1) + } + + duration, err := time.ParseDuration(*expDuration) + if err != nil { + fmt.Fprintf(os.Stderr, "error: invalid expiration: %v\n", err) + os.Exit(1) + } + + // Fetch JWKS + resp, err := http.Get(*jwksURL) + if err != nil { + fmt.Fprintf(os.Stderr, "error: fetching JWKS: %v\n", err) + os.Exit(1) + } + defer resp.Body.Close() + + var jwks jose.JSONWebKeySet + if err := json.NewDecoder(resp.Body).Decode(&jwks); err != nil { + fmt.Fprintf(os.Stderr, "error: parsing JWKS: %v\n", err) + os.Exit(1) + } + + if len(jwks.Keys) == 0 { + fmt.Fprintln(os.Stderr, "error: no keys in JWKS") + os.Exit(1) + } + + // Use the first key + jwk := jwks.Keys[0] + pubKey, ok := jwk.Key.(*ecdsa.PublicKey) + if !ok { + fmt.Fprintln(os.Stderr, "error: first key is not an ECDSA public key") + os.Exit(1) + } + + // Parse credentials + var credentials []jwe.Credential + for _, c := range creds { + eqIdx := strings.IndexByte(c, '=') + if eqIdx == -1 { + fmt.Fprintf(os.Stderr, "error: invalid credential format: %s (expected Header=Value)\n", c) + os.Exit(1) + } + credentials = append(credentials, jwe.Credential{ + Header: c[:eqIdx], + Value: c[eqIdx+1:], + }) + } + + payload := jwe.Payload{ + Exp: time.Now().Add(duration).Unix(), + AllowedHosts: strings.Split(*hosts, ","), + Credentials: credentials, + } + + enc := jwe.NewEncryptor(pubKey, jwk.KeyID) + token, err := enc.Encrypt(payload) + if err != nil { + fmt.Fprintf(os.Stderr, "error: encrypting: %v\n", err) + os.Exit(1) + } + + fmt.Println(token) +} diff --git a/stainless-proxy/cmd/stainless-proxy/main.go b/stainless-proxy/cmd/stainless-proxy/main.go new file mode 100644 index 0000000..bc22669 --- /dev/null +++ b/stainless-proxy/cmd/stainless-proxy/main.go @@ -0,0 +1,88 @@ +package main + +import ( + "context" + "flag" + "log/slog" + "os" + + "github.com/stainless-api/stainless-proxy/internal/config" + "github.com/stainless-api/stainless-proxy/internal/jwe" + "github.com/stainless-api/stainless-proxy/internal/keystore" + "github.com/stainless-api/stainless-proxy/internal/proxy" + "github.com/stainless-api/stainless-proxy/internal/revocation" + "github.com/stainless-api/stainless-proxy/internal/server" +) + +func main() { + configPath := flag.String("config", "", "path to config file") + flag.Parse() + + if *configPath == "" { + slog.Error("config flag is required") + os.Exit(1) + } + + cfg, err := config.Load(*configPath) + if err != nil { + slog.Error("loading config", "error", err) + os.Exit(1) + } + + setupLogging(cfg) + + ks, err := keystore.New(cfg.KeyDir, cfg.GenerateKeys) + if err != nil { + slog.Error("initializing keystore", "error", err) + os.Exit(1) + } + + primary := ks.PrimaryKey() + slog.Info("keystore initialized", + "key_count", len(ks.Keys()), + "primary_kid", primary.KID, + ) + + var decryptorKeys []jwe.KeyEntry + for _, k := range ks.Keys() { + decryptorKeys = append(decryptorKeys, jwe.KeyEntry{ + KID: k.KID, + PrivateKey: k.PrivateKey, + }) + } + + decryptor := jwe.NewMultiKeyDecryptor(decryptorKeys) + denyList := revocation.NewDenyList() + p := proxy.New(decryptor, denyList) + srv := server.New(cfg, ks, p, denyList) + + if err := srv.Run(context.Background()); err != nil { + slog.Error("server error", "error", err) + os.Exit(1) + } +} + +func setupLogging(cfg *config.Config) { + var level slog.Level + switch cfg.LogLevel { + case "debug": + level = slog.LevelDebug + case "warn": + level = slog.LevelWarn + case "error": + level = slog.LevelError + default: + level = slog.LevelInfo + } + + opts := &slog.HandlerOptions{Level: level} + + var handler slog.Handler + if cfg.LogFormat == "json" { + handler = slog.NewJSONHandler(os.Stderr, opts) + } else { + handler = slog.NewTextHandler(os.Stderr, opts) + } + + slog.SetDefault(slog.New(handler)) +} diff --git a/stainless-proxy/config.example.json b/stainless-proxy/config.example.json new file mode 100644 index 0000000..8c8d0a2 --- /dev/null +++ b/stainless-proxy/config.example.json @@ -0,0 +1,9 @@ +{ + "addr": ":8443", + "keyDir": "./keys", + "generateKeys": true, + "mintEnabled": true, + "mintSecret": {"$env": "MINT_SECRET"}, + "logLevel": "debug", + "logFormat": "text" +} diff --git a/stainless-proxy/go.mod b/stainless-proxy/go.mod new file mode 100644 index 0000000..b1bf2d6 --- /dev/null +++ b/stainless-proxy/go.mod @@ -0,0 +1,14 @@ +module github.com/stainless-api/stainless-proxy + +go 1.26rc1 + +require ( + github.com/go-jose/go-jose/v4 v4.1.3 + github.com/stretchr/testify v1.11.1 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/stainless-proxy/go.sum b/stainless-proxy/go.sum new file mode 100644 index 0000000..eb1f89f --- /dev/null +++ b/stainless-proxy/go.sum @@ -0,0 +1,12 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-jose/go-jose/v4 v4.1.3 h1:CVLmWDhDVRa6Mi/IgCgaopNosCaHz7zrMeF9MlZRkrs= +github.com/go-jose/go-jose/v4 v4.1.3/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/stainless-proxy/internal/config/config.go b/stainless-proxy/internal/config/config.go new file mode 100644 index 0000000..745d830 --- /dev/null +++ b/stainless-proxy/internal/config/config.go @@ -0,0 +1,112 @@ +package config + +import ( + "encoding/json" + "fmt" + "os" +) + +type Secret string + +func (s Secret) String() string { + if s == "" { + return "" + } + return "***" +} + +func (s Secret) MarshalJSON() ([]byte, error) { + if s == "" { + return json.Marshal("") + } + return json.Marshal("***") +} + +type Config struct { + Addr string `json:"addr"` + KeyDir string `json:"keyDir"` + GenerateKeys bool `json:"generateKeys"` + MintEnabled bool `json:"mintEnabled"` + MintSecret Secret `json:"mintSecret"` + LogLevel string `json:"logLevel"` + LogFormat string `json:"logFormat"` +} + +func Load(path string) (*Config, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("reading config: %w", err) + } + + var raw rawConfig + if err := json.Unmarshal(data, &raw); err != nil { + return nil, fmt.Errorf("parsing config: %w", err) + } + + cfg := &Config{ + Addr: raw.Addr, + KeyDir: raw.KeyDir, + GenerateKeys: raw.GenerateKeys, + MintEnabled: raw.MintEnabled, + LogLevel: raw.LogLevel, + LogFormat: raw.LogFormat, + } + + if raw.MintSecret != nil { + secret, err := parseConfigValue(raw.MintSecret) + if err != nil { + return nil, fmt.Errorf("parsing mintSecret: %w", err) + } + cfg.MintSecret = Secret(secret) + } + + if cfg.Addr == "" { + cfg.Addr = ":8443" + } + if cfg.LogLevel == "" { + cfg.LogLevel = "info" + } + if cfg.LogFormat == "" { + cfg.LogFormat = "text" + } + + return cfg, nil +} + +type rawConfig struct { + Addr string `json:"addr"` + KeyDir string `json:"keyDir"` + GenerateKeys bool `json:"generateKeys"` + MintEnabled bool `json:"mintEnabled"` + MintSecret json.RawMessage `json:"mintSecret"` + LogLevel string `json:"logLevel"` + LogFormat string `json:"logFormat"` +} + +func parseConfigValue(raw json.RawMessage) (string, error) { + var str string + if err := json.Unmarshal(raw, &str); err == nil { + return str, nil + } + + var ref map[string]string + if err := json.Unmarshal(raw, &ref); err != nil { + return "", fmt.Errorf("config value must be string or reference object") + } + + if envVar, ok := ref["$env"]; ok { + value := os.Getenv(envVar) + if value == "" { + return "", fmt.Errorf("environment variable %s not set", envVar) + } + if len(value) >= 2 { + if (value[0] == '"' && value[len(value)-1] == '"') || + (value[0] == '\'' && value[len(value)-1] == '\'') { + value = value[1 : len(value)-1] + } + } + return value, nil + } + + return "", fmt.Errorf("unknown reference type in config value") +} diff --git a/stainless-proxy/internal/hostmatch/hostmatch.go b/stainless-proxy/internal/hostmatch/hostmatch.go new file mode 100644 index 0000000..ddc55cc --- /dev/null +++ b/stainless-proxy/internal/hostmatch/hostmatch.go @@ -0,0 +1,31 @@ +package hostmatch + +import ( + "net" + "strings" +) + +func Match(host string, patterns []string) bool { + host = stripPort(strings.ToLower(host)) + for _, p := range patterns { + p = stripPort(strings.ToLower(p)) + if p == host { + return true + } + if strings.HasPrefix(p, "*.") { + suffix := p[1:] // ".example.com" + if strings.HasSuffix(host, suffix) && host != suffix[1:] { + return true + } + } + } + return false +} + +func stripPort(host string) string { + h, _, err := net.SplitHostPort(host) + if err != nil { + return host + } + return h +} diff --git a/stainless-proxy/internal/hostmatch/hostmatch_test.go b/stainless-proxy/internal/hostmatch/hostmatch_test.go new file mode 100644 index 0000000..8c43b81 --- /dev/null +++ b/stainless-proxy/internal/hostmatch/hostmatch_test.go @@ -0,0 +1,34 @@ +package hostmatch + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestMatch(t *testing.T) { + tests := []struct { + name string + host string + patterns []string + want bool + }{ + {"exact match", "api.example.com", []string{"api.example.com"}, true}, + {"no match", "api.example.com", []string{"other.com"}, false}, + {"wildcard match", "api.example.com", []string{"*.example.com"}, true}, + {"wildcard no match on bare domain", "example.com", []string{"*.example.com"}, false}, + {"wildcard deep subdomain", "deep.api.example.com", []string{"*.example.com"}, true}, + {"case insensitive", "API.Example.COM", []string{"api.example.com"}, true}, + {"strip port", "api.example.com:443", []string{"api.example.com"}, true}, + {"strip port from pattern", "api.example.com", []string{"api.example.com:443"}, true}, + {"empty patterns", "api.example.com", nil, false}, + {"empty host", "", []string{"api.example.com"}, false}, + {"multiple patterns", "api.example.com", []string{"other.com", "*.example.com"}, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, Match(tt.host, tt.patterns)) + }) + } +} diff --git a/stainless-proxy/internal/jwe/jwe.go b/stainless-proxy/internal/jwe/jwe.go new file mode 100644 index 0000000..b5bb35c --- /dev/null +++ b/stainless-proxy/internal/jwe/jwe.go @@ -0,0 +1,152 @@ +package jwe + +import ( + "crypto/ecdsa" + "encoding/json" + "fmt" + "time" + + "github.com/go-jose/go-jose/v4" + "github.com/stainless-api/stainless-proxy/internal/secretutil" +) + +type Credential struct { + Header string `json:"header"` + Value string `json:"value"` +} + +type Payload struct { + Exp int64 `json:"exp"` + AllowedHosts []string `json:"allowed_hosts"` + Credentials []Credential `json:"credentials"` +} + +type Encryptor struct { + publicKey *ecdsa.PublicKey + kid string +} + +func NewEncryptor(publicKey *ecdsa.PublicKey, kid string) *Encryptor { + return &Encryptor{publicKey: publicKey, kid: kid} +} + +func (e *Encryptor) Encrypt(payload Payload) (string, error) { + plaintext, err := json.Marshal(payload) + if err != nil { + return "", fmt.Errorf("marshaling payload: %w", err) + } + + recipient := jose.Recipient{ + Algorithm: jose.ECDH_ES_A256KW, + Key: e.publicKey, + KeyID: e.kid, + } + + enc, err := jose.NewEncrypter( + jose.A256GCM, + recipient, + (&jose.EncrypterOptions{}).WithContentType("JWT"), + ) + if err != nil { + return "", fmt.Errorf("creating encrypter: %w", err) + } + + obj, err := enc.Encrypt(plaintext) + if err != nil { + return "", fmt.Errorf("encrypting: %w", err) + } + + return obj.CompactSerialize() +} + +type Decryptor interface { + Decrypt(token string) (*Payload, error) +} + +type KeyEntry struct { + KID string + PrivateKey *ecdsa.PrivateKey +} + +type MultiKeyDecryptor struct { + keys map[string]*ecdsa.PrivateKey +} + +func NewMultiKeyDecryptor(keys []KeyEntry) *MultiKeyDecryptor { + m := make(map[string]*ecdsa.PrivateKey, len(keys)) + for _, k := range keys { + m[k.KID] = k.PrivateKey + } + return &MultiKeyDecryptor{keys: m} +} + +// decryptAndValidate decrypts the JWE and validates the payload. +// This function is isolated so it can be wrapped in runtime/secret.Do() +// once Go 1.26 is stable — all plaintext credential handling happens here. +func (d *MultiKeyDecryptor) decryptAndValidate(token string) (*Payload, error) { + obj, err := jose.ParseEncrypted(token, + []jose.KeyAlgorithm{jose.ECDH_ES_A256KW}, + []jose.ContentEncryption{jose.A256GCM}, + ) + if err != nil { + return nil, fmt.Errorf("parsing JWE: %w", err) + } + + kid := obj.Header.KeyID + key, ok := d.keys[kid] + if !ok { + return nil, fmt.Errorf("unknown key ID: %s", kid) + } + + plaintext, err := obj.Decrypt(key) + if err != nil { + return nil, fmt.Errorf("decrypting: %w", err) + } + + var payload Payload + if err := json.Unmarshal(plaintext, &payload); err != nil { + zeroBytes(plaintext) + return nil, fmt.Errorf("unmarshaling payload: %w", err) + } + zeroBytes(plaintext) + + if payload.Exp > 0 && time.Now().Unix() > payload.Exp { + return nil, fmt.Errorf("token expired") + } + + if len(payload.AllowedHosts) == 0 { + return nil, fmt.Errorf("allowed_hosts is required") + } + + if len(payload.Credentials) == 0 { + return nil, fmt.Errorf("credentials is required") + } + + return &payload, nil +} + +func (d *MultiKeyDecryptor) Decrypt(token string) (*Payload, error) { + var ( + payload *Payload + decErr error + ) + // secret.Do zeros stack frames, registers, and (with GOEXPERIMENT=runtimesecret) + // marks heap allocations for zeroing on GC. This covers: + // - plaintext []byte from go-jose's obj.Decrypt() + // - go-jose's internal AES-GCM / ECDH intermediate buffers + // - json.Unmarshal's working memory + // The returned *Payload escapes the closure (still reachable), but its + // string fields will be zeroed by GC once the Payload becomes unreachable. + secretutil.Do(func() { + payload, decErr = d.decryptAndValidate(token) + }) + return payload, decErr +} + +// zeroBytes overwrites b with zeros. Called immediately after unmarshal +// so the raw JSON containing credential values doesn't linger on the heap. +// +//go:noinline +func zeroBytes(b []byte) { + clear(b) +} diff --git a/stainless-proxy/internal/jwe/jwe_test.go b/stainless-proxy/internal/jwe/jwe_test.go new file mode 100644 index 0000000..f7dd06a --- /dev/null +++ b/stainless-proxy/internal/jwe/jwe_test.go @@ -0,0 +1,151 @@ +package jwe + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func generateTestKey(t *testing.T) *ecdsa.PrivateKey { + t.Helper() + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + return key +} + +func TestRoundTrip(t *testing.T) { + key := generateTestKey(t) + kid := "test-key-1" + + enc := NewEncryptor(&key.PublicKey, kid) + dec := NewMultiKeyDecryptor([]KeyEntry{{KID: kid, PrivateKey: key}}) + + payload := Payload{ + Exp: time.Now().Add(time.Hour).Unix(), + AllowedHosts: []string{"api.example.com"}, + Credentials: []Credential{ + {Header: "Authorization", Value: "Bearer secret-token"}, + {Header: "X-Api-Key", Value: "key-123"}, + }, + } + + token, err := enc.Encrypt(payload) + require.NoError(t, err) + assert.NotEmpty(t, token) + + got, err := dec.Decrypt(token) + require.NoError(t, err) + assert.Equal(t, payload.Exp, got.Exp) + assert.Equal(t, payload.AllowedHosts, got.AllowedHosts) + assert.Equal(t, payload.Credentials, got.Credentials) +} + +func TestExpiredToken(t *testing.T) { + key := generateTestKey(t) + kid := "test-key-1" + + enc := NewEncryptor(&key.PublicKey, kid) + dec := NewMultiKeyDecryptor([]KeyEntry{{KID: kid, PrivateKey: key}}) + + payload := Payload{ + Exp: time.Now().Add(-time.Hour).Unix(), + AllowedHosts: []string{"api.example.com"}, + Credentials: []Credential{{Header: "Authorization", Value: "Bearer x"}}, + } + + token, err := enc.Encrypt(payload) + require.NoError(t, err) + + _, err = dec.Decrypt(token) + assert.ErrorContains(t, err, "expired") +} + +func TestWrongKey(t *testing.T) { + key1 := generateTestKey(t) + key2 := generateTestKey(t) + + enc := NewEncryptor(&key1.PublicKey, "key1") + dec := NewMultiKeyDecryptor([]KeyEntry{{KID: "key2", PrivateKey: key2}}) + + payload := Payload{ + Exp: time.Now().Add(time.Hour).Unix(), + AllowedHosts: []string{"api.example.com"}, + Credentials: []Credential{{Header: "Authorization", Value: "Bearer x"}}, + } + + token, err := enc.Encrypt(payload) + require.NoError(t, err) + + _, err = dec.Decrypt(token) + assert.Error(t, err) +} + +func TestMultipleKeys(t *testing.T) { + key1 := generateTestKey(t) + key2 := generateTestKey(t) + + dec := NewMultiKeyDecryptor([]KeyEntry{ + {KID: "key1", PrivateKey: key1}, + {KID: "key2", PrivateKey: key2}, + }) + + // Encrypt with key2 + enc := NewEncryptor(&key2.PublicKey, "key2") + payload := Payload{ + Exp: time.Now().Add(time.Hour).Unix(), + AllowedHosts: []string{"api.example.com"}, + Credentials: []Credential{{Header: "X-Key", Value: "val"}}, + } + + token, err := enc.Encrypt(payload) + require.NoError(t, err) + + got, err := dec.Decrypt(token) + require.NoError(t, err) + assert.Equal(t, "val", got.Credentials[0].Value) +} + +func TestMissingAllowedHosts(t *testing.T) { + key := generateTestKey(t) + kid := "test-key-1" + + enc := NewEncryptor(&key.PublicKey, kid) + dec := NewMultiKeyDecryptor([]KeyEntry{{KID: kid, PrivateKey: key}}) + + payload := Payload{ + Exp: time.Now().Add(time.Hour).Unix(), + AllowedHosts: nil, + Credentials: []Credential{{Header: "X-Key", Value: "val"}}, + } + + token, err := enc.Encrypt(payload) + require.NoError(t, err) + + _, err = dec.Decrypt(token) + assert.ErrorContains(t, err, "allowed_hosts") +} + +func TestMissingCredentials(t *testing.T) { + key := generateTestKey(t) + kid := "test-key-1" + + enc := NewEncryptor(&key.PublicKey, kid) + dec := NewMultiKeyDecryptor([]KeyEntry{{KID: kid, PrivateKey: key}}) + + payload := Payload{ + Exp: time.Now().Add(time.Hour).Unix(), + AllowedHosts: []string{"api.example.com"}, + Credentials: nil, + } + + token, err := enc.Encrypt(payload) + require.NoError(t, err) + + _, err = dec.Decrypt(token) + assert.ErrorContains(t, err, "credentials") +} diff --git a/stainless-proxy/internal/keystore/keystore.go b/stainless-proxy/internal/keystore/keystore.go new file mode 100644 index 0000000..d3cf3ff --- /dev/null +++ b/stainless-proxy/internal/keystore/keystore.go @@ -0,0 +1,180 @@ +package keystore + +import ( + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/x509" + "encoding/base64" + "encoding/pem" + "fmt" + "os" + "path/filepath" + "sort" + "time" + + "github.com/go-jose/go-jose/v4" +) + +type KeyEntry struct { + KID string + PrivateKey *ecdsa.PrivateKey + PublicKey *ecdsa.PublicKey + CreatedAt time.Time +} + +type KeyStore struct { + keys []KeyEntry +} + +func New(keyDir string, generateIfEmpty bool) (*KeyStore, error) { + ks := &KeyStore{} + + if keyDir != "" { + if err := ks.loadFromDir(keyDir); err != nil { + return nil, fmt.Errorf("loading keys: %w", err) + } + } + + if len(ks.keys) == 0 { + if !generateIfEmpty { + return nil, fmt.Errorf("no keys found and generation disabled") + } + entry, err := generateKey() + if err != nil { + return nil, fmt.Errorf("generating key: %w", err) + } + ks.keys = append(ks.keys, entry) + + if keyDir != "" { + if err := os.MkdirAll(keyDir, 0700); err != nil { + return nil, fmt.Errorf("creating key directory: %w", err) + } + path := filepath.Join(keyDir, fmt.Sprintf("key-%s.pem", time.Now().Format("2006-01-02"))) + if err := writeKeyFile(path, entry.PrivateKey); err != nil { + return nil, fmt.Errorf("writing key: %w", err) + } + } + } + + sort.Slice(ks.keys, func(i, j int) bool { + return ks.keys[i].CreatedAt.After(ks.keys[j].CreatedAt) + }) + + return ks, nil +} + +func (ks *KeyStore) JWKS() jose.JSONWebKeySet { + var keys []jose.JSONWebKey + for _, entry := range ks.keys { + keys = append(keys, jose.JSONWebKey{ + Key: entry.PublicKey, + KeyID: entry.KID, + Algorithm: string(jose.ECDH_ES_A256KW), + Use: "enc", + }) + } + return jose.JSONWebKeySet{Keys: keys} +} + +func (ks *KeyStore) PrimaryKey() KeyEntry { + return ks.keys[0] +} + +func (ks *KeyStore) Keys() []KeyEntry { + return ks.keys +} + +func (ks *KeyStore) loadFromDir(dir string) error { + matches, err := filepath.Glob(filepath.Join(dir, "*.pem")) + if err != nil { + return fmt.Errorf("globbing key files: %w", err) + } + + for _, path := range matches { + entry, err := loadKeyFile(path) + if err != nil { + return fmt.Errorf("loading %s: %w", path, err) + } + ks.keys = append(ks.keys, entry) + } + return nil +} + +func generateKey() (KeyEntry, error) { + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return KeyEntry{}, fmt.Errorf("generating ECDSA key: %w", err) + } + + kid, err := thumbprint(priv) + if err != nil { + return KeyEntry{}, fmt.Errorf("computing thumbprint: %w", err) + } + + return KeyEntry{ + KID: kid, + PrivateKey: priv, + PublicKey: &priv.PublicKey, + CreatedAt: time.Now(), + }, nil +} + +func thumbprint(key *ecdsa.PrivateKey) (string, error) { + jwk := jose.JSONWebKey{Key: &key.PublicKey} + kid, err := jwk.Thumbprint(crypto.SHA256) + if err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(kid), nil +} + +func loadKeyFile(path string) (KeyEntry, error) { + data, err := os.ReadFile(path) + if err != nil { + return KeyEntry{}, err + } + + block, _ := pem.Decode(data) + if block == nil { + return KeyEntry{}, fmt.Errorf("no PEM block found") + } + + priv, err := x509.ParseECPrivateKey(block.Bytes) + if err != nil { + return KeyEntry{}, fmt.Errorf("parsing EC private key: %w", err) + } + + kid, err := thumbprint(priv) + if err != nil { + return KeyEntry{}, fmt.Errorf("computing thumbprint: %w", err) + } + + info, err := os.Stat(path) + if err != nil { + return KeyEntry{}, err + } + + return KeyEntry{ + KID: kid, + PrivateKey: priv, + PublicKey: &priv.PublicKey, + CreatedAt: info.ModTime(), + }, nil +} + +func writeKeyFile(path string, key *ecdsa.PrivateKey) error { + der, err := x509.MarshalECPrivateKey(key) + if err != nil { + return fmt.Errorf("marshaling key: %w", err) + } + + block := &pem.Block{ + Type: "EC PRIVATE KEY", + Bytes: der, + } + + return os.WriteFile(path, pem.EncodeToMemory(block), 0600) +} + diff --git a/stainless-proxy/internal/proxy/bench_test.go b/stainless-proxy/internal/proxy/bench_test.go new file mode 100644 index 0000000..bd76160 --- /dev/null +++ b/stainless-proxy/internal/proxy/bench_test.go @@ -0,0 +1,579 @@ +package proxy + +import ( + "bytes" + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "fmt" + "log/slog" + "math" + "net/http" + "net/http/httptest" + "os" + "sort" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stainless-api/stainless-proxy/internal/jwe" + "github.com/stainless-api/stainless-proxy/internal/revocation" +) + +type traceCollector struct { + mu sync.Mutex + traces []RequestTrace +} + +func (c *traceCollector) collect(t RequestTrace) { + c.mu.Lock() + c.traces = append(c.traces, t) + c.mu.Unlock() +} + +func (c *traceCollector) reset() { + c.mu.Lock() + c.traces = c.traces[:0] + c.mu.Unlock() +} + +type benchHarness struct { + proxy *Proxy + enc *jwe.Encryptor + backend *httptest.Server + host string + col *traceCollector +} + +func suppressLogs(tb testing.TB) { + tb.Helper() + prev := slog.Default() + slog.SetDefault(slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))) + tb.Cleanup(func() { slog.SetDefault(prev) }) +} + +func newBenchHarness(tb testing.TB, handler http.Handler) *benchHarness { + tb.Helper() + suppressLogs(tb) + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + tb.Fatal(err) + } + + kid := "bench-key" + enc := jwe.NewEncryptor(&key.PublicKey, kid) + dec := jwe.NewMultiKeyDecryptor([]jwe.KeyEntry{{KID: kid, PrivateKey: key}}) + dl := revocation.NewDenyList() + p := New(dec, dl) + + col := &traceCollector{} + p.OnTrace = col.collect + + backend := httptest.NewTLSServer(handler) + p.client = backend.Client() + host := strings.TrimPrefix(backend.URL, "https://") + + tb.Cleanup(backend.Close) + + return &benchHarness{proxy: p, enc: enc, backend: backend, host: host, col: col} +} + +func (h *benchHarness) mint(tb testing.TB, creds []jwe.Credential) string { + tb.Helper() + payload := jwe.Payload{ + Exp: time.Now().Add(time.Hour).Unix(), + AllowedHosts: []string{h.host}, + Credentials: creds, + } + token, err := h.enc.Encrypt(payload) + if err != nil { + tb.Fatal(err) + } + return token +} + +// ============================================================================= +// Benchmark: JWE encrypt + decrypt cycle (crypto overhead in isolation) +// ============================================================================= + +func BenchmarkJWECrypto(b *testing.B) { + key, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + enc := jwe.NewEncryptor(&key.PublicKey, "bench") + dec := jwe.NewMultiKeyDecryptor([]jwe.KeyEntry{{KID: "bench", PrivateKey: key}}) + + payload := jwe.Payload{ + Exp: time.Now().Add(time.Hour).Unix(), + AllowedHosts: []string{"api.example.com"}, + Credentials: []jwe.Credential{ + {Header: "DD-API-KEY", Value: "datadog-api-key-12345"}, + {Header: "DD-APPLICATION-KEY", Value: "datadog-app-key-67890"}, + }, + } + + b.Run("encrypt", func(b *testing.B) { + for b.Loop() { + _, err := enc.Encrypt(payload) + if err != nil { + b.Fatal(err) + } + } + }) + + token, _ := enc.Encrypt(payload) + b.Run("decrypt", func(b *testing.B) { + for b.Loop() { + _, err := dec.Decrypt(token) + if err != nil { + b.Fatal(err) + } + } + }) + + b.Run("round_trip", func(b *testing.B) { + for b.Loop() { + t, _ := enc.Encrypt(payload) + _, err := dec.Decrypt(t) + if err != nil { + b.Fatal(err) + } + } + }) +} + +// ============================================================================= +// Benchmark: End-to-end proxy throughput (sequential) +// Simulates a sandbox making sequential API calls through the proxy. +// ============================================================================= + +func BenchmarkProxySequential(b *testing.B) { + h := newBenchHarness(b, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"status":"ok"}`)) + })) + + token := h.mint(b, []jwe.Credential{ + {Header: "DD-API-KEY", Value: "datadog-api-key-12345"}, + {Header: "DD-APPLICATION-KEY", Value: "datadog-app-key-67890"}, + }) + + h.col.reset() + b.ResetTimer() + + for b.Loop() { + req := httptest.NewRequest("GET", "/https/"+h.host+"/api/v2/metrics", nil) + req.Header.Set("Authorization", "Bearer "+token) + w := httptest.NewRecorder() + h.proxy.ServeHTTP(w, req) + if w.Code != http.StatusOK { + b.Fatalf("unexpected status: %d", w.Code) + } + } + + b.StopTimer() + reportTraces(b, h.col.traces) +} + +// ============================================================================= +// Benchmark: End-to-end proxy throughput (parallel) +// Simulates multiple sandboxes hitting the proxy concurrently. +// ============================================================================= + +func BenchmarkProxyParallel(b *testing.B) { + h := newBenchHarness(b, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"status":"ok"}`)) + })) + + token := h.mint(b, []jwe.Credential{ + {Header: "Authorization", Value: "Bearer real-token"}, + }) + + h.col.reset() + b.ResetTimer() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + req := httptest.NewRequest("GET", "/https/"+h.host+"/api/data", nil) + req.Header.Set("Authorization", "Bearer "+token) + w := httptest.NewRecorder() + h.proxy.ServeHTTP(w, req) + if w.Code != http.StatusOK { + b.Fatalf("unexpected status: %d", w.Code) + } + } + }) + + b.StopTimer() + reportTraces(b, h.col.traces) +} + +// ============================================================================= +// Benchmark: Proxy with varying payload sizes +// Measures throughput degradation as request/response bodies grow. +// ============================================================================= + +func BenchmarkProxyPayloadSizes(b *testing.B) { + sizes := []struct { + name string + size int + }{ + {"1KB", 1024}, + {"10KB", 10 * 1024}, + {"100KB", 100 * 1024}, + {"1MB", 1024 * 1024}, + {"10MB", 10 * 1024 * 1024}, + } + + for _, sz := range sizes { + b.Run(sz.name, func(b *testing.B) { + responseBody := bytes.Repeat([]byte("x"), sz.size) + + h := newBenchHarness(b, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/octet-stream") + w.Write(responseBody) + })) + + token := h.mint(b, []jwe.Credential{ + {Header: "Authorization", Value: "Bearer x"}, + }) + + requestBody := bytes.Repeat([]byte("y"), sz.size) + + h.col.reset() + b.ResetTimer() + b.SetBytes(int64(sz.size) * 2) // request + response + + for b.Loop() { + req := httptest.NewRequest("POST", "/https/"+h.host+"/api", bytes.NewReader(requestBody)) + req.Header.Set("Authorization", "Bearer "+token) + w := httptest.NewRecorder() + h.proxy.ServeHTTP(w, req) + if w.Code != http.StatusOK { + b.Fatalf("unexpected status: %d", w.Code) + } + } + + b.StopTimer() + reportTraces(b, h.col.traces) + }) + } +} + +// ============================================================================= +// Benchmark: Proxy with varying credential counts +// Measures impact of many credentials in a single JWE. +// ============================================================================= + +func BenchmarkProxyCredentialCount(b *testing.B) { + counts := []int{1, 5, 10, 50} + + for _, n := range counts { + b.Run(fmt.Sprintf("%d_credentials", n), func(b *testing.B) { + h := newBenchHarness(b, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + creds := make([]jwe.Credential, n) + for i := range n { + creds[i] = jwe.Credential{ + Header: fmt.Sprintf("X-Credential-%d", i), + Value: fmt.Sprintf("secret-value-%d-with-some-realistic-length-padding", i), + } + } + + token := h.mint(b, creds) + + h.col.reset() + b.ResetTimer() + + for b.Loop() { + req := httptest.NewRequest("GET", "/https/"+h.host+"/api", nil) + req.Header.Set("Authorization", "Bearer "+token) + w := httptest.NewRecorder() + h.proxy.ServeHTTP(w, req) + if w.Code != http.StatusOK { + b.Fatalf("unexpected status: %d", w.Code) + } + } + + b.StopTimer() + reportTraces(b, h.col.traces) + }) + } +} + +// ============================================================================= +// Benchmark: Deny-list impact under load +// Measures overhead when the deny-list has many entries. +// ============================================================================= + +func BenchmarkDenyListOverhead(b *testing.B) { + entryCounts := []int{0, 100, 10_000, 100_000} + + for _, n := range entryCounts { + b.Run(fmt.Sprintf("%d_entries", n), func(b *testing.B) { + h := newBenchHarness(b, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + for i := range n { + h.proxy.denyList.Add(fmt.Sprintf("fake-hash-%d", i), time.Now().Add(time.Hour)) + } + + token := h.mint(b, []jwe.Credential{ + {Header: "Authorization", Value: "Bearer x"}, + }) + + h.col.reset() + b.ResetTimer() + + for b.Loop() { + req := httptest.NewRequest("GET", "/https/"+h.host+"/api", nil) + req.Header.Set("Authorization", "Bearer "+token) + w := httptest.NewRecorder() + h.proxy.ServeHTTP(w, req) + if w.Code != http.StatusOK { + b.Fatalf("unexpected status: %d", w.Code) + } + } + + b.StopTimer() + reportTraces(b, h.col.traces) + }) + } +} + +// ============================================================================= +// Benchmark: Slow backend (realistic network latency simulation) +// Measures proxy overhead when the bottleneck is upstream latency. +// ============================================================================= + +func BenchmarkProxyWithLatency(b *testing.B) { + latencies := []time.Duration{ + 0, + 1 * time.Millisecond, + 10 * time.Millisecond, + 50 * time.Millisecond, + } + + for _, lat := range latencies { + name := "0ms" + if lat > 0 { + name = lat.String() + } + b.Run(name, func(b *testing.B) { + h := newBenchHarness(b, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if lat > 0 { + time.Sleep(lat) + } + w.Write([]byte(`{"ok":true}`)) + })) + + token := h.mint(b, []jwe.Credential{ + {Header: "Authorization", Value: "Bearer x"}, + }) + + h.col.reset() + b.ResetTimer() + + for b.Loop() { + req := httptest.NewRequest("GET", "/https/"+h.host+"/api", nil) + req.Header.Set("Authorization", "Bearer "+token) + w := httptest.NewRecorder() + h.proxy.ServeHTTP(w, req) + } + + b.StopTimer() + reportTraces(b, h.col.traces) + }) + } +} + +// ============================================================================= +// Benchmark: Sustained concurrent load (throughput + latency distribution) +// Simulates realistic sustained traffic and reports percentiles. +// ============================================================================= + +func TestBenchmarkSustainedLoad(t *testing.T) { + if testing.Short() { + t.Skip("skipping sustained load test in short mode") + } + + h := newBenchHarness(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"data":[1,2,3]}`)) + })) + + token := h.mint(t, []jwe.Credential{ + {Header: "DD-API-KEY", Value: "datadog-api-key-12345"}, + {Header: "DD-APPLICATION-KEY", Value: "datadog-app-key-67890"}, + }) + + concurrency := []int{1, 10, 50, 100, 200} + duration := 3 * time.Second + + for _, c := range concurrency { + t.Run(fmt.Sprintf("%d_goroutines", c), func(t *testing.T) { + h.col.reset() + var totalRequests atomic.Int64 + var errors atomic.Int64 + + ctx, cancel := timeoutContext(duration) + defer cancel() + + var wg sync.WaitGroup + for range c { + wg.Add(1) + go func() { + defer wg.Done() + for { + select { + case <-ctx.Done(): + return + default: + } + + req := httptest.NewRequest("GET", "/https/"+h.host+"/api/v2/metrics", nil) + req.Header.Set("Authorization", "Bearer "+token) + w := httptest.NewRecorder() + h.proxy.ServeHTTP(w, req) + + totalRequests.Add(1) + if w.Code != http.StatusOK { + errors.Add(1) + } + } + }() + } + + wg.Wait() + + total := totalRequests.Load() + errs := errors.Load() + rps := float64(total) / duration.Seconds() + + t.Logf("=== %d goroutines, %.1fs ===", c, duration.Seconds()) + t.Logf(" total requests: %d", total) + t.Logf(" errors: %d", errs) + t.Logf(" throughput: %.0f req/s", rps) + reportTracesTest(t, h.col.traces) + }) + } +} + +func timeoutContext(d time.Duration) (context.Context, context.CancelFunc) { + return context.WithTimeout(context.Background(), d) +} + +// ============================================================================= +// Trace reporting helpers +// ============================================================================= + +func reportTraces(b *testing.B, traces []RequestTrace) { + if len(traces) == 0 { + return + } + + var ( + totalDenyList time.Duration + totalDecrypt time.Duration + totalBuild time.Duration + totalUpstream time.Duration + totalStream time.Duration + totalOverall time.Duration + ) + + for _, t := range traces { + totalDenyList += t.DenyListCheck + totalDecrypt += t.JWEDecrypt + totalBuild += t.RequestBuild + totalUpstream += t.UpstreamRoundTrip + totalStream += t.ResponseStream + totalOverall += t.Total + } + + n := len(traces) + b.ReportMetric(float64(totalDenyList.Nanoseconds())/float64(n), "ns/deny-list") + b.ReportMetric(float64(totalDecrypt.Nanoseconds())/float64(n), "ns/jwe-decrypt") + b.ReportMetric(float64(totalBuild.Nanoseconds())/float64(n), "ns/req-build") + b.ReportMetric(float64(totalUpstream.Nanoseconds())/float64(n), "ns/upstream") + b.ReportMetric(float64(totalStream.Nanoseconds())/float64(n), "ns/response") + b.ReportMetric(float64(totalOverall.Nanoseconds())/float64(n), "ns/total") +} + +func reportTracesTest(t *testing.T, traces []RequestTrace) { + if len(traces) == 0 { + return + } + + totals := extractDurations(traces, func(tr RequestTrace) time.Duration { return tr.Total }) + decrypts := extractDurations(traces, func(tr RequestTrace) time.Duration { return tr.JWEDecrypt }) + upstreams := extractDurations(traces, func(tr RequestTrace) time.Duration { return tr.UpstreamRoundTrip }) + denyLists := extractDurations(traces, func(tr RequestTrace) time.Duration { return tr.DenyListCheck }) + + t.Logf(" --- total latency ---") + reportPercentiles(t, " total", totals) + t.Logf(" --- phase breakdown (avg) ---") + t.Logf(" deny_list: %s", avg(denyLists)) + t.Logf(" jwe_decrypt: %s", avg(decrypts)) + t.Logf(" upstream: %s", avg(upstreams)) + t.Logf(" --- phase breakdown (p99) ---") + t.Logf(" deny_list: %s", percentile(denyLists, 0.99)) + t.Logf(" jwe_decrypt: %s", percentile(decrypts, 0.99)) + t.Logf(" upstream: %s", percentile(upstreams, 0.99)) + + proxyOverhead := make([]time.Duration, len(traces)) + for i, tr := range traces { + proxyOverhead[i] = tr.Total - tr.UpstreamRoundTrip - tr.ResponseStream + } + sort.Slice(proxyOverhead, func(i, j int) bool { return proxyOverhead[i] < proxyOverhead[j] }) + t.Logf(" --- proxy overhead (total - upstream - response_stream) ---") + reportPercentiles(t, " overhead", proxyOverhead) +} + +func extractDurations(traces []RequestTrace, fn func(RequestTrace) time.Duration) []time.Duration { + ds := make([]time.Duration, len(traces)) + for i, t := range traces { + ds[i] = fn(t) + } + sort.Slice(ds, func(i, j int) bool { return ds[i] < ds[j] }) + return ds +} + +func avg(ds []time.Duration) time.Duration { + if len(ds) == 0 { + return 0 + } + var total time.Duration + for _, d := range ds { + total += d + } + return total / time.Duration(len(ds)) +} + +func percentile(sorted []time.Duration, p float64) time.Duration { + if len(sorted) == 0 { + return 0 + } + idx := int(math.Ceil(p*float64(len(sorted)))) - 1 + if idx < 0 { + idx = 0 + } + if idx >= len(sorted) { + idx = len(sorted) - 1 + } + return sorted[idx] +} + +func reportPercentiles(t *testing.T, prefix string, sorted []time.Duration) { + t.Logf("%s p50=%s p90=%s p99=%s max=%s", + prefix, + percentile(sorted, 0.50), + percentile(sorted, 0.90), + percentile(sorted, 0.99), + sorted[len(sorted)-1], + ) +} diff --git a/stainless-proxy/internal/proxy/integration_test.go b/stainless-proxy/internal/proxy/integration_test.go new file mode 100644 index 0000000..84d1f81 --- /dev/null +++ b/stainless-proxy/internal/proxy/integration_test.go @@ -0,0 +1,692 @@ +package proxy + +import ( + "bytes" + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stainless-api/stainless-proxy/internal/jwe" + "github.com/stainless-api/stainless-proxy/internal/revocation" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Helper to build a full proxy + encryptor + deny list for integration tests +type testHarness struct { + proxy *Proxy + enc *jwe.Encryptor + denyList *revocation.DenyList + kid string + key *ecdsa.PrivateKey +} + +func newTestHarness(t *testing.T) *testHarness { + t.Helper() + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + + kid := "integration-key" + enc := jwe.NewEncryptor(&key.PublicKey, kid) + dec := jwe.NewMultiKeyDecryptor([]jwe.KeyEntry{{KID: kid, PrivateKey: key}}) + dl := revocation.NewDenyList() + p := New(dec, dl) + + return &testHarness{proxy: p, enc: enc, denyList: dl, kid: kid, key: key} +} + +func (h *testHarness) mint(t *testing.T, hosts []string, creds []jwe.Credential, ttl time.Duration) string { + t.Helper() + payload := jwe.Payload{ + Exp: time.Now().Add(ttl).Unix(), + AllowedHosts: hosts, + Credentials: creds, + } + token, err := h.enc.Encrypt(payload) + require.NoError(t, err) + return token +} + +// ============================================================================= +// Scenario 1: Datadog-style API with two credential headers +// Verifies that multiple credentials are injected correctly and that the +// original Authorization header (carrying the JWE) is stripped. +// ============================================================================= + +func TestScenario_DatadogMultiHeaderAuth(t *testing.T) { + h := newTestHarness(t) + + var receivedHeaders http.Header + backend := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedHeaders = r.Header.Clone() + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"data":[]}`)) + })) + defer backend.Close() + h.proxy.client = backend.Client() + + backendHost := strings.TrimPrefix(backend.URL, "https://") + token := h.mint(t, []string{backendHost}, []jwe.Credential{ + {Header: "DD-API-KEY", Value: "datadog-api-key-12345"}, + {Header: "DD-APPLICATION-KEY", Value: "datadog-app-key-67890"}, + }, time.Hour) + + req := httptest.NewRequest("GET", "/https/"+backendHost+"/api/v2/metrics?from=now-1h", nil) + req.Header.Set("Authorization", "Bearer "+token) + req.Header.Set("Accept", "application/json") + w := httptest.NewRecorder() + + h.proxy.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "datadog-api-key-12345", receivedHeaders.Get("DD-API-KEY")) + assert.Equal(t, "datadog-app-key-67890", receivedHeaders.Get("DD-APPLICATION-KEY")) + // JWE bearer token must NOT reach the backend + assert.NotContains(t, receivedHeaders.Get("Authorization"), "eyJ") + // Non-sensitive headers should be forwarded + assert.Equal(t, "application/json", receivedHeaders.Get("Accept")) +} + +// ============================================================================= +// Scenario 2: Credential exfiltration attempt +// A compromised sandbox tries to redirect the proxy to an attacker-controlled +// server. Host locking must block this. +// ============================================================================= + +func TestScenario_ExfiltrationAttempt(t *testing.T) { + h := newTestHarness(t) + + attacker := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Fatal("attacker server should never be reached") + })) + defer attacker.Close() + + // Token is only valid for the legitimate API + token := h.mint(t, []string{"api.datadoghq.com"}, []jwe.Credential{ + {Header: "DD-API-KEY", Value: "secret"}, + }, time.Hour) + + attackerHost := strings.TrimPrefix(attacker.URL, "https://") + + req := httptest.NewRequest("GET", "/https/"+attackerHost+"/exfiltrate", nil) + req.Header.Set("Authorization", "Bearer "+token) + w := httptest.NewRecorder() + + h.proxy.ServeHTTP(w, req) + + assert.Equal(t, http.StatusForbidden, w.Code) + assert.Contains(t, w.Body.String(), "host not allowed") +} + +// ============================================================================= +// Scenario 3: Token replay after revocation +// A JWE is minted, used successfully, then revoked. Subsequent use must fail. +// ============================================================================= + +func TestScenario_TokenReplayAfterRevocation(t *testing.T) { + h := newTestHarness(t) + + var callCount atomic.Int32 + backend := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callCount.Add(1) + w.WriteHeader(http.StatusOK) + })) + defer backend.Close() + h.proxy.client = backend.Client() + + backendHost := strings.TrimPrefix(backend.URL, "https://") + token := h.mint(t, []string{backendHost}, []jwe.Credential{ + {Header: "Authorization", Value: "Bearer real-token"}, + }, time.Hour) + + // First request: should succeed + req := httptest.NewRequest("GET", "/https/"+backendHost+"/api/data", nil) + req.Header.Set("Authorization", "Bearer "+token) + w := httptest.NewRecorder() + h.proxy.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, int32(1), callCount.Load()) + + // Revoke the token + hash := sha256.Sum256([]byte(token)) + h.denyList.Add(hex.EncodeToString(hash[:]), time.Now().Add(time.Hour)) + + // Second request: should be rejected + req = httptest.NewRequest("GET", "/https/"+backendHost+"/api/data", nil) + req.Header.Set("Authorization", "Bearer "+token) + w = httptest.NewRecorder() + h.proxy.ServeHTTP(w, req) + assert.Equal(t, http.StatusForbidden, w.Code) + // Backend should NOT have been called again + assert.Equal(t, int32(1), callCount.Load()) +} + +// ============================================================================= +// Scenario 4: Key rotation — old JWE still works after new key is added +// Simulates a key rotation where the proxy has both old and new keys. +// JWEs minted with the old key must still decrypt. +// ============================================================================= + +func TestScenario_KeyRotation(t *testing.T) { + oldKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + newKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + + // Decryptor knows both keys + dec := jwe.NewMultiKeyDecryptor([]jwe.KeyEntry{ + {KID: "old-key", PrivateKey: oldKey}, + {KID: "new-key", PrivateKey: newKey}, + }) + dl := revocation.NewDenyList() + p := New(dec, dl) + + backend := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("ok:" + r.Header.Get("X-Api-Key"))) + })) + defer backend.Close() + p.client = backend.Client() + backendHost := strings.TrimPrefix(backend.URL, "https://") + + // Mint with old key + oldEnc := jwe.NewEncryptor(&oldKey.PublicKey, "old-key") + oldPayload := jwe.Payload{ + Exp: time.Now().Add(time.Hour).Unix(), + AllowedHosts: []string{backendHost}, + Credentials: []jwe.Credential{{Header: "X-Api-Key", Value: "old-secret"}}, + } + oldToken, err := oldEnc.Encrypt(oldPayload) + require.NoError(t, err) + + // Mint with new key + newEnc := jwe.NewEncryptor(&newKey.PublicKey, "new-key") + newPayload := jwe.Payload{ + Exp: time.Now().Add(time.Hour).Unix(), + AllowedHosts: []string{backendHost}, + Credentials: []jwe.Credential{{Header: "X-Api-Key", Value: "new-secret"}}, + } + newToken, err := newEnc.Encrypt(newPayload) + require.NoError(t, err) + + // Both tokens should work + for _, tc := range []struct { + name string + token string + want string + }{ + {"old key token", oldToken, "ok:old-secret"}, + {"new key token", newToken, "ok:new-secret"}, + } { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest("GET", "/https/"+backendHost+"/test", nil) + req.Header.Set("Authorization", "Bearer "+tc.token) + w := httptest.NewRecorder() + p.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, tc.want, w.Body.String()) + }) + } +} + +// ============================================================================= +// Scenario 5: Concurrent requests with different credentials +// Multiple sandbox sessions using different JWEs concurrently. +// Each must get their own credentials injected — no cross-contamination. +// ============================================================================= + +func TestScenario_ConcurrentIsolation(t *testing.T) { + h := newTestHarness(t) + + backend := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Echo the injected credential back + w.Write([]byte(r.Header.Get("X-User-Token"))) + })) + defer backend.Close() + h.proxy.client = backend.Client() + backendHost := strings.TrimPrefix(backend.URL, "https://") + + const numUsers = 50 + var wg sync.WaitGroup + errors := make([]string, numUsers) + + for i := range numUsers { + wg.Add(1) + go func(userID int) { + defer wg.Done() + + expectedToken := fmt.Sprintf("user-%d-token", userID) + token := h.mint(t, []string{backendHost}, []jwe.Credential{ + {Header: "X-User-Token", Value: expectedToken}, + }, time.Hour) + + req := httptest.NewRequest("GET", "/https/"+backendHost+"/api/me", nil) + req.Header.Set("Authorization", "Bearer "+token) + w := httptest.NewRecorder() + h.proxy.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + errors[userID] = fmt.Sprintf("user %d: status %d", userID, w.Code) + return + } + if w.Body.String() != expectedToken { + errors[userID] = fmt.Sprintf("user %d: got %q, want %q", userID, w.Body.String(), expectedToken) + } + }(i) + } + + wg.Wait() + for _, e := range errors { + if e != "" { + t.Error(e) + } + } +} + +// ============================================================================= +// Scenario 6: POST with large JSON body (code execution payload) +// Simulates the actual use case: agent sends a script to execute. +// The proxy must forward the body untouched. +// ============================================================================= + +func TestScenario_LargePostBody(t *testing.T) { + h := newTestHarness(t) + + var receivedBody []byte + backend := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var err error + receivedBody, err = io.ReadAll(r.Body) + if err != nil { + t.Fatal(err) + } + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"result":"success"}`)) + })) + defer backend.Close() + h.proxy.client = backend.Client() + backendHost := strings.TrimPrefix(backend.URL, "https://") + + // Simulate a large TypeScript script + script := strings.Repeat("console.log('line');\n", 10000) + body, _ := json.Marshal(map[string]string{"script": script}) + + token := h.mint(t, []string{backendHost}, []jwe.Credential{ + {Header: "Authorization", Value: "Bearer api-key-for-execution"}, + }, time.Hour) + + req := httptest.NewRequest("POST", "/https/"+backendHost+"/ai/tools/code", bytes.NewReader(body)) + req.Header.Set("Authorization", "Bearer "+token) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + h.proxy.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, body, receivedBody) +} + +// ============================================================================= +// Scenario 7: SSE streaming response +// The target API returns a streaming SSE response. The proxy must flush +// incrementally, not buffer the entire response. +// ============================================================================= + +func TestScenario_SSEStreaming(t *testing.T) { + h := newTestHarness(t) + + backend := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.WriteHeader(http.StatusOK) + + flusher, ok := w.(http.Flusher) + if !ok { + t.Fatal("expected flusher") + } + + for i := range 5 { + fmt.Fprintf(w, "data: event %d\n\n", i) + flusher.Flush() + } + })) + defer backend.Close() + h.proxy.client = backend.Client() + backendHost := strings.TrimPrefix(backend.URL, "https://") + + token := h.mint(t, []string{backendHost}, []jwe.Credential{ + {Header: "Authorization", Value: "Bearer stream-key"}, + }, time.Hour) + + req := httptest.NewRequest("GET", "/https/"+backendHost+"/events", nil) + req.Header.Set("Authorization", "Bearer "+token) + w := httptest.NewRecorder() + + h.proxy.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "text/event-stream", w.Header().Get("Content-Type")) + body := w.Body.String() + for i := range 5 { + assert.Contains(t, body, fmt.Sprintf("data: event %d", i)) + } +} + +// ============================================================================= +// Scenario 8: Token expiry mid-session +// Mint a token with very short TTL, wait for it to expire, verify rejection. +// ============================================================================= + +func TestScenario_TokenExpiryMidSession(t *testing.T) { + h := newTestHarness(t) + + backend := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer backend.Close() + h.proxy.client = backend.Client() + backendHost := strings.TrimPrefix(backend.URL, "https://") + + // Token expires in 1 second + token := h.mint(t, []string{backendHost}, []jwe.Credential{ + {Header: "Authorization", Value: "Bearer short-lived"}, + }, 1*time.Second) + + // Should work immediately + req := httptest.NewRequest("GET", "/https/"+backendHost+"/api", nil) + req.Header.Set("Authorization", "Bearer "+token) + w := httptest.NewRecorder() + h.proxy.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + + // Wait for expiry + time.Sleep(2 * time.Second) + + // Should now fail + req = httptest.NewRequest("GET", "/https/"+backendHost+"/api", nil) + req.Header.Set("Authorization", "Bearer "+token) + w = httptest.NewRecorder() + h.proxy.ServeHTTP(w, req) + assert.Equal(t, http.StatusUnauthorized, w.Code) +} + +// ============================================================================= +// Scenario 9: Wildcard host matching edge cases +// Verify that wildcard patterns work correctly and don't allow bypasses. +// ============================================================================= + +func TestScenario_WildcardHostEdgeCases(t *testing.T) { + h := newTestHarness(t) + + backend := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer backend.Close() + h.proxy.client = backend.Client() + backendHost := strings.TrimPrefix(backend.URL, "https://") + + tests := []struct { + name string + hosts []string + target string + wantCode int + }{ + { + name: "wildcard allows subdomain", + hosts: []string{"*.nonexistent.test"}, + target: "/https/sub.nonexistent.test/v1/data", + wantCode: http.StatusBadGateway, // host matches wildcard, but no real backend → 502 + }, + { + name: "wildcard does not match bare domain", + hosts: []string{"*.datadoghq.com"}, + target: "/https/datadoghq.com/v1/data", + wantCode: http.StatusForbidden, + }, + { + name: "exact match on backend host", + hosts: []string{backendHost}, + target: "/https/" + backendHost + "/test", + wantCode: http.StatusOK, + }, + { + name: "multiple patterns, one matches", + hosts: []string{"other.com", backendHost}, + target: "/https/" + backendHost + "/test", + wantCode: http.StatusOK, + }, + { + name: "no pattern matches", + hosts: []string{"totally-different.com"}, + target: "/https/" + backendHost + "/test", + wantCode: http.StatusForbidden, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + token := h.mint(t, tc.hosts, []jwe.Credential{ + {Header: "X-Key", Value: "val"}, + }, time.Hour) + + req := httptest.NewRequest("GET", tc.target, nil) + req.Header.Set("Authorization", "Bearer "+token) + w := httptest.NewRecorder() + h.proxy.ServeHTTP(w, req) + assert.Equal(t, tc.wantCode, w.Code) + }) + } +} + +// ============================================================================= +// Scenario 10: Request cancellation / context timeout +// Verify that the proxy respects context cancellation (e.g., client disconnect). +// ============================================================================= + +func TestScenario_ContextCancellation(t *testing.T) { + h := newTestHarness(t) + + backendStarted := make(chan struct{}) + backend := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + close(backendStarted) + // Simulate a slow backend + select { + case <-r.Context().Done(): + // Client cancelled — this is what we expect + return + case <-time.After(30 * time.Second): + w.WriteHeader(http.StatusOK) + } + })) + defer backend.Close() + h.proxy.client = backend.Client() + backendHost := strings.TrimPrefix(backend.URL, "https://") + + token := h.mint(t, []string{backendHost}, []jwe.Credential{ + {Header: "Authorization", Value: "Bearer x"}, + }, time.Hour) + + ctx, cancel := context.WithCancel(context.Background()) + req := httptest.NewRequest("GET", "/https/"+backendHost+"/slow", nil) + req = req.WithContext(ctx) + req.Header.Set("Authorization", "Bearer "+token) + w := httptest.NewRecorder() + + done := make(chan struct{}) + go func() { + h.proxy.ServeHTTP(w, req) + close(done) + }() + + <-backendStarted + cancel() + + select { + case <-done: + // Proxy returned after cancellation — good + case <-time.After(5 * time.Second): + t.Fatal("proxy did not return after context cancellation") + } +} + +// ============================================================================= +// Scenario 11: Query parameters preserved +// Verify query strings pass through correctly to the target. +// ============================================================================= + +func TestScenario_QueryParametersPreserved(t *testing.T) { + h := newTestHarness(t) + + var receivedQuery string + backend := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedQuery = r.URL.RawQuery + w.WriteHeader(http.StatusOK) + })) + defer backend.Close() + h.proxy.client = backend.Client() + backendHost := strings.TrimPrefix(backend.URL, "https://") + + token := h.mint(t, []string{backendHost}, []jwe.Credential{ + {Header: "X-Key", Value: "val"}, + }, time.Hour) + + req := httptest.NewRequest("GET", "/https/"+backendHost+"/api/search?q=test%20query&limit=10&offset=0", nil) + req.Header.Set("Authorization", "Bearer "+token) + w := httptest.NewRecorder() + + h.proxy.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "q=test%20query&limit=10&offset=0", receivedQuery) +} + +// ============================================================================= +// Scenario 12: Backend returns error status codes +// Proxy must faithfully forward 4xx/5xx from the target API. +// ============================================================================= + +func TestScenario_BackendErrorPassthrough(t *testing.T) { + h := newTestHarness(t) + + backend := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/not-found": + w.WriteHeader(http.StatusNotFound) + w.Write([]byte(`{"error":"not found"}`)) + case "/rate-limited": + w.Header().Set("Retry-After", "30") + w.WriteHeader(http.StatusTooManyRequests) + w.Write([]byte(`{"error":"rate limited"}`)) + case "/server-error": + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(`{"error":"internal"}`)) + } + })) + defer backend.Close() + h.proxy.client = backend.Client() + backendHost := strings.TrimPrefix(backend.URL, "https://") + + token := h.mint(t, []string{backendHost}, []jwe.Credential{ + {Header: "Authorization", Value: "Bearer x"}, + }, time.Hour) + + tests := []struct { + path string + wantCode int + wantBody string + }{ + {"/not-found", 404, `{"error":"not found"}`}, + {"/rate-limited", 429, `{"error":"rate limited"}`}, + {"/server-error", 500, `{"error":"internal"}`}, + } + + for _, tc := range tests { + t.Run(tc.path, func(t *testing.T) { + req := httptest.NewRequest("GET", "/https/"+backendHost+tc.path, nil) + req.Header.Set("Authorization", "Bearer "+token) + w := httptest.NewRecorder() + h.proxy.ServeHTTP(w, req) + + assert.Equal(t, tc.wantCode, w.Code) + assert.Equal(t, tc.wantBody, w.Body.String()) + }) + } +} + +// ============================================================================= +// Scenario 13: Credential override — JWE credential overwrites existing header +// If the sandbox code sets a header that the JWE also injects, the JWE +// credential MUST win (otherwise sandbox code could inject its own auth). +// ============================================================================= + +func TestScenario_CredentialOverridesExistingHeader(t *testing.T) { + h := newTestHarness(t) + + var received string + backend := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + received = r.Header.Get("X-Api-Key") + w.WriteHeader(http.StatusOK) + })) + defer backend.Close() + h.proxy.client = backend.Client() + backendHost := strings.TrimPrefix(backend.URL, "https://") + + token := h.mint(t, []string{backendHost}, []jwe.Credential{ + {Header: "X-Api-Key", Value: "real-secret-from-jwe"}, + }, time.Hour) + + req := httptest.NewRequest("GET", "/https/"+backendHost+"/api", nil) + req.Header.Set("Authorization", "Bearer "+token) + // Sandbox tries to set its own X-Api-Key + req.Header.Set("X-Api-Key", "attacker-injected-value") + w := httptest.NewRecorder() + + h.proxy.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "real-secret-from-jwe", received) +} + +// ============================================================================= +// Scenario 14: Path traversal attempt +// Sandbox tries to manipulate the path to access unintended APIs. +// ============================================================================= + +func TestScenario_PathTraversal(t *testing.T) { + h := newTestHarness(t) + + token := h.mint(t, []string{"api.safe.com"}, []jwe.Credential{ + {Header: "Authorization", Value: "Bearer x"}, + }, time.Hour) + + paths := []string{ + "/https/api.safe.com/../../../etc/passwd", + "/https/api.safe.com/..%2F..%2Fetc%2Fpasswd", + "/https/evil.com%00api.safe.com/data", + } + + for _, path := range paths { + t.Run(path, func(t *testing.T) { + req := httptest.NewRequest("GET", path, nil) + req.Header.Set("Authorization", "Bearer "+token) + w := httptest.NewRecorder() + h.proxy.ServeHTTP(w, req) + + // Either blocked by host matching or forwarded with the path as-is + // (the target server handles path traversal — we just ensure the + // host lock isn't bypassed) + assert.NotEqual(t, http.StatusOK, w.Code) + }) + } +} diff --git a/stainless-proxy/internal/proxy/proxy.go b/stainless-proxy/internal/proxy/proxy.go new file mode 100644 index 0000000..3a5f501 --- /dev/null +++ b/stainless-proxy/internal/proxy/proxy.go @@ -0,0 +1,280 @@ +package proxy + +import ( + "crypto/sha256" + "encoding/hex" + "fmt" + "io" + "log/slog" + "net/http" + "net/url" + "strings" + "time" + + "github.com/stainless-api/stainless-proxy/internal/hostmatch" + "github.com/stainless-api/stainless-proxy/internal/jwe" + "github.com/stainless-api/stainless-proxy/internal/revocation" + "github.com/stainless-api/stainless-proxy/internal/secretutil" +) + +type Proxy struct { + client *http.Client + decryptor jwe.Decryptor + denyList *revocation.DenyList + OnTrace TraceCallback +} + +func New(decryptor jwe.Decryptor, denyList *revocation.DenyList) *Proxy { + return &Proxy{ + client: &http.Client{}, + decryptor: decryptor, + denyList: denyList, + } +} + +func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { + trace := &RequestTrace{ + Start: time.Now(), + Method: r.Method, + } + defer p.finishTrace(trace) + + token, ok := extractBearer(r) + if !ok { + trace.StatusCode = http.StatusUnauthorized + writeError(w, http.StatusUnauthorized, "missing or invalid Authorization header") + return + } + + t0 := time.Now() + hash := hashJWE(token) + if p.denyList.IsRevoked(hash) { + trace.DenyListCheck = time.Since(t0) + trace.StatusCode = http.StatusForbidden + writeError(w, http.StatusForbidden, "token revoked") + return + } + trace.DenyListCheck = time.Since(t0) + + t0 = time.Now() + payload, err := p.decryptor.Decrypt(token) + trace.JWEDecrypt = time.Since(t0) + if err != nil { + slog.Warn("JWE decryption failed", "error", err) + trace.StatusCode = http.StatusUnauthorized + writeError(w, http.StatusUnauthorized, "invalid token") + return + } + trace.CredentialCount = len(payload.Credentials) + + t0 = time.Now() + targetScheme, targetHost, targetPath, err := parseTargetFromPath(r.URL.Path) + trace.PathParse = time.Since(t0) + if err != nil { + trace.StatusCode = http.StatusBadRequest + writeError(w, http.StatusBadRequest, err.Error()) + return + } + trace.TargetHost = targetHost + trace.TargetPath = targetPath + + t0 = time.Now() + if !hostmatch.Match(targetHost, payload.AllowedHosts) { + trace.HostMatch = time.Since(t0) + slog.Warn("host not allowed", "host", targetHost, "allowed", payload.AllowedHosts) + trace.StatusCode = http.StatusForbidden + writeError(w, http.StatusForbidden, "target host not allowed") + return + } + trace.HostMatch = time.Since(t0) + + t0 = time.Now() + targetURL := &url.URL{ + Scheme: targetScheme, + Host: targetHost, + Path: targetPath, + RawQuery: r.URL.RawQuery, + } + + outReq, err := http.NewRequestWithContext(r.Context(), r.Method, targetURL.String(), r.Body) + if err != nil { + trace.RequestBuild = time.Since(t0) + trace.StatusCode = http.StatusInternalServerError + writeError(w, http.StatusInternalServerError, "failed to create request") + return + } + + copyHeaders(outReq.Header, r.Header) + trace.RequestBuild = time.Since(t0) + + slog.Info("proxying request", + "method", r.Method, + "target", targetURL.Redacted(), + "credentials_count", len(payload.Credentials), + ) + + // secret.Do wraps credential injection + upstream call. This ensures: + // - Credential string values (from payload) are used then zeroed + // - net/http transport's internal write buffers (containing serialized + // credential headers) are marked for zeroing on GC collection + // - Stack frames and registers are zeroed on return + var resp *http.Response + secretutil.Do(func() { + for _, cred := range payload.Credentials { + outReq.Header.Set(cred.Header, cred.Value) + } + + t0 = time.Now() + resp, err = p.client.Do(outReq) + trace.UpstreamRoundTrip = time.Since(t0) + }) + if err != nil { + slog.Error("upstream request failed", "error", err) + trace.StatusCode = http.StatusBadGateway + writeError(w, http.StatusBadGateway, "upstream request failed") + return + } + defer resp.Body.Close() + trace.StatusCode = resp.StatusCode + + t0 = time.Now() + for k, v := range resp.Header { + if isHopByHop(k) { + continue + } + w.Header()[k] = v + } + w.WriteHeader(resp.StatusCode) + + var responseBytes int64 + if isStreaming(resp) { + flusher, ok := w.(http.Flusher) + buf := make([]byte, 4096) + for { + n, readErr := resp.Body.Read(buf) + if n > 0 { + w.Write(buf[:n]) + responseBytes += int64(n) + if ok { + flusher.Flush() + } + } + if readErr != nil { + break + } + } + } else { + responseBytes, _ = io.Copy(w, resp.Body) + } + trace.ResponseStream = time.Since(t0) + trace.ResponseBodyBytes = responseBytes +} + +func (p *Proxy) finishTrace(trace *RequestTrace) { + trace.Total = time.Since(trace.Start) + + slog.Info("request completed", + "method", trace.Method, + "target_host", trace.TargetHost, + "status", trace.StatusCode, + "total", trace.Total, + "deny_list_check", trace.DenyListCheck, + "jwe_decrypt", trace.JWEDecrypt, + "path_parse", trace.PathParse, + "host_match", trace.HostMatch, + "request_build", trace.RequestBuild, + "upstream_round_trip", trace.UpstreamRoundTrip, + "response_stream", trace.ResponseStream, + "credentials", trace.CredentialCount, + "response_bytes", trace.ResponseBodyBytes, + ) + + if p.OnTrace != nil { + p.OnTrace(*trace) + } +} + +// parseTargetFromPath extracts scheme, host, and path from /{scheme}/{host}/{path...} +func parseTargetFromPath(reqPath string) (scheme, host, path string, err error) { + // Strip leading slash + trimmed := strings.TrimPrefix(reqPath, "/") + if trimmed == "" { + return "", "", "", fmt.Errorf("empty path") + } + + // First segment is scheme + slashIdx := strings.IndexByte(trimmed, '/') + if slashIdx == -1 { + return "", "", "", fmt.Errorf("missing host in path") + } + scheme = trimmed[:slashIdx] + if scheme != "http" && scheme != "https" { + return "", "", "", fmt.Errorf("invalid scheme: %s", scheme) + } + + rest := trimmed[slashIdx+1:] + + // Second segment is host (may include port) + slashIdx = strings.IndexByte(rest, '/') + if slashIdx == -1 { + host = rest + path = "/" + } else { + host = rest[:slashIdx] + path = rest[slashIdx:] + } + + if host == "" { + return "", "", "", fmt.Errorf("empty host") + } + + return scheme, host, path, nil +} + +func extractBearer(r *http.Request) (string, bool) { + auth := r.Header.Get("Authorization") + if !strings.HasPrefix(auth, "Bearer ") { + return "", false + } + token := auth[7:] + if token == "" { + return "", false + } + return token, true +} + +func hashJWE(token string) string { + h := sha256.Sum256([]byte(token)) + return hex.EncodeToString(h[:]) +} + +func copyHeaders(dst, src http.Header) { + for k, v := range src { + if k == "Connection" || k == "Upgrade" || k == "Host" || + k == "Authorization" || k == "Cookie" { + continue + } + dst[k] = v + } +} + +func isHopByHop(header string) bool { + switch header { + case "Connection", "Keep-Alive", "Proxy-Authenticate", + "Proxy-Authorization", "Te", "Trailers", + "Transfer-Encoding", "Upgrade": + return true + } + return false +} + +func isStreaming(resp *http.Response) bool { + ct := resp.Header.Get("Content-Type") + return strings.HasPrefix(ct, "text/event-stream") +} + +func writeError(w http.ResponseWriter, status int, message string) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + fmt.Fprintf(w, `{"error":"%s"}`, message) +} diff --git a/stainless-proxy/internal/proxy/proxy_test.go b/stainless-proxy/internal/proxy/proxy_test.go new file mode 100644 index 0000000..a13d6be --- /dev/null +++ b/stainless-proxy/internal/proxy/proxy_test.go @@ -0,0 +1,219 @@ +package proxy + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/stainless-api/stainless-proxy/internal/jwe" + "github.com/stainless-api/stainless-proxy/internal/revocation" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func setupProxy(t *testing.T) (*Proxy, *jwe.Encryptor, string) { + t.Helper() + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + + kid := "test-key" + enc := jwe.NewEncryptor(&key.PublicKey, kid) + dec := jwe.NewMultiKeyDecryptor([]jwe.KeyEntry{{KID: kid, PrivateKey: key}}) + dl := revocation.NewDenyList() + p := New(dec, dl) + return p, enc, kid +} + +func mintToken(t *testing.T, enc *jwe.Encryptor, hosts []string, creds []jwe.Credential) string { + t.Helper() + payload := jwe.Payload{ + Exp: time.Now().Add(time.Hour).Unix(), + AllowedHosts: hosts, + Credentials: creds, + } + token, err := enc.Encrypt(payload) + require.NoError(t, err) + return token +} + +func TestProxyInjectsCredentials(t *testing.T) { + p, enc, _ := setupProxy(t) + + // Backend that echoes headers + backend := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"api_key":"` + r.Header.Get("X-Api-Key") + `","app_key":"` + r.Header.Get("X-App-Key") + `"}`)) + })) + defer backend.Close() + + // The backend URL looks like https://127.0.0.1:PORT + backendHost := strings.TrimPrefix(backend.URL, "https://") + + token := mintToken(t, enc, []string{backendHost}, []jwe.Credential{ + {Header: "X-Api-Key", Value: "secret-api-key"}, + {Header: "X-App-Key", Value: "secret-app-key"}, + }) + + // Use the backend's TLS client + p.client = backend.Client() + + req := httptest.NewRequest("GET", "/https/"+backendHost+"/api/v1/data", nil) + req.Header.Set("Authorization", "Bearer "+token) + w := httptest.NewRecorder() + + p.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + body := w.Body.String() + assert.Contains(t, body, "secret-api-key") + assert.Contains(t, body, "secret-app-key") +} + +func TestProxyBlocksUnallowedHost(t *testing.T) { + p, enc, _ := setupProxy(t) + + token := mintToken(t, enc, []string{"allowed.com"}, []jwe.Credential{ + {Header: "Authorization", Value: "Bearer x"}, + }) + + req := httptest.NewRequest("GET", "/https/evil.com/steal", nil) + req.Header.Set("Authorization", "Bearer "+token) + w := httptest.NewRecorder() + + p.ServeHTTP(w, req) + + assert.Equal(t, http.StatusForbidden, w.Code) +} + +func TestProxyRejectsNoAuth(t *testing.T) { + p, _, _ := setupProxy(t) + + req := httptest.NewRequest("GET", "/https/api.example.com/data", nil) + w := httptest.NewRecorder() + + p.ServeHTTP(w, req) + + assert.Equal(t, http.StatusUnauthorized, w.Code) +} + +func TestProxyRejectsRevokedToken(t *testing.T) { + key, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + kid := "test-key" + enc := jwe.NewEncryptor(&key.PublicKey, kid) + dec := jwe.NewMultiKeyDecryptor([]jwe.KeyEntry{{KID: kid, PrivateKey: key}}) + dl := revocation.NewDenyList() + p := New(dec, dl) + + token := mintToken(t, enc, []string{"api.example.com"}, []jwe.Credential{ + {Header: "Authorization", Value: "Bearer x"}, + }) + + // Revoke the token + hash := hashJWE(token) + dl.Add(hash, time.Now().Add(time.Hour)) + + req := httptest.NewRequest("GET", "/https/api.example.com/data", nil) + req.Header.Set("Authorization", "Bearer "+token) + w := httptest.NewRecorder() + + p.ServeHTTP(w, req) + + assert.Equal(t, http.StatusForbidden, w.Code) +} + +func TestProxyStripsAuthHeader(t *testing.T) { + p, enc, _ := setupProxy(t) + + backend := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // The original Authorization header (with JWE) should be stripped + // and replaced with the credential from the JWE + auth := r.Header.Get("Authorization") + w.Write([]byte(auth)) + })) + defer backend.Close() + + backendHost := strings.TrimPrefix(backend.URL, "https://") + + token := mintToken(t, enc, []string{backendHost}, []jwe.Credential{ + {Header: "Authorization", Value: "Bearer real-api-token"}, + }) + + p.client = backend.Client() + + req := httptest.NewRequest("GET", "/https/"+backendHost+"/test", nil) + req.Header.Set("Authorization", "Bearer "+token) + w := httptest.NewRecorder() + + p.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + body := w.Body.String() + assert.Equal(t, "Bearer real-api-token", body) +} + +func TestProxyForwardsBody(t *testing.T) { + p, enc, _ := setupProxy(t) + + backend := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + w.Write(body) + })) + defer backend.Close() + + backendHost := strings.TrimPrefix(backend.URL, "https://") + + token := mintToken(t, enc, []string{backendHost}, []jwe.Credential{ + {Header: "X-Key", Value: "val"}, + }) + + p.client = backend.Client() + + req := httptest.NewRequest("POST", "/https/"+backendHost+"/api", strings.NewReader(`{"query":"test"}`)) + req.Header.Set("Authorization", "Bearer "+token) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + p.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, `{"query":"test"}`, w.Body.String()) +} + +func TestParseTargetFromPath(t *testing.T) { + tests := []struct { + name string + path string + wantScheme string + wantHost string + wantPath string + wantErr bool + }{ + {"full path", "/https/api.example.com/v1/users", "https", "api.example.com", "/v1/users", false}, + {"no trailing path", "/https/api.example.com", "https", "api.example.com", "/", false}, + {"with port", "/https/api.example.com:8080/data", "https", "api.example.com:8080", "/data", false}, + {"http scheme", "/http/localhost:3000/test", "http", "localhost:3000", "/test", false}, + {"invalid scheme", "/ftp/example.com/file", "", "", "", true}, + {"empty path", "/", "", "", "", true}, + {"missing host", "/https/", "", "", "", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + scheme, host, path, err := parseTargetFromPath(tt.path) + if tt.wantErr { + assert.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.wantScheme, scheme) + assert.Equal(t, tt.wantHost, host) + assert.Equal(t, tt.wantPath, path) + }) + } +} diff --git a/stainless-proxy/internal/proxy/trace.go b/stainless-proxy/internal/proxy/trace.go new file mode 100644 index 0000000..389b6d8 --- /dev/null +++ b/stainless-proxy/internal/proxy/trace.go @@ -0,0 +1,44 @@ +package proxy + +import ( + "context" + "time" +) + +type RequestTrace struct { + Start time.Time + + // Phase durations + DenyListCheck time.Duration + JWEDecrypt time.Duration + PathParse time.Duration + HostMatch time.Duration + RequestBuild time.Duration + UpstreamRoundTrip time.Duration + ResponseStream time.Duration + + // Total wall-clock time for the entire request + Total time.Duration + + // Metadata + Method string + TargetHost string + TargetPath string + StatusCode int + CredentialCount int + RequestBodyBytes int64 + ResponseBodyBytes int64 +} + +type TraceCallback func(RequestTrace) + +type contextKey struct{} + +func traceFromContext(ctx context.Context) *RequestTrace { + t, _ := ctx.Value(contextKey{}).(*RequestTrace) + return t +} + +func contextWithTrace(ctx context.Context, t *RequestTrace) context.Context { + return context.WithValue(ctx, contextKey{}, t) +} diff --git a/stainless-proxy/internal/revocation/revocation.go b/stainless-proxy/internal/revocation/revocation.go new file mode 100644 index 0000000..2caaf7a --- /dev/null +++ b/stainless-proxy/internal/revocation/revocation.go @@ -0,0 +1,66 @@ +package revocation + +import ( + "context" + "sync" + "time" +) + +type DenyList struct { + mu sync.RWMutex + entries map[string]time.Time // JWE SHA-256 hash -> expiry +} + +func NewDenyList() *DenyList { + return &DenyList{ + entries: make(map[string]time.Time), + } +} + +func (d *DenyList) Add(hash string, expiresAt time.Time) { + d.mu.Lock() + d.entries[hash] = expiresAt + d.mu.Unlock() +} + +func (d *DenyList) IsRevoked(hash string) bool { + d.mu.RLock() + exp, ok := d.entries[hash] + d.mu.RUnlock() + if !ok { + return false + } + if time.Now().After(exp) { + d.mu.Lock() + delete(d.entries, hash) + d.mu.Unlock() + return false + } + return true +} + +func (d *DenyList) Cleanup() { + now := time.Now() + d.mu.Lock() + for hash, exp := range d.entries { + if now.After(exp) { + delete(d.entries, hash) + } + } + d.mu.Unlock() +} + +func (d *DenyList) StartCleanup(ctx context.Context, interval time.Duration) { + go func() { + ticker := time.NewTicker(interval) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + d.Cleanup() + } + } + }() +} diff --git a/stainless-proxy/internal/revocation/revocation_test.go b/stainless-proxy/internal/revocation/revocation_test.go new file mode 100644 index 0000000..d127a75 --- /dev/null +++ b/stainless-proxy/internal/revocation/revocation_test.go @@ -0,0 +1,35 @@ +package revocation + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestDenyList(t *testing.T) { + dl := NewDenyList() + + dl.Add("hash1", time.Now().Add(time.Hour)) + assert.True(t, dl.IsRevoked("hash1")) + assert.False(t, dl.IsRevoked("hash2")) +} + +func TestDenyListExpiry(t *testing.T) { + dl := NewDenyList() + + dl.Add("hash1", time.Now().Add(-time.Second)) + assert.False(t, dl.IsRevoked("hash1")) +} + +func TestDenyListCleanup(t *testing.T) { + dl := NewDenyList() + + dl.Add("expired", time.Now().Add(-time.Second)) + dl.Add("active", time.Now().Add(time.Hour)) + + dl.Cleanup() + + assert.False(t, dl.IsRevoked("expired")) + assert.True(t, dl.IsRevoked("active")) +} diff --git a/stainless-proxy/internal/secretutil/do_default.go b/stainless-proxy/internal/secretutil/do_default.go new file mode 100644 index 0000000..0cf28e2 --- /dev/null +++ b/stainless-proxy/internal/secretutil/do_default.go @@ -0,0 +1,5 @@ +//go:build !goexperiment.runtimesecret + +package secretutil + +func Do(f func()) { f() } diff --git a/stainless-proxy/internal/secretutil/do_experiment.go b/stainless-proxy/internal/secretutil/do_experiment.go new file mode 100644 index 0000000..efdc82f --- /dev/null +++ b/stainless-proxy/internal/secretutil/do_experiment.go @@ -0,0 +1,7 @@ +//go:build goexperiment.runtimesecret + +package secretutil + +import "runtime/secret" + +func Do(f func()) { secret.Do(f) } diff --git a/stainless-proxy/internal/server/handlers.go b/stainless-proxy/internal/server/handlers.go new file mode 100644 index 0000000..0bb90da --- /dev/null +++ b/stainless-proxy/internal/server/handlers.go @@ -0,0 +1,129 @@ +package server + +import ( + "encoding/json" + "fmt" + "log/slog" + "net/http" + "time" + + "github.com/stainless-api/stainless-proxy/internal/jwe" + "github.com/stainless-api/stainless-proxy/internal/keystore" + "github.com/stainless-api/stainless-proxy/internal/revocation" +) + +type Handlers struct { + keyStore *keystore.KeyStore + denyList *revocation.DenyList +} + +func NewHandlers(ks *keystore.KeyStore, dl *revocation.DenyList) *Handlers { + return &Handlers{keyStore: ks, denyList: dl} +} + +func (h *Handlers) JWKS(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Cache-Control", "public, max-age=3600") + json.NewEncoder(w).Encode(h.keyStore.JWKS()) +} + +func (h *Handlers) Health(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"status":"ok"}`) +} + +type revokeRequest struct { + Hash string `json:"hash"` + ExpiresAt int64 `json:"expires_at"` +} + +func (h *Handlers) Revoke(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, `{"error":"method not allowed"}`, http.StatusMethodNotAllowed) + return + } + + var req revokeRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, `{"error":"invalid request body"}`, http.StatusBadRequest) + return + } + + if req.Hash == "" { + http.Error(w, `{"error":"hash is required"}`, http.StatusBadRequest) + return + } + + expiresAt := time.Unix(req.ExpiresAt, 0) + if req.ExpiresAt == 0 { + expiresAt = time.Now().Add(24 * time.Hour) + } + + h.denyList.Add(req.Hash, expiresAt) + slog.Info("JWE revoked", "hash", req.Hash, "expires_at", expiresAt) + + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"status":"revoked"}`) +} + +type mintRequest struct { + ExpDuration string `json:"exp_duration"` + AllowedHosts []string `json:"allowed_hosts"` + Credentials []jwe.Credential `json:"credentials"` +} + +type mintResponse struct { + Token string `json:"token"` + ExpiresAt string `json:"expires_at"` +} + +func (h *Handlers) Mint(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, `{"error":"method not allowed"}`, http.StatusMethodNotAllowed) + return + } + + var req mintRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, `{"error":"invalid request body"}`, http.StatusBadRequest) + return + } + + if len(req.AllowedHosts) == 0 { + http.Error(w, `{"error":"allowed_hosts is required"}`, http.StatusBadRequest) + return + } + if len(req.Credentials) == 0 { + http.Error(w, `{"error":"credentials is required"}`, http.StatusBadRequest) + return + } + + duration, err := time.ParseDuration(req.ExpDuration) + if err != nil { + http.Error(w, `{"error":"invalid exp_duration"}`, http.StatusBadRequest) + return + } + + expiresAt := time.Now().Add(duration) + payload := jwe.Payload{ + Exp: expiresAt.Unix(), + AllowedHosts: req.AllowedHosts, + Credentials: req.Credentials, + } + + primary := h.keyStore.PrimaryKey() + enc := jwe.NewEncryptor(primary.PublicKey, primary.KID) + + token, err := enc.Encrypt(payload) + if err != nil { + slog.Error("minting JWE failed", "error", err) + http.Error(w, `{"error":"encryption failed"}`, http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(mintResponse{ + Token: token, + ExpiresAt: expiresAt.UTC().Format(time.RFC3339), + }) +} diff --git a/stainless-proxy/internal/server/middleware.go b/stainless-proxy/internal/server/middleware.go new file mode 100644 index 0000000..63dd3e3 --- /dev/null +++ b/stainless-proxy/internal/server/middleware.go @@ -0,0 +1,90 @@ +package server + +import ( + "log/slog" + "net/http" + "time" + + "github.com/stainless-api/stainless-proxy/internal/config" +) + +type MiddlewareFunc func(http.Handler) http.Handler + +func ChainMiddleware(h http.Handler, middlewares ...MiddlewareFunc) http.Handler { + for _, mw := range middlewares { + h = mw(h) + } + return h +} + +func LoggerMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + start := time.Now() + rw := &responseWriter{ResponseWriter: w, status: http.StatusOK} + next.ServeHTTP(rw, r) + slog.Info("request", + "method", r.Method, + "path", r.URL.Path, + "status", rw.status, + "duration_ms", time.Since(start).Milliseconds(), + "remote_addr", r.RemoteAddr, + ) + }) +} + +func RecoverMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer func() { + if err := recover(); err != nil { + slog.Error("panic recovered", "error", err) + http.Error(w, `{"error":"internal server error"}`, http.StatusInternalServerError) + } + }() + next.ServeHTTP(w, r) + }) +} + +func MintAuthMiddleware(secret config.Secret) MiddlewareFunc { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + auth := r.Header.Get("Authorization") + if auth != "Bearer "+string(secret) { + http.Error(w, `{"error":"unauthorized"}`, http.StatusUnauthorized) + return + } + next.ServeHTTP(w, r) + }) + } +} + +type responseWriter struct { + http.ResponseWriter + status int + wroteHeader bool +} + +func (rw *responseWriter) WriteHeader(code int) { + if rw.wroteHeader { + return + } + rw.status = code + rw.wroteHeader = true + rw.ResponseWriter.WriteHeader(code) +} + +func (rw *responseWriter) Write(b []byte) (int, error) { + if !rw.wroteHeader { + rw.WriteHeader(http.StatusOK) + } + return rw.ResponseWriter.Write(b) +} + +func (rw *responseWriter) Unwrap() http.ResponseWriter { + return rw.ResponseWriter +} + +func (rw *responseWriter) Flush() { + if f, ok := rw.ResponseWriter.(http.Flusher); ok { + f.Flush() + } +} diff --git a/stainless-proxy/internal/server/server.go b/stainless-proxy/internal/server/server.go new file mode 100644 index 0000000..ea3b87c --- /dev/null +++ b/stainless-proxy/internal/server/server.go @@ -0,0 +1,90 @@ +package server + +import ( + "context" + "errors" + "log/slog" + "net/http" + "os" + "os/signal" + "syscall" + "time" + + "github.com/stainless-api/stainless-proxy/internal/config" + "github.com/stainless-api/stainless-proxy/internal/keystore" + "github.com/stainless-api/stainless-proxy/internal/proxy" + "github.com/stainless-api/stainless-proxy/internal/revocation" +) + +type Server struct { + httpServer *http.Server + denyList *revocation.DenyList +} + +func New(cfg *config.Config, ks *keystore.KeyStore, p *proxy.Proxy, dl *revocation.DenyList) *Server { + mux := http.NewServeMux() + handlers := NewHandlers(ks, dl) + + mux.HandleFunc("GET /.well-known/jwks.json", handlers.JWKS) + mux.HandleFunc("GET /health", handlers.Health) + mux.Handle("POST /revoke", ChainMiddleware( + http.HandlerFunc(handlers.Revoke), + RecoverMiddleware, + LoggerMiddleware, + )) + + if cfg.MintEnabled { + mintHandler := ChainMiddleware( + http.HandlerFunc(handlers.Mint), + MintAuthMiddleware(cfg.MintSecret), + RecoverMiddleware, + LoggerMiddleware, + ) + mux.Handle("POST /mint", mintHandler) + slog.Info("mint endpoint enabled") + } + + // All other paths go to the proxy + mux.Handle("/", ChainMiddleware( + p, + RecoverMiddleware, + LoggerMiddleware, + )) + + return &Server{ + httpServer: &http.Server{ + Addr: cfg.Addr, + Handler: mux, + }, + denyList: dl, + } +} + +func (s *Server) Run(ctx context.Context) error { + s.denyList.StartCleanup(ctx, 5*time.Minute) + + errChan := make(chan error, 1) + go func() { + slog.Info("starting server", "addr", s.httpServer.Addr) + if err := s.httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + errChan <- err + } + }() + + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + + select { + case sig := <-sigChan: + slog.Info("received signal", "signal", sig) + case err := <-errChan: + return err + case <-ctx.Done(): + slog.Info("context cancelled") + } + + shutdownCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + return s.httpServer.Shutdown(shutdownCtx) +} From a7faa2bec7dbfc6a9a5d6bbba806f7817f85f846 Mon Sep 17 00:00:00 2001 From: Sam El-Borai Date: Mon, 9 Feb 2026 23:17:38 +0100 Subject: [PATCH 2/2] Fix review findings: race condition, timing attack, hop-by-hop typo, http timeout MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Use constant-time comparison for mint secret auth - Make DenyList.IsRevoked read-only to fix race between RUnlock and Lock where a concurrent Add could be deleted; expired entry cleanup is left to the Cleanup goroutine - Add 30s timeout to JWKS fetch in mint CLI - Fix hop-by-hop header name: Trailers → Trailer per RFC 7230 --- stainless-proxy/cmd/mint/main.go | 3 ++- stainless-proxy/internal/proxy/proxy.go | 2 +- stainless-proxy/internal/revocation/revocation.go | 10 ++-------- stainless-proxy/internal/server/middleware.go | 4 +++- 4 files changed, 8 insertions(+), 11 deletions(-) diff --git a/stainless-proxy/cmd/mint/main.go b/stainless-proxy/cmd/mint/main.go index 026dcee..bfe6b2a 100644 --- a/stainless-proxy/cmd/mint/main.go +++ b/stainless-proxy/cmd/mint/main.go @@ -51,7 +51,8 @@ func main() { } // Fetch JWKS - resp, err := http.Get(*jwksURL) + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Get(*jwksURL) if err != nil { fmt.Fprintf(os.Stderr, "error: fetching JWKS: %v\n", err) os.Exit(1) diff --git a/stainless-proxy/internal/proxy/proxy.go b/stainless-proxy/internal/proxy/proxy.go index 3a5f501..74b7d7b 100644 --- a/stainless-proxy/internal/proxy/proxy.go +++ b/stainless-proxy/internal/proxy/proxy.go @@ -261,7 +261,7 @@ func copyHeaders(dst, src http.Header) { func isHopByHop(header string) bool { switch header { case "Connection", "Keep-Alive", "Proxy-Authenticate", - "Proxy-Authorization", "Te", "Trailers", + "Proxy-Authorization", "Te", "Trailer", "Transfer-Encoding", "Upgrade": return true } diff --git a/stainless-proxy/internal/revocation/revocation.go b/stainless-proxy/internal/revocation/revocation.go index 2caaf7a..10f6682 100644 --- a/stainless-proxy/internal/revocation/revocation.go +++ b/stainless-proxy/internal/revocation/revocation.go @@ -25,18 +25,12 @@ func (d *DenyList) Add(hash string, expiresAt time.Time) { func (d *DenyList) IsRevoked(hash string) bool { d.mu.RLock() + defer d.mu.RUnlock() exp, ok := d.entries[hash] - d.mu.RUnlock() if !ok { return false } - if time.Now().After(exp) { - d.mu.Lock() - delete(d.entries, hash) - d.mu.Unlock() - return false - } - return true + return time.Now().Before(exp) } func (d *DenyList) Cleanup() { diff --git a/stainless-proxy/internal/server/middleware.go b/stainless-proxy/internal/server/middleware.go index 63dd3e3..5994989 100644 --- a/stainless-proxy/internal/server/middleware.go +++ b/stainless-proxy/internal/server/middleware.go @@ -1,6 +1,7 @@ package server import ( + "crypto/subtle" "log/slog" "net/http" "time" @@ -48,7 +49,8 @@ func MintAuthMiddleware(secret config.Secret) MiddlewareFunc { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { auth := r.Header.Get("Authorization") - if auth != "Bearer "+string(secret) { + expected := "Bearer " + string(secret) + if subtle.ConstantTimeCompare([]byte(auth), []byte(expected)) != 1 { http.Error(w, `{"error":"unauthorized"}`, http.StatusUnauthorized) return }