Skip to content

Commit

Permalink
Minor rework of device data, create or destroy devices that shouldn't…
Browse files Browse the repository at this point in the history
… be present on the node.
  • Loading branch information
pwood committed Nov 28, 2023
1 parent 75a82b2 commit e0be24f
Show file tree
Hide file tree
Showing 4 changed files with 162 additions and 24 deletions.
2 changes: 2 additions & 0 deletions device.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ type device struct {
m *sync.RWMutex

// Mutable data, obtain lock first.
deviceId uint16
capabilities []da.Capability
productData productData
}

func (d device) Gateway() da.Gateway {
Expand Down
82 changes: 71 additions & 11 deletions enumerate_device.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,14 @@ const (
EnumerationNetworkRetries = 5
)

type deviceManager interface {
createNextDevice(*node) *device
removeDevice(IEEEAddressWithSubIdentifier) bool
}

type enumerateDevice struct {
gw *gateway
dm deviceManager
logger logwrap.Logger

nq zigbee.NodeQuerier
Expand Down Expand Up @@ -86,12 +92,16 @@ func (e enumerateDevice) enumerate(pctx context.Context, n *node) {
return
}

_ = e.splitInventoryToDevices(inv)
inventoryDevices := e.groupInventoryDevices(inv)

_ = e.updateNodeTable(n, inventoryDevices)

}

func (e enumerateDevice) interrogateNode(ctx context.Context, n *node) (inventory, error) {
var inv inventory
inv.endpoints = make(map[zigbee.Endpoint]endpointDetails)
inv := inventory{
endpoints: make(map[zigbee.Endpoint]endpointDetails),
}

e.logger.LogTrace(ctx, "Enumerating node description.")
if nd, err := retry.RetryWithValue(ctx, EnumerationNetworkTimeout, EnumerationNetworkRetries, func(ctx context.Context) (zigbee.NodeDescription, error) {
Expand Down Expand Up @@ -188,24 +198,24 @@ func (e enumerateDevice) runRules(inv inventory) (inventory, error) {
}

type inventoryDevice struct {
DeviceId uint16
Endpoints []endpointDetails
deviceId uint16
endpoints []endpointDetails
}

func (e enumerateDevice) splitInventoryToDevices(inv inventory) []inventoryDevice {
func (e enumerateDevice) groupInventoryDevices(inv inventory) []inventoryDevice {
devices := map[uint16]*inventoryDevice{}

for _, ep := range inv.endpoints {
invDev := devices[ep.description.DeviceID]
if invDev == nil {
invDev = &inventoryDevice{DeviceId: ep.description.DeviceID}
invDev = &inventoryDevice{deviceId: ep.description.DeviceID}
devices[ep.description.DeviceID] = invDev
}

invDev.Endpoints = append(invDev.Endpoints, ep)
invDev.endpoints = append(invDev.endpoints, ep)

sort.Slice(invDev.Endpoints, func(i, j int) bool {
return invDev.Endpoints[i].description.Endpoint < invDev.Endpoints[j].description.Endpoint
sort.Slice(invDev.endpoints, func(i, j int) bool {
return invDev.endpoints[i].description.Endpoint < invDev.endpoints[j].description.Endpoint
})
}

Expand All @@ -215,10 +225,60 @@ func (e enumerateDevice) splitInventoryToDevices(inv inventory) []inventoryDevic
}

sort.Slice(outDevices, func(i, j int) bool {
return outDevices[i].DeviceId < outDevices[j].DeviceId
return outDevices[i].deviceId < outDevices[j].deviceId
})

return outDevices
}

func (e enumerateDevice) updateNodeTable(n *node, inventoryDevices []inventoryDevice) map[uint16]*device {
deviceIdMapping := map[uint16]*device{}

/* Find existing devices that match the deviceId. */
n.m.RLock()
for _, i := range inventoryDevices {
for _, d := range n.device {
d.m.RLock()
devId := d.deviceId
d.m.RUnlock()

if devId == i.deviceId {
deviceIdMapping[i.deviceId] = d
break
}
}
}
n.m.RUnlock()

/* Create new devices for those that are missing. */
for _, i := range inventoryDevices {
if _, found := deviceIdMapping[i.deviceId]; !found {
d := e.dm.createNextDevice(n)
d.m.Lock()
d.deviceId = i.deviceId
d.m.Unlock()
deviceIdMapping[i.deviceId] = d
}
}

/* Report devices that should no longer be present on node. */
var devicesToRemove []IEEEAddressWithSubIdentifier

n.m.RLock()
for _, d := range n.device {
d.m.RLock()
if _, found := deviceIdMapping[d.deviceId]; !found {
devicesToRemove = append(devicesToRemove, d.address)
}
d.m.RUnlock()
}
n.m.RUnlock()

for _, d := range devicesToRemove {
e.dm.removeDevice(d)
}

return deviceIdMapping
}

var _ capabilities.EnumerateDevice = (*enumerateDevice)(nil)
92 changes: 86 additions & 6 deletions enumerate_device_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ func Test_enumerateDevice_runRules(t *testing.T) {
})
}

func Test_enumerateDevice_splitInventoryToDevices(t *testing.T) {
func Test_enumerateDevice_groupInventoryDevices(t *testing.T) {
t.Run("aggregates into devices and sorts endpoints and device ids", func(t *testing.T) {
inv := inventory{
endpoints: map[zigbee.Endpoint]endpointDetails{
Expand All @@ -250,8 +250,8 @@ func Test_enumerateDevice_splitInventoryToDevices(t *testing.T) {

expected := []inventoryDevice{
{
DeviceId: 1,
Endpoints: []endpointDetails{
deviceId: 1,
endpoints: []endpointDetails{
{
description: zigbee.EndpointDescription{
Endpoint: 1,
Expand All @@ -267,8 +267,8 @@ func Test_enumerateDevice_splitInventoryToDevices(t *testing.T) {
},
},
{
DeviceId: 2,
Endpoints: []endpointDetails{
deviceId: 2,
endpoints: []endpointDetails{
{
description: zigbee.EndpointDescription{
Endpoint: 2,
Expand All @@ -280,8 +280,88 @@ func Test_enumerateDevice_splitInventoryToDevices(t *testing.T) {
}

ed := enumerateDevice{logger: logwrap.New(discard.Discard())}
actual := ed.splitInventoryToDevices(inv)
actual := ed.groupInventoryDevices(inv)

assert.Equal(t, expected, actual)
})
}

type mockDeviceManager struct {
mock.Mock
}

func (m *mockDeviceManager) createNextDevice(n *node) *device {
args := m.Called(n)
return args.Get(0).(*device)
}

func (m *mockDeviceManager) removeDevice(i IEEEAddressWithSubIdentifier) bool {
args := m.Called(i)
return args.Bool(0)
}

func Test_enumerateDevice_updateNodeTable(t *testing.T) {
t.Run("creates new device if missing from node", func(t *testing.T) {
mdm := &mockDeviceManager{}
defer mdm.AssertExpectations(t)

ed := enumerateDevice{logger: logwrap.New(discard.Discard()), dm: mdm}
n := &node{m: &sync.RWMutex{}}
d := &device{m: &sync.RWMutex{}}

mdm.On("createNextDevice", n).Return(d)

expectedDeviceId := uint16(0x2000)

id := []inventoryDevice{
{
deviceId: expectedDeviceId,
},
}

mapping := ed.updateNodeTable(n, id)

assert.Equal(t, d, mapping[expectedDeviceId])
assert.Equal(t, expectedDeviceId, d.deviceId)
})

t.Run("returns an existing on in mapping if present", func(t *testing.T) {
mdm := &mockDeviceManager{}
defer mdm.AssertExpectations(t)

existingDeviceId := uint16(0x2000)

ed := enumerateDevice{logger: logwrap.New(discard.Discard()), dm: mdm}
d := &device{m: &sync.RWMutex{}, deviceId: existingDeviceId}
n := &node{m: &sync.RWMutex{}, device: map[uint8]*device{0: d}}

id := []inventoryDevice{
{
deviceId: existingDeviceId,
},
}

mapping := ed.updateNodeTable(n, id)

assert.Equal(t, d, mapping[existingDeviceId])
assert.Equal(t, existingDeviceId, d.deviceId)
})

t.Run("removes an device that should not be present", func(t *testing.T) {
mdm := &mockDeviceManager{}
defer mdm.AssertExpectations(t)

unwantedDeviceId := uint16(0x2000)
address := IEEEAddressWithSubIdentifier{IEEEAddress: zigbee.GenerateLocalAdministeredIEEEAddress(), SubIdentifier: 0}

mdm.On("removeDevice", address).Return(true)

ed := enumerateDevice{logger: logwrap.New(discard.Discard()), dm: mdm}
d := &device{m: &sync.RWMutex{}, deviceId: unwantedDeviceId, address: address}
n := &node{m: &sync.RWMutex{}, device: map[uint8]*device{0: d}}

mapping := ed.updateNodeTable(n, nil)

assert.Nil(t, mapping[unwantedDeviceId])
})
}
10 changes: 3 additions & 7 deletions node.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,16 +73,12 @@ type node struct {
m *sync.RWMutex

// Thread safe data.
sequence chan uint8
sequence chan uint8
enumerationSem *semaphore.Weighted

// Mutable data, obtain lock first.
device map[uint8]*device

device map[uint8]*device
useAPSAck bool

// Enumeration data.
enumerationSem *semaphore.Weighted
inventory inventory
}

func makeTransactionSequence() chan uint8 {
Expand Down

0 comments on commit e0be24f

Please sign in to comment.