Skip to content

Commit

Permalink
Merge pull request #39283 from fyqtian/fix/webacl
Browse files Browse the repository at this point in the history
fix(webaclv2): WAFv2 rule loss when making an update."
  • Loading branch information
ewbankkit authored Oct 25, 2024
2 parents 514e460 + 0275a24 commit 7d695f8
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 13 deletions.
3 changes: 3 additions & 0 deletions .changelog/39283.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:bug
resource/aws_wafv2_web_acl: Fix `decoding JSON: unexpected end of JSON input` errors when updating from using `rule_json` to using `rule`
```
20 changes: 13 additions & 7 deletions internal/service/wafv2/flex.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
package wafv2

import (
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"reflect"
"strings"
Expand All @@ -14,6 +13,8 @@ import (
awstypes "github.com/aws/aws-sdk-go-v2/service/wafv2/types"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema"
"github.com/hashicorp/terraform-provider-aws/internal/flex"
tfjson "github.com/hashicorp/terraform-provider-aws/internal/json"
itypes "github.com/hashicorp/terraform-provider-aws/internal/types"
"github.com/hashicorp/terraform-provider-aws/names"
)

Expand Down Expand Up @@ -984,23 +985,28 @@ func expandHeaderMatchPattern(l []interface{}) *awstypes.HeaderMatchPattern {
}

func expandWebACLRulesJSON(rawRules string) ([]awstypes.Rule, error) {
// Backwards compatibility.
if rawRules == "" {
return nil, errors.New("decoding JSON: unexpected end of JSON input")
}

var temp []any
err := json.Unmarshal([]byte(rawRules), &temp)
err := tfjson.DecodeFromBytes([]byte(rawRules), &temp)
if err != nil {
return nil, fmt.Errorf("decoding JSON: %s", err)
return nil, fmt.Errorf("decoding JSON: %w", err)
}

for _, v := range temp {
walkWebACLJSON(reflect.ValueOf(v))
}

out, err := json.Marshal(temp)
out, err := tfjson.EncodeToBytes(temp)
if err != nil {
return nil, err
}

var rules []awstypes.Rule
err = json.Unmarshal(out, &rules)
err = tfjson.DecodeFromBytes(out, &rules)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1041,7 +1047,7 @@ func walkWebACLJSON(v reflect.Value) {
case reflect.Slice, reflect.Array:
switch reflect.ValueOf(va.outputType).Type().Elem().Kind() {
case reflect.Uint8:
base64String := base64.StdEncoding.EncodeToString([]byte(str.(string)))
base64String := itypes.Base64Encode([]byte(str.(string)))
st[va.key] = base64String
default:
}
Expand Down
4 changes: 2 additions & 2 deletions internal/service/wafv2/web_acl.go
Original file line number Diff line number Diff line change
Expand Up @@ -325,8 +325,8 @@ func resourceWebACLUpdate(ctx context.Context, d *schema.ResourceData, meta inte
rules = append(rules, findShieldRule(output.WebACL.Rules)...)
}

if d.HasChange("rule_json") {
r, err := expandWebACLRulesJSON(d.Get("rule_json").(string))
if v, ok := d.GetOk("rule_json"); ok {
r, err := expandWebACLRulesJSON(v.(string))
if err != nil {
return sdkdiag.AppendErrorf(diags, "expanding WAFv2 WebACL JSON rule (%s): %s", d.Id(), err)
}
Expand Down
112 changes: 108 additions & 4 deletions internal/service/wafv2/web_acl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3060,7 +3060,7 @@ func TestAccWAFV2WebACL_ruleJSON(t *testing.T) {
CheckDestroy: testAccCheckWebACLDestroy(ctx),
Steps: []resource.TestStep{
{
Config: testAccWebACLConfig_JSONrule(webACLName),
Config: testAccWebACLConfig_jsonRule(webACLName),
Check: resource.ComposeAggregateTestCheckFunc(
testAccCheckWebACLExists(ctx, resourceName, &v),
acctest.MatchResourceAttrRegionalARN(resourceName, names.AttrARN, "wafv2", regexache.MustCompile(`regional/webacl/.+$`)),
Expand All @@ -3075,7 +3075,7 @@ func TestAccWAFV2WebACL_ruleJSON(t *testing.T) {
ImportStateIdFunc: testAccWebACLImportStateIdFunc(resourceName),
},
{
Config: testAccWebACLConfig_JSONruleUpdate(webACLName),
Config: testAccWebACLConfig_jsonRuleUpdate(webACLName),
Check: resource.ComposeAggregateTestCheckFunc(
testAccCheckWebACLExists(ctx, resourceName, &v),
acctest.MatchResourceAttrRegionalARN(resourceName, names.AttrARN, "wafv2", regexache.MustCompile(`regional/webacl/.+$`)),
Expand All @@ -3086,6 +3086,62 @@ func TestAccWAFV2WebACL_ruleJSON(t *testing.T) {
})
}

func TestAccWAFV2WebACL_ruleJSONToRule(t *testing.T) {
ctx := acctest.Context(t)
var v awstypes.WebACL
webACLName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix)
resourceName := "aws_wafv2_web_acl.test"

resource.ParallelTest(t, resource.TestCase{
PreCheck: func() { acctest.PreCheck(ctx, t); testAccPreCheckScopeRegional(ctx, t) },
ErrorCheck: acctest.ErrorCheck(t, names.WAFV2ServiceID),
ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories,
CheckDestroy: testAccCheckWebACLDestroy(ctx),
Steps: []resource.TestStep{
{
Config: testAccWebACLConfig_jsonRule(webACLName),
Check: resource.ComposeAggregateTestCheckFunc(
testAccCheckWebACLExists(ctx, resourceName, &v),
acctest.MatchResourceAttrRegionalARN(resourceName, names.AttrARN, "wafv2", regexache.MustCompile(`regional/webacl/.+$`)),
resource.TestCheckResourceAttr(resourceName, acctest.CtRulePound, "0"),
resource.TestCheckResourceAttrSet(resourceName, "rule_json"),
),
},
{
Config: testAccWebACLConfig_jsonRuleEquivalent(webACLName),
Check: resource.ComposeAggregateTestCheckFunc(
testAccCheckWebACLExists(ctx, resourceName, &v),
acctest.MatchResourceAttrRegionalARN(resourceName, names.AttrARN, "wafv2", regexache.MustCompile(`regional/webacl/.+$`)),
resource.TestCheckResourceAttr(resourceName, acctest.CtRulePound, "1"),
resource.TestCheckTypeSetElemNestedAttrs(resourceName, "rule.*", map[string]string{
names.AttrName: "rule-1",
names.AttrPriority: "10",
"action.#": "1",
"action.0.allow.#": "0",
"action.0.block.#": "0",
"action.0.count.#": "1",
"visibility_config.#": "1",
"visibility_config.0.cloudwatch_metrics_enabled": acctest.CtFalse,
"visibility_config.0.metric_name": "friendly-rule-metric-name",
"visibility_config.0.sampled_requests_enabled": acctest.CtFalse,
"statement.#": "1",
"statement.0.rate_based_statement.#": "1",
"statement.0.rate_based_statement.0.aggregate_key_type": "IP",
"statement.0.rate_based_statement.0.evaluation_window_sec": "600",
"statement.0.rate_based_statement.0.limit": "10000",
"statement.0.rate_based_statement.0.scope_down_statement.#": "1",
"statement.0.rate_based_statement.0.scope_down_statement.0.geo_match_statement.#": "1",
"statement.0.rate_based_statement.0.scope_down_statement.0.geo_match_statement.0.country_codes.#": "2",
"statement.0.rate_based_statement.0.scope_down_statement.0.geo_match_statement.0.country_codes.0": "US",
"statement.0.rate_based_statement.0.scope_down_statement.0.geo_match_statement.0.country_codes.1": "NL",
}),
resource.TestCheckResourceAttr(resourceName, "rule_json", ""),
),
},
},
})
}

func testAccCheckWebACLDestroy(ctx context.Context) resource.TestCheckFunc {
return func(s *terraform.State) error {
conn := acctest.Provider.Meta().(*conns.AWSClient).WAFV2Client(ctx)
Expand Down Expand Up @@ -6123,7 +6179,7 @@ resource "aws_wafv2_web_acl" "test" {
`, rName)
}

