diff --git a/cmd/network/switch.go b/cmd/network/switch.go index a897ccf..04d633f 100644 --- a/cmd/network/switch.go +++ b/cmd/network/switch.go @@ -1,6 +1,7 @@ package network import ( + "encoding/json" "errors" "fmt" "net" @@ -40,6 +41,11 @@ func init() { networkCmd.AddCommand(switchCmd) } +// retainedSettings are the settings we want to keep track of when switching networks so we can swap back to them in the future +type retainedSettings struct { + DNSServers []string `json:"dns_servers"` +} + // SwitchNetwork implements the logic to swap networks func SwitchNetwork(networkName string, checkForRunningNode bool) { slogs.Logr.Info("Swapping to network", "network", networkName) @@ -91,6 +97,32 @@ func SwitchNetwork(networkName string, checkForRunningNode bool) { slogs.Logr.Fatal("error creating cache file directory for new network", "error", err, "directory", cacheFileDirNewNetwork) } + previousSettings := retainedSettings{ + DNSServers: cfg.FullNode.DNSServers, + } + marshalled, err := json.Marshal(previousSettings) + if err != nil { + slogs.Logr.Fatal("error marshalling retained settings to json", "error", err) + } + err = os.WriteFile(path.Join(cacheFileDirOldNetwork, "settings.json"), marshalled, 0644) + if err != nil { + slogs.Logr.Fatal("error writing settings for old network", "error", err) + } + + var settingsToRestore *retainedSettings + newSettingsPath := path.Join(cacheFileDirNewNetwork, "settings.json") + if _, err := os.Stat(newSettingsPath); err == nil { + settings, err := os.ReadFile(newSettingsPath) + if err != nil { + slogs.Logr.Fatal("error reading stored settings for the new network", "error", err) + } + settingsToRestore = &retainedSettings{} + err = json.Unmarshal(settings, settingsToRestore) + if err != nil { + slogs.Logr.Fatal("error unmarshalling stored settings for the new network", "error", err) + } + } + // Check if Full Node is running if checkForRunningNode { slogs.Logr.Debug("initializing websocket client to ensure chia is stopped") @@ -133,24 +165,33 @@ func SwitchNetwork(networkName string, checkForRunningNode bool) { } introducerHost := "introducer.chia.net" - dnsIntroducerHost := "dns-introducer.chia.net" + dnsIntroducerHosts := []string{"dns-introducer.chia.net"} fullNodePort := uint16(8444) peersFilePath := "db/peers.dat" walletPeersFilePath := "wallet/db/wallet_peers.dat" bootstrapPeers := []string{"node.chia.net"} if networkName != "mainnet" { introducerHost = fmt.Sprintf("introducer-%s.chia.net", networkName) - dnsIntroducerHost = fmt.Sprintf("dns-introducer-%s.chia.net", networkName) + dnsIntroducerHosts = []string{fmt.Sprintf("dns-introducer-%s.chia.net", networkName)} fullNodePort = uint16(58444) peersFilePath = fmt.Sprintf("db/peers-%s.dat", networkName) walletPeersFilePath = fmt.Sprintf("wallet/db/wallet_peers-%s.dat", networkName) bootstrapPeers = []string{fmt.Sprintf("node-%s.chia.net", networkName)} } + + // Any stored settings for the new network should be applied here, before any flags override them + if settingsToRestore != nil { + slogs.Logr.Info("restoring stored settings for this network") + if len(settingsToRestore.DNSServers) > 0 { + dnsIntroducerHosts = settingsToRestore.DNSServers + } + } + if introFlag := viper.GetString("switch-introducer"); introFlag != "" { introducerHost = introFlag } if dnsIntroFlag := viper.GetString("switch-dns-introducer"); dnsIntroFlag != "" { - dnsIntroducerHost = dnsIntroFlag + dnsIntroducerHosts = []string{dnsIntroFlag} } if bootPeer := viper.GetString("switch-bootstrap-peer"); bootPeer != "" { bootstrapPeers = []string{bootPeer} @@ -172,7 +213,7 @@ func SwitchNetwork(networkName string, checkForRunningNode bool) { }, }, "full_node.database_path": fmt.Sprintf("db/blockchain_v2_%s.sqlite", networkName), - "full_node.dns_servers": []string{dnsIntroducerHost}, + "full_node.dns_servers": dnsIntroducerHosts, "full_node.peers_file_path": peersFilePath, "full_node.port": fullNodePort, "full_node.introducer_peer.host": introducerHost, @@ -187,7 +228,7 @@ func SwitchNetwork(networkName string, checkForRunningNode bool) { Port: fullNodePort, }, }, - "wallet.dns_servers": []string{dnsIntroducerHost}, + "wallet.dns_servers": dnsIntroducerHosts, "wallet.full_node_peers": []config.Peer{ { Host: "localhost", @@ -198,14 +239,14 @@ func SwitchNetwork(networkName string, checkForRunningNode bool) { "wallet.introducer_peer.port": fullNodePort, "wallet.wallet_peers_file_path": walletPeersFilePath, } - for path, value := range pathUpdates { - pathMap := config.ParsePathsFromStrings([]string{path}, false) + for configPath, value := range pathUpdates { + pathMap := config.ParsePathsFromStrings([]string{configPath}, false) var key string var pathSlice []string for key, pathSlice = range pathMap { break } - slogs.Logr.Debug("setting config path", "path", path, "value", value) + slogs.Logr.Debug("setting config path", "path", configPath, "value", value) err = cfg.SetFieldByPath(pathSlice, value) if err != nil { slogs.Logr.Fatal("error setting path in config", "key", key, "value", value, "error", err) diff --git a/cmd/network/switch_test.go b/cmd/network/switch_test.go index 16a2ae4..3c10282 100644 --- a/cmd/network/switch_test.go +++ b/cmd/network/switch_test.go @@ -88,3 +88,29 @@ func TestNetworkSwitch(t *testing.T) { assert.Equal(t, config.Peer{Host: "introducer-unittestnet.chia.net", Port: port}, cfg.Wallet.IntroducerPeer) assert.Equal(t, "wallet/db/wallet_peers-unittestnet.dat", cfg.Wallet.WalletPeersFilePath) } + +func TestNetworkSwitch_SettingRetention(t *testing.T) { + cmd.InitLogs() + setupDefaultConfig(t) + cfg, err := config.GetChiaConfig() + assert.NoError(t, err) + assert.Equal(t, "mainnet", *cfg.SelectedNetwork) + + // Set some custom dns introducers, and ensure they are back when swapping away and back to mainnet + cfg.FullNode.DNSServers = []string{"dns-mainnet-1.example.com", "dns-mainnet-2.example.com"} + err = cfg.Save() + assert.NoError(t, err) + + // reload config from disk to ensure the dns servers were persisted + cfg, err = config.GetChiaConfig() + assert.NoError(t, err) + assert.Equal(t, []string{"dns-mainnet-1.example.com", "dns-mainnet-2.example.com"}, cfg.FullNode.DNSServers) + + network.SwitchNetwork("unittestnet", false) + network.SwitchNetwork("mainnet", false) + + // reload config from disk + cfg, err = config.GetChiaConfig() + assert.NoError(t, err) + assert.Equal(t, []string{"dns-mainnet-1.example.com", "dns-mainnet-2.example.com"}, cfg.FullNode.DNSServers) +}