Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions stainless-proxy/cmd/mint/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
package main

import (
"encoding/json"
"flag"
"fmt"
"net/http"
"os"
"strings"
"time"

"github.com/go-jose/go-jose/v4"
"github.com/stainless-api/stainless-proxy/internal/jwe"

"crypto/ecdsa"
)

type credFlag []string

func (f *credFlag) String() string { return strings.Join(*f, ", ") }
func (f *credFlag) Set(v string) error {
*f = append(*f, v)
return nil
}

func main() {
jwksURL := flag.String("jwks-url", "", "URL to fetch JWKS from")
expDuration := flag.String("exp", "1h", "expiration duration")
hosts := flag.String("hosts", "", "comma-separated allowed hosts")
var creds credFlag
flag.Var(&creds, "cred", "credential as Header=Value (repeatable)")
flag.Parse()

if *jwksURL == "" {
fmt.Fprintln(os.Stderr, "error: -jwks-url is required")
os.Exit(1)
}
if *hosts == "" {
fmt.Fprintln(os.Stderr, "error: -hosts is required")
os.Exit(1)
}
if len(creds) == 0 {
fmt.Fprintln(os.Stderr, "error: at least one -cred is required")
os.Exit(1)
}

duration, err := time.ParseDuration(*expDuration)
if err != nil {
fmt.Fprintf(os.Stderr, "error: invalid expiration: %v\n", err)
os.Exit(1)
}

// Fetch JWKS
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Get(*jwksURL)
if err != nil {
fmt.Fprintf(os.Stderr, "error: fetching JWKS: %v\n", err)
os.Exit(1)
}
defer resp.Body.Close()

var jwks jose.JSONWebKeySet
if err := json.NewDecoder(resp.Body).Decode(&jwks); err != nil {
fmt.Fprintf(os.Stderr, "error: parsing JWKS: %v\n", err)
os.Exit(1)
}

if len(jwks.Keys) == 0 {
fmt.Fprintln(os.Stderr, "error: no keys in JWKS")
os.Exit(1)
}

// Use the first key
jwk := jwks.Keys[0]
pubKey, ok := jwk.Key.(*ecdsa.PublicKey)
if !ok {
fmt.Fprintln(os.Stderr, "error: first key is not an ECDSA public key")
os.Exit(1)
}

// Parse credentials
var credentials []jwe.Credential
for _, c := range creds {
eqIdx := strings.IndexByte(c, '=')
if eqIdx == -1 {
fmt.Fprintf(os.Stderr, "error: invalid credential format: %s (expected Header=Value)\n", c)
os.Exit(1)
}
credentials = append(credentials, jwe.Credential{
Header: c[:eqIdx],
Value: c[eqIdx+1:],
})
}

payload := jwe.Payload{
Exp: time.Now().Add(duration).Unix(),
AllowedHosts: strings.Split(*hosts, ","),
Credentials: credentials,
}

enc := jwe.NewEncryptor(pubKey, jwk.KeyID)
token, err := enc.Encrypt(payload)
if err != nil {
fmt.Fprintf(os.Stderr, "error: encrypting: %v\n", err)
os.Exit(1)
}

fmt.Println(token)
}
88 changes: 88 additions & 0 deletions stainless-proxy/cmd/stainless-proxy/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package main

import (
"context"
"flag"
"log/slog"
"os"

"github.com/stainless-api/stainless-proxy/internal/config"
"github.com/stainless-api/stainless-proxy/internal/jwe"
"github.com/stainless-api/stainless-proxy/internal/keystore"
"github.com/stainless-api/stainless-proxy/internal/proxy"
"github.com/stainless-api/stainless-proxy/internal/revocation"
"github.com/stainless-api/stainless-proxy/internal/server"
)

func main() {
configPath := flag.String("config", "", "path to config file")
flag.Parse()

if *configPath == "" {
slog.Error("config flag is required")
os.Exit(1)
}

cfg, err := config.Load(*configPath)
if err != nil {
slog.Error("loading config", "error", err)
os.Exit(1)
}

setupLogging(cfg)

ks, err := keystore.New(cfg.KeyDir, cfg.GenerateKeys)
if err != nil {
slog.Error("initializing keystore", "error", err)
os.Exit(1)
}

primary := ks.PrimaryKey()
slog.Info("keystore initialized",
"key_count", len(ks.Keys()),
"primary_kid", primary.KID,
)

var decryptorKeys []jwe.KeyEntry
for _, k := range ks.Keys() {
decryptorKeys = append(decryptorKeys, jwe.KeyEntry{
KID: k.KID,
PrivateKey: k.PrivateKey,
})
}

decryptor := jwe.NewMultiKeyDecryptor(decryptorKeys)
denyList := revocation.NewDenyList()
p := proxy.New(decryptor, denyList)
srv := server.New(cfg, ks, p, denyList)

if err := srv.Run(context.Background()); err != nil {
slog.Error("server error", "error", err)
os.Exit(1)
}
}

