diff --git a/go.mod b/go.mod index 1af9c5eb2..9b3b17a87 100644 --- a/go.mod +++ b/go.mod @@ -13,7 +13,7 @@ require ( github.com/aws/aws-sdk-go-v2/config v1.0.0 github.com/aws/aws-sdk-go-v2/service/athena v1.0.0 github.com/coocood/freecache v1.1.1 - github.com/flyteorg/flyteidl v0.21.11 + github.com/flyteorg/flyteidl v0.21.23 github.com/flyteorg/flytestdlib v0.4.7 github.com/go-logr/zapr v0.4.0 // indirect github.com/go-test/deep v1.0.7 diff --git a/go.sum b/go.sum index 56cb7daa7..0c1ba83b6 100644 --- a/go.sum +++ b/go.sum @@ -233,8 +233,8 @@ github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5Kwzbycv github.com/fatih/color v1.10.0 h1:s36xzo75JdqLaaWoiEHk767eHiwo0598uUxyfiPkDsg= github.com/fatih/color v1.10.0/go.mod h1:ELkj/draVOlAH/xkhN6mQ50Qd0MPOk5AAr3maGEBuJM= github.com/fatih/structtag v1.2.0/go.mod h1:mBJUNpUnHmRKrKlQQlmCrh5PuhftFbNv8Ys4/aAZl94= -github.com/flyteorg/flyteidl v0.21.11 h1:oH9YPoR7scO9GFF/I8D0gCTOB+JP5HRK7b7cLUBRz90= -github.com/flyteorg/flyteidl v0.21.11/go.mod h1:576W2ViEyjTpT+kEVHAGbrTP3HARNUZ/eCwrNPmdx9U= +github.com/flyteorg/flyteidl v0.21.23 h1:hzGIFNOt3VooW/NdnaicXijn3EKjNKTz1kY+tlHkED4= +github.com/flyteorg/flyteidl v0.21.23/go.mod h1:576W2ViEyjTpT+kEVHAGbrTP3HARNUZ/eCwrNPmdx9U= github.com/flyteorg/flytestdlib v0.3.13/go.mod h1:Tz8JCECAbX6VWGwFT6cmEQ+RJpZ/6L9pswu3fzWs220= github.com/flyteorg/flytestdlib v0.4.7 h1:SMPPXI3j/MjP7D2fqaR+lPQkTrqYS7xZbwsgJI2F8SU= github.com/flyteorg/flytestdlib v0.4.7/go.mod h1:fv1ar34LJLMTaf0tbfetisLykUlARi7rP+NQTUn6QQs= diff --git a/go/tasks/plugins/webapi/bigquery/integration_test.go b/go/tasks/plugins/webapi/bigquery/integration_test.go index 18d287446..bec6f3c2f 100644 --- a/go/tasks/plugins/webapi/bigquery/integration_test.go +++ b/go/tasks/plugins/webapi/bigquery/integration_test.go @@ -75,7 +75,11 @@ func newFakeBigQueryServer() *httptest.Server { if strings.HasPrefix(request.URL.Path, "/projects/flyte/jobs/") && request.Method == "GET" { writer.WriteHeader(200) - job := bigquery.Job{Status: &bigquery.JobStatus{State: "DONE"}} + job := bigquery.Job{Status: &bigquery.JobStatus{State: "DONE"}, + Configuration: &bigquery.JobConfiguration{ + Query: &bigquery.JobConfigurationQuery{ + DestinationTable: &bigquery.TableReference{ + ProjectId: "project", DatasetId: "dataset", TableId: "table"}}}} bytes, _ := json.Marshal(job) _, _ = writer.Write(bytes) return diff --git a/go/tasks/plugins/webapi/bigquery/plugin.go b/go/tasks/plugins/webapi/bigquery/plugin.go index 5a9d7289b..8dd45b650 100644 --- a/go/tasks/plugins/webapi/bigquery/plugin.go +++ b/go/tasks/plugins/webapi/bigquery/plugin.go @@ -7,6 +7,8 @@ import ( "net/http" "time" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/ioutils" + "golang.org/x/oauth2" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s" @@ -42,8 +44,9 @@ type Plugin struct { } type ResourceWrapper struct { - Status *bigquery.JobStatus - CreateError *googleapi.Error + Status *bigquery.JobStatus + CreateError *googleapi.Error + OutputLocation string } type ResourceMetaWrapper struct { @@ -211,8 +214,12 @@ func (p Plugin) getImpl(ctx context.Context, taskCtx webapi.GetContext) (wrapper return nil, err } + dst := job.Configuration.Query.DestinationTable + outputLocation := fmt.Sprintf("bq://%v:%v.%v", dst.ProjectId, dst.DatasetId, dst.TableId) + return &ResourceWrapper{ - Status: job.Status, + Status: job.Status, + OutputLocation: outputLocation, }, nil } @@ -244,7 +251,7 @@ func (p Plugin) Delete(ctx context.Context, taskCtx webapi.DeleteContext) error return nil } -func (p Plugin) Status(_ context.Context, tCtx webapi.StatusContext) (phase core.PhaseInfo, err error) { +func (p Plugin) Status(ctx context.Context, tCtx webapi.StatusContext) (phase core.PhaseInfo, err error) { resourceMeta := tCtx.ResourceMeta().(*ResourceMetaWrapper) resource := tCtx.Resource().(*ResourceWrapper) version := pluginsCore.DefaultPhaseVersion @@ -273,13 +280,54 @@ func (p Plugin) Status(_ context.Context, tCtx webapi.StatusContext) (phase core resource.Status.ErrorResult.Message, taskInfo), nil } - + err = writeOutput(ctx, tCtx, resource.OutputLocation) + if err != nil { + logger.Warnf(ctx, "Failed to write output, uri [%s], err %s", resource.OutputLocation, err.Error()) + return core.PhaseInfoUndefined, err + } return pluginsCore.PhaseInfoSuccess(taskInfo), nil } return core.PhaseInfoUndefined, pluginErrors.Errorf(pluginsCore.SystemErrorCode, "unknown execution phase [%v].", resource.Status.State) } +func writeOutput(ctx context.Context, tCtx webapi.StatusContext, OutputLocation string) error { + taskTemplate, err := tCtx.TaskReader().Read(ctx) + if err != nil { + return err + } + + if taskTemplate.Interface == nil || taskTemplate.Interface.Outputs == nil || taskTemplate.Interface.Outputs.Variables == nil { + logger.Infof(ctx, "The task declares no outputs. Skipping writing the outputs.") + return nil + } + + resultsStructuredDatasetType, exists := taskTemplate.Interface.Outputs.Variables["results"] + if !exists { + logger.Infof(ctx, "The task declares no outputs. Skipping writing the outputs.") + return nil + } + return tCtx.OutputWriter().Put(ctx, ioutils.NewInMemoryOutputReader( + &flyteIdlCore.LiteralMap{ + Literals: map[string]*flyteIdlCore.Literal{ + "results": { + Value: &flyteIdlCore.Literal_Scalar{ + Scalar: &flyteIdlCore.Scalar{ + Value: &flyteIdlCore.Scalar_StructuredDataset{ + StructuredDataset: &flyteIdlCore.StructuredDataset{ + Uri: OutputLocation, + Metadata: &flyteIdlCore.StructuredDatasetMetadata{ + StructuredDatasetType: resultsStructuredDatasetType.GetType().GetStructuredDatasetType(), + }, + }, + }, + }, + }, + }, + }, + }, nil)) +} + func handleCreateError(createError *googleapi.Error, taskInfo *core.TaskInfo) core.PhaseInfo { code := fmt.Sprintf("http%d", createError.Code) diff --git a/go/tasks/plugins/webapi/bigquery/plugin_test.go b/go/tasks/plugins/webapi/bigquery/plugin_test.go index 0de1d30c5..d39d84418 100644 --- a/go/tasks/plugins/webapi/bigquery/plugin_test.go +++ b/go/tasks/plugins/webapi/bigquery/plugin_test.go @@ -1,9 +1,21 @@ package bigquery import ( + "context" "testing" "time" + coreMocks "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io" + ioMocks "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io/mocks" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/webapi/mocks" + "github.com/flyteorg/flytestdlib/contextutils" + "github.com/flyteorg/flytestdlib/promutils" + "github.com/flyteorg/flytestdlib/promutils/labeled" + "github.com/flyteorg/flytestdlib/storage" + "github.com/stretchr/testify/mock" + "k8s.io/apimachinery/pkg/util/rand" + flyteIdlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" @@ -12,6 +24,10 @@ import ( "google.golang.org/api/googleapi" ) +func init() { + labeled.SetMetricKeys(contextutils.NamespaceKey) +} + func TestFormatJobReference(t *testing.T) { t.Run("format job reference", func(t *testing.T) { jobReference := bigquery.JobReference{ @@ -46,6 +62,79 @@ func TestCreateTaskInfo(t *testing.T) { }) } +func TestOutputWriter(t *testing.T) { + ctx := context.Background() + statusContext := &mocks.StatusContext{} + + template := flyteIdlCore.TaskTemplate{} + tr := &coreMocks.TaskReader{} + tr.OnRead(ctx).Return(&template, nil) + statusContext.OnTaskReader().Return(tr) + + outputLocation := "bq://project:flyte.table" + err := writeOutput(ctx, statusContext, outputLocation) + assert.NoError(t, err) + + ds, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) + assert.NoError(t, err) + + outputWriter := &ioMocks.OutputWriter{} + outputWriter.OnPutMatch(mock.Anything, mock.Anything).Return(nil).Run(func(args mock.Arguments) { + or := args.Get(1).(io.OutputReader) + literals, ee, err := or.Read(ctx) + assert.NoError(t, err) + + sd := literals.GetLiterals()["results"].GetScalar().GetStructuredDataset() + assert.Equal(t, sd.Uri, outputLocation) + assert.Equal(t, sd.Metadata.GetStructuredDatasetType().Columns[0].Name, "col1") + assert.Equal(t, sd.Metadata.GetStructuredDatasetType().Columns[0].LiteralType.GetSimple(), flyteIdlCore.SimpleType_INTEGER) + + if ee != nil { + assert.NoError(t, ds.WriteProtobuf(ctx, outputWriter.GetErrorPath(), storage.Options{}, ee)) + } + + if literals != nil { + assert.NoError(t, ds.WriteProtobuf(ctx, outputWriter.GetOutputPath(), storage.Options{}, literals)) + } + }) + + execID := rand.String(3) + basePrefix := storage.DataReference("fake://bucket/prefix/" + execID) + outputWriter.OnGetOutputPath().Return(basePrefix + "/outputs.pb") + statusContext.OnOutputWriter().Return(outputWriter) + + template = flyteIdlCore.TaskTemplate{ + Interface: &flyteIdlCore.TypedInterface{ + Outputs: &flyteIdlCore.VariableMap{ + Variables: map[string]*flyteIdlCore.Variable{ + "results": { + Type: &flyteIdlCore.LiteralType{ + Type: &flyteIdlCore.LiteralType_StructuredDatasetType{ + StructuredDatasetType: &flyteIdlCore.StructuredDatasetType{ + Columns: []*flyteIdlCore.StructuredDatasetType_DatasetColumn{ + { + Name: "col1", + LiteralType: &flyteIdlCore.LiteralType{ + Type: &flyteIdlCore.LiteralType_Simple{ + Simple: flyteIdlCore.SimpleType_INTEGER, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + } + tr.OnRead(ctx).Return(&template, nil) + statusContext.OnTaskReader().Return(tr) + err = writeOutput(ctx, statusContext, outputLocation) + assert.NoError(t, err) +} + func TestHandleCreateError(t *testing.T) { occurredAt := time.Now() taskInfo := core.TaskInfo{OccurredAt: &occurredAt}