Skip to content

Commit

Permalink
cli/sql: properly format different times and floats
Browse files Browse the repository at this point in the history
Prior to this patch, the SQL shell would improperly format
variants of timestamps (time, date) and mistakenly
use the same format for 32-bit and 64-bit float.
This patch corrects this.

Release note (cli change): The SQL shell (`cockroach demo`, `cockroach
sql`) now attempts to better format values that are akin to time/date,
as well as floating-point numbers.
  • Loading branch information
knz committed May 13, 2021
1 parent a5f8630 commit dbe3f3d
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 36 deletions.
24 changes: 20 additions & 4 deletions pkg/cli/cli_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -619,16 +619,32 @@ func Example_sql_format() {

c.RunWithArgs([]string{"sql", "-e", "create database t; create table t.times (bare timestamp, withtz timestamptz)"})
c.RunWithArgs([]string{"sql", "-e", "insert into t.times values ('2016-01-25 10:10:10', '2016-01-25 10:10:10-05:00')"})
c.RunWithArgs([]string{"sql", "-e", "select * from t.times"})
c.RunWithArgs([]string{"sql", "-e", "select bare from t.times; select withtz from t.times"})
c.RunWithArgs([]string{"sql", "-e", "select '2021-03-20'::date; select '01:01'::time; select '01:01'::timetz"})
c.RunWithArgs([]string{"sql", "-e", "select (1/3.0)::real; select (1/3.0)::double precision"})

// Output:
// sql -e create database t; create table t.times (bare timestamp, withtz timestamptz)
// CREATE TABLE
// sql -e insert into t.times values ('2016-01-25 10:10:10', '2016-01-25 10:10:10-05:00')
// INSERT 1
// sql -e select * from t.times
// bare withtz
// 2016-01-25 10:10:10+00:00:00 2016-01-25 15:10:10+00:00:00
// sql -e select bare from t.times; select withtz from t.times
// bare
// 2016-01-25 10:10:10
// withtz
// 2016-01-25 15:10:10+00:00:00
// sql -e select '2021-03-20'::date; select '01:01'::time; select '01:01'::timetz
// date
// 2021-03-20
// time
// 01:01:00
// timetz
// 01:01:00+00:00:00
// sql -e select (1/3.0)::real; select (1/3.0)::double precision
// float4
// 0.33333334
// float8
// 0.3333333333333333
}

