Skip to content

Commit

Permalink
feat: Added several integration test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
willguibr committed Oct 8, 2023
1 parent d3a8653 commit b40968d
Show file tree
Hide file tree
Showing 101 changed files with 2,769 additions and 1,203 deletions.
21 changes: 21 additions & 0 deletions examples/main.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
- name: Policy rule
hosts: localhost

vars:
zpa_cloud:
client_id: "{{ lookup('env', 'ZPA_CLIENT_ID') }}"
client_secret: "{{ lookup('env', 'ZPA_CLIENT_SECRET') }}"
customer_id: "{{ lookup('env', 'ZPA_CUSTOMER_ID') }}"
cloud: "{{ lookup('env', 'ZPA_CLOUD') | default(omit) }}"

tasks:
- name: Gather information about specific Posture Profile UDID
zscaler.zpacloud.zpa_posture_profile_info:
# name: "CrowdStrike_ZPA_Pre-ZTA"
register: pp_crowdstrike_zta40
- debug:
msg: "{{ pp_crowdstrike_zta40}}"

- name: pp_crowdstrike_zta40
debug:
msg: "{{ pp_crowdstrike_zta40}}"
18 changes: 18 additions & 0 deletions plugins/module_utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
__metaclass__ = type

import pycountry
import re


def deleteNone(_dict):
"""Delete None values recursively from all of the dictionaries, tuples, lists, sets"""
Expand All @@ -16,6 +18,10 @@ def deleteNone(_dict):
_dict = type(_dict)(deleteNone(item) for item in _dict if item is not None)
return _dict

def remove_cloud_suffix(s: str) -> str:
reg = re.compile(r"(.*)[\s]+\([a-zA-Z0-9\-_\.]*\)[\s]*$")
res = reg.sub(r"\1", s)
return res.strip()

# Function to handle application segment port conversion list
def convert_ports_list(obj_list):
Expand Down Expand Up @@ -104,6 +110,8 @@ def normalize_app(app):
"is_incomplete_dr_config",
"inspect_traffic_with_zia",
"adp_enabled",
"app_id",
"ip_anchored",
]
for attr in computed_values:
normalized.pop(attr, None)
Expand All @@ -125,6 +133,13 @@ def normalize_app(app):
normalized["common_apps_dto"]
)

# Normalizing clientless_app_ids attributes
if "clientless_app_ids" in normalized:
for clientless_app in normalized["clientless_app_ids"]:
for field in ["app_id", "id", "hidden", "portal", "path", "certificate_name", "cname", "local_domain"]:
clientless_app.pop(field, None)


return normalized


Expand All @@ -151,6 +166,7 @@ def validate_latitude(val):
return (None, ["latitude value should be a valid float number or not empty"])
return (None, None)


def validate_longitude(val):
try:
v = float(val)
Expand All @@ -160,6 +176,7 @@ def validate_longitude(val):
return (None, ["longitude value should be a valid float number or not empty"])
return (None, None)


def diff_suppress_func_coordinate(old, new):
try:
o = round(float(old) * 1000000) / 1000000
Expand All @@ -168,6 +185,7 @@ def diff_suppress_func_coordinate(old, new):
except ValueError:
return False


def validate_tcp_quick_ack(
tcp_quick_ack_app, tcp_quick_ack_assistant, tcp_quick_ack_read_assistant
):
Expand Down
82 changes: 52 additions & 30 deletions plugins/module_utils/zpa_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@

def to_zscaler_sdk_cls(pkg_name, cls_name):
sdk_name = "zscaler"

try:
mod = importlib.import_module("{0}.{1}".format(sdk_name, pkg_name))
except ModuleNotFoundError:
Expand All @@ -61,8 +60,6 @@ def to_zscaler_sdk_cls(pkg_name, cls_name):


class ConnectionHelper:
"""ConnectionHelper class for managing and verifying connectivity."""

def __init__(self, min_sdk_version):
self.min_sdk_version = min_sdk_version
self.sdk_installed = self._check_sdk_installed()
Expand Down Expand Up @@ -94,58 +91,83 @@ def __init__(self, module):
self.connection_helper = ConnectionHelper(min_sdk_version=(1, 0, 0))
self.connection_helper.ensure_sdk_installed()

cloud_env = module.params.get("cloud")
if cloud_env is None:
cloud_env = "PRODUCTION"
else:
cloud_env = cloud_env.upper()
provider = module.params.get("provider") or {}

client_id = provider.get("client_id") if provider else module.params.get("client_id")
if not client_id:
raise ValueError("client_id must be provided via provider or directly")

client_secret = provider.get("client_secret") if provider else module.params.get("client_secret")
if not client_secret:
raise ValueError("client_secret must be provided via provider or directly")

