From 182f5e1918ee7dc52fdfc3e6d3fa0bb76484f729 Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Wed, 30 Oct 2024 20:44:48 +0800 Subject: [PATCH 1/2] Binary IDL Attribute Access for Map Task Signed-off-by: Future-Outlier --- .../controller/nodes/attr_path_resolver.go | 28 ++- .../nodes/attr_path_resolver_test.go | 223 +++++++----------- 2 files changed, 103 insertions(+), 148 deletions(-) diff --git a/flytepropeller/pkg/controller/nodes/attr_path_resolver.go b/flytepropeller/pkg/controller/nodes/attr_path_resolver.go index 192fa1956c..222987ffa1 100644 --- a/flytepropeller/pkg/controller/nodes/attr_path_resolver.go +++ b/flytepropeller/pkg/controller/nodes/attr_path_resolver.go @@ -46,8 +46,7 @@ func resolveAttrPathInPromise(ctx context.Context, datastore *storage.DataStore, } currVal = currVal.GetCollection().GetLiterals()[attr.GetIntValue()] index++ - // scalar is always the leaf, so we can break here - case *core.Literal_Scalar: + default: break } } @@ -107,9 +106,7 @@ func resolveAttrPathInPbStruct(nodeID string, st *structpb.Struct, bindAttrPath } // resolveAttrPathInBinary resolves the binary idl object (e.g. dataclass, pydantic basemodel) with attribute path -func resolveAttrPathInBinary(nodeID string, binaryIDL *core.Binary, bindAttrPath []*core.PromiseAttribute) (*core. - Literal, - error) { +func resolveAttrPathInBinary(nodeID string, binaryIDL *core.Binary, bindAttrPath []*core.PromiseAttribute) (*core.Literal, error) { binaryBytes := binaryIDL.GetValue() serializationFormat := binaryIDL.GetTag() @@ -165,6 +162,27 @@ func resolveAttrPathInBinary(nodeID string, binaryIDL *core.Binary, bindAttrPath } } + // if currVal is list, convert it to literal collection + // This is for map task handling + if collection, ok := currVal.([]any); ok { + literals := make([]*core.Literal, len(collection)) + for i, v := range collection { + resolvedBinaryBytes, err := msgpack.Marshal(v) + if err != nil { + return nil, err + } + literals[i] = constructResolvedBinary(resolvedBinaryBytes, serializationFormat) + } + + return &core.Literal{ + Value: &core.Literal_Collection{ + Collection: &core.LiteralCollection{ + Literals: literals, + }, + }, + }, nil + } + // Marshal the current value to MessagePack bytes resolvedBinaryBytes, err := msgpack.Marshal(currVal) if err != nil { diff --git a/flytepropeller/pkg/controller/nodes/attr_path_resolver_test.go b/flytepropeller/pkg/controller/nodes/attr_path_resolver_test.go index 1467fc0ea4..e8e28ac08f 100644 --- a/flytepropeller/pkg/controller/nodes/attr_path_resolver_test.go +++ b/flytepropeller/pkg/controller/nodes/attr_path_resolver_test.go @@ -10,13 +10,14 @@ import ( "github.com/stretchr/testify/assert" "google.golang.org/protobuf/types/known/structpb" + "github.com/flyteorg/flyte/flyteidl/clients/go/coreutils" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyte/flytepropeller/pkg/controller/nodes/errors" ) -// FlyteFile and FlyteDirectory represented as map[interface{}]interface{} -type FlyteFile map[interface{}]interface{} -type FlyteDirectory map[interface{}]interface{} +// FlyteFile and FlyteDirectory represented as map[any]any +type FlyteFile map[any]any +type FlyteDirectory map[any]any // InnerDC struct (equivalent to InnerDC dataclass in Python) type InnerDC struct { @@ -73,7 +74,7 @@ func NewScalarLiteral(value string) *core.Literal { } } -func NewStructFromMap(m map[string]interface{}) *structpb.Struct { +func NewStructFromMap(m map[string]any) *structpb.Struct { st, _ := structpb.NewStruct(m) return st } @@ -135,7 +136,7 @@ func TestResolveAttrPathInStruct(t *testing.T) { Value: &core.Literal_Scalar{ Scalar: &core.Scalar{ Value: &core.Scalar_Generic{ - Generic: NewStructFromMap(map[string]interface{}{"foo": "bar"}), + Generic: NewStructFromMap(map[string]any{"foo": "bar"}), }, }, }, @@ -157,8 +158,8 @@ func TestResolveAttrPathInStruct(t *testing.T) { Scalar: &core.Scalar{ Value: &core.Scalar_Generic{ Generic: NewStructFromMap( - map[string]interface{}{ - "foo": []interface{}{"bar1", "bar2"}, + map[string]any{ + "foo": []any{"bar1", "bar2"}, }, ), }, @@ -187,8 +188,8 @@ func TestResolveAttrPathInStruct(t *testing.T) { Scalar: &core.Scalar{ Value: &core.Scalar_Generic{ Generic: NewStructFromMap( - map[string]interface{}{ - "foo": []interface{}{[]interface{}{"bar1", "bar2"}}, + map[string]any{ + "foo": []any{[]any{"bar1", "bar2"}}, }, ), }, @@ -236,7 +237,7 @@ func TestResolveAttrPathInStruct(t *testing.T) { Value: &core.Literal_Scalar{ Scalar: &core.Scalar{ Value: &core.Scalar_Generic{ - Generic: NewStructFromMap(map[string]interface{}{"bar": "car"}), + Generic: NewStructFromMap(map[string]any{"bar": "car"}), }, }, }, @@ -276,9 +277,9 @@ func TestResolveAttrPathInStruct(t *testing.T) { Scalar: &core.Scalar{ Value: &core.Scalar_Generic{ Generic: NewStructFromMap( - map[string]interface{}{ - "foo": map[string]interface{}{ - "bar": map[string]interface{}{ + map[string]any{ + "foo": map[string]any{ + "bar": map[string]any{ "baz": 42, }, }, @@ -306,7 +307,7 @@ func TestResolveAttrPathInStruct(t *testing.T) { Scalar: &core.Scalar{ Value: &core.Scalar_Generic{ Generic: NewStructFromMap( - map[string]interface{}{ + map[string]any{ "baz": 42, }, ), @@ -365,7 +366,7 @@ func TestResolveAttrPathInStruct(t *testing.T) { Value: &core.Literal_Scalar{ Scalar: &core.Scalar{ Value: &core.Scalar_Generic{ - Generic: NewStructFromMap(map[string]interface{}{"foo": "bar"}), + Generic: NewStructFromMap(map[string]any{"foo": "bar"}), }, }, }, @@ -387,8 +388,8 @@ func TestResolveAttrPathInStruct(t *testing.T) { Scalar: &core.Scalar{ Value: &core.Scalar_Generic{ Generic: NewStructFromMap( - map[string]interface{}{ - "foo": []interface{}{"bar1", "bar2"}, + map[string]any{ + "foo": []any{"bar1", "bar2"}, }, ), }, @@ -495,11 +496,35 @@ func createNestedDC() DC { func TestResolveAttrPathInBinary(t *testing.T) { // Helper function to convert a map to msgpack bytes and then to BinaryIDL - toMsgpackBytes := func(m interface{}) []byte { + toMsgpackBytes := func(m any) []byte { msgpackBytes, err := msgpack.Marshal(m) assert.NoError(t, err) return msgpackBytes } + toLiteralCollectionWithMsgpackBytes := func(collection []any) *core.Literal { + literals := make([]*core.Literal, len(collection)) + for i, v := range collection { + resolvedBinaryBytes, _ := msgpack.Marshal(v) + literals[i] = constructResolvedBinary(resolvedBinaryBytes, coreutils.MESSAGEPACK) + } + return &core.Literal{ + Value: &core.Literal_Collection{ + Collection: &core.LiteralCollection{ + Literals: literals, + }, + }, + } + } + fromLiteralCollectionWithMsgpackBytes := func(lv *core.Literal) []any { + literals := lv.GetCollection().GetLiterals() + collection := make([]any, len(literals)) + for i, l := range literals { + var v any + _ = msgpack.Unmarshal(l.GetScalar().GetBinary().Value, &v) + collection[i] = v + } + return collection + } flyteFile := FlyteFile{ "path": "s3://my-s3-bucket/example.txt", @@ -630,18 +655,7 @@ func TestResolveAttrPathInBinary(t *testing.T) { }, }, }, - expected: &core.Literal{ - Value: &core.Literal_Scalar{ - Scalar: &core.Scalar{ - Value: &core.Scalar_Binary{ - Binary: &core.Binary{ - Value: toMsgpackBytes([]int{0, 1, 2, -1, -2}), - Tag: "msgpack", - }, - }, - }, - }, - }, + expected: toLiteralCollectionWithMsgpackBytes([]any{0, 1, 2, -1, -2}), hasError: false, }, { @@ -653,18 +667,7 @@ func TestResolveAttrPathInBinary(t *testing.T) { }, }, }, - expected: &core.Literal{ - Value: &core.Literal_Scalar{ - Scalar: &core.Scalar{ - Value: &core.Scalar_Binary{ - Binary: &core.Binary{ - Value: toMsgpackBytes([]FlyteFile{{"path": "s3://my-s3-bucket/example.txt"}}), - Tag: "msgpack", - }, - }, - }, - }, - }, + expected: toLiteralCollectionWithMsgpackBytes([]any{flyteFile}), hasError: false, }, { @@ -676,18 +679,7 @@ func TestResolveAttrPathInBinary(t *testing.T) { }, }, }, - expected: &core.Literal{ - Value: &core.Literal_Scalar{ - Scalar: &core.Scalar{ - Value: &core.Scalar_Binary{ - Binary: &core.Binary{ - Value: toMsgpackBytes([][]int{{0}, {1}, {-1}}), - Tag: "msgpack", - }, - }, - }, - }, - }, + expected: toLiteralCollectionWithMsgpackBytes([]any{[]int{0}, []int{1}, []int{-1}}), hasError: false, }, { @@ -699,18 +691,8 @@ func TestResolveAttrPathInBinary(t *testing.T) { }, }, }, - expected: &core.Literal{ - Value: &core.Literal_Scalar{ - Scalar: &core.Scalar{ - Value: &core.Scalar_Binary{ - Binary: &core.Binary{ - Value: toMsgpackBytes([]map[int]bool{{0: false}, {1: true}, {-1: true}}), - Tag: "msgpack", - }, - }, - }, - }, - }, + expected: toLiteralCollectionWithMsgpackBytes([]any{map[int]bool{0: false}, map[int]bool{1: true}, + map[int]bool{-1: true}}), hasError: false, }, { @@ -1037,18 +1019,7 @@ func TestResolveAttrPathInBinary(t *testing.T) { }, }, }, - expected: &core.Literal{ - Value: &core.Literal_Scalar{ - Scalar: &core.Scalar{ - Value: &core.Scalar_Binary{ - Binary: &core.Binary{ - Value: toMsgpackBytes([]int{0, 1, 2, -1, -2}), - Tag: "msgpack", - }, - }, - }, - }, - }, + expected: toLiteralCollectionWithMsgpackBytes([]any{0, 1, 2, -1, -2}), hasError: false, }, { @@ -1065,18 +1036,7 @@ func TestResolveAttrPathInBinary(t *testing.T) { }, }, }, - expected: &core.Literal{ - Value: &core.Literal_Scalar{ - Scalar: &core.Scalar{ - Value: &core.Scalar_Binary{ - Binary: &core.Binary{ - Value: toMsgpackBytes([]FlyteFile{flyteFile}), - Tag: "msgpack", - }, - }, - }, - }, - }, + expected: toLiteralCollectionWithMsgpackBytes([]any{flyteFile}), hasError: false, }, { @@ -1093,18 +1053,7 @@ func TestResolveAttrPathInBinary(t *testing.T) { }, }, }, - expected: &core.Literal{ - Value: &core.Literal_Scalar{ - Scalar: &core.Scalar{ - Value: &core.Scalar_Binary{ - Binary: &core.Binary{ - Value: toMsgpackBytes([][]int{{0}, {1}, {-1}}), - Tag: "msgpack", - }, - }, - }, - }, - }, + expected: toLiteralCollectionWithMsgpackBytes([]any{[]int{0}, []int{1}, []int{-1}}), hasError: false, }, { @@ -1126,18 +1075,7 @@ func TestResolveAttrPathInBinary(t *testing.T) { }, }, }, - expected: &core.Literal{ - Value: &core.Literal_Scalar{ - Scalar: &core.Scalar{ - Value: &core.Scalar_Binary{ - Binary: &core.Binary{ - Value: toMsgpackBytes([]int{0}), - Tag: "msgpack", - }, - }, - }, - }, - }, + expected: toLiteralCollectionWithMsgpackBytes([]any{0}), hasError: false, }, { @@ -1192,18 +1130,11 @@ func TestResolveAttrPathInBinary(t *testing.T) { }, }, }, - expected: &core.Literal{ - Value: &core.Literal_Scalar{ - Scalar: &core.Scalar{ - Value: &core.Scalar_Binary{ - Binary: &core.Binary{ - Value: toMsgpackBytes([]map[int]bool{{0: false}, {1: true}, {-1: true}}), - Tag: "msgpack", - }, - }, - }, - }, - }, + expected: toLiteralCollectionWithMsgpackBytes([]any{ + map[int]bool{0: false}, + map[int]bool{1: true}, + map[int]bool{-1: true}, + }), hasError: false, }, { @@ -1422,10 +1353,10 @@ func TestResolveAttrPathInBinary(t *testing.T) { Scalar: &core.Scalar{ Value: &core.Scalar_Binary{ Binary: &core.Binary{ - Value: toMsgpackBytes(map[string]interface{}{ - "foo": map[string]interface{}{ + Value: toMsgpackBytes(map[string]any{ + "foo": map[string]any{ "bar": int64(42), - "baz": map[string]interface{}{ + "baz": map[string]any{ "qux": 3.14, "quux": "str", }, @@ -1465,8 +1396,8 @@ func TestResolveAttrPathInBinary(t *testing.T) { Scalar: &core.Scalar{ Value: &core.Scalar_Binary{ Binary: &core.Binary{ - Value: toMsgpackBytes(map[string]interface{}{ - "foo": []interface{}{int64(42), 3.14, "str"}, + Value: toMsgpackBytes(map[string]any{ + "foo": []any{int64(42), 3.14, "str"}, }), Tag: "msgpack", }, @@ -1499,13 +1430,13 @@ func TestResolveAttrPathInBinary(t *testing.T) { assert.Error(t, err, i) assert.ErrorContains(t, err, errors.PromiseAttributeResolveError, i) } else { - var expectedValue, actualValue interface{} + var expectedValue, actualValue any - // Helper function to unmarshal a Binary Literal into an interface{} - unmarshalBinaryLiteral := func(literal *core.Literal) (interface{}, error) { + // Helper function to unmarshal a Binary Literal into an any + unmarshalBinaryLiteral := func(literal *core.Literal) (any, error) { if scalar, ok := literal.Value.(*core.Literal_Scalar); ok { if binary, ok := scalar.Scalar.Value.(*core.Scalar_Binary); ok { - var value interface{} + var value any err := msgpack.Unmarshal(binary.Binary.Value, &value) return value, err } @@ -1513,16 +1444,22 @@ func TestResolveAttrPathInBinary(t *testing.T) { return nil, fmt.Errorf("literal is not a Binary Scalar") } - // Unmarshal the expected value - expectedValue, err := unmarshalBinaryLiteral(arg.expected) - if err != nil { - t.Fatalf("Failed to unmarshal expected value in test case %d: %v", i, err) + if arg.expected.GetCollection() != nil { + expectedValue = fromLiteralCollectionWithMsgpackBytes(arg.expected) + } else { + expectedValue, err = unmarshalBinaryLiteral(arg.expected) + if err != nil { + t.Fatalf("Failed to unmarshal expected value in test case %d: %v", i, err) + } } - // Unmarshal the resolved value - actualValue, err = unmarshalBinaryLiteral(resolved) - if err != nil { - t.Fatalf("Failed to unmarshal resolved value in test case %d: %v", i, err) + if resolved.GetCollection() != nil { + actualValue = fromLiteralCollectionWithMsgpackBytes(resolved) + } else { + actualValue, err = unmarshalBinaryLiteral(resolved) + if err != nil { + t.Fatalf("Failed to unmarshal resolved value in test case %d: %v", i, err) + } } // Deeply compare the expected and actual values, ignoring map ordering From 856e606419a3756db8c4d8c0968bee80fbb445b4 Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Wed, 30 Oct 2024 23:12:30 +0800 Subject: [PATCH 2/2] add arrayNodeHandler comments Signed-off-by: Future-Outlier --- flytepropeller/pkg/controller/nodes/attr_path_resolver.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/flytepropeller/pkg/controller/nodes/attr_path_resolver.go b/flytepropeller/pkg/controller/nodes/attr_path_resolver.go index 222987ffa1..3b4e46ce50 100644 --- a/flytepropeller/pkg/controller/nodes/attr_path_resolver.go +++ b/flytepropeller/pkg/controller/nodes/attr_path_resolver.go @@ -162,8 +162,9 @@ func resolveAttrPathInBinary(nodeID string, binaryIDL *core.Binary, bindAttrPath } } - // if currVal is list, convert it to literal collection - // This is for map task handling + // In arrayNodeHandler, the resolved value should be a literal collection. + // If the current value is already a collection, convert it to a literal collection. + // This conversion does not affect how Flytekit processes the resolved value. if collection, ok := currVal.([]any); ok { literals := make([]*core.Literal, len(collection)) for i, v := range collection {