diff --git a/coredns/resolver/clusterip_service_test.go b/coredns/resolver/clusterip_service_test.go index 0fdf5da02..4259e4420 100644 --- a/coredns/resolver/clusterip_service_test.go +++ b/coredns/resolver/clusterip_service_test.go @@ -69,7 +69,7 @@ func testClusterIPServiceInOneCluster() { Context("and it becomes disconnected", func() { BeforeEach(func() { - t.clusterStatus.ConnectedClusterIDs.RemoveAll() + t.clusterStatus.DisconnectAll() }) It("should return no DNS records", func() { @@ -139,7 +139,7 @@ func testClusterIPServiceInTwoClusters() { Context("and one is the local cluster", func() { BeforeEach(func() { - t.clusterStatus.LocalClusterID.Store(clusterID1) + t.clusterStatus.SetLocalClusterID(clusterID1) }) It("should consistently return its DNS record", func() { @@ -157,7 +157,7 @@ func testClusterIPServiceInTwoClusters() { } BeforeEach(func() { - t.clusterStatus.ConnectedClusterIDs.Remove(clusterID1) + t.clusterStatus.DisconnectClusterID(clusterID1) }) Context("and no specific cluster is requested", func() { @@ -187,7 +187,7 @@ func testClusterIPServiceInTwoClusters() { Context("and both become disconnected", func() { BeforeEach(func() { - t.clusterStatus.ConnectedClusterIDs.RemoveAll() + t.clusterStatus.DisconnectAll() }) It("should return no DNS records", func() { @@ -251,7 +251,7 @@ func testClusterIPServiceInTwoClusters() { Context("and a non-existent local cluster is specified", func() { BeforeEach(func() { - t.clusterStatus.LocalClusterID.Store("non-existent") + t.clusterStatus.SetLocalClusterID("non-existent") }) It("should consistently return the DNS records round-robin", func() { @@ -301,7 +301,7 @@ func testClusterIPServiceInThreeClusters() { Context("and one becomes disconnected", func() { BeforeEach(func() { - t.clusterStatus.ConnectedClusterIDs.Remove(clusterID3) + t.clusterStatus.DisconnectClusterID(clusterID3) }) It("should consistently return the connected clusters' DNS records round-robin", func() { @@ -333,7 +333,7 @@ func testClusterIPServiceInThreeClusters() { Context("and one becomes disconnected and one becomes unhealthy", func() { BeforeEach(func() { - t.clusterStatus.ConnectedClusterIDs.Remove(clusterID2) + t.clusterStatus.DisconnectClusterID(clusterID2) t.putEndpointSlice(newClusterIPEndpointSlice(namespace1, service1, clusterID3, serviceIP3, false)) }) diff --git a/coredns/resolver/fake/cluster_status.go b/coredns/resolver/fake/cluster_status.go index 2dbaa4b95..5931fc57c 100644 --- a/coredns/resolver/fake/cluster_status.go +++ b/coredns/resolver/fake/cluster_status.go @@ -19,29 +19,53 @@ limitations under the License. package fake import ( + "sync" "sync/atomic" - "github.com/submariner-io/admiral/pkg/stringset" + "k8s.io/apimachinery/pkg/util/sets" ) type ClusterStatus struct { - ConnectedClusterIDs stringset.Interface - LocalClusterID atomic.Value + mutex sync.Mutex + connectedClusterIDs sets.Set[string] + localClusterID atomic.Value } func NewClusterStatus(localClusterID string, isConnected ...string) *ClusterStatus { c := &ClusterStatus{ - ConnectedClusterIDs: stringset.NewSynchronized(isConnected...), + connectedClusterIDs: sets.New(isConnected...), } - c.LocalClusterID.Store(localClusterID) + c.localClusterID.Store(localClusterID) + return c } func (c *ClusterStatus) IsConnected(clusterID string) bool { - return c.ConnectedClusterIDs.Contains(clusterID) + c.mutex.Lock() + defer c.mutex.Unlock() + + return c.connectedClusterIDs.Has(clusterID) +} + +func (c *ClusterStatus) SetLocalClusterID(clusterID string) { + c.localClusterID.Store(clusterID) } func (c *ClusterStatus) GetLocalClusterID() string { - return c.LocalClusterID.Load().(string) + return c.localClusterID.Load().(string) +} + +func (c *ClusterStatus) DisconnectAll() { + c.mutex.Lock() + defer c.mutex.Unlock() + + c.connectedClusterIDs.Delete(c.connectedClusterIDs.UnsortedList()...) +} + +func (c *ClusterStatus) DisconnectClusterID(clusterID string) { + c.mutex.Lock() + defer c.mutex.Unlock() + + c.connectedClusterIDs.Delete(clusterID) } diff --git a/coredns/resolver/headless_service_test.go b/coredns/resolver/headless_service_test.go index 1505df0bc..a44dbc79a 100644 --- a/coredns/resolver/headless_service_test.go +++ b/coredns/resolver/headless_service_test.go @@ -137,7 +137,7 @@ func testHeadlessServiceInMultipleClusters() { Context("and one is on the local cluster", func() { BeforeEach(func() { - t.clusterStatus.LocalClusterID.Store(clusterID3) + t.clusterStatus.SetLocalClusterID(clusterID3) // If the local cluster EndpointSlice is created before the local K8s EndpointSlice, PutEndpointSlice should // return true to requeue. @@ -183,7 +183,7 @@ func testHeadlessServiceInMultipleClusters() { Context("and one becomes disconnected", func() { JustBeforeEach(func() { - t.clusterStatus.ConnectedClusterIDs.Remove(clusterID3) + t.clusterStatus.DisconnectClusterID(clusterID3) }) Context("and no specific cluster is requested", func() {