From 5ca20643e54e5cff1eb2939d044f86f96861176a Mon Sep 17 00:00:00 2001 From: adonovan Date: Tue, 8 Dec 2020 11:04:36 -0800 Subject: [PATCH] starlark: allow nested def statements This change removes the restriction that def statements may not be nested. As in Python, Starlark's nested def statements are lexically scoped, and are implemented using closures, a technique first described in 1964 (for Landin's SECD machine) and employed in essentially every language since Scheme in the 1970s. The resolver computes the free variables of each function, which are in effect treated as hidden parameters implicitly supplied from the environment in which the def statement is executed. Local variables shared between outer and inner functions are indirect and called Cells. The tests now use the assert_fails(lambda: expr, "expected error") construct so that they can test failures without aborting the test chunk. (There is no lambda syntax yet, but its addition is trivial and will be done in a follow-up; see CL 345746527.) RELNOTES: Starlark now permits def statements to be nested (closures). PiperOrigin-RevId: 346365019 --- .../java/net/starlark/java/eval/Eval.java | 34 +++- .../java/net/starlark/java/eval/Starlark.java | 14 +- .../starlark/java/eval/StarlarkFunction.java | 52 +++++- .../starlark/java/eval/StarlarkThread.java | 37 ++-- .../net/starlark/java/syntax/Resolver.java | 160 +++++++++++++----- src/test/java/net/starlark/java/eval/BUILD | 1 + .../net/starlark/java/eval/ScriptTest.java | 55 +++++- .../starlark/java/eval/testdata/function.star | 95 +++++++++++ .../starlark/java/syntax/ResolverTest.java | 29 ++-- 9 files changed, 391 insertions(+), 86 deletions(-) diff --git a/src/main/java/net/starlark/java/eval/Eval.java b/src/main/java/net/starlark/java/eval/Eval.java index 5cfa7ed4033ca7..fd8a9a8c0e1cb3 100644 --- a/src/main/java/net/starlark/java/eval/Eval.java +++ b/src/main/java/net/starlark/java/eval/Eval.java @@ -175,13 +175,32 @@ private static void execDef(StarlarkThread.Frame fr, DefStatement node) defaults = EMPTY; } + // Capture the cells of the function's + // free variables from the lexical environment. + Object[] freevars = new Object[rfn.getFreeVars().size()]; + int i = 0; + for (Resolver.Binding bind : rfn.getFreeVars()) { + // Unlike expr(Identifier), we want the cell itself, not its content. + switch (bind.getScope()) { + case FREE: + freevars[i++] = fn(fr).getFreeVar(bind.getIndex()); + break; + case CELL: + freevars[i++] = fr.locals[bind.getIndex()]; + break; + default: + throw new IllegalStateException("unexpected: " + bind); + } + } + // Nested functions use the same globalIndex as their enclosing function, // since both were compiled from the same Program. StarlarkFunction fn = fn(fr); assignIdentifier( fr, node.getIdentifier(), - new StarlarkFunction(rfn, Tuple.wrap(defaults), fn.getModule(), fn.globalIndex)); + new StarlarkFunction( + rfn, fn.getModule(), fn.globalIndex, Tuple.wrap(defaults), Tuple.wrap(freevars))); } private static TokenKind execIf(StarlarkThread.Frame fr, IfStatement node) @@ -231,8 +250,8 @@ private static void execLoad(StarlarkThread.Frame fr, LoadStatement node) throws // loads bind file-locally. Either way, the resolver should designate // the proper scope of binding.getLocalName() and this should become // simply assign(binding.getLocalName(), value). - // Currently, we update the module but not module.exportedGlobals; - // changing it to fr.locals.put breaks a test. TODO(adonovan): find out why. + // Currently, we update the module but not module.exportedGlobals. + // Change it to a local binding now that closures are supported. fn(fr).setGlobal(binding.getLocalName().getBinding().getIndex(), value); } } @@ -328,6 +347,9 @@ private static void assignIdentifier(StarlarkThread.Frame fr, Identifier id, Obj case LOCAL: fr.locals[bind.getIndex()] = value; break; + case CELL: + ((StarlarkFunction.Cell) fr.locals[bind.getIndex()]).x = value; + break; case GLOBAL: // Updates a module binding and sets its 'exported' flag. // (Only load bindings are not exported. @@ -637,6 +659,12 @@ private static Object evalIdentifier(StarlarkThread.Frame fr, Identifier id) case LOCAL: result = fr.locals[bind.getIndex()]; break; + case CELL: + result = ((StarlarkFunction.Cell) fr.locals[bind.getIndex()]).x; + break; + case FREE: + result = fn(fr).getFreeVar(bind.getIndex()).x; + break; case GLOBAL: result = fn(fr).getGlobal(bind.getIndex()); break; diff --git a/src/main/java/net/starlark/java/eval/Starlark.java b/src/main/java/net/starlark/java/eval/Starlark.java index 027f9c6d9f3cc1..e4e69081b91d89 100644 --- a/src/main/java/net/starlark/java/eval/Starlark.java +++ b/src/main/java/net/starlark/java/eval/Starlark.java @@ -868,8 +868,6 @@ public static Object execFile( */ public static Object execFileProgram(Program prog, Module module, StarlarkThread thread) throws EvalException, InterruptedException { - Tuple defaultValues = Tuple.empty(); - Resolver.Function rfn = prog.getResolvedFunction(); // A given Module may be passed to execFileProgram multiple times in sequence, @@ -884,7 +882,13 @@ public static Object execFileProgram(Program prog, Module module, StarlarkThread // two array lookups. int[] globalIndex = module.getIndicesOfGlobals(rfn.getGlobals()); - StarlarkFunction toplevel = new StarlarkFunction(rfn, defaultValues, module, globalIndex); + StarlarkFunction toplevel = + new StarlarkFunction( + rfn, + module, + globalIndex, + /*defaultValues=*/ Tuple.empty(), + /*freevars=*/ Tuple.empty()); return Starlark.fastcall(thread, toplevel, EMPTY, EMPTY); } @@ -928,10 +932,10 @@ public static StarlarkFunction newExprFunction( ParserInput input, FileOptions options, Module module) throws SyntaxError.Exception { Expression expr = Expression.parse(input, options); Program prog = Program.compileExpr(expr, module, options); - Tuple defaultValues = Tuple.empty(); Resolver.Function rfn = prog.getResolvedFunction(); int[] globalIndex = module.getIndicesOfGlobals(rfn.getGlobals()); // see execFileProgram - return new StarlarkFunction(rfn, defaultValues, module, globalIndex); + return new StarlarkFunction( + rfn, module, globalIndex, /*defaultValues=*/ Tuple.empty(), /*freevars=*/ Tuple.empty()); } /** diff --git a/src/main/java/net/starlark/java/eval/StarlarkFunction.java b/src/main/java/net/starlark/java/eval/StarlarkFunction.java index 0889d8c261abe6..4277829efe9f3b 100644 --- a/src/main/java/net/starlark/java/eval/StarlarkFunction.java +++ b/src/main/java/net/starlark/java/eval/StarlarkFunction.java @@ -36,15 +36,33 @@ public final class StarlarkFunction implements StarlarkCallable { final Resolver.Function rfn; - final int[] globalIndex; // index in Module.globals of ith Program global (binding index) private final Module module; // a function closes over its defining module + + // Index in Module.globals of ith Program global (Resolver.Binding(GLOBAL).index). + // See explanation at Starlark.execFileProgram. + final int[] globalIndex; + + // Default values of optional parameters. + // Indices correspond to the subsequence of parameters after the initial + // required parameters and before *args/**kwargs. + // Contain MANDATORY for the required keyword-only parameters. private final Tuple defaultValues; - StarlarkFunction(Resolver.Function rfn, Tuple defaultValues, Module module, int[] globalIndex) { + // Cells (shared locals) of enclosing functions. + // Indexed by Resolver.Binding(FREE).index values. + private final Tuple freevars; + + StarlarkFunction( + Resolver.Function rfn, + Module module, + int[] globalIndex, + Tuple defaultValues, + Tuple freevars) { this.rfn = rfn; - this.globalIndex = globalIndex; this.module = module; + this.globalIndex = globalIndex; this.defaultValues = defaultValues; + this.freevars = freevars; } // Sets a global variable, given its index in this function's compiled Program. @@ -153,9 +171,6 @@ public Module getModule() { @Override public Object fastcall(StarlarkThread thread, Object[] positional, Object[] named) throws EvalException, InterruptedException { - if (thread.mutability().isFrozen()) { - throw Starlark.errorf("Trying to call in frozen environment"); - } if (!thread.isRecursionAllowed() && thread.isRecursiveCall(this)) { throw Starlark.errorf("function '%s' called recursively", getName()); } @@ -167,9 +182,19 @@ public Object fastcall(StarlarkThread thread, Object[] positional, Object[] name StarlarkThread.Frame fr = thread.frame(0); fr.locals = new Object[rfn.getLocals().size()]; System.arraycopy(arguments, 0, fr.locals, 0, rfn.getParameterNames().size()); + + // Spill indicated locals to cells. + for (int index : rfn.getCellIndices()) { + fr.locals[index] = new Cell(fr.locals[index]); + } + return Eval.execFunctionBody(fr, rfn.getBody()); } + Cell getFreeVar(int index) { + return (Cell) freevars.get(index); + } + @Override public void repr(Printer printer) { // TODO(adonovan): use the file name instead. But that's a breaking Bazel change. @@ -376,9 +401,20 @@ public boolean isImmutable() { } // The MANDATORY sentinel indicates a slot in the defaultValues - // tuple corresponding to a required parameter. It is not visible - // to Java or Starlark code. + // tuple corresponding to a required parameter. + // It is not visible to Java or Starlark code. static final Object MANDATORY = new Mandatory(); private static class Mandatory implements StarlarkValue {} + + // A Cell is a local variable shared between an inner and an outer function. + // It is a StarlarkValue because it is a stack operand and a Tuple element, + // but it is not visible to Java or Starlark code. + static final class Cell implements StarlarkValue { + Object x; + + Cell(Object x) { + this.x = x; + } + } } diff --git a/src/main/java/net/starlark/java/eval/StarlarkThread.java b/src/main/java/net/starlark/java/eval/StarlarkThread.java index 5019d003a8c7bc..635dc5c1505653 100644 --- a/src/main/java/net/starlark/java/eval/StarlarkThread.java +++ b/src/main/java/net/starlark/java/eval/StarlarkThread.java @@ -32,15 +32,17 @@ * per-thread application state (see {@link #setThreadLocal}) that passes through Starlark functions * but does not directly affect them, such as information about the BUILD file being loaded. * - *

