diff --git a/config.yaml b/config.yaml index dd8203d09..ae97dbd5f 100644 --- a/config.yaml +++ b/config.yaml @@ -32,8 +32,9 @@ tasks: - container - K8S-ARRAY - qubole-hive-executor - - sagemaker_training - - sagemaker_hyperparameter_tuning +# Uncomment to enable sagemaker plugin +# - sagemaker_training +# - sagemaker_hyperparameter_tuning # Sample plugins config plugins: # All k8s plugins default configuration diff --git a/go.mod b/go.mod index 8a228db66..7d2b25305 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/grpc-ecosystem/grpc-gateway v1.14.3 // indirect github.com/imdario/mergo v0.3.8 // indirect github.com/lyft/datacatalog v0.2.1 - github.com/lyft/flyteidl v0.18.0 + github.com/lyft/flyteidl v0.18.1 github.com/lyft/flyteplugins v0.4.4 github.com/lyft/flytestdlib v0.3.9 github.com/magiconair/properties v1.8.1 diff --git a/go.sum b/go.sum index fc24e5176..0fdf3efc9 100644 --- a/go.sum +++ b/go.sum @@ -398,8 +398,8 @@ github.com/lyft/datacatalog v0.2.1/go.mod h1:ktrPvzTDUwHO5Lv0hLH38zLHnOJ++rGoAO0 github.com/lyft/flyteidl v0.17.0/go.mod h1:/zQXxuHO11u/saxTTZc8oYExIGEShXB+xCB1/F1Cu20= github.com/lyft/flyteidl v0.18.0 h1:f4yv1MafE26wpMC6QlthM02EeTEDXpy/waL54dRDiSs= github.com/lyft/flyteidl v0.18.0/go.mod h1:/zQXxuHO11u/saxTTZc8oYExIGEShXB+xCB1/F1Cu20= -github.com/lyft/flyteplugins v0.4.2 h1:DUyvi7PkJtQ+WV5ZlVypIfJOJOL3THn6QJgh5g24kG4= -github.com/lyft/flyteplugins v0.4.2/go.mod h1:8zhqFG9BzbHNQGEXzGYltTJLD+KTmQZkanxXgeFI25c= +github.com/lyft/flyteidl v0.18.1 h1:COKkZi5k6bQvUYOk5gE70+FJX9/NUn0WOQ1uMrw3Qio= +github.com/lyft/flyteidl v0.18.1/go.mod h1:/zQXxuHO11u/saxTTZc8oYExIGEShXB+xCB1/F1Cu20= github.com/lyft/flyteplugins v0.4.4 h1:2tFBAtcxjd81wVByI5yVSIBKJ/UECk7XQK3F1XzttNA= github.com/lyft/flyteplugins v0.4.4/go.mod h1:8zhqFG9BzbHNQGEXzGYltTJLD+KTmQZkanxXgeFI25c= github.com/lyft/flytestdlib v0.3.0/go.mod h1:LJPPJlkFj+wwVWMrQT3K5JZgNhZi2mULsCG4ZYhinhU= diff --git a/pkg/apis/flyteworkflow/v1alpha1/admin.go b/pkg/apis/flyteworkflow/v1alpha1/admin.go new file mode 100644 index 000000000..3e4843d45 --- /dev/null +++ b/pkg/apis/flyteworkflow/v1alpha1/admin.go @@ -0,0 +1,15 @@ +package v1alpha1 + +import "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" + +// This contains an OutputLocationPrefix. When running against AWS, this should be something of the form +// s3://my-bucket, or s3://my-bucket/ A sharding string will automatically be appended to this prefix before +// handing off to plugins/tasks. Sharding behavior may change in the future. +// Background available at https://github.com/lyft/flyte/issues/211 +type RawOutputDataConfig struct { + *admin.RawOutputDataConfig +} + +func (in *RawOutputDataConfig) DeepCopyInto(out *RawOutputDataConfig) { + *out = *in +} diff --git a/pkg/apis/flyteworkflow/v1alpha1/admin_test.go b/pkg/apis/flyteworkflow/v1alpha1/admin_test.go new file mode 100644 index 000000000..3699ec6c8 --- /dev/null +++ b/pkg/apis/flyteworkflow/v1alpha1/admin_test.go @@ -0,0 +1,15 @@ +package v1alpha1 + +import ( + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/stretchr/testify/assert" + + "testing" +) + +func TestRawOutputConfig(t *testing.T) { + r := RawOutputDataConfig{&admin.RawOutputDataConfig{ + OutputLocationPrefix: "s3://bucket", + }} + assert.Equal(t, "s3://bucket", r.OutputLocationPrefix) +} diff --git a/pkg/apis/flyteworkflow/v1alpha1/iface.go b/pkg/apis/flyteworkflow/v1alpha1/iface.go index 727a4a0b3..6de1a023f 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/iface.go +++ b/pkg/apis/flyteworkflow/v1alpha1/iface.go @@ -432,6 +432,7 @@ type Meta interface { GetName() string GetServiceAccountName() string IsInterruptible() bool + GetRawOutputDataConfig() RawOutputDataConfig } type TaskDetailsGetter interface { diff --git a/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableWorkflow.go b/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableWorkflow.go index 4f816edfb..ec0b8b15a 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableWorkflow.go +++ b/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableWorkflow.go @@ -696,6 +696,38 @@ func (_m *ExecutableWorkflow) GetOwnerReference() v1.OwnerReference { return r0 } +type ExecutableWorkflow_GetRawOutputDataConfig struct { + *mock.Call +} + +func (_m ExecutableWorkflow_GetRawOutputDataConfig) Return(_a0 v1alpha1.RawOutputDataConfig) *ExecutableWorkflow_GetRawOutputDataConfig { + return &ExecutableWorkflow_GetRawOutputDataConfig{Call: _m.Call.Return(_a0)} +} + +func (_m *ExecutableWorkflow) OnGetRawOutputDataConfig() *ExecutableWorkflow_GetRawOutputDataConfig { + c := _m.On("GetRawOutputDataConfig") + return &ExecutableWorkflow_GetRawOutputDataConfig{Call: c} +} + +func (_m *ExecutableWorkflow) OnGetRawOutputDataConfigMatch(matchers ...interface{}) *ExecutableWorkflow_GetRawOutputDataConfig { + c := _m.On("GetRawOutputDataConfig", matchers...) + return &ExecutableWorkflow_GetRawOutputDataConfig{Call: c} +} + +// GetRawOutputDataConfig provides a mock function with given fields: +func (_m *ExecutableWorkflow) GetRawOutputDataConfig() v1alpha1.RawOutputDataConfig { + ret := _m.Called() + + var r0 v1alpha1.RawOutputDataConfig + if rf, ok := ret.Get(0).(func() v1alpha1.RawOutputDataConfig); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(v1alpha1.RawOutputDataConfig) + } + + return r0 +} + type ExecutableWorkflow_GetServiceAccountName struct { *mock.Call } diff --git a/pkg/apis/flyteworkflow/v1alpha1/mocks/Meta.go b/pkg/apis/flyteworkflow/v1alpha1/mocks/Meta.go index 7692bd130..72048800d 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/mocks/Meta.go +++ b/pkg/apis/flyteworkflow/v1alpha1/mocks/Meta.go @@ -276,6 +276,38 @@ func (_m *Meta) GetOwnerReference() v1.OwnerReference { return r0 } +type Meta_GetRawOutputDataConfig struct { + *mock.Call +} + +func (_m Meta_GetRawOutputDataConfig) Return(_a0 v1alpha1.RawOutputDataConfig) *Meta_GetRawOutputDataConfig { + return &Meta_GetRawOutputDataConfig{Call: _m.Call.Return(_a0)} +} + +func (_m *Meta) OnGetRawOutputDataConfig() *Meta_GetRawOutputDataConfig { + c := _m.On("GetRawOutputDataConfig") + return &Meta_GetRawOutputDataConfig{Call: c} +} + +func (_m *Meta) OnGetRawOutputDataConfigMatch(matchers ...interface{}) *Meta_GetRawOutputDataConfig { + c := _m.On("GetRawOutputDataConfig", matchers...) + return &Meta_GetRawOutputDataConfig{Call: c} +} + +// GetRawOutputDataConfig provides a mock function with given fields: +func (_m *Meta) GetRawOutputDataConfig() v1alpha1.RawOutputDataConfig { + ret := _m.Called() + + var r0 v1alpha1.RawOutputDataConfig + if rf, ok := ret.Get(0).(func() v1alpha1.RawOutputDataConfig); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(v1alpha1.RawOutputDataConfig) + } + + return r0 +} + type Meta_GetServiceAccountName struct { *mock.Call } diff --git a/pkg/apis/flyteworkflow/v1alpha1/mocks/MetaExtended.go b/pkg/apis/flyteworkflow/v1alpha1/mocks/MetaExtended.go index 282a78256..04f2499f4 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/mocks/MetaExtended.go +++ b/pkg/apis/flyteworkflow/v1alpha1/mocks/MetaExtended.go @@ -344,6 +344,38 @@ func (_m *MetaExtended) GetOwnerReference() v1.OwnerReference { return r0 } +type MetaExtended_GetRawOutputDataConfig struct { + *mock.Call +} + +func (_m MetaExtended_GetRawOutputDataConfig) Return(_a0 v1alpha1.RawOutputDataConfig) *MetaExtended_GetRawOutputDataConfig { + return &MetaExtended_GetRawOutputDataConfig{Call: _m.Call.Return(_a0)} +} + +func (_m *MetaExtended) OnGetRawOutputDataConfig() *MetaExtended_GetRawOutputDataConfig { + c := _m.On("GetRawOutputDataConfig") + return &MetaExtended_GetRawOutputDataConfig{Call: c} +} + +func (_m *MetaExtended) OnGetRawOutputDataConfigMatch(matchers ...interface{}) *MetaExtended_GetRawOutputDataConfig { + c := _m.On("GetRawOutputDataConfig", matchers...) + return &MetaExtended_GetRawOutputDataConfig{Call: c} +} + +// GetRawOutputDataConfig provides a mock function with given fields: +func (_m *MetaExtended) GetRawOutputDataConfig() v1alpha1.RawOutputDataConfig { + ret := _m.Called() + + var r0 v1alpha1.RawOutputDataConfig + if rf, ok := ret.Get(0).(func() v1alpha1.RawOutputDataConfig); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(v1alpha1.RawOutputDataConfig) + } + + return r0 +} + type MetaExtended_GetServiceAccountName struct { *mock.Call } diff --git a/pkg/apis/flyteworkflow/v1alpha1/workflow.go b/pkg/apis/flyteworkflow/v1alpha1/workflow.go index 672d15fe8..bf327ac1a 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/workflow.go +++ b/pkg/apis/flyteworkflow/v1alpha1/workflow.go @@ -44,8 +44,14 @@ type FlyteWorkflow struct { ServiceAccountName string `json:"serviceAccountName,omitempty" protobuf:"bytes,8,opt,name=serviceAccountName"` // Status is the only mutable section in the workflow. It holds all the execution information Status WorkflowStatus `json:"status,omitempty"` - - // non-Serialized fields + // RawOutputDataConfig defines the configurations to use for generating raw outputs (e.g. blobs, schemas). + RawOutputDataConfig RawOutputDataConfig `json:"rawOutputDataConfig,omitempty"` + + // non-Serialized fields (these will not get written to etcd) + // As of 2020-07, the only real implementation of this interface is a URLPathConstructor, which is just an empty + // struct. However, because this field is an interface, we create it once when the crd is hydrated from etcd, + // so that it can be used downstream without any confusion. + // This field is here because it's easier to put it here than pipe through a new object through all of propeller. DataReferenceConstructor storage.ReferenceConstructor `json:"-"` } @@ -110,6 +116,10 @@ func (in *FlyteWorkflow) IsInterruptible() bool { return in.NodeDefaults.Interruptible } +func (in *FlyteWorkflow) GetRawOutputDataConfig() RawOutputDataConfig { + return in.RawOutputDataConfig +} + type Inputs struct { *core.LiteralMap } diff --git a/pkg/controller/executors/mocks/execution_context.go b/pkg/controller/executors/mocks/execution_context.go index 9d1a585a7..5c19f7444 100644 --- a/pkg/controller/executors/mocks/execution_context.go +++ b/pkg/controller/executors/mocks/execution_context.go @@ -374,6 +374,38 @@ func (_m *ExecutionContext) GetOwnerReference() v1.OwnerReference { return r0 } +type ExecutionContext_GetRawOutputDataConfig struct { + *mock.Call +} + +func (_m ExecutionContext_GetRawOutputDataConfig) Return(_a0 v1alpha1.RawOutputDataConfig) *ExecutionContext_GetRawOutputDataConfig { + return &ExecutionContext_GetRawOutputDataConfig{Call: _m.Call.Return(_a0)} +} + +func (_m *ExecutionContext) OnGetRawOutputDataConfig() *ExecutionContext_GetRawOutputDataConfig { + c := _m.On("GetRawOutputDataConfig") + return &ExecutionContext_GetRawOutputDataConfig{Call: c} +} + +func (_m *ExecutionContext) OnGetRawOutputDataConfigMatch(matchers ...interface{}) *ExecutionContext_GetRawOutputDataConfig { + c := _m.On("GetRawOutputDataConfig", matchers...) + return &ExecutionContext_GetRawOutputDataConfig{Call: c} +} + +// GetRawOutputDataConfig provides a mock function with given fields: +func (_m *ExecutionContext) GetRawOutputDataConfig() v1alpha1.RawOutputDataConfig { + ret := _m.Called() + + var r0 v1alpha1.RawOutputDataConfig + if rf, ok := ret.Get(0).(func() v1alpha1.RawOutputDataConfig); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(v1alpha1.RawOutputDataConfig) + } + + return r0 +} + type ExecutionContext_GetServiceAccountName struct { *mock.Call } diff --git a/pkg/controller/executors/mocks/immutable_execution_context.go b/pkg/controller/executors/mocks/immutable_execution_context.go index fe996fb17..2d0758b9f 100644 --- a/pkg/controller/executors/mocks/immutable_execution_context.go +++ b/pkg/controller/executors/mocks/immutable_execution_context.go @@ -340,6 +340,38 @@ func (_m *ImmutableExecutionContext) GetOwnerReference() v1.OwnerReference { return r0 } +type ImmutableExecutionContext_GetRawOutputDataConfig struct { + *mock.Call +} + +func (_m ImmutableExecutionContext_GetRawOutputDataConfig) Return(_a0 v1alpha1.RawOutputDataConfig) *ImmutableExecutionContext_GetRawOutputDataConfig { + return &ImmutableExecutionContext_GetRawOutputDataConfig{Call: _m.Call.Return(_a0)} +} + +func (_m *ImmutableExecutionContext) OnGetRawOutputDataConfig() *ImmutableExecutionContext_GetRawOutputDataConfig { + c := _m.On("GetRawOutputDataConfig") + return &ImmutableExecutionContext_GetRawOutputDataConfig{Call: c} +} + +func (_m *ImmutableExecutionContext) OnGetRawOutputDataConfigMatch(matchers ...interface{}) *ImmutableExecutionContext_GetRawOutputDataConfig { + c := _m.On("GetRawOutputDataConfig", matchers...) + return &ImmutableExecutionContext_GetRawOutputDataConfig{Call: c} +} + +// GetRawOutputDataConfig provides a mock function with given fields: +func (_m *ImmutableExecutionContext) GetRawOutputDataConfig() v1alpha1.RawOutputDataConfig { + ret := _m.Called() + + var r0 v1alpha1.RawOutputDataConfig + if rf, ok := ret.Get(0).(func() v1alpha1.RawOutputDataConfig); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(v1alpha1.RawOutputDataConfig) + } + + return r0 +} + type ImmutableExecutionContext_GetServiceAccountName struct { *mock.Call } diff --git a/pkg/controller/nodes/executor_test.go b/pkg/controller/nodes/executor_test.go index 47c312a2f..01864c14f 100644 --- a/pkg/controller/nodes/executor_test.go +++ b/pkg/controller/nodes/executor_test.go @@ -8,6 +8,8 @@ import ( "testing" "time" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/event" "github.com/lyft/flytestdlib/promutils/labeled" "github.com/lyft/flytestdlib/storage" @@ -193,6 +195,9 @@ func TestNodeExecutor_RecursiveNodeHandler_RecurseStartNodes(t *testing.T) { }, }, DataReferenceConstructor: store, + RawOutputDataConfig: v1alpha1.RawOutputDataConfig{ + RawOutputDataConfig: &admin.RawOutputDataConfig{OutputLocationPrefix: ""}, + }, }, startNode, startNodeStatus } @@ -292,6 +297,9 @@ func TestNodeExecutor_RecursiveNodeHandler_RecurseEndNode(t *testing.T) { }, }, DataReferenceConstructor: store, + RawOutputDataConfig: v1alpha1.RawOutputDataConfig{ + RawOutputDataConfig: &admin.RawOutputDataConfig{OutputLocationPrefix: ""}, + }, }, n, ns } @@ -377,6 +385,9 @@ func TestNodeExecutor_RecursiveNodeHandler_RecurseEndNode(t *testing.T) { }, }, DataReferenceConstructor: store, + RawOutputDataConfig: v1alpha1.RawOutputDataConfig{ + RawOutputDataConfig: &admin.RawOutputDataConfig{OutputLocationPrefix: ""}, + }, }, n, ns } @@ -507,6 +518,9 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { }, }, DataReferenceConstructor: store, + RawOutputDataConfig: v1alpha1.RawOutputDataConfig{ + RawOutputDataConfig: &admin.RawOutputDataConfig{OutputLocationPrefix: ""}, + }, }, n, ns } @@ -599,6 +613,9 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { mockWf.OnGetLabels().Return(make(map[string]string)) mockWf.OnIsInterruptible().Return(false) mockWf.OnGetOnFailurePolicy().Return(v1alpha1.WorkflowOnFailurePolicy(core.WorkflowMetadata_FAIL_IMMEDIATELY)) + mockWf.OnGetRawOutputDataConfig().Return(v1alpha1.RawOutputDataConfig{ + RawOutputDataConfig: &admin.RawOutputDataConfig{OutputLocationPrefix: ""}, + }) mockWfStatus.OnGetDataDir().Return(storage.DataReference("x")) mockWfStatus.OnConstructNodeDataDirMatch(mock.Anything, mock.Anything, mock.Anything).Return("x", nil) return mockWf, mockN2Status @@ -1098,6 +1115,9 @@ func TestNodeExecutor_RecursiveNodeHandler_UpstreamNotReady(t *testing.T) { }, }, DataReferenceConstructor: store, + RawOutputDataConfig: v1alpha1.RawOutputDataConfig{ + RawOutputDataConfig: &admin.RawOutputDataConfig{OutputLocationPrefix: ""}, + }, }, n, ns } @@ -1210,6 +1230,9 @@ func TestNodeExecutor_RecursiveNodeHandler_BranchNode(t *testing.T) { eCtx.OnIsInterruptible().Return(true) eCtx.OnGetExecutionID().Return(v1alpha1.WorkflowExecutionIdentifier{WorkflowExecutionIdentifier: &core.WorkflowExecutionIdentifier{}}) eCtx.OnGetLabels().Return(nil) + eCtx.OnGetRawOutputDataConfig().Return(v1alpha1.RawOutputDataConfig{ + RawOutputDataConfig: &admin.RawOutputDataConfig{OutputLocationPrefix: ""}, + }) branchTakenNodeID := "branchTakenNode" branchTakenNode := &mocks.ExecutableNode{} diff --git a/pkg/controller/nodes/node_exec_context.go b/pkg/controller/nodes/node_exec_context.go index 225d712fb..2b5276d40 100644 --- a/pkg/controller/nodes/node_exec_context.go +++ b/pkg/controller/nodes/node_exec_context.go @@ -135,7 +135,11 @@ func (e nodeExecContext) MaxDatasetSizeBytes() int64 { return e.maxDatasetSizeBytes } -func newNodeExecContext(_ context.Context, store *storage.DataStore, execContext executors.ExecutionContext, nl executors.NodeLookup, node v1alpha1.ExecutableNode, nodeStatus v1alpha1.ExecutableNodeStatus, inputs io.InputReader, interruptible bool, maxDatasetSize int64, er events.TaskEventRecorder, tr handler.TaskReader, nsm *nodeStateManager, enqueueOwner func() error, rawOutputPrefix storage.DataReference, outputShardSelector ioutils.ShardSelector) *nodeExecContext { +func newNodeExecContext(_ context.Context, store *storage.DataStore, execContext executors.ExecutionContext, nl executors.NodeLookup, + node v1alpha1.ExecutableNode, nodeStatus v1alpha1.ExecutableNodeStatus, inputs io.InputReader, interruptible bool, + maxDatasetSize int64, er events.TaskEventRecorder, tr handler.TaskReader, nsm *nodeStateManager, + enqueueOwner func() error, rawOutputPrefix storage.DataReference, outputShardSelector ioutils.ShardSelector) *nodeExecContext { + md := nodeExecMetadata{ Meta: execContext, nodeExecID: &core.NodeExecutionIdentifier{ @@ -175,7 +179,8 @@ func newNodeExecContext(_ context.Context, store *storage.DataStore, execContext } } -func (c *nodeExecutor) newNodeExecContextDefault(ctx context.Context, currentNodeID v1alpha1.NodeID, executionContext executors.ExecutionContext, nl executors.NodeLookup) (*nodeExecContext, error) { +func (c *nodeExecutor) newNodeExecContextDefault(ctx context.Context, currentNodeID v1alpha1.NodeID, + executionContext executors.ExecutionContext, nl executors.NodeLookup) (*nodeExecContext, error) { n, ok := nl.GetNode(currentNodeID) if !ok { return nil, fmt.Errorf("failed to find node with ID [%s] in execution [%s]", currentNodeID, executionContext.GetID()) @@ -198,19 +203,24 @@ func (c *nodeExecutor) newNodeExecContextDefault(ctx context.Context, currentNod return nil } - interrutible := executionContext.IsInterruptible() + interruptible := executionContext.IsInterruptible() if n.IsInterruptible() != nil { - interrutible = *n.IsInterruptible() + interruptible = *n.IsInterruptible() } s := nl.GetNodeExecutionStatus(ctx, currentNodeID) // a node is not considered interruptible if the system failures have exceeded the configured threshold - if interrutible && s.GetSystemFailures() >= c.interruptibleFailureThreshold { - interrutible = false + if interruptible && s.GetSystemFailures() >= c.interruptibleFailureThreshold { + interruptible = false c.metrics.InterruptedThresholdHit.Inc(ctx) } + rawOutputPrefix := c.defaultDataSandbox + if executionContext.GetRawOutputDataConfig().RawOutputDataConfig != nil && len(executionContext.GetRawOutputDataConfig().OutputLocationPrefix) > 0 { + rawOutputPrefix = storage.DataReference(executionContext.GetRawOutputDataConfig().OutputLocationPrefix) + } + return newNodeExecContext(ctx, c.store, executionContext, nl, n, s, ioutils.NewCachedInputReader( ctx, @@ -224,15 +234,13 @@ func (c *nodeExecutor) newNodeExecContextDefault(ctx context.Context, currentNod ), ), ), - interrutible, + interruptible, c.maxDatasetSizeBytes, &taskEventRecorder{TaskEventRecorder: c.taskRecorder}, tr, newNodeStateManager(ctx, s), workflowEnqueuer, - // Eventually we want to replace this with per workflow sandboxes - // https://github.com/lyft/flyte/issues/211 - c.defaultDataSandbox, + rawOutputPrefix, c.shardSelector, ), nil } diff --git a/pkg/controller/nodes/node_exec_context_test.go b/pkg/controller/nodes/node_exec_context_test.go index 9b8483c2a..77120b530 100644 --- a/pkg/controller/nodes/node_exec_context_test.go +++ b/pkg/controller/nodes/node_exec_context_test.go @@ -4,6 +4,9 @@ import ( "context" "testing" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" + mocks2 "github.com/lyft/flytepropeller/pkg/controller/executors/mocks" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/ioutils" "github.com/lyft/flytestdlib/promutils" @@ -52,3 +55,60 @@ func Test_NodeContext(t *testing.T) { assert.Equal(t, "false", nCtx.NodeExecutionMetadata().GetLabels()["interruptible"]) assert.Equal(t, "task-name", nCtx.NodeExecutionMetadata().GetLabels()["task-name"]) } + +func Test_NodeContextDefault(t *testing.T) { + ctx := context.Background() + + w1 := &v1alpha1.FlyteWorkflow{ + NodeDefaults: v1alpha1.NodeDefaults{Interruptible: false}, + RawOutputDataConfig: v1alpha1.RawOutputDataConfig{RawOutputDataConfig: &admin.RawOutputDataConfig{ + OutputLocationPrefix: ""}, + }, + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: "some.workflow", + }, + Tasks: map[v1alpha1.TaskID]*v1alpha1.TaskSpec{ + "taskID": { + TaskTemplate: &core.TaskTemplate{ + Id: &core.Identifier{ + ResourceType: 1, + Project: "proj", + Domain: "domain", + Name: "taskID", + Version: "abc", + }, + }, + }, + }, + } + dataStore, _ := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) + taskID := "taskID" + n := &v1alpha1.NodeSpec{ + ID: "id", + TaskRef: &taskID, + Kind: v1alpha1.NodeKindTask, + } + nodeLookup := &mocks2.NodeLookup{} + nodeLookup.OnGetNode("node-a").Return(n, true) + nodeLookup.OnGetNodeExecutionStatus(ctx, "node-a").Return(&v1alpha1.NodeStatus{ + SystemFailures: 0, + }) + + nodeExecutor := nodeExecutor{ + interruptibleFailureThreshold: 0, + maxDatasetSizeBytes: 0, + defaultDataSandbox: "s3://bucket-a", + store: dataStore, + shardSelector: ioutils.NewConstantShardSelector([]string{"x"}), + enqueueWorkflow: func(workflowID v1alpha1.WorkflowID) {}, + } + + nodeExecContext, err := nodeExecutor.newNodeExecContextDefault(context.Background(), "node-a", w1, nodeLookup) + assert.NoError(t, err) + assert.Equal(t, "s3://bucket-a", nodeExecContext.rawOutputPrefix.String()) + + w1.RawOutputDataConfig.OutputLocationPrefix = "s3://bucket-b" + nodeExecContext, err = nodeExecutor.newNodeExecContextDefault(context.Background(), "node-a", w1, nodeLookup) + assert.NoError(t, err) + assert.Equal(t, "s3://bucket-b", nodeExecContext.rawOutputPrefix.String()) +} diff --git a/pkg/controller/nodes/resolve_test.go b/pkg/controller/nodes/resolve_test.go index 215677e16..fdee701c5 100644 --- a/pkg/controller/nodes/resolve_test.go +++ b/pkg/controller/nodes/resolve_test.go @@ -96,6 +96,10 @@ func (d *dummyBaseWorkflow) IsInterruptible() bool { return d.Interruptible } +func (d *dummyBaseWorkflow) GetRawOutputDataConfig() v1alpha1.RawOutputDataConfig { + return v1alpha1.RawOutputDataConfig{} +} + func (d *dummyBaseWorkflow) GetName() string { return d.ID } diff --git a/pkg/controller/workflow/executor_test.go b/pkg/controller/workflow/executor_test.go index d00820a3a..5da1bcba9 100644 --- a/pkg/controller/workflow/executor_test.go +++ b/pkg/controller/workflow/executor_test.go @@ -9,6 +9,8 @@ import ( "testing" "time" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/lyft/flytestdlib/contextutils" "github.com/lyft/flytestdlib/promutils/labeled" "github.com/stretchr/testify/mock" @@ -240,7 +242,9 @@ func TestWorkflowExecutor_HandleFlyteWorkflow_Error(t *testing.T) { wJSON, err := yamlutils.ReadYamlFileAsJSON("testdata/benchmark_wf.yaml") if assert.NoError(t, err) { - w := &v1alpha1.FlyteWorkflow{} + w := &v1alpha1.FlyteWorkflow{ + RawOutputDataConfig: v1alpha1.RawOutputDataConfig{RawOutputDataConfig: &admin.RawOutputDataConfig{}}, + } if assert.NoError(t, json.Unmarshal(wJSON, w)) { // For benchmark workflow, we know how many rounds it needs // Number of rounds = 7 + 1 @@ -318,7 +322,9 @@ func TestWorkflowExecutor_HandleFlyteWorkflow(t *testing.T) { wJSON, err := yamlutils.ReadYamlFileAsJSON("testdata/benchmark_wf.yaml") if assert.NoError(t, err) { - w := &v1alpha1.FlyteWorkflow{} + w := &v1alpha1.FlyteWorkflow{ + RawOutputDataConfig: v1alpha1.RawOutputDataConfig{RawOutputDataConfig: &admin.RawOutputDataConfig{}}, + } if assert.NoError(t, json.Unmarshal(wJSON, w)) { // For benchmark workflow, we know how many rounds it needs // Number of rounds = 28 @@ -464,7 +470,9 @@ func TestWorkflowExecutor_HandleFlyteWorkflow_Failing(t *testing.T) { wJSON, err := yamlutils.ReadYamlFileAsJSON("testdata/benchmark_wf.yaml") if assert.NoError(t, err) { - w := &v1alpha1.FlyteWorkflow{} + w := &v1alpha1.FlyteWorkflow{ + RawOutputDataConfig: v1alpha1.RawOutputDataConfig{RawOutputDataConfig: &admin.RawOutputDataConfig{}}, + } if assert.NoError(t, json.Unmarshal(wJSON, w)) { // For benchmark workflow, we will run into the first failure on round 6 @@ -554,7 +562,9 @@ func TestWorkflowExecutor_HandleFlyteWorkflow_Events(t *testing.T) { wJSON, err := yamlutils.ReadYamlFileAsJSON("testdata/benchmark_wf.yaml") if assert.NoError(t, err) { - w := &v1alpha1.FlyteWorkflow{} + w := &v1alpha1.FlyteWorkflow{ + RawOutputDataConfig: v1alpha1.RawOutputDataConfig{RawOutputDataConfig: &admin.RawOutputDataConfig{}}, + } if assert.NoError(t, json.Unmarshal(wJSON, w)) { // For benchmark workflow, we know how many rounds it needs // Number of rounds = 28 ?