Skip to content

Commit

Permalink
Refactor(plugins): Move jinja test code for arista.avd.contains to Py…
Browse files Browse the repository at this point in the history
…AVD (#4131)
  • Loading branch information
gmuloc authored Jun 19, 2024
1 parent fde40bd commit d6e3870
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

__metaclass__ = type

from functools import wraps
from typing import Callable
from functools import partial, wraps
from typing import Callable, Literal

from ansible.errors import AnsibleFilterError, AnsibleUndefinedVariable
from ansible.errors import AnsibleFilterError, AnsibleInternalError, AnsibleTemplateError, AnsibleUndefinedVariable
from ansible.module_utils.basic import to_native
from jinja2.exceptions import UndefinedError

Expand All @@ -27,17 +27,29 @@ def __call__(self, *args):
raise self.exception


def wrap_filter(name: str) -> Callable:
def wrap_filter_decorator(func: Callable | None) -> Callable:
def wrap_plugin(plugin_type: Literal["filter", "test"], name: str) -> Callable:
plugin_map = {
"filter": AnsibleFilterError,
"test": AnsibleTemplateError,
}

if plugin_type not in plugin_map:
raise AnsibleInternalError(f"Wrong plugin type {plugin_type} passed to wrap_plugin.")

def wrap_plugin_decorator(func: Callable) -> Callable:
@wraps(func)
def filter_wrapper(*args, **kwargs):
def plugin_wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except UndefinedError as e:
raise AnsibleUndefinedVariable(f"Filter '{name}' failed: {to_native(e)}", orig_exc=e) from e
raise AnsibleUndefinedVariable(f"{plugin_type.capitalize()} '{name}' failed: {to_native(e)}", orig_exc=e) from e
except Exception as e:
raise AnsibleFilterError(f"Filter '{name}' failed: {to_native(e)}", orig_exc=e) from e
raise plugin_map[plugin_type](f"{plugin_type.capitalize()} '{name}' failed: {to_native(e)}", orig_exc=e) from e

return plugin_wrapper

return wrap_plugin_decorator

return filter_wrapper

return wrap_filter_decorator
wrap_filter = partial(wrap_plugin, "filter")
wrap_test = partial(wrap_plugin, "test")
66 changes: 17 additions & 49 deletions ansible_collections/arista/avd/plugins/test/contains.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,22 @@

__metaclass__ = type

from jinja2.runtime import Undefined
from ansible.errors import AnsibleTemplateError

from ansible_collections.arista.avd.plugins.plugin_utils.pyavd_wrappers import RaiseOnUse, wrap_test

PLUGIN_NAME = "arista.avd.contains"

try:
from pyavd.j2tests.contains import contains
except ImportError as e:
contains = RaiseOnUse(
AnsibleTemplateError(
f"The '{PLUGIN_NAME}' plugin requires the 'pyavd' Python library. Got import error",
orig_exc=e,
)
)


DOCUMENTATION = r"""
---
Expand Down Expand Up @@ -66,53 +81,6 @@
"""


def contains(value, test_value=None):
"""
contains - Ansible test plugin to test if a list contains one or more elements
Arista.avd.contains will test value and argument if defined and is not none and return false if any one them doesn't pass.
Test value can be one value or a list of values to test for.
Example:
1. Test for one element in list
{% if switch.vlans is arista.avd.contains(123) %}
...
{% endif %}
2. Test for multiple elements in list
{% if switch.vlans is arista.avd.contains([123, 456]) %}
...
{% endif %}
Parameters
----------
value : any
List to test
test_value : single item or list of items
Value(s) to test for in value
Returns
-------
boolean
True if variable matches criteria, False in other cases.
"""
if isinstance(value, Undefined) or value is None or not isinstance(value, list):
# Invalid value - return false
return False
elif isinstance(test_value, Undefined) or value is None:
# Invalid value - return false
return False
elif isinstance(test_value, list) and not set(value).isdisjoint(test_value):
# test_value is list so test if value and test_value has any common items
return True
elif test_value in value:
# Test if test_value is in value
return True
else:
return False


class TestModule(object):
def tests(self):
return {
"contains": contains,
}
return {"contains": wrap_test(PLUGIN_NAME)(contains)}
1 change: 1 addition & 0 deletions python-avd/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ EOS_CLI_CONFIG_GEN_TEMPLATE_DIR = $(VENDOR_DIR)/templates
SCHEMAS_DIR = $(VENDOR_DIR)/schemas
EOS_DESIGNS_MODULES_DIR = $(VENDOR_DIR)/eos_designs
PYAVD_FILTER_IMPORT = $(PACKAGE_DIR).j2filters
PYAVD_TEST_IMPORT = $(PACKAGE_DIR).j2tests
EOS_DESIGNS_IMPORT = $(PACKAGE_DIR)._eos_designs
# export PYTHONPATH=$(CURRENT_DIR) # Uncomment to test from source

Expand Down
43 changes: 43 additions & 0 deletions python-avd/pyavd/j2tests/contains.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright (c) 2023-2024 Arista Networks, Inc.
# Use of this source code is governed by the Apache License 2.0
# that can be found in the LICENSE file.
"""AVD Jinja2 test contains.
The test checks if a list contains any of the value(s) passed in test_value.
"""

from __future__ import annotations

from typing import Any

from jinja2.runtime import Undefined


def contains(value: list[Any], test_value: Any | list[Any] = None) -> bool:
"""The test checks if a list contains any of the value(s) passed in test_value.
If 'value' is Undefined, None or not a list then the test has failed.
Parameters
----------
value :
List to test
test_value : single item or list of items
Value(s) to test for in value
Returns
-------
boolean
True if variable matches criteria, False in other cases.
"""
# TODO - this will fail miserably if test_value is not hashable !
if isinstance(value, Undefined) or value is None or not isinstance(value, list):
# Invalid value - return false
return False
if isinstance(test_value, Undefined) or test_value is None:
# Invalid value - return false
return False
if isinstance(test_value, list) and not set(value).isdisjoint(test_value):
# test_value is list so test if value and test_value has any common items
return True
return test_value in value
2 changes: 1 addition & 1 deletion python-avd/pyavd/templater.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .j2filters.hide_passwords import hide_passwords
from .j2filters.list_compress import list_compress
from .j2filters.natural_sort import natural_sort
from .j2tests.contains import contains


class Undefined(StrictUndefined):
Expand Down Expand Up @@ -62,7 +63,6 @@ def import_filters_and_tests(self) -> None:
from .vendor.j2.filter.decrypt import decrypt
from .vendor.j2.filter.encrypt import encrypt
from .vendor.j2.filter.range_expand import range_expand
from .vendor.j2.test.contains import contains
from .vendor.j2.test.defined import defined

# pylint: enable=import-outside-toplevel
Expand Down
32 changes: 32 additions & 0 deletions python-avd/tests/pyavd/j2tests/test_contains.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (c) 2023-2024 Arista Networks, Inc.
# Use of this source code is governed by the Apache License 2.0
# that can be found in the LICENSE file.
"""Unit tests for pyavd.j2tests.contains."""

from __future__ import annotations

import pytest
from jinja2.runtime import Undefined
from pyavd.j2tests.contains import contains

TEST_DATA = [
pytest.param(None, "dummy", False, id="value is None"),
pytest.param(Undefined, "dummy", False, id="value is Undefined"),
pytest.param("value_not_a_list", "dummy", False, id="value is not a list"),
pytest.param(["dummy"], None, False, id="test_value is None"),
pytest.param(["dummy"], Undefined, False, id="test_value is Undefined"),
pytest.param(["a", "b", "c"], "b", True, id="test_value single value in value"),
pytest.param(["a", "b", "c"], ["d", "b"], True, id="test_value list contained value"),
pytest.param([1, 42, 666], 42, True, id="test success with int"),
pytest.param(["a", "b", "c"], "d", False, id="test_value list not contained value"),
pytest.param(["a", "b", "c"], ["d", "e"], False, id="test_value single value not in value"),
]


class TestContainsTest:
"""Test Contains."""

@pytest.mark.parametrize(("value, test_value, expected_result"), TEST_DATA)
def test_contains(self, value, test_value, expected_result):
"""Test the contains function."""
assert contains(value, test_value) == expected_result

0 comments on commit d6e3870

Please sign in to comment.