Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented DISCARD #596

Merged
merged 4 commits into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions server/ast/discard.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,19 @@ import (
vitess "github.com/dolthub/vitess/go/vt/sqlparser"

"github.com/dolthub/doltgresql/postgres/parser/sem/tree"
"github.com/dolthub/doltgresql/server/node"
)

// nodeDiscard handles *tree.Discard nodes.
func nodeDiscard(node *tree.Discard) (vitess.Statement, error) {
if node == nil {
func nodeDiscard(discard *tree.Discard) (vitess.Statement, error) {
if discard == nil {
return nil, nil
}
return nil, fmt.Errorf("DISCARD is not yet supported")
if discard.Mode != tree.DiscardModeAll {
return nil, fmt.Errorf("unhandled DISCARD mode: %v", discard.Mode)
}

return vitess.InjectedStatement{
Statement: node.DiscardStatement{},
}, nil
}
48 changes: 42 additions & 6 deletions server/connection_handler.go
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import (
"github.com/dolthub/doltgresql/postgres/parser/parser"
"github.com/dolthub/doltgresql/server/ast"
pgexprs "github.com/dolthub/doltgresql/server/expression"
"github.com/dolthub/doltgresql/server/node"
pgtypes "github.com/dolthub/doltgresql/server/types"
)

Expand Down Expand Up @@ -347,15 +348,29 @@ func (h *ConnectionHandler) handleQuery(message messages.Query) error {
delete(h.preparedStatements, "")
delete(h.portals, "")

// The Deallocate message does not get passed to the engine, since we handle allocation / deallocation of
// prepared statements at this layer
// Certain statement types get handled directly by the handler instead of being passed to the engine
err, handled = h.handleQueryOutsideEngine(query)
if handled {
return err
}

return h.query(query)
}

// handleQueryOutsideEngine handles any queries that should be handled by the handler directly, rather than being
// passed to the engine. Returns true if the query was handled and any error that occurred while doing so.
func (h *ConnectionHandler) handleQueryOutsideEngine(query ConvertedQuery) (error, bool) {
switch stmt := query.AST.(type) {
case *sqlparser.Deallocate:
// TODO: handle ALL keyword
return h.deallocatePreparedStatement(stmt.Name, h.preparedStatements, query, h.Conn())
return h.deallocatePreparedStatement(stmt.Name, h.preparedStatements, query, h.Conn()), true
case sqlparser.InjectedStatement:
switch stmt.Statement.(type) {
case node.DiscardStatement:
return h.discardAll(query, h.Conn()), true
}
}

return h.query(query)
return nil, false
}

// handleParse handles a parse message, returning any error that occurs
Expand Down Expand Up @@ -497,7 +512,13 @@ func (h *ConnectionHandler) handleExecute(message messages.Execute) error {
return connection.Send(h.Conn(), messages.EmptyQueryResponse{})
}

err := h.handler.(mysql.ExtendedHandler).ComExecuteBound(context.Background(), h.mysqlConn, query.String, portalData.BoundPlan, spoolRowsCallback(h.Conn(), &complete, true))
// Certain statement types get handled directly by the handler instead of being passed to the engine
err, handled := h.handleQueryOutsideEngine(query)
if handled {
return err
}

err = h.handler.(mysql.ExtendedHandler).ComExecuteBound(context.Background(), h.mysqlConn, query.String, portalData.BoundPlan, spoolRowsCallback(h.Conn(), &complete, true))
if err != nil {
return err
}
Expand Down Expand Up @@ -979,3 +1000,18 @@ func (h *ConnectionHandler) bindParams(

return plan, fields, err
}

// discardAll handles the DISCARD ALL command
func (h *ConnectionHandler) discardAll(query ConvertedQuery, conn net.Conn) error {
err := h.handler.ComResetConnection(h.mysqlConn)
if err != nil {
return err
}

commandComplete := messages.CommandComplete{
Query: query.String,
Tag: query.StatementTag,
}

return connection.Send(conn, commandComplete)
}
50 changes: 50 additions & 0 deletions server/node/discard.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package node

import (
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/vitess/go/vt/sqlparser"
)

// DiscardStatement is just a marker type, since all functionality is handled by the connection handler,
// rather than the engine. It has to conform to the sql.ExecSourceRel interface to be used in the handler, but this
// functionality is all unused.
type DiscardStatement struct{}

var _ sqlparser.Injectable = DiscardStatement{}
var _ sql.ExecSourceRel = DiscardStatement{}

func (d DiscardStatement) Resolved() bool {
return true
}

func (d DiscardStatement) String() string {
return "DISCARD ALL"
}

func (d DiscardStatement) Schema() sql.Schema {
return nil
}

func (d DiscardStatement) Children() []sql.Node {
return nil
}

func (d DiscardStatement) WithChildren(children ...sql.Node) (sql.Node, error) {
return d, nil
}

func (d DiscardStatement) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool {
return true
}

func (d DiscardStatement) IsReadOnly() bool {
return true
}

func (d DiscardStatement) RowIter(ctx *sql.Context, r sql.Row) (sql.RowIter, error) {
panic("DISCARD ALL should be handled by the connection handler")
}

func (d DiscardStatement) WithResolvedChildren(children []any) (any, error) {
return d, nil
}
2 changes: 1 addition & 1 deletion testing/generation/command_docs/output/discard_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import "testing"

func TestDiscard(t *testing.T) {
tests := []QueryParses{
Parses("DISCARD ALL"),
Converts("DISCARD ALL"),
Unimplemented("DISCARD PLANS"),
Unimplemented("DISCARD SEQUENCES"),
Unimplemented("DISCARD TEMPORARY"),
Expand Down
71 changes: 71 additions & 0 deletions testing/go/session_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package _go

import (
"testing"

"github.com/dolthub/go-mysql-server/sql"
)

func TestDiscard(t *testing.T) {
RunScripts(t, []ScriptTest{
{
Name: "Test discard",
SetUpScript: []string{
`CREATE temporary TABLE test (a INT)`,
`insert into test values (1)`,
},
Assertions: []ScriptTestAssertion{
{
Query: "select * from test",
Expected: []sql.Row{
{1},
},
},
{
Query: "DISCARD ALL",
Expected: []sql.Row{},
},
{
Query: "select * from test",
ExpectedErr: "table not found",
},
},
},
{
Name: "Test discard errors",
SetUpScript: []string{
`CREATE temporary TABLE test (a INT)`,
`insert into test values (1)`,
},
Assertions: []ScriptTestAssertion{
{
Query: "DISCARD SEQUENCES",
ExpectedErr: "unimplemented",
},
{
Query: "select * from test",
Expected: []sql.Row{
{1},
},
},
},
},
{
Name: "Test discard in transaction",
SetUpScript: []string{
`CREATE temporary TABLE test (a INT)`,
`insert into test values (1)`,
},
Assertions: []ScriptTestAssertion{
{
Query: "BEGIN",
},
{
Query: "DISCARD ALL",
ExpectedErr: "DISCARD ALL cannot run inside a transaction block",
Skip: true, // not yet implemented
},
},
},
})
}
Loading