diff --git a/flyteadmin/go.sum b/flyteadmin/go.sum index 5b7d47b2a6..31a1714ed7 100644 --- a/flyteadmin/go.sum +++ b/flyteadmin/go.sum @@ -1174,6 +1174,8 @@ github.com/sendgrid/sendgrid-go v3.10.0+incompatible/go.mod h1:QRQt+LX/NmgVEvmdR github.com/serenize/snaker v0.0.0-20171204205717-a683aaf2d516/go.mod h1:Yow6lPLSAXx2ifx470yD/nUe22Dv5vBvxK/UK9UUTVs= github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= github.com/sergi/go-diff v1.1.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM= +github.com/shamaton/msgpack/v2 v2.2.2 h1:GOIg0c9LV04VwzOOqZSrmsv/JzjNOOMxnS/HvOHGdgs= +github.com/shamaton/msgpack/v2 v2.2.2/go.mod h1:6khjYnkx73f7VQU7wjcFS9DFjs+59naVWJv1TB7qdOI= github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= github.com/shurcooL/go v0.0.0-20180423040247-9e1955d9fb6e/go.mod h1:TDJrrUr11Vxrven61rcy3hJMUqaf/CLWYhHNPmT14Lk= github.com/shurcooL/go-goon v0.0.0-20170922171312-37c2f522c041/go.mod h1:N5mDOmsrJOB+vfqUK+7DmDyjhSLIIBnXo9lvZJj3MWQ= diff --git a/flytepropeller/go.mod b/flytepropeller/go.mod index f579049aff..6b55e8909e 100644 --- a/flytepropeller/go.mod +++ b/flytepropeller/go.mod @@ -22,6 +22,7 @@ require ( github.com/mitchellh/mapstructure v1.5.0 github.com/pkg/errors v0.9.1 github.com/prometheus/client_golang v1.16.0 + github.com/shamaton/msgpack/v2 v2.2.2 github.com/sirupsen/logrus v1.9.3 github.com/spf13/cobra v1.7.0 github.com/spf13/pflag v1.0.5 diff --git a/flytepropeller/go.sum b/flytepropeller/go.sum index 07a92b902b..dc0ccb0464 100644 --- a/flytepropeller/go.sum +++ b/flytepropeller/go.sum @@ -374,6 +374,8 @@ github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFR github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/shamaton/msgpack/v2 v2.2.2 h1:GOIg0c9LV04VwzOOqZSrmsv/JzjNOOMxnS/HvOHGdgs= +github.com/shamaton/msgpack/v2 v2.2.2/go.mod h1:6khjYnkx73f7VQU7wjcFS9DFjs+59naVWJv1TB7qdOI= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= diff --git a/flytepropeller/pkg/compiler/validators/bindings_test.go b/flytepropeller/pkg/compiler/validators/bindings_test.go index 7e5b388391..bcb498eebd 100644 --- a/flytepropeller/pkg/compiler/validators/bindings_test.go +++ b/flytepropeller/pkg/compiler/validators/bindings_test.go @@ -776,9 +776,9 @@ func TestValidateBindings(t *testing.T) { _, ok := ValidateBindings(wf, n, bindings, vars, true, c.EdgeDirectionBidirectional, compileErrors) assert.False(t, ok) assert.Equal(t, "MismatchingTypes", string(compileErrors.Errors().List()[0].Code())) - assert.Equal(t, "Code: MismatchingTypes, Node Id: node1, Description: Variable [x]"+ - " (type [union_type:{variants:{simple:INTEGER structure:{tag:\"int\"}}}]) doesn't match expected type"+ - " [union_type:{variants:{simple:INTEGER structure:{tag:\"int_other\"}}}].", compileErrors.Errors().List()[0].Error()) + assert.Contains(t, compileErrors.Errors().List()[0].Error(), "Code: MismatchingTypes, Node Id: node1, Description: Variable [x]") + assert.Contains(t, compileErrors.Errors().List()[0].Error(), "(type [union_type:{variants:{simple:INTEGER") + assert.Contains(t, compileErrors.Errors().List()[0].Error(), "doesn't match expected type") }) t.Run("List of Int to List of Unions Binding", func(t *testing.T) { @@ -1210,10 +1210,9 @@ func TestValidateBindings(t *testing.T) { _, ok := ValidateBindings(wf, n, bindings, vars, true, c.EdgeDirectionBidirectional, compileErrors) assert.False(t, ok) assert.Equal(t, "MismatchingTypes", string(compileErrors.Errors().List()[0].Code())) - assert.Equal(t, "Code: MismatchingTypes, Node Id: node1, Description: The output variable 'n2.n2_out'"+ - " has type [simple:INTEGER], but it's assigned to the input variable 'n.x' which has type"+ - " type [union_type:{variants:{simple:STRING structure:{tag:\"str\"}} variants:{simple:INTEGER structure:{tag:\"int1\"}}"+ - " variants:{simple:INTEGER structure:{tag:\"int2\"}}}].", compileErrors.Errors().List()[0].Error()) + assert.Contains(t, compileErrors.Errors().List()[0].Error(), "Code: MismatchingTypes, Node Id: node1,") + assert.Contains(t, compileErrors.Errors().List()[0].Error(), "Description: The output variable 'n2.n2_out'") + assert.Contains(t, compileErrors.Errors().List()[0].Error(), "has type [simple:INTEGER], but it's assigned to the input variable 'n.x' which has type") }) t.Run("Union Promise Union Literal", func(t *testing.T) { diff --git a/flytepropeller/pkg/compiler/validators/utils.go b/flytepropeller/pkg/compiler/validators/utils.go index 5f41a6e65e..fb4ba04548 100644 --- a/flytepropeller/pkg/compiler/validators/utils.go +++ b/flytepropeller/pkg/compiler/validators/utils.go @@ -44,7 +44,14 @@ func literalTypeForScalar(scalar *core.Scalar) *core.LiteralType { literalType = &core.LiteralType{Type: &core.LiteralType_Blob{Blob: scalar.GetBlob().GetMetadata().GetType()}} case *core.Scalar_Binary: - literalType = &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_BINARY}} + // If the binary has a tag, treat it as a structured type (e.g., dict, dataclass, Pydantic BaseModel). + // Otherwise, treat it as raw binary data. + // Reference: https://github.com/flyteorg/flyte/blob/master/rfc/system/5741-binary-idl-with-message-pack.md + if len(v.Binary.Tag) > 0 { + literalType = &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_STRUCT}} + } else { + literalType = &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_BINARY}} + } case *core.Scalar_Schema: literalType = &core.LiteralType{ Type: &core.LiteralType_Schema{ diff --git a/flytepropeller/pkg/compiler/validators/utils_test.go b/flytepropeller/pkg/compiler/validators/utils_test.go index 4a37f100dc..26e34988c3 100644 --- a/flytepropeller/pkg/compiler/validators/utils_test.go +++ b/flytepropeller/pkg/compiler/validators/utils_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/golang/protobuf/proto" + "github.com/shamaton/msgpack/v2" "github.com/stretchr/testify/assert" "github.com/flyteorg/flyte/flyteidl/clients/go/coreutils" @@ -16,6 +17,82 @@ func TestLiteralTypeForLiterals(t *testing.T) { assert.Equal(t, core.SimpleType_NONE.String(), lt.GetSimple().String()) }) + t.Run("binary idl with raw binary data and no tag", func(t *testing.T) { + // Some arbitrary raw binary data + rawBinaryData := []byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A} + + lv := &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: rawBinaryData, + Tag: "", + }, + }, + }, + }, + } + lt := LiteralTypeForLiteral(lv) + assert.Equal(t, core.SimpleType_BINARY.String(), lt.GetSimple().String()) + }) + + t.Run("binary idl with messagepack input map[int]strings", func(t *testing.T) { + // Create a map[int]string and serialize it using MessagePack. + data := map[int]string{ + 1: "hello", + 2: "world", + -1: "foo", + } + // Serializing the map using MessagePack + serializedBinaryData, err := msgpack.Marshal(data) + if err != nil { + t.Fatalf("failed to serialize map: %v", err) + } + lv := &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: serializedBinaryData, + Tag: "msgpack", + }, + }, + }, + }, + } + lt := LiteralTypeForLiteral(lv) + assert.Equal(t, core.SimpleType_STRUCT.String(), lt.GetSimple().String()) + }) + + t.Run("binary idl with messagepack input map[float]strings", func(t *testing.T) { + // Create a map[float]string and serialize it using MessagePack. + data := map[float64]string{ + 1.0: "hello", + 5.0: "world", + -1.0: "foo", + } + // Serializing the map using MessagePack + serializedBinaryData, err := msgpack.Marshal(data) + if err != nil { + t.Fatalf("failed to serialize map: %v", err) + } + lv := &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: serializedBinaryData, + Tag: "msgpack", + }, + }, + }, + }, + } + lt := LiteralTypeForLiteral(lv) + assert.Equal(t, core.SimpleType_STRUCT.String(), lt.GetSimple().String()) + }) + t.Run("homogeneous", func(t *testing.T) { lt := literalTypeForLiterals([]*core.Literal{ coreutils.MustMakeLiteral(5), diff --git a/go.sum b/go.sum index 68eebb1fde..63453bbd87 100644 --- a/go.sum +++ b/go.sum @@ -1210,6 +1210,8 @@ github.com/sendgrid/sendgrid-go v3.10.0+incompatible/go.mod h1:QRQt+LX/NmgVEvmdR github.com/serenize/snaker v0.0.0-20171204205717-a683aaf2d516/go.mod h1:Yow6lPLSAXx2ifx470yD/nUe22Dv5vBvxK/UK9UUTVs= github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= github.com/sergi/go-diff v1.1.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM= +github.com/shamaton/msgpack/v2 v2.2.2 h1:GOIg0c9LV04VwzOOqZSrmsv/JzjNOOMxnS/HvOHGdgs= +github.com/shamaton/msgpack/v2 v2.2.2/go.mod h1:6khjYnkx73f7VQU7wjcFS9DFjs+59naVWJv1TB7qdOI= github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= github.com/shurcooL/go v0.0.0-20180423040247-9e1955d9fb6e/go.mod h1:TDJrrUr11Vxrven61rcy3hJMUqaf/CLWYhHNPmT14Lk= github.com/shurcooL/go-goon v0.0.0-20170922171312-37c2f522c041/go.mod h1:N5mDOmsrJOB+vfqUK+7DmDyjhSLIIBnXo9lvZJj3MWQ=