Skip to content

Commit

Permalink
Refactor ZCL communicator, remove Global struct, add interface for ea…
Browse files Browse the repository at this point in the history
…sy mocks, make match Ids process unique.
  • Loading branch information
pwood committed Apr 17, 2024
1 parent 732be4c commit 8400e0c
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 44 deletions.
68 changes: 30 additions & 38 deletions communicator/communicator.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ type MessageWithSource struct {
Message zcl.Message
}

var matchId = new(uint64)

type Matcher func(address zigbee.IEEEAddress, appMsg zigbee.ApplicationMessage, zclMessage zcl.Message) bool

func AddressAndSequenceMatch(matchAddress zigbee.IEEEAddress, matchSequence uint8) Matcher {
Expand All @@ -24,40 +26,38 @@ func AddressAndSequenceMatch(matchAddress zigbee.IEEEAddress, matchSequence uint
}
}

func (c *Communicator) NewMatch(matcher Matcher, callback func(source MessageWithSource)) Match {
func NewMatch(matcher Matcher, callback func(source MessageWithSource)) Match {
return Match{
Id: atomic.AddUint64(c.matchId, 1),
Matcher: matcher,
Callback: callback,
id: atomic.AddUint64(matchId, 1),
matcher: matcher,
callback: callback,
}
}

type Match struct {
Id uint64
Matcher Matcher
Callback func(source MessageWithSource)
id uint64
matcher Matcher
callback func(source MessageWithSource)
}

type Communicator struct {
type communicator struct {
Provider zigbee.Provider
CommandRegistry *zcl.CommandRegistry

mutex *sync.RWMutex
matches map[uint64]Match
matchId *uint64
}

func NewCommunicator(provider zigbee.Provider, registry *zcl.CommandRegistry) *Communicator {
return &Communicator{
func NewCommunicator(provider zigbee.Provider, registry *zcl.CommandRegistry) Communicator {
return &communicator{
Provider: provider,
CommandRegistry: registry,
mutex: &sync.RWMutex{},
matches: map[uint64]Match{},
matchId: new(uint64),
}
}

func (c *Communicator) ProcessIncomingMessage(msg zigbee.NodeIncomingMessageEvent) error {
func (c *communicator) ProcessIncomingMessage(msg zigbee.NodeIncomingMessageEvent) error {
message, err := c.CommandRegistry.Unmarshal(msg.ApplicationMessage)

if err != nil {
Expand All @@ -68,8 +68,8 @@ func (c *Communicator) ProcessIncomingMessage(msg zigbee.NodeIncomingMessageEven
defer c.mutex.RUnlock()

for _, match := range c.matches {
if match.Matcher(msg.IEEEAddress, msg.ApplicationMessage, message) {
go match.Callback(MessageWithSource{
if match.matcher(msg.IEEEAddress, msg.ApplicationMessage, message) {
go match.callback(MessageWithSource{
SourceAddress: msg.IEEEAddress,
Message: message,
})
Expand All @@ -79,21 +79,21 @@ func (c *Communicator) ProcessIncomingMessage(msg zigbee.NodeIncomingMessageEven
return nil
}

func (c *Communicator) AddCallback(match Match) {
func (c *communicator) RegisterMatch(match Match) {
c.mutex.Lock()
defer c.mutex.Unlock()

c.matches[match.Id] = match
c.matches[match.id] = match
}

func (c *Communicator) RemoveCallback(match Match) {
func (c *communicator) UnregisterMatch(match Match) {
c.mutex.Lock()
defer c.mutex.Unlock()

delete(c.matches, match.Id)
delete(c.matches, match.id)
}

func (c *Communicator) Request(ctx context.Context, address zigbee.IEEEAddress, requireAck bool, message zcl.Message) error {
func (c *communicator) Request(ctx context.Context, address zigbee.IEEEAddress, requireAck bool, message zcl.Message) error {
appMessage, err := c.CommandRegistry.Marshal(message)

if err != nil {
Expand All @@ -109,16 +109,16 @@ func (c *Communicator) Request(ctx context.Context, address zigbee.IEEEAddress,
return nil
}

func (c *Communicator) RequestResponse(ctx context.Context, address zigbee.IEEEAddress, requireAck bool, message zcl.Message) (zcl.Message, error) {
func (c *communicator) RequestResponse(ctx context.Context, address zigbee.IEEEAddress, requireAck bool, message zcl.Message) (zcl.Message, error) {
ch := make(chan zcl.Message, 1)

match := c.NewMatch(AddressAndSequenceMatch(address, message.TransactionSequence),
match := NewMatch(AddressAndSequenceMatch(address, message.TransactionSequence),
func(recvMessage MessageWithSource) {
ch <- recvMessage.Message
})

c.AddCallback(match)
defer c.RemoveCallback(match)
c.RegisterMatch(match)
defer c.UnregisterMatch(match)

if err := c.Request(ctx, address, requireAck, message); err != nil {
return zcl.Message{}, err
Expand All @@ -132,15 +132,7 @@ func (c *Communicator) RequestResponse(ctx context.Context, address zigbee.IEEEA
}
}

func (c *Communicator) Global() *GlobalCommunicator {
return &GlobalCommunicator{communicator: c}
}

type GlobalCommunicator struct {
communicator *Communicator
}

func (g *GlobalCommunicator) ReadAttributes(ctx context.Context, ieeeAddress zigbee.IEEEAddress, requireAck bool, cluster zigbee.ClusterID, code zigbee.ManufacturerCode, sourceEndpoint zigbee.Endpoint, destEndpoint zigbee.Endpoint, transactionSequence uint8, attributes []zcl.AttributeID) ([]global.ReadAttributeResponseRecord, error) {
func (c *communicator) ReadAttributes(ctx context.Context, ieeeAddress zigbee.IEEEAddress, requireAck bool, cluster zigbee.ClusterID, code zigbee.ManufacturerCode, sourceEndpoint zigbee.Endpoint, destEndpoint zigbee.Endpoint, transactionSequence uint8, attributes []zcl.AttributeID) ([]global.ReadAttributeResponseRecord, error) {
request := zcl.Message{
FrameType: zcl.FrameGlobal,
Direction: zcl.ClientToServer,
Expand All @@ -154,7 +146,7 @@ func (g *GlobalCommunicator) ReadAttributes(ctx context.Context, ieeeAddress zig
},
}

response, err := g.communicator.RequestResponse(ctx, ieeeAddress, requireAck, request)
response, err := c.RequestResponse(ctx, ieeeAddress, requireAck, request)

if err != nil {
return nil, err
Expand All @@ -177,7 +169,7 @@ func ReadResponsesToMap(recs []global.ReadAttributeResponseRecord) map[zcl.Attri
return m
}

func (g *GlobalCommunicator) WriteAttributes(ctx context.Context, ieeeAddress zigbee.IEEEAddress, requireAck bool, cluster zigbee.ClusterID, code zigbee.ManufacturerCode, sourceEndpoint zigbee.Endpoint, destEndpoint zigbee.Endpoint, transactionSequence uint8, attributes map[zcl.AttributeID]zcl.AttributeDataTypeValue) ([]global.WriteAttributesResponseRecord, error) {
func (c *communicator) WriteAttributes(ctx context.Context, ieeeAddress zigbee.IEEEAddress, requireAck bool, cluster zigbee.ClusterID, code zigbee.ManufacturerCode, sourceEndpoint zigbee.Endpoint, destEndpoint zigbee.Endpoint, transactionSequence uint8, attributes map[zcl.AttributeID]zcl.AttributeDataTypeValue) ([]global.WriteAttributesResponseRecord, error) {
var records []global.WriteAttributesRecord

for k, v := range attributes {
Expand All @@ -200,7 +192,7 @@ func (g *GlobalCommunicator) WriteAttributes(ctx context.Context, ieeeAddress zi
},
}

response, err := g.communicator.RequestResponse(ctx, ieeeAddress, requireAck, request)
response, err := c.RequestResponse(ctx, ieeeAddress, requireAck, request)

if err != nil {
return nil, err
Expand All @@ -222,7 +214,7 @@ func WriteResponsesToMap(recs []global.WriteAttributesResponseRecord) map[zcl.At
return m
}

func (g *GlobalCommunicator) ConfigureReporting(ctx context.Context, ieeeAddress zigbee.IEEEAddress, requireAck bool, cluster zigbee.ClusterID, code zigbee.ManufacturerCode, sourceEndpoint zigbee.Endpoint, destEndpoint zigbee.Endpoint, transactionSequence uint8, attributeId zcl.AttributeID, dataType zcl.AttributeDataType, minimumReportingInterval uint16, maximumReportingInterval uint16, reportableChange interface{}) error {
func (c *communicator) ConfigureReporting(ctx context.Context, ieeeAddress zigbee.IEEEAddress, requireAck bool, cluster zigbee.ClusterID, code zigbee.ManufacturerCode, sourceEndpoint zigbee.Endpoint, destEndpoint zigbee.Endpoint, transactionSequence uint8, attributeId zcl.AttributeID, dataType zcl.AttributeDataType, minimumReportingInterval uint16, maximumReportingInterval uint16, reportableChange interface{}) error {
request := zcl.Message{
FrameType: zcl.FrameGlobal,
Direction: zcl.ClientToServer,
Expand All @@ -246,7 +238,7 @@ func (g *GlobalCommunicator) ConfigureReporting(ctx context.Context, ieeeAddress
},
}

response, err := g.communicator.RequestResponse(ctx, ieeeAddress, requireAck, request)
response, err := c.RequestResponse(ctx, ieeeAddress, requireAck, request)

if err != nil {
return err
Expand Down
9 changes: 3 additions & 6 deletions communicator/communicator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,6 @@ func TestCommunicator_GlobalReadAttributes(t *testing.T) {
ieee := zigbee.IEEEAddress(0x0102030405060708)

c := NewCommunicator(mockProvider, cr)
g := c.Global()

expectedValue := "value"
clusterId := zigbee.ClusterID(0x1223)
Expand Down Expand Up @@ -416,7 +415,7 @@ func TestCommunicator_GlobalReadAttributes(t *testing.T) {
})
})

resp, err := g.ReadAttributes(context.Background(), ieee, true, clusterId, zigbee.NoManufacturer, srcEndpoint, destEndpoint, transactionSequence, []zcl.AttributeID{0x0004, 0x0005})
resp, err := c.ReadAttributes(context.Background(), ieee, true, clusterId, zigbee.NoManufacturer, srcEndpoint, destEndpoint, transactionSequence, []zcl.AttributeID{0x0004, 0x0005})
assert.NoError(t, err)
assert.Equal(t, expectedResponse.Records, resp)
})
Expand All @@ -431,7 +430,6 @@ func TestCommunicator_GlobalWritesAttributes(t *testing.T) {
ieee := zigbee.IEEEAddress(0x0102030405060708)

c := NewCommunicator(mockProvider, cr)
g := c.Global()

clusterId := zigbee.ClusterID(0x1223)

Expand Down Expand Up @@ -511,7 +509,7 @@ func TestCommunicator_GlobalWritesAttributes(t *testing.T) {
})
})

resp, err := g.WriteAttributes(context.Background(), ieee, true, clusterId, zigbee.NoManufacturer, srcEndpoint, destEndpoint, transactionSequence, map[zcl.AttributeID]zcl.AttributeDataTypeValue{0x0004: {DataType: zcl.TypeUnsignedInt8, Value: uint(8)}})
resp, err := c.WriteAttributes(context.Background(), ieee, true, clusterId, zigbee.NoManufacturer, srcEndpoint, destEndpoint, transactionSequence, map[zcl.AttributeID]zcl.AttributeDataTypeValue{0x0004: {DataType: zcl.TypeUnsignedInt8, Value: uint(8)}})
assert.NoError(t, err)
assert.Equal(t, expectedResponse.Records, resp)
})
Expand All @@ -526,7 +524,6 @@ func TestCommunicator_GlobalConfigureReporting(t *testing.T) {
ieee := zigbee.IEEEAddress(0x0102030405060708)

c := NewCommunicator(mockProvider, cr)
g := c.Global()

clusterId := zigbee.ClusterID(0x1223)

Expand Down Expand Up @@ -615,7 +612,7 @@ func TestCommunicator_GlobalConfigureReporting(t *testing.T) {
})
})

err := g.ConfigureReporting(context.Background(), ieee, true, clusterId, zigbee.NoManufacturer, srcEndpoint, destEndpoint, transactionSequence, attributeId, dataType, minInterval, maxInterval, reportableChange)
err := c.ConfigureReporting(context.Background(), ieee, true, clusterId, zigbee.NoManufacturer, srcEndpoint, destEndpoint, transactionSequence, attributeId, dataType, minInterval, maxInterval, reportableChange)
assert.NoError(t, err)
})
}
22 changes: 22 additions & 0 deletions communicator/interface.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package communicator

import (
"context"
"github.com/shimmeringbee/zcl"
"github.com/shimmeringbee/zcl/commands/global"
"github.com/shimmeringbee/zigbee"
)

type Communicator interface {
RegisterMatch(match Match)
UnregisterMatch(match Match)

ProcessIncomingMessage(msg zigbee.NodeIncomingMessageEvent) error

Request(ctx context.Context, address zigbee.IEEEAddress, requireAck bool, message zcl.Message) error
RequestResponse(ctx context.Context, address zigbee.IEEEAddress, requireAck bool, message zcl.Message) (zcl.Message, error)

ReadAttributes(ctx context.Context, ieeeAddress zigbee.IEEEAddress, requireAck bool, cluster zigbee.ClusterID, code zigbee.ManufacturerCode, sourceEndpoint zigbee.Endpoint, destEndpoint zigbee.Endpoint, transactionSequence uint8, attributes []zcl.AttributeID) ([]global.ReadAttributeResponseRecord, error)
WriteAttributes(ctx context.Context, ieeeAddress zigbee.IEEEAddress, requireAck bool, cluster zigbee.ClusterID, code zigbee.ManufacturerCode, sourceEndpoint zigbee.Endpoint, destEndpoint zigbee.Endpoint, transactionSequence uint8, attributes map[zcl.AttributeID]zcl.AttributeDataTypeValue) ([]global.WriteAttributesResponseRecord, error)
ConfigureReporting(ctx context.Context, ieeeAddress zigbee.IEEEAddress, requireAck bool, cluster zigbee.ClusterID, code zigbee.ManufacturerCode, sourceEndpoint zigbee.Endpoint, destEndpoint zigbee.Endpoint, transactionSequence uint8, attributeId zcl.AttributeID, dataType zcl.AttributeDataType, minimumReportingInterval uint16, maximumReportingInterval uint16, reportableChange interface{}) error
}

0 comments on commit 8400e0c

Please sign in to comment.