Skip to content

Commit

Permalink
starlark: allow lambda expressions
Browse files Browse the repository at this point in the history
This change introduces lambda expressions, following Python,
as a shorthand for declaring anonymous functions whose body is
a single expression.

RELNOTES: Starlark now supports lambda (anonymous function) expressions.
PiperOrigin-RevId: 346583352
  • Loading branch information
adonovan authored and copybara-github committed Dec 9, 2020
1 parent 337e717 commit 50ce3f9
Show file tree
Hide file tree
Showing 13 changed files with 250 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
import net.starlark.java.syntax.Identifier;
import net.starlark.java.syntax.IfStatement;
import net.starlark.java.syntax.IntLiteral;
import net.starlark.java.syntax.LambdaExpression;
import net.starlark.java.syntax.ListExpression;
import net.starlark.java.syntax.Location;
import net.starlark.java.syntax.NodeVisitor;
Expand Down Expand Up @@ -1003,7 +1004,15 @@ void recordGeneratorName(CallExpression call) {
public void visit(DefStatement node) {
error(
node.getStartLocation(),
"function definitions are not allowed in BUILD files. You may move the function to "
"functions may not be defined in BUILD files. You may move the function to "
+ "a .bzl file and load it.");
}

@Override
public void visit(LambdaExpression node) {
error(
node.getStartLocation(),
"functions may not be defined in BUILD files. You may move the function to "
+ "a .bzl file and load it.");
}

Expand Down
18 changes: 9 additions & 9 deletions src/main/java/net/starlark/java/eval/Eval.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import net.starlark.java.syntax.IfStatement;
import net.starlark.java.syntax.IndexExpression;
import net.starlark.java.syntax.IntLiteral;
import net.starlark.java.syntax.LambdaExpression;
import net.starlark.java.syntax.ListExpression;
import net.starlark.java.syntax.LoadStatement;
import net.starlark.java.syntax.Location;
Expand Down Expand Up @@ -148,10 +149,8 @@ private static TokenKind execFor(StarlarkThread.Frame fr, ForStatement node)
return TokenKind.PASS;
}

private static void execDef(StarlarkThread.Frame fr, DefStatement node)
private static StarlarkFunction newFunction(StarlarkThread.Frame fr, Resolver.Function rfn)
throws EvalException, InterruptedException {
Resolver.Function rfn = node.getResolvedFunction();

// Evaluate default value expressions of optional parameters.
// We use MANDATORY to indicate a required parameter
// (not null, because defaults must be a legal tuple value, as
Expand Down Expand Up @@ -196,11 +195,8 @@ private static void execDef(StarlarkThread.Frame fr, DefStatement node)
// Nested functions use the same globalIndex as their enclosing function,
// since both were compiled from the same Program.
StarlarkFunction fn = fn(fr);
assignIdentifier(
fr,
node.getIdentifier(),
new StarlarkFunction(
rfn, fn.getModule(), fn.globalIndex, Tuple.wrap(defaults), Tuple.wrap(freevars)));
return new StarlarkFunction(
rfn, fn.getModule(), fn.globalIndex, Tuple.wrap(defaults), Tuple.wrap(freevars));
}

private static TokenKind execIf(StarlarkThread.Frame fr, IfStatement node)
Expand Down Expand Up @@ -289,7 +285,9 @@ private static TokenKind exec(StarlarkThread.Frame fr, Statement st)
case FOR:
return execFor(fr, (ForStatement) st);
case DEF:
execDef(fr, (DefStatement) st);
DefStatement def = (DefStatement) st;
StarlarkFunction fn = newFunction(fr, def.getResolvedFunction());
assignIdentifier(fr, def.getIdentifier(), fn);
return TokenKind.PASS;
case IF:
return execIf(fr, (IfStatement) st);
Expand Down Expand Up @@ -481,6 +479,8 @@ private static Object eval(StarlarkThread.Frame fr, Expression expr)
}
case FLOAT_LITERAL:
return StarlarkFloat.of(((FloatLiteral) expr).getValue());
case LAMBDA:
return newFunction(fr, ((LambdaExpression) expr).getResolvedFunction());
case LIST_EXPR:
return evalList(fr, (ListExpression) expr);
case SLICE:
Expand Down
4 changes: 2 additions & 2 deletions src/main/java/net/starlark/java/eval/StarlarkFunction.java
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,8 @@ public Location getLocation() {
}

