diff --git a/aws_policy_equivalence.go b/aws_policy_equivalence.go index 7b7b370..756060d 100644 --- a/aws_policy_equivalence.go +++ b/aws_policy_equivalence.go @@ -87,19 +87,24 @@ type intermediatePolicyDocument struct { func (intermediate *intermediatePolicyDocument) document() (*policyDocument, error) { var statements []*policyStatement - switch s := intermediate.Statements.(type) { - case []interface{}: - if err := mapstructure.Decode(s, &statements); err != nil { - return nil, fmt.Errorf("parsing statement 1: %s", err) - } - case map[string]interface{}: - var singleStatement *policyStatement - if err := mapstructure.Decode(s, &singleStatement); err != nil { - return nil, fmt.Errorf("parsing statement 2: %s", err) + // Decode only non-nil statements to prevent irreversible result when setting values + // in Terraform state. + // Reference: https://github.com/hashicorp/terraform-provider-aws/issues/22944 + if intermediate.Statements != nil { + switch s := intermediate.Statements.(type) { + case []interface{}: + if err := mapstructure.Decode(s, &statements); err != nil { + return nil, fmt.Errorf("parsing statement 1: %s", err) + } + case map[string]interface{}: + var singleStatement *policyStatement + if err := mapstructure.Decode(s, &singleStatement); err != nil { + return nil, fmt.Errorf("parsing statement 2: %s", err) + } + statements = append(statements, singleStatement) + default: + return nil, errors.New("unknown statement parsing problem") } - statements = append(statements, singleStatement) - default: - return nil, errors.New("Unknown error parsing statement") } document := &policyDocument{ @@ -118,6 +123,10 @@ type policyDocument struct { } func (doc *policyDocument) equals(other *policyDocument) bool { + // Prevent panic + if doc == nil { + return other == nil + } // Check the basic fields of the document if doc.Version != other.Version { return false diff --git a/aws_policy_equivalence_test.go b/aws_policy_equivalence_test.go index 1ea6617..fac0e43 100644 --- a/aws_policy_equivalence_test.go +++ b/aws_policy_equivalence_test.go @@ -5,6 +5,7 @@ package awspolicy // file, You can obtain one at http://mozilla.org/MPL/2.0/. import ( + "encoding/json" "testing" ) @@ -228,7 +229,6 @@ func TestPolicyEquivalence(t *testing.T) { policy1: policyTest30, policy2: policyTest30, equivalent: true, - err: false, }, { name: "Incorrect Statement type", @@ -324,17 +324,19 @@ func TestPolicyEquivalence(t *testing.T) { } for _, tc := range cases { - equal, err := PoliciesAreEquivalent(tc.policy1, tc.policy2) - if !tc.err && err != nil { - t.Fatalf("Unexpected error: %s", err) - } - if tc.err && err == nil { - t.Fatal("Expected error, none produced") - } - - if equal != tc.equivalent { - t.Fatalf("Bad: %s\n Expected: %t\n Got: %t\n", tc.name, tc.equivalent, equal) - } + t.Run(tc.name, func(t *testing.T) { + equal, err := PoliciesAreEquivalent(tc.policy1, tc.policy2) + if !tc.err && err != nil { + t.Fatalf("Unexpected error: %s", err) + } + if tc.err && err == nil { + t.Fatal("Expected error, none produced") + } + + if equal != tc.equivalent { + t.Fatalf("Bad: %s\n Expected: %t\n Got: %t\n", tc.name, tc.equivalent, equal) + } + }) } } @@ -1661,3 +1663,72 @@ func TestStringValueSlicesEqualIgnoreOrder(t *testing.T) { } } } + +func TestIntermediatePolicyDocument(t *testing.T) { + cases := []struct { + name string + inputPolicy string + expectedPolicyDocument *policyDocument + err bool + }{ + { + name: "invalid policy with null string statement", + inputPolicy: `{ + "Version": "2012-10-17", + "Statement": "null" + }`, + err: true, + }, + { + name: "policy with version and nil statement", + inputPolicy: `{ + "Version": "2012-10-17", + "Statement": null + }`, + expectedPolicyDocument: &policyDocument{ + Version: "2012-10-17", + Statements: nil, + }, + }, + { + name: "basic policy", + inputPolicy: policyTest1, + expectedPolicyDocument: &policyDocument{ + Version: "2012-10-17", + Statements: []*policyStatement{ + { + Sid: "", + Effect: "Allow", + Principals: map[string]interface{}{ + "Service": "spotfleet.amazonaws.com", + }, + Actions: "sts:AssumeRole", + }, + }, + }, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + policy1intermediate := &intermediatePolicyDocument{} + err := json.Unmarshal([]byte(tc.inputPolicy), policy1intermediate) + if err != nil { + t.Fatalf("Error unmarshaling policy: %s", err) + } + + actual, err := policy1intermediate.document() + + if !tc.err && err != nil { + t.Fatalf("Unexpected error: %s", err) + } + if tc.err && err == nil { + t.Fatal("Expected error, none produced") + } + + if !actual.equals(tc.expectedPolicyDocument) { + t.Fatalf("Bad: %s\n Expected: %v\n Got: %v\n", tc.name, tc.expectedPolicyDocument, actual) + } + }) + } +}