diff --git a/vault/identity_store_entities.go b/vault/identity_store_entities.go index f326edc8faea..f3b3caf87cf8 100644 --- a/vault/identity_store_entities.go +++ b/vault/identity_store_entities.go @@ -446,6 +446,26 @@ func (i *IdentityStore) handleEntityReadCommon(entity *identity.Entity) (*logica // formats respData["aliases"] = aliasesToReturn + // Fetch the groups this entity belongs to and return their identifiers + groups, inheritedGroups, err := i.groupsByEntityID(entity.ID) + if err != nil { + return nil, err + } + + groupIDs := make([]string, len(groups)) + for i, group := range groups { + groupIDs[i] = group.ID + } + respData["direct_group_ids"] = groupIDs + + inheritedGroupIDs := make([]string, len(inheritedGroups)) + for i, group := range inheritedGroups { + inheritedGroupIDs[i] = group.ID + } + respData["inherited_group_ids"] = inheritedGroupIDs + + respData["group_ids"] = append(groupIDs, inheritedGroupIDs...) + return &logical.Response{ Data: respData, }, nil diff --git a/vault/identity_store_entities_test.go b/vault/identity_store_entities_test.go index 10225fd444aa..23beabc99c71 100644 --- a/vault/identity_store_entities_test.go +++ b/vault/identity_store_entities_test.go @@ -12,6 +12,86 @@ import ( "github.com/hashicorp/vault/logical" ) +func TestIdentityStore_EntityReadGroupIDs(t *testing.T) { + var err error + var resp *logical.Response + + i, _, _ := testIdentityStoreWithGithubAuth(t) + + entityReq := &logical.Request{ + Path: "entity", + Operation: logical.UpdateOperation, + } + + resp, err = i.HandleRequest(entityReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: resp: %#v\nerr: %v", resp, err) + } + + entityID := resp.Data["id"].(string) + + groupReq := &logical.Request{ + Path: "group", + Operation: logical.UpdateOperation, + Data: map[string]interface{}{ + "member_entity_ids": []string{ + entityID, + }, + }, + } + + resp, err = i.HandleRequest(groupReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: resp: %#v\nerr: %v", resp, err) + } + + groupID := resp.Data["id"].(string) + + // Create another group with the above created group as its subgroup + + groupReq.Data = map[string]interface{}{ + "member_group_ids": []string{groupID}, + } + resp, err = i.HandleRequest(groupReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: resp: %#v\nerr: %v", resp, err) + } + + inheritedGroupID := resp.Data["id"].(string) + + lookupReq := &logical.Request{ + Path: "lookup/entity", + Operation: logical.UpdateOperation, + Data: map[string]interface{}{ + "type": "id", + "id": entityID, + }, + } + + resp, err = i.HandleRequest(lookupReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: resp: %#v\nerr: %v", resp, err) + } + + expected := []string{groupID, inheritedGroupID} + actual := resp.Data["group_ids"].([]string) + if !reflect.DeepEqual(expected, actual) { + t.Fatalf("bad: group_ids; expected: %#v\nactual: %#v\n", expected, actual) + } + + expected = []string{groupID} + actual = resp.Data["direct_group_ids"].([]string) + if !reflect.DeepEqual(expected, actual) { + t.Fatalf("bad: direct_group_ids; expected: %#v\nactual: %#v\n", expected, actual) + } + + expected = []string{inheritedGroupID} + actual = resp.Data["inherited_group_ids"].([]string) + if !reflect.DeepEqual(expected, actual) { + t.Fatalf("bad: inherited_group_ids; expected: %#v\nactual: %#v\n", expected, actual) + } +} + func TestIdentityStore_EntityCreateUpdate(t *testing.T) { var err error var resp *logical.Response diff --git a/vault/identity_store_groups_test.go b/vault/identity_store_groups_test.go index a7a7d166793f..45287c2f3ca8 100644 --- a/vault/identity_store_groups_test.go +++ b/vault/identity_store_groups_test.go @@ -546,11 +546,11 @@ func TestIdentityStore_GroupMultiCase(t *testing.T) { /* Test groups hierarchy: - eng - | | - vault ops - | | | | - kube identity build deploy + ------- eng(entityID3) ------- + | | + ----- vault ----- -- ops(entityID2) -- + | | | | + kube(entityID1) identity build deploy */ func TestIdentityStore_GroupHierarchyCases(t *testing.T) { var resp *logical.Response @@ -808,27 +808,36 @@ func TestIdentityStore_GroupHierarchyCases(t *testing.T) { t.Fatalf("bad: policies; expected: 'engpolicy'\nactual:%#v", policies) } - groups, err := is.transitiveGroupsByEntityID(entityID1) + groups, inheritedGroups, err := is.groupsByEntityID(entityID1) if err != nil { t.Fatal(err) } - if len(groups) != 3 { - t.Fatalf("bad: length of groups; expected: 3, actual: %d", len(groups)) + if len(groups) != 1 { + t.Fatalf("bad: length of groups; expected: 1, actual: %d", len(groups)) + } + if len(inheritedGroups) != 2 { + t.Fatalf("bad: length of inheritedGroups; expected: 2, actual: %d", len(inheritedGroups)) } - groups, err = is.transitiveGroupsByEntityID(entityID2) + groups, inheritedGroups, err = is.groupsByEntityID(entityID2) if err != nil { t.Fatal(err) } - if len(groups) != 2 { - t.Fatalf("bad: length of groups; expected: 2, actual: %d", len(groups)) + if len(groups) != 1 { + t.Fatalf("bad: length of groups; expected: 1, actual: %d", len(groups)) + } + if len(inheritedGroups) != 1 { + t.Fatalf("bad: length of inheritedGroups; expected: 1, actual: %d", len(inheritedGroups)) } - groups, err = is.transitiveGroupsByEntityID(entityID3) + groups, inheritedGroups, err = is.groupsByEntityID(entityID3) if err != nil { t.Fatal(err) } if len(groups) != 1 { t.Fatalf("bad: length of groups; expected: 1, actual: %d", len(groups)) } + if len(inheritedGroups) != 0 { + t.Fatalf("bad: length of inheritedGroups; expected: 0, actual: %d", len(inheritedGroups)) + } } diff --git a/vault/identity_store_util.go b/vault/identity_store_util.go index 4fb99e8be89d..381db6a68809 100644 --- a/vault/identity_store_util.go +++ b/vault/identity_store_util.go @@ -1970,14 +1970,14 @@ func (i *IdentityStore) groupPoliciesByEntityID(entityID string) ([]string, erro return strutil.RemoveDuplicates(policies, false), nil } -func (i *IdentityStore) transitiveGroupsByEntityID(entityID string) ([]*identity.Group, error) { +func (i *IdentityStore) groupsByEntityID(entityID string) ([]*identity.Group, []*identity.Group, error) { if entityID == "" { - return nil, fmt.Errorf("empty entity ID") + return nil, nil, fmt.Errorf("empty entity ID") } - groups, err := i.MemDBGroupsByMemberEntityID(entityID, false, false) + groups, err := i.MemDBGroupsByMemberEntityID(entityID, true, false) if err != nil { - return nil, err + return nil, nil, err } visited := make(map[string]bool) @@ -1985,7 +1985,7 @@ func (i *IdentityStore) transitiveGroupsByEntityID(entityID string) ([]*identity for _, group := range groups { gGroups, err := i.collectGroupsReverseDFS(group, visited, nil) if err != nil { - return nil, err + return nil, nil, err } tGroups = append(tGroups, gGroups...) } @@ -2001,7 +2001,15 @@ func (i *IdentityStore) transitiveGroupsByEntityID(entityID string) ([]*identity tGroups = append(tGroups, group) } - return tGroups, nil + diff := diffGroups(groups, tGroups) + + // For sanity + // There should not be any group that gets deleted + if len(diff.Deleted) != 0 { + return nil, nil, fmt.Errorf("failed to diff group memberships") + } + + return diff.Unmodified, diff.New, nil } func (i *IdentityStore) collectGroupsReverseDFS(group *identity.Group, visited map[string]bool, groups []*identity.Group) ([]*identity.Group, error) {