Skip to content

Commit

Permalink
[Python] Provide a user friendly API for getting data versions (#15709)
Browse files Browse the repository at this point in the history
* [Python] Cleanup data version API

* Fix

* Fix
  • Loading branch information
erjiaqing authored and pull[bot] committed Jul 7, 2023
1 parent 174a1ad commit 1105232
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 61 deletions.
2 changes: 1 addition & 1 deletion src/controller/python/chip/ChipDeviceCtrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,7 +759,7 @@ def ZCLReadAttribute(self, cluster, attribute, nodeid, endpoint, groupid, blocki
nodeid, [(endpoint, attributeType)]))
path = ClusterAttribute.AttributePath(
EndpointId=endpoint, Attribute=attributeType)
return im.AttributeReadResult(path=im.AttributePath(nodeId=nodeid, endpointId=path.EndpointId, clusterId=path.ClusterId, attributeId=path.AttributeId), status=0, value=result[0][endpoint][clusterType][attributeType])
return im.AttributeReadResult(path=im.AttributePath(nodeId=nodeid, endpointId=path.EndpointId, clusterId=path.ClusterId, attributeId=path.AttributeId), status=0, value=result[endpoint][clusterType][attributeType])

def ZCLWriteAttribute(self, cluster: str, attribute: str, nodeid, endpoint, groupid, value, dataVersion=0, blocking=True):
req = None
Expand Down
27 changes: 16 additions & 11 deletions src/controller/python/chip/clusters/Attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,13 @@ class SubscriptionParameters:
MaxReportIntervalCeilingSeconds: int


class DataVersion:
'''
A helper class as a key for getting cluster data version when reading attributes without returnClusterObject.
'''
pass


@dataclass
class AttributeCache:
''' A cache that stores data & errors returned in read/subscribe reports, but organizes it topologically
Expand Down Expand Up @@ -358,23 +365,18 @@ def UpdateTLV(self, path: AttributePath, dataVersion: int, data: Union[bytes, V
self.versionList[path.EndpointId] = {}

endpointCache = self.attributeTLVCache[path.EndpointId]
endpoint = self.versionList[path.EndpointId]
endpointVersion = self.versionList[path.EndpointId]
if (path.ClusterId not in endpointCache):
endpointCache[path.ClusterId] = {}

if (path.ClusterId not in endpoint):
endpoint[path.ClusterId] = {}
# All attributes from the same cluster instance should have the same dataVersion, so we can set the dataVersion of the cluster to the dataVersion with a random attribute.
endpointVersion[path.ClusterId] = dataVersion

clusterCache = endpointCache[path.ClusterId]
cluster = endpoint[path.ClusterId]
if (path.AttributeId not in clusterCache):
clusterCache[path.AttributeId] = None

if (path.AttributeId not in cluster):
cluster[path.AttributeId] = None

clusterCache[path.AttributeId] = data
cluster[path.AttributeId] = dataVersion

def UpdateCachedData(self):
''' This converts the raw TLV data into a cluster object format.
Expand Down Expand Up @@ -409,12 +411,16 @@ def UpdateCachedData(self):
endpointCache[clusterType] = {}

clusterCache = endpointCache[clusterType]
clusterDataVersion = self.versionList.get(
endpoint, {}).get(cluster, None)

if (self.returnClusterObject):
try:
# Since the TLV data is already organized by attribute tags, we can trivially convert to a cluster object representation.
endpointCache[clusterType] = clusterType.FromDict(
data=clusterType.descriptor.TagDictToLabelDict([], tlvCache[endpoint][cluster]))
endpointCache[clusterType].SetDataVersion(
clusterDataVersion)
except Exception as ex:
logging.error(
f"Error converting TLV to Cluster Object for path: Endpoint = {endpoint}, cluster = {str(clusterType)}")
Expand All @@ -423,6 +429,7 @@ def UpdateCachedData(self):
tlvCache[endpoint][cluster], ex)
endpointCache[clusterType] = decodedValue
else:
clusterCache[DataVersion] = clusterDataVersion
for attribute in tlvCache[endpoint][cluster]:
value = tlvCache[endpoint][cluster][attribute]

