Skip to content

Commit

Permalink
[Python] Cleanup data version API
Browse files Browse the repository at this point in the history
  • Loading branch information
erjiaqing committed Mar 2, 2022
1 parent 6e47fae commit 9e4b81a
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 55 deletions.
27 changes: 16 additions & 11 deletions src/controller/python/chip/clusters/Attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,13 @@ class SubscriptionParameters:
MaxReportIntervalCeilingSeconds: int


class DataVersion:
'''
A helper class as a key for getting cluster data version when reading attributes without returnClusterObject.
'''
pass


@dataclass
class AttributeCache:
''' A cache that stores data & errors returned in read/subscribe reports, but organizes it topologically
Expand Down Expand Up @@ -358,23 +365,18 @@ def UpdateTLV(self, path: AttributePath, dataVersion: int, data: Union[bytes, V
self.versionList[path.EndpointId] = {}

endpointCache = self.attributeTLVCache[path.EndpointId]
endpoint = self.versionList[path.EndpointId]
endpointVersion = self.versionList[path.EndpointId]
if (path.ClusterId not in endpointCache):
endpointCache[path.ClusterId] = {}

if (path.ClusterId not in endpoint):
endpoint[path.ClusterId] = {}
# All attributes from the same cluster instance should have the same dataVersion, so we can set the dataVersion of the cluster to the dataVersion with a random attribute.
endpointVersion[path.ClusterId] = dataVersion

clusterCache = endpointCache[path.ClusterId]
cluster = endpoint[path.ClusterId]
if (path.AttributeId not in clusterCache):
clusterCache[path.AttributeId] = None

if (path.AttributeId not in cluster):
cluster[path.AttributeId] = None

clusterCache[path.AttributeId] = data
cluster[path.AttributeId] = dataVersion

def UpdateCachedData(self):
''' This converts the raw TLV data into a cluster object format.
Expand Down Expand Up @@ -409,12 +411,16 @@ def UpdateCachedData(self):
endpointCache[clusterType] = {}

clusterCache = endpointCache[clusterType]
clusterDataVersion = self.versionList.get(
endpoint, {}).get(cluster, None)

if (self.returnClusterObject):
try:
# Since the TLV data is already organized by attribute tags, we can trivially convert to a cluster object representation.
endpointCache[clusterType] = clusterType.FromDict(
data=clusterType.descriptor.TagDictToLabelDict([], tlvCache[endpoint][cluster]))
endpointCache[clusterType].SetDataVersion(
clusterDataVersion)
except Exception as ex:
logging.error(
f"Error converting TLV to Cluster Object for path: Endpoint = {endpoint}, cluster = {str(clusterType)}")
Expand All @@ -423,6 +429,7 @@ def UpdateCachedData(self):
tlvCache[endpoint][cluster], ex)
endpointCache[clusterType] = decodedValue
else:
clusterCache[DataVersion] = clusterDataVersion
for attribute in tlvCache[endpoint][cluster]:
value = tlvCache[endpoint][cluster][attribute]

Expand Down Expand Up @@ -680,14 +687,12 @@ def _handleDone(self):
if (self._transactionType == TransactionType.READ_EVENTS):
self._future.set_result(self._events)
else:
self._future.set_result(
(self._cache.attributeCache, self._cache.versionList))
self._future.set_result(self._cache.attributeCache)

def handleDone(self):
self._event_loop.call_soon_threadsafe(self._handleDone)

def handleReportBegin(self):
self._cache.versionList.clear()
pass

def handleReportEnd(self):
Expand Down
12 changes: 9 additions & 3 deletions src/controller/python/chip/clusters/ClusterObjects.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,10 +219,16 @@ def must_use_timed_invoke(cls) -> bool:


