From 8c84ad5bb068ec6e956bf69e52a58a813c548c4f Mon Sep 17 00:00:00 2001 From: Ben Hoyt Date: Fri, 5 Dec 2025 12:59:02 +1300 Subject: [PATCH 01/15] move identities out of state package - "go vet" passes --- client/identities.go | 12 +- internals/daemon/access.go | 10 +- internals/daemon/access_test.go | 26 +- internals/daemon/api_identities.go | 19 +- internals/daemon/api_identities_test.go | 99 +++-- internals/daemon/api_notices.go | 7 +- internals/daemon/api_notices_test.go | 55 +-- internals/daemon/daemon.go | 14 +- internals/daemon/daemon_test.go | 69 +-- .../{state => identities}/identities.go | 190 +++++--- .../{state => identities}/identities_test.go | 414 ++++++++++-------- internals/overlord/overlord.go | 38 +- internals/overlord/pairingstate/manager.go | 23 +- .../overlord/pairingstate/manager_test.go | 126 +++--- .../overlord/pairingstate/package_test.go | 13 +- internals/overlord/state/state.go | 104 +---- 16 files changed, 655 insertions(+), 564 deletions(-) rename internals/overlord/{state => identities}/identities.go (71%) rename internals/overlord/{state => identities}/identities_test.go (61%) diff --git a/client/identities.go b/client/identities.go index a3aedf221..78a1b0beb 100644 --- a/client/identities.go +++ b/client/identities.go @@ -23,7 +23,7 @@ import ( // Identity holds the configuration of a single identity. type Identity struct { - Access IdentityAccess `json:"access" yaml:"access"` + Access Access `json:"access" yaml:"access"` // One or more of the following type-specific configuration fields must be // non-nil. @@ -32,13 +32,13 @@ type Identity struct { Cert *CertIdentity `json:"cert,omitempty" yaml:"cert,omitempty"` } -// IdentityAccess defines the access level for an identity. -type IdentityAccess string +// Access defines the access level for an identity. +type Access string const ( - AdminAccess IdentityAccess = "admin" - ReadAccess IdentityAccess = "read" - UntrustedAccess IdentityAccess = "untrusted" + AdminAccess Access = "admin" + ReadAccess Access = "read" + UntrustedAccess Access = "untrusted" ) // LocalIdentity holds identity configuration specific to the "local" type diff --git a/internals/daemon/access.go b/internals/daemon/access.go index 556998c2c..56f257053 100644 --- a/internals/daemon/access.go +++ b/internals/daemon/access.go @@ -17,7 +17,7 @@ package daemon import ( "net/http" - "github.com/canonical/pebble/internals/overlord/state" + "github.com/canonical/pebble/internals/overlord/identities" ) const ( @@ -52,7 +52,7 @@ func (ac AdminAccess) CheckAccess(d *Daemon, r *http.Request, user *UserState) R // Not Unix Domain Socket or HTTPS. return Unauthorized(accessDenied) } - if user.Access == state.AdminAccess { + if user.Access == identities.AdminAccess { return nil } // An identity explicitly set to "access: read" or "access: untrusted" isn't allowed. @@ -73,7 +73,7 @@ func (ac UserAccess) CheckAccess(d *Daemon, r *http.Request, user *UserState) Re return Unauthorized(accessDenied) } switch user.Access { - case state.ReadAccess, state.AdminAccess: + case identities.ReadAccess, identities.AdminAccess: return nil } // An identity explicitly set to "access: untrusted" isn't allowed. @@ -95,7 +95,7 @@ func (ac MetricsAccess) CheckAccess(d *Daemon, r *http.Request, user *UserState) // HTTP access (only basic auth is possible here, so no need to // check with identity type). transport := RequestTransportType(r) - if transport == TransportTypeHTTP && user.Access == state.MetricsAccess { + if transport == TransportTypeHTTP && user.Access == identities.MetricsAccess { return nil } if !transport.IsConcealed() { @@ -103,7 +103,7 @@ func (ac MetricsAccess) CheckAccess(d *Daemon, r *http.Request, user *UserState) return Unauthorized(accessDenied) } switch user.Access { - case state.MetricsAccess, state.ReadAccess, state.AdminAccess: + case identities.MetricsAccess, identities.ReadAccess, identities.AdminAccess: return nil default: // All other access levels, including "access: untrusted", are denied. diff --git a/internals/daemon/access_test.go b/internals/daemon/access_test.go index 6bbd9fd87..24378c883 100644 --- a/internals/daemon/access_test.go +++ b/internals/daemon/access_test.go @@ -22,7 +22,7 @@ import ( . "gopkg.in/check.v1" "github.com/canonical/pebble/internals/daemon" - "github.com/canonical/pebble/internals/overlord/state" + "github.com/canonical/pebble/internals/overlord/identities" ) type accessSuite struct{} @@ -54,7 +54,7 @@ func (s *accessSuite) TestAccess(c *C) { }, { // User access: UntrustedAccess apiSource: daemon.TransportTypeUnixSocket, - user: &daemon.UserState{Access: state.UntrustedAccess}, + user: &daemon.UserState{Access: identities.UntrustedAccess}, openCheckErr: nil, adminCheckErr: errUnauthorized, userCheckErr: errUnauthorized, @@ -63,7 +63,7 @@ func (s *accessSuite) TestAccess(c *C) { }, { // User access: MetricsAccess apiSource: daemon.TransportTypeUnixSocket, - user: &daemon.UserState{Access: state.MetricsAccess}, + user: &daemon.UserState{Access: identities.MetricsAccess}, openCheckErr: nil, adminCheckErr: errUnauthorized, userCheckErr: errUnauthorized, @@ -72,7 +72,7 @@ func (s *accessSuite) TestAccess(c *C) { }, { // User access: ReadAccess apiSource: daemon.TransportTypeUnixSocket, - user: &daemon.UserState{Access: state.ReadAccess}, + user: &daemon.UserState{Access: identities.ReadAccess}, openCheckErr: nil, adminCheckErr: errUnauthorized, userCheckErr: nil, @@ -81,7 +81,7 @@ func (s *accessSuite) TestAccess(c *C) { }, { // User access: AdminAccess apiSource: daemon.TransportTypeUnixSocket, - user: &daemon.UserState{Access: state.AdminAccess}, + user: &daemon.UserState{Access: identities.AdminAccess}, openCheckErr: nil, adminCheckErr: nil, userCheckErr: nil, @@ -101,7 +101,7 @@ func (s *accessSuite) TestAccess(c *C) { }, { // User access: UntrustedAccess apiSource: daemon.TransportTypeHTTP, - user: &daemon.UserState{Access: state.UntrustedAccess}, + user: &daemon.UserState{Access: identities.UntrustedAccess}, openCheckErr: nil, adminCheckErr: errUnauthorized, userCheckErr: errUnauthorized, @@ -110,7 +110,7 @@ func (s *accessSuite) TestAccess(c *C) { }, { // User access: MetricsAccess apiSource: daemon.TransportTypeHTTP, - user: &daemon.UserState{Access: state.MetricsAccess}, + user: &daemon.UserState{Access: identities.MetricsAccess}, openCheckErr: nil, adminCheckErr: errUnauthorized, userCheckErr: errUnauthorized, @@ -119,7 +119,7 @@ func (s *accessSuite) TestAccess(c *C) { }, { // User access: ReadAccess apiSource: daemon.TransportTypeHTTP, - user: &daemon.UserState{Access: state.ReadAccess}, + user: &daemon.UserState{Access: identities.ReadAccess}, openCheckErr: nil, adminCheckErr: errUnauthorized, userCheckErr: errUnauthorized, @@ -128,7 +128,7 @@ func (s *accessSuite) TestAccess(c *C) { }, { // User access: AdminAccess apiSource: daemon.TransportTypeHTTP, - user: &daemon.UserState{Access: state.AdminAccess}, + user: &daemon.UserState{Access: identities.AdminAccess}, openCheckErr: nil, adminCheckErr: errUnauthorized, userCheckErr: errUnauthorized, @@ -148,7 +148,7 @@ func (s *accessSuite) TestAccess(c *C) { }, { // User access: UntrustedAccess apiSource: daemon.TransportTypeHTTPS, - user: &daemon.UserState{Access: state.UntrustedAccess}, + user: &daemon.UserState{Access: identities.UntrustedAccess}, openCheckErr: nil, adminCheckErr: errUnauthorized, userCheckErr: errUnauthorized, @@ -157,7 +157,7 @@ func (s *accessSuite) TestAccess(c *C) { }, { // User access: MetricsAccess apiSource: daemon.TransportTypeHTTPS, - user: &daemon.UserState{Access: state.MetricsAccess}, + user: &daemon.UserState{Access: identities.MetricsAccess}, openCheckErr: nil, adminCheckErr: errUnauthorized, userCheckErr: errUnauthorized, @@ -166,7 +166,7 @@ func (s *accessSuite) TestAccess(c *C) { }, { // User access: ReadAccess apiSource: daemon.TransportTypeHTTPS, - user: &daemon.UserState{Access: state.ReadAccess}, + user: &daemon.UserState{Access: identities.ReadAccess}, openCheckErr: nil, adminCheckErr: errUnauthorized, userCheckErr: nil, @@ -175,7 +175,7 @@ func (s *accessSuite) TestAccess(c *C) { }, { // User access: AdminAccess apiSource: daemon.TransportTypeHTTPS, - user: &daemon.UserState{Access: state.AdminAccess}, + user: &daemon.UserState{Access: identities.AdminAccess}, openCheckErr: nil, adminCheckErr: nil, userCheckErr: nil, diff --git a/internals/daemon/api_identities.go b/internals/daemon/api_identities.go index 36feeced9..eb432a1a7 100644 --- a/internals/daemon/api_identities.go +++ b/internals/daemon/api_identities.go @@ -20,7 +20,7 @@ import ( "net/http" "github.com/canonical/pebble/internals/logger" - "github.com/canonical/pebble/internals/overlord/state" + "github.com/canonical/pebble/internals/overlord/identities" ) func v1GetIdentities(c *Command, r *http.Request, _ *UserState) Response { @@ -28,14 +28,15 @@ func v1GetIdentities(c *Command, r *http.Request, _ *UserState) Response { st.Lock() defer st.Unlock() - identities := st.Identities() + identitiesMgr := c.d.overlord.IdentitiesManager() + identities := identitiesMgr.Identities() return SyncResponse(identities) } func v1PostIdentities(c *Command, r *http.Request, user *UserState) Response { var payload struct { - Action string `json:"action"` - Identities map[string]*state.Identity `json:"identities"` + Action string `json:"action"` + Identities map[string]*identities.Identity `json:"identities"` } decoder := json.NewDecoder(r.Body) if err := decoder.Decode(&payload); err != nil { @@ -68,6 +69,8 @@ func v1PostIdentities(c *Command, r *http.Request, user *UserState) Response { st.Lock() defer st.Unlock() + identitiesMgr := c.d.overlord.IdentitiesManager() + var err error switch payload.Action { case "add": @@ -76,14 +79,14 @@ func v1PostIdentities(c *Command, r *http.Request, user *UserState) Response { fmt.Sprintf("%s,%s,%s", userString(user), name, identity.Access), fmt.Sprintf("Creating %s user %s", identity.Access, name)) } - err = st.AddIdentities(payload.Identities) + err = identitiesMgr.AddIdentities(payload.Identities) case "update": for name, identity := range payload.Identities { logger.SecurityWarn(logger.SecurityUserUpdated, fmt.Sprintf("%s,%s,%s", userString(user), name, identity.Access), fmt.Sprintf("Updating %s user %s", identity.Access, name)) } - err = st.UpdateIdentities(payload.Identities) + err = identitiesMgr.UpdateIdentities(payload.Identities) case "replace": for name, identity := range payload.Identities { if identity == nil { @@ -96,14 +99,14 @@ func v1PostIdentities(c *Command, r *http.Request, user *UserState) Response { fmt.Sprintf("Updating %s user %s", identity.Access, name)) } } - err = st.ReplaceIdentities(payload.Identities) + err = identitiesMgr.ReplaceIdentities(payload.Identities) case "remove": for name := range payload.Identities { logger.SecurityWarn(logger.SecurityUserDeleted, fmt.Sprintf("%s,%s", userString(user), name), fmt.Sprintf("Deleting user %s", name)) } - err = st.RemoveIdentities(identityNames) + err = identitiesMgr.RemoveIdentities(identityNames) } if err != nil { return BadRequest("%v", err) diff --git a/internals/daemon/api_identities_test.go b/internals/daemon/api_identities_test.go index 8942b2d40..23067d723 100644 --- a/internals/daemon/api_identities_test.go +++ b/internals/daemon/api_identities_test.go @@ -22,22 +22,25 @@ import ( . "gopkg.in/check.v1" "github.com/canonical/pebble/internals/logger" - "github.com/canonical/pebble/internals/overlord/state" + "github.com/canonical/pebble/internals/overlord/identities" ) func (s *apiSuite) TestIdentities(c *C) { s.daemon(c) st := s.d.overlord.State() + identitiesMgr, err := identities.NewManager(st) + c.Assert(err, IsNil) st.Lock() - err := st.AddIdentities(map[string]*state.Identity{ + + err = identitiesMgr.AddIdentities(map[string]*identities.Identity{ "bob": { - Access: state.ReadAccess, - Local: &state.LocalIdentity{UserID: 42}, + Access: identities.ReadAccess, + Local: &identities.LocalIdentity{UserID: 42}, }, "mary": { - Access: state.AdminAccess, - Local: &state.LocalIdentity{UserID: 1000}, + Access: identities.AdminAccess, + Local: &identities.LocalIdentity{UserID: 1000}, }, }) c.Assert(err, IsNil) @@ -51,7 +54,7 @@ func (s *apiSuite) TestIdentities(c *C) { c.Check(rsp.Type, Equals, ResponseTypeSync) c.Check(rsp.Status, Equals, http.StatusOK) - identities, ok := rsp.Result.(map[string]*state.Identity) + identities, ok := rsp.Result.(map[string]*identities.Identity) c.Assert(ok, Equals, true) data, err := json.MarshalIndent(identities, "", " ") @@ -102,18 +105,20 @@ func (s *apiSuite) TestAddIdentities(c *C) { c.Check(rsp.Status, Equals, http.StatusOK) st := s.d.overlord.State() + identitiesMgr, err := identities.NewManager(st) + c.Assert(err, IsNil) st.Lock() - identities := st.Identities() - c.Assert(identities, DeepEquals, map[string]*state.Identity{ + idents := identitiesMgr.Identities() + c.Assert(idents, DeepEquals, map[string]*identities.Identity{ "bob": { Name: "bob", - Access: state.ReadAccess, - Local: &state.LocalIdentity{UserID: 42}, + Access: identities.ReadAccess, + Local: &identities.LocalIdentity{UserID: 42}, }, "mary": { Name: "mary", - Access: state.AdminAccess, - Local: &state.LocalIdentity{UserID: 1000}, + Access: identities.AdminAccess, + Local: &identities.LocalIdentity{UserID: 1000}, }, }) st.Unlock() @@ -147,15 +152,17 @@ func (s *apiSuite) TestUpdateIdentities(c *C) { s.daemon(c) st := s.d.overlord.State() + identitiesMgr, err := identities.NewManager(st) + c.Assert(err, IsNil) st.Lock() - err := st.AddIdentities(map[string]*state.Identity{ + err = identitiesMgr.AddIdentities(map[string]*identities.Identity{ "bob": { - Access: state.ReadAccess, - Local: &state.LocalIdentity{UserID: 42}, + Access: identities.ReadAccess, + Local: &identities.LocalIdentity{UserID: 42}, }, "mary": { - Access: state.AdminAccess, - Local: &state.LocalIdentity{UserID: 1000}, + Access: identities.AdminAccess, + Local: &identities.LocalIdentity{UserID: 1000}, }, }) c.Assert(err, IsNil) @@ -184,17 +191,17 @@ func (s *apiSuite) TestUpdateIdentities(c *C) { c.Check(rsp.Status, Equals, http.StatusOK) st.Lock() - identities := st.Identities() - c.Assert(identities, DeepEquals, map[string]*state.Identity{ + idents := identitiesMgr.Identities() + c.Assert(idents, DeepEquals, map[string]*identities.Identity{ "bob": { Name: "bob", - Access: state.AdminAccess, - Local: &state.LocalIdentity{UserID: 1000}, + Access: identities.AdminAccess, + Local: &identities.LocalIdentity{UserID: 1000}, }, "mary": { Name: "mary", - Access: state.ReadAccess, - Local: &state.LocalIdentity{UserID: 42}, + Access: identities.ReadAccess, + Local: &identities.LocalIdentity{UserID: 42}, }, }) st.Unlock() @@ -228,15 +235,17 @@ func (s *apiSuite) TestReplaceIdentities(c *C) { s.daemon(c) st := s.d.overlord.State() + identitiesMgr, err := identities.NewManager(st) + c.Assert(err, IsNil) st.Lock() - err := st.AddIdentities(map[string]*state.Identity{ + err = identitiesMgr.AddIdentities(map[string]*identities.Identity{ "bob": { - Access: state.ReadAccess, - Local: &state.LocalIdentity{UserID: 42}, + Access: identities.ReadAccess, + Local: &identities.LocalIdentity{UserID: 42}, }, "mary": { - Access: state.AdminAccess, - Local: &state.LocalIdentity{UserID: 1000}, + Access: identities.AdminAccess, + Local: &identities.LocalIdentity{UserID: 1000}, }, }) c.Assert(err, IsNil) @@ -266,17 +275,17 @@ func (s *apiSuite) TestReplaceIdentities(c *C) { c.Check(rsp.Status, Equals, http.StatusOK) st.Lock() - identities := st.Identities() - c.Assert(identities, DeepEquals, map[string]*state.Identity{ + idents := identitiesMgr.Identities() + c.Assert(idents, DeepEquals, map[string]*identities.Identity{ "mary": { Name: "mary", - Access: state.ReadAccess, - Local: &state.LocalIdentity{UserID: 43}, + Access: identities.ReadAccess, + Local: &identities.LocalIdentity{UserID: 43}, }, "newguy": { Name: "newguy", - Access: state.AdminAccess, - Local: &state.LocalIdentity{UserID: 44}, + Access: identities.AdminAccess, + Local: &identities.LocalIdentity{UserID: 44}, }, }) st.Unlock() @@ -293,15 +302,17 @@ func (s *apiSuite) TestRemoveIdentities(c *C) { s.daemon(c) st := s.d.overlord.State() + identitiesMgr, err := identities.NewManager(st) + c.Assert(err, IsNil) st.Lock() - err := st.AddIdentities(map[string]*state.Identity{ + err = identitiesMgr.AddIdentities(map[string]*identities.Identity{ "bob": { - Access: state.ReadAccess, - Local: &state.LocalIdentity{UserID: 42}, + Access: identities.ReadAccess, + Local: &identities.LocalIdentity{UserID: 42}, }, "mary": { - Access: state.AdminAccess, - Local: &state.LocalIdentity{UserID: 1000}, + Access: identities.AdminAccess, + Local: &identities.LocalIdentity{UserID: 1000}, }, }) c.Assert(err, IsNil) @@ -319,12 +330,12 @@ func (s *apiSuite) TestRemoveIdentities(c *C) { c.Check(rsp.Status, Equals, http.StatusOK) st.Lock() - identities := st.Identities() - c.Assert(identities, DeepEquals, map[string]*state.Identity{ + idents := identitiesMgr.Identities() + c.Assert(idents, DeepEquals, map[string]*identities.Identity{ "mary": { Name: "mary", - Access: state.AdminAccess, - Local: &state.LocalIdentity{UserID: 1000}, + Access: identities.AdminAccess, + Local: &identities.LocalIdentity{UserID: 1000}, }, }) st.Unlock() diff --git a/internals/daemon/api_notices.go b/internals/daemon/api_notices.go index 8cd274508..2a4ccb496 100644 --- a/internals/daemon/api_notices.go +++ b/internals/daemon/api_notices.go @@ -26,6 +26,7 @@ import ( "github.com/canonical/x-go/strutil" + "github.com/canonical/pebble/internals/overlord/identities" "github.com/canonical/pebble/internals/overlord/state" ) @@ -54,7 +55,7 @@ func v1GetNotices(c *Command, r *http.Request, user *UserState) Response { query := r.URL.Query() if len(query["user-id"]) > 0 { - if user.Access != state.AdminAccess { + if user.Access != identities.AdminAccess { return Forbidden(`only admins may use the "user-id" filter`) } var err error @@ -65,7 +66,7 @@ func v1GetNotices(c *Command, r *http.Request, user *UserState) Response { } if len(query["users"]) > 0 { - if user.Access != state.AdminAccess { + if user.Access != identities.AdminAccess { return Forbidden(`only admins may use the "users" filter`) } if len(query["user-id"]) > 0 { @@ -263,7 +264,7 @@ func noticeViewableByUser(notice *state.Notice, user *UserState) bool { // Notice has no UID, so it's viewable by any user (with a UID). return true } - if user.Access == state.AdminAccess { + if user.Access == identities.AdminAccess { // User is admin, they can view anything. return true } diff --git a/internals/daemon/api_notices_test.go b/internals/daemon/api_notices_test.go index 607e9f869..8b958baea 100644 --- a/internals/daemon/api_notices_test.go +++ b/internals/daemon/api_notices_test.go @@ -27,6 +27,7 @@ import ( . "gopkg.in/check.v1" + "github.com/canonical/pebble/internals/overlord/identities" "github.com/canonical/pebble/internals/overlord/state" ) @@ -86,7 +87,7 @@ func (s *apiSuite) testNoticesFilter(c *C, makeQuery func(after time.Time) url.V req, err := http.NewRequest("GET", "/v1/notices?"+query.Encode(), nil) c.Assert(err, IsNil) noticesCmd := apiCmd("/v1/notices") - rsp, ok := noticesCmd.GET(noticesCmd, req, userState(state.AdminAccess, 0)).(*resp) + rsp, ok := noticesCmd.GET(noticesCmd, req, userState(identities.AdminAccess, 0)).(*resp) c.Assert(ok, Equals, true) c.Check(rsp.Type, Equals, ResponseTypeSync) @@ -135,7 +136,7 @@ func (s *apiSuite) TestNoticesFilterMultipleTypes(c *C) { req, err := http.NewRequest("GET", "/v1/notices?types=change-update&types=warning,warning", nil) c.Assert(err, IsNil) noticesCmd := apiCmd("/v1/notices") - rsp, ok := noticesCmd.GET(noticesCmd, req, userState(state.ReadAccess, 1000)).(*resp) + rsp, ok := noticesCmd.GET(noticesCmd, req, userState(identities.ReadAccess, 1000)).(*resp) c.Assert(ok, Equals, true) c.Check(rsp.Type, Equals, ResponseTypeSync) @@ -164,7 +165,7 @@ func (s *apiSuite) TestNoticesFilterMultipleKeys(c *C) { req, err := http.NewRequest("GET", "/v1/notices?keys=a.b/x&keys=danger", nil) c.Assert(err, IsNil) noticesCmd := apiCmd("/v1/notices") - rsp, ok := noticesCmd.GET(noticesCmd, req, userState(state.ReadAccess, 1000)).(*resp) + rsp, ok := noticesCmd.GET(noticesCmd, req, userState(identities.ReadAccess, 1000)).(*resp) c.Assert(ok, Equals, true) c.Check(rsp.Type, Equals, ResponseTypeSync) @@ -192,7 +193,7 @@ func (s *apiSuite) TestNoticesFilterInvalidTypes(c *C) { req, err := http.NewRequest("GET", "/v1/notices?types=foo&types=warning&types=bar,baz", nil) c.Assert(err, IsNil) noticesCmd := apiCmd("/v1/notices") - rsp, ok := noticesCmd.GET(noticesCmd, req, userState(state.ReadAccess, 1000)).(*resp) + rsp, ok := noticesCmd.GET(noticesCmd, req, userState(identities.ReadAccess, 1000)).(*resp) c.Assert(ok, Equals, true) c.Check(rsp.Type, Equals, ResponseTypeSync) @@ -208,7 +209,7 @@ func (s *apiSuite) TestNoticesFilterInvalidTypes(c *C) { req, err = http.NewRequest("GET", "/v1/notices?types=foo&types=bar,baz", nil) c.Assert(err, IsNil) noticesCmd = apiCmd("/v1/notices") - rsp, ok = noticesCmd.GET(noticesCmd, req, userState(state.ReadAccess, 1000)).(*resp) + rsp, ok = noticesCmd.GET(noticesCmd, req, userState(identities.ReadAccess, 1000)).(*resp) c.Assert(ok, Equals, true) c.Check(rsp.Type, Equals, ResponseTypeSync) @@ -240,7 +241,7 @@ func (s *apiSuite) TestNoticesUserIDAdminDefault(c *C) { // Test that admin user sees their own and all public notices if no filter is specified req, err := http.NewRequest("GET", "/v1/notices", nil) c.Assert(err, IsNil) - rsp, ok := noticesCmd.GET(noticesCmd, req, userState(state.AdminAccess, 0)).(*resp) + rsp, ok := noticesCmd.GET(noticesCmd, req, userState(identities.AdminAccess, 0)).(*resp) c.Assert(ok, Equals, true) c.Check(rsp.Type, Equals, ResponseTypeSync) @@ -282,7 +283,7 @@ func (s *apiSuite) TestNoticesUserIDAdminFilter(c *C) { reqUrl := fmt.Sprintf("/v1/notices?%s", userIDValues.Encode()) req, err := http.NewRequest("GET", reqUrl, nil) c.Assert(err, IsNil) - rsp, ok := noticesCmd.GET(noticesCmd, req, userState(state.AdminAccess, 0)).(*resp) + rsp, ok := noticesCmd.GET(noticesCmd, req, userState(identities.AdminAccess, 0)).(*resp) c.Assert(ok, Equals, true) c.Check(rsp.Type, Equals, ResponseTypeSync) @@ -319,7 +320,7 @@ func (s *apiSuite) TestNoticesUserIDNonAdminDefault(c *C) { // Test that non-admin user by default only sees their notices and public notices. req, err := http.NewRequest("GET", "/v1/notices", nil) c.Assert(err, IsNil) - rsp, ok := noticesCmd.GET(noticesCmd, req, userState(state.ReadAccess, 1000)).(*resp) + rsp, ok := noticesCmd.GET(noticesCmd, req, userState(identities.ReadAccess, 1000)).(*resp) c.Assert(ok, Equals, true) c.Check(rsp.Type, Equals, ResponseTypeSync) @@ -350,7 +351,7 @@ func (s *apiSuite) TestNoticesUserIDNonAdminFilter(c *C) { reqUrl := "/v1/notices?user-id=1000" req, err := http.NewRequest("GET", reqUrl, nil) c.Assert(err, IsNil) - rsp, ok := noticesCmd.GET(noticesCmd, req, userState(state.ReadAccess, 1000)).(*resp) + rsp, ok := noticesCmd.GET(noticesCmd, req, userState(identities.ReadAccess, 1000)).(*resp) c.Assert(ok, Equals, true) c.Check(rsp.Type, Equals, ResponseTypeError) @@ -382,7 +383,7 @@ func (s *apiSuite) TestNoticesUsersAdminFilter(c *C) { reqUrl := "/v1/notices?users=all" req, err := http.NewRequest("GET", reqUrl, nil) c.Check(err, IsNil) - rsp, ok := noticesCmd.GET(noticesCmd, req, userState(state.AdminAccess, 0)).(*resp) + rsp, ok := noticesCmd.GET(noticesCmd, req, userState(identities.AdminAccess, 0)).(*resp) c.Assert(ok, Equals, true) c.Check(rsp.Type, Equals, ResponseTypeSync) @@ -419,7 +420,7 @@ func (s *apiSuite) TestNoticesUsersNonAdminFilter(c *C) { reqUrl := "/v2/notices?users=all" req, err := http.NewRequest("GET", reqUrl, nil) c.Check(err, IsNil) - rsp, ok := noticesCmd.GET(noticesCmd, req, userState(state.ReadAccess, 1000)).(*resp) + rsp, ok := noticesCmd.GET(noticesCmd, req, userState(identities.ReadAccess, 1000)).(*resp) c.Assert(ok, Equals, true) c.Check(rsp.Type, Equals, ResponseTypeError) @@ -441,7 +442,7 @@ func (s *apiSuite) TestNoticesUnknownRequestUID(c *C) { // Test that a connection with unknown UID is forbidden from receiving notices req, err := http.NewRequest("GET", "/v1/notices", nil) c.Assert(err, IsNil) - rsp, ok := noticesCmd.GET(noticesCmd, req, &UserState{Access: state.ReadAccess}).(*resp) + rsp, ok := noticesCmd.GET(noticesCmd, req, &UserState{Access: identities.ReadAccess}).(*resp) c.Assert(ok, Equals, true) c.Check(rsp.Type, Equals, ResponseTypeError) @@ -464,7 +465,7 @@ func (s *apiSuite) TestNoticesWait(c *C) { req, err := http.NewRequest("GET", "/v1/notices?timeout=1s", nil) c.Assert(err, IsNil) noticesCmd := apiCmd("/v1/notices") - rsp, ok := noticesCmd.GET(noticesCmd, req, userState(state.ReadAccess, 1000)).(*resp) + rsp, ok := noticesCmd.GET(noticesCmd, req, userState(identities.ReadAccess, 1000)).(*resp) c.Assert(ok, Equals, true) c.Check(rsp.Type, Equals, ResponseTypeSync) @@ -484,7 +485,7 @@ func (s *apiSuite) TestNoticesTimeout(c *C) { req, err := http.NewRequest("GET", "/v1/notices?timeout=1ms", nil) c.Assert(err, IsNil) noticesCmd := apiCmd("/v1/notices") - rsp, ok := noticesCmd.GET(noticesCmd, req, userState(state.ReadAccess, 1000)).(*resp) + rsp, ok := noticesCmd.GET(noticesCmd, req, userState(identities.ReadAccess, 1000)).(*resp) c.Assert(ok, Equals, true) c.Check(rsp.Type, Equals, ResponseTypeSync) @@ -508,7 +509,7 @@ func (s *apiSuite) TestNoticesRequestCancelled(c *C) { req, err := http.NewRequestWithContext(ctx, "GET", "/v1/notices?timeout=1s", nil) c.Assert(err, IsNil) noticesCmd := apiCmd("/v1/notices") - rsp, ok := noticesCmd.GET(noticesCmd, req, userState(state.ReadAccess, 1000)).(*resp) + rsp, ok := noticesCmd.GET(noticesCmd, req, userState(identities.ReadAccess, 1000)).(*resp) c.Assert(ok, Equals, true) c.Check(rsp.Type, Equals, ResponseTypeError) @@ -560,7 +561,7 @@ func (s *apiSuite) testNoticesBadRequest(c *C, query, errorMatch string) { req, err := http.NewRequest("GET", "/v1/notices?"+query, nil) c.Assert(err, IsNil) noticesCmd := apiCmd("/v1/notices") - rsp, ok := noticesCmd.GET(noticesCmd, req, userState(state.AdminAccess, 0)).(*resp) + rsp, ok := noticesCmd.GET(noticesCmd, req, userState(identities.AdminAccess, 0)).(*resp) c.Assert(ok, Equals, true) c.Check(rsp.Type, Equals, ResponseTypeError) @@ -584,7 +585,7 @@ func (s *apiSuite) TestAddNotice(c *C) { req, err := http.NewRequest("POST", "/v1/notices", bytes.NewReader(body)) c.Assert(err, IsNil) noticesCmd := apiCmd("/v1/notices") - rsp, ok := noticesCmd.POST(noticesCmd, req, userState(state.ReadAccess, 1000)).(*resp) + rsp, ok := noticesCmd.POST(noticesCmd, req, userState(identities.ReadAccess, 1000)).(*resp) c.Assert(ok, Equals, true) c.Check(rsp.Type, Equals, ResponseTypeSync) @@ -638,7 +639,7 @@ func (s *apiSuite) TestAddNoticeMinimal(c *C) { req, err := http.NewRequest("POST", "/v1/notices", bytes.NewReader(body)) c.Assert(err, IsNil) noticesCmd := apiCmd("/v1/notices") - rsp, ok := noticesCmd.POST(noticesCmd, req, userState(state.ReadAccess, 1000)).(*resp) + rsp, ok := noticesCmd.POST(noticesCmd, req, userState(identities.ReadAccess, 1000)).(*resp) c.Assert(ok, Equals, true) c.Check(rsp.Type, Equals, ResponseTypeSync) @@ -680,7 +681,7 @@ func (s *apiSuite) TestAddNoticeInvalidRequestUid(c *C) { req, err := http.NewRequest("POST", "/v1/notices", bytes.NewReader(body)) c.Assert(err, IsNil) noticesCmd := apiCmd("/v1/notices") - rsp, ok := noticesCmd.POST(noticesCmd, req, &UserState{Access: state.ReadAccess}).(*resp) + rsp, ok := noticesCmd.POST(noticesCmd, req, &UserState{Access: identities.ReadAccess}).(*resp) c.Assert(ok, Equals, true) c.Check(rsp.Type, Equals, ResponseTypeError) @@ -737,7 +738,7 @@ func (s *apiSuite) testAddNoticeBadRequest(c *C, body, errorMatch string) { req, err := http.NewRequest("POST", "/v1/notices", strings.NewReader(body)) c.Assert(err, IsNil) noticesCmd := apiCmd("/v1/notices") - rsp, ok := noticesCmd.POST(noticesCmd, req, userState(state.ReadAccess, 1000)).(*resp) + rsp, ok := noticesCmd.POST(noticesCmd, req, userState(identities.ReadAccess, 1000)).(*resp) c.Assert(ok, Equals, true) c.Check(rsp.Type, Equals, ResponseTypeError) @@ -765,7 +766,7 @@ func (s *apiSuite) TestNotice(c *C) { c.Assert(err, IsNil) noticesCmd := apiCmd("/v1/notices/{id}") s.vars = map[string]string{"id": noticeIDPublic} - rsp, ok := noticesCmd.GET(noticesCmd, req, userState(state.ReadAccess, 1000)).(*resp) + rsp, ok := noticesCmd.GET(noticesCmd, req, userState(identities.ReadAccess, 1000)).(*resp) c.Assert(ok, Equals, true) c.Check(rsp.Type, Equals, ResponseTypeSync) @@ -781,7 +782,7 @@ func (s *apiSuite) TestNotice(c *C) { c.Assert(err, IsNil) noticesCmd = apiCmd("/v1/notices/{id}") s.vars = map[string]string{"id": noticeIDPrivate} - rsp, ok = noticesCmd.GET(noticesCmd, req, userState(state.ReadAccess, 1000)).(*resp) + rsp, ok = noticesCmd.GET(noticesCmd, req, userState(identities.ReadAccess, 1000)).(*resp) c.Assert(ok, Equals, true) c.Check(rsp.Type, Equals, ResponseTypeSync) @@ -801,7 +802,7 @@ func (s *apiSuite) TestNoticeNotFound(c *C) { c.Assert(err, IsNil) noticesCmd := apiCmd("/v1/notices/{id}") s.vars = map[string]string{"id": "1234"} - rsp, ok := noticesCmd.GET(noticesCmd, req, userState(state.ReadAccess, 1000)).(*resp) + rsp, ok := noticesCmd.GET(noticesCmd, req, userState(identities.ReadAccess, 1000)).(*resp) c.Assert(ok, Equals, true) c.Check(rsp.Type, Equals, ResponseTypeError) @@ -817,7 +818,7 @@ func (s *apiSuite) TestNoticeUnknownRequestUID(c *C) { c.Assert(err, IsNil) noticesCmd := apiCmd("/v1/notices/{id}") s.vars = map[string]string{"id": "1234"} - rsp, ok := noticesCmd.GET(noticesCmd, req, &UserState{Access: state.ReadAccess}).(*resp) + rsp, ok := noticesCmd.GET(noticesCmd, req, &UserState{Access: identities.ReadAccess}).(*resp) c.Assert(ok, Equals, true) c.Check(rsp.Type, Equals, ResponseTypeError) @@ -840,7 +841,7 @@ func (s *apiSuite) TestNoticeAdminAllowed(c *C) { c.Assert(err, IsNil) noticesCmd := apiCmd("/v1/notices/{id}") s.vars = map[string]string{"id": noticeID} - rsp, ok := noticesCmd.GET(noticesCmd, req, userState(state.AdminAccess, 0)).(*resp) + rsp, ok := noticesCmd.GET(noticesCmd, req, userState(identities.AdminAccess, 0)).(*resp) c.Assert(ok, Equals, true) c.Check(rsp.Type, Equals, ResponseTypeSync) @@ -867,7 +868,7 @@ func (s *apiSuite) TestNoticeNonAdminNotAllowed(c *C) { c.Assert(err, IsNil) noticesCmd := apiCmd("/v1/notices/{id}") s.vars = map[string]string{"id": noticeID} - rsp, ok := noticesCmd.GET(noticesCmd, req, userState(state.ReadAccess, 1001)).(*resp) + rsp, ok := noticesCmd.GET(noticesCmd, req, userState(identities.ReadAccess, 1001)).(*resp) c.Assert(ok, Equals, true) c.Check(rsp.Type, Equals, ResponseTypeError) @@ -890,6 +891,6 @@ func addNotice(c *C, st *state.State, userID *uint32, noticeType state.NoticeTyp c.Assert(err, IsNil) } -func userState(access state.IdentityAccess, uid uint32) *UserState { +func userState(access identities.IdentityAccess, uid uint32) *UserState { return &UserState{Access: access, UID: &uid} } diff --git a/internals/daemon/daemon.go b/internals/daemon/daemon.go index d14f9e69c..b6f1cfa5f 100644 --- a/internals/daemon/daemon.go +++ b/internals/daemon/daemon.go @@ -42,6 +42,7 @@ import ( "github.com/canonical/pebble/internals/osutil" "github.com/canonical/pebble/internals/overlord" "github.com/canonical/pebble/internals/overlord/checkstate" + "github.com/canonical/pebble/internals/overlord/identities" "github.com/canonical/pebble/internals/overlord/restart" "github.com/canonical/pebble/internals/overlord/servstate" "github.com/canonical/pebble/internals/overlord/standby" @@ -194,7 +195,7 @@ type Daemon struct { // UserState represents the state of an authenticated API user. type UserState struct { - Access state.IdentityAccess + Access identities.IdentityAccess UID *uint32 Username string } @@ -218,7 +219,7 @@ type Command struct { d *Daemon } -func userFromRequest(st *state.State, r *http.Request, ucred *Ucrednet) *UserState { +func userFromRequest(st *state.State, identitiesMgr *identities.Manager, r *http.Request, ucred *Ucrednet) *UserState { // Does the connection include a single mTLS client identity // certificate? var clientCert *x509.Certificate @@ -242,7 +243,7 @@ func userFromRequest(st *state.State, r *http.Request, ucred *Ucrednet) *UserSta } st.Lock() - identity := st.IdentityFromInputs(userID, username, password, clientCert) + identity := identitiesMgr.IdentityFromInputs(userID, username, password, clientCert) st.Unlock() if identity != nil { @@ -321,17 +322,18 @@ func (c *Command) ServeHTTP(w http.ResponseWriter, r *http.Request) { // not good: https://github.com/canonical/pebble/pull/369 var user *UserState if _, isOpen := access.(OpenAccess); !isOpen { - user = userFromRequest(c.d.state, r, ucred) + identitiesMgr := c.d.Overlord().IdentitiesManager() + user = userFromRequest(c.d.state, identitiesMgr, r, ucred) } // If we don't have a named-identity user, use ucred UID to see if we have a default. if user == nil && ucred != nil { if ucred.Uid == 0 || ucred.Uid == uint32(os.Getuid()) { // Admin if UID is 0 (root) or the UID the daemon is running as. - user = &UserState{Access: state.AdminAccess, UID: &ucred.Uid} + user = &UserState{Access: identities.AdminAccess, UID: &ucred.Uid} } else { // Regular read access if any other local UID. - user = &UserState{Access: state.ReadAccess, UID: &ucred.Uid} + user = &UserState{Access: identities.ReadAccess, UID: &ucred.Uid} } } diff --git a/internals/daemon/daemon_test.go b/internals/daemon/daemon_test.go index 830ae688e..aef44dade 100644 --- a/internals/daemon/daemon_test.go +++ b/internals/daemon/daemon_test.go @@ -49,6 +49,7 @@ import ( "github.com/canonical/pebble/internals/logger" "github.com/canonical/pebble/internals/osutil" "github.com/canonical/pebble/internals/overlord" + "github.com/canonical/pebble/internals/overlord/identities" "github.com/canonical/pebble/internals/overlord/pairingstate" "github.com/canonical/pebble/internals/overlord/patch" "github.com/canonical/pebble/internals/overlord/restart" @@ -393,19 +394,21 @@ func (s *daemonSuite) testAccessChecker(c *C, tests []accessCheckerTestCase, rem d := s.newDaemon(c) // Add some named identities for testing with. + identitiesMgr, err := identities.NewManager(d.overlord.State()) + c.Assert(err, IsNil) d.state.Lock() - err := d.state.ReplaceIdentities(map[string]*state.Identity{ + err = identitiesMgr.ReplaceIdentities(map[string]*identities.Identity{ "adminuser": { - Access: state.AdminAccess, - Local: &state.LocalIdentity{UserID: 1}, + Access: identities.AdminAccess, + Local: &identities.LocalIdentity{UserID: 1}, }, "readuser": { - Access: state.ReadAccess, - Local: &state.LocalIdentity{UserID: 2}, + Access: identities.ReadAccess, + Local: &identities.LocalIdentity{UserID: 2}, }, "untrusteduser": { - Access: state.UntrustedAccess, - Local: &state.LocalIdentity{UserID: 3}, + Access: identities.UntrustedAccess, + Local: &identities.LocalIdentity{UserID: 3}, }, }) d.state.Unlock() @@ -597,7 +600,7 @@ func (s *daemonSuite) TestDefaultUcredUsers(c *C) { cmd.ServeHTTP(rec, req) c.Check(rec.Code, Equals, http.StatusOK) c.Assert(userSeen, NotNil) - c.Check(userSeen.Access, Equals, state.AdminAccess) + c.Check(userSeen.Access, Equals, identities.AdminAccess) c.Assert(userSeen.UID, NotNil) c.Check(*userSeen.UID, Equals, uint32(0)) @@ -611,7 +614,7 @@ func (s *daemonSuite) TestDefaultUcredUsers(c *C) { cmd.ServeHTTP(rec, req) c.Check(rec.Code, Equals, http.StatusOK) c.Assert(userSeen, NotNil) - c.Check(userSeen.Access, Equals, state.AdminAccess) + c.Check(userSeen.Access, Equals, identities.AdminAccess) c.Assert(userSeen.UID, NotNil) c.Check(*userSeen.UID, Equals, uint32(os.Getuid())) @@ -625,7 +628,7 @@ func (s *daemonSuite) TestDefaultUcredUsers(c *C) { cmd.ServeHTTP(rec, req) c.Check(rec.Code, Equals, http.StatusOK) c.Assert(userSeen, NotNil) - c.Check(userSeen.Access, Equals, state.ReadAccess) + c.Check(userSeen.Access, Equals, identities.ReadAccess) c.Assert(userSeen.UID, NotNil) c.Check(*userSeen.UID, Equals, uint32(os.Getuid()+1)) } @@ -1888,11 +1891,13 @@ func (s *daemonSuite) TestServeHTTPUserStateLocal(c *C) { d := s.newDaemon(c) // Set up a Local identity. + identitiesMgr, err := identities.NewManager(d.overlord.State()) + c.Assert(err, IsNil) d.state.Lock() - err := d.state.AddIdentities(map[string]*state.Identity{ + err = identitiesMgr.AddIdentities(map[string]*identities.Identity{ "localuser": { - Access: state.AdminAccess, - Local: &state.LocalIdentity{UserID: 1000}, + Access: identities.AdminAccess, + Local: &identities.LocalIdentity{UserID: 1000}, }, }) d.state.Unlock() @@ -1922,7 +1927,7 @@ func (s *daemonSuite) TestServeHTTPUserStateLocal(c *C) { // Verify UserState for Local identity. c.Assert(capturedUser, NotNil) c.Assert(capturedUser.Username, Equals, "localuser") - c.Assert(capturedUser.Access, Equals, state.AdminAccess) + c.Assert(capturedUser.Access, Equals, identities.AdminAccess) c.Assert(capturedUser.UID, NotNil) // This specific expectation is only a temporary workaround to @@ -1959,7 +1964,7 @@ func (s *daemonSuite) TestServeHTTPUserStateUIDOnly(c *C) { // Verify UserState for UID-only (no named identity) c.Assert(capturedUser, NotNil) c.Assert(capturedUser.Username, Equals, "") - c.Assert(capturedUser.Access, Equals, state.ReadAccess) + c.Assert(capturedUser.Access, Equals, identities.ReadAccess) c.Assert(capturedUser.UID, NotNil) c.Assert(*capturedUser.UID, Equals, uint32(5000)) } @@ -1992,7 +1997,7 @@ func (s *daemonSuite) TestServeHTTPUserStateUIDOnlyRoot(c *C) { // Verify UserState for root UID. c.Assert(capturedUser, NotNil) c.Assert(capturedUser.Username, Equals, "") - c.Assert(capturedUser.Access, Equals, state.AdminAccess) + c.Assert(capturedUser.Access, Equals, identities.AdminAccess) c.Assert(capturedUser.UID, NotNil) c.Assert(*capturedUser.UID, Equals, uint32(0)) } @@ -2026,7 +2031,7 @@ func (s *daemonSuite) TestServeHTTPUserStateUIDOnlyDaemonUID(c *C) { // Verify UserState for daemon UID. c.Assert(capturedUser, NotNil) c.Assert(capturedUser.Username, Equals, "") - c.Assert(capturedUser.Access, Equals, state.AdminAccess) + c.Assert(capturedUser.Access, Equals, identities.AdminAccess) c.Assert(capturedUser.UID, NotNil) c.Assert(*capturedUser.UID, Equals, daemonUID) } @@ -2040,11 +2045,13 @@ func (s *daemonSuite) TestServeHTTPUserStateBasicUnixSocket(c *C) { hashedPassword, err := crypt.Generate([]byte("test"), nil) c.Assert(err, IsNil) + identitiesMgr, err := identities.NewManager(d.overlord.State()) + c.Assert(err, IsNil) d.state.Lock() - err = d.state.AddIdentities(map[string]*state.Identity{ + err = identitiesMgr.AddIdentities(map[string]*identities.Identity{ "basicuser": { - Access: state.ReadAccess, - Basic: &state.BasicIdentity{Password: hashedPassword}, + Access: identities.ReadAccess, + Basic: &identities.BasicIdentity{Password: hashedPassword}, }, }) d.state.Unlock() @@ -2075,7 +2082,7 @@ func (s *daemonSuite) TestServeHTTPUserStateBasicUnixSocket(c *C) { // Verify UserState for Basic identity over Unix Socket. c.Assert(capturedUser, NotNil) c.Assert(capturedUser.Username, Equals, "basicuser") - c.Assert(capturedUser.Access, Equals, state.ReadAccess) + c.Assert(capturedUser.Access, Equals, identities.ReadAccess) c.Assert(capturedUser.UID, IsNil) } @@ -2088,11 +2095,13 @@ func (s *daemonSuite) TestServeHTTPUserStateBasicHTTP(c *C) { hashedPassword, err := crypt.Generate([]byte("test"), nil) c.Assert(err, IsNil) + identitiesMgr, err := identities.NewManager(d.overlord.State()) + c.Assert(err, IsNil) d.state.Lock() - err = d.state.AddIdentities(map[string]*state.Identity{ + err = identitiesMgr.AddIdentities(map[string]*identities.Identity{ "basicuser": { - Access: state.MetricsAccess, - Basic: &state.BasicIdentity{Password: hashedPassword}, + Access: identities.MetricsAccess, + Basic: &identities.BasicIdentity{Password: hashedPassword}, }, }) d.state.Unlock() @@ -2123,7 +2132,7 @@ func (s *daemonSuite) TestServeHTTPUserStateBasicHTTP(c *C) { // Verify UserState for Basic identity over HTTP. c.Assert(capturedUser, NotNil) c.Assert(capturedUser.Username, Equals, "basicuser") - c.Assert(capturedUser.Access, Equals, state.MetricsAccess) + c.Assert(capturedUser.Access, Equals, identities.MetricsAccess) c.Assert(capturedUser.UID, IsNil) } @@ -2147,11 +2156,13 @@ jwXVTUH4HLpbhK0RAaEPOL4h5jm36CrWTkxzpbdCrIu4NgPLQKJ6Cw== c.Assert(err, IsNil) // Set up a Cert identity. + identitiesMgr, err := identities.NewManager(d.overlord.State()) + c.Assert(err, IsNil) d.state.Lock() - err = d.state.AddIdentities(map[string]*state.Identity{ + err = identitiesMgr.AddIdentities(map[string]*identities.Identity{ "certuser1": { - Access: state.AdminAccess, - Cert: &state.CertIdentity{X509: cert}, + Access: identities.AdminAccess, + Cert: &identities.CertIdentity{X509: cert}, }, }) d.state.Unlock() @@ -2184,6 +2195,6 @@ jwXVTUH4HLpbhK0RAaEPOL4h5jm36CrWTkxzpbdCrIu4NgPLQKJ6Cw== // Verify UserState for Cert identity. c.Assert(capturedUser, NotNil) c.Assert(capturedUser.Username, Equals, "certuser1") - c.Assert(capturedUser.Access, Equals, state.AdminAccess) + c.Assert(capturedUser.Access, Equals, identities.AdminAccess) c.Assert(capturedUser.UID, IsNil) } diff --git a/internals/overlord/state/identities.go b/internals/overlord/identities/identities.go similarity index 71% rename from internals/overlord/state/identities.go rename to internals/overlord/identities/identities.go index 3d64b0837..5e71d2766 100644 --- a/internals/overlord/state/identities.go +++ b/internals/overlord/identities/identities.go @@ -1,4 +1,4 @@ -// Copyright (c) 2024 Canonical Ltd +// Copyright (c) 2025 Canonical Ltd // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU General Public License version 3 as @@ -12,7 +12,80 @@ // You should have received a copy of the GNU General Public License // along with this program. If not, see . -package state +/* +TODO: from state package -- what to do with these? +// marshalledIdentity is used specifically for marshalling to the state +// database file. Unlike apiIdentity, it should include secrets. +type marshalledIdentity struct { + Access string `json:"access"` + Local *marshalledLocalIdentity `json:"local,omitempty"` + Basic *marshalledBasicIdentity `json:"basic,omitempty"` + Cert *marshalledCertIdentity `json:"cert,omitempty"` +} + +type marshalledLocalIdentity struct { + UserID uint32 `json:"user-id"` +} + +type marshalledBasicIdentity struct { + Password string `json:"password"` +} + +type marshalledCertIdentity struct { + PEM string `json:"pem"` +} + +func (s *State) marshalledIdentities() map[string]*marshalledIdentity { + marshalled := make(map[string]*marshalledIdentity, len(s.identities)) + for name, identity := range s.identities { + marshalled[name] = &marshalledIdentity{ + Access: string(identity.Access), + } + if identity.Local != nil { + marshalled[name].Local = &marshalledLocalIdentity{UserID: identity.Local.UserID} + } + if identity.Basic != nil { + marshalled[name].Basic = &marshalledBasicIdentity{Password: identity.Basic.Password} + } + if identity.Cert != nil { + pemBlock := &pem.Block{ + Type: "CERTIFICATE", + Bytes: identity.Cert.X509.Raw, + } + marshalled[name].Cert = &marshalledCertIdentity{PEM: string(pem.EncodeToMemory(pemBlock))} + } + } + return marshalled +} + + +func (s *State) unmarshalIdentities(marshalled map[string]*marshalledIdentity) error { + s.identities = make(map[string]*Identity, len(marshalled)) + for name, mi := range marshalled { + s.identities[name] = &Identity{ + Name: name, + Access: IdentityAccess(mi.Access), + } + if mi.Local != nil { + s.identities[name].Local = &LocalIdentity{UserID: mi.Local.UserID} + } + if mi.Basic != nil { + s.identities[name].Basic = &BasicIdentity{Password: mi.Basic.Password} + } + if mi.Cert != nil { + block, _ := pem.Decode([]byte(mi.Cert.PEM)) + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return fmt.Errorf("cannot parse certificate from cert identity: %w", err) + } + s.identities[name].Cert = &CertIdentity{X509: cert} + } + } + return nil +} +*/ + +package identities import ( "crypto/x509" @@ -20,13 +93,42 @@ import ( "encoding/pem" "errors" "fmt" + "maps" "regexp" "sort" "strings" "github.com/GehirnInc/crypt/sha512_crypt" + + "github.com/canonical/pebble/internals/overlord/state" ) +const ( + identitiesKey = "identities" +) + +type Manager struct { + state *state.State + + // Keep a local copy to avoid having to deserialize from state each time. + identities map[string]*Identity +} + +func NewManager(st *state.State) (*Manager, error) { + m := &Manager{ + state: st, + } + err := st.Get(identitiesKey, &m.identities) + if err != nil && !errors.Is(err, state.ErrNoState) { + return nil, err + } + return m, nil +} + +func (m *Manager) Ensure() error { + return nil +} + // Identity holds the configuration of a single identity. type Identity struct { Name string @@ -212,13 +314,13 @@ func (d *Identity) UnmarshalJSON(data []byte) error { // AddIdentities adds the given identities to the system. It's an error if any // of the named identities already exist. -func (s *State) AddIdentities(identities map[string]*Identity) error { - s.reading() - +// +// The state lock must be held for the duration of this call. +func (m *Manager) AddIdentities(identities map[string]*Identity) error { // If any of the named identities already exist, return an error. var existing []string for name, identity := range identities { - if _, ok := s.identities[name]; ok { + if _, ok := m.identities[name]; ok { existing = append(existing, name) } err := identity.validate(name) @@ -231,7 +333,7 @@ func (s *State) AddIdentities(identities map[string]*Identity) error { return fmt.Errorf("identities already exist: %s", strings.Join(existing, ", ")) } - newIdentities := s.cloneIdentities() + newIdentities := maps.Clone(m.identities) for name, identity := range identities { identity.Name = name newIdentities[name] = identity @@ -242,20 +344,20 @@ func (s *State) AddIdentities(identities map[string]*Identity) error { return err } - s.writing() - s.identities = newIdentities + m.identities = newIdentities + m.state.Set(identitiesKey, newIdentities) return nil } // UpdateIdentities updates the given identities in the system. It's an error // if any of the named identities do not exist. -func (s *State) UpdateIdentities(identities map[string]*Identity) error { - s.reading() - +// +// The state lock must be held for the duration of this call. +func (m *Manager) UpdateIdentities(identities map[string]*Identity) error { // If any of the named identities don't exist, return an error. var missing []string for name, identity := range identities { - if _, ok := s.identities[name]; !ok { + if _, ok := m.identities[name]; !ok { missing = append(missing, name) } err := identity.validate(name) @@ -268,7 +370,7 @@ func (s *State) UpdateIdentities(identities map[string]*Identity) error { return fmt.Errorf("identities do not exist: %s", strings.Join(missing, ", ")) } - newIdentities := s.cloneIdentities() + newIdentities := maps.Clone(m.identities) for name, identity := range identities { identity.Name = name newIdentities[name] = identity @@ -279,17 +381,17 @@ func (s *State) UpdateIdentities(identities map[string]*Identity) error { return err } - s.writing() - s.identities = newIdentities + m.identities = newIdentities + m.state.Set(identitiesKey, newIdentities) return nil } // ReplaceIdentities replaces the named identities in the system with the // given identities (adding those that don't exist), or removes them if the // map value is nil. -func (s *State) ReplaceIdentities(identities map[string]*Identity) error { - s.reading() - +// +// The state lock must be held for the duration of this call. +func (m *Manager) ReplaceIdentities(identities map[string]*Identity) error { for name, identity := range identities { if identity != nil { err := identity.validate(name) @@ -299,7 +401,7 @@ func (s *State) ReplaceIdentities(identities map[string]*Identity) error { } } - newIdentities := s.cloneIdentities() + newIdentities := maps.Clone(m.identities) for name, identity := range identities { if identity == nil { delete(newIdentities, name) @@ -314,20 +416,20 @@ func (s *State) ReplaceIdentities(identities map[string]*Identity) error { return err } - s.writing() - s.identities = newIdentities + m.identities = newIdentities + m.state.Set(identitiesKey, newIdentities) return nil } // RemoveIdentities removes the named identities from the system. It's an // error if any of the named identities do not exist. -func (s *State) RemoveIdentities(identities map[string]struct{}) error { - s.reading() - +// +// The state lock must be held for the duration of this call. +func (m *Manager) RemoveIdentities(identities map[string]struct{}) error { // If any of the named identities don't exist, return an error. var missing []string for name := range identities { - if _, ok := s.identities[name]; !ok { + if _, ok := m.identities[name]; !ok { missing = append(missing, name) } } @@ -336,23 +438,19 @@ func (s *State) RemoveIdentities(identities map[string]struct{}) error { return fmt.Errorf("identities do not exist: %s", strings.Join(missing, ", ")) } - s.writing() for name := range identities { - delete(s.identities, name) + delete(m.identities, name) } + m.state.Set(identitiesKey, m.identities) return nil } // Identities returns all the identities in the system. The returned map is a // shallow clone, so map mutations won't affect state. -func (s *State) Identities() map[string]*Identity { - s.reading() - - result := make(map[string]*Identity, len(s.identities)) - for name, identity := range s.identities { - result[name] = identity - } - return result +// +// The state lock must be held for the duration of this call. +func (m *Manager) Identities() map[string]*Identity { + return maps.Clone(m.identities) } // IdentityFromInputs returns an identity matching the given inputs. @@ -361,12 +459,12 @@ func (s *State) Identities() map[string]*Identity { // because they are intentionally setup by the client. // // If no matching identity is found for the given inputs, nil is returned. -func (s *State) IdentityFromInputs(userID *uint32, username, password string, clientCert *x509.Certificate) *Identity { - s.reading() - +// +// The state lock must be held for the duration of this call. +func (m *Manager) IdentityFromInputs(userID *uint32, username, password string, clientCert *x509.Certificate) *Identity { switch { case clientCert != nil: - for _, identity := range s.identities { + for _, identity := range m.identities { if identity.Cert != nil && identity.Cert.X509.Equal(clientCert) { // Certificate identities can be added // manually, so we still need to verify @@ -390,7 +488,7 @@ func (s *State) IdentityFromInputs(userID *uint32, username, password string, cl case username != "" || password != "": passwordBytes := []byte(password) - for _, identity := range s.identities { + for _, identity := range m.identities { if identity.Basic == nil || identity.Name != username { continue } @@ -406,7 +504,7 @@ func (s *State) IdentityFromInputs(userID *uint32, username, password string, cl return nil case userID != nil: - for _, identity := range s.identities { + for _, identity := range m.identities { if identity.Local != nil && identity.Local.UserID == *userID { return identity } @@ -418,14 +516,6 @@ func (s *State) IdentityFromInputs(userID *uint32, username, password string, cl return nil } -func (s *State) cloneIdentities() map[string]*Identity { - newIdentities := make(map[string]*Identity, len(s.identities)) - for name, identity := range s.identities { - newIdentities[name] = identity - } - return newIdentities -} - func verifyUniqueUserIDs(identities map[string]*Identity) error { userIDs := make(map[uint32][]string) // maps user ID to identity names for name, identity := range identities { diff --git a/internals/overlord/state/identities_test.go b/internals/overlord/identities/identities_test.go similarity index 61% rename from internals/overlord/state/identities_test.go rename to internals/overlord/identities/identities_test.go index 4d514ce2c..e2d8c5fea 100644 --- a/internals/overlord/state/identities_test.go +++ b/internals/overlord/identities/identities_test.go @@ -12,7 +12,7 @@ // You should have received a copy of the GNU General Public License // along with this program. If not, see . -package state_test +package identities_test import ( "crypto/x509" @@ -22,6 +22,7 @@ import ( . "gopkg.in/check.v1" + "github.com/canonical/pebble/internals/overlord/identities" "github.com/canonical/pebble/internals/overlord/state" ) @@ -64,30 +65,33 @@ GV6pXv511MycDg== // IMPORTANT NOTE: be sure secrets aren't included when adding to this! func (s *identitiesSuite) TestMarshalAPI(c *C) { st := state.New(nil) + mgr, err := identities.NewManager(st) + c.Assert(err, IsNil) + st.Lock() defer st.Unlock() - err := st.AddIdentities(map[string]*state.Identity{ + err = mgr.AddIdentities(map[string]*identities.Identity{ "bob": { - Access: state.ReadAccess, - Local: &state.LocalIdentity{UserID: 42}, + Access: identities.ReadAccess, + Local: &identities.LocalIdentity{UserID: 42}, }, "mary": { - Access: state.AdminAccess, - Local: &state.LocalIdentity{UserID: 1000}, + Access: identities.AdminAccess, + Local: &identities.LocalIdentity{UserID: 1000}, }, "nancy": { - Access: state.MetricsAccess, - Basic: &state.BasicIdentity{Password: "hash"}, + Access: identities.MetricsAccess, + Basic: &identities.BasicIdentity{Password: "hash"}, }, "olivia": { - Access: state.ReadAccess, - Cert: &state.CertIdentity{X509: parseCert(c, validPEMX509Cert)}, + Access: identities.ReadAccess, + Cert: &identities.CertIdentity{X509: parseCert(c, validPEMX509Cert)}, }, }) c.Assert(err, IsNil) - identities := st.Identities() + identities := mgr.Identities() data, err := json.MarshalIndent(identities, "", " ") c.Assert(err, IsNil) c.Assert(string(data), Equals, ` @@ -149,25 +153,25 @@ func (s *identitiesSuite) TestUnmarshalAPI(c *C) { } } }`, jsonCert) - var identities map[string]*state.Identity - err = json.Unmarshal(data, &identities) + var idents map[string]*identities.Identity + err = json.Unmarshal(data, &idents) c.Assert(err, IsNil) - c.Assert(identities, DeepEquals, map[string]*state.Identity{ + c.Assert(idents, DeepEquals, map[string]*identities.Identity{ "bob": { - Access: state.ReadAccess, - Local: &state.LocalIdentity{UserID: 42}, + Access: identities.ReadAccess, + Local: &identities.LocalIdentity{UserID: 42}, }, "mary": { - Access: state.AdminAccess, - Local: &state.LocalIdentity{UserID: 1000}, + Access: identities.AdminAccess, + Local: &identities.LocalIdentity{UserID: 1000}, }, "nancy": { - Access: state.MetricsAccess, - Basic: &state.BasicIdentity{Password: "hash"}, + Access: identities.MetricsAccess, + Basic: &identities.BasicIdentity{Password: "hash"}, }, "olivia": { - Access: state.ReadAccess, - Cert: &state.CertIdentity{X509: parseCert(c, validPEMX509Cert)}, + Access: identities.ReadAccess, + Cert: &identities.CertIdentity{X509: parseCert(c, validPEMX509Cert)}, }, }) } @@ -213,7 +217,7 @@ func (s *identitiesSuite) TestUnmarshalAPIErrors(c *C) { }} for _, test := range tests { c.Logf("Input data: %s", test.data) - var identities map[string]*state.Identity + var identities map[string]*identities.Identity err := json.Unmarshal([]byte(test.data), &identities) c.Check(err, ErrorMatches, test.error) } @@ -221,22 +225,25 @@ func (s *identitiesSuite) TestUnmarshalAPIErrors(c *C) { func (s *identitiesSuite) TestMarshalState(c *C) { st := state.New(nil) + mgr, err := identities.NewManager(st) + c.Assert(err, IsNil) + st.Lock() defer st.Unlock() - err := st.AddIdentities(map[string]*state.Identity{ + err = mgr.AddIdentities(map[string]*identities.Identity{ "bob": { - Access: state.ReadAccess, - Local: &state.LocalIdentity{UserID: 42}, + Access: identities.ReadAccess, + Local: &identities.LocalIdentity{UserID: 42}, }, "mary": { - Access: state.AdminAccess, - Local: &state.LocalIdentity{UserID: 1000}, + Access: identities.AdminAccess, + Local: &identities.LocalIdentity{UserID: 1000}, }, }) c.Assert(err, IsNil) - // Marshal entire state, then pull out just the "identities" key to test that. + // Marshal entire identities, then pull out just the "identities" key to test that. data, err := json.Marshal(st) c.Assert(err, IsNil) var unmarshalled map[string]any @@ -263,6 +270,9 @@ func (s *identitiesSuite) TestMarshalState(c *C) { func (s *identitiesSuite) TestUnmarshalState(c *C) { st := state.New(nil) + mgr, err := identities.NewManager(st) + c.Assert(err, IsNil) + st.Lock() defer st.Unlock() @@ -283,107 +293,110 @@ func (s *identitiesSuite) TestUnmarshalState(c *C) { } } }`) - err := json.Unmarshal(data, &st) + err = json.Unmarshal(data, &st) c.Assert(err, IsNil) - c.Assert(st.Identities(), DeepEquals, map[string]*state.Identity{ + c.Assert(mgr.Identities(), DeepEquals, map[string]*identities.Identity{ "bob": { Name: "bob", - Access: state.ReadAccess, - Local: &state.LocalIdentity{UserID: 42}, + Access: identities.ReadAccess, + Local: &identities.LocalIdentity{UserID: 42}, }, "mary": { Name: "mary", - Access: state.AdminAccess, - Local: &state.LocalIdentity{UserID: 1000}, + Access: identities.AdminAccess, + Local: &identities.LocalIdentity{UserID: 1000}, }, }) } func (s *identitiesSuite) TestAddIdentities(c *C) { st := state.New(nil) + mgr, err := identities.NewManager(st) + c.Assert(err, IsNil) + st.Lock() defer st.Unlock() - original := map[string]*state.Identity{ + original := map[string]*identities.Identity{ "bob": { - Access: state.ReadAccess, - Local: &state.LocalIdentity{UserID: 42}, + Access: identities.ReadAccess, + Local: &identities.LocalIdentity{UserID: 42}, }, "mary": { - Access: state.AdminAccess, - Local: &state.LocalIdentity{UserID: 1000}, + Access: identities.AdminAccess, + Local: &identities.LocalIdentity{UserID: 1000}, }, "nancy": { - Access: state.MetricsAccess, - Basic: &state.BasicIdentity{Password: "hash"}, + Access: identities.MetricsAccess, + Basic: &identities.BasicIdentity{Password: "hash"}, }, "olivia": { - Access: state.ReadAccess, - Cert: &state.CertIdentity{X509: parseCert(c, validPEMX509Cert)}, + Access: identities.ReadAccess, + Cert: &identities.CertIdentity{X509: parseCert(c, validPEMX509Cert)}, }, } - err := st.AddIdentities(original) + err = mgr.AddIdentities(original) c.Assert(err, IsNil) // Ensure they were added correctly (and Name fields have been set). - identities := st.Identities() - c.Assert(identities, DeepEquals, map[string]*state.Identity{ + idents := mgr.Identities() + c.Assert(idents, DeepEquals, map[string]*identities.Identity{ "bob": { Name: "bob", - Access: state.ReadAccess, - Local: &state.LocalIdentity{UserID: 42}, + Access: identities.ReadAccess, + Local: &identities.LocalIdentity{UserID: 42}, }, "mary": { Name: "mary", - Access: state.AdminAccess, - Local: &state.LocalIdentity{UserID: 1000}, + Access: identities.AdminAccess, + Local: &identities.LocalIdentity{UserID: 1000}, }, "nancy": { Name: "nancy", - Access: state.MetricsAccess, - Basic: &state.BasicIdentity{Password: "hash"}, + Access: identities.MetricsAccess, + Basic: &identities.BasicIdentity{Password: "hash"}, }, "olivia": { Name: "olivia", - Access: state.ReadAccess, - Cert: &state.CertIdentity{X509: parseCert(c, validPEMX509Cert)}, + Access: identities.ReadAccess, + Cert: &identities.CertIdentity{X509: parseCert(c, validPEMX509Cert)}, }, }) // Can't add identity names that already exist. - err = st.AddIdentities(map[string]*state.Identity{ + err = mgr.AddIdentities(map[string]*identities.Identity{ "bill": { - Access: state.ReadAccess, - Local: &state.LocalIdentity{UserID: 43}, + Access: identities.ReadAccess, + Local: &identities.LocalIdentity{UserID: 43}, }, "bob": { - Access: state.ReadAccess, - Local: &state.LocalIdentity{UserID: 42}, + Access: identities.ReadAccess, + Local: &identities.LocalIdentity{UserID: 42}, }, "mary": { - Access: state.AdminAccess, - Local: &state.LocalIdentity{UserID: 1000}, + Access: identities.AdminAccess, + Local: &identities.LocalIdentity{UserID: 1000}, }, }) c.Assert(err, ErrorMatches, "identities already exist: bob, mary") // Can't add a nil identity. - err = st.AddIdentities(map[string]*state.Identity{ + err = mgr.AddIdentities(map[string]*identities.Identity{ "bill": nil, }) c.Assert(err, ErrorMatches, `identity "bill" invalid: identity must not be nil`) // Access value must be valid. - err = st.AddIdentities(map[string]*state.Identity{ + err = mgr.AddIdentities(map[string]*identities.Identity{ "bill": { Access: "bar", - Local: &state.LocalIdentity{UserID: 43}, + Local: &identities.LocalIdentity{UserID: 43}, }, }) c.Assert(err, ErrorMatches, `identity "bill" invalid: invalid access value "bar", must be "admin", "read", "metrics", or "untrusted"`) // Must have at least one type. - err = st.AddIdentities(map[string]*state.Identity{ + err = mgr.AddIdentities(map[string]*identities.Identity{ "bill": { Access: "admin", }, @@ -391,37 +404,37 @@ func (s *identitiesSuite) TestAddIdentities(c *C) { c.Assert(err, ErrorMatches, `identity "bill" invalid: identity must have at least one type \("local", "basic", or "cert"\)`) // May have two types. - err = st.AddIdentities(map[string]*state.Identity{ + err = mgr.AddIdentities(map[string]*identities.Identity{ "peter": { - Access: state.MetricsAccess, - Basic: &state.BasicIdentity{Password: "hash"}, - Local: &state.LocalIdentity{UserID: 1001}, + Access: identities.MetricsAccess, + Basic: &identities.BasicIdentity{Password: "hash"}, + Local: &identities.LocalIdentity{UserID: 1001}, }, }) c.Assert(err, IsNil) // Ensure user IDs are unique with existing users. - err = st.AddIdentities(map[string]*state.Identity{ + err = mgr.AddIdentities(map[string]*identities.Identity{ "bill": { - Access: state.ReadAccess, - Local: &state.LocalIdentity{UserID: 1000}, + Access: identities.ReadAccess, + Local: &identities.LocalIdentity{UserID: 1000}, }, }) c.Assert(err, ErrorMatches, `cannot have multiple identities with user ID 1000 \(bill, mary\)`) // Ensure user IDs are unique among the ones being added (and test >2 with same UID). - err = st.AddIdentities(map[string]*state.Identity{ + err = mgr.AddIdentities(map[string]*identities.Identity{ "bill": { - Access: state.ReadAccess, - Local: &state.LocalIdentity{UserID: 2000}, + Access: identities.ReadAccess, + Local: &identities.LocalIdentity{UserID: 2000}, }, "bale": { - Access: state.ReadAccess, - Local: &state.LocalIdentity{UserID: 2000}, + Access: identities.ReadAccess, + Local: &identities.LocalIdentity{UserID: 2000}, }, "boll": { - Access: state.ReadAccess, - Local: &state.LocalIdentity{UserID: 2000}, + Access: identities.ReadAccess, + Local: &identities.LocalIdentity{UserID: 2000}, }, }) c.Assert(err, ErrorMatches, `cannot have multiple identities with user ID 2000 \(bale, bill, boll\)`) @@ -429,90 +442,93 @@ func (s *identitiesSuite) TestAddIdentities(c *C) { func (s *identitiesSuite) TestUpdateIdentities(c *C) { st := state.New(nil) + mgr, err := identities.NewManager(st) + c.Assert(err, IsNil) + st.Lock() defer st.Unlock() - original := map[string]*state.Identity{ + original := map[string]*identities.Identity{ "bob": { - Access: state.ReadAccess, - Local: &state.LocalIdentity{UserID: 42}, + Access: identities.ReadAccess, + Local: &identities.LocalIdentity{UserID: 42}, }, "mary": { - Access: state.AdminAccess, - Local: &state.LocalIdentity{UserID: 1000}, + Access: identities.AdminAccess, + Local: &identities.LocalIdentity{UserID: 1000}, }, "nancy": { - Access: state.MetricsAccess, - Basic: &state.BasicIdentity{Password: "hash"}, + Access: identities.MetricsAccess, + Basic: &identities.BasicIdentity{Password: "hash"}, }, } - err := st.AddIdentities(original) + err = mgr.AddIdentities(original) c.Assert(err, IsNil) - err = st.UpdateIdentities(map[string]*state.Identity{ + err = mgr.UpdateIdentities(map[string]*identities.Identity{ "bob": { - Access: state.AdminAccess, - Local: &state.LocalIdentity{UserID: 1000}, + Access: identities.AdminAccess, + Local: &identities.LocalIdentity{UserID: 1000}, }, "mary": { - Access: state.ReadAccess, - Local: &state.LocalIdentity{UserID: 42}, + Access: identities.ReadAccess, + Local: &identities.LocalIdentity{UserID: 42}, }, "nancy": { - Access: state.MetricsAccess, - Basic: &state.BasicIdentity{Password: "new hash"}, + Access: identities.MetricsAccess, + Basic: &identities.BasicIdentity{Password: "new hash"}, }, }) c.Assert(err, IsNil) // Ensure they were updated correctly. - identities := st.Identities() - c.Assert(identities, DeepEquals, map[string]*state.Identity{ + idents := mgr.Identities() + c.Assert(idents, DeepEquals, map[string]*identities.Identity{ "bob": { Name: "bob", - Access: state.AdminAccess, - Local: &state.LocalIdentity{UserID: 1000}, + Access: identities.AdminAccess, + Local: &identities.LocalIdentity{UserID: 1000}, }, "mary": { Name: "mary", - Access: state.ReadAccess, - Local: &state.LocalIdentity{UserID: 42}, + Access: identities.ReadAccess, + Local: &identities.LocalIdentity{UserID: 42}, }, "nancy": { Name: "nancy", - Access: state.MetricsAccess, - Basic: &state.BasicIdentity{Password: "new hash"}, + Access: identities.MetricsAccess, + Basic: &identities.BasicIdentity{Password: "new hash"}, }, }) // Can't update identity names that don't exist. - err = st.UpdateIdentities(map[string]*state.Identity{ + err = mgr.UpdateIdentities(map[string]*identities.Identity{ "bill": { - Access: state.ReadAccess, - Local: &state.LocalIdentity{UserID: 43}, + Access: identities.ReadAccess, + Local: &identities.LocalIdentity{UserID: 43}, }, "bale": { - Access: state.ReadAccess, - Local: &state.LocalIdentity{UserID: 42}, + Access: identities.ReadAccess, + Local: &identities.LocalIdentity{UserID: 42}, }, "mary": { - Access: state.AdminAccess, - Local: &state.LocalIdentity{UserID: 1000}, + Access: identities.AdminAccess, + Local: &identities.LocalIdentity{UserID: 1000}, }, }) c.Assert(err, ErrorMatches, "identities do not exist: bale, bill") // Ensure validation is being done (full testing done in AddIdentity). - err = st.UpdateIdentities(map[string]*state.Identity{ + err = mgr.UpdateIdentities(map[string]*identities.Identity{ "bob": nil, }) c.Assert(err, ErrorMatches, `identity "bob" invalid: identity must not be nil`) // Ensure unique user ID testing is being done (full testing done in AddIdentity). - err = st.UpdateIdentities(map[string]*state.Identity{ + err = mgr.UpdateIdentities(map[string]*identities.Identity{ "bob": { - Access: state.AdminAccess, - Local: &state.LocalIdentity{UserID: 42}, + Access: identities.AdminAccess, + Local: &identities.LocalIdentity{UserID: 42}, }, }) c.Assert(err, ErrorMatches, `cannot have multiple identities with user ID 42 \(bob, mary\)`) @@ -520,52 +536,55 @@ func (s *identitiesSuite) TestUpdateIdentities(c *C) { func (s *identitiesSuite) TestReplaceIdentities(c *C) { st := state.New(nil) + mgr, err := identities.NewManager(st) + c.Assert(err, IsNil) + st.Lock() defer st.Unlock() - original := map[string]*state.Identity{ + original := map[string]*identities.Identity{ "bob": { - Access: state.ReadAccess, - Local: &state.LocalIdentity{UserID: 42}, + Access: identities.ReadAccess, + Local: &identities.LocalIdentity{UserID: 42}, }, "mary": { - Access: state.AdminAccess, - Local: &state.LocalIdentity{UserID: 1000}, + Access: identities.AdminAccess, + Local: &identities.LocalIdentity{UserID: 1000}, }, } - err := st.AddIdentities(original) + err = mgr.AddIdentities(original) c.Assert(err, IsNil) - err = st.ReplaceIdentities(map[string]*state.Identity{ + err = mgr.ReplaceIdentities(map[string]*identities.Identity{ "bob": nil, // nil means remove it "mary": { - Access: state.ReadAccess, - Local: &state.LocalIdentity{UserID: 43}, + Access: identities.ReadAccess, + Local: &identities.LocalIdentity{UserID: 43}, }, "newguy": { - Access: state.AdminAccess, - Local: &state.LocalIdentity{UserID: 44}, + Access: identities.AdminAccess, + Local: &identities.LocalIdentity{UserID: 44}, }, }) c.Assert(err, IsNil) // Ensure they were added/updated/deleted correctly. - identities := st.Identities() - c.Assert(identities, DeepEquals, map[string]*state.Identity{ + idents := mgr.Identities() + c.Assert(idents, DeepEquals, map[string]*identities.Identity{ "mary": { Name: "mary", - Access: state.ReadAccess, - Local: &state.LocalIdentity{UserID: 43}, + Access: identities.ReadAccess, + Local: &identities.LocalIdentity{UserID: 43}, }, "newguy": { Name: "newguy", - Access: state.AdminAccess, - Local: &state.LocalIdentity{UserID: 44}, + Access: identities.AdminAccess, + Local: &identities.LocalIdentity{UserID: 44}, }, }) // Ensure validation is being done (full testing done in AddIdentity). - err = st.ReplaceIdentities(map[string]*state.Identity{ + err = mgr.ReplaceIdentities(map[string]*identities.Identity{ "bill": { Access: "admin", }, @@ -573,10 +592,10 @@ func (s *identitiesSuite) TestReplaceIdentities(c *C) { c.Assert(err, ErrorMatches, `identity "bill" invalid: identity must have at least one type \("local", "basic", or "cert"\)`) // Ensure unique user ID testing is being done (full testing done in AddIdentity). - err = st.ReplaceIdentities(map[string]*state.Identity{ + err = mgr.ReplaceIdentities(map[string]*identities.Identity{ "bob": { - Access: state.AdminAccess, - Local: &state.LocalIdentity{UserID: 43}, + Access: identities.AdminAccess, + Local: &identities.LocalIdentity{UserID: 43}, }, }) c.Assert(err, ErrorMatches, `cannot have multiple identities with user ID 43 \(bob, mary\)`) @@ -584,35 +603,38 @@ func (s *identitiesSuite) TestReplaceIdentities(c *C) { func (s *identitiesSuite) TestRemoveIdentities(c *C) { st := state.New(nil) + mgr, err := identities.NewManager(st) + c.Assert(err, IsNil) + st.Lock() defer st.Unlock() - original := map[string]*state.Identity{ + original := map[string]*identities.Identity{ "bill": { - Access: state.ReadAccess, - Local: &state.LocalIdentity{UserID: 43}, + Access: identities.ReadAccess, + Local: &identities.LocalIdentity{UserID: 43}, }, "bob": { - Access: state.ReadAccess, - Local: &state.LocalIdentity{UserID: 42}, + Access: identities.ReadAccess, + Local: &identities.LocalIdentity{UserID: 42}, }, "mary": { - Access: state.AdminAccess, - Local: &state.LocalIdentity{UserID: 1000}, + Access: identities.AdminAccess, + Local: &identities.LocalIdentity{UserID: 1000}, }, "nancy": { - Access: state.MetricsAccess, - Basic: &state.BasicIdentity{Password: "hash"}, + Access: identities.MetricsAccess, + Basic: &identities.BasicIdentity{Password: "hash"}, }, "queen": { - Access: state.AdminAccess, - Local: &state.LocalIdentity{UserID: 1001}, + Access: identities.AdminAccess, + Local: &identities.LocalIdentity{UserID: 1001}, }, } - err := st.AddIdentities(original) + err = mgr.AddIdentities(original) c.Assert(err, IsNil) - err = st.RemoveIdentities(map[string]struct{}{ + err = mgr.RemoveIdentities(map[string]struct{}{ "bob": {}, "mary": {}, "nancy": {}, @@ -620,22 +642,22 @@ func (s *identitiesSuite) TestRemoveIdentities(c *C) { c.Assert(err, IsNil) // Ensure they were removed correctly. - identities := st.Identities() - c.Assert(identities, DeepEquals, map[string]*state.Identity{ + idents := mgr.Identities() + c.Assert(idents, DeepEquals, map[string]*identities.Identity{ "bill": { Name: "bill", - Access: state.ReadAccess, - Local: &state.LocalIdentity{UserID: 43}, + Access: identities.ReadAccess, + Local: &identities.LocalIdentity{UserID: 43}, }, "queen": { Name: "queen", - Access: state.AdminAccess, - Local: &state.LocalIdentity{UserID: 1001}, + Access: identities.AdminAccess, + Local: &identities.LocalIdentity{UserID: 1001}, }, }) // Can't remove identity names that don't exist. - err = st.RemoveIdentities(map[string]struct{}{ + err = mgr.RemoveIdentities(map[string]struct{}{ "bill": {}, "bale": {}, "mary": {}, @@ -645,77 +667,83 @@ func (s *identitiesSuite) TestRemoveIdentities(c *C) { func (s *identitiesSuite) TestIdentities(c *C) { st := state.New(nil) + mgr, err := identities.NewManager(st) + c.Assert(err, IsNil) + st.Lock() defer st.Unlock() - original := map[string]*state.Identity{ + original := map[string]*identities.Identity{ "bob": { - Access: state.ReadAccess, - Local: &state.LocalIdentity{UserID: 42}, + Access: identities.ReadAccess, + Local: &identities.LocalIdentity{UserID: 42}, }, "mary": { - Access: state.AdminAccess, - Local: &state.LocalIdentity{UserID: 1000}, + Access: identities.AdminAccess, + Local: &identities.LocalIdentity{UserID: 1000}, }, "nancy": { - Access: state.MetricsAccess, - Basic: &state.BasicIdentity{Password: "hash"}, + Access: identities.MetricsAccess, + Basic: &identities.BasicIdentity{Password: "hash"}, }, } - err := st.AddIdentities(original) + err = mgr.AddIdentities(original) c.Assert(err, IsNil) // Ensure it returns correct results. - identities := st.Identities() - expected := map[string]*state.Identity{ + idents := mgr.Identities() + expected := map[string]*identities.Identity{ "bob": { Name: "bob", - Access: state.ReadAccess, - Local: &state.LocalIdentity{UserID: 42}, + Access: identities.ReadAccess, + Local: &identities.LocalIdentity{UserID: 42}, }, "mary": { Name: "mary", - Access: state.AdminAccess, - Local: &state.LocalIdentity{UserID: 1000}, + Access: identities.AdminAccess, + Local: &identities.LocalIdentity{UserID: 1000}, }, "nancy": { Name: "nancy", - Access: state.MetricsAccess, - Basic: &state.BasicIdentity{Password: "hash"}, + Access: identities.MetricsAccess, + Basic: &identities.BasicIdentity{Password: "hash"}, }, } - c.Assert(identities, DeepEquals, expected) + c.Assert(idents, DeepEquals, expected) // Ensure the map was cloned (mutations to first map won't affect second). - identities2 := st.Identities() - c.Assert(identities2, DeepEquals, expected) - identities["changed"] = &state.Identity{} - c.Assert(identities2, DeepEquals, expected) + idents2 := mgr.Identities() + c.Assert(idents2, DeepEquals, expected) + idents2["changed"] = &identities.Identity{} + c.Assert(idents2, DeepEquals, expected) } func (s *identitiesSuite) TestIdentityFromInputs(c *C) { st := state.New(nil) + mgr, err := identities.NewManager(st) + c.Assert(err, IsNil) + st.Lock() defer st.Unlock() - ids := map[string]*state.Identity{ + ids := map[string]*identities.Identity{ "uid": { - Access: state.MetricsAccess, - Local: &state.LocalIdentity{UserID: 42}, + Access: identities.MetricsAccess, + Local: &identities.LocalIdentity{UserID: 42}, }, "basic": { - Access: state.ReadAccess, - Basic: &state.BasicIdentity{ + Access: identities.ReadAccess, + Basic: &identities.BasicIdentity{ // password: test Password: "$6$F9cFSVEKyO4gB1Wh$8S1BSKsNkF.jBAixGc4W7l80OpfCNk65LZBDHBng3NAmbcHuMj4RIm7992rrJ8YA.SJ0hvm.vGk2z483am4Ym1", }, }, "cert": { - Access: state.AdminAccess, - Cert: &state.CertIdentity{X509: parseCert(c, validPEMX509Cert)}, + Access: identities.AdminAccess, + Cert: &identities.CertIdentity{X509: parseCert(c, validPEMX509Cert)}, }, } - err := st.AddIdentities(ids) + err = mgr.AddIdentities(ids) c.Assert(err, IsNil) validCert := parseCert(c, validPEMX509Cert) @@ -728,7 +756,7 @@ func (s *identitiesSuite) TestIdentityFromInputs(c *C) { basicPass string cert *x509.Certificate expectedUser string - expectedAccess state.IdentityAccess + expectedAccess identities.IdentityAccess }{{ name: "no inputs", expectedUser: "", @@ -737,7 +765,7 @@ func (s *identitiesSuite) TestIdentityFromInputs(c *C) { name: "valid cert", cert: validCert, expectedUser: "cert", - expectedAccess: state.AdminAccess, + expectedAccess: identities.AdminAccess, }, { name: "invalid cert", cert: invalidCert, @@ -749,13 +777,13 @@ func (s *identitiesSuite) TestIdentityFromInputs(c *C) { basicUser: "basic", basicPass: "test", expectedUser: "cert", - expectedAccess: state.AdminAccess, + expectedAccess: identities.AdminAccess, }, { name: "cert with uid ignored", cert: validCert, userID: ptr(uint32(42)), expectedUser: "cert", - expectedAccess: state.AdminAccess, + expectedAccess: identities.AdminAccess, }, { name: "cert with both basic and uid ignored", cert: validCert, @@ -763,14 +791,14 @@ func (s *identitiesSuite) TestIdentityFromInputs(c *C) { basicPass: "test", userID: ptr(uint32(42)), expectedUser: "cert", - expectedAccess: state.AdminAccess, + expectedAccess: identities.AdminAccess, }, { // Basic authentication tests (medium priority) name: "valid basic auth", basicUser: "basic", basicPass: "test", expectedUser: "basic", - expectedAccess: state.ReadAccess, + expectedAccess: identities.ReadAccess, }, { name: "valid user invalid password", basicUser: "basic", @@ -808,7 +836,7 @@ func (s *identitiesSuite) TestIdentityFromInputs(c *C) { basicPass: "test", userID: ptr(uint32(42)), expectedUser: "basic", - expectedAccess: state.ReadAccess, + expectedAccess: identities.ReadAccess, }, { name: "invalid basic auth with valid uid ignored", basicUser: "basic", @@ -820,7 +848,7 @@ func (s *identitiesSuite) TestIdentityFromInputs(c *C) { name: "valid uid", userID: ptr(uint32(42)), expectedUser: "uid", - expectedAccess: state.MetricsAccess, + expectedAccess: identities.MetricsAccess, }, { name: "invalid uid", userID: ptr(uint32(100)), @@ -834,7 +862,7 @@ func (s *identitiesSuite) TestIdentityFromInputs(c *C) { for _, test := range tests { c.Logf("Running test: %s", test.name) - identity := st.IdentityFromInputs(test.userID, test.basicUser, test.basicPass, test.cert) + identity := mgr.IdentityFromInputs(test.userID, test.basicUser, test.basicPass, test.cert) if test.expectedUser != "" { c.Assert(identity, NotNil) diff --git a/internals/overlord/overlord.go b/internals/overlord/overlord.go index 1e21c844b..19b4aa2e4 100644 --- a/internals/overlord/overlord.go +++ b/internals/overlord/overlord.go @@ -32,6 +32,7 @@ import ( "github.com/canonical/pebble/internals/osutil" "github.com/canonical/pebble/internals/overlord/checkstate" "github.com/canonical/pebble/internals/overlord/cmdstate" + "github.com/canonical/pebble/internals/overlord/identities" "github.com/canonical/pebble/internals/overlord/logstate" "github.com/canonical/pebble/internals/overlord/pairingstate" "github.com/canonical/pebble/internals/overlord/patch" @@ -119,17 +120,18 @@ type Overlord struct { startOfOperationTime time.Time // managers - inited bool - startedUp bool - runner *state.TaskRunner - restartMgr *restart.RestartManager - planMgr *planstate.PlanManager - serviceMgr *servstate.ServiceManager - commandMgr *cmdstate.CommandManager - checkMgr *checkstate.CheckManager - logMgr *logstate.LogManager - tlsMgr *tlsstate.TLSManager - pairingMgr *pairingstate.PairingManager + inited bool + startedUp bool + runner *state.TaskRunner + restartMgr *restart.RestartManager + planMgr *planstate.PlanManager + serviceMgr *servstate.ServiceManager + commandMgr *cmdstate.CommandManager + checkMgr *checkstate.CheckManager + logMgr *logstate.LogManager + tlsMgr *tlsstate.TLSManager + identitiesMgr *identities.Manager + pairingMgr *pairingstate.PairingManager extension Extension } @@ -211,7 +213,13 @@ func New(opts *Options) (*Overlord, error) { o.tlsMgr = tlsstate.NewManager(tlsDir, opts.IDSigner) o.stateEng.AddManager(o.tlsMgr) - o.pairingMgr, err = pairingstate.NewManager(s) + o.identitiesMgr, err = identities.NewManager(s) + if err != nil { + return nil, fmt.Errorf("cannot create identities manager: %w", err) + } + o.stateEng.AddManager(o.identitiesMgr) + + o.pairingMgr, err = pairingstate.NewManager(s, o.identitiesMgr) if err != nil { return nil, fmt.Errorf("cannot create pairing manager: %w", err) } @@ -651,6 +659,12 @@ func (o *Overlord) TLSManager() *tlsstate.TLSManager { return o.tlsMgr } +// IdentitiesManager returns the manager responsible for managing client +// identities. +func (o *Overlord) IdentitiesManager() *identities.Manager { + return o.identitiesMgr +} + // PairingManager returns the manager that handles client pairing. func (o *Overlord) PairingManager() *pairingstate.PairingManager { return o.pairingMgr diff --git a/internals/overlord/pairingstate/manager.go b/internals/overlord/pairingstate/manager.go index 7b6570808..6c855b9c0 100644 --- a/internals/overlord/pairingstate/manager.go +++ b/internals/overlord/pairingstate/manager.go @@ -22,6 +22,7 @@ import ( "sync" "time" + "github.com/canonical/pebble/internals/overlord/identities" "github.com/canonical/pebble/internals/overlord/state" "github.com/canonical/pebble/internals/plan" ) @@ -87,8 +88,9 @@ func (c *pairingConfig) Combine(other *pairingConfig) { } type PairingManager struct { - state *state.State - mu sync.Mutex + state *state.State + identitiesMgr *identities.Manager + mu sync.Mutex // Plan config of the pairing manager. config *pairingConfig // Persisted state of the pairing manager. @@ -100,9 +102,10 @@ type PairingManager struct { expiry time.Time } -func NewManager(st *state.State) (*PairingManager, error) { +func NewManager(st *state.State, identitiesMgr *identities.Manager) (*PairingManager, error) { m := &PairingManager{ - state: st, + state: st, + identitiesMgr: identitiesMgr, config: &pairingConfig{ Mode: ModeUnset, }, @@ -233,7 +236,7 @@ func (m *PairingManager) PairMTLS(clientCert *x509.Certificate) error { m.state.Lock() defer m.state.Unlock() - existingIdentities := m.state.Identities() + existingIdentities := m.identitiesMgr.Identities() for _, identity := range existingIdentities { if identity.Cert == nil || identity.Cert.X509 == nil { @@ -257,12 +260,12 @@ func (m *PairingManager) PairMTLS(clientCert *x509.Certificate) error { return fmt.Errorf("cannot create new identity username: %w", err) } - newIdentity := &state.Identity{ - Access: state.AdminAccess, - Cert: &state.CertIdentity{X509: clientCert}, + newIdentity := &identities.Identity{ + Access: identities.AdminAccess, + Cert: &identities.CertIdentity{X509: clientCert}, } - err = m.state.AddIdentities(map[string]*state.Identity{ + err = m.identitiesMgr.AddIdentities(map[string]*identities.Identity{ username: newIdentity, }) if err != nil { @@ -278,7 +281,7 @@ func (m *PairingManager) PairMTLS(clientCert *x509.Certificate) error { // generateUniqueUsername finds the first unique username following the pattern // "user-x" where x starts at 1 and monotonically increments. Usernames not // following this pattern will simply not be considered. -func generateUniqueUsername(existingIdentities map[string]*state.Identity) (string, error) { +func generateUniqueUsername(existingIdentities map[string]*identities.Identity) (string, error) { for i := 1; i <= maxUsernameSuffix; i++ { username := fmt.Sprintf("user-%d", i) diff --git a/internals/overlord/pairingstate/manager_test.go b/internals/overlord/pairingstate/manager_test.go index ee8b89e90..e6b31a2d3 100644 --- a/internals/overlord/pairingstate/manager_test.go +++ b/internals/overlord/pairingstate/manager_test.go @@ -20,8 +20,8 @@ import ( . "gopkg.in/check.v1" + "github.com/canonical/pebble/internals/overlord/identities" "github.com/canonical/pebble/internals/overlord/pairingstate" - "github.com/canonical/pebble/internals/overlord/state" ) // testWindowDuration is a carefully selected pairing window duration that is @@ -186,15 +186,15 @@ func (ps *pairingSuite) TestPairMTLSSuccess(c *C) { pairingDetails := ps.PairingDetails() ps.state.Lock() - identities := ps.state.Identities() + idents := ps.identitiesMgr.Identities() ps.state.Unlock() c.Assert(pairingDetails.Paired, Equals, true) - c.Assert(len(identities), Equals, 1) + c.Assert(len(idents), Equals, 1) - identity, exists := identities["user-1"] + identity, exists := idents["user-1"] c.Assert(exists, Equals, true) - c.Assert(identity.Access, Equals, state.AdminAccess) + c.Assert(identity.Access, Equals, identities.AdminAccess) c.Assert(identity.Cert, NotNil) c.Assert(identity.Cert.X509, NotNil) @@ -214,11 +214,11 @@ func (ps *pairingSuite) TestPairMTLSNotOpen(c *C) { pairingDetails := ps.PairingDetails() ps.state.Lock() - identities := ps.state.Identities() + idents := ps.identitiesMgr.Identities() ps.state.Unlock() c.Assert(pairingDetails.Paired, Equals, false) - c.Assert(len(identities), Equals, 0) + c.Assert(len(idents), Equals, 0) } // TestPairMTLSDuplicateCertificate verifies that identities already added @@ -231,10 +231,10 @@ func (ps *pairingSuite) TestPairMTLSDuplicateCertificate(c *C) { ps.updatePlan(pairingstate.ModeMultiple) ps.state.Lock() - ps.state.AddIdentities(map[string]*state.Identity{ + ps.identitiesMgr.AddIdentities(map[string]*identities.Identity{ "existing-user": { - Access: state.AdminAccess, - Cert: &state.CertIdentity{X509: clientCert}, + Access: identities.AdminAccess, + Cert: &identities.CertIdentity{X509: clientCert}, }, }) ps.state.Unlock() @@ -249,10 +249,10 @@ func (ps *pairingSuite) TestPairMTLSDuplicateCertificate(c *C) { pairingDetails := ps.PairingDetails() ps.state.Lock() - identities := ps.state.Identities() + idents := ps.identitiesMgr.Identities() ps.state.Unlock() - c.Assert(len(identities), Equals, 1) + c.Assert(len(idents), Equals, 1) c.Assert(pairingDetails.Paired, Equals, true) } @@ -262,18 +262,18 @@ func (ps *pairingSuite) TestPairMTLSUsernameIncrementing(c *C) { ps.newManager(c, nil) ps.state.Lock() - ps.state.AddIdentities(map[string]*state.Identity{ + ps.identitiesMgr.AddIdentities(map[string]*identities.Identity{ "user-3": { - Access: state.AdminAccess, - Local: &state.LocalIdentity{UserID: 1000}, + Access: identities.AdminAccess, + Local: &identities.LocalIdentity{UserID: 1000}, }, "user-1": { - Access: state.ReadAccess, - Local: &state.LocalIdentity{UserID: 1001}, + Access: identities.ReadAccess, + Local: &identities.LocalIdentity{UserID: 1001}, }, "other-user": { - Access: state.AdminAccess, - Local: &state.LocalIdentity{UserID: 1002}, + Access: identities.AdminAccess, + Local: &identities.LocalIdentity{UserID: 1002}, }, }) ps.state.Unlock() @@ -289,10 +289,10 @@ func (ps *pairingSuite) TestPairMTLSUsernameIncrementing(c *C) { c.Assert(ps.manager.PairingEnabled(), Equals, false) ps.state.Lock() - identities := ps.state.Identities() + idents := ps.identitiesMgr.Identities() ps.state.Unlock() - _, exists := identities["user-2"] + _, exists := idents["user-2"] c.Assert(exists, Equals, true) } @@ -317,81 +317,81 @@ func (ps *pairingSuite) TestPlanChangedDisablesPairingWindow(c *C) { func (ps *pairingSuite) TestGenerateUniqueUsername(c *C) { testCases := []struct { name string - existingIdentities map[string]*state.Identity + existingIdentities map[string]*identities.Identity expectedUsername string expectedError string }{{ name: "empty identities should return user-1", - existingIdentities: map[string]*state.Identity{}, + existingIdentities: map[string]*identities.Identity{}, expectedUsername: "user-1", }, { name: "single user-1 should return user-2", - existingIdentities: map[string]*state.Identity{ - "user-1": {Access: state.AdminAccess}, + existingIdentities: map[string]*identities.Identity{ + "user-1": {Access: identities.AdminAccess}, }, expectedUsername: "user-2", }, { name: "non-sequential users should fill gaps", - existingIdentities: map[string]*state.Identity{ - "user-1": {Access: state.AdminAccess}, - "user-3": {Access: state.AdminAccess}, - "user-5": {Access: state.AdminAccess}, + existingIdentities: map[string]*identities.Identity{ + "user-1": {Access: identities.AdminAccess}, + "user-3": {Access: identities.AdminAccess}, + "user-5": {Access: identities.AdminAccess}, }, expectedUsername: "user-2", }, { name: "non-user prefixed usernames should be ignored", - existingIdentities: map[string]*state.Identity{ - "admin-1": {Access: state.AdminAccess}, - "other-user": {Access: state.ReadAccess}, - "user1": {Access: state.AdminAccess}, - "usertest": {Access: state.AdminAccess}, + existingIdentities: map[string]*identities.Identity{ + "admin-1": {Access: identities.AdminAccess}, + "other-user": {Access: identities.ReadAccess}, + "user1": {Access: identities.AdminAccess}, + "usertest": {Access: identities.AdminAccess}, }, expectedUsername: "user-1", }, { name: "invalid user suffixes should be ignored", - existingIdentities: map[string]*state.Identity{ - "user-": {Access: state.AdminAccess}, - "user-abc": {Access: state.AdminAccess}, - "user-1.5": {Access: state.AdminAccess}, - "user-0": {Access: state.AdminAccess}, - "user--1": {Access: state.AdminAccess}, - "user-1-2": {Access: state.AdminAccess}, + existingIdentities: map[string]*identities.Identity{ + "user-": {Access: identities.AdminAccess}, + "user-abc": {Access: identities.AdminAccess}, + "user-1.5": {Access: identities.AdminAccess}, + "user-0": {Access: identities.AdminAccess}, + "user--1": {Access: identities.AdminAccess}, + "user-1-2": {Access: identities.AdminAccess}, }, expectedUsername: "user-1", }, { name: "sequential users from 1 to 10", - existingIdentities: map[string]*state.Identity{ - "user-1": {Access: state.AdminAccess}, - "user-2": {Access: state.AdminAccess}, - "user-3": {Access: state.AdminAccess}, - "user-4": {Access: state.AdminAccess}, - "user-5": {Access: state.AdminAccess}, - "user-6": {Access: state.AdminAccess}, - "user-7": {Access: state.AdminAccess}, - "user-8": {Access: state.AdminAccess}, - "user-9": {Access: state.AdminAccess}, - "user-10": {Access: state.AdminAccess}, + existingIdentities: map[string]*identities.Identity{ + "user-1": {Access: identities.AdminAccess}, + "user-2": {Access: identities.AdminAccess}, + "user-3": {Access: identities.AdminAccess}, + "user-4": {Access: identities.AdminAccess}, + "user-5": {Access: identities.AdminAccess}, + "user-6": {Access: identities.AdminAccess}, + "user-7": {Access: identities.AdminAccess}, + "user-8": {Access: identities.AdminAccess}, + "user-9": {Access: identities.AdminAccess}, + "user-10": {Access: identities.AdminAccess}, }, expectedUsername: "user-11", }, { name: "mixed valid and invalid usernames", - existingIdentities: map[string]*state.Identity{ - "user-1": {Access: state.AdminAccess}, - "user-abc": {Access: state.AdminAccess}, - "user-3": {Access: state.AdminAccess}, - "admin-user": {Access: state.AdminAccess}, - "user-": {Access: state.AdminAccess}, - "user-5": {Access: state.AdminAccess}, + existingIdentities: map[string]*identities.Identity{ + "user-1": {Access: identities.AdminAccess}, + "user-abc": {Access: identities.AdminAccess}, + "user-3": {Access: identities.AdminAccess}, + "admin-user": {Access: identities.AdminAccess}, + "user-": {Access: identities.AdminAccess}, + "user-5": {Access: identities.AdminAccess}, }, expectedUsername: "user-2", }, { name: "limit exceeded should return error", - existingIdentities: func() map[string]*state.Identity { - identities := make(map[string]*state.Identity) + existingIdentities: func() map[string]*identities.Identity { + idents := make(map[string]*identities.Identity) for i := 1; i <= 1000; i++ { - identities[fmt.Sprintf("user-%d", i)] = &state.Identity{Access: state.AdminAccess} + idents[fmt.Sprintf("user-%d", i)] = &identities.Identity{Access: identities.AdminAccess} } - return identities + return idents }(), expectedError: "user allocation limit 1000 reached", }} diff --git a/internals/overlord/pairingstate/package_test.go b/internals/overlord/pairingstate/package_test.go index e077e1d95..f7ea9cb24 100644 --- a/internals/overlord/pairingstate/package_test.go +++ b/internals/overlord/pairingstate/package_test.go @@ -29,6 +29,7 @@ import ( "gopkg.in/yaml.v3" "github.com/canonical/pebble/internals/overlord" + "github.com/canonical/pebble/internals/overlord/identities" "github.com/canonical/pebble/internals/overlord/pairingstate" "github.com/canonical/pebble/internals/overlord/state" "github.com/canonical/pebble/internals/plan" @@ -38,9 +39,10 @@ import ( func Test(t *testing.T) { TestingT(t) } type pairingSuite struct { - overlord *overlord.Overlord - state *state.State - manager *pairingstate.PairingManager + overlord *overlord.Overlord + state *state.State + manager *pairingstate.PairingManager + identitiesMgr *identities.Manager } var _ = Suite(&pairingSuite{}) @@ -69,7 +71,10 @@ func (ps *pairingSuite) newManager(c *C, s *pairingstate.PairingDetails) { } var err error - ps.manager, err = pairingstate.NewManager(ps.state) + ps.identitiesMgr, err = identities.NewManager(ps.state) + c.Assert(err, IsNil) + + ps.manager, err = pairingstate.NewManager(ps.state, ps.identitiesMgr) c.Assert(err, IsNil) ps.overlord.AddManager(ps.manager) diff --git a/internals/overlord/state/state.go b/internals/overlord/state/state.go index b4fbfe93c..22c5b7ad3 100644 --- a/internals/overlord/state/state.go +++ b/internals/overlord/state/state.go @@ -16,9 +16,7 @@ package state import ( - "crypto/x509" "encoding/json" - "encoding/pem" "errors" "fmt" "io" @@ -91,12 +89,11 @@ type State struct { // for registering runtime callbacks lastHandlerId int - backend Backend - data customData - changes map[string]*Change - tasks map[string]*Task - notices map[noticeKey]*Notice - identities map[string]*Identity + backend Backend + data customData + changes map[string]*Change + tasks map[string]*Task + notices map[noticeKey]*Notice noticeCond *sync.Cond latestWarningTime atomic.Pointer[time.Time] @@ -120,7 +117,6 @@ func New(backend Backend) *State { changes: make(map[string]*Change), tasks: make(map[string]*Task), notices: make(map[noticeKey]*Notice), - identities: make(map[string]*Identity), modified: true, cache: make(map[any]any), pendingChangeByAttr: make(map[string]func(*Change) bool), @@ -161,11 +157,10 @@ func (s *State) unlock() { } type marshalledState struct { - Data map[string]*json.RawMessage `json:"data"` - Changes map[string]*Change `json:"changes"` - Tasks map[string]*Task `json:"tasks"` - Notices []*Notice `json:"notices,omitempty"` - Identities map[string]*marshalledIdentity `json:"identities,omitempty"` + Data map[string]*json.RawMessage `json:"data"` + Changes map[string]*Change `json:"changes"` + Tasks map[string]*Task `json:"tasks"` + Notices []*Notice `json:"notices,omitempty"` LastChangeId int `json:"last-change-id"` LastTaskId int `json:"last-task-id"` @@ -173,36 +168,14 @@ type marshalledState struct { LastNoticeId int `json:"last-notice-id"` } -// marshalledIdentity is used specifically for marshalling to the state -// database file. Unlike apiIdentity, it should include secrets. -type marshalledIdentity struct { - Access string `json:"access"` - Local *marshalledLocalIdentity `json:"local,omitempty"` - Basic *marshalledBasicIdentity `json:"basic,omitempty"` - Cert *marshalledCertIdentity `json:"cert,omitempty"` -} - -type marshalledLocalIdentity struct { - UserID uint32 `json:"user-id"` -} - -type marshalledBasicIdentity struct { - Password string `json:"password"` -} - -type marshalledCertIdentity struct { - PEM string `json:"pem"` -} - // MarshalJSON makes State a json.Marshaller func (s *State) MarshalJSON() ([]byte, error) { s.reading() return json.Marshal(marshalledState{ - Data: s.data, - Changes: s.changes, - Tasks: s.tasks, - Notices: s.flattenNotices(nil), - Identities: s.marshalledIdentities(), + Data: s.data, + Changes: s.changes, + Tasks: s.tasks, + Notices: s.flattenNotices(nil), LastTaskId: s.lastTaskId, LastChangeId: s.lastChangeId, @@ -211,29 +184,6 @@ func (s *State) MarshalJSON() ([]byte, error) { }) } -func (s *State) marshalledIdentities() map[string]*marshalledIdentity { - marshalled := make(map[string]*marshalledIdentity, len(s.identities)) - for name, identity := range s.identities { - marshalled[name] = &marshalledIdentity{ - Access: string(identity.Access), - } - if identity.Local != nil { - marshalled[name].Local = &marshalledLocalIdentity{UserID: identity.Local.UserID} - } - if identity.Basic != nil { - marshalled[name].Basic = &marshalledBasicIdentity{Password: identity.Basic.Password} - } - if identity.Cert != nil { - pemBlock := &pem.Block{ - Type: "CERTIFICATE", - Bytes: identity.Cert.X509.Raw, - } - marshalled[name].Cert = &marshalledCertIdentity{PEM: string(pem.EncodeToMemory(pemBlock))} - } - } - return marshalled -} - // UnmarshalJSON makes State a json.Unmarshaller func (s *State) UnmarshalJSON(data []byte) error { s.writing() @@ -246,9 +196,6 @@ func (s *State) UnmarshalJSON(data []byte) error { s.changes = unmarshalled.Changes s.tasks = unmarshalled.Tasks s.unflattenNotices(unmarshalled.Notices) - if err := s.unmarshalIdentities(unmarshalled.Identities); err != nil { - return err - } s.lastChangeId = unmarshalled.LastChangeId s.lastTaskId = unmarshalled.LastTaskId s.lastLaneId = unmarshalled.LastLaneId @@ -264,31 +211,6 @@ func (s *State) UnmarshalJSON(data []byte) error { return nil } -func (s *State) unmarshalIdentities(marshalled map[string]*marshalledIdentity) error { - s.identities = make(map[string]*Identity, len(marshalled)) - for name, mi := range marshalled { - s.identities[name] = &Identity{ - Name: name, - Access: IdentityAccess(mi.Access), - } - if mi.Local != nil { - s.identities[name].Local = &LocalIdentity{UserID: mi.Local.UserID} - } - if mi.Basic != nil { - s.identities[name].Basic = &BasicIdentity{Password: mi.Basic.Password} - } - if mi.Cert != nil { - block, _ := pem.Decode([]byte(mi.Cert.PEM)) - cert, err := x509.ParseCertificate(block.Bytes) - if err != nil { - return fmt.Errorf("cannot parse certificate from cert identity: %w", err) - } - s.identities[name].Cert = &CertIdentity{X509: cert} - } - } - return nil -} - func (s *State) checkpointData() []byte { data, err := json.Marshal(s) if err != nil { From d42438f89b7dcc2f59a9c2162580c06671a73b5f Mon Sep 17 00:00:00 2001 From: Ben Hoyt Date: Fri, 5 Dec 2025 13:29:15 +1300 Subject: [PATCH 02/15] fix a couple of tests --- internals/overlord/identities/identities.go | 13 ++++- .../overlord/identities/identities_test.go | 55 +++++++++++-------- internals/overlord/state/state_test.go | 1 - 3 files changed, 44 insertions(+), 25 deletions(-) diff --git a/internals/overlord/identities/identities.go b/internals/overlord/identities/identities.go index 5e71d2766..1fa0e046b 100644 --- a/internals/overlord/identities/identities.go +++ b/internals/overlord/identities/identities.go @@ -116,12 +116,23 @@ type Manager struct { func NewManager(st *state.State) (*Manager, error) { m := &Manager{ - state: st, + state: st, + identities: make(map[string]*Identity), } + + m.state.Lock() + defer m.state.Unlock() + + // Read existing identities from state, if any. err := st.Get(identitiesKey, &m.identities) if err != nil && !errors.Is(err, state.ErrNoState) { return nil, err } + + // TODO: is this the right place to do this? See also commented out stuff above. + for name, identity := range m.identities { + identity.Name = name + } return m, nil } diff --git a/internals/overlord/identities/identities_test.go b/internals/overlord/identities/identities_test.go index e2d8c5fea..38f883279 100644 --- a/internals/overlord/identities/identities_test.go +++ b/internals/overlord/identities/identities_test.go @@ -15,10 +15,12 @@ package identities_test import ( + "bytes" "crypto/x509" "encoding/json" "encoding/pem" "fmt" + "testing" . "gopkg.in/check.v1" @@ -26,6 +28,8 @@ import ( "github.com/canonical/pebble/internals/overlord/state" ) +func TestIdentities(t *testing.T) { TestingT(t) } + type identitiesSuite struct{} var _ = Suite(&identitiesSuite{}) @@ -243,13 +247,16 @@ func (s *identitiesSuite) TestMarshalState(c *C) { }) c.Assert(err, IsNil) - // Marshal entire identities, then pull out just the "identities" key to test that. + // Marshal entire state, then pull out just the "identities" key to test that. data, err := json.Marshal(st) c.Assert(err, IsNil) + var unmarshalled map[string]any err = json.Unmarshal(data, &unmarshalled) c.Assert(err, IsNil) - data, err = json.MarshalIndent(unmarshalled["identities"], "", " ") + customData := unmarshalled["data"].(map[string]any) + + data, err = json.MarshalIndent(customData["identities"], "", " ") c.Assert(err, IsNil) c.Assert(string(data), Equals, ` { @@ -269,32 +276,34 @@ func (s *identitiesSuite) TestMarshalState(c *C) { } func (s *identitiesSuite) TestUnmarshalState(c *C) { - st := state.New(nil) + data := []byte(` +{ + "data": { + "identities": { + "bob": { + "access": "read", + "local": { + "user-id": 42 + } + }, + "mary": { + "access": "admin", + "local": { + "user-id": 1000 + } + } + } + } +}`) + + st, err := state.ReadState(nil, bytes.NewReader(data)) + c.Assert(err, IsNil) mgr, err := identities.NewManager(st) c.Assert(err, IsNil) st.Lock() defer st.Unlock() - data := []byte(` -{ - "identities": { - "bob": { - "access": "read", - "local": { - "user-id": 42 - } - }, - "mary": { - "access": "admin", - "local": { - "user-id": 1000 - } - } - } -}`) - err = json.Unmarshal(data, &st) - c.Assert(err, IsNil) c.Assert(mgr.Identities(), DeepEquals, map[string]*identities.Identity{ "bob": { Name: "bob", @@ -714,7 +723,7 @@ func (s *identitiesSuite) TestIdentities(c *C) { // Ensure the map was cloned (mutations to first map won't affect second). idents2 := mgr.Identities() c.Assert(idents2, DeepEquals, expected) - idents2["changed"] = &identities.Identity{} + idents["changed"] = &identities.Identity{} c.Assert(idents2, DeepEquals, expected) } diff --git a/internals/overlord/state/state_test.go b/internals/overlord/state/state_test.go index 7ee5d6653..1bd6d6b8d 100644 --- a/internals/overlord/state/state_test.go +++ b/internals/overlord/state/state_test.go @@ -607,7 +607,6 @@ func (ss *stateSuite) TestEmptyStateDataAndCheckpointReadAndSet(c *C) { "changes", "tasks", "notices", - "identities", "cache", "pendingChangeByAttr", "taskHandlers", From c6c6a9782fdbfbe8c7e430a8939c90e27a9bac39 Mon Sep 17 00:00:00 2001 From: Ben Hoyt Date: Mon, 8 Dec 2025 14:44:07 +1300 Subject: [PATCH 03/15] fix the rest of the tests --- internals/daemon/api_identities_test.go | 24 +++++++++--------------- internals/daemon/daemon_test.go | 18 +++++++----------- 2 files changed, 16 insertions(+), 26 deletions(-) diff --git a/internals/daemon/api_identities_test.go b/internals/daemon/api_identities_test.go index 23067d723..a6fd2f9c3 100644 --- a/internals/daemon/api_identities_test.go +++ b/internals/daemon/api_identities_test.go @@ -29,11 +29,9 @@ func (s *apiSuite) TestIdentities(c *C) { s.daemon(c) st := s.d.overlord.State() - identitiesMgr, err := identities.NewManager(st) - c.Assert(err, IsNil) st.Lock() - - err = identitiesMgr.AddIdentities(map[string]*identities.Identity{ + identitiesMgr := s.d.overlord.IdentitiesManager() + err := identitiesMgr.AddIdentities(map[string]*identities.Identity{ "bob": { Access: identities.ReadAccess, Local: &identities.LocalIdentity{UserID: 42}, @@ -105,9 +103,8 @@ func (s *apiSuite) TestAddIdentities(c *C) { c.Check(rsp.Status, Equals, http.StatusOK) st := s.d.overlord.State() - identitiesMgr, err := identities.NewManager(st) - c.Assert(err, IsNil) st.Lock() + identitiesMgr := s.d.overlord.IdentitiesManager() idents := identitiesMgr.Identities() c.Assert(idents, DeepEquals, map[string]*identities.Identity{ "bob": { @@ -152,10 +149,9 @@ func (s *apiSuite) TestUpdateIdentities(c *C) { s.daemon(c) st := s.d.overlord.State() - identitiesMgr, err := identities.NewManager(st) - c.Assert(err, IsNil) st.Lock() - err = identitiesMgr.AddIdentities(map[string]*identities.Identity{ + identitiesMgr := s.d.overlord.IdentitiesManager() + err := identitiesMgr.AddIdentities(map[string]*identities.Identity{ "bob": { Access: identities.ReadAccess, Local: &identities.LocalIdentity{UserID: 42}, @@ -235,10 +231,9 @@ func (s *apiSuite) TestReplaceIdentities(c *C) { s.daemon(c) st := s.d.overlord.State() - identitiesMgr, err := identities.NewManager(st) - c.Assert(err, IsNil) st.Lock() - err = identitiesMgr.AddIdentities(map[string]*identities.Identity{ + identitiesMgr := s.d.overlord.IdentitiesManager() + err := identitiesMgr.AddIdentities(map[string]*identities.Identity{ "bob": { Access: identities.ReadAccess, Local: &identities.LocalIdentity{UserID: 42}, @@ -302,10 +297,9 @@ func (s *apiSuite) TestRemoveIdentities(c *C) { s.daemon(c) st := s.d.overlord.State() - identitiesMgr, err := identities.NewManager(st) - c.Assert(err, IsNil) st.Lock() - err = identitiesMgr.AddIdentities(map[string]*identities.Identity{ + identitiesMgr := s.d.overlord.IdentitiesManager() + err := identitiesMgr.AddIdentities(map[string]*identities.Identity{ "bob": { Access: identities.ReadAccess, Local: &identities.LocalIdentity{UserID: 42}, diff --git a/internals/daemon/daemon_test.go b/internals/daemon/daemon_test.go index aef44dade..810df2748 100644 --- a/internals/daemon/daemon_test.go +++ b/internals/daemon/daemon_test.go @@ -394,10 +394,9 @@ func (s *daemonSuite) testAccessChecker(c *C, tests []accessCheckerTestCase, rem d := s.newDaemon(c) // Add some named identities for testing with. - identitiesMgr, err := identities.NewManager(d.overlord.State()) - c.Assert(err, IsNil) + identitiesMgr := d.overlord.IdentitiesManager() d.state.Lock() - err = identitiesMgr.ReplaceIdentities(map[string]*identities.Identity{ + err := identitiesMgr.ReplaceIdentities(map[string]*identities.Identity{ "adminuser": { Access: identities.AdminAccess, Local: &identities.LocalIdentity{UserID: 1}, @@ -1891,10 +1890,9 @@ func (s *daemonSuite) TestServeHTTPUserStateLocal(c *C) { d := s.newDaemon(c) // Set up a Local identity. - identitiesMgr, err := identities.NewManager(d.overlord.State()) - c.Assert(err, IsNil) + identitiesMgr := d.overlord.IdentitiesManager() d.state.Lock() - err = identitiesMgr.AddIdentities(map[string]*identities.Identity{ + err := identitiesMgr.AddIdentities(map[string]*identities.Identity{ "localuser": { Access: identities.AdminAccess, Local: &identities.LocalIdentity{UserID: 1000}, @@ -2045,7 +2043,7 @@ func (s *daemonSuite) TestServeHTTPUserStateBasicUnixSocket(c *C) { hashedPassword, err := crypt.Generate([]byte("test"), nil) c.Assert(err, IsNil) - identitiesMgr, err := identities.NewManager(d.overlord.State()) + identitiesMgr := d.overlord.IdentitiesManager() c.Assert(err, IsNil) d.state.Lock() err = identitiesMgr.AddIdentities(map[string]*identities.Identity{ @@ -2095,8 +2093,7 @@ func (s *daemonSuite) TestServeHTTPUserStateBasicHTTP(c *C) { hashedPassword, err := crypt.Generate([]byte("test"), nil) c.Assert(err, IsNil) - identitiesMgr, err := identities.NewManager(d.overlord.State()) - c.Assert(err, IsNil) + identitiesMgr := d.overlord.IdentitiesManager() d.state.Lock() err = identitiesMgr.AddIdentities(map[string]*identities.Identity{ "basicuser": { @@ -2156,8 +2153,7 @@ jwXVTUH4HLpbhK0RAaEPOL4h5jm36CrWTkxzpbdCrIu4NgPLQKJ6Cw== c.Assert(err, IsNil) // Set up a Cert identity. - identitiesMgr, err := identities.NewManager(d.overlord.State()) - c.Assert(err, IsNil) + identitiesMgr := d.overlord.IdentitiesManager() d.state.Lock() err = identitiesMgr.AddIdentities(map[string]*identities.Identity{ "certuser1": { From cca77058b2401d50bf981577ed0ba918d9508f9d Mon Sep 17 00:00:00 2001 From: Ben Hoyt Date: Fri, 12 Dec 2025 11:42:11 +1300 Subject: [PATCH 04/15] get loading from and storing to state working --- internals/overlord/identities/identities.go | 92 +++------------------ internals/overlord/identities/state.go | 76 +++++++++++++++++ 2 files changed, 86 insertions(+), 82 deletions(-) create mode 100644 internals/overlord/identities/state.go diff --git a/internals/overlord/identities/identities.go b/internals/overlord/identities/identities.go index 1fa0e046b..c08e4d478 100644 --- a/internals/overlord/identities/identities.go +++ b/internals/overlord/identities/identities.go @@ -12,79 +12,6 @@ // You should have received a copy of the GNU General Public License // along with this program. If not, see . -/* -TODO: from state package -- what to do with these? -// marshalledIdentity is used specifically for marshalling to the state -// database file. Unlike apiIdentity, it should include secrets. -type marshalledIdentity struct { - Access string `json:"access"` - Local *marshalledLocalIdentity `json:"local,omitempty"` - Basic *marshalledBasicIdentity `json:"basic,omitempty"` - Cert *marshalledCertIdentity `json:"cert,omitempty"` -} - -type marshalledLocalIdentity struct { - UserID uint32 `json:"user-id"` -} - -type marshalledBasicIdentity struct { - Password string `json:"password"` -} - -type marshalledCertIdentity struct { - PEM string `json:"pem"` -} - -func (s *State) marshalledIdentities() map[string]*marshalledIdentity { - marshalled := make(map[string]*marshalledIdentity, len(s.identities)) - for name, identity := range s.identities { - marshalled[name] = &marshalledIdentity{ - Access: string(identity.Access), - } - if identity.Local != nil { - marshalled[name].Local = &marshalledLocalIdentity{UserID: identity.Local.UserID} - } - if identity.Basic != nil { - marshalled[name].Basic = &marshalledBasicIdentity{Password: identity.Basic.Password} - } - if identity.Cert != nil { - pemBlock := &pem.Block{ - Type: "CERTIFICATE", - Bytes: identity.Cert.X509.Raw, - } - marshalled[name].Cert = &marshalledCertIdentity{PEM: string(pem.EncodeToMemory(pemBlock))} - } - } - return marshalled -} - - -func (s *State) unmarshalIdentities(marshalled map[string]*marshalledIdentity) error { - s.identities = make(map[string]*Identity, len(marshalled)) - for name, mi := range marshalled { - s.identities[name] = &Identity{ - Name: name, - Access: IdentityAccess(mi.Access), - } - if mi.Local != nil { - s.identities[name].Local = &LocalIdentity{UserID: mi.Local.UserID} - } - if mi.Basic != nil { - s.identities[name].Basic = &BasicIdentity{Password: mi.Basic.Password} - } - if mi.Cert != nil { - block, _ := pem.Decode([]byte(mi.Cert.PEM)) - cert, err := x509.ParseCertificate(block.Bytes) - if err != nil { - return fmt.Errorf("cannot parse certificate from cert identity: %w", err) - } - s.identities[name].Cert = &CertIdentity{X509: cert} - } - } - return nil -} -*/ - package identities import ( @@ -124,15 +51,16 @@ func NewManager(st *state.State) (*Manager, error) { defer m.state.Unlock() // Read existing identities from state, if any. - err := st.Get(identitiesKey, &m.identities) + var marshalled map[string]*marshalledIdentity + err := st.Get(identitiesKey, &marshalled) if err != nil && !errors.Is(err, state.ErrNoState) { return nil, err } - - // TODO: is this the right place to do this? See also commented out stuff above. - for name, identity := range m.identities { - identity.Name = name + m.identities, err = unmarshalIdentities(marshalled) + if err != nil { + return nil, err } + return m, nil } @@ -356,7 +284,7 @@ func (m *Manager) AddIdentities(identities map[string]*Identity) error { } m.identities = newIdentities - m.state.Set(identitiesKey, newIdentities) + m.state.Set(identitiesKey, marshalledIdentities(newIdentities)) return nil } @@ -393,7 +321,7 @@ func (m *Manager) UpdateIdentities(identities map[string]*Identity) error { } m.identities = newIdentities - m.state.Set(identitiesKey, newIdentities) + m.state.Set(identitiesKey, marshalledIdentities(newIdentities)) return nil } @@ -428,7 +356,7 @@ func (m *Manager) ReplaceIdentities(identities map[string]*Identity) error { } m.identities = newIdentities - m.state.Set(identitiesKey, newIdentities) + m.state.Set(identitiesKey, marshalledIdentities(newIdentities)) return nil } @@ -452,7 +380,7 @@ func (m *Manager) RemoveIdentities(identities map[string]struct{}) error { for name := range identities { delete(m.identities, name) } - m.state.Set(identitiesKey, m.identities) + m.state.Set(identitiesKey, marshalledIdentities(m.identities)) return nil } diff --git a/internals/overlord/identities/state.go b/internals/overlord/identities/state.go new file mode 100644 index 000000000..371e5cd55 --- /dev/null +++ b/internals/overlord/identities/state.go @@ -0,0 +1,76 @@ +package identities + +import ( + "crypto/x509" + "encoding/pem" + "fmt" +) + +// marshalledIdentity is used specifically for marshalling to the state +// database file. Unlike apiIdentity, it should include secrets. +type marshalledIdentity struct { + Access string `json:"access"` + Local *marshalledLocalIdentity `json:"local,omitempty"` + Basic *marshalledBasicIdentity `json:"basic,omitempty"` + Cert *marshalledCertIdentity `json:"cert,omitempty"` +} + +type marshalledLocalIdentity struct { + UserID uint32 `json:"user-id"` +} + +type marshalledBasicIdentity struct { + Password string `json:"password"` +} + +type marshalledCertIdentity struct { + PEM string `json:"pem"` +} + +func marshalledIdentities(identities map[string]*Identity) map[string]*marshalledIdentity { + marshalled := make(map[string]*marshalledIdentity, len(identities)) + for name, identity := range identities { + marshalled[name] = &marshalledIdentity{ + Access: string(identity.Access), + } + if identity.Local != nil { + marshalled[name].Local = &marshalledLocalIdentity{UserID: identity.Local.UserID} + } + if identity.Basic != nil { + marshalled[name].Basic = &marshalledBasicIdentity{Password: identity.Basic.Password} + } + if identity.Cert != nil { + pemBlock := &pem.Block{ + Type: "CERTIFICATE", + Bytes: identity.Cert.X509.Raw, + } + marshalled[name].Cert = &marshalledCertIdentity{PEM: string(pem.EncodeToMemory(pemBlock))} + } + } + return marshalled +} + +func unmarshalIdentities(marshalled map[string]*marshalledIdentity) (map[string]*Identity, error) { + identities := make(map[string]*Identity, len(marshalled)) + for name, mi := range marshalled { + identities[name] = &Identity{ + Name: name, + Access: IdentityAccess(mi.Access), + } + if mi.Local != nil { + identities[name].Local = &LocalIdentity{UserID: mi.Local.UserID} + } + if mi.Basic != nil { + identities[name].Basic = &BasicIdentity{Password: mi.Basic.Password} + } + if mi.Cert != nil { + block, _ := pem.Decode([]byte(mi.Cert.PEM)) + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return nil, fmt.Errorf("cannot parse certificate from cert identity: %w", err) + } + identities[name].Cert = &CertIdentity{X509: cert} + } + } + return identities, nil +} From d97e02626b69e29398e9fb9f55d01296d0c07852 Mon Sep 17 00:00:00 2001 From: Ben Hoyt Date: Fri, 12 Dec 2025 11:54:08 +1300 Subject: [PATCH 05/15] revert changes to client/identities.go --- client/identities.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/client/identities.go b/client/identities.go index 78a1b0beb..a3aedf221 100644 --- a/client/identities.go +++ b/client/identities.go @@ -23,7 +23,7 @@ import ( // Identity holds the configuration of a single identity. type Identity struct { - Access Access `json:"access" yaml:"access"` + Access IdentityAccess `json:"access" yaml:"access"` // One or more of the following type-specific configuration fields must be // non-nil. @@ -32,13 +32,13 @@ type Identity struct { Cert *CertIdentity `json:"cert,omitempty" yaml:"cert,omitempty"` } -// Access defines the access level for an identity. -type Access string +// IdentityAccess defines the access level for an identity. +type IdentityAccess string const ( - AdminAccess Access = "admin" - ReadAccess Access = "read" - UntrustedAccess Access = "untrusted" + AdminAccess IdentityAccess = "admin" + ReadAccess IdentityAccess = "read" + UntrustedAccess IdentityAccess = "untrusted" ) // LocalIdentity holds identity configuration specific to the "local" type From 11a1390ceeef759abc2d6048556080188e8e4ce1 Mon Sep 17 00:00:00 2001 From: Ben Hoyt Date: Fri, 12 Dec 2025 11:54:45 +1300 Subject: [PATCH 06/15] test that API *****'s passwords (I'm surprised we weren't testing this) --- internals/daemon/api_identities_test.go | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/internals/daemon/api_identities_test.go b/internals/daemon/api_identities_test.go index a6fd2f9c3..40599fc3f 100644 --- a/internals/daemon/api_identities_test.go +++ b/internals/daemon/api_identities_test.go @@ -40,6 +40,12 @@ func (s *apiSuite) TestIdentities(c *C) { Access: identities.AdminAccess, Local: &identities.LocalIdentity{UserID: 1000}, }, + "nancy": { + Access: identities.MetricsAccess, + Basic: &identities.BasicIdentity{ + Password: "$6$F9cFSVEKyO4gB1Wh$8S1BSKsNkF.jBAixGc4W7l80OpfCNk65LZBDHBng3NAmbcHuMj4RIm7992rrJ8YA.SJ0hvm.vGk2z483am4Ym1", // "test" + }, + }, }) c.Assert(err, IsNil) st.Unlock() @@ -70,6 +76,12 @@ func (s *apiSuite) TestIdentities(c *C) { "local": { "user-id": 1000 } + }, + "nancy": { + "access": "metrics", + "basic": { + "password": "*****" + } } }`[1:]) } From d4877900129d61d5625ae769e1685845c7d97cc8 Mon Sep 17 00:00:00 2001 From: Ben Hoyt Date: Fri, 12 Dec 2025 12:02:44 +1300 Subject: [PATCH 07/15] rename identities.IdentityAccess to identities.Access --- internals/daemon/api_identities_test.go | 2 ++ internals/daemon/api_notices_test.go | 2 +- internals/daemon/daemon.go | 2 +- internals/overlord/identities/identities.go | 20 +++++++++++-------- .../overlord/identities/identities_test.go | 2 +- internals/overlord/identities/state.go | 2 +- 6 files changed, 18 insertions(+), 12 deletions(-) diff --git a/internals/daemon/api_identities_test.go b/internals/daemon/api_identities_test.go index 40599fc3f..a65faf0ae 100644 --- a/internals/daemon/api_identities_test.go +++ b/internals/daemon/api_identities_test.go @@ -12,6 +12,8 @@ // You should have received a copy of the GNU General Public License // along with this program. If not, see . +// TODO: add test of adding identity with password too + package daemon import ( diff --git a/internals/daemon/api_notices_test.go b/internals/daemon/api_notices_test.go index 8b958baea..58162f5e6 100644 --- a/internals/daemon/api_notices_test.go +++ b/internals/daemon/api_notices_test.go @@ -891,6 +891,6 @@ func addNotice(c *C, st *state.State, userID *uint32, noticeType state.NoticeTyp c.Assert(err, IsNil) } -func userState(access identities.IdentityAccess, uid uint32) *UserState { +func userState(access identities.Access, uid uint32) *UserState { return &UserState{Access: access, UID: &uid} } diff --git a/internals/daemon/daemon.go b/internals/daemon/daemon.go index b6f1cfa5f..982d3f8b3 100644 --- a/internals/daemon/daemon.go +++ b/internals/daemon/daemon.go @@ -195,7 +195,7 @@ type Daemon struct { // UserState represents the state of an authenticated API user. type UserState struct { - Access identities.IdentityAccess + Access identities.Access UID *uint32 Username string } diff --git a/internals/overlord/identities/identities.go b/internals/overlord/identities/identities.go index c08e4d478..696cd1999 100644 --- a/internals/overlord/identities/identities.go +++ b/internals/overlord/identities/identities.go @@ -12,6 +12,10 @@ // You should have received a copy of the GNU General Public License // along with this program. If not, see . +// TODO: refactor to avoid three types: Identity, apiIdentity, marshalledIdentity +// do api stuff in api_identities.go instead? +// TODO: load from top-level "identities" key in state for migration path + package identities import ( @@ -71,7 +75,7 @@ func (m *Manager) Ensure() error { // Identity holds the configuration of a single identity. type Identity struct { Name string - Access IdentityAccess + Access Access // One or more of the following type-specific configuration fields must be // non-nil. @@ -80,14 +84,14 @@ type Identity struct { Cert *CertIdentity } -// IdentityAccess defines the access level for an identity. -type IdentityAccess string +// Access defines the access level for an identity. +type Access string const ( - AdminAccess IdentityAccess = "admin" - ReadAccess IdentityAccess = "read" - MetricsAccess IdentityAccess = "metrics" - UntrustedAccess IdentityAccess = "untrusted" + AdminAccess Access = "admin" + ReadAccess Access = "read" + MetricsAccess Access = "metrics" + UntrustedAccess Access = "untrusted" ) // LocalIdentity holds identity configuration specific to the "local" type @@ -214,7 +218,7 @@ func (d *Identity) UnmarshalJSON(data []byte) error { } identity := Identity{ - Access: IdentityAccess(ai.Access), + Access: Access(ai.Access), } if ai.Local != nil { diff --git a/internals/overlord/identities/identities_test.go b/internals/overlord/identities/identities_test.go index 38f883279..946d9fa7f 100644 --- a/internals/overlord/identities/identities_test.go +++ b/internals/overlord/identities/identities_test.go @@ -765,7 +765,7 @@ func (s *identitiesSuite) TestIdentityFromInputs(c *C) { basicPass string cert *x509.Certificate expectedUser string - expectedAccess identities.IdentityAccess + expectedAccess identities.Access }{{ name: "no inputs", expectedUser: "", diff --git a/internals/overlord/identities/state.go b/internals/overlord/identities/state.go index 371e5cd55..36ab323de 100644 --- a/internals/overlord/identities/state.go +++ b/internals/overlord/identities/state.go @@ -55,7 +55,7 @@ func unmarshalIdentities(marshalled map[string]*marshalledIdentity) (map[string] for name, mi := range marshalled { identities[name] = &Identity{ Name: name, - Access: IdentityAccess(mi.Access), + Access: Access(mi.Access), } if mi.Local != nil { identities[name].Local = &LocalIdentity{UserID: mi.Local.UserID} From 6eb8a935e5a6dd942534a3fabdcc97bb1a9cad99 Mon Sep 17 00:00:00 2001 From: Ben Hoyt Date: Fri, 12 Dec 2025 12:36:32 +1300 Subject: [PATCH 08/15] ensure identities are loaded from top-level "identities" key (legacy) --- internals/daemon/api_identities_test.go | 2 - internals/overlord/identities/identities.go | 1 - .../overlord/identities/identities_test.go | 44 +++++++++++++++++++ internals/overlord/state/state.go | 14 ++++++ 4 files changed, 58 insertions(+), 3 deletions(-) diff --git a/internals/daemon/api_identities_test.go b/internals/daemon/api_identities_test.go index a65faf0ae..40599fc3f 100644 --- a/internals/daemon/api_identities_test.go +++ b/internals/daemon/api_identities_test.go @@ -12,8 +12,6 @@ // You should have received a copy of the GNU General Public License // along with this program. If not, see . -// TODO: add test of adding identity with password too - package daemon import ( diff --git a/internals/overlord/identities/identities.go b/internals/overlord/identities/identities.go index 696cd1999..ea4b8c400 100644 --- a/internals/overlord/identities/identities.go +++ b/internals/overlord/identities/identities.go @@ -14,7 +14,6 @@ // TODO: refactor to avoid three types: Identity, apiIdentity, marshalledIdentity // do api stuff in api_identities.go instead? -// TODO: load from top-level "identities" key in state for migration path package identities diff --git a/internals/overlord/identities/identities_test.go b/internals/overlord/identities/identities_test.go index 946d9fa7f..3a7dfb31c 100644 --- a/internals/overlord/identities/identities_test.go +++ b/internals/overlord/identities/identities_test.go @@ -273,6 +273,9 @@ func (s *identitiesSuite) TestMarshalState(c *C) { } } }`[1:]) + + _, hasLegacyIdentities := unmarshalled["identities"].(any) + c.Assert(hasLegacyIdentities, Equals, false) } func (s *identitiesSuite) TestUnmarshalState(c *C) { @@ -318,6 +321,47 @@ func (s *identitiesSuite) TestUnmarshalState(c *C) { }) } +func (s *identitiesSuite) TestUnmarshalStateLegacy(c *C) { + data := []byte(` +{ + "identities": { + "bob": { + "access": "read", + "local": { + "user-id": 42 + } + }, + "mary": { + "access": "admin", + "local": { + "user-id": 1000 + } + } + } +}`) + + st, err := state.ReadState(nil, bytes.NewReader(data)) + c.Assert(err, IsNil) + mgr, err := identities.NewManager(st) + c.Assert(err, IsNil) + + st.Lock() + defer st.Unlock() + + c.Assert(mgr.Identities(), DeepEquals, map[string]*identities.Identity{ + "bob": { + Name: "bob", + Access: identities.ReadAccess, + Local: &identities.LocalIdentity{UserID: 42}, + }, + "mary": { + Name: "mary", + Access: identities.AdminAccess, + Local: &identities.LocalIdentity{UserID: 1000}, + }, + }) +} + func (s *identitiesSuite) TestAddIdentities(c *C) { st := state.New(nil) mgr, err := identities.NewManager(st) diff --git a/internals/overlord/state/state.go b/internals/overlord/state/state.go index 22c5b7ad3..29bd7c8bb 100644 --- a/internals/overlord/state/state.go +++ b/internals/overlord/state/state.go @@ -162,6 +162,10 @@ type marshalledState struct { Tasks map[string]*Task `json:"tasks"` Notices []*Notice `json:"notices,omitempty"` + // The "identities" key used to be stored directly on state at the top level, + // so be sure to read them from old state files. + LegacyIdentities json.RawMessage `json:"identities,omitempty"` + LastChangeId int `json:"last-change-id"` LastTaskId int `json:"last-task-id"` LastLaneId int `json:"last-lane-id"` @@ -193,6 +197,16 @@ func (s *State) UnmarshalJSON(data []byte) error { return err } s.data = unmarshalled.Data + + // Load legacy identities if new identities are not in Get/Set data. + if !s.data.has("identities") && len(unmarshalled.LegacyIdentities) > 0 { + logger.Noticef("Loaded legacy identities from state file") + if s.data == nil { + s.data = make(customData) + } + s.data["identities"] = &unmarshalled.LegacyIdentities + } + s.changes = unmarshalled.Changes s.tasks = unmarshalled.Tasks s.unflattenNotices(unmarshalled.Notices) From 0c71a2009268a06985e4a7e42a9bdeec28fd5f67 Mon Sep 17 00:00:00 2001 From: Ben Hoyt Date: Fri, 12 Dec 2025 12:37:52 +1300 Subject: [PATCH 09/15] remove TODO comment; we'll do that in a follow-up PR --- internals/overlord/identities/identities.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/internals/overlord/identities/identities.go b/internals/overlord/identities/identities.go index ea4b8c400..d2725989f 100644 --- a/internals/overlord/identities/identities.go +++ b/internals/overlord/identities/identities.go @@ -12,9 +12,6 @@ // You should have received a copy of the GNU General Public License // along with this program. If not, see . -// TODO: refactor to avoid three types: Identity, apiIdentity, marshalledIdentity -// do api stuff in api_identities.go instead? - package identities import ( From f53d625d43d3054af72c9c47095c1fefb3022910 Mon Sep 17 00:00:00 2001 From: Ben Hoyt Date: Fri, 12 Dec 2025 12:41:29 +1300 Subject: [PATCH 10/15] fix linter error --- internals/overlord/identities/identities_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internals/overlord/identities/identities_test.go b/internals/overlord/identities/identities_test.go index 3a7dfb31c..8e6a3c0e5 100644 --- a/internals/overlord/identities/identities_test.go +++ b/internals/overlord/identities/identities_test.go @@ -274,7 +274,7 @@ func (s *identitiesSuite) TestMarshalState(c *C) { } }`[1:]) - _, hasLegacyIdentities := unmarshalled["identities"].(any) + _, hasLegacyIdentities := unmarshalled["identities"] c.Assert(hasLegacyIdentities, Equals, false) } From 02d664624ec01ddb0db43e6afe411b6a0da05db6 Mon Sep 17 00:00:00 2001 From: Ben Hoyt Date: Fri, 12 Dec 2025 12:57:23 +1300 Subject: [PATCH 11/15] couple of minor changes from self review --- internals/overlord/identities/identities.go | 3 +- .../overlord/identities/identities_test.go | 28 +++++++++---------- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/internals/overlord/identities/identities.go b/internals/overlord/identities/identities.go index d2725989f..bc3512019 100644 --- a/internals/overlord/identities/identities.go +++ b/internals/overlord/identities/identities.go @@ -37,7 +37,8 @@ const ( type Manager struct { state *state.State - // Keep a local copy to avoid having to deserialize from state each time. + // Keep a local copy to avoid having to deserialize from state each time + // Get is called. identities map[string]*Identity } diff --git a/internals/overlord/identities/identities_test.go b/internals/overlord/identities/identities_test.go index 8e6a3c0e5..d0a1d3cca 100644 --- a/internals/overlord/identities/identities_test.go +++ b/internals/overlord/identities/identities_test.go @@ -324,20 +324,20 @@ func (s *identitiesSuite) TestUnmarshalState(c *C) { func (s *identitiesSuite) TestUnmarshalStateLegacy(c *C) { data := []byte(` { - "identities": { - "bob": { - "access": "read", - "local": { - "user-id": 42 - } - }, - "mary": { - "access": "admin", - "local": { - "user-id": 1000 - } - } - } + "identities": { + "bob": { + "access": "read", + "local": { + "user-id": 42 + } + }, + "mary": { + "access": "admin", + "local": { + "user-id": 1000 + } + } + } }`) st, err := state.ReadState(nil, bytes.NewReader(data)) From dcf305dc5f19eb59ad22e1f1abf4c6f9d1997a50 Mon Sep 17 00:00:00 2001 From: Ben Hoyt Date: Fri, 12 Dec 2025 13:11:16 +1300 Subject: [PATCH 12/15] remove unneeded assert --- internals/daemon/daemon_test.go | 1 - 1 file changed, 1 deletion(-) diff --git a/internals/daemon/daemon_test.go b/internals/daemon/daemon_test.go index 37cadba50..56d11c9b8 100644 --- a/internals/daemon/daemon_test.go +++ b/internals/daemon/daemon_test.go @@ -2044,7 +2044,6 @@ func (s *daemonSuite) TestServeHTTPUserStateBasicUnixSocket(c *C) { c.Assert(err, IsNil) identitiesMgr := d.overlord.IdentitiesManager() - c.Assert(err, IsNil) d.state.Lock() err = identitiesMgr.AddIdentities(map[string]*identities.Identity{ "basicuser": { From 0d055e30facd9ce69d6b95c2728793f0e1852b72 Mon Sep 17 00:00:00 2001 From: Ben Hoyt Date: Tue, 16 Dec 2025 11:19:40 +1300 Subject: [PATCH 13/15] add warning log if both new and legacy identities present; test this --- .../overlord/identities/identities_test.go | 42 +++++++++++++++++++ internals/overlord/state/state.go | 16 ++++--- 2 files changed, 52 insertions(+), 6 deletions(-) diff --git a/internals/overlord/identities/identities_test.go b/internals/overlord/identities/identities_test.go index d0a1d3cca..ef3b45ccf 100644 --- a/internals/overlord/identities/identities_test.go +++ b/internals/overlord/identities/identities_test.go @@ -362,6 +362,48 @@ func (s *identitiesSuite) TestUnmarshalStateLegacy(c *C) { }) } +func (s *identitiesSuite) TestUnmarshalStateNewAndLegacy(c *C) { + // If both new and legacy are present, it should prefer the new + // (and emit a warning log, but we don't test for that). + data := []byte(` +{ + "data": { + "identities": { + "bob": { + "access": "read", + "local": { + "user-id": 42 + } + } + } + }, + "identities": { + "mary": { + "access": "admin", + "local": { + "user-id": 1000 + } + } + } +}`) + + st, err := state.ReadState(nil, bytes.NewReader(data)) + c.Assert(err, IsNil) + mgr, err := identities.NewManager(st) + c.Assert(err, IsNil) + + st.Lock() + defer st.Unlock() + + c.Assert(mgr.Identities(), DeepEquals, map[string]*identities.Identity{ + "bob": { + Name: "bob", + Access: identities.ReadAccess, + Local: &identities.LocalIdentity{UserID: 42}, + }, + }) +} + func (s *identitiesSuite) TestAddIdentities(c *C) { st := state.New(nil) mgr, err := identities.NewManager(st) diff --git a/internals/overlord/state/state.go b/internals/overlord/state/state.go index 29bd7c8bb..d6ac538a7 100644 --- a/internals/overlord/state/state.go +++ b/internals/overlord/state/state.go @@ -198,13 +198,17 @@ func (s *State) UnmarshalJSON(data []byte) error { } s.data = unmarshalled.Data - // Load legacy identities if new identities are not in Get/Set data. - if !s.data.has("identities") && len(unmarshalled.LegacyIdentities) > 0 { - logger.Noticef("Loaded legacy identities from state file") - if s.data == nil { - s.data = make(customData) + // Load legacy identities if present (and new identities are not in Get/Set data). + if len(unmarshalled.LegacyIdentities) > 0 { + if s.data.has("identities") { + logger.Noticef("WARNING: both new and legacy identities found in state file, ignoring legacy") + } else { + logger.Noticef("Loaded legacy identities from state file") + if s.data == nil { + s.data = make(customData) + } + s.data["identities"] = &unmarshalled.LegacyIdentities } - s.data["identities"] = &unmarshalled.LegacyIdentities } s.changes = unmarshalled.Changes From 0cc6bbf6b8dc5c2966e2b5253cf0179349a3799f Mon Sep 17 00:00:00 2001 From: Ben Hoyt Date: Mon, 5 Jan 2026 13:56:28 +1300 Subject: [PATCH 14/15] Use "patch" feature instead of ad-hoc migration --- .../overlord/identities/identities_test.go | 6 +++ internals/overlord/patch/patch.go | 2 +- internals/overlord/patch/patch2.go | 49 +++++++++++++++++++ internals/overlord/state/state.go | 21 +++----- 4 files changed, 64 insertions(+), 14 deletions(-) create mode 100644 internals/overlord/patch/patch2.go diff --git a/internals/overlord/identities/identities_test.go b/internals/overlord/identities/identities_test.go index ef3b45ccf..a9d5ee379 100644 --- a/internals/overlord/identities/identities_test.go +++ b/internals/overlord/identities/identities_test.go @@ -25,6 +25,7 @@ import ( . "gopkg.in/check.v1" "github.com/canonical/pebble/internals/overlord/identities" + "github.com/canonical/pebble/internals/overlord/patch" "github.com/canonical/pebble/internals/overlord/state" ) @@ -324,6 +325,7 @@ func (s *identitiesSuite) TestUnmarshalState(c *C) { func (s *identitiesSuite) TestUnmarshalStateLegacy(c *C) { data := []byte(` { + "data": {}, "identities": { "bob": { "access": "read", @@ -342,6 +344,8 @@ func (s *identitiesSuite) TestUnmarshalStateLegacy(c *C) { st, err := state.ReadState(nil, bytes.NewReader(data)) c.Assert(err, IsNil) + err = patch.Apply(st) + c.Assert(err, IsNil) mgr, err := identities.NewManager(st) c.Assert(err, IsNil) @@ -389,6 +393,8 @@ func (s *identitiesSuite) TestUnmarshalStateNewAndLegacy(c *C) { st, err := state.ReadState(nil, bytes.NewReader(data)) c.Assert(err, IsNil) + err = patch.Apply(st) + c.Assert(err, IsNil) mgr, err := identities.NewManager(st) c.Assert(err, IsNil) diff --git a/internals/overlord/patch/patch.go b/internals/overlord/patch/patch.go index 58514a8a4..c40d90ee3 100644 --- a/internals/overlord/patch/patch.go +++ b/internals/overlord/patch/patch.go @@ -28,7 +28,7 @@ import ( ) // Level is the current implemented patch level of the state format and content. -var Level = 1 +var Level = 2 // Sublevel is the current implemented sublevel for the Level. // Sublevel 0 is the first patch for the new Level, rollback below x.0 is not possible. diff --git a/internals/overlord/patch/patch2.go b/internals/overlord/patch/patch2.go new file mode 100644 index 000000000..7cf058f29 --- /dev/null +++ b/internals/overlord/patch/patch2.go @@ -0,0 +1,49 @@ +// Copyright (c) 2025 Canonical Ltd +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License version 3 as +// published by the Free Software Foundation. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . + +package patch + +import ( + "encoding/json" + "fmt" + + "github.com/canonical/pebble/internals/logger" + "github.com/canonical/pebble/internals/overlord/identities" + "github.com/canonical/pebble/internals/overlord/state" +) + +func init() { + patches[2] = []PatchFunc{patch2} +} + +// Load legacy identities if present (and new identities are not in Get/Set data). +func patch2(s *state.State) error { + legacy := s.LegacyIdentities() + if len(legacy) == 0 { + return nil + } + if s.Has("identities") { + logger.Noticef("WARNING: both new and legacy identities found in state file, ignoring legacy") + return nil + } + + var idents map[string]*identities.Identity + err := json.Unmarshal(legacy, &idents) + if err != nil { + return fmt.Errorf("cannot unmarshal legacy identities: %v", err) + } + s.Set("identities", idents) + logger.Noticef("Loaded legacy identities from state file") + return nil +} diff --git a/internals/overlord/state/state.go b/internals/overlord/state/state.go index d6ac538a7..32bb24917 100644 --- a/internals/overlord/state/state.go +++ b/internals/overlord/state/state.go @@ -95,6 +95,8 @@ type State struct { tasks map[string]*Task notices map[noticeKey]*Notice + legacyIdentities json.RawMessage + noticeCond *sync.Cond latestWarningTime atomic.Pointer[time.Time] @@ -198,22 +200,10 @@ func (s *State) UnmarshalJSON(data []byte) error { } s.data = unmarshalled.Data - // Load legacy identities if present (and new identities are not in Get/Set data). - if len(unmarshalled.LegacyIdentities) > 0 { - if s.data.has("identities") { - logger.Noticef("WARNING: both new and legacy identities found in state file, ignoring legacy") - } else { - logger.Noticef("Loaded legacy identities from state file") - if s.data == nil { - s.data = make(customData) - } - s.data["identities"] = &unmarshalled.LegacyIdentities - } - } - s.changes = unmarshalled.Changes s.tasks = unmarshalled.Tasks s.unflattenNotices(unmarshalled.Notices) + s.legacyIdentities = unmarshalled.LegacyIdentities s.lastChangeId = unmarshalled.LastChangeId s.lastTaskId = unmarshalled.LastTaskId s.lastLaneId = unmarshalled.LastLaneId @@ -229,6 +219,11 @@ func (s *State) UnmarshalJSON(data []byte) error { return nil } +// LegacyIdentities is exported for use in patch2.go to perform the migration. +func (s *State) LegacyIdentities() json.RawMessage { + return s.legacyIdentities +} + func (s *State) checkpointData() []byte { data, err := json.Marshal(s) if err != nil { From c07fff64cd7999f79b3c1bf7fb966ee374773761 Mon Sep 17 00:00:00 2001 From: Ben Hoyt Date: Mon, 5 Jan 2026 15:47:12 +1300 Subject: [PATCH 15/15] move migration tests to patch/patch2_test.go --- .../overlord/identities/identities_test.go | 89 ------------ internals/overlord/patch/patch2_test.go | 133 ++++++++++++++++++ internals/overlord/state/state.go | 1 + 3 files changed, 134 insertions(+), 89 deletions(-) create mode 100644 internals/overlord/patch/patch2_test.go diff --git a/internals/overlord/identities/identities_test.go b/internals/overlord/identities/identities_test.go index a9d5ee379..fef11b4a6 100644 --- a/internals/overlord/identities/identities_test.go +++ b/internals/overlord/identities/identities_test.go @@ -25,7 +25,6 @@ import ( . "gopkg.in/check.v1" "github.com/canonical/pebble/internals/overlord/identities" - "github.com/canonical/pebble/internals/overlord/patch" "github.com/canonical/pebble/internals/overlord/state" ) @@ -322,94 +321,6 @@ func (s *identitiesSuite) TestUnmarshalState(c *C) { }) } -func (s *identitiesSuite) TestUnmarshalStateLegacy(c *C) { - data := []byte(` -{ - "data": {}, - "identities": { - "bob": { - "access": "read", - "local": { - "user-id": 42 - } - }, - "mary": { - "access": "admin", - "local": { - "user-id": 1000 - } - } - } -}`) - - st, err := state.ReadState(nil, bytes.NewReader(data)) - c.Assert(err, IsNil) - err = patch.Apply(st) - c.Assert(err, IsNil) - mgr, err := identities.NewManager(st) - c.Assert(err, IsNil) - - st.Lock() - defer st.Unlock() - - c.Assert(mgr.Identities(), DeepEquals, map[string]*identities.Identity{ - "bob": { - Name: "bob", - Access: identities.ReadAccess, - Local: &identities.LocalIdentity{UserID: 42}, - }, - "mary": { - Name: "mary", - Access: identities.AdminAccess, - Local: &identities.LocalIdentity{UserID: 1000}, - }, - }) -} - -func (s *identitiesSuite) TestUnmarshalStateNewAndLegacy(c *C) { - // If both new and legacy are present, it should prefer the new - // (and emit a warning log, but we don't test for that). - data := []byte(` -{ - "data": { - "identities": { - "bob": { - "access": "read", - "local": { - "user-id": 42 - } - } - } - }, - "identities": { - "mary": { - "access": "admin", - "local": { - "user-id": 1000 - } - } - } -}`) - - st, err := state.ReadState(nil, bytes.NewReader(data)) - c.Assert(err, IsNil) - err = patch.Apply(st) - c.Assert(err, IsNil) - mgr, err := identities.NewManager(st) - c.Assert(err, IsNil) - - st.Lock() - defer st.Unlock() - - c.Assert(mgr.Identities(), DeepEquals, map[string]*identities.Identity{ - "bob": { - Name: "bob", - Access: identities.ReadAccess, - Local: &identities.LocalIdentity{UserID: 42}, - }, - }) -} - func (s *identitiesSuite) TestAddIdentities(c *C) { st := state.New(nil) mgr, err := identities.NewManager(st) diff --git a/internals/overlord/patch/patch2_test.go b/internals/overlord/patch/patch2_test.go new file mode 100644 index 000000000..034f06edd --- /dev/null +++ b/internals/overlord/patch/patch2_test.go @@ -0,0 +1,133 @@ +// Copyright (c) 2026 Canonical Ltd +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License version 3 as +// published by the Free Software Foundation. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . + +package patch_test + +import ( + "bytes" + + . "gopkg.in/check.v1" + + "github.com/canonical/pebble/internals/overlord/identities" + "github.com/canonical/pebble/internals/overlord/patch" + "github.com/canonical/pebble/internals/overlord/state" +) + +type patch2Suite struct{} + +var _ = Suite(&patch2Suite{}) + +func (s *patch2Suite) TestLegacyIdentities(c *C) { + restore := patch.FakeLevel(2, 1) + defer restore() + + data := []byte(` +{ + "data": {"patch-level": 1}, + "identities": { + "bob": { + "access": "read", + "local": { + "user-id": 42 + } + }, + "mary": { + "access": "admin", + "local": { + "user-id": 1000 + } + } + } +}`) + + st, err := state.ReadState(nil, bytes.NewReader(data)) + c.Assert(err, IsNil) + err = patch.Apply(st) + c.Assert(err, IsNil) + mgr, err := identities.NewManager(st) + c.Assert(err, IsNil) + + st.Lock() + defer st.Unlock() + + c.Assert(mgr.Identities(), DeepEquals, map[string]*identities.Identity{ + "bob": { + Name: "bob", + Access: identities.ReadAccess, + Local: &identities.LocalIdentity{UserID: 42}, + }, + "mary": { + Name: "mary", + Access: identities.AdminAccess, + Local: &identities.LocalIdentity{UserID: 1000}, + }, + }) + + // ensure we moved forward to patch-level 2 (sublevel 0) + var patchLevel int + err = st.Get("patch-level", &patchLevel) + c.Assert(err, IsNil) + c.Assert(patchLevel, Equals, 2) + err = st.Get("patch-sublevel", &patchLevel) + c.Assert(err, IsNil) + c.Assert(patchLevel, Equals, 0) +} + +func (s *patch2Suite) TestNewAndLegacyIdentities(c *C) { + restore := patch.FakeLevel(2, 1) + defer restore() + + // If both new and legacy are present, it should prefer the new + // (and emit a warning log, but we don't test for that). + data := []byte(` +{ + "data": { + "patch-level": 1, + "identities": { + "bob": { + "access": "read", + "local": { + "user-id": 42 + } + } + } + }, + "identities": { + "mary": { + "access": "admin", + "local": { + "user-id": 1000 + } + } + } +}`) + + st, err := state.ReadState(nil, bytes.NewReader(data)) + c.Assert(err, IsNil) + err = patch.Apply(st) + c.Assert(err, IsNil) + mgr, err := identities.NewManager(st) + c.Assert(err, IsNil) + + st.Lock() + defer st.Unlock() + + c.Assert(mgr.Identities(), DeepEquals, map[string]*identities.Identity{ + "bob": { + Name: "bob", + Access: identities.ReadAccess, + Local: &identities.LocalIdentity{UserID: 42}, + }, + }) +} diff --git a/internals/overlord/state/state.go b/internals/overlord/state/state.go index 32bb24917..8d02883f6 100644 --- a/internals/overlord/state/state.go +++ b/internals/overlord/state/state.go @@ -221,6 +221,7 @@ func (s *State) UnmarshalJSON(data []byte) error { // LegacyIdentities is exported for use in patch2.go to perform the migration. func (s *State) LegacyIdentities() json.RawMessage { + s.reading() return s.legacyIdentities }