diff --git a/driver/driver_test.go b/driver/driver_test.go index bbb3ae7a..3d030c50 100644 --- a/driver/driver_test.go +++ b/driver/driver_test.go @@ -567,6 +567,22 @@ func Test_ColumnTypesEnd(t *testing.T) { assert.NoError(t, conn.Close()) } +func Test_ZeroColumns(t *testing.T) { + drv, cleanup := newDriver(t) + defer cleanup() + + conn, err := drv.Open("test.db") + require.NoError(t, err) + queryer := conn.(driver.Queryer) + + rows, err := queryer.Query("CREATE TABLE foo (bar INTEGER)", []driver.Value{}) + require.NoError(t, err) + values := []driver.Value{} + require.Equal(t, io.EOF, rows.Next(values)) + + require.NoError(t, conn.Close()) +} + func newDriver(t *testing.T) (*dqlitedriver.Driver, func()) { t.Helper() diff --git a/internal/protocol/message.go b/internal/protocol/message.go index 95133995..3a9581d9 100644 --- a/internal/protocol/message.go +++ b/internal/protocol/message.go @@ -478,6 +478,12 @@ func (r *Rows) columnTypes(save bool) ([]uint8, error) { r.types = make([]uint8, len(r.Columns)) } + // If there are zero columns, no rows can be encoded or decoded, + // so we signal EOF immediately. + if len(r.types) == 0 { + return r.types, io.EOF + } + // Each column needs a 4 byte slot to store the column type. The row // header must be padded to reach word boundary. headerBits := len(r.types) * 4 diff --git a/internal/shell/shell.go b/internal/shell/shell.go index 5d7a573e..a93c69fd 100644 --- a/internal/shell/shell.go +++ b/internal/shell/shell.go @@ -87,11 +87,7 @@ func (s *Shell) Process(ctx context.Context, line string) (string, error) { if strings.HasPrefix(strings.ToLower(strings.TrimLeft(line, " ")), ".reconfigure") { return s.processReconfigure(ctx, line) } - if strings.HasPrefix(strings.ToUpper(strings.TrimLeft(line, " ")), "SELECT") { - return s.processSelect(ctx, line) - } else { - return "", s.processExec(ctx, line) - } + return s.processQuery(ctx, line) } func (s *Shell) processHelp() string { @@ -317,7 +313,7 @@ func (s *Shell) processWeight(ctx context.Context, line string) (string, error) return "", nil } -func (s *Shell) processSelect(ctx context.Context, line string) (string, error) { +func (s *Shell) processQuery(ctx context.Context, line string) (string, error) { tx, err := s.db.BeginTx(ctx, nil) if err != nil { return "", fmt.Errorf("begin transaction: %w", err)