class Cluster(ClusterObject):
''' This class does nothing, but a convenient class that generated clusters can inherit from.
This gives the ability that the users can use issubclass(X, Cluster) to determine if the class represnents a Cluster.
'''
pass
When send read requests with returnClusterObject=True, we will set the dataVersion property of the object.
Otherwise the [endpoint][cluster][Clusters.DataVersion] will be set to the DataVersion of the cluster.
'''
@ChipUtility.classproperty
def dataVersion(self) -> int:
return self._dataVersion

def SetDataVersion(self, version: int) -> None:
self._dataVersion = version


class ClusterAttributeDescriptor:
Expand Down
59 changes: 27 additions & 32 deletions src/controller/python/test/test_scripts/cluster_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import chip.clusters as Clusters
import chip.exceptions
import logging
from chip.clusters.Attribute import AttributePath, AttributeReadResult, AttributeStatus, ValueDecodeFailure, TypedAttributePath, SubscriptionTransaction
from chip.clusters.Attribute import AttributePath, AttributeReadResult, AttributeStatus, ValueDecodeFailure, TypedAttributePath, SubscriptionTransaction, DataVersion
import chip.interaction_model
import asyncio
import time
Expand All @@ -45,31 +45,26 @@ def _IgnoreAttributeDecodeFailure(path):

def VerifyDecodeSuccess(values):
print(f"{values}")
for endpoint in values[0]:
for cluster in values[0][endpoint]:
for attribute in values[0][endpoint][cluster]:
v = values[0][endpoint][cluster][attribute]
print(f"EP{endpoint}/{attribute} = {v}")
for endpoint in values:
for cluster in values[endpoint]:
for attribute in values[endpoint][cluster]:
v = values[endpoint][cluster][attribute]
print(f"EP{endpoint}/{cluster}/{attribute} = {v}")
if (isinstance(v, ValueDecodeFailure)):
if _IgnoreAttributeDecodeFailure((endpoint, cluster, attribute)):
print(
f"Ignoring attribute decode failure for path {endpoint}/{attribute}")
f"Ignoring attribute decode failure for path {endpoint}/{cluster}/{attribute}")
else:
raise AssertionError(
f"Cannot decode value for path {endpoint}/{attribute}, got error: '{str(v.Reason)}', raw TLV data: '{v.TLVValue}'")
f"Cannot decode value for path {endpoint}/{cluster}/{attribute}, got error: '{str(v.Reason)}', raw TLV data: '{v.TLVValue}'")

for endpoint in values[1]:
for cluster in values[1][endpoint]:
for attribute in values[1][endpoint][cluster]:
v = values[1][endpoint][cluster][attribute]
print(f"EP{endpoint}/{attribute} version = {v}")
if (isinstance(v, ValueDecodeFailure)):
if _IgnoreAttributeDecodeFailure((endpoint, cluster, attribute)):
print(
f"Ignoring attribute version decode failure for path {endpoint}/{attribute}")
else:
raise AssertionError(
f"Cannot decode value for path {endpoint}/{attribute}, got error: '{str(v.Reason)}', raw TLV data: '{v.TLVValue}'")
for endpoint in values:
for cluster in values[endpoint]:
v = values[endpoint][cluster].get(DataVersion, None)
print(f"EP{endpoint}/{cluster} version = {v}")
if v is None:
raise AssertionError(
f"Cannot get data version for path {endpoint}/{cluster}")


def _AssumeEventsDecodeSuccess(values):
Expand Down Expand Up @@ -200,9 +195,9 @@ async def TestReadAttributeRequests(cls, devCtrl):
(0, Clusters.Basic.Attributes.HardwareVersion),
]
res = await devCtrl.ReadAttribute(nodeid=NODE_ID, attributes=req)
if ((0 not in res[0]) or (Clusters.Basic not in res[0][0]) or (len(res[0][0][Clusters.Basic]) != 3)):
if ((0 not in res) or (Clusters.Basic not in res[0]) or (len(res[0][Clusters.Basic]) != 3)):
raise AssertionError(
f"Got back {len(res[0])} data items instead of 3")
f"Got back {len(res)} data items instead of 3")
VerifyDecodeSuccess(res)

logger.info("2: Reading Ex Cx A*")
Expand Down Expand Up @@ -237,32 +232,32 @@ async def TestReadAttributeRequests(cls, devCtrl):

res = await devCtrl.ReadAttribute(nodeid=NODE_ID, attributes=req, returnClusterObject=True)
logger.info(
f"Basic Cluster - Label: {res[0][0][Clusters.Basic].productLabel}")
f"Basic Cluster - Label: {res[0][Clusters.Basic].productLabel}")
# TestCluster will be ValueDecodeError here, so we comment out the log below.
# Values are not expected to be ValueDecodeError for real clusters.
# logger.info(
# f"Test Cluster - Struct: {res[0][1][Clusters.TestCluster].structAttr}")
logger.info(f"Test Cluster: {res[0][1][Clusters.TestCluster]}")
# f"Test Cluster - Struct: {res[1][Clusters.TestCluster].structAttr}")
logger.info(f"Test Cluster: {res[1][Clusters.TestCluster]}")

logger.info("7: Reading Chunked List")
res = await devCtrl.ReadAttribute(nodeid=NODE_ID, attributes=[(1, Clusters.TestCluster.Attributes.ListLongOctetString)])
if res[0][1][Clusters.TestCluster][Clusters.TestCluster.Attributes.ListLongOctetString] != [b'0123456789abcdef' * 32] * 4:
if res[1][Clusters.TestCluster][Clusters.TestCluster.Attributes.ListLongOctetString] != [b'0123456789abcdef' * 32] * 4:
raise AssertionError("Unexpected read result")

logger.info("*: Getting current fabric index")
res = await devCtrl.ReadAttribute(nodeid=NODE_ID, attributes=[(0, Clusters.OperationalCredentials.Attributes.CurrentFabricIndex)])
fabricIndex = res[0][0][Clusters.OperationalCredentials][Clusters.OperationalCredentials.Attributes.CurrentFabricIndex]
fabricIndex = res[0][Clusters.OperationalCredentials][Clusters.OperationalCredentials.Attributes.CurrentFabricIndex]

logger.info("8: Read without fabric filter")
res = await devCtrl.ReadAttribute(nodeid=NODE_ID, attributes=[(1, Clusters.TestCluster.Attributes.ListFabricScoped)], fabricFiltered=False)
if len(res[0][1][Clusters.TestCluster][Clusters.TestCluster.Attributes.ListFabricScoped]) != 1:
if len(res[1][Clusters.TestCluster][Clusters.TestCluster.Attributes.ListFabricScoped]) != 1:
raise AssertionError("Expect more elements in the response")

logger.info("9: Read with fabric filter")
res = await devCtrl.ReadAttribute(nodeid=NODE_ID, attributes=[(1, Clusters.TestCluster.Attributes.ListFabricScoped)], fabricFiltered=True)
if len(res[0][1][Clusters.TestCluster][Clusters.TestCluster.Attributes.ListFabricScoped]) != 1:
if len(res[1][Clusters.TestCluster][Clusters.TestCluster.Attributes.ListFabricScoped]) != 1:
raise AssertionError("Expect exact one element in the response")
if res[0][1][Clusters.TestCluster][Clusters.TestCluster.Attributes.ListFabricScoped][0].fabricIndex != fabricIndex:
if res[1][Clusters.TestCluster][Clusters.TestCluster.Attributes.ListFabricScoped][0].fabricIndex != fabricIndex:
raise AssertionError(
"Expect the fabric index matches the one current reading")

Expand Down Expand Up @@ -381,7 +376,7 @@ async def TestReadWriteAttributeRequestsWithVersion(cls, devCtrl):
]
res = await devCtrl.ReadAttribute(nodeid=NODE_ID, attributes=req)
VerifyDecodeSuccess(res)
data_version = res[1][0][40][1]
data_version = res[0][Clusters.Basic][DataVersion]

res = await devCtrl.WriteAttribute(nodeid=NODE_ID,
attributes=[
Expand All @@ -405,7 +400,7 @@ async def TestReadWriteAttributeRequestsWithVersion(cls, devCtrl):
]
res = await devCtrl.ReadAttribute(nodeid=NODE_ID, attributes=req, dataVersionFilters=[(0, Clusters.Basic, data_version)])
VerifyDecodeSuccess(res)
new_data_version = res[1][0][40][1]
new_data_version = res[0][Clusters.Basic][DataVersion]
if (data_version + 1) != new_data_version:
raise AssertionError("Version mistmatch happens.")

Expand Down
17 changes: 8 additions & 9 deletions src/controller/python/test/test_scripts/network_commissioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ async def readLastNetworkingStateAttributes(self, endpointId):
res = await self._devCtrl.ReadAttribute(nodeid=self._nodeid, attributes=[(endpointId, Clusters.NetworkCommissioning.Attributes.LastConnectErrorValue),
(endpointId, Clusters.NetworkCommissioning.Attributes.LastNetworkID),
(endpointId, Clusters.NetworkCommissioning.Attributes.LastNetworkingStatus)], returnClusterObject=True)
values = res[0][endpointId][Clusters.NetworkCommissioning]
values = res[endpointId][Clusters.NetworkCommissioning]
logger.info(f"Got values: {values}")
return values

Expand Down Expand Up @@ -97,7 +97,7 @@ async def test_wifi(self, endpointId):
(endpointId, Clusters.NetworkCommissioning.Attributes.FeatureMap)],
returnClusterObject=True)
self.log_interface_basic_info(
res[0][endpointId][Clusters.NetworkCommissioning])
res[endpointId][Clusters.NetworkCommissioning])
logger.info(f"Finished getting basic information of the endpoint")

# Read Last* attributes
Expand Down Expand Up @@ -126,7 +126,7 @@ async def test_wifi(self, endpointId):
# Remove existing network
logger.info(f"Check network list")
res = await self._devCtrl.ReadAttribute(nodeid=self._nodeid, attributes=[(endpointId, Clusters.NetworkCommissioning.Attributes.Networks)], returnClusterObject=True)
networkList = res[0][endpointId][Clusters.NetworkCommissioning].networks
networkList = res[endpointId][Clusters.NetworkCommissioning].networks
logger.info(f"Got network list: {networkList}")
if len(networkList) != 0:
logger.info(f"Removing existing network")
Expand All @@ -149,7 +149,7 @@ async def test_wifi(self, endpointId):

logger.info(f"Check network list")
res = await self._devCtrl.ReadAttribute(nodeid=self._nodeid, attributes=[(endpointId, Clusters.NetworkCommissioning.Attributes.Networks)], returnClusterObject=True)
networkList = res[0][endpointId][Clusters.NetworkCommissioning].networks
networkList = res[endpointId][Clusters.NetworkCommissioning].networks
logger.info(f"Got network list: {networkList}")
if len(networkList) != 1:
raise AssertionError(
Expand All @@ -172,7 +172,7 @@ async def test_wifi(self, endpointId):

logger.info(f"Check network is connected")
res = await self._devCtrl.ReadAttribute(nodeid=self._nodeid, attributes=[(endpointId, Clusters.NetworkCommissioning.Attributes.Networks)], returnClusterObject=True)
networkList = res[0][endpointId][Clusters.NetworkCommissioning].networks
networkList = res[endpointId][Clusters.NetworkCommissioning].networks
logger.info(f"Got network list: {networkList}")
if len(networkList) != 1:
raise AssertionError(
Expand Down Expand Up @@ -202,7 +202,7 @@ async def test_thread(self, endpointId):
(endpointId, Clusters.NetworkCommissioning.Attributes.FeatureMap)],
returnClusterObject=True)
self.log_interface_basic_info(
res[0][endpointId][Clusters.NetworkCommissioning])
res[endpointId][Clusters.NetworkCommissioning])
logger.info(f"Finished getting basic information of the endpoint")

# Read Last* attributes
Expand Down Expand Up @@ -231,7 +231,7 @@ async def test_thread(self, endpointId):
# Remove existing network
logger.info(f"Check network list")
res = await self._devCtrl.ReadAttribute(nodeid=self._nodeid, attributes=[(endpointId, Clusters.NetworkCommissioning.Attributes.Networks)], returnClusterObject=True)
networkList = res[0][endpointId][Clusters.NetworkCommissioning].networks
networkList = res[endpointId][Clusters.NetworkCommissioning].networks
logger.info(f"Got network list: {networkList}")
if len(networkList) != 0:
logger.info(f"Removing existing network")
Expand All @@ -254,7 +254,7 @@ async def test_thread(self, endpointId):

logger.info(f"Check network list")
res = await self._devCtrl.ReadAttribute(nodeid=self._nodeid, attributes=[(endpointId, Clusters.NetworkCommissioning.Attributes.Networks)], returnClusterObject=True)
networkList = res[0][endpointId][Clusters.NetworkCommissioning].networks
networkList = res[endpointId][Clusters.NetworkCommissioning].networks
logger.info(f"Got network list: {networkList}")
if len(networkList) != 1:
raise AssertionError(
Expand Down Expand Up @@ -297,7 +297,6 @@ async def run(self):
try:
endpoints = await self._devCtrl.ReadAttribute(nodeid=self._nodeid, attributes=[(Clusters.NetworkCommissioning.Attributes.FeatureMap)], returnClusterObject=True)
logger.info(endpoints)
endpoints = endpoints[0]
for endpoint, obj in endpoints.items():
clus = obj[Clusters.NetworkCommissioning]
if clus.featureMap == WIFI_NETWORK_FEATURE_MAP:
Expand Down

0 comments on commit 9e4b81a

Please sign in to comment.