Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(anta.tests)!: Implement caching #394

Merged
merged 18 commits into from
Sep 28, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 46 additions & 5 deletions anta/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,14 @@
import asyncio
import logging
from abc import ABC, abstractmethod
from collections import defaultdict
from collections.abc import Iterator
from pathlib import Path
from typing import Any, Literal, Optional, Union
from typing import Any, DefaultDict, Literal, Optional, Union

import asyncssh
from aiocache import Cache
from aiocache.plugins import HitMissRatioPlugin
from aioeapi import Device, EapiCommandError
from asyncssh import SSHClientConnection, SSHClientConnectionOptions
from httpx import ConnectError, HTTPError
Expand Down Expand Up @@ -52,11 +55,21 @@ def __init__(self, name: str, tags: Optional[list[str]] = None) -> None:
self.tags: list[str] = tags if tags is not None else []
self.is_online: bool = False
self.established: bool = False
self.cache: Cache = Cache(cache_class=Cache.MEMORY, ttl=60, namespace=self.name, plugins=[HitMissRatioPlugin()])
carl-baillargeon marked this conversation as resolved.
Show resolved Hide resolved
self.cache_locks: DefaultDict[str, asyncio.Lock] = defaultdict(asyncio.Lock)

# Ensure tag 'all' is always set
if DEFAULT_TAG not in self.tags:
self.tags.append(DEFAULT_TAG)

@property
def statistics(self) -> dict[str, Any]:
"""
Returns the device tests statistics for logging
"""
stats = self.cache.hit_miss_ratio # pylint: disable=no-member
return {"total_tests": stats["total"], "cache_hits": stats["hits"], "cache_hit_ratio": f"{stats['hit_ratio'] * 100:.2f}%"}

def __rich_repr__(self) -> Iterator[tuple[str, Any]]:
"""
Implements Rich Repr Protocol
Expand All @@ -75,23 +88,51 @@ def __eq__(self, other: object) -> bool:
"""

@abstractmethod
async def collect(self, command: AntaCommand) -> None:
async def _collect(self, command: AntaCommand) -> None:
"""
Collect device command output.
This abstract coroutine can be used to implement any command collection method
for a device in ANTA.

The `collect()` implementation needs to populate the `output` attribute
The `_collect()` implementation needs to populate the `output` attribute
of the `AntaCommand` object passed as argument.

If a failure occurs, the `collect()` implementation is expected to catch the
If a failure occurs, the `_collect()` implementation is expected to catch the
exception and implement proper logging, the `output` attribute of the
`AntaCommand` object passed as argument would be `None` in this case.

Args:
command: the command to collect
"""

async def collect(self, command: AntaCommand) -> None:
"""
Collects the output of a given command. If caching is enabled for the command,
this method first checks the cache. If not found or caching is off, the output is
collected and stored in the cache. The collection is done by the private method `_collect`.

Ensures thread-safety for cache access using asynchronous locks on the command's UID.

Args:
command (AntaCommand): The command to process.

Effects:
- Updates `command.output` with the cached or collected data.
- Logs when data is retrieved from cache.
"""
async with self.cache_locks[command.uid]:
cached_output = None

if command.cache:
cached_output = await self.cache.get(command.uid) # pylint: disable=no-member
carl-baillargeon marked this conversation as resolved.
Show resolved Hide resolved

if cached_output is not None:
logger.debug(f"Cache hit for {command.command} on {self.name}")
command.output = cached_output
else:
await self._collect(command=command)
await self.cache.set(command.uid, command.output) # pylint: disable=no-member

async def collect_commands(self, commands: list[AntaCommand]) -> None:
"""
Collect multiple commands.
Expand Down Expand Up @@ -210,7 +251,7 @@ def __eq__(self, other: object) -> bool:
return False
return self._session.host == other._session.host and self._session.port == other._session.port

async def collect(self, command: AntaCommand) -> None:
async def _collect(self, command: AntaCommand) -> None:
"""
Collect device command output from EOS using aio-eapi.

