diff --git a/device.go b/device.go index 7e763e9..71cf750 100644 --- a/device.go +++ b/device.go @@ -15,6 +15,7 @@ type device struct { gw da.Gateway m *sync.RWMutex eda *enumeratedDeviceAttachment + dr *deviceRemoval // Mutable data, obtain lock first. deviceId uint16 @@ -23,14 +24,16 @@ type device struct { } func (d device) Capability(capability da.Capability) da.BasicCapability { - if capability == capabilities.EnumerateDeviceFlag { + switch capability { + case capabilities.EnumerateDeviceFlag: return d.eda + case capabilities.DeviceRemovalFlag: + return d.dr + default: + d.m.RLock() + defer d.m.RUnlock() + return d.capabilities[capability] } - - d.m.RLock() - defer d.m.RUnlock() - - return d.capabilities[capability] } func (d device) Gateway() da.Gateway { @@ -55,6 +58,10 @@ func (d device) Capabilities() []da.Capability { caps = append(caps, capabilities.EnumerateDeviceFlag) } + if d.dr != nil { + caps = append(caps, capabilities.DeviceRemovalFlag) + } + return caps } diff --git a/device_removal.go b/device_removal.go new file mode 100644 index 0000000..2da3135 --- /dev/null +++ b/device_removal.go @@ -0,0 +1,41 @@ +package zda + +import ( + "context" + "fmt" + "github.com/shimmeringbee/da" + "github.com/shimmeringbee/da/capabilities" + "github.com/shimmeringbee/logwrap" + "github.com/shimmeringbee/zigbee" +) + +type deviceRemoval struct { + node *node + logger logwrap.Logger + nodeRemover zigbee.NodeRemover +} + +func (z deviceRemoval) Capability() da.Capability { + return capabilities.DeviceRemovalFlag +} + +func (z deviceRemoval) Name() string { + return capabilities.StandardNames[z.Capability()] +} + +func (z deviceRemoval) Remove(ctx context.Context, removalType capabilities.RemovalType) error { + switch removalType { + case capabilities.Request: + z.logger.LogInfo(ctx, "Requesting removal of device from zigbee provider.", logwrap.Datum("IEEEAddress", z.node.address.String())) + return z.nodeRemover.RequestNodeLeave(ctx, z.node.address) + case capabilities.Force: + z.logger.LogInfo(ctx, "Requesting forced removal of device from zigbee provider.", logwrap.Datum("IEEEAddress", z.node.address.String())) + return z.nodeRemover.ForceNodeLeave(ctx, z.node.address) + default: + z.logger.LogError(ctx, "Request removal called with unknown removal type.", logwrap.Datum("IEEEAddress", z.node.address.String()), logwrap.Datum("removalType", removalType)) + return fmt.Errorf("remove device called with unknown removal type: %v", removalType) + } +} + +var _ capabilities.DeviceRemoval = (*deviceRemoval)(nil) +var _ da.BasicCapability = (*deviceRemoval)(nil) diff --git a/device_removal_test.go b/device_removal_test.go new file mode 100644 index 0000000..1df6530 --- /dev/null +++ b/device_removal_test.go @@ -0,0 +1,50 @@ +package zda + +import ( + "context" + "github.com/shimmeringbee/da/capabilities" + lw "github.com/shimmeringbee/logwrap" + "github.com/shimmeringbee/logwrap/impl/discard" + "github.com/shimmeringbee/zigbee" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "testing" +) + +func TestZigbeeDeviceRemoval_Remove(t *testing.T) { + t.Run("successfully calls RequestNodeLeave on provider with devices ieee and Request flag", func(t *testing.T) { + mockProvider := zigbee.MockProvider{} + defer mockProvider.AssertExpectations(t) + + expectedIEEE := zigbee.GenerateLocalAdministeredIEEEAddress() + + zed := deviceRemoval{ + nodeRemover: &mockProvider, + logger: lw.New(discard.Discard()), + node: &node{address: expectedIEEE}, + } + + mockProvider.On("RequestNodeLeave", mock.Anything, expectedIEEE).Return(nil) + + err := zed.Remove(context.Background(), capabilities.Request) + assert.NoError(t, err) + }) + + t.Run("successfully calls ForceNodeLeave on provider with devices ieee and Force flag", func(t *testing.T) { + mockProvider := zigbee.MockProvider{} + defer mockProvider.AssertExpectations(t) + + expectedIEEE := zigbee.GenerateLocalAdministeredIEEEAddress() + + zed := deviceRemoval{ + nodeRemover: &mockProvider, + logger: lw.New(discard.Discard()), + node: &node{address: expectedIEEE}, + } + + mockProvider.On("ForceNodeLeave", mock.Anything, expectedIEEE).Return(nil) + + err := zed.Remove(context.Background(), capabilities.Force) + assert.NoError(t, err) + }) +} diff --git a/enumerate_device.go b/enumerate_device.go index 43a78d5..bd3c988 100644 --- a/enumerate_device.go +++ b/enumerate_device.go @@ -257,13 +257,6 @@ func (e enumerateDevice) updateNodeTable(n *node, inventoryDevices []inventoryDe d := e.dm.createNextDevice(n) d.m.Lock() d.deviceId = i.deviceId - d.eda = &enumeratedDeviceAttachment{ - node: n, - device: d, - ed: &e, - - m: &sync.RWMutex{}, - } d.m.Unlock() deviceIdMapping[i.deviceId] = d } diff --git a/enumerate_device_test.go b/enumerate_device_test.go index b2c1e7f..2caa8f3 100644 --- a/enumerate_device_test.go +++ b/enumerate_device_test.go @@ -328,15 +328,6 @@ func Test_enumerateDevice_updateNodeTable(t *testing.T) { assert.Equal(t, d, mapping[expectedDeviceId]) assert.Equal(t, expectedDeviceId, d.deviceId) - - c := d.Capability(capabilities.EnumerateDeviceFlag) - assert.NotNil(t, c) - - cc, ok := c.(*enumeratedDeviceAttachment) - assert.True(t, ok) - - assert.Equal(t, n, cc.node) - assert.Equal(t, d, cc.device) }) t.Run("returns an existing on in mapping if present", func(t *testing.T) { diff --git a/gateway.go b/gateway.go index 6bd9fd5..f91e6c9 100644 --- a/gateway.go +++ b/gateway.go @@ -6,6 +6,7 @@ import ( "github.com/shimmeringbee/da" "github.com/shimmeringbee/da/capabilities" "github.com/shimmeringbee/logwrap" + "github.com/shimmeringbee/zda/implcaps/factory" "github.com/shimmeringbee/zda/rules" "github.com/shimmeringbee/zigbee" "log" @@ -37,6 +38,21 @@ func New(baseCtx context.Context, p zigbee.Provider, r ruleExecutor) da.Gateway events: make(chan interface{}, 0xffff), } + gw.ed = &enumerateDevice{ + gw: gw, + dm: gw, + logger: logwrap.Logger{}, + nq: gw.provider, + zclReadFn: nil, + capabilityFactory: factory.Create, + } + + if gw.ruleExecutor != nil { + gw.ed.runRulesFn = gw.ruleExecutor.Execute + } + + gw.callbacks.Add(gw.ed.onNodeJoin) + gw.WithGoLogger(log.New(os.Stderr, "", log.LstdFlags)) return gw } @@ -60,6 +76,7 @@ type gateway struct { callbacks callbacks.AdderCaller ruleExecutor ruleExecutor + ed *enumerateDevice events chan interface{} } diff --git a/table.go b/table.go index e217dfd..381f172 100644 --- a/table.go +++ b/table.go @@ -53,10 +53,26 @@ func (g *gateway) createNextDevice(n *node) *device { subId := n._nextDeviceSubIdentifier() - return g._createDevice(n, IEEEAddressWithSubIdentifier{ + d := g._createDevice(n, IEEEAddressWithSubIdentifier{ IEEEAddress: n.address, SubIdentifier: subId, }) + + d.eda = &enumeratedDeviceAttachment{ + node: n, + device: d, + ed: g.ed, + + m: &sync.RWMutex{}, + } + + d.dr = &deviceRemoval{ + node: n, + logger: g.logger, + nodeRemover: g.provider, + } + + return d } func (g *gateway) _createDevice(n *node, addr IEEEAddressWithSubIdentifier) *device { diff --git a/table_test.go b/table_test.go index 4e67d28..2177674 100644 --- a/table_test.go +++ b/table_test.go @@ -2,6 +2,7 @@ package zda import ( "context" + "github.com/shimmeringbee/da/capabilities" "github.com/shimmeringbee/zigbee" "github.com/stretchr/testify/assert" "testing" @@ -84,6 +85,12 @@ func Test_gateway_createNextDevice(t *testing.T) { assert.Equal(t, uint8(0), d.address.SubIdentifier) assert.Equal(t, g, d.gw) + assert.NotNil(t, d.eda) + assert.NotNil(t, d.dr) + + assert.Contains(t, d.Capabilities(), capabilities.EnumerateDeviceFlag) + assert.Contains(t, d.Capabilities(), capabilities.DeviceRemovalFlag) + d = g.createNextDevice(n) assert.Equal(t, addr, d.address.IEEEAddress)