diff --git a/r/source.result b/r/source.result new file mode 100644 index 0000000..ad266c0 --- /dev/null +++ b/r/source.result @@ -0,0 +1,10 @@ +before source +Hello from the included file +SELECT 1 +1 +1 +after source +Goodbye from the included file +SELECT 3 +3 +3 diff --git a/src/main.go b/src/main.go index 9ddaf45..5e09363 100644 --- a/src/main.go +++ b/src/main.go @@ -47,9 +47,13 @@ var ( retryConnCount int collationDisable bool checkErr bool + disableSource bool ) func init() { + // Disable the `--source` command by default to avoid breaking existing tests + disableSource = true + flag.StringVar(&host, "host", "127.0.0.1", "The host of the TiDB/MySQL server.") flag.StringVar(&port, "port", "4000", "The listen port of TiDB/MySQL server.") flag.StringVar(&user, "user", "root", "The user for connecting to the database.") @@ -72,10 +76,15 @@ const ( type query struct { firstWord string Query string + File string Line int tp int } +func (q *query) location() string { + return fmt.Sprintf("%s:%d", q.File, q.Line) +} + type Conn struct { // DB might be a shared one by multiple Conn, if the connection information are the same. mdb *sql.DB @@ -325,7 +334,7 @@ func (t *tester) addSuccess(testSuite *XUnitTestSuite, startTime *time.Time, cnt func (t *tester) Run() error { t.preProcess() defer t.postProcess() - queries, err := t.loadQueries() + queries, err := t.loadQueries(t.testFileName()) if err != nil { err = errors.Trace(err) t.addFailure(&testSuite, &err, 0) @@ -338,17 +347,33 @@ func (t *tester) Run() error { return err } - var s string defer func() { if t.resultFD != nil { t.resultFD.Close() } }() - testCnt := 0 startTime := time.Now() + testCnt, err := t.runQueries(queries) + if err != nil { + return err + } + + fmt.Printf("%s: ok! %d test cases passed, take time %v s\n", t.testFileName(), testCnt, time.Since(startTime).Seconds()) + + if xmlPath != "" { + t.addSuccess(&testSuite, &startTime, testCnt) + } + + return t.flushResult() +} + +func (t *tester) runQueries(queries []query) (int, error) { + testCnt := 0 var concurrentQueue []query var concurrentSize int + var s string + var err error for _, q := range queries { s = q.Query switch q.tp { @@ -379,7 +404,7 @@ func (t *tester) Run() error { if err != nil { err = errors.Annotate(err, "Atoi failed") t.addFailure(&testSuite, &err, testCnt) - return err + return testCnt, err } } case Q_END_CONCURRENT: @@ -387,7 +412,7 @@ func (t *tester) Run() error { if err = t.concurrentRun(concurrentQueue, concurrentSize); err != nil { err = errors.Annotate(err, fmt.Sprintf("concurrent test failed in %v", t.name)) t.addFailure(&testSuite, &err, testCnt) - return err + return testCnt, err } t.expectedErrs = nil case Q_ERROR: @@ -404,9 +429,9 @@ func (t *tester) Run() error { if t.enableConcurrent { concurrentQueue = append(concurrentQueue, q) } else if err = t.execute(q); err != nil { - err = errors.Annotate(err, fmt.Sprintf("sql:%v", q.Query)) + err = errors.Annotate(err, fmt.Sprintf("sql:%v line:%s", q.Query, q.location())) t.addFailure(&testSuite, &err, testCnt) - return err + return testCnt, err } testCnt++ @@ -426,7 +451,7 @@ func (t *tester) Run() error { if err != nil { err = errors.Annotate(err, fmt.Sprintf("Could not parse column in --replace_column: sql:%v", q.Query)) t.addFailure(&testSuite, &err, testCnt) - return err + return testCnt, err } t.replaceColumn = append(t.replaceColumn, ReplaceColumn{col: colNr, replace: []byte(cols[i+1])}) @@ -473,7 +498,7 @@ func (t *tester) Run() error { r, err := t.executeStmtString(s) if err != nil { log.WithFields(log.Fields{ - "query": s, "line": q.Line}, + "query": s, "line": q.location()}, ).Error("failed to perform let query") return "" } @@ -484,27 +509,68 @@ func (t *tester) Run() error { case Q_REMOVE_FILE: err = os.Remove(strings.TrimSpace(q.Query)) if err != nil { - return errors.Annotate(err, "failed to remove file") + return testCnt, errors.Annotate(err, "failed to remove file") } case Q_REPLACE_REGEX: t.replaceRegex = nil regex, err := ParseReplaceRegex(q.Query) if err != nil { - return errors.Annotate(err, fmt.Sprintf("Could not parse regex in --replace_regex: line: %d sql:%v", q.Line, q.Query)) + return testCnt, errors.Annotate( + err, fmt.Sprintf("Could not parse regex in --replace_regex: line: %s sql:%v", + q.location(), q.Query)) } t.replaceRegex = regex - default: - log.WithFields(log.Fields{"command": q.firstWord, "arguments": q.Query, "line": q.Line}).Warn("command not implemented") - } - } + case Q_ENABLE_SOURCE: + disableSource = false + case Q_DISABLE_SOURCE: + disableSource = true + case Q_SOURCE: + if disableSource { + log.WithFields(log.Fields{"line": q.location()}).Warn("source command disabled, add '--enable_source' to your file to enable") + break + } + fileName := strings.TrimSpace(q.Query) + cwd, err := os.Getwd() + if err != nil { + return testCnt, err + } - fmt.Printf("%s: ok! %d test cases passed, take time %v s\n", t.testFileName(), testCnt, time.Since(startTime).Seconds()) + // For security, don't allow to include files from other locations + fullpath, err := filepath.Abs(fileName) + if err != nil { + return testCnt, err + } + if !strings.HasPrefix(fullpath, cwd) { + return testCnt, errors.Errorf("included file %s is not prefixed with %s", fullpath, cwd) + } - if xmlPath != "" { - t.addSuccess(&testSuite, &startTime, testCnt) - } + // Make sure we have a useful error message if the file can't be found or isn't a regular file + s, err := os.Stat(fileName) + if err != nil { + return testCnt, errors.Annotate(err, + fmt.Sprintf("file sourced with --source doesn't exist: line %s, file: %s", + q.location(), fileName)) + } + if !s.Mode().IsRegular() { + return testCnt, errors.Errorf("file sourced with --source isn't a regular file: line %s, file: %s", + q.location(), fileName) + } - return t.flushResult() + // Process the queries in the file + includedQueries, err := t.loadQueries(fileName) + if err != nil { + return testCnt, errors.Annotate(err, fmt.Sprintf("error loading queries from %s", fileName)) + } + includeCnt, err := t.runQueries(includedQueries) + if err != nil { + return testCnt, err + } + testCnt += includeCnt + default: + log.WithFields(log.Fields{"command": q.firstWord, "arguments": q.Query, "line": q.location()}).Warn("command not implemented") + } + } + return testCnt, nil } func (t *tester) concurrentRun(concurrentQueue []query, concurrentSize int) error { @@ -606,8 +672,8 @@ func (t *tester) concurrentExecute(querys []query, wg *sync.WaitGroup, errOccure } } -func (t *tester) loadQueries() ([]query, error) { - data, err := os.ReadFile(t.testFileName()) +func (t *tester) loadQueries(fileName string) ([]query, error) { + data, err := os.ReadFile(fileName) if err != nil { return nil, err } @@ -623,7 +689,11 @@ func (t *tester) loadQueries() ([]query, error) { newStmt = true continue } else if strings.HasPrefix(s, "--") { - queries = append(queries, query{Query: s, Line: i + 1}) + queries = append(queries, query{ + Query: s, + Line: i + 1, + File: fileName, + }) newStmt = true continue } else if len(s) == 0 { @@ -631,10 +701,18 @@ func (t *tester) loadQueries() ([]query, error) { } if newStmt { - queries = append(queries, query{Query: s, Line: i + 1}) + queries = append(queries, query{ + Query: s, + Line: i + 1, + File: fileName, + }) } else { lastQuery := queries[len(queries)-1] - lastQuery = query{Query: fmt.Sprintf("%s\n%s", lastQuery.Query, s), Line: lastQuery.Line} + lastQuery = query{ + Query: fmt.Sprintf("%s\n%s", lastQuery.Query, s), + Line: lastQuery.Line, + File: fileName, + } queries[len(queries)-1] = lastQuery } @@ -668,8 +746,8 @@ func (t *tester) checkExpectedError(q query, err error) error { } } if !checkErr { - log.Warnf("%s:%d query succeeded, but expected error(s)! (expected errors: %s) (query: %s)", - t.name, q.Line, strings.Join(t.expectedErrs, ","), q.Query) + log.Warnf("%s query succeeded, but expected error(s)! (expected errors: %s) (query: %s)", + q.location(), strings.Join(t.expectedErrs, ","), q.Query) return nil } return errors.Errorf("Statement succeeded, expected error(s) '%s'", strings.Join(t.expectedErrs, ",")) @@ -684,7 +762,7 @@ func (t *tester) checkExpectedError(q query, err error) error { errNo = int(innerErr.Number) } if errNo == 0 { - log.Warnf("%s:%d Could not parse mysql error: %s", t.name, q.Line, err.Error()) + log.Warnf("%s Could not parse mysql error: %s", q.location(), err.Error()) return err } for _, s := range t.expectedErrs { @@ -696,9 +774,9 @@ func (t *tester) checkExpectedError(q query, err error) error { checkErrNo = i } else { if len(t.expectedErrs) > 1 { - log.Warnf("%s:%d Unknown named error %s in --error %s", t.name, q.Line, s, strings.Join(t.expectedErrs, ",")) + log.Warnf("%s Unknown named error %s in --error %s", q.location(), s, strings.Join(t.expectedErrs, ",")) } else { - log.Warnf("%s:%d Unknown named --error %s", t.name, q.Line, s) + log.Warnf("%s Unknown named --error %s", q.location(), s) } continue } @@ -726,11 +804,11 @@ func (t *tester) checkExpectedError(q query, err error) error { } } if len(t.expectedErrs) > 1 { - log.Warnf("%s:%d query failed with non expected error(s)! (%s not in %s) (err: %s) (query: %s)", - t.name, q.Line, gotErrCode, strings.Join(t.expectedErrs, ","), err.Error(), q.Query) + log.Warnf("%s query failed with non expected error(s)! (%s not in %s) (err: %s) (query: %s)", + q.location(), gotErrCode, strings.Join(t.expectedErrs, ","), err.Error(), q.Query) } else { - log.Warnf("%s:%d query failed with non expected error(s)! (%s != %s) (err: %s) (query: %s)", - t.name, q.Line, gotErrCode, t.expectedErrs[0], err.Error(), q.Query) + log.Warnf("%s query failed with non expected error(s)! (%s != %s) (err: %s) (query: %s)", + q.location(), gotErrCode, t.expectedErrs[0], err.Error(), q.Query) } errStr := err.Error() for _, reg := range t.replaceRegex { diff --git a/src/query.go b/src/query.go index 6a128d8..8f4518f 100644 --- a/src/query.go +++ b/src/query.go @@ -124,6 +124,8 @@ const ( Q_COMMENT /* Comments, ignored. */ Q_COMMENT_WITH_COMMAND Q_EMPTY_LINE + Q_DISABLE_SOURCE + Q_ENABLE_SOURCE ) // ParseQueries parses an array of string into an array of query object. @@ -136,6 +138,7 @@ func ParseQueries(qs ...query) ([]query, error) { q := query{} q.tp = Q_UNKNOWN q.Line = rs.Line + q.File = rs.File // a valid query's length should be at least 3. if len(s) < 3 { continue diff --git a/src/type.go b/src/type.go index 50ea5a6..d4b4766 100644 --- a/src/type.go +++ b/src/type.go @@ -114,6 +114,8 @@ var commandMap = map[string]int{ "single_query": Q_SINGLE_QUERY, "begin_concurrent": Q_BEGIN_CONCURRENT, "end_concurrent": Q_END_CONCURRENT, + "disable_source": Q_DISABLE_SOURCE, + "enable_source": Q_ENABLE_SOURCE, } func findType(cmdName string) int { diff --git a/t/source.test b/t/source.test new file mode 100644 index 0000000..05140a3 --- /dev/null +++ b/t/source.test @@ -0,0 +1,11 @@ +--enable_source + +--echo before source +--source include/hello1.inc +--echo after source + +--disable_source +--source include/hello2.inc +--enable_source + +--source include/hello3.inc