Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions enforcer.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,11 @@ func (e *Enforcer) initialize() {
e.autoBuildRoleLinks = true
e.autoNotifyWatcher = true
e.autoNotifyDispatcher = true

if e.model["g"] != nil && len(e.model["g"]) > 1 {
e.model.AddDef("g", "*", "_, _")
}

e.initRmMap()
}

Expand Down Expand Up @@ -477,6 +482,11 @@ func (e *Enforcer) SavePolicy() error {
if e.IsFiltered() {
return errors.New("cannot save a filtered policy")
}

if e.model["g"] != nil && e.model["g"]["*"] != nil {
delete(e.model["g"], "*")
}

if err := e.adapter.SavePolicy(e.model); err != nil {
return err
}
Expand Down
4 changes: 4 additions & 0 deletions model/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ func (model Model) loadModelFromConfig(cfg config.ConfigInterface) error {
for s := range sectionNameMap {
loadSection(model, cfg, s)
}

ms := make([]string, 0)
for _, rs := range requiredSections {
if !model.hasSection(rs) {
Expand Down Expand Up @@ -392,6 +393,9 @@ func (model Model) ToText() string {
if _, ok := model["g"]; ok {
s.WriteString("[role_definition]\n")
for ptype := range model["g"] {
if ptype == "*" {
continue
}
s.WriteString(fmt.Sprintf("%s = %s\n", ptype, model["g"][ptype].Value))
}
}
Expand Down
10 changes: 10 additions & 0 deletions model/policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,16 @@ func (model Model) AddPolicy(sec string, ptype string, rule []string) error {
assertion.Policy = append(assertion.Policy, rule)
assertion.PolicyMap[strings.Join(rule, DefaultSep)] = len(model[sec][ptype].Policy) - 1

if sec == "g" && len(model["g"]) > 1 {
assert, err := model.GetAssertion(sec, "*")
if err != nil {
return err
} else {
assert.Policy = append(assert.Policy, rule)
assert.PolicyMap[strings.Join(rule, DefaultSep)] = len(model[sec]["*"].Policy) - 1
}
}

hasPriority := false
if _, ok := assertion.FieldIndexMap[constant.PriorityIndex]; ok {
hasPriority = true
Expand Down
6 changes: 6 additions & 0 deletions rbac_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ func (e *Enforcer) GetImplicitRolesForUser(name string, domain ...string) ([]str
res = append(res, roles...)
}

util.ArrayRemoveDuplicates(&res)
return res, nil
}

Expand Down Expand Up @@ -311,6 +312,7 @@ func (e *Enforcer) GetImplicitUsersForRole(name string, domain ...string) ([]str
}
}

util.ArrayRemoveDuplicates(&res)
return res, nil
}

Expand All @@ -327,6 +329,10 @@ func (e *Enforcer) GetImplicitPermissionsForUser(user string, domain ...string)
return e.GetNamedImplicitPermissionsForUser("p", "g", user, domain...)
}

func (e *Enforcer) GetImplicitPermissionsForUserFromAllRoles(user string, domain ...string) ([][]string, error) {
return e.GetNamedImplicitPermissionsForUser("p", "*", user, domain...)
}

// GetNamedImplicitPermissionsForUser gets implicit permissions for a user or role by named policy.
// Compared to GetNamedPermissionsForUser(), this function retrieves permissions for inherited roles.
// For example:
Expand Down
13 changes: 13 additions & 0 deletions rbac_api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package casbin

import (
"fmt"
"log"
"sort"
"testing"
Expand Down Expand Up @@ -283,6 +284,7 @@ func TestPermissionAPI(t *testing.T) {
e, _ = NewEnforcer("examples/rbac_with_multiple_policy_model.conf", "examples/rbac_with_multiple_policy_policy.csv")
testGetNamedPermissionsForUser(t, e, "p", "user", [][]string{{"user", "/data", "GET"}})
testGetNamedPermissionsForUser(t, e, "p2", "user", [][]string{{"user", "view"}})
testGetImplicitPermissionsForUserFromAllRoles(t, e, "alice", [][]string{{"user", "/data", "GET"}, {"admin", "/data", "POST"}})
}

func testGetImplicitRoles(t *testing.T, e *Enforcer, name string, res []string) {
Expand Down Expand Up @@ -334,6 +336,17 @@ func testGetImplicitPermissions(t *testing.T, e *Enforcer, name string, res [][]
}
}

func testGetImplicitPermissionsForUserFromAllRoles(t *testing.T, e *Enforcer, name string, res [][]string, domain ...string) {
t.Helper()
myRes, _ := e.GetImplicitPermissionsForUserFromAllRoles(name, domain...)
fmt.Println(e.GetNamedImplicitRolesForUser("*", name))
t.Log("Implicit permissions for ", name, "from all roles", ": ", myRes)

if !util.Set2DEquals(res, myRes) {
t.Error("Implicit permissions for ", name, ": ", myRes, ", supposed to be ", res)
}
}

func testGetImplicitPermissionsWithDomain(t *testing.T, e *Enforcer, name string, domain string, res [][]string) {
t.Helper()
myRes, _ := e.GetImplicitPermissionsForUser(name, domain)
Expand Down
Loading