diff --git a/boilerplate/lyft/golang_support_tools/tools.go b/boilerplate/lyft/golang_support_tools/tools.go index 4310b39d7..88ff64523 100644 --- a/boilerplate/lyft/golang_support_tools/tools.go +++ b/boilerplate/lyft/golang_support_tools/tools.go @@ -3,8 +3,8 @@ package tools import ( + _ "github.com/alvaroloes/enumer" _ "github.com/golangci/golangci-lint/cmd/golangci-lint" _ "github.com/lyft/flytestdlib/cli/pflags" _ "github.com/vektra/mockery/cmd/mockery" - _ "github.com/alvaroloes/enumer" ) diff --git a/go.mod b/go.mod index 5aa663fe7..08e30c84d 100644 --- a/go.mod +++ b/go.mod @@ -12,7 +12,6 @@ require ( github.com/go-test/deep v1.0.5 github.com/gogo/protobuf v1.3.1 github.com/golang/protobuf v1.3.5 - github.com/googleapis/gnostic v0.4.1 // indirect github.com/hashicorp/golang-lru v0.5.4 github.com/kubeflow/pytorch-operator v0.6.0 github.com/kubeflow/tf-operator v0.5.3 @@ -27,19 +26,16 @@ require ( github.com/spf13/cobra v0.0.6 github.com/spf13/pflag v1.0.5 github.com/stretchr/testify v1.6.1 - go.opencensus.io v0.22.3 // indirect golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e golang.org/x/tools v0.0.0-20200304193943-95d2e580d8eb google.golang.org/grpc v1.28.0 gopkg.in/yaml.v2 v2.2.8 - gopkg.in/yaml.v3 v3.0.0-20200605160147-a5ece683394c // indirect gotest.tools v2.2.0+incompatible k8s.io/api v0.17.3 k8s.io/apimachinery v0.17.3 k8s.io/client-go v11.0.1-0.20190918222721-c0e3722d5cf0+incompatible k8s.io/klog v1.0.0 sigs.k8s.io/controller-runtime v0.5.1 - sigs.k8s.io/yaml v1.2.0 // indirect ) // Pin the version of client-go to something that's compatible with katrogan's fork of api and apimachinery diff --git a/go.sum b/go.sum index 05cac1653..58256bfa5 100644 --- a/go.sum +++ b/go.sum @@ -156,6 +156,7 @@ github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymF github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/ernesto-jimenez/gogen v0.0.0-20180125220232-d7d4131e6607 h1:cTavhURetDkezJCvxFggiyLeP40Mrk/TtVg2+ycw1Es= github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5/go.mod h1:a2zkGnVExMxdzMo3M0Hi/3sEU+cWnZpSni0O6/Yb/P0= github.com/ernesto-jimenez/gogen v0.0.0-20180125220232-d7d4131e6607/go.mod h1:Cg4fM0vhYWOZdgM7RIOSTRNIc8/VT7CXClC3Ni86lu4= github.com/evanphx/json-patch v4.2.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQLiYLvXMP4fmwYFNcr97nuDLSk= @@ -164,6 +165,7 @@ github.com/evanphx/json-patch v4.5.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQLi github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= github.com/fatih/color v1.9.0 h1:8xPHl4/q1VyqGIPif1F+1V3Y3lSmrq01EabUW3CoW5s= github.com/fatih/color v1.9.0/go.mod h1:eQcE1qtQxscV5RaZvpXrrb8Drkc3/DdQ+uUYCNjL+zU= +github.com/fatih/structtag v1.2.0 h1:/OdNE99OxoI/PqaW/SuSK9uxxT3f/tcSZgon/ssNSx4= github.com/fatih/structtag v1.2.0/go.mod h1:mBJUNpUnHmRKrKlQQlmCrh5PuhftFbNv8Ys4/aAZl94= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsnotify/fsnotify v1.4.8-0.20191012010759-4bf2d1fec783/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= @@ -382,6 +384,10 @@ github.com/lyft/flyteidl v0.18.9/go.mod h1:/zQXxuHO11u/saxTTZc8oYExIGEShXB+xCB1/ github.com/lyft/flyteplugins v0.5.1/go.mod h1:8zhqFG9BzbHNQGEXzGYltTJLD+KTmQZkanxXgeFI25c= github.com/lyft/flytepropeller v0.4.2/go.mod h1:TIiWv/ZP1KOI0mqeUbiMqSn2XuY8O8kn8fQc5tWcaLA= github.com/lyft/flytestdlib v0.3.0/go.mod h1:LJPPJlkFj+wwVWMrQT3K5JZgNhZi2mULsCG4ZYhinhU= +github.com/lyft/flytestdlib v0.3.2 h1:bY6Y+Fg6Jdc7zY4GAYuR7t2hjWwynIdmRvtLcRNaGnw= +github.com/lyft/flytestdlib v0.3.2/go.mod h1:LJPPJlkFj+wwVWMrQT3K5JZgNhZi2mULsCG4ZYhinhU= +github.com/lyft/flytestdlib v0.3.3 h1:MkWXPkwQinh6MR3Yf5siZhmRSt9r4YmsF+5kvVVVedE= +github.com/lyft/flytestdlib v0.3.3/go.mod h1:LJPPJlkFj+wwVWMrQT3K5JZgNhZi2mULsCG4ZYhinhU= github.com/lyft/flytestdlib v0.3.9 h1:NaKp9xkeWWwhVvqTOcR/FqlASy1N2gu/kN7PVe4S7YI= github.com/lyft/flytestdlib v0.3.9/go.mod h1:LJPPJlkFj+wwVWMrQT3K5JZgNhZi2mULsCG4ZYhinhU= github.com/lyft/spark-on-k8s-operator v0.1.3 h1:rmke8lR2Oy8mvKXRhloKuEu7fgGuXepDxiBNiorVUFI= diff --git a/go/tasks/aws/config.go b/go/tasks/aws/config.go index 5b69c8c76..4010db6e2 100644 --- a/go/tasks/aws/config.go +++ b/go/tasks/aws/config.go @@ -7,8 +7,9 @@ package aws import ( "time" - pluginsConfig "github.com/lyft/flyteplugins/go/tasks/config" "github.com/lyft/flytestdlib/config" + + pluginsConfig "github.com/lyft/flyteplugins/go/tasks/config" ) //go:generate pflags Config --default-var defaultConfig diff --git a/go/tasks/pluginmachinery/core/mocks/resource_negotiator.go b/go/tasks/pluginmachinery/core/mocks/resource_negotiator.go index b0d5ad227..9ead9b529 100644 --- a/go/tasks/pluginmachinery/core/mocks/resource_negotiator.go +++ b/go/tasks/pluginmachinery/core/mocks/resource_negotiator.go @@ -2,9 +2,12 @@ package mocks -import context "context" -import core "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core" -import mock "github.com/stretchr/testify/mock" +import ( + context "context" + + core "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core" + mock "github.com/stretchr/testify/mock" +) // ResourceRegistrar is an autogenerated mock type for the ResourceRegistrar type type ResourceNegotiator struct { diff --git a/go/tasks/pluginmachinery/core/phase.go b/go/tasks/pluginmachinery/core/phase.go index 855aa633a..d5649e8b7 100644 --- a/go/tasks/pluginmachinery/core/phase.go +++ b/go/tasks/pluginmachinery/core/phase.go @@ -60,6 +60,10 @@ func (p Phase) IsSuccess() bool { return p == PhaseSuccess } +func (p Phase) IsWaitingForResources() bool { + return p == PhaseWaitingForResources +} + type TaskInfo struct { // log information for the task execution Logs []*core.TaskLog diff --git a/go/tasks/plugins/array/core/state.go b/go/tasks/plugins/array/core/state.go index e4c20ecbd..15849390b 100644 --- a/go/tasks/plugins/array/core/state.go +++ b/go/tasks/plugins/array/core/state.go @@ -226,6 +226,7 @@ func SummaryToPhase(ctx context.Context, minSuccesses int64, summary arraystatus totalSuccesses := int64(0) totalFailures := int64(0) totalRunning := int64(0) + totalWaitingForResources := int64(0) for phase, count := range summary { totalCount += count if phase.IsTerminal() { @@ -238,6 +239,8 @@ func SummaryToPhase(ctx context.Context, minSuccesses int64, summary arraystatus // TODO: preferable to auto-combine to array tasks for now. totalFailures += count } + } else if phase.IsWaitingForResources() { + totalWaitingForResources += count } else { totalRunning += count } @@ -249,12 +252,16 @@ func SummaryToPhase(ctx context.Context, minSuccesses int64, summary arraystatus } // No chance to reach the required success numbers. - if totalRunning+totalSuccesses < minSuccesses { - logger.Infof(ctx, "Array failed early because totalRunning[%v] + totalSuccesses[%v] < minSuccesses[%v]", - totalRunning, totalSuccesses, minSuccesses) + if totalRunning+totalSuccesses+totalWaitingForResources < minSuccesses { + logger.Infof(ctx, "Array failed early because total failures > minSuccesses[%v]. Snapshot totalRunning[%v] + totalSuccesses[%v] + totalWaitingForResource[%v]", + totalRunning, totalSuccesses, totalWaitingForResources, minSuccesses) return PhaseWriteToDiscoveryThenFail } + if totalWaitingForResources > 0 { + logger.Infof(ctx, "Array is still running and waiting for resources totalWaitingForResources[%v]", totalWaitingForResources) + return PhaseWaitingForResources + } if totalSuccesses >= minSuccesses && totalRunning == 0 { logger.Infof(ctx, "Array succeeded because totalSuccesses[%v] >= minSuccesses[%v]", totalSuccesses, minSuccesses) return PhaseWriteToDiscovery diff --git a/go/tasks/plugins/array/k8s/config.go b/go/tasks/plugins/array/k8s/config.go index 33a428b90..e5d1ea04a 100644 --- a/go/tasks/plugins/array/k8s/config.go +++ b/go/tasks/plugins/array/k8s/config.go @@ -5,10 +5,16 @@ package k8s import ( + "fmt" + "io/ioutil" + + "github.com/pkg/errors" v1 "k8s.io/api/core/v1" + restclient "k8s.io/client-go/rest" + "sigs.k8s.io/controller-runtime/pkg/client" + pluginsConfig "github.com/lyft/flyteplugins/go/tasks/config" "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/workqueue" - "github.com/lyft/flytestdlib/config" ) //go:generate pflags Config --default-var=defaultConfig @@ -31,14 +37,80 @@ var ( }, } - configSection = config.MustRegisterSection(configSectionKey, defaultConfig) + configSection = pluginsConfig.MustRegisterSubSection(configSectionKey, defaultConfig) ) +type ResourceConfig struct { + PrimaryLabel string `json:"primaryLabel" pflag:",PrimaryLabel of a given service cluster"` + Limit int `json:"limit" pflag:",Resource quota (in the number of outstanding requests) for the cluster"` +} + +type ClusterConfig struct { + Name string `json:"name" pflag:",Friendly name of the remote cluster"` + Endpoint string `json:"endpoint" pflag:", Remote K8s cluster endpoint"` + Auth Auth `json:"auth" pflag:"-, Auth setting for the cluster"` + Enabled bool `json:"enabled" pflag:", Boolean flag to enable or disable"` +} + +type Auth struct { + Type string `json:"type" pflag:", Authentication type"` + TokenPath string `json:"tokenPath" pflag:", Token path"` + CertPath string `json:"certPath" pflag:", Certificate path"` +} + +func (auth Auth) GetCA() ([]byte, error) { + cert, err := ioutil.ReadFile(auth.CertPath) + if err != nil { + return nil, errors.Wrap(err, "failed to read k8s CA cert from configured path") + } + return cert, nil +} + +func (auth Auth) GetToken() (string, error) { + token, err := ioutil.ReadFile(auth.TokenPath) + if err != nil { + return "", errors.Wrap(err, "failed to read k8s bearer token from configured path") + } + return string(token), nil +} + +// TODO: Move logic to flytestdlib +// Reads secret values from paths specified in the config to initialize a Kubernetes rest client Config. +func RemoteClusterConfig(host string, auth Auth) (*restclient.Config, error) { + tokenString, err := auth.GetToken() + if err != nil { + return nil, errors.New(fmt.Sprintf("Failed to get auth token: %+v", err)) + } + + caCert, err := auth.GetCA() + if err != nil { + return nil, errors.New(fmt.Sprintf("Failed to get auth CA: %+v", err)) + } + + tlsClientConfig := restclient.TLSClientConfig{} + tlsClientConfig.CAData = caCert + return &restclient.Config{ + Host: host, + TLSClientConfig: tlsClientConfig, + BearerToken: tokenString, + }, nil +} + +func GetK8sClient(config ClusterConfig) (client.Client, error) { + kubeConf, err := RemoteClusterConfig(config.Endpoint, config.Auth) + if err != nil { + return nil, err + } + return client.New(kubeConf, client.Options{}) +} + // Defines custom config for K8s Array plugin type Config struct { DefaultScheduler string `json:"scheduler" pflag:",Decides the scheduler to use when launching array-pods."` - MaxErrorStringLength int `json:"maxErrLength" pflag:",Determines the maximum length of the error string returned for the array."` + MaxErrorStringLength int `json:"maxErrorLength" pflag:",Determines the maximum length of the error string returned for the array."` MaxArrayJobSize int64 `json:"maxArrayJobSize" pflag:",Maximum size of array job."` + ResourceConfig ResourceConfig `json:"resourceConfig" pflag:"-,ResourceConfiguration to limit number of resources used by k8s-array."` + RemoteClusterConfig ClusterConfig `json:"remoteClusterConfig" pflag:"-,Configuration of remote K8s cluster for array jobs"` NodeSelector map[string]string `json:"node-selector" pflag:"-,Defines a set of node selector labels to add to the pod."` Tolerations []v1.Toleration `json:"tolerations" pflag:"-,Tolerations to be applied for k8s-array pods"` OutputAssembler workqueue.Config @@ -48,3 +120,8 @@ type Config struct { func GetConfig() *Config { return configSection.GetConfig().(*Config) } + +func IsResourceConfigSet(resourceConfig ResourceConfig) bool { + emptyResouceConfig := ResourceConfig{} + return resourceConfig != emptyResouceConfig +} diff --git a/go/tasks/plugins/array/k8s/config_flags.go b/go/tasks/plugins/array/k8s/config_flags.go index 4a03aefc9..f8b40cc45 100755 --- a/go/tasks/plugins/array/k8s/config_flags.go +++ b/go/tasks/plugins/array/k8s/config_flags.go @@ -44,6 +44,14 @@ func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags.String(fmt.Sprintf("%v%v", prefix, "scheduler"), defaultConfig.DefaultScheduler, "Decides the scheduler to use when launching array-pods.") cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "maxErrLength"), defaultConfig.MaxErrorStringLength, "Determines the maximum length of the error string returned for the array.") cmdFlags.Int64(fmt.Sprintf("%v%v", prefix, "maxArrayJobSize"), defaultConfig.MaxArrayJobSize, "Maximum size of array job.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "resourceConfig.primaryLabel"), defaultConfig.ResourceConfig.PrimaryLabel, "PrimaryLabel of a given service cluster") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "resourceConfig.limit"), defaultConfig.ResourceConfig.Limit, "Resource quota (in the number of outstanding requests) for the cluster") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "remoteClusterConfig.name"), defaultConfig.RemoteClusterConfig.Name, "") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "remoteClusterConfig.endpoint"), defaultConfig.RemoteClusterConfig.Endpoint, "") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "remoteClusterConfig.auth.type"), defaultConfig.RemoteClusterConfig.Auth.Type, "") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "remoteClusterConfig.auth.tokenPath"), defaultConfig.RemoteClusterConfig.Auth.TokenPath, "") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "remoteClusterConfig.auth.certPath"), defaultConfig.RemoteClusterConfig.Auth.CertPath, "") + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "remoteClusterConfig.enabled"), defaultConfig.RemoteClusterConfig.Enabled, "") cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "OutputAssembler.workers"), defaultConfig.OutputAssembler.Workers, "Number of concurrent workers to start processing the queue.") cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "OutputAssembler.maxRetries"), defaultConfig.OutputAssembler.MaxRetries, "Maximum number of retries per item.") cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "OutputAssembler.maxItems"), defaultConfig.OutputAssembler.IndexCacheMaxItems, "Maximum number of entries to keep in the index.") diff --git a/go/tasks/plugins/array/k8s/config_flags_test.go b/go/tasks/plugins/array/k8s/config_flags_test.go index df7b41909..95c7350b0 100755 --- a/go/tasks/plugins/array/k8s/config_flags_test.go +++ b/go/tasks/plugins/array/k8s/config_flags_test.go @@ -165,6 +165,182 @@ func TestConfig_SetFlags(t *testing.T) { } }) }) + t.Run("Test_resourceConfig.primaryLabel", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("resourceConfig.primaryLabel"); err == nil { + assert.Equal(t, string(defaultConfig.ResourceConfig.PrimaryLabel), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("resourceConfig.primaryLabel", testValue) + if vString, err := cmdFlags.GetString("resourceConfig.primaryLabel"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.ResourceConfig.PrimaryLabel) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_resourceConfig.limit", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt, err := cmdFlags.GetInt("resourceConfig.limit"); err == nil { + assert.Equal(t, int(defaultConfig.ResourceConfig.Limit), vInt) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("resourceConfig.limit", testValue) + if vInt, err := cmdFlags.GetInt("resourceConfig.limit"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.ResourceConfig.Limit) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_remoteClusterConfig.name", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("remoteClusterConfig.name"); err == nil { + assert.Equal(t, string(defaultConfig.RemoteClusterConfig.Name), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("remoteClusterConfig.name", testValue) + if vString, err := cmdFlags.GetString("remoteClusterConfig.name"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.RemoteClusterConfig.Name) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_remoteClusterConfig.endpoint", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("remoteClusterConfig.endpoint"); err == nil { + assert.Equal(t, string(defaultConfig.RemoteClusterConfig.Endpoint), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("remoteClusterConfig.endpoint", testValue) + if vString, err := cmdFlags.GetString("remoteClusterConfig.endpoint"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.RemoteClusterConfig.Endpoint) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_remoteClusterConfig.auth.type", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("remoteClusterConfig.auth.type"); err == nil { + assert.Equal(t, string(defaultConfig.RemoteClusterConfig.Auth.Type), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("remoteClusterConfig.auth.type", testValue) + if vString, err := cmdFlags.GetString("remoteClusterConfig.auth.type"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.RemoteClusterConfig.Auth.Type) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_remoteClusterConfig.auth.tokenPath", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("remoteClusterConfig.auth.tokenPath"); err == nil { + assert.Equal(t, string(defaultConfig.RemoteClusterConfig.Auth.TokenPath), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("remoteClusterConfig.auth.tokenPath", testValue) + if vString, err := cmdFlags.GetString("remoteClusterConfig.auth.tokenPath"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.RemoteClusterConfig.Auth.TokenPath) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_remoteClusterConfig.auth.certPath", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("remoteClusterConfig.auth.certPath"); err == nil { + assert.Equal(t, string(defaultConfig.RemoteClusterConfig.Auth.CertPath), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("remoteClusterConfig.auth.certPath", testValue) + if vString, err := cmdFlags.GetString("remoteClusterConfig.auth.certPath"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.RemoteClusterConfig.Auth.CertPath) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_remoteClusterConfig.enabled", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vBool, err := cmdFlags.GetBool("remoteClusterConfig.enabled"); err == nil { + assert.Equal(t, bool(defaultConfig.RemoteClusterConfig.Enabled), vBool) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("remoteClusterConfig.enabled", testValue) + if vBool, err := cmdFlags.GetBool("remoteClusterConfig.enabled"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vBool), &actual.RemoteClusterConfig.Enabled) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) t.Run("Test_OutputAssembler.workers", func(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly diff --git a/go/tasks/plugins/array/k8s/executor.go b/go/tasks/plugins/array/k8s/executor.go index de8b46973..eea808221 100644 --- a/go/tasks/plugins/array/k8s/executor.go +++ b/go/tasks/plugins/array/k8s/executor.go @@ -3,10 +3,14 @@ package k8s import ( "context" + "sigs.k8s.io/controller-runtime/pkg/cache" + "sigs.k8s.io/controller-runtime/pkg/client" + idlCore "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" "github.com/lyft/flyteplugins/go/tasks/plugins/array" arrayCore "github.com/lyft/flyteplugins/go/tasks/plugins/array/core" + "github.com/lyft/flytestdlib/logger" "github.com/lyft/flytestdlib/promutils" "github.com/lyft/flyteplugins/go/tasks/pluginmachinery" @@ -25,6 +29,24 @@ type Executor struct { errorAssembler array.OutputAssembler } +type KubeClientObj struct { + client client.Client +} + +func (k KubeClientObj) GetClient() client.Client { + return k.client +} + +func (k KubeClientObj) GetCache() cache.Cache { + return nil +} + +func NewKubeClientObj(c client.Client) core.KubeClient { + return &KubeClientObj{ + client: c, + } +} + func NewExecutor(kubeClient core.KubeClient, cfg *Config, scope promutils.Scope) (Executor, error) { outputAssembler, err := array.NewOutputAssembler(cfg.OutputAssembler, scope.NewSubScope("output_assembler")) if err != nil { @@ -77,11 +99,15 @@ func (e Executor) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (c fallthrough case arrayCore.PhaseLaunch: - nextState, err = LaunchSubTasks(ctx, tCtx, e.kubeClient, pluginConfig, pluginState) + // In order to maintain backwards compatibility with the state transitions + // in the aws batch plugin. Forward to PhaseCheckingSubTasksExecutions where the launching + // is actually occurring. + nextState = pluginState.SetPhase(arrayCore.PhaseCheckingSubTaskExecutions, core.DefaultPhaseVersion).SetReason("Nothing to do in Launch phase.") + err = nil case arrayCore.PhaseCheckingSubTaskExecutions: - nextState, logLinks, err = CheckSubTasksState(ctx, tCtx, e.kubeClient, tCtx.DataStore(), - tCtx.OutputWriter().GetOutputPrefixPath(), tCtx.OutputWriter().GetRawOutputPrefix(), pluginState) + nextState, logLinks, err = LaunchAndCheckSubTasksState(ctx, tCtx, e.kubeClient, pluginConfig, + tCtx.DataStore(), tCtx.OutputWriter().GetOutputPrefixPath(), tCtx.OutputWriter().GetRawOutputPrefix(), pluginState) case arrayCore.PhaseAssembleFinalOutput: nextState, err = array.AssembleFinalOutputs(ctx, e.outputsAssembler, tCtx, arrayCore.PhaseSuccess, pluginState) @@ -128,8 +154,7 @@ func (e Executor) Finalize(ctx context.Context, tCtx core.TaskExecutionContext) return errors.Wrapf(errors.CorruptedPluginState, err, "Failed to read unmarshal custom state") } - return TerminateSubTasks(ctx, tCtx.TaskExecutionMetadata(), e.kubeClient, pluginConfig.MaxErrorStringLength, - pluginState) + return TerminateSubTasks(ctx, tCtx, e.kubeClient, pluginConfig, pluginState) } func (e Executor) Start(ctx context.Context) error { @@ -155,7 +180,18 @@ func init() { } func GetNewExecutorPlugin(ctx context.Context, iCtx core.SetupContext) (core.Plugin, error) { - exec, err := NewExecutor(iCtx.KubeClient(), GetConfig(), iCtx.MetricsScope()) + var kubeClient core.KubeClient + remoteClusterConfig := GetConfig().RemoteClusterConfig + if remoteClusterConfig.Enabled { + client, err := GetK8sClient(remoteClusterConfig) + if err != nil { + return nil, err + } + kubeClient = NewKubeClientObj(client) + } else { + kubeClient = iCtx.KubeClient() + } + exec, err := NewExecutor(kubeClient, GetConfig(), iCtx.MetricsScope()) if err != nil { return nil, err } @@ -164,5 +200,15 @@ func GetNewExecutorPlugin(ctx context.Context, iCtx core.SetupContext) (core.Plu return nil, err } + resourceConfig := GetConfig().ResourceConfig + if IsResourceConfigSet(resourceConfig) { + primaryLabel := resourceConfig.PrimaryLabel + limit := resourceConfig.Limit + if err := iCtx.ResourceRegistrar().RegisterResourceQuota(ctx, core.ResourceNamespace(primaryLabel), limit); err != nil { + logger.Errorf(ctx, "Token Resource registration for [%v] failed due to error [%v]", primaryLabel, err) + return nil, err + } + } + return exec, nil } diff --git a/go/tasks/plugins/array/k8s/launcher.go b/go/tasks/plugins/array/k8s/launcher.go index 6689dddd1..8de45d4d0 100644 --- a/go/tasks/plugins/array/k8s/launcher.go +++ b/go/tasks/plugins/array/k8s/launcher.go @@ -3,24 +3,14 @@ package k8s import ( "context" "fmt" - "strconv" - "strings" "github.com/lyft/flyteplugins/go/tasks/plugins/array/errorcollector" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - arrayCore "github.com/lyft/flyteplugins/go/tasks/plugins/array/core" - arraystatus2 "github.com/lyft/flyteplugins/go/tasks/plugins/array/arraystatus" errors2 "github.com/lyft/flytestdlib/errors" - "github.com/lyft/flytestdlib/logger" - - "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/utils" - corev1 "k8s.io/api/core/v1" - k8serrors "k8s.io/apimachinery/pkg/api/errors" "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core" ) @@ -68,113 +58,30 @@ func applyPodTolerations(_ context.Context, cfg *Config, pod *corev1.Pod) *corev return pod } -// Launches subtasks -func LaunchSubTasks(ctx context.Context, tCtx core.TaskExecutionContext, kubeClient core.KubeClient, - config *Config, currentState *arrayCore.State) (newState *arrayCore.State, err error) { - - if int64(currentState.GetExecutionArraySize()) > config.MaxArrayJobSize { - ee := fmt.Errorf("array size > max allowed. Requested [%v]. Allowed [%v]", currentState.GetExecutionArraySize(), config.MaxArrayJobSize) - logger.Info(ctx, ee) - currentState = currentState.SetPhase(arrayCore.PhasePermanentFailure, 0).SetReason(ee.Error()) - return currentState, nil - } - - podTemplate, _, err := FlyteArrayJobToK8sPodTemplate(ctx, tCtx) - if err != nil { - return currentState, errors2.Wrapf(ErrBuildPodTemplate, err, "Failed to convert task template to a pod template for task") - } - - var args []string - if len(podTemplate.Spec.Containers) > 0 { - args = append(podTemplate.Spec.Containers[0].Command, podTemplate.Spec.Containers[0].Args...) - podTemplate.Spec.Containers[0].Command = []string{} - } - - size := currentState.GetExecutionArraySize() - // TODO: Respect parallelism param - for i := 0; i < size; i++ { - pod := podTemplate.DeepCopy() - indexStr := strconv.Itoa(i) - pod.Name = formatSubTaskName(ctx, tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), indexStr) - pod.Spec.Containers[0].Env = append(pod.Spec.Containers[0].Env, corev1.EnvVar{ - Name: FlyteK8sArrayIndexVarName, - Value: indexStr, - }) - - pod.Spec.Containers[0].Env = append(pod.Spec.Containers[0].Env, arrayJobEnvVars...) - - pod.Spec.Containers[0].Args, err = utils.ReplaceTemplateCommandArgs(ctx, args, arrayJobInputReader{tCtx.InputReader()}, tCtx.OutputWriter()) - if err != nil { - return currentState, errors2.Wrapf(ErrReplaceCmdTemplate, err, "Failed to replace cmd args") - } - - pod = ApplyPodPolicies(ctx, config, pod) - pod = applyNodeSelectorLabels(ctx, config, pod) - pod = applyPodTolerations(ctx, config, pod) - - err = kubeClient.GetClient().Create(ctx, pod) - if err != nil && !k8serrors.IsAlreadyExists(err) { - if k8serrors.IsForbidden(err) { - if strings.Contains(err.Error(), "exceeded quota") { - // TODO: Quota errors are retried forever, it would be good to have support for backoff strategy. - logger.Infof(ctx, "Failed to launch job, resource quota exceeded. Err: %v", err) - currentState = currentState.SetPhase(arrayCore.PhaseWaitingForResources, 0).SetReason("Not enough resources to launch job.") - } else { - currentState = currentState.SetPhase(arrayCore.PhaseRetryableFailure, 0).SetReason("Failed to launch job.") - } - - currentState = currentState.SetReason(err.Error()) - return currentState, nil - } - - return currentState, errors2.Wrapf(ErrSubmitJob, err, "Failed to submit job") - } - } - - logger.Infof(ctx, "Successfully submitted Job(s) with Prefix:[%v], Count:[%v]", tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), size) - - arrayStatus := arraystatus2.ArrayStatus{ - Summary: arraystatus2.ArraySummary{}, - Detailed: arrayCore.NewPhasesCompactArray(uint(size)), - } - - currentState.SetPhase(arrayCore.PhaseCheckingSubTaskExecutions, 0).SetReason("Job launched.") - currentState.SetArrayStatus(arrayStatus) - - return currentState, nil -} - -func TerminateSubTasks(ctx context.Context, tMeta core.TaskExecutionMetadata, kubeClient core.KubeClient, - errsMaxLength int, currentState *arrayCore.State) error { +func TerminateSubTasks(ctx context.Context, tCtx core.TaskExecutionContext, kubeClient core.KubeClient, config *Config, + currentState *arrayCore.State) error { size := currentState.GetExecutionArraySize() errs := errorcollector.NewErrorMessageCollector() - for i := 0; i < size; i++ { - indexStr := strconv.Itoa(i) - podName := formatSubTaskName(ctx, tMeta.GetTaskExecutionID().GetGeneratedName(), indexStr) - pod := &corev1.Pod{ - TypeMeta: metav1.TypeMeta{ - Kind: PodKind, - APIVersion: metav1.SchemeGroupVersion.String(), - }, - ObjectMeta: metav1.ObjectMeta{ - Name: podName, - Namespace: tMeta.GetNamespace(), - }, + for childIdx := 0; childIdx < size; childIdx++ { + task := Task{ + ChildIdx: childIdx, + Config: config, + State: currentState, } - err := kubeClient.GetClient().Delete(ctx, pod) + err := task.Abort(ctx, tCtx, kubeClient) if err != nil { - if k8serrors.IsNotFound(err) { - continue - } - - errs.Collect(i, err.Error()) + errs.Collect(childIdx, err.Error()) + } + err = task.Finalize(ctx, tCtx, kubeClient) + if err != nil { + errs.Collect(childIdx, err.Error()) } } if errs.Length() > 0 { - return fmt.Errorf(errs.Summary(errsMaxLength)) + return fmt.Errorf(errs.Summary(config.MaxErrorStringLength)) } return nil diff --git a/go/tasks/plugins/array/k8s/monitor.go b/go/tasks/plugins/array/k8s/monitor.go index f406a5307..be948b1f6 100644 --- a/go/tasks/plugins/array/k8s/monitor.go +++ b/go/tasks/plugins/array/k8s/monitor.go @@ -6,10 +6,9 @@ import ( "strconv" "time" + "github.com/lyft/flytestdlib/logger" "github.com/lyft/flytestdlib/storage" - "github.com/lyft/flyteplugins/go/tasks/plugins/array" - arrayCore "github.com/lyft/flyteplugins/go/tasks/plugins/array/core" "github.com/lyft/flytestdlib/bitarray" @@ -24,7 +23,7 @@ import ( "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/flytek8s" idlCore "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" - "github.com/lyft/flytestdlib/errors" + errors2 "github.com/lyft/flytestdlib/errors" k8serrors "k8s.io/apimachinery/pkg/api/errors" "github.com/lyft/flyteplugins/go/tasks/logs" @@ -32,27 +31,47 @@ import ( ) const ( - ErrCheckPodStatus errors.ErrorCode = "CHECK_POD_FAILED" + ErrCheckPodStatus errors2.ErrorCode = "CHECK_POD_FAILED" ) -func CheckSubTasksState(ctx context.Context, tCtx core.TaskExecutionContext, kubeClient core.KubeClient, - dataStore *storage.DataStore, outputPrefix, baseOutputDataSandbox storage.DataReference, currentState *arrayCore.State) ( +func LaunchAndCheckSubTasksState(ctx context.Context, tCtx core.TaskExecutionContext, kubeClient core.KubeClient, + config *Config, dataStore *storage.DataStore, outputPrefix, baseOutputDataSandbox storage.DataReference, currentState *arrayCore.State) ( newState *arrayCore.State, logLinks []*idlCore.TaskLog, err error) { + if int64(currentState.GetExecutionArraySize()) > config.MaxArrayJobSize { + ee := fmt.Errorf("array size > max allowed. Requested [%v]. Allowed [%v]", currentState.GetExecutionArraySize(), config.MaxArrayJobSize) + logger.Info(ctx, ee) + currentState = currentState.SetPhase(arrayCore.PhasePermanentFailure, 0).SetReason(ee.Error()) + return currentState, logLinks, nil + } logLinks = make([]*idlCore.TaskLog, 0, 4) newState = currentState - msg := errorcollector.NewErrorMessageCollector() - newArrayStatus := arraystatus.ArrayStatus{ + newArrayStatus := &arraystatus.ArrayStatus{ Summary: arraystatus.ArraySummary{}, Detailed: arrayCore.NewPhasesCompactArray(uint(currentState.GetExecutionArraySize())), } + // If we have arrived at this state for the first time then currentState has not been + // initialized with number of sub tasks. + if len(currentState.GetArrayStatus().Detailed.GetItems()) == 0 { + currentState.ArrayStatus = *newArrayStatus + } + for childIdx, existingPhaseIdx := range currentState.GetArrayStatus().Detailed.GetItems() { existingPhase := core.Phases[existingPhaseIdx] + indexStr := strconv.Itoa(childIdx) + podName := formatSubTaskName(ctx, tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), indexStr) if existingPhase.IsTerminal() { // If we get here it means we have already "processed" this terminal phase since we will only persist // the phase after all processing is done (e.g. check outputs/errors file, record events... etc.). + + // Since we know we have already "processed" this terminal phase we can safely deallocate resource + err = deallocateResource(ctx, tCtx, config, childIdx) + if err != nil { + logger.Errorf(ctx, "Error releasing allocation token [%s] in LaunchAndCheckSubTasks [%s]", podName, err) + return currentState, logLinks, errors2.Wrapf(ErrCheckPodStatus, err, "Error releasing allocation token.") + } newArrayStatus.Summary.Inc(existingPhase) newArrayStatus.Detailed.SetItem(childIdx, bitarray.Item(existingPhase)) @@ -60,37 +79,49 @@ func CheckSubTasksState(ctx context.Context, tCtx core.TaskExecutionContext, kub continue } - phaseInfo, err := CheckPodStatus(ctx, kubeClient, - k8sTypes.NamespacedName{ - Name: formatSubTaskName(ctx, tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), strconv.Itoa(childIdx)), - Namespace: tCtx.TaskExecutionMetadata().GetNamespace(), - }) - if err != nil { - return currentState, logLinks, errors.Wrapf(ErrCheckPodStatus, err, "Failed to check pod status") + task := &Task{ + LogLinks: logLinks, + State: newState, + NewArrayStatus: newArrayStatus, + Config: config, + ChildIdx: childIdx, + MessageCollector: &msg, } - if phaseInfo.Info() != nil { - logLinks = append(logLinks, phaseInfo.Info().Logs...) + // The first time we enter this state we will launch every subtask. On subsequent rounds, the pod + // has already been created so we return a Success value and continue with the Monitor step. + var launchResult LaunchResult + launchResult, err = task.Launch(ctx, tCtx, kubeClient) + if err != nil { + logger.Errorf(ctx, "K8s array - Launch error %v", err) + return currentState, logLinks, err } - if phaseInfo.Err() != nil { - msg.Collect(childIdx, phaseInfo.Err().String()) + switch launchResult { + case LaunchSuccess: + // Continue with execution if successful + case LaunchError: + return currentState, logLinks, err + // If Resource manager is enabled and there are currently not enough resources we can skip this round + // for a subtask and wait until there are enough resources. + case LaunchWaiting: + continue + case LaunchReturnState: + return currentState, logLinks, nil } - actualPhase := phaseInfo.Phase() - if phaseInfo.Phase().IsSuccess() { - originalIdx := arrayCore.CalculateOriginalIndex(childIdx, currentState.GetIndexesToCache()) - actualPhase, err = array.CheckTaskOutput(ctx, dataStore, outputPrefix, baseOutputDataSandbox, childIdx, originalIdx) + var monitorResult MonitorResult + monitorResult, err = task.Monitor(ctx, tCtx, kubeClient, dataStore, outputPrefix, baseOutputDataSandbox) + + if monitorResult != MonitorSuccess { if err != nil { - return nil, nil, err + logger.Errorf(ctx, "K8s array - Monitor error %v", err) } + return currentState, logLinks, err } - - newArrayStatus.Detailed.SetItem(childIdx, bitarray.Item(actualPhase)) - newArrayStatus.Summary.Inc(actualPhase) } - newState = newState.SetArrayStatus(newArrayStatus) + newState = newState.SetArrayStatus(*newArrayStatus) // Check that the taskTemplate is valid taskTemplate, err := tCtx.TaskReader().Read(ctx) @@ -108,6 +139,7 @@ func CheckSubTasksState(ctx context.Context, tCtx core.TaskExecutionContext, kub if phase == arrayCore.PhaseCheckingSubTaskExecutions { newPhaseVersion := uint32(0) + // For now, the only changes to PhaseVersion and PreviousSummary occur for running array jobs. for phase, count := range newState.GetArrayStatus().Summary { newPhaseVersion += uint32(phase) * uint32(count) diff --git a/go/tasks/plugins/array/k8s/monitor_test.go b/go/tasks/plugins/array/k8s/monitor_test.go new file mode 100644 index 000000000..9ebba1a40 --- /dev/null +++ b/go/tasks/plugins/array/k8s/monitor_test.go @@ -0,0 +1,207 @@ +package k8s + +import ( + "testing" + + core2 "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core" + mocks2 "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/io/mocks" + "github.com/lyft/flyteplugins/go/tasks/plugins/array/arraystatus" + "github.com/lyft/flytestdlib/bitarray" + "github.com/stretchr/testify/mock" + v1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/resource" + v12 "k8s.io/apimachinery/pkg/apis/meta/v1" + + arrayCore "github.com/lyft/flyteplugins/go/tasks/plugins/array/core" + + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core/mocks" + "github.com/stretchr/testify/assert" + "golang.org/x/net/context" +) + +func createSampleContainerTask() *core2.Container { + return &core2.Container{ + Command: []string{"cmd"}, + Args: []string{"{{$inputPrefix}}"}, + Image: "img1", + } +} + +func getMockTaskExecutionContext(ctx context.Context) *mocks.TaskExecutionContext { + tr := &mocks.TaskReader{} + tr.OnRead(ctx).Return(&core2.TaskTemplate{ + Target: &core2.TaskTemplate_Container{ + Container: createSampleContainerTask(), + }, + }, nil) + + tID := &mocks.TaskExecutionID{} + tID.OnGetGeneratedName().Return("notfound") + tID.OnGetID().Return(core2.TaskExecutionIdentifier{ + TaskId: &core2.Identifier{ + ResourceType: core2.ResourceType_TASK, + Project: "a", + Domain: "d", + Name: "n", + Version: "abc", + }, + NodeExecutionId: &core2.NodeExecutionIdentifier{ + NodeId: "node1", + ExecutionId: &core2.WorkflowExecutionIdentifier{ + Project: "a", + Domain: "d", + Name: "exec", + }, + }, + RetryAttempt: 0, + }) + + overrides := &mocks.TaskOverrides{} + overrides.OnGetResources().Return(&v1.ResourceRequirements{ + Requests: v1.ResourceList{ + v1.ResourceCPU: resource.MustParse("10"), + }, + }) + + tMeta := &mocks.TaskExecutionMetadata{} + tMeta.OnGetTaskExecutionID().Return(tID) + tMeta.OnGetOverrides().Return(overrides) + tMeta.OnIsInterruptible().Return(false) + tMeta.OnGetK8sServiceAccount().Return("s") + + tMeta.OnGetNamespace().Return("n") + tMeta.OnGetLabels().Return(nil) + tMeta.OnGetAnnotations().Return(nil) + tMeta.OnGetOwnerReference().Return(v12.OwnerReference{}) + + ow := &mocks2.OutputWriter{} + ow.OnGetOutputPrefixPath().Return("/prefix/") + ow.OnGetRawOutputPrefix().Return("/raw_prefix/") + + ir := &mocks2.InputReader{} + ir.OnGetInputPrefixPath().Return("/prefix/") + ir.OnGetInputPath().Return("/prefix/inputs.pb") + ir.OnGetMatch(mock.Anything).Return(&core2.LiteralMap{}, nil) + + tCtx := &mocks.TaskExecutionContext{} + tCtx.OnTaskReader().Return(tr) + tCtx.OnTaskExecutionMetadata().Return(tMeta) + tCtx.OnOutputWriter().Return(ow) + tCtx.OnInputReader().Return(ir) + return tCtx +} + +func TestCheckSubTasksState(t *testing.T) { + ctx := context.Background() + + tCtx := getMockTaskExecutionContext(ctx) + kubeClient := mocks.KubeClient{} + kubeClient.OnGetClient().Return(mocks.NewFakeKubeClient()) + resourceManager := mocks.ResourceManager{} + resourceManager.OnAllocateResourceMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(core.AllocationStatusExhausted, nil) + tCtx.OnResourceManager().Return(&resourceManager) + + t.Run("Happy case", func(t *testing.T) { + config := Config{MaxArrayJobSize: 100} + newState, _, err := LaunchAndCheckSubTasksState(ctx, tCtx, &kubeClient, &config, nil, "/prefix/", "/prefix-sand/", &arrayCore.State{ + CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions, + ExecutionArraySize: 5, + OriginalArraySize: 10, + OriginalMinSuccesses: 5, + }) + + assert.Nil(t, err) + //assert.NotEmpty(t, logLinks) + p, _ := newState.GetPhase() + assert.Equal(t, arrayCore.PhaseCheckingSubTaskExecutions.String(), p.String()) + resourceManager.AssertNumberOfCalls(t, "AllocateResource", 0) + }) + + t.Run("Resource exhausted", func(t *testing.T) { + config := Config{ + MaxArrayJobSize: 100, + ResourceConfig: ResourceConfig{ + PrimaryLabel: "p", + Limit: 10, + }, + } + + newState, _, err := LaunchAndCheckSubTasksState(ctx, tCtx, &kubeClient, &config, nil, "/prefix/", "/prefix-sand/", &arrayCore.State{ + CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions, + ExecutionArraySize: 5, + OriginalArraySize: 10, + OriginalMinSuccesses: 5, + }) + + assert.Nil(t, err) + p, _ := newState.GetPhase() + assert.Equal(t, arrayCore.PhaseWaitingForResources.String(), p.String()) + resourceManager.AssertNumberOfCalls(t, "AllocateResource", 5) + }) +} + +func TestCheckSubTasksStateResourceGranted(t *testing.T) { + ctx := context.Background() + + tCtx := getMockTaskExecutionContext(ctx) + kubeClient := mocks.KubeClient{} + kubeClient.OnGetClient().Return(mocks.NewFakeKubeClient()) + resourceManager := mocks.ResourceManager{} + resourceManager.OnAllocateResourceMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(core.AllocationStatusGranted, nil) + resourceManager.OnReleaseResourceMatch(mock.Anything, mock.Anything, mock.Anything).Return(nil) + tCtx.OnResourceManager().Return(&resourceManager) + + t.Run("Resource granted", func(t *testing.T) { + config := Config{ + MaxArrayJobSize: 100, + ResourceConfig: ResourceConfig{ + PrimaryLabel: "p", + Limit: 10, + }, + } + + newState, _, err := LaunchAndCheckSubTasksState(ctx, tCtx, &kubeClient, &config, nil, "/prefix/", "/prefix-sand/", &arrayCore.State{ + CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions, + ExecutionArraySize: 5, + OriginalArraySize: 10, + OriginalMinSuccesses: 5, + }) + + assert.Nil(t, err) + p, _ := newState.GetPhase() + assert.Equal(t, arrayCore.PhaseCheckingSubTaskExecutions.String(), p.String()) + resourceManager.AssertNumberOfCalls(t, "AllocateResource", 5) + }) + + t.Run("All tasks success", func(t *testing.T) { + config := Config{ + MaxArrayJobSize: 100, + ResourceConfig: ResourceConfig{ + PrimaryLabel: "p", + Limit: 10, + }, + } + + arrayStatus := &arraystatus.ArrayStatus{ + Summary: arraystatus.ArraySummary{}, + Detailed: arrayCore.NewPhasesCompactArray(uint(5)), + } + for childIdx := range arrayStatus.Detailed.GetItems() { + arrayStatus.Detailed.SetItem(childIdx, bitarray.Item(core.PhaseSuccess)) + + } + newState, _, err := LaunchAndCheckSubTasksState(ctx, tCtx, &kubeClient, &config, nil, "/prefix/", "/prefix-sand/", &arrayCore.State{ + CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions, + ExecutionArraySize: 5, + OriginalArraySize: 10, + OriginalMinSuccesses: 5, + ArrayStatus: *arrayStatus, + }) + + assert.Nil(t, err) + p, _ := newState.GetPhase() + assert.Equal(t, arrayCore.PhaseWriteToDiscovery.String(), p.String()) + resourceManager.AssertNumberOfCalls(t, "ReleaseResource", 5) + }) +} diff --git a/go/tasks/plugins/array/k8s/task.go b/go/tasks/plugins/array/k8s/task.go new file mode 100644 index 000000000..a236b4871 --- /dev/null +++ b/go/tasks/plugins/array/k8s/task.go @@ -0,0 +1,231 @@ +package k8s + +import ( + "context" + "strconv" + "strings" + + idlCore "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/utils" + "github.com/lyft/flyteplugins/go/tasks/plugins/array" + "github.com/lyft/flyteplugins/go/tasks/plugins/array/arraystatus" + arrayCore "github.com/lyft/flyteplugins/go/tasks/plugins/array/core" + "github.com/lyft/flyteplugins/go/tasks/plugins/array/errorcollector" + "github.com/lyft/flytestdlib/bitarray" + errors2 "github.com/lyft/flytestdlib/errors" + "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/storage" + corev1 "k8s.io/api/core/v1" + k8serrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + k8sTypes "k8s.io/apimachinery/pkg/types" +) + +type Task struct { + LogLinks []*idlCore.TaskLog + State *arrayCore.State + NewArrayStatus *arraystatus.ArrayStatus + Config *Config + ChildIdx int + MessageCollector *errorcollector.ErrorMessageCollector +} + +type LaunchResult int8 +type MonitorResult int8 + +const ( + LaunchSuccess LaunchResult = iota + LaunchError + LaunchWaiting + LaunchReturnState +) + +const ( + MonitorSuccess MonitorResult = iota + MonitorError +) + +func (t Task) Launch(ctx context.Context, tCtx core.TaskExecutionContext, kubeClient core.KubeClient) (LaunchResult, error) { + podTemplate, _, err := FlyteArrayJobToK8sPodTemplate(ctx, tCtx) + if err != nil { + return LaunchError, errors2.Wrapf(ErrBuildPodTemplate, err, "Failed to convert task template to a pod template for a task") + } + // Remove owner references for remote cluster execution + if t.Config.RemoteClusterConfig.Enabled { + podTemplate.OwnerReferences = nil + } + var args []string + if len(podTemplate.Spec.Containers) > 0 { + args = append(podTemplate.Spec.Containers[0].Command, podTemplate.Spec.Containers[0].Args...) + podTemplate.Spec.Containers[0].Command = []string{} + } else { + return LaunchError, errors2.Wrapf(ErrReplaceCmdTemplate, err, "No containers found in podSpec.") + } + + indexStr := strconv.Itoa(t.ChildIdx) + podName := formatSubTaskName(ctx, tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), indexStr) + + pod := podTemplate.DeepCopy() + pod.Name = podName + pod.Spec.Containers[0].Env = append(pod.Spec.Containers[0].Env, corev1.EnvVar{ + Name: FlyteK8sArrayIndexVarName, + Value: indexStr, + }) + + pod.Spec.Containers[0].Env = append(pod.Spec.Containers[0].Env, arrayJobEnvVars...) + pod.Spec.Containers[0].Args, err = utils.ReplaceTemplateCommandArgs(ctx, args, arrayJobInputReader{tCtx.InputReader()}, tCtx.OutputWriter()) + if err != nil { + return LaunchError, errors2.Wrapf(ErrReplaceCmdTemplate, err, "Failed to replace cmd args") + } + + pod = ApplyPodPolicies(ctx, t.Config, pod) + pod = applyNodeSelectorLabels(ctx, t.Config, pod) + pod = applyPodTolerations(ctx, t.Config, pod) + + allocationStatus, err := allocateResource(ctx, tCtx, t.Config, podName) + if err != nil { + return LaunchError, err + } + if allocationStatus != core.AllocationStatusGranted { + t.NewArrayStatus.Detailed.SetItem(t.ChildIdx, bitarray.Item(core.PhaseWaitingForResources)) + t.NewArrayStatus.Summary.Inc(core.PhaseWaitingForResources) + return LaunchWaiting, nil + } + + err = kubeClient.GetClient().Create(ctx, pod) + if err != nil && !k8serrors.IsAlreadyExists(err) { + if k8serrors.IsForbidden(err) { + if strings.Contains(err.Error(), "exceeded quota") { + // TODO: Quota errors are retried forever, it would be good to have support for backoff strategy. + logger.Infof(ctx, "Failed to launch job, resource quota exceeded. Err: %v", err) + t.State = t.State.SetPhase(arrayCore.PhaseWaitingForResources, 0).SetReason("Not enough resources to launch job") + } else { + t.State = t.State.SetPhase(arrayCore.PhaseRetryableFailure, 0).SetReason("Failed to launch job.") + } + + t.State = t.State.SetReason(err.Error()) + return LaunchReturnState, nil + } + + return LaunchError, errors2.Wrapf(ErrSubmitJob, err, "Failed to submit job.") + } + + return LaunchSuccess, nil +} + +func (t Task) Monitor(ctx context.Context, tCtx core.TaskExecutionContext, kubeClient core.KubeClient, dataStore *storage.DataStore, outputPrefix, baseOutputDataSandbox storage.DataReference) (MonitorResult, error) { + indexStr := strconv.Itoa(t.ChildIdx) + podName := formatSubTaskName(ctx, tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), indexStr) + phaseInfo, err := CheckPodStatus(ctx, kubeClient, + k8sTypes.NamespacedName{ + Name: podName, + Namespace: tCtx.TaskExecutionMetadata().GetNamespace(), + }) + if err != nil { + return MonitorError, errors2.Wrapf(ErrCheckPodStatus, err, "Failed to check pod status.") + } + + if phaseInfo.Info() != nil { + t.LogLinks = append(t.LogLinks, phaseInfo.Info().Logs...) + } + + if phaseInfo.Err() != nil { + t.MessageCollector.Collect(t.ChildIdx, phaseInfo.Err().String()) + } + + actualPhase := phaseInfo.Phase() + if phaseInfo.Phase().IsSuccess() { + originalIdx := arrayCore.CalculateOriginalIndex(t.ChildIdx, t.State.GetIndexesToCache()) + actualPhase, err = array.CheckTaskOutput(ctx, dataStore, outputPrefix, baseOutputDataSandbox, t.ChildIdx, originalIdx) + if err != nil { + return MonitorError, err + } + } + + t.NewArrayStatus.Detailed.SetItem(t.ChildIdx, bitarray.Item(actualPhase)) + t.NewArrayStatus.Summary.Inc(actualPhase) + + return MonitorSuccess, nil +} + +func (t Task) Abort(ctx context.Context, tCtx core.TaskExecutionContext, kubeClient core.KubeClient) error { + indexStr := strconv.Itoa(t.ChildIdx) + podName := formatSubTaskName(ctx, tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), indexStr) + pod := &corev1.Pod{ + TypeMeta: metav1.TypeMeta{ + Kind: PodKind, + APIVersion: metav1.SchemeGroupVersion.String(), + }, + ObjectMeta: metav1.ObjectMeta{ + Name: podName, + Namespace: tCtx.TaskExecutionMetadata().GetNamespace(), + }, + } + + err := kubeClient.GetClient().Delete(ctx, pod) + if err != nil { + if k8serrors.IsNotFound(err) { + + return nil + } + return err + } + + return nil + +} + +func (t Task) Finalize(ctx context.Context, tCtx core.TaskExecutionContext, kubeClient core.KubeClient) error { + indexStr := strconv.Itoa(t.ChildIdx) + podName := formatSubTaskName(ctx, tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), indexStr) + + // Deallocate Resource + err := deallocateResource(ctx, tCtx, t.Config, t.ChildIdx) + if err != nil { + logger.Errorf(ctx, "Error releasing allocation token [%s] in Finalize [%s]", podName, err) + return err + } + + return nil + +} + +func allocateResource(ctx context.Context, tCtx core.TaskExecutionContext, config *Config, podName string) (core.AllocationStatus, error) { + if !IsResourceConfigSet(config.ResourceConfig) { + return core.AllocationStatusGranted, nil + } + + resourceNamespace := core.ResourceNamespace(config.ResourceConfig.PrimaryLabel) + resourceConstraintSpec := core.ResourceConstraintsSpec{ + ProjectScopeResourceConstraint: nil, + NamespaceScopeResourceConstraint: nil, + } + + allocationStatus, err := tCtx.ResourceManager().AllocateResource(ctx, resourceNamespace, podName, resourceConstraintSpec) + if err != nil { + logger.Errorf(ctx, "Resource manager failed for TaskExecId [%s] token [%s]. error %s", + tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetID(), podName, err) + return core.AllocationUndefined, err + } + + logger.Infof(ctx, "Allocation result for [%s] is [%s]", podName, allocationStatus) + return allocationStatus, nil +} + +func deallocateResource(ctx context.Context, tCtx core.TaskExecutionContext, config *Config, childIdx int) error { + if !IsResourceConfigSet(config.ResourceConfig) { + return nil + } + indexStr := strconv.Itoa((childIdx)) + podName := formatSubTaskName(ctx, tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), indexStr) + resourceNamespace := core.ResourceNamespace(config.ResourceConfig.PrimaryLabel) + + err := tCtx.ResourceManager().ReleaseResource(ctx, resourceNamespace, podName) + if err != nil { + logger.Errorf(ctx, "Error releasing token [%s]. error %s", podName, err) + return err + } + + return nil +}