From 66656b6b91ad403bdc8fdf65e825064c414e4967 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lovro=20Ma=C5=BEgon?= Date: Fri, 20 Oct 2023 21:01:53 +0200 Subject: [PATCH] optimize fanin and fanout nodes to use less reflection --- pkg/pipeline/stream/fanin.go | 56 ++++++---- pkg/pipeline/stream/fanin_select.go | 168 ++++++++++++++++++++++++++++ pkg/pipeline/stream/fanout.go | 25 +++++ pkg/pipeline/stream/message.go | 3 +- 4 files changed, 228 insertions(+), 24 deletions(-) create mode 100644 pkg/pipeline/stream/fanin_select.go diff --git a/pkg/pipeline/stream/fanin.go b/pkg/pipeline/stream/fanin.go index 494ed8cdf..0e2bd4396 100644 --- a/pkg/pipeline/stream/fanin.go +++ b/pkg/pipeline/stream/fanin.go @@ -16,7 +16,6 @@ package stream import ( "context" - "reflect" ) type FaninNode struct { @@ -49,31 +48,14 @@ func (n *FaninNode) Run(ctx context.Context) error { n.running = false }() - cases := make([]reflect.SelectCase, len(n.in)+1) - cases[0] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(ctx.Done())} - for i, ch := range n.in { - cases[i+1] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(ch)} - } + trigger := n.trigger(ctx) for { - chosen, value, ok := reflect.Select(cases) - // ok will be true if the channel has not been closed. - if !ok { - if chosen == 0 { - // context is done - return ctx.Err() - } - // one of the in channels is closed, remove it from select case - cases = append(cases[:chosen], cases[chosen+1:]...) - if len(cases) == 1 { - // only context is left, we're done - return nil - } - continue + msg, err := trigger() + if err != nil || msg == nil { + return err } - msg := value.Interface().(*Message) - select { case <-ctx.Done(): return msg.Nack(ctx.Err(), n.ID()) @@ -82,6 +64,36 @@ func (n *FaninNode) Run(ctx context.Context) error { } } +func (n *FaninNode) trigger(ctx context.Context) func() (*Message, error) { + in := make([]<-chan *Message, len(n.in)) + copy(in, n.in) + + f := n.chooseSelectFunc(ctx, in) + + return func() (*Message, error) { + for { + chosen, msg, ok := f() + // ok will be true if the channel has not been closed. + if !ok { + if chosen == 0 { + // context is done + return nil, ctx.Err() + } + // one of the in channels is closed, remove it from select case + in = append(in[:chosen-1], in[chosen:]...) + if len(in) == 0 { + // only context is left, we're done + return nil, nil + } + + f = n.chooseSelectFunc(ctx, in) + continue // keep selecting with new select func + } + return msg, nil + } + } +} + func (n *FaninNode) Sub(in <-chan *Message) { n.in = append(n.in, in) } diff --git a/pkg/pipeline/stream/fanin_select.go b/pkg/pipeline/stream/fanin_select.go new file mode 100644 index 000000000..1270d616c --- /dev/null +++ b/pkg/pipeline/stream/fanin_select.go @@ -0,0 +1,168 @@ +// Copyright © 2023 Meroxa, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stream + +import ( + "context" + "reflect" +) + +func (n *FaninNode) chooseSelectFunc(ctx context.Context, in []<-chan *Message) func() (int, *Message, bool) { + switch len(in) { + case 1: + return func() (int, *Message, bool) { return n.select1(ctx, in[0]) } + case 2: + return func() (int, *Message, bool) { return n.select2(ctx, in[0], in[1]) } + case 3: + return func() (int, *Message, bool) { return n.select3(ctx, in[0], in[1], in[2]) } + case 4: + return func() (int, *Message, bool) { return n.select4(ctx, in[0], in[1], in[2], in[3]) } + case 5: + return func() (int, *Message, bool) { return n.select5(ctx, in[0], in[1], in[2], in[3], in[4]) } + case 6: + return func() (int, *Message, bool) { return n.select6(ctx, in[0], in[1], in[2], in[3], in[4], in[5]) } + default: + // use reflection for more channels + cases := make([]reflect.SelectCase, len(in)+1) + cases[0] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(ctx.Done())} + for i, ch := range in { + cases[i+1] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(ch)} + } + return func() (int, *Message, bool) { + chosen, value, ok := reflect.Select(cases) + if !ok { // a channel was closed + return chosen, nil, ok + } + return chosen, value.Interface().(*Message), ok + } + } +} + +func (*FaninNode) select1( + ctx context.Context, + c1 <-chan *Message, +) (int, *Message, bool) { + select { + case <-ctx.Done(): + return 0, nil, false + case val, ok := <-c1: + return 1, val, ok + } +} + +func (*FaninNode) select2( + ctx context.Context, + c1 <-chan *Message, + c2 <-chan *Message, +) (int, *Message, bool) { + select { + case <-ctx.Done(): + return 0, nil, false + case val, ok := <-c1: + return 1, val, ok + case val, ok := <-c2: + return 2, val, ok + } +} + +func (*FaninNode) select3( + ctx context.Context, + c1 <-chan *Message, + c2 <-chan *Message, + c3 <-chan *Message, +) (int, *Message, bool) { + select { + case <-ctx.Done(): + return 0, nil, false + case val, ok := <-c1: + return 1, val, ok + case val, ok := <-c2: + return 2, val, ok + case val, ok := <-c3: + return 3, val, ok + } +} + +func (*FaninNode) select4( + ctx context.Context, + c1 <-chan *Message, + c2 <-chan *Message, + c3 <-chan *Message, + c4 <-chan *Message, +) (int, *Message, bool) { + select { + case <-ctx.Done(): + return 0, nil, false + case val, ok := <-c1: + return 1, val, ok + case val, ok := <-c2: + return 2, val, ok + case val, ok := <-c3: + return 3, val, ok + case val, ok := <-c4: + return 4, val, ok + } +} + +func (*FaninNode) select5( + ctx context.Context, + c1 <-chan *Message, + c2 <-chan *Message, + c3 <-chan *Message, + c4 <-chan *Message, + c5 <-chan *Message, +) (int, *Message, bool) { + select { + case <-ctx.Done(): + return 0, nil, false + case val, ok := <-c1: + return 1, val, ok + case val, ok := <-c2: + return 2, val, ok + case val, ok := <-c3: + return 3, val, ok + case val, ok := <-c4: + return 4, val, ok + case val, ok := <-c5: + return 5, val, ok + } +} + +func (*FaninNode) select6( + ctx context.Context, + c1 <-chan *Message, + c2 <-chan *Message, + c3 <-chan *Message, + c4 <-chan *Message, + c5 <-chan *Message, + c6 <-chan *Message, +) (int, *Message, bool) { + select { + case <-ctx.Done(): + return 0, nil, false + case val, ok := <-c1: + return 1, val, ok + case val, ok := <-c2: + return 2, val, ok + case val, ok := <-c3: + return 3, val, ok + case val, ok := <-c4: + return 4, val, ok + case val, ok := <-c5: + return 5, val, ok + case val, ok := <-c6: + return 6, val, ok + } +} diff --git a/pkg/pipeline/stream/fanout.go b/pkg/pipeline/stream/fanout.go index cf1328d93..b916af721 100644 --- a/pkg/pipeline/stream/fanout.go +++ b/pkg/pipeline/stream/fanout.go @@ -54,6 +54,11 @@ func (n *FanoutNode) Run(ctx context.Context) error { n.running = false }() + if len(n.out) == 1 { + // shortcut if there's only 1 destination + return n.select1(ctx) + } + var wg sync.WaitGroup for { select { @@ -141,6 +146,26 @@ func (n *FanoutNode) Run(ctx context.Context) error { } } +func (n *FanoutNode) select1(ctx context.Context) error { + for { + select { + case <-ctx.Done(): + return ctx.Err() + case msg, ok := <-n.in: + if !ok { + // pipeline closed + return nil + } + select { + case <-ctx.Done(): + return msg.Nack(ctx.Err(), n.ID()) + case n.out[0] <- msg: + // all good + } + } + } +} + // wrapAckHandler modifies the ack handler, so it's called with the original // message received by FanoutNode instead of the new message created by // FanoutNode. diff --git a/pkg/pipeline/stream/message.go b/pkg/pipeline/stream/message.go index d12717bed..26f740c2a 100644 --- a/pkg/pipeline/stream/message.go +++ b/pkg/pipeline/stream/message.go @@ -18,7 +18,6 @@ package stream import ( "context" - "fmt" "sync" "github.com/conduitio/conduit/pkg/foundation/cerrors" @@ -131,7 +130,7 @@ func (m *Message) init() { // ID returns a string representing a unique ID of this message. This is meant // only for logging purposes. func (m *Message) ID() string { - return fmt.Sprintf("%s/%s", m.SourceID, m.Record.Position) + return m.SourceID + "/" + string(m.Record.Position) } func (m *Message) ControlMessageType() ControlMessageType {