diff --git a/builtin/builtin_test.go b/builtin/builtin_test.go index 0cdf9515..e302d719 100644 --- a/builtin/builtin_test.go +++ b/builtin/builtin_test.go @@ -327,6 +327,24 @@ func TestBuiltin_allow_builtins_override(t *testing.T) { program, err := expr.Compile(fmt.Sprintf("%s()", name), fn) require.NoError(t, err) + out, err := expr.Run(program, nil) + require.NoError(t, err) + assert.Equal(t, 42, out) + }) + } + }) + t.Run("via expr.Function as pipe", func(t *testing.T) { + for _, name := range builtin.Names { + t.Run(name, func(t *testing.T) { + fn := expr.Function(name, + func(params ...any) (any, error) { + return 42, nil + }, + new(func(s string) int), + ) + program, err := expr.Compile(fmt.Sprintf("'str' | %s()", name), fn) + require.NoError(t, err) + out, err := expr.Run(program, nil) require.NoError(t, err) assert.Equal(t, 42, out) diff --git a/parser/parser.go b/parser/parser.go index caab130d..07c1be54 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -15,22 +15,31 @@ import ( "github.com/expr-lang/expr/parser/utils" ) +type arg byte + +const ( + expr arg = 1 << iota + closure +) + +const optional arg = 1 << 7 + var predicates = map[string]struct { - arity int + args []arg }{ - "all": {2}, - "none": {2}, - "any": {2}, - "one": {2}, - "filter": {2}, - "map": {2}, - "count": {2}, - "find": {2}, - "findIndex": {2}, - "findLast": {2}, - "findLastIndex": {2}, - "groupBy": {2}, - "reduce": {3}, + "all": {[]arg{expr, closure}}, + "none": {[]arg{expr, closure}}, + "any": {[]arg{expr, closure}}, + "one": {[]arg{expr, closure}}, + "filter": {[]arg{expr, closure}}, + "map": {[]arg{expr, closure}}, + "count": {[]arg{expr, closure}}, + "find": {[]arg{expr, closure}}, + "findIndex": {[]arg{expr, closure}}, + "findLast": {[]arg{expr, closure}}, + "findLastIndex": {[]arg{expr, closure}}, + "groupBy": {[]arg{expr, closure}}, + "reduce": {[]arg{expr, closure, expr | optional}}, } type parser struct { @@ -143,7 +152,9 @@ func (p *parser) parseExpression(precedence int) Node { p.next() if opToken.Value == "|" { - nodeLeft = p.parsePipe(nodeLeft) + identToken := p.current + p.expect(Identifier) + nodeLeft = p.parseCall(identToken, []Node{nodeLeft}, true) goto next } @@ -279,7 +290,7 @@ func (p *parser) parsePrimary() Node { p.next() token = p.current p.expect(Identifier) - return p.parsePostfixExpression(p.parseCall(token, false)) + return p.parsePostfixExpression(p.parseCall(token, []Node{}, false)) } return p.parseSecondary() @@ -307,7 +318,12 @@ func (p *parser) parseSecondary() Node { node.SetLocation(token.Location) return node default: - node = p.parseCall(token, true) + if p.current.Is(Bracket, "(") { + node = p.parseCall(token, []Node{}, true) + } else { + node = &IdentifierNode{Value: token.Value} + node.SetLocation(token.Location) + } } case Number: @@ -386,68 +402,86 @@ func (p *parser) toFloatNode(number float64) Node { return &FloatNode{Value: number} } -func (p *parser) parseCall(token Token, checkOverrides bool) Node { +func (p *parser) parseCall(token Token, arguments []Node, checkOverrides bool) Node { var node Node - if p.current.Is(Bracket, "(") { - var arguments []Node - - isOverridden := p.config.IsOverridden(token.Value) - isOverridden = isOverridden && checkOverrides - - // TODO: Refactor parser to use builtin.Builtins instead of predicates map. - if b, ok := predicates[token.Value]; ok && !isOverridden { - p.expect(Bracket, "(") - - if b.arity == 1 { - arguments = make([]Node, 1) - arguments[0] = p.parseExpression(0) - } else if b.arity == 2 { - arguments = make([]Node, 2) - arguments[0] = p.parseExpression(0) - p.expect(Operator, ",") - arguments[1] = p.parseClosure() - } - if token.Value == "reduce" { - arguments = make([]Node, 2) - arguments[0] = p.parseExpression(0) - p.expect(Operator, ",") - arguments[1] = p.parseClosure() - if p.current.Is(Operator, ",") { - p.next() - arguments = append(arguments, p.parseExpression(0)) - } - } + isOverridden := p.config.IsOverridden(token.Value) + isOverridden = isOverridden && checkOverrides + + if b, ok := predicates[token.Value]; ok && !isOverridden { + p.expect(Bracket, "(") - p.expect(Bracket, ")") + // In case of the pipe operator, the first argument is the left-hand side + // of the operator, so we do not parse it as an argument inside brackets. + args := b.args[len(arguments):] - node = &BuiltinNode{ - Name: token.Value, - Arguments: arguments, + for i, arg := range args { + if arg&optional == optional { + if p.current.Is(Bracket, ")") { + break + } + } else { + if p.current.Is(Bracket, ")") { + p.error("expected at least %d arguments", len(args)) + } } - node.SetLocation(token.Location) - } else if _, ok := builtin.Index[token.Value]; ok && !p.config.Disabled[token.Value] && !isOverridden { - node = &BuiltinNode{ - Name: token.Value, - Arguments: p.parseArguments(), + + if i > 0 { + p.expect(Operator, ",") } - node.SetLocation(token.Location) - } else { - callee := &IdentifierNode{Value: token.Value} - callee.SetLocation(token.Location) - node = &CallNode{ - Callee: callee, - Arguments: p.parseArguments(), + var node Node + switch { + case arg&expr == expr: + node = p.parseExpression(0) + case arg&closure == closure: + node = p.parseClosure() } - node.SetLocation(token.Location) + arguments = append(arguments, node) + } + + p.expect(Bracket, ")") + + node = &BuiltinNode{ + Name: token.Value, + Arguments: arguments, + } + node.SetLocation(token.Location) + } else if _, ok := builtin.Index[token.Value]; ok && !p.config.Disabled[token.Value] && !isOverridden { + node = &BuiltinNode{ + Name: token.Value, + Arguments: p.parseArguments(arguments), } + node.SetLocation(token.Location) } else { - node = &IdentifierNode{Value: token.Value} + callee := &IdentifierNode{Value: token.Value} + callee.SetLocation(token.Location) + node = &CallNode{ + Callee: callee, + Arguments: p.parseArguments(arguments), + } node.SetLocation(token.Location) } return node } +func (p *parser) parseArguments(arguments []Node) []Node { + // If pipe operator is used, the first argument is the left-hand side + // of the operator, so we do not parse it as an argument inside brackets. + offset := len(arguments) + + p.expect(Bracket, "(") + for !p.current.Is(Bracket, ")") && p.err == nil { + if len(arguments) > offset { + p.expect(Operator, ",") + } + node := p.parseExpression(0) + arguments = append(arguments, node) + } + p.expect(Bracket, ")") + + return arguments +} + func (p *parser) parseClosure() Node { startToken := p.current expectClosingBracket := false @@ -575,7 +609,7 @@ func (p *parser) parsePostfixExpression(node Node) Node { memberNode.Method = true node = &CallNode{ Callee: memberNode, - Arguments: p.parseArguments(), + Arguments: p.parseArguments([]Node{}), } node.SetLocation(propertyToken.Location) } else { @@ -641,72 +675,3 @@ func (p *parser) parsePostfixExpression(node Node) Node { } return node } - -func (p *parser) parsePipe(node Node) Node { - identifier := p.current - p.expect(Identifier) - - arguments := []Node{node} - - if b, ok := predicates[identifier.Value]; ok { - p.expect(Bracket, "(") - - // TODO: Refactor parser to use builtin.Builtins instead of predicates map. - - if b.arity == 2 { - arguments = append(arguments, p.parseClosure()) - } - - if identifier.Value == "reduce" { - arguments = append(arguments, p.parseClosure()) - if p.current.Is(Operator, ",") { - p.next() - arguments = append(arguments, p.parseExpression(0)) - } - } - - p.expect(Bracket, ")") - - node = &BuiltinNode{ - Name: identifier.Value, - Arguments: arguments, - } - node.SetLocation(identifier.Location) - } else if _, ok := builtin.Index[identifier.Value]; ok { - arguments = append(arguments, p.parseArguments()...) - - node = &BuiltinNode{ - Name: identifier.Value, - Arguments: arguments, - } - node.SetLocation(identifier.Location) - } else { - callee := &IdentifierNode{Value: identifier.Value} - callee.SetLocation(identifier.Location) - - arguments = append(arguments, p.parseArguments()...) - - node = &CallNode{ - Callee: callee, - Arguments: arguments, - } - node.SetLocation(identifier.Location) - } - - return node -} - -func (p *parser) parseArguments() []Node { - p.expect(Bracket, "(") - nodes := make([]Node, 0) - for !p.current.Is(Bracket, ")") && p.err == nil { - if len(nodes) > 0 { - p.expect(Operator, ",") - } - node := p.parseExpression(0) - nodes = append(nodes, node) - } - p.expect(Bracket, ")") - - return nodes -}