diff --git a/pkg/proxy/client_conn.go b/pkg/proxy/client_conn.go index da134e30d92b9..974be97e46411 100644 --- a/pkg/proxy/client_conn.go +++ b/pkg/proxy/client_conn.go @@ -107,7 +107,8 @@ type ClientConn interface { } type migration struct { - setVarStmts []string + setVarStmtMap map[string]struct{} + setVarStmts []string } // clientConn is the connection between proxy and client. @@ -231,6 +232,7 @@ func newClientConn( } c.tlsConfig = tlsConfig } + c.migration.setVarStmtMap = make(map[string]struct{}) return c, nil } @@ -408,6 +410,17 @@ func (c *clientConn) handleKill(e *killEvent, resp chan<- []byte) error { // handleSetVar handles the set variable event. func (c *clientConn) handleSetVar(e *setVarEvent) error { defer e.notify() + _, ok := c.migration.setVarStmtMap[e.stmt] + if ok { + for i := 0; i < len(c.migration.setVarStmts); i++ { + if c.migration.setVarStmts[i] == e.stmt { + c.migration.setVarStmts = append(c.migration.setVarStmts[:i], c.migration.setVarStmts[i+1:]...) + i-- + } + } + } else { + c.migration.setVarStmtMap[e.stmt] = struct{}{} + } c.migration.setVarStmts = append(c.migration.setVarStmts, e.stmt) return nil } diff --git a/pkg/proxy/client_conn_test.go b/pkg/proxy/client_conn_test.go index 556d38f160fa2..1948d6a605728 100644 --- a/pkg/proxy/client_conn_test.go +++ b/pkg/proxy/client_conn_test.go @@ -568,3 +568,41 @@ func TestClientConn_SendErrToClient(t *testing.T) { cc.SendErrToClient(moerr.NewInternalErrorNoCtx("msg1")) wg.Wait() } + +func TestHandleSetVar(t *testing.T) { + defer leaktest.AfterTest(t)() + var cc clientConn + cc.migration.setVarStmtMap = make(map[string]struct{}) + e0 := &setVarEvent{ + baseEvent: baseEvent{waitC: make(chan struct{}, 5)}, + stmt: "set autocommit=0", + } + require.NoError(t, cc.handleSetVar(e0)) + require.Equal(t, 1, len(cc.migration.setVarStmtMap)) + require.Equal(t, 1, len(cc.migration.setVarStmts)) + require.Equal(t, e0.stmt, cc.migration.setVarStmts[len(cc.migration.setVarStmts)-1]) + + require.NoError(t, cc.handleSetVar(e0)) + require.Equal(t, 1, len(cc.migration.setVarStmtMap)) + require.Equal(t, 1, len(cc.migration.setVarStmts)) + require.Equal(t, e0.stmt, cc.migration.setVarStmts[len(cc.migration.setVarStmts)-1]) + + e1 := &setVarEvent{ + baseEvent: baseEvent{waitC: make(chan struct{}, 5)}, + stmt: "set autocommit=1", + } + require.NoError(t, cc.handleSetVar(e1)) + require.Equal(t, 2, len(cc.migration.setVarStmtMap)) + require.Equal(t, 2, len(cc.migration.setVarStmts)) + require.Equal(t, e1.stmt, cc.migration.setVarStmts[len(cc.migration.setVarStmts)-1]) + + require.NoError(t, cc.handleSetVar(e0)) + require.Equal(t, 2, len(cc.migration.setVarStmtMap)) + require.Equal(t, 2, len(cc.migration.setVarStmts)) + require.Equal(t, e0.stmt, cc.migration.setVarStmts[len(cc.migration.setVarStmts)-1]) + + require.NoError(t, cc.handleSetVar(e1)) + require.Equal(t, 2, len(cc.migration.setVarStmtMap)) + require.Equal(t, 2, len(cc.migration.setVarStmts)) + require.Equal(t, e1.stmt, cc.migration.setVarStmts[len(cc.migration.setVarStmts)-1]) +}