Skip to content

Commit

Permalink
[WIP] python test framework PICS 2.0
Browse files Browse the repository at this point in the history
  • Loading branch information
cecille committed Jul 18, 2024
1 parent 024b09b commit 37c0a12
Show file tree
Hide file tree
Showing 5 changed files with 410 additions and 43 deletions.
6 changes: 6 additions & 0 deletions scripts/py_matter_yamltests/matter_yamltests/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,12 @@ def show_prompt(self,
"""
pass

def test_skipped(self, filename: str, name: str):
"""
This method is called when the test script determines that the test is not applicable for the DUT.
"""
pass


class WebSocketRunnerHooks():
def connecting(self, url: str):
Expand Down
54 changes: 23 additions & 31 deletions src/python_testing/TC_TIMESYNC_2_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,53 +27,45 @@

import chip.clusters as Clusters
from chip.clusters.Types import NullValue
from matter_testing_support import MatterBaseTest, async_test_body, default_matter_test_main, utc_time_in_matter_epoch
from matter_testing_support import MatterBaseTest, default_matter_test_main, utc_time_in_matter_epoch, per_endpoint_test, has_cluster, has_attribute
from mobly import asserts


class TC_TIMESYNC_2_1(MatterBaseTest):
async def read_ts_attribute_expect_success(self, endpoint, attribute):
async def read_ts_attribute_expect_success(self, attribute):
cluster = Clusters.Objects.TimeSynchronization
return await self.read_single_attribute_check_success(endpoint=endpoint, cluster=cluster, attribute=attribute)
return await self.read_single_attribute_check_success(endpoint=None, cluster=cluster, attribute=attribute)

def pics_TC_TIMESYNC_2_1(self) -> list[str]:
return ["TIMESYNC.S"]

@async_test_body
@per_endpoint_test(has_cluster(Clusters.TimeSynchronization) and has_attribute(Clusters.TimeSynchronization.Attributes.TimeSource))
async def test_TC_TIMESYNC_2_1(self):
endpoint = 0

features = await self.read_single_attribute(dev_ctrl=self.default_controller, node_id=self.dut_node_id,
endpoint=endpoint, attribute=Clusters.TimeSynchronization.Attributes.FeatureMap)
attributes = Clusters.TimeSynchronization.Attributes
features = await self.read_ts_attribute_expect_success(attribute=attributes.FeatureMap)

self.supports_time_zone = bool(features & Clusters.TimeSynchronization.Bitmaps.Feature.kTimeZone)
self.supports_ntpc = bool(features & Clusters.TimeSynchronization.Bitmaps.Feature.kNTPClient)
self.supports_ntps = bool(features & Clusters.TimeSynchronization.Bitmaps.Feature.kNTPServer)
self.supports_trusted_time_source = bool(features & Clusters.TimeSynchronization.Bitmaps.Feature.kTimeSyncClient)

time_cluster = Clusters.TimeSynchronization
timesync_attr_list = time_cluster.Attributes.AttributeList
attribute_list = await self.read_single_attribute_check_success(endpoint=endpoint, cluster=time_cluster, attribute=timesync_attr_list)
timesource_attr_id = time_cluster.Attributes.TimeSource.attribute_id
timesync_attr_list = attributes.AttributeList
attribute_list = await self.read_ts_attribute_expect_success(attribute=timesync_attr_list)
timesource_attr_id = attributes.TimeSource.attribute_id

self.print_step(1, "Commissioning, already done")
attributes = Clusters.TimeSynchronization.Attributes

self.print_step(2, "Read Granularity attribute")
granularity_dut = await self.read_ts_attribute_expect_success(endpoint=endpoint, attribute=attributes.Granularity)
granularity_dut = await self.read_ts_attribute_expect_success(attribute=attributes.Granularity)
asserts.assert_less(granularity_dut, Clusters.TimeSynchronization.Enums.GranularityEnum.kUnknownEnumValue,
"Granularity is not in valid range")

self.print_step(3, "Read TimeSource")
if timesource_attr_id in attribute_list:
time_source = await self.read_ts_attribute_expect_success(endpoint=endpoint, attribute=attributes.TimeSource)
time_source = await self.read_ts_attribute_expect_success(attribute=attributes.TimeSource)
asserts.assert_less(time_source, Clusters.TimeSynchronization.Enums.TimeSourceEnum.kUnknownEnumValue,
"TimeSource is not in valid range")

self.print_step(4, "Read TrustedTimeSource")
if self.supports_trusted_time_source:
trusted_time_source = await self.read_ts_attribute_expect_success(endpoint=endpoint,
attribute=attributes.TrustedTimeSource)
trusted_time_source = await self.read_ts_attribute_expect_success(attribute=attributes.TrustedTimeSource)
if trusted_time_source is not NullValue:
asserts.assert_less_equal(trusted_time_source.fabricIndex, 0xFE,
"FabricIndex for the TrustedTimeSource is out of range")
Expand All @@ -82,7 +74,7 @@ async def test_TC_TIMESYNC_2_1(self):

self.print_step(5, "Read DefaultNTP")
if self.supports_ntpc:
default_ntp = await self.read_ts_attribute_expect_success(endpoint=endpoint, attribute=attributes.DefaultNTP)
default_ntp = await self.read_ts_attribute_expect_success(attribute=attributes.DefaultNTP)
if default_ntp is not NullValue:
asserts.assert_less_equal(len(default_ntp), 128, "DefaultNTP length must be less than 128")
# Assume this is a valid web address if it has at least one . in the name
Expand All @@ -97,7 +89,7 @@ async def test_TC_TIMESYNC_2_1(self):

self.print_step(6, "Read TimeZone")
if self.supports_time_zone:
tz_dut = await self.read_ts_attribute_expect_success(endpoint=endpoint, attribute=attributes.TimeZone)
tz_dut = await self.read_ts_attribute_expect_success(attribute=attributes.TimeZone)
asserts.assert_greater_equal(len(tz_dut), 1, "TimeZone must have at least one entry in the list")
asserts.assert_less_equal(len(tz_dut), 2, "TimeZone may have a maximum of two entries in the list")
for entry in tz_dut:
Expand All @@ -112,7 +104,7 @@ async def test_TC_TIMESYNC_2_1(self):

self.print_step(7, "Read DSTOffset")
if self.supports_time_zone:
dst_dut = await self.read_ts_attribute_expect_success(endpoint=endpoint, attribute=attributes.DSTOffset)
dst_dut = await self.read_ts_attribute_expect_success(attribute=attributes.DSTOffset)
last_valid_until = -1
last_valid_starting = -1
for dst in dst_dut:
Expand All @@ -126,7 +118,7 @@ async def test_TC_TIMESYNC_2_1(self):
asserts.assert_equal(dst, dst_dut[-1], "DSTOffset list must have Null ValidUntil at the end")

self.print_step(8, "Read UTCTime")
utc_dut = await self.read_ts_attribute_expect_success(endpoint=endpoint, attribute=attributes.UTCTime)
utc_dut = await self.read_ts_attribute_expect_success(attribute=attributes.UTCTime)
if utc_dut is NullValue:
asserts.assert_equal(granularity_dut, Clusters.TimeSynchronization.Enums.GranularityEnum.kNoTimeGranularity)
else:
Expand All @@ -141,8 +133,8 @@ async def test_TC_TIMESYNC_2_1(self):

self.print_step(9, "Read LocalTime")
if self.supports_time_zone:
utc_dut = await self.read_ts_attribute_expect_success(endpoint=endpoint, attribute=attributes.UTCTime)
local_dut = await self.read_ts_attribute_expect_success(endpoint=endpoint, attribute=attributes.LocalTime)
utc_dut = await self.read_ts_attribute_expect_success(attribute=attributes.UTCTime)
local_dut = await self.read_ts_attribute_expect_success(attribute=attributes.LocalTime)
if utc_dut is NullValue:
asserts.assert_true(local_dut is NullValue, "LocalTime must be Null if UTC time is Null")
elif len(dst_dut) == 0:
Expand All @@ -156,30 +148,30 @@ async def test_TC_TIMESYNC_2_1(self):

self.print_step(10, "Read TimeZoneDatabase")
if self.supports_time_zone:
tz_db_dut = await self.read_ts_attribute_expect_success(endpoint=endpoint, attribute=attributes.TimeZoneDatabase)
tz_db_dut = await self.read_ts_attribute_expect_success(attribute=attributes.TimeZoneDatabase)
asserts.assert_less(tz_db_dut, Clusters.TimeSynchronization.Enums.TimeZoneDatabaseEnum.kUnknownEnumValue,
"TimeZoneDatabase is not in valid range")

self.print_step(11, "Read NTPServerAvailable")
if self.supports_ntps:
# bool typechecking happens in the test read functions, so all we need to do here is do the read
await self.read_ts_attribute_expect_success(endpoint=endpoint, attribute=attributes.NTPServerAvailable)
await self.read_ts_attribute_expect_success(attribute=attributes.NTPServerAvailable)

self.print_step(12, "Read TimeZoneListMaxSize")
if self.supports_time_zone:
size = await self.read_ts_attribute_expect_success(endpoint=endpoint, attribute=attributes.TimeZoneListMaxSize)
size = await self.read_ts_attribute_expect_success(attribute=attributes.TimeZoneListMaxSize)
asserts.assert_greater_equal(size, 1, "TimeZoneListMaxSize must be at least 1")
asserts.assert_less_equal(size, 2, "TimeZoneListMaxSize must be max 2")

self.print_step(13, "Read DSTOffsetListMaxSize")
if self.supports_time_zone:
size = await self.read_ts_attribute_expect_success(endpoint=endpoint, attribute=attributes.DSTOffsetListMaxSize)
size = await self.read_ts_attribute_expect_success(attribute=attributes.DSTOffsetListMaxSize)
asserts.assert_greater_equal(size, 1, "DSTOffsetListMaxSize must be at least 1")

self.print_step(14, "Read SupportsDNSResolve")
# bool typechecking happens in the test read functions, so all we need to do here is do the read
if self.supports_ntpc:
await self.read_ts_attribute_expect_success(endpoint=endpoint, attribute=attributes.SupportsDNSResolve)
await self.read_ts_attribute_expect_success(attribute=attributes.SupportsDNSResolve)


if __name__ == "__main__":
Expand Down
100 changes: 93 additions & 7 deletions src/python_testing/matter_testing_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from dataclasses import dataclass, field
from datetime import datetime, timedelta, timezone
from enum import Enum
from functools import partial
from typing import List, Optional, Tuple

from chip.tlv import float32, uint
Expand Down Expand Up @@ -339,6 +340,8 @@ def show_prompt(self,
placeholder: Optional[str] = None,
default_value: Optional[str] = None) -> None:
pass
def test_skipped(self, filename: str, name: str):
logging.info(f"Skipping test from {filename}: {name}")


@dataclass
Expand Down Expand Up @@ -771,8 +774,10 @@ def setup_class(self):

def setup_test(self):
self.current_step_index = 0
self.test_start_time = datetime.now(timezone.utc)
self.step_start_time = datetime.now(timezone.utc)
self.step_skipped = False
self.failed = False
if self.runner_hook and not self.is_commissioning:
test_name = self.current_test_info.name
steps = self.get_defined_test_steps(test_name)
Expand Down Expand Up @@ -949,12 +954,11 @@ def on_fail(self, record):
record is of type TestResultRecord
'''
self.failed = True
if self.runner_hook and not self.is_commissioning:
exception = record.termination_signal.exception
step_duration = (datetime.now(timezone.utc) - self.step_start_time) / timedelta(microseconds=1)
# This isn't QUITE the test duration because the commissioning is handled separately, but it's clsoe enough for now
# This is already given in milliseconds
test_duration = record.end_time - record.begin_time
test_duration = datetime.now(timezone.utc) - self.test_start_time
# TODO: I have no idea what logger, logs, request or received are. Hope None works because I have nothing to give
self.runner_hook.step_failure(logger=None, logs=None, duration=step_duration, request=None, received=None)
self.runner_hook.test_stop(exception=exception, duration=test_duration)
Expand All @@ -968,7 +972,7 @@ def on_pass(self, record):
# What is request? This seems like an implementation detail for the runner
# TODO: As with failure, I have no idea what logger, logs or request are meant to be
step_duration = (datetime.now(timezone.utc) - self.step_start_time) / timedelta(microseconds=1)
test_duration = record.end_time - record.begin_time
test_duration = datetime.now(timezone.utc) - self.test_start_time
self.runner_hook.step_success(logger=None, logs=None, duration=step_duration, request=None)

# TODO: this check could easily be annoying when doing dev. flag it somehow? Ditto with the in-order check
Expand All @@ -986,6 +990,18 @@ def on_pass(self, record):
if self.runner_hook and not self.is_commissioning:
self.runner_hook.test_stop(exception=None, duration=test_duration)

def on_skip(self, record):
''' Called by Mobly on test skip
record is of type TestResultRecord
'''
if self.runner_hook and not self.is_commissioning:
test_duration = record.end_time - record.begin_time
test_name = self.current_test_info.name
filename = inspect.getfile(self.__class__)
self.runner_hook.test_skipped(filename, test_name)
self.runner_hook.test_stop(exception=None, duration=test_duration)

def pics_guard(self, pics_condition: bool):
"""Checks a condition and if False marks the test step as skipped and
returns False, otherwise returns True.
Expand Down Expand Up @@ -1531,6 +1547,10 @@ def parse_matter_test_args(argv: Optional[List[str]] = None) -> MatterTestConfig

return convert_args_to_matter_config(parser.parse_known_args(argv)[0])

def _async_runner(body, self: MatterBaseTest, *args, **kwargs):
timeout = self.matter_test_config.timeout if self.matter_test_config.timeout is not None else self.default_timeout
runner_with_timeout = asyncio.wait_for(body(self, *args, **kwargs), timeout=timeout)
return asyncio.run(runner_with_timeout)

def async_test_body(body):
"""Decorator required to be applied whenever a `test_*` method is `async def`.
Expand All @@ -1541,12 +1561,78 @@ def async_test_body(body):
"""

def async_runner(self: MatterBaseTest, *args, **kwargs):
timeout = self.matter_test_config.timeout if self.matter_test_config.timeout is not None else self.default_timeout
runner_with_timeout = asyncio.wait_for(body(self, *args, **kwargs), timeout=timeout)
return asyncio.run(runner_with_timeout)
return _async_runner(body, self, *args, **kwargs)

return async_runner

def per_node_test(body):

""" Decorator to be used for PICS-free tests that apply to the entire node.
Use this decorator when your script needs to be run once to validate the whole node.
To use this decorator, the test must NOT have an associated pics_ method.
"""
def whole_node_runner(self: MatterBaseTest, *args, **kwargs):
asserts.assert_false(self.get_test_pics(self.current_test_info.name), "pics_ method supplied for per_node_test.")
return _async_runner(body, self, *args, **kwargs)

return whole_node_runner

EndpointCheckFunction = typing.Callable[[Clusters.Attribute.AsyncReadTransaction.ReadResponse, int], bool]

def _has_cluster(wildcard, endpoint, cluster: ClusterObjects.Cluster) -> bool:
try:
return cluster in wildcard.attributes[endpoint]
except KeyError:
return False

def has_cluster(cluster: ClusterObjects.ClusterObjectDescriptor) -> EndpointCheckFunction:
return partial(_has_cluster, cluster=cluster)

def _has_attribute(wildcard, endpoint, attribute: ClusterObjects.ClusterAttributeDescriptor) -> bool:
cluster = getattr(Clusters, attribute.__qualname__.split('.')[-3])
try:
attr_list = wildcard.attributes[endpoint][cluster][cluster.Attributes.AttributeList]
return attribute.attribute_id in attr_list
except KeyError:
return False

def has_attribute(attribute: ClusterObjects.ClusterAttributeDescriptor) -> EndpointCheckFunction:
return partial(_has_attribute, attribute=attribute)

async def get_accepted_endpoints_for_test(self:MatterBaseTest, accept_function: EndpointCheckFunction):
wildcard = await self.default_controller.Read(self.dut_node_id, [()])
return [e for e in wildcard.attributes.keys() if accept_function(wildcard, e)]

def per_endpoint_test(accept_function):
def per_endpoint_test_internal(body):
def per_endpoint_runner(self: MatterBaseTest, *args, **kwargs):
asserts.assert_false(self.get_test_pics(self.current_test_info.name), "pics_ method supplied for per_endpoint_test.")
runner_with_timeout = asyncio.wait_for(get_accepted_endpoints_for_test(self, accept_function), timeout=5)
endpoints = asyncio.run(runner_with_timeout)
if not endpoints:
logging.info("No matching endpoints found - skipping test")
asserts.skip('No endpoints match requirements')
return
logging.info(f"Running test on the following endpoints: {endpoints}")
# setup_class is meant to be called once, but setup_test is expected to be run before
# each iteration. Mobly will run it for us the first time, but since we're running this
# more than one time, we want to make sure we reset everything as expected.
# Ditto for teardown - we want to tear down after each iteration, and we want to notify the hool that
# the test iteration is stopped. test_stop is called by on_pass or on_fail during the last iteration or
# on failure.
for e in endpoints:
logging.info(f'Running test on endpoint {e}')
if e != endpoints[0]:
self.setup_test()
self.matter_test_config.endpoint = e
_async_runner(body, self, *args, **kwargs)
if e != endpoints[-1] and not self.failed:
self.teardown_test()
self.runner_hook.test_stop(exception=None, duration=datetime.now(timezone.utc) - self.test_start_time)

return per_endpoint_runner
return per_endpoint_test_internal

class CommissionDeviceTest(MatterBaseTest):
"""Test class auto-injected at the start of test list to commission a device when requested"""
Expand Down
Loading

0 comments on commit 37c0a12

Please sign in to comment.