func testAccWebACLConfig_JSONrule(rName string) string {
func testAccWebACLConfig_jsonRule(rName string) string {
return fmt.Sprintf(`
resource "aws_wafv2_web_acl" "test" {
name = %[1]q
Expand Down Expand Up @@ -6169,7 +6225,7 @@ resource "aws_wafv2_web_acl" "test" {
`, rName)
}

func testAccWebACLConfig_JSONruleUpdate(rName string) string {
func testAccWebACLConfig_jsonRuleUpdate(rName string) string {
return fmt.Sprintf(`
resource "aws_wafv2_web_acl" "test" {
name = %[1]q
Expand Down Expand Up @@ -6245,3 +6301,51 @@ resource "aws_wafv2_web_acl" "test" {
}
`, rName)
}

func testAccWebACLConfig_jsonRuleEquivalent(rName string) string {
return fmt.Sprintf(`
resource "aws_wafv2_web_acl" "test" {
name = %[1]q
description = %[1]q
scope = "REGIONAL"
default_action {
allow {}
}
visibility_config {
cloudwatch_metrics_enabled = false
metric_name = "friendly-metric-name"
sampled_requests_enabled = false
}
rule {
name = "rule-1"
priority = 10
action {
count {}
}
statement {
rate_based_statement {
limit = 10000
aggregate_key_type = "IP"
evaluation_window_sec = 600
scope_down_statement {
geo_match_statement {
country_codes = ["US", "NL"]
}
}
}
}
visibility_config {
cloudwatch_metrics_enabled = false
metric_name = "friendly-rule-metric-name"
sampled_requests_enabled = false
}
}
}
`, rName)
}

0 comments on commit 7d695f8

Please sign in to comment.