/**
* Returns the name of the function. Implicit functions (those not created by a def statement),
* may have names such as "<toplevel>" or "<expr>".
* Returns the name of the function, or "lambda" if anonymous. Implicit functions (those not
* created by a def statement), may have names such as "<toplevel>" or "<expr>".
*/
@Override
public String getName() {
Expand Down
1 change: 1 addition & 0 deletions src/main/java/net/starlark/java/syntax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ java_library(
"IfStatement.java",
"IndexExpression.java",
"IntLiteral.java",
"LambdaExpression.java",
"Lexer.java",
"ListExpression.java",
"LoadStatement.java",
Expand Down
1 change: 1 addition & 0 deletions src/main/java/net/starlark/java/syntax/Expression.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ public enum Kind {
IDENTIFIER,
INDEX,
INT_LITERAL,
LAMBDA,
LIST_EXPR,
SLICE,
STRING_LITERAL,
Expand Down
75 changes: 75 additions & 0 deletions src/main/java/net/starlark/java/syntax/LambdaExpression.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
// Copyright 2020 The Bazel Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package net.starlark.java.syntax;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import javax.annotation.Nullable;

/** A LambdaExpression ({@code lambda params: body}) denotes an anonymous function. */
public final class LambdaExpression extends Expression {

private final int lambdaOffset; // offset of 'lambda' token
private final ImmutableList<Parameter> parameters;
private final Expression body;

// set by resolver
@Nullable private Resolver.Function resolved;

LambdaExpression(
FileLocations locs, int lambdaOffset, ImmutableList<Parameter> parameters, Expression body) {
super(locs);
this.lambdaOffset = lambdaOffset;
this.parameters = Preconditions.checkNotNull(parameters);
this.body = Preconditions.checkNotNull(body);
}

public ImmutableList<Parameter> getParameters() {
return parameters;
}

public Expression getBody() {
return body;
}

/** Returns information about the resolved function. Set by the resolver. */
@Nullable
public Resolver.Function getResolvedFunction() {
return resolved;
}

void setResolvedFunction(Resolver.Function resolved) {
this.resolved = resolved;
}

@Override
public int getStartOffset() {
return lambdaOffset;
}

@Override
public int getEndOffset() {
return body.getEndOffset();
}

@Override
public void accept(NodeVisitor visitor) {
visitor.visit(this);
}

@Override
public Kind kind() {
return Kind.LAMBDA;
}
}
15 changes: 15 additions & 0 deletions src/main/java/net/starlark/java/syntax/NodePrinter.java
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,21 @@ private void printExpr(Expression expr) {
break;
}

case LAMBDA:
{
LambdaExpression lambda = (LambdaExpression) expr;
buf.append("lambda");
String sep = " ";
for (Parameter param : lambda.getParameters()) {
buf.append(sep);
sep = ", ";
printParameter(param);
}
buf.append(": ");
printExpr(lambda.getBody());
break;
}

case LIST_EXPR:
{
ListExpression list = (ListExpression) expr;
Expand Down
5 changes: 5 additions & 0 deletions src/main/java/net/starlark/java/syntax/NodeVisitor.java
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,11 @@ public void visit(IndexExpression node) {
visit(node.getKey());
}

public void visit(LambdaExpression node) {
visitAll(node.getParameters());
visit(node.getBody());
}

public void visit(SliceExpression node) {
visit(node.getObject());
if (node.getStart() != null) {
Expand Down
35 changes: 29 additions & 6 deletions src/main/java/net/starlark/java/syntax/Parser.java
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,6 @@ private int syncTo(EnumSet<TokenKind> terminatingTokens) {
TokenKind.GLOBAL,
TokenKind.IMPORT,
TokenKind.IS,
TokenKind.LAMBDA,
TokenKind.NONLOCAL,
TokenKind.RAISE,
TokenKind.TRY,
Expand All @@ -360,7 +359,6 @@ private void checkForbiddenKeywords() {
break;
case IMPORT: error = "'import' not supported, use 'load' instead"; break;
case IS: error = "'is' not supported, use '==' instead"; break;
case LAMBDA: error = "'lambda' not supported, declare a function instead"; break;
case RAISE: error = "'raise' not supported, use 'fail' instead"; break;
case TRY: error = "'try' not supported, all exceptions are fatal"; break;
case WHILE: error = "'while' not supported, use 'for' instead"; break;
Expand Down Expand Up @@ -432,7 +430,7 @@ private Argument parseArgument() {

// arg = IDENTIFIER '=' test
// | IDENTIFIER
private Parameter parseFunctionParameter() {
private Parameter parseParameter() {
// **kwargs
if (token.kind == TokenKind.STAR_STAR) {
int starStarOffset = nextToken();
Expand Down Expand Up @@ -752,7 +750,7 @@ private Expression parseComprehensionSuffix(int loffset, Node body, TokenKind cl
int ifOffset = nextToken();
// [x for x in li if 1, 2] # parse error
// [x for x in li if (1, 2)] # ok
Expression cond = parseTest(0);
Expression cond = parseTestNoCond();
clauses.add(new Comprehension.If(locs, ifOffset, cond));
} else if (token.kind == closingBracket) {
break;
Expand Down Expand Up @@ -928,6 +926,10 @@ private Expression optimizeBinOpExpression(
// Parses a non-tuple expression ("test" in Python terminology).
private Expression parseTest() {
int start = token.start;
if (token.kind == TokenKind.LAMBDA) {
return parseLambda(/*allowCond=*/ true);
}

Expression expr = parseTest(0);
if (token.kind == TokenKind.IF) {
nextToken();
Expand All @@ -954,6 +956,25 @@ private Expression parseTest(int prec) {
return parseBinOpExpression(prec);
}

// parseLambda parses a lambda expression.
// The allowCond flag allows the body to be an 'a if b else c' conditional.
private LambdaExpression parseLambda(boolean allowCond) {
int lambdaOffset = expect(TokenKind.LAMBDA);
ImmutableList<Parameter> params = parseParameters();
expect(TokenKind.COLON);
Expression body = allowCond ? parseTest() : parseTestNoCond();
return new LambdaExpression(locs, lambdaOffset, params, body);
}

// parseTestNoCond parses a a single-component expression without
// consuming a trailing 'if expr else expr'.
private Expression parseTestNoCond() {
if (token.kind == TokenKind.LAMBDA) {
return parseLambda(/*allowCond=*/ false);
}
return parseTest(0);
}

// not_expr = 'not' expr
private Expression parseNotExpression(int prec) {
int notOffset = expect(TokenKind.NOT);
Expand Down Expand Up @@ -1184,15 +1205,17 @@ private ImmutableList<Parameter> parseParameters() {
boolean hasParam = false;
ImmutableList.Builder<Parameter> list = ImmutableList.builder();

while (token.kind != TokenKind.RPAREN && token.kind != TokenKind.EOF) {
while (token.kind != TokenKind.RPAREN
&& token.kind != TokenKind.COLON
&& token.kind != TokenKind.EOF) {
if (hasParam) {
expect(TokenKind.COMMA);
// The list may end with a comma.
if (token.kind == TokenKind.RPAREN) {
break;
}
}
Parameter param = parseFunctionParameter();
Parameter param = parseParameter();
hasParam = true;
list.add(param);
}
Expand Down
21 changes: 17 additions & 4 deletions src/main/java/net/starlark/java/syntax/Resolver.java
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ public static Module moduleWithPredeclared(String... names) {

private static class Block {
@Nullable private final Block parent; // enclosing block, or null for tail of list
@Nullable Node syntax; // Comprehension, DefStatement, StarlarkFile, or null
@Nullable Node syntax; // Comprehension, DefStatement/LambdaExpression, StarlarkFile, or null
private final ArrayList<Binding> frame; // accumulated locals of enclosing function
// Accumulated CELL/FREE bindings of the enclosing function that will provide
// the values for the free variables of this function; see Function.getFreeVars.
Expand Down Expand Up @@ -554,7 +554,8 @@ private static Binding lookupLexical(String name, Block b) {
// This step may occur many times if the lookupLexical
// recursion returns through many functions.
// TODO(adonovan): make this 'DEF or LAMBDA' when we have lambda.
if (bind != null && b.syntax instanceof DefStatement) {
if (bind != null
&& (b.syntax instanceof DefStatement || b.syntax instanceof LambdaExpression)) {
Scope scope = bind.getScope();
if (scope == Scope.LOCAL || scope == Scope.FREE || scope == Scope.CELL) {
if (scope == Scope.LOCAL) {
Expand Down Expand Up @@ -717,8 +718,20 @@ public void visit(DefStatement node) {
node.getBody()));
}

@Override
public void visit(LambdaExpression expr) {
expr.setResolvedFunction(
resolveFunction(
expr,
"lambda",
expr.getStartLocation(),
expr.getParameters(),
ImmutableList.of(ReturnStatement.make(expr.getBody()))));
}

// Common code for def, lambda.
private Function resolveFunction(
DefStatement def,
Node syntax, // DefStatement or LambdaExpression
String name,
Location loc,
ImmutableList<Parameter> parameters,
Expand All @@ -734,7 +747,7 @@ private Function resolveFunction(
// Enter function block.
ArrayList<Binding> frame = new ArrayList<>();
ArrayList<Binding> freevars = new ArrayList<>();
pushLocalBlock(def, frame, freevars);
pushLocalBlock(syntax, frame, freevars);

// Check parameter order and convert to run-time order:
// positionals, keyword-only, *args, **kwargs.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1154,7 +1154,14 @@ public void testGlobPatternExtractor() {
public void testDefInBuild() throws Exception {
checkBuildDialectError(
"def func(): pass", //
"function definitions are not allowed in BUILD files");
"functions may not be defined in BUILD files");
}

@Test
public void testLambdaInBuild() throws Exception {
checkBuildDialectError(
"lambda: None", //
"functions may not be defined in BUILD files");
}

@Test
Expand Down
Loading

0 comments on commit 50ce3f9

Please sign in to comment.