From 68ecabf56af7966436c4cb211e219cb8eec9c80e Mon Sep 17 00:00:00 2001 From: daeMOn Date: Thu, 8 Oct 2020 16:53:04 +0200 Subject: [PATCH 01/24] add biscuit wrappers and helpers to generate, sign and verify hubauth biscuits --- go.sum | 37 +--- pkg/biscuit/biscuit.go | 137 ++++++++++++++ pkg/biscuit/biscuit_test.go | 57 ++++++ pkg/biscuit/signature.go | 162 ++++++++++++++++ pkg/biscuit/signature_test.go | 258 ++++++++++++++++++++++++++ pkg/biscuit/wrapper.go | 336 ++++++++++++++++++++++++++++++++++ 6 files changed, 958 insertions(+), 29 deletions(-) create mode 100644 pkg/biscuit/biscuit.go create mode 100644 pkg/biscuit/biscuit_test.go create mode 100644 pkg/biscuit/signature.go create mode 100644 pkg/biscuit/signature_test.go create mode 100644 pkg/biscuit/wrapper.go diff --git a/go.sum b/go.sum index 287ce43..aa14fc0 100644 --- a/go.sum +++ b/go.sum @@ -81,21 +81,11 @@ github.com/flynn/biscuit-go v0.0.0-20201015081742-15d7d351f345 h1:ME6bm5dwn9V2DU github.com/flynn/biscuit-go v0.0.0-20201015081742-15d7d351f345/go.mod h1:Sj4oR2hNkrZH1cf3Cj5DPHc3Xq0o61GWeau6UkZR+3c= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= -github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= -github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= -github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= -github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= -github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e h1:1r7pUrabqp18hOBcwBwiTsbnFeTZHV9eER/QT5JVZxY= -github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= +github.com/flynn/biscuit-go v0.0.0-20200907174027-193b7bdbbdca h1:LUZQQzaCT+gltxii4icyPH5oMdAP38JmbvO9aI0E4qM= +github.com/flynn/biscuit-go v0.0.0-20200907174027-193b7bdbbdca/go.mod h1:EMJZ3stAYtwaP763F5HcGjPjCnYu21V2TEsg/iw88I8= +github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= +github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= -github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= -github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFUx0Y= -github.com/golang/mock v1.4.0/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= -github.com/golang/mock v1.4.1/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= -github.com/golang/mock v1.4.3/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= -github.com/golang/mock v1.4.4/go.mod h1:l3mdAwkq5BuhzHwde/uurv3sEJeZMXNpwsxVWU71h+4= -github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= github.com/golang/protobuf v1.3.4/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= @@ -291,22 +281,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20201207232520-09787c993a3a h1:DcqTD9SDLc+1P/r1EmRBwnVsrOwW+kk2vWf9n+1sGhs= -golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20180816055513-1c9583448a9c/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190606165138-5da285871e9c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190624142023-c5567b49c5d0/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9 h1:SQFwaSi55rU7vdNs9Yr0Z324VNlrF+0wMqRXT4St8ck= golang.org/x/sys v0.0.0-20191001151750-bb3f8db39f24/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191228213918-04cbcbbfeed8/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200113162924-86b910548bc1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200122134326-e047566fdf82/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -324,6 +300,7 @@ golang.org/x/sys v0.0.0-20200803210538-64077c9b5642/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20200905004654-be1d3432aa8f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201201145000-ef89a241ccb3 h1:kzM6+9dur93BcC2kVlYl34cHU+TYZLanmpSJHVMmL64= golang.org/x/sys v0.0.0-20201201145000-ef89a241ccb3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201214095126-aec9a390925b h1:tv7/y4pd+sR8bcNb2D6o7BNU6zjWm0VjQLac+w7fNNM= golang.org/x/sys v0.0.0-20201214095126-aec9a390925b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -386,6 +363,7 @@ golang.org/x/tools v0.0.0-20200904185747-39188db58858/go.mod h1:Cj7w3i3Rnn0Xh82u golang.org/x/tools v0.0.0-20200916150407-587cf2330ce8/go.mod h1:z6u4i615ZeAfBE4XtMziQW1fSVJXACjjbWkB/mvPzlU= golang.org/x/tools v0.0.0-20201110124207-079ba7bd75cd/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.0.0-20201201161351-ac6f37ff4c2a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= +golang.org/x/tools v0.0.0-20201208233053-a543418bbed2 h1:vEtypaVub6UvKkiXZ2xx9QIvp9TL7sI7xp7vdi2kezA= golang.org/x/tools v0.0.0-20201208233053-a543418bbed2/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.0.0-20201211185031-d93e913c1a58 h1:1Bs6RVeBFtLZ8Yi1Hk07DiOqzvwLD/4hln4iahvFlag= golang.org/x/tools v0.0.0-20201211185031-d93e913c1a58/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= @@ -457,6 +435,7 @@ google.golang.org/genproto v0.0.0-20200904004341-0bd0a958aa1d/go.mod h1:FWY/as6D google.golang.org/genproto v0.0.0-20200916143405-f6a2fa72f0c4/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/genproto v0.0.0-20201109203340-2640f1f9cdfb/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/genproto v0.0.0-20201201144952-b05cb90ed32e/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= +google.golang.org/genproto v0.0.0-20201210142538-e3217bee35cc h1:BgQmMjmd7K1zov8j8lYULHW0WnmBGUIMp6+VDwlGErc= google.golang.org/genproto v0.0.0-20201210142538-e3217bee35cc/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/genproto v0.0.0-20201211151036-40ec1c210f7a h1:GnJAhasbD8HiT8DZMvsEx3QLVy/X0icq/MGr0MqRJ2M= google.golang.org/genproto v0.0.0-20201211151036-40ec1c210f7a/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= diff --git a/pkg/biscuit/biscuit.go b/pkg/biscuit/biscuit.go new file mode 100644 index 0000000..ef1a452 --- /dev/null +++ b/pkg/biscuit/biscuit.go @@ -0,0 +1,137 @@ +package biscuit + +import ( + "crypto/rand" + "fmt" + + "github.com/flynn/biscuit-go" + "github.com/flynn/biscuit-go/sig" + "github.com/flynn/hubauth/pkg/kmssign" +) + +type UserKeyPair struct { + Public []byte + Private []byte +} + +// GenerateSignable returns a biscuit which will only verify after being +// signed with the private key matching the given userPubkey. +func GenerateSignable(rootKey sig.Keypair, audience string, audienceKey *kmssign.Key, userPubkey []byte) ([]byte, error) { + builder := &hubauthBuilder{ + Builder: biscuit.NewBuilder(rand.Reader, rootKey), + } + + if err := builder.withAudienceSignature(audience, audienceKey); err != nil { + return nil, err + } + + if err := builder.withUserToSignFact(userPubkey); err != nil { + return nil, err + } + + b, err := builder.Build() + if err != nil { + return nil, err + } + + return b.Serialize() +} + +// Sign append a user signature on the given token and return it. +// The UserKeyPair key format to provide depends on the signature algorithm: +// - for ECDSA_P256_SHA256, the private key must be encoded in SEC 1, ASN.1 DER form, +// and the public key in PKIX, ASN.1 DER form. +func Sign(token []byte, rootPubKey sig.PublicKey, userKey *UserKeyPair) ([]byte, error) { + b, err := biscuit.Unmarshal(token) + if err != nil { + return nil, fmt.Errorf("biscuit: failed to unmarshal: %w", err) + } + + v, err := b.Verify(rootPubKey) + if err != nil { + return nil, fmt.Errorf("biscuit: failed to verify: %w", err) + } + verifier := &hubauthVerifier{ + Verifier: v, + } + + toSignData, err := verifier.getUserToSignData(userKey.Public, b.BlockCount()) + if err != nil { + return nil, fmt.Errorf("biscuit: failed to get to_sign data: %w", err) + } + + if err := verifier.ensureNotAlreadyUserSigned(toSignData.DataID, userKey.Public); err != nil { + return nil, fmt.Errorf("biscuit: previous signature check failed: %w", err) + } + + tokenHash, err := b.SHA256Sum(b.BlockCount()) + if err != nil { + return nil, err + } + + signData, err := userSign(tokenHash, userKey, toSignData) + if err != nil { + return nil, fmt.Errorf("biscuit: signature failed: %w", err) + } + + builder := &hubauthBlockBuilder{ + BlockBuilder: b.CreateBlock(), + } + if err := builder.withUserSignature(signData); err != nil { + return nil, fmt.Errorf("biscuit: failed to create signature block: %w", err) + } + + clientKey := sig.GenerateKeypair(rand.Reader) + b, err = b.Append(rand.Reader, clientKey, builder.Build()) + if err != nil { + return nil, fmt.Errorf("biscuit: failed to append signature block: %w", err) + } + + return b.Serialize() +} + +// Verify will verify the biscuit, the included audience and user signature, and return an error +// when anything is invalid. +func Verify(token []byte, rootPubKey sig.PublicKey, audience string, audienceKey *kmssign.Key) error { + b, err := biscuit.Unmarshal(token) + if err != nil { + return fmt.Errorf("biscuit: failed to unmarshal: %w", err) + } + + v, err := b.Verify(rootPubKey) + if err != nil { + return fmt.Errorf("biscuit: failed to verify: %w", err) + } + verifier := &hubauthVerifier{v} + + audienceVerificationData, err := verifier.getAudienceVerificationData(audience) + if err != nil { + return fmt.Errorf("biscuit: failed to retrieve audience signature data: %w", err) + } + + if err := verifyAudienceSignature(audienceKey, audienceVerificationData); err != nil { + return fmt.Errorf("biscuit: failed to verify audience signature: %w", err) + } + if err := verifier.withValidatedAudienceSignature(audienceVerificationData); err != nil { + return fmt.Errorf("biscuit: failed to add validated signature: %w", err) + } + + userVerificationData, err := verifier.getUserVerificationData() + if err != nil { + return fmt.Errorf("biscuit: failed to retrieve user signature data: %w", err) + } + + signedTokenHash, err := b.SHA256Sum(int(userVerificationData.SignedBlockCount)) + if err != nil { + return fmt.Errorf("biscuit: failed to generate token hash: %w", err) + } + + if err := verifyUserSignature(signedTokenHash, userVerificationData); err != nil { + return fmt.Errorf("biscuit: failed to verify user signature: %w", err) + } + if err := verifier.withValidatedUserSignature(userVerificationData); err != nil { + return fmt.Errorf("biscuit: failed to add validated signature: %w", err) + } + + return verifier.Verify() +} diff --git a/pkg/biscuit/biscuit_test.go b/pkg/biscuit/biscuit_test.go new file mode 100644 index 0000000..ed7b273 --- /dev/null +++ b/pkg/biscuit/biscuit_test.go @@ -0,0 +1,57 @@ +package biscuit + +import ( + "context" + "crypto/rand" + "testing" + + "github.com/flynn/biscuit-go/sig" + "github.com/flynn/hubauth/pkg/kmssign" + "github.com/flynn/hubauth/pkg/kmssign/kmssim" + "github.com/stretchr/testify/require" +) + +func TestBiscuit(t *testing.T) { + rootKey := sig.GenerateKeypair(rand.Reader) + audience := "http://random.audience.url" + + kms := kmssim.NewClient([]string{audience}) + audienceKey, err := kmssign.NewKey(context.Background(), kms, audience) + require.NoError(t, err) + + userKey := generateUserKeyPair(t) + + signableBiscuit, err := GenerateSignable(rootKey, audience, audienceKey, userKey.Public) + require.NoError(t, err) + t.Logf("signable biscuit size: %d", len(signableBiscuit)) + + t.Run("happy path", func(t *testing.T) { + signedBiscuit, err := Sign(signableBiscuit, rootKey.Public(), userKey) + require.NoError(t, err) + t.Logf("signed biscuit size: %d", len(signedBiscuit)) + + err = Verify(signedBiscuit, rootKey.Public(), audience, audienceKey) + require.NoError(t, err) + }) + + t.Run("user sign with wrong key", func(t *testing.T) { + _, err := Sign(signableBiscuit, rootKey.Public(), generateUserKeyPair(t)) + require.Error(t, err) + }) + + t.Run("verify wrong audience", func(t *testing.T) { + signedBiscuit, err := Sign(signableBiscuit, rootKey.Public(), userKey) + require.NoError(t, err) + + err = Verify(signedBiscuit, rootKey.Public(), "http://another.audience.url", audienceKey) + require.Error(t, err) + + wrongAudience := "http://another.audience.url" + kms := kmssim.NewClient([]string{wrongAudience}) + wrongAudienceKey, err := kmssign.NewKey(context.Background(), kms, wrongAudience) + require.NoError(t, err) + + err = Verify(signedBiscuit, rootKey.Public(), audience, wrongAudienceKey) + require.Error(t, err) + }) +} diff --git a/pkg/biscuit/signature.go b/pkg/biscuit/signature.go new file mode 100644 index 0000000..385e4bf --- /dev/null +++ b/pkg/biscuit/signature.go @@ -0,0 +1,162 @@ +package biscuit + +import ( + "crypto" + "crypto/ecdsa" + "crypto/rand" + "crypto/sha256" + "crypto/x509" + "errors" + "time" + + "github.com/flynn/biscuit-go" + "github.com/flynn/hubauth/pkg/kmssign" +) + +var ( + ErrUnsupportedSignatureAlg = errors.New("unsupported signature algorithm") + ErrInvalidSignature = errors.New("invalid signature") +) + +type SignatureAlg biscuit.Symbol + +const ( + ECDSA_P256_SHA256 SignatureAlg = "ECDSA_P256_SHA256" +) + +type userToSignData struct { + DataID biscuit.Integer + Alg biscuit.Symbol + Data biscuit.Bytes + SignedBlockCount biscuit.Integer +} + +type userSignatureData struct { + DataID biscuit.Integer + UserPubKey biscuit.Bytes + Signature biscuit.Bytes + SignedBlockCount biscuit.Integer + Nonce biscuit.Bytes + Timestamp biscuit.Date +} + +type userVerificationData struct { + DataID biscuit.Integer + Alg biscuit.Symbol + Data biscuit.Bytes + UserPubKey biscuit.Bytes + Signature biscuit.Bytes + SignedBlockCount biscuit.Integer + Nonce biscuit.Bytes + Timestamp biscuit.Date +} + +func userSign(tokenHash []byte, userKey *UserKeyPair, toSignData *userToSignData) (*userSignatureData, error) { + if len(tokenHash) == 0 { + return nil, errors.New("invalid tokenHash") + } + + signerTimestamp := time.Now() + signerNonce := make([]byte, nonceSize) + if _, err := rand.Read(signerNonce); err != nil { + return nil, err + } + + var dataToSign []byte + dataToSign = append(dataToSign, toSignData.Data...) + dataToSign = append(dataToSign, tokenHash...) + dataToSign = append(dataToSign, signerNonce...) + dataToSign = append(dataToSign, []byte(signerTimestamp.Format(time.RFC3339))...) + dataToSign = append(dataToSign, []byte(toSignData.SignedBlockCount.String())...) + + var signedData biscuit.Bytes + switch SignatureAlg(toSignData.Alg) { + case ECDSA_P256_SHA256: + privKey, err := x509.ParseECPrivateKey(userKey.Private) + if err != nil { + return nil, err + } + hash := sha256.Sum256(dataToSign) + signedData, err = ecdsa.SignASN1(rand.Reader, privKey, hash[:]) + if err != nil { + return nil, err + } + default: + return nil, ErrUnsupportedSignatureAlg + } + + return &userSignatureData{ + DataID: toSignData.DataID, + Nonce: signerNonce, + Signature: signedData, + SignedBlockCount: toSignData.SignedBlockCount, + Timestamp: biscuit.Date(signerTimestamp), + UserPubKey: userKey.Public, + }, nil +} + +func verifyUserSignature(signedTokenHash []byte, data *userVerificationData) error { + var signedData []byte + signedData = append(signedData, data.Data...) + signedData = append(signedData, signedTokenHash...) + signedData = append(signedData, data.Nonce...) + signedData = append(signedData, []byte(time.Time(data.Timestamp).Format(time.RFC3339))...) + signedData = append(signedData, []byte(data.SignedBlockCount.String())...) + + switch SignatureAlg(data.Alg) { + case ECDSA_P256_SHA256: + pk, err := x509.ParsePKIXPublicKey(data.UserPubKey) + if err != nil { + return err + } + pubkey, ok := pk.(*ecdsa.PublicKey) + if !ok { + return errors.New("invalid pubkey, not an *ecdsa.PublicKey") + } + + hash := sha256.Sum256(signedData) + if !ecdsa.VerifyASN1(pubkey, hash[:], data.Signature) { + return ErrInvalidSignature + } + return nil + default: + return ErrUnsupportedSignatureAlg + } +} + +type audienceVerificationData struct { + Audience biscuit.Symbol + Challenge biscuit.Bytes + Signature biscuit.Bytes +} + +func audienceSign(audience string, audienceKey *kmssign.Key) (*audienceVerificationData, error) { + challenge := make([]byte, challengeSize) + if _, err := rand.Reader.Read(challenge); err != nil { + return nil, err + } + + signedData := append(signStaticCtx, challenge...) + signedData = append(signedData, []byte(audience)...) + signedHash := sha256.Sum256(signedData) + signature, err := audienceKey.Sign(rand.Reader, signedHash[:], crypto.SHA256) + if err != nil { + return nil, err + } + + return &audienceVerificationData{ + Audience: biscuit.Symbol(audience), + Challenge: challenge, + Signature: signature, + }, nil +} + +func verifyAudienceSignature(audiencePubkey *kmssign.Key, data *audienceVerificationData) error { + signedData := append(signStaticCtx, data.Challenge...) + signedData = append(signedData, []byte(data.Audience)...) + hash := sha256.Sum256(signedData) + if !audiencePubkey.Verify(hash[:], data.Signature) { + return errors.New("invalid signature") + } + return nil +} diff --git a/pkg/biscuit/signature_test.go b/pkg/biscuit/signature_test.go new file mode 100644 index 0000000..5cb5d8b --- /dev/null +++ b/pkg/biscuit/signature_test.go @@ -0,0 +1,258 @@ +package biscuit + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "testing" + "time" + + "github.com/flynn/biscuit-go" + "github.com/stretchr/testify/require" +) + +func TestUserSignVerify(t *testing.T) { + tokenHash := make([]byte, 32) + _, err := rand.Read(tokenHash) + require.NoError(t, err) + + challenge := make([]byte, challengeSize) + _, err = rand.Read(challenge) + require.NoError(t, err) + + userKey := generateUserKeyPair(t) + + toSignData := &userToSignData{ + DataID: 1, + Alg: biscuit.Symbol(ECDSA_P256_SHA256), + Data: []byte("challenge"), + SignedBlockCount: 2, + } + + signedData, err := userSign(tokenHash, userKey, toSignData) + require.NoError(t, err) + require.NotEmpty(t, signedData.Signature) + require.Equal(t, biscuit.Integer(2), signedData.SignedBlockCount) + require.Equal(t, biscuit.Integer(1), signedData.DataID) + require.Equal(t, biscuit.Bytes(userKey.Public), signedData.UserPubKey) + + require.Len(t, signedData.Nonce, nonceSize) + zeroNonce := make([]byte, nonceSize) + require.NotEqual(t, biscuit.Bytes(zeroNonce), signedData.Nonce) + + require.WithinDuration(t, time.Now(), time.Time(signedData.Timestamp), 1*time.Second) + + require.NoError(t, verifyUserSignature(tokenHash, &userVerificationData{ + DataID: toSignData.DataID, + Alg: toSignData.Alg, + Data: toSignData.Data, + Nonce: signedData.Nonce, + Signature: signedData.Signature, + SignedBlockCount: signedData.SignedBlockCount, + Timestamp: signedData.Timestamp, + UserPubKey: signedData.UserPubKey, + })) +} + +func TestUserSignFail(t *testing.T) { + validTokenHash := make([]byte, 32) + _, err := rand.Read(validTokenHash) + require.NoError(t, err) + + validChallenge := make([]byte, challengeSize) + _, err = rand.Read(validChallenge) + require.NoError(t, err) + + invalidPrivateKey := &UserKeyPair{ + Private: make([]byte, 32), + } + + testCases := []struct { + desc string + tokenHash []byte + userKey *UserKeyPair + data *userToSignData + expectedErr error + }{ + { + desc: "empty tokenHash", + tokenHash: []byte{}, + }, + { + desc: "unsupported alg", + tokenHash: validTokenHash, + data: &userToSignData{ + Alg: "unsupported", + }, + expectedErr: ErrUnsupportedSignatureAlg, + }, + { + desc: "wrong private key encoding", + tokenHash: validTokenHash, + data: &userToSignData{ + Alg: biscuit.Symbol(ECDSA_P256_SHA256), + }, + userKey: invalidPrivateKey, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.desc, func(t *testing.T) { + _, err := userSign(testCase.tokenHash, testCase.userKey, testCase.data) + require.Error(t, err) + if testCase.expectedErr != nil { + require.Equal(t, testCase.expectedErr, err) + } + }) + } +} + +func TestVerifyUserSignatureFail(t *testing.T) { + tokenHash := []byte("token hash") + toSignData := &userToSignData{ + DataID: 1, + Alg: biscuit.Symbol(ECDSA_P256_SHA256), + Data: []byte("challenge"), + SignedBlockCount: 2, + } + + userKey := generateUserKeyPair(t) + invalidKey := generateUserKeyPair(t) + + signedData, err := userSign(tokenHash, userKey, toSignData) + require.NoError(t, err) + + rsaKey, err := rsa.GenerateKey(rand.Reader, 1024) + require.NoError(t, err) + wrongKeyKind, err := x509.MarshalPKIXPublicKey(&rsaKey.PublicKey) + require.NoError(t, err) + + testCases := []struct { + desc string + tokenHash []byte + data *userVerificationData + expectedErr error + }{ + { + desc: "unsupported alg", + expectedErr: ErrUnsupportedSignatureAlg, + data: &userVerificationData{ + Alg: "unknown", + }, + }, + { + desc: "invalid pubkey encoding", + data: &userVerificationData{ + Alg: biscuit.Symbol(ECDSA_P256_SHA256), + UserPubKey: make([]byte, 32), + }, + }, + { + desc: "invalid pubkey kind", + data: &userVerificationData{ + Alg: biscuit.Symbol(ECDSA_P256_SHA256), + UserPubKey: wrongKeyKind, + }, + }, + { + desc: "wrong pubkey", + tokenHash: tokenHash, + data: &userVerificationData{ + Alg: biscuit.Symbol(ECDSA_P256_SHA256), + UserPubKey: invalidKey.Public, + Data: toSignData.Data, + DataID: toSignData.DataID, + Nonce: signedData.Nonce, + Signature: signedData.Signature, + Timestamp: signedData.Timestamp, + SignedBlockCount: signedData.SignedBlockCount, + }, + }, + { + desc: "tampered token hash", + expectedErr: ErrInvalidSignature, + tokenHash: []byte("wrong"), + data: &userVerificationData{ + Alg: biscuit.Symbol(ECDSA_P256_SHA256), + UserPubKey: userKey.Public, + Data: toSignData.Data, + DataID: toSignData.DataID, + Nonce: signedData.Nonce, + Signature: signedData.Signature, + Timestamp: signedData.Timestamp, + SignedBlockCount: signedData.SignedBlockCount, + }, + }, + { + desc: "tampered nonce", + expectedErr: ErrInvalidSignature, + tokenHash: tokenHash, + data: &userVerificationData{ + Alg: biscuit.Symbol(ECDSA_P256_SHA256), + UserPubKey: userKey.Public, + Data: toSignData.Data, + DataID: toSignData.DataID, + Nonce: []byte("another nonce"), + Signature: signedData.Signature, + SignedBlockCount: signedData.SignedBlockCount, + Timestamp: signedData.Timestamp, + }, + }, + { + desc: "tampered timestamp", + expectedErr: ErrInvalidSignature, + tokenHash: tokenHash, + data: &userVerificationData{ + Alg: biscuit.Symbol(ECDSA_P256_SHA256), + UserPubKey: userKey.Public, + Data: toSignData.Data, + DataID: toSignData.DataID, + Nonce: signedData.Nonce, + Signature: signedData.Signature, + Timestamp: biscuit.Date(time.Now().Add(1 * time.Second)), + SignedBlockCount: signedData.SignedBlockCount, + }, + }, + { + desc: "tampered signedBlockCount", + expectedErr: ErrInvalidSignature, + tokenHash: tokenHash, + data: &userVerificationData{ + Alg: biscuit.Symbol(ECDSA_P256_SHA256), + UserPubKey: userKey.Public, + Data: toSignData.Data, + DataID: toSignData.DataID, + Nonce: signedData.Nonce, + Signature: signedData.Signature, + Timestamp: signedData.Timestamp, + SignedBlockCount: signedData.SignedBlockCount + 1, + }, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.desc, func(t *testing.T) { + + err := verifyUserSignature(testCase.tokenHash, testCase.data) + require.Error(t, err) + if testCase.expectedErr != nil { + require.Equal(t, testCase.expectedErr, err) + } + }) + } +} + +func generateUserKeyPair(t *testing.T) *UserKeyPair { + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + privBytes, err := x509.MarshalECPrivateKey(priv) + require.NoError(t, err) + pubBytes, err := x509.MarshalPKIXPublicKey(&priv.PublicKey) + require.NoError(t, err) + return &UserKeyPair{ + Private: privBytes, + Public: pubBytes, + } +} diff --git a/pkg/biscuit/wrapper.go b/pkg/biscuit/wrapper.go new file mode 100644 index 0000000..ec16ebd --- /dev/null +++ b/pkg/biscuit/wrapper.go @@ -0,0 +1,336 @@ +package biscuit + +import ( + "bytes" + "crypto/rand" + "errors" + "fmt" + + "github.com/flynn/biscuit-go" + "github.com/flynn/hubauth/pkg/kmssign" +) + +var ( + ErrAlreadySigned = errors.New("already signed") + ErrInvalidToSignDataPrefix = errors.New("invalid to_sign data prefix") +) + +var ( + signStaticCtx = []byte("biscuit-pop-v0") + challengeSize = 16 + nonceSize = 16 +) + +type hubauthBuilder struct { + biscuit.Builder +} + +// withUserToSignFact add an authority should_sign fact and associated data to the biscuit +// with an authority caveat requiring the verifier to provide a valid_signature fact. +// the verifier is responsible of ensuring that a valid signature exists over the data. +func (b *hubauthBuilder) withUserToSignFact(userPubkey []byte) error { + dataID := biscuit.Integer(0) + + if err := b.AddAuthorityFact(biscuit.Fact{Predicate: biscuit.Predicate{ + Name: "should_sign", + IDs: []biscuit.Atom{ + dataID, + biscuit.Symbol(ECDSA_P256_SHA256), + biscuit.Bytes(userPubkey), + }, + }}); err != nil { + return err + } + + challenge := make([]byte, challengeSize) + if _, err := rand.Reader.Read(challenge); err != nil { + return err + } + + if err := b.AddAuthorityFact(biscuit.Fact{Predicate: biscuit.Predicate{ + Name: "data", + IDs: []biscuit.Atom{ + dataID, + biscuit.Bytes(append(signStaticCtx, challenge...)), + }, + }}); err != nil { + return err + } + + if err := b.AddAuthorityCaveat(biscuit.Rule{ + Head: biscuit.Predicate{Name: "valid", IDs: []biscuit.Atom{biscuit.Variable(0)}}, + Body: []biscuit.Predicate{ + {Name: "should_sign", IDs: []biscuit.Atom{biscuit.SymbolAuthority, biscuit.Variable(0), biscuit.Variable(1), biscuit.Variable(2)}}, + {Name: "valid_signature", IDs: []biscuit.Atom{biscuit.Symbol("ambient"), biscuit.Variable(0), biscuit.Variable(1), biscuit.Variable(2)}}, + }, + }); err != nil { + return err + } + + return nil +} + +// withAudienceSignature add an authority audience_signature fact, containing a challenge and +// a matching signature using the audience key. +// the verifier is responsible of providing a valid_audience_signature fact, after +// verifying the signature using the audience pubkey. +func (b *hubauthBuilder) withAudienceSignature(audience string, audienceKey *kmssign.Key) error { + if len(audience) == 0 { + return errors.New("audience is required") + } + + data, err := audienceSign(audience, audienceKey) + if err != nil { + return err + } + + if err := b.AddAuthorityFact(biscuit.Fact{Predicate: biscuit.Predicate{ + Name: "audience_signature", + IDs: []biscuit.Atom{ + data.Audience, + data.Challenge, + data.Signature, + }, + }}); err != nil { + return err + } + + if err := b.AddAuthorityCaveat(biscuit.Rule{ + Head: biscuit.Predicate{Name: "valid_audience", IDs: []biscuit.Atom{biscuit.Variable(0)}}, + Body: []biscuit.Predicate{ + {Name: "audience_signature", IDs: []biscuit.Atom{biscuit.SymbolAuthority, biscuit.Variable(0), biscuit.Variable(1), biscuit.Variable(2)}}, + {Name: "valid_audience_signature", IDs: []biscuit.Atom{biscuit.Symbol("ambient"), biscuit.Variable(0), biscuit.Variable(2)}}, + }, + }); err != nil { + return err + } + + return nil +} + +type hubauthBlockBuilder struct { + biscuit.BlockBuilder +} + +func (b *hubauthBlockBuilder) withUserSignature(sigData *userSignatureData) error { + return b.AddFact(biscuit.Fact{Predicate: biscuit.Predicate{ + Name: "signature", + IDs: []biscuit.Atom{ + sigData.DataID, + sigData.UserPubKey, + sigData.Signature, + sigData.Nonce, + sigData.Timestamp, + sigData.SignedBlockCount, + }, + }}) +} + +type hubauthVerifier struct { + biscuit.Verifier +} + +func (v *hubauthVerifier) getUserToSignData(userPubKey biscuit.Bytes, signedBlockCount int) (*userToSignData, error) { + toSign, err := v.Query(biscuit.Rule{ + Head: biscuit.Predicate{ + Name: "to_sign", + IDs: []biscuit.Atom{biscuit.Variable(0), biscuit.Variable(1), biscuit.Variable(2)}, + }, + Body: []biscuit.Predicate{ + { + Name: "should_sign", IDs: []biscuit.Atom{ + biscuit.SymbolAuthority, + biscuit.Variable(0), + biscuit.Variable(1), + biscuit.Bytes(userPubKey), + }, + }, { + Name: "data", IDs: []biscuit.Atom{ + biscuit.SymbolAuthority, + biscuit.Variable(0), + biscuit.Variable(2), + }, + }, + }, + }) + if err != nil { + return nil, err + } + + if g, w := len(toSign), 1; g != w { + return nil, fmt.Errorf("invalid to_sign fact count, got %d, want %d", g, w) + } + + toSignFact := toSign[0] + if g, w := len(toSignFact.IDs), 3; g != w { + return nil, fmt.Errorf("invalid to_sign fact, got %d atoms, want %d", g, w) + } + + sigData := &userToSignData{} + var ok bool + sigData.DataID, ok = toSign[0].IDs[0].(biscuit.Integer) + if !ok { + return nil, errors.New("invalid to_sign atom: dataID") + } + sigData.Alg, ok = toSign[0].IDs[1].(biscuit.Symbol) + if !ok { + return nil, errors.New("invalid to_sign atom: alg") + } + sigData.Data, ok = toSign[0].IDs[2].(biscuit.Bytes) + if !ok { + return nil, errors.New("invalid to_sign atom: data") + } + + if !bytes.HasPrefix(sigData.Data, signStaticCtx) { + return nil, ErrInvalidToSignDataPrefix + } + + sigData.SignedBlockCount = biscuit.Integer(signedBlockCount) + + return sigData, nil +} + +func (v *hubauthVerifier) ensureNotAlreadyUserSigned(dataID biscuit.Integer, userPubKey biscuit.Bytes) error { + alreadySigned, err := v.Query(biscuit.Rule{ + Head: biscuit.Predicate{Name: "already_signed", IDs: []biscuit.Atom{biscuit.Variable(0)}}, + Body: []biscuit.Predicate{ + {Name: "signature", IDs: []biscuit.Atom{dataID, userPubKey, biscuit.Variable(0)}}, + }, + }) + if err != nil { + return err + } + if len(alreadySigned) != 0 { + return ErrAlreadySigned + } + + return nil +} + +func (v *hubauthVerifier) getUserVerificationData() (*userVerificationData, error) { + toValidate, err := v.Query(biscuit.Rule{ + Head: biscuit.Predicate{ + Name: "to_validate", + IDs: []biscuit.Atom{ + biscuit.Variable(0), // dataID + biscuit.Variable(1), // alg + biscuit.Variable(2), // pubkey + biscuit.Variable(3), // data + biscuit.Variable(4), // signature + biscuit.Variable(5), // signerNonce + biscuit.Variable(6), // signerTimestamp + biscuit.Variable(7), // signedBlockCount + }}, + Body: []biscuit.Predicate{ + {Name: "should_sign", IDs: []biscuit.Atom{biscuit.SymbolAuthority, biscuit.Variable(0), biscuit.Variable(1), biscuit.Variable(2)}}, + {Name: "data", IDs: []biscuit.Atom{biscuit.SymbolAuthority, biscuit.Variable(0), biscuit.Variable(3)}}, + {Name: "signature", IDs: []biscuit.Atom{biscuit.Variable(0), biscuit.Variable(2), biscuit.Variable(4), biscuit.Variable(5), biscuit.Variable(6), biscuit.Variable(7)}}, + }, + }) + if err != nil { + return nil, err + } + + if g, w := len(toValidate), 1; g != w { + return nil, fmt.Errorf("invalid to_validate fact count, got %d, want %d", g, w) + } + + toValidateFact := toValidate[0] + if g, w := len(toValidateFact.IDs), 8; g != w { + return nil, fmt.Errorf("invalid to_valid fact atom count, got %d, want %d", g, w) + } + + toVerify := &userVerificationData{} + var ok bool + toVerify.DataID, ok = toValidateFact.IDs[0].(biscuit.Integer) + if !ok { + return nil, errors.New("invalid to_validate atom: dataID") + } + toVerify.Alg, ok = toValidateFact.IDs[1].(biscuit.Symbol) + if !ok { + return nil, errors.New("invalid to_validate atom: alg") + } + toVerify.UserPubKey, ok = toValidateFact.IDs[2].(biscuit.Bytes) + if !ok { + return nil, errors.New("invalid to_validate atom: userPubKey") + } + toVerify.Data, ok = toValidateFact.IDs[3].(biscuit.Bytes) + if !ok { + return nil, errors.New("invalid to_validate atom: data") + } + toVerify.Signature, ok = toValidateFact.IDs[4].(biscuit.Bytes) + if !ok { + return nil, errors.New("invalid to_validate atom: signature") + } + toVerify.Nonce, ok = toValidateFact.IDs[5].(biscuit.Bytes) + if !ok { + return nil, errors.New("invalid to_validate atom: nonce") + } + toVerify.Timestamp, ok = toValidateFact.IDs[6].(biscuit.Date) + if !ok { + return nil, errors.New("invalid to_validate atom: timestamp") + } + toVerify.SignedBlockCount, ok = toValidateFact.IDs[7].(biscuit.Integer) + if !ok { + return nil, errors.New("invalid to_validate atom: signedBlockCount") + } + + return toVerify, nil +} + +func (v *hubauthVerifier) withValidatedUserSignature(data *userVerificationData) error { + v.AddFact(biscuit.Fact{Predicate: biscuit.Predicate{ + Name: "valid_signature", + IDs: []biscuit.Atom{biscuit.Symbol("ambient"), data.DataID, data.Alg, data.UserPubKey}, + }}) + + return nil +} + +func (v *hubauthVerifier) getAudienceVerificationData(audience string) (*audienceVerificationData, error) { + toValidate, err := v.Query(biscuit.Rule{ + Head: biscuit.Predicate{ + Name: "audience_to_validate", + IDs: []biscuit.Atom{ + biscuit.Variable(0), // challenge + biscuit.Variable(1), // signature + }}, + Body: []biscuit.Predicate{ + {Name: "audience_signature", IDs: []biscuit.Atom{biscuit.SymbolAuthority, biscuit.Symbol(audience), biscuit.Variable(0), biscuit.Variable(1)}}, + }, + }) + if err != nil { + return nil, err + } + + if g, w := len(toValidate), 1; g != w { + return nil, fmt.Errorf("invalid audience_to_validate fact count, got %d, want %d", g, w) + } + + toValidateFact := toValidate[0] + if g, w := len(toValidateFact.IDs), 2; g != w { + return nil, fmt.Errorf("invalid audience_to_validate fact atom count, got %d, want %d", g, w) + } + + toVerify := &audienceVerificationData{Audience: biscuit.Symbol(audience)} + var ok bool + toVerify.Challenge, ok = toValidateFact.IDs[0].(biscuit.Bytes) + if !ok { + return nil, errors.New("invalid audience_to_validate atom: challenge") + } + toVerify.Signature, ok = toValidateFact.IDs[1].(biscuit.Bytes) + if !ok { + return nil, errors.New("invalid audience_to_validate atom: signature") + } + + return toVerify, nil +} + +func (v *hubauthVerifier) withValidatedAudienceSignature(data *audienceVerificationData) error { + v.AddFact(biscuit.Fact{Predicate: biscuit.Predicate{ + Name: "valid_audience_signature", + IDs: []biscuit.Atom{biscuit.Symbol("ambient"), data.Audience, data.Signature}, + }}) + + return nil +} From 4db21fa5c46f56c406d1e6930179f01d8c54fbaa Mon Sep 17 00:00:00 2001 From: daeMOn Date: Mon, 12 Oct 2020 09:39:15 +0200 Subject: [PATCH 02/24] add biscuit metadata and expiration date Verifiers must now provide the current time for verifying the biscuit, and can extract user informations. User's pubkeys are now provided in http param when exchanging code. Removed block count from biscuit weakening the signature. --- go.sum | 38 +++++++++++- pkg/biscuit/biscuit.go | 56 ++++++++++++----- pkg/biscuit/biscuit_test.go | 20 ++++-- pkg/biscuit/signature.go | 66 ++++++++++++-------- pkg/biscuit/signature_test.go | 105 +++++++++++++------------------ pkg/biscuit/wrapper.go | 112 ++++++++++++++++++++++++++++++---- 6 files changed, 273 insertions(+), 124 deletions(-) diff --git a/go.sum b/go.sum index aa14fc0..d812a10 100644 --- a/go.sum +++ b/go.sum @@ -77,15 +77,32 @@ github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.m github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= github.com/envoyproxy/go-control-plane v0.9.7/go.mod h1:cwu0lG7PUMfa9snN8LXBig5ynNVH9qI8YYLbd1fK2po= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/flynn/biscuit-go v0.0.0-20200907174027-193b7bdbbdca h1:LUZQQzaCT+gltxii4icyPH5oMdAP38JmbvO9aI0E4qM= +github.com/flynn/biscuit-go v0.0.0-20200907174027-193b7bdbbdca/go.mod h1:EMJZ3stAYtwaP763F5HcGjPjCnYu21V2TEsg/iw88I8= +github.com/flynn/biscuit-go v0.0.0-20201009174859-e7eb59a90195 h1:TP3jMHmhjz8XxqqigEd5OQffNAO/6KPvGUYII6TFdmI= +github.com/flynn/biscuit-go v0.0.0-20201009174859-e7eb59a90195/go.mod h1:EMJZ3stAYtwaP763F5HcGjPjCnYu21V2TEsg/iw88I8= github.com/flynn/biscuit-go v0.0.0-20201015081742-15d7d351f345 h1:ME6bm5dwn9V2DUlfXJqeN121B5nM7rDFqLFOATALqYE= github.com/flynn/biscuit-go v0.0.0-20201015081742-15d7d351f345/go.mod h1:Sj4oR2hNkrZH1cf3Cj5DPHc3Xq0o61GWeau6UkZR+3c= +github.com/flynn/biscuit-go v0.0.0-20201204161836-6af1c88a7b3d h1:RHIlExiAgFgF1hQzdjhq41dnlOlkbcsOczQD+YgVQRk= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= -github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= -github.com/flynn/biscuit-go v0.0.0-20200907174027-193b7bdbbdca h1:LUZQQzaCT+gltxii4icyPH5oMdAP38JmbvO9aI0E4qM= -github.com/flynn/biscuit-go v0.0.0-20200907174027-193b7bdbbdca/go.mod h1:EMJZ3stAYtwaP763F5HcGjPjCnYu21V2TEsg/iw88I8= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= +github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= +github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= +github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= +github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= +github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= +github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e h1:1r7pUrabqp18hOBcwBwiTsbnFeTZHV9eER/QT5JVZxY= +github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFUx0Y= +github.com/golang/mock v1.4.0/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= +github.com/golang/mock v1.4.1/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= +github.com/golang/mock v1.4.3/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= +github.com/golang/mock v1.4.4/go.mod h1:l3mdAwkq5BuhzHwde/uurv3sEJeZMXNpwsxVWU71h+4= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= github.com/golang/protobuf v1.3.4/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= @@ -282,7 +299,22 @@ golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9 h1:SQFwaSi55rU7vdNs9Yr0Z324VNlrF+0wMqRXT4St8ck= +golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20201207232520-09787c993a3a h1:DcqTD9SDLc+1P/r1EmRBwnVsrOwW+kk2vWf9n+1sGhs= +golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180816055513-1c9583448a9c/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190606165138-5da285871e9c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190624142023-c5567b49c5d0/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191001151750-bb3f8db39f24/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191228213918-04cbcbbfeed8/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200113162924-86b910548bc1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200122134326-e047566fdf82/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/pkg/biscuit/biscuit.go b/pkg/biscuit/biscuit.go index ef1a452..e335e60 100644 --- a/pkg/biscuit/biscuit.go +++ b/pkg/biscuit/biscuit.go @@ -3,12 +3,20 @@ package biscuit import ( "crypto/rand" "fmt" + "time" "github.com/flynn/biscuit-go" "github.com/flynn/biscuit-go/sig" "github.com/flynn/hubauth/pkg/kmssign" ) +type Metadata struct { + ClientID string + UserID string + UserEmail string + IssueTime time.Time +} + type UserKeyPair struct { Public []byte Private []byte @@ -16,7 +24,7 @@ type UserKeyPair struct { // GenerateSignable returns a biscuit which will only verify after being // signed with the private key matching the given userPubkey. -func GenerateSignable(rootKey sig.Keypair, audience string, audienceKey *kmssign.Key, userPubkey []byte) ([]byte, error) { +func GenerateSignable(rootKey sig.Keypair, audience string, audienceKey *kmssign.Key, userPublicKey []byte, expireTime time.Time, m *Metadata) ([]byte, error) { builder := &hubauthBuilder{ Builder: biscuit.NewBuilder(rand.Reader, rootKey), } @@ -25,7 +33,15 @@ func GenerateSignable(rootKey sig.Keypair, audience string, audienceKey *kmssign return nil, err } - if err := builder.withUserToSignFact(userPubkey); err != nil { + if err := builder.withUserToSignFact(userPublicKey); err != nil { + return nil, err + } + + if err := builder.withExpire(expireTime); err != nil { + return nil, err + } + + if err := builder.withMetadata(m); err != nil { return nil, err } @@ -55,7 +71,7 @@ func Sign(token []byte, rootPubKey sig.PublicKey, userKey *UserKeyPair) ([]byte, Verifier: v, } - toSignData, err := verifier.getUserToSignData(userKey.Public, b.BlockCount()) + toSignData, err := verifier.getUserToSignData(userKey.Public) if err != nil { return nil, fmt.Errorf("biscuit: failed to get to_sign data: %w", err) } @@ -92,46 +108,56 @@ func Sign(token []byte, rootPubKey sig.PublicKey, userKey *UserKeyPair) ([]byte, // Verify will verify the biscuit, the included audience and user signature, and return an error // when anything is invalid. -func Verify(token []byte, rootPubKey sig.PublicKey, audience string, audienceKey *kmssign.Key) error { +func Verify(token []byte, rootPubKey sig.PublicKey, audience string, audienceKey *kmssign.Key) (*Metadata, error) { b, err := biscuit.Unmarshal(token) if err != nil { - return fmt.Errorf("biscuit: failed to unmarshal: %w", err) + return nil, fmt.Errorf("biscuit: failed to unmarshal: %w", err) } v, err := b.Verify(rootPubKey) if err != nil { - return fmt.Errorf("biscuit: failed to verify: %w", err) + return nil, fmt.Errorf("biscuit: failed to verify: %w", err) } verifier := &hubauthVerifier{v} audienceVerificationData, err := verifier.getAudienceVerificationData(audience) if err != nil { - return fmt.Errorf("biscuit: failed to retrieve audience signature data: %w", err) + return nil, fmt.Errorf("biscuit: failed to retrieve audience signature data: %w", err) } if err := verifyAudienceSignature(audienceKey, audienceVerificationData); err != nil { - return fmt.Errorf("biscuit: failed to verify audience signature: %w", err) + return nil, fmt.Errorf("biscuit: failed to verify audience signature: %w", err) } if err := verifier.withValidatedAudienceSignature(audienceVerificationData); err != nil { - return fmt.Errorf("biscuit: failed to add validated signature: %w", err) + return nil, fmt.Errorf("biscuit: failed to add validated signature: %w", err) } userVerificationData, err := verifier.getUserVerificationData() if err != nil { - return fmt.Errorf("biscuit: failed to retrieve user signature data: %w", err) + return nil, fmt.Errorf("biscuit: failed to retrieve user signature data: %w", err) } - signedTokenHash, err := b.SHA256Sum(int(userVerificationData.SignedBlockCount)) + // TODO: improve biscuit API to allow retrieve the block index the signature is at + // so that we can still append other blocks if needed. Right now the signature MUST BE the last block. + signedTokenHash, err := b.SHA256Sum(b.BlockCount() - 1) if err != nil { - return fmt.Errorf("biscuit: failed to generate token hash: %w", err) + return nil, fmt.Errorf("biscuit: failed to generate token hash: %w", err) } if err := verifyUserSignature(signedTokenHash, userVerificationData); err != nil { - return fmt.Errorf("biscuit: failed to verify user signature: %w", err) + return nil, fmt.Errorf("biscuit: failed to verify user signature: %w", err) } if err := verifier.withValidatedUserSignature(userVerificationData); err != nil { - return fmt.Errorf("biscuit: failed to add validated signature: %w", err) + return nil, fmt.Errorf("biscuit: failed to add validated signature: %w", err) + } + + if err := verifier.withCurrentTime(time.Now()); err != nil { + return nil, fmt.Errorf("biscuit: failed to add current time: %w", err) + } + + if err := verifier.Verify(); err != nil { + return nil, fmt.Errorf("biscuit: failed to verify: %w", err) } - return verifier.Verify() + return verifier.getMetadata() } diff --git a/pkg/biscuit/biscuit_test.go b/pkg/biscuit/biscuit_test.go index ed7b273..bf6199b 100644 --- a/pkg/biscuit/biscuit_test.go +++ b/pkg/biscuit/biscuit_test.go @@ -4,6 +4,7 @@ import ( "context" "crypto/rand" "testing" + "time" "github.com/flynn/biscuit-go/sig" "github.com/flynn/hubauth/pkg/kmssign" @@ -20,8 +21,13 @@ func TestBiscuit(t *testing.T) { require.NoError(t, err) userKey := generateUserKeyPair(t) - - signableBiscuit, err := GenerateSignable(rootKey, audience, audienceKey, userKey.Public) + metas := &Metadata{ + ClientID: "abcd", + UserEmail: "1234@example.com", + UserID: "1234", + IssueTime: time.Now(), + } + signableBiscuit, err := GenerateSignable(rootKey, audience, audienceKey, userKey.Public, time.Now().Add(5*time.Minute), metas) require.NoError(t, err) t.Logf("signable biscuit size: %d", len(signableBiscuit)) @@ -30,8 +36,12 @@ func TestBiscuit(t *testing.T) { require.NoError(t, err) t.Logf("signed biscuit size: %d", len(signedBiscuit)) - err = Verify(signedBiscuit, rootKey.Public(), audience, audienceKey) + res, err := Verify(signedBiscuit, rootKey.Public(), audience, audienceKey) require.NoError(t, err) + require.Equal(t, metas.ClientID, res.ClientID) + require.Equal(t, metas.UserID, res.UserID) + require.Equal(t, metas.UserEmail, res.UserEmail) + require.WithinDuration(t, metas.IssueTime, res.IssueTime, 1*time.Second) }) t.Run("user sign with wrong key", func(t *testing.T) { @@ -43,7 +53,7 @@ func TestBiscuit(t *testing.T) { signedBiscuit, err := Sign(signableBiscuit, rootKey.Public(), userKey) require.NoError(t, err) - err = Verify(signedBiscuit, rootKey.Public(), "http://another.audience.url", audienceKey) + _, err = Verify(signedBiscuit, rootKey.Public(), "http://another.audience.url", audienceKey) require.Error(t, err) wrongAudience := "http://another.audience.url" @@ -51,7 +61,7 @@ func TestBiscuit(t *testing.T) { wrongAudienceKey, err := kmssign.NewKey(context.Background(), kms, wrongAudience) require.NoError(t, err) - err = Verify(signedBiscuit, rootKey.Public(), audience, wrongAudienceKey) + _, err = Verify(signedBiscuit, rootKey.Public(), audience, wrongAudienceKey) require.Error(t, err) }) } diff --git a/pkg/biscuit/signature.go b/pkg/biscuit/signature.go index 385e4bf..078eaa6 100644 --- a/pkg/biscuit/signature.go +++ b/pkg/biscuit/signature.go @@ -3,10 +3,12 @@ package biscuit import ( "crypto" "crypto/ecdsa" + "crypto/elliptic" "crypto/rand" "crypto/sha256" "crypto/x509" "errors" + "fmt" "time" "github.com/flynn/biscuit-go" @@ -25,30 +27,27 @@ const ( ) type userToSignData struct { - DataID biscuit.Integer - Alg biscuit.Symbol - Data biscuit.Bytes - SignedBlockCount biscuit.Integer + DataID biscuit.Integer + Alg biscuit.Symbol + Data biscuit.Bytes } type userSignatureData struct { - DataID biscuit.Integer - UserPubKey biscuit.Bytes - Signature biscuit.Bytes - SignedBlockCount biscuit.Integer - Nonce biscuit.Bytes - Timestamp biscuit.Date + DataID biscuit.Integer + UserPubKey biscuit.Bytes + Signature biscuit.Bytes + Nonce biscuit.Bytes + Timestamp biscuit.Date } type userVerificationData struct { - DataID biscuit.Integer - Alg biscuit.Symbol - Data biscuit.Bytes - UserPubKey biscuit.Bytes - Signature biscuit.Bytes - SignedBlockCount biscuit.Integer - Nonce biscuit.Bytes - Timestamp biscuit.Date + DataID biscuit.Integer + Alg biscuit.Symbol + Data biscuit.Bytes + UserPubKey biscuit.Bytes + Signature biscuit.Bytes + Nonce biscuit.Bytes + Timestamp biscuit.Date } func userSign(tokenHash []byte, userKey *UserKeyPair, toSignData *userToSignData) (*userSignatureData, error) { @@ -67,7 +66,6 @@ func userSign(tokenHash []byte, userKey *UserKeyPair, toSignData *userToSignData dataToSign = append(dataToSign, tokenHash...) dataToSign = append(dataToSign, signerNonce...) dataToSign = append(dataToSign, []byte(signerTimestamp.Format(time.RFC3339))...) - dataToSign = append(dataToSign, []byte(toSignData.SignedBlockCount.String())...) var signedData biscuit.Bytes switch SignatureAlg(toSignData.Alg) { @@ -86,12 +84,11 @@ func userSign(tokenHash []byte, userKey *UserKeyPair, toSignData *userToSignData } return &userSignatureData{ - DataID: toSignData.DataID, - Nonce: signerNonce, - Signature: signedData, - SignedBlockCount: toSignData.SignedBlockCount, - Timestamp: biscuit.Date(signerTimestamp), - UserPubKey: userKey.Public, + DataID: toSignData.DataID, + Nonce: signerNonce, + Signature: signedData, + Timestamp: biscuit.Date(signerTimestamp), + UserPubKey: userKey.Public, }, nil } @@ -101,7 +98,6 @@ func verifyUserSignature(signedTokenHash []byte, data *userVerificationData) err signedData = append(signedData, signedTokenHash...) signedData = append(signedData, data.Nonce...) signedData = append(signedData, []byte(time.Time(data.Timestamp).Format(time.RFC3339))...) - signedData = append(signedData, []byte(data.SignedBlockCount.String())...) switch SignatureAlg(data.Alg) { case ECDSA_P256_SHA256: @@ -160,3 +156,21 @@ func verifyAudienceSignature(audiencePubkey *kmssign.Key, data *audienceVerifica } return nil } + +func validatePKIXP256PublicKey(pubkey []byte) error { + key, err := x509.ParsePKIXPublicKey(pubkey) + if err != nil { + return fmt.Errorf("failed to parse PKIX, ASN.1 DER public key: %v", err) + } + + ecKey, ok := key.(*ecdsa.PublicKey) + if !ok { + return errors.New("public key is not an *ecdsa.PublicKey") + } + + if ecKey.Curve != elliptic.P256() { + return fmt.Errorf("publickey is on wrong curve, expected P256") + } + + return nil +} diff --git a/pkg/biscuit/signature_test.go b/pkg/biscuit/signature_test.go index 5cb5d8b..d4d4cc2 100644 --- a/pkg/biscuit/signature_test.go +++ b/pkg/biscuit/signature_test.go @@ -25,16 +25,14 @@ func TestUserSignVerify(t *testing.T) { userKey := generateUserKeyPair(t) toSignData := &userToSignData{ - DataID: 1, - Alg: biscuit.Symbol(ECDSA_P256_SHA256), - Data: []byte("challenge"), - SignedBlockCount: 2, + DataID: 1, + Alg: biscuit.Symbol(ECDSA_P256_SHA256), + Data: []byte("challenge"), } signedData, err := userSign(tokenHash, userKey, toSignData) require.NoError(t, err) require.NotEmpty(t, signedData.Signature) - require.Equal(t, biscuit.Integer(2), signedData.SignedBlockCount) require.Equal(t, biscuit.Integer(1), signedData.DataID) require.Equal(t, biscuit.Bytes(userKey.Public), signedData.UserPubKey) @@ -45,14 +43,13 @@ func TestUserSignVerify(t *testing.T) { require.WithinDuration(t, time.Now(), time.Time(signedData.Timestamp), 1*time.Second) require.NoError(t, verifyUserSignature(tokenHash, &userVerificationData{ - DataID: toSignData.DataID, - Alg: toSignData.Alg, - Data: toSignData.Data, - Nonce: signedData.Nonce, - Signature: signedData.Signature, - SignedBlockCount: signedData.SignedBlockCount, - Timestamp: signedData.Timestamp, - UserPubKey: signedData.UserPubKey, + DataID: toSignData.DataID, + Alg: toSignData.Alg, + Data: toSignData.Data, + Nonce: signedData.Nonce, + Signature: signedData.Signature, + Timestamp: signedData.Timestamp, + UserPubKey: signedData.UserPubKey, })) } @@ -112,10 +109,9 @@ func TestUserSignFail(t *testing.T) { func TestVerifyUserSignatureFail(t *testing.T) { tokenHash := []byte("token hash") toSignData := &userToSignData{ - DataID: 1, - Alg: biscuit.Symbol(ECDSA_P256_SHA256), - Data: []byte("challenge"), - SignedBlockCount: 2, + DataID: 1, + Alg: biscuit.Symbol(ECDSA_P256_SHA256), + Data: []byte("challenge"), } userKey := generateUserKeyPair(t) @@ -160,14 +156,13 @@ func TestVerifyUserSignatureFail(t *testing.T) { desc: "wrong pubkey", tokenHash: tokenHash, data: &userVerificationData{ - Alg: biscuit.Symbol(ECDSA_P256_SHA256), - UserPubKey: invalidKey.Public, - Data: toSignData.Data, - DataID: toSignData.DataID, - Nonce: signedData.Nonce, - Signature: signedData.Signature, - Timestamp: signedData.Timestamp, - SignedBlockCount: signedData.SignedBlockCount, + Alg: biscuit.Symbol(ECDSA_P256_SHA256), + UserPubKey: invalidKey.Public, + Data: toSignData.Data, + DataID: toSignData.DataID, + Nonce: signedData.Nonce, + Signature: signedData.Signature, + Timestamp: signedData.Timestamp, }, }, { @@ -175,14 +170,13 @@ func TestVerifyUserSignatureFail(t *testing.T) { expectedErr: ErrInvalidSignature, tokenHash: []byte("wrong"), data: &userVerificationData{ - Alg: biscuit.Symbol(ECDSA_P256_SHA256), - UserPubKey: userKey.Public, - Data: toSignData.Data, - DataID: toSignData.DataID, - Nonce: signedData.Nonce, - Signature: signedData.Signature, - Timestamp: signedData.Timestamp, - SignedBlockCount: signedData.SignedBlockCount, + Alg: biscuit.Symbol(ECDSA_P256_SHA256), + UserPubKey: userKey.Public, + Data: toSignData.Data, + DataID: toSignData.DataID, + Nonce: signedData.Nonce, + Signature: signedData.Signature, + Timestamp: signedData.Timestamp, }, }, { @@ -190,14 +184,13 @@ func TestVerifyUserSignatureFail(t *testing.T) { expectedErr: ErrInvalidSignature, tokenHash: tokenHash, data: &userVerificationData{ - Alg: biscuit.Symbol(ECDSA_P256_SHA256), - UserPubKey: userKey.Public, - Data: toSignData.Data, - DataID: toSignData.DataID, - Nonce: []byte("another nonce"), - Signature: signedData.Signature, - SignedBlockCount: signedData.SignedBlockCount, - Timestamp: signedData.Timestamp, + Alg: biscuit.Symbol(ECDSA_P256_SHA256), + UserPubKey: userKey.Public, + Data: toSignData.Data, + DataID: toSignData.DataID, + Nonce: []byte("another nonce"), + Signature: signedData.Signature, + Timestamp: signedData.Timestamp, }, }, { @@ -205,29 +198,13 @@ func TestVerifyUserSignatureFail(t *testing.T) { expectedErr: ErrInvalidSignature, tokenHash: tokenHash, data: &userVerificationData{ - Alg: biscuit.Symbol(ECDSA_P256_SHA256), - UserPubKey: userKey.Public, - Data: toSignData.Data, - DataID: toSignData.DataID, - Nonce: signedData.Nonce, - Signature: signedData.Signature, - Timestamp: biscuit.Date(time.Now().Add(1 * time.Second)), - SignedBlockCount: signedData.SignedBlockCount, - }, - }, - { - desc: "tampered signedBlockCount", - expectedErr: ErrInvalidSignature, - tokenHash: tokenHash, - data: &userVerificationData{ - Alg: biscuit.Symbol(ECDSA_P256_SHA256), - UserPubKey: userKey.Public, - Data: toSignData.Data, - DataID: toSignData.DataID, - Nonce: signedData.Nonce, - Signature: signedData.Signature, - Timestamp: signedData.Timestamp, - SignedBlockCount: signedData.SignedBlockCount + 1, + Alg: biscuit.Symbol(ECDSA_P256_SHA256), + UserPubKey: userKey.Public, + Data: toSignData.Data, + DataID: toSignData.DataID, + Nonce: signedData.Nonce, + Signature: signedData.Signature, + Timestamp: biscuit.Date(time.Now().Add(1 * time.Second)), }, }, } diff --git a/pkg/biscuit/wrapper.go b/pkg/biscuit/wrapper.go index ec16ebd..f39681f 100644 --- a/pkg/biscuit/wrapper.go +++ b/pkg/biscuit/wrapper.go @@ -5,8 +5,10 @@ import ( "crypto/rand" "errors" "fmt" + "time" "github.com/flynn/biscuit-go" + "github.com/flynn/biscuit-go/datalog" "github.com/flynn/hubauth/pkg/kmssign" ) @@ -31,6 +33,10 @@ type hubauthBuilder struct { func (b *hubauthBuilder) withUserToSignFact(userPubkey []byte) error { dataID := biscuit.Integer(0) + if err := validatePKIXP256PublicKey(userPubkey); err != nil { + return err + } + if err := b.AddAuthorityFact(biscuit.Fact{Predicate: biscuit.Predicate{ Name: "should_sign", IDs: []biscuit.Atom{ @@ -108,6 +114,38 @@ func (b *hubauthBuilder) withAudienceSignature(audience string, audienceKey *kms return nil } +func (b *hubauthBuilder) withMetadata(m *Metadata) error { + return b.AddAuthorityFact(biscuit.Fact{Predicate: biscuit.Predicate{ + Name: "metadata", + IDs: []biscuit.Atom{ + biscuit.String(m.ClientID), + biscuit.String(m.UserID), + biscuit.String(m.UserEmail), + biscuit.Date(m.IssueTime), + }, + }}) +} + +func (b *hubauthBuilder) withExpire(exp time.Time) error { + if err := b.AddAuthorityCaveat(biscuit.Rule{ + Head: biscuit.Predicate{Name: "not_expired", IDs: []biscuit.Atom{biscuit.Variable(0)}}, + Body: []biscuit.Predicate{ + {Name: "current_time", IDs: []biscuit.Atom{biscuit.Symbol("ambient"), biscuit.Variable(0)}}, + }, + Constraints: []biscuit.Constraint{{ + Name: biscuit.Variable(0), + Checker: biscuit.DateComparisonChecker{ + Comparison: datalog.DateComparisonBefore, + Date: biscuit.Date(exp), + }, + }}, + }); err != nil { + return err + } + + return nil +} + type hubauthBlockBuilder struct { biscuit.BlockBuilder } @@ -121,7 +159,6 @@ func (b *hubauthBlockBuilder) withUserSignature(sigData *userSignatureData) erro sigData.Signature, sigData.Nonce, sigData.Timestamp, - sigData.SignedBlockCount, }, }}) } @@ -130,7 +167,7 @@ type hubauthVerifier struct { biscuit.Verifier } -func (v *hubauthVerifier) getUserToSignData(userPubKey biscuit.Bytes, signedBlockCount int) (*userToSignData, error) { +func (v *hubauthVerifier) getUserToSignData(userPubKey biscuit.Bytes) (*userToSignData, error) { toSign, err := v.Query(biscuit.Rule{ Head: biscuit.Predicate{ Name: "to_sign", @@ -185,8 +222,6 @@ func (v *hubauthVerifier) getUserToSignData(userPubKey biscuit.Bytes, signedBloc return nil, ErrInvalidToSignDataPrefix } - sigData.SignedBlockCount = biscuit.Integer(signedBlockCount) - return sigData, nil } @@ -219,12 +254,11 @@ func (v *hubauthVerifier) getUserVerificationData() (*userVerificationData, erro biscuit.Variable(4), // signature biscuit.Variable(5), // signerNonce biscuit.Variable(6), // signerTimestamp - biscuit.Variable(7), // signedBlockCount }}, Body: []biscuit.Predicate{ {Name: "should_sign", IDs: []biscuit.Atom{biscuit.SymbolAuthority, biscuit.Variable(0), biscuit.Variable(1), biscuit.Variable(2)}}, {Name: "data", IDs: []biscuit.Atom{biscuit.SymbolAuthority, biscuit.Variable(0), biscuit.Variable(3)}}, - {Name: "signature", IDs: []biscuit.Atom{biscuit.Variable(0), biscuit.Variable(2), biscuit.Variable(4), biscuit.Variable(5), biscuit.Variable(6), biscuit.Variable(7)}}, + {Name: "signature", IDs: []biscuit.Atom{biscuit.Variable(0), biscuit.Variable(2), biscuit.Variable(4), biscuit.Variable(5), biscuit.Variable(6)}}, }, }) if err != nil { @@ -236,7 +270,7 @@ func (v *hubauthVerifier) getUserVerificationData() (*userVerificationData, erro } toValidateFact := toValidate[0] - if g, w := len(toValidateFact.IDs), 8; g != w { + if g, w := len(toValidateFact.IDs), 7; g != w { return nil, fmt.Errorf("invalid to_valid fact atom count, got %d, want %d", g, w) } @@ -270,10 +304,6 @@ func (v *hubauthVerifier) getUserVerificationData() (*userVerificationData, erro if !ok { return nil, errors.New("invalid to_validate atom: timestamp") } - toVerify.SignedBlockCount, ok = toValidateFact.IDs[7].(biscuit.Integer) - if !ok { - return nil, errors.New("invalid to_validate atom: signedBlockCount") - } return toVerify, nil } @@ -326,6 +356,54 @@ func (v *hubauthVerifier) getAudienceVerificationData(audience string) (*audienc return toVerify, nil } +func (v *hubauthVerifier) getMetadata() (*Metadata, error) { + metaFacts, err := v.Query(biscuit.Rule{ + Head: biscuit.Predicate{ + Name: "metadata", + IDs: []biscuit.Atom{ + biscuit.Variable(0), // clientID + biscuit.Variable(1), // userID + biscuit.Variable(2), // userEmail + biscuit.Variable(3), // issueTime + }}, + Body: []biscuit.Predicate{ + {Name: "metadata", IDs: []biscuit.Atom{biscuit.SymbolAuthority, biscuit.Variable(0), biscuit.Variable(1), biscuit.Variable(2), biscuit.Variable(3)}}, + }, + }) + if err != nil { + return nil, err + } + + if g, w := len(metaFacts), 1; g != w { + return nil, fmt.Errorf("invalid metadata fact count, got %d, want %d", g, w) + } + + metaFact := metaFacts[0] + + clientID, ok := metaFact.IDs[0].(biscuit.String) + if !ok { + return nil, errors.New("invalid metadata atom: clientID") + } + userID, ok := metaFact.IDs[1].(biscuit.String) + if !ok { + return nil, errors.New("invalid metadata atom: userID") + } + userEmail, ok := metaFact.IDs[2].(biscuit.String) + if !ok { + return nil, errors.New("invalid metadata atom: userEmail") + } + issueTime, ok := metaFact.IDs[3].(biscuit.Date) + if !ok { + return nil, errors.New("invalid metadata atom: issueTime") + } + return &Metadata{ + ClientID: string(clientID), + UserID: string(userID), + UserEmail: string(userEmail), + IssueTime: time.Time(issueTime), + }, nil +} + func (v *hubauthVerifier) withValidatedAudienceSignature(data *audienceVerificationData) error { v.AddFact(biscuit.Fact{Predicate: biscuit.Predicate{ Name: "valid_audience_signature", @@ -334,3 +412,15 @@ func (v *hubauthVerifier) withValidatedAudienceSignature(data *audienceVerificat return nil } + +func (v *hubauthVerifier) withCurrentTime(t time.Time) error { + v.AddFact(biscuit.Fact{Predicate: biscuit.Predicate{ + Name: "current_time", + IDs: []biscuit.Atom{ + biscuit.Symbol("ambient"), + biscuit.Date(t), + }, + }}) + + return nil +} From fac0ee4dc623fb975e1c79badc009529f91d8e51 Mon Sep 17 00:00:00 2001 From: daeMOn Date: Tue, 13 Oct 2020 11:50:18 +0200 Subject: [PATCH 03/24] cleanup --- pkg/biscuit/biscuit.go | 2 +- pkg/biscuit/signature_test.go | 11 +++-------- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/pkg/biscuit/biscuit.go b/pkg/biscuit/biscuit.go index e335e60..0c7a3e4 100644 --- a/pkg/biscuit/biscuit.go +++ b/pkg/biscuit/biscuit.go @@ -26,7 +26,7 @@ type UserKeyPair struct { // signed with the private key matching the given userPubkey. func GenerateSignable(rootKey sig.Keypair, audience string, audienceKey *kmssign.Key, userPublicKey []byte, expireTime time.Time, m *Metadata) ([]byte, error) { builder := &hubauthBuilder{ - Builder: biscuit.NewBuilder(rand.Reader, rootKey), + biscuit.NewBuilder(rand.Reader, rootKey), } if err := builder.withAudienceSignature(audience, audienceKey); err != nil { diff --git a/pkg/biscuit/signature_test.go b/pkg/biscuit/signature_test.go index d4d4cc2..2351e26 100644 --- a/pkg/biscuit/signature_test.go +++ b/pkg/biscuit/signature_test.go @@ -211,7 +211,6 @@ func TestVerifyUserSignatureFail(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.desc, func(t *testing.T) { - err := verifyUserSignature(testCase.tokenHash, testCase.data) require.Error(t, err) if testCase.expectedErr != nil { @@ -224,12 +223,8 @@ func TestVerifyUserSignatureFail(t *testing.T) { func generateUserKeyPair(t *testing.T) *UserKeyPair { priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) require.NoError(t, err) - privBytes, err := x509.MarshalECPrivateKey(priv) - require.NoError(t, err) - pubBytes, err := x509.MarshalPKIXPublicKey(&priv.PublicKey) + + kp, err := NewECDSAKeyPair(priv) require.NoError(t, err) - return &UserKeyPair{ - Private: privBytes, - Public: pubBytes, - } + return kp } From b4faed849735af5c2871ab13fd6885f255514e28 Mon Sep 17 00:00:00 2001 From: Flavien Binet Date: Thu, 15 Oct 2020 10:38:31 +0200 Subject: [PATCH 04/24] moved biscuit pkg to biscuit-go repo --- pkg/biscuit/biscuit.go | 163 ------------- pkg/biscuit/biscuit_test.go | 67 ------ pkg/biscuit/signature.go | 176 -------------- pkg/biscuit/signature_test.go | 230 ------------------ pkg/biscuit/wrapper.go | 426 ---------------------------------- 5 files changed, 1062 deletions(-) delete mode 100644 pkg/biscuit/biscuit.go delete mode 100644 pkg/biscuit/biscuit_test.go delete mode 100644 pkg/biscuit/signature.go delete mode 100644 pkg/biscuit/signature_test.go delete mode 100644 pkg/biscuit/wrapper.go diff --git a/pkg/biscuit/biscuit.go b/pkg/biscuit/biscuit.go deleted file mode 100644 index 0c7a3e4..0000000 --- a/pkg/biscuit/biscuit.go +++ /dev/null @@ -1,163 +0,0 @@ -package biscuit - -import ( - "crypto/rand" - "fmt" - "time" - - "github.com/flynn/biscuit-go" - "github.com/flynn/biscuit-go/sig" - "github.com/flynn/hubauth/pkg/kmssign" -) - -type Metadata struct { - ClientID string - UserID string - UserEmail string - IssueTime time.Time -} - -type UserKeyPair struct { - Public []byte - Private []byte -} - -// GenerateSignable returns a biscuit which will only verify after being -// signed with the private key matching the given userPubkey. -func GenerateSignable(rootKey sig.Keypair, audience string, audienceKey *kmssign.Key, userPublicKey []byte, expireTime time.Time, m *Metadata) ([]byte, error) { - builder := &hubauthBuilder{ - biscuit.NewBuilder(rand.Reader, rootKey), - } - - if err := builder.withAudienceSignature(audience, audienceKey); err != nil { - return nil, err - } - - if err := builder.withUserToSignFact(userPublicKey); err != nil { - return nil, err - } - - if err := builder.withExpire(expireTime); err != nil { - return nil, err - } - - if err := builder.withMetadata(m); err != nil { - return nil, err - } - - b, err := builder.Build() - if err != nil { - return nil, err - } - - return b.Serialize() -} - -// Sign append a user signature on the given token and return it. -// The UserKeyPair key format to provide depends on the signature algorithm: -// - for ECDSA_P256_SHA256, the private key must be encoded in SEC 1, ASN.1 DER form, -// and the public key in PKIX, ASN.1 DER form. -func Sign(token []byte, rootPubKey sig.PublicKey, userKey *UserKeyPair) ([]byte, error) { - b, err := biscuit.Unmarshal(token) - if err != nil { - return nil, fmt.Errorf("biscuit: failed to unmarshal: %w", err) - } - - v, err := b.Verify(rootPubKey) - if err != nil { - return nil, fmt.Errorf("biscuit: failed to verify: %w", err) - } - verifier := &hubauthVerifier{ - Verifier: v, - } - - toSignData, err := verifier.getUserToSignData(userKey.Public) - if err != nil { - return nil, fmt.Errorf("biscuit: failed to get to_sign data: %w", err) - } - - if err := verifier.ensureNotAlreadyUserSigned(toSignData.DataID, userKey.Public); err != nil { - return nil, fmt.Errorf("biscuit: previous signature check failed: %w", err) - } - - tokenHash, err := b.SHA256Sum(b.BlockCount()) - if err != nil { - return nil, err - } - - signData, err := userSign(tokenHash, userKey, toSignData) - if err != nil { - return nil, fmt.Errorf("biscuit: signature failed: %w", err) - } - - builder := &hubauthBlockBuilder{ - BlockBuilder: b.CreateBlock(), - } - if err := builder.withUserSignature(signData); err != nil { - return nil, fmt.Errorf("biscuit: failed to create signature block: %w", err) - } - - clientKey := sig.GenerateKeypair(rand.Reader) - b, err = b.Append(rand.Reader, clientKey, builder.Build()) - if err != nil { - return nil, fmt.Errorf("biscuit: failed to append signature block: %w", err) - } - - return b.Serialize() -} - -// Verify will verify the biscuit, the included audience and user signature, and return an error -// when anything is invalid. -func Verify(token []byte, rootPubKey sig.PublicKey, audience string, audienceKey *kmssign.Key) (*Metadata, error) { - b, err := biscuit.Unmarshal(token) - if err != nil { - return nil, fmt.Errorf("biscuit: failed to unmarshal: %w", err) - } - - v, err := b.Verify(rootPubKey) - if err != nil { - return nil, fmt.Errorf("biscuit: failed to verify: %w", err) - } - verifier := &hubauthVerifier{v} - - audienceVerificationData, err := verifier.getAudienceVerificationData(audience) - if err != nil { - return nil, fmt.Errorf("biscuit: failed to retrieve audience signature data: %w", err) - } - - if err := verifyAudienceSignature(audienceKey, audienceVerificationData); err != nil { - return nil, fmt.Errorf("biscuit: failed to verify audience signature: %w", err) - } - if err := verifier.withValidatedAudienceSignature(audienceVerificationData); err != nil { - return nil, fmt.Errorf("biscuit: failed to add validated signature: %w", err) - } - - userVerificationData, err := verifier.getUserVerificationData() - if err != nil { - return nil, fmt.Errorf("biscuit: failed to retrieve user signature data: %w", err) - } - - // TODO: improve biscuit API to allow retrieve the block index the signature is at - // so that we can still append other blocks if needed. Right now the signature MUST BE the last block. - signedTokenHash, err := b.SHA256Sum(b.BlockCount() - 1) - if err != nil { - return nil, fmt.Errorf("biscuit: failed to generate token hash: %w", err) - } - - if err := verifyUserSignature(signedTokenHash, userVerificationData); err != nil { - return nil, fmt.Errorf("biscuit: failed to verify user signature: %w", err) - } - if err := verifier.withValidatedUserSignature(userVerificationData); err != nil { - return nil, fmt.Errorf("biscuit: failed to add validated signature: %w", err) - } - - if err := verifier.withCurrentTime(time.Now()); err != nil { - return nil, fmt.Errorf("biscuit: failed to add current time: %w", err) - } - - if err := verifier.Verify(); err != nil { - return nil, fmt.Errorf("biscuit: failed to verify: %w", err) - } - - return verifier.getMetadata() -} diff --git a/pkg/biscuit/biscuit_test.go b/pkg/biscuit/biscuit_test.go deleted file mode 100644 index bf6199b..0000000 --- a/pkg/biscuit/biscuit_test.go +++ /dev/null @@ -1,67 +0,0 @@ -package biscuit - -import ( - "context" - "crypto/rand" - "testing" - "time" - - "github.com/flynn/biscuit-go/sig" - "github.com/flynn/hubauth/pkg/kmssign" - "github.com/flynn/hubauth/pkg/kmssign/kmssim" - "github.com/stretchr/testify/require" -) - -func TestBiscuit(t *testing.T) { - rootKey := sig.GenerateKeypair(rand.Reader) - audience := "http://random.audience.url" - - kms := kmssim.NewClient([]string{audience}) - audienceKey, err := kmssign.NewKey(context.Background(), kms, audience) - require.NoError(t, err) - - userKey := generateUserKeyPair(t) - metas := &Metadata{ - ClientID: "abcd", - UserEmail: "1234@example.com", - UserID: "1234", - IssueTime: time.Now(), - } - signableBiscuit, err := GenerateSignable(rootKey, audience, audienceKey, userKey.Public, time.Now().Add(5*time.Minute), metas) - require.NoError(t, err) - t.Logf("signable biscuit size: %d", len(signableBiscuit)) - - t.Run("happy path", func(t *testing.T) { - signedBiscuit, err := Sign(signableBiscuit, rootKey.Public(), userKey) - require.NoError(t, err) - t.Logf("signed biscuit size: %d", len(signedBiscuit)) - - res, err := Verify(signedBiscuit, rootKey.Public(), audience, audienceKey) - require.NoError(t, err) - require.Equal(t, metas.ClientID, res.ClientID) - require.Equal(t, metas.UserID, res.UserID) - require.Equal(t, metas.UserEmail, res.UserEmail) - require.WithinDuration(t, metas.IssueTime, res.IssueTime, 1*time.Second) - }) - - t.Run("user sign with wrong key", func(t *testing.T) { - _, err := Sign(signableBiscuit, rootKey.Public(), generateUserKeyPair(t)) - require.Error(t, err) - }) - - t.Run("verify wrong audience", func(t *testing.T) { - signedBiscuit, err := Sign(signableBiscuit, rootKey.Public(), userKey) - require.NoError(t, err) - - _, err = Verify(signedBiscuit, rootKey.Public(), "http://another.audience.url", audienceKey) - require.Error(t, err) - - wrongAudience := "http://another.audience.url" - kms := kmssim.NewClient([]string{wrongAudience}) - wrongAudienceKey, err := kmssign.NewKey(context.Background(), kms, wrongAudience) - require.NoError(t, err) - - _, err = Verify(signedBiscuit, rootKey.Public(), audience, wrongAudienceKey) - require.Error(t, err) - }) -} diff --git a/pkg/biscuit/signature.go b/pkg/biscuit/signature.go deleted file mode 100644 index 078eaa6..0000000 --- a/pkg/biscuit/signature.go +++ /dev/null @@ -1,176 +0,0 @@ -package biscuit - -import ( - "crypto" - "crypto/ecdsa" - "crypto/elliptic" - "crypto/rand" - "crypto/sha256" - "crypto/x509" - "errors" - "fmt" - "time" - - "github.com/flynn/biscuit-go" - "github.com/flynn/hubauth/pkg/kmssign" -) - -var ( - ErrUnsupportedSignatureAlg = errors.New("unsupported signature algorithm") - ErrInvalidSignature = errors.New("invalid signature") -) - -type SignatureAlg biscuit.Symbol - -const ( - ECDSA_P256_SHA256 SignatureAlg = "ECDSA_P256_SHA256" -) - -type userToSignData struct { - DataID biscuit.Integer - Alg biscuit.Symbol - Data biscuit.Bytes -} - -type userSignatureData struct { - DataID biscuit.Integer - UserPubKey biscuit.Bytes - Signature biscuit.Bytes - Nonce biscuit.Bytes - Timestamp biscuit.Date -} - -type userVerificationData struct { - DataID biscuit.Integer - Alg biscuit.Symbol - Data biscuit.Bytes - UserPubKey biscuit.Bytes - Signature biscuit.Bytes - Nonce biscuit.Bytes - Timestamp biscuit.Date -} - -func userSign(tokenHash []byte, userKey *UserKeyPair, toSignData *userToSignData) (*userSignatureData, error) { - if len(tokenHash) == 0 { - return nil, errors.New("invalid tokenHash") - } - - signerTimestamp := time.Now() - signerNonce := make([]byte, nonceSize) - if _, err := rand.Read(signerNonce); err != nil { - return nil, err - } - - var dataToSign []byte - dataToSign = append(dataToSign, toSignData.Data...) - dataToSign = append(dataToSign, tokenHash...) - dataToSign = append(dataToSign, signerNonce...) - dataToSign = append(dataToSign, []byte(signerTimestamp.Format(time.RFC3339))...) - - var signedData biscuit.Bytes - switch SignatureAlg(toSignData.Alg) { - case ECDSA_P256_SHA256: - privKey, err := x509.ParseECPrivateKey(userKey.Private) - if err != nil { - return nil, err - } - hash := sha256.Sum256(dataToSign) - signedData, err = ecdsa.SignASN1(rand.Reader, privKey, hash[:]) - if err != nil { - return nil, err - } - default: - return nil, ErrUnsupportedSignatureAlg - } - - return &userSignatureData{ - DataID: toSignData.DataID, - Nonce: signerNonce, - Signature: signedData, - Timestamp: biscuit.Date(signerTimestamp), - UserPubKey: userKey.Public, - }, nil -} - -func verifyUserSignature(signedTokenHash []byte, data *userVerificationData) error { - var signedData []byte - signedData = append(signedData, data.Data...) - signedData = append(signedData, signedTokenHash...) - signedData = append(signedData, data.Nonce...) - signedData = append(signedData, []byte(time.Time(data.Timestamp).Format(time.RFC3339))...) - - switch SignatureAlg(data.Alg) { - case ECDSA_P256_SHA256: - pk, err := x509.ParsePKIXPublicKey(data.UserPubKey) - if err != nil { - return err - } - pubkey, ok := pk.(*ecdsa.PublicKey) - if !ok { - return errors.New("invalid pubkey, not an *ecdsa.PublicKey") - } - - hash := sha256.Sum256(signedData) - if !ecdsa.VerifyASN1(pubkey, hash[:], data.Signature) { - return ErrInvalidSignature - } - return nil - default: - return ErrUnsupportedSignatureAlg - } -} - -type audienceVerificationData struct { - Audience biscuit.Symbol - Challenge biscuit.Bytes - Signature biscuit.Bytes -} - -func audienceSign(audience string, audienceKey *kmssign.Key) (*audienceVerificationData, error) { - challenge := make([]byte, challengeSize) - if _, err := rand.Reader.Read(challenge); err != nil { - return nil, err - } - - signedData := append(signStaticCtx, challenge...) - signedData = append(signedData, []byte(audience)...) - signedHash := sha256.Sum256(signedData) - signature, err := audienceKey.Sign(rand.Reader, signedHash[:], crypto.SHA256) - if err != nil { - return nil, err - } - - return &audienceVerificationData{ - Audience: biscuit.Symbol(audience), - Challenge: challenge, - Signature: signature, - }, nil -} - -func verifyAudienceSignature(audiencePubkey *kmssign.Key, data *audienceVerificationData) error { - signedData := append(signStaticCtx, data.Challenge...) - signedData = append(signedData, []byte(data.Audience)...) - hash := sha256.Sum256(signedData) - if !audiencePubkey.Verify(hash[:], data.Signature) { - return errors.New("invalid signature") - } - return nil -} - -func validatePKIXP256PublicKey(pubkey []byte) error { - key, err := x509.ParsePKIXPublicKey(pubkey) - if err != nil { - return fmt.Errorf("failed to parse PKIX, ASN.1 DER public key: %v", err) - } - - ecKey, ok := key.(*ecdsa.PublicKey) - if !ok { - return errors.New("public key is not an *ecdsa.PublicKey") - } - - if ecKey.Curve != elliptic.P256() { - return fmt.Errorf("publickey is on wrong curve, expected P256") - } - - return nil -} diff --git a/pkg/biscuit/signature_test.go b/pkg/biscuit/signature_test.go deleted file mode 100644 index 2351e26..0000000 --- a/pkg/biscuit/signature_test.go +++ /dev/null @@ -1,230 +0,0 @@ -package biscuit - -import ( - "crypto/ecdsa" - "crypto/elliptic" - "crypto/rand" - "crypto/rsa" - "crypto/x509" - "testing" - "time" - - "github.com/flynn/biscuit-go" - "github.com/stretchr/testify/require" -) - -func TestUserSignVerify(t *testing.T) { - tokenHash := make([]byte, 32) - _, err := rand.Read(tokenHash) - require.NoError(t, err) - - challenge := make([]byte, challengeSize) - _, err = rand.Read(challenge) - require.NoError(t, err) - - userKey := generateUserKeyPair(t) - - toSignData := &userToSignData{ - DataID: 1, - Alg: biscuit.Symbol(ECDSA_P256_SHA256), - Data: []byte("challenge"), - } - - signedData, err := userSign(tokenHash, userKey, toSignData) - require.NoError(t, err) - require.NotEmpty(t, signedData.Signature) - require.Equal(t, biscuit.Integer(1), signedData.DataID) - require.Equal(t, biscuit.Bytes(userKey.Public), signedData.UserPubKey) - - require.Len(t, signedData.Nonce, nonceSize) - zeroNonce := make([]byte, nonceSize) - require.NotEqual(t, biscuit.Bytes(zeroNonce), signedData.Nonce) - - require.WithinDuration(t, time.Now(), time.Time(signedData.Timestamp), 1*time.Second) - - require.NoError(t, verifyUserSignature(tokenHash, &userVerificationData{ - DataID: toSignData.DataID, - Alg: toSignData.Alg, - Data: toSignData.Data, - Nonce: signedData.Nonce, - Signature: signedData.Signature, - Timestamp: signedData.Timestamp, - UserPubKey: signedData.UserPubKey, - })) -} - -func TestUserSignFail(t *testing.T) { - validTokenHash := make([]byte, 32) - _, err := rand.Read(validTokenHash) - require.NoError(t, err) - - validChallenge := make([]byte, challengeSize) - _, err = rand.Read(validChallenge) - require.NoError(t, err) - - invalidPrivateKey := &UserKeyPair{ - Private: make([]byte, 32), - } - - testCases := []struct { - desc string - tokenHash []byte - userKey *UserKeyPair - data *userToSignData - expectedErr error - }{ - { - desc: "empty tokenHash", - tokenHash: []byte{}, - }, - { - desc: "unsupported alg", - tokenHash: validTokenHash, - data: &userToSignData{ - Alg: "unsupported", - }, - expectedErr: ErrUnsupportedSignatureAlg, - }, - { - desc: "wrong private key encoding", - tokenHash: validTokenHash, - data: &userToSignData{ - Alg: biscuit.Symbol(ECDSA_P256_SHA256), - }, - userKey: invalidPrivateKey, - }, - } - - for _, testCase := range testCases { - t.Run(testCase.desc, func(t *testing.T) { - _, err := userSign(testCase.tokenHash, testCase.userKey, testCase.data) - require.Error(t, err) - if testCase.expectedErr != nil { - require.Equal(t, testCase.expectedErr, err) - } - }) - } -} - -func TestVerifyUserSignatureFail(t *testing.T) { - tokenHash := []byte("token hash") - toSignData := &userToSignData{ - DataID: 1, - Alg: biscuit.Symbol(ECDSA_P256_SHA256), - Data: []byte("challenge"), - } - - userKey := generateUserKeyPair(t) - invalidKey := generateUserKeyPair(t) - - signedData, err := userSign(tokenHash, userKey, toSignData) - require.NoError(t, err) - - rsaKey, err := rsa.GenerateKey(rand.Reader, 1024) - require.NoError(t, err) - wrongKeyKind, err := x509.MarshalPKIXPublicKey(&rsaKey.PublicKey) - require.NoError(t, err) - - testCases := []struct { - desc string - tokenHash []byte - data *userVerificationData - expectedErr error - }{ - { - desc: "unsupported alg", - expectedErr: ErrUnsupportedSignatureAlg, - data: &userVerificationData{ - Alg: "unknown", - }, - }, - { - desc: "invalid pubkey encoding", - data: &userVerificationData{ - Alg: biscuit.Symbol(ECDSA_P256_SHA256), - UserPubKey: make([]byte, 32), - }, - }, - { - desc: "invalid pubkey kind", - data: &userVerificationData{ - Alg: biscuit.Symbol(ECDSA_P256_SHA256), - UserPubKey: wrongKeyKind, - }, - }, - { - desc: "wrong pubkey", - tokenHash: tokenHash, - data: &userVerificationData{ - Alg: biscuit.Symbol(ECDSA_P256_SHA256), - UserPubKey: invalidKey.Public, - Data: toSignData.Data, - DataID: toSignData.DataID, - Nonce: signedData.Nonce, - Signature: signedData.Signature, - Timestamp: signedData.Timestamp, - }, - }, - { - desc: "tampered token hash", - expectedErr: ErrInvalidSignature, - tokenHash: []byte("wrong"), - data: &userVerificationData{ - Alg: biscuit.Symbol(ECDSA_P256_SHA256), - UserPubKey: userKey.Public, - Data: toSignData.Data, - DataID: toSignData.DataID, - Nonce: signedData.Nonce, - Signature: signedData.Signature, - Timestamp: signedData.Timestamp, - }, - }, - { - desc: "tampered nonce", - expectedErr: ErrInvalidSignature, - tokenHash: tokenHash, - data: &userVerificationData{ - Alg: biscuit.Symbol(ECDSA_P256_SHA256), - UserPubKey: userKey.Public, - Data: toSignData.Data, - DataID: toSignData.DataID, - Nonce: []byte("another nonce"), - Signature: signedData.Signature, - Timestamp: signedData.Timestamp, - }, - }, - { - desc: "tampered timestamp", - expectedErr: ErrInvalidSignature, - tokenHash: tokenHash, - data: &userVerificationData{ - Alg: biscuit.Symbol(ECDSA_P256_SHA256), - UserPubKey: userKey.Public, - Data: toSignData.Data, - DataID: toSignData.DataID, - Nonce: signedData.Nonce, - Signature: signedData.Signature, - Timestamp: biscuit.Date(time.Now().Add(1 * time.Second)), - }, - }, - } - - for _, testCase := range testCases { - t.Run(testCase.desc, func(t *testing.T) { - err := verifyUserSignature(testCase.tokenHash, testCase.data) - require.Error(t, err) - if testCase.expectedErr != nil { - require.Equal(t, testCase.expectedErr, err) - } - }) - } -} - -func generateUserKeyPair(t *testing.T) *UserKeyPair { - priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - require.NoError(t, err) - - kp, err := NewECDSAKeyPair(priv) - require.NoError(t, err) - return kp -} diff --git a/pkg/biscuit/wrapper.go b/pkg/biscuit/wrapper.go deleted file mode 100644 index f39681f..0000000 --- a/pkg/biscuit/wrapper.go +++ /dev/null @@ -1,426 +0,0 @@ -package biscuit - -import ( - "bytes" - "crypto/rand" - "errors" - "fmt" - "time" - - "github.com/flynn/biscuit-go" - "github.com/flynn/biscuit-go/datalog" - "github.com/flynn/hubauth/pkg/kmssign" -) - -var ( - ErrAlreadySigned = errors.New("already signed") - ErrInvalidToSignDataPrefix = errors.New("invalid to_sign data prefix") -) - -var ( - signStaticCtx = []byte("biscuit-pop-v0") - challengeSize = 16 - nonceSize = 16 -) - -type hubauthBuilder struct { - biscuit.Builder -} - -// withUserToSignFact add an authority should_sign fact and associated data to the biscuit -// with an authority caveat requiring the verifier to provide a valid_signature fact. -// the verifier is responsible of ensuring that a valid signature exists over the data. -func (b *hubauthBuilder) withUserToSignFact(userPubkey []byte) error { - dataID := biscuit.Integer(0) - - if err := validatePKIXP256PublicKey(userPubkey); err != nil { - return err - } - - if err := b.AddAuthorityFact(biscuit.Fact{Predicate: biscuit.Predicate{ - Name: "should_sign", - IDs: []biscuit.Atom{ - dataID, - biscuit.Symbol(ECDSA_P256_SHA256), - biscuit.Bytes(userPubkey), - }, - }}); err != nil { - return err - } - - challenge := make([]byte, challengeSize) - if _, err := rand.Reader.Read(challenge); err != nil { - return err - } - - if err := b.AddAuthorityFact(biscuit.Fact{Predicate: biscuit.Predicate{ - Name: "data", - IDs: []biscuit.Atom{ - dataID, - biscuit.Bytes(append(signStaticCtx, challenge...)), - }, - }}); err != nil { - return err - } - - if err := b.AddAuthorityCaveat(biscuit.Rule{ - Head: biscuit.Predicate{Name: "valid", IDs: []biscuit.Atom{biscuit.Variable(0)}}, - Body: []biscuit.Predicate{ - {Name: "should_sign", IDs: []biscuit.Atom{biscuit.SymbolAuthority, biscuit.Variable(0), biscuit.Variable(1), biscuit.Variable(2)}}, - {Name: "valid_signature", IDs: []biscuit.Atom{biscuit.Symbol("ambient"), biscuit.Variable(0), biscuit.Variable(1), biscuit.Variable(2)}}, - }, - }); err != nil { - return err - } - - return nil -} - -// withAudienceSignature add an authority audience_signature fact, containing a challenge and -// a matching signature using the audience key. -// the verifier is responsible of providing a valid_audience_signature fact, after -// verifying the signature using the audience pubkey. -func (b *hubauthBuilder) withAudienceSignature(audience string, audienceKey *kmssign.Key) error { - if len(audience) == 0 { - return errors.New("audience is required") - } - - data, err := audienceSign(audience, audienceKey) - if err != nil { - return err - } - - if err := b.AddAuthorityFact(biscuit.Fact{Predicate: biscuit.Predicate{ - Name: "audience_signature", - IDs: []biscuit.Atom{ - data.Audience, - data.Challenge, - data.Signature, - }, - }}); err != nil { - return err - } - - if err := b.AddAuthorityCaveat(biscuit.Rule{ - Head: biscuit.Predicate{Name: "valid_audience", IDs: []biscuit.Atom{biscuit.Variable(0)}}, - Body: []biscuit.Predicate{ - {Name: "audience_signature", IDs: []biscuit.Atom{biscuit.SymbolAuthority, biscuit.Variable(0), biscuit.Variable(1), biscuit.Variable(2)}}, - {Name: "valid_audience_signature", IDs: []biscuit.Atom{biscuit.Symbol("ambient"), biscuit.Variable(0), biscuit.Variable(2)}}, - }, - }); err != nil { - return err - } - - return nil -} - -func (b *hubauthBuilder) withMetadata(m *Metadata) error { - return b.AddAuthorityFact(biscuit.Fact{Predicate: biscuit.Predicate{ - Name: "metadata", - IDs: []biscuit.Atom{ - biscuit.String(m.ClientID), - biscuit.String(m.UserID), - biscuit.String(m.UserEmail), - biscuit.Date(m.IssueTime), - }, - }}) -} - -func (b *hubauthBuilder) withExpire(exp time.Time) error { - if err := b.AddAuthorityCaveat(biscuit.Rule{ - Head: biscuit.Predicate{Name: "not_expired", IDs: []biscuit.Atom{biscuit.Variable(0)}}, - Body: []biscuit.Predicate{ - {Name: "current_time", IDs: []biscuit.Atom{biscuit.Symbol("ambient"), biscuit.Variable(0)}}, - }, - Constraints: []biscuit.Constraint{{ - Name: biscuit.Variable(0), - Checker: biscuit.DateComparisonChecker{ - Comparison: datalog.DateComparisonBefore, - Date: biscuit.Date(exp), - }, - }}, - }); err != nil { - return err - } - - return nil -} - -type hubauthBlockBuilder struct { - biscuit.BlockBuilder -} - -func (b *hubauthBlockBuilder) withUserSignature(sigData *userSignatureData) error { - return b.AddFact(biscuit.Fact{Predicate: biscuit.Predicate{ - Name: "signature", - IDs: []biscuit.Atom{ - sigData.DataID, - sigData.UserPubKey, - sigData.Signature, - sigData.Nonce, - sigData.Timestamp, - }, - }}) -} - -type hubauthVerifier struct { - biscuit.Verifier -} - -func (v *hubauthVerifier) getUserToSignData(userPubKey biscuit.Bytes) (*userToSignData, error) { - toSign, err := v.Query(biscuit.Rule{ - Head: biscuit.Predicate{ - Name: "to_sign", - IDs: []biscuit.Atom{biscuit.Variable(0), biscuit.Variable(1), biscuit.Variable(2)}, - }, - Body: []biscuit.Predicate{ - { - Name: "should_sign", IDs: []biscuit.Atom{ - biscuit.SymbolAuthority, - biscuit.Variable(0), - biscuit.Variable(1), - biscuit.Bytes(userPubKey), - }, - }, { - Name: "data", IDs: []biscuit.Atom{ - biscuit.SymbolAuthority, - biscuit.Variable(0), - biscuit.Variable(2), - }, - }, - }, - }) - if err != nil { - return nil, err - } - - if g, w := len(toSign), 1; g != w { - return nil, fmt.Errorf("invalid to_sign fact count, got %d, want %d", g, w) - } - - toSignFact := toSign[0] - if g, w := len(toSignFact.IDs), 3; g != w { - return nil, fmt.Errorf("invalid to_sign fact, got %d atoms, want %d", g, w) - } - - sigData := &userToSignData{} - var ok bool - sigData.DataID, ok = toSign[0].IDs[0].(biscuit.Integer) - if !ok { - return nil, errors.New("invalid to_sign atom: dataID") - } - sigData.Alg, ok = toSign[0].IDs[1].(biscuit.Symbol) - if !ok { - return nil, errors.New("invalid to_sign atom: alg") - } - sigData.Data, ok = toSign[0].IDs[2].(biscuit.Bytes) - if !ok { - return nil, errors.New("invalid to_sign atom: data") - } - - if !bytes.HasPrefix(sigData.Data, signStaticCtx) { - return nil, ErrInvalidToSignDataPrefix - } - - return sigData, nil -} - -func (v *hubauthVerifier) ensureNotAlreadyUserSigned(dataID biscuit.Integer, userPubKey biscuit.Bytes) error { - alreadySigned, err := v.Query(biscuit.Rule{ - Head: biscuit.Predicate{Name: "already_signed", IDs: []biscuit.Atom{biscuit.Variable(0)}}, - Body: []biscuit.Predicate{ - {Name: "signature", IDs: []biscuit.Atom{dataID, userPubKey, biscuit.Variable(0)}}, - }, - }) - if err != nil { - return err - } - if len(alreadySigned) != 0 { - return ErrAlreadySigned - } - - return nil -} - -func (v *hubauthVerifier) getUserVerificationData() (*userVerificationData, error) { - toValidate, err := v.Query(biscuit.Rule{ - Head: biscuit.Predicate{ - Name: "to_validate", - IDs: []biscuit.Atom{ - biscuit.Variable(0), // dataID - biscuit.Variable(1), // alg - biscuit.Variable(2), // pubkey - biscuit.Variable(3), // data - biscuit.Variable(4), // signature - biscuit.Variable(5), // signerNonce - biscuit.Variable(6), // signerTimestamp - }}, - Body: []biscuit.Predicate{ - {Name: "should_sign", IDs: []biscuit.Atom{biscuit.SymbolAuthority, biscuit.Variable(0), biscuit.Variable(1), biscuit.Variable(2)}}, - {Name: "data", IDs: []biscuit.Atom{biscuit.SymbolAuthority, biscuit.Variable(0), biscuit.Variable(3)}}, - {Name: "signature", IDs: []biscuit.Atom{biscuit.Variable(0), biscuit.Variable(2), biscuit.Variable(4), biscuit.Variable(5), biscuit.Variable(6)}}, - }, - }) - if err != nil { - return nil, err - } - - if g, w := len(toValidate), 1; g != w { - return nil, fmt.Errorf("invalid to_validate fact count, got %d, want %d", g, w) - } - - toValidateFact := toValidate[0] - if g, w := len(toValidateFact.IDs), 7; g != w { - return nil, fmt.Errorf("invalid to_valid fact atom count, got %d, want %d", g, w) - } - - toVerify := &userVerificationData{} - var ok bool - toVerify.DataID, ok = toValidateFact.IDs[0].(biscuit.Integer) - if !ok { - return nil, errors.New("invalid to_validate atom: dataID") - } - toVerify.Alg, ok = toValidateFact.IDs[1].(biscuit.Symbol) - if !ok { - return nil, errors.New("invalid to_validate atom: alg") - } - toVerify.UserPubKey, ok = toValidateFact.IDs[2].(biscuit.Bytes) - if !ok { - return nil, errors.New("invalid to_validate atom: userPubKey") - } - toVerify.Data, ok = toValidateFact.IDs[3].(biscuit.Bytes) - if !ok { - return nil, errors.New("invalid to_validate atom: data") - } - toVerify.Signature, ok = toValidateFact.IDs[4].(biscuit.Bytes) - if !ok { - return nil, errors.New("invalid to_validate atom: signature") - } - toVerify.Nonce, ok = toValidateFact.IDs[5].(biscuit.Bytes) - if !ok { - return nil, errors.New("invalid to_validate atom: nonce") - } - toVerify.Timestamp, ok = toValidateFact.IDs[6].(biscuit.Date) - if !ok { - return nil, errors.New("invalid to_validate atom: timestamp") - } - - return toVerify, nil -} - -func (v *hubauthVerifier) withValidatedUserSignature(data *userVerificationData) error { - v.AddFact(biscuit.Fact{Predicate: biscuit.Predicate{ - Name: "valid_signature", - IDs: []biscuit.Atom{biscuit.Symbol("ambient"), data.DataID, data.Alg, data.UserPubKey}, - }}) - - return nil -} - -func (v *hubauthVerifier) getAudienceVerificationData(audience string) (*audienceVerificationData, error) { - toValidate, err := v.Query(biscuit.Rule{ - Head: biscuit.Predicate{ - Name: "audience_to_validate", - IDs: []biscuit.Atom{ - biscuit.Variable(0), // challenge - biscuit.Variable(1), // signature - }}, - Body: []biscuit.Predicate{ - {Name: "audience_signature", IDs: []biscuit.Atom{biscuit.SymbolAuthority, biscuit.Symbol(audience), biscuit.Variable(0), biscuit.Variable(1)}}, - }, - }) - if err != nil { - return nil, err - } - - if g, w := len(toValidate), 1; g != w { - return nil, fmt.Errorf("invalid audience_to_validate fact count, got %d, want %d", g, w) - } - - toValidateFact := toValidate[0] - if g, w := len(toValidateFact.IDs), 2; g != w { - return nil, fmt.Errorf("invalid audience_to_validate fact atom count, got %d, want %d", g, w) - } - - toVerify := &audienceVerificationData{Audience: biscuit.Symbol(audience)} - var ok bool - toVerify.Challenge, ok = toValidateFact.IDs[0].(biscuit.Bytes) - if !ok { - return nil, errors.New("invalid audience_to_validate atom: challenge") - } - toVerify.Signature, ok = toValidateFact.IDs[1].(biscuit.Bytes) - if !ok { - return nil, errors.New("invalid audience_to_validate atom: signature") - } - - return toVerify, nil -} - -func (v *hubauthVerifier) getMetadata() (*Metadata, error) { - metaFacts, err := v.Query(biscuit.Rule{ - Head: biscuit.Predicate{ - Name: "metadata", - IDs: []biscuit.Atom{ - biscuit.Variable(0), // clientID - biscuit.Variable(1), // userID - biscuit.Variable(2), // userEmail - biscuit.Variable(3), // issueTime - }}, - Body: []biscuit.Predicate{ - {Name: "metadata", IDs: []biscuit.Atom{biscuit.SymbolAuthority, biscuit.Variable(0), biscuit.Variable(1), biscuit.Variable(2), biscuit.Variable(3)}}, - }, - }) - if err != nil { - return nil, err - } - - if g, w := len(metaFacts), 1; g != w { - return nil, fmt.Errorf("invalid metadata fact count, got %d, want %d", g, w) - } - - metaFact := metaFacts[0] - - clientID, ok := metaFact.IDs[0].(biscuit.String) - if !ok { - return nil, errors.New("invalid metadata atom: clientID") - } - userID, ok := metaFact.IDs[1].(biscuit.String) - if !ok { - return nil, errors.New("invalid metadata atom: userID") - } - userEmail, ok := metaFact.IDs[2].(biscuit.String) - if !ok { - return nil, errors.New("invalid metadata atom: userEmail") - } - issueTime, ok := metaFact.IDs[3].(biscuit.Date) - if !ok { - return nil, errors.New("invalid metadata atom: issueTime") - } - return &Metadata{ - ClientID: string(clientID), - UserID: string(userID), - UserEmail: string(userEmail), - IssueTime: time.Time(issueTime), - }, nil -} - -func (v *hubauthVerifier) withValidatedAudienceSignature(data *audienceVerificationData) error { - v.AddFact(biscuit.Fact{Predicate: biscuit.Predicate{ - Name: "valid_audience_signature", - IDs: []biscuit.Atom{biscuit.Symbol("ambient"), data.Audience, data.Signature}, - }}) - - return nil -} - -func (v *hubauthVerifier) withCurrentTime(t time.Time) error { - v.AddFact(biscuit.Fact{Predicate: biscuit.Predicate{ - Name: "current_time", - IDs: []biscuit.Atom{ - biscuit.Symbol("ambient"), - biscuit.Date(t), - }, - }}) - - return nil -} From 3bae14ba441c4caccbae72e272a9a24757dd3a11 Mon Sep 17 00:00:00 2001 From: Flavien Binet Date: Mon, 14 Dec 2020 10:42:27 +0100 Subject: [PATCH 05/24] audience policies migration step1 Migration for audience.Policies field to audience.UserGroups. The application still rely on audience.Policies field. A new CLI command allows to update all audiences, copying Policies field into UserGroups. The next step will remove this command, and update the application to rely on the new UserGroups field. --- pkg/cli/audiences.go | 48 +++++++++++---------------------------- pkg/datastore/audience.go | 6 +++++ pkg/hubauth/data.go | 1 + 3 files changed, 20 insertions(+), 35 deletions(-) diff --git a/pkg/cli/audiences.go b/pkg/cli/audiences.go index de43b0a..bbdda19 100644 --- a/pkg/cli/audiences.go +++ b/pkg/cli/audiences.go @@ -206,43 +206,8 @@ func (c *audiencesSetUserGroupsCmd) Run(cfg *Config) error { } return cfg.DB.MutateAudience(context.Background(), c.AudienceURL, []*hubauth.AudienceMutation{mut}) } - -type audiencesUpdateUserGroupsCmd struct { - AudienceURL string `kong:"required,name='audience-url',help='audience URL'"` - Domain string `kong:"required,help='G Suite domain name'"` - APIUser string `kong:"name='api-user',help='G Suite user email to impersonate for API calls'"` - AddGroups []string `kong:"name='add-groups',help='comma-separated group IDs to add'"` - DeleteGroups []string `kong:"name='delete-groups',help='comma-separated group IDs to delete'"` -} - -func (c *audiencesUpdateUserGroupsCmd) Run(cfg *Config) error { - var muts []*hubauth.AudienceUserGroupsMutation - for _, groupID := range c.AddGroups { - muts = append(muts, &hubauth.AudienceUserGroupsMutation{ - Op: hubauth.AudienceUserGroupsMutationOpAddGroup, - Group: groupID, - }) - } - for _, groupID := range c.DeleteGroups { - muts = append(muts, &hubauth.AudienceUserGroupsMutation{ - Op: hubauth.AudienceUserGroupsMutationOpDeleteGroup, - Group: groupID, }) - } - if c.APIUser != "" { - muts = append(muts, &hubauth.AudienceUserGroupsMutation{ - Op: hubauth.AudienceUserGroupsMutationOpSetAPIUser, - APIUser: c.APIUser, - }) - } - - return cfg.DB.MutateAudienceUserGroups(context.Background(), c.AudienceURL, c.Domain, muts) -} -type audiencesDeleteUserGroupsCmd struct { - AudienceURL string `kong:"required,name='audience-url',help='audience URL'"` - Domain string `kong:"required,help='G Suite domain name'"` -} func (c *audiencesDeleteUserGroupsCmd) Run(cfg *Config) error { mut := &hubauth.AudienceMutation{ @@ -285,3 +250,16 @@ func (c *audiencesKeyCmd) Run(cfg *Config) error { fmt.Println(base64.URLEncoding.EncodeToString(b.Bytes)) return nil } + +type audienceMigratePoliciesCmd struct { +} + +func (c *audienceMigratePoliciesCmd) Run(cfg *Config) error { + policies, err := cfg.DB.ListAudiences(ctx) + for _, p := range policies { + return fmt.Errorf("Failed to migrate policies to userGroups for audience %q", p.URL) + } + fmt.Printf("Success migrating policies to userGroups for %q\n", p.URL) + } + return nil +} diff --git a/pkg/datastore/audience.go b/pkg/datastore/audience.go index 4fad7c8..d05ab72 100644 --- a/pkg/datastore/audience.go +++ b/pkg/datastore/audience.go @@ -172,6 +172,12 @@ func (s *service) MutateAudience(ctx context.Context, url string, mut []*hubauth } aud.Type = m.Type modified = true + case hubauth.AudienceMutationMigratePolicy: + aud.UserGroups = make([]googleUserGroups, len(aud.Policies)) + for i, ug := range aud.Policies { + aud.UserGroups[i] = ug + } + modified = true default: return fmt.Errorf("datastore: unknown audience mutation op %s", m.Op) } diff --git a/pkg/hubauth/data.go b/pkg/hubauth/data.go index dd5752f..628cba9 100644 --- a/pkg/hubauth/data.go +++ b/pkg/hubauth/data.go @@ -91,6 +91,7 @@ const ( AudienceMutationOpSetUserGroups AudienceMutationOpDeleteUserGroups AudienceMutationSetType + AudienceMutationMigratePolicy ) type AudienceMutation struct { From 09d999f2a6f6d954bea41257420844413a2be4b9 Mon Sep 17 00:00:00 2001 From: daeMOn Date: Thu, 8 Oct 2020 16:53:04 +0200 Subject: [PATCH 06/24] add biscuit wrappers and helpers to generate, sign and verify hubauth biscuits --- pkg/biscuit/biscuit.go | 137 ++++++++++++++ pkg/biscuit/biscuit_test.go | 57 ++++++ pkg/biscuit/signature.go | 162 ++++++++++++++++ pkg/biscuit/signature_test.go | 258 ++++++++++++++++++++++++++ pkg/biscuit/wrapper.go | 336 ++++++++++++++++++++++++++++++++++ 5 files changed, 950 insertions(+) create mode 100644 pkg/biscuit/biscuit.go create mode 100644 pkg/biscuit/biscuit_test.go create mode 100644 pkg/biscuit/signature.go create mode 100644 pkg/biscuit/signature_test.go create mode 100644 pkg/biscuit/wrapper.go diff --git a/pkg/biscuit/biscuit.go b/pkg/biscuit/biscuit.go new file mode 100644 index 0000000..ef1a452 --- /dev/null +++ b/pkg/biscuit/biscuit.go @@ -0,0 +1,137 @@ +package biscuit + +import ( + "crypto/rand" + "fmt" + + "github.com/flynn/biscuit-go" + "github.com/flynn/biscuit-go/sig" + "github.com/flynn/hubauth/pkg/kmssign" +) + +type UserKeyPair struct { + Public []byte + Private []byte +} + +// GenerateSignable returns a biscuit which will only verify after being +// signed with the private key matching the given userPubkey. +func GenerateSignable(rootKey sig.Keypair, audience string, audienceKey *kmssign.Key, userPubkey []byte) ([]byte, error) { + builder := &hubauthBuilder{ + Builder: biscuit.NewBuilder(rand.Reader, rootKey), + } + + if err := builder.withAudienceSignature(audience, audienceKey); err != nil { + return nil, err + } + + if err := builder.withUserToSignFact(userPubkey); err != nil { + return nil, err + } + + b, err := builder.Build() + if err != nil { + return nil, err + } + + return b.Serialize() +} + +// Sign append a user signature on the given token and return it. +// The UserKeyPair key format to provide depends on the signature algorithm: +// - for ECDSA_P256_SHA256, the private key must be encoded in SEC 1, ASN.1 DER form, +// and the public key in PKIX, ASN.1 DER form. +func Sign(token []byte, rootPubKey sig.PublicKey, userKey *UserKeyPair) ([]byte, error) { + b, err := biscuit.Unmarshal(token) + if err != nil { + return nil, fmt.Errorf("biscuit: failed to unmarshal: %w", err) + } + + v, err := b.Verify(rootPubKey) + if err != nil { + return nil, fmt.Errorf("biscuit: failed to verify: %w", err) + } + verifier := &hubauthVerifier{ + Verifier: v, + } + + toSignData, err := verifier.getUserToSignData(userKey.Public, b.BlockCount()) + if err != nil { + return nil, fmt.Errorf("biscuit: failed to get to_sign data: %w", err) + } + + if err := verifier.ensureNotAlreadyUserSigned(toSignData.DataID, userKey.Public); err != nil { + return nil, fmt.Errorf("biscuit: previous signature check failed: %w", err) + } + + tokenHash, err := b.SHA256Sum(b.BlockCount()) + if err != nil { + return nil, err + } + + signData, err := userSign(tokenHash, userKey, toSignData) + if err != nil { + return nil, fmt.Errorf("biscuit: signature failed: %w", err) + } + + builder := &hubauthBlockBuilder{ + BlockBuilder: b.CreateBlock(), + } + if err := builder.withUserSignature(signData); err != nil { + return nil, fmt.Errorf("biscuit: failed to create signature block: %w", err) + } + + clientKey := sig.GenerateKeypair(rand.Reader) + b, err = b.Append(rand.Reader, clientKey, builder.Build()) + if err != nil { + return nil, fmt.Errorf("biscuit: failed to append signature block: %w", err) + } + + return b.Serialize() +} + +// Verify will verify the biscuit, the included audience and user signature, and return an error +// when anything is invalid. +func Verify(token []byte, rootPubKey sig.PublicKey, audience string, audienceKey *kmssign.Key) error { + b, err := biscuit.Unmarshal(token) + if err != nil { + return fmt.Errorf("biscuit: failed to unmarshal: %w", err) + } + + v, err := b.Verify(rootPubKey) + if err != nil { + return fmt.Errorf("biscuit: failed to verify: %w", err) + } + verifier := &hubauthVerifier{v} + + audienceVerificationData, err := verifier.getAudienceVerificationData(audience) + if err != nil { + return fmt.Errorf("biscuit: failed to retrieve audience signature data: %w", err) + } + + if err := verifyAudienceSignature(audienceKey, audienceVerificationData); err != nil { + return fmt.Errorf("biscuit: failed to verify audience signature: %w", err) + } + if err := verifier.withValidatedAudienceSignature(audienceVerificationData); err != nil { + return fmt.Errorf("biscuit: failed to add validated signature: %w", err) + } + + userVerificationData, err := verifier.getUserVerificationData() + if err != nil { + return fmt.Errorf("biscuit: failed to retrieve user signature data: %w", err) + } + + signedTokenHash, err := b.SHA256Sum(int(userVerificationData.SignedBlockCount)) + if err != nil { + return fmt.Errorf("biscuit: failed to generate token hash: %w", err) + } + + if err := verifyUserSignature(signedTokenHash, userVerificationData); err != nil { + return fmt.Errorf("biscuit: failed to verify user signature: %w", err) + } + if err := verifier.withValidatedUserSignature(userVerificationData); err != nil { + return fmt.Errorf("biscuit: failed to add validated signature: %w", err) + } + + return verifier.Verify() +} diff --git a/pkg/biscuit/biscuit_test.go b/pkg/biscuit/biscuit_test.go new file mode 100644 index 0000000..ed7b273 --- /dev/null +++ b/pkg/biscuit/biscuit_test.go @@ -0,0 +1,57 @@ +package biscuit + +import ( + "context" + "crypto/rand" + "testing" + + "github.com/flynn/biscuit-go/sig" + "github.com/flynn/hubauth/pkg/kmssign" + "github.com/flynn/hubauth/pkg/kmssign/kmssim" + "github.com/stretchr/testify/require" +) + +func TestBiscuit(t *testing.T) { + rootKey := sig.GenerateKeypair(rand.Reader) + audience := "http://random.audience.url" + + kms := kmssim.NewClient([]string{audience}) + audienceKey, err := kmssign.NewKey(context.Background(), kms, audience) + require.NoError(t, err) + + userKey := generateUserKeyPair(t) + + signableBiscuit, err := GenerateSignable(rootKey, audience, audienceKey, userKey.Public) + require.NoError(t, err) + t.Logf("signable biscuit size: %d", len(signableBiscuit)) + + t.Run("happy path", func(t *testing.T) { + signedBiscuit, err := Sign(signableBiscuit, rootKey.Public(), userKey) + require.NoError(t, err) + t.Logf("signed biscuit size: %d", len(signedBiscuit)) + + err = Verify(signedBiscuit, rootKey.Public(), audience, audienceKey) + require.NoError(t, err) + }) + + t.Run("user sign with wrong key", func(t *testing.T) { + _, err := Sign(signableBiscuit, rootKey.Public(), generateUserKeyPair(t)) + require.Error(t, err) + }) + + t.Run("verify wrong audience", func(t *testing.T) { + signedBiscuit, err := Sign(signableBiscuit, rootKey.Public(), userKey) + require.NoError(t, err) + + err = Verify(signedBiscuit, rootKey.Public(), "http://another.audience.url", audienceKey) + require.Error(t, err) + + wrongAudience := "http://another.audience.url" + kms := kmssim.NewClient([]string{wrongAudience}) + wrongAudienceKey, err := kmssign.NewKey(context.Background(), kms, wrongAudience) + require.NoError(t, err) + + err = Verify(signedBiscuit, rootKey.Public(), audience, wrongAudienceKey) + require.Error(t, err) + }) +} diff --git a/pkg/biscuit/signature.go b/pkg/biscuit/signature.go new file mode 100644 index 0000000..385e4bf --- /dev/null +++ b/pkg/biscuit/signature.go @@ -0,0 +1,162 @@ +package biscuit + +import ( + "crypto" + "crypto/ecdsa" + "crypto/rand" + "crypto/sha256" + "crypto/x509" + "errors" + "time" + + "github.com/flynn/biscuit-go" + "github.com/flynn/hubauth/pkg/kmssign" +) + +var ( + ErrUnsupportedSignatureAlg = errors.New("unsupported signature algorithm") + ErrInvalidSignature = errors.New("invalid signature") +) + +type SignatureAlg biscuit.Symbol + +const ( + ECDSA_P256_SHA256 SignatureAlg = "ECDSA_P256_SHA256" +) + +type userToSignData struct { + DataID biscuit.Integer + Alg biscuit.Symbol + Data biscuit.Bytes + SignedBlockCount biscuit.Integer +} + +type userSignatureData struct { + DataID biscuit.Integer + UserPubKey biscuit.Bytes + Signature biscuit.Bytes + SignedBlockCount biscuit.Integer + Nonce biscuit.Bytes + Timestamp biscuit.Date +} + +type userVerificationData struct { + DataID biscuit.Integer + Alg biscuit.Symbol + Data biscuit.Bytes + UserPubKey biscuit.Bytes + Signature biscuit.Bytes + SignedBlockCount biscuit.Integer + Nonce biscuit.Bytes + Timestamp biscuit.Date +} + +func userSign(tokenHash []byte, userKey *UserKeyPair, toSignData *userToSignData) (*userSignatureData, error) { + if len(tokenHash) == 0 { + return nil, errors.New("invalid tokenHash") + } + + signerTimestamp := time.Now() + signerNonce := make([]byte, nonceSize) + if _, err := rand.Read(signerNonce); err != nil { + return nil, err + } + + var dataToSign []byte + dataToSign = append(dataToSign, toSignData.Data...) + dataToSign = append(dataToSign, tokenHash...) + dataToSign = append(dataToSign, signerNonce...) + dataToSign = append(dataToSign, []byte(signerTimestamp.Format(time.RFC3339))...) + dataToSign = append(dataToSign, []byte(toSignData.SignedBlockCount.String())...) + + var signedData biscuit.Bytes + switch SignatureAlg(toSignData.Alg) { + case ECDSA_P256_SHA256: + privKey, err := x509.ParseECPrivateKey(userKey.Private) + if err != nil { + return nil, err + } + hash := sha256.Sum256(dataToSign) + signedData, err = ecdsa.SignASN1(rand.Reader, privKey, hash[:]) + if err != nil { + return nil, err + } + default: + return nil, ErrUnsupportedSignatureAlg + } + + return &userSignatureData{ + DataID: toSignData.DataID, + Nonce: signerNonce, + Signature: signedData, + SignedBlockCount: toSignData.SignedBlockCount, + Timestamp: biscuit.Date(signerTimestamp), + UserPubKey: userKey.Public, + }, nil +} + +func verifyUserSignature(signedTokenHash []byte, data *userVerificationData) error { + var signedData []byte + signedData = append(signedData, data.Data...) + signedData = append(signedData, signedTokenHash...) + signedData = append(signedData, data.Nonce...) + signedData = append(signedData, []byte(time.Time(data.Timestamp).Format(time.RFC3339))...) + signedData = append(signedData, []byte(data.SignedBlockCount.String())...) + + switch SignatureAlg(data.Alg) { + case ECDSA_P256_SHA256: + pk, err := x509.ParsePKIXPublicKey(data.UserPubKey) + if err != nil { + return err + } + pubkey, ok := pk.(*ecdsa.PublicKey) + if !ok { + return errors.New("invalid pubkey, not an *ecdsa.PublicKey") + } + + hash := sha256.Sum256(signedData) + if !ecdsa.VerifyASN1(pubkey, hash[:], data.Signature) { + return ErrInvalidSignature + } + return nil + default: + return ErrUnsupportedSignatureAlg + } +} + +type audienceVerificationData struct { + Audience biscuit.Symbol + Challenge biscuit.Bytes + Signature biscuit.Bytes +} + +func audienceSign(audience string, audienceKey *kmssign.Key) (*audienceVerificationData, error) { + challenge := make([]byte, challengeSize) + if _, err := rand.Reader.Read(challenge); err != nil { + return nil, err + } + + signedData := append(signStaticCtx, challenge...) + signedData = append(signedData, []byte(audience)...) + signedHash := sha256.Sum256(signedData) + signature, err := audienceKey.Sign(rand.Reader, signedHash[:], crypto.SHA256) + if err != nil { + return nil, err + } + + return &audienceVerificationData{ + Audience: biscuit.Symbol(audience), + Challenge: challenge, + Signature: signature, + }, nil +} + +func verifyAudienceSignature(audiencePubkey *kmssign.Key, data *audienceVerificationData) error { + signedData := append(signStaticCtx, data.Challenge...) + signedData = append(signedData, []byte(data.Audience)...) + hash := sha256.Sum256(signedData) + if !audiencePubkey.Verify(hash[:], data.Signature) { + return errors.New("invalid signature") + } + return nil +} diff --git a/pkg/biscuit/signature_test.go b/pkg/biscuit/signature_test.go new file mode 100644 index 0000000..5cb5d8b --- /dev/null +++ b/pkg/biscuit/signature_test.go @@ -0,0 +1,258 @@ +package biscuit + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "testing" + "time" + + "github.com/flynn/biscuit-go" + "github.com/stretchr/testify/require" +) + +func TestUserSignVerify(t *testing.T) { + tokenHash := make([]byte, 32) + _, err := rand.Read(tokenHash) + require.NoError(t, err) + + challenge := make([]byte, challengeSize) + _, err = rand.Read(challenge) + require.NoError(t, err) + + userKey := generateUserKeyPair(t) + + toSignData := &userToSignData{ + DataID: 1, + Alg: biscuit.Symbol(ECDSA_P256_SHA256), + Data: []byte("challenge"), + SignedBlockCount: 2, + } + + signedData, err := userSign(tokenHash, userKey, toSignData) + require.NoError(t, err) + require.NotEmpty(t, signedData.Signature) + require.Equal(t, biscuit.Integer(2), signedData.SignedBlockCount) + require.Equal(t, biscuit.Integer(1), signedData.DataID) + require.Equal(t, biscuit.Bytes(userKey.Public), signedData.UserPubKey) + + require.Len(t, signedData.Nonce, nonceSize) + zeroNonce := make([]byte, nonceSize) + require.NotEqual(t, biscuit.Bytes(zeroNonce), signedData.Nonce) + + require.WithinDuration(t, time.Now(), time.Time(signedData.Timestamp), 1*time.Second) + + require.NoError(t, verifyUserSignature(tokenHash, &userVerificationData{ + DataID: toSignData.DataID, + Alg: toSignData.Alg, + Data: toSignData.Data, + Nonce: signedData.Nonce, + Signature: signedData.Signature, + SignedBlockCount: signedData.SignedBlockCount, + Timestamp: signedData.Timestamp, + UserPubKey: signedData.UserPubKey, + })) +} + +func TestUserSignFail(t *testing.T) { + validTokenHash := make([]byte, 32) + _, err := rand.Read(validTokenHash) + require.NoError(t, err) + + validChallenge := make([]byte, challengeSize) + _, err = rand.Read(validChallenge) + require.NoError(t, err) + + invalidPrivateKey := &UserKeyPair{ + Private: make([]byte, 32), + } + + testCases := []struct { + desc string + tokenHash []byte + userKey *UserKeyPair + data *userToSignData + expectedErr error + }{ + { + desc: "empty tokenHash", + tokenHash: []byte{}, + }, + { + desc: "unsupported alg", + tokenHash: validTokenHash, + data: &userToSignData{ + Alg: "unsupported", + }, + expectedErr: ErrUnsupportedSignatureAlg, + }, + { + desc: "wrong private key encoding", + tokenHash: validTokenHash, + data: &userToSignData{ + Alg: biscuit.Symbol(ECDSA_P256_SHA256), + }, + userKey: invalidPrivateKey, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.desc, func(t *testing.T) { + _, err := userSign(testCase.tokenHash, testCase.userKey, testCase.data) + require.Error(t, err) + if testCase.expectedErr != nil { + require.Equal(t, testCase.expectedErr, err) + } + }) + } +} + +func TestVerifyUserSignatureFail(t *testing.T) { + tokenHash := []byte("token hash") + toSignData := &userToSignData{ + DataID: 1, + Alg: biscuit.Symbol(ECDSA_P256_SHA256), + Data: []byte("challenge"), + SignedBlockCount: 2, + } + + userKey := generateUserKeyPair(t) + invalidKey := generateUserKeyPair(t) + + signedData, err := userSign(tokenHash, userKey, toSignData) + require.NoError(t, err) + + rsaKey, err := rsa.GenerateKey(rand.Reader, 1024) + require.NoError(t, err) + wrongKeyKind, err := x509.MarshalPKIXPublicKey(&rsaKey.PublicKey) + require.NoError(t, err) + + testCases := []struct { + desc string + tokenHash []byte + data *userVerificationData + expectedErr error + }{ + { + desc: "unsupported alg", + expectedErr: ErrUnsupportedSignatureAlg, + data: &userVerificationData{ + Alg: "unknown", + }, + }, + { + desc: "invalid pubkey encoding", + data: &userVerificationData{ + Alg: biscuit.Symbol(ECDSA_P256_SHA256), + UserPubKey: make([]byte, 32), + }, + }, + { + desc: "invalid pubkey kind", + data: &userVerificationData{ + Alg: biscuit.Symbol(ECDSA_P256_SHA256), + UserPubKey: wrongKeyKind, + }, + }, + { + desc: "wrong pubkey", + tokenHash: tokenHash, + data: &userVerificationData{ + Alg: biscuit.Symbol(ECDSA_P256_SHA256), + UserPubKey: invalidKey.Public, + Data: toSignData.Data, + DataID: toSignData.DataID, + Nonce: signedData.Nonce, + Signature: signedData.Signature, + Timestamp: signedData.Timestamp, + SignedBlockCount: signedData.SignedBlockCount, + }, + }, + { + desc: "tampered token hash", + expectedErr: ErrInvalidSignature, + tokenHash: []byte("wrong"), + data: &userVerificationData{ + Alg: biscuit.Symbol(ECDSA_P256_SHA256), + UserPubKey: userKey.Public, + Data: toSignData.Data, + DataID: toSignData.DataID, + Nonce: signedData.Nonce, + Signature: signedData.Signature, + Timestamp: signedData.Timestamp, + SignedBlockCount: signedData.SignedBlockCount, + }, + }, + { + desc: "tampered nonce", + expectedErr: ErrInvalidSignature, + tokenHash: tokenHash, + data: &userVerificationData{ + Alg: biscuit.Symbol(ECDSA_P256_SHA256), + UserPubKey: userKey.Public, + Data: toSignData.Data, + DataID: toSignData.DataID, + Nonce: []byte("another nonce"), + Signature: signedData.Signature, + SignedBlockCount: signedData.SignedBlockCount, + Timestamp: signedData.Timestamp, + }, + }, + { + desc: "tampered timestamp", + expectedErr: ErrInvalidSignature, + tokenHash: tokenHash, + data: &userVerificationData{ + Alg: biscuit.Symbol(ECDSA_P256_SHA256), + UserPubKey: userKey.Public, + Data: toSignData.Data, + DataID: toSignData.DataID, + Nonce: signedData.Nonce, + Signature: signedData.Signature, + Timestamp: biscuit.Date(time.Now().Add(1 * time.Second)), + SignedBlockCount: signedData.SignedBlockCount, + }, + }, + { + desc: "tampered signedBlockCount", + expectedErr: ErrInvalidSignature, + tokenHash: tokenHash, + data: &userVerificationData{ + Alg: biscuit.Symbol(ECDSA_P256_SHA256), + UserPubKey: userKey.Public, + Data: toSignData.Data, + DataID: toSignData.DataID, + Nonce: signedData.Nonce, + Signature: signedData.Signature, + Timestamp: signedData.Timestamp, + SignedBlockCount: signedData.SignedBlockCount + 1, + }, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.desc, func(t *testing.T) { + + err := verifyUserSignature(testCase.tokenHash, testCase.data) + require.Error(t, err) + if testCase.expectedErr != nil { + require.Equal(t, testCase.expectedErr, err) + } + }) + } +} + +func generateUserKeyPair(t *testing.T) *UserKeyPair { + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + privBytes, err := x509.MarshalECPrivateKey(priv) + require.NoError(t, err) + pubBytes, err := x509.MarshalPKIXPublicKey(&priv.PublicKey) + require.NoError(t, err) + return &UserKeyPair{ + Private: privBytes, + Public: pubBytes, + } +} diff --git a/pkg/biscuit/wrapper.go b/pkg/biscuit/wrapper.go new file mode 100644 index 0000000..ec16ebd --- /dev/null +++ b/pkg/biscuit/wrapper.go @@ -0,0 +1,336 @@ +package biscuit + +import ( + "bytes" + "crypto/rand" + "errors" + "fmt" + + "github.com/flynn/biscuit-go" + "github.com/flynn/hubauth/pkg/kmssign" +) + +var ( + ErrAlreadySigned = errors.New("already signed") + ErrInvalidToSignDataPrefix = errors.New("invalid to_sign data prefix") +) + +var ( + signStaticCtx = []byte("biscuit-pop-v0") + challengeSize = 16 + nonceSize = 16 +) + +type hubauthBuilder struct { + biscuit.Builder +} + +// withUserToSignFact add an authority should_sign fact and associated data to the biscuit +// with an authority caveat requiring the verifier to provide a valid_signature fact. +// the verifier is responsible of ensuring that a valid signature exists over the data. +func (b *hubauthBuilder) withUserToSignFact(userPubkey []byte) error { + dataID := biscuit.Integer(0) + + if err := b.AddAuthorityFact(biscuit.Fact{Predicate: biscuit.Predicate{ + Name: "should_sign", + IDs: []biscuit.Atom{ + dataID, + biscuit.Symbol(ECDSA_P256_SHA256), + biscuit.Bytes(userPubkey), + }, + }}); err != nil { + return err + } + + challenge := make([]byte, challengeSize) + if _, err := rand.Reader.Read(challenge); err != nil { + return err + } + + if err := b.AddAuthorityFact(biscuit.Fact{Predicate: biscuit.Predicate{ + Name: "data", + IDs: []biscuit.Atom{ + dataID, + biscuit.Bytes(append(signStaticCtx, challenge...)), + }, + }}); err != nil { + return err + } + + if err := b.AddAuthorityCaveat(biscuit.Rule{ + Head: biscuit.Predicate{Name: "valid", IDs: []biscuit.Atom{biscuit.Variable(0)}}, + Body: []biscuit.Predicate{ + {Name: "should_sign", IDs: []biscuit.Atom{biscuit.SymbolAuthority, biscuit.Variable(0), biscuit.Variable(1), biscuit.Variable(2)}}, + {Name: "valid_signature", IDs: []biscuit.Atom{biscuit.Symbol("ambient"), biscuit.Variable(0), biscuit.Variable(1), biscuit.Variable(2)}}, + }, + }); err != nil { + return err + } + + return nil +} + +// withAudienceSignature add an authority audience_signature fact, containing a challenge and +// a matching signature using the audience key. +// the verifier is responsible of providing a valid_audience_signature fact, after +// verifying the signature using the audience pubkey. +func (b *hubauthBuilder) withAudienceSignature(audience string, audienceKey *kmssign.Key) error { + if len(audience) == 0 { + return errors.New("audience is required") + } + + data, err := audienceSign(audience, audienceKey) + if err != nil { + return err + } + + if err := b.AddAuthorityFact(biscuit.Fact{Predicate: biscuit.Predicate{ + Name: "audience_signature", + IDs: []biscuit.Atom{ + data.Audience, + data.Challenge, + data.Signature, + }, + }}); err != nil { + return err + } + + if err := b.AddAuthorityCaveat(biscuit.Rule{ + Head: biscuit.Predicate{Name: "valid_audience", IDs: []biscuit.Atom{biscuit.Variable(0)}}, + Body: []biscuit.Predicate{ + {Name: "audience_signature", IDs: []biscuit.Atom{biscuit.SymbolAuthority, biscuit.Variable(0), biscuit.Variable(1), biscuit.Variable(2)}}, + {Name: "valid_audience_signature", IDs: []biscuit.Atom{biscuit.Symbol("ambient"), biscuit.Variable(0), biscuit.Variable(2)}}, + }, + }); err != nil { + return err + } + + return nil +} + +type hubauthBlockBuilder struct { + biscuit.BlockBuilder +} + +func (b *hubauthBlockBuilder) withUserSignature(sigData *userSignatureData) error { + return b.AddFact(biscuit.Fact{Predicate: biscuit.Predicate{ + Name: "signature", + IDs: []biscuit.Atom{ + sigData.DataID, + sigData.UserPubKey, + sigData.Signature, + sigData.Nonce, + sigData.Timestamp, + sigData.SignedBlockCount, + }, + }}) +} + +type hubauthVerifier struct { + biscuit.Verifier +} + +func (v *hubauthVerifier) getUserToSignData(userPubKey biscuit.Bytes, signedBlockCount int) (*userToSignData, error) { + toSign, err := v.Query(biscuit.Rule{ + Head: biscuit.Predicate{ + Name: "to_sign", + IDs: []biscuit.Atom{biscuit.Variable(0), biscuit.Variable(1), biscuit.Variable(2)}, + }, + Body: []biscuit.Predicate{ + { + Name: "should_sign", IDs: []biscuit.Atom{ + biscuit.SymbolAuthority, + biscuit.Variable(0), + biscuit.Variable(1), + biscuit.Bytes(userPubKey), + }, + }, { + Name: "data", IDs: []biscuit.Atom{ + biscuit.SymbolAuthority, + biscuit.Variable(0), + biscuit.Variable(2), + }, + }, + }, + }) + if err != nil { + return nil, err + } + + if g, w := len(toSign), 1; g != w { + return nil, fmt.Errorf("invalid to_sign fact count, got %d, want %d", g, w) + } + + toSignFact := toSign[0] + if g, w := len(toSignFact.IDs), 3; g != w { + return nil, fmt.Errorf("invalid to_sign fact, got %d atoms, want %d", g, w) + } + + sigData := &userToSignData{} + var ok bool + sigData.DataID, ok = toSign[0].IDs[0].(biscuit.Integer) + if !ok { + return nil, errors.New("invalid to_sign atom: dataID") + } + sigData.Alg, ok = toSign[0].IDs[1].(biscuit.Symbol) + if !ok { + return nil, errors.New("invalid to_sign atom: alg") + } + sigData.Data, ok = toSign[0].IDs[2].(biscuit.Bytes) + if !ok { + return nil, errors.New("invalid to_sign atom: data") + } + + if !bytes.HasPrefix(sigData.Data, signStaticCtx) { + return nil, ErrInvalidToSignDataPrefix + } + + sigData.SignedBlockCount = biscuit.Integer(signedBlockCount) + + return sigData, nil +} + +func (v *hubauthVerifier) ensureNotAlreadyUserSigned(dataID biscuit.Integer, userPubKey biscuit.Bytes) error { + alreadySigned, err := v.Query(biscuit.Rule{ + Head: biscuit.Predicate{Name: "already_signed", IDs: []biscuit.Atom{biscuit.Variable(0)}}, + Body: []biscuit.Predicate{ + {Name: "signature", IDs: []biscuit.Atom{dataID, userPubKey, biscuit.Variable(0)}}, + }, + }) + if err != nil { + return err + } + if len(alreadySigned) != 0 { + return ErrAlreadySigned + } + + return nil +} + +func (v *hubauthVerifier) getUserVerificationData() (*userVerificationData, error) { + toValidate, err := v.Query(biscuit.Rule{ + Head: biscuit.Predicate{ + Name: "to_validate", + IDs: []biscuit.Atom{ + biscuit.Variable(0), // dataID + biscuit.Variable(1), // alg + biscuit.Variable(2), // pubkey + biscuit.Variable(3), // data + biscuit.Variable(4), // signature + biscuit.Variable(5), // signerNonce + biscuit.Variable(6), // signerTimestamp + biscuit.Variable(7), // signedBlockCount + }}, + Body: []biscuit.Predicate{ + {Name: "should_sign", IDs: []biscuit.Atom{biscuit.SymbolAuthority, biscuit.Variable(0), biscuit.Variable(1), biscuit.Variable(2)}}, + {Name: "data", IDs: []biscuit.Atom{biscuit.SymbolAuthority, biscuit.Variable(0), biscuit.Variable(3)}}, + {Name: "signature", IDs: []biscuit.Atom{biscuit.Variable(0), biscuit.Variable(2), biscuit.Variable(4), biscuit.Variable(5), biscuit.Variable(6), biscuit.Variable(7)}}, + }, + }) + if err != nil { + return nil, err + } + + if g, w := len(toValidate), 1; g != w { + return nil, fmt.Errorf("invalid to_validate fact count, got %d, want %d", g, w) + } + + toValidateFact := toValidate[0] + if g, w := len(toValidateFact.IDs), 8; g != w { + return nil, fmt.Errorf("invalid to_valid fact atom count, got %d, want %d", g, w) + } + + toVerify := &userVerificationData{} + var ok bool + toVerify.DataID, ok = toValidateFact.IDs[0].(biscuit.Integer) + if !ok { + return nil, errors.New("invalid to_validate atom: dataID") + } + toVerify.Alg, ok = toValidateFact.IDs[1].(biscuit.Symbol) + if !ok { + return nil, errors.New("invalid to_validate atom: alg") + } + toVerify.UserPubKey, ok = toValidateFact.IDs[2].(biscuit.Bytes) + if !ok { + return nil, errors.New("invalid to_validate atom: userPubKey") + } + toVerify.Data, ok = toValidateFact.IDs[3].(biscuit.Bytes) + if !ok { + return nil, errors.New("invalid to_validate atom: data") + } + toVerify.Signature, ok = toValidateFact.IDs[4].(biscuit.Bytes) + if !ok { + return nil, errors.New("invalid to_validate atom: signature") + } + toVerify.Nonce, ok = toValidateFact.IDs[5].(biscuit.Bytes) + if !ok { + return nil, errors.New("invalid to_validate atom: nonce") + } + toVerify.Timestamp, ok = toValidateFact.IDs[6].(biscuit.Date) + if !ok { + return nil, errors.New("invalid to_validate atom: timestamp") + } + toVerify.SignedBlockCount, ok = toValidateFact.IDs[7].(biscuit.Integer) + if !ok { + return nil, errors.New("invalid to_validate atom: signedBlockCount") + } + + return toVerify, nil +} + +func (v *hubauthVerifier) withValidatedUserSignature(data *userVerificationData) error { + v.AddFact(biscuit.Fact{Predicate: biscuit.Predicate{ + Name: "valid_signature", + IDs: []biscuit.Atom{biscuit.Symbol("ambient"), data.DataID, data.Alg, data.UserPubKey}, + }}) + + return nil +} + +func (v *hubauthVerifier) getAudienceVerificationData(audience string) (*audienceVerificationData, error) { + toValidate, err := v.Query(biscuit.Rule{ + Head: biscuit.Predicate{ + Name: "audience_to_validate", + IDs: []biscuit.Atom{ + biscuit.Variable(0), // challenge + biscuit.Variable(1), // signature + }}, + Body: []biscuit.Predicate{ + {Name: "audience_signature", IDs: []biscuit.Atom{biscuit.SymbolAuthority, biscuit.Symbol(audience), biscuit.Variable(0), biscuit.Variable(1)}}, + }, + }) + if err != nil { + return nil, err + } + + if g, w := len(toValidate), 1; g != w { + return nil, fmt.Errorf("invalid audience_to_validate fact count, got %d, want %d", g, w) + } + + toValidateFact := toValidate[0] + if g, w := len(toValidateFact.IDs), 2; g != w { + return nil, fmt.Errorf("invalid audience_to_validate fact atom count, got %d, want %d", g, w) + } + + toVerify := &audienceVerificationData{Audience: biscuit.Symbol(audience)} + var ok bool + toVerify.Challenge, ok = toValidateFact.IDs[0].(biscuit.Bytes) + if !ok { + return nil, errors.New("invalid audience_to_validate atom: challenge") + } + toVerify.Signature, ok = toValidateFact.IDs[1].(biscuit.Bytes) + if !ok { + return nil, errors.New("invalid audience_to_validate atom: signature") + } + + return toVerify, nil +} + +func (v *hubauthVerifier) withValidatedAudienceSignature(data *audienceVerificationData) error { + v.AddFact(biscuit.Fact{Predicate: biscuit.Predicate{ + Name: "valid_audience_signature", + IDs: []biscuit.Atom{biscuit.Symbol("ambient"), data.Audience, data.Signature}, + }}) + + return nil +} From 969a3053a550a2ebc3e9865473ae963c7d45ee19 Mon Sep 17 00:00:00 2001 From: daeMOn Date: Mon, 12 Oct 2020 09:39:15 +0200 Subject: [PATCH 07/24] add biscuit metadata and expiration date Verifiers must now provide the current time for verifying the biscuit, and can extract user informations. User's pubkeys are now provided in http param when exchanging code. Removed block count from biscuit weakening the signature. --- pkg/biscuit/biscuit.go | 56 ++++++++++++----- pkg/biscuit/biscuit_test.go | 20 ++++-- pkg/biscuit/signature.go | 66 ++++++++++++-------- pkg/biscuit/signature_test.go | 105 +++++++++++++------------------ pkg/biscuit/wrapper.go | 112 ++++++++++++++++++++++++++++++---- 5 files changed, 238 insertions(+), 121 deletions(-) diff --git a/pkg/biscuit/biscuit.go b/pkg/biscuit/biscuit.go index ef1a452..e335e60 100644 --- a/pkg/biscuit/biscuit.go +++ b/pkg/biscuit/biscuit.go @@ -3,12 +3,20 @@ package biscuit import ( "crypto/rand" "fmt" + "time" "github.com/flynn/biscuit-go" "github.com/flynn/biscuit-go/sig" "github.com/flynn/hubauth/pkg/kmssign" ) +type Metadata struct { + ClientID string + UserID string + UserEmail string + IssueTime time.Time +} + type UserKeyPair struct { Public []byte Private []byte @@ -16,7 +24,7 @@ type UserKeyPair struct { // GenerateSignable returns a biscuit which will only verify after being // signed with the private key matching the given userPubkey. -func GenerateSignable(rootKey sig.Keypair, audience string, audienceKey *kmssign.Key, userPubkey []byte) ([]byte, error) { +func GenerateSignable(rootKey sig.Keypair, audience string, audienceKey *kmssign.Key, userPublicKey []byte, expireTime time.Time, m *Metadata) ([]byte, error) { builder := &hubauthBuilder{ Builder: biscuit.NewBuilder(rand.Reader, rootKey), } @@ -25,7 +33,15 @@ func GenerateSignable(rootKey sig.Keypair, audience string, audienceKey *kmssign return nil, err } - if err := builder.withUserToSignFact(userPubkey); err != nil { + if err := builder.withUserToSignFact(userPublicKey); err != nil { + return nil, err + } + + if err := builder.withExpire(expireTime); err != nil { + return nil, err + } + + if err := builder.withMetadata(m); err != nil { return nil, err } @@ -55,7 +71,7 @@ func Sign(token []byte, rootPubKey sig.PublicKey, userKey *UserKeyPair) ([]byte, Verifier: v, } - toSignData, err := verifier.getUserToSignData(userKey.Public, b.BlockCount()) + toSignData, err := verifier.getUserToSignData(userKey.Public) if err != nil { return nil, fmt.Errorf("biscuit: failed to get to_sign data: %w", err) } @@ -92,46 +108,56 @@ func Sign(token []byte, rootPubKey sig.PublicKey, userKey *UserKeyPair) ([]byte, // Verify will verify the biscuit, the included audience and user signature, and return an error // when anything is invalid. -func Verify(token []byte, rootPubKey sig.PublicKey, audience string, audienceKey *kmssign.Key) error { +func Verify(token []byte, rootPubKey sig.PublicKey, audience string, audienceKey *kmssign.Key) (*Metadata, error) { b, err := biscuit.Unmarshal(token) if err != nil { - return fmt.Errorf("biscuit: failed to unmarshal: %w", err) + return nil, fmt.Errorf("biscuit: failed to unmarshal: %w", err) } v, err := b.Verify(rootPubKey) if err != nil { - return fmt.Errorf("biscuit: failed to verify: %w", err) + return nil, fmt.Errorf("biscuit: failed to verify: %w", err) } verifier := &hubauthVerifier{v} audienceVerificationData, err := verifier.getAudienceVerificationData(audience) if err != nil { - return fmt.Errorf("biscuit: failed to retrieve audience signature data: %w", err) + return nil, fmt.Errorf("biscuit: failed to retrieve audience signature data: %w", err) } if err := verifyAudienceSignature(audienceKey, audienceVerificationData); err != nil { - return fmt.Errorf("biscuit: failed to verify audience signature: %w", err) + return nil, fmt.Errorf("biscuit: failed to verify audience signature: %w", err) } if err := verifier.withValidatedAudienceSignature(audienceVerificationData); err != nil { - return fmt.Errorf("biscuit: failed to add validated signature: %w", err) + return nil, fmt.Errorf("biscuit: failed to add validated signature: %w", err) } userVerificationData, err := verifier.getUserVerificationData() if err != nil { - return fmt.Errorf("biscuit: failed to retrieve user signature data: %w", err) + return nil, fmt.Errorf("biscuit: failed to retrieve user signature data: %w", err) } - signedTokenHash, err := b.SHA256Sum(int(userVerificationData.SignedBlockCount)) + // TODO: improve biscuit API to allow retrieve the block index the signature is at + // so that we can still append other blocks if needed. Right now the signature MUST BE the last block. + signedTokenHash, err := b.SHA256Sum(b.BlockCount() - 1) if err != nil { - return fmt.Errorf("biscuit: failed to generate token hash: %w", err) + return nil, fmt.Errorf("biscuit: failed to generate token hash: %w", err) } if err := verifyUserSignature(signedTokenHash, userVerificationData); err != nil { - return fmt.Errorf("biscuit: failed to verify user signature: %w", err) + return nil, fmt.Errorf("biscuit: failed to verify user signature: %w", err) } if err := verifier.withValidatedUserSignature(userVerificationData); err != nil { - return fmt.Errorf("biscuit: failed to add validated signature: %w", err) + return nil, fmt.Errorf("biscuit: failed to add validated signature: %w", err) + } + + if err := verifier.withCurrentTime(time.Now()); err != nil { + return nil, fmt.Errorf("biscuit: failed to add current time: %w", err) + } + + if err := verifier.Verify(); err != nil { + return nil, fmt.Errorf("biscuit: failed to verify: %w", err) } - return verifier.Verify() + return verifier.getMetadata() } diff --git a/pkg/biscuit/biscuit_test.go b/pkg/biscuit/biscuit_test.go index ed7b273..bf6199b 100644 --- a/pkg/biscuit/biscuit_test.go +++ b/pkg/biscuit/biscuit_test.go @@ -4,6 +4,7 @@ import ( "context" "crypto/rand" "testing" + "time" "github.com/flynn/biscuit-go/sig" "github.com/flynn/hubauth/pkg/kmssign" @@ -20,8 +21,13 @@ func TestBiscuit(t *testing.T) { require.NoError(t, err) userKey := generateUserKeyPair(t) - - signableBiscuit, err := GenerateSignable(rootKey, audience, audienceKey, userKey.Public) + metas := &Metadata{ + ClientID: "abcd", + UserEmail: "1234@example.com", + UserID: "1234", + IssueTime: time.Now(), + } + signableBiscuit, err := GenerateSignable(rootKey, audience, audienceKey, userKey.Public, time.Now().Add(5*time.Minute), metas) require.NoError(t, err) t.Logf("signable biscuit size: %d", len(signableBiscuit)) @@ -30,8 +36,12 @@ func TestBiscuit(t *testing.T) { require.NoError(t, err) t.Logf("signed biscuit size: %d", len(signedBiscuit)) - err = Verify(signedBiscuit, rootKey.Public(), audience, audienceKey) + res, err := Verify(signedBiscuit, rootKey.Public(), audience, audienceKey) require.NoError(t, err) + require.Equal(t, metas.ClientID, res.ClientID) + require.Equal(t, metas.UserID, res.UserID) + require.Equal(t, metas.UserEmail, res.UserEmail) + require.WithinDuration(t, metas.IssueTime, res.IssueTime, 1*time.Second) }) t.Run("user sign with wrong key", func(t *testing.T) { @@ -43,7 +53,7 @@ func TestBiscuit(t *testing.T) { signedBiscuit, err := Sign(signableBiscuit, rootKey.Public(), userKey) require.NoError(t, err) - err = Verify(signedBiscuit, rootKey.Public(), "http://another.audience.url", audienceKey) + _, err = Verify(signedBiscuit, rootKey.Public(), "http://another.audience.url", audienceKey) require.Error(t, err) wrongAudience := "http://another.audience.url" @@ -51,7 +61,7 @@ func TestBiscuit(t *testing.T) { wrongAudienceKey, err := kmssign.NewKey(context.Background(), kms, wrongAudience) require.NoError(t, err) - err = Verify(signedBiscuit, rootKey.Public(), audience, wrongAudienceKey) + _, err = Verify(signedBiscuit, rootKey.Public(), audience, wrongAudienceKey) require.Error(t, err) }) } diff --git a/pkg/biscuit/signature.go b/pkg/biscuit/signature.go index 385e4bf..078eaa6 100644 --- a/pkg/biscuit/signature.go +++ b/pkg/biscuit/signature.go @@ -3,10 +3,12 @@ package biscuit import ( "crypto" "crypto/ecdsa" + "crypto/elliptic" "crypto/rand" "crypto/sha256" "crypto/x509" "errors" + "fmt" "time" "github.com/flynn/biscuit-go" @@ -25,30 +27,27 @@ const ( ) type userToSignData struct { - DataID biscuit.Integer - Alg biscuit.Symbol - Data biscuit.Bytes - SignedBlockCount biscuit.Integer + DataID biscuit.Integer + Alg biscuit.Symbol + Data biscuit.Bytes } type userSignatureData struct { - DataID biscuit.Integer - UserPubKey biscuit.Bytes - Signature biscuit.Bytes - SignedBlockCount biscuit.Integer - Nonce biscuit.Bytes - Timestamp biscuit.Date + DataID biscuit.Integer + UserPubKey biscuit.Bytes + Signature biscuit.Bytes + Nonce biscuit.Bytes + Timestamp biscuit.Date } type userVerificationData struct { - DataID biscuit.Integer - Alg biscuit.Symbol - Data biscuit.Bytes - UserPubKey biscuit.Bytes - Signature biscuit.Bytes - SignedBlockCount biscuit.Integer - Nonce biscuit.Bytes - Timestamp biscuit.Date + DataID biscuit.Integer + Alg biscuit.Symbol + Data biscuit.Bytes + UserPubKey biscuit.Bytes + Signature biscuit.Bytes + Nonce biscuit.Bytes + Timestamp biscuit.Date } func userSign(tokenHash []byte, userKey *UserKeyPair, toSignData *userToSignData) (*userSignatureData, error) { @@ -67,7 +66,6 @@ func userSign(tokenHash []byte, userKey *UserKeyPair, toSignData *userToSignData dataToSign = append(dataToSign, tokenHash...) dataToSign = append(dataToSign, signerNonce...) dataToSign = append(dataToSign, []byte(signerTimestamp.Format(time.RFC3339))...) - dataToSign = append(dataToSign, []byte(toSignData.SignedBlockCount.String())...) var signedData biscuit.Bytes switch SignatureAlg(toSignData.Alg) { @@ -86,12 +84,11 @@ func userSign(tokenHash []byte, userKey *UserKeyPair, toSignData *userToSignData } return &userSignatureData{ - DataID: toSignData.DataID, - Nonce: signerNonce, - Signature: signedData, - SignedBlockCount: toSignData.SignedBlockCount, - Timestamp: biscuit.Date(signerTimestamp), - UserPubKey: userKey.Public, + DataID: toSignData.DataID, + Nonce: signerNonce, + Signature: signedData, + Timestamp: biscuit.Date(signerTimestamp), + UserPubKey: userKey.Public, }, nil } @@ -101,7 +98,6 @@ func verifyUserSignature(signedTokenHash []byte, data *userVerificationData) err signedData = append(signedData, signedTokenHash...) signedData = append(signedData, data.Nonce...) signedData = append(signedData, []byte(time.Time(data.Timestamp).Format(time.RFC3339))...) - signedData = append(signedData, []byte(data.SignedBlockCount.String())...) switch SignatureAlg(data.Alg) { case ECDSA_P256_SHA256: @@ -160,3 +156,21 @@ func verifyAudienceSignature(audiencePubkey *kmssign.Key, data *audienceVerifica } return nil } + +func validatePKIXP256PublicKey(pubkey []byte) error { + key, err := x509.ParsePKIXPublicKey(pubkey) + if err != nil { + return fmt.Errorf("failed to parse PKIX, ASN.1 DER public key: %v", err) + } + + ecKey, ok := key.(*ecdsa.PublicKey) + if !ok { + return errors.New("public key is not an *ecdsa.PublicKey") + } + + if ecKey.Curve != elliptic.P256() { + return fmt.Errorf("publickey is on wrong curve, expected P256") + } + + return nil +} diff --git a/pkg/biscuit/signature_test.go b/pkg/biscuit/signature_test.go index 5cb5d8b..d4d4cc2 100644 --- a/pkg/biscuit/signature_test.go +++ b/pkg/biscuit/signature_test.go @@ -25,16 +25,14 @@ func TestUserSignVerify(t *testing.T) { userKey := generateUserKeyPair(t) toSignData := &userToSignData{ - DataID: 1, - Alg: biscuit.Symbol(ECDSA_P256_SHA256), - Data: []byte("challenge"), - SignedBlockCount: 2, + DataID: 1, + Alg: biscuit.Symbol(ECDSA_P256_SHA256), + Data: []byte("challenge"), } signedData, err := userSign(tokenHash, userKey, toSignData) require.NoError(t, err) require.NotEmpty(t, signedData.Signature) - require.Equal(t, biscuit.Integer(2), signedData.SignedBlockCount) require.Equal(t, biscuit.Integer(1), signedData.DataID) require.Equal(t, biscuit.Bytes(userKey.Public), signedData.UserPubKey) @@ -45,14 +43,13 @@ func TestUserSignVerify(t *testing.T) { require.WithinDuration(t, time.Now(), time.Time(signedData.Timestamp), 1*time.Second) require.NoError(t, verifyUserSignature(tokenHash, &userVerificationData{ - DataID: toSignData.DataID, - Alg: toSignData.Alg, - Data: toSignData.Data, - Nonce: signedData.Nonce, - Signature: signedData.Signature, - SignedBlockCount: signedData.SignedBlockCount, - Timestamp: signedData.Timestamp, - UserPubKey: signedData.UserPubKey, + DataID: toSignData.DataID, + Alg: toSignData.Alg, + Data: toSignData.Data, + Nonce: signedData.Nonce, + Signature: signedData.Signature, + Timestamp: signedData.Timestamp, + UserPubKey: signedData.UserPubKey, })) } @@ -112,10 +109,9 @@ func TestUserSignFail(t *testing.T) { func TestVerifyUserSignatureFail(t *testing.T) { tokenHash := []byte("token hash") toSignData := &userToSignData{ - DataID: 1, - Alg: biscuit.Symbol(ECDSA_P256_SHA256), - Data: []byte("challenge"), - SignedBlockCount: 2, + DataID: 1, + Alg: biscuit.Symbol(ECDSA_P256_SHA256), + Data: []byte("challenge"), } userKey := generateUserKeyPair(t) @@ -160,14 +156,13 @@ func TestVerifyUserSignatureFail(t *testing.T) { desc: "wrong pubkey", tokenHash: tokenHash, data: &userVerificationData{ - Alg: biscuit.Symbol(ECDSA_P256_SHA256), - UserPubKey: invalidKey.Public, - Data: toSignData.Data, - DataID: toSignData.DataID, - Nonce: signedData.Nonce, - Signature: signedData.Signature, - Timestamp: signedData.Timestamp, - SignedBlockCount: signedData.SignedBlockCount, + Alg: biscuit.Symbol(ECDSA_P256_SHA256), + UserPubKey: invalidKey.Public, + Data: toSignData.Data, + DataID: toSignData.DataID, + Nonce: signedData.Nonce, + Signature: signedData.Signature, + Timestamp: signedData.Timestamp, }, }, { @@ -175,14 +170,13 @@ func TestVerifyUserSignatureFail(t *testing.T) { expectedErr: ErrInvalidSignature, tokenHash: []byte("wrong"), data: &userVerificationData{ - Alg: biscuit.Symbol(ECDSA_P256_SHA256), - UserPubKey: userKey.Public, - Data: toSignData.Data, - DataID: toSignData.DataID, - Nonce: signedData.Nonce, - Signature: signedData.Signature, - Timestamp: signedData.Timestamp, - SignedBlockCount: signedData.SignedBlockCount, + Alg: biscuit.Symbol(ECDSA_P256_SHA256), + UserPubKey: userKey.Public, + Data: toSignData.Data, + DataID: toSignData.DataID, + Nonce: signedData.Nonce, + Signature: signedData.Signature, + Timestamp: signedData.Timestamp, }, }, { @@ -190,14 +184,13 @@ func TestVerifyUserSignatureFail(t *testing.T) { expectedErr: ErrInvalidSignature, tokenHash: tokenHash, data: &userVerificationData{ - Alg: biscuit.Symbol(ECDSA_P256_SHA256), - UserPubKey: userKey.Public, - Data: toSignData.Data, - DataID: toSignData.DataID, - Nonce: []byte("another nonce"), - Signature: signedData.Signature, - SignedBlockCount: signedData.SignedBlockCount, - Timestamp: signedData.Timestamp, + Alg: biscuit.Symbol(ECDSA_P256_SHA256), + UserPubKey: userKey.Public, + Data: toSignData.Data, + DataID: toSignData.DataID, + Nonce: []byte("another nonce"), + Signature: signedData.Signature, + Timestamp: signedData.Timestamp, }, }, { @@ -205,29 +198,13 @@ func TestVerifyUserSignatureFail(t *testing.T) { expectedErr: ErrInvalidSignature, tokenHash: tokenHash, data: &userVerificationData{ - Alg: biscuit.Symbol(ECDSA_P256_SHA256), - UserPubKey: userKey.Public, - Data: toSignData.Data, - DataID: toSignData.DataID, - Nonce: signedData.Nonce, - Signature: signedData.Signature, - Timestamp: biscuit.Date(time.Now().Add(1 * time.Second)), - SignedBlockCount: signedData.SignedBlockCount, - }, - }, - { - desc: "tampered signedBlockCount", - expectedErr: ErrInvalidSignature, - tokenHash: tokenHash, - data: &userVerificationData{ - Alg: biscuit.Symbol(ECDSA_P256_SHA256), - UserPubKey: userKey.Public, - Data: toSignData.Data, - DataID: toSignData.DataID, - Nonce: signedData.Nonce, - Signature: signedData.Signature, - Timestamp: signedData.Timestamp, - SignedBlockCount: signedData.SignedBlockCount + 1, + Alg: biscuit.Symbol(ECDSA_P256_SHA256), + UserPubKey: userKey.Public, + Data: toSignData.Data, + DataID: toSignData.DataID, + Nonce: signedData.Nonce, + Signature: signedData.Signature, + Timestamp: biscuit.Date(time.Now().Add(1 * time.Second)), }, }, } diff --git a/pkg/biscuit/wrapper.go b/pkg/biscuit/wrapper.go index ec16ebd..f39681f 100644 --- a/pkg/biscuit/wrapper.go +++ b/pkg/biscuit/wrapper.go @@ -5,8 +5,10 @@ import ( "crypto/rand" "errors" "fmt" + "time" "github.com/flynn/biscuit-go" + "github.com/flynn/biscuit-go/datalog" "github.com/flynn/hubauth/pkg/kmssign" ) @@ -31,6 +33,10 @@ type hubauthBuilder struct { func (b *hubauthBuilder) withUserToSignFact(userPubkey []byte) error { dataID := biscuit.Integer(0) + if err := validatePKIXP256PublicKey(userPubkey); err != nil { + return err + } + if err := b.AddAuthorityFact(biscuit.Fact{Predicate: biscuit.Predicate{ Name: "should_sign", IDs: []biscuit.Atom{ @@ -108,6 +114,38 @@ func (b *hubauthBuilder) withAudienceSignature(audience string, audienceKey *kms return nil } +func (b *hubauthBuilder) withMetadata(m *Metadata) error { + return b.AddAuthorityFact(biscuit.Fact{Predicate: biscuit.Predicate{ + Name: "metadata", + IDs: []biscuit.Atom{ + biscuit.String(m.ClientID), + biscuit.String(m.UserID), + biscuit.String(m.UserEmail), + biscuit.Date(m.IssueTime), + }, + }}) +} + +func (b *hubauthBuilder) withExpire(exp time.Time) error { + if err := b.AddAuthorityCaveat(biscuit.Rule{ + Head: biscuit.Predicate{Name: "not_expired", IDs: []biscuit.Atom{biscuit.Variable(0)}}, + Body: []biscuit.Predicate{ + {Name: "current_time", IDs: []biscuit.Atom{biscuit.Symbol("ambient"), biscuit.Variable(0)}}, + }, + Constraints: []biscuit.Constraint{{ + Name: biscuit.Variable(0), + Checker: biscuit.DateComparisonChecker{ + Comparison: datalog.DateComparisonBefore, + Date: biscuit.Date(exp), + }, + }}, + }); err != nil { + return err + } + + return nil +} + type hubauthBlockBuilder struct { biscuit.BlockBuilder } @@ -121,7 +159,6 @@ func (b *hubauthBlockBuilder) withUserSignature(sigData *userSignatureData) erro sigData.Signature, sigData.Nonce, sigData.Timestamp, - sigData.SignedBlockCount, }, }}) } @@ -130,7 +167,7 @@ type hubauthVerifier struct { biscuit.Verifier } -func (v *hubauthVerifier) getUserToSignData(userPubKey biscuit.Bytes, signedBlockCount int) (*userToSignData, error) { +func (v *hubauthVerifier) getUserToSignData(userPubKey biscuit.Bytes) (*userToSignData, error) { toSign, err := v.Query(biscuit.Rule{ Head: biscuit.Predicate{ Name: "to_sign", @@ -185,8 +222,6 @@ func (v *hubauthVerifier) getUserToSignData(userPubKey biscuit.Bytes, signedBloc return nil, ErrInvalidToSignDataPrefix } - sigData.SignedBlockCount = biscuit.Integer(signedBlockCount) - return sigData, nil } @@ -219,12 +254,11 @@ func (v *hubauthVerifier) getUserVerificationData() (*userVerificationData, erro biscuit.Variable(4), // signature biscuit.Variable(5), // signerNonce biscuit.Variable(6), // signerTimestamp - biscuit.Variable(7), // signedBlockCount }}, Body: []biscuit.Predicate{ {Name: "should_sign", IDs: []biscuit.Atom{biscuit.SymbolAuthority, biscuit.Variable(0), biscuit.Variable(1), biscuit.Variable(2)}}, {Name: "data", IDs: []biscuit.Atom{biscuit.SymbolAuthority, biscuit.Variable(0), biscuit.Variable(3)}}, - {Name: "signature", IDs: []biscuit.Atom{biscuit.Variable(0), biscuit.Variable(2), biscuit.Variable(4), biscuit.Variable(5), biscuit.Variable(6), biscuit.Variable(7)}}, + {Name: "signature", IDs: []biscuit.Atom{biscuit.Variable(0), biscuit.Variable(2), biscuit.Variable(4), biscuit.Variable(5), biscuit.Variable(6)}}, }, }) if err != nil { @@ -236,7 +270,7 @@ func (v *hubauthVerifier) getUserVerificationData() (*userVerificationData, erro } toValidateFact := toValidate[0] - if g, w := len(toValidateFact.IDs), 8; g != w { + if g, w := len(toValidateFact.IDs), 7; g != w { return nil, fmt.Errorf("invalid to_valid fact atom count, got %d, want %d", g, w) } @@ -270,10 +304,6 @@ func (v *hubauthVerifier) getUserVerificationData() (*userVerificationData, erro if !ok { return nil, errors.New("invalid to_validate atom: timestamp") } - toVerify.SignedBlockCount, ok = toValidateFact.IDs[7].(biscuit.Integer) - if !ok { - return nil, errors.New("invalid to_validate atom: signedBlockCount") - } return toVerify, nil } @@ -326,6 +356,54 @@ func (v *hubauthVerifier) getAudienceVerificationData(audience string) (*audienc return toVerify, nil } +func (v *hubauthVerifier) getMetadata() (*Metadata, error) { + metaFacts, err := v.Query(biscuit.Rule{ + Head: biscuit.Predicate{ + Name: "metadata", + IDs: []biscuit.Atom{ + biscuit.Variable(0), // clientID + biscuit.Variable(1), // userID + biscuit.Variable(2), // userEmail + biscuit.Variable(3), // issueTime + }}, + Body: []biscuit.Predicate{ + {Name: "metadata", IDs: []biscuit.Atom{biscuit.SymbolAuthority, biscuit.Variable(0), biscuit.Variable(1), biscuit.Variable(2), biscuit.Variable(3)}}, + }, + }) + if err != nil { + return nil, err + } + + if g, w := len(metaFacts), 1; g != w { + return nil, fmt.Errorf("invalid metadata fact count, got %d, want %d", g, w) + } + + metaFact := metaFacts[0] + + clientID, ok := metaFact.IDs[0].(biscuit.String) + if !ok { + return nil, errors.New("invalid metadata atom: clientID") + } + userID, ok := metaFact.IDs[1].(biscuit.String) + if !ok { + return nil, errors.New("invalid metadata atom: userID") + } + userEmail, ok := metaFact.IDs[2].(biscuit.String) + if !ok { + return nil, errors.New("invalid metadata atom: userEmail") + } + issueTime, ok := metaFact.IDs[3].(biscuit.Date) + if !ok { + return nil, errors.New("invalid metadata atom: issueTime") + } + return &Metadata{ + ClientID: string(clientID), + UserID: string(userID), + UserEmail: string(userEmail), + IssueTime: time.Time(issueTime), + }, nil +} + func (v *hubauthVerifier) withValidatedAudienceSignature(data *audienceVerificationData) error { v.AddFact(biscuit.Fact{Predicate: biscuit.Predicate{ Name: "valid_audience_signature", @@ -334,3 +412,15 @@ func (v *hubauthVerifier) withValidatedAudienceSignature(data *audienceVerificat return nil } + +func (v *hubauthVerifier) withCurrentTime(t time.Time) error { + v.AddFact(biscuit.Fact{Predicate: biscuit.Predicate{ + Name: "current_time", + IDs: []biscuit.Atom{ + biscuit.Symbol("ambient"), + biscuit.Date(t), + }, + }}) + + return nil +} From db630ec0ac82c1ca5cb8c51180c0f62552e98903 Mon Sep 17 00:00:00 2001 From: daeMOn Date: Tue, 13 Oct 2020 11:50:18 +0200 Subject: [PATCH 08/24] cleanup --- pkg/biscuit/biscuit.go | 2 +- pkg/biscuit/signature_test.go | 11 +++-------- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/pkg/biscuit/biscuit.go b/pkg/biscuit/biscuit.go index e335e60..0c7a3e4 100644 --- a/pkg/biscuit/biscuit.go +++ b/pkg/biscuit/biscuit.go @@ -26,7 +26,7 @@ type UserKeyPair struct { // signed with the private key matching the given userPubkey. func GenerateSignable(rootKey sig.Keypair, audience string, audienceKey *kmssign.Key, userPublicKey []byte, expireTime time.Time, m *Metadata) ([]byte, error) { builder := &hubauthBuilder{ - Builder: biscuit.NewBuilder(rand.Reader, rootKey), + biscuit.NewBuilder(rand.Reader, rootKey), } if err := builder.withAudienceSignature(audience, audienceKey); err != nil { diff --git a/pkg/biscuit/signature_test.go b/pkg/biscuit/signature_test.go index d4d4cc2..2351e26 100644 --- a/pkg/biscuit/signature_test.go +++ b/pkg/biscuit/signature_test.go @@ -211,7 +211,6 @@ func TestVerifyUserSignatureFail(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.desc, func(t *testing.T) { - err := verifyUserSignature(testCase.tokenHash, testCase.data) require.Error(t, err) if testCase.expectedErr != nil { @@ -224,12 +223,8 @@ func TestVerifyUserSignatureFail(t *testing.T) { func generateUserKeyPair(t *testing.T) *UserKeyPair { priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) require.NoError(t, err) - privBytes, err := x509.MarshalECPrivateKey(priv) - require.NoError(t, err) - pubBytes, err := x509.MarshalPKIXPublicKey(&priv.PublicKey) + + kp, err := NewECDSAKeyPair(priv) require.NoError(t, err) - return &UserKeyPair{ - Private: privBytes, - Public: pubBytes, - } + return kp } From 3067b97c18dd55b980f1bfb7b916e31add0408fa Mon Sep 17 00:00:00 2001 From: Flavien Binet Date: Thu, 15 Oct 2020 10:38:31 +0200 Subject: [PATCH 09/24] moved biscuit pkg to biscuit-go repo --- pkg/biscuit/biscuit.go | 163 ------------- pkg/biscuit/biscuit_test.go | 67 ------ pkg/biscuit/signature.go | 176 -------------- pkg/biscuit/signature_test.go | 230 ------------------ pkg/biscuit/wrapper.go | 426 ---------------------------------- 5 files changed, 1062 deletions(-) delete mode 100644 pkg/biscuit/biscuit.go delete mode 100644 pkg/biscuit/biscuit_test.go delete mode 100644 pkg/biscuit/signature.go delete mode 100644 pkg/biscuit/signature_test.go delete mode 100644 pkg/biscuit/wrapper.go diff --git a/pkg/biscuit/biscuit.go b/pkg/biscuit/biscuit.go deleted file mode 100644 index 0c7a3e4..0000000 --- a/pkg/biscuit/biscuit.go +++ /dev/null @@ -1,163 +0,0 @@ -package biscuit - -import ( - "crypto/rand" - "fmt" - "time" - - "github.com/flynn/biscuit-go" - "github.com/flynn/biscuit-go/sig" - "github.com/flynn/hubauth/pkg/kmssign" -) - -type Metadata struct { - ClientID string - UserID string - UserEmail string - IssueTime time.Time -} - -type UserKeyPair struct { - Public []byte - Private []byte -} - -// GenerateSignable returns a biscuit which will only verify after being -// signed with the private key matching the given userPubkey. -func GenerateSignable(rootKey sig.Keypair, audience string, audienceKey *kmssign.Key, userPublicKey []byte, expireTime time.Time, m *Metadata) ([]byte, error) { - builder := &hubauthBuilder{ - biscuit.NewBuilder(rand.Reader, rootKey), - } - - if err := builder.withAudienceSignature(audience, audienceKey); err != nil { - return nil, err - } - - if err := builder.withUserToSignFact(userPublicKey); err != nil { - return nil, err - } - - if err := builder.withExpire(expireTime); err != nil { - return nil, err - } - - if err := builder.withMetadata(m); err != nil { - return nil, err - } - - b, err := builder.Build() - if err != nil { - return nil, err - } - - return b.Serialize() -} - -// Sign append a user signature on the given token and return it. -// The UserKeyPair key format to provide depends on the signature algorithm: -// - for ECDSA_P256_SHA256, the private key must be encoded in SEC 1, ASN.1 DER form, -// and the public key in PKIX, ASN.1 DER form. -func Sign(token []byte, rootPubKey sig.PublicKey, userKey *UserKeyPair) ([]byte, error) { - b, err := biscuit.Unmarshal(token) - if err != nil { - return nil, fmt.Errorf("biscuit: failed to unmarshal: %w", err) - } - - v, err := b.Verify(rootPubKey) - if err != nil { - return nil, fmt.Errorf("biscuit: failed to verify: %w", err) - } - verifier := &hubauthVerifier{ - Verifier: v, - } - - toSignData, err := verifier.getUserToSignData(userKey.Public) - if err != nil { - return nil, fmt.Errorf("biscuit: failed to get to_sign data: %w", err) - } - - if err := verifier.ensureNotAlreadyUserSigned(toSignData.DataID, userKey.Public); err != nil { - return nil, fmt.Errorf("biscuit: previous signature check failed: %w", err) - } - - tokenHash, err := b.SHA256Sum(b.BlockCount()) - if err != nil { - return nil, err - } - - signData, err := userSign(tokenHash, userKey, toSignData) - if err != nil { - return nil, fmt.Errorf("biscuit: signature failed: %w", err) - } - - builder := &hubauthBlockBuilder{ - BlockBuilder: b.CreateBlock(), - } - if err := builder.withUserSignature(signData); err != nil { - return nil, fmt.Errorf("biscuit: failed to create signature block: %w", err) - } - - clientKey := sig.GenerateKeypair(rand.Reader) - b, err = b.Append(rand.Reader, clientKey, builder.Build()) - if err != nil { - return nil, fmt.Errorf("biscuit: failed to append signature block: %w", err) - } - - return b.Serialize() -} - -// Verify will verify the biscuit, the included audience and user signature, and return an error -// when anything is invalid. -func Verify(token []byte, rootPubKey sig.PublicKey, audience string, audienceKey *kmssign.Key) (*Metadata, error) { - b, err := biscuit.Unmarshal(token) - if err != nil { - return nil, fmt.Errorf("biscuit: failed to unmarshal: %w", err) - } - - v, err := b.Verify(rootPubKey) - if err != nil { - return nil, fmt.Errorf("biscuit: failed to verify: %w", err) - } - verifier := &hubauthVerifier{v} - - audienceVerificationData, err := verifier.getAudienceVerificationData(audience) - if err != nil { - return nil, fmt.Errorf("biscuit: failed to retrieve audience signature data: %w", err) - } - - if err := verifyAudienceSignature(audienceKey, audienceVerificationData); err != nil { - return nil, fmt.Errorf("biscuit: failed to verify audience signature: %w", err) - } - if err := verifier.withValidatedAudienceSignature(audienceVerificationData); err != nil { - return nil, fmt.Errorf("biscuit: failed to add validated signature: %w", err) - } - - userVerificationData, err := verifier.getUserVerificationData() - if err != nil { - return nil, fmt.Errorf("biscuit: failed to retrieve user signature data: %w", err) - } - - // TODO: improve biscuit API to allow retrieve the block index the signature is at - // so that we can still append other blocks if needed. Right now the signature MUST BE the last block. - signedTokenHash, err := b.SHA256Sum(b.BlockCount() - 1) - if err != nil { - return nil, fmt.Errorf("biscuit: failed to generate token hash: %w", err) - } - - if err := verifyUserSignature(signedTokenHash, userVerificationData); err != nil { - return nil, fmt.Errorf("biscuit: failed to verify user signature: %w", err) - } - if err := verifier.withValidatedUserSignature(userVerificationData); err != nil { - return nil, fmt.Errorf("biscuit: failed to add validated signature: %w", err) - } - - if err := verifier.withCurrentTime(time.Now()); err != nil { - return nil, fmt.Errorf("biscuit: failed to add current time: %w", err) - } - - if err := verifier.Verify(); err != nil { - return nil, fmt.Errorf("biscuit: failed to verify: %w", err) - } - - return verifier.getMetadata() -} diff --git a/pkg/biscuit/biscuit_test.go b/pkg/biscuit/biscuit_test.go deleted file mode 100644 index bf6199b..0000000 --- a/pkg/biscuit/biscuit_test.go +++ /dev/null @@ -1,67 +0,0 @@ -package biscuit - -import ( - "context" - "crypto/rand" - "testing" - "time" - - "github.com/flynn/biscuit-go/sig" - "github.com/flynn/hubauth/pkg/kmssign" - "github.com/flynn/hubauth/pkg/kmssign/kmssim" - "github.com/stretchr/testify/require" -) - -func TestBiscuit(t *testing.T) { - rootKey := sig.GenerateKeypair(rand.Reader) - audience := "http://random.audience.url" - - kms := kmssim.NewClient([]string{audience}) - audienceKey, err := kmssign.NewKey(context.Background(), kms, audience) - require.NoError(t, err) - - userKey := generateUserKeyPair(t) - metas := &Metadata{ - ClientID: "abcd", - UserEmail: "1234@example.com", - UserID: "1234", - IssueTime: time.Now(), - } - signableBiscuit, err := GenerateSignable(rootKey, audience, audienceKey, userKey.Public, time.Now().Add(5*time.Minute), metas) - require.NoError(t, err) - t.Logf("signable biscuit size: %d", len(signableBiscuit)) - - t.Run("happy path", func(t *testing.T) { - signedBiscuit, err := Sign(signableBiscuit, rootKey.Public(), userKey) - require.NoError(t, err) - t.Logf("signed biscuit size: %d", len(signedBiscuit)) - - res, err := Verify(signedBiscuit, rootKey.Public(), audience, audienceKey) - require.NoError(t, err) - require.Equal(t, metas.ClientID, res.ClientID) - require.Equal(t, metas.UserID, res.UserID) - require.Equal(t, metas.UserEmail, res.UserEmail) - require.WithinDuration(t, metas.IssueTime, res.IssueTime, 1*time.Second) - }) - - t.Run("user sign with wrong key", func(t *testing.T) { - _, err := Sign(signableBiscuit, rootKey.Public(), generateUserKeyPair(t)) - require.Error(t, err) - }) - - t.Run("verify wrong audience", func(t *testing.T) { - signedBiscuit, err := Sign(signableBiscuit, rootKey.Public(), userKey) - require.NoError(t, err) - - _, err = Verify(signedBiscuit, rootKey.Public(), "http://another.audience.url", audienceKey) - require.Error(t, err) - - wrongAudience := "http://another.audience.url" - kms := kmssim.NewClient([]string{wrongAudience}) - wrongAudienceKey, err := kmssign.NewKey(context.Background(), kms, wrongAudience) - require.NoError(t, err) - - _, err = Verify(signedBiscuit, rootKey.Public(), audience, wrongAudienceKey) - require.Error(t, err) - }) -} diff --git a/pkg/biscuit/signature.go b/pkg/biscuit/signature.go deleted file mode 100644 index 078eaa6..0000000 --- a/pkg/biscuit/signature.go +++ /dev/null @@ -1,176 +0,0 @@ -package biscuit - -import ( - "crypto" - "crypto/ecdsa" - "crypto/elliptic" - "crypto/rand" - "crypto/sha256" - "crypto/x509" - "errors" - "fmt" - "time" - - "github.com/flynn/biscuit-go" - "github.com/flynn/hubauth/pkg/kmssign" -) - -var ( - ErrUnsupportedSignatureAlg = errors.New("unsupported signature algorithm") - ErrInvalidSignature = errors.New("invalid signature") -) - -type SignatureAlg biscuit.Symbol - -const ( - ECDSA_P256_SHA256 SignatureAlg = "ECDSA_P256_SHA256" -) - -type userToSignData struct { - DataID biscuit.Integer - Alg biscuit.Symbol - Data biscuit.Bytes -} - -type userSignatureData struct { - DataID biscuit.Integer - UserPubKey biscuit.Bytes - Signature biscuit.Bytes - Nonce biscuit.Bytes - Timestamp biscuit.Date -} - -type userVerificationData struct { - DataID biscuit.Integer - Alg biscuit.Symbol - Data biscuit.Bytes - UserPubKey biscuit.Bytes - Signature biscuit.Bytes - Nonce biscuit.Bytes - Timestamp biscuit.Date -} - -func userSign(tokenHash []byte, userKey *UserKeyPair, toSignData *userToSignData) (*userSignatureData, error) { - if len(tokenHash) == 0 { - return nil, errors.New("invalid tokenHash") - } - - signerTimestamp := time.Now() - signerNonce := make([]byte, nonceSize) - if _, err := rand.Read(signerNonce); err != nil { - return nil, err - } - - var dataToSign []byte - dataToSign = append(dataToSign, toSignData.Data...) - dataToSign = append(dataToSign, tokenHash...) - dataToSign = append(dataToSign, signerNonce...) - dataToSign = append(dataToSign, []byte(signerTimestamp.Format(time.RFC3339))...) - - var signedData biscuit.Bytes - switch SignatureAlg(toSignData.Alg) { - case ECDSA_P256_SHA256: - privKey, err := x509.ParseECPrivateKey(userKey.Private) - if err != nil { - return nil, err - } - hash := sha256.Sum256(dataToSign) - signedData, err = ecdsa.SignASN1(rand.Reader, privKey, hash[:]) - if err != nil { - return nil, err - } - default: - return nil, ErrUnsupportedSignatureAlg - } - - return &userSignatureData{ - DataID: toSignData.DataID, - Nonce: signerNonce, - Signature: signedData, - Timestamp: biscuit.Date(signerTimestamp), - UserPubKey: userKey.Public, - }, nil -} - -func verifyUserSignature(signedTokenHash []byte, data *userVerificationData) error { - var signedData []byte - signedData = append(signedData, data.Data...) - signedData = append(signedData, signedTokenHash...) - signedData = append(signedData, data.Nonce...) - signedData = append(signedData, []byte(time.Time(data.Timestamp).Format(time.RFC3339))...) - - switch SignatureAlg(data.Alg) { - case ECDSA_P256_SHA256: - pk, err := x509.ParsePKIXPublicKey(data.UserPubKey) - if err != nil { - return err - } - pubkey, ok := pk.(*ecdsa.PublicKey) - if !ok { - return errors.New("invalid pubkey, not an *ecdsa.PublicKey") - } - - hash := sha256.Sum256(signedData) - if !ecdsa.VerifyASN1(pubkey, hash[:], data.Signature) { - return ErrInvalidSignature - } - return nil - default: - return ErrUnsupportedSignatureAlg - } -} - -type audienceVerificationData struct { - Audience biscuit.Symbol - Challenge biscuit.Bytes - Signature biscuit.Bytes -} - -func audienceSign(audience string, audienceKey *kmssign.Key) (*audienceVerificationData, error) { - challenge := make([]byte, challengeSize) - if _, err := rand.Reader.Read(challenge); err != nil { - return nil, err - } - - signedData := append(signStaticCtx, challenge...) - signedData = append(signedData, []byte(audience)...) - signedHash := sha256.Sum256(signedData) - signature, err := audienceKey.Sign(rand.Reader, signedHash[:], crypto.SHA256) - if err != nil { - return nil, err - } - - return &audienceVerificationData{ - Audience: biscuit.Symbol(audience), - Challenge: challenge, - Signature: signature, - }, nil -} - -func verifyAudienceSignature(audiencePubkey *kmssign.Key, data *audienceVerificationData) error { - signedData := append(signStaticCtx, data.Challenge...) - signedData = append(signedData, []byte(data.Audience)...) - hash := sha256.Sum256(signedData) - if !audiencePubkey.Verify(hash[:], data.Signature) { - return errors.New("invalid signature") - } - return nil -} - -func validatePKIXP256PublicKey(pubkey []byte) error { - key, err := x509.ParsePKIXPublicKey(pubkey) - if err != nil { - return fmt.Errorf("failed to parse PKIX, ASN.1 DER public key: %v", err) - } - - ecKey, ok := key.(*ecdsa.PublicKey) - if !ok { - return errors.New("public key is not an *ecdsa.PublicKey") - } - - if ecKey.Curve != elliptic.P256() { - return fmt.Errorf("publickey is on wrong curve, expected P256") - } - - return nil -} diff --git a/pkg/biscuit/signature_test.go b/pkg/biscuit/signature_test.go deleted file mode 100644 index 2351e26..0000000 --- a/pkg/biscuit/signature_test.go +++ /dev/null @@ -1,230 +0,0 @@ -package biscuit - -import ( - "crypto/ecdsa" - "crypto/elliptic" - "crypto/rand" - "crypto/rsa" - "crypto/x509" - "testing" - "time" - - "github.com/flynn/biscuit-go" - "github.com/stretchr/testify/require" -) - -func TestUserSignVerify(t *testing.T) { - tokenHash := make([]byte, 32) - _, err := rand.Read(tokenHash) - require.NoError(t, err) - - challenge := make([]byte, challengeSize) - _, err = rand.Read(challenge) - require.NoError(t, err) - - userKey := generateUserKeyPair(t) - - toSignData := &userToSignData{ - DataID: 1, - Alg: biscuit.Symbol(ECDSA_P256_SHA256), - Data: []byte("challenge"), - } - - signedData, err := userSign(tokenHash, userKey, toSignData) - require.NoError(t, err) - require.NotEmpty(t, signedData.Signature) - require.Equal(t, biscuit.Integer(1), signedData.DataID) - require.Equal(t, biscuit.Bytes(userKey.Public), signedData.UserPubKey) - - require.Len(t, signedData.Nonce, nonceSize) - zeroNonce := make([]byte, nonceSize) - require.NotEqual(t, biscuit.Bytes(zeroNonce), signedData.Nonce) - - require.WithinDuration(t, time.Now(), time.Time(signedData.Timestamp), 1*time.Second) - - require.NoError(t, verifyUserSignature(tokenHash, &userVerificationData{ - DataID: toSignData.DataID, - Alg: toSignData.Alg, - Data: toSignData.Data, - Nonce: signedData.Nonce, - Signature: signedData.Signature, - Timestamp: signedData.Timestamp, - UserPubKey: signedData.UserPubKey, - })) -} - -func TestUserSignFail(t *testing.T) { - validTokenHash := make([]byte, 32) - _, err := rand.Read(validTokenHash) - require.NoError(t, err) - - validChallenge := make([]byte, challengeSize) - _, err = rand.Read(validChallenge) - require.NoError(t, err) - - invalidPrivateKey := &UserKeyPair{ - Private: make([]byte, 32), - } - - testCases := []struct { - desc string - tokenHash []byte - userKey *UserKeyPair - data *userToSignData - expectedErr error - }{ - { - desc: "empty tokenHash", - tokenHash: []byte{}, - }, - { - desc: "unsupported alg", - tokenHash: validTokenHash, - data: &userToSignData{ - Alg: "unsupported", - }, - expectedErr: ErrUnsupportedSignatureAlg, - }, - { - desc: "wrong private key encoding", - tokenHash: validTokenHash, - data: &userToSignData{ - Alg: biscuit.Symbol(ECDSA_P256_SHA256), - }, - userKey: invalidPrivateKey, - }, - } - - for _, testCase := range testCases { - t.Run(testCase.desc, func(t *testing.T) { - _, err := userSign(testCase.tokenHash, testCase.userKey, testCase.data) - require.Error(t, err) - if testCase.expectedErr != nil { - require.Equal(t, testCase.expectedErr, err) - } - }) - } -} - -func TestVerifyUserSignatureFail(t *testing.T) { - tokenHash := []byte("token hash") - toSignData := &userToSignData{ - DataID: 1, - Alg: biscuit.Symbol(ECDSA_P256_SHA256), - Data: []byte("challenge"), - } - - userKey := generateUserKeyPair(t) - invalidKey := generateUserKeyPair(t) - - signedData, err := userSign(tokenHash, userKey, toSignData) - require.NoError(t, err) - - rsaKey, err := rsa.GenerateKey(rand.Reader, 1024) - require.NoError(t, err) - wrongKeyKind, err := x509.MarshalPKIXPublicKey(&rsaKey.PublicKey) - require.NoError(t, err) - - testCases := []struct { - desc string - tokenHash []byte - data *userVerificationData - expectedErr error - }{ - { - desc: "unsupported alg", - expectedErr: ErrUnsupportedSignatureAlg, - data: &userVerificationData{ - Alg: "unknown", - }, - }, - { - desc: "invalid pubkey encoding", - data: &userVerificationData{ - Alg: biscuit.Symbol(ECDSA_P256_SHA256), - UserPubKey: make([]byte, 32), - }, - }, - { - desc: "invalid pubkey kind", - data: &userVerificationData{ - Alg: biscuit.Symbol(ECDSA_P256_SHA256), - UserPubKey: wrongKeyKind, - }, - }, - { - desc: "wrong pubkey", - tokenHash: tokenHash, - data: &userVerificationData{ - Alg: biscuit.Symbol(ECDSA_P256_SHA256), - UserPubKey: invalidKey.Public, - Data: toSignData.Data, - DataID: toSignData.DataID, - Nonce: signedData.Nonce, - Signature: signedData.Signature, - Timestamp: signedData.Timestamp, - }, - }, - { - desc: "tampered token hash", - expectedErr: ErrInvalidSignature, - tokenHash: []byte("wrong"), - data: &userVerificationData{ - Alg: biscuit.Symbol(ECDSA_P256_SHA256), - UserPubKey: userKey.Public, - Data: toSignData.Data, - DataID: toSignData.DataID, - Nonce: signedData.Nonce, - Signature: signedData.Signature, - Timestamp: signedData.Timestamp, - }, - }, - { - desc: "tampered nonce", - expectedErr: ErrInvalidSignature, - tokenHash: tokenHash, - data: &userVerificationData{ - Alg: biscuit.Symbol(ECDSA_P256_SHA256), - UserPubKey: userKey.Public, - Data: toSignData.Data, - DataID: toSignData.DataID, - Nonce: []byte("another nonce"), - Signature: signedData.Signature, - Timestamp: signedData.Timestamp, - }, - }, - { - desc: "tampered timestamp", - expectedErr: ErrInvalidSignature, - tokenHash: tokenHash, - data: &userVerificationData{ - Alg: biscuit.Symbol(ECDSA_P256_SHA256), - UserPubKey: userKey.Public, - Data: toSignData.Data, - DataID: toSignData.DataID, - Nonce: signedData.Nonce, - Signature: signedData.Signature, - Timestamp: biscuit.Date(time.Now().Add(1 * time.Second)), - }, - }, - } - - for _, testCase := range testCases { - t.Run(testCase.desc, func(t *testing.T) { - err := verifyUserSignature(testCase.tokenHash, testCase.data) - require.Error(t, err) - if testCase.expectedErr != nil { - require.Equal(t, testCase.expectedErr, err) - } - }) - } -} - -func generateUserKeyPair(t *testing.T) *UserKeyPair { - priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - require.NoError(t, err) - - kp, err := NewECDSAKeyPair(priv) - require.NoError(t, err) - return kp -} diff --git a/pkg/biscuit/wrapper.go b/pkg/biscuit/wrapper.go deleted file mode 100644 index f39681f..0000000 --- a/pkg/biscuit/wrapper.go +++ /dev/null @@ -1,426 +0,0 @@ -package biscuit - -import ( - "bytes" - "crypto/rand" - "errors" - "fmt" - "time" - - "github.com/flynn/biscuit-go" - "github.com/flynn/biscuit-go/datalog" - "github.com/flynn/hubauth/pkg/kmssign" -) - -var ( - ErrAlreadySigned = errors.New("already signed") - ErrInvalidToSignDataPrefix = errors.New("invalid to_sign data prefix") -) - -var ( - signStaticCtx = []byte("biscuit-pop-v0") - challengeSize = 16 - nonceSize = 16 -) - -type hubauthBuilder struct { - biscuit.Builder -} - -// withUserToSignFact add an authority should_sign fact and associated data to the biscuit -// with an authority caveat requiring the verifier to provide a valid_signature fact. -// the verifier is responsible of ensuring that a valid signature exists over the data. -func (b *hubauthBuilder) withUserToSignFact(userPubkey []byte) error { - dataID := biscuit.Integer(0) - - if err := validatePKIXP256PublicKey(userPubkey); err != nil { - return err - } - - if err := b.AddAuthorityFact(biscuit.Fact{Predicate: biscuit.Predicate{ - Name: "should_sign", - IDs: []biscuit.Atom{ - dataID, - biscuit.Symbol(ECDSA_P256_SHA256), - biscuit.Bytes(userPubkey), - }, - }}); err != nil { - return err - } - - challenge := make([]byte, challengeSize) - if _, err := rand.Reader.Read(challenge); err != nil { - return err - } - - if err := b.AddAuthorityFact(biscuit.Fact{Predicate: biscuit.Predicate{ - Name: "data", - IDs: []biscuit.Atom{ - dataID, - biscuit.Bytes(append(signStaticCtx, challenge...)), - }, - }}); err != nil { - return err - } - - if err := b.AddAuthorityCaveat(biscuit.Rule{ - Head: biscuit.Predicate{Name: "valid", IDs: []biscuit.Atom{biscuit.Variable(0)}}, - Body: []biscuit.Predicate{ - {Name: "should_sign", IDs: []biscuit.Atom{biscuit.SymbolAuthority, biscuit.Variable(0), biscuit.Variable(1), biscuit.Variable(2)}}, - {Name: "valid_signature", IDs: []biscuit.Atom{biscuit.Symbol("ambient"), biscuit.Variable(0), biscuit.Variable(1), biscuit.Variable(2)}}, - }, - }); err != nil { - return err - } - - return nil -} - -// withAudienceSignature add an authority audience_signature fact, containing a challenge and -// a matching signature using the audience key. -// the verifier is responsible of providing a valid_audience_signature fact, after -// verifying the signature using the audience pubkey. -func (b *hubauthBuilder) withAudienceSignature(audience string, audienceKey *kmssign.Key) error { - if len(audience) == 0 { - return errors.New("audience is required") - } - - data, err := audienceSign(audience, audienceKey) - if err != nil { - return err - } - - if err := b.AddAuthorityFact(biscuit.Fact{Predicate: biscuit.Predicate{ - Name: "audience_signature", - IDs: []biscuit.Atom{ - data.Audience, - data.Challenge, - data.Signature, - }, - }}); err != nil { - return err - } - - if err := b.AddAuthorityCaveat(biscuit.Rule{ - Head: biscuit.Predicate{Name: "valid_audience", IDs: []biscuit.Atom{biscuit.Variable(0)}}, - Body: []biscuit.Predicate{ - {Name: "audience_signature", IDs: []biscuit.Atom{biscuit.SymbolAuthority, biscuit.Variable(0), biscuit.Variable(1), biscuit.Variable(2)}}, - {Name: "valid_audience_signature", IDs: []biscuit.Atom{biscuit.Symbol("ambient"), biscuit.Variable(0), biscuit.Variable(2)}}, - }, - }); err != nil { - return err - } - - return nil -} - -func (b *hubauthBuilder) withMetadata(m *Metadata) error { - return b.AddAuthorityFact(biscuit.Fact{Predicate: biscuit.Predicate{ - Name: "metadata", - IDs: []biscuit.Atom{ - biscuit.String(m.ClientID), - biscuit.String(m.UserID), - biscuit.String(m.UserEmail), - biscuit.Date(m.IssueTime), - }, - }}) -} - -func (b *hubauthBuilder) withExpire(exp time.Time) error { - if err := b.AddAuthorityCaveat(biscuit.Rule{ - Head: biscuit.Predicate{Name: "not_expired", IDs: []biscuit.Atom{biscuit.Variable(0)}}, - Body: []biscuit.Predicate{ - {Name: "current_time", IDs: []biscuit.Atom{biscuit.Symbol("ambient"), biscuit.Variable(0)}}, - }, - Constraints: []biscuit.Constraint{{ - Name: biscuit.Variable(0), - Checker: biscuit.DateComparisonChecker{ - Comparison: datalog.DateComparisonBefore, - Date: biscuit.Date(exp), - }, - }}, - }); err != nil { - return err - } - - return nil -} - -type hubauthBlockBuilder struct { - biscuit.BlockBuilder -} - -func (b *hubauthBlockBuilder) withUserSignature(sigData *userSignatureData) error { - return b.AddFact(biscuit.Fact{Predicate: biscuit.Predicate{ - Name: "signature", - IDs: []biscuit.Atom{ - sigData.DataID, - sigData.UserPubKey, - sigData.Signature, - sigData.Nonce, - sigData.Timestamp, - }, - }}) -} - -type hubauthVerifier struct { - biscuit.Verifier -} - -func (v *hubauthVerifier) getUserToSignData(userPubKey biscuit.Bytes) (*userToSignData, error) { - toSign, err := v.Query(biscuit.Rule{ - Head: biscuit.Predicate{ - Name: "to_sign", - IDs: []biscuit.Atom{biscuit.Variable(0), biscuit.Variable(1), biscuit.Variable(2)}, - }, - Body: []biscuit.Predicate{ - { - Name: "should_sign", IDs: []biscuit.Atom{ - biscuit.SymbolAuthority, - biscuit.Variable(0), - biscuit.Variable(1), - biscuit.Bytes(userPubKey), - }, - }, { - Name: "data", IDs: []biscuit.Atom{ - biscuit.SymbolAuthority, - biscuit.Variable(0), - biscuit.Variable(2), - }, - }, - }, - }) - if err != nil { - return nil, err - } - - if g, w := len(toSign), 1; g != w { - return nil, fmt.Errorf("invalid to_sign fact count, got %d, want %d", g, w) - } - - toSignFact := toSign[0] - if g, w := len(toSignFact.IDs), 3; g != w { - return nil, fmt.Errorf("invalid to_sign fact, got %d atoms, want %d", g, w) - } - - sigData := &userToSignData{} - var ok bool - sigData.DataID, ok = toSign[0].IDs[0].(biscuit.Integer) - if !ok { - return nil, errors.New("invalid to_sign atom: dataID") - } - sigData.Alg, ok = toSign[0].IDs[1].(biscuit.Symbol) - if !ok { - return nil, errors.New("invalid to_sign atom: alg") - } - sigData.Data, ok = toSign[0].IDs[2].(biscuit.Bytes) - if !ok { - return nil, errors.New("invalid to_sign atom: data") - } - - if !bytes.HasPrefix(sigData.Data, signStaticCtx) { - return nil, ErrInvalidToSignDataPrefix - } - - return sigData, nil -} - -func (v *hubauthVerifier) ensureNotAlreadyUserSigned(dataID biscuit.Integer, userPubKey biscuit.Bytes) error { - alreadySigned, err := v.Query(biscuit.Rule{ - Head: biscuit.Predicate{Name: "already_signed", IDs: []biscuit.Atom{biscuit.Variable(0)}}, - Body: []biscuit.Predicate{ - {Name: "signature", IDs: []biscuit.Atom{dataID, userPubKey, biscuit.Variable(0)}}, - }, - }) - if err != nil { - return err - } - if len(alreadySigned) != 0 { - return ErrAlreadySigned - } - - return nil -} - -func (v *hubauthVerifier) getUserVerificationData() (*userVerificationData, error) { - toValidate, err := v.Query(biscuit.Rule{ - Head: biscuit.Predicate{ - Name: "to_validate", - IDs: []biscuit.Atom{ - biscuit.Variable(0), // dataID - biscuit.Variable(1), // alg - biscuit.Variable(2), // pubkey - biscuit.Variable(3), // data - biscuit.Variable(4), // signature - biscuit.Variable(5), // signerNonce - biscuit.Variable(6), // signerTimestamp - }}, - Body: []biscuit.Predicate{ - {Name: "should_sign", IDs: []biscuit.Atom{biscuit.SymbolAuthority, biscuit.Variable(0), biscuit.Variable(1), biscuit.Variable(2)}}, - {Name: "data", IDs: []biscuit.Atom{biscuit.SymbolAuthority, biscuit.Variable(0), biscuit.Variable(3)}}, - {Name: "signature", IDs: []biscuit.Atom{biscuit.Variable(0), biscuit.Variable(2), biscuit.Variable(4), biscuit.Variable(5), biscuit.Variable(6)}}, - }, - }) - if err != nil { - return nil, err - } - - if g, w := len(toValidate), 1; g != w { - return nil, fmt.Errorf("invalid to_validate fact count, got %d, want %d", g, w) - } - - toValidateFact := toValidate[0] - if g, w := len(toValidateFact.IDs), 7; g != w { - return nil, fmt.Errorf("invalid to_valid fact atom count, got %d, want %d", g, w) - } - - toVerify := &userVerificationData{} - var ok bool - toVerify.DataID, ok = toValidateFact.IDs[0].(biscuit.Integer) - if !ok { - return nil, errors.New("invalid to_validate atom: dataID") - } - toVerify.Alg, ok = toValidateFact.IDs[1].(biscuit.Symbol) - if !ok { - return nil, errors.New("invalid to_validate atom: alg") - } - toVerify.UserPubKey, ok = toValidateFact.IDs[2].(biscuit.Bytes) - if !ok { - return nil, errors.New("invalid to_validate atom: userPubKey") - } - toVerify.Data, ok = toValidateFact.IDs[3].(biscuit.Bytes) - if !ok { - return nil, errors.New("invalid to_validate atom: data") - } - toVerify.Signature, ok = toValidateFact.IDs[4].(biscuit.Bytes) - if !ok { - return nil, errors.New("invalid to_validate atom: signature") - } - toVerify.Nonce, ok = toValidateFact.IDs[5].(biscuit.Bytes) - if !ok { - return nil, errors.New("invalid to_validate atom: nonce") - } - toVerify.Timestamp, ok = toValidateFact.IDs[6].(biscuit.Date) - if !ok { - return nil, errors.New("invalid to_validate atom: timestamp") - } - - return toVerify, nil -} - -func (v *hubauthVerifier) withValidatedUserSignature(data *userVerificationData) error { - v.AddFact(biscuit.Fact{Predicate: biscuit.Predicate{ - Name: "valid_signature", - IDs: []biscuit.Atom{biscuit.Symbol("ambient"), data.DataID, data.Alg, data.UserPubKey}, - }}) - - return nil -} - -func (v *hubauthVerifier) getAudienceVerificationData(audience string) (*audienceVerificationData, error) { - toValidate, err := v.Query(biscuit.Rule{ - Head: biscuit.Predicate{ - Name: "audience_to_validate", - IDs: []biscuit.Atom{ - biscuit.Variable(0), // challenge - biscuit.Variable(1), // signature - }}, - Body: []biscuit.Predicate{ - {Name: "audience_signature", IDs: []biscuit.Atom{biscuit.SymbolAuthority, biscuit.Symbol(audience), biscuit.Variable(0), biscuit.Variable(1)}}, - }, - }) - if err != nil { - return nil, err - } - - if g, w := len(toValidate), 1; g != w { - return nil, fmt.Errorf("invalid audience_to_validate fact count, got %d, want %d", g, w) - } - - toValidateFact := toValidate[0] - if g, w := len(toValidateFact.IDs), 2; g != w { - return nil, fmt.Errorf("invalid audience_to_validate fact atom count, got %d, want %d", g, w) - } - - toVerify := &audienceVerificationData{Audience: biscuit.Symbol(audience)} - var ok bool - toVerify.Challenge, ok = toValidateFact.IDs[0].(biscuit.Bytes) - if !ok { - return nil, errors.New("invalid audience_to_validate atom: challenge") - } - toVerify.Signature, ok = toValidateFact.IDs[1].(biscuit.Bytes) - if !ok { - return nil, errors.New("invalid audience_to_validate atom: signature") - } - - return toVerify, nil -} - -func (v *hubauthVerifier) getMetadata() (*Metadata, error) { - metaFacts, err := v.Query(biscuit.Rule{ - Head: biscuit.Predicate{ - Name: "metadata", - IDs: []biscuit.Atom{ - biscuit.Variable(0), // clientID - biscuit.Variable(1), // userID - biscuit.Variable(2), // userEmail - biscuit.Variable(3), // issueTime - }}, - Body: []biscuit.Predicate{ - {Name: "metadata", IDs: []biscuit.Atom{biscuit.SymbolAuthority, biscuit.Variable(0), biscuit.Variable(1), biscuit.Variable(2), biscuit.Variable(3)}}, - }, - }) - if err != nil { - return nil, err - } - - if g, w := len(metaFacts), 1; g != w { - return nil, fmt.Errorf("invalid metadata fact count, got %d, want %d", g, w) - } - - metaFact := metaFacts[0] - - clientID, ok := metaFact.IDs[0].(biscuit.String) - if !ok { - return nil, errors.New("invalid metadata atom: clientID") - } - userID, ok := metaFact.IDs[1].(biscuit.String) - if !ok { - return nil, errors.New("invalid metadata atom: userID") - } - userEmail, ok := metaFact.IDs[2].(biscuit.String) - if !ok { - return nil, errors.New("invalid metadata atom: userEmail") - } - issueTime, ok := metaFact.IDs[3].(biscuit.Date) - if !ok { - return nil, errors.New("invalid metadata atom: issueTime") - } - return &Metadata{ - ClientID: string(clientID), - UserID: string(userID), - UserEmail: string(userEmail), - IssueTime: time.Time(issueTime), - }, nil -} - -func (v *hubauthVerifier) withValidatedAudienceSignature(data *audienceVerificationData) error { - v.AddFact(biscuit.Fact{Predicate: biscuit.Predicate{ - Name: "valid_audience_signature", - IDs: []biscuit.Atom{biscuit.Symbol("ambient"), data.Audience, data.Signature}, - }}) - - return nil -} - -func (v *hubauthVerifier) withCurrentTime(t time.Time) error { - v.AddFact(biscuit.Fact{Predicate: biscuit.Predicate{ - Name: "current_time", - IDs: []biscuit.Atom{ - biscuit.Symbol("ambient"), - biscuit.Date(t), - }, - }}) - - return nil -} From b49171734ab2843235dd7a8738065c8baf44a4a2 Mon Sep 17 00:00:00 2001 From: Flavien Binet Date: Thu, 10 Dec 2020 14:02:48 +0100 Subject: [PATCH 10/24] lint fix --- pkg/cli/clients_test.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pkg/cli/clients_test.go b/pkg/cli/clients_test.go index 687a7a8..3a14ef5 100644 --- a/pkg/cli/clients_test.go +++ b/pkg/cli/clients_test.go @@ -215,7 +215,7 @@ func TestClientUpdateCmd(t *testing.T) { } cfg := &Config{DB: &mockClientDatastore{}} expectedMutations := []*hubauth.ClientMutation{ - &hubauth.ClientMutation{ + { Op: hubauth.ClientMutationOpSetRefreshTokenExpiry, RefreshTokenExpiry: 5 * time.Minute, }, @@ -232,19 +232,19 @@ func TestClientUpdateCmd(t *testing.T) { } cfg := &Config{DB: &mockClientDatastore{}} expectedMutations := []*hubauth.ClientMutation{ - &hubauth.ClientMutation{ + { Op: hubauth.ClientMutationOpSetRefreshTokenExpiry, RefreshTokenExpiry: 5 * time.Minute, }, - &hubauth.ClientMutation{ + { Op: hubauth.ClientMutationOpAddRedirectURI, RedirectURI: "http://localhost:1234", }, - &hubauth.ClientMutation{ + { Op: hubauth.ClientMutationOpAddRedirectURI, RedirectURI: "http://localhost:5678", }, - &hubauth.ClientMutation{ + { Op: hubauth.ClientMutationOpDeleteRedirectURI, RedirectURI: "http://removed-domain:1234", }, From 12c39aac7c51ba26bc4cb388b15d8a65099849df Mon Sep 17 00:00:00 2001 From: Flavien Binet Date: Thu, 10 Dec 2020 18:08:18 +0100 Subject: [PATCH 11/24] add policy parser and printer --- go.sum | 5 + pkg/policy/parser.go | 77 ++++++ pkg/policy/parser_test.go | 190 ++++++++++++++ pkg/policy/printer.go | 233 ++++++++++++++++++ pkg/policy/printer_test.go | 28 +++ pkg/policy/testdata/printer/comments.golden | 40 +++ .../testdata/printer/empty_policy.golden | 1 + pkg/policy/testdata/printer/multiple.golden | 72 ++++++ .../testdata/printer/single copy.golden | 24 ++ 9 files changed, 670 insertions(+) create mode 100644 pkg/policy/parser.go create mode 100644 pkg/policy/parser_test.go create mode 100644 pkg/policy/printer.go create mode 100644 pkg/policy/printer_test.go create mode 100644 pkg/policy/testdata/printer/comments.golden create mode 100644 pkg/policy/testdata/printer/empty_policy.golden create mode 100644 pkg/policy/testdata/printer/multiple.golden create mode 100644 pkg/policy/testdata/printer/single copy.golden diff --git a/go.sum b/go.sum index d812a10..aaa4022 100644 --- a/go.sum +++ b/go.sum @@ -55,6 +55,11 @@ github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym github.com/alecthomas/kong v0.2.12 h1:X3kkCOXGUNzLmiu+nQtoxWqj4U2a39MpSJR3QdQXOwI= github.com/alecthomas/kong v0.2.12/go.mod h1:kQOmtJgV+Lb4aj+I2LEn40cbtawdWJ9Y8QLq+lElKxE= github.com/alecthomas/participle v0.6.0/go.mod h1:HfdmEuwvr12HXQN44HPWXR0lHmVolVYe4dyL6lQ3duY= +github.com/alecthomas/participle v0.7.1 h1:2bN7reTw//5f0cugJcTOnY/NYZcWQOaajW+BwZB5xWs= +github.com/alecthomas/participle/v2 v2.0.0-alpha3 h1:7aeHdGgRXADjrDEHwCpXiMMZqppOw2dpQfmVTyBN5cY= +github.com/alecthomas/participle/v2 v2.0.0-alpha3/go.mod h1:Z1zPLDbcGsVsBYsThKXY00i84575bN/nMczzIrU4rWU= +github.com/alecthomas/participle/v2 v2.0.0-alpha3.0.20201208114601-14bec2482095 h1:DCGcCFtR/4YWEOoszqekJRdDoq41G+btPdOSWf5FoSo= +github.com/alecthomas/participle/v2 v2.0.0-alpha3.0.20201208114601-14bec2482095/go.mod h1:Z1zPLDbcGsVsBYsThKXY00i84575bN/nMczzIrU4rWU= github.com/alecthomas/repr v0.0.0-20181024024818-d37bc2a10ba1/go.mod h1:xTS7Pm1pD1mvyM075QCDSRqH6qRLXylzS24ZTpRiSzQ= github.com/aws/aws-sdk-go v1.23.20 h1:2CBuL21P0yKdZN5urf2NxKa1ha8fhnY+A3pBCHFeZoA= github.com/aws/aws-sdk-go v1.23.20/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo= diff --git a/pkg/policy/parser.go b/pkg/policy/parser.go new file mode 100644 index 0000000..6b1d588 --- /dev/null +++ b/pkg/policy/parser.go @@ -0,0 +1,77 @@ +package policy + +import ( + "fmt" + "io" + + "github.com/alecthomas/participle/v2" + "github.com/alecthomas/participle/v2/lexer/stateful" + "github.com/flynn/biscuit-go" + "github.com/flynn/biscuit-go/parser" +) + +var defaultParserOptions = append(parser.DefaultParserOptions, participle.Lexer(policyLexer)) + +var policyLexer = stateful.MustSimple(append( + parser.BiscuitLexerRules, + stateful.Rule{Name: "Policy", Pattern: `policy`}, +)) + +type Document struct { + Policies []*DocumentPolicy `@@*` +} + +type DocumentPolicy struct { + Comments []*parser.Comment `@Comment*` + Name *string `"policy" @String "{"` + Rules []*parser.Rule `("rules" "{" @@* "}")?` + Caveats []*parser.Caveat `("caveats" "{" (@@ ("," @@+)*)* "}")? "}"` +} + +func (d *DocumentPolicy) BiscuitRules() ([]biscuit.Rule, error) { + rules := make([]biscuit.Rule, 0, len(d.Rules)) + for _, r := range d.Rules { + rule, err := r.ToBiscuit() + if err != nil { + return nil, err + } + rules = append(rules, *rule) + } + return rules, nil +} + +func (d *DocumentPolicy) BiscuitCaveats() ([]biscuit.Caveat, error) { + caveats := make([]biscuit.Caveat, 0, len(d.Caveats)) + for _, c := range d.Caveats { + caveat, err := c.ToBiscuit() + if err != nil { + return nil, err + } + + caveats = append(caveats, *caveat) + } + + return caveats, nil +} + +var documentParser = participle.MustBuild(&Document{}, defaultParserOptions...) + +func Parse(r io.Reader) (*Document, error) { + return ParseNamed("policy", r) +} + +func ParseNamed(filename string, r io.Reader) (*Document, error) { + parsed := &Document{} + if err := documentParser.Parse(filename, r, parsed); err != nil { + return nil, err + } + + policies := make(map[string]DocumentPolicy, len(parsed.Policies)) + for _, p := range parsed.Policies { + if _, exists := policies[*p.Name]; exists { + return nil, fmt.Errorf("parse error: duplicate policy %q", *p.Name) + } + } + + return parsed, nil +} diff --git a/pkg/policy/parser_test.go b/pkg/policy/parser_test.go new file mode 100644 index 0000000..be702b4 --- /dev/null +++ b/pkg/policy/parser_test.go @@ -0,0 +1,190 @@ +package policy + +import ( + "strings" + "testing" + + "github.com/flynn/biscuit-go/parser" + "github.com/stretchr/testify/require" +) + +func TestParse(t *testing.T) { + definition := ` + // admin policy comment + policy "admin" { + rules { + // rule 1 comment + *authorized($0) + <- namespace(#ambient, $0) + @ prefix($0, "demo.v1") + } + caveats {[ + // caveat 1 comment + *caveat0($0) <- authorized($0) + ]} + } + + policy "developer" { + rules { + *authorized("demo.v1.Account", $1) + <- namespace(#ambient, "demo.v1.Account"), + method(#ambient, $1), + arg(#ambient, "env", $2) + @ $1 in ["Create", "Read", "Update"], + $2 in ["DEV", "STAGING"] + *authorized("demo.v1.Account", "Read") + <- namespace(#ambient, "demo.v1.Account"), + method(#ambient, "Read"), + arg(#ambient, "env", "PROD") + } + caveats { + [*caveat1($1) <- authorized("demo.v1.Account", $1)] + } + } + + policy "auditor" { + rules { + *authorized("demo.v1.Account", "Read") + <- namespace(#ambient, "demo.v1.Account"), + method(#ambient, "Read"), + arg(#ambient, "env", "DEV") + } + caveats { + [*caveat2("Read") <- authorized("demo.v1.Account", "Read")] + } + } + ` + + doc, err := Parse(strings.NewReader(definition)) + require.NoError(t, err) + + expectedPolicies := &Document{ + Policies: []*DocumentPolicy{{ + Name: sptr("admin"), + Comments: []*parser.Comment{commentptr("admin policy comment")}, + Rules: []*parser.Rule{ + { + Comments: []*parser.Comment{commentptr("rule 1 comment")}, + Head: &parser.Predicate{Name: sptr("authorized"), IDs: []*parser.Atom{{Variable: varptr("0")}}}, + Body: []*parser.Predicate{ + {Name: sptr("namespace"), IDs: []*parser.Atom{{Symbol: symptr("ambient")}, {Variable: varptr("0")}}}, + }, + Constraints: []*parser.Constraint{ + {FunctionConstraint: &parser.FunctionConstraint{ + Function: sptr("prefix"), + Variable: varptr("0"), + Argument: sptr("demo.v1"), + }}, + }, + }, + }, + Caveats: []*parser.Caveat{{Queries: []*parser.Rule{ + { + Comments: []*parser.Comment{commentptr("caveat 1 comment")}, + Head: &parser.Predicate{Name: sptr("caveat0"), IDs: []*parser.Atom{{Variable: varptr("0")}}}, + Body: []*parser.Predicate{ + {Name: sptr("authorized"), IDs: []*parser.Atom{{Variable: varptr("0")}}}, + }, + }, + }}}, + }, + { + Name: sptr("developer"), + Rules: []*parser.Rule{ + { + Head: &parser.Predicate{Name: sptr("authorized"), IDs: []*parser.Atom{ + {String: sptr("demo.v1.Account")}, + {Variable: varptr("1")}, + }}, + Body: []*parser.Predicate{ + {Name: sptr("namespace"), IDs: []*parser.Atom{{Symbol: symptr("ambient")}, {String: sptr("demo.v1.Account")}}}, + {Name: sptr("method"), IDs: []*parser.Atom{{Symbol: symptr("ambient")}, {Variable: varptr("1")}}}, + {Name: sptr("arg"), IDs: []*parser.Atom{{Symbol: symptr("ambient")}, {String: sptr("env")}, {Variable: varptr("2")}}}, + }, + Constraints: []*parser.Constraint{ + { + VariableConstraint: &parser.VariableConstraint{ + Variable: varptr("1"), + Set: &parser.Set{ + Not: false, + String: []string{"Create", "Read", "Update"}, + }, + }, + }, + { + VariableConstraint: &parser.VariableConstraint{ + Variable: varptr("2"), + Set: &parser.Set{ + Not: false, + String: []string{"DEV", "STAGING"}, + }, + }, + }, + }, + }, + { + Head: &parser.Predicate{Name: sptr("authorized"), IDs: []*parser.Atom{{String: sptr("demo.v1.Account")}, {String: sptr("Read")}}}, + Body: []*parser.Predicate{ + {Name: sptr("namespace"), IDs: []*parser.Atom{{Symbol: symptr("ambient")}, {String: sptr("demo.v1.Account")}}}, + {Name: sptr("method"), IDs: []*parser.Atom{{Symbol: symptr("ambient")}, {String: sptr("Read")}}}, + {Name: sptr("arg"), IDs: []*parser.Atom{{Symbol: symptr("ambient")}, {String: sptr("env")}, {String: sptr("PROD")}}}, + }, + }, + }, + Caveats: []*parser.Caveat{{Queries: []*parser.Rule{ + { + Head: &parser.Predicate{Name: sptr("caveat1"), IDs: []*parser.Atom{{Variable: varptr("1")}}}, + Body: []*parser.Predicate{ + {Name: sptr("authorized"), IDs: []*parser.Atom{{String: sptr("demo.v1.Account")}, {Variable: varptr("1")}}}, + }, + }, + }}}, + }, + { + Name: sptr("auditor"), + Rules: []*parser.Rule{ + { + Head: &parser.Predicate{Name: sptr("authorized"), IDs: []*parser.Atom{{String: sptr("demo.v1.Account")}, {String: sptr("Read")}}}, + Body: []*parser.Predicate{ + {Name: sptr("namespace"), IDs: []*parser.Atom{{Symbol: symptr("ambient")}, {String: sptr("demo.v1.Account")}}}, + {Name: sptr("method"), IDs: []*parser.Atom{{Symbol: symptr("ambient")}, {String: sptr("Read")}}}, + {Name: sptr("arg"), IDs: []*parser.Atom{{Symbol: symptr("ambient")}, {String: sptr("env")}, {String: sptr("DEV")}}}, + }, + }, + }, + Caveats: []*parser.Caveat{{Queries: []*parser.Rule{ + { + Head: &parser.Predicate{Name: sptr("caveat2"), IDs: []*parser.Atom{{String: sptr("Read")}}}, + Body: []*parser.Predicate{ + {Name: sptr("authorized"), IDs: []*parser.Atom{{String: sptr("demo.v1.Account")}, {String: sptr("Read")}}}, + }, + }, + }}}, + }, + }, + } + + require.Equal(t, len(expectedPolicies.Policies), len(doc.Policies)) + for i, expectedPolicy := range expectedPolicies.Policies { + require.Equal(t, doc.Policies[i], expectedPolicy) + } +} + +func sptr(s string) *string { + return &s +} + +func symptr(s string) *parser.Symbol { + sym := parser.Symbol(s) + return &sym +} + +func varptr(s string) *parser.Variable { + v := parser.Variable(s) + return &v +} + +func commentptr(s string) *parser.Comment { + c := parser.Comment(s) + return &c +} diff --git a/pkg/policy/printer.go b/pkg/policy/printer.go new file mode 100644 index 0000000..25433cd --- /dev/null +++ b/pkg/policy/printer.go @@ -0,0 +1,233 @@ +package policy + +import ( + "fmt" + "strings" + + "github.com/flynn/biscuit-go/parser" +) + +func Print(d *Document) (string, error) { + p := &printer{ + indent: 0, + out: &strings.Builder{}, + } + + for i, policy := range d.Policies { + p.printPolicy(policy) + if i != len(d.Policies)-1 { + p.write("\n") + } + } + + return p.out.String(), nil +} + +func PrintPolicy(policy *DocumentPolicy) string { + p := &printer{ + indent: 0, + out: &strings.Builder{}, + } + + p.printPolicy(policy) + + return p.out.String() +} + +type printer struct { + indent int + out *strings.Builder +} + +func (p *printer) write(format string, args ...interface{}) { + format = strings.ReplaceAll(format, "\n", "\n"+strings.Repeat(" ", p.indent)) + p.out.WriteString(fmt.Sprintf(format, args...)) +} + +func (p *printer) printPolicy(policy *DocumentPolicy) { + for _, c := range policy.Comments { + p.write("// %s\n", *c) + } + + p.write("policy %q {", *policy.Name) + + if len(policy.Rules) > 0 { + p.indent++ + p.write("\nrules {") + p.indent++ + for _, r := range policy.Rules { + p.write("\n") + p.printRule(r) + } + p.indent-- + p.write("\n") + p.indent-- + p.write("}\n") + } + + if len(policy.Caveats) > 0 { + p.indent++ + p.write("\ncaveats {") + for i, c := range policy.Caveats { + p.indent++ + p.printCaveat(c) + if i != len(policy.Caveats)-1 { + p.write(", ") + } + } + p.indent-- + p.write("}\n") + } + + p.write("}\n") +} + +func (p *printer) printRule(rule *parser.Rule) { + for _, c := range rule.Comments { + p.write("// %s\n", *c) + } + + p.write("*") + p.printPredicate(rule.Head) + p.indent++ + p.write("\n") + + for i, b := range rule.Body { + if i == 0 { + p.write("<- ") + } else { + p.write(" ") + } + p.printPredicate(b) + if i != len(rule.Body)-1 { + p.write(",\n") + } + } + + if len(rule.Constraints) > 0 { + p.write("\n") + } + + for i, c := range rule.Constraints { + if i == 0 { + p.write("@ ") + } else { + p.write(" ") + } + p.printConstraint(c) + if i != len(rule.Constraints)-1 { + p.write(",\n") + } + } + p.indent-- +} + +func (p *printer) printCaveat(c *parser.Caveat) { + p.write("[\n") + for j, r := range c.Queries { + if j != 0 { + p.write("||") + p.indent++ + p.write("\n") + } + p.printRule(r) + if j != len(c.Queries)-1 { + p.indent-- + p.write("\n") + } + } + p.indent-- + p.write("\n]") +} + +func (p *printer) printPredicate(pred *parser.Predicate) { + p.write("%s(%s)", *pred.Name, strings.Join(atomsToString(pred.IDs), ", ")) +} + +func (p *printer) printConstraint(c *parser.Constraint) { + switch { + case c.FunctionConstraint != nil: + p.printFunctionConstraint(c.FunctionConstraint) + case c.VariableConstraint != nil: + p.printVariableConstraint(c.VariableConstraint) + } +} + +func (p *printer) printFunctionConstraint(c *parser.FunctionConstraint) { + p.write("%s($%s, %q)", *c.Function, *c.Variable, *c.Argument) +} + +func (p *printer) printVariableConstraint(c *parser.VariableConstraint) { + var op, target string + switch { + case c.Bytes != nil: + op = *c.Bytes.Operation + target = c.Bytes.Target.String() + case c.Date != nil: + op = *c.Date.Operation + target = fmt.Sprintf("%q", *c.Date.Target) + case c.Int != nil: + op = *c.Int.Operation + target = fmt.Sprintf("%d", *c.Int.Target) + case c.Set != nil: + op = "in" + if c.Set.Not { + op = "not in" + } + + switch { + case c.Set.Bytes != nil: + members := make([]string, 0, len(c.Set.Bytes)) + for _, b := range c.Set.Bytes { + members = append(members, b.String()) + } + target = fmt.Sprintf("[%s]", strings.Join(members, ", ")) + case c.Set.Int != nil: + members := make([]string, 0, len(c.Set.Int)) + for _, i := range c.Set.Int { + members = append(members, fmt.Sprintf("%d", i)) + } + target = fmt.Sprintf("[%s]", strings.Join(members, ", ")) + case c.Set.String != nil: + members := make([]string, 0, len(c.Set.String)) + for _, s := range c.Set.String { + members = append(members, fmt.Sprintf("%q", s)) + } + target = fmt.Sprintf("[%s]", strings.Join(members, ", ")) + case c.Set.Symbols != nil: + members := make([]string, 0, len(c.Set.Symbols)) + for _, s := range c.Set.Symbols { + members = append(members, fmt.Sprintf("#%s", s)) + } + target = fmt.Sprintf("[%s]", strings.Join(members, ", ")) + } + case c.String != nil: + op = *c.String.Operation + target = fmt.Sprintf("%q", *c.String.Target) + } + p.write("$%s %s %s", *c.Variable, op, target) +} + +func atomsToString(atoms []*parser.Atom) []string { + out := make([]string, 0, len(atoms)) + for _, a := range atoms { + var atomStr string + switch { + case a.Bytes != nil: + atomStr = a.Bytes.String() + case a.Integer != nil: + atomStr = fmt.Sprintf("%d", *a.Integer) + case a.Set != nil: + atomStr = fmt.Sprintf("[%s]", strings.Join(atomsToString(a.Set), ", ")) + case a.String != nil: + atomStr = fmt.Sprintf("%q", *a.String) + case a.Symbol != nil: + atomStr = fmt.Sprintf("#%s", *a.Symbol) + case a.Variable != nil: + atomStr = fmt.Sprintf("$%s", *a.Variable) + } + + out = append(out, atomStr) + } + return out +} diff --git a/pkg/policy/printer_test.go b/pkg/policy/printer_test.go new file mode 100644 index 0000000..248b9cc --- /dev/null +++ b/pkg/policy/printer_test.go @@ -0,0 +1,28 @@ +package policy + +import ( + "io/ioutil" + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestPrintTemplateGolden(t *testing.T) { + files, err := filepath.Glob("./testdata/printer/*.golden") + require.NoError(t, err) + + for _, f := range files { + src, err := ioutil.ReadFile(f) + require.NoError(t, err) + + golden := string(src) + d, err := Parse(strings.NewReader(golden)) + require.NoError(t, err) + + out, err := Print(d) + require.NoError(t, err) + require.Equal(t, golden, out) + } +} diff --git a/pkg/policy/testdata/printer/comments.golden b/pkg/policy/testdata/printer/comments.golden new file mode 100644 index 0000000..555490b --- /dev/null +++ b/pkg/policy/testdata/printer/comments.golden @@ -0,0 +1,40 @@ +// some comment +policy "developer" { + rules { + // comment this specific rule + // on multiple lines + *allow_method("Status") + <- service(#ambient, "demo.api.v1.Demo"), + method(#ambient, "Status") + *allow_method($0) + <- service(#ambient, "demo.api.v1.Demo"), + method(#ambient, $0), + arg(#ambient, "env", "DEV") + @ $0 in ["Create", "Delete"] + } + + caveats {[ + // this caveat is required + *authorized($0) + <- allow_method(#authority, $0) + || + // this caveat is required too + *authorized($0) + <- allow_method(#authority, $0) + @ $0 == "method" + ], [ + *authorized($0) + <- allow_method(#authority, $0) + ]} +} + +// some comment +policy "admin" { + rules { + // comment this specific rule + // on multiple lines + *allow_method("Status") + <- service(#ambient, "demo.api.v1.Demo"), + method(#ambient, "Status") + } +} diff --git a/pkg/policy/testdata/printer/empty_policy.golden b/pkg/policy/testdata/printer/empty_policy.golden new file mode 100644 index 0000000..c5767df --- /dev/null +++ b/pkg/policy/testdata/printer/empty_policy.golden @@ -0,0 +1 @@ +policy "test" {} diff --git a/pkg/policy/testdata/printer/multiple.golden b/pkg/policy/testdata/printer/multiple.golden new file mode 100644 index 0000000..6cfc860 --- /dev/null +++ b/pkg/policy/testdata/printer/multiple.golden @@ -0,0 +1,72 @@ +policy "admin" { + rules { + *allow_method($0) + <- service(#ambient, "demo.api.v1.Demo"), + method(#ambient, $0) + @ $0 in ["Status"] + } + + caveats {[ + *authorized($0) + <- allow_method(#authority, $0) + || + *authorized($0) + <- method(#ambient, $0), + env(#ambient, $1) + @ $1 in ["DEV", "STG"] + ], [ + *authorized_server($2) + <- service(#ambient, $2) + @ prefix($2, "demo.api.v1") + ]} +} + +policy "auditor" { + caveats {[ + *allow_dev() + <- arg(#ambient, "env", "DEV") + ]} +} + +policy "developer" { + rules { + *allow_method("Status") + <- service(#ambient, "demo.api.v1.Demo"), + method(#ambient, "Status") + *allow_method($0) + <- service(#ambient, "demo.api.v1.Demo"), + method(#ambient, $0), + arg(#ambient, "env", "DEV") + @ $0 in ["Create", "Delete", "Read", "Status", "Update"] + *allow_method($0) + <- service(#ambient, "demo.api.v1.Demo"), + method(#ambient, $0), + arg(#ambient, "env", $1) + @ $0 in ["Read"], + $1 in ["DEV"] + *allow_method("Read") + <- service(#ambient, "demo.api.v1.Demo"), + method(#ambient, "Read"), + arg(#ambient, "env", "PRD"), + arg(#ambient, "entities.name", $3) + @ $3 in ["entity1", "entity2", "entity3"] + } + + caveats {[ + *authorized($0) + <- allow_method(#authority, $0) + ]} +} + +policy "guest" { + rules { + *allow_method("Status") + <- service(#ambient, "demo.api.v1.Demo"), + method(#ambient, "Status") + } + + caveats {[ + *authorized($0) + <- allow_method(#authority, $0) + ]} +} diff --git a/pkg/policy/testdata/printer/single copy.golden b/pkg/policy/testdata/printer/single copy.golden new file mode 100644 index 0000000..cbedf97 --- /dev/null +++ b/pkg/policy/testdata/printer/single copy.golden @@ -0,0 +1,24 @@ +policy "developer" { + rules { + *allow_method("Status") + <- service(#ambient, "demo.api.v1.Demo"), + method(#ambient, "Status") + *allow_method($0) + <- service(#ambient, "demo.api.v1.Demo"), + method(#ambient, $0), + arg(#ambient, "env", "DEV") + @ $0 in ["Create", "Delete"] + } + + caveats {[ + *authorized($0) + <- allow_method(#authority, $0) + || + *authorized($0) + <- allow_method(#authority, $0) + @ $0 == "method" + ], [ + *authorized($0) + <- allow_method(#authority, $0) + ]} +} From 6a70d9563b8f21be702668850cdc79f111b6056b Mon Sep 17 00:00:00 2001 From: Flavien Binet Date: Thu, 10 Dec 2020 18:09:29 +0100 Subject: [PATCH 12/24] use newest biscuit cookbook --- pkg/idp/token/biscuit.go | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/pkg/idp/token/biscuit.go b/pkg/idp/token/biscuit.go index 9c997be..fd89cfa 100644 --- a/pkg/idp/token/biscuit.go +++ b/pkg/idp/token/biscuit.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" + "github.com/flynn/biscuit-go" "github.com/flynn/biscuit-go/cookbook/signedbiscuit" "github.com/flynn/biscuit-go/sig" "github.com/flynn/hubauth/pkg/kmssign" @@ -43,7 +44,17 @@ func (b *biscuitBuilder) Build(ctx context.Context, audience string, t *AccessTo IssueTime: t.IssueTime, } - return signedbiscuit.GenerateSignable(b.rootKeyPair, audience, audienceKey, t.UserPublicKey, t.ExpireTime, meta) + builder := biscuit.NewBuilder(b.rootKeyPair) + builder, err := signedbiscuit.WithSignableFacts(builder, audience, audienceKey, t.UserPublicKey, t.ExpireTime, meta) + if err != nil { + return nil, err + } + + bisc, err := builder.Build() + if err != nil { + return nil, err + } + return bisc.Serialize() } func (b *biscuitBuilder) TokenType() string { From 5a5dbce81cb26829ad1a7e1b8e42e494c027b48e Mon Sep 17 00:00:00 2001 From: Flavien Binet Date: Thu, 10 Dec 2020 18:10:03 +0100 Subject: [PATCH 13/24] add biscuit policy store --- pkg/datastore/datastore.go | 1 + pkg/datastore/policy.go | 170 +++++++++++++++++++++++++++++++++++++ pkg/hubauth/data.go | 31 +++++++ 3 files changed, 202 insertions(+) create mode 100644 pkg/datastore/policy.go diff --git a/pkg/datastore/datastore.go b/pkg/datastore/datastore.go index b55aff4..7865c2d 100644 --- a/pkg/datastore/datastore.go +++ b/pkg/datastore/datastore.go @@ -21,6 +21,7 @@ const ( kindDomain = "GoogleDomain" kindCachedGroup = "CachedGoogleGroup" kindCachedGroupMember = "CachedGoogleGroupMember" + kindBiscuitPolicy = "BiscuitPolicy" ) func New(db *datastore.Client) hubauth.DataStore { diff --git a/pkg/datastore/policy.go b/pkg/datastore/policy.go new file mode 100644 index 0000000..5e19639 --- /dev/null +++ b/pkg/datastore/policy.go @@ -0,0 +1,170 @@ +package datastore + +import ( + "context" + "time" + + "cloud.google.com/go/datastore" + "github.com/flynn/hubauth/pkg/hubauth" + "go.opencensus.io/trace" + "golang.org/x/exp/errors/fmt" +) + +func buildBiscuitPolicy(p *hubauth.BiscuitPolicy) *biscuitPolicy { + now := time.Now() + return &biscuitPolicy{ + Content: p.Content, + Groups: p.Groups, + CreateTime: now, + UpdateTime: now, + } +} + +type biscuitPolicy struct { + ID *datastore.Key `datastore:"__key__"` + Content string + Groups []string + CreateTime time.Time + UpdateTime time.Time +} + +func (c *biscuitPolicy) Export() *hubauth.BiscuitPolicy { + return &hubauth.BiscuitPolicy{ + ID: c.ID.Encode(), + Content: c.Content, + Groups: c.Groups, + CreateTime: c.CreateTime, + UpdateTime: c.UpdateTime, + } +} + +func biscuitPolicyKey(id string) (*datastore.Key, error) { + k, err := datastore.DecodeKey(id) + if err != nil { + return nil, hubauth.ErrNotFound + } + if k.Kind != kindBiscuitPolicy { + return nil, hubauth.ErrNotFound + } + return k, nil +} + +func (s *service) GetBiscuitPolicy(ctx context.Context, id string) (*hubauth.BiscuitPolicy, error) { + ctx, span := trace.StartSpan(ctx, "datastore.GetBiscuitPolicy") + span.AddAttributes(trace.StringAttribute("biscuit_policy_id", id)) + defer span.End() + + k, err := biscuitPolicyKey(id) + if err != nil { + return nil, err + } + res := &biscuitPolicy{} + if err := s.db.Get(ctx, k, res); err != nil { + if err == datastore.ErrNoSuchEntity { + err = hubauth.ErrNotFound + } + return nil, fmt.Errorf("datastore: error fetching biscuit policy %s: %w", id, err) + } + return res.Export(), nil +} + +func (s *service) CreateBiscuitPolicy(ctx context.Context, policy *hubauth.BiscuitPolicy) (string, error) { + ctx, span := trace.StartSpan(ctx, "datastore.CreateBiscuitPolicy") + defer span.End() + + k, err := s.db.Put(ctx, datastore.IncompleteKey(kindBiscuitPolicy, nil), buildBiscuitPolicy(policy)) + if err != nil { + return "", fmt.Errorf("datastore: error creating biscuit policy: %w", err) + } + id := k.Encode() + span.AddAttributes(trace.StringAttribute("biscuit_policy_id", id)) + return id, nil +} + +func (s *service) MutateBiscuitPolicy(ctx context.Context, id string, mut []*hubauth.BiscuitPolicyMutation) error { + ctx, span := trace.StartSpan(ctx, "datastore.MutateBiscuitPolicy") + span.AddAttributes( + trace.StringAttribute("biscuit_policy_id", id), + trace.Int64Attribute("biscuit_policy_mutation_count", int64(len(mut))), + ) + defer span.End() + + k, err := biscuitPolicyKey(id) + if err != nil { + return err + } + _, err = s.db.RunInTransaction(ctx, func(tx *datastore.Transaction) error { + policy := &biscuitPolicy{} + if err := tx.Get(k, policy); err != nil { + if err == datastore.ErrNoSuchEntity { + err = hubauth.ErrNotFound + } + return fmt.Errorf("datastore: error fetching biscuit policy %s: %w", id, err) + } + modified := false + outer: + for _, m := range mut { + switch m.Op { + case hubauth.BiscuitPolicyMutationOpAddGroup: + for _, g := range policy.Groups { + if g == m.Group { + continue outer + } + } + policy.Groups = append(policy.Groups, m.Group) + modified = true + case hubauth.BiscuitPolicyMutationOpDeleteGroup: + for i, u := range policy.Groups { + if u != m.Group { + continue + } + policy.Groups[i] = policy.Groups[len(policy.Groups)-1] + policy.Groups = policy.Groups[:len(policy.Groups)-1] + modified = true + } + default: + return fmt.Errorf("datastore: unknown biscuit policy mutation op %s", m.Op) + } + } + if !modified { + return nil + } + policy.UpdateTime = time.Now() + _, err := tx.Put(k, policy) + return err + }) + if err != nil { + return fmt.Errorf("datastore: error mutating biscuit policy %s: %w", id, err) + } + return nil +} + +func (s *service) ListBiscuitPolicies(ctx context.Context) ([]*hubauth.BiscuitPolicy, error) { + ctx, span := trace.StartSpan(ctx, "datastore.ListBiscuitPolicies") + defer span.End() + + var policies []*biscuitPolicy + if _, err := s.db.GetAll(ctx, datastore.NewQuery(kindBiscuitPolicy), &policies); err != nil { + return nil, fmt.Errorf("datastore: error listing biscuit policies: %w", err) + } + res := make([]*hubauth.BiscuitPolicy, len(policies)) + for i, c := range policies { + res[i] = c.Export() + } + return res, nil +} + +func (s *service) DeleteBiscuitPolicy(ctx context.Context, id string) error { + ctx, span := trace.StartSpan(ctx, "datastore.DeleteBiscuitPolicy") + span.AddAttributes(trace.StringAttribute("biscuit_policy_id", id)) + defer span.End() + + k, err := biscuitPolicyKey(id) + if err != nil { + return err + } + if err := s.db.Delete(ctx, k); err != nil { + return fmt.Errorf("datastore: error deleting biscuit policy %s: %w", id, err) + } + return nil +} diff --git a/pkg/hubauth/data.go b/pkg/hubauth/data.go index 628cba9..607f1b6 100644 --- a/pkg/hubauth/data.go +++ b/pkg/hubauth/data.go @@ -24,6 +24,7 @@ type DataStore interface { CodeStore RefreshTokenStore CachedGroupStore + BiscuitPolicyStore } type ClientStore interface { @@ -262,3 +263,33 @@ func GetClientInfo(ctx context.Context) *ClientInfo { } return res } + +type BiscuitPolicyStore interface { + GetBiscuitPolicy(ctx context.Context, id string) (*BiscuitPolicy, error) + CreateBiscuitPolicy(ctx context.Context, policy *BiscuitPolicy) (string, error) + MutateBiscuitPolicy(ctx context.Context, id string, mut []*BiscuitPolicyMutation) error + ListBiscuitPolicies(ctx context.Context) ([]*BiscuitPolicy, error) + DeleteBiscuitPolicy(ctx context.Context, id string) error +} + +type BiscuitPolicy struct { + ID string + Content string + Groups []string + CreateTime time.Time + UpdateTime time.Time +} + +type BiscuitPolicyMutationOp byte + +const ( + BiscuitPolicyMutationOpAddGroup BiscuitPolicyMutationOp = iota + BiscuitPolicyMutationOpDeleteGroup + BiscuitPolicyMutationOpSetContent +) + +type BiscuitPolicyMutation struct { + Op BiscuitPolicyMutationOp + + Group string +} From 76334e2ab48daafa23a63f8cbab229af8c9b7079 Mon Sep 17 00:00:00 2001 From: Flavien Binet Date: Thu, 10 Dec 2020 18:10:24 +0100 Subject: [PATCH 14/24] add policy cli commands --- pkg/cli/cli.go | 1 + pkg/cli/policies.go | 219 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 220 insertions(+) create mode 100644 pkg/cli/policies.go diff --git a/pkg/cli/cli.go b/pkg/cli/cli.go index c772034..0739544 100644 --- a/pkg/cli/cli.go +++ b/pkg/cli/cli.go @@ -17,4 +17,5 @@ type CLI struct { Clients clientsCmd `kong:"cmd,help='manage oauth clients'"` Audiences audiencesCmd `kong:"cmd,help='manage audiences'"` + Policies policiesCmd `kong:"cmd,help='manage policies'"` } diff --git a/pkg/cli/policies.go b/pkg/cli/policies.go new file mode 100644 index 0000000..e9f0a33 --- /dev/null +++ b/pkg/cli/policies.go @@ -0,0 +1,219 @@ +package cli + +import ( + "context" + "fmt" + "io/ioutil" + "os" + "strings" + + "github.com/flynn/hubauth/pkg/hubauth" + "github.com/flynn/hubauth/pkg/policy" + "github.com/jedib0t/go-pretty/v6/table" +) + +type policiesCmd struct { + New policiesNewCmd `kong:"cmd,help='dump a new empty policy document on stdout'"` + List policiesListCmd `kong:"cmd,help='list policies',default:'1'"` + Dump policiesDumpCmd `kong:"cmd,help='dump a policy content on stdout'"` + Validate policiesValidateCmd `kong:"cmd,help='validate a policy file'"` + Import policiesImportCmd `kong:"cmd,help='import a policy'"` + Update policiesUpdateCmd `kong:"cmd,help='update a policy'"` + Delete policiesDeleteCmd `kong:"cmd,help='delete a policy'"` +} + +type policiesNewCmd struct { + Filepath string `kong:"name='filepath',short='f',help='optionnal filepath where the policy is written (default: stdout)'"` +} + +func (c *policiesNewCmd) Run(cfg *Config) error { + template := `// This is a template policy + policy "dummy" { + rules { + // This is a dummy rule + *head($var1) + <- body1(#ambient, $name), + body2($value) + @ $name == "example" + } + + caveats {[ + // this is a dummy caveat + *head($var1) + <- body1(#ambient, $name), + body2($value) + @ $name == "example" + ]} + }` + + d, err := policy.Parse(strings.NewReader(template)) + if err != nil { + return err + } + + out, err := policy.Print(d) + if err != nil { + return err + } + + if c.Filepath != "" { + ioutil.WriteFile(c.Filepath, []byte(out), 0644) + fmt.Printf("written %s\n", c.Filepath) + return nil + } + + fmt.Println(out) + return nil +} + +type policiesListCmd struct{} + +func (c *policiesListCmd) Run(cfg *Config) error { + policies, err := cfg.DB.ListBiscuitPolicies(context.Background()) + if err != nil { + return err + } + t := table.NewWriter() + t.SetOutputMirror(os.Stdout) + t.AppendHeader(table.Row{"ID", "Description", "Groups", "CreateTime", "UpdateTime"}) + for _, p := range policies { + t.AppendRow(table.Row{p.ID, getPolicyFirstComment(p), p.Groups, p.CreateTime, p.UpdateTime}) + } + t.Render() + return nil +} + +type policiesDumpCmd struct { + PolicyIDs []string `kong:"name='policy-ids',help='comma separated policy IDs to dump (default: all)'"` +} + +func (c *policiesDumpCmd) Run(cfg *Config) error { + allPolicies := "" + if len(c.PolicyIDs) > 0 { + for _, id := range c.PolicyIDs { + p, err := cfg.DB.GetBiscuitPolicy(context.Background(), id) + if err != nil { + return err + } + + allPolicies += p.Content + } + } else { + policies, err := cfg.DB.ListBiscuitPolicies(context.Background()) + if err != nil { + return err + } + for _, p := range policies { + allPolicies += p.Content + } + } + + doc, err := policy.Parse(strings.NewReader(allPolicies)) + if err != nil { + return err + } + + out, err := policy.Print(doc) + if err != nil { + return err + } + fmt.Printf("%s\n", out) + return nil +} + +type policiesImportCmd struct { + Filepath string `kong:"required,name='filepath',short='f',help='a file containing policy definitions'"` +} + +func (c *policiesImportCmd) Run(cfg *Config) error { + content, err := ioutil.ReadFile(c.Filepath) + if err != nil { + return err + } + + doc, err := policy.Parse(strings.NewReader(string(content))) + if err != nil { + return err + } + + for _, p := range doc.Policies { + content := policy.PrintPolicy(p) + id, err := cfg.DB.CreateBiscuitPolicy(context.Background(), &hubauth.BiscuitPolicy{ + Content: string(content), + }) + if err != nil { + return err + } + + fmt.Printf("Imported policy %q: %s\n", *p.Name, id) + } + return nil +} + +type policiesValidateCmd struct { + Filepath string `kong:"required,name='filepath',short='f',help='a file containing policy definitions'"` +} + +func (c *policiesValidateCmd) Run(cfg *Config) error { + content, err := ioutil.ReadFile(c.Filepath) + if err != nil { + return err + } + + _, err = policy.ParseNamed(c.Filepath, strings.NewReader(string(content))) + if err != nil { + return err + } + + return nil +} + +type policiesUpdateCmd struct { + PolicyID string `kong:"required,name='policy-id',help='a policy ID to update'"` + AddGroups []string `kong:"name='add-groups',help='comma separated list of groups to add on the policy'"` + DeleteGroups []string `kong:"name='delete-groups',help='comma separated list of groups to delete from the policy'"` +} + +func (c *policiesUpdateCmd) Run(cfg *Config) error { + var mut []*hubauth.BiscuitPolicyMutation + for _, g := range c.AddGroups { + mut = append(mut, &hubauth.BiscuitPolicyMutation{ + Op: hubauth.BiscuitPolicyMutationOpAddGroup, + Group: g, + }) + } + for _, g := range c.DeleteGroups { + mut = append(mut, &hubauth.BiscuitPolicyMutation{ + Op: hubauth.BiscuitPolicyMutationOpDeleteGroup, + Group: g, + }) + } + if err := cfg.DB.MutateBiscuitPolicy(context.Background(), c.PolicyID, mut); err != nil { + return err + } + return nil +} + +type policiesDeleteCmd struct { + PolicyID string `kong:"required,name='policy-id',help='a policy ID to delete'"` +} + +func (c *policiesDeleteCmd) Run(cfg *Config) error { + return cfg.DB.DeleteBiscuitPolicy(context.Background(), c.PolicyID) +} + +// getPolicyFirstComment parse the policy content and returns the first policy +// comment line if it exists. On error or when not set, an empty string is returned. +func getPolicyFirstComment(p *hubauth.BiscuitPolicy) string { + doc, err := policy.Parse(strings.NewReader(p.Content)) + if err != nil { + return "" + } + if len(doc.Policies) == 0 { + return "" + } + if len(doc.Policies[0].Comments) == 0 { + return "" + } + return string(*doc.Policies[0].Comments[0]) +} From f0762f79374818b53052e564e5274b79a99ebe5a Mon Sep 17 00:00:00 2001 From: Flavien Binet Date: Thu, 10 Dec 2020 18:17:31 +0100 Subject: [PATCH 15/24] bump biscuit to tmp version --- go.mod | 2 ++ go.sum | 1 + 2 files changed, 3 insertions(+) diff --git a/go.mod b/go.mod index 5dde64c..a4aa488 100644 --- a/go.mod +++ b/go.mod @@ -29,3 +29,5 @@ require ( google.golang.org/grpc v1.34.0 google.golang.org/protobuf v1.25.0 ) + +replace github.com/flynn/biscuit-go => /home/flavien/workspace/flynn-ws/biscuit-go diff --git a/go.sum b/go.sum index aaa4022..03c558d 100644 --- a/go.sum +++ b/go.sum @@ -60,6 +60,7 @@ github.com/alecthomas/participle/v2 v2.0.0-alpha3 h1:7aeHdGgRXADjrDEHwCpXiMMZqpp github.com/alecthomas/participle/v2 v2.0.0-alpha3/go.mod h1:Z1zPLDbcGsVsBYsThKXY00i84575bN/nMczzIrU4rWU= github.com/alecthomas/participle/v2 v2.0.0-alpha3.0.20201208114601-14bec2482095 h1:DCGcCFtR/4YWEOoszqekJRdDoq41G+btPdOSWf5FoSo= github.com/alecthomas/participle/v2 v2.0.0-alpha3.0.20201208114601-14bec2482095/go.mod h1:Z1zPLDbcGsVsBYsThKXY00i84575bN/nMczzIrU4rWU= +github.com/alecthomas/repr v0.0.0-20181024024818-d37bc2a10ba1 h1:GDQdwm/gAcJcLAKQQZGOJ4knlw+7rfEQQcmwTbt4p5E= github.com/alecthomas/repr v0.0.0-20181024024818-d37bc2a10ba1/go.mod h1:xTS7Pm1pD1mvyM075QCDSRqH6qRLXylzS24ZTpRiSzQ= github.com/aws/aws-sdk-go v1.23.20 h1:2CBuL21P0yKdZN5urf2NxKa1ha8fhnY+A3pBCHFeZoA= github.com/aws/aws-sdk-go v1.23.20/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo= From 4c2850ec61558deb45a9393846564e07ebb1c498 Mon Sep 17 00:00:00 2001 From: Flavien Binet Date: Fri, 11 Dec 2020 16:58:08 +0100 Subject: [PATCH 16/24] wip: wait for audience updates --- cmd/hubauth-ext/main.go | 2 +- go.sum | 4 -- pkg/idp/oauth.go | 58 +++++++++++++------------- pkg/idp/oauth_test.go | 18 ++++++--- pkg/idp/steps.go | 28 +++++++------ pkg/idp/steps_test.go | 25 +++++++----- pkg/idp/token/biscuit.go | 76 ++++++++++++++++++++++++++++++++--- pkg/idp/token/biscuit_test.go | 33 ++++++++++++++- pkg/idp/token/builder.go | 1 + pkg/policy/parser.go | 13 ++++++ 10 files changed, 191 insertions(+), 67 deletions(-) diff --git a/cmd/hubauth-ext/main.go b/cmd/hubauth-ext/main.go index fce743d..834646d 100644 --- a/cmd/hubauth-ext/main.go +++ b/cmd/hubauth-ext/main.go @@ -90,7 +90,7 @@ func main() { } rootPubKey = biscuitKey.Public().Bytes() - accessTokenBuilder = token.NewBiscuitBuilder(kmsClient, audienceKeyNamer, biscuitKey) + accessTokenBuilder = token.NewBiscuitBuilder(kmsClient, datastore.New(dsClient), audienceKeyNamer, biscuitKey) default: log.Fatalf("invalid TOKEN_TYPE, must be one of: Bearer, Biscuit") } diff --git a/go.sum b/go.sum index 03c558d..002e339 100644 --- a/go.sum +++ b/go.sum @@ -54,10 +54,6 @@ github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03 github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/alecthomas/kong v0.2.12 h1:X3kkCOXGUNzLmiu+nQtoxWqj4U2a39MpSJR3QdQXOwI= github.com/alecthomas/kong v0.2.12/go.mod h1:kQOmtJgV+Lb4aj+I2LEn40cbtawdWJ9Y8QLq+lElKxE= -github.com/alecthomas/participle v0.6.0/go.mod h1:HfdmEuwvr12HXQN44HPWXR0lHmVolVYe4dyL6lQ3duY= -github.com/alecthomas/participle v0.7.1 h1:2bN7reTw//5f0cugJcTOnY/NYZcWQOaajW+BwZB5xWs= -github.com/alecthomas/participle/v2 v2.0.0-alpha3 h1:7aeHdGgRXADjrDEHwCpXiMMZqppOw2dpQfmVTyBN5cY= -github.com/alecthomas/participle/v2 v2.0.0-alpha3/go.mod h1:Z1zPLDbcGsVsBYsThKXY00i84575bN/nMczzIrU4rWU= github.com/alecthomas/participle/v2 v2.0.0-alpha3.0.20201208114601-14bec2482095 h1:DCGcCFtR/4YWEOoszqekJRdDoq41G+btPdOSWf5FoSo= github.com/alecthomas/participle/v2 v2.0.0-alpha3.0.20201208114601-14bec2482095/go.mod h1:Z1zPLDbcGsVsBYsThKXY00i84575bN/nMczzIrU4rWU= github.com/alecthomas/repr v0.0.0-20181024024818-d37bc2a10ba1 h1:GDQdwm/gAcJcLAKQQZGOJ4knlw+7rfEQQcmwTbt4p5E= diff --git a/pkg/idp/oauth.go b/pkg/idp/oauth.go index 33a9102..6317ec1 100644 --- a/pkg/idp/oauth.go +++ b/pkg/idp/oauth.go @@ -35,7 +35,7 @@ func (clockImpl) Now() time.Time { } type idpSteps interface { - VerifyAudience(ctx context.Context, audienceURL, clientID, userID string) error + VerifyAudience(ctx context.Context, audienceURL, clientID, userID string) ([]string, error) VerifyUserGroups(ctx context.Context, userID string) error CreateCode(ctx context.Context, code *hubauth.Code) (string, string, error) @@ -277,8 +277,10 @@ func (s *idpService) ExchangeCode(parentCtx context.Context, req *hubauth.Exchan return err }) - g.Go(func() error { - return s.steps.VerifyAudience(ctx, req.Audience, req.ClientID, codeInfo.UserId) + var userGroups []string + g.Go(func() (err error) { + userGroups, err = s.steps.VerifyAudience(ctx, req.Audience, req.ClientID, codeInfo.UserId) + return err }) var client *hubauth.Client @@ -303,19 +305,20 @@ func (s *idpService) ExchangeCode(parentCtx context.Context, req *hubauth.Exchan return err }) + if err := g.Wait(); err != nil { + return nil, err + } + + // build access token var accessToken string var tokenType string - g.Go(func() (err error) { - if req.Audience == "" { - return nil - } - + if req.Audience != "" { var userPublicKey []byte if len(req.UserPublicKey) > 0 { var err error userPublicKey, err = base64Decode(req.UserPublicKey) if err != nil { - return fmt.Errorf("idp: invalid public key: %v", err) + return nil, fmt.Errorf("idp: invalid public key: %v", err) } } @@ -324,14 +327,13 @@ func (s *idpService) ExchangeCode(parentCtx context.Context, req *hubauth.Exchan UserID: codeInfo.UserId, UserEmail: codeInfo.UserEmail, UserPublicKey: userPublicKey, + UserGroups: userGroups, IssueTime: now, ExpireTime: now.Add(accessTokenDuration), }) - return err - }) - - if err := g.Wait(); err != nil { - return nil, err + if err != nil { + return nil, err + } } res := &hubauth.AccessToken{ @@ -343,6 +345,7 @@ func (s *idpService) ExchangeCode(parentCtx context.Context, req *hubauth.Exchan RefreshTokenExpiresIn: int(client.RefreshTokenExpiry / time.Second), RefreshTokenIssueTime: now, } + if res.AccessToken == "" { // if no audience was provided, provide a refresh token that can be used to to access /audiences res.TokenType = "RefreshToken" @@ -363,8 +366,10 @@ func (s *idpService) RefreshToken(ctx context.Context, req *hubauth.RefreshToken g, ctx := errgroup.WithContext(ctx) - g.Go(func() error { - return s.steps.VerifyAudience(ctx, req.Audience, req.ClientID, oldToken.UserID) + var userGroups []string + g.Go(func() (err error) { + userGroups, err = s.steps.VerifyAudience(ctx, req.Audience, req.ClientID, oldToken.UserID) + return err }) now := s.clock.Now() @@ -390,19 +395,19 @@ func (s *idpService) RefreshToken(ctx context.Context, req *hubauth.RefreshToken return err }) + if err := g.Wait(); err != nil { + return nil, err + } + var accessToken string var tokenType string - g.Go(func() (err error) { - if req.Audience == "" { - return nil - } - + if req.Audience != "" { var userPublicKey []byte if len(req.UserPublicKey) > 0 { var err error userPublicKey, err = base64Decode(req.UserPublicKey) if err != nil { - return fmt.Errorf("idp: invalid public key: %v", err) + return nil, fmt.Errorf("idp: invalid public key: %v", err) } } @@ -410,15 +415,14 @@ func (s *idpService) RefreshToken(ctx context.Context, req *hubauth.RefreshToken ClientID: req.ClientID, UserID: oldToken.UserID, UserEmail: oldToken.UserEmail, + UserGroups: userGroups, UserPublicKey: userPublicKey, IssueTime: now, ExpireTime: now.Add(accessTokenDuration), }) - return err - }) - - if err := g.Wait(); err != nil { - return nil, err + if err != nil { + return nil, err + } } res := &hubauth.AccessToken{ diff --git a/pkg/idp/oauth_test.go b/pkg/idp/oauth_test.go index 1722aec..48bef2b 100644 --- a/pkg/idp/oauth_test.go +++ b/pkg/idp/oauth_test.go @@ -58,9 +58,9 @@ func (m *mockSteps) SignCode(ctx context.Context, signKey hmacpb.Key, code *sign args := m.Called(ctx, signKey, code) return args.String(0), args.Error(1) } -func (m *mockSteps) VerifyAudience(ctx context.Context, audienceURL, clientID, userID string) error { +func (m *mockSteps) VerifyAudience(ctx context.Context, audienceURL, clientID, userID string) ([]string, error) { args := m.Called(ctx, audienceURL, clientID, userID) - return args.Error(0) + return args.Get(0).([]string), args.Error(1) } func (m *mockSteps) VerifyUserGroups(ctx context.Context, userID string) error { args := m.Called(ctx, userID) @@ -653,9 +653,11 @@ func TestExchangeCode(t *testing.T) { ExpiryTime: now.Add(refreshTokenExpiry), } + userGroups := []string{"grp1", "grp2"} + idpService.clock.(*mockClock).On("Now").Return(now) idpService.steps.(*mockSteps).On("AllocateRefreshToken", mock.Anything, clientID).Return(rtID, nil) - idpService.steps.(*mockSteps).On("VerifyAudience", mock.Anything, testCase.AudienceURL, clientID, userID).Return(nil) + idpService.steps.(*mockSteps).On("VerifyAudience", mock.Anything, testCase.AudienceURL, clientID, userID).Return(userGroups, nil) idpService.steps.(*mockSteps).On("VerifyCode", mock.Anything, &verifyCodeData{ ClientID: clientID, RedirectURI: redirectURI, @@ -669,6 +671,7 @@ func TestExchangeCode(t *testing.T) { ClientID: clientID, UserID: userID, UserEmail: userEmail, + UserGroups: userGroups, IssueTime: now, ExpireTime: now.Add(accessTokenDuration), }).Return(accessToken, testCase.Want.TokenType, nil) @@ -799,7 +802,7 @@ func TestExchangeCodeErrors(t *testing.T) { idpService.clock.(*mockClock).On("Now").Return(now) idpService.steps.(*mockSteps).On("AllocateRefreshToken", mock.Anything, mock.Anything).Return("", testCase.AllocateErr) idpService.steps.(*mockSteps).On("VerifyCode", mock.Anything, mock.Anything).Return(&hubauth.Code{}, testCase.VerifyCodeErr) - idpService.steps.(*mockSteps).On("VerifyAudience", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(testCase.VerifyAudienceErr) + idpService.steps.(*mockSteps).On("VerifyAudience", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]string{}, testCase.VerifyAudienceErr) idpService.steps.(*mockSteps).On("SaveRefreshToken", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&hubauth.Client{}, testCase.SaveErr) idpService.steps.(*mockSteps).On("SignRefreshToken", mock.Anything, mock.Anything, mock.Anything).Return("", testCase.SignRTErr) idpService.steps.(*mockSteps).On("BuildAccessToken", mock.Anything, mock.Anything, mock.Anything).Return("", "", testCase.SignATErr) @@ -879,12 +882,14 @@ func TestRefreshToken(t *testing.T) { }, } + userGroups := []string{"grp1", "grp2"} + for _, testCase := range testCases { t.Run(testCase.Desc, func(t *testing.T) { idpService := newTestIdPService(t) idpService.clock.(*mockClock).On("Now").Return(now) - idpService.steps.(*mockSteps).On("VerifyAudience", mock.Anything, testCase.AudienceURL, b64ClientID, userID).Return(nil) + idpService.steps.(*mockSteps).On("VerifyAudience", mock.Anything, testCase.AudienceURL, b64ClientID, userID).Return(userGroups, nil) idpService.steps.(*mockSteps).On("RenewRefreshToken", mock.Anything, b64ClientID, b64OldTokenID, issueTimeFromProto, now).Return(newRefreshToken, nil) idpService.steps.(*mockSteps).On("SignRefreshToken", mock.Anything, idpService.refreshKey, &signedRefreshTokenData{ refreshTokenData: &refreshTokenData{ @@ -900,6 +905,7 @@ func TestRefreshToken(t *testing.T) { ClientID: b64ClientID, UserID: userID, UserEmail: userEmail, + UserGroups: userGroups, IssueTime: now, ExpireTime: now.Add(accessTokenDuration), }).Return(newAccessTokenStr, testCase.Want.TokenType, nil) @@ -1002,7 +1008,7 @@ func TestRefreshTokenStepErrors(t *testing.T) { } idpService.clock.(*mockClock).On("Now").Return(now) - idpService.steps.(*mockSteps).On("VerifyAudience", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(testCase.VerifyAudienceErr) + idpService.steps.(*mockSteps).On("VerifyAudience", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]string{}, testCase.VerifyAudienceErr) idpService.steps.(*mockSteps).On("RenewRefreshToken", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&hubauth.RefreshToken{}, testCase.RenewRTErr) idpService.steps.(*mockSteps).On("SignRefreshToken", mock.Anything, mock.Anything, mock.Anything).Return("", testCase.SignRTErr) idpService.steps.(*mockSteps).On("BuildAccessToken", mock.Anything, mock.Anything, mock.Anything).Return("", "", testCase.SignATErr) diff --git a/pkg/idp/steps.go b/pkg/idp/steps.go index db08163..c5d3620 100644 --- a/pkg/idp/steps.go +++ b/pkg/idp/steps.go @@ -122,19 +122,23 @@ func (s *steps) SignCode(ctx context.Context, signKey hmacpb.Key, code *signCode return base64Encode(res), nil } -func (s *steps) VerifyAudience(ctx context.Context, audienceURL, clientID, userID string) error { +// VerifyAudience ensure the user can access the audience by verifying +// that the user have at least one group belonging to the audience policies groups +// It returns the list of user groups, or an error when the user is not allowed to access this audience. +// When no audience is provided, no group and no error is returned, +func (s *steps) VerifyAudience(ctx context.Context, audienceURL, clientID, userID string) ([]string, error) { if audienceURL == "" { - return nil + return nil, nil } audience, err := s.db.GetAudience(ctx, audienceURL) if err != nil { if errors.Is(err, hubauth.ErrNotFound) { - return &hubauth.OAuthError{ + return nil, &hubauth.OAuthError{ Code: "invalid_request", Description: "unknown audience", } } - return fmt.Errorf("idp: error getting audience %s: %w", audienceURL, err) + return nil, fmt.Errorf("idp: error getting audience %s: %w", audienceURL, err) } foundClient := false for _, c := range audience.ClientIDs { @@ -145,20 +149,20 @@ func (s *steps) VerifyAudience(ctx context.Context, audienceURL, clientID, userI } if !foundClient { clog.Set(ctx, zap.Strings("audience_client_ids", audience.ClientIDs)) - return &hubauth.OAuthError{ + return nil, &hubauth.OAuthError{ Code: "invalid_client", Description: "unknown client for audience", } } - err = s.checkUser(ctx, audience, userID) + userGroups, err := s.checkUser(ctx, audience, userID) if errors.Is(err, hubauth.ErrUnauthorizedUser) { - return &hubauth.OAuthError{ + return nil, &hubauth.OAuthError{ Code: "access_denied", Description: "user is not authorized for access", } } - return err + return userGroups, err } func (s *steps) VerifyUserGroups(ctx context.Context, userID string) error { @@ -175,10 +179,10 @@ func (s *steps) VerifyUserGroups(ctx context.Context, userID string) error { return nil } -func (s *steps) checkUser(ctx context.Context, cluster *hubauth.Audience, userID string) error { +func (s *steps) checkUser(ctx context.Context, cluster *hubauth.Audience, userID string) ([]string, error) { groups, err := s.db.GetCachedMemberGroups(ctx, userID) if err != nil { - return fmt.Errorf("idp: error getting cached groups for user: %w", err) + return nil, fmt.Errorf("idp: error getting cached groups for user: %w", err) } // TODO: log allowed groups and cached groups @@ -196,9 +200,9 @@ outer: } } if !allowed { - return hubauth.ErrUnauthorizedUser + return nil, hubauth.ErrUnauthorizedUser } - return nil + return groups, nil } type refreshTokenData struct { diff --git a/pkg/idp/steps_test.go b/pkg/idp/steps_test.go index 9c02bbf..c950493 100644 --- a/pkg/idp/steps_test.go +++ b/pkg/idp/steps_test.go @@ -312,11 +312,12 @@ func TestVerifyAudience(t *testing.T) { require.NoError(t, err) testCases := []struct { - Desc string - Err error - AudienceURL string - ClientID string - UserID string + Desc string + Err error + ExpectedGroups []string + AudienceURL string + ClientID string + UserID string }{ { Desc: "no audience does nothing", @@ -350,21 +351,23 @@ func TestVerifyAudience(t *testing.T) { }, }, { - Desc: "all valid no error", - AudienceURL: validAudienceURL, - ClientID: validClientID, - UserID: validUserID, - Err: nil, + Desc: "all valid no error", + AudienceURL: validAudienceURL, + ExpectedGroups: []string{validGroupID}, + ClientID: validClientID, + UserID: validUserID, + Err: nil, }, } for _, testCase := range testCases { t.Run(testCase.Desc, func(t *testing.T) { - err := s.VerifyAudience(context.Background(), testCase.AudienceURL, testCase.ClientID, testCase.UserID) + grps, err := s.VerifyAudience(context.Background(), testCase.AudienceURL, testCase.ClientID, testCase.UserID) if testCase.Err != nil { require.Equal(t, testCase.Err, err) } else { require.NoError(t, err) + require.Equal(t, testCase.ExpectedGroups, grps) } }) } diff --git a/pkg/idp/token/biscuit.go b/pkg/idp/token/biscuit.go index fd89cfa..4743124 100644 --- a/pkg/idp/token/biscuit.go +++ b/pkg/idp/token/biscuit.go @@ -6,11 +6,14 @@ import ( "encoding/base64" "errors" "fmt" + "strings" "github.com/flynn/biscuit-go" "github.com/flynn/biscuit-go/cookbook/signedbiscuit" "github.com/flynn/biscuit-go/sig" + "github.com/flynn/hubauth/pkg/hubauth" "github.com/flynn/hubauth/pkg/kmssign" + "github.com/flynn/hubauth/pkg/policy" ) var ( @@ -19,13 +22,15 @@ var ( type biscuitBuilder struct { kms kmssign.KMSClient + db hubauth.BiscuitPolicyStore audienceKey kmssign.AudienceKeyNamer rootKeyPair sig.Keypair } -func NewBiscuitBuilder(kms kmssign.KMSClient, audienceKey kmssign.AudienceKeyNamer, rootKeyPair sig.Keypair) AccessTokenBuilder { +func NewBiscuitBuilder(kms kmssign.KMSClient, db hubauth.BiscuitPolicyStore, audienceKey kmssign.AudienceKeyNamer, rootKeyPair sig.Keypair) AccessTokenBuilder { return &biscuitBuilder{ kms: kms, + db: db, audienceKey: audienceKey, rootKeyPair: rootKeyPair, } @@ -38,10 +43,11 @@ func (b *biscuitBuilder) Build(ctx context.Context, audience string, t *AccessTo audienceKey := kmssign.NewPrivateKey(b.kms, b.audienceKey(audience), crypto.SHA256) meta := &signedbiscuit.Metadata{ - ClientID: t.ClientID, - UserID: t.UserID, - UserEmail: t.UserEmail, - IssueTime: t.IssueTime, + ClientID: t.ClientID, + UserID: t.UserID, + UserEmail: t.UserEmail, + UserGroups: t.UserGroups, + IssueTime: t.IssueTime, } builder := biscuit.NewBuilder(b.rootKeyPair) @@ -50,6 +56,19 @@ func (b *biscuitBuilder) Build(ctx context.Context, audience string, t *AccessTo return nil, err } + // retrieve policies from user groups and add each policy rules and caveats to the biscuit + userPolicies, err := b.getUserPolicies(ctx, t.UserGroups) + if err != nil { + return nil, err + } + + for _, p := range userPolicies { + builder, err = withPolicy(builder, p) + if err != nil { + return nil, err + } + } + bisc, err := builder.Build() if err != nil { return nil, err @@ -74,3 +93,50 @@ func DecodeB64PrivateKey(b64key string) (sig.Keypair, error) { kp = sig.NewKeypair(rootPrivateKey) return kp, nil } + +func (b *biscuitBuilder) getUserPolicies(ctx context.Context, userGroups []string) ([]*hubauth.BiscuitPolicy, error) { + allPolicies, err := b.db.ListBiscuitPolicies(ctx) + if err != nil { + return nil, err + } + + var userPolicies []*hubauth.BiscuitPolicy + for _, p := range allPolicies { + outer: + for _, g := range p.Groups { + for _, ug := range userGroups { + if g == ug { + userPolicies = append(userPolicies, p) + continue outer + } + } + } + } + return userPolicies, nil +} + +func withPolicy(builder biscuit.Builder, p *hubauth.BiscuitPolicy) (biscuit.Builder, error) { + parsed, err := policy.ParseDocumentPolicy(strings.NewReader(p.Content)) + if err != nil { + return nil, err + } + for _, rule := range parsed.Rules { + biscuitRule, err := rule.ToBiscuit() + if err != nil { + return nil, err + } + if err := builder.AddAuthorityRule(*biscuitRule); err != nil { + return nil, err + } + } + for _, caveat := range parsed.Caveats { + biscuitCaveat, err := caveat.ToBiscuit() + if err != nil { + return nil, err + } + if err := builder.AddAuthorityCaveat(*biscuitCaveat); err != nil { + return nil, err + } + } + return builder, nil +} diff --git a/pkg/idp/token/biscuit_test.go b/pkg/idp/token/biscuit_test.go index 41b705f..a32a1d7 100644 --- a/pkg/idp/token/biscuit_test.go +++ b/pkg/idp/token/biscuit_test.go @@ -12,17 +12,48 @@ import ( "github.com/flynn/biscuit-go" "github.com/flynn/biscuit-go/sig" + "github.com/flynn/hubauth/pkg/hubauth" "github.com/flynn/hubauth/pkg/kmssign/kmssim" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) +type mockPolicyStore struct { + mock.Mock +} + +var _ hubauth.BiscuitPolicyStore = (*mockPolicyStore)(nil) + +func (m *mockPolicyStore) GetBiscuitPolicy(ctx context.Context, id string) (*hubauth.BiscuitPolicy, error) { + args := m.Called(ctx, id) + return args.Get(0).(*hubauth.BiscuitPolicy), args.Error(1) +} +func (m *mockPolicyStore) CreateBiscuitPolicy(ctx context.Context, policy *hubauth.BiscuitPolicy) (string, error) { + args := m.Called(ctx, policy) + return args.String(0), args.Error(1) +} +func (m *mockPolicyStore) MutateBiscuitPolicy(ctx context.Context, id string, mut []*hubauth.BiscuitPolicyMutation) error { + args := m.Called(ctx, id, mut) + return args.Error(0) +} +func (m *mockPolicyStore) ListBiscuitPolicies(ctx context.Context) ([]*hubauth.BiscuitPolicy, error) { + args := m.Called(ctx) + return args.Get(1).([]*hubauth.BiscuitPolicy), args.Error(1) +} +func (m *mockPolicyStore) DeleteBiscuitPolicy(ctx context.Context, id string) error { + args := m.Called(ctx, id) + return args.Error(0) +} + func TestBiscuitBuilder(t *testing.T) { audience := "https://audience.url" audienceKeyName := audienceKeyNamer(audience) kmsClient := kmssim.NewClient([]string{audienceKeyName}) rootKeyPair := sig.GenerateKeypair(rand.Reader) - builder := NewBiscuitBuilder(kmsClient, audienceKeyNamer, rootKeyPair) + policyStore := new(mockPolicyStore) + + builder := NewBiscuitBuilder(kmsClient, policyStore, audienceKeyNamer, rootKeyPair) priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) require.NoError(t, err) diff --git a/pkg/idp/token/builder.go b/pkg/idp/token/builder.go index 0da8c9a..e6dd403 100644 --- a/pkg/idp/token/builder.go +++ b/pkg/idp/token/builder.go @@ -9,6 +9,7 @@ type AccessTokenData struct { ClientID string UserID string UserEmail string + UserGroups []string UserPublicKey []byte IssueTime time.Time ExpireTime time.Time diff --git a/pkg/policy/parser.go b/pkg/policy/parser.go index 6b1d588..86aaeb3 100644 --- a/pkg/policy/parser.go +++ b/pkg/policy/parser.go @@ -55,6 +55,7 @@ func (d *DocumentPolicy) BiscuitCaveats() ([]biscuit.Caveat, error) { } var documentParser = participle.MustBuild(&Document{}, defaultParserOptions...) +var documentPolicyParser = participle.MustBuild(&DocumentPolicy{}, defaultParserOptions...) func Parse(r io.Reader) (*Document, error) { return ParseNamed("policy", r) @@ -75,3 +76,15 @@ func ParseNamed(filename string, r io.Reader) (*Document, error) { return parsed, nil } + +func ParseDocumentPolicy(r io.Reader) (*DocumentPolicy, error) { + return ParseNamedDocumentPolicy("policy", r) +} + +func ParseNamedDocumentPolicy(name string, r io.Reader) (*DocumentPolicy, error) { + p := &DocumentPolicy{} + if err := documentPolicyParser.Parse(name, r, p); err != nil { + return nil, err + } + return p, nil +} From 2038a0ca7487ffc513c58b72edd2ffc0c1d2818f Mon Sep 17 00:00:00 2001 From: Flavien Binet Date: Tue, 15 Dec 2020 17:39:58 +0100 Subject: [PATCH 17/24] datastore: add audience policy support --- pkg/datastore/audience.go | 164 +++++++++++++++-- pkg/datastore/audience_test.go | 313 ++++++++++++++++++++++++++++++--- pkg/datastore/datastore.go | 1 - pkg/datastore/policy.go | 170 ------------------ pkg/hubauth/data.go | 64 ++++--- 5 files changed, 473 insertions(+), 239 deletions(-) delete mode 100644 pkg/datastore/policy.go diff --git a/pkg/datastore/audience.go b/pkg/datastore/audience.go index d05ab72..b8a2b6f 100644 --- a/pkg/datastore/audience.go +++ b/pkg/datastore/audience.go @@ -18,12 +18,18 @@ func buildAudience(c *hubauth.Audience) *audience { userGroups[i] = buildGoogleUserGroups(p) } + policies := make([]biscuitPolicy, len(c.Policies)) + for i, p := range c.Policies { + policies[i] = buildBiscuitPolicy(p) + } + return &audience{ Key: audienceKey(c.URL), Name: c.Name, Type: c.Type, ClientIDs: c.ClientIDs, UserGroups: userGroups, + Policies: policies, CreateTime: now, UpdateTime: now, } @@ -35,6 +41,7 @@ type audience struct { Type string ClientIDs []string UserGroups []googleUserGroups `datastore:",flatten"` + Policies []biscuitPolicy `datastore:",flatten"` CreateTime time.Time UpdateTime time.Time } @@ -47,32 +54,67 @@ func buildGoogleUserGroups(p *hubauth.GoogleUserGroups) googleUserGroups { } } +func buildBiscuitPolicy(p *hubauth.BiscuitPolicy) biscuitPolicy { + return biscuitPolicy{ + Name: p.Name, + Content: p.Content, + Groups: strings.Join(p.Groups, ","), + } +} + type googleUserGroups struct { Domain string APIUser string Groups string // datastore doesn't take nested lists, so encode by comma-separating } +type biscuitPolicy struct { + Name string + Content string + Groups string // datastore doesn't take nested lists, so encode by comma-separating +} + func (c *audience) Export() *hubauth.Audience { - userGroups := make([]*hubauth.GoogleUserGroups, len(c.UserGroups)) - for i, p := range c.UserGroups { - var grps []string - if p.Groups != "" { - grps = strings.Split(p.Groups, ",") + var userGroups []*hubauth.GoogleUserGroups + if len(c.UserGroups) > 0 { + userGroups = make([]*hubauth.GoogleUserGroups, len(c.UserGroups)) + for i, p := range c.UserGroups { + var grps []string + if p.Groups != "" { + grps = strings.Split(p.Groups, ",") + } + + userGroups[i] = &hubauth.GoogleUserGroups{ + Domain: p.Domain, + APIUser: p.APIUser, + Groups: grps, + } } + } + var policies []*hubauth.BiscuitPolicy + if len(c.Policies) > 0 { + policies = make([]*hubauth.BiscuitPolicy, len(c.Policies)) + for i, p := range c.Policies { + var grps []string + if p.Groups != "" { + grps = strings.Split(p.Groups, ",") + } - userGroups[i] = &hubauth.GoogleUserGroups{ - Domain: p.Domain, - APIUser: p.APIUser, - Groups: grps, + policies[i] = &hubauth.BiscuitPolicy{ + Name: p.Name, + Content: p.Content, + Groups: grps, + } } } + return &hubauth.Audience{ URL: c.Key.Name, Name: c.Name, Type: c.Type, ClientIDs: c.ClientIDs, UserGroups: userGroups, + Policies: policies, CreateTime: c.CreateTime, UpdateTime: c.UpdateTime, } @@ -172,12 +214,25 @@ func (s *service) MutateAudience(ctx context.Context, url string, mut []*hubauth } aud.Type = m.Type modified = true - case hubauth.AudienceMutationMigratePolicy: - aud.UserGroups = make([]googleUserGroups, len(aud.Policies)) - for i, ug := range aud.Policies { - aud.UserGroups[i] = ug + case hubauth.AudienceMutationSetPolicy: + for i, p := range aud.Policies { + if p.Name == m.Policy.Name { + aud.Policies[i] = buildBiscuitPolicy(&m.Policy) + modified = true + continue outer + } } + aud.Policies = append(aud.Policies, buildBiscuitPolicy(&m.Policy)) modified = true + case hubauth.AudienceMutationDeletePolicy: + for i, p := range aud.Policies { + if p.Name != m.Policy.Name { + continue + } + aud.Policies[i] = aud.Policies[len(aud.Policies)-1] + aud.Policies = aud.Policies[:len(aud.Policies)-1] + modified = true + } default: return fmt.Errorf("datastore: unknown audience mutation op %s", m.Op) } @@ -278,6 +333,89 @@ func (s *service) MutateAudienceUserGroups(ctx context.Context, url string, doma return nil } +func (s *service) MutateAudiencePolicy(ctx context.Context, url string, policyName string, mut []*hubauth.AudiencePolicyMutation) error { + ctx, span := trace.StartSpan(ctx, "datastore.MutateAudiencePolicy") + span.AddAttributes( + trace.StringAttribute("audience_url", url), + trace.StringAttribute("audience_policy_name", policyName), + trace.Int64Attribute("audience_policy_mutation_count", int64(len(mut))), + ) + defer span.End() + + k := audienceKey(url) + _, err := s.db.RunInTransaction(ctx, func(tx *datastore.Transaction) error { + aud := &audience{} + if err := tx.Get(k, aud); err != nil { + if err == datastore.ErrNoSuchEntity { + err = hubauth.ErrNotFound + } + return fmt.Errorf("datastore: error fetching audience %s: %w", url, err) + } + + var policy *biscuitPolicy + for i := range aud.Policies { + if aud.Policies[i].Name == policyName { + policy = &aud.Policies[i] + break + } + } + if policy == nil { + return hubauth.ErrNotFound + } + + modified := false + outer: + for _, m := range mut { + switch m.Op { + case hubauth.AudiencePolicyMutationOpAddGroup: + var groups []string + if policy.Groups != "" { + groups = strings.Split(policy.Groups, ",") + } + for _, g := range groups { + if g == m.Group { + continue outer + } + } + policy.Groups = strings.Join(append(groups, m.Group), ",") + modified = true + case hubauth.AudiencePolicyMutationOpDeleteGroup: + var groups []string + if policy.Groups != "" { + groups = strings.Split(policy.Groups, ",") + } + for i, g := range groups { + if g != m.Group { + continue + } + groups[i] = groups[len(groups)-1] + groups = groups[:len(groups)-1] + } + policy.Groups = strings.Join(groups, ",") + modified = true + case hubauth.AudiencePolicyMutationOpSetContent: + if policy.Content == m.Content { + continue + } + policy.Content = m.Content + modified = true + default: + return fmt.Errorf("datastore: unknown audience policy mutation op %s", m.Op) + } + } + if !modified { + return nil + } + aud.UpdateTime = time.Now() + _, err := tx.Put(k, aud) + return err + }) + if err != nil { + return fmt.Errorf("datastore: error mutating audience %s: %w", url, err) + } + return nil +} + func (s *service) ListAudiences(ctx context.Context) ([]*hubauth.Audience, error) { ctx, span := trace.StartSpan(ctx, "datastore.ListAudiences") defer span.End() diff --git a/pkg/datastore/audience_test.go b/pkg/datastore/audience_test.go index 0c754bb..6ae7b34 100644 --- a/pkg/datastore/audience_test.go +++ b/pkg/datastore/audience_test.go @@ -358,6 +358,105 @@ func TestAudienceMutate(t *testing.T) { Type: "new-type", }, }, + { + desc: "set new policy", + mut: []*hubauth.AudienceMutation{ + { + Op: hubauth.AudienceMutationSetPolicy, + Policy: hubauth.BiscuitPolicy{ + Name: "policy", + Content: "policy content", + Groups: []string{"grp1"}, + }, + }, + }, + before: &hubauth.Audience{ + Policies: nil, + }, + after: &hubauth.Audience{ + Policies: []*hubauth.BiscuitPolicy{ + { + Name: "policy", + Content: "policy content", + Groups: []string{"grp1"}, + }, + }, + }, + }, + { + desc: "set existing policy", + mut: []*hubauth.AudienceMutation{ + { + Op: hubauth.AudienceMutationSetPolicy, + Policy: hubauth.BiscuitPolicy{ + Name: "policy", + Content: "policy content", + Groups: []string{"grp1"}, + }, + }, + }, + before: &hubauth.Audience{ + Policies: []*hubauth.BiscuitPolicy{ + { + Name: "policy", + Content: "old policy content", + Groups: []string{"grpA", "grpB"}, + }, + { + Name: "another_policy", + Content: "another policy content", + }, + }, + }, + after: &hubauth.Audience{ + Policies: []*hubauth.BiscuitPolicy{ + { + Name: "policy", + Content: "policy content", + Groups: []string{"grp1"}, + }, + { + Name: "another_policy", + Content: "another policy content", + }, + }, + }, + }, + { + desc: "delete policy", + mut: []*hubauth.AudienceMutation{ + { + Op: hubauth.AudienceMutationDeletePolicy, + Policy: hubauth.BiscuitPolicy{ + Name: "policy", + Content: "policy content", + Groups: []string{"grp1"}, + }, + }, + }, + before: &hubauth.Audience{ + Policies: []*hubauth.BiscuitPolicy{ + { + Name: "policy", + Content: "old policy content", + Groups: []string{"grpA", "grpB"}, + }, + { + Name: "another_policy", + Content: "another policy content", + }, + }, + }, + after: &hubauth.Audience{ + Policies: []*hubauth.BiscuitPolicy{ + { + Name: "another_policy", + Content: "another policy content", + }, + }, + }, + }, + { desc: "multiple", mut: []*hubauth.AudienceMutation{ @@ -387,28 +486,30 @@ func TestAudienceMutate(t *testing.T) { ctx := context.Background() url := "https://cluster.mutate.example.com" for _, tt := range tests { - tt.before.URL = url - tt.after.URL = url - err := s.CreateAudience(ctx, tt.before) - require.NoError(t, err, tt.desc) - before, err := s.GetAudience(ctx, url) - require.NoError(t, err) - - err = s.MutateAudience(ctx, url, tt.mut) - require.NoError(t, err, tt.desc) - - res, err := s.GetAudience(ctx, url) - require.NoError(t, err, tt.desc) - if len(res.UserGroups) == 0 { - res.UserGroups = nil - } - require.Equal(t, before.CreateTime, res.CreateTime) - - res.CreateTime = time.Time{} - res.UpdateTime = time.Time{} - require.Equal(t, tt.after, res, tt.desc) - - s.DeleteAudience(ctx, url) + t.Run(tt.desc, func(t *testing.T) { + tt.before.URL = url + tt.after.URL = url + err := s.CreateAudience(ctx, tt.before) + require.NoError(t, err, tt.desc) + before, err := s.GetAudience(ctx, url) + require.NoError(t, err) + + err = s.MutateAudience(ctx, url, tt.mut) + require.NoError(t, err, tt.desc) + + res, err := s.GetAudience(ctx, url) + require.NoError(t, err, tt.desc) + if len(res.UserGroups) == 0 { + res.UserGroups = nil + } + require.Equal(t, before.CreateTime, res.CreateTime) + + res.CreateTime = time.Time{} + res.UpdateTime = time.Time{} + require.Equal(t, tt.after, res, tt.desc) + + s.DeleteAudience(ctx, url) + }) } } @@ -624,3 +725,171 @@ func TestMutateAudienceUserGroups(t *testing.T) { s.DeleteAudience(ctx, aud.URL) } } + +func TestMutateAudiencePolicy(t *testing.T) { + policyName := "policy_name" + type test struct { + desc string + mut []*hubauth.AudiencePolicyMutation + before []*hubauth.BiscuitPolicy + after []*hubauth.BiscuitPolicy + } + tests := []test{ + { + desc: "add single group", + mut: []*hubauth.AudiencePolicyMutation{ + { + Op: hubauth.AudiencePolicyMutationOpAddGroup, + Group: "grp1", + }, + }, + before: []*hubauth.BiscuitPolicy{ + { + Name: policyName, + Groups: nil, + }, + }, + after: []*hubauth.BiscuitPolicy{ + { + Name: policyName, + Groups: []string{"grp1"}, + }, + }, + }, + { + desc: "add multiple groups", + mut: []*hubauth.AudiencePolicyMutation{ + { + Op: hubauth.AudiencePolicyMutationOpAddGroup, + Group: "existing", + }, + { + Op: hubauth.AudiencePolicyMutationOpAddGroup, + Group: "grp1", + }, + { + Op: hubauth.AudiencePolicyMutationOpAddGroup, + Group: "grp2", + }, + }, + before: []*hubauth.BiscuitPolicy{ + { + Name: policyName, + Groups: []string{"existing"}, + }, + }, + after: []*hubauth.BiscuitPolicy{ + { + Name: policyName, + Groups: []string{"existing", "grp1", "grp2"}, + }, + }, + }, + { + desc: "delete last group", + mut: []*hubauth.AudiencePolicyMutation{ + { + Op: hubauth.AudiencePolicyMutationOpDeleteGroup, + Group: "grp1", + }, + }, + before: []*hubauth.BiscuitPolicy{ + { + Name: policyName, + Groups: []string{"grp1"}, + }, + }, + after: []*hubauth.BiscuitPolicy{ + { + Name: policyName, + Groups: nil, + }, + }, + }, + { + desc: "delete multiple groups", + mut: []*hubauth.AudiencePolicyMutation{ + { + Op: hubauth.AudiencePolicyMutationOpDeleteGroup, + Group: "grp1", + }, + { + Op: hubauth.AudiencePolicyMutationOpDeleteGroup, + Group: "grp2", + }, + }, + before: []*hubauth.BiscuitPolicy{ + { + Name: policyName, + Groups: []string{"existing", "grp1", "grp2"}, + }, + }, + after: []*hubauth.BiscuitPolicy{ + { + Name: policyName, + Groups: []string{"existing"}, + }, + }, + }, + { + desc: "set content", + mut: []*hubauth.AudiencePolicyMutation{ + { + Op: hubauth.AudiencePolicyMutationOpSetContent, + Content: "new content", + }, + }, + before: []*hubauth.BiscuitPolicy{ + { + Name: policyName, + Content: "", + }, + }, + after: []*hubauth.BiscuitPolicy{ + { + Name: policyName, + Content: "new content", + }, + }, + }, + } + + s := newTestService(t) + ctx := context.Background() + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + + aud := &hubauth.Audience{ + URL: "https://cluster.mutate.example.com", + Policies: tt.before, + } + + err := s.CreateAudience(ctx, aud) + require.NoError(t, err, tt.desc) + before, err := s.GetAudience(ctx, aud.URL) + require.NoError(t, err) + + err = s.MutateAudiencePolicy(ctx, aud.URL, policyName, tt.mut) + require.NoError(t, err, tt.desc) + + res, err := s.GetAudience(ctx, aud.URL) + require.NoError(t, err, tt.desc) + if len(res.UserGroups) == 0 { + res.UserGroups = nil + } + require.Equal(t, before.CreateTime, res.CreateTime) + + // sort to ensure consistent slice comparison + for _, p := range res.Policies { + sort.Strings(p.Groups) + } + for _, p := range tt.after { + sort.Strings(p.Groups) + } + + require.Equal(t, tt.after, res.Policies, tt.desc) + + s.DeleteAudience(ctx, aud.URL) + }) + } +} diff --git a/pkg/datastore/datastore.go b/pkg/datastore/datastore.go index 7865c2d..b55aff4 100644 --- a/pkg/datastore/datastore.go +++ b/pkg/datastore/datastore.go @@ -21,7 +21,6 @@ const ( kindDomain = "GoogleDomain" kindCachedGroup = "CachedGoogleGroup" kindCachedGroupMember = "CachedGoogleGroupMember" - kindBiscuitPolicy = "BiscuitPolicy" ) func New(db *datastore.Client) hubauth.DataStore { diff --git a/pkg/datastore/policy.go b/pkg/datastore/policy.go deleted file mode 100644 index 5e19639..0000000 --- a/pkg/datastore/policy.go +++ /dev/null @@ -1,170 +0,0 @@ -package datastore - -import ( - "context" - "time" - - "cloud.google.com/go/datastore" - "github.com/flynn/hubauth/pkg/hubauth" - "go.opencensus.io/trace" - "golang.org/x/exp/errors/fmt" -) - -func buildBiscuitPolicy(p *hubauth.BiscuitPolicy) *biscuitPolicy { - now := time.Now() - return &biscuitPolicy{ - Content: p.Content, - Groups: p.Groups, - CreateTime: now, - UpdateTime: now, - } -} - -type biscuitPolicy struct { - ID *datastore.Key `datastore:"__key__"` - Content string - Groups []string - CreateTime time.Time - UpdateTime time.Time -} - -func (c *biscuitPolicy) Export() *hubauth.BiscuitPolicy { - return &hubauth.BiscuitPolicy{ - ID: c.ID.Encode(), - Content: c.Content, - Groups: c.Groups, - CreateTime: c.CreateTime, - UpdateTime: c.UpdateTime, - } -} - -func biscuitPolicyKey(id string) (*datastore.Key, error) { - k, err := datastore.DecodeKey(id) - if err != nil { - return nil, hubauth.ErrNotFound - } - if k.Kind != kindBiscuitPolicy { - return nil, hubauth.ErrNotFound - } - return k, nil -} - -func (s *service) GetBiscuitPolicy(ctx context.Context, id string) (*hubauth.BiscuitPolicy, error) { - ctx, span := trace.StartSpan(ctx, "datastore.GetBiscuitPolicy") - span.AddAttributes(trace.StringAttribute("biscuit_policy_id", id)) - defer span.End() - - k, err := biscuitPolicyKey(id) - if err != nil { - return nil, err - } - res := &biscuitPolicy{} - if err := s.db.Get(ctx, k, res); err != nil { - if err == datastore.ErrNoSuchEntity { - err = hubauth.ErrNotFound - } - return nil, fmt.Errorf("datastore: error fetching biscuit policy %s: %w", id, err) - } - return res.Export(), nil -} - -func (s *service) CreateBiscuitPolicy(ctx context.Context, policy *hubauth.BiscuitPolicy) (string, error) { - ctx, span := trace.StartSpan(ctx, "datastore.CreateBiscuitPolicy") - defer span.End() - - k, err := s.db.Put(ctx, datastore.IncompleteKey(kindBiscuitPolicy, nil), buildBiscuitPolicy(policy)) - if err != nil { - return "", fmt.Errorf("datastore: error creating biscuit policy: %w", err) - } - id := k.Encode() - span.AddAttributes(trace.StringAttribute("biscuit_policy_id", id)) - return id, nil -} - -func (s *service) MutateBiscuitPolicy(ctx context.Context, id string, mut []*hubauth.BiscuitPolicyMutation) error { - ctx, span := trace.StartSpan(ctx, "datastore.MutateBiscuitPolicy") - span.AddAttributes( - trace.StringAttribute("biscuit_policy_id", id), - trace.Int64Attribute("biscuit_policy_mutation_count", int64(len(mut))), - ) - defer span.End() - - k, err := biscuitPolicyKey(id) - if err != nil { - return err - } - _, err = s.db.RunInTransaction(ctx, func(tx *datastore.Transaction) error { - policy := &biscuitPolicy{} - if err := tx.Get(k, policy); err != nil { - if err == datastore.ErrNoSuchEntity { - err = hubauth.ErrNotFound - } - return fmt.Errorf("datastore: error fetching biscuit policy %s: %w", id, err) - } - modified := false - outer: - for _, m := range mut { - switch m.Op { - case hubauth.BiscuitPolicyMutationOpAddGroup: - for _, g := range policy.Groups { - if g == m.Group { - continue outer - } - } - policy.Groups = append(policy.Groups, m.Group) - modified = true - case hubauth.BiscuitPolicyMutationOpDeleteGroup: - for i, u := range policy.Groups { - if u != m.Group { - continue - } - policy.Groups[i] = policy.Groups[len(policy.Groups)-1] - policy.Groups = policy.Groups[:len(policy.Groups)-1] - modified = true - } - default: - return fmt.Errorf("datastore: unknown biscuit policy mutation op %s", m.Op) - } - } - if !modified { - return nil - } - policy.UpdateTime = time.Now() - _, err := tx.Put(k, policy) - return err - }) - if err != nil { - return fmt.Errorf("datastore: error mutating biscuit policy %s: %w", id, err) - } - return nil -} - -func (s *service) ListBiscuitPolicies(ctx context.Context) ([]*hubauth.BiscuitPolicy, error) { - ctx, span := trace.StartSpan(ctx, "datastore.ListBiscuitPolicies") - defer span.End() - - var policies []*biscuitPolicy - if _, err := s.db.GetAll(ctx, datastore.NewQuery(kindBiscuitPolicy), &policies); err != nil { - return nil, fmt.Errorf("datastore: error listing biscuit policies: %w", err) - } - res := make([]*hubauth.BiscuitPolicy, len(policies)) - for i, c := range policies { - res[i] = c.Export() - } - return res, nil -} - -func (s *service) DeleteBiscuitPolicy(ctx context.Context, id string) error { - ctx, span := trace.StartSpan(ctx, "datastore.DeleteBiscuitPolicy") - span.AddAttributes(trace.StringAttribute("biscuit_policy_id", id)) - defer span.End() - - k, err := biscuitPolicyKey(id) - if err != nil { - return err - } - if err := s.db.Delete(ctx, k); err != nil { - return fmt.Errorf("datastore: error deleting biscuit policy %s: %w", id, err) - } - return nil -} diff --git a/pkg/hubauth/data.go b/pkg/hubauth/data.go index 607f1b6..b97dc42 100644 --- a/pkg/hubauth/data.go +++ b/pkg/hubauth/data.go @@ -24,7 +24,6 @@ type DataStore interface { CodeStore RefreshTokenStore CachedGroupStore - BiscuitPolicyStore } type ClientStore interface { @@ -58,11 +57,16 @@ type ClientMutation struct { RefreshTokenExpiry time.Duration } -type AudienceStore interface { +type AudienceGetterStore interface { GetAudience(ctx context.Context, url string) (*Audience, error) +} + +type AudienceStore interface { + AudienceGetterStore CreateAudience(ctx context.Context, audience *Audience) error MutateAudience(ctx context.Context, url string, mut []*AudienceMutation) error MutateAudienceUserGroups(ctx context.Context, url string, domain string, mut []*AudienceUserGroupsMutation) error + MutateAudiencePolicy(ctx context.Context, url string, policyName string, mut []*AudiencePolicyMutation) error ListAudiencesForClient(ctx context.Context, clientID string) ([]*Audience, error) ListAudiences(ctx context.Context) ([]*Audience, error) DeleteAudience(ctx context.Context, url string) error @@ -74,6 +78,7 @@ type Audience struct { Type string `json:"type"` ClientIDs []string `json:"-"` UserGroups []*GoogleUserGroups `json:"-"` + Policies []*BiscuitPolicy `json:"-"` CreateTime time.Time `json:"-"` UpdateTime time.Time `json:"-"` } @@ -84,6 +89,12 @@ type GoogleUserGroups struct { Groups []string } +type BiscuitPolicy struct { + Name string + Content string + Groups []string +} + type AudienceMutationOp byte const ( @@ -92,7 +103,8 @@ const ( AudienceMutationOpSetUserGroups AudienceMutationOpDeleteUserGroups AudienceMutationSetType - AudienceMutationMigratePolicy + AudienceMutationSetPolicy + AudienceMutationDeletePolicy ) type AudienceMutation struct { @@ -101,6 +113,7 @@ type AudienceMutation struct { ClientID string Type string UserGroups GoogleUserGroups + Policy BiscuitPolicy } type AudienceUserGroupsMutationOp byte @@ -118,6 +131,21 @@ type AudienceUserGroupsMutation struct { Group string } +type AudiencePolicyMutationOp byte + +const ( + AudiencePolicyMutationOpAddGroup AudiencePolicyMutationOp = iota + AudiencePolicyMutationOpDeleteGroup + AudiencePolicyMutationOpSetContent +) + +type AudiencePolicyMutation struct { + Op AudiencePolicyMutationOp + + Content string + Group string +} + type CodeStore interface { GetCode(ctx context.Context, id string) (*Code, error) VerifyAndDeleteCode(ctx context.Context, id, secret string) (*Code, error) @@ -263,33 +291,3 @@ func GetClientInfo(ctx context.Context) *ClientInfo { } return res } - -type BiscuitPolicyStore interface { - GetBiscuitPolicy(ctx context.Context, id string) (*BiscuitPolicy, error) - CreateBiscuitPolicy(ctx context.Context, policy *BiscuitPolicy) (string, error) - MutateBiscuitPolicy(ctx context.Context, id string, mut []*BiscuitPolicyMutation) error - ListBiscuitPolicies(ctx context.Context) ([]*BiscuitPolicy, error) - DeleteBiscuitPolicy(ctx context.Context, id string) error -} - -type BiscuitPolicy struct { - ID string - Content string - Groups []string - CreateTime time.Time - UpdateTime time.Time -} - -type BiscuitPolicyMutationOp byte - -const ( - BiscuitPolicyMutationOpAddGroup BiscuitPolicyMutationOp = iota - BiscuitPolicyMutationOpDeleteGroup - BiscuitPolicyMutationOpSetContent -) - -type BiscuitPolicyMutation struct { - Op BiscuitPolicyMutationOp - - Group string -} From 2bfb08a5a0b35777ec5820794fccd83dc566ce22 Mon Sep 17 00:00:00 2001 From: Flavien Binet Date: Tue, 15 Dec 2020 17:40:29 +0100 Subject: [PATCH 18/24] policy: add parser test for empty policies --- pkg/policy/parser_test.go | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/pkg/policy/parser_test.go b/pkg/policy/parser_test.go index be702b4..8953d38 100644 --- a/pkg/policy/parser_test.go +++ b/pkg/policy/parser_test.go @@ -170,6 +170,38 @@ func TestParse(t *testing.T) { } } +func TestParseDocumentPolicy(t *testing.T) { + testCases := []struct { + Desc string + Input string + ExpectedErr bool + ExpectedOut *DocumentPolicy + }{ + { + Desc: "single policy", + Input: `policy "foo" {}`, + ExpectedOut: &DocumentPolicy{Name: sptr("foo")}, + }, + { + Desc: "empty document returns an error", + Input: "", + ExpectedErr: true, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.Desc, func(t *testing.T) { + out, err := ParseDocumentPolicy(strings.NewReader(testCase.Input)) + if testCase.ExpectedErr { + require.Error(t, err) + } else { + require.NoError(t, err) + require.Equal(t, testCase.ExpectedOut, out) + } + }) + } +} + func sptr(s string) *string { return &s } From 5f0c2ad11be89ae94a39a6c937d0a208074f4f65 Mon Sep 17 00:00:00 2001 From: Flavien Binet Date: Wed, 16 Dec 2020 11:32:03 +0100 Subject: [PATCH 19/24] cli: add audiences policies commands --- go.mod | 1 + go.sum | 7 - pkg/cli/audiences.go | 285 +++++++++++++++++++++++++++-- pkg/cli/audiences_test.go | 369 +++++++++++++++++++++++++++++++++++++- pkg/cli/cli.go | 1 - pkg/cli/policies.go | 219 ---------------------- pkg/policy/parser.go | 2 +- pkg/policy/printer.go | 10 ++ 8 files changed, 653 insertions(+), 241 deletions(-) delete mode 100644 pkg/cli/policies.go diff --git a/go.mod b/go.mod index a4aa488..72bb495 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( cloud.google.com/go/datastore v1.3.0 contrib.go.opencensus.io/exporter/stackdriver v0.13.4 github.com/alecthomas/kong v0.2.12 + github.com/alecthomas/participle/v2 v2.0.0-alpha3.0.20201208114601-14bec2482095 github.com/aws/aws-sdk-go v1.36.7 // indirect github.com/census-instrumentation/opencensus-proto v0.3.0 // indirect github.com/flynn/biscuit-go v0.0.0-20201009174859-e7eb59a90195 diff --git a/go.sum b/go.sum index 002e339..a397178 100644 --- a/go.sum +++ b/go.sum @@ -79,13 +79,6 @@ github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.m github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= github.com/envoyproxy/go-control-plane v0.9.7/go.mod h1:cwu0lG7PUMfa9snN8LXBig5ynNVH9qI8YYLbd1fK2po= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= -github.com/flynn/biscuit-go v0.0.0-20200907174027-193b7bdbbdca h1:LUZQQzaCT+gltxii4icyPH5oMdAP38JmbvO9aI0E4qM= -github.com/flynn/biscuit-go v0.0.0-20200907174027-193b7bdbbdca/go.mod h1:EMJZ3stAYtwaP763F5HcGjPjCnYu21V2TEsg/iw88I8= -github.com/flynn/biscuit-go v0.0.0-20201009174859-e7eb59a90195 h1:TP3jMHmhjz8XxqqigEd5OQffNAO/6KPvGUYII6TFdmI= -github.com/flynn/biscuit-go v0.0.0-20201009174859-e7eb59a90195/go.mod h1:EMJZ3stAYtwaP763F5HcGjPjCnYu21V2TEsg/iw88I8= -github.com/flynn/biscuit-go v0.0.0-20201015081742-15d7d351f345 h1:ME6bm5dwn9V2DUlfXJqeN121B5nM7rDFqLFOATALqYE= -github.com/flynn/biscuit-go v0.0.0-20201015081742-15d7d351f345/go.mod h1:Sj4oR2hNkrZH1cf3Cj5DPHc3Xq0o61GWeau6UkZR+3c= -github.com/flynn/biscuit-go v0.0.0-20201204161836-6af1c88a7b3d h1:RHIlExiAgFgF1hQzdjhq41dnlOlkbcsOczQD+YgVQRk= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= diff --git a/pkg/cli/audiences.go b/pkg/cli/audiences.go index bbdda19..281fbde 100644 --- a/pkg/cli/audiences.go +++ b/pkg/cli/audiences.go @@ -5,26 +5,39 @@ import ( "encoding/base64" "encoding/pem" "fmt" + "io/ioutil" "net/url" "os" "strings" "github.com/flynn/hubauth/pkg/hubauth" + "github.com/flynn/hubauth/pkg/policy" "github.com/jedib0t/go-pretty/v6/table" "google.golang.org/genproto/googleapis/cloud/kms/v1" ) type audiencesCmd struct { - List audiencesListCmd `kong:"cmd,help='list audiences',default:'1'"` - Create audiencesCreateCmd `kong:"cmd,help='create audience'"` - UpdateType audienceUpdateTypeCmd `kong:"cmd,name='update-type',help='change audience type'"` - UpdateClientIDs audiencesUpdateClientsIDsCmd `kong:"cmd,name='update-client-ids',help='add or remove audience client IDs'"` - Delete audiencesDeleteCmd `kong:"cmd,help='delete audience and all its keys'"` + List audiencesListCmd `kong:"cmd,help='list audiences',default:'1'"` + Create audiencesCreateCmd `kong:"cmd,help='create audience'"` + UpdateType audienceUpdateTypeCmd `kong:"cmd,name='update-type',help='change audience type'"` + UpdateClientIDs audiencesUpdateClientsIDsCmd `kong:"cmd,name='update-client-ids',help='add or remove audience client IDs'"` + Delete audiencesDeleteCmd `kong:"cmd,help='delete audience and all its keys'"` + ListUserGroups audiencesListUserGroupsCmd `kong:"cmd,name='list-user-groups',help='list audience user groups'"` SetUserGroups audiencesSetUserGroupsCmd `kong:"cmd,name='set-user-groups',help='set audience auth user groups'"` UpdateUserGroups audiencesUpdateUserGroupsCmd `kong:"cmd,name='update-user-groups',help='modify audience user groups api user or groups'"` DeleteUserGroups audiencesDeleteUserGroupsCmd `kong:"cmd,name='delete-user-groups',help='delete audience auth user groups'"` - Key audiencesKeyCmd `kong:"cmd,help='get audience public key'"` + + Key audiencesKeyCmd `kong:"cmd,help='get audience public key'"` + + ListPolicies audiencesListPoliciesCmd `kong:"cmd,name='list-policies',help='list audience policies'"` + DumpPolicies audiencesDumpPoliciesCmd `kong:"cmd,name='dump-policies',help='dump audience policies'"` + SetPolicies audiencesSetPoliciesCmd `kong:"cmd,name='set-policies',help='set audience policies'"` + UpdatePolicy audiencesUpdatePolicyCmd `kong:"cmd,name='update-policy',help='modify audience policy content or groups'"` + DeletePolicy audiencesDeletePolicyCmd `kong:"cmd,name='delete-policy',help='delete audience policy'"` + + NewPolicy audiencesNewPolicyCmd `kong:"cmd,name='new-policy',help='print a new empty policy document on stdout'"` + ValidatePolicies audiencesValidatePoliciesCmd `kong:"cmd,name='validate-policies',help='validate a policy file'"` } type audiencesListCmd struct{} @@ -251,15 +264,263 @@ func (c *audiencesKeyCmd) Run(cfg *Config) error { return nil } -type audienceMigratePoliciesCmd struct { +type audiencesListPoliciesCmd struct { + AudienceURL string `kong:"required,name='audience-url',help='audience URL'"` +} + +func (c *audiencesListPoliciesCmd) Run(cfg *Config) error { + audience, err := cfg.DB.GetAudience(context.Background(), c.AudienceURL) + if err != nil { + return err + } + + t := table.NewWriter() + t.SetOutputMirror(os.Stdout) + t.AppendHeader(table.Row{"Name", "Groups", "Description"}) + for _, p := range audience.Policies { + t.AppendRow(table.Row{p.Name, p.Groups, getFirstComment(p)}) + } + t.Render() + return nil +} + +// getFirstComment parse the policy content and returns the first policy +// comment line if it exists. On failure to parse the policy content, or when unset, an empty string is returned. +func getFirstComment(p *hubauth.BiscuitPolicy) string { + doc, err := policy.ParseDocumentPolicy(strings.NewReader(p.Content)) + if err != nil { + return "" + } + if len(doc.Comments) == 0 { + return "" + } + return string(*doc.Comments[0]) +} + +type audiencesSetPoliciesCmd struct { + AudienceURL string `kong:"required,name='audience-url',help='audience URL'"` + Filepath string `kong:"required,name='filepath',help='policy file'"` + Groups []string `kong:"help='comma-separated group IDs'"` +} + +// Run parses Filepath for a list of policies, and creates or updates them on the audience identified by AudienceURL, +// forcing their groups to the provided Groups. +func (c *audiencesSetPoliciesCmd) Run(cfg *Config) error { + f, err := os.Open(c.Filepath) + if err != nil { + return err + } + + doc, err := policy.ParseNamed(f.Name(), f) + if err != nil { + return err + } + + muts := make([]*hubauth.AudienceMutation, len(doc.Policies)) + for i, p := range doc.Policies { + muts[i] = &hubauth.AudienceMutation{ + Op: hubauth.AudienceMutationSetPolicy, + Policy: hubauth.BiscuitPolicy{ + Name: *p.Name, + Content: policy.PrintPolicy(p), + Groups: c.Groups, + }, + } + } + + return cfg.DB.MutateAudience(context.Background(), c.AudienceURL, muts) +} + +type audiencesUpdatePolicyCmd struct { + AudienceURL string `kong:"required,name='audience-url',help='audience URL'"` + PolicyName string `kong:"required,help='policy name'"` + Filepath string `kong:"name='filepath',help='replace policy content from a file'"` + AddGroups []string `kong:"name='add-groups',help='comma-separated group IDs to add'"` + DeleteGroups []string `kong:"name='delete-groups',help='comma-separated group IDs to delete'"` } -func (c *audienceMigratePoliciesCmd) Run(cfg *Config) error { - policies, err := cfg.DB.ListAudiences(ctx) - for _, p := range policies { - return fmt.Errorf("Failed to migrate policies to userGroups for audience %q", p.URL) +func (c *audiencesUpdatePolicyCmd) Run(cfg *Config) error { + var muts []*hubauth.AudiencePolicyMutation + for _, groupID := range c.AddGroups { + muts = append(muts, &hubauth.AudiencePolicyMutation{ + Op: hubauth.AudiencePolicyMutationOpAddGroup, + Group: groupID, + }) + } + for _, groupID := range c.DeleteGroups { + muts = append(muts, &hubauth.AudiencePolicyMutation{ + Op: hubauth.AudiencePolicyMutationOpDeleteGroup, + Group: groupID, + }) + } + if c.Filepath != "" { + doc, err := parsePolicy(c.Filepath) + if err != nil { + return err + } + + var mutatedPolicy *policy.DocumentPolicy + for _, p := range doc.Policies { + if *p.Name == c.PolicyName { + mutatedPolicy = p + break + } } - fmt.Printf("Success migrating policies to userGroups for %q\n", p.URL) + if mutatedPolicy == nil { + return fmt.Errorf("policy %q not found in file %q", c.PolicyName, c.Filepath) + } + + muts = append(muts, &hubauth.AudiencePolicyMutation{ + Op: hubauth.AudiencePolicyMutationOpSetContent, + Content: policy.PrintPolicy(mutatedPolicy), + }) + } + + return cfg.DB.MutateAudiencePolicy(context.Background(), c.AudienceURL, c.PolicyName, muts) +} + +type audiencesDeletePolicyCmd struct { + AudienceURL string `kong:"required,name='audience-url',help='audience URL'"` + PolicyName string `kong:"required,help='policy name'"` +} + +func (c *audiencesDeletePolicyCmd) Run(cfg *Config) error { + mut := &hubauth.AudienceMutation{ + Op: hubauth.AudienceMutationDeletePolicy, + Policy: hubauth.BiscuitPolicy{ + Name: c.PolicyName, + }, } + return cfg.DB.MutateAudience(context.Background(), c.AudienceURL, []*hubauth.AudienceMutation{mut}) +} + +type audiencesNewPolicyCmd struct { + Filepath string `kong:"name='filepath',short='f',help='optionnal filepath where to write the policy (default: stdout)'"` +} + +var policyTemplate string = `// this is a template policy +policy "dummy" { + rules { + // this is a dummy rule + *head($var1) + <- body1(#ambient, $name), + body2($value) + @ $name == "example" + } + + caveats {[ + // this is a dummy caveat + *head($var1) + <- body1(#ambient, $name), + body2($value) + @ $name == "example" + ]} +}` + +func (c *audiencesNewPolicyCmd) Run(cfg *Config) error { + d, err := policy.Parse(strings.NewReader(policyTemplate)) + if err != nil { + return err + } + + out, err := policy.Print(d) + if err != nil { + return err + } + + if c.Filepath != "" { + ioutil.WriteFile(c.Filepath, []byte(out), 0644) + fmt.Printf("written %s\n", c.Filepath) + return nil + } + + fmt.Print(out) + return nil +} + +type audiencesValidatePoliciesCmd struct { + Filepath string `kong:"required,name='filepath',short='f',help='a file containing policy definitions'"` +} + +func (c *audiencesValidatePoliciesCmd) Run(cfg *Config) error { + f, err := os.Open(c.Filepath) + if err != nil { + return err + } + + _, err = policy.ParseNamed(f.Name(), f) + if err != nil { + return err + } + return nil } + +type audiencesDumpPoliciesCmd struct { + AudienceURL string `kong:"required,name='audience-url',help='audience URL'"` + PolicyNames []string `kong:"name='policy-names',help='comma separated policy names to dump (default: all)'"` + Filepath string `kong:"name='filepath',short='f',help='optionnal filepath where to write the policies (default: stdout)'"` +} + +func (c *audiencesDumpPoliciesCmd) Run(cfg *Config) error { + aud, err := cfg.DB.GetAudience(context.Background(), c.AudienceURL) + if err != nil { + return err + } + + if len(aud.Policies) == 0 { + return fmt.Errorf("audience %s have no policy", c.AudienceURL) + } + + dumpPolicies := aud.Policies + if len(c.PolicyNames) > 0 { + dumpPolicies = make([]*hubauth.BiscuitPolicy, 0, len(c.PolicyNames)) + for _, p := range aud.Policies { + for _, name := range c.PolicyNames { + if name == p.Name { + dumpPolicies = append(dumpPolicies, p) + break + } + } + } + } + + aggContent := "" + for _, p := range dumpPolicies { + aggContent += p.Content + } + + doc, err := policy.Parse(strings.NewReader(aggContent)) + if err != nil { + return err + } + + out, err := policy.Print(doc) + if err != nil { + return err + } + + if c.Filepath != "" { + ioutil.WriteFile(c.Filepath, []byte(out), 0644) + fmt.Printf("written %d policies to %s\n", len(dumpPolicies), c.Filepath) + return nil + } + + fmt.Printf("%s", out) + return nil +} + +func parsePolicy(path string) (*policy.Document, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + defer f.Close() + + doc, err := policy.ParseNamed(f.Name(), f) + if err != nil { + return nil, err + } + + return doc, nil +} diff --git a/pkg/cli/audiences_test.go b/pkg/cli/audiences_test.go index fff0b5f..8609f00 100644 --- a/pkg/cli/audiences_test.go +++ b/pkg/cli/audiences_test.go @@ -11,11 +11,14 @@ import ( "encoding/pem" "errors" "fmt" + "io/ioutil" "os" + "strings" "testing" "time" "github.com/flynn/hubauth/pkg/hubauth" + "github.com/flynn/hubauth/pkg/policy" "github.com/googleapis/gax-go/v2" "github.com/jedib0t/go-pretty/v6/table" "github.com/stretchr/testify/mock" @@ -77,7 +80,10 @@ func (m *mockAudienceDatastore) MutateAudience(ctx context.Context, url string, args := m.Called(ctx, url, mut) return args.Error(0) } - +func (m *mockAudienceDatastore) MutateAudiencePolicy(ctx context.Context, url string, policyName string, mut []*hubauth.AudiencePolicyMutation) error { + args := m.Called(ctx, url, policyName, mut) + return args.Error(0) +} func (m *mockAudienceDatastore) MutateAudienceUserGroups(ctx context.Context, url string, domain string, mut []*hubauth.AudienceUserGroupsMutation) error { args := m.Called(ctx, url, domain, mut) return args.Error(0) @@ -623,3 +629,364 @@ func TestAudienceUpdateTypeCmd(t *testing.T) { require.NoError(t, cmd.Run(cfg)) } + +func TestAudienceListPoliciesCmd(t *testing.T) { + cmd := &audiencesListPoliciesCmd{ + AudienceURL: "https://audience.url", + } + + cfg := &Config{ + DB: &mockAudienceDatastore{}, + } + + policy1Content := "// policy1 description\npolicy \"policy1\" {}" + policy2Content := "// policy2 description\npolicy \"policy2\" {}" + policy3Content := "policy \"policy3\" {}" + + audience := &hubauth.Audience{ + URL: cmd.AudienceURL, + Policies: []*hubauth.BiscuitPolicy{ + { + Name: "policy1", + Content: policy1Content, + Groups: []string{"grp1", "grp2"}, + }, + { + Name: "policy2", + Content: policy2Content, + Groups: nil, + }, + { + Name: "policy3", + Content: policy3Content, + Groups: nil, + }, + }, + } + + cfg.DB.(*mockAudienceDatastore).On("GetAudience", mock.Anything, cmd.AudienceURL).Return(audience, nil) + + r, w, err := os.Pipe() + require.NoError(t, err) + origStdout := os.Stdout + os.Stdout = w + + require.NoError(t, cmd.Run(cfg)) + + os.Stdout = origStdout + + buf := make([]byte, 2048) + n, err := r.Read(buf) + require.NoError(t, err) + + expectedBuf := new(bytes.Buffer) + tw := table.NewWriter() + tw.SetOutputMirror(expectedBuf) + tw.AppendHeader(table.Row{"Name", "Groups", "Description"}) + for _, p := range audience.Policies { + tw.AppendRow(table.Row{p.Name, p.Groups, getFirstComment(p)}) + } + tw.Render() + + require.Equal(t, expectedBuf.String(), string(buf[:n])) +} + +func TestAudiencesSetPoliciesCmd(t *testing.T) { + policy1Content := `// policy1 +policy "policy1" { + rules { + // rule1 + *r1($a) <- f1($a) + } +}` + + policy2Content := `// policy2 +policy "policy2" {} +` + + policyFile, err := ioutil.TempFile(os.TempDir(), "testaudiencessetpoliciescmd-") + require.NoError(t, err) + defer func() { + policyFile.Close() + os.Remove(policyFile.Name()) + }() + + _, err = policyFile.WriteString(policy1Content) + require.NoError(t, err) + _, err = policyFile.WriteString(policy2Content) + require.NoError(t, err) + + groups := []string{"grp1", "grp2"} + + cmd := &audiencesSetPoliciesCmd{ + AudienceURL: "https://audience.url", + Filepath: policyFile.Name(), + Groups: groups, + } + + cfg := &Config{ + DB: &mockAudienceDatastore{}, + } + + policy1ContentFmt, err := policy.Format(strings.NewReader(policy1Content)) + require.NoError(t, err) + policy2ContentFmt, err := policy.Format(strings.NewReader(policy2Content)) + require.NoError(t, err) + + expectedMuts := []*hubauth.AudienceMutation{ + { + Op: hubauth.AudienceMutationSetPolicy, + Policy: hubauth.BiscuitPolicy{ + Name: "policy1", + Content: policy1ContentFmt, + Groups: groups, + }, + }, + { + Op: hubauth.AudienceMutationSetPolicy, + Policy: hubauth.BiscuitPolicy{ + Name: "policy2", + Content: policy2ContentFmt, + Groups: groups, + }, + }, + } + + cfg.DB.(*mockAudienceDatastore).On("MutateAudience", mock.Anything, cmd.AudienceURL, expectedMuts).Return(nil) + + require.NoError(t, cmd.Run(cfg)) +} + +func TestAudiencesUpdatePolicyCmd(t *testing.T) { + policy1Content := "// policy1\npolicy \"policy1\" {}" + + policy1ContentFmt, err := policy.Format(strings.NewReader(policy1Content)) + require.NoError(t, err) + + policyFile, err := ioutil.TempFile(os.TempDir(), "testaudiencesupdatepoliciescmd-") + require.NoError(t, err) + defer func() { + policyFile.Close() + os.Remove(policyFile.Name()) + }() + + _, err = policyFile.WriteString(policy1Content) + require.NoError(t, err) + + cmd := &audiencesUpdatePolicyCmd{ + AudienceURL: "https://audience.url", + PolicyName: "policy1", + Filepath: policyFile.Name(), + AddGroups: []string{"grp1", "grp2"}, + DeleteGroups: []string{"grp3"}, + } + + cfg := &Config{ + DB: &mockAudienceDatastore{}, + } + + expectedMuts := []*hubauth.AudiencePolicyMutation{ + { + Op: hubauth.AudiencePolicyMutationOpAddGroup, + Group: "grp1", + }, + { + Op: hubauth.AudiencePolicyMutationOpAddGroup, + Group: "grp2", + }, + { + Op: hubauth.AudiencePolicyMutationOpDeleteGroup, + Group: "grp3", + }, + { + Op: hubauth.AudiencePolicyMutationOpSetContent, + Content: policy1ContentFmt, + }, + } + + cfg.DB.(*mockAudienceDatastore).On("MutateAudiencePolicy", mock.Anything, cmd.AudienceURL, cmd.PolicyName, expectedMuts).Return(nil) + + require.NoError(t, cmd.Run(cfg)) + + cmd.PolicyName = "not-existing-policy" + require.Error(t, cmd.Run(cfg)) +} + +func TestAudiencesDeletePolicyCmd(t *testing.T) { + cmd := audiencesDeletePolicyCmd{ + AudienceURL: "https://audience.url", + PolicyName: "policy1", + } + + cfg := &Config{ + DB: &mockAudienceDatastore{}, + } + + expectedMuts := []*hubauth.AudienceMutation{ + { + Op: hubauth.AudienceMutationDeletePolicy, + Policy: hubauth.BiscuitPolicy{ + Name: cmd.PolicyName, + }, + }, + } + + cfg.DB.(*mockAudienceDatastore).On("MutateAudience", mock.Anything, cmd.AudienceURL, expectedMuts).Return(nil) + + require.NoError(t, cmd.Run(cfg)) +} + +func TestAudiencesNewPolicyCmd(t *testing.T) { + cmd := &audiencesNewPolicyCmd{} + + r, w, err := os.Pipe() + require.NoError(t, err) + origStdout := os.Stdout + os.Stdout = w + + require.NoError(t, cmd.Run(&Config{})) + + os.Stdout = origStdout + + buf := make([]byte, 2048) + n, err := r.Read(buf) + require.NoError(t, err) + + policyTemplateFmt, err := policy.Format(strings.NewReader(policyTemplate)) + require.NoError(t, err) + + require.Equal(t, policyTemplateFmt, string(buf[:n])) + + policyFile, err := ioutil.TempFile(os.TempDir(), "testaudiencesnewpolicycmd-") + require.NoError(t, err) + defer func() { + policyFile.Close() + os.Remove(policyFile.Name()) + }() + + cmd.Filepath = policyFile.Name() + require.NoError(t, cmd.Run(&Config{})) + + out, err := ioutil.ReadFile(policyFile.Name()) + require.NoError(t, err) + require.Equal(t, policyTemplateFmt, string(out)) +} + +func TestAudiencesValidatePoliciesCmd(t *testing.T) { + testCases := []struct { + desc string + content string + expectValid bool + }{ + { + desc: "valid policy", + content: `policy "p1" {}`, + expectValid: true, + }, + { + desc: "invalid policy", + content: `policy {}`, + expectValid: false, + }, + { + desc: "empty", + content: ``, + expectValid: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + policyFile, err := ioutil.TempFile(os.TempDir(), "testaudiencesvalidatepolicycmd-") + require.NoError(t, err) + defer func() { + policyFile.Close() + os.Remove(policyFile.Name()) + }() + + _, err = policyFile.WriteString(tc.content) + require.NoError(t, err) + + cmd := &audiencesValidatePoliciesCmd{ + Filepath: policyFile.Name(), + } + + err = cmd.Run(&Config{}) + if tc.expectValid { + require.NoError(t, err) + } else { + require.Error(t, err) + } + }) + } +} + +func TestAudiencesDumpPoliciesCmd(t *testing.T) { + cmd := audiencesDumpPoliciesCmd{ + AudienceURL: "https://audience.url", + } + + cfg := &Config{ + DB: &mockAudienceDatastore{}, + } + + p1Content := "policy \"p1\" {}" + p2Content := "policy \"p2\" {}" + p3Content := "policy \"p3\" {}" + + audience := &hubauth.Audience{ + URL: cmd.AudienceURL, + Policies: []*hubauth.BiscuitPolicy{ + { + Name: "p1", + Content: p1Content, + }, + { + Name: "p2", + Content: p2Content, + }, + { + Name: "p3", + Content: p3Content, + }, + }, + } + + cfg.DB.(*mockAudienceDatastore).On("GetAudience", mock.Anything, cmd.AudienceURL).Return(audience, nil) + + r, w, err := os.Pipe() + require.NoError(t, err) + origStdout := os.Stdout + os.Stdout = w + + require.NoError(t, cmd.Run(cfg)) + + buf := make([]byte, 2048) + n, err := r.Read(buf) + require.NoError(t, err) + require.Equal(t, strings.Join([]string{p1Content, p2Content, p3Content}, "\n\n")+"\n", string(buf[:n])) + + cmd.PolicyNames = []string{"p1", "p3"} + require.NoError(t, cmd.Run(cfg)) + + os.Stdout = origStdout + + n, err = r.Read(buf) + require.NoError(t, err) + + expectedOut := strings.Join([]string{p1Content, p3Content}, "\n\n") + "\n" + require.Equal(t, expectedOut, string(buf[:n])) + + policyFile, err := ioutil.TempFile(os.TempDir(), "testaudiencesdumppolicycmd-") + require.NoError(t, err) + defer func() { + policyFile.Close() + os.Remove(policyFile.Name()) + }() + cmd.Filepath = policyFile.Name() + require.NoError(t, cmd.Run(cfg)) + + got, err := ioutil.ReadFile(policyFile.Name()) + require.NoError(t, err) + require.Equal(t, expectedOut, string(got)) +} diff --git a/pkg/cli/cli.go b/pkg/cli/cli.go index 0739544..c772034 100644 --- a/pkg/cli/cli.go +++ b/pkg/cli/cli.go @@ -17,5 +17,4 @@ type CLI struct { Clients clientsCmd `kong:"cmd,help='manage oauth clients'"` Audiences audiencesCmd `kong:"cmd,help='manage audiences'"` - Policies policiesCmd `kong:"cmd,help='manage policies'"` } diff --git a/pkg/cli/policies.go b/pkg/cli/policies.go deleted file mode 100644 index e9f0a33..0000000 --- a/pkg/cli/policies.go +++ /dev/null @@ -1,219 +0,0 @@ -package cli - -import ( - "context" - "fmt" - "io/ioutil" - "os" - "strings" - - "github.com/flynn/hubauth/pkg/hubauth" - "github.com/flynn/hubauth/pkg/policy" - "github.com/jedib0t/go-pretty/v6/table" -) - -type policiesCmd struct { - New policiesNewCmd `kong:"cmd,help='dump a new empty policy document on stdout'"` - List policiesListCmd `kong:"cmd,help='list policies',default:'1'"` - Dump policiesDumpCmd `kong:"cmd,help='dump a policy content on stdout'"` - Validate policiesValidateCmd `kong:"cmd,help='validate a policy file'"` - Import policiesImportCmd `kong:"cmd,help='import a policy'"` - Update policiesUpdateCmd `kong:"cmd,help='update a policy'"` - Delete policiesDeleteCmd `kong:"cmd,help='delete a policy'"` -} - -type policiesNewCmd struct { - Filepath string `kong:"name='filepath',short='f',help='optionnal filepath where the policy is written (default: stdout)'"` -} - -func (c *policiesNewCmd) Run(cfg *Config) error { - template := `// This is a template policy - policy "dummy" { - rules { - // This is a dummy rule - *head($var1) - <- body1(#ambient, $name), - body2($value) - @ $name == "example" - } - - caveats {[ - // this is a dummy caveat - *head($var1) - <- body1(#ambient, $name), - body2($value) - @ $name == "example" - ]} - }` - - d, err := policy.Parse(strings.NewReader(template)) - if err != nil { - return err - } - - out, err := policy.Print(d) - if err != nil { - return err - } - - if c.Filepath != "" { - ioutil.WriteFile(c.Filepath, []byte(out), 0644) - fmt.Printf("written %s\n", c.Filepath) - return nil - } - - fmt.Println(out) - return nil -} - -type policiesListCmd struct{} - -func (c *policiesListCmd) Run(cfg *Config) error { - policies, err := cfg.DB.ListBiscuitPolicies(context.Background()) - if err != nil { - return err - } - t := table.NewWriter() - t.SetOutputMirror(os.Stdout) - t.AppendHeader(table.Row{"ID", "Description", "Groups", "CreateTime", "UpdateTime"}) - for _, p := range policies { - t.AppendRow(table.Row{p.ID, getPolicyFirstComment(p), p.Groups, p.CreateTime, p.UpdateTime}) - } - t.Render() - return nil -} - -type policiesDumpCmd struct { - PolicyIDs []string `kong:"name='policy-ids',help='comma separated policy IDs to dump (default: all)'"` -} - -func (c *policiesDumpCmd) Run(cfg *Config) error { - allPolicies := "" - if len(c.PolicyIDs) > 0 { - for _, id := range c.PolicyIDs { - p, err := cfg.DB.GetBiscuitPolicy(context.Background(), id) - if err != nil { - return err - } - - allPolicies += p.Content - } - } else { - policies, err := cfg.DB.ListBiscuitPolicies(context.Background()) - if err != nil { - return err - } - for _, p := range policies { - allPolicies += p.Content - } - } - - doc, err := policy.Parse(strings.NewReader(allPolicies)) - if err != nil { - return err - } - - out, err := policy.Print(doc) - if err != nil { - return err - } - fmt.Printf("%s\n", out) - return nil -} - -type policiesImportCmd struct { - Filepath string `kong:"required,name='filepath',short='f',help='a file containing policy definitions'"` -} - -func (c *policiesImportCmd) Run(cfg *Config) error { - content, err := ioutil.ReadFile(c.Filepath) - if err != nil { - return err - } - - doc, err := policy.Parse(strings.NewReader(string(content))) - if err != nil { - return err - } - - for _, p := range doc.Policies { - content := policy.PrintPolicy(p) - id, err := cfg.DB.CreateBiscuitPolicy(context.Background(), &hubauth.BiscuitPolicy{ - Content: string(content), - }) - if err != nil { - return err - } - - fmt.Printf("Imported policy %q: %s\n", *p.Name, id) - } - return nil -} - -type policiesValidateCmd struct { - Filepath string `kong:"required,name='filepath',short='f',help='a file containing policy definitions'"` -} - -func (c *policiesValidateCmd) Run(cfg *Config) error { - content, err := ioutil.ReadFile(c.Filepath) - if err != nil { - return err - } - - _, err = policy.ParseNamed(c.Filepath, strings.NewReader(string(content))) - if err != nil { - return err - } - - return nil -} - -type policiesUpdateCmd struct { - PolicyID string `kong:"required,name='policy-id',help='a policy ID to update'"` - AddGroups []string `kong:"name='add-groups',help='comma separated list of groups to add on the policy'"` - DeleteGroups []string `kong:"name='delete-groups',help='comma separated list of groups to delete from the policy'"` -} - -func (c *policiesUpdateCmd) Run(cfg *Config) error { - var mut []*hubauth.BiscuitPolicyMutation - for _, g := range c.AddGroups { - mut = append(mut, &hubauth.BiscuitPolicyMutation{ - Op: hubauth.BiscuitPolicyMutationOpAddGroup, - Group: g, - }) - } - for _, g := range c.DeleteGroups { - mut = append(mut, &hubauth.BiscuitPolicyMutation{ - Op: hubauth.BiscuitPolicyMutationOpDeleteGroup, - Group: g, - }) - } - if err := cfg.DB.MutateBiscuitPolicy(context.Background(), c.PolicyID, mut); err != nil { - return err - } - return nil -} - -type policiesDeleteCmd struct { - PolicyID string `kong:"required,name='policy-id',help='a policy ID to delete'"` -} - -func (c *policiesDeleteCmd) Run(cfg *Config) error { - return cfg.DB.DeleteBiscuitPolicy(context.Background(), c.PolicyID) -} - -// getPolicyFirstComment parse the policy content and returns the first policy -// comment line if it exists. On error or when not set, an empty string is returned. -func getPolicyFirstComment(p *hubauth.BiscuitPolicy) string { - doc, err := policy.Parse(strings.NewReader(p.Content)) - if err != nil { - return "" - } - if len(doc.Policies) == 0 { - return "" - } - if len(doc.Policies[0].Comments) == 0 { - return "" - } - return string(*doc.Policies[0].Comments[0]) -} diff --git a/pkg/policy/parser.go b/pkg/policy/parser.go index 86aaeb3..b3886aa 100644 --- a/pkg/policy/parser.go +++ b/pkg/policy/parser.go @@ -18,7 +18,7 @@ var policyLexer = stateful.MustSimple(append( )) type Document struct { - Policies []*DocumentPolicy `@@*` + Policies []*DocumentPolicy `@@+` } type DocumentPolicy struct { diff --git a/pkg/policy/printer.go b/pkg/policy/printer.go index 25433cd..8676784 100644 --- a/pkg/policy/printer.go +++ b/pkg/policy/printer.go @@ -2,11 +2,21 @@ package policy import ( "fmt" + "io" "strings" "github.com/flynn/biscuit-go/parser" ) +func Format(r io.Reader) (string, error) { + d, err := Parse(r) + if err != nil { + return "", err + } + + return Print(d) +} + func Print(d *Document) (string, error) { p := &printer{ indent: 0, From 7f0f2c2a386de65b5f3483eb5fcfcbc00432a392 Mon Sep 17 00:00:00 2001 From: Flavien Binet Date: Wed, 16 Dec 2020 15:01:47 +0100 Subject: [PATCH 20/24] idp: add policy support to issued biscuits --- pkg/idp/token/biscuit.go | 12 ++-- pkg/idp/token/biscuit_test.go | 125 +++++++++++++++++++++++++++------- 2 files changed, 106 insertions(+), 31 deletions(-) diff --git a/pkg/idp/token/biscuit.go b/pkg/idp/token/biscuit.go index 4743124..a4857b0 100644 --- a/pkg/idp/token/biscuit.go +++ b/pkg/idp/token/biscuit.go @@ -22,12 +22,12 @@ var ( type biscuitBuilder struct { kms kmssign.KMSClient - db hubauth.BiscuitPolicyStore + db hubauth.AudienceGetterStore audienceKey kmssign.AudienceKeyNamer rootKeyPair sig.Keypair } -func NewBiscuitBuilder(kms kmssign.KMSClient, db hubauth.BiscuitPolicyStore, audienceKey kmssign.AudienceKeyNamer, rootKeyPair sig.Keypair) AccessTokenBuilder { +func NewBiscuitBuilder(kms kmssign.KMSClient, db hubauth.AudienceGetterStore, audienceKey kmssign.AudienceKeyNamer, rootKeyPair sig.Keypair) AccessTokenBuilder { return &biscuitBuilder{ kms: kms, db: db, @@ -57,7 +57,7 @@ func (b *biscuitBuilder) Build(ctx context.Context, audience string, t *AccessTo } // retrieve policies from user groups and add each policy rules and caveats to the biscuit - userPolicies, err := b.getUserPolicies(ctx, t.UserGroups) + userPolicies, err := b.getUserPolicies(ctx, audience, t.UserGroups) if err != nil { return nil, err } @@ -94,14 +94,14 @@ func DecodeB64PrivateKey(b64key string) (sig.Keypair, error) { return kp, nil } -func (b *biscuitBuilder) getUserPolicies(ctx context.Context, userGroups []string) ([]*hubauth.BiscuitPolicy, error) { - allPolicies, err := b.db.ListBiscuitPolicies(ctx) +func (b *biscuitBuilder) getUserPolicies(ctx context.Context, audience string, userGroups []string) ([]*hubauth.BiscuitPolicy, error) { + aud, err := b.db.GetAudience(ctx, audience) if err != nil { return nil, err } var userPolicies []*hubauth.BiscuitPolicy - for _, p := range allPolicies { + for _, p := range aud.Policies { outer: for _, g := range p.Groups { for _, ug := range userGroups { diff --git a/pkg/idp/token/biscuit_test.go b/pkg/idp/token/biscuit_test.go index a32a1d7..0129b59 100644 --- a/pkg/idp/token/biscuit_test.go +++ b/pkg/idp/token/biscuit_test.go @@ -7,53 +7,40 @@ import ( "crypto/rand" "crypto/x509" "encoding/base64" + "encoding/pem" "testing" "time" "github.com/flynn/biscuit-go" + "github.com/flynn/biscuit-go/cookbook/signedbiscuit" "github.com/flynn/biscuit-go/sig" "github.com/flynn/hubauth/pkg/hubauth" "github.com/flynn/hubauth/pkg/kmssign/kmssim" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + "google.golang.org/genproto/googleapis/cloud/kms/v1" ) -type mockPolicyStore struct { +type mockAudienceGetterStore struct { mock.Mock } -var _ hubauth.BiscuitPolicyStore = (*mockPolicyStore)(nil) - -func (m *mockPolicyStore) GetBiscuitPolicy(ctx context.Context, id string) (*hubauth.BiscuitPolicy, error) { - args := m.Called(ctx, id) - return args.Get(0).(*hubauth.BiscuitPolicy), args.Error(1) -} -func (m *mockPolicyStore) CreateBiscuitPolicy(ctx context.Context, policy *hubauth.BiscuitPolicy) (string, error) { - args := m.Called(ctx, policy) - return args.String(0), args.Error(1) -} -func (m *mockPolicyStore) MutateBiscuitPolicy(ctx context.Context, id string, mut []*hubauth.BiscuitPolicyMutation) error { - args := m.Called(ctx, id, mut) - return args.Error(0) -} -func (m *mockPolicyStore) ListBiscuitPolicies(ctx context.Context) ([]*hubauth.BiscuitPolicy, error) { - args := m.Called(ctx) - return args.Get(1).([]*hubauth.BiscuitPolicy), args.Error(1) -} -func (m *mockPolicyStore) DeleteBiscuitPolicy(ctx context.Context, id string) error { - args := m.Called(ctx, id) - return args.Error(0) +func (m *mockAudienceGetterStore) GetAudience(ctx context.Context, url string) (*hubauth.Audience, error) { + args := m.Called(ctx, url) + return args.Get(0).(*hubauth.Audience), args.Error(1) } +var _ hubauth.AudienceGetterStore = (*mockAudienceGetterStore)(nil) + func TestBiscuitBuilder(t *testing.T) { audience := "https://audience.url" audienceKeyName := audienceKeyNamer(audience) kmsClient := kmssim.NewClient([]string{audienceKeyName}) rootKeyPair := sig.GenerateKeypair(rand.Reader) - policyStore := new(mockPolicyStore) + audienceGetterStore := new(mockAudienceGetterStore) - builder := NewBiscuitBuilder(kmsClient, policyStore, audienceKeyNamer, rootKeyPair) + builder := NewBiscuitBuilder(kmsClient, audienceGetterStore, audienceKeyNamer, rootKeyPair) priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) require.NoError(t, err) @@ -61,26 +48,114 @@ func TestBiscuitBuilder(t *testing.T) { require.NoError(t, err) now := time.Now() + userGroups := []string{"grp1", "grp2"} + accessTokenData := &AccessTokenData{ ClientID: "clientID", ExpireTime: now.Add(1 * time.Minute), IssueTime: now, UserEmail: "user@email", + UserGroups: userGroups, UserID: "userID", } _, err = builder.Build(context.Background(), audience, accessTokenData) require.Equal(t, ErrPublicKeyRequired, err) accessTokenData.UserPublicKey = userPublicKey + + p1Content := ` + policy "p1" { + caveats {[ + *valid() <- test(#ambient, "policy1exists") + ]} + } + ` + + p2Content := ` + policy "p2" { + rules { + *test(#authority, $inputStr) + <- testRule(#ambient, $inputStr) + } + caveats {[ + *valid() <- test(#authority, "policy2exists") + ]} + } + ` + + p3Content := ` + policy "p3" { + caveats {[ + *valid() <- test(#ambient, "policy3exists") + ]} + } + ` + + aud := &hubauth.Audience{ + URL: audience, + Policies: []*hubauth.BiscuitPolicy{ + { + Name: "p1", + Content: p1Content, + Groups: []string{"grp1"}, + }, + { + Name: "p2", + Content: p2Content, + Groups: []string{"grp2", "grp3"}, + }, + { + Name: "p3", + Content: p3Content, + Groups: []string{"grp3"}, + }, + }, + } + audienceGetterStore.On("GetAudience", mock.Anything, audience).Return(aud, nil) + token, err := builder.Build(context.Background(), audience, accessTokenData) require.NoError(t, err) require.NotEmpty(t, token) + userKeyPair, err := signedbiscuit.NewECDSAKeyPair(priv) + require.NoError(t, err) + token, err = signedbiscuit.Sign(token, rootKeyPair.Public(), userKeyPair) + require.NoError(t, err) + b, err := biscuit.Unmarshal(token) require.NoError(t, err) - _, err = b.Verify(rootKeyPair.Public()) + verifier, err := b.Verify(rootKeyPair.Public()) + require.NoError(t, err) + + kmsPubkey, err := kmsClient.GetPublicKey(context.Background(), &kms.GetPublicKeyRequest{Name: audienceKeyName}) + require.NoError(t, err) + pemBlock, _ := pem.Decode([]byte(kmsPubkey.Pem)) + audiencePubKey, err := x509.ParsePKIXPublicKey(pemBlock.Bytes) + require.NoError(t, err) + + verifier, metas, err := signedbiscuit.WithSignatureVerification(verifier, audience, audiencePubKey.(*ecdsa.PublicKey)) require.NoError(t, err) + + require.Equal(t, accessTokenData.ClientID, metas.ClientID) + require.Equal(t, accessTokenData.UserEmail, metas.UserEmail) + require.Equal(t, accessTokenData.UserGroups, metas.UserGroups) + require.Equal(t, accessTokenData.UserID, metas.UserID) + require.Equal(t, accessTokenData.IssueTime.Unix(), metas.IssueTime.Unix()) + + require.Error(t, verifier.Verify()) + + verifier.AddFact(biscuit.Fact{Predicate: biscuit.Predicate{ + Name: "test", + IDs: []biscuit.Atom{biscuit.Symbol("ambient"), biscuit.String("policy1exists")}, + }}) + require.Error(t, verifier.Verify()) + + verifier.AddFact(biscuit.Fact{Predicate: biscuit.Predicate{ + Name: "testRule", + IDs: []biscuit.Atom{biscuit.Symbol("ambient"), biscuit.String("policy2exists")}, + }}) + require.NoError(t, verifier.Verify()) } func TestDecodeB64PrivateKey(t *testing.T) { From c01bfcdd81e4df2b097ccc72b05f0d8cd67c161e Mon Sep 17 00:00:00 2001 From: Flavien Binet Date: Wed, 16 Dec 2020 15:44:13 +0100 Subject: [PATCH 21/24] fix bad rebase --- pkg/cli/audiences.go | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/pkg/cli/audiences.go b/pkg/cli/audiences.go index 281fbde..3dc11a5 100644 --- a/pkg/cli/audiences.go +++ b/pkg/cli/audiences.go @@ -219,8 +219,43 @@ func (c *audiencesSetUserGroupsCmd) Run(cfg *Config) error { } return cfg.DB.MutateAudience(context.Background(), c.AudienceURL, []*hubauth.AudienceMutation{mut}) } + +type audiencesUpdateUserGroupsCmd struct { + AudienceURL string `kong:"required,name='audience-url',help='audience URL'"` + Domain string `kong:"required,help='G Suite domain name'"` + APIUser string `kong:"name='api-user',help='G Suite user email to impersonate for API calls'"` + AddGroups []string `kong:"name='add-groups',help='comma-separated group IDs to add'"` + DeleteGroups []string `kong:"name='delete-groups',help='comma-separated group IDs to delete'"` +} + +func (c *audiencesUpdateUserGroupsCmd) Run(cfg *Config) error { + var muts []*hubauth.AudienceUserGroupsMutation + for _, groupID := range c.AddGroups { + muts = append(muts, &hubauth.AudienceUserGroupsMutation{ + Op: hubauth.AudienceUserGroupsMutationOpAddGroup, + Group: groupID, + }) + } + for _, groupID := range c.DeleteGroups { + muts = append(muts, &hubauth.AudienceUserGroupsMutation{ + Op: hubauth.AudienceUserGroupsMutationOpDeleteGroup, + Group: groupID, + }) + } + if c.APIUser != "" { + muts = append(muts, &hubauth.AudienceUserGroupsMutation{ + Op: hubauth.AudienceUserGroupsMutationOpSetAPIUser, + APIUser: c.APIUser, }) + } + + return cfg.DB.MutateAudienceUserGroups(context.Background(), c.AudienceURL, c.Domain, muts) +} +type audiencesDeleteUserGroupsCmd struct { + AudienceURL string `kong:"required,name='audience-url',help='audience URL'"` + Domain string `kong:"required,help='G Suite domain name'"` +} func (c *audiencesDeleteUserGroupsCmd) Run(cfg *Config) error { mut := &hubauth.AudienceMutation{ From af770442b9326fe7890face5554062a7c4d2de3f Mon Sep 17 00:00:00 2001 From: Flavien Binet Date: Wed, 16 Dec 2020 15:48:51 +0100 Subject: [PATCH 22/24] bump biscuit version, remove replace directive --- go.mod | 4 +--- go.sum | 2 ++ 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 72bb495..7cb5039 100644 --- a/go.mod +++ b/go.mod @@ -10,7 +10,7 @@ require ( github.com/alecthomas/participle/v2 v2.0.0-alpha3.0.20201208114601-14bec2482095 github.com/aws/aws-sdk-go v1.36.7 // indirect github.com/census-instrumentation/opencensus-proto v0.3.0 // indirect - github.com/flynn/biscuit-go v0.0.0-20201009174859-e7eb59a90195 + github.com/flynn/biscuit-go v0.0.0-20201211135022-dbd2f8863bf4 github.com/golang/protobuf v1.4.3 github.com/googleapis/gax-go/v2 v2.0.5 github.com/jedib0t/go-pretty/v6 v6.0.5 @@ -30,5 +30,3 @@ require ( google.golang.org/grpc v1.34.0 google.golang.org/protobuf v1.25.0 ) - -replace github.com/flynn/biscuit-go => /home/flavien/workspace/flynn-ws/biscuit-go diff --git a/go.sum b/go.sum index a397178..61cea50 100644 --- a/go.sum +++ b/go.sum @@ -79,6 +79,8 @@ github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.m github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= github.com/envoyproxy/go-control-plane v0.9.7/go.mod h1:cwu0lG7PUMfa9snN8LXBig5ynNVH9qI8YYLbd1fK2po= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/flynn/biscuit-go v0.0.0-20201211135022-dbd2f8863bf4 h1:5TqasLkkptxZIP8TNawz76F+vMSf04Mab8/d8VdJWus= +github.com/flynn/biscuit-go v0.0.0-20201211135022-dbd2f8863bf4/go.mod h1:mY0paJD7nJ1hsxNzqHOKES2u+asFldh25X8WkveoMaw= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= From 4ebd62036488c288a10da167ad9f4ad23ede8d9b Mon Sep 17 00:00:00 2001 From: Flavien Binet Date: Thu, 17 Dec 2020 15:36:03 +0100 Subject: [PATCH 23/24] remove spurious newline --- pkg/datastore/audience_test.go | 1 - 1 file changed, 1 deletion(-) diff --git a/pkg/datastore/audience_test.go b/pkg/datastore/audience_test.go index 6ae7b34..abfefa1 100644 --- a/pkg/datastore/audience_test.go +++ b/pkg/datastore/audience_test.go @@ -858,7 +858,6 @@ func TestMutateAudiencePolicy(t *testing.T) { ctx := context.Background() for _, tt := range tests { t.Run(tt.desc, func(t *testing.T) { - aud := &hubauth.Audience{ URL: "https://cluster.mutate.example.com", Policies: tt.before, From e93c2d947032dd73e591f1d0d0bc0fe6d7b1f2a1 Mon Sep 17 00:00:00 2001 From: Flavien Binet Date: Thu, 17 Dec 2020 15:36:35 +0100 Subject: [PATCH 24/24] fix cancelled context when building access token --- pkg/idp/oauth.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pkg/idp/oauth.go b/pkg/idp/oauth.go index 6317ec1..b329095 100644 --- a/pkg/idp/oauth.go +++ b/pkg/idp/oauth.go @@ -322,7 +322,7 @@ func (s *idpService) ExchangeCode(parentCtx context.Context, req *hubauth.Exchan } } - accessToken, tokenType, err = s.steps.BuildAccessToken(ctx, req.Audience, &token.AccessTokenData{ + accessToken, tokenType, err = s.steps.BuildAccessToken(parentCtx, req.Audience, &token.AccessTokenData{ ClientID: req.ClientID, UserID: codeInfo.UserId, UserEmail: codeInfo.UserEmail, @@ -358,13 +358,13 @@ func (s *idpService) ExchangeCode(parentCtx context.Context, req *hubauth.Exchan return res, nil } -func (s *idpService) RefreshToken(ctx context.Context, req *hubauth.RefreshTokenRequest) (*hubauth.AccessToken, error) { - oldToken, err := s.decodeRefreshToken(ctx, req.RefreshToken) +func (s *idpService) RefreshToken(parentCtx context.Context, req *hubauth.RefreshTokenRequest) (*hubauth.AccessToken, error) { + oldToken, err := s.decodeRefreshToken(parentCtx, req.RefreshToken) if err != nil { return nil, err } - g, ctx := errgroup.WithContext(ctx) + g, ctx := errgroup.WithContext(parentCtx) var userGroups []string g.Go(func() (err error) { @@ -411,7 +411,7 @@ func (s *idpService) RefreshToken(ctx context.Context, req *hubauth.RefreshToken } } - accessToken, tokenType, err = s.steps.BuildAccessToken(ctx, req.Audience, &token.AccessTokenData{ + accessToken, tokenType, err = s.steps.BuildAccessToken(parentCtx, req.Audience, &token.AccessTokenData{ ClientID: req.ClientID, UserID: oldToken.UserID, UserEmail: oldToken.UserEmail,