diff --git a/src/main/java/net/starlark/java/eval/Eval.java b/src/main/java/net/starlark/java/eval/Eval.java index 5cfa7ed4033ca7..fd8a9a8c0e1cb3 100644 --- a/src/main/java/net/starlark/java/eval/Eval.java +++ b/src/main/java/net/starlark/java/eval/Eval.java @@ -175,13 +175,32 @@ private static void execDef(StarlarkThread.Frame fr, DefStatement node) defaults = EMPTY; } + // Capture the cells of the function's + // free variables from the lexical environment. + Object[] freevars = new Object[rfn.getFreeVars().size()]; + int i = 0; + for (Resolver.Binding bind : rfn.getFreeVars()) { + // Unlike expr(Identifier), we want the cell itself, not its content. + switch (bind.getScope()) { + case FREE: + freevars[i++] = fn(fr).getFreeVar(bind.getIndex()); + break; + case CELL: + freevars[i++] = fr.locals[bind.getIndex()]; + break; + default: + throw new IllegalStateException("unexpected: " + bind); + } + } + // Nested functions use the same globalIndex as their enclosing function, // since both were compiled from the same Program. StarlarkFunction fn = fn(fr); assignIdentifier( fr, node.getIdentifier(), - new StarlarkFunction(rfn, Tuple.wrap(defaults), fn.getModule(), fn.globalIndex)); + new StarlarkFunction( + rfn, fn.getModule(), fn.globalIndex, Tuple.wrap(defaults), Tuple.wrap(freevars))); } private static TokenKind execIf(StarlarkThread.Frame fr, IfStatement node) @@ -231,8 +250,8 @@ private static void execLoad(StarlarkThread.Frame fr, LoadStatement node) throws // loads bind file-locally. Either way, the resolver should designate // the proper scope of binding.getLocalName() and this should become // simply assign(binding.getLocalName(), value). - // Currently, we update the module but not module.exportedGlobals; - // changing it to fr.locals.put breaks a test. TODO(adonovan): find out why. + // Currently, we update the module but not module.exportedGlobals. + // Change it to a local binding now that closures are supported. fn(fr).setGlobal(binding.getLocalName().getBinding().getIndex(), value); } } @@ -328,6 +347,9 @@ private static void assignIdentifier(StarlarkThread.Frame fr, Identifier id, Obj case LOCAL: fr.locals[bind.getIndex()] = value; break; + case CELL: + ((StarlarkFunction.Cell) fr.locals[bind.getIndex()]).x = value; + break; case GLOBAL: // Updates a module binding and sets its 'exported' flag. // (Only load bindings are not exported. @@ -637,6 +659,12 @@ private static Object evalIdentifier(StarlarkThread.Frame fr, Identifier id) case LOCAL: result = fr.locals[bind.getIndex()]; break; + case CELL: + result = ((StarlarkFunction.Cell) fr.locals[bind.getIndex()]).x; + break; + case FREE: + result = fn(fr).getFreeVar(bind.getIndex()).x; + break; case GLOBAL: result = fn(fr).getGlobal(bind.getIndex()); break; diff --git a/src/main/java/net/starlark/java/eval/Starlark.java b/src/main/java/net/starlark/java/eval/Starlark.java index 027f9c6d9f3cc1..e4e69081b91d89 100644 --- a/src/main/java/net/starlark/java/eval/Starlark.java +++ b/src/main/java/net/starlark/java/eval/Starlark.java @@ -868,8 +868,6 @@ public static Object execFile( */ public static Object execFileProgram(Program prog, Module module, StarlarkThread thread) throws EvalException, InterruptedException { - Tuple defaultValues = Tuple.empty(); - Resolver.Function rfn = prog.getResolvedFunction(); // A given Module may be passed to execFileProgram multiple times in sequence, @@ -884,7 +882,13 @@ public static Object execFileProgram(Program prog, Module module, StarlarkThread // two array lookups. int[] globalIndex = module.getIndicesOfGlobals(rfn.getGlobals()); - StarlarkFunction toplevel = new StarlarkFunction(rfn, defaultValues, module, globalIndex); + StarlarkFunction toplevel = + new StarlarkFunction( + rfn, + module, + globalIndex, + /*defaultValues=*/ Tuple.empty(), + /*freevars=*/ Tuple.empty()); return Starlark.fastcall(thread, toplevel, EMPTY, EMPTY); } @@ -928,10 +932,10 @@ public static StarlarkFunction newExprFunction( ParserInput input, FileOptions options, Module module) throws SyntaxError.Exception { Expression expr = Expression.parse(input, options); Program prog = Program.compileExpr(expr, module, options); - Tuple defaultValues = Tuple.empty(); Resolver.Function rfn = prog.getResolvedFunction(); int[] globalIndex = module.getIndicesOfGlobals(rfn.getGlobals()); // see execFileProgram - return new StarlarkFunction(rfn, defaultValues, module, globalIndex); + return new StarlarkFunction( + rfn, module, globalIndex, /*defaultValues=*/ Tuple.empty(), /*freevars=*/ Tuple.empty()); } /** diff --git a/src/main/java/net/starlark/java/eval/StarlarkFunction.java b/src/main/java/net/starlark/java/eval/StarlarkFunction.java index 0889d8c261abe6..4277829efe9f3b 100644 --- a/src/main/java/net/starlark/java/eval/StarlarkFunction.java +++ b/src/main/java/net/starlark/java/eval/StarlarkFunction.java @@ -36,15 +36,33 @@ public final class StarlarkFunction implements StarlarkCallable { final Resolver.Function rfn; - final int[] globalIndex; // index in Module.globals of ith Program global (binding index) private final Module module; // a function closes over its defining module + + // Index in Module.globals of ith Program global (Resolver.Binding(GLOBAL).index). + // See explanation at Starlark.execFileProgram. + final int[] globalIndex; + + // Default values of optional parameters. + // Indices correspond to the subsequence of parameters after the initial + // required parameters and before *args/**kwargs. + // Contain MANDATORY for the required keyword-only parameters. private final Tuple defaultValues; - StarlarkFunction(Resolver.Function rfn, Tuple defaultValues, Module module, int[] globalIndex) { + // Cells (shared locals) of enclosing functions. + // Indexed by Resolver.Binding(FREE).index values. + private final Tuple freevars; + + StarlarkFunction( + Resolver.Function rfn, + Module module, + int[] globalIndex, + Tuple defaultValues, + Tuple freevars) { this.rfn = rfn; - this.globalIndex = globalIndex; this.module = module; + this.globalIndex = globalIndex; this.defaultValues = defaultValues; + this.freevars = freevars; } // Sets a global variable, given its index in this function's compiled Program. @@ -153,9 +171,6 @@ public Module getModule() { @Override public Object fastcall(StarlarkThread thread, Object[] positional, Object[] named) throws EvalException, InterruptedException { - if (thread.mutability().isFrozen()) { - throw Starlark.errorf("Trying to call in frozen environment"); - } if (!thread.isRecursionAllowed() && thread.isRecursiveCall(this)) { throw Starlark.errorf("function '%s' called recursively", getName()); } @@ -167,9 +182,19 @@ public Object fastcall(StarlarkThread thread, Object[] positional, Object[] name StarlarkThread.Frame fr = thread.frame(0); fr.locals = new Object[rfn.getLocals().size()]; System.arraycopy(arguments, 0, fr.locals, 0, rfn.getParameterNames().size()); + + // Spill indicated locals to cells. + for (int index : rfn.getCellIndices()) { + fr.locals[index] = new Cell(fr.locals[index]); + } + return Eval.execFunctionBody(fr, rfn.getBody()); } + Cell getFreeVar(int index) { + return (Cell) freevars.get(index); + } + @Override public void repr(Printer printer) { // TODO(adonovan): use the file name instead. But that's a breaking Bazel change. @@ -376,9 +401,20 @@ public boolean isImmutable() { } // The MANDATORY sentinel indicates a slot in the defaultValues - // tuple corresponding to a required parameter. It is not visible - // to Java or Starlark code. + // tuple corresponding to a required parameter. + // It is not visible to Java or Starlark code. static final Object MANDATORY = new Mandatory(); private static class Mandatory implements StarlarkValue {} + + // A Cell is a local variable shared between an inner and an outer function. + // It is a StarlarkValue because it is a stack operand and a Tuple element, + // but it is not visible to Java or Starlark code. + static final class Cell implements StarlarkValue { + Object x; + + Cell(Object x) { + this.x = x; + } + } } diff --git a/src/main/java/net/starlark/java/eval/StarlarkThread.java b/src/main/java/net/starlark/java/eval/StarlarkThread.java index 5019d003a8c7bc..635dc5c1505653 100644 --- a/src/main/java/net/starlark/java/eval/StarlarkThread.java +++ b/src/main/java/net/starlark/java/eval/StarlarkThread.java @@ -32,15 +32,17 @@ * per-thread application state (see {@link #setThreadLocal}) that passes through Starlark functions * but does not directly affect them, such as information about the BUILD file being loaded. * - *

