diff --git a/pkg/proxy/plugin.go b/pkg/proxy/plugin.go index 062ed54fdf124..6061dfc628f1d 100644 --- a/pkg/proxy/plugin.go +++ b/pkg/proxy/plugin.go @@ -42,7 +42,7 @@ func newPluginRouter(r Router, p Plugin) *pluginRouter { // Route implements Router.Route. func (r *pluginRouter) Route( - ctx context.Context, ci clientInfo, filter func(uuid string) bool, + ctx context.Context, ci clientInfo, filter func(prevAddr string) bool, ) (*CNServer, error) { re, err := r.plugin.RecommendCN(ctx, ci) if err != nil { @@ -59,6 +59,10 @@ func (r *pluginRouter) Route( if re.CN == nil { return nil, moerr.NewInternalErrorNoCtx("no CN server selected") } + // selected CN should be filtered out, fall back to the delegated router + if filter != nil && filter(re.CN.SQLAddress) { + return r.Router.Route(ctx, ci, filter) + } hash, err := ci.labelInfo.getHash() if err != nil { return nil, err diff --git a/pkg/proxy/plugin_test.go b/pkg/proxy/plugin_test.go index de09914757a7c..6d213303e6121 100644 --- a/pkg/proxy/plugin_test.go +++ b/pkg/proxy/plugin_test.go @@ -80,6 +80,7 @@ func TestPluginRouter_Route(t *testing.T) { expectErr bool expectUUID string expectRefresh int + filter func(prevAddr string) bool }{{ name: "recommend select CN", mockRecommendCNFn: func(ctx context.Context, ci clientInfo) (*plugin.Recommendation, error) { @@ -149,6 +150,26 @@ func TestPluginRouter_Route(t *testing.T) { }, expectUUID: "cn0", expectRefresh: 1, + }, { + name: "filter out current CN", + mockRecommendCNFn: func(ctx context.Context, ci clientInfo) (*plugin.Recommendation, error) { + return &plugin.Recommendation{ + Action: plugin.Select, + CN: &metadata.CNService{ + ServiceID: "cn0", + SQLAddress: "8.8.8.8:6001", + }, + Updated: true, + }, nil + }, + mockRouteFn: func(ctx context.Context, ci clientInfo) (*CNServer, error) { + return &CNServer{uuid: "cn1"}, nil + }, + expectRefresh: 1, + expectUUID: "cn1", + filter: func(addr string) bool { + return addr == "8.8.8.8:6001" + }, }} for _, tt := range tests { @@ -156,13 +177,13 @@ func TestPluginRouter_Route(t *testing.T) { p := &mockPlugin{mockRecommendCNFn: tt.mockRecommendCNFn} r := &mockRouter{mockRouteFn: tt.mockRouteFn} pr := newPluginRouter(r, p) - cn, err := pr.Route(context.TODO(), clientInfo{}, nil) + cn, err := pr.Route(context.TODO(), clientInfo{}, tt.filter) if tt.expectErr { require.Error(t, err) require.Nil(t, cn) } else { require.NotNil(t, cn) - require.Equal(t, cn.uuid, tt.expectUUID) + require.Equal(t, tt.expectUUID, cn.uuid) } require.Equal(t, r.refreshCount, tt.expectRefresh) })