Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PostgreSQL REPL implementation #49598

Merged
merged 12 commits into from
Dec 13, 2024
Merged
123 changes: 123 additions & 0 deletions lib/client/db/postgres/repl/commands.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
// Teleport
// Copyright (C) 2024 Gravitational, Inc.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

package repl

import (
"fmt"
"strings"

"github.com/gravitational/teleport"
"github.com/gravitational/teleport/lib/asciitable"
)

// processCommand receives a command call and return the reply and if the
// command terminates the session.
func (r *REPL) processCommand(line string) (string, bool) {
gabrielcorado marked this conversation as resolved.
Show resolved Hide resolved
cmdStr, args, _ := strings.Cut(strings.TrimPrefix(line, commandPrefix), " ")
cmd, ok := r.commands[cmdStr]
if !ok {
return "Unknown command. Try \\? to show the list of supported commands." + lineBreak, false
}

return cmd.ExecFunc(r, args)
}

// commandType specify the command category. This is used to organize the
// commands, for example, when showing them in the help command.
type commandType string

const (
// commandTypeGeneral represents a general-purpose command type.
commandTypeGeneral commandType = "General"
// commandTypeConnection represents a command type related to connection
// operations.
commandTypeConnection = "Connection"
)

// command represents a command that can be executed in the REPL.
type command struct {
// Type specifies the type of the command.
Type commandType
// Description provides a user-friendly explanation of what the command
// does.
Description string
// ExecFunc is the function to execute the command. The commands can either
// return a reply (that will be sent back to the client) as a string. Or
// it can terminates the REPL by returning bool on the second argument.
gabrielcorado marked this conversation as resolved.
Show resolved Hide resolved
ExecFunc func(r *REPL, args string) (reply string, exit bool)
}

func initCommands() map[string]*command {
return map[string]*command{
"q": {
Type: commandTypeGeneral,
Description: "Terminates the session.",
ExecFunc: func(_ *REPL, _ string) (string, bool) { return "", true },
},
"teleport": {
Type: commandTypeGeneral,
Description: "Show Teleport interactive shell information, such as execution limitations.",
ExecFunc: func(_ *REPL, _ string) (string, bool) {
// Formats limitiations in a dash list. Example:
// - hello
// multi line
// - another item
var limitations strings.Builder
for _, l := range descriptiveLimitations {
limitations.WriteString("- " + strings.Join(strings.Split(l, "\n"), "\n "))
}

return fmt.Sprintf(
"Teleport PostgreSQL interactive shell (v%s)\n\nLimitations: \n%s",
teleport.Version,
limitations.String(),
), false
},
},
"?": {
Type: commandTypeGeneral,
Description: "Show the list of supported commands.",
ExecFunc: func(r *REPL, _ string) (string, bool) {
typesTable := make(map[commandType]*asciitable.Table)
for cmdStr, cmd := range r.commands {
if _, ok := typesTable[cmd.Type]; !ok {
table := asciitable.MakeHeadlessTable(2)
typesTable[cmd.Type] = &table
}

typesTable[cmd.Type].AddRow([]string{cmdStr, cmd.Description})
}

var res strings.Builder
for cmdType, output := range typesTable {
res.WriteString(string(cmdType))
output.AsBuffer().WriteTo(&res)
res.WriteString(lineBreak)
}

return res.String(), false
},
},
"session": {
Type: commandTypeConnection,
Description: "Display information about the current session, like user, roles, and database instance.",
gabrielcorado marked this conversation as resolved.
Show resolved Hide resolved
ExecFunc: func(r *REPL, _ string) (string, bool) {
return fmt.Sprintf("Connected to %q instance at %q database as %q user.", r.route.ServiceName, r.route.Database, r.route.Username), false
},
},
}
}
171 changes: 171 additions & 0 deletions lib/client/db/postgres/repl/commands_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
// Teleport
// Copyright (C) 2024 Gravitational, Inc.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

package repl

import (
"context"
"io"
"strings"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/gravitational/teleport"
clientproto "github.com/gravitational/teleport/api/client/proto"
)

