Skip to content

Commit

Permalink
Merge pull request #63541 from knz/backport21.1-62310
Browse files Browse the repository at this point in the history
release-21.1: cli/sql: properly format different times and floats
  • Loading branch information
knz authored May 13, 2021
2 parents 3884c78 + dbe3f3d commit bbee893
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 36 deletions.
24 changes: 20 additions & 4 deletions pkg/cli/cli_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -619,16 +619,32 @@ func Example_sql_format() {

c.RunWithArgs([]string{"sql", "-e", "create database t; create table t.times (bare timestamp, withtz timestamptz)"})
c.RunWithArgs([]string{"sql", "-e", "insert into t.times values ('2016-01-25 10:10:10', '2016-01-25 10:10:10-05:00')"})
c.RunWithArgs([]string{"sql", "-e", "select * from t.times"})
c.RunWithArgs([]string{"sql", "-e", "select bare from t.times; select withtz from t.times"})
c.RunWithArgs([]string{"sql", "-e", "select '2021-03-20'::date; select '01:01'::time; select '01:01'::timetz"})
c.RunWithArgs([]string{"sql", "-e", "select (1/3.0)::real; select (1/3.0)::double precision"})

// Output:
// sql -e create database t; create table t.times (bare timestamp, withtz timestamptz)
// CREATE TABLE
// sql -e insert into t.times values ('2016-01-25 10:10:10', '2016-01-25 10:10:10-05:00')
// INSERT 1
// sql -e select * from t.times
// bare withtz
// 2016-01-25 10:10:10+00:00:00 2016-01-25 15:10:10+00:00:00
// sql -e select bare from t.times; select withtz from t.times
// bare
// 2016-01-25 10:10:10
// withtz
// 2016-01-25 15:10:10+00:00:00
// sql -e select '2021-03-20'::date; select '01:01'::time; select '01:01'::timetz
// date
// 2021-03-20
// time
// 01:01:00
// timetz
// 01:01:00+00:00:00
// sql -e select (1/3.0)::real; select (1/3.0)::double precision
// float4
// 0.33333334
// float8
// 0.3333333333333333
}

