diff --git a/flytepropeller/go.mod b/flytepropeller/go.mod index 8f36c34fe..b833d23ad 100644 --- a/flytepropeller/go.mod +++ b/flytepropeller/go.mod @@ -6,7 +6,7 @@ require ( github.com/DiSiqueira/GoTree v1.0.1-0.20180907134536-53a8e837f295 github.com/benlaurie/objecthash v0.0.0-20180202135721-d1e3d6079fc1 github.com/fatih/color v1.10.0 - github.com/flyteorg/flyteidl v0.18.50 + github.com/flyteorg/flyteidl v0.19.2 github.com/flyteorg/flyteplugins v0.5.54 github.com/flyteorg/flytestdlib v0.3.17 github.com/ghodss/yaml v1.0.0 diff --git a/flytepropeller/go.sum b/flytepropeller/go.sum index daff446e9..fd1edd1da 100644 --- a/flytepropeller/go.sum +++ b/flytepropeller/go.sum @@ -72,7 +72,6 @@ github.com/Azure/go-autorest/logger v0.2.0/go.mod h1:T9E3cAhj2VqvPOtCYAvby9aBXkZ github.com/Azure/go-autorest/tracing v0.5.0/go.mod h1:r/s2XiOKccPW3HrqB+W0TQzfbtp2fGCgRFtBroKn4Dk= github.com/Azure/go-autorest/tracing v0.6.0 h1:TYi4+3m5t6K48TGI9AUdb+IzbnSxvnvUMfuitfgcfuo= github.com/Azure/go-autorest/tracing v0.6.0/go.mod h1:+vhtPC754Xsa23ID7GlGsrdKBpUA79WCAKPPZVC2DeU= -github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/DiSiqueira/GoTree v1.0.1-0.20180907134536-53a8e837f295 h1:xJ0dAkuxJXfwdH7IaSzBEbSQxEDz36YUmt7+CB4zoNA= @@ -231,8 +230,8 @@ github.com/fatih/color v1.10.0 h1:s36xzo75JdqLaaWoiEHk767eHiwo0598uUxyfiPkDsg= github.com/fatih/color v1.10.0/go.mod h1:ELkj/draVOlAH/xkhN6mQ50Qd0MPOk5AAr3maGEBuJM= github.com/fatih/structtag v1.2.0/go.mod h1:mBJUNpUnHmRKrKlQQlmCrh5PuhftFbNv8Ys4/aAZl94= github.com/flyteorg/flyteidl v0.18.48/go.mod h1:576W2ViEyjTpT+kEVHAGbrTP3HARNUZ/eCwrNPmdx9U= -github.com/flyteorg/flyteidl v0.18.50 h1:L1fMj6QEXoKin+cPQn9sfwJ1x14tlChdz1mG1WaaIW4= -github.com/flyteorg/flyteidl v0.18.50/go.mod h1:576W2ViEyjTpT+kEVHAGbrTP3HARNUZ/eCwrNPmdx9U= +github.com/flyteorg/flyteidl v0.19.2 h1:jXuRrLJEzSo33N9pw7bMEd6mRYSL7LCz/vnazz5XcOg= +github.com/flyteorg/flyteidl v0.19.2/go.mod h1:576W2ViEyjTpT+kEVHAGbrTP3HARNUZ/eCwrNPmdx9U= github.com/flyteorg/flyteplugins v0.5.54 h1:QQRh4RRnLxW89A/D/SjPmbPETTp2ypNZnHpGGg1vA84= github.com/flyteorg/flyteplugins v0.5.54/go.mod h1:dcAWfANpOlrPemHmegNXUhrkWjVWIPvLGaX6rHPlA/E= github.com/flyteorg/flytestdlib v0.3.13/go.mod h1:Tz8JCECAbX6VWGwFT6cmEQ+RJpZ/6L9pswu3fzWs220= @@ -1227,7 +1226,6 @@ honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWh honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= honnef.co/go/tools v0.0.1-2020.1.3/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= -honnef.co/go/tools v0.0.1-2020.1.4 h1:UoveltGrhghAA7ePc+e+QYDHXrBps2PqFZiHkGR/xK8= honnef.co/go/tools v0.0.1-2020.1.4/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= k8s.io/api v0.0.0-20210217171935-8e2decd92398/go.mod h1:60tmSUpHxGPFerNHbo/ayI2lKxvtrhbxFyXuEIWJd78= k8s.io/api v0.18.2/go.mod h1:SJCWI7OLzhZSvbY7U8zwNl9UA4o1fizoug34OV/2r78= diff --git a/flytepropeller/pkg/compiler/validators/typing.go b/flytepropeller/pkg/compiler/validators/typing.go index ea3a983b8..9961587d7 100644 --- a/flytepropeller/pkg/compiler/validators/typing.go +++ b/flytepropeller/pkg/compiler/validators/typing.go @@ -33,6 +33,20 @@ func (t trivialChecker) CastsFrom(upstreamType *flyte.LiteralType) bool { if isVoid(upstreamType) { return true } + + // If upstream is an enum, it can be consumed as a string downstream + if upstreamType.GetEnumType() != nil { + if t.literalType.GetSimple() == flyte.SimpleType_STRING { + return true + } + } + // If t is an enum, it can be created from a string as Enums as just constrained String aliases + if t.literalType.GetEnumType() != nil { + if upstreamType.GetSimple() == flyte.SimpleType_STRING { + return true + } + } + // Ignore metadata when comparing types. upstreamTypeCopy := *upstreamType downstreamTypeCopy := *t.literalType diff --git a/flytepropeller/pkg/compiler/validators/typing_test.go b/flytepropeller/pkg/compiler/validators/typing_test.go index 8b46646d1..0dc0288ce 100644 --- a/flytepropeller/pkg/compiler/validators/typing_test.go +++ b/flytepropeller/pkg/compiler/validators/typing_test.go @@ -74,6 +74,66 @@ func TestSimpleLiteralCasting(t *testing.T) { ) assert.True(t, castable, "Metadata should be ignored") }) + + t.Run("EnumToString", func(t *testing.T) { + castable := AreTypesCastable( + &core.LiteralType{ + Type: &core.LiteralType_EnumType{EnumType: &core.EnumType{ + Values: []string{"x", "y"}, + }}, + }, + &core.LiteralType{ + Type: &core.LiteralType_Simple{Simple: core.SimpleType_STRING}, + }, + ) + assert.True(t, castable, "Enum should be castable to string") + }) + + t.Run("EnumToEnum", func(t *testing.T) { + castable := AreTypesCastable( + &core.LiteralType{ + Type: &core.LiteralType_EnumType{EnumType: &core.EnumType{ + Values: []string{"x", "y"}, + }}, + }, + &core.LiteralType{ + Type: &core.LiteralType_EnumType{EnumType: &core.EnumType{ + Values: []string{"x", "y"}, + }}, + }, + ) + assert.True(t, castable, "Enum should be castable to Enums if they are identical") + }) + + t.Run("EnumToEnum", func(t *testing.T) { + castable := AreTypesCastable( + &core.LiteralType{ + Type: &core.LiteralType_EnumType{EnumType: &core.EnumType{ + Values: []string{"x", "y"}, + }}, + }, + &core.LiteralType{ + Type: &core.LiteralType_EnumType{EnumType: &core.EnumType{ + Values: []string{"m", "n"}, + }}, + }, + ) + assert.False(t, castable, "Enum should not be castable to non matching enums") + }) + + t.Run("StringToEnum", func(t *testing.T) { + castable := AreTypesCastable( + &core.LiteralType{ + Type: &core.LiteralType_Simple{Simple: core.SimpleType_STRING}, + }, + &core.LiteralType{ + Type: &core.LiteralType_EnumType{EnumType: &core.EnumType{ + Values: []string{"x", "y"}, + }}, + }, + ) + assert.True(t, castable, "Strings should be castable to enums - may result in runtime failure") + }) } func TestCollectionCasting(t *testing.T) {