Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions cmd/spire-server/cli/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ package cli
import (
"context"
stdlog "log"
"os"
"slices"
"strings"

"github.com/mitchellh/cli"
"github.com/spiffe/spire/cmd/spire-server/cli/agent"
Expand All @@ -18,7 +21,9 @@ import (
"github.com/spiffe/spire/cmd/spire-server/cli/token"
"github.com/spiffe/spire/cmd/spire-server/cli/upstreamauthority"
"github.com/spiffe/spire/cmd/spire-server/cli/validate"
"github.com/spiffe/spire/cmd/spire-server/cli/wit"
"github.com/spiffe/spire/cmd/spire-server/cli/x509"
"github.com/spiffe/spire/pkg/common/fflag"
"github.com/spiffe/spire/pkg/common/log"
"github.com/spiffe/spire/pkg/common/version"
)
Expand Down Expand Up @@ -165,9 +170,33 @@ func (cc *CLI) Run(ctx context.Context, args []string) int {
},
}

addCommandsEnabledByFFlags(c.Commands)

exitStatus, err := c.Run()
if err != nil {
stdlog.Println(err)
}
return exitStatus
}

// addCommandsEnabledByFFlags adds commands that are currently available only
// through a feature flag.
// Feature flags support through the fflag package in SPIRE Server is
// designed to work only with the run command and the config file.
// Since feature flags are intended to be used by developers of a specific
// feature only, exposing them through command line arguments is not
// convenient. Instead, we use the SPIRE_SERVER_FFLAGS environment variable
// to read the configured SPIRE Server feature flags from the environment
// when other commands may be enabled through feature flags.
func addCommandsEnabledByFFlags(commands map[string]cli.CommandFactory) {
fflagsEnv := os.Getenv("SPIRE_SERVER_FFLAGS")
fflags := strings.Split(fflagsEnv, " ")

flagWITSVID := slices.Contains(fflags, string(fflag.FlagWITSVID))

if flagWITSVID {
commands["wit mint"] = func() (cli.Command, error) {
return wit.NewMintCommand(), nil
}
}
}
21 changes: 3 additions & 18 deletions cmd/spire-server/cli/run/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -643,7 +643,7 @@ func NewServerConfig(c *Config, logOptions []log.Option, allowUnknownConfig bool
}