func Example_sql_column_labels() {
Expand Down
6 changes: 4 additions & 2 deletions pkg/cli/format_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,12 @@ func newRowSliceIter(allRows [][]string, align string) *rowSliceIter {

type rowIter struct {
rows *sqlRows
colTypes []string
showMoreChars bool
}

func (iter *rowIter) Next() (row []string, err error) {
nextRowString, err := getNextRowStrings(iter.rows, iter.showMoreChars)
nextRowString, err := getNextRowStrings(iter.rows, iter.colTypes, iter.showMoreChars)
if err != nil {
return nil, err
}
Expand All @@ -109,7 +110,7 @@ func (iter *rowIter) Next() (row []string, err error) {
}

func (iter *rowIter) ToSlice() ([][]string, error) {
return getAllRowStrings(iter.rows, iter.showMoreChars)
return getAllRowStrings(iter.rows, iter.colTypes, iter.showMoreChars)
}

func (iter *rowIter) Align() []int {
Expand Down Expand Up @@ -137,6 +138,7 @@ func (iter *rowIter) Align() []int {
func newRowIter(rows *sqlRows, showMoreChars bool) *rowIter {
return &rowIter{
rows: rows,
colTypes: rows.getColTypes(),
showMoreChars: showMoreChars,
}
}
Expand Down
8 changes: 4 additions & 4 deletions pkg/cli/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -821,12 +821,12 @@ func (c *cliState) doRefreshPrompts(nextState cliStateEnum) cliStateEnum {
func (c *cliState) refreshTransactionStatus() {
c.lastKnownTxnStatus = unknownTxnStatus

dbVal, hasVal := c.conn.getServerValue("transaction status", `SHOW TRANSACTION STATUS`)
dbVal, dbColType, hasVal := c.conn.getServerValue("transaction status", `SHOW TRANSACTION STATUS`)
if !hasVal {
return
}

txnString := formatVal(dbVal,
txnString := formatVal(dbVal, dbColType,
false /* showPrintableUnicode */, false /* shownewLinesAndTabs */)

// Change the prompt based on the response from the server.
Expand Down Expand Up @@ -854,7 +854,7 @@ func (c *cliState) refreshDatabaseName() string {
return unknownDbName
}

dbVal, hasVal := c.conn.getServerValue("database name", `SHOW DATABASE`)
dbVal, dbColType, hasVal := c.conn.getServerValue("database name", `SHOW DATABASE`)
if !hasVal {
return unknownDbName
}
Expand All @@ -865,7 +865,7 @@ func (c *cliState) refreshDatabaseName() string {
" Use SET database = <dbname> to change, CREATE DATABASE to make a new database.")
}

dbName := formatVal(dbVal.(string),
dbName := formatVal(dbVal, dbColType,
false /* showPrintableUnicode */, false /* shownewLinesAndTabs */)

// Preserve the current database name in case of reconnects.
Expand Down
81 changes: 55 additions & 26 deletions pkg/cli/sql_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ func (c *sqlConn) getServerMetadata() (
defer func() { _ = rows.Close() }()

// Read the node_build_info table as an array of strings.
rowVals, err := getAllRowStrings(rows, true /* showMoreChars */)
rowVals, err := getAllRowStrings(rows, rows.getColTypes(), true /* showMoreChars */)
if err != nil || len(rowVals) == 0 || len(rowVals[0]) != 3 {
return 0, "", "", errors.New("incorrect data while retrieving the server version")
}
Expand Down Expand Up @@ -352,28 +352,29 @@ func (c *sqlConn) checkServerMetadata() error {
// getServerValue retrieves the first driverValue returned by the
// given sql query. If the query fails or does not return a single
// column, `false` is returned in the second result.
func (c *sqlConn) getServerValue(what, sql string) (driver.Value, bool) {
func (c *sqlConn) getServerValue(what, sql string) (driver.Value, string, bool) {
rows, err := c.Query(sql, nil)
if err != nil {
fmt.Fprintf(stderr, "warning: error retrieving the %s: %v\n", what, err)
return nil, false
return nil, "", false
}
defer func() { _ = rows.Close() }()

if len(rows.Columns()) == 0 {
fmt.Fprintf(stderr, "warning: cannot get the %s\n", what)
return nil, false
return nil, "", false
}

dbColType := rows.ColumnTypeDatabaseTypeName(0)
dbVals := make([]driver.Value, len(rows.Columns()))

err = rows.Next(dbVals[:])
if err != nil {
fmt.Fprintf(stderr, "warning: invalid %s: %v\n", what, err)
return nil, false
return nil, "", false
}

return dbVals[0], true
return dbVals[0], dbColType, true
}

// parseLastQueryStatistics runs the "SHOW LAST QUERY STATISTICS" statements,
Expand Down Expand Up @@ -419,10 +420,10 @@ func (c *sqlConn) getLastQueryStatistics() (
return 0, 0, 0, 0, err
}

parseLatencyRaw = formatVal(row[0], false, false)
planLatencyRaw = formatVal(row[1], false, false)
execLatencyRaw = formatVal(row[2], false, false)
serviceLatencyRaw = formatVal(row[3], false, false)
parseLatencyRaw = formatVal(row[0], iter.colTypes[0], false, false)
planLatencyRaw = formatVal(row[1], iter.colTypes[1], false, false)
execLatencyRaw = formatVal(row[2], iter.colTypes[2], false, false)
serviceLatencyRaw = formatVal(row[3], iter.colTypes[3], false, false)

nRows++
}
Expand Down Expand Up @@ -630,6 +631,14 @@ func (r *sqlRows) ColumnTypeDatabaseTypeName(index int) string {
return r.rows.ColumnTypeDatabaseTypeName(index)
}

func (r *sqlRows) getColTypes() []string {
colTypes := make([]string, len(r.Columns()))
for i := range colTypes {
colTypes[i] = r.ColumnTypeDatabaseTypeName(i)
}
return colTypes
}

func makeSQLConn(url string) *sqlConn {
return &sqlConn{
url: url,
Expand Down Expand Up @@ -1047,7 +1056,7 @@ func maybeShowTimes(
// If showMoreChars is true, then more characters are not escaped.
func sqlRowsToStrings(rows *sqlRows, showMoreChars bool) ([]string, [][]string, error) {
cols := getColumnStrings(rows, showMoreChars)
allRows, err := getAllRowStrings(rows, showMoreChars)
allRows, err := getAllRowStrings(rows, rows.getColTypes(), showMoreChars)
if err != nil {
return nil, nil, err
}
Expand All @@ -1058,16 +1067,16 @@ func getColumnStrings(rows *sqlRows, showMoreChars bool) []string {
srcCols := rows.Columns()
cols := make([]string, len(srcCols))
for i, c := range srcCols {
cols[i] = formatVal(c, showMoreChars, showMoreChars)
cols[i] = formatVal(c, "NAME", showMoreChars, showMoreChars)
}
return cols
}

func getAllRowStrings(rows *sqlRows, showMoreChars bool) ([][]string, error) {
func getAllRowStrings(rows *sqlRows, colTypes []string, showMoreChars bool) ([][]string, error) {
var allRows [][]string

for {
rowStrings, err := getNextRowStrings(rows, showMoreChars)
rowStrings, err := getNextRowStrings(rows, colTypes, showMoreChars)
if err != nil {
return nil, err
}
Expand All @@ -1080,7 +1089,7 @@ func getAllRowStrings(rows *sqlRows, showMoreChars bool) ([][]string, error) {
return allRows, nil
}

func getNextRowStrings(rows *sqlRows, showMoreChars bool) ([]string, error) {
func getNextRowStrings(rows *sqlRows, colTypes []string, showMoreChars bool) ([]string, error) {
cols := rows.Columns()
var vals []driver.Value
if len(cols) > 0 {
Expand All @@ -1097,13 +1106,7 @@ func getNextRowStrings(rows *sqlRows, showMoreChars bool) ([]string, error) {

rowStrings := make([]string, len(cols))
for i, v := range vals {
databaseType := rows.ColumnTypeDatabaseTypeName(i)
if databaseType == "NAME" {
if bytes, ok := v.([]byte); ok {
v = string(bytes)
}
}
rowStrings[i] = formatVal(v, showMoreChars, showMoreChars)
rowStrings[i] = formatVal(v, colTypes[i], showMoreChars, showMoreChars)
}
return rowStrings, nil
}
Expand All @@ -1114,10 +1117,25 @@ func isNotGraphicUnicodeOrTabOrNewline(r rune) bool {
return r != '\t' && r != '\n' && !unicode.IsGraphic(r)
}

func formatVal(val driver.Value, showPrintableUnicode bool, showNewLinesAndTabs bool) string {
func formatVal(
val driver.Value, colType string, showPrintableUnicode bool, showNewLinesAndTabs bool,
) string {
if b, ok := val.([]byte); ok && colType == "NAME" {
val = string(b)
colType = "VARCHAR"
}

switch t := val.(type) {
case nil:
return "NULL"

case float64:
width := 64
if colType == "FLOAT4" {
width = 32
}
return strconv.FormatFloat(t, 'g', -1, width)

case string:
if showPrintableUnicode {
pred := isNotGraphicUnicode
Expand Down Expand Up @@ -1155,14 +1173,25 @@ func formatVal(val driver.Value, showPrintableUnicode bool, showNewLinesAndTabs
sessiondatapb.BytesEncodeEscape, false /* skipHexPrefix */)

case time.Time:
// Since we do not know whether the datum is Timestamp or TimestampTZ,
// output the full format.
return t.Format(timeutil.FullTimeFormat)
tfmt, ok := timeOutputFormats[colType]
if !ok {
// Some unknown/new time-like format.
tfmt = timeutil.FullTimeFormat
}
return t.Format(tfmt)
}

return fmt.Sprint(val)
}

var timeOutputFormats = map[string]string{
"TIMESTAMP": timeutil.TimestampWithoutTZFormat,
"TIMESTAMPTZ": timeutil.FullTimeFormat,
"TIME": timeutil.TimeWithoutTZFormat,
"TIMETZ": timeutil.TimeWithTZFormat,
"DATE": timeutil.DateFormat,
}

// parseBool parses a boolean string for use in slash commands.
func parseBool(s string) (bool, error) {
switch strings.TrimSpace(strings.ToLower(s)) {
Expand Down
15 changes: 15 additions & 0 deletions pkg/util/timeutil/timeutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,18 @@ package timeutil
// FullTimeFormat is the time format used to display any timestamp
// with date, time and time zone data.
const FullTimeFormat = "2006-01-02 15:04:05.999999-07:00:00"

// TimestampWithoutTZFormat is the time format used to display
// timestamps without a time zone offset.
const TimestampWithoutTZFormat = "2006-01-02 15:04:05.999999"

// TimeWithTZFormat is the time format used to display a time
// with a time zone offset.
const TimeWithTZFormat = "15:04:05.999999-07:00:00"

// TimeWithoutTZFormat is the time format used to display a time
// without a time zone offset.
const TimeWithoutTZFormat = "15:04:05.999999"

// DateFormat is the time format used to display a date.
const DateFormat = "2006-01-02"

0 comments on commit bbee893

Please sign in to comment.