Skip to content

Commit

Permalink
Only emit OpEqual{Int,String} for simple types
Browse files Browse the repository at this point in the history
Fixes #461
  • Loading branch information
antonmedv committed Nov 16, 2023
1 parent d27e5a3 commit 0354d1b
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 2 deletions.
19 changes: 17 additions & 2 deletions compiler/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -360,16 +360,20 @@ func (c *compiler) BinaryNode(node *ast.BinaryNode) {
l := kind(node.Left)
r := kind(node.Right)

leftIsSimple := isSimpleType(node.Left)
rightIsSimple := isSimpleType(node.Right)
leftAndRightAreSimple := leftIsSimple && rightIsSimple

switch node.Operator {
case "==":
c.compile(node.Left)
c.derefInNeeded(node.Left)
c.compile(node.Right)
c.derefInNeeded(node.Left)

if l == r && l == reflect.Int {
if l == r && l == reflect.Int && leftAndRightAreSimple {
c.emit(OpEqualInt)
} else if l == r && l == reflect.String {
} else if l == r && l == reflect.String && leftAndRightAreSimple {
c.emit(OpEqualString)
} else {
c.emit(OpEqual)
Expand Down Expand Up @@ -534,6 +538,17 @@ func (c *compiler) BinaryNode(node *ast.BinaryNode) {
}
}

func isSimpleType(node ast.Node) bool {
if node == nil {
return false
}
t := node.Type()
if t == nil {
return false
}
return t.PkgPath() == ""
}

func (c *compiler) ChainNode(node *ast.ChainNode) {
c.chains = append(c.chains, []int{})
c.compile(node.Node)
Expand Down
80 changes: 80 additions & 0 deletions expr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2022,3 +2022,83 @@ func TestIssue453(t *testing.T) {
_, err := expr.Compile(`foo()`, expr.Env(env))
require.Error(t, err)
}

func TestIssue461(t *testing.T) {
type EnvStr string
type EnvField struct {
S EnvStr
Str string
}
type Env struct {
S EnvStr
Str string
EnvField EnvField
}
var tests = []struct {
input string
env Env
want bool
}{
{
input: "Str == S",
env: Env{S: "string", Str: "string"},
want: false,
},
{
input: "Str == Str",
env: Env{Str: "string"},
want: true,
},
{
input: "S == S",
env: Env{Str: "string"},
want: true,
},
{
input: `Str == "string"`,
env: Env{Str: "string"},
want: true,
},
{
input: `S == "string"`,
env: Env{Str: "string"},
want: false,
},
{
input: "EnvField.Str == EnvField.S",
env: Env{EnvField: EnvField{S: "string", Str: "string"}},
want: false,
},
{
input: "EnvField.Str == EnvField.Str",
env: Env{EnvField: EnvField{Str: "string"}},
want: true,
},
{
input: "EnvField.S == EnvField.S",
env: Env{EnvField: EnvField{Str: "string"}},
want: true,
},
{
input: `EnvField.Str == "string"`,
env: Env{EnvField: EnvField{Str: "string"}},
want: true,
},
{
input: `EnvField.S == "string"`,
env: Env{EnvField: EnvField{Str: "string"}},
want: false,
},
}

for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
program, err := expr.Compile(tt.input, expr.Env(tt.env), expr.AsBool())

out, err := expr.Run(program, tt.env)
require.NoError(t, err)

require.Equal(t, tt.want, out)
})
}
}

0 comments on commit 0354d1b

Please sign in to comment.