Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor(plugins): Move jinja test code for arista.avd.contains to PyAVD #4131

Merged
merged 3 commits into from
Jun 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

__metaclass__ = type

from functools import wraps
from typing import Callable
from functools import partial, wraps
from typing import Callable, Literal

from ansible.errors import AnsibleFilterError, AnsibleUndefinedVariable
from ansible.errors import AnsibleFilterError, AnsibleInternalError, AnsibleTemplateError, AnsibleUndefinedVariable
from ansible.module_utils.basic import to_native
from jinja2.exceptions import UndefinedError

Expand All @@ -27,17 +27,29 @@ def __call__(self, *args):
raise self.exception


def wrap_filter(name: str) -> Callable:
def wrap_filter_decorator(func: Callable | None) -> Callable:
def wrap_plugin(plugin_type: Literal["filter", "test"], name: str) -> Callable:
plugin_map = {
"filter": AnsibleFilterError,
"test": AnsibleTemplateError,
}

if plugin_type not in plugin_map:
raise AnsibleInternalError(f"Wrong plugin type {plugin_type} passed to wrap_plugin.")

def wrap_plugin_decorator(func: Callable) -> Callable:
@wraps(func)
def filter_wrapper(*args, **kwargs):
def plugin_wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except UndefinedError as e:
raise AnsibleUndefinedVariable(f"Filter '{name}' failed: {to_native(e)}", orig_exc=e) from e
raise AnsibleUndefinedVariable(f"{plugin_type.capitalize()} '{name}' failed: {to_native(e)}", orig_exc=e) from e
except Exception as e:
raise AnsibleFilterError(f"Filter '{name}' failed: {to_native(e)}", orig_exc=e) from e
raise plugin_map[plugin_type](f"{plugin_type.capitalize()} '{name}' failed: {to_native(e)}", orig_exc=e) from e

return plugin_wrapper

return wrap_plugin_decorator

return filter_wrapper

return wrap_filter_decorator
wrap_filter = partial(wrap_plugin, "filter")
wrap_test = partial(wrap_plugin, "test")
66 changes: 17 additions & 49 deletions ansible_collections/arista/avd/plugins/test/contains.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,22 @@

__metaclass__ = type

from jinja2.runtime import Undefined
from ansible.errors import AnsibleTemplateError

from ansible_collections.arista.avd.plugins.plugin_utils.pyavd_wrappers import RaiseOnUse, wrap_test

PLUGIN_NAME = "arista.avd.contains"

try:
from pyavd.j2tests.contains import contains
except ImportError as e:
contains = RaiseOnUse(
AnsibleTemplateError(
f"The '{PLUGIN_NAME}' plugin requires the 'pyavd' Python library. Got import error",
orig_exc=e,
)
)


DOCUMENTATION = r"""
---
Expand Down Expand Up @@ -66,53 +81,6 @@
"""


def contains(value, test_value=None):
"""
contains - Ansible test plugin to test if a list contains one or more elements

Arista.avd.contains will test value and argument if defined and is not none and return false if any one them doesn't pass.
Test value can be one value or a list of values to test for.

Example:
1. Test for one element in list
{% if switch.vlans is arista.avd.contains(123) %}
...
{% endif %}
2. Test for multiple elements in list
{% if switch.vlans is arista.avd.contains([123, 456]) %}
...
{% endif %}

Parameters
----------
value : any
List to test
test_value : single item or list of items
Value(s) to test for in value

Returns
-------
boolean
True if variable matches criteria, False in other cases.
"""
if isinstance(value, Undefined) or value is None or not isinstance(value, list):
# Invalid value - return false
return False
elif isinstance(test_value, Undefined) or value is None:
# Invalid value - return false
return False
elif isinstance(test_value, list) and not set(value).isdisjoint(test_value):
# test_value is list so test if value and test_value has any common items
return True
elif test_value in value:
# Test if test_value is in value
return True
else:
return False