customer_id = provider.get("customer_id") if provider else module.params.get("customer_id")
if not customer_id:
raise ValueError("customer_id must be provided via provider or directly")

cloud_env = (provider.get("cloud") if provider else module.params.get("cloud")) or "PRODUCTION"
cloud_env = cloud_env.upper()

if cloud_env not in VALID_ZPA_ENVIRONMENTS:
raise ValueError(
f"Invalid ZPA Cloud environment '{cloud_env}'. Supported environments are: {', '.join(VALID_ZPA_ENVIRONMENTS)}."
)

super().__init__(
client_id=module.params.get("client_id", ""),
client_secret=module.params.get("client_secret", ""),
customer_id=module.params.get("customer_id", ""),
client_id=client_id,
client_secret=client_secret,
customer_id=customer_id,
cloud=cloud_env, # using the validated cloud environment
)

super().__init__(
client_id=client_id,
client_secret=client_secret,
customer_id=customer_id,
cloud=cloud_env, # using the validated cloud environment
)

# Set the User-Agent
ansible_version = ansible.__version__ # Get the Ansible version
customer_id = module.params.get("customer_id", "")
self.user_agent = f"zpa-ansible/{ansible_version}/({platform.system().lower()} {platform.machine()})/customer_id:{customer_id}"

@staticmethod
def zpa_argument_spec():
return dict(
provider=dict(
type="dict",
options=dict(
client_id=dict(
no_log=True,
fallback=(env_fallback, ["ZPA_CLIENT_ID"]),
),
client_secret=dict(
no_log=True,
fallback=(env_fallback, ["ZPA_CLIENT_SECRET"]),
),
customer_id=dict(
no_log=True,
fallback=(env_fallback, ["ZPA_CUSTOMER_ID"]),
),
cloud=dict(
no_log=True,
fallback=(env_fallback, ["ZPA_CLOUD"]),
),
),
),
client_id=dict(
no_log=True,
fallback=(
env_fallback,
["ZPA_CLIENT_ID"],
),
fallback=(env_fallback, ["ZPA_CLIENT_ID"]),
),
client_secret=dict(
no_log=True,
fallback=(
env_fallback,
["ZPA_CLIENT_SECRET"],
),
fallback=(env_fallback, ["ZPA_CLIENT_SECRET"]),
),
customer_id=dict(
no_log=True,
fallback=(
env_fallback,
["ZPA_CUSTOMER_ID"],
),
fallback=(env_fallback, ["ZPA_CUSTOMER_ID"]),
),
cloud=dict(
no_log=True,
fallback=(
env_fallback,
["ZPA_CLOUD"],
),
fallback=(env_fallback, ["ZPA_CLOUD"]),
),
)
)
36 changes: 27 additions & 9 deletions plugins/modules/zpa_app_connector_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,17 +270,25 @@ def core(module):
new_lat = group.get("latitude")
if new_lat is not None: # Check if new_lat is not None before comparing
if diff_suppress_func_coordinate(existing_lat, new_lat):
existing_group["latitude"] = existing_lat # reset to original if they're deemed equal
existing_group[
"latitude"
] = existing_lat # reset to original if they're deemed equal
else:
existing_group["latitude"] = existing_lat # If new_lat is None, keep the existing value
existing_group[
"latitude"
] = existing_lat # If new_lat is None, keep the existing value

existing_long = existing_group.get("longitude")
new_long = group.get("longitude")
if new_long is not None: # Check if new_long is not None before comparing
if diff_suppress_func_coordinate(existing_long, new_long):
existing_group["longitude"] = existing_long # reset to original if they're deemed equal
existing_group[
"longitude"
] = existing_long # reset to original if they're deemed equal
else:
existing_group["longitude"] = existing_long # If new_long is None, keep the existing value
existing_group[
"longitude"
] = existing_long # If new_long is None, keep the existing value

