Skip to content

Commit

Permalink
update registration of interfaces (#166)
Browse files Browse the repository at this point in the history
Signed-off-by: Yee Hing Tong <[email protected]>
  • Loading branch information
wild-endeavor authored Mar 19, 2024
1 parent aaa2519 commit af5f4ce
Show file tree
Hide file tree
Showing 8 changed files with 237 additions and 34 deletions.
25 changes: 20 additions & 5 deletions flyteadmin/pkg/artifacts/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,14 @@ type ArtifactRegistry struct {
Client artifacts.ArtifactRegistryClient
}

func (a *ArtifactRegistry) RegisterArtifactProducer(ctx context.Context, id *core.Identifier, ti core.TypedInterface) {
func (a *ArtifactRegistry) RegisterArtifactProducer(ctx context.Context, id *core.Identifier, ti *core.TypedInterface) error {
if a == nil || a.Client == nil {
logger.Debugf(ctx, "Artifact Client not configured, skipping registration for task [%+v]", id)
return
return nil
}
if ti.GetOutputs() == nil {
logger.Infof(ctx, "No outputs for [%+v], skipping registration", id)
return nil
}

ap := &artifacts.ArtifactProducer{
Expand All @@ -36,26 +40,37 @@ func (a *ArtifactRegistry) RegisterArtifactProducer(ctx context.Context, id *cor
})
if err != nil {
logger.Errorf(ctx, "Failed to register artifact producer for task [%+v] with err: %v", id, err)
return err
}
logger.Debugf(ctx, "Registered artifact producer [%+v]", id)

return nil
}

func (a *ArtifactRegistry) RegisterArtifactConsumer(ctx context.Context, id *core.Identifier, pm core.ParameterMap) {
func (a *ArtifactRegistry) RegisterArtifactConsumer(ctx context.Context, id *core.Identifier, pm *core.ParameterMap) error {
if a == nil || a.Client == nil {
logger.Debugf(ctx, "Artifact Client not configured, skipping registration for consumer [%+v]", id)
return
return nil
}
if pm.GetParameters() == nil {
logger.Infof(ctx, "No inputs for [%+v], skipping registration", id)
return nil
}

ac := &artifacts.ArtifactConsumer{
EntityId: id,
Inputs: &pm,
Inputs: pm,
}
_, err := a.Client.RegisterConsumer(ctx, &artifacts.RegisterConsumerRequest{
Consumers: []*artifacts.ArtifactConsumer{ac},
})
if err != nil {
logger.Errorf(ctx, "Failed to register artifact consumer for entity [%+v] with err: %v", id, err)
return err
}
logger.Debugf(ctx, "Registered artifact consumer [%+v]", id)

return nil
}

func (a *ArtifactRegistry) RegisterTrigger(ctx context.Context, plan *admin.LaunchPlan) error {
Expand Down
20 changes: 12 additions & 8 deletions flyteadmin/pkg/manager/impl/launch_plan_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,16 +129,20 @@ func (m *LaunchPlanManager) CreateLaunchPlan(
}
m.metrics.SpecSizeBytes.Observe(float64(len(launchPlanModel.Spec)))
m.metrics.ClosureSizeBytes.Observe(float64(len(launchPlanModel.Closure)))

// TODO: Artifact feature gate, remove when ready
if m.artifactRegistry.GetClient() != nil {
go func() {
ceCtx := context.TODO()
if launchPlan.Spec.DefaultInputs == nil {
logger.Debugf(ceCtx, "Insufficient fields to submit launchplan interface %v", launchPlan.Id)
return
if m.config.ApplicationConfiguration().GetTopLevelConfig().FeatureGates.EnableArtifacts {
// As an optimization, run through the interface, and only send to artifacts service it there are artifacts
for _, param := range launchPlan.GetSpec().GetDefaultInputs().GetParameters() {
if param.GetArtifactId() != nil || param.GetArtifactQuery() != nil {
err = m.artifactRegistry.RegisterArtifactConsumer(ctx, launchPlan.Id, launchPlan.Spec.DefaultInputs)
if err != nil {
logger.Errorf(ctx, "failed RegisterArtifactConsumer for launch plan [%+v] with err: %v", launchPlan.Id, err)
return nil, err
}
break
}
m.artifactRegistry.RegisterArtifactConsumer(ceCtx, launchPlan.Id, *launchPlan.Spec.DefaultInputs)
}()
}
}

return &admin.LaunchPlanCreateResponse{}, nil
Expand Down
54 changes: 54 additions & 0 deletions flyteadmin/pkg/manager/impl/launch_plan_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,60 @@ func TestCreateLaunchPlanValidateCreate(t *testing.T) {
assert.True(t, proto.Equal(expectedResponse, response))
}

func TestCreateLaunchPlan_ArtifactBehavior(t *testing.T) {
// Test that enabling artifacts feature flag will not call RegisterArtifactConsumer if no artifact queries present.
repository := getMockRepositoryForLpTest()
repository.LaunchPlanRepo().(*repositoryMocks.MockLaunchPlanRepo).SetGetCallback(
func(input interfaces.Identifier) (models.LaunchPlan, error) {
return models.LaunchPlan{}, errors.New("foo")
})
client := artifactMocks.ArtifactRegistryClient{}
registry := artifacts.ArtifactRegistry{
Client: &client,
}

client.On("RegisterConsumer", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
req := args.Get(1).(*artifactsIdl.RegisterConsumerRequest)
id := req.Consumers[0].EntityId
assert.Equal(t, "project", id.Project)
assert.Equal(t, "domain", id.Domain)
assert.Equal(t, "name", id.Name)
assert.Equal(t, "version", id.Version)
}).Return(&artifactsIdl.RegisterResponse{}, nil)

setDefaultWorkflowCallbackForLpTest(repository)
mockConfig := getMockConfigForLpTest()
mockConfig.(*runtimeMocks.MockConfigurationProvider).ApplicationConfiguration().GetTopLevelConfig().FeatureGates.EnableArtifacts = true
lpManager := NewLaunchPlanManager(repository, mockConfig, mockScheduler, mockScope.NewTestScope(), &registry)
request := testutils.GetLaunchPlanRequest()
response, err := lpManager.CreateLaunchPlan(context.Background(), request)
assert.Nil(t, err)

expectedResponse := &admin.LaunchPlanCreateResponse{}
assert.True(t, proto.Equal(expectedResponse, response))
client.AssertNotCalled(t, "RegisterConsumer", mock.Anything, mock.Anything, mock.Anything)

// If the launch plan interface has a query however, then the service should be called.
aq := &core.Parameter_ArtifactQuery{
ArtifactQuery: &core.ArtifactQuery{
Identifier: &core.ArtifactQuery_ArtifactId{
ArtifactId: testutils.GetArtifactID(),
},
},
}
request.GetSpec().GetDefaultInputs().GetParameters()["foo"] = &core.Parameter{
Var: &core.Variable{
Type: &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_STRING}},
},
Behavior: aq,
}
response, err = lpManager.CreateLaunchPlan(context.Background(), request)
assert.Nil(t, err)
expectedResponse = &admin.LaunchPlanCreateResponse{}
assert.True(t, proto.Equal(expectedResponse, response))
client.AssertCalled(t, "RegisterConsumer", mock.Anything, mock.Anything, mock.Anything)
}

func TestCreateLaunchPlanNoWorkflowInterface(t *testing.T) {
repository := getMockRepositoryForLpTest()
repository.LaunchPlanRepo().(*repositoryMocks.MockLaunchPlanRepo).SetGetCallback(
Expand Down
25 changes: 15 additions & 10 deletions flyteadmin/pkg/manager/impl/task_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"strconv"
"time"

"github.com/golang/protobuf/proto"
"github.com/golang/protobuf/ptypes"
"github.com/prometheus/client_golang/prometheus"
"google.golang.org/grpc/codes"
Expand Down Expand Up @@ -133,17 +132,23 @@ func (t *TaskManager) CreateTask(
contextWithRuntimeMeta, common.RuntimeVersionKey, finalizedRequest.Spec.Template.Metadata.Runtime.Version)
t.metrics.Registered.Inc(contextWithRuntimeMeta)
}

// TODO: Artifact feature gate, remove when ready
if t.artifactRegistry.GetClient() != nil {
tIfaceCopy := proto.Clone(finalizedRequest.Spec.Template.Interface).(*core.TypedInterface)
go func() {
ceCtx := context.TODO()
if finalizedRequest.Spec.Template.Interface == nil {
logger.Debugf(ceCtx, "Task [%+v] has no interface, skipping registration", finalizedRequest.Id)
return
if t.config.ApplicationConfiguration().GetTopLevelConfig().FeatureGates.EnableArtifacts {
if finalizedRequest.GetSpec().GetTemplate().GetInterface().GetOutputs() != nil {
// As optimization, iterate over the outputs, if any, and only register interface with artifacts service
// if there are any artifact outputs.
for _, outputVar := range finalizedRequest.GetSpec().GetTemplate().GetInterface().GetOutputs().GetVariables() {
if outputVar.GetArtifactPartialId() != nil {
err = t.artifactRegistry.RegisterArtifactProducer(ctx, finalizedRequest.Id, finalizedRequest.GetSpec().GetTemplate().GetInterface())
if err != nil {
logger.Errorf(ctx, "Failed RegisterArtifactProducer for task [%+v] with err: %v", request.Id, err)
return nil, err
}
break
}
}
t.artifactRegistry.RegisterArtifactProducer(ceCtx, finalizedRequest.Id, *tIfaceCopy)
}()
}
}

return &admin.TaskCreateResponse{}, nil
Expand Down
61 changes: 61 additions & 0 deletions flyteadmin/pkg/manager/impl/task_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@ import (

"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"google.golang.org/grpc/codes"

"github.com/flyteorg/flyte/flyteadmin/pkg/artifacts"
artifactMocks "github.com/flyteorg/flyte/flyteadmin/pkg/artifacts/mocks"
"github.com/flyteorg/flyte/flyteadmin/pkg/common"
adminErrors "github.com/flyteorg/flyte/flyteadmin/pkg/errors"
"github.com/flyteorg/flyte/flyteadmin/pkg/manager/impl/testutils"
Expand All @@ -22,6 +24,7 @@ import (
workflowengine "github.com/flyteorg/flyte/flyteadmin/pkg/workflowengine/interfaces"
workflowMocks "github.com/flyteorg/flyte/flyteadmin/pkg/workflowengine/mocks"
"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin"
artifactsIdl "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/artifacts"
"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core"
mockScope "github.com/flyteorg/flyte/flytestdlib/promutils"
"github.com/flyteorg/flyte/flytestdlib/promutils/labeled"
Expand Down Expand Up @@ -147,6 +150,64 @@ func TestCreateTask_DatabaseError(t *testing.T) {
assert.Nil(t, response)
}

func TestCreateTask_ArtifactBehavior(t *testing.T) {
// Test that tasks that don't produce artifacts do not call the artifacts service at registration
mockRepository := getMockTaskRepository()
mockRepository.TaskRepo().(*repositoryMocks.MockTaskRepo).SetGetCallback(
func(input interfaces.Identifier) (models.Task, error) {
return models.Task{}, errors.New("foo")
})
client := artifactMocks.ArtifactRegistryClient{}
client.On("RegisterProducer", mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
req := args.Get(1).(*artifactsIdl.RegisterProducerRequest)
id := req.Producers[0].EntityId
assert.Equal(t, "project", id.Project)
assert.Equal(t, "domain", id.Domain)
assert.Equal(t, "name", id.Name)
assert.Equal(t, "version", id.Version)
}).Return(&artifactsIdl.RegisterResponse{}, nil)

registry := artifacts.ArtifactRegistry{
Client: &client,
}
var createCalled bool
mockRepository.TaskRepo().(*repositoryMocks.MockTaskRepo).SetCreateCallback(func(input models.Task, descriptionEntity *models.DescriptionEntity) error {
createCalled = true
return nil
})
mockRepository.DescriptionEntityRepo().(*repositoryMocks.MockDescriptionEntityRepo).SetGetCallback(
func(input interfaces.GetDescriptionEntityInput) (models.DescriptionEntity, error) {
return models.DescriptionEntity{}, adminErrors.NewFlyteAdminErrorf(codes.NotFound, "NotFound")
})
mockConfig := getMockConfigForTaskTest()
mockConfig.(*runtimeMocks.MockConfigurationProvider).ApplicationConfiguration().GetTopLevelConfig().FeatureGates.EnableArtifacts = true
taskManager := NewTaskManager(mockRepository, mockConfig, getMockTaskCompiler(),
mockScope.NewTestScope(), &registry)
request := testutils.GetValidTaskRequest()
response, err := taskManager.CreateTask(context.Background(), request)
assert.NoError(t, err)
assert.Equal(t, &admin.TaskCreateResponse{}, response)
assert.True(t, createCalled)
client.AssertNotCalled(t, "RegisterProducer", mock.Anything, mock.Anything, mock.Anything)

withArtifact := &core.Variable{
Type: &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_STRING}},
ArtifactPartialId: testutils.GetArtifactID(),
}
ti := &core.TypedInterface{
Outputs: &core.VariableMap{
Variables: map[string]*core.Variable{
"bar": withArtifact,
},
},
}
request.Spec.Template.Interface = ti
response, err = taskManager.CreateTask(context.Background(), request)
client.AssertCalled(t, "RegisterProducer", mock.Anything, mock.Anything, mock.Anything)
assert.NoError(t, err)
assert.Equal(t, &admin.TaskCreateResponse{}, response)
}

func TestGetTask(t *testing.T) {
repository := getMockTaskRepository()
taskGetFunc := func(input interfaces.Identifier) (models.Task, error) {
Expand Down
14 changes: 14 additions & 0 deletions flyteadmin/pkg/manager/impl/testutils/artifacts.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package testutils

import "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core"

func GetArtifactID() *core.ArtifactID {
return &core.ArtifactID{
ArtifactKey: &core.ArtifactKey{
Project: "project",
Domain: "domain",
Name: "name",
},
Version: "artf_v",
}
}
24 changes: 13 additions & 11 deletions flyteadmin/pkg/manager/impl/workflow_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"strconv"
"time"

"github.com/golang/protobuf/proto"
"github.com/golang/protobuf/ptypes"
"github.com/prometheus/client_golang/prometheus"
"github.com/samber/lo"
Expand Down Expand Up @@ -227,18 +226,21 @@ func (w *WorkflowManager) CreateWorkflow(

// Send the interface definition to Artifact service, this is so that it can statically pick up one dimension of
// lineage information
tIfaceCopy := proto.Clone(workflowClosure.CompiledWorkflow.Primary.Template.Interface).(*core.TypedInterface)

// TODO: Artifact feature gate, remove when ready
if w.artifactRegistry.GetClient() != nil {
go func() {
ceCtx := context.TODO()
if workflowClosure.CompiledWorkflow == nil || workflowClosure.CompiledWorkflow.Primary == nil {
logger.Debugf(ceCtx, "Insufficient fields to submit workflow interface %v", finalizedRequest.Id)
return
if w.config.ApplicationConfiguration().GetTopLevelConfig().FeatureGates.EnableArtifacts {
if iFace := workflowClosure.GetCompiledWorkflow().GetPrimary().GetTemplate().GetInterface(); iFace != nil {
for _, outputVar := range iFace.GetOutputs().GetVariables() {
if outputVar.GetArtifactPartialId() != nil {
err = w.artifactRegistry.RegisterArtifactProducer(ctx, finalizedRequest.Id, iFace)
if err != nil {
logger.Errorf(ctx, "Failed RegisterArtifactProducer for workflow [%+v] with err: %v", request.Id, err)
return nil, err
}
break
}
}

w.artifactRegistry.RegisterArtifactProducer(ceCtx, finalizedRequest.Id, *tIfaceCopy)
}()
}
}

return &admin.WorkflowCreateResponse{}, nil
Expand Down
Loading

0 comments on commit af5f4ce

Please sign in to comment.