From 8bdd3ade4b4ee7145bd3bbbdaabb06b63d6487b1 Mon Sep 17 00:00:00 2001 From: alice <90381261+alice-yyds@users.noreply.github.com> Date: Wed, 27 Sep 2023 19:34:16 +0800 Subject: [PATCH] chore: release v0.7.2 (#1127) Co-authored-by: QihengZhou Co-authored-by: Li2CO3 <45219850+HeyJavaBean@users.noreply.github.com> Co-authored-by: Felix021 Co-authored-by: Joway Co-authored-by: qiheng.zhou Co-authored-by: kinggo Co-authored-by: Jayant Co-authored-by: Z.Q.K --- .github/workflows/tests.yml | 28 +++ client/client.go | 28 ++- client/option_test.go | 2 +- client/service_inline.go | 4 +- client/service_inline_test.go | 35 +++- client/stream.go | 2 +- go.mod | 4 +- go.sum | 12 +- pkg/connpool/config.go | 6 +- pkg/connpool/config_test.go | 2 +- pkg/remote/codec/header_codec_test.go | 4 +- .../codec/protobuf/encoding/encoding.go | 39 ++++ pkg/remote/codec/protobuf/grpc.go | 58 ++++-- pkg/remote/codec/protobuf/grpc_compress.go | 58 +++--- pkg/remote/codec/protobuf/grpc_test.go | 15 -- .../codec/thrift/thrift_frugal_amd64.go | 4 +- .../codec/thrift/thrift_frugal_amd64_test.go | 4 +- pkg/remote/codec/thrift/thrift_others.go | 4 +- pkg/remote/compression.go | 38 ++-- pkg/remote/connpool/long_pool_test.go | 2 +- pkg/remote/trans/netpoll/bytebuf.go | 4 +- pkg/remote/trans/nphttp2/client_conn.go | 25 +-- pkg/remote/trans/nphttp2/client_conn_test.go | 26 +++ pkg/remote/trans/nphttp2/client_handler.go | 4 +- pkg/remote/trans/nphttp2/grpc/http2_server.go | 7 +- pkg/remote/trans/nphttp2/server_conn.go | 8 +- pkg/remote/trans/nphttp2/server_handler.go | 6 +- pkg/retry/backup_retryer.go | 44 +++-- pkg/retry/failure_retryer.go | 14 +- pkg/retry/percentage_limit.go | 31 +++ pkg/retry/policy_test.go | 13 +- pkg/retry/retryer.go | 142 ++++++++++++-- pkg/retry/retryer_test.go | 177 ++++++++++++++++-- pkg/retry/util.go | 34 ++-- pkg/retry/util_test.go | 64 +++++++ pkg/rpcinfo/interface.go | 1 + pkg/rpcinfo/mocks_test.go | 1 + pkg/rpcinfo/remoteinfo/remoteInfo.go | 54 +----- pkg/rpcinfo/remoteinfo/remoteInfo_test.go | 35 +--- pkg/rpcinfo/rpcstats.go | 18 ++ pkg/rpcinfo/rpcstats_test.go | 16 ++ pkg/stats/event.go | 5 + pkg/utils/strings.go | 81 +++++++- pkg/utils/strings_test.go | 60 ++++++ server/server.go | 6 +- server/server_test.go | 16 +- tool/internal_pkg/generator/type.go | 5 + .../pluginmode/thriftgo/file_tpl.go | 7 + tool/internal_pkg/tpl/service.go | 2 +- transport/keys.go | 1 - version.go | 2 +- 51 files changed, 958 insertions(+), 300 deletions(-) delete mode 100644 pkg/remote/codec/protobuf/grpc_test.go create mode 100644 pkg/retry/percentage_limit.go create mode 100644 pkg/retry/util_test.go create mode 100644 pkg/utils/strings_test.go diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 1207d2703d..2bc5a5319d 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -49,3 +49,31 @@ jobs: go-version: ${{ matrix.go }} - name: Unit Test run: go test -gcflags=-l -race -covermode=atomic ./... + + codegen-test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Set up Go + uses: actions/setup-go@v3 + with: + go-version: '1.17' + - name: Prepare + run: | + go install github.com/cloudwego/thriftgo@main + go install ./tool/cmd/kitex + LOCAL_REPO=$(pwd) + cd .. + git clone https://github.com/cloudwego/kitex-tests.git + cd kitex-tests/codegen + go mod init codegen-test + go mod edit -replace=github.com/apache/thrift=github.com/apache/thrift@v0.13.0 + go mod edit -replace github.com/cloudwego/kitex=${LOCAL_REPO} + go mod tidy + bash -version + bash ./codegen_install_check.sh + - name: CodeGen + run: | + cd ../kitex-tests/codegen + tree + bash ./codegen_run.sh diff --git a/client/client.go b/client/client.go index 1d022a94b9..ee34c7ebab 100644 --- a/client/client.go +++ b/client/client.go @@ -302,8 +302,8 @@ func richMWsWithBuilder(ctx context.Context, mwBs []endpoint.MiddlewareBuilder) } // initRPCInfo initializes the RPCInfo structure and attaches it to context. -func (kc *kClient) initRPCInfo(ctx context.Context, method string, retryTimes int) (context.Context, rpcinfo.RPCInfo, *callopt.CallOptions) { - return initRPCInfo(ctx, method, kc.opt, kc.svcInfo, retryTimes) +func (kc *kClient) initRPCInfo(ctx context.Context, method string, retryTimes int, firstRI rpcinfo.RPCInfo) (context.Context, rpcinfo.RPCInfo, *callopt.CallOptions) { + return initRPCInfo(ctx, method, kc.opt, kc.svcInfo, retryTimes, firstRI) } func applyCallOptions(ctx context.Context, cfg rpcinfo.MutableRPCConfig, svr remoteinfo.RemoteInfo, opt *client.Options) (context.Context, *callopt.CallOptions) { @@ -325,10 +325,7 @@ func (kc *kClient) Call(ctx context.Context, method string, request, response in validateForCall(ctx, kc.inited, kc.closed) var ri rpcinfo.RPCInfo var callOpts *callopt.CallOptions - ctx, ri, callOpts = kc.initRPCInfo(ctx, method, 0) - if callOpts != nil && callOpts.CompressorName != "" { - ctx = remote.SetSendCompressor(ctx, callOpts.CompressorName) - } + ctx, ri, callOpts = kc.initRPCInfo(ctx, method, 0, nil) ctx = kc.opt.TracerCtl.DoStart(ctx, ri) var reportErr error @@ -364,7 +361,7 @@ func (kc *kClient) Call(ctx context.Context, method string, request, response in recycleRI = true } } else { - recycleRI, err = kc.opt.RetryContainer.WithRetryIfNeeded(ctx, callOptRetry, kc.rpcCallWithRetry(ri, method, request, response), ri, request) + ri, recycleRI, err = kc.opt.RetryContainer.WithRetryIfNeeded(ctx, callOptRetry, kc.rpcCallWithRetry(ri, method, request, response), ri, request) } // do fallback if with setup @@ -381,7 +378,7 @@ func (kc *kClient) rpcCallWithRetry(ri rpcinfo.RPCInfo, method string, request, currCallTimes := int(atomic.AddInt32(&callTimes, 1)) cRI := ri if currCallTimes > 1 { - ctx, cRI, _ = kc.initRPCInfo(ctx, method, currCallTimes-1) + ctx, cRI, _ = kc.initRPCInfo(ctx, method, currCallTimes-1, ri) ctx = metainfo.WithPersistentValue(ctx, retry.TransitKey, strconv.Itoa(currCallTimes-1)) if prevRI.Load() == nil { prevRI.Store(ri) @@ -671,12 +668,17 @@ func getFallbackPolicy(cliOptFB *fallback.Policy, callOpts *callopt.CallOptions) return nil, false } -func initRPCInfo(ctx context.Context, method string, opt *client.Options, svcInfo *serviceinfo.ServiceInfo, retryTimes int) (context.Context, rpcinfo.RPCInfo, *callopt.CallOptions) { +func initRPCInfo(ctx context.Context, method string, opt *client.Options, svcInfo *serviceinfo.ServiceInfo, retryTimes int, firstRI rpcinfo.RPCInfo) (context.Context, rpcinfo.RPCInfo, *callopt.CallOptions) { cfg := rpcinfo.AsMutableRPCConfig(opt.Configs).Clone() rmt := remoteinfo.NewRemoteInfo(opt.Svr, method) var callOpts *callopt.CallOptions ctx, callOpts = applyCallOptions(ctx, cfg, rmt, opt) - rpcStats := rpcinfo.AsMutableRPCStats(rpcinfo.NewRPCStats()) + var rpcStats rpcinfo.MutableRPCStats + if firstRI != nil { + rpcStats = rpcinfo.AsMutableRPCStats(firstRI.Stats().CopyForRetry()) + } else { + rpcStats = rpcinfo.AsMutableRPCStats(rpcinfo.NewRPCStats()) + } if opt.StatsLevel != nil { rpcStats.SetLevel(*opt.StatsLevel) } @@ -712,5 +714,11 @@ func initRPCInfo(ctx context.Context, method string, opt *client.Options, svcInf } ctx = rpcinfo.NewCtxWithRPCInfo(ctx, ri) + + if callOpts != nil && callOpts.CompressorName != "" { + // set send grpc compressor at client to tell how to server decode + remote.SetSendCompressor(ri, callOpts.CompressorName) + } + return ctx, ri, callOpts } diff --git a/client/option_test.go b/client/option_test.go index d5fe530544..d10f0b424e 100644 --- a/client/option_test.go +++ b/client/option_test.go @@ -650,7 +650,7 @@ func TestWithLongConnectionOption(t *testing.T) { opts := client.NewOptions(options) test.Assert(t, opts.PoolCfg.MaxIdleTimeout == 30*time.Second) // defaultMaxIdleTimeout test.Assert(t, opts.PoolCfg.MaxIdlePerAddress == 1) // default - test.Assert(t, opts.PoolCfg.MaxIdleGlobal == 1) // default + test.Assert(t, opts.PoolCfg.MaxIdleGlobal == 1<<20) // default } func TestWithWarmingUpOption(t *testing.T) { diff --git a/client/service_inline.go b/client/service_inline.go index ed11c1aa71..a5967bcb60 100644 --- a/client/service_inline.go +++ b/client/service_inline.go @@ -27,6 +27,7 @@ import ( "github.com/cloudwego/kitex/client/callopt" "github.com/cloudwego/kitex/internal/client" internal_server "github.com/cloudwego/kitex/internal/server" + "github.com/cloudwego/kitex/pkg/consts" "github.com/cloudwego/kitex/pkg/endpoint" "github.com/cloudwego/kitex/pkg/klog" "github.com/cloudwego/kitex/pkg/rpcinfo" @@ -125,7 +126,7 @@ func (kc *serviceInlineClient) initMiddlewares(ctx context.Context) { // initRPCInfo initializes the RPCInfo structure and attaches it to context. func (kc *serviceInlineClient) initRPCInfo(ctx context.Context, method string) (context.Context, rpcinfo.RPCInfo, *callopt.CallOptions) { - return initRPCInfo(ctx, method, kc.opt, kc.svcInfo, 0) + return initRPCInfo(ctx, method, kc.opt, kc.svcInfo, 0, nil) } // Call implements the Client interface . @@ -207,6 +208,7 @@ func (kc *serviceInlineClient) constructServerRPCInfo(svrCtx, cliCtx context.Con ink.SetServiceName(kc.svcInfo.ServiceName) } rpcinfo.AsMutableEndpointInfo(ri.To()).SetMethod(method) + svrCtx = context.WithValue(svrCtx, consts.CtxKeyMethod, method) return svrCtx, ri } diff --git a/client/service_inline_test.go b/client/service_inline_test.go index 65d2e36615..bd372378f5 100644 --- a/client/service_inline_test.go +++ b/client/service_inline_test.go @@ -32,6 +32,7 @@ import ( "github.com/cloudwego/kitex/internal/mocks" internal_server "github.com/cloudwego/kitex/internal/server" "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/consts" "github.com/cloudwego/kitex/pkg/endpoint" "github.com/cloudwego/kitex/pkg/kerrors" "github.com/cloudwego/kitex/pkg/rpcinfo" @@ -39,9 +40,14 @@ import ( "github.com/cloudwego/kitex/pkg/serviceinfo" ) -type serverInitialInfoImpl struct{} +type serverInitialInfoImpl struct { + EndpointsFunc func(ctx context.Context, req, resp interface{}) (err error) +} func (s serverInitialInfoImpl) Endpoints() endpoint.Endpoint { + if s.EndpointsFunc != nil { + return s.EndpointsFunc + } return func(ctx context.Context, req, resp interface{}) (err error) { return nil } @@ -67,8 +73,8 @@ func newMockServiceInlineClient(tb testing.TB, ctrl *gomock.Controller, extra .. WithDestService("destService"), } opts = append(opts, extra...) - svcInfo := mocks.ServiceInfo() + cli, err := NewServiceInlineClient(svcInfo, newMockServerInitialInfo(), opts...) test.Assert(tb, err == nil) @@ -341,3 +347,28 @@ func TestServiceInlineClientFinalizer(t *testing.T) { t.Logf("After second GC, allocation: %f Mb, Number of allocation: %d\n", secondGCHeapAlloc, secondGCHeapObjects) test.Assert(t, secondGCHeapAlloc < firstGCHeapAlloc/2 && secondGCHeapObjects < firstGCHeapObjects/2) } + +func TestServiceInlineMethodKeyCall(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mtd := mocks.MockMethod + opts := []Option{ + WithTransHandlerFactory(newMockCliTransHandlerFactory(ctrl)), + WithResolver(resolver404(ctrl)), + WithDialer(newDialer(ctrl)), + WithDestService("destService"), + } + svcInfo := mocks.ServiceInfo() + s := serverInitialInfoImpl{} + s.EndpointsFunc = func(ctx context.Context, req, resp interface{}) (err error) { + test.Assert(t, ctx.Value(consts.CtxKeyMethod) == mtd) + return nil + } + cli, err := NewServiceInlineClient(svcInfo, s, opts...) + test.Assert(t, err == nil) + ctx := context.Background() + req := new(MockTStruct) + res := new(MockTStruct) + err = cli.Call(ctx, mtd, req, res) + test.Assert(t, err == nil, err) +} diff --git a/client/stream.go b/client/stream.go index affcbb118a..88806fdf78 100644 --- a/client/stream.go +++ b/client/stream.go @@ -45,7 +45,7 @@ func (kc *kClient) Stream(ctx context.Context, method string, request, response panic("ctx is nil") } var ri rpcinfo.RPCInfo - ctx, ri, _ = kc.initRPCInfo(ctx, method, 0) + ctx, ri, _ = kc.initRPCInfo(ctx, method, 0, nil) rpcinfo.AsMutableRPCConfig(ri.Config()).SetInteractionMode(rpcinfo.Streaming) ctx = rpcinfo.NewCtxWithRPCInfo(ctx, ri) diff --git a/go.mod b/go.mod index 30e1499362..edf0406982 100644 --- a/go.mod +++ b/go.mod @@ -11,9 +11,9 @@ require ( github.com/cloudwego/configmanager v0.2.0 github.com/cloudwego/dynamicgo v0.1.3 github.com/cloudwego/fastpb v0.0.4 - github.com/cloudwego/frugal v0.1.7 + github.com/cloudwego/frugal v0.1.8 github.com/cloudwego/localsession v0.0.2 - github.com/cloudwego/netpoll v0.4.1 + github.com/cloudwego/netpoll v0.5.0 github.com/cloudwego/thriftgo v0.3.0 github.com/golang/mock v1.6.0 github.com/google/pprof v0.0.0-20220608213341-c488b8fa1db3 diff --git a/go.sum b/go.sum index 809f99eab1..bcd5830029 100644 --- a/go.sum +++ b/go.sum @@ -52,8 +52,8 @@ github.com/cloudwego/fastpb v0.0.4 h1:/ROVVfoFtpfc+1pkQLzGs+azjxUbSOsAqSY4tAAx4m github.com/cloudwego/fastpb v0.0.4/go.mod h1:/V13XFTq2TUkxj2qWReV8MwfPC4NnPcy6FsrojnsSG0= github.com/cloudwego/frugal v0.1.3/go.mod h1:b981ViPYdhI56aFYsoMjl9kv6yeqYSO+iEz2jrhkCgI= github.com/cloudwego/frugal v0.1.6/go.mod h1:9ElktKsh5qd2zDBQ5ENhPSQV7F2dZ/mXlr1eaZGDBFs= -github.com/cloudwego/frugal v0.1.7 h1:Ggyk8mk0WrhBlM4g4RJxdOcVWJl/Hxbd8NJ19J8My6c= -github.com/cloudwego/frugal v0.1.7/go.mod h1:3VECBCSiTYwm3QApqHXjZB9NDH+8hUw7txxlr+6pPb4= +github.com/cloudwego/frugal v0.1.8 h1:MaJDRfvSnepsbUyMlQA9cySJ2+Y/we+r57tv5txx3sE= +github.com/cloudwego/frugal v0.1.8/go.mod h1:F0mLIWHymuQgh6r8N0owTA/ARv1B4SOiKa88tpOAfEU= github.com/cloudwego/kitex v0.3.2/go.mod h1:/XD07VpUD9VQWmmoepASgZ6iw//vgWikVA9MpzLC5i0= github.com/cloudwego/kitex v0.4.4/go.mod h1:3FcH5h9Qw+dhRljSzuGSpWuThttA8DvK0BsL7HUYydo= github.com/cloudwego/kitex v0.6.1/go.mod h1:zI1GBrjT0qloTikcCfQTgxg3Ws+yQMyaChEEOcGNUvA= @@ -62,8 +62,12 @@ github.com/cloudwego/localsession v0.0.2/go.mod h1:kiJxmvAcy4PLgKtEnPS5AXed3xCiX github.com/cloudwego/netpoll v0.2.4/go.mod h1:1T2WVuQ+MQw6h6DpE45MohSvDTKdy2DlzCx2KsnPI4E= github.com/cloudwego/netpoll v0.3.1/go.mod h1:1T2WVuQ+MQw6h6DpE45MohSvDTKdy2DlzCx2KsnPI4E= github.com/cloudwego/netpoll v0.4.0/go.mod h1:xVefXptcyheopwNDZjDPcfU6kIjZXZ4nY550k1yH9eQ= -github.com/cloudwego/netpoll v0.4.1 h1:/pGsY7Rs09KqEXEniB9fcsEWfi1iY+66bKUO3/NO6hc= -github.com/cloudwego/netpoll v0.4.1/go.mod h1:xVefXptcyheopwNDZjDPcfU6kIjZXZ4nY550k1yH9eQ= +github.com/cloudwego/netpoll v0.4.2-0.20230913081710-1a27688e2033 h1:/VYzCYH+Brp8CW1u475U+mPS7lHv5ulKx0vFJbp3YZ0= +github.com/cloudwego/netpoll v0.4.2-0.20230913081710-1a27688e2033/go.mod h1:xVefXptcyheopwNDZjDPcfU6kIjZXZ4nY550k1yH9eQ= +github.com/cloudwego/netpoll v0.4.2-0.20230918061532-5719b5310f34 h1:AbZPQaXr7MzOiUf1OZauww5rjmBpeLlyhM+hD7UsCn8= +github.com/cloudwego/netpoll v0.4.2-0.20230918061532-5719b5310f34/go.mod h1:xVefXptcyheopwNDZjDPcfU6kIjZXZ4nY550k1yH9eQ= +github.com/cloudwego/netpoll v0.5.0 h1:oRrOp58cPCvK2QbMozZNDESvrxQaEHW2dCimmwH1lcU= +github.com/cloudwego/netpoll v0.5.0/go.mod h1:xVefXptcyheopwNDZjDPcfU6kIjZXZ4nY550k1yH9eQ= github.com/cloudwego/thriftgo v0.1.2/go.mod h1:LzeafuLSiHA9JTiWC8TIMIq64iadeObgRUhmVG1OC/w= github.com/cloudwego/thriftgo v0.2.4/go.mod h1:8i9AF5uDdWHGqzUhXDlubCjx4MEfKvWXGQlMWyR0tM4= github.com/cloudwego/thriftgo v0.2.7/go.mod h1:8i9AF5uDdWHGqzUhXDlubCjx4MEfKvWXGQlMWyR0tM4= diff --git a/pkg/connpool/config.go b/pkg/connpool/config.go index 61496138dc..53511e611f 100644 --- a/pkg/connpool/config.go +++ b/pkg/connpool/config.go @@ -30,6 +30,7 @@ const ( defaultMaxIdleTimeout = 30 * time.Second minMaxIdleTimeout = 2 * time.Second maxMinIdlePerAddress = 5 + defaultMaxIdleGlobal = 1 << 20 // no limit ) // CheckPoolConfig to check invalid param. @@ -58,9 +59,8 @@ func CheckPoolConfig(config IdleConfig) *IdleConfig { // globalIdle if config.MaxIdleGlobal <= 0 { - config.MaxIdleGlobal = 1 - } - if config.MaxIdleGlobal < config.MaxIdlePerAddress { + config.MaxIdleGlobal = defaultMaxIdleGlobal + } else if config.MaxIdleGlobal < config.MaxIdlePerAddress { config.MaxIdleGlobal = config.MaxIdlePerAddress } return &config diff --git a/pkg/connpool/config_test.go b/pkg/connpool/config_test.go index ccde0ef092..1b5bc9ccd6 100644 --- a/pkg/connpool/config_test.go +++ b/pkg/connpool/config_test.go @@ -33,7 +33,7 @@ func TestCheckPoolConfig(t *testing.T) { cfg = CheckPoolConfig(IdleConfig{MinIdlePerAddress: -1}) test.Assert(t, cfg.MinIdlePerAddress == 0) test.Assert(t, cfg.MaxIdlePerAddress == 1) - test.Assert(t, cfg.MaxIdleGlobal == 1) + test.Assert(t, cfg.MaxIdleGlobal == defaultMaxIdleGlobal) cfg = CheckPoolConfig(IdleConfig{MinIdlePerAddress: 1}) test.Assert(t, cfg.MinIdlePerAddress == 1) cfg = CheckPoolConfig(IdleConfig{MinIdlePerAddress: maxMinIdlePerAddress + 1}) diff --git a/pkg/remote/codec/header_codec_test.go b/pkg/remote/codec/header_codec_test.go index 1539cd0cfe..0c52ca7c5f 100644 --- a/pkg/remote/codec/header_codec_test.go +++ b/pkg/remote/codec/header_codec_test.go @@ -343,9 +343,9 @@ func (m *mockInst) Address() net.Addr { return m.addr } -func (m *mockInst) SetRemoteAddr(addr net.Addr) (ok bool) { +func (m *mockInst) RefreshInstanceWithAddr(addr net.Addr) discovery.Instance { m.addr = addr - return true + return m } func (m *mockInst) Weight() int { diff --git a/pkg/remote/codec/protobuf/encoding/encoding.go b/pkg/remote/codec/protobuf/encoding/encoding.go index 8ad25990a0..be03f7f88e 100644 --- a/pkg/remote/codec/protobuf/encoding/encoding.go +++ b/pkg/remote/codec/protobuf/encoding/encoding.go @@ -28,6 +28,7 @@ package encoding import ( + "fmt" "io" "strings" ) @@ -82,6 +83,44 @@ func GetCompressor(name string) Compressor { return registeredCompressor[name] } +// FindCompressorName returns the name of compressor that actually used. +// when cname is like "identity,deflate,gzip", only one compressor name should be returned. +func FindCompressorName(cname string) string { + compressor, _ := FindCompressor(cname) + if compressor != nil { + return compressor.Name() + } + return "" +} + +// FindCompressor is used to search for compressors based on a given name, where the input name can be an array of compressor names. +func FindCompressor(cname string) (compressor Compressor, err error) { + // if cname is empty, it means there's no compressor + if cname == "" { + return nil, nil + } + // cname can be an array, such as "identity,deflate,gzip", which means there should be at least one compressor registered. + // found available compressors + var hasIdentity bool + for _, name := range strings.Split(strings.TrimSuffix(cname, ";"), ",") { + name = strings.TrimSpace(name) + if name == Identity { + hasIdentity = true + } + compressor = GetCompressor(name) + if compressor != nil { + break + } + } + if compressor == nil { + if hasIdentity { + return nil, nil + } + return nil, fmt.Errorf("no kitex compressor registered found for:%v", cname) + } + return compressor, nil +} + // Codec defines the interface gRPC uses to encode and decode messages. Note // that implementations of this interface must be thread safe; a Codec's // methods can be called from concurrent goroutines. diff --git a/pkg/remote/codec/protobuf/grpc.go b/pkg/remote/codec/protobuf/grpc.go index 0774d405a7..6ef4b2e9fc 100644 --- a/pkg/remote/codec/protobuf/grpc.go +++ b/pkg/remote/codec/protobuf/grpc.go @@ -18,8 +18,10 @@ package protobuf import ( "context" + "encoding/binary" "fmt" + "github.com/bytedance/gopkg/lang/mcache" "github.com/cloudwego/fastpb" "google.golang.org/protobuf/proto" @@ -46,19 +48,49 @@ func NewGRPCCodec() remote.Codec { return new(grpcCodec) } +func mallocWithFirstByteZeroed(size int) []byte { + data := mcache.Malloc(size) + data[0] = 0 // compressed flag = false + return data +} + func (c *grpcCodec) Encode(ctx context.Context, message remote.Message, out remote.ByteBuffer) (err error) { writer, ok := out.(remote.FrameWrite) if !ok { return fmt.Errorf("output buffer must implement FrameWrite") } + compressor, err := getSendCompressor(ctx) + if err != nil { + return err + } + isCompressed := compressor != nil + var payload []byte switch t := message.Data().(type) { case fastpb.Writer: - payload = make([]byte, t.Size()) + size := t.Size() + if !isCompressed { + payload = mallocWithFirstByteZeroed(size + dataFrameHeaderLen) + t.FastWrite(payload[dataFrameHeaderLen:]) + binary.BigEndian.PutUint32(payload[1:dataFrameHeaderLen], uint32(size)) + return writer.WriteData(payload) + } + payload = mcache.Malloc(size) t.FastWrite(payload) case marshaler: - payload = make([]byte, t.Size()) - _, err = t.MarshalTo(payload) + size := t.Size() + if !isCompressed { + payload = mallocWithFirstByteZeroed(size + dataFrameHeaderLen) + if _, err = t.MarshalTo(payload[dataFrameHeaderLen:]); err != nil { + return err + } + binary.BigEndian.PutUint32(payload[1:dataFrameHeaderLen], uint32(size)) + return writer.WriteData(payload) + } + payload = mcache.Malloc(size) + if _, err = t.MarshalTo(payload); err != nil { + return err + } case protobufV2MsgCodec: payload, err = t.XXX_Marshal(nil, true) case proto.Message: @@ -69,18 +101,22 @@ func (c *grpcCodec) Encode(ctx context.Context, message remote.Message, out remo if err != nil { return err } - - hdr, data, er := buildGRPCFrame(ctx, payload) - if er != nil { - return er + var header [dataFrameHeaderLen]byte + if isCompressed { + payload, err = compress(compressor, payload) + if err != nil { + return err + } + header[0] = 1 + } else { + header[0] = 0 } - - err = writer.WriteHeader(hdr) + binary.BigEndian.PutUint32(header[1:dataFrameHeaderLen], uint32(len(payload))) + err = writer.WriteHeader(header[:]) if err != nil { return err } - - return writer.WriteData(data) + return writer.WriteData(payload) } func (c *grpcCodec) Decode(ctx context.Context, message remote.Message, in remote.ByteBuffer) (err error) { diff --git a/pkg/remote/codec/protobuf/grpc_compress.go b/pkg/remote/codec/protobuf/grpc_compress.go index b53ad4b118..5c7ce822df 100644 --- a/pkg/remote/codec/protobuf/grpc_compress.go +++ b/pkg/remote/codec/protobuf/grpc_compress.go @@ -20,30 +20,29 @@ import ( "bytes" "context" "encoding/binary" - "fmt" + "errors" "io" + "github.com/cloudwego/kitex/pkg/rpcinfo" + + "github.com/bytedance/gopkg/lang/mcache" + "github.com/cloudwego/kitex/pkg/remote/codec/protobuf/encoding" "github.com/cloudwego/kitex/pkg/remote" ) -func buildGRPCFrame(ctx context.Context, payload []byte) ([]byte, []byte, error) { - data, isCompressed, err := compress(ctx, payload) - if err != nil { - return nil, nil, err - } - header := make([]byte, dataFrameHeaderLen) - if isCompressed { - header[0] = 1 - } else { - header[0] = 0 - } - binary.BigEndian.PutUint32(header[1:dataFrameHeaderLen], uint32(len(data))) - return header, data, nil +func getSendCompressor(ctx context.Context) (encoding.Compressor, error) { + ri := rpcinfo.GetRPCInfo(ctx) + return encoding.FindCompressor(remote.GetSendCompressor(ri)) } func decodeGRPCFrame(ctx context.Context, in remote.ByteBuffer) ([]byte, error) { + ri := rpcinfo.GetRPCInfo(ctx) + compressor, err := encoding.FindCompressor(remote.GetRecvCompressor(ri)) + if err != nil { + return nil, err + } hdr, err := in.Next(5) if err != nil { return nil, err @@ -55,40 +54,31 @@ func decodeGRPCFrame(ctx context.Context, in remote.ByteBuffer) ([]byte, error) return nil, err } if compressFlag == 1 { - return decompress(ctx, d) + if compressor == nil { + return nil, errors.New("kitex compression algorithm not found") + } + return decompress(compressor, d) } return d, nil } -func compress(ctx context.Context, data []byte) ([]byte, bool, error) { - cname := remote.GetSendCompressor(ctx) - if cname == "" { - return data, false, nil - } - compressor := encoding.GetCompressor(cname) - if compressor == nil { - return nil, false, fmt.Errorf("no compressor registered for: %s", cname) - } +func compress(compressor encoding.Compressor, data []byte) ([]byte, error) { + defer mcache.Free(data) cbuf := &bytes.Buffer{} z, err := compressor.Compress(cbuf) if err != nil { - return nil, false, err + return nil, err } if _, err = z.Write(data); err != nil { - return nil, false, err + return nil, err } if err = z.Close(); err != nil { - return nil, false, err + return nil, err } - return cbuf.Bytes(), true, nil + return cbuf.Bytes(), nil } -func decompress(ctx context.Context, data []byte) ([]byte, error) { - cname := remote.GetRecvCompressor(ctx) - compressor := encoding.GetCompressor(cname) - if compressor == nil { - return nil, fmt.Errorf("no compressor registered found for:%v", cname) - } +func decompress(compressor encoding.Compressor, data []byte) ([]byte, error) { dcReader, er := compressor.Decompress(bytes.NewReader(data)) if er != nil { return nil, er diff --git a/pkg/remote/codec/protobuf/grpc_test.go b/pkg/remote/codec/protobuf/grpc_test.go deleted file mode 100644 index 30851f90b3..0000000000 --- a/pkg/remote/codec/protobuf/grpc_test.go +++ /dev/null @@ -1,15 +0,0 @@ -// Copyright 2023 CloudWeGo Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package protobuf diff --git a/pkg/remote/codec/thrift/thrift_frugal_amd64.go b/pkg/remote/codec/thrift/thrift_frugal_amd64.go index 26cb979e93..d5778d2f1d 100644 --- a/pkg/remote/codec/thrift/thrift_frugal_amd64.go +++ b/pkg/remote/codec/thrift/thrift_frugal_amd64.go @@ -1,5 +1,5 @@ -//go:build amd64 && !windows && go1.16 && !go1.21 -// +build amd64,!windows,go1.16,!go1.21 +//go:build amd64 && !windows && go1.16 && !go1.22 +// +build amd64,!windows,go1.16,!go1.22 /* * Copyright 2021 CloudWeGo Authors diff --git a/pkg/remote/codec/thrift/thrift_frugal_amd64_test.go b/pkg/remote/codec/thrift/thrift_frugal_amd64_test.go index 5ec4242328..d2da63ad99 100644 --- a/pkg/remote/codec/thrift/thrift_frugal_amd64_test.go +++ b/pkg/remote/codec/thrift/thrift_frugal_amd64_test.go @@ -1,5 +1,5 @@ -//go:build amd64 && !windows && go1.16 && !go1.21 -// +build amd64,!windows,go1.16,!go1.21 +//go:build amd64 && !windows && go1.16 && !go1.22 +// +build amd64,!windows,go1.16,!go1.22 /* * Copyright 2021 CloudWeGo Authors diff --git a/pkg/remote/codec/thrift/thrift_others.go b/pkg/remote/codec/thrift/thrift_others.go index 88e0ad8515..a91d3cdaf0 100644 --- a/pkg/remote/codec/thrift/thrift_others.go +++ b/pkg/remote/codec/thrift/thrift_others.go @@ -1,5 +1,5 @@ -//go:build !amd64 || windows || !go1.16 || go1.21 -// +build !amd64 windows !go1.16 go1.21 +//go:build !amd64 || windows || !go1.16 || go1.22 +// +build !amd64 windows !go1.16 go1.22 /* * Copyright 2021 CloudWeGo Authors diff --git a/pkg/remote/compression.go b/pkg/remote/compression.go index 8f8c9393e8..58499a21f3 100644 --- a/pkg/remote/compression.go +++ b/pkg/remote/compression.go @@ -16,7 +16,9 @@ package remote -import "context" +import ( + "github.com/cloudwego/kitex/pkg/rpcinfo" +) // CompressType tells compression type for a message. type CompressType int32 @@ -28,27 +30,37 @@ const ( GZip ) -type recvCompressorKey struct{} - -type sendCompressorKey struct{} - -func SetRecvCompressor(ctx context.Context, compressorName string) context.Context { - return context.WithValue(ctx, recvCompressorKey{}, compressorName) +func SetRecvCompressor(ri rpcinfo.RPCInfo, compressorName string) { + if ri == nil { + return + } + rpcinfo.AsMutableEndpointInfo(ri.From()).SetTag("recv-compressor", compressorName) } -func SetSendCompressor(ctx context.Context, compressorName string) context.Context { - return context.WithValue(ctx, sendCompressorKey{}, compressorName) +func SetSendCompressor(ri rpcinfo.RPCInfo, compressorName string) { + if ri == nil { + return + } + rpcinfo.AsMutableEndpointInfo(ri.From()).SetTag("send-compressor", compressorName) } -func GetSendCompressor(ctx context.Context) string { - if v, ok := ctx.Value(sendCompressorKey{}).(string); ok { +func GetSendCompressor(ri rpcinfo.RPCInfo) string { + if ri == nil { + return "" + } + v, exist := ri.From().Tag("send-compressor") + if exist { return v } return "" } -func GetRecvCompressor(ctx context.Context) string { - if v, ok := ctx.Value(recvCompressorKey{}).(string); ok { +func GetRecvCompressor(ri rpcinfo.RPCInfo) string { + if ri == nil { + return "" + } + v, exist := ri.From().Tag("recv-compressor") + if exist { return v } return "" diff --git a/pkg/remote/connpool/long_pool_test.go b/pkg/remote/connpool/long_pool_test.go index 8ffc1e477f..90c5052aa1 100644 --- a/pkg/remote/connpool/long_pool_test.go +++ b/pkg/remote/connpool/long_pool_test.go @@ -561,7 +561,7 @@ func TestLongConnPoolCloseOnIdleTimeout(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - idleTime := time.Millisecond + idleTime := time.Second p := newLongPoolForTest(0, 2, 5, idleTime) defer p.Close() diff --git a/pkg/remote/trans/netpoll/bytebuf.go b/pkg/remote/trans/netpoll/bytebuf.go index a93d9a0ed1..a00f2cd174 100644 --- a/pkg/remote/trans/netpoll/bytebuf.go +++ b/pkg/remote/trans/netpoll/bytebuf.go @@ -255,8 +255,10 @@ func (b *netpollByteBuffer) Release(e error) (err error) { } func (b *netpollByteBuffer) zero() { - b.status = 0 b.writer = nil b.reader = nil + b.ioReader = nil + b.ioWriter = nil + b.status = 0 b.readSize = 0 } diff --git a/pkg/remote/trans/nphttp2/client_conn.go b/pkg/remote/trans/nphttp2/client_conn.go index 563c2b5835..876765db86 100644 --- a/pkg/remote/trans/nphttp2/client_conn.go +++ b/pkg/remote/trans/nphttp2/client_conn.go @@ -33,9 +33,14 @@ import ( "github.com/cloudwego/kitex/pkg/rpcinfo" ) +type streamDesc struct { + isStreaming bool +} + type clientConn struct { - tr grpc.ClientTransport - s *grpc.Stream + tr grpc.ClientTransport + s *grpc.Stream + desc *streamDesc } var _ GRPCConn = (*clientConn)(nil) @@ -63,7 +68,6 @@ func newClientConn(ctx context.Context, tr grpc.ClientTransport, addr string) (* } else { svcName = fmt.Sprintf("%s.%s", ri.Invocation().PackageName(), ri.Invocation().ServiceName()) } - host := ri.To().ServiceName() if rawURL, ok := ri.To().Tag(rpcinfo.HTTPURL); ok { u, err := url.Parse(rawURL) @@ -72,19 +76,20 @@ func newClientConn(ctx context.Context, tr grpc.ClientTransport, addr string) (* } host = u.Host } - + isStreaming := ri.Config().InteractionMode() == rpcinfo.Streaming s, err := tr.NewStream(ctx, &grpc.CallHdr{ Host: host, // grpc method format /package.Service/Method Method: fmt.Sprintf("/%s/%s", svcName, ri.Invocation().MethodName()), - SendCompress: remote.GetSendCompressor(ctx), + SendCompress: remote.GetSendCompressor(ri), }) if err != nil { return nil, err } return &clientConn{ - tr: tr, - s: s, + tr: tr, + s: s, + desc: &streamDesc{isStreaming: isStreaming}, }, nil } @@ -111,11 +116,7 @@ func (c *clientConn) Write(b []byte) (n int, err error) { } func (c *clientConn) WriteFrame(hdr, data []byte) (n int, err error) { - grpcConnOpt := &grpc.Options{} - // When there's no more data frame, add END_STREAM flag to this empty frame. - if hdr == nil && data == nil { - grpcConnOpt.Last = true - } + grpcConnOpt := &grpc.Options{Last: !c.desc.isStreaming} err = c.tr.Write(c.s, hdr, data, grpcConnOpt) return len(hdr) + len(data), err } diff --git a/pkg/remote/trans/nphttp2/client_conn_test.go b/pkg/remote/trans/nphttp2/client_conn_test.go index b647137202..d8bc004ed9 100644 --- a/pkg/remote/trans/nphttp2/client_conn_test.go +++ b/pkg/remote/trans/nphttp2/client_conn_test.go @@ -20,6 +20,7 @@ import ( "testing" "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/rpcinfo" ) func TestClientConn(t *testing.T) { @@ -59,3 +60,28 @@ func TestClientConn(t *testing.T) { test.Assert(t, err != nil, err) test.Assert(t, n == 0) } + +func TestClientConnStreamDesc(t *testing.T) { + connPool := newMockConnPool() + + // streaming + ctx := newMockCtxWithRPCInfo() + ri := rpcinfo.GetRPCInfo(ctx) + test.Assert(t, ri != nil) + rpcinfo.AsMutableRPCConfig(ri.Config()).SetInteractionMode(rpcinfo.Streaming) + conn, err := connPool.Get(ctx, "tcp", mockAddr0, newMockConnOption()) + test.Assert(t, err == nil, err) + defer conn.Close() + cn := conn.(*clientConn) + test.Assert(t, cn != nil) + test.Assert(t, cn.desc.isStreaming == true) + + // pingpong + rpcinfo.AsMutableRPCConfig(ri.Config()).SetInteractionMode(rpcinfo.PingPong) + conn, err = connPool.Get(ctx, "tcp", mockAddr0, newMockConnOption()) + test.Assert(t, err == nil, err) + defer conn.Close() + cn = conn.(*clientConn) + test.Assert(t, cn != nil) + test.Assert(t, cn.desc.isStreaming == false) +} diff --git a/pkg/remote/trans/nphttp2/client_handler.go b/pkg/remote/trans/nphttp2/client_handler.go index e098880681..add62c0e20 100644 --- a/pkg/remote/trans/nphttp2/client_handler.go +++ b/pkg/remote/trans/nphttp2/client_handler.go @@ -66,7 +66,9 @@ func (h *cliTransHandler) Read(ctx context.Context, conn net.Conn, msg remote.Me buf := newBuffer(conn.(*clientConn)) defer buf.Release(err) - ctx = remote.SetRecvCompressor(ctx, conn.(*clientConn).GetRecvCompress()) + // set recv grpc compressor at client to decode the pack from server + ri := rpcinfo.GetRPCInfo(ctx) + remote.SetRecvCompressor(ri, conn.(*clientConn).GetRecvCompress()) err = h.codec.Decode(ctx, msg, buf) if bizStatusErr, isBizErr := kerrors.FromBizStatusError(err); isBizErr { if setter, ok := msg.RPCInfo().Invocation().(rpcinfo.InvocationSetter); ok { diff --git a/pkg/remote/trans/nphttp2/grpc/http2_server.go b/pkg/remote/trans/nphttp2/grpc/http2_server.go index ee0a617763..6999fd9899 100644 --- a/pkg/remote/trans/nphttp2/grpc/http2_server.go +++ b/pkg/remote/trans/nphttp2/grpc/http2_server.go @@ -34,6 +34,8 @@ import ( "sync/atomic" "time" + "github.com/cloudwego/kitex/pkg/remote/codec/protobuf/encoding" + "github.com/cloudwego/netpoll" "golang.org/x/net/http2" "golang.org/x/net/http2/hpack" @@ -722,8 +724,9 @@ func (t *http2Server) writeHeaderLocked(s *Stream) error { headerFields := make([]hpack.HeaderField, 0, 3+s.header.Len()) // at least :status, content-type will be there if none else. headerFields = append(headerFields, hpack.HeaderField{Name: ":status", Value: "200"}) headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: contentType(s.contentSubtype)}) - if s.sendCompress != "" { - headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-encoding", Value: s.sendCompress}) + sendCompress := encoding.FindCompressorName(s.sendCompress) + if sendCompress != "" { + headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-encoding", Value: sendCompress}) } headerFields = appendHeaderFieldsFromMD(headerFields, s.header) success, err := t.controlBuf.executeAndPut(t.checkForHeaderListSize, &headerFrame{ diff --git a/pkg/remote/trans/nphttp2/server_conn.go b/pkg/remote/trans/nphttp2/server_conn.go index 2d14d3b838..80011692d2 100644 --- a/pkg/remote/trans/nphttp2/server_conn.go +++ b/pkg/remote/trans/nphttp2/server_conn.go @@ -85,12 +85,8 @@ func (c *serverConn) Write(b []byte) (n int, err error) { } func (c *serverConn) WriteFrame(hdr, data []byte) (n int, err error) { - grpcConnOpt := &grpc.Options{} - // When there's no more data frame, add END_STREAM flag to this empty frame. - if hdr == nil && data == nil { - grpcConnOpt.Last = true - } - err = c.tr.Write(c.s, hdr, data, grpcConnOpt) + // server sets the END_STREAM flag in trailer when writeStatus + err = c.tr.Write(c.s, hdr, data, &grpc.Options{}) return len(hdr) + len(data), err } diff --git a/pkg/remote/trans/nphttp2/server_handler.go b/pkg/remote/trans/nphttp2/server_handler.go index 9386e2867e..0856a51a36 100644 --- a/pkg/remote/trans/nphttp2/server_handler.go +++ b/pkg/remote/trans/nphttp2/server_handler.go @@ -184,8 +184,10 @@ func (t *svrTransHandler) OnRead(ctx context.Context, conn net.Conn) error { ink.SetServiceName(sm[idx+1 : pos]) } - rCtx = remote.SetRecvCompressor(rCtx, s.RecvCompress()) - rCtx = remote.SetSendCompressor(rCtx, s.SendCompress()) + // set recv grpc compressor at server to decode the pack from client + remote.SetRecvCompressor(ri, s.RecvCompress()) + // set send grpc compressor at server to encode reply pack + remote.SetSendCompressor(ri, s.SendCompress()) st := NewStream(rCtx, t.svcInfo, newServerConn(tr, s), t) streamArg := &streaming.Args{Stream: st} diff --git a/pkg/retry/backup_retryer.go b/pkg/retry/backup_retryer.go index 47f4191eac..da3c80125c 100644 --- a/pkg/retry/backup_retryer.go +++ b/pkg/retry/backup_retryer.go @@ -32,6 +32,7 @@ import ( "github.com/cloudwego/kitex/pkg/kerrors" "github.com/cloudwego/kitex/pkg/klog" "github.com/cloudwego/kitex/pkg/rpcinfo" + "github.com/cloudwego/kitex/pkg/utils" ) func newBackupRetryer(policy Policy, cbC *cbContainer) (Retryer, error) { @@ -83,14 +84,14 @@ func (r *backupRetryer) AllowRetry(ctx context.Context) (string, bool) { } // Do implement the Retryer interface. -func (r *backupRetryer) Do(ctx context.Context, rpcCall RPCCallFunc, firstRI rpcinfo.RPCInfo, req interface{}) (recycleRI bool, err error) { +func (r *backupRetryer) Do(ctx context.Context, rpcCall RPCCallFunc, firstRI rpcinfo.RPCInfo, req interface{}) (lastRI rpcinfo.RPCInfo, recycleRI bool, err error) { r.RLock() retryTimes := r.policy.StopPolicy.MaxRetryTimes retryDelay := r.retryDelay r.RUnlock() var callTimes int32 = 0 - var callCosts strings.Builder - callCosts.Grow(32) + var callCosts utils.StringBuilder + callCosts.RawStringBuilder().Grow(32) var recordCostDoing int32 = 0 var abort int32 = 0 // notice: buff num of chan is very important here, it cannot less than call times, or the below chan receive will block @@ -125,9 +126,13 @@ func (r *backupRetryer) Do(ctx context.Context, rpcCall RPCCallFunc, firstRI rpc }() ct := atomic.AddInt32(&callTimes, 1) callStart := time.Now() + if r.cbContainer.enablePercentageLimit { + // record stat before call since requests may be slow, making the limiter more accurate + recordRetryStat(cbKey, r.cbContainer.cbPanel, ct) + } cRI, _, e = rpcCall(ctx, r) recordCost(ct, callStart, &recordCostDoing, &callCosts, &abort, e) - if r.cbContainer.cbStat { + if !r.cbContainer.enablePercentageLimit && r.cbContainer.cbStat { circuitbreak.RecordStat(ctx, req, nil, e, cbKey, r.cbContainer.cbCtl, r.cbContainer.cbPanel) } }) @@ -145,8 +150,8 @@ func (r *backupRetryer) Do(ctx context.Context, rpcCall RPCCallFunc, firstRI rpc continue } atomic.StoreInt32(&abort, 1) - recordRetryInfo(firstRI, res.ri, atomic.LoadInt32(&callTimes), callCosts.String()) - return false, res.err + recordRetryInfo(res.ri, atomic.LoadInt32(&callTimes), callCosts.String()) + return res.ri, false, res.err } } } @@ -222,23 +227,26 @@ func (r *backupRetryer) Type() Type { } // record request cost, it may execute concurrent -func recordCost(ct int32, start time.Time, recordCostDoing *int32, sb *strings.Builder, abort *int32, err error) { +func recordCost(ct int32, start time.Time, recordCostDoing *int32, sb *utils.StringBuilder, abort *int32, err error) { if atomic.LoadInt32(abort) == 1 { return } for !atomic.CompareAndSwapInt32(recordCostDoing, 0, 1) { runtime.Gosched() } - if sb.Len() > 0 { - sb.WriteByte(',') - } - sb.WriteString(strconv.Itoa(int(ct))) - sb.WriteByte('-') - sb.WriteString(strconv.FormatInt(time.Since(start).Microseconds(), 10)) - if err != nil && errors.Is(err, kerrors.ErrRPCFinish) { - // ErrRPCFinish means previous call returns first but is decoding. - // Add ignore to distinguish. - sb.WriteString("(ignore)") - } + sb.WithLocked(func(b *strings.Builder) error { + if b.Len() > 0 { + b.WriteByte(',') + } + b.WriteString(strconv.Itoa(int(ct))) + b.WriteByte('-') + b.WriteString(strconv.FormatInt(time.Since(start).Microseconds(), 10)) + if err != nil && errors.Is(err, kerrors.ErrRPCFinish) { + // ErrRPCFinish means previous call returns first but is decoding. + // Add ignore to distinguish. + b.WriteString("(ignore)") + } + return nil + }) atomic.StoreInt32(recordCostDoing, 0) } diff --git a/pkg/retry/failure_retryer.go b/pkg/retry/failure_retryer.go index 3264f5bdce..078be145e5 100644 --- a/pkg/retry/failure_retryer.go +++ b/pkg/retry/failure_retryer.go @@ -84,7 +84,7 @@ func (r *failureRetryer) AllowRetry(ctx context.Context) (string, bool) { } // Do implement the Retryer interface. -func (r *failureRetryer) Do(ctx context.Context, rpcCall RPCCallFunc, firstRI rpcinfo.RPCInfo, req interface{}) (recycleRI bool, err error) { +func (r *failureRetryer) Do(ctx context.Context, rpcCall RPCCallFunc, firstRI rpcinfo.RPCInfo, req interface{}) (lastRI rpcinfo.RPCInfo, recycleRI bool, err error) { r.RLock() var maxDuration time.Duration if r.policy.StopPolicy.MaxDurationMS > 0 { @@ -127,10 +127,14 @@ func (r *failureRetryer) Do(ctx context.Context, rpcCall RPCCallFunc, firstRI rp } } callTimes++ + if r.cbContainer.enablePercentageLimit { + // record stat before call since requests may be slow, making the limiter more accurate + recordRetryStat(cbKey, r.cbContainer.cbPanel, callTimes) + } cRI, resp, err = rpcCall(ctx, r) callCosts.WriteString(strconv.FormatInt(time.Since(callStart).Microseconds(), 10)) - if r.cbContainer.cbStat { + if !r.cbContainer.enablePercentageLimit && r.cbContainer.cbStat { circuitbreak.RecordStat(ctx, req, nil, err, cbKey, r.cbContainer.cbCtl, r.cbContainer.cbPanel) } if err == nil { @@ -149,11 +153,11 @@ func (r *failureRetryer) Do(ctx context.Context, rpcCall RPCCallFunc, firstRI rp } } } - recordRetryInfo(firstRI, cRI, callTimes, callCosts.String()) + recordRetryInfo(cRI, callTimes, callCosts.String()) if err == nil && callTimes == 1 { - return true, nil + return cRI, true, nil } - return false, err + return cRI, false, err } // UpdatePolicy implements the Retryer interface. diff --git a/pkg/retry/percentage_limit.go b/pkg/retry/percentage_limit.go new file mode 100644 index 0000000000..d527644bc3 --- /dev/null +++ b/pkg/retry/percentage_limit.go @@ -0,0 +1,31 @@ +/* + * Copyright 2023 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package retry + +import ( + "github.com/bytedance/gopkg/cloud/circuitbreaker" +) + +// treat retry as 'error' for limiting the percentage of retry requests. +// callTimes == 1 means it's the first request, not a retry. +func recordRetryStat(cbKey string, panel circuitbreaker.Panel, callTimes int32) { + if callTimes > 1 { + panel.Fail(cbKey) + } else { + panel.Succeed(cbKey) + } +} diff --git a/pkg/retry/policy_test.go b/pkg/retry/policy_test.go index ebfcd3eddc..e767c6aedd 100644 --- a/pkg/retry/policy_test.go +++ b/pkg/retry/policy_test.go @@ -25,6 +25,7 @@ import ( "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/rpcinfo" "github.com/cloudwego/kitex/pkg/rpcinfo/remoteinfo" + "github.com/cloudwego/kitex/pkg/stats" ) var ( @@ -433,7 +434,17 @@ func TestPolicyNotRetryForTimeout(t *testing.T) { func genRPCInfo() rpcinfo.RPCInfo { to := remoteinfo.NewRemoteInfo(&rpcinfo.EndpointBasicInfo{Method: method}, method).ImmutableView() - ri := rpcinfo.NewRPCInfo(to, to, rpcinfo.NewInvocation("", method), rpcinfo.NewRPCConfig(), rpcinfo.NewRPCStats()) + riStats := rpcinfo.AsMutableRPCStats(rpcinfo.NewRPCStats()) + riStats.SetLevel(stats.LevelDetailed) + ri := rpcinfo.NewRPCInfo(to, to, rpcinfo.NewInvocation("", method), rpcinfo.NewRPCConfig(), riStats.ImmutableView()) + return ri +} + +func genRPCInfoWithFirstStats(firstRI rpcinfo.RPCInfo) rpcinfo.RPCInfo { + to := remoteinfo.NewRemoteInfo(&rpcinfo.EndpointBasicInfo{Method: method}, method).ImmutableView() + riStats := rpcinfo.AsMutableRPCStats(firstRI.Stats().CopyForRetry()) + riStats.SetLevel(stats.LevelDetailed) + ri := rpcinfo.NewRPCInfo(to, to, rpcinfo.NewInvocation("", method), rpcinfo.NewRPCConfig(), riStats.ImmutableView()) return ri } diff --git a/pkg/retry/retryer.go b/pkg/retry/retryer.go index 8e2088bc86..dbced34ecc 100644 --- a/pkg/retry/retryer.go +++ b/pkg/retry/retryer.go @@ -47,7 +47,7 @@ type Retryer interface { UpdatePolicy(policy Policy) error // Retry policy execute func. recycleRI is to decide if the firstRI can be recycled. - Do(ctx context.Context, rpcCall RPCCallFunc, firstRI rpcinfo.RPCInfo, request interface{}) (recycleRI bool, err error) + Do(ctx context.Context, rpcCall RPCCallFunc, firstRI rpcinfo.RPCInfo, request interface{}) (lastRI rpcinfo.RPCInfo, recycleRI bool, err error) AppendErrMsgIfNeeded(err error, ri rpcinfo.RPCInfo, msg string) // Prepare to do something needed before retry call. @@ -57,7 +57,7 @@ type Retryer interface { } // NewRetryContainerWithCB build Container that doesn't do circuit breaker statistic but get statistic result. -// Which is used in case that circuit breaker is enable. +// Which is used in case that circuit breaker is enabled. // eg: // // cbs := circuitbreak.NewCBSuite(circuitbreak.RPCInfo2Key) @@ -67,9 +67,11 @@ type Retryer interface { // // enable service circuit breaker // opts = append(opts, client.WithMiddleware(cbs.ServiceCBMW())) func NewRetryContainerWithCB(cc *circuitbreak.Control, cp circuitbreaker.Panel) *Container { - return &Container{ - cbContainer: &cbContainer{cbCtl: cc, cbPanel: cp}, retryerMap: sync.Map{}, - } + return NewRetryContainer(WithContainerCBControl(cc), WithContainerCBPanel(cp)) +} + +func newCBSuite() *circuitbreak.CBSuite { + return circuitbreak.NewCBSuite(circuitbreak.RPCInfo2Key) } // NewRetryContainerWithCBStat build Container that need to do circuit breaker statistic. @@ -79,15 +81,91 @@ func NewRetryContainerWithCB(cc *circuitbreak.Control, cp circuitbreaker.Panel) // cbs := circuitbreak.NewCBSuite(YourGenServiceCBKeyFunc) // retry.NewRetryContainerWithCBStat(cbs.ServiceControl(), cbs.ServicePanel()) func NewRetryContainerWithCBStat(cc *circuitbreak.Control, cp circuitbreaker.Panel) *Container { - return &Container{ - cbContainer: &cbContainer{cbCtl: cc, cbPanel: cp, cbStat: true}, retryerMap: sync.Map{}, + return NewRetryContainer(WithContainerCBControl(cc), WithContainerCBPanel(cp), WithContainerCBStat()) +} + +// NewRetryContainerWithPercentageLimit build a Container to limiting the percentage of retry requests; +// This is the RECOMMENDED initializer if you want to control PRECISELY the percentage of retry requests. +func NewRetryContainerWithPercentageLimit() *Container { + return NewRetryContainer(WithContainerEnablePercentageLimit()) +} + +// ContainerOption is used when initializing a Container +type ContainerOption func(rc *Container) + +// WithContainerCBSuite specifies the CBSuite used in the retry circuitbreak +// retryer will use its ServiceControl and ServicePanel +// Its priority is lower than WithContainerCBControl and WithContainerCBPanel +func WithContainerCBSuite(cbs *circuitbreak.CBSuite) ContainerOption { + return func(rc *Container) { + rc.cbContainer.cbSuite = cbs + } +} + +// WithContainerCBControl is specifies the circuitbreak.Control used in the retry circuitbreaker +// It's user's responsibility to make sure it's paired with panel +func WithContainerCBControl(ctrl *circuitbreak.Control) ContainerOption { + return func(rc *Container) { + rc.cbContainer.cbCtl = ctrl + } +} + +// WithContainerCBPanel is specifies the circuitbreaker.Panel used in the retry circuitbreaker +// It's user's responsibility to make sure it's paired with control +func WithContainerCBPanel(panel circuitbreaker.Panel) ContainerOption { + return func(rc *Container) { + rc.cbContainer.cbPanel = panel + } +} + +// WithContainerCBStat instructs the circuitbreak.RecordStat is called within the retryer +func WithContainerCBStat() ContainerOption { + return func(rc *Container) { + rc.cbContainer.cbStat = true + } +} + +// WithContainerEnablePercentageLimit should be called for limiting the percentage of retry requests +func WithContainerEnablePercentageLimit() ContainerOption { + return func(rc *Container) { + rc.cbContainer.enablePercentageLimit = true } } // NewRetryContainer build Container that need to build circuit breaker and do circuit breaker statistic. -func NewRetryContainer() *Container { - cbs := circuitbreak.NewCBSuite(circuitbreak.RPCInfo2Key) - return NewRetryContainerWithCBStat(cbs.ServiceControl(), cbs.ServicePanel()) +// The caller is responsible for calling Container.Close() to release resources referenced. +func NewRetryContainer(opts ...ContainerOption) *Container { + rc := &Container{ + cbContainer: &cbContainer{ + cbSuite: nil, + }, + retryerMap: sync.Map{}, + } + for _, opt := range opts { + opt(rc) + } + + if rc.cbContainer.enablePercentageLimit { + // ignore cbSuite/cbCtl/cbPanel options + rc.cbContainer = &cbContainer{ + enablePercentageLimit: true, + cbSuite: newCBSuite(), + } + } + + container := rc.cbContainer + if container.cbCtl == nil && container.cbPanel == nil { + if container.cbSuite == nil { + container.cbSuite = newCBSuite() + container.cbStat = true + } + container.cbCtl = container.cbSuite.ServiceControl() + container.cbPanel = container.cbSuite.ServicePanel() + } + if !container.IsValid() { + panic("KITEX: invalid container") + } + return rc } // Container is a wrapper for Retryer. @@ -102,10 +180,32 @@ type Container struct { shouldResultRetry *ShouldResultRetry } +// Recommended usage: NewRetryContainerWithPercentageLimit() +// For more details, refer to the following comments for each field. type cbContainer struct { + // In NewRetryContainer, if cbCtrl & cbPanel are not set, Kitex will use cbSuite.ServiceControl() and + // cbSuite.ServicePanel(); If cbSuite is nil, Kitex will create one. + cbSuite *circuitbreak.CBSuite + + // It's more recommended to rely on the cbSuite than specifying cbCtl & cbPanel with corresponding options, + // since cbCtl & cbPanel should be correctly paired, and with the cbSuite, Kitex will ensure it by using the + // cbSuite.ServiceControl() and cbSuite.ServicePanel(). cbCtl *circuitbreak.Control cbPanel circuitbreaker.Panel - cbStat bool + + // If cbStat && !enablePercentageLimit, retryer will call `circuitbreak.RecordStat` after rpcCall to record + // rpc failures/timeouts, for cutting down on the retry requests when the error rate is beyond the threshold. + cbStat bool + + // If enabled, Kitex will always create a cbSuite and use its cbCtl & cbPanel, and retryer will call + // recordRetryStat before rpcCall, to precisely control the percentage of retry requests over all requests. + enablePercentageLimit bool +} + +// IsValid returns true when both cbCtl & cbPanel are not nil +// It's the user's responsibility to guarantee that cbCtl & cbPanel are correctly paired. +func (c *cbContainer) IsValid() bool { + return c.cbCtl != nil && c.cbPanel != nil } // InitWithPolicies to init Retryer with methodPolicies @@ -188,7 +288,7 @@ func (rc *Container) Init(mp map[string]Policy, rr *ShouldResultRetry) (err erro // WithRetryIfNeeded to check if there is a retryer can be used and if current call can retry. // When the retry condition is satisfied, use retryer to call -func (rc *Container) WithRetryIfNeeded(ctx context.Context, callOptRetry *Policy, rpcCall RPCCallFunc, ri rpcinfo.RPCInfo, request interface{}) (recycleRI bool, err error) { +func (rc *Container) WithRetryIfNeeded(ctx context.Context, callOptRetry *Policy, rpcCall RPCCallFunc, ri rpcinfo.RPCInfo, request interface{}) (lastRI rpcinfo.RPCInfo, recycleRI bool, err error) { var retryer Retryer if callOptRetry != nil && callOptRetry.Enable { // build retryer for call level if retry policy is set up with callopt @@ -202,20 +302,20 @@ func (rc *Container) WithRetryIfNeeded(ctx context.Context, callOptRetry *Policy // case 1(default): no retry policy if retryer == nil { if _, _, err = rpcCall(ctx, nil); err == nil { - return true, nil + return ri, true, nil } - return false, err + return ri, false, err } // case 2: setup retry policy, but not satisfy retry condition eg: circuit, retry times == 0, chain stop, ddl if msg, ok := retryer.AllowRetry(ctx); !ok { if _, _, err = rpcCall(ctx, retryer); err == nil { - return true, err + return ri, true, err } if msg != "" { retryer.AppendErrMsgIfNeeded(err, ri, msg) } - return false, err + return ri, false, err } // case 3: retry @@ -227,7 +327,7 @@ func (rc *Container) WithRetryIfNeeded(ctx context.Context, callOptRetry *Policy ctx = context.WithValue(ctx, CtxRespOp, &respOp) // do rpc call with retry policy - recycleRI, err = retryer.Do(ctx, rpcCall, ri, request) + lastRI, recycleRI, err = retryer.Do(ctx, rpcCall, ri, request) // the rpc call has finished, modify respOp to done state. atomic.StoreInt32(&respOp, OpDone) @@ -308,3 +408,11 @@ func (rc *Container) updateRetryer(rr *ShouldResultRetry) { }) } } + +// Close releases all possible resources referenced. +func (rc *Container) Close() (err error) { + if rc.cbContainer != nil && rc.cbContainer.cbSuite != nil { + err = rc.cbContainer.cbSuite.Close() + } + return +} diff --git a/pkg/retry/retryer_test.go b/pkg/retry/retryer_test.go index 21be13dbc6..c8f23444e8 100644 --- a/pkg/retry/retryer_test.go +++ b/pkg/retry/retryer_test.go @@ -24,10 +24,12 @@ import ( "time" "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/discovery" "github.com/cloudwego/kitex/pkg/kerrors" "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/rpcinfo" "github.com/cloudwego/kitex/pkg/rpcinfo/remoteinfo" + "github.com/cloudwego/kitex/pkg/stats" ) var ( @@ -345,7 +347,7 @@ func TestFailurePolicyCall(t *testing.T) { test.Assert(t, err == nil, err) ri := genRPCInfo() ctx = rpcinfo.NewCtxWithRPCInfo(ctx, ri) - ok, err := rc.WithRetryIfNeeded(ctx, &Policy{}, func(ctx context.Context, r Retryer) (rpcinfo.RPCInfo, interface{}, error) { + _, ok, err := rc.WithRetryIfNeeded(ctx, &Policy{}, func(ctx context.Context, r Retryer) (rpcinfo.RPCInfo, interface{}, error) { return ri, nil, kerrors.ErrRPCTimeout }, ri, nil) test.Assert(t, err != nil, err) @@ -359,7 +361,7 @@ func TestFailurePolicyCall(t *testing.T) { FailurePolicy: failurePolicy, }}, nil) test.Assert(t, err == nil, err) - ok, err = rc.WithRetryIfNeeded(ctx, &Policy{}, func(ctx context.Context, r Retryer) (rpcinfo.RPCInfo, interface{}, error) { + _, ok, err = rc.WithRetryIfNeeded(ctx, &Policy{}, func(ctx context.Context, r Retryer) (rpcinfo.RPCInfo, interface{}, error) { return ri, nil, nil }, ri, nil) test.Assert(t, err == nil, err) @@ -382,7 +384,7 @@ func TestRetryWithOneTimePolicy(t *testing.T) { } ri := genRPCInfo() ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) - ok, err := NewRetryContainer().WithRetryIfNeeded(ctx, &p, func(ctx context.Context, r Retryer) (rpcinfo.RPCInfo, interface{}, error) { + _, ok, err := NewRetryContainer().WithRetryIfNeeded(ctx, &p, func(ctx context.Context, r Retryer) (rpcinfo.RPCInfo, interface{}, error) { return ri, nil, kerrors.ErrRPCTimeout }, ri, nil) test.Assert(t, err != nil, err) @@ -392,7 +394,7 @@ func TestRetryWithOneTimePolicy(t *testing.T) { failurePolicy.StopPolicy.MaxDurationMS = 0 var callTimes int32 ctx = rpcinfo.NewCtxWithRPCInfo(context.Background(), genRPCInfo()) - ok, err = NewRetryContainer().WithRetryIfNeeded(ctx, &p, func(ctx context.Context, r Retryer) (rpcinfo.RPCInfo, interface{}, error) { + _, ok, err = NewRetryContainer().WithRetryIfNeeded(ctx, &p, func(ctx context.Context, r Retryer) (rpcinfo.RPCInfo, interface{}, error) { if atomic.LoadInt32(&callTimes) == 0 { atomic.AddInt32(&callTimes, 1) return ri, nil, kerrors.ErrRPCTimeout @@ -406,7 +408,7 @@ func TestRetryWithOneTimePolicy(t *testing.T) { p = BuildBackupRequest(NewBackupPolicy(10)) ctx = rpcinfo.NewCtxWithRPCInfo(context.Background(), genRPCInfo()) callTimes = 0 - ok, err = NewRetryContainer().WithRetryIfNeeded(ctx, &p, func(ctx context.Context, r Retryer) (rpcinfo.RPCInfo, interface{}, error) { + _, ok, err = NewRetryContainer().WithRetryIfNeeded(ctx, &p, func(ctx context.Context, r Retryer) (rpcinfo.RPCInfo, interface{}, error) { if atomic.LoadInt32(&callTimes) == 0 || atomic.LoadInt32(&callTimes) == 1 { atomic.AddInt32(&callTimes, 1) time.Sleep(time.Millisecond * 100) @@ -444,7 +446,7 @@ func TestSpecifiedErrorRetry(t *testing.T) { }} err := rc.Init(map[string]Policy{Wildcard: BuildFailurePolicy(NewFailurePolicy())}, shouldResultRetry) test.Assert(t, err == nil, err) - ok, err := rc.WithRetryIfNeeded(ctx, &Policy{}, retryWithTransError, ri, nil) + ri, ok, err := rc.WithRetryIfNeeded(ctx, &Policy{}, retryWithTransError, ri, nil) test.Assert(t, err == nil, err) test.Assert(t, !ok) v, ok := ri.To().Tag(remoteTagKey) @@ -465,7 +467,7 @@ func TestSpecifiedErrorRetry(t *testing.T) { err = rc.Init(map[string]Policy{Wildcard: BuildBackupRequest(NewBackupPolicy(10))}, shouldResultRetry) test.Assert(t, err == nil, err) ri = genRPCInfo() - ok, err = rc.WithRetryIfNeeded(ctx, &Policy{}, retryWithTransError, ri, nil) + _, ok, err = rc.WithRetryIfNeeded(ctx, &Policy{}, retryWithTransError, ri, nil) test.Assert(t, err != nil, err) test.Assert(t, !ok) @@ -483,7 +485,7 @@ func TestSpecifiedErrorRetry(t *testing.T) { err = rc.Init(map[string]Policy{method: BuildFailurePolicy(NewFailurePolicy())}, shouldResultRetry) test.Assert(t, err == nil, err) ri = genRPCInfo() - ok, err = rc.WithRetryIfNeeded(ctx, &Policy{}, retryWithTransError, ri, nil) + ri, ok, err = rc.WithRetryIfNeeded(ctx, &Policy{}, retryWithTransError, ri, nil) test.Assert(t, err != nil) test.Assert(t, !ok) _, ok = ri.To().Tag(remoteTagKey) @@ -494,7 +496,7 @@ func TestSpecifiedErrorRetry(t *testing.T) { rc = NewRetryContainer() p := BuildFailurePolicy(NewFailurePolicyWithResultRetry(AllErrorRetry())) ri = genRPCInfo() - ok, err = rc.WithRetryIfNeeded(ctx, &p, retryWithTransError, ri, nil) + ri, ok, err = rc.WithRetryIfNeeded(ctx, &p, retryWithTransError, ri, nil) test.Assert(t, err == nil, err) test.Assert(t, !ok) v, ok = ri.To().Tag(remoteTagKey) @@ -539,7 +541,7 @@ func TestSpecifiedRespRetry(t *testing.T) { }} err := rc.Init(map[string]Policy{Wildcard: BuildFailurePolicy(NewFailurePolicy())}, shouldResultRetry) test.Assert(t, err == nil, err) - ok, err := rc.WithRetryIfNeeded(ctx, &Policy{}, retryWithResp, ri, nil) + ri, ok, err := rc.WithRetryIfNeeded(ctx, &Policy{}, retryWithResp, ri, nil) test.Assert(t, err == nil, err) test.Assert(t, retryResult.GetResult() == noRetryResp, retryResult) test.Assert(t, !ok) @@ -554,7 +556,7 @@ func TestSpecifiedRespRetry(t *testing.T) { test.Assert(t, err == nil, err) ri = genRPCInfo() ctx = rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) - ok, err = rc.WithRetryIfNeeded(ctx, &Policy{}, retryWithResp, ri, nil) + _, ok, err = rc.WithRetryIfNeeded(ctx, &Policy{}, retryWithResp, ri, nil) test.Assert(t, err == nil, err) test.Assert(t, retryResult.GetResult() == retryResp, retryResp) test.Assert(t, !ok) @@ -574,7 +576,7 @@ func TestSpecifiedRespRetry(t *testing.T) { test.Assert(t, err == nil, err) ri = genRPCInfo() ctx = rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) - ok, err = rc.WithRetryIfNeeded(ctx, &Policy{}, retryWithResp, ri, nil) + ri, ok, err = rc.WithRetryIfNeeded(ctx, &Policy{}, retryWithResp, ri, nil) test.Assert(t, err == nil, err) test.Assert(t, retryResult.GetResult() == retryResp, retryResult) test.Assert(t, ok) @@ -628,7 +630,7 @@ func TestDifferentMethodConfig(t *testing.T) { // case1: test method do error retry ri := genRPCInfo() ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) - ok, err := rc.WithRetryIfNeeded(ctx, &Policy{}, rpcCall, ri, nil) + _, ok, err := rc.WithRetryIfNeeded(ctx, &Policy{}, rpcCall, ri, nil) test.Assert(t, err == nil, err) test.Assert(t, !ok) lock.Lock() @@ -641,7 +643,7 @@ func TestDifferentMethodConfig(t *testing.T) { to := remoteinfo.NewRemoteInfo(&rpcinfo.EndpointBasicInfo{Method: method2}, method2).ImmutableView() ri = rpcinfo.NewRPCInfo(to, to, rpcinfo.NewInvocation("", method2), rpcinfo.NewRPCConfig(), rpcinfo.NewRPCStats()) ctx = rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) - ok, err = rc.WithRetryIfNeeded(ctx, &Policy{}, rpcCall, ri, nil) + _, ok, err = rc.WithRetryIfNeeded(ctx, &Policy{}, rpcCall, ri, nil) test.Assert(t, err == nil, err) test.Assert(t, !ok) lock.Lock() @@ -702,7 +704,7 @@ func TestBackupPolicyCall(t *testing.T) { firstRI := genRPCInfo() secondRI := genRPCInfoWithRemoteTag(remoteTags) ctx = rpcinfo.NewCtxWithRPCInfo(ctx, firstRI) - ok, err := rc.WithRetryIfNeeded(ctx, &Policy{}, func(ctx context.Context, r Retryer) (rpcinfo.RPCInfo, interface{}, error) { + ri, ok, err := rc.WithRetryIfNeeded(ctx, &Policy{}, func(ctx context.Context, r Retryer) (rpcinfo.RPCInfo, interface{}, error) { atomic.AddInt32(&callTimes, 1) if atomic.LoadInt32(&callTimes) == 1 { // mock timeout for the first request and get the response of the backup request. @@ -714,7 +716,7 @@ func TestBackupPolicyCall(t *testing.T) { test.Assert(t, err == nil, err) test.Assert(t, atomic.LoadInt32(&callTimes) == 2) test.Assert(t, !ok) - v, ok := firstRI.To().Tag(remoteTagKey) + v, ok := ri.To().Tag(remoteTagKey) test.Assert(t, ok) test.Assert(t, v == remoteTagValue) } @@ -729,7 +731,7 @@ func TestPolicyNoRetryCall(t *testing.T) { var callTimes int32 ri := genRPCInfo() ctx = rpcinfo.NewCtxWithRPCInfo(ctx, ri) - ok, err := rc.WithRetryIfNeeded(ctx, &Policy{}, func(ctx context.Context, r Retryer) (rpcinfo.RPCInfo, interface{}, error) { + _, ok, err := rc.WithRetryIfNeeded(ctx, &Policy{}, func(ctx context.Context, r Retryer) (rpcinfo.RPCInfo, interface{}, error) { atomic.AddInt32(&callTimes, 1) return ri, nil, nil }, ri, nil) @@ -741,7 +743,7 @@ func TestPolicyNoRetryCall(t *testing.T) { atomic.StoreInt32(&callTimes, 0) ri = genRPCInfo() ctx = rpcinfo.NewCtxWithRPCInfo(ctx, ri) - ok, err = rc.WithRetryIfNeeded(ctx, &Policy{}, func(ctx context.Context, r Retryer) (rpcinfo.RPCInfo, interface{}, error) { + _, ok, err = rc.WithRetryIfNeeded(ctx, &Policy{}, func(ctx context.Context, r Retryer) (rpcinfo.RPCInfo, interface{}, error) { atomic.AddInt32(&callTimes, 1) if atomic.LoadInt32(&callTimes) == 1 { return ri, nil, kerrors.ErrRPCTimeout @@ -768,7 +770,7 @@ func TestPolicyNoRetryCall(t *testing.T) { atomic.StoreInt32(&callTimes, 0) ri = genRPCInfo() ctx = rpcinfo.NewCtxWithRPCInfo(ctx, ri) - ok, err = rc.WithRetryIfNeeded(ctx, &Policy{}, func(ctx context.Context, r Retryer) (rpcinfo.RPCInfo, interface{}, error) { + _, ok, err = rc.WithRetryIfNeeded(ctx, &Policy{}, func(ctx context.Context, r Retryer) (rpcinfo.RPCInfo, interface{}, error) { atomic.AddInt32(&callTimes, 1) if atomic.LoadInt32(&callTimes) == 1 { return ri, nil, kerrors.ErrRPCTimeout @@ -791,7 +793,7 @@ func TestPolicyNoRetryCall(t *testing.T) { atomic.StoreInt32(&callTimes, 0) ri = genRPCInfo() ctx = rpcinfo.NewCtxWithRPCInfo(ctx, ri) - ok, err = rc.WithRetryIfNeeded(ctx, &Policy{}, func(ctx context.Context, r Retryer) (rpcinfo.RPCInfo, interface{}, error) { + _, ok, err = rc.WithRetryIfNeeded(ctx, &Policy{}, func(ctx context.Context, r Retryer) (rpcinfo.RPCInfo, interface{}, error) { atomic.AddInt32(&callTimes, 1) time.Sleep(time.Millisecond * 100) return ri, nil, nil @@ -801,6 +803,76 @@ func TestPolicyNoRetryCall(t *testing.T) { test.Assert(t, ok) } +func retryCall(callTimes *int32, firstRI rpcinfo.RPCInfo, backup bool) RPCCallFunc { + // prevRI represents a value of rpcinfo.RPCInfo type. + var prevRI atomic.Value + return func(ctx context.Context, r Retryer) (rpcinfo.RPCInfo, interface{}, error) { + currCallTimes := int(atomic.AddInt32(callTimes, 1)) + cRI := firstRI + if currCallTimes > 1 { + cRI = genRPCInfoWithFirstStats(firstRI) + cRI.Stats().Record(ctx, stats.RPCFinish, stats.StatusInfo, "") + remoteInfo := remoteinfo.AsRemoteInfo(cRI.To()) + remoteInfo.SetInstance(discovery.NewInstance("tcp", "10.20.30.40:8888", 10, nil)) + if prevRI.Load() == nil { + prevRI.Store(firstRI) + } + r.Prepare(ctx, prevRI.Load().(rpcinfo.RPCInfo), cRI) + prevRI.Store(cRI) + return cRI, nil, nil + } else { + if backup { + time.Sleep(20 * time.Millisecond) + return cRI, nil, nil + } else { + return cRI, nil, kerrors.ErrRPCTimeout + } + } + } +} + +func TestFailureRetryWithRPCInfo(t *testing.T) { + // failure retry + ctx := context.Background() + rc := NewRetryContainer() + + ri := genRPCInfo() + ctx = rpcinfo.NewCtxWithRPCInfo(ctx, ri) + rpcinfo.Record(ctx, ri, stats.RPCStart, nil) + + // call with retry policy + var callTimes int32 + policy := BuildFailurePolicy(NewFailurePolicy()) + ri, ok, err := rc.WithRetryIfNeeded(ctx, &policy, retryCall(&callTimes, ri, false), ri, nil) + test.Assert(t, err == nil, err) + test.Assert(t, !ok) + test.Assert(t, ri.Stats().GetEvent(stats.RPCStart).Status() == stats.StatusInfo) + test.Assert(t, ri.Stats().GetEvent(stats.RPCFinish).Status() == stats.StatusInfo) + test.Assert(t, ri.To().Address().String() == "10.20.30.40:8888") + test.Assert(t, atomic.LoadInt32(&callTimes) == 2) +} + +func TestBackupRetryWithRPCInfo(t *testing.T) { + // backup retry + ctx := context.Background() + rc := NewRetryContainer() + + ri := genRPCInfo() + ctx = rpcinfo.NewCtxWithRPCInfo(ctx, ri) + rpcinfo.Record(ctx, ri, stats.RPCStart, nil) + + // call with retry policy + var callTimes int32 + policy := BuildBackupRequest(NewBackupPolicy(10)) + ri, ok, err := rc.WithRetryIfNeeded(ctx, &policy, retryCall(&callTimes, ri, true), ri, nil) + test.Assert(t, err == nil, err) + test.Assert(t, !ok) + test.Assert(t, ri.Stats().GetEvent(stats.RPCStart).Status() == stats.StatusInfo) + test.Assert(t, ri.Stats().GetEvent(stats.RPCFinish).Status() == stats.StatusInfo) + test.Assert(t, ri.To().Address().String() == "10.20.30.40:8888") + test.Assert(t, atomic.LoadInt32(&callTimes) == 2) +} + type mockResult struct { result mockResp sync.RWMutex @@ -822,3 +894,68 @@ func (r *mockResult) SetResult(ret mockResp) { defer r.Unlock() r.result = ret } + +func TestNewRetryContainerWithOptions(t *testing.T) { + t.Run("no_option", func(t *testing.T) { + rc := NewRetryContainer() + test.Assertf(t, rc.cbContainer.cbSuite != nil, "cb_suite nil") + test.Assertf(t, rc.cbContainer.cbStat == true, "cb_stat not true") + }) + + t.Run("percentage_limit", func(t *testing.T) { + rc := NewRetryContainer(WithContainerEnablePercentageLimit()) + test.Assertf(t, rc.cbContainer.enablePercentageLimit == true, "percentage_limit not true") + }) + + t.Run("percentage_limit&&cbOptions", func(t *testing.T) { + cbSuite := newCBSuite() + rc := NewRetryContainer( + WithContainerEnablePercentageLimit(), + WithContainerCBSuite(cbSuite), + WithContainerCBControl(cbSuite.ServiceControl()), + WithContainerCBPanel(cbSuite.ServicePanel()), + ) + test.Assertf(t, rc.cbContainer.enablePercentageLimit == true, "percentage_limit not true") + test.Assertf(t, rc.cbContainer.cbSuite != cbSuite, "cbSuite not ignored") + test.Assertf(t, rc.cbContainer.cbCtl != cbSuite.ServiceControl(), "cbCtl not ignored") + test.Assertf(t, rc.cbContainer.cbPanel != cbSuite.ServicePanel(), "cbPanel not ignored") + }) + + t.Run("cb_stat", func(t *testing.T) { + rc := NewRetryContainer(WithContainerCBStat()) + test.Assertf(t, rc.cbContainer.cbStat == true, "cb_stat not true") + }) + + t.Run("cb_suite", func(t *testing.T) { + cbs := newCBSuite() + rc := NewRetryContainer(WithContainerCBSuite(cbs)) + test.Assert(t, rc.cbContainer.cbSuite == cbs, "cbSuite expected %p, got %p", cbs, rc.cbContainer.cbSuite) + }) + + t.Run("cb_control&cb_panel", func(t *testing.T) { + cbs := newCBSuite() + rc := NewRetryContainer( + WithContainerCBControl(cbs.ServiceControl()), + WithContainerCBPanel(cbs.ServicePanel())) + test.Assert(t, rc.cbContainer.cbCtl == cbs.ServiceControl(), "cbControl not match") + test.Assert(t, rc.cbContainer.cbPanel == cbs.ServicePanel(), "cbPanel not match") + }) +} + +func TestNewRetryContainerWithCBStat(t *testing.T) { + cbs := newCBSuite() + rc := NewRetryContainerWithCBStat(cbs.ServiceControl(), cbs.ServicePanel()) + test.Assert(t, rc.cbContainer.cbCtl == cbs.ServiceControl(), "cbControl not match") + test.Assert(t, rc.cbContainer.cbPanel == cbs.ServicePanel(), "cbPanel not match") + test.Assertf(t, rc.cbContainer.cbStat == true, "cb_stat not true") + rc.Close() +} + +func TestNewRetryContainerWithCB(t *testing.T) { + cbs := newCBSuite() + rc := NewRetryContainerWithCB(cbs.ServiceControl(), cbs.ServicePanel()) + test.Assert(t, rc.cbContainer.cbCtl == cbs.ServiceControl(), "cbControl not match") + test.Assert(t, rc.cbContainer.cbPanel == cbs.ServicePanel(), "cbPanel not match") + test.Assertf(t, rc.cbContainer.cbStat == false, "cb_stat not false") + rc.Close() +} diff --git a/pkg/retry/util.go b/pkg/retry/util.go index 8428fea60b..3d00717153 100644 --- a/pkg/retry/util.go +++ b/pkg/retry/util.go @@ -55,6 +55,8 @@ const ( OpDone ) +var tagValueFirstTry = "0" + // DDLStopFunc is the definition of ddlStop func type DDLStopFunc func(ctx context.Context, policy StopPolicy) (bool, string) @@ -83,7 +85,7 @@ func chainStop(ctx context.Context, policy StopPolicy) (bool, string) { if policy.DisableChainStop { return false, "" } - if _, exist := metainfo.GetPersistentValue(ctx, TransitKey); !exist { + if !IsRemoteRetryRequest(ctx) { return false, "" } return true, "chain stop retry" @@ -154,19 +156,27 @@ func appendErrMsg(err error, msg string) { } } -func recordRetryInfo(firstRI, lastRI rpcinfo.RPCInfo, callTimes int32, lastCosts string) { +func recordRetryInfo(ri rpcinfo.RPCInfo, callTimes int32, lastCosts string) { if callTimes > 1 { - if firstRe := remoteinfo.AsRemoteInfo(firstRI.To()); firstRe != nil { - // use the remoteInfo of the RPCCall that returns finally, in case the remoteInfo is modified during the call. - if lastRI != nil { - if lastRe := remoteinfo.AsRemoteInfo(lastRI.To()); lastRe != nil { - firstRe.CopyFrom(lastRe) - } - } - - firstRe.SetTag(rpcinfo.RetryTag, strconv.Itoa(int(callTimes)-1)) + if re := remoteinfo.AsRemoteInfo(ri.To()); re != nil { + re.SetTag(rpcinfo.RetryTag, strconv.Itoa(int(callTimes)-1)) // record last cost - firstRe.SetTag(rpcinfo.RetryLastCostTag, lastCosts) + re.SetTag(rpcinfo.RetryLastCostTag, lastCosts) } } } + +// IsLocalRetryRequest checks whether it's a retry request by checking the RetryTag set in rpcinfo +// It's supposed to be used in client middlewares +func IsLocalRetryRequest(ctx context.Context) bool { + ri := rpcinfo.GetRPCInfo(ctx) + retryCountStr := ri.To().DefaultTag(rpcinfo.RetryTag, tagValueFirstTry) + return retryCountStr != tagValueFirstTry +} + +// IsRemoteRetryRequest checks whether it's a retry request by checking the TransitKey in metainfo +// It's supposed to be used in server side (handler/middleware) +func IsRemoteRetryRequest(ctx context.Context) bool { + _, isRetry := metainfo.GetPersistentValue(ctx, TransitKey) + return isRetry +} diff --git a/pkg/retry/util_test.go b/pkg/retry/util_test.go new file mode 100644 index 0000000000..74414075a7 --- /dev/null +++ b/pkg/retry/util_test.go @@ -0,0 +1,64 @@ +/* + * Copyright 2023 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package retry + +import ( + "context" + "testing" + + "github.com/bytedance/gopkg/cloud/metainfo" + + "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/rpcinfo" +) + +func mockRPCInfo(retryTag string) rpcinfo.RPCInfo { + var tags map[string]string + if retryTag != "" { + tags = map[string]string{ + rpcinfo.RetryTag: retryTag, + } + } + to := rpcinfo.NewEndpointInfo("service", "method", nil, tags) + return rpcinfo.NewRPCInfo(nil, to, nil, nil, nil) +} + +func mockContext(retryTag string) context.Context { + return rpcinfo.NewCtxWithRPCInfo(context.TODO(), mockRPCInfo(retryTag)) +} + +func TestIsLocalRetryRequest(t *testing.T) { + t.Run("no-retry-tag", func(t *testing.T) { + test.Assertf(t, !IsLocalRetryRequest(mockContext("")), "no-retry-tag") + }) + t.Run("retry-tag=0", func(t *testing.T) { + test.Assertf(t, !IsLocalRetryRequest(mockContext("0")), "retry-tag=0") + }) + t.Run("retry-tag=1", func(t *testing.T) { + test.Assertf(t, IsLocalRetryRequest(mockContext("1")), "retry-tag=1") + }) +} + +func TestIsRemoteRetryRequest(t *testing.T) { + t.Run("no-retry", func(t *testing.T) { + test.Assertf(t, !IsRemoteRetryRequest(context.Background()), "should be not retry") + }) + t.Run("retry", func(t *testing.T) { + ctx := metainfo.WithPersistentValue(context.Background(), TransitKey, "2") + test.Assertf(t, IsRemoteRetryRequest(ctx), "should be retry") + }) +} diff --git a/pkg/rpcinfo/interface.go b/pkg/rpcinfo/interface.go index 489b864617..c78d4a9b39 100644 --- a/pkg/rpcinfo/interface.go +++ b/pkg/rpcinfo/interface.go @@ -44,6 +44,7 @@ type RPCStats interface { Panicked() (bool, interface{}) GetEvent(event stats.Event) Event Level() stats.Level + CopyForRetry() RPCStats } // Event is the abstraction of an event happened at a specific time. diff --git a/pkg/rpcinfo/mocks_test.go b/pkg/rpcinfo/mocks_test.go index d76f12a4ce..0fe5129f94 100644 --- a/pkg/rpcinfo/mocks_test.go +++ b/pkg/rpcinfo/mocks_test.go @@ -93,3 +93,4 @@ func (m *MockRPCStats) Error() error func (m *MockRPCStats) Panicked() (yes bool, val interface{}) { return } func (m *MockRPCStats) GetEvent(event stats.Event) (e rpcinfo.Event) { return } func (m *MockRPCStats) Level() (lv stats.Level) { return } +func (m *MockRPCStats) CopyForRetry() rpcinfo.RPCStats { return nil } diff --git a/pkg/rpcinfo/remoteinfo/remoteInfo.go b/pkg/rpcinfo/remoteinfo/remoteInfo.go index c479f41e7b..3b067f4abe 100644 --- a/pkg/rpcinfo/remoteinfo/remoteInfo.go +++ b/pkg/rpcinfo/remoteinfo/remoteInfo.go @@ -43,13 +43,12 @@ type RemoteInfo interface { // SetRemoteAddr tries to set the network address of the discovery.Instance hold by RemoteInfo. // The result indicates whether the modification is successful. SetRemoteAddr(addr net.Addr) (ok bool) - CopyFrom(from RemoteInfo) ImmutableView() rpcinfo.EndpointInfo } -// RemoteAddrSetter is used to set remote addr. -type RemoteAddrSetter interface { - SetRemoteAddr(addr net.Addr) (ok bool) +// RefreshableInstance declares an interface which can return an instance containing the new address. +type RefreshableInstance interface { + RefreshInstanceWithAddr(addr net.Addr) (newInstance discovery.Instance) } var ( @@ -130,12 +129,13 @@ func (ri *remoteInfo) DefaultTag(key, def string) string { return def } -// SetRemoteAddr implements the RemoteAddrSetter interface. +// SetRemoteAddr implements the RemoteInfo interface. func (ri *remoteInfo) SetRemoteAddr(addr net.Addr) bool { - if ins, ok := ri.instance.(RemoteAddrSetter); ok { - ri.Lock() - defer ri.Unlock() - return ins.SetRemoteAddr(addr) + ri.Lock() + defer ri.Unlock() + if ins, ok := ri.instance.(RefreshableInstance); ok { + ri.instance = ins.RefreshInstanceWithAddr(addr) + return true } return false } @@ -175,47 +175,11 @@ func (ri *remoteInfo) ForceSetTag(key, value string) { ri.tags[key] = value } -// CopyFrom copies the input RemoteInfo. -// the `from` param may be modified, so must do deep copy to prevent race. -func (ri *remoteInfo) CopyFrom(from RemoteInfo) { - if from == nil || ri == from { - return - } - ri.Lock() - f := from.(*remoteInfo) - ri.serviceName = f.serviceName - ri.instance = f.instance - ri.method = f.method - ri.tags = f.copyTags() - ri.tagLocks = f.copyTagsLocks() - ri.Unlock() -} - // ImmutableView implements rpcinfo.MutableEndpointInfo. func (ri *remoteInfo) ImmutableView() rpcinfo.EndpointInfo { return ri } -func (ri *remoteInfo) copyTags() map[string]string { - ri.Lock() - defer ri.Unlock() - newTags := make(map[string]string, len(ri.tags)) - for k, v := range ri.tags { - newTags[k] = v - } - return newTags -} - -func (ri *remoteInfo) copyTagsLocks() map[string]struct{} { - ri.Lock() - defer ri.Unlock() - newTagLocks := make(map[string]struct{}, len(ri.tagLocks)) - for k, v := range ri.tagLocks { - newTagLocks[k] = v - } - return newTagLocks -} - func (ri *remoteInfo) zero() { ri.Lock() defer ri.Unlock() diff --git a/pkg/rpcinfo/remoteinfo/remoteInfo_test.go b/pkg/rpcinfo/remoteinfo/remoteInfo_test.go index 6d62835d45..0277fb85f0 100644 --- a/pkg/rpcinfo/remoteinfo/remoteInfo_test.go +++ b/pkg/rpcinfo/remoteinfo/remoteInfo_test.go @@ -87,7 +87,7 @@ func TestAsRemoteInfo(t *testing.T) { na := ri.Address() test.Assert(t, na.Network() == "n" && na.String() == "a") - _, ok := ins.(remoteinfo.RemoteAddrSetter) + _, ok := ins.(remoteinfo.RefreshableInstance) na = utils.NewNetAddr("nnn", "aaa") test.Assert(t, ri2.SetRemoteAddr(na) == ok) if ok { @@ -186,39 +186,6 @@ func TestGetTag(t *testing.T) { test.Assert(t, valIDC == myIDC, valIDC) } -func TestCopyFromRace(t *testing.T) { - key1, key2, val1, val2 := "key1", "key2", "val1", "val2" - ri1 := remoteinfo.NewRemoteInfo(&rpcinfo.EndpointBasicInfo{ServiceName: "service", Tags: map[string]string{key1: val1}}, "method1") - v, ok := ri1.Tag(key1) - test.Assert(t, ok) - test.Assert(t, v == val1) - - ri2 := remoteinfo.NewRemoteInfo(&rpcinfo.EndpointBasicInfo{ServiceName: "service", Tags: map[string]string{key2: val2}}, "method1") - // do copyFrom ri2 - ri1.CopyFrom(ri2) - _, ok = ri1.Tag(key1) - test.Assert(t, !ok) - v, ok = ri1.Tag(key2) - test.Assert(t, ok) - test.Assert(t, v == val2) - - // test the data race problem caused by tag modification - var wg sync.WaitGroup - wg.Add(2) - go func() { - ri2.ForceSetTag("key11", "val11") - wg.Done() - }() - go func() { - ri1.Tag("key2") - wg.Done() - }() - wg.Wait() - - // test if dead lock - ri1.CopyFrom(ri1) -} - func TestRecycleRace(t *testing.T) { ri := remoteinfo.NewRemoteInfo(&rpcinfo.EndpointBasicInfo{ServiceName: "service", Tags: map[string]string{"key1": "val1"}}, "method1") diff --git a/pkg/rpcinfo/rpcstats.go b/pkg/rpcinfo/rpcstats.go index 6396511d4c..363df20f62 100644 --- a/pkg/rpcinfo/rpcstats.go +++ b/pkg/rpcinfo/rpcstats.go @@ -170,6 +170,24 @@ func (r *rpcStats) Level() stats.Level { return r.level } +// CopyForRetry implements the RPCStats interface, it copies a RPCStats from the origin one +// to pass through info of the first request to retrying requests. +func (r *rpcStats) CopyForRetry() RPCStats { + // Copied rpc stats is for request retrying and cannot be reused, so no need to get from pool. + nr := newRPCStats().(*rpcStats) + r.Lock() + startIdx := int(stats.RPCStart.Index()) + userIdx := stats.PredefinedEventNum() + for i := 0; i < len(nr.eventMap); i++ { + // Ignore none RPCStart events to avoid incorrect tracing. + if i == startIdx || i >= userIdx { + nr.eventMap[i] = r.eventMap[i] + } + } + r.Unlock() + return nr +} + // SetSendSize sets send size. func (r *rpcStats) SetSendSize(size uint64) { atomic.StoreUint64(&r.sendSize, size) diff --git a/pkg/rpcinfo/rpcstats_test.go b/pkg/rpcinfo/rpcstats_test.go index 4b2567427e..febd0fb82f 100644 --- a/pkg/rpcinfo/rpcstats_test.go +++ b/pkg/rpcinfo/rpcstats_test.go @@ -63,3 +63,19 @@ func TestRPCStats(t *testing.T) { ok, err = st.Panicked() test.Assert(t, !ok && err == nil) } + +func BenchmarkCopyForRetry(b *testing.B) { + b.Run("BenchmarkNewRPCStats", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = rpcinfo.NewRPCStats() + } + }) + + s := rpcinfo.NewRPCStats() + b.Run("BenchmarkCopyForRetry", func(b *testing.B) { + s.Record(context.Background(), stats.RPCStart, stats.StatusInfo, "") + for i := 0; i < b.N; i++ { + _ = s.CopyForRetry() + } + }) +} diff --git a/pkg/stats/event.go b/pkg/stats/event.go index 0fad8d4d40..74afa0e035 100644 --- a/pkg/stats/event.go +++ b/pkg/stats/event.go @@ -133,6 +133,11 @@ func MaxEventNum() int { return maxEventNum } +// PredefinedEventNum returns the number of predefined events of kitex. +func PredefinedEventNum() int { + return int(predefinedEventNum) +} + func newEvent(idx EventIndex, level Level) Event { return event{ idx: idx, diff --git a/pkg/utils/strings.go b/pkg/utils/strings.go index 7c4282372d..ecb276fe46 100644 --- a/pkg/utils/strings.go +++ b/pkg/utils/strings.go @@ -16,10 +16,89 @@ package utils -import "unsafe" +import ( + "strings" + "sync" + "unsafe" +) func StringDeepCopy(s string) string { buf := []byte(s) ns := (*string)(unsafe.Pointer(&buf)) return *ns } + +// StringBuilder is a concurrently safe wrapper for strings.Builder. +type StringBuilder struct { + sync.Mutex + sb strings.Builder +} + +// WithLocked encapsulates a concurrent-safe interface for batch operations on strings.Builder. +// Please note that you should avoid calling member functions of StringBuilder within the input +// function, as it may lead to program deadlock. +func (b *StringBuilder) WithLocked(f func(sb *strings.Builder) error) error { + b.Lock() + defer b.Unlock() + return f(&b.sb) +} + +// RawStringBuilder returns the inner strings.Builder of StringBuilder. It allows users to perform +// lock-free operations in scenarios without concurrency issues, thereby improving performance. +func (b *StringBuilder) RawStringBuilder() *strings.Builder { + return &b.sb +} + +func (b *StringBuilder) String() string { + b.Lock() + defer b.Unlock() + return b.sb.String() +} + +func (b *StringBuilder) Len() int { + b.Lock() + defer b.Unlock() + return b.sb.Len() +} + +func (b *StringBuilder) Cap() int { + b.Lock() + defer b.Unlock() + return b.sb.Cap() +} + +func (b *StringBuilder) Reset() { + b.Lock() + defer b.Unlock() + b.sb.Reset() +} + +func (b *StringBuilder) Grow(n int) { + b.Lock() + defer b.Unlock() + b.sb.Grow(n) +} + +func (b *StringBuilder) Write(p []byte) (int, error) { + b.Lock() + defer b.Unlock() + return b.sb.Write(p) +} + +func (b *StringBuilder) WriteByte(c byte) error { + b.Lock() + defer b.Unlock() + return b.sb.WriteByte(c) +} + +func (b *StringBuilder) WriteRune(r rune) (int, error) { + b.Lock() + defer b.Unlock() + return b.sb.WriteRune(r) +} + +func (b *StringBuilder) WriteString(s string) (int, error) { + b.Lock() + defer b.Unlock() + return b.sb.WriteString(s) +} diff --git a/pkg/utils/strings_test.go b/pkg/utils/strings_test.go new file mode 100644 index 0000000000..f253d33529 --- /dev/null +++ b/pkg/utils/strings_test.go @@ -0,0 +1,60 @@ +/* + * Copyright 2023 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package utils + +import ( + "strings" + "testing" + + "github.com/cloudwego/kitex/internal/test" +) + +func TestStringBuilder(t *testing.T) { + sb := &StringBuilder{} + sb.Grow(4) + test.Assert(t, sb.Cap() == 4) + test.Assert(t, sb.Len() == 0) + sb.WriteString("1") + sb.WriteByte('2') + sb.WriteRune(rune('3')) + sb.Write([]byte("4")) + test.Assert(t, sb.Cap() == 4) + test.Assert(t, sb.Len() == 4) + test.Assert(t, sb.String() == "1234") + sb.Reset() + test.Assert(t, sb.String() == "") + test.Assert(t, sb.Cap() == 0) + test.Assert(t, sb.Len() == 0) + + sb.WithLocked(func(sb *strings.Builder) error { + sb.Grow(4) + test.Assert(t, sb.Cap() == 4) + test.Assert(t, sb.Len() == 0) + sb.WriteString("1") + sb.WriteByte('2') + sb.WriteRune(rune('3')) + sb.Write([]byte("4")) + test.Assert(t, sb.Cap() == 4) + test.Assert(t, sb.Len() == 4) + test.Assert(t, sb.String() == "1234") + sb.Reset() + test.Assert(t, sb.String() == "") + test.Assert(t, sb.Cap() == 0) + test.Assert(t, sb.Len() == 0) + return nil + }) +} diff --git a/server/server.go b/server/server.go index f71da8d793..65c27d71a2 100644 --- a/server/server.go +++ b/server/server.go @@ -194,9 +194,6 @@ func (s *server) GetServiceInfo() *serviceinfo.ServiceInfo { // Run runs the server. func (s *server) Run() (err error) { - if s.svcInfo == nil { - return errors.New("no service, use RegisterService to set one") - } if err = s.check(); err != nil { return err } @@ -353,10 +350,9 @@ func (s *server) addBoundHandlers(opt *remote.ServerOption) { } } - // for server limiter, the handler should be added as first one limitHdlr := s.buildLimiterWithOpt() if limitHdlr != nil { - doAddBoundHandlerToHead(limitHdlr, opt) + doAddBoundHandler(limitHdlr, opt) } } diff --git a/server/server_test.go b/server/server_test.go index c3af42c750..554465b6cb 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -424,8 +424,8 @@ func TestServerBoundHandler(t *testing.T) { WithMetaHandler(noopMetahandler{}), }, wantInbounds: []remote.InboundHandler{ - bound.NewServerLimiterHandler(limiter.NewConnectionLimiter(1000), limiter.NewQPSLimiter(interval, 10000), nil, false), bound.NewTransMetaHandler([]remote.MetaHandler{transmeta.MetainfoServerHandler, noopMetahandler{}}), + bound.NewServerLimiterHandler(limiter.NewConnectionLimiter(1000), limiter.NewQPSLimiter(interval, 10000), nil, false), }, wantOutbounds: []remote.OutboundHandler{ bound.NewTransMetaHandler([]remote.MetaHandler{transmeta.MetainfoServerHandler, noopMetahandler{}}), @@ -436,8 +436,8 @@ func TestServerBoundHandler(t *testing.T) { WithConnectionLimiter(mockslimiter.NewMockConcurrencyLimiter(ctrl)), }, wantInbounds: []remote.InboundHandler{ - bound.NewServerLimiterHandler(mockslimiter.NewMockConcurrencyLimiter(ctrl), &limiter.DummyRateLimiter{}, nil, false), bound.NewTransMetaHandler([]remote.MetaHandler{transmeta.MetainfoServerHandler}), + bound.NewServerLimiterHandler(mockslimiter.NewMockConcurrencyLimiter(ctrl), &limiter.DummyRateLimiter{}, nil, false), }, wantOutbounds: []remote.OutboundHandler{ bound.NewTransMetaHandler([]remote.MetaHandler{transmeta.MetainfoServerHandler}), @@ -448,8 +448,8 @@ func TestServerBoundHandler(t *testing.T) { WithQPSLimiter(mockslimiter.NewMockRateLimiter(ctrl)), }, wantInbounds: []remote.InboundHandler{ - bound.NewServerLimiterHandler(&limiter.DummyConcurrencyLimiter{}, mockslimiter.NewMockRateLimiter(ctrl), nil, true), bound.NewTransMetaHandler([]remote.MetaHandler{transmeta.MetainfoServerHandler}), + bound.NewServerLimiterHandler(&limiter.DummyConcurrencyLimiter{}, mockslimiter.NewMockRateLimiter(ctrl), nil, true), }, wantOutbounds: []remote.OutboundHandler{ bound.NewTransMetaHandler([]remote.MetaHandler{transmeta.MetainfoServerHandler}), @@ -461,8 +461,8 @@ func TestServerBoundHandler(t *testing.T) { WithQPSLimiter(mockslimiter.NewMockRateLimiter(ctrl)), }, wantInbounds: []remote.InboundHandler{ - bound.NewServerLimiterHandler(mockslimiter.NewMockConcurrencyLimiter(ctrl), mockslimiter.NewMockRateLimiter(ctrl), nil, true), bound.NewTransMetaHandler([]remote.MetaHandler{transmeta.MetainfoServerHandler}), + bound.NewServerLimiterHandler(mockslimiter.NewMockConcurrencyLimiter(ctrl), mockslimiter.NewMockRateLimiter(ctrl), nil, true), }, wantOutbounds: []remote.OutboundHandler{ bound.NewTransMetaHandler([]remote.MetaHandler{transmeta.MetainfoServerHandler}), @@ -477,8 +477,8 @@ func TestServerBoundHandler(t *testing.T) { WithConnectionLimiter(mockslimiter.NewMockConcurrencyLimiter(ctrl)), }, wantInbounds: []remote.InboundHandler{ - bound.NewServerLimiterHandler(mockslimiter.NewMockConcurrencyLimiter(ctrl), limiter.NewQPSLimiter(interval, 10000), nil, false), bound.NewTransMetaHandler([]remote.MetaHandler{transmeta.MetainfoServerHandler}), + bound.NewServerLimiterHandler(mockslimiter.NewMockConcurrencyLimiter(ctrl), limiter.NewQPSLimiter(interval, 10000), nil, false), }, wantOutbounds: []remote.OutboundHandler{ bound.NewTransMetaHandler([]remote.MetaHandler{transmeta.MetainfoServerHandler}), @@ -494,8 +494,8 @@ func TestServerBoundHandler(t *testing.T) { WithQPSLimiter(mockslimiter.NewMockRateLimiter(ctrl)), }, wantInbounds: []remote.InboundHandler{ - bound.NewServerLimiterHandler(mockslimiter.NewMockConcurrencyLimiter(ctrl), mockslimiter.NewMockRateLimiter(ctrl), nil, true), bound.NewTransMetaHandler([]remote.MetaHandler{transmeta.MetainfoServerHandler}), + bound.NewServerLimiterHandler(mockslimiter.NewMockConcurrencyLimiter(ctrl), mockslimiter.NewMockRateLimiter(ctrl), nil, true), }, wantOutbounds: []remote.OutboundHandler{ bound.NewTransMetaHandler([]remote.MetaHandler{transmeta.MetainfoServerHandler}), @@ -512,8 +512,8 @@ func TestServerBoundHandler(t *testing.T) { WithMuxTransport(), }, wantInbounds: []remote.InboundHandler{ - bound.NewServerLimiterHandler(mockslimiter.NewMockConcurrencyLimiter(ctrl), mockslimiter.NewMockRateLimiter(ctrl), nil, true), bound.NewTransMetaHandler([]remote.MetaHandler{transmeta.MetainfoServerHandler}), + bound.NewServerLimiterHandler(mockslimiter.NewMockConcurrencyLimiter(ctrl), mockslimiter.NewMockRateLimiter(ctrl), nil, true), }, wantOutbounds: []remote.OutboundHandler{ bound.NewTransMetaHandler([]remote.MetaHandler{transmeta.MetainfoServerHandler}), @@ -528,8 +528,8 @@ func TestServerBoundHandler(t *testing.T) { WithMuxTransport(), }, wantInbounds: []remote.InboundHandler{ - bound.NewServerLimiterHandler(limiter.NewConnectionLimiter(1000), limiter.NewQPSLimiter(interval, 10000), nil, true), bound.NewTransMetaHandler([]remote.MetaHandler{transmeta.MetainfoServerHandler}), + bound.NewServerLimiterHandler(limiter.NewConnectionLimiter(1000), limiter.NewQPSLimiter(interval, 10000), nil, true), }, wantOutbounds: []remote.OutboundHandler{ bound.NewTransMetaHandler([]remote.MetaHandler{transmeta.MetainfoServerHandler}), diff --git a/tool/internal_pkg/generator/type.go b/tool/internal_pkg/generator/type.go index 12bf5f6391..c890c71139 100644 --- a/tool/internal_pkg/generator/type.go +++ b/tool/internal_pkg/generator/type.go @@ -144,6 +144,7 @@ var funcs = map[string]interface{}{ "SnakeString": util.SnakeString, "HasFeature": HasFeature, "FilterImports": FilterImports, + "backquoted": BackQuoted, } var templateNames = []string{ @@ -298,3 +299,7 @@ func FilterImports(Imports map[string]map[string]bool, ms []*MethodInfo) map[str } return res } + +func BackQuoted(s string) string { + return "`" + s + "`" +} diff --git a/tool/internal_pkg/pluginmode/thriftgo/file_tpl.go b/tool/internal_pkg/pluginmode/thriftgo/file_tpl.go index 2580af75da..7810566fd9 100644 --- a/tool/internal_pkg/pluginmode/thriftgo/file_tpl.go +++ b/tool/internal_pkg/pluginmode/thriftgo/file_tpl.go @@ -33,6 +33,13 @@ import ( {{if GenerateFastAPIs}} "{{ImportPathTo "pkg/protocol/bthrift"}}" {{- end}} + {{if GenerateArgsResultTypes}} + {{if Features.KeepUnknownFields}} + {{- if ne (len .Scope.Services) 0}} + unknown "github.com/cloudwego/thriftgo/generator/golang/extension/unknown" + {{- end}} + {{- end}} + {{- end}} {{- range $path, $alias := .Imports}} {{$alias }}"{{$path}}" {{- end}} diff --git a/tool/internal_pkg/tpl/service.go b/tool/internal_pkg/tpl/service.go index c490ae3682..c4d5d54ac1 100644 --- a/tool/internal_pkg/tpl/service.go +++ b/tool/internal_pkg/tpl/service.go @@ -56,7 +56,7 @@ func NewServiceInfo() *kitex.ServiceInfo { } extra := map[string]interface{}{ "PackageName": "{{.PkgInfo.PkgName}}", - "ServiceFilePath": "{{.ServiceFilePath}}", + "ServiceFilePath": {{ backquoted .ServiceFilePath }}, } {{- if gt (len .CombineServices) 0}} extra["combine_service"] = true diff --git a/transport/keys.go b/transport/keys.go index d9e426eeb5..c4c4b8afd5 100644 --- a/transport/keys.go +++ b/transport/keys.go @@ -21,7 +21,6 @@ package transport type Protocol int // Predefined transport protocols. -// Framed is suggested. const ( PurePayload Protocol = 0 diff --git a/version.go b/version.go index 3e19decae3..ab7b95bd4c 100644 --- a/version.go +++ b/version.go @@ -19,5 +19,5 @@ package kitex // Name and Version info of this framework, used for statistics and debug const ( Name = "Kitex" - Version = "v0.7.1" + Version = "v0.7.2" )