diff --git a/channel.go b/channel.go index 9b430f173..8bb31dc1e 100644 --- a/channel.go +++ b/channel.go @@ -6,6 +6,13 @@ import ( "github.com/bluenviron/gomavlib/v2/pkg/frame" "github.com/bluenviron/gomavlib/v2/pkg/message" + "github.com/bluenviron/gomavlib/v2/pkg/ringbuffer" +) + +const ( + // this is low in order to avoid accumulating messages + // when a channel is reconnecting + writeBufferSize = 8 ) func randomByte() (byte, error) { @@ -21,13 +28,13 @@ func randomByte() (byte, error) { type Channel struct { e Endpoint label string - rwc io.ReadWriteCloser + rwc io.Closer n *Node frw *frame.ReadWriter running bool + writeBuffer *ringbuffer.RingBuffer // in - write chan interface{} terminate chan struct{} } @@ -56,13 +63,18 @@ func newChannel(n *Node, e Endpoint, label string, rwc io.ReadWriteCloser) (*Cha return nil, err } + writeBuffer, err := ringbuffer.New(writeBufferSize) + if err != nil { + return nil, err + } + return &Channel{ e: e, label: label, rwc: rwc, n: n, frw: frw, - write: make(chan interface{}), + writeBuffer: writeBuffer, terminate: make(chan struct{}), }, nil } @@ -85,84 +97,27 @@ func (ch *Channel) run() { defer ch.n.channelsWg.Done() readerDone := make(chan struct{}) - go func() { - defer close(readerDone) - - // wait client here, in order to allow the writer goroutine to start - // and allow clients to write messages before starting listening to events - select { - case ch.n.events <- &EventChannelOpen{ch}: - case <-ch.n.terminate: - } - - for { - fr, err := ch.frw.Read() - if err != nil { - // ignore parse errors - if _, ok := err.(*frame.ReadError); ok { - select { - case ch.n.events <- &EventParseError{err, ch}: - case <-ch.n.terminate: - } - continue - } - return - } - - evt := &EventFrame{fr, ch} - - if ch.n.nodeStreamRequest != nil { - ch.n.nodeStreamRequest.onEventFrame(evt) - } - - select { - case ch.n.events <- evt: - case <-ch.n.terminate: - } - } - }() + go ch.runReader(readerDone) writerDone := make(chan struct{}) - go func() { - defer close(writerDone) - - for what := range ch.write { - switch wh := what.(type) { - case message.Message: - ch.frw.WriteMessage(wh) //nolint:errcheck - - case frame.Frame: - ch.frw.WriteFrame(wh) //nolint:errcheck - } - } - }() + go ch.runWriter(writerDone) select { case <-readerDone: - select { - case ch.n.events <- &EventChannelClose{ch}: - case <-ch.n.terminate: - } - - select { - case ch.n.channelClose <- ch: - case <-ch.n.terminate: - } + ch.n.pushEvent(&EventChannelClose{ch}) + ch.n.closeChannel(ch) <-ch.terminate - close(ch.write) + ch.writeBuffer.Close() <-writerDone ch.rwc.Close() case <-ch.terminate: - select { - case ch.n.events <- &EventChannelClose{ch}: - case <-ch.n.terminate: - } + ch.n.pushEvent(&EventChannelClose{ch}) - close(ch.write) + ch.writeBuffer.Close() <-writerDone ch.rwc.Close() @@ -170,6 +125,52 @@ func (ch *Channel) run() { } } +func (ch *Channel) runReader(readerDone chan struct{}) { + defer close(readerDone) + + // wait client here, in order to allow the writer goroutine to start + // and allow clients to write messages before starting listening to events + ch.n.pushEvent(&EventChannelOpen{ch}) + + for { + fr, err := ch.frw.Read() + if err != nil { + if _, ok := err.(*frame.ReadError); ok { + ch.n.pushEvent(&EventParseError{err, ch}) + continue + } + return + } + + evt := &EventFrame{fr, ch} + + if ch.n.nodeStreamRequest != nil { + ch.n.nodeStreamRequest.onEventFrame(evt) + } + + ch.n.pushEvent(evt) + } +} + +func (ch *Channel) runWriter(writerDone chan struct{}) { + defer close(writerDone) + + for { + what, ok := ch.writeBuffer.Pull() + if !ok { + return + } + + switch wh := what.(type) { + case message.Message: + ch.frw.WriteMessage(wh) //nolint:errcheck + + case frame.Frame: + ch.frw.WriteFrame(wh) //nolint:errcheck + } + } +} + // String implements fmt.Stringer. func (ch *Channel) String() string { return ch.label @@ -179,3 +180,7 @@ func (ch *Channel) String() string { func (ch *Channel) Endpoint() Endpoint { return ch.e } + +func (ch *Channel) write(what interface{}) { + ch.writeBuffer.Push(what) +} diff --git a/channel_accepter.go b/channel_accepter.go index eb5781d69..3132d35fe 100644 --- a/channel_accepter.go +++ b/channel_accepter.go @@ -22,10 +22,10 @@ func (ca *channelAccepter) close() { func (ca *channelAccepter) start() { ca.n.channelAcceptersWg.Add(1) - go ca.runSingle() + go ca.run() } -func (ca *channelAccepter) runSingle() { +func (ca *channelAccepter) run() { defer ca.n.channelAcceptersWg.Done() for { @@ -42,9 +42,6 @@ func (ca *channelAccepter) runSingle() { panic(fmt.Errorf("newChannel unexpected error: %s", err)) } - select { - case ca.n.channelNew <- ch: - case <-ca.n.terminate: - } + ca.n.newChannel(ch) } } diff --git a/node.go b/node.go index ae45bbf2a..c85eac81e 100644 --- a/node.go +++ b/node.go @@ -100,15 +100,15 @@ type Node struct { nodeStreamRequest *nodeStreamRequest // in - channelNew chan *Channel - channelClose chan *Channel - writeTo chan writeToReq - writeAll chan interface{} - writeExcept chan writeExceptReq + chNewChannel chan *Channel + chCloseChannel chan *Channel + chWriteTo chan writeToReq + chWriteAll chan interface{} + chWriteExcept chan writeExceptReq terminate chan struct{} // out - events chan Event + chEvent chan Event done chan struct{} } @@ -169,13 +169,13 @@ func NewNode(conf NodeConf) (*Node, error) { dialectRW: dialectRW, channelAccepters: make(map[*channelAccepter]struct{}), channels: make(map[*Channel]struct{}), - channelNew: make(chan *Channel), - channelClose: make(chan *Channel), - writeTo: make(chan writeToReq), - writeAll: make(chan interface{}), - writeExcept: make(chan writeExceptReq), + chNewChannel: make(chan *Channel), + chCloseChannel: make(chan *Channel), + chWriteTo: make(chan writeToReq), + chWriteAll: make(chan interface{}), + chWriteExcept: make(chan writeExceptReq), terminate: make(chan struct{}), - events: make(chan Event), + chEvent: make(chan Event), done: make(chan struct{}), } @@ -256,15 +256,15 @@ func (n *Node) run() { outer: for { select { - case ch := <-n.channelNew: + case ch := <-n.chNewChannel: n.channels[ch] = struct{}{} ch.start() - case ch := <-n.channelClose: + case ch := <-n.chCloseChannel: delete(n.channels, ch) ch.close() - case req := <-n.writeTo: + case req := <-n.chWriteTo: if _, ok := n.channels[req.ch]; !ok { continue } @@ -272,25 +272,25 @@ outer: var err error req.what, err = n.encodeMessage(req.what) if err == nil { - req.ch.write <- req.what + req.ch.write(req.what) } - case what := <-n.writeAll: + case what := <-n.chWriteAll: var err error what, err = n.encodeMessage(what) if err == nil { for ch := range n.channels { - ch.write <- what + ch.write(what) } } - case req := <-n.writeExcept: + case req := <-n.chWriteExcept: var err error req.what, err = n.encodeMessage(req.what) if err == nil { for ch := range n.channels { if ch != req.except { - ch.write <- req.what + ch.write(req.what) } } } @@ -318,7 +318,7 @@ outer: } n.channelsWg.Wait() - close(n.events) + close(n.chEvent) } // FixFrame recomputes the Frame checksum and signature. @@ -409,13 +409,13 @@ func (n *Node) encodeMessage(what interface{}) (interface{}, error) { // // See individual events for details. func (n *Node) Events() chan Event { - return n.events + return n.chEvent } // WriteMessageTo writes a message to given channel. func (n *Node) WriteMessageTo(channel *Channel, m message.Message) { select { - case n.writeTo <- writeToReq{channel, m}: + case n.chWriteTo <- writeToReq{channel, m}: case <-n.terminate: } } @@ -423,7 +423,7 @@ func (n *Node) WriteMessageTo(channel *Channel, m message.Message) { // WriteMessageAll writes a message to all channels. func (n *Node) WriteMessageAll(m message.Message) { select { - case n.writeAll <- m: + case n.chWriteAll <- m: case <-n.terminate: } } @@ -431,7 +431,7 @@ func (n *Node) WriteMessageAll(m message.Message) { // WriteMessageExcept writes a message to all channels except specified channel. func (n *Node) WriteMessageExcept(exceptChannel *Channel, m message.Message) { select { - case n.writeExcept <- writeExceptReq{exceptChannel, m}: + case n.chWriteExcept <- writeExceptReq{exceptChannel, m}: case <-n.terminate: } } @@ -441,7 +441,7 @@ func (n *Node) WriteMessageExcept(exceptChannel *Channel, m message.Message) { // since all frame fields must be filled manually. func (n *Node) WriteFrameTo(channel *Channel, fr frame.Frame) { select { - case n.writeTo <- writeToReq{channel, fr}: + case n.chWriteTo <- writeToReq{channel, fr}: case <-n.terminate: } } @@ -451,7 +451,7 @@ func (n *Node) WriteFrameTo(channel *Channel, fr frame.Frame) { // since all frame fields must be filled manually. func (n *Node) WriteFrameAll(fr frame.Frame) { select { - case n.writeAll <- fr: + case n.chWriteAll <- fr: case <-n.terminate: } } @@ -461,7 +461,29 @@ func (n *Node) WriteFrameAll(fr frame.Frame) { // since all frame fields must be filled manually. func (n *Node) WriteFrameExcept(exceptChannel *Channel, fr frame.Frame) { select { - case n.writeExcept <- writeExceptReq{exceptChannel, fr}: + case n.chWriteExcept <- writeExceptReq{exceptChannel, fr}: + case <-n.terminate: + } +} + +func (n *Node) pushEvent(evt Event) { + select { + case n.chEvent <- evt: + case <-n.terminate: + } +} + +func (n *Node) newChannel(ch *Channel) { + select { + case n.chNewChannel <- ch: + case <-n.terminate: + ch.close() + } +} + +func (n *Node) closeChannel(ch *Channel) { + select { + case n.chCloseChannel <- ch: case <-n.terminate: } } diff --git a/node_streamrequest.go b/node_streamrequest.go index 608397649..d5a49ea85 100644 --- a/node_streamrequest.go +++ b/node_streamrequest.go @@ -174,13 +174,10 @@ func (sr *nodeStreamRequest) onEventFrame(evt *EventFrame) { sr.n.WriteMessageTo(evt.Channel, m.Interface().(message.Message)) } - select { - case sr.n.events <- &EventStreamRequested{ + sr.n.pushEvent(&EventStreamRequested{ Channel: evt.Channel, SystemID: evt.SystemID(), ComponentID: evt.ComponentID(), - }: - case <-sr.n.terminate: - } + }) } } diff --git a/pkg/ringbuffer/event.go b/pkg/ringbuffer/event.go new file mode 100644 index 000000000..0b2bea5f1 --- /dev/null +++ b/pkg/ringbuffer/event.go @@ -0,0 +1,38 @@ +package ringbuffer + +import ( + "sync" +) + +type event struct { + mutex sync.Mutex + cond *sync.Cond + value bool +} + +func newEvent() *event { + cv := &event{} + cv.cond = sync.NewCond(&cv.mutex) + return cv +} + +func (cv *event) signal() { + func() { + cv.mutex.Lock() + defer cv.mutex.Unlock() + cv.value = true + }() + + cv.cond.Broadcast() +} + +func (cv *event) wait() { + cv.mutex.Lock() + defer cv.mutex.Unlock() + + if !cv.value { + cv.cond.Wait() + } + + cv.value = false +} diff --git a/pkg/ringbuffer/ringbuffer.go b/pkg/ringbuffer/ringbuffer.go new file mode 100644 index 000000000..8e9fb7ad7 --- /dev/null +++ b/pkg/ringbuffer/ringbuffer.go @@ -0,0 +1,77 @@ +// Package ringbuffer contains a ring buffer. +package ringbuffer + +import ( + "fmt" + "sync/atomic" + "unsafe" +) + +// RingBuffer is a ring buffer. +type RingBuffer struct { + size uint64 + readIndex uint64 + writeIndex uint64 + closed int64 + buffer []unsafe.Pointer + event *event +} + +// New allocates a RingBuffer. +func New(size uint64) (*RingBuffer, error) { + // when writeIndex overflows, if size is not a power of + // two, only a portion of the buffer is used. + if (size & (size - 1)) != 0 { + return nil, fmt.Errorf("size must be a power of two") + } + + return &RingBuffer{ + size: size, + readIndex: 1, + writeIndex: 0, + buffer: make([]unsafe.Pointer, size), + event: newEvent(), + }, nil +} + +// Close makes Pull() return false. +func (r *RingBuffer) Close() { + atomic.StoreInt64(&r.closed, 1) + r.event.signal() +} + +// Reset restores Pull() behavior after a Close(). +func (r *RingBuffer) Reset() { + for i := uint64(0); i < r.size; i++ { + atomic.SwapPointer(&r.buffer[i], nil) + } + atomic.SwapUint64(&r.writeIndex, 0) + r.readIndex = 1 + atomic.StoreInt64(&r.closed, 0) +} + +// Push pushes data at the end of the buffer. +func (r *RingBuffer) Push(data interface{}) { + writeIndex := atomic.AddUint64(&r.writeIndex, 1) + i := writeIndex % r.size + atomic.SwapPointer(&r.buffer[i], unsafe.Pointer(&data)) + r.event.signal() +} + +// Pull pulls data from the beginning of the buffer. +func (r *RingBuffer) Pull() (interface{}, bool) { + for { + i := r.readIndex % r.size + res := (*interface{})(atomic.SwapPointer(&r.buffer[i], nil)) + if res == nil { + if atomic.SwapInt64(&r.closed, 0) == 1 { + return nil, false + } + r.event.wait() + continue + } + + r.readIndex++ + return *res, true + } +} diff --git a/pkg/ringbuffer/ringbuffer_test.go b/pkg/ringbuffer/ringbuffer_test.go new file mode 100644 index 000000000..25a9635a9 --- /dev/null +++ b/pkg/ringbuffer/ringbuffer_test.go @@ -0,0 +1,147 @@ +package ringbuffer + +import ( + "bytes" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestCreateError(t *testing.T) { + _, err := New(1000) + require.EqualError(t, err, "size must be a power of two") +} + +func TestPushBeforePull(t *testing.T) { + r, err := New(1024) + require.NoError(t, err) + defer r.Close() + + data := bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04}, 1024/4) + + r.Push(data) + ret, ok := r.Pull() + require.Equal(t, true, ok) + require.Equal(t, data, ret) +} + +func TestPullBeforePush(t *testing.T) { + r, err := New(1024) + require.NoError(t, err) + defer r.Close() + + data := bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04}, 1024/4) + + done := make(chan struct{}) + go func() { + defer close(done) + ret, ok := r.Pull() + require.Equal(t, true, ok) + require.Equal(t, data, ret) + }() + + time.Sleep(100 * time.Millisecond) + + r.Push(data) + <-done +} + +func TestClose(t *testing.T) { + r, err := New(1024) + require.NoError(t, err) + + done := make(chan struct{}) + go func() { + defer close(done) + + _, ok := r.Pull() + require.Equal(t, true, ok) + + _, ok = r.Pull() + require.Equal(t, false, ok) + }() + + r.Push([]byte{0x01, 0x02, 0x03, 0x04}) + + r.Close() + <-done + + r.Reset() + + r.Push([]byte{0x05, 0x06, 0x07, 0x08}) + + _, ok := r.Pull() + require.Equal(t, true, ok) +} + +func BenchmarkPushPullContinuous(b *testing.B) { + r, _ := New(1024 * 8) + defer r.Close() + + data := make([]byte, 1024) + + for n := 0; n < b.N; n++ { + done := make(chan struct{}) + go func() { + defer close(done) + for i := 0; i < 1024*8; i++ { + r.Push(data) + } + }() + + for i := 0; i < 1024*8; i++ { + r.Pull() + } + + <-done + } +} + +func BenchmarkPushPullPaused5(b *testing.B) { + r, _ := New(128) + defer r.Close() + + data := make([]byte, 1024) + + for n := 0; n < b.N; n++ { + done := make(chan struct{}) + go func() { + defer close(done) + for i := 0; i < 128; i++ { + r.Push(data) + time.Sleep(5 * time.Millisecond) + } + }() + + for i := 0; i < 128; i++ { + r.Pull() + } + + <-done + } +} + +func BenchmarkPushPullPaused10(b *testing.B) { + r, _ := New(1024 * 8) + defer r.Close() + + data := make([]byte, 1024) + + for n := 0; n < b.N; n++ { + done := make(chan struct{}) + go func() { + defer close(done) + for i := 0; i < 128; i++ { + r.Push(data) + time.Sleep(10 * time.Millisecond) + } + }() + + for i := 0; i < 128; i++ { + r.Pull() + } + + <-done + } +}