diff --git a/go.mod b/go.mod index 8380fb47d..d2410d899 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/benlaurie/objecthash v0.0.0-20180202135721-d1e3d6079fc1 github.com/fatih/color v1.10.0 github.com/flyteorg/flyteidl v0.18.25 - github.com/flyteorg/flyteplugins v0.5.39-0.20210323171635-da37ae79061e + github.com/flyteorg/flyteplugins v0.5.39-0.20210323225036-cc282dd4f15b github.com/flyteorg/flytestdlib v0.3.13 github.com/ghodss/yaml v1.0.0 github.com/go-redis/redis v6.15.7+incompatible diff --git a/go.sum b/go.sum index 7c3edd674..932c02100 100644 --- a/go.sum +++ b/go.sum @@ -237,6 +237,10 @@ github.com/flyteorg/flyteplugins v0.5.39-0.20210323003259-f356ddf426ec h1:mnwspO github.com/flyteorg/flyteplugins v0.5.39-0.20210323003259-f356ddf426ec/go.mod h1:ireF+bYk8xjw9BfcMbPN/hN5aZeBJpP0CoQYHkSRL+w= github.com/flyteorg/flyteplugins v0.5.39-0.20210323171635-da37ae79061e h1:3Pq58HB6Jcu+He3FGZuYt+KFAWPesRxTcQPpXQFmux0= github.com/flyteorg/flyteplugins v0.5.39-0.20210323171635-da37ae79061e/go.mod h1:ireF+bYk8xjw9BfcMbPN/hN5aZeBJpP0CoQYHkSRL+w= +github.com/flyteorg/flyteplugins v0.5.39-0.20210323223608-28f0b78870cd h1:RkBLOqilNKXRLyG0tQuR96Rl6sZmvjdtVLt69Nib7LM= +github.com/flyteorg/flyteplugins v0.5.39-0.20210323223608-28f0b78870cd/go.mod h1:ireF+bYk8xjw9BfcMbPN/hN5aZeBJpP0CoQYHkSRL+w= +github.com/flyteorg/flyteplugins v0.5.39-0.20210323225036-cc282dd4f15b h1:ql3W0H5LgjNQuvllSRArfQDyRs1RvrZ2qxNMLRKXGDU= +github.com/flyteorg/flyteplugins v0.5.39-0.20210323225036-cc282dd4f15b/go.mod h1:ireF+bYk8xjw9BfcMbPN/hN5aZeBJpP0CoQYHkSRL+w= github.com/flyteorg/flytestdlib v0.3.13 h1:5ioA/q3ixlyqkFh5kDaHgmPyTP/AHtqq1K/TIbVLUzM= github.com/flyteorg/flytestdlib v0.3.13/go.mod h1:Tz8JCECAbX6VWGwFT6cmEQ+RJpZ/6L9pswu3fzWs220= github.com/form3tech-oss/jwt-go v3.2.2+incompatible h1:TcekIExNqud5crz4xD2pavyTgWiPvpYe4Xau31I0PRk= diff --git a/pkg/controller/nodes/task/resourcemanager/resourcemanager.go b/pkg/controller/nodes/task/resourcemanager/resourcemanager.go index 278f99867..e896b9f0e 100644 --- a/pkg/controller/nodes/task/resourcemanager/resourcemanager.go +++ b/pkg/controller/nodes/task/resourcemanager/resourcemanager.go @@ -5,6 +5,8 @@ import ( "fmt" "sync" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" pluginCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" @@ -72,6 +74,12 @@ type Proxy struct { BaseResourceManager ResourceNamespacePrefix pluginCore.ResourceNamespace ExecutionIdentifier *core.TaskExecutionIdentifier + ResourcePoolInfo map[string]*event.ResourcePoolInfo +} + +type TaskResourceManager interface { + pluginCore.ResourceManager + GetResourcePoolInfo() []*event.ResourcePoolInfo } func (p Proxy) ComposeResourceConstraint(spec pluginCore.ResourceConstraintsSpec) []FullyQualifiedResourceConstraint { @@ -92,6 +100,13 @@ func (p Proxy) AllocateResource(ctx context.Context, namespace pluginCore.Resour p.ResourceNamespacePrefix.CreateSubNamespace(namespace), Token(allocationToken).prepend(ComposeTokenPrefix(p.ExecutionIdentifier)), composedResourceConstraintList) + if err != nil { + return status, err + } + p.ResourcePoolInfo[allocationToken] = &event.ResourcePoolInfo{ + AllocationToken: allocationToken, + Namespace: string(namespace), + } return status, err } @@ -103,12 +118,21 @@ func (p Proxy) ReleaseResource(ctx context.Context, namespace pluginCore.Resourc return err } +func (p Proxy) GetResourcePoolInfo() []*event.ResourcePoolInfo { + response := make([]*event.ResourcePoolInfo, 0, len(p.ResourcePoolInfo)) + for _, resourcePoolInfo := range p.ResourcePoolInfo { + response = append(response, resourcePoolInfo) + } + return response +} + func GetTaskResourceManager(r BaseResourceManager, resourceNamespacePrefix pluginCore.ResourceNamespace, - id *core.TaskExecutionIdentifier) pluginCore.ResourceManager { + id *core.TaskExecutionIdentifier) TaskResourceManager { return Proxy{ BaseResourceManager: r, ResourceNamespacePrefix: resourceNamespacePrefix, ExecutionIdentifier: id, + ResourcePoolInfo: make(map[string]*event.ResourcePoolInfo), } } diff --git a/pkg/controller/nodes/task/taskexec_context.go b/pkg/controller/nodes/task/taskexec_context.go index e3b15fa61..dcdda89e1 100644 --- a/pkg/controller/nodes/task/taskexec_context.go +++ b/pkg/controller/nodes/task/taskexec_context.go @@ -49,10 +49,9 @@ func (te taskExecutionID) GetGeneratedName() string { type taskExecutionMetadata struct { handler.NodeExecutionMetadata - taskExecID taskExecutionID - o pluginCore.TaskOverrides - maxAttempts uint32 - resourcePoolInfo map[string]*event.ResourcePoolInfo + taskExecID taskExecutionID + o pluginCore.TaskOverrides + maxAttempts uint32 } func (t taskExecutionMetadata) GetTaskExecutionID() pluginCore.TaskExecutionID { @@ -67,18 +66,10 @@ func (t taskExecutionMetadata) GetMaxAttempts() uint32 { return t.maxAttempts } -func (t taskExecutionMetadata) GetResourcePoolInfo() []*event.ResourcePoolInfo { - response := make([]*event.ResourcePoolInfo, 0, len(t.resourcePoolInfo)) - for _, resourcePoolInfo := range t.resourcePoolInfo { - response = append(response, resourcePoolInfo) - } - return response -} - type taskExecutionContext struct { handler.NodeExecutionContext tm taskExecutionMetadata - rm pluginCore.ResourceManager + rm resourcemanager.TaskResourceManager psm *pluginStateManager tr handler.TaskReader ow *ioutils.BufferedOutputWriter @@ -109,15 +100,7 @@ func (t taskExecutionContext) EventsRecorder() pluginCore.EventsRecorder { // During execution time, plugins can call AllocateResource() to register a token to the token pool associated with a resource with the resource manager. func (t taskExecutionContext) AllocateResource(ctx context.Context, namespace pluginCore.ResourceNamespace, allocationToken string, constraintsSpec pluginCore.ResourceConstraintsSpec) (pluginCore.AllocationStatus, error) { - allocationStatus, err := t.rm.AllocateResource(ctx, namespace, allocationToken, constraintsSpec) - if err != nil { - return allocationStatus, err - } - t.tm.resourcePoolInfo[allocationToken] = &event.ResourcePoolInfo{ - AllocationToken: allocationToken, - Namespace: string(namespace), - } - return allocationStatus, nil + return t.rm.AllocateResource(ctx, namespace, allocationToken, constraintsSpec) } // During execution time, after an outstanding request is completed, the plugin needs to use ReleaseResource() to release the allocation of the corresponding token @@ -150,6 +133,14 @@ func (t taskExecutionContext) SecretManager() pluginCore.SecretManager { return t.sm } +func (t taskExecutionContext) ResourceManager() pluginCore.ResourceManager { + return t.rm +} + +func (t taskExecutionContext) GetResourcePoolInfo() []*event.ResourcePoolInfo { + return t.rm.GetResourcePoolInfo() +} + func (t *Handler) newTaskExecutionContext(ctx context.Context, nCtx handler.NodeExecutionContext, pluginID string) (*taskExecutionContext, error) { id := GetTaskExecutionIdentifier(nCtx) @@ -196,7 +187,6 @@ func (t *Handler) newTaskExecutionContext(ctx context.Context, nCtx handler.Node taskExecID: taskExecutionID{execName: uniqueID, id: id}, o: nCtx.Node(), maxAttempts: maxAttempts, - resourcePoolInfo: make(map[string]*event.ResourcePoolInfo), }, rm: resourcemanager.GetTaskResourceManager( t.resourceManager, resourceNamespacePrefix, id), diff --git a/pkg/controller/nodes/task/taskexec_context_test.go b/pkg/controller/nodes/task/taskexec_context_test.go index af9abb6bd..1014bcd4c 100644 --- a/pkg/controller/nodes/task/taskexec_context_test.go +++ b/pkg/controller/nodes/task/taskexec_context_test.go @@ -151,7 +151,7 @@ func TestHandler_newTaskExecutionContext(t *testing.T) { assert.Equal(t, got.TaskExecutionMetadata().GetTaskExecutionID().GetID().NodeExecutionId.GetNodeId(), nodeID) assert.Equal(t, got.TaskExecutionMetadata().GetTaskExecutionID().GetID().NodeExecutionId.GetExecutionId(), wfExecID) - assert.EqualValues(t, got.TaskExecutionMetadata().GetResourcePoolInfo(), make([]*event.ResourcePoolInfo, 0)) + assert.EqualValues(t, got.GetResourcePoolInfo(), make([]*event.ResourcePoolInfo, 0)) // TODO @kumare fix this test assert.NotNil(t, got.rm) @@ -163,7 +163,7 @@ func TestHandler_newTaskExecutionContext(t *testing.T) { Namespace: "foo", AllocationToken: "token", }, - }, got.TaskExecutionMetadata().GetResourcePoolInfo()) + }, got.GetResourcePoolInfo()) assert.Nil(t, got.Catalog()) // assert.Equal(t, got.InputReader(), ir) } diff --git a/pkg/controller/nodes/task/transformer.go b/pkg/controller/nodes/task/transformer.go index 4adf73562..86d453baa 100644 --- a/pkg/controller/nodes/task/transformer.go +++ b/pkg/controller/nodes/task/transformer.go @@ -103,7 +103,7 @@ func ToTaskExecutionEvent(input ToTaskExecutionEventInputs) (*event.TaskExecutio } metadata.PluginIdentifier = input.PluginID metadata.GeneratedName = input.TaskExecContext.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() - metadata.ResourcePoolInfo = input.TaskExecContext.TaskExecutionMetadata().GetResourcePoolInfo() + metadata.ResourcePoolInfo = input.TaskExecContext.GetResourcePoolInfo() tev := &event.TaskExecutionEvent{ TaskId: taskExecID.TaskId, ParentNodeExecutionId: nodeExecutionID, diff --git a/pkg/controller/nodes/task/transformer_test.go b/pkg/controller/nodes/task/transformer_test.go index 477c63496..09342fb47 100644 --- a/pkg/controller/nodes/task/transformer_test.go +++ b/pkg/controller/nodes/task/transformer_test.go @@ -92,16 +92,16 @@ func TestToTaskExecutionEvent(t *testing.T) { tMeta := &pluginMocks.TaskExecutionMetadata{} tMeta.OnGetTaskExecutionID().Return(tID) + + tCtx := &pluginMocks.TaskExecutionContext{} + tCtx.OnTaskExecutionMetadata().Return(tMeta) resourcePoolInfo := []*event.ResourcePoolInfo{ { Namespace: "ns", AllocationToken: "alloc_token", }, } - tMeta.OnGetResourcePoolInfo().Return(resourcePoolInfo) - - tCtx := &pluginMocks.TaskExecutionContext{} - tCtx.OnTaskExecutionMetadata().Return(tMeta) + tCtx.OnGetResourcePoolInfo().Return(resourcePoolInfo) tev, err := ToTaskExecutionEvent(ToTaskExecutionEventInputs{ TaskExecContext: tCtx, @@ -241,16 +241,16 @@ func TestToTaskExecutionEventWithParent(t *testing.T) { tMeta := &pluginMocks.TaskExecutionMetadata{} tMeta.OnGetTaskExecutionID().Return(tID) + + tCtx := &pluginMocks.TaskExecutionContext{} + tCtx.OnTaskExecutionMetadata().Return(tMeta) resourcePoolInfo := []*event.ResourcePoolInfo{ { Namespace: "ns", AllocationToken: "alloc_token", }, } - tMeta.OnGetResourcePoolInfo().Return(resourcePoolInfo) - - tCtx := &pluginMocks.TaskExecutionContext{} - tCtx.OnTaskExecutionMetadata().Return(tMeta) + tCtx.OnGetResourcePoolInfo().Return(resourcePoolInfo) tev, err := ToTaskExecutionEvent(ToTaskExecutionEventInputs{ TaskExecContext: tCtx,