Skip to content

Commit

Permalink
Fix handling of default cases in arrow switches
Browse files Browse the repository at this point in the history
And consolidate some reflective workarounds for switch API changes.

Fixes #4266

PiperOrigin-RevId: 609484284
  • Loading branch information
cushon authored and Error Prone Team committed Feb 22, 2024
1 parent e3725d2 commit f768b0b
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 53 deletions.
76 changes: 76 additions & 0 deletions check_api/src/main/java/com/google/errorprone/util/ASTHelpers.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static com.google.common.collect.Iterables.getOnlyElement;
import static com.google.common.collect.Streams.stream;
import static com.google.errorprone.matchers.JUnitMatchers.JUNIT4_RUN_WITH_ANNOTATION;
import static com.google.errorprone.matchers.Matchers.isSubtypeOf;
Expand Down Expand Up @@ -2754,5 +2755,80 @@ private static boolean hasMatchingMethods(
return false;
}

private static final Method CASE_TREE_GET_LABELS = getCaseTreeGetLabelsMethod();

@Nullable
private static Method getCaseTreeGetLabelsMethod() {
try {
return CaseTree.class.getMethod("getLabels");
} catch (NoSuchMethodException e) {
return null;
}
}

@SuppressWarnings("unchecked") // reflection
private static List<? extends Tree> getCaseLabels(CaseTree caseTree) {
if (CASE_TREE_GET_LABELS == null) {
return ImmutableList.of();
}
try {
return (List<? extends Tree>) CASE_TREE_GET_LABELS.invoke(caseTree);
} catch (ReflectiveOperationException e) {
throw new LinkageError(e.getMessage(), e);
}
}

// getExpression() is being used for compatibility with earlier JDK versions
@SuppressWarnings("deprecation")
public static Optional<? extends CaseTree> getSwitchDefault(SwitchTree switchTree) {
return switchTree.getCases().stream()
.filter(
(CaseTree c) -> {
if (c.getExpression() != null) {
return false;
}
List<? extends Tree> labels = getCaseLabels(c);
return labels.isEmpty()
|| (labels.size() == 1
&& getOnlyElement(labels).getKind().name().equals("DEFAULT_CASE_LABEL"));
})
.findFirst();
}

private static final Method CASE_TREE_GET_EXPRESSIONS = getCaseTreeGetExpressionsMethod();

@Nullable
private static Method getCaseTreeGetExpressionsMethod() {
try {
return CaseTree.class.getMethod("getExpressions");
} catch (NoSuchMethodException e) {
return null;
}
}

/**
* Retrieves a stream containing all case expressions, in order, for a given {@code CaseTree}.
* This method acts as a facade to the {@code CaseTree.getExpressions()} API, falling back to
* legacy APIs when necessary.
*/
@SuppressWarnings({
"deprecation", // getExpression() is being used for compatibility with earlier JDK versions
"unchecked", // reflection
})
public static Stream<? extends ExpressionTree> getCaseExpressions(CaseTree caseTree) {
if (!RuntimeVersion.isAtLeast12()) {
// "default" case gives an empty stream
return Stream.ofNullable(caseTree.getExpression());
}
if (CASE_TREE_GET_EXPRESSIONS == null) {
return Stream.empty();
}
try {
return ((List<? extends ExpressionTree>) CASE_TREE_GET_EXPRESSIONS.invoke(caseTree)).stream();
} catch (ReflectiveOperationException e) {
throw new LinkageError(e.getMessage(), e);
}
}

private ASTHelpers() {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,11 @@
import com.google.errorprone.bugpatterns.BugChecker.SwitchTreeMatcher;
import com.google.errorprone.matchers.Description;
import com.google.errorprone.util.ASTHelpers;
import com.google.errorprone.util.RuntimeVersion;
import com.sun.source.tree.CaseTree;
import com.sun.source.tree.ExpressionTree;
import com.sun.source.tree.IdentifierTree;
import com.sun.source.tree.SwitchTree;
import com.sun.tools.javac.code.Type;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import javax.lang.model.element.ElementKind;

/** A {@link BugChecker}; see the associated {@link BugPattern} annotation for details. */
Expand All @@ -58,7 +53,7 @@ public Description matchSwitch(SwitchTree tree, VisitorState state) {
}
ImmutableSet<String> handled =
tree.getCases().stream()
.flatMap(MissingCasesInEnumSwitch::getExpressions)
.flatMap(ASTHelpers::getCaseExpressions)
.filter(IdentifierTree.class::isInstance)
.map(e -> ((IdentifierTree) e).getName().toString())
.collect(toImmutableSet());
Expand Down Expand Up @@ -93,19 +88,4 @@ private static String buildMessage(Set<String> unhandled) {
}
return message.toString();
}

