Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add predicate to sum() builtin #592

Merged
merged 3 commits into from
Apr 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions builtin/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ var Builtins = []*Function{
Predicate: true,
Types: types(new(func([]any, func(any) bool) int)),
},
{
Name: "sum",
Predicate: true,
Types: types(new(func([]any, func(any) bool) int)),
},
{
Name: "groupBy",
Predicate: true,
Expand Down Expand Up @@ -387,13 +392,6 @@ var Builtins = []*Function{
return validateAggregateFunc("min", args)
},
},
{
Name: "sum",
Func: sum,
Validate: func(args []reflect.Type) (reflect.Type, error) {
return validateAggregateFunc("sum", args)
},
},
{
Name: "mean",
Func: func(args ...any) (any, error) {
Expand Down
4 changes: 0 additions & 4 deletions builtin/builtin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,6 @@ func TestBuiltin(t *testing.T) {
{`sum([.5, 1.5, 2.5])`, 4.5},
{`sum([])`, 0},
{`sum([1, 2, 3.0, 4])`, 10.0},
{`sum(10, [1, 2, 3], 1..9)`, 61},
{`sum(-10, [1, 2, 3, 4])`, 0},
{`sum(-10.9, [1, 2, 3, 4, 9])`, 8.1},
{`mean(1..9)`, 5.0},
{`mean([.5, 1.5, 2.5])`, 1.5},
{`mean([])`, 0.0},
Expand Down Expand Up @@ -219,7 +216,6 @@ func TestBuiltin_errors(t *testing.T) {
{`min([1, "2"])`, `invalid argument for min (type string)`},
{`median(1..9, "t")`, "invalid argument for median (type string)"},
{`mean("s", 1..9)`, "invalid argument for mean (type string)"},
{`sum("s", "h")`, "invalid argument for sum (type string)"},
{`duration("error")`, `invalid duration`},
{`date("error")`, `invalid date`},
{`get()`, `invalid number of arguments (expected 2, got 0)`},
Expand Down
39 changes: 0 additions & 39 deletions builtin/lib.go
Original file line number Diff line number Diff line change
Expand Up @@ -258,45 +258,6 @@ func String(arg any) any {
return fmt.Sprintf("%v", arg)
}

func sum(args ...any) (any, error) {
var total int
var fTotal float64

for _, arg := range args {
rv := reflect.ValueOf(deref.Deref(arg))

switch rv.Kind() {
case reflect.Array, reflect.Slice:
size := rv.Len()
for i := 0; i < size; i++ {
elemSum, err := sum(rv.Index(i).Interface())
if err != nil {
return nil, err
}
switch elemSum := elemSum.(type) {
case int:
total += elemSum
case float64:
fTotal += elemSum
}
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
total += int(rv.Int())
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
total += int(rv.Uint())
case reflect.Float32, reflect.Float64:
fTotal += rv.Float()
default:
return nil, fmt.Errorf("invalid argument for sum (type %T)", arg)
}
}

if fTotal != 0.0 {
return fTotal + float64(total), nil
}
return total, nil
}

func minMax(name string, fn func(any, any) bool, args ...any) (any, error) {
var val any
for _, arg := range args {
Expand Down
23 changes: 23 additions & 0 deletions checker/checker.go
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,29 @@ func (v *checker) BuiltinNode(node *ast.BuiltinNode) (reflect.Type, info) {
}
return v.error(node.Arguments[1], "predicate should has one input and one output param")

case "sum":
collection, _ := v.visit(node.Arguments[0])
if !isArray(collection) && !isAny(collection) {
return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection)
}

if len(node.Arguments) == 2 {
v.begin(collection)
closure, _ := v.visit(node.Arguments[1])
v.end()

if isFunc(closure) &&
closure.NumOut() == 1 &&
closure.NumIn() == 1 && isAny(closure.In(0)) {
return closure.Out(0), info{}
}
} else {
if isAny(collection) {
return anyType, info{}
}
return collection.Elem(), info{}
}

case "find", "findLast":
collection, _ := v.visit(node.Arguments[0])
if !isArray(collection) && !isAny(collection) {
Expand Down
19 changes: 19 additions & 0 deletions compiler/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -809,6 +809,25 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) {
c.emit(OpEnd)
return

case "sum":
c.compile(node.Arguments[0])
c.emit(OpBegin)
c.emit(OpInt, 0)
c.emit(OpSetAcc)
c.emitLoop(func() {
if len(node.Arguments) == 2 {
c.compile(node.Arguments[1])
} else {
c.emit(OpPointer)
}
c.emit(OpGetAcc)
c.emit(OpAdd)
c.emit(OpSetAcc)
})
c.emit(OpGetAcc)
c.emit(OpEnd)
return

case "find":
c.compile(node.Arguments[0])
c.emit(OpBegin)
Expand Down
1 change: 1 addition & 0 deletions parser/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ var predicates = map[string]struct {
"filter": {[]arg{expr, closure}},
"map": {[]arg{expr, closure}},
"count": {[]arg{expr, closure}},
"sum": {[]arg{expr, closure | optional}},
"find": {[]arg{expr, closure}},
"findIndex": {[]arg{expr, closure}},
"findLast": {[]arg{expr, closure}},
Expand Down
1 change: 0 additions & 1 deletion test/fuzz/fuzz_corpus.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10455,7 +10455,6 @@ max(f64, i64)
max(false ? 1 : 0.5)
max(false ? 1 : nil)
max(false ? add : ok)
max(false ? half : list)
max(false ? i : nil)
max(false ? i32 : score)
max(false ? true : 1)
Expand Down
14 changes: 0 additions & 14 deletions testdata/examples.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7419,12 +7419,6 @@ get(ok ? score : foo, String?.foo())
get(ok ? score : i64, foo)
get(reduce(list, array), i32)
get(sort(array), i32)
get(sum(array), Qux)
get(sum(array), String)
get(sum(array), f32)
get(sum(array), f64 == list)
get(sum(array), greet)
get(sum(array), i)
get(take(list, i), i64)
get(true ? "bar" : ok, score(i))
get(true ? "foo" : half, list)
Expand Down Expand Up @@ -7460,7 +7454,6 @@ greet != nil ? list : false
greet != score
greet != score != false
greet != score or ok
greet != sum(array)
greet == add
greet == add ? i : list
greet == add or ok
Expand Down Expand Up @@ -12200,7 +12193,6 @@ last(ok ? ok : 0.5)
last(reduce(array, list))
last(reduce(list, array))
last(sort(array))
last(sum(array))
last(true ? "bar" : half)
last(true ? add : list)
last(true ? foo : 1)
Expand Down Expand Up @@ -14818,7 +14810,6 @@ ok != nil ? nil : array
ok != not ok
ok != ok
ok != ok ? false : "bar"
ok != sum(array)
ok && !false
ok && !ok
ok && "foo" matches "bar"
Expand Down Expand Up @@ -16970,7 +16961,6 @@ string(groupBy(list, i))
string(half != nil)
string(half != score)
string(half == nil)
string(half == sum(array))
string(half(0.5))
string(half(1))
string(half(f64))
Expand Down Expand Up @@ -17297,18 +17287,14 @@ sum([0.5])
sum([f32])
sum(array)
sum(array) != f32
sum(array) != half
sum(array) != ok
sum(array) % i
sum(array) % i64
sum(array) - f32
sum(array) / -f64
sum(array) < i
sum(array) == div
sum(array) == i64 - i
sum(array) ^ f64
sum(array) not in array
sum(array) not in list
sum(filter(array, ok))
sum(groupBy(array, i32).String)
sum(groupBy(list, #)?.greet)
Expand Down
Loading