diff --git a/flyteadmin/pkg/executioncluster/impl/random_cluster_selector.go b/flyteadmin/pkg/executioncluster/impl/random_cluster_selector.go index 9a5487a2f..10bdd7714 100644 --- a/flyteadmin/pkg/executioncluster/impl/random_cluster_selector.go +++ b/flyteadmin/pkg/executioncluster/impl/random_cluster_selector.go @@ -32,6 +32,7 @@ type RandomClusterSelector struct { equalWeightedAllClusters random.WeightedRandomList labelWeightedRandomMap map[string]random.WeightedRandomList resourceManager managerInterfaces.ResourceInterface + defaultExecutionLabel string } func getRandSource(seed string) (rand.Source, error) { @@ -111,6 +112,7 @@ func (s RandomClusterSelector) GetTarget(ctx context.Context, spec *executionclu return nil, err } } + var weightedRandomList random.WeightedRandomList if resource != nil && resource.Attributes.GetExecutionClusterLabel() != nil { label := resource.Attributes.GetExecutionClusterLabel().Value @@ -123,8 +125,17 @@ func (s RandomClusterSelector) GetTarget(ctx context.Context, spec *executionclu } else { logger.Debugf(ctx, "No override found for the spec %v", spec) } - // If there is no label associated (or) if the label is invalid, choose from all enabled clusters. - // Note that if there is a valid label with zero "Enabled" clusters, we still choose from all enabled ones. + + if weightedRandomList == nil { + if s.defaultExecutionLabel != "" { + if _, ok := s.labelWeightedRandomMap[s.defaultExecutionLabel]; ok { + weightedRandomList = s.labelWeightedRandomMap[s.defaultExecutionLabel] + } else { + logger.Warnf(ctx, "No cluster mapping found for the default execution label %s", s.defaultExecutionLabel) + } + } + } + if weightedRandomList == nil { weightedRandomList = s.equalWeightedAllClusters } @@ -148,6 +159,9 @@ func (s RandomClusterSelector) GetTarget(ctx context.Context, spec *executionclu func NewRandomClusterSelector(listTargets interfaces.ListTargetsInterface, config runtime.Configuration, db repositoryInterfaces.Repository) (interfaces.ClusterInterface, error) { + + defaultExecutionLabel := config.ClusterConfiguration().GetDefaultExecutionLabel() + equalWeightedAllClusters, err := convertToRandomWeightedList(context.Background(), listTargets.GetValidTargets()) if err != nil { return nil, err @@ -161,5 +175,6 @@ func NewRandomClusterSelector(listTargets interfaces.ListTargetsInterface, confi resourceManager: resources.NewResourceManager(db, config.ApplicationConfiguration()), equalWeightedAllClusters: equalWeightedAllClusters, ListTargetsInterface: listTargets, + defaultExecutionLabel: defaultExecutionLabel, }, nil } diff --git a/flyteadmin/pkg/executioncluster/impl/random_cluster_selector_test.go b/flyteadmin/pkg/executioncluster/impl/random_cluster_selector_test.go index b2d05c84d..e30a92f32 100644 --- a/flyteadmin/pkg/executioncluster/impl/random_cluster_selector_test.go +++ b/flyteadmin/pkg/executioncluster/impl/random_cluster_selector_test.go @@ -23,14 +23,16 @@ import ( "github.com/stretchr/testify/assert" ) -const testProject = "project" -const testDomain = "domain" -const testWorkflow = "name" - const ( - testCluster1 = "testcluster1" - testCluster2 = "testcluster2" - testCluster3 = "testcluster3" + testProject = "project" + testDomain = "domain" + testWorkflow = "name" + testCluster1 = "testcluster1" + testCluster2 = "testcluster2" + testCluster3 = "testcluster3" + clusterConfig1 = "clusters_config.yaml" + clusterConfig2 = "clusters_config2.yaml" + clusterConfig2WithDefaultLabel = "clusters_config2_default_label.yaml" ) func initTestConfig(fileName string) error { @@ -47,7 +49,7 @@ func initTestConfig(fileName string) error { } func getRandomClusterSelectorForTest(t *testing.T) interfaces2.ClusterInterface { - err := initTestConfig("clusters_config.yaml") + err := initTestConfig(clusterConfig1) assert.NoError(t, err) db := repo_mock.NewMockRepository() @@ -119,6 +121,74 @@ func getRandomClusterSelectorForTest(t *testing.T) interfaces2.ClusterInterface return randomCluster } +func getRandomClusterSelectorWithDefaultLabelForTest(t *testing.T, configFile string) interfaces2.ClusterInterface { + err := initTestConfig(configFile) + assert.NoError(t, err) + + db := repo_mock.NewMockRepository() + db.ResourceRepo().(*repo_mock.MockResourceRepo).GetFunction = func(ctx context.Context, ID repo_interface.ResourceID) (resource models.Resource, e error) { + assert.Equal(t, "EXECUTION_CLUSTER_LABEL", ID.ResourceType) + if ID.Project == "" { + return models.Resource{}, errors.NewFlyteAdminErrorf(codes.NotFound, + "Resource [%+v] not found", ID) + } + response := models.Resource{ + Project: ID.Project, + Domain: ID.Domain, + Workflow: ID.Workflow, + ResourceType: ID.ResourceType, + LaunchPlan: ID.LaunchPlan, + } + if ID.Project == testProject && ID.Domain == testDomain { + matchingAttributes := &admin.MatchingAttributes{ + Target: &admin.MatchingAttributes_ExecutionClusterLabel{ + ExecutionClusterLabel: &admin.ExecutionClusterLabel{ + Value: "two", + }, + }, + } + marshalledMatchingAttributes, _ := proto.Marshal(matchingAttributes) + response.Attributes = marshalledMatchingAttributes + } + return response, nil + } + configProvider := runtime.NewConfigurationProvider() + listTargetsProvider := mocks.ListTargetsInterface{} + validTargets := map[string]*executioncluster.ExecutionTarget{ + testCluster1: { + ID: testCluster1, + Enabled: true, + }, + testCluster2: { + ID: testCluster2, + Enabled: true, + }, + testCluster3: { + ID: testCluster3, + Enabled: true, + }, + } + targets := map[string]*executioncluster.ExecutionTarget{ + testCluster1: { + ID: testCluster1, + Enabled: true, + }, + testCluster2: { + ID: testCluster2, + Enabled: true, + }, + testCluster3: { + ID: testCluster3, + Enabled: true, + }, + } + listTargetsProvider.OnGetValidTargets().Return(validTargets) + listTargetsProvider.OnGetAllTargets().Return(targets) + randomCluster, err := NewRandomClusterSelector(&listTargetsProvider, configProvider, db) + assert.NoError(t, err) + return randomCluster +} + func TestRandomClusterSelectorGetTarget(t *testing.T) { cluster := getRandomClusterSelectorForTest(t) target, err := cluster.GetTarget(context.Background(), &executioncluster.ExecutionTargetSpec{TargetID: testCluster1}) @@ -198,3 +268,42 @@ func TestRandomClusterSelectorGetAllValidTargets(t *testing.T) { targets := cluster.GetValidTargets() assert.Equal(t, 2, len(targets)) } + +func TestRandomClusterSelectorGetTargetWithFallbackToDefault1(t *testing.T) { + cluster := getRandomClusterSelectorWithDefaultLabelForTest(t, clusterConfig2) + target, err := cluster.GetTarget(context.Background(), &executioncluster.ExecutionTargetSpec{ + Project: testProject, + Domain: "different", + Workflow: testWorkflow, + ExecutionID: "e3", + }) + assert.Nil(t, err) + assert.Equal(t, testCluster3, target.ID) + assert.True(t, target.Enabled) +} + +func TestRandomClusterSelectorGetTargetWithFallbackToDefault2(t *testing.T) { + cluster := getRandomClusterSelectorWithDefaultLabelForTest(t, clusterConfig2) + target, err := cluster.GetTarget(context.Background(), &executioncluster.ExecutionTargetSpec{ + Project: testProject, + Domain: testDomain, + Workflow: testWorkflow, + ExecutionID: "e3", + }) + assert.Nil(t, err) + assert.Equal(t, testCluster2, target.ID) + assert.True(t, target.Enabled) +} + +func TestRandomClusterSelectorGetTargetWithFallbackToDefault3(t *testing.T) { + cluster := getRandomClusterSelectorWithDefaultLabelForTest(t, clusterConfig2WithDefaultLabel) + target, err := cluster.GetTarget(context.Background(), &executioncluster.ExecutionTargetSpec{ + Project: testProject, + Domain: "different", + Workflow: testWorkflow, + ExecutionID: "e3", + }) + assert.Nil(t, err) + assert.Equal(t, testCluster1, target.ID) + assert.True(t, target.Enabled) +} diff --git a/flyteadmin/pkg/executioncluster/impl/testdata/clusters_config2.yaml b/flyteadmin/pkg/executioncluster/impl/testdata/clusters_config2.yaml new file mode 100644 index 000000000..6ea8cfd8b --- /dev/null +++ b/flyteadmin/pkg/executioncluster/impl/testdata/clusters_config2.yaml @@ -0,0 +1,34 @@ +clusters: + labelClusterMap: + one: + - id: testcluster1 + weight: 1 + two: + - id: testcluster2 + weight: 1 + three: + - id: testcluster3 + weight: 1 + clusterConfigs: + - name: "testcluster1" + endpoint: "testcluster1_endpoint" + enabled: true + auth: + type: "file_path" + tokenPath: "/path/to/testcluster1/token" + certPath: "/path/to/testcluster1/cert" + - name: "testcluster2" + endpoint: "testcluster2_endpoint" + enabled: true + auth: + type: "file_path" + tokenPath: "/path/to/testcluster2/token" + certPath: "/path/to/testcluster2/cert" + - name: "testcluster3" + endpoint: "testcluster2_endpoint" + enabled: true + auth: + type: "file_path" + tokenPath: "/path/to/testcluster2/token" + certPath: "/path/to/testcluster2/cert" + diff --git a/flyteadmin/pkg/executioncluster/impl/testdata/clusters_config2_default_label.yaml b/flyteadmin/pkg/executioncluster/impl/testdata/clusters_config2_default_label.yaml new file mode 100644 index 000000000..cb290e3fe --- /dev/null +++ b/flyteadmin/pkg/executioncluster/impl/testdata/clusters_config2_default_label.yaml @@ -0,0 +1,35 @@ +clusters: + defaultExecutionLabel: one + labelClusterMap: + one: + - id: testcluster1 + weight: 1 + two: + - id: testcluster2 + weight: 1 + three: + - id: testcluster3 + weight: 1 + clusterConfigs: + - name: "testcluster1" + endpoint: "testcluster1_endpoint" + enabled: true + auth: + type: "file_path" + tokenPath: "/path/to/testcluster1/token" + certPath: "/path/to/testcluster1/cert" + - name: "testcluster2" + endpoint: "testcluster2_endpoint" + enabled: true + auth: + type: "file_path" + tokenPath: "/path/to/testcluster2/token" + certPath: "/path/to/testcluster2/cert" + - name: "testcluster3" + endpoint: "testcluster2_endpoint" + enabled: true + auth: + type: "file_path" + tokenPath: "/path/to/testcluster2/token" + certPath: "/path/to/testcluster2/cert" + diff --git a/flyteadmin/pkg/runtime/cluster_config_provider.go b/flyteadmin/pkg/runtime/cluster_config_provider.go index 13c2cea7e..7de2f1759 100644 --- a/flyteadmin/pkg/runtime/cluster_config_provider.go +++ b/flyteadmin/pkg/runtime/cluster_config_provider.go @@ -35,6 +35,15 @@ func (p *ClusterConfigurationProvider) GetClusterConfigs() []interfaces.ClusterC return make([]interfaces.ClusterConfig, 0) } +func (p *ClusterConfigurationProvider) GetDefaultExecutionLabel() string { + if clusterConfig != nil { + clusters := clusterConfig.GetConfig().(*interfaces.Clusters) + return clusters.DefaultExecutionLabel + } + logger.Debug(context.Background(), "Failed to find default execution label in config. Will use random cluster if no execution label matches.") + return "" +} + func NewClusterConfigurationProvider() interfaces.ClusterConfiguration { clusterConfigProvider := ClusterConfigurationProvider{} clusterNameMap := make(map[string]bool) diff --git a/flyteadmin/pkg/runtime/interfaces/cluster_configuration.go b/flyteadmin/pkg/runtime/interfaces/cluster_configuration.go index f6e3932fa..100e6f004 100644 --- a/flyteadmin/pkg/runtime/interfaces/cluster_configuration.go +++ b/flyteadmin/pkg/runtime/interfaces/cluster_configuration.go @@ -42,8 +42,9 @@ func (auth Auth) GetToken() (string, error) { } type Clusters struct { - ClusterConfigs []ClusterConfig `json:"clusterConfigs"` - LabelClusterMap map[string][]ClusterEntity `json:"labelClusterMap"` + ClusterConfigs []ClusterConfig `json:"clusterConfigs"` + LabelClusterMap map[string][]ClusterEntity `json:"labelClusterMap"` + DefaultExecutionLabel string `json:"defaultExecutionLabel"` } //go:generate mockery -name ClusterConfiguration -case=underscore -output=../mocks -case=underscore @@ -56,4 +57,7 @@ type ClusterConfiguration interface { // Returns label cluster map for routing GetLabelClusterMap() map[string][]ClusterEntity + + // Returns default execution label used as fallback if no execution cluster was explicitly defined. + GetDefaultExecutionLabel() string } diff --git a/flyteadmin/pkg/runtime/mocks/cluster_configuration.go b/flyteadmin/pkg/runtime/mocks/cluster_configuration.go index 538d3b720..1fd8033cc 100644 --- a/flyteadmin/pkg/runtime/mocks/cluster_configuration.go +++ b/flyteadmin/pkg/runtime/mocks/cluster_configuration.go @@ -46,6 +46,38 @@ func (_m *ClusterConfiguration) GetClusterConfigs() []interfaces.ClusterConfig { return r0 } +type ClusterConfiguration_GetDefaultExecutionLabel struct { + *mock.Call +} + +func (_m ClusterConfiguration_GetDefaultExecutionLabel) Return(_a0 string) *ClusterConfiguration_GetDefaultExecutionLabel { + return &ClusterConfiguration_GetDefaultExecutionLabel{Call: _m.Call.Return(_a0)} +} + +func (_m *ClusterConfiguration) OnGetDefaultExecutionLabel() *ClusterConfiguration_GetDefaultExecutionLabel { + c_call := _m.On("GetDefaultExecutionLabel") + return &ClusterConfiguration_GetDefaultExecutionLabel{Call: c_call} +} + +func (_m *ClusterConfiguration) OnGetDefaultExecutionLabelMatch(matchers ...interface{}) *ClusterConfiguration_GetDefaultExecutionLabel { + c_call := _m.On("GetDefaultExecutionLabel", matchers...) + return &ClusterConfiguration_GetDefaultExecutionLabel{Call: c_call} +} + +// GetDefaultExecutionLabel provides a mock function with given fields: +func (_m *ClusterConfiguration) GetDefaultExecutionLabel() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + type ClusterConfiguration_GetLabelClusterMap struct { *mock.Call }