From bcfaa7816195e15fa99c56e667df5cfefba1d835 Mon Sep 17 00:00:00 2001 From: William Guilherme Date: Wed, 27 Sep 2023 16:25:43 -0700 Subject: [PATCH] feat: Added access policy condition operands validation --- plugins/module_utils/utils.py | 154 +++++++--------------- plugins/modules/zpa_policy_access_rule.py | 130 ++++++++++-------- 2 files changed, 119 insertions(+), 165 deletions(-) diff --git a/plugins/module_utils/utils.py b/plugins/module_utils/utils.py index 9a88ca3..fa6cee1 100644 --- a/plugins/module_utils/utils.py +++ b/plugins/module_utils/utils.py @@ -3,9 +3,7 @@ __metaclass__ = type import pycountry -from ansible_collections.zscaler.zpacloud.plugins.module_utils.zpa_client import ( - ZPAClientHelper, -) + def map_conditions(conditions_obj): result = [] @@ -68,7 +66,6 @@ def normalize_policy(policy): return normalized - def validate_operand(operand, module): def lhsWarn(object_type, expected, got, error=None): error_msg = f"Invalid LHS for '{object_type}'. Expected {expected}, but got '{got}'" @@ -82,72 +79,51 @@ def rhsWarn(object_type, expected, got, error=None): error_msg += f". Error details: {error}" return error_msg - object_type = operand.get("objectType", "").upper() + object_type = operand.get("object_type", "").upper() lhs = operand.get("lhs") rhs = operand.get("rhs") - # Check lhs and rhs for emptiness - if lhs is None or not lhs.strip(): - return lhsWarn(object_type, "a non-empty value", "empty or None") - if rhs is None or not rhs.strip(): - return rhsWarn(object_type, "a non-empty value", "empty or None") - - lhs = lhs.strip() - rhs = rhs.strip() - - client = ZPAClientHelper(module) - - object_validations = { - "APP": { - "lhs": ["id"], - "fetch_method": client.app_segments.get_segment, - "kwargs": {"id": rhs}, - "rhs_msg": "valid application segment ID", - }, - "APP_GROUP": { - "lhs": ["id"], - "fetch_method": client.segment_groups.get_group, - "kwargs": {"id": rhs}, - "rhs_msg": "valid segment group ID", - }, - "MACHINE_GRP": { - "lhs": ["id"], - "fetch_method": client.machine_groups.get_group, - "kwargs": {"id": rhs}, - "rhs_msg": "valid machine group ID", - }, - "EDGE_CONNECTOR_GROUP": { - "lhs": ["id"], - "fetch_method": client.cloud_connector_groups.get_group, - "kwargs": {"id": rhs}, - "rhs_msg": "valid cloud connector ID", - }, - "POSTURE": { - "rhs": ["true", "false"], - "fetch_method": client.posture_profiles.get_profile_by_posture_udid, - "kwargs": {"posture_udid": lhs}, - "lhs_msg": "valid posture profile ID", - }, - "TRUSTED_NETWORK": { - "rhs": ["true", "false"], - "fetch_method": client.trusted_networks.get_by_network_id, - "kwargs": {"network_id": lhs}, - "lhs_msg": "valid trusted network ID", - }, - "PLATFORM": { - "rhs": ["true"], - "lhs": ['linux', 'android', 'windows', 'ios', 'mac'], - "lhs_msg": "one of ['linux', 'android', 'windows', 'ios', 'mac']", - }, - "COUNTRY_CODE": { - "rhs": ["true"], - "lhs": validate_iso3166_alpha2, # Using the function directly here - "lhs_msg": "valid ISO-3166 Alpha-2 country code. Please visit the following site for reference: https://en.wikipedia.org/wiki/List_of_ISO_3166_country_codes", - }, - "CLIENT_TYPE": { - "lhs": ["id"], - "lhs_msg": "the string 'id'", - "rhs": [ + # Validate non-emptiness + if not object_type or not lhs or not rhs: + return "Object type, LHS, and RHS cannot be empty or None" + + # Ensure lhs and rhs are strings + if not isinstance(lhs, str): + lhs = str(lhs) + if not isinstance(rhs, str): + rhs = str(rhs) + + valid_object_types = ["APP", "APP_GROUP", "MACHINE_GRP", "EDGE_CONNECTOR_GROUP", "POSTURE", "TRUSTED_NETWORK", "PLATFORM", "COUNTRY_CODE", "CLIENT_TYPE"] + + if object_type not in valid_object_types: + return f"Invalid object type: {object_type}. Supported types are: {', '.join(valid_object_types)}" + + if object_type in ["APP", "APP_GROUP", "MACHINE_GRP", "EDGE_CONNECTOR_GROUP"]: + if lhs != 'id': + return lhsWarn(object_type, 'id', lhs) + if not rhs: + return rhsWarn(object_type, "non-empty string", rhs) + + elif object_type in ["POSTURE", "TRUSTED_NETWORK"]: + if rhs not in ['true', 'false']: + return rhsWarn(object_type, "one of ['true', 'false']", rhs) + + elif object_type == "PLATFORM": + if rhs != 'true': + return rhsWarn(object_type, 'true', rhs) + if lhs not in ['linux', 'android', 'windows', 'ios', 'mac']: + return lhsWarn(object_type, "one of ['linux', 'android', 'windows', 'ios', 'mac']", lhs) + + elif object_type == "COUNTRY_CODE": + if rhs != 'true': + return rhsWarn(object_type, 'true', rhs) + if not validate_iso3166_alpha2(lhs): + return lhsWarn(object_type, "a valid ISO-3166 Alpha-2 country code", lhs, "Please visit the following site for reference: https://en.wikipedia.org/wiki/List_of_ISO_3166_country_codes") + + elif object_type == "CLIENT_TYPE": + if lhs != 'id': + return lhsWarn(object_type, 'id', lhs) + valid_client_types = [ 'zpn_client_type_exporter', 'zpn_client_type_exporter_noauth', 'zpn_client_type_browser_isolation', @@ -158,47 +134,9 @@ def rhsWarn(object_type, expected, got, error=None): 'zpn_client_type_slogger', 'zpn_client_type_zapp_partner', 'zpn_client_type_branch_connector' - ], - "rhs_msg": "one of ['zpn_client_type_exporter', zpn_client_type_exporter_noauth, zpn_client_type_browser_isolation, zpn_client_type_machine_tunnel, zpn_client_type_ip_anchoring, zpn_client_type_edge_connector, zpn_client_type_zapp, zpn_client_type_s]" - }, - } - - validation = object_validations.get(object_type) - - if validation: - # Validate LHS - if "lhs" in validation and lhs not in validation["lhs"]: - return lhsWarn(object_type, validation["lhs_msg"], lhs) - - # Validate RHS for APP, APP_GROUP, MACHINE_GRP, and EDGE_CONNECTOR_GROUP - if object_type in ["APP", "APP_GROUP", "MACHINE_GRP", "EDGE_CONNECTOR_GROUP"]: - try: - result = validation["fetch_method"](**validation["kwargs"]) - if not result or result.get('id') != rhs: - return rhsWarn(object_type, validation["rhs_msg"], rhs) - except Exception as e: - fetch_msg = f"Error retrieving {object_type} with ID '{rhs}': {str(e)}" - return fetch_msg - - # Validate RHS for other types (POSTURE, TRUSTED_NETWORK, PLATFORM, CLIENT_TYPE) - elif rhs not in validation["rhs"]: - return rhsWarn(object_type, validation["rhs_msg"], rhs) - - # Validate LHS for POSTURE and TRUSTED_NETWORK - if object_type in ["POSTURE", "TRUSTED_NETWORK"]: - try: - result = validation["fetch_method"](**validation["kwargs"]) - if not result: - return lhsWarn(object_type, validation["lhs_msg"], lhs) - except Exception as e: - fetch_msg = f"Error retrieving {object_type} with ID '{lhs}': {str(e)}" - return fetch_msg - - # Specific LHS Validation for PLATFORM and COUNTRY_CODE - if object_type == "PLATFORM" and lhs not in ['linux', 'android', 'windows', 'ios', 'mac']: - return lhsWarn(object_type, "one of ['linux', 'android', 'windows', 'ios', 'mac']", lhs) - if object_type == "COUNTRY_CODE" and not validate_iso3166_alpha2(lhs): - return lhsWarn(object_type, "a valid ISO-3166 Alpha-2 country code", lhs) + ] + if rhs not in valid_client_types: + return rhsWarn(object_type, f"one of {valid_client_types}", rhs) return None diff --git a/plugins/modules/zpa_policy_access_rule.py b/plugins/modules/zpa_policy_access_rule.py index 2b72c06..1c565e5 100644 --- a/plugins/modules/zpa_policy_access_rule.py +++ b/plugins/modules/zpa_policy_access_rule.py @@ -251,104 +251,121 @@ from ansible.module_utils._text import to_native from ansible.module_utils.basic import AnsibleModule from ansible_collections.zscaler.zpacloud.plugins.module_utils.utils import map_conditions -from ansible_collections.zscaler.zpacloud.plugins.module_utils.utils import validate_operand from ansible_collections.zscaler.zpacloud.plugins.module_utils.utils import normalize_policy +from ansible_collections.zscaler.zpacloud.plugins.module_utils.utils import validate_operand from ansible_collections.zscaler.zpacloud.plugins.module_utils.zpa_client import ( ZPAClientHelper, deleteNone, ) def core(module): - state = module.params.get("state", "present") + state = module.params.get("state", None) client = ZPAClientHelper(module) policy_rule_id = module.params.get("id", None) policy_rule_name = module.params.get("name", None) + policy = dict() params = [ "id", "name", "description", - "custom_msg," - "policy_type", "action", - "operator", "rule_order", + "policy_type", + "custom_msg", + "app_connector_group_ids", + "app_server_group_ids", + "operator", "conditions", ] + for param_name in params: + policy[param_name] = module.params.get(param_name, None) - policy = {param: module.params.get(param, None) for param in params} + conditions = module.params.get('conditions') or [] - # Validate conditions - for condition in module.params.get('conditions', []): - for operand in condition.get('operands', []): + # Validate each operand in the conditions + for condition in conditions: + operands = condition.get('operands', []) + for operand in operands: validation_result = validate_operand(operand, module) if validation_result: - module.fail_json(msg=validation_result) + module.fail_json(msg=validation_result) # Fail if validation returns a warning or error message existing_policy = None if policy_rule_id is not None: - existing_policy = client.policies.get_rule(policy_type="access", rule_id=policy_rule_id) + existing_policy = client.policies.get_rule( + policy_type="access", rule_id=policy_rule_id + ) elif policy_rule_name is not None: - rules = client.policies.list_rules(policy_type="access") + rules = client.policies.list_rules(policy_type="access").to_list() for rule in rules: if rule.get("name") == policy_rule_name: existing_policy = rule break - differences_detected = False - if existing_policy: + if existing_policy is not None: + # Normalize both policies' conditions policy['conditions'] = map_conditions(policy.get("conditions", [])) existing_policy['conditions'] = map_conditions(existing_policy.get("conditions", [])) + desired_policy = normalize_policy(policy) current_policy = normalize_policy(existing_policy) fields_to_exclude = ['id', 'policy_type'] + differences_detected = False for key, value in desired_policy.items(): if key not in fields_to_exclude and current_policy.get(key) != value: differences_detected = True module.warn(f"Difference detected in {key}. Current: {current_policy.get(key)}, Desired: {value}") + if existing_policy is not None: + id = existing_policy.get("id") + existing_policy.update(policy) + existing_policy["id"] = id if state == "present": - if existing_policy and differences_detected: + if existing_policy is not None and differences_detected: """Update""" updated_policy = { "policy_type": "access", - "rule_id": existing_policy.get("id"), - "name": existing_policy.get("name"), - "description": existing_policy.get("description"), - "action": existing_policy.get("action").upper(), + "rule_id": existing_policy.get("id", None), + "name": existing_policy.get("name", None), + "description": existing_policy.get("description", None), + "rule_order": existing_policy.get("rule_order", None), + "action": existing_policy.get("action", "").upper(), "conditions": map_conditions(existing_policy.get("conditions", [])), - "custom_msg": existing_policy.get("custom_msg"), - "rule_order": existing_policy.get("rule_order"), - "app_connector_group_ids": existing_policy.get("app_connector_group_ids"), - "app_server_group_ids": existing_policy.get("app_server_group_ids") + "custom_msg": existing_policy.get("custom_msg", None), + "app_connector_group_ids": existing_policy.get("app_connector_group_ids", None), + "app_server_group_ids": existing_policy.get("app_server_group_ids", None) } cleaned_policy = deleteNone(updated_policy) updated_policy = client.policies.update_access_rule(**cleaned_policy) module.exit_json(changed=True, data=updated_policy) - elif not existing_policy: + elif existing_policy is None: """Create""" new_policy = { - "name": policy.get("name"), - "description": policy.get("description"), - "action": policy.get("action"), - "rule_order": policy.get("rule_order"), + "name": policy.get("name", None), + "description": policy.get("description", None), + "action": policy.get("action", None), + "rule_order": policy.get("rule_order", None), "conditions": map_conditions(policy.get("conditions", [])), - "custom_msg": policy.get("custom_msg"), - "app_connector_group_ids": policy.get("app_connector_group_ids"), - "app_server_group_ids": policy.get("app_server_group_ids") + "custom_msg": policy.get("custom_msg", None), + "app_connector_group_ids": policy.get("app_connector_group_ids", None), + "app_server_group_ids": policy.get("app_server_group_ids", None) } cleaned_policy = deleteNone(new_policy) created_policy = client.policies.add_access_rule(**cleaned_policy) - module.exit_json(changed=True, data=created_policy) + module.exit_json(changed=True, data=created_policy) # Mark as changed since we are creating else: - module.exit_json(changed=False, data=existing_policy) - elif state == "absent" and existing_policy: - code = client.policies.delete_rule(policy_type="access", rule_id=existing_policy.get("id")) + module.exit_json(changed=False, data=existing_policy) # If there's no change, exit without updating + elif state == "absent" and existing_policy is not None: + code = client.policies.delete_rule( + policy_type="access", rule_id=existing_policy.get("id") + ) if code > 299: - module.fail_json(msg="Failed to delete rule", data=None) + module.exit_json(changed=False, data=None) module.exit_json(changed=True, data=existing_policy) module.exit_json(changed=False, data={}) + def main(): argument_spec = ZPAClientHelper.zpa_argument_spec() argument_spec.update( @@ -382,24 +399,7 @@ def main(): rhs=dict(type="str", required=False), object_type=dict( type="str", - required=True, - choices=[ - "APP", - "APP_GROUP", - "LOCATION", - "IDP", - "SAML", - "SCIM", - "SCIM_GROUP", - "CLIENT_TYPE", - "POSTURE", - "TRUSTED_NETWORK", - "BRANCH_CONNECTOR_GROUP", - "EDGE_CONNECTOR_GROUP", - "MACHINE_GRP", - "COUNTRY_CODE", - "PLATFORM", - ], + required=False, ), ), required=False, @@ -407,15 +407,31 @@ def main(): ), required=False, ), - new_rule_name=dict(type="str", required=False), state=dict(type="str", choices=["present", "absent"], default="present"), ) module = AnsibleModule(argument_spec=argument_spec, supports_check_mode=True) + + # Custom validation for object_type + conditions = module.params['conditions'] + if conditions: # Add this check to handle when conditions is None + for condition in conditions: + operands = condition.get('operands', []) + for operand in operands: + object_type = operand.get('object_type') + valid_object_types = [ + "APP", "APP_GROUP", "LOCATION", "IDP", "SAML", "SCIM", + "SCIM_GROUP", "CLIENT_TYPE", "POSTURE", "TRUSTED_NETWORK", + "BRANCH_CONNECTOR_GROUP", "EDGE_CONNECTOR_GROUP", "MACHINE_GRP", + "COUNTRY_CODE", "PLATFORM", + ] + if object_type is None or object_type == "": # Explicitly check for None or empty string + module.fail_json(msg=f"object_type cannot be empty or None. Must be one of: {', '.join(valid_object_types)}") + elif object_type not in valid_object_types: + module.fail_json(msg=f"Invalid object_type: {object_type}. Must be one of: {', '.join(valid_object_types)}") try: core(module) except Exception as e: module.fail_json(msg=to_native(e), exception=format_exc()) - if __name__ == "__main__": - main() + main() \ No newline at end of file