forked from aristanetworks/anta
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(anta): Add get_item function (aristanetworks#518)
- Loading branch information
1 parent
ae105b0
commit 8289071
Showing
2 changed files
with
155 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
# Copyright (c) 2023-2024 Arista Networks, Inc. | ||
# Use of this source code is governed by the Apache License 2.0 | ||
# that can be found in the LICENSE file. | ||
|
||
"""Get one dictionary from a list of dictionaries by matching the given key and value.""" | ||
from __future__ import annotations | ||
|
||
from typing import Any, Optional | ||
|
||
|
||
# pylint: disable=too-many-arguments | ||
def get_item( | ||
list_of_dicts: list[dict[Any, Any]], | ||
key: Any, | ||
value: Any, | ||
default: Optional[Any] = None, | ||
required: bool = False, | ||
case_sensitive: bool = False, | ||
var_name: Optional[str] = None, | ||
custom_error_msg: Optional[str] = None, | ||
) -> Any: | ||
"""Get one dictionary from a list of dictionaries by matching the given key and value. | ||
Returns the supplied default value or None if there is no match and "required" is False. | ||
Will return the first matching item if there are multiple matching items. | ||
Parameters | ||
---------- | ||
list_of_dicts : list(dict) | ||
List of Dictionaries to get list item from | ||
key : any | ||
Dictionary Key to match on | ||
value : any | ||
Value that must match | ||
default : any | ||
Default value returned if the key and value is not found | ||
required : bool | ||
Fail if there is no match | ||
case_sensitive : bool | ||
If the search value is a string, the comparison will ignore case by default | ||
var_name : str | ||
String used for raising exception with the full variable name | ||
custom_error_msg : str | ||
Custom error message to raise when required is True and the value is not found | ||
Returns | ||
------- | ||
any | ||
Dict or default value | ||
Raises | ||
------ | ||
ValueError | ||
If the key and value is not found and "required" == True | ||
""" | ||
if var_name is None: | ||
var_name = key | ||
|
||
if (not isinstance(list_of_dicts, list)) or list_of_dicts == [] or value is None or key is None: | ||
if required is True: | ||
raise ValueError(custom_error_msg or var_name) | ||
return default | ||
|
||
for list_item in list_of_dicts: | ||
if not isinstance(list_item, dict): | ||
# List item is not a dict as required. Skip this item | ||
continue | ||
|
||
item_value = list_item.get(key) | ||
|
||
# Perform case-insensitive comparison if value and item_value are strings and case_sensitive is False | ||
if not case_sensitive and isinstance(value, str) and isinstance(item_value, str): | ||
if item_value.casefold() == value.casefold(): | ||
return list_item | ||
elif item_value == value: | ||
# Match. Return this item | ||
return list_item | ||
|
||
# No Match | ||
if required is True: | ||
raise ValueError(custom_error_msg or var_name) | ||
return default |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
# Copyright (c) 2023-2024 Arista Networks, Inc. | ||
# Use of this source code is governed by the Apache License 2.0 | ||
# that can be found in the LICENSE file. | ||
|
||
"""Tests for `anta.tools.get_item`.""" | ||
from __future__ import annotations | ||
|
||
from contextlib import nullcontext as does_not_raise | ||
from typing import Any | ||
|
||
import pytest | ||
|
||
from anta.tools.get_item import get_item | ||
|
||
DUMMY_DATA = [ | ||
("id", 0), | ||
{ | ||
"id": 1, | ||
"name": "Alice", | ||
"age": 30, | ||
"email": "[email protected]", | ||
}, | ||
{ | ||
"id": 2, | ||
"name": "Bob", | ||
"age": 35, | ||
"email": "[email protected]", | ||
}, | ||
{ | ||
"id": 3, | ||
"name": "Charlie", | ||
"age": 40, | ||
"email": "[email protected]", | ||
}, | ||
] | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"list_of_dicts, key, value, default, required, case_sensitive, var_name, custom_error_msg, expected_result, expected_raise", | ||
[ | ||
pytest.param([], "name", "Bob", None, False, False, None, None, None, does_not_raise(), id="empty list"), | ||
pytest.param([], "name", "Bob", None, True, False, None, None, None, pytest.raises(ValueError, match="name"), id="empty list and required"), | ||
pytest.param(DUMMY_DATA, "name", "Jack", None, False, False, None, None, None, does_not_raise(), id="missing item"), | ||
pytest.param(DUMMY_DATA, "name", "Alice", None, False, False, None, None, DUMMY_DATA[1], does_not_raise(), id="found item"), | ||
pytest.param(DUMMY_DATA, "name", "Jack", "default_value", False, False, None, None, "default_value", does_not_raise(), id="default value"), | ||
pytest.param(DUMMY_DATA, "name", "Jack", None, True, False, None, None, None, pytest.raises(ValueError, match="name"), id="required"), | ||
pytest.param(DUMMY_DATA, "name", "Bob", None, False, True, None, None, DUMMY_DATA[2], does_not_raise(), id="case sensitive"), | ||
pytest.param(DUMMY_DATA, "name", "charlie", None, False, False, None, None, DUMMY_DATA[3], does_not_raise(), id="case insensitive"), | ||
pytest.param( | ||
DUMMY_DATA, "name", "Jack", None, True, False, "custom_var_name", None, None, pytest.raises(ValueError, match="custom_var_name"), id="custom var_name" | ||
), | ||
pytest.param( | ||
DUMMY_DATA, "name", "Jack", None, True, False, None, "custom_error_msg", None, pytest.raises(ValueError, match="custom_error_msg"), id="custom error msg" | ||
), | ||
], | ||
) | ||
def test_get_item( | ||
list_of_dicts: list[dict[Any, Any]], | ||
key: Any, | ||
value: Any, | ||
default: Any | None, | ||
required: bool, | ||
case_sensitive: bool, | ||
var_name: str | None, | ||
custom_error_msg: str | None, | ||
expected_result: str, | ||
expected_raise: Any, | ||
) -> None: | ||
"""Test get_item.""" | ||
# pylint: disable=too-many-arguments | ||
with expected_raise: | ||
assert get_item(list_of_dicts, key, value, default, required, case_sensitive, var_name, custom_error_msg) == expected_result |