class TestModule(object):
def tests(self):
return {
"contains": contains,
}
return {"contains": wrap_test(PLUGIN_NAME)(contains)}
1 change: 1 addition & 0 deletions python-avd/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ EOS_CLI_CONFIG_GEN_TEMPLATE_DIR = $(VENDOR_DIR)/templates
SCHEMAS_DIR = $(VENDOR_DIR)/schemas
EOS_DESIGNS_MODULES_DIR = $(VENDOR_DIR)/eos_designs
PYAVD_FILTER_IMPORT = $(PACKAGE_DIR).j2filters
PYAVD_TEST_IMPORT = $(PACKAGE_DIR).j2tests
EOS_DESIGNS_IMPORT = $(PACKAGE_DIR)._eos_designs
# export PYTHONPATH=$(CURRENT_DIR) # Uncomment to test from source

Expand Down
43 changes: 43 additions & 0 deletions python-avd/pyavd/j2tests/contains.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright (c) 2023-2024 Arista Networks, Inc.
# Use of this source code is governed by the Apache License 2.0
# that can be found in the LICENSE file.
"""AVD Jinja2 test contains.

The test checks if a list contains any of the value(s) passed in test_value.
"""

from __future__ import annotations

from typing import Any

from jinja2.runtime import Undefined


def contains(value: list[Any], test_value: Any | list[Any] = None) -> bool:
"""The test checks if a list contains any of the value(s) passed in test_value.

If 'value' is Undefined, None or not a list then the test has failed.

Parameters
----------
value :
List to test
test_value : single item or list of items
Value(s) to test for in value

Returns
-------
boolean
True if variable matches criteria, False in other cases.
"""
# TODO - this will fail miserably if test_value is not hashable !
if isinstance(value, Undefined) or value is None or not isinstance(value, list):
# Invalid value - return false
return False
if isinstance(test_value, Undefined) or test_value is None:
# Invalid value - return false
return False
if isinstance(test_value, list) and not set(value).isdisjoint(test_value):
# test_value is list so test if value and test_value has any common items
return True
return test_value in value
2 changes: 1 addition & 1 deletion python-avd/pyavd/templater.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .j2filters.hide_passwords import hide_passwords
from .j2filters.list_compress import list_compress
from .j2filters.natural_sort import natural_sort
from .j2tests.contains import contains


class Undefined(StrictUndefined):
Expand Down Expand Up @@ -62,7 +63,6 @@ def import_filters_and_tests(self) -> None:
from .vendor.j2.filter.decrypt import decrypt
from .vendor.j2.filter.encrypt import encrypt
from .vendor.j2.filter.range_expand import range_expand
from .vendor.j2.test.contains import contains
from .vendor.j2.test.defined import defined

# pylint: enable=import-outside-toplevel
Expand Down
32 changes: 32 additions & 0 deletions python-avd/tests/pyavd/j2tests/test_contains.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (c) 2023-2024 Arista Networks, Inc.
# Use of this source code is governed by the Apache License 2.0
# that can be found in the LICENSE file.
"""Unit tests for pyavd.j2tests.contains."""

from __future__ import annotations

import pytest
from jinja2.runtime import Undefined
from pyavd.j2tests.contains import contains

TEST_DATA = [
gmuloc marked this conversation as resolved.
Show resolved Hide resolved
pytest.param(None, "dummy", False, id="value is None"),
pytest.param(Undefined, "dummy", False, id="value is Undefined"),
pytest.param("value_not_a_list", "dummy", False, id="value is not a list"),
pytest.param(["dummy"], None, False, id="test_value is None"),
pytest.param(["dummy"], Undefined, False, id="test_value is Undefined"),
pytest.param(["a", "b", "c"], "b", True, id="test_value single value in value"),
pytest.param(["a", "b", "c"], ["d", "b"], True, id="test_value list contained value"),
pytest.param([1, 42, 666], 42, True, id="test success with int"),
pytest.param(["a", "b", "c"], "d", False, id="test_value list not contained value"),
pytest.param(["a", "b", "c"], ["d", "e"], False, id="test_value single value not in value"),
]


class TestContainsTest:
"""Test Contains."""

@pytest.mark.parametrize(("value, test_value, expected_result"), TEST_DATA)
def test_contains(self, value, test_value, expected_result):
"""Test the contains function."""
assert contains(value, test_value) == expected_result
Loading