Skip to content

Commit

Permalink
Use jsonpb AllowUnknownFields everywhere
Browse files Browse the repository at this point in the history
Signed-off-by: Andrew Dye <[email protected]>
  • Loading branch information
andrewwdye committed Jun 26, 2024
1 parent ce5eb03 commit c996dcb
Show file tree
Hide file tree
Showing 23 changed files with 162 additions and 193 deletions.
4 changes: 2 additions & 2 deletions flyteadmin/pkg/manager/impl/util/digests_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@ import (
"path/filepath"
"testing"

"github.com/golang/protobuf/jsonpb"
"github.com/golang/protobuf/ptypes/duration"
_struct "github.com/golang/protobuf/ptypes/struct"
"github.com/stretchr/testify/assert"

"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin"
"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flyte/flytestdlib/utils"
)

var testLaunchPlanDigest = []byte{
Expand Down Expand Up @@ -92,7 +92,7 @@ func getCompiledWorkflow() (*core.CompiledWorkflowClosure, error) {
if err != nil {
return nil, err
}
err = jsonpb.UnmarshalString(string(workflowJSON), &compiledWorkflow)
err = utils.UnmarshalBytesToPb(workflowJSON, &compiledWorkflow)
if err != nil {
return nil, err
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/event"
"github.com/flyteorg/flyte/flytestdlib/promutils"
"github.com/flyteorg/flyte/flytestdlib/storage"
"github.com/flyteorg/flyte/flytestdlib/utils"
)

var taskEventOccurredAt = time.Now().UTC()
Expand Down Expand Up @@ -63,7 +64,7 @@ func transformMapToStructPB(t *testing.T, thing map[string]string) *structpb.Str
}

thingAsCustom := &structpb.Struct{}
if err := jsonpb.UnmarshalString(string(b), thingAsCustom); err != nil {
if err := utils.UnmarshalBytesToPb(b, thingAsCustom); err != nil {
t.Fatal(t, err)
}
return thingAsCustom
Expand Down
22 changes: 6 additions & 16 deletions flytectl/cmd/get/node_execution.go
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
package get

import (
"bytes"
"context"
"fmt"
"sort"
"strconv"

"github.com/disiqueira/gotree"

cmdCore "github.com/flyteorg/flyte/flytectl/cmd/core"
"github.com/flyteorg/flyte/flytectl/pkg/printer"
"github.com/flyteorg/flyte/flyteidl/clients/go/coreutils"
"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin"
"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core"
"github.com/golang/protobuf/jsonpb"
"github.com/flyteorg/flyte/flytestdlib/utils"
)

var nodeExecutionColumns = []printer.Column{
Expand Down Expand Up @@ -50,18 +50,13 @@ type TaskExecution struct {

// MarshalJSON overridden method to json marshalling to use jsonpb
func (in *TaskExecution) MarshalJSON() ([]byte, error) {
var buf bytes.Buffer
marshaller := jsonpb.Marshaler{}
if err := marshaller.Marshal(&buf, in.TaskExecution); err != nil {
return nil, err
}
return buf.Bytes(), nil
return utils.MarshalPbToBytes(in.TaskExecution)
}

// UnmarshalJSON overridden method to json unmarshalling to use jsonpb
func (in *TaskExecution) UnmarshalJSON(b []byte) error {
in.TaskExecution = &admin.TaskExecution{}
return jsonpb.Unmarshal(bytes.NewReader(b), in.TaskExecution)
return utils.UnmarshalBytesToPb(b, in.TaskExecution)

Check warning on line 59 in flytectl/cmd/get/node_execution.go

View check run for this annotation

Codecov / codecov/patch

flytectl/cmd/get/node_execution.go#L59

Added line #L59 was not covered by tests
}

type NodeExecution struct {
Expand All @@ -70,18 +65,13 @@ type NodeExecution struct {

// MarshalJSON overridden method to json marshalling to use jsonpb
func (in *NodeExecution) MarshalJSON() ([]byte, error) {
var buf bytes.Buffer
marshaller := jsonpb.Marshaler{}
if err := marshaller.Marshal(&buf, in.NodeExecution); err != nil {
return nil, err
}
return buf.Bytes(), nil
return utils.MarshalPbToBytes(in.NodeExecution)
}

// UnmarshalJSON overridden method to json unmarshalling to use jsonpb
func (in *NodeExecution) UnmarshalJSON(b []byte) error {
*in = NodeExecution{}
return jsonpb.Unmarshal(bytes.NewReader(b), in)
return utils.UnmarshalBytesToPb(b, in.NodeExecution)

Check warning on line 74 in flytectl/cmd/get/node_execution.go

View check run for this annotation

Codecov / codecov/patch

flytectl/cmd/get/node_execution.go#L74

Added line #L74 was not covered by tests
}

// NodeExecutionClosure forms a wrapper around admin.NodeExecution and also fetches the childnodes , task execs
Expand Down
11 changes: 4 additions & 7 deletions flytectl/pkg/visualize/graphviz_test.go
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
package visualize

import (
"bytes"
"fmt"
"io/ioutil"
"testing"

graphviz "github.com/awalterschulze/gographviz"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"

"github.com/flyteorg/flyte/flytectl/pkg/visualize/mocks"
"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flyte/flytestdlib/utils"
"github.com/golang/protobuf/jsonpb"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)

func TestRenderWorkflowBranch(t *testing.T) {
Expand All @@ -24,10 +23,8 @@ func TestRenderWorkflowBranch(t *testing.T) {
r, err := ioutil.ReadFile(fmt.Sprintf("testdata/%s.json", s))
assert.NoError(t, err)

i := bytes.NewReader(r)

c := &core.CompiledWorkflowClosure{}
err = jsonpb.Unmarshal(i, c)
err = utils.UnmarshalBytesToPb(r, c)
assert.NoError(t, err)
b, err := RenderWorkflow(c)
fmt.Println(b)
Expand Down
3 changes: 2 additions & 1 deletion flyteidl/clients/go/coreutils/literals.go
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,8 @@ func MakeLiteralForSimpleType(t core.SimpleType, s string) (*core.Literal, error
switch t {
case core.SimpleType_STRUCT:
st := &structpb.Struct{}
err := jsonpb.UnmarshalString(s, st)
unmarshaler := jsonpb.Unmarshaler{AllowUnknownFields: true}
err := unmarshaler.Unmarshal(strings.NewReader(s), st)
if err != nil {
return nil, errors.Wrapf(err, "failed to load generic type as json.")
}
Expand Down
6 changes: 5 additions & 1 deletion flyteplugins/go/tasks/pluginmachinery/utils/marshal_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ var jsonPbUnmarshaler = &jsonpb.Unmarshaler{
AllowUnknownFields: true,
}

// Deprecated: Use flytestdlib/utils.UnmarshalStructToPb instead.
func UnmarshalStruct(structObj *structpb.Struct, msg proto.Message) error {
if structObj == nil {
return fmt.Errorf("nil Struct Object passed")
Expand All @@ -32,6 +33,7 @@ func UnmarshalStruct(structObj *structpb.Struct, msg proto.Message) error {
return nil
}

// Deprecated: Use flytestdlib/utils.MarshalPbToStruct instead.
func MarshalStruct(in proto.Message, out *structpb.Struct) error {
if out == nil {
return fmt.Errorf("nil Struct Object passed")
Expand All @@ -49,11 +51,12 @@ func MarshalStruct(in proto.Message, out *structpb.Struct) error {
return nil
}

// Deprecated: Use flytestdlib/utils.MarshalToString instead.
func MarshalToString(msg proto.Message) (string, error) {
return jsonPbMarshaler.MarshalToString(msg)
}

// TODO: Use the stdlib version in the future, or move there if not there.
// Deprecated: Use flytestdlib/utils.MarshalObjToStruct instead.
// Don't use this if input is a proto Message.
func MarshalObjToStruct(input interface{}) (*structpb.Struct, error) {
b, err := json.Marshal(input)
Expand All @@ -69,6 +72,7 @@ func MarshalObjToStruct(input interface{}) (*structpb.Struct, error) {
return structObj, nil
}

// Deprecated: Use flytestdlib/utils.UnmarshalStructToObj instead.
// Don't use this if the unmarshalled obj is a proto message.
func UnmarshalStructToObj(structObj *structpb.Struct, obj interface{}) error {
if structObj == nil {
Expand Down
4 changes: 2 additions & 2 deletions flyteplugins/go/tasks/plugins/k8s/dask/dask_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"time"

daskAPI "github.com/dask/dask-kubernetes/v2023/dask_kubernetes/operator/go_client/pkg/apis/kubernetes.dask.org/v1"
"github.com/golang/protobuf/jsonpb"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"google.golang.org/protobuf/types/known/structpb"
Expand All @@ -25,6 +24,7 @@ import (
pluginIOMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/io/mocks"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils"
stdlibUtils "github.com/flyteorg/flyte/flytestdlib/utils"
)

const (
Expand Down Expand Up @@ -122,7 +122,7 @@ func dummyDaskTaskTemplate(customImage string, resources *core.Resources, podTem
}

structObj := structpb.Struct{}
err = jsonpb.UnmarshalString(daskJobJSON, &structObj)
err = stdlibUtils.UnmarshalStringToPb(daskJobJSON, &structObj)
if err != nil {
panic(err)
}
Expand Down
4 changes: 2 additions & 2 deletions flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"testing"
"time"

"github.com/golang/protobuf/jsonpb"
structpb "github.com/golang/protobuf/ptypes/struct"
mpiOp "github.com/kubeflow/common/pkg/apis/common/v1"
kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1"
Expand All @@ -29,6 +28,7 @@ import (
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/kfoperators/common"
stdlibUtils "github.com/flyteorg/flyte/flytestdlib/utils"
)

const testImage = "image://"
Expand Down Expand Up @@ -99,7 +99,7 @@ func dummyMPITaskTemplate(id string, args ...interface{}) *core.TaskTemplate {

structObj := structpb.Struct{}

err = jsonpb.UnmarshalString(mpiObjJSON, &structObj)
err = stdlibUtils.UnmarshalStringToPb(mpiObjJSON, &structObj)
if err != nil {
panic(err)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"testing"
"time"

"github.com/golang/protobuf/jsonpb"
structpb "github.com/golang/protobuf/ptypes/struct"
commonOp "github.com/kubeflow/common/pkg/apis/common/v1"
kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1"
Expand All @@ -29,6 +28,7 @@ import (
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/kfoperators/common"
stdlibUtils "github.com/flyteorg/flyte/flytestdlib/utils"
)

const testImage = "image://"
Expand Down Expand Up @@ -105,7 +105,7 @@ func dummyPytorchTaskTemplate(id string, args ...interface{}) *core.TaskTemplate

structObj := structpb.Struct{}

err = jsonpb.UnmarshalString(ptObjJSON, &structObj)
err = stdlibUtils.UnmarshalStringToPb(ptObjJSON, &structObj)
if err != nil {
panic(err)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"testing"
"time"

"github.com/golang/protobuf/jsonpb"
structpb "github.com/golang/protobuf/ptypes/struct"
commonOp "github.com/kubeflow/common/pkg/apis/common/v1"
kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1"
Expand All @@ -29,6 +28,7 @@ import (
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/kfoperators/common"
stdlibUtils "github.com/flyteorg/flyte/flytestdlib/utils"
)

const testImage = "image://"
Expand Down Expand Up @@ -100,7 +100,7 @@ func dummyTensorFlowTaskTemplate(id string, args ...interface{}) *core.TaskTempl

structObj := structpb.Struct{}

err = jsonpb.UnmarshalString(tfObjJSON, &structObj)
err = stdlibUtils.UnmarshalStringToPb(tfObjJSON, &structObj)
if err != nil {
panic(err)
}
Expand Down
6 changes: 3 additions & 3 deletions flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (

sj "github.com/GoogleCloudPlatform/spark-on-k8s-operator/pkg/apis/sparkoperator.k8s.io/v1beta2"
sparkOp "github.com/GoogleCloudPlatform/spark-on-k8s-operator/pkg/apis/sparkoperator.k8s.io/v1beta2"
"github.com/golang/protobuf/jsonpb"
structpb "github.com/golang/protobuf/ptypes/struct"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
Expand All @@ -26,6 +25,7 @@ import (
pluginIOMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/io/mocks"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils"
stdlibUtils "github.com/flyteorg/flyte/flytestdlib/utils"
)

const sparkMainClass = "MainClass"
Expand Down Expand Up @@ -318,7 +318,7 @@ func dummySparkTaskTemplateContainer(id string, sparkConf map[string]string) *co

structObj := structpb.Struct{}

err = jsonpb.UnmarshalString(sparkJobJSON, &structObj)
err = stdlibUtils.UnmarshalStringToPb(sparkJobJSON, &structObj)
if err != nil {
panic(err)
}
Expand Down Expand Up @@ -346,7 +346,7 @@ func dummySparkTaskTemplatePod(id string, sparkConf map[string]string, podSpec *

structObj := structpb.Struct{}

err = jsonpb.UnmarshalString(sparkJobJSON, &structObj)
err = stdlibUtils.UnmarshalStringToPb(sparkJobJSON, &structObj)
if err != nil {
panic(err)
}
Expand Down
14 changes: 3 additions & 11 deletions flytepropeller/cmd/kubectl-flyte/cmd/create.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
package cmd

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io/ioutil"

"github.com/ghodss/yaml"
"github.com/golang/protobuf/jsonpb"
"github.com/golang/protobuf/proto"
"github.com/pkg/errors"
"github.com/spf13/cobra"
Expand All @@ -20,6 +18,7 @@ import (
"github.com/flyteorg/flyte/flytepropeller/pkg/compiler/common"
compilerErrors "github.com/flyteorg/flyte/flytepropeller/pkg/compiler/errors"
"github.com/flyteorg/flyte/flytepropeller/pkg/compiler/transformers/k8s"
"github.com/flyteorg/flyte/flytestdlib/utils"
)

const (
Expand Down Expand Up @@ -89,7 +88,7 @@ func unmarshal(in []byte, format format, message proto.Message) (err error) {
case formatProto:
err = proto.Unmarshal(in, message)
case formatJSON:
err = jsonpb.Unmarshal(bytes.NewReader(in), message)
err = utils.UnmarshalBytesToPb(in, message)
if err != nil {
err = errors.Wrapf(err, "Failed to unmarshal converted Json. [%v]", string(in))
}
Expand All @@ -105,19 +104,12 @@ func unmarshal(in []byte, format format, message proto.Message) (err error) {
return
}

var jsonPbMarshaler = jsonpb.Marshaler{}

func marshal(message proto.Message, format format) (raw []byte, err error) {
switch format {
case formatProto:
return proto.Marshal(message)
case formatJSON:
b := &bytes.Buffer{}
err := jsonPbMarshaler.Marshal(b, message)
if err != nil {
return nil, errors.Wrapf(err, "Failed to marshal Json.")
}
return b.Bytes(), nil
return utils.MarshalPbToBytes(message)
case formatYaml:
b, err := marshal(message, formatJSON)
if err != nil {
Expand Down
Loading

0 comments on commit c996dcb

Please sign in to comment.