Skip to content

Commit

Permalink
De-duplicate code, remove reflection for default case, shorten distri…
Browse files Browse the repository at this point in the history
…bution methods name, ran codegen on manifests

Signed-off-by: Akram Ben Aissi <[email protected]>

Signed-off-by: Akram Ben Aissi <[email protected]>
  • Loading branch information
akram committed Jun 5, 2023
1 parent c7ea155 commit 095cb58
Show file tree
Hide file tree
Showing 9 changed files with 114 additions and 80 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ import (
"context"
"fmt"
"math"
"reflect"
"runtime"
"time"

"github.com/argoproj/pkg/stats"
Expand Down Expand Up @@ -217,19 +215,10 @@ func getClusterFilter(kubeClient *kubernetes.Clientset, settingsMgr *settings.Se
shard, err = sharding.InferShard()
errors.CheckError(err)
}
distributionFunction := sharding.GetShardByIdUsingHashDistributionFunction()
log.Infof("Processing clusters from shard %d", shard)
db := db.NewDB(settingsMgr.GetNamespace(), settingsMgr, kubeClient)
log.Infof("Using filter function: %s", shardingAlgorithm)
switch {
case shardingAlgorithm == common.RoundRobinShardingAlgorithm:
distributionFunction = sharding.GetShardByIndexModuloReplicasCountDistributionFunction(db, shardingAlgorithm)
case shardingAlgorithm == common.LegacyShardingAlgorithm:
default:
distributionFunctionName := runtime.FuncForPC(reflect.ValueOf(distributionFunction).Pointer())
log.Warnf("No distribution function named '%s' found. Defaulting to '%s'", shardingAlgorithm, distributionFunctionName)
}

distributionFunction := sharding.GetDistributionFunction(db, shardingAlgorithm)
clusterFilter = sharding.GetClusterFilter(distributionFunction, shard)
} else {
log.Info("Processing all cluster shards")
Expand Down
119 changes: 67 additions & 52 deletions controller/sharding/sharding.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ import (
"hash/fnv"
"math"
"os"
"reflect"
"runtime"
"sort"
"strconv"
"strings"
Expand All @@ -23,22 +21,13 @@ import (
// Make it overridable for testing
var osHostnameFunction = os.Hostname

func InferShard() (int, error) {
hostname, err := osHostnameFunction()
if err != nil {
return 0, err
}
parts := strings.Split(hostname, "-")
if len(parts) == 0 {
return 0, fmt.Errorf("hostname should ends with shard number separated by '-' but got: %s", hostname)
}
shard, err := strconv.Atoi(parts[len(parts)-1])
if err != nil {
return 0, fmt.Errorf("hostname should ends with shard number separated by '-' but got: %s", hostname)
}
return int(shard), nil
}
type DistributionFunction func(c *v1alpha1.Cluster) int
type ClusterFilterFunction func(c *v1alpha1.Cluster) bool

// GetClusterFilter returns a ClusterFilterFunction which is a function taking a cluster as a parameter
// and returns wheter or not the cluster should be processed by a given shard. It calls the distributionFunction
// to determine which shard will process the cluster, and if the given shard is equal to the calculated shard
// the function will return true.
func GetClusterFilter(distributionFunction DistributionFunction, shard int) ClusterFilterFunction {
replicas := env.ParseNumFromEnv(common.EnvControllerReplicas, 0, 0, math.MaxInt32)
return func(c *v1alpha1.Cluster) bool {
Expand All @@ -57,22 +46,57 @@ func GetClusterFilter(distributionFunction DistributionFunction, shard int) Clus
}
}

// GetDistributionFunction returns which DistributionFunction should be used based on the passed algorithm and
// the current datas.
func GetDistributionFunction(db db.ArgoDB, shardingAlgorithm string) DistributionFunction {
log.Infof("Using filter function: %s", shardingAlgorithm)
distributionFunction := GetShardByIdUsingHashDistributionFunction()
switch {
case shardingAlgorithm == common.RoundRobinShardingAlgorithm:
distributionFunction = GetShardByIndexModuloReplicasCountDistributionFunction(db, shardingAlgorithm)
case shardingAlgorithm == common.LegacyShardingAlgorithm:
distributionFunction = GetShardByIdUsingHashDistributionFunction()
distributionFunction := LegacyDistributionFunction()
switch shardingAlgorithm {
case common.RoundRobinShardingAlgorithm:
distributionFunction = RoundRobinDistributionFunction(db)
case common.LegacyShardingAlgorithm:
distributionFunction = LegacyDistributionFunction()
default:
distributionFunctionName := runtime.FuncForPC(reflect.ValueOf(distributionFunction).Pointer())
log.Warnf("distribution type %s is not supported, defaulting to %s", shardingAlgorithm, distributionFunctionName)
log.Warnf("distribution type %s is not supported, defaulting to %s", shardingAlgorithm, common.DefaultShardingAlgorithm)
}
return distributionFunction
}

func GetShardByIndexModuloReplicasCountDistributionFunction(db db.ArgoDB, shardingAlgorithm string) DistributionFunction {
// LegacyDistributionFunction returns a DistributionFunction using a stable distribution algorithm:
// for a given cluster the function will return the shard number based on the cluster id. This function
// is lightweight and can be distributed easily, however, it does not ensure an homogenous distribution as
// some shards may get assigned more clusters than others. It is the legacy function distribution that is
// kept for compatibility reasons
func LegacyDistributionFunction() DistributionFunction {
replicas := env.ParseNumFromEnv(common.EnvControllerReplicas, 0, 0, math.MaxInt32)
return func(c *v1alpha1.Cluster) int {
if replicas == 0 {
return -1
}
if c == nil {
return 0
}
id := c.ID
log.Debugf("Calculating cluster shard for cluster id: %s", id)
if id == "" {
return 0
} else {
h := fnv.New32a()
_, _ = h.Write([]byte(id))
shard := int32(h.Sum32() % uint32(replicas))
log.Infof("Cluster with id=%s will be processed by shard %d", id, shard)
return int(shard)
}
}
}

// RoundRobinDistributionFunction returns a DistributionFunction using an homogeneous distribution algorithm:
// for a given cluster the function will return the shard number based on the modulo of the cluster rank in
// the cluster's list sorted by uid on the shard number.
// This function ensures an homogenous distribution: each shards got assigned the same number of
// clusters +/-1 , but with the drawback of a reshuffling of clusters accross shards in case of some changes
// in the cluster list
func RoundRobinDistributionFunction(db db.ArgoDB) DistributionFunction {
replicas := env.ParseNumFromEnv(common.EnvControllerReplicas, 0, 0, math.MaxInt32)
return func(c *v1alpha1.Cluster) int {
if replicas > 0 {
Expand All @@ -95,6 +119,23 @@ func GetShardByIndexModuloReplicasCountDistributionFunction(db db.ArgoDB, shardi
}
}

// InferShard extracts the shard index based on its hostname.
func InferShard() (int, error) {
hostname, err := osHostnameFunction()
if err != nil {
return 0, err
}
parts := strings.Split(hostname, "-")
if len(parts) == 0 {
return 0, fmt.Errorf("hostname should ends with shard number separated by '-' but got: %s", hostname)
}
shard, err := strconv.Atoi(parts[len(parts)-1])
if err != nil {
return 0, fmt.Errorf("hostname should ends with shard number separated by '-' but got: %s", hostname)
}
return int(shard), nil
}

func getSortedClustersList(db db.ArgoDB) []v1alpha1.Cluster {
ctx := context.Background()
clustersList, dbErr := db.ListClusters(ctx)
Expand All @@ -121,29 +162,3 @@ func createClusterIndexByClusterIdMap(db db.ArgoDB) map[string]int {
}
return clusterIndexedByClusterId
}

func GetShardByIdUsingHashDistributionFunction() DistributionFunction {
replicas := env.ParseNumFromEnv(common.EnvControllerReplicas, 0, 0, math.MaxInt32)
return func(c *v1alpha1.Cluster) int {
if replicas == 0 {
return -1
}
if c == nil {
return 0
}
id := c.ID
log.Debugf("Calculating cluster shard for cluster id: %s", id)
if id == "" {
return 0
} else {
h := fnv.New32a()
_, _ = h.Write([]byte(id))
shard := int32(h.Sum32() % uint32(replicas))
log.Infof("Cluster with id=%s will be processed by shard %d", id, shard)
return int(shard)
}
}
}

type DistributionFunction func(c *v1alpha1.Cluster) int
type ClusterFilterFunction func(c *v1alpha1.Cluster) bool
28 changes: 14 additions & 14 deletions controller/sharding/sharding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,29 +15,29 @@ import (

func TestGetShardByID_NotEmptyID(t *testing.T) {
os.Setenv(common.EnvControllerReplicas, "1")
assert.Equal(t, 0, GetShardByIdUsingHashDistributionFunction()(&v1alpha1.Cluster{ID: "1"}))
assert.Equal(t, 0, GetShardByIdUsingHashDistributionFunction()(&v1alpha1.Cluster{ID: "2"}))
assert.Equal(t, 0, GetShardByIdUsingHashDistributionFunction()(&v1alpha1.Cluster{ID: "3"}))
assert.Equal(t, 0, GetShardByIdUsingHashDistributionFunction()(&v1alpha1.Cluster{ID: "4"}))
assert.Equal(t, 0, LegacyDistributionFunction()(&v1alpha1.Cluster{ID: "1"}))
assert.Equal(t, 0, LegacyDistributionFunction()(&v1alpha1.Cluster{ID: "2"}))
assert.Equal(t, 0, LegacyDistributionFunction()(&v1alpha1.Cluster{ID: "3"}))
assert.Equal(t, 0, LegacyDistributionFunction()(&v1alpha1.Cluster{ID: "4"}))
}

func TestGetShardByID_EmptyID(t *testing.T) {
os.Setenv(common.EnvControllerReplicas, "1")
distributionFunction := GetShardByIdUsingHashDistributionFunction
distributionFunction := LegacyDistributionFunction
shard := distributionFunction()(&v1alpha1.Cluster{})
assert.Equal(t, 0, shard)
}

func TestGetShardByID_NoReplicas(t *testing.T) {
os.Setenv(common.EnvControllerReplicas, "0")
distributionFunction := GetShardByIdUsingHashDistributionFunction
distributionFunction := LegacyDistributionFunction
shard := distributionFunction()(&v1alpha1.Cluster{})
assert.Equal(t, -1, shard)
}

func TestGetShardByID_NoReplicasUsingHashDistributionFunction(t *testing.T) {
os.Setenv(common.EnvControllerReplicas, "0")
distributionFunction := GetShardByIdUsingHashDistributionFunction
distributionFunction := LegacyDistributionFunction
shard := distributionFunction()(&v1alpha1.Cluster{})
assert.Equal(t, -1, shard)
}
Expand All @@ -47,7 +47,7 @@ func TestGetShardByID_NoReplicasUsingHashDistributionFunctionWithClusters(t *tes
// Test with replicas set to 0
os.Setenv(common.EnvControllerReplicas, "0")
os.Setenv(common.EnvControllerShardingAlgorithm, common.RoundRobinShardingAlgorithm)
distributionFunction := GetShardByIndexModuloReplicasCountDistributionFunction(db, common.RoundRobinShardingAlgorithm)
distributionFunction := RoundRobinDistributionFunction(db)
assert.Equal(t, -1, distributionFunction(nil))
assert.Equal(t, -1, distributionFunction(&cluster1))
assert.Equal(t, -1, distributionFunction(&cluster2))
Expand Down Expand Up @@ -180,7 +180,7 @@ func TestGetShardByIndexModuloReplicasCountDistributionFunction2(t *testing.T) {
db, cluster1, cluster2, cluster3, cluster4, cluster5 := createTestClusters()
// Test with replicas set to 1
os.Setenv(common.EnvControllerReplicas, "1")
distributionFunction := GetShardByIndexModuloReplicasCountDistributionFunction(db, common.DefaultShardingAlgorithm)
distributionFunction := RoundRobinDistributionFunction(db)
assert.Equal(t, 0, distributionFunction(nil))
assert.Equal(t, 0, distributionFunction(&cluster1))
assert.Equal(t, 0, distributionFunction(&cluster2))
Expand All @@ -190,7 +190,7 @@ func TestGetShardByIndexModuloReplicasCountDistributionFunction2(t *testing.T) {

// Test with replicas set to 2
os.Setenv(common.EnvControllerReplicas, "2")
distributionFunction = GetShardByIndexModuloReplicasCountDistributionFunction(db, common.DefaultShardingAlgorithm)
distributionFunction = RoundRobinDistributionFunction(db)
assert.Equal(t, 0, distributionFunction(nil))
assert.Equal(t, 0, distributionFunction(&cluster1))
assert.Equal(t, 1, distributionFunction(&cluster2))
Expand All @@ -200,7 +200,7 @@ func TestGetShardByIndexModuloReplicasCountDistributionFunction2(t *testing.T) {

// // Test with replicas set to 3
os.Setenv(common.EnvControllerReplicas, "3")
distributionFunction = GetShardByIndexModuloReplicasCountDistributionFunction(db, common.DefaultShardingAlgorithm)
distributionFunction = RoundRobinDistributionFunction(db)
assert.Equal(t, 0, distributionFunction(nil))
assert.Equal(t, 0, distributionFunction(&cluster1))
assert.Equal(t, 1, distributionFunction(&cluster2))
Expand All @@ -223,7 +223,7 @@ func TestGetShardByIndexModuloReplicasCountDistributionFunctionWhenClusterNumber
}
db.On("ListClusters", mock.Anything).Return(clusterList, nil)
os.Setenv(common.EnvControllerReplicas, "2")
distributionFunction := GetShardByIndexModuloReplicasCountDistributionFunction(&db, common.DefaultShardingAlgorithm)
distributionFunction := RoundRobinDistributionFunction(&db)
for i, c := range clusterList.Items {
assert.Equal(t, i%2, distributionFunction(&c))
}
Expand All @@ -243,7 +243,7 @@ func TestGetShardByIndexModuloReplicasCountDistributionFunctionWhenClusterIsAdde

// Test with replicas set to 2
os.Setenv(common.EnvControllerReplicas, "2")
distributionFunction := GetShardByIndexModuloReplicasCountDistributionFunction(&db, common.DefaultShardingAlgorithm)
distributionFunction := RoundRobinDistributionFunction(&db)
assert.Equal(t, 0, distributionFunction(nil))
assert.Equal(t, 0, distributionFunction(&cluster1))
assert.Equal(t, 1, distributionFunction(&cluster2))
Expand All @@ -265,7 +265,7 @@ func TestGetShardByIndexModuloReplicasCountDistributionFunctionWhenClusterIsAdde
func TestGetShardByIndexModuloReplicasCountDistributionFunction(t *testing.T) {
db, cluster1, cluster2, _, _, _ := createTestClusters()
os.Setenv(common.EnvControllerReplicas, "2")
distributionFunction := GetShardByIndexModuloReplicasCountDistributionFunction(db, common.DefaultShardingAlgorithm)
distributionFunction := RoundRobinDistributionFunction(db)

// Test that the function returns the correct shard for cluster1 and cluster2
expectedShardForCluster1 := 0
Expand Down
4 changes: 2 additions & 2 deletions controller/sharding/shuffle_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func TestLargeShuffle(t *testing.T) {
db.On("ListClusters", mock.Anything).Return(clusterList, nil)
// Test with replicas set to 256
os.Setenv(common.EnvControllerReplicas, "256")
distributionFunction := GetShardByIndexModuloReplicasCountDistributionFunction(&db, common.DefaultShardingAlgorithm)
distributionFunction := RoundRobinDistributionFunction(&db)
for i, c := range clusterList.Items {
assert.Equal(t, i%2567, distributionFunction(&c))
}
Expand All @@ -48,7 +48,7 @@ func TestShuffle(t *testing.T) {

// Test with replicas set to 3
os.Setenv(common.EnvControllerReplicas, "3")
distributionFunction := GetShardByIndexModuloReplicasCountDistributionFunction(&db, common.DefaultShardingAlgorithm)
distributionFunction := RoundRobinDistributionFunction(&db)
assert.Equal(t, 0, distributionFunction(nil))
assert.Equal(t, 0, distributionFunction(&cluster1))
assert.Equal(t, 1, distributionFunction(&cluster2))
Expand Down
6 changes: 6 additions & 0 deletions manifests/core-install.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18899,6 +18899,12 @@ spec:
key: application.namespaces
name: argocd-cmd-params-cm
optional: true
- name: ARGOCD_CONTROLLER_SHARDING_ALGORITHM
valueFrom:
configMapKeyRef:
key: controller.sharding.algorithm
name: argocd-cmd-params-cm
optional: true
- name: ARGOCD_APPLICATION_CONTROLLER_KUBECTL_PARALLELISM_LIMIT
valueFrom:
configMapKeyRef:
Expand Down
6 changes: 6 additions & 0 deletions manifests/ha/install.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20686,6 +20686,12 @@ spec:
key: application.namespaces
name: argocd-cmd-params-cm
optional: true
- name: ARGOCD_CONTROLLER_SHARDING_ALGORITHM
valueFrom:
configMapKeyRef:
key: controller.sharding.algorithm
name: argocd-cmd-params-cm
optional: true
- name: ARGOCD_APPLICATION_CONTROLLER_KUBECTL_PARALLELISM_LIMIT
valueFrom:
configMapKeyRef:
Expand Down
6 changes: 6 additions & 0 deletions manifests/ha/namespace-install.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2714,6 +2714,12 @@ spec:
key: application.namespaces
name: argocd-cmd-params-cm
optional: true
- name: ARGOCD_CONTROLLER_SHARDING_ALGORITHM
valueFrom:
configMapKeyRef:
key: controller.sharding.algorithm
name: argocd-cmd-params-cm
optional: true
- name: ARGOCD_APPLICATION_CONTROLLER_KUBECTL_PARALLELISM_LIMIT
valueFrom:
configMapKeyRef:
Expand Down
6 changes: 6 additions & 0 deletions manifests/install.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19726,6 +19726,12 @@ spec:
key: application.namespaces
name: argocd-cmd-params-cm
optional: true
- name: ARGOCD_CONTROLLER_SHARDING_ALGORITHM
valueFrom:
configMapKeyRef:
key: controller.sharding.algorithm
name: argocd-cmd-params-cm
optional: true
- name: ARGOCD_APPLICATION_CONTROLLER_KUBECTL_PARALLELISM_LIMIT
valueFrom:
configMapKeyRef:
Expand Down
6 changes: 6 additions & 0 deletions manifests/namespace-install.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1754,6 +1754,12 @@ spec:
key: application.namespaces
name: argocd-cmd-params-cm
optional: true
- name: ARGOCD_CONTROLLER_SHARDING_ALGORITHM
valueFrom:
configMapKeyRef:
key: controller.sharding.algorithm
name: argocd-cmd-params-cm
optional: true
- name: ARGOCD_APPLICATION_CONTROLLER_KUBECTL_PARALLELISM_LIMIT
valueFrom:
configMapKeyRef:
Expand Down

0 comments on commit 095cb58

Please sign in to comment.