From 56923e5ddd8690d148c9a1100562515b94d39a07 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Mon, 17 Oct 2022 11:58:51 -0700 Subject: [PATCH 01/14] Turn PhaseRetryableFailure into PhaseRetryLimitExceededFailure Signed-off-by: Kevin Su --- go/tasks/pluginmachinery/core/phase.go | 2 ++ go/tasks/plugins/array/awsbatch/executor.go | 13 +++++++++++-- go/tasks/plugins/array/awsbatch/monitor.go | 7 +++++-- go/tasks/plugins/array/awsbatch/monitor_test.go | 8 ++++---- go/tasks/plugins/array/core/state.go | 7 +++++-- go/tasks/plugins/array/core/state_test.go | 8 ++++++++ go/tasks/plugins/array/outputs.go | 2 +- go/tasks/plugins/array/subtask_phase.go | 2 +- 8 files changed, 37 insertions(+), 12 deletions(-) diff --git a/go/tasks/pluginmachinery/core/phase.go b/go/tasks/pluginmachinery/core/phase.go index 9cfdbe2ba..99a9809e8 100644 --- a/go/tasks/pluginmachinery/core/phase.go +++ b/go/tasks/pluginmachinery/core/phase.go @@ -33,6 +33,8 @@ const ( PhaseRetryableFailure // Indicate that the failure is non recoverable even if retries exist PhasePermanentFailure + // Indicate that the tasks won't be executed again because the retry limit exceeded + PhaseRetryLimitExceededFailure // Indicates the task is waiting for the cache to be populated so it can reuse results PhaseWaitingForCache ) diff --git a/go/tasks/plugins/array/awsbatch/executor.go b/go/tasks/plugins/array/awsbatch/executor.go index 00614da9e..db2a879f3 100644 --- a/go/tasks/plugins/array/awsbatch/executor.go +++ b/go/tasks/plugins/array/awsbatch/executor.go @@ -80,9 +80,18 @@ func (e Executor) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (c pluginState, err = LaunchSubTasks(ctx, tCtx, e.jobStore, pluginConfig, pluginState, e.metrics) case arrayCore.PhaseCheckingSubTaskExecutions: + // Check that the taskTemplate is valid + taskTemplate, err := tCtx.TaskReader().Read(ctx) + if err != nil { + return core.UnknownTransition, errors.Wrapf(errors.CorruptedPluginState, err, "Failed to read task template") + } else if taskTemplate == nil { + return core.UnknownTransition, errors.Errorf(errors.BadTaskSpecification, "Required value not set, taskTemplate is nil") + } + retry := toRetryStrategy(ctx, toBackoffLimit(taskTemplate.Metadata), pluginConfig.MinRetries, pluginConfig.MaxRetries) + pluginState, err = CheckSubTasksState(ctx, tCtx.TaskExecutionMetadata(), tCtx.OutputWriter().GetOutputPrefixPath(), tCtx.OutputWriter().GetRawOutputPrefix(), - e.jobStore, tCtx.DataStore(), pluginConfig, pluginState, e.metrics) + e.jobStore, tCtx.DataStore(), pluginConfig, pluginState, e.metrics, *retry.Attempts) case arrayCore.PhaseAssembleFinalOutput: pluginState.State, err = array.AssembleFinalOutputs(ctx, e.outputAssembler, tCtx, arrayCore.PhaseSuccess, version, pluginState.State) @@ -94,7 +103,7 @@ func (e Executor) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (c pluginState.State, err = array.WriteToDiscovery(ctx, tCtx, pluginState.State, arrayCore.PhaseAssembleFinalOutput, version) case arrayCore.PhaseAssembleFinalError: - pluginState.State, err = array.AssembleFinalOutputs(ctx, e.errorAssembler, tCtx, arrayCore.PhaseRetryableFailure, version, pluginState.State) + pluginState.State, err = array.AssembleFinalOutputs(ctx, e.errorAssembler, tCtx, arrayCore.PhasePermanentFailure, version, pluginState.State) } if err != nil { diff --git a/go/tasks/plugins/array/awsbatch/monitor.go b/go/tasks/plugins/array/awsbatch/monitor.go index 846aec420..a02362431 100644 --- a/go/tasks/plugins/array/awsbatch/monitor.go +++ b/go/tasks/plugins/array/awsbatch/monitor.go @@ -2,7 +2,6 @@ package awsbatch import ( "context" - core2 "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flytestdlib/storage" @@ -35,7 +34,7 @@ func createSubJobList(count int) []*Job { } func CheckSubTasksState(ctx context.Context, taskMeta core.TaskExecutionMetadata, outputPrefix, baseOutputSandbox storage.DataReference, jobStore *JobStore, - dataStore *storage.DataStore, cfg *config.Config, currentState *State, metrics ExecutorMetrics) (newState *State, err error) { + dataStore *storage.DataStore, cfg *config.Config, currentState *State, metrics ExecutorMetrics, retryLimit int64) (newState *State, err error) { newState = currentState parentState := currentState.State jobName := taskMeta.GetTaskExecutionID().GetGeneratedName() @@ -108,6 +107,10 @@ func CheckSubTasksState(ctx context.Context, taskMeta core.TaskExecutionMetadata } else { msg.Collect(childIdx, "Job failed") } + + if subJob.Status.Phase == core.PhaseRetryableFailure && retryLimit == int64(len(subJob.Attempts)) { + actualPhase = core.PhaseRetryLimitExceededFailure + } } else if subJob.Status.Phase.IsSuccess() { actualPhase, err = array.CheckTaskOutput(ctx, dataStore, outputPrefix, baseOutputSandbox, childIdx, originalIdx) if err != nil { diff --git a/go/tasks/plugins/array/awsbatch/monitor_test.go b/go/tasks/plugins/array/awsbatch/monitor_test.go index 621512b4a..5b3418a22 100644 --- a/go/tasks/plugins/array/awsbatch/monitor_test.go +++ b/go/tasks/plugins/array/awsbatch/monitor_test.go @@ -61,7 +61,7 @@ func TestCheckSubTasksState(t *testing.T) { }, ExternalJobID: refStr("job-id"), JobDefinitionArn: "", - }, getAwsBatchExecutorMetrics(promutils.NewTestScope())) + }, getAwsBatchExecutorMetrics(promutils.NewTestScope()), 3) assert.NoError(t, err) p, _ := newState.GetPhase() @@ -107,7 +107,7 @@ func TestCheckSubTasksState(t *testing.T) { }, ExternalJobID: refStr("job-id"), JobDefinitionArn: "", - }, getAwsBatchExecutorMetrics(promutils.NewTestScope())) + }, getAwsBatchExecutorMetrics(promutils.NewTestScope()), 3) assert.NoError(t, err) p, _ := newState.GetPhase() @@ -153,7 +153,7 @@ func TestCheckSubTasksState(t *testing.T) { }, ExternalJobID: refStr("job-id"), JobDefinitionArn: "", - }, getAwsBatchExecutorMetrics(promutils.NewTestScope())) + }, getAwsBatchExecutorMetrics(promutils.NewTestScope()), 3) assert.NoError(t, err) p, _ := newState.GetPhase() @@ -201,7 +201,7 @@ func TestCheckSubTasksState(t *testing.T) { }, ExternalJobID: refStr("job-id"), JobDefinitionArn: "", - }, getAwsBatchExecutorMetrics(promutils.NewTestScope())) + }, getAwsBatchExecutorMetrics(promutils.NewTestScope()), 3) assert.NoError(t, err) p, _ := newState.GetPhase() diff --git a/go/tasks/plugins/array/core/state.go b/go/tasks/plugins/array/core/state.go index 9a281c311..11821bc84 100644 --- a/go/tasks/plugins/array/core/state.go +++ b/go/tasks/plugins/array/core/state.go @@ -242,6 +242,7 @@ func SummaryToPhase(ctx context.Context, minSuccesses int64, summary arraystatus totalCount := int64(0) totalSuccesses := int64(0) totalPermanentFailures := int64(0) + totalRetryLimitExceededFailures := int64(0) totalRetryableFailures := int64(0) totalRunning := int64(0) totalWaitingForResources := int64(0) @@ -253,6 +254,8 @@ func SummaryToPhase(ctx context.Context, minSuccesses int64, summary arraystatus totalSuccesses += count case core.PhasePermanentFailure: totalPermanentFailures += count + case core.PhaseRetryLimitExceededFailure: + totalRetryLimitExceededFailures += count case core.PhaseRetryableFailure: totalRetryableFailures += count case core.PhaseWaitingForResources: @@ -283,8 +286,8 @@ func SummaryToPhase(ctx context.Context, minSuccesses int64, summary arraystatus return PhaseWriteToDiscovery } - logger.Debugf(ctx, "Array is still running [Successes: %v, PermanentFailures: %v, RetryableFailures: %v, Total: %v, MinSuccesses: %v]", - totalSuccesses, totalPermanentFailures, totalRetryableFailures, totalCount, minSuccesses) + logger.Debugf(ctx, "Array is still running [Successes: %v, PermanentFailures: %v, RetryLimitExceededFailures: %v, RetryableFailures: %v, Total: %v, MinSuccesses: %v]", + totalSuccesses, totalPermanentFailures, totalRetryableFailures, totalRetryableFailures, totalCount, minSuccesses) return PhaseCheckingSubTaskExecutions } diff --git a/go/tasks/plugins/array/core/state_test.go b/go/tasks/plugins/array/core/state_test.go index 654979431..9bf1746f7 100644 --- a/go/tasks/plugins/array/core/state_test.go +++ b/go/tasks/plugins/array/core/state_test.go @@ -334,6 +334,14 @@ func TestSummaryToPhase(t *testing.T) { core.PhaseSuccess: 10, }, }, + { + "FailedToRetry", + PhaseWriteToDiscoveryThenFail, + map[core.Phase]int64{ + core.PhaseSuccess: 5, + core.PhaseRetryLimitExceededFailure: 5, + }, + }, } for _, tt := range tests { diff --git a/go/tasks/plugins/array/outputs.go b/go/tasks/plugins/array/outputs.go index ed9b08f39..afc3c8863 100644 --- a/go/tasks/plugins/array/outputs.go +++ b/go/tasks/plugins/array/outputs.go @@ -267,7 +267,7 @@ func AssembleFinalOutputs(ctx context.Context, assemblyQueue OutputAssembler, tC Message: w.Error().Error(), }) - state = state.SetPhase(arrayCore.PhaseRetryableFailure, 0).SetReason("Failed to assemble outputs/errors.") + state = state.SetPhase(arrayCore.PhasePermanentFailure, 0).SetReason("Failed to assemble outputs/errors.") } return state, nil diff --git a/go/tasks/plugins/array/subtask_phase.go b/go/tasks/plugins/array/subtask_phase.go index 50569ee57..8ba1cd273 100644 --- a/go/tasks/plugins/array/subtask_phase.go +++ b/go/tasks/plugins/array/subtask_phase.go @@ -33,7 +33,7 @@ func CheckTaskOutput(ctx context.Context, dataStore *storage.DataStore, outputPr if errExists { logger.Debugf(ctx, "Found error file for sub task [%v] with original index [%v]. Marking as failure.", childIdx, originalIdx) - return core.PhaseRetryableFailure, nil + return core.PhasePermanentFailure, nil } } From ecd56c5430a874b5d0fb256bda9a22d586cb3aef Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Mon, 17 Oct 2022 12:03:25 -0700 Subject: [PATCH 02/14] nit Signed-off-by: Kevin Su --- go/tasks/plugins/array/awsbatch/executor.go | 2 +- go/tasks/plugins/array/outputs.go | 2 +- go/tasks/plugins/array/subtask_phase.go | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go/tasks/plugins/array/awsbatch/executor.go b/go/tasks/plugins/array/awsbatch/executor.go index db2a879f3..4eec7fd9d 100644 --- a/go/tasks/plugins/array/awsbatch/executor.go +++ b/go/tasks/plugins/array/awsbatch/executor.go @@ -103,7 +103,7 @@ func (e Executor) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (c pluginState.State, err = array.WriteToDiscovery(ctx, tCtx, pluginState.State, arrayCore.PhaseAssembleFinalOutput, version) case arrayCore.PhaseAssembleFinalError: - pluginState.State, err = array.AssembleFinalOutputs(ctx, e.errorAssembler, tCtx, arrayCore.PhasePermanentFailure, version, pluginState.State) + pluginState.State, err = array.AssembleFinalOutputs(ctx, e.errorAssembler, tCtx, arrayCore.PhaseRetryableFailure, version, pluginState.State) } if err != nil { diff --git a/go/tasks/plugins/array/outputs.go b/go/tasks/plugins/array/outputs.go index afc3c8863..ed9b08f39 100644 --- a/go/tasks/plugins/array/outputs.go +++ b/go/tasks/plugins/array/outputs.go @@ -267,7 +267,7 @@ func AssembleFinalOutputs(ctx context.Context, assemblyQueue OutputAssembler, tC Message: w.Error().Error(), }) - state = state.SetPhase(arrayCore.PhasePermanentFailure, 0).SetReason("Failed to assemble outputs/errors.") + state = state.SetPhase(arrayCore.PhaseRetryableFailure, 0).SetReason("Failed to assemble outputs/errors.") } return state, nil diff --git a/go/tasks/plugins/array/subtask_phase.go b/go/tasks/plugins/array/subtask_phase.go index 8ba1cd273..aab90ae9a 100644 --- a/go/tasks/plugins/array/subtask_phase.go +++ b/go/tasks/plugins/array/subtask_phase.go @@ -33,7 +33,7 @@ func CheckTaskOutput(ctx context.Context, dataStore *storage.DataStore, outputPr if errExists { logger.Debugf(ctx, "Found error file for sub task [%v] with original index [%v]. Marking as failure.", childIdx, originalIdx) - return core.PhasePermanentFailure, nil + return core.PhaseRetryLimitExceededFailure, nil } } From cbaac1ea1e4c608f5227c88cd03bf6806deabc92 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Mon, 17 Oct 2022 14:24:50 -0700 Subject: [PATCH 03/14] update Signed-off-by: Kevin Su --- go/tasks/pluginmachinery/core/phase.go | 5 +++-- go/tasks/plugins/array/awsbatch/monitor.go | 2 +- go/tasks/plugins/array/core/state.go | 7 ++----- go/tasks/plugins/array/subtask_phase.go | 2 +- 4 files changed, 7 insertions(+), 9 deletions(-) diff --git a/go/tasks/pluginmachinery/core/phase.go b/go/tasks/pluginmachinery/core/phase.go index 99a9809e8..a9ed27b0e 100644 --- a/go/tasks/pluginmachinery/core/phase.go +++ b/go/tasks/pluginmachinery/core/phase.go @@ -33,10 +33,10 @@ const ( PhaseRetryableFailure // Indicate that the failure is non recoverable even if retries exist PhasePermanentFailure - // Indicate that the tasks won't be executed again because the retry limit exceeded - PhaseRetryLimitExceededFailure // Indicates the task is waiting for the cache to be populated so it can reuse results PhaseWaitingForCache + // Indicate that the tasks won't be executed again because the retry limit exceeded + PhaseRetryLimitExceededFailure ) var Phases = []Phase{ @@ -50,6 +50,7 @@ var Phases = []Phase{ PhaseRetryableFailure, PhasePermanentFailure, PhaseWaitingForCache, + PhaseRetryLimitExceededFailure, } // Returns true if the given phase is failure, retryable failure or success diff --git a/go/tasks/plugins/array/awsbatch/monitor.go b/go/tasks/plugins/array/awsbatch/monitor.go index a02362431..cf75ba7c3 100644 --- a/go/tasks/plugins/array/awsbatch/monitor.go +++ b/go/tasks/plugins/array/awsbatch/monitor.go @@ -133,7 +133,7 @@ func CheckSubTasksState(ctx context.Context, taskMeta core.TaskExecutionMetadata if phase != arrayCore.PhaseCheckingSubTaskExecutions { metrics.SubTasksSucceeded.Add(ctx, float64(newArrayStatus.Summary[core.PhaseSuccess])) - totalFailed := newArrayStatus.Summary[core.PhasePermanentFailure] + newArrayStatus.Summary[core.PhaseRetryableFailure] + totalFailed := newArrayStatus.Summary[core.PhasePermanentFailure] + newArrayStatus.Summary[core.PhaseRetryableFailure] + newArrayStatus.Summary[core.PhaseRetryLimitExceededFailure] metrics.SubTasksFailed.Add(ctx, float64(totalFailed)) } if phase == arrayCore.PhaseWriteToDiscoveryThenFail { diff --git a/go/tasks/plugins/array/core/state.go b/go/tasks/plugins/array/core/state.go index 11821bc84..c0bc00bfb 100644 --- a/go/tasks/plugins/array/core/state.go +++ b/go/tasks/plugins/array/core/state.go @@ -205,9 +205,6 @@ func MapArrayStateToPluginPhase(_ context.Context, state *State, logLinks []*idl case PhaseAssembleFinalError: fallthrough - case PhaseWriteToDiscoveryThenFail: - fallthrough - case PhaseWriteToDiscovery: // The state version is only incremented in PhaseCheckingSubTaskExecutions when subtask // phases are updated. Therefore by adding the phase to the state version we ensure that @@ -225,7 +222,7 @@ func MapArrayStateToPluginPhase(_ context.Context, state *State, logLinks []*idl phaseInfo = core.PhaseInfoRetryableFailure(ErrorK8sArrayGeneric, state.GetReason(), nowTaskInfo) } - case PhasePermanentFailure: + case PhasePermanentFailure, PhaseWriteToDiscoveryThenFail: if state.GetExecutionErr() != nil { phaseInfo = core.PhaseInfoFailed(core.PhasePermanentFailure, state.GetExecutionErr(), nowTaskInfo) } else { @@ -287,7 +284,7 @@ func SummaryToPhase(ctx context.Context, minSuccesses int64, summary arraystatus } logger.Debugf(ctx, "Array is still running [Successes: %v, PermanentFailures: %v, RetryLimitExceededFailures: %v, RetryableFailures: %v, Total: %v, MinSuccesses: %v]", - totalSuccesses, totalPermanentFailures, totalRetryableFailures, totalRetryableFailures, totalCount, minSuccesses) + totalSuccesses, totalPermanentFailures, totalRetryLimitExceededFailures, totalRetryableFailures, totalCount, minSuccesses) return PhaseCheckingSubTaskExecutions } diff --git a/go/tasks/plugins/array/subtask_phase.go b/go/tasks/plugins/array/subtask_phase.go index aab90ae9a..50569ee57 100644 --- a/go/tasks/plugins/array/subtask_phase.go +++ b/go/tasks/plugins/array/subtask_phase.go @@ -33,7 +33,7 @@ func CheckTaskOutput(ctx context.Context, dataStore *storage.DataStore, outputPr if errExists { logger.Debugf(ctx, "Found error file for sub task [%v] with original index [%v]. Marking as failure.", childIdx, originalIdx) - return core.PhaseRetryLimitExceededFailure, nil + return core.PhaseRetryableFailure, nil } } From 5ad7369c7fa448a24df2eaafa855de56c7ad4190 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Mon, 17 Oct 2022 15:30:21 -0700 Subject: [PATCH 04/14] update test Signed-off-by: Kevin Su --- go/tasks/plugins/array/awsbatch/executor.go | 3 +- .../plugins/array/awsbatch/monitor_test.go | 46 +++++++++++++++++++ 2 files changed, 48 insertions(+), 1 deletion(-) diff --git a/go/tasks/plugins/array/awsbatch/executor.go b/go/tasks/plugins/array/awsbatch/executor.go index 4eec7fd9d..045ebdb57 100644 --- a/go/tasks/plugins/array/awsbatch/executor.go +++ b/go/tasks/plugins/array/awsbatch/executor.go @@ -81,7 +81,8 @@ func (e Executor) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (c case arrayCore.PhaseCheckingSubTaskExecutions: // Check that the taskTemplate is valid - taskTemplate, err := tCtx.TaskReader().Read(ctx) + var taskTemplate *idlCore.TaskTemplate + taskTemplate, err = tCtx.TaskReader().Read(ctx) if err != nil { return core.UnknownTransition, errors.Wrapf(errors.CorruptedPluginState, err, "Failed to read task template") } else if taskTemplate == nil { diff --git a/go/tasks/plugins/array/awsbatch/monitor_test.go b/go/tasks/plugins/array/awsbatch/monitor_test.go index 5b3418a22..b5a5925be 100644 --- a/go/tasks/plugins/array/awsbatch/monitor_test.go +++ b/go/tasks/plugins/array/awsbatch/monitor_test.go @@ -206,6 +206,52 @@ func TestCheckSubTasksState(t *testing.T) { assert.NoError(t, err) p, _ := newState.GetPhase() assert.Equal(t, arrayCore.PhaseCheckingSubTaskExecutions.String(), p.String()) + }) + + t.Run("retry limit exceeded", func(t *testing.T) { + mBatchClient := batchMocks.NewMockAwsBatchClient() + batchClient := NewCustomBatchClient(mBatchClient, "", "", + utils.NewRateLimiter("", 10, 20), + utils.NewRateLimiter("", 10, 20)) + + jobStore := newJobsStore(t, batchClient) + _, err := jobStore.GetOrCreate(tID.GetGeneratedName(), &Job{ + ID: "job-id", + Status: JobStatus{ + Phase: core.PhaseRunning, + }, + SubJobs: []*Job{ + {Status: JobStatus{Phase: core.PhaseRetryableFailure}, Attempts: []Attempt{{LogStream: "failed"}}}, + {Status: JobStatus{Phase: core.PhaseSuccess}}, + }, + }) + + assert.NoError(t, err) + + inMemDatastore, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) + assert.NoError(t, err) + retryAttemptsArray, err := bitarray.NewCompactArray(2, bitarray.Item(1)) + assert.NoError(t, err) + + newState, err := CheckSubTasksState(ctx, tMeta, "", "", jobStore, inMemDatastore, &config.Config{}, &State{ + State: &arrayCore.State{ + CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions, + ExecutionArraySize: 2, + OriginalArraySize: 2, + OriginalMinSuccesses: 2, + ArrayStatus: arraystatus.ArrayStatus{ + Detailed: arrayCore.NewPhasesCompactArray(2), + }, + IndexesToCache: bitarray.NewBitSet(2), + RetryAttempts: retryAttemptsArray, + }, + ExternalJobID: refStr("job-id"), + JobDefinitionArn: "", + }, getAwsBatchExecutorMetrics(promutils.NewTestScope()), 3) + + assert.NoError(t, err) + p, _ := newState.GetPhase() + assert.Equal(t, arrayCore.PhaseCheckingSubTaskExecutions.String(), p.String()) }) } From 46952d630cedef95ba6aa774a08d5c2939283e88 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Mon, 17 Oct 2022 15:49:45 -0700 Subject: [PATCH 05/14] lint Signed-off-by: Kevin Su --- go/tasks/plugins/array/awsbatch/monitor.go | 1 + 1 file changed, 1 insertion(+) diff --git a/go/tasks/plugins/array/awsbatch/monitor.go b/go/tasks/plugins/array/awsbatch/monitor.go index cf75ba7c3..6eee57db0 100644 --- a/go/tasks/plugins/array/awsbatch/monitor.go +++ b/go/tasks/plugins/array/awsbatch/monitor.go @@ -2,6 +2,7 @@ package awsbatch import ( "context" + core2 "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flytestdlib/storage" From 2bae5305bcee1ecb3feffc014b4e45f9200befd4 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Mon, 17 Oct 2022 17:02:14 -0700 Subject: [PATCH 06/14] update Signed-off-by: Kevin Su --- go/tasks/pluginmachinery/core/phase.go | 3 --- go/tasks/plugins/array/awsbatch/monitor.go | 7 ++++--- go/tasks/plugins/array/core/state.go | 16 ++++++++-------- go/tasks/plugins/array/core/state_test.go | 6 +++--- go/tasks/plugins/array/k8s/management.go | 2 +- 5 files changed, 16 insertions(+), 18 deletions(-) diff --git a/go/tasks/pluginmachinery/core/phase.go b/go/tasks/pluginmachinery/core/phase.go index a9ed27b0e..9cfdbe2ba 100644 --- a/go/tasks/pluginmachinery/core/phase.go +++ b/go/tasks/pluginmachinery/core/phase.go @@ -35,8 +35,6 @@ const ( PhasePermanentFailure // Indicates the task is waiting for the cache to be populated so it can reuse results PhaseWaitingForCache - // Indicate that the tasks won't be executed again because the retry limit exceeded - PhaseRetryLimitExceededFailure ) var Phases = []Phase{ @@ -50,7 +48,6 @@ var Phases = []Phase{ PhaseRetryableFailure, PhasePermanentFailure, PhaseWaitingForCache, - PhaseRetryLimitExceededFailure, } // Returns true if the given phase is failure, retryable failure or success diff --git a/go/tasks/plugins/array/awsbatch/monitor.go b/go/tasks/plugins/array/awsbatch/monitor.go index 6eee57db0..437903845 100644 --- a/go/tasks/plugins/array/awsbatch/monitor.go +++ b/go/tasks/plugins/array/awsbatch/monitor.go @@ -69,6 +69,7 @@ func CheckSubTasksState(ctx context.Context, taskMeta core.TaskExecutionMetadata } queued := 0 + totalRetryLimitExceeded := 0 for childIdx, subJob := range job.SubJobs { actualPhase := subJob.Status.Phase originalIdx := arrayCore.CalculateOriginalIndex(childIdx, currentState.GetIndexesToCache()) @@ -110,7 +111,7 @@ func CheckSubTasksState(ctx context.Context, taskMeta core.TaskExecutionMetadata } if subJob.Status.Phase == core.PhaseRetryableFailure && retryLimit == int64(len(subJob.Attempts)) { - actualPhase = core.PhaseRetryLimitExceededFailure + totalRetryLimitExceeded += 1 } } else if subJob.Status.Phase.IsSuccess() { actualPhase, err = array.CheckTaskOutput(ctx, dataStore, outputPrefix, baseOutputSandbox, childIdx, originalIdx) @@ -130,11 +131,11 @@ func CheckSubTasksState(ctx context.Context, taskMeta core.TaskExecutionMetadata parentState = parentState.SetArrayStatus(newArrayStatus) // Based on the summary produced above, deduce the overall phase of the task. - phase := arrayCore.SummaryToPhase(ctx, currentState.GetOriginalMinSuccesses()-currentState.GetOriginalArraySize()+int64(currentState.GetExecutionArraySize()), newArrayStatus.Summary) + phase := arrayCore.SummaryToPhase(ctx, currentState.GetOriginalMinSuccesses()-currentState.GetOriginalArraySize()+int64(currentState.GetExecutionArraySize()), newArrayStatus.Summary, int64(totalRetryLimitExceeded)) if phase != arrayCore.PhaseCheckingSubTaskExecutions { metrics.SubTasksSucceeded.Add(ctx, float64(newArrayStatus.Summary[core.PhaseSuccess])) - totalFailed := newArrayStatus.Summary[core.PhasePermanentFailure] + newArrayStatus.Summary[core.PhaseRetryableFailure] + newArrayStatus.Summary[core.PhaseRetryLimitExceededFailure] + totalFailed := newArrayStatus.Summary[core.PhasePermanentFailure] + newArrayStatus.Summary[core.PhaseRetryableFailure] metrics.SubTasksFailed.Add(ctx, float64(totalFailed)) } if phase == arrayCore.PhaseWriteToDiscoveryThenFail { diff --git a/go/tasks/plugins/array/core/state.go b/go/tasks/plugins/array/core/state.go index c0bc00bfb..b74a5f6b3 100644 --- a/go/tasks/plugins/array/core/state.go +++ b/go/tasks/plugins/array/core/state.go @@ -205,6 +205,9 @@ func MapArrayStateToPluginPhase(_ context.Context, state *State, logLinks []*idl case PhaseAssembleFinalError: fallthrough + case PhaseWriteToDiscoveryThenFail: + fallthrough + case PhaseWriteToDiscovery: // The state version is only incremented in PhaseCheckingSubTaskExecutions when subtask // phases are updated. Therefore by adding the phase to the state version we ensure that @@ -222,7 +225,7 @@ func MapArrayStateToPluginPhase(_ context.Context, state *State, logLinks []*idl phaseInfo = core.PhaseInfoRetryableFailure(ErrorK8sArrayGeneric, state.GetReason(), nowTaskInfo) } - case PhasePermanentFailure, PhaseWriteToDiscoveryThenFail: + case PhasePermanentFailure: if state.GetExecutionErr() != nil { phaseInfo = core.PhaseInfoFailed(core.PhasePermanentFailure, state.GetExecutionErr(), nowTaskInfo) } else { @@ -235,11 +238,10 @@ func MapArrayStateToPluginPhase(_ context.Context, state *State, logLinks []*idl return phaseInfo, nil } -func SummaryToPhase(ctx context.Context, minSuccesses int64, summary arraystatus.ArraySummary) Phase { +func SummaryToPhase(ctx context.Context, minSuccesses int64, summary arraystatus.ArraySummary, totalRetryLimitExceeded int64) Phase { totalCount := int64(0) totalSuccesses := int64(0) totalPermanentFailures := int64(0) - totalRetryLimitExceededFailures := int64(0) totalRetryableFailures := int64(0) totalRunning := int64(0) totalWaitingForResources := int64(0) @@ -251,8 +253,6 @@ func SummaryToPhase(ctx context.Context, minSuccesses int64, summary arraystatus totalSuccesses += count case core.PhasePermanentFailure: totalPermanentFailures += count - case core.PhaseRetryLimitExceededFailure: - totalRetryLimitExceededFailures += count case core.PhaseRetryableFailure: totalRetryableFailures += count case core.PhaseWaitingForResources: @@ -268,7 +268,7 @@ func SummaryToPhase(ctx context.Context, minSuccesses int64, summary arraystatus } // No chance to reach the required success numbers. - if totalRunning+totalSuccesses+totalWaitingForResources+totalRetryableFailures < minSuccesses { + if totalRunning+totalSuccesses+totalWaitingForResources+totalRetryableFailures-totalRetryLimitExceeded < minSuccesses { logger.Infof(ctx, "Array failed early because total failures > minSuccesses[%v]. Snapshot totalRunning[%v] + totalSuccesses[%v] + totalWaitingForResource[%v] + totalRetryableFailures[%v]", minSuccesses, totalRunning, totalSuccesses, totalWaitingForResources, totalRetryableFailures) return PhaseWriteToDiscoveryThenFail @@ -283,8 +283,8 @@ func SummaryToPhase(ctx context.Context, minSuccesses int64, summary arraystatus return PhaseWriteToDiscovery } - logger.Debugf(ctx, "Array is still running [Successes: %v, PermanentFailures: %v, RetryLimitExceededFailures: %v, RetryableFailures: %v, Total: %v, MinSuccesses: %v]", - totalSuccesses, totalPermanentFailures, totalRetryLimitExceededFailures, totalRetryableFailures, totalCount, minSuccesses) + logger.Debugf(ctx, "Array is still running [Successes: %v, PermanentFailures: %v, RetryableFailures: %v, Total: %v, MinSuccesses: %v]", + totalSuccesses, totalPermanentFailures, totalRetryableFailures, totalCount, minSuccesses) return PhaseCheckingSubTaskExecutions } diff --git a/go/tasks/plugins/array/core/state_test.go b/go/tasks/plugins/array/core/state_test.go index 9bf1746f7..1fee02526 100644 --- a/go/tasks/plugins/array/core/state_test.go +++ b/go/tasks/plugins/array/core/state_test.go @@ -338,15 +338,15 @@ func TestSummaryToPhase(t *testing.T) { "FailedToRetry", PhaseWriteToDiscoveryThenFail, map[core.Phase]int64{ - core.PhaseSuccess: 5, - core.PhaseRetryLimitExceededFailure: 5, + core.PhaseSuccess: 5, + core.PhaseRetryableFailure: 5, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assert.Equal(t, tt.phase, SummaryToPhase(context.TODO(), minSuccesses, tt.summary)) + assert.Equal(t, tt.phase, SummaryToPhase(context.TODO(), minSuccesses, tt.summary, 5)) }) } } diff --git a/go/tasks/plugins/array/k8s/management.go b/go/tasks/plugins/array/k8s/management.go index a29967ac1..335eaac15 100644 --- a/go/tasks/plugins/array/k8s/management.go +++ b/go/tasks/plugins/array/k8s/management.go @@ -298,7 +298,7 @@ func LaunchAndCheckSubTasksState(ctx context.Context, tCtx core.TaskExecutionCon newArrayStatus.Summary.Inc(core.Phases[phaseIdx]) } - phase := arrayCore.SummaryToPhase(ctx, currentState.GetOriginalMinSuccesses()-currentState.GetOriginalArraySize()+int64(currentState.GetExecutionArraySize()), newArrayStatus.Summary) + phase := arrayCore.SummaryToPhase(ctx, currentState.GetOriginalMinSuccesses()-currentState.GetOriginalArraySize()+int64(currentState.GetExecutionArraySize()), newArrayStatus.Summary, 0) // process new state newState = newState.SetArrayStatus(*newArrayStatus) From 6052880956a3d8ad5eb35cb296e0595665406f96 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Mon, 17 Oct 2022 17:57:56 -0700 Subject: [PATCH 07/14] update tests Signed-off-by: Kevin Su --- go/tasks/plugins/array/core/state_test.go | 25 +++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/go/tasks/plugins/array/core/state_test.go b/go/tasks/plugins/array/core/state_test.go index 1fee02526..218325a2b 100644 --- a/go/tasks/plugins/array/core/state_test.go +++ b/go/tasks/plugins/array/core/state_test.go @@ -286,14 +286,16 @@ func TestToArrayJob(t *testing.T) { func TestSummaryToPhase(t *testing.T) { minSuccesses := int64(10) tests := []struct { - name string - phase Phase - summary map[core.Phase]int64 + name string + phase Phase + summary map[core.Phase]int64 + totalRetryLimitExceeded int64 }{ { "FailOnTooFewTasks", PhaseWriteToDiscoveryThenFail, map[core.Phase]int64{}, + 0, }, { "ContinueOnRetryableFailures", @@ -302,6 +304,7 @@ func TestSummaryToPhase(t *testing.T) { core.PhaseRetryableFailure: 1, core.PhaseUndefined: 9, }, + 0, }, { "FailOnToManyPermanentFailures", @@ -310,6 +313,7 @@ func TestSummaryToPhase(t *testing.T) { core.PhasePermanentFailure: 1, core.PhaseUndefined: 9, }, + 0, }, { "CheckWaitingForResources", @@ -318,6 +322,7 @@ func TestSummaryToPhase(t *testing.T) { core.PhaseWaitingForResources: 1, core.PhaseUndefined: 9, }, + 0, }, { "WaitForAllSubtasksToComplete", @@ -326,6 +331,7 @@ func TestSummaryToPhase(t *testing.T) { core.PhaseUndefined: 1, core.PhaseSuccess: 9, }, + 0, }, { "SuccessfullyCompleted", @@ -333,6 +339,7 @@ func TestSummaryToPhase(t *testing.T) { map[core.Phase]int64{ core.PhaseSuccess: 10, }, + 0, }, { "FailedToRetry", @@ -341,12 +348,22 @@ func TestSummaryToPhase(t *testing.T) { core.PhaseSuccess: 5, core.PhaseRetryableFailure: 5, }, + 5, + }, + { + "Retrying", + PhaseCheckingSubTaskExecutions, + map[core.Phase]int64{ + core.PhaseSuccess: 5, + core.PhaseRetryableFailure: 5, + }, + 0, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assert.Equal(t, tt.phase, SummaryToPhase(context.TODO(), minSuccesses, tt.summary, 5)) + assert.Equal(t, tt.phase, SummaryToPhase(context.TODO(), minSuccesses, tt.summary, tt.totalRetryLimitExceeded)) }) } } From a8ab8f6f002ca548bd17c0e8e0d1c08735f237da Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Thu, 20 Oct 2022 15:29:04 -0700 Subject: [PATCH 08/14] lint Signed-off-by: Kevin Su --- go/tasks/plugins/array/awsbatch/monitor.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go/tasks/plugins/array/awsbatch/monitor.go b/go/tasks/plugins/array/awsbatch/monitor.go index 437903845..74e29af43 100644 --- a/go/tasks/plugins/array/awsbatch/monitor.go +++ b/go/tasks/plugins/array/awsbatch/monitor.go @@ -111,7 +111,7 @@ func CheckSubTasksState(ctx context.Context, taskMeta core.TaskExecutionMetadata } if subJob.Status.Phase == core.PhaseRetryableFailure && retryLimit == int64(len(subJob.Attempts)) { - totalRetryLimitExceeded += 1 + totalRetryLimitExceeded++ } } else if subJob.Status.Phase.IsSuccess() { actualPhase, err = array.CheckTaskOutput(ctx, dataStore, outputPrefix, baseOutputSandbox, childIdx, originalIdx) From 56f96f3b25c5f0a77267bdf25f8af657d11914f1 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Thu, 17 Nov 2022 10:03:48 -0800 Subject: [PATCH 09/14] wip Signed-off-by: Kevin Su --- go/tasks/plugins/webapi/databricks/config.go | 71 +++++ .../plugins/webapi/databricks/config_test.go | 18 ++ .../webapi/databricks/integration_test.go | 107 +++++++ go/tasks/plugins/webapi/databricks/plugin.go | 290 ++++++++++++++++++ .../plugins/webapi/databricks/plugin_test.go | 122 ++++++++ go/tasks/plugins/webapi/snowflake/config.go | 2 +- 6 files changed, 609 insertions(+), 1 deletion(-) create mode 100644 go/tasks/plugins/webapi/databricks/config.go create mode 100644 go/tasks/plugins/webapi/databricks/config_test.go create mode 100644 go/tasks/plugins/webapi/databricks/integration_test.go create mode 100644 go/tasks/plugins/webapi/databricks/plugin.go create mode 100644 go/tasks/plugins/webapi/databricks/plugin_test.go diff --git a/go/tasks/plugins/webapi/databricks/config.go b/go/tasks/plugins/webapi/databricks/config.go new file mode 100644 index 000000000..7cd0c17c2 --- /dev/null +++ b/go/tasks/plugins/webapi/databricks/config.go @@ -0,0 +1,71 @@ +package databricks + +import ( + "time" + + pluginsConfig "github.com/flyteorg/flyteplugins/go/tasks/config" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/webapi" + "github.com/flyteorg/flytestdlib/config" +) + +var ( + defaultConfig = Config{ + WebAPI: webapi.PluginConfig{ + ResourceQuotas: map[core.ResourceNamespace]int{ + "default": 1000, + }, + ReadRateLimiter: webapi.RateLimiterConfig{ + Burst: 100, + QPS: 10, + }, + WriteRateLimiter: webapi.RateLimiterConfig{ + Burst: 100, + QPS: 10, + }, + Caching: webapi.CachingConfig{ + Size: 500000, + ResyncInterval: config.Duration{Duration: 30 * time.Second}, + Workers: 10, + MaxSystemFailures: 5, + }, + ResourceMeta: nil, + }, + ResourceConstraints: core.ResourceConstraintsSpec{ + ProjectScopeResourceConstraint: &core.ResourceConstraint{ + Value: 100, + }, + NamespaceScopeResourceConstraint: &core.ResourceConstraint{ + Value: 50, + }, + }, + DefaultCluster: "COMPUTE_CLUSTER", + TokenKey: "FLYTE_DATABRICKS_API_TOKEN", + } + + configSection = pluginsConfig.MustRegisterSubSection("databricks", &defaultConfig) +) + +// Config is config for 'databricks' plugin +type Config struct { + // WebAPI defines config for the base WebAPI plugin + WebAPI webapi.PluginConfig `json:"webApi" pflag:",Defines config for the base WebAPI plugin."` + + // ResourceConstraints defines resource constraints on how many executions to be created per project/overall at any given time + ResourceConstraints core.ResourceConstraintsSpec `json:"resourceConstraints" pflag:"-,Defines resource constraints on how many executions to be created per project/overall at any given time."` + + DefaultCluster string `json:"defaultWarehouse" pflag:",Defines the default warehouse to use when running on Databricks unless overwritten by the task."` + + TokenKey string `json:"databricksTokenKey" pflag:",Name of the key where to find Databricks token in the secret manager."` + + // databricksEndpoint overrides databricks instance endpoint, only for testing + databricksEndpoint string +} + +func GetConfig() *Config { + return configSection.GetConfig().(*Config) +} + +func SetConfig(cfg *Config) error { + return configSection.SetConfig(cfg) +} diff --git a/go/tasks/plugins/webapi/databricks/config_test.go b/go/tasks/plugins/webapi/databricks/config_test.go new file mode 100644 index 000000000..46cee89e2 --- /dev/null +++ b/go/tasks/plugins/webapi/databricks/config_test.go @@ -0,0 +1,18 @@ +package databricks + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestGetAndSetConfig(t *testing.T) { + cfg := defaultConfig + cfg.DefaultCluster = "test-cluster" + cfg.WebAPI.Caching.Workers = 1 + cfg.WebAPI.Caching.ResyncInterval.Duration = 5 * time.Second + err := SetConfig(&cfg) + assert.NoError(t, err) + assert.Equal(t, &cfg, GetConfig()) +} diff --git a/go/tasks/plugins/webapi/databricks/integration_test.go b/go/tasks/plugins/webapi/databricks/integration_test.go new file mode 100644 index 000000000..21a17962c --- /dev/null +++ b/go/tasks/plugins/webapi/databricks/integration_test.go @@ -0,0 +1,107 @@ +package databricks + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/flyteorg/flyteidl/clients/go/coreutils" + coreIdl "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + flyteIdlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery" + pluginCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" + pluginCoreMocks "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks" + "github.com/flyteorg/flyteplugins/tests" + "github.com/flyteorg/flytestdlib/contextutils" + "github.com/flyteorg/flytestdlib/promutils" + "github.com/flyteorg/flytestdlib/promutils/labeled" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestEndToEnd(t *testing.T) { + server := newFakeSnowflakeServer() + defer server.Close() + + iter := func(ctx context.Context, tCtx pluginCore.TaskExecutionContext) error { + return nil + } + + cfg := defaultConfig + cfg.databricksEndpoint = server.URL + cfg.DefaultCluster = "test-cluster" + cfg.WebAPI.Caching.Workers = 1 + cfg.WebAPI.Caching.ResyncInterval.Duration = 5 * time.Second + err := SetConfig(&cfg) + assert.NoError(t, err) + + pluginEntry := pluginmachinery.CreateRemotePlugin(newSnowflakeJobTaskPlugin()) + plugin, err := pluginEntry.LoadPlugin(context.TODO(), newFakeSetupContext()) + assert.NoError(t, err) + + t.Run("SELECT 1", func(t *testing.T) { + config := make(map[string]string) + config["database"] = "my-database" + config["account"] = "snowflake" + config["schema"] = "my-schema" + config["warehouse"] = "my-warehouse" + + inputs, _ := coreutils.MakeLiteralMap(map[string]interface{}{"x": 1}) + template := flyteIdlCore.TaskTemplate{ + Type: "snowflake", + Config: config, + Target: &coreIdl.TaskTemplate_Sql{Sql: &coreIdl.Sql{Statement: "SELECT 1", Dialect: coreIdl.Sql_ANSI}}, + } + + phase := tests.RunPluginEndToEndTest(t, plugin, &template, inputs, nil, nil, iter) + + assert.Equal(t, true, phase.Phase().IsSuccess()) + }) +} + +func newFakeSnowflakeServer() *httptest.Server { + statementHandle := "019e7546-0000-278c-0000-40f10001a082" + return httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + if request.URL.Path == "/api/v2/statements" && request.Method == "POST" { + writer.WriteHeader(202) + bytes := []byte(fmt.Sprintf(`{ + "statementHandle": "%v", + "message": "Asynchronous execution in progress." + }`, statementHandle)) + _, _ = writer.Write(bytes) + return + } + + if request.URL.Path == "/api/v2/statements/"+statementHandle && request.Method == "GET" { + writer.WriteHeader(200) + bytes := []byte(fmt.Sprintf(`{ + "statementHandle": "%v", + "message": "Statement executed successfully." + }`, statementHandle)) + _, _ = writer.Write(bytes) + return + } + + if request.URL.Path == "/api/v2/statements/"+statementHandle+"/cancel" && request.Method == "POST" { + writer.WriteHeader(200) + return + } + + writer.WriteHeader(500) + })) +} + +func newFakeSetupContext() *pluginCoreMocks.SetupContext { + fakeResourceRegistrar := pluginCoreMocks.ResourceRegistrar{} + fakeResourceRegistrar.On("RegisterResourceQuota", mock.Anything, mock.Anything, mock.Anything).Return(nil) + labeled.SetMetricKeys(contextutils.NamespaceKey) + + fakeSetupContext := pluginCoreMocks.SetupContext{} + fakeSetupContext.OnMetricsScope().Return(promutils.NewScope("test")) + fakeSetupContext.OnResourceRegistrar().Return(&fakeResourceRegistrar) + + return &fakeSetupContext +} diff --git a/go/tasks/plugins/webapi/databricks/plugin.go b/go/tasks/plugins/webapi/databricks/plugin.go new file mode 100644 index 000000000..a52e91293 --- /dev/null +++ b/go/tasks/plugins/webapi/databricks/plugin.go @@ -0,0 +1,290 @@ +package databricks + +import ( + "bytes" + "context" + "encoding/gob" + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "time" + + flyteIdlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + errors2 "github.com/flyteorg/flyteplugins/go/tasks/errors" + pluginErrors "github.com/flyteorg/flyteplugins/go/tasks/errors" + pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/template" + "github.com/flyteorg/flytestdlib/errors" + "github.com/flyteorg/flytestdlib/logger" + + "github.com/flyteorg/flytestdlib/promutils" + + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/webapi" +) + +const ( + ErrSystem errors.ErrorCode = "System" + post string = "POST" + get string = "GET" +) + +// for mocking/testing purposes, and we'll override this method +type HTTPClient interface { + Do(req *http.Request) (*http.Response, error) +} + +type Plugin struct { + metricScope promutils.Scope + cfg *Config + client HTTPClient +} + +type ResourceWrapper struct { + StatusCode int + Message string +} + +type ResourceMetaWrapper struct { + QueryID string + Account string + Token string +} + +func (p Plugin) GetConfig() webapi.PluginConfig { + return GetConfig().WebAPI +} + +type QueryInfo struct { + Account string + Warehouse string + Schema string + Database string + Statement string +} + +func (p Plugin) ResourceRequirements(_ context.Context, _ webapi.TaskExecutionContextReader) ( + namespace core.ResourceNamespace, constraints core.ResourceConstraintsSpec, err error) { + + // Resource requirements are assumed to be the same. + return "default", p.cfg.ResourceConstraints, nil +} + +func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextReader) (webapi.ResourceMeta, + webapi.Resource, error) { + task, err := taskCtx.TaskReader().Read(ctx) + if err != nil { + return nil, nil, err + } + + token, err := taskCtx.SecretManager().Get(ctx, p.cfg.TokenKey) + if err != nil { + return nil, nil, err + } + config := task.GetConfig() + + outputs, err := template.Render(ctx, []string{ + task.GetSql().Statement, + }, template.Parameters{ + TaskExecMetadata: taskCtx.TaskExecutionMetadata(), + Inputs: taskCtx.InputReader(), + OutputPath: taskCtx.OutputWriter(), + Task: taskCtx.TaskReader(), + }) + if err != nil { + return nil, nil, err + } + queryInfo := QueryInfo{ + Account: config["account"], + Warehouse: config["warehouse"], + Schema: config["schema"], + Database: config["database"], + Statement: outputs[0], + } + + if len(queryInfo.Warehouse) == 0 { + queryInfo.Warehouse = p.cfg.DefaultCluster + } + if len(queryInfo.Account) == 0 { + return nil, nil, errors.Errorf(errors2.BadTaskSpecification, "Account must not be empty.") + } + if len(queryInfo.Database) == 0 { + return nil, nil, errors.Errorf(errors2.BadTaskSpecification, "Database must not be empty.") + } + req, err := buildRequest(post, queryInfo, p.cfg.databricksEndpoint, + config["account"], token, "", false) + if err != nil { + return nil, nil, err + } + resp, err := p.client.Do(req) + if err != nil { + return nil, nil, err + } + defer resp.Body.Close() + data, err := buildResponse(resp) + if err != nil { + return nil, nil, err + } + + if data["statementHandle"] == "" { + return nil, nil, pluginErrors.Wrapf(pluginErrors.RuntimeFailure, err, + "Unable to fetch statementHandle from http response") + } + if data["message"] == "" { + return nil, nil, pluginErrors.Wrapf(pluginErrors.RuntimeFailure, err, + "Unable to fetch message from http response") + } + queryID := fmt.Sprintf("%v", data["statementHandle"]) + message := fmt.Sprintf("%v", data["message"]) + + return &ResourceMetaWrapper{queryID, queryInfo.Account, token}, + &ResourceWrapper{StatusCode: resp.StatusCode, Message: message}, nil +} + +func (p Plugin) Get(ctx context.Context, taskCtx webapi.GetContext) (latest webapi.Resource, err error) { + exec := taskCtx.ResourceMeta().(*ResourceMetaWrapper) + req, err := buildRequest(get, QueryInfo{}, p.cfg.databricksEndpoint, + exec.Account, exec.Token, exec.QueryID, false) + if err != nil { + return nil, err + } + resp, err := p.client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + data, err := buildResponse(resp) + if err != nil { + return nil, err + } + message := fmt.Sprintf("%v", data["message"]) + return &ResourceWrapper{ + StatusCode: resp.StatusCode, + Message: message, + }, nil +} + +func (p Plugin) Delete(ctx context.Context, taskCtx webapi.DeleteContext) error { + exec := taskCtx.ResourceMeta().(*ResourceMetaWrapper) + req, err := buildRequest(post, QueryInfo{}, p.cfg.databricksEndpoint, + exec.Account, exec.Token, exec.QueryID, true) + if err != nil { + return err + } + resp, err := p.client.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + logger.Info(ctx, "Deleted query execution [%v]", resp) + + return nil +} + +func (p Plugin) Status(_ context.Context, taskCtx webapi.StatusContext) (phase core.PhaseInfo, err error) { + exec := taskCtx.ResourceMeta().(*ResourceMetaWrapper) + statusCode := taskCtx.Resource().(*ResourceWrapper).StatusCode + if statusCode == 0 { + return core.PhaseInfoUndefined, errors.Errorf(ErrSystem, "No Status field set.") + } + + taskInfo := createTaskInfo(exec.QueryID, exec.Account) + switch statusCode { + case http.StatusAccepted: + return core.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, createTaskInfo(exec.QueryID, exec.Account)), nil + case http.StatusOK: + return pluginsCore.PhaseInfoSuccess(taskInfo), nil + case http.StatusUnprocessableEntity: + return pluginsCore.PhaseInfoFailure(string(rune(statusCode)), "phaseReason", taskInfo), nil + } + return core.PhaseInfoUndefined, pluginErrors.Errorf(pluginsCore.SystemErrorCode, "unknown execution phase [%v].", statusCode) +} + +func buildRequest(method string, queryInfo QueryInfo, snowflakeEndpoint string, account string, token string, + queryID string, isCancel bool) (*http.Request, error) { + var snowflakeURL string + // for mocking/testing purposes + if snowflakeEndpoint == "" { + snowflakeURL = "https://" + account + ".snowflakecomputing.com/api/v2/statements" + } else { + snowflakeURL = snowflakeEndpoint + "/api/v2/statements" + } + + var data []byte + if method == post && !isCancel { + snowflakeURL += "?async=true" + data = []byte(fmt.Sprintf(`{ + "statement": "%v", + "database": "%v", + "schema": "%v", + "warehouse": "%v" + }`, queryInfo.Statement, queryInfo.Database, queryInfo.Schema, queryInfo.Warehouse)) + } else { + snowflakeURL += "/" + queryID + } + if isCancel { + snowflakeURL += "/cancel" + } + + req, err := http.NewRequest(method, snowflakeURL, bytes.NewBuffer(data)) + if err != nil { + return nil, err + } + req.Header.Add("Authorization", "Bearer "+token) + req.Header.Add("X-Snowflake-Authorization-Token-Type", "KEYPAIR_JWT") + req.Header.Add("Content-Type", "application/json") + req.Header.Add("Accept", "application/json") + return req, nil +} + +func buildResponse(response *http.Response) (map[string]interface{}, error) { + responseBody, err := ioutil.ReadAll(response.Body) + if err != nil { + return nil, err + } + var data map[string]interface{} + err = json.Unmarshal(responseBody, &data) + if err != nil { + return nil, err + } + return data, nil +} + +func createTaskInfo(queryID string, account string) *core.TaskInfo { + timeNow := time.Now() + + return &core.TaskInfo{ + OccurredAt: &timeNow, + Logs: []*flyteIdlCore.TaskLog{ + { + Uri: fmt.Sprintf("https://%v.snowflakecomputing.com/console#/monitoring/queries/detail?queryId=%v", + account, + queryID), + Name: "Snowflake Console", + }, + }, + } +} + +func newSnowflakeJobTaskPlugin() webapi.PluginEntry { + return webapi.PluginEntry{ + ID: "snowflake", + SupportedTaskTypes: []core.TaskType{"snowflake"}, + PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) { + return &Plugin{ + metricScope: iCtx.MetricsScope(), + cfg: GetConfig(), + client: &http.Client{}, + }, nil + }, + } +} + +func init() { + gob.Register(ResourceMetaWrapper{}) + gob.Register(ResourceWrapper{}) + + pluginmachinery.PluginRegistry().RegisterRemotePlugin(newSnowflakeJobTaskPlugin()) +} diff --git a/go/tasks/plugins/webapi/databricks/plugin_test.go b/go/tasks/plugins/webapi/databricks/plugin_test.go new file mode 100644 index 000000000..10febc17f --- /dev/null +++ b/go/tasks/plugins/webapi/databricks/plugin_test.go @@ -0,0 +1,122 @@ +package databricks + +import ( + "context" + "encoding/json" + "io/ioutil" + "net/http" + "strings" + "testing" + "time" + + pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" + pluginCoreMocks "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks" + "github.com/flyteorg/flytestdlib/promutils" + "github.com/stretchr/testify/assert" +) + +type MockClient struct { +} + +var ( + MockDo func(req *http.Request) (*http.Response, error) +) + +func (m *MockClient) Do(req *http.Request) (*http.Response, error) { + return MockDo(req) +} + +func TestPlugin(t *testing.T) { + fakeSetupContext := pluginCoreMocks.SetupContext{} + fakeSetupContext.OnMetricsScope().Return(promutils.NewScope("test")) + + plugin := Plugin{ + metricScope: fakeSetupContext.MetricsScope(), + cfg: GetConfig(), + client: &MockClient{}, + } + t.Run("get config", func(t *testing.T) { + cfg := defaultConfig + cfg.WebAPI.Caching.Workers = 1 + cfg.WebAPI.Caching.ResyncInterval.Duration = 5 * time.Second + err := SetConfig(&cfg) + assert.NoError(t, err) + assert.Equal(t, cfg.WebAPI, plugin.GetConfig()) + }) + t.Run("get ResourceRequirements", func(t *testing.T) { + namespace, constraints, err := plugin.ResourceRequirements(context.TODO(), nil) + assert.NoError(t, err) + assert.Equal(t, pluginsCore.ResourceNamespace("default"), namespace) + assert.Equal(t, plugin.cfg.ResourceConstraints, constraints) + }) +} + +func TestCreateTaskInfo(t *testing.T) { + t.Run("create task info", func(t *testing.T) { + taskInfo := createTaskInfo("d5493e36", "test-account") + + assert.Equal(t, 1, len(taskInfo.Logs)) + assert.Equal(t, taskInfo.Logs[0].Uri, "https://test-account.snowflakecomputing.com/console#/monitoring/queries/detail?queryId=d5493e36") + assert.Equal(t, taskInfo.Logs[0].Name, "Snowflake Console") + }) +} + +func TestBuildRequest(t *testing.T) { + account := "test-account" + token := "test-token" + queryID := "019e70eb-0000-278b-0000-40f100012b1a" + snowflakeEndpoint := "" + snowflakeURL := "https://" + account + ".snowflakecomputing.com/api/v2/statements" + t.Run("build http request for submitting a snowflake query", func(t *testing.T) { + queryInfo := QueryInfo{ + Account: account, + Warehouse: "test-warehouse", + Schema: "test-schema", + Database: "test-database", + Statement: "SELECT 1", + } + + req, err := buildRequest(post, queryInfo, snowflakeEndpoint, account, token, queryID, false) + header := http.Header{} + header.Add("Authorization", "Bearer "+token) + header.Add("X-Snowflake-Authorization-Token-Type", "KEYPAIR_JWT") + header.Add("Content-Type", "application/json") + header.Add("Accept", "application/json") + + assert.NoError(t, err) + assert.Equal(t, header, req.Header) + assert.Equal(t, snowflakeURL+"?async=true", req.URL.String()) + assert.Equal(t, post, req.Method) + }) + t.Run("build http request for getting a snowflake query status", func(t *testing.T) { + req, err := buildRequest(get, QueryInfo{}, snowflakeEndpoint, account, token, queryID, false) + + assert.NoError(t, err) + assert.Equal(t, snowflakeURL+"/"+queryID, req.URL.String()) + assert.Equal(t, get, req.Method) + }) + t.Run("build http request for deleting a snowflake query", func(t *testing.T) { + req, err := buildRequest(post, QueryInfo{}, snowflakeEndpoint, account, token, queryID, true) + + assert.NoError(t, err) + assert.Equal(t, snowflakeURL+"/"+queryID+"/cancel", req.URL.String()) + assert.Equal(t, post, req.Method) + }) +} + +func TestBuildResponse(t *testing.T) { + t.Run("build http response", func(t *testing.T) { + bodyStr := `{"statementHandle":"019c06a4-0000", "message":"Statement executed successfully."}` + responseBody := ioutil.NopCloser(strings.NewReader(bodyStr)) + response := &http.Response{Body: responseBody} + actualData, err := buildResponse(response) + assert.NoError(t, err) + + bodyByte, err := ioutil.ReadAll(strings.NewReader(bodyStr)) + assert.NoError(t, err) + var expectedData map[string]interface{} + err = json.Unmarshal(bodyByte, &expectedData) + assert.NoError(t, err) + assert.Equal(t, expectedData, actualData) + }) +} diff --git a/go/tasks/plugins/webapi/snowflake/config.go b/go/tasks/plugins/webapi/snowflake/config.go index 93160f663..4a6647e8c 100644 --- a/go/tasks/plugins/webapi/snowflake/config.go +++ b/go/tasks/plugins/webapi/snowflake/config.go @@ -48,7 +48,7 @@ var ( // Config is config for 'snowflake' plugin type Config struct { - // WeCreateTaskInfobAPI defines config for the base WebAPI plugin + // WebAPI defines config for the base WebAPI plugin WebAPI webapi.PluginConfig `json:"webApi" pflag:",Defines config for the base WebAPI plugin."` // ResourceConstraints defines resource constraints on how many executions to be created per project/overall at any given time From 8e6189825df79644c79139d59779c26157b61f3d Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Thu, 17 Nov 2022 10:54:26 -0800 Subject: [PATCH 10/14] udpate Signed-off-by: Kevin Su --- go/tasks/plugins/array/awsbatch/monitor.go | 5 ++--- .../plugins/array/awsbatch/monitor_test.go | 6 +++--- go/tasks/plugins/array/core/state.go | 4 ++-- go/tasks/plugins/array/core/state_test.go | 19 +++++-------------- go/tasks/plugins/array/k8s/management.go | 2 +- 5 files changed, 13 insertions(+), 23 deletions(-) diff --git a/go/tasks/plugins/array/awsbatch/monitor.go b/go/tasks/plugins/array/awsbatch/monitor.go index 74e29af43..c4366e831 100644 --- a/go/tasks/plugins/array/awsbatch/monitor.go +++ b/go/tasks/plugins/array/awsbatch/monitor.go @@ -69,7 +69,6 @@ func CheckSubTasksState(ctx context.Context, taskMeta core.TaskExecutionMetadata } queued := 0 - totalRetryLimitExceeded := 0 for childIdx, subJob := range job.SubJobs { actualPhase := subJob.Status.Phase originalIdx := arrayCore.CalculateOriginalIndex(childIdx, currentState.GetIndexesToCache()) @@ -111,7 +110,7 @@ func CheckSubTasksState(ctx context.Context, taskMeta core.TaskExecutionMetadata } if subJob.Status.Phase == core.PhaseRetryableFailure && retryLimit == int64(len(subJob.Attempts)) { - totalRetryLimitExceeded++ + actualPhase = core.PhasePermanentFailure } } else if subJob.Status.Phase.IsSuccess() { actualPhase, err = array.CheckTaskOutput(ctx, dataStore, outputPrefix, baseOutputSandbox, childIdx, originalIdx) @@ -131,7 +130,7 @@ func CheckSubTasksState(ctx context.Context, taskMeta core.TaskExecutionMetadata parentState = parentState.SetArrayStatus(newArrayStatus) // Based on the summary produced above, deduce the overall phase of the task. - phase := arrayCore.SummaryToPhase(ctx, currentState.GetOriginalMinSuccesses()-currentState.GetOriginalArraySize()+int64(currentState.GetExecutionArraySize()), newArrayStatus.Summary, int64(totalRetryLimitExceeded)) + phase := arrayCore.SummaryToPhase(ctx, currentState.GetOriginalMinSuccesses()-currentState.GetOriginalArraySize()+int64(currentState.GetExecutionArraySize()), newArrayStatus.Summary) if phase != arrayCore.PhaseCheckingSubTaskExecutions { metrics.SubTasksSucceeded.Add(ctx, float64(newArrayStatus.Summary[core.PhaseSuccess])) diff --git a/go/tasks/plugins/array/awsbatch/monitor_test.go b/go/tasks/plugins/array/awsbatch/monitor_test.go index b5a5925be..e9cedb90b 100644 --- a/go/tasks/plugins/array/awsbatch/monitor_test.go +++ b/go/tasks/plugins/array/awsbatch/monitor_test.go @@ -236,7 +236,7 @@ func TestCheckSubTasksState(t *testing.T) { newState, err := CheckSubTasksState(ctx, tMeta, "", "", jobStore, inMemDatastore, &config.Config{}, &State{ State: &arrayCore.State{ - CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions, + CurrentPhase: arrayCore.PhaseWriteToDiscoveryThenFail, ExecutionArraySize: 2, OriginalArraySize: 2, OriginalMinSuccesses: 2, @@ -248,10 +248,10 @@ func TestCheckSubTasksState(t *testing.T) { }, ExternalJobID: refStr("job-id"), JobDefinitionArn: "", - }, getAwsBatchExecutorMetrics(promutils.NewTestScope()), 3) + }, getAwsBatchExecutorMetrics(promutils.NewTestScope()), 1) assert.NoError(t, err) p, _ := newState.GetPhase() - assert.Equal(t, arrayCore.PhaseCheckingSubTaskExecutions.String(), p.String()) + assert.Equal(t, arrayCore.PhaseWriteToDiscoveryThenFail, p) }) } diff --git a/go/tasks/plugins/array/core/state.go b/go/tasks/plugins/array/core/state.go index b74a5f6b3..9a281c311 100644 --- a/go/tasks/plugins/array/core/state.go +++ b/go/tasks/plugins/array/core/state.go @@ -238,7 +238,7 @@ func MapArrayStateToPluginPhase(_ context.Context, state *State, logLinks []*idl return phaseInfo, nil } -func SummaryToPhase(ctx context.Context, minSuccesses int64, summary arraystatus.ArraySummary, totalRetryLimitExceeded int64) Phase { +func SummaryToPhase(ctx context.Context, minSuccesses int64, summary arraystatus.ArraySummary) Phase { totalCount := int64(0) totalSuccesses := int64(0) totalPermanentFailures := int64(0) @@ -268,7 +268,7 @@ func SummaryToPhase(ctx context.Context, minSuccesses int64, summary arraystatus } // No chance to reach the required success numbers. - if totalRunning+totalSuccesses+totalWaitingForResources+totalRetryableFailures-totalRetryLimitExceeded < minSuccesses { + if totalRunning+totalSuccesses+totalWaitingForResources+totalRetryableFailures < minSuccesses { logger.Infof(ctx, "Array failed early because total failures > minSuccesses[%v]. Snapshot totalRunning[%v] + totalSuccesses[%v] + totalWaitingForResource[%v] + totalRetryableFailures[%v]", minSuccesses, totalRunning, totalSuccesses, totalWaitingForResources, totalRetryableFailures) return PhaseWriteToDiscoveryThenFail diff --git a/go/tasks/plugins/array/core/state_test.go b/go/tasks/plugins/array/core/state_test.go index 218325a2b..d7e9dfeed 100644 --- a/go/tasks/plugins/array/core/state_test.go +++ b/go/tasks/plugins/array/core/state_test.go @@ -286,16 +286,14 @@ func TestToArrayJob(t *testing.T) { func TestSummaryToPhase(t *testing.T) { minSuccesses := int64(10) tests := []struct { - name string - phase Phase - summary map[core.Phase]int64 - totalRetryLimitExceeded int64 + name string + phase Phase + summary map[core.Phase]int64 }{ { "FailOnTooFewTasks", PhaseWriteToDiscoveryThenFail, map[core.Phase]int64{}, - 0, }, { "ContinueOnRetryableFailures", @@ -304,7 +302,6 @@ func TestSummaryToPhase(t *testing.T) { core.PhaseRetryableFailure: 1, core.PhaseUndefined: 9, }, - 0, }, { "FailOnToManyPermanentFailures", @@ -313,7 +310,6 @@ func TestSummaryToPhase(t *testing.T) { core.PhasePermanentFailure: 1, core.PhaseUndefined: 9, }, - 0, }, { "CheckWaitingForResources", @@ -322,7 +318,6 @@ func TestSummaryToPhase(t *testing.T) { core.PhaseWaitingForResources: 1, core.PhaseUndefined: 9, }, - 0, }, { "WaitForAllSubtasksToComplete", @@ -331,7 +326,6 @@ func TestSummaryToPhase(t *testing.T) { core.PhaseUndefined: 1, core.PhaseSuccess: 9, }, - 0, }, { "SuccessfullyCompleted", @@ -339,16 +333,14 @@ func TestSummaryToPhase(t *testing.T) { map[core.Phase]int64{ core.PhaseSuccess: 10, }, - 0, }, { "FailedToRetry", PhaseWriteToDiscoveryThenFail, map[core.Phase]int64{ core.PhaseSuccess: 5, - core.PhaseRetryableFailure: 5, + core.PhasePermanentFailure: 5, }, - 5, }, { "Retrying", @@ -357,13 +349,12 @@ func TestSummaryToPhase(t *testing.T) { core.PhaseSuccess: 5, core.PhaseRetryableFailure: 5, }, - 0, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assert.Equal(t, tt.phase, SummaryToPhase(context.TODO(), minSuccesses, tt.summary, tt.totalRetryLimitExceeded)) + assert.Equal(t, tt.phase, SummaryToPhase(context.TODO(), minSuccesses, tt.summary)) }) } } diff --git a/go/tasks/plugins/array/k8s/management.go b/go/tasks/plugins/array/k8s/management.go index 335eaac15..a29967ac1 100644 --- a/go/tasks/plugins/array/k8s/management.go +++ b/go/tasks/plugins/array/k8s/management.go @@ -298,7 +298,7 @@ func LaunchAndCheckSubTasksState(ctx context.Context, tCtx core.TaskExecutionCon newArrayStatus.Summary.Inc(core.Phases[phaseIdx]) } - phase := arrayCore.SummaryToPhase(ctx, currentState.GetOriginalMinSuccesses()-currentState.GetOriginalArraySize()+int64(currentState.GetExecutionArraySize()), newArrayStatus.Summary, 0) + phase := arrayCore.SummaryToPhase(ctx, currentState.GetOriginalMinSuccesses()-currentState.GetOriginalArraySize()+int64(currentState.GetExecutionArraySize()), newArrayStatus.Summary) // process new state newState = newState.SetArrayStatus(*newArrayStatus) From d70bce4c7bf764d0694e920c90209a7b561602bb Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Tue, 29 Nov 2022 14:43:34 -0800 Subject: [PATCH 11/14] address comment Signed-off-by: Kevin Su --- go/tasks/plugins/array/awsbatch/executor.go | 14 +---------- go/tasks/plugins/array/awsbatch/monitor.go | 27 +++++++++++++++------ 2 files changed, 21 insertions(+), 20 deletions(-) diff --git a/go/tasks/plugins/array/awsbatch/executor.go b/go/tasks/plugins/array/awsbatch/executor.go index 045ebdb57..e5172d31f 100644 --- a/go/tasks/plugins/array/awsbatch/executor.go +++ b/go/tasks/plugins/array/awsbatch/executor.go @@ -80,19 +80,7 @@ func (e Executor) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (c pluginState, err = LaunchSubTasks(ctx, tCtx, e.jobStore, pluginConfig, pluginState, e.metrics) case arrayCore.PhaseCheckingSubTaskExecutions: - // Check that the taskTemplate is valid - var taskTemplate *idlCore.TaskTemplate - taskTemplate, err = tCtx.TaskReader().Read(ctx) - if err != nil { - return core.UnknownTransition, errors.Wrapf(errors.CorruptedPluginState, err, "Failed to read task template") - } else if taskTemplate == nil { - return core.UnknownTransition, errors.Errorf(errors.BadTaskSpecification, "Required value not set, taskTemplate is nil") - } - retry := toRetryStrategy(ctx, toBackoffLimit(taskTemplate.Metadata), pluginConfig.MinRetries, pluginConfig.MaxRetries) - - pluginState, err = CheckSubTasksState(ctx, tCtx.TaskExecutionMetadata(), - tCtx.OutputWriter().GetOutputPrefixPath(), tCtx.OutputWriter().GetRawOutputPrefix(), - e.jobStore, tCtx.DataStore(), pluginConfig, pluginState, e.metrics, *retry.Attempts) + pluginState, err = CheckSubTasksState(ctx, tCtx, e.jobStore, pluginConfig, pluginState, e.metrics) case arrayCore.PhaseAssembleFinalOutput: pluginState.State, err = array.AssembleFinalOutputs(ctx, e.outputAssembler, tCtx, arrayCore.PhaseSuccess, version, pluginState.State) diff --git a/go/tasks/plugins/array/awsbatch/monitor.go b/go/tasks/plugins/array/awsbatch/monitor.go index c4366e831..df9716507 100644 --- a/go/tasks/plugins/array/awsbatch/monitor.go +++ b/go/tasks/plugins/array/awsbatch/monitor.go @@ -3,9 +3,9 @@ package awsbatch import ( "context" - core2 "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - "github.com/flyteorg/flytestdlib/storage" + "github.com/flyteorg/flyteplugins/go/tasks/errors" + core2 "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/ioutils" "github.com/flyteorg/flyteplugins/go/tasks/plugins/array" @@ -34,19 +34,32 @@ func createSubJobList(count int) []*Job { return res } -func CheckSubTasksState(ctx context.Context, taskMeta core.TaskExecutionMetadata, outputPrefix, baseOutputSandbox storage.DataReference, jobStore *JobStore, - dataStore *storage.DataStore, cfg *config.Config, currentState *State, metrics ExecutorMetrics, retryLimit int64) (newState *State, err error) { +func CheckSubTasksState(ctx context.Context, tCtx core.TaskExecutionContext, jobStore *JobStore, + cfg *config.Config, currentState *State, metrics ExecutorMetrics) (newState *State, err error) { newState = currentState parentState := currentState.State - jobName := taskMeta.GetTaskExecutionID().GetGeneratedName() + jobName := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() job := jobStore.Get(jobName) + outputPrefix := tCtx.OutputWriter().GetOutputPrefixPath() + baseOutputSandbox := tCtx.OutputWriter().GetRawOutputPrefix() + dataStore := tCtx.DataStore() + // Check that the taskTemplate is valid + var taskTemplate *core2.TaskTemplate + taskTemplate, err = tCtx.TaskReader().Read(ctx) + if err != nil { + return nil, errors.Wrapf(errors.CorruptedPluginState, err, "Failed to read task template") + } else if taskTemplate == nil { + return nil, errors.Errorf(errors.BadTaskSpecification, "Required value not set, taskTemplate is nil") + } + retry := toRetryStrategy(ctx, toBackoffLimit(taskTemplate.Metadata), cfg.MinRetries, cfg.MaxRetries) + // If job isn't currently being monitored (recovering from a restart?), add it to the sync-cache and return if job == nil { logger.Info(ctx, "Job not found in cache, adding it. [%v]", jobName) _, err = jobStore.GetOrCreate(jobName, &Job{ ID: *currentState.ExternalJobID, - OwnerReference: taskMeta.GetOwnerID(), + OwnerReference: tCtx.TaskExecutionMetadata().GetOwnerID(), SubJobs: createSubJobList(currentState.GetExecutionArraySize()), }) @@ -109,7 +122,7 @@ func CheckSubTasksState(ctx context.Context, taskMeta core.TaskExecutionMetadata msg.Collect(childIdx, "Job failed") } - if subJob.Status.Phase == core.PhaseRetryableFailure && retryLimit == int64(len(subJob.Attempts)) { + if subJob.Status.Phase == core.PhaseRetryableFailure && *retry.Attempts == int64(len(subJob.Attempts)) { actualPhase = core.PhasePermanentFailure } } else if subJob.Status.Phase.IsSuccess() { From 9ab3dfb0c24f895fd1eba2c80cdebe498e78cbac Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Tue, 29 Nov 2022 14:45:14 -0800 Subject: [PATCH 12/14] nit Signed-off-by: Kevin Su --- go/tasks/plugins/array/awsbatch/monitor.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/go/tasks/plugins/array/awsbatch/monitor.go b/go/tasks/plugins/array/awsbatch/monitor.go index df9716507..4f09b911f 100644 --- a/go/tasks/plugins/array/awsbatch/monitor.go +++ b/go/tasks/plugins/array/awsbatch/monitor.go @@ -3,9 +3,9 @@ package awsbatch import ( "context" - "github.com/flyteorg/flyteplugins/go/tasks/errors" - core2 "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + + "github.com/flyteorg/flyteplugins/go/tasks/errors" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/ioutils" "github.com/flyteorg/flyteplugins/go/tasks/plugins/array" From c294ab3b6f033fdcdeb220e9a7266b3cb9c3214c Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Wed, 30 Nov 2022 12:08:45 -0800 Subject: [PATCH 13/14] fix tests Signed-off-by: Kevin Su --- .../plugins/array/awsbatch/monitor_test.go | 59 ++++++++++++------- 1 file changed, 39 insertions(+), 20 deletions(-) diff --git a/go/tasks/plugins/array/awsbatch/monitor_test.go b/go/tasks/plugins/array/awsbatch/monitor_test.go index e9cedb90b..a9d708377 100644 --- a/go/tasks/plugins/array/awsbatch/monitor_test.go +++ b/go/tasks/plugins/array/awsbatch/monitor_test.go @@ -3,6 +3,8 @@ package awsbatch import ( "testing" + "github.com/stretchr/testify/mock" + "github.com/flyteorg/flytestdlib/contextutils" "github.com/flyteorg/flytestdlib/promutils/labeled" @@ -11,6 +13,7 @@ import ( "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/arraystatus" + flyteIdl "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" "github.com/aws/aws-sdk-go/aws/request" @@ -19,6 +22,7 @@ import ( arrayCore "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/core" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks" + ioMocks "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io/mocks" "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/awsbatch/config" batchMocks "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/awsbatch/mocks" "github.com/flyteorg/flytestdlib/utils" @@ -35,15 +39,39 @@ func init() { func TestCheckSubTasksState(t *testing.T) { ctx := context.Background() + tCtx := &mocks.TaskExecutionContext{} tID := &mocks.TaskExecutionID{} tID.OnGetGeneratedName().Return("generated-name") - tMeta := &mocks.TaskExecutionMetadata{} tMeta.OnGetOwnerID().Return(types.NamespacedName{ Namespace: "domain", Name: "name", }) tMeta.OnGetTaskExecutionID().Return(tID) + inMemDatastore, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) + assert.NoError(t, err) + + outputWriter := &ioMocks.OutputWriter{} + outputWriter.OnGetOutputPrefixPath().Return("") + outputWriter.OnGetRawOutputPrefix().Return("") + + taskReader := &mocks.TaskReader{} + task := &flyteIdl.TaskTemplate{ + Type: "test", + Target: &flyteIdl.TaskTemplate_Container{ + Container: &flyteIdl.Container{ + Command: []string{"command"}, + Args: []string{"{{.Input}}"}, + }, + }, + Metadata: &flyteIdl.TaskMetadata{Retries: &flyteIdl.RetryStrategy{Retries: 3}}, + } + taskReader.On("Read", mock.Anything).Return(task, nil) + + tCtx.OnOutputWriter().Return(outputWriter) + tCtx.OnTaskReader().Return(taskReader) + tCtx.OnDataStore().Return(inMemDatastore) + tCtx.OnTaskExecutionMetadata().Return(tMeta) t.Run("Not in cache", func(t *testing.T) { mBatchClient := batchMocks.NewMockAwsBatchClient() @@ -52,7 +80,7 @@ func TestCheckSubTasksState(t *testing.T) { utils.NewRateLimiter("", 10, 20)) jobStore := newJobsStore(t, batchClient) - newState, err := CheckSubTasksState(ctx, tMeta, "", "", jobStore, nil, &config.Config{}, &State{ + newState, err := CheckSubTasksState(ctx, tCtx, jobStore, &config.Config{}, &State{ State: &arrayCore.State{ CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions, ExecutionArraySize: 5, @@ -61,7 +89,7 @@ func TestCheckSubTasksState(t *testing.T) { }, ExternalJobID: refStr("job-id"), JobDefinitionArn: "", - }, getAwsBatchExecutorMetrics(promutils.NewTestScope()), 3) + }, getAwsBatchExecutorMetrics(promutils.NewTestScope())) assert.NoError(t, err) p, _ := newState.GetPhase() @@ -98,7 +126,7 @@ func TestCheckSubTasksState(t *testing.T) { assert.NoError(t, err) - newState, err := CheckSubTasksState(ctx, tMeta, "", "", jobStore, nil, &config.Config{}, &State{ + newState, err := CheckSubTasksState(ctx, tCtx, jobStore, &config.Config{}, &State{ State: &arrayCore.State{ CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions, ExecutionArraySize: 5, @@ -107,7 +135,7 @@ func TestCheckSubTasksState(t *testing.T) { }, ExternalJobID: refStr("job-id"), JobDefinitionArn: "", - }, getAwsBatchExecutorMetrics(promutils.NewTestScope()), 3) + }, getAwsBatchExecutorMetrics(promutils.NewTestScope())) assert.NoError(t, err) p, _ := newState.GetPhase() @@ -133,13 +161,10 @@ func TestCheckSubTasksState(t *testing.T) { assert.NoError(t, err) - inMemDatastore, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) - assert.NoError(t, err) - retryAttemptsArray, err := bitarray.NewCompactArray(1, bitarray.Item(1)) assert.NoError(t, err) - newState, err := CheckSubTasksState(ctx, tMeta, "", "", jobStore, inMemDatastore, &config.Config{}, &State{ + newState, err := CheckSubTasksState(ctx, tCtx, jobStore, &config.Config{}, &State{ State: &arrayCore.State{ CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions, ExecutionArraySize: 1, @@ -153,7 +178,7 @@ func TestCheckSubTasksState(t *testing.T) { }, ExternalJobID: refStr("job-id"), JobDefinitionArn: "", - }, getAwsBatchExecutorMetrics(promutils.NewTestScope()), 3) + }, getAwsBatchExecutorMetrics(promutils.NewTestScope())) assert.NoError(t, err) p, _ := newState.GetPhase() @@ -181,13 +206,10 @@ func TestCheckSubTasksState(t *testing.T) { assert.NoError(t, err) - inMemDatastore, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) - assert.NoError(t, err) - retryAttemptsArray, err := bitarray.NewCompactArray(2, bitarray.Item(1)) assert.NoError(t, err) - newState, err := CheckSubTasksState(ctx, tMeta, "", "", jobStore, inMemDatastore, &config.Config{}, &State{ + newState, err := CheckSubTasksState(ctx, tCtx, jobStore, &config.Config{}, &State{ State: &arrayCore.State{ CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions, ExecutionArraySize: 2, @@ -201,7 +223,7 @@ func TestCheckSubTasksState(t *testing.T) { }, ExternalJobID: refStr("job-id"), JobDefinitionArn: "", - }, getAwsBatchExecutorMetrics(promutils.NewTestScope()), 3) + }, getAwsBatchExecutorMetrics(promutils.NewTestScope())) assert.NoError(t, err) p, _ := newState.GetPhase() @@ -228,13 +250,10 @@ func TestCheckSubTasksState(t *testing.T) { assert.NoError(t, err) - inMemDatastore, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) - assert.NoError(t, err) - retryAttemptsArray, err := bitarray.NewCompactArray(2, bitarray.Item(1)) assert.NoError(t, err) - newState, err := CheckSubTasksState(ctx, tMeta, "", "", jobStore, inMemDatastore, &config.Config{}, &State{ + newState, err := CheckSubTasksState(ctx, tCtx, jobStore, &config.Config{}, &State{ State: &arrayCore.State{ CurrentPhase: arrayCore.PhaseWriteToDiscoveryThenFail, ExecutionArraySize: 2, @@ -248,7 +267,7 @@ func TestCheckSubTasksState(t *testing.T) { }, ExternalJobID: refStr("job-id"), JobDefinitionArn: "", - }, getAwsBatchExecutorMetrics(promutils.NewTestScope()), 1) + }, getAwsBatchExecutorMetrics(promutils.NewTestScope())) assert.NoError(t, err) p, _ := newState.GetPhase() From 6809ea4cd33a7c4801517655bd4574105492d21e Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Wed, 30 Nov 2022 12:10:18 -0800 Subject: [PATCH 14/14] nit Signed-off-by: Kevin Su --- go/tasks/plugins/webapi/databricks/config.go | 71 ----- .../plugins/webapi/databricks/config_test.go | 18 -- .../webapi/databricks/integration_test.go | 107 ------- go/tasks/plugins/webapi/databricks/plugin.go | 290 ------------------ .../plugins/webapi/databricks/plugin_test.go | 122 -------- go/tasks/plugins/webapi/snowflake/config.go | 2 +- 6 files changed, 1 insertion(+), 609 deletions(-) delete mode 100644 go/tasks/plugins/webapi/databricks/config.go delete mode 100644 go/tasks/plugins/webapi/databricks/config_test.go delete mode 100644 go/tasks/plugins/webapi/databricks/integration_test.go delete mode 100644 go/tasks/plugins/webapi/databricks/plugin.go delete mode 100644 go/tasks/plugins/webapi/databricks/plugin_test.go diff --git a/go/tasks/plugins/webapi/databricks/config.go b/go/tasks/plugins/webapi/databricks/config.go deleted file mode 100644 index 7cd0c17c2..000000000 --- a/go/tasks/plugins/webapi/databricks/config.go +++ /dev/null @@ -1,71 +0,0 @@ -package databricks - -import ( - "time" - - pluginsConfig "github.com/flyteorg/flyteplugins/go/tasks/config" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/webapi" - "github.com/flyteorg/flytestdlib/config" -) - -var ( - defaultConfig = Config{ - WebAPI: webapi.PluginConfig{ - ResourceQuotas: map[core.ResourceNamespace]int{ - "default": 1000, - }, - ReadRateLimiter: webapi.RateLimiterConfig{ - Burst: 100, - QPS: 10, - }, - WriteRateLimiter: webapi.RateLimiterConfig{ - Burst: 100, - QPS: 10, - }, - Caching: webapi.CachingConfig{ - Size: 500000, - ResyncInterval: config.Duration{Duration: 30 * time.Second}, - Workers: 10, - MaxSystemFailures: 5, - }, - ResourceMeta: nil, - }, - ResourceConstraints: core.ResourceConstraintsSpec{ - ProjectScopeResourceConstraint: &core.ResourceConstraint{ - Value: 100, - }, - NamespaceScopeResourceConstraint: &core.ResourceConstraint{ - Value: 50, - }, - }, - DefaultCluster: "COMPUTE_CLUSTER", - TokenKey: "FLYTE_DATABRICKS_API_TOKEN", - } - - configSection = pluginsConfig.MustRegisterSubSection("databricks", &defaultConfig) -) - -// Config is config for 'databricks' plugin -type Config struct { - // WebAPI defines config for the base WebAPI plugin - WebAPI webapi.PluginConfig `json:"webApi" pflag:",Defines config for the base WebAPI plugin."` - - // ResourceConstraints defines resource constraints on how many executions to be created per project/overall at any given time - ResourceConstraints core.ResourceConstraintsSpec `json:"resourceConstraints" pflag:"-,Defines resource constraints on how many executions to be created per project/overall at any given time."` - - DefaultCluster string `json:"defaultWarehouse" pflag:",Defines the default warehouse to use when running on Databricks unless overwritten by the task."` - - TokenKey string `json:"databricksTokenKey" pflag:",Name of the key where to find Databricks token in the secret manager."` - - // databricksEndpoint overrides databricks instance endpoint, only for testing - databricksEndpoint string -} - -func GetConfig() *Config { - return configSection.GetConfig().(*Config) -} - -func SetConfig(cfg *Config) error { - return configSection.SetConfig(cfg) -} diff --git a/go/tasks/plugins/webapi/databricks/config_test.go b/go/tasks/plugins/webapi/databricks/config_test.go deleted file mode 100644 index 46cee89e2..000000000 --- a/go/tasks/plugins/webapi/databricks/config_test.go +++ /dev/null @@ -1,18 +0,0 @@ -package databricks - -import ( - "testing" - "time" - - "github.com/stretchr/testify/assert" -) - -func TestGetAndSetConfig(t *testing.T) { - cfg := defaultConfig - cfg.DefaultCluster = "test-cluster" - cfg.WebAPI.Caching.Workers = 1 - cfg.WebAPI.Caching.ResyncInterval.Duration = 5 * time.Second - err := SetConfig(&cfg) - assert.NoError(t, err) - assert.Equal(t, &cfg, GetConfig()) -} diff --git a/go/tasks/plugins/webapi/databricks/integration_test.go b/go/tasks/plugins/webapi/databricks/integration_test.go deleted file mode 100644 index 21a17962c..000000000 --- a/go/tasks/plugins/webapi/databricks/integration_test.go +++ /dev/null @@ -1,107 +0,0 @@ -package databricks - -import ( - "context" - "fmt" - "net/http" - "net/http/httptest" - "testing" - "time" - - "github.com/flyteorg/flyteidl/clients/go/coreutils" - coreIdl "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - flyteIdlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery" - pluginCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" - pluginCoreMocks "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks" - "github.com/flyteorg/flyteplugins/tests" - "github.com/flyteorg/flytestdlib/contextutils" - "github.com/flyteorg/flytestdlib/promutils" - "github.com/flyteorg/flytestdlib/promutils/labeled" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" -) - -func TestEndToEnd(t *testing.T) { - server := newFakeSnowflakeServer() - defer server.Close() - - iter := func(ctx context.Context, tCtx pluginCore.TaskExecutionContext) error { - return nil - } - - cfg := defaultConfig - cfg.databricksEndpoint = server.URL - cfg.DefaultCluster = "test-cluster" - cfg.WebAPI.Caching.Workers = 1 - cfg.WebAPI.Caching.ResyncInterval.Duration = 5 * time.Second - err := SetConfig(&cfg) - assert.NoError(t, err) - - pluginEntry := pluginmachinery.CreateRemotePlugin(newSnowflakeJobTaskPlugin()) - plugin, err := pluginEntry.LoadPlugin(context.TODO(), newFakeSetupContext()) - assert.NoError(t, err) - - t.Run("SELECT 1", func(t *testing.T) { - config := make(map[string]string) - config["database"] = "my-database" - config["account"] = "snowflake" - config["schema"] = "my-schema" - config["warehouse"] = "my-warehouse" - - inputs, _ := coreutils.MakeLiteralMap(map[string]interface{}{"x": 1}) - template := flyteIdlCore.TaskTemplate{ - Type: "snowflake", - Config: config, - Target: &coreIdl.TaskTemplate_Sql{Sql: &coreIdl.Sql{Statement: "SELECT 1", Dialect: coreIdl.Sql_ANSI}}, - } - - phase := tests.RunPluginEndToEndTest(t, plugin, &template, inputs, nil, nil, iter) - - assert.Equal(t, true, phase.Phase().IsSuccess()) - }) -} - -func newFakeSnowflakeServer() *httptest.Server { - statementHandle := "019e7546-0000-278c-0000-40f10001a082" - return httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { - if request.URL.Path == "/api/v2/statements" && request.Method == "POST" { - writer.WriteHeader(202) - bytes := []byte(fmt.Sprintf(`{ - "statementHandle": "%v", - "message": "Asynchronous execution in progress." - }`, statementHandle)) - _, _ = writer.Write(bytes) - return - } - - if request.URL.Path == "/api/v2/statements/"+statementHandle && request.Method == "GET" { - writer.WriteHeader(200) - bytes := []byte(fmt.Sprintf(`{ - "statementHandle": "%v", - "message": "Statement executed successfully." - }`, statementHandle)) - _, _ = writer.Write(bytes) - return - } - - if request.URL.Path == "/api/v2/statements/"+statementHandle+"/cancel" && request.Method == "POST" { - writer.WriteHeader(200) - return - } - - writer.WriteHeader(500) - })) -} - -func newFakeSetupContext() *pluginCoreMocks.SetupContext { - fakeResourceRegistrar := pluginCoreMocks.ResourceRegistrar{} - fakeResourceRegistrar.On("RegisterResourceQuota", mock.Anything, mock.Anything, mock.Anything).Return(nil) - labeled.SetMetricKeys(contextutils.NamespaceKey) - - fakeSetupContext := pluginCoreMocks.SetupContext{} - fakeSetupContext.OnMetricsScope().Return(promutils.NewScope("test")) - fakeSetupContext.OnResourceRegistrar().Return(&fakeResourceRegistrar) - - return &fakeSetupContext -} diff --git a/go/tasks/plugins/webapi/databricks/plugin.go b/go/tasks/plugins/webapi/databricks/plugin.go deleted file mode 100644 index a52e91293..000000000 --- a/go/tasks/plugins/webapi/databricks/plugin.go +++ /dev/null @@ -1,290 +0,0 @@ -package databricks - -import ( - "bytes" - "context" - "encoding/gob" - "encoding/json" - "fmt" - "io/ioutil" - "net/http" - "time" - - flyteIdlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - errors2 "github.com/flyteorg/flyteplugins/go/tasks/errors" - pluginErrors "github.com/flyteorg/flyteplugins/go/tasks/errors" - pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/template" - "github.com/flyteorg/flytestdlib/errors" - "github.com/flyteorg/flytestdlib/logger" - - "github.com/flyteorg/flytestdlib/promutils" - - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/webapi" -) - -const ( - ErrSystem errors.ErrorCode = "System" - post string = "POST" - get string = "GET" -) - -// for mocking/testing purposes, and we'll override this method -type HTTPClient interface { - Do(req *http.Request) (*http.Response, error) -} - -type Plugin struct { - metricScope promutils.Scope - cfg *Config - client HTTPClient -} - -type ResourceWrapper struct { - StatusCode int - Message string -} - -type ResourceMetaWrapper struct { - QueryID string - Account string - Token string -} - -func (p Plugin) GetConfig() webapi.PluginConfig { - return GetConfig().WebAPI -} - -type QueryInfo struct { - Account string - Warehouse string - Schema string - Database string - Statement string -} - -func (p Plugin) ResourceRequirements(_ context.Context, _ webapi.TaskExecutionContextReader) ( - namespace core.ResourceNamespace, constraints core.ResourceConstraintsSpec, err error) { - - // Resource requirements are assumed to be the same. - return "default", p.cfg.ResourceConstraints, nil -} - -func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextReader) (webapi.ResourceMeta, - webapi.Resource, error) { - task, err := taskCtx.TaskReader().Read(ctx) - if err != nil { - return nil, nil, err - } - - token, err := taskCtx.SecretManager().Get(ctx, p.cfg.TokenKey) - if err != nil { - return nil, nil, err - } - config := task.GetConfig() - - outputs, err := template.Render(ctx, []string{ - task.GetSql().Statement, - }, template.Parameters{ - TaskExecMetadata: taskCtx.TaskExecutionMetadata(), - Inputs: taskCtx.InputReader(), - OutputPath: taskCtx.OutputWriter(), - Task: taskCtx.TaskReader(), - }) - if err != nil { - return nil, nil, err - } - queryInfo := QueryInfo{ - Account: config["account"], - Warehouse: config["warehouse"], - Schema: config["schema"], - Database: config["database"], - Statement: outputs[0], - } - - if len(queryInfo.Warehouse) == 0 { - queryInfo.Warehouse = p.cfg.DefaultCluster - } - if len(queryInfo.Account) == 0 { - return nil, nil, errors.Errorf(errors2.BadTaskSpecification, "Account must not be empty.") - } - if len(queryInfo.Database) == 0 { - return nil, nil, errors.Errorf(errors2.BadTaskSpecification, "Database must not be empty.") - } - req, err := buildRequest(post, queryInfo, p.cfg.databricksEndpoint, - config["account"], token, "", false) - if err != nil { - return nil, nil, err - } - resp, err := p.client.Do(req) - if err != nil { - return nil, nil, err - } - defer resp.Body.Close() - data, err := buildResponse(resp) - if err != nil { - return nil, nil, err - } - - if data["statementHandle"] == "" { - return nil, nil, pluginErrors.Wrapf(pluginErrors.RuntimeFailure, err, - "Unable to fetch statementHandle from http response") - } - if data["message"] == "" { - return nil, nil, pluginErrors.Wrapf(pluginErrors.RuntimeFailure, err, - "Unable to fetch message from http response") - } - queryID := fmt.Sprintf("%v", data["statementHandle"]) - message := fmt.Sprintf("%v", data["message"]) - - return &ResourceMetaWrapper{queryID, queryInfo.Account, token}, - &ResourceWrapper{StatusCode: resp.StatusCode, Message: message}, nil -} - -func (p Plugin) Get(ctx context.Context, taskCtx webapi.GetContext) (latest webapi.Resource, err error) { - exec := taskCtx.ResourceMeta().(*ResourceMetaWrapper) - req, err := buildRequest(get, QueryInfo{}, p.cfg.databricksEndpoint, - exec.Account, exec.Token, exec.QueryID, false) - if err != nil { - return nil, err - } - resp, err := p.client.Do(req) - if err != nil { - return nil, err - } - defer resp.Body.Close() - data, err := buildResponse(resp) - if err != nil { - return nil, err - } - message := fmt.Sprintf("%v", data["message"]) - return &ResourceWrapper{ - StatusCode: resp.StatusCode, - Message: message, - }, nil -} - -func (p Plugin) Delete(ctx context.Context, taskCtx webapi.DeleteContext) error { - exec := taskCtx.ResourceMeta().(*ResourceMetaWrapper) - req, err := buildRequest(post, QueryInfo{}, p.cfg.databricksEndpoint, - exec.Account, exec.Token, exec.QueryID, true) - if err != nil { - return err - } - resp, err := p.client.Do(req) - if err != nil { - return err - } - defer resp.Body.Close() - logger.Info(ctx, "Deleted query execution [%v]", resp) - - return nil -} - -func (p Plugin) Status(_ context.Context, taskCtx webapi.StatusContext) (phase core.PhaseInfo, err error) { - exec := taskCtx.ResourceMeta().(*ResourceMetaWrapper) - statusCode := taskCtx.Resource().(*ResourceWrapper).StatusCode - if statusCode == 0 { - return core.PhaseInfoUndefined, errors.Errorf(ErrSystem, "No Status field set.") - } - - taskInfo := createTaskInfo(exec.QueryID, exec.Account) - switch statusCode { - case http.StatusAccepted: - return core.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, createTaskInfo(exec.QueryID, exec.Account)), nil - case http.StatusOK: - return pluginsCore.PhaseInfoSuccess(taskInfo), nil - case http.StatusUnprocessableEntity: - return pluginsCore.PhaseInfoFailure(string(rune(statusCode)), "phaseReason", taskInfo), nil - } - return core.PhaseInfoUndefined, pluginErrors.Errorf(pluginsCore.SystemErrorCode, "unknown execution phase [%v].", statusCode) -} - -func buildRequest(method string, queryInfo QueryInfo, snowflakeEndpoint string, account string, token string, - queryID string, isCancel bool) (*http.Request, error) { - var snowflakeURL string - // for mocking/testing purposes - if snowflakeEndpoint == "" { - snowflakeURL = "https://" + account + ".snowflakecomputing.com/api/v2/statements" - } else { - snowflakeURL = snowflakeEndpoint + "/api/v2/statements" - } - - var data []byte - if method == post && !isCancel { - snowflakeURL += "?async=true" - data = []byte(fmt.Sprintf(`{ - "statement": "%v", - "database": "%v", - "schema": "%v", - "warehouse": "%v" - }`, queryInfo.Statement, queryInfo.Database, queryInfo.Schema, queryInfo.Warehouse)) - } else { - snowflakeURL += "/" + queryID - } - if isCancel { - snowflakeURL += "/cancel" - } - - req, err := http.NewRequest(method, snowflakeURL, bytes.NewBuffer(data)) - if err != nil { - return nil, err - } - req.Header.Add("Authorization", "Bearer "+token) - req.Header.Add("X-Snowflake-Authorization-Token-Type", "KEYPAIR_JWT") - req.Header.Add("Content-Type", "application/json") - req.Header.Add("Accept", "application/json") - return req, nil -} - -func buildResponse(response *http.Response) (map[string]interface{}, error) { - responseBody, err := ioutil.ReadAll(response.Body) - if err != nil { - return nil, err - } - var data map[string]interface{} - err = json.Unmarshal(responseBody, &data) - if err != nil { - return nil, err - } - return data, nil -} - -func createTaskInfo(queryID string, account string) *core.TaskInfo { - timeNow := time.Now() - - return &core.TaskInfo{ - OccurredAt: &timeNow, - Logs: []*flyteIdlCore.TaskLog{ - { - Uri: fmt.Sprintf("https://%v.snowflakecomputing.com/console#/monitoring/queries/detail?queryId=%v", - account, - queryID), - Name: "Snowflake Console", - }, - }, - } -} - -func newSnowflakeJobTaskPlugin() webapi.PluginEntry { - return webapi.PluginEntry{ - ID: "snowflake", - SupportedTaskTypes: []core.TaskType{"snowflake"}, - PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) { - return &Plugin{ - metricScope: iCtx.MetricsScope(), - cfg: GetConfig(), - client: &http.Client{}, - }, nil - }, - } -} - -func init() { - gob.Register(ResourceMetaWrapper{}) - gob.Register(ResourceWrapper{}) - - pluginmachinery.PluginRegistry().RegisterRemotePlugin(newSnowflakeJobTaskPlugin()) -} diff --git a/go/tasks/plugins/webapi/databricks/plugin_test.go b/go/tasks/plugins/webapi/databricks/plugin_test.go deleted file mode 100644 index 10febc17f..000000000 --- a/go/tasks/plugins/webapi/databricks/plugin_test.go +++ /dev/null @@ -1,122 +0,0 @@ -package databricks - -import ( - "context" - "encoding/json" - "io/ioutil" - "net/http" - "strings" - "testing" - "time" - - pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" - pluginCoreMocks "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks" - "github.com/flyteorg/flytestdlib/promutils" - "github.com/stretchr/testify/assert" -) - -type MockClient struct { -} - -var ( - MockDo func(req *http.Request) (*http.Response, error) -) - -func (m *MockClient) Do(req *http.Request) (*http.Response, error) { - return MockDo(req) -} - -func TestPlugin(t *testing.T) { - fakeSetupContext := pluginCoreMocks.SetupContext{} - fakeSetupContext.OnMetricsScope().Return(promutils.NewScope("test")) - - plugin := Plugin{ - metricScope: fakeSetupContext.MetricsScope(), - cfg: GetConfig(), - client: &MockClient{}, - } - t.Run("get config", func(t *testing.T) { - cfg := defaultConfig - cfg.WebAPI.Caching.Workers = 1 - cfg.WebAPI.Caching.ResyncInterval.Duration = 5 * time.Second - err := SetConfig(&cfg) - assert.NoError(t, err) - assert.Equal(t, cfg.WebAPI, plugin.GetConfig()) - }) - t.Run("get ResourceRequirements", func(t *testing.T) { - namespace, constraints, err := plugin.ResourceRequirements(context.TODO(), nil) - assert.NoError(t, err) - assert.Equal(t, pluginsCore.ResourceNamespace("default"), namespace) - assert.Equal(t, plugin.cfg.ResourceConstraints, constraints) - }) -} - -func TestCreateTaskInfo(t *testing.T) { - t.Run("create task info", func(t *testing.T) { - taskInfo := createTaskInfo("d5493e36", "test-account") - - assert.Equal(t, 1, len(taskInfo.Logs)) - assert.Equal(t, taskInfo.Logs[0].Uri, "https://test-account.snowflakecomputing.com/console#/monitoring/queries/detail?queryId=d5493e36") - assert.Equal(t, taskInfo.Logs[0].Name, "Snowflake Console") - }) -} - -func TestBuildRequest(t *testing.T) { - account := "test-account" - token := "test-token" - queryID := "019e70eb-0000-278b-0000-40f100012b1a" - snowflakeEndpoint := "" - snowflakeURL := "https://" + account + ".snowflakecomputing.com/api/v2/statements" - t.Run("build http request for submitting a snowflake query", func(t *testing.T) { - queryInfo := QueryInfo{ - Account: account, - Warehouse: "test-warehouse", - Schema: "test-schema", - Database: "test-database", - Statement: "SELECT 1", - } - - req, err := buildRequest(post, queryInfo, snowflakeEndpoint, account, token, queryID, false) - header := http.Header{} - header.Add("Authorization", "Bearer "+token) - header.Add("X-Snowflake-Authorization-Token-Type", "KEYPAIR_JWT") - header.Add("Content-Type", "application/json") - header.Add("Accept", "application/json") - - assert.NoError(t, err) - assert.Equal(t, header, req.Header) - assert.Equal(t, snowflakeURL+"?async=true", req.URL.String()) - assert.Equal(t, post, req.Method) - }) - t.Run("build http request for getting a snowflake query status", func(t *testing.T) { - req, err := buildRequest(get, QueryInfo{}, snowflakeEndpoint, account, token, queryID, false) - - assert.NoError(t, err) - assert.Equal(t, snowflakeURL+"/"+queryID, req.URL.String()) - assert.Equal(t, get, req.Method) - }) - t.Run("build http request for deleting a snowflake query", func(t *testing.T) { - req, err := buildRequest(post, QueryInfo{}, snowflakeEndpoint, account, token, queryID, true) - - assert.NoError(t, err) - assert.Equal(t, snowflakeURL+"/"+queryID+"/cancel", req.URL.String()) - assert.Equal(t, post, req.Method) - }) -} - -func TestBuildResponse(t *testing.T) { - t.Run("build http response", func(t *testing.T) { - bodyStr := `{"statementHandle":"019c06a4-0000", "message":"Statement executed successfully."}` - responseBody := ioutil.NopCloser(strings.NewReader(bodyStr)) - response := &http.Response{Body: responseBody} - actualData, err := buildResponse(response) - assert.NoError(t, err) - - bodyByte, err := ioutil.ReadAll(strings.NewReader(bodyStr)) - assert.NoError(t, err) - var expectedData map[string]interface{} - err = json.Unmarshal(bodyByte, &expectedData) - assert.NoError(t, err) - assert.Equal(t, expectedData, actualData) - }) -} diff --git a/go/tasks/plugins/webapi/snowflake/config.go b/go/tasks/plugins/webapi/snowflake/config.go index 4a6647e8c..93160f663 100644 --- a/go/tasks/plugins/webapi/snowflake/config.go +++ b/go/tasks/plugins/webapi/snowflake/config.go @@ -48,7 +48,7 @@ var ( // Config is config for 'snowflake' plugin type Config struct { - // WebAPI defines config for the base WebAPI plugin + // WeCreateTaskInfobAPI defines config for the base WebAPI plugin WebAPI webapi.PluginConfig `json:"webApi" pflag:",Defines config for the base WebAPI plugin."` // ResourceConstraints defines resource constraints on how many executions to be created per project/overall at any given time