From 3bfb3ff20a3067e4603316587fd2b803e6cc8907 Mon Sep 17 00:00:00 2001 From: Dan Rammer Date: Thu, 3 Aug 2023 09:38:11 -0500 Subject: [PATCH] make singular unions castable to their underlying type (#599) Signed-off-by: Daniel Rammer --- .../pkg/compiler/validators/typing.go | 12 ++++++++- .../pkg/compiler/validators/typing_test.go | 26 +++++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/flytepropeller/pkg/compiler/validators/typing.go b/flytepropeller/pkg/compiler/validators/typing.go index 09268359dd..fc5cd368a6 100644 --- a/flytepropeller/pkg/compiler/validators/typing.go +++ b/flytepropeller/pkg/compiler/validators/typing.go @@ -356,5 +356,15 @@ func getTypeChecker(t *flyte.LiteralType) typeChecker { } func AreTypesCastable(upstreamType, downstreamType *flyte.LiteralType) bool { - return getTypeChecker(downstreamType).CastsFrom(upstreamType) + typeChecker := getTypeChecker(downstreamType) + + // if upstream is a singular union we check if the downstream type is castable from the union variant + if upstreamType.GetUnionType() != nil && len(upstreamType.GetUnionType().GetVariants()) == 1 { + variants := upstreamType.GetUnionType().GetVariants() + if len(variants) == 1 && typeChecker.CastsFrom(variants[0]) { + return true + } + } + + return typeChecker.CastsFrom(upstreamType) } diff --git a/flytepropeller/pkg/compiler/validators/typing_test.go b/flytepropeller/pkg/compiler/validators/typing_test.go index 8344339f09..146d314de2 100644 --- a/flytepropeller/pkg/compiler/validators/typing_test.go +++ b/flytepropeller/pkg/compiler/validators/typing_test.go @@ -326,6 +326,32 @@ func TestUnionCasting(t *testing.T) { ) assert.False(t, castable, "Union types can only be cast to a union that contains a superset of variants") }) + + t.Run("SingularUnionToUnderlyingType", func(t *testing.T) { + castable := AreTypesCastable( + &core.LiteralType{ + Type: &core.LiteralType_UnionType{ + UnionType: &core.UnionType{ + Variants: []*core.LiteralType{ + { + Type: &core.LiteralType_Simple{Simple: core.SimpleType_STRING}, + Structure: &core.TypeStructure{ + Tag: "string", + }, + }, + }, + }, + }, + }, + &core.LiteralType{ + Type: &core.LiteralType_Simple{Simple: core.SimpleType_STRING}, + Structure: &core.TypeStructure{ + Tag: "string", + }, + }, + ) + assert.True(t, castable, "Singular unions should be castable to their underlying type") + }) } func TestCollectionCasting(t *testing.T) {