diff --git a/dm/dm/worker/server.go b/dm/dm/worker/server.go index 7ff41340595..0d416c2b83f 100644 --- a/dm/dm/worker/server.go +++ b/dm/dm/worker/server.go @@ -790,11 +790,14 @@ func (s *Server) OperateSchema(ctx context.Context, req *pb.OperateWorkerSchemaR log.L().Info("", zap.String("request", "OperateSchema"), zap.Stringer("payload", req)) w := s.getSourceWorker(true) + w.RLock() + sourceID := w.cfg.SourceID + w.RUnlock() if w == nil { log.L().Warn("fail to call OperateSchema, because no mysql source is being handled in the worker") return makeCommonWorkerResponse(terror.ErrWorkerNoStart.Generate()), nil - } else if req.Source != w.cfg.SourceID { - log.L().Error("fail to call OperateSchema, because source mismatch", zap.String("request", req.Source), zap.String("current", w.cfg.SourceID)) + } else if req.Source != sourceID { + log.L().Error("fail to call OperateSchema, because source mismatch", zap.String("request", req.Source), zap.String("current", sourceID)) return makeCommonWorkerResponse(terror.ErrWorkerSourceNotMatch.Generate()), nil } diff --git a/dm/dm/worker/source_worker.go b/dm/dm/worker/source_worker.go index 9d43eff9406..2dd7d245c0f 100644 --- a/dm/dm/worker/source_worker.go +++ b/dm/dm/worker/source_worker.go @@ -43,9 +43,10 @@ import ( // SourceWorker manages a source(upstream) which is mainly related to subtasks and relay. type SourceWorker struct { - // ensure no other operation can be done when closing (we can use `WatGroup`/`Context` to archive this) + // ensure no other operation can be done when closing (we can use `WaitGroup`/`Context` to archive this) // TODO: check what does it guards. Now it's used to guard relayHolder and relayPurger (maybe subTaskHolder?) since // query-status maybe access them when closing/disable functionalities + // This lock is used to guards source worker's source config and subtask holder(subtask configs) sync.RWMutex wg sync.WaitGroup @@ -249,9 +250,12 @@ func (w *SourceWorker) Stop(graceful bool) { // updateSourceStatus updates w.sourceStatus. func (w *SourceWorker) updateSourceStatus(ctx context.Context) error { w.sourceDBMu.Lock() + w.RLock() + cfg := w.cfg + w.RUnlock() if w.sourceDB == nil { var err error - w.sourceDB, err = conn.DefaultDBProvider.Apply(&w.cfg.DecryptPassword().From) + w.sourceDB, err = conn.DefaultDBProvider.Apply(&cfg.DecryptPassword().From) if err != nil { w.sourceDBMu.Unlock() return err @@ -262,7 +266,7 @@ func (w *SourceWorker) updateSourceStatus(ctx context.Context) error { var status binlog.SourceStatus ctx, cancel := context.WithTimeout(ctx, utils.DefaultDBTimeout) defer cancel() - pos, gtidSet, err := utils.GetPosAndGs(ctx, w.sourceDB.DB, w.cfg.Flavor) + pos, gtidSet, err := utils.GetPosAndGs(ctx, w.sourceDB.DB, cfg.Flavor) if err != nil { return err } @@ -1160,7 +1164,10 @@ func (w *SourceWorker) observeValidatorStage(ctx context.Context, lastUsedRev in case <-ctx.Done(): return nil case <-time.After(500 * time.Millisecond): - startRevision, err = w.getCurrentValidatorRevision(w.cfg.SourceID) + w.RLock() + sourceID := w.cfg.SourceID + w.RUnlock() + startRevision, err = w.getCurrentValidatorRevision(sourceID) if err != nil { log.L().Error("reset validator stage failed, will retry later", zap.Error(err), zap.Int("retryNum", retryNum)) } diff --git a/dm/dm/worker/status.go b/dm/dm/worker/status.go index 1fa1237af37..5e749c98f54 100644 --- a/dm/dm/worker/status.go +++ b/dm/dm/worker/status.go @@ -131,7 +131,10 @@ func (w *SourceWorker) GetValidateStatus(stName string, filterStatus pb.Stage) [ if st == nil { return res } - sourceIP := w.cfg.From.Host + ":" + strconv.Itoa(w.cfg.From.Port) + w.RLock() + cfg := w.cfg + w.RUnlock() + sourceIP := cfg.From.Host + ":" + strconv.Itoa(cfg.From.Port) tblStats := st.GetValidatorStatus() for _, stat := range tblStats { if filterStatus == pb.Stage_InvalidStage || stat.ValidationStatus == filterStatus.String() { diff --git a/dm/simulator/internal/config/config.go b/dm/simulator/internal/config/config.go new file mode 100644 index 00000000000..9f8508d21f2 --- /dev/null +++ b/dm/simulator/internal/config/config.go @@ -0,0 +1,31 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package config is the configuration definitions used by the simulator. +package config + +// TableConfig is the sub config for describing a simulating table in the data source. +type TableConfig struct { + TableID string `yaml:"id"` + DatabaseName string `yaml:"db"` + TableName string `yaml:"table"` + Columns []*ColumnDefinition `yaml:"columns"` + UniqueKeyColumnNames []string `yaml:"unique_keys"` +} + +// ColumnDefinition is the sub config for describing a column in a simulating table. +type ColumnDefinition struct { + ColumnName string `yaml:"name"` + DataType string `yaml:"type"` + DataLen int `yaml:"length"` +} diff --git a/dm/simulator/internal/mcp/errors.go b/dm/simulator/internal/mcp/errors.go new file mode 100644 index 00000000000..9294ea0dddf --- /dev/null +++ b/dm/simulator/internal/mcp/errors.go @@ -0,0 +1,28 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package mcp + +import ( + "github.com/pingcap/errors" +) + +var ( + // ErrMCPCapacityFull means the capacity of the modification candidate pool (MCP) is full. + ErrMCPCapacityFull = errors.New("the capacity of the modification candidate pool is full") + // ErrInvalidRowID means the row ID of the unique key is invalid. + // For example, when the row ID is greater than the current MCP size, this error will be triggered. + ErrInvalidRowID = errors.New("invalid row ID") + // ErrDeleteUKNotFound means the unique key to be deleted is not found in the MCP. + ErrDeleteUKNotFound = errors.New("delete UK not found") +) diff --git a/dm/simulator/internal/mcp/mcp.go b/dm/simulator/internal/mcp/mcp.go new file mode 100644 index 00000000000..ade09cafb1b --- /dev/null +++ b/dm/simulator/internal/mcp/mcp.go @@ -0,0 +1,117 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package mcp defines the Modification Candidate Pool (MCP). +package mcp + +import ( + "math/rand" + "sync" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/tiflow/dm/pkg/log" + "go.uber.org/zap" +) + +// ModificationCandidatePool is the core container storing all the current unique keys for a table. +type ModificationCandidatePool struct { + sync.RWMutex + keyPool []*UniqueKey + theRand *rand.Rand + randLock sync.Mutex +} + +// NewModificationCandidatePool create a new MCP. +func NewModificationCandidatePool(capcity int) *ModificationCandidatePool { + theKeyPool := make([]*UniqueKey, 0, capcity) + theRand := rand.New(rand.NewSource(time.Now().Unix())) + return &ModificationCandidatePool{ + keyPool: theKeyPool, + theRand: theRand, + } +} + +// NextUK randomly picks a unique key in the MCP. +func (mcp *ModificationCandidatePool) NextUK() *UniqueKey { + mcp.RLock() + defer mcp.RUnlock() + if len(mcp.keyPool) == 0 { + return nil + } + mcp.randLock.Lock() + idx := mcp.theRand.Intn(len(mcp.keyPool)) + mcp.randLock.Unlock() + return mcp.keyPool[idx] // pass by reference +} + +// Len gets the current length of the MCP. +func (mcp *ModificationCandidatePool) Len() int { + mcp.RLock() + defer mcp.RUnlock() + return len(mcp.keyPool) +} + +// AddUK adds the unique key into the MCP. +// It has side effect: the input UK's row ID will be changed. +func (mcp *ModificationCandidatePool) AddUK(uk *UniqueKey) error { + mcp.Lock() + defer mcp.Unlock() + if len(mcp.keyPool)+1 > cap(mcp.keyPool) { + return errors.Trace(ErrMCPCapacityFull) + } + currentLen := len(mcp.keyPool) + uk.SetRowID(currentLen) + mcp.keyPool = append(mcp.keyPool, uk) + return nil +} + +// DeleteUK deletes the unique key from the MCP. +// It will get the row ID of the UK and delete the UK on that position. +// If the actual value is different from the input UK, the element will still be deleted. +// It has side effect: after the deletion, the input UK's row ID will be set to -1, +// to prevent deleting a dangling UK multiple times. +func (mcp *ModificationCandidatePool) DeleteUK(uk *UniqueKey) error { + var ( + deletedUK *UniqueKey + deleteIdx int + ) + if uk == nil { + return nil + } + mcp.Lock() + defer mcp.Unlock() + deleteIdx = uk.GetRowID() + if deleteIdx < 0 { + return errors.Trace(ErrDeleteUKNotFound) + } + if deleteIdx >= len(mcp.keyPool) { + log.L().Error("the delete UK row ID > MCP's total length", zap.Int("delete row ID", deleteIdx), zap.Int("current key pool length", len(mcp.keyPool))) + return errors.Trace(ErrInvalidRowID) + } + deletedUK = mcp.keyPool[deleteIdx] + curLen := len(mcp.keyPool) + lastUK := mcp.keyPool[curLen-1] + lastUK.SetRowID(deleteIdx) + mcp.keyPool[deleteIdx] = lastUK + mcp.keyPool = mcp.keyPool[:curLen-1] + deletedUK.SetRowID(-1) + return nil +} + +// Reset cleans up all the items in the MCP. +func (mcp *ModificationCandidatePool) Reset() { + mcp.Lock() + defer mcp.Unlock() + mcp.keyPool = mcp.keyPool[:0] +} diff --git a/dm/simulator/internal/mcp/mcp_test.go b/dm/simulator/internal/mcp/mcp_test.go new file mode 100644 index 00000000000..8062479583a --- /dev/null +++ b/dm/simulator/internal/mcp/mcp_test.go @@ -0,0 +1,239 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package mcp + +import ( + "math/rand" + "sync" + "testing" + + "github.com/stretchr/testify/suite" + + "github.com/pingcap/tiflow/dm/pkg/log" +) + +type testMCPSuite struct { + suite.Suite + mcp *ModificationCandidatePool +} + +func (s *testMCPSuite) SetupSuite() { + s.Require().Nil(log.InitLogger(&log.Config{})) +} + +func (s *testMCPSuite) SetupTest() { + mcp := NewModificationCandidatePool(8192) + for i := 0; i < 4096; i++ { + mcp.keyPool = append(mcp.keyPool, &UniqueKey{ + rowID: i, + value: map[string]interface{}{ + "id": i, + }, + }) + } + s.mcp = mcp +} + +func (s *testMCPSuite) TestNextUK() { + allHitRowIDs := map[int]int{} + repeatCnt := 20 + for i := 0; i < repeatCnt; i++ { + theUK := s.mcp.NextUK() + s.Require().NotNil(theUK, "the picked UK should not be nil") + theRowID := theUK.GetRowID() + if _, ok := allHitRowIDs[theRowID]; !ok { + allHitRowIDs[theRowID] = 0 + } + allHitRowIDs[theRowID]++ + s.T().Logf("next UK: %v", theUK) + } + totalOccurredTimes := 0 + totalOccurredRowIDs := 0 + for _, times := range allHitRowIDs { + totalOccurredRowIDs++ + totalOccurredTimes += times + } + s.Greater(totalOccurredRowIDs, 1, "there should be more than 1 occurred row IDs") + s.Equal(repeatCnt, totalOccurredTimes, "total occurred UKs should equal the iteration count") +} + +func (s *testMCPSuite) TestParallelNextUK() { + allHitRowIDs := map[int]int{} + repeatCnt := 20 + workerCnt := 5 + rowIDCh := make(chan int, workerCnt) + var wg sync.WaitGroup + wg.Add(workerCnt) + for i := 0; i < workerCnt; i++ { + go func() { + defer wg.Done() + for i := 0; i < repeatCnt; i++ { + theUK := s.mcp.NextUK() + if theUK != nil { + rowIDCh <- theUK.GetRowID() + } + } + }() + } + collectionFinishCh := func() <-chan struct{} { + ch := make(chan struct{}) + go func() { + defer close(ch) + for rowID := range rowIDCh { + if _, ok := allHitRowIDs[rowID]; !ok { + allHitRowIDs[rowID] = 0 + } + allHitRowIDs[rowID]++ + } + }() + return ch + }() + wg.Wait() + close(rowIDCh) + <-collectionFinishCh + + totalOccurredTimes := 0 + totalOccurredRowIDs := 0 + for _, times := range allHitRowIDs { + totalOccurredRowIDs++ + totalOccurredTimes += times + } + s.Greater(totalOccurredRowIDs, 1, "there should be more than 1 occurred row IDs") + s.Equal(repeatCnt*workerCnt, totalOccurredTimes, "total occurred UKs should equal the iteration count") +} + +func (s *testMCPSuite) TestMCPAddDeleteBasic() { + var ( + curPoolSize int + repeatCnt int + err error + ) + curPoolSize = len(s.mcp.keyPool) + startPoolSize := curPoolSize + repeatCnt = 5 + for i := 0; i < repeatCnt; i++ { + theUK := NewUniqueKey(-1, map[string]interface{}{ + "id": rand.Int(), + }) + err = s.mcp.AddUK(theUK) + s.Require().Nil(err) + s.Equal(curPoolSize+i+1, len(s.mcp.keyPool), "key pool size is not equal") + s.Equal(curPoolSize+i, s.mcp.keyPool[curPoolSize+i].GetRowID(), "the new added UK's row ID is abnormal") + s.Equal(curPoolSize+i, theUK.GetRowID(), "the input UK's row ID is not changed") + s.T().Logf("new added UK: %v\n", s.mcp.keyPool[curPoolSize+i]) + } + // test delete from bottom + curPoolSize = len(s.mcp.keyPool) + for i := 0; i < repeatCnt; i++ { + theUK := s.mcp.keyPool[curPoolSize-i-1] + err = s.mcp.DeleteUK(NewUniqueKey(curPoolSize-i-1, nil)) + s.Require().Nil(err) + s.Equal(-1, theUK.GetRowID(), "the deleted UK's row ID is not right") + s.Equal(curPoolSize-i-1, len(s.mcp.keyPool), "key pool size is not equal") + } + curPoolSize = len(s.mcp.keyPool) + s.Equal(startPoolSize, curPoolSize, "the MCP size is not right after adding & deleting") + // test delete from top + for i := 0; i < repeatCnt; i++ { + theDelUK := s.mcp.keyPool[i] + err = s.mcp.DeleteUK(NewUniqueKey(i, nil)) + s.Require().Nil(err) + theSwappedUK := s.mcp.keyPool[i] + swappedUKVal := theSwappedUK.GetValue() + s.Equal(i, theSwappedUK.GetRowID(), "the swapped UK's row ID is abnormal") + s.Equal(curPoolSize-i-1, swappedUKVal["id"], "the swapped UK's value is abnormal") + s.Equal(-1, theDelUK.GetRowID(), "the deleted UK's row ID is not right") + s.T().Logf("new UK after delete on the index %d: %v\n", i, s.mcp.keyPool[i]) + } + curPoolSize = len(s.mcp.keyPool) + // test delete at random position + for i := 0; i < repeatCnt; i++ { + theDelUK := s.mcp.NextUK() + deleteRowID := theDelUK.GetRowID() + err = s.mcp.DeleteUK(theDelUK) + s.Require().Nil(err) + theSwappedUK := s.mcp.keyPool[deleteRowID] + swappedUKVal := theSwappedUK.GetValue() + s.Equal(deleteRowID, theSwappedUK.GetRowID(), "the swapped UK's row ID is abnormal") + s.Equal(-1, theDelUK.GetRowID(), "the deleted UK's row ID is not right") + s.Equal(curPoolSize-i-1, swappedUKVal["id"], "the swapped UK's value is abnormal") + s.T().Logf("new UK after delete on the index %d: %v\n", deleteRowID, s.mcp.keyPool[deleteRowID]) + } + // check whether all the row ID is right + curPoolSize = len(s.mcp.keyPool) + for i := 0; i < curPoolSize; i++ { + theUK := s.mcp.keyPool[i] + s.Require().Equal(i, theUK.GetRowID(), "this UK element in the MCP has a wrong row ID") + } +} + +func (s *testMCPSuite) TestMCPAddDeleteInParallel() { + beforeLen := s.mcp.Len() + pendingCh := make(chan struct{}) + ch1 := func() <-chan error { + ch := make(chan error) + go func() { + var err error + <-pendingCh + defer func() { + ch <- err + }() + for i := 0; i < 5; i++ { + theUK := NewUniqueKey(-1, map[string]interface{}{ + "id": rand.Int(), + }) + err = s.mcp.AddUK(theUK) + if err != nil { + return + } + s.T().Logf("new added UK: %v\n", theUK) + } + }() + return ch + }() + ch2 := func() <-chan error { + ch := make(chan error) + go func() { + var err error + <-pendingCh + defer func() { + ch <- err + }() + for i := 0; i < 5; i++ { + theDelUK := s.mcp.NextUK() + deletedRowID := theDelUK.rowID + err = s.mcp.DeleteUK(theDelUK) + if err != nil { + return + } + s.mcp.RLock() + theSwappedUK := s.mcp.keyPool[deletedRowID] + s.mcp.RUnlock() + s.T().Logf("deletedUK: %v, swapped UK: %v\n", theDelUK, theSwappedUK) + } + }() + return ch + }() + close(pendingCh) + err1 := <-ch1 + err2 := <-ch2 + s.Require().Nil(err1) + s.Require().Nil(err2) + afterLen := s.mcp.Len() + s.Equal(beforeLen, afterLen, "the key pool size has changed after the parallel modification") +} + +func TestMCPSuite(t *testing.T) { + suite.Run(t, &testMCPSuite{}) +} diff --git a/dm/simulator/internal/mcp/uk.go b/dm/simulator/internal/mcp/uk.go new file mode 100644 index 00000000000..ce7ee72ef52 --- /dev/null +++ b/dm/simulator/internal/mcp/uk.go @@ -0,0 +1,140 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package mcp + +import ( + "fmt" + "strings" + "sync" +) + +// UniqueKey is the data structure describing a unique key. +type UniqueKey struct { + // It inherits a RWMutex, which is used to modify the metadata inside the UK struct. + sync.RWMutex + // rowID is an integer describing the row ID of the unique key. + // The row ID is a virtual concept, not the real row ID for a DB table. + // Usually it is used to locate the index in an MCP. + rowID int + // value is the real value of all the UK columns. + // The key is the column name, the value is the real value. + value map[string]interface{} +} + +// NewUniqueKey creates a new unique key instance. +// the map values are cloned into the new UK instance, +// so that the further changes in the value map won't affect the values inside the UK. +func NewUniqueKey(rowID int, value map[string]interface{}) *UniqueKey { + result := &UniqueKey{ + rowID: rowID, + value: make(map[string]interface{}), + } + for k, v := range value { + result.value[k] = v + } + return result +} + +// GetRowID gets the row ID of the unique key. +// The row ID is a virtual concept, not the real row ID for a DB table. +// Usually it is used to locate the index in an MCP. +func (uk *UniqueKey) GetRowID() int { + uk.RLock() + defer uk.RUnlock() + return uk.rowID +} + +// SetRowID sets the row ID of the unique key. +func (uk *UniqueKey) SetRowID(rowID int) { + uk.Lock() + defer uk.Unlock() + uk.rowID = rowID +} + +// GetValue gets the UK value map of a unique key. +// The returned value is cloned, so that further modifications won't affect the value inside the UK. +func (uk *UniqueKey) GetValue() map[string]interface{} { + uk.RLock() + defer uk.RUnlock() + result := make(map[string]interface{}) + for k, v := range uk.value { + result[k] = v + } + return result +} + +// SetValue sets the UK value map. +// The input values are cloned into the UK, +// and further modifications on the input map won't affect the values inside the UK. +func (uk *UniqueKey) SetValue(value map[string]interface{}) { + uk.Lock() + defer uk.Unlock() + uk.value = make(map[string]interface{}) + for k, v := range value { + uk.value[k] = v + } +} + +// Clone is to clone a UK into a new one. +// So that two UK objects are not interfered with each other. +func (uk *UniqueKey) Clone() *UniqueKey { + uk.RLock() + defer uk.RUnlock() + result := &UniqueKey{ + rowID: uk.rowID, + value: map[string]interface{}{}, + } + for k, v := range uk.value { + result.value[k] = v + } + return result +} + +// String returns the string representation of a UK. +func (uk *UniqueKey) String() string { + uk.RLock() + defer uk.RUnlock() + var b strings.Builder + fmt.Fprintf(&b, "%p: { RowID: %d, ", uk, uk.rowID) + fmt.Fprintf(&b, "Keys: ( ") + for k, v := range uk.value { + fmt.Fprintf(&b, "%s = %v; ", k, v) + } + fmt.Fprintf(&b, ") }") + return b.String() +} + +// IsValueEqual tests whether two UK's value parts are equal. +func (uk *UniqueKey) IsValueEqual(otherUK *UniqueKey) bool { + if uk == nil || otherUK == nil { + return false + } + uk.RLock() + defer uk.RUnlock() + otherUK.RLock() + defer otherUK.RUnlock() + if len(uk.value) != len(otherUK.value) { + return false + } + for k, v := range uk.value { + otherV, ok := otherUK.value[k] + if !ok { + return false + } + if v != otherV { + return false + } + } + return true +} diff --git a/dm/simulator/internal/mcp/uk_test.go b/dm/simulator/internal/mcp/uk_test.go new file mode 100644 index 00000000000..febc306ff50 --- /dev/null +++ b/dm/simulator/internal/mcp/uk_test.go @@ -0,0 +1,176 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package mcp + +import ( + "sync" + "testing" + + "github.com/stretchr/testify/suite" + + "github.com/pingcap/tiflow/dm/pkg/log" +) + +type testUniqueKeySuite struct { + suite.Suite +} + +func (s *testUniqueKeySuite) SetupSuite() { + s.Require().Nil(log.InitLogger(&log.Config{})) +} + +func (s *testUniqueKeySuite) TestUKClone() { + origUKCol1Value := 111 + origUKCol2Value := "COL1" + originalUK := &UniqueKey{ + rowID: -1, + value: map[string]interface{}{ + "col1": origUKCol1Value, + "col2": origUKCol2Value, + }, + } + newUKCol1Value := 222 + newUKCol2Value := "COL2" + clonedUK := originalUK.Clone() + clonedUK.value["col1"] = newUKCol1Value + clonedUK.value["col2"] = newUKCol2Value + + s.T().Logf("original UK: %v; cloned UK: %v\n", originalUK, clonedUK) + + s.Equalf(origUKCol1Value, originalUK.value["col1"], "original.%s value incorrect", "col1") + s.Equalf(origUKCol2Value, originalUK.value["col2"], "original.%s value incorrect", "col2") + s.Equalf(newUKCol1Value, clonedUK.value["col1"], "cloned.%s value incorrect", "col1") + s.Equalf(newUKCol2Value, clonedUK.value["col2"], "cloned.%s value incorrect", "col2") +} + +func (s *testUniqueKeySuite) TestUKChangeBasic() { + col1Value := 111 + col2Value := "aaa" + theValueMap := map[string]interface{}{ + "col1": col1Value, + "col2": col2Value, + } + theUK := NewUniqueKey(-1, theValueMap) + s.Equalf(col1Value, theUK.value["col1"], "%s value incorrect", "col1") + s.Equalf(col2Value, theUK.value["col2"], "%s value incorrect", "col2") + + theValueMap["col1"] = 222 + theValueMap["col2"] = "bbb" + s.Equalf(col1Value, theUK.value["col1"], "%s value incorrect", "col1") + s.Equalf(col2Value, theUK.value["col2"], "%s value incorrect", "col2") + + assignedValueMap := theUK.GetValue() + s.Equalf(col1Value, assignedValueMap["col1"], "%s value incorrect", "col1") + s.Equalf(col2Value, assignedValueMap["col2"], "%s value incorrect", "col2") + + newRowID := 999 + theUK.SetRowID(newRowID) + s.Equal(newRowID, theUK.GetRowID(), "row ID value incorrect") + + newCol1Value := 333 + newCol2Value := "ccc" + newValueMap := map[string]interface{}{ + "col1": newCol1Value, + "col2": newCol2Value, + } + theUK.SetValue(newValueMap) + s.Equalf(newCol1Value, theUK.value["col1"], "%s value incorrect", "col1") + s.Equalf(newCol2Value, theUK.value["col2"], "%s value incorrect", "col2") + s.Equalf(col1Value, assignedValueMap["col1"], "assigned map's %s value incorrect", "col1") + s.Equalf(col2Value, assignedValueMap["col2"], "assigned map's %s value incorrect", "col2") + + newValueMap["col1"] = 444 + newValueMap["col2"] = "ddd" + s.Equalf(newCol1Value, theUK.value["col1"], "%s value incorrect", "col1") + s.Equalf(newCol2Value, theUK.value["col2"], "%s value incorrect", "col2") +} + +func (s *testUniqueKeySuite) TestUKParallelChange() { + theUK := NewUniqueKey(-1, nil) + pendingCh := make(chan struct{}) + var wg sync.WaitGroup + targetID := 100 + workerCnt := 10 + wg.Add(workerCnt) + for i := 0; i < workerCnt; i++ { + go func() { + defer wg.Done() + <-pendingCh + for i := 1; i <= targetID; i++ { + theUK.SetRowID(i) + theUK.SetValue(map[string]interface{}{ + "id": i, + }) + } + }() + } + close(pendingCh) + wg.Wait() + s.Equal(targetID, theUK.GetRowID(), "row ID value incorrect") + theValue := theUK.GetValue() + s.Equal(targetID, theValue["id"], "ID column value incorrect") + s.Equal(targetID, theUK.value["id"], "ID column value in UK incorrect") +} + +func (s *testUniqueKeySuite) TestUKValueEqual() { + col1Value := 111 + col2Value := "aaa" + uk1 := &UniqueKey{ + rowID: -1, + value: map[string]interface{}{ + "col1": col1Value, + "col2": col2Value, + }, + } + uk2 := &UniqueKey{ + rowID: 100, + value: map[string]interface{}{ + "col1": col1Value, + "col2": col2Value, + }, + } + s.Equal(true, uk1.IsValueEqual(uk2), "uk1 should equal uk2 on value") + s.Equal(true, uk2.IsValueEqual(uk1), "uk2 should equal uk1 on value") + uk3 := &UniqueKey{ + rowID: 100, + value: map[string]interface{}{ + "col1": col1Value, + "col2": "bbb", + }, + } + s.Equal(false, uk1.IsValueEqual(uk3), "uk1 should not equal uk3 on value") + s.Equal(false, uk3.IsValueEqual(uk1), "uk3 should not equal uk1 on value") + uk4 := &UniqueKey{ + rowID: 100, + value: map[string]interface{}{ + "col3": 321, + }, + } + s.Equal(false, uk1.IsValueEqual(uk4), "uk1 should not equal uk4 on value") + s.Equal(false, uk4.IsValueEqual(uk1), "uk4 should not equal uk1 on value") + uk5 := &UniqueKey{ + rowID: 100, + value: map[string]interface{}{ + "col3": 321, + "col1": col1Value, + "col2": col2Value, + }, + } + s.Equal(false, uk1.IsValueEqual(uk5), "uk1 should not equal uk5 on value") + s.Equal(false, uk5.IsValueEqual(uk1), "uk5 should not equal uk1 on value") +} + +func TestUniqueKeySuite(t *testing.T) { + suite.Run(t, &testUniqueKeySuite{}) +} diff --git a/dm/simulator/internal/sqlgen/errors.go b/dm/simulator/internal/sqlgen/errors.go new file mode 100644 index 00000000000..b941e7e90da --- /dev/null +++ b/dm/simulator/internal/sqlgen/errors.go @@ -0,0 +1,28 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlgen + +import ( + "github.com/pingcap/errors" +) + +var ( + // ErrUKColValueNotProvided means that some column values of the unique key are not provided. + ErrUKColValueNotProvided = errors.New("some UK column values are not provided") + // ErrMissingUKValue means the input unique key is nil. + ErrMissingUKValue = errors.New("missing the UK values") + // ErrWhereConditionsEmpty means the WHERE clause compare conditions is empty. + // It is usually caused when there is no filter clause generated on generating a WHERE clause. + ErrWhereFiltersEmpty = errors.New("`WHERE` condition filters is empty") +) diff --git a/dm/simulator/internal/sqlgen/impl.go b/dm/simulator/internal/sqlgen/impl.go new file mode 100644 index 00000000000..20f51c4de80 --- /dev/null +++ b/dm/simulator/internal/sqlgen/impl.go @@ -0,0 +1,272 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlgen + +import ( + "strings" + + "github.com/chaos-mesh/go-sqlsmith/util" + "github.com/pingcap/errors" + "github.com/pingcap/tidb/parser/ast" + "github.com/pingcap/tidb/parser/format" + "github.com/pingcap/tidb/parser/model" + "github.com/pingcap/tidb/parser/opcode" + _ "github.com/pingcap/tidb/types/parser_driver" // import this to make the parser work + "go.uber.org/zap" + + "github.com/pingcap/tiflow/dm/pkg/log" + "github.com/pingcap/tiflow/dm/simulator/internal/config" + "github.com/pingcap/tiflow/dm/simulator/internal/mcp" +) + +type sqlGeneratorImpl struct { + tableConfig *config.TableConfig + columnMap map[string]*config.ColumnDefinition + ukMap map[string]struct{} +} + +// NewSQLGeneratorImpl generates a new implementation object for SQL generator. +func NewSQLGeneratorImpl(tableConfig *config.TableConfig) *sqlGeneratorImpl { + colDefMap := make(map[string]*config.ColumnDefinition) + for _, colDef := range tableConfig.Columns { + colDefMap[colDef.ColumnName] = colDef + } + ukMap := make(map[string]struct{}) + for _, ukColName := range tableConfig.UniqueKeyColumnNames { + if _, ok := colDefMap[ukColName]; ok { + ukMap[ukColName] = struct{}{} + } + } + return &sqlGeneratorImpl{ + tableConfig: tableConfig, + columnMap: colDefMap, + ukMap: ukMap, + } +} + +// outputString parses an ast node to SQL string. +func outputString(node ast.Node) (string, error) { + var sb strings.Builder + err := node.Restore(format.NewRestoreCtx(format.DefaultRestoreFlags, &sb)) + if err != nil { + return "", errors.Annotate(err, "restore AST into SQL string error") + } + return sb.String(), nil +} + +// GenTruncateTable generates a TRUNCATE TABLE SQL. +// It implements the SQLGenerator interface. +func (g *sqlGeneratorImpl) GenTruncateTable() (string, error) { + truncateTree := &ast.TruncateTableStmt{ + Table: &ast.TableName{ + Schema: model.NewCIStr(g.tableConfig.DatabaseName), + Name: model.NewCIStr(g.tableConfig.TableName), + }, + } + return outputString(truncateTree) +} + +func (g *sqlGeneratorImpl) generateWhereClause(theUK map[string]interface{}) (ast.ExprNode, error) { + compareExprs := make([]ast.ExprNode, 0) + // iterate the existing UKs, to make sure all the uk columns has values + for ukColName := range g.ukMap { + val, ok := theUK[ukColName] + if !ok { + log.L().Error(ErrUKColValueNotProvided.Error(), zap.String("column_name", ukColName)) + return nil, errors.Trace(ErrUKColValueNotProvided) + } + var compareExpr ast.ExprNode + if val == nil { + compareExpr = &ast.IsNullExpr{ + Expr: &ast.ColumnNameExpr{ + Name: &ast.ColumnName{ + Name: model.NewCIStr(ukColName), + }, + }, + } + } else { + compareExpr = &ast.BinaryOperationExpr{ + Op: opcode.EQ, + L: &ast.ColumnNameExpr{ + Name: &ast.ColumnName{ + Name: model.NewCIStr(ukColName), + }, + }, + R: ast.NewValueExpr(val, "", ""), + } + } + compareExprs = append(compareExprs, compareExpr) + } + resultExpr := generateCompoundBinaryOpExpr(compareExprs) + if resultExpr == nil { + return nil, ErrWhereFiltersEmpty + } + return resultExpr, nil +} + +func generateCompoundBinaryOpExpr(compExprs []ast.ExprNode) ast.ExprNode { + switch len(compExprs) { + case 0: + return nil + case 1: + return compExprs[0] + default: + return &ast.BinaryOperationExpr{ + Op: opcode.LogicAnd, + L: compExprs[0], + R: generateCompoundBinaryOpExpr(compExprs[1:]), + } + } +} + +// GenUpdateRow generates an UPDATE SQL for the given unique key. +// It implements the SQLGenerator interface. +func (g *sqlGeneratorImpl) GenUpdateRow(theUK *mcp.UniqueKey) (string, error) { + if theUK == nil { + return "", errors.Trace(ErrMissingUKValue) + } + assignments := make([]*ast.Assignment, 0) + for _, colInfo := range g.columnMap { + if _, ok := g.ukMap[colInfo.ColumnName]; ok { + // this is a UK column, skip from modifying it + // TODO: support UK modification in the future + continue + } + assignments = append(assignments, &ast.Assignment{ + Column: &ast.ColumnName{ + Name: model.NewCIStr(colInfo.ColumnName), + }, + Expr: ast.NewValueExpr(util.GenerateDataItem(colInfo.DataType), "", ""), + }) + } + whereClause, err := g.generateWhereClause(theUK.GetValue()) + if err != nil { + return "", errors.Annotate(err, "generate where clause error") + } + updateTree := &ast.UpdateStmt{ + List: assignments, + TableRefs: &ast.TableRefsClause{ + TableRefs: &ast.Join{ + Left: &ast.TableName{ + Schema: model.NewCIStr(g.tableConfig.DatabaseName), + Name: model.NewCIStr(g.tableConfig.TableName), + }, + }, + }, + Where: whereClause, + } + return outputString(updateTree) +} + +// GenInsertRow generates an INSERT SQL. +// It implements the SQLGenerator interface. +// The new row's unique key is also provided, +// so that it can be further added into an MCP. +func (g *sqlGeneratorImpl) GenInsertRow() (string, *mcp.UniqueKey, error) { + ukValues := make(map[string]interface{}) + columnNames := []*ast.ColumnName{} + values := []ast.ExprNode{} + for _, col := range g.columnMap { + columnNames = append(columnNames, &ast.ColumnName{ + Name: model.NewCIStr(col.ColumnName), + }) + newValue := util.GenerateDataItem(col.DataType) + values = append(values, ast.NewValueExpr(newValue, "", "")) + if _, ok := g.ukMap[col.ColumnName]; ok { + // add UK value + ukValues[col.ColumnName] = newValue + } + } + insertTree := &ast.InsertStmt{ + Table: &ast.TableRefsClause{ + TableRefs: &ast.Join{ + Left: &ast.TableName{ + Schema: model.NewCIStr(g.tableConfig.DatabaseName), + Name: model.NewCIStr(g.tableConfig.TableName), + }, + }, + }, + Lists: [][]ast.ExprNode{values}, + Columns: columnNames, + } + sql, err := outputString(insertTree) + if err != nil { + return "", nil, errors.Annotate(err, "output INSERT AST into SQL string error") + } + return sql, mcp.NewUniqueKey(-1, ukValues), nil +} + +// GenDeleteRow generates a DELETE SQL for the given unique key. +// It implements the SQLGenerator interface. +func (g *sqlGeneratorImpl) GenDeleteRow(theUK *mcp.UniqueKey) (string, error) { + if theUK == nil { + return "", errors.Trace(ErrMissingUKValue) + } + whereClause, err := g.generateWhereClause(theUK.GetValue()) + if err != nil { + return "", errors.Annotate(err, "generate where clause error") + } + updateTree := &ast.DeleteStmt{ + TableRefs: &ast.TableRefsClause{ + TableRefs: &ast.Join{ + Left: &ast.TableName{ + Schema: model.NewCIStr(g.tableConfig.DatabaseName), + Name: model.NewCIStr(g.tableConfig.TableName), + }, + }, + }, + Where: whereClause, + } + return outputString(updateTree) +} + +// GenLoadUniqueKeySQL generates a SELECT SQL fetching all the uniques of a table. +// It implements the SQLGenerator interface. +// The column definitions of the returned data is also provided, +// so that the values can be stored to different variables. +func (g *sqlGeneratorImpl) GenLoadUniqueKeySQL() (string, []*config.ColumnDefinition, error) { + selectFields := make([]*ast.SelectField, 0) + cols := make([]*config.ColumnDefinition, 0) + for ukColName := range g.ukMap { + selectFields = append(selectFields, &ast.SelectField{ + Expr: &ast.ColumnNameExpr{ + Name: &ast.ColumnName{ + Name: model.NewCIStr(ukColName), + }, + }, + }) + cols = append(cols, g.columnMap[ukColName]) + } + selectTree := &ast.SelectStmt{ + SelectStmtOpts: &ast.SelectStmtOpts{ + SQLCache: true, + }, + Fields: &ast.FieldList{ + Fields: selectFields, + }, + From: &ast.TableRefsClause{ + TableRefs: &ast.Join{ + Left: &ast.TableName{ + Schema: model.NewCIStr(g.tableConfig.DatabaseName), + Name: model.NewCIStr(g.tableConfig.TableName), + }, + }, + }, + } + sql, err := outputString(selectTree) + if err != nil { + return "", nil, errors.Annotate(err, "output SELECT AST into SQL string error") + } + return sql, cols, nil +} diff --git a/dm/simulator/internal/sqlgen/impl_test.go b/dm/simulator/internal/sqlgen/impl_test.go new file mode 100644 index 00000000000..363a77a8867 --- /dev/null +++ b/dm/simulator/internal/sqlgen/impl_test.go @@ -0,0 +1,353 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlgen + +import ( + "fmt" + "strings" + "testing" + + "github.com/pingcap/tidb/parser" + "github.com/pingcap/tidb/parser/ast" + "github.com/stretchr/testify/suite" + + "github.com/pingcap/tiflow/dm/pkg/log" + "github.com/pingcap/tiflow/dm/simulator/internal/config" + "github.com/pingcap/tiflow/dm/simulator/internal/mcp" +) + +type testSQLGenImplSuite struct { + suite.Suite + tableConfig *config.TableConfig + sqlParser *parser.Parser + allColNameMap map[string]struct{} +} + +func (s *testSQLGenImplSuite) SetupSuite() { + s.allColNameMap = make(map[string]struct{}) + s.Require().Nil(log.InitLogger(&log.Config{})) + s.tableConfig = &config.TableConfig{ + DatabaseName: "games", + TableName: "members", + Columns: []*config.ColumnDefinition{ + { + ColumnName: "id", + DataType: "int", + DataLen: 11, + }, + { + ColumnName: "name", + DataType: "varchar", + DataLen: 255, + }, + { + ColumnName: "age", + DataType: "int", + DataLen: 11, + }, + { + ColumnName: "team_id", + DataType: "int", + DataLen: 11, + }, + }, + UniqueKeyColumnNames: []string{"id"}, + } + for _, colInfo := range s.tableConfig.Columns { + s.allColNameMap[fmt.Sprintf("`%s`", colInfo.ColumnName)] = struct{}{} + } + s.sqlParser = parser.New() +} + +func generateUKColNameMap(ukColNames []string) map[string]struct{} { + ukColNameMap := make(map[string]struct{}) + for _, colName := range ukColNames { + ukColNameMap[fmt.Sprintf("`%s`", colName)] = struct{}{} + } + return ukColNameMap +} + +func (s *testSQLGenImplSuite) checkLoadUKsSQL(sql string, ukColNames []string) { + var err error + theAST, err := s.sqlParser.ParseOneStmt(sql, "", "") + if !s.Nilf(err, "parse statement error: %s", sql) { + return + } + selectAST, ok := theAST.(*ast.SelectStmt) + if !ok { + s.Fail("cannot convert the AST to select AST") + return + } + s.checkTableName(selectAST.From) + if !s.Equal(len(s.tableConfig.UniqueKeyColumnNames), len(selectAST.Fields.Fields)) { + return + } + ukColNameMap := generateUKColNameMap(ukColNames) + for _, field := range selectAST.Fields.Fields { + fieldNameStr, err := outputString(field) + if !s.Nil(err) { + continue + } + if _, ok := ukColNameMap[fieldNameStr]; !ok { + s.Fail( + "the parsed column name cannot be found in the UK names", + "parsed column name: %s", fieldNameStr, + ) + } + } +} + +func (s *testSQLGenImplSuite) checkTruncateSQL(sql string) { + theAST, err := s.sqlParser.ParseOneStmt(sql, "", "") + if !s.Nilf(err, "parse statement error: %s", sql) { + return + } + truncateAST, ok := theAST.(*ast.TruncateTableStmt) + if !ok { + s.Fail("cannot convert the AST to truncate AST") + return + } + s.checkTableName(truncateAST.Table) +} + +func (s *testSQLGenImplSuite) checkInsertSQL(sql string) { + theAST, err := s.sqlParser.ParseOneStmt(sql, "", "") + if !s.Nilf(err, "parse statement error: %s", sql) { + return + } + insertAST, ok := theAST.(*ast.InsertStmt) + if !ok { + s.Fail("cannot convert the AST to insert AST") + return + } + s.checkTableName(insertAST.Table) + if !s.Equal(len(s.tableConfig.Columns), len(insertAST.Columns)) { + return + } + for _, col := range insertAST.Columns { + colNameStr, err := outputString(col) + if !s.Nil(err) { + continue + } + if _, ok := s.allColNameMap[colNameStr]; !ok { + s.Fail( + "the parsed column name cannot be found in the column names", + "parsed column name: %s", colNameStr, + ) + } + } +} + +func (s *testSQLGenImplSuite) checkUpdateSQL(sql string, ukColNames []string) { + theAST, err := s.sqlParser.ParseOneStmt(sql, "", "") + if !s.Nilf(err, "parse statement error: %s", sql) { + return + } + updateAST, ok := theAST.(*ast.UpdateStmt) + if !ok { + s.Fail("cannot convert the AST to update AST") + return + } + s.checkTableName(updateAST.TableRefs) + s.Greater(len(updateAST.List), 0) + s.checkWhereClause(updateAST.Where, ukColNames) +} + +func (s *testSQLGenImplSuite) checkDeleteSQL(sql string, ukColNames []string) { + theAST, err := s.sqlParser.ParseOneStmt(sql, "", "") + if !s.Nilf(err, "parse statement error: %s", sql) { + return + } + deleteAST, ok := theAST.(*ast.DeleteStmt) + if !ok { + s.Fail("cannot convert the AST to delete AST") + return + } + s.checkTableName(deleteAST.TableRefs) + s.checkWhereClause(deleteAST.Where, ukColNames) +} + +func (s *testSQLGenImplSuite) checkTableName(astNode ast.Node) { + tableNameStr, err := outputString(astNode) + if !s.Nil(err) { + return + } + s.Equal( + fmt.Sprintf("`%s`.`%s`", s.tableConfig.DatabaseName, s.tableConfig.TableName), + tableNameStr, + ) +} + +func (s *testSQLGenImplSuite) checkWhereClause(astNode ast.Node, ukColNames []string) { + whereClauseStr, err := outputString(astNode) + if !s.Nil(err) { + return + } + ukColNameMap := generateUKColNameMap(ukColNames) + for colName := range ukColNameMap { + if !s.Truef( + strings.Contains(whereClauseStr, fmt.Sprintf("%s=", colName)) || + strings.Contains(whereClauseStr, fmt.Sprintf("%s IS NULL", colName)), + "cannot find the column name in the where clause: where clause string: %s; column name: %s", + whereClauseStr, colName, + ) { + continue + } + } +} + +func (s *testSQLGenImplSuite) TestDMLBasic() { + var ( + err error + sql string + uk *mcp.UniqueKey + ) + g := NewSQLGeneratorImpl(s.tableConfig) + + sql, _, err = g.GenLoadUniqueKeySQL() + s.Nil(err, "generate load UK SQL error") + s.T().Logf("Generated SELECT SQL: %s\n", sql) + s.checkLoadUKsSQL(sql, s.tableConfig.UniqueKeyColumnNames) + + sql, err = g.GenTruncateTable() + s.Nil(err, "generate truncate table SQL error") + s.T().Logf("Generated Truncate Table SQL: %s\n", sql) + s.checkTruncateSQL(sql) + + theMCP := mcp.NewModificationCandidatePool(8192) + for i := 0; i < 4096; i++ { + s.Nil( + theMCP.AddUK(mcp.NewUniqueKey(i, map[string]interface{}{ + "id": i, + })), + ) + } + for i := 0; i < 10; i++ { + uk = theMCP.NextUK() + sql, err = g.GenUpdateRow(uk) + s.Nil(err, "generate update sql error") + s.T().Logf("Generated SQL: %s\n", sql) + s.checkUpdateSQL(sql, s.tableConfig.UniqueKeyColumnNames) + + sql, uk, err = g.GenInsertRow() + s.Nil(err, "generate insert sql error") + s.T().Logf("Generated SQL: %s\n; Unique key: %v\n", sql, uk) + s.checkInsertSQL(sql) + + uk = theMCP.NextUK() + sql, err = g.GenDeleteRow(uk) + s.Nil(err, "generate delete sql error") + s.T().Logf("Generated SQL: %s\n; Unique key: %v\n", sql, uk) + s.checkDeleteSQL(sql, s.tableConfig.UniqueKeyColumnNames) + } +} + +func (s *testSQLGenImplSuite) TestWhereNULL() { + var ( + err error + sql string + ) + theTableConfig := &config.TableConfig{ + DatabaseName: s.tableConfig.DatabaseName, + TableName: s.tableConfig.TableName, + Columns: s.tableConfig.Columns, + UniqueKeyColumnNames: []string{"name", "team_id"}, + } + g := NewSQLGeneratorImpl(theTableConfig) + theUK := mcp.NewUniqueKey(-1, map[string]interface{}{ + "name": "ABCDEFG", + "team_id": nil, + }) + sql, err = g.GenUpdateRow(theUK) + s.Require().Nil(err) + s.T().Logf("Generated UPDATE SQL: %s\n", sql) + s.checkUpdateSQL(sql, theTableConfig.UniqueKeyColumnNames) + + sql, err = g.GenDeleteRow(theUK) + s.Require().Nil(err) + s.T().Logf("Generated DELETE SQL: %s\n", sql) + s.checkDeleteSQL(sql, theTableConfig.UniqueKeyColumnNames) +} + +func (s *testSQLGenImplSuite) TestDMLAbnormalUK() { + var ( + sql string + err error + uk *mcp.UniqueKey + ) + g := NewSQLGeneratorImpl(s.tableConfig) + uk = mcp.NewUniqueKey(-1, map[string]interface{}{ + "abcdefg": 123, + }) + _, err = g.GenUpdateRow(uk) + s.NotNil(err) + _, err = g.GenDeleteRow(uk) + s.NotNil(err) + + uk = mcp.NewUniqueKey(-1, map[string]interface{}{ + "id": 123, + "abcdefg": 321, + }) + sql, err = g.GenUpdateRow(uk) + s.Nil(err) + s.T().Logf("Generated SQL: %s\n", sql) + s.checkUpdateSQL(sql, s.tableConfig.UniqueKeyColumnNames) + + sql, err = g.GenDeleteRow(uk) + s.Nil(err) + s.T().Logf("Generated SQL: %s\n", sql) + s.checkDeleteSQL(sql, s.tableConfig.UniqueKeyColumnNames) + + uk = mcp.NewUniqueKey(-1, map[string]interface{}{}) + _, err = g.GenUpdateRow(uk) + s.NotNil(err) +} + +func (s *testSQLGenImplSuite) TestDMLWithNoUK() { + var ( + err error + sql string + theUK *mcp.UniqueKey + ) + theTableConfig := &config.TableConfig{ + DatabaseName: s.tableConfig.DatabaseName, + TableName: s.tableConfig.TableName, + Columns: s.tableConfig.Columns, + UniqueKeyColumnNames: []string{}, + } + g := NewSQLGeneratorImpl(theTableConfig) + + sql, theUK, err = g.GenInsertRow() + s.Nil(err, "generate insert sql error") + s.T().Logf("Generated SQL: %s\n; Unique key: %v\n", sql, theUK) + s.checkInsertSQL(sql) + + theUK = mcp.NewUniqueKey(-1, map[string]interface{}{}) + _, err = g.GenUpdateRow(theUK) + s.NotNil(err) + _, err = g.GenDeleteRow(theUK) + s.NotNil(err) + + theUK = mcp.NewUniqueKey(-1, map[string]interface{}{ + "id": 123, // the column is filtered out by the UK configs + }) + _, err = g.GenUpdateRow(theUK) + s.NotNil(err) + _, err = g.GenDeleteRow(theUK) + s.NotNil(err) +} + +func TestSQLGenImplSuite(t *testing.T) { + suite.Run(t, &testSQLGenImplSuite{}) +} diff --git a/dm/simulator/internal/sqlgen/sqlgen.go b/dm/simulator/internal/sqlgen/sqlgen.go new file mode 100644 index 00000000000..68dd0dd5c5a --- /dev/null +++ b/dm/simulator/internal/sqlgen/sqlgen.go @@ -0,0 +1,38 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package sqlgen is the logic for generating different kinds of SQL statements. +package sqlgen + +import ( + "github.com/pingcap/tiflow/dm/simulator/internal/config" + "github.com/pingcap/tiflow/dm/simulator/internal/mcp" +) + +// SQLGenerator contains all the operations for generating SQLs. +type SQLGenerator interface { + // GenTruncateTable generates a TRUNCATE TABLE SQL. + GenTruncateTable() (string, error) + // GenLoadUniqueKeySQL generates a SELECT SQL fetching all the uniques of a table. + // The column definitions of the returned data is also provided, + // so that the values can be stored to different variables. + GenLoadUniqueKeySQL() (string, []*config.ColumnDefinition, error) + // GenInsertRow generates an INSERT SQL. + // The new row's unique key is also provided, + // so that it can be further added into an MCP. + GenInsertRow() (string, *mcp.UniqueKey, error) + // GenUpdateRow generates an UPDATE SQL for the given unique key. + GenUpdateRow(*mcp.UniqueKey) (string, error) + // GenDeleteRow generates a DELETE SQL for the given unique key. + GenDeleteRow(*mcp.UniqueKey) (string, error) +} diff --git a/dm/syncer/data_validator.go b/dm/syncer/data_validator.go index 2b353f7ceb3..5e5be8e7e63 100644 --- a/dm/syncer/data_validator.go +++ b/dm/syncer/data_validator.go @@ -171,7 +171,6 @@ func NewContinuousDataValidator(cfg *config.SubTaskConfig, syncerObj *Syncer, st v.workerCnt = cfg.ValidatorCfg.WorkerCount v.processedRowCounts = make([]atomic.Int64, rowChangeTypeCount) - v.workers = make([]*validateWorker, v.workerCnt) v.validateInterval = validationInterval v.persistHelper = newValidatorCheckpointHelper(v) v.tableStatus = make(map[string]*tableValidateStatus) @@ -571,6 +570,7 @@ func (v *DataValidator) Stage() pb.Stage { func (v *DataValidator) startValidateWorkers() { v.wg.Add(v.workerCnt) + v.workers = make([]*validateWorker, v.workerCnt) for i := 0; i < v.workerCnt; i++ { worker := newValidateWorker(v, i) v.workers[i] = worker diff --git a/dm/syncer/validator_checkpoint_test.go b/dm/syncer/validator_checkpoint_test.go index ce73febfd25..2d9151a9506 100644 --- a/dm/syncer/validator_checkpoint_test.go +++ b/dm/syncer/validator_checkpoint_test.go @@ -31,6 +31,7 @@ import ( "github.com/pingcap/tiflow/dm/pkg/binlog" "github.com/pingcap/tiflow/dm/pkg/conn" tcontext "github.com/pingcap/tiflow/dm/pkg/context" + "github.com/pingcap/tiflow/dm/pkg/log" "github.com/pingcap/tiflow/dm/pkg/retry" "github.com/pingcap/tiflow/dm/pkg/schema" "github.com/pingcap/tiflow/dm/syncer/dbconn" @@ -131,3 +132,21 @@ func TestValidatorCheckpointPersist(t *testing.T) { testFunc("") testFunc("failed") } + +func TestCheckpointNotPanic(t *testing.T) { + // validator will try persisting data before starting + // if it visits and persists workers, which are not intialized before starting, + // the program will panick. + // This issue is fixed by putting off initializing workers + var err error + cfg := genSubtaskConfig(t) + syncerObj := NewSyncer(cfg, nil, nil) + require.Equal(t, log.InitLogger(&log.Config{}), nil) + validator := NewContinuousDataValidator(cfg, syncerObj, false) + validator.ctx, validator.cancel = context.WithCancel(context.Background()) + validator.tctx = tcontext.NewContext(validator.ctx, validator.L) + validator.persistHelper.tctx = validator.tctx + currLoc := binlog.NewLocation(cfg.Flavor) + err = validator.persistHelper.persist(currLoc) // persist nil worker + require.NotNil(t, err) // err not nil but program not panicks +}