Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor constraints so that each constraint is it's own class #23753

Merged
merged 1 commit into from
Nov 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/controller/python/BUILD.gn
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ chip_python_wheel_action("chip-core") {
"chip/utils/CommissioningBuildingBlocks.py",
"chip/utils/__init__.py",
"chip/yaml/__init__.py",
"chip/yaml/constraints.py",
"chip/yaml/data_model_lookup.py",
"chip/yaml/errors.py",
"chip/yaml/format_converter.py",
Expand Down
218 changes: 218 additions & 0 deletions src/controller/python/chip/yaml/constraints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
#
# Copyright (c) 2022 Project CHIP Authors
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from abc import ABC, abstractmethod
import chip.yaml.format_converter as Converter
from .variable_storage import VariableStorage


class ConstraintValidationError(Exception):
def __init__(self, message):
super().__init__(message)


class BaseConstraint(ABC):
'''Constrain Interface'''

@abstractmethod
def is_met(self, response) -> bool:
pass


class _LoadableConstraint(BaseConstraint):
'''Constraints where value might be stored in VariableStorage needing runtime load.'''

def __init__(self, value, field_type, variable_storage: VariableStorage):
self._variable_storage = variable_storage
# When not none _indirect_value_key is binding a name to the constraint value, and the
# actual value can only be looked-up dynamically, which is why this is a key name.
self._indirect_value_key = None
self._value = None

if value is None:
# Default values set above is all we need here.
return

if isinstance(value, str) and self._variable_storage.is_key_saved(value):
self._indirect_value_key = value
else:
self._value = Converter.convert_yaml_type(
value, field_type)

def get_value(self):
'''Gets the current value of the constraint.

This method accounts for getting the runtime saved value from DUT previous responses.
'''
if self._indirect_value_key:
return self._variable_storage.load(self._indirect_value_key)
return self._value


class _ConstraintHasValue(BaseConstraint):
def __init__(self, has_value):
self._has_value = has_value

def is_met(self, response) -> bool:
raise ConstraintValidationError('HasValue constraint currently not implemented')


class _ConstraintType(BaseConstraint):
def __init__(self, type):
self._type = type

def is_met(self, response) -> bool:
raise ConstraintValidationError('Type constraint currently not implemented')


class _ConstraintStartsWith(BaseConstraint):
def __init__(self, starts_with):
self._starts_with = starts_with

def is_met(self, response) -> bool:
return response.startswith(self._starts_with)


class _ConstraintEndsWith(BaseConstraint):
def __init__(self, ends_with):
self._ends_with = ends_with

def is_met(self, response) -> bool:
return response.endswith(self._ends_with)


class _ConstraintIsUpperCase(BaseConstraint):
def __init__(self, is_upper_case):
self._is_upper_case = is_upper_case

def is_met(self, response) -> bool:
return response.isupper() == self._is_upper_case


class _ConstraintIsLowerCase(BaseConstraint):
def __init__(self, is_lower_case):
self._is_lower_case = is_lower_case

def is_met(self, response) -> bool:
return response.islower() == self._is_lower_case


class _ConstraintMinValue(_LoadableConstraint):
def __init__(self, min_value, field_type, variable_storage: VariableStorage):
super().__init__(min_value, field_type, variable_storage)

def is_met(self, response) -> bool:
min_value = self.get_value()
return response >= min_value


class _ConstraintMaxValue(_LoadableConstraint):
def __init__(self, max_value, field_type, variable_storage: VariableStorage):
super().__init__(max_value, field_type, variable_storage)

def is_met(self, response) -> bool:
max_value = self.get_value()
return response <= max_value


class _ConstraintContains(BaseConstraint):
def __init__(self, contains):
self._contains = contains

def is_met(self, response) -> bool:
return set(self._contains).issubset(response)


class _ConstraintExcludes(BaseConstraint):
def __init__(self, excludes):
self._excludes = excludes

def is_met(self, response) -> bool:
return set(self._excludes).isdisjoint(response)


class _ConstraintHasMaskSet(BaseConstraint):
def __init__(self, has_masks_set):
self._has_masks_set = has_masks_set

def is_met(self, response) -> bool:
return all([(response & mask) == mask for mask in self._has_masks_set])


class _ConstraintHasMaskClear(BaseConstraint):
def __init__(self, has_masks_clear):
self._has_masks_clear = has_masks_clear

def is_met(self, response) -> bool:
return all([(response & mask) == 0 for mask in self._has_masks_clear])


class _ConstraintNotValue(_LoadableConstraint):
def __init__(self, not_value, field_type, variable_storage: VariableStorage):
super().__init__(not_value, field_type, variable_storage)

def is_met(self, response) -> bool:
not_value = self.get_value()
return response != not_value


def get_constraints(constraints, field_type,
variable_storage: VariableStorage) -> list[BaseConstraint]:
_constraints = []
if 'hasValue' in constraints:
_constraints.append(_ConstraintHasValue(constraints.get('hasValue')))

if 'type' in constraints:
_constraints.append(_ConstraintType(constraints.get('type')))

if 'startsWith' in constraints:
_constraints.append(_ConstraintStartsWith(constraints.get('startsWith')))

if 'endsWith' in constraints:
_constraints.append(_ConstraintEndsWith(constraints.get('endsWith')))

if 'isUpperCase' in constraints:
_constraints.append(_ConstraintIsUpperCase(constraints.get('isUpperCase')))

if 'isLowerCase' in constraints:
_constraints.append(_ConstraintIsLowerCase(constraints.get('isLowerCase')))

if 'minValue' in constraints:
_constraints.append(_ConstraintMinValue(
constraints.get('minValue'), field_type, variable_storage))

if 'maxValue' in constraints:
_constraints.append(_ConstraintMaxValue(
constraints.get('maxValue'), field_type, variable_storage))

if 'contains' in constraints:
_constraints.append(_ConstraintContains(constraints.get('contains')))

if 'excludes' in constraints:
_constraints.append(_ConstraintExcludes(constraints.get('excludes')))

if 'hasMasksSet' in constraints:
_constraints.append(_ConstraintHasMaskSet(constraints.get('hasMasksSet')))

if 'hasMasksClear' in constraints:
_constraints.append(_ConstraintHasMaskClear(constraints.get('hasMasksClear')))

if 'notValue' in constraints:
_constraints.append(_ConstraintNotValue(
constraints.get('notValue'), field_type, variable_storage))

return _constraints
116 changes: 6 additions & 110 deletions src/controller/python/chip/yaml/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from chip import ChipDeviceCtrl
from chip.clusters.Types import NullValue
from chip.tlv import float32
import yaml
import stringcase
Expand All @@ -30,6 +29,7 @@
from .data_model_lookup import *
import chip.yaml.format_converter as Converter
from .variable_storage import VariableStorage
from .constraints import get_constraints

_SUCCESS_STATUS_CODE = "SUCCESS"
_NODE_ID_DEFAULT = 0x12345
Expand All @@ -50,110 +50,6 @@ class _ExecutionContext:
config_values: dict = None


class _ConstraintValue:
'''Constraints that are numeric primitive data types'''

def __init__(self, value, field_type, context: _ExecutionContext):
self._variable_storage = context.variable_storage
# When not none _indirect_value_key is binding a name to the constraint value, and the
# actual value can only be looked-up dynamically, which is why this is a key name.
self._indirect_value_key = None
self._value = None

if value is None:
# Default values set above is all we need here.
return

if isinstance(value, str) and self._variable_storage.is_key_saved(value):
self._indirect_value_key = value
else:
self._value = Converter.convert_yaml_type(
value, field_type)

def get_value(self):
'''Gets the current value of the constraint.

This method accounts for getting the runtime saved value from DUT previous responses.
'''
if self._indirect_value_key:
return self._variable_storage.load(self._indirect_value_key)
return self._value


class _Constraints:
def __init__(self, constraints: dict, field_type, context: _ExecutionContext):
self._variable_storage = context.variable_storage
self._has_value = constraints.get('hasValue')
self._type = constraints.get('type')
self._starts_with = constraints.get('startsWith')
self._ends_with = constraints.get('endsWith')
self._is_upper_case = constraints.get('isUpperCase')
self._is_lower_case = constraints.get('isLowerCase')
self._min_value = _ConstraintValue(constraints.get('minValue'), field_type,
context)
self._max_value = _ConstraintValue(constraints.get('maxValue'), field_type,
context)
self._contains = constraints.get('contains')
self._excludes = constraints.get('excludes')
self._has_masks_set = constraints.get('hasMasksSet')
self._has_masks_clear = constraints.get('hasMasksClear')
self._not_value = _ConstraintValue(constraints.get('notValue'), field_type,
context)

def are_constrains_met(self, response) -> bool:
return_value = True

if self._has_value:
logger.warn(f'HasValue constraint currently not implemented, forcing failure')
return_value = False

if self._type:
logger.warn(f'Type constraint currently not implemented, forcing failure')
return_value = False

if self._starts_with and not response.startswith(self._starts_with):
return_value = False

if self._ends_with and not response.endswith(self._ends_with):
return_value = False

if self._is_upper_case and not response.isupper():
return_value = False

if self._is_lower_case and not response.islower():
return_value = False

min_value = self._min_value.get_value()
if response is not NullValue and min_value and response < min_value:
return_value = False

max_value = self._max_value.get_value()
if response is not NullValue and max_value and response > max_value:
return_value = False

if self._contains and not set(self._contains).issubset(response):
return_value = False

if self._excludes and not set(self._excludes).isdisjoint(response):
return_value = False

if self._has_masks_set:
for mask in self._has_masks_set:
if (response & mask) != mask:
return_value = False

if self._has_masks_clear:
for mask in self._has_masks_clear:
if (response & mask) != 0:
return_value = False

not_value = self._not_value.get_value()
if not_value and response == not_value:
return_value = False

return return_value


class _VariableToSave:
def __init__(self, variable_name: str, variable_storage: VariableStorage):
self._variable_name = variable_name
Expand Down Expand Up @@ -311,7 +207,7 @@ def __init__(self, item: dict, cluster: str, context: _ExecutionContext):
'''
super().__init__(item['label'])
self._attribute_name = stringcase.pascalcase(item['attribute'])
self._constraints = None
self._constraints = []
self._cluster = cluster
self._cluster_object = None
self._request_object = None
Expand Down Expand Up @@ -362,9 +258,9 @@ def __init__(self, item: dict, cluster: str, context: _ExecutionContext):

constraints = self._expected_raw_response.get('constraints')
if constraints:
self._constraints = _Constraints(constraints,
self._request_object.attribute_type.Type,
context)
self._constraints = get_constraints(constraints,
self._request_object.attribute_type.Type,
context.variable_storage)

def run_action(self, dev_ctrl: ChipDeviceCtrl, endpoint: int, node_id: int):
try:
Expand All @@ -391,7 +287,7 @@ def run_action(self, dev_ctrl: ChipDeviceCtrl, endpoint: int, node_id: int):
if self._variable_to_save is not None:
self._variable_to_save.save_response(parsed_resp)

if self._constraints and not self._constraints.are_constrains_met(parsed_resp):
if not all([constraint.is_met(parsed_resp) for constraint in self._constraints]):
logger.error(f'Constraints check failed')
# TODO how should we fail the test here?

Expand Down