From 3050380a02a88067c3f0ffe3dd308d2cea8034d3 Mon Sep 17 00:00:00 2001 From: Luis Medina <3936213+lu4nm3@users.noreply.github.com> Date: Tue, 24 Mar 2020 13:15:48 -0700 Subject: [PATCH] Presto plugin executor (#69) * Presto plugin executor * mockery * add unit tests * more unit tests * update to correct import * fix lint * linting * last linting? * minor changes * expanded comment on state machine logic * PR feedback changes * PR feedback changes 2 * PR feedback changes 3 * PR feedback changes 4 * add user to execute args * resource reg * update status * e2e chages * update sync period * add input interpolator * changes * changes 2 * prefix implicit inputs with __ * comments * more feedback * edit metrics --- go.mod | 2 +- go.sum | 4 +- .../presto/client/mocks/presto_client.go | 126 ++++ .../presto/client/noop_presto_client.go | 43 ++ .../plugins/presto/client/presto_client.go | 35 ++ .../plugins/presto/client/presto_status.go | 43 ++ go/tasks/plugins/presto/config/config.go | 76 +++ .../plugins/presto/config/config_flags.go | 52 ++ .../presto/config/config_flags_test.go | 256 ++++++++ go/tasks/plugins/presto/execution_state.go | 562 ++++++++++++++++++ .../plugins/presto/execution_state_test.go | 358 +++++++++++ go/tasks/plugins/presto/executions_cache.go | 161 +++++ .../plugins/presto/executions_cache_test.go | 91 +++ go/tasks/plugins/presto/executor.go | 156 +++++ go/tasks/plugins/presto/executor_metrics.go | 25 + go/tasks/plugins/presto/helpers_test.go | 113 ++++ 16 files changed, 2100 insertions(+), 3 deletions(-) create mode 100644 go/tasks/plugins/presto/client/mocks/presto_client.go create mode 100644 go/tasks/plugins/presto/client/noop_presto_client.go create mode 100644 go/tasks/plugins/presto/client/presto_client.go create mode 100644 go/tasks/plugins/presto/client/presto_status.go create mode 100644 go/tasks/plugins/presto/config/config.go create mode 100755 go/tasks/plugins/presto/config/config_flags.go create mode 100755 go/tasks/plugins/presto/config/config_flags_test.go create mode 100644 go/tasks/plugins/presto/execution_state.go create mode 100644 go/tasks/plugins/presto/execution_state_test.go create mode 100644 go/tasks/plugins/presto/executions_cache.go create mode 100644 go/tasks/plugins/presto/executions_cache_test.go create mode 100644 go/tasks/plugins/presto/executor.go create mode 100644 go/tasks/plugins/presto/executor_metrics.go create mode 100644 go/tasks/plugins/presto/helpers_test.go diff --git a/go.mod b/go.mod index 45b4db65e..610aba01e 100644 --- a/go.mod +++ b/go.mod @@ -11,7 +11,7 @@ require ( github.com/golang/protobuf v1.3.3 github.com/googleapis/gnostic v0.4.1 // indirect github.com/hashicorp/golang-lru v0.5.4 - github.com/lyft/flyteidl v0.17.6 + github.com/lyft/flyteidl v0.17.9 github.com/lyft/flytestdlib v0.3.3 github.com/magiconair/properties v1.8.1 github.com/mitchellh/mapstructure v1.1.2 diff --git a/go.sum b/go.sum index f1738e912..ae9c3db9e 100644 --- a/go.sum +++ b/go.sum @@ -296,8 +296,8 @@ github.com/lyft/api v0.0.0-20191031200350-b49a72c274e0 h1:NGL46+1RYcCXb3sShp0nQq github.com/lyft/api v0.0.0-20191031200350-b49a72c274e0/go.mod h1:/L5qH+AD540e7Cetbui1tuJeXdmNhO8jM6VkXeDdDhQ= github.com/lyft/apimachinery v0.0.0-20191031200210-047e3ea32d7f h1:PGuAMDzAen0AulUfaEhNQMYmUpa41pAVo3zHI+GJsCM= github.com/lyft/apimachinery v0.0.0-20191031200210-047e3ea32d7f/go.mod h1:llRdnznGEAqC3DcNm6yEj472xaFVfLM7hnYofMb12tQ= -github.com/lyft/flyteidl v0.17.6 h1:O0qpT6ya45e/92+E84uGOYa0ZsaFoE5ZfPoyJ6e1bEQ= -github.com/lyft/flyteidl v0.17.6/go.mod h1:/zQXxuHO11u/saxTTZc8oYExIGEShXB+xCB1/F1Cu20= +github.com/lyft/flyteidl v0.17.9 h1:JXT9PovHqS9V3YN74x9zWT0kvIEL48c2uNoujF1KMes= +github.com/lyft/flyteidl v0.17.9/go.mod h1:/zQXxuHO11u/saxTTZc8oYExIGEShXB+xCB1/F1Cu20= github.com/lyft/flytestdlib v0.3.0 h1:nIkX4MlyYdcLLzaF35RI2P5BhARt+qMgHoFto8eVNzU= github.com/lyft/flytestdlib v0.3.0/go.mod h1:LJPPJlkFj+wwVWMrQT3K5JZgNhZi2mULsCG4ZYhinhU= github.com/lyft/flytestdlib v0.3.2 h1:bY6Y+Fg6Jdc7zY4GAYuR7t2hjWwynIdmRvtLcRNaGnw= diff --git a/go/tasks/plugins/presto/client/mocks/presto_client.go b/go/tasks/plugins/presto/client/mocks/presto_client.go new file mode 100644 index 000000000..ad732e91b --- /dev/null +++ b/go/tasks/plugins/presto/client/mocks/presto_client.go @@ -0,0 +1,126 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + context "context" + + client "github.com/lyft/flyteplugins/go/tasks/plugins/presto/client" + + mock "github.com/stretchr/testify/mock" +) + +// PrestoClient is an autogenerated mock type for the PrestoClient type +type PrestoClient struct { + mock.Mock +} + +type PrestoClient_ExecuteCommand struct { + *mock.Call +} + +func (_m PrestoClient_ExecuteCommand) Return(_a0 client.PrestoExecuteResponse, _a1 error) *PrestoClient_ExecuteCommand { + return &PrestoClient_ExecuteCommand{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *PrestoClient) OnExecuteCommand(ctx context.Context, commandStr string, executeArgs client.PrestoExecuteArgs) *PrestoClient_ExecuteCommand { + c := _m.On("ExecuteCommand", ctx, commandStr, executeArgs) + return &PrestoClient_ExecuteCommand{Call: c} +} + +func (_m *PrestoClient) OnExecuteCommandMatch(matchers ...interface{}) *PrestoClient_ExecuteCommand { + c := _m.On("ExecuteCommand", matchers...) + return &PrestoClient_ExecuteCommand{Call: c} +} + +// ExecuteCommand provides a mock function with given fields: ctx, commandStr, executeArgs +func (_m *PrestoClient) ExecuteCommand(ctx context.Context, commandStr string, executeArgs client.PrestoExecuteArgs) (client.PrestoExecuteResponse, error) { + ret := _m.Called(ctx, commandStr, executeArgs) + + var r0 client.PrestoExecuteResponse + if rf, ok := ret.Get(0).(func(context.Context, string, client.PrestoExecuteArgs) client.PrestoExecuteResponse); ok { + r0 = rf(ctx, commandStr, executeArgs) + } else { + r0 = ret.Get(0).(client.PrestoExecuteResponse) + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, string, client.PrestoExecuteArgs) error); ok { + r1 = rf(ctx, commandStr, executeArgs) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type PrestoClient_GetCommandStatus struct { + *mock.Call +} + +func (_m PrestoClient_GetCommandStatus) Return(_a0 client.PrestoStatus, _a1 error) *PrestoClient_GetCommandStatus { + return &PrestoClient_GetCommandStatus{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *PrestoClient) OnGetCommandStatus(ctx context.Context, commandID string) *PrestoClient_GetCommandStatus { + c := _m.On("GetCommandStatus", ctx, commandID) + return &PrestoClient_GetCommandStatus{Call: c} +} + +func (_m *PrestoClient) OnGetCommandStatusMatch(matchers ...interface{}) *PrestoClient_GetCommandStatus { + c := _m.On("GetCommandStatus", matchers...) + return &PrestoClient_GetCommandStatus{Call: c} +} + +// GetCommandStatus provides a mock function with given fields: ctx, commandID +func (_m *PrestoClient) GetCommandStatus(ctx context.Context, commandID string) (client.PrestoStatus, error) { + ret := _m.Called(ctx, commandID) + + var r0 client.PrestoStatus + if rf, ok := ret.Get(0).(func(context.Context, string) client.PrestoStatus); ok { + r0 = rf(ctx, commandID) + } else { + r0 = ret.Get(0).(client.PrestoStatus) + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, commandID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type PrestoClient_KillCommand struct { + *mock.Call +} + +func (_m PrestoClient_KillCommand) Return(_a0 error) *PrestoClient_KillCommand { + return &PrestoClient_KillCommand{Call: _m.Call.Return(_a0)} +} + +func (_m *PrestoClient) OnKillCommand(ctx context.Context, commandID string) *PrestoClient_KillCommand { + c := _m.On("KillCommand", ctx, commandID) + return &PrestoClient_KillCommand{Call: c} +} + +func (_m *PrestoClient) OnKillCommandMatch(matchers ...interface{}) *PrestoClient_KillCommand { + c := _m.On("KillCommand", matchers...) + return &PrestoClient_KillCommand{Call: c} +} + +// KillCommand provides a mock function with given fields: ctx, commandID +func (_m *PrestoClient) KillCommand(ctx context.Context, commandID string) error { + ret := _m.Called(ctx, commandID) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = rf(ctx, commandID) + } else { + r0 = ret.Error(0) + } + + return r0 +} diff --git a/go/tasks/plugins/presto/client/noop_presto_client.go b/go/tasks/plugins/presto/client/noop_presto_client.go new file mode 100644 index 000000000..98facaf3d --- /dev/null +++ b/go/tasks/plugins/presto/client/noop_presto_client.go @@ -0,0 +1,43 @@ +package client + +import ( + "context" + "net/http" + "net/url" + + "time" + + "github.com/lyft/flyteplugins/go/tasks/plugins/presto/config" +) + +const ( + httpRequestTimeoutSecs = 30 +) + +type noopPrestoClient struct { + client *http.Client + environment *url.URL +} + +func (p noopPrestoClient) ExecuteCommand( + ctx context.Context, + queryStr string, + executeArgs PrestoExecuteArgs) (PrestoExecuteResponse, error) { + + return PrestoExecuteResponse{}, nil +} + +func (p noopPrestoClient) KillCommand(ctx context.Context, commandID string) error { + return nil +} + +func (p noopPrestoClient) GetCommandStatus(ctx context.Context, commandID string) (PrestoStatus, error) { + return NewPrestoStatus(ctx, "UNKNOWN"), nil +} + +func NewNoopPrestoClient(cfg *config.Config) PrestoClient { + return &noopPrestoClient{ + client: &http.Client{Timeout: httpRequestTimeoutSecs * time.Second}, + environment: cfg.Environment.ResolveReference(&cfg.Environment.URL), + } +} diff --git a/go/tasks/plugins/presto/client/presto_client.go b/go/tasks/plugins/presto/client/presto_client.go new file mode 100644 index 000000000..fb8812e2c --- /dev/null +++ b/go/tasks/plugins/presto/client/presto_client.go @@ -0,0 +1,35 @@ +package client + +import "context" + +type PrestoStatus string + +// Contains information needed to execute a Presto query +type PrestoExecuteArgs struct { + RoutingGroup string `json:"routingGroup,omitempty"` + Catalog string `json:"catalog,omitempty"` + Schema string `json:"schema,omitempty"` + Source string `json:"source,omitempty"` + User string `json:"user,omitempty"` +} + +// Representation of a response after submitting a query to Presto +type PrestoExecuteResponse struct { + ID string `json:"id,omitempty"` + Status PrestoStatus `json:"status,omitempty"` + NextURI string `json:"nextUri,omitempty"` +} + +//go:generate mockery -all -case=snake + +// Interface to interact with PrestoClient for Presto tasks +type PrestoClient interface { + // Submits a query to Presto + ExecuteCommand(ctx context.Context, commandStr string, executeArgs PrestoExecuteArgs) (PrestoExecuteResponse, error) + + // Cancels a currently running Presto query + KillCommand(ctx context.Context, commandID string) error + + // Gets the status of a Presto query + GetCommandStatus(ctx context.Context, commandID string) (PrestoStatus, error) +} diff --git a/go/tasks/plugins/presto/client/presto_status.go b/go/tasks/plugins/presto/client/presto_status.go new file mode 100644 index 000000000..d5e772f6c --- /dev/null +++ b/go/tasks/plugins/presto/client/presto_status.go @@ -0,0 +1,43 @@ +package client + +import ( + "context" + "strings" + + "github.com/lyft/flytestdlib/logger" +) + +// This type is meant only to encapsulate the response coming from Presto as a type, it is +// not meant to be stored locally. +const ( + PrestoStatusUnknown PrestoStatus = "UNKNOWN" + PrestoStatusWaiting PrestoStatus = "WAITING" + PrestoStatusRunning PrestoStatus = "RUNNING" + PrestoStatusFinished PrestoStatus = "FINISHED" + PrestoStatusFailed PrestoStatus = "FAILED" + PrestoStatusCancelled PrestoStatus = "CANCELLED" +) + +var PrestoStatuses = map[PrestoStatus]struct{}{ + PrestoStatusUnknown: {}, + PrestoStatusWaiting: {}, + PrestoStatusRunning: {}, + PrestoStatusFinished: {}, + PrestoStatusFailed: {}, + PrestoStatusCancelled: {}, +} + +func NewPrestoStatus(ctx context.Context, state string) PrestoStatus { + upperCased := strings.ToUpper(state) + + // Presto has different failure modes so this maps them all to a single Failure on the + // Flyte side + if strings.Contains(upperCased, "FAILED") { + return PrestoStatusFailed + } else if _, ok := PrestoStatuses[PrestoStatus(upperCased)]; ok { + return PrestoStatus(upperCased) + } else { + logger.Warnf(ctx, "Invalid Presto Status found: %v", state) + return PrestoStatusUnknown + } +} diff --git a/go/tasks/plugins/presto/config/config.go b/go/tasks/plugins/presto/config/config.go new file mode 100644 index 000000000..21f2ff917 --- /dev/null +++ b/go/tasks/plugins/presto/config/config.go @@ -0,0 +1,76 @@ +package config + +//go:generate pflags Config --default-var=defaultConfig + +import ( + "context" + "net/url" + "time" + + "github.com/lyft/flytestdlib/config" + "github.com/lyft/flytestdlib/logger" + + pluginsConfig "github.com/lyft/flyteplugins/go/tasks/config" +) + +const prestoConfigSectionKey = "presto" + +func URLMustParse(s string) config.URL { + r, err := url.Parse(s) + if err != nil { + logger.Panicf(context.TODO(), "Bad Presto URL Specified as default, error: %s", err) + } + if r == nil { + logger.Panicf(context.TODO(), "Nil Presto URL specified.", err) + } + return config.URL{URL: *r} +} + +type RoutingGroupConfig struct { + Name string `json:"name" pflag:",The name of a given Presto routing group"` + Limit int `json:"limit" pflag:",Resource quota (in the number of outstanding requests) of the routing group"` + ProjectScopeQuotaProportionCap float64 `json:"projectScopeQuotaProportionCap" pflag:",A floating point number between 0 and 1, specifying the maximum proportion of quotas allowed to allocate to a project in the routing group"` + NamespaceScopeQuotaProportionCap float64 `json:"namespaceScopeQuotaProportionCap" pflag:",A floating point number between 0 and 1, specifying the maximum proportion of quotas allowed to allocate to a namespace in the routing group"` +} + +type RefreshCacheConfig struct { + Name string `json:"name" pflag:",The name of the rate limiter"` + SyncPeriod config.Duration `json:"syncPeriod" pflag:",The duration to wait before the cache is refreshed again"` + Workers int `json:"workers" pflag:",Number of parallel workers to refresh the cache"` + LruCacheSize int `json:"lruCacheSize" pflag:",Size of the cache"` +} + +var ( + defaultConfig = Config{ + Environment: URLMustParse(""), + DefaultRoutingGroup: "adhoc", + DefaultUser: "flyte-default-user", + RoutingGroupConfigs: []RoutingGroupConfig{{Name: "adhoc", Limit: 250}, {Name: "etl", Limit: 100}}, + RefreshCacheConfig: RefreshCacheConfig{ + Name: "presto", + SyncPeriod: config.Duration{Duration: 5 * time.Second}, + Workers: 15, + LruCacheSize: 10000, + }, + } + + prestoConfigSection = pluginsConfig.MustRegisterSubSection(prestoConfigSectionKey, &defaultConfig) +) + +// Presto plugin configs +type Config struct { + Environment config.URL `json:"environment" pflag:",Environment endpoint for Presto to use"` + DefaultRoutingGroup string `json:"defaultRoutingGroup" pflag:",Default Presto routing group"` + DefaultUser string `json:"defaultUser" pflag:",Default Presto user"` + RoutingGroupConfigs []RoutingGroupConfig `json:"routingGroupConfigs" pflag:"-,A list of cluster configs. Each of the configs corresponds to a service cluster"` + RefreshCacheConfig RefreshCacheConfig `json:"refreshCacheConfig" pflag:"Rate limiter config"` +} + +// Retrieves the current config value or default. +func GetPrestoConfig() *Config { + return prestoConfigSection.GetConfig().(*Config) +} + +func SetPrestoConfig(cfg *Config) error { + return prestoConfigSection.SetConfig(cfg) +} diff --git a/go/tasks/plugins/presto/config/config_flags.go b/go/tasks/plugins/presto/config/config_flags.go new file mode 100755 index 000000000..c4200c63c --- /dev/null +++ b/go/tasks/plugins/presto/config/config_flags.go @@ -0,0 +1,52 @@ +// Code generated by go generate; DO NOT EDIT. +// This file was generated by robots. + +package config + +import ( + "encoding/json" + "reflect" + + "fmt" + + "github.com/spf13/pflag" +) + +// If v is a pointer, it will get its element value or the zero value of the element type. +// If v is not a pointer, it will return it as is. +func (Config) elemValueOrNil(v interface{}) interface{} { + if t := reflect.TypeOf(v); t.Kind() == reflect.Ptr { + if reflect.ValueOf(v).IsNil() { + return reflect.Zero(t.Elem()).Interface() + } else { + return reflect.ValueOf(v).Interface() + } + } else if v == nil { + return reflect.Zero(t).Interface() + } + + return v +} + +func (Config) mustMarshalJSON(v json.Marshaler) string { + raw, err := v.MarshalJSON() + if err != nil { + panic(err) + } + + return string(raw) +} + +// GetPFlagSet will return strongly types pflags for all fields in Config and its nested types. The format of the +// flags is json-name.json-sub-name... etc. +func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { + cmdFlags := pflag.NewFlagSet("Config", pflag.ExitOnError) + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "environment"), defaultConfig.Environment.String(), "Environment endpoint for Presto to use") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "defaultRoutingGroup"), defaultConfig.DefaultRoutingGroup, "Default Presto routing group") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "defaultUser"), defaultConfig.DefaultUser, "Default Presto user") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "refreshCacheConfig.name"), defaultConfig.RefreshCacheConfig.Name, "The name of the rate limiter") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "refreshCacheConfig.syncPeriod"), defaultConfig.RefreshCacheConfig.SyncPeriod.String(), "The duration to wait before the cache is refreshed again") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "refreshCacheConfig.workers"), defaultConfig.RefreshCacheConfig.Workers, "Number of parallel workers to refresh the cache") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "refreshCacheConfig.lruCacheSize"), defaultConfig.RefreshCacheConfig.LruCacheSize, "Size of the cache") + return cmdFlags +} diff --git a/go/tasks/plugins/presto/config/config_flags_test.go b/go/tasks/plugins/presto/config/config_flags_test.go new file mode 100755 index 000000000..00820c7be --- /dev/null +++ b/go/tasks/plugins/presto/config/config_flags_test.go @@ -0,0 +1,256 @@ +// Code generated by go generate; DO NOT EDIT. +// This file was generated by robots. + +package config + +import ( + "encoding/json" + "fmt" + "reflect" + "strings" + "testing" + + "github.com/mitchellh/mapstructure" + "github.com/stretchr/testify/assert" +) + +var dereferencableKindsConfig = map[reflect.Kind]struct{}{ + reflect.Array: {}, reflect.Chan: {}, reflect.Map: {}, reflect.Ptr: {}, reflect.Slice: {}, +} + +// Checks if t is a kind that can be dereferenced to get its underlying type. +func canGetElementConfig(t reflect.Kind) bool { + _, exists := dereferencableKindsConfig[t] + return exists +} + +// This decoder hook tests types for json unmarshaling capability. If implemented, it uses json unmarshal to build the +// object. Otherwise, it'll just pass on the original data. +func jsonUnmarshalerHookConfig(_, to reflect.Type, data interface{}) (interface{}, error) { + unmarshalerType := reflect.TypeOf((*json.Unmarshaler)(nil)).Elem() + if to.Implements(unmarshalerType) || reflect.PtrTo(to).Implements(unmarshalerType) || + (canGetElementConfig(to.Kind()) && to.Elem().Implements(unmarshalerType)) { + + raw, err := json.Marshal(data) + if err != nil { + fmt.Printf("Failed to marshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err) + return data, nil + } + + res := reflect.New(to).Interface() + err = json.Unmarshal(raw, &res) + if err != nil { + fmt.Printf("Failed to umarshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err) + return data, nil + } + + return res, nil + } + + return data, nil +} + +func decode_Config(input, result interface{}) error { + config := &mapstructure.DecoderConfig{ + TagName: "json", + WeaklyTypedInput: true, + Result: result, + DecodeHook: mapstructure.ComposeDecodeHookFunc( + mapstructure.StringToTimeDurationHookFunc(), + mapstructure.StringToSliceHookFunc(","), + jsonUnmarshalerHookConfig, + ), + } + + decoder, err := mapstructure.NewDecoder(config) + if err != nil { + return err + } + + return decoder.Decode(input) +} + +func join_Config(arr interface{}, sep string) string { + listValue := reflect.ValueOf(arr) + strs := make([]string, 0, listValue.Len()) + for i := 0; i < listValue.Len(); i++ { + strs = append(strs, fmt.Sprintf("%v", listValue.Index(i))) + } + + return strings.Join(strs, sep) +} + +func testDecodeJson_Config(t *testing.T, val, result interface{}) { + assert.NoError(t, decode_Config(val, result)) +} + +func testDecodeSlice_Config(t *testing.T, vStringSlice, result interface{}) { + assert.NoError(t, decode_Config(vStringSlice, result)) +} + +func TestConfig_GetPFlagSet(t *testing.T) { + val := Config{} + cmdFlags := val.GetPFlagSet("") + assert.True(t, cmdFlags.HasFlags()) +} + +func TestConfig_SetFlags(t *testing.T) { + actual := Config{} + cmdFlags := actual.GetPFlagSet("") + assert.True(t, cmdFlags.HasFlags()) + + t.Run("Test_environment", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("environment"); err == nil { + assert.Equal(t, string(defaultConfig.Environment.String()), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := defaultConfig.Environment.String() + + cmdFlags.Set("environment", testValue) + if vString, err := cmdFlags.GetString("environment"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.Environment) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_defaultRoutingGroup", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("defaultRoutingGroup"); err == nil { + assert.Equal(t, string(defaultConfig.DefaultRoutingGroup), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("defaultRoutingGroup", testValue) + if vString, err := cmdFlags.GetString("defaultRoutingGroup"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.DefaultRoutingGroup) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_defaultUser", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("defaultUser"); err == nil { + assert.Equal(t, string(defaultConfig.DefaultUser), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("defaultUser", testValue) + if vString, err := cmdFlags.GetString("defaultUser"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.DefaultUser) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_refreshCacheConfig.name", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("refreshCacheConfig.name"); err == nil { + assert.Equal(t, string(defaultConfig.RefreshCacheConfig.Name), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("refreshCacheConfig.name", testValue) + if vString, err := cmdFlags.GetString("refreshCacheConfig.name"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.RefreshCacheConfig.Name) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_refreshCacheConfig.syncPeriod", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("refreshCacheConfig.syncPeriod"); err == nil { + assert.Equal(t, string(defaultConfig.RefreshCacheConfig.SyncPeriod.String()), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := defaultConfig.RefreshCacheConfig.SyncPeriod.String() + + cmdFlags.Set("refreshCacheConfig.syncPeriod", testValue) + if vString, err := cmdFlags.GetString("refreshCacheConfig.syncPeriod"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.RefreshCacheConfig.SyncPeriod) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_refreshCacheConfig.workers", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt, err := cmdFlags.GetInt("refreshCacheConfig.workers"); err == nil { + assert.Equal(t, int(defaultConfig.RefreshCacheConfig.Workers), vInt) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("refreshCacheConfig.workers", testValue) + if vInt, err := cmdFlags.GetInt("refreshCacheConfig.workers"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.RefreshCacheConfig.Workers) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_refreshCacheConfig.lruCacheSize", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt, err := cmdFlags.GetInt("refreshCacheConfig.lruCacheSize"); err == nil { + assert.Equal(t, int(defaultConfig.RefreshCacheConfig.LruCacheSize), vInt) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("refreshCacheConfig.lruCacheSize", testValue) + if vInt, err := cmdFlags.GetInt("refreshCacheConfig.lruCacheSize"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.RefreshCacheConfig.LruCacheSize) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) +} diff --git a/go/tasks/plugins/presto/execution_state.go b/go/tasks/plugins/presto/execution_state.go new file mode 100644 index 000000000..6217b4b21 --- /dev/null +++ b/go/tasks/plugins/presto/execution_state.go @@ -0,0 +1,562 @@ +package presto + +import ( + "context" + "strings" + + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/ioutils" + + "k8s.io/apimachinery/pkg/util/rand" + + "fmt" + + "github.com/lyft/flyteplugins/go/tasks/plugins/presto/client" + + "time" + + "github.com/lyft/flytestdlib/cache" + + idlCore "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/plugins" + + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/utils" + "github.com/lyft/flyteplugins/go/tasks/plugins/presto/config" + + "github.com/lyft/flyteplugins/go/tasks/errors" + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/lyft/flytestdlib/logger" + + pb "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" +) + +type ExecutionPhase int + +const ( + PhaseNotStarted ExecutionPhase = iota + PhaseQueued // resource manager token gotten + PhaseSubmitted // Sent off to Presto + PhaseQuerySucceeded + PhaseQueryFailed +) + +func (p ExecutionPhase) String() string { + switch p { + case PhaseNotStarted: + return "PhaseNotStarted" + case PhaseQueued: + return "PhaseQueued" + case PhaseSubmitted: + return "PhaseSubmitted" + case PhaseQuerySucceeded: + return "PhaseQuerySucceeded" + case PhaseQueryFailed: + return "PhaseQueryFailed" + } + return "Bad Presto execution phase" +} + +type ExecutionState struct { + Phase ExecutionPhase + + // This will store the command ID from Presto + CommandID string `json:"commandId,omitempty"` + + // This will have the nextUri from Presto which is used to advance the query forward + URI string `json:"uri,omitempty"` + + // This is the current Presto query (out of 5) needed to complete a Presto task + CurrentPrestoQuery Query `json:"currentPrestoQuery,omitempty"` + + // This is an id to keep track of the current query. Every query's id should be unique for caching purposes + CurrentPrestoQueryUUID string `json:"currentPrestoQueryUUID,omitempty"` + + // Keeps track of which Presto query we are on. Its values range from 0-4 for the 5 queries that are needed + QueryCount int `json:"queryCount,omitempty"` + + // This number keeps track of the number of failures within the sync function. Without this, what happens in + // the sync function is entirely opaque. Note that this field is completely orthogonal to Flyte system/node/task + // level retries, just errors from hitting the Presto API, inside the sync loop + SyncFailureCount int `json:"syncFailureCount,omitempty"` + + // In kicking off the Presto command, this is the number of failures + CreationFailureCount int `json:"creationFailureCount,omitempty"` + + // The time the execution first requests for an allocation token + AllocationTokenRequestStartTime time.Time `json:"allocationTokenRequestStartTime,omitempty"` +} + +type Query struct { + Statement string `json:"statement,omitempty"` + ExecuteArgs client.PrestoExecuteArgs `json:"executeArgs,omitempty"` + TempTableName string `json:"tempTableName,omitempty"` + ExternalTableName string `json:"externalTableName,omitempty"` + ExternalLocation string `json:"externalLocation"` +} + +// This is the main state iteration +func HandleExecutionState( + ctx context.Context, + tCtx core.TaskExecutionContext, + currentState ExecutionState, + prestoClient client.PrestoClient, + executionsCache cache.AutoRefresh, + metrics ExecutorMetrics) (ExecutionState, error) { + + var transformError error + var newState ExecutionState + + switch currentState.Phase { + case PhaseNotStarted: + newState, transformError = GetAllocationToken(ctx, tCtx, currentState, metrics) + + case PhaseQueued: + prestoQuery, err := GetNextQuery(ctx, tCtx, currentState) + if err != nil { + return ExecutionState{}, err + } + currentState.CurrentPrestoQuery = prestoQuery + newState, transformError = KickOffQuery(ctx, tCtx, currentState, prestoClient, executionsCache) + + case PhaseSubmitted: + newState, transformError = MonitorQuery(ctx, tCtx, currentState, executionsCache) + + case PhaseQuerySucceeded: + if currentState.QueryCount < 4 { + // If there are still Presto statements to execute, increment the query count, reset the phase to 'queued' + // and continue executing the remaining statements. In this case, we won't request another allocation token + // as the 5 statements that get executed are all considered to be part of the same "query" + currentState.Phase = PhaseQueued + } else { + transformError = writeOutput(ctx, tCtx, currentState.CurrentPrestoQuery.ExternalLocation) + } + currentState.QueryCount++ + newState = currentState + + case PhaseQueryFailed: + newState = currentState + transformError = nil + } + + return newState, transformError +} + +func GetAllocationToken( + ctx context.Context, + tCtx core.TaskExecutionContext, + currentState ExecutionState, + metric ExecutorMetrics) (ExecutionState, error) { + + newState := ExecutionState{} + uniqueID := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() + + routingGroup, err := composeResourceNamespaceWithRoutingGroup(ctx, tCtx) + if err != nil { + return newState, errors.Wrapf(errors.ResourceManagerFailure, err, "Error getting query info when requesting allocation token %s", uniqueID) + } + + resourceConstraintsSpec := createResourceConstraintsSpec(ctx, tCtx, routingGroup) + + allocationStatus, err := tCtx.ResourceManager().AllocateResource(ctx, routingGroup, uniqueID, resourceConstraintsSpec) + if err != nil { + logger.Errorf(ctx, "Resource manager failed for TaskExecId [%s] token [%s]. error %s", + tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetID(), uniqueID, err) + return newState, errors.Wrapf(errors.ResourceManagerFailure, err, "Error requesting allocation token %s", uniqueID) + } + logger.Infof(ctx, "Allocation result for [%s] is [%s]", uniqueID, allocationStatus) + + // Emitting the duration this execution has been waiting for a token allocation + if currentState.AllocationTokenRequestStartTime.IsZero() { + newState.AllocationTokenRequestStartTime = time.Now() + } else { + newState.AllocationTokenRequestStartTime = currentState.AllocationTokenRequestStartTime + } + + if allocationStatus == core.AllocationStatusGranted { + newState.Phase = PhaseQueued + } else if allocationStatus == core.AllocationStatusExhausted { + newState.Phase = PhaseNotStarted + } else if allocationStatus == core.AllocationStatusNamespaceQuotaExceeded { + newState.Phase = PhaseNotStarted + } else { + return newState, errors.Errorf(errors.ResourceManagerFailure, "Got bad allocation result [%s] for token [%s]", + allocationStatus, uniqueID) + } + + return newState, nil +} + +func composeResourceNamespaceWithRoutingGroup(ctx context.Context, tCtx core.TaskExecutionContext) (core.ResourceNamespace, error) { + routingGroup, _, _, _, err := GetQueryInfo(ctx, tCtx) + if err != nil { + return "", err + } + clusterPrimaryLabel := resolveRoutingGroup(ctx, routingGroup, config.GetPrestoConfig()) + return core.ResourceNamespace(clusterPrimaryLabel), nil +} + +// This function is the link between the output written by the SDK, and the execution side. It extracts the query +// out of the task template. +func GetQueryInfo(ctx context.Context, tCtx core.TaskExecutionContext) (string, string, string, string, error) { + taskTemplate, err := tCtx.TaskReader().Read(ctx) + if err != nil { + return "", "", "", "", err + } + + prestoQuery := plugins.PrestoQuery{} + if err := utils.UnmarshalStruct(taskTemplate.GetCustom(), &prestoQuery); err != nil { + return "", "", "", "", err + } + + if err := validatePrestoStatement(prestoQuery); err != nil { + return "", "", "", "", err + } + + outputs, err := utils.ReplaceTemplateCommandArgs(ctx, []string{ + prestoQuery.RoutingGroup, + prestoQuery.Catalog, + prestoQuery.Schema, + prestoQuery.Statement, + }, tCtx.InputReader(), tCtx.OutputWriter()) + if err != nil { + return "", "", "", "", err + } + + routingGroup := outputs[0] + catalog := outputs[1] + schema := outputs[2] + statement := outputs[3] + + logger.Debugf(ctx, "QueryInfo: query: [%v], routingGroup: [%v], catalog: [%v], schema: [%v]", statement, routingGroup, catalog, schema) + return routingGroup, catalog, schema, statement, err +} + +func validatePrestoStatement(prestoJob plugins.PrestoQuery) error { + if prestoJob.Statement == "" { + return errors.Errorf(errors.BadTaskSpecification, + "Query could not be found. Please ensure that you are at least on Flytekit version 0.3.0 or later.") + } + return nil +} + +func resolveRoutingGroup(ctx context.Context, routingGroup string, prestoCfg *config.Config) string { + if routingGroup == "" { + logger.Debugf(ctx, "Input routing group is an empty string; falling back to using the default routing group [%v]", prestoCfg.DefaultRoutingGroup) + return prestoCfg.DefaultRoutingGroup + } + + for _, routingGroupCfg := range prestoCfg.RoutingGroupConfigs { + if routingGroup == routingGroupCfg.Name { + logger.Debugf(ctx, "Found the Presto routing group: [%v]", routingGroupCfg.Name) + return routingGroup + } + } + + logger.Debugf(ctx, "Cannot find the routing group [%v] in configmap; "+ + "falling back to using the default routing group [%v]", routingGroup, prestoCfg.DefaultRoutingGroup) + return prestoCfg.DefaultRoutingGroup +} + +func createResourceConstraintsSpec(ctx context.Context, _ core.TaskExecutionContext, routingGroup core.ResourceNamespace) core.ResourceConstraintsSpec { + cfg := config.GetPrestoConfig() + constraintsSpec := core.ResourceConstraintsSpec{ + ProjectScopeResourceConstraint: nil, + NamespaceScopeResourceConstraint: nil, + } + if cfg.RoutingGroupConfigs == nil { + logger.Infof(ctx, "No routing group config is found. Returning an empty resource constraints spec") + return constraintsSpec + } + for _, routingGroupCfg := range cfg.RoutingGroupConfigs { + if routingGroupCfg.Name == string(routingGroup) { + constraintsSpec.ProjectScopeResourceConstraint = &core.ResourceConstraint{Value: int64(float64(routingGroupCfg.Limit) * routingGroupCfg.ProjectScopeQuotaProportionCap)} + constraintsSpec.NamespaceScopeResourceConstraint = &core.ResourceConstraint{Value: int64(float64(routingGroupCfg.Limit) * routingGroupCfg.NamespaceScopeQuotaProportionCap)} + break + } + } + logger.Infof(ctx, "Created a resource constraints spec: [%v]", constraintsSpec) + return constraintsSpec +} + +func GetNextQuery( + ctx context.Context, + tCtx core.TaskExecutionContext, + currentState ExecutionState) (Query, error) { + + switch currentState.QueryCount { + case 0: + prestoCfg := config.GetPrestoConfig() + tempTableName := rand.String(32) + routingGroup, catalog, schema, statement, err := GetQueryInfo(ctx, tCtx) + if err != nil { + return Query{}, err + } + + statement = fmt.Sprintf(`CREATE TABLE hive.flyte_temporary_tables."%s_temp" AS %s`, tempTableName, statement) + + prestoQuery := Query{ + Statement: statement, + ExecuteArgs: client.PrestoExecuteArgs{ + RoutingGroup: resolveRoutingGroup(ctx, routingGroup, prestoCfg), + Catalog: catalog, + Schema: schema, + Source: "flyte", + User: getUser(ctx, prestoCfg.DefaultUser), + }, + TempTableName: tempTableName + "_temp", + ExternalTableName: tempTableName + "_external", + } + + return prestoQuery, nil + + case 1: + // TODO + externalLocation := getExternalLocation("s3://lyft-modelbuilder/{}/", 2) + statement := fmt.Sprintf(` +CREATE TABLE hive.flyte_temporary_tables."%s" (LIKE hive.flyte_temporary_tables."%s") +WITH (format = 'PARQUET', external_location = '%s')`, + currentState.CurrentPrestoQuery.ExternalTableName, + currentState.CurrentPrestoQuery.TempTableName, + externalLocation, + ) + currentState.CurrentPrestoQuery.Statement = statement + currentState.CurrentPrestoQuery.ExternalLocation = externalLocation + return currentState.CurrentPrestoQuery, nil + + case 2: + statement := ` +INSERT INTO hive.flyte_temporary_tables."%s" +SELECT * +FROM hive.flyte_temporary_tables."%s"` + statement = fmt.Sprintf(statement, currentState.CurrentPrestoQuery.ExternalTableName, currentState.CurrentPrestoQuery.TempTableName) + currentState.CurrentPrestoQuery.Statement = statement + return currentState.CurrentPrestoQuery, nil + + case 3: + statement := fmt.Sprintf(`DROP TABLE hive.flyte_temporary_tables."%s"`, currentState.CurrentPrestoQuery.TempTableName) + currentState.CurrentPrestoQuery.Statement = statement + return currentState.CurrentPrestoQuery, nil + + case 4: + statement := fmt.Sprintf(`DROP TABLE hive.flyte_temporary_tables."%s"`, currentState.CurrentPrestoQuery.ExternalTableName) + currentState.CurrentPrestoQuery.Statement = statement + return currentState.CurrentPrestoQuery, nil + + default: + return currentState.CurrentPrestoQuery, nil + } +} + +func getExternalLocation(shardFormatter string, shardLength int) string { + shardCount := strings.Count(shardFormatter, "{}") + for i := 0; i < shardCount; i++ { + shardFormatter = strings.Replace(shardFormatter, "{}", rand.String(shardLength), 1) + } + + return shardFormatter + rand.String(32) + "/" +} + +func getUser(ctx context.Context, defaultUser string) string { + principalContextUser := ctx.Value("principal") + if principalContextUser != nil { + return fmt.Sprintf("%v", principalContextUser) + } + return defaultUser +} + +func KickOffQuery( + ctx context.Context, + tCtx core.TaskExecutionContext, + currentState ExecutionState, + prestoClient client.PrestoClient, + cache cache.AutoRefresh) (ExecutionState, error) { + + // For the caching id, we can't rely simply on the task execution id since we have to run 5 consecutive queries and + // the ids used for each of these has to be unique. Because of this, we append a random postfix to the task + // execution id. + uniqueID := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() + "_" + rand.String(32) + + statement := currentState.CurrentPrestoQuery.Statement + executeArgs := currentState.CurrentPrestoQuery.ExecuteArgs + + response, err := prestoClient.ExecuteCommand(ctx, statement, executeArgs) + if err != nil { + // If we failed, we'll keep the NotStarted state + currentState.CreationFailureCount = currentState.CreationFailureCount + 1 + logger.Warnf(ctx, "Error creating Presto query for %s, failure counts %d. Error: %s", uniqueID, currentState.CreationFailureCount, err) + } else { + // If we succeed, then store the command id returned from Presto, and update our state. Also, add to the + // AutoRefreshCache so we start getting updates for its status. + commandID := response.ID + logger.Infof(ctx, "Created Presto ID [%s] for token %s", commandID, uniqueID) + currentState.CommandID = commandID + currentState.Phase = PhaseSubmitted + currentState.URI = response.NextURI + currentState.CurrentPrestoQueryUUID = uniqueID + + executionStateCacheItem := ExecutionStateCacheItem{ + ExecutionState: currentState, + Identifier: uniqueID, + } + + // The first time we put it in the cache, we know it won't have succeeded so we don't need to look at it + _, err := cache.GetOrCreate(uniqueID, executionStateCacheItem) + if err != nil { + // This means that our cache has fundamentally broken... return a system error + logger.Errorf(ctx, "Cache failed to GetOrCreate for execution [%s] cache key [%s], owner [%s]. Error %s", + tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetID(), uniqueID, + tCtx.TaskExecutionMetadata().GetOwnerReference(), err) + return currentState, err + } + } + + return currentState, nil +} + +func MonitorQuery( + ctx context.Context, + tCtx core.TaskExecutionContext, + currentState ExecutionState, + cache cache.AutoRefresh) (ExecutionState, error) { + + uniqueQueryID := currentState.CurrentPrestoQueryUUID + executionStateCacheItem := ExecutionStateCacheItem{ + ExecutionState: currentState, + Identifier: uniqueQueryID, + } + + cachedItem, err := cache.GetOrCreate(uniqueQueryID, executionStateCacheItem) + if err != nil { + // This means that our cache has fundamentally broken... return a system error + logger.Errorf(ctx, "Cache is broken on execution [%s] cache key [%s], owner [%s]. Error %s", + tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetID(), uniqueQueryID, + tCtx.TaskExecutionMetadata().GetOwnerReference(), err) + return currentState, errors.Wrapf(errors.CacheFailed, err, "Error when GetOrCreate while monitoring") + } + + cachedExecutionState, ok := cachedItem.(ExecutionStateCacheItem) + if !ok { + logger.Errorf(ctx, "Error casting cache object into ExecutionState") + return currentState, errors.Errorf(errors.CacheFailed, "Failed to cast [%v]", cachedItem) + } + + // If there were updates made to the state, we'll have picked them up automatically. Nothing more to do. + return cachedExecutionState.ExecutionState, nil +} + +func writeOutput(ctx context.Context, tCtx core.TaskExecutionContext, externalLocation string) error { + taskTemplate, err := tCtx.TaskReader().Read(ctx) + if err != nil { + return err + } + + results := taskTemplate.Interface.Outputs.Variables["results"] + + return tCtx.OutputWriter().Put(ctx, ioutils.NewInMemoryOutputReader( + &pb.LiteralMap{ + Literals: map[string]*pb.Literal{ + "results": { + Value: &pb.Literal_Scalar{ + Scalar: &pb.Scalar{Value: &pb.Scalar_Schema{ + Schema: &pb.Schema{ + Uri: externalLocation, + Type: results.GetType().GetSchema(), + }, + }, + }, + }, + }, + }, + }, nil)) +} + +// The 'PhaseInfoRunning' occurs 15 times (3 for each of the 5 Presto queries that get run for every Presto task) which +// are differentiated by the version (1-15) +func MapExecutionStateToPhaseInfo(state ExecutionState) core.PhaseInfo { + var phaseInfo core.PhaseInfo + t := time.Now() + + switch state.Phase { + case PhaseNotStarted: + phaseInfo = core.PhaseInfoNotReady(t, core.DefaultPhaseVersion, "Haven't received allocation token") + case PhaseQueued: + if state.CreationFailureCount > 5 { + phaseInfo = core.PhaseInfoRetryableFailure("PrestoFailure", "Too many creation attempts", nil) + } else { + phaseInfo = core.PhaseInfoRunning(uint32(3*state.QueryCount+1), ConstructTaskInfo(state)) + } + case PhaseSubmitted: + phaseInfo = core.PhaseInfoRunning(uint32(3*state.QueryCount+2), ConstructTaskInfo(state)) + case PhaseQuerySucceeded: + if state.QueryCount < 5 { + phaseInfo = core.PhaseInfoRunning(uint32(3*state.QueryCount+3), ConstructTaskInfo(state)) + } else { + phaseInfo = core.PhaseInfoSuccess(ConstructTaskInfo(state)) + } + case PhaseQueryFailed: + phaseInfo = core.PhaseInfoFailure(errors.DownstreamSystemError, "Query failed", ConstructTaskInfo(state)) + } + + return phaseInfo +} + +func ConstructTaskInfo(e ExecutionState) *core.TaskInfo { + logs := make([]*idlCore.TaskLog, 0, 1) + t := time.Now() + if e.CommandID != "" { + logs = append(logs, ConstructTaskLog(e)) + return &core.TaskInfo{ + Logs: logs, + OccurredAt: &t, + } + } + + return nil +} + +func ConstructTaskLog(e ExecutionState) *idlCore.TaskLog { + return &idlCore.TaskLog{ + Name: fmt.Sprintf("Status: %s [%s]", e.Phase, e.CommandID), + MessageFormat: idlCore.TaskLog_UNKNOWN, + Uri: e.URI, + } +} + +func Abort(ctx context.Context, currentState ExecutionState, client client.PrestoClient) error { + // Cancel Presto query if non-terminal state + if !InTerminalState(currentState) && currentState.CommandID != "" { + err := client.KillCommand(ctx, currentState.CommandID) + if err != nil { + logger.Errorf(ctx, "Error terminating Presto command in Finalize [%s]", err) + return err + } + } + return nil +} + +func Finalize(ctx context.Context, tCtx core.TaskExecutionContext, _ ExecutionState) error { + // Release allocation token + uniqueID := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() + routingGroup, err := composeResourceNamespaceWithRoutingGroup(ctx, tCtx) + if err != nil { + return errors.Wrapf(errors.ResourceManagerFailure, err, "Error getting query info when releasing allocation token %s", uniqueID) + } + + err = tCtx.ResourceManager().ReleaseResource(ctx, routingGroup, uniqueID) + + if err != nil { + logger.Errorf(ctx, "Error releasing allocation token [%s] in Finalize [%s]", uniqueID, err) + return err + } + return nil +} + +func InTerminalState(e ExecutionState) bool { + return e.Phase == PhaseQuerySucceeded || e.Phase == PhaseQueryFailed +} + +func IsNotYetSubmitted(e ExecutionState) bool { + if e.Phase == PhaseNotStarted || e.Phase == PhaseQueued { + return true + } + return false +} diff --git a/go/tasks/plugins/presto/execution_state_test.go b/go/tasks/plugins/presto/execution_state_test.go new file mode 100644 index 000000000..2f0facb68 --- /dev/null +++ b/go/tasks/plugins/presto/execution_state_test.go @@ -0,0 +1,358 @@ +package presto + +import ( + "context" + "net/url" + "testing" + "time" + + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/lyft/flyteplugins/go/tasks/plugins/presto/client" + prestoMocks "github.com/lyft/flyteplugins/go/tasks/plugins/presto/client/mocks" + "github.com/lyft/flyteplugins/go/tasks/plugins/presto/config" + mocks2 "github.com/lyft/flytestdlib/cache/mocks" + stdConfig "github.com/lyft/flytestdlib/config" + "github.com/lyft/flytestdlib/contextutils" + "github.com/lyft/flytestdlib/promutils" + "github.com/lyft/flytestdlib/promutils/labeled" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/plugins" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core/mocks" +) + +func init() { + labeled.SetMetricKeys(contextutils.NamespaceKey) +} + +func TestInTerminalState(t *testing.T) { + var stateTests = []struct { + phase ExecutionPhase + isTerminal bool + }{ + {phase: PhaseNotStarted, isTerminal: false}, + {phase: PhaseQueued, isTerminal: false}, + {phase: PhaseSubmitted, isTerminal: false}, + {phase: PhaseQuerySucceeded, isTerminal: true}, + {phase: PhaseQueryFailed, isTerminal: true}, + } + + for _, tt := range stateTests { + t.Run(tt.phase.String(), func(t *testing.T) { + e := ExecutionState{Phase: tt.phase} + res := InTerminalState(e) + assert.Equal(t, tt.isTerminal, res) + }) + } +} + +func TestIsNotYetSubmitted(t *testing.T) { + var stateTests = []struct { + phase ExecutionPhase + isNotYetSubmitted bool + }{ + {phase: PhaseNotStarted, isNotYetSubmitted: true}, + {phase: PhaseQueued, isNotYetSubmitted: true}, + {phase: PhaseSubmitted, isNotYetSubmitted: false}, + {phase: PhaseQuerySucceeded, isNotYetSubmitted: false}, + {phase: PhaseQueryFailed, isNotYetSubmitted: false}, + } + + for _, tt := range stateTests { + t.Run(tt.phase.String(), func(t *testing.T) { + e := ExecutionState{Phase: tt.phase} + res := IsNotYetSubmitted(e) + assert.Equal(t, tt.isNotYetSubmitted, res) + }) + } +} + +func TestValidatePrestoStatement(t *testing.T) { + prestoQuery := plugins.PrestoQuery{ + RoutingGroup: "adhoc", + Catalog: "hive", + Schema: "city", + Statement: "", + } + err := validatePrestoStatement(prestoQuery) + assert.Error(t, err) +} + +func TestConstructTaskLog(t *testing.T) { + expected := "https://prestoproxy-internal.lyft.net:443" + u, err := url.Parse(expected) + assert.NoError(t, err) + taskLog := ConstructTaskLog(ExecutionState{CommandID: "123", URI: u.String()}) + assert.Equal(t, expected, taskLog.Uri) +} + +func TestConstructTaskInfo(t *testing.T) { + empty := ConstructTaskInfo(ExecutionState{}) + assert.Nil(t, empty) + + expected := "https://prestoproxy-internal.lyft.net:443" + u, err := url.Parse(expected) + assert.NoError(t, err) + + e := ExecutionState{ + Phase: PhaseQuerySucceeded, + CommandID: "123", + SyncFailureCount: 0, + URI: u.String(), + } + + taskInfo := ConstructTaskInfo(e) + assert.Equal(t, "https://prestoproxy-internal.lyft.net:443", taskInfo.Logs[0].Uri) +} + +func TestMapExecutionStateToPhaseInfo(t *testing.T) { + t.Run("NotStarted", func(t *testing.T) { + e := ExecutionState{ + Phase: PhaseNotStarted, + } + phaseInfo := MapExecutionStateToPhaseInfo(e) + assert.Equal(t, core.PhaseNotReady, phaseInfo.Phase()) + }) + + t.Run("Queued", func(t *testing.T) { + e := ExecutionState{ + Phase: PhaseQueued, + CreationFailureCount: 0, + } + phaseInfo := MapExecutionStateToPhaseInfo(e) + assert.Equal(t, core.PhaseRunning, phaseInfo.Phase()) + + e = ExecutionState{ + Phase: PhaseQueued, + CreationFailureCount: 100, + } + phaseInfo = MapExecutionStateToPhaseInfo(e) + assert.Equal(t, core.PhaseRetryableFailure, phaseInfo.Phase()) + + }) + + t.Run("Submitted", func(t *testing.T) { + e := ExecutionState{ + Phase: PhaseSubmitted, + } + phaseInfo := MapExecutionStateToPhaseInfo(e) + assert.Equal(t, core.PhaseRunning, phaseInfo.Phase()) + }) +} + +func TestGetAllocationToken(t *testing.T) { + ctx := context.Background() + + t.Run("allocation granted", func(t *testing.T) { + tCtx := GetMockTaskExecutionContext() + mockResourceManager := tCtx.ResourceManager() + x := mockResourceManager.(*mocks.ResourceManager) + x.On("AllocateResource", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return(core.AllocationStatusGranted, nil) + + mockCurrentState := ExecutionState{AllocationTokenRequestStartTime: time.Now()} + mockMetrics := getPrestoExecutorMetrics(promutils.NewTestScope()) + state, err := GetAllocationToken(ctx, tCtx, mockCurrentState, mockMetrics) + assert.NoError(t, err) + assert.Equal(t, PhaseQueued, state.Phase) + }) + + t.Run("exhausted", func(t *testing.T) { + tCtx := GetMockTaskExecutionContext() + mockResourceManager := tCtx.ResourceManager() + x := mockResourceManager.(*mocks.ResourceManager) + x.On("AllocateResource", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return(core.AllocationStatusExhausted, nil) + + mockCurrentState := ExecutionState{AllocationTokenRequestStartTime: time.Now()} + mockMetrics := getPrestoExecutorMetrics(promutils.NewTestScope()) + state, err := GetAllocationToken(ctx, tCtx, mockCurrentState, mockMetrics) + assert.NoError(t, err) + assert.Equal(t, PhaseNotStarted, state.Phase) + }) + + t.Run("namespace exhausted", func(t *testing.T) { + tCtx := GetMockTaskExecutionContext() + mockResourceManager := tCtx.ResourceManager() + x := mockResourceManager.(*mocks.ResourceManager) + x.On("AllocateResource", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return(core.AllocationStatusNamespaceQuotaExceeded, nil) + + mockCurrentState := ExecutionState{AllocationTokenRequestStartTime: time.Now()} + mockMetrics := getPrestoExecutorMetrics(promutils.NewTestScope()) + state, err := GetAllocationToken(ctx, tCtx, mockCurrentState, mockMetrics) + assert.NoError(t, err) + assert.Equal(t, PhaseNotStarted, state.Phase) + }) + + t.Run("Request start time, if empty in current state, should be set", func(t *testing.T) { + tCtx := GetMockTaskExecutionContext() + mockResourceManager := tCtx.ResourceManager() + x := mockResourceManager.(*mocks.ResourceManager) + x.On("AllocateResource", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return(core.AllocationStatusNamespaceQuotaExceeded, nil) + + mockCurrentState := ExecutionState{} + mockMetrics := getPrestoExecutorMetrics(promutils.NewTestScope()) + state, err := GetAllocationToken(ctx, tCtx, mockCurrentState, mockMetrics) + assert.NoError(t, err) + assert.Equal(t, state.AllocationTokenRequestStartTime.IsZero(), false) + }) + + t.Run("Request start time, if already set in current state, should be maintained", func(t *testing.T) { + tCtx := GetMockTaskExecutionContext() + mockResourceManager := tCtx.ResourceManager() + x := mockResourceManager.(*mocks.ResourceManager) + x.On("AllocateResource", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return(core.AllocationStatusGranted, nil) + + startTime := time.Now() + mockCurrentState := ExecutionState{AllocationTokenRequestStartTime: startTime} + mockMetrics := getPrestoExecutorMetrics(promutils.NewTestScope()) + state, err := GetAllocationToken(ctx, tCtx, mockCurrentState, mockMetrics) + assert.NoError(t, err) + assert.Equal(t, state.AllocationTokenRequestStartTime.IsZero(), false) + assert.Equal(t, state.AllocationTokenRequestStartTime, startTime) + }) +} + +func TestAbort(t *testing.T) { + ctx := context.Background() + + t.Run("Terminate called when not in terminal state", func(t *testing.T) { + var x = false + + mockPresto := &prestoMocks.PrestoClient{} + mockPresto.On("KillCommand", mock.Anything, mock.MatchedBy(func(commandId string) bool { + return commandId == "123456" + }), mock.Anything).Run(func(_ mock.Arguments) { + x = true + }).Return(nil) + + err := Abort(ctx, ExecutionState{Phase: PhaseSubmitted, CommandID: "123456"}, mockPresto) + assert.NoError(t, err) + assert.True(t, x) + }) + + t.Run("Terminate not called when in terminal state", func(t *testing.T) { + var x = false + + mockPresto := &prestoMocks.PrestoClient{} + mockPresto.On("KillCommand", mock.Anything, mock.Anything).Run(func(_ mock.Arguments) { + x = true + }).Return(nil) + + err := Abort(ctx, ExecutionState{Phase: PhaseQuerySucceeded, CommandID: "123456"}, mockPresto) + assert.NoError(t, err) + assert.False(t, x) + }) +} + +func TestFinalize(t *testing.T) { + // Test that Finalize releases resources + ctx := context.Background() + tCtx := GetMockTaskExecutionContext() + state := ExecutionState{} + var called = false + mockResourceManager := tCtx.ResourceManager() + x := mockResourceManager.(*mocks.ResourceManager) + x.On("ReleaseResource", mock.Anything, mock.Anything, mock.Anything).Run(func(_ mock.Arguments) { + called = true + }).Return(nil) + + err := Finalize(ctx, tCtx, state) + assert.NoError(t, err) + assert.True(t, called) +} + +func TestMonitorQuery(t *testing.T) { + ctx := context.Background() + tCtx := GetMockTaskExecutionContext() + state := ExecutionState{ + Phase: PhaseSubmitted, + } + var getOrCreateCalled = false + mockCache := &mocks2.AutoRefresh{} + mockCache.OnGetOrCreateMatch(mock.AnythingOfType("string"), mock.Anything).Return(ExecutionStateCacheItem{ + ExecutionState: ExecutionState{Phase: PhaseQuerySucceeded}, + Identifier: "my_wf_exec_project:my_wf_exec_domain:my_wf_exec_name", + }, nil).Run(func(_ mock.Arguments) { + getOrCreateCalled = true + }) + + newState, err := MonitorQuery(ctx, tCtx, state, mockCache) + assert.NoError(t, err) + assert.True(t, getOrCreateCalled) + assert.Equal(t, PhaseQuerySucceeded, newState.Phase) +} + +func TestKickOffQuery(t *testing.T) { + ctx := context.Background() + tCtx := GetMockTaskExecutionContext() + + var prestoCalled = false + + prestoExecuteResponse := client.PrestoExecuteResponse{ + ID: "1234567", + Status: client.PrestoStatusWaiting, + } + mockPresto := &prestoMocks.PrestoClient{} + mockPresto.OnExecuteCommandMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything, + mock.Anything, mock.Anything).Run(func(_ mock.Arguments) { + prestoCalled = true + }).Return(prestoExecuteResponse, nil) + var getOrCreateCalled = false + mockCache := &mocks2.AutoRefresh{} + mockCache.OnGetOrCreate(mock.Anything, mock.Anything).Run(func(_ mock.Arguments) { + getOrCreateCalled = true + }).Return(ExecutionStateCacheItem{}, nil) + + state := ExecutionState{} + newState, err := KickOffQuery(ctx, tCtx, state, mockPresto, mockCache) + assert.NoError(t, err) + assert.Equal(t, PhaseSubmitted, newState.Phase) + assert.Equal(t, "1234567", newState.CommandID) + assert.True(t, getOrCreateCalled) + assert.True(t, prestoCalled) +} + +func createMockPrestoCfg() *config.Config { + return &config.Config{ + Environment: config.URLMustParse(""), + DefaultRoutingGroup: "adhoc", + RoutingGroupConfigs: []config.RoutingGroupConfig{{Name: "adhoc", Limit: 250}, {Name: "etl", Limit: 100}}, + RefreshCacheConfig: config.RefreshCacheConfig{ + Name: "presto", + SyncPeriod: stdConfig.Duration{Duration: 3 * time.Second}, + Workers: 15, + LruCacheSize: 2000, + }, + } +} + +func Test_mapLabelToPrimaryLabel(t *testing.T) { + ctx := context.TODO() + mockPrestoCfg := createMockPrestoCfg() + + type args struct { + ctx context.Context + routingGroup string + prestoCfg *config.Config + } + tests := []struct { + name string + args args + want string + }{ + {name: "Routing group is found in configs", args: args{ctx: ctx, routingGroup: "etl", prestoCfg: mockPrestoCfg}, want: "etl"}, + {name: "Use routing group default when not found in configs", args: args{ctx: ctx, routingGroup: "test", prestoCfg: mockPrestoCfg}, want: mockPrestoCfg.DefaultRoutingGroup}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, resolveRoutingGroup(tt.args.ctx, tt.args.routingGroup, tt.args.prestoCfg)) + }) + } +} diff --git a/go/tasks/plugins/presto/executions_cache.go b/go/tasks/plugins/presto/executions_cache.go new file mode 100644 index 000000000..a26c6b3a2 --- /dev/null +++ b/go/tasks/plugins/presto/executions_cache.go @@ -0,0 +1,161 @@ +package presto + +import ( + "context" + + "k8s.io/client-go/util/workqueue" + + "github.com/lyft/flytestdlib/cache" + + "github.com/lyft/flyteplugins/go/tasks/errors" + stdErrors "github.com/lyft/flytestdlib/errors" + + "github.com/lyft/flyteplugins/go/tasks/plugins/presto/client" + "github.com/lyft/flyteplugins/go/tasks/plugins/presto/config" + + "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/promutils" +) + +const ( + BadPrestoReturnCodeError stdErrors.ErrorCode = "PRESTO_RETURNED_UNKNOWN" +) + +type ExecutionsCache struct { + cache.AutoRefresh + prestoClient client.PrestoClient + scope promutils.Scope + cfg *config.Config +} + +func NewPrestoExecutionsCache( + ctx context.Context, + prestoClient client.PrestoClient, + cfg *config.Config, + scope promutils.Scope) (ExecutionsCache, error) { + + q := ExecutionsCache{ + prestoClient: prestoClient, + scope: scope, + cfg: cfg, + } + autoRefreshCache, err := cache.NewAutoRefreshCache(cfg.RefreshCacheConfig.Name, q.SyncPrestoQuery, workqueue.DefaultControllerRateLimiter(), cfg.RefreshCacheConfig.SyncPeriod.Duration, cfg.RefreshCacheConfig.Workers, cfg.RefreshCacheConfig.LruCacheSize, scope) + if err != nil { + logger.Errorf(ctx, "Could not create AutoRefreshCache in Executor. [%s]", err) + return q, errors.Wrapf(errors.CacheFailed, err, "Error creating AutoRefreshCache") + } + q.AutoRefresh = autoRefreshCache + return q, nil +} + +type ExecutionStateCacheItem struct { + ExecutionState + + // This ID is the cache key and so will need to be unique across all objects in the cache (it will probably be + // unique across all of Flyte) and needs to be deterministic. + // This will also be used as the allocation token for now. + Identifier string `json:"id"` +} + +func (e ExecutionStateCacheItem) ID() string { + return e.Identifier +} + +// This basically grab an updated status from the Presto API and stores it in the cache +// All other handling should be in the synchronous loop. +func (p *ExecutionsCache) SyncPrestoQuery(ctx context.Context, batch cache.Batch) ( + updatedBatch []cache.ItemSyncResponse, err error) { + + resp := make([]cache.ItemSyncResponse, 0, len(batch)) + for _, query := range batch { + // Cast the item back to the thing we want to work with. + executionStateCacheItem, ok := query.GetItem().(ExecutionStateCacheItem) + if !ok { + logger.Errorf(ctx, "Sync loop - Error casting cache object into ExecutionState") + return nil, errors.Errorf(errors.CacheFailed, "Failed to cast [%v]", batch[0].GetID()) + } + + if executionStateCacheItem.CommandID == "" { + logger.Warnf(ctx, "Sync loop - CommandID is blank for [%s] skipping", executionStateCacheItem.Identifier) + resp = append(resp, cache.ItemSyncResponse{ + ID: query.GetID(), + Item: query.GetItem(), + Action: cache.Unchanged, + }) + + continue + } + + logger.Debugf(ctx, "Sync loop - processing Presto job [%s] - cache key [%s]", + executionStateCacheItem.CommandID, executionStateCacheItem.Identifier) + + if InTerminalState(executionStateCacheItem.ExecutionState) { + logger.Debugf(ctx, "Sync loop - Presto id [%s] in terminal state [%s]", + executionStateCacheItem.CommandID, executionStateCacheItem.Identifier) + + resp = append(resp, cache.ItemSyncResponse{ + ID: query.GetID(), + Item: query.GetItem(), + Action: cache.Unchanged, + }) + + continue + } + + // Get an updated status from Presto + logger.Debugf(ctx, "Querying Presto for %s - %s", executionStateCacheItem.CommandID, executionStateCacheItem.Identifier) + commandStatus, err := p.prestoClient.GetCommandStatus(ctx, executionStateCacheItem.CommandID) + if err != nil { + logger.Errorf(ctx, "Error from Presto command %s", executionStateCacheItem.CommandID) + executionStateCacheItem.SyncFailureCount++ + // Make sure we don't return nil for the first argument, because that deletes it from the cache. + resp = append(resp, cache.ItemSyncResponse{ + ID: query.GetID(), + Item: executionStateCacheItem, + Action: cache.Update, + }) + + continue + } + + newExecutionPhase, err := StatusToExecutionPhase(commandStatus) + if err != nil { + return nil, err + } + + if newExecutionPhase > executionStateCacheItem.Phase { + logger.Infof(ctx, "Moving ExecutionPhase for %s %s from %s to %s", executionStateCacheItem.CommandID, + executionStateCacheItem.Identifier, executionStateCacheItem.Phase, newExecutionPhase) + + executionStateCacheItem.Phase = newExecutionPhase + + resp = append(resp, cache.ItemSyncResponse{ + ID: query.GetID(), + Item: executionStateCacheItem, + Action: cache.Update, + }) + } + } + + return resp, nil +} + +// We need some way to translate results we get from Presto, into a plugin phase +func StatusToExecutionPhase(s client.PrestoStatus) (ExecutionPhase, error) { + switch s { + case client.PrestoStatusFinished: + return PhaseQuerySucceeded, nil + case client.PrestoStatusCancelled: + return PhaseQueryFailed, nil + case client.PrestoStatusFailed: + return PhaseQueryFailed, nil + case client.PrestoStatusWaiting: + return PhaseSubmitted, nil + case client.PrestoStatusRunning: + return PhaseSubmitted, nil + case client.PrestoStatusUnknown: + return PhaseQueryFailed, errors.Errorf(BadPrestoReturnCodeError, "Presto returned status Unknown") + default: + return PhaseQueryFailed, errors.Errorf(BadPrestoReturnCodeError, "default fallthrough case") + } +} diff --git a/go/tasks/plugins/presto/executions_cache_test.go b/go/tasks/plugins/presto/executions_cache_test.go new file mode 100644 index 000000000..3f6114762 --- /dev/null +++ b/go/tasks/plugins/presto/executions_cache_test.go @@ -0,0 +1,91 @@ +package presto + +import ( + "context" + "testing" + + prestoMocks "github.com/lyft/flyteplugins/go/tasks/plugins/presto/client/mocks" + + "github.com/lyft/flytestdlib/cache" + cacheMocks "github.com/lyft/flytestdlib/cache/mocks" + + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core/mocks" + "github.com/lyft/flyteplugins/go/tasks/plugins/presto/client" + "github.com/lyft/flyteplugins/go/tasks/plugins/presto/config" + + "github.com/lyft/flytestdlib/promutils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestPrestoExecutionsCache_SyncQuboleQuery(t *testing.T) { + ctx := context.Background() + + t.Run("terminal state return unchanged", func(t *testing.T) { + mockCache := &cacheMocks.AutoRefresh{} + mockPresto := &prestoMocks.PrestoClient{} + testScope := promutils.NewTestScope() + + p := ExecutionsCache{ + AutoRefresh: mockCache, + prestoClient: mockPresto, + scope: testScope, + cfg: config.GetPrestoConfig(), + } + + state := ExecutionState{ + Phase: PhaseQuerySucceeded, + } + cacheItem := ExecutionStateCacheItem{ + ExecutionState: state, + Identifier: "some-id", + } + + iw := &cacheMocks.ItemWrapper{} + iw.OnGetItem().Return(cacheItem) + iw.OnGetID().Return("some-id") + + newCacheItem, err := p.SyncPrestoQuery(ctx, []cache.ItemWrapper{iw}) + assert.NoError(t, err) + assert.Equal(t, cache.Unchanged, newCacheItem[0].Action) + assert.Equal(t, cacheItem, newCacheItem[0].Item) + }) + + t.Run("move to success", func(t *testing.T) { + mockCache := &cacheMocks.AutoRefresh{} + mockPresto := &prestoMocks.PrestoClient{} + mockSecretManager := &mocks.SecretManager{} + mockSecretManager.OnGetMatch(mock.Anything, mock.Anything).Return("fake key", nil) + + testScope := promutils.NewTestScope() + + p := ExecutionsCache{ + AutoRefresh: mockCache, + prestoClient: mockPresto, + scope: testScope, + cfg: config.GetPrestoConfig(), + } + + state := ExecutionState{ + CommandID: "123456", + Phase: PhaseSubmitted, + } + cacheItem := ExecutionStateCacheItem{ + ExecutionState: state, + Identifier: "some-id", + } + mockPresto.OnGetCommandStatusMatch(mock.Anything, mock.MatchedBy(func(commandId string) bool { + return commandId == state.CommandID + }), mock.Anything).Return(client.PrestoStatusFinished, nil) + + iw := &cacheMocks.ItemWrapper{} + iw.OnGetItem().Return(cacheItem) + iw.OnGetID().Return("some-id") + + newCacheItem, err := p.SyncPrestoQuery(ctx, []cache.ItemWrapper{iw}) + newExecutionState := newCacheItem[0].Item.(ExecutionStateCacheItem) + assert.NoError(t, err) + assert.Equal(t, cache.Update, newCacheItem[0].Action) + assert.Equal(t, PhaseQuerySucceeded, newExecutionState.Phase) + }) +} diff --git a/go/tasks/plugins/presto/executor.go b/go/tasks/plugins/presto/executor.go new file mode 100644 index 000000000..745ac7817 --- /dev/null +++ b/go/tasks/plugins/presto/executor.go @@ -0,0 +1,156 @@ +package presto + +import ( + "context" + + "github.com/lyft/flyteplugins/go/tasks/plugins/presto/client" + + "github.com/lyft/flytestdlib/cache" + + "github.com/lyft/flyteplugins/go/tasks/errors" + pluginMachinery "github.com/lyft/flyteplugins/go/tasks/pluginmachinery" + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/lyft/flyteplugins/go/tasks/plugins/presto/config" + "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/promutils" +) + +// This is the name of this plugin effectively. In Flyte plugin configuration, use this string to enable this plugin. +const prestoPluginID = "presto" + +// Version of the custom state this plugin stores. Useful for backwards compatibility if you one day need to update +// the structure of the stored state +const pluginStateVersion = 0 + +const prestoTaskType = "presto" // This needs to match the type defined in Flytekit constants.py + +type Executor struct { + id string + metrics ExecutorMetrics + prestoClient client.PrestoClient + executionsCache cache.AutoRefresh + cfg *config.Config +} + +func (p Executor) GetID() string { + return p.id +} + +func (p Executor) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (core.Transition, error) { + incomingState := ExecutionState{} + + // We assume here that the first time this function is called, the custom state we get back is whatever we passed in, + // namely the zero-value of our struct. + if _, err := tCtx.PluginStateReader().Get(&incomingState); err != nil { + logger.Errorf(ctx, "Plugin %s failed to unmarshal custom state when handling [%s] [%s]", + p.id, tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), err) + return core.UnknownTransition, errors.Wrapf(errors.CorruptedPluginState, err, + "Failed to unmarshal custom state in Handle") + } + + // Do what needs to be done, and give this function everything it needs to do its job properly + outgoingState, transformError := HandleExecutionState(ctx, tCtx, incomingState, p.prestoClient, p.executionsCache, p.metrics) + + // Return if there was an error + if transformError != nil { + return core.UnknownTransition, transformError + } + + // If no error, then infer the new Phase from the various states + phaseInfo := MapExecutionStateToPhaseInfo(outgoingState) + + if err := tCtx.PluginStateWriter().Put(pluginStateVersion, outgoingState); err != nil { + return core.UnknownTransition, err + } + + return core.DoTransitionType(core.TransitionTypeBarrier, phaseInfo), nil +} + +func (p Executor) Abort(ctx context.Context, tCtx core.TaskExecutionContext) error { + incomingState := ExecutionState{} + if _, err := tCtx.PluginStateReader().Get(&incomingState); err != nil { + logger.Errorf(ctx, "Plugin %s failed to unmarshal custom state in Finalize [%s] Err [%s]", + p.id, tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), err) + return errors.Wrapf(errors.CorruptedPluginState, err, "Failed to unmarshal custom state in Finalize") + } + + return Abort(ctx, incomingState, p.prestoClient) +} + +func (p Executor) Finalize(ctx context.Context, tCtx core.TaskExecutionContext) error { + incomingState := ExecutionState{} + if _, err := tCtx.PluginStateReader().Get(&incomingState); err != nil { + logger.Errorf(ctx, "Plugin %s failed to unmarshal custom state in Finalize [%s] Err [%s]", + p.id, tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), err) + return errors.Wrapf(errors.CorruptedPluginState, err, "Failed to unmarshal custom state in Finalize") + } + + return Finalize(ctx, tCtx, incomingState) +} + +func (p Executor) GetProperties() core.PluginProperties { + return core.PluginProperties{} +} + +func ExecutorLoader(ctx context.Context, iCtx core.SetupContext) (core.Plugin, error) { + cfg := config.GetPrestoConfig() + return InitializePrestoExecutor(ctx, iCtx, cfg, client.NewNoopPrestoClient(cfg)) +} + +func InitializePrestoExecutor( + ctx context.Context, + iCtx core.SetupContext, + cfg *config.Config, + prestoClient client.PrestoClient) (core.Plugin, error) { + logger.Infof(ctx, "Initializing a Presto executo") + q, err := NewPrestoExecutor(ctx, cfg, prestoClient, iCtx.MetricsScope()) + if err != nil { + logger.Errorf(ctx, "Failed to create a new Executor due to error: [%v]", err) + return nil, err + } + + for _, routingGroup := range cfg.RoutingGroupConfigs { + logger.Infof(ctx, "Registering resource quota for routing group [%v]", routingGroup.Name) + if err := iCtx.ResourceRegistrar().RegisterResourceQuota(ctx, core.ResourceNamespace(routingGroup.Name), routingGroup.Limit); err != nil { + logger.Errorf(ctx, "Resource quota registration for [%v] failed due to error [%v]", routingGroup.Name, err) + return nil, err + } + } + + return q, nil +} + +func NewPrestoExecutor( + ctx context.Context, + cfg *config.Config, + prestoClient client.PrestoClient, + scope promutils.Scope) (Executor, error) { + executionsAutoRefreshCache, err := NewPrestoExecutionsCache(ctx, prestoClient, cfg, scope.NewSubScope(prestoTaskType)) + if err != nil { + logger.Errorf(ctx, "Failed to create AutoRefreshCache in Executor Setup. Error: %v", err) + return Executor{}, err + } + + err = executionsAutoRefreshCache.Start(ctx) + if err != nil { + logger.Errorf(ctx, "Failed to start AutoRefreshCache. Error: %v", err) + } + + return Executor{ + id: prestoPluginID, + cfg: cfg, + metrics: getPrestoExecutorMetrics(scope), + prestoClient: prestoClient, + executionsCache: executionsAutoRefreshCache, + }, nil +} + +func init() { + pluginMachinery.PluginRegistry().RegisterCorePlugin( + core.PluginEntry{ + ID: prestoPluginID, + RegisteredTaskTypes: []core.TaskType{prestoTaskType}, + LoadPlugin: ExecutorLoader, + IsDefault: false, + }) +} diff --git a/go/tasks/plugins/presto/executor_metrics.go b/go/tasks/plugins/presto/executor_metrics.go new file mode 100644 index 000000000..69538a757 --- /dev/null +++ b/go/tasks/plugins/presto/executor_metrics.go @@ -0,0 +1,25 @@ +package presto + +import ( + "github.com/lyft/flytestdlib/promutils" + "github.com/lyft/flytestdlib/promutils/labeled" +) + +type ExecutorMetrics struct { + Scope promutils.Scope + ReleaseResourceFailed labeled.Counter + AllocationGranted labeled.Counter + AllocationNotGranted labeled.Counter +} + +func getPrestoExecutorMetrics(scope promutils.Scope) ExecutorMetrics { + return ExecutorMetrics{ + Scope: scope, + ReleaseResourceFailed: labeled.NewCounter("presto_released_resource_failed", + "Error releasing allocation token for Presto", scope), + AllocationGranted: labeled.NewCounter("presto_allocation_granted", + "Allocation request granted for Presto", scope), + AllocationNotGranted: labeled.NewCounter("presto_allocation_not_granted", + "Allocation request did not fail but not granted for Presto", scope), + } +} diff --git a/go/tasks/plugins/presto/helpers_test.go b/go/tasks/plugins/presto/helpers_test.go new file mode 100644 index 000000000..f5bc6a235 --- /dev/null +++ b/go/tasks/plugins/presto/helpers_test.go @@ -0,0 +1,113 @@ +package presto + +import ( + structpb "github.com/golang/protobuf/ptypes/struct" + idlCore "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/plugins" + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core" + coreMock "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core/mocks" + ioMock "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/io/mocks" + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/utils" + "github.com/lyft/flytestdlib/storage" + "github.com/stretchr/testify/mock" + v1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/resource" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" +) + +func GetPrestoQueryTaskTemplate() idlCore.TaskTemplate { + prestoQuery := plugins.PrestoQuery{ + RoutingGroup: "adhoc", + Catalog: "hive", + Schema: "city", + Statement: "select * from hive.city.fact_airport_sessions limit 10", + } + stObj := &structpb.Struct{} + _ = utils.MarshalStruct(&prestoQuery, stObj) + tt := idlCore.TaskTemplate{ + Type: "presto", + Custom: stObj, + Id: &idlCore.Identifier{ + Name: "sample_presto_task_test_name", + Project: "flyteplugins", + Version: "1", + ResourceType: idlCore.ResourceType_TASK, + }, + } + + return tt +} + +var resourceRequirements = &v1.ResourceRequirements{ + Limits: v1.ResourceList{ + v1.ResourceCPU: resource.MustParse("1024m"), + v1.ResourceStorage: resource.MustParse("100M"), + }, +} + +func GetMockTaskExecutionMetadata() core.TaskExecutionMetadata { + taskMetadata := &coreMock.TaskExecutionMetadata{} + taskMetadata.On("GetNamespace").Return("test-namespace") + taskMetadata.On("GetAnnotations").Return(map[string]string{"annotation-1": "val1"}) + taskMetadata.On("GetLabels").Return(map[string]string{"label-1": "val1"}) + taskMetadata.On("GetOwnerReference").Return(metav1.OwnerReference{ + Kind: "node", + Name: "blah", + }) + taskMetadata.On("GetK8sServiceAccount").Return("service-account") + taskMetadata.On("GetOwnerID").Return(types.NamespacedName{ + Namespace: "test-namespace", + Name: "test-owner-name", + }) + + tID := &coreMock.TaskExecutionID{} + tID.On("GetID").Return(idlCore.TaskExecutionIdentifier{ + NodeExecutionId: &idlCore.NodeExecutionIdentifier{ + ExecutionId: &idlCore.WorkflowExecutionIdentifier{ + Name: "my_wf_exec_name", + Project: "my_wf_exec_project", + Domain: "my_wf_exec_domain", + }, + }, + }) + tID.On("GetGeneratedName").Return("my_wf_exec_project:my_wf_exec_domain:my_wf_exec_name") + taskMetadata.On("GetTaskExecutionID").Return(tID) + + to := &coreMock.TaskOverrides{} + to.On("GetResources").Return(resourceRequirements) + taskMetadata.On("GetOverrides").Return(to) + + return taskMetadata +} + +func GetMockTaskExecutionContext() core.TaskExecutionContext { + tt := GetPrestoQueryTaskTemplate() + + dummyTaskMetadata := GetMockTaskExecutionMetadata() + taskCtx := &coreMock.TaskExecutionContext{} + inputReader := &ioMock.InputReader{} + inputReader.On("GetInputPath").Return(storage.DataReference("test-data-reference")) + inputReader.On("Get", mock.Anything).Return(&idlCore.LiteralMap{}, nil) + inputReader.On("GetInputPrefixPath").Return(storage.DataReference("/data")) + taskCtx.On("InputReader").Return(inputReader) + + outputReader := &ioMock.OutputWriter{} + outputReader.On("GetOutputPath").Return(storage.DataReference("/data/outputs.pb")) + outputReader.On("GetOutputPrefixPath").Return(storage.DataReference("/data/")) + taskCtx.On("OutputWriter").Return(outputReader) + + taskReader := &coreMock.TaskReader{} + taskReader.On("Read", mock.Anything).Return(&tt, nil) + taskCtx.On("TaskReader").Return(taskReader) + + resourceManager := &coreMock.ResourceManager{} + taskCtx.On("ResourceManager").Return(resourceManager) + + taskCtx.On("TaskExecutionMetadata").Return(dummyTaskMetadata) + mockSecretManager := &coreMock.SecretManager{} + mockSecretManager.On("Get", mock.Anything, mock.Anything).Return("fake key", nil) + taskCtx.On("SecretManager").Return(mockSecretManager) + + return taskCtx +}