From f1b68c877d52fe7e5ca985055e113c9a519bb57c Mon Sep 17 00:00:00 2001 From: bizy Date: Wed, 28 Feb 2024 00:30:33 +0700 Subject: [PATCH] Compare `any` arrays --- expr_test.go | 26 +++ vm/runtime/helpers/main.go | 46 +++++ vm/runtime/helpers[generated].go | 343 +++++++++++++++++++++++++++++++ 3 files changed, 415 insertions(+) diff --git a/expr_test.go b/expr_test.go index 74975362b..ea9213be2 100644 --- a/expr_test.go +++ b/expr_test.go @@ -2482,3 +2482,29 @@ func TestRaceCondition_variables(t *testing.T) { wg.Wait() } + +func TestArrayComparison(t *testing.T) { + tests := []struct { + env any + code string + }{ + {[]string{"A", "B"}, "foo == ['A', 'B']"}, + {[]int{1, 2}, "foo == [1, 2]"}, + {[]uint8{1, 2}, "foo == [1, 2]"}, + {[]float64{1.1, 2.2}, "foo == [1.1, 2.2]"}, + {[]any{"A", 1, 1.1, true}, "foo == ['A', 1, 1.1, true]"}, + {[]string{"A", "B"}, "foo != [1, 2]"}, + } + + for _, tt := range tests { + t.Run(tt.code, func(t *testing.T) { + env := map[string]any{"foo": tt.env} + program, err := expr.Compile(tt.code, expr.Env(env)) + require.NoError(t, err) + + out, err := expr.Run(program, env) + require.NoError(t, err) + require.Equal(t, true, out) + }) + } +} diff --git a/vm/runtime/helpers/main.go b/vm/runtime/helpers/main.go index b3f598a43..54a4fc235 100644 --- a/vm/runtime/helpers/main.go +++ b/vm/runtime/helpers/main.go @@ -19,6 +19,7 @@ func main() { "cases_with_duration": func(op string) string { return cases(op, uints, ints, floats, []string{"time.Duration"}) }, + "array_equal_cases": func() string { return arrayEqualCases([]string{"string"}, uints, ints, floats) }, }). Parse(helpers), ).Execute(&b, nil) @@ -89,6 +90,45 @@ func cases(op string, xs ...[]string) string { return strings.TrimRight(out, "\n") } +func arrayEqualCases(xs ...[]string) string { + var types []string + for _, x := range xs { + types = append(types, x...) + } + + _, _ = fmt.Fprintf(os.Stderr, "Generating array equal cases for %v\n", types) + + var out string + echo := func(s string, xs ...any) { + out += fmt.Sprintf(s, xs...) + "\n" + } + echo(`case []any:`) + echo(`switch y := b.(type) {`) + for _, a := range append(types, "any") { + echo(`case []%v:`, a) + echo(`if len(x) != len(y) { return false }`) + echo(`for i := range x {`) + echo(`if !Equal(x[i], y[i]) { return false }`) + echo(`}`) + echo("return true") + } + echo(`}`) + for _, a := range types { + echo(`case []%v:`, a) + echo(`switch y := b.(type) {`) + echo(`case []any:`) + echo(`return Equal(y, x)`) + echo(`case []%v:`, a) + echo(`if len(x) != len(y) { return false }`) + echo(`for i := range x {`) + echo(`if x[i] != y[i] { return false }`) + echo(`}`) + echo("return true") + echo(`}`) + } + return strings.TrimRight(out, "\n") +} + func isFloat(t string) bool { return strings.HasPrefix(t, "float") } @@ -110,6 +150,7 @@ import ( func Equal(a, b interface{}) bool { switch x := a.(type) { {{ cases "==" }} + {{ array_equal_cases }} case string: switch y := b.(type) { case string: @@ -125,6 +166,11 @@ func Equal(a, b interface{}) bool { case time.Duration: return x == y } + case bool: + switch y := b.(type) { + case bool: + return x == y + } } if IsNil(a) && IsNil(b) { return true diff --git a/vm/runtime/helpers[generated].go b/vm/runtime/helpers[generated].go index 720feb455..d950f1111 100644 --- a/vm/runtime/helpers[generated].go +++ b/vm/runtime/helpers[generated].go @@ -334,6 +334,344 @@ func Equal(a, b interface{}) bool { case float64: return float64(x) == float64(y) } + case []any: + switch y := b.(type) { + case []string: + if len(x) != len(y) { + return false + } + for i := range x { + if !Equal(x[i], y[i]) { + return false + } + } + return true + case []uint: + if len(x) != len(y) { + return false + } + for i := range x { + if !Equal(x[i], y[i]) { + return false + } + } + return true + case []uint8: + if len(x) != len(y) { + return false + } + for i := range x { + if !Equal(x[i], y[i]) { + return false + } + } + return true + case []uint16: + if len(x) != len(y) { + return false + } + for i := range x { + if !Equal(x[i], y[i]) { + return false + } + } + return true + case []uint32: + if len(x) != len(y) { + return false + } + for i := range x { + if !Equal(x[i], y[i]) { + return false + } + } + return true + case []uint64: + if len(x) != len(y) { + return false + } + for i := range x { + if !Equal(x[i], y[i]) { + return false + } + } + return true + case []int: + if len(x) != len(y) { + return false + } + for i := range x { + if !Equal(x[i], y[i]) { + return false + } + } + return true + case []int8: + if len(x) != len(y) { + return false + } + for i := range x { + if !Equal(x[i], y[i]) { + return false + } + } + return true + case []int16: + if len(x) != len(y) { + return false + } + for i := range x { + if !Equal(x[i], y[i]) { + return false + } + } + return true + case []int32: + if len(x) != len(y) { + return false + } + for i := range x { + if !Equal(x[i], y[i]) { + return false + } + } + return true + case []int64: + if len(x) != len(y) { + return false + } + for i := range x { + if !Equal(x[i], y[i]) { + return false + } + } + return true + case []float32: + if len(x) != len(y) { + return false + } + for i := range x { + if !Equal(x[i], y[i]) { + return false + } + } + return true + case []float64: + if len(x) != len(y) { + return false + } + for i := range x { + if !Equal(x[i], y[i]) { + return false + } + } + return true + case []any: + if len(x) != len(y) { + return false + } + for i := range x { + if !Equal(x[i], y[i]) { + return false + } + } + return true + } + case []string: + switch y := b.(type) { + case []any: + return Equal(y, x) + case []string: + if len(x) != len(y) { + return false + } + for i := range x { + if x[i] != y[i] { + return false + } + } + return true + } + case []uint: + switch y := b.(type) { + case []any: + return Equal(y, x) + case []uint: + if len(x) != len(y) { + return false + } + for i := range x { + if x[i] != y[i] { + return false + } + } + return true + } + case []uint8: + switch y := b.(type) { + case []any: + return Equal(y, x) + case []uint8: + if len(x) != len(y) { + return false + } + for i := range x { + if x[i] != y[i] { + return false + } + } + return true + } + case []uint16: + switch y := b.(type) { + case []any: + return Equal(y, x) + case []uint16: + if len(x) != len(y) { + return false + } + for i := range x { + if x[i] != y[i] { + return false + } + } + return true + } + case []uint32: + switch y := b.(type) { + case []any: + return Equal(y, x) + case []uint32: + if len(x) != len(y) { + return false + } + for i := range x { + if x[i] != y[i] { + return false + } + } + return true + } + case []uint64: + switch y := b.(type) { + case []any: + return Equal(y, x) + case []uint64: + if len(x) != len(y) { + return false + } + for i := range x { + if x[i] != y[i] { + return false + } + } + return true + } + case []int: + switch y := b.(type) { + case []any: + return Equal(y, x) + case []int: + if len(x) != len(y) { + return false + } + for i := range x { + if x[i] != y[i] { + return false + } + } + return true + } + case []int8: + switch y := b.(type) { + case []any: + return Equal(y, x) + case []int8: + if len(x) != len(y) { + return false + } + for i := range x { + if x[i] != y[i] { + return false + } + } + return true + } + case []int16: + switch y := b.(type) { + case []any: + return Equal(y, x) + case []int16: + if len(x) != len(y) { + return false + } + for i := range x { + if x[i] != y[i] { + return false + } + } + return true + } + case []int32: + switch y := b.(type) { + case []any: + return Equal(y, x) + case []int32: + if len(x) != len(y) { + return false + } + for i := range x { + if x[i] != y[i] { + return false + } + } + return true + } + case []int64: + switch y := b.(type) { + case []any: + return Equal(y, x) + case []int64: + if len(x) != len(y) { + return false + } + for i := range x { + if x[i] != y[i] { + return false + } + } + return true + } + case []float32: + switch y := b.(type) { + case []any: + return Equal(y, x) + case []float32: + if len(x) != len(y) { + return false + } + for i := range x { + if x[i] != y[i] { + return false + } + } + return true + } + case []float64: + switch y := b.(type) { + case []any: + return Equal(y, x) + case []float64: + if len(x) != len(y) { + return false + } + for i := range x { + if x[i] != y[i] { + return false + } + } + return true + } case string: switch y := b.(type) { case string: @@ -349,6 +687,11 @@ func Equal(a, b interface{}) bool { case time.Duration: return x == y } + case bool: + switch y := b.(type) { + case bool: + return x == y + } } if IsNil(a) && IsNil(b) { return true