Every {@code StarlarkThread} has a {@link Mutability} field, and must be used within a - * function that creates and closes this {@link Mutability} with the try-with-resource pattern. This - * {@link Mutability} is also used when initializing mutable objects within that {@code - * StarlarkThread}. When the {@code Mutability} is closed at the end of the computation, it freezes - * the {@code StarlarkThread} along with all of those objects. This pattern enforces the discipline - * that there should be no dangling mutable {@code StarlarkThread}, or concurrency between - * interacting {@code StarlarkThread}s. It is a Starlark-level error to attempt to mutate a frozen - * {@code StarlarkThread} or its objects, but it is a Java-level error to attempt to mutate an - * unfrozen {@code StarlarkThread} or its objects from within a different {@code StarlarkThread}. + *

StarlarkThreads are not thread-safe: they should be confined to a single Java thread. + * + *

Every StarlarkThread has an associated {@link Mutability}, which should be created for that + * thread, and closed once the thread's work is done. (A try-with-resources statement is handy for + * this purpose.) Starlark values created by the thread are associated with the thread's Mutability, + * so that when the Mutability is closed at the end of the computation, all the values created by + * the thread become frozen. This pattern ensures that all Starlark values are frozen before they + * are published to another thread, and thus that concurrently executing Starlark threads are free + * from data races. Once a thread's mutability is frozen, the thread is unlikely to be useful for + * further computation because it can no longer create mutable values. (This is occasionally + * valuable in tests.) */ public final class StarlarkThread { @@ -136,7 +138,8 @@ static final class Frame implements Debug.Frame { private boolean errorLocationSet; // The locals of this frame, if fn is a StarlarkFunction, otherwise null. - // Set by StarlarkFunction.fastcall. + // Set by StarlarkFunction.fastcall. Elements may be regular Starlark + // values, or wrapped in StarlarkFunction.Cells if shared with a nested function. @Nullable Object[] locals; @Nullable private Object profileSpan; // current span of walltime call profiler @@ -181,8 +184,12 @@ public ImmutableMap getLocals() { ImmutableMap.Builder env = ImmutableMap.builder(); if (fn instanceof StarlarkFunction) { for (int i = 0; i < locals.length; i++) { - if (locals[i] != null) { - env.put(((StarlarkFunction) fn).rfn.getLocals().get(i).getName(), locals[i]); + Object local = locals[i]; + if (local != null) { + if (local instanceof StarlarkFunction.Cell) { + local = ((StarlarkFunction.Cell) local).x; + } + env.put(((StarlarkFunction) fn).rfn.getLocals().get(i).getName(), local); } } } @@ -332,9 +339,9 @@ boolean isRecursiveCall(StarlarkFunction fn) { // Find fn buried within stack. (The top of the stack is assumed to be fn.) for (int i = callstack.size() - 2; i >= 0; --i) { Frame fr = callstack.get(i); - // TODO(adonovan): compare code, not closure values, otherwise - // one can defeat this check by writing the Y combinator. - if (fr.fn.equals(fn)) { + // We compare code, not closure values, otherwise one can defeat the + // check by writing the Y combinator. + if (fr.fn instanceof StarlarkFunction && ((StarlarkFunction) fr.fn).rfn.equals(fn.rfn)) { return true; } } diff --git a/src/main/java/net/starlark/java/syntax/Resolver.java b/src/main/java/net/starlark/java/syntax/Resolver.java index e3b0a62aa8608b..9369db664a0128 100644 --- a/src/main/java/net/starlark/java/syntax/Resolver.java +++ b/src/main/java/net/starlark/java/syntax/Resolver.java @@ -50,21 +50,18 @@ public final class Resolver extends NodeVisitor { // including the spec. // - move the "no if statements at top level" check to bazel's check{Build,*}Syntax // (that's a spec change), or put it behind a FileOptions flag (no spec change). - // - remove restriction on nested def: - // 1. use FREE for scope of references to outer LOCALs, which become CELLs. - // 2. implement closures in eval/. - // - make loads bind locals by default. + // - make loads bind locals by default (depends on closures). /** Scope discriminates the scope of a binding: global, local, etc. */ public enum Scope { /** Binding is local to a function, comprehension, or file (e.g. load). */ LOCAL, - /** Binding occurs outside any function or comprehension. */ + /** Binding is non-local and occurs outside any function or comprehension. */ GLOBAL, /** Binding is local to a function, comprehension, or file, but shared with nested functions. */ - CELL, // TODO(adonovan): implement nested def + CELL, /** Binding is an implicit parameter whose value is the CELL of some enclosing function. */ - FREE, // TODO(adonovan): implement nested def + FREE, /** Binding is predeclared by the application (e.g. glob in Bazel). */ PREDECLARED, /** Binding is predeclared by the core (e.g. None). */ @@ -81,8 +78,8 @@ public String toString() { * Binding. */ public static final class Binding { - private final Scope scope; - private final int index; // index within function (LOCAL) or module (GLOBAL) + private Scope scope; + private final int index; // index within frame (LOCAL/CELL), freevars (FREE), or module (GLOBAL) @Nullable private final Identifier first; // first binding use, if syntactic private Binding(Scope scope, int index, @Nullable Identifier first) { @@ -102,7 +99,10 @@ public Scope getScope() { return scope; } - /** Returns the index of a binding within its function (LOCAL) or module (GLOBAL). */ + /** + * Returns the index of a binding within its function's frame (LOCAL/CELL), freevars (FREE), or + * module (GLOBAL). + */ public int getIndex() { return index; } @@ -129,10 +129,9 @@ public static final class Function { private final ImmutableList parameterNames; private final boolean isToplevel; private final ImmutableList locals; - // TODO(adonovan): move this to Program, but that requires communication - // between resolveFile and compileFile, which depends on use doing the TODO - // described at Program.compileResolvedFile and eliminating that function. - private final ImmutableList globals; + private final int[] cellIndices; + private final ImmutableList freevars; + private final ImmutableList globals; // TODO(adonovan): move to Program. private Function( String name, @@ -143,6 +142,7 @@ private Function( boolean hasKwargs, int numKeywordOnlyParams, List locals, + List freevars, List globals) { this.name = name; this.location = loc; @@ -160,7 +160,23 @@ private Function( this.isToplevel = name.equals(""); this.locals = ImmutableList.copyOf(locals); + this.freevars = ImmutableList.copyOf(freevars); this.globals = ImmutableList.copyOf(globals); + + // Create an index of the locals that are cells. + int ncells = 0; + int nlocals = locals.size(); + for (int i = 0; i < nlocals; i++) { + if (locals.get(i).scope == Scope.CELL) { + ncells++; + } + } + this.cellIndices = new int[ncells]; + for (int i = 0, j = 0; i < nlocals; i++) { + if (locals.get(i).scope == Scope.CELL) { + cellIndices[j++] = i; + } + } } /** @@ -177,6 +193,14 @@ public ImmutableList getLocals() { return locals; } + /** + * Returns the indices within {@code getLocals()} of the "cells", that is, local variables of + * thus function that are shared with nested functions. The caller must not modify the result. + */ + public int[] getCellIndices() { + return cellIndices; + } + /** * Returns the list of names of globals referenced by this function. The order matches the * indices used in compiled code. @@ -185,6 +209,17 @@ public ImmutableList getGlobals() { return globals; } + /** + * Returns the list of enclosing CELL or FREE bindings referenced by this function. At run time, + * these values, all of which are cells containing variables local to some enclosing function, + * will be stored in the closure. (CELL bindings in this list are local to the immediately + * enclosing function, while FREE bindings pass through one or more intermediate enclosing + * functions.) + */ + public ImmutableList getFreeVars() { + return freevars; + } + /** Returns the location of the function's identifier. */ public Location getLocation() { return location; @@ -293,15 +328,24 @@ private static class Block { @Nullable private final Block parent; // enclosing block, or null for tail of list @Nullable Node syntax; // Comprehension, DefStatement, StarlarkFile, or null private final ArrayList frame; // accumulated locals of enclosing function + // Accumulated CELL/FREE bindings of the enclosing function that will provide + // the values for the free variables of this function; see Function.getFreeVars. + // Null for toplevel functions and expressions, which have no free variables. + @Nullable private final ArrayList freevars; // Bindings for names defined in this block. // Also, as an optimization, memoized lookups of enclosing bindings. private final Map bindings = new HashMap<>(); - Block(@Nullable Block parent, @Nullable Node syntax, ArrayList frame) { + Block( + @Nullable Block parent, + @Nullable Node syntax, + ArrayList frame, + @Nullable ArrayList freevars) { this.parent = parent; this.syntax = syntax; this.frame = frame; + this.freevars = freevars; } } @@ -309,7 +353,7 @@ private static class Block { private final FileOptions options; private final Module module; // List whose order defines the numbering of global variables in this program. - private final ArrayList globals = new ArrayList<>(); + private final List globals = new ArrayList<>(); // A cache of PREDECLARED, UNIVERSAL, and GLOBAL bindings queried from the module. private final Map toplevel = new HashMap<>(); // Linked list of blocks, innermost first, for functions and comprehensions and (finally) file. @@ -339,7 +383,6 @@ private void errorf(Location loc, String format, Object... args) { * are sometimes used before their definition point (e.g. functions are not necessarily declared * in order). */ - // TODO(adonovan): eliminate this first pass by using go.starlark.net one-pass approach. private void createBindingsForBlock(Iterable stmts) { for (Statement stmt : stmts) { createBindings(stmt); @@ -440,21 +483,15 @@ public void visit(Identifier id) { private Binding use(Identifier id) { String name = id.getName(); - // local (to function, comprehension, or file)? - for (Block b = locals; b != null; b = b.parent) { - Binding bind = b.bindings.get(name); - if (bind != null) { - // Optimization: memoize lookup of an outer local - // in an inner block, to avoid repeated walks. - if (b != locals) { - locals.bindings.put(name, bind); - } - return bind; - } + // Locally defined in this function, comprehension, + // or file block, or an enclosing one? + Binding bind = lookupLexical(name, locals); + if (bind != null) { + return bind; } - // toplevel (global, predeclared, universal)? - Binding bind = toplevel.get(name); + // Defined at toplevel (global, predeclared, universal)? + bind = toplevel.get(name); if (bind != null) { return bind; } @@ -493,6 +530,49 @@ private Binding use(Identifier id) { return bind; } + // lookupLexical finds a lexically enclosing local binding of the name, + // plumbing it through enclosing functions as needed. + private static Binding lookupLexical(String name, Block b) { + Binding bind = b.bindings.get(name); + if (bind != null) { + return bind; + } + + if (b.parent != null) { + bind = lookupLexical(name, b.parent); + + // If a local binding was found in a parent block, + // and this block is a function, then it is a free variable + // of this function and must be plumbed through. + // Add an implicit FREE binding (a hidden parameter) to this function, + // and record the outer binding that will supply its value when + // we construct the closure. + // Also, mark the outer LOCAL as a CELL: a shared, indirect local. + // (For a comprehension block there's nothing to do, + // because it's part of the same frame as the enclosing block.) + // + // This step may occur many times if the lookupLexical + // recursion returns through many functions. + // TODO(adonovan): make this 'DEF or LAMBDA' when we have lambda. + if (bind != null && b.syntax instanceof DefStatement) { + Scope scope = bind.getScope(); + if (scope == Scope.LOCAL || scope == Scope.FREE || scope == Scope.CELL) { + if (scope == Scope.LOCAL) { + bind.scope = Scope.CELL; + } + int index = b.freevars.size(); + b.freevars.add(bind); + bind = new Binding(Scope.FREE, index, bind.first); + } + } + + // Memoize, to avoid duplicate free vars and repeated walks. + b.bindings.put(name, bind); + } + + return bind; + } + @Override public void visit(ReturnStatement node) { if (locals.syntax instanceof StarlarkFile) { @@ -601,7 +681,7 @@ public void visit(Comprehension node) { // A comprehension defines a distinct lexical block // in the same function's frame. - pushLocalBlock(node, this.locals.frame); + pushLocalBlock(node, this.locals.frame, this.locals.freevars); for (Comprehension.Clause clause : clauses) { if (clause instanceof Comprehension.For) { @@ -628,9 +708,6 @@ public void visit(Comprehension node) { @Override public void visit(DefStatement node) { - if (!(locals.syntax instanceof StarlarkFile)) { - errorf(node, "nested functions are not allowed. Move the function to the top level."); - } node.setResolvedFunction( resolveFunction( node, @@ -656,7 +733,8 @@ private Function resolveFunction( // Enter function block. ArrayList frame = new ArrayList<>(); - pushLocalBlock(def, frame); + ArrayList freevars = new ArrayList<>(); + pushLocalBlock(def, frame, freevars); // Check parameter order and convert to run-time order: // positionals, keyword-only, *args, **kwargs. @@ -743,6 +821,7 @@ private Function resolveFunction( starStar != null, numKeywordOnlyParams, frame, + freevars, globals); } @@ -879,7 +958,7 @@ public static void resolveFile(StarlarkFile file, Module module) { } ArrayList frame = new ArrayList<>(); - r.pushLocalBlock(file, frame); + r.pushLocalBlock(file, frame, /*freevars=*/ null); // First pass: creating bindings for statements in this block. r.createBindingsForBlock(stmts); @@ -911,6 +990,7 @@ public static void resolveFile(StarlarkFile file, Module module) { /*hasKwargs=*/ false, /*numKeywordOnlyParams=*/ 0, frame, + /*freevars=*/ ImmutableList.of(), r.globals)); } @@ -925,7 +1005,7 @@ public static Function resolveExpr(Expression expr, Module module, FileOptions o Resolver r = new Resolver(errors, module, options); ArrayList frame = new ArrayList<>(); - r.pushLocalBlock(null, frame); // for bindings in list comprehensions + r.pushLocalBlock(null, frame, /*freevars=*/ null); // for bindings in list comprehensions r.visit(expr); r.popLocalBlock(); @@ -943,11 +1023,13 @@ public static Function resolveExpr(Expression expr, Module module, FileOptions o /*hasKwargs=*/ false, /*numKeywordOnlyParams=*/ 0, frame, + /*freevars=*/ ImmutableList.of(), r.globals); } - private void pushLocalBlock(Node syntax, ArrayList frame) { - locals = new Block(locals, syntax, frame); + private void pushLocalBlock( + Node syntax, ArrayList frame, @Nullable ArrayList freevars) { + locals = new Block(locals, syntax, frame, freevars); } private void popLocalBlock() { diff --git a/src/test/java/net/starlark/java/eval/BUILD b/src/test/java/net/starlark/java/eval/BUILD index 4d0972d22c223f..6091dc5983b208 100644 --- a/src/test/java/net/starlark/java/eval/BUILD +++ b/src/test/java/net/starlark/java/eval/BUILD @@ -54,6 +54,7 @@ java_test( "//src/main/java/net/starlark/java/eval", "//src/main/java/net/starlark/java/lib/json", "//src/main/java/net/starlark/java/syntax", + "//third_party:error_prone_annotations", "//third_party:guava", ], ) diff --git a/src/test/java/net/starlark/java/eval/ScriptTest.java b/src/test/java/net/starlark/java/eval/ScriptTest.java index ccf89f5e556d66..2e5e499ad6d559 100644 --- a/src/test/java/net/starlark/java/eval/ScriptTest.java +++ b/src/test/java/net/starlark/java/eval/ScriptTest.java @@ -20,6 +20,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.io.Files; +import com.google.errorprone.annotations.FormatMethod; import java.io.File; import java.util.HashMap; import java.util.List; @@ -73,7 +74,7 @@ interface Reporter { public Object assertStarlark(Object cond, String msg, StarlarkThread thread) throws EvalException { if (!Starlark.truth(cond)) { - thread.getThreadLocal(Reporter.class).reportError(thread, "assert_: " + msg); + reportErrorf(thread, "assert_: %s", msg); } return Starlark.NONE; } @@ -88,12 +89,48 @@ public Object assertStarlark(Object cond, String msg, StarlarkThread thread) useStarlarkThread = true) public Object assertEq(Object x, Object y, StarlarkThread thread) throws EvalException { if (!x.equals(y)) { - String msg = String.format("assert_eq: %s != %s", Starlark.repr(x), Starlark.repr(y)); - thread.getThreadLocal(Reporter.class).reportError(thread, msg); + reportErrorf(thread, "assert_eq: %s != %s", Starlark.repr(x), Starlark.repr(y)); } return Starlark.NONE; } + @StarlarkMethod( + name = "assert_fails", + doc = "assert_fails asserts that evaluation of f() fails with the specified error", + parameters = { + @Param(name = "f", doc = "the Starlark function to call"), + @Param( + name = "wantError", + doc = "a regular expression matching the expected error message"), + }, + useStarlarkThread = true) + public Object assertFails(StarlarkCallable f, String wantError, StarlarkThread thread) + throws EvalException, InterruptedException { + Pattern pattern; + try { + pattern = Pattern.compile(wantError); + } catch (PatternSyntaxException unused) { + throw Starlark.errorf("invalid regexp: %s", wantError); + } + + try { + Starlark.call(thread, f, ImmutableList.of(), ImmutableMap.of()); + reportErrorf(thread, "evaluation succeeded unexpectedly (want error matching %s)", wantError); + } catch (EvalException ex) { + // Verify error matches expectation. + String msg = ex.getMessage(); + if (!pattern.matcher(msg).find()) { + reportErrorf(thread, "regular expression (%s) did not match error (%s)", pattern, msg); + } + } + return Starlark.NONE; + } + + @FormatMethod + private static void reportErrorf(StarlarkThread thread, String format, Object... args) { + thread.getThreadLocal(Reporter.class).reportError(thread, String.format(format, args)); + } + // Constructor for simple structs, for testing. @StarlarkMethod(name = "struct", documented = false, extraKeywords = @Param(name = "kwargs")) public Struct struct(Dict kwargs) throws EvalException { @@ -110,9 +147,15 @@ public Struct mutablestruct(Dict kwargs) throws EvalException { @StarlarkMethod( name = "freeze", - documented = false, - parameters = {@Param(name = "x")}) - public void freeze(Object x) throws EvalException { + doc = "Shallow-freezes the operand. With no argument, freezes the thread.", + parameters = {@Param(name = "x", defaultValue = "unbound")}, + useStarlarkThread = true) + public void freeze(Object x, StarlarkThread thread) throws EvalException { + if (x == Starlark.UNBOUND) { + thread.mutability().close(); + return; + } + if (x instanceof Mutability.Freezable) { ((Mutability.Freezable) x).unsafeShallowFreeze(); } else { diff --git a/src/test/java/net/starlark/java/eval/testdata/function.star b/src/test/java/net/starlark/java/eval/testdata/function.star index 3cf13df8a756bf..99394092c60670 100644 --- a/src/test/java/net/starlark/java/eval/testdata/function.star +++ b/src/test/java/net/starlark/java/eval/testdata/function.star @@ -79,3 +79,98 @@ x() ### 'string' object is not callable --- # Regression test for a type mismatch crash (b/168743413). getattr(1, []) ### parameter 'name' got value of type 'list', want 'string' + +--- +# assert_fails. This will be more useful when we add lambda. +def divzero(): 1//0 +assert_fails(divzero, 'integer division by zero') + + +--- +# Test of nested def statements. +def adder(x): + def add(x, y): return x + y # no free vars + def adder(y): return add(x, y) # freevars={x, add} + return adder + +add3 = adder(3) +assert_eq(add3(1), 4) +assert_eq(add3(-1), 2) + +addlam = adder("lam") +assert_eq(addlam("bda"), "lambda") +assert_eq(addlam("bada"), "lambada") + + +# Test of stateful function values. +def makerand(seed=0): + "makerand returns a stateful generator of small pseudorandom numbers." + state = [seed] + def rand(): + "rand returns the next pseudorandom number in the sequence." + state[0] = ((state[0] + 7207) * 9941) & 0xfff + return state[0] + return rand + +rand1 = makerand(123) +rand2 = makerand(123) +assert_eq([rand1() for _ in range(10)], [3786, 133, 796, 1215, 862, 1961, 3088, 4035, 1458, 3981]) +assert_eq([rand2() for _ in range(10)], [3786, 133, 796, 1215, 862, 1961, 3088, 4035, 1458, 3981]) + +# different seed +rand3 = makerand() +assert_eq([rand3() for _ in range(10)], [1651, 1570, 3261, 3508, 1335, 1846, 2657, 3880, 699, 3594]) + +# Attempt to mutate frozen closure state. +freeze() +assert_fails(rand3, "trying to mutate a frozen list value") + +--- +# recursion is disallowed +def fib(x): + return x if x < 2 else fib(x-1)+fib(x-2) + +# TODO(adonovan): use lambda. +def fib10(): return fib(10) +assert_fails(fib10, "function 'fib' called recursively") + +--- +# The recursion check breaks function encapsulation: +# A function g that internally uses a higher-order helper function +# such as 'call' (or Python's map and reduce) cannot itself be +# called from within an active call of that helper. +def call(f): f() +def g(): call(list) +# TODO(adonovan): use lambda. +def call_g(): call(g) +assert_fails(call_g, "function 'call' called recursively") + +--- +# The recursion check is based on the syntactic equality +# (same def statement), not function value equivalence. +def eta(f): + # TODO(adonovan): use lambda + def call(): + f() + return call + +def nop(): pass + +# fn1 and fn2 are both created by 'def call', +# but they are distinct and close over different values... +fn1 = eta(nop) +fn2 = eta(fn1) +assert_eq(str(fn1), '') +assert_eq(str(fn2), '') +assert_(fn1 != fn2) + +# ...yet both cannot be called in the same thread: +assert_fails(fn2, "function 'call' called recursively") + +# This rule prevents users from writing the Y combinator, +# which creates a new closure at each step of the recursion. +# TODO(adonovan): enable test when we have lambda. +# Y = lambda f: (lambda x: x(x))(lambda y: f(lambda *args: y(y)(*args))) +# fibgen = lambda fib: lambda x: (x if x<2 else fib(x-1)+fib(x-2)) +# fib2 = Y(fibgen) +# assert_fails(lambda: [fib2(x) for x in range(10)], "function lambda called recursively") diff --git a/src/test/java/net/starlark/java/syntax/ResolverTest.java b/src/test/java/net/starlark/java/syntax/ResolverTest.java index 784a84743a2763..feaa7a694af6ee 100644 --- a/src/test/java/net/starlark/java/syntax/ResolverTest.java +++ b/src/test/java/net/starlark/java/syntax/ResolverTest.java @@ -281,16 +281,6 @@ public void testTopLevelForFails() throws Exception { "for i in []: 0\n"); } - @Test - public void testNestedFunctionFails() throws Exception { - assertInvalid( - "nested functions are not allowed. Move the function to the top level", // - "def func(a):", - " def bar(): return 0", - " return bar()", - ""); - } - @Test public void testComprehension() throws Exception { // The operand of the first for clause is resolved outside the comprehension block. @@ -419,6 +409,25 @@ public void testBindingScopeAndIndex() throws Exception { " [(aᴸ₁, bᴳ₁) for aᴸ₁ in aᴸ₀]"); checkBindings("load('module', aᴳ₀='a', bᴳ₁='b')"); + + // Nested functions have lexical scope. + checkBindings( + "def fᴳ₀(aᴸ₀, bᶜ₁):", // b is a cell: an indirect local shared with nested functions + " aᴸ₀", + " def gᴸ₂(cᴸ₀):", + " bᶠ₀, cᴸ₀"); // b is a free var: a reference to a cell of an outer function + + // Multiply nested functions. + // Load still binds globally, for now, but soon it will bind locally. + checkBindings( + "load('module', aᴳ₀='a')", // eventually: aᶜ₀ + "bᴳ₁= 0", + "def fᴳ₂(cᶜ₀):", + " aᴳ₀, bᴳ₁, cᶜ₀", // eventually: aᶠ0 + " def gᶜ₁(dᶜ₀):", + " aᴳ₀, bᴳ₁, cᶠ₀, dᶜ₀, fᴳ₂", + " def hᶜ₁(eᴸ₀):", + " aᴳ₀, bᴳ₁, cᶠ₀, dᶠ₁, eᴸ₀, fᴳ₂, gᶠ₂, hᶠ₃"); } // checkBindings verifies the binding (scope and index) of each identifier.