From 8340cb071d8cbbb9888aa2d82232b6ffb6e7de10 Mon Sep 17 00:00:00 2001 From: Max Ma Date: Tue, 12 Nov 2024 23:24:16 +0100 Subject: [PATCH] add cache for network nodes --- logic/nodes.go | 38 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 36 insertions(+), 2 deletions(-) diff --git a/logic/nodes.go b/logic/nodes.go index 34eebe2e4..0a3c5eafa 100644 --- a/logic/nodes.go +++ b/logic/nodes.go @@ -5,7 +5,9 @@ import ( "encoding/json" "errors" "fmt" + "maps" "net" + "slices" "sort" "sync" "time" @@ -24,8 +26,10 @@ import ( ) var ( - nodeCacheMutex = &sync.RWMutex{} - nodesCacheMap = make(map[string]models.Node) + nodeCacheMutex = &sync.RWMutex{} + nodeNetworkCacheMutex = &sync.RWMutex{} + nodesCacheMap = make(map[string]models.Node) + nodesNetworkCacheMap = make(map[string]map[string]models.Node) ) func getNodeFromCache(nodeID string) (node models.Node, ok bool) { @@ -48,6 +52,20 @@ func deleteNodeFromCache(nodeID string) { delete(nodesCacheMap, nodeID) nodeCacheMutex.Unlock() } +func deleteNodeFromNetworkCache(nodeID string, network string) { + nodeNetworkCacheMutex.Lock() + delete(nodesNetworkCacheMap[network], nodeID) + nodeNetworkCacheMutex.Unlock() +} + +func storeNodeInNetworkCache(node models.Node, network string) { + nodeNetworkCacheMutex.Lock() + if nodesNetworkCacheMap[network] == nil { + nodesNetworkCacheMap[network] = make(map[string]models.Node) + } + nodesNetworkCacheMap[network][node.ID.String()] = node + nodeNetworkCacheMutex.Unlock() +} func storeNodeInCache(node models.Node) { nodeCacheMutex.Lock() @@ -77,6 +95,11 @@ const ( // GetNetworkNodes - gets the nodes of a network func GetNetworkNodes(network string) ([]models.Node, error) { + nodeNetworkCacheMutex.Lock() + defer nodeNetworkCacheMutex.Unlock() + if networkNodes, ok := nodesNetworkCacheMap[network]; ok { + return slices.Collect(maps.Values(networkNodes)), nil + } allnodes, err := GetAllNodes() if err != nil { return []models.Node{}, err @@ -99,6 +122,11 @@ func GetHostNodes(host *models.Host) []models.Node { // GetNetworkNodesMemory - gets all nodes belonging to a network from list in memory func GetNetworkNodesMemory(allNodes []models.Node, network string) []models.Node { + nodeNetworkCacheMutex.Lock() + defer nodeNetworkCacheMutex.Unlock() + if networkNodes, ok := nodesNetworkCacheMap[network]; ok { + return slices.Collect(maps.Values(networkNodes)) + } var nodes = []models.Node{} for i := range allNodes { node := allNodes[i] @@ -123,6 +151,7 @@ func UpdateNodeCheckin(node *models.Node) error { } if servercfg.CacheEnabled() { storeNodeInCache(*node) + storeNodeInNetworkCache(*node, node.Network) } return nil } @@ -140,6 +169,7 @@ func UpsertNode(newNode *models.Node) error { } if servercfg.CacheEnabled() { storeNodeInCache(*newNode) + storeNodeInNetworkCache(*newNode, newNode.Network) } return nil } @@ -179,6 +209,7 @@ func UpdateNode(currentNode *models.Node, newNode *models.Node) error { } if servercfg.CacheEnabled() { storeNodeInCache(*newNode) + storeNodeInNetworkCache(*newNode, newNode.Network) } return nil } @@ -288,6 +319,7 @@ func DeleteNodeByID(node *models.Node) error { } if servercfg.CacheEnabled() { deleteNodeFromCache(node.ID.String()) + deleteNodeFromNetworkCache(node.ID.String(), node.Network) } if servercfg.IsDNSMode() { SetDNS() @@ -459,6 +491,7 @@ func GetNodeByID(uuid string) (models.Node, error) { } if servercfg.CacheEnabled() { storeNodeInCache(node) + storeNodeInNetworkCache(node, node.Network) } return node, nil } @@ -612,6 +645,7 @@ func createNode(node *models.Node) error { } if servercfg.CacheEnabled() { storeNodeInCache(*node) + storeNodeInNetworkCache(*node, node.Network) } if _, ok := allocatedIpMap[node.Network]; ok { if node.Address.IP != nil {