Expand Down Expand Up @@ -684,14 +691,12 @@ def _handleDone(self):
if (self._transactionType == TransactionType.READ_EVENTS):
self._future.set_result(self._events)
else:
self._future.set_result(
(self._cache.attributeCache, self._cache.versionList))
self._future.set_result(self._cache.attributeCache)

def handleDone(self):
self._event_loop.call_soon_threadsafe(self._handleDone)

def handleReportBegin(self):
self._cache.versionList.clear()
pass

def handleReportEnd(self):
Expand Down
12 changes: 9 additions & 3 deletions src/controller/python/chip/clusters/ClusterObjects.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,10 +219,16 @@ def must_use_timed_invoke(cls) -> bool:


class Cluster(ClusterObject):
''' This class does nothing, but a convenient class that generated clusters can inherit from.
This gives the ability that the users can use issubclass(X, Cluster) to determine if the class represnents a Cluster.
'''
pass
When send read requests with returnClusterObject=True, we will set the dataVersion property of the object.
Otherwise the [endpoint][cluster][Clusters.DataVersion] will be set to the DataVersion of the cluster.
'''
@ChipUtility.classproperty
def dataVersion(self) -> int:
return self._dataVersion

def SetDataVersion(self, version: int) -> None:
self._dataVersion = version


class ClusterAttributeDescriptor:
Expand Down
8 changes: 4 additions & 4 deletions src/controller/python/test/test_scripts/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,8 @@ async def TestMultiFabric(self, ip: str, setuppin: int, nodeid: int):
data2 = await devCtrl2.ReadAttribute(nodeid, [(Clusters.OperationalCredentials.Attributes.NOCs)], fabricFiltered=False)

# Read out noclist from each fabric, and each should contain two NOCs.
nocList1 = data1[0][0][Clusters.OperationalCredentials][Clusters.OperationalCredentials.Attributes.NOCs]
nocList2 = data2[0][0][Clusters.OperationalCredentials][Clusters.OperationalCredentials.Attributes.NOCs]
nocList1 = data1[0][Clusters.OperationalCredentials][Clusters.OperationalCredentials.Attributes.NOCs]
nocList2 = data2[0][Clusters.OperationalCredentials][Clusters.OperationalCredentials.Attributes.NOCs]

if (len(nocList1) != 2 or len(nocList2) != 2):
self.logger.error("Got back invalid nocList")
Expand All @@ -201,8 +201,8 @@ async def TestMultiFabric(self, ip: str, setuppin: int, nodeid: int):
data2 = await devCtrl2.ReadAttribute(nodeid, [(Clusters.OperationalCredentials.Attributes.CurrentFabricIndex)], fabricFiltered=False)

# Read out current fabric from each fabric, and both should be different.
currentFabric1 = data1[0][0][Clusters.OperationalCredentials][Clusters.OperationalCredentials.Attributes.CurrentFabricIndex]
currentFabric2 = data2[0][0][Clusters.OperationalCredentials][Clusters.OperationalCredentials.Attributes.CurrentFabricIndex]
currentFabric1 = data1[0][Clusters.OperationalCredentials][Clusters.OperationalCredentials.Attributes.CurrentFabricIndex]
currentFabric2 = data2[0][Clusters.OperationalCredentials][Clusters.OperationalCredentials.Attributes.CurrentFabricIndex]
if (currentFabric1 == currentFabric2):
self.logger.error(
"Got back fabric indices that match for two different fabrics!")
Expand Down
60 changes: 28 additions & 32 deletions src/controller/python/test/test_scripts/cluster_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import chip.clusters as Clusters
import chip.exceptions
import logging
from chip.clusters.Attribute import AttributePath, AttributeReadResult, AttributeStatus, ValueDecodeFailure, TypedAttributePath, SubscriptionTransaction
from chip.clusters.Attribute import AttributePath, AttributeReadResult, AttributeStatus, ValueDecodeFailure, TypedAttributePath, SubscriptionTransaction, DataVersion
import chip.interaction_model
import asyncio
import time
Expand All @@ -45,31 +45,26 @@ def _IgnoreAttributeDecodeFailure(path):

def VerifyDecodeSuccess(values):
print(f"{values}")
for endpoint in values[0]:
for cluster in values[0][endpoint]:
for attribute in values[0][endpoint][cluster]:
v = values[0][endpoint][cluster][attribute]
print(f"EP{endpoint}/{attribute} = {v}")
for endpoint in values:
for cluster in values[endpoint]:
for attribute in values[endpoint][cluster]:
v = values[endpoint][cluster][attribute]
print(f"EP{endpoint}/{cluster}/{attribute} = {v}")
if (isinstance(v, ValueDecodeFailure)):
if _IgnoreAttributeDecodeFailure((endpoint, cluster, attribute)):
print(
f"Ignoring attribute decode failure for path {endpoint}/{attribute}")
f"Ignoring attribute decode failure for path {endpoint}/{cluster}/{attribute}")
else:
raise AssertionError(
f"Cannot decode value for path {endpoint}/{attribute}, got error: '{str(v.Reason)}', raw TLV data: '{v.TLVValue}'")
f"Cannot decode value for path {endpoint}/{cluster}/{attribute}, got error: '{str(v.Reason)}', raw TLV data: '{v.TLVValue}'")

for endpoint in values[1]:
for cluster in values[1][endpoint]:
for attribute in values[1][endpoint][cluster]:
v = values[1][endpoint][cluster][attribute]
print(f"EP{endpoint}/{attribute} version = {v}")
if (isinstance(v, ValueDecodeFailure)):
if _IgnoreAttributeDecodeFailure((endpoint, cluster, attribute)):
print(
f"Ignoring attribute version decode failure for path {endpoint}/{attribute}")
else:
raise AssertionError(
f"Cannot decode value for path {endpoint}/{attribute}, got error: '{str(v.Reason)}', raw TLV data: '{v.TLVValue}'")
for endpoint in values:
for cluster in values[endpoint]:
v = values[endpoint][cluster].get(DataVersion, None)
print(f"EP{endpoint}/{cluster} version = {v}")
if v is None:
raise AssertionError(
f"Cannot get data version for path {endpoint}/{cluster}")


def _AssumeEventsDecodeSuccess(values):
Expand Down Expand Up @@ -200,9 +195,10 @@ async def TestReadAttributeRequests(cls, devCtrl):
(0, Clusters.Basic.Attributes.HardwareVersion),
]
res = await devCtrl.ReadAttribute(nodeid=NODE_ID, attributes=req)
if ((0 not in res[0]) or (Clusters.Basic not in res[0][0]) or (len(res[0][0][Clusters.Basic]) != 3)):
if ((0 not in res) or (Clusters.Basic not in res[0]) or (len(res[0][Clusters.Basic]) != 4)):
# 3 attribute data + DataVersion
raise AssertionError(
f"Got back {len(res[0])} data items instead of 3")
f"Got back {len(res)} data items instead of 3")
VerifyDecodeSuccess(res)

logger.info("2: Reading Ex Cx A*")
Expand Down Expand Up @@ -237,32 +233,32 @@ async def TestReadAttributeRequests(cls, devCtrl):

res = await devCtrl.ReadAttribute(nodeid=NODE_ID, attributes=req, returnClusterObject=True)
logger.info(
f"Basic Cluster - Label: {res[0][0][Clusters.Basic].productLabel}")
f"Basic Cluster - Label: {res[0][Clusters.Basic].productLabel}")
# TestCluster will be ValueDecodeError here, so we comment out the log below.
# Values are not expected to be ValueDecodeError for real clusters.
# logger.info(
# f"Test Cluster - Struct: {res[0][1][Clusters.TestCluster].structAttr}")
logger.info(f"Test Cluster: {res[0][1][Clusters.TestCluster]}")
# f"Test Cluster - Struct: {res[1][Clusters.TestCluster].structAttr}")
logger.info(f"Test Cluster: {res[1][Clusters.TestCluster]}")

logger.info("7: Reading Chunked List")
res = await devCtrl.ReadAttribute(nodeid=NODE_ID, attributes=[(1, Clusters.TestCluster.Attributes.ListLongOctetString)])
if res[0][1][Clusters.TestCluster][Clusters.TestCluster.Attributes.ListLongOctetString] != [b'0123456789abcdef' * 32] * 4:
if res[1][Clusters.TestCluster][Clusters.TestCluster.Attributes.ListLongOctetString] != [b'0123456789abcdef' * 32] * 4:
raise AssertionError("Unexpected read result")

logger.info("*: Getting current fabric index")
res = await devCtrl.ReadAttribute(nodeid=NODE_ID, attributes=[(0, Clusters.OperationalCredentials.Attributes.CurrentFabricIndex)])
fabricIndex = res[0][0][Clusters.OperationalCredentials][Clusters.OperationalCredentials.Attributes.CurrentFabricIndex]
fabricIndex = res[0][Clusters.OperationalCredentials][Clusters.OperationalCredentials.Attributes.CurrentFabricIndex]

logger.info("8: Read without fabric filter")
res = await devCtrl.ReadAttribute(nodeid=NODE_ID, attributes=[(1, Clusters.TestCluster.Attributes.ListFabricScoped)], fabricFiltered=False)
if len(res[0][1][Clusters.TestCluster][Clusters.TestCluster.Attributes.ListFabricScoped]) != 1:
if len(res[1][Clusters.TestCluster][Clusters.TestCluster.Attributes.ListFabricScoped]) != 1:
raise AssertionError("Expect more elements in the response")

logger.info("9: Read with fabric filter")
res = await devCtrl.ReadAttribute(nodeid=NODE_ID, attributes=[(1, Clusters.TestCluster.Attributes.ListFabricScoped)], fabricFiltered=True)
if len(res[0][1][Clusters.TestCluster][Clusters.TestCluster.Attributes.ListFabricScoped]) != 1:
if len(res[1][Clusters.TestCluster][Clusters.TestCluster.Attributes.ListFabricScoped]) != 1:
raise AssertionError("Expect exact one element in the response")
if res[0][1][Clusters.TestCluster][Clusters.TestCluster.Attributes.ListFabricScoped][0].fabricIndex != fabricIndex:
if res[1][Clusters.TestCluster][Clusters.TestCluster.Attributes.ListFabricScoped][0].fabricIndex != fabricIndex:
raise AssertionError(
"Expect the fabric index matches the one current reading")

Expand Down Expand Up @@ -381,7 +377,7 @@ async def TestReadWriteAttributeRequestsWithVersion(cls, devCtrl):
]
res = await devCtrl.ReadAttribute(nodeid=NODE_ID, attributes=req)
VerifyDecodeSuccess(res)
data_version = res[1][0][40][1]
data_version = res[0][Clusters.Basic][DataVersion]

res = await devCtrl.WriteAttribute(nodeid=NODE_ID,
attributes=[
Expand All @@ -405,7 +401,7 @@ async def TestReadWriteAttributeRequestsWithVersion(cls, devCtrl):
]
res = await devCtrl.ReadAttribute(nodeid=NODE_ID, attributes=req, dataVersionFilters=[(0, Clusters.Basic, data_version)])
VerifyDecodeSuccess(res)
new_data_version = res[1][0][40][1]
new_data_version = res[0][Clusters.Basic][DataVersion]
if (data_version + 1) != new_data_version:
raise AssertionError("Version mistmatch happens.")

Expand Down
Loading

0 comments on commit 1105232

Please sign in to comment.