@SuppressWarnings("unchecked")
private static Stream<? extends ExpressionTree> getExpressions(CaseTree caseTree) {
try {
if (RuntimeVersion.isAtLeast12()) {
return ((List<? extends ExpressionTree>)
CaseTree.class.getMethod("getExpressions").invoke(caseTree))
.stream();
} else {
return Stream.of(caseTree.getExpression());
}
} catch (ReflectiveOperationException e) {
throw new LinkageError(e.getMessage(), e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import static com.google.common.collect.Iterables.getLast;
import static com.google.errorprone.BugPattern.SeverityLevel.WARNING;
import static com.google.errorprone.matchers.Description.NO_MATCH;
import static com.google.errorprone.util.ASTHelpers.getSwitchDefault;

import com.google.common.collect.Iterables;
import com.google.errorprone.BugPattern;
Expand Down Expand Up @@ -52,8 +53,7 @@ public Description matchSwitch(SwitchTree tree, VisitorState state) {
// by MissingCasesInEnumSwitch
return NO_MATCH;
}
Optional<? extends CaseTree> maybeDefault =
tree.getCases().stream().filter(c -> c.getExpression() == null).findFirst();
Optional<? extends CaseTree> maybeDefault = getSwitchDefault(tree);
if (!maybeDefault.isPresent()) {
Description.Builder description = buildDescription(tree);
if (!tree.getCases().isEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import static com.google.common.collect.Iterables.getOnlyElement;
import static com.google.errorprone.BugPattern.SeverityLevel.WARNING;
import static com.google.errorprone.matchers.Description.NO_MATCH;
import static com.google.errorprone.util.ASTHelpers.getCaseExpressions;
import static com.google.errorprone.util.ASTHelpers.getStartPosition;
import static com.google.errorprone.util.ASTHelpers.getSymbol;
import static com.sun.source.tree.Tree.Kind.BLOCK;
Expand All @@ -45,7 +46,6 @@
import com.google.errorprone.matchers.Matchers;
import com.google.errorprone.util.ASTHelpers;
import com.google.errorprone.util.Reachability;
import com.google.errorprone.util.RuntimeVersion;
import com.google.errorprone.util.SourceVersion;
import com.sun.source.tree.AssignmentTree;
import com.sun.source.tree.BlockTree;
Expand Down Expand Up @@ -77,7 +77,6 @@
import java.util.Optional;
import java.util.Set;
import java.util.regex.Pattern;
import java.util.stream.Stream;
import javax.inject.Inject;
import javax.lang.model.element.ElementKind;

Expand Down Expand Up @@ -206,11 +205,11 @@ private static AnalysisResult analyzeSwitchTree(SwitchTree switchTree, VisitorSt
// One-pass scan through each case in switch
for (int caseIndex = 0; caseIndex < cases.size(); caseIndex++) {
CaseTree caseTree = cases.get(caseIndex);
boolean isDefaultCase = (getExpressions(caseTree).count() == 0);
boolean isDefaultCase = (getCaseExpressions(caseTree).count() == 0);
hasDefaultCase |= isDefaultCase;
// Accumulate enum values included in this case
handledEnumValues.addAll(
getExpressions(caseTree)
getCaseExpressions(caseTree)
.filter(IdentifierTree.class::isInstance)
.map(expressionTree -> ((IdentifierTree) expressionTree).getName().toString())
.collect(toImmutableSet()));
Expand Down Expand Up @@ -954,30 +953,7 @@ private static String removeFallThruLines(String comments) {

/** Prints source for all expressions in a given {@code case}, separated by commas. */
private static String printCaseExpressions(CaseTree caseTree, VisitorState state) {
return getExpressions(caseTree).map(state::getSourceForNode).collect(joining(", "));
}

/**
* Retrieves a stream containing all case expressions, in order, for a given {@code CaseTree}.
* This method acts as a facade to the {@code CaseTree.getExpressions()} API, falling back to
* legacy APIs when necessary.
*/
@SuppressWarnings("unchecked")
private static Stream<? extends ExpressionTree> getExpressions(CaseTree caseTree) {
try {
if (RuntimeVersion.isAtLeast12()) {
return ((List<? extends ExpressionTree>)
CaseTree.class.getMethod("getExpressions").invoke(caseTree))
.stream();
} else {
// "default" case gives an empty stream
return caseTree.getExpression() == null
? Stream.empty()
: Stream.of(caseTree.getExpression());
}
} catch (ReflectiveOperationException e) {
throw new LinkageError(e.getMessage(), e);
}
return getCaseExpressions(caseTree).map(state::getSourceForNode).collect(joining(", "));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import static com.google.errorprone.BugPattern.SeverityLevel.SUGGESTION;
import static com.google.errorprone.matchers.Description.NO_MATCH;
import static com.google.errorprone.util.ASTHelpers.getStartPosition;
import static com.google.errorprone.util.ASTHelpers.getSwitchDefault;

import com.google.errorprone.BugPattern;
import com.google.errorprone.VisitorState;
Expand All @@ -44,8 +45,7 @@ public class SwitchDefault extends BugChecker implements SwitchTreeMatcher {

@Override
public Description matchSwitch(SwitchTree tree, VisitorState state) {
Optional<? extends CaseTree> maybeDefault =
tree.getCases().stream().filter(c -> c.getExpression() == null).findAny();
Optional<? extends CaseTree> maybeDefault = getSwitchDefault(tree);
if (!maybeDefault.isPresent()) {
return NO_MATCH;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,4 +291,27 @@ public void newNotation_changeOrder() {
"}")
.doTest();
}

@Test
public void arrowSwitch_noDefault() {
assumeTrue(RuntimeVersion.isAtLeast21());
compilationHelper
.addSourceLines(
"Foo.java", //
"sealed interface Foo {",
" final class Bar implements Foo {}",
" final class Baz implements Foo {}",
"}")
.addSourceLines(
"Test.java",
"class Test {",
" void f(Foo i) {",
" switch (i) {",
" case Foo.Bar bar -> {}",
" case Foo.Baz baz -> {}",
" }",
" }",
"}")
.doTest();
}
}

0 comments on commit f768b0b

Please sign in to comment.