From 97c07c4a00b23307560b4975337e0b98a8cc9655 Mon Sep 17 00:00:00 2001 From: Gabriel Corado Date: Sat, 30 Nov 2024 02:50:34 -0300 Subject: [PATCH 01/10] feat(repl): add postgres --- lib/client/db/postgres/repl/commands.go | 110 ++++++ lib/client/db/postgres/repl/commands_test.go | 146 +++++++ lib/client/db/postgres/repl/repl.go | 258 ++++++++++++ lib/client/db/postgres/repl/repl_test.go | 369 ++++++++++++++++++ .../repl/testdata/TestStart/data_type.golden | 4 + .../repl/testdata/TestStart/err.golden | 1 + .../repl/testdata/TestStart/multi.golden | 5 + .../repl/testdata/TestStart/single.golden | 4 + lib/client/db/postgres/repl/testdata/query.go | 111 ++++++ 9 files changed, 1008 insertions(+) create mode 100644 lib/client/db/postgres/repl/commands.go create mode 100644 lib/client/db/postgres/repl/commands_test.go create mode 100644 lib/client/db/postgres/repl/repl.go create mode 100644 lib/client/db/postgres/repl/repl_test.go create mode 100644 lib/client/db/postgres/repl/testdata/TestStart/data_type.golden create mode 100644 lib/client/db/postgres/repl/testdata/TestStart/err.golden create mode 100644 lib/client/db/postgres/repl/testdata/TestStart/multi.golden create mode 100644 lib/client/db/postgres/repl/testdata/TestStart/single.golden create mode 100644 lib/client/db/postgres/repl/testdata/query.go diff --git a/lib/client/db/postgres/repl/commands.go b/lib/client/db/postgres/repl/commands.go new file mode 100644 index 0000000000000..4c4d8de28f39a --- /dev/null +++ b/lib/client/db/postgres/repl/commands.go @@ -0,0 +1,110 @@ +// Teleport +// Copyright (C) 2024 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package repl + +import ( + "fmt" + "strings" + + "github.com/gravitational/teleport" + "github.com/gravitational/teleport/lib/asciitable" +) + +// processCommand receives a command call and return the reply and if the +// command terminates the session. +func (r *REPL) processCommand(line string) (string, bool) { + cmdStr, args, _ := strings.Cut(strings.TrimPrefix(line, commandPrefix), " ") + cmd, ok := r.commands[cmdStr] + if !ok { + return "Unknown command. Try \\? to show the list of supported commands." + lineBreak, false + } + + return cmd.ExecFunc(r, args) +} + +// commandType specify the command category. +type commandType string + +const ( + // commandTypeGeneral represents a general-purpose command type. + commandTypeGeneral commandType = "General" + // commandTypeInformational represents a command type used for informational + // purposes. + commandTypeInformational = "Informational" + // commandTypeConnection represents a command type related to connection + // operations. + commandTypeConnection = "Connection" +) + +// command represents a command that can be executed in the REPL. +type command struct { + // Type specifies the type of the command. + Type commandType + // Description provides a user-friendly explanation of what the command + // does. + Description string + // ExecFunc is the function to execute the command. + ExecFunc func(*REPL, string) (string, bool) +} + +func initCommands() map[string]*command { + return map[string]*command{ + "q": { + Type: commandTypeGeneral, + Description: "Terminates the session.", + ExecFunc: func(_ *REPL, _ string) (string, bool) { return "", true }, + }, + "teleport": { + Type: commandTypeGeneral, + Description: "Show Teleport interactive shell information, such as execution limitations.", + ExecFunc: func(_ *REPL, _ string) (string, bool) { + return fmt.Sprintf("Teleport PostgreSQL interactive shell (v%s)", teleport.Version), false + }, + }, + "?": { + Type: commandTypeGeneral, + Description: "Show the list of supported commands.", + ExecFunc: func(r *REPL, _ string) (string, bool) { + typesTable := make(map[commandType]*asciitable.Table) + for cmdStr, cmd := range r.commands { + if _, ok := typesTable[cmd.Type]; !ok { + table := asciitable.MakeHeadlessTable(2) + typesTable[cmd.Type] = &table + } + + typesTable[cmd.Type].AddRow([]string{cmdStr, cmd.Description}) + } + + var res strings.Builder + for cmdType, output := range typesTable { + res.WriteString(string(cmdType)) + output.AsBuffer().WriteTo(&res) + res.WriteString(lineBreak) + } + + return res.String(), false + }, + }, + "session": { + Type: commandTypeConnection, + Description: "Display information about the current session, like user, roles, and database instance.", + ExecFunc: func(r *REPL, _ string) (string, bool) { + return fmt.Sprintf("Connected to %q instance at %q database as %q user.", r.route.ServiceName, r.route.Database, r.route.Username), false + }, + }, + } +} diff --git a/lib/client/db/postgres/repl/commands_test.go b/lib/client/db/postgres/repl/commands_test.go new file mode 100644 index 0000000000000..a48d9dc970904 --- /dev/null +++ b/lib/client/db/postgres/repl/commands_test.go @@ -0,0 +1,146 @@ +// Teleport +// Copyright (C) 2024 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package repl + +import ( + "context" + "io" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport" + clientproto "github.com/gravitational/teleport/api/client/proto" +) + +func TestCommandExecution(t *testing.T) { + ctx := context.Background() + + for name, tt := range map[string]struct { + line string + commandResult string + expectedArgs string + expectUnknown bool + commandExit bool + }{ + "execute": {line: "\\test", commandResult: "test"}, + "execute with additional arguments": {line: "\\test a b", commandResult: "test", expectedArgs: "a b"}, + "execute with exit": {line: "\\test", commandExit: true}, + "execute with leading and trailing whitespace": {line: " \\test ", commandResult: "test"}, + "unknown command with semicolon": {line: "\\test;", expectUnknown: true}, + "unknown command": {line: "\\wrong", expectUnknown: true}, + "with special characters": {line: "\\special_chars_!@#$%^&*()}", expectUnknown: true}, + "empty command": {line: "\\", expectUnknown: true}, + } { + t.Run(name, func(t *testing.T) { + commandArgsChan := make(chan string) + instance, tc := StartWithServer(t, ctx) + + // Consume the REPL banner. + _ = readUntilNextLead(t, tc) + + // Reset available commands and add a test command so we can assert + // the command execution flow without relying in commands + // implementation or test server capabilities. + instance.commands = map[string]*command{ + "test": { + ExecFunc: func(r *REPL, args string) (string, bool) { + commandArgsChan <- args + return tt.commandResult, tt.commandExit + }, + }, + } + + writeLine(t, tc, tt.line) + if tt.expectUnknown { + reply := readUntilNextLead(t, tc) + require.True(t, strings.HasPrefix(strings.ToLower(reply), "unknown command")) + return + } + + require.EventuallyWithT(t, func(collect *assert.CollectT) { + args := <-commandArgsChan + require.Equal(t, tt.expectedArgs, args) + }, time.Second, time.Millisecond) + + // When the command exits, the REPL and the connections will be + // closed. + if tt.commandExit { + require.EventuallyWithT(t, func(collect *assert.CollectT) { + var buf []byte + _, err := tc.conn.Read(buf[0:]) + require.ErrorIs(t, err, io.EOF) + }, time.Second, time.Millisecond) + return + } + + reply := readUntilNextLead(t, tc) + require.Equal(t, tt.commandResult, reply) + }) + } +} + +func TestCommands(t *testing.T) { + availableCmds := initCommands() + for cmdName, tc := range map[string]struct { + repl *REPL + args string + expectExit bool + assertCommandReply require.ValueAssertionFunc + }{ + "q": {expectExit: true}, + "teleport": { + assertCommandReply: func(t require.TestingT, val interface{}, _ ...interface{}) { + require.Contains(t, val, teleport.Version, "expected \\teleport command to include current Teleport version") + }, + }, + "?": { + repl: &REPL{commands: availableCmds}, + assertCommandReply: func(t require.TestingT, val interface{}, _ ...interface{}) { + for cmd := range availableCmds { + require.Contains(t, val, cmd, "expected \\? command to include information about \\%s", cmd) + } + }, + }, + "session": { + repl: &REPL{route: clientproto.RouteToDatabase{ + ServiceName: "service", + Username: "username", + Database: "database", + }}, + assertCommandReply: func(t require.TestingT, val interface{}, _ ...interface{}) { + require.Contains(t, val, "service", "expected \\session command to contain service name") + require.Contains(t, val, "username", "expected \\session command to contain username") + require.Contains(t, val, "database", "expected \\session command to contain database name") + }, + }, + } { + t.Run(cmdName, func(t *testing.T) { + cmd, ok := availableCmds[cmdName] + require.True(t, ok, "expected command %q to be available at commands", cmdName) + reply, exit := cmd.ExecFunc(tc.repl, tc.args) + if tc.expectExit { + require.True(t, exit, "expected command to exit the REPL") + return + } + tc.assertCommandReply(t, reply) + }) + } +} diff --git a/lib/client/db/postgres/repl/repl.go b/lib/client/db/postgres/repl/repl.go new file mode 100644 index 0000000000000..2e597558a82ba --- /dev/null +++ b/lib/client/db/postgres/repl/repl.go @@ -0,0 +1,258 @@ +// Teleport +// Copyright (C) 2024 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package repl + +import ( + "context" + "fmt" + "io" + "net" + "strings" + + "github.com/gravitational/trace" + "github.com/jackc/pgconn" + "golang.org/x/term" + + "github.com/gravitational/teleport" + clientproto "github.com/gravitational/teleport/api/client/proto" + "github.com/gravitational/teleport/lib/asciitable" +) + +type REPL struct { + ctx context.Context + cancelFunc context.CancelCauseFunc + conn *pgconn.PgConn + client io.ReadWriter + serverConn net.Conn + route clientproto.RouteToDatabase + term *term.Terminal + commands map[string]*command +} + +func Start(ctx context.Context, client io.ReadWriter, serverConn net.Conn, route clientproto.RouteToDatabase) (*REPL, error) { + config, err := pgconn.ParseConfig(fmt.Sprintf("postgres://%s@%s/%s", route.Username, hostnamePlaceholder, route.Database)) + if err != nil { + return nil, trace.Wrap(err) + } + config.TLSConfig = nil + + // Provide a lookup function to avoid having the hostname placeholder to + // resolve into something else. Note that the returned value won't be used. + config.LookupFunc = func(_ context.Context, _ string) ([]string, error) { + return []string{hostnamePlaceholder}, nil + } + config.DialFunc = func(_ context.Context, _, _ string) (net.Conn, error) { + return serverConn, nil + } + + pgConn, err := pgconn.ConnectConfig(ctx, config) + if err != nil { + return nil, trace.Wrap(err) + } + + replCtx, cancelFunc := context.WithCancelCause(ctx) + r := &REPL{ + ctx: replCtx, + cancelFunc: cancelFunc, + conn: pgConn, + client: client, + serverConn: serverConn, + route: route, + term: term.NewTerminal(client, ""), + commands: initCommands(), + } + + go r.start() + return r, nil +} + +func (r *REPL) Close() { + r.close(nil) +} + +func (r *REPL) close(err error) { + r.cancelFunc(err) + r.conn.Close(r.ctx) + r.serverConn.Close() +} + +func (r *REPL) Wait(ctx context.Context) error { + select { + case <-ctx.Done(): + return ctx.Err() + case <-r.ctx.Done(): + return r.ctx.Err() + } +} + +func (r *REPL) start() { + if err := r.presentBanner(); err != nil { + r.close(err) + return + } + + // After loop is done, we always close the REPL to ensure the connections + // are cleaned. + r.close(r.loop()) +} + +// loop implements the main REPL loop. +func (r *REPL) loop() error { + var ( + multilineAcc strings.Builder + readingMultiline bool + ) + + lead := lineLeading(r.route) + leadSpacing := strings.Repeat(" ", len(lead)) + r.term.SetPrompt(lineBreak + lead) + + for { + line, err := r.term.ReadLine() + if err != nil { + return trace.Wrap(err) + } + // ReadLine should always return the line without trailing line breaks, + // but we still require to remove trailing and leading spaces. + line = strings.TrimSpace(line) + + var reply string + switch { + case strings.HasPrefix(line, commandPrefix) && !readingMultiline: + var exit bool + reply, exit = r.processCommand(line) + if exit { + return nil + } + case strings.HasSuffix(line, executionRequestSuffix): + var query string + if readingMultiline { + multilineAcc.WriteString(lineBreak + line) + query = multilineAcc.String() + } else { + query = line + } + + // Reset multiline state. + multilineAcc.Reset() + readingMultiline = false + r.term.SetPrompt(lineBreak + lead) + + reply = formatResult(r.conn.Exec(r.ctx, query).ReadAll()) + lineBreak + default: + // If there wasn't a specific execution, we assume the input is + // multi-line. In this case, we need to accumulate the contents. + + // If this isn't the first line, add the line break as the + // ReadLine function removes it. + if readingMultiline { + multilineAcc.WriteString(lineBreak) + } + + readingMultiline = true + multilineAcc.WriteString(line) + r.term.SetPrompt(leadSpacing) + } + + if reply == "" { + continue + } + + if _, err := r.term.Write([]byte(reply)); err != nil { + return trace.Wrap(err) + } + } +} + +func (r *REPL) presentBanner() error { + _, err := fmt.Fprintf( + r.term, + `Teleport PostgreSQL interactive shell (v%s) +Connected to %q instance as %q user. +Type \? for help.`, + teleport.Version, + r.route.GetServiceName(), + r.route.GetUsername()) + return trace.Wrap(err) +} + +// formatResult formats a pgconn.Exec result. +func formatResult(results []*pgconn.Result, err error) string { + if err != nil { + return errorReplyPrefix + err.Error() + } + + var sb strings.Builder + for _, res := range results { + if !res.CommandTag.Select() { + return res.CommandTag.String() + } + + // build columns + var columns []string + for _, fd := range res.FieldDescriptions { + columns = append(columns, string(fd.Name)) + } + + table := asciitable.MakeTable(columns) + for _, row := range res.Rows { + rowData := make([]string, len(columns)) + for i, data := range row { + // The PostgreSQL package is responsible for transforming the + // row data into a readable format. + rowData[i] = string(data) + } + + table.AddRow(rowData) + } + + table.AsBuffer().WriteTo(&sb) + sb.WriteString(rowsText(len(res.Rows))) + } + + return sb.String() +} + +func lineLeading(route clientproto.RouteToDatabase) string { + return fmt.Sprintf("%s=> ", route.Database) +} + +func rowsText(count int) string { + rowTxt := "row" + if count > 1 { + rowTxt = "rows" + } + + return fmt.Sprintf("(%d %s)", count, rowTxt) +} + +const ( + // hostnamePlaceholder is the hostname used when connecting to the database. + // The pgconn functions require a hostname, however, since we already have + // the connection, we just need to provide a name to suppress this + // requirement. + hostnamePlaceholder = "repl" + // lineBreak represents a line break on the REPL. + lineBreak = "\r\n" + // commandPrefix is the prefix that identifies a REPL command. + commandPrefix = "\\" + // executionRequestSuffix is the suffix that indicates the input must be + // executed. + executionRequestSuffix = ";" + // errorReplyPrefix is the prefix presented when there is a execution error. + errorReplyPrefix = "ERR " +) diff --git a/lib/client/db/postgres/repl/repl_test.go b/lib/client/db/postgres/repl/repl_test.go new file mode 100644 index 0000000000000..e99ec53431fb0 --- /dev/null +++ b/lib/client/db/postgres/repl/repl_test.go @@ -0,0 +1,369 @@ +// Teleport +// Copyright (C) 2024 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package repl + +import ( + "context" + "errors" + "io" + "net" + "strings" + "testing" + "time" + + "github.com/gravitational/trace" + "github.com/jackc/pgconn" + "github.com/jackc/pgproto3/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + clientproto "github.com/gravitational/teleport/api/client/proto" + "github.com/gravitational/teleport/lib/client/db/postgres/repl/testdata" + "github.com/gravitational/teleport/lib/utils/golden" +) + +func TestStart(t *testing.T) { + ctx := context.Background() + _, tc := StartWithServer(t, ctx) + + // Consume the REPL banner. + _ = readUntilNextLead(t, tc) + + writeLine(t, tc, singleRowQuery) + singleRowQueryResult := readUntilNextLead(t, tc) + if golden.ShouldSet() { + golden.SetNamed(t, "single", []byte(singleRowQueryResult)) + } + require.Equal(t, string(golden.GetNamed(t, "single")), singleRowQueryResult) + + writeLine(t, tc, multiRowQuery) + multiRowQueryResult := readUntilNextLead(t, tc) + if golden.ShouldSet() { + golden.SetNamed(t, "multi", []byte(multiRowQueryResult)) + } + require.Equal(t, string(golden.GetNamed(t, "multi")), multiRowQueryResult) + + writeLine(t, tc, errorQuery) + errorQueryResult := readUntilNextLead(t, tc) + if golden.ShouldSet() { + golden.SetNamed(t, "err", []byte(errorQueryResult)) + } + require.Equal(t, string(golden.GetNamed(t, "err")), errorQueryResult) + + writeLine(t, tc, dataTypesQuery) + dataTypeQueryResult := readUntilNextLead(t, tc) + if golden.ShouldSet() { + golden.SetNamed(t, "data_type", []byte(dataTypeQueryResult)) + } + require.Equal(t, string(golden.GetNamed(t, "data_type")), dataTypeQueryResult) +} + +// TestQuery given some input lines, the REPL should execute the expected +// query on the PostgreSQL test server. +func TestQuery(t *testing.T) { + ctx := context.Background() + _, tc := StartWithServer(t, ctx, WithCustomQueries()) + + // Consume the REPL banner. + _ = readUntilNextLead(t, tc) + + for name, tt := range map[string]struct { + lines []string + expectedQuery string + }{ + "query": {lines: []string{"SELECT 1;"}, expectedQuery: "SELECT 1;"}, + "multiline query": {lines: []string{"SELECT", "1", ";"}, expectedQuery: "SELECT\r\n1\r\n;"}, + "malformatted": {lines: []string{"SELECT err;"}, expectedQuery: "SELECT err;"}, + "query with special characters": {lines: []string{"SELECT 'special_chars_!@#$%^&*()';"}, expectedQuery: "SELECT 'special_chars_!@#$%^&*()';"}, + "leading and trailing whitespace": {lines: []string{" SELECT 1; "}, expectedQuery: "SELECT 1;"}, + "multiline with excessive whitespace": {lines: []string{" SELECT", " 1", " ;"}, expectedQuery: "SELECT\r\n1\r\n;"}, + // Commands should only be executed if they are at the beginning of the + // first line. + "with command in the middle": {lines: []string{"SELECT \\d 1;"}, expectedQuery: "SELECT \\d 1;"}, + "multiline with command in the middle": {lines: []string{"SELECT", "\\d", ";"}, expectedQuery: "SELECT\r\n\\d\r\n;"}, + "multiline with command in the last line": {lines: []string{"SELECT", "1", "\\d;"}, expectedQuery: "SELECT\r\n1\r\n\\d;"}, + } { + t.Run(name, func(t *testing.T) { + for _, line := range tt.lines { + writeLine(t, tc, line) + } + + require.EventuallyWithT(t, func(collect *assert.CollectT) { + query := <-tc.QueryChan() + require.Equal(t, tt.expectedQuery, query) + }, time.Second, time.Millisecond) + + // Always expect a query reply from the server. + _ = readUntilNextLead(t, tc) + }) + } +} + +func writeLine(t *testing.T, c *testCtx, line string) { + t.Helper() + data := []byte(line + lineBreak) + + // When writing to the connection, the terminal emulator always writes back. + // If we don't consume those bytes, it will block the ReadLine call (as + // we're net.Pipe). + go func(conn net.Conn) { + buf := make([]byte, len(data)) + for { + n, err := conn.Read(buf[0:]) + if err != nil { + t.Logf("Error while terminal reply on write: %s", err) + break + } + + if string(buf[:n]) == line+lineBreak { + break + } + } + }(c.conn) + + require.EventuallyWithT(t, func(collect *assert.CollectT) { + _, err := c.conn.Write(data) + require.NoError(t, err) + }, 5*time.Second, time.Millisecond) +} + +// readUntilNextLead reads the contents from the client connection until we +// reach the next leading prompt. +func readUntilNextLead(t *testing.T, c *testCtx) string { + t.Helper() + + var acc strings.Builder + for { + line := readLine(t, c) + if strings.HasPrefix(line, lineBreak+lineLeading(c.route)) { + break + } + + acc.WriteString(line) + } + return acc.String() +} + +func readLine(t *testing.T, c *testCtx) string { + t.Helper() + + var n int + buf := make([]byte, 1024) + require.EventuallyWithT(t, func(collect *assert.CollectT) { + var err error + n, err = c.conn.Read(buf[0:]) + require.NoError(t, err) + require.Greater(t, n, 0) + }, 5*time.Second, time.Millisecond) + return string(buf[:n]) +} + +type testCtx struct { + ctx context.Context + cancelFunc context.CancelFunc + + // conn is the connection used by tests to read/write from/to the REPL. + conn net.Conn + // clientConn is the connection passed to the REPL. + clientConn net.Conn + // serverConn is the fake database server connection (that works as a + // PostgreSQL instance). + serverConn net.Conn + + route clientproto.RouteToDatabase + pgClient *pgproto3.Backend + errChan chan error + // handleCustomQueries when set to true the PostgreSQL test server will + // accept any query sent and reply with success. + handleCustomQueries bool + // queryChan when HandleCustomQueries == true the queries received by the + // test server will be sent to this channel. + queryChan chan string +} + +// testCtxOption represents a testCtx option. +type testCtxOption func(*testCtx) + +// WithCustomQueries enables sending custom queries to the PostgreSQL test +// server. Note that when it is enabled, callers must consume the queries on the +// query channel. +func WithCustomQueries() testCtxOption { + return func(tc *testCtx) { + tc.handleCustomQueries = true + } +} + +// StartWithServer starts a REPL instance with a PostgreSQL test server capable +// of receiving and replying to queries. +func StartWithServer(t *testing.T, ctx context.Context, opts ...testCtxOption) (*REPL, *testCtx) { + t.Helper() + + conn, clientConn := net.Pipe() + serverConn, pgConn := net.Pipe() + client := pgproto3.NewBackend(pgproto3.NewChunkReader(pgConn), pgConn) + ctx, cancelFunc := context.WithCancel(ctx) + tc := &testCtx{ + ctx: ctx, + cancelFunc: cancelFunc, + conn: conn, + clientConn: clientConn, + serverConn: serverConn, + pgClient: client, + errChan: make(chan error, 1), + queryChan: make(chan string), + } + + for _, opt := range opts { + opt(tc) + } + + t.Cleanup(func() { + tc.close() + require.EventuallyWithT(t, func(collect *assert.CollectT) { + err, _ := <-tc.errChan + assert.NoError(t, err) + }, time.Second, time.Millisecond) + }) + + go func(c *testCtx) { + defer close(c.errChan) + if err := c.processMessages(); err != nil { + c.errChan <- err + } + }(tc) + + r, err := Start(ctx, tc.clientConn, tc.serverConn, tc.route) + require.NoError(t, err) + return r, tc +} + +func (tc *testCtx) QueryChan() chan string { + return tc.queryChan +} + +func (tc *testCtx) close() { + tc.serverConn.Close() + tc.clientConn.Close() +} + +func (tc *testCtx) processMessages() error { + defer tc.close() + + startupMessage, err := tc.pgClient.ReceiveStartupMessage() + if err != nil { + return trace.Wrap(err) + } + + switch msg := startupMessage.(type) { + case *pgproto3.StartupMessage: + // Accept auth and send ready for query. + if err := tc.pgClient.Send(&pgproto3.AuthenticationOk{}); err != nil { + return trace.Wrap(err) + } + + // Values on the backend key data are not relavant since we don't + // support canceling requests. + err := tc.pgClient.Send(&pgproto3.BackendKeyData{ + ProcessID: 0, + SecretKey: 123, + }) + if err != nil { + return trace.Wrap(err) + } + + if err := tc.pgClient.Send(&pgproto3.ReadyForQuery{}); err != nil { + return trace.Wrap(err) + } + default: + return trace.BadParameter("expected *pgproto3.StartupMessage, got: %T", msg) + } + + for { + message, err := tc.pgClient.Receive() + if err != nil { + if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { + return nil + } + + return trace.Wrap(err) + } + + var messages []pgproto3.BackendMessage + switch msg := message.(type) { + case *pgproto3.Query: + if tc.handleCustomQueries { + select { + case tc.queryChan <- msg.String: + messages = []pgproto3.BackendMessage{ + &pgproto3.CommandComplete{CommandTag: pgconn.CommandTag("INSERT 0 1")}, + &pgproto3.ReadyForQuery{}, + } + case <-tc.ctx.Done(): + return trace.Wrap(tc.ctx.Err()) + } + + break // breaks the message switch case. + } + + switch msg.String { + case singleRowQuery: + messages = []pgproto3.BackendMessage{ + &pgproto3.RowDescription{Fields: []pgproto3.FieldDescription{{Name: []byte("id")}, {Name: []byte("email")}}}, + &pgproto3.DataRow{Values: [][]byte{[]byte("1"), []byte("alice@example.com")}}, + &pgproto3.CommandComplete{CommandTag: pgconn.CommandTag("SELECT")}, + &pgproto3.ReadyForQuery{}, + } + case multiRowQuery: + messages = []pgproto3.BackendMessage{ + &pgproto3.RowDescription{Fields: []pgproto3.FieldDescription{{Name: []byte("id")}, {Name: []byte("email")}}}, + &pgproto3.DataRow{Values: [][]byte{[]byte("1"), []byte("alice@example.com")}}, + &pgproto3.DataRow{Values: [][]byte{[]byte("2"), []byte("bob@example.com")}}, + &pgproto3.CommandComplete{CommandTag: pgconn.CommandTag("SELECT")}, + &pgproto3.ReadyForQuery{}, + } + case dataTypesQuery: + messages = testdata.TestDataQueryResult + case errorQuery: + messages = []pgproto3.BackendMessage{ + &pgproto3.ErrorResponse{Severity: "ERROR", Code: "42703", Message: "error"}, + &pgproto3.ReadyForQuery{}, + } + default: + return trace.BadParameter("unsupported query %q", msg.String) + + } + case *pgproto3.Terminate: + return nil + default: + return trace.BadParameter("unsupported message %#v", message) + } + + for _, message := range messages { + err := tc.pgClient.Send(message) + if err != nil { + return trace.Wrap(err) + } + } + } +} + +const ( + singleRowQuery = "SELECT * FROM users LIMIT 1;" + multiRowQuery = "SELECT * FROM users;" + dataTypesQuery = "SELECT * FROM test_data_types;" + errorQuery = "SELECT err;" +) diff --git a/lib/client/db/postgres/repl/testdata/TestStart/data_type.golden b/lib/client/db/postgres/repl/testdata/TestStart/data_type.golden new file mode 100644 index 0000000000000..725af38776034 --- /dev/null +++ b/lib/client/db/postgres/repl/testdata/TestStart/data_type.golden @@ -0,0 +1,4 @@ +serial_col int_col smallint_col bigint_col decimal_col numeric_col real_col double_col smallserial_col bigserial_col char_col varchar_col text_col boolean_col date_col time_col timetz_col timestamp_col timestamptz_col interval_col uuid_col json_col jsonb_col xml_col bytea_col inet_col cidr_col macaddr_col point_col line_col lseg_col box_col path_col polygon_col circle_col tsquery_col tsvector_col +---------- ------- ------------ ------------------- ----------- ----------- -------- ----------------- --------------- ------------- ---------- ------------------- ---------------- ----------- ---------- -------- ----------- ------------------- ---------------------- ----------------------------- ------------------------------------ ---------------- ---------------- --------------------------------------- ------------------------ ----------- -------------- ----------------- --------- -------- ------------- ----------- ------------------- ------------------- ---------- ------------- -------------------------------------------------- +1 42 32767 9223372036854775807 12345.67 98765.43210 3.14 2.718281828459045 1 1 A Sample varchar text Sample text data t 2024-11-29 12:34:56 12:34:56+03 2024-11-29 12:34:56 2024-11-29 09:34:56+00 1 year 2 mons 3 days 04:05:06 550e8400-e29b-41d4-a716-446655440000 {"key": "value"} {"key": "value"} XML content \x48656c6c6f20576f726c64 192.168.1.1 192.168.1.0/24 08:00:2b:01:02:03 (1,2) {1,-1,0} [(0,0),(1,1)] (1,1),(0,0) ((0,0),(1,1),(2,2)) ((0,0),(1,1),(1,0)) <(0,0),1> 'fat' & 'rat' 'a' 'and' 'ate' 'cat' 'fat' 'mat' 'on' 'rat' 'sat' +(1 row) diff --git a/lib/client/db/postgres/repl/testdata/TestStart/err.golden b/lib/client/db/postgres/repl/testdata/TestStart/err.golden new file mode 100644 index 0000000000000..1dd89d57178c7 --- /dev/null +++ b/lib/client/db/postgres/repl/testdata/TestStart/err.golden @@ -0,0 +1 @@ +ERR ERROR: error (SQLSTATE 42703) diff --git a/lib/client/db/postgres/repl/testdata/TestStart/multi.golden b/lib/client/db/postgres/repl/testdata/TestStart/multi.golden new file mode 100644 index 0000000000000..43b92f3157fbb --- /dev/null +++ b/lib/client/db/postgres/repl/testdata/TestStart/multi.golden @@ -0,0 +1,5 @@ +id email +-- ----------------- +1 alice@example.com +2 bob@example.com +(2 rows) diff --git a/lib/client/db/postgres/repl/testdata/TestStart/single.golden b/lib/client/db/postgres/repl/testdata/TestStart/single.golden new file mode 100644 index 0000000000000..c6ac2ed5ce793 --- /dev/null +++ b/lib/client/db/postgres/repl/testdata/TestStart/single.golden @@ -0,0 +1,4 @@ +id email +-- ----------------- +1 alice@example.com +(1 row) diff --git a/lib/client/db/postgres/repl/testdata/query.go b/lib/client/db/postgres/repl/testdata/query.go new file mode 100644 index 0000000000000..6789b128c1f37 --- /dev/null +++ b/lib/client/db/postgres/repl/testdata/query.go @@ -0,0 +1,111 @@ +// Teleport +// Copyright (C) 2024 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package testdata + +import ( + "github.com/jackc/pgconn" + "github.com/jackc/pgproto3/v2" +) + +// Contains a query result with the most common fields in PostgreSQL. +// This can be used to understand how the REPL deals with different data types. +// +// Sampled from https://github.com/postgres/postgres/blob/b6612aedc53a6bf069eba5e356a8421ad6426486/src/include/catalog/pg_type.dat +// PostgreSQL version 17.2 +var TestDataQueryResult = []pgproto3.BackendMessage{ + &pgproto3.RowDescription{Fields: []pgproto3.FieldDescription{ + // TableOID and TableAttributeNumber values omitted. + {Name: []byte("serial_col"), DataTypeOID: 23, DataTypeSize: 4, TypeModifier: -1, Format: 0}, + {Name: []byte("int_col"), DataTypeOID: 23, DataTypeSize: 4, TypeModifier: -1, Format: 0}, + {Name: []byte("smallint_col"), DataTypeOID: 21, DataTypeSize: 2, TypeModifier: -1, Format: 0}, + {Name: []byte("bigint_col"), DataTypeOID: 20, DataTypeSize: 8, TypeModifier: -1, Format: 0}, + {Name: []byte("decimal_col"), DataTypeOID: 1700, DataTypeSize: -1, TypeModifier: 655366, Format: 0}, + {Name: []byte("numeric_col"), DataTypeOID: 1700, DataTypeSize: -1, TypeModifier: 983049, Format: 0}, + {Name: []byte("real_col"), DataTypeOID: 700, DataTypeSize: 4, TypeModifier: -1, Format: 0}, + {Name: []byte("double_col"), DataTypeOID: 701, DataTypeSize: 8, TypeModifier: -1, Format: 0}, + {Name: []byte("smallserial_col"), DataTypeOID: 21, DataTypeSize: 2, TypeModifier: -1, Format: 0}, + {Name: []byte("bigserial_col"), DataTypeOID: 20, DataTypeSize: 8, TypeModifier: -1, Format: 0}, + {Name: []byte("char_col"), DataTypeOID: 1042, DataTypeSize: -1, TypeModifier: 14, Format: 0}, + {Name: []byte("varchar_col"), DataTypeOID: 1043, DataTypeSize: -1, TypeModifier: 54, Format: 0}, + {Name: []byte("text_col"), DataTypeOID: 25, DataTypeSize: -1, TypeModifier: -1, Format: 0}, + {Name: []byte("boolean_col"), DataTypeOID: 16, DataTypeSize: 1, TypeModifier: -1, Format: 0}, + {Name: []byte("date_col"), DataTypeOID: 1082, DataTypeSize: 4, TypeModifier: -1, Format: 0}, + {Name: []byte("time_col"), DataTypeOID: 1083, DataTypeSize: 8, TypeModifier: -1, Format: 0}, + {Name: []byte("timetz_col"), DataTypeOID: 1266, DataTypeSize: 12, TypeModifier: -1, Format: 0}, + {Name: []byte("timestamp_col"), DataTypeOID: 1114, DataTypeSize: 8, TypeModifier: -1, Format: 0}, + {Name: []byte("timestamptz_col"), DataTypeOID: 1184, DataTypeSize: 8, TypeModifier: -1, Format: 0}, + {Name: []byte("interval_col"), DataTypeOID: 1186, DataTypeSize: 16, TypeModifier: -1, Format: 0}, + {Name: []byte("uuid_col"), DataTypeOID: 2950, DataTypeSize: 16, TypeModifier: -1, Format: 0}, + {Name: []byte("json_col"), DataTypeOID: 114, DataTypeSize: -1, TypeModifier: -1, Format: 0}, + {Name: []byte("jsonb_col"), DataTypeOID: 3802, DataTypeSize: -1, TypeModifier: -1, Format: 0}, + {Name: []byte("xml_col"), DataTypeOID: 142, DataTypeSize: -1, TypeModifier: -1, Format: 0}, + {Name: []byte("bytea_col"), DataTypeOID: 17, DataTypeSize: -1, TypeModifier: -1, Format: 0}, + {Name: []byte("inet_col"), DataTypeOID: 869, DataTypeSize: -1, TypeModifier: -1, Format: 0}, + {Name: []byte("cidr_col"), DataTypeOID: 650, DataTypeSize: -1, TypeModifier: -1, Format: 0}, + {Name: []byte("macaddr_col"), DataTypeOID: 829, DataTypeSize: 6, TypeModifier: -1, Format: 0}, + {Name: []byte("point_col"), DataTypeOID: 600, DataTypeSize: 16, TypeModifier: -1, Format: 0}, + {Name: []byte("line_col"), DataTypeOID: 628, DataTypeSize: 24, TypeModifier: -1, Format: 0}, + {Name: []byte("lseg_col"), DataTypeOID: 601, DataTypeSize: 32, TypeModifier: -1, Format: 0}, + {Name: []byte("box_col"), DataTypeOID: 603, DataTypeSize: 32, TypeModifier: -1, Format: 0}, + {Name: []byte("path_col"), DataTypeOID: 602, DataTypeSize: -1, TypeModifier: -1, Format: 0}, + {Name: []byte("polygon_col"), DataTypeOID: 604, DataTypeSize: -1, TypeModifier: -1, Format: 0}, + {Name: []byte("circle_col"), DataTypeOID: 718, DataTypeSize: 24, TypeModifier: -1, Format: 0}, + {Name: []byte("tsquery_col"), DataTypeOID: 3615, DataTypeSize: -1, TypeModifier: -1, Format: 0}, + {Name: []byte("tsvector_col"), DataTypeOID: 3614, DataTypeSize: -1, TypeModifier: -1, Format: 0}, + }}, + &pgproto3.DataRow{Values: [][]byte{ + []byte("1"), + []byte("42"), + []byte("32767"), + []byte("9223372036854775807"), + []byte("12345.67"), + []byte("98765.43210"), + []byte("3.14"), + []byte("2.718281828459045"), + []byte("1"), + []byte("1"), + []byte("A "), + []byte("Sample varchar text"), + []byte("Sample text data"), + []byte("t"), + []byte("2024-11-29"), + []byte("12:34:56"), + []byte("12:34:56+03"), + []byte("2024-11-29 12:34:56"), + []byte("2024-11-29 09:34:56+00"), + []byte("1 year 2 mons 3 days 04:05:06"), + []byte("550e8400-e29b-41d4-a716-446655440000"), + []byte("{\"key\": \"value\"}"), + []byte("{\"key\": \"value\"}"), + []byte("XML content"), + []byte("\\x48656c6c6f20576f726c64"), + []byte("192.168.1.1"), + []byte("192.168.1.0/24"), + []byte("08:00:2b:01:02:03"), + []byte("(1,2)"), + []byte("{1,-1,0}"), + []byte("[(0,0),(1,1)]"), + []byte("(1,1),(0,0)"), + []byte("((0,0),(1,1),(2,2))"), + []byte("((0,0),(1,1),(1,0))"), + []byte("<(0,0),1>"), + []byte("'fat' & 'rat'"), + []byte("'a' 'and' 'ate' 'cat' 'fat' 'mat' 'on' 'rat' 'sat'"), + }}, + &pgproto3.CommandComplete{CommandTag: pgconn.CommandTag("SELECT")}, + &pgproto3.ReadyForQuery{}, +} From bb7d4d6ec76987bc9d370d8b53354d64989edc02 Mon Sep 17 00:00:00 2001 From: Gabriel Corado Date: Thu, 5 Dec 2024 00:06:51 -0300 Subject: [PATCH 02/10] refactor(repl): change repl to use a single Run function --- lib/client/db/postgres/repl/commands_test.go | 25 ++++- lib/client/db/postgres/repl/repl.go | 96 +++++++++----------- lib/client/db/postgres/repl/repl_test.go | 60 +++++++++--- 3 files changed, 116 insertions(+), 65 deletions(-) diff --git a/lib/client/db/postgres/repl/commands_test.go b/lib/client/db/postgres/repl/commands_test.go index a48d9dc970904..c3c8f69e42535 100644 --- a/lib/client/db/postgres/repl/commands_test.go +++ b/lib/client/db/postgres/repl/commands_test.go @@ -51,7 +51,14 @@ func TestCommandExecution(t *testing.T) { } { t.Run(name, func(t *testing.T) { commandArgsChan := make(chan string) - instance, tc := StartWithServer(t, ctx) + instance, tc := StartWithServer(t, ctx, WithSkipREPLRun()) + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + runErrChan := make(chan error) + go func() { + runErrChan <- instance.Run(ctx) + }() // Consume the REPL banner. _ = readUntilNextLead(t, tc) @@ -88,11 +95,27 @@ func TestCommandExecution(t *testing.T) { _, err := tc.conn.Read(buf[0:]) require.ErrorIs(t, err, io.EOF) }, time.Second, time.Millisecond) + + select { + case err := <-runErrChan: + require.NoError(t, err, "expected the REPL instance exit gracefully") + case <-time.After(time.Second): + require.Fail(t, "expected REPL run to terminate but got nothing") + } return } reply := readUntilNextLead(t, tc) require.Equal(t, tt.commandResult, reply) + + // Terminate the REPL run session and wait for the Run results. + cancel() + select { + case err := <-runErrChan: + require.ErrorIs(t, err, context.Canceled, "expected the REPL instance to finish running with error due to cancelation") + case <-time.After(time.Second): + require.Fail(t, "expected REPL run to terminate but got nothing") + } }) } } diff --git a/lib/client/db/postgres/repl/repl.go b/lib/client/db/postgres/repl/repl.go index 2e597558a82ba..d1bd7b263cd2a 100644 --- a/lib/client/db/postgres/repl/repl.go +++ b/lib/client/db/postgres/repl/repl.go @@ -18,6 +18,7 @@ package repl import ( "context" + "errors" "fmt" "io" "net" @@ -33,17 +34,15 @@ import ( ) type REPL struct { - ctx context.Context - cancelFunc context.CancelCauseFunc - conn *pgconn.PgConn - client io.ReadWriter + connConfig *pgconn.Config + client io.ReadWriteCloser serverConn net.Conn route clientproto.RouteToDatabase term *term.Terminal commands map[string]*command } -func Start(ctx context.Context, client io.ReadWriter, serverConn net.Conn, route clientproto.RouteToDatabase) (*REPL, error) { +func New(ctx context.Context, client io.ReadWriteCloser, serverConn net.Conn, route clientproto.RouteToDatabase) (*REPL, error) { config, err := pgconn.ParseConfig(fmt.Sprintf("postgres://%s@%s/%s", route.Username, hostnamePlaceholder, route.Database)) if err != nil { return nil, trace.Wrap(err) @@ -59,59 +58,40 @@ func Start(ctx context.Context, client io.ReadWriter, serverConn net.Conn, route return serverConn, nil } - pgConn, err := pgconn.ConnectConfig(ctx, config) - if err != nil { - return nil, trace.Wrap(err) - } - - replCtx, cancelFunc := context.WithCancelCause(ctx) - r := &REPL{ - ctx: replCtx, - cancelFunc: cancelFunc, - conn: pgConn, + return &REPL{ + connConfig: config, client: client, serverConn: serverConn, route: route, term: term.NewTerminal(client, ""), commands: initCommands(), - } - - go r.start() - return r, nil -} - -func (r *REPL) Close() { - r.close(nil) -} - -func (r *REPL) close(err error) { - r.cancelFunc(err) - r.conn.Close(r.ctx) - r.serverConn.Close() -} - -func (r *REPL) Wait(ctx context.Context) error { - select { - case <-ctx.Done(): - return ctx.Err() - case <-r.ctx.Done(): - return r.ctx.Err() - } + }, nil } -func (r *REPL) start() { - if err := r.presentBanner(); err != nil { - r.close(err) - return +// Run starts and run the PostgreSQL REPL session. The provided context is used +// to interrupt the execution and clean up resources. +func (r *REPL) Run(ctx context.Context) error { + pgConn, err := pgconn.ConnectConfig(ctx, r.connConfig) + if err != nil { + return trace.Wrap(err) } + defer pgConn.Close(context.TODO()) + + // term.Terminal blocks reads/writes without respecting the context. The + // only thing that unblocks it is closing the underlaying connection (in + // our case r.client). On this goroutine we only watch for context + // cancelation and close the connection. This will unblocks all terminal + // reads/writes. + ctxCancelCh := make(chan struct{}) + defer close(ctxCancelCh) + go func() { + select { + case <-ctx.Done(): + _ = r.client.Close() + case <-ctxCancelCh: + } + }() - // After loop is done, we always close the REPL to ensure the connections - // are cleaned. - r.close(r.loop()) -} - -// loop implements the main REPL loop. -func (r *REPL) loop() error { var ( multilineAcc strings.Builder readingMultiline bool @@ -124,8 +104,9 @@ func (r *REPL) loop() error { for { line, err := r.term.ReadLine() if err != nil { - return trace.Wrap(err) + return trace.Wrap(formatTermError(ctx, err)) } + // ReadLine should always return the line without trailing line breaks, // but we still require to remove trailing and leading spaces. line = strings.TrimSpace(line) @@ -152,7 +133,7 @@ func (r *REPL) loop() error { readingMultiline = false r.term.SetPrompt(lineBreak + lead) - reply = formatResult(r.conn.Exec(r.ctx, query).ReadAll()) + lineBreak + reply = formatResult(pgConn.Exec(ctx, query).ReadAll()) + lineBreak default: // If there wasn't a specific execution, we assume the input is // multi-line. In this case, we need to accumulate the contents. @@ -173,11 +154,22 @@ func (r *REPL) loop() error { } if _, err := r.term.Write([]byte(reply)); err != nil { - return trace.Wrap(err) + return trace.Wrap(formatTermError(ctx, err)) } } } +// formatTermError changes the term.Terminal error to match caller expectations. +func formatTermError(ctx context.Context, err error) error { + // When context is canceled it will immediatly lead read/write errors due + // to the closed connection. For this cases we return the context error. + if ctx.Err() != nil && (errors.Is(err, io.ErrClosedPipe) || errors.Is(err, net.ErrClosed)) { + return ctx.Err() + } + + return err +} + func (r *REPL) presentBanner() error { _, err := fmt.Fprintf( r.term, diff --git a/lib/client/db/postgres/repl/repl_test.go b/lib/client/db/postgres/repl/repl_test.go index e99ec53431fb0..e28de99ea3603 100644 --- a/lib/client/db/postgres/repl/repl_test.go +++ b/lib/client/db/postgres/repl/repl_test.go @@ -173,6 +173,7 @@ func readLine(t *testing.T, c *testCtx) string { } type testCtx struct { + cfg *testCtxConfig ctx context.Context cancelFunc context.CancelFunc @@ -187,23 +188,35 @@ type testCtx struct { route clientproto.RouteToDatabase pgClient *pgproto3.Backend errChan chan error + // queryChan handling custom queries is enabled the queries received by the + // test server will be sent to this channel. + queryChan chan string +} + +type testCtxConfig struct { + // skipREPLRun when set to true the REPL instance won't be executed. + skipREPLRun bool // handleCustomQueries when set to true the PostgreSQL test server will // accept any query sent and reply with success. handleCustomQueries bool - // queryChan when HandleCustomQueries == true the queries received by the - // test server will be sent to this channel. - queryChan chan string } // testCtxOption represents a testCtx option. -type testCtxOption func(*testCtx) +type testCtxOption func(*testCtxConfig) // WithCustomQueries enables sending custom queries to the PostgreSQL test // server. Note that when it is enabled, callers must consume the queries on the // query channel. func WithCustomQueries() testCtxOption { - return func(tc *testCtx) { - tc.handleCustomQueries = true + return func(cfg *testCtxConfig) { + cfg.handleCustomQueries = true + } +} + +// WithSkipREPLRun disables automatically running the REPL instance. +func WithSkipREPLRun() testCtxOption { + return func(cfg *testCtxConfig) { + cfg.skipREPLRun = true } } @@ -212,11 +225,17 @@ func WithCustomQueries() testCtxOption { func StartWithServer(t *testing.T, ctx context.Context, opts ...testCtxOption) (*REPL, *testCtx) { t.Helper() + cfg := &testCtxConfig{} + for _, opt := range opts { + opt(cfg) + } + conn, clientConn := net.Pipe() serverConn, pgConn := net.Pipe() client := pgproto3.NewBackend(pgproto3.NewChunkReader(pgConn), pgConn) ctx, cancelFunc := context.WithCancel(ctx) tc := &testCtx{ + cfg: cfg, ctx: ctx, cancelFunc: cancelFunc, conn: conn, @@ -227,10 +246,6 @@ func StartWithServer(t *testing.T, ctx context.Context, opts ...testCtxOption) ( queryChan: make(chan string), } - for _, opt := range opts { - opt(tc) - } - t.Cleanup(func() { tc.close() require.EventuallyWithT(t, func(collect *assert.CollectT) { @@ -246,8 +261,29 @@ func StartWithServer(t *testing.T, ctx context.Context, opts ...testCtxOption) ( } }(tc) - r, err := Start(ctx, tc.clientConn, tc.serverConn, tc.route) + r, err := New(ctx, tc.clientConn, tc.serverConn, tc.route) require.NoError(t, err) + + if !cfg.skipREPLRun { + // Start the REPL session and return to the caller a channel that will + // receive the execution result so it can assert REPL executions. + runCtx, cancelRun := context.WithCancel(ctx) + runErrChan := make(chan error, 1) + go func() { + runErrChan <- r.Run(runCtx) + }() + t.Cleanup(func() { + cancelRun() + + select { + case err := <-runErrChan: + require.ErrorIs(t, err, context.Canceled, "expected the REPL instance to finish running with error due to cancelation") + case <-time.After(10 * time.Second): + require.Fail(t, "timeout while waiting for REPL Run result") + } + }) + } + return r, tc } @@ -305,7 +341,7 @@ func (tc *testCtx) processMessages() error { var messages []pgproto3.BackendMessage switch msg := message.(type) { case *pgproto3.Query: - if tc.handleCustomQueries { + if tc.cfg.handleCustomQueries { select { case tc.queryChan <- msg.String: messages = []pgproto3.BackendMessage{ From 8dad87764c5955574cf8541bb7f5f716ecf1f2b3 Mon Sep 17 00:00:00 2001 From: Gabriel Corado Date: Thu, 5 Dec 2024 00:18:27 -0300 Subject: [PATCH 03/10] test(repl): reduce usage of require.Eventually blocks --- lib/client/db/postgres/repl/commands.go | 6 ++--- lib/client/db/postgres/repl/commands_test.go | 8 +++--- lib/client/db/postgres/repl/repl.go | 6 ++++- lib/client/db/postgres/repl/repl_test.go | 27 ++++++++++++++------ 4 files changed, 31 insertions(+), 16 deletions(-) diff --git a/lib/client/db/postgres/repl/commands.go b/lib/client/db/postgres/repl/commands.go index 4c4d8de28f39a..c63d0037ada2d 100644 --- a/lib/client/db/postgres/repl/commands.go +++ b/lib/client/db/postgres/repl/commands.go @@ -36,15 +36,13 @@ func (r *REPL) processCommand(line string) (string, bool) { return cmd.ExecFunc(r, args) } -// commandType specify the command category. +// commandType specify the command category. This is used to organize the +// commands, for example, when showing them in the help command. type commandType string const ( // commandTypeGeneral represents a general-purpose command type. commandTypeGeneral commandType = "General" - // commandTypeInformational represents a command type used for informational - // purposes. - commandTypeInformational = "Informational" // commandTypeConnection represents a command type related to connection // operations. commandTypeConnection = "Connection" diff --git a/lib/client/db/postgres/repl/commands_test.go b/lib/client/db/postgres/repl/commands_test.go index c3c8f69e42535..3a9bb9ec6159a 100644 --- a/lib/client/db/postgres/repl/commands_test.go +++ b/lib/client/db/postgres/repl/commands_test.go @@ -82,10 +82,12 @@ func TestCommandExecution(t *testing.T) { return } - require.EventuallyWithT(t, func(collect *assert.CollectT) { - args := <-commandArgsChan + select { + case args := <-commandArgsChan: require.Equal(t, tt.expectedArgs, args) - }, time.Second, time.Millisecond) + case <-time.After(time.Second): + require.Fail(t, "expected to command args from test server but got nothing") + } // When the command exits, the REPL and the connections will be // closed. diff --git a/lib/client/db/postgres/repl/repl.go b/lib/client/db/postgres/repl/repl.go index d1bd7b263cd2a..eb1ddcaa705b0 100644 --- a/lib/client/db/postgres/repl/repl.go +++ b/lib/client/db/postgres/repl/repl.go @@ -92,6 +92,10 @@ func (r *REPL) Run(ctx context.Context) error { } }() + if err := r.presentBanner(); err != nil { + return trace.Wrap(err) + } + var ( multilineAcc strings.Builder readingMultiline bool @@ -161,7 +165,7 @@ func (r *REPL) Run(ctx context.Context) error { // formatTermError changes the term.Terminal error to match caller expectations. func formatTermError(ctx context.Context, err error) error { - // When context is canceled it will immediatly lead read/write errors due + // When context is canceled it will immediately lead read/write errors due // to the closed connection. For this cases we return the context error. if ctx.Err() != nil && (errors.Is(err, io.ErrClosedPipe) || errors.Is(err, net.ErrClosed)) { return ctx.Err() diff --git a/lib/client/db/postgres/repl/repl_test.go b/lib/client/db/postgres/repl/repl_test.go index e28de99ea3603..46a3f230dbb0c 100644 --- a/lib/client/db/postgres/repl/repl_test.go +++ b/lib/client/db/postgres/repl/repl_test.go @@ -102,10 +102,12 @@ func TestQuery(t *testing.T) { writeLine(t, tc, line) } - require.EventuallyWithT(t, func(collect *assert.CollectT) { - query := <-tc.QueryChan() + select { + case query := <-tc.QueryChan(): require.Equal(t, tt.expectedQuery, query) - }, time.Second, time.Millisecond) + case <-time.After(time.Second): + require.Fail(t, "expected to receive query but got nothing") + } // Always expect a query reply from the server. _ = readUntilNextLead(t, tc) @@ -135,10 +137,13 @@ func writeLine(t *testing.T, c *testCtx, line string) { } }(c.conn) + // Given that the test connections are piped a problem with the reader side + // would lead into blocking writing. To avoid this scenario we're using + // the Eventually just to ensure a timeout on writing into the connections. require.EventuallyWithT(t, func(collect *assert.CollectT) { _, err := c.conn.Write(data) require.NoError(t, err) - }, 5*time.Second, time.Millisecond) + }, 5*time.Second, time.Millisecond, "expected to write into the connection successfully") } // readUntilNextLead reads the contents from the client connection until we @@ -163,6 +168,9 @@ func readLine(t *testing.T, c *testCtx) string { var n int buf := make([]byte, 1024) + // Given that the test connections are piped a problem with the writer side + // would lead into blocking reading. To avoid this scenario we're using + // the Eventually just to ensure a timeout on reading from the connections. require.EventuallyWithT(t, func(collect *assert.CollectT) { var err error n, err = c.conn.Read(buf[0:]) @@ -248,10 +256,13 @@ func StartWithServer(t *testing.T, ctx context.Context, opts ...testCtxOption) ( t.Cleanup(func() { tc.close() - require.EventuallyWithT(t, func(collect *assert.CollectT) { - err, _ := <-tc.errChan - assert.NoError(t, err) - }, time.Second, time.Millisecond) + + select { + case err := <-tc.errChan: + require.NoError(t, err) + case <-time.After(time.Second): + require.Fail(t, "expected to receive the test server close result but got nothing") + } }) go func(c *testCtx) { From 56e707162b5031de609c8e20b22ef62a596d8af8 Mon Sep 17 00:00:00 2001 From: Gabriel Corado Date: Thu, 5 Dec 2024 01:18:28 -0300 Subject: [PATCH 04/10] refactor(repl): code review suggestions --- lib/client/db/postgres/repl/commands.go | 21 ++++++++++-- lib/client/db/postgres/repl/commands_test.go | 2 +- lib/client/db/postgres/repl/repl.go | 36 +++++++++++++++++++- lib/client/db/postgres/repl/repl_test.go | 4 ++- 4 files changed, 57 insertions(+), 6 deletions(-) diff --git a/lib/client/db/postgres/repl/commands.go b/lib/client/db/postgres/repl/commands.go index c63d0037ada2d..5e82a2ef26b68 100644 --- a/lib/client/db/postgres/repl/commands.go +++ b/lib/client/db/postgres/repl/commands.go @@ -55,8 +55,10 @@ type command struct { // Description provides a user-friendly explanation of what the command // does. Description string - // ExecFunc is the function to execute the command. - ExecFunc func(*REPL, string) (string, bool) + // ExecFunc is the function to execute the command. The commands can either + // return a reply (that will be sent back to the client) as a string. Or + // it can terminates the REPL by returning bool on the second argument. + ExecFunc func(r *REPL, args string) (reply string, exit bool) } func initCommands() map[string]*command { @@ -70,7 +72,20 @@ func initCommands() map[string]*command { Type: commandTypeGeneral, Description: "Show Teleport interactive shell information, such as execution limitations.", ExecFunc: func(_ *REPL, _ string) (string, bool) { - return fmt.Sprintf("Teleport PostgreSQL interactive shell (v%s)", teleport.Version), false + // Formats limitiations in a dash list. Example: + // - hello + // multi line + // - another item + var limitations strings.Builder + for _, l := range descriptiveLimitations { + limitations.WriteString("- " + strings.Join(strings.Split(l, "\n"), "\n ")) + } + + return fmt.Sprintf( + "Teleport PostgreSQL interactive shell (v%s)\n\nLimitations: \n%s", + teleport.Version, + limitations.String(), + ), false }, }, "?": { diff --git a/lib/client/db/postgres/repl/commands_test.go b/lib/client/db/postgres/repl/commands_test.go index 3a9bb9ec6159a..16cc118b9a47a 100644 --- a/lib/client/db/postgres/repl/commands_test.go +++ b/lib/client/db/postgres/repl/commands_test.go @@ -50,7 +50,7 @@ func TestCommandExecution(t *testing.T) { "empty command": {line: "\\", expectUnknown: true}, } { t.Run(name, func(t *testing.T) { - commandArgsChan := make(chan string) + commandArgsChan := make(chan string, 1) instance, tc := StartWithServer(t, ctx, WithSkipREPLRun()) ctx, cancel := context.WithCancel(ctx) defer cancel() diff --git a/lib/client/db/postgres/repl/repl.go b/lib/client/db/postgres/repl/repl.go index eb1ddcaa705b0..3adf1c377b2c3 100644 --- a/lib/client/db/postgres/repl/repl.go +++ b/lib/client/db/postgres/repl/repl.go @@ -31,6 +31,7 @@ import ( "github.com/gravitational/teleport" clientproto "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/lib/asciitable" + "github.com/gravitational/teleport/lib/defaults" ) type REPL struct { @@ -43,10 +44,16 @@ type REPL struct { } func New(ctx context.Context, client io.ReadWriteCloser, serverConn net.Conn, route clientproto.RouteToDatabase) (*REPL, error) { - config, err := pgconn.ParseConfig(fmt.Sprintf("postgres://%s@%s/%s", route.Username, hostnamePlaceholder, route.Database)) + config, err := pgconn.ParseConfig(fmt.Sprintf("postgres://%s", hostnamePlaceholder)) if err != nil { return nil, trace.Wrap(err) } + config.User = route.Username + config.Database = route.Database + config.ConnectTimeout = defaults.DatabaseConnectTimeout + config.RuntimeParams = map[string]string{ + applicationNameParamName: applicationNameParamValue, + } config.TLSConfig = nil // Provide a lookup function to avoid having the hostname placeholder to @@ -252,3 +259,30 @@ const ( // errorReplyPrefix is the prefix presented when there is a execution error. errorReplyPrefix = "ERR " ) + +const ( + // applicationNameParamName defines the application name parameter name. + // + // https://www.postgresql.org/docs/17/libpq-connect.html#LIBPQ-CONNECT-APPLICATION-NAME + applicationNameParamName = "application_name" + // applicationNameParamValue defines the application name parameter value. + applicationNameParamValue = "teleport-repl" +) + +// descriptiveLimitations defines a user-friendly text containing the REPL +// limitations. +var descriptiveLimitations = []string{ + `Query cancellation is not supported. Once a query is sent, its execution +cannot be canceled. Note that Teleport sends a terminate message to the database +when the database session terminates. This flow doesn't guarantee that any +running queries will be canceled. +See https://www.postgresql.org/docs/17/protocol-flow.html#PROTOCOL-FLOW-TERMINATION for more details on the termination flow.`, + // This limitation is due to our terminal emulator not fully supporting this + // shortcut's custom handler. Instead, it will close the terminal, leading + // to terminating the session. To avoid having users accidentally + // terminating their sessions, we're turning this off until we have a better + // solution and propose the behavior for it. + // + // This shortcut filtered out by the WebUI key handler. + "Pressing CTRL-C will have no effect on this shell.", +} diff --git a/lib/client/db/postgres/repl/repl_test.go b/lib/client/db/postgres/repl/repl_test.go index 46a3f230dbb0c..286c96e53b03d 100644 --- a/lib/client/db/postgres/repl/repl_test.go +++ b/lib/client/db/postgres/repl/repl_test.go @@ -85,7 +85,9 @@ func TestQuery(t *testing.T) { lines []string expectedQuery string }{ - "query": {lines: []string{"SELECT 1;"}, expectedQuery: "SELECT 1;"}, + "query": {lines: []string{"SELECT 1;"}, expectedQuery: "SELECT 1;"}, + "query multiple semicolons": {lines: []string{"SELECT 1; ;;"}, expectedQuery: "SELECT 1; ;;"}, + "query multiple semicolons with trailing space": {lines: []string{"SELECT 1; ;; "}, expectedQuery: "SELECT 1; ;;"}, "multiline query": {lines: []string{"SELECT", "1", ";"}, expectedQuery: "SELECT\r\n1\r\n;"}, "malformatted": {lines: []string{"SELECT err;"}, expectedQuery: "SELECT err;"}, "query with special characters": {lines: []string{"SELECT 'special_chars_!@#$%^&*()';"}, expectedQuery: "SELECT 'special_chars_!@#$%^&*()';"}, From 6d28777425d6471c76c9ca2eda8968adac689acf Mon Sep 17 00:00:00 2001 From: Gabriel Corado Date: Fri, 6 Dec 2024 18:19:26 -0300 Subject: [PATCH 05/10] refactor(repl): code review suggestions --- lib/client/db/postgres/repl/commands.go | 6 +++--- lib/client/db/postgres/repl/commands_test.go | 4 ++-- lib/client/db/postgres/repl/repl.go | 4 ++-- lib/client/db/postgres/repl/repl_test.go | 14 ++++++++------ 4 files changed, 15 insertions(+), 13 deletions(-) diff --git a/lib/client/db/postgres/repl/commands.go b/lib/client/db/postgres/repl/commands.go index 5e82a2ef26b68..b7c488cffd46d 100644 --- a/lib/client/db/postgres/repl/commands.go +++ b/lib/client/db/postgres/repl/commands.go @@ -56,8 +56,8 @@ type command struct { // does. Description string // ExecFunc is the function to execute the command. The commands can either - // return a reply (that will be sent back to the client) as a string. Or - // it can terminates the REPL by returning bool on the second argument. + // return a reply (that will be sent back to the client) as a string. It can + // terminate the REPL by returning bool on the second argument. ExecFunc func(r *REPL, args string) (reply string, exit bool) } @@ -114,7 +114,7 @@ func initCommands() map[string]*command { }, "session": { Type: commandTypeConnection, - Description: "Display information about the current session, like user, roles, and database instance.", + Description: "Display information about the current session, like user, and database instance.", ExecFunc: func(r *REPL, _ string) (string, bool) { return fmt.Sprintf("Connected to %q instance at %q database as %q user.", r.route.ServiceName, r.route.Database, r.route.Username), false }, diff --git a/lib/client/db/postgres/repl/commands_test.go b/lib/client/db/postgres/repl/commands_test.go index 16cc118b9a47a..ce195c31d80ae 100644 --- a/lib/client/db/postgres/repl/commands_test.go +++ b/lib/client/db/postgres/repl/commands_test.go @@ -92,10 +92,10 @@ func TestCommandExecution(t *testing.T) { // When the command exits, the REPL and the connections will be // closed. if tt.commandExit { - require.EventuallyWithT(t, func(collect *assert.CollectT) { + require.EventuallyWithT(t, func(t *assert.CollectT) { var buf []byte _, err := tc.conn.Read(buf[0:]) - require.ErrorIs(t, err, io.EOF) + assert.ErrorIs(t, err, io.EOF) }, time.Second, time.Millisecond) select { diff --git a/lib/client/db/postgres/repl/repl.go b/lib/client/db/postgres/repl/repl.go index 3adf1c377b2c3..c0a6679ae333c 100644 --- a/lib/client/db/postgres/repl/repl.go +++ b/lib/client/db/postgres/repl/repl.go @@ -43,7 +43,7 @@ type REPL struct { commands map[string]*command } -func New(ctx context.Context, client io.ReadWriteCloser, serverConn net.Conn, route clientproto.RouteToDatabase) (*REPL, error) { +func New(client io.ReadWriteCloser, serverConn net.Conn, route clientproto.RouteToDatabase) (*REPL, error) { config, err := pgconn.ParseConfig(fmt.Sprintf("postgres://%s", hostnamePlaceholder)) if err != nil { return nil, trace.Wrap(err) @@ -284,5 +284,5 @@ See https://www.postgresql.org/docs/17/protocol-flow.html#PROTOCOL-FLOW-TERMINAT // solution and propose the behavior for it. // // This shortcut filtered out by the WebUI key handler. - "Pressing CTRL-C will have no effect on this shell.", + "Pressing CTRL-C will have no effect in this shell.", } diff --git a/lib/client/db/postgres/repl/repl_test.go b/lib/client/db/postgres/repl/repl_test.go index 286c96e53b03d..fc23be5c4e6fa 100644 --- a/lib/client/db/postgres/repl/repl_test.go +++ b/lib/client/db/postgres/repl/repl_test.go @@ -126,6 +126,8 @@ func writeLine(t *testing.T, c *testCtx, line string) { // we're net.Pipe). go func(conn net.Conn) { buf := make([]byte, len(data)) + // We need to consume any additional replies made by the terminal + // emulator until we consume the line contents. for { n, err := conn.Read(buf[0:]) if err != nil { @@ -142,9 +144,9 @@ func writeLine(t *testing.T, c *testCtx, line string) { // Given that the test connections are piped a problem with the reader side // would lead into blocking writing. To avoid this scenario we're using // the Eventually just to ensure a timeout on writing into the connections. - require.EventuallyWithT(t, func(collect *assert.CollectT) { + require.EventuallyWithT(t, func(t *assert.CollectT) { _, err := c.conn.Write(data) - require.NoError(t, err) + assert.NoError(t, err) }, 5*time.Second, time.Millisecond, "expected to write into the connection successfully") } @@ -173,11 +175,11 @@ func readLine(t *testing.T, c *testCtx) string { // Given that the test connections are piped a problem with the writer side // would lead into blocking reading. To avoid this scenario we're using // the Eventually just to ensure a timeout on reading from the connections. - require.EventuallyWithT(t, func(collect *assert.CollectT) { + require.EventuallyWithT(t, func(t *assert.CollectT) { var err error n, err = c.conn.Read(buf[0:]) - require.NoError(t, err) - require.Greater(t, n, 0) + assert.NoError(t, err) + assert.Greater(t, n, 0) }, 5*time.Second, time.Millisecond) return string(buf[:n]) } @@ -274,7 +276,7 @@ func StartWithServer(t *testing.T, ctx context.Context, opts ...testCtxOption) ( } }(tc) - r, err := New(ctx, tc.clientConn, tc.serverConn, tc.route) + r, err := New(tc.clientConn, tc.serverConn, tc.route) require.NoError(t, err) if !cfg.skipREPLRun { From 2e697a750d17a302fdc3cdb61e5948fe93cb79d8 Mon Sep 17 00:00:00 2001 From: Gabriel Corado Date: Fri, 6 Dec 2024 18:21:21 -0300 Subject: [PATCH 06/10] test(repl): increase timeout values --- lib/client/db/postgres/repl/commands_test.go | 6 +++--- lib/client/db/postgres/repl/repl_test.go | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/lib/client/db/postgres/repl/commands_test.go b/lib/client/db/postgres/repl/commands_test.go index ce195c31d80ae..20980444fd836 100644 --- a/lib/client/db/postgres/repl/commands_test.go +++ b/lib/client/db/postgres/repl/commands_test.go @@ -96,12 +96,12 @@ func TestCommandExecution(t *testing.T) { var buf []byte _, err := tc.conn.Read(buf[0:]) assert.ErrorIs(t, err, io.EOF) - }, time.Second, time.Millisecond) + }, 5*time.Second, time.Millisecond) select { case err := <-runErrChan: require.NoError(t, err, "expected the REPL instance exit gracefully") - case <-time.After(time.Second): + case <-time.After(5 * time.Second): require.Fail(t, "expected REPL run to terminate but got nothing") } return @@ -115,7 +115,7 @@ func TestCommandExecution(t *testing.T) { select { case err := <-runErrChan: require.ErrorIs(t, err, context.Canceled, "expected the REPL instance to finish running with error due to cancelation") - case <-time.After(time.Second): + case <-time.After(5 * time.Second): require.Fail(t, "expected REPL run to terminate but got nothing") } }) diff --git a/lib/client/db/postgres/repl/repl_test.go b/lib/client/db/postgres/repl/repl_test.go index fc23be5c4e6fa..f967c34857f3a 100644 --- a/lib/client/db/postgres/repl/repl_test.go +++ b/lib/client/db/postgres/repl/repl_test.go @@ -107,7 +107,7 @@ func TestQuery(t *testing.T) { select { case query := <-tc.QueryChan(): require.Equal(t, tt.expectedQuery, query) - case <-time.After(time.Second): + case <-time.After(5 * time.Second): require.Fail(t, "expected to receive query but got nothing") } @@ -264,7 +264,7 @@ func StartWithServer(t *testing.T, ctx context.Context, opts ...testCtxOption) ( select { case err := <-tc.errChan: require.NoError(t, err) - case <-time.After(time.Second): + case <-time.After(5 * time.Second): require.Fail(t, "expected to receive the test server close result but got nothing") } }) From 909d61521d908aca5e65f01a985250c31339cebb Mon Sep 17 00:00:00 2001 From: Gabriel Corado Date: Tue, 10 Dec 2024 11:20:45 -0300 Subject: [PATCH 07/10] fix(repl): commands formatting --- lib/client/db/postgres/repl/commands.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/client/db/postgres/repl/commands.go b/lib/client/db/postgres/repl/commands.go index b7c488cffd46d..07d2faf7a02aa 100644 --- a/lib/client/db/postgres/repl/commands.go +++ b/lib/client/db/postgres/repl/commands.go @@ -78,7 +78,7 @@ func initCommands() map[string]*command { // - another item var limitations strings.Builder for _, l := range descriptiveLimitations { - limitations.WriteString("- " + strings.Join(strings.Split(l, "\n"), "\n ")) + limitations.WriteString("- " + strings.Join(strings.Split(l, "\n"), "\n ") + lineBreak) } return fmt.Sprintf( @@ -99,12 +99,12 @@ func initCommands() map[string]*command { typesTable[cmd.Type] = &table } - typesTable[cmd.Type].AddRow([]string{cmdStr, cmd.Description}) + typesTable[cmd.Type].AddRow([]string{"\\" + cmdStr, cmd.Description}) } var res strings.Builder for cmdType, output := range typesTable { - res.WriteString(string(cmdType)) + res.WriteString(string(cmdType) + lineBreak) output.AsBuffer().WriteTo(&res) res.WriteString(lineBreak) } From d0ae6f939126c24d7554f5573cd8e7fcafccdcf8 Mon Sep 17 00:00:00 2001 From: Gabriel Corado Date: Tue, 10 Dec 2024 11:21:02 -0300 Subject: [PATCH 08/10] refactor(repl): send close pgconn using a different context --- lib/client/db/postgres/repl/repl.go | 7 +- lib/client/db/postgres/repl/repl_test.go | 90 ++++++++++++++++++++---- 2 files changed, 82 insertions(+), 15 deletions(-) diff --git a/lib/client/db/postgres/repl/repl.go b/lib/client/db/postgres/repl/repl.go index c0a6679ae333c..6320e15e56be1 100644 --- a/lib/client/db/postgres/repl/repl.go +++ b/lib/client/db/postgres/repl/repl.go @@ -23,6 +23,7 @@ import ( "io" "net" "strings" + "time" "github.com/gravitational/trace" "github.com/jackc/pgconn" @@ -82,7 +83,11 @@ func (r *REPL) Run(ctx context.Context) error { if err != nil { return trace.Wrap(err) } - defer pgConn.Close(context.TODO()) + defer func() { + closeCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + pgConn.Close(closeCtx) + }() // term.Terminal blocks reads/writes without respecting the context. The // only thing that unblocks it is closing the underlaying connection (in diff --git a/lib/client/db/postgres/repl/repl_test.go b/lib/client/db/postgres/repl/repl_test.go index f967c34857f3a..89af0ab422350 100644 --- a/lib/client/db/postgres/repl/repl_test.go +++ b/lib/client/db/postgres/repl/repl_test.go @@ -117,6 +117,56 @@ func TestQuery(t *testing.T) { } } +func TestClose(t *testing.T) { + for name, tt := range map[string]struct { + closeFunc func(tc *testCtx, cancelCtx context.CancelFunc) + expectTerminateMessage bool + }{ + "closed by context": { + closeFunc: func(_ *testCtx, cancelCtx context.CancelFunc) { + cancelCtx() + }, + expectTerminateMessage: true, + }, + "closed by server": { + closeFunc: func(tc *testCtx, _ context.CancelFunc) { + tc.CloseServer() + }, + expectTerminateMessage: false, + }, + } { + t.Run(name, func(t *testing.T) { + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + _, tc := StartWithServer(t, ctx) + // Consume the REPL banner. + _ = readUntilNextLead(t, tc) + + tt.closeFunc(tc, cancelFunc) + // After closing the REPL session, we expect any read/write to + // return error. In case the close wasn't effective we need to + // execute the read on a Eventually block to avoid blocking the + // test. + require.EventuallyWithT(t, func(t *assert.CollectT) { + var buf []byte + _, err := tc.conn.Read(buf[0:]) + assert.ErrorIs(t, err, io.EOF) + }, 5*time.Second, time.Millisecond) + + if !tt.expectTerminateMessage { + return + } + + select { + case <-tc.terminateChan: + case <-time.After(5 * time.Second): + require.Fail(t, "expected REPL to send terminate message but got nothing") + } + }) + } +} + func writeLine(t *testing.T, c *testCtx, line string) { t.Helper() data := []byte(line + lineBreak) @@ -196,10 +246,13 @@ type testCtx struct { // serverConn is the fake database server connection (that works as a // PostgreSQL instance). serverConn net.Conn + // rawPgConn is the underlaying net.Conn used by pgconn client. + rawPgConn net.Conn - route clientproto.RouteToDatabase - pgClient *pgproto3.Backend - errChan chan error + route clientproto.RouteToDatabase + pgClient *pgproto3.Backend + errChan chan error + terminateChan chan struct{} // queryChan handling custom queries is enabled the queries received by the // test server will be sent to this channel. queryChan chan string @@ -247,15 +300,17 @@ func StartWithServer(t *testing.T, ctx context.Context, opts ...testCtxOption) ( client := pgproto3.NewBackend(pgproto3.NewChunkReader(pgConn), pgConn) ctx, cancelFunc := context.WithCancel(ctx) tc := &testCtx{ - cfg: cfg, - ctx: ctx, - cancelFunc: cancelFunc, - conn: conn, - clientConn: clientConn, - serverConn: serverConn, - pgClient: client, - errChan: make(chan error, 1), - queryChan: make(chan string), + cfg: cfg, + ctx: ctx, + cancelFunc: cancelFunc, + conn: conn, + clientConn: clientConn, + serverConn: serverConn, + rawPgConn: pgConn, + pgClient: client, + errChan: make(chan error, 1), + terminateChan: make(chan struct{}), + queryChan: make(chan string), } t.Cleanup(func() { @@ -271,7 +326,7 @@ func StartWithServer(t *testing.T, ctx context.Context, opts ...testCtxOption) ( go func(c *testCtx) { defer close(c.errChan) - if err := c.processMessages(); err != nil { + if err := c.processMessages(); err != nil && !errors.Is(err, io.ErrClosedPipe) { c.errChan <- err } }(tc) @@ -292,7 +347,9 @@ func StartWithServer(t *testing.T, ctx context.Context, opts ...testCtxOption) ( select { case err := <-runErrChan: - require.ErrorIs(t, err, context.Canceled, "expected the REPL instance to finish running with error due to cancelation") + if !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrClosedPipe) { + require.Fail(t, "expected the REPL instance to finish with context cancelation or server closed pipe but got %q", err) + } case <-time.After(10 * time.Second): require.Fail(t, "timeout while waiting for REPL Run result") } @@ -306,6 +363,10 @@ func (tc *testCtx) QueryChan() chan string { return tc.queryChan } +func (tc *testCtx) CloseServer() { + tc.rawPgConn.Close() +} + func (tc *testCtx) close() { tc.serverConn.Close() tc.clientConn.Close() @@ -398,6 +459,7 @@ func (tc *testCtx) processMessages() error { } case *pgproto3.Terminate: + close(tc.terminateChan) return nil default: return trace.BadParameter("unsupported message %#v", message) From 3df80c062b26d578e2d94a7a211faeeee7afcb0a Mon Sep 17 00:00:00 2001 From: Gabriel Corado Date: Tue, 10 Dec 2024 11:36:55 -0300 Subject: [PATCH 09/10] fix(repl): add proper spacing between multi queries --- lib/client/db/postgres/repl/repl.go | 13 +++++++++++-- lib/client/db/postgres/repl/repl_test.go | 19 +++++++++++++++++++ .../repl/testdata/TestStart/multiquery.golden | 10 ++++++++++ 3 files changed, 40 insertions(+), 2 deletions(-) create mode 100644 lib/client/db/postgres/repl/testdata/TestStart/multiquery.golden diff --git a/lib/client/db/postgres/repl/repl.go b/lib/client/db/postgres/repl/repl.go index 6320e15e56be1..1f3c6b2cbefc0 100644 --- a/lib/client/db/postgres/repl/repl.go +++ b/lib/client/db/postgres/repl/repl.go @@ -204,8 +204,11 @@ func formatResult(results []*pgconn.Result, err error) string { return errorReplyPrefix + err.Error() } - var sb strings.Builder - for _, res := range results { + var ( + sb strings.Builder + resultsLen = len(results) + ) + for i, res := range results { if !res.CommandTag.Select() { return res.CommandTag.String() } @@ -230,6 +233,12 @@ func formatResult(results []*pgconn.Result, err error) string { table.AsBuffer().WriteTo(&sb) sb.WriteString(rowsText(len(res.Rows))) + + // Add line breaks to separate results. Except the last result, which + // will have line breaks added later in the reply. + if i != resultsLen-1 { + sb.WriteString(lineBreak + lineBreak) + } } return sb.String() diff --git a/lib/client/db/postgres/repl/repl_test.go b/lib/client/db/postgres/repl/repl_test.go index 89af0ab422350..1d571f2bdfcc9 100644 --- a/lib/client/db/postgres/repl/repl_test.go +++ b/lib/client/db/postgres/repl/repl_test.go @@ -70,6 +70,13 @@ func TestStart(t *testing.T) { golden.SetNamed(t, "data_type", []byte(dataTypeQueryResult)) } require.Equal(t, string(golden.GetNamed(t, "data_type")), dataTypeQueryResult) + + writeLine(t, tc, multiQuery) + multiQueryResult := readUntilNextLead(t, tc) + if golden.ShouldSet() { + golden.SetNamed(t, "multiquery", []byte(multiQueryResult)) + } + require.Equal(t, string(golden.GetNamed(t, "multiquery")), multiQueryResult) } // TestQuery given some input lines, the REPL should execute the expected @@ -449,6 +456,17 @@ func (tc *testCtx) processMessages() error { } case dataTypesQuery: messages = testdata.TestDataQueryResult + case multiQuery: + messages = []pgproto3.BackendMessage{ + &pgproto3.RowDescription{Fields: []pgproto3.FieldDescription{{Name: []byte("?column?")}}}, + &pgproto3.DataRow{Values: [][]byte{[]byte("1")}}, + &pgproto3.CommandComplete{CommandTag: pgconn.CommandTag("SELECT")}, + &pgproto3.RowDescription{Fields: []pgproto3.FieldDescription{{Name: []byte("id")}, {Name: []byte("email")}}}, + &pgproto3.DataRow{Values: [][]byte{[]byte("1"), []byte("alice@example.com")}}, + &pgproto3.DataRow{Values: [][]byte{[]byte("2"), []byte("bob@example.com")}}, + &pgproto3.CommandComplete{CommandTag: pgconn.CommandTag("SELECT")}, + &pgproto3.ReadyForQuery{}, + } case errorQuery: messages = []pgproto3.BackendMessage{ &pgproto3.ErrorResponse{Severity: "ERROR", Code: "42703", Message: "error"}, @@ -477,6 +495,7 @@ func (tc *testCtx) processMessages() error { const ( singleRowQuery = "SELECT * FROM users LIMIT 1;" multiRowQuery = "SELECT * FROM users;" + multiQuery = "SELECT 1; SELECT * FROM users;" dataTypesQuery = "SELECT * FROM test_data_types;" errorQuery = "SELECT err;" ) diff --git a/lib/client/db/postgres/repl/testdata/TestStart/multiquery.golden b/lib/client/db/postgres/repl/testdata/TestStart/multiquery.golden new file mode 100644 index 0000000000000..3d3724d7186be --- /dev/null +++ b/lib/client/db/postgres/repl/testdata/TestStart/multiquery.golden @@ -0,0 +1,10 @@ +?column? +-------- +1 +(1 row) + +id email +-- ----------------- +1 alice@example.com +2 bob@example.com +(2 rows) From c52d7f481fb4d04a9764441b3263bb45035e067d Mon Sep 17 00:00:00 2001 From: Gabriel Corado Date: Tue, 10 Dec 2024 12:05:17 -0300 Subject: [PATCH 10/10] test(repl): add fuzz test for processing commands --- lib/client/db/postgres/repl/commands_test.go | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/lib/client/db/postgres/repl/commands_test.go b/lib/client/db/postgres/repl/commands_test.go index 20980444fd836..2a974d470601f 100644 --- a/lib/client/db/postgres/repl/commands_test.go +++ b/lib/client/db/postgres/repl/commands_test.go @@ -169,3 +169,17 @@ func TestCommands(t *testing.T) { }) } } + +func FuzzCommands(f *testing.F) { + f.Add("q") + f.Add("?") + f.Add("session") + f.Add("teleport") + + repl := &REPL{commands: make(map[string]*command)} + f.Fuzz(func(t *testing.T, line string) { + require.NotPanics(t, func() { + _, _ = repl.processCommand(line) + }) + }) +}