diff --git a/internal/provider/resource_saml_group_mapping.go b/internal/provider/resource_saml_group_mapping.go index e8db79b..0fb2d6c 100644 --- a/internal/provider/resource_saml_group_mapping.go +++ b/internal/provider/resource_saml_group_mapping.go @@ -30,13 +30,9 @@ type UpdateSAMLGroupMappingInput struct { Patch wiz.ModifySAMLGroupMappingPatch `json:"patch"` } -// SAMLGroupMappingsImport represents the structure of a SAML group mapping import. -// It includes the SAML IdP ID, provider group ID, project IDs, and role. type SAMLGroupMappingsImport struct { - SamlIdpID string - ProviderGroupID string - ProjectIDs []string - Role string + SamlIdpID string + GroupMappings []wiz.SAMLGroupDetailsInput } // UpdateSAMLGroupMappingPayload struct @@ -59,23 +55,30 @@ func resourceWizSAMLGroupMapping() *schema.Resource { Required: true, ForceNew: true, }, - "provider_group_id": { - Type: schema.TypeString, - Description: "Provider group ID", - Required: true, - ForceNew: true, - }, - "role": { - Type: schema.TypeString, - Description: "Wiz Role name", - Required: true, - }, - "projects": { - Type: schema.TypeList, - Optional: true, - Description: "Project mapping", - Elem: &schema.Schema{ - Type: schema.TypeString, + "group_mapping": { + Type: schema.TypeSet, + Required: true, + Elem: &schema.Resource{ + Schema: map[string]*schema.Schema{ + "provider_group_id": { + Type: schema.TypeString, + Description: "Provider group ID", + Required: true, + }, + "role": { + Type: schema.TypeString, + Description: "Wiz Role name", + Required: true, + }, + "projects": { + Type: schema.TypeList, + Optional: true, + Description: "Project mapping", + Elem: &schema.Schema{ + Type: schema.TypeString, + }, + }, + }, }, }, }, @@ -84,9 +87,9 @@ func resourceWizSAMLGroupMapping() *schema.Resource { UpdateContext: resourceSAMLGroupMappingUpdate, DeleteContext: resourceSAMLGroupMappingDelete, Importer: &schema.ResourceImporter{ - StateContext: func(ctx context.Context, d *schema.ResourceData, m interface{}) ([]*schema.ResourceData, error) { - // schema for import id: mapping|||| + StateContext: func(ctx context.Context, d *schema.ResourceData, m interface{}) ([]*schema.ResourceData, error) { + // schema for import id: | mappingToImport, err := extractIDsFromSamlIdpGroupMappingImportID(d.Id()) if err != nil { return nil, err @@ -97,23 +100,22 @@ func resourceWizSAMLGroupMapping() *schema.Resource { return nil, err } - err = d.Set("provider_group_id", mappingToImport.ProviderGroupID) - if err != nil { - return nil, err - } - - err = d.Set("role", mappingToImport.Role) - if err != nil { - return nil, err + var groupMappings []map[string]interface{} + for _, groupMapping := range mappingToImport.GroupMappings { + groupMappingMap := map[string]interface{}{ + "provider_group_id": groupMapping.ProviderGroupID, + "role": groupMapping.Role, + "projects": groupMapping.Projects, + } + groupMappings = append(groupMappings, groupMappingMap) } - err = d.Set("projects", mappingToImport.ProjectIDs) + err = d.Set("group_mappings", groupMappings) if err != nil { return nil, err } d.SetId(uuid.NewString()) - return []*schema.ResourceData{d}, nil }, }, @@ -122,46 +124,53 @@ func resourceWizSAMLGroupMapping() *schema.Resource { func resourceSAMLGroupMappingCreate(ctx context.Context, d *schema.ResourceData, m interface{}) (diags diag.Diagnostics) { tflog.Info(ctx, "resourceWizSAMLGroupMappingCreate called...") - samlIdpID := d.Get("saml_idp_id").(string) - providerGroupID := d.Get("provider_group_id").(string) - role := d.Get("role").(string) - projectIDs := utils.ConvertListToString(d.Get("projects").([]interface{})) - - // verify the mapping doesn't already exist - matchingNode, diags := querySAMLGroupMappings(ctx, m, samlIdpID, providerGroupID, role, projectIDs) - if len(diags) != 0 { - return diags - } + groupMappings := d.Get("group_mapping").(*schema.Set).List() + + for _, item := range groupMappings { + groupMapping := item.(map[string]interface{}) + providerGroupID := groupMapping["provider_group_id"].(string) + role := groupMapping["role"].(string) + projectIDs := utils.ConvertListToString(groupMapping["projects"].([]interface{})) + + // verify the mapping doesn't already exist + matchingNodes, diags := querySAMLGroupMappings(ctx, m, samlIdpID, groupMappings) + if len(diags) != 0 { + return diags + } - if matchingNode != nil { - return diag.Errorf("saml group mapping for group: %s and role: %s to project(s): %s already exists for saml idp provider: %s and should be imported instead", - providerGroupID, role, strings.Join(projectIDs, ", "), samlIdpID) - } + for _, matchingNode := range matchingNodes { + if matchingNode.ProviderGroupID == providerGroupID && matchingNode.Role.ID == role && slices.Equal(projectIDs, extractProjectIDs(matchingNode.Projects)) { + return diag.Errorf("saml group mapping for group: %s and role: %s to project(s): %s already exists for saml idp provider: %s and should be imported instead", + providerGroupID, role, strings.Join(projectIDs, ", "), samlIdpID) + } + } - // define the graphql query - query := `mutation SetSAMLGroupMapping ($input: ModifySAMLGroupMappingInput!) { - modifySAMLIdentityProviderGroupMappings(input: $input) { - _stub - } - }` - // populate the graphql variables - vars := &UpdateSAMLGroupMappingInput{} - vars.ID = samlIdpID - vars.Patch = wiz.ModifySAMLGroupMappingPatch{ - Upsert: &wiz.SAMLGroupDetailsInput{ - ProviderGroupID: providerGroupID, - Role: role, - Projects: projectIDs, - }, - } + // define the graphql query + query := `mutation SetSAMLGroupMapping ($input: ModifySAMLGroupMappingInput!) { + modifySAMLIdentityProviderGroupMappings(input: $input) { + _stub + } + }` + + // populate the graphql variables + vars := &UpdateSAMLGroupMappingInput{} + vars.ID = samlIdpID + vars.Patch = wiz.ModifySAMLGroupMappingPatch{ + Upsert: &wiz.SAMLGroupDetailsInput{ + ProviderGroupID: providerGroupID, + Role: role, + Projects: projectIDs, + }, + } - // process the request - data := &UpdateSAMLGroupMappingPayload{} - requestDiags := client.ProcessRequest(ctx, m, vars, data, query, "saml_group_mapping", "create") - diags = append(diags, requestDiags...) - if len(diags) > 0 { - return diags + // process the request + data := &UpdateSAMLGroupMappingPayload{} + requestDiags := client.ProcessRequest(ctx, m, vars, data, query, "saml_group_mapping", "create") + diags = append(diags, requestDiags...) + if len(diags) > 0 { + return diags + } } // set the id @@ -172,23 +181,38 @@ func resourceSAMLGroupMappingCreate(ctx context.Context, d *schema.ResourceData, func extractIDsFromSamlIdpGroupMappingImportID(id string) (SAMLGroupMappingsImport, error) { parts := strings.Split(id, "|") - if len(parts) != 5 { + + if len(parts) != 3 { return SAMLGroupMappingsImport{}, errors.New("invalid ID format") } - // if user species the mapping to be global we return an empty slice - var projectIDs []string - if parts[3] != "global" { - for _, projectID := range strings.Split(parts[3], ",") { - projectIDs = append(projectIDs, strings.TrimSpace(projectID)) + groupMappingStrings := strings.Split(parts[2], "#") + var groupMappings []wiz.SAMLGroupDetailsInput + for _, groupMappingString := range groupMappingStrings { + groupMappingParts := strings.Split(groupMappingString, ":") + if len(groupMappingParts) < 2 { + return SAMLGroupMappingsImport{}, errors.New("invalid group mapping format") + } + + providerGroupID := groupMappingParts[0] + role := groupMappingParts[1] + var projectIDs []string + if len(groupMappingParts) > 2 && groupMappingParts[2] != "" { + projectIDs = strings.Split(groupMappingParts[2], ",") + } + + groupMapping := wiz.SAMLGroupDetailsInput{ + ProviderGroupID: providerGroupID, + Role: role, + Projects: projectIDs, } + groupMappings = append(groupMappings, groupMapping) + } return SAMLGroupMappingsImport{ - SamlIdpID: parts[1], - ProviderGroupID: parts[2], - ProjectIDs: projectIDs, - Role: parts[4], + SamlIdpID: parts[1], + GroupMappings: groupMappings, }, nil } @@ -208,44 +232,48 @@ func resourceSAMLGroupMappingRead(ctx context.Context, d *schema.ResourceData, m if d.Id() == "" { return nil } + samlIdpID := d.Get("saml_idp_id").(string) - providerGroupID := d.Get("provider_group_id").(string) - role := d.Get("role").(string) - projectIDs := utils.ConvertListToString(d.Get("projects").([]interface{})) + groupMappings := d.Get("group_mapping").(*schema.Set).List() - matchingNode, diags := querySAMLGroupMappings(ctx, m, samlIdpID, providerGroupID, role, projectIDs) + var newGroupMappings []interface{} + + matchingNodes, diags := querySAMLGroupMappings(ctx, m, samlIdpID, groupMappings) if len(diags) > 0 { return diags } - // If no matching node was found, return error - if matchingNode == nil { - return diag.Errorf("saml group mapping for group: %s not found for saml idp provider: %s", providerGroupID, samlIdpID) + for _, item := range groupMappings { + groupMapping := item.(map[string]interface{}) + providerGroupID := groupMapping["provider_group_id"].(string) + role := groupMapping["role"].(string) + projectIDs := utils.ConvertListToString(groupMapping["projects"].([]interface{})) + + for _, matchingNode := range matchingNodes { + if matchingNode.ProviderGroupID == providerGroupID && matchingNode.Role.ID == role && slices.Equal(projectIDs, extractProjectIDs(matchingNode.Projects)) { + // set the resource parameters + newGroupMapping := map[string]interface{}{ + "provider_group_id": matchingNode.ProviderGroupID, + "role": matchingNode.Role.ID, + "projects": extractProjectIDs(matchingNode.Projects), + } + newGroupMappings = append(newGroupMappings, newGroupMapping) + } + } } - // set the resource parameters err := d.Set("saml_idp_id", samlIdpID) if err != nil { return append(diags, diag.FromErr(err)...) } - err = d.Set("provider_group_id", matchingNode.ProviderGroupID) - if err != nil { - return append(diags, diag.FromErr(err)...) - } - - err = d.Set("role", matchingNode.Role.ID) - if err != nil { - return append(diags, diag.FromErr(err)...) - } - - projectIDs = extractProjectIDs(matchingNode.Projects) - err = d.Set("projects", projectIDs) + err = d.Set("group_mapping", newGroupMappings) if err != nil { return append(diags, diag.FromErr(err)...) } return diags + } func resourceSAMLGroupMappingUpdate(ctx context.Context, d *schema.ResourceData, m interface{}) (diags diag.Diagnostics) { @@ -256,6 +284,9 @@ func resourceSAMLGroupMappingUpdate(ctx context.Context, d *schema.ResourceData, return nil } + samlIdpID := d.Get("saml_idp_id").(string) + groupMappings := d.Get("group_mapping").(*schema.Set).List() + // define the graphql query query := `mutation SetSAMLGroupMapping ($input: ModifySAMLGroupMappingInput!) { modifySAMLIdentityProviderGroupMappings(input: $input) { @@ -263,28 +294,30 @@ func resourceSAMLGroupMappingUpdate(ctx context.Context, d *schema.ResourceData, } }` - samlIdpID := d.Get("saml_idp_id").(string) - providerGroupID := d.Get("provider_group_id").(string) - role := d.Get("role").(string) - projects := utils.ConvertListToString(d.Get("projects").([]interface{})) - - // populate the graphql variables - vars := &UpdateSAMLGroupMappingInput{} - vars.ID = samlIdpID - vars.Patch = wiz.ModifySAMLGroupMappingPatch{ - Upsert: &wiz.SAMLGroupDetailsInput{ - ProviderGroupID: providerGroupID, - Role: role, - Projects: projects, - }, - } + for _, item := range groupMappings { + groupMapping := item.(map[string]interface{}) + providerGroupID := groupMapping["provider_group_id"].(string) + role := groupMapping["role"].(string) + projects := utils.ConvertListToString(groupMapping["projects"].([]interface{})) + + // populate the graphql variables + vars := &UpdateSAMLGroupMappingInput{} + vars.ID = samlIdpID + vars.Patch = wiz.ModifySAMLGroupMappingPatch{ + Upsert: &wiz.SAMLGroupDetailsInput{ + ProviderGroupID: providerGroupID, + Role: role, + Projects: projects, + }, + } - // process the request - data := &UpdateSAMLGroupMappingPayload{} - requestDiags := client.ProcessRequest(ctx, m, vars, data, query, "saml_group_mapping", "update") - diags = append(diags, requestDiags...) - if len(diags) > 0 { - return diags + // process the request + data := &UpdateSAMLGroupMappingPayload{} + requestDiags := client.ProcessRequest(ctx, m, vars, data, query, "saml_group_mapping", "update") + diags = append(diags, requestDiags...) + if len(diags) > 0 { + return diags + } } return resourceSAMLGroupMappingRead(ctx, d, m) @@ -298,6 +331,9 @@ func resourceSAMLGroupMappingDelete(ctx context.Context, d *schema.ResourceData, return nil } + samlIdpID := d.Get("saml_idp_id").(string) + groupMappings := d.Get("group_mapping").(*schema.Set).List() + // define the graphql query query := `mutation SetSAMLGroupMapping ($input: ModifySAMLGroupMappingInput!) { modifySAMLIdentityProviderGroupMappings(input: $input) { @@ -305,26 +341,28 @@ func resourceSAMLGroupMappingDelete(ctx context.Context, d *schema.ResourceData, } }` - samlIdpID := d.Get("saml_idp_id").(string) - providerGroupID := d.Get("provider_group_id").(string) - - // populate the graphql variables - vars := &UpdateSAMLGroupMappingInput{} - vars.ID = samlIdpID - vars.Patch.Delete = &[]string{providerGroupID} - - // process the request - data := &UpdateSAMLGroupMappingPayload{} - requestDiags := client.ProcessRequest(ctx, m, vars, data, query, "saml_group_mapping", "delete") - diags = append(diags, requestDiags...) - if len(diags) > 0 { - return diags + for _, item := range groupMappings { + groupMapping := item.(map[string]interface{}) + providerGroupID := groupMapping["provider_group_id"].(string) + + // populate the graphql variables + vars := &UpdateSAMLGroupMappingInput{} + vars.ID = samlIdpID + vars.Patch.Delete = &[]string{providerGroupID} + + // process the request + data := &UpdateSAMLGroupMappingPayload{} + requestDiags := client.ProcessRequest(ctx, m, vars, data, query, "saml_group_mapping", "delete") + diags = append(diags, requestDiags...) + if len(diags) > 0 { + return diags + } } return diags } -func querySAMLGroupMappings(ctx context.Context, m interface{}, samlIdpID string, providerGroupID string, roleID string, projectIDs []string) (*wiz.SAMLGroupMapping, diag.Diagnostics) { +func querySAMLGroupMappings(ctx context.Context, m interface{}, samlIdpID string, groupMappings []interface{}) ([]*wiz.SAMLGroupMapping, diag.Diagnostics) { // define the graphql query query := `query samlIdentityProviderGroupMappings ($id: ID!, $first: Int! $after: String){ samlIdentityProviderGroupMappings ( @@ -359,23 +397,31 @@ func querySAMLGroupMappings(ctx context.Context, m interface{}, samlIdpID string return nil, diags } - var matchingNode *wiz.SAMLGroupMapping + var matchingNodes []*wiz.SAMLGroupMapping + // Process the data... for _, data := range allData { typedData, ok := data.(*ReadSAMLGroupMappings) if !ok { return nil, diag.Errorf("data is not of type *ReadSAMLGroupMappings") } + nodes := typedData.SAMLGroupMappings.Nodes for _, node := range nodes { - nodeProjectIDs := extractProjectIDs(node.Projects) - // If we find a match, store the node and break the loop - if node.ProviderGroupID == providerGroupID && node.Role.ID == roleID && slices.Equal(projectIDs, nodeProjectIDs) { - matchingNode = node - break + for _, item := range groupMappings { + groupMapping := item.(map[string]interface{}) + providerGroupID := groupMapping["provider_group_id"].(string) + roleID := groupMapping["role"].(string) + projectIDs := utils.ConvertListToString(groupMapping["projects"].([]interface{})) + nodeProjectIDs := extractProjectIDs(node.Projects) + + // If we find a match, store the node + if node.ProviderGroupID == providerGroupID && node.Role.ID == roleID && slices.Equal(projectIDs, nodeProjectIDs) { + matchingNodes = append(matchingNodes, node) + } } } } - return matchingNode, nil + return matchingNodes, nil } diff --git a/internal/provider/resource_saml_group_mapping_test.go b/internal/provider/resource_saml_group_mapping_test.go index e73e28a..f28f88c 100644 --- a/internal/provider/resource_saml_group_mapping_test.go +++ b/internal/provider/resource_saml_group_mapping_test.go @@ -2,6 +2,8 @@ package provider import ( "reflect" + "wiz.io/hashicorp/terraform-provider-wiz/internal/wiz" + "testing" ) @@ -14,14 +16,14 @@ func TestExtractIDsFromSamlIdpGroupMappingImportID(t *testing.T) { }{ { name: "Valid ID", - input: "link|samlIdpID|providerGroupID|projectID1,projectID2|role", - expectedMapping: SAMLGroupMappingsImport{SamlIdpID: "samlIdpID", ProviderGroupID: "providerGroupID", ProjectIDs: []string{"projectID1", "projectID2"}, Role: "role"}, + input: "link|samlIdpID|providerGroupID:role:projectID1,projectID2", + expectedMapping: SAMLGroupMappingsImport{SamlIdpID: "samlIdpID", GroupMappings: []wiz.SAMLGroupDetailsInput{{ProviderGroupID: "providerGroupID", Role: "role", Projects: []string{"projectID1", "projectID2"}}}}, expectErr: false, }, { name: "Valid ID global mapping", - input: "link|samlIdpID|providerGroupID|global|role", - expectedMapping: SAMLGroupMappingsImport{SamlIdpID: "samlIdpID", ProviderGroupID: "providerGroupID", ProjectIDs: nil, Role: "role"}, + input: "link|samlIdpID|providerGroupID:role", + expectedMapping: SAMLGroupMappingsImport{SamlIdpID: "samlIdpID", GroupMappings: []wiz.SAMLGroupDetailsInput{{ProviderGroupID: "providerGroupID", Role: "role", Projects: nil}}}, expectErr: false, }, { @@ -32,7 +34,7 @@ func TestExtractIDsFromSamlIdpGroupMappingImportID(t *testing.T) { }, { name: "Invalid ID length", - input: "link|samlIdpId|providerGroupId", + input: "link|samlIdpId", expectedMapping: SAMLGroupMappingsImport{}, expectErr: true, }, @@ -44,7 +46,6 @@ func TestExtractIDsFromSamlIdpGroupMappingImportID(t *testing.T) { if (err != nil) != tc.expectErr { t.Errorf("Expected error: %v, got: %v", tc.expectErr, err) } - if !reflect.DeepEqual(mapping, tc.expectedMapping) { t.Errorf("Expected mapping: %+v, got: %+v", tc.expectedMapping, mapping) }