diff --git a/flyteidl/clients/go/coreutils/extract_literal.go b/flyteidl/clients/go/coreutils/extract_literal.go index 4c4ae1f5ad..398821d503 100644 --- a/flyteidl/clients/go/coreutils/extract_literal.go +++ b/flyteidl/clients/go/coreutils/extract_literal.go @@ -56,6 +56,8 @@ func ExtractFromLiteral(literal *core.Literal) (interface{}, error) { } case *core.Scalar_Blob: return scalarValue.Blob.Uri, nil + case *core.Scalar_Generic: + return scalarValue.Generic, nil default: return nil, fmt.Errorf("unsupported literal scalar type %T", scalarValue) } diff --git a/flyteidl/clients/go/coreutils/extract_literal_test.go b/flyteidl/clients/go/coreutils/extract_literal_test.go index ed1dccbbd6..321f3041a3 100644 --- a/flyteidl/clients/go/coreutils/extract_literal_test.go +++ b/flyteidl/clients/go/coreutils/extract_literal_test.go @@ -6,6 +6,9 @@ package coreutils import ( "testing" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + + structpb "github.com/golang/protobuf/ptypes/struct" "github.com/stretchr/testify/assert" ) @@ -101,4 +104,57 @@ func TestFetchLiteral(t *testing.T) { _, err = ExtractFromLiteral(p) assert.NotNil(t, err) }) + + t.Run("Generic", func(t *testing.T) { + literalVal := map[string]interface{}{ + "x": 1, + "y": "ystringvalue", + } + var literalType = &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_STRUCT}} + lit, err := MakeLiteralForType(literalType, literalVal) + assert.NoError(t, err) + extractedLiteralVal, err := ExtractFromLiteral(lit) + assert.NoError(t, err) + fieldsMap := map[string]*structpb.Value{ + "x": { + Kind: &structpb.Value_NumberValue{NumberValue: 1}, + }, + "y": { + Kind: &structpb.Value_StringValue{StringValue: "ystringvalue"}, + }, + } + expectedStructVal := &structpb.Struct{ + Fields: fieldsMap, + } + extractedStructValue := extractedLiteralVal.(*structpb.Struct) + assert.Equal(t, len(expectedStructVal.Fields), len(extractedStructValue.Fields)) + for key, val := range expectedStructVal.Fields { + assert.Equal(t, val.Kind, extractedStructValue.Fields[key].Kind) + } + }) + + t.Run("Generic Passed As String", func(t *testing.T) { + literalVal := "{\"x\": 1,\"y\": \"ystringvalue\"}" + var literalType = &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_STRUCT}} + lit, err := MakeLiteralForType(literalType, literalVal) + assert.NoError(t, err) + extractedLiteralVal, err := ExtractFromLiteral(lit) + assert.NoError(t, err) + fieldsMap := map[string]*structpb.Value{ + "x": { + Kind: &structpb.Value_NumberValue{NumberValue: 1}, + }, + "y": { + Kind: &structpb.Value_StringValue{StringValue: "ystringvalue"}, + }, + } + expectedStructVal := &structpb.Struct{ + Fields: fieldsMap, + } + extractedStructValue := extractedLiteralVal.(*structpb.Struct) + assert.Equal(t, len(expectedStructVal.Fields), len(extractedStructValue.Fields)) + for key, val := range expectedStructVal.Fields { + assert.Equal(t, val.Kind, extractedStructValue.Fields[key].Kind) + } + }) } diff --git a/flyteidl/clients/go/coreutils/literals.go b/flyteidl/clients/go/coreutils/literals.go index 6bb30f7be6..5a23a42b16 100644 --- a/flyteidl/clients/go/coreutils/literals.go +++ b/flyteidl/clients/go/coreutils/literals.go @@ -2,6 +2,7 @@ package coreutils import ( + "encoding/json" "fmt" "reflect" "strconv" @@ -483,7 +484,17 @@ func MakeLiteralForType(t *core.LiteralType, v interface{}) (*core.Literal, erro } case *core.LiteralType_Simple: newT := t.Type.(*core.LiteralType_Simple) - lv, err := MakeLiteralForSimpleType(newT.Simple, fmt.Sprintf("%v", v)) + strValue := fmt.Sprintf("%v", v) + if newT.Simple == core.SimpleType_STRUCT { + if _, isValueStringType := v.(string); !isValueStringType { + byteValue, err := json.Marshal(v) + if err != nil { + return nil, fmt.Errorf("unable to marshal to json string for struct value %v", v) + } + strValue = string(byteValue) + } + } + lv, err := MakeLiteralForSimpleType(newT.Simple, strValue) if err != nil { return nil, err }