diff --git a/internal/services/network/web_application_firewall_policy_resource.go b/internal/services/network/web_application_firewall_policy_resource.go index d5a15168ee58b..a5a3b3a8d2cff 100644 --- a/internal/services/network/web_application_firewall_policy_resource.go +++ b/internal/services/network/web_application_firewall_policy_resource.go @@ -28,9 +28,9 @@ import ( func resourceWebApplicationFirewallPolicy() *pluginsdk.Resource { resource := &pluginsdk.Resource{ - Create: resourceWebApplicationFirewallPolicyCreateUpdate, + Create: resourceWebApplicationFirewallPolicyCreate, Read: resourceWebApplicationFirewallPolicyRead, - Update: resourceWebApplicationFirewallPolicyCreateUpdate, + Update: resourceWebApplicationFirewallPolicyUpdate, Delete: resourceWebApplicationFirewallPolicyDelete, Importer: pluginsdk.ImporterValidatingResourceId(func(id string) error { _, err := webapplicationfirewallpolicies.ParseApplicationGatewayWebApplicationFirewallPolicyID(id) @@ -489,25 +489,23 @@ func resourceWebApplicationFirewallPolicy() *pluginsdk.Resource { return resource } -func resourceWebApplicationFirewallPolicyCreateUpdate(d *pluginsdk.ResourceData, meta interface{}) error { +func resourceWebApplicationFirewallPolicyCreate(d *pluginsdk.ResourceData, meta interface{}) error { client := meta.(*clients.Client).Network.WebApplicationFirewallPolicies subscriptionId := meta.(*clients.Client).Account.SubscriptionId - ctx, cancel := timeouts.ForCreateUpdate(meta.(*clients.Client).StopContext, d) + ctx, cancel := timeouts.ForCreate(meta.(*clients.Client).StopContext, d) defer cancel() id := webapplicationfirewallpolicies.NewApplicationGatewayWebApplicationFirewallPolicyID(subscriptionId, d.Get("resource_group_name").(string), d.Get("name").(string)) - if d.IsNewResource() { - resp, err := client.Get(ctx, id) - if err != nil { - if !response.WasNotFound(resp.HttpResponse) { - return fmt.Errorf("checking for present of existing %s: %+v", id, err) - } - } + resp, err := client.Get(ctx, id) + if err != nil { if !response.WasNotFound(resp.HttpResponse) { - return tf.ImportAsExistsError("azurerm_web_application_firewall_policy", id.ID()) + return fmt.Errorf("checking for present of existing %s: %+v", id, err) } } + if !response.WasNotFound(resp.HttpResponse) { + return tf.ImportAsExistsError("azurerm_web_application_firewall_policy", id.ID()) + } location := azure.NormalizeLocation(d.Get("location").(string)) customRules := d.Get("custom_rules").([]interface{}) @@ -539,6 +537,57 @@ func resourceWebApplicationFirewallPolicyCreateUpdate(d *pluginsdk.ResourceData, return resourceWebApplicationFirewallPolicyRead(d, meta) } +func resourceWebApplicationFirewallPolicyUpdate(d *pluginsdk.ResourceData, meta interface{}) error { + client := meta.(*clients.Client).Network.WebApplicationFirewallPolicies + subscriptionId := meta.(*clients.Client).Account.SubscriptionId + ctx, cancel := timeouts.ForUpdate(meta.(*clients.Client).StopContext, d) + defer cancel() + + id := webapplicationfirewallpolicies.NewApplicationGatewayWebApplicationFirewallPolicyID(subscriptionId, d.Get("resource_group_name").(string), d.Get("name").(string)) + + resp, err := client.Get(ctx, id) + if err != nil { + return fmt.Errorf("retrieving %s: %+v", id, err) + } + + if resp.Model == nil { + return fmt.Errorf("retrieving %s: model was nil", id) + } + if resp.Model.Properties == nil { + return fmt.Errorf("retrieving %s: properties was nil", id) + } + + model := resp.Model + + if d.HasChange("custom_rules") { + model.Properties.CustomRules = expandWebApplicationFirewallPolicyWebApplicationFirewallCustomRule(d.Get("custom_rules").([]interface{})) + } + + if d.HasChange("policy_settings") { + model.Properties.PolicySettings = expandWebApplicationFirewallPolicyPolicySettings(d.Get("policy_settings").([]interface{})) + } + + if d.HasChange("managed_rules") { + expandedManagedRules, err := expandWebApplicationFirewallPolicyManagedRulesDefinition(d.Get("managed_rules").([]interface{}), d) + if err != nil { + return err + } + model.Properties.ManagedRules = pointer.From(expandedManagedRules) + } + + if d.HasChange("tags") { + model.Tags = tags.Expand(d.Get("tags").(map[string]interface{})) + } + + if _, err := client.CreateOrUpdate(ctx, id, *model); err != nil { + return fmt.Errorf("creating %s: %+v", id, err) + } + + d.SetId(id.ID()) + + return resourceWebApplicationFirewallPolicyRead(d, meta) +} + func resourceWebApplicationFirewallPolicyRead(d *pluginsdk.ResourceData, meta interface{}) error { client := meta.(*clients.Client).Network.WebApplicationFirewallPolicies ctx, cancel := timeouts.ForRead(meta.(*clients.Client).StopContext, d)