diff --git a/cmd/root.go b/cmd/root.go index 04a8da72..e674071a 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -99,7 +99,6 @@ func InitConfig(viper *viper.Viper) { logger.FatalLog("could not create /etc/netclient dir" + err.Error()) } } - //wireguard.WriteWgConfig(Netclient(), GetNodes()) } func setupLogging(flags *viper.Viper) { diff --git a/functions/daemon.go b/functions/daemon.go index 5bf79039..e19d3a48 100644 --- a/functions/daemon.go +++ b/functions/daemon.go @@ -49,6 +49,7 @@ type cachedMessage struct { func Daemon() { slog.Info("starting netclient daemon", "version", config.Version) daemon.RemoveAllLockFiles() + go deleteAllDNS() if err := ncutils.SavePID(); err != nil { slog.Error("unable to save PID on daemon startup", "error", err) os.Exit(1) @@ -319,16 +320,6 @@ func setHostSubscription(client mqtt.Client, server string) { slog.Error("unable to subscribe to host updates", "host", hostID, "server", server, "error", token.Error()) return } - slog.Info("subscribing to dns updates for", "host", hostID, "server", server) - if token := client.Subscribe(fmt.Sprintf("dns/update/%s/%s", hostID.String(), server), 0, mqtt.MessageHandler(dnsUpdate)); token.Wait() && token.Error() != nil { - slog.Error("unable to subscribe to dns updates", "host", hostID, "server", server, "error", token.Error()) - return - } - slog.Info("subscribing to all dns updates for", "host", hostID, "server", server) - if token := client.Subscribe(fmt.Sprintf("dns/all/%s/%s", hostID.String(), server), 0, mqtt.MessageHandler(dnsAll)); token.Wait() && token.Error() != nil { - slog.Error("unable to subscribe to all dns updates", "host", hostID, "server", server, "error", token.Error()) - return - } } diff --git a/functions/dns.go b/functions/dns.go index 65815e3a..cf713e34 100644 --- a/functions/dns.go +++ b/functions/dns.go @@ -68,39 +68,3 @@ func deleteAllDNS() error { } return nil } - -func deleteNetworkDNS(network string) error { - temp := os.TempDir() - lockfile := temp + "/netclient-lock" - if err := config.Lock(lockfile); err != nil { - return err - } - defer config.Unlock(lockfile) - hosts, err := txeh.NewHostsDefault() - if err != nil { - return err - } - lines := hosts.GetHostFileLines() - addressesToRemove := []string{} - for _, line := range *lines { - if line.Comment == etcHostsComment { - if sliceContains(line.Hostnames, network) { - addressesToRemove = append(addressesToRemove, line.Address) - } - } - } - hosts.RemoveAddresses(addressesToRemove, etcHostsComment) - if err := hosts.Save(); err != nil { - return err - } - return nil -} - -func sliceContains(s []string, v string) bool { - for _, e := range s { - if strings.Contains(e, v) { - return true - } - } - return false -} diff --git a/functions/mqhandlers.go b/functions/mqhandlers.go index 8b813527..99589081 100644 --- a/functions/mqhandlers.go +++ b/functions/mqhandlers.go @@ -2,9 +2,7 @@ package functions import ( "encoding/json" - "log" "net" - "os" "strings" "time" @@ -16,7 +14,6 @@ import ( "github.com/gravitl/netclient/networking" "github.com/gravitl/netclient/wireguard" "github.com/gravitl/netmaker/models" - "github.com/gravitl/txeh" "golang.org/x/exp/slog" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) @@ -353,142 +350,6 @@ func parseServerFromTopic(topic string) string { return strings.Split(topic, "/")[3] } -// dnsUpdate - mq handler for host update dns//server -func dnsUpdate(client mqtt.Client, msg mqtt.Message) { - temp := os.TempDir() - lockfile := temp + "/netclient-lock" - if err := config.Lock(lockfile); err != nil { - slog.Error("could not create lock file", "error", err) - return - } - defer config.Unlock(lockfile) - var dns models.DNSUpdate - serverName := parseServerFromTopic(msg.Topic()) - server := config.GetServer(serverName) - if server == nil { - slog.Error("server not found in config", "server", serverName) - return - } - data, err := decryptMsg(serverName, msg.Payload()) - if err != nil { - return - } - if err := json.Unmarshal([]byte(data), &dns); err != nil { - slog.Error("error unmarshalling dns update", "error", err) - } - if config.Netclient().Debug { - log.Println("dnsUpdate received", dns) - } - var currentMessage = read("dns", lastDNSUpdate) - if currentMessage == string(data) { - slog.Info("cache hit on dns update ... skipping") - return - } - insert("dns", lastDNSUpdate, string(data)) - slog.Info("received dns update", "name", dns.Name, "address", dns.Address, "action", dns.Action) - applyDNSUpdate(dns) -} - -func applyDNSUpdate(dns models.DNSUpdate) { - if config.Netclient().Debug { - log.Println(dns) - } - hosts, err := txeh.NewHostsDefault() - if err != nil { - slog.Error("failed to read hosts file", "error", err) - return - } - switch dns.Action { - case models.DNSInsert: - // remove any existing entries - hosts.RemoveHost(dns.Name, etcHostsComment) - hosts.RemoveAddress(dns.Address, etcHostsComment) - hosts.AddHost(dns.Address, dns.Name, etcHostsComment) - case models.DNSDeleteByName: - hosts.RemoveHost(dns.Name, etcHostsComment) - case models.DNSDeleteByIP: - hosts.RemoveAddress(dns.Address, etcHostsComment) - case models.DNSReplaceName: - ok, ip, _ := hosts.HostAddressLookup(dns.Name, txeh.IPFamilyV4, etcHostsComment) - if !ok { - slog.Error("failed to find dns address for host", "host", dns.Name) - return - } - dns.Address = ip - hosts.RemoveHost(dns.Name, etcHostsComment) - hosts.AddHost(dns.Address, dns.NewName, etcHostsComment) - case models.DNSReplaceIP: - hosts.RemoveAddress(dns.Address, etcHostsComment) - hosts.AddHost(dns.NewAddress, dns.Name, etcHostsComment) - } - if err := hosts.Save(); err != nil { - slog.Error("error saving hosts file", "error", err) - return - } -} - -// dnsAll- mq handler for host update dnsall//server -func dnsAll(client mqtt.Client, msg mqtt.Message) { - temp := os.TempDir() - lockfile := temp + "/netclient-lock" - if err := config.Lock(lockfile); err != nil { - slog.Error("could not create lock file", "error", err) - return - } - defer config.Unlock(lockfile) - var dns []models.DNSUpdate - serverName := parseServerFromTopic(msg.Topic()) - server := config.GetServer(serverName) - if server == nil { - slog.Error("server not found in config", "server", serverName) - return - } - data, err := decryptMsg(serverName, msg.Payload()) - if err != nil { - return - } - if err := json.Unmarshal([]byte(data), &dns); err != nil { - slog.Error("error unmarshalling dns update", "error", err) - } - if config.Netclient().Debug { - log.Println("all dns", dns) - } - var currentMessage = read("dnsall", lastALLDNSUpdate) - slog.Info("received initial dns", "dns", dns) - if currentMessage == string(data) { - slog.Info("cache hit on all dns ... skipping") - if config.Netclient().Debug { - log.Println("dns cache", currentMessage, string(data)) - } - return - } - insert("dnsall", lastALLDNSUpdate, string(data)) - applyAllDNS(dns) -} - -func applyAllDNS(dns []models.DNSUpdate) { - hosts, err := txeh.NewHostsDefault() - if err != nil { - slog.Error("failed to read hosts file", "error", err) - return - } - for _, entry := range dns { - if entry.Action != models.DNSInsert { - slog.Info("invalid dns actions", "action", entry.Action) - continue - } - // remove any existing entries - hosts.RemoveHost(entry.Name, etcHostsComment) - hosts.RemoveAddress(entry.Address, etcHostsComment) - hosts.AddHost(entry.Address, entry.Name, etcHostsComment) - } - - if err := hosts.Save(); err != nil { - slog.Error("error saving hosts file", "error", err) - return - } -} - func getAllAllowedIPs(peers []wgtypes.PeerConfig) (cidrs []net.IPNet) { if len(peers) > 0 { // nil check for i := range peers { diff --git a/functions/uninstall.go b/functions/uninstall.go index 9f43dfe6..db781629 100644 --- a/functions/uninstall.go +++ b/functions/uninstall.go @@ -63,9 +63,6 @@ func LeaveNetwork(network string, isDaemon bool) ([]error, error) { if err := deleteLocalNetwork(&node); err != nil { faults = append(faults, fmt.Errorf("error deleting wireguard interface %w", err)) } - if err := deleteNetworkDNS(network); err != nil { - faults = append(faults, fmt.Errorf("error deleting dns entries %w", err)) - } // re-configure interface if daemon is calling leave if isDaemon { nc := wireguard.GetInterface()