Skip to content

Commit

Permalink
Compare any arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
bizywizy committed Feb 27, 2024
1 parent ae180c8 commit f1b68c8
Show file tree
Hide file tree
Showing 3 changed files with 415 additions and 0 deletions.
26 changes: 26 additions & 0 deletions expr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2482,3 +2482,29 @@ func TestRaceCondition_variables(t *testing.T) {

wg.Wait()
}

func TestArrayComparison(t *testing.T) {
tests := []struct {
env any
code string
}{
{[]string{"A", "B"}, "foo == ['A', 'B']"},
{[]int{1, 2}, "foo == [1, 2]"},
{[]uint8{1, 2}, "foo == [1, 2]"},
{[]float64{1.1, 2.2}, "foo == [1.1, 2.2]"},
{[]any{"A", 1, 1.1, true}, "foo == ['A', 1, 1.1, true]"},
{[]string{"A", "B"}, "foo != [1, 2]"},
}

for _, tt := range tests {
t.Run(tt.code, func(t *testing.T) {
env := map[string]any{"foo": tt.env}
program, err := expr.Compile(tt.code, expr.Env(env))
require.NoError(t, err)

out, err := expr.Run(program, env)
require.NoError(t, err)
require.Equal(t, true, out)
})
}
}
46 changes: 46 additions & 0 deletions vm/runtime/helpers/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ func main() {
"cases_with_duration": func(op string) string {
return cases(op, uints, ints, floats, []string{"time.Duration"})
},
"array_equal_cases": func() string { return arrayEqualCases([]string{"string"}, uints, ints, floats) },
}).
Parse(helpers),
).Execute(&b, nil)
Expand Down Expand Up @@ -89,6 +90,45 @@ func cases(op string, xs ...[]string) string {
return strings.TrimRight(out, "\n")
}

func arrayEqualCases(xs ...[]string) string {
var types []string
for _, x := range xs {
types = append(types, x...)
}

_, _ = fmt.Fprintf(os.Stderr, "Generating array equal cases for %v\n", types)

var out string
echo := func(s string, xs ...any) {
out += fmt.Sprintf(s, xs...) + "\n"
}
echo(`case []any:`)
echo(`switch y := b.(type) {`)
for _, a := range append(types, "any") {
echo(`case []%v:`, a)
echo(`if len(x) != len(y) { return false }`)
echo(`for i := range x {`)
echo(`if !Equal(x[i], y[i]) { return false }`)
echo(`}`)
echo("return true")
}
echo(`}`)
for _, a := range types {
echo(`case []%v:`, a)
echo(`switch y := b.(type) {`)
echo(`case []any:`)
echo(`return Equal(y, x)`)
echo(`case []%v:`, a)
echo(`if len(x) != len(y) { return false }`)
echo(`for i := range x {`)
echo(`if x[i] != y[i] { return false }`)
echo(`}`)
echo("return true")
echo(`}`)
}
return strings.TrimRight(out, "\n")
}

func isFloat(t string) bool {
return strings.HasPrefix(t, "float")
}
Expand All @@ -110,6 +150,7 @@ import (
func Equal(a, b interface{}) bool {
switch x := a.(type) {
{{ cases "==" }}
{{ array_equal_cases }}
case string:
switch y := b.(type) {
case string:
Expand All @@ -125,6 +166,11 @@ func Equal(a, b interface{}) bool {
case time.Duration:
return x == y
}
case bool:
switch y := b.(type) {
case bool:
return x == y
}
}
if IsNil(a) && IsNil(b) {
return true
Expand Down
Loading

0 comments on commit f1b68c8

Please sign in to comment.