func TestCommandExecution(t *testing.T) {
ctx := context.Background()

for name, tt := range map[string]struct {
line string
commandResult string
expectedArgs string
expectUnknown bool
commandExit bool
}{
"execute": {line: "\\test", commandResult: "test"},
"execute with additional arguments": {line: "\\test a b", commandResult: "test", expectedArgs: "a b"},
"execute with exit": {line: "\\test", commandExit: true},
"execute with leading and trailing whitespace": {line: " \\test ", commandResult: "test"},
"unknown command with semicolon": {line: "\\test;", expectUnknown: true},
"unknown command": {line: "\\wrong", expectUnknown: true},
"with special characters": {line: "\\special_chars_!@#$%^&*()}", expectUnknown: true},
"empty command": {line: "\\", expectUnknown: true},
} {
t.Run(name, func(t *testing.T) {
commandArgsChan := make(chan string, 1)
instance, tc := StartWithServer(t, ctx, WithSkipREPLRun())
ctx, cancel := context.WithCancel(ctx)
defer cancel()

runErrChan := make(chan error)
go func() {
runErrChan <- instance.Run(ctx)
}()

// Consume the REPL banner.
_ = readUntilNextLead(t, tc)

// Reset available commands and add a test command so we can assert
// the command execution flow without relying in commands
// implementation or test server capabilities.
instance.commands = map[string]*command{
"test": {
ExecFunc: func(r *REPL, args string) (string, bool) {
commandArgsChan <- args
gabrielcorado marked this conversation as resolved.
Show resolved Hide resolved
return tt.commandResult, tt.commandExit
},
},
}

writeLine(t, tc, tt.line)
if tt.expectUnknown {
reply := readUntilNextLead(t, tc)
require.True(t, strings.HasPrefix(strings.ToLower(reply), "unknown command"))
return
}

select {
case args := <-commandArgsChan:
require.Equal(t, tt.expectedArgs, args)
case <-time.After(time.Second):
require.Fail(t, "expected to command args from test server but got nothing")
}

// When the command exits, the REPL and the connections will be
// closed.
if tt.commandExit {
require.EventuallyWithT(t, func(collect *assert.CollectT) {
var buf []byte
_, err := tc.conn.Read(buf[0:])
require.ErrorIs(t, err, io.EOF)
gabrielcorado marked this conversation as resolved.
Show resolved Hide resolved
}, time.Second, time.Millisecond)

select {
case err := <-runErrChan:
require.NoError(t, err, "expected the REPL instance exit gracefully")
case <-time.After(time.Second):
require.Fail(t, "expected REPL run to terminate but got nothing")
}
return
}

reply := readUntilNextLead(t, tc)
require.Equal(t, tt.commandResult, reply)

// Terminate the REPL run session and wait for the Run results.
cancel()
select {
case err := <-runErrChan:
require.ErrorIs(t, err, context.Canceled, "expected the REPL instance to finish running with error due to cancelation")
case <-time.After(time.Second):
require.Fail(t, "expected REPL run to terminate but got nothing")
}
})
}
}

func TestCommands(t *testing.T) {
availableCmds := initCommands()
for cmdName, tc := range map[string]struct {
repl *REPL
args string
expectExit bool
assertCommandReply require.ValueAssertionFunc
}{
"q": {expectExit: true},
"teleport": {
assertCommandReply: func(t require.TestingT, val interface{}, _ ...interface{}) {
require.Contains(t, val, teleport.Version, "expected \\teleport command to include current Teleport version")
},
},
"?": {
repl: &REPL{commands: availableCmds},
assertCommandReply: func(t require.TestingT, val interface{}, _ ...interface{}) {
for cmd := range availableCmds {
require.Contains(t, val, cmd, "expected \\? command to include information about \\%s", cmd)
}
},
},
"session": {
repl: &REPL{route: clientproto.RouteToDatabase{
ServiceName: "service",
Username: "username",
Database: "database",
}},
assertCommandReply: func(t require.TestingT, val interface{}, _ ...interface{}) {
require.Contains(t, val, "service", "expected \\session command to contain service name")
require.Contains(t, val, "username", "expected \\session command to contain username")
require.Contains(t, val, "database", "expected \\session command to contain database name")
},
},
} {
t.Run(cmdName, func(t *testing.T) {
cmd, ok := availableCmds[cmdName]
require.True(t, ok, "expected command %q to be available at commands", cmdName)
reply, exit := cmd.ExecFunc(tc.repl, tc.args)
if tc.expectExit {
require.True(t, exit, "expected command to exit the REPL")
return
}
tc.assertCommandReply(t, reply)
})
}
}
Loading
Loading