From 50ce3f973cbc96a0326560a31b736a4f0ca8dc62 Mon Sep 17 00:00:00 2001 From: adonovan Date: Wed, 9 Dec 2020 10:32:34 -0800 Subject: [PATCH] starlark: allow lambda expressions This change introduces lambda expressions, following Python, as a shorthand for declaring anonymous functions whose body is a single expression. RELNOTES: Starlark now supports lambda (anonymous function) expressions. PiperOrigin-RevId: 346583352 --- .../build/lib/packages/PackageFactory.java | 11 ++- .../java/net/starlark/java/eval/Eval.java | 18 ++--- .../starlark/java/eval/StarlarkFunction.java | 4 +- src/main/java/net/starlark/java/syntax/BUILD | 1 + .../net/starlark/java/syntax/Expression.java | 1 + .../java/syntax/LambdaExpression.java | 75 +++++++++++++++++++ .../net/starlark/java/syntax/NodePrinter.java | 15 ++++ .../net/starlark/java/syntax/NodeVisitor.java | 5 ++ .../java/net/starlark/java/syntax/Parser.java | 35 +++++++-- .../net/starlark/java/syntax/Resolver.java | 21 +++++- .../lib/packages/PackageFactoryTest.java | 9 ++- .../starlark/java/eval/testdata/function.star | 72 ++++++++++++------ .../net/starlark/java/syntax/ParserTest.java | 30 +++++++- 13 files changed, 250 insertions(+), 47 deletions(-) create mode 100644 src/main/java/net/starlark/java/syntax/LambdaExpression.java diff --git a/src/main/java/com/google/devtools/build/lib/packages/PackageFactory.java b/src/main/java/com/google/devtools/build/lib/packages/PackageFactory.java index 1d286feed4a72c..58795a58923465 100644 --- a/src/main/java/com/google/devtools/build/lib/packages/PackageFactory.java +++ b/src/main/java/com/google/devtools/build/lib/packages/PackageFactory.java @@ -75,6 +75,7 @@ import net.starlark.java.syntax.Identifier; import net.starlark.java.syntax.IfStatement; import net.starlark.java.syntax.IntLiteral; +import net.starlark.java.syntax.LambdaExpression; import net.starlark.java.syntax.ListExpression; import net.starlark.java.syntax.Location; import net.starlark.java.syntax.NodeVisitor; @@ -1003,7 +1004,15 @@ void recordGeneratorName(CallExpression call) { public void visit(DefStatement node) { error( node.getStartLocation(), - "function definitions are not allowed in BUILD files. You may move the function to " + "functions may not be defined in BUILD files. You may move the function to " + + "a .bzl file and load it."); + } + + @Override + public void visit(LambdaExpression node) { + error( + node.getStartLocation(), + "functions may not be defined in BUILD files. You may move the function to " + "a .bzl file and load it."); } diff --git a/src/main/java/net/starlark/java/eval/Eval.java b/src/main/java/net/starlark/java/eval/Eval.java index fd8a9a8c0e1cb3..d5f92aee048584 100644 --- a/src/main/java/net/starlark/java/eval/Eval.java +++ b/src/main/java/net/starlark/java/eval/Eval.java @@ -41,6 +41,7 @@ import net.starlark.java.syntax.IfStatement; import net.starlark.java.syntax.IndexExpression; import net.starlark.java.syntax.IntLiteral; +import net.starlark.java.syntax.LambdaExpression; import net.starlark.java.syntax.ListExpression; import net.starlark.java.syntax.LoadStatement; import net.starlark.java.syntax.Location; @@ -148,10 +149,8 @@ private static TokenKind execFor(StarlarkThread.Frame fr, ForStatement node) return TokenKind.PASS; } - private static void execDef(StarlarkThread.Frame fr, DefStatement node) + private static StarlarkFunction newFunction(StarlarkThread.Frame fr, Resolver.Function rfn) throws EvalException, InterruptedException { - Resolver.Function rfn = node.getResolvedFunction(); - // Evaluate default value expressions of optional parameters. // We use MANDATORY to indicate a required parameter // (not null, because defaults must be a legal tuple value, as @@ -196,11 +195,8 @@ private static void execDef(StarlarkThread.Frame fr, DefStatement node) // Nested functions use the same globalIndex as their enclosing function, // since both were compiled from the same Program. StarlarkFunction fn = fn(fr); - assignIdentifier( - fr, - node.getIdentifier(), - new StarlarkFunction( - rfn, fn.getModule(), fn.globalIndex, Tuple.wrap(defaults), Tuple.wrap(freevars))); + return new StarlarkFunction( + rfn, fn.getModule(), fn.globalIndex, Tuple.wrap(defaults), Tuple.wrap(freevars)); } private static TokenKind execIf(StarlarkThread.Frame fr, IfStatement node) @@ -289,7 +285,9 @@ private static TokenKind exec(StarlarkThread.Frame fr, Statement st) case FOR: return execFor(fr, (ForStatement) st); case DEF: - execDef(fr, (DefStatement) st); + DefStatement def = (DefStatement) st; + StarlarkFunction fn = newFunction(fr, def.getResolvedFunction()); + assignIdentifier(fr, def.getIdentifier(), fn); return TokenKind.PASS; case IF: return execIf(fr, (IfStatement) st); @@ -481,6 +479,8 @@ private static Object eval(StarlarkThread.Frame fr, Expression expr) } case FLOAT_LITERAL: return StarlarkFloat.of(((FloatLiteral) expr).getValue()); + case LAMBDA: + return newFunction(fr, ((LambdaExpression) expr).getResolvedFunction()); case LIST_EXPR: return evalList(fr, (ListExpression) expr); case SLICE: diff --git a/src/main/java/net/starlark/java/eval/StarlarkFunction.java b/src/main/java/net/starlark/java/eval/StarlarkFunction.java index 4277829efe9f3b..1cb4771388ba27 100644 --- a/src/main/java/net/starlark/java/eval/StarlarkFunction.java +++ b/src/main/java/net/starlark/java/eval/StarlarkFunction.java @@ -139,8 +139,8 @@ public Location getLocation() { } /** - * Returns the name of the function. Implicit functions (those not created by a def statement), - * may have names such as "" or "". + * Returns the name of the function, or "lambda" if anonymous. Implicit functions (those not + * created by a def statement), may have names such as "" or "". */ @Override public String getName() { diff --git a/src/main/java/net/starlark/java/syntax/BUILD b/src/main/java/net/starlark/java/syntax/BUILD index bfc5385ec3791c..967fb3210b414d 100644 --- a/src/main/java/net/starlark/java/syntax/BUILD +++ b/src/main/java/net/starlark/java/syntax/BUILD @@ -35,6 +35,7 @@ java_library( "IfStatement.java", "IndexExpression.java", "IntLiteral.java", + "LambdaExpression.java", "Lexer.java", "ListExpression.java", "LoadStatement.java", diff --git a/src/main/java/net/starlark/java/syntax/Expression.java b/src/main/java/net/starlark/java/syntax/Expression.java index 9a3bdf0ba282ca..069003e1ef235f 100644 --- a/src/main/java/net/starlark/java/syntax/Expression.java +++ b/src/main/java/net/starlark/java/syntax/Expression.java @@ -39,6 +39,7 @@ public enum Kind { IDENTIFIER, INDEX, INT_LITERAL, + LAMBDA, LIST_EXPR, SLICE, STRING_LITERAL, diff --git a/src/main/java/net/starlark/java/syntax/LambdaExpression.java b/src/main/java/net/starlark/java/syntax/LambdaExpression.java new file mode 100644 index 00000000000000..0723f4f23e008e --- /dev/null +++ b/src/main/java/net/starlark/java/syntax/LambdaExpression.java @@ -0,0 +1,75 @@ +// Copyright 2020 The Bazel Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package net.starlark.java.syntax; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; +import javax.annotation.Nullable; + +/** A LambdaExpression ({@code lambda params: body}) denotes an anonymous function. */ +public final class LambdaExpression extends Expression { + + private final int lambdaOffset; // offset of 'lambda' token + private final ImmutableList parameters; + private final Expression body; + + // set by resolver + @Nullable private Resolver.Function resolved; + + LambdaExpression( + FileLocations locs, int lambdaOffset, ImmutableList parameters, Expression body) { + super(locs); + this.lambdaOffset = lambdaOffset; + this.parameters = Preconditions.checkNotNull(parameters); + this.body = Preconditions.checkNotNull(body); + } + + public ImmutableList getParameters() { + return parameters; + } + + public Expression getBody() { + return body; + } + + /** Returns information about the resolved function. Set by the resolver. */ + @Nullable + public Resolver.Function getResolvedFunction() { + return resolved; + } + + void setResolvedFunction(Resolver.Function resolved) { + this.resolved = resolved; + } + + @Override + public int getStartOffset() { + return lambdaOffset; + } + + @Override + public int getEndOffset() { + return body.getEndOffset(); + } + + @Override + public void accept(NodeVisitor visitor) { + visitor.visit(this); + } + + @Override + public Kind kind() { + return Kind.LAMBDA; + } +} diff --git a/src/main/java/net/starlark/java/syntax/NodePrinter.java b/src/main/java/net/starlark/java/syntax/NodePrinter.java index 0d118aa6e54015..7d6cebf8d3f501 100644 --- a/src/main/java/net/starlark/java/syntax/NodePrinter.java +++ b/src/main/java/net/starlark/java/syntax/NodePrinter.java @@ -368,6 +368,21 @@ private void printExpr(Expression expr) { break; } + case LAMBDA: + { + LambdaExpression lambda = (LambdaExpression) expr; + buf.append("lambda"); + String sep = " "; + for (Parameter param : lambda.getParameters()) { + buf.append(sep); + sep = ", "; + printParameter(param); + } + buf.append(": "); + printExpr(lambda.getBody()); + break; + } + case LIST_EXPR: { ListExpression list = (ListExpression) expr; diff --git a/src/main/java/net/starlark/java/syntax/NodeVisitor.java b/src/main/java/net/starlark/java/syntax/NodeVisitor.java index f2f9036ebffd25..cf9ac6960ff2a4 100644 --- a/src/main/java/net/starlark/java/syntax/NodeVisitor.java +++ b/src/main/java/net/starlark/java/syntax/NodeVisitor.java @@ -170,6 +170,11 @@ public void visit(IndexExpression node) { visit(node.getKey()); } + public void visit(LambdaExpression node) { + visitAll(node.getParameters()); + visit(node.getBody()); + } + public void visit(SliceExpression node) { visit(node.getObject()); if (node.getStart() != null) { diff --git a/src/main/java/net/starlark/java/syntax/Parser.java b/src/main/java/net/starlark/java/syntax/Parser.java index d6c15a1688b7a3..3683c5f027a8f2 100644 --- a/src/main/java/net/starlark/java/syntax/Parser.java +++ b/src/main/java/net/starlark/java/syntax/Parser.java @@ -340,7 +340,6 @@ private int syncTo(EnumSet terminatingTokens) { TokenKind.GLOBAL, TokenKind.IMPORT, TokenKind.IS, - TokenKind.LAMBDA, TokenKind.NONLOCAL, TokenKind.RAISE, TokenKind.TRY, @@ -360,7 +359,6 @@ private void checkForbiddenKeywords() { break; case IMPORT: error = "'import' not supported, use 'load' instead"; break; case IS: error = "'is' not supported, use '==' instead"; break; - case LAMBDA: error = "'lambda' not supported, declare a function instead"; break; case RAISE: error = "'raise' not supported, use 'fail' instead"; break; case TRY: error = "'try' not supported, all exceptions are fatal"; break; case WHILE: error = "'while' not supported, use 'for' instead"; break; @@ -432,7 +430,7 @@ private Argument parseArgument() { // arg = IDENTIFIER '=' test // | IDENTIFIER - private Parameter parseFunctionParameter() { + private Parameter parseParameter() { // **kwargs if (token.kind == TokenKind.STAR_STAR) { int starStarOffset = nextToken(); @@ -752,7 +750,7 @@ private Expression parseComprehensionSuffix(int loffset, Node body, TokenKind cl int ifOffset = nextToken(); // [x for x in li if 1, 2] # parse error // [x for x in li if (1, 2)] # ok - Expression cond = parseTest(0); + Expression cond = parseTestNoCond(); clauses.add(new Comprehension.If(locs, ifOffset, cond)); } else if (token.kind == closingBracket) { break; @@ -928,6 +926,10 @@ private Expression optimizeBinOpExpression( // Parses a non-tuple expression ("test" in Python terminology). private Expression parseTest() { int start = token.start; + if (token.kind == TokenKind.LAMBDA) { + return parseLambda(/*allowCond=*/ true); + } + Expression expr = parseTest(0); if (token.kind == TokenKind.IF) { nextToken(); @@ -954,6 +956,25 @@ private Expression parseTest(int prec) { return parseBinOpExpression(prec); } + // parseLambda parses a lambda expression. + // The allowCond flag allows the body to be an 'a if b else c' conditional. + private LambdaExpression parseLambda(boolean allowCond) { + int lambdaOffset = expect(TokenKind.LAMBDA); + ImmutableList params = parseParameters(); + expect(TokenKind.COLON); + Expression body = allowCond ? parseTest() : parseTestNoCond(); + return new LambdaExpression(locs, lambdaOffset, params, body); + } + + // parseTestNoCond parses a a single-component expression without + // consuming a trailing 'if expr else expr'. + private Expression parseTestNoCond() { + if (token.kind == TokenKind.LAMBDA) { + return parseLambda(/*allowCond=*/ false); + } + return parseTest(0); + } + // not_expr = 'not' expr private Expression parseNotExpression(int prec) { int notOffset = expect(TokenKind.NOT); @@ -1184,7 +1205,9 @@ private ImmutableList parseParameters() { boolean hasParam = false; ImmutableList.Builder list = ImmutableList.builder(); - while (token.kind != TokenKind.RPAREN && token.kind != TokenKind.EOF) { + while (token.kind != TokenKind.RPAREN + && token.kind != TokenKind.COLON + && token.kind != TokenKind.EOF) { if (hasParam) { expect(TokenKind.COMMA); // The list may end with a comma. @@ -1192,7 +1215,7 @@ private ImmutableList parseParameters() { break; } } - Parameter param = parseFunctionParameter(); + Parameter param = parseParameter(); hasParam = true; list.add(param); } diff --git a/src/main/java/net/starlark/java/syntax/Resolver.java b/src/main/java/net/starlark/java/syntax/Resolver.java index 9369db664a0128..464545feebe642 100644 --- a/src/main/java/net/starlark/java/syntax/Resolver.java +++ b/src/main/java/net/starlark/java/syntax/Resolver.java @@ -326,7 +326,7 @@ public static Module moduleWithPredeclared(String... names) { private static class Block { @Nullable private final Block parent; // enclosing block, or null for tail of list - @Nullable Node syntax; // Comprehension, DefStatement, StarlarkFile, or null + @Nullable Node syntax; // Comprehension, DefStatement/LambdaExpression, StarlarkFile, or null private final ArrayList frame; // accumulated locals of enclosing function // Accumulated CELL/FREE bindings of the enclosing function that will provide // the values for the free variables of this function; see Function.getFreeVars. @@ -554,7 +554,8 @@ private static Binding lookupLexical(String name, Block b) { // This step may occur many times if the lookupLexical // recursion returns through many functions. // TODO(adonovan): make this 'DEF or LAMBDA' when we have lambda. - if (bind != null && b.syntax instanceof DefStatement) { + if (bind != null + && (b.syntax instanceof DefStatement || b.syntax instanceof LambdaExpression)) { Scope scope = bind.getScope(); if (scope == Scope.LOCAL || scope == Scope.FREE || scope == Scope.CELL) { if (scope == Scope.LOCAL) { @@ -717,8 +718,20 @@ public void visit(DefStatement node) { node.getBody())); } + @Override + public void visit(LambdaExpression expr) { + expr.setResolvedFunction( + resolveFunction( + expr, + "lambda", + expr.getStartLocation(), + expr.getParameters(), + ImmutableList.of(ReturnStatement.make(expr.getBody())))); + } + + // Common code for def, lambda. private Function resolveFunction( - DefStatement def, + Node syntax, // DefStatement or LambdaExpression String name, Location loc, ImmutableList parameters, @@ -734,7 +747,7 @@ private Function resolveFunction( // Enter function block. ArrayList frame = new ArrayList<>(); ArrayList freevars = new ArrayList<>(); - pushLocalBlock(def, frame, freevars); + pushLocalBlock(syntax, frame, freevars); // Check parameter order and convert to run-time order: // positionals, keyword-only, *args, **kwargs. diff --git a/src/test/java/com/google/devtools/build/lib/packages/PackageFactoryTest.java b/src/test/java/com/google/devtools/build/lib/packages/PackageFactoryTest.java index d6420863903ff1..7195f13a7aadc8 100644 --- a/src/test/java/com/google/devtools/build/lib/packages/PackageFactoryTest.java +++ b/src/test/java/com/google/devtools/build/lib/packages/PackageFactoryTest.java @@ -1154,7 +1154,14 @@ public void testGlobPatternExtractor() { public void testDefInBuild() throws Exception { checkBuildDialectError( "def func(): pass", // - "function definitions are not allowed in BUILD files"); + "functions may not be defined in BUILD files"); + } + + @Test + public void testLambdaInBuild() throws Exception { + checkBuildDialectError( + "lambda: None", // + "functions may not be defined in BUILD files"); } @Test diff --git a/src/test/java/net/starlark/java/eval/testdata/function.star b/src/test/java/net/starlark/java/eval/testdata/function.star index 99394092c60670..bc516d2d040455 100644 --- a/src/test/java/net/starlark/java/eval/testdata/function.star +++ b/src/test/java/net/starlark/java/eval/testdata/function.star @@ -81,10 +81,9 @@ x() ### 'string' object is not callable getattr(1, []) ### parameter 'name' got value of type 'list', want 'string' --- -# assert_fails. This will be more useful when we add lambda. -def divzero(): 1//0 -assert_fails(divzero, 'integer division by zero') - +# assert_fails evaluates an expression (passed unevaluated in the form of +# a lambda) and asserts that evaluation fails with the given error. +assert_fails(lambda: 1//0, 'integer division by zero') --- # Test of nested def statements. @@ -102,6 +101,14 @@ assert_eq(addlam("bda"), "lambda") assert_eq(addlam("bada"), "lambada") +# Same, with lambda +def adder2(x): + return lambda y: x+y + +assert_eq(adder2(3)(1), 4) +assert_eq(adder2("lam")("bda"), "lambda") + + # Test of stateful function values. def makerand(seed=0): "makerand returns a stateful generator of small pseudorandom numbers." @@ -130,9 +137,7 @@ assert_fails(rand3, "trying to mutate a frozen list value") def fib(x): return x if x < 2 else fib(x-1)+fib(x-2) -# TODO(adonovan): use lambda. -def fib10(): return fib(10) -assert_fails(fib10, "function 'fib' called recursively") +assert_fails(lambda: fib(10), "function 'fib' called recursively") --- # The recursion check breaks function encapsulation: @@ -141,36 +146,57 @@ assert_fails(fib10, "function 'fib' called recursively") # called from within an active call of that helper. def call(f): f() def g(): call(list) -# TODO(adonovan): use lambda. -def call_g(): call(g) -assert_fails(call_g, "function 'call' called recursively") +assert_fails(lambda: call(g), "function 'call' called recursively") --- # The recursion check is based on the syntactic equality # (same def statement), not function value equivalence. def eta(f): - # TODO(adonovan): use lambda - def call(): - f() - return call + return lambda: f() def nop(): pass -# fn1 and fn2 are both created by 'def call', +# fn1 and fn2 are both created by the same lambda, # but they are distinct and close over different values... fn1 = eta(nop) fn2 = eta(fn1) -assert_eq(str(fn1), '') -assert_eq(str(fn2), '') +assert_eq(str(fn1), '') +assert_eq(str(fn2), '') assert_(fn1 != fn2) # ...yet both cannot be called in the same thread: -assert_fails(fn2, "function 'call' called recursively") +assert_fails(fn2, "function 'lambda' called recursively") # This rule prevents users from writing the Y combinator, # which creates a new closure at each step of the recursion. -# TODO(adonovan): enable test when we have lambda. -# Y = lambda f: (lambda x: x(x))(lambda y: f(lambda *args: y(y)(*args))) -# fibgen = lambda fib: lambda x: (x if x<2 else fib(x-1)+fib(x-2)) -# fib2 = Y(fibgen) -# assert_fails(lambda: [fib2(x) for x in range(10)], "function lambda called recursively") +Y = lambda f: (lambda x: x(x))(lambda y: f(lambda *args: y(y)(*args))) +fibgen = lambda fib: lambda x: (x if x<2 else fib(x-1)+fib(x-2)) +fib2 = Y(fibgen) +assert_fails(lambda: [fib2(x) for x in range(10)], "function 'lambda' called recursively") + +--- +# Trivial test of lambda. +def map(f, list): + return [f(x) for x in list] + +assert_eq(map(lambda x: len(x), ["one", "two", "three"]), + [3, 3, 5]) + +assert_eq(type(lambda: 0), "function") +assert_eq(str(lambda: 0), "") + +--- +# builder returns a string builder: +# an opaque, stateful value with methods and open recursion. +def builder(): + chunks = [] + self = None + def append(x): + chunks.append("%s" % x) + return self + def build(): + return "".join(chunks) + self = struct(append = append, build = build) + return self + +assert_eq(builder().append(1).append(" + ").append(2).build(), "1 + 2") diff --git a/src/test/java/net/starlark/java/syntax/ParserTest.java b/src/test/java/net/starlark/java/syntax/ParserTest.java index 6bece4d94947d1..5c32809dabf4a4 100644 --- a/src/test/java/net/starlark/java/syntax/ParserTest.java +++ b/src/test/java/net/starlark/java/syntax/ParserTest.java @@ -1013,6 +1013,35 @@ public void testDefSingleLine() throws Exception { assertThat(stmt.getBody()).hasSize(2); } + @Test + public void testLambda() throws Exception { + parseExpression("lambda a, b=1, *args, **kwargs: a+b"); + parseExpression("lambda *, a, *b: 0"); + + // lambda has lower predecence than binary or. + assertThat(parseExpression("lambda: x or y").toString()).isEqualTo("lambda: (x or y)"); + + // This is a well known parsing ambiguity in Python. + // Python 2.7 accepts it but Python3 and Starlark reject it. + parseExpressionError("[x for x in lambda: True, lambda: False if x()]"); + + // ok in all dialects: + parseExpression("[x for x in (lambda: True, lambda: False) if x()]"); + + // An unparenthesized tuple is not allowed as the operand + // of an 'if' clause in a comprehension, but a lambda is ok. + assertThat(parseExpressionError("[a for b in c if 1, 2]")) + .contains("expected ']', 'for' or 'if'"); + parseExpression("[a for b in c if lambda: d]"); + // But the body of the unparenthesized lambda may not be a conditional: + parseExpression("[a for b in c if (lambda: d if e else f)]"); + assertThat(parseExpressionError("[a for b in c if lambda: d if e else f]")) + .contains("expected ']', 'for' or 'if'"); + + // A lambda is not allowed as the operand of a 'for' clause. + assertThat(parseExpressionError("[a for b in lambda: c]")).contains("syntax error at 'lambda'"); + } + @Test public void testForPass() throws Exception { List statements = parseStatements("def foo():", " pass\n"); @@ -1263,7 +1292,6 @@ public void testClassDefinitionInStarlark() throws Exception { assertContainsError("keyword 'class' not supported"); } - @Test public void testStringsAreDeduped() throws Exception { StarlarkFile file = parseFile("L1 = ['cat', 'dog', 'fish']", "L2 = ['dog', 'fish', 'cat']");