func setupLogging(cfg *config.Config) {
var level slog.Level
switch cfg.LogLevel {
case "debug":
level = slog.LevelDebug
case "warn":
level = slog.LevelWarn
case "error":
level = slog.LevelError
default:
level = slog.LevelInfo
}

opts := &slog.HandlerOptions{Level: level}

var handler slog.Handler
if cfg.LogFormat == "json" {
handler = slog.NewJSONHandler(os.Stderr, opts)
} else {
handler = slog.NewTextHandler(os.Stderr, opts)
}

slog.SetDefault(slog.New(handler))
}
9 changes: 9 additions & 0 deletions stainless-proxy/config.example.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"addr": ":8443",
"keyDir": "./keys",
"generateKeys": true,
"mintEnabled": true,
"mintSecret": {"$env": "MINT_SECRET"},
"logLevel": "debug",
"logFormat": "text"
}
14 changes: 14 additions & 0 deletions stainless-proxy/go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
module github.com/stainless-api/stainless-proxy

go 1.26rc1

require (
github.com/go-jose/go-jose/v4 v4.1.3
github.com/stretchr/testify v1.11.1
)

require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
12 changes: 12 additions & 0 deletions stainless-proxy/go.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-jose/go-jose/v4 v4.1.3 h1:CVLmWDhDVRa6Mi/IgCgaopNosCaHz7zrMeF9MlZRkrs=
github.com/go-jose/go-jose/v4 v4.1.3/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
112 changes: 112 additions & 0 deletions stainless-proxy/internal/config/config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
package config

import (
"encoding/json"
"fmt"
"os"
)

type Secret string

func (s Secret) String() string {
if s == "" {
return ""
}
return "***"
}

func (s Secret) MarshalJSON() ([]byte, error) {
if s == "" {
return json.Marshal("")
}
return json.Marshal("***")
}

type Config struct {
Addr string `json:"addr"`
KeyDir string `json:"keyDir"`
GenerateKeys bool `json:"generateKeys"`
MintEnabled bool `json:"mintEnabled"`
MintSecret Secret `json:"mintSecret"`
LogLevel string `json:"logLevel"`
LogFormat string `json:"logFormat"`
}

func Load(path string) (*Config, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("reading config: %w", err)
}

var raw rawConfig
if err := json.Unmarshal(data, &raw); err != nil {
return nil, fmt.Errorf("parsing config: %w", err)
}

cfg := &Config{
Addr: raw.Addr,
KeyDir: raw.KeyDir,
GenerateKeys: raw.GenerateKeys,
MintEnabled: raw.MintEnabled,
LogLevel: raw.LogLevel,
LogFormat: raw.LogFormat,
}

if raw.MintSecret != nil {
secret, err := parseConfigValue(raw.MintSecret)
if err != nil {
return nil, fmt.Errorf("parsing mintSecret: %w", err)
}
cfg.MintSecret = Secret(secret)
}

if cfg.Addr == "" {
cfg.Addr = ":8443"
}
if cfg.LogLevel == "" {
cfg.LogLevel = "info"
}
if cfg.LogFormat == "" {
cfg.LogFormat = "text"
}

return cfg, nil
}

type rawConfig struct {
Addr string `json:"addr"`
KeyDir string `json:"keyDir"`
GenerateKeys bool `json:"generateKeys"`
MintEnabled bool `json:"mintEnabled"`
MintSecret json.RawMessage `json:"mintSecret"`
LogLevel string `json:"logLevel"`
LogFormat string `json:"logFormat"`
}

func parseConfigValue(raw json.RawMessage) (string, error) {
var str string
if err := json.Unmarshal(raw, &str); err == nil {
return str, nil
}

var ref map[string]string
if err := json.Unmarshal(raw, &ref); err != nil {
return "", fmt.Errorf("config value must be string or reference object")
}

if envVar, ok := ref["$env"]; ok {
value := os.Getenv(envVar)
if value == "" {
return "", fmt.Errorf("environment variable %s not set", envVar)
}
if len(value) >= 2 {
if (value[0] == '"' && value[len(value)-1] == '"') ||
(value[0] == '\'' && value[len(value)-1] == '\'') {
value = value[1 : len(value)-1]
}
Comment on lines +103 to +106
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This logic to trim quotes from environment variable values can be simplified and made more robust by using strings.Trim. This avoids manual index checking and handles both single and double quotes more cleanly.

value = strings.Trim(value, "\"'")

}
return value, nil
}

return "", fmt.Errorf("unknown reference type in config value")
}
31 changes: 31 additions & 0 deletions stainless-proxy/internal/hostmatch/hostmatch.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package hostmatch

import (
"net"
"strings"
)

func Match(host string, patterns []string) bool {
host = stripPort(strings.ToLower(host))
for _, p := range patterns {
p = stripPort(strings.ToLower(p))
if p == host {
return true
}
if strings.HasPrefix(p, "*.") {
suffix := p[1:] // ".example.com"
if strings.HasSuffix(host, suffix) && host != suffix[1:] {
return true
}
}
}
return false
}

func stripPort(host string) string {
h, _, err := net.SplitHostPort(host)
if err != nil {
return host
}
return h
}
Loading
Loading