diff --git a/pkg/cli/cli_test.go b/pkg/cli/cli_test.go index e5ea810320fe..b186bddf72b8 100644 --- a/pkg/cli/cli_test.go +++ b/pkg/cli/cli_test.go @@ -623,6 +623,16 @@ func Example_sql_format() { c.RunWithArgs([]string{"sql", "-e", "select '2021-03-20'::date; select '01:01'::time; select '01:01'::timetz; select '01:01+02:02'::timetz"}) c.RunWithArgs([]string{"sql", "-e", "select (1/3.0)::real; select (1/3.0)::double precision; select '-inf'::float8"}) + // Special characters inside arrays used to be represented as escaped bytes. + c.RunWithArgs([]string{"sql", "-e", "select array['哈哈'::TEXT], array['哈哈'::NAME], array['哈哈'::VARCHAR]"}) + c.RunWithArgs([]string{"sql", "-e", "select array['哈哈'::CHAR(2)], array['哈'::\"char\"]"}) + // Preserve quoting of arrays containing commas or double quotes. + c.RunWithArgs([]string{"sql", "-e", `select array['a,b', 'a"b', 'a\b']`, "--format=table"}) + // Infinities inside float arrays used to be represented differently from infinities as simpler scalar. + c.RunWithArgs([]string{"sql", "-e", "select array['Inf'::FLOAT4, '-Inf'::FLOAT4], array['Inf'::FLOAT8]"}) + // Sanity check for other array types. + c.RunWithArgs([]string{"sql", "-e", "select array[true, false], array['01:01'::time], array['2021-03-20'::date]"}) + c.RunWithArgs([]string{"sql", "-e", "select array[123::int2], array[123::int4], array[123::int8]"}) // Output: // sql -e create database t; create table t.times (bare timestamp, withtz timestamptz) @@ -650,6 +660,26 @@ func Example_sql_format() { // 0.3333333333333333 // float8 // -Infinity + // sql -e select array['哈哈'::TEXT], array['哈哈'::NAME], array['哈哈'::VARCHAR] + // array array array + // {哈哈} {哈哈} {哈哈} + // sql -e select array['哈哈'::CHAR(2)], array['哈'::"char"] + // array array + // {哈哈} {哈} + // sql -e select array['a,b', 'a"b', 'a\b'] --format=table + // array + // ------------------------- + // {"a,b","a\"b","a\\b"} + // (1 row) + // sql -e select array['Inf'::FLOAT4, '-Inf'::FLOAT4], array['Inf'::FLOAT8] + // array array + // {Infinity,-Infinity} {Infinity} + // sql -e select array[true, false], array['01:01'::time], array['2021-03-20'::date] + // array array array + // {true,false} {01:01:00} {2021-03-20} + // sql -e select array[123::int2], array[123::int4], array[123::int8] + // array array array + // {123} {123} {123} } func Example_sql_column_labels() { diff --git a/pkg/cli/sql_util.go b/pkg/cli/sql_util.go index 4e5522e67d0d..96cf95bdbd75 100644 --- a/pkg/cli/sql_util.go +++ b/pkg/cli/sql_util.go @@ -12,12 +12,14 @@ package cli import ( "context" + gosql "database/sql" "database/sql/driver" "fmt" "io" "math" "net/url" "reflect" + "regexp" "strconv" "strings" "time" @@ -1121,9 +1123,17 @@ func isNotGraphicUnicodeOrTabOrNewline(r rune) bool { func formatVal( val driver.Value, colType string, showPrintableUnicode bool, showNewLinesAndTabs bool, ) string { - if b, ok := val.([]byte); ok && colType == "NAME" { - val = string(b) - colType = "VARCHAR" + log.VInfof(context.Background(), 2, "value: go %T, sql %q", val, colType) + + if b, ok := val.([]byte); ok { + if strings.HasPrefix(colType, "_") && len(b) > 0 && b[0] == '{' { + return formatArray(b, colType[1:], showPrintableUnicode, showNewLinesAndTabs) + } + + if colType == "NAME" { + val = string(b) + colType = "VARCHAR" + } } switch t := val.(type) { @@ -1197,6 +1207,82 @@ func formatVal( return fmt.Sprint(val) } +func formatArray( + b []byte, colType string, showPrintableUnicode bool, showNewLinesAndTabs bool, +) string { + // backingArray is the array we're going to parse the server data + // into. + var backingArray interface{} + // parsingArray is a helper structure provided by lib/pq to parse + // arrays. + var parsingArray gosql.Scanner + + // lib.pq has different array parsers for special value types. + // + // TODO(knz): This would better use a general-purpose parser + // using the OID to look up an array parser in crdb's sql package. + // However, unfortunately the OID is hidden from us. + switch colType { + case "BOOL": + boolArray := []bool{} + backingArray = &boolArray + parsingArray = (*pq.BoolArray)(&boolArray) + case "FLOAT4", "FLOAT8": + floatArray := []float64{} + backingArray = &floatArray + parsingArray = (*pq.Float64Array)(&floatArray) + case "INT2", "INT4", "INT8", "OID": + intArray := []int64{} + backingArray = &intArray + parsingArray = (*pq.Int64Array)(&intArray) + case "TEXT", "VARCHAR", "NAME", "CHAR", "BPCHAR": + stringArray := []string{} + backingArray = &stringArray + parsingArray = (*pq.StringArray)(&stringArray) + default: + genArray := [][]byte{} + backingArray = &genArray + parsingArray = &pq.GenericArray{A: &genArray} + } + + // Now ask the pq array parser to convert the byte slice + // from the server into a Go array. + if err := parsingArray.Scan(b); err != nil { + // A parsing failure is not a catastrophe; we can still print out + // the array as a byte slice. This will do in many cases. + log.VInfof(context.Background(), 1, "unable to parse %q (sql %q) as array: %v", b, colType, err) + return formatVal(b, "BYTEA", showPrintableUnicode, showNewLinesAndTabs) + } + + // We have a go array in "backingArray". Now print it out. + var buf strings.Builder + buf.WriteByte('{') + comma := "" // delimiter + v := reflect.ValueOf(backingArray).Elem() + for i := 0; i < v.Len(); i++ { + buf.WriteString(comma) + + // Access the i-th element in the backingArray. + arrayVal := driver.Value(v.Index(i).Interface()) + // Format the value recursively into a string. + vs := formatVal(arrayVal, colType, showPrintableUnicode, showNewLinesAndTabs) + + // If the value contains special characters or a comma, enclose in double quotes. + // Also escape the special characters. + if strings.IndexByte(vs, ',') >= 0 || reArrayStringEscape.MatchString(vs) { + vs = "\"" + reArrayStringEscape.ReplaceAllString(vs, "\\$1") + "\"" + } + + // Add the string for that one value to the output array representation. + buf.WriteString(vs) + comma = "," + } + buf.WriteByte('}') + return buf.String() +} + +var reArrayStringEscape = regexp.MustCompile(`(["\\])`) + var timeOutputFormats = map[string]string{ "TIMESTAMP": timeutil.TimestampWithoutTZFormat, "TIMESTAMPTZ": timeutil.TimestampWithTZFormat,