diff --git a/go/tasks/plugins/webapi/bigquery/integration_test.go b/go/tasks/plugins/webapi/bigquery/integration_test.go index 187150633..18d287446 100644 --- a/go/tasks/plugins/webapi/bigquery/integration_test.go +++ b/go/tasks/plugins/webapi/bigquery/integration_test.go @@ -47,7 +47,6 @@ func TestEndToEnd(t *testing.T) { t.Run("SELECT 1", func(t *testing.T) { queryJobConfig := QueryJobConfig{ ProjectID: "flyte", - Query: "SELECT 1", } inputs, _ := coreutils.MakeLiteralMap(map[string]interface{}{"x": 1}) @@ -55,6 +54,7 @@ func TestEndToEnd(t *testing.T) { template := flyteIdlCore.TaskTemplate{ Type: bigqueryQueryJobTask, Custom: custom, + Target: &flyteIdlCore.TaskTemplate_Sql{Sql: &flyteIdlCore.Sql{Statement: "SELECT 1", Dialect: flyteIdlCore.Sql_ANSI}}, } phase := tests.RunPluginEndToEndTest(t, plugin, &template, inputs, nil, nil, iter) diff --git a/go/tasks/plugins/webapi/bigquery/plugin.go b/go/tasks/plugins/webapi/bigquery/plugin.go index 7a8b21d28..5a9d7289b 100644 --- a/go/tasks/plugins/webapi/bigquery/plugin.go +++ b/go/tasks/plugins/webapi/bigquery/plugin.go @@ -7,10 +7,10 @@ import ( "net/http" "time" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s" - "golang.org/x/oauth2" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/google" structpb "github.com/golang/protobuf/ptypes/struct" "google.golang.org/api/bigquery/v2" @@ -105,6 +105,7 @@ func (p Plugin) createImpl(ctx context.Context, taskCtx webapi.TaskExecutionCont return nil, nil, err } + job.Configuration.Query.Query = taskTemplate.GetSql().Statement job.Configuration.Labels = taskCtx.TaskExecutionMetadata().GetLabels() resp, err := client.Jobs.Insert(job.JobReference.ProjectId, job).Do() @@ -456,7 +457,8 @@ func (p Plugin) newBigQueryClient(ctx context.Context, identity google.Identity) options = append(options, option.WithEndpoint(p.cfg.bigQueryEndpoint), option.WithTokenSource(oauth2.StaticTokenSource(&oauth2.Token{}))) - } else { + } else if p.cfg.GoogleTokenSource.Type != "default" { + tokenSource, err := p.googleTokenSource.GetTokenSource(ctx, identity) if err != nil { @@ -464,6 +466,8 @@ func (p Plugin) newBigQueryClient(ctx context.Context, identity google.Identity) } options = append(options, option.WithTokenSource(tokenSource)) + } else { + logger.Infof(ctx, "BigQuery client read $GOOGLE_APPLICATION_CREDENTIALS by default") } return bigquery.NewService(ctx, options...) diff --git a/go/tasks/plugins/webapi/bigquery/query_job.go b/go/tasks/plugins/webapi/bigquery/query_job.go index 9c1fec1cd..ccd3610f0 100644 --- a/go/tasks/plugins/webapi/bigquery/query_job.go +++ b/go/tasks/plugins/webapi/bigquery/query_job.go @@ -161,6 +161,9 @@ func getJobConfigurationQuery(custom *QueryJobConfig, inputs *flyteIdlCore.Liter return nil, pluginErrors.Errorf(pluginErrors.BadTaskSpecification, "unable build query parameters [%v]", err.Error()) } + // BigQuery supports query parameters to help prevent SQL injection when queries are constructed using user input. + // This feature is only available with standard SQL syntax. For more detail: https://cloud.google.com/bigquery/docs/parameterized-queries + useLegacySQL := false return &bigquery.JobConfigurationQuery{ AllowLargeResults: custom.AllowLargeResults, Clustering: custom.Clustering, @@ -178,7 +181,7 @@ func getJobConfigurationQuery(custom *QueryJobConfig, inputs *flyteIdlCore.Liter SchemaUpdateOptions: custom.SchemaUpdateOptions, TableDefinitions: custom.TableDefinitions, TimePartitioning: custom.TimePartitioning, - UseLegacySql: custom.UseLegacySQL, + UseLegacySql: &useLegacySQL, UseQueryCache: custom.UseQueryCache, UserDefinedFunctionResources: custom.UserDefinedFunctionResources, WriteDisposition: custom.WriteDisposition, diff --git a/go/tasks/plugins/webapi/bigquery/query_job_test.go b/go/tasks/plugins/webapi/bigquery/query_job_test.go index 8df93268a..b53c83840 100644 --- a/go/tasks/plugins/webapi/bigquery/query_job_test.go +++ b/go/tasks/plugins/webapi/bigquery/query_job_test.go @@ -69,10 +69,11 @@ func TestGetJobConfigurationQuery(t *testing.T) { }) jobConfigurationQuery, err := getJobConfigurationQuery(&config, inputs) + useLegacySQL := false assert.NoError(t, err) assert.Equal(t, "NAMED", jobConfigurationQuery.ParameterMode) - + assert.Equal(t, &useLegacySQL, jobConfigurationQuery.UseLegacySql) assert.Equal(t, 1, len(jobConfigurationQuery.QueryParameters)) assert.Equal(t, bigquery.QueryParameter{ Name: "integer",