diff --git a/Makefile b/Makefile index 9be6773..6e6ddf9 100644 --- a/Makefile +++ b/Makefile @@ -9,7 +9,7 @@ lint: init golangci-lint run ./... test: init - go test -race -v ./... + go test -race ./... tools: curl -sfL https://install.goreleaser.com/github.com/golangci/golangci-lint.sh | sh -s -- -b $(GOPATH)/bin v1.17.0 diff --git a/docs/decisions/2024-07-08-multiserver.md b/docs/decisions/2024-07-08-multiserver.md new file mode 100644 index 0000000..4d159c6 --- /dev/null +++ b/docs/decisions/2024-07-08-multiserver.md @@ -0,0 +1,19 @@ +--- +Date: 2024/07/08 +Authors: @Xopherus, @shezadkhan137 +Status: Accepted +--- + +# Multi-Server Fair Weighted Queues + +## Context + +ZeroFox needs to prioritize messages within SQS queues, especially when services receive data from multiple sources with varying latency requirements. Prioritization is key to meeting SLAs for high-priority messages while ensuring lower-priority messages are not starved. Additionally, we need to dynamically allocate throughput based on message priority to ensure fairness and efficiency. + +## Decisions + +1. Implement an experimental Weighted Fair Queue algorithm as a decorator for the go-msg library (receiver). +This decorator, called `WeightedFairReceiver`, will enable message prioritization based on assigned weights. + +2. Implement a `MultiServer` wrapper around `WeightedFairReceiver`. +This server will consume messages from multiple underlying servers into a single receiver, guaranteeing consumption in proportion to assigned weights. diff --git a/docs/decisions/2024-07-09-experimental-changes b/docs/decisions/2024-07-09-experimental-changes.md similarity index 100% rename from docs/decisions/2024-07-09-experimental-changes rename to docs/decisions/2024-07-09-experimental-changes.md diff --git a/go.mod b/go.mod index 8cde497..81389d6 100644 --- a/go.mod +++ b/go.mod @@ -3,20 +3,29 @@ module github.com/zerofox-oss/go-msg go 1.21 require ( + github.com/JimWen/gods-generic v0.10.2 + github.com/asecurityteam/rolling v2.0.4+incompatible github.com/google/go-cmp v0.6.0 github.com/pierrec/lz4/v4 v4.1.8 + github.com/stretchr/testify v1.8.4 go.opencensus.io v0.24.0 go.opentelemetry.io/otel v1.24.0 go.opentelemetry.io/otel/bridge/opencensus v1.24.0 go.opentelemetry.io/otel/sdk v1.24.0 go.opentelemetry.io/otel/trace v1.24.0 + golang.org/x/sync v0.0.0-20190423024810-112230192c58 + pgregory.net/rapid v1.1.0 ) require ( + github.com/davecgh/go-spew v1.1.1 // indirect github.com/go-logr/logr v1.4.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect + github.com/kr/text v0.2.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect go.opentelemetry.io/otel/metric v1.24.0 // indirect go.opentelemetry.io/otel/sdk/metric v1.24.0 // indirect golang.org/x/sys v0.20.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 35764e3..45f9083 100644 --- a/go.sum +++ b/go.sum @@ -1,8 +1,13 @@ cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/JimWen/gods-generic v0.10.2 h1:ib/BF6W5+ANQJinlNxHYETH1BtxZASkBOV3v4mHSYYY= +github.com/JimWen/gods-generic v0.10.2/go.mod h1:ukDWk4Hb0hovQbhqitDTeOK4Hz+IK0y3q5QKQdri3as= +github.com/asecurityteam/rolling v2.0.4+incompatible h1:WOSeokINZT0IDzYGc5BVcjLlR9vPol08RvI2GAsmB0s= +github.com/asecurityteam/rolling v2.0.4+incompatible/go.mod h1:2D4ba5ZfYCWrIMleUgTvc8pmLExEuvu3PDwl+vnG58Q= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -11,8 +16,6 @@ github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.m github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= -github.com/go-logr/logr v1.4.1 h1:pKouT5E8xu9zeFC39JXRDukb6JFQPXM5p5I91188VAQ= -github.com/go-logr/logr v1.4.1/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= @@ -40,6 +43,10 @@ github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/pierrec/lz4/v4 v4.1.8 h1:ieHkV+i2BRzngO4Wd/3HGowuZStgq6QkPsD1eolNAO4= github.com/pierrec/lz4/v4 v4.1.8/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -52,46 +59,21 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= -github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= -github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0= go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo= go.opentelemetry.io/otel v1.24.0 h1:0LAOdjNmQeSTzGBzduGe/rU4tZhMwL5rWgtp9Ku5Jfo= go.opentelemetry.io/otel v1.24.0/go.mod h1:W7b9Ozg4nkF5tWI5zsXkaKKDjdVjpD4oAt9Qi/MArHo= -go.opentelemetry.io/otel v1.25.0 h1:gldB5FfhRl7OJQbUHt/8s0a7cE8fbsPAtdpRaApKy4k= -go.opentelemetry.io/otel v1.25.0/go.mod h1:Wa2ds5NOXEMkCmUou1WA7ZBfLTHWIsp034OVD7AO+Vg= -go.opentelemetry.io/otel v1.26.0 h1:LQwgL5s/1W7YiiRwxf03QGnWLb2HW4pLiAhaA5cZXBs= -go.opentelemetry.io/otel v1.26.0/go.mod h1:UmLkJHUAidDval2EICqBMbnAd0/m2vmpf/dAM+fvFs4= go.opentelemetry.io/otel/bridge/opencensus v1.24.0 h1:Vlhy5ee5k5R0zASpH+9AgHiJH7xnKACI3XopO1tUZfY= go.opentelemetry.io/otel/bridge/opencensus v1.24.0/go.mod h1:jRjVXV/X38jyrnHtvMGN8+9cejZB21JvXAAvooF2s+Q= -go.opentelemetry.io/otel/bridge/opencensus v1.25.0 h1:0o/9KwAgxjK+3pMV0pwIF5toYHqDsPmQhfrBvKaG6mU= -go.opentelemetry.io/otel/bridge/opencensus v1.25.0/go.mod h1:rZyTdpmRqoV+PpUn6QlruxJp/kE4765rPy0pP6mRDk8= -go.opentelemetry.io/otel/bridge/opencensus v1.26.0 h1:DZzxj9QjznMVoehskOJnFP2gsTCWtDTFBDvFhPAY7nc= -go.opentelemetry.io/otel/bridge/opencensus v1.26.0/go.mod h1:rJiX0KrF5m8Tm1XE8jLczpAv5zUaDcvhKecFG0ZoFG4= go.opentelemetry.io/otel/metric v1.24.0 h1:6EhoGWWK28x1fbpA4tYTOWBkPefTDQnb8WSGXlc88kI= go.opentelemetry.io/otel/metric v1.24.0/go.mod h1:VYhLe1rFfxuTXLgj4CBiyz+9WYBA8pNGJgDcSFRKBco= -go.opentelemetry.io/otel/metric v1.25.0 h1:LUKbS7ArpFL/I2jJHdJcqMGxkRdxpPHE0VU/D4NuEwA= -go.opentelemetry.io/otel/metric v1.25.0/go.mod h1:rkDLUSd2lC5lq2dFNrX9LGAbINP5B7WBkC78RXCpH5s= -go.opentelemetry.io/otel/metric v1.26.0 h1:7S39CLuY5Jgg9CrnA9HHiEjGMF/X2VHvoXGgSllRz30= -go.opentelemetry.io/otel/metric v1.26.0/go.mod h1:SY+rHOI4cEawI9a7N1A4nIg/nTQXe1ccCNWYOJUrpX4= go.opentelemetry.io/otel/sdk v1.24.0 h1:YMPPDNymmQN3ZgczicBY3B6sf9n62Dlj9pWD3ucgoDw= go.opentelemetry.io/otel/sdk v1.24.0/go.mod h1:KVrIYw6tEubO9E96HQpcmpTKDVn9gdv35HoYiQWGDFg= -go.opentelemetry.io/otel/sdk v1.25.0 h1:PDryEJPC8YJZQSyLY5eqLeafHtG+X7FWnf3aXMtxbqo= -go.opentelemetry.io/otel/sdk v1.25.0/go.mod h1:oFgzCM2zdsxKzz6zwpTZYLLQsFwc+K0daArPdIhuxkw= -go.opentelemetry.io/otel/sdk v1.26.0 h1:Y7bumHf5tAiDlRYFmGqetNcLaVUZmh4iYfmGxtmz7F8= -go.opentelemetry.io/otel/sdk v1.26.0/go.mod h1:0p8MXpqLeJ0pzcszQQN4F0S5FVjBLgypeGSngLsmirs= go.opentelemetry.io/otel/sdk/metric v1.24.0 h1:yyMQrPzF+k88/DbH7o4FMAs80puqd+9osbiBrJrz/w8= go.opentelemetry.io/otel/sdk/metric v1.24.0/go.mod h1:I6Y5FjH6rvEnTTAYQz3Mmv2kl6Ek5IIrmwTLqMrrOE0= -go.opentelemetry.io/otel/sdk/metric v1.25.0 h1:7CiHOy08LbrxMAp4vWpbiPcklunUshVpAvGBrdDRlGw= -go.opentelemetry.io/otel/sdk/metric v1.25.0/go.mod h1:LzwoKptdbBBdYfvtGCzGwk6GWMA3aUzBOwtQpR6Nz7o= -go.opentelemetry.io/otel/sdk/metric v1.26.0 h1:cWSks5tfriHPdWFnl+qpX3P681aAYqlZHcAyHw5aU9Y= -go.opentelemetry.io/otel/sdk/metric v1.26.0/go.mod h1:ClMFFknnThJCksebJwz7KIyEDHO+nTB6gK8obLy8RyE= go.opentelemetry.io/otel/trace v1.24.0 h1:CsKnnL4dUAr/0llH9FKuc698G04IrpWV0MQA/Y1YELI= go.opentelemetry.io/otel/trace v1.24.0/go.mod h1:HPc3Xr/cOApsBI154IU0OI0HJexz+aw5uPdbs3UCjNU= -go.opentelemetry.io/otel/trace v1.25.0 h1:tqukZGLwQYRIFtSQM2u2+yfMVTgGVeqRLPUYx1Dq6RM= -go.opentelemetry.io/otel/trace v1.25.0/go.mod h1:hCCs70XM/ljO+BeQkyFnbK28SBIJ/Emuha+ccrCRT7I= -go.opentelemetry.io/otel/trace v1.26.0 h1:1ieeAUb4y0TE26jUFrCIXKpTuVK7uJGN9/Z/2LP5sQA= -go.opentelemetry.io/otel/trace v1.26.0/go.mod h1:4iDxvGDQuUkHve82hJJ8UqrwswHYsZuWCBllGV2U2y0= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -107,6 +89,7 @@ golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwY golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58 h1:8gQV6CLnAEikrhgkHFbMAEhagSSnXWGV915qUMm9mrU= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -142,8 +125,12 @@ google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2 google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +pgregory.net/rapid v1.1.0 h1:CMa0sjHSru3puNx+J0MIAuiiEV4N0qj8/cMWGBBCsjw= +pgregory.net/rapid v1.1.0/go.mod h1:PY5XlDGj0+V1FCq0o192FdRhpKHGTRIWBgqjDBTrq04= diff --git a/x/README.md b/x/README.md index 24e9d61..9b721ef 100644 --- a/x/README.md +++ b/x/README.md @@ -3,4 +3,4 @@ This directory holds experimental packages ([decision record](../docs/decisions/2024-07-09-experimental-changes)). The intent is to introduce new features which may or may not become a part of the core library. -Code in this directory may not be backwards compatible. \ No newline at end of file +Code in this directory may not be backwards compatible. diff --git a/x/multiserver/README.md b/x/multiserver/README.md new file mode 100644 index 0000000..3862a3c --- /dev/null +++ b/x/multiserver/README.md @@ -0,0 +1,65 @@ +# Multiserver + +The `multiserver` package provides a `MultiServer` that can serve messages from multiple underlying servers to a single receiver. The server will consume messages from the underlying servers in the ratio of the weights provided. + +### Example + +```go +package main + +import ( + "bytes" + "context" + "fmt" + "time" + + "github.com/zerofox-oss/go-msg" + "github.com/zerofox-oss/go-msg/backends/mem" + "github.com/zerofox-oss/go-msg/x/multiserver" +) + +func main() { + // Create memory servers + server1 := mem.NewServer(make(chan *msg.Message, 100), 10) + server2 := mem.NewServer(make(chan *msg.Message, 100), 10) + + // Define server weights + serverWeights := []multiserver.ServerWeight{ + {Server: server1, Weight: 1.0}, + {Server: server2, Weight: 2.0}, + } + + // Create MultiServer + mserver, err := multiserver.NewMultiServer(10, serverWeights) + if err != nil { + fmt.Println("Error creating MultiServer:", err) + return + } + + // Start serving messages + go func() { + mserver.Serve(msg.ReceiverFunc(func(ctx context.Context, m *msg.Message) error { + fmt.Println("Received message:", m) + return nil + })) + }() + + // Simulate sending messages + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + go func() { + for { + select { + case <-ctx.Done(): + return + case server1.C <- &msg.Message{Body: bytes.NewBuffer([]byte("message from server1"))}: + case server2.C <- &msg.Message{Body: bytes.NewBuffer([]byte("message from server2"))}: + } + } + }() + + <-ctx.Done() + mserver.Shutdown(context.Background()) +} +``` diff --git a/x/multiserver/multiserver.go b/x/multiserver/multiserver.go new file mode 100644 index 0000000..80a111c --- /dev/null +++ b/x/multiserver/multiserver.go @@ -0,0 +1,108 @@ +package multiserver + +import ( + "context" + "errors" + "time" + + msg "github.com/zerofox-oss/go-msg" + "golang.org/x/sync/errgroup" +) + +// MultiServer is a server that can serve messages from multiple underlying servers +// to as single receiver. The server will consume messages from the underlying servers +// in the ratio of the weights provided. +type MultiServer struct { + servers []msg.Server + weights []float64 + concurrency int + queueWaitTime time.Duration + wfr *WeightedFairReceiver +} + +type ServerWeight struct { + Server msg.Server + Weight float64 +} + +// MultiServerOption is a functional option for the MultiServer. +type MultiServerOption func(*MultiServer) + +// WithQueueWaitTime sets the time to wait for a message to arrive. +func WithQueueWaitTime(queueWaitTime time.Duration) MultiServerOption { + return func(m *MultiServer) { + m.queueWaitTime = queueWaitTime + } +} + +// NewMultiServer creates a new MultiServer with the given concurrency. +// The server will distribute the messages to underlying receiver from +// the given servers in the ratio of the weights provided. +func NewMultiServer(concurrency int, serverWeights []ServerWeight, opts ...MultiServerOption) (*MultiServer, error) { + if len(serverWeights) == 0 { + return nil, errors.New("serverWeights must not be empty") + } + + if concurrency <= 0 { + return nil, errors.New("concurrency must be greater than 0") + } + + servers := make([]msg.Server, 0, len(serverWeights)) + weights := make([]float64, 0, len(serverWeights)) + + for _, s := range serverWeights { + servers = append(servers, s.Server) + weights = append(weights, s.Weight) + } + + server := &MultiServer{ + concurrency: concurrency, + servers: servers, + weights: weights, + queueWaitTime: 1 * time.Millisecond, + } + + for _, opt := range opts { + opt(server) + } + + return server, nil +} + +// Serve serves messages to the underlying servers. +func (m *MultiServer) Serve(msg msg.Receiver) error { + m.wfr = NewWeightedFairReceiver( + m.weights, + m.concurrency, + m.queueWaitTime, + msg, + ) + + g := errgroup.Group{} + for i, s := range m.servers { + s := s + i := i + g.Go(func() error { + return s.Serve(m.wfr.WithPriorityReceiver(i)) + }) + } + + return g.Wait() +} + +// Shutdown shuts down the server. +func (m *MultiServer) Shutdown(ctx context.Context) error { + g := errgroup.Group{} + for _, s := range m.servers { + s := s + g.Go(func() error { + return s.Shutdown(ctx) + }) + } + err := g.Wait() + if err != nil { + return err + } + + return m.wfr.Close(ctx) +} diff --git a/x/multiserver/multiserver_test.go b/x/multiserver/multiserver_test.go new file mode 100644 index 0000000..0bbfa5d --- /dev/null +++ b/x/multiserver/multiserver_test.go @@ -0,0 +1,112 @@ +package multiserver_test + +import ( + "bytes" + "context" + "fmt" + "math/rand" + "strconv" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/zerofox-oss/go-msg" + "github.com/zerofox-oss/go-msg/backends/mem" + "github.com/zerofox-oss/go-msg/x/multiserver" + "pgregory.net/rapid" +) + +func sendMessages(ctx context.Context, inputChan chan *msg.Message) { + for { + select { + case <-ctx.Done(): + return + case inputChan <- &msg.Message{ + Body: bytes.NewBuffer([]byte("hello world")), + Attributes: msg.Attributes{}, + }: + } + } +} + +func TestMultiServer(t *testing.T) { + t.Parallel() + + for i := 0; i <= 10; i++ { + t.Run(fmt.Sprintf("TestMultiServer_%d", i), func(t *testing.T) { + t.Parallel() + + rapid.Check(t, func(t *rapid.T) { + numServers := rapid.IntRange(1, 10).Draw(t, "numServers") + + serverConcurrency := 10 + inputChanBuffer := 100 + + counts := make([]atomic.Int32, numServers) + inputChans := make([]chan *msg.Message, numServers) + serverWeights := make([]multiserver.ServerWeight, numServers) + + for i := 0; i < numServers; i++ { + inputChan := make(chan *msg.Message, inputChanBuffer) + server := mem.NewServer(inputChan, serverConcurrency) + weight := rapid.IntRange(1, 10).Draw(t, "weight") + serverWeights[i] = multiserver.ServerWeight{ + Server: server, + Weight: float64(weight), + } + inputChans[i] = inputChan + } + + mserver, err := multiserver.NewMultiServer(serverConcurrency, serverWeights) + assert.NoError(t, err) + + go func() { + mserver.Serve(msg.ReceiverFunc(func(ctx context.Context, m *msg.Message) error { + msgPriority := m.Attributes.Get(multiserver.MultiServerMsgPriority) + p, err := strconv.Atoi(msgPriority) + if err != nil { + assert.Fail(t, "failed to parse priority") + } + counts[p].Add(1) + time.Sleep(time.Duration(rand.Intn(500)) * time.Millisecond) + return nil + })) + }() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + for i := 0; i < numServers; i++ { + go sendMessages(ctx, inputChans[i]) + } + + <-ctx.Done() + mserver.Shutdown(context.Background()) + + totalWeight := 0 + totalCounts := 0 + for i := 0; i < numServers; i++ { + totalWeight += int(serverWeights[i].Weight) + totalCounts += int(counts[i].Load()) + } + + delta := float64(totalCounts) * float64(0.1) + + t.Logf("Total weight: %d\n", totalWeight) + t.Logf("Total counts: %d\n", totalCounts) + t.Logf("Total counts delta: %f\n", delta) + + for i := 0; i < numServers; i++ { + weight := int(serverWeights[i].Weight) + expectedCount := (weight * totalCounts) / totalWeight + serverCount := int(counts[i].Load()) + t.Logf("Server %d: weight %d, expected count %d, actual count %d\n", i, weight, expectedCount, serverCount) + assert.InDeltaf( + t, expectedCount, serverCount, delta, + "Server %d: weight %d, expected count %d, actual count %d\n", i, weight, expectedCount, serverCount) + } + }) + }) + } +} diff --git a/x/multiserver/receiver.go b/x/multiserver/receiver.go new file mode 100644 index 0000000..704153b --- /dev/null +++ b/x/multiserver/receiver.go @@ -0,0 +1,180 @@ +package multiserver + +import ( + "context" + "fmt" + "math" + "time" + + pq "github.com/JimWen/gods-generic/queues/priorityqueue" + "github.com/JimWen/gods-generic/utils" + "github.com/asecurityteam/rolling" + msg "github.com/zerofox-oss/go-msg" +) + +// MultiServerMsgPriority is the key used to store the priority of a message in the message attributes. +const MultiServerMsgPriority = "x-multiserver-priority" + +type weightedMessage struct { + msg *msg.Message + context context.Context + vFinish float64 + priority int + doneChan chan error +} + +// WeightedFairReceiver implements a "fair" receiver that processes messages based on their priority. +// See https://en.wikipedia.org/wiki/Fair_queuing for more information, about fairness. +// The concrete implementation is based on the Weighted Fair Queuing algorithm. +// At a high level the algorithm works as follows: +// 1. Each message is assigned a virtual finish time. The virtual finish time is defined as the +// time of the last message processed for that priority plus a weighting factor. The higher the +// the weight the smaller the virtual finish time is for that message. +// 2. The message is then enqueued in a priority queue based on the virtual finish time. +// 3. Every "tick" we select the message with the smallest virtual finish time and process it. +type WeightedFairReceiver struct { + // underlying receiver + receiver msg.Receiver + + // time to wait for a message to arrive + // the longer you wait, the better the fairness + // the shorter the better the latency. + queueWaitTime time.Duration + + // weights for each priority level + // the higher the weight, the more messages + // will be processed for that priority + weights []float64 + + // max number of concurrent messages that can + // be processed at the same time by the receiver + maxConcurrent int + + // a rough estimate time to process a message + // this should be in milliseconds. + initialEstimatedCost float64 + + receiveChan chan *weightedMessage + queue *pq.Queue[*weightedMessage] + startTime int + lastVFinish []float64 + + timePolicy *rolling.TimePolicy + closeChan chan chan error +} + +// NewWeightedFairReceiver creates a new WeightedFairReceiver. +// The receiver will process messages based on their priority level. +func NewWeightedFairReceiver( + weights []float64, + maxConcurrent int, + queueWaitTime time.Duration, + receiver msg.Receiver, +) *WeightedFairReceiver { + priorityQueue := pq.NewWith(func(a, b *weightedMessage) int { + return utils.NumberComparator(a.vFinish, b.vFinish) + }) + + wfr := &WeightedFairReceiver{ + weights: weights, + lastVFinish: make([]float64, len(weights)), + receiver: receiver, + queue: priorityQueue, + startTime: 1, + queueWaitTime: queueWaitTime, + maxConcurrent: maxConcurrent, + receiveChan: make(chan *weightedMessage), + timePolicy: rolling.NewTimePolicy(rolling.NewWindow(10000), 1*time.Millisecond), + initialEstimatedCost: 100.0, + closeChan: make(chan chan error), + } + + go wfr.dispatch() + return wfr +} + +// WithPriorityReceiver returns a new msg.Receiver that should be used to receive messages +// at a specific priority level. +func (w *WeightedFairReceiver) WithPriorityReceiver(priority int) msg.Receiver { + return msg.ReceiverFunc(func(ctx context.Context, m *msg.Message) error { + return w.Receive(ctx, m, priority) + }) +} + +// Receive receives a message with a specific priority level. +func (w *WeightedFairReceiver) Receive(ctx context.Context, m *msg.Message, priority int) error { + wm := &weightedMessage{ + msg: m, + priority: priority, + doneChan: make(chan error, 1), + context: ctx, + } + w.receiveChan <- wm + return <-wm.doneChan +} + +// Close closes the receiver. +func (w *WeightedFairReceiver) Close(ctx context.Context) error { + doneChan := make(chan error) + w.closeChan <- doneChan + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-doneChan: + return err + } +} + +func (w *WeightedFairReceiver) estimateCost() float64 { + return w.timePolicy.Reduce(rolling.Avg) +} + +func (w *WeightedFairReceiver) dispatch() { + maxConcurrentReceives := make(chan struct{}, w.maxConcurrent) + timer := time.NewTicker(w.queueWaitTime) + + doReceive := func() { + select { + case maxConcurrentReceives <- struct{}{}: + if mw, ok := w.queue.Dequeue(); ok { + go func(wm weightedMessage) { + defer func() { + <-maxConcurrentReceives + }() + st := time.Now() + wm.msg.Attributes.Set(MultiServerMsgPriority, fmt.Sprint(wm.priority)) + result := w.receiver.Receive(wm.context, wm.msg) + w.timePolicy.Append(float64(time.Since(st).Milliseconds())) + wm.doneChan <- result + }(*mw) + } else { + <-maxConcurrentReceives + } + default: + } + } + + for { + select { + case doneChan := <-w.closeChan: + doneChan <- nil + // TODO: process remaining messages + // that are queued in order to cleanly shutdown + return + case wm := <-w.receiveChan: + vStart := math.Max(float64(w.startTime), w.lastVFinish[wm.priority]) + estimatedCost := w.estimateCost() + if math.IsNaN(estimatedCost) { + estimatedCost = w.initialEstimatedCost + } + estimatedCost = math.Round(estimatedCost) + weight2 := float64(estimatedCost / w.weights[wm.priority]) + vFinish := vStart + weight2 + wm.vFinish = vFinish + w.lastVFinish[wm.priority] = vFinish + w.queue.Enqueue(wm) + case <-timer.C: + doReceive() + } + } +}