Skip to content

Commit

Permalink
Code review updates
Browse files Browse the repository at this point in the history
  • Loading branch information
abikouo committed Oct 9, 2024
1 parent 469202d commit 130f908
Showing 1 changed file with 80 additions and 109 deletions.
189 changes: 80 additions & 109 deletions plugins/modules/ec2_vpc_nacl.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,10 @@

RETURN = r"""
nacl_id:
description: The id of the NACL (when creating or updating an ACL)
description: The id of the NACL (when creating or updating an ACL).
returned: success
type: str
sample: acl-123456789abcdef01
sample: "acl-123456789abcdef01"
"""

from typing import Any
Expand Down Expand Up @@ -184,7 +184,7 @@ def subnets_changed(client, module: AnsibleAWSModule, nacl_id: str, subnets_ids:
vpc_id = module.params.get("vpc_id")

if not subnets_ids:
default_nacl_id = find_default_vpc_nacl(client, module, vpc_id)
default_nacl_id = find_default_vpc_nacl(client, vpc_id)
# Find subnets by Network ACL ids
network_acls = describe_network_acls(
client, Filters=[{"Name": "association.network-acl-id", "Values": [nacl_id]}]
Expand All @@ -211,7 +211,7 @@ def subnets_changed(client, module: AnsibleAWSModule, nacl_id: str, subnets_ids:
if subnets_added:
changed |= associate_nacl_to_subnets(client, module, nacl_id, subnets_added)
if subnets_removed:
default_nacl_id = find_default_vpc_nacl(client, module, vpc_id)
default_nacl_id = find_default_vpc_nacl(client, vpc_id)
changed |= associate_nacl_to_subnets(client, module, default_nacl_id, subnets_removed)

return changed
Expand All @@ -221,13 +221,13 @@ def nacls_changed(client, module: AnsibleAWSModule, nacl_info: Dict[str, Any]) -
changed = False
entries = nacl_info["Entries"]
nacl_id = nacl_info["NetworkAclId"]
current_egress_rules = [rule for rule in entries if rule["Egress"] is True and rule["RuleNumber"] < 32767]
current_ingress_rules = [rule for rule in entries if rule["Egress"] is False and rule["RuleNumber"] < 32767]
aws_egress_rules = [rule for rule in entries if rule["Egress"] is True and rule["RuleNumber"] < 32767]
aws_ingress_rules = [rule for rule in entries if rule["Egress"] is False and rule["RuleNumber"] < 32767]

# Egress Rules
changed |= rules_changed(client, module, nacl_id, current_egress_rules, True)
changed |= rules_changed(client, nacl_id, module.params.get("egress"), aws_egress_rules, True, module.check_mode)
# Ingress Rules
changed |= rules_changed(client, module, nacl_id, current_ingress_rules, False)
changed |= rules_changed(client, nacl_id, module.params.get("ingress"), aws_ingress_rules, False, module.check_mode)
return changed


Expand Down Expand Up @@ -272,54 +272,36 @@ def ansible_to_boto3_dict_rule(ansible_rule: List[Any], egress: bool) -> Dict[st
return boto3_rule


def diff_network_acl_rules(
new_rules: List[Dict[str, Any]], current_rules: List[Dict[str, Any]]
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
rules_to_add = []
rules_to_remove = []

# Find all rule to add
for rule in new_rules:
_match = False
for existing in current_rules:
if existing["RuleNumber"] == rule["RuleNumber"] and existing == rule:
_match = True
break
if not _match:
rules_to_add.append(rule)

# Find all rule to remove
for existing in current_rules:
_match = False
for rule in new_rules:
if existing["RuleNumber"] == rule["RuleNumber"] and existing == rule:
_match = True
break
if not _match:
rules_to_remove.append(existing)

return rules_to_add, rules_to_remove
def find_added_rules(rules_a: List[Dict[str, Any]], rules_b: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
results = []
# A rule is considered as a new rule if either the RuleNumber does exist in the list of
# current Rules stored in AWS or if the Rule differs with the Rule stored in AWS with the same RuleNumber
for a in rules_a:
if not any(a["RuleNumber"] == b["RuleNumber"] and a == b for b in rules_b):
results.append(a)
return results


def rules_changed(
client, module: AnsibleAWSModule, nacl_id: str, current_rules: List[Dict[str, Any]], egress: bool
client,
nacl_id: str,
ansible_rules: List[List[str]],
aws_rules: List[Dict[str, Any]],
egress: bool,
check_mode: bool,
) -> bool:
changed = False
new_rules = []
# transform rules: from ansible list to boto3 dict
ansible_rules = [ansible_to_boto3_dict_rule(r, egress) for r in ansible_rules]

if egress:
new_rules += [ansible_to_boto3_dict_rule(r, egress=True) for r in module.params.get("egress")]
else:
new_rules += [ansible_to_boto3_dict_rule(r, egress=False) for r in module.params.get("ingress")]
# find added rules
added_rules = find_added_rules(ansible_rules, aws_rules)
# find removed rules
removed_rules = find_added_rules(aws_rules, ansible_rules)

added_rules, removed_rules = diff_network_acl_rules(new_rules, current_rules)
if not added_rules and not removed_rules:
return changed

# Added Rules
changed = False
for rule in added_rules:
changed = True
if not module.check_mode:
if not check_mode:
rule_number = rule.pop("RuleNumber")
protocol = rule.pop("Protocol")
rule_action = rule.pop("RuleAction")
Expand All @@ -337,7 +319,7 @@ def rules_changed(
# Removed Rules
for rule in removed_rules:
changed = True
if not module.check_mode:
if not check_mode:
delete_network_acl_entry(client, network_acl_id=nacl_id, rule_number=rule["RuleNumber"], egress=egress)

return changed
Expand All @@ -362,34 +344,22 @@ def process_rule_entry(entry: List[Any]) -> Dict[str, Any]:
return params


def construct_acl_entries(client, module: AnsibleAWSModule, nacl_id: str) -> bool:
def add_network_acl_entries(
client, nacl_id: str, ansible_entries: List[List[str]], egress: bool, check_mode: bool
) -> bool:
changed = False
# Process list entries
for entry in module.params.get("ingress"):
changed = True
if not module.check_mode:
create_network_acl_entry(
client,
network_acl_id=nacl_id,
protocol=str(PROTOCOL_NUMBERS[entry[1]]),
egress=False,
rule_action=entry[2],
rule_number=entry[0],
**process_rule_entry(entry),
)
for entry in module.params.get("egress"):
for entry in ansible_entries:
changed = True
if not module.check_mode:
if not check_mode:
create_network_acl_entry(
client,
network_acl_id=nacl_id,
protocol=str(PROTOCOL_NUMBERS[entry[1]]),
egress=True,
egress=egress,
rule_action=entry[2],
rule_number=entry[0],
**process_rule_entry(entry),
)

return changed


Expand Down Expand Up @@ -436,10 +406,15 @@ def ensure_present(client, module: AnsibleAWSModule) -> None:

# Associate Subnets to Network ACL
nacl_id = nacl["NetworkAclId"]
associate_nacl_to_subnets(client, module, nacl_id, subnets_ids)
changed |= associate_nacl_to_subnets(client, module, nacl_id, subnets_ids)

# Create Network ACL entries
construct_acl_entries(client, module, nacl_id)
# Create Network ACL entries (ingress and egress)
changed |= add_network_acl_entries(
client, nacl_id, module.params.get("ingress"), egress=False, check_mode=module.check_mode
)
changed |= add_network_acl_entries(
client, nacl_id, module.params.get("egress"), egress=True, check_mode=module.check_mode
)
else:
nacl_id = nacl["NetworkAclId"]
changed |= subnets_changed(client, module, nacl_id, subnets_ids)
Expand All @@ -462,7 +437,7 @@ def ensure_absent(client, module: AnsibleAWSModule) -> None:
assoc_ids = [a["NetworkAclAssociationId"] for a in associations]

# Find default NACL associated to the VPC
default_nacl_id = find_default_vpc_nacl(client, module, vpc_id)
default_nacl_id = find_default_vpc_nacl(client, vpc_id)
if not default_nacl_id:
module.exit_json(changed=changed, msg="Default NACL ID not found - Check the VPC ID")

Expand All @@ -484,49 +459,42 @@ def ensure_absent(client, module: AnsibleAWSModule) -> None:
def describe_network_acl(client, module: AnsibleAWSModule) -> Optional[Dict[str, Any]]:
nacl_id = module.params.get("nacl_id")
name = module.params.get("name")
try:
if nacl_id:
filters = [{"Name": "network-acl-id", "Values": [nacl_id]}]
else:
filters = [{"Name": "tag:Name", "Values": [name]}]
network_acls = describe_network_acls(client, Filters=filters)
return None if not network_acls else network_acls[0]
except AnsibleEC2Error as e:
module.fail_json_aws(e)

if nacl_id:
filters = [{"Name": "network-acl-id", "Values": [nacl_id]}]
else:
filters = [{"Name": "tag:Name", "Values": [name]}]
network_acls = describe_network_acls(client, Filters=filters)
return None if not network_acls else network_acls[0]

def find_default_vpc_nacl(client, module: AnsibleAWSModule, vpc_id: str) -> Optional[str]:

def find_default_vpc_nacl(client, vpc_id: str) -> Optional[str]:
default_nacl_id = None
try:
for nacl in describe_network_acls(client, Filters=[{"Name": "vpc-id", "Values": [vpc_id]}]):
if nacl.get("IsDefault", False):
default_nacl_id = nacl["NetworkAclId"]
break
except AnsibleEC2Error as e:
module.fail_json_aws(e)
for nacl in describe_network_acls(client, Filters=[{"Name": "vpc-id", "Values": [vpc_id]}]):
if nacl.get("IsDefault", False):
default_nacl_id = nacl["NetworkAclId"]
break
return default_nacl_id


def find_subnets_ids(client, module: AnsibleAWSModule, subnets_ids_or_names: List[str]) -> List[str]:
subnets_ids = []
subnets_names = []
try:
# Find Subnets by ID
subnets = describe_subnets(client, Filters=[{"Name": "subnet-id", "Values": subnets_ids_or_names}])
subnets_ids += [subnet["SubnetId"] for subnet in subnets]
subnets_names += [tag["Value"] for subnet in subnets for tag in subnet.get("Tags", []) if tag["Key"] == "Name"]

# Find Subnets by Name
subnets = describe_subnets(client, Filters=[{"Name": "tag:Name", "Values": subnets_ids_or_names}])
subnets_ids += [subnet["SubnetId"] for subnet in subnets]
subnets_names += [tag["Value"] for subnet in subnets for tag in subnet.get("Tags", []) if tag["Key"] == "Name"]

unexisting_subnets = [s for s in subnets_ids_or_names if s not in subnets_names + subnets_ids]
if unexisting_subnets:
module.fail_json(msg=f"The following subnets do not exist: {unexisting_subnets}")
return subnets_ids
except AnsibleEC2Error as e:
module.fail_json_aws(e)

# Find Subnets by ID
subnets = describe_subnets(client, Filters=[{"Name": "subnet-id", "Values": subnets_ids_or_names}])
subnets_ids += [subnet["SubnetId"] for subnet in subnets]
subnets_names += [tag["Value"] for subnet in subnets for tag in subnet.get("Tags", []) if tag["Key"] == "Name"]

# Find Subnets by Name
subnets = describe_subnets(client, Filters=[{"Name": "tag:Name", "Values": subnets_ids_or_names}])
subnets_ids += [subnet["SubnetId"] for subnet in subnets]
subnets_names += [tag["Value"] for subnet in subnets for tag in subnet.get("Tags", []) if tag["Key"] == "Name"]

unexisting_subnets = [s for s in subnets_ids_or_names if s not in subnets_names + subnets_ids]
if unexisting_subnets:
module.fail_json(msg=f"The following subnets do not exist: {unexisting_subnets}")
return subnets_ids


def main():
Expand Down Expand Up @@ -556,10 +524,13 @@ def main():

client = module.client("ec2")

if module.params.get("state") == "present":
ensure_present(client, module)
else:
ensure_absent(client, module)
try:
if module.params.get("state") == "present":
ensure_present(client, module)
else:
ensure_absent(client, module)
except AnsibleEC2Error as e:
module.fail_json_aws_error(e)


if __name__ == "__main__":
Expand Down

0 comments on commit 130f908

Please sign in to comment.