Skip to content

Commit

Permalink
Merge pull request #260 from cole-miller/shell
Browse files Browse the repository at this point in the history
Always use QUERY_SQL in the dqlite shell
  • Loading branch information
cole-miller authored Jul 17, 2023
2 parents 67cf13d + f7320ed commit beebd01
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 6 deletions.
16 changes: 16 additions & 0 deletions driver/driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,22 @@ func Test_ColumnTypesEnd(t *testing.T) {
assert.NoError(t, conn.Close())
}

func Test_ZeroColumns(t *testing.T) {
drv, cleanup := newDriver(t)
defer cleanup()

conn, err := drv.Open("test.db")
require.NoError(t, err)
queryer := conn.(driver.Queryer)

rows, err := queryer.Query("CREATE TABLE foo (bar INTEGER)", []driver.Value{})
require.NoError(t, err)
values := []driver.Value{}
require.Equal(t, io.EOF, rows.Next(values))

require.NoError(t, conn.Close())
}

func newDriver(t *testing.T) (*dqlitedriver.Driver, func()) {
t.Helper()

Expand Down
6 changes: 6 additions & 0 deletions internal/protocol/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,12 @@ func (r *Rows) columnTypes(save bool) ([]uint8, error) {
r.types = make([]uint8, len(r.Columns))
}

// If there are zero columns, no rows can be encoded or decoded,
// so we signal EOF immediately.
if len(r.types) == 0 {
return r.types, io.EOF
}

// Each column needs a 4 byte slot to store the column type. The row
// header must be padded to reach word boundary.
headerBits := len(r.types) * 4
Expand Down
8 changes: 2 additions & 6 deletions internal/shell/shell.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,7 @@ func (s *Shell) Process(ctx context.Context, line string) (string, error) {
if strings.HasPrefix(strings.ToLower(strings.TrimLeft(line, " ")), ".reconfigure") {
return s.processReconfigure(ctx, line)
}
if strings.HasPrefix(strings.ToUpper(strings.TrimLeft(line, " ")), "SELECT") {
return s.processSelect(ctx, line)
} else {
return "", s.processExec(ctx, line)
}
return s.processQuery(ctx, line)
}

func (s *Shell) processHelp() string {
Expand Down Expand Up @@ -317,7 +313,7 @@ func (s *Shell) processWeight(ctx context.Context, line string) (string, error)
return "", nil
}

func (s *Shell) processSelect(ctx context.Context, line string) (string, error) {
func (s *Shell) processQuery(ctx context.Context, line string) (string, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return "", fmt.Errorf("begin transaction: %w", err)
Expand Down

0 comments on commit beebd01

Please sign in to comment.