if c.Server.CAKeyType != "" {
keyType, err := keyTypeFromString(c.Server.CAKeyType)
keyType, err := keymanager.KeyTypeFromString(c.Server.CAKeyType)
if err != nil {
return nil, fmt.Errorf("error parsing ca_key_type: %w", err)
}
Expand All @@ -657,14 +657,14 @@ func NewServerConfig(c *Config, logOptions []log.Option, allowUnknownConfig bool
}

if c.Server.JWTKeyType != "" {
sc.JWTKeyType, err = keyTypeFromString(c.Server.JWTKeyType)
sc.JWTKeyType, err = keymanager.KeyTypeFromString(c.Server.JWTKeyType)
if err != nil {
return nil, fmt.Errorf("error parsing jwt_key_type: %w", err)
}
}

if c.Server.Experimental.WITKeyType != "" {
sc.WITKeyType, err = keyTypeFromString(c.Server.Experimental.WITKeyType)
sc.WITKeyType, err = keymanager.KeyTypeFromString(c.Server.Experimental.WITKeyType)
if err != nil {
return nil, fmt.Errorf("error parsing wit_key_type: %w", err)
}
Expand Down Expand Up @@ -1081,21 +1081,6 @@ func defaultConfig() *Config {
}
}

func keyTypeFromString(s string) (keymanager.KeyType, error) {
switch strings.ToLower(s) {
case "rsa-2048":
return keymanager.RSA2048, nil
case "rsa-4096":
return keymanager.RSA4096, nil
case "ec-p256":
return keymanager.ECP256, nil
case "ec-p384":
return keymanager.ECP384, nil
default:
return keymanager.KeyTypeUnset, fmt.Errorf("key type %q is unknown; must be one of [rsa-2048, rsa-4096, ec-p256, ec-p384]", s)
}
}

// hasCompatibleTTL checks if we can guarantee the configured SVID TTL given the
// configured CA TTL. If we detect that a new SVID TTL may be cut short due to
// a scheduled CA rotation, this function will return false. This method should
Expand Down
205 changes: 205 additions & 0 deletions cmd/spire-server/cli/wit/mint.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
package wit

import (
"context"
"crypto"
"crypto/x509"
"errors"
"flag"
"fmt"
"time"

"github.com/go-jose/go-jose/v4"
"github.com/go-jose/go-jose/v4/jwt"
"github.com/mitchellh/cli"
"github.com/spiffe/go-spiffe/v2/spiffeid"
svidv1 "github.com/spiffe/spire-api-sdk/proto/spire/api/server/svid/v1"
"github.com/spiffe/spire-api-sdk/proto/spire/api/types"
serverutil "github.com/spiffe/spire/cmd/spire-server/util"
commoncli "github.com/spiffe/spire/pkg/common/cli"
"github.com/spiffe/spire/pkg/common/cliprinter"
"github.com/spiffe/spire/pkg/common/diskutil"
"github.com/spiffe/spire/pkg/common/jwtsvid"
"github.com/spiffe/spire/pkg/common/util"
"github.com/spiffe/spire/pkg/server/plugin/keymanager"
)

func NewMintCommand() cli.Command {
return newMintCommand(commoncli.DefaultEnv)
}

func newMintCommand(env *commoncli.Env) cli.Command {
return serverutil.AdaptCommand(env, &mintCommand{env: env})
}

// Test helper function, to have control over the workload key that is being generated
func newMintCommandWithKeyGenerator(env *commoncli.Env, workloadKeyGenerator func() (crypto.Signer, error)) cli.Command {
return serverutil.AdaptCommand(env, &mintCommand{env: env, workloadKeyGenerator: workloadKeyGenerator})
}

type mintCommand struct {
spiffeID string
keyType string
ttl time.Duration
write string
env *commoncli.Env
printer cliprinter.Printer
workloadKeyGenerator func() (crypto.Signer, error)
}

func (c *mintCommand) Name() string {
return "wit mint"
}

func (c *mintCommand) Synopsis() string {
return "Mints a WIT-SVID"
}

func (c *mintCommand) AppendFlags(fs *flag.FlagSet) {
fs.StringVar(&c.spiffeID, "spiffeID", "", "SPIFFE ID of the WIT-SVID")
fs.StringVar(&c.keyType, "keyType", "ec-p256", "Key type of the WIT-SVID")
fs.DurationVar(&c.ttl, "ttl", 0, "TTL of the WIT-SVID")
fs.StringVar(&c.write, "write", "", "Directory to write output to instead of stdout")
cliprinter.AppendFlagWithCustomPretty(&c.printer, fs, c.env, prettyPrintMint)
}

type mintResult struct {
PrivateKey string `json:"private_key"`
Svid *types.WITSVID `json:"svid"`
}

func (c *mintCommand) Run(ctx context.Context, env *commoncli.Env, serverClient serverutil.ServerClient) error {
if c.spiffeID == "" {
return errors.New("spiffeID must be specified")
}
spiffeID, err := spiffeid.FromString(c.spiffeID)
if err != nil {
return err
}
ttl, err := ttlToSeconds(c.ttl)
if err != nil {
return fmt.Errorf("invalid value for TTL: %w", err)
}

keyType, err := keymanager.KeyTypeFromString(c.keyType)
if err != nil {
return fmt.Errorf("invalid key-type: %w", err)
}

if c.workloadKeyGenerator == nil {
c.workloadKeyGenerator = keyType.GenerateSigner
}

signer, err := c.workloadKeyGenerator()
if err != nil {
return fmt.Errorf("could not generate public/private key pair: %w", err)
}

publicKeyDer, err := x509.MarshalPKIXPublicKey(signer.Public())
if err != nil {
return fmt.Errorf("could not marshal public/private key pair: %w", err)
}

client := serverClient.NewSVIDClient()
resp, err := client.MintWITSVID(ctx, &svidv1.MintWITSVIDRequest{
Id: &types.SPIFFEID{
TrustDomain: spiffeID.TrustDomain().Name(),
Path: spiffeID.Path(),
},
PublicKey: publicKeyDer,
Ttl: ttl,
})
if err != nil {
return fmt.Errorf("unable to mint SVID: %w", err)
}
token := resp.Svid.Token
if err := c.validateToken(token, env); err != nil {
return err
}

jwk := jose.JSONWebKey{
Key: signer,
}
jwkJson, err := jwk.MarshalJSON()
if err != nil {
return fmt.Errorf("could not marshal private key: %w", err)
}

// Print in stdout
if c.write == "" {
return c.printer.PrintStruct(&mintResult{
PrivateKey: string(jwkJson),
Svid: resp.Svid,
})
}

tokenPath := env.JoinPath(c.write, "token")
keyPath := env.JoinPath(c.write, "key")

if err := diskutil.WritePrivateFile(tokenPath, []byte(token)); err != nil {
return fmt.Errorf("unable to write token: %w", err)
}
if err := env.Printf("WIT-SVID written to %s\n", tokenPath); err != nil {
return err
}
if err := diskutil.WritePrivateFile(keyPath, jwkJson); err != nil {
return fmt.Errorf("unable to write key: %w", err)
}
return env.Printf("Private key written to %s\n", keyPath)
}

func (c *mintCommand) validateToken(token string, env *commoncli.Env) error {
if token == "" {
return errors.New("server response missing token")
}

eol, err := getWITSVIDEndOfLife(token)
if err != nil {
env.ErrPrintf("Unable to determine WIT-SVID lifetime: %v\n", err)
return nil
}

if time.Until(eol) < c.ttl {
env.ErrPrintf("WIT-SVID lifetime was capped shorter than specified ttl; expires %q\n", eol.UTC().Format(time.RFC3339))
}

return nil
}

func getWITSVIDEndOfLife(token string) (time.Time, error) {
t, err := jwt.ParseSigned(token, jwtsvid.AllowedSignatureAlgorithms)
if err != nil {
return time.Time{}, err
}

claims := new(jwt.Claims)
if err := t.UnsafeClaimsWithoutVerification(claims); err != nil {
return time.Time{}, err
}

if claims.Expiry == nil {
return time.Time{}, errors.New("no expiry claim")
}

return claims.Expiry.Time(), nil
}

// ttlToSeconds returns the number of seconds in a duration, rounded up to
// the nearest second
func ttlToSeconds(ttl time.Duration) (int32, error) {
return util.CheckedCast[int32]((ttl + time.Second - 1) / time.Second)
}

func prettyPrintMint(env *commoncli.Env, results ...any) error {
resultInterface, ok := results[0].([]any)
if !ok {
return cliprinter.ErrInternalCustomPrettyFunc
}

if wit, ok := resultInterface[0].(*mintResult); ok {
errToken := env.Println(wit.Svid.Token)
errKey := env.Println(wit.PrivateKey)
return errors.Join(errToken, errKey)
}
return cliprinter.ErrInternalCustomPrettyFunc
}
Loading
Loading