diff --git a/pkg/manager/impl/execution_manager.go b/pkg/manager/impl/execution_manager.go index 644b70fa4..0e5a78ff5 100644 --- a/pkg/manager/impl/execution_manager.go +++ b/pkg/manager/impl/execution_manager.go @@ -537,7 +537,8 @@ func (m *ExecutionManager) getInheritedExecMetadata(ctx context.Context, request // Produces execution-time attributes for workflow execution. // Defaults to overridable execution values set in the execution create request, then looks at the launch plan values -// (if any) before defaulting to values set in the matchable resource db. +// (if any) before defaulting to values set in the matchable resource db and further if matchable resources don't +// exist then defaults to one set in application configuration func (m *ExecutionManager) getExecutionConfig(ctx context.Context, request *admin.ExecutionCreateRequest, launchPlan *admin.LaunchPlan) (*admin.WorkflowExecutionConfig, error) { if request.Spec.MaxParallelism > 0 { @@ -565,7 +566,10 @@ func (m *ExecutionManager) getExecutionConfig(ctx context.Context, request *admi if resource != nil && resource.Attributes.GetWorkflowExecutionConfig() != nil { return resource.Attributes.GetWorkflowExecutionConfig(), nil } - return nil, nil + // Defaults to one from the application config + return &admin.WorkflowExecutionConfig{ + MaxParallelism: m.config.ApplicationConfiguration().GetTopLevelConfig().GetMaxParallelism(), + }, nil } func (m *ExecutionManager) launchSingleTaskExecution( diff --git a/pkg/manager/impl/execution_manager_test.go b/pkg/manager/impl/execution_manager_test.go index 02b2edc96..c6de600cf 100644 --- a/pkg/manager/impl/execution_manager_test.go +++ b/pkg/manager/impl/execution_manager_test.go @@ -9,6 +9,7 @@ import ( managerInterfaces "github.com/flyteorg/flyteadmin/pkg/manager/interfaces" managerMocks "github.com/flyteorg/flyteadmin/pkg/manager/mocks" + "github.com/flyteorg/flyteadmin/pkg/runtime" "github.com/flyteorg/flyteidl/clients/go/coreutils" @@ -3300,9 +3301,10 @@ func TestGetExecutionConfig_Spec(t *testing.T) { t.Errorf("When a user specifies max parallelism in a spec, the db should not be queried") return nil, nil } - + applicationConfig := runtime.NewConfigurationProvider() executionManager := ExecutionManager{ resourceManager: &resourceManager, + config: applicationConfig, } execConfig, err := executionManager.getExecutionConfig(context.TODO(), &admin.ExecutionCreateRequest{ Project: workflowIdentifier.Project, @@ -3329,6 +3331,26 @@ func TestGetExecutionConfig_Spec(t *testing.T) { }) assert.NoError(t, err) assert.Equal(t, execConfig.MaxParallelism, int32(50)) + + resourceManager = managerMocks.MockResourceManager{} + resourceManager.GetResourceFunc = func(ctx context.Context, + request managerInterfaces.ResourceRequest) (*managerInterfaces.ResourceResponse, error) { + return nil, nil + } + executionManager = ExecutionManager{ + resourceManager: &resourceManager, + config: applicationConfig, + } + + execConfig, err = executionManager.getExecutionConfig(context.TODO(), &admin.ExecutionCreateRequest{ + Project: workflowIdentifier.Project, + Domain: workflowIdentifier.Domain, + Spec: &admin.ExecutionSpec{}, + }, &admin.LaunchPlan{ + Spec: &admin.LaunchPlanSpec{}, + }) + assert.NoError(t, err) + assert.Equal(t, execConfig.MaxParallelism, int32(25)) } func TestResolvePermissions(t *testing.T) { diff --git a/pkg/rpc/adminservice/base.go b/pkg/rpc/adminservice/base.go index 95e851401..c3fa21eae 100644 --- a/pkg/rpc/adminservice/base.go +++ b/pkg/rpc/adminservice/base.go @@ -60,7 +60,7 @@ func NewAdminServer(kubeConfig, master string) *AdminService { configuration := runtime.NewConfigurationProvider() applicationConfiguration := configuration.ApplicationConfiguration().GetTopLevelConfig() - adminScope := promutils.NewScope(applicationConfiguration.MetricsScope).NewSubScope("admin") + adminScope := promutils.NewScope(applicationConfiguration.GetMetricsScope()).NewSubScope("admin") defer func() { if err := recover(); err != nil { @@ -92,10 +92,10 @@ func NewAdminServer(kubeConfig, master string) *AdminService { configuration, db) workflowExecutor := workflowengine.NewFlytePropeller( - applicationConfiguration.RoleNameKey, + applicationConfiguration.GetRoleNameKey(), execCluster, adminScope.NewSubScope("executor").NewSubScope("flytepropeller"), - configuration.NamespaceMappingConfiguration(), applicationConfiguration.EventVersion) + configuration.NamespaceMappingConfiguration(), applicationConfiguration.GetEventVersion()) logger.Info(context.Background(), "Successfully created a workflow executor engine") dataStorageClient, err := storage.NewDataStore(storeConfig, adminScope.NewSubScope("storage")) if err != nil { @@ -135,11 +135,11 @@ func NewAdminServer(kubeConfig, master string) *AdminService { }).GetRemoteURLInterface() workflowManager := manager.NewWorkflowManager( - db, configuration, workflowengine.NewCompiler(), dataStorageClient, applicationConfiguration.MetadataStoragePrefix, + db, configuration, workflowengine.NewCompiler(), dataStorageClient, applicationConfiguration.GetMetadataStoragePrefix(), adminScope.NewSubScope("workflow_manager")) namedEntityManager := manager.NewNamedEntityManager(db, configuration, adminScope.NewSubScope("named_entity_manager")) - executionEventWriter := eventWriter.NewWorkflowExecutionEventWriter(db, applicationConfiguration.AsyncEventsBufferSize) + executionEventWriter := eventWriter.NewWorkflowExecutionEventWriter(db, applicationConfiguration.GetAsyncEventsBufferSize()) go func() { executionEventWriter.Run() }() @@ -159,13 +159,13 @@ func NewAdminServer(kubeConfig, master string) *AdminService { // Serve profiling endpoints. go func() { err := profutils.StartProfilingServerWithDefaultHandlers( - context.Background(), applicationConfiguration.ProfilerPort, nil) + context.Background(), applicationConfiguration.GetProfilerPort(), nil) if err != nil { logger.Panicf(context.Background(), "Failed to Start profiling and Metrics server. Error, %v", err) } }() - nodeExecutionEventWriter := eventWriter.NewNodeExecutionEventWriter(db, applicationConfiguration.AsyncEventsBufferSize) + nodeExecutionEventWriter := eventWriter.NewNodeExecutionEventWriter(db, applicationConfiguration.GetAsyncEventsBufferSize()) go func() { nodeExecutionEventWriter.Run() }() @@ -179,7 +179,7 @@ func NewAdminServer(kubeConfig, master string) *AdminService { ExecutionManager: executionManager, NamedEntityManager: namedEntityManager, VersionManager: versionManager, - NodeExecutionManager: manager.NewNodeExecutionManager(db, configuration, applicationConfiguration.MetadataStoragePrefix, dataStorageClient, + NodeExecutionManager: manager.NewNodeExecutionManager(db, configuration, applicationConfiguration.GetMetadataStoragePrefix(), dataStorageClient, adminScope.NewSubScope("node_execution_manager"), urlData, eventPublisher, nodeExecutionEventWriter), TaskExecutionManager: manager.NewTaskExecutionManager(db, configuration, dataStorageClient, adminScope.NewSubScope("task_execution_manager"), urlData, eventPublisher), diff --git a/pkg/runtime/application_config_provider.go b/pkg/runtime/application_config_provider.go index 1bfad5aca..c98e1572a 100644 --- a/pkg/runtime/application_config_provider.go +++ b/pkg/runtime/application_config_provider.go @@ -37,7 +37,9 @@ var flyteAdminConfig = config.MustRegisterSection(flyteAdmin, &interfaces.Applic MetadataStoragePrefix: []string{"metadata", "admin"}, EventVersion: 1, AsyncEventsBufferSize: 100, + MaxParallelism: 25, }) + var schedulerConfig = config.MustRegisterSection(scheduler, &interfaces.SchedulerConfig{ ProfilerPort: config.Port{Port: 10253}, EventSchedulerConfig: interfaces.EventSchedulerConfig{ diff --git a/pkg/runtime/interfaces/application_configuration.go b/pkg/runtime/interfaces/application_configuration.go index dc6b11840..42c477dc6 100644 --- a/pkg/runtime/interfaces/application_configuration.go +++ b/pkg/runtime/interfaces/application_configuration.go @@ -57,6 +57,38 @@ type ApplicationConfig struct { EventVersion int `json:"eventVersion"` // Specifies the shared buffer size which is used to queue asynchronous event writes. AsyncEventsBufferSize int `json:"asyncEventsBufferSize"` + // Controls the maximum number of task nodes that can be run in parallel for the entire workflow. + // This is useful to achieve fairness. Note: MapTasks are regarded as one unit, + // and parallelism/concurrency of MapTasks is independent from this. + MaxParallelism int32 `json:"maxParallelism"` +} + +func (a *ApplicationConfig) GetRoleNameKey() string { + return a.RoleNameKey +} + +func (a *ApplicationConfig) GetMetricsScope() string { + return a.MetricsScope +} + +func (a *ApplicationConfig) GetProfilerPort() int { + return a.ProfilerPort +} + +func (a *ApplicationConfig) GetMetadataStoragePrefix() []string { + return a.MetadataStoragePrefix +} + +func (a *ApplicationConfig) GetEventVersion() int { + return a.EventVersion +} + +func (a *ApplicationConfig) GetAsyncEventsBufferSize() int { + return a.AsyncEventsBufferSize +} + +func (a *ApplicationConfig) GetMaxParallelism() int32 { + return a.MaxParallelism } // This section holds common config for AWS