Skip to content

Commit

Permalink
multiple custom actions validators (aws#5134)
Browse files Browse the repository at this point in the history
support for async Resources validation and async UrlValidator usage for custom actions
  • Loading branch information
psacc authored Apr 17, 2023
1 parent 0704826 commit 0cecd4d
Show file tree
Hide file tree
Showing 16 changed files with 762 additions and 70 deletions.
6 changes: 3 additions & 3 deletions cli/src/pcluster/config/cluster_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@
SharedStorageNameValidator,
UnmanagedFsxMultiAzValidator,
)
from pcluster.validators.common import ValidatorContext
from pcluster.validators.common import ValidatorContext, get_async_timed_validator_type_for
from pcluster.validators.database_validators import DatabaseUriValidator
from pcluster.validators.directory_service_validators import (
AdditionalSssdConfigsValidator,
Expand Down Expand Up @@ -1135,15 +1135,15 @@ def _register_validators(self, context: ValidatorContext = None): # noqa: D102


class CustomAction(Resource):
"""Represent a custom action resource."""
"""Represent a custom action script resource."""

def __init__(self, script: str, args: List[str] = None):
super().__init__()
self.script = Resource.init_param(script)
self.args = Resource.init_param(args)

def _register_validators(self, context: ValidatorContext = None): # noqa: D102 #pylint: disable=unused-argument
self._register_validator(UrlValidator, url=self.script)
self._register_validator(get_async_timed_validator_type_for(UrlValidator), url=self.script, timeout=5)


class CustomActions(Resource):
Expand Down
92 changes: 78 additions & 14 deletions cli/src/pcluster/config/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
# This module contains all the classes representing the Resources objects.
# These objects are obtained from the configuration file through a conversion based on the Schema classes.
#
import asyncio
import itertools
import json
import logging
from abc import ABC, abstractmethod
from typing import List, Set

from pcluster.validators.common import FailureLevel, ValidationResult, Validator, ValidatorContext
from pcluster.validators.common import AsyncValidator, FailureLevel, ValidationResult, Validator, ValidatorContext
from pcluster.validators.iam_validators import AdditionalIamPolicyValidator
from pcluster.validators.networking_validators import LambdaFunctionsVpcConfigValidator
from pcluster.validators.s3_validators import UrlValidator
Expand Down Expand Up @@ -122,6 +124,7 @@ def __repr__(self):
def __init__(self, implied: bool = False):
# Parameters registry
self.__params = {}
self._validation_futures = []
self._validation_failures: List[ValidationResult] = []
self._validators: List = []
self.implied = implied
Expand Down Expand Up @@ -167,17 +170,40 @@ def init_param(value, default=None, update_policy=None):
"""Create a resource attribute backed by a Configuration Parameter."""
return Resource.Param(value, default=default, update_policy=update_policy)

def _validator_execute(self, validator_class, validator_args, suppressors):
@staticmethod
def _validator_execute(validator_class, validator_args, suppressors, validation_executor):
validator = validator_class()

if any(suppressor.suppress_validator(validator) for suppressor in (suppressors or [])):
LOGGER.debug("Suppressing validator %s", validator_class.__name__)
return []

LOGGER.debug("Executing validator %s", validator_class.__name__)
return validation_executor(validator_args, validator)

@staticmethod
def _validator_execute_sync(validator_args, validator):
try:
return validator.execute(**validator_args)
except Exception as e:
LOGGER.debug("Validator %s unexpected failure: %s", validator.type, e)
return [ValidationResult(str(e), FailureLevel.ERROR, validator.type)]

@staticmethod
def _validator_execute_async(validator_args, validator):
return validator.execute_async(**validator_args)

def _await_async_validators(self):
# here could be a good spot to add a cascading timeout for the async validators
# if they are taking too long to execute for a resource and its children since the use of
# get_async_timed_validator_type_for allows to decorate only on single validator at a time and
# does not cascade to child resources
return list(
itertools.chain.from_iterable(
asyncio.get_event_loop().run_until_complete(asyncio.gather(*self._validation_futures))
)
)

def _nested_resources(self):
nested_resources = []
for _, value in self.__dict__.items():
Expand All @@ -188,34 +214,72 @@ def _nested_resources(self):
return nested_resources

def validate(
self, suppressors: List[ValidatorSuppressor] = None, context: ValidatorContext = None
) -> List[ValidationResult]:
"""Execute registered validators."""
# Cleanup failures and validators
self, suppressors: List[ValidatorSuppressor] = None, context: ValidatorContext = None, nested: bool = False
):
"""
Execute registered validators.
The "nested" parameter is used only for internal recursive calls to distinguish those from the top level
one where the async validators results should be awaited for.
"""
# this validation logic is a responsibility that could be completely separated from the resource tree
# also until we need to support both sync and async validation this logic will be unnecessarily complex
# embracing async validation completely is possible and will greatly simplify this
self._validation_futures.clear()
self._validation_failures.clear()

