From 048bced6fe4d54952c79d52c7656971fb298e320 Mon Sep 17 00:00:00 2001 From: Gareth Kirwan Date: Fri, 11 Oct 2024 12:18:20 +0700 Subject: [PATCH] fixup! Config: Update cmd/config to allow upgrades --- cmd/config/config.go | 168 ++++++++++++++++++++++++++++--------------- 1 file changed, 109 insertions(+), 59 deletions(-) diff --git a/cmd/config/config.go b/cmd/config/config.go index 4204c87e0b1..d0b84c8da75 100644 --- a/cmd/config/config.go +++ b/cmd/config/config.go @@ -4,78 +4,45 @@ import ( "flag" "fmt" "os" + "slices" + "strings" "github.com/thrasher-corp/gocryptotrader/common/file" "github.com/thrasher-corp/gocryptotrader/config" ) +var commands = []string{"upgrade", "encrypt", "decrypt"} + func main() { - var inFile, outFile, key string - var encrypt, decrypt, upgrade bool + fmt.Println("GoCryptoTrader: config-helper tool") + defaultCfgFile := config.DefaultFilePath() - flag.StringVar(&inFile, "infile", defaultCfgFile, "The config input file to process.") - flag.StringVar(&outFile, "outfile", "", "The config output file.") - flag.BoolVar(&encrypt, "encrypt", false, "Encrypt the config file") - flag.BoolVar(&decrypt, "decrypt", false, "Decrypt the config file") - flag.BoolVar(&upgrade, "upgrade", false, "Upgrade the config file") - flag.StringVar(&key, "key", "", "The key to use for AES encryption.") - flag.Parse() - - if outFile == "" { - outFile = inFile + ".out" - } - fmt.Println("GoCryptoTrader: config-helper tool.") + var in, out, key string + fs := flag.NewFlagSet("config", flag.ExitOnError) + fs.Usage = func() { usage(fs) } + fs.StringVar(&in, "infile", defaultCfgFile, "The config input file to process") + fs.StringVar(&out, "outfile", "[infile].out", "The config output file") + fs.StringVar(&key, "key", "", "The key to use for AES encryption") + _ = fs.Parse(os.Args[1:]) - if !(encrypt || decrypt || upgrade) { - fatal("Must provide one of -encrypt, -decrypt or -upgrade") + if out == "[infile].out" { + out = in + ".out" } - if upgrade { - doUpgrade(inFile, outFile) - return - } - - if key == "" { - result, err := config.PromptForConfigKey(false) - if err != nil { - fatal("Unable to obtain encryption/decryption key: " + err.Error()) - } - key = string(result) + switch parseCommand(fs) { + case "upgrade": + upgradeFile(in, out) + case "decrypt": + encryptWrapper(in, out, key, decryptFile) + case "encrypt": + encryptWrapper(in, out, key, encryptFile) } - fileData, err := os.ReadFile(inFile) - if err != nil { - fatal("Unable to read input file " + inFile + "; Error: " + err.Error()) - } - - switch { - case encrypt: - if config.IsEncrypted(fileData) { - fatal("File is already encrypted") - } - if fileData, err = config.EncryptConfigFile(fileData, []byte(key)); err != nil { - fatal("Unable to encrypt config data. Error: " + err.Error()) - } - fmt.Println("Encrypted config file") - case decrypt: - if !config.IsEncrypted(fileData) { - fatal("File is already decrypted") - } - if fileData, err = config.DecryptConfigFile(fileData, []byte(key)); err != nil { - fatal("Unable to decrypt config data. Error: " + err.Error()) - } - fmt.Println("Decrypted config file") - } - - if err = file.Write(outFile, fileData); err != nil { - fatal("Unable to write output file " + outFile + "; Error: " + err.Error()) - } - - fmt.Println("Success! File written to " + outFile) + fmt.Println("Success! File written to " + out) } -func doUpgrade(in, out string) { +func upgradeFile(in, out string) { if config.IsFileEncrypted(in) { fatal("Cannot upgrade an encrypted file. Please decrypt first") } @@ -87,10 +54,93 @@ func doUpgrade(in, out string) { if err != nil { fatal(err.Error()) } - fmt.Println("Success! File written to ", out) +} + +func encryptWrapper(in, out, key string, fn func(in string, key []byte) []byte) { + if key == "" { + key = getKey() + } + outData := fn(in, []byte(key)) + if err := file.Write(out, outData); err != nil { + fatal("Unable to write output file " + out + "; Error: " + err.Error()) + } +} + +func encryptFile(in string, key []byte) []byte { + if config.IsFileEncrypted(in) { + fatal("File is already encrypted") + } + outData, err := config.EncryptConfigFile(readFile(in), key) + if err != nil { + fatal("Unable to encrypt config data. Error: " + err.Error()) + } + return outData +} + +func decryptFile(in string, key []byte) []byte { + if !config.IsFileEncrypted(in) { + fatal("File is already decrypted") + } + outData, err := config.DecryptConfigFile(readFile(in), key) + if err != nil { + fatal("Unable to decrypt config data. Error: " + err.Error()) + } + return outData +} + +func readFile(in string) []byte { + fileData, err := os.ReadFile(in) + if err != nil { + fatal("Unable to read input file " + in + "; Error: " + err.Error()) + } + return fileData +} + +func getKey() string { + result, err := config.PromptForConfigKey(false) + if err != nil { + fatal("Unable to obtain encryption/decryption key: " + err.Error()) + } + return string(result) } func fatal(msg string) { fmt.Fprintln(os.Stderr, msg) - os.Exit(1) + os.Exit(2) +} + +// parseCommand will return the single non-flag parameter +// If none is provided, too many, or unrecognised, usage() will be called and exit 1 +func parseCommand(fs *flag.FlagSet) string { + switch fs.NArg() { + case 0: + fmt.Fprintln(os.Stderr, "No command provided") + case 1: + command := fs.Arg(0) + if slices.Contains(commands, command) { + return command + } + fmt.Fprintln(os.Stderr, "Unknown command provided: "+command) + default: + fmt.Fprintln(os.Stderr, "Too many commands porvided: "+strings.Join(fs.Args(), " ")) + } + usage(fs) + os.Exit(2) + return "" +} + +// usage prints command usage and exits 1 +func usage(fs *flag.FlagSet) { + //nolint:dupword // deliberate duplication of commands + fmt.Fprintln(os.Stderr, ` +Usage: +config [arguments] + +The commands are: + encrypt encrypt infile and write to outfile + decrypt decrypt infile and write to outfile + upgrade upgrade the version of a decrypted config file + +The arguments are:`) + fs.PrintDefaults() }