From d3795a57f327456ae1acffcd9e0a8d206511d6ff Mon Sep 17 00:00:00 2001 From: Gareth Kirwan Date: Fri, 15 Nov 2024 12:35:56 +0700 Subject: [PATCH] fixup! Config: Add -edit and encryption upgrade to cmd/config --- cmd/config/config.go | 5 ++ config/config_encryption.go | 30 ++++--- config/config_encryption_test.go | 148 +++++++++++++------------------ 3 files changed, 86 insertions(+), 97 deletions(-) diff --git a/cmd/config/config.go b/cmd/config/config.go index 699fafaf391..209cd454278 100644 --- a/cmd/config/config.go +++ b/cmd/config/config.go @@ -8,6 +8,7 @@ import ( "slices" "strings" + "github.com/buger/jsonparser" "github.com/thrasher-corp/gocryptotrader/common/file" "github.com/thrasher-corp/gocryptotrader/config" ) @@ -118,6 +119,10 @@ func decryptFile(in string, key []byte) ([]byte, error) { if err != nil { return nil, fmt.Errorf("Unable to decrypt config data. Error: %w", err) } + if json, err := jsonparser.Set(outData, []byte("-1"), "encryptConfig"); err == nil { + // err probably means it didn't decrypt, but we don't tell the user that for security + outData = json + } return outData, nil } diff --git a/config/config_encryption.go b/config/config_encryption.go index 423d7a1ee7b..2cedc3f8551 100644 --- a/config/config_encryption.go +++ b/config/config_encryption.go @@ -1,6 +1,7 @@ package config import ( + "bufio" "bytes" "crypto/aes" "crypto/cipher" @@ -9,7 +10,6 @@ import ( "fmt" "io" "os" - "syscall" "github.com/buger/jsonparser" "github.com/thrasher-corp/gocryptotrader/common" @@ -22,6 +22,11 @@ const ( saltRandomLength = 12 ) +// Public errors +var ( + ErrSettingEncryptConfig = errors.New("error setting EncryptConfig during encrypt config file") +) + var ( errAESBlockSize = errors.New("config file data is too small for the AES required block size") errNoPrefix = errors.New("data does not start with Encryption Prefix") @@ -48,10 +53,7 @@ func promptForConfigEncryption() (bool, error) { // PromptForConfigKey asks for configuration key func PromptForConfigKey(confirmKey bool) ([]byte, error) { for i := 0; i < 3; i++ { - fmt.Print("Please enter encryption key: ") - key, err := term.ReadPassword(int(syscall.Stdin)) - fmt.Println() - + key, err := getSensitiveInput("Please enter encryption key: ") if err != nil { return nil, err } @@ -64,10 +66,7 @@ func PromptForConfigKey(confirmKey bool) ([]byte, error) { return key, nil } - fmt.Print("Please re-enter key: ") - conf, err := term.ReadPassword(int(syscall.Stdin)) - fmt.Println() - + conf, err := getSensitiveInput("Please re-enter key: ") if err != nil { return nil, err } @@ -80,6 +79,17 @@ func PromptForConfigKey(confirmKey bool) ([]byte, error) { return nil, errors.New("No key entered") } +// getSensitiveInput reads input from stdin, with echo off it stdin is a terminal +func getSensitiveInput(prompt string) ([]byte, error) { + fmt.Print(prompt) + defer fmt.Println() + if term.IsTerminal(int(os.Stdin.Fd())) { + return term.ReadPassword(int(os.Stdin.Fd())) + } + a, err := bufio.NewReader(os.Stdin).ReadBytes('\n') + return bytes.TrimRight(a, "\n\r"), err +} + // EncryptConfigFile encrypts json config data with a key func EncryptConfigFile(configData, key []byte) ([]byte, error) { sessionDK, salt, err := makeNewSessionDK(key) @@ -98,7 +108,7 @@ func EncryptConfigFile(configData, key []byte) ([]byte, error) { func (c *Config) encryptConfigFile(configData []byte) ([]byte, error) { configData, err := jsonparser.Set(configData, []byte("1"), "encryptConfig") if err != nil { - return nil, fmt.Errorf("error setting EncryptConfig true during encrypt config file: %w", err) + return nil, fmt.Errorf("%w: %w", ErrSettingEncryptConfig, err) } block, err := aes.NewCipher(c.sessionDK) diff --git a/config/config_encryption_test.go b/config/config_encryption_test.go index e34b821cf02..16adfdd1bcc 100644 --- a/config/config_encryption_test.go +++ b/config/config_encryption_test.go @@ -2,7 +2,7 @@ package config import ( "bytes" - "fmt" + "crypto/aes" "io" "os" "path/filepath" @@ -17,56 +17,61 @@ func TestPromptForConfigEncryption(t *testing.T) { t.Parallel() confirm, err := promptForConfigEncryption() - if confirm { - t.Error("promptForConfigEncryption return incorrect bool") - } - if err == nil { - t.Error("Expected error as there is no input") - } + require.ErrorIs(t, err, io.EOF) + require.False(t, confirm) } func TestPromptForConfigKey(t *testing.T) { t.Parallel() - byteyBite, err := PromptForConfigKey(true) - if err == nil && len(byteyBite) > 1 { - t.Errorf("PromptForConfigKey: %s", err) - } + withInteractiveResponse(t, "\n\n", func() { + _, err := PromptForConfigKey(false) + require.ErrorIs(t, err, io.EOF) + }) - _, err = PromptForConfigKey(false) - if err == nil { - t.Error("Expected error") - } + withInteractiveResponse(t, "pass\n", func() { + k, err := PromptForConfigKey(false) + require.NoError(t, err) + assert.Equal(t, "pass", string(k)) + }) + + withInteractiveResponse(t, "what\nwhat\n", func() { + k, err := PromptForConfigKey(true) + require.NoError(t, err) + assert.Equal(t, "pass", string(k)) + }) + + withInteractiveResponse(t, "what\nno\n", func() { + k, err := PromptForConfigKey(true) + require.NoError(t, err) + assert.Equal(t, "pass", string(k)) + }) } func TestEncryptConfigFile(t *testing.T) { t.Parallel() _, err := EncryptConfigFile([]byte("test"), nil) - if err == nil { - t.Fatal("Expected error") - } + require.ErrorIs(t, err, errKeyIsEmpty) c := &Config{ sessionDK: []byte("a"), } - _, err = c.encryptConfigFile([]byte("test")) - if err == nil { - t.Fatal("Expected error") - } + _, err = c.encryptConfigFile([]byte(`test`)) + require.ErrorIs(t, err, ErrSettingEncryptConfig) + + _, err = c.encryptConfigFile([]byte(`{"test":1}`)) + require.Error(t, err) + require.IsType(t, aes.KeySizeError(1), err) sessDk, salt, err := makeNewSessionDK([]byte("asdf")) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err, "makeNewSessionDK must not error") c = &Config{ sessionDK: sessDk, storedSalt: salt, } - _, err = c.encryptConfigFile([]byte("test")) - if err != nil { - t.Fatal(err) - } + _, err = c.encryptConfigFile([]byte(`{"test":1}`)) + require.NoError(t, err) } func TestDecryptConfigFile(t *testing.T) { @@ -176,39 +181,22 @@ func TestSaveAndReopenEncryptedConfig(t *testing.T) { // Save encrypted config enc := filepath.Join(tempDir, "encrypted.dat") - err := withInteractiveResponse(t, "pass\npass\n", func() error { - return c.SaveConfigToFile(enc) + withInteractiveResponse(t, "pass\npass\n", func() { + err := c.SaveConfigToFile(enc) + require.NoError(t, err, "SaveConfigToFile must not error") }) - require.NoError(t, err) readConf := &Config{} - err = withInteractiveResponse(t, "pass\n", func() error { + withInteractiveResponse(t, "pass\n", func() { // Load with no existing state, key is read from the prepared file - return readConf.ReadConfigFromFile(enc, true) + err := readConf.ReadConfigFromFile(enc, true) + require.NoError(t, err, "ReadConfigFromFile must not error") }) - require.NoError(t, err) assert.Equal(t, "myCustomName", readConf.Name, "Name must be correct") assert.Equal(t, 1, readConf.EncryptConfig, "EncryptConfig must be set correctly") } -// setAnswersFile sets the given file as the current stdin -// returns the close function to defer for reverting the stdin -func setAnswersFile(t *testing.T, answerFile string) func() { - t.Helper() - oldIn := os.Stdin - - inputFile, err := os.Open(answerFile) - if err != nil { - t.Fatalf("Problem opening temp file at %s: %s\n", answerFile, err) - } - os.Stdin = inputFile - return func() { - inputFile.Close() - os.Stdin = oldIn - } -} - func TestReadConfigWithPrompt(t *testing.T) { // Prepare temp dir tempDir := t.TempDir() @@ -221,14 +209,13 @@ func TestReadConfigWithPrompt(t *testing.T) { // Save config testConfigFile := filepath.Join(tempDir, "config.json") err := c.SaveConfigToFile(testConfigFile) - if err != nil { - t.Fatalf("Problem saving config file in %s: %s\n", tempDir, err) - } + require.NoError(t, err, "SaveConfigToFile must not error") // Run the test c = &Config{} - err = withInteractiveResponse(t, "y\npass\npass\n", func() error { - return c.ReadConfigFromFile(testConfigFile, false) + withInteractiveResponse(t, "y\npass\npass\n", func() { + err := c.ReadConfigFromFile(testConfigFile, false) + require.NoError(t, err, "ReadConfigFromFile must not error") }) if err != nil { t.Fatalf("Problem reading config file at %s: %s\n", testConfigFile, err) @@ -284,16 +271,11 @@ func TestSaveConfigToFileWithErrorInPasswordPrompt(t *testing.T) { if err != nil { t.Fatal(err) } - err = withInteractiveResponse(t, "\n\n", func() error { - err = c.SaveConfigToFile(targetFile) - if err == nil { - t.Error("Expected error") - } - return nil + withInteractiveResponse(t, "\n\n", func() { + err := c.SaveConfigToFile(targetFile) + require.NoError(t, err, "SaveConfigToFile must not error") }) - if err != nil { - t.Fatal(err) - } + data, err := os.ReadFile(targetFile) if err != nil { t.Fatal(err) @@ -303,25 +285,17 @@ func TestSaveConfigToFileWithErrorInPasswordPrompt(t *testing.T) { } } -func withInteractiveResponse(t *testing.T, response string, body func() error) error { - t.Helper() - // Answers to the prompt - responseFile, err := os.CreateTemp("", "*.in") - if err != nil { - return fmt.Errorf("problem creating temp file: %w", err) - } - _, err = responseFile.WriteString(response) - if err != nil { - return fmt.Errorf("problem writing to temp file at %s: %w", responseFile.Name(), err) - } - err = responseFile.Close() - if err != nil { - return fmt.Errorf("problem closing temp file at %s: %w", responseFile.Name(), err) - } - defer os.Remove(responseFile.Name()) - - // Temporarily replace Stdin with a custom input - cleanup := setAnswersFile(t, responseFile.Name()) - defer cleanup() - return body() +func withInteractiveResponse(tb testing.TB, response string, fn func()) { + tb.Helper() + f, err := os.CreateTemp("", "*.in") + require.NoError(tb, err, "CreateTemp must not error") + defer f.Close() + defer os.Remove(f.Name()) + _, err = f.WriteString(response) + require.NoError(tb, err, "WriteString must not error") + _, err = f.Seek(0, 0) + require.NoError(tb, err, "Seek must not error") + defer func(orig *os.File) { os.Stdin = orig }(os.Stdin) + os.Stdin = f + fn() }