From b73fed384d4e427d0641ac06e0d0f0d640c3d79e Mon Sep 17 00:00:00 2001 From: Suzanne Millstein Date: Thu, 16 Dec 2021 09:32:38 -0800 Subject: [PATCH] Compute the type of switch expressions and check them. (#4978) --- .../java17/SwitchExpressionInvariant.java | 27 +++ docs/CHANGELOG.md | 2 + .../common/basetype/BaseTypeVisitor.java | 33 ++++ .../common/basetype/messages.properties | 1 + .../type/TypeFromExpressionVisitor.java | 27 ++- .../value/java17/SwitchExpressionTyping.java | 100 ++++++++++ .../javacutil/SwitchExpressionScanner.java | 174 ++++++++++++++++++ .../checkerframework/javacutil/TreeUtils.java | 91 +++++++++ 8 files changed, 449 insertions(+), 6 deletions(-) create mode 100644 checker/tests/nullness/java17/SwitchExpressionInvariant.java create mode 100644 framework/tests/value/java17/SwitchExpressionTyping.java create mode 100644 javacutil/src/main/java/org/checkerframework/javacutil/SwitchExpressionScanner.java diff --git a/checker/tests/nullness/java17/SwitchExpressionInvariant.java b/checker/tests/nullness/java17/SwitchExpressionInvariant.java new file mode 100644 index 00000000000..6be0fa35fdc --- /dev/null +++ b/checker/tests/nullness/java17/SwitchExpressionInvariant.java @@ -0,0 +1,27 @@ +// @below-java17-jdk-skip-test +import java.util.List; +import org.checkerframework.checker.nullness.qual.NonNull; +import org.checkerframework.checker.nullness.qual.Nullable; + +public class SwitchExpressionInvariant { + public static boolean flag = false; + + void method( + List<@NonNull String> nonnullStrings, List<@Nullable String> nullableStrings, int fenum) { + + List<@NonNull String> list = + // :: error: (assignment) + switch (fenum) { + // :: error: (switch.expression) + case 1 -> nonnullStrings; + default -> nullableStrings; + }; + + List<@Nullable String> list2 = + switch (fenum) { + // :: error: (switch.expression) + case 1 -> nonnullStrings; + default -> nullableStrings; + }; + } +} diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index 69bbb91f020..243363b8038 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -3,6 +3,8 @@ Version 3.21.0 (December 17, 2021) **User-visible changes:** +The Checker Framework now more precisely computes the type of a switch expression. + **Implementation details:** **Closed issues:** diff --git a/framework/src/main/java/org/checkerframework/common/basetype/BaseTypeVisitor.java b/framework/src/main/java/org/checkerframework/common/basetype/BaseTypeVisitor.java index bc066ef4962..4be4866a4c3 100644 --- a/framework/src/main/java/org/checkerframework/common/basetype/BaseTypeVisitor.java +++ b/framework/src/main/java/org/checkerframework/common/basetype/BaseTypeVisitor.java @@ -138,6 +138,8 @@ import org.checkerframework.javacutil.BugInCF; import org.checkerframework.javacutil.ElementUtils; import org.checkerframework.javacutil.Pair; +import org.checkerframework.javacutil.SwitchExpressionScanner; +import org.checkerframework.javacutil.SwitchExpressionScanner.FunctionalSwitchExpressionScanner; import org.checkerframework.javacutil.TreePathUtil; import org.checkerframework.javacutil.TreeUtils; import org.checkerframework.javacutil.TypesUtils; @@ -342,6 +344,10 @@ public Void scan(@Nullable Tree tree, Void p) { if (tree != null && getCurrentPath() != null) { this.atypeFactory.setVisitorTreePath(new TreePath(getCurrentPath(), tree)); } + if (tree != null && tree.getKind().name().equals("SWITCH_EXPRESSION")) { + visitSwitchExpression17(tree); + return null; + } return super.scan(tree, p); } @@ -2117,6 +2123,33 @@ public Void visitConditionalExpression(ConditionalExpressionTree node, Void p) { return super.visitConditionalExpression(node, p); } + /** + * This method validates the type of the switch expression. It issues an error if the type of a + * value that the switch expression can result is not a subtype of the switch type. + * + *

If a subclass overrides this method, it must call {@code super.scan(switchExpressionTree, + * null)} so that the blocks and statements in the cases are checked. + * + * @param switchExpressionTree a {@code SwitchExpressionTree} + */ + public void visitSwitchExpression17(Tree switchExpressionTree) { + boolean valid = validateTypeOf(switchExpressionTree); + if (valid) { + AnnotatedTypeMirror switchType = atypeFactory.getAnnotatedType(switchExpressionTree); + SwitchExpressionScanner scanner = + new FunctionalSwitchExpressionScanner<>( + (ExpressionTree valueTree, Void unused) -> { + BaseTypeVisitor.this.commonAssignmentCheck( + switchType, valueTree, "switch.expression"); + return null; + }, + (r1, r2) -> null); + + scanner.scanSwitchExpression(switchExpressionTree, null); + } + super.scan(switchExpressionTree, null); + } + // ********************************************************************** // Check for illegal re-assignment // ********************************************************************** diff --git a/framework/src/main/java/org/checkerframework/common/basetype/messages.properties b/framework/src/main/java/org/checkerframework/common/basetype/messages.properties index 8c726b6acd8..38611cce681 100644 --- a/framework/src/main/java/org/checkerframework/common/basetype/messages.properties +++ b/framework/src/main/java/org/checkerframework/common/basetype/messages.properties @@ -13,6 +13,7 @@ vector.copyinto=incompatible component type in Vector.copyinto.%nfound : %s%nr return=incompatible types in return.%ntype of expression: %s%nmethod return type: %s annotation=incompatible types in annotation.%nfound : %s%nrequired: %s conditional=incompatible types in conditional expression.%nfound : %s%nrequired: %s +switch.expression=incompatible types in switch expression.%nfound : %s%nrequired: %s type.argument=incompatible type argument for type parameter %s of %s.%nfound : %s%nrequired: %s argument=incompatible argument for parameter %s of %s.%nfound : %s%nrequired: %s varargs=incompatible types in varargs.%nfound : %s%nrequired: %s diff --git a/framework/src/main/java/org/checkerframework/framework/type/TypeFromExpressionVisitor.java b/framework/src/main/java/org/checkerframework/framework/type/TypeFromExpressionVisitor.java index 59e3a6d49e5..493f1dd23cd 100644 --- a/framework/src/main/java/org/checkerframework/framework/type/TypeFromExpressionVisitor.java +++ b/framework/src/main/java/org/checkerframework/framework/type/TypeFromExpressionVisitor.java @@ -36,6 +36,8 @@ import org.checkerframework.framework.util.AnnotatedTypes; import org.checkerframework.javacutil.BugInCF; import org.checkerframework.javacutil.ElementUtils; +import org.checkerframework.javacutil.SwitchExpressionScanner; +import org.checkerframework.javacutil.SwitchExpressionScanner.FunctionalSwitchExpressionScanner; import org.checkerframework.javacutil.TreeUtils; import org.checkerframework.javacutil.TypesUtils; @@ -176,16 +178,29 @@ public AnnotatedTypeMirror defaultAction(Tree tree, AnnotatedTypeFactory f) { /** * Compute the type of the switch expression tree. * - * @param switchExpressionTree SwitchExpressionTree; typed as Tree to be backward-compatible - * @param f AnnotatedTypeFactory + * @param switchExpressionTree a SwitchExpressionTree; typed as Tree so method signature is + * backward-compatible + * @param f an AnnotatedTypeFactory * @return the type of the switch expression */ public AnnotatedTypeMirror visitSwitchExpressionTree17( Tree switchExpressionTree, AnnotatedTypeFactory f) { - // TODO: Properly compute the type from the cases. - AnnotatedTypeMirror result = f.type(switchExpressionTree); - result.addAnnotations(f.getQualifierHierarchy().getTopAnnotations()); - return result; + TypeMirror switchTypeMirror = TreeUtils.typeOf(switchExpressionTree); + SwitchExpressionScanner luber = + new FunctionalSwitchExpressionScanner<>( + // Function applied to each result expression of the switch expression. + (valueTree, unused) -> f.getAnnotatedType(valueTree), + // Function used to combine the types of each result expression. + (type1, type2) -> { + if (type1 == null) { + return type2; + } else if (type2 == null) { + return type1; + } else { + return AnnotatedTypes.leastUpperBound(f, type1, type2, switchTypeMirror); + } + }); + return luber.scanSwitchExpression(switchExpressionTree, null); } @Override diff --git a/framework/tests/value/java17/SwitchExpressionTyping.java b/framework/tests/value/java17/SwitchExpressionTyping.java new file mode 100644 index 00000000000..94860ed4264 --- /dev/null +++ b/framework/tests/value/java17/SwitchExpressionTyping.java @@ -0,0 +1,100 @@ +// @below-java17-jdk-skip-test +import org.checkerframework.common.value.qual.IntVal; + +public class SwitchExpressionTyping { + public static boolean flag = false; + + void method0(String s) { + @IntVal({0, 1, 2, 3}) int o = + switch (s) { + case "Hello?" -> { + throw new RuntimeException(); + } + case "Hello" -> 0; + case "Bye" -> 1; + case "Later" -> 2; + case "What?" -> throw new RuntimeException(); + default -> 3; + }; + } + + void method1(String s) { + @IntVal({1, 2, 3}) int o = + switch (s) { + case "Hello?" -> 1; + case "Hello" -> 1; + case "Bye" -> 1; + case "Later" -> 1; + case "What?" -> { + if (flag) { + yield 2; + } + yield 3; + } + default -> 1; + }; + + @IntVal(1) int o2 = + // :: error: (assignment) + switch (s) { + case "Hello?" -> 1; + case "Hello" -> 1; + case "Bye" -> 1; + case "Later" -> 1; + case "What?" -> { + if (flag) { + yield 2; + } + yield 3; + } + default -> 1; + }; + } + + void method2(String s, String r) { + @IntVal({0, 1, 2, 3}) int o = + switch (s) { + case "Hello?" -> { + if (flag) { + throw new RuntimeException(); + } + yield 2; + } + case "Hello" -> { + int i = + switch (r) { + case "Hello" -> 4; + case "Bye" -> 5; + case "Later" -> 6; + default -> 42; + }; + yield 0; + } + case "Bye" -> 1; + case "Later" -> { + int i = + switch (r) { + case "Hello": + { + yield 4; + } + case "Bye": + { + yield 5; + } + case "Later": + { + yield 6; + } + default: + { + yield 42; + } + }; + yield 2; + } + case "What?" -> throw new RuntimeException(); + default -> 3; + }; + } +} diff --git a/javacutil/src/main/java/org/checkerframework/javacutil/SwitchExpressionScanner.java b/javacutil/src/main/java/org/checkerframework/javacutil/SwitchExpressionScanner.java new file mode 100644 index 00000000000..6bd94ff50a0 --- /dev/null +++ b/javacutil/src/main/java/org/checkerframework/javacutil/SwitchExpressionScanner.java @@ -0,0 +1,174 @@ +package org.checkerframework.javacutil; + +import com.sun.source.tree.BlockTree; +import com.sun.source.tree.CaseTree; +import com.sun.source.tree.ExpressionTree; +import com.sun.source.tree.Tree; +import com.sun.source.tree.Tree.Kind; +import com.sun.source.util.TreeScanner; +import java.util.List; +import java.util.function.BiFunction; +import org.checkerframework.checker.nullness.qual.NonNull; +import org.checkerframework.checker.nullness.qual.Nullable; + +/** + * A class that visits each result expression of a switch expression and calls {@link + * #visitSwitchResultExpression(ExpressionTree, Object)} on each result expression. The results of + * these method calls are combined using {@link #combineResults(Object, Object)}. Call {@link + * #scanSwitchExpression(Tree, Object)} to start scanning the switch expression. + * + *

{@link FunctionalSwitchExpressionScanner} can be used to pass functions for to use for {@link + * #visitSwitchResultExpression(ExpressionTree, Object)} and {@link #combineResults(Object, + * Object)}. + * + * @param the type of the result of {@link #visitSwitchResultExpression(ExpressionTree, Object)} + * @param

the type of the parameter to pass to {@link + * #visitSwitchResultExpression(ExpressionTree, Object)} + */ +public abstract class SwitchExpressionScanner extends TreeScanner { + + /** + * This method is called for each result expression of the switch expression passed in {@link + * #scanSwitchExpression(Tree, Object)}. + * + * @param resultExpressionTree a result expression of the switch expression currently being + * scanned + * @param p a parameter + * @return the result of visiting the result expression + */ + protected abstract R visitSwitchResultExpression(ExpressionTree resultExpressionTree, P p); + + /** + * This method combines the result of two calls to {@link + * #visitSwitchResultExpression(ExpressionTree, Object)} or {@code null} and the result of one + * call to {@link #visitSwitchResultExpression(ExpressionTree, Object)}. + * + * @param r1 a possibly null result returned by {@link + * #visitSwitchResultExpression(ExpressionTree, Object)} + * @param r2 a possibly null result returned by {@link + * #visitSwitchResultExpression(ExpressionTree, Object)} + * @return the combination of {@code r1} and {@code r2} + */ + protected abstract R combineResults(@Nullable R r1, @Nullable R r2); + + /** + * Scans the given switch expression and calls {@link #visitSwitchResultExpression(ExpressionTree, + * Object)} on each result expression of the switch expression. {@link #combineResults(Object, + * Object)} is called to combine the results of visiting multiple switch result expressions. + * + * @param switchExpression a switch expression tree + * @param p the parameter to pass to {@link #visitSwitchResultExpression(ExpressionTree, Object)} + * @return the result of calling {@link #visitSwitchResultExpression(ExpressionTree, Object)} on + * each result expression of {@code switchExpression} and combining the results using {@link + * #combineResults(Object, Object)} + */ + public R scanSwitchExpression(Tree switchExpression, P p) { + assert switchExpression.getKind().name().equals("SWITCH_EXPRESSION"); + List caseTrees = TreeUtils.switchExpressionTreeGetCases(switchExpression); + R result = null; + for (CaseTree caseTree : caseTrees) { + if (caseTree.getStatements() != null) { + // This case is a switch labeled statement group, so scan the statements for yield + // statements. + result = combineResults(result, yieldVisitor.scan(caseTree.getStatements(), p)); + } else { + @SuppressWarnings( + "nullness:assignment") // caseTree.getStatements() == null, so the case has a body. + @NonNull Tree body = TreeUtils.caseTreeGetBody(caseTree); + // This case is a switch rule, so its body is either an expression, block, or throw. + // See https://docs.oracle.com/javase/specs/jls/se17/html/jls-15.html#jls-15.28.2. + if (body.getKind() == Kind.BLOCK) { + // Scan for yield statements. + result = combineResults(result, yieldVisitor.scan(((BlockTree) body).getStatements(), p)); + } else if (body.getKind() != Kind.THROW) { + // The expression is the result expression. + ExpressionTree expressionTree = (ExpressionTree) body; + result = combineResults(result, visitSwitchResultExpression(expressionTree, p)); + } + } + } + @SuppressWarnings( + "nullness:assignment" // switch expressions must have at least one case that results in a + // value, so {@code result} must be nonnull. + ) + @NonNull R nonNullResult = result; + return nonNullResult; + } + + /** + * A scanner that visits all the yield trees in a given tree and calls {@link + * #visitSwitchResultExpression(ExpressionTree, Object)} on the expression in the yield trees. It + * does not descend into switch expressions. + */ + protected YieldVisitor yieldVisitor = new YieldVisitor(); + + /** + * A scanner that visits all the yield trees in a given tree and calls {@link + * #visitSwitchResultExpression(ExpressionTree, Object)} on the expression in the yield trees. It + * does not descend into switch expressions. + */ + protected class YieldVisitor extends TreeScanner<@Nullable R, P> { + + @Override + public @Nullable R scan(Tree tree, P p) { + if (tree == null) { + return null; + } + if (tree.getKind().name().equals("SWITCH_EXPRESSION")) { + // Don't scan nested switch expressions. + return null; + } else if (tree.getKind().name().equals("YIELD")) { + ExpressionTree value = TreeUtils.yieldTreeGetValue(tree); + return visitSwitchResultExpression(value, p); + } + return super.scan(tree, p); + } + + @Override + public R reduce(R r1, R r2) { + return combineResults(r1, r2); + } + } + + /** + * An implementation of {@link SwitchExpressionScanner} that uses functions passed to the + * constructor for {@link #visitSwitchResultExpression(ExpressionTree, Object)} and {@link + * #combineResults(Object, Object)}. + * + * @param the type result of {@link #visitSwitchResultExpression(ExpressionTree, Object)} + * @param the type of the parameter to pass to {@link + * #visitSwitchResultExpression(ExpressionTree, Object)} + */ + public static class FunctionalSwitchExpressionScanner + extends SwitchExpressionScanner { + + /** The function to use for {@link #visitSwitchResultExpression(ExpressionTree, Object)}. */ + private final BiFunction switchValueExpressionFunction; + /** The function to use for {@link #visitSwitchResultExpression(ExpressionTree, Object)}. */ + private final BiFunction<@Nullable R1, @Nullable R1, R1> combineResultFunc; + + /** + * Creates a {@link FunctionalSwitchExpressionScanner} that uses the given functions. + * + * @param switchValueExpressionFunc the function called on each switch result expression + * @param combineResultFunc the function used to combine the result of multiple calls to {@code + * switchValueExpressionFunc} + */ + public FunctionalSwitchExpressionScanner( + BiFunction switchValueExpressionFunc, + BiFunction<@Nullable R1, @Nullable R1, R1> combineResultFunc) { + this.switchValueExpressionFunction = switchValueExpressionFunc; + this.combineResultFunc = combineResultFunc; + } + + @Override + protected R1 visitSwitchResultExpression(ExpressionTree resultExpressionTree, P1 p1) { + return switchValueExpressionFunction.apply(resultExpressionTree, p1); + } + + @Override + protected R1 combineResults(@Nullable R1 r1, @Nullable R1 r2) { + return combineResultFunc.apply(r1, r2); + } + } +} diff --git a/javacutil/src/main/java/org/checkerframework/javacutil/TreeUtils.java b/javacutil/src/main/java/org/checkerframework/javacutil/TreeUtils.java index 92a16c897a2..8fa37d4a7dc 100644 --- a/javacutil/src/main/java/org/checkerframework/javacutil/TreeUtils.java +++ b/javacutil/src/main/java/org/checkerframework/javacutil/TreeUtils.java @@ -1704,6 +1704,97 @@ public static List caseTreeGetExpressions(CaseTree cas } } + /** + * Returns the selector expression of {@code switchExpressionTree}. For example + * + *

+   *   switch ( expression ) { ... }
+   * 
+ * + * @param switchExpressionTree the switch expression whose selector expression is returned + * @return the selector expression of {@code switchExpressionTree} + */ + public static ExpressionTree switchExpressionTreeGetExpression(Tree switchExpressionTree) { + try { + Class switchExpressionClass = Class.forName("com.sun.source.tree.SwitchExpressionTree"); + Method getExpressionMethod = switchExpressionClass.getMethod("getExpression"); + ExpressionTree expressionTree = + (ExpressionTree) getExpressionMethod.invoke(switchExpressionTree); + if (expressionTree != null) { + return expressionTree; + } + throw new BugInCF( + "TreeUtils.switchExpressionTreeGetExpression: expression is null for tree: %s", + switchExpressionTree); + } catch (ClassNotFoundException + | NoSuchMethodException + | InvocationTargetException + | IllegalAccessException e) { + throw new BugInCF( + "TreeUtils.switchExpressionTreeGetExpression: reflection failed for tree: %s", + switchExpressionTree, e); + } + } + + /** + * Returns the cases of {@code switchExpressionTree}. For example + * + *
+   *   switch ( expression ) {
+   *     cases
+   *   }
+   * 
+ * + * @param switchExpressionTree the switch expression whose cases are returned + * @return the cases of {@code switchExpressionTree} + */ + public static List switchExpressionTreeGetCases(Tree switchExpressionTree) { + try { + Class switchExpressionClass = Class.forName("com.sun.source.tree.SwitchExpressionTree"); + Method getCasesMethod = switchExpressionClass.getMethod("getCases"); + @SuppressWarnings("unchecked") + List cases = + (List) getCasesMethod.invoke(switchExpressionTree); + if (cases != null) { + return cases; + } + throw new BugInCF( + "TreeUtils.switchExpressionTreeGetCases: cases is null for tree: %s", + switchExpressionTree); + } catch (ClassNotFoundException + | NoSuchMethodException + | InvocationTargetException + | IllegalAccessException e) { + throw new BugInCF( + "TreeUtils.switchExpressionTreeGetCases: reflection failed for tree: %s", + switchExpressionTree, e); + } + } + + /** + * Returns the value (expression) for {@code yieldTree}. + * + * @param yieldTree the yield tree + * @return the value (expression) for {@code yieldTree}. + */ + public static ExpressionTree yieldTreeGetValue(Tree yieldTree) { + try { + Class yieldTreeClass = Class.forName("com.sun.source.tree.YieldTree"); + Method getCasesMethod = yieldTreeClass.getMethod("getValue"); + ExpressionTree expressionTree = (ExpressionTree) getCasesMethod.invoke(yieldTree); + if (expressionTree != null) { + return expressionTree; + } + throw new BugInCF("TreeUtils.yieldTreeGetValue: expression is null for tree: %s", yieldTree); + } catch (ClassNotFoundException + | NoSuchMethodException + | InvocationTargetException + | IllegalAccessException e) { + throw new BugInCF( + "TreeUtils.yieldTreeGetValue: reflection failed for tree: %s", yieldTree, e); + } + } + /** * Returns true if the given method/constructor invocation is a varargs invocation. *