Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support automatic transaction management with COPY FROM STDIN #741

Merged
merged 4 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion server/connection_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,18 @@ type ConvertedQuery struct {
// this statement is processed, the server accepts COPY DATA messages from the client with chunks of data to load
// into a table.
type copyFromStdinState struct {
// copyFromStdinNode stores the original CopyFrom statement that initiated the CopyData message sequence. This
// node is used to look at what parameters were specified, such as which table to load data into, file format,
// delimiters, etc.
copyFromStdinNode *node.CopyFrom
dataLoader dataloader.DataLoader
// dataLoader is the implementation of DataLoader that is used to load each individual CopyData chunk into the
// target table.
dataLoader dataloader.DataLoader
// copyErr stores any error that was returned while processing a CopyData message and loading a chunk of data
// to the target table. The server needs to keep track of any errors that were encountered while processing chunks
// so that it can avoid sending a CommandComplete message if an error was encountered after the client already
// sent a CopyDone message to the server.
copyErr error
}

type PortalData struct {
Expand Down
55 changes: 53 additions & 2 deletions server/connection_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"strings"
"sync/atomic"

"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
"github.com/dolthub/dolt/go/libraries/doltcore/sqlserver"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/vitess/go/mysql"
Expand Down Expand Up @@ -616,15 +617,29 @@ func makeCommandComplete(tag string, rows int32) *pgproto3.CommandComplete {
// messages are expected, and the server should tell the client that it is ready for the next query, and |err| contains
// any error that occurred while processing the COPY DATA message.
func (h *ConnectionHandler) handleCopyData(message *pgproto3.CopyData) (stop bool, endOfMessages bool, err error) {
helper, messages, err := h.handleCopyDataHelper(message)
if err != nil {
h.copyFromStdinState.copyErr = err
}
return helper, messages, err
}

// handleCopyDataHelper is a helper function that should only be invoked by handleCopyData. handleCopyData wraps this
// function so that it can capture any returned error message and store it in the saved state.
func (h *ConnectionHandler) handleCopyDataHelper(message *pgproto3.CopyData) (stop bool, endOfMessages bool, err error) {
if h.copyFromStdinState == nil {
return false, true, fmt.Errorf("COPY DATA message received without a COPY FROM STDIN operation in progress")
}

// Grab a sql.Context
// Grab a sql.Context and ensure the session has a transaction started, otherwise the copied data
// won't get committed correctly.
sqlCtx, err := h.doltgresHandler.NewContext(context.Background(), h.mysqlConn, "")
if err != nil {
return false, false, err
}
if err = startTransaction(sqlCtx); err != nil {
return false, false, err
}

dataLoader := h.copyFromStdinState.dataLoader
if dataLoader == nil {
Expand Down Expand Up @@ -686,6 +701,14 @@ func (h *ConnectionHandler) handleCopyDone(_ *pgproto3.CopyDone) (stop bool, end
fmt.Errorf("COPY DONE message received without a COPY FROM STDIN operation in progress")
}

// If there was a previous error returned from processing a CopyData message, then don't return an error here
// and don't send endOfMessage=true, since the CopyData error already sent endOfMessage=true. If we do send
// endOfMessage=true here, then the client gets confused about the unexpected/extra Idle message since the
// server has already reported it was idle in the last message after the returned error.
if h.copyFromStdinState.copyErr != nil {
return false, false, nil
}

dataLoader := h.copyFromStdinState.dataLoader
if dataLoader == nil {
return false, true,
Expand All @@ -702,6 +725,17 @@ func (h *ConnectionHandler) handleCopyDone(_ *pgproto3.CopyDone) (stop bool, end
return false, false, err
}

// If we aren't in an explicit/user managed transaction, we need to commit the transaction
if !sqlCtx.GetIgnoreAutoCommit() {
fulghum marked this conversation as resolved.
Show resolved Hide resolved
txSession, ok := sqlCtx.Session.(sql.TransactionSession)
if !ok {
return false, false, fmt.Errorf("session does not implement sql.TransactionSession")
}
if err = txSession.CommitTransaction(sqlCtx, txSession.GetTransaction()); err != nil {
return false, false, err
}
}

h.copyFromStdinState = nil
// We send back endOfMessage=true, since the COPY DONE message ends the COPY DATA flow and the server is ready
// to accept the next query now.
Expand All @@ -710,7 +744,7 @@ func (h *ConnectionHandler) handleCopyDone(_ *pgproto3.CopyDone) (stop bool, end
})
}

// handleCopyDone handles a COPY FAIL message by aborting the in-progress COPY DATA operation. The |stop| response
// handleCopyFail handles a COPY FAIL message by aborting the in-progress COPY DATA operation. The |stop| response
// parameter is true if the connection handler should shut down the connection, |endOfMessages| is true if no more
// COPY DATA messages are expected, and the server should tell the client that it is ready for the next query, and
// |err| contains any error that occurred while processing the COPY DATA message.
Expand All @@ -732,6 +766,23 @@ func (h *ConnectionHandler) handleCopyFail(_ *pgproto3.CopyFail) (stop bool, end
return false, true, nil
}

// startTransaction checks to see if the current session has a transaction started yet or not, and if not,
// creates a read/write transaction for the session to use. This is necessary for handling commands that alter
// data without going through the GMS engine.
func startTransaction(ctx *sql.Context) error {
doltSession, ok := ctx.Session.(*dsess.DoltSession)
if !ok {
return fmt.Errorf("unexpected session type: %T", ctx.Session)
}
if doltSession.GetTransaction() == nil {
if _, err := doltSession.StartTransaction(ctx, sql.ReadWrite); err != nil {
return err
}
}

return nil
}

func (h *ConnectionHandler) deallocatePreparedStatement(name string, preparedStatements map[string]PreparedStatementData, query ConvertedQuery, conn net.Conn) error {
_, ok := preparedStatements[name]
if !ok {
Expand Down
34 changes: 33 additions & 1 deletion testing/bats/dataloading.bats
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,22 @@ teardown() {
[[ "$output" =~ "3 | 03 | 97302 | Guyane" ]] || false
}

# Tests that we can load tabular data dump files that do not explicitly manage the session's transaction.
@test 'dataloading: tabular import, no explicit tx management' {
# Import the data dump and assert the expected output
run query_server -f $BATS_TEST_DIRNAME/dataloading/tab-load-with-no-tx-control.sql
[ "$status" -eq 0 ]
[[ "$output" =~ "COPY 3" ]] || false
[[ ! "$output" =~ "ERROR" ]] || false

# Check the inserted rows
run query_server -c "SELECT * FROM test_info ORDER BY id;"
[ "$status" -eq 0 ]
[[ "$output" =~ "4 | string for 4 | 1" ]] || false
[[ "$output" =~ "5 | string for 5 | 0" ]] || false
[[ "$output" =~ "6 | string for 6 | 0" ]] || false
}

# Tests loading in data via different CSV data files.
@test 'dataloading: csv import' {
# Import the data dump and assert the expected output
Expand Down Expand Up @@ -157,4 +173,20 @@ teardown() {
run query_server -c "SELECT count(*) from tbl1;"
[ "$status" -eq 0 ]
[[ "$output" =~ "100" ]] || false
}
}

# Tests that we can load CSV data dump files that do not explicitly manage the session's transaction.
@test 'dataloading: csv import, no explicit tx management' {
# Import the data dump and assert the expected output
run query_server -f $BATS_TEST_DIRNAME/dataloading/csv-load-with-no-tx-control.sql
[ "$status" -eq 0 ]
[[ "$output" =~ "COPY 3" ]] || false
[[ ! "$output" =~ "ERROR" ]] || false

# Check the inserted rows
run query_server -c "SELECT * FROM test_info ORDER BY id;"
[ "$status" -eq 0 ]
[[ "$output" =~ "4 | string for 4 | 1" ]] || false
[[ "$output" =~ "5 | string for 5 | 0" ]] || false
[[ "$output" =~ "6 | string for 6 | 0" ]] || false
}
11 changes: 11 additions & 0 deletions testing/bats/dataloading/csv-load-with-no-tx-control.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
CREATE TABLE test (pk int primary key);
INSERT INTO test VALUES (0), (1);

CREATE TABLE test_info (id int, info varchar(255), test_pk int, primary key(id), foreign key (test_pk) references test(pk));

COPY test_info FROM STDIN (FORMAT CSV, HEADER TRUE);
id,info,test_pk
4,string for 4,1
5,string for 5,0
6,string for 6,0
\.
11 changes: 11 additions & 0 deletions testing/bats/dataloading/tab-load-with-no-tx-control.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
CREATE TABLE test (pk int primary key);
INSERT INTO test VALUES (0), (1);

CREATE TABLE test_info (id int, info varchar(255), test_pk int, primary key(id), foreign key (test_pk) references test(pk));

COPY test_info FROM STDIN WITH (HEADER);
id info test_pk
4 string for 4 1
5 string for 5 0
6 string for 6 0
\.
8 changes: 8 additions & 0 deletions testing/go/regression/tests/triggers.sql
Original file line number Diff line number Diff line change
Expand Up @@ -362,20 +362,27 @@ CREATE TRIGGER insert_when BEFORE INSERT ON main_table
FOR EACH STATEMENT WHEN (true) EXECUTE PROCEDURE trigger_func('insert_when');
CREATE TRIGGER delete_when AFTER DELETE ON main_table
FOR EACH STATEMENT WHEN (true) EXECUTE PROCEDURE trigger_func('delete_when');

SELECT trigger_name, event_manipulation, event_object_schema, event_object_table,
action_order, action_condition, action_orientation, action_timing,
action_reference_old_table, action_reference_new_table
FROM information_schema.triggers
WHERE event_object_table IN ('main_table')
ORDER BY trigger_name COLLATE "C", 2;

INSERT INTO main_table (a) VALUES (123), (456);

COPY main_table FROM stdin;
123 999
456 999
\.

DELETE FROM main_table WHERE a IN (123, 456);

UPDATE main_table SET a = 50, b = 60;

SELECT * FROM main_table ORDER BY a, b;

SELECT pg_get_triggerdef(oid, true) FROM pg_trigger WHERE tgrelid = 'main_table'::regclass AND tgname = 'modified_a';
SELECT pg_get_triggerdef(oid, false) FROM pg_trigger WHERE tgrelid = 'main_table'::regclass AND tgname = 'modified_a';
SELECT pg_get_triggerdef(oid, true) FROM pg_trigger WHERE tgrelid = 'main_table'::regclass AND tgname = 'modified_any';
Expand Down Expand Up @@ -420,6 +427,7 @@ FOR EACH STATEMENT EXECUTE PROCEDURE trigger_func('after_upd_b_stmt');
SELECT pg_get_triggerdef(oid) FROM pg_trigger WHERE tgrelid = 'main_table'::regclass AND tgname = 'after_upd_a_b_row_trig';

UPDATE main_table SET a = 50;

UPDATE main_table SET b = 10;

--
Expand Down
Loading