From 632c508459f39f4dca827d7fd40cc79cba0d8b48 Mon Sep 17 00:00:00 2001 From: Gleb Kanterov Date: Fri, 4 Jun 2021 16:59:10 +0200 Subject: [PATCH] Add BigQuery plugin (#161) Signed-off-by: Gleb Kanterov --- flyteplugins/go.mod | 4 +- flyteplugins/go.sum | 4 + .../go/tasks/pluginmachinery/core/phase.go | 6 + .../go/tasks/pluginmachinery/google/config.go | 19 + .../google/default_token_source_factory.go | 18 + .../google/token_source_factory.go | 25 + .../go/tasks/pluginmachinery/registry.go | 4 + .../tasks/plugins/webapi/bigquery/config.go | 74 +++ .../plugins/webapi/bigquery/config_flags.go | 55 ++ .../webapi/bigquery/config_flags_test.go | 322 +++++++++++ .../plugins/webapi/bigquery/config_test.go | 28 + .../webapi/bigquery/integration_test.go | 98 ++++ .../tasks/plugins/webapi/bigquery/plugin.go | 503 ++++++++++++++++++ .../plugins/webapi/bigquery/plugin_test.go | 206 +++++++ .../plugins/webapi/bigquery/query_job.go | 258 +++++++++ .../plugins/webapi/bigquery/query_job_test.go | 83 +++ flyteplugins/tests/end_to_end.go | 9 +- 17 files changed, 1714 insertions(+), 2 deletions(-) create mode 100644 flyteplugins/go/tasks/pluginmachinery/google/config.go create mode 100644 flyteplugins/go/tasks/pluginmachinery/google/default_token_source_factory.go create mode 100644 flyteplugins/go/tasks/pluginmachinery/google/token_source_factory.go create mode 100644 flyteplugins/go/tasks/plugins/webapi/bigquery/config.go create mode 100755 flyteplugins/go/tasks/plugins/webapi/bigquery/config_flags.go create mode 100755 flyteplugins/go/tasks/plugins/webapi/bigquery/config_flags_test.go create mode 100644 flyteplugins/go/tasks/plugins/webapi/bigquery/config_test.go create mode 100644 flyteplugins/go/tasks/plugins/webapi/bigquery/integration_test.go create mode 100644 flyteplugins/go/tasks/plugins/webapi/bigquery/plugin.go create mode 100644 flyteplugins/go/tasks/plugins/webapi/bigquery/plugin_test.go create mode 100644 flyteplugins/go/tasks/plugins/webapi/bigquery/query_job.go create mode 100644 flyteplugins/go/tasks/plugins/webapi/bigquery/query_job_test.go diff --git a/flyteplugins/go.mod b/flyteplugins/go.mod index 752138e031..9b9b403385 100644 --- a/flyteplugins/go.mod +++ b/flyteplugins/go.mod @@ -36,11 +36,13 @@ require ( go.uber.org/zap v1.16.0 // indirect golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83 // indirect golang.org/x/net v0.0.0-20210226172049-e18ecbb05110 - golang.org/x/oauth2 v0.0.0-20210220000619-9bb904979d93 // indirect + golang.org/x/oauth2 v0.0.0-20210220000619-9bb904979d93 golang.org/x/sys v0.0.0-20210303074136-134d130e1a04 // indirect golang.org/x/term v0.0.0-20210220032956-6a3ed077a48d // indirect golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba // indirect + google.golang.org/api v0.40.0 google.golang.org/grpc v1.35.0 + google.golang.org/protobuf v1.25.0 gotest.tools v2.2.0+incompatible k8s.io/api v0.20.2 k8s.io/apimachinery v0.20.2 diff --git a/flyteplugins/go.sum b/flyteplugins/go.sum index 09c23b59f7..5cd006933e 100644 --- a/flyteplugins/go.sum +++ b/flyteplugins/go.sum @@ -72,6 +72,7 @@ github.com/Azure/go-autorest/logger v0.2.0/go.mod h1:T9E3cAhj2VqvPOtCYAvby9aBXkZ github.com/Azure/go-autorest/tracing v0.5.0/go.mod h1:r/s2XiOKccPW3HrqB+W0TQzfbtp2fGCgRFtBroKn4Dk= github.com/Azure/go-autorest/tracing v0.6.0 h1:TYi4+3m5t6K48TGI9AUdb+IzbnSxvnvUMfuitfgcfuo= github.com/Azure/go-autorest/tracing v0.6.0/go.mod h1:+vhtPC754Xsa23ID7GlGsrdKBpUA79WCAKPPZVC2DeU= +github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/GoogleCloudPlatform/spark-on-k8s-operator v0.0.0-20200723154620-6f35a1152625 h1:cQyO5JQ2iuHnEcF3v24kdDMsgh04RjyFPDtuvD6PCE0= @@ -216,6 +217,7 @@ github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1m github.com/envoyproxy/go-control-plane v0.9.7/go.mod h1:cwu0lG7PUMfa9snN8LXBig5ynNVH9qI8YYLbd1fK2po= github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/ernesto-jimenez/gogen v0.0.0-20180125220232-d7d4131e6607 h1:cTavhURetDkezJCvxFggiyLeP40Mrk/TtVg2+ycw1Es= github.com/ernesto-jimenez/gogen v0.0.0-20180125220232-d7d4131e6607/go.mod h1:Cg4fM0vhYWOZdgM7RIOSTRNIc8/VT7CXClC3Ni86lu4= github.com/evanphx/json-patch v0.0.0-20200808040245-162e5629780b/go.mod h1:NAJj0yf/KaRKURN6nyi7A9IZydMivZEm9oQLWNjfKDc= github.com/evanphx/json-patch v4.2.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQLiYLvXMP4fmwYFNcr97nuDLSk= @@ -225,6 +227,7 @@ github.com/evanphx/json-patch v4.9.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQLi github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= github.com/fatih/color v1.10.0 h1:s36xzo75JdqLaaWoiEHk767eHiwo0598uUxyfiPkDsg= github.com/fatih/color v1.10.0/go.mod h1:ELkj/draVOlAH/xkhN6mQ50Qd0MPOk5AAr3maGEBuJM= +github.com/fatih/structtag v1.2.0 h1:/OdNE99OxoI/PqaW/SuSK9uxxT3f/tcSZgon/ssNSx4= github.com/fatih/structtag v1.2.0/go.mod h1:mBJUNpUnHmRKrKlQQlmCrh5PuhftFbNv8Ys4/aAZl94= github.com/flyteorg/flyteidl v0.18.48 h1:WYTat8kFS0mDxLoTEQai2/uy4YO/cavsvh1t3/EKQCw= github.com/flyteorg/flyteidl v0.18.48/go.mod h1:576W2ViEyjTpT+kEVHAGbrTP3HARNUZ/eCwrNPmdx9U= @@ -1212,6 +1215,7 @@ honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWh honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= honnef.co/go/tools v0.0.1-2020.1.3/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= +honnef.co/go/tools v0.0.1-2020.1.4 h1:UoveltGrhghAA7ePc+e+QYDHXrBps2PqFZiHkGR/xK8= honnef.co/go/tools v0.0.1-2020.1.4/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= k8s.io/api v0.0.0-20210217171935-8e2decd92398/go.mod h1:60tmSUpHxGPFerNHbo/ayI2lKxvtrhbxFyXuEIWJd78= k8s.io/api v0.18.2/go.mod h1:SJCWI7OLzhZSvbY7U8zwNl9UA4o1fizoug34OV/2r78= diff --git a/flyteplugins/go/tasks/pluginmachinery/core/phase.go b/flyteplugins/go/tasks/pluginmachinery/core/phase.go index b6a06bb3e2..da9cc9faf7 100644 --- a/flyteplugins/go/tasks/pluginmachinery/core/phase.go +++ b/flyteplugins/go/tasks/pluginmachinery/core/phase.go @@ -172,6 +172,12 @@ func PhaseInfoQueued(t time.Time, version uint32, reason string) PhaseInfo { return pi } +func PhaseInfoQueuedWithTaskInfo(version uint32, reason string, info *TaskInfo) PhaseInfo { + pi := phaseInfo(PhaseQueued, version, nil, info) + pi.reason = reason + return pi +} + func PhaseInfoInitializing(t time.Time, version uint32, reason string, info *TaskInfo) PhaseInfo { pi := phaseInfo(PhaseInitializing, version, nil, info) diff --git a/flyteplugins/go/tasks/pluginmachinery/google/config.go b/flyteplugins/go/tasks/pluginmachinery/google/config.go new file mode 100644 index 0000000000..445cb9efdf --- /dev/null +++ b/flyteplugins/go/tasks/pluginmachinery/google/config.go @@ -0,0 +1,19 @@ +package google + +type TokenSourceFactoryType = string + +const ( + TokenSourceTypeDefault = "default" +) + +type TokenSourceFactoryConfig struct { + // Type is type of TokenSourceFactory, possible values are 'default' or 'gke'. + // - 'default' uses default credentials, see https://cloud.google.com/iam/docs/service-accounts#default + Type TokenSourceFactoryType `json:"type" pflag:",Defines type of TokenSourceFactory, possible values are 'default'"` +} + +func GetDefaultConfig() TokenSourceFactoryConfig { + return TokenSourceFactoryConfig{ + Type: "default", + } +} diff --git a/flyteplugins/go/tasks/pluginmachinery/google/default_token_source_factory.go b/flyteplugins/go/tasks/pluginmachinery/google/default_token_source_factory.go new file mode 100644 index 0000000000..430e208791 --- /dev/null +++ b/flyteplugins/go/tasks/pluginmachinery/google/default_token_source_factory.go @@ -0,0 +1,18 @@ +package google + +import ( + "context" + + "golang.org/x/oauth2" + "golang.org/x/oauth2/google" +) + +type defaultTokenSource struct{} + +func (m *defaultTokenSource) GetTokenSource(ctx context.Context, identity Identity) (oauth2.TokenSource, error) { + return google.DefaultTokenSource(ctx) +} + +func NewDefaultTokenSourceFactory() (TokenSourceFactory, error) { + return &defaultTokenSource{}, nil +} diff --git a/flyteplugins/go/tasks/pluginmachinery/google/token_source_factory.go b/flyteplugins/go/tasks/pluginmachinery/google/token_source_factory.go new file mode 100644 index 0000000000..05207e25c9 --- /dev/null +++ b/flyteplugins/go/tasks/pluginmachinery/google/token_source_factory.go @@ -0,0 +1,25 @@ +package google + +import ( + "context" + + "github.com/pkg/errors" + "golang.org/x/oauth2" +) + +type Identity struct { + K8sNamespace string + K8sServiceAccount string +} + +type TokenSourceFactory interface { + GetTokenSource(ctx context.Context, identity Identity) (oauth2.TokenSource, error) +} + +func NewTokenSourceFactory(config TokenSourceFactoryConfig) (TokenSourceFactory, error) { + if config.Type == TokenSourceTypeDefault { + return NewDefaultTokenSourceFactory() + } + + return nil, errors.Errorf("unknown token source type [%v], possible values are: 'default'", config.Type) +} diff --git a/flyteplugins/go/tasks/pluginmachinery/registry.go b/flyteplugins/go/tasks/pluginmachinery/registry.go index a3bb49793c..76e7844bbd 100644 --- a/flyteplugins/go/tasks/pluginmachinery/registry.go +++ b/flyteplugins/go/tasks/pluginmachinery/registry.go @@ -45,6 +45,10 @@ func (p *taskPluginRegistry) RegisterRemotePlugin(info webapi.PluginEntry) { p.corePlugin = append(p.corePlugin, internalRemote.CreateRemotePlugin(info)) } +func CreateRemotePlugin(pluginEntry webapi.PluginEntry) core.PluginEntry { + return internalRemote.CreateRemotePlugin(pluginEntry) +} + // Use this method to register Kubernetes Plugins func (p *taskPluginRegistry) RegisterK8sPlugin(info k8s.PluginEntry) { if info.ID == "" { diff --git a/flyteplugins/go/tasks/plugins/webapi/bigquery/config.go b/flyteplugins/go/tasks/plugins/webapi/bigquery/config.go new file mode 100644 index 0000000000..fe1bcad30e --- /dev/null +++ b/flyteplugins/go/tasks/plugins/webapi/bigquery/config.go @@ -0,0 +1,74 @@ +// Package bigquery implements WebAPI plugin for Google BigQuery +package bigquery + +import ( + "time" + + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/google" + + pluginsConfig "github.com/flyteorg/flyteplugins/go/tasks/config" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/webapi" + "github.com/flyteorg/flytestdlib/config" +) + +//go:generate pflags Config --default-var=defaultConfig + +var ( + defaultConfig = Config{ + WebAPI: webapi.PluginConfig{ + ResourceQuotas: map[core.ResourceNamespace]int{ + "default": 1000, + }, + ReadRateLimiter: webapi.RateLimiterConfig{ + Burst: 100, + QPS: 10, + }, + WriteRateLimiter: webapi.RateLimiterConfig{ + Burst: 100, + QPS: 10, + }, + Caching: webapi.CachingConfig{ + Size: 500000, + ResyncInterval: config.Duration{Duration: 30 * time.Second}, + Workers: 10, + MaxSystemFailures: 5, + }, + ResourceMeta: nil, + }, + ResourceConstraints: core.ResourceConstraintsSpec{ + ProjectScopeResourceConstraint: &core.ResourceConstraint{ + Value: 100, + }, + NamespaceScopeResourceConstraint: &core.ResourceConstraint{ + Value: 50, + }, + }, + GoogleTokenSource: google.GetDefaultConfig(), + } + + configSection = pluginsConfig.MustRegisterSubSection("bigquery", &defaultConfig) +) + +// Config is config for 'bigquery' plugin +type Config struct { + // WebAPI defines config for the base WebAPI plugin + WebAPI webapi.PluginConfig `json:"webApi" pflag:",Defines config for the base WebAPI plugin."` + + // ResourceConstraints defines resource constraints on how many executions to be created per project/overall at any given time + ResourceConstraints core.ResourceConstraintsSpec `json:"resourceConstraints" pflag:"-,Defines resource constraints on how many executions to be created per project/overall at any given time."` + + // GoogleTokenSource configures token source for BigQuery client + GoogleTokenSource google.TokenSourceFactoryConfig `json:"googleTokenSource" pflag:",Defines Google token source"` + + // bigQueryEndpoint overrides BigQuery client endpoint, only for testing + bigQueryEndpoint string +} + +func GetConfig() *Config { + return configSection.GetConfig().(*Config) +} + +func SetConfig(cfg *Config) error { + return configSection.SetConfig(cfg) +} diff --git a/flyteplugins/go/tasks/plugins/webapi/bigquery/config_flags.go b/flyteplugins/go/tasks/plugins/webapi/bigquery/config_flags.go new file mode 100755 index 0000000000..3ed4f34395 --- /dev/null +++ b/flyteplugins/go/tasks/plugins/webapi/bigquery/config_flags.go @@ -0,0 +1,55 @@ +// Code generated by go generate; DO NOT EDIT. +// This file was generated by robots. + +package bigquery + +import ( + "encoding/json" + "reflect" + + "fmt" + + "github.com/spf13/pflag" +) + +// If v is a pointer, it will get its element value or the zero value of the element type. +// If v is not a pointer, it will return it as is. +func (Config) elemValueOrNil(v interface{}) interface{} { + if t := reflect.TypeOf(v); t.Kind() == reflect.Ptr { + if reflect.ValueOf(v).IsNil() { + return reflect.Zero(t.Elem()).Interface() + } else { + return reflect.ValueOf(v).Interface() + } + } else if v == nil { + return reflect.Zero(t).Interface() + } + + return v +} + +func (Config) mustMarshalJSON(v json.Marshaler) string { + raw, err := v.MarshalJSON() + if err != nil { + panic(err) + } + + return string(raw) +} + +// GetPFlagSet will return strongly types pflags for all fields in Config and its nested types. The format of the +// flags is json-name.json-sub-name... etc. +func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { + cmdFlags := pflag.NewFlagSet("Config", pflag.ExitOnError) + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "webApi.readRateLimiter.qps"), defaultConfig.WebAPI.ReadRateLimiter.QPS, "Defines the max rate of calls per second.") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "webApi.readRateLimiter.burst"), defaultConfig.WebAPI.ReadRateLimiter.Burst, "Defines the maximum burst size.") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "webApi.writeRateLimiter.qps"), defaultConfig.WebAPI.WriteRateLimiter.QPS, "Defines the max rate of calls per second.") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "webApi.writeRateLimiter.burst"), defaultConfig.WebAPI.WriteRateLimiter.Burst, "Defines the maximum burst size.") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "webApi.caching.size"), defaultConfig.WebAPI.Caching.Size, "Defines the maximum number of items to cache.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "webApi.caching.resyncInterval"), defaultConfig.WebAPI.Caching.ResyncInterval.String(), "Defines the sync interval.") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "webApi.caching.workers"), defaultConfig.WebAPI.Caching.Workers, "Defines the number of workers to start up to process items.") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "webApi.caching.maxSystemFailures"), defaultConfig.WebAPI.Caching.MaxSystemFailures, "Defines the number of failures to fetch a task before failing the task.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "googleTokenSource.type"), defaultConfig.GoogleTokenSource.Type, "Defines type of TokenSourceFactory, possible values are 'default'") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "bigQueryEndpoint"), defaultConfig.bigQueryEndpoint, "") + return cmdFlags +} diff --git a/flyteplugins/go/tasks/plugins/webapi/bigquery/config_flags_test.go b/flyteplugins/go/tasks/plugins/webapi/bigquery/config_flags_test.go new file mode 100755 index 0000000000..7d969f8734 --- /dev/null +++ b/flyteplugins/go/tasks/plugins/webapi/bigquery/config_flags_test.go @@ -0,0 +1,322 @@ +// Code generated by go generate; DO NOT EDIT. +// This file was generated by robots. + +package bigquery + +import ( + "encoding/json" + "fmt" + "reflect" + "strings" + "testing" + + "github.com/mitchellh/mapstructure" + "github.com/stretchr/testify/assert" +) + +var dereferencableKindsConfig = map[reflect.Kind]struct{}{ + reflect.Array: {}, reflect.Chan: {}, reflect.Map: {}, reflect.Ptr: {}, reflect.Slice: {}, +} + +// Checks if t is a kind that can be dereferenced to get its underlying type. +func canGetElementConfig(t reflect.Kind) bool { + _, exists := dereferencableKindsConfig[t] + return exists +} + +// This decoder hook tests types for json unmarshaling capability. If implemented, it uses json unmarshal to build the +// object. Otherwise, it'll just pass on the original data. +func jsonUnmarshalerHookConfig(_, to reflect.Type, data interface{}) (interface{}, error) { + unmarshalerType := reflect.TypeOf((*json.Unmarshaler)(nil)).Elem() + if to.Implements(unmarshalerType) || reflect.PtrTo(to).Implements(unmarshalerType) || + (canGetElementConfig(to.Kind()) && to.Elem().Implements(unmarshalerType)) { + + raw, err := json.Marshal(data) + if err != nil { + fmt.Printf("Failed to marshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err) + return data, nil + } + + res := reflect.New(to).Interface() + err = json.Unmarshal(raw, &res) + if err != nil { + fmt.Printf("Failed to umarshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err) + return data, nil + } + + return res, nil + } + + return data, nil +} + +func decode_Config(input, result interface{}) error { + config := &mapstructure.DecoderConfig{ + TagName: "json", + WeaklyTypedInput: true, + Result: result, + DecodeHook: mapstructure.ComposeDecodeHookFunc( + mapstructure.StringToTimeDurationHookFunc(), + mapstructure.StringToSliceHookFunc(","), + jsonUnmarshalerHookConfig, + ), + } + + decoder, err := mapstructure.NewDecoder(config) + if err != nil { + return err + } + + return decoder.Decode(input) +} + +func join_Config(arr interface{}, sep string) string { + listValue := reflect.ValueOf(arr) + strs := make([]string, 0, listValue.Len()) + for i := 0; i < listValue.Len(); i++ { + strs = append(strs, fmt.Sprintf("%v", listValue.Index(i))) + } + + return strings.Join(strs, sep) +} + +func testDecodeJson_Config(t *testing.T, val, result interface{}) { + assert.NoError(t, decode_Config(val, result)) +} + +func testDecodeSlice_Config(t *testing.T, vStringSlice, result interface{}) { + assert.NoError(t, decode_Config(vStringSlice, result)) +} + +func TestConfig_GetPFlagSet(t *testing.T) { + val := Config{} + cmdFlags := val.GetPFlagSet("") + assert.True(t, cmdFlags.HasFlags()) +} + +func TestConfig_SetFlags(t *testing.T) { + actual := Config{} + cmdFlags := actual.GetPFlagSet("") + assert.True(t, cmdFlags.HasFlags()) + + t.Run("Test_webApi.readRateLimiter.qps", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt, err := cmdFlags.GetInt("webApi.readRateLimiter.qps"); err == nil { + assert.Equal(t, int(defaultConfig.WebAPI.ReadRateLimiter.QPS), vInt) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("webApi.readRateLimiter.qps", testValue) + if vInt, err := cmdFlags.GetInt("webApi.readRateLimiter.qps"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.WebAPI.ReadRateLimiter.QPS) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_webApi.readRateLimiter.burst", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt, err := cmdFlags.GetInt("webApi.readRateLimiter.burst"); err == nil { + assert.Equal(t, int(defaultConfig.WebAPI.ReadRateLimiter.Burst), vInt) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("webApi.readRateLimiter.burst", testValue) + if vInt, err := cmdFlags.GetInt("webApi.readRateLimiter.burst"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.WebAPI.ReadRateLimiter.Burst) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_webApi.writeRateLimiter.qps", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt, err := cmdFlags.GetInt("webApi.writeRateLimiter.qps"); err == nil { + assert.Equal(t, int(defaultConfig.WebAPI.WriteRateLimiter.QPS), vInt) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("webApi.writeRateLimiter.qps", testValue) + if vInt, err := cmdFlags.GetInt("webApi.writeRateLimiter.qps"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.WebAPI.WriteRateLimiter.QPS) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_webApi.writeRateLimiter.burst", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt, err := cmdFlags.GetInt("webApi.writeRateLimiter.burst"); err == nil { + assert.Equal(t, int(defaultConfig.WebAPI.WriteRateLimiter.Burst), vInt) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("webApi.writeRateLimiter.burst", testValue) + if vInt, err := cmdFlags.GetInt("webApi.writeRateLimiter.burst"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.WebAPI.WriteRateLimiter.Burst) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_webApi.caching.size", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt, err := cmdFlags.GetInt("webApi.caching.size"); err == nil { + assert.Equal(t, int(defaultConfig.WebAPI.Caching.Size), vInt) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("webApi.caching.size", testValue) + if vInt, err := cmdFlags.GetInt("webApi.caching.size"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.WebAPI.Caching.Size) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_webApi.caching.resyncInterval", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("webApi.caching.resyncInterval"); err == nil { + assert.Equal(t, string(defaultConfig.WebAPI.Caching.ResyncInterval.String()), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := defaultConfig.WebAPI.Caching.ResyncInterval.String() + + cmdFlags.Set("webApi.caching.resyncInterval", testValue) + if vString, err := cmdFlags.GetString("webApi.caching.resyncInterval"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.WebAPI.Caching.ResyncInterval) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_webApi.caching.workers", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt, err := cmdFlags.GetInt("webApi.caching.workers"); err == nil { + assert.Equal(t, int(defaultConfig.WebAPI.Caching.Workers), vInt) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("webApi.caching.workers", testValue) + if vInt, err := cmdFlags.GetInt("webApi.caching.workers"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.WebAPI.Caching.Workers) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_webApi.caching.maxSystemFailures", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt, err := cmdFlags.GetInt("webApi.caching.maxSystemFailures"); err == nil { + assert.Equal(t, int(defaultConfig.WebAPI.Caching.MaxSystemFailures), vInt) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("webApi.caching.maxSystemFailures", testValue) + if vInt, err := cmdFlags.GetInt("webApi.caching.maxSystemFailures"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.WebAPI.Caching.MaxSystemFailures) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_googleTokenSource.type", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("googleTokenSource.type"); err == nil { + assert.Equal(t, string(defaultConfig.GoogleTokenSource.Type), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("googleTokenSource.type", testValue) + if vString, err := cmdFlags.GetString("googleTokenSource.type"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.GoogleTokenSource.Type) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_bigQueryEndpoint", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("bigQueryEndpoint"); err == nil { + assert.Equal(t, string(defaultConfig.bigQueryEndpoint), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("bigQueryEndpoint", testValue) + if vString, err := cmdFlags.GetString("bigQueryEndpoint"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.bigQueryEndpoint) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) +} diff --git a/flyteplugins/go/tasks/plugins/webapi/bigquery/config_test.go b/flyteplugins/go/tasks/plugins/webapi/bigquery/config_test.go new file mode 100644 index 0000000000..88de9cf16b --- /dev/null +++ b/flyteplugins/go/tasks/plugins/webapi/bigquery/config_test.go @@ -0,0 +1,28 @@ +package bigquery + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/types/known/structpb" +) + +func TestUnmarshalBigQueryQueryConfig(t *testing.T) { + custom := structpb.Struct{ + Fields: map[string]*structpb.Value{ + "projectId": structpb.NewStringValue("project-id"), + "location": structpb.NewStringValue("EU"), + "query": structpb.NewStringValue("SELECT 1"), + }, + } + + config, err := unmarshalQueryJobConfig(&custom) + + assert.NoError(t, err) + + assert.Equal(t, config, &QueryJobConfig{ + ProjectID: "project-id", + Location: "EU", + Query: "SELECT 1", + }) +} diff --git a/flyteplugins/go/tasks/plugins/webapi/bigquery/integration_test.go b/flyteplugins/go/tasks/plugins/webapi/bigquery/integration_test.go new file mode 100644 index 0000000000..1871506332 --- /dev/null +++ b/flyteplugins/go/tasks/plugins/webapi/bigquery/integration_test.go @@ -0,0 +1,98 @@ +package bigquery + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/flyteorg/flyteidl/clients/go/coreutils" + + flyteIdlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery" + pluginCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" + pluginCoreMocks "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks" + pluginUtils "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils" + "github.com/flyteorg/flyteplugins/tests" + "github.com/flyteorg/flytestdlib/contextutils" + "github.com/flyteorg/flytestdlib/promutils" + "github.com/flyteorg/flytestdlib/promutils/labeled" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "google.golang.org/api/bigquery/v2" +) + +func TestEndToEnd(t *testing.T) { + server := newFakeBigQueryServer() + defer server.Close() + + iter := func(ctx context.Context, tCtx pluginCore.TaskExecutionContext) error { + return nil + } + + cfg := defaultConfig + cfg.bigQueryEndpoint = server.URL + cfg.WebAPI.Caching.Workers = 1 + cfg.WebAPI.Caching.ResyncInterval.Duration = 5 * time.Second + err := SetConfig(&cfg) + assert.NoError(t, err) + + pluginEntry := pluginmachinery.CreateRemotePlugin(newBigQueryJobTaskPlugin()) + plugin, err := pluginEntry.LoadPlugin(context.TODO(), newFakeSetupContext()) + assert.NoError(t, err) + + t.Run("SELECT 1", func(t *testing.T) { + queryJobConfig := QueryJobConfig{ + ProjectID: "flyte", + Query: "SELECT 1", + } + + inputs, _ := coreutils.MakeLiteralMap(map[string]interface{}{"x": 1}) + custom, _ := pluginUtils.MarshalObjToStruct(queryJobConfig) + template := flyteIdlCore.TaskTemplate{ + Type: bigqueryQueryJobTask, + Custom: custom, + } + + phase := tests.RunPluginEndToEndTest(t, plugin, &template, inputs, nil, nil, iter) + + assert.Equal(t, true, phase.Phase().IsSuccess()) + }) +} + +func newFakeBigQueryServer() *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + if request.URL.Path == "/projects/flyte/jobs" && request.Method == "POST" { + writer.WriteHeader(200) + job := bigquery.Job{Status: &bigquery.JobStatus{State: "RUNNING"}} + bytes, _ := json.Marshal(job) + _, _ = writer.Write(bytes) + return + } + + if strings.HasPrefix(request.URL.Path, "/projects/flyte/jobs/") && request.Method == "GET" { + writer.WriteHeader(200) + job := bigquery.Job{Status: &bigquery.JobStatus{State: "DONE"}} + bytes, _ := json.Marshal(job) + _, _ = writer.Write(bytes) + return + } + + writer.WriteHeader(500) + })) +} + +func newFakeSetupContext() *pluginCoreMocks.SetupContext { + fakeResourceRegistrar := pluginCoreMocks.ResourceRegistrar{} + fakeResourceRegistrar.On("RegisterResourceQuota", mock.Anything, mock.Anything, mock.Anything).Return(nil) + labeled.SetMetricKeys(contextutils.NamespaceKey) + + fakeSetupContext := pluginCoreMocks.SetupContext{} + fakeSetupContext.OnMetricsScope().Return(promutils.NewScope("test")) + fakeSetupContext.OnResourceRegistrar().Return(&fakeResourceRegistrar) + + return &fakeSetupContext +} diff --git a/flyteplugins/go/tasks/plugins/webapi/bigquery/plugin.go b/flyteplugins/go/tasks/plugins/webapi/bigquery/plugin.go new file mode 100644 index 0000000000..7a8b21d282 --- /dev/null +++ b/flyteplugins/go/tasks/plugins/webapi/bigquery/plugin.go @@ -0,0 +1,503 @@ +package bigquery + +import ( + "context" + "encoding/gob" + "fmt" + "net/http" + "time" + + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s" + + "golang.org/x/oauth2" + + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/google" + structpb "github.com/golang/protobuf/ptypes/struct" + "google.golang.org/api/bigquery/v2" + "google.golang.org/api/googleapi" + + flyteIdlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + pluginErrors "github.com/flyteorg/flyteplugins/go/tasks/errors" + pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" + "google.golang.org/api/option" + + "github.com/flyteorg/flytestdlib/logger" + + "github.com/flyteorg/flytestdlib/promutils" + + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/webapi" +) + +const ( + bigqueryQueryJobTask = "bigquery_query_job_task" + bigqueryConsolePath = "https://console.cloud.google.com/bigquery" +) + +type Plugin struct { + metricScope promutils.Scope + cfg *Config + googleTokenSource google.TokenSourceFactory +} + +type ResourceWrapper struct { + Status *bigquery.JobStatus + CreateError *googleapi.Error +} + +type ResourceMetaWrapper struct { + K8sServiceAccount string + Namespace string + JobReference bigquery.JobReference +} + +func (p Plugin) GetConfig() webapi.PluginConfig { + return GetConfig().WebAPI +} + +func (p Plugin) ResourceRequirements(_ context.Context, _ webapi.TaskExecutionContextReader) ( + namespace core.ResourceNamespace, constraints core.ResourceConstraintsSpec, err error) { + + // Resource requirements are assumed to be the same. + return "default", p.cfg.ResourceConstraints, nil +} + +func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextReader) (webapi.ResourceMeta, + webapi.Resource, error) { + return p.createImpl(ctx, taskCtx) +} + +func (p Plugin) createImpl(ctx context.Context, taskCtx webapi.TaskExecutionContextReader) (*ResourceMetaWrapper, + *ResourceWrapper, error) { + + taskTemplate, err := taskCtx.TaskReader().Read(ctx) + jobID := taskCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() + + if err != nil { + return nil, nil, pluginErrors.Wrapf(pluginErrors.RuntimeFailure, err, "unable to fetch task specification") + } + + inputs, err := taskCtx.InputReader().Get(ctx) + + if err != nil { + return nil, nil, pluginErrors.Wrapf(pluginErrors.RuntimeFailure, err, "unable to fetch task inputs") + } + + var job *bigquery.Job + + namespace := taskCtx.TaskExecutionMetadata().GetNamespace() + k8sServiceAccount := flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata()) + identity := google.Identity{K8sNamespace: namespace, K8sServiceAccount: k8sServiceAccount} + client, err := p.newBigQueryClient(ctx, identity) + + if err != nil { + return nil, nil, pluginErrors.Wrapf(pluginErrors.RuntimeFailure, err, "unable to get bigquery client") + } + + if taskTemplate.Type == bigqueryQueryJobTask { + job, err = createQueryJob(jobID, taskTemplate.GetCustom(), inputs) + } else { + err = pluginErrors.Errorf(pluginErrors.BadTaskSpecification, "unexpected task type [%v]", taskTemplate.Type) + } + + if err != nil { + return nil, nil, err + } + + job.Configuration.Labels = taskCtx.TaskExecutionMetadata().GetLabels() + + resp, err := client.Jobs.Insert(job.JobReference.ProjectId, job).Do() + + if err != nil { + apiError, ok := err.(*googleapi.Error) + resourceMeta := ResourceMetaWrapper{ + JobReference: *job.JobReference, + Namespace: namespace, + K8sServiceAccount: k8sServiceAccount, + } + + if ok && apiError.Code == 409 { + job, err := client.Jobs.Get(resourceMeta.JobReference.ProjectId, resourceMeta.JobReference.JobId).Do() + + if err != nil { + err := pluginErrors.Wrapf( + pluginErrors.RuntimeFailure, + err, + "failed to get job [%s]", + formatJobReference(resourceMeta.JobReference)) + + return nil, nil, err + } + + resource := ResourceWrapper{Status: job.Status} + + return &resourceMeta, &resource, nil + } + + if ok { + resource := ResourceWrapper{CreateError: apiError} + + return &resourceMeta, &resource, nil + } + + return nil, nil, pluginErrors.Wrapf(pluginErrors.RuntimeFailure, err, "failed to create query job") + } + + resource := ResourceWrapper{Status: resp.Status} + resourceMeta := ResourceMetaWrapper{ + JobReference: *job.JobReference, + Namespace: namespace, + K8sServiceAccount: k8sServiceAccount, + } + + return &resourceMeta, &resource, nil +} + +func createQueryJob(jobID string, custom *structpb.Struct, inputs *flyteIdlCore.LiteralMap) (*bigquery.Job, error) { + queryJobConfig, err := unmarshalQueryJobConfig(custom) + + if err != nil { + return nil, pluginErrors.Wrapf(pluginErrors.BadTaskSpecification, err, "can't unmarshall struct to QueryJobConfig") + } + + jobConfigurationQuery, err := getJobConfigurationQuery(queryJobConfig, inputs) + + if err != nil { + return nil, pluginErrors.Wrapf(pluginErrors.BadTaskSpecification, err, "unable to fetch task inputs") + } + + jobReference := bigquery.JobReference{ + JobId: jobID, + Location: queryJobConfig.Location, + ProjectId: queryJobConfig.ProjectID, + } + + return &bigquery.Job{ + Configuration: &bigquery.JobConfiguration{ + Query: jobConfigurationQuery, + }, + JobReference: &jobReference, + }, nil +} + +func (p Plugin) Get(ctx context.Context, taskCtx webapi.GetContext) (latest webapi.Resource, err error) { + return p.getImpl(ctx, taskCtx) +} + +func (p Plugin) getImpl(ctx context.Context, taskCtx webapi.GetContext) (wrapper *ResourceWrapper, err error) { + resourceMeta := taskCtx.ResourceMeta().(*ResourceMetaWrapper) + + identity := google.Identity{ + K8sNamespace: resourceMeta.Namespace, + K8sServiceAccount: resourceMeta.K8sServiceAccount, + } + client, err := p.newBigQueryClient(ctx, identity) + + if err != nil { + return nil, pluginErrors.Wrapf(pluginErrors.RuntimeFailure, err, "unable to get client") + } + + job, err := client.Jobs.Get(resourceMeta.JobReference.ProjectId, resourceMeta.JobReference.JobId).Do() + + if err != nil { + err := pluginErrors.Wrapf( + pluginErrors.RuntimeFailure, + err, + "failed to get job [%s]", + formatJobReference(resourceMeta.JobReference)) + + return nil, err + } + + return &ResourceWrapper{ + Status: job.Status, + }, nil +} + +func (p Plugin) Delete(ctx context.Context, taskCtx webapi.DeleteContext) error { + if taskCtx.ResourceMeta() == nil { + return nil + } + + resourceMeta := taskCtx.ResourceMeta().(*ResourceMetaWrapper) + + identity := google.Identity{ + K8sNamespace: resourceMeta.Namespace, + K8sServiceAccount: resourceMeta.K8sServiceAccount, + } + client, err := p.newBigQueryClient(ctx, identity) + + if err != nil { + return err + } + + _, err = client.Jobs.Cancel(resourceMeta.JobReference.ProjectId, resourceMeta.JobReference.JobId).Do() + + if err != nil { + return err + } + + logger.Info(ctx, "Cancelled job [%s]", formatJobReference(resourceMeta.JobReference)) + + return nil +} + +func (p Plugin) Status(_ context.Context, tCtx webapi.StatusContext) (phase core.PhaseInfo, err error) { + resourceMeta := tCtx.ResourceMeta().(*ResourceMetaWrapper) + resource := tCtx.Resource().(*ResourceWrapper) + version := pluginsCore.DefaultPhaseVersion + + if resource == nil { + return core.PhaseInfoUndefined, nil + } + + taskInfo := createTaskInfo(resourceMeta) + + if resource.CreateError != nil { + return handleCreateError(resource.CreateError, taskInfo), nil + } + + switch resource.Status.State { + case "PENDING": + return core.PhaseInfoQueuedWithTaskInfo(version, "Query is PENDING", taskInfo), nil + + case "RUNNING": + return core.PhaseInfoRunning(version, taskInfo), nil + + case "DONE": + if resource.Status.ErrorResult != nil { + return handleErrorResult( + resource.Status.ErrorResult.Reason, + resource.Status.ErrorResult.Message, + taskInfo), nil + } + + return pluginsCore.PhaseInfoSuccess(taskInfo), nil + } + + return core.PhaseInfoUndefined, pluginErrors.Errorf(pluginsCore.SystemErrorCode, "unknown execution phase [%v].", resource.Status.State) +} + +func handleCreateError(createError *googleapi.Error, taskInfo *core.TaskInfo) core.PhaseInfo { + code := fmt.Sprintf("http%d", createError.Code) + + userExecutionError := &flyteIdlCore.ExecutionError{ + Message: createError.Message, + Kind: flyteIdlCore.ExecutionError_USER, + Code: code, + } + + systemExecutionError := &flyteIdlCore.ExecutionError{ + Message: createError.Message, + Kind: flyteIdlCore.ExecutionError_SYSTEM, + Code: code, + } + + if createError.Code >= http.StatusBadRequest && createError.Code < http.StatusInternalServerError { + return core.PhaseInfoFailed(pluginsCore.PhasePermanentFailure, userExecutionError, taskInfo) + } + + if createError.Code >= http.StatusInternalServerError { + return core.PhaseInfoFailed(pluginsCore.PhaseRetryableFailure, systemExecutionError, taskInfo) + } + + // something unexpected happened, just terminate task + return core.PhaseInfoFailed(pluginsCore.PhasePermanentFailure, systemExecutionError, taskInfo) +} + +func handleErrorResult(reason string, message string, taskInfo *core.TaskInfo) core.PhaseInfo { + phaseCode := reason + phaseReason := message + + // see https://cloud.google.com/bigquery/docs/error-messages + + // user errors are errors where users have to take action, e.g. fix their code + // all errors with project configuration are also considered as user errors + + // system errors are errors where system doesn't work well and system owners have to take action + // all errors internal to BigQuery are also considered as system errors + + // transient errors are retryable, if any action is needed, errors are permanent + + switch reason { + case "": + return pluginsCore.PhaseInfoSuccess(taskInfo) + + // This error returns when you try to access a resource such as a dataset, table, view, or job that you + // don't have access to. This error also returns when you try to modify a read-only object. + case "accessDenied": + return pluginsCore.PhaseInfoFailure(phaseCode, phaseReason, taskInfo) + + // This error returns when there is a temporary server failure such as a network connection problem or + // a server overload. + case "backendError": + return pluginsCore.PhaseInfoSystemRetryableFailure(phaseCode, phaseReason, taskInfo) + + // This error returns when billing isn't enabled for the project. + case "billingNotEnabled": + return pluginsCore.PhaseInfoFailure(phaseCode, phaseReason, taskInfo) + + // This error returns when BigQuery has temporarily denylisted the operation you attempted to perform, + // usually to prevent a service outage. This error rarely occurs. + case "blocked": + return pluginsCore.PhaseInfoFailure(phaseCode, phaseReason, taskInfo) + + // This error returns when trying to create a job, dataset, or table that already exists. The error also + // returns when a job's writeDisposition property is set to WRITE_EMPTY and the destination table accessed + // by the job already exists. + case "duplicate": + return pluginsCore.PhaseInfoFailure(phaseCode, phaseReason, taskInfo) + + // This error returns when an internal error occurs within BigQuery. + case "internalError": + return pluginsCore.PhaseInfoSystemRetryableFailure(phaseCode, phaseReason, taskInfo) + + // This error returns when there is any kind of invalid input other than an invalid query, such as missing + // required fields or an invalid table schema. Invalid queries return an invalidQuery error instead. + case "invalid": + return pluginsCore.PhaseInfoFailure(phaseCode, phaseReason, taskInfo) + + // This error returns when you attempt to run an invalid query. + case "invalidQuery": + return pluginsCore.PhaseInfoFailure(phaseCode, phaseReason, taskInfo) + + // This error returns when you attempt to schedule a query with invalid user credentials. + case "invalidUser": + return pluginsCore.PhaseInfoSystemRetryableFailure(phaseCode, phaseReason, taskInfo) + + // This error returns when you refer to a resource (a dataset, a table, or a job) that doesn't exist. + // This can also occur when using snapshot decorators to refer to deleted tables that have recently been + // streamed to. + case "notFound": + return pluginsCore.PhaseInfoFailure(phaseCode, phaseReason, taskInfo) + + // This job error returns when you try to access a feature that isn't implemented. + case "notImplemented": + return pluginsCore.PhaseInfoFailure(phaseCode, phaseReason, taskInfo) + + // This error returns when your project exceeds a BigQuery quota, a custom quota, or when you haven't set up + // billing and you have exceeded the free tier for queries. + case "quotaExceeded": + return pluginsCore.PhaseInfoRetryableFailure(phaseCode, phaseReason, taskInfo) + + case "rateLimitExceeded": + return pluginsCore.PhaseInfoRetryableFailure(phaseCode, phaseReason, taskInfo) + + // This error returns when you try to delete a dataset that contains tables or when you try to delete a job + // that is currently running. + case "resourceInUse": + return pluginsCore.PhaseInfoSystemRetryableFailure(phaseCode, phaseReason, taskInfo) + + // This error returns when your query uses too many resources. + case "resourcesExceeded": + return pluginsCore.PhaseInfoFailure(phaseCode, phaseReason, taskInfo) + + // This error returns when your query's results are larger than the maximum response size. Some queries execute + // in multiple stages, and this error returns when any stage returns a response size that is too large, even if + // the final result is smaller than the maximum. This error commonly returns when queries use an ORDER BY + // clause. + case "responseTooLarge": + return pluginsCore.PhaseInfoFailure(phaseCode, phaseReason, taskInfo) + + // This status code returns when a job is canceled. + case "stopped": + return pluginsCore.PhaseInfoFailure(phaseCode, phaseReason, taskInfo) + + // Certain BigQuery tables are backed by data managed by other Google product teams. This error indicates that + // one of these tables is unavailable. + case "tableUnavailable": + return pluginsCore.PhaseInfoSystemRetryableFailure(phaseCode, phaseReason, taskInfo) + + // The job timed out. + case "timeout": + return pluginsCore.PhaseInfoFailure(phaseCode, phaseReason, taskInfo) + + default: + return pluginsCore.PhaseInfoSystemFailure(phaseCode, phaseReason, taskInfo) + } +} + +func createTaskInfo(resourceMeta *ResourceMetaWrapper) *core.TaskInfo { + timeNow := time.Now() + j := formatJobReferenceForQueryParam(resourceMeta.JobReference) + + return &core.TaskInfo{ + OccurredAt: &timeNow, + Logs: []*flyteIdlCore.TaskLog{ + { + Uri: fmt.Sprintf("%s?project=%v&j=%v&page=queryresults", + bigqueryConsolePath, + resourceMeta.JobReference.ProjectId, + j), + Name: "BigQuery Console", + }, + }, + } +} + +func formatJobReference(reference bigquery.JobReference) string { + return fmt.Sprintf("%s:%s.%s", reference.ProjectId, reference.Location, reference.JobId) +} + +func formatJobReferenceForQueryParam(jobReference bigquery.JobReference) string { + return fmt.Sprintf("bq:%s:%s", jobReference.Location, jobReference.JobId) +} + +func (p Plugin) newBigQueryClient(ctx context.Context, identity google.Identity) (*bigquery.Service, error) { + options := []option.ClientOption{ + option.WithScopes("https://www.googleapis.com/auth/bigquery"), + // FIXME how do I access current version? + option.WithUserAgent(fmt.Sprintf("%s/%s", "flytepropeller", "LATEST")), + } + + // for mocking/testing purposes + if p.cfg.bigQueryEndpoint != "" { + options = append(options, + option.WithEndpoint(p.cfg.bigQueryEndpoint), + option.WithTokenSource(oauth2.StaticTokenSource(&oauth2.Token{}))) + } else { + tokenSource, err := p.googleTokenSource.GetTokenSource(ctx, identity) + + if err != nil { + return nil, pluginErrors.Wrapf(pluginErrors.RuntimeFailure, err, "unable to get token source") + } + + options = append(options, option.WithTokenSource(tokenSource)) + } + + return bigquery.NewService(ctx, options...) +} + +func NewPlugin(cfg *Config, metricScope promutils.Scope) (*Plugin, error) { + googleTokenSource, err := google.NewTokenSourceFactory(cfg.GoogleTokenSource) + + if err != nil { + return nil, pluginErrors.Wrapf(pluginErrors.PluginInitializationFailed, err, "failed to get google token source") + } + + return &Plugin{ + metricScope: metricScope, + cfg: cfg, + googleTokenSource: googleTokenSource, + }, nil +} + +func newBigQueryJobTaskPlugin() webapi.PluginEntry { + return webapi.PluginEntry{ + ID: "bigquery", + SupportedTaskTypes: []core.TaskType{bigqueryQueryJobTask}, + PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) { + cfg := GetConfig() + + return NewPlugin(cfg, iCtx.MetricsScope()) + }, + } +} + +func init() { + gob.Register(ResourceMetaWrapper{}) + gob.Register(ResourceWrapper{}) + + pluginmachinery.PluginRegistry().RegisterRemotePlugin(newBigQueryJobTaskPlugin()) +} diff --git a/flyteplugins/go/tasks/plugins/webapi/bigquery/plugin_test.go b/flyteplugins/go/tasks/plugins/webapi/bigquery/plugin_test.go new file mode 100644 index 0000000000..0de1d30c5a --- /dev/null +++ b/flyteplugins/go/tasks/plugins/webapi/bigquery/plugin_test.go @@ -0,0 +1,206 @@ +package bigquery + +import ( + "testing" + "time" + + flyteIdlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" + pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/stretchr/testify/assert" + "google.golang.org/api/bigquery/v2" + "google.golang.org/api/googleapi" +) + +func TestFormatJobReference(t *testing.T) { + t.Run("format job reference", func(t *testing.T) { + jobReference := bigquery.JobReference{ + JobId: "my-job-id", + Location: "EU", + ProjectId: "flyte-test", + } + + str := formatJobReference(jobReference) + + assert.Equal(t, "flyte-test:EU.my-job-id", str) + }) +} + +func TestCreateTaskInfo(t *testing.T) { + t.Run("create task info", func(t *testing.T) { + resourceMeta := ResourceMetaWrapper{ + JobReference: bigquery.JobReference{ + JobId: "my-job-id", + Location: "EU", + ProjectId: "flyte-test", + }, + } + + taskInfo := createTaskInfo(&resourceMeta) + + assert.Equal(t, 1, len(taskInfo.Logs)) + assert.Equal(t, flyteIdlCore.TaskLog{ + Uri: "https://console.cloud.google.com/bigquery?project=flyte-test&j=bq:EU:my-job-id&page=queryresults", + Name: "BigQuery Console", + }, *taskInfo.Logs[0]) + }) +} + +func TestHandleCreateError(t *testing.T) { + occurredAt := time.Now() + taskInfo := core.TaskInfo{OccurredAt: &occurredAt} + + t.Run("handle 401", func(t *testing.T) { + createError := googleapi.Error{ + Code: 401, + Message: "user xxx is not authorized", + } + + phase := handleCreateError(&createError, &taskInfo) + + assert.Equal(t, flyteIdlCore.ExecutionError{ + Code: "http401", + Message: "user xxx is not authorized", + Kind: flyteIdlCore.ExecutionError_USER, + }, *phase.Err()) + assert.Equal(t, taskInfo, *phase.Info()) + }) + + t.Run("handle 500", func(t *testing.T) { + createError := googleapi.Error{ + Code: 500, + Message: "oops", + } + + phase := handleCreateError(&createError, &taskInfo) + + assert.Equal(t, flyteIdlCore.ExecutionError{ + Code: "http500", + Message: "oops", + Kind: flyteIdlCore.ExecutionError_SYSTEM, + }, *phase.Err()) + assert.Equal(t, taskInfo, *phase.Info()) + }) +} + +func TestHandleErrorResult(t *testing.T) { + occurredAt := time.Now() + taskInfo := core.TaskInfo{OccurredAt: &occurredAt} + + type args struct { + reason string + phase core.Phase + errorKind flyteIdlCore.ExecutionError_ErrorKind + } + + tests := []args{ + { + reason: "accessDenied", + phase: pluginsCore.PhasePermanentFailure, + errorKind: flyteIdlCore.ExecutionError_USER, + }, + { + reason: "backendError", + phase: pluginsCore.PhaseRetryableFailure, + errorKind: flyteIdlCore.ExecutionError_SYSTEM, + }, + { + reason: "billingNotEnabled", + phase: pluginsCore.PhasePermanentFailure, + errorKind: flyteIdlCore.ExecutionError_USER, + }, + { + reason: "blocked", + phase: pluginsCore.PhasePermanentFailure, + errorKind: flyteIdlCore.ExecutionError_USER, + }, + { + reason: "duplicate", + phase: pluginsCore.PhasePermanentFailure, + errorKind: flyteIdlCore.ExecutionError_USER, + }, + { + reason: "internalError", + phase: pluginsCore.PhaseRetryableFailure, + errorKind: flyteIdlCore.ExecutionError_SYSTEM, + }, + { + reason: "invalid", + phase: pluginsCore.PhasePermanentFailure, + errorKind: flyteIdlCore.ExecutionError_USER, + }, + { + reason: "invalidQuery", + phase: pluginsCore.PhasePermanentFailure, + errorKind: flyteIdlCore.ExecutionError_USER, + }, + { + reason: "invalidUser", + phase: pluginsCore.PhaseRetryableFailure, + errorKind: flyteIdlCore.ExecutionError_SYSTEM, + }, + { + reason: "notFound", + phase: pluginsCore.PhasePermanentFailure, + errorKind: flyteIdlCore.ExecutionError_USER, + }, + { + reason: "notImplemented", + phase: pluginsCore.PhasePermanentFailure, + errorKind: flyteIdlCore.ExecutionError_USER, + }, + { + reason: "quotaExceeded", + phase: pluginsCore.PhaseRetryableFailure, + errorKind: flyteIdlCore.ExecutionError_USER, + }, + { + reason: "rateLimitExceeded", + phase: pluginsCore.PhaseRetryableFailure, + errorKind: flyteIdlCore.ExecutionError_USER, + }, + { + reason: "resourceInUse", + phase: pluginsCore.PhaseRetryableFailure, + errorKind: flyteIdlCore.ExecutionError_SYSTEM, + }, + + { + reason: "resourcesExceeded", + phase: pluginsCore.PhasePermanentFailure, + errorKind: flyteIdlCore.ExecutionError_USER, + }, + + { + reason: "responseTooLarge", + phase: pluginsCore.PhasePermanentFailure, + errorKind: flyteIdlCore.ExecutionError_USER, + }, + { + reason: "stopped", + phase: pluginsCore.PhasePermanentFailure, + errorKind: flyteIdlCore.ExecutionError_USER, + }, + { + reason: "tableUnavailable", + phase: pluginsCore.PhaseRetryableFailure, + errorKind: flyteIdlCore.ExecutionError_SYSTEM, + }, + { + reason: "timeout", + phase: pluginsCore.PhasePermanentFailure, + errorKind: flyteIdlCore.ExecutionError_USER, + }, + } + + for _, test := range tests { + t.Run(test.reason, func(t *testing.T) { + phaseInfo := handleErrorResult(test.reason, "message", &taskInfo) + + assert.Equal(t, test.phase, phaseInfo.Phase()) + assert.Equal(t, test.reason, phaseInfo.Err().Code) + assert.Equal(t, test.errorKind, phaseInfo.Err().Kind) + assert.Equal(t, "message", phaseInfo.Err().Message) + }) + } +} diff --git a/flyteplugins/go/tasks/plugins/webapi/bigquery/query_job.go b/flyteplugins/go/tasks/plugins/webapi/bigquery/query_job.go new file mode 100644 index 0000000000..9c1fec1cd6 --- /dev/null +++ b/flyteplugins/go/tasks/plugins/webapi/bigquery/query_job.go @@ -0,0 +1,258 @@ +package bigquery + +import ( + "strconv" + + "github.com/pkg/errors" + + flyteIdlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + pluginErrors "github.com/flyteorg/flyteplugins/go/tasks/errors" + pluginUtils "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils" + structpb "github.com/golang/protobuf/ptypes/struct" + "google.golang.org/api/bigquery/v2" +) + +type QueryJobConfig struct { + Location string `json:"location"` + ProjectID string `json:"projectId"` + + // AllowLargeResults: [Optional] If true and query uses legacy SQL + // dialect, allows the query to produce arbitrarily large result tables + // at a slight cost in performance. Requires destinationTable to be set. + // For standard SQL queries, this flag is ignored and large results are + // always allowed. However, you must still set destinationTable when + // result size exceeds the allowed maximum response size. + AllowLargeResults bool `json:"allowLargeResults,omitempty"` + + // Clustering: [Beta] Clustering specification for the destination + // table. Must be specified with time-based partitioning, data in the + // table will be first partitioned and subsequently clustered. + Clustering *bigquery.Clustering `json:"clustering,omitempty"` + + // CreateDisposition: [Optional] Specifies whether the job is allowed to + // create new tables. The following values are supported: + // CREATE_IF_NEEDED: If the table does not exist, BigQuery creates the + // table. CREATE_NEVER: The table must already exist. If it does not, a + // 'notFound' error is returned in the job result. The default value is + // CREATE_IF_NEEDED. Creation, truncation and append actions occur as + // one atomic update upon job completion. + CreateDisposition string `json:"createDisposition,omitempty"` + + // DefaultDataset: [Optional] Specifies the default dataset to use for + // unqualified table names in the query. Note that this does not alter + // behavior of unqualified dataset names. + DefaultDataset *bigquery.DatasetReference `json:"defaultDataset,omitempty"` + + // DestinationEncryptionConfiguration: Custom encryption configuration + // (e.g., Cloud KMS keys). + DestinationEncryptionConfiguration *bigquery.EncryptionConfiguration `json:"destinationEncryptionConfiguration,omitempty"` + + // DestinationTable: [Optional] Describes the table where the query + // results should be stored. If not present, a new table will be created + // to store the results. This property must be set for large results + // that exceed the maximum response size. + DestinationTable *bigquery.TableReference `json:"destinationTable,omitempty"` + + // FlattenResults: [Optional] If true and query uses legacy SQL dialect, + // flattens all nested and repeated fields in the query results. + // allowLargeResults must be true if this is set to false. For standard + // SQL queries, this flag is ignored and results are never flattened. + // + // Default: true + FlattenResults *bool `json:"flattenResults,omitempty"` + + // MaximumBillingTier: [Optional] Limits the billing tier for this job. + // Queries that have resource usage beyond this tier will fail (without + // incurring a charge). If unspecified, this will be set to your project + // default. + // + // Default: 1 + MaximumBillingTier *int64 `json:"maximumBillingTier,omitempty"` + + // MaximumBytesBilled: [Optional] Limits the bytes billed for this job. + // Queries that will have bytes billed beyond this limit will fail + // (without incurring a charge). If unspecified, this will be set to + // your project default. + MaximumBytesBilled int64 `json:"maximumBytesBilled,omitempty,string"` + + // Priority: [Optional] Specifies a priority for the query. Possible + // values include INTERACTIVE and BATCH. The default value is + // INTERACTIVE. + Priority string `json:"priority,omitempty"` + + // Query: [Required] SQL query text to execute. The useLegacySql field + // can be used to indicate whether the query uses legacy SQL or standard + // SQL. + Query string `json:"query,omitempty"` + + // SchemaUpdateOptions: Allows the schema of the destination table to be + // updated as a side effect of the query job. Schema update options are + // supported in two cases: when writeDisposition is WRITE_APPEND; when + // writeDisposition is WRITE_TRUNCATE and the destination table is a + // partition of a table, specified by partition decorators. For normal + // tables, WRITE_TRUNCATE will always overwrite the schema. One or more + // of the following values are specified: ALLOW_FIELD_ADDITION: allow + // adding a nullable field to the schema. ALLOW_FIELD_RELAXATION: allow + // relaxing a required field in the original schema to nullable. + SchemaUpdateOptions []string `json:"schemaUpdateOptions,omitempty"` + + // TableDefinitions: [Optional] If querying an external data source + // outside of BigQuery, describes the data format, location and other + // properties of the data source. By defining these properties, the data + // source can then be queried as if it were a standard BigQuery table. + TableDefinitions map[string]bigquery.ExternalDataConfiguration `json:"tableDefinitions,omitempty"` + + // TimePartitioning: Time-based partitioning specification for the + // destination table. Only one of timePartitioning and rangePartitioning + // should be specified. + TimePartitioning *bigquery.TimePartitioning `json:"timePartitioning,omitempty"` + + // UseLegacySQL: Specifies whether to use BigQuery's legacy SQL dialect + // for this query. The default value is true. If set to false, the query + // will use BigQuery's standard SQL: + // https://cloud.google.com/bigquery/sql-reference/ When useLegacySql is + // set to false, the value of flattenResults is ignored; query will be + // run as if flattenResults is false. + // + // Default: true + UseLegacySQL *bool `json:"useLegacySql,omitempty"` + + // UseQueryCache: [Optional] Whether to look for the result in the query + // cache. The query cache is a best-effort cache that will be flushed + // whenever tables in the query are modified. Moreover, the query cache + // is only available when a query does not have a destination table + // specified. The default value is true. + // + // Default: true + UseQueryCache *bool `json:"useQueryCache,omitempty"` + + // UserDefinedFunctionResources: Describes user-defined function + // resources used in the query. + UserDefinedFunctionResources []*bigquery.UserDefinedFunctionResource `json:"userDefinedFunctionResources,omitempty"` + + // WriteDisposition: [Optional] Specifies the action that occurs if the + // destination table already exists. The following values are supported: + // WRITE_TRUNCATE: If the table already exists, BigQuery overwrites the + // table data and uses the schema from the query result. WRITE_APPEND: + // If the table already exists, BigQuery appends the data to the table. + // WRITE_EMPTY: If the table already exists and contains data, a + // 'duplicate' error is returned in the job result. The default value is + // WRITE_EMPTY. Each action is atomic and only occurs if BigQuery is + // able to complete the job successfully. Creation, truncation and + // append actions occur as one atomic update upon job completion. + WriteDisposition string `json:"writeDisposition,omitempty"` +} + +func unmarshalQueryJobConfig(structObj *structpb.Struct) (*QueryJobConfig, error) { + queryJobConfig := QueryJobConfig{} + err := pluginUtils.UnmarshalStructToObj(structObj, &queryJobConfig) + + if err != nil { + return nil, errors.Wrapf(err, "failed to unmarshal QueryJobConfig") + } + + return &queryJobConfig, nil +} + +func getJobConfigurationQuery(custom *QueryJobConfig, inputs *flyteIdlCore.LiteralMap) (*bigquery.JobConfigurationQuery, error) { + queryParameters, err := getQueryParameters(inputs.Literals) + + if err != nil { + return nil, pluginErrors.Errorf(pluginErrors.BadTaskSpecification, "unable build query parameters [%v]", err.Error()) + } + + return &bigquery.JobConfigurationQuery{ + AllowLargeResults: custom.AllowLargeResults, + Clustering: custom.Clustering, + CreateDisposition: custom.CreateDisposition, + DefaultDataset: custom.DefaultDataset, + DestinationEncryptionConfiguration: custom.DestinationEncryptionConfiguration, + DestinationTable: custom.DestinationTable, + FlattenResults: custom.FlattenResults, + MaximumBillingTier: custom.MaximumBillingTier, + MaximumBytesBilled: custom.MaximumBytesBilled, + ParameterMode: "NAMED", + Priority: custom.Priority, + Query: custom.Query, + QueryParameters: queryParameters, + SchemaUpdateOptions: custom.SchemaUpdateOptions, + TableDefinitions: custom.TableDefinitions, + TimePartitioning: custom.TimePartitioning, + UseLegacySql: custom.UseLegacySQL, + UseQueryCache: custom.UseQueryCache, + UserDefinedFunctionResources: custom.UserDefinedFunctionResources, + WriteDisposition: custom.WriteDisposition, + }, nil +} + +func getQueryParameters(literalMap map[string]*flyteIdlCore.Literal) ([]*bigquery.QueryParameter, error) { + queryParameters := make([]*bigquery.QueryParameter, len(literalMap)) + + i := 0 + for name, literal := range literalMap { + parameterType, parameterValue, err := getQueryParameter(literal) + + if err != nil { + return nil, err + } + + queryParameters[i] = &bigquery.QueryParameter{ + Name: name, + ParameterType: parameterType, + ParameterValue: parameterValue, + } + + i++ + } + + return queryParameters, nil +} + +// read more about parameterized queries: https://cloud.google.com/bigquery/docs/parameterized-queries + +func getQueryParameter(literal *flyteIdlCore.Literal) (*bigquery.QueryParameterType, *bigquery.QueryParameterValue, error) { + if scalar := literal.GetScalar(); scalar != nil { + if primitive := scalar.GetPrimitive(); primitive != nil { + switch primitive.Value.(type) { + case *flyteIdlCore.Primitive_Integer: + integerType := bigquery.QueryParameterType{Type: "INT64"} + integerValue := bigquery.QueryParameterValue{ + Value: strconv.FormatInt(primitive.GetInteger(), 10), + } + + return &integerType, &integerValue, nil + + case *flyteIdlCore.Primitive_StringValue: + stringType := bigquery.QueryParameterType{Type: "STRING"} + stringValue := bigquery.QueryParameterValue{ + Value: primitive.GetStringValue(), + } + + return &stringType, &stringValue, nil + + case *flyteIdlCore.Primitive_FloatValue: + floatType := bigquery.QueryParameterType{Type: "FLOAT64"} + floatValue := bigquery.QueryParameterValue{ + Value: strconv.FormatFloat(primitive.GetFloatValue(), 'f', -1, 64), + } + + return &floatType, &floatValue, nil + + case *flyteIdlCore.Primitive_Boolean: + boolType := bigquery.QueryParameterType{Type: "BOOL"} + + if primitive.GetBoolean() { + return &boolType, &bigquery.QueryParameterValue{ + Value: "TRUE", + }, nil + } + + return &boolType, &bigquery.QueryParameterValue{ + Value: "FALSE", + }, nil + } + } + } + + return nil, nil, pluginErrors.Errorf(pluginErrors.BadTaskSpecification, "unsupported literal [%v]", literal) +} diff --git a/flyteplugins/go/tasks/plugins/webapi/bigquery/query_job_test.go b/flyteplugins/go/tasks/plugins/webapi/bigquery/query_job_test.go new file mode 100644 index 0000000000..8df93268ad --- /dev/null +++ b/flyteplugins/go/tasks/plugins/webapi/bigquery/query_job_test.go @@ -0,0 +1,83 @@ +package bigquery + +import ( + "testing" + + "github.com/flyteorg/flyteidl/clients/go/coreutils" + + "github.com/stretchr/testify/assert" + "google.golang.org/api/bigquery/v2" +) + +func TestGetQueryParameter(t *testing.T) { + t.Run("get integer parameter", func(t *testing.T) { + literal, _ := coreutils.MakePrimitiveLiteral(42) + + tpe, value, err := getQueryParameter(literal) + + assert.NoError(t, err) + assert.Equal(t, bigquery.QueryParameterType{Type: "INT64"}, *tpe) + assert.Equal(t, bigquery.QueryParameterValue{Value: "42"}, *value) + }) + + t.Run("get string parameter", func(t *testing.T) { + literal, _ := coreutils.MakePrimitiveLiteral("abc") + + tpe, value, err := getQueryParameter(literal) + + assert.NoError(t, err) + assert.Equal(t, bigquery.QueryParameterType{Type: "STRING"}, *tpe) + assert.Equal(t, bigquery.QueryParameterValue{Value: "abc"}, *value) + }) + + t.Run("get float parameter", func(t *testing.T) { + literal, _ := coreutils.MakePrimitiveLiteral(42.5) + + tpe, value, err := getQueryParameter(literal) + + assert.NoError(t, err) + assert.Equal(t, bigquery.QueryParameterType{Type: "FLOAT64"}, *tpe) + assert.Equal(t, bigquery.QueryParameterValue{Value: "42.5"}, *value) + }) + + t.Run("get true parameter", func(t *testing.T) { + literal, _ := coreutils.MakePrimitiveLiteral(true) + + tpe, value, err := getQueryParameter(literal) + + assert.NoError(t, err) + assert.Equal(t, bigquery.QueryParameterType{Type: "BOOL"}, *tpe) + assert.Equal(t, bigquery.QueryParameterValue{Value: "TRUE"}, *value) + }) + + t.Run("get false parameter", func(t *testing.T) { + literal, _ := coreutils.MakePrimitiveLiteral(false) + + tpe, value, err := getQueryParameter(literal) + + assert.NoError(t, err) + assert.Equal(t, bigquery.QueryParameterType{Type: "BOOL"}, *tpe) + assert.Equal(t, bigquery.QueryParameterValue{Value: "FALSE"}, *value) + }) +} + +func TestGetJobConfigurationQuery(t *testing.T) { + t.Run("get job configuration query", func(t *testing.T) { + config := QueryJobConfig{} + inputs, _ := coreutils.MakeLiteralMap(map[string]interface{}{ + "integer": 42, + }) + + jobConfigurationQuery, err := getJobConfigurationQuery(&config, inputs) + + assert.NoError(t, err) + assert.Equal(t, "NAMED", jobConfigurationQuery.ParameterMode) + + assert.Equal(t, 1, len(jobConfigurationQuery.QueryParameters)) + assert.Equal(t, bigquery.QueryParameter{ + Name: "integer", + ParameterType: &bigquery.QueryParameterType{Type: "INT64"}, + ParameterValue: &bigquery.QueryParameterValue{Value: "42"}, + }, *jobConfigurationQuery.QueryParameters[0]) + }) +} diff --git a/flyteplugins/tests/end_to_end.go b/flyteplugins/tests/end_to_end.go index b3eb1050da..529ddfeadc 100644 --- a/flyteplugins/tests/end_to_end.go +++ b/flyteplugins/tests/end_to_end.go @@ -65,7 +65,7 @@ func BuildTaskTemplate() *idlCore.TaskTemplate { func RunPluginEndToEndTest(t *testing.T, executor pluginCore.Plugin, template *idlCore.TaskTemplate, inputs *idlCore.LiteralMap, expectedOutputs *idlCore.LiteralMap, expectedFailure *idlCore.ExecutionError, - iterationUpdate func(ctx context.Context, tCtx pluginCore.TaskExecutionContext) error) { + iterationUpdate func(ctx context.Context, tCtx pluginCore.TaskExecutionContext) error) pluginCore.PhaseInfo { ctx := context.Background() @@ -158,6 +158,11 @@ func RunPluginEndToEndTest(t *testing.T, executor pluginCore.Plugin, template *i tMeta.OnGetOverrides().Return(overrides) tMeta.OnGetK8sServiceAccount().Return("s") tMeta.OnGetNamespace().Return("fake-development") + tMeta.OnGetSecurityContext().Return(idlCore.SecurityContext{ + RunAs: &idlCore.Identity{ + K8SServiceAccount: "s", + }, + }) tMeta.OnGetLabels().Return(map[string]string{}) tMeta.OnGetAnnotations().Return(map[string]string{}) tMeta.OnIsInterruptible().Return(true) @@ -266,4 +271,6 @@ func RunPluginEndToEndTest(t *testing.T, executor pluginCore.Plugin, template *i t.Errorf("Expected != Actual. Diff: %v", diff) } } + + return trns.Info() }