From 6809ea4cd33a7c4801517655bd4574105492d21e Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Wed, 30 Nov 2022 12:10:18 -0800 Subject: [PATCH] nit Signed-off-by: Kevin Su --- go/tasks/plugins/webapi/databricks/config.go | 71 ----- .../plugins/webapi/databricks/config_test.go | 18 -- .../webapi/databricks/integration_test.go | 107 ------- go/tasks/plugins/webapi/databricks/plugin.go | 290 ------------------ .../plugins/webapi/databricks/plugin_test.go | 122 -------- go/tasks/plugins/webapi/snowflake/config.go | 2 +- 6 files changed, 1 insertion(+), 609 deletions(-) delete mode 100644 go/tasks/plugins/webapi/databricks/config.go delete mode 100644 go/tasks/plugins/webapi/databricks/config_test.go delete mode 100644 go/tasks/plugins/webapi/databricks/integration_test.go delete mode 100644 go/tasks/plugins/webapi/databricks/plugin.go delete mode 100644 go/tasks/plugins/webapi/databricks/plugin_test.go diff --git a/go/tasks/plugins/webapi/databricks/config.go b/go/tasks/plugins/webapi/databricks/config.go deleted file mode 100644 index 7cd0c17c2..000000000 --- a/go/tasks/plugins/webapi/databricks/config.go +++ /dev/null @@ -1,71 +0,0 @@ -package databricks - -import ( - "time" - - pluginsConfig "github.com/flyteorg/flyteplugins/go/tasks/config" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/webapi" - "github.com/flyteorg/flytestdlib/config" -) - -var ( - defaultConfig = Config{ - WebAPI: webapi.PluginConfig{ - ResourceQuotas: map[core.ResourceNamespace]int{ - "default": 1000, - }, - ReadRateLimiter: webapi.RateLimiterConfig{ - Burst: 100, - QPS: 10, - }, - WriteRateLimiter: webapi.RateLimiterConfig{ - Burst: 100, - QPS: 10, - }, - Caching: webapi.CachingConfig{ - Size: 500000, - ResyncInterval: config.Duration{Duration: 30 * time.Second}, - Workers: 10, - MaxSystemFailures: 5, - }, - ResourceMeta: nil, - }, - ResourceConstraints: core.ResourceConstraintsSpec{ - ProjectScopeResourceConstraint: &core.ResourceConstraint{ - Value: 100, - }, - NamespaceScopeResourceConstraint: &core.ResourceConstraint{ - Value: 50, - }, - }, - DefaultCluster: "COMPUTE_CLUSTER", - TokenKey: "FLYTE_DATABRICKS_API_TOKEN", - } - - configSection = pluginsConfig.MustRegisterSubSection("databricks", &defaultConfig) -) - -// Config is config for 'databricks' plugin -type Config struct { - // WebAPI defines config for the base WebAPI plugin - WebAPI webapi.PluginConfig `json:"webApi" pflag:",Defines config for the base WebAPI plugin."` - - // ResourceConstraints defines resource constraints on how many executions to be created per project/overall at any given time - ResourceConstraints core.ResourceConstraintsSpec `json:"resourceConstraints" pflag:"-,Defines resource constraints on how many executions to be created per project/overall at any given time."` - - DefaultCluster string `json:"defaultWarehouse" pflag:",Defines the default warehouse to use when running on Databricks unless overwritten by the task."` - - TokenKey string `json:"databricksTokenKey" pflag:",Name of the key where to find Databricks token in the secret manager."` - - // databricksEndpoint overrides databricks instance endpoint, only for testing - databricksEndpoint string -} - -func GetConfig() *Config { - return configSection.GetConfig().(*Config) -} - -func SetConfig(cfg *Config) error { - return configSection.SetConfig(cfg) -} diff --git a/go/tasks/plugins/webapi/databricks/config_test.go b/go/tasks/plugins/webapi/databricks/config_test.go deleted file mode 100644 index 46cee89e2..000000000 --- a/go/tasks/plugins/webapi/databricks/config_test.go +++ /dev/null @@ -1,18 +0,0 @@ -package databricks - -import ( - "testing" - "time" - - "github.com/stretchr/testify/assert" -) - -func TestGetAndSetConfig(t *testing.T) { - cfg := defaultConfig - cfg.DefaultCluster = "test-cluster" - cfg.WebAPI.Caching.Workers = 1 - cfg.WebAPI.Caching.ResyncInterval.Duration = 5 * time.Second - err := SetConfig(&cfg) - assert.NoError(t, err) - assert.Equal(t, &cfg, GetConfig()) -} diff --git a/go/tasks/plugins/webapi/databricks/integration_test.go b/go/tasks/plugins/webapi/databricks/integration_test.go deleted file mode 100644 index 21a17962c..000000000 --- a/go/tasks/plugins/webapi/databricks/integration_test.go +++ /dev/null @@ -1,107 +0,0 @@ -package databricks - -import ( - "context" - "fmt" - "net/http" - "net/http/httptest" - "testing" - "time" - - "github.com/flyteorg/flyteidl/clients/go/coreutils" - coreIdl "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - flyteIdlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery" - pluginCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" - pluginCoreMocks "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks" - "github.com/flyteorg/flyteplugins/tests" - "github.com/flyteorg/flytestdlib/contextutils" - "github.com/flyteorg/flytestdlib/promutils" - "github.com/flyteorg/flytestdlib/promutils/labeled" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" -) - -func TestEndToEnd(t *testing.T) { - server := newFakeSnowflakeServer() - defer server.Close() - - iter := func(ctx context.Context, tCtx pluginCore.TaskExecutionContext) error { - return nil - } - - cfg := defaultConfig - cfg.databricksEndpoint = server.URL - cfg.DefaultCluster = "test-cluster" - cfg.WebAPI.Caching.Workers = 1 - cfg.WebAPI.Caching.ResyncInterval.Duration = 5 * time.Second - err := SetConfig(&cfg) - assert.NoError(t, err) - - pluginEntry := pluginmachinery.CreateRemotePlugin(newSnowflakeJobTaskPlugin()) - plugin, err := pluginEntry.LoadPlugin(context.TODO(), newFakeSetupContext()) - assert.NoError(t, err) - - t.Run("SELECT 1", func(t *testing.T) { - config := make(map[string]string) - config["database"] = "my-database" - config["account"] = "snowflake" - config["schema"] = "my-schema" - config["warehouse"] = "my-warehouse" - - inputs, _ := coreutils.MakeLiteralMap(map[string]interface{}{"x": 1}) - template := flyteIdlCore.TaskTemplate{ - Type: "snowflake", - Config: config, - Target: &coreIdl.TaskTemplate_Sql{Sql: &coreIdl.Sql{Statement: "SELECT 1", Dialect: coreIdl.Sql_ANSI}}, - } - - phase := tests.RunPluginEndToEndTest(t, plugin, &template, inputs, nil, nil, iter) - - assert.Equal(t, true, phase.Phase().IsSuccess()) - }) -} - -func newFakeSnowflakeServer() *httptest.Server { - statementHandle := "019e7546-0000-278c-0000-40f10001a082" - return httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { - if request.URL.Path == "/api/v2/statements" && request.Method == "POST" { - writer.WriteHeader(202) - bytes := []byte(fmt.Sprintf(`{ - "statementHandle": "%v", - "message": "Asynchronous execution in progress." - }`, statementHandle)) - _, _ = writer.Write(bytes) - return - } - - if request.URL.Path == "/api/v2/statements/"+statementHandle && request.Method == "GET" { - writer.WriteHeader(200) - bytes := []byte(fmt.Sprintf(`{ - "statementHandle": "%v", - "message": "Statement executed successfully." - }`, statementHandle)) - _, _ = writer.Write(bytes) - return - } - - if request.URL.Path == "/api/v2/statements/"+statementHandle+"/cancel" && request.Method == "POST" { - writer.WriteHeader(200) - return - } - - writer.WriteHeader(500) - })) -} - -func newFakeSetupContext() *pluginCoreMocks.SetupContext { - fakeResourceRegistrar := pluginCoreMocks.ResourceRegistrar{} - fakeResourceRegistrar.On("RegisterResourceQuota", mock.Anything, mock.Anything, mock.Anything).Return(nil) - labeled.SetMetricKeys(contextutils.NamespaceKey) - - fakeSetupContext := pluginCoreMocks.SetupContext{} - fakeSetupContext.OnMetricsScope().Return(promutils.NewScope("test")) - fakeSetupContext.OnResourceRegistrar().Return(&fakeResourceRegistrar) - - return &fakeSetupContext -} diff --git a/go/tasks/plugins/webapi/databricks/plugin.go b/go/tasks/plugins/webapi/databricks/plugin.go deleted file mode 100644 index a52e91293..000000000 --- a/go/tasks/plugins/webapi/databricks/plugin.go +++ /dev/null @@ -1,290 +0,0 @@ -package databricks - -import ( - "bytes" - "context" - "encoding/gob" - "encoding/json" - "fmt" - "io/ioutil" - "net/http" - "time" - - flyteIdlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - errors2 "github.com/flyteorg/flyteplugins/go/tasks/errors" - pluginErrors "github.com/flyteorg/flyteplugins/go/tasks/errors" - pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/template" - "github.com/flyteorg/flytestdlib/errors" - "github.com/flyteorg/flytestdlib/logger" - - "github.com/flyteorg/flytestdlib/promutils" - - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/webapi" -) - -const ( - ErrSystem errors.ErrorCode = "System" - post string = "POST" - get string = "GET" -) - -// for mocking/testing purposes, and we'll override this method -type HTTPClient interface { - Do(req *http.Request) (*http.Response, error) -} - -type Plugin struct { - metricScope promutils.Scope - cfg *Config - client HTTPClient -} - -type ResourceWrapper struct { - StatusCode int - Message string -} - -type ResourceMetaWrapper struct { - QueryID string - Account string - Token string -} - -func (p Plugin) GetConfig() webapi.PluginConfig { - return GetConfig().WebAPI -} - -type QueryInfo struct { - Account string - Warehouse string - Schema string - Database string - Statement string -} - -func (p Plugin) ResourceRequirements(_ context.Context, _ webapi.TaskExecutionContextReader) ( - namespace core.ResourceNamespace, constraints core.ResourceConstraintsSpec, err error) { - - // Resource requirements are assumed to be the same. - return "default", p.cfg.ResourceConstraints, nil -} - -func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextReader) (webapi.ResourceMeta, - webapi.Resource, error) { - task, err := taskCtx.TaskReader().Read(ctx) - if err != nil { - return nil, nil, err - } - - token, err := taskCtx.SecretManager().Get(ctx, p.cfg.TokenKey) - if err != nil { - return nil, nil, err - } - config := task.GetConfig() - - outputs, err := template.Render(ctx, []string{ - task.GetSql().Statement, - }, template.Parameters{ - TaskExecMetadata: taskCtx.TaskExecutionMetadata(), - Inputs: taskCtx.InputReader(), - OutputPath: taskCtx.OutputWriter(), - Task: taskCtx.TaskReader(), - }) - if err != nil { - return nil, nil, err - } - queryInfo := QueryInfo{ - Account: config["account"], - Warehouse: config["warehouse"], - Schema: config["schema"], - Database: config["database"], - Statement: outputs[0], - } - - if len(queryInfo.Warehouse) == 0 { - queryInfo.Warehouse = p.cfg.DefaultCluster - } - if len(queryInfo.Account) == 0 { - return nil, nil, errors.Errorf(errors2.BadTaskSpecification, "Account must not be empty.") - } - if len(queryInfo.Database) == 0 { - return nil, nil, errors.Errorf(errors2.BadTaskSpecification, "Database must not be empty.") - } - req, err := buildRequest(post, queryInfo, p.cfg.databricksEndpoint, - config["account"], token, "", false) - if err != nil { - return nil, nil, err - } - resp, err := p.client.Do(req) - if err != nil { - return nil, nil, err - } - defer resp.Body.Close() - data, err := buildResponse(resp) - if err != nil { - return nil, nil, err - } - - if data["statementHandle"] == "" { - return nil, nil, pluginErrors.Wrapf(pluginErrors.RuntimeFailure, err, - "Unable to fetch statementHandle from http response") - } - if data["message"] == "" { - return nil, nil, pluginErrors.Wrapf(pluginErrors.RuntimeFailure, err, - "Unable to fetch message from http response") - } - queryID := fmt.Sprintf("%v", data["statementHandle"]) - message := fmt.Sprintf("%v", data["message"]) - - return &ResourceMetaWrapper{queryID, queryInfo.Account, token}, - &ResourceWrapper{StatusCode: resp.StatusCode, Message: message}, nil -} - -func (p Plugin) Get(ctx context.Context, taskCtx webapi.GetContext) (latest webapi.Resource, err error) { - exec := taskCtx.ResourceMeta().(*ResourceMetaWrapper) - req, err := buildRequest(get, QueryInfo{}, p.cfg.databricksEndpoint, - exec.Account, exec.Token, exec.QueryID, false) - if err != nil { - return nil, err - } - resp, err := p.client.Do(req) - if err != nil { - return nil, err - } - defer resp.Body.Close() - data, err := buildResponse(resp) - if err != nil { - return nil, err - } - message := fmt.Sprintf("%v", data["message"]) - return &ResourceWrapper{ - StatusCode: resp.StatusCode, - Message: message, - }, nil -} - -func (p Plugin) Delete(ctx context.Context, taskCtx webapi.DeleteContext) error { - exec := taskCtx.ResourceMeta().(*ResourceMetaWrapper) - req, err := buildRequest(post, QueryInfo{}, p.cfg.databricksEndpoint, - exec.Account, exec.Token, exec.QueryID, true) - if err != nil { - return err - } - resp, err := p.client.Do(req) - if err != nil { - return err - } - defer resp.Body.Close() - logger.Info(ctx, "Deleted query execution [%v]", resp) - - return nil -} - -func (p Plugin) Status(_ context.Context, taskCtx webapi.StatusContext) (phase core.PhaseInfo, err error) { - exec := taskCtx.ResourceMeta().(*ResourceMetaWrapper) - statusCode := taskCtx.Resource().(*ResourceWrapper).StatusCode - if statusCode == 0 { - return core.PhaseInfoUndefined, errors.Errorf(ErrSystem, "No Status field set.") - } - - taskInfo := createTaskInfo(exec.QueryID, exec.Account) - switch statusCode { - case http.StatusAccepted: - return core.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, createTaskInfo(exec.QueryID, exec.Account)), nil - case http.StatusOK: - return pluginsCore.PhaseInfoSuccess(taskInfo), nil - case http.StatusUnprocessableEntity: - return pluginsCore.PhaseInfoFailure(string(rune(statusCode)), "phaseReason", taskInfo), nil - } - return core.PhaseInfoUndefined, pluginErrors.Errorf(pluginsCore.SystemErrorCode, "unknown execution phase [%v].", statusCode) -} - -func buildRequest(method string, queryInfo QueryInfo, snowflakeEndpoint string, account string, token string, - queryID string, isCancel bool) (*http.Request, error) { - var snowflakeURL string - // for mocking/testing purposes - if snowflakeEndpoint == "" { - snowflakeURL = "https://" + account + ".snowflakecomputing.com/api/v2/statements" - } else { - snowflakeURL = snowflakeEndpoint + "/api/v2/statements" - } - - var data []byte - if method == post && !isCancel { - snowflakeURL += "?async=true" - data = []byte(fmt.Sprintf(`{ - "statement": "%v", - "database": "%v", - "schema": "%v", - "warehouse": "%v" - }`, queryInfo.Statement, queryInfo.Database, queryInfo.Schema, queryInfo.Warehouse)) - } else { - snowflakeURL += "/" + queryID - } - if isCancel { - snowflakeURL += "/cancel" - } - - req, err := http.NewRequest(method, snowflakeURL, bytes.NewBuffer(data)) - if err != nil { - return nil, err - } - req.Header.Add("Authorization", "Bearer "+token) - req.Header.Add("X-Snowflake-Authorization-Token-Type", "KEYPAIR_JWT") - req.Header.Add("Content-Type", "application/json") - req.Header.Add("Accept", "application/json") - return req, nil -} - -func buildResponse(response *http.Response) (map[string]interface{}, error) { - responseBody, err := ioutil.ReadAll(response.Body) - if err != nil { - return nil, err - } - var data map[string]interface{} - err = json.Unmarshal(responseBody, &data) - if err != nil { - return nil, err - } - return data, nil -} - -func createTaskInfo(queryID string, account string) *core.TaskInfo { - timeNow := time.Now() - - return &core.TaskInfo{ - OccurredAt: &timeNow, - Logs: []*flyteIdlCore.TaskLog{ - { - Uri: fmt.Sprintf("https://%v.snowflakecomputing.com/console#/monitoring/queries/detail?queryId=%v", - account, - queryID), - Name: "Snowflake Console", - }, - }, - } -} - -func newSnowflakeJobTaskPlugin() webapi.PluginEntry { - return webapi.PluginEntry{ - ID: "snowflake", - SupportedTaskTypes: []core.TaskType{"snowflake"}, - PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) { - return &Plugin{ - metricScope: iCtx.MetricsScope(), - cfg: GetConfig(), - client: &http.Client{}, - }, nil - }, - } -} - -func init() { - gob.Register(ResourceMetaWrapper{}) - gob.Register(ResourceWrapper{}) - - pluginmachinery.PluginRegistry().RegisterRemotePlugin(newSnowflakeJobTaskPlugin()) -} diff --git a/go/tasks/plugins/webapi/databricks/plugin_test.go b/go/tasks/plugins/webapi/databricks/plugin_test.go deleted file mode 100644 index 10febc17f..000000000 --- a/go/tasks/plugins/webapi/databricks/plugin_test.go +++ /dev/null @@ -1,122 +0,0 @@ -package databricks - -import ( - "context" - "encoding/json" - "io/ioutil" - "net/http" - "strings" - "testing" - "time" - - pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" - pluginCoreMocks "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks" - "github.com/flyteorg/flytestdlib/promutils" - "github.com/stretchr/testify/assert" -) - -type MockClient struct { -} - -var ( - MockDo func(req *http.Request) (*http.Response, error) -) - -func (m *MockClient) Do(req *http.Request) (*http.Response, error) { - return MockDo(req) -} - -func TestPlugin(t *testing.T) { - fakeSetupContext := pluginCoreMocks.SetupContext{} - fakeSetupContext.OnMetricsScope().Return(promutils.NewScope("test")) - - plugin := Plugin{ - metricScope: fakeSetupContext.MetricsScope(), - cfg: GetConfig(), - client: &MockClient{}, - } - t.Run("get config", func(t *testing.T) { - cfg := defaultConfig - cfg.WebAPI.Caching.Workers = 1 - cfg.WebAPI.Caching.ResyncInterval.Duration = 5 * time.Second - err := SetConfig(&cfg) - assert.NoError(t, err) - assert.Equal(t, cfg.WebAPI, plugin.GetConfig()) - }) - t.Run("get ResourceRequirements", func(t *testing.T) { - namespace, constraints, err := plugin.ResourceRequirements(context.TODO(), nil) - assert.NoError(t, err) - assert.Equal(t, pluginsCore.ResourceNamespace("default"), namespace) - assert.Equal(t, plugin.cfg.ResourceConstraints, constraints) - }) -} - -func TestCreateTaskInfo(t *testing.T) { - t.Run("create task info", func(t *testing.T) { - taskInfo := createTaskInfo("d5493e36", "test-account") - - assert.Equal(t, 1, len(taskInfo.Logs)) - assert.Equal(t, taskInfo.Logs[0].Uri, "https://test-account.snowflakecomputing.com/console#/monitoring/queries/detail?queryId=d5493e36") - assert.Equal(t, taskInfo.Logs[0].Name, "Snowflake Console") - }) -} - -func TestBuildRequest(t *testing.T) { - account := "test-account" - token := "test-token" - queryID := "019e70eb-0000-278b-0000-40f100012b1a" - snowflakeEndpoint := "" - snowflakeURL := "https://" + account + ".snowflakecomputing.com/api/v2/statements" - t.Run("build http request for submitting a snowflake query", func(t *testing.T) { - queryInfo := QueryInfo{ - Account: account, - Warehouse: "test-warehouse", - Schema: "test-schema", - Database: "test-database", - Statement: "SELECT 1", - } - - req, err := buildRequest(post, queryInfo, snowflakeEndpoint, account, token, queryID, false) - header := http.Header{} - header.Add("Authorization", "Bearer "+token) - header.Add("X-Snowflake-Authorization-Token-Type", "KEYPAIR_JWT") - header.Add("Content-Type", "application/json") - header.Add("Accept", "application/json") - - assert.NoError(t, err) - assert.Equal(t, header, req.Header) - assert.Equal(t, snowflakeURL+"?async=true", req.URL.String()) - assert.Equal(t, post, req.Method) - }) - t.Run("build http request for getting a snowflake query status", func(t *testing.T) { - req, err := buildRequest(get, QueryInfo{}, snowflakeEndpoint, account, token, queryID, false) - - assert.NoError(t, err) - assert.Equal(t, snowflakeURL+"/"+queryID, req.URL.String()) - assert.Equal(t, get, req.Method) - }) - t.Run("build http request for deleting a snowflake query", func(t *testing.T) { - req, err := buildRequest(post, QueryInfo{}, snowflakeEndpoint, account, token, queryID, true) - - assert.NoError(t, err) - assert.Equal(t, snowflakeURL+"/"+queryID+"/cancel", req.URL.String()) - assert.Equal(t, post, req.Method) - }) -} - -func TestBuildResponse(t *testing.T) { - t.Run("build http response", func(t *testing.T) { - bodyStr := `{"statementHandle":"019c06a4-0000", "message":"Statement executed successfully."}` - responseBody := ioutil.NopCloser(strings.NewReader(bodyStr)) - response := &http.Response{Body: responseBody} - actualData, err := buildResponse(response) - assert.NoError(t, err) - - bodyByte, err := ioutil.ReadAll(strings.NewReader(bodyStr)) - assert.NoError(t, err) - var expectedData map[string]interface{} - err = json.Unmarshal(bodyByte, &expectedData) - assert.NoError(t, err) - assert.Equal(t, expectedData, actualData) - }) -} diff --git a/go/tasks/plugins/webapi/snowflake/config.go b/go/tasks/plugins/webapi/snowflake/config.go index 4a6647e8c..93160f663 100644 --- a/go/tasks/plugins/webapi/snowflake/config.go +++ b/go/tasks/plugins/webapi/snowflake/config.go @@ -48,7 +48,7 @@ var ( // Config is config for 'snowflake' plugin type Config struct { - // WebAPI defines config for the base WebAPI plugin + // WeCreateTaskInfobAPI defines config for the base WebAPI plugin WebAPI webapi.PluginConfig `json:"webApi" pflag:",Defines config for the base WebAPI plugin."` // ResourceConstraints defines resource constraints on how many executions to be created per project/overall at any given time