Every {@code StarlarkThread} has a {@link Mutability} field, and must be used within a - * function that creates and closes this {@link Mutability} with the try-with-resource pattern. This - * {@link Mutability} is also used when initializing mutable objects within that {@code - * StarlarkThread}. When the {@code Mutability} is closed at the end of the computation, it freezes - * the {@code StarlarkThread} along with all of those objects. This pattern enforces the discipline - * that there should be no dangling mutable {@code StarlarkThread}, or concurrency between - * interacting {@code StarlarkThread}s. It is a Starlark-level error to attempt to mutate a frozen - * {@code StarlarkThread} or its objects, but it is a Java-level error to attempt to mutate an - * unfrozen {@code StarlarkThread} or its objects from within a different {@code StarlarkThread}. + *

StarlarkThreads are not thread-safe: they should be confined to a single Java thread. + * + *

Every StarlarkThread has an associated {@link Mutability}, which should be created for that + * thread, and closed once the thread's work is done. (A try-with-resources statement is handy for + * this purpose.) Starlark values created by the thread are associated with the thread's Mutability, + * so that when the Mutability is closed at the end of the computation, all the values created by + * the thread become frozen. This pattern ensures that all Starlark values are frozen before they + * are published to another thread, and thus that concurrently executing Starlark threads are free + * from data races. Once a thread's mutability is frozen, the thread is unlikely to be useful for + * further computation because it can no longer create mutable values. (This is occasionally + * valuable in tests.) */ public final class StarlarkThread { @@ -136,7 +138,8 @@ static final class Frame implements Debug.Frame { private boolean errorLocationSet; // The locals of this frame, if fn is a StarlarkFunction, otherwise null. - // Set by StarlarkFunction.fastcall. + // Set by StarlarkFunction.fastcall. Elements may be regular Starlark + // values, or wrapped in StarlarkFunction.Cells if shared with a nested function. @Nullable Object[] locals; @Nullable private Object profileSpan; // current span of walltime call profiler @@ -181,8 +184,12 @@ public ImmutableMap getLocals() { ImmutableMap.Builder env = ImmutableMap.builder(); if (fn instanceof StarlarkFunction) { for (int i = 0; i < locals.length; i++) { - if (locals[i] != null) { - env.put(((StarlarkFunction) fn).rfn.getLocals().get(i).getName(), locals[i]); + Object local = locals[i]; + if (local != null) { + if (local instanceof StarlarkFunction.Cell) { + local = ((StarlarkFunction.Cell) local).x; + } + env.put(((StarlarkFunction) fn).rfn.getLocals().get(i).getName(), local); } } } @@ -332,9 +339,9 @@ boolean isRecursiveCall(StarlarkFunction fn) { // Find fn buried within stack. (The top of the stack is assumed to be fn.) for (int i = callstack.size() - 2; i >= 0; --i) { Frame fr = callstack.get(i); - // TODO(adonovan): compare code, not closure values, otherwise - // one can defeat this check by writing the Y combinator. - if (fr.fn.equals(fn)) { + // We compare code, not closure values, otherwise one can defeat the + // check by writing the Y combinator. + if (fr.fn instanceof StarlarkFunction && ((StarlarkFunction) fr.fn).rfn.equals(fn.rfn)) { return true; } } diff --git a/src/main/java/net/starlark/java/syntax/Resolver.java b/src/main/java/net/starlark/java/syntax/Resolver.java index e3b0a62aa8608b..9369db664a0128 100644 --- a/src/main/java/net/starlark/java/syntax/Resolver.java +++ b/src/main/java/net/starlark/java/syntax/Resolver.java @@ -50,21 +50,18 @@ public final class Resolver extends NodeVisitor { // including the spec. // - move the "no if statements at top level" check to bazel's check{Build,*}Syntax // (that's a spec change), or put it behind a FileOptions flag (no spec change). - // - remove restriction on nested def: - // 1. use FREE for scope of references to outer LOCALs, which become CELLs. - // 2. implement closures in eval/. - // - make loads bind locals by default. + // - make loads bind locals by default (depends on closures). /** Scope discriminates the scope of a binding: global, local, etc. */ public enum Scope { /** Binding is local to a function, comprehension, or file (e.g. load). */ LOCAL, - /** Binding occurs outside any function or comprehension. */ + /** Binding is non-local and occurs outside any function or comprehension. */ GLOBAL, /** Binding is local to a function, comprehension, or file, but shared with nested functions. */ - CELL, // TODO(adonovan): implement nested def + CELL, /** Binding is an implicit parameter whose value is the CELL of some enclosing function. */ - FREE, // TODO(adonovan): implement nested def + FREE, /** Binding is predeclared by the application (e.g. glob in Bazel). */ PREDECLARED, /** Binding is predeclared by the core (e.g. None). */ @@ -81,8 +78,8 @@ public String toString() { * Binding. */ public static final class Binding { - private final Scope scope; - private final int index; // index within function (LOCAL) or module (GLOBAL) + private Scope scope; + private final int index; // index within frame (LOCAL/CELL), freevars (FREE), or module (GLOBAL) @Nullable private final Identifier first; // first binding use, if syntactic private Binding(Scope scope, int index, @Nullable Identifier first) { @@ -102,7 +99,10 @@ public Scope getScope() { return scope; } - /** Returns the index of a binding within its function (LOCAL) or module (GLOBAL). */ + /** + * Returns the index of a binding within its function's frame (LOCAL/CELL), freevars (FREE), or + * module (GLOBAL). + */ public int getIndex() { return index; } @@ -129,10 +129,9 @@ public static final class Function { private final ImmutableList parameterNames; private final boolean isToplevel; private final ImmutableList locals; - // TODO(adonovan): move this to Program, but that requires communication - // between resolveFile and compileFile, which depends on use doing the TODO - // described at Program.compileResolvedFile and eliminating that function. - private final ImmutableList globals; + private final int[] cellIndices; + private final ImmutableList freevars; + private final ImmutableList globals; // TODO(adonovan): move to Program. private Function( String name, @@ -143,6 +142,7 @@ private Function( boolean hasKwargs, int numKeywordOnlyParams, List locals, + List freevars, List globals) { this.name = name; this.location = loc; @@ -160,7 +160,23 @@ private Function( this.isToplevel = name.equals(""); this.locals = ImmutableList.copyOf(locals); + this.freevars = ImmutableList.copyOf(freevars); this.globals = ImmutableList.copyOf(globals); + + // Create an index of the locals that are cells. + int ncells = 0; + int nlocals = locals.size(); + for (int i = 0; i < nlocals; i++) { + if (locals.get(i).scope == Scope.CELL) { + ncells++; + } + } + this.cellIndices = new int[ncells]; + for (int i = 0, j = 0; i < nlocals; i++) { + if (locals.get(i).scope == Scope.CELL) { + cellIndices[j++] = i; + } + } } /** @@ -177,6 +193,14 @@ public ImmutableList getLocals() { return locals; } + /** + * Returns the indices within {@code getLocals()} of the "cells", that is, local variables of + * thus function that are shared with nested functions. The caller must not modify the result. + */ + public int[] getCellIndices() { + return cellIndices; + } + /** * Returns the list of names of globals referenced by this function. The order matches the * indices used in compiled code. @@ -185,6 +209,17 @@ public ImmutableList getGlobals() { return globals; } + /** + * Returns the list of enclosing CELL or FREE bindings referenced by this function. At run time, + * these values, all of which are cells containing variables local to some enclosing function, + * will be stored in the closure. (CELL bindings in this list are local to the immediately + * enclosing function, while FREE bindings pass through one or more intermediate enclosing + * functions.) + */ + public ImmutableList getFreeVars() { + return freevars; + } + /** Returns the location of the function's identifier. */ public Location getLocation() { return location; @@ -293,15 +328,24 @@ private static class Block { @Nullable private final Block parent; // enclosing block, or null for tail of list @Nullable Node syntax; // Comprehension, DefStatement, StarlarkFile, or null private final ArrayList frame; // accumulated locals of enclosing function + // Accumulated CELL/FREE bindings of the enclosing function that will provide + // the values for the free variables of this function; see Function.getFreeVars. + // Null for toplevel functions and expressions, which have no free variables. + @Nullable private final ArrayList freevars; // Bindings for names defined in this block. // Also, as an optimization, memoized lookups of enclosing bindings. private final Map bindings = new HashMap<>(); - Block(@Nullable Block parent, @Nullable Node syntax, ArrayList frame) { + Block( + @Nullable Block parent, + @Nullable Node syntax, + ArrayList frame, + @Nullable ArrayList freevars) { this.parent = parent; this.syntax = syntax; this.frame = frame; + this.freevars = freevars; } } @@ -309,7 +353,7 @@ private static class Block { private final FileOptions options; private final Module module; // List whose order defines the numbering of global variables in this program. - private final ArrayList globals = new ArrayList<>(); + private final List globals = new ArrayList<>(); // A cache of PREDECLARED, UNIVERSAL, and GLOBAL bindings queried from the module. private final Map toplevel = new HashMap<>(); // Linked list of blocks, innermost first, for functions and comprehensions and (finally) file. @@ -339,7 +383,6 @@ private void errorf(Location loc, String format, Object... args) { * are sometimes used before their definition point (e.g. functions are not necessarily declared * in order). */ - // TODO(adonovan): eliminate this first pass by using go.starlark.net one-pass approach. private void createBindingsForBlock(Iterable stmts) { for (Statement stmt : stmts) { createBindings(stmt); @@ -440,21 +483,15 @@ public void visit(Identifier id) { private Binding use(Identifier id) { String name = id.getName(); - // local (to function, comprehension, or file)? - for (Block b = locals; b != null; b = b.parent) { - Binding bind = b.bindings.get(name); - if (bind != null) { - // Optimization: memoize lookup of an outer local - // in an inner block, to avoid repeated walks. - if (b != locals) { - locals.bindings.put(name, bind); - } - return bind; - } + // Locally defined in this function, comprehension, + // or file block, or an enclosing one? + Binding bind = lookupLexical(name, locals); + if (bind != null) { + return bind; } - // toplevel (global, predeclared, universal)? - Binding bind = toplevel.get(name); + // Defined at toplevel (global, predeclared, universal)? + bind = toplevel.get(name); if (bind != null) { return bind; } @@ -493,6 +530,49 @@ private Binding use(Identifier id) { return bind; } + // lookupLexical finds a lexically enclosing local binding of the name, + // plumbing it through enclosing functions as needed. + private static Binding lookupLexical(String name, Block b) { + Binding bind = b.bindings.get(name); + if (bind != null) { + return bind; + } + + if (b.parent != null) { + bind = lookupLexical(name, b.parent); + + // If a local binding was found in a parent block, + // and this block is a function, then it is a free variable + // of this function and must be plumbed through. + // Add an implicit FREE binding (a hidden parameter) to this function, + // and record the outer binding that will supply its value when + // we construct the closure. + // Also, mark the outer LOCAL as a CELL: a shared, indirect local. + // (For a comprehension block there's nothing to do, + // because it's part of the same frame as the enclosing block.) + // + // This step may occur many times if the lookupLexical + // recursion returns through many functions. + // TODO(adonovan): make this 'DEF or LAMBDA' when we have lambda. + if (bind != null && b.syntax instanceof DefStatement) { + Scope scope = bind.getScope(); + if (scope == Scope.LOCAL || scope == Scope.FREE || scope == Scope.CELL) { + if (scope == Scope.LOCAL) { + bind.scope = Scope.CELL; + } + int index = b.freevars.size(); + b.freevars.add(bind); + bind = new Binding(Scope.FREE, index, bind.first); + } + } + + // Memoize, to avoid duplicate free vars and repeated walks. + b.bindings.put(name, bind); + } + + return bind; + } + @Override public void visit(ReturnStatement node) { if (locals.syntax instanceof StarlarkFile) { @@ -601,7 +681,7 @@ public void visit(Comprehension node) { // A comprehension defines a distinct lexical block // in the same function's frame. - pushLocalBlock(node, this.locals.frame); + pushLocalBlock(node, this.locals.frame, this.locals.freevars); for (Comprehension.Clause clause : clauses) { if (clause instanceof Comprehension.For) { @@ -628,9 +708,6 @@ public void visit(Comprehension node) { @Override public void visit(DefStatement node) { - if (!(locals.syntax instanceof StarlarkFile)) { - errorf(node, "nested functions are not allowed. Move the function to the top level."); - } node.setResolvedFunction( resolveFunction( node, @@ -656,7 +733,8 @@ private Function resolveFunction( // Enter function block. ArrayList frame = new ArrayList<>(); - pushLocalBlock(def, frame); + ArrayList freevars = new ArrayList<>(); + pushLocalBlock(def, frame, freevars); // Check parameter order and convert to run-time order: // positionals, keyword-only, *args, **kwargs. @@ -743,6 +821,7 @@ private Function resolveFunction( starStar != null, numKeywordOnlyParams, frame, + freevars, globals); } @@ -879,7 +958,7 @@ public static void resolveFile(StarlarkFile file, Module module) { } ArrayList frame = new ArrayList<>(); - r.pushLocalBlock(file, frame); + r.pushLocalBlock(file, frame, /*freevars=*/ null); // First pass: creating bindings for statements in this block. r.createBindingsForBlock(stmts); @@ -911,6 +990,7 @@ public static void resolveFile(StarlarkFile file, Module module) { /*hasKwargs=*/ false, /*numKeywordOnlyParams=*/ 0, frame, + /*freevars=*/ ImmutableList.of(), r.globals)); } @@ -925,7 +1005,7 @@ public static Function resolveExpr(Expression expr, Module module, FileOptions o Resolver r = new Resolver(errors, module, options); ArrayList frame = new ArrayList<>(); - r.pushLocalBlock(null, frame); // for bindings in list comprehensions + r.pushLocalBlock(null, frame, /*freevars=*/ null); // for bindings in list comprehensions r.visit(expr); r.popLocalBlock(); @@ -943,11 +1023,13 @@ public static Function resolveExpr(Expression expr, Module module, FileOptions o /*hasKwargs=*/ false, /*numKeywordOnlyParams=*/ 0, frame, + /*freevars=*/ ImmutableList.of(), r.globals); } - private void pushLocalBlock(Node syntax, ArrayList frame) { - locals = new Block(locals, syntax, frame); + private void pushLocalBlock( + Node syntax, ArrayList frame, @Nullable ArrayList freevars) { + locals = new Block(locals, syntax, frame, freevars); } private void popLocalBlock() { diff --git a/src/test/java/net/starlark/java/eval/BUILD b/src/test/java/net/starlark/java/eval/BUILD index 4d0972d22c223f..6091dc5983b208 100644 --- a/src/test/java/net/starlark/java/eval/BUILD +++ b/src/test/java/net/starlark/java/eval/BUILD @@ -54,6 +54,7 @@ java_test( "//src/main/java/net/starlark/java/eval", "//src/main/java/net/starlark/java/lib/json", "//src/main/java/net/starlark/java/syntax", + "//third_party:error_prone_annotations", "//third_party:guava", ], ) diff --git a/src/test/java/net/starlark/java/eval/ScriptTest.java b/src/test/java/net/starlark/java/eval/ScriptTest.java index ccf89f5e556d66..2e5e499ad6d559 100644 --- a/src/test/java/net/starlark/java/eval/ScriptTest.java +++ b/src/test/java/net/starlark/java/eval/ScriptTest.java @@ -20,6 +20,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.io.Files; +import com.google.errorprone.annotations.FormatMethod; import java.io.File; import java.util.HashMap; import java.util.List; @@ -73,7 +74,7 @@ interface Reporter { public Object assertStarlark(Object cond, String msg, StarlarkThread thread) throws EvalException { if (!Starlark.truth(cond)) { - thread.getThreadLocal(Reporter.class).reportError(thread, "assert_: " + msg); + reportErrorf(thread, "assert_: %s", msg); } return Starlark.NONE; } @@ -88,12 +89,48 @@ public Object assertStarlark(Object cond, String msg, StarlarkThread thread) useStarlarkThread = true) public Object assertEq(Object x, Object y, StarlarkThread thread) throws EvalException { if (!x.equals(y)) { - String msg = String.format("assert_eq: %s != %s", Starlark.repr(x), Starlark.repr(y)); - thread.getThreadLocal(Reporter.class).reportError(thread, msg); + reportErrorf(thread, "assert_eq: %s != %s", Starlark.repr(x), Starlark.repr(y)); } return Starlark.NONE; } + @StarlarkMethod( + name = "assert_fails", + doc = "assert_fails asserts that evaluation of f() fails with the specified error", + parameters = { + @Param(name = "f", doc = "the Starlark function to call"), + @Param( + name = "wantError", + doc = "a regular expression matching the expected error message"), + }, + useStarlarkThread = true) + public Object assertFails(StarlarkCallable f, String wantError, StarlarkThread thread) + throws EvalException, InterruptedException { + Pattern pattern; + try { + pattern = Pattern.compile(wantError); + } catch (PatternSyntaxException unused) { + throw Starlark.errorf("invalid regexp: %s", wantError); + } + + try { + Starlark.call(thread, f, ImmutableList.of(), ImmutableMap.of()); + reportErrorf(thread, "evaluation succeeded unexpectedly (want error matching %s)", wantError); + } catch (EvalException ex) { + // Verify error matches expectation. + String msg = ex.getMessage(); + if (!pattern.matcher(msg).find()) { + reportErrorf(thread, "regular expression (%s) did not match error (%s)", pattern, msg); + } + } + return Starlark.NONE; + } + + @FormatMethod + private static void reportErrorf(StarlarkThread thread, String format, Object... args) { + thread.getThreadLocal(Reporter.class).reportError(thread, String.format(format, args)); + } + // Constructor for simple structs, for testing. @StarlarkMethod(name = "struct", documented = false, extraKeywords = @Param(name = "kwargs")) public Struct struct(Dict kwargs) throws EvalException { @@ -110,9 +147,15 @@ public Struct mutablestruct(Dict kwargs) throws EvalException { @StarlarkMethod( name = "freeze", - documented = false, - parameters = {@Param(name = "x")}) - public void freeze(Object x) throws EvalException { + doc = "Shallow-freezes the operand. With no argument, freezes the thread.", + parameters = {@Param(name = "x", defaultValue = "unbound")}, + useStarlarkThread = true) + public void freeze(Object x, StarlarkThread thread) throws EvalException { + if (x == Starlark.UNBOUND) { + thread.mutability().close(); + return; + } + if (x instanceof Mutability.Freezable) { ((Mutability.Freezable) x).unsafeShallowFreeze(); } else { diff --git a/src/test/java/net/starlark/java/eval/testdata/function.star b/src/test/java/net/starlark/java/eval/testdata/function.star index 3cf13df8a756bf..99394092c60670 100644 --- a/src/test/java/net/starlark/java/eval/testdata/function.star +++ b/src/test/java/net/starlark/java/eval/testdata/function.star @@ -79,3 +79,98 @@ x() ### 'string' object is not callable --- # Regression test for a type mismatch crash (b/168743413). getattr(1, []) ### parameter 'name' got value of type 'list', want 'string' + +--- +# assert_fails. This will be more useful when we add lambda. +def divzero(): 1//0 +assert_fails(divzero, 'integer division by zero') + + +--- +# Test of nested def statements. +def adder(x): + def add(x, y): return x + y # no free vars + def adder(y): return add(x, y) # freevars={x, add} + return adder + +add3 = adder(3) +assert_eq(add3(1), 4) +assert_eq(add3(-1), 2) + +addlam = adder("lam") +assert_eq(addlam("bda"), "lambda") +assert_eq(addlam("bada"), "lambada") + + +# Test of stateful function values. +def makerand(seed=0): + "makerand returns a stateful generator of small pseudorandom numbers." + state = [seed] + def rand(): + "rand returns the next pseudorandom number in the sequence." + state[0] = ((state[0] + 7207) * 9941) & 0xfff + return state[0] + return rand + +rand1 = makerand(123) +rand2 = makerand(123) +assert_eq([rand1() for _ in range(10)], [3786, 133, 796, 1215, 862, 1961, 3088, 4035, 1458, 3981]) +assert_eq([rand2() for _ in range(10)], [3786, 133, 796, 1215, 862, 1961, 3088, 4035, 1458, 3981]) + +# different seed +rand3 = makerand() +assert_eq([rand3() for _ in range(10)], [1651, 1570, 3261, 3508, 1335, 1846, 2657, 3880, 699, 3594]) + +# Attempt to mutate frozen closure state. +freeze() +assert_fails(rand3, "trying to mutate a frozen list value") + +--- +# recursion is disallowed +def fib(x): + return x if x < 2 else fib(x-1)+fib(x-2) + +# TODO(adonovan): use lambda. +def fib10(): return fib(10) +assert_fails(fib10, "function 'fib' called recursively") + +--- +# The recursion check breaks function encapsulation: +# A function g that internally uses a higher-order helper function +# such as 'call' (or Python's map and reduce) cannot itself be +# called from within an active call of that helper. +def call(f): f() +def g(): call(list) +# TODO(adonovan): use lambda. +def call_g(): call(g) +assert_fails(call_g, "function 'call' called recursively") + +--- +# The recursion check is based on the syntactic equality +# (same def statement), not function value equivalence. +def eta(f): + # TODO(adonovan): use lambda + def call(): + f() + return call + +def nop(): pass + +# fn1 and fn2 are both created by 'def call', +# but they are distinct and close over different values... +fn1 = eta(nop) +fn2 = eta(fn1) +assert_eq(str(fn1), '') +assert_eq(str(fn2), '') +assert_(fn1 != fn2) + +# ...yet both cannot be called in the same thread: +assert_fails(fn2, "function 'call' called recursively") + +# This rule prevents users from writing the Y combinator, +# which creates a new closure at each step of the recursion. +# TODO(adonovan): enable test when we have lambda. +# Y = lambda f: (lambda x: x(x))(lambda y: f(lambda *args: y(y)(*args))) +# fibgen = lambda fib: lambda x: (x if x<2 else fib(x-1)+fib(x-2)) +# fib2 = Y(fibgen) +# assert_fails(lambda: [fib2(x) for x in range(10)], "function lambda called recursively") diff --git a/src/test/java/net/starlark/java/syntax/ResolverTest.java b/src/test/java/net/starlark/java/syntax/ResolverTest.java index 784a84743a2763..feaa7a694af6ee 100644 --- a/src/test/java/net/starlark/java/syntax/ResolverTest.java +++ b/src/test/java/net/starlark/java/syntax/ResolverTest.java @@ -281,16 +281,6 @@ public void testTopLevelForFails() throws Exception { "for i in []: 0\n"); } - @Test - public void testNestedFunctionFails() throws Exception { - assertInvalid( - "nested functions are not allowed. Move the function to the top level", // - "def func(a):", - " def bar(): return 0", - " return bar()", - ""); - } - @Test public void testComprehension() throws Exception { // The operand of the first for clause is resolved outside the comprehension block. @@ -419,6 +409,25 @@ public void testBindingScopeAndIndex() throws Exception { " [(aᴸ₁, bᴳ₁) for aᴸ₁ in aᴸ₀]"); checkBindings("load('module', aᴳ₀='a', bᴳ₁='b')"); + + // Nested functions have lexical scope. + checkBindings( + "def fᴳ₀(aᴸ₀, bᶜ₁):", // b is a cell: an indirect local shared with nested functions + " aᴸ₀", + " def gᴸ₂(cᴸ₀):", + " bᶠ₀, cᴸ₀"); // b is a free var: a reference to a cell of an outer function + + // Multiply nested functions. + // Load still binds globally, for now, but soon it will bind locally. + checkBindings( + "load('module', aᴳ₀='a')", // eventually: aᶜ₀ + "bᴳ₁= 0", + "def fᴳ₂(cᶜ₀):", + " aᴳ₀, bᴳ₁, cᶜ₀", // eventually: aᶠ0 + " def gᶜ₁(dᶜ₀):", + " aᴳ₀, bᴳ₁, cᶠ₀, dᶜ₀, fᴳ₂", + " def hᶜ₁(eᴸ₀):", + " aᴳ₀, bᴳ₁, cᶠ₀, dᶠ₁, eᴸ₀, fᴳ₂, gᶠ₂, hᶠ₃"); } // checkBindings verifies the binding (scope and index) of each identifier.