Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add override_action to aws_networkfirewall_firewall_policy #25135

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .changelog/25135.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:enhancement
resource/aws_networkfirewall_firewall_policy: Add `firewall_policy.stateful_rule_group_reference.override` argument
```
14 changes: 0 additions & 14 deletions internal/service/networkfirewall/find.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,6 @@ func FindFirewall(ctx context.Context, conn *networkfirewall.NetworkFirewall, ar
return output, nil
}

// FindFirewallPolicy returns the FirewallPolicyOutput from a call to DescribeFirewallPolicyWithContext
// given the context and FindFirewallPolicy ARN.
// Returns nil if the FindFirewallPolicy is not found.
func FindFirewallPolicy(ctx context.Context, conn *networkfirewall.NetworkFirewall, arn string) (*networkfirewall.DescribeFirewallPolicyOutput, error) {
input := &networkfirewall.DescribeFirewallPolicyInput{
FirewallPolicyArn: aws.String(arn),
}
output, err := conn.DescribeFirewallPolicyWithContext(ctx, input)
if err != nil {
return nil, err
}
return output, nil
}

// FindFirewallPolicyByNameAndARN returns the FirewallPolicyOutput from a call to DescribeFirewallPolicyWithContext
// given the context and at least one of FirewallPolicyArn and FirewallPolicyName.
func FindFirewallPolicyByNameAndARN(ctx context.Context, conn *networkfirewall.NetworkFirewall, arn string, name string) (*networkfirewall.DescribeFirewallPolicyOutput, error) {
Expand Down
190 changes: 138 additions & 52 deletions internal/service/networkfirewall/firewall_policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package networkfirewall

import (
"context"
"fmt"
"log"

"github.com/aws/aws-sdk-go/aws"
Expand Down Expand Up @@ -70,6 +69,20 @@ func ResourceFirewallPolicy() *schema.Resource {
Optional: true,
Elem: &schema.Resource{
Schema: map[string]*schema.Schema{
"override": {
Type: schema.TypeList,
MaxItems: 1,
Optional: true,
Elem: &schema.Resource{
Schema: map[string]*schema.Schema{
"action": {
Type: schema.TypeString,
Optional: true,
ValidateFunc: validation.StringInSlice(networkfirewall.OverrideAction_Values(), false),
},
},
},
},
"priority": {
Type: schema.TypeInt,
Optional: true,
Expand Down Expand Up @@ -143,6 +156,7 @@ func resourceFirewallPolicyCreate(ctx context.Context, d *schema.ResourceData, m
conn := meta.(*conns.AWSClient).NetworkFirewallConn
defaultTagsConfig := meta.(*conns.AWSClient).DefaultTagsConfig
tags := defaultTagsConfig.MergeTags(tftags.New(d.Get("tags").(map[string]interface{})))

name := d.Get("name").(string)
input := &networkfirewall.CreateFirewallPolicyInput{
FirewallPolicy: expandFirewallPolicy(d.Get("firewall_policy").([]interface{})),
Expand All @@ -157,14 +171,11 @@ func resourceFirewallPolicyCreate(ctx context.Context, d *schema.ResourceData, m
input.Tags = Tags(tags.IgnoreAWS())
}

log.Printf("[DEBUG] Creating NetworkFirewall Firewall Policy %s", name)

log.Printf("[DEBUG] Creating NetworkFirewall Firewall Policy: %s", input)
output, err := conn.CreateFirewallPolicyWithContext(ctx, input)

if err != nil {
return diag.FromErr(fmt.Errorf("error creating NetworkFirewall Firewall Policy (%s): %w", name, err))
}
if output == nil || output.FirewallPolicyResponse == nil {
return diag.FromErr(fmt.Errorf("error creating NetworkFirewall Firewall Policy (%s): empty output", name))
return diag.Errorf("creating NetworkFirewall Firewall Policy (%s): %s", name, err)
}

d.SetId(aws.StringValue(output.FirewallPolicyResponse.FirewallPolicyArn))
Expand All @@ -177,23 +188,16 @@ func resourceFirewallPolicyRead(ctx context.Context, d *schema.ResourceData, met
defaultTagsConfig := meta.(*conns.AWSClient).DefaultTagsConfig
ignoreTagsConfig := meta.(*conns.AWSClient).IgnoreTagsConfig

log.Printf("[DEBUG] Reading NetworkFirewall Firewall Policy %s", d.Id())
output, err := FindFirewallPolicyByARN(ctx, conn, d.Id())

output, err := FindFirewallPolicy(ctx, conn, d.Id())
if !d.IsNewResource() && tfawserr.ErrCodeEquals(err, networkfirewall.ErrCodeResourceNotFoundException) {
if !d.IsNewResource() && tfresource.NotFound(err) {
log.Printf("[WARN] NetworkFirewall Firewall Policy (%s) not found, removing from state", d.Id())
d.SetId("")
return nil
}
if err != nil {
return diag.FromErr(fmt.Errorf("error reading NetworkFirewall Firewall Policy (%s): %w", d.Id(), err))
}

if output == nil {
return diag.FromErr(fmt.Errorf("error reading NetworkFirewall Firewall Policy (%s): empty output", d.Id()))
}
if output.FirewallPolicyResponse == nil {
return diag.FromErr(fmt.Errorf("error reading NetworkFirewall Firewall Policy (%s): empty output.FirewallPolicyResponse", d.Id()))
if err != nil {
return diag.Errorf("reading NetworkFirewall Firewall Policy (%s): %s", d.Id(), err)
}

resp := output.FirewallPolicyResponse
Expand All @@ -205,49 +209,51 @@ func resourceFirewallPolicyRead(ctx context.Context, d *schema.ResourceData, met
d.Set("update_token", output.UpdateToken)

if err := d.Set("firewall_policy", flattenFirewallPolicy(policy)); err != nil {
return diag.FromErr(fmt.Errorf("error setting firewall_policy: %w", err))
return diag.Errorf("setting firewall_policy: %s", err)
}

tags := KeyValueTags(resp.Tags).IgnoreAWS().IgnoreConfig(ignoreTagsConfig)

//lintignore:AWSR002
if err := d.Set("tags", tags.RemoveDefaultConfig(defaultTagsConfig).Map()); err != nil {
return diag.FromErr(fmt.Errorf("error setting tags: %w", err))
return diag.Errorf("setting tags: %s", err)
}

if err := d.Set("tags_all", tags.Map()); err != nil {
return diag.FromErr(fmt.Errorf("error setting tags_all: %w", err))
return diag.Errorf("setting tags_all: %s", err)
}

return nil
}

func resourceFirewallPolicyUpdate(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics {
conn := meta.(*conns.AWSClient).NetworkFirewallConn
arn := d.Id()

log.Printf("[DEBUG] Updating NetworkFirewall Firewall Policy %s", arn)

if d.HasChanges("description", "firewall_policy") {
input := &networkfirewall.UpdateFirewallPolicyInput{
FirewallPolicy: expandFirewallPolicy(d.Get("firewall_policy").([]interface{})),
FirewallPolicyArn: aws.String(arn),
FirewallPolicyArn: aws.String(d.Id()),
UpdateToken: aws.String(d.Get("update_token").(string)),
}

// Only pass non-empty description values, else API request returns an InternalServiceError
if v, ok := d.GetOk("description"); ok {
input.Description = aws.String(v.(string))
}

log.Printf("[DEBUG] Updating NetworkFirewall Firewall Policy: %s", input)
_, err := conn.UpdateFirewallPolicyWithContext(ctx, input)

if err != nil {
return diag.FromErr(fmt.Errorf("error updating NetworkFirewall Firewall Policy (%s) firewall_policy: %w", arn, err))
return diag.Errorf("updating NetworkFirewall Firewall Policy (%s): %s", d.Id(), err)
}
}

if d.HasChange("tags_all") {
o, n := d.GetChange("tags_all")
if err := UpdateTags(conn, arn, o, n); err != nil {
return diag.FromErr(fmt.Errorf("error updating NetworkFirewall Firewall Policy (%s) tags: %w", arn, err))

if err := UpdateTagsWithContext(ctx, conn, d.Id(), o, n); err != nil {
return diag.Errorf("updating NetworkFirewall Firewall Policy (%s) tags: %s", d.Id(), err)
}
}

Expand All @@ -257,42 +263,84 @@ func resourceFirewallPolicyUpdate(ctx context.Context, d *schema.ResourceData, m
func resourceFirewallPolicyDelete(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics {
conn := meta.(*conns.AWSClient).NetworkFirewallConn

log.Printf("[DEBUG] Deleting NetworkFirewall Firewall Policy %s", d.Id())
log.Printf("[DEBUG] Deleting NetworkFirewall Firewall Policy: %s", d.Id())
_, err := tfresource.RetryWhenAWSErrMessageContainsContext(ctx, firewallPolicyTimeout, func() (interface{}, error) {
return conn.DeleteFirewallPolicyWithContext(ctx, &networkfirewall.DeleteFirewallPolicyInput{
FirewallPolicyArn: aws.String(d.Id()),
})
}, networkfirewall.ErrCodeInvalidOperationException, "Unable to delete the object because it is still in use")

input := &networkfirewall.DeleteFirewallPolicyInput{
FirewallPolicyArn: aws.String(d.Id()),
if tfawserr.ErrCodeEquals(err, networkfirewall.ErrCodeResourceNotFoundException) {
return nil
}

err := resource.RetryContext(ctx, firewallPolicyTimeout, func() *resource.RetryError {
_, err := conn.DeleteFirewallPolicyWithContext(ctx, input)
if err != nil {
if tfawserr.ErrMessageContains(err, networkfirewall.ErrCodeInvalidOperationException, "Unable to delete the object because it is still in use") {
return resource.RetryableError(err)
}
return resource.NonRetryableError(err)
}
return nil
})
if err != nil {
return diag.Errorf("deleting NetworkFirewall Firewall Policy (%s): %s", d.Id(), err)
}

if tfresource.TimedOut(err) {
_, err = conn.DeleteFirewallPolicyWithContext(ctx, input)
if _, err := waitFirewallPolicyDeleted(ctx, conn, d.Id()); err != nil {
return diag.Errorf("waiting for NetworkFirewall Firewall Policy (%s) delete: %s", d.Id(), err)
}

if err != nil {
if tfawserr.ErrCodeEquals(err, networkfirewall.ErrCodeResourceNotFoundException) {
return nil
return nil
}

func FindFirewallPolicyByARN(ctx context.Context, conn *networkfirewall.NetworkFirewall, arn string) (*networkfirewall.DescribeFirewallPolicyOutput, error) {
input := &networkfirewall.DescribeFirewallPolicyInput{
FirewallPolicyArn: aws.String(arn),
}

output, err := conn.DescribeFirewallPolicyWithContext(ctx, input)

if tfawserr.ErrCodeEquals(err, networkfirewall.ErrCodeResourceNotFoundException) {
return nil, &resource.NotFoundError{
LastError: err,
LastRequest: input,
}
return diag.FromErr(fmt.Errorf("error deleting NetworkFirewall Firewall Policy (%s): %w", d.Id(), err))
}

if _, err := waitFirewallPolicyDeleted(ctx, conn, d.Id()); err != nil {
if tfawserr.ErrCodeEquals(err, networkfirewall.ErrCodeResourceNotFoundException) {
return nil
if err != nil {
return nil, err
}

if output == nil || output.FirewallPolicyResponse == nil {
return nil, tfresource.NewEmptyResultError(input)
}

return output, nil
}

func statusFirewallPolicy(ctx context.Context, conn *networkfirewall.NetworkFirewall, arn string) resource.StateRefreshFunc {
return func() (interface{}, string, error) {
output, err := FindFirewallPolicyByARN(ctx, conn, arn)

if tfresource.NotFound(err) {
return nil, "", nil
}
return diag.FromErr(fmt.Errorf("error waiting for NetworkFirewall Firewall Policy (%s) to delete: %w", d.Id(), err))

if err != nil {
return nil, "", err
}

return output, aws.StringValue(output.FirewallPolicyResponse.FirewallPolicyStatus), nil
}
}

return nil
func waitFirewallPolicyDeleted(ctx context.Context, conn *networkfirewall.NetworkFirewall, arn string) (*networkfirewall.DescribeFirewallPolicyOutput, error) {
stateConf := &resource.StateChangeConf{
Pending: []string{networkfirewall.ResourceStatusDeleting},
Target: []string{},
Refresh: statusFirewallPolicy(ctx, conn, arn),
Timeout: firewallPolicyTimeout,
}

outputRaw, err := stateConf.WaitForStateContext(ctx)

if v, ok := outputRaw.(*networkfirewall.DescribeFirewallPolicyOutput); ok {
return v, err
}

return nil, err
}

func expandStatefulEngineOptions(l []interface{}) *networkfirewall.StatefulEngineOptions {
Expand All @@ -310,6 +358,21 @@ func expandStatefulEngineOptions(l []interface{}) *networkfirewall.StatefulEngin
return options
}

func expandStatefulRuleGroupOverride(l []interface{}) *networkfirewall.StatefulRuleGroupOverride {
if len(l) == 0 || l[0] == nil {
return nil
}

lRaw := l[0].(map[string]interface{})
override := &networkfirewall.StatefulRuleGroupOverride{}

if v, ok := lRaw["action"].(string); ok && v != "" {
override.SetAction(v)
}

return override
}

func expandStatefulRuleGroupReferences(l []interface{}) []*networkfirewall.StatefulRuleGroupReference {
if len(l) == 0 || l[0] == nil {
return nil
Expand All @@ -320,15 +383,22 @@ func expandStatefulRuleGroupReferences(l []interface{}) []*networkfirewall.State
if !ok {
continue
}

reference := &networkfirewall.StatefulRuleGroupReference{}
if v, ok := tfMap["priority"].(int); ok && v > 0 {
reference.Priority = aws.Int64(int64(v))
}
if v, ok := tfMap["resource_arn"].(string); ok && v != "" {
reference.ResourceArn = aws.String(v)
}

if v, ok := tfMap["override"].([]interface{}); ok && len(v) > 0 {
reference.Override = expandStatefulRuleGroupOverride(v)
}

references = append(references, reference)
}

return references
}

Expand Down Expand Up @@ -429,6 +499,18 @@ func flattenStatefulEngineOptions(options *networkfirewall.StatefulEngineOptions
return []interface{}{m}
}

func flattenStatefulRuleGroupOverride(override *networkfirewall.StatefulRuleGroupOverride) []interface{} {
if override == nil {
return []interface{}{}
}

m := map[string]interface{}{
"action": aws.StringValue(override.Action),
}

return []interface{}{m}
}

func flattenPolicyStatefulRuleGroupReference(l []*networkfirewall.StatefulRuleGroupReference) []interface{} {
references := make([]interface{}, 0, len(l))
for _, ref := range l {
Expand All @@ -438,6 +520,10 @@ func flattenPolicyStatefulRuleGroupReference(l []*networkfirewall.StatefulRuleGr
if ref.Priority != nil {
reference["priority"] = int(aws.Int64Value(ref.Priority))
}
if ref.Override != nil {
reference["override"] = flattenStatefulRuleGroupOverride(ref.Override)
}

references = append(references, reference)
}

Expand Down
Loading