Skip to content

Commit

Permalink
Add PostgreSQL REPL implementation (#49598)
Browse files Browse the repository at this point in the history
* feat(repl): add postgres

* refactor(repl): change repl to use a single Run function

* test(repl): reduce usage of require.Eventually blocks

* refactor(repl): code review suggestions

* refactor(repl): code review suggestions

* test(repl): increase timeout values

* fix(repl): commands formatting

* refactor(repl): send close pgconn using a different context

* fix(repl): add proper spacing between multi queries

* test(repl): add fuzz test for processing commands
  • Loading branch information
gabrielcorado authored Dec 13, 2024
1 parent 821708e commit 89ea69c
Show file tree
Hide file tree
Showing 10 changed files with 1,246 additions and 0 deletions.
123 changes: 123 additions & 0 deletions lib/client/db/postgres/repl/commands.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
// Teleport
// Copyright (C) 2024 Gravitational, Inc.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

package repl

import (
"fmt"
"strings"

"github.com/gravitational/teleport"
"github.com/gravitational/teleport/lib/asciitable"
)

// processCommand receives a command call and return the reply and if the
// command terminates the session.
func (r *REPL) processCommand(line string) (string, bool) {
cmdStr, args, _ := strings.Cut(strings.TrimPrefix(line, commandPrefix), " ")
cmd, ok := r.commands[cmdStr]
if !ok {
return "Unknown command. Try \\? to show the list of supported commands." + lineBreak, false
}

return cmd.ExecFunc(r, args)
}

// commandType specify the command category. This is used to organize the
// commands, for example, when showing them in the help command.
type commandType string

const (
// commandTypeGeneral represents a general-purpose command type.
commandTypeGeneral commandType = "General"
// commandTypeConnection represents a command type related to connection
// operations.
commandTypeConnection = "Connection"
)

// command represents a command that can be executed in the REPL.
type command struct {
// Type specifies the type of the command.
Type commandType
// Description provides a user-friendly explanation of what the command
// does.
Description string
// ExecFunc is the function to execute the command. The commands can either
// return a reply (that will be sent back to the client) as a string. It can
// terminate the REPL by returning bool on the second argument.
ExecFunc func(r *REPL, args string) (reply string, exit bool)
}

func initCommands() map[string]*command {
return map[string]*command{
"q": {
Type: commandTypeGeneral,
Description: "Terminates the session.",
ExecFunc: func(_ *REPL, _ string) (string, bool) { return "", true },
},
"teleport": {
Type: commandTypeGeneral,
Description: "Show Teleport interactive shell information, such as execution limitations.",
ExecFunc: func(_ *REPL, _ string) (string, bool) {
// Formats limitiations in a dash list. Example:
// - hello
// multi line
// - another item
var limitations strings.Builder
for _, l := range descriptiveLimitations {
limitations.WriteString("- " + strings.Join(strings.Split(l, "\n"), "\n ") + lineBreak)
}

return fmt.Sprintf(
"Teleport PostgreSQL interactive shell (v%s)\n\nLimitations: \n%s",
teleport.Version,
limitations.String(),
), false
},
},
"?": {
Type: commandTypeGeneral,
Description: "Show the list of supported commands.",
ExecFunc: func(r *REPL, _ string) (string, bool) {
typesTable := make(map[commandType]*asciitable.Table)
for cmdStr, cmd := range r.commands {
if _, ok := typesTable[cmd.Type]; !ok {
table := asciitable.MakeHeadlessTable(2)
typesTable[cmd.Type] = &table
}

typesTable[cmd.Type].AddRow([]string{"\\" + cmdStr, cmd.Description})
}

var res strings.Builder
for cmdType, output := range typesTable {
res.WriteString(string(cmdType) + lineBreak)
output.AsBuffer().WriteTo(&res)
res.WriteString(lineBreak)
}

return res.String(), false
},
},
"session": {
Type: commandTypeConnection,
Description: "Display information about the current session, like user, and database instance.",
ExecFunc: func(r *REPL, _ string) (string, bool) {
return fmt.Sprintf("Connected to %q instance at %q database as %q user.", r.route.ServiceName, r.route.Database, r.route.Username), false
},
},
}
}
185 changes: 185 additions & 0 deletions lib/client/db/postgres/repl/commands_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
// Teleport
// Copyright (C) 2024 Gravitational, Inc.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

package repl

import (
"context"
"io"
"strings"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/gravitational/teleport"
clientproto "github.com/gravitational/teleport/api/client/proto"
)

func TestCommandExecution(t *testing.T) {
ctx := context.Background()

for name, tt := range map[string]struct {
line string
commandResult string
expectedArgs string
expectUnknown bool
commandExit bool
}{
"execute": {line: "\\test", commandResult: "test"},
"execute with additional arguments": {line: "\\test a b", commandResult: "test", expectedArgs: "a b"},
"execute with exit": {line: "\\test", commandExit: true},
"execute with leading and trailing whitespace": {line: " \\test ", commandResult: "test"},
"unknown command with semicolon": {line: "\\test;", expectUnknown: true},
"unknown command": {line: "\\wrong", expectUnknown: true},
"with special characters": {line: "\\special_chars_!@#$%^&*()}", expectUnknown: true},
"empty command": {line: "\\", expectUnknown: true},
} {
t.Run(name, func(t *testing.T) {
commandArgsChan := make(chan string, 1)
instance, tc := StartWithServer(t, ctx, WithSkipREPLRun())
ctx, cancel := context.WithCancel(ctx)
defer cancel()

runErrChan := make(chan error)
go func() {
runErrChan <- instance.Run(ctx)
}()

// Consume the REPL banner.
_ = readUntilNextLead(t, tc)

// Reset available commands and add a test command so we can assert
// the command execution flow without relying in commands
// implementation or test server capabilities.
instance.commands = map[string]*command{
"test": {
ExecFunc: func(r *REPL, args string) (string, bool) {
commandArgsChan <- args
return tt.commandResult, tt.commandExit
},
},
}

writeLine(t, tc, tt.line)
if tt.expectUnknown {
reply := readUntilNextLead(t, tc)
require.True(t, strings.HasPrefix(strings.ToLower(reply), "unknown command"))
return
}

select {
case args := <-commandArgsChan:
require.Equal(t, tt.expectedArgs, args)
case <-time.After(time.Second):
require.Fail(t, "expected to command args from test server but got nothing")
}

// When the command exits, the REPL and the connections will be
// closed.
if tt.commandExit {
require.EventuallyWithT(t, func(t *assert.CollectT) {
var buf []byte
_, err := tc.conn.Read(buf[0:])
assert.ErrorIs(t, err, io.EOF)
}, 5*time.Second, time.Millisecond)

select {
case err := <-runErrChan:
require.NoError(t, err, "expected the REPL instance exit gracefully")
case <-time.After(5 * time.Second):
require.Fail(t, "expected REPL run to terminate but got nothing")
}
return
}

reply := readUntilNextLead(t, tc)
require.Equal(t, tt.commandResult, reply)

// Terminate the REPL run session and wait for the Run results.
cancel()
select {
case err := <-runErrChan:
require.ErrorIs(t, err, context.Canceled, "expected the REPL instance to finish running with error due to cancelation")
case <-time.After(5 * time.Second):
require.Fail(t, "expected REPL run to terminate but got nothing")
}
})
}
}

func TestCommands(t *testing.T) {
availableCmds := initCommands()
for cmdName, tc := range map[string]struct {
repl *REPL
args string
expectExit bool
assertCommandReply require.ValueAssertionFunc
}{
"q": {expectExit: true},
"teleport": {
assertCommandReply: func(t require.TestingT, val interface{}, _ ...interface{}) {
require.Contains(t, val, teleport.Version, "expected \\teleport command to include current Teleport version")
},
},
"?": {
repl: &REPL{commands: availableCmds},
assertCommandReply: func(t require.TestingT, val interface{}, _ ...interface{}) {
for cmd := range availableCmds {
require.Contains(t, val, cmd, "expected \\? command to include information about \\%s", cmd)
}
},
},
"session": {
repl: &REPL{route: clientproto.RouteToDatabase{
ServiceName: "service",
Username: "username",
Database: "database",
}},
assertCommandReply: func(t require.TestingT, val interface{}, _ ...interface{}) {
require.Contains(t, val, "service", "expected \\session command to contain service name")
require.Contains(t, val, "username", "expected \\session command to contain username")
require.Contains(t, val, "database", "expected \\session command to contain database name")
},
},
} {
t.Run(cmdName, func(t *testing.T) {
cmd, ok := availableCmds[cmdName]
require.True(t, ok, "expected command %q to be available at commands", cmdName)
reply, exit := cmd.ExecFunc(tc.repl, tc.args)
if tc.expectExit {
require.True(t, exit, "expected command to exit the REPL")
return
}
tc.assertCommandReply(t, reply)
})
}
}

func FuzzCommands(f *testing.F) {
f.Add("q")
f.Add("?")
f.Add("session")
f.Add("teleport")

repl := &REPL{commands: make(map[string]*command)}
f.Fuzz(func(t *testing.T, line string) {
require.NotPanics(t, func() {
_, _ = repl.processCommand(line)
})
})
}
Loading

0 comments on commit 89ea69c

Please sign in to comment.