diff --git a/pkg/compiler/validators/typing.go b/pkg/compiler/validators/typing.go index 3600f262c..9961587d7 100644 --- a/pkg/compiler/validators/typing.go +++ b/pkg/compiler/validators/typing.go @@ -40,6 +40,12 @@ func (t trivialChecker) CastsFrom(upstreamType *flyte.LiteralType) bool { return true } } + // If t is an enum, it can be created from a string as Enums as just constrained String aliases + if t.literalType.GetEnumType() != nil { + if upstreamType.GetSimple() == flyte.SimpleType_STRING { + return true + } + } // Ignore metadata when comparing types. upstreamTypeCopy := *upstreamType diff --git a/pkg/compiler/validators/typing_test.go b/pkg/compiler/validators/typing_test.go index 75f17da16..0dc0288ce 100644 --- a/pkg/compiler/validators/typing_test.go +++ b/pkg/compiler/validators/typing_test.go @@ -86,7 +86,7 @@ func TestSimpleLiteralCasting(t *testing.T) { Type: &core.LiteralType_Simple{Simple: core.SimpleType_STRING}, }, ) - assert.True(t, castable, "Integers should be castable to other integers") + assert.True(t, castable, "Enum should be castable to string") }) t.Run("EnumToEnum", func(t *testing.T) { @@ -102,7 +102,23 @@ func TestSimpleLiteralCasting(t *testing.T) { }}, }, ) - assert.True(t, castable, "Integers should be castable to other integers") + assert.True(t, castable, "Enum should be castable to Enums if they are identical") + }) + + t.Run("EnumToEnum", func(t *testing.T) { + castable := AreTypesCastable( + &core.LiteralType{ + Type: &core.LiteralType_EnumType{EnumType: &core.EnumType{ + Values: []string{"x", "y"}, + }}, + }, + &core.LiteralType{ + Type: &core.LiteralType_EnumType{EnumType: &core.EnumType{ + Values: []string{"m", "n"}, + }}, + }, + ) + assert.False(t, castable, "Enum should not be castable to non matching enums") }) t.Run("StringToEnum", func(t *testing.T) { @@ -116,7 +132,7 @@ func TestSimpleLiteralCasting(t *testing.T) { }}, }, ) - assert.False(t, castable, "Integers should be castable to other integers") + assert.True(t, castable, "Strings should be castable to enums - may result in runtime failure") }) }