diff --git a/auth/claims.go b/auth/claims.go index 4a809735..2bc2c455 100644 --- a/auth/claims.go +++ b/auth/claims.go @@ -3,8 +3,9 @@ package auth import ( "encoding/base64" "encoding/json" - "golang.org/x/oauth2" "strings" + + "golang.org/x/oauth2" ) // Claims is used to unmarshall the claims from a JWT issued by the Microsoft Identity Platform. @@ -30,7 +31,7 @@ func ParseClaims(token *oauth2.Token) (claims Claims, err error) { return } jwt := strings.Split(token.AccessToken, ".") - payload, err := base64.StdEncoding.DecodeString(jwt[1]) + payload, err := base64.RawStdEncoding.DecodeString(jwt[1]) if err != nil { return } diff --git a/clients/users.go b/clients/users.go index ad94735b..0adfb2c8 100644 --- a/clients/users.go +++ b/clients/users.go @@ -135,3 +135,31 @@ func (c *UsersClient) Delete(ctx context.Context, id string) (int, error) { } return status, nil } + +// ListGroupMemberships returns a list of Groups the user is member of, optionally filtered using OData. +func (c *UsersClient) ListGroupMemberships(ctx context.Context, id string, filter string) (*[]models.Group, int, error) { + params := url.Values{} + if filter != "" { + params.Add("$filter", filter) + } + resp, status, err := c.BaseClient.Get(ctx, base.GetHttpRequestInput{ + ValidStatusCodes: []int{http.StatusOK}, + Uri: base.Uri{ + Entity: fmt.Sprintf("/users/%s/transitiveMemberOf", id), + Params: params, + HasTenantId: true, + }, + }) + if err != nil { + return nil, status, err + } + defer resp.Body.Close() + respBody, _ := ioutil.ReadAll(resp.Body) + var data struct { + Groups []models.Group `json:"value"` + } + if err := json.Unmarshal(respBody, &data); err != nil { + return nil, status, err + } + return &data.Groups, status, nil +} diff --git a/clients/users_test.go b/clients/users_test.go index 82f35e27..e6fce341 100644 --- a/clients/users_test.go +++ b/clients/users_test.go @@ -16,9 +16,10 @@ type UsersClientTest struct { } func TestUsersClient(t *testing.T) { + rs := internal.RandomString() c := UsersClientTest{ connection: internal.NewConnection(), - randomString: internal.RandomString(), + randomString: rs, } c.client = clients.NewUsersClient(c.connection.AuthConfig.TenantID) c.client.BaseClient.Authorizer = c.connection.Authorizer @@ -36,6 +37,38 @@ func TestUsersClient(t *testing.T) { user.DisplayName = internal.String(fmt.Sprintf("test-updated-user-%s", c.randomString)) testUsersClient_Update(t, c, *user) testUsersClient_List(t, c) + + g := GroupsClientTest{ + connection: internal.NewConnection(), + randomString: rs, + } + g.client = clients.NewGroupsClient(g.connection.AuthConfig.TenantID) + g.client.BaseClient.Authorizer = g.connection.Authorizer + + newGroupParent := models.Group{ + DisplayName: internal.String("Test Group Parent"), + MailEnabled: internal.Bool(false), + MailNickname: internal.String(fmt.Sprintf("test-group-parent-%s", c.randomString)), + SecurityEnabled: internal.Bool(true), + } + newGroupChild := models.Group{ + DisplayName: internal.String("Test Group Child"), + MailEnabled: internal.Bool(false), + MailNickname: internal.String(fmt.Sprintf("test-group-child-%s", c.randomString)), + SecurityEnabled: internal.Bool(true), + } + + groupParent := testGroupsClient_Create(t, g, newGroupParent) + groupChild := testGroupsClient_Create(t, g, newGroupChild) + groupParent.AppendMember(g.client.BaseClient.Endpoint, g.client.BaseClient.ApiVersion, *groupChild.ID) + testGroupsClient_AddMembers(t, g, groupParent) + groupChild.AppendMember(g.client.BaseClient.Endpoint, g.client.BaseClient.ApiVersion, *user.ID) + testGroupsClient_AddMembers(t, g, groupChild) + + testUsersClient_ListGroupMemberships(t, c, *user.ID) + testGroupsClient_Delete(t, g, *groupParent.ID) + testGroupsClient_Delete(t, g, *groupChild.ID) + testUsersClient_Delete(t, c, *user.ID) } @@ -100,3 +133,20 @@ func testUsersClient_Delete(t *testing.T, c UsersClientTest, id string) { t.Fatalf("UsersClient.Delete(): invalid status: %d", status) } } + +func testUsersClient_ListGroupMemberships(t *testing.T, c UsersClientTest, id string) (groups *[]models.Group) { + groups, _, err := c.client.ListGroupMemberships(c.connection.Context, id, "") + if err != nil { + t.Fatalf("UsersClient.ListGroupMemberships(): %v", err) + } + + if groups == nil { + t.Fatal("UsersClient.ListGroupMemberships(): groups was nil") + } + + if len(*groups) != 2 { + t.Fatalf("UsersClient.ListGroupMemberships(): expected groups length 2. was: %d", len(*groups)) + } + + return +}