Expand Down
12 changes: 11 additions & 1 deletion anta/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""
from __future__ import annotations

import hashlib
import logging
import time
from abc import ABC, abstractmethod
Expand Down Expand Up @@ -107,9 +108,11 @@ class AntaCommand(BaseModel):
version: eAPI version - valid values are 1 or "latest" - default is "latest"
revision: eAPI revision of the command. Valid values are 1 to 99. Revision has precedence over version.
ofmt: eAPI output - json or text - default is json
output: Output of the command populated by the collect() function
template: AntaTemplate object used to render this command
params: dictionary of variables with string values to render the template
params: Dictionary of variables with string values to render the template
failed: If the command execution fails, the Exception object is stored in this field
cache: Enable or disable caching for this AntaCommand
carl-baillargeon marked this conversation as resolved.
Show resolved Hide resolved
"""

# This is required if we want to keep an Exception object in the failed field
Expand All @@ -123,6 +126,13 @@ class AntaCommand(BaseModel):
template: Optional[AntaTemplate] = None
failed: Optional[Exception] = None
params: Dict[str, Any] = {}
cache: bool = True

@property
def uid(self) -> str:
"""Generate a unique identifier for this command"""
uid_str = f"{self.command}_{self.version}_{self.revision or 'NA'}_{self.ofmt}"
return hashlib.md5(uid_str.encode()).hexdigest()

@property
def json_output(self) -> dict[str, Any]:
Expand Down
12 changes: 10 additions & 2 deletions anta/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,10 @@ async def main(

# asyncio.gather takes an iterator of the function to run concurrently.
# we get the cross product of the devices and tests to build that iterator.

devices = inventory.get_inventory(established_only=established_only, tags=tags).values()
coros = []
for device, test in itertools.product(inventory.get_inventory(established_only=established_only, tags=tags).values(), tests):

for device, test in itertools.product(devices, tests):
test_class = test[0]
test_inputs = test[1]
try:
Expand All @@ -70,3 +71,10 @@ async def main(
anta_log_exception(r, message, logger)
else:
manager.add_test_result(r)

# Get each device statistics
for device in devices:
if hasattr(device, "statistics"):
logger.info(f"Tests statistics for {device.name}: {device.statistics}")
else:
logger.warning(f"{device.name} does not have a statistics attribute.")
6 changes: 3 additions & 3 deletions anta/tests/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ class VerifyLoggingLogsGeneration(AntaTest):
categories = ["logging"]
commands = [
AntaCommand(command="send log level informational message ANTA VerifyLoggingLogsGeneration validation"),
AntaCommand(command="show logging informational last 30 seconds | grep ANTA", ofmt="text"),
AntaCommand(command="show logging informational last 30 seconds | grep ANTA", ofmt="text", cache=False),
]

@AntaTest.anta_test
Expand Down Expand Up @@ -173,7 +173,7 @@ class VerifyLoggingHostname(AntaTest):
commands = [
AntaCommand(command="show hostname"),
AntaCommand(command="send log level informational message ANTA VerifyLoggingHostname validation"),
AntaCommand(command="show logging informational last 30 seconds | grep ANTA", ofmt="text"),
AntaCommand(command="show logging informational last 30 seconds | grep ANTA", ofmt="text", cache=False),
]

@AntaTest.anta_test
Expand Down Expand Up @@ -208,7 +208,7 @@ class VerifyLoggingTimestamp(AntaTest):
categories = ["logging"]
commands = [
AntaCommand(command="send log level informational message ANTA VerifyLoggingTimestamp validation"),
AntaCommand(command="show logging informational last 30 seconds | grep ANTA", ofmt="text"),
AntaCommand(command="show logging informational last 30 seconds | grep ANTA", ofmt="text", cache=False),
]

@AntaTest.anta_test
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ maintainers = [
description = "Arista Network Test Automation (ANTA) Framework"
license = { file = "LICENSE" }
dependencies = [
"aiocache~=0.12.2",
"aio-eapi==0.6.3",
"click~=8.1.6",
"click-help-colors~=0.9",
Expand Down