Skip to content

Commit

Permalink
Ensure that directories passed to Args.add_all can be expanded insi…
Browse files Browse the repository at this point in the history
…de the `map_each` function by `DirectoryExpander` even if they are nested in other data structures.

RELNOTES: None.
PiperOrigin-RevId: 369656737
  • Loading branch information
allevato authored and copybara-github committed Apr 21, 2021
1 parent ec94022 commit 2f5c575
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import javax.annotation.Nullable;
import net.starlark.java.eval.EvalException;
Expand All @@ -41,9 +42,11 @@
import net.starlark.java.eval.Starlark;
import net.starlark.java.eval.StarlarkCallable;
import net.starlark.java.eval.StarlarkFunction;
import net.starlark.java.eval.StarlarkIterable;
import net.starlark.java.eval.StarlarkSemantics;
import net.starlark.java.eval.StarlarkThread;
import net.starlark.java.eval.StarlarkValue;
import net.starlark.java.eval.Structure;
import net.starlark.java.syntax.Location;

/**
Expand Down Expand Up @@ -99,7 +102,7 @@ public void debugPrint(Printer printer) {
* Returns a set of directory artifacts which will need to be expanded for evaluating the
* encapsulated arguments during execution.
*/
public abstract ImmutableSet<Artifact> getDirectoryArtifacts();
public abstract ImmutableSet<Artifact> getDirectoryArtifacts() throws EvalException;

