From 37c0a12c19cf76776c45029a90de37c58a77ab18 Mon Sep 17 00:00:00 2001 From: Cecille Freeman Date: Wed, 17 Jul 2024 21:34:55 -0400 Subject: [PATCH] [WIP] python test framework PICS 2.0 --- .../matter_yamltests/hooks.py | 6 + src/python_testing/TC_TIMESYNC_2_1.py | 54 ++-- src/python_testing/matter_testing_support.py | 100 ++++++- .../test_testing/MockTestRunner.py | 13 +- .../test_testing/TestDecorators.py | 280 ++++++++++++++++++ 5 files changed, 410 insertions(+), 43 deletions(-) create mode 100644 src/python_testing/test_testing/TestDecorators.py diff --git a/scripts/py_matter_yamltests/matter_yamltests/hooks.py b/scripts/py_matter_yamltests/matter_yamltests/hooks.py index 3d25cfff9c06c1..78905826f55757 100644 --- a/scripts/py_matter_yamltests/matter_yamltests/hooks.py +++ b/scripts/py_matter_yamltests/matter_yamltests/hooks.py @@ -227,6 +227,12 @@ def show_prompt(self, """ pass + def test_skipped(self, filename: str, name: str): + """ + This method is called when the test script determines that the test is not applicable for the DUT. + """ + pass + class WebSocketRunnerHooks(): def connecting(self, url: str): diff --git a/src/python_testing/TC_TIMESYNC_2_1.py b/src/python_testing/TC_TIMESYNC_2_1.py index febaea6f1e9d49..919bce46cab38d 100644 --- a/src/python_testing/TC_TIMESYNC_2_1.py +++ b/src/python_testing/TC_TIMESYNC_2_1.py @@ -27,53 +27,45 @@ import chip.clusters as Clusters from chip.clusters.Types import NullValue -from matter_testing_support import MatterBaseTest, async_test_body, default_matter_test_main, utc_time_in_matter_epoch +from matter_testing_support import MatterBaseTest, default_matter_test_main, utc_time_in_matter_epoch, per_endpoint_test, has_cluster, has_attribute from mobly import asserts class TC_TIMESYNC_2_1(MatterBaseTest): - async def read_ts_attribute_expect_success(self, endpoint, attribute): + async def read_ts_attribute_expect_success(self, attribute): cluster = Clusters.Objects.TimeSynchronization - return await self.read_single_attribute_check_success(endpoint=endpoint, cluster=cluster, attribute=attribute) + return await self.read_single_attribute_check_success(endpoint=None, cluster=cluster, attribute=attribute) - def pics_TC_TIMESYNC_2_1(self) -> list[str]: - return ["TIMESYNC.S"] - - @async_test_body + @per_endpoint_test(has_cluster(Clusters.TimeSynchronization) and has_attribute(Clusters.TimeSynchronization.Attributes.TimeSource)) async def test_TC_TIMESYNC_2_1(self): - endpoint = 0 - - features = await self.read_single_attribute(dev_ctrl=self.default_controller, node_id=self.dut_node_id, - endpoint=endpoint, attribute=Clusters.TimeSynchronization.Attributes.FeatureMap) + attributes = Clusters.TimeSynchronization.Attributes + features = await self.read_ts_attribute_expect_success(attribute=attributes.FeatureMap) self.supports_time_zone = bool(features & Clusters.TimeSynchronization.Bitmaps.Feature.kTimeZone) self.supports_ntpc = bool(features & Clusters.TimeSynchronization.Bitmaps.Feature.kNTPClient) self.supports_ntps = bool(features & Clusters.TimeSynchronization.Bitmaps.Feature.kNTPServer) self.supports_trusted_time_source = bool(features & Clusters.TimeSynchronization.Bitmaps.Feature.kTimeSyncClient) - time_cluster = Clusters.TimeSynchronization - timesync_attr_list = time_cluster.Attributes.AttributeList - attribute_list = await self.read_single_attribute_check_success(endpoint=endpoint, cluster=time_cluster, attribute=timesync_attr_list) - timesource_attr_id = time_cluster.Attributes.TimeSource.attribute_id + timesync_attr_list = attributes.AttributeList + attribute_list = await self.read_ts_attribute_expect_success(attribute=timesync_attr_list) + timesource_attr_id = attributes.TimeSource.attribute_id self.print_step(1, "Commissioning, already done") - attributes = Clusters.TimeSynchronization.Attributes self.print_step(2, "Read Granularity attribute") - granularity_dut = await self.read_ts_attribute_expect_success(endpoint=endpoint, attribute=attributes.Granularity) + granularity_dut = await self.read_ts_attribute_expect_success(attribute=attributes.Granularity) asserts.assert_less(granularity_dut, Clusters.TimeSynchronization.Enums.GranularityEnum.kUnknownEnumValue, "Granularity is not in valid range") self.print_step(3, "Read TimeSource") if timesource_attr_id in attribute_list: - time_source = await self.read_ts_attribute_expect_success(endpoint=endpoint, attribute=attributes.TimeSource) + time_source = await self.read_ts_attribute_expect_success(attribute=attributes.TimeSource) asserts.assert_less(time_source, Clusters.TimeSynchronization.Enums.TimeSourceEnum.kUnknownEnumValue, "TimeSource is not in valid range") self.print_step(4, "Read TrustedTimeSource") if self.supports_trusted_time_source: - trusted_time_source = await self.read_ts_attribute_expect_success(endpoint=endpoint, - attribute=attributes.TrustedTimeSource) + trusted_time_source = await self.read_ts_attribute_expect_success(attribute=attributes.TrustedTimeSource) if trusted_time_source is not NullValue: asserts.assert_less_equal(trusted_time_source.fabricIndex, 0xFE, "FabricIndex for the TrustedTimeSource is out of range") @@ -82,7 +74,7 @@ async def test_TC_TIMESYNC_2_1(self): self.print_step(5, "Read DefaultNTP") if self.supports_ntpc: - default_ntp = await self.read_ts_attribute_expect_success(endpoint=endpoint, attribute=attributes.DefaultNTP) + default_ntp = await self.read_ts_attribute_expect_success(attribute=attributes.DefaultNTP) if default_ntp is not NullValue: asserts.assert_less_equal(len(default_ntp), 128, "DefaultNTP length must be less than 128") # Assume this is a valid web address if it has at least one . in the name @@ -97,7 +89,7 @@ async def test_TC_TIMESYNC_2_1(self): self.print_step(6, "Read TimeZone") if self.supports_time_zone: - tz_dut = await self.read_ts_attribute_expect_success(endpoint=endpoint, attribute=attributes.TimeZone) + tz_dut = await self.read_ts_attribute_expect_success(attribute=attributes.TimeZone) asserts.assert_greater_equal(len(tz_dut), 1, "TimeZone must have at least one entry in the list") asserts.assert_less_equal(len(tz_dut), 2, "TimeZone may have a maximum of two entries in the list") for entry in tz_dut: @@ -112,7 +104,7 @@ async def test_TC_TIMESYNC_2_1(self): self.print_step(7, "Read DSTOffset") if self.supports_time_zone: - dst_dut = await self.read_ts_attribute_expect_success(endpoint=endpoint, attribute=attributes.DSTOffset) + dst_dut = await self.read_ts_attribute_expect_success(attribute=attributes.DSTOffset) last_valid_until = -1 last_valid_starting = -1 for dst in dst_dut: @@ -126,7 +118,7 @@ async def test_TC_TIMESYNC_2_1(self): asserts.assert_equal(dst, dst_dut[-1], "DSTOffset list must have Null ValidUntil at the end") self.print_step(8, "Read UTCTime") - utc_dut = await self.read_ts_attribute_expect_success(endpoint=endpoint, attribute=attributes.UTCTime) + utc_dut = await self.read_ts_attribute_expect_success(attribute=attributes.UTCTime) if utc_dut is NullValue: asserts.assert_equal(granularity_dut, Clusters.TimeSynchronization.Enums.GranularityEnum.kNoTimeGranularity) else: @@ -141,8 +133,8 @@ async def test_TC_TIMESYNC_2_1(self): self.print_step(9, "Read LocalTime") if self.supports_time_zone: - utc_dut = await self.read_ts_attribute_expect_success(endpoint=endpoint, attribute=attributes.UTCTime) - local_dut = await self.read_ts_attribute_expect_success(endpoint=endpoint, attribute=attributes.LocalTime) + utc_dut = await self.read_ts_attribute_expect_success(attribute=attributes.UTCTime) + local_dut = await self.read_ts_attribute_expect_success(attribute=attributes.LocalTime) if utc_dut is NullValue: asserts.assert_true(local_dut is NullValue, "LocalTime must be Null if UTC time is Null") elif len(dst_dut) == 0: @@ -156,30 +148,30 @@ async def test_TC_TIMESYNC_2_1(self): self.print_step(10, "Read TimeZoneDatabase") if self.supports_time_zone: - tz_db_dut = await self.read_ts_attribute_expect_success(endpoint=endpoint, attribute=attributes.TimeZoneDatabase) + tz_db_dut = await self.read_ts_attribute_expect_success(attribute=attributes.TimeZoneDatabase) asserts.assert_less(tz_db_dut, Clusters.TimeSynchronization.Enums.TimeZoneDatabaseEnum.kUnknownEnumValue, "TimeZoneDatabase is not in valid range") self.print_step(11, "Read NTPServerAvailable") if self.supports_ntps: # bool typechecking happens in the test read functions, so all we need to do here is do the read - await self.read_ts_attribute_expect_success(endpoint=endpoint, attribute=attributes.NTPServerAvailable) + await self.read_ts_attribute_expect_success(attribute=attributes.NTPServerAvailable) self.print_step(12, "Read TimeZoneListMaxSize") if self.supports_time_zone: - size = await self.read_ts_attribute_expect_success(endpoint=endpoint, attribute=attributes.TimeZoneListMaxSize) + size = await self.read_ts_attribute_expect_success(attribute=attributes.TimeZoneListMaxSize) asserts.assert_greater_equal(size, 1, "TimeZoneListMaxSize must be at least 1") asserts.assert_less_equal(size, 2, "TimeZoneListMaxSize must be max 2") self.print_step(13, "Read DSTOffsetListMaxSize") if self.supports_time_zone: - size = await self.read_ts_attribute_expect_success(endpoint=endpoint, attribute=attributes.DSTOffsetListMaxSize) + size = await self.read_ts_attribute_expect_success(attribute=attributes.DSTOffsetListMaxSize) asserts.assert_greater_equal(size, 1, "DSTOffsetListMaxSize must be at least 1") self.print_step(14, "Read SupportsDNSResolve") # bool typechecking happens in the test read functions, so all we need to do here is do the read if self.supports_ntpc: - await self.read_ts_attribute_expect_success(endpoint=endpoint, attribute=attributes.SupportsDNSResolve) + await self.read_ts_attribute_expect_success(attribute=attributes.SupportsDNSResolve) if __name__ == "__main__": diff --git a/src/python_testing/matter_testing_support.py b/src/python_testing/matter_testing_support.py index b8854e23a7d2ee..80678946303927 100644 --- a/src/python_testing/matter_testing_support.py +++ b/src/python_testing/matter_testing_support.py @@ -34,6 +34,7 @@ from dataclasses import dataclass, field from datetime import datetime, timedelta, timezone from enum import Enum +from functools import partial from typing import List, Optional, Tuple from chip.tlv import float32, uint @@ -339,6 +340,8 @@ def show_prompt(self, placeholder: Optional[str] = None, default_value: Optional[str] = None) -> None: pass + def test_skipped(self, filename: str, name: str): + logging.info(f"Skipping test from {filename}: {name}") @dataclass @@ -771,8 +774,10 @@ def setup_class(self): def setup_test(self): self.current_step_index = 0 + self.test_start_time = datetime.now(timezone.utc) self.step_start_time = datetime.now(timezone.utc) self.step_skipped = False + self.failed = False if self.runner_hook and not self.is_commissioning: test_name = self.current_test_info.name steps = self.get_defined_test_steps(test_name) @@ -949,12 +954,11 @@ def on_fail(self, record): record is of type TestResultRecord ''' + self.failed = True if self.runner_hook and not self.is_commissioning: exception = record.termination_signal.exception step_duration = (datetime.now(timezone.utc) - self.step_start_time) / timedelta(microseconds=1) - # This isn't QUITE the test duration because the commissioning is handled separately, but it's clsoe enough for now - # This is already given in milliseconds - test_duration = record.end_time - record.begin_time + test_duration = datetime.now(timezone.utc) - self.test_start_time # TODO: I have no idea what logger, logs, request or received are. Hope None works because I have nothing to give self.runner_hook.step_failure(logger=None, logs=None, duration=step_duration, request=None, received=None) self.runner_hook.test_stop(exception=exception, duration=test_duration) @@ -968,7 +972,7 @@ def on_pass(self, record): # What is request? This seems like an implementation detail for the runner # TODO: As with failure, I have no idea what logger, logs or request are meant to be step_duration = (datetime.now(timezone.utc) - self.step_start_time) / timedelta(microseconds=1) - test_duration = record.end_time - record.begin_time + test_duration = datetime.now(timezone.utc) - self.test_start_time self.runner_hook.step_success(logger=None, logs=None, duration=step_duration, request=None) # TODO: this check could easily be annoying when doing dev. flag it somehow? Ditto with the in-order check @@ -986,6 +990,18 @@ def on_pass(self, record): if self.runner_hook and not self.is_commissioning: self.runner_hook.test_stop(exception=None, duration=test_duration) + def on_skip(self, record): + ''' Called by Mobly on test skip + + record is of type TestResultRecord + ''' + if self.runner_hook and not self.is_commissioning: + test_duration = record.end_time - record.begin_time + test_name = self.current_test_info.name + filename = inspect.getfile(self.__class__) + self.runner_hook.test_skipped(filename, test_name) + self.runner_hook.test_stop(exception=None, duration=test_duration) + def pics_guard(self, pics_condition: bool): """Checks a condition and if False marks the test step as skipped and returns False, otherwise returns True. @@ -1531,6 +1547,10 @@ def parse_matter_test_args(argv: Optional[List[str]] = None) -> MatterTestConfig return convert_args_to_matter_config(parser.parse_known_args(argv)[0]) +def _async_runner(body, self: MatterBaseTest, *args, **kwargs): + timeout = self.matter_test_config.timeout if self.matter_test_config.timeout is not None else self.default_timeout + runner_with_timeout = asyncio.wait_for(body(self, *args, **kwargs), timeout=timeout) + return asyncio.run(runner_with_timeout) def async_test_body(body): """Decorator required to be applied whenever a `test_*` method is `async def`. @@ -1541,12 +1561,78 @@ def async_test_body(body): """ def async_runner(self: MatterBaseTest, *args, **kwargs): - timeout = self.matter_test_config.timeout if self.matter_test_config.timeout is not None else self.default_timeout - runner_with_timeout = asyncio.wait_for(body(self, *args, **kwargs), timeout=timeout) - return asyncio.run(runner_with_timeout) + return _async_runner(body, self, *args, **kwargs) return async_runner +def per_node_test(body): + + """ Decorator to be used for PICS-free tests that apply to the entire node. + + Use this decorator when your script needs to be run once to validate the whole node. + To use this decorator, the test must NOT have an associated pics_ method. + """ + def whole_node_runner(self: MatterBaseTest, *args, **kwargs): + asserts.assert_false(self.get_test_pics(self.current_test_info.name), "pics_ method supplied for per_node_test.") + return _async_runner(body, self, *args, **kwargs) + + return whole_node_runner + +EndpointCheckFunction = typing.Callable[[Clusters.Attribute.AsyncReadTransaction.ReadResponse, int], bool] + +def _has_cluster(wildcard, endpoint, cluster: ClusterObjects.Cluster) -> bool: + try: + return cluster in wildcard.attributes[endpoint] + except KeyError: + return False + +def has_cluster(cluster: ClusterObjects.ClusterObjectDescriptor) -> EndpointCheckFunction: + return partial(_has_cluster, cluster=cluster) + +def _has_attribute(wildcard, endpoint, attribute: ClusterObjects.ClusterAttributeDescriptor) -> bool: + cluster = getattr(Clusters, attribute.__qualname__.split('.')[-3]) + try: + attr_list = wildcard.attributes[endpoint][cluster][cluster.Attributes.AttributeList] + return attribute.attribute_id in attr_list + except KeyError: + return False + +def has_attribute(attribute: ClusterObjects.ClusterAttributeDescriptor) -> EndpointCheckFunction: + return partial(_has_attribute, attribute=attribute) + +async def get_accepted_endpoints_for_test(self:MatterBaseTest, accept_function: EndpointCheckFunction): + wildcard = await self.default_controller.Read(self.dut_node_id, [()]) + return [e for e in wildcard.attributes.keys() if accept_function(wildcard, e)] + +def per_endpoint_test(accept_function): + def per_endpoint_test_internal(body): + def per_endpoint_runner(self: MatterBaseTest, *args, **kwargs): + asserts.assert_false(self.get_test_pics(self.current_test_info.name), "pics_ method supplied for per_endpoint_test.") + runner_with_timeout = asyncio.wait_for(get_accepted_endpoints_for_test(self, accept_function), timeout=5) + endpoints = asyncio.run(runner_with_timeout) + if not endpoints: + logging.info("No matching endpoints found - skipping test") + asserts.skip('No endpoints match requirements') + return + logging.info(f"Running test on the following endpoints: {endpoints}") + # setup_class is meant to be called once, but setup_test is expected to be run before + # each iteration. Mobly will run it for us the first time, but since we're running this + # more than one time, we want to make sure we reset everything as expected. + # Ditto for teardown - we want to tear down after each iteration, and we want to notify the hool that + # the test iteration is stopped. test_stop is called by on_pass or on_fail during the last iteration or + # on failure. + for e in endpoints: + logging.info(f'Running test on endpoint {e}') + if e != endpoints[0]: + self.setup_test() + self.matter_test_config.endpoint = e + _async_runner(body, self, *args, **kwargs) + if e != endpoints[-1] and not self.failed: + self.teardown_test() + self.runner_hook.test_stop(exception=None, duration=datetime.now(timezone.utc) - self.test_start_time) + + return per_endpoint_runner + return per_endpoint_test_internal class CommissionDeviceTest(MatterBaseTest): """Test class auto-injected at the start of test list to commission a device when requested""" diff --git a/src/python_testing/test_testing/MockTestRunner.py b/src/python_testing/test_testing/MockTestRunner.py index 5d6592b5a8cd41..58497c7004c7e4 100644 --- a/src/python_testing/test_testing/MockTestRunner.py +++ b/src/python_testing/test_testing/MockTestRunner.py @@ -37,21 +37,24 @@ async def __call__(self, *args, **kwargs): class MockTestRunner(): - def __init__(self, filename: str, classname: str, test: str, endpoint: int, pics: dict[str, bool] = {}): - self.config = MatterTestConfig( - tests=[test], endpoint=endpoint, dut_node_ids=[1], pics=pics) + def __init__(self, filename: str, classname: str, test: str, endpoint: int = 0, pics: dict[str, bool] = {}): + self.set_test(filename, classname, test, endpoint, pics) self.stack = MatterStackState(self.config) self.default_controller = self.stack.certificate_authorities[0].adminList[0].NewController( nodeId=self.config.controller_node_id, paaTrustStorePath=str(self.config.paa_trust_store_path), catTags=self.config.controller_cat_tags ) + + def set_test(self, filename: str, classname: str, test: str, endpoint: int = 0, pics: dict[str, bool] = {}): + self.config = MatterTestConfig( + tests=[test], endpoint=endpoint, dut_node_ids=[1], pics=pics) module = importlib.import_module(Path(os.path.basename(filename)).stem) self.test_class = getattr(module, classname) def Shutdown(self): self.stack.Shutdown() - def run_test_with_mock_read(self, read_cache: Attribute.AsyncReadTransaction.ReadResponse): + def run_test_with_mock_read(self, read_cache: Attribute.AsyncReadTransaction.ReadResponse, hooks = None): self.default_controller.Read = AsyncMock(return_value=read_cache) - return run_tests_no_exit(self.test_class, self.config, None, self.default_controller, self.stack) + return run_tests_no_exit(self.test_class, self.config, hooks, self.default_controller, self.stack) diff --git a/src/python_testing/test_testing/TestDecorators.py b/src/python_testing/test_testing/TestDecorators.py new file mode 100644 index 00000000000000..3ea75b3b10c595 --- /dev/null +++ b/src/python_testing/test_testing/TestDecorators.py @@ -0,0 +1,280 @@ +# +# Copyright (c) 2024 Project CHIP Authors +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +# Hooks: +# If this is a per-endpoint test +# - If test is run, hook object will get one test_start and one test_stop call per endpoint on which the test is run +# - If the test is skipped, hook object will get test_start, test_skipped, test_stop +# If this is a whole-node test: +# - test will always be run, so hook object will get test_start, test_stop +# +# You will get step_* calls as appropriate in between the test_start and test_stop calls if the test is not skipped. + +import chip.clusters as Clusters +import os +import sys +from chip.clusters import ClusterObjects as ClusterObjects +from chip.clusters import Attribute + +try: + from matter_testing_support import MatterBaseTest, async_test_body, has_cluster, has_attribute, get_accepted_endpoints_for_test, per_endpoint_test, per_node_test +except ImportError: + sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + from matter_testing_support import MatterBaseTest, async_test_body, has_cluster, has_attribute, get_accepted_endpoints_for_test, per_endpoint_test, per_node_test + +from MockTestRunner import MockTestRunner +from typing import Optional +from mobly import asserts + + +def get_clusters(endpoints: list[int]) -> Attribute.AsyncReadTransaction.ReadResponse: + c = Clusters.OnOff + attr = c.Attributes + # We're JUST populating the globals here because that's all that matters for this particular test + feature_map = c.Bitmaps.Feature.kLighting + # Only supported attributes - globals and OnOff. This isn't a compliant device. Doesn't matter for this test. + attribute_list = [attr.FeatureMap.attribute_id, attr.AttributeList.attribute_id, attr.AcceptedCommandList.attribute_id, attr.GeneratedCommandList.attribute_id, attr.OnOff.attribute_id] + accepted_commands = [c.Commands.Off, c.Commands.On] + resp = Attribute.AsyncReadTransaction.ReadResponse({}, [], {}) + for e in endpoints: + resp.attributes[e] = {c:{attr.FeatureMap: feature_map, attr.AttributeList: attribute_list, attr.AcceptedCommandList: accepted_commands}} + return resp + + +class DecoratorTestRunnerHooks: + def __init__(self): + self.started = [] + self.skipped = [] + self.stopped = 0 + + def start(self, count: int): + pass + + def stop(self, duration: int): + pass + + def test_start(self, filename: str, name: str, count: int, steps: list[str] = []): + self.started.append(name) + + def test_skipped(self, filename: str, name: str): + self.skipped.append(name) + + def test_stop(self, exception: Exception, duration: int): + self.stopped += 1 + + def step_skipped(self, name: str, expression: str): + pass + + def step_start(self, name: str): + pass + + def step_success(self, logger, logs, duration: int, request): + pass + + def step_failure(self, logger, logs, duration: int, request, received): + pass + + def step_unknown(self): + pass + + def show_prompt(self, + msg: str, + placeholder: Optional[str] = None, + default_value: Optional[str] = None) -> None: + pass + +class TestDecorators(MatterBaseTest): + def test_checkers(self): + has_onoff = has_cluster(Clusters.OnOff) + has_onoff_onoff = has_attribute(Clusters.OnOff.Attributes.OnOff) + has_onoff_ontime = has_attribute(Clusters.OnOff.Attributes.OnTime) + has_timesync = has_cluster(Clusters.TimeSynchronization) + has_timesync_utc = has_attribute(Clusters.TimeSynchronization.Attributes.UTCTime) + + wildcard = get_clusters([0, 1]) + def check_endpoints(f, expect_true, expectation:str): + asserts.assert_equal(f(wildcard, 0), expect_true, f"Expected {expectation} == {expect_true} on EP0") + asserts.assert_equal(f(wildcard, 1), expect_true, f"Expected {expectation} == {expect_true} on EP1") + asserts.assert_false(f(wildcard, 2), f"Expected {expectation} == False on EP2") + + check_endpoints(has_onoff, True, "OnOff Cluster") + check_endpoints(has_onoff_onoff, True, "OnOff attribute") + check_endpoints(has_onoff_ontime, False, "OnTime attribute") + check_endpoints(has_timesync, False, "TimeSynchronization Cluster") + check_endpoints(has_timesync_utc, False, "UTC attribute") + + @async_test_body + async def test_endpoints(self): + has_onoff = has_cluster(Clusters.OnOff) + has_onoff_onoff = has_attribute(Clusters.OnOff.Attributes.OnOff) + has_onoff_ontime = has_attribute(Clusters.OnOff.Attributes.OnTime) + has_timesync = has_cluster(Clusters.TimeSynchronization) + has_timesync_utc = has_attribute(Clusters.TimeSynchronization.Attributes.UTCTime) + + all_endpoints = await self.default_controller.Read(self.dut_node_id, [()]) + all_endpoints = list(all_endpoints.attributes.keys()) + + msg = "Unexpected endpoint list returned" + + endpoints = await get_accepted_endpoints_for_test(self, has_onoff) + asserts.assert_equal(endpoints, all_endpoints, msg) + + endpoints = await get_accepted_endpoints_for_test(self, has_onoff_onoff) + asserts.assert_equal(endpoints, all_endpoints, msg) + + endpoints = await get_accepted_endpoints_for_test(self, has_onoff_ontime) + asserts.assert_equal(endpoints, [], msg) + + endpoints = await get_accepted_endpoints_for_test(self, has_timesync) + asserts.assert_equal(endpoints, [], msg) + + endpoints = await get_accepted_endpoints_for_test(self, has_timesync_utc) + asserts.assert_equal(endpoints, [], msg) + + # This test should cause an assertion because it has pics_ method + @per_node_test + async def test_whole_node_with_pics(self): + pass + def pics_whole_node_with_pics(self): + return ['EXAMPLE.S'] + + # This test should cause an assertion because it has a pics_ method + @per_endpoint_test(has_cluster(Clusters.OnOff)) + async def test_per_endpoint_with_pics(self): + pass + def pics_per_endpoint_with_pics(self): + return ['EXAMPLE.S'] + + # This test should be run once + @per_node_test + async def test_whole_node_ok(self): + pass + + # This test should be run once per endpoint + @per_endpoint_test(has_cluster(Clusters.OnOff)) + async def test_endpoint_cluster_yes(self): + pass + + # This test should be skipped since this cluster isn't on any endpoint + @per_endpoint_test(has_cluster(Clusters.TimeSynchronization)) + async def test_endpoint_cluster_no(self): + pass + + # This test should be run once per endpoint + @per_endpoint_test(has_attribute(Clusters.OnOff.Attributes.OnOff)) + async def test_endpoint_attribute_yes(self): + pass + + # This test should be skipped since this attribute isn't on the supported cluster + @per_endpoint_test(has_attribute(Clusters.OnOff.Attributes.OffWaitTime)) + async def test_endpoint_attribute_supported_cluster_no(self): + pass + + # This test should be skipped since this attribute is part of an unsupported cluster + @per_endpoint_test(has_attribute(Clusters.TimeSynchronization.Attributes.Granularity)) + async def test_endpoint_attribute_unsupported_cluster_no(self): + pass + + # This test should be run since both are present + @per_endpoint_test(has_attribute(Clusters.OnOff.Attributes.OnOff) and has_cluster(Clusters.OnOff)) + async def test_endpoint_boolean_yes(self): + pass + + # This test should be skipped since we have an OnOff cluster, but no Time sync + @per_endpoint_test(has_cluster(Clusters.OnOff) and has_cluster(Clusters.TimeSynchronization)) + async def test_endpoint_boolean_no(self): + pass + + # TODO: add test for assertions on whole node, and on each endpoint of an endpoint test + +def main(): + failures = [] + hooks = DecoratorTestRunnerHooks() + test_runner = MockTestRunner('TestDecorators.py', 'TestDecorators', 'test_checkers') + read_resp = get_clusters([0, 1]) + ok = test_runner.run_test_with_mock_read(read_resp, hooks) + if not ok: + failures.append(f"Test case failure: test_checkers") + + test_runner.set_test('TestDecorators.py', 'TestDecorators', 'test_endpoints') + read_resp = get_clusters([0, 1]) + ok = test_runner.run_test_with_mock_read(read_resp, hooks) + if not ok: + failures.append(f"Test case failure: test_endpoints") + + read_resp = get_clusters([0]) + ok = test_runner.run_test_with_mock_read(read_resp, hooks) + if not ok: + failures.append(f"Test case failure: test_endpoints") + + test_name = 'test_whole_node_with_pics' + test_runner.set_test('TestDecorators.py', 'TestDecorators', test_name) + ok = test_runner.run_test_with_mock_read(read_resp, hooks) + if ok: + failures.append(f"Did not get expected test assertion on {test_name}") + + test_name = 'test_per_endpoint_with_pics' + test_runner.set_test('TestDecorators.py', 'TestDecorators', test_name) + ok = test_runner.run_test_with_mock_read(read_resp, hooks) + if ok: + failures.append(f"Did not get expected test assertion on {test_name}") + + # Test should run once for the whole node, regardless of the number of endpoints + def run_check(test_name: str, read_response: Attribute.AsyncReadTransaction.ReadResponse, expected_runs: int, expect_skip: bool) -> None: + nonlocal failures + test_runner.set_test('TestDecorators.py', 'TestDecorators', test_name) + hooks = DecoratorTestRunnerHooks() + ok = test_runner.run_test_with_mock_read(read_response, hooks) + started_ok = len(hooks.started) == expected_runs + skipped_ok = (hooks.skipped != []) == expect_skip + stopped_ok = hooks.stopped == expected_runs + if not ok or not started_ok or not skipped_ok or not stopped_ok: + failures.append(f'Expected {expected_runs} run of {test_name}, skips expected: {expect_skip}. Runs: {hooks.started}, skips: {hooks.skipped} stops: {hooks.stopped}') + + def check_once_per_node(test_name:str): + run_check(test_name, get_clusters([0]), 1, False) + run_check(test_name, get_clusters([0, 1]), 1, False) + + def check_once_per_endpoint(test_name: str): + run_check(test_name, get_clusters([0]), 1, False) + run_check(test_name, get_clusters([0, 1]), 2, False) + + def check_skipped(test_name: str): + run_check(test_name, get_clusters([0]), 1, True) + run_check(test_name, get_clusters([0, 1]), 1, True) + + check_once_per_node('test_whole_node_ok') + check_once_per_endpoint('test_endpoint_cluster_yes') + check_skipped('test_endpoint_cluster_no') + check_once_per_endpoint('test_endpoint_attribute_yes') + check_skipped('test_endpoint_attribute_supported_cluster_no') + check_skipped('test_endpoint_attribute_unsupported_cluster_no') + check_once_per_endpoint('test_endpoint_boolean_yes') + check_skipped('test_endpoint_boolean_no') + + test_runner.Shutdown() + print( + f"Test of Decorators: test response incorrect: {len(failures)}") + for f in failures: + print(f) + + return 1 if failures else 0 + + +if __name__ == "__main__": + sys.exit(main())