Skip to content

Commit

Permalink
Pass task-name and node-id in pod labels (flyteorg#87)
Browse files Browse the repository at this point in the history
* Pass task-name and node-id in pod labels

* changing receiver name
  • Loading branch information
surindersinghp authored Mar 13, 2020
1 parent 1bd3181 commit 12dca2b
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 6 deletions.
12 changes: 6 additions & 6 deletions flytepropeller/pkg/controller/nodes/node_exec_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ const TaskNameLabel = "task-name"

type execMetadata struct {
v1alpha1.WorkflowMeta
nodeLabels map[string]string
}

func (e execMetadata) GetK8sServiceAccount() string {
Expand All @@ -30,6 +31,10 @@ func (e execMetadata) GetOwnerID() types.NamespacedName {
return types.NamespacedName{Name: e.GetName(), Namespace: e.GetNamespace()}
}

func (e execMetadata) GetLabels() map[string]string {
return e.nodeLabels
}

type execContext struct {
store *storage.DataStore
tr handler.TaskReader
Expand All @@ -42,7 +47,6 @@ type execContext struct {
nsm *nodeStateManager
enqueueOwner func() error
w v1alpha1.ExecutableWorkflow
nodeLabels map[string]string
}

func (e execContext) EnqueueOwnerFunc() func() error {
Expand Down Expand Up @@ -101,10 +105,6 @@ func (e execContext) MaxDatasetSizeBytes() int64 {
return e.maxDatasetSizeBytes
}

func (e execContext) GetLabels() map[string]string {
return e.nodeLabels
}

func newNodeExecContext(_ context.Context, store *storage.DataStore, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode, nodeStatus v1alpha1.ExecutableNodeStatus, inputs io.InputReader, maxDatasetSize int64, er events.TaskEventRecorder, tr handler.TaskReader, nsm *nodeStateManager, enqueueOwner func() error) *execContext {
md := execMetadata{WorkflowMeta: w}

Expand All @@ -117,6 +117,7 @@ func newNodeExecContext(_ context.Context, store *storage.DataStore, w v1alpha1.
if tr != nil && tr.GetTaskID() != nil {
nodeLabels[TaskNameLabel] = utils.SanitizeLabelValue(tr.GetTaskID().Name)
}
md.nodeLabels = nodeLabels

return &execContext{
md: md,
Expand All @@ -130,7 +131,6 @@ func newNodeExecContext(_ context.Context, store *storage.DataStore, w v1alpha1.
nsm: nsm,
enqueueOwner: enqueueOwner,
w: w,
nodeLabels: nodeLabels,
}
}

Expand Down
51 changes: 51 additions & 0 deletions flytepropeller/pkg/controller/nodes/node_exec_context_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package nodes

import (
"context"
"testing"

"github.com/lyft/flyteidl/gen/pb-go/flyteidl/core"
"github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1"
"github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks"
"github.com/lyft/flytestdlib/promutils"
"github.com/lyft/flytestdlib/storage"
"github.com/stretchr/testify/assert"
)

type TaskReader struct{}

func (t TaskReader) Read(ctx context.Context) (*core.TaskTemplate, error) { return nil, nil }
func (t TaskReader) GetTaskType() v1alpha1.TaskType { return "" }
func (t TaskReader) GetTaskID() *core.Identifier {
return &core.Identifier{Project: "p", Domain: "d", Name: "task-name"}
}

func Test_NodeContext(t *testing.T) {
ns := mocks.ExecutableNodeStatus{}
ns.On("GetDataDir").Return(storage.DataReference("data-dir"))
ns.On("GetPhase").Return(v1alpha1.NodePhaseNotYetStarted)

childDatadir := v1alpha1.DataReference("test")
dataStore, _ := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope())
w1 := &v1alpha1.FlyteWorkflow{
Status: v1alpha1.WorkflowStatus{
NodeStatus: map[v1alpha1.NodeID]*v1alpha1.NodeStatus{
"childNodeID": {
DataDir: childDatadir,
},
},
},
DataReferenceConstructor: dataStore,
}

taskID := "taskID"
n := &v1alpha1.NodeSpec{
ID: "id",
TaskRef: &taskID,
Kind: v1alpha1.NodeKindTask,
}
s, _ := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope())
nCtx := newNodeExecContext(context.TODO(), s, w1, n, nil, nil, 0, nil, TaskReader{}, nil, nil)
assert.Equal(t, nCtx.NodeExecutionMetadata().GetLabels()["node-id"], "id")
assert.Equal(t, nCtx.NodeExecutionMetadata().GetLabels()["task-name"], "task-name")
}

0 comments on commit 12dca2b

Please sign in to comment.