Skip to content

Commit

Permalink
Fixing cache lookup on large fanout map tasks (flyteorg#282)
Browse files Browse the repository at this point in the history
* waiting for cache lookup to complete before initializing external resources

Signed-off-by: Daniel Rammer <[email protected]>

* added docs

Signed-off-by: Daniel Rammer <[email protected]>

* updated array job size checks

Signed-off-by: Daniel Rammer <[email protected]>

* waiting for PhasePreLaunch to InitializeExternalResources

Signed-off-by: Daniel Rammer <[email protected]>

* fixed tests

Signed-off-by: Daniel Rammer <[email protected]>

* added unit tests

Signed-off-by: Daniel Rammer <[email protected]>

* removed duplicated test

Signed-off-by: Daniel Rammer <[email protected]>

Signed-off-by: Daniel Rammer <[email protected]>
  • Loading branch information
hamersaw authored Aug 31, 2022
1 parent b0cf365 commit 80bfc2e
Show file tree
Hide file tree
Showing 7 changed files with 64 additions and 35 deletions.
4 changes: 2 additions & 2 deletions flyteplugins/go/tasks/pluginmachinery/catalog/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ var defaultConfig = &Config{
ReaderWorkqueueConfig: workqueue.Config{
MaxRetries: 3,
Workers: 10,
IndexCacheMaxItems: 1000,
IndexCacheMaxItems: 10000,
},
WriterWorkqueueConfig: workqueue.Config{
MaxRetries: 3,
Workers: 10,
IndexCacheMaxItems: 1000,
IndexCacheMaxItems: 10000,
},
}

Expand Down
8 changes: 4 additions & 4 deletions flyteplugins/go/tasks/plugins/array/awsbatch/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func (e Executor) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (c

switch p {
case arrayCore.PhaseStart:
pluginState.State, err = array.DetermineDiscoverability(ctx, tCtx, pluginState.State)
pluginState.State, err = array.DetermineDiscoverability(ctx, tCtx, pluginConfig.MaxArrayJobSize, pluginState.State)

case arrayCore.PhasePreLaunch:
pluginState, err = EnsureJobDefinition(ctx, tCtx, pluginConfig, e.jobStore.Client, e.jobDefinitionCache, pluginState)
Expand Down Expand Up @@ -108,16 +108,16 @@ func (e Executor) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (c
// Always attempt to augment phase with task logs.
var logLinks []*idlCore.TaskLog
var externalResources []*core.ExternalResource
switch p {
case arrayCore.PhaseStart:

if p == arrayCore.PhasePreLaunch {
externalResources, err = arrayCore.InitializeExternalResources(ctx, tCtx, pluginState.State,
func(tCtx core.TaskExecutionContext, childIndex int) string {
// subTaskIDs for the the aws_batch are generated based on the job ID, therefore
// to initialize we default to an empty string which will be updated later.
return ""
},
)
default:
} else if p != arrayCore.PhaseStart {
logLinks, externalResources, err = GetTaskLinks(ctx, tCtx.TaskExecutionMetadata(), e.jobStore, pluginState)
}

Expand Down
6 changes: 0 additions & 6 deletions flyteplugins/go/tasks/plugins/array/awsbatch/launcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,6 @@ import (
func LaunchSubTasks(ctx context.Context, tCtx core.TaskExecutionContext, batchClient Client, pluginConfig *config.Config,
currentState *State, metrics ExecutorMetrics) (nextState *State, err error) {
size := currentState.GetExecutionArraySize()
if int64(currentState.GetExecutionArraySize()) > pluginConfig.MaxArrayJobSize {
ee := fmt.Errorf("array size > max allowed. Requested [%v]. Allowed [%v]", currentState.GetExecutionArraySize(), pluginConfig.MaxArrayJobSize)
logger.Info(ctx, ee)
currentState.State = currentState.SetPhase(arrayCore.PhasePermanentFailure, 0).SetReason(ee.Error())
return currentState, nil
}

jobDefinition := currentState.GetJobDefinitionArn()
if len(jobDefinition) == 0 {
Expand Down
21 changes: 20 additions & 1 deletion flyteplugins/go/tasks/plugins/array/catalog.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ const AwsBatchTaskType = "aws-batch"
// which is different than their original location. To find the original index we construct an indexLookup array.
// The subtask can find it's original index value in indexLookup[JOB_ARRAY_INDEX] where JOB_ARRAY_INDEX is an
// environment variable in the pod
func DetermineDiscoverability(ctx context.Context, tCtx core.TaskExecutionContext, state *arrayCore.State) (
func DetermineDiscoverability(ctx context.Context, tCtx core.TaskExecutionContext, maxArrayJobSize int64, state *arrayCore.State) (
*arrayCore.State, error) {

// Check that the taskTemplate is valid
Expand Down Expand Up @@ -109,6 +109,13 @@ func DetermineDiscoverability(ctx context.Context, tCtx core.TaskExecutionContex
inputReaders = ConstructStaticInputReaders(tCtx.InputReader(), literalCollection.Literals, discoveredInputName)
}

if arrayJobSize > maxArrayJobSize {
ee := fmt.Errorf("array size > max allowed. requested [%v]. allowed [%v]", arrayJobSize, maxArrayJobSize)
logger.Info(ctx, ee)
state = state.SetPhase(arrayCore.PhasePermanentFailure, 0).SetReason(ee.Error())
return state, nil
}

// If the task is not discoverable, then skip data catalog work and move directly to launch
if taskTemplate.Metadata == nil || !taskTemplate.Metadata.Discoverable {
logger.Infof(ctx, "Task is not discoverable, moving to launch phase...")
Expand All @@ -122,6 +129,18 @@ func DetermineDiscoverability(ctx context.Context, tCtx core.TaskExecutionContex

// Otherwise, run the data catalog steps - create and submit work items to the catalog processor,

// check that the number of items in the cache index LRU cache is greater than the number of
// jobs in the array. if not we will never complete the cache lookup and be stuck in an
// infinite loop.
cfg := catalog.GetConfig()
if int(arrayJobSize) > cfg.WriterWorkqueueConfig.IndexCacheMaxItems || int(arrayJobSize) > cfg.ReaderWorkqueueConfig.IndexCacheMaxItems {
ee := fmt.Errorf("array size > max allowed for cache lookup. requested [%v]. writer allowed [%v] reader allowed [%v]",
arrayJobSize, cfg.WriterWorkqueueConfig.IndexCacheMaxItems, cfg.ReaderWorkqueueConfig.IndexCacheMaxItems)
logger.Error(ctx, ee)
state = state.SetPhase(arrayCore.PhasePermanentFailure, 0).SetReason(ee.Error())
return state, nil
}

// build output writers
outputWriters, err := ConstructOutputWriters(ctx, tCtx.DataStore(), tCtx.OutputWriter().GetOutputPrefixPath(), tCtx.OutputWriter().GetRawOutputPrefix(), int(arrayJobSize))
if err != nil {
Expand Down
39 changes: 31 additions & 8 deletions flyteplugins/go/tasks/plugins/array/catalog_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func TestCatalogBitsetToLiteralCollection(t *testing.T) {
}

func runDetermineDiscoverabilityTest(t testing.TB, taskTemplate *core.TaskTemplate, future catalog.DownloadFuture,
expectedState *arrayCore.State, expectedError error) {
expectedState *arrayCore.State, maxArrayJobSize int64, expectedError error) {

ctx := context.Background()

Expand Down Expand Up @@ -126,7 +126,7 @@ func runDetermineDiscoverabilityTest(t testing.TB, taskTemplate *core.TaskTempla
CurrentPhase: arrayCore.PhaseStart,
}

got, err := DetermineDiscoverability(ctx, tCtx, state)
got, err := DetermineDiscoverability(ctx, tCtx, maxArrayJobSize, state)
if expectedError != nil {
assert.Error(t, err)
assert.True(t, errors.Is(err, expectedError))
Expand All @@ -151,7 +151,7 @@ func TestDetermineDiscoverability(t *testing.T) {
f.OnGetResponse().Return(download, nil)

t.Run("Bad Task Spec", func(t *testing.T) {
runDetermineDiscoverabilityTest(t, template, f, nil, stdErrors.Errorf(pluginErrors.BadTaskSpecification, ""))
runDetermineDiscoverabilityTest(t, template, f, nil, 0, stdErrors.Errorf(pluginErrors.BadTaskSpecification, ""))
})

template = &core.TaskTemplate{
Expand Down Expand Up @@ -186,7 +186,7 @@ func TestDetermineDiscoverability(t *testing.T) {
OriginalMinSuccesses: 1,
IndexesToCache: toCache,
Reason: "Task is not discoverable.",
}, nil)
}, 1, nil)
})

t.Run("Not discoverable", func(t *testing.T) {
Expand All @@ -200,7 +200,7 @@ func TestDetermineDiscoverability(t *testing.T) {
OriginalMinSuccesses: 1,
IndexesToCache: toCache,
Reason: "Task is not discoverable.",
}, nil)
}, 1, nil)
})

template.Metadata = &core.TaskMetadata{
Expand All @@ -221,7 +221,7 @@ func TestDetermineDiscoverability(t *testing.T) {
OriginalMinSuccesses: 1,
IndexesToCache: toCache,
Reason: "Finished cache lookup.",
}, nil)
}, 1, nil)
})

t.Run("Discoverable and cached", func(t *testing.T) {
Expand All @@ -239,7 +239,30 @@ func TestDetermineDiscoverability(t *testing.T) {
OriginalMinSuccesses: 1,
IndexesToCache: toCache,
Reason: "Finished cache lookup.",
}, nil)
}, 1, nil)
})

t.Run("DiscoveryNotYetComplete ", func(t *testing.T) {
future := &catalogMocks.DownloadFuture{}
future.OnGetResponseStatus().Return(catalog.ResponseStatusNotReady)
future.On("OnReady", mock.Anything).Return(func(_ context.Context, _ catalog.Future) {})

runDetermineDiscoverabilityTest(t, template, future, &arrayCore.State{
CurrentPhase: arrayCore.PhaseStart,
PhaseVersion: core2.DefaultPhaseVersion,
OriginalArraySize: 1,
OriginalMinSuccesses: 1,
}, 1, nil)
})

t.Run("MaxArrayJobSizeFailure", func(t *testing.T) {
runDetermineDiscoverabilityTest(t, template, f, &arrayCore.State{
CurrentPhase: arrayCore.PhasePermanentFailure,
PhaseVersion: core2.DefaultPhaseVersion,
OriginalArraySize: 1,
OriginalMinSuccesses: 1,
Reason: "array size > max allowed. requested [1]. allowed [0]",
}, 0, nil)
})
}

Expand Down Expand Up @@ -301,6 +324,6 @@ func TestDiscoverabilityTaskType1(t *testing.T) {
OriginalMinSuccesses: 2,
IndexesToCache: toCache,
Reason: "Task is not discoverable.",
}, nil)
}, 3, nil)
})
}
15 changes: 7 additions & 8 deletions flyteplugins/go/tasks/plugins/array/k8s/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,22 +86,21 @@ func (e Executor) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (c

switch p, version := pluginState.GetPhase(); p {
case arrayCore.PhaseStart:
nextState, err = array.DetermineDiscoverability(ctx, tCtx, pluginState)
if err != nil {
return core.UnknownTransition, err
}
nextState, err = array.DetermineDiscoverability(ctx, tCtx, pluginConfig.MaxArrayJobSize, pluginState)

case arrayCore.PhasePreLaunch:
nextState = pluginState.SetPhase(arrayCore.PhaseLaunch, core.DefaultPhaseVersion).SetReason("Nothing to do in PreLaunch phase.")

// we wait for PhasePreLaunch to InitializeExternalResources because then the array job
// configuration has been validated and all of the metadata necessary to report subtask
// status (ie. cache hit / etc) is available.
externalResources, err = arrayCore.InitializeExternalResources(ctx, tCtx, pluginState,
func(tCtx core.TaskExecutionContext, childIndex int) string {
subTaskExecutionID := NewSubTaskExecutionID(tCtx.TaskExecutionMetadata().GetTaskExecutionID(), childIndex, 0)
return subTaskExecutionID.GetGeneratedName()
},
)

case arrayCore.PhasePreLaunch:
nextState = pluginState.SetPhase(arrayCore.PhaseLaunch, core.DefaultPhaseVersion).SetReason("Nothing to do in PreLaunch phase.")
err = nil

case arrayCore.PhaseWaitingForResources:
fallthrough

Expand Down
6 changes: 0 additions & 6 deletions flyteplugins/go/tasks/plugins/array/k8s/management.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,6 @@ func deallocateResource(ctx context.Context, tCtx core.TaskExecutionContext, con
func LaunchAndCheckSubTasksState(ctx context.Context, tCtx core.TaskExecutionContext, kubeClient core.KubeClient,
config *Config, dataStore *storage.DataStore, outputPrefix, baseOutputDataSandbox storage.DataReference, currentState *arrayCore.State) (
newState *arrayCore.State, externalResources []*core.ExternalResource, err error) {
if int64(currentState.GetExecutionArraySize()) > config.MaxArrayJobSize {
ee := fmt.Errorf("array size > max allowed. Requested [%v]. Allowed [%v]", currentState.GetExecutionArraySize(), config.MaxArrayJobSize)
logger.Info(ctx, ee)
currentState = currentState.SetPhase(arrayCore.PhasePermanentFailure, 0).SetReason(ee.Error())
return currentState, externalResources, nil
}

newState = currentState
messageCollector := errorcollector.NewErrorMessageCollector()
Expand Down

0 comments on commit 80bfc2e

Please sign in to comment.