Skip to content

Commit

Permalink
Add DeviceRemoval, move creation of all device capabilties to table.go.
Browse files Browse the repository at this point in the history
  • Loading branch information
pwood committed Dec 31, 2023
1 parent 61623c3 commit 3a08987
Show file tree
Hide file tree
Showing 8 changed files with 145 additions and 23 deletions.
19 changes: 13 additions & 6 deletions device.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ type device struct {
gw da.Gateway
m *sync.RWMutex
eda *enumeratedDeviceAttachment
dr *deviceRemoval

// Mutable data, obtain lock first.
deviceId uint16
Expand All @@ -23,14 +24,16 @@ type device struct {
}

func (d device) Capability(capability da.Capability) da.BasicCapability {
if capability == capabilities.EnumerateDeviceFlag {
switch capability {
case capabilities.EnumerateDeviceFlag:
return d.eda
case capabilities.DeviceRemovalFlag:
return d.dr
default:
d.m.RLock()
defer d.m.RUnlock()
return d.capabilities[capability]
}

d.m.RLock()
defer d.m.RUnlock()

return d.capabilities[capability]
}

func (d device) Gateway() da.Gateway {
Expand All @@ -55,6 +58,10 @@ func (d device) Capabilities() []da.Capability {
caps = append(caps, capabilities.EnumerateDeviceFlag)
}

if d.dr != nil {
caps = append(caps, capabilities.DeviceRemovalFlag)
}

return caps
}

Expand Down
41 changes: 41 additions & 0 deletions device_removal.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package zda

import (
"context"
"fmt"
"github.com/shimmeringbee/da"
"github.com/shimmeringbee/da/capabilities"
"github.com/shimmeringbee/logwrap"
"github.com/shimmeringbee/zigbee"
)

type deviceRemoval struct {
node *node
logger logwrap.Logger
nodeRemover zigbee.NodeRemover
}

func (z deviceRemoval) Capability() da.Capability {
return capabilities.DeviceRemovalFlag
}

func (z deviceRemoval) Name() string {
return capabilities.StandardNames[z.Capability()]
}

func (z deviceRemoval) Remove(ctx context.Context, removalType capabilities.RemovalType) error {
switch removalType {
case capabilities.Request:
z.logger.LogInfo(ctx, "Requesting removal of device from zigbee provider.", logwrap.Datum("IEEEAddress", z.node.address.String()))
return z.nodeRemover.RequestNodeLeave(ctx, z.node.address)
case capabilities.Force:
z.logger.LogInfo(ctx, "Requesting forced removal of device from zigbee provider.", logwrap.Datum("IEEEAddress", z.node.address.String()))
return z.nodeRemover.ForceNodeLeave(ctx, z.node.address)
default:
z.logger.LogError(ctx, "Request removal called with unknown removal type.", logwrap.Datum("IEEEAddress", z.node.address.String()), logwrap.Datum("removalType", removalType))
return fmt.Errorf("remove device called with unknown removal type: %v", removalType)
}
}

var _ capabilities.DeviceRemoval = (*deviceRemoval)(nil)
var _ da.BasicCapability = (*deviceRemoval)(nil)
50 changes: 50 additions & 0 deletions device_removal_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package zda

import (
"context"
"github.com/shimmeringbee/da/capabilities"
lw "github.com/shimmeringbee/logwrap"
"github.com/shimmeringbee/logwrap/impl/discard"
"github.com/shimmeringbee/zigbee"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"testing"
)

func TestZigbeeDeviceRemoval_Remove(t *testing.T) {
t.Run("successfully calls RequestNodeLeave on provider with devices ieee and Request flag", func(t *testing.T) {
mockProvider := zigbee.MockProvider{}
defer mockProvider.AssertExpectations(t)

expectedIEEE := zigbee.GenerateLocalAdministeredIEEEAddress()

zed := deviceRemoval{
nodeRemover: &mockProvider,
logger: lw.New(discard.Discard()),
node: &node{address: expectedIEEE},
}

mockProvider.On("RequestNodeLeave", mock.Anything, expectedIEEE).Return(nil)

err := zed.Remove(context.Background(), capabilities.Request)
assert.NoError(t, err)
})

t.Run("successfully calls ForceNodeLeave on provider with devices ieee and Force flag", func(t *testing.T) {
mockProvider := zigbee.MockProvider{}
defer mockProvider.AssertExpectations(t)

expectedIEEE := zigbee.GenerateLocalAdministeredIEEEAddress()

zed := deviceRemoval{
nodeRemover: &mockProvider,
logger: lw.New(discard.Discard()),
node: &node{address: expectedIEEE},
}

mockProvider.On("ForceNodeLeave", mock.Anything, expectedIEEE).Return(nil)

err := zed.Remove(context.Background(), capabilities.Force)
assert.NoError(t, err)
})
}
7 changes: 0 additions & 7 deletions enumerate_device.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,13 +257,6 @@ func (e enumerateDevice) updateNodeTable(n *node, inventoryDevices []inventoryDe
d := e.dm.createNextDevice(n)
d.m.Lock()
d.deviceId = i.deviceId
d.eda = &enumeratedDeviceAttachment{
node: n,
device: d,
ed: &e,

m: &sync.RWMutex{},
}
d.m.Unlock()
deviceIdMapping[i.deviceId] = d
}
Expand Down
9 changes: 0 additions & 9 deletions enumerate_device_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -328,15 +328,6 @@ func Test_enumerateDevice_updateNodeTable(t *testing.T) {

assert.Equal(t, d, mapping[expectedDeviceId])
assert.Equal(t, expectedDeviceId, d.deviceId)

c := d.Capability(capabilities.EnumerateDeviceFlag)
assert.NotNil(t, c)

cc, ok := c.(*enumeratedDeviceAttachment)
assert.True(t, ok)

assert.Equal(t, n, cc.node)
assert.Equal(t, d, cc.device)
})

t.Run("returns an existing on in mapping if present", func(t *testing.T) {
Expand Down
17 changes: 17 additions & 0 deletions gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"github.com/shimmeringbee/da"
"github.com/shimmeringbee/da/capabilities"
"github.com/shimmeringbee/logwrap"
"github.com/shimmeringbee/zda/implcaps/factory"
"github.com/shimmeringbee/zda/rules"
"github.com/shimmeringbee/zigbee"
"log"
Expand Down Expand Up @@ -37,6 +38,21 @@ func New(baseCtx context.Context, p zigbee.Provider, r ruleExecutor) da.Gateway
events: make(chan interface{}, 0xffff),
}

gw.ed = &enumerateDevice{
gw: gw,
dm: gw,
logger: logwrap.Logger{},
nq: gw.provider,
zclReadFn: nil,
capabilityFactory: factory.Create,
}

if gw.ruleExecutor != nil {
gw.ed.runRulesFn = gw.ruleExecutor.Execute
}

gw.callbacks.Add(gw.ed.onNodeJoin)

gw.WithGoLogger(log.New(os.Stderr, "", log.LstdFlags))
return gw
}
Expand All @@ -60,6 +76,7 @@ type gateway struct {
callbacks callbacks.AdderCaller
ruleExecutor ruleExecutor

ed *enumerateDevice
events chan interface{}
}

Expand Down
18 changes: 17 additions & 1 deletion table.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,26 @@ func (g *gateway) createNextDevice(n *node) *device {

subId := n._nextDeviceSubIdentifier()

return g._createDevice(n, IEEEAddressWithSubIdentifier{
d := g._createDevice(n, IEEEAddressWithSubIdentifier{
IEEEAddress: n.address,
SubIdentifier: subId,
})

d.eda = &enumeratedDeviceAttachment{
node: n,
device: d,
ed: g.ed,

m: &sync.RWMutex{},
}

d.dr = &deviceRemoval{
node: n,
logger: g.logger,
nodeRemover: g.provider,
}

return d
}

func (g *gateway) _createDevice(n *node, addr IEEEAddressWithSubIdentifier) *device {
Expand Down
7 changes: 7 additions & 0 deletions table_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package zda

import (
"context"
"github.com/shimmeringbee/da/capabilities"
"github.com/shimmeringbee/zigbee"
"github.com/stretchr/testify/assert"
"testing"
Expand Down Expand Up @@ -84,6 +85,12 @@ func Test_gateway_createNextDevice(t *testing.T) {
assert.Equal(t, uint8(0), d.address.SubIdentifier)
assert.Equal(t, g, d.gw)

assert.NotNil(t, d.eda)
assert.NotNil(t, d.dr)

assert.Contains(t, d.Capabilities(), capabilities.EnumerateDeviceFlag)
assert.Contains(t, d.Capabilities(), capabilities.DeviceRemovalFlag)

d = g.createNextDevice(n)

assert.Equal(t, addr, d.address.IEEEAddress)
Expand Down

0 comments on commit 3a08987

Please sign in to comment.