Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Support BQ plugin to write structured dataset as output (#233)
Browse files Browse the repository at this point in the history
* BQ output structured dataset

Signed-off-by: Kevin Su <[email protected]>

* Fixed test

Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw authored Jan 21, 2022
1 parent f4478ac commit 590813b
Show file tree
Hide file tree
Showing 5 changed files with 150 additions and 9 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ require (
github.com/aws/aws-sdk-go-v2/config v1.0.0
github.com/aws/aws-sdk-go-v2/service/athena v1.0.0
github.com/coocood/freecache v1.1.1
github.com/flyteorg/flyteidl v0.21.11
github.com/flyteorg/flyteidl v0.21.23
github.com/flyteorg/flytestdlib v0.4.7
github.com/go-logr/zapr v0.4.0 // indirect
github.com/go-test/deep v1.0.7
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -233,8 +233,8 @@ github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5Kwzbycv
github.com/fatih/color v1.10.0 h1:s36xzo75JdqLaaWoiEHk767eHiwo0598uUxyfiPkDsg=
github.com/fatih/color v1.10.0/go.mod h1:ELkj/draVOlAH/xkhN6mQ50Qd0MPOk5AAr3maGEBuJM=
github.com/fatih/structtag v1.2.0/go.mod h1:mBJUNpUnHmRKrKlQQlmCrh5PuhftFbNv8Ys4/aAZl94=
github.com/flyteorg/flyteidl v0.21.11 h1:oH9YPoR7scO9GFF/I8D0gCTOB+JP5HRK7b7cLUBRz90=
github.com/flyteorg/flyteidl v0.21.11/go.mod h1:576W2ViEyjTpT+kEVHAGbrTP3HARNUZ/eCwrNPmdx9U=
github.com/flyteorg/flyteidl v0.21.23 h1:hzGIFNOt3VooW/NdnaicXijn3EKjNKTz1kY+tlHkED4=
github.com/flyteorg/flyteidl v0.21.23/go.mod h1:576W2ViEyjTpT+kEVHAGbrTP3HARNUZ/eCwrNPmdx9U=
github.com/flyteorg/flytestdlib v0.3.13/go.mod h1:Tz8JCECAbX6VWGwFT6cmEQ+RJpZ/6L9pswu3fzWs220=
github.com/flyteorg/flytestdlib v0.4.7 h1:SMPPXI3j/MjP7D2fqaR+lPQkTrqYS7xZbwsgJI2F8SU=
github.com/flyteorg/flytestdlib v0.4.7/go.mod h1:fv1ar34LJLMTaf0tbfetisLykUlARi7rP+NQTUn6QQs=
Expand Down
6 changes: 5 additions & 1 deletion go/tasks/plugins/webapi/bigquery/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,11 @@ func newFakeBigQueryServer() *httptest.Server {

if strings.HasPrefix(request.URL.Path, "/projects/flyte/jobs/") && request.Method == "GET" {
writer.WriteHeader(200)
job := bigquery.Job{Status: &bigquery.JobStatus{State: "DONE"}}
job := bigquery.Job{Status: &bigquery.JobStatus{State: "DONE"},
Configuration: &bigquery.JobConfiguration{
Query: &bigquery.JobConfigurationQuery{
DestinationTable: &bigquery.TableReference{
ProjectId: "project", DatasetId: "dataset", TableId: "table"}}}}
bytes, _ := json.Marshal(job)
_, _ = writer.Write(bytes)
return
Expand Down
58 changes: 53 additions & 5 deletions go/tasks/plugins/webapi/bigquery/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"net/http"
"time"

"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/ioutils"

"golang.org/x/oauth2"

"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s"
Expand Down Expand Up @@ -42,8 +44,9 @@ type Plugin struct {
}

type ResourceWrapper struct {
Status *bigquery.JobStatus
CreateError *googleapi.Error
Status *bigquery.JobStatus
CreateError *googleapi.Error
OutputLocation string
}

type ResourceMetaWrapper struct {
Expand Down Expand Up @@ -211,8 +214,12 @@ func (p Plugin) getImpl(ctx context.Context, taskCtx webapi.GetContext) (wrapper
return nil, err
}

dst := job.Configuration.Query.DestinationTable
outputLocation := fmt.Sprintf("bq://%v:%v.%v", dst.ProjectId, dst.DatasetId, dst.TableId)

return &ResourceWrapper{
Status: job.Status,
Status: job.Status,
OutputLocation: outputLocation,
}, nil
}

Expand Down Expand Up @@ -244,7 +251,7 @@ func (p Plugin) Delete(ctx context.Context, taskCtx webapi.DeleteContext) error
return nil
}

func (p Plugin) Status(_ context.Context, tCtx webapi.StatusContext) (phase core.PhaseInfo, err error) {
func (p Plugin) Status(ctx context.Context, tCtx webapi.StatusContext) (phase core.PhaseInfo, err error) {
resourceMeta := tCtx.ResourceMeta().(*ResourceMetaWrapper)
resource := tCtx.Resource().(*ResourceWrapper)
version := pluginsCore.DefaultPhaseVersion
Expand Down Expand Up @@ -273,13 +280,54 @@ func (p Plugin) Status(_ context.Context, tCtx webapi.StatusContext) (phase core
resource.Status.ErrorResult.Message,
taskInfo), nil
}

err = writeOutput(ctx, tCtx, resource.OutputLocation)
if err != nil {
logger.Warnf(ctx, "Failed to write output, uri [%s], err %s", resource.OutputLocation, err.Error())
return core.PhaseInfoUndefined, err
}
return pluginsCore.PhaseInfoSuccess(taskInfo), nil
}

return core.PhaseInfoUndefined, pluginErrors.Errorf(pluginsCore.SystemErrorCode, "unknown execution phase [%v].", resource.Status.State)
}

func writeOutput(ctx context.Context, tCtx webapi.StatusContext, OutputLocation string) error {
taskTemplate, err := tCtx.TaskReader().Read(ctx)
if err != nil {
return err
}

if taskTemplate.Interface == nil || taskTemplate.Interface.Outputs == nil || taskTemplate.Interface.Outputs.Variables == nil {
logger.Infof(ctx, "The task declares no outputs. Skipping writing the outputs.")
return nil
}

resultsStructuredDatasetType, exists := taskTemplate.Interface.Outputs.Variables["results"]
if !exists {
logger.Infof(ctx, "The task declares no outputs. Skipping writing the outputs.")
return nil
}
return tCtx.OutputWriter().Put(ctx, ioutils.NewInMemoryOutputReader(
&flyteIdlCore.LiteralMap{
Literals: map[string]*flyteIdlCore.Literal{
"results": {
Value: &flyteIdlCore.Literal_Scalar{
Scalar: &flyteIdlCore.Scalar{
Value: &flyteIdlCore.Scalar_StructuredDataset{
StructuredDataset: &flyteIdlCore.StructuredDataset{
Uri: OutputLocation,
Metadata: &flyteIdlCore.StructuredDatasetMetadata{
StructuredDatasetType: resultsStructuredDatasetType.GetType().GetStructuredDatasetType(),
},
},
},
},
},
},
},
}, nil))
}

func handleCreateError(createError *googleapi.Error, taskInfo *core.TaskInfo) core.PhaseInfo {
code := fmt.Sprintf("http%d", createError.Code)

Expand Down
89 changes: 89 additions & 0 deletions go/tasks/plugins/webapi/bigquery/plugin_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,21 @@
package bigquery

import (
"context"
"testing"
"time"

coreMocks "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io"
ioMocks "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io/mocks"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/webapi/mocks"
"github.com/flyteorg/flytestdlib/contextutils"
"github.com/flyteorg/flytestdlib/promutils"
"github.com/flyteorg/flytestdlib/promutils/labeled"
"github.com/flyteorg/flytestdlib/storage"
"github.com/stretchr/testify/mock"
"k8s.io/apimachinery/pkg/util/rand"

flyteIdlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"
pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"
Expand All @@ -12,6 +24,10 @@ import (
"google.golang.org/api/googleapi"
)

func init() {
labeled.SetMetricKeys(contextutils.NamespaceKey)
}

func TestFormatJobReference(t *testing.T) {
t.Run("format job reference", func(t *testing.T) {
jobReference := bigquery.JobReference{
Expand Down Expand Up @@ -46,6 +62,79 @@ func TestCreateTaskInfo(t *testing.T) {
})
}

func TestOutputWriter(t *testing.T) {
ctx := context.Background()
statusContext := &mocks.StatusContext{}

template := flyteIdlCore.TaskTemplate{}
tr := &coreMocks.TaskReader{}
tr.OnRead(ctx).Return(&template, nil)
statusContext.OnTaskReader().Return(tr)

outputLocation := "bq://project:flyte.table"
err := writeOutput(ctx, statusContext, outputLocation)
assert.NoError(t, err)

ds, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope())
assert.NoError(t, err)

outputWriter := &ioMocks.OutputWriter{}
outputWriter.OnPutMatch(mock.Anything, mock.Anything).Return(nil).Run(func(args mock.Arguments) {
or := args.Get(1).(io.OutputReader)
literals, ee, err := or.Read(ctx)
assert.NoError(t, err)

sd := literals.GetLiterals()["results"].GetScalar().GetStructuredDataset()
assert.Equal(t, sd.Uri, outputLocation)
assert.Equal(t, sd.Metadata.GetStructuredDatasetType().Columns[0].Name, "col1")
assert.Equal(t, sd.Metadata.GetStructuredDatasetType().Columns[0].LiteralType.GetSimple(), flyteIdlCore.SimpleType_INTEGER)

if ee != nil {
assert.NoError(t, ds.WriteProtobuf(ctx, outputWriter.GetErrorPath(), storage.Options{}, ee))
}

if literals != nil {
assert.NoError(t, ds.WriteProtobuf(ctx, outputWriter.GetOutputPath(), storage.Options{}, literals))
}
})

execID := rand.String(3)
basePrefix := storage.DataReference("fake://bucket/prefix/" + execID)
outputWriter.OnGetOutputPath().Return(basePrefix + "/outputs.pb")
statusContext.OnOutputWriter().Return(outputWriter)

template = flyteIdlCore.TaskTemplate{
Interface: &flyteIdlCore.TypedInterface{
Outputs: &flyteIdlCore.VariableMap{
Variables: map[string]*flyteIdlCore.Variable{
"results": {
Type: &flyteIdlCore.LiteralType{
Type: &flyteIdlCore.LiteralType_StructuredDatasetType{
StructuredDatasetType: &flyteIdlCore.StructuredDatasetType{
Columns: []*flyteIdlCore.StructuredDatasetType_DatasetColumn{
{
Name: "col1",
LiteralType: &flyteIdlCore.LiteralType{
Type: &flyteIdlCore.LiteralType_Simple{
Simple: flyteIdlCore.SimpleType_INTEGER,
},
},
},
},
},
},
},
},
},
},
},
}
tr.OnRead(ctx).Return(&template, nil)
statusContext.OnTaskReader().Return(tr)
err = writeOutput(ctx, statusContext, outputLocation)
assert.NoError(t, err)
}

func TestHandleCreateError(t *testing.T) {
occurredAt := time.Now()
taskInfo := core.TaskInfo{OccurredAt: &occurredAt}
Expand Down

0 comments on commit 590813b

Please sign in to comment.