try:
self._validate_nested_resources(context, suppressors)
self._validate_self(context, suppressors)
finally:
if nested:
result = self._validation_failures, self._validation_futures.copy()
else:
self._validation_failures.extend(self._await_async_validators())
result = self._validation_failures
self._validation_futures.clear()

return result

def _validate_nested_resources(self, context, suppressors):
# Call validators for nested resources
for nested_resource in self._nested_resources():
self._validation_failures.extend(nested_resource.validate(suppressors, context))
failures, futures = nested_resource.validate(suppressors, context, nested=True)
self._validation_futures.extend(futures)
self._validation_failures.extend(failures)

# Update validators to be executed according to current status of the model and order by priority
def _validate_self(self, context, suppressors):
self._validators.clear()
self._register_validators(context)
for validator in self._validators:
self._validation_failures.extend(self._validator_execute(*validator, suppressors))

return self._validation_failures
if issubclass(validator[0], AsyncValidator):
self._validation_futures.extend(
[self._validator_execute(*validator, suppressors, self._validator_execute_async)]
)
else:
self._validation_failures.extend(
self._validator_execute(*validator, suppressors, self._validator_execute_sync)
)

def _register_validators(self, context: ValidatorContext = None):
"""
Execute validators.
Register all the validators that contribute to ensure that the resource parameters are valid.
Method to be implemented in Resources.
Method to be implemented in Resource subclasses that need to register validators by invoking the internal
_register_validator method.
:param context:
:return:
"""
pass

def _register_validator(self, validator_class, **validator_args):
"""Execute the validator."""
"""Add a validator with the specified arguments.
The validator will be executed according to the current status of the model and ordered by priority.
:param validator_class: Validator class to be executed
:param validator_args: Arguments to be passed to the validator
:return:
"""
self._validators.append((validator_class, validator_args))

