From 03c303c79900e9a2f24e7b73427f217952794702 Mon Sep 17 00:00:00 2001 From: Sorin Dumitru Date: Sat, 22 Feb 2025 20:40:02 +0000 Subject: [PATCH 1/6] credbuilder: add workload WIT-SVID support Signed-off-by: Sorin Dumitru --- pkg/server/credtemplate/builder.go | 40 ++++++++++ pkg/server/credtemplate/builder_test.go | 101 ++++++++++++++++++++++++ 2 files changed, 141 insertions(+) diff --git a/pkg/server/credtemplate/builder.go b/pkg/server/credtemplate/builder.go index 76d7cde2dd..db1f7a82a4 100644 --- a/pkg/server/credtemplate/builder.go +++ b/pkg/server/credtemplate/builder.go @@ -12,6 +12,7 @@ import ( "time" "github.com/andres-erbsen/clock" + "github.com/go-jose/go-jose/v4" "github.com/go-jose/go-jose/v4/jwt" "github.com/spiffe/go-spiffe/v2/spiffeid" "github.com/spiffe/spire/pkg/common/idutil" @@ -34,6 +35,10 @@ const ( // not provided in the signing request. DefaultJWTSVIDTTL = time.Minute * 5 + // DefaultWITSVIDTTL is the TTL given to WIT-SVIDs if a different TTL is + // not provided in the signing request. + DefaultWITSVIDTTL = time.Hour + // NotBeforeCushion is how much of a cushion to subtract from the current // time when determining the notBefore field of certificates to account // for clock skew. @@ -99,6 +104,13 @@ type WorkloadJWTSVIDParams struct { ExpirationCap time.Time } +type WorkloadWITSVIDParams struct { + SPIFFEID spiffeid.ID + PublicKey jose.JSONWebKey + TTL time.Duration + ExpirationCap time.Time +} + type Config struct { TrustDomain spiffeid.TrustDomain Clock clock.Clock @@ -108,6 +120,7 @@ type Config struct { X509SVIDTTL time.Duration JWTSVIDTTL time.Duration JWTIssuer string + WITSVIDTTL time.Duration AgentSVIDTTL time.Duration CredentialComposers []credentialcomposer.CredentialComposer NewSerialNumber func() (*big.Int, error) @@ -143,6 +156,9 @@ func NewBuilder(config Config) (*Builder, error) { if config.JWTSVIDTTL == 0 { config.JWTSVIDTTL = DefaultJWTSVIDTTL } + if config.WITSVIDTTL == 0 { + config.WITSVIDTTL = DefaultWITSVIDTTL + } if config.AgentSVIDTTL == 0 { // config.X509SVIDTTL should be initialized by the code above and // therefore safe to use to initialize the AgentSVIDTTL. @@ -357,6 +373,30 @@ func (b *Builder) BuildWorkloadJWTSVIDClaims(ctx context.Context, params Workloa return attributes.Claims, nil } +func (b *Builder) BuildWorkloadWITSVIDClaims(ctx context.Context, params WorkloadWITSVIDParams) (map[string]any, error) { + if params.SPIFFEID.IsZero() { + return nil, errors.New("invalid WIT-SVID ID: cannot be empty") + } + if err := api.VerifyTrustDomainMemberID(b.config.TrustDomain, params.SPIFFEID); err != nil { + return nil, fmt.Errorf("invalid WIT-SVID ID: %w", err) + } + + ttl := params.TTL + if ttl <= 0 { + ttl = b.config.WITSVIDTTL + } + + now := b.config.Clock.Now() + _, expiresAt := computeCappedLifetime(b.config.Clock, ttl, params.ExpirationCap) + + return map[string]any{ + "sub": params.SPIFFEID.String(), + "exp": jwt.NewNumericDate(expiresAt), + "iat": jwt.NewNumericDate(now), + "cnf": params.PublicKey, + }, nil +} + func (b *Builder) buildX509CATemplate(publicKey crypto.PublicKey, parentChain []*x509.Certificate, ttl time.Duration) (*x509.Certificate, error) { tmpl, err := b.buildBaseTemplate(b.x509CAID, publicKey, parentChain) if err != nil { diff --git a/pkg/server/credtemplate/builder_test.go b/pkg/server/credtemplate/builder_test.go index e7d97af67b..2bed64d10a 100644 --- a/pkg/server/credtemplate/builder_test.go +++ b/pkg/server/credtemplate/builder_test.go @@ -3,6 +3,9 @@ package credtemplate_test import ( "context" "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" "crypto/x509" "crypto/x509/pkix" "errors" @@ -13,6 +16,7 @@ import ( "testing" "time" + "github.com/go-jose/go-jose/v4" "github.com/go-jose/go-jose/v4/jwt" "github.com/spiffe/go-spiffe/v2/spiffeid" credentialcomposerv1 "github.com/spiffe/spire-plugin-sdk/proto/spire/plugin/server/credentialcomposer/v1" @@ -44,6 +48,7 @@ var ( x509CANotAfter = now.Add(credtemplate.DefaultX509CATTL) x509SVIDNotAfter = now.Add(credtemplate.DefaultX509SVIDTTL) jwtSVIDNotAfter = now.Add(credtemplate.DefaultJWTSVIDTTL) + witSVIDNotAfter = now.Add(credtemplate.DefaultWITSVIDTTL) caKeyUsage = x509.KeyUsageCertSign | x509.KeyUsageCRLSign svidKeyUsage = x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement | x509.KeyUsageDigitalSignature svidExtKeyUsage = []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth} @@ -81,6 +86,7 @@ func TestNewBuilderSetsDefaults(t *testing.T) { X509SVIDSubject: credtemplate.DefaultX509SVIDSubject(), X509SVIDTTL: credtemplate.DefaultX509SVIDTTL, JWTSVIDTTL: credtemplate.DefaultJWTSVIDTTL, + WITSVIDTTL: credtemplate.DefaultWITSVIDTTL, JWTIssuer: "", AgentSVIDTTL: credtemplate.DefaultX509SVIDTTL, }, config) @@ -94,6 +100,7 @@ func TestNewBuilderAllowsConfigOverrides(t *testing.T) { X509CATTL: 1 * time.Minute, X509SVIDTTL: 2 * time.Minute, JWTSVIDTTL: 3 * time.Minute, + WITSVIDTTL: 4 * time.Minute, JWTIssuer: "ISSUER", AgentSVIDTTL: 4 * time.Minute, } @@ -1207,6 +1214,100 @@ func TestBuildWorkloadJWTSVIDClaims(t *testing.T) { } } +func TestBuildWorkloadWITSVIDClaims(t *testing.T) { + for _, tc := range []struct { + desc string + overrideConfig func(config *credtemplate.Config) + overrideParams func(params *credtemplate.WorkloadWITSVIDParams) + overrideExpected func(expected map[string]any) + expectErr string + }{ + { + desc: "defaults", + }, + { + desc: "empty SPIFFE ID", + overrideParams: func(params *credtemplate.WorkloadWITSVIDParams) { + params.SPIFFEID = spiffeid.ID{} + }, + expectErr: "invalid WIT-SVID ID: cannot be empty", + }, + { + desc: "SPIFFE ID from another trust domain", + overrideParams: func(params *credtemplate.WorkloadWITSVIDParams) { + params.SPIFFEID = spiffeid.RequireFromString("spiffe://otherdomain.test/spire/agent/foo/foo-1") + }, + expectErr: `invalid WIT-SVID ID: "spiffe://otherdomain.test/spire/agent/foo/foo-1" is not a member of trust domain "domain.test"`, + }, + { + desc: "override WITSVIDTTL", + overrideConfig: func(config *credtemplate.Config) { + config.WITSVIDTTL = credtemplate.DefaultWITSVIDTTL * 2 + }, + overrideExpected: func(expected map[string]any) { + expected["exp"] = jwt.NewNumericDate(now.Add(credtemplate.DefaultWITSVIDTTL * 2)) + }, + }, + { + desc: "ttl capped by expiration cap", + overrideConfig: func(config *credtemplate.Config) { + config.WITSVIDTTL = parentTTL + time.Hour + }, + overrideParams: func(params *credtemplate.WorkloadWITSVIDParams) { + params.ExpirationCap = now.Add(parentTTL) + }, + overrideExpected: func(expected map[string]any) { + expected["exp"] = jwt.NewNumericDate(now.Add(parentTTL)) + }, + }, + { + desc: "with ttl", + overrideParams: func(params *credtemplate.WorkloadWITSVIDParams) { + params.TTL = credtemplate.DefaultWITSVIDTTL / 2 + }, + overrideExpected: func(expected map[string]any) { + expected["exp"] = jwt.NewNumericDate(now.Add(credtemplate.DefaultWITSVIDTTL / 2)) + }, + }, + } { + t.Run(tc.desc, func(t *testing.T) { + testBuilder(t, tc.overrideConfig, func(t *testing.T, credBuilder *credtemplate.Builder) { + signer, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + + params := credtemplate.WorkloadWITSVIDParams{ + SPIFFEID: workloadID, + PublicKey: jose.JSONWebKey{ + Key: signer.PublicKey, + }, + } + if tc.overrideParams != nil { + tc.overrideParams(¶ms) + } + template, err := credBuilder.BuildWorkloadWITSVIDClaims(ctx, params) + if tc.expectErr != "" { + require.EqualError(t, err, tc.expectErr) + return + } + require.NoError(t, err) + + expected := map[string]any{ + "exp": jwt.NewNumericDate(witSVIDNotAfter), + "iat": jwt.NewNumericDate(now), + "sub": workloadID.String(), + "cnf": jose.JSONWebKey{ + Key: signer.PublicKey, + }, + } + if tc.overrideExpected != nil { + tc.overrideExpected(expected) + } + require.Equal(t, expected, template) + }) + }) + } +} + func testBuilder(t *testing.T, overrideConfig func(config *credtemplate.Config), fn func(*testing.T, *credtemplate.Builder)) { config := credtemplate.Config{ TrustDomain: td, From 4de57b683ff1c0fe1334413931ed803afea3bd02 Mon Sep 17 00:00:00 2001 From: Sorin Dumitru Date: Sun, 23 Feb 2025 07:29:55 +0000 Subject: [PATCH 2/6] ca: support for signing workload WIT-SVIDs Signed-off-by: Sorin Dumitru --- pkg/server/ca/ca.go | 85 +++++++++++++++++++++++++++++ pkg/server/ca/ca_test.go | 72 +++++++++++++++++++++++- test/fakes/fakeserverca/serverca.go | 36 ++++++++++++ 3 files changed, 192 insertions(+), 1 deletion(-) diff --git a/pkg/server/ca/ca.go b/pkg/server/ca/ca.go index e749c7c3a9..98d8ba4e9f 100644 --- a/pkg/server/ca/ca.go +++ b/pkg/server/ca/ca.go @@ -7,6 +7,7 @@ import ( "crypto/x509/pkix" "errors" "fmt" + "slices" "sync" "time" @@ -29,6 +30,14 @@ const ( backdate = 10 * time.Second ) +var ( + witWorkloadKeyAllowedAlgorithms = []jose.SignatureAlgorithm{ + jose.RS256, // RSA with 2048 bits key or higher + jose.ES256, + jose.ES384, + } +) + // ServerCA is an interface for Server CAs type ServerCA interface { SignDownstreamX509CA(ctx context.Context, params DownstreamX509CAParams) ([]*x509.Certificate, error) @@ -36,8 +45,10 @@ type ServerCA interface { SignAgentX509SVID(ctx context.Context, params AgentX509SVIDParams) ([]*x509.Certificate, error) SignWorkloadX509SVID(ctx context.Context, params WorkloadX509SVIDParams) ([]*x509.Certificate, error) SignWorkloadJWTSVID(ctx context.Context, params WorkloadJWTSVIDParams) (string, error) + SignWorkloadWITSVID(ctx context.Context, params WorkloadWITSVIDParams) (string, error) TaintedAuthorities() <-chan []*x509.Certificate IsJWTSVIDsDisabled() bool + IsWITSVIDsDisabled() bool } // DownstreamX509CAParams are parameters relevant to downstream X.509 CA creation @@ -98,6 +109,19 @@ type WorkloadJWTSVIDParams struct { Audience []string } +// WorkloadWITSVIDParams are parameters relevant to workload WIT-SVID creation +type WorkloadWITSVIDParams struct { + // SPIFFE ID of the SVID + SPIFFEID spiffeid.ID + + // TTL is the desired time-to-live of the SVID. Regardless of the TTL, the + // lifetime of the token will be capped to that of the signing key. + TTL time.Duration + + // PublicKey is used for the cnf claim + PublicKey jose.JSONWebKey +} + type X509CA struct { // Signer is used to sign child certificates. Signer crypto.Signer @@ -372,6 +396,39 @@ func (ca *CA) SignWorkloadJWTSVID(ctx context.Context, params WorkloadJWTSVIDPar return token, nil } +func (ca *CA) SignWorkloadWITSVID(ctx context.Context, params WorkloadWITSVIDParams) (string, error) { + witKey := ca.WITKey() + if witKey == nil { + return "", errors.New("WIT key is not available for signing") + } + + workloadKeyAlg, err := cryptoutil.JoseAlgFromPublicKey(params.PublicKey.Key) + if err != nil { + return "", fmt.Errorf("could not determined workload key algorithm: %w", err) + } + + if slices.Index(witWorkloadKeyAllowedAlgorithms, workloadKeyAlg) == -1 { + return "", fmt.Errorf("workload key type '%q' not supported", workloadKeyAlg) + } + + claims, err := ca.c.CredBuilder.BuildWorkloadWITSVIDClaims(ctx, credtemplate.WorkloadWITSVIDParams{ + SPIFFEID: params.SPIFFEID, + PublicKey: params.PublicKey, + TTL: params.TTL, + ExpirationCap: witKey.NotAfter, + }) + if err != nil { + return "", err + } + + token, err := ca.signWITSVID(witKey, claims) + if err != nil { + return "", fmt.Errorf("unable to sign JWT SVID: %w", err) + } + + return token, nil +} + func (ca *CA) getX509CA() (*X509CA, []*x509.Certificate, error) { ca.mu.RLock() defer ca.mu.RUnlock() @@ -426,6 +483,34 @@ func (ca *CA) IsWITSVIDsDisabled() bool { return ca.c.DisableWITSVIDs } +func (ca *CA) signWITSVID(witKey *WITKey, claims map[string]any) (string, error) { + alg, err := cryptoutil.JoseAlgFromPublicKey(witKey.Signer.Public()) + if err != nil { + return "", fmt.Errorf("failed to determine WIT key algorithm: %w", err) + } + + jwtSigner, err := jose.NewSigner( + jose.SigningKey{ + Algorithm: alg, + Key: jose.JSONWebKey{ + Key: cryptosigner.Opaque(witKey.Signer), + KeyID: witKey.Kid, + }, + }, + new(jose.SignerOptions).WithType("wit+jwt"), + ) + if err != nil { + return "", fmt.Errorf("failed to configure WIT signer: %w", err) + } + + signedToken, err := jwt.Signed(jwtSigner).Claims(claims).Serialize() + if err != nil { + return "", fmt.Errorf("failed to sign WIT SVID: %w", err) + } + + return signedToken, nil +} + func makeCertChain(x509CA *X509CA, leaf *x509.Certificate) []*x509.Certificate { return append([]*x509.Certificate{leaf}, x509CA.UpstreamChain...) } diff --git a/pkg/server/ca/ca_test.go b/pkg/server/ca/ca_test.go index 2070e5895c..3d7425c37f 100644 --- a/pkg/server/ca/ca_test.go +++ b/pkg/server/ca/ca_test.go @@ -3,12 +3,14 @@ package ca import ( "context" "crypto/rand" + "crypto/rsa" "crypto/x509" "crypto/x509/pkix" "math/big" "testing" "time" + "github.com/go-jose/go-jose/v4" "github.com/sirupsen/logrus/hooks/test" "github.com/spiffe/go-spiffe/v2/spiffeid" "github.com/spiffe/spire/pkg/common/health" @@ -520,6 +522,64 @@ func (s *CATestSuite) TestSignDownstreamX509CANoCASet() { s.Require().EqualError(err, "X509 CA is not available for signing") } +func (s *CATestSuite) TestNoWITKeySet() { + s.ca.SetWITKey(nil) + _, err := s.ca.SignWorkloadWITSVID(ctx, s.createWITSVIDParams(trustDomainExample, 0)) + s.Require().EqualError(err, "WIT key is not available for signing") +} + +func (s *CATestSuite) TestSignWorkloadWITSVIDUsesDefaultTTLIfTTLUnspecified() { + token, err := s.ca.SignWorkloadWITSVID(ctx, s.createWITSVIDParams(trustDomainExample, 0)) + s.Require().NoError(err) + issuedAt, expiresAt, err := jwtsvid.GetTokenExpiry(token) + s.Require().NoError(err) + s.Require().Equal(s.clock.Now(), issuedAt) + s.Require().Equal(s.clock.Now().Add(credtemplate.DefaultWITSVIDTTL), expiresAt) +} + +func (s *CATestSuite) TestSignWorkloadWITSVIDUsesTTLIfSpecified() { + token, err := s.ca.SignWorkloadWITSVID(ctx, s.createWITSVIDParams(trustDomainExample, time.Minute+time.Second)) + s.Require().NoError(err) + issuedAt, expiresAt, err := jwtsvid.GetTokenExpiry(token) + s.Require().NoError(err) + s.Require().Equal(s.clock.Now(), issuedAt) + s.Require().Equal(s.clock.Now().Add(time.Minute+time.Second), expiresAt) +} + +func (s *CATestSuite) TestSignWorkloadWITSVIDCapsTTLToKeyExpiry() { + token, err := s.ca.SignWorkloadWITSVID(ctx, s.createWITSVIDParams(trustDomainExample, 48*time.Hour)) + s.Require().NoError(err) + issuedAt, expiresAt, err := jwtsvid.GetTokenExpiry(token) + s.Require().NoError(err) + s.Require().Equal(s.clock.Now(), issuedAt) + s.Require().Equal(s.clock.Now().Add(24*time.Hour), expiresAt) +} + +func (s *CATestSuite) TestSignWorkloadWITSVIDValidation() { + // spiffe id for wrong trust domain + _, err := s.ca.SignWorkloadWITSVID(ctx, s.createWITSVIDParams(trustDomainFoo, 0)) + s.Require().EqualError(err, `invalid WIT-SVID ID: "spiffe://foo.com/workload" is not a member of trust domain "example.org"`) + + // validates public key + params := s.createWITSVIDParams(trustDomainExample, 0) + params.PublicKey = jose.JSONWebKey{ + Key: "invalid", + } + _, err = s.ca.SignWorkloadWITSVID(ctx, params) + s.Require().EqualError(err, "could not determined workload key algorithm: unable to determine signature algorithm for public key type string") + + // Validate key algorithm + params = s.createWITSVIDParams(trustDomainExample, 0) + rsaKey, err := rsa.GenerateKey(rand.Reader, 1024) //nolint:gosec + s.Require().NoError(err) + + params.PublicKey = jose.JSONWebKey{ + Key: &rsaKey.PublicKey, + } + _, err = s.ca.SignWorkloadWITSVID(ctx, params) + s.Require().EqualError(err, "could not determined workload key algorithm: unsupported RSA key size: 128") +} + func (s *CATestSuite) TestSignDownstreamX509CA() { svidChain, err := s.ca.SignDownstreamX509CA(ctx, s.createDownstreamX509CAParams()) s.Require().NoError(err) @@ -604,7 +664,7 @@ func (s *CATestSuite) setWITKey() { s.ca.SetWITKey(&WITKey{ Signer: testSigner, Kid: "KID", - NotAfter: s.clock.Now().Add(10 * time.Minute), + NotAfter: s.clock.Now().Add(24 * time.Hour), }) } @@ -650,6 +710,16 @@ func (s *CATestSuite) createJWTSVIDParams(trustDomain spiffeid.TrustDomain, ttl } } +func (s *CATestSuite) createWITSVIDParams(trustDomain spiffeid.TrustDomain, ttl time.Duration) WorkloadWITSVIDParams { + return WorkloadWITSVIDParams{ + SPIFFEID: spiffeid.RequireFromPath(trustDomain, "/workload"), + TTL: ttl, + PublicKey: jose.JSONWebKey{ + Key: testSigner.Public(), + }, + } +} + func (s *CATestSuite) createCACertificate(cn string, parent *x509.Certificate) *x509.Certificate { return createCACertificate(s.T(), s.clock, cn, parent) } diff --git a/test/fakes/fakeserverca/serverca.go b/test/fakes/fakeserverca/serverca.go index e8cfc87192..d6ffcd0f0d 100644 --- a/test/fakes/fakeserverca/serverca.go +++ b/test/fakes/fakeserverca/serverca.go @@ -28,7 +28,9 @@ type Options struct { AgentSVIDTTL time.Duration X509SVIDTTL time.Duration JWTSVIDTTL time.Duration + WITSVIDTTL time.Duration DisableJWTSVIDs bool + DisableWITSVIDs bool } type CA struct { @@ -39,6 +41,7 @@ type CA struct { bundle []*x509.Certificate err error disableJWTSVIDs bool + disableWITSVIDs bool } func New(t *testing.T, trustDomain spiffeid.TrustDomain, options *Options) *CA { @@ -57,6 +60,9 @@ func New(t *testing.T, trustDomain spiffeid.TrustDomain, options *Options) *CA { if options.JWTSVIDTTL == 0 { options.JWTSVIDTTL = time.Minute } + if options.WITSVIDTTL == 0 { + options.WITSVIDTTL = time.Minute + } log, _ := test.NewNullLogger() @@ -69,6 +75,7 @@ func New(t *testing.T, trustDomain spiffeid.TrustDomain, options *Options) *CA { AgentSVIDTTL: options.AgentSVIDTTL, X509SVIDTTL: options.X509SVIDTTL, JWTSVIDTTL: options.JWTSVIDTTL, + WITSVIDTTL: options.WITSVIDTTL, }) require.NoError(t, err) @@ -86,6 +93,7 @@ func New(t *testing.T, trustDomain spiffeid.TrustDomain, options *Options) *CA { TrustDomain: trustDomain, HealthChecker: healthChecker, DisableJWTSVIDs: options.DisableJWTSVIDs, + DisableWITSVIDs: options.DisableWITSVIDs, }) template, err := credBuilder.BuildSelfSignedX509CATemplate(context.Background(), credtemplate.SelfSignedX509CAParams{ @@ -105,6 +113,11 @@ func New(t *testing.T, trustDomain spiffeid.TrustDomain, options *Options) *CA { Kid: "KID", NotAfter: options.Clock.Now().Add(time.Hour), }) + serverCA.SetWITKey(&ca.WITKey{ + Signer: signer, + Kid: "KID", + NotAfter: options.Clock.Now().Add(time.Hour), + }) return &CA{ ca: serverCA, @@ -132,6 +145,10 @@ func (c *CA) SetJWTKey(jwtKey *ca.JWTKey) { c.ca.SetJWTKey(jwtKey) } +func (c *CA) SetWITKey(witKey *ca.WITKey) { + c.ca.SetWITKey(witKey) +} + func (c *CA) NotifyTaintedX509Authorities(taintedAuthorities []*x509.Certificate) { c.ca.NotifyTaintedX509Authorities(taintedAuthorities) } @@ -171,6 +188,13 @@ func (c *CA) SignWorkloadJWTSVID(ctx context.Context, params ca.WorkloadJWTSVIDP return c.ca.SignWorkloadJWTSVID(ctx, params) } +func (c *CA) SignWorkloadWITSVID(ctx context.Context, params ca.WorkloadWITSVIDParams) (string, error) { + if c.err != nil { + return "", c.err + } + return c.ca.SignWorkloadWITSVID(ctx, params) +} + func (c *CA) TaintedAuthorities() <-chan []*x509.Certificate { return c.ca.TaintedAuthorities() } @@ -199,10 +223,22 @@ func (c *CA) JWTSVIDTTL() time.Duration { return c.options.JWTSVIDTTL } +func (c *CA) WITSVIDTTL() time.Duration { + return c.options.WITSVIDTTL +} + func (c *CA) IsJWTSVIDsDisabled() bool { return c.disableJWTSVIDs } +func (c *CA) IsWITSVIDsDisabled() bool { + return c.disableWITSVIDs +} + func (c *CA) SetDisableJWTSVIDs(disableJWTSVIDs bool) { c.disableJWTSVIDs = disableJWTSVIDs } + +func (c *CA) SetDisableWITSVIDs(disableWITSVIDs bool) { + c.disableWITSVIDs = disableWITSVIDs +} From dd325ffe08db8be4830fbb419dae4fe7b5d96991 Mon Sep 17 00:00:00 2001 From: Sorin Dumitru Date: Sun, 23 Feb 2025 07:30:53 +0000 Subject: [PATCH 3/6] svid.v1: add support for signing WIT-SVIDs Signed-off-by: Sorin Dumitru --- pkg/server/api/svid/v1/service.go | 199 +++++- pkg/server/api/svid/v1/service_test.go | 810 ++++++++++++++++++++++++- pkg/server/endpoints/middleware.go | 6 +- 3 files changed, 985 insertions(+), 30 deletions(-) diff --git a/pkg/server/api/svid/v1/service.go b/pkg/server/api/svid/v1/service.go index 6c127a0711..a41cbd3e57 100644 --- a/pkg/server/api/svid/v1/service.go +++ b/pkg/server/api/svid/v1/service.go @@ -6,6 +6,7 @@ import ( "strings" "time" + "github.com/go-jose/go-jose/v4" "github.com/sirupsen/logrus" "github.com/spiffe/go-spiffe/v2/spiffeid" svidv1 "github.com/spiffe/spire-api-sdk/proto/spire/api/server/svid/v1" @@ -162,7 +163,20 @@ func (s *Service) MintJWTSVID(ctx context.Context, req *svidv1.MintJWTSVIDReques func (s *Service) MintWITSVID(ctx context.Context, req *svidv1.MintWITSVIDRequest) (*svidv1.MintWITSVIDResponse, error) { log := rpccontext.Logger(ctx) - return nil, api.MakeErr(log, codes.Unimplemented, "WIT-SVID functionality is not yet implemented", nil) + if s.isWITSVIDsDisabled() { + return nil, api.MakeErr(log, codes.Unimplemented, "WIT functionality is disabled", nil) + } + + rpccontext.AddRPCAuditFields(ctx, s.fieldsFromWITSvidParams(ctx, req.Id, req.Ttl)) + witsvid, err := s.mintWITSVID(ctx, req.Id, req.PublicKey, req.Ttl) + if err != nil { + return nil, err + } + rpccontext.AuditRPC(ctx) + + return &svidv1.MintWITSVIDResponse{ + Svid: witsvid, + }, nil } func (s *Service) BatchNewX509SVID(ctx context.Context, req *svidv1.BatchNewX509SVIDRequest) (*svidv1.BatchNewX509SVIDResponse, error) { @@ -393,7 +407,169 @@ func (s *Service) NewJWTSVID(ctx context.Context, req *svidv1.NewJWTSVIDRequest) func (s *Service) BatchNewWITSVID(ctx context.Context, req *svidv1.BatchNewWITSVIDRequest) (*svidv1.BatchNewWITSVIDResponse, error) { log := rpccontext.Logger(ctx) - return nil, api.MakeErr(log, codes.Unimplemented, "WIT-SVID functionality is not yet implemented", nil) + + if len(req.Params) == 0 { + return nil, api.MakeErr(log, codes.InvalidArgument, "missing parameters", nil) + } + + if err := rpccontext.RateLimit(ctx, len(req.Params)); err != nil { + return nil, api.MakeErr(log, status.Code(err), "rejecting request due to certificate signing rate limiting", err) + } + + requestedEntries := make(map[string]struct{}) + for _, svidParam := range req.Params { + requestedEntries[svidParam.GetEntryId()] = struct{}{} + } + + // Fetch authorized entries + entriesMap, err := s.findEntries(ctx, log, requestedEntries) + if err != nil { + return nil, err + } + + var results []*svidv1.BatchNewWITSVIDResponse_Result + for _, svidParam := range req.Params { + // Create new SVID + r := s.newWITSVID(ctx, svidParam, entriesMap) + results = append(results, r) + spiffeID := "" + if r.Svid != nil { + id, err := idutil.IDProtoString(r.Svid.Id) + if err == nil { + spiffeID = id + } + } + + rpccontext.AuditRPCWithTypesStatus(ctx, r.Status, func() logrus.Fields { + fields := logrus.Fields{ + telemetry.RegistrationID: svidParam.EntryId, + telemetry.SPIFFEID: spiffeID, + } + + if r.Svid != nil { + fields[telemetry.ExpiresAt] = r.Svid.ExpiresAt + } + + return fields + }) + } + + return &svidv1.BatchNewWITSVIDResponse{Results: results}, nil +} + +func (s *Service) mintWITSVID(ctx context.Context, protoID *types.SPIFFEID, publicKeyDer []byte, ttl int32) (*types.WITSVID, error) { + log := rpccontext.Logger(ctx) + + id, err := api.TrustDomainWorkloadIDFromProto(ctx, s.td, protoID) + if err != nil { + return nil, api.MakeErr(log, codes.InvalidArgument, "invalid SPIFFE ID", err) + } + + log = log.WithField(telemetry.SPIFFEID, id.String()) + + publicKey, err := x509.ParsePKIXPublicKey(publicKeyDer) + if err != nil { + return nil, api.MakeErr(log, codes.InvalidArgument, "invalid public key", err) + } + + token, err := s.ca.SignWorkloadWITSVID(ctx, ca.WorkloadWITSVIDParams{ + SPIFFEID: id, + TTL: time.Duration(ttl) * time.Second, + PublicKey: jose.JSONWebKey{ + Key: publicKey, + }, + }) + if err != nil { + return nil, api.MakeErr(log, codes.Internal, "failed to sign WIT-SVID", err) + } + + issuedAt, expiresAt, err := jwtsvid.GetTokenExpiry(token) + if err != nil { + return nil, api.MakeErr(log, codes.Internal, "failed to get WIT-SVID expiry", err) + } + + log.WithFields(logrus.Fields{ + telemetry.Expiration: expiresAt.Format(time.RFC3339), + }).Debug("Server CA successfully signed WIT-SVID") + + return &types.WITSVID{ + Token: token, + Id: api.ProtoFromID(id), + ExpiresAt: expiresAt.Unix(), + IssuedAt: issuedAt.Unix(), + }, nil +} + +// newWITSVID creates an WIT-SVID using data from registration entry and public key from input params +func (s *Service) newWITSVID(ctx context.Context, param *svidv1.NewWITSVIDParams, entries map[string]api.ReadOnlyEntry) *svidv1.BatchNewWITSVIDResponse_Result { + log := rpccontext.Logger(ctx) + + switch { + case param.EntryId == "": + return &svidv1.BatchNewWITSVIDResponse_Result{ + Status: api.MakeStatus(log, codes.InvalidArgument, "missing entry ID", nil), + } + case len(param.PublicKey) == 0: + return &svidv1.BatchNewWITSVIDResponse_Result{ + Status: api.MakeStatus(log, codes.InvalidArgument, "missing public key", nil), + } + } + + log = log.WithField(telemetry.RegistrationID, param.EntryId) + + entry, ok := entries[param.EntryId] + if !ok { + return &svidv1.BatchNewWITSVIDResponse_Result{ + Status: api.MakeStatus(log, codes.NotFound, "entry not found or not authorized", nil), + } + } + + publicKey, err := x509.ParsePKIXPublicKey(param.PublicKey) + if err != nil { + return &svidv1.BatchNewWITSVIDResponse_Result{ + Status: api.MakeStatus(log, codes.InvalidArgument, "malformed public key", err), + } + } + + spiffeID, err := api.TrustDomainMemberIDFromProto(ctx, s.td, entry.GetSpiffeId()) + if err != nil { + // This shouldn't be the case unless there is invalid data in the datastore + return &svidv1.BatchNewWITSVIDResponse_Result{ + Status: api.MakeStatus(log, codes.Internal, "entry has malformed SPIFFE ID", err), + } + } + + log = log.WithField(telemetry.SPIFFEID, spiffeID.String()) + + witSvid, err := s.ca.SignWorkloadWITSVID(ctx, ca.WorkloadWITSVIDParams{ + SPIFFEID: spiffeID, + PublicKey: jose.JSONWebKey{ + Key: publicKey, + }, + // TODO: add its own TTL + TTL: time.Duration(entry.GetX509SvidTtl()) * time.Second, + }) + if err != nil { + return &svidv1.BatchNewWITSVIDResponse_Result{ + Status: api.MakeStatus(log, codes.Internal, "failed to sign WIT-SVID", err), + } + } + + issuedAt, expiresAt, err := jwtsvid.GetTokenExpiry(witSvid) + if err != nil { + return &svidv1.BatchNewWITSVIDResponse_Result{ + Status: api.MakeStatus(log, codes.Internal, "failed to get WIT-SVID expiry", err), + } + } + + return &svidv1.BatchNewWITSVIDResponse_Result{ + Svid: &types.WITSVID{ + Id: entry.GetSpiffeId(), + Token: witSvid, + ExpiresAt: expiresAt.Unix(), + IssuedAt: issuedAt.Unix(), + }, + } } func (s *Service) NewDownstreamX509CA(ctx context.Context, req *svidv1.NewDownstreamX509CARequest) (*svidv1.NewDownstreamX509CAResponse, error) { @@ -470,6 +646,10 @@ func (s *Service) isJWTSVIDsDisabled() bool { return s.ca.IsJWTSVIDsDisabled() } +func (s *Service) isWITSVIDsDisabled() bool { + return s.ca.IsWITSVIDsDisabled() +} + func (s Service) fieldsFromJWTSvidParams(ctx context.Context, protoID *types.SPIFFEID, audience []string, ttl int32) logrus.Fields { fields := logrus.Fields{ telemetry.TTL: ttl, @@ -489,6 +669,21 @@ func (s Service) fieldsFromJWTSvidParams(ctx context.Context, protoID *types.SPI return fields } +func (s Service) fieldsFromWITSvidParams(ctx context.Context, protoID *types.SPIFFEID, ttl int32) logrus.Fields { + fields := logrus.Fields{ + telemetry.TTL: ttl, + } + if protoID != nil { + // Don't care about parsing error + id, err := api.TrustDomainWorkloadIDFromProto(ctx, s.td, protoID) + if err == nil { + fields[telemetry.SPIFFEID] = id.String() + } + } + + return fields +} + func parseAndCheckCSR(ctx context.Context, csrBytes []byte) (*x509.CertificateRequest, error) { log := rpccontext.Logger(ctx) diff --git a/pkg/server/api/svid/v1/service_test.go b/pkg/server/api/svid/v1/service_test.go index 9e54e04822..6a921c93b9 100644 --- a/pkg/server/api/svid/v1/service_test.go +++ b/pkg/server/api/svid/v1/service_test.go @@ -2,6 +2,8 @@ package svid_test import ( "context" + "crypto/ecdsa" + "crypto/elliptic" "crypto/rand" "crypto/x509" "crypto/x509/pkix" @@ -844,6 +846,231 @@ func TestServiceMintJWTSVID(t *testing.T) { } } +func TestServiceMintWITSVID(t *testing.T) { + test := setupServiceTest(t) + defer test.Cleanup() + + now := test.ca.Clock().Now().UTC() + issuedAt := now + expiresAt := now.Add(test.ca.WITSVIDTTL()) + + for _, tt := range []struct { + name string + + code codes.Code + err string + expiresAt time.Time + id spiffeid.ID + ttl time.Duration + disableWITSVIDs bool + failMinting bool + expectLogs []spiretest.LogEntry + }{ + { + name: "success", + expiresAt: expiresAt, + id: workloadID, + expectLogs: []spiretest.LogEntry{ + { + Level: logrus.InfoLevel, + Message: "API accessed", + Data: logrus.Fields{ + telemetry.Status: "success", + telemetry.Type: "audit", + telemetry.SPIFFEID: "spiffe://example.org/workload1", + telemetry.TTL: "0", + }, + }, + }, + }, + { + name: "success custom TTL", + ttl: 10 * time.Second, + expiresAt: now.Add(10 * time.Second), + id: workloadID, + expectLogs: []spiretest.LogEntry{ + { + Level: logrus.InfoLevel, + Message: "API accessed", + Data: logrus.Fields{ + telemetry.Status: "success", + telemetry.Type: "audit", + telemetry.SPIFFEID: "spiffe://example.org/workload1", + telemetry.TTL: "10", + }, + }, + }, + }, + { + name: "bad id", + code: codes.InvalidArgument, + id: spiffeid.ID{}, + err: "invalid SPIFFE ID: trust domain is missing", + expectLogs: []spiretest.LogEntry{ + { + Level: logrus.ErrorLevel, + Message: "Invalid argument: invalid SPIFFE ID", + Data: logrus.Fields{ + logrus.ErrorKey: "trust domain is missing", + }, + }, + { + Level: logrus.InfoLevel, + Message: "API accessed", + Data: logrus.Fields{ + telemetry.Status: "error", + telemetry.Type: "audit", + telemetry.StatusCode: "InvalidArgument", + telemetry.StatusMessage: "invalid SPIFFE ID: trust domain is missing", + telemetry.TTL: "0", + }, + }, + }, + }, + { + name: "invalid trust domain", + code: codes.InvalidArgument, + id: spiffeid.RequireFromString("spiffe://invalid.test/workload1"), + err: `invalid SPIFFE ID: "spiffe://invalid.test/workload1" is not a member of trust domain "example.org"`, + expectLogs: []spiretest.LogEntry{ + { + Level: logrus.ErrorLevel, + Message: "Invalid argument: invalid SPIFFE ID", + Data: logrus.Fields{ + logrus.ErrorKey: `"spiffe://invalid.test/workload1" is not a member of trust domain "example.org"`, + }, + }, + { + Level: logrus.InfoLevel, + Message: "API accessed", + Data: logrus.Fields{ + telemetry.Status: "error", + telemetry.Type: "audit", + telemetry.StatusCode: "InvalidArgument", + telemetry.StatusMessage: `invalid SPIFFE ID: "spiffe://invalid.test/workload1" is not a member of trust domain "example.org"`, + telemetry.TTL: "0", + }, + }, + }, + }, + { + name: "SPIFFE ID is not for a workload in the trust domain", + code: codes.InvalidArgument, + id: spiffeid.RequireFromString("spiffe://invalid.test"), + err: `invalid SPIFFE ID: "spiffe://invalid.test" is not a member of trust domain "example.org"`, + expectLogs: []spiretest.LogEntry{ + { + Level: logrus.ErrorLevel, + Message: "Invalid argument: invalid SPIFFE ID", + Data: logrus.Fields{ + logrus.ErrorKey: `"spiffe://invalid.test" is not a member of trust domain "example.org"`, + }, + }, + { + Level: logrus.InfoLevel, + Message: "API accessed", + Data: logrus.Fields{ + telemetry.Status: "error", + telemetry.Type: "audit", + telemetry.StatusCode: "InvalidArgument", + telemetry.StatusMessage: `invalid SPIFFE ID: "spiffe://invalid.test" is not a member of trust domain "example.org"`, + telemetry.TTL: "0", + }, + }, + }, + }, + { + name: "fails minting", + code: codes.Internal, + err: "failed to sign WIT-SVID: oh no", + failMinting: true, + expiresAt: expiresAt, + id: workloadID, + expectLogs: []spiretest.LogEntry{ + { + Level: logrus.ErrorLevel, + Message: "Failed to sign WIT-SVID", + Data: logrus.Fields{ + logrus.ErrorKey: "oh no", + telemetry.SPIFFEID: "spiffe://example.org/workload1", + }, + }, + { + Level: logrus.InfoLevel, + Message: "API accessed", + Data: logrus.Fields{ + telemetry.Status: "error", + telemetry.Type: "audit", + telemetry.StatusCode: "Internal", + telemetry.StatusMessage: "failed to sign WIT-SVID: oh no", + telemetry.SPIFFEID: "spiffe://example.org/workload1", + telemetry.TTL: "0", + }, + }, + }, + }, + { + name: "wit is disabled", + disableWITSVIDs: true, + code: codes.Unimplemented, + expiresAt: expiresAt, + id: workloadID, + err: "WIT functionality is disabled", + expectLogs: []spiretest.LogEntry{ + { + Level: logrus.ErrorLevel, + Message: "WIT functionality is disabled", + }, + { + Level: logrus.InfoLevel, + Message: "API accessed", + Data: logrus.Fields{ + telemetry.Status: "error", + telemetry.Type: "audit", + telemetry.StatusCode: "Unimplemented", + telemetry.StatusMessage: "WIT functionality is disabled", + }, + }, + }, + }, + } { + t.Run(tt.name, func(t *testing.T) { + test.logHook.Reset() + + test.ca.SetDisableWITSVIDs(tt.disableWITSVIDs) + if tt.failMinting { + test.ca.SetError(errors.New("oh no")) + } + + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + + keyDer, err := x509.MarshalPKIXPublicKey(&key.PublicKey) + require.NoError(t, err) + + resp, err := test.client.MintWITSVID(context.Background(), &svidv1.MintWITSVIDRequest{ + Id: api.ProtoFromID(tt.id), + Ttl: int32(tt.ttl / time.Second), + PublicKey: keyDer, + }) + + spiretest.AssertLogs(t, test.logHook.AllEntries(), tt.expectLogs) + // Check for expected errors + if tt.err != "" { + spiretest.RequireGRPCStatusContains(t, err, tt.code, tt.err) + require.Nil(t, resp) + + return + } + require.NoError(t, err) + require.NotNil(t, resp) + + // Verify response + verifyWITSVIDResponse(t, resp.Svid, tt.id, issuedAt, tt.expiresAt, expiresAt, tt.ttl) + }) + } +} + func TestServiceNewJWTSVID(t *testing.T) { test := setupServiceTest(t) defer test.Cleanup() @@ -1873,40 +2100,539 @@ func TestServiceBatchNewX509SVID(t *testing.T) { } } -func TestNewDownstreamX509CA(t *testing.T) { - type downstreamCaTest struct { - name string - err string - failSigning bool - failDataStore bool - rateLimiterErr error - entry *types.Entry - csr []byte - csrTemplate *x509.CertificateRequest - code codes.Code - fetcherErr string - expectLogs func([]byte) []spiretest.LogEntry - } - - downstreamEntry1 := &types.Entry{ - Id: "downstreamCA1", - ParentId: api.ProtoFromID(agentID), - SpiffeId: &types.SPIFFEID{TrustDomain: "example.org", Path: ""}, - Downstream: true, - } - +func BatchNewWITSVID(t *testing.T) { test := setupServiceTest(t) defer test.Cleanup() - _, csrErr := x509.ParseCertificateRequest([]byte{1, 2, 3}) + workloadEntry1 := &types.Entry{ + Id: "workload1", + ParentId: api.ProtoFromID(agentID), + SpiffeId: &types.SPIFFEID{TrustDomain: "example.org", Path: "/workload1"}, + } + workloadEntry2 := &types.Entry{ + Id: "workload2", + ParentId: api.ProtoFromID(agentID), + SpiffeId: &types.SPIFFEID{TrustDomain: "example.org", Path: "/workload2"}, + } + invalidEntry := &types.Entry{ + Id: "invalid", + SpiffeId: &types.SPIFFEID{}, + ParentId: api.ProtoFromID(agentID), + } + test.ef.entries = []*types.Entry{workloadEntry1, workloadEntry2, invalidEntry} now := test.ca.Clock().Now().UTC() - expiresAtFromCA := now.Add(test.ca.X509CATTL()).Unix() - for _, tt := range []downstreamCaTest{ + expiresAtFromCA := now.Add(test.ca.WITSVIDTTL()).Unix() + expiresAtFromCAStr := strconv.FormatInt(expiresAtFromCA, 10) + + _, invalidCsrErr := x509.ParseCertificateRequest([]byte{1, 2, 3}) + require.Error(t, invalidCsrErr) + + type expectResult struct { + entry *types.Entry + status *types.Status + } + + for _, tt := range []struct { + name string + code codes.Code + reqs []string + err string + expectLogs []spiretest.LogEntry + expectResults []*expectResult + failSigning bool + failCallerID bool + fetcherErr string + setPublicKey func() []byte + rateLimiterErr error + }{ { - name: "Malformed CSR", - rateLimiterErr: nil, + name: "success", + reqs: []string{workloadEntry1.Id}, + expectResults: []*expectResult{ + { + entry: workloadEntry1, + }, + }, + expectLogs: []spiretest.LogEntry{ + { + Level: logrus.InfoLevel, + Message: "API accessed", + Data: logrus.Fields{ + telemetry.Status: "success", + telemetry.Type: "audit", + telemetry.RegistrationID: "workload", + telemetry.ExpiresAt: expiresAtFromCAStr, + telemetry.SPIFFEID: "spiffe://example.org/workload1", + }, + }, + }, + }, { + name: "keep request order", + reqs: []string{workloadEntry1.Id, invalidEntry.Id, workloadEntry2.Id}, + expectResults: []*expectResult{ + { + entry: workloadEntry1, + }, + { + status: &types.Status{ + Code: int32(codes.Internal), + Message: "entry has malformed SPIFFE ID", + }, + }, + { + entry: workloadEntry2, + }, + }, + expectLogs: []spiretest.LogEntry{ + { + Level: logrus.InfoLevel, + Message: "API accessed", + Data: logrus.Fields{ + telemetry.Status: "success", + telemetry.Type: "audit", + telemetry.RegistrationID: "workload1", + telemetry.ExpiresAt: expiresAtFromCAStr, + telemetry.SPIFFEID: "spiffe://example.org/workload1", + }, + }, + { + Level: logrus.ErrorLevel, + Message: "Entry has malformed SPIFFE ID", + Data: logrus.Fields{ + telemetry.RegistrationID: "invalid", + logrus.ErrorKey: "trust domain is missing", + }, + }, + { + Level: logrus.InfoLevel, + Message: "API accessed", + Data: logrus.Fields{ + telemetry.Status: "error", + telemetry.Type: "audit", + telemetry.RegistrationID: "", + telemetry.StatusCode: "Internal", + telemetry.StatusMessage: "entry has malformed SPIFFE ID: trust domain is missing", + telemetry.SPIFFEID: "", + }, + }, + { + Level: logrus.InfoLevel, + Message: "API accessed", + Data: logrus.Fields{ + telemetry.Status: "success", + telemetry.Type: "audit", + telemetry.RegistrationID: "workload2", + telemetry.ExpiresAt: expiresAtFromCAStr, + telemetry.SPIFFEID: "spiffe://example.org/workload2", + }, + }, + }, + }, { + name: "no caller id", + reqs: []string{workloadEntry1.Id}, + code: codes.Internal, + err: "caller ID missing from request context", + failCallerID: true, + expectLogs: []spiretest.LogEntry{ + { + Level: logrus.ErrorLevel, + Message: "Caller ID missing from request context", + }, + { + Level: logrus.InfoLevel, + Message: "API accessed", + Data: logrus.Fields{ + telemetry.Status: "error", + telemetry.Type: "audit", + telemetry.StatusCode: "Internal", + telemetry.StatusMessage: "caller ID missing from request context", + }, + }, + }, + }, { + name: "no parameters", + reqs: []string{}, + code: codes.InvalidArgument, + err: "missing parameters", + expectLogs: []spiretest.LogEntry{ + { + Level: logrus.ErrorLevel, + Message: "Invalid argument: missing parameters", + }, + { + Level: logrus.InfoLevel, + Message: "API accessed", + Data: logrus.Fields{ + telemetry.Status: "error", + telemetry.Type: "audit", + telemetry.StatusCode: "InvalidArgument", + telemetry.StatusMessage: "missing parameters", + }, + }, + }, + }, { + name: "rate limit fails", + reqs: []string{workloadEntry1.Id}, + code: codes.Internal, + err: "rate limit error", + rateLimiterErr: status.Error(codes.Internal, "rate limit error"), + expectLogs: []spiretest.LogEntry{ + { + Level: logrus.ErrorLevel, + Message: "Rejecting request due to certificate signing rate limiting", + Data: logrus.Fields{ + logrus.ErrorKey: "rpc error: code = Internal desc = rate limit error", + }, + }, + { + Level: logrus.InfoLevel, + Message: "API accessed", + Data: logrus.Fields{ + telemetry.Status: "error", + telemetry.Type: "audit", + telemetry.StatusCode: "Internal", + telemetry.StatusMessage: "rejecting request due to certificate signing rate limiting: rate limit error", + }, + }, + }, + }, { + name: "fetch entries fails", + reqs: []string{workloadEntry1.Id}, + code: codes.Internal, + err: "failed to fetch registration entries", + fetcherErr: "fetcher fails", + expectLogs: []spiretest.LogEntry{ + { + Level: logrus.ErrorLevel, + Message: "Failed to fetch registration entries", + Data: logrus.Fields{ + logrus.ErrorKey: "rpc error: code = Internal desc = fetcher fails", + }, + }, + { + Level: logrus.InfoLevel, + Message: "API accessed", + Data: logrus.Fields{ + telemetry.Status: "error", + telemetry.Type: "audit", + telemetry.StatusCode: "Internal", + telemetry.StatusMessage: "failed to fetch registration entries: fetcher fails", + }, + }, + }, + }, { + name: "missing entry ID", + reqs: []string{""}, + expectResults: []*expectResult{ + { + status: &types.Status{ + Code: int32(codes.InvalidArgument), + Message: "missing entry ID", + }, + }, + }, + expectLogs: []spiretest.LogEntry{ + { + Level: logrus.ErrorLevel, + Message: "Invalid argument: missing entry ID", + }, + { + Level: logrus.InfoLevel, + Message: "API accessed", + Data: logrus.Fields{ + telemetry.Status: "error", + telemetry.Type: "audit", + telemetry.RegistrationID: "", + telemetry.StatusCode: "InvalidArgument", + telemetry.StatusMessage: "missing entry ID", + telemetry.SPIFFEID: "", + }, + }, + }, + }, { + name: "missing public key", + reqs: []string{workloadEntry1.Id}, + expectResults: []*expectResult{ + { + status: &types.Status{ + Code: int32(codes.InvalidArgument), + Message: `missing public key`, + }, + }, + }, + setPublicKey: func() []byte { + return []byte{} + }, + expectLogs: []spiretest.LogEntry{ + { + Level: logrus.ErrorLevel, + Message: "Invalid argument: missing CSR", + }, + { + Level: logrus.InfoLevel, + Message: "API accessed", + Data: logrus.Fields{ + telemetry.Status: "error", + telemetry.Type: "audit", + telemetry.RegistrationID: "workload", + telemetry.StatusCode: "InvalidArgument", + telemetry.StatusMessage: "missing CSR", + telemetry.SPIFFEID: "", + }, + }, + }, + }, { + name: "entry not found", + reqs: []string{"invalid entry"}, + expectResults: []*expectResult{ + { + status: &types.Status{ + Code: int32(codes.NotFound), + Message: "entry not found or not authorized", + }, + }, + }, + expectLogs: []spiretest.LogEntry{ + { + Level: logrus.ErrorLevel, + Message: "Entry not found or not authorized", + Data: logrus.Fields{ + telemetry.RegistrationID: "invalid entry", + }, + }, + { + Level: logrus.InfoLevel, + Message: "API accessed", + Data: logrus.Fields{ + telemetry.Status: "error", + telemetry.Type: "audit", + telemetry.RegistrationID: "invalid entry", + telemetry.StatusCode: "NotFound", + telemetry.StatusMessage: "entry not found or not authorized", + telemetry.SPIFFEID: "", + }, + }, + }, + }, { + name: "malformed public key", + reqs: []string{workloadEntry1.Id}, + expectResults: []*expectResult{ + { + status: &types.Status{ + Code: int32(codes.InvalidArgument), + Message: "malformed CSR: asn1:", + }, + }, + }, + setPublicKey: func() []byte { + return []byte{1, 2, 3} + }, + expectLogs: []spiretest.LogEntry{ + { + Level: logrus.ErrorLevel, + Message: "Invalid argument: malformed CSR", + Data: logrus.Fields{ + telemetry.RegistrationID: "workload", + logrus.ErrorKey: invalidCsrErr.Error(), + }, + }, + { + Level: logrus.InfoLevel, + Message: "API accessed", + Data: logrus.Fields{ + telemetry.Status: "error", + telemetry.Type: "audit", + telemetry.RegistrationID: "workload", + telemetry.StatusCode: "InvalidArgument", + telemetry.StatusMessage: fmt.Sprintf("malformed CSR: %v", invalidCsrErr), + telemetry.SPIFFEID: "", + }, + }, + }, + }, { + name: "malformed SPIFFE ID", + reqs: []string{invalidEntry.Id}, + expectResults: []*expectResult{ + { + status: &types.Status{ + Code: int32(codes.Internal), + Message: "entry has malformed SPIFFE ID", + }, + }, + }, + expectLogs: []spiretest.LogEntry{ + { + Level: logrus.ErrorLevel, + Message: "Entry has malformed SPIFFE ID", + Data: logrus.Fields{ + telemetry.RegistrationID: "invalid", + logrus.ErrorKey: "trust domain is missing", + }, + }, + { + Level: logrus.InfoLevel, + Message: "API accessed", + Data: logrus.Fields{ + telemetry.Status: "error", + telemetry.Type: "audit", + telemetry.RegistrationID: "invalid", + telemetry.StatusCode: "Internal", + telemetry.StatusMessage: "entry has malformed SPIFFE ID: trust domain is missing", + telemetry.SPIFFEID: "", + }, + }, + }, + }, { + name: "signing fails", + reqs: []string{workloadEntry1.Id}, + expectResults: []*expectResult{ + { + status: &types.Status{ + Code: int32(codes.Internal), + Message: "failed to sign WIT-SVID: oh no", + }, + }, + }, + failSigning: true, + expectLogs: []spiretest.LogEntry{ + { + Level: logrus.ErrorLevel, + Message: "Failed to sign WIT-SVID", + Data: logrus.Fields{ + telemetry.RegistrationID: "workload", + logrus.ErrorKey: "oh no", + telemetry.SPIFFEID: workloadID.String(), + }, + }, + { + Level: logrus.InfoLevel, + Message: "API accessed", + Data: logrus.Fields{ + telemetry.Status: "error", + telemetry.Type: "audit", + telemetry.RegistrationID: "workload", + telemetry.StatusCode: "Internal", + telemetry.StatusMessage: "failed to sign WIT-SVID: oh no", + telemetry.SPIFFEID: "", + }, + }, + }, + }, + } { + t.Run(tt.name, func(t *testing.T) { + test.logHook.Reset() + + if tt.failSigning { + test.ca.SetError(errors.New("oh no")) + } + + ctx := context.Background() + + test.rateLimiter.count = len(tt.reqs) + test.rateLimiter.err = tt.rateLimiterErr + + test.withCallerID = !tt.failCallerID + test.ef.err = tt.fetcherErr + + var params []*svidv1.NewWITSVIDParams + for _, entryID := range tt.reqs { + key := testkey.MustEC256() + keyDer, err := key.PublicKey.Bytes() + require.NoError(t, err) + + params = append(params, &svidv1.NewWITSVIDParams{ + EntryId: entryID, + PublicKey: keyDer, + }) + if tt.setPublicKey != nil { + params[len(params)-1].PublicKey = tt.setPublicKey() + } + } + + // Batch svids + resp, err := test.client.BatchNewWITSVID(ctx, &svidv1.BatchNewWITSVIDRequest{ + Params: params, + }) + spiretest.AssertLogs(t, test.logHook.AllEntries(), tt.expectLogs) + if tt.err != "" { + spiretest.RequireGRPCStatusContains(t, err, tt.code, tt.err) + require.Nil(t, resp) + + return + } + require.NoError(t, err) + require.NotNil(t, resp) + require.NotEmpty(t, resp.Results) + + for i, result := range resp.Results { + expect := tt.expectResults[i] + + if expect.status != nil { + require.Nil(t, result.Svid) + require.Equal(t, expect.status.Code, result.Status.Code) + require.Contains(t, result.Status.Message, expect.status.Message) + + continue + } + spiretest.AssertProtoEqual(t, &types.Status{Code: int32(codes.OK), Message: "OK"}, result.Status) + + require.NotNil(t, result.Svid) + + entry := expect.entry + + require.Equal(t, entry.SpiffeId.TrustDomain, result.Svid.Id.TrustDomain) + require.Equal(t, entry.SpiffeId.Path, result.Svid.Id.Path) + + svid := result.Svid + + entrySPIFFEID := idutil.RequireIDFromProto(entry.SpiffeId) + require.Equal(t, []*url.URL{entrySPIFFEID.URL()}, svid.Id) + + expiresAt := now.Add(test.ca.WITSVIDTTL()) + + require.Equal(t, expiresAt, svid.ExpiresAt) + require.Equal(t, expiresAt.UTC().Unix(), result.Svid.ExpiresAt) + } + }) + } +} + +func TestNewDownstreamX509CA(t *testing.T) { + type downstreamCaTest struct { + name string + err string + failSigning bool + failDataStore bool + rateLimiterErr error + entry *types.Entry + csr []byte + csrTemplate *x509.CertificateRequest + code codes.Code + fetcherErr string + expectLogs func([]byte) []spiretest.LogEntry + } + + downstreamEntry1 := &types.Entry{ + Id: "downstreamCA1", + ParentId: api.ProtoFromID(agentID), + SpiffeId: &types.SPIFFEID{TrustDomain: "example.org", Path: ""}, + Downstream: true, + } + + test := setupServiceTest(t) + defer test.Cleanup() + + _, csrErr := x509.ParseCertificateRequest([]byte{1, 2, 3}) + + now := test.ca.Clock().Now().UTC() + expiresAtFromCA := now.Add(test.ca.X509CATTL()).Unix() + + for _, tt := range []downstreamCaTest{ + { + name: "Malformed CSR", + rateLimiterErr: nil, err: "malformed CSR: asn1: structure error", failSigning: false, csr: []byte{1, 2, 3}, @@ -2229,6 +2955,36 @@ func verifyJWTSVIDResponse(t *testing.T, svid *types.JWTSVID, id spiffeid.ID, au } } +func verifyWITSVIDResponse(t *testing.T, svid *types.WITSVID, id spiffeid.ID, issuedAt, expiresAt, defaultExpiresAt time.Time, ttl time.Duration) { + require.NotNil(t, svid) + require.NotEmpty(t, svid.Token) + + token, err := jwt.ParseSigned(svid.Token, jwtsvid.AllowedSignatureAlgorithms) + require.NoError(t, err) + + var claims jwt.Claims + err = token.UnsafeClaimsWithoutVerification(&claims) + require.NoError(t, err) + + jwtsvidID, err := api.TrustDomainWorkloadIDFromProto(context.Background(), td, svid.Id) + require.NoError(t, err) + require.Equal(t, id, jwtsvidID) + require.Equal(t, id.String(), claims.Subject) + + require.NotNil(t, claims.IssuedAt) + require.Equal(t, issuedAt.Unix(), svid.IssuedAt) + require.Equal(t, issuedAt.Unix(), int64(*claims.IssuedAt)) + + require.NotNil(t, claims.Expiry) + if ttl == 0 { + require.Equal(t, defaultExpiresAt.Unix(), svid.ExpiresAt) + require.Equal(t, defaultExpiresAt.Unix(), int64(*claims.Expiry)) + } else { + require.Equal(t, expiresAt.Unix(), svid.ExpiresAt) + require.Equal(t, expiresAt.Unix(), int64(*claims.Expiry)) + } +} + type entryFetcher struct { err string entries []*types.Entry diff --git a/pkg/server/endpoints/middleware.go b/pkg/server/endpoints/middleware.go index 67481ec80d..03ac78cda7 100644 --- a/pkg/server/endpoints/middleware.go +++ b/pkg/server/endpoints/middleware.go @@ -144,6 +144,10 @@ func RateLimits(config RateLimitConfig) map[string]api.RateLimiter { } pushKeyLimit := middleware.PerIPLimit(limits.PushKeyLimitPerIP) + wsrLimit := middleware.DisabledLimit() + if config.Signing { + wsrLimit = middleware.PerIPLimit(limits.SignLimitPerIP) + } return map[string]api.RateLimiter{ "/spire.api.server.svid.v1.SVID/MintX509SVID": noLimit, @@ -151,7 +155,7 @@ func RateLimits(config RateLimitConfig) map[string]api.RateLimiter { "/spire.api.server.svid.v1.SVID/MintWITSVID": noLimit, "/spire.api.server.svid.v1.SVID/BatchNewX509SVID": csrLimit, "/spire.api.server.svid.v1.SVID/NewJWTSVID": jsrLimit, - "/spire.api.server.svid.v1.SVID/BatchNewWITSVID": jsrLimit, + "/spire.api.server.svid.v1.SVID/BatchNewWITSVID": wsrLimit, "/spire.api.server.svid.v1.SVID/NewDownstreamX509CA": csrLimit, "/spire.api.server.bundle.v1.Bundle/GetBundle": noLimit, "/spire.api.server.bundle.v1.Bundle/AppendBundle": noLimit, From d0425e0a4296e708e849e55347b86014aec4ab54 Mon Sep 17 00:00:00 2001 From: Sorin Dumitru Date: Sun, 23 Feb 2025 07:33:11 +0000 Subject: [PATCH 4/6] server cli: add wit mint command Signed-off-by: Sorin Dumitru --- cmd/spire-server/cli/cli.go | 4 + cmd/spire-server/cli/run/run.go | 21 +- cmd/spire-server/cli/wit/mint.go | 205 ++++++++++ cmd/spire-server/cli/wit/mint_test.go | 413 +++++++++++++++++++++ pkg/server/plugin/keymanager/keymanager.go | 16 + 5 files changed, 641 insertions(+), 18 deletions(-) create mode 100644 cmd/spire-server/cli/wit/mint.go create mode 100644 cmd/spire-server/cli/wit/mint_test.go diff --git a/cmd/spire-server/cli/cli.go b/cmd/spire-server/cli/cli.go index e86c939305..c1d2ffcf2f 100644 --- a/cmd/spire-server/cli/cli.go +++ b/cmd/spire-server/cli/cli.go @@ -18,6 +18,7 @@ import ( "github.com/spiffe/spire/cmd/spire-server/cli/token" "github.com/spiffe/spire/cmd/spire-server/cli/upstreamauthority" "github.com/spiffe/spire/cmd/spire-server/cli/validate" + "github.com/spiffe/spire/cmd/spire-server/cli/wit" "github.com/spiffe/spire/cmd/spire-server/cli/x509" "github.com/spiffe/spire/pkg/common/log" "github.com/spiffe/spire/pkg/common/version" @@ -124,6 +125,9 @@ func (cc *CLI) Run(ctx context.Context, args []string) int { "jwt mint": func() (cli.Command, error) { return jwt.NewMintCommand(), nil }, + "wit mint": func() (cli.Command, error) { + return wit.NewMintCommand(), nil + }, "validate": func() (cli.Command, error) { return validate.NewValidateCommand(), nil }, diff --git a/cmd/spire-server/cli/run/run.go b/cmd/spire-server/cli/run/run.go index 251ab55d6e..f56677565d 100644 --- a/cmd/spire-server/cli/run/run.go +++ b/cmd/spire-server/cli/run/run.go @@ -643,7 +643,7 @@ func NewServerConfig(c *Config, logOptions []log.Option, allowUnknownConfig bool } if c.Server.CAKeyType != "" { - keyType, err := keyTypeFromString(c.Server.CAKeyType) + keyType, err := keymanager.KeyTypeFromString(c.Server.CAKeyType) if err != nil { return nil, fmt.Errorf("error parsing ca_key_type: %w", err) } @@ -657,14 +657,14 @@ func NewServerConfig(c *Config, logOptions []log.Option, allowUnknownConfig bool } if c.Server.JWTKeyType != "" { - sc.JWTKeyType, err = keyTypeFromString(c.Server.JWTKeyType) + sc.JWTKeyType, err = keymanager.KeyTypeFromString(c.Server.JWTKeyType) if err != nil { return nil, fmt.Errorf("error parsing jwt_key_type: %w", err) } } if c.Server.Experimental.WITKeyType != "" { - sc.WITKeyType, err = keyTypeFromString(c.Server.Experimental.WITKeyType) + sc.WITKeyType, err = keymanager.KeyTypeFromString(c.Server.Experimental.WITKeyType) if err != nil { return nil, fmt.Errorf("error parsing wit_key_type: %w", err) } @@ -1081,21 +1081,6 @@ func defaultConfig() *Config { } } -func keyTypeFromString(s string) (keymanager.KeyType, error) { - switch strings.ToLower(s) { - case "rsa-2048": - return keymanager.RSA2048, nil - case "rsa-4096": - return keymanager.RSA4096, nil - case "ec-p256": - return keymanager.ECP256, nil - case "ec-p384": - return keymanager.ECP384, nil - default: - return keymanager.KeyTypeUnset, fmt.Errorf("key type %q is unknown; must be one of [rsa-2048, rsa-4096, ec-p256, ec-p384]", s) - } -} - // hasCompatibleTTL checks if we can guarantee the configured SVID TTL given the // configured CA TTL. If we detect that a new SVID TTL may be cut short due to // a scheduled CA rotation, this function will return false. This method should diff --git a/cmd/spire-server/cli/wit/mint.go b/cmd/spire-server/cli/wit/mint.go new file mode 100644 index 0000000000..2d9534cb6d --- /dev/null +++ b/cmd/spire-server/cli/wit/mint.go @@ -0,0 +1,205 @@ +package wit + +import ( + "context" + "crypto" + "crypto/x509" + "errors" + "flag" + "fmt" + "time" + + "github.com/go-jose/go-jose/v4" + "github.com/go-jose/go-jose/v4/jwt" + "github.com/mitchellh/cli" + "github.com/spiffe/go-spiffe/v2/spiffeid" + svidv1 "github.com/spiffe/spire-api-sdk/proto/spire/api/server/svid/v1" + "github.com/spiffe/spire-api-sdk/proto/spire/api/types" + serverutil "github.com/spiffe/spire/cmd/spire-server/util" + commoncli "github.com/spiffe/spire/pkg/common/cli" + "github.com/spiffe/spire/pkg/common/cliprinter" + "github.com/spiffe/spire/pkg/common/diskutil" + "github.com/spiffe/spire/pkg/common/jwtsvid" + "github.com/spiffe/spire/pkg/common/util" + "github.com/spiffe/spire/pkg/server/plugin/keymanager" +) + +func NewMintCommand() cli.Command { + return newMintCommand(commoncli.DefaultEnv) +} + +func newMintCommand(env *commoncli.Env) cli.Command { + return serverutil.AdaptCommand(env, &mintCommand{env: env}) +} + +// Test helper function, to have control over the workload key that is being generated +func newMintCommandWithKeyGenerator(env *commoncli.Env, workloadKeyGenerator func() (crypto.Signer, error)) cli.Command { + return serverutil.AdaptCommand(env, &mintCommand{env: env, workloadKeyGenerator: workloadKeyGenerator}) +} + +type mintCommand struct { + spiffeID string + keyType string + ttl time.Duration + write string + env *commoncli.Env + printer cliprinter.Printer + workloadKeyGenerator func() (crypto.Signer, error) +} + +func (c *mintCommand) Name() string { + return "wit mint" +} + +func (c *mintCommand) Synopsis() string { + return "Mints a WIT-SVID" +} + +func (c *mintCommand) AppendFlags(fs *flag.FlagSet) { + fs.StringVar(&c.spiffeID, "spiffeID", "", "SPIFFE ID of the WIT-SVID") + fs.StringVar(&c.keyType, "keyType", "ec-p256", "Key type of the WIT-SVID") + fs.DurationVar(&c.ttl, "ttl", 0, "TTL of the WIT-SVID") + fs.StringVar(&c.write, "write", "", "Directory to write output to instead of stdout") + cliprinter.AppendFlagWithCustomPretty(&c.printer, fs, c.env, prettyPrintMint) +} + +type mintResult struct { + PrivateKey string `json:"private_key"` + Svid *types.WITSVID `json:"svid"` +} + +func (c *mintCommand) Run(ctx context.Context, env *commoncli.Env, serverClient serverutil.ServerClient) error { + if c.spiffeID == "" { + return errors.New("spiffeID must be specified") + } + spiffeID, err := spiffeid.FromString(c.spiffeID) + if err != nil { + return err + } + ttl, err := ttlToSeconds(c.ttl) + if err != nil { + return fmt.Errorf("invalid value for TTL: %w", err) + } + + keyType, err := keymanager.KeyTypeFromString(c.keyType) + if err != nil { + return fmt.Errorf("invalid key-type: %w", err) + } + + if c.workloadKeyGenerator == nil { + c.workloadKeyGenerator = keyType.GenerateSigner + } + + signer, err := c.workloadKeyGenerator() + if err != nil { + return fmt.Errorf("could not generate public/private key pair: %w", err) + } + + publicKeyDer, err := x509.MarshalPKIXPublicKey(signer.Public()) + if err != nil { + return fmt.Errorf("could not marshal public/private key pair: %w", err) + } + + client := serverClient.NewSVIDClient() + resp, err := client.MintWITSVID(ctx, &svidv1.MintWITSVIDRequest{ + Id: &types.SPIFFEID{ + TrustDomain: spiffeID.TrustDomain().Name(), + Path: spiffeID.Path(), + }, + PublicKey: publicKeyDer, + Ttl: ttl, + }) + if err != nil { + return fmt.Errorf("unable to mint SVID: %w", err) + } + token := resp.Svid.Token + if err := c.validateToken(token, env); err != nil { + return err + } + + jwk := jose.JSONWebKey{ + Key: signer, + } + jwkJson, err := jwk.MarshalJSON() + if err != nil { + return fmt.Errorf("could not marshal private key: %w", err) + } + + // Print in stdout + if c.write == "" { + return c.printer.PrintStruct(&mintResult{ + PrivateKey: string(jwkJson), + Svid: resp.Svid, + }) + } + + tokenPath := env.JoinPath(c.write, "token") + keyPath := env.JoinPath(c.write, "key") + + if err := diskutil.WritePrivateFile(tokenPath, []byte(token)); err != nil { + return fmt.Errorf("unable to write token: %w", err) + } + if err := env.Printf("WIT-SVID written to %s\n", tokenPath); err != nil { + return err + } + if err := diskutil.WritePrivateFile(keyPath, jwkJson); err != nil { + return fmt.Errorf("unable to write key: %w", err) + } + return env.Printf("Private key written to %s\n", keyPath) +} + +func (c *mintCommand) validateToken(token string, env *commoncli.Env) error { + if token == "" { + return errors.New("server response missing token") + } + + eol, err := getWITSVIDEndOfLife(token) + if err != nil { + env.ErrPrintf("Unable to determine WIT-SVID lifetime: %v\n", err) + return nil + } + + if time.Until(eol) < c.ttl { + env.ErrPrintf("WIT-SVID lifetime was capped shorter than specified ttl; expires %q\n", eol.UTC().Format(time.RFC3339)) + } + + return nil +} + +func getWITSVIDEndOfLife(token string) (time.Time, error) { + t, err := jwt.ParseSigned(token, jwtsvid.AllowedSignatureAlgorithms) + if err != nil { + return time.Time{}, err + } + + claims := new(jwt.Claims) + if err := t.UnsafeClaimsWithoutVerification(claims); err != nil { + return time.Time{}, err + } + + if claims.Expiry == nil { + return time.Time{}, errors.New("no expiry claim") + } + + return claims.Expiry.Time(), nil +} + +// ttlToSeconds returns the number of seconds in a duration, rounded up to +// the nearest second +func ttlToSeconds(ttl time.Duration) (int32, error) { + return util.CheckedCast[int32]((ttl + time.Second - 1) / time.Second) +} + +func prettyPrintMint(env *commoncli.Env, results ...any) error { + resultInterface, ok := results[0].([]any) + if !ok { + return cliprinter.ErrInternalCustomPrettyFunc + } + + if wit, ok := resultInterface[0].(*mintResult); ok { + errToken := env.Println(wit.Svid.Token) + errKey := env.Println(wit.PrivateKey) + return errors.Join(errToken, errKey) + } + return cliprinter.ErrInternalCustomPrettyFunc +} diff --git a/cmd/spire-server/cli/wit/mint_test.go b/cmd/spire-server/cli/wit/mint_test.go new file mode 100644 index 0000000000..b70056069a --- /dev/null +++ b/cmd/spire-server/cli/wit/mint_test.go @@ -0,0 +1,413 @@ +package wit + +import ( + "bytes" + "context" + "crypto" + "errors" + "fmt" + "os" + "path/filepath" + "strings" + "sync" + "testing" + "time" + + "github.com/go-jose/go-jose/v4" + "github.com/go-jose/go-jose/v4/jwt" + svidv1 "github.com/spiffe/spire-api-sdk/proto/spire/api/server/svid/v1" + "github.com/spiffe/spire-api-sdk/proto/spire/api/types" + common_cli "github.com/spiffe/spire/pkg/common/cli" + "github.com/spiffe/spire/pkg/common/pemutil" + "github.com/spiffe/spire/test/clitest" + "github.com/spiffe/spire/test/spiretest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" +) + +var ( + testKey, _ = pemutil.ParseSigner([]byte(`-----BEGIN PRIVATE KEY----- +MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgy8ps3oQaBaSUFpfd +XM13o+VSA0tcZteyTvbOdIQNVnKhRANCAAT4dPIORBjghpL5O4h+9kyzZZUAFV9F +qNV3lKIL59N7G2B4ojbhfSNneSIIpP448uPxUnaunaQZ+/m7+x9oobIp +-----END PRIVATE KEY----- +`)) + testWorkloadKey, _ = pemutil.ParseSigner([]byte(`-----BEGIN PRIVATE KEY----- +MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgWItUvS9niefekUhG +VVYGK3UtPzVn5+LS2bmAahx7+iyhRANCAATRxgM9udjgUixZTHwg1TOYpdFkVZAo +dlmdCTY3trEN+pXoR+kecSyZFcvjYBaND9mOPsSHCAc5AAtFPQF/j0H/ +-----END PRIVATE KEY----- +`)) + availableFormats = []string{"pretty", "json"} + expectedUsage = `Usage of wit mint: + -keyType string + Key type of the WIT-SVID (default "ec-p256")` + clitest.AddrOutputUsage + + ` -spiffeID string + SPIFFE ID of the WIT-SVID + -ttl duration + TTL of the WIT-SVID + -write string + Directory to write output to instead of stdout +` +) + +func TestMintSynopsis(t *testing.T) { + cmd := NewMintCommand() + assert.Equal(t, "Mints a WIT-SVID", cmd.Synopsis()) +} + +func TestMintHelp(t *testing.T) { + stdout := new(bytes.Buffer) + stderr := new(bytes.Buffer) + cmd := newMintCommand(&common_cli.Env{ + Stdin: new(bytes.Buffer), + Stdout: stdout, + Stderr: stderr, + }) + assert.Equal(t, "flag: help requested", cmd.Help()) + assert.Empty(t, stdout.String()) + assert.Equal(t, expectedUsage, stderr.String()) +} + +func TestMintRun(t *testing.T) { + dir := spiretest.TempDir(t) + svidPath := filepath.Join(dir, "token") + keyPath := filepath.Join(dir, "key") + server := new(fakeSVIDServer) + addr := spiretest.StartGRPCServer(t, func(s *grpc.Server) { + svidv1.RegisterSVIDServer(s, server) + }) + + signer, err := jose.NewSigner(jose.SigningKey{ + Algorithm: jose.ES256, + Key: testKey, + }, nil) + require.NoError(t, err) + + expiry := time.Now().Add(30 * time.Second) + builder := jwt.Signed(signer).Claims(jwt.Claims{ + Expiry: jwt.NewNumericDate(expiry), + }) + token, err := builder.Serialize() + require.NoError(t, err) + + // Create expired token + expiredAt := time.Now().Add(-30 * time.Second) + builder = jwt.Signed(signer).Claims(jwt.Claims{ + Expiry: jwt.NewNumericDate(expiredAt), + }) + expiredToken, err := builder.Serialize() + require.NoError(t, err) + + testCases := []struct { + name string + + // flags + spiffeID string + expectID *types.SPIFFEID + ttl time.Duration + write string + extraArgs []string + + // results + code int + stdin string + expStderr string + + noRequestExpected bool + expStdoutPretty string + expStdoutJSON string + resp *svidv1.MintWITSVIDResponse + }{ + { + name: "missing spiffeID flag", + code: 1, + expStderr: "Error: spiffeID must be specified\n", + noRequestExpected: true, + }, + { + name: "invalid flag", + code: 1, + expStderr: fmt.Sprintf("flag provided but not defined: -bad\n%s", expectedUsage), + extraArgs: []string{"-bad", "flag"}, + noRequestExpected: true, + }, + { + name: "RPC fails", + spiffeID: "spiffe://domain.test/workload", + expectID: &types.SPIFFEID{ + TrustDomain: "domain.test", + Path: "/workload", + }, + code: 1, + expStderr: "Error: unable to mint SVID: rpc error: code = Unknown desc = response not configured in test\n", + }, + { + name: "response missing token", + spiffeID: "spiffe://domain.test/workload", + expectID: &types.SPIFFEID{ + TrustDomain: "domain.test", + Path: "/workload", + }, + code: 1, + expStderr: "Error: server response missing token\n", + resp: &svidv1.MintWITSVIDResponse{Svid: &types.WITSVID{}}, + }, + { + name: "malformed spiffeID", + spiffeID: "domain.test/workload", + expectID: &types.SPIFFEID{ + TrustDomain: "domain.test", + Path: "/workload", + }, + code: 1, + expStderr: "Error: scheme is missing or invalid\n", + noRequestExpected: true, + }, + { + name: "success with defaults", + spiffeID: "spiffe://domain.test/workload", + expectID: &types.SPIFFEID{ + TrustDomain: "domain.test", + Path: "/workload", + }, + code: 0, + resp: &svidv1.MintWITSVIDResponse{ + Svid: &types.WITSVID{ + Token: token, + Id: &types.SPIFFEID{ + TrustDomain: "domain.test", + Path: "/workload", + }, + ExpiresAt: 1628600000, + IssuedAt: 1628500000, + }, + }, + expStdoutPretty: token + "\n", + expStdoutJSON: fmt.Sprintf(`[{ + "svid": { + "token": "%s", + "id": { + "trust_domain": "domain.test", + "path": "/workload" + }, + "expires_at": 1628600000, + "issued_at": 1628500000 + }, + "private_key": "{\"kty\":\"EC\",\"crv\":\"P-256\",\"x\":\"0cYDPbnY4FIsWUx8INUzmKXRZFWQKHZZnQk2N7axDfo\",\"y\":\"lehH6R5xLJkVy-NgFo0P2Y4-xIcIBzkAC0U9AX-PQf8\",\"d\":\"WItUvS9niefekUhGVVYGK3UtPzVn5-LS2bmAahx7-iw\"}" +}]`, token), + }, + { + name: "write on invalid path", + spiffeID: "spiffe://domain.test/workload", + expectID: &types.SPIFFEID{ + TrustDomain: "domain.test", + Path: "/workload", + }, + code: 1, + resp: &svidv1.MintWITSVIDResponse{ + Svid: &types.WITSVID{ + Token: token, + }, + }, + write: "/doesnotexist", + expStdoutPretty: token + "\n", + expStdoutJSON: `{}`, + expStderr: "Error: unable to write token", + }, + { + name: "malformed token", + spiffeID: "spiffe://domain.test/workload", + expectID: &types.SPIFFEID{ + TrustDomain: "domain.test", + Path: "/workload", + }, + code: 0, + resp: &svidv1.MintWITSVIDResponse{ + Svid: &types.WITSVID{ + Token: "malformed token", + }, + }, + expStdoutPretty: "malformed token\n", + expStdoutJSON: `[{ + "svid": { + "token": "malformed token" + }, + "private_key": "{\"kty\":\"EC\",\"crv\":\"P-256\",\"x\":\"0cYDPbnY4FIsWUx8INUzmKXRZFWQKHZZnQk2N7axDfo\",\"y\":\"lehH6R5xLJkVy-NgFo0P2Y4-xIcIBzkAC0U9AX-PQf8\",\"d\":\"WItUvS9niefekUhGVVYGK3UtPzVn5-LS2bmAahx7-iw\"}" +}]`, + expStderr: "Unable to determine WIT-SVID lifetime: go-jose/go-jose: compact JWS format must have three parts\n", + }, + { + name: "expired token", + spiffeID: "spiffe://domain.test/workload", + expectID: &types.SPIFFEID{ + TrustDomain: "domain.test", + Path: "/workload", + }, + code: 0, + resp: &svidv1.MintWITSVIDResponse{ + Svid: &types.WITSVID{ + Token: expiredToken, + Id: &types.SPIFFEID{ + TrustDomain: "domain.test", + Path: "/workload", + }, + ExpiresAt: 1628500000, + IssuedAt: 1628600000, + }, + }, + expStdoutPretty: expiredToken + "\n", + expStdoutJSON: fmt.Sprintf(`[{ + "svid": { + "token": "%s", + "id": { + "trust_domain": "domain.test", + "path": "/workload" + }, + "expires_at": 1628500000, + "issued_at": 1628600000 + }, + "private_key": "{\"kty\":\"EC\",\"crv\":\"P-256\",\"x\":\"0cYDPbnY4FIsWUx8INUzmKXRZFWQKHZZnQk2N7axDfo\",\"y\":\"lehH6R5xLJkVy-NgFo0P2Y4-xIcIBzkAC0U9AX-PQf8\",\"d\":\"WItUvS9niefekUhGVVYGK3UtPzVn5-LS2bmAahx7-iw\"}" +}]`, expiredToken), + expStderr: fmt.Sprintf("WIT-SVID lifetime was capped shorter than specified ttl; expires %q\n", expiredAt.UTC().Format(time.RFC3339)), + }, + { + name: "success with ttl, output to file", + spiffeID: "spiffe://domain.test/workload", + expectID: &types.SPIFFEID{ + TrustDomain: "domain.test", + Path: "/workload", + }, + ttl: time.Minute, + code: 0, + write: ".", + resp: &svidv1.MintWITSVIDResponse{ + Svid: &types.WITSVID{ + Token: token, + }, + }, + expStdoutPretty: token + "\n", + expStdoutJSON: `{}`, + expStderr: fmt.Sprintf("WIT-SVID lifetime was capped shorter than specified ttl; expires %q\n", expiry.UTC().Format(time.RFC3339)), + }, + } + + for _, testCase := range testCases { + tt := testCase + for _, format := range availableFormats { + t.Run(fmt.Sprintf("%s using %s format", tt.name, format), func(t *testing.T) { + server.setMintWITSVIDResponse(tt.resp) + server.resetMintWITSVIDRequest() + + stdout := new(bytes.Buffer) + stderr := new(bytes.Buffer) + cmd := newMintCommandWithKeyGenerator(&common_cli.Env{ + Stdin: strings.NewReader(tt.stdin), + Stdout: stdout, + Stderr: stderr, + BaseDir: dir, + }, func() (crypto.Signer, error) { + return testWorkloadKey, nil + }) + + args := []string{clitest.AddrArg, clitest.GetAddr(addr)} + if tt.spiffeID != "" { + args = append(args, "-spiffeID", tt.spiffeID) + } + if tt.ttl != 0 { + args = append(args, "-ttl", fmt.Sprint(tt.ttl)) + } + if tt.write != "" { + args = append(args, "-write", tt.write) + } + args = append(args, tt.extraArgs...) + args = append(args, "-output", format) + + code := cmd.Run(args) + + assert.Equal(t, tt.code, code, "exit code does not match") + assert.Contains(t, stderr.String(), tt.expStderr, "stderr does not match") + + req := server.lastMintWITSVIDRequest() + if tt.noRequestExpected { + assert.Nil(t, req) + return + } + + if assert.NotNil(t, req) { + assert.Equal(t, tt.expectID, req.Id) + assert.Equal(t, int32(tt.ttl/time.Second), req.Ttl) + } + + // assert output file contents + if code == 0 { + if tt.write != "" { + assert.Equal(t, fmt.Sprintf("WIT-SVID written to %s\nPrivate key written to %s\n", svidPath, keyPath), + stdout.String(), "stdout does not write output path") + assertFileData(t, svidPath, tt.resp.Svid.Token) + } else { + requireOutputBasedOnFormat(t, format, stdout.String(), tt.expStdoutPretty, tt.expStdoutJSON) + } + } + }) + } + } +} + +type fakeSVIDServer struct { + svidv1.SVIDServer + + mu sync.Mutex + req *svidv1.MintWITSVIDRequest + resp *svidv1.MintWITSVIDResponse +} + +func (f *fakeSVIDServer) resetMintWITSVIDRequest() { + f.mu.Lock() + defer f.mu.Unlock() + f.req = nil +} + +func (f *fakeSVIDServer) lastMintWITSVIDRequest() *svidv1.MintWITSVIDRequest { + f.mu.Lock() + defer f.mu.Unlock() + return f.req +} + +func (f *fakeSVIDServer) setMintWITSVIDResponse(resp *svidv1.MintWITSVIDResponse) { + f.mu.Lock() + defer f.mu.Unlock() + f.resp = resp +} + +func (f *fakeSVIDServer) MintWITSVID(_ context.Context, req *svidv1.MintWITSVIDRequest) (*svidv1.MintWITSVIDResponse, error) { + f.mu.Lock() + defer f.mu.Unlock() + + f.req = req + if f.resp == nil { + return nil, errors.New("response not configured in test") + } + return f.resp, nil +} + +func assertFileData(t *testing.T, path string, expectedData string) { + b, err := os.ReadFile(path) + if assert.NoError(t, err) { + assert.Equal(t, expectedData, string(b)) + } +} + +func requireOutputBasedOnFormat(t *testing.T, format, stdoutString string, expectedStdoutPretty, expectedStdoutJSON string) { + switch format { + case "pretty": + require.Contains(t, stdoutString, expectedStdoutPretty) + case "json": + if expectedStdoutJSON != "" { + require.JSONEq(t, expectedStdoutJSON, stdoutString) + } else { + require.Empty(t, stdoutString) + } + } +} diff --git a/pkg/server/plugin/keymanager/keymanager.go b/pkg/server/plugin/keymanager/keymanager.go index 27ebf5e287..53533ef863 100644 --- a/pkg/server/plugin/keymanager/keymanager.go +++ b/pkg/server/plugin/keymanager/keymanager.go @@ -8,6 +8,7 @@ import ( "crypto/rand" "crypto/rsa" "fmt" + "strings" "github.com/spiffe/spire/pkg/common/catalog" ) @@ -47,6 +48,21 @@ const ( RSA4096 ) +func KeyTypeFromString(s string) (KeyType, error) { + switch strings.ToLower(s) { + case "rsa-2048": + return RSA2048, nil + case "rsa-4096": + return RSA4096, nil + case "ec-p256": + return ECP256, nil + case "ec-p384": + return ECP384, nil + default: + return KeyTypeUnset, fmt.Errorf("key type %q is unknown; must be one of [rsa-2048, rsa-4096, ec-p256, ec-p384]", s) + } +} + // GenerateSigner generates a new key for the given key type func (keyType KeyType) GenerateSigner() (crypto.Signer, error) { switch keyType { From 638dab15782daae2b13358dd354d92f43097d18c Mon Sep 17 00:00:00 2001 From: Sorin Dumitru Date: Mon, 29 Dec 2025 14:20:36 +0000 Subject: [PATCH 5/6] Check that we can mint WIT-SVIDs Signed-off-by: Sorin Dumitru --- .../suites/fetch-wit-svids/03-mint-svid | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) create mode 100755 test/integration/suites/fetch-wit-svids/03-mint-svid diff --git a/test/integration/suites/fetch-wit-svids/03-mint-svid b/test/integration/suites/fetch-wit-svids/03-mint-svid new file mode 100755 index 0000000000..54555cde4e --- /dev/null +++ b/test/integration/suites/fetch-wit-svids/03-mint-svid @@ -0,0 +1,18 @@ +#!/bin/bash + +for format in pretty json; do + for keyType in ec-p256 ec-p384 rsa-2048 rsa-4096; do + docker compose exec -T spire-server /opt//spire/bin/spire-server \ + wit mint -spiffeID "spiffe://domain.test/workload" \ + -keyType "${keyType}" \ + -output ${format} || fail-now "could not mint WIT-SVID using key type: ${keyType}" + done +done + +# Check that we can specify a custom TTL +docker compose exec -T spire-server /opt//spire/bin/spire-server \ + wit mint -spiffeID "spiffe://domain.test/workload" -ttl 60s || fail-now "could not mint WIT-SVID with custom TTL" + +# Check that WIT-SVID can be written to a directory +docker compose exec -T spire-server /opt//spire/bin/spire-server \ + wit mint -spiffeID "spiffe://domain.test/workload" -write /tmp || fail-now "could not write WIT-SVID to /tmp" From 56c2fd4ea8af12c329ba6fe2eaa627d83e3a02da Mon Sep 17 00:00:00 2001 From: Sorin Dumitru Date: Tue, 6 Jan 2026 18:22:50 +0000 Subject: [PATCH 6/6] Address review comments Signed-off-by: Sorin Dumitru --- cmd/spire-server/cli/cli.go | 31 +++++++++++++++++-- pkg/server/api/svid/v1/service.go | 2 +- .../suites/fetch-wit-svids/03-mint-svid | 6 ++-- 3 files changed, 32 insertions(+), 7 deletions(-) diff --git a/cmd/spire-server/cli/cli.go b/cmd/spire-server/cli/cli.go index c1d2ffcf2f..0e7c1636ac 100644 --- a/cmd/spire-server/cli/cli.go +++ b/cmd/spire-server/cli/cli.go @@ -3,6 +3,9 @@ package cli import ( "context" stdlog "log" + "os" + "slices" + "strings" "github.com/mitchellh/cli" "github.com/spiffe/spire/cmd/spire-server/cli/agent" @@ -20,6 +23,7 @@ import ( "github.com/spiffe/spire/cmd/spire-server/cli/validate" "github.com/spiffe/spire/cmd/spire-server/cli/wit" "github.com/spiffe/spire/cmd/spire-server/cli/x509" + "github.com/spiffe/spire/pkg/common/fflag" "github.com/spiffe/spire/pkg/common/log" "github.com/spiffe/spire/pkg/common/version" ) @@ -125,9 +129,6 @@ func (cc *CLI) Run(ctx context.Context, args []string) int { "jwt mint": func() (cli.Command, error) { return jwt.NewMintCommand(), nil }, - "wit mint": func() (cli.Command, error) { - return wit.NewMintCommand(), nil - }, "validate": func() (cli.Command, error) { return validate.NewValidateCommand(), nil }, @@ -169,9 +170,33 @@ func (cc *CLI) Run(ctx context.Context, args []string) int { }, } + addCommandsEnabledByFFlags(c.Commands) + exitStatus, err := c.Run() if err != nil { stdlog.Println(err) } return exitStatus } + +// addCommandsEnabledByFFlags adds commands that are currently available only +// through a feature flag. +// Feature flags support through the fflag package in SPIRE Server is +// designed to work only with the run command and the config file. +// Since feature flags are intended to be used by developers of a specific +// feature only, exposing them through command line arguments is not +// convenient. Instead, we use the SPIRE_SERVER_FFLAGS environment variable +// to read the configured SPIRE Server feature flags from the environment +// when other commands may be enabled through feature flags. +func addCommandsEnabledByFFlags(commands map[string]cli.CommandFactory) { + fflagsEnv := os.Getenv("SPIRE_SERVER_FFLAGS") + fflags := strings.Split(fflagsEnv, " ") + + flagWITSVID := slices.Contains(fflags, string(fflag.FlagWITSVID)) + + if flagWITSVID { + commands["wit mint"] = func() (cli.Command, error) { + return wit.NewMintCommand(), nil + } + } +} diff --git a/pkg/server/api/svid/v1/service.go b/pkg/server/api/svid/v1/service.go index a41cbd3e57..4c9ae09275 100644 --- a/pkg/server/api/svid/v1/service.go +++ b/pkg/server/api/svid/v1/service.go @@ -546,7 +546,7 @@ func (s *Service) newWITSVID(ctx context.Context, param *svidv1.NewWITSVIDParams PublicKey: jose.JSONWebKey{ Key: publicKey, }, - // TODO: add its own TTL + // TODO: add WIT specific TTL (https://github.com/spiffe/spire/issues/6535) TTL: time.Duration(entry.GetX509SvidTtl()) * time.Second, }) if err != nil { diff --git a/test/integration/suites/fetch-wit-svids/03-mint-svid b/test/integration/suites/fetch-wit-svids/03-mint-svid index 54555cde4e..0357cbdb21 100755 --- a/test/integration/suites/fetch-wit-svids/03-mint-svid +++ b/test/integration/suites/fetch-wit-svids/03-mint-svid @@ -2,7 +2,7 @@ for format in pretty json; do for keyType in ec-p256 ec-p384 rsa-2048 rsa-4096; do - docker compose exec -T spire-server /opt//spire/bin/spire-server \ + docker compose exec -T -e SPIRE_SERVER_FFLAGS="wit-svid" spire-server /opt/spire/bin/spire-server \ wit mint -spiffeID "spiffe://domain.test/workload" \ -keyType "${keyType}" \ -output ${format} || fail-now "could not mint WIT-SVID using key type: ${keyType}" @@ -10,9 +10,9 @@ for format in pretty json; do done # Check that we can specify a custom TTL -docker compose exec -T spire-server /opt//spire/bin/spire-server \ +docker compose exec -T -e SPIRE_SERVER_FFLAGS="wit-svid" spire-server /opt/spire/bin/spire-server \ wit mint -spiffeID "spiffe://domain.test/workload" -ttl 60s || fail-now "could not mint WIT-SVID with custom TTL" # Check that WIT-SVID can be written to a directory -docker compose exec -T spire-server /opt//spire/bin/spire-server \ +docker compose exec -T -e SPIRE_SERVER_FFLAGS="wit-svid" spire-server /opt/spire/bin/spire-server \ wit mint -spiffeID "spiffe://domain.test/workload" -write /tmp || fail-now "could not write WIT-SVID to /tmp"