Skip to content

Commit

Permalink
Return group memberships of entity during read (#3526)
Browse files Browse the repository at this point in the history
* return group memberships of entity during read

* Add implied group memberships to read response of entity

* distinguish between all, direct and inherited group IDs of an entity

* address review feedback

* address review feedback

* s/implied/inherited in tests
  • Loading branch information
vishalnayak authored Nov 6, 2017
1 parent 447d13e commit 55c032d
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 18 deletions.
20 changes: 20 additions & 0 deletions vault/identity_store_entities.go
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,26 @@ func (i *IdentityStore) handleEntityReadCommon(entity *identity.Entity) (*logica
// formats
respData["aliases"] = aliasesToReturn

// Fetch the groups this entity belongs to and return their identifiers
groups, inheritedGroups, err := i.groupsByEntityID(entity.ID)
if err != nil {
return nil, err
}

groupIDs := make([]string, len(groups))
for i, group := range groups {
groupIDs[i] = group.ID
}
respData["direct_group_ids"] = groupIDs

inheritedGroupIDs := make([]string, len(inheritedGroups))
for i, group := range inheritedGroups {
inheritedGroupIDs[i] = group.ID
}
respData["inherited_group_ids"] = inheritedGroupIDs

respData["group_ids"] = append(groupIDs, inheritedGroupIDs...)

return &logical.Response{
Data: respData,
}, nil
Expand Down
80 changes: 80 additions & 0 deletions vault/identity_store_entities_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,86 @@ import (
"github.com/hashicorp/vault/logical"
)

func TestIdentityStore_EntityReadGroupIDs(t *testing.T) {
var err error
var resp *logical.Response

i, _, _ := testIdentityStoreWithGithubAuth(t)

entityReq := &logical.Request{
Path: "entity",
Operation: logical.UpdateOperation,
}

resp, err = i.HandleRequest(entityReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: resp: %#v\nerr: %v", resp, err)
}

entityID := resp.Data["id"].(string)

groupReq := &logical.Request{
Path: "group",
Operation: logical.UpdateOperation,
Data: map[string]interface{}{
"member_entity_ids": []string{
entityID,
},
},
}

resp, err = i.HandleRequest(groupReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: resp: %#v\nerr: %v", resp, err)
}

groupID := resp.Data["id"].(string)

// Create another group with the above created group as its subgroup

groupReq.Data = map[string]interface{}{
"member_group_ids": []string{groupID},
}
resp, err = i.HandleRequest(groupReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: resp: %#v\nerr: %v", resp, err)
}

inheritedGroupID := resp.Data["id"].(string)

lookupReq := &logical.Request{
Path: "lookup/entity",
Operation: logical.UpdateOperation,
Data: map[string]interface{}{
"type": "id",
"id": entityID,
},
}

resp, err = i.HandleRequest(lookupReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: resp: %#v\nerr: %v", resp, err)
}

expected := []string{groupID, inheritedGroupID}
actual := resp.Data["group_ids"].([]string)
if !reflect.DeepEqual(expected, actual) {
t.Fatalf("bad: group_ids; expected: %#v\nactual: %#v\n", expected, actual)
}

expected = []string{groupID}
actual = resp.Data["direct_group_ids"].([]string)
if !reflect.DeepEqual(expected, actual) {
t.Fatalf("bad: direct_group_ids; expected: %#v\nactual: %#v\n", expected, actual)
}

expected = []string{inheritedGroupID}
actual = resp.Data["inherited_group_ids"].([]string)
if !reflect.DeepEqual(expected, actual) {
t.Fatalf("bad: inherited_group_ids; expected: %#v\nactual: %#v\n", expected, actual)
}
}

func TestIdentityStore_EntityCreateUpdate(t *testing.T) {
var err error
var resp *logical.Response
Expand Down
33 changes: 21 additions & 12 deletions vault/identity_store_groups_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -546,11 +546,11 @@ func TestIdentityStore_GroupMultiCase(t *testing.T) {

/*
Test groups hierarchy:
eng
| |
vault ops
| | | |
kube identity build deploy
------- eng(entityID3) -------
| |
----- vault ----- -- ops(entityID2) --
| | | |
kube(entityID1) identity build deploy
*/
func TestIdentityStore_GroupHierarchyCases(t *testing.T) {
var resp *logical.Response
Expand Down Expand Up @@ -808,27 +808,36 @@ func TestIdentityStore_GroupHierarchyCases(t *testing.T) {
t.Fatalf("bad: policies; expected: 'engpolicy'\nactual:%#v", policies)
}

groups, err := is.transitiveGroupsByEntityID(entityID1)
groups, inheritedGroups, err := is.groupsByEntityID(entityID1)
if err != nil {
t.Fatal(err)
}
if len(groups) != 3 {
t.Fatalf("bad: length of groups; expected: 3, actual: %d", len(groups))
if len(groups) != 1 {
t.Fatalf("bad: length of groups; expected: 1, actual: %d", len(groups))
}
if len(inheritedGroups) != 2 {
t.Fatalf("bad: length of inheritedGroups; expected: 2, actual: %d", len(inheritedGroups))
}

groups, err = is.transitiveGroupsByEntityID(entityID2)
groups, inheritedGroups, err = is.groupsByEntityID(entityID2)
if err != nil {
t.Fatal(err)
}
if len(groups) != 2 {
t.Fatalf("bad: length of groups; expected: 2, actual: %d", len(groups))
if len(groups) != 1 {
t.Fatalf("bad: length of groups; expected: 1, actual: %d", len(groups))
}
if len(inheritedGroups) != 1 {
t.Fatalf("bad: length of inheritedGroups; expected: 1, actual: %d", len(inheritedGroups))
}

groups, err = is.transitiveGroupsByEntityID(entityID3)
groups, inheritedGroups, err = is.groupsByEntityID(entityID3)
if err != nil {
t.Fatal(err)
}
if len(groups) != 1 {
t.Fatalf("bad: length of groups; expected: 1, actual: %d", len(groups))
}
if len(inheritedGroups) != 0 {
t.Fatalf("bad: length of inheritedGroups; expected: 0, actual: %d", len(inheritedGroups))
}
}
20 changes: 14 additions & 6 deletions vault/identity_store_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -1970,22 +1970,22 @@ func (i *IdentityStore) groupPoliciesByEntityID(entityID string) ([]string, erro
return strutil.RemoveDuplicates(policies, false), nil
}

func (i *IdentityStore) transitiveGroupsByEntityID(entityID string) ([]*identity.Group, error) {
func (i *IdentityStore) groupsByEntityID(entityID string) ([]*identity.Group, []*identity.Group, error) {
if entityID == "" {
return nil, fmt.Errorf("empty entity ID")
return nil, nil, fmt.Errorf("empty entity ID")
}

groups, err := i.MemDBGroupsByMemberEntityID(entityID, false, false)
groups, err := i.MemDBGroupsByMemberEntityID(entityID, true, false)
if err != nil {
return nil, err
return nil, nil, err
}

visited := make(map[string]bool)
var tGroups []*identity.Group
for _, group := range groups {
gGroups, err := i.collectGroupsReverseDFS(group, visited, nil)
if err != nil {
return nil, err
return nil, nil, err
}
tGroups = append(tGroups, gGroups...)
}
Expand All @@ -2001,7 +2001,15 @@ func (i *IdentityStore) transitiveGroupsByEntityID(entityID string) ([]*identity
tGroups = append(tGroups, group)
}

return tGroups, nil
diff := diffGroups(groups, tGroups)

// For sanity
// There should not be any group that gets deleted
if len(diff.Deleted) != 0 {
return nil, nil, fmt.Errorf("failed to diff group memberships")
}

return diff.Unmodified, diff.New, nil
}

func (i *IdentityStore) collectGroupsReverseDFS(group *identity.Group, visited map[string]bool, groups []*identity.Group) ([]*identity.Group, error) {
Expand Down

0 comments on commit 55c032d

Please sign in to comment.