Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cli: additional enhancements after switching to jackc/pgx #82101

Merged
merged 7 commits into from
Jun 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions pkg/cli/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,12 @@ func createAuthSessionToken(
ctx := context.Background()

// First things first. Does the user exist?
_, rows, err := sqlExecCtx.RunQuery(ctx,
_, rows, err := sqlExecCtx.RunQuery(
ctx,
sqlConn,
clisqlclient.MakeQuery(`SELECT count(username) FROM system.users WHERE username = $1 AND NOT "isRole"`, username), false)
clisqlclient.MakeQuery(`SELECT count(username) FROM system.users WHERE username = $1 AND NOT "isRole"`, username),
false, /* showMoreChars */
)
if err != nil {
return -1, nil, err
}
Expand Down
14 changes: 4 additions & 10 deletions pkg/cli/clisqlclient/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (
"context"
"database/sql/driver"
"io"
"reflect"
"time"
)

Expand Down Expand Up @@ -89,11 +88,14 @@ type Conn interface {
// The what argument is a descriptive label for the value being
// retrieved, for inclusion inside warning or error message.
// The sql argument is the SQL query to use to retrieve the value.
GetServerValue(ctx context.Context, what, sql string) (driver.Value, string, bool)
GetServerValue(ctx context.Context, what, sql string) (driver.Value, bool)

// GetDriverConn exposes the underlying driver connection object
// for use by the cli package.
GetDriverConn() DriverConn

// Cancel sends a query cancellation request to the server.
Cancel(ctx context.Context) error
}

// Rows describes a result set.
Expand All @@ -109,18 +111,10 @@ type Rows interface {
// result does not need to be constructed on each invocation.
Columns() []string

// ColumnTypeScanType returns the natural Go type of values at the
// given column index.
ColumnTypeScanType(index int) reflect.Type

// ColumnTypeDatabaseTypeName returns the database type name
// of the column at the given column index.
ColumnTypeDatabaseTypeName(index int) string

// ColumnTypeNames returns the database type names for all
// columns.
ColumnTypeNames() []string

// Tag retrieves the statement tag for the current result set.
Tag() (CommandTag, error)

Expand Down
19 changes: 10 additions & 9 deletions pkg/cli/clisqlclient/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,10 @@ func (c *sqlConn) GetDriverConn() DriverConn {
return &driverConnAdapter{c}
}

func (c *sqlConn) Cancel(ctx context.Context) error {
return c.conn.PgConn().CancelRequest(ctx)
}

// SetCurrentDatabase implements the Conn interface.
func (c *sqlConn) SetCurrentDatabase(dbName string) {
c.dbName = dbName
Expand Down Expand Up @@ -442,31 +446,28 @@ func (c *sqlConn) checkServerMetadata(ctx context.Context) error {
// GetServerValue retrieves the first driverValue returned by the
// given sql query. If the query fails or does not return a single
// column, `false` is returned in the second result.
func (c *sqlConn) GetServerValue(
ctx context.Context, what, sql string,
) (driver.Value, string, bool) {
func (c *sqlConn) GetServerValue(ctx context.Context, what, sql string) (driver.Value, bool) {
rows, err := c.Query(ctx, sql)
if err != nil {
fmt.Fprintf(c.errw, "warning: error retrieving the %s: %v\n", what, err)
return nil, "", false
return nil, false
}
defer func() { _ = rows.Close() }()

if len(rows.Columns()) == 0 {
fmt.Fprintf(c.errw, "warning: cannot get the %s\n", what)
return nil, "", false
return nil, false
}

dbColType := rows.ColumnTypeDatabaseTypeName(0)
dbVals := make([]driver.Value, len(rows.Columns()))

err = rows.Next(dbVals[:])
err = rows.Next(dbVals)
if err != nil {
fmt.Fprintf(c.errw, "warning: invalid %s: %v\n", what, err)
return nil, "", false
return nil, false
}

return dbVals[0], dbColType, true
return dbVals[0], true
}

func (c *sqlConn) GetLastQueryStatistics(ctx context.Context) (results QueryStats, resErr error) {
Expand Down
9 changes: 0 additions & 9 deletions pkg/cli/clisqlclient/copy.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ import (
"context"
"database/sql/driver"
"io"
"reflect"

"github.com/jackc/pgconn"
)
Expand Down Expand Up @@ -53,18 +52,10 @@ func (c copyFromRows) Columns() []string {
return nil
}

func (c copyFromRows) ColumnTypeScanType(index int) reflect.Type {
return nil
}

func (c copyFromRows) ColumnTypeDatabaseTypeName(index int) string {
return ""
}

func (c copyFromRows) ColumnTypeNames() []string {
return nil
}

func (c copyFromRows) Tag() (CommandTag, error) {
return c.t, nil
}
Expand Down
34 changes: 0 additions & 34 deletions pkg/cli/clisqlclient/row_type_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,45 +11,11 @@
package clisqlclient

import (
"reflect"
"strings"
"time"

"github.com/jackc/pgtype"
)

// scanType returns the Go type used for the given column type.
func scanType(typeOID uint32, typeName string) reflect.Type {
if typeName == "" || strings.HasPrefix(typeName, "_") {
// User-defined types and array types are scanned into []byte.
// These are handled separately since we can't easily include all the OIDs
// for user-defined types and arrays in the switch statement.
return reflect.TypeOf([]byte(nil))
}
// This switch is copied from lib/pq, and modified with a few additional cases.
// https://github.com/lib/pq/blob/b2901c7946b69f1e7226214f9760e31620499595/rows.go#L24
switch typeOID {
case pgtype.Int8OID:
return reflect.TypeOf(int64(0))
case pgtype.Int4OID:
return reflect.TypeOf(int32(0))
case pgtype.Int2OID:
return reflect.TypeOf(int16(0))
case pgtype.VarcharOID, pgtype.TextOID, pgtype.NameOID,
pgtype.ByteaOID, pgtype.NumericOID, pgtype.RecordOID,
pgtype.QCharOID, pgtype.BPCharOID:
return reflect.TypeOf("")
case pgtype.BoolOID:
return reflect.TypeOf(false)
case pgtype.DateOID, pgtype.TimeOID, 1266, pgtype.TimestampOID, pgtype.TimestamptzOID:
// 1266 is the OID for TimeTZ.
// TODO(rafi): Add TimetzOID to pgtype.
return reflect.TypeOf(time.Time{})
default:
return reflect.TypeOf(new(interface{})).Elem()
}
}

// databaseTypeName returns the database type name for the given type OID.
func databaseTypeName(ci *pgtype.ConnInfo, typeOID uint32) string {
dataType, ok := ci.DataTypeForOID(typeOID)
Expand Down
15 changes: 0 additions & 15 deletions pkg/cli/clisqlclient/rows.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ package clisqlclient
import (
"database/sql/driver"
"io"
"reflect"

"github.com/jackc/pgtype"
"github.com/jackc/pgx/v4"
Expand Down Expand Up @@ -87,21 +86,7 @@ func (r *sqlRows) NextResultSet() (bool, error) {
return false, nil
}

func (r *sqlRows) ColumnTypeScanType(index int) reflect.Type {
o := r.rows.FieldDescriptions()[index].DataTypeOID
n := r.ColumnTypeDatabaseTypeName(index)
return scanType(o, n)
}

func (r *sqlRows) ColumnTypeDatabaseTypeName(index int) string {
fieldOID := r.rows.FieldDescriptions()[index].DataTypeOID
return databaseTypeName(r.connInfo, fieldOID)
}

func (r *sqlRows) ColumnTypeNames() []string {
colTypes := make([]string, len(r.Columns()))
for i := range colTypes {
colTypes[i] = r.ColumnTypeDatabaseTypeName(i)
}
return colTypes
}
18 changes: 1 addition & 17 deletions pkg/cli/clisqlclient/rows_multi.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ package clisqlclient
import (
"database/sql/driver"
"io"
"reflect"

"github.com/cockroachdb/errors"
"github.com/jackc/pgconn"
Expand Down Expand Up @@ -97,7 +96,7 @@ func (r *sqlRowsMultiResultSet) Next(values []driver.Value) error {
for i := range values {
rowVal := rd.Values()[i]
if rowVal == nil {
values[i] = "NULL"
values[i] = nil
continue
}
fieldOID := rd.FieldDescriptions()[i].DataTypeOID
Expand Down Expand Up @@ -135,23 +134,8 @@ func (r *sqlRowsMultiResultSet) NextResultSet() (bool, error) {
return next, nil
}

func (r *sqlRowsMultiResultSet) ColumnTypeScanType(index int) reflect.Type {
rd := r.rows.ResultReader()
o := rd.FieldDescriptions()[index].DataTypeOID
n := r.ColumnTypeDatabaseTypeName(index)
return scanType(o, n)
}

func (r *sqlRowsMultiResultSet) ColumnTypeDatabaseTypeName(index int) string {
rd := r.rows.ResultReader()
fieldOID := rd.FieldDescriptions()[index].DataTypeOID
return databaseTypeName(r.connInfo, fieldOID)
}

func (r *sqlRowsMultiResultSet) ColumnTypeNames() []string {
colTypes := make([]string, len(r.Columns()))
for i := range colTypes {
colTypes[i] = r.ColumnTypeDatabaseTypeName(i)
}
return colTypes
}
1 change: 0 additions & 1 deletion pkg/cli/clisqlexec/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ go_library(
"//pkg/util/syncutil",
"//pkg/util/timeutil",
"@com_github_cockroachdb_errors//:errors",
"@com_github_jackc_pgtype//:pgtype",
"@com_github_olekukonko_tablewriter//:tablewriter",
"@com_github_spf13_pflag//:pflag",
],
Expand Down
24 changes: 12 additions & 12 deletions pkg/cli/clisqlexec/format_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ import (
"fmt"
"html"
"io"
"reflect"
"strings"
"text/tabwriter"
"time"
Expand Down Expand Up @@ -99,12 +98,11 @@ func NewRowSliceIter(allRows [][]string, align string) RowStrIter {

type rowIter struct {
rows clisqlclient.Rows
colTypes []string
showMoreChars bool
}

func (iter *rowIter) Next() (row []string, err error) {
nextRowString, err := getNextRowStrings(iter.rows, iter.colTypes, iter.showMoreChars)
nextRowString, err := getNextRowStrings(iter.rows, iter.showMoreChars)
if err != nil {
return nil, err
}
Expand All @@ -115,23 +113,26 @@ func (iter *rowIter) Next() (row []string, err error) {
}

func (iter *rowIter) ToSlice() ([][]string, error) {
return getAllRowStrings(iter.rows, iter.colTypes, iter.showMoreChars)
return getAllRowStrings(iter.rows, iter.showMoreChars)
}

func (iter *rowIter) Align() []int {
cols := iter.rows.Columns()
align := make([]int, len(cols))
for i := range align {
switch iter.rows.ColumnTypeScanType(i).Kind() {
case reflect.String:
typName := iter.rows.ColumnTypeDatabaseTypeName(i)
if typName == "" || strings.HasPrefix(typName, "_") {
// All array types begin with "_" and user-defined types may not have a
// type name available.
align[i] = tablewriter.ALIGN_LEFT
case reflect.Slice:
continue
}
switch typName {
case "TEXT", "BYTEA", "CHAR", "BPCHAR", "NAME", "UUID":
align[i] = tablewriter.ALIGN_LEFT
case reflect.Int64:
align[i] = tablewriter.ALIGN_RIGHT
case reflect.Float64:
case "INT2", "INT4", "INT8", "FLOAT4", "FLOAT8", "NUMERIC", "OID":
align[i] = tablewriter.ALIGN_RIGHT
case reflect.Bool:
case "BOOL":
align[i] = tablewriter.ALIGN_CENTER
default:
align[i] = tablewriter.ALIGN_DEFAULT
Expand All @@ -143,7 +144,6 @@ func (iter *rowIter) Align() []int {
func newRowIter(rows clisqlclient.Rows, showMoreChars bool) *rowIter {
return &rowIter{
rows: rows,
colTypes: rows.ColumnTypeNames(),
showMoreChars: showMoreChars,
}
}
Expand Down
Loading