forked from flyteorg/flyte
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Snowflake backend plugin (flyteorg#202)
* Add snowflake plugin Signed-off-by: Kevin Su <[email protected]> * Fix test Signed-off-by: Kevin Su <[email protected]> * Fix test Signed-off-by: Kevin Su <[email protected]> * Fix test Signed-off-by: Kevin Su <[email protected]> * Fix test Signed-off-by: Kevin Su <[email protected]> * Remove duplicate code Signed-off-by: Kevin Su <[email protected]> * Improve test coverage Signed-off-by: Kevin Su <[email protected]> * Add integration tests Signed-off-by: Kevin Su <[email protected]> * Improve test coverage Signed-off-by: Kevin Su <[email protected]> * Improve test coverage Signed-off-by: Kevin Su <[email protected]> * Fix lint and tests Signed-off-by: Kevin Su <[email protected]> * update proto Signed-off-by: Kevin Su <[email protected]> * remove snowflake proto Signed-off-by: Kevin Su <[email protected]> * Update idl version Signed-off-by: Kevin Su <[email protected]> * fix test Signed-off-by: Kevin Su <[email protected]>
- Loading branch information
Showing
8 changed files
with
615 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
package snowflake | ||
|
||
import ( | ||
"time" | ||
|
||
pluginsConfig "github.com/flyteorg/flyteplugins/go/tasks/config" | ||
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" | ||
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/webapi" | ||
"github.com/flyteorg/flytestdlib/config" | ||
) | ||
|
||
var ( | ||
defaultConfig = Config{ | ||
WebAPI: webapi.PluginConfig{ | ||
ResourceQuotas: map[core.ResourceNamespace]int{ | ||
"default": 1000, | ||
}, | ||
ReadRateLimiter: webapi.RateLimiterConfig{ | ||
Burst: 100, | ||
QPS: 10, | ||
}, | ||
WriteRateLimiter: webapi.RateLimiterConfig{ | ||
Burst: 100, | ||
QPS: 10, | ||
}, | ||
Caching: webapi.CachingConfig{ | ||
Size: 500000, | ||
ResyncInterval: config.Duration{Duration: 30 * time.Second}, | ||
Workers: 10, | ||
MaxSystemFailures: 5, | ||
}, | ||
ResourceMeta: nil, | ||
}, | ||
ResourceConstraints: core.ResourceConstraintsSpec{ | ||
ProjectScopeResourceConstraint: &core.ResourceConstraint{ | ||
Value: 100, | ||
}, | ||
NamespaceScopeResourceConstraint: &core.ResourceConstraint{ | ||
Value: 50, | ||
}, | ||
}, | ||
DefaultWarehouse: "COMPUTE_WH", | ||
TokenKey: "FLYTE_SNOWFLAKE_CLIENT_TOKEN", | ||
} | ||
|
||
configSection = pluginsConfig.MustRegisterSubSection("snowflake", &defaultConfig) | ||
) | ||
|
||
// Config is config for 'snowflake' plugin | ||
type Config struct { | ||
// WeCreateTaskInfobAPI defines config for the base WebAPI plugin | ||
WebAPI webapi.PluginConfig `json:"webApi" pflag:",Defines config for the base WebAPI plugin."` | ||
|
||
// ResourceConstraints defines resource constraints on how many executions to be created per project/overall at any given time | ||
ResourceConstraints core.ResourceConstraintsSpec `json:"resourceConstraints" pflag:"-,Defines resource constraints on how many executions to be created per project/overall at any given time."` | ||
|
||
DefaultWarehouse string `json:"defaultWarehouse" pflag:",Defines the default warehouse to use when running on Snowflake unless overwritten by the task."` | ||
|
||
TokenKey string `json:"snowflakeTokenKey" pflag:",Name of the key where to find Snowflake token in the secret manager."` | ||
|
||
// snowflakeEndpoint overrides Snowflake client endpoint, only for testing | ||
snowflakeEndpoint string | ||
} | ||
|
||
func GetConfig() *Config { | ||
return configSection.GetConfig().(*Config) | ||
} | ||
|
||
func SetConfig(cfg *Config) error { | ||
return configSection.SetConfig(cfg) | ||
} |
18 changes: 18 additions & 0 deletions
18
flyteplugins/go/tasks/plugins/webapi/snowflake/config_test.go
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
package snowflake | ||
|
||
import ( | ||
"testing" | ||
"time" | ||
|
||
"github.com/stretchr/testify/assert" | ||
) | ||
|
||
func TestGetAndSetConfig(t *testing.T) { | ||
cfg := defaultConfig | ||
cfg.DefaultWarehouse = "test-warehouse" | ||
cfg.WebAPI.Caching.Workers = 1 | ||
cfg.WebAPI.Caching.ResyncInterval.Duration = 5 * time.Second | ||
err := SetConfig(&cfg) | ||
assert.NoError(t, err) | ||
assert.Equal(t, &cfg, GetConfig()) | ||
} |
107 changes: 107 additions & 0 deletions
107
flyteplugins/go/tasks/plugins/webapi/snowflake/integration_test.go
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
package snowflake | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"net/http" | ||
"net/http/httptest" | ||
"testing" | ||
"time" | ||
|
||
"github.com/flyteorg/flyteidl/clients/go/coreutils" | ||
coreIdl "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" | ||
flyteIdlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" | ||
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery" | ||
pluginCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" | ||
pluginCoreMocks "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks" | ||
"github.com/flyteorg/flyteplugins/tests" | ||
"github.com/flyteorg/flytestdlib/contextutils" | ||
"github.com/flyteorg/flytestdlib/promutils" | ||
"github.com/flyteorg/flytestdlib/promutils/labeled" | ||
"github.com/stretchr/testify/assert" | ||
"github.com/stretchr/testify/mock" | ||
) | ||
|
||
func TestEndToEnd(t *testing.T) { | ||
server := newFakeSnowflakeServer() | ||
defer server.Close() | ||
|
||
iter := func(ctx context.Context, tCtx pluginCore.TaskExecutionContext) error { | ||
return nil | ||
} | ||
|
||
cfg := defaultConfig | ||
cfg.snowflakeEndpoint = server.URL | ||
cfg.DefaultWarehouse = "test-warehouse" | ||
cfg.WebAPI.Caching.Workers = 1 | ||
cfg.WebAPI.Caching.ResyncInterval.Duration = 5 * time.Second | ||
err := SetConfig(&cfg) | ||
assert.NoError(t, err) | ||
|
||
pluginEntry := pluginmachinery.CreateRemotePlugin(newSnowflakeJobTaskPlugin()) | ||
plugin, err := pluginEntry.LoadPlugin(context.TODO(), newFakeSetupContext()) | ||
assert.NoError(t, err) | ||
|
||
t.Run("SELECT 1", func(t *testing.T) { | ||
config := make(map[string]string) | ||
config["database"] = "my-database" | ||
config["account"] = "snowflake" | ||
config["schema"] = "my-schema" | ||
config["warehouse"] = "my-warehouse" | ||
|
||
inputs, _ := coreutils.MakeLiteralMap(map[string]interface{}{"x": 1}) | ||
template := flyteIdlCore.TaskTemplate{ | ||
Type: "snowflake", | ||
Config: config, | ||
Target: &coreIdl.TaskTemplate_Sql{Sql: &coreIdl.Sql{Statement: "SELECT 1", Dialect: coreIdl.Sql_ANSI}}, | ||
} | ||
|
||
phase := tests.RunPluginEndToEndTest(t, plugin, &template, inputs, nil, nil, iter) | ||
|
||
assert.Equal(t, true, phase.Phase().IsSuccess()) | ||
}) | ||
} | ||
|
||
func newFakeSnowflakeServer() *httptest.Server { | ||
statementHandle := "019e7546-0000-278c-0000-40f10001a082" | ||
return httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { | ||
if request.URL.Path == "/api/statements" && request.Method == "POST" { | ||
writer.WriteHeader(202) | ||
bytes := []byte(fmt.Sprintf(`{ | ||
"statementHandle": "%v", | ||
"message": "Asynchronous execution in progress." | ||
}`, statementHandle)) | ||
_, _ = writer.Write(bytes) | ||
return | ||
} | ||
|
||
if request.URL.Path == "/api/statements/"+statementHandle && request.Method == "GET" { | ||
writer.WriteHeader(200) | ||
bytes := []byte(fmt.Sprintf(`{ | ||
"statementHandle": "%v", | ||
"message": "Statement executed successfully." | ||
}`, statementHandle)) | ||
_, _ = writer.Write(bytes) | ||
return | ||
} | ||
|
||
if request.URL.Path == "/api/statements/"+statementHandle+"/cancel" && request.Method == "POST" { | ||
writer.WriteHeader(200) | ||
return | ||
} | ||
|
||
writer.WriteHeader(500) | ||
})) | ||
} | ||
|
||
func newFakeSetupContext() *pluginCoreMocks.SetupContext { | ||
fakeResourceRegistrar := pluginCoreMocks.ResourceRegistrar{} | ||
fakeResourceRegistrar.On("RegisterResourceQuota", mock.Anything, mock.Anything, mock.Anything).Return(nil) | ||
labeled.SetMetricKeys(contextutils.NamespaceKey) | ||
|
||
fakeSetupContext := pluginCoreMocks.SetupContext{} | ||
fakeSetupContext.OnMetricsScope().Return(promutils.NewScope("test")) | ||
fakeSetupContext.OnResourceRegistrar().Return(&fakeResourceRegistrar) | ||
|
||
return &fakeSetupContext | ||
} |
Oops, something went wrong.