func Example_sql_column_labels() {
Expand Down
6 changes: 4 additions & 2 deletions pkg/cli/format_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,12 @@ func newRowSliceIter(allRows [][]string, align string) *rowSliceIter {

type rowIter struct {
rows *sqlRows
colTypes []string
showMoreChars bool
}

func (iter *rowIter) Next() (row []string, err error) {
nextRowString, err := getNextRowStrings(iter.rows, iter.showMoreChars)
nextRowString, err := getNextRowStrings(iter.rows, iter.colTypes, iter.showMoreChars)
if err != nil {
return nil, err
}
Expand All @@ -109,7 +110,7 @@ func (iter *rowIter) Next() (row []string, err error) {
}

func (iter *rowIter) ToSlice() ([][]string, error) {
return getAllRowStrings(iter.rows, iter.showMoreChars)
return getAllRowStrings(iter.rows, iter.colTypes, iter.showMoreChars)
}

func (iter *rowIter) Align() []int {
Expand Down Expand Up @@ -137,6 +138,7 @@ func (iter *rowIter) Align() []int {
func newRowIter(rows *sqlRows, showMoreChars bool) *rowIter {
return &rowIter{
rows: rows,
colTypes: rows.getColTypes(),
showMoreChars: showMoreChars,
}
}
Expand Down
8 changes: 4 additions & 4 deletions pkg/cli/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -821,12 +821,12 @@ func (c *cliState) doRefreshPrompts(nextState cliStateEnum) cliStateEnum {
func (c *cliState) refreshTransactionStatus() {
c.lastKnownTxnStatus = unknownTxnStatus

dbVal, hasVal := c.conn.getServerValue("transaction status", `SHOW TRANSACTION STATUS`)
dbVal, dbColType, hasVal := c.conn.getServerValue("transaction status", `SHOW TRANSACTION STATUS`)
if !hasVal {
return
}

txnString := formatVal(dbVal,
txnString := formatVal(dbVal, dbColType,
false /* showPrintableUnicode */, false /* shownewLinesAndTabs */)

// Change the prompt based on the response from the server.
Expand Down Expand Up @@ -854,7 +854,7 @@ func (c *cliState) refreshDatabaseName() string {
return unknownDbName
}

dbVal, hasVal := c.conn.getServerValue("database name", `SHOW DATABASE`)
dbVal, dbColType, hasVal := c.conn.getServerValue("database name", `SHOW DATABASE`)
if !hasVal {
return unknownDbName
}
Expand All @@ -865,7 +865,7 @@ func (c *cliState) refreshDatabaseName() string {
" Use SET database = <dbname> to change, CREATE DATABASE to make a new database.")
}

dbName := formatVal(dbVal.(string),
dbName := formatVal(dbVal, dbColType,
false /* showPrintableUnicode */, false /* shownewLinesAndTabs */)

// Preserve the current database name in case of reconnects.
Expand Down
81 changes: 55 additions & 26 deletions pkg/cli/sql_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ func (c *sqlConn) getServerMetadata() (
defer func() { _ = rows.Close() }()

// Read the node_build_info table as an array of strings.
rowVals, err := getAllRowStrings(rows, true /* showMoreChars */)
rowVals, err := getAllRowStrings(rows, rows.getColTypes(), true /* showMoreChars */)
if err != nil || len(rowVals) == 0 || len(rowVals[0]) != 3 {
return 0, "", "", errors.New("incorrect data while retrieving the server version")
}
Expand Down Expand Up @@ -352,28 +352,29 @@ func (c *sqlConn) checkServerMetadata() error {
// getServerValue retrieves the first driverValue returned by the
// given sql query. If the query fails or does not return a single
// column, `false` is returned in the second result.
func (c *sqlConn) getServerValue(what, sql string) (driver.Value, bool) {
func (c *sqlConn) getServerValue(what, sql string) (driver.Value, string, bool) {
rows, err := c.Query(sql, nil)
if err != nil {
fmt.Fprintf(stderr, "warning: error retrieving the %s: %v\n", what, err)
return nil, false
return nil, "", false
}
defer func() { _ = rows.Close() }()

if len(rows.Columns()) == 0 {
fmt.Fprintf(stderr, "warning: cannot get the %s\n", what)
return nil, false
return nil, "", false
}

dbColType := rows.ColumnTypeDatabaseTypeName(0)
dbVals := make([]driver.Value, len(rows.Columns()))

err = rows.Next(dbVals[:])
if err != nil {
fmt.Fprintf(stderr, "warning: invalid %s: %v\n", what, err)
return nil, false
return nil, "", false
}

return dbVals[0], true
return dbVals[0], dbColType, true
}

// parseLastQueryStatistics runs the "SHOW LAST QUERY STATISTICS" statements,
Expand Down Expand Up @@ -419,10 +420,10 @@ func (c *sqlConn) getLastQueryStatistics() (
return 0, 0, 0, 0, err
}

parseLatencyRaw = formatVal(row[0], false, false)
planLatencyRaw = formatVal(row[1], false, false)
execLatencyRaw = formatVal(row[2], false, false)
serviceLatencyRaw = formatVal(row[3], false, false)
parseLatencyRaw = formatVal(row[0], iter.colTypes[0], false, false)
planLatencyRaw = formatVal(row[1], iter.colTypes[1], false, false)
execLatencyRaw = formatVal(row[2], iter.colTypes[2], false, false)
serviceLatencyRaw = formatVal(row[3], iter.colTypes[3], false, false)

nRows++
}
Expand Down Expand Up @@ -630,6 +631,14 @@ func (r *sqlRows) ColumnTypeDatabaseTypeName(index int) string {
return r.rows.ColumnTypeDatabaseTypeName(index)
}

func (r *sqlRows) getColTypes() []string {
colTypes := make([]string, len(r.Columns()))
for i := range colTypes {
colTypes[i] = r.ColumnTypeDatabaseTypeName(i)
}
return colTypes
}

func makeSQLConn(url string) *sqlConn {
return &sqlConn{
url: url,
Expand Down Expand Up @@ -1047,7 +1056,7 @@ func maybeShowTimes(
// If showMoreChars is true, then more characters are not escaped.
func sqlRowsToStrings(rows *sqlRows, showMoreChars bool) ([]string, [][]string, error) {
cols := getColumnStrings(rows, showMoreChars)
allRows, err := getAllRowStrings(rows, showMoreChars)
allRows, err := getAllRowStrings(rows, rows.getColTypes(), showMoreChars)
if err != nil {
return nil, nil, err
}
Expand All @@ -1058,16 +1067,16 @@ func getColumnStrings(rows *sqlRows, showMoreChars bool) []string {
srcCols := rows.Columns()
cols := make([]string, len(srcCols))
for i, c := range srcCols {
cols[i] = formatVal(c, showMoreChars, showMoreChars)
cols[i] = formatVal(c, "NAME", showMoreChars, showMoreChars)
}
return cols
}

func getAllRowStrings(rows *sqlRows, showMoreChars bool) ([][]string, error) {
func getAllRowStrings(rows *sqlRows, colTypes []string, showMoreChars bool) ([][]string, error) {
var allRows [][]string

for {
rowStrings, err := getNextRowStrings(rows, showMoreChars)
rowStrings, err := getNextRowStrings(rows, colTypes, showMoreChars)
if err != nil {
return nil, err
}
Expand All @@ -1080,7 +1089,7 @@ func getAllRowStrings(rows *sqlRows, showMoreChars bool) ([][]string, error) {
return allRows, nil
}

func getNextRowStrings(rows *sqlRows, showMoreChars bool) ([]string, error) {
func getNextRowStrings(rows *sqlRows, colTypes []string, showMoreChars bool) ([]string, error) {
cols := rows.Columns()
var vals []driver.Value
if len(cols) > 0 {
Expand All @@ -1097,13 +1106,7 @@ func getNextRowStrings(rows *sqlRows, showMoreChars bool) ([]string, error) {

rowStrings := make([]string, len(cols))
for i, v := range vals {
databaseType := rows.ColumnTypeDatabaseTypeName(i)
if databaseType == "NAME" {
if bytes, ok := v.([]byte); ok {
v = string(bytes)
}
}
rowStrings[i] = formatVal(v, showMoreChars, showMoreChars)
rowStrings[i] = formatVal(v, colTypes[i], showMoreChars, showMoreChars)
}
return rowStrings, nil
}
Expand All @@ -1114,10 +1117,25 @@ func isNotGraphicUnicodeOrTabOrNewline(r rune) bool {
return r != '\t' && r != '\n' && !unicode.IsGraphic(r)
}

func formatVal(val driver.Value, showPrintableUnicode bool, showNewLinesAndTabs bool) string {
func formatVal(
val driver.Value, colType string, showPrintableUnicode bool, showNewLinesAndTabs bool,
) string {
if b, ok := val.([]byte); ok && colType == "NAME" {
val = string(b)
colType = "VARCHAR"
}

switch t := val.(type) {
case nil:
return "NULL"

case float64:
width := 64
if colType == "FLOAT4" {
width = 32
}
return strconv.FormatFloat(t, 'g', -1, width)

case string:
if showPrintableUnicode {
pred := isNotGraphicUnicode
Expand Down Expand Up @@ -1155,14 +1173,25 @@ func formatVal(val driver.Value, showPrintableUnicode bool, showNewLinesAndTabs
sessiondatapb.BytesEncodeEscape, false /* skipHexPrefix */)

case time.Time:
// Since we do not know whether the datum is Timestamp or TimestampTZ,
// output the full format.
return t.Format(timeutil.FullTimeFormat)
tfmt, ok := timeOutputFormats[colType]
if !ok {
// Some unknown/new time-like format.
tfmt = timeutil.FullTimeFormat
}
return t.Format(tfmt)
}

return fmt.Sprint(val)
}

var timeOutputFormats = map[string]string{
"TIMESTAMP": timeutil.TimestampWithoutTZFormat,
"TIMESTAMPTZ": timeutil.FullTimeFormat,
"TIME": timeutil.TimeWithoutTZFormat,
"TIMETZ": timeutil.TimeWithTZFormat,
"DATE": timeutil.DateFormat,
}

// parseBool parses a boolean string for use in slash commands.
func parseBool(s string) (bool, error) {
switch strings.TrimSpace(strings.ToLower(s)) {
Expand Down
15 changes: 15 additions & 0 deletions pkg/util/timeutil/timeutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,18 @@ package timeutil
// FullTimeFormat is the time format used to display any timestamp
// with date, time and time zone data.
const FullTimeFormat = "2006-01-02 15:04:05.999999-07:00:00"

// TimestampWithoutTZFormat is the time format used to display
// timestamps without a time zone offset.
const TimestampWithoutTZFormat = "2006-01-02 15:04:05.999999"

// TimeWithTZFormat is the time format used to display a time
// with a time zone offset.
const TimeWithTZFormat = "15:04:05.999999-07:00:00"

// TimeWithoutTZFormat is the time format used to display a time
// without a time zone offset.
const TimeWithoutTZFormat = "15:04:05.999999"

// DateFormat is the time format used to display a date.
const DateFormat = "2006-01-02"

0 comments on commit dbe3f3d

Please sign in to comment.