existing_group = deleteNone(
dict(
Expand All @@ -296,19 +304,27 @@ def core(module):
upgrade_day=existing_group.get("upgrade_day"),
connector_ids=existing_group.get("connector_ids"),
upgrade_time_in_secs=existing_group.get("upgrade_time_in_secs"),
override_version_profile=existing_group.get("override_version_profile"),
override_version_profile=existing_group.get(
"override_version_profile"
),
version_profile_id=existing_group.get("version_profile_id"),
version_profile_name=existing_group.get("version_profile_name"),
dns_query_type=existing_group.get("dns_query_type"),
tcp_quick_ack_app=existing_group.get("tcp_quick_ack_app"),
tcp_quick_ack_assistant=existing_group.get("tcp_quick_ack_assistant"),
tcp_quick_ack_read_assistant=existing_group.get("tcp_quick_ack_read_assistant"),
tcp_quick_ack_assistant=existing_group.get(
"tcp_quick_ack_assistant"
),
tcp_quick_ack_read_assistant=existing_group.get(
"tcp_quick_ack_read_assistant"
),
use_in_dr_mode=existing_group.get("use_in_dr_mode"),
pra_enabled=existing_group.get("pra_enabled"),
waf_disabled=existing_group.get("waf_disabled"),
)
)
existing_group = client.connectors.update_connector_group(**existing_group).to_dict()
existing_group = client.connectors.update_connector_group(
**existing_group
).to_dict()
module.exit_json(changed=True, data=existing_group)
else:
"""Create"""
Expand All @@ -331,7 +347,9 @@ def core(module):
dns_query_type=group.get("dns_query_type"),
tcp_quick_ack_app=group.get("tcp_quick_ack_app"),
tcp_quick_ack_assistant=group.get("tcp_quick_ack_assistant"),
tcp_quick_ack_read_assistant=group.get("tcp_quick_ack_read_assistant"),
tcp_quick_ack_read_assistant=group.get(
"tcp_quick_ack_read_assistant"
),
use_in_dr_mode=group.get("use_in_dr_mode"),
pra_enabled=group.get("pra_enabled"),
waf_disabled=group.get("waf_disabled"),
Expand Down
23 changes: 13 additions & 10 deletions plugins/modules/zpa_application_segment_browser_access.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@
elements: str
required: True
description: "List of domains and IPs."
clientless_apps:
clientless_app_ids:
description: ""
type: list
elements: dict
Expand Down Expand Up @@ -264,7 +264,7 @@
enabled: true
health_reporting: ON_ACCESS
bypass_type: NEVER
clientless_apps:
clientless_app_ids:
- name: "crm.example.com"
application_protocol: "HTTP"
application_port: "8080"
Expand Down Expand Up @@ -294,6 +294,7 @@
# The newly created browser access application segment resource record.
"""

# Need to review resource to ensure update occurs successfully.
from traceback import format_exc

from ansible.module_utils._text import to_native
Expand All @@ -309,7 +310,6 @@
ZPAClientHelper,
)


def core(module):
state = module.params.get("state", None)
client = ZPAClientHelper(module)
Expand All @@ -320,7 +320,7 @@ def core(module):
"enabled",
"description",
"bypass_type",
"clientless_apps",
"clientless_app_ids",
"domain_names",
"double_encrypt",
"health_check_type",
Expand All @@ -341,6 +341,7 @@ def core(module):
]
for param_name in params:
app[param_name] = module.params.get(param_name)

# Usage for tcp_keep_alive
tcp_keep_alive = module.params.get("tcp_keep_alive")
converted_tcp_keep_alive = convert_bool_to_str(
Expand Down Expand Up @@ -400,8 +401,9 @@ def core(module):
existing_app.update(app)
existing_app["id"] = id

if state == "present":
if existing_app is not None:
if state == "present":
if existing_app is not None:
if differences_detected:
"""Update"""
existing_app = deleteNone(
dict(
Expand All @@ -410,7 +412,7 @@ def core(module):
description=existing_app.get("description", None),
enabled=existing_app.get("enabled", None),
bypass_type=existing_app.get("bypass_type", None),
clientless_app_ids=existing_app.get("clientless_apps", None),
clientless_app_ids=existing_app.get("clientless_app_ids", None),
domain_names=existing_app.get("domain_names", None),
double_encrypt=existing_app.get("double_encrypt", None),
health_check_type=existing_app.get("health_check_type", None),
Expand Down Expand Up @@ -443,6 +445,7 @@ def core(module):
),
)
)
module.warn("Prepared payload for update_segment: {}".format(existing_app))
app = client.app_segments.update_segment(**existing_app)
module.exit_json(changed=True, data=app)
else:
Expand All @@ -456,7 +459,7 @@ def core(module):
description=app.get("description", None),
enabled=app.get("enabled", None),
bypass_type=app.get("bypass_type", None),
clientless_app_ids=app.get("clientless_apps", None),
clientless_app_ids=app.get("clientless_app_ids", None),
domain_names=app.get("domain_names", None),
double_encrypt=app.get("double_encrypt", None),
health_check_type=app.get("health_check_type", None),
Expand All @@ -480,7 +483,7 @@ def core(module):
)
)
app = client.app_segments.add_segment(**app)
module.exit_json(changed=False, data=app)
module.exit_json(changed=True, data=app)
elif state == "absent" and existing_app is not None:
code = client.app_segments.delete_segment(
segment_id=existing_app.get("id"), force_delete=True
Expand Down Expand Up @@ -536,7 +539,7 @@ def main():
udp_port_range=dict(
type="list", elements="dict", options=port_spec, required=False
),
clientless_apps=dict(
clientless_app_ids=dict(
type="list",
elements="dict",
options=dict(
Expand Down
Loading

0 comments on commit b40968d

Please sign in to comment.