Skip to content

Commit

Permalink
✨ Add ClaimsMappingPolicy
Browse files Browse the repository at this point in the history
This adds initial support for the ClaimsMappingPolicy resource.
Support for this resource type is prerequisite for adding support
to the azure ad terraform provider so it can support Azure AD
Claims Mapping polices.

Fixes: manicminer#118

Graph API Reference:
https://docs.microsoft.com/en-us/graph/api/resources/claimsmappingpolicy?view=graph-rest-1.0
  • Loading branch information
computeracer authored and dhohengassner committed Feb 23, 2022
1 parent f7a9113 commit 9f20f0b
Show file tree
Hide file tree
Showing 5 changed files with 278 additions and 0 deletions.
5 changes: 5 additions & 0 deletions internal/test/testing.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ type Test struct {
ApplicationsClient *msgraph.ApplicationsClient
AppRoleAssignedToClient *msgraph.AppRoleAssignedToClient
AuthenticationMethodsClient *msgraph.AuthenticationMethodsClient
ClaimsMappingPolicyClient *msgraph.ClaimsMappingPolicyClient
ConditionalAccessPoliciesClient *msgraph.ConditionalAccessPoliciesClient
DelegatedPermissionGrantsClient *msgraph.DelegatedPermissionGrantsClient
DirectoryAuditReportsClient *msgraph.DirectoryAuditReportsClient
Expand Down Expand Up @@ -322,6 +323,10 @@ func NewTest(t *testing.T) (c *Test) {
c.UsersClient.BaseClient.Authorizer = c.Connection.Authorizer
c.UsersClient.BaseClient.Endpoint = c.Connection.AuthConfig.Environment.MsGraph.Endpoint
c.UsersClient.BaseClient.RetryableClient.RetryMax = retry
c.ClaimsMappingPolicyClient = msgraph.NewClaimsMappingPolicyClient(c.Connection.AuthConfig.TenantID)
c.ClaimsMappingPolicyClient.BaseClient.Authorizer = c.Connection.Authorizer
c.ClaimsMappingPolicyClient.BaseClient.Endpoint = c.Connection.AuthConfig.Environment.MsGraph.Endpoint
c.ClaimsMappingPolicyClient.BaseClient.RetryableClient.RetryMax = retry

return
}
5 changes: 5 additions & 0 deletions internal/utils/pointers.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,8 @@ func Int32Ptr(i int32) *int32 {
func StringPtr(s string) *string {
return &s
}

// ArrayStringPtr returns a pointer to the provided array of strings.
func ArrayStringPtr(s []string) *[]string {
return &s
}
169 changes: 169 additions & 0 deletions msgraph/claims_mapping_policy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
package msgraph

import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"

"github.com/manicminer/hamilton/odata"
)

type ClaimsMappingPolicyClient struct {
BaseClient Client
}

// NewClaimsMappingPolicyClient returns a new ClaimsMappingPolicyClient
func NewClaimsMappingPolicyClient(tenantId string) *ClaimsMappingPolicyClient {
return &ClaimsMappingPolicyClient{
BaseClient: NewClient(VersionBeta, tenantId),
}
}

// Create creates a new ClaimsMappingPolicy.
func (c *ClaimsMappingPolicyClient) Create(ctx context.Context, policy ClaimsMappingPolicy) (*ClaimsMappingPolicy, int, error) {
var status int

body, err := json.Marshal(policy)
if err != nil {
return nil, status, fmt.Errorf("json.Marshal(): %v", err)
}

resp, status, _, err := c.BaseClient.Post(ctx, PostHttpRequestInput{
Body: body,
ValidStatusCodes: []int{http.StatusCreated},
Uri: Uri{
Entity: "/policies/claimsMappingPolicies",
HasTenantId: false,
},
})
if err != nil {
return nil, status, fmt.Errorf("ClaimsMappingPolicyClient.BaseClient.Post(): %v", err)
}

defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, status, fmt.Errorf("io.ReadAll(): %v", err)
}

var newPolicy ClaimsMappingPolicy
if err := json.Unmarshal(respBody, &newPolicy); err != nil {
return nil, status, fmt.Errorf("json.Unmarshal(): %v", err)
}

return &newPolicy, status, nil
}

