diff --git a/executor/write.go b/executor/write.go index 999daa6424d4b..93047aa5250aa 100644 --- a/executor/write.go +++ b/executor/write.go @@ -406,7 +406,6 @@ func (e *LoadDataInfo) InsertData(prevData, curData []byte) ([]byte, bool, error var line []byte var isEOF, hasStarting, reachLimit bool - cols := make([]string, 0, len(e.row)) if len(prevData) > 0 && len(curData) == 0 { isEOF = true prevData, curData = curData, prevData @@ -432,8 +431,10 @@ func (e *LoadDataInfo) InsertData(prevData, curData []byte) ([]byte, bool, error curData = nil } - rawCols := bytes.Split(line, []byte(e.FieldsInfo.Terminated)) - cols = escapeCols(rawCols) + cols, err := GetFieldsFromLine(line, e.FieldsInfo) + if err != nil { + return nil, false, errors.Trace(err) + } e.insertData(cols) e.insertVal.currRow++ if e.insertVal.batchRows != 0 && e.insertVal.currRow%e.insertVal.batchRows == 0 { @@ -450,10 +451,31 @@ func (e *LoadDataInfo) InsertData(prevData, curData []byte) ([]byte, bool, error return curData, reachLimit, nil } +// GetFieldsFromLine splits line according to fieldsInfo, this function is exported for testing. +func GetFieldsFromLine(line []byte, fieldsInfo *ast.FieldsClause) ([]string, error) { + var sep []byte + if fieldsInfo.Enclosed != 0 { + if line[0] != fieldsInfo.Enclosed || line[len(line)-1] != fieldsInfo.Enclosed { + return nil, errors.Errorf("line %s should begin and end with %c", string(line), fieldsInfo.Enclosed) + } + line = line[1 : len(line)-1] + sep = make([]byte, 0, len(fieldsInfo.Terminated)+2) + sep = append(sep, fieldsInfo.Enclosed) + sep = append(sep, fieldsInfo.Terminated...) + sep = append(sep, fieldsInfo.Enclosed) + } else { + sep = []byte(fieldsInfo.Terminated) + } + rawCols := bytes.Split(line, sep) + cols := escapeCols(rawCols) + return cols, nil +} + func escapeCols(strs [][]byte) []string { ret := make([]string, len(strs)) for i, v := range strs { - ret[i] = string(escape(v)) + output := escape(v) + ret[i] = string(output) } return ret } @@ -467,10 +489,8 @@ func escape(str []byte) []byte { for i := 0; i < len(str); i++ { c := str[i] if c == '\\' && i+1 < len(str) { - var ok bool - if c, ok = escapeChar(str[i+1]); ok { - i++ - } + c = escapeChar(str[i+1]) + i++ } str[pos] = c @@ -479,24 +499,24 @@ func escape(str []byte) []byte { return str[:pos] } -func escapeChar(c byte) (byte, bool) { +func escapeChar(c byte) byte { switch c { case '0': - return 0, true + return 0 case 'b': - return '\b', true + return '\b' case 'n': - return '\n', true + return '\n' case 'r': - return '\r', true + return '\r' case 't': - return '\t', true + return '\t' case 'Z': - return 26, true + return 26 case '\\': - return '\\', true + return '\\' } - return c, false + return c } func (e *LoadDataInfo) insertData(cols []string) { diff --git a/executor/write_test.go b/executor/write_test.go index 88ac13a05d159..5ea71a0e3d236 100644 --- a/executor/write_test.go +++ b/executor/write_test.go @@ -1022,3 +1022,52 @@ func (s *testSuite) TestNullDefault(c *C) { tk.MustExec("insert into test_null_default values ()") tk.MustQuery("select * from test_null_default").Check(testkit.Rows("", "1970-01-01 08:20:34")) } + +func (s *testSuite) TestGetFieldsFromLine(c *C) { + tests := []struct { + input string + expected []string + }{ + { + `"1","a string","100.20"`, + []string{"1", "a string", "100.20"}, + }, + { + `"2","a string containing a , comma","102.20"`, + []string{"2", "a string containing a , comma", "102.20"}, + }, + { + `"3","a string containing a \" quote","102.20"`, + []string{"3", "a string containing a \" quote", "102.20"}, + }, + { + `"4","a string containing a \", quote and comma","102.20"`, + []string{"4", "a string containing a \", quote and comma", "102.20"}, + }, + // Test some escape char. + { + `"\0\b\n\r\t\Z\\\ \c\'\""`, + []string{string([]byte{0, '\b', '\n', '\r', '\t', 26, '\\', ' ', ' ', 'c', '\'', '"'})}, + }, + } + fieldsInfo := &ast.FieldsClause{ + Enclosed: '"', + Terminated: ",", + } + + for _, test := range tests { + got, err := executor.GetFieldsFromLine([]byte(test.input), fieldsInfo) + c.Assert(err, IsNil, Commentf("failed: %s", test.input)) + assertEqualStrings(c, got, test.expected) + } + + _, err := executor.GetFieldsFromLine([]byte(`1,a string,100.20`), fieldsInfo) + c.Assert(err, NotNil) +} + +func assertEqualStrings(c *C, got []string, expect []string) { + c.Assert(len(got), Equals, len(expect)) + for i := 0; i < len(got); i++ { + c.Assert(got[i], Equals, expect[i]) + } +}