From 3916c76f1a9080941f3c474aa96c785e0a97b987 Mon Sep 17 00:00:00 2001 From: Shezad Khan Date: Mon, 13 May 2024 18:11:14 +0100 Subject: [PATCH 1/2] feat: Add weighted fair receiver --- decorators/wfr/receiver.go | 189 ++++++++++++++++++++++++++++++++ decorators/wfr/receiver_test.go | 1 + go.mod | 1 + go.sum | 31 +----- 4 files changed, 194 insertions(+), 28 deletions(-) create mode 100644 decorators/wfr/receiver.go create mode 100644 decorators/wfr/receiver_test.go diff --git a/decorators/wfr/receiver.go b/decorators/wfr/receiver.go new file mode 100644 index 0000000..dffc11b --- /dev/null +++ b/decorators/wfr/receiver.go @@ -0,0 +1,189 @@ +package wfr + +import ( + "context" + "time" + + pq "github.com/JimWen/gods-generic/queues/priorityqueue" + "github.com/JimWen/gods-generic/utils" + "github.com/zerofox-oss/go-msg" +) + +type weightedMessage struct { + msg *msg.Message + context context.Context + vFinish float64 + priority int + doneChan chan error +} + +// WeightedFairReceiver implements a "fair" receiver that processes messages based on their priority. +// See https://en.wikipedia.org/wiki/Fair_queuing for more information, about fairness. +// The concrete implmentation is based on the Weighted Fair Queuing algorithm. +// At a high level the algorithm works as follows: +// 1. Each message is asigned a virtual finish time. The virtual finish time is defined as the +// time of the last message processed for that priority plus a weigting factor. The higher the +// the weight the smaller the virtual finish time is for that message. +// 2. The message is then enqueued in a priority queue based on the virtual finish time. +// 3. Every "tick" we select the the message with the smallest virtual finish time and process it. +type WeightedFairReceiver struct { + + // underlying receiver + receiver msg.Receiver + + // time to wait for a message to arrive + // the longer you wait, the better the fairness + // the shorter the better the latency + // a good value should be around the mean processing + // time for your receiver. If in doubt 50ms may be good. + queueWaitTime time.Duration + + // weights for each priority level + // the higher the weight, the more messages + // will be processed for that priority + weights []float64 + + // max number of concurrent messages that can + // be processed at the same time by the receiver + maxConcurrent int + + receiveChan chan *weightedMessage + queue *pq.Queue[*weightedMessage] + startTime int + lastVFinish []float64 +} + +func (w *WeightedFairReceiver) WithPriorityReceiver(priority int) msg.Receiver { + if priority < 0 || priority >= len(w.weights) { + panic("invalid priority") + } + return msg.ReceiverFunc(func(ctx context.Context, m *msg.Message) error { + return w.receiveWithPriority(ctx, m, priority) + }) +} + +func (r *WeightedFairReceiver) receiveWithPriority(ctx context.Context, m *msg.Message, priority int) error { + wm := &weightedMessage{ + msg: m, + priority: priority, + doneChan: make(chan error, 1), + context: ctx, + } + r.receiveChan <- wm + return <-wm.doneChan +} + +func (r *WeightedFairReceiver) dispatch(ctx context.Context) { + maxConcurrentReceives := make(chan struct{}, r.maxConcurrent) + ticker := time.NewTicker(r.queueWaitTime) + + // New messages to be processed are enqueued from the receiveChan + // with their virtual finish time calculated based on the priority. + // Every tick we select the message with the smallest virtual finish + // time and process it. + + for { + select { + case <-ctx.Done(): + return + case wm := <-r.receiveChan: + vStart := max(float64(r.startTime), r.lastVFinish[wm.priority]) + weight := float64(1.0 / r.weights[wm.priority]) + vFinish := vStart + weight + wm.vFinish = vFinish + r.lastVFinish[wm.priority] = vFinish + r.queue.Enqueue(wm) + case <-ticker.C: + mw, ok := r.queue.Dequeue() + if !ok { + continue + } + + maxConcurrentReceives <- struct{}{} + go func() { + defer func() { + <-maxConcurrentReceives + }() + mw.doneChan <- r.receiver.Receive(mw.context, mw.msg) + }() + } + } +} + +type WeightedFairReceiverOption func(*WeightedFairReceiver) + +// WithQueueWaitTime sets the time to wait for a message to arrive. +// The longer you wait, the better the fairness. +func WithQueueWaitTime(d time.Duration) WeightedFairReceiverOption { + return func(r *WeightedFairReceiver) { + r.queueWaitTime = d + } +} + +// WithMaxConcurrent sets the maximum number of concurrent messages that can be processed +// by the receiver. If the maximum number of concurrent messages is reached, the receiver +// will block until a message is processed. +func WithMaxConcurrent(max int) WeightedFairReceiverOption { + return func(r *WeightedFairReceiver) { + r.maxConcurrent = max + } +} + +// NewWeightedFairReceiver creates a new WeightedFairReceiver with the given receiver and weights. +// The receiver will process messages based on their priority and the weights provided. +// Example: +// +// // The receiver will ensure that messages with priority 0 are processed 1/6 of the time, +// // messages with priority 1 are processed 1/3 of the time and messages with priority 2 are +// // processed 1/2 of the time. +// r := NewWeightedFairReceiver(ctx, receiver, []float64{1, 2, 3}) +// r0 = r.WithPriorityReceiver(0) // will process messages with priority 0 +// r1 = r.WithPriorityReceiver(1) // will process messages with priority 1 +// r2 = r.WithPriorityReceiver(2) // will process messages with priority 2 +// srv0, err := sqs.NewServer(c.String("sqs-queue-url-0"), 50, 100) +// if err != nil { +// return fmt.Errorf("failed to establish SQS connection: %w", err) +// } +// srv1, err := sqs.NewServer(c.String("sqs-queue-url-1"), 50, 100) +// if err != nil { +// return fmt.Errorf("failed to establish SQS connection: %w", err) +// } +// srv2, err := sqs.NewServer(c.String("sqs-queue-url-2"), 50, 100) +// if err != nil { +// return fmt.Errorf("failed to establish SQS connection: %w", err) +// } +// errGroup, _ := errgroup.WithContext(context.Background()) +// errGroup.Go(func() error { +// return srv0.Serve(lowPriorityReceiver) +// }) +// errGroup.Go(func() error { +// return srv1.Serve(mediumPriorityReceiver) +// }) +// errGroup.Go(func() error { +// return srv2.Serve(highPriorityReceiver) +// }) +// errGroup.Wait() +func NewWeightedFairReceiver(ctx context.Context, receiver msg.Receiver, weights []float64, opts ...WeightedFairReceiverOption) *WeightedFairReceiver { + + priorityQueue := pq.NewWith(func(a, b *weightedMessage) int { + return utils.NumberComparator(a.vFinish, b.vFinish) + }) + + r := &WeightedFairReceiver{ + receiver: receiver, + weights: weights, + queueWaitTime: 50 * time.Millisecond, + maxConcurrent: 10, + receiveChan: make(chan *weightedMessage), + queue: priorityQueue, + startTime: int(time.Now().Unix()), + lastVFinish: make([]float64, len(weights)), + } + + for _, opt := range opts { + opt(r) + } + + go r.dispatch(ctx) + return r +} diff --git a/decorators/wfr/receiver_test.go b/decorators/wfr/receiver_test.go new file mode 100644 index 0000000..1c444dd --- /dev/null +++ b/decorators/wfr/receiver_test.go @@ -0,0 +1 @@ +package wfr diff --git a/go.mod b/go.mod index 8cde497..662704e 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/zerofox-oss/go-msg go 1.21 require ( + github.com/JimWen/gods-generic v0.10.2 github.com/google/go-cmp v0.6.0 github.com/pierrec/lz4/v4 v4.1.8 go.opencensus.io v0.24.0 diff --git a/go.sum b/go.sum index 35764e3..6325ae4 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/JimWen/gods-generic v0.10.2 h1:ib/BF6W5+ANQJinlNxHYETH1BtxZASkBOV3v4mHSYYY= +github.com/JimWen/gods-generic v0.10.2/go.mod h1:ukDWk4Hb0hovQbhqitDTeOK4Hz+IK0y3q5QKQdri3as= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= @@ -11,8 +13,6 @@ github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.m github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= -github.com/go-logr/logr v1.4.1 h1:pKouT5E8xu9zeFC39JXRDukb6JFQPXM5p5I91188VAQ= -github.com/go-logr/logr v1.4.1/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= @@ -52,46 +52,21 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= -github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= -github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0= go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo= go.opentelemetry.io/otel v1.24.0 h1:0LAOdjNmQeSTzGBzduGe/rU4tZhMwL5rWgtp9Ku5Jfo= go.opentelemetry.io/otel v1.24.0/go.mod h1:W7b9Ozg4nkF5tWI5zsXkaKKDjdVjpD4oAt9Qi/MArHo= -go.opentelemetry.io/otel v1.25.0 h1:gldB5FfhRl7OJQbUHt/8s0a7cE8fbsPAtdpRaApKy4k= -go.opentelemetry.io/otel v1.25.0/go.mod h1:Wa2ds5NOXEMkCmUou1WA7ZBfLTHWIsp034OVD7AO+Vg= -go.opentelemetry.io/otel v1.26.0 h1:LQwgL5s/1W7YiiRwxf03QGnWLb2HW4pLiAhaA5cZXBs= -go.opentelemetry.io/otel v1.26.0/go.mod h1:UmLkJHUAidDval2EICqBMbnAd0/m2vmpf/dAM+fvFs4= go.opentelemetry.io/otel/bridge/opencensus v1.24.0 h1:Vlhy5ee5k5R0zASpH+9AgHiJH7xnKACI3XopO1tUZfY= go.opentelemetry.io/otel/bridge/opencensus v1.24.0/go.mod h1:jRjVXV/X38jyrnHtvMGN8+9cejZB21JvXAAvooF2s+Q= -go.opentelemetry.io/otel/bridge/opencensus v1.25.0 h1:0o/9KwAgxjK+3pMV0pwIF5toYHqDsPmQhfrBvKaG6mU= -go.opentelemetry.io/otel/bridge/opencensus v1.25.0/go.mod h1:rZyTdpmRqoV+PpUn6QlruxJp/kE4765rPy0pP6mRDk8= -go.opentelemetry.io/otel/bridge/opencensus v1.26.0 h1:DZzxj9QjznMVoehskOJnFP2gsTCWtDTFBDvFhPAY7nc= -go.opentelemetry.io/otel/bridge/opencensus v1.26.0/go.mod h1:rJiX0KrF5m8Tm1XE8jLczpAv5zUaDcvhKecFG0ZoFG4= go.opentelemetry.io/otel/metric v1.24.0 h1:6EhoGWWK28x1fbpA4tYTOWBkPefTDQnb8WSGXlc88kI= go.opentelemetry.io/otel/metric v1.24.0/go.mod h1:VYhLe1rFfxuTXLgj4CBiyz+9WYBA8pNGJgDcSFRKBco= -go.opentelemetry.io/otel/metric v1.25.0 h1:LUKbS7ArpFL/I2jJHdJcqMGxkRdxpPHE0VU/D4NuEwA= -go.opentelemetry.io/otel/metric v1.25.0/go.mod h1:rkDLUSd2lC5lq2dFNrX9LGAbINP5B7WBkC78RXCpH5s= -go.opentelemetry.io/otel/metric v1.26.0 h1:7S39CLuY5Jgg9CrnA9HHiEjGMF/X2VHvoXGgSllRz30= -go.opentelemetry.io/otel/metric v1.26.0/go.mod h1:SY+rHOI4cEawI9a7N1A4nIg/nTQXe1ccCNWYOJUrpX4= go.opentelemetry.io/otel/sdk v1.24.0 h1:YMPPDNymmQN3ZgczicBY3B6sf9n62Dlj9pWD3ucgoDw= go.opentelemetry.io/otel/sdk v1.24.0/go.mod h1:KVrIYw6tEubO9E96HQpcmpTKDVn9gdv35HoYiQWGDFg= -go.opentelemetry.io/otel/sdk v1.25.0 h1:PDryEJPC8YJZQSyLY5eqLeafHtG+X7FWnf3aXMtxbqo= -go.opentelemetry.io/otel/sdk v1.25.0/go.mod h1:oFgzCM2zdsxKzz6zwpTZYLLQsFwc+K0daArPdIhuxkw= -go.opentelemetry.io/otel/sdk v1.26.0 h1:Y7bumHf5tAiDlRYFmGqetNcLaVUZmh4iYfmGxtmz7F8= -go.opentelemetry.io/otel/sdk v1.26.0/go.mod h1:0p8MXpqLeJ0pzcszQQN4F0S5FVjBLgypeGSngLsmirs= go.opentelemetry.io/otel/sdk/metric v1.24.0 h1:yyMQrPzF+k88/DbH7o4FMAs80puqd+9osbiBrJrz/w8= go.opentelemetry.io/otel/sdk/metric v1.24.0/go.mod h1:I6Y5FjH6rvEnTTAYQz3Mmv2kl6Ek5IIrmwTLqMrrOE0= -go.opentelemetry.io/otel/sdk/metric v1.25.0 h1:7CiHOy08LbrxMAp4vWpbiPcklunUshVpAvGBrdDRlGw= -go.opentelemetry.io/otel/sdk/metric v1.25.0/go.mod h1:LzwoKptdbBBdYfvtGCzGwk6GWMA3aUzBOwtQpR6Nz7o= -go.opentelemetry.io/otel/sdk/metric v1.26.0 h1:cWSks5tfriHPdWFnl+qpX3P681aAYqlZHcAyHw5aU9Y= -go.opentelemetry.io/otel/sdk/metric v1.26.0/go.mod h1:ClMFFknnThJCksebJwz7KIyEDHO+nTB6gK8obLy8RyE= go.opentelemetry.io/otel/trace v1.24.0 h1:CsKnnL4dUAr/0llH9FKuc698G04IrpWV0MQA/Y1YELI= go.opentelemetry.io/otel/trace v1.24.0/go.mod h1:HPc3Xr/cOApsBI154IU0OI0HJexz+aw5uPdbs3UCjNU= -go.opentelemetry.io/otel/trace v1.25.0 h1:tqukZGLwQYRIFtSQM2u2+yfMVTgGVeqRLPUYx1Dq6RM= -go.opentelemetry.io/otel/trace v1.25.0/go.mod h1:hCCs70XM/ljO+BeQkyFnbK28SBIJ/Emuha+ccrCRT7I= -go.opentelemetry.io/otel/trace v1.26.0 h1:1ieeAUb4y0TE26jUFrCIXKpTuVK7uJGN9/Z/2LP5sQA= -go.opentelemetry.io/otel/trace v1.26.0/go.mod h1:4iDxvGDQuUkHve82hJJ8UqrwswHYsZuWCBllGV2U2y0= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= From 61477852de429968c3ef05ed3100763b6978cefc Mon Sep 17 00:00:00 2001 From: Shezad Khan Date: Thu, 27 Jun 2024 10:53:14 +0100 Subject: [PATCH 2/2] feat: add multiserver implementation --- Makefile | 2 +- decorators/wfr/receiver.go | 189 ------------------ decorators/wfr/receiver_test.go | 1 - docs/decisions/2024-07-08-multiserver.md | 19 ++ ...ges => 2024-07-09-experimental-changes.md} | 0 go.mod | 8 + go.sum | 12 ++ x/README.md | 2 +- x/multiserver/README.md | 65 ++++++ x/multiserver/multiserver.go | 108 ++++++++++ x/multiserver/multiserver_test.go | 112 +++++++++++ x/multiserver/receiver.go | 180 +++++++++++++++++ 12 files changed, 506 insertions(+), 192 deletions(-) delete mode 100644 decorators/wfr/receiver.go delete mode 100644 decorators/wfr/receiver_test.go create mode 100644 docs/decisions/2024-07-08-multiserver.md rename docs/decisions/{2024-07-09-experimental-changes => 2024-07-09-experimental-changes.md} (100%) create mode 100644 x/multiserver/README.md create mode 100644 x/multiserver/multiserver.go create mode 100644 x/multiserver/multiserver_test.go create mode 100644 x/multiserver/receiver.go diff --git a/Makefile b/Makefile index 9be6773..6e6ddf9 100644 --- a/Makefile +++ b/Makefile @@ -9,7 +9,7 @@ lint: init golangci-lint run ./... test: init - go test -race -v ./... + go test -race ./... tools: curl -sfL https://install.goreleaser.com/github.com/golangci/golangci-lint.sh | sh -s -- -b $(GOPATH)/bin v1.17.0 diff --git a/decorators/wfr/receiver.go b/decorators/wfr/receiver.go deleted file mode 100644 index dffc11b..0000000 --- a/decorators/wfr/receiver.go +++ /dev/null @@ -1,189 +0,0 @@ -package wfr - -import ( - "context" - "time" - - pq "github.com/JimWen/gods-generic/queues/priorityqueue" - "github.com/JimWen/gods-generic/utils" - "github.com/zerofox-oss/go-msg" -) - -type weightedMessage struct { - msg *msg.Message - context context.Context - vFinish float64 - priority int - doneChan chan error -} - -// WeightedFairReceiver implements a "fair" receiver that processes messages based on their priority. -// See https://en.wikipedia.org/wiki/Fair_queuing for more information, about fairness. -// The concrete implmentation is based on the Weighted Fair Queuing algorithm. -// At a high level the algorithm works as follows: -// 1. Each message is asigned a virtual finish time. The virtual finish time is defined as the -// time of the last message processed for that priority plus a weigting factor. The higher the -// the weight the smaller the virtual finish time is for that message. -// 2. The message is then enqueued in a priority queue based on the virtual finish time. -// 3. Every "tick" we select the the message with the smallest virtual finish time and process it. -type WeightedFairReceiver struct { - - // underlying receiver - receiver msg.Receiver - - // time to wait for a message to arrive - // the longer you wait, the better the fairness - // the shorter the better the latency - // a good value should be around the mean processing - // time for your receiver. If in doubt 50ms may be good. - queueWaitTime time.Duration - - // weights for each priority level - // the higher the weight, the more messages - // will be processed for that priority - weights []float64 - - // max number of concurrent messages that can - // be processed at the same time by the receiver - maxConcurrent int - - receiveChan chan *weightedMessage - queue *pq.Queue[*weightedMessage] - startTime int - lastVFinish []float64 -} - -func (w *WeightedFairReceiver) WithPriorityReceiver(priority int) msg.Receiver { - if priority < 0 || priority >= len(w.weights) { - panic("invalid priority") - } - return msg.ReceiverFunc(func(ctx context.Context, m *msg.Message) error { - return w.receiveWithPriority(ctx, m, priority) - }) -} - -func (r *WeightedFairReceiver) receiveWithPriority(ctx context.Context, m *msg.Message, priority int) error { - wm := &weightedMessage{ - msg: m, - priority: priority, - doneChan: make(chan error, 1), - context: ctx, - } - r.receiveChan <- wm - return <-wm.doneChan -} - -func (r *WeightedFairReceiver) dispatch(ctx context.Context) { - maxConcurrentReceives := make(chan struct{}, r.maxConcurrent) - ticker := time.NewTicker(r.queueWaitTime) - - // New messages to be processed are enqueued from the receiveChan - // with their virtual finish time calculated based on the priority. - // Every tick we select the message with the smallest virtual finish - // time and process it. - - for { - select { - case <-ctx.Done(): - return - case wm := <-r.receiveChan: - vStart := max(float64(r.startTime), r.lastVFinish[wm.priority]) - weight := float64(1.0 / r.weights[wm.priority]) - vFinish := vStart + weight - wm.vFinish = vFinish - r.lastVFinish[wm.priority] = vFinish - r.queue.Enqueue(wm) - case <-ticker.C: - mw, ok := r.queue.Dequeue() - if !ok { - continue - } - - maxConcurrentReceives <- struct{}{} - go func() { - defer func() { - <-maxConcurrentReceives - }() - mw.doneChan <- r.receiver.Receive(mw.context, mw.msg) - }() - } - } -} - -type WeightedFairReceiverOption func(*WeightedFairReceiver) - -// WithQueueWaitTime sets the time to wait for a message to arrive. -// The longer you wait, the better the fairness. -func WithQueueWaitTime(d time.Duration) WeightedFairReceiverOption { - return func(r *WeightedFairReceiver) { - r.queueWaitTime = d - } -} - -// WithMaxConcurrent sets the maximum number of concurrent messages that can be processed -// by the receiver. If the maximum number of concurrent messages is reached, the receiver -// will block until a message is processed. -func WithMaxConcurrent(max int) WeightedFairReceiverOption { - return func(r *WeightedFairReceiver) { - r.maxConcurrent = max - } -} - -// NewWeightedFairReceiver creates a new WeightedFairReceiver with the given receiver and weights. -// The receiver will process messages based on their priority and the weights provided. -// Example: -// -// // The receiver will ensure that messages with priority 0 are processed 1/6 of the time, -// // messages with priority 1 are processed 1/3 of the time and messages with priority 2 are -// // processed 1/2 of the time. -// r := NewWeightedFairReceiver(ctx, receiver, []float64{1, 2, 3}) -// r0 = r.WithPriorityReceiver(0) // will process messages with priority 0 -// r1 = r.WithPriorityReceiver(1) // will process messages with priority 1 -// r2 = r.WithPriorityReceiver(2) // will process messages with priority 2 -// srv0, err := sqs.NewServer(c.String("sqs-queue-url-0"), 50, 100) -// if err != nil { -// return fmt.Errorf("failed to establish SQS connection: %w", err) -// } -// srv1, err := sqs.NewServer(c.String("sqs-queue-url-1"), 50, 100) -// if err != nil { -// return fmt.Errorf("failed to establish SQS connection: %w", err) -// } -// srv2, err := sqs.NewServer(c.String("sqs-queue-url-2"), 50, 100) -// if err != nil { -// return fmt.Errorf("failed to establish SQS connection: %w", err) -// } -// errGroup, _ := errgroup.WithContext(context.Background()) -// errGroup.Go(func() error { -// return srv0.Serve(lowPriorityReceiver) -// }) -// errGroup.Go(func() error { -// return srv1.Serve(mediumPriorityReceiver) -// }) -// errGroup.Go(func() error { -// return srv2.Serve(highPriorityReceiver) -// }) -// errGroup.Wait() -func NewWeightedFairReceiver(ctx context.Context, receiver msg.Receiver, weights []float64, opts ...WeightedFairReceiverOption) *WeightedFairReceiver { - - priorityQueue := pq.NewWith(func(a, b *weightedMessage) int { - return utils.NumberComparator(a.vFinish, b.vFinish) - }) - - r := &WeightedFairReceiver{ - receiver: receiver, - weights: weights, - queueWaitTime: 50 * time.Millisecond, - maxConcurrent: 10, - receiveChan: make(chan *weightedMessage), - queue: priorityQueue, - startTime: int(time.Now().Unix()), - lastVFinish: make([]float64, len(weights)), - } - - for _, opt := range opts { - opt(r) - } - - go r.dispatch(ctx) - return r -} diff --git a/decorators/wfr/receiver_test.go b/decorators/wfr/receiver_test.go deleted file mode 100644 index 1c444dd..0000000 --- a/decorators/wfr/receiver_test.go +++ /dev/null @@ -1 +0,0 @@ -package wfr diff --git a/docs/decisions/2024-07-08-multiserver.md b/docs/decisions/2024-07-08-multiserver.md new file mode 100644 index 0000000..4d159c6 --- /dev/null +++ b/docs/decisions/2024-07-08-multiserver.md @@ -0,0 +1,19 @@ +--- +Date: 2024/07/08 +Authors: @Xopherus, @shezadkhan137 +Status: Accepted +--- + +# Multi-Server Fair Weighted Queues + +## Context + +ZeroFox needs to prioritize messages within SQS queues, especially when services receive data from multiple sources with varying latency requirements. Prioritization is key to meeting SLAs for high-priority messages while ensuring lower-priority messages are not starved. Additionally, we need to dynamically allocate throughput based on message priority to ensure fairness and efficiency. + +## Decisions + +1. Implement an experimental Weighted Fair Queue algorithm as a decorator for the go-msg library (receiver). +This decorator, called `WeightedFairReceiver`, will enable message prioritization based on assigned weights. + +2. Implement a `MultiServer` wrapper around `WeightedFairReceiver`. +This server will consume messages from multiple underlying servers into a single receiver, guaranteeing consumption in proportion to assigned weights. diff --git a/docs/decisions/2024-07-09-experimental-changes b/docs/decisions/2024-07-09-experimental-changes.md similarity index 100% rename from docs/decisions/2024-07-09-experimental-changes rename to docs/decisions/2024-07-09-experimental-changes.md diff --git a/go.mod b/go.mod index 662704e..81389d6 100644 --- a/go.mod +++ b/go.mod @@ -4,20 +4,28 @@ go 1.21 require ( github.com/JimWen/gods-generic v0.10.2 + github.com/asecurityteam/rolling v2.0.4+incompatible github.com/google/go-cmp v0.6.0 github.com/pierrec/lz4/v4 v4.1.8 + github.com/stretchr/testify v1.8.4 go.opencensus.io v0.24.0 go.opentelemetry.io/otel v1.24.0 go.opentelemetry.io/otel/bridge/opencensus v1.24.0 go.opentelemetry.io/otel/sdk v1.24.0 go.opentelemetry.io/otel/trace v1.24.0 + golang.org/x/sync v0.0.0-20190423024810-112230192c58 + pgregory.net/rapid v1.1.0 ) require ( + github.com/davecgh/go-spew v1.1.1 // indirect github.com/go-logr/logr v1.4.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect + github.com/kr/text v0.2.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect go.opentelemetry.io/otel/metric v1.24.0 // indirect go.opentelemetry.io/otel/sdk/metric v1.24.0 // indirect golang.org/x/sys v0.20.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 6325ae4..45f9083 100644 --- a/go.sum +++ b/go.sum @@ -2,9 +2,12 @@ cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMT github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/JimWen/gods-generic v0.10.2 h1:ib/BF6W5+ANQJinlNxHYETH1BtxZASkBOV3v4mHSYYY= github.com/JimWen/gods-generic v0.10.2/go.mod h1:ukDWk4Hb0hovQbhqitDTeOK4Hz+IK0y3q5QKQdri3as= +github.com/asecurityteam/rolling v2.0.4+incompatible h1:WOSeokINZT0IDzYGc5BVcjLlR9vPol08RvI2GAsmB0s= +github.com/asecurityteam/rolling v2.0.4+incompatible/go.mod h1:2D4ba5ZfYCWrIMleUgTvc8pmLExEuvu3PDwl+vnG58Q= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -40,6 +43,10 @@ github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/pierrec/lz4/v4 v4.1.8 h1:ieHkV+i2BRzngO4Wd/3HGowuZStgq6QkPsD1eolNAO4= github.com/pierrec/lz4/v4 v4.1.8/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -82,6 +89,7 @@ golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwY golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58 h1:8gQV6CLnAEikrhgkHFbMAEhagSSnXWGV915qUMm9mrU= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -117,8 +125,12 @@ google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2 google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +pgregory.net/rapid v1.1.0 h1:CMa0sjHSru3puNx+J0MIAuiiEV4N0qj8/cMWGBBCsjw= +pgregory.net/rapid v1.1.0/go.mod h1:PY5XlDGj0+V1FCq0o192FdRhpKHGTRIWBgqjDBTrq04= diff --git a/x/README.md b/x/README.md index 24e9d61..9b721ef 100644 --- a/x/README.md +++ b/x/README.md @@ -3,4 +3,4 @@ This directory holds experimental packages ([decision record](../docs/decisions/2024-07-09-experimental-changes)). The intent is to introduce new features which may or may not become a part of the core library. -Code in this directory may not be backwards compatible. \ No newline at end of file +Code in this directory may not be backwards compatible. diff --git a/x/multiserver/README.md b/x/multiserver/README.md new file mode 100644 index 0000000..3862a3c --- /dev/null +++ b/x/multiserver/README.md @@ -0,0 +1,65 @@ +# Multiserver + +The `multiserver` package provides a `MultiServer` that can serve messages from multiple underlying servers to a single receiver. The server will consume messages from the underlying servers in the ratio of the weights provided. + +### Example + +```go +package main + +import ( + "bytes" + "context" + "fmt" + "time" + + "github.com/zerofox-oss/go-msg" + "github.com/zerofox-oss/go-msg/backends/mem" + "github.com/zerofox-oss/go-msg/x/multiserver" +) + +func main() { + // Create memory servers + server1 := mem.NewServer(make(chan *msg.Message, 100), 10) + server2 := mem.NewServer(make(chan *msg.Message, 100), 10) + + // Define server weights + serverWeights := []multiserver.ServerWeight{ + {Server: server1, Weight: 1.0}, + {Server: server2, Weight: 2.0}, + } + + // Create MultiServer + mserver, err := multiserver.NewMultiServer(10, serverWeights) + if err != nil { + fmt.Println("Error creating MultiServer:", err) + return + } + + // Start serving messages + go func() { + mserver.Serve(msg.ReceiverFunc(func(ctx context.Context, m *msg.Message) error { + fmt.Println("Received message:", m) + return nil + })) + }() + + // Simulate sending messages + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + go func() { + for { + select { + case <-ctx.Done(): + return + case server1.C <- &msg.Message{Body: bytes.NewBuffer([]byte("message from server1"))}: + case server2.C <- &msg.Message{Body: bytes.NewBuffer([]byte("message from server2"))}: + } + } + }() + + <-ctx.Done() + mserver.Shutdown(context.Background()) +} +``` diff --git a/x/multiserver/multiserver.go b/x/multiserver/multiserver.go new file mode 100644 index 0000000..80a111c --- /dev/null +++ b/x/multiserver/multiserver.go @@ -0,0 +1,108 @@ +package multiserver + +import ( + "context" + "errors" + "time" + + msg "github.com/zerofox-oss/go-msg" + "golang.org/x/sync/errgroup" +) + +// MultiServer is a server that can serve messages from multiple underlying servers +// to as single receiver. The server will consume messages from the underlying servers +// in the ratio of the weights provided. +type MultiServer struct { + servers []msg.Server + weights []float64 + concurrency int + queueWaitTime time.Duration + wfr *WeightedFairReceiver +} + +type ServerWeight struct { + Server msg.Server + Weight float64 +} + +// MultiServerOption is a functional option for the MultiServer. +type MultiServerOption func(*MultiServer) + +// WithQueueWaitTime sets the time to wait for a message to arrive. +func WithQueueWaitTime(queueWaitTime time.Duration) MultiServerOption { + return func(m *MultiServer) { + m.queueWaitTime = queueWaitTime + } +} + +// NewMultiServer creates a new MultiServer with the given concurrency. +// The server will distribute the messages to underlying receiver from +// the given servers in the ratio of the weights provided. +func NewMultiServer(concurrency int, serverWeights []ServerWeight, opts ...MultiServerOption) (*MultiServer, error) { + if len(serverWeights) == 0 { + return nil, errors.New("serverWeights must not be empty") + } + + if concurrency <= 0 { + return nil, errors.New("concurrency must be greater than 0") + } + + servers := make([]msg.Server, 0, len(serverWeights)) + weights := make([]float64, 0, len(serverWeights)) + + for _, s := range serverWeights { + servers = append(servers, s.Server) + weights = append(weights, s.Weight) + } + + server := &MultiServer{ + concurrency: concurrency, + servers: servers, + weights: weights, + queueWaitTime: 1 * time.Millisecond, + } + + for _, opt := range opts { + opt(server) + } + + return server, nil +} + +// Serve serves messages to the underlying servers. +func (m *MultiServer) Serve(msg msg.Receiver) error { + m.wfr = NewWeightedFairReceiver( + m.weights, + m.concurrency, + m.queueWaitTime, + msg, + ) + + g := errgroup.Group{} + for i, s := range m.servers { + s := s + i := i + g.Go(func() error { + return s.Serve(m.wfr.WithPriorityReceiver(i)) + }) + } + + return g.Wait() +} + +// Shutdown shuts down the server. +func (m *MultiServer) Shutdown(ctx context.Context) error { + g := errgroup.Group{} + for _, s := range m.servers { + s := s + g.Go(func() error { + return s.Shutdown(ctx) + }) + } + err := g.Wait() + if err != nil { + return err + } + + return m.wfr.Close(ctx) +} diff --git a/x/multiserver/multiserver_test.go b/x/multiserver/multiserver_test.go new file mode 100644 index 0000000..0bbfa5d --- /dev/null +++ b/x/multiserver/multiserver_test.go @@ -0,0 +1,112 @@ +package multiserver_test + +import ( + "bytes" + "context" + "fmt" + "math/rand" + "strconv" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/zerofox-oss/go-msg" + "github.com/zerofox-oss/go-msg/backends/mem" + "github.com/zerofox-oss/go-msg/x/multiserver" + "pgregory.net/rapid" +) + +func sendMessages(ctx context.Context, inputChan chan *msg.Message) { + for { + select { + case <-ctx.Done(): + return + case inputChan <- &msg.Message{ + Body: bytes.NewBuffer([]byte("hello world")), + Attributes: msg.Attributes{}, + }: + } + } +} + +func TestMultiServer(t *testing.T) { + t.Parallel() + + for i := 0; i <= 10; i++ { + t.Run(fmt.Sprintf("TestMultiServer_%d", i), func(t *testing.T) { + t.Parallel() + + rapid.Check(t, func(t *rapid.T) { + numServers := rapid.IntRange(1, 10).Draw(t, "numServers") + + serverConcurrency := 10 + inputChanBuffer := 100 + + counts := make([]atomic.Int32, numServers) + inputChans := make([]chan *msg.Message, numServers) + serverWeights := make([]multiserver.ServerWeight, numServers) + + for i := 0; i < numServers; i++ { + inputChan := make(chan *msg.Message, inputChanBuffer) + server := mem.NewServer(inputChan, serverConcurrency) + weight := rapid.IntRange(1, 10).Draw(t, "weight") + serverWeights[i] = multiserver.ServerWeight{ + Server: server, + Weight: float64(weight), + } + inputChans[i] = inputChan + } + + mserver, err := multiserver.NewMultiServer(serverConcurrency, serverWeights) + assert.NoError(t, err) + + go func() { + mserver.Serve(msg.ReceiverFunc(func(ctx context.Context, m *msg.Message) error { + msgPriority := m.Attributes.Get(multiserver.MultiServerMsgPriority) + p, err := strconv.Atoi(msgPriority) + if err != nil { + assert.Fail(t, "failed to parse priority") + } + counts[p].Add(1) + time.Sleep(time.Duration(rand.Intn(500)) * time.Millisecond) + return nil + })) + }() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + for i := 0; i < numServers; i++ { + go sendMessages(ctx, inputChans[i]) + } + + <-ctx.Done() + mserver.Shutdown(context.Background()) + + totalWeight := 0 + totalCounts := 0 + for i := 0; i < numServers; i++ { + totalWeight += int(serverWeights[i].Weight) + totalCounts += int(counts[i].Load()) + } + + delta := float64(totalCounts) * float64(0.1) + + t.Logf("Total weight: %d\n", totalWeight) + t.Logf("Total counts: %d\n", totalCounts) + t.Logf("Total counts delta: %f\n", delta) + + for i := 0; i < numServers; i++ { + weight := int(serverWeights[i].Weight) + expectedCount := (weight * totalCounts) / totalWeight + serverCount := int(counts[i].Load()) + t.Logf("Server %d: weight %d, expected count %d, actual count %d\n", i, weight, expectedCount, serverCount) + assert.InDeltaf( + t, expectedCount, serverCount, delta, + "Server %d: weight %d, expected count %d, actual count %d\n", i, weight, expectedCount, serverCount) + } + }) + }) + } +} diff --git a/x/multiserver/receiver.go b/x/multiserver/receiver.go new file mode 100644 index 0000000..704153b --- /dev/null +++ b/x/multiserver/receiver.go @@ -0,0 +1,180 @@ +package multiserver + +import ( + "context" + "fmt" + "math" + "time" + + pq "github.com/JimWen/gods-generic/queues/priorityqueue" + "github.com/JimWen/gods-generic/utils" + "github.com/asecurityteam/rolling" + msg "github.com/zerofox-oss/go-msg" +) + +// MultiServerMsgPriority is the key used to store the priority of a message in the message attributes. +const MultiServerMsgPriority = "x-multiserver-priority" + +type weightedMessage struct { + msg *msg.Message + context context.Context + vFinish float64 + priority int + doneChan chan error +} + +// WeightedFairReceiver implements a "fair" receiver that processes messages based on their priority. +// See https://en.wikipedia.org/wiki/Fair_queuing for more information, about fairness. +// The concrete implementation is based on the Weighted Fair Queuing algorithm. +// At a high level the algorithm works as follows: +// 1. Each message is assigned a virtual finish time. The virtual finish time is defined as the +// time of the last message processed for that priority plus a weighting factor. The higher the +// the weight the smaller the virtual finish time is for that message. +// 2. The message is then enqueued in a priority queue based on the virtual finish time. +// 3. Every "tick" we select the message with the smallest virtual finish time and process it. +type WeightedFairReceiver struct { + // underlying receiver + receiver msg.Receiver + + // time to wait for a message to arrive + // the longer you wait, the better the fairness + // the shorter the better the latency. + queueWaitTime time.Duration + + // weights for each priority level + // the higher the weight, the more messages + // will be processed for that priority + weights []float64 + + // max number of concurrent messages that can + // be processed at the same time by the receiver + maxConcurrent int + + // a rough estimate time to process a message + // this should be in milliseconds. + initialEstimatedCost float64 + + receiveChan chan *weightedMessage + queue *pq.Queue[*weightedMessage] + startTime int + lastVFinish []float64 + + timePolicy *rolling.TimePolicy + closeChan chan chan error +} + +// NewWeightedFairReceiver creates a new WeightedFairReceiver. +// The receiver will process messages based on their priority level. +func NewWeightedFairReceiver( + weights []float64, + maxConcurrent int, + queueWaitTime time.Duration, + receiver msg.Receiver, +) *WeightedFairReceiver { + priorityQueue := pq.NewWith(func(a, b *weightedMessage) int { + return utils.NumberComparator(a.vFinish, b.vFinish) + }) + + wfr := &WeightedFairReceiver{ + weights: weights, + lastVFinish: make([]float64, len(weights)), + receiver: receiver, + queue: priorityQueue, + startTime: 1, + queueWaitTime: queueWaitTime, + maxConcurrent: maxConcurrent, + receiveChan: make(chan *weightedMessage), + timePolicy: rolling.NewTimePolicy(rolling.NewWindow(10000), 1*time.Millisecond), + initialEstimatedCost: 100.0, + closeChan: make(chan chan error), + } + + go wfr.dispatch() + return wfr +} + +// WithPriorityReceiver returns a new msg.Receiver that should be used to receive messages +// at a specific priority level. +func (w *WeightedFairReceiver) WithPriorityReceiver(priority int) msg.Receiver { + return msg.ReceiverFunc(func(ctx context.Context, m *msg.Message) error { + return w.Receive(ctx, m, priority) + }) +} + +// Receive receives a message with a specific priority level. +func (w *WeightedFairReceiver) Receive(ctx context.Context, m *msg.Message, priority int) error { + wm := &weightedMessage{ + msg: m, + priority: priority, + doneChan: make(chan error, 1), + context: ctx, + } + w.receiveChan <- wm + return <-wm.doneChan +} + +// Close closes the receiver. +func (w *WeightedFairReceiver) Close(ctx context.Context) error { + doneChan := make(chan error) + w.closeChan <- doneChan + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-doneChan: + return err + } +} + +func (w *WeightedFairReceiver) estimateCost() float64 { + return w.timePolicy.Reduce(rolling.Avg) +} + +func (w *WeightedFairReceiver) dispatch() { + maxConcurrentReceives := make(chan struct{}, w.maxConcurrent) + timer := time.NewTicker(w.queueWaitTime) + + doReceive := func() { + select { + case maxConcurrentReceives <- struct{}{}: + if mw, ok := w.queue.Dequeue(); ok { + go func(wm weightedMessage) { + defer func() { + <-maxConcurrentReceives + }() + st := time.Now() + wm.msg.Attributes.Set(MultiServerMsgPriority, fmt.Sprint(wm.priority)) + result := w.receiver.Receive(wm.context, wm.msg) + w.timePolicy.Append(float64(time.Since(st).Milliseconds())) + wm.doneChan <- result + }(*mw) + } else { + <-maxConcurrentReceives + } + default: + } + } + + for { + select { + case doneChan := <-w.closeChan: + doneChan <- nil + // TODO: process remaining messages + // that are queued in order to cleanly shutdown + return + case wm := <-w.receiveChan: + vStart := math.Max(float64(w.startTime), w.lastVFinish[wm.priority]) + estimatedCost := w.estimateCost() + if math.IsNaN(estimatedCost) { + estimatedCost = w.initialEstimatedCost + } + estimatedCost = math.Round(estimatedCost) + weight2 := float64(estimatedCost / w.weights[wm.priority]) + vFinish := vStart + weight2 + wm.vFinish = vFinish + w.lastVFinish[wm.priority] = vFinish + w.queue.Enqueue(wm) + case <-timer.C: + doReceive() + } + } +}