diff --git a/flytepropeller/pkg/compiler/validators/typing.go b/flytepropeller/pkg/compiler/validators/typing.go index 09268359dd..fc5cd368a6 100644 --- a/flytepropeller/pkg/compiler/validators/typing.go +++ b/flytepropeller/pkg/compiler/validators/typing.go @@ -356,5 +356,15 @@ func getTypeChecker(t *flyte.LiteralType) typeChecker { } func AreTypesCastable(upstreamType, downstreamType *flyte.LiteralType) bool { - return getTypeChecker(downstreamType).CastsFrom(upstreamType) + typeChecker := getTypeChecker(downstreamType) + + // if upstream is a singular union we check if the downstream type is castable from the union variant + if upstreamType.GetUnionType() != nil && len(upstreamType.GetUnionType().GetVariants()) == 1 { + variants := upstreamType.GetUnionType().GetVariants() + if len(variants) == 1 && typeChecker.CastsFrom(variants[0]) { + return true + } + } + + return typeChecker.CastsFrom(upstreamType) } diff --git a/flytepropeller/pkg/compiler/validators/typing_test.go b/flytepropeller/pkg/compiler/validators/typing_test.go index 8344339f09..146d314de2 100644 --- a/flytepropeller/pkg/compiler/validators/typing_test.go +++ b/flytepropeller/pkg/compiler/validators/typing_test.go @@ -326,6 +326,32 @@ func TestUnionCasting(t *testing.T) { ) assert.False(t, castable, "Union types can only be cast to a union that contains a superset of variants") }) + + t.Run("SingularUnionToUnderlyingType", func(t *testing.T) { + castable := AreTypesCastable( + &core.LiteralType{ + Type: &core.LiteralType_UnionType{ + UnionType: &core.UnionType{ + Variants: []*core.LiteralType{ + { + Type: &core.LiteralType_Simple{Simple: core.SimpleType_STRING}, + Structure: &core.TypeStructure{ + Tag: "string", + }, + }, + }, + }, + }, + }, + &core.LiteralType{ + Type: &core.LiteralType_Simple{Simple: core.SimpleType_STRING}, + Structure: &core.TypeStructure{ + Tag: "string", + }, + }, + ) + assert.True(t, castable, "Singular unions should be castable to their underlying type") + }) } func TestCollectionCasting(t *testing.T) {