diff --git a/flyteadmin/pkg/artifacts/registry.go b/flyteadmin/pkg/artifacts/registry.go index 5b5ebfb8fa..6dc4997bff 100644 --- a/flyteadmin/pkg/artifacts/registry.go +++ b/flyteadmin/pkg/artifacts/registry.go @@ -21,10 +21,14 @@ type ArtifactRegistry struct { Client artifacts.ArtifactRegistryClient } -func (a *ArtifactRegistry) RegisterArtifactProducer(ctx context.Context, id *core.Identifier, ti core.TypedInterface) { +func (a *ArtifactRegistry) RegisterArtifactProducer(ctx context.Context, id *core.Identifier, ti *core.TypedInterface) error { if a == nil || a.Client == nil { logger.Debugf(ctx, "Artifact Client not configured, skipping registration for task [%+v]", id) - return + return nil + } + if ti.GetOutputs() == nil { + logger.Infof(ctx, "No outputs for [%+v], skipping registration", id) + return nil } ap := &artifacts.ArtifactProducer{ @@ -36,26 +40,37 @@ func (a *ArtifactRegistry) RegisterArtifactProducer(ctx context.Context, id *cor }) if err != nil { logger.Errorf(ctx, "Failed to register artifact producer for task [%+v] with err: %v", id, err) + return err } logger.Debugf(ctx, "Registered artifact producer [%+v]", id) + + return nil } -func (a *ArtifactRegistry) RegisterArtifactConsumer(ctx context.Context, id *core.Identifier, pm core.ParameterMap) { +func (a *ArtifactRegistry) RegisterArtifactConsumer(ctx context.Context, id *core.Identifier, pm *core.ParameterMap) error { if a == nil || a.Client == nil { logger.Debugf(ctx, "Artifact Client not configured, skipping registration for consumer [%+v]", id) - return + return nil } + if pm.GetParameters() == nil { + logger.Infof(ctx, "No inputs for [%+v], skipping registration", id) + return nil + } + ac := &artifacts.ArtifactConsumer{ EntityId: id, - Inputs: &pm, + Inputs: pm, } _, err := a.Client.RegisterConsumer(ctx, &artifacts.RegisterConsumerRequest{ Consumers: []*artifacts.ArtifactConsumer{ac}, }) if err != nil { logger.Errorf(ctx, "Failed to register artifact consumer for entity [%+v] with err: %v", id, err) + return err } logger.Debugf(ctx, "Registered artifact consumer [%+v]", id) + + return nil } func (a *ArtifactRegistry) RegisterTrigger(ctx context.Context, plan *admin.LaunchPlan) error { diff --git a/flyteadmin/pkg/manager/impl/launch_plan_manager.go b/flyteadmin/pkg/manager/impl/launch_plan_manager.go index e7f8ddb3d1..18418fb9fc 100644 --- a/flyteadmin/pkg/manager/impl/launch_plan_manager.go +++ b/flyteadmin/pkg/manager/impl/launch_plan_manager.go @@ -129,16 +129,20 @@ func (m *LaunchPlanManager) CreateLaunchPlan( } m.metrics.SpecSizeBytes.Observe(float64(len(launchPlanModel.Spec))) m.metrics.ClosureSizeBytes.Observe(float64(len(launchPlanModel.Closure))) + // TODO: Artifact feature gate, remove when ready - if m.artifactRegistry.GetClient() != nil { - go func() { - ceCtx := context.TODO() - if launchPlan.Spec.DefaultInputs == nil { - logger.Debugf(ceCtx, "Insufficient fields to submit launchplan interface %v", launchPlan.Id) - return + if m.config.ApplicationConfiguration().GetTopLevelConfig().FeatureGates.EnableArtifacts { + // As an optimization, run through the interface, and only send to artifacts service it there are artifacts + for _, param := range launchPlan.GetSpec().GetDefaultInputs().GetParameters() { + if param.GetArtifactId() != nil || param.GetArtifactQuery() != nil { + err = m.artifactRegistry.RegisterArtifactConsumer(ctx, launchPlan.Id, launchPlan.Spec.DefaultInputs) + if err != nil { + logger.Errorf(ctx, "failed RegisterArtifactConsumer for launch plan [%+v] with err: %v", launchPlan.Id, err) + return nil, err + } + break } - m.artifactRegistry.RegisterArtifactConsumer(ceCtx, launchPlan.Id, *launchPlan.Spec.DefaultInputs) - }() + } } return &admin.LaunchPlanCreateResponse{}, nil diff --git a/flyteadmin/pkg/manager/impl/launch_plan_manager_test.go b/flyteadmin/pkg/manager/impl/launch_plan_manager_test.go index 73608d034b..ccdc231f51 100644 --- a/flyteadmin/pkg/manager/impl/launch_plan_manager_test.go +++ b/flyteadmin/pkg/manager/impl/launch_plan_manager_test.go @@ -340,6 +340,60 @@ func TestCreateLaunchPlanValidateCreate(t *testing.T) { assert.True(t, proto.Equal(expectedResponse, response)) } +func TestCreateLaunchPlan_ArtifactBehavior(t *testing.T) { + // Test that enabling artifacts feature flag will not call RegisterArtifactConsumer if no artifact queries present. + repository := getMockRepositoryForLpTest() + repository.LaunchPlanRepo().(*repositoryMocks.MockLaunchPlanRepo).SetGetCallback( + func(input interfaces.Identifier) (models.LaunchPlan, error) { + return models.LaunchPlan{}, errors.New("foo") + }) + client := artifactMocks.ArtifactRegistryClient{} + registry := artifacts.ArtifactRegistry{ + Client: &client, + } + + client.On("RegisterConsumer", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + req := args.Get(1).(*artifactsIdl.RegisterConsumerRequest) + id := req.Consumers[0].EntityId + assert.Equal(t, "project", id.Project) + assert.Equal(t, "domain", id.Domain) + assert.Equal(t, "name", id.Name) + assert.Equal(t, "version", id.Version) + }).Return(&artifactsIdl.RegisterResponse{}, nil) + + setDefaultWorkflowCallbackForLpTest(repository) + mockConfig := getMockConfigForLpTest() + mockConfig.(*runtimeMocks.MockConfigurationProvider).ApplicationConfiguration().GetTopLevelConfig().FeatureGates.EnableArtifacts = true + lpManager := NewLaunchPlanManager(repository, mockConfig, mockScheduler, mockScope.NewTestScope(), ®istry) + request := testutils.GetLaunchPlanRequest() + response, err := lpManager.CreateLaunchPlan(context.Background(), request) + assert.Nil(t, err) + + expectedResponse := &admin.LaunchPlanCreateResponse{} + assert.True(t, proto.Equal(expectedResponse, response)) + client.AssertNotCalled(t, "RegisterConsumer", mock.Anything, mock.Anything, mock.Anything) + + // If the launch plan interface has a query however, then the service should be called. + aq := &core.Parameter_ArtifactQuery{ + ArtifactQuery: &core.ArtifactQuery{ + Identifier: &core.ArtifactQuery_ArtifactId{ + ArtifactId: testutils.GetArtifactID(), + }, + }, + } + request.GetSpec().GetDefaultInputs().GetParameters()["foo"] = &core.Parameter{ + Var: &core.Variable{ + Type: &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_STRING}}, + }, + Behavior: aq, + } + response, err = lpManager.CreateLaunchPlan(context.Background(), request) + assert.Nil(t, err) + expectedResponse = &admin.LaunchPlanCreateResponse{} + assert.True(t, proto.Equal(expectedResponse, response)) + client.AssertCalled(t, "RegisterConsumer", mock.Anything, mock.Anything, mock.Anything) +} + func TestCreateLaunchPlanNoWorkflowInterface(t *testing.T) { repository := getMockRepositoryForLpTest() repository.LaunchPlanRepo().(*repositoryMocks.MockLaunchPlanRepo).SetGetCallback( diff --git a/flyteadmin/pkg/manager/impl/task_manager.go b/flyteadmin/pkg/manager/impl/task_manager.go index 267d717a88..9222bf2974 100644 --- a/flyteadmin/pkg/manager/impl/task_manager.go +++ b/flyteadmin/pkg/manager/impl/task_manager.go @@ -6,7 +6,6 @@ import ( "strconv" "time" - "github.com/golang/protobuf/proto" "github.com/golang/protobuf/ptypes" "github.com/prometheus/client_golang/prometheus" "google.golang.org/grpc/codes" @@ -133,17 +132,23 @@ func (t *TaskManager) CreateTask( contextWithRuntimeMeta, common.RuntimeVersionKey, finalizedRequest.Spec.Template.Metadata.Runtime.Version) t.metrics.Registered.Inc(contextWithRuntimeMeta) } + // TODO: Artifact feature gate, remove when ready - if t.artifactRegistry.GetClient() != nil { - tIfaceCopy := proto.Clone(finalizedRequest.Spec.Template.Interface).(*core.TypedInterface) - go func() { - ceCtx := context.TODO() - if finalizedRequest.Spec.Template.Interface == nil { - logger.Debugf(ceCtx, "Task [%+v] has no interface, skipping registration", finalizedRequest.Id) - return + if t.config.ApplicationConfiguration().GetTopLevelConfig().FeatureGates.EnableArtifacts { + if finalizedRequest.GetSpec().GetTemplate().GetInterface().GetOutputs() != nil { + // As optimization, iterate over the outputs, if any, and only register interface with artifacts service + // if there are any artifact outputs. + for _, outputVar := range finalizedRequest.GetSpec().GetTemplate().GetInterface().GetOutputs().GetVariables() { + if outputVar.GetArtifactPartialId() != nil { + err = t.artifactRegistry.RegisterArtifactProducer(ctx, finalizedRequest.Id, finalizedRequest.GetSpec().GetTemplate().GetInterface()) + if err != nil { + logger.Errorf(ctx, "Failed RegisterArtifactProducer for task [%+v] with err: %v", request.Id, err) + return nil, err + } + break + } } - t.artifactRegistry.RegisterArtifactProducer(ceCtx, finalizedRequest.Id, *tIfaceCopy) - }() + } } return &admin.TaskCreateResponse{}, nil diff --git a/flyteadmin/pkg/manager/impl/task_manager_test.go b/flyteadmin/pkg/manager/impl/task_manager_test.go index 74f9c2f5ba..85a39b3c68 100644 --- a/flyteadmin/pkg/manager/impl/task_manager_test.go +++ b/flyteadmin/pkg/manager/impl/task_manager_test.go @@ -8,9 +8,11 @@ import ( "github.com/golang/protobuf/proto" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "google.golang.org/grpc/codes" "github.com/flyteorg/flyte/flyteadmin/pkg/artifacts" + artifactMocks "github.com/flyteorg/flyte/flyteadmin/pkg/artifacts/mocks" "github.com/flyteorg/flyte/flyteadmin/pkg/common" adminErrors "github.com/flyteorg/flyte/flyteadmin/pkg/errors" "github.com/flyteorg/flyte/flyteadmin/pkg/manager/impl/testutils" @@ -22,6 +24,7 @@ import ( workflowengine "github.com/flyteorg/flyte/flyteadmin/pkg/workflowengine/interfaces" workflowMocks "github.com/flyteorg/flyte/flyteadmin/pkg/workflowengine/mocks" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" + artifactsIdl "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/artifacts" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" mockScope "github.com/flyteorg/flyte/flytestdlib/promutils" "github.com/flyteorg/flyte/flytestdlib/promutils/labeled" @@ -147,6 +150,64 @@ func TestCreateTask_DatabaseError(t *testing.T) { assert.Nil(t, response) } +func TestCreateTask_ArtifactBehavior(t *testing.T) { + // Test that tasks that don't produce artifacts do not call the artifacts service at registration + mockRepository := getMockTaskRepository() + mockRepository.TaskRepo().(*repositoryMocks.MockTaskRepo).SetGetCallback( + func(input interfaces.Identifier) (models.Task, error) { + return models.Task{}, errors.New("foo") + }) + client := artifactMocks.ArtifactRegistryClient{} + client.On("RegisterProducer", mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + req := args.Get(1).(*artifactsIdl.RegisterProducerRequest) + id := req.Producers[0].EntityId + assert.Equal(t, "project", id.Project) + assert.Equal(t, "domain", id.Domain) + assert.Equal(t, "name", id.Name) + assert.Equal(t, "version", id.Version) + }).Return(&artifactsIdl.RegisterResponse{}, nil) + + registry := artifacts.ArtifactRegistry{ + Client: &client, + } + var createCalled bool + mockRepository.TaskRepo().(*repositoryMocks.MockTaskRepo).SetCreateCallback(func(input models.Task, descriptionEntity *models.DescriptionEntity) error { + createCalled = true + return nil + }) + mockRepository.DescriptionEntityRepo().(*repositoryMocks.MockDescriptionEntityRepo).SetGetCallback( + func(input interfaces.GetDescriptionEntityInput) (models.DescriptionEntity, error) { + return models.DescriptionEntity{}, adminErrors.NewFlyteAdminErrorf(codes.NotFound, "NotFound") + }) + mockConfig := getMockConfigForTaskTest() + mockConfig.(*runtimeMocks.MockConfigurationProvider).ApplicationConfiguration().GetTopLevelConfig().FeatureGates.EnableArtifacts = true + taskManager := NewTaskManager(mockRepository, mockConfig, getMockTaskCompiler(), + mockScope.NewTestScope(), ®istry) + request := testutils.GetValidTaskRequest() + response, err := taskManager.CreateTask(context.Background(), request) + assert.NoError(t, err) + assert.Equal(t, &admin.TaskCreateResponse{}, response) + assert.True(t, createCalled) + client.AssertNotCalled(t, "RegisterProducer", mock.Anything, mock.Anything, mock.Anything) + + withArtifact := &core.Variable{ + Type: &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_STRING}}, + ArtifactPartialId: testutils.GetArtifactID(), + } + ti := &core.TypedInterface{ + Outputs: &core.VariableMap{ + Variables: map[string]*core.Variable{ + "bar": withArtifact, + }, + }, + } + request.Spec.Template.Interface = ti + response, err = taskManager.CreateTask(context.Background(), request) + client.AssertCalled(t, "RegisterProducer", mock.Anything, mock.Anything, mock.Anything) + assert.NoError(t, err) + assert.Equal(t, &admin.TaskCreateResponse{}, response) +} + func TestGetTask(t *testing.T) { repository := getMockTaskRepository() taskGetFunc := func(input interfaces.Identifier) (models.Task, error) { diff --git a/flyteadmin/pkg/manager/impl/testutils/artifacts.go b/flyteadmin/pkg/manager/impl/testutils/artifacts.go new file mode 100644 index 0000000000..fdf64d9ed7 --- /dev/null +++ b/flyteadmin/pkg/manager/impl/testutils/artifacts.go @@ -0,0 +1,14 @@ +package testutils + +import "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" + +func GetArtifactID() *core.ArtifactID { + return &core.ArtifactID{ + ArtifactKey: &core.ArtifactKey{ + Project: "project", + Domain: "domain", + Name: "name", + }, + Version: "artf_v", + } +} diff --git a/flyteadmin/pkg/manager/impl/workflow_manager.go b/flyteadmin/pkg/manager/impl/workflow_manager.go index a20aee28ba..846f10ed53 100644 --- a/flyteadmin/pkg/manager/impl/workflow_manager.go +++ b/flyteadmin/pkg/manager/impl/workflow_manager.go @@ -6,7 +6,6 @@ import ( "strconv" "time" - "github.com/golang/protobuf/proto" "github.com/golang/protobuf/ptypes" "github.com/prometheus/client_golang/prometheus" "github.com/samber/lo" @@ -227,18 +226,21 @@ func (w *WorkflowManager) CreateWorkflow( // Send the interface definition to Artifact service, this is so that it can statically pick up one dimension of // lineage information - tIfaceCopy := proto.Clone(workflowClosure.CompiledWorkflow.Primary.Template.Interface).(*core.TypedInterface) + // TODO: Artifact feature gate, remove when ready - if w.artifactRegistry.GetClient() != nil { - go func() { - ceCtx := context.TODO() - if workflowClosure.CompiledWorkflow == nil || workflowClosure.CompiledWorkflow.Primary == nil { - logger.Debugf(ceCtx, "Insufficient fields to submit workflow interface %v", finalizedRequest.Id) - return + if w.config.ApplicationConfiguration().GetTopLevelConfig().FeatureGates.EnableArtifacts { + if iFace := workflowClosure.GetCompiledWorkflow().GetPrimary().GetTemplate().GetInterface(); iFace != nil { + for _, outputVar := range iFace.GetOutputs().GetVariables() { + if outputVar.GetArtifactPartialId() != nil { + err = w.artifactRegistry.RegisterArtifactProducer(ctx, finalizedRequest.Id, iFace) + if err != nil { + logger.Errorf(ctx, "Failed RegisterArtifactProducer for workflow [%+v] with err: %v", request.Id, err) + return nil, err + } + break + } } - - w.artifactRegistry.RegisterArtifactProducer(ceCtx, finalizedRequest.Id, *tIfaceCopy) - }() + } } return &admin.WorkflowCreateResponse{}, nil diff --git a/flyteadmin/pkg/manager/impl/workflow_manager_test.go b/flyteadmin/pkg/manager/impl/workflow_manager_test.go index 06cdb40ef0..8c4b708e1b 100644 --- a/flyteadmin/pkg/manager/impl/workflow_manager_test.go +++ b/flyteadmin/pkg/manager/impl/workflow_manager_test.go @@ -8,10 +8,12 @@ import ( "github.com/golang/protobuf/proto" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "github.com/flyteorg/flyte/flyteadmin/pkg/artifacts" + artifactMocks "github.com/flyteorg/flyte/flyteadmin/pkg/artifacts/mocks" "github.com/flyteorg/flyte/flyteadmin/pkg/common" commonMocks "github.com/flyteorg/flyte/flyteadmin/pkg/common/mocks" adminErrors "github.com/flyteorg/flyte/flyteadmin/pkg/errors" @@ -25,6 +27,7 @@ import ( workflowengineInterfaces "github.com/flyteorg/flyte/flyteadmin/pkg/workflowengine/interfaces" workflowengineMocks "github.com/flyteorg/flyte/flyteadmin/pkg/workflowengine/mocks" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" + artifactsIdl "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/artifacts" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyte/flytepropeller/pkg/compiler" engine "github.com/flyteorg/flyte/flytepropeller/pkg/compiler/common" @@ -272,6 +275,51 @@ func TestCreateWorkflow_DatabaseError(t *testing.T) { assert.Nil(t, response) } +func TestCreateWorkflow_ArtifactBehavior(t *testing.T) { + // Test that workflows without artifacts do not call the artifacts service upon registration. + repository := getMockRepository(false) + workflowCreateFunc := func(input models.Workflow, descriptionEntity *models.DescriptionEntity) error { + return nil + } + client := artifactMocks.ArtifactRegistryClient{} + client.On("RegisterProducer", mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + req := args.Get(1).(*artifactsIdl.RegisterProducerRequest) + id := req.Producers[0].EntityId + assert.Equal(t, "project", id.Project) + assert.Equal(t, "domain", id.Domain) + assert.Equal(t, "name", id.Name) + assert.Equal(t, "version", id.Version) + }).Return(&artifactsIdl.RegisterResponse{}, nil) + registry := artifacts.ArtifactRegistry{ + Client: &client, + } + + repository.WorkflowRepo().(*repositoryMocks.MockWorkflowRepo).SetCreateCallback(workflowCreateFunc) + mockConfig := getMockWorkflowConfigProvider() + mockConfig.(*runtimeMocks.MockConfigurationProvider).ApplicationConfiguration().GetTopLevelConfig().FeatureGates.EnableArtifacts = true + workflowManager := NewWorkflowManager( + repository, mockConfig, getMockWorkflowCompiler(), getMockStorage(), storagePrefix, + mockScope.NewTestScope(), ®istry) + + request := testutils.GetWorkflowRequest() + ctx := context.Background() + response, err := workflowManager.CreateWorkflow(ctx, request) + assert.NoError(t, err) + assert.NotNil(t, response) + client.AssertNotCalled(t, "RegisterProducer", mock.Anything, mock.Anything, mock.Anything) + + withArtifact := &core.Variable{ + Type: &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_STRING}}, + ArtifactPartialId: testutils.GetArtifactID(), + } + request.Spec.Template.Interface.Outputs.Variables["bar"] = withArtifact + + // But workflows that do have artifacts do call the service + _, err = workflowManager.CreateWorkflow(ctx, request) + assert.NoError(t, err) + client.AssertCalled(t, "RegisterProducer", mock.Anything, mock.Anything) +} + func TestGetWorkflow(t *testing.T) { repository := repositoryMocks.NewMockRepository() workflowGetFunc := func(input interfaces.Identifier) (models.Workflow, error) {