// List returns a list of ClaimsMappingPolicy, optionally queried using OData.
func (c *ClaimsMappingPolicyClient) List(ctx context.Context, query odata.Query) (*[]ClaimsMappingPolicy, int, error) {
resp, status, _, err := c.BaseClient.Get(ctx, GetHttpRequestInput{
DisablePaging: query.Top > 0,
OData: query,
ValidStatusCodes: []int{http.StatusOK},
Uri: Uri{
Entity: "/policies/claimsMappingPolicies",
HasTenantId: false,
},
})
if err != nil {
return nil, status, fmt.Errorf("ClaimsMappingPolicyClient.BaseClient.Get(): %v", err)
}

defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, status, fmt.Errorf("io.ReadAll(): %v", err)
}

var data struct {
ClaimsMappingPolicies []ClaimsMappingPolicy `json:"value"`
}
if err := json.Unmarshal(respBody, &data); err != nil {
return nil, status, fmt.Errorf("json.Unmarshal(): %v", err)
}

return &data.ClaimsMappingPolicies, status, nil
}

// Get retrieves a ClaimsMappingPolicy.
func (c *ClaimsMappingPolicyClient) Get(ctx context.Context, id string, query odata.Query) (*ClaimsMappingPolicy, int, error) {
resp, status, _, err := c.BaseClient.Get(ctx, GetHttpRequestInput{
ConsistencyFailureFunc: RetryOn404ConsistencyFailureFunc,
OData: query,
ValidStatusCodes: []int{http.StatusOK},
Uri: Uri{
Entity: fmt.Sprintf("/policies/claimsMappingPolicies/%s", id),
HasTenantId: false,
},
})
if err != nil {
return nil, status, fmt.Errorf("ClaimsMappingPolicyClient.BaseClient.Get(): %v", err)
}

defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, status, fmt.Errorf("io.ReadAll(): %v", err)
}

var claimsMappingPolicies ClaimsMappingPolicy
if err := json.Unmarshal(respBody, &claimsMappingPolicies); err != nil {
return nil, status, fmt.Errorf("json.Unmarshal(): %v", err)
}

return &claimsMappingPolicies, status, nil
}

// Update amends an existing ClaimsMappingPolicy.
func (c *ClaimsMappingPolicyClient) Update(ctx context.Context, claimsMappingPolicy ClaimsMappingPolicy) (int, error) {
var status int

if claimsMappingPolicy.ID == nil {
return status, fmt.Errorf("cannot update ClaimsMappingPolicy with nil ID")
}

claimsMappingPolicyId := *claimsMappingPolicy.ID
claimsMappingPolicy.ID = nil

body, err := json.Marshal(claimsMappingPolicy)
if err != nil {
return status, fmt.Errorf("json.Marshal(): %v", err)
}

_, status, _, err = c.BaseClient.Patch(ctx, PatchHttpRequestInput{
Body: body,
ConsistencyFailureFunc: RetryOn404ConsistencyFailureFunc,
ValidStatusCodes: []int{
http.StatusOK,
http.StatusNoContent,
},
Uri: Uri{
Entity: fmt.Sprintf("/policies/claimsMappingPolicies/%s", claimsMappingPolicyId),
HasTenantId: true,
},
})
if err != nil {
return status, fmt.Errorf("ClaimsMappingPolicy.BaseClient.Patch(): %v", err)
}

return status, nil
}

// Delete removes a ClaimsMappingPolicy.
func (c *ClaimsMappingPolicyClient) Delete(ctx context.Context, id string) (int, error) {
_, status, _, err := c.BaseClient.Delete(ctx, DeleteHttpRequestInput{
ConsistencyFailureFunc: RetryOn404ConsistencyFailureFunc,
ValidStatusCodes: []int{http.StatusNoContent},
Uri: Uri{
Entity: fmt.Sprintf("/policies/claimsMappingPolicies/%s", id),
HasTenantId: false,
},
})
if err != nil {
return status, fmt.Errorf("ClaimsMappingPolicyClient.BaseClient.Delete(): %v", err)
}

return status, nil
}
91 changes: 91 additions & 0 deletions msgraph/claims_mapping_policy_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
package msgraph_test

import (
"fmt"
"testing"

"github.com/manicminer/hamilton/internal/test"
"github.com/manicminer/hamilton/internal/utils"
"github.com/manicminer/hamilton/msgraph"
"github.com/manicminer/hamilton/odata"
)

