From b52eff3c880ddd9a2a321eef3de7e888fd913aa4 Mon Sep 17 00:00:00 2001 From: Song Guo Date: Wed, 16 Feb 2022 17:30:31 +0800 Subject: [PATCH] [Python] Cleanup data version API --- .../python/chip/clusters/Attribute.py | 27 +++++---- .../python/chip/clusters/ClusterObjects.py | 12 +++- .../test/test_scripts/cluster_objects.py | 59 +++++++++---------- .../test_scripts/network_commissioning.py | 17 +++--- 4 files changed, 60 insertions(+), 55 deletions(-) diff --git a/src/controller/python/chip/clusters/Attribute.py b/src/controller/python/chip/clusters/Attribute.py index 4214a45364791f..9ac162f7f55e11 100644 --- a/src/controller/python/chip/clusters/Attribute.py +++ b/src/controller/python/chip/clusters/Attribute.py @@ -320,6 +320,13 @@ class SubscriptionParameters: MaxReportIntervalCeilingSeconds: int +class DataVersion: + ''' + A helper class as a key for getting cluster data version when reading attributes without returnClusterObject. + ''' + pass + + @dataclass class AttributeCache: ''' A cache that stores data & errors returned in read/subscribe reports, but organizes it topologically @@ -358,23 +365,18 @@ def UpdateTLV(self, path: AttributePath, dataVersion: int, data: Union[bytes, V self.versionList[path.EndpointId] = {} endpointCache = self.attributeTLVCache[path.EndpointId] - endpoint = self.versionList[path.EndpointId] + endpointVersion = self.versionList[path.EndpointId] if (path.ClusterId not in endpointCache): endpointCache[path.ClusterId] = {} - if (path.ClusterId not in endpoint): - endpoint[path.ClusterId] = {} + # All attributes from the same cluster instance should have the same dataVersion, so we can set the dataVersion of the cluster to the dataVersion with a random attribute. + endpointVersion[path.ClusterId] = dataVersion clusterCache = endpointCache[path.ClusterId] - cluster = endpoint[path.ClusterId] if (path.AttributeId not in clusterCache): clusterCache[path.AttributeId] = None - if (path.AttributeId not in cluster): - cluster[path.AttributeId] = None - clusterCache[path.AttributeId] = data - cluster[path.AttributeId] = dataVersion def UpdateCachedData(self): ''' This converts the raw TLV data into a cluster object format. @@ -409,12 +411,16 @@ def UpdateCachedData(self): endpointCache[clusterType] = {} clusterCache = endpointCache[clusterType] + clusterDataVersion = self.versionList.get( + endpoint, {}).get(cluster, None) if (self.returnClusterObject): try: # Since the TLV data is already organized by attribute tags, we can trivially convert to a cluster object representation. endpointCache[clusterType] = clusterType.FromDict( data=clusterType.descriptor.TagDictToLabelDict([], tlvCache[endpoint][cluster])) + endpointCache[clusterType].SetDataVersion( + clusterDataVersion) except Exception as ex: logging.error( f"Error converting TLV to Cluster Object for path: Endpoint = {endpoint}, cluster = {str(clusterType)}") @@ -423,6 +429,7 @@ def UpdateCachedData(self): tlvCache[endpoint][cluster], ex) endpointCache[clusterType] = decodedValue else: + clusterCache[DataVersion] = clusterDataVersion for attribute in tlvCache[endpoint][cluster]: value = tlvCache[endpoint][cluster][attribute] @@ -680,14 +687,12 @@ def _handleDone(self): if (self._transactionType == TransactionType.READ_EVENTS): self._future.set_result(self._events) else: - self._future.set_result( - (self._cache.attributeCache, self._cache.versionList)) + self._future.set_result(self._cache.attributeCache) def handleDone(self): self._event_loop.call_soon_threadsafe(self._handleDone) def handleReportBegin(self): - self._cache.versionList.clear() pass def handleReportEnd(self): diff --git a/src/controller/python/chip/clusters/ClusterObjects.py b/src/controller/python/chip/clusters/ClusterObjects.py index bcac3fd14b838d..07e5b14246b3a4 100644 --- a/src/controller/python/chip/clusters/ClusterObjects.py +++ b/src/controller/python/chip/clusters/ClusterObjects.py @@ -219,10 +219,16 @@ def must_use_timed_invoke(cls) -> bool: class Cluster(ClusterObject): - ''' This class does nothing, but a convenient class that generated clusters can inherit from. - This gives the ability that the users can use issubclass(X, Cluster) to determine if the class represnents a Cluster. ''' - pass + When send read requests with returnClusterObject=True, we will set the dataVersion property of the object. + Otherwise the [endpoint][cluster][Clusters.DataVersion] will be set to the DataVersion of the cluster. + ''' + @ChipUtility.classproperty + def dataVersion(self) -> int: + return self._dataVersion + + def SetDataVersion(self, version: int) -> None: + self._dataVersion = version class ClusterAttributeDescriptor: diff --git a/src/controller/python/test/test_scripts/cluster_objects.py b/src/controller/python/test/test_scripts/cluster_objects.py index d962c47a57cdc6..7c2a9b00480c07 100644 --- a/src/controller/python/test/test_scripts/cluster_objects.py +++ b/src/controller/python/test/test_scripts/cluster_objects.py @@ -19,7 +19,7 @@ import chip.clusters as Clusters import chip.exceptions import logging -from chip.clusters.Attribute import AttributePath, AttributeReadResult, AttributeStatus, ValueDecodeFailure, TypedAttributePath, SubscriptionTransaction +from chip.clusters.Attribute import AttributePath, AttributeReadResult, AttributeStatus, ValueDecodeFailure, TypedAttributePath, SubscriptionTransaction, DataVersion import chip.interaction_model import asyncio import time @@ -45,31 +45,26 @@ def _IgnoreAttributeDecodeFailure(path): def VerifyDecodeSuccess(values): print(f"{values}") - for endpoint in values[0]: - for cluster in values[0][endpoint]: - for attribute in values[0][endpoint][cluster]: - v = values[0][endpoint][cluster][attribute] - print(f"EP{endpoint}/{attribute} = {v}") + for endpoint in values: + for cluster in values[endpoint]: + for attribute in values[endpoint][cluster]: + v = values[endpoint][cluster][attribute] + print(f"EP{endpoint}/{cluster}/{attribute} = {v}") if (isinstance(v, ValueDecodeFailure)): if _IgnoreAttributeDecodeFailure((endpoint, cluster, attribute)): print( - f"Ignoring attribute decode failure for path {endpoint}/{attribute}") + f"Ignoring attribute decode failure for path {endpoint}/{cluster}/{attribute}") else: raise AssertionError( - f"Cannot decode value for path {endpoint}/{attribute}, got error: '{str(v.Reason)}', raw TLV data: '{v.TLVValue}'") + f"Cannot decode value for path {endpoint}/{cluster}/{attribute}, got error: '{str(v.Reason)}', raw TLV data: '{v.TLVValue}'") - for endpoint in values[1]: - for cluster in values[1][endpoint]: - for attribute in values[1][endpoint][cluster]: - v = values[1][endpoint][cluster][attribute] - print(f"EP{endpoint}/{attribute} version = {v}") - if (isinstance(v, ValueDecodeFailure)): - if _IgnoreAttributeDecodeFailure((endpoint, cluster, attribute)): - print( - f"Ignoring attribute version decode failure for path {endpoint}/{attribute}") - else: - raise AssertionError( - f"Cannot decode value for path {endpoint}/{attribute}, got error: '{str(v.Reason)}', raw TLV data: '{v.TLVValue}'") + for endpoint in values: + for cluster in values[endpoint]: + v = values[endpoint][cluster].get(DataVersion, None) + print(f"EP{endpoint}/{cluster} version = {v}") + if v is None: + raise AssertionError( + f"Cannot get data version for path {endpoint}/{cluster}") def _AssumeEventsDecodeSuccess(values): @@ -200,9 +195,9 @@ async def TestReadAttributeRequests(cls, devCtrl): (0, Clusters.Basic.Attributes.HardwareVersion), ] res = await devCtrl.ReadAttribute(nodeid=NODE_ID, attributes=req) - if ((0 not in res[0]) or (Clusters.Basic not in res[0][0]) or (len(res[0][0][Clusters.Basic]) != 3)): + if ((0 not in res) or (Clusters.Basic not in res[0]) or (len(res[0][Clusters.Basic]) != 3)): raise AssertionError( - f"Got back {len(res[0])} data items instead of 3") + f"Got back {len(res)} data items instead of 3") VerifyDecodeSuccess(res) logger.info("2: Reading Ex Cx A*") @@ -237,32 +232,32 @@ async def TestReadAttributeRequests(cls, devCtrl): res = await devCtrl.ReadAttribute(nodeid=NODE_ID, attributes=req, returnClusterObject=True) logger.info( - f"Basic Cluster - Label: {res[0][0][Clusters.Basic].productLabel}") + f"Basic Cluster - Label: {res[0][Clusters.Basic].productLabel}") # TestCluster will be ValueDecodeError here, so we comment out the log below. # Values are not expected to be ValueDecodeError for real clusters. # logger.info( - # f"Test Cluster - Struct: {res[0][1][Clusters.TestCluster].structAttr}") - logger.info(f"Test Cluster: {res[0][1][Clusters.TestCluster]}") + # f"Test Cluster - Struct: {res[1][Clusters.TestCluster].structAttr}") + logger.info(f"Test Cluster: {res[1][Clusters.TestCluster]}") logger.info("7: Reading Chunked List") res = await devCtrl.ReadAttribute(nodeid=NODE_ID, attributes=[(1, Clusters.TestCluster.Attributes.ListLongOctetString)]) - if res[0][1][Clusters.TestCluster][Clusters.TestCluster.Attributes.ListLongOctetString] != [b'0123456789abcdef' * 32] * 4: + if res[1][Clusters.TestCluster][Clusters.TestCluster.Attributes.ListLongOctetString] != [b'0123456789abcdef' * 32] * 4: raise AssertionError("Unexpected read result") logger.info("*: Getting current fabric index") res = await devCtrl.ReadAttribute(nodeid=NODE_ID, attributes=[(0, Clusters.OperationalCredentials.Attributes.CurrentFabricIndex)]) - fabricIndex = res[0][0][Clusters.OperationalCredentials][Clusters.OperationalCredentials.Attributes.CurrentFabricIndex] + fabricIndex = res[0][Clusters.OperationalCredentials][Clusters.OperationalCredentials.Attributes.CurrentFabricIndex] logger.info("8: Read without fabric filter") res = await devCtrl.ReadAttribute(nodeid=NODE_ID, attributes=[(1, Clusters.TestCluster.Attributes.ListFabricScoped)], fabricFiltered=False) - if len(res[0][1][Clusters.TestCluster][Clusters.TestCluster.Attributes.ListFabricScoped]) != 1: + if len(res[1][Clusters.TestCluster][Clusters.TestCluster.Attributes.ListFabricScoped]) != 1: raise AssertionError("Expect more elements in the response") logger.info("9: Read with fabric filter") res = await devCtrl.ReadAttribute(nodeid=NODE_ID, attributes=[(1, Clusters.TestCluster.Attributes.ListFabricScoped)], fabricFiltered=True) - if len(res[0][1][Clusters.TestCluster][Clusters.TestCluster.Attributes.ListFabricScoped]) != 1: + if len(res[1][Clusters.TestCluster][Clusters.TestCluster.Attributes.ListFabricScoped]) != 1: raise AssertionError("Expect exact one element in the response") - if res[0][1][Clusters.TestCluster][Clusters.TestCluster.Attributes.ListFabricScoped][0].fabricIndex != fabricIndex: + if res[1][Clusters.TestCluster][Clusters.TestCluster.Attributes.ListFabricScoped][0].fabricIndex != fabricIndex: raise AssertionError( "Expect the fabric index matches the one current reading") @@ -381,7 +376,7 @@ async def TestReadWriteAttributeRequestsWithVersion(cls, devCtrl): ] res = await devCtrl.ReadAttribute(nodeid=NODE_ID, attributes=req) VerifyDecodeSuccess(res) - data_version = res[1][0][40][1] + data_version = res[0][Clusters.Basic][DataVersion] res = await devCtrl.WriteAttribute(nodeid=NODE_ID, attributes=[ @@ -405,7 +400,7 @@ async def TestReadWriteAttributeRequestsWithVersion(cls, devCtrl): ] res = await devCtrl.ReadAttribute(nodeid=NODE_ID, attributes=req, dataVersionFilters=[(0, Clusters.Basic, data_version)]) VerifyDecodeSuccess(res) - new_data_version = res[1][0][40][1] + new_data_version = res[0][Clusters.Basic][DataVersion] if (data_version + 1) != new_data_version: raise AssertionError("Version mistmatch happens.") diff --git a/src/controller/python/test/test_scripts/network_commissioning.py b/src/controller/python/test/test_scripts/network_commissioning.py index 6d47e897be63a5..789b0a2f187ab9 100644 --- a/src/controller/python/test/test_scripts/network_commissioning.py +++ b/src/controller/python/test/test_scripts/network_commissioning.py @@ -66,7 +66,7 @@ async def readLastNetworkingStateAttributes(self, endpointId): res = await self._devCtrl.ReadAttribute(nodeid=self._nodeid, attributes=[(endpointId, Clusters.NetworkCommissioning.Attributes.LastConnectErrorValue), (endpointId, Clusters.NetworkCommissioning.Attributes.LastNetworkID), (endpointId, Clusters.NetworkCommissioning.Attributes.LastNetworkingStatus)], returnClusterObject=True) - values = res[0][endpointId][Clusters.NetworkCommissioning] + values = res[endpointId][Clusters.NetworkCommissioning] logger.info(f"Got values: {values}") return values @@ -97,7 +97,7 @@ async def test_wifi(self, endpointId): (endpointId, Clusters.NetworkCommissioning.Attributes.FeatureMap)], returnClusterObject=True) self.log_interface_basic_info( - res[0][endpointId][Clusters.NetworkCommissioning]) + res[endpointId][Clusters.NetworkCommissioning]) logger.info(f"Finished getting basic information of the endpoint") # Read Last* attributes @@ -126,7 +126,7 @@ async def test_wifi(self, endpointId): # Remove existing network logger.info(f"Check network list") res = await self._devCtrl.ReadAttribute(nodeid=self._nodeid, attributes=[(endpointId, Clusters.NetworkCommissioning.Attributes.Networks)], returnClusterObject=True) - networkList = res[0][endpointId][Clusters.NetworkCommissioning].networks + networkList = res[endpointId][Clusters.NetworkCommissioning].networks logger.info(f"Got network list: {networkList}") if len(networkList) != 0: logger.info(f"Removing existing network") @@ -149,7 +149,7 @@ async def test_wifi(self, endpointId): logger.info(f"Check network list") res = await self._devCtrl.ReadAttribute(nodeid=self._nodeid, attributes=[(endpointId, Clusters.NetworkCommissioning.Attributes.Networks)], returnClusterObject=True) - networkList = res[0][endpointId][Clusters.NetworkCommissioning].networks + networkList = res[endpointId][Clusters.NetworkCommissioning].networks logger.info(f"Got network list: {networkList}") if len(networkList) != 1: raise AssertionError( @@ -172,7 +172,7 @@ async def test_wifi(self, endpointId): logger.info(f"Check network is connected") res = await self._devCtrl.ReadAttribute(nodeid=self._nodeid, attributes=[(endpointId, Clusters.NetworkCommissioning.Attributes.Networks)], returnClusterObject=True) - networkList = res[0][endpointId][Clusters.NetworkCommissioning].networks + networkList = res[endpointId][Clusters.NetworkCommissioning].networks logger.info(f"Got network list: {networkList}") if len(networkList) != 1: raise AssertionError( @@ -202,7 +202,7 @@ async def test_thread(self, endpointId): (endpointId, Clusters.NetworkCommissioning.Attributes.FeatureMap)], returnClusterObject=True) self.log_interface_basic_info( - res[0][endpointId][Clusters.NetworkCommissioning]) + res[endpointId][Clusters.NetworkCommissioning]) logger.info(f"Finished getting basic information of the endpoint") # Read Last* attributes @@ -231,7 +231,7 @@ async def test_thread(self, endpointId): # Remove existing network logger.info(f"Check network list") res = await self._devCtrl.ReadAttribute(nodeid=self._nodeid, attributes=[(endpointId, Clusters.NetworkCommissioning.Attributes.Networks)], returnClusterObject=True) - networkList = res[0][endpointId][Clusters.NetworkCommissioning].networks + networkList = res[endpointId][Clusters.NetworkCommissioning].networks logger.info(f"Got network list: {networkList}") if len(networkList) != 0: logger.info(f"Removing existing network") @@ -254,7 +254,7 @@ async def test_thread(self, endpointId): logger.info(f"Check network list") res = await self._devCtrl.ReadAttribute(nodeid=self._nodeid, attributes=[(endpointId, Clusters.NetworkCommissioning.Attributes.Networks)], returnClusterObject=True) - networkList = res[0][endpointId][Clusters.NetworkCommissioning].networks + networkList = res[endpointId][Clusters.NetworkCommissioning].networks logger.info(f"Got network list: {networkList}") if len(networkList) != 1: raise AssertionError( @@ -297,7 +297,6 @@ async def run(self): try: endpoints = await self._devCtrl.ReadAttribute(nodeid=self._nodeid, attributes=[(Clusters.NetworkCommissioning.Attributes.FeatureMap)], returnClusterObject=True) logger.info(endpoints) - endpoints = endpoints[0] for endpoint, obj in endpoints.items(): clus = obj[Clusters.NetworkCommissioning] if clus.featureMap == WIFI_NETWORK_FEATURE_MAP: