diff --git a/nullaway/src/main/java/com/uber/nullaway/LibraryModels.java b/nullaway/src/main/java/com/uber/nullaway/LibraryModels.java index 1f8e0be782..98d479f658 100644 --- a/nullaway/src/main/java/com/uber/nullaway/LibraryModels.java +++ b/nullaway/src/main/java/com/uber/nullaway/LibraryModels.java @@ -196,7 +196,7 @@ default ImmutableList customStreamNullabilitySpecs() { * * */ - final class MethodRef { + public final class MethodRef { public final String enclosingClass; diff --git a/nullaway/src/main/java/com/uber/nullaway/NullAway.java b/nullaway/src/main/java/com/uber/nullaway/NullAway.java index ee1944ad04..ab88cf7c1d 100644 --- a/nullaway/src/main/java/com/uber/nullaway/NullAway.java +++ b/nullaway/src/main/java/com/uber/nullaway/NullAway.java @@ -463,7 +463,8 @@ private void updateEnvironmentMapping(TreePath treePath, VisitorState state) { // 2. we keep info on all locals rather than just effectively final ones for simplicity EnclosingEnvironmentNullness.instance(state.context) .addEnvironmentMapping( - treePath.getLeaf(), analysis.getNullnessInfoBeforeNewContext(treePath, state, handler)); + treePath.getLeaf(), + analysis.getNullnessInfoBeforeNestedMethodNode(treePath, state, handler)); } private Symbol.MethodSymbol getSymbolOfSuperConstructor( diff --git a/nullaway/src/main/java/com/uber/nullaway/dataflow/AccessPathNullnessAnalysis.java b/nullaway/src/main/java/com/uber/nullaway/dataflow/AccessPathNullnessAnalysis.java index e68f05e8b5..2e1054a8e0 100644 --- a/nullaway/src/main/java/com/uber/nullaway/dataflow/AccessPathNullnessAnalysis.java +++ b/nullaway/src/main/java/com/uber/nullaway/dataflow/AccessPathNullnessAnalysis.java @@ -211,18 +211,23 @@ public Set getNonnullStaticFieldsBefore(TreePath path, Context context) } /** - * Get nullness info for local variables (and final fields) before some node + * Get nullness info for local variables (and final fields) before some node represented a nested + * method (lambda or anonymous class) * - * @param path tree path to some AST node within a method / lambda / initializer + * @param pathToNestedMethodNode tree path to some AST node representing a nested method * @param state visitor state - * @return nullness info for local variables just before the node + * @param handler handler instance + * @return nullness info for local variables just before the leaf of the tree path */ - public NullnessStore getNullnessInfoBeforeNewContext( - TreePath path, VisitorState state, Handler handler) { - NullnessStore store = dataFlow.resultBefore(path, state.context, nullnessPropagation); + public NullnessStore getNullnessInfoBeforeNestedMethodNode( + TreePath pathToNestedMethodNode, VisitorState state, Handler handler) { + NullnessStore store = + dataFlow.resultBefore(pathToNestedMethodNode, state.context, nullnessPropagation); if (store == null) { return NullnessStore.empty(); } + Predicate handlerPredicate = + handler.getAccessPathPredicateForNestedMethod(pathToNestedMethodNode, state); return store.filterAccessPaths( (ap) -> { boolean allAPNonRootElementsAreFinalFields = true; @@ -243,7 +248,7 @@ public NullnessStore getNullnessInfoBeforeNewContext( && e.getModifiers().contains(Modifier.FINAL)); } - return handler.includeApInfoInSavedContext(ap, state); + return handlerPredicate.test(ap); }); } diff --git a/nullaway/src/main/java/com/uber/nullaway/handlers/AccessPathPredicates.java b/nullaway/src/main/java/com/uber/nullaway/handlers/AccessPathPredicates.java new file mode 100644 index 0000000000..1b239d5ac5 --- /dev/null +++ b/nullaway/src/main/java/com/uber/nullaway/handlers/AccessPathPredicates.java @@ -0,0 +1,25 @@ +package com.uber.nullaway.handlers; + +import com.google.errorprone.VisitorState; +import com.sun.source.util.TreePath; +import com.uber.nullaway.dataflow.AccessPath; +import java.util.function.Predicate; + +/** + * {@link java.util.function.Predicate}s over {@link com.uber.nullaway.dataflow.AccessPath}s useful + * in defining handlers. + */ +public class AccessPathPredicates { + + /** + * An AccessPath predicate that always returns false. Used to optimize {@link + * CompositeHandler#getAccessPathPredicateForNestedMethod(TreePath, VisitorState)} + */ + static final Predicate FALSE_AP_PREDICATE = ap -> false; + + /** + * An AccessPath predicate that always returns true. Used to optimize {@link + * CompositeHandler#getAccessPathPredicateForNestedMethod(TreePath, VisitorState)} + */ + static final Predicate TRUE_AP_PREDICATE = ap -> true; +} diff --git a/nullaway/src/main/java/com/uber/nullaway/handlers/BaseNoOpHandler.java b/nullaway/src/main/java/com/uber/nullaway/handlers/BaseNoOpHandler.java index 3cc0e92f81..171836276b 100644 --- a/nullaway/src/main/java/com/uber/nullaway/handlers/BaseNoOpHandler.java +++ b/nullaway/src/main/java/com/uber/nullaway/handlers/BaseNoOpHandler.java @@ -31,6 +31,7 @@ import com.sun.source.tree.MethodInvocationTree; import com.sun.source.tree.MethodTree; import com.sun.source.tree.ReturnTree; +import com.sun.source.util.TreePath; import com.sun.tools.javac.code.Symbol; import com.sun.tools.javac.code.Types; import com.sun.tools.javac.util.Context; @@ -44,6 +45,7 @@ import com.uber.nullaway.dataflow.cfg.NullAwayCFGBuilder; import java.util.List; import java.util.Optional; +import java.util.function.Predicate; import javax.annotation.Nullable; import org.checkerframework.nullaway.dataflow.cfg.UnderlyingAST; import org.checkerframework.nullaway.dataflow.cfg.node.FieldAccessNode; @@ -199,8 +201,9 @@ public Optional onExpressionDereference( } @Override - public boolean includeApInfoInSavedContext(AccessPath accessPath, VisitorState state) { - return false; + public Predicate getAccessPathPredicateForNestedMethod( + TreePath path, VisitorState state) { + return AccessPathPredicates.FALSE_AP_PREDICATE; } @Override diff --git a/nullaway/src/main/java/com/uber/nullaway/handlers/CompositeHandler.java b/nullaway/src/main/java/com/uber/nullaway/handlers/CompositeHandler.java index a8eec51ac7..b05128a00a 100644 --- a/nullaway/src/main/java/com/uber/nullaway/handlers/CompositeHandler.java +++ b/nullaway/src/main/java/com/uber/nullaway/handlers/CompositeHandler.java @@ -22,6 +22,9 @@ package com.uber.nullaway.handlers; +import static com.uber.nullaway.handlers.AccessPathPredicates.FALSE_AP_PREDICATE; +import static com.uber.nullaway.handlers.AccessPathPredicates.TRUE_AP_PREDICATE; + import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.errorprone.VisitorState; @@ -32,6 +35,7 @@ import com.sun.source.tree.MethodInvocationTree; import com.sun.source.tree.MethodTree; import com.sun.source.tree.ReturnTree; +import com.sun.source.util.TreePath; import com.sun.tools.javac.code.Symbol; import com.sun.tools.javac.code.Types; import com.sun.tools.javac.util.Context; @@ -45,6 +49,7 @@ import com.uber.nullaway.dataflow.cfg.NullAwayCFGBuilder; import java.util.List; import java.util.Optional; +import java.util.function.Predicate; import javax.annotation.Nullable; import org.checkerframework.nullaway.dataflow.cfg.UnderlyingAST; import org.checkerframework.nullaway.dataflow.cfg.node.FieldAccessNode; @@ -253,12 +258,24 @@ public Optional onExpressionDereference( } @Override - public boolean includeApInfoInSavedContext(AccessPath accessPath, VisitorState state) { - boolean shouldFilter = false; + public Predicate getAccessPathPredicateForNestedMethod( + TreePath path, VisitorState state) { + Predicate filter = FALSE_AP_PREDICATE; for (Handler h : handlers) { - shouldFilter |= h.includeApInfoInSavedContext(accessPath, state); + Predicate curFilter = h.getAccessPathPredicateForNestedMethod(path, state); + // here we do some optimization, to try to avoid unnecessarily returning a deeply nested + // Predicate object (which would be more costly to test) + if (curFilter != FALSE_AP_PREDICATE) { + if (curFilter == TRUE_AP_PREDICATE) { + return curFilter; + } else if (filter == FALSE_AP_PREDICATE) { + filter = curFilter; + } else { + filter = filter.or(curFilter); + } + } } - return shouldFilter; + return filter; } @Override diff --git a/nullaway/src/main/java/com/uber/nullaway/handlers/Handler.java b/nullaway/src/main/java/com/uber/nullaway/handlers/Handler.java index ea084c3cb2..08477a056f 100644 --- a/nullaway/src/main/java/com/uber/nullaway/handlers/Handler.java +++ b/nullaway/src/main/java/com/uber/nullaway/handlers/Handler.java @@ -31,6 +31,7 @@ import com.sun.source.tree.MethodInvocationTree; import com.sun.source.tree.MethodTree; import com.sun.source.tree.ReturnTree; +import com.sun.source.util.TreePath; import com.sun.tools.javac.code.Symbol; import com.sun.tools.javac.code.Types; import com.sun.tools.javac.util.Context; @@ -45,6 +46,7 @@ import com.uber.nullaway.dataflow.cfg.NullAwayCFGBuilder; import java.util.List; import java.util.Optional; +import java.util.function.Predicate; import javax.annotation.Nullable; import org.checkerframework.nullaway.dataflow.cfg.UnderlyingAST; import org.checkerframework.nullaway.dataflow.cfg.node.FieldAccessNode; @@ -327,15 +329,16 @@ Optional onExpressionDereference( ExpressionTree expr, ExpressionTree baseExpr, VisitorState state); /** - * Called when the store access paths are filtered for local variable information before an - * expression. + * Called when determining which access path nullability information should be preserved when + * analyzing a nested method, i.e., a lambda expression or a method in an anonymous or local + * class. * - * @param accessPath The access path that needs to be checked if filtered. + * @param path The tree path to the node for the nested method. * @param state The current visitor state. - * @return true if the nullability information for this accesspath should be treated as part of - * the surrounding context when processing a lambda expression or anonymous class declaration. + * @return A predicate that determines which access paths should be preserved when analyzing the + * nested method. */ - boolean includeApInfoInSavedContext(AccessPath accessPath, VisitorState state); + Predicate getAccessPathPredicateForNestedMethod(TreePath path, VisitorState state); /** * Called during dataflow analysis initialization to register structurally immutable types. diff --git a/nullaway/src/main/java/com/uber/nullaway/handlers/Handlers.java b/nullaway/src/main/java/com/uber/nullaway/handlers/Handlers.java index f82f343ffb..c9c012c3da 100644 --- a/nullaway/src/main/java/com/uber/nullaway/handlers/Handlers.java +++ b/nullaway/src/main/java/com/uber/nullaway/handlers/Handlers.java @@ -69,6 +69,7 @@ public static Handler buildDefault(Config config) { handlerListBuilder.add(new GrpcHandler()); handlerListBuilder.add(new RequiresNonNullHandler()); handlerListBuilder.add(new EnsuresNonNullHandler()); + handlerListBuilder.add(new SynchronousCallbackHandler()); if (config.serializationIsActive() && config.getSerializationConfig().fieldInitInfoEnabled) { handlerListBuilder.add( new FieldInitializationSerializationHandler(config.getSerializationConfig())); diff --git a/nullaway/src/main/java/com/uber/nullaway/handlers/OptionalEmptinessHandler.java b/nullaway/src/main/java/com/uber/nullaway/handlers/OptionalEmptinessHandler.java index e5818854dc..e8a70bc855 100644 --- a/nullaway/src/main/java/com/uber/nullaway/handlers/OptionalEmptinessHandler.java +++ b/nullaway/src/main/java/com/uber/nullaway/handlers/OptionalEmptinessHandler.java @@ -50,6 +50,7 @@ import java.util.Optional; import java.util.Set; import java.util.function.Consumer; +import java.util.function.Predicate; import javax.annotation.Nullable; import javax.lang.model.element.AnnotationMirror; import javax.lang.model.element.Element; @@ -164,17 +165,18 @@ private boolean isOptionalContentNullable( } @Override - public boolean includeApInfoInSavedContext(AccessPath accessPath, VisitorState state) { - - if (accessPath.getElements().size() == 1) { - final Element e = accessPath.getRoot(); - if (e != null) { - return e.getKind().equals(ElementKind.LOCAL_VARIABLE) - && accessPath.getElements().get(0).getJavaElement() - instanceof OptionalContentVariableElement; + public Predicate getAccessPathPredicateForNestedMethod( + TreePath path, VisitorState state) { + return ap -> { + if (ap.getElements().size() == 1) { + final Element e = ap.getRoot(); + if (e != null) { + return e.getKind().equals(ElementKind.LOCAL_VARIABLE) + && ap.getElements().get(0).getJavaElement() instanceof OptionalContentVariableElement; + } } - } - return false; + return false; + }; } private void handleTestAssertions( diff --git a/nullaway/src/main/java/com/uber/nullaway/handlers/SynchronousCallbackHandler.java b/nullaway/src/main/java/com/uber/nullaway/handlers/SynchronousCallbackHandler.java new file mode 100644 index 0000000000..ca49de3f98 --- /dev/null +++ b/nullaway/src/main/java/com/uber/nullaway/handlers/SynchronousCallbackHandler.java @@ -0,0 +1,94 @@ +package com.uber.nullaway.handlers; + +import static com.uber.nullaway.handlers.AccessPathPredicates.FALSE_AP_PREDICATE; +import static com.uber.nullaway.handlers.AccessPathPredicates.TRUE_AP_PREDICATE; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableMap; +import com.google.errorprone.VisitorState; +import com.google.errorprone.suppliers.Supplier; +import com.google.errorprone.suppliers.Suppliers; +import com.google.errorprone.util.ASTHelpers; +import com.sun.source.tree.ClassTree; +import com.sun.source.tree.LambdaExpressionTree; +import com.sun.source.tree.MethodInvocationTree; +import com.sun.source.tree.Tree; +import com.sun.source.util.TreePath; +import com.sun.tools.javac.code.Symbol; +import com.sun.tools.javac.code.Type; +import com.uber.nullaway.LibraryModels.MethodRef; +import com.uber.nullaway.dataflow.AccessPath; +import java.util.function.Predicate; + +public class SynchronousCallbackHandler extends BaseNoOpHandler { + + /** + * Maps method name to full information about the corresponding methods and what parameter is the + * relevant callback. We key on method name to quickly eliminate most cases when doing a lookup. + */ + private static final ImmutableMap> + METHOD_NAME_TO_SIG_AND_PARAM_INDEX = + ImmutableMap.of( + "forEach", + ImmutableMap.of( + MethodRef.methodRef( + "java.util.Map", + "forEach(java.util.function.BiConsumer)"), + 0, + MethodRef.methodRef( + "java.lang.Iterable", "forEach(java.util.function.Consumer)"), + 0), + "removeIf", + ImmutableMap.of( + MethodRef.methodRef( + "java.util.Collection", "removeIf(java.util.function.Predicate)"), + 0)); + + private static final Supplier STREAM_TYPE_SUPPLIER = + Suppliers.typeFromString("java.util.stream.Stream"); + + @Override + public Predicate getAccessPathPredicateForNestedMethod( + TreePath path, VisitorState state) { + Tree leafNode = path.getLeaf(); + Preconditions.checkArgument( + leafNode instanceof ClassTree || leafNode instanceof LambdaExpressionTree, + "Unexpected leaf type: %s", + leafNode.getClass()); + Tree parentNode = path.getParentPath().getLeaf(); + if (parentNode instanceof MethodInvocationTree) { + MethodInvocationTree methodInvocationTree = (MethodInvocationTree) parentNode; + Symbol.MethodSymbol symbol = ASTHelpers.getSymbol(methodInvocationTree); + if (symbol == null) { + return FALSE_AP_PREDICATE; + } + Type ownerType = symbol.owner.type; + if (ASTHelpers.isSameType(ownerType, STREAM_TYPE_SUPPLIER.get(state), state)) { + // preserve access paths for all callbacks passed to stream methods + return TRUE_AP_PREDICATE; + } + String invokedMethodName = symbol.getSimpleName().toString(); + if (METHOD_NAME_TO_SIG_AND_PARAM_INDEX.containsKey(invokedMethodName)) { + ImmutableMap entriesForMethodName = + METHOD_NAME_TO_SIG_AND_PARAM_INDEX.get(invokedMethodName); + for (MethodRef methodRef : entriesForMethodName.keySet()) { + if (symbol.toString().equals(methodRef.fullMethodSig) + && ASTHelpers.isSubtype( + ownerType, state.getTypeFromString(methodRef.enclosingClass), state)) { + int parameterIndex = -1; + for (int i = 0; i < methodInvocationTree.getArguments().size(); i++) { + if (methodInvocationTree.getArguments().get(i) == leafNode) { + parameterIndex = i; + break; + } + } + if (parameterIndex == entriesForMethodName.get(methodRef)) { + return TRUE_AP_PREDICATE; + } + } + } + } + } + return FALSE_AP_PREDICATE; + } +} diff --git a/nullaway/src/test/java/com/uber/nullaway/SyncLambdasTests.java b/nullaway/src/test/java/com/uber/nullaway/SyncLambdasTests.java new file mode 100644 index 0000000000..0189b3d3e6 --- /dev/null +++ b/nullaway/src/test/java/com/uber/nullaway/SyncLambdasTests.java @@ -0,0 +1,166 @@ +package com.uber.nullaway; + +import org.junit.Test; + +/** + * Tests for cases where lambdas or anonymous class methods are invoked nearly synchronously, so it + * is reasonable to propagate more nullability information to their bodies. + */ +public class SyncLambdasTests extends NullAwayTestsBase { + + @Test + public void forEachOnMap() { + defaultCompilationHelper + .addSourceLines( + "Test.java", + "package com.uber;", + "import java.util.Map;", + "import java.util.HashMap;", + "import org.jspecify.annotations.Nullable;", + "public class Test {", + " private @Nullable Map target;", + " private @Nullable Map resolved;", + " public void initialize() {", + " if (this.target == null) {", + " throw new IllegalArgumentException();", + " }", + " this.resolved = new HashMap<>();", + " this.target.forEach((key, value) -> {", + " // no error here as info gets propagated", + " this.resolved.put(key, value);", + " });", + " }", + "}") + .doTest(); + } + + @Test + public void forEachOnHashMap() { + defaultCompilationHelper + .addSourceLines( + "Test.java", + "package com.uber;", + "import java.util.HashMap;", + "import org.jspecify.annotations.Nullable;", + "public class Test {", + " private @Nullable HashMap target;", + " private @Nullable HashMap resolved;", + " public void initialize() {", + " if (this.target == null) {", + " throw new IllegalArgumentException();", + " }", + " this.resolved = new HashMap<>();", + " this.target.forEach((key, value) -> {", + " // no error here as info gets propagated", + " this.resolved.put(key, value);", + " });", + " }", + "}") + .doTest(); + } + + @Test + public void otherForEach() { + defaultCompilationHelper + .addSourceLines( + "Test.java", + "package com.uber;", + "import java.util.HashMap;", + "import java.util.function.BiConsumer;", + "import org.jspecify.annotations.Nullable;", + "public class Test {", + " private @Nullable MyMap target;", + " private @Nullable Object resolved;", + " static class MyMap {", + " public void forEach(BiConsumer consumer) {}", + " public void put(Object key, Object value) {}", + " }", + " public void initialize() {", + " if (this.target == null) {", + " throw new IllegalArgumentException();", + " }", + " this.resolved = new Object();", + " this.target.forEach((key, value) -> {", + " // error since this is a custom type, not inheriting from java.util.Map", + " // BUG: Diagnostic contains: dereferenced expression this.resolved is @Nullable", + " System.out.println(this.resolved.toString());", + " });", + " }", + "}") + .doTest(); + } + + @Test + public void forEachOnIterable() { + defaultCompilationHelper + .addSourceLines( + "Test.java", + "package com.uber;", + "import java.util.List;", + "import java.util.ArrayList;", + "import org.jspecify.annotations.Nullable;", + "public class Test {", + " private @Nullable Object f;", + " public void test1() {", + " if (this.f == null) {", + " throw new IllegalArgumentException();", + " }", + " List l = new ArrayList<>();", + " l.forEach(v -> System.out.println(v + this.f.toString()));", + " Iterable l2 = l;", + " l2.forEach(v -> System.out.println(v + this.f.toString()));", + " this.f = null;", + " // BUG: Diagnostic contains: dereferenced expression this.f is @Nullable", + " l2.forEach(v -> System.out.println(v + this.f.toString()));", + " }", + "}") + .doTest(); + } + + @Test + public void removeIf() { + defaultCompilationHelper + .addSourceLines( + "Test.java", + "package com.uber;", + "import java.util.List;", + "import java.util.ArrayList;", + "import org.jspecify.annotations.Nullable;", + "public class Test {", + " private @Nullable Object f;", + " public void test1() {", + " if (this.f == null) {", + " throw new IllegalArgumentException();", + " }", + " List l = new ArrayList<>();", + " l.removeIf(v -> this.f.toString().equals(v.toString()));", + " }", + "}") + .doTest(); + } + + @Test + public void streamMethods() { + defaultCompilationHelper + .addSourceLines( + "Test.java", + "package com.uber;", + "import java.util.List;", + "import java.util.ArrayList;", + "import org.jspecify.annotations.Nullable;", + "public class Test {", + " private @Nullable Object f;", + " public void test1() {", + " if (this.f == null) {", + " throw new IllegalArgumentException();", + " }", + " List l = new ArrayList<>();", + " // this.f being non-null gets propagated to all callback lambdas", + " l.stream().filter(v -> this.f.toString().equals(v.toString()))", + " .map(v -> this.f.toString())", + " .forEach(v -> System.out.println(this.f.hashCode() + v.toString()));", + " }", + "}") + .doTest(); + } +} diff --git a/nullaway/src/test/resources/com/uber/nullaway/testdata/NullAwayStreamSupportPositiveCases.java b/nullaway/src/test/resources/com/uber/nullaway/testdata/NullAwayStreamSupportPositiveCases.java index 82a12a00b9..4427260539 100644 --- a/nullaway/src/test/resources/com/uber/nullaway/testdata/NullAwayStreamSupportPositiveCases.java +++ b/nullaway/src/test/resources/com/uber/nullaway/testdata/NullAwayStreamSupportPositiveCases.java @@ -191,7 +191,8 @@ private Stream test1(Stream stream) { private Stream test2(Stream stream) { Preconditions.checkNotNull(ref); - // BUG: Diagnostic contains: dereferenced expression ref is @Nullable + // no error since we propagate nullability facts to stream callbacks, which + // in sane code are invoked soon after the stream is created return stream.filter(s -> ref.equals(s)); } }