-
Notifications
You must be signed in to change notification settings - Fork 64
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for the source command #128
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
before source | ||
Hello from the included file | ||
SELECT 1 | ||
1 | ||
1 | ||
after source | ||
Goodbye from the included file | ||
SELECT 3 | ||
3 | ||
3 |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -47,9 +47,13 @@ var ( | |
retryConnCount int | ||
collationDisable bool | ||
checkErr bool | ||
disableSource bool | ||
) | ||
|
||
func init() { | ||
// Disable the `--source` command by default to avoid breaking existing tests | ||
disableSource = true | ||
|
||
flag.StringVar(&host, "host", "127.0.0.1", "The host of the TiDB/MySQL server.") | ||
flag.StringVar(&port, "port", "4000", "The listen port of TiDB/MySQL server.") | ||
flag.StringVar(&user, "user", "root", "The user for connecting to the database.") | ||
|
@@ -72,10 +76,15 @@ const ( | |
type query struct { | ||
firstWord string | ||
Query string | ||
File string | ||
Line int | ||
tp int | ||
} | ||
|
||
func (q *query) location() string { | ||
return fmt.Sprintf("%s:%d", q.File, q.Line) | ||
} | ||
|
||
type Conn struct { | ||
// DB might be a shared one by multiple Conn, if the connection information are the same. | ||
mdb *sql.DB | ||
|
@@ -325,7 +334,7 @@ func (t *tester) addSuccess(testSuite *XUnitTestSuite, startTime *time.Time, cnt | |
func (t *tester) Run() error { | ||
t.preProcess() | ||
defer t.postProcess() | ||
queries, err := t.loadQueries() | ||
queries, err := t.loadQueries(t.testFileName()) | ||
if err != nil { | ||
err = errors.Trace(err) | ||
t.addFailure(&testSuite, &err, 0) | ||
|
@@ -338,17 +347,33 @@ func (t *tester) Run() error { | |
return err | ||
} | ||
|
||
var s string | ||
defer func() { | ||
if t.resultFD != nil { | ||
t.resultFD.Close() | ||
} | ||
}() | ||
|
||
testCnt := 0 | ||
startTime := time.Now() | ||
testCnt, err := t.runQueries(queries) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
fmt.Printf("%s: ok! %d test cases passed, take time %v s\n", t.testFileName(), testCnt, time.Since(startTime).Seconds()) | ||
|
||
if xmlPath != "" { | ||
t.addSuccess(&testSuite, &startTime, testCnt) | ||
} | ||
|
||
return t.flushResult() | ||
} | ||
|
||
func (t *tester) runQueries(queries []query) (int, error) { | ||
testCnt := 0 | ||
var concurrentQueue []query | ||
var concurrentSize int | ||
var s string | ||
var err error | ||
for _, q := range queries { | ||
s = q.Query | ||
switch q.tp { | ||
|
@@ -379,15 +404,15 @@ func (t *tester) Run() error { | |
if err != nil { | ||
err = errors.Annotate(err, "Atoi failed") | ||
t.addFailure(&testSuite, &err, testCnt) | ||
return err | ||
return testCnt, err | ||
} | ||
} | ||
case Q_END_CONCURRENT: | ||
t.enableConcurrent = false | ||
if err = t.concurrentRun(concurrentQueue, concurrentSize); err != nil { | ||
err = errors.Annotate(err, fmt.Sprintf("concurrent test failed in %v", t.name)) | ||
t.addFailure(&testSuite, &err, testCnt) | ||
return err | ||
return testCnt, err | ||
} | ||
t.expectedErrs = nil | ||
case Q_ERROR: | ||
|
@@ -404,9 +429,9 @@ func (t *tester) Run() error { | |
if t.enableConcurrent { | ||
concurrentQueue = append(concurrentQueue, q) | ||
} else if err = t.execute(q); err != nil { | ||
err = errors.Annotate(err, fmt.Sprintf("sql:%v", q.Query)) | ||
err = errors.Annotate(err, fmt.Sprintf("sql:%v line:%s", q.Query, q.location())) | ||
t.addFailure(&testSuite, &err, testCnt) | ||
return err | ||
return testCnt, err | ||
} | ||
|
||
testCnt++ | ||
|
@@ -426,7 +451,7 @@ func (t *tester) Run() error { | |
if err != nil { | ||
err = errors.Annotate(err, fmt.Sprintf("Could not parse column in --replace_column: sql:%v", q.Query)) | ||
t.addFailure(&testSuite, &err, testCnt) | ||
return err | ||
return testCnt, err | ||
} | ||
|
||
t.replaceColumn = append(t.replaceColumn, ReplaceColumn{col: colNr, replace: []byte(cols[i+1])}) | ||
|
@@ -473,7 +498,7 @@ func (t *tester) Run() error { | |
r, err := t.executeStmtString(s) | ||
if err != nil { | ||
log.WithFields(log.Fields{ | ||
"query": s, "line": q.Line}, | ||
"query": s, "line": q.location()}, | ||
).Error("failed to perform let query") | ||
return "" | ||
} | ||
|
@@ -484,27 +509,68 @@ func (t *tester) Run() error { | |
case Q_REMOVE_FILE: | ||
err = os.Remove(strings.TrimSpace(q.Query)) | ||
if err != nil { | ||
return errors.Annotate(err, "failed to remove file") | ||
return testCnt, errors.Annotate(err, "failed to remove file") | ||
} | ||
case Q_REPLACE_REGEX: | ||
t.replaceRegex = nil | ||
regex, err := ParseReplaceRegex(q.Query) | ||
if err != nil { | ||
return errors.Annotate(err, fmt.Sprintf("Could not parse regex in --replace_regex: line: %d sql:%v", q.Line, q.Query)) | ||
return testCnt, errors.Annotate( | ||
err, fmt.Sprintf("Could not parse regex in --replace_regex: line: %s sql:%v", | ||
q.location(), q.Query)) | ||
} | ||
t.replaceRegex = regex | ||
default: | ||
log.WithFields(log.Fields{"command": q.firstWord, "arguments": q.Query, "line": q.Line}).Warn("command not implemented") | ||
} | ||
} | ||
case Q_ENABLE_SOURCE: | ||
disableSource = false | ||
case Q_DISABLE_SOURCE: | ||
disableSource = true | ||
case Q_SOURCE: | ||
if disableSource { | ||
log.WithFields(log.Fields{"line": q.location()}).Warn("source command disabled, add '--enable_source' to your file to enable") | ||
break | ||
} | ||
fileName := strings.TrimSpace(q.Query) | ||
cwd, err := os.Getwd() | ||
if err != nil { | ||
return testCnt, err | ||
} | ||
|
||
fmt.Printf("%s: ok! %d test cases passed, take time %v s\n", t.testFileName(), testCnt, time.Since(startTime).Seconds()) | ||
// For security, don't allow to include files from other locations | ||
fullpath, err := filepath.Abs(fileName) | ||
if err != nil { | ||
return testCnt, err | ||
} | ||
if !strings.HasPrefix(fullpath, cwd) { | ||
return testCnt, errors.Errorf("included file %s is not prefixed with %s", fullpath, cwd) | ||
} | ||
|
||
if xmlPath != "" { | ||
t.addSuccess(&testSuite, &startTime, testCnt) | ||
} | ||
// Make sure we have a useful error message if the file can't be found or isn't a regular file | ||
s, err := os.Stat(fileName) | ||
if err != nil { | ||
return testCnt, errors.Annotate(err, | ||
fmt.Sprintf("file sourced with --source doesn't exist: line %s, file: %s", | ||
q.location(), fileName)) | ||
} | ||
if !s.Mode().IsRegular() { | ||
return testCnt, errors.Errorf("file sourced with --source isn't a regular file: line %s, file: %s", | ||
q.location(), fileName) | ||
} | ||
|
||
return t.flushResult() | ||
// Process the queries in the file | ||
includedQueries, err := t.loadQueries(fileName) | ||
if err != nil { | ||
return testCnt, errors.Annotate(err, fmt.Sprintf("error loading queries from %s", fileName)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will this lead to a full 'stack trace'? I.e. so you can follow from the originating file's line, to the current There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could add the location:
|
||
} | ||
includeCnt, err := t.runQueries(includedQueries) | ||
if err != nil { | ||
return testCnt, err | ||
} | ||
testCnt += includeCnt | ||
default: | ||
log.WithFields(log.Fields{"command": q.firstWord, "arguments": q.Query, "line": q.location()}).Warn("command not implemented") | ||
} | ||
} | ||
return testCnt, nil | ||
} | ||
|
||
func (t *tester) concurrentRun(concurrentQueue []query, concurrentSize int) error { | ||
|
@@ -606,8 +672,8 @@ func (t *tester) concurrentExecute(querys []query, wg *sync.WaitGroup, errOccure | |
} | ||
} | ||
|
||
func (t *tester) loadQueries() ([]query, error) { | ||
data, err := os.ReadFile(t.testFileName()) | ||
func (t *tester) loadQueries(fileName string) ([]query, error) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we have a new There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think so, the included files are still part of the same test. |
||
data, err := os.ReadFile(fileName) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
@@ -623,18 +689,30 @@ func (t *tester) loadQueries() ([]query, error) { | |
newStmt = true | ||
continue | ||
} else if strings.HasPrefix(s, "--") { | ||
queries = append(queries, query{Query: s, Line: i + 1}) | ||
queries = append(queries, query{ | ||
Query: s, | ||
Line: i + 1, | ||
File: fileName, | ||
}) | ||
newStmt = true | ||
continue | ||
} else if len(s) == 0 { | ||
continue | ||
} | ||
|
||
if newStmt { | ||
queries = append(queries, query{Query: s, Line: i + 1}) | ||
queries = append(queries, query{ | ||
Query: s, | ||
Line: i + 1, | ||
File: fileName, | ||
}) | ||
} else { | ||
lastQuery := queries[len(queries)-1] | ||
lastQuery = query{Query: fmt.Sprintf("%s\n%s", lastQuery.Query, s), Line: lastQuery.Line} | ||
lastQuery = query{ | ||
Query: fmt.Sprintf("%s\n%s", lastQuery.Query, s), | ||
Line: lastQuery.Line, | ||
File: fileName, | ||
} | ||
queries[len(queries)-1] = lastQuery | ||
} | ||
|
||
|
@@ -668,8 +746,8 @@ func (t *tester) checkExpectedError(q query, err error) error { | |
} | ||
} | ||
if !checkErr { | ||
log.Warnf("%s:%d query succeeded, but expected error(s)! (expected errors: %s) (query: %s)", | ||
t.name, q.Line, strings.Join(t.expectedErrs, ","), q.Query) | ||
log.Warnf("%s query succeeded, but expected error(s)! (expected errors: %s) (query: %s)", | ||
q.location(), strings.Join(t.expectedErrs, ","), q.Query) | ||
return nil | ||
} | ||
return errors.Errorf("Statement succeeded, expected error(s) '%s'", strings.Join(t.expectedErrs, ",")) | ||
|
@@ -684,7 +762,7 @@ func (t *tester) checkExpectedError(q query, err error) error { | |
errNo = int(innerErr.Number) | ||
} | ||
if errNo == 0 { | ||
log.Warnf("%s:%d Could not parse mysql error: %s", t.name, q.Line, err.Error()) | ||
log.Warnf("%s Could not parse mysql error: %s", q.location(), err.Error()) | ||
return err | ||
} | ||
for _, s := range t.expectedErrs { | ||
|
@@ -696,9 +774,9 @@ func (t *tester) checkExpectedError(q query, err error) error { | |
checkErrNo = i | ||
} else { | ||
if len(t.expectedErrs) > 1 { | ||
log.Warnf("%s:%d Unknown named error %s in --error %s", t.name, q.Line, s, strings.Join(t.expectedErrs, ",")) | ||
log.Warnf("%s Unknown named error %s in --error %s", q.location(), s, strings.Join(t.expectedErrs, ",")) | ||
} else { | ||
log.Warnf("%s:%d Unknown named --error %s", t.name, q.Line, s) | ||
log.Warnf("%s Unknown named --error %s", q.location(), s) | ||
} | ||
continue | ||
} | ||
|
@@ -726,11 +804,11 @@ func (t *tester) checkExpectedError(q query, err error) error { | |
} | ||
} | ||
if len(t.expectedErrs) > 1 { | ||
log.Warnf("%s:%d query failed with non expected error(s)! (%s not in %s) (err: %s) (query: %s)", | ||
t.name, q.Line, gotErrCode, strings.Join(t.expectedErrs, ","), err.Error(), q.Query) | ||
log.Warnf("%s query failed with non expected error(s)! (%s not in %s) (err: %s) (query: %s)", | ||
q.location(), gotErrCode, strings.Join(t.expectedErrs, ","), err.Error(), q.Query) | ||
} else { | ||
log.Warnf("%s:%d query failed with non expected error(s)! (%s != %s) (err: %s) (query: %s)", | ||
t.name, q.Line, gotErrCode, t.expectedErrs[0], err.Error(), q.Query) | ||
log.Warnf("%s query failed with non expected error(s)! (%s != %s) (err: %s) (query: %s)", | ||
q.location(), gotErrCode, t.expectedErrs[0], err.Error(), q.Query) | ||
} | ||
errStr := err.Error() | ||
for _, reg := range t.replaceRegex { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
--enable_source | ||
|
||
--echo before source | ||
--source include/hello1.inc | ||
--echo after source | ||
|
||
--disable_source | ||
--source include/hello2.inc | ||
--enable_source | ||
|
||
--source include/hello3.inc |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should there be a program argument to set it enabled by default, similar to
--check-error
argument, so it would work with test files that contain--source
'd files that does not exists or have not yet been tested successfully?