From 646edf548b9a76cdecb5c7e3f14a5590328cea6c Mon Sep 17 00:00:00 2001 From: Peter Wood Date: Mon, 27 Nov 2023 22:10:33 +0000 Subject: [PATCH] Minor rework of device data, create or destroy devices that shouldn't be present on the node. --- device.go | 2 + enumerate_device.go | 82 ++++++++++++++++++++++++++++++----- enumerate_device_test.go | 92 +++++++++++++++++++++++++++++++++++++--- node.go | 10 ++--- 4 files changed, 162 insertions(+), 24 deletions(-) diff --git a/device.go b/device.go index 0577a57..a7cccc4 100644 --- a/device.go +++ b/device.go @@ -14,7 +14,9 @@ type device struct { m *sync.RWMutex // Mutable data, obtain lock first. + deviceId uint16 capabilities []da.Capability + productData productData } func (d device) Gateway() da.Gateway { diff --git a/enumerate_device.go b/enumerate_device.go index 0ee1f94..9a66ba8 100644 --- a/enumerate_device.go +++ b/enumerate_device.go @@ -22,8 +22,14 @@ const ( EnumerationNetworkRetries = 5 ) +type deviceManager interface { + createNextDevice(*node) *device + removeDevice(IEEEAddressWithSubIdentifier) bool +} + type enumerateDevice struct { gw *gateway + dm deviceManager logger logwrap.Logger nq zigbee.NodeQuerier @@ -86,12 +92,16 @@ func (e enumerateDevice) enumerate(pctx context.Context, n *node) { return } - _ = e.splitInventoryToDevices(inv) + inventoryDevices := e.groupInventoryDevices(inv) + + _ = e.updateNodeTable(n, inventoryDevices) + } func (e enumerateDevice) interrogateNode(ctx context.Context, n *node) (inventory, error) { - var inv inventory - inv.endpoints = make(map[zigbee.Endpoint]endpointDetails) + inv := inventory{ + endpoints: make(map[zigbee.Endpoint]endpointDetails), + } e.logger.LogTrace(ctx, "Enumerating node description.") if nd, err := retry.RetryWithValue(ctx, EnumerationNetworkTimeout, EnumerationNetworkRetries, func(ctx context.Context) (zigbee.NodeDescription, error) { @@ -188,24 +198,24 @@ func (e enumerateDevice) runRules(inv inventory) (inventory, error) { } type inventoryDevice struct { - DeviceId uint16 - Endpoints []endpointDetails + deviceId uint16 + endpoints []endpointDetails } -func (e enumerateDevice) splitInventoryToDevices(inv inventory) []inventoryDevice { +func (e enumerateDevice) groupInventoryDevices(inv inventory) []inventoryDevice { devices := map[uint16]*inventoryDevice{} for _, ep := range inv.endpoints { invDev := devices[ep.description.DeviceID] if invDev == nil { - invDev = &inventoryDevice{DeviceId: ep.description.DeviceID} + invDev = &inventoryDevice{deviceId: ep.description.DeviceID} devices[ep.description.DeviceID] = invDev } - invDev.Endpoints = append(invDev.Endpoints, ep) + invDev.endpoints = append(invDev.endpoints, ep) - sort.Slice(invDev.Endpoints, func(i, j int) bool { - return invDev.Endpoints[i].description.Endpoint < invDev.Endpoints[j].description.Endpoint + sort.Slice(invDev.endpoints, func(i, j int) bool { + return invDev.endpoints[i].description.Endpoint < invDev.endpoints[j].description.Endpoint }) } @@ -215,10 +225,60 @@ func (e enumerateDevice) splitInventoryToDevices(inv inventory) []inventoryDevic } sort.Slice(outDevices, func(i, j int) bool { - return outDevices[i].DeviceId < outDevices[j].DeviceId + return outDevices[i].deviceId < outDevices[j].deviceId }) return outDevices } +func (e enumerateDevice) updateNodeTable(n *node, inventoryDevices []inventoryDevice) map[uint16]*device { + deviceIdMapping := map[uint16]*device{} + + /* Find existing devices that match the deviceId. */ + n.m.RLock() + for _, i := range inventoryDevices { + for _, d := range n.device { + d.m.RLock() + devId := d.deviceId + d.m.RUnlock() + + if devId == i.deviceId { + deviceIdMapping[i.deviceId] = d + break + } + } + } + n.m.RUnlock() + + /* Create new devices for those that are missing. */ + for _, i := range inventoryDevices { + if _, found := deviceIdMapping[i.deviceId]; !found { + d := e.dm.createNextDevice(n) + d.m.Lock() + d.deviceId = i.deviceId + d.m.Unlock() + deviceIdMapping[i.deviceId] = d + } + } + + /* Report devices that should no longer be present on node. */ + var devicesToRemove []IEEEAddressWithSubIdentifier + + n.m.RLock() + for _, d := range n.device { + d.m.RLock() + if _, found := deviceIdMapping[d.deviceId]; !found { + devicesToRemove = append(devicesToRemove, d.address) + } + d.m.RUnlock() + } + n.m.RUnlock() + + for _, d := range devicesToRemove { + e.dm.removeDevice(d) + } + + return deviceIdMapping +} + var _ capabilities.EnumerateDevice = (*enumerateDevice)(nil) diff --git a/enumerate_device_test.go b/enumerate_device_test.go index ff21915..a8b6977 100644 --- a/enumerate_device_test.go +++ b/enumerate_device_test.go @@ -223,7 +223,7 @@ func Test_enumerateDevice_runRules(t *testing.T) { }) } -func Test_enumerateDevice_splitInventoryToDevices(t *testing.T) { +func Test_enumerateDevice_groupInventoryDevices(t *testing.T) { t.Run("aggregates into devices and sorts endpoints and device ids", func(t *testing.T) { inv := inventory{ endpoints: map[zigbee.Endpoint]endpointDetails{ @@ -250,8 +250,8 @@ func Test_enumerateDevice_splitInventoryToDevices(t *testing.T) { expected := []inventoryDevice{ { - DeviceId: 1, - Endpoints: []endpointDetails{ + deviceId: 1, + endpoints: []endpointDetails{ { description: zigbee.EndpointDescription{ Endpoint: 1, @@ -267,8 +267,8 @@ func Test_enumerateDevice_splitInventoryToDevices(t *testing.T) { }, }, { - DeviceId: 2, - Endpoints: []endpointDetails{ + deviceId: 2, + endpoints: []endpointDetails{ { description: zigbee.EndpointDescription{ Endpoint: 2, @@ -280,8 +280,88 @@ func Test_enumerateDevice_splitInventoryToDevices(t *testing.T) { } ed := enumerateDevice{logger: logwrap.New(discard.Discard())} - actual := ed.splitInventoryToDevices(inv) + actual := ed.groupInventoryDevices(inv) assert.Equal(t, expected, actual) }) } + +type mockDeviceManager struct { + mock.Mock +} + +func (m *mockDeviceManager) createNextDevice(n *node) *device { + args := m.Called(n) + return args.Get(0).(*device) +} + +func (m *mockDeviceManager) removeDevice(i IEEEAddressWithSubIdentifier) bool { + args := m.Called(i) + return args.Bool(0) +} + +func Test_enumerateDevice_updateNodeTable(t *testing.T) { + t.Run("creates new device if missing from node", func(t *testing.T) { + mdm := &mockDeviceManager{} + defer mdm.AssertExpectations(t) + + ed := enumerateDevice{logger: logwrap.New(discard.Discard()), dm: mdm} + n := &node{m: &sync.RWMutex{}} + d := &device{m: &sync.RWMutex{}} + + mdm.On("createNextDevice", n).Return(d) + + expectedDeviceId := uint16(0x2000) + + id := []inventoryDevice{ + { + deviceId: expectedDeviceId, + }, + } + + mapping := ed.updateNodeTable(n, id) + + assert.Equal(t, d, mapping[expectedDeviceId]) + assert.Equal(t, expectedDeviceId, d.deviceId) + }) + + t.Run("returns an existing on in mapping if present", func(t *testing.T) { + mdm := &mockDeviceManager{} + defer mdm.AssertExpectations(t) + + existingDeviceId := uint16(0x2000) + + ed := enumerateDevice{logger: logwrap.New(discard.Discard()), dm: mdm} + d := &device{m: &sync.RWMutex{}, deviceId: existingDeviceId} + n := &node{m: &sync.RWMutex{}, device: map[uint8]*device{0: d}} + + id := []inventoryDevice{ + { + deviceId: existingDeviceId, + }, + } + + mapping := ed.updateNodeTable(n, id) + + assert.Equal(t, d, mapping[existingDeviceId]) + assert.Equal(t, existingDeviceId, d.deviceId) + }) + + t.Run("removes an device that should not be present", func(t *testing.T) { + mdm := &mockDeviceManager{} + defer mdm.AssertExpectations(t) + + unwantedDeviceId := uint16(0x2000) + address := IEEEAddressWithSubIdentifier{IEEEAddress: zigbee.GenerateLocalAdministeredIEEEAddress(), SubIdentifier: 0} + + mdm.On("removeDevice", address).Return(true) + + ed := enumerateDevice{logger: logwrap.New(discard.Discard()), dm: mdm} + d := &device{m: &sync.RWMutex{}, deviceId: unwantedDeviceId, address: address} + n := &node{m: &sync.RWMutex{}, device: map[uint8]*device{0: d}} + + mapping := ed.updateNodeTable(n, nil) + + assert.Nil(t, mapping[unwantedDeviceId]) + }) +} diff --git a/node.go b/node.go index ba31956..4b5d7ff 100644 --- a/node.go +++ b/node.go @@ -73,16 +73,12 @@ type node struct { m *sync.RWMutex // Thread safe data. - sequence chan uint8 + sequence chan uint8 + enumerationSem *semaphore.Weighted // Mutable data, obtain lock first. - device map[uint8]*device - + device map[uint8]*device useAPSAck bool - - // Enumeration data. - enumerationSem *semaphore.Weighted - inventory inventory } func makeTransactionSequence() chan uint8 {