diff --git a/fasttask/plugin/builder_test.go b/fasttask/plugin/builder_test.go new file mode 100644 index 00000000000..f198906a104 --- /dev/null +++ b/fasttask/plugin/builder_test.go @@ -0,0 +1,490 @@ +package plugin + +import ( + "context" + "encoding/json" + "testing" + "time" + + "github.com/golang/protobuf/proto" + _struct "github.com/golang/protobuf/ptypes/struct" + "github.com/stretchr/testify/assert" + v1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "sigs.k8s.io/controller-runtime/pkg/cache" + "sigs.k8s.io/controller-runtime/pkg/client" + + coremocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core/mocks" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils" + + "github.com/unionai/flyte/fasttask/plugin/pb" +) + +type kubeClient struct { + client.Client + + createCalls int + deleteCalls int +} + +func (k *kubeClient) Create(ctx context.Context, obj client.Object, opts ...client.CreateOption) error { + k.createCalls++ + return nil +} + +func (k *kubeClient) Delete(ctx context.Context, obj client.Object, opts ...client.DeleteOption) error { + k.deleteCalls++ + return nil +} + +type kubeCache struct { + cache.Cache + + pods []v1.Pod +} + +func (k *kubeCache) Get(ctx context.Context, key client.ObjectKey, obj client.Object, opts ...client.GetOption) error { + for _, pod := range k.pods { + if pod.Name == key.Name { + return nil + } + } + + return errors.NewNotFound(v1.Resource("pods"), key.Name) +} + +func (k *kubeCache) List(ctx context.Context, list client.ObjectList, opts ...client.ListOption) error { + if podList, ok := list.(*v1.PodList); ok { + podList.Items = k.pods + } + return nil +} + +func TestCreate(t *testing.T) { + fastTaskEnvironment := &pb.FastTaskEnvironment{ + QueueId: "foo", + } + + fastTaskEnvStruct := &_struct.Struct{} + err := utils.MarshalStruct(fastTaskEnvironment, fastTaskEnvStruct) + assert.Nil(t, err) + + podTemplateSpec := &v1.PodTemplateSpec{ + Spec: v1.PodSpec{ + Containers: []v1.Container{ + v1.Container{ + Name: "primary", + }, + }, + }, + } + podTemplateSpecBytes, err := json.Marshal(podTemplateSpec) + assert.Nil(t, err) + + fastTaskEnvSpec := &pb.FastTaskEnvironmentSpec{ + Parallelism: 1, + PrimaryContainerName: "primary", + PodTemplateSpec: podTemplateSpecBytes, + ReplicaCount: 2, + TerminationCriteria: &pb.FastTaskEnvironmentSpec_TtlSeconds{ + TtlSeconds: 300, + }, + } + + ctx := context.TODO() + tests := []struct { + name string + environmentSpec *pb.FastTaskEnvironmentSpec + environments map[string]*environment + expectedEnvironment *_struct.Struct + expectedCreateCalls int + }{ + { + name: "Success", + environmentSpec: fastTaskEnvSpec, + environments: map[string]*environment{}, + expectedEnvironment: fastTaskEnvStruct, + expectedCreateCalls: 2, + }, + { + name: "Exists", + environmentSpec: fastTaskEnvSpec, + environments: map[string]*environment{ + "foo": &environment{ + extant: fastTaskEnvStruct, + state: HEALTHY, + }, + }, + expectedEnvironment: fastTaskEnvStruct, + expectedCreateCalls: 0, + }, + { + name: "Orphaned", + environmentSpec: fastTaskEnvSpec, + environments: map[string]*environment{ + "foo": &environment{ + extant: fastTaskEnvStruct, + replicas: []string{"bar"}, + state: ORPHANED, + }, + }, + expectedEnvironment: fastTaskEnvStruct, + expectedCreateCalls: 1, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + fastTaskEnvSpecStruct := &_struct.Struct{} + err := utils.MarshalStruct(test.environmentSpec, fastTaskEnvSpecStruct) + assert.Nil(t, err) + + // initialize InMemoryBuilder + kubeClient := &kubeClient{} + kubeCache := &kubeCache{} + + kubeClientImpl := &coremocks.KubeClient{} + kubeClientImpl.OnGetClient().Return(kubeClient) + kubeClientImpl.OnGetCache().Return(kubeCache) + + builder := NewEnvironmentBuilder(kubeClientImpl) + builder.environments = test.environments + + // call `Create` + environment, err := builder.Create(ctx, "foo", fastTaskEnvSpecStruct) + assert.Nil(t, err) + assert.True(t, proto.Equal(test.expectedEnvironment, environment)) + assert.Equal(t, test.expectedCreateCalls, kubeClient.createCalls) + }) + } +} + +func TestDetectOrphanedEnvironments(t *testing.T) { + pods := []v1.Pod{ + v1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "bar", + Labels: map[string]string{ + EXECUTION_ENV_ID: "foo", + }, + Annotations: map[string]string{ + TTL_SECONDS: "60", + }, + }, + }, + v1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "baz", + Labels: map[string]string{ + EXECUTION_ENV_ID: "foo", + }, + Annotations: map[string]string{ + TTL_SECONDS: "60", + }, + }, + }, + } + + ctx := context.TODO() + tests := []struct { + name string + environments map[string]*environment + expectedEnvironmentCount int + expectedReplicaCount int + }{ + { + name: "Noop", + environments: map[string]*environment{ + "foo": &environment{ + replicas: []string{"bar", "baz"}, + state: HEALTHY, + }, + }, + expectedEnvironmentCount: 1, + expectedReplicaCount: 2, + }, + { + name: "CreateOrphanedEnvironment", + environments: map[string]*environment{}, + expectedEnvironmentCount: 1, + expectedReplicaCount: 2, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + // initialize InMemoryBuilder + kubeClient := &kubeClient{} + kubeCache := &kubeCache{ + pods: pods, + } + + kubeClientImpl := &coremocks.KubeClient{} + kubeClientImpl.OnGetClient().Return(kubeClient) + kubeClientImpl.OnGetCache().Return(kubeCache) + + builder := NewEnvironmentBuilder(kubeClientImpl) + builder.environments = test.environments + + // call `Create` + err := builder.detectOrphanedEnvironments(ctx, kubeCache) + assert.Nil(t, err) + assert.Equal(t, test.expectedEnvironmentCount, len(builder.environments)) + totalReplicas := 0 + for _, environment := range builder.environments { + totalReplicas += len(environment.replicas) + } + assert.Equal(t, test.expectedReplicaCount, totalReplicas) + }) + } +} + +func TestGCEnvironments(t *testing.T) { + podTemplateSpec := &v1.PodTemplateSpec{ + Spec: v1.PodSpec{ + Containers: []v1.Container{ + v1.Container{ + Name: "primary", + }, + }, + }, + } + podTemplateSpecBytes, err := json.Marshal(podTemplateSpec) + assert.Nil(t, err) + + fastTaskEnvSpec := &pb.FastTaskEnvironmentSpec{ + PodTemplateSpec: podTemplateSpecBytes, + TerminationCriteria: &pb.FastTaskEnvironmentSpec_TtlSeconds{ + TtlSeconds: 300, + }, + } + + ctx := context.TODO() + tests := []struct { + name string + environments map[string]*environment + expectedDeleteCalls int + expectedEnvironmentCount int + }{ + { + name: "Noop", + environments: map[string]*environment{ + "foo": &environment{ + lastAccessedAt: time.Now(), + replicas: []string{"bar", "baz"}, + spec: fastTaskEnvSpec, + state: HEALTHY, + }, + }, + expectedDeleteCalls: 0, + expectedEnvironmentCount: 1, + }, + { + name: "TimeoutTtl", + environments: map[string]*environment{ + "foo": &environment{ + lastAccessedAt: time.Now().Add(-301 * time.Second), + replicas: []string{"bar", "baz"}, + spec: fastTaskEnvSpec, + state: TOMBSTONED, + }, + }, + expectedDeleteCalls: 2, + expectedEnvironmentCount: 0, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + // initialize InMemoryBuilder + kubeClient := &kubeClient{} + kubeCache := &kubeCache{} + + kubeClientImpl := &coremocks.KubeClient{} + kubeClientImpl.OnGetClient().Return(kubeClient) + kubeClientImpl.OnGetCache().Return(kubeCache) + + builder := NewEnvironmentBuilder(kubeClientImpl) + builder.environments = test.environments + + // call `Create` + err := builder.gcEnvironments(ctx) + assert.Nil(t, err) + assert.Equal(t, test.expectedDeleteCalls, kubeClient.deleteCalls) + assert.Equal(t, test.expectedEnvironmentCount, len(builder.environments)) + }) + } +} + +func TestGet(t *testing.T) { + fastTaskEnvironment := &pb.FastTaskEnvironment{ + QueueId: "foo", + } + + fastTaskEnvStruct := &_struct.Struct{} + err := utils.MarshalStruct(fastTaskEnvironment, fastTaskEnvStruct) + assert.Nil(t, err) + + ctx := context.TODO() + tests := []struct { + name string + environments map[string]*environment + expectedEnvironment *_struct.Struct + }{ + { + name: "Exists", + environments: map[string]*environment{ + "foo": &environment{ + extant: fastTaskEnvStruct, + state: HEALTHY, + }, + }, + expectedEnvironment: fastTaskEnvStruct, + }, + { + name: "DoesNotExist", + environments: map[string]*environment{}, + expectedEnvironment: nil, + }, + { + name: "Tombstoned", + environments: map[string]*environment{ + "foo": &environment{ + state: TOMBSTONED, + }, + }, + expectedEnvironment: nil, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + // initialize InMemoryBuilder + kubeClient := &kubeClient{} + kubeCache := &kubeCache{} + + kubeClientImpl := &coremocks.KubeClient{} + kubeClientImpl.OnGetClient().Return(kubeClient) + kubeClientImpl.OnGetCache().Return(kubeCache) + + builder := NewEnvironmentBuilder(kubeClientImpl) + builder.environments = test.environments + + // call `Get` + environment := builder.Get(ctx, "foo") + assert.True(t, proto.Equal(test.expectedEnvironment, environment)) + }) + } +} + +func TestRepairEnvironments(t *testing.T) { + fastTaskEnvironment := &pb.FastTaskEnvironment{ + QueueId: "foo", + } + + fastTaskEnvStruct := &_struct.Struct{} + err := utils.MarshalStruct(fastTaskEnvironment, fastTaskEnvStruct) + assert.Nil(t, err) + + podTemplateSpec := &v1.PodTemplateSpec{ + Spec: v1.PodSpec{ + Containers: []v1.Container{ + v1.Container{ + Name: "primary", + }, + }, + }, + } + podTemplateSpecBytes, err := json.Marshal(podTemplateSpec) + assert.Nil(t, err) + + fastTaskEnvSpec := &pb.FastTaskEnvironmentSpec{ + Parallelism: 1, + PrimaryContainerName: "primary", + PodTemplateSpec: podTemplateSpecBytes, + ReplicaCount: 2, + TerminationCriteria: &pb.FastTaskEnvironmentSpec_TtlSeconds{ + TtlSeconds: 300, + }, + } + + ctx := context.TODO() + tests := []struct { + name string + environments map[string]*environment + existingPods []v1.Pod + expectedCreateCalls int + }{ + { + name: "Noop", + environments: map[string]*environment{ + "foo": &environment{ + extant: fastTaskEnvStruct, + replicas: []string{"bar", "baz"}, + spec: fastTaskEnvSpec, + state: HEALTHY, + }, + }, + existingPods: []v1.Pod{ + v1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "bar", + }, + }, + v1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "baz", + }, + }, + }, + expectedCreateCalls: 0, + }, + { + name: "RepairMissingPod", + environments: map[string]*environment{ + "foo": &environment{ + extant: fastTaskEnvStruct, + replicas: []string{"bar", "baz"}, + spec: fastTaskEnvSpec, + state: REPAIRING, + }, + }, + existingPods: []v1.Pod{ + v1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "bar", + }, + }, + }, + expectedCreateCalls: 1, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + // initialize InMemoryBuilder + kubeClient := &kubeClient{} + kubeCache := &kubeCache{ + pods: test.existingPods, + } + + kubeClientImpl := &coremocks.KubeClient{} + kubeClientImpl.OnGetClient().Return(kubeClient) + kubeClientImpl.OnGetCache().Return(kubeCache) + + builder := NewEnvironmentBuilder(kubeClientImpl) + builder.environments = test.environments + + // call `Create` + err := builder.repairEnvironments(ctx) + assert.Nil(t, err) + assert.Equal(t, test.expectedCreateCalls, kubeClient.createCalls) + + // verify all environments are now healthy + for _, environment := range builder.environments { + assert.Equal(t, HEALTHY, environment.state) + } + }) + } +} diff --git a/fasttask/plugin/go.mod b/fasttask/plugin/go.mod index eac8232734b..e65c8300095 100644 --- a/fasttask/plugin/go.mod +++ b/fasttask/plugin/go.mod @@ -85,6 +85,7 @@ require ( github.com/prometheus/procfs v0.10.1 // indirect github.com/sirupsen/logrus v1.9.3 // indirect github.com/spf13/cobra v1.7.0 // indirect + github.com/stretchr/objx v0.5.2 // indirect go.opencensus.io v0.24.0 // indirect go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.47.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.46.1 // indirect diff --git a/fasttask/plugin/mocks/fast_task__heartbeat_client.go b/fasttask/plugin/mocks/fast_task__heartbeat_client.go new file mode 100644 index 00000000000..594ffea539c --- /dev/null +++ b/fasttask/plugin/mocks/fast_task__heartbeat_client.go @@ -0,0 +1,295 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" + metadata "google.golang.org/grpc/metadata" + + pb "github.com/unionai/flyte/fasttask/plugin/pb" +) + +// FastTask_HeartbeatClient is an autogenerated mock type for the FastTask_HeartbeatClient type +type FastTask_HeartbeatClient struct { + mock.Mock +} + +type FastTask_HeartbeatClient_CloseSend struct { + *mock.Call +} + +func (_m FastTask_HeartbeatClient_CloseSend) Return(_a0 error) *FastTask_HeartbeatClient_CloseSend { + return &FastTask_HeartbeatClient_CloseSend{Call: _m.Call.Return(_a0)} +} + +func (_m *FastTask_HeartbeatClient) OnCloseSend() *FastTask_HeartbeatClient_CloseSend { + c_call := _m.On("CloseSend") + return &FastTask_HeartbeatClient_CloseSend{Call: c_call} +} + +func (_m *FastTask_HeartbeatClient) OnCloseSendMatch(matchers ...interface{}) *FastTask_HeartbeatClient_CloseSend { + c_call := _m.On("CloseSend", matchers...) + return &FastTask_HeartbeatClient_CloseSend{Call: c_call} +} + +// CloseSend provides a mock function with given fields: +func (_m *FastTask_HeartbeatClient) CloseSend() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +type FastTask_HeartbeatClient_Context struct { + *mock.Call +} + +func (_m FastTask_HeartbeatClient_Context) Return(_a0 context.Context) *FastTask_HeartbeatClient_Context { + return &FastTask_HeartbeatClient_Context{Call: _m.Call.Return(_a0)} +} + +func (_m *FastTask_HeartbeatClient) OnContext() *FastTask_HeartbeatClient_Context { + c_call := _m.On("Context") + return &FastTask_HeartbeatClient_Context{Call: c_call} +} + +func (_m *FastTask_HeartbeatClient) OnContextMatch(matchers ...interface{}) *FastTask_HeartbeatClient_Context { + c_call := _m.On("Context", matchers...) + return &FastTask_HeartbeatClient_Context{Call: c_call} +} + +// Context provides a mock function with given fields: +func (_m *FastTask_HeartbeatClient) Context() context.Context { + ret := _m.Called() + + var r0 context.Context + if rf, ok := ret.Get(0).(func() context.Context); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(context.Context) + } + } + + return r0 +} + +type FastTask_HeartbeatClient_Header struct { + *mock.Call +} + +func (_m FastTask_HeartbeatClient_Header) Return(_a0 metadata.MD, _a1 error) *FastTask_HeartbeatClient_Header { + return &FastTask_HeartbeatClient_Header{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *FastTask_HeartbeatClient) OnHeader() *FastTask_HeartbeatClient_Header { + c_call := _m.On("Header") + return &FastTask_HeartbeatClient_Header{Call: c_call} +} + +func (_m *FastTask_HeartbeatClient) OnHeaderMatch(matchers ...interface{}) *FastTask_HeartbeatClient_Header { + c_call := _m.On("Header", matchers...) + return &FastTask_HeartbeatClient_Header{Call: c_call} +} + +// Header provides a mock function with given fields: +func (_m *FastTask_HeartbeatClient) Header() (metadata.MD, error) { + ret := _m.Called() + + var r0 metadata.MD + if rf, ok := ret.Get(0).(func() metadata.MD); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(metadata.MD) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type FastTask_HeartbeatClient_Recv struct { + *mock.Call +} + +func (_m FastTask_HeartbeatClient_Recv) Return(_a0 *pb.HeartbeatResponse, _a1 error) *FastTask_HeartbeatClient_Recv { + return &FastTask_HeartbeatClient_Recv{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *FastTask_HeartbeatClient) OnRecv() *FastTask_HeartbeatClient_Recv { + c_call := _m.On("Recv") + return &FastTask_HeartbeatClient_Recv{Call: c_call} +} + +func (_m *FastTask_HeartbeatClient) OnRecvMatch(matchers ...interface{}) *FastTask_HeartbeatClient_Recv { + c_call := _m.On("Recv", matchers...) + return &FastTask_HeartbeatClient_Recv{Call: c_call} +} + +// Recv provides a mock function with given fields: +func (_m *FastTask_HeartbeatClient) Recv() (*pb.HeartbeatResponse, error) { + ret := _m.Called() + + var r0 *pb.HeartbeatResponse + if rf, ok := ret.Get(0).(func() *pb.HeartbeatResponse); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*pb.HeartbeatResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type FastTask_HeartbeatClient_RecvMsg struct { + *mock.Call +} + +func (_m FastTask_HeartbeatClient_RecvMsg) Return(_a0 error) *FastTask_HeartbeatClient_RecvMsg { + return &FastTask_HeartbeatClient_RecvMsg{Call: _m.Call.Return(_a0)} +} + +func (_m *FastTask_HeartbeatClient) OnRecvMsg(m interface{}) *FastTask_HeartbeatClient_RecvMsg { + c_call := _m.On("RecvMsg", m) + return &FastTask_HeartbeatClient_RecvMsg{Call: c_call} +} + +func (_m *FastTask_HeartbeatClient) OnRecvMsgMatch(matchers ...interface{}) *FastTask_HeartbeatClient_RecvMsg { + c_call := _m.On("RecvMsg", matchers...) + return &FastTask_HeartbeatClient_RecvMsg{Call: c_call} +} + +// RecvMsg provides a mock function with given fields: m +func (_m *FastTask_HeartbeatClient) RecvMsg(m interface{}) error { + ret := _m.Called(m) + + var r0 error + if rf, ok := ret.Get(0).(func(interface{}) error); ok { + r0 = rf(m) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +type FastTask_HeartbeatClient_Send struct { + *mock.Call +} + +func (_m FastTask_HeartbeatClient_Send) Return(_a0 error) *FastTask_HeartbeatClient_Send { + return &FastTask_HeartbeatClient_Send{Call: _m.Call.Return(_a0)} +} + +func (_m *FastTask_HeartbeatClient) OnSend(_a0 *pb.HeartbeatRequest) *FastTask_HeartbeatClient_Send { + c_call := _m.On("Send", _a0) + return &FastTask_HeartbeatClient_Send{Call: c_call} +} + +func (_m *FastTask_HeartbeatClient) OnSendMatch(matchers ...interface{}) *FastTask_HeartbeatClient_Send { + c_call := _m.On("Send", matchers...) + return &FastTask_HeartbeatClient_Send{Call: c_call} +} + +// Send provides a mock function with given fields: _a0 +func (_m *FastTask_HeartbeatClient) Send(_a0 *pb.HeartbeatRequest) error { + ret := _m.Called(_a0) + + var r0 error + if rf, ok := ret.Get(0).(func(*pb.HeartbeatRequest) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +type FastTask_HeartbeatClient_SendMsg struct { + *mock.Call +} + +func (_m FastTask_HeartbeatClient_SendMsg) Return(_a0 error) *FastTask_HeartbeatClient_SendMsg { + return &FastTask_HeartbeatClient_SendMsg{Call: _m.Call.Return(_a0)} +} + +func (_m *FastTask_HeartbeatClient) OnSendMsg(m interface{}) *FastTask_HeartbeatClient_SendMsg { + c_call := _m.On("SendMsg", m) + return &FastTask_HeartbeatClient_SendMsg{Call: c_call} +} + +func (_m *FastTask_HeartbeatClient) OnSendMsgMatch(matchers ...interface{}) *FastTask_HeartbeatClient_SendMsg { + c_call := _m.On("SendMsg", matchers...) + return &FastTask_HeartbeatClient_SendMsg{Call: c_call} +} + +// SendMsg provides a mock function with given fields: m +func (_m *FastTask_HeartbeatClient) SendMsg(m interface{}) error { + ret := _m.Called(m) + + var r0 error + if rf, ok := ret.Get(0).(func(interface{}) error); ok { + r0 = rf(m) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +type FastTask_HeartbeatClient_Trailer struct { + *mock.Call +} + +func (_m FastTask_HeartbeatClient_Trailer) Return(_a0 metadata.MD) *FastTask_HeartbeatClient_Trailer { + return &FastTask_HeartbeatClient_Trailer{Call: _m.Call.Return(_a0)} +} + +func (_m *FastTask_HeartbeatClient) OnTrailer() *FastTask_HeartbeatClient_Trailer { + c_call := _m.On("Trailer") + return &FastTask_HeartbeatClient_Trailer{Call: c_call} +} + +func (_m *FastTask_HeartbeatClient) OnTrailerMatch(matchers ...interface{}) *FastTask_HeartbeatClient_Trailer { + c_call := _m.On("Trailer", matchers...) + return &FastTask_HeartbeatClient_Trailer{Call: c_call} +} + +// Trailer provides a mock function with given fields: +func (_m *FastTask_HeartbeatClient) Trailer() metadata.MD { + ret := _m.Called() + + var r0 metadata.MD + if rf, ok := ret.Get(0).(func() metadata.MD); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(metadata.MD) + } + } + + return r0 +} diff --git a/fasttask/plugin/mocks/fast_task__heartbeat_server.go b/fasttask/plugin/mocks/fast_task__heartbeat_server.go new file mode 100644 index 00000000000..8796ab23427 --- /dev/null +++ b/fasttask/plugin/mocks/fast_task__heartbeat_server.go @@ -0,0 +1,257 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" + metadata "google.golang.org/grpc/metadata" + + pb "github.com/unionai/flyte/fasttask/plugin/pb" +) + +// FastTask_HeartbeatServer is an autogenerated mock type for the FastTask_HeartbeatServer type +type FastTask_HeartbeatServer struct { + mock.Mock +} + +type FastTask_HeartbeatServer_Context struct { + *mock.Call +} + +func (_m FastTask_HeartbeatServer_Context) Return(_a0 context.Context) *FastTask_HeartbeatServer_Context { + return &FastTask_HeartbeatServer_Context{Call: _m.Call.Return(_a0)} +} + +func (_m *FastTask_HeartbeatServer) OnContext() *FastTask_HeartbeatServer_Context { + c_call := _m.On("Context") + return &FastTask_HeartbeatServer_Context{Call: c_call} +} + +func (_m *FastTask_HeartbeatServer) OnContextMatch(matchers ...interface{}) *FastTask_HeartbeatServer_Context { + c_call := _m.On("Context", matchers...) + return &FastTask_HeartbeatServer_Context{Call: c_call} +} + +// Context provides a mock function with given fields: +func (_m *FastTask_HeartbeatServer) Context() context.Context { + ret := _m.Called() + + var r0 context.Context + if rf, ok := ret.Get(0).(func() context.Context); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(context.Context) + } + } + + return r0 +} + +type FastTask_HeartbeatServer_Recv struct { + *mock.Call +} + +func (_m FastTask_HeartbeatServer_Recv) Return(_a0 *pb.HeartbeatRequest, _a1 error) *FastTask_HeartbeatServer_Recv { + return &FastTask_HeartbeatServer_Recv{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *FastTask_HeartbeatServer) OnRecv() *FastTask_HeartbeatServer_Recv { + c_call := _m.On("Recv") + return &FastTask_HeartbeatServer_Recv{Call: c_call} +} + +func (_m *FastTask_HeartbeatServer) OnRecvMatch(matchers ...interface{}) *FastTask_HeartbeatServer_Recv { + c_call := _m.On("Recv", matchers...) + return &FastTask_HeartbeatServer_Recv{Call: c_call} +} + +// Recv provides a mock function with given fields: +func (_m *FastTask_HeartbeatServer) Recv() (*pb.HeartbeatRequest, error) { + ret := _m.Called() + + var r0 *pb.HeartbeatRequest + if rf, ok := ret.Get(0).(func() *pb.HeartbeatRequest); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*pb.HeartbeatRequest) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type FastTask_HeartbeatServer_RecvMsg struct { + *mock.Call +} + +func (_m FastTask_HeartbeatServer_RecvMsg) Return(_a0 error) *FastTask_HeartbeatServer_RecvMsg { + return &FastTask_HeartbeatServer_RecvMsg{Call: _m.Call.Return(_a0)} +} + +func (_m *FastTask_HeartbeatServer) OnRecvMsg(m interface{}) *FastTask_HeartbeatServer_RecvMsg { + c_call := _m.On("RecvMsg", m) + return &FastTask_HeartbeatServer_RecvMsg{Call: c_call} +} + +func (_m *FastTask_HeartbeatServer) OnRecvMsgMatch(matchers ...interface{}) *FastTask_HeartbeatServer_RecvMsg { + c_call := _m.On("RecvMsg", matchers...) + return &FastTask_HeartbeatServer_RecvMsg{Call: c_call} +} + +// RecvMsg provides a mock function with given fields: m +func (_m *FastTask_HeartbeatServer) RecvMsg(m interface{}) error { + ret := _m.Called(m) + + var r0 error + if rf, ok := ret.Get(0).(func(interface{}) error); ok { + r0 = rf(m) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +type FastTask_HeartbeatServer_Send struct { + *mock.Call +} + +func (_m FastTask_HeartbeatServer_Send) Return(_a0 error) *FastTask_HeartbeatServer_Send { + return &FastTask_HeartbeatServer_Send{Call: _m.Call.Return(_a0)} +} + +func (_m *FastTask_HeartbeatServer) OnSend(_a0 *pb.HeartbeatResponse) *FastTask_HeartbeatServer_Send { + c_call := _m.On("Send", _a0) + return &FastTask_HeartbeatServer_Send{Call: c_call} +} + +func (_m *FastTask_HeartbeatServer) OnSendMatch(matchers ...interface{}) *FastTask_HeartbeatServer_Send { + c_call := _m.On("Send", matchers...) + return &FastTask_HeartbeatServer_Send{Call: c_call} +} + +// Send provides a mock function with given fields: _a0 +func (_m *FastTask_HeartbeatServer) Send(_a0 *pb.HeartbeatResponse) error { + ret := _m.Called(_a0) + + var r0 error + if rf, ok := ret.Get(0).(func(*pb.HeartbeatResponse) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +type FastTask_HeartbeatServer_SendHeader struct { + *mock.Call +} + +func (_m FastTask_HeartbeatServer_SendHeader) Return(_a0 error) *FastTask_HeartbeatServer_SendHeader { + return &FastTask_HeartbeatServer_SendHeader{Call: _m.Call.Return(_a0)} +} + +func (_m *FastTask_HeartbeatServer) OnSendHeader(_a0 metadata.MD) *FastTask_HeartbeatServer_SendHeader { + c_call := _m.On("SendHeader", _a0) + return &FastTask_HeartbeatServer_SendHeader{Call: c_call} +} + +func (_m *FastTask_HeartbeatServer) OnSendHeaderMatch(matchers ...interface{}) *FastTask_HeartbeatServer_SendHeader { + c_call := _m.On("SendHeader", matchers...) + return &FastTask_HeartbeatServer_SendHeader{Call: c_call} +} + +// SendHeader provides a mock function with given fields: _a0 +func (_m *FastTask_HeartbeatServer) SendHeader(_a0 metadata.MD) error { + ret := _m.Called(_a0) + + var r0 error + if rf, ok := ret.Get(0).(func(metadata.MD) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +type FastTask_HeartbeatServer_SendMsg struct { + *mock.Call +} + +func (_m FastTask_HeartbeatServer_SendMsg) Return(_a0 error) *FastTask_HeartbeatServer_SendMsg { + return &FastTask_HeartbeatServer_SendMsg{Call: _m.Call.Return(_a0)} +} + +func (_m *FastTask_HeartbeatServer) OnSendMsg(m interface{}) *FastTask_HeartbeatServer_SendMsg { + c_call := _m.On("SendMsg", m) + return &FastTask_HeartbeatServer_SendMsg{Call: c_call} +} + +func (_m *FastTask_HeartbeatServer) OnSendMsgMatch(matchers ...interface{}) *FastTask_HeartbeatServer_SendMsg { + c_call := _m.On("SendMsg", matchers...) + return &FastTask_HeartbeatServer_SendMsg{Call: c_call} +} + +// SendMsg provides a mock function with given fields: m +func (_m *FastTask_HeartbeatServer) SendMsg(m interface{}) error { + ret := _m.Called(m) + + var r0 error + if rf, ok := ret.Get(0).(func(interface{}) error); ok { + r0 = rf(m) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +type FastTask_HeartbeatServer_SetHeader struct { + *mock.Call +} + +func (_m FastTask_HeartbeatServer_SetHeader) Return(_a0 error) *FastTask_HeartbeatServer_SetHeader { + return &FastTask_HeartbeatServer_SetHeader{Call: _m.Call.Return(_a0)} +} + +func (_m *FastTask_HeartbeatServer) OnSetHeader(_a0 metadata.MD) *FastTask_HeartbeatServer_SetHeader { + c_call := _m.On("SetHeader", _a0) + return &FastTask_HeartbeatServer_SetHeader{Call: c_call} +} + +func (_m *FastTask_HeartbeatServer) OnSetHeaderMatch(matchers ...interface{}) *FastTask_HeartbeatServer_SetHeader { + c_call := _m.On("SetHeader", matchers...) + return &FastTask_HeartbeatServer_SetHeader{Call: c_call} +} + +// SetHeader provides a mock function with given fields: _a0 +func (_m *FastTask_HeartbeatServer) SetHeader(_a0 metadata.MD) error { + ret := _m.Called(_a0) + + var r0 error + if rf, ok := ret.Get(0).(func(metadata.MD) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// SetTrailer provides a mock function with given fields: _a0 +func (_m *FastTask_HeartbeatServer) SetTrailer(_a0 metadata.MD) { + _m.Called(_a0) +} diff --git a/fasttask/plugin/mocks/fast_task_client.go b/fasttask/plugin/mocks/fast_task_client.go new file mode 100644 index 00000000000..86ac7a5c4b4 --- /dev/null +++ b/fasttask/plugin/mocks/fast_task_client.go @@ -0,0 +1,66 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + context "context" + + grpc "google.golang.org/grpc" + + mock "github.com/stretchr/testify/mock" + + pb "github.com/unionai/flyte/fasttask/plugin/pb" +) + +// FastTaskClient is an autogenerated mock type for the FastTaskClient type +type FastTaskClient struct { + mock.Mock +} + +type FastTaskClient_Heartbeat struct { + *mock.Call +} + +func (_m FastTaskClient_Heartbeat) Return(_a0 pb.FastTask_HeartbeatClient, _a1 error) *FastTaskClient_Heartbeat { + return &FastTaskClient_Heartbeat{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *FastTaskClient) OnHeartbeat(ctx context.Context, opts ...grpc.CallOption) *FastTaskClient_Heartbeat { + c_call := _m.On("Heartbeat", ctx, opts) + return &FastTaskClient_Heartbeat{Call: c_call} +} + +func (_m *FastTaskClient) OnHeartbeatMatch(matchers ...interface{}) *FastTaskClient_Heartbeat { + c_call := _m.On("Heartbeat", matchers...) + return &FastTaskClient_Heartbeat{Call: c_call} +} + +// Heartbeat provides a mock function with given fields: ctx, opts +func (_m *FastTaskClient) Heartbeat(ctx context.Context, opts ...grpc.CallOption) (pb.FastTask_HeartbeatClient, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 pb.FastTask_HeartbeatClient + if rf, ok := ret.Get(0).(func(context.Context, ...grpc.CallOption) pb.FastTask_HeartbeatClient); ok { + r0 = rf(ctx, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(pb.FastTask_HeartbeatClient) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, ...grpc.CallOption) error); ok { + r1 = rf(ctx, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} diff --git a/fasttask/plugin/mocks/fast_task_server.go b/fasttask/plugin/mocks/fast_task_server.go new file mode 100644 index 00000000000..6ebfa5ecb19 --- /dev/null +++ b/fasttask/plugin/mocks/fast_task_server.go @@ -0,0 +1,50 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + mock "github.com/stretchr/testify/mock" + pb "github.com/unionai/flyte/fasttask/plugin/pb" +) + +// FastTaskServer is an autogenerated mock type for the FastTaskServer type +type FastTaskServer struct { + mock.Mock +} + +type FastTaskServer_Heartbeat struct { + *mock.Call +} + +func (_m FastTaskServer_Heartbeat) Return(_a0 error) *FastTaskServer_Heartbeat { + return &FastTaskServer_Heartbeat{Call: _m.Call.Return(_a0)} +} + +func (_m *FastTaskServer) OnHeartbeat(_a0 pb.FastTask_HeartbeatServer) *FastTaskServer_Heartbeat { + c_call := _m.On("Heartbeat", _a0) + return &FastTaskServer_Heartbeat{Call: c_call} +} + +func (_m *FastTaskServer) OnHeartbeatMatch(matchers ...interface{}) *FastTaskServer_Heartbeat { + c_call := _m.On("Heartbeat", matchers...) + return &FastTaskServer_Heartbeat{Call: c_call} +} + +// Heartbeat provides a mock function with given fields: _a0 +func (_m *FastTaskServer) Heartbeat(_a0 pb.FastTask_HeartbeatServer) error { + ret := _m.Called(_a0) + + var r0 error + if rf, ok := ret.Get(0).(func(pb.FastTask_HeartbeatServer) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// mustEmbedUnimplementedFastTaskServer provides a mock function with given fields: +func (_m *FastTaskServer) mustEmbedUnimplementedFastTaskServer() { + _m.Called() +} diff --git a/fasttask/plugin/mocks/fast_task_service.go b/fasttask/plugin/mocks/fast_task_service.go new file mode 100644 index 00000000000..54e59aa8ff6 --- /dev/null +++ b/fasttask/plugin/mocks/fast_task_service.go @@ -0,0 +1,132 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + context "context" + + core "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" + mock "github.com/stretchr/testify/mock" +) + +// FastTaskService is an autogenerated mock type for the FastTaskService type +type FastTaskService struct { + mock.Mock +} + +type FastTaskService_CheckStatus struct { + *mock.Call +} + +func (_m FastTaskService_CheckStatus) Return(_a0 core.Phase, _a1 string, _a2 error) *FastTaskService_CheckStatus { + return &FastTaskService_CheckStatus{Call: _m.Call.Return(_a0, _a1, _a2)} +} + +func (_m *FastTaskService) OnCheckStatus(ctx context.Context, taskID string, queueID string, workerID string) *FastTaskService_CheckStatus { + c_call := _m.On("CheckStatus", ctx, taskID, queueID, workerID) + return &FastTaskService_CheckStatus{Call: c_call} +} + +func (_m *FastTaskService) OnCheckStatusMatch(matchers ...interface{}) *FastTaskService_CheckStatus { + c_call := _m.On("CheckStatus", matchers...) + return &FastTaskService_CheckStatus{Call: c_call} +} + +// CheckStatus provides a mock function with given fields: ctx, taskID, queueID, workerID +func (_m *FastTaskService) CheckStatus(ctx context.Context, taskID string, queueID string, workerID string) (core.Phase, string, error) { + ret := _m.Called(ctx, taskID, queueID, workerID) + + var r0 core.Phase + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) core.Phase); ok { + r0 = rf(ctx, taskID, queueID, workerID) + } else { + r0 = ret.Get(0).(core.Phase) + } + + var r1 string + if rf, ok := ret.Get(1).(func(context.Context, string, string, string) string); ok { + r1 = rf(ctx, taskID, queueID, workerID) + } else { + r1 = ret.Get(1).(string) + } + + var r2 error + if rf, ok := ret.Get(2).(func(context.Context, string, string, string) error); ok { + r2 = rf(ctx, taskID, queueID, workerID) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +type FastTaskService_Cleanup struct { + *mock.Call +} + +func (_m FastTaskService_Cleanup) Return(_a0 error) *FastTaskService_Cleanup { + return &FastTaskService_Cleanup{Call: _m.Call.Return(_a0)} +} + +func (_m *FastTaskService) OnCleanup(ctx context.Context, taskID string, queueID string, workerID string) *FastTaskService_Cleanup { + c_call := _m.On("Cleanup", ctx, taskID, queueID, workerID) + return &FastTaskService_Cleanup{Call: c_call} +} + +func (_m *FastTaskService) OnCleanupMatch(matchers ...interface{}) *FastTaskService_Cleanup { + c_call := _m.On("Cleanup", matchers...) + return &FastTaskService_Cleanup{Call: c_call} +} + +// Cleanup provides a mock function with given fields: ctx, taskID, queueID, workerID +func (_m *FastTaskService) Cleanup(ctx context.Context, taskID string, queueID string, workerID string) error { + ret := _m.Called(ctx, taskID, queueID, workerID) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) error); ok { + r0 = rf(ctx, taskID, queueID, workerID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +type FastTaskService_OfferOnQueue struct { + *mock.Call +} + +func (_m FastTaskService_OfferOnQueue) Return(_a0 string, _a1 error) *FastTaskService_OfferOnQueue { + return &FastTaskService_OfferOnQueue{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *FastTaskService) OnOfferOnQueue(ctx context.Context, queueID string, taskID string, namespace string, workflowID string, cmd []string) *FastTaskService_OfferOnQueue { + c_call := _m.On("OfferOnQueue", ctx, queueID, taskID, namespace, workflowID, cmd) + return &FastTaskService_OfferOnQueue{Call: c_call} +} + +func (_m *FastTaskService) OnOfferOnQueueMatch(matchers ...interface{}) *FastTaskService_OfferOnQueue { + c_call := _m.On("OfferOnQueue", matchers...) + return &FastTaskService_OfferOnQueue{Call: c_call} +} + +// OfferOnQueue provides a mock function with given fields: ctx, queueID, taskID, namespace, workflowID, cmd +func (_m *FastTaskService) OfferOnQueue(ctx context.Context, queueID string, taskID string, namespace string, workflowID string, cmd []string) (string, error) { + ret := _m.Called(ctx, queueID, taskID, namespace, workflowID, cmd) + + var r0 string + if rf, ok := ret.Get(0).(func(context.Context, string, string, string, string, []string) string); ok { + r0 = rf(ctx, queueID, taskID, namespace, workflowID, cmd) + } else { + r0 = ret.Get(0).(string) + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, string, string, string, string, []string) error); ok { + r1 = rf(ctx, queueID, taskID, namespace, workflowID, cmd) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} diff --git a/fasttask/plugin/mocks/is_fast_task_environment_spec__termination_criteria.go b/fasttask/plugin/mocks/is_fast_task_environment_spec__termination_criteria.go new file mode 100644 index 00000000000..d6adf5437c2 --- /dev/null +++ b/fasttask/plugin/mocks/is_fast_task_environment_spec__termination_criteria.go @@ -0,0 +1,15 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" + +// isFastTaskEnvironmentSpec_TerminationCriteria is an autogenerated mock type for the isFastTaskEnvironmentSpec_TerminationCriteria type +type isFastTaskEnvironmentSpec_TerminationCriteria struct { + mock.Mock +} + +// isFastTaskEnvironmentSpec_TerminationCriteria provides a mock function with given fields: +func (_m *isFastTaskEnvironmentSpec_TerminationCriteria) isFastTaskEnvironmentSpec_TerminationCriteria() { + _m.Called() +} diff --git a/fasttask/plugin/mocks/unsafe_fast_task_server.go b/fasttask/plugin/mocks/unsafe_fast_task_server.go new file mode 100644 index 00000000000..3c4635d6111 --- /dev/null +++ b/fasttask/plugin/mocks/unsafe_fast_task_server.go @@ -0,0 +1,15 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" + +// UnsafeFastTaskServer is an autogenerated mock type for the UnsafeFastTaskServer type +type UnsafeFastTaskServer struct { + mock.Mock +} + +// mustEmbedUnimplementedFastTaskServer provides a mock function with given fields: +func (_m *UnsafeFastTaskServer) mustEmbedUnimplementedFastTaskServer() { + _m.Called() +} diff --git a/fasttask/plugin/plugin.go b/fasttask/plugin/plugin.go index e6d374bce9b..6199f290821 100644 --- a/fasttask/plugin/plugin.go +++ b/fasttask/plugin/plugin.go @@ -19,6 +19,7 @@ import ( "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/ioutils" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils" + "github.com/flyteorg/flyte/flytestdlib/logger" "github.com/unionai/flyte/fasttask/plugin/pb" ) @@ -27,23 +28,23 @@ const fastTaskType = "fast-task" var statusUpdateNotFoundError = errors.New("StatusUpdateNotFound") -type Phase int +type SubmissionPhase int const ( - PhaseNotStarted Phase = iota - PhaseRunning + NotSubmitted SubmissionPhase = iota + Submitted ) // State maintains the current status of the task execution. type State struct { - Phase Phase - WorkerID string - LastUpdated time.Time + SubmissionPhase SubmissionPhase + WorkerID string + LastUpdated time.Time } // Plugin is a fast task plugin that offers task execution to a worker pool. type Plugin struct { - fastTaskService *FastTaskService + fastTaskService FastTaskService } // GetID returns the unique identifier for the plugin. @@ -164,9 +165,10 @@ func (p *Plugin) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (co return core.UnknownTransition, err } + queueID := fastTaskEnvironment.GetQueueId() phaseInfo := core.PhaseInfoUndefined - switch pluginState.Phase { - case PhaseNotStarted: + switch pluginState.SubmissionPhase { + case NotSubmitted: // read task template taskTemplate, err := tCtx.TaskReader().Read(ctx) if err != nil { @@ -191,17 +193,18 @@ func (p *Plugin) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (co // offer the work to the queue ownerID := tCtx.TaskExecutionMetadata().GetOwnerID() - workerID, err := p.fastTaskService.OfferOnQueue(ctx, fastTaskEnvironment.GetQueueId(), taskID, ownerID.Namespace, ownerID.Name, command) + workerID, err := p.fastTaskService.OfferOnQueue(ctx, queueID, taskID, ownerID.Namespace, ownerID.Name, command) if err != nil { return core.UnknownTransition, err } if len(workerID) > 0 { - pluginState.Phase = PhaseRunning + pluginState.SubmissionPhase = Submitted + pluginState.WorkerID = workerID pluginState.LastUpdated = time.Now() - phaseInfo = core.PhaseInfoRunning(core.DefaultPhaseVersion, nil) + phaseInfo = core.PhaseInfoQueued(time.Now(), core.DefaultPhaseVersion, fmt.Sprintf("task offered to worker %s", workerID)) } else { if pluginState.LastUpdated.IsZero() { pluginState.LastUpdated = time.Now() @@ -209,12 +212,13 @@ func (p *Plugin) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (co // fail if no worker available within grace period if time.Since(pluginState.LastUpdated) > GetConfig().GracePeriodWorkersUnavailable.Duration { - phaseInfo = core.PhaseInfoSystemFailure("unknown", "timed out waiting for worker availability", nil) + logger.Infof(ctx, "Timed out waiting for available worker for queue %s", queueID) + phaseInfo = core.PhaseInfoSystemFailure("unknown", fmt.Sprintf("timed out waiting for available worker for queue %s", queueID), nil) } else { - phaseInfo = core.PhaseInfoNotReady(time.Now(), core.DefaultPhaseVersion, "no workers available") + phaseInfo = core.PhaseInfoWaitingForResourcesInfo(time.Now(), core.DefaultPhaseVersion, "no workers available", nil) } } - case PhaseRunning: + case Submitted: // check the task status phase, reason, err := p.fastTaskService.CheckStatus(ctx, taskID, fastTaskEnvironment.GetQueueId(), pluginState.WorkerID) @@ -223,7 +227,8 @@ func (p *Plugin) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (co return core.UnknownTransition, err } else if errors.Is(err, statusUpdateNotFoundError) && now.Sub(pluginState.LastUpdated) > GetConfig().GracePeriodStatusNotFound.Duration { // if task has not been updated within the grace period we should abort - return core.DoTransition(core.PhaseInfoSystemRetryableFailure("unknown", "task status update not reported within grace period", nil)), nil + logger.Infof(ctx, "Task status update not reported within grace period for queue %s and worker %s", queueID, pluginState.WorkerID) + return core.DoTransition(core.PhaseInfoSystemRetryableFailure("unknown", fmt.Sprintf("task status update not reported within grace period for queue %s and worker %s", queueID, pluginState.WorkerID), nil)), nil } else if phase == core.PhaseSuccess { taskTemplate, err := tCtx.TaskReader().Read(ctx) if err != nil { @@ -300,7 +305,7 @@ func init() { } // create and start grpc server - fastTaskService := NewFastTaskService(iCtx.EnqueueOwner()) + fastTaskService := newFastTaskService(iCtx.EnqueueOwner(), iCtx.MetricsScope()) go func() { grpcServer := grpc.NewServer() pb.RegisterFastTaskServer(grpcServer, fastTaskService) diff --git a/fasttask/plugin/plugin_test.go b/fasttask/plugin/plugin_test.go new file mode 100644 index 00000000000..07cf7f8937d --- /dev/null +++ b/fasttask/plugin/plugin_test.go @@ -0,0 +1,483 @@ +package plugin + +import ( + "context" + "testing" + "time" + + "github.com/golang/protobuf/proto" + _struct "github.com/golang/protobuf/ptypes/struct" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + v1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/types" + + idlcore "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" + coremocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core/mocks" + iomocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/io/mocks" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils" + + "github.com/unionai/flyte/fasttask/plugin/mocks" + "github.com/unionai/flyte/fasttask/plugin/pb" +) + +func buildFasttaskEnvironment(t *testing.T, fastTaskExtant *pb.FastTaskEnvironment, fastTaskSpec *pb.FastTaskEnvironmentSpec) *_struct.Struct { + executionEnv := &idlcore.ExecutionEnv{ + Id: "foo", + Type: "fast-task", + } + + if fastTaskExtant != nil { + extant := &_struct.Struct{} + err := utils.MarshalStruct(fastTaskExtant, extant) + assert.Nil(t, err) + executionEnv.Environment = &idlcore.ExecutionEnv_Extant{ + Extant: extant, + } + } else if fastTaskSpec != nil { + spec := &_struct.Struct{} + err := utils.MarshalStruct(fastTaskSpec, spec) + assert.Nil(t, err) + executionEnv.Environment = &idlcore.ExecutionEnv_Spec{ + Spec: spec, + } + } + + executionEnvStruct := &_struct.Struct{} + err := utils.MarshalStruct(executionEnv, executionEnvStruct) + assert.Nil(t, err) + + return executionEnvStruct +} + +func getBaseFasttaskTaskTemplate(t *testing.T) *idlcore.TaskTemplate { + executionEnv := buildFasttaskEnvironment(t, &pb.FastTaskEnvironment{ + QueueId: "foo", + }, nil) + + executionEnvStruct := &_struct.Struct{} + err := utils.MarshalStruct(executionEnv, executionEnvStruct) + assert.Nil(t, err) + + return &idlcore.TaskTemplate{ + Custom: executionEnvStruct, + Target: &idlcore.TaskTemplate_Container{ + Container: &idlcore.Container{ + Command: []string{""}, + Args: []string{}, + }, + }, + } +} + +func TestFinalize(t *testing.T) { + ctx := context.TODO() + + // initialize fasttask TaskTemplate + taskTemplate := getBaseFasttaskTaskTemplate(t) + taskReader := &coremocks.TaskReader{} + taskReader.OnReadMatch(mock.Anything).Return(taskTemplate, nil) + + // initialize static execution context attributes + taskMetadata := &coremocks.TaskExecutionMetadata{} + taskExecutionID := &coremocks.TaskExecutionID{} + taskExecutionID.OnGetGeneratedNameWithMatch(mock.Anything, mock.Anything).Return("task_id", nil) + taskMetadata.OnGetTaskExecutionID().Return(taskExecutionID) + + // create TaskExecutionContext + tCtx := &coremocks.TaskExecutionContext{} + tCtx.OnTaskExecutionMetadata().Return(taskMetadata) + tCtx.OnTaskReader().Return(taskReader) + + arrayNodeStateInput := &State{ + SubmissionPhase: Submitted, + WorkerID: "w0", + } + pluginStateReader := &coremocks.PluginStateReader{} + pluginStateReader.On("Get", mock.Anything).Return( + func(v interface{}) uint8 { + *v.(*State) = *arrayNodeStateInput + return 0 + }, + nil, + ) + tCtx.OnPluginStateReader().Return(pluginStateReader) + + // create FastTaskService mock + fastTaskService := &mocks.FastTaskService{} + fastTaskService.OnCleanup(ctx, "task_id", "foo", "w0").Return(nil) + + // initialize plugin + plugin := &Plugin{ + fastTaskService: fastTaskService, + } + + // call handle + err := plugin.Finalize(ctx, tCtx) + assert.Nil(t, err) +} + +func TestGetExecutionEnv(t *testing.T) { + ctx := context.TODO() + + expectedExtant := &pb.FastTaskEnvironment{ + QueueId: "foo", + } + expectedExtantStruct := &_struct.Struct{} + err := utils.MarshalStruct(expectedExtant, expectedExtantStruct) + assert.Nil(t, err) + + tests := []struct { + name string + fastTaskExtant *pb.FastTaskEnvironment + fastTaskSpec *pb.FastTaskEnvironmentSpec + clientGetExists bool + }{ + { + name: "ExecutionExtant", + fastTaskExtant: &pb.FastTaskEnvironment{ + QueueId: "foo", + }, + }, + { + name: "ExecutionSpecExists", + fastTaskSpec: &pb.FastTaskEnvironmentSpec{}, + clientGetExists: true, + }, + { + name: "ExecutionSpecCreate", + fastTaskSpec: &pb.FastTaskEnvironmentSpec{ + PodTemplateSpec: []byte("bar"), + }, + clientGetExists: false, + }, + { + name: "ExecutionSpecInjectPodTemplateAndCreate", + fastTaskSpec: &pb.FastTaskEnvironmentSpec{}, + clientGetExists: false, + }, + } + + // initialize static execution context attributes + inputReader := &iomocks.InputReader{} + inputReader.OnGetInputPrefixPath().Return("test-data-prefix") + inputReader.OnGetInputPath().Return("test-data-reference") + inputReader.OnGetMatch(mock.Anything).Return(&idlcore.LiteralMap{}, nil) + + outputReader := &iomocks.OutputWriter{} + outputReader.OnGetOutputPath().Return("/data/outputs.pb") + outputReader.OnGetOutputPrefixPath().Return("/data/") + outputReader.OnGetRawOutputPrefix().Return("") + outputReader.OnGetCheckpointPrefix().Return("/checkpoint") + outputReader.OnGetPreviousCheckpointsPrefix().Return("/prev") + + taskMetadata := &coremocks.TaskExecutionMetadata{} + taskMetadata.OnGetEnvironmentVariables().Return(nil) + taskMetadata.OnGetK8sServiceAccount().Return("service-account") + taskMetadata.OnGetNamespace().Return("test-namespace") + taskMetadata.OnGetPlatformResources().Return(&v1.ResourceRequirements{}) + taskMetadata.OnIsInterruptible().Return(true) + + taskExecutionID := &coremocks.TaskExecutionID{} + taskExecutionID.OnGetIDMatch().Return(idlcore.TaskExecutionIdentifier{ + NodeExecutionId: &idlcore.NodeExecutionIdentifier{ + ExecutionId: &idlcore.WorkflowExecutionIdentifier{ + Name: "my_name", + Project: "my_project", + Domain: "my_domain", + }, + }, + }) + taskExecutionID.OnGetGeneratedNameMatch().Return("task_id") + taskMetadata.OnGetTaskExecutionID().Return(taskExecutionID) + taskExecutionID.OnGetGeneratedNameWithMatch(mock.Anything, mock.Anything).Return("task_id", nil) + taskMetadata.OnGetTaskExecutionID().Return(taskExecutionID) + + taskOverrides := &coremocks.TaskOverrides{} + taskOverrides.OnGetResourcesMatch().Return(&v1.ResourceRequirements{}) + taskOverrides.OnGetExtendedResourcesMatch().Return(nil) + taskMetadata.OnGetOverridesMatch().Return(taskOverrides) + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + // initialize fasttask TaskTemplate + executionEnvStruct := buildFasttaskEnvironment(t, test.fastTaskExtant, test.fastTaskSpec) + taskTemplate := &idlcore.TaskTemplate{ + Custom: executionEnvStruct, + Target: &idlcore.TaskTemplate_Container{ + Container: &idlcore.Container{ + Command: []string{""}, + Args: []string{}, + }, + }, + } + + // create ExecutionEnvClient mock + executionEnvClient := &coremocks.ExecutionEnvClient{} + if test.clientGetExists { + executionEnvClient.OnGetMatch(ctx, mock.Anything).Return(expectedExtantStruct) + } else { + executionEnvClient.OnGetMatch(ctx, mock.Anything).Return(nil) + } + executionEnvClient.OnCreateMatch(ctx, "foo", mock.Anything).Return(expectedExtantStruct, nil) + + // create TaskExecutionContext + tCtx := &coremocks.TaskExecutionContext{} + tCtx.OnInputReader().Return(inputReader) + tCtx.OnOutputWriter().Return(outputReader) + tCtx.OnTaskExecutionMetadata().Return(taskMetadata) + + taskReader := &coremocks.TaskReader{} + taskReader.OnReadMatch(mock.Anything).Return(taskTemplate, nil) + tCtx.OnTaskReader().Return(taskReader) + + tCtx.OnGetExecutionEnvClient().Return(executionEnvClient) + + // initialize plugin + plugin := &Plugin{} + + // call handle + fastTaskEnvironment, err := plugin.getExecutionEnv(ctx, tCtx) + assert.Nil(t, err) + assert.True(t, proto.Equal(expectedExtant, fastTaskEnvironment)) + }) + } +} + +func TestHandleNotYetStarted(t *testing.T) { + ctx := context.TODO() + tests := []struct { + name string + workerID string + lastUpdated time.Time + expectedPhase core.Phase + expectedReason string + expectedError error + }{ + { + name: "NoWorkersAvailable", + workerID: "", + expectedPhase: core.PhaseWaitingForResources, + expectedReason: "no workers available", + expectedError: nil, + }, + { + name: "NoWorkersAvailableGracePeriodFailure", + workerID: "", + lastUpdated: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), + expectedPhase: core.PhasePermanentFailure, + expectedReason: "", + expectedError: nil, + }, + { + name: "AssignedToWorker", + workerID: "w0", + expectedPhase: core.PhaseQueued, + expectedReason: "task offered to worker w0", + expectedError: nil, + }, + } + + // initialize fasttask TaskTemplate + taskTemplate := getBaseFasttaskTaskTemplate(t) + taskReader := &coremocks.TaskReader{} + taskReader.OnReadMatch(mock.Anything).Return(taskTemplate, nil) + + // initialize static execution context attributes + inputReader := &iomocks.InputReader{} + inputReader.OnGetInputPrefixPath().Return("test-data-prefix") + inputReader.OnGetInputPath().Return("test-data-reference") + inputReader.OnGetMatch(mock.Anything).Return(&idlcore.LiteralMap{}, nil) + + outputReader := &iomocks.OutputWriter{} + outputReader.OnGetOutputPath().Return("/data/outputs.pb") + outputReader.OnGetOutputPrefixPath().Return("/data/") + outputReader.OnGetRawOutputPrefix().Return("") + outputReader.OnGetCheckpointPrefix().Return("/checkpoint") + outputReader.OnGetPreviousCheckpointsPrefix().Return("/prev") + + taskMetadata := &coremocks.TaskExecutionMetadata{} + taskMetadata.OnGetOwnerIDMatch().Return(types.NamespacedName{ + Namespace: "namespace", + Name: "execution_id", + }) + taskExecutionID := &coremocks.TaskExecutionID{} + taskExecutionID.OnGetGeneratedNameWithMatch(mock.Anything, mock.Anything).Return("task_id", nil) + taskMetadata.OnGetTaskExecutionID().Return(taskExecutionID) + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + // create TaskExecutionContext + tCtx := &coremocks.TaskExecutionContext{} + tCtx.OnInputReader().Return(inputReader) + tCtx.OnOutputWriter().Return(outputReader) + tCtx.OnTaskExecutionMetadata().Return(taskMetadata) + tCtx.OnTaskReader().Return(taskReader) + + arrayNodeStateInput := &State{ + SubmissionPhase: NotSubmitted, + LastUpdated: test.lastUpdated, + } + pluginStateReader := &coremocks.PluginStateReader{} + pluginStateReader.On("Get", mock.Anything).Return( + func(v interface{}) uint8 { + *v.(*State) = *arrayNodeStateInput + return 0 + }, + nil, + ) + tCtx.OnPluginStateReader().Return(pluginStateReader) + + arrayNodeStateOutput := &State{} + pluginStateWriter := &coremocks.PluginStateWriter{} + pluginStateWriter.On("Put", mock.Anything, mock.Anything).Return( + func(stateVersion uint8, v interface{}) error { + *arrayNodeStateOutput = *v.(*State) + return nil + }, + ) + tCtx.OnPluginStateWriter().Return(pluginStateWriter) + + // create FastTaskService mock + fastTaskService := &mocks.FastTaskService{} + fastTaskService.OnOfferOnQueue(ctx, "foo", "task_id", "namespace", "execution_id", []string{}).Return(test.workerID, nil) + + // initialize plugin + plugin := &Plugin{ + fastTaskService: fastTaskService, + } + + // call handle + transition, err := plugin.Handle(ctx, tCtx) + assert.Equal(t, test.expectedError, err) + assert.Equal(t, test.expectedPhase, transition.Info().Phase()) + assert.Equal(t, test.expectedReason, transition.Info().Reason()) + + if len(test.workerID) > 0 { + assert.Equal(t, test.workerID, arrayNodeStateOutput.WorkerID) + } + }) + } +} + +func TestHandleRunning(t *testing.T) { + ctx := context.TODO() + tests := []struct { + name string + lastUpdated time.Time + taskStatusPhase core.Phase + taskStatusReason string + checkStatusError error + expectedPhase core.Phase + expectedReason string + expectedError error + expectedLastUpdatedInc bool + }{ + { + name: "Running", + lastUpdated: time.Now().Add(-5 * time.Second), + taskStatusPhase: core.PhaseRunning, + taskStatusReason: "", + checkStatusError: nil, + expectedPhase: core.PhaseRunning, + expectedReason: "", + expectedError: nil, + expectedLastUpdatedInc: true, + }, + { + name: "RetryableFailure", + lastUpdated: time.Now().Add(-5 * time.Second), + taskStatusPhase: core.PhaseRetryableFailure, + checkStatusError: nil, + expectedPhase: core.PhaseRetryableFailure, + expectedError: nil, + expectedLastUpdatedInc: false, + }, + { + name: "StatusNotFoundTimeout", + lastUpdated: time.Now().Add(-600 * time.Second), + taskStatusPhase: core.PhaseUndefined, + checkStatusError: statusUpdateNotFoundError, + expectedPhase: core.PhaseRetryableFailure, + expectedError: nil, + expectedLastUpdatedInc: false, + }, + { + name: "Success", + lastUpdated: time.Now().Add(-5 * time.Second), + taskStatusPhase: core.PhaseSuccess, + checkStatusError: nil, + expectedPhase: core.PhaseSuccess, + expectedError: nil, + expectedLastUpdatedInc: false, + }, + } + + // initialize fasttask TaskTemplate + taskTemplate := getBaseFasttaskTaskTemplate(t) + taskReader := &coremocks.TaskReader{} + taskReader.OnReadMatch(mock.Anything).Return(taskTemplate, nil) + + // initialize static execution context attributes + taskMetadata := &coremocks.TaskExecutionMetadata{} + taskMetadata.OnGetOwnerIDMatch().Return(types.NamespacedName{ + Namespace: "namespace", + Name: "execution_id", + }) + taskExecutionID := &coremocks.TaskExecutionID{} + taskExecutionID.OnGetGeneratedNameWithMatch(mock.Anything, mock.Anything).Return("task_id", nil) + taskMetadata.OnGetTaskExecutionID().Return(taskExecutionID) + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + // create TaskExecutionContext + tCtx := &coremocks.TaskExecutionContext{} + tCtx.OnTaskExecutionMetadata().Return(taskMetadata) + tCtx.OnTaskReader().Return(taskReader) + + arrayNodeStateInput := &State{ + SubmissionPhase: Submitted, + WorkerID: "w0", + LastUpdated: test.lastUpdated, + } + pluginStateReader := &coremocks.PluginStateReader{} + pluginStateReader.On("Get", mock.Anything).Return( + func(v interface{}) uint8 { + *v.(*State) = *arrayNodeStateInput + return 0 + }, + nil, + ) + tCtx.OnPluginStateReader().Return(pluginStateReader) + + arrayNodeStateOutput := &State{} + pluginStateWriter := &coremocks.PluginStateWriter{} + pluginStateWriter.On("Put", mock.Anything, mock.Anything).Return( + func(stateVersion uint8, v interface{}) error { + *arrayNodeStateOutput = *v.(*State) + return nil + }, + ) + tCtx.OnPluginStateWriter().Return(pluginStateWriter) + + // create FastTaskService mock + fastTaskService := &mocks.FastTaskService{} + fastTaskService.OnCheckStatusMatch(ctx, "task_id", "foo", "w0").Return(test.taskStatusPhase, "", test.checkStatusError) + + // initialize plugin + plugin := &Plugin{ + fastTaskService: fastTaskService, + } + + // call handle + transition, err := plugin.Handle(ctx, tCtx) + assert.Equal(t, test.expectedError, err) + assert.Equal(t, test.expectedPhase, transition.Info().Phase()) + + if test.expectedLastUpdatedInc { + assert.True(t, arrayNodeStateOutput.LastUpdated.After(test.lastUpdated)) + } + }) + } +} diff --git a/fasttask/plugin/service.go b/fasttask/plugin/service.go index 6edf3ad38b4..8b00a672e28 100644 --- a/fasttask/plugin/service.go +++ b/fasttask/plugin/service.go @@ -2,7 +2,6 @@ package plugin import ( "context" - "errors" "fmt" "io" "math/rand" @@ -20,9 +19,18 @@ import ( var maxPendingOwnersPerQueue = 100 -// FastTaskService is a gRPC service that manages assignment and management of task executions with -// respect to fasttask workers. -type FastTaskService struct { +//go:generate mockery -all -case=underscore + +// FastTaskService defines the interface for managing assignment and management of task executions +type FastTaskService interface { + CheckStatus(ctx context.Context, taskID, queueID, workerID string) (core.Phase, string, error) + Cleanup(ctx context.Context, taskID, queueID, workerID string) error + OfferOnQueue(ctx context.Context, queueID, taskID, namespace, workflowID string, cmd []string) (string, error) +} + +// fastTaskServiceImpl is a gRPC service that manages assignment and management of task executions +// with respect to fasttask workers. +type fastTaskServiceImpl struct { pb.UnimplementedFastTaskServer enqueueOwner core.EnqueueOwner @@ -34,7 +42,7 @@ type FastTaskService struct { pendingTaskOwners map[string]map[string]types.NamespacedName // map[queueID]map[taskID]ownerID pendingTaskOwnersLock sync.RWMutex - taskStatusChannels sync.Map // map[string]chan *WorkerTaskStatus + taskStatusChannels sync.Map // map[taskID]chan *WorkerTaskStatus metrics metrics } @@ -66,12 +74,22 @@ type metrics struct { workers *prometheus.Desc } -func (f *FastTaskService) Describe(ch chan<- *prometheus.Desc) { +func newMetrics(scope promutils.Scope) metrics { + return metrics{ + taskNoWorkersAvailable: scope.MustNewCounter("task_no_workers_available", "Count of task assignment attempts with no workers available"), + taskNoCapacityAvailable: scope.MustNewCounter("task_no_capacity_available", "Count of task assignment attempts with no capacity available"), + taskAssigned: scope.MustNewCounter("task_assigned", "Count of task assignments"), + queues: prometheus.NewDesc(scope.NewScopedMetricName("queue"), "Current number of queues", nil, nil), + workers: prometheus.NewDesc(scope.NewScopedMetricName("workers"), "Current number of workers", nil, nil), + } +} + +func (f *fastTaskServiceImpl) Describe(ch chan<- *prometheus.Desc) { ch <- f.metrics.queues ch <- f.metrics.workers } -func (f *FastTaskService) Collect(ch chan<- prometheus.Metric) { +func (f *fastTaskServiceImpl) Collect(ch chan<- prometheus.Metric) { f.queuesLock.RLock() defer f.queuesLock.RUnlock() @@ -89,7 +107,7 @@ func (f *FastTaskService) Collect(ch chan<- prometheus.Metric) { // Heartbeat is a gRPC stream that manages the heartbeat of a fasttask worker. This includes // receiving task status updates and sending task assignments. -func (f *FastTaskService) Heartbeat(stream pb.FastTask_HeartbeatServer) error { +func (f *fastTaskServiceImpl) Heartbeat(stream pb.FastTask_HeartbeatServer) error { workerID := "" // recv initial heartbeat request @@ -173,12 +191,10 @@ func (f *FastTaskService) Heartbeat(stream pb.FastTask_HeartbeatServer) error { // if taskStatus is complete then enqueueOwner for fast feedback phase := core.Phase(taskStatus.GetPhase()) if phase == core.PhaseSuccess || phase == core.PhaseRetryableFailure { - ownerID := types.NamespacedName{ + if err := f.enqueueOwner(types.NamespacedName{ Namespace: taskStatus.GetNamespace(), Name: taskStatus.GetWorkflowId(), - } - - if err := f.enqueueOwner(ownerID); err != nil { + }); err != nil { logger.Warnf(context.Background(), "failed to enqueue owner for task %s: %+v", taskStatus.GetTaskId(), err) } } @@ -188,7 +204,8 @@ func (f *FastTaskService) Heartbeat(stream pb.FastTask_HeartbeatServer) error { return nil } -func (f *FastTaskService) addWorkerToQueue(queueID string, worker *Worker) *Queue { +// addWorkerToQueue adds a worker to the queue. If the queue does not exist, it is created. +func (f *fastTaskServiceImpl) addWorkerToQueue(queueID string, worker *Worker) *Queue { f.queuesLock.Lock() defer f.queuesLock.Unlock() @@ -207,7 +224,8 @@ func (f *FastTaskService) addWorkerToQueue(queueID string, worker *Worker) *Queu return queue } -func (f *FastTaskService) removeWorkerFromQueue(queueID, workerID string) { +// removeWorkerFromQueue removes a worker from the queue. If the queue is empty, it is deleted. +func (f *fastTaskServiceImpl) removeWorkerFromQueue(queueID, workerID string) { f.queuesLock.Lock() defer f.queuesLock.Unlock() @@ -226,7 +244,7 @@ func (f *FastTaskService) removeWorkerFromQueue(queueID, workerID string) { } // addPendingOwner adds to the pending owners list for the queue, if not already full -func (f *FastTaskService) addPendingOwner(queueID, taskID string, ownerID types.NamespacedName) { +func (f *fastTaskServiceImpl) addPendingOwner(queueID, taskID string, ownerID types.NamespacedName) { f.pendingTaskOwnersLock.Lock() defer f.pendingTaskOwnersLock.Unlock() @@ -243,7 +261,7 @@ func (f *FastTaskService) addPendingOwner(queueID, taskID string, ownerID types. } // removePendingOwner removes the pending owner from the list if still there -func (f *FastTaskService) removePendingOwner(queueID, taskID string) { +func (f *fastTaskServiceImpl) removePendingOwner(queueID, taskID string) { f.pendingTaskOwnersLock.Lock() defer f.pendingTaskOwnersLock.Unlock() @@ -259,7 +277,7 @@ func (f *FastTaskService) removePendingOwner(queueID, taskID string) { } // enqueuePendingOwners drains the pending owners list for the queue and enqueues them for reevaluation -func (f *FastTaskService) enqueuePendingOwners(queueID string) { +func (f *fastTaskServiceImpl) enqueuePendingOwners(queueID string) { f.pendingTaskOwnersLock.Lock() defer f.pendingTaskOwnersLock.Unlock() @@ -284,7 +302,7 @@ func (f *FastTaskService) enqueuePendingOwners(queueID string) { // OfferOnQueue offers a task to a worker on a specific queue. If no workers are available, an // empty string is returned. -func (f *FastTaskService) OfferOnQueue(ctx context.Context, queueID, taskID, namespace, workflowID string, cmd []string) (string, error) { +func (f *fastTaskServiceImpl) OfferOnQueue(ctx context.Context, queueID, taskID, namespace, workflowID string, cmd []string) (string, error) { f.queuesLock.RLock() defer f.queuesLock.RUnlock() @@ -318,7 +336,7 @@ func (f *FastTaskService) OfferOnQueue(ctx context.Context, queueID, taskID, nam worker.capacity.BacklogCount++ } else { // No workers available. Note, we do not add to pending owners at this time as we are optimizing for the worker - // startup case. Theworker backlog should be sufficient to keep the worker busy without needing to proactively + // startup case. The worker backlog should be sufficient to keep the worker busy without needing to proactively // enqueue owners when capacity becomes available. f.metrics.taskNoCapacityAvailable.Inc() return "", nil @@ -340,13 +358,13 @@ func (f *FastTaskService) OfferOnQueue(ctx context.Context, queueID, taskID, nam } // CheckStatus checks the status of a task on a specific queue and worker. -func (f *FastTaskService) CheckStatus(ctx context.Context, taskID, queueID, workerID string) (core.Phase, string, error) { +func (f *fastTaskServiceImpl) CheckStatus(ctx context.Context, taskID, queueID, workerID string) (core.Phase, string, error) { taskStatusChannelResult, exists := f.taskStatusChannels.Load(taskID) if !exists { // if this plugin restarts then TaskContexts may not exist for tasks that are still active. we can // create a TaskContext here because we ensure it will be cleaned up when the task completes. f.taskStatusChannels.Store(taskID, make(chan *workerTaskStatus, GetConfig().TaskStatusBufferSize)) - return core.PhaseUndefined, "", errors.New("task context not found") + return core.PhaseUndefined, "", fmt.Errorf("task context not found") } taskStatusChannel := taskStatusChannelResult.(chan *workerTaskStatus) @@ -377,14 +395,16 @@ Loop: f.queuesLock.RLock() defer f.queuesLock.RUnlock() - queue := f.queues[queueID] - queue.lock.RLock() - defer queue.lock.RUnlock() + // if here it should be impossible for the queue not to exist, but left for safety + if queue, exists := f.queues[queueID]; exists { + queue.lock.RLock() + defer queue.lock.RUnlock() - if worker, exists := queue.workers[workerID]; exists { - worker.responseChan <- &pb.HeartbeatResponse{ - TaskId: taskID, - Operation: pb.HeartbeatResponse_ACK, + if worker, exists := queue.workers[workerID]; exists { + worker.responseChan <- &pb.HeartbeatResponse{ + TaskId: taskID, + Operation: pb.HeartbeatResponse_ACK, + } } } } @@ -394,7 +414,7 @@ Loop: // Cleanup is used to indicate a task is no longer being tracked by the worker and delete the // associated task context. -func (f *FastTaskService) Cleanup(ctx context.Context, taskID, queueID, workerID string) error { +func (f *fastTaskServiceImpl) Cleanup(ctx context.Context, taskID, queueID, workerID string) error { // send delete taskID message to worker f.queuesLock.RLock() defer f.queuesLock.RUnlock() @@ -420,21 +440,12 @@ func (f *FastTaskService) Cleanup(ctx context.Context, taskID, queueID, workerID return nil } -// NewFastTaskService creates a new FastTaskService. -func NewFastTaskService(enqueueOwner core.EnqueueOwner) *FastTaskService { - scope := promutils.NewScope("fasttask") - svc := &FastTaskService{ +// newFastTaskService creates a new fastTaskServiceImpl. +func newFastTaskService(enqueueOwner core.EnqueueOwner, scope promutils.Scope) *fastTaskServiceImpl { + return &fastTaskServiceImpl{ enqueueOwner: enqueueOwner, queues: make(map[string]*Queue), pendingTaskOwners: make(map[string]map[string]types.NamespacedName), - metrics: metrics{ - taskNoWorkersAvailable: scope.MustNewCounter("task_no_workers_available", "Count of task assignment attempts with no workers available"), - taskNoCapacityAvailable: scope.MustNewCounter("task_no_capacity_available", "Count of task assignment attempts with no capacity available"), - taskAssigned: scope.MustNewCounter("task_assigned", "Count of task assignments"), - queues: prometheus.NewDesc(scope.NewScopedMetricName("queue"), "Current number of queues", nil, nil), - workers: prometheus.NewDesc(scope.NewScopedMetricName("workers"), "Current number of workers", nil, nil), - }, + metrics: newMetrics(scope), } - prometheus.MustRegister(svc) - return svc } diff --git a/fasttask/plugin/service_test.go b/fasttask/plugin/service_test.go new file mode 100644 index 00000000000..2d63f0591eb --- /dev/null +++ b/fasttask/plugin/service_test.go @@ -0,0 +1,761 @@ +package plugin + +import ( + "context" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "k8s.io/apimachinery/pkg/types" + + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/flyteorg/flyte/flytestdlib/promutils" + + "github.com/unionai/flyte/fasttask/plugin/pb" +) + +func TestCheckStatus(t *testing.T) { + ctx := context.TODO() + tests := []struct { + name string + taskID string + queueID string + workerID string + taskStatuses []*workerTaskStatus + expectedPhase core.Phase + expectedError error + }{ + { + name: "ChannelDoesNotExist", + taskID: "bar", + queueID: "foo", + workerID: "w1", + taskStatuses: nil, + expectedPhase: core.PhaseUndefined, + expectedError: fmt.Errorf("task context not found"), + }, + { + name: "NoUpdates", + taskID: "bar", + queueID: "foo", + workerID: "w1", + taskStatuses: []*workerTaskStatus{}, + expectedPhase: core.PhaseUndefined, + expectedError: fmt.Errorf("unable to find task status update: %w", statusUpdateNotFoundError), + }, + { + name: "UpdateFromDifferentWorker", + taskID: "bar", + queueID: "foo", + workerID: "w1", + taskStatuses: []*workerTaskStatus{ + &workerTaskStatus{ + workerID: "w2", + taskStatus: &pb.TaskStatus{ + TaskId: "bar", + Phase: int32(core.PhaseRunning), + }, + }, + }, + expectedPhase: core.PhaseUndefined, + expectedError: fmt.Errorf("unable to find task status update: %w", statusUpdateNotFoundError), + }, + { + name: "MultipleUpdates", + taskID: "bar", + queueID: "foo", + workerID: "w1", + taskStatuses: []*workerTaskStatus{ + &workerTaskStatus{ + workerID: "w1", + taskStatus: &pb.TaskStatus{ + TaskId: "bar", + Phase: int32(core.PhaseQueued), + }, + }, + &workerTaskStatus{ + workerID: "w1", + taskStatus: &pb.TaskStatus{ + TaskId: "bar", + Phase: int32(core.PhaseRunning), + }, + }, + }, + expectedPhase: core.PhaseRunning, + expectedError: nil, + }, + { + name: "NoAckOnSuccess", + taskID: "bar", + queueID: "foo", + workerID: "w1", + taskStatuses: []*workerTaskStatus{ + &workerTaskStatus{ + workerID: "w1", + taskStatus: &pb.TaskStatus{ + TaskId: "bar", + Phase: int32(core.PhaseSuccess), + }, + }, + }, + expectedPhase: core.PhaseSuccess, + expectedError: nil, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + // create fastTaskService + enqueueOwner := func(owner types.NamespacedName) error { + return nil + } + scope := promutils.NewTestScope() + + fastTaskService := newFastTaskService(enqueueOwner, scope) + + // setup taskStatusChannels + if test.taskStatuses != nil { + taskStatusChannel := make(chan *workerTaskStatus, len(test.taskStatuses)) + for _, taskStatus := range test.taskStatuses { + taskStatusChannel <- taskStatus + } + fastTaskService.taskStatusChannels.Store(test.taskID, taskStatusChannel) + } + + // setup response channels for queue workers + queue := &Queue{ + workers: make(map[string]*Worker), + } + queues := map[string]*Queue{ + test.queueID: queue, + } + + responseChans := make(map[string]chan *pb.HeartbeatResponse) + if len(test.queueID) > 0 { + if len(test.workerID) > 0 { + responseChan := make(chan *pb.HeartbeatResponse, 1) + worker := &Worker{ + workerID: test.workerID, + responseChan: responseChan, + } + queue.workers[test.workerID] = worker + + responseChans[test.workerID] = responseChan + } + + for _, taskStatus := range test.taskStatuses { + if taskStatus.workerID != test.workerID { + responseChan := make(chan *pb.HeartbeatResponse, 1) + worker := &Worker{ + workerID: taskStatus.workerID, + responseChan: responseChan, + } + queue.workers[taskStatus.workerID] = worker + + responseChans[taskStatus.workerID] = responseChan + } + } + } + fastTaskService.queues = queues + + // offer on queue and validate + phase, _, err := fastTaskService.CheckStatus(ctx, test.taskID, test.queueID, test.workerID) + assert.Equal(t, test.expectedPhase, phase) + assert.Equal(t, test.expectedError, err) + + // validate ACK response + for responseWorkerID, responseChan := range responseChans { + expectAck := false + if test.expectedError == nil && responseWorkerID == test.workerID && + test.expectedPhase != core.PhaseSuccess && test.expectedPhase != core.PhaseRetryableFailure { + + expectAck = true + } + + if expectAck { + // the assigned worker should have received an ACK response + select { + case response := <-responseChan: + assert.NotNil(t, response) + assert.Equal(t, test.taskID, response.GetTaskId()) + assert.Equal(t, pb.HeartbeatResponse_ACK, response.GetOperation()) + default: + assert.Fail(t, "expected response") + } + } else { + // all other workers should have no responses + select { + case <-responseChan: + assert.Fail(t, "unexpected response") + default: + } + } + } + }) + } +} + +func TestCleanup(t *testing.T) { + ctx := context.TODO() + tests := []struct { + name string + taskID string + queueID string + workerID string + queues map[string]*Queue + expectedError error + pendingOwnerExists bool + taskStatusChannelExists bool + }{ + { + name: "QueueDoesNotExist", + taskID: "bar", + queueID: "foo", + workerID: "w1", + queues: map[string]*Queue{}, + expectedError: nil, + pendingOwnerExists: false, + taskStatusChannelExists: false, + }, + { + name: "WorkerDoesNostExist", + taskID: "bar", + queueID: "foo", + workerID: "w1", + queues: map[string]*Queue{ + "foo": &Queue{ + workers: map[string]*Worker{ + "w0": &Worker{ + workerID: "w0", + }, + }, + }, + }, + expectedError: nil, + pendingOwnerExists: false, + taskStatusChannelExists: false, + }, + { + name: "WorkerExists", + taskID: "bar", + queueID: "foo", + workerID: "w1", + queues: map[string]*Queue{ + "foo": &Queue{ + workers: map[string]*Worker{ + "w0": &Worker{ + workerID: "w0", + }, + "w1": &Worker{ + workerID: "w1", + }, + }, + }, + }, + expectedError: nil, + pendingOwnerExists: false, + taskStatusChannelExists: false, + }, + { + // worker exists and pendingOwner / taskStatusChannel are cleaned up + name: "WorkerExistsCleanupAll", + taskID: "bar", + queueID: "foo", + workerID: "w1", + queues: map[string]*Queue{ + "foo": &Queue{ + workers: map[string]*Worker{ + "w0": &Worker{ + workerID: "w0", + }, + "w1": &Worker{ + workerID: "w1", + }, + }, + }, + }, + expectedError: nil, + pendingOwnerExists: true, + taskStatusChannelExists: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + // create fastTaskService + enqueueOwner := func(owner types.NamespacedName) error { + return nil + } + scope := promutils.NewTestScope() + + fastTaskService := newFastTaskService(enqueueOwner, scope) + + // setup response channels for queue workers + responseChans := make(map[string]chan *pb.HeartbeatResponse) + if queue, exists := test.queues[test.queueID]; exists { + for workerID, worker := range queue.workers { + responseChan := make(chan *pb.HeartbeatResponse, 1) + responseChans[workerID] = responseChan + worker.responseChan = responseChan + } + } + fastTaskService.queues = test.queues + + // initialize pendingTaskOwners and taskStatusChannels if necessary + if test.pendingOwnerExists { + fastTaskService.addPendingOwner(test.queueID, test.taskID, + types.NamespacedName{Name: "foo"}) + } else { + _, exists := fastTaskService.pendingTaskOwners[test.queueID] + assert.False(t, exists) + } + + if test.taskStatusChannelExists { + fastTaskService.taskStatusChannels.Store(test.taskID, make(chan *pb.TaskStatus, 1)) + } else { + _, exists := fastTaskService.taskStatusChannels.Load(test.taskID) + assert.False(t, exists) + } + + // offer on queue and validate + err := fastTaskService.Cleanup(ctx, test.taskID, test.queueID, test.workerID) + assert.Equal(t, test.expectedError, err) + + // validate DELETE response + for responseWorkerID, responseChan := range responseChans { + expectDelete := false + if responseWorkerID == test.workerID { + if queue, exists := test.queues[test.queueID]; exists { + if _, exists := queue.workers[responseWorkerID]; exists { + expectDelete = true + } + } + } + + if expectDelete { + // the assigned worker should have received an DELETE response + select { + case response := <-responseChan: + assert.NotNil(t, response) + assert.Equal(t, test.taskID, response.GetTaskId()) + assert.Equal(t, pb.HeartbeatResponse_DELETE, response.GetOperation()) + default: + assert.Fail(t, "expected response") + } + } else { + // all other workers should have no responses + select { + case <-responseChan: + assert.Fail(t, "unexpected response") + default: + } + } + } + + // validate pendingTaskOwners and taskStatusChannels are cleaned up + _, exists := fastTaskService.pendingTaskOwners[test.queueID] + assert.False(t, exists) + + _, exists = fastTaskService.taskStatusChannels.Load(test.taskID) + assert.False(t, exists) + }) + } +} + +func TestOfferOnQueue(t *testing.T) { + ctx := context.TODO() + tests := []struct { + name string + queueID string + taskID string + namespace string + workflowID string + queues map[string]*Queue + expectedWorkerID string + expectedError error + expectedPendingOwner bool + }{ + { + name: "QueueDoesNotExist", + queueID: "foo", + taskID: "bar", + namespace: "x", + workflowID: "y", + queues: map[string]*Queue{}, + expectedWorkerID: "", + expectedError: nil, + expectedPendingOwner: true, + }, + { + name: "PreferredWorker", + queueID: "foo", + taskID: "bar", + namespace: "x", + workflowID: "y", + queues: map[string]*Queue{ + "foo": &Queue{ + workers: map[string]*Worker{ + "w0": &Worker{ + workerID: "w0", + capacity: &pb.Capacity{ + ExecutionCount: 0, + ExecutionLimit: 1, + }, + }, + "w1": &Worker{ + workerID: "w1", + capacity: &pb.Capacity{ + ExecutionCount: 1, + ExecutionLimit: 1, + }, + }, + }, + }, + }, + expectedWorkerID: "w0", + expectedError: nil, + }, + { + name: "AcceptedWorker", + queueID: "foo", + taskID: "bar", + namespace: "x", + workflowID: "y", + queues: map[string]*Queue{ + "foo": &Queue{ + workers: map[string]*Worker{ + "w0": &Worker{ + workerID: "w0", + capacity: &pb.Capacity{ + ExecutionCount: 1, + ExecutionLimit: 1, + BacklogCount: 1, + BacklogLimit: 1, + }, + }, + "w1": &Worker{ + workerID: "w1", + capacity: &pb.Capacity{ + ExecutionCount: 1, + ExecutionLimit: 1, + BacklogCount: 0, + BacklogLimit: 1, + }, + }, + }, + }, + }, + expectedWorkerID: "w1", + expectedError: nil, + }, + { + name: "NoWorkerAvailable", + queueID: "foo", + taskID: "bar", + namespace: "x", + workflowID: "y", + queues: map[string]*Queue{ + "foo": &Queue{ + workers: map[string]*Worker{ + "w0": &Worker{ + workerID: "w0", + capacity: &pb.Capacity{ + ExecutionCount: 1, + ExecutionLimit: 1, + BacklogCount: 1, + BacklogLimit: 1, + }, + }, + }, + }, + }, + expectedWorkerID: "", + expectedError: nil, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + // create fastTaskService + enqueueOwner := func(owner types.NamespacedName) error { + return nil + } + scope := promutils.NewTestScope() + + fastTaskService := newFastTaskService(enqueueOwner, scope) + + // setup response channels for queue workers + responseChans := make(map[string]chan *pb.HeartbeatResponse) + if queue, exists := test.queues[test.queueID]; exists { + for workerID, worker := range queue.workers { + responseChan := make(chan *pb.HeartbeatResponse, 1) + responseChans[workerID] = responseChan + worker.responseChan = responseChan + } + } + fastTaskService.queues = test.queues + + // pre-execute validation - taskStatusChannel does not exist + _, exists := fastTaskService.taskStatusChannels.Load(test.taskID) + assert.False(t, exists) + + // offer on queue and validate + workerID, err := fastTaskService.OfferOnQueue(ctx, test.queueID, test.taskID, test.namespace, test.workflowID, []string{}) + assert.Equal(t, test.expectedWorkerID, workerID) + assert.Equal(t, test.expectedError, err) + + if len(workerID) > 0 { + // validate ASSIGN response + for responseWorkerID, responseChan := range responseChans { + if responseWorkerID == workerID { + // the assigned worker should have received an ASSIGN response + select { + case response := <-responseChan: + assert.NotNil(t, response) + assert.Equal(t, test.taskID, response.GetTaskId()) + assert.Equal(t, pb.HeartbeatResponse_ASSIGN, response.GetOperation()) + default: + assert.Fail(t, "expected response") + } + } else { + // all other workers should have no responses + select { + case <-responseChan: + assert.Fail(t, "unexpected response") + default: + } + } + } + + // ensure taskStatusChannel now exists + _, exists := fastTaskService.taskStatusChannels.Load(test.taskID) + assert.True(t, exists) + } + + if test.expectedPendingOwner { + pendingOwners, exists := fastTaskService.pendingTaskOwners[test.queueID] + assert.True(t, exists) + + _, exists = pendingOwners[test.taskID] + assert.True(t, exists) + } + }) + } +} + +func TestPendingOwnerManagement(t *testing.T) { + // create fastTaskService + ownerEnqueueCount := 0 + enqueueOwner := func(owner types.NamespacedName) error { + ownerEnqueueCount++ + return nil + } + scope := promutils.NewTestScope() + + fastTaskService := newFastTaskService(enqueueOwner, scope) + assert.Equal(t, 0, len(fastTaskService.queues)) + + // add pending owners + additions := []struct { + queueID string + taskID string + ownerIDName string + expectedQueueOwnersCount int + totalOwnerCount int + }{ + { + // add owner to new queue + queueID: "foo", + taskID: "a", + ownerIDName: "0", + expectedQueueOwnersCount: 1, + totalOwnerCount: 1, + }, + { + // add owner to existing queue + queueID: "foo", + taskID: "b", + ownerIDName: "1", + expectedQueueOwnersCount: 1, + totalOwnerCount: 2, + }, + { + // add owner to another new queue + queueID: "bar", + taskID: "c", + ownerIDName: "2", + expectedQueueOwnersCount: 2, + totalOwnerCount: 3, + }, + } + + for _, addition := range additions { + fastTaskService.addPendingOwner(addition.queueID, addition.taskID, types.NamespacedName{Name: addition.ownerIDName}) + + assert.Equal(t, addition.expectedQueueOwnersCount, len(fastTaskService.pendingTaskOwners)) + totalOwnerCount := 0 + for _, queueOwners := range fastTaskService.pendingTaskOwners { + totalOwnerCount += len(queueOwners) + } + assert.Equal(t, addition.totalOwnerCount, totalOwnerCount) + } + + // validate overflow management on addPendingOwner + overflowTestQueueID := "baz" + for i := 0; i < maxPendingOwnersPerQueue; i++ { + fastTaskService.addPendingOwner(overflowTestQueueID, fmt.Sprintf("%d", i), types.NamespacedName{Name: fmt.Sprintf("%d", i)}) + } + assert.Equal(t, maxPendingOwnersPerQueue, len(fastTaskService.pendingTaskOwners[overflowTestQueueID])) + + fastTaskService.addPendingOwner(overflowTestQueueID, "overflow", types.NamespacedName{Name: "overflow"}) + assert.Equal(t, maxPendingOwnersPerQueue, len(fastTaskService.pendingTaskOwners[overflowTestQueueID])) + + // validate enqueuePendingOwners + assert.Equal(t, 0, ownerEnqueueCount) + + fastTaskService.enqueuePendingOwners(overflowTestQueueID) + assert.Equal(t, maxPendingOwnersPerQueue, ownerEnqueueCount) + assert.Equal(t, 0, len(fastTaskService.pendingTaskOwners[overflowTestQueueID])) + + fastTaskService.enqueuePendingOwners(overflowTestQueueID) // call a second time to validate on empty queue + assert.Equal(t, maxPendingOwnersPerQueue, ownerEnqueueCount) + + // remove workers + removals := []struct { + queueID string + taskID string + expectedQueueOwnersCount int + totalOwnerCount int + }{ + { + // remove owner from non-existent queue + queueID: "baz", + taskID: "d", + expectedQueueOwnersCount: 2, + totalOwnerCount: 3, + }, + { + // remove worker from existing queue + queueID: "foo", + taskID: "a", + expectedQueueOwnersCount: 2, + totalOwnerCount: 2, + }, + { + // remove last worker from queue + queueID: "foo", + taskID: "b", + expectedQueueOwnersCount: 1, + totalOwnerCount: 1, + }, + } + + for _, removal := range removals { + fastTaskService.removePendingOwner(removal.queueID, removal.taskID) + + assert.Equal(t, removal.expectedQueueOwnersCount, len(fastTaskService.pendingTaskOwners)) + totalOwnerCount := 0 + for _, queueOwners := range fastTaskService.pendingTaskOwners { + totalOwnerCount += len(queueOwners) + } + assert.Equal(t, removal.totalOwnerCount, totalOwnerCount) + } +} + +func TestQueueWorkerManagement(t *testing.T) { + // create fastTaskService + enqueueOwner := func(owner types.NamespacedName) error { + return nil + } + scope := promutils.NewTestScope() + + fastTaskService := newFastTaskService(enqueueOwner, scope) + assert.Equal(t, 0, len(fastTaskService.queues)) + + // add workers + additions := []struct { + queueID string + workerID string + expectedQueueCount int + totalWorkerCount int + }{ + { + // add worker to new queue + queueID: "foo", + workerID: "a", + expectedQueueCount: 1, + totalWorkerCount: 1, + }, + { + // add worker to existing queue + queueID: "foo", + workerID: "b", + expectedQueueCount: 1, + totalWorkerCount: 2, + }, + { + // add worker to another new queue + queueID: "bar", + workerID: "c", + expectedQueueCount: 2, + totalWorkerCount: 3, + }, + } + + for _, addition := range additions { + worker := &Worker{ + workerID: addition.workerID, + } + + queue := fastTaskService.addWorkerToQueue(addition.queueID, worker) + assert.NotNil(t, queue) + + assert.Equal(t, addition.expectedQueueCount, len(fastTaskService.queues)) + totalWorkers := 0 + for _, q := range fastTaskService.queues { + totalWorkers += len(q.workers) + } + assert.Equal(t, addition.totalWorkerCount, totalWorkers) + } + + // remove workers + removals := []struct { + queueID string + workerID string + expectedQueueCount int + totalWorkerCount int + }{ + { + // remove worker from non-existent queue + queueID: "baz", + workerID: "d", + expectedQueueCount: 2, + totalWorkerCount: 3, + }, + { + // remove worker from existing queue + queueID: "foo", + workerID: "a", + expectedQueueCount: 2, + totalWorkerCount: 2, + }, + { + // remove last worker from queue + queueID: "foo", + workerID: "b", + expectedQueueCount: 1, + totalWorkerCount: 1, + }, + } + + for _, removal := range removals { + fastTaskService.removeWorkerFromQueue(removal.queueID, removal.workerID) + + assert.Equal(t, removal.expectedQueueCount, len(fastTaskService.queues)) + totalWorkers := 0 + for _, q := range fastTaskService.queues { + totalWorkers += len(q.workers) + } + assert.Equal(t, removal.totalWorkerCount, totalWorkers) + } +} diff --git a/flyteadmin/auth/handlers.go b/flyteadmin/auth/handlers.go index cb6ca9030dc..6323eb3aa6e 100644 --- a/flyteadmin/auth/handlers.go +++ b/flyteadmin/auth/handlers.go @@ -262,7 +262,6 @@ func GetAuthenticationCustomMetadataInterceptor(authCtx interfaces.Authenticatio if ok { existingHeader := md.Get(authCtx.Options().GrpcAuthorizationHeader) if len(existingHeader) > 0 { - logger.Debugf(ctx, "Found existing metadata %s", existingHeader[0]) newAuthorizationMetadata := metadata.Pairs(DefaultAuthorizationHeader, existingHeader[0]) joinedMetadata := metadata.Join(md, newAuthorizationMetadata) newCtx := metadata.NewIncomingContext(ctx, joinedMetadata) diff --git a/flyteadmin/go.mod b/flyteadmin/go.mod index 91957f58d5f..a7b4eb41a89 100644 --- a/flyteadmin/go.mod +++ b/flyteadmin/go.mod @@ -54,6 +54,7 @@ require ( github.com/wI2L/jsondiff v0.5.0 go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.47.0 go.opentelemetry.io/otel v1.24.0 + golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 golang.org/x/oauth2 v0.16.0 golang.org/x/sync v0.7.0 golang.org/x/time v0.5.0 @@ -202,7 +203,6 @@ require ( go.uber.org/multierr v1.11.0 // indirect go.uber.org/zap v1.25.0 // indirect golang.org/x/crypto v0.22.0 // indirect - golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 // indirect golang.org/x/net v0.24.0 // indirect golang.org/x/sys v0.19.0 // indirect golang.org/x/term v0.19.0 // indirect diff --git a/flyteadmin/scheduler/executor/executor_impl.go b/flyteadmin/scheduler/executor/executor_impl.go index 30ab7f06770..8aac10af733 100644 --- a/flyteadmin/scheduler/executor/executor_impl.go +++ b/flyteadmin/scheduler/executor/executor_impl.go @@ -55,6 +55,7 @@ func (w *executor) Execute(ctx context.Context, scheduledTime time.Time, s model // Making the identifier deterministic using the hash of the identifier and scheduled time executionIdentifier, err := identifier.GetExecutionIdentifier(ctx, core.Identifier{ + Org: s.Org, Project: s.Project, Domain: s.Domain, Name: s.Name, @@ -67,12 +68,14 @@ func (w *executor) Execute(ctx context.Context, scheduledTime time.Time, s model } executionRequest := &admin.ExecutionCreateRequest{ + Org: s.Org, Project: s.Project, Domain: s.Domain, Name: "f" + strings.ReplaceAll(executionIdentifier.String(), "-", "")[:19], Spec: &admin.ExecutionSpec{ LaunchPlan: &core.Identifier{ ResourceType: core.ResourceType_LAUNCH_PLAN, + Org: s.Org, Project: s.Project, Domain: s.Domain, Name: s.Name, diff --git a/flyteadmin/scheduler/identifier/identifier.go b/flyteadmin/scheduler/identifier/identifier.go index 420be2af4ee..46eda90895c 100644 --- a/flyteadmin/scheduler/identifier/identifier.go +++ b/flyteadmin/scheduler/identifier/identifier.go @@ -6,9 +6,11 @@ import ( "fmt" "hash/fnv" "strconv" + "strings" "time" "github.com/google/uuid" + "github.com/samber/lo" "github.com/flyteorg/flyte/flyteadmin/scheduler/repositories/models" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" @@ -18,8 +20,7 @@ import ( // Utility functions used by the flyte native scheduler const ( - scheduleNameInputsFormat = "%s:%s:%s:%s" - executionIDInputsFormat = scheduleNameInputsFormat + ":%d" + executionIDInputsFormat = "%s:%d" ) // GetScheduleName generate the schedule name to be used as unique identification string within the scheduler @@ -40,11 +41,19 @@ func GetExecutionIdentifier(ctx context.Context, identifier core.Identifier, sch return uuid.FromBytes(b) } +func getIdentifierString(identifier *core.Identifier) string { + fields := lo.Filter([]string{ + identifier.Org, identifier.Project, identifier.Domain, identifier.Name, identifier.Version, + }, func(item string, index int) bool { + return len(item) > 0 + }) + return strings.Join(fields, ":") +} + // hashIdentifier returns the hash of the identifier func hashIdentifier(ctx context.Context, identifier core.Identifier) uint64 { h := fnv.New64() - _, err := h.Write([]byte(fmt.Sprintf(scheduleNameInputsFormat, - identifier.Project, identifier.Domain, identifier.Name, identifier.Version))) + _, err := h.Write([]byte(getIdentifierString(&identifier))) if err != nil { // This shouldn't occur. logger.Errorf(ctx, @@ -55,11 +64,15 @@ func hashIdentifier(ctx context.Context, identifier core.Identifier) uint64 { return h.Sum64() } +func getExecutionIDInputsFormat(identifier *core.Identifier, scheduleTime time.Time) []byte { + return []byte(fmt.Sprintf(executionIDInputsFormat, getIdentifierString(identifier), scheduleTime.Unix())) +} + // hashScheduledTimeStamp return the hash of the identifier and the scheduledTime func hashScheduledTimeStamp(ctx context.Context, identifier core.Identifier, scheduledTime time.Time) uint64 { h := fnv.New64() - _, err := h.Write([]byte(fmt.Sprintf(executionIDInputsFormat, - identifier.Project, identifier.Domain, identifier.Name, identifier.Version, scheduledTime.Unix()))) + + _, err := h.Write(getExecutionIDInputsFormat(&identifier, scheduledTime)) if err != nil { // This shouldn't occur. logger.Errorf(ctx, diff --git a/flyteadmin/scheduler/identifier/identifier_test.go b/flyteadmin/scheduler/identifier/identifier_test.go new file mode 100644 index 00000000000..8898fc954c3 --- /dev/null +++ b/flyteadmin/scheduler/identifier/identifier_test.go @@ -0,0 +1,35 @@ +package identifier + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" +) + +func TestGetIdentifierString(t *testing.T) { + t.Run("with org", func(t *testing.T) { + identifier := &core.Identifier{ + Org: "org", + Project: "project", + Domain: "domain", + Name: "name", + Version: "version", + } + expected := "org:project:domain:name:version" + actual := getIdentifierString(identifier) + assert.Equal(t, expected, actual) + }) + t.Run("without org", func(t *testing.T) { + identifier := &core.Identifier{ + Project: "project", + Domain: "domain", + Name: "name", + Version: "version", + } + expected := "project:domain:name:version" + actual := getIdentifierString(identifier) + assert.Equal(t, expected, actual) + }) +} diff --git a/flyteadmin/tests/attributes_test.go b/flyteadmin/tests/attributes_test.go index bcc73c7b9ce..42ee3b00687 100644 --- a/flyteadmin/tests/attributes_test.go +++ b/flyteadmin/tests/attributes_test.go @@ -8,16 +8,16 @@ import ( "fmt" "testing" - "github.com/flyteorg/flyte/flyteadmin/pkg/repositories/models" - "github.com/flyteorg/flyte/flyteadmin/pkg/runtime" - runtimeIfaces "github.com/flyteorg/flyte/flyteadmin/pkg/runtime/interfaces" - "github.com/flyteorg/flyte/flytestdlib/logger" + "github.com/golang/protobuf/proto" + "github.com/stretchr/testify/assert" "gorm.io/gorm" "github.com/flyteorg/flyte/flyteadmin/pkg/repositories" + "github.com/flyteorg/flyte/flyteadmin/pkg/repositories/models" + "github.com/flyteorg/flyte/flyteadmin/pkg/runtime" + runtimeIfaces "github.com/flyteorg/flyte/flyteadmin/pkg/runtime/interfaces" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" - "github.com/golang/protobuf/proto" - "github.com/stretchr/testify/assert" + "github.com/flyteorg/flyte/flytestdlib/logger" ) var matchingTaskResourceAttributes = &admin.MatchingAttributes{ diff --git a/flyteadmin/tests/configuration_test.go b/flyteadmin/tests/configuration_test.go index 4a181b53d31..50d8069835b 100644 --- a/flyteadmin/tests/configuration_test.go +++ b/flyteadmin/tests/configuration_test.go @@ -8,10 +8,11 @@ import ( "sync" "testing" - "github.com/flyteorg/flyte/flyteadmin/pkg/repositories" - "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" "github.com/golang/protobuf/proto" "github.com/stretchr/testify/assert" + + "github.com/flyteorg/flyte/flyteadmin/pkg/repositories" + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" ) var taskResourceAttributes = &admin.TaskResourceAttributes{ diff --git a/flyteadmin/tests/migrations_test.go b/flyteadmin/tests/migrations_test.go index c4bf1a55ed4..d835548fc3f 100644 --- a/flyteadmin/tests/migrations_test.go +++ b/flyteadmin/tests/migrations_test.go @@ -9,10 +9,11 @@ import ( "sort" "testing" - "github.com/flyteorg/flyte/flyteadmin/pkg/repositories" - "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" "github.com/golang/protobuf/proto" "github.com/stretchr/testify/assert" + + "github.com/flyteorg/flyte/flyteadmin/pkg/repositories" + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" ) // This file is used to test the migration from resource to configuration and configuration to resource. diff --git a/flyteidl/clients/go/admin/auth_interceptor.go b/flyteidl/clients/go/admin/auth_interceptor.go index 8a0024b3192..bcd39142f86 100644 --- a/flyteidl/clients/go/admin/auth_interceptor.go +++ b/flyteidl/clients/go/admin/auth_interceptor.go @@ -20,7 +20,9 @@ const ProxyAuthorizationHeader = "proxy-authorization" // MaterializeCredentials will attempt to build a TokenSource given the anonymously available information exposed by the server. // Once established, it'll invoke PerRPCCredentialsFuture.Store() on perRPCCredentials to populate it with the appropriate values. -func MaterializeCredentials(ctx context.Context, cfg *Config, tokenCache cache.TokenCache, perRPCCredentials *PerRPCCredentialsFuture, proxyCredentialsFuture *PerRPCCredentialsFuture) error { +func MaterializeCredentials(ctx context.Context, cfg *Config, tokenCache cache.TokenCache, + perRPCCredentials *PerRPCCredentialsFuture, proxyCredentialsFuture *PerRPCCredentialsFuture) error { + authMetadataClient, err := InitializeAuthMetadataClient(ctx, cfg, proxyCredentialsFuture) if err != nil { return fmt.Errorf("failed to initialized Auth Metadata Client. Error: %w", err) @@ -42,11 +44,17 @@ func MaterializeCredentials(ctx context.Context, cfg *Config, tokenCache cache.T tokenSource, err := tokenSourceProvider.GetTokenSource(ctx) if err != nil { - return err + return fmt.Errorf("failed to get token source. Error: %w", err) + } + + _, err = tokenSource.Token() + if err != nil { + return fmt.Errorf("failed to issue token. Error: %w", err) } wrappedTokenSource := NewCustomHeaderTokenSource(tokenSource, cfg.UseInsecureConnection, authorizationMetadataKey) perRPCCredentials.Store(wrappedTokenSource) + return nil } @@ -134,6 +142,15 @@ func NewAuthInterceptor(cfg *Config, tokenCache cache.TokenCache, credentialsFut return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { ctx = setHTTPClientContext(ctx, cfg, proxyCredentialsFuture) + // If there is already a token in the cache (e.g. key-ring), we should use it immediately... + t, _ := tokenCache.GetToken() + if t != nil { + err := MaterializeCredentials(ctx, cfg, tokenCache, credentialsFuture, proxyCredentialsFuture) + if err != nil { + return fmt.Errorf("failed to materialize credentials. Error: %v", err) + } + } + err := invoker(ctx, method, req, reply, cc, opts...) if err != nil { logger.Debugf(ctx, "Request failed due to [%v]. If it's an unauthenticated error, we will attempt to establish an authenticated context.", err) @@ -141,10 +158,32 @@ func NewAuthInterceptor(cfg *Config, tokenCache cache.TokenCache, credentialsFut if st, ok := status.FromError(err); ok { // If the error we receive from executing the request expects if shouldAttemptToAuthenticate(st.Code()) { - logger.Debugf(ctx, "Request failed due to [%v]. Attempting to establish an authenticated connection and trying again.", st.Code()) - newErr := MaterializeCredentials(ctx, cfg, tokenCache, credentialsFuture, proxyCredentialsFuture) - if newErr != nil { - return fmt.Errorf("authentication error! Original Error: %v, Auth Error: %w", err, newErr) + err = func() error { + if !tokenCache.TryLock() { + tokenCache.CondWait() + return nil + } + defer tokenCache.Unlock() + _, err := tokenCache.PurgeIfEquals(t) + if err != nil && !errors.Is(err, cache.ErrNotFound) { + logger.Errorf(ctx, "Failed to purge cache. Error [%v]", err) + return fmt.Errorf("failed to purge cache. Error: %w", err) + } + + logger.Debugf(ctx, "Request failed due to [%v]. Attempting to establish an authenticated connection and trying again.", st.Code()) + newErr := MaterializeCredentials(ctx, cfg, tokenCache, credentialsFuture, proxyCredentialsFuture) + if newErr != nil { + errString := fmt.Sprintf("authentication error! Original Error: %v, Auth Error: %v", err, newErr) + logger.Errorf(ctx, errString) + return fmt.Errorf(errString) + } + + tokenCache.CondBroadcast() + return nil + }() + + if err != nil { + return err } return invoker(ctx, method, req, reply, cc, opts...) @@ -167,6 +206,7 @@ func NewProxyAuthInterceptor(cfg *Config, proxyCredentialsFuture *PerRPCCredenti } return invoker(ctx, method, req, reply, cc, opts...) } + return err } } diff --git a/flyteidl/clients/go/admin/auth_interceptor_test.go b/flyteidl/clients/go/admin/auth_interceptor_test.go index ce99c992708..10c96625b78 100644 --- a/flyteidl/clients/go/admin/auth_interceptor_test.go +++ b/flyteidl/clients/go/admin/auth_interceptor_test.go @@ -2,13 +2,14 @@ package admin import ( "context" + "encoding/json" "errors" "fmt" "io" "net" "net/http" - "net/http/httptest" "net/url" + "os" "strings" "sync" "testing" @@ -31,10 +32,11 @@ import ( // authMetadataServer is a fake AuthMetadataServer that takes in an AuthMetadataServer implementation (usually one // initialized through mockery) and starts a local server that uses it to respond to grpc requests. type authMetadataServer struct { - s *httptest.Server t testing.TB - port int + grpcPort int + httpPort int grpcServer *grpc.Server + httpServer *http.Server netListener net.Listener impl service.AuthMetadataServiceServer lck *sync.RWMutex @@ -70,27 +72,49 @@ func (s authMetadataServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { http.NotFound(w, r) } +func (s *authMetadataServer) tokenHandler(w http.ResponseWriter, r *http.Request) { + tokenJSON := []byte(`{"access_token": "exampletoken", "token_type": "bearer"}`) + w.Header().Set("Content-Type", "application/json") + _, err := w.Write(tokenJSON) + assert.NoError(s.t, err) +} + func (s *authMetadataServer) Start(_ context.Context) error { s.lck.Lock() defer s.lck.Unlock() /***** Set up the server serving channelz service. *****/ - lis, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", s.port)) + + lis, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", s.grpcPort)) if err != nil { - return fmt.Errorf("failed to listen on port [%v]: %w", s.port, err) + return fmt.Errorf("failed to listen on port [%v]: %w", s.grpcPort, err) } + s.netListener = lis grpcS := grpc.NewServer() service.RegisterAuthMetadataServiceServer(grpcS, s) go func() { + defer grpcS.Stop() _ = grpcS.Serve(lis) - //assert.NoError(s.t, err) }() - s.grpcServer = grpcS - s.netListener = lis + mux := http.NewServeMux() + // Attach the handler to the /oauth2/token path + mux.HandleFunc("/oauth2/token", s.tokenHandler) + + //nolint:gosec + s.httpServer = &http.Server{ + Addr: fmt.Sprintf("localhost:%d", s.httpPort), + Handler: mux, + } - s.s = httptest.NewServer(s) + go func() { + defer s.httpServer.Close() + err := s.httpServer.ListenAndServe() + if err != nil { + panic(err) + } + }() return nil } @@ -98,25 +122,30 @@ func (s *authMetadataServer) Start(_ context.Context) error { func (s *authMetadataServer) Close() { s.lck.RLock() defer s.lck.RUnlock() - s.grpcServer.Stop() - s.s.Close() } -func newAuthMetadataServer(t testing.TB, port int, impl service.AuthMetadataServiceServer) *authMetadataServer { +func newAuthMetadataServer(t testing.TB, grpcPort int, httpPort int, impl service.AuthMetadataServiceServer) *authMetadataServer { return &authMetadataServer{ - port: port, - t: t, - impl: impl, - lck: &sync.RWMutex{}, + grpcPort: grpcPort, + httpPort: httpPort, + t: t, + impl: impl, + lck: &sync.RWMutex{}, } } func Test_newAuthInterceptor(t *testing.T) { + plan, _ := os.ReadFile("tokenorchestrator/testdata/token.json") + var tokenData oauth2.Token + err := json.Unmarshal(plan, &tokenData) + assert.NoError(t, err) t.Run("Other Error", func(t *testing.T) { f := NewPerRPCCredentialsFuture() p := NewPerRPCCredentialsFuture() - interceptor := NewAuthInterceptor(&Config{}, &mocks.TokenCache{}, f, p) + mockTokenCache := &mocks.TokenCache{} + mockTokenCache.OnGetTokenMatch().Return(&tokenData, nil) + interceptor := NewAuthInterceptor(&Config{}, mockTokenCache, f, p) otherError := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error { return status.New(codes.Canceled, "").Err() } @@ -129,35 +158,43 @@ func Test_newAuthInterceptor(t *testing.T) { Level: logger.DebugLevel, })) - port := rand.IntnRange(10000, 60000) + httpPort := rand.IntnRange(10000, 60000) + grpcPort := rand.IntnRange(10000, 60000) m := &adminMocks.AuthMetadataServiceServer{} m.OnGetOAuth2MetadataMatch(mock.Anything, mock.Anything).Return(&service.OAuth2MetadataResponse{ - AuthorizationEndpoint: fmt.Sprintf("http://localhost:%d/oauth2/authorize", port), - TokenEndpoint: fmt.Sprintf("http://localhost:%d/oauth2/token", port), - JwksUri: fmt.Sprintf("http://localhost:%d/oauth2/jwks", port), + AuthorizationEndpoint: fmt.Sprintf("http://localhost:%d/oauth2/authorize", httpPort), + TokenEndpoint: fmt.Sprintf("http://localhost:%d/oauth2/token", httpPort), + JwksUri: fmt.Sprintf("http://localhost:%d/oauth2/jwks", httpPort), }, nil) + m.OnGetPublicClientConfigMatch(mock.Anything, mock.Anything).Return(&service.PublicClientAuthConfigResponse{ Scopes: []string{"all"}, }, nil) - s := newAuthMetadataServer(t, port, m) + s := newAuthMetadataServer(t, grpcPort, httpPort, m) ctx := context.Background() assert.NoError(t, s.Start(ctx)) defer s.Close() - u, err := url.Parse(fmt.Sprintf("dns:///localhost:%d", port)) + u, err := url.Parse(fmt.Sprintf("dns:///localhost:%d", grpcPort)) assert.NoError(t, err) f := NewPerRPCCredentialsFuture() p := NewPerRPCCredentialsFuture() + c := &mocks.TokenCache{} + c.OnGetTokenMatch().Return(nil, nil) + c.OnTryLockMatch().Return(true) + c.OnSaveTokenMatch(mock.Anything).Return(nil) + c.On("CondBroadcast").Return() + c.On("Unlock").Return() + c.OnPurgeIfEqualsMatch(mock.Anything).Return(true, nil) interceptor := NewAuthInterceptor(&Config{ Endpoint: config.URL{URL: *u}, UseInsecureConnection: true, AuthType: AuthTypeClientSecret, - }, &mocks.TokenCache{}, f, p) + }, c, f, p) unauthenticated := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error { return status.New(codes.Unauthenticated, "").Err() } - err = interceptor(ctx, "POST", nil, nil, nil, unauthenticated) assert.Error(t, err) assert.Truef(t, f.IsInitialized(), "PerRPCCredentialFuture should be initialized") @@ -169,24 +206,26 @@ func Test_newAuthInterceptor(t *testing.T) { Level: logger.DebugLevel, })) - port := rand.IntnRange(10000, 60000) + httpPort := rand.IntnRange(10000, 60000) + grpcPort := rand.IntnRange(10000, 60000) m := &adminMocks.AuthMetadataServiceServer{} - s := newAuthMetadataServer(t, port, m) + s := newAuthMetadataServer(t, grpcPort, httpPort, m) ctx := context.Background() assert.NoError(t, s.Start(ctx)) defer s.Close() - u, err := url.Parse(fmt.Sprintf("dns:///localhost:%d", port)) + u, err := url.Parse(fmt.Sprintf("dns:///localhost:%d", grpcPort)) assert.NoError(t, err) f := NewPerRPCCredentialsFuture() p := NewPerRPCCredentialsFuture() - + c := &mocks.TokenCache{} + c.OnGetTokenMatch().Return(nil, nil) interceptor := NewAuthInterceptor(&Config{ Endpoint: config.URL{URL: *u}, UseInsecureConnection: true, AuthType: AuthTypeClientSecret, - }, &mocks.TokenCache{}, f, p) + }, c, f, p) authenticated := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error { return nil } @@ -201,33 +240,39 @@ func Test_newAuthInterceptor(t *testing.T) { Level: logger.DebugLevel, })) - port := rand.IntnRange(10000, 60000) + httpPort := rand.IntnRange(10000, 60000) + grpcPort := rand.IntnRange(10000, 60000) m := &adminMocks.AuthMetadataServiceServer{} m.OnGetOAuth2MetadataMatch(mock.Anything, mock.Anything).Return(&service.OAuth2MetadataResponse{ - AuthorizationEndpoint: fmt.Sprintf("http://localhost:%d/oauth2/authorize", port), - TokenEndpoint: fmt.Sprintf("http://localhost:%d/oauth2/token", port), - JwksUri: fmt.Sprintf("http://localhost:%d/oauth2/jwks", port), + AuthorizationEndpoint: fmt.Sprintf("http://localhost:%d/oauth2/authorize", httpPort), + TokenEndpoint: fmt.Sprintf("http://localhost:%d/oauth2/token", httpPort), + JwksUri: fmt.Sprintf("http://localhost:%d/oauth2/jwks", httpPort), }, nil) m.OnGetPublicClientConfigMatch(mock.Anything, mock.Anything).Return(&service.PublicClientAuthConfigResponse{ Scopes: []string{"all"}, }, nil) - s := newAuthMetadataServer(t, port, m) + s := newAuthMetadataServer(t, grpcPort, httpPort, m) ctx := context.Background() assert.NoError(t, s.Start(ctx)) defer s.Close() - u, err := url.Parse(fmt.Sprintf("dns:///localhost:%d", port)) + u, err := url.Parse(fmt.Sprintf("dns:///localhost:%d", grpcPort)) assert.NoError(t, err) f := NewPerRPCCredentialsFuture() p := NewPerRPCCredentialsFuture() + c := &mocks.TokenCache{} + c.OnGetTokenMatch().Return(nil, nil) + c.OnTryLockMatch().Return(true) + c.OnSaveTokenMatch(mock.Anything).Return(nil) + c.OnPurgeIfEqualsMatch(mock.Anything).Return(true, nil) interceptor := NewAuthInterceptor(&Config{ Endpoint: config.URL{URL: *u}, UseInsecureConnection: true, AuthType: AuthTypeClientSecret, - }, &mocks.TokenCache{}, f, p) + }, c, f, p) unauthenticated := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error { return status.New(codes.Aborted, "").Err() } @@ -239,17 +284,21 @@ func Test_newAuthInterceptor(t *testing.T) { } func TestMaterializeCredentials(t *testing.T) { - port := rand.IntnRange(10000, 60000) t.Run("No oauth2 metadata endpoint or Public client config lookup", func(t *testing.T) { + httpPort := rand.IntnRange(10000, 60000) + grpcPort := rand.IntnRange(10000, 60000) + c := &mocks.TokenCache{} + c.OnGetTokenMatch().Return(nil, nil) + c.OnSaveTokenMatch(mock.Anything).Return(nil) m := &adminMocks.AuthMetadataServiceServer{} m.OnGetOAuth2MetadataMatch(mock.Anything, mock.Anything).Return(nil, errors.New("unexpected call to get oauth2 metadata")) m.OnGetPublicClientConfigMatch(mock.Anything, mock.Anything).Return(nil, errors.New("unexpected call to get public client config")) - s := newAuthMetadataServer(t, port, m) + s := newAuthMetadataServer(t, grpcPort, httpPort, m) ctx := context.Background() assert.NoError(t, s.Start(ctx)) defer s.Close() - u, err := url.Parse(fmt.Sprintf("dns:///localhost:%d", port)) + u, err := url.Parse(fmt.Sprintf("dns:///localhost:%d", grpcPort)) assert.NoError(t, err) f := NewPerRPCCredentialsFuture() @@ -259,24 +308,29 @@ func TestMaterializeCredentials(t *testing.T) { Endpoint: config.URL{URL: *u}, UseInsecureConnection: true, AuthType: AuthTypeClientSecret, - TokenURL: fmt.Sprintf("http://localhost:%d/api/v1/token", port), + TokenURL: fmt.Sprintf("http://localhost:%d/oauth2/token", httpPort), Scopes: []string{"all"}, Audience: "http://localhost:30081", AuthorizationHeader: "authorization", - }, &mocks.TokenCache{}, f, p) + }, c, f, p) assert.NoError(t, err) }) t.Run("Failed to fetch client metadata", func(t *testing.T) { + httpPort := rand.IntnRange(10000, 60000) + grpcPort := rand.IntnRange(10000, 60000) + c := &mocks.TokenCache{} + c.OnGetTokenMatch().Return(nil, nil) + c.OnSaveTokenMatch(mock.Anything).Return(nil) m := &adminMocks.AuthMetadataServiceServer{} m.OnGetOAuth2MetadataMatch(mock.Anything, mock.Anything).Return(nil, errors.New("unexpected call to get oauth2 metadata")) failedPublicClientConfigLookup := errors.New("expected err") m.OnGetPublicClientConfigMatch(mock.Anything, mock.Anything).Return(nil, failedPublicClientConfigLookup) - s := newAuthMetadataServer(t, port, m) + s := newAuthMetadataServer(t, grpcPort, httpPort, m) ctx := context.Background() assert.NoError(t, s.Start(ctx)) defer s.Close() - u, err := url.Parse(fmt.Sprintf("dns:///localhost:%d", port)) + u, err := url.Parse(fmt.Sprintf("dns:///localhost:%d", grpcPort)) assert.NoError(t, err) f := NewPerRPCCredentialsFuture() @@ -286,9 +340,9 @@ func TestMaterializeCredentials(t *testing.T) { Endpoint: config.URL{URL: *u}, UseInsecureConnection: true, AuthType: AuthTypeClientSecret, - TokenURL: fmt.Sprintf("http://localhost:%d/api/v1/token", port), + TokenURL: fmt.Sprintf("http://localhost:%d/api/v1/token", httpPort), Scopes: []string{"all"}, - }, &mocks.TokenCache{}, f, p) + }, c, f, p) assert.EqualError(t, err, "failed to fetch client metadata. Error: rpc error: code = Unknown desc = expected err") }) } diff --git a/flyteidl/clients/go/admin/cache/mocks/token_cache.go b/flyteidl/clients/go/admin/cache/mocks/token_cache.go index 0af58b381f8..88a1bef81cd 100644 --- a/flyteidl/clients/go/admin/cache/mocks/token_cache.go +++ b/flyteidl/clients/go/admin/cache/mocks/token_cache.go @@ -12,6 +12,16 @@ type TokenCache struct { mock.Mock } +// CondBroadcast provides a mock function with given fields: +func (_m *TokenCache) CondBroadcast() { + _m.Called() +} + +// CondWait provides a mock function with given fields: +func (_m *TokenCache) CondWait() { + _m.Called() +} + type TokenCache_GetToken struct { *mock.Call } @@ -53,6 +63,50 @@ func (_m *TokenCache) GetToken() (*oauth2.Token, error) { return r0, r1 } +// Lock provides a mock function with given fields: +func (_m *TokenCache) Lock() { + _m.Called() +} + +type TokenCache_PurgeIfEquals struct { + *mock.Call +} + +func (_m TokenCache_PurgeIfEquals) Return(_a0 bool, _a1 error) *TokenCache_PurgeIfEquals { + return &TokenCache_PurgeIfEquals{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *TokenCache) OnPurgeIfEquals(t *oauth2.Token) *TokenCache_PurgeIfEquals { + c_call := _m.On("PurgeIfEquals", t) + return &TokenCache_PurgeIfEquals{Call: c_call} +} + +func (_m *TokenCache) OnPurgeIfEqualsMatch(matchers ...interface{}) *TokenCache_PurgeIfEquals { + c_call := _m.On("PurgeIfEquals", matchers...) + return &TokenCache_PurgeIfEquals{Call: c_call} +} + +// PurgeIfEquals provides a mock function with given fields: t +func (_m *TokenCache) PurgeIfEquals(t *oauth2.Token) (bool, error) { + ret := _m.Called(t) + + var r0 bool + if rf, ok := ret.Get(0).(func(*oauth2.Token) bool); ok { + r0 = rf(t) + } else { + r0 = ret.Get(0).(bool) + } + + var r1 error + if rf, ok := ret.Get(1).(func(*oauth2.Token) error); ok { + r1 = rf(t) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + type TokenCache_SaveToken struct { *mock.Call } @@ -84,3 +138,40 @@ func (_m *TokenCache) SaveToken(token *oauth2.Token) error { return r0 } + +type TokenCache_TryLock struct { + *mock.Call +} + +func (_m TokenCache_TryLock) Return(_a0 bool) *TokenCache_TryLock { + return &TokenCache_TryLock{Call: _m.Call.Return(_a0)} +} + +func (_m *TokenCache) OnTryLock() *TokenCache_TryLock { + c_call := _m.On("TryLock") + return &TokenCache_TryLock{Call: c_call} +} + +func (_m *TokenCache) OnTryLockMatch(matchers ...interface{}) *TokenCache_TryLock { + c_call := _m.On("TryLock", matchers...) + return &TokenCache_TryLock{Call: c_call} +} + +// TryLock provides a mock function with given fields: +func (_m *TokenCache) TryLock() bool { + ret := _m.Called() + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// Unlock provides a mock function with given fields: +func (_m *TokenCache) Unlock() { + _m.Called() +} diff --git a/flyteidl/clients/go/admin/cache/token_cache.go b/flyteidl/clients/go/admin/cache/token_cache.go index e4e2b7e17f2..f946c8bdea6 100644 --- a/flyteidl/clients/go/admin/cache/token_cache.go +++ b/flyteidl/clients/go/admin/cache/token_cache.go @@ -1,14 +1,39 @@ package cache -import "golang.org/x/oauth2" +import ( + "fmt" + "golang.org/x/oauth2" +) //go:generate mockery -all -case=underscore +var ( + ErrNotFound = fmt.Errorf("secret not found in keyring") +) + // TokenCache defines the interface needed to cache and retrieve oauth tokens. type TokenCache interface { // SaveToken saves the token securely to cache. SaveToken(token *oauth2.Token) error - // Retrieves the token from the cache. + // GetToken retrieves the token from the cache. GetToken() (*oauth2.Token, error) + + // PurgeIfEquals purges the token from the cache. + PurgeIfEquals(t *oauth2.Token) (bool, error) + + // Lock the cache. + Lock() + + // TryLock tries to lock the cache. + TryLock() bool + + // Unlock the cache. + Unlock() + + // CondWait waits for the condition to be true. + CondWait() + + // CondSignalCondBroadcast signals the condition. + CondBroadcast() } diff --git a/flyteidl/clients/go/admin/cache/token_cache_inmemory.go b/flyteidl/clients/go/admin/cache/token_cache_inmemory.go index 9c6223fc06b..477b28ace44 100644 --- a/flyteidl/clients/go/admin/cache/token_cache_inmemory.go +++ b/flyteidl/clients/go/admin/cache/token_cache_inmemory.go @@ -2,23 +2,93 @@ package cache import ( "fmt" + "sync" + "sync/atomic" "golang.org/x/oauth2" ) type TokenCacheInMemoryProvider struct { - token *oauth2.Token + token atomic.Value + mu *sync.Mutex + condLocker *NoopLocker + cond *sync.Cond } func (t *TokenCacheInMemoryProvider) SaveToken(token *oauth2.Token) error { - t.token = token + t.token.Store(token) return nil } -func (t TokenCacheInMemoryProvider) GetToken() (*oauth2.Token, error) { - if t.token == nil { +func (t *TokenCacheInMemoryProvider) GetToken() (*oauth2.Token, error) { + tkn := t.token.Load() + if tkn == nil { return nil, fmt.Errorf("cannot find token in cache") } + return tkn.(*oauth2.Token), nil +} + +func (t *TokenCacheInMemoryProvider) PurgeIfEquals(existing *oauth2.Token) (bool, error) { + // Add an empty token since we can't mark it nil using Compare and swap + return t.token.CompareAndSwap(existing, &oauth2.Token{}), nil +} + +func (t *TokenCacheInMemoryProvider) Lock() { + t.mu.Lock() +} + +func (t *TokenCacheInMemoryProvider) TryLock() bool { + return t.mu.TryLock() +} + +func (t *TokenCacheInMemoryProvider) Unlock() { + t.mu.Unlock() +} + +// CondWait adds the current go routine to the condition waitlist and waits for another go routine to notify using CondBroadcast +// The current usage is that one who was able to acquire the lock using TryLock is the one who gets a valid token and notifies all the waitlist requesters so that they can use the new valid token. +// It also locks the Locker in the condition variable as the semantics of Wait is that it unlocks the Locker after adding +// the consumer to the waitlist and before blocking on notification. +// We use the condLocker which is noOp locker to get added to waitlist for notifications. +// The underlying notifcationList doesn't need to be guarded as it implmentation is atomic and is thread safe +// Refer https://go.dev/src/runtime/sema.go +// Following is the function and its comments +// notifyListAdd adds the caller to a notify list such that it can receive +// notifications. The caller must eventually call notifyListWait to wait for +// such a notification, passing the returned ticket number. +// +// func notifyListAdd(l *notifyList) uint32 { +// // This may be called concurrently, for example, when called from +// // sync.Cond.Wait while holding a RWMutex in read mode. +// return l.wait.Add(1) - 1 +// } +func (t *TokenCacheInMemoryProvider) CondWait() { + t.condLocker.Lock() + t.cond.Wait() + t.condLocker.Unlock() +} + +// NoopLocker has empty implementation of Locker interface +type NoopLocker struct { +} + +func (*NoopLocker) Lock() { + +} +func (*NoopLocker) Unlock() { +} - return t.token, nil +// CondBroadcast signals the condition. +func (t *TokenCacheInMemoryProvider) CondBroadcast() { + t.cond.Broadcast() +} + +func NewTokenCacheInMemoryProvider() *TokenCacheInMemoryProvider { + condLocker := &NoopLocker{} + return &TokenCacheInMemoryProvider{ + mu: &sync.Mutex{}, + token: atomic.Value{}, + condLocker: condLocker, + cond: sync.NewCond(condLocker), + } } diff --git a/flyteidl/clients/go/admin/client_builder.go b/flyteidl/clients/go/admin/client_builder.go index 25b263ecf1c..0d1341bf7b8 100644 --- a/flyteidl/clients/go/admin/client_builder.go +++ b/flyteidl/clients/go/admin/client_builder.go @@ -40,7 +40,7 @@ func (cb *ClientsetBuilder) WithDialOptions(opts ...grpc.DialOption) *ClientsetB // Build the clientset using the current state of the ClientsetBuilder func (cb *ClientsetBuilder) Build(ctx context.Context) (*Clientset, error) { if cb.tokenCache == nil { - cb.tokenCache = &cache.TokenCacheInMemoryProvider{} + cb.tokenCache = cache.NewTokenCacheInMemoryProvider() } if cb.config == nil { diff --git a/flyteidl/clients/go/admin/client_builder_test.go b/flyteidl/clients/go/admin/client_builder_test.go index c871bcb326d..89bcc385505 100644 --- a/flyteidl/clients/go/admin/client_builder_test.go +++ b/flyteidl/clients/go/admin/client_builder_test.go @@ -17,9 +17,9 @@ func TestClientsetBuilder_Build(t *testing.T) { cb := NewClientsetBuilder().WithConfig(&Config{ UseInsecureConnection: true, Endpoint: config.URL{URL: *u}, - }).WithTokenCache(&cache.TokenCacheInMemoryProvider{}) + }).WithTokenCache(cache.NewTokenCacheInMemoryProvider()) ctx := context.Background() _, err := cb.Build(ctx) assert.NoError(t, err) - assert.True(t, reflect.TypeOf(cb.tokenCache) == reflect.TypeOf(&cache.TokenCacheInMemoryProvider{})) + assert.True(t, reflect.TypeOf(cb.tokenCache) == reflect.TypeOf(cache.NewTokenCacheInMemoryProvider())) } diff --git a/flyteidl/clients/go/admin/client_test.go b/flyteidl/clients/go/admin/client_test.go index eb19b76f471..042a8266921 100644 --- a/flyteidl/clients/go/admin/client_test.go +++ b/flyteidl/clients/go/admin/client_test.go @@ -255,6 +255,8 @@ func TestGetAuthenticationDialOptionPkce(t *testing.T) { mockAuthClient := new(mocks.AuthMetadataServiceClient) mockTokenCache.OnGetTokenMatch().Return(&tokenData, nil) mockTokenCache.OnSaveTokenMatch(mock.Anything).Return(nil) + mockTokenCache.On("Lock").Return() + mockTokenCache.On("Unlock").Return() mockAuthClient.OnGetOAuth2MetadataMatch(mock.Anything, mock.Anything).Return(metadata, nil) mockAuthClient.OnGetPublicClientConfigMatch(mock.Anything, mock.Anything).Return(clientMetatadata, nil) tokenSourceProvider, err := NewTokenSourceProvider(ctx, adminServiceConfig, mockTokenCache, mockAuthClient) @@ -288,7 +290,7 @@ func Test_getPkceAuthTokenSource(t *testing.T) { assert.NoError(t, err) // populate the cache - tokenCache := &cache.TokenCacheInMemoryProvider{} + tokenCache := cache.NewTokenCacheInMemoryProvider() assert.NoError(t, tokenCache.SaveToken(&tokenData)) baseOrchestrator := tokenorchestrator.BaseTokenOrchestrator{ diff --git a/flyteidl/clients/go/admin/deviceflow/token_orchestrator_test.go b/flyteidl/clients/go/admin/deviceflow/token_orchestrator_test.go index 5c1dc5f2bdf..9f20fb3ef59 100644 --- a/flyteidl/clients/go/admin/deviceflow/token_orchestrator_test.go +++ b/flyteidl/clients/go/admin/deviceflow/token_orchestrator_test.go @@ -23,7 +23,7 @@ import ( func TestFetchFromAuthFlow(t *testing.T) { ctx := context.Background() t.Run("fetch from auth flow", func(t *testing.T) { - tokenCache := &cache.TokenCacheInMemoryProvider{} + tokenCache := cache.NewTokenCacheInMemoryProvider() orchestrator, err := NewDeviceFlowTokenOrchestrator(tokenorchestrator.BaseTokenOrchestrator{ ClientConfig: &oauth.Config{ Config: &oauth2.Config{ @@ -97,7 +97,7 @@ func TestFetchFromAuthFlow(t *testing.T) { })) defer fakeServer.Close() - tokenCache := &cache.TokenCacheInMemoryProvider{} + tokenCache := cache.NewTokenCacheInMemoryProvider() orchestrator, err := NewDeviceFlowTokenOrchestrator(tokenorchestrator.BaseTokenOrchestrator{ ClientConfig: &oauth.Config{ Config: &oauth2.Config{ diff --git a/flyteidl/clients/go/admin/pkce/auth_flow_orchestrator_test.go b/flyteidl/clients/go/admin/pkce/auth_flow_orchestrator_test.go index dc1c80f63ac..ca1973ea669 100644 --- a/flyteidl/clients/go/admin/pkce/auth_flow_orchestrator_test.go +++ b/flyteidl/clients/go/admin/pkce/auth_flow_orchestrator_test.go @@ -16,7 +16,7 @@ import ( func TestFetchFromAuthFlow(t *testing.T) { ctx := context.Background() t.Run("fetch from auth flow", func(t *testing.T) { - tokenCache := &cache.TokenCacheInMemoryProvider{} + tokenCache := cache.NewTokenCacheInMemoryProvider() orchestrator, err := NewTokenOrchestrator(tokenorchestrator.BaseTokenOrchestrator{ ClientConfig: &oauth.Config{ Config: &oauth2.Config{ diff --git a/flyteidl/clients/go/admin/token_source_provider.go b/flyteidl/clients/go/admin/token_source_provider.go index c911f58d35c..d55f64dc46b 100644 --- a/flyteidl/clients/go/admin/token_source_provider.go +++ b/flyteidl/clients/go/admin/token_source_provider.go @@ -191,7 +191,7 @@ func NewClientCredentialsTokenSourceProvider(ctx context.Context, cfg *Config, s } secret = strings.TrimSpace(secret) if tokenCache == nil { - tokenCache = &cache.TokenCacheInMemoryProvider{} + tokenCache = cache.NewTokenCacheInMemoryProvider() } return ClientCredentialsTokenSourceProvider{ ccConfig: clientcredentials.Config{ @@ -249,13 +249,15 @@ func (s *customTokenSource) Token() (*oauth2.Token, error) { return nil }) if err != nil { - return nil, err + logger.Warnf(s.ctx, "failed to get token: %v", err) + return nil, fmt.Errorf("failed to get token: %w", err) } + logger.Infof(s.ctx, "retrieved token with expiry %v", token.Expiry) err = s.tokenCache.SaveToken(token) if err != nil { - logger.Warnf(s.ctx, "failed to cache token: %w", err) + logger.Warnf(s.ctx, "failed to cache token: %v", err) } return token, nil diff --git a/flyteidl/clients/go/admin/token_source_provider_test.go b/flyteidl/clients/go/admin/token_source_provider_test.go index 63fc1aa56ec..43d0fdd9280 100644 --- a/flyteidl/clients/go/admin/token_source_provider_test.go +++ b/flyteidl/clients/go/admin/token_source_provider_test.go @@ -127,7 +127,9 @@ func TestCustomTokenSource_Token(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { tokenCache := &tokenCacheMocks.TokenCache{} - tokenCache.OnGetToken().Return(test.token, nil).Once() + tokenCache.OnGetToken().Return(test.token, nil).Maybe() + tokenCache.On("Lock").Return().Maybe() + tokenCache.On("Unlock").Return().Maybe() provider, err := NewClientCredentialsTokenSourceProvider(ctx, cfg, []string{}, "", tokenCache, "") assert.NoError(t, err) source, err := provider.GetTokenSource(ctx) diff --git a/flyteidl/clients/go/admin/tokenorchestrator/base_token_orchestrator.go b/flyteidl/clients/go/admin/tokenorchestrator/base_token_orchestrator.go index c4891b13ae6..a1227ec585f 100644 --- a/flyteidl/clients/go/admin/tokenorchestrator/base_token_orchestrator.go +++ b/flyteidl/clients/go/admin/tokenorchestrator/base_token_orchestrator.go @@ -3,7 +3,6 @@ package tokenorchestrator import ( "context" "fmt" - "time" "golang.org/x/oauth2" @@ -53,22 +52,29 @@ func (t BaseTokenOrchestrator) FetchTokenFromCacheOrRefreshIt(ctx context.Contex return nil, err } - if !token.Valid() { - return nil, fmt.Errorf("token from cache is invalid") + if token.Valid() { + return token, nil } - // If token doesn't need to be refreshed, return it. - if time.Now().Before(token.Expiry.Add(-tokenRefreshGracePeriod.Duration)) { - logger.Infof(ctx, "found the token in the cache") + t.TokenCache.Lock() + defer t.TokenCache.Unlock() + + token, err = t.TokenCache.GetToken() + if err != nil { + return nil, err + } + + if token.Valid() { return token, nil } - token.Expiry = token.Expiry.Add(-tokenRefreshGracePeriod.Duration) token, err = t.RefreshToken(ctx, token) if err != nil { return nil, fmt.Errorf("failed to refresh token using cached token. Error: %w", err) } + token.Expiry = token.Expiry.Add(-tokenRefreshGracePeriod.Duration) + if !token.Valid() { return nil, fmt.Errorf("refreshed token is invalid") } diff --git a/flyteidl/clients/go/admin/tokenorchestrator/base_token_orchestrator_test.go b/flyteidl/clients/go/admin/tokenorchestrator/base_token_orchestrator_test.go index ed4afa0ff05..0a1a9f4985a 100644 --- a/flyteidl/clients/go/admin/tokenorchestrator/base_token_orchestrator_test.go +++ b/flyteidl/clients/go/admin/tokenorchestrator/base_token_orchestrator_test.go @@ -26,7 +26,7 @@ func TestRefreshTheToken(t *testing.T) { ClientID: "dummyClient", }, } - tokenCacheProvider := &cache.TokenCacheInMemoryProvider{} + tokenCacheProvider := cache.NewTokenCacheInMemoryProvider() orchestrator := BaseTokenOrchestrator{ ClientConfig: clientConf, TokenCache: tokenCacheProvider, @@ -58,7 +58,7 @@ func TestFetchFromCache(t *testing.T) { mockAuthClient.OnGetPublicClientConfigMatch(mock.Anything, mock.Anything).Return(clientMetatadata, nil) t.Run("no token in cache", func(t *testing.T) { - tokenCacheProvider := &cache.TokenCacheInMemoryProvider{} + tokenCacheProvider := cache.NewTokenCacheInMemoryProvider() orchestrator, err := NewBaseTokenOrchestrator(ctx, tokenCacheProvider, mockAuthClient) @@ -69,7 +69,7 @@ func TestFetchFromCache(t *testing.T) { }) t.Run("token in cache", func(t *testing.T) { - tokenCacheProvider := &cache.TokenCacheInMemoryProvider{} + tokenCacheProvider := cache.NewTokenCacheInMemoryProvider() orchestrator, err := NewBaseTokenOrchestrator(ctx, tokenCacheProvider, mockAuthClient) assert.NoError(t, err) fileData, _ := os.ReadFile("testdata/token.json") @@ -86,7 +86,7 @@ func TestFetchFromCache(t *testing.T) { }) t.Run("expired token in cache", func(t *testing.T) { - tokenCacheProvider := &cache.TokenCacheInMemoryProvider{} + tokenCacheProvider := cache.NewTokenCacheInMemoryProvider() orchestrator, err := NewBaseTokenOrchestrator(ctx, tokenCacheProvider, mockAuthClient) assert.NoError(t, err) fileData, _ := os.ReadFile("testdata/token.json") diff --git a/flyteidl/clients/go/assets/admin.swagger.json b/flyteidl/clients/go/assets/admin.swagger.json index 42c89e77db9..c6f90a9bb00 100644 --- a/flyteidl/clients/go/assets/admin.swagger.json +++ b/flyteidl/clients/go/assets/admin.swagger.json @@ -12393,6 +12393,19 @@ }, "description": "This configuration allows executing raw containers in Flyte using the Flyte CoPilot system.\nFlyte CoPilot, eliminates the needs of flytekit or sdk inside the container. Any inputs required by the users container are side-loaded in the input_path\nAny outputs generated by the user container - within output_path are automatically uploaded." }, + "coreEnumType": { + "type": "object", + "properties": { + "values": { + "type": "array", + "items": { + "type": "string" + }, + "description": "Predefined set of enum values." + } + }, + "description": "Enables declaring enum types, with predefined string values\nFor len(values) \u003e 0, the first value in the ordered list is regarded as the default value. If you wish\nTo provide no defaults, make the first value as undefined." + }, "coreError": { "type": "object", "properties": { @@ -12797,7 +12810,7 @@ "description": "A blob might have specialized implementation details depending on associated metadata." }, "enum_type": { - "$ref": "#/definitions/flyteidlcoreEnumType", + "$ref": "#/definitions/coreEnumType", "description": "Defines an enum with pre-defined string values." }, "structured_dataset_type": { @@ -14321,19 +14334,6 @@ }, "title": "Metadata for a WorkflowNode" }, - "flyteidlcoreEnumType": { - "type": "object", - "properties": { - "values": { - "type": "array", - "items": { - "type": "string" - }, - "description": "Predefined set of enum values." - } - }, - "description": "Enables declaring enum types, with predefined string values\nFor len(values) \u003e 0, the first value in the ordered list is regarded as the default value. If you wish\nTo provide no defaults, make the first value as undefined." - }, "flyteidlcoreKeyValuePair": { "type": "object", "properties": { @@ -14477,11 +14477,11 @@ "properties": { "@type": { "type": "string", - "description": "A URL/resource name that uniquely identifies the type of the serialized\nprotocol buffer message. This string must contain at least\none \"/\" character. The last segment of the URL's path must represent\nthe fully qualified name of the type (as in\n`path/google.protobuf.Duration`). The name should be in a canonical form\n(e.g., leading \".\" is not accepted).\n\nIn practice, teams usually precompile into the binary all types that they\nexpect it to use in the context of Any. However, for URLs which use the\nscheme `http`, `https`, or no scheme, one can optionally set up a type\nserver that maps type URLs to message definitions as follows:\n\n* If no scheme is provided, `https` is assumed.\n* An HTTP GET on the URL must yield a [google.protobuf.Type][]\n value in binary format, or produce an error.\n* Applications are allowed to cache lookup results based on the\n URL, or have them precompiled into a binary to avoid any\n lookup. Therefore, binary compatibility needs to be preserved\n on changes to types. (Use versioned type names to manage\n breaking changes.)\n\nNote: this functionality is not currently available in the official\nprotobuf release, and it is not used for type URLs beginning with\ntype.googleapis.com. As of May 2023, there are no widely used type server\nimplementations and no plans to implement one.\n\nSchemes other than `http`, `https` (or the empty scheme) might be\nused with implementation specific semantics." + "description": "A URL/resource name that uniquely identifies the type of the serialized\nprotocol buffer message. This string must contain at least\none \"/\" character. The last segment of the URL's path must represent\nthe fully qualified name of the type (as in\n`path/google.protobuf.Duration`). The name should be in a canonical form\n(e.g., leading \".\" is not accepted).\n\nIn practice, teams usually precompile into the binary all types that they\nexpect it to use in the context of Any. However, for URLs which use the\nscheme `http`, `https`, or no scheme, one can optionally set up a type\nserver that maps type URLs to message definitions as follows:\n\n* If no scheme is provided, `https` is assumed.\n* An HTTP GET on the URL must yield a [google.protobuf.Type][]\n value in binary format, or produce an error.\n* Applications are allowed to cache lookup results based on the\n URL, or have them precompiled into a binary to avoid any\n lookup. Therefore, binary compatibility needs to be preserved\n on changes to types. (Use versioned type names to manage\n breaking changes.)\n\nNote: this functionality is not currently available in the official\nprotobuf release, and it is not used for type URLs beginning with\ntype.googleapis.com.\n\nSchemes other than `http`, `https` (or the empty scheme) might be\nused with implementation specific semantics." } }, "additionalProperties": {}, - "description": "`Any` contains an arbitrary serialized protocol buffer message along with a\nURL that describes the type of the serialized message.\n\nProtobuf library provides support to pack/unpack Any values in the form\nof utility functions or additional generated methods of the Any type.\n\nExample 1: Pack and unpack a message in C++.\n\n Foo foo = ...;\n Any any;\n any.PackFrom(foo);\n ...\n if (any.UnpackTo(\u0026foo)) {\n ...\n }\n\nExample 2: Pack and unpack a message in Java.\n\n Foo foo = ...;\n Any any = Any.pack(foo);\n ...\n if (any.is(Foo.class)) {\n foo = any.unpack(Foo.class);\n }\n // or ...\n if (any.isSameTypeAs(Foo.getDefaultInstance())) {\n foo = any.unpack(Foo.getDefaultInstance());\n }\n\n Example 3: Pack and unpack a message in Python.\n\n foo = Foo(...)\n any = Any()\n any.Pack(foo)\n ...\n if any.Is(Foo.DESCRIPTOR):\n any.Unpack(foo)\n ...\n\n Example 4: Pack and unpack a message in Go\n\n foo := \u0026pb.Foo{...}\n any, err := anypb.New(foo)\n if err != nil {\n ...\n }\n ...\n foo := \u0026pb.Foo{}\n if err := any.UnmarshalTo(foo); err != nil {\n ...\n }\n\nThe pack methods provided by protobuf library will by default use\n'type.googleapis.com/full.type.name' as the type URL and the unpack\nmethods only use the fully qualified type name after the last '/'\nin the type URL, for example \"foo.bar.com/x/y.z\" will yield type\nname \"y.z\".\n\nJSON\n====\nThe JSON representation of an `Any` value uses the regular\nrepresentation of the deserialized, embedded message, with an\nadditional field `@type` which contains the type URL. Example:\n\n package google.profile;\n message Person {\n string first_name = 1;\n string last_name = 2;\n }\n\n {\n \"@type\": \"type.googleapis.com/google.profile.Person\",\n \"firstName\": \u003cstring\u003e,\n \"lastName\": \u003cstring\u003e\n }\n\nIf the embedded message type is well-known and has a custom JSON\nrepresentation, that representation will be embedded adding a field\n`value` which holds the custom JSON in addition to the `@type`\nfield. Example (for message [google.protobuf.Duration][]):\n\n {\n \"@type\": \"type.googleapis.com/google.protobuf.Duration\",\n \"value\": \"1.212s\"\n }" + "description": "`Any` contains an arbitrary serialized protocol buffer message along with a\nURL that describes the type of the serialized message.\n\nProtobuf library provides support to pack/unpack Any values in the form\nof utility functions or additional generated methods of the Any type.\n\nExample 1: Pack and unpack a message in C++.\n\n Foo foo = ...;\n Any any;\n any.PackFrom(foo);\n ...\n if (any.UnpackTo(\u0026foo)) {\n ...\n }\n\nExample 2: Pack and unpack a message in Java.\n\n Foo foo = ...;\n Any any = Any.pack(foo);\n ...\n if (any.is(Foo.class)) {\n foo = any.unpack(Foo.class);\n }\n // or ...\n if (any.isSameTypeAs(Foo.getDefaultInstance())) {\n foo = any.unpack(Foo.getDefaultInstance());\n }\n\nExample 3: Pack and unpack a message in Python.\n\n foo = Foo(...)\n any = Any()\n any.Pack(foo)\n ...\n if any.Is(Foo.DESCRIPTOR):\n any.Unpack(foo)\n ...\n\nExample 4: Pack and unpack a message in Go\n\n foo := \u0026pb.Foo{...}\n any, err := anypb.New(foo)\n if err != nil {\n ...\n }\n ...\n foo := \u0026pb.Foo{}\n if err := any.UnmarshalTo(foo); err != nil {\n ...\n }\n\nThe pack methods provided by protobuf library will by default use\n'type.googleapis.com/full.type.name' as the type URL and the unpack\nmethods only use the fully qualified type name after the last '/'\nin the type URL, for example \"foo.bar.com/x/y.z\" will yield type\nname \"y.z\".\n\nJSON\n\nThe JSON representation of an `Any` value uses the regular\nrepresentation of the deserialized, embedded message, with an\nadditional field `@type` which contains the type URL. Example:\n\n package google.profile;\n message Person {\n string first_name = 1;\n string last_name = 2;\n }\n\n {\n \"@type\": \"type.googleapis.com/google.profile.Person\",\n \"firstName\": \u003cstring\u003e,\n \"lastName\": \u003cstring\u003e\n }\n\nIf the embedded message type is well-known and has a custom JSON\nrepresentation, that representation will be embedded adding a field\n`value` which holds the custom JSON in addition to the `@type`\nfield. Example (for message [google.protobuf.Duration][]):\n\n {\n \"@type\": \"type.googleapis.com/google.protobuf.Duration\",\n \"value\": \"1.212s\"\n }" }, "protobufNullValue": { "type": "string", @@ -14489,7 +14489,7 @@ "NULL_VALUE" ], "default": "NULL_VALUE", - "description": "`NullValue` is a singleton enumeration to represent the null value for the\n`Value` type union.\n\nThe JSON representation for `NullValue` is JSON `null`.\n\n - NULL_VALUE: Null value." + "description": "`NullValue` is a singleton enumeration to represent the null value for the\n`Value` type union.\n\n The JSON representation for `NullValue` is JSON `null`.\n\n - NULL_VALUE: Null value." }, "serviceAdminServiceCreateTaskBody": { "type": "object", diff --git a/flytepropeller/pkg/controller/nodes/subworkflow/handler_test.go b/flytepropeller/pkg/controller/nodes/subworkflow/handler_test.go index 34c59d082de..d4dbe5489e1 100644 --- a/flytepropeller/pkg/controller/nodes/subworkflow/handler_test.go +++ b/flytepropeller/pkg/controller/nodes/subworkflow/handler_test.go @@ -259,6 +259,8 @@ func TestWorkflowNodeHandler_CheckNodeStatus(t *testing.T) { mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { return assert.Equal(t, wfExecID.Project, o.Project) && assert.Equal(t, wfExecID.Domain, o.Domain) }), + mock.MatchedBy(func(o v1alpha1.ExecutableLaunchPlan) bool { return true }), + mock.MatchedBy(func(o v1alpha1.WorkflowID) bool { return true }), ).Return(&admin.ExecutionClosure{ Phase: core.WorkflowExecution_RUNNING, }, &core.LiteralMap{}, nil) @@ -280,6 +282,8 @@ func TestWorkflowNodeHandler_CheckNodeStatus(t *testing.T) { mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { return assert.Equal(t, wfExecID.Project, o.Project) && assert.Equal(t, wfExecID.Domain, o.Domain) }), + mock.MatchedBy(func(o v1alpha1.ExecutableLaunchPlan) bool { return true }), + mock.MatchedBy(func(o v1alpha1.WorkflowID) bool { return true }), ).Return(&admin.ExecutionClosure{ Phase: core.WorkflowExecution_RUNNING, }, &core.LiteralMap{}, nil) diff --git a/flytepropeller/pkg/controller/nodes/subworkflow/launchplan.go b/flytepropeller/pkg/controller/nodes/subworkflow/launchplan.go index 8dfd9932625..21ddc467962 100644 --- a/flytepropeller/pkg/controller/nodes/subworkflow/launchplan.go +++ b/flytepropeller/pkg/controller/nodes/subworkflow/launchplan.go @@ -159,7 +159,13 @@ func (l *launchPlanHandler) CheckLaunchPlanStatus(ctx context.Context, nCtx inte return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(core.ExecutionError_SYSTEM, errors.RuntimeExecutionError, "failed to create unique ID", nil)), nil } - wfStatusClosure, outputs, err := l.launchPlan.GetStatus(ctx, childID) + launchPlanRefID := nCtx.Node().GetWorkflowNode().GetLaunchPlanRefID() + launchPlan := nCtx.ExecutionContext().FindLaunchPlan(*launchPlanRefID) + if launchPlan == nil { + return handler.DoTransition(handler.TransitionTypeEphemeral, + handler.PhaseInfoFailure(core.ExecutionError_SYSTEM, errors.BadSpecificationError, fmt.Sprintf("launch plan not found [%v]", launchPlanRefID), nil)), nil + } + wfStatusClosure, outputs, err := l.launchPlan.GetStatus(ctx, childID, launchPlan, nCtx.NodeExecutionMetadata().GetOwnerID().String()) if err != nil { if launchplan.IsNotFound(err) { // NotFound errorCode, _ := errors.GetErrorCode(err) diff --git a/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/admin.go b/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/admin.go index bf944ccadc4..5b2928ffac0 100644 --- a/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/admin.go +++ b/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/admin.go @@ -160,7 +160,7 @@ func (a *adminLaunchPlanExecutor) Launch(ctx context.Context, launchCtx LaunchCo } } - hasOutputs := launchPlan.GetInterface() != nil && launchPlan.GetInterface().GetOutputs() != nil && len(launchPlan.GetInterface().GetOutputs().GetVariables()) > 0 + hasOutputs := launchPlan.GetInterface() != nil && launchPlan.GetInterface().GetOutputs().GetVariables() != nil && len(launchPlan.GetInterface().GetOutputs().GetVariables()) > 0 _, err = a.cache.GetOrCreate(executionID.String(), executionCacheItem{ WorkflowExecutionIdentifier: *executionID, HasOutputs: hasOutputs, @@ -173,12 +173,18 @@ func (a *adminLaunchPlanExecutor) Launch(ctx context.Context, launchCtx LaunchCo return nil } -func (a *adminLaunchPlanExecutor) GetStatus(ctx context.Context, executionID *core.WorkflowExecutionIdentifier) (*admin.ExecutionClosure, *core.LiteralMap, error) { +func (a *adminLaunchPlanExecutor) GetStatus(ctx context.Context, executionID *core.WorkflowExecutionIdentifier, + launchPlan v1alpha1.ExecutableLaunchPlan, parentWorkflowID v1alpha1.WorkflowID) (*admin.ExecutionClosure, *core.LiteralMap, error) { if executionID == nil { return nil, nil, fmt.Errorf("nil executionID") } - obj, err := a.cache.GetOrCreate(executionID.String(), executionCacheItem{WorkflowExecutionIdentifier: *executionID}) + hasOutputs := launchPlan.GetInterface() != nil && launchPlan.GetInterface().GetOutputs().GetVariables() != nil && len(launchPlan.GetInterface().GetOutputs().GetVariables()) > 0 + obj, err := a.cache.GetOrCreate(executionID.String(), executionCacheItem{ + WorkflowExecutionIdentifier: *executionID, + HasOutputs: hasOutputs, + ParentWorkflowID: parentWorkflowID, + }) if err != nil { return nil, nil, err } diff --git a/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/admin_test.go b/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/admin_test.go index 58f01122be5..96ebcb44a26 100644 --- a/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/admin_test.go +++ b/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/admin_test.go @@ -26,6 +26,29 @@ import ( storageMocks "github.com/flyteorg/flyte/flytestdlib/storage/mocks" ) +var ( + launchPlanWithOutputs = &core.LaunchPlanTemplate{ + Id: &core.Identifier{}, + Interface: &core.TypedInterface{ + Inputs: &core.VariableMap{ + Variables: map[string]*core.Variable{ + "foo": { + Type: &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_STRING}}, + }, + }, + }, + Outputs: &core.VariableMap{ + Variables: map[string]*core.Variable{ + "bar": { + Type: &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_STRING}}, + }, + }, + }, + }, + } + parentWorkflowID = "parentwf" +) + func TestAdminLaunchPlanExecutor_GetStatus(t *testing.T) { ctx := context.TODO() adminConfig := defaultAdminConfig @@ -50,9 +73,22 @@ func TestAdminLaunchPlanExecutor_GetStatus(t *testing.T) { mock.MatchedBy(func(o *admin.WorkflowExecutionGetRequest) bool { return true }), ).Return(result, nil) assert.NoError(t, err) - s, _, err := exec.GetStatus(ctx, id) + s, _, err := exec.GetStatus( + ctx, + id, + launchPlanWithOutputs, + parentWorkflowID, + ) assert.NoError(t, err) assert.Equal(t, result, s) + + item, err := exec.(*adminLaunchPlanExecutor).cache.Get(id.String()) + assert.NoError(t, err) + assert.NotNil(t, item) + assert.IsType(t, executionCacheItem{}, item) + e := item.(executionCacheItem) + assert.True(t, e.HasOutputs) + assert.Equal(t, parentWorkflowID, e.ParentWorkflowID) }) t.Run("notFound", func(t *testing.T) { @@ -87,18 +123,16 @@ func TestAdminLaunchPlanExecutor_GetStatus(t *testing.T) { }, }, id, - &core.LaunchPlanTemplate{ - Id: &core.Identifier{}, - }, + launchPlanWithOutputs, nil, - "", + parentWorkflowID, ) assert.NoError(t, err) // Allow for sync to be called time.Sleep(time.Second) - s, _, err := exec.GetStatus(ctx, id) + s, _, err := exec.GetStatus(ctx, id, launchPlanWithOutputs, parentWorkflowID) assert.Error(t, err) assert.Nil(t, s) assert.True(t, IsNotFound(err)) @@ -136,18 +170,16 @@ func TestAdminLaunchPlanExecutor_GetStatus(t *testing.T) { }, }, id, - &core.LaunchPlanTemplate{ - Id: &core.Identifier{}, - }, + launchPlanWithOutputs, nil, - "", + parentWorkflowID, ) assert.NoError(t, err) // Allow for sync to be called time.Sleep(time.Second) - s, _, err := exec.GetStatus(ctx, id) + s, _, err := exec.GetStatus(ctx, id, launchPlanWithOutputs, parentWorkflowID) assert.Error(t, err) assert.Nil(t, s) assert.False(t, IsNotFound(err)) @@ -196,11 +228,9 @@ func TestAdminLaunchPlanExecutor_Launch(t *testing.T) { Labels: labels, }, id, - &core.LaunchPlanTemplate{ - Id: &core.Identifier{}, - }, + launchPlanWithOutputs, nil, - "", + parentWorkflowID, ) assert.NoError(t, err) // Ensure we haven't mutated the state of the parent workflow. @@ -237,11 +267,9 @@ func TestAdminLaunchPlanExecutor_Launch(t *testing.T) { ParentNodeExecution: parentNodeExecution, }, id, - &core.LaunchPlanTemplate{ - Id: &core.Identifier{}, - }, + launchPlanWithOutputs, nil, - "", + parentWorkflowID, ) assert.NoError(t, err) }) @@ -289,11 +317,9 @@ func TestAdminLaunchPlanExecutor_Launch(t *testing.T) { ParentNodeExecution: parentNodeExecution, }, id, - &core.LaunchPlanTemplate{ - Id: &core.Identifier{}, - }, + launchPlanWithOutputs, nil, - "", + parentWorkflowID, ) assert.NoError(t, err) assert.True(t, createCalled) @@ -320,11 +346,9 @@ func TestAdminLaunchPlanExecutor_Launch(t *testing.T) { }, }, id, - &core.LaunchPlanTemplate{ - Id: &core.Identifier{}, - }, + launchPlanWithOutputs, nil, - "", + parentWorkflowID, ) assert.Error(t, err) assert.True(t, IsAlreadyExists(err)) @@ -351,11 +375,9 @@ func TestAdminLaunchPlanExecutor_Launch(t *testing.T) { }, }, id, - &core.LaunchPlanTemplate{ - Id: &core.Identifier{}, - }, + launchPlanWithOutputs, nil, - "", + parentWorkflowID, ) assert.Error(t, err) assert.False(t, IsAlreadyExists(err)) diff --git a/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/launchplan.go b/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/launchplan.go index ae89ec9ad55..dd897f1ec78 100644 --- a/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/launchplan.go +++ b/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/launchplan.go @@ -41,7 +41,8 @@ type Executor interface { launchPlan v1alpha1.ExecutableLaunchPlan, inputs *core.LiteralMap, parentWorkflowID v1alpha1.WorkflowID) error // GetStatus retrieves status of a LaunchPlan execution - GetStatus(ctx context.Context, executionID *core.WorkflowExecutionIdentifier) (*admin.ExecutionClosure, *core.LiteralMap, error) + GetStatus(ctx context.Context, executionID *core.WorkflowExecutionIdentifier, launchPlan v1alpha1.ExecutableLaunchPlan, + parentWorkflowID v1alpha1.WorkflowID) (*admin.ExecutionClosure, *core.LiteralMap, error) // Kill a remote execution Kill(ctx context.Context, executionID *core.WorkflowExecutionIdentifier, reason string) error diff --git a/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/mocks/executor.go b/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/mocks/executor.go index 9feb899d80f..f2981b6bb2d 100644 --- a/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/mocks/executor.go +++ b/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/mocks/executor.go @@ -29,8 +29,8 @@ func (_m Executor_GetStatus) Return(_a0 *admin.ExecutionClosure, _a1 *core.Liter return &Executor_GetStatus{Call: _m.Call.Return(_a0, _a1, _a2)} } -func (_m *Executor) OnGetStatus(ctx context.Context, executionID *core.WorkflowExecutionIdentifier) *Executor_GetStatus { - c_call := _m.On("GetStatus", ctx, executionID) +func (_m *Executor) OnGetStatus(ctx context.Context, executionID *core.WorkflowExecutionIdentifier, launchPlan v1alpha1.ExecutableLaunchPlan, parentWorkflowID string) *Executor_GetStatus { + c_call := _m.On("GetStatus", ctx, executionID, launchPlan, parentWorkflowID) return &Executor_GetStatus{Call: c_call} } @@ -39,13 +39,13 @@ func (_m *Executor) OnGetStatusMatch(matchers ...interface{}) *Executor_GetStatu return &Executor_GetStatus{Call: c_call} } -// GetStatus provides a mock function with given fields: ctx, executionID -func (_m *Executor) GetStatus(ctx context.Context, executionID *core.WorkflowExecutionIdentifier) (*admin.ExecutionClosure, *core.LiteralMap, error) { - ret := _m.Called(ctx, executionID) +// GetStatus provides a mock function with given fields: ctx, executionID, launchPlan, parentWorkflowID +func (_m *Executor) GetStatus(ctx context.Context, executionID *core.WorkflowExecutionIdentifier, launchPlan v1alpha1.ExecutableLaunchPlan, parentWorkflowID string) (*admin.ExecutionClosure, *core.LiteralMap, error) { + ret := _m.Called(ctx, executionID, launchPlan, parentWorkflowID) var r0 *admin.ExecutionClosure - if rf, ok := ret.Get(0).(func(context.Context, *core.WorkflowExecutionIdentifier) *admin.ExecutionClosure); ok { - r0 = rf(ctx, executionID) + if rf, ok := ret.Get(0).(func(context.Context, *core.WorkflowExecutionIdentifier, v1alpha1.ExecutableLaunchPlan, string) *admin.ExecutionClosure); ok { + r0 = rf(ctx, executionID, launchPlan, parentWorkflowID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*admin.ExecutionClosure) @@ -53,8 +53,8 @@ func (_m *Executor) GetStatus(ctx context.Context, executionID *core.WorkflowExe } var r1 *core.LiteralMap - if rf, ok := ret.Get(1).(func(context.Context, *core.WorkflowExecutionIdentifier) *core.LiteralMap); ok { - r1 = rf(ctx, executionID) + if rf, ok := ret.Get(1).(func(context.Context, *core.WorkflowExecutionIdentifier, v1alpha1.ExecutableLaunchPlan, string) *core.LiteralMap); ok { + r1 = rf(ctx, executionID, launchPlan, parentWorkflowID) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(*core.LiteralMap) @@ -62,8 +62,8 @@ func (_m *Executor) GetStatus(ctx context.Context, executionID *core.WorkflowExe } var r2 error - if rf, ok := ret.Get(2).(func(context.Context, *core.WorkflowExecutionIdentifier) error); ok { - r2 = rf(ctx, executionID) + if rf, ok := ret.Get(2).(func(context.Context, *core.WorkflowExecutionIdentifier, v1alpha1.ExecutableLaunchPlan, string) error); ok { + r2 = rf(ctx, executionID, launchPlan, parentWorkflowID) } else { r2 = ret.Error(2) } diff --git a/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/mocks/flyte_admin.go b/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/mocks/flyte_admin.go index b627b68de30..986c5b03e17 100644 --- a/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/mocks/flyte_admin.go +++ b/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/mocks/flyte_admin.go @@ -70,8 +70,8 @@ func (_m FlyteAdmin_GetStatus) Return(_a0 *admin.ExecutionClosure, _a1 *core.Lit return &FlyteAdmin_GetStatus{Call: _m.Call.Return(_a0, _a1, _a2)} } -func (_m *FlyteAdmin) OnGetStatus(ctx context.Context, executionID *core.WorkflowExecutionIdentifier) *FlyteAdmin_GetStatus { - c_call := _m.On("GetStatus", ctx, executionID) +func (_m *FlyteAdmin) OnGetStatus(ctx context.Context, executionID *core.WorkflowExecutionIdentifier, launchPlan v1alpha1.ExecutableLaunchPlan, parentWorkflowID string) *FlyteAdmin_GetStatus { + c_call := _m.On("GetStatus", ctx, executionID, launchPlan, parentWorkflowID) return &FlyteAdmin_GetStatus{Call: c_call} } @@ -80,13 +80,13 @@ func (_m *FlyteAdmin) OnGetStatusMatch(matchers ...interface{}) *FlyteAdmin_GetS return &FlyteAdmin_GetStatus{Call: c_call} } -// GetStatus provides a mock function with given fields: ctx, executionID -func (_m *FlyteAdmin) GetStatus(ctx context.Context, executionID *core.WorkflowExecutionIdentifier) (*admin.ExecutionClosure, *core.LiteralMap, error) { - ret := _m.Called(ctx, executionID) +// GetStatus provides a mock function with given fields: ctx, executionID, launchPlan, parentWorkflowID +func (_m *FlyteAdmin) GetStatus(ctx context.Context, executionID *core.WorkflowExecutionIdentifier, launchPlan v1alpha1.ExecutableLaunchPlan, parentWorkflowID string) (*admin.ExecutionClosure, *core.LiteralMap, error) { + ret := _m.Called(ctx, executionID, launchPlan, parentWorkflowID) var r0 *admin.ExecutionClosure - if rf, ok := ret.Get(0).(func(context.Context, *core.WorkflowExecutionIdentifier) *admin.ExecutionClosure); ok { - r0 = rf(ctx, executionID) + if rf, ok := ret.Get(0).(func(context.Context, *core.WorkflowExecutionIdentifier, v1alpha1.ExecutableLaunchPlan, string) *admin.ExecutionClosure); ok { + r0 = rf(ctx, executionID, launchPlan, parentWorkflowID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*admin.ExecutionClosure) @@ -94,8 +94,8 @@ func (_m *FlyteAdmin) GetStatus(ctx context.Context, executionID *core.WorkflowE } var r1 *core.LiteralMap - if rf, ok := ret.Get(1).(func(context.Context, *core.WorkflowExecutionIdentifier) *core.LiteralMap); ok { - r1 = rf(ctx, executionID) + if rf, ok := ret.Get(1).(func(context.Context, *core.WorkflowExecutionIdentifier, v1alpha1.ExecutableLaunchPlan, string) *core.LiteralMap); ok { + r1 = rf(ctx, executionID, launchPlan, parentWorkflowID) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(*core.LiteralMap) @@ -103,8 +103,8 @@ func (_m *FlyteAdmin) GetStatus(ctx context.Context, executionID *core.WorkflowE } var r2 error - if rf, ok := ret.Get(2).(func(context.Context, *core.WorkflowExecutionIdentifier) error); ok { - r2 = rf(ctx, executionID) + if rf, ok := ret.Get(2).(func(context.Context, *core.WorkflowExecutionIdentifier, v1alpha1.ExecutableLaunchPlan, string) error); ok { + r2 = rf(ctx, executionID, launchPlan, parentWorkflowID) } else { r2 = ret.Error(2) } diff --git a/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/noop.go b/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/noop.go index bd1e0648ebe..19cd3ebdcf8 100644 --- a/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/noop.go +++ b/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/noop.go @@ -24,7 +24,8 @@ func (failFastWorkflowLauncher) Launch(ctx context.Context, launchCtx LaunchCont return errors.Wrapf(RemoteErrorUser, fmt.Errorf("badly configured system"), "please enable admin workflow launch to use launchplans") } -func (failFastWorkflowLauncher) GetStatus(ctx context.Context, executionID *core.WorkflowExecutionIdentifier) (*admin.ExecutionClosure, *core.LiteralMap, error) { +func (failFastWorkflowLauncher) GetStatus(ctx context.Context, executionID *core.WorkflowExecutionIdentifier, + launchPlan v1alpha1.ExecutableLaunchPlan, parentWorkflowID v1alpha1.WorkflowID) (*admin.ExecutionClosure, *core.LiteralMap, error) { logger.Infof(ctx, "NOOP: Workflow Status ExecID [%s]", executionID.Name) return nil, nil, errors.Wrapf(RemoteErrorUser, fmt.Errorf("badly configured system"), "please enable admin workflow launch to use launchplans") } diff --git a/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/noop_test.go b/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/noop_test.go index a263a6aba15..b2bbf08b777 100644 --- a/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/noop_test.go +++ b/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/noop_test.go @@ -17,7 +17,10 @@ func TestFailFastWorkflowLauncher(t *testing.T) { Project: "p", Domain: "d", Name: "n", - }) + }, &core.LaunchPlanTemplate{ + Id: &core.Identifier{}, + }, "", + ) assert.Nil(t, a) assert.Error(t, err) }) diff --git a/flytepropeller/pkg/controller/nodes/subworkflow/launchplan_test.go b/flytepropeller/pkg/controller/nodes/subworkflow/launchplan_test.go index 21c340e82c4..69dee9b2d78 100644 --- a/flytepropeller/pkg/controller/nodes/subworkflow/launchplan_test.go +++ b/flytepropeller/pkg/controller/nodes/subworkflow/launchplan_test.go @@ -321,6 +321,8 @@ func TestSubWorkflowHandler_CheckLaunchPlanStatus(t *testing.T) { mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { return assert.Equal(t, wfExecID.Project, o.Project) && assert.Equal(t, wfExecID.Domain, o.Domain) }), + mock.MatchedBy(func(o v1alpha1.ExecutableLaunchPlan) bool { return true }), + mock.MatchedBy(func(o v1alpha1.WorkflowID) bool { return true }), ).Return(&admin.ExecutionClosure{ Phase: core.WorkflowExecution_RUNNING, }, &core.LiteralMap{}, nil) @@ -344,6 +346,8 @@ func TestSubWorkflowHandler_CheckLaunchPlanStatus(t *testing.T) { mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { return assert.Equal(t, wfExecID.Project, o.Project) && assert.Equal(t, wfExecID.Domain, o.Domain) }), + mock.MatchedBy(func(o v1alpha1.ExecutableLaunchPlan) bool { return true }), + mock.MatchedBy(func(o v1alpha1.WorkflowID) bool { return true }), ).Return(&admin.ExecutionClosure{ Phase: core.WorkflowExecution_SUCCEEDED, }, nil, nil) @@ -379,6 +383,8 @@ func TestSubWorkflowHandler_CheckLaunchPlanStatus(t *testing.T) { mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { return assert.Equal(t, wfExecID.Project, o.Project) && assert.Equal(t, wfExecID.Domain, o.Domain) }), + mock.MatchedBy(func(o v1alpha1.ExecutableLaunchPlan) bool { return true }), + mock.MatchedBy(func(o v1alpha1.WorkflowID) bool { return true }), ).Return(&admin.ExecutionClosure{ Phase: core.WorkflowExecution_SUCCEEDED, OutputResult: &admin.ExecutionClosure_Outputs{ @@ -421,6 +427,8 @@ func TestSubWorkflowHandler_CheckLaunchPlanStatus(t *testing.T) { mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { return assert.Equal(t, wfExecID.Project, o.Project) && assert.Equal(t, wfExecID.Domain, o.Domain) }), + mock.MatchedBy(func(o v1alpha1.ExecutableLaunchPlan) bool { return true }), + mock.MatchedBy(func(o v1alpha1.WorkflowID) bool { return true }), ).Return(&admin.ExecutionClosure{ Phase: core.WorkflowExecution_SUCCEEDED, OutputResult: &admin.ExecutionClosure_Outputs{ @@ -458,6 +466,8 @@ func TestSubWorkflowHandler_CheckLaunchPlanStatus(t *testing.T) { mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { return assert.Equal(t, wfExecID.Project, o.Project) && assert.Equal(t, wfExecID.Domain, o.Domain) }), + mock.MatchedBy(func(o v1alpha1.ExecutableLaunchPlan) bool { return true }), + mock.MatchedBy(func(o v1alpha1.WorkflowID) bool { return true }), ).Return(&admin.ExecutionClosure{ Phase: core.WorkflowExecution_FAILED, OutputResult: &admin.ExecutionClosure_Error{ @@ -488,6 +498,8 @@ func TestSubWorkflowHandler_CheckLaunchPlanStatus(t *testing.T) { mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { return assert.Equal(t, wfExecID.Project, o.Project) && assert.Equal(t, wfExecID.Domain, o.Domain) }), + mock.MatchedBy(func(o v1alpha1.ExecutableLaunchPlan) bool { return true }), + mock.MatchedBy(func(o v1alpha1.WorkflowID) bool { return true }), ).Return(&admin.ExecutionClosure{ Phase: core.WorkflowExecution_FAILED, }, &core.LiteralMap{}, nil) @@ -512,6 +524,8 @@ func TestSubWorkflowHandler_CheckLaunchPlanStatus(t *testing.T) { mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { return assert.Equal(t, wfExecID.Project, o.Project) && assert.Equal(t, wfExecID.Domain, o.Domain) }), + mock.MatchedBy(func(o v1alpha1.ExecutableLaunchPlan) bool { return true }), + mock.MatchedBy(func(o v1alpha1.WorkflowID) bool { return true }), ).Return(&admin.ExecutionClosure{ Phase: core.WorkflowExecution_ABORTED, }, &core.LiteralMap{}, nil) @@ -536,6 +550,8 @@ func TestSubWorkflowHandler_CheckLaunchPlanStatus(t *testing.T) { mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { return assert.Equal(t, wfExecID.Project, o.Project) && assert.Equal(t, wfExecID.Domain, o.Domain) }), + mock.MatchedBy(func(o v1alpha1.ExecutableLaunchPlan) bool { return true }), + mock.MatchedBy(func(o v1alpha1.WorkflowID) bool { return true }), ).Return(nil, &core.LiteralMap{}, errors.Wrapf(launchplan.RemoteErrorNotFound, fmt.Errorf("some error"), "not found")) nCtx := createNodeContext(v1alpha1.WorkflowNodePhaseExecuting, mockNode, mockNodeStatus) @@ -558,6 +574,8 @@ func TestSubWorkflowHandler_CheckLaunchPlanStatus(t *testing.T) { mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { return assert.Equal(t, wfExecID.Project, o.Project) && assert.Equal(t, wfExecID.Domain, o.Domain) }), + mock.MatchedBy(func(o v1alpha1.ExecutableLaunchPlan) bool { return true }), + mock.MatchedBy(func(o v1alpha1.WorkflowID) bool { return true }), ).Return(nil, &core.LiteralMap{}, errors.Wrapf(launchplan.RemoteErrorSystem, fmt.Errorf("some error"), "not found")) nCtx := createNodeContext(v1alpha1.WorkflowNodePhaseExecuting, mockNode, mockNodeStatus) @@ -586,6 +604,8 @@ func TestSubWorkflowHandler_CheckLaunchPlanStatus(t *testing.T) { mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { return assert.Equal(t, wfExecID.Project, o.Project) && assert.Equal(t, wfExecID.Domain, o.Domain) }), + mock.MatchedBy(func(o v1alpha1.ExecutableLaunchPlan) bool { return true }), + mock.MatchedBy(func(o v1alpha1.WorkflowID) bool { return true }), ).Return(&admin.ExecutionClosure{ Phase: core.WorkflowExecution_SUCCEEDED, OutputResult: &admin.ExecutionClosure_Outputs{ @@ -620,6 +640,8 @@ func TestSubWorkflowHandler_CheckLaunchPlanStatus(t *testing.T) { mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { return assert.Equal(t, wfExecID.Project, o.Project) && assert.Equal(t, wfExecID.Domain, o.Domain) }), + mock.MatchedBy(func(o v1alpha1.ExecutableLaunchPlan) bool { return true }), + mock.MatchedBy(func(o v1alpha1.WorkflowID) bool { return true }), ).Return(&admin.ExecutionClosure{ Phase: core.WorkflowExecution_SUCCEEDED, OutputResult: &admin.ExecutionClosure_Outputs{