/** Returns the command line built by this {@link Args} object. */
public abstract CommandLine build();
Expand Down Expand Up @@ -434,13 +437,13 @@ private void addVectorArg(
if (value instanceof Depset) {
Depset starlarkNestedSet = (Depset) value;
NestedSet<?> nestedSet = starlarkNestedSet.getSet();
if (expandDirectories) {
if (expandDirectories || mapEach != null) {
potentialDirectoryArtifacts.add(nestedSet);
}
vectorArg = new StarlarkCustomCommandLine.VectorArg.Builder(nestedSet);
} else {
Sequence<?> starlarkList = (Sequence) value;
if (expandDirectories) {
if (expandDirectories || mapEach != null) {
scanForDirectories(starlarkList);
}
vectorArg = new StarlarkCustomCommandLine.VectorArg.Builder(starlarkList);
Expand Down Expand Up @@ -571,18 +574,49 @@ public Mutability mutability() {
}

@Override
public ImmutableSet<Artifact> getDirectoryArtifacts() {
public ImmutableSet<Artifact> getDirectoryArtifacts() throws EvalException {
for (NestedSet<?> collection : potentialDirectoryArtifacts) {
scanForDirectories(collection.toList());
}
potentialDirectoryArtifacts.clear();
return ImmutableSet.copyOf(directoryArtifacts);
}

private void scanForDirectories(Iterable<?> objects) {
private void scanForDirectories(Iterable<?> objects) throws EvalException {
for (Object object : objects) {
if (isDirectory(object)) {
directoryArtifacts.add((Artifact) object);
try {
scanForDirectoriesDeeply(object);
} catch (StackOverflowError unused) {
throw Starlark.errorf("nesting depth limit exceeded");
}
}
}

/**
* Walks recursively through the given object, collecting any component parts that are directory
* {@code Artifact}s.
*
* <p>At this time, the following data types are supported: dictionaries (both keys and values
* are checked), lists, tuples, and structs/Starlark providers.
*/
private void scanForDirectoriesDeeply(Object object) throws EvalException {
if (isDirectory(object)) {
directoryArtifacts.add((Artifact) object);
} else if (object instanceof Map) {
Map<?, ?> map = (Map) object;
for (Map.Entry<?, ?> entry : map.entrySet()) {
scanForDirectoriesDeeply(entry.getKey());
scanForDirectoriesDeeply(entry.getValue());
}
} else if (object instanceof StarlarkIterable) {
StarlarkIterable<?> iterable = (StarlarkIterable) object;
for (Object element : iterable) {
scanForDirectoriesDeeply(element);
}
} else if (object instanceof Structure) {
Structure struct = (Structure) object;
for (String fieldName : struct.getFieldNames()) {
scanForDirectoriesDeeply(struct.getValue(fieldName));
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3219,4 +3219,92 @@ public void testCallDirectoryExpanderWithWrongType() throws Exception {
CommandLine commandLine = args.build();
assertThrows(CommandLineExpansionException.class, commandLine::arguments);
}

@Test
public void addAll_failsWhenExpandingRecursiveDataStructure() throws Exception {
setRuleContext(createRuleContext("//foo"));
EvalException e =
assertThrows(
EvalException.class,
() ->
ev.exec(
"args = ruleContext.actions.args()",
"directory = ruleContext.actions.declare_directory('dir')",
"def _expand_dirs(l, dir_expander):",
" return [f.short_path for f in dir_expander.expand(l)]",
"l = [0, 1, 2]",
"l[1] = struct(loop = l)",
"args.add_all([l], map_each=_expand_dirs)"));

assertThat(e).hasMessageThat().contains("nesting depth limit exceeded");
}

@Test
public void getDirectoryArtifacts_returnsDirectoryFromNestedList() throws Exception {
setRuleContext(createRuleContext("//foo"));
ev.exec(
"args = ruleContext.actions.args()",
"dir = ruleContext.actions.declare_directory('dir')",
"def _get_paths(map_arg, expander): return []",
"args.add_all([['foo', dir]], map_each=_get_paths)");
Args args = (Args) ev.eval("args");
Artifact dir = (Artifact) ev.eval("dir");

assertThat(args.getDirectoryArtifacts()).containsExactly(dir);
}

@Test
public void getDirectoryArtifacts_returnsDirectoryFromDict() throws Exception {
setRuleContext(createRuleContext("//foo"));
ev.exec(
"args = ruleContext.actions.args()",
"dir = ruleContext.actions.declare_directory('dir')",
"def _get_paths(map_arg, expander): return []",
"args.add_all([{'key': dir}], map_each=_get_paths)");
Args args = (Args) ev.eval("args");
Artifact dir = (Artifact) ev.eval("dir");

assertThat(args.getDirectoryArtifacts()).containsExactly(dir);
}

@Test
public void getDirectoryArtifacts_returnsDirectoryFromStructField() throws Exception {
setRuleContext(createRuleContext("//foo"));
ev.exec(
"args = ruleContext.actions.args()",
"dir = ruleContext.actions.declare_directory('dir')",
"def _get_paths(map_arg, expander): return []",
"args.add_all([struct(field=dir)], map_each=_get_paths)");
Args args = (Args) ev.eval("args");
Artifact dir = (Artifact) ev.eval("dir");

assertThat(args.getDirectoryArtifacts()).containsExactly(dir);
}

@Test
public void getDirectoryArtifacts_returnsDirectoryFromComplexNesting() throws Exception {
setRuleContext(createRuleContext("//foo"));
ev.exec(
"args = ruleContext.actions.args()",
"dir = ruleContext.actions.declare_directory('dir')",
"def _get_paths(map_arg, expander): return []",
"args.add_all([struct(field=['foo', {'key': struct(field=dir)}])], map_each=_get_paths)");
Args args = (Args) ev.eval("args");
Artifact dir = (Artifact) ev.eval("dir");

assertThat(args.getDirectoryArtifacts()).containsExactly(dir);
}

@Test
public void getDirectoryArtifacts_doesNotExpandIfNoMapFunctionAndExpandDirectoriesIsFalse()
throws Exception {
setRuleContext(createRuleContext("//foo"));
ev.exec(
"args = ruleContext.actions.args()",
"dir = ruleContext.actions.declare_directory('dir')",
"args.add_all([struct(field=dir)], expand_directories=False)");
Args args = (Args) ev.eval("args");

assertThat(args.getDirectoryArtifacts()).isEmpty();
}
}

0 comments on commit 2f5c575

Please sign in to comment.