def __repr__(self):
Expand Down
86 changes: 86 additions & 0 deletions cli/src/pcluster/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import datetime
import functools
import itertools
import json
import logging
Expand All @@ -19,6 +21,7 @@
import sys
import time
import zipfile
from concurrent.futures import ThreadPoolExecutor
from io import BytesIO
from shlex import quote
from typing import Callable, NoReturn
Expand Down Expand Up @@ -450,3 +453,86 @@ def batch_by_property_callback(items, property_callback: Callable[..., int], bat
if item_index == len(items) - 1:
# If on the last time, yield the batch
yield current_batch


class AsyncUtils:
"""Utility class for async functions."""

@staticmethod
def async_retry(stop_max_attempt_number=5, wait_fixed=1, retry_on_exception=None):
"""
Decorate an async coroutine function to retry its execution when an exception is risen.
:param stop_max_attempt_number: Max number of retries.
:param wait_fixed: Wait time (in seconds) between retries.
:param retry_on_exception: Exception to retry on.
:return:
"""
if retry_on_exception is None:
retry_on_exception = (Exception,)

def decorator(func):
@functools.wraps(func)
async def wrapper(*args, **kwargs):
attempt = 0
result = None
while attempt <= stop_max_attempt_number:
try:
result = await func(*args, **kwargs)
break
except retry_on_exception as err:
if attempt < stop_max_attempt_number:
attempt += 1
await asyncio.sleep(wait_fixed)
else:
raise err
return result

return wrapper

return decorator

@staticmethod
def async_timeout_cache(timeout=10):
"""
Decorate a function to cache the result of a call for a given time period.
:param timeout: Timeout in seconds.
:return:
"""

def decorator(func):
_cache = {}

@functools.wraps(func)
async def wrapper(*args, **kwargs):
cache_key = (func.__name__, args[1:], frozenset(kwargs.items()))
current_time = time.time()

if cache_key not in _cache or (_cache[cache_key][1] < current_time):
_cache[cache_key] = (asyncio.ensure_future(func(*args, **kwargs)), current_time + timeout)

return await _cache[cache_key][0]

return wrapper

return decorator

_thread_pool_executor: ThreadPoolExecutor = ThreadPoolExecutor()

@staticmethod
def async_from_sync(func):
"""
Convert a synchronous function to an async one.
:param func:
:return:
"""

@functools.wraps(func)
async def wrapper(self, *args, **kwargs):
return await asyncio.get_event_loop().run_in_executor(
AsyncUtils._thread_pool_executor, lambda: func(self, *args, **kwargs)
)

return wrapper
79 changes: 75 additions & 4 deletions cli/src/pcluster/validators/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,13 @@
# This module contains all the classes representing the Resources objects.
# These objects are obtained from the configuration file through a conversion based on the Schema classes.
#

import asyncio
import functools
from abc import ABC, abstractmethod
from enum import Enum
from typing import List

ASYNC_TIMED_VALIDATORS_DEFAULT_TIMEOUT_SEC = 10


class FailureLevel(Enum):
Expand All @@ -41,7 +45,7 @@ def __repr__(self):


class Validator(ABC):
"""Abstract validator. The children must implement the validate method."""
"""Abstract validator. The children must implement the _validate method."""

def __init__(self):
self._failures = []
Expand All @@ -55,17 +59,84 @@ def type(self):
"""Identify the type of validator."""
return self.__class__.__name__

def execute(self, *arg, **kwargs):
def execute(self, *arg, **kwargs) -> List[ValidationResult]:
"""Entry point of all validators to verify all input params are valid."""
self._validate(*arg, **kwargs)
return self._failures

@abstractmethod
def _validate(self, *args, **kwargs):
"""Must be implemented with specific validation logic."""
"""
Must be implemented with specific validation logic.
Use _add_failure to add failures to the list of failures returned by execute.
"""
pass


class AsyncValidator(Validator):
"""Abstract validator that supports *also* async execution. Children must implement the _validate_async method."""

def __init__(self):
super().__init__()

def _validate(self, *arg, **kwargs):
asyncio.get_event_loop().run_until_complete(self._validate_async(*arg, **kwargs))
return self._failures

async def execute_async(self, *arg, **kwargs) -> List[ValidationResult]:
"""Entry point of all async validators to verify all input params are valid."""
await self._validate_async(*arg, **kwargs)
return self._failures

@abstractmethod
async def _validate_async(self, *args, **kwargs):
"""
Must be implemented with specific validation logic.
Use _add_failure to add failures to the list of failures returned by execute or execute_async when awaited.
"""
pass


def get_async_timed_validator_type_for(validator_type: type) -> AsyncValidator:
"""
Return the type decorating the given validator with timeout support.
The enriched _validate_async will accept an additional timeout parameter.
If not provided will default to ASYNC_TIMED_VALIDATORS_DEFAULT_TIMEOUT_SEC.
Since validators async execution is coroutine based with preemptive multitasking,
the effective time to fail the validator for timeout may exceed the requested one.
"""
class_name = f"AsyncTimed{validator_type.__name__}"

if class_name not in globals():
class_bases = validator_type.__bases__
class_dict = dict(validator_type.__dict__)

def _async_timed_validate(original_method):
@functools.wraps(original_method)
async def _validate_async(self: AsyncValidator, *args, **kwargs):
timeout = kwargs.pop("timeout", ASYNC_TIMED_VALIDATORS_DEFAULT_TIMEOUT_SEC)
try:
await asyncio.wait_for(original_method(self, *args, **kwargs), timeout=timeout)
except asyncio.TimeoutError:
self._add_failure( # pylint: disable=protected-access
f"Validation of ({kwargs}) timed out after {timeout} seconds.", FailureLevel.WARNING
)

return _validate_async

class_dict["_validate_async"] = _async_timed_validate(class_dict["_validate_async"])

schema_class_type = type(class_name, class_bases, class_dict)
globals()[class_name] = schema_class_type
else:
schema_class_type = globals()[class_name]
return schema_class_type


class ValidatorContext:
"""Context containing information about cluster environment meant to be passed to validators."""

Expand Down
Loading

0 comments on commit 0cecd4d

Please sign in to comment.