func TestClaimsMappingPolicyClient(t *testing.T) {
c := test.NewTest(t)
defer c.CancelFunc()
policy := testClaimsMappingPolicyClient_Create(t, c, msgraph.ClaimsMappingPolicy{
DisplayName: utils.StringPtr(fmt.Sprintf("test-claims-mapping-policy-%s", c.RandomString)),
Definition: utils.ArrayStringPtr(
[]string{
"{\"ClaimsMappingPolicy\":{\"Version\":1,\"IncludeBasicClaimSet\":\"true\",\"ClaimsSchema\": [{\"Source\":\"user\",\"ID\":\"employeeid\",\"SamlClaimType\":\"http://schemas.xmlsoap.org/ws/2005/05/identity/claims/name\",\"JwtClaimType\":\"name\"},{\"Source\":\"company\",\"ID\":\"tenantcountry\",\"SamlClaimType\":\"http://schemas.xmlsoap.org/ws/2005/05/identity/claims/country\",\"JwtClaimType\":\"country\"}]}}",
},
),
})
testClaimsMappingPolicyClient_List(t, c)
testClaimsMappingPolicyClient_Get(t, c, *policy.ID)
policy.DisplayName = utils.StringPtr(fmt.Sprintf("test-claims-mapping-policy-%s", c.RandomString))
testClaimsMappingPolicyClient_Update(t, c, *policy)
testClaimsMappingPolicyClient_Delete(t, c, *policy.ID)
}

func testClaimsMappingPolicyClient_Create(t *testing.T, c *test.Test, p msgraph.ClaimsMappingPolicy) (policy *msgraph.ClaimsMappingPolicy) {
policy, status, err := c.ClaimsMappingPolicyClient.Create(c.Context, p)
if err != nil {
t.Fatalf("ClaimsMappingPolicyClient.Create(): %v", err)
}
if status < 200 || status >= 300 {
t.Fatalf("ClaimsMappingPolicyClient.Create(): invalid status: %d", status)
}
if policy == nil {
t.Fatal("ClaimsMappingPolicyClient.Create(): policy was nil")
}
if policy.ID == nil {
t.Fatal("ClaimsMappingPolicyClient.Create(): policy.ID was nil")
}
return
}

func testClaimsMappingPolicyClient_List(t *testing.T, c *test.Test) (policy *[]msgraph.ClaimsMappingPolicy) {
policies, _, err := c.ClaimsMappingPolicyClient.List(c.Context, odata.Query{Top: 10})
if err != nil {
t.Fatalf("ClaimsMappingPolicy.List(): %v", err)
}
if policies == nil {
t.Fatal("ClaimsMappingPolicy.List(): ClaimsMappingPolicies was nil")
}
return
}

func testClaimsMappingPolicyClient_Get(t *testing.T, c *test.Test, id string) (ClaimsMappingPolicy *msgraph.ClaimsMappingPolicy) {
policies, status, err := c.ClaimsMappingPolicyClient.Get(c.Context, id, odata.Query{})
if err != nil {
t.Fatalf("ClaimsMappingPolicyClient.Get(): %v", err)
}
if status < 200 || status >= 300 {
t.Fatalf("ClaimsMappingPolicyClient.Get(): invalid status: %d", status)
}
if policies == nil {
t.Fatal("ClaimsMappingPolicyClient.Get(): policies was nil")
}
return
}

func testClaimsMappingPolicyClient_Update(t *testing.T, c *test.Test, g msgraph.ClaimsMappingPolicy) {
status, err := c.ClaimsMappingPolicyClient.Update(c.Context, g)
if err != nil {
t.Fatalf("ClaimsMappingPolicyClient.Update(): %v", err)
}
if status < 200 || status >= 300 {
t.Fatalf("ClaimsMappingPolicyClient.Update(): invalid status: %d", status)
}
}

func testClaimsMappingPolicyClient_Delete(t *testing.T, c *test.Test, id string) {
status, err := c.ClaimsMappingPolicyClient.Delete(c.Context, id)
if err != nil {
t.Fatalf("ClaimsMappingPolicyClient.Delete(): %v", err)
}
if status < 200 || status >= 300 {
t.Fatalf("ClaimsMappingPolicyClient.Delete(): invalid status: %d", status)
}
}
8 changes: 8 additions & 0 deletions msgraph/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,14 @@ type BaseNamedLocation struct {
ModifiedDateTime *time.Time `json:"modifiedDateTime,omitempty"`
}

type ClaimsMappingPolicy struct {
Definition *[]string `json:"definition,omitempty"`
Description *string `json:"description,omitempty"`
DisplayName *string `json:"displayName,omitempty"`
ID *string `json:"id,omitempty"`
IsOrganizationDefault *bool `json:"isOrganizationDefault,omitempty"`
}

type CloudAppSecurityControl struct {
IsEnabled *bool `json:"isEnabled,omitempty"`
CloudAppSecurityType *ConditionalAccessCloudAppSecuritySessionControlType `json:"cloudAppSecurityType,omitempty"`
Expand Down

0 comments on commit 9f20f0b

Please sign in to comment.