diff --git a/server/ast/discard.go b/server/ast/discard.go index a301c09ea4..19eab11d12 100644 --- a/server/ast/discard.go +++ b/server/ast/discard.go @@ -20,12 +20,19 @@ import ( vitess "github.com/dolthub/vitess/go/vt/sqlparser" "github.com/dolthub/doltgresql/postgres/parser/sem/tree" + "github.com/dolthub/doltgresql/server/node" ) // nodeDiscard handles *tree.Discard nodes. -func nodeDiscard(node *tree.Discard) (vitess.Statement, error) { - if node == nil { +func nodeDiscard(discard *tree.Discard) (vitess.Statement, error) { + if discard == nil { return nil, nil } - return nil, fmt.Errorf("DISCARD is not yet supported") + if discard.Mode != tree.DiscardModeAll { + return nil, fmt.Errorf("unhandled DISCARD mode: %v", discard.Mode) + } + + return vitess.InjectedStatement{ + Statement: node.DiscardStatement{}, + }, nil } diff --git a/server/connection_handler.go b/server/connection_handler.go old mode 100755 new mode 100644 index 381360dccb..774d015510 --- a/server/connection_handler.go +++ b/server/connection_handler.go @@ -39,6 +39,7 @@ import ( "github.com/dolthub/doltgresql/postgres/parser/parser" "github.com/dolthub/doltgresql/server/ast" pgexprs "github.com/dolthub/doltgresql/server/expression" + "github.com/dolthub/doltgresql/server/node" pgtypes "github.com/dolthub/doltgresql/server/types" ) @@ -347,15 +348,29 @@ func (h *ConnectionHandler) handleQuery(message messages.Query) error { delete(h.preparedStatements, "") delete(h.portals, "") - // The Deallocate message does not get passed to the engine, since we handle allocation / deallocation of - // prepared statements at this layer + // Certain statement types get handled directly by the handler instead of being passed to the engine + err, handled = h.handleQueryOutsideEngine(query) + if handled { + return err + } + + return h.query(query) +} + +// handleQueryOutsideEngine handles any queries that should be handled by the handler directly, rather than being +// passed to the engine. Returns true if the query was handled and any error that occurred while doing so. +func (h *ConnectionHandler) handleQueryOutsideEngine(query ConvertedQuery) (error, bool) { switch stmt := query.AST.(type) { case *sqlparser.Deallocate: // TODO: handle ALL keyword - return h.deallocatePreparedStatement(stmt.Name, h.preparedStatements, query, h.Conn()) + return h.deallocatePreparedStatement(stmt.Name, h.preparedStatements, query, h.Conn()), true + case sqlparser.InjectedStatement: + switch stmt.Statement.(type) { + case node.DiscardStatement: + return h.discardAll(query, h.Conn()), true + } } - - return h.query(query) + return nil, false } // handleParse handles a parse message, returning any error that occurs @@ -497,7 +512,13 @@ func (h *ConnectionHandler) handleExecute(message messages.Execute) error { return connection.Send(h.Conn(), messages.EmptyQueryResponse{}) } - err := h.handler.(mysql.ExtendedHandler).ComExecuteBound(context.Background(), h.mysqlConn, query.String, portalData.BoundPlan, spoolRowsCallback(h.Conn(), &complete, true)) + // Certain statement types get handled directly by the handler instead of being passed to the engine + err, handled := h.handleQueryOutsideEngine(query) + if handled { + return err + } + + err = h.handler.(mysql.ExtendedHandler).ComExecuteBound(context.Background(), h.mysqlConn, query.String, portalData.BoundPlan, spoolRowsCallback(h.Conn(), &complete, true)) if err != nil { return err } @@ -979,3 +1000,18 @@ func (h *ConnectionHandler) bindParams( return plan, fields, err } + +// discardAll handles the DISCARD ALL command +func (h *ConnectionHandler) discardAll(query ConvertedQuery, conn net.Conn) error { + err := h.handler.ComResetConnection(h.mysqlConn) + if err != nil { + return err + } + + commandComplete := messages.CommandComplete{ + Query: query.String, + Tag: query.StatementTag, + } + + return connection.Send(conn, commandComplete) +} diff --git a/server/node/discard.go b/server/node/discard.go new file mode 100755 index 0000000000..8e3dd2ccd0 --- /dev/null +++ b/server/node/discard.go @@ -0,0 +1,50 @@ +package node + +import ( + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/vitess/go/vt/sqlparser" +) + +// DiscardStatement is just a marker type, since all functionality is handled by the connection handler, +// rather than the engine. It has to conform to the sql.ExecSourceRel interface to be used in the handler, but this +// functionality is all unused. +type DiscardStatement struct{} + +var _ sqlparser.Injectable = DiscardStatement{} +var _ sql.ExecSourceRel = DiscardStatement{} + +func (d DiscardStatement) Resolved() bool { + return true +} + +func (d DiscardStatement) String() string { + return "DISCARD ALL" +} + +func (d DiscardStatement) Schema() sql.Schema { + return nil +} + +func (d DiscardStatement) Children() []sql.Node { + return nil +} + +func (d DiscardStatement) WithChildren(children ...sql.Node) (sql.Node, error) { + return d, nil +} + +func (d DiscardStatement) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { + return true +} + +func (d DiscardStatement) IsReadOnly() bool { + return true +} + +func (d DiscardStatement) RowIter(ctx *sql.Context, r sql.Row) (sql.RowIter, error) { + panic("DISCARD ALL should be handled by the connection handler") +} + +func (d DiscardStatement) WithResolvedChildren(children []any) (any, error) { + return d, nil +} diff --git a/testing/generation/command_docs/output/discard_test.go b/testing/generation/command_docs/output/discard_test.go index 94239c182c..c1d1d157fd 100644 --- a/testing/generation/command_docs/output/discard_test.go +++ b/testing/generation/command_docs/output/discard_test.go @@ -18,7 +18,7 @@ import "testing" func TestDiscard(t *testing.T) { tests := []QueryParses{ - Parses("DISCARD ALL"), + Converts("DISCARD ALL"), Unimplemented("DISCARD PLANS"), Unimplemented("DISCARD SEQUENCES"), Unimplemented("DISCARD TEMPORARY"), diff --git a/testing/go/session_test.go b/testing/go/session_test.go new file mode 100755 index 0000000000..31096ce742 --- /dev/null +++ b/testing/go/session_test.go @@ -0,0 +1,71 @@ +package _go + +import ( + "testing" + + "github.com/dolthub/go-mysql-server/sql" +) + +func TestDiscard(t *testing.T) { + RunScripts(t, []ScriptTest{ + { + Name: "Test discard", + SetUpScript: []string{ + `CREATE temporary TABLE test (a INT)`, + `insert into test values (1)`, + }, + Assertions: []ScriptTestAssertion{ + { + Query: "select * from test", + Expected: []sql.Row{ + {1}, + }, + }, + { + Query: "DISCARD ALL", + Expected: []sql.Row{}, + }, + { + Query: "select * from test", + ExpectedErr: "table not found", + }, + }, + }, + { + Name: "Test discard errors", + SetUpScript: []string{ + `CREATE temporary TABLE test (a INT)`, + `insert into test values (1)`, + }, + Assertions: []ScriptTestAssertion{ + { + Query: "DISCARD SEQUENCES", + ExpectedErr: "unimplemented", + }, + { + Query: "select * from test", + Expected: []sql.Row{ + {1}, + }, + }, + }, + }, + { + Name: "Test discard in transaction", + SetUpScript: []string{ + `CREATE temporary TABLE test (a INT)`, + `insert into test values (1)`, + }, + Assertions: []ScriptTestAssertion{ + { + Query: "BEGIN", + }, + { + Query: "DISCARD ALL", + ExpectedErr: "DISCARD ALL cannot run inside a transaction block", + Skip: true, // not yet implemented + }, + }, + }, + }) +}