diff --git a/checker/tests/nullness/java17/SwitchExpressionInvariant.java b/checker/tests/nullness/java17/SwitchExpressionInvariant.java
new file mode 100644
index 00000000000..6be0fa35fdc
--- /dev/null
+++ b/checker/tests/nullness/java17/SwitchExpressionInvariant.java
@@ -0,0 +1,27 @@
+// @below-java17-jdk-skip-test
+import java.util.List;
+import org.checkerframework.checker.nullness.qual.NonNull;
+import org.checkerframework.checker.nullness.qual.Nullable;
+
+public class SwitchExpressionInvariant {
+ public static boolean flag = false;
+
+ void method(
+ List<@NonNull String> nonnullStrings, List<@Nullable String> nullableStrings, int fenum) {
+
+ List<@NonNull String> list =
+ // :: error: (assignment)
+ switch (fenum) {
+ // :: error: (switch.expression)
+ case 1 -> nonnullStrings;
+ default -> nullableStrings;
+ };
+
+ List<@Nullable String> list2 =
+ switch (fenum) {
+ // :: error: (switch.expression)
+ case 1 -> nonnullStrings;
+ default -> nullableStrings;
+ };
+ }
+}
diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md
index 69bbb91f020..243363b8038 100644
--- a/docs/CHANGELOG.md
+++ b/docs/CHANGELOG.md
@@ -3,6 +3,8 @@ Version 3.21.0 (December 17, 2021)
**User-visible changes:**
+The Checker Framework now more precisely computes the type of a switch expression.
+
**Implementation details:**
**Closed issues:**
diff --git a/framework/src/main/java/org/checkerframework/common/basetype/BaseTypeVisitor.java b/framework/src/main/java/org/checkerframework/common/basetype/BaseTypeVisitor.java
index bc066ef4962..4be4866a4c3 100644
--- a/framework/src/main/java/org/checkerframework/common/basetype/BaseTypeVisitor.java
+++ b/framework/src/main/java/org/checkerframework/common/basetype/BaseTypeVisitor.java
@@ -138,6 +138,8 @@
import org.checkerframework.javacutil.BugInCF;
import org.checkerframework.javacutil.ElementUtils;
import org.checkerframework.javacutil.Pair;
+import org.checkerframework.javacutil.SwitchExpressionScanner;
+import org.checkerframework.javacutil.SwitchExpressionScanner.FunctionalSwitchExpressionScanner;
import org.checkerframework.javacutil.TreePathUtil;
import org.checkerframework.javacutil.TreeUtils;
import org.checkerframework.javacutil.TypesUtils;
@@ -342,6 +344,10 @@ public Void scan(@Nullable Tree tree, Void p) {
if (tree != null && getCurrentPath() != null) {
this.atypeFactory.setVisitorTreePath(new TreePath(getCurrentPath(), tree));
}
+ if (tree != null && tree.getKind().name().equals("SWITCH_EXPRESSION")) {
+ visitSwitchExpression17(tree);
+ return null;
+ }
return super.scan(tree, p);
}
@@ -2117,6 +2123,33 @@ public Void visitConditionalExpression(ConditionalExpressionTree node, Void p) {
return super.visitConditionalExpression(node, p);
}
+ /**
+ * This method validates the type of the switch expression. It issues an error if the type of a
+ * value that the switch expression can result is not a subtype of the switch type.
+ *
+ *
If a subclass overrides this method, it must call {@code super.scan(switchExpressionTree,
+ * null)} so that the blocks and statements in the cases are checked.
+ *
+ * @param switchExpressionTree a {@code SwitchExpressionTree}
+ */
+ public void visitSwitchExpression17(Tree switchExpressionTree) {
+ boolean valid = validateTypeOf(switchExpressionTree);
+ if (valid) {
+ AnnotatedTypeMirror switchType = atypeFactory.getAnnotatedType(switchExpressionTree);
+ SwitchExpressionScanner scanner =
+ new FunctionalSwitchExpressionScanner<>(
+ (ExpressionTree valueTree, Void unused) -> {
+ BaseTypeVisitor.this.commonAssignmentCheck(
+ switchType, valueTree, "switch.expression");
+ return null;
+ },
+ (r1, r2) -> null);
+
+ scanner.scanSwitchExpression(switchExpressionTree, null);
+ }
+ super.scan(switchExpressionTree, null);
+ }
+
// **********************************************************************
// Check for illegal re-assignment
// **********************************************************************
diff --git a/framework/src/main/java/org/checkerframework/common/basetype/messages.properties b/framework/src/main/java/org/checkerframework/common/basetype/messages.properties
index 8c726b6acd8..38611cce681 100644
--- a/framework/src/main/java/org/checkerframework/common/basetype/messages.properties
+++ b/framework/src/main/java/org/checkerframework/common/basetype/messages.properties
@@ -13,6 +13,7 @@ vector.copyinto=incompatible component type in Vector.copyinto.%nfound : %s%nr
return=incompatible types in return.%ntype of expression: %s%nmethod return type: %s
annotation=incompatible types in annotation.%nfound : %s%nrequired: %s
conditional=incompatible types in conditional expression.%nfound : %s%nrequired: %s
+switch.expression=incompatible types in switch expression.%nfound : %s%nrequired: %s
type.argument=incompatible type argument for type parameter %s of %s.%nfound : %s%nrequired: %s
argument=incompatible argument for parameter %s of %s.%nfound : %s%nrequired: %s
varargs=incompatible types in varargs.%nfound : %s%nrequired: %s
diff --git a/framework/src/main/java/org/checkerframework/framework/type/TypeFromExpressionVisitor.java b/framework/src/main/java/org/checkerframework/framework/type/TypeFromExpressionVisitor.java
index 59e3a6d49e5..493f1dd23cd 100644
--- a/framework/src/main/java/org/checkerframework/framework/type/TypeFromExpressionVisitor.java
+++ b/framework/src/main/java/org/checkerframework/framework/type/TypeFromExpressionVisitor.java
@@ -36,6 +36,8 @@
import org.checkerframework.framework.util.AnnotatedTypes;
import org.checkerframework.javacutil.BugInCF;
import org.checkerframework.javacutil.ElementUtils;
+import org.checkerframework.javacutil.SwitchExpressionScanner;
+import org.checkerframework.javacutil.SwitchExpressionScanner.FunctionalSwitchExpressionScanner;
import org.checkerframework.javacutil.TreeUtils;
import org.checkerframework.javacutil.TypesUtils;
@@ -176,16 +178,29 @@ public AnnotatedTypeMirror defaultAction(Tree tree, AnnotatedTypeFactory f) {
/**
* Compute the type of the switch expression tree.
*
- * @param switchExpressionTree SwitchExpressionTree; typed as Tree to be backward-compatible
- * @param f AnnotatedTypeFactory
+ * @param switchExpressionTree a SwitchExpressionTree; typed as Tree so method signature is
+ * backward-compatible
+ * @param f an AnnotatedTypeFactory
* @return the type of the switch expression
*/
public AnnotatedTypeMirror visitSwitchExpressionTree17(
Tree switchExpressionTree, AnnotatedTypeFactory f) {
- // TODO: Properly compute the type from the cases.
- AnnotatedTypeMirror result = f.type(switchExpressionTree);
- result.addAnnotations(f.getQualifierHierarchy().getTopAnnotations());
- return result;
+ TypeMirror switchTypeMirror = TreeUtils.typeOf(switchExpressionTree);
+ SwitchExpressionScanner luber =
+ new FunctionalSwitchExpressionScanner<>(
+ // Function applied to each result expression of the switch expression.
+ (valueTree, unused) -> f.getAnnotatedType(valueTree),
+ // Function used to combine the types of each result expression.
+ (type1, type2) -> {
+ if (type1 == null) {
+ return type2;
+ } else if (type2 == null) {
+ return type1;
+ } else {
+ return AnnotatedTypes.leastUpperBound(f, type1, type2, switchTypeMirror);
+ }
+ });
+ return luber.scanSwitchExpression(switchExpressionTree, null);
}
@Override
diff --git a/framework/tests/value/java17/SwitchExpressionTyping.java b/framework/tests/value/java17/SwitchExpressionTyping.java
new file mode 100644
index 00000000000..94860ed4264
--- /dev/null
+++ b/framework/tests/value/java17/SwitchExpressionTyping.java
@@ -0,0 +1,100 @@
+// @below-java17-jdk-skip-test
+import org.checkerframework.common.value.qual.IntVal;
+
+public class SwitchExpressionTyping {
+ public static boolean flag = false;
+
+ void method0(String s) {
+ @IntVal({0, 1, 2, 3}) int o =
+ switch (s) {
+ case "Hello?" -> {
+ throw new RuntimeException();
+ }
+ case "Hello" -> 0;
+ case "Bye" -> 1;
+ case "Later" -> 2;
+ case "What?" -> throw new RuntimeException();
+ default -> 3;
+ };
+ }
+
+ void method1(String s) {
+ @IntVal({1, 2, 3}) int o =
+ switch (s) {
+ case "Hello?" -> 1;
+ case "Hello" -> 1;
+ case "Bye" -> 1;
+ case "Later" -> 1;
+ case "What?" -> {
+ if (flag) {
+ yield 2;
+ }
+ yield 3;
+ }
+ default -> 1;
+ };
+
+ @IntVal(1) int o2 =
+ // :: error: (assignment)
+ switch (s) {
+ case "Hello?" -> 1;
+ case "Hello" -> 1;
+ case "Bye" -> 1;
+ case "Later" -> 1;
+ case "What?" -> {
+ if (flag) {
+ yield 2;
+ }
+ yield 3;
+ }
+ default -> 1;
+ };
+ }
+
+ void method2(String s, String r) {
+ @IntVal({0, 1, 2, 3}) int o =
+ switch (s) {
+ case "Hello?" -> {
+ if (flag) {
+ throw new RuntimeException();
+ }
+ yield 2;
+ }
+ case "Hello" -> {
+ int i =
+ switch (r) {
+ case "Hello" -> 4;
+ case "Bye" -> 5;
+ case "Later" -> 6;
+ default -> 42;
+ };
+ yield 0;
+ }
+ case "Bye" -> 1;
+ case "Later" -> {
+ int i =
+ switch (r) {
+ case "Hello":
+ {
+ yield 4;
+ }
+ case "Bye":
+ {
+ yield 5;
+ }
+ case "Later":
+ {
+ yield 6;
+ }
+ default:
+ {
+ yield 42;
+ }
+ };
+ yield 2;
+ }
+ case "What?" -> throw new RuntimeException();
+ default -> 3;
+ };
+ }
+}
diff --git a/javacutil/src/main/java/org/checkerframework/javacutil/SwitchExpressionScanner.java b/javacutil/src/main/java/org/checkerframework/javacutil/SwitchExpressionScanner.java
new file mode 100644
index 00000000000..6bd94ff50a0
--- /dev/null
+++ b/javacutil/src/main/java/org/checkerframework/javacutil/SwitchExpressionScanner.java
@@ -0,0 +1,174 @@
+package org.checkerframework.javacutil;
+
+import com.sun.source.tree.BlockTree;
+import com.sun.source.tree.CaseTree;
+import com.sun.source.tree.ExpressionTree;
+import com.sun.source.tree.Tree;
+import com.sun.source.tree.Tree.Kind;
+import com.sun.source.util.TreeScanner;
+import java.util.List;
+import java.util.function.BiFunction;
+import org.checkerframework.checker.nullness.qual.NonNull;
+import org.checkerframework.checker.nullness.qual.Nullable;
+
+/**
+ * A class that visits each result expression of a switch expression and calls {@link
+ * #visitSwitchResultExpression(ExpressionTree, Object)} on each result expression. The results of
+ * these method calls are combined using {@link #combineResults(Object, Object)}. Call {@link
+ * #scanSwitchExpression(Tree, Object)} to start scanning the switch expression.
+ *
+ * {@link FunctionalSwitchExpressionScanner} can be used to pass functions for to use for {@link
+ * #visitSwitchResultExpression(ExpressionTree, Object)} and {@link #combineResults(Object,
+ * Object)}.
+ *
+ * @param the type of the result of {@link #visitSwitchResultExpression(ExpressionTree, Object)}
+ * @param the type of the parameter to pass to {@link
+ * #visitSwitchResultExpression(ExpressionTree, Object)}
+ */
+public abstract class SwitchExpressionScanner extends TreeScanner {
+
+ /**
+ * This method is called for each result expression of the switch expression passed in {@link
+ * #scanSwitchExpression(Tree, Object)}.
+ *
+ * @param resultExpressionTree a result expression of the switch expression currently being
+ * scanned
+ * @param p a parameter
+ * @return the result of visiting the result expression
+ */
+ protected abstract R visitSwitchResultExpression(ExpressionTree resultExpressionTree, P p);
+
+ /**
+ * This method combines the result of two calls to {@link
+ * #visitSwitchResultExpression(ExpressionTree, Object)} or {@code null} and the result of one
+ * call to {@link #visitSwitchResultExpression(ExpressionTree, Object)}.
+ *
+ * @param r1 a possibly null result returned by {@link
+ * #visitSwitchResultExpression(ExpressionTree, Object)}
+ * @param r2 a possibly null result returned by {@link
+ * #visitSwitchResultExpression(ExpressionTree, Object)}
+ * @return the combination of {@code r1} and {@code r2}
+ */
+ protected abstract R combineResults(@Nullable R r1, @Nullable R r2);
+
+ /**
+ * Scans the given switch expression and calls {@link #visitSwitchResultExpression(ExpressionTree,
+ * Object)} on each result expression of the switch expression. {@link #combineResults(Object,
+ * Object)} is called to combine the results of visiting multiple switch result expressions.
+ *
+ * @param switchExpression a switch expression tree
+ * @param p the parameter to pass to {@link #visitSwitchResultExpression(ExpressionTree, Object)}
+ * @return the result of calling {@link #visitSwitchResultExpression(ExpressionTree, Object)} on
+ * each result expression of {@code switchExpression} and combining the results using {@link
+ * #combineResults(Object, Object)}
+ */
+ public R scanSwitchExpression(Tree switchExpression, P p) {
+ assert switchExpression.getKind().name().equals("SWITCH_EXPRESSION");
+ List extends CaseTree> caseTrees = TreeUtils.switchExpressionTreeGetCases(switchExpression);
+ R result = null;
+ for (CaseTree caseTree : caseTrees) {
+ if (caseTree.getStatements() != null) {
+ // This case is a switch labeled statement group, so scan the statements for yield
+ // statements.
+ result = combineResults(result, yieldVisitor.scan(caseTree.getStatements(), p));
+ } else {
+ @SuppressWarnings(
+ "nullness:assignment") // caseTree.getStatements() == null, so the case has a body.
+ @NonNull Tree body = TreeUtils.caseTreeGetBody(caseTree);
+ // This case is a switch rule, so its body is either an expression, block, or throw.
+ // See https://docs.oracle.com/javase/specs/jls/se17/html/jls-15.html#jls-15.28.2.
+ if (body.getKind() == Kind.BLOCK) {
+ // Scan for yield statements.
+ result = combineResults(result, yieldVisitor.scan(((BlockTree) body).getStatements(), p));
+ } else if (body.getKind() != Kind.THROW) {
+ // The expression is the result expression.
+ ExpressionTree expressionTree = (ExpressionTree) body;
+ result = combineResults(result, visitSwitchResultExpression(expressionTree, p));
+ }
+ }
+ }
+ @SuppressWarnings(
+ "nullness:assignment" // switch expressions must have at least one case that results in a
+ // value, so {@code result} must be nonnull.
+ )
+ @NonNull R nonNullResult = result;
+ return nonNullResult;
+ }
+
+ /**
+ * A scanner that visits all the yield trees in a given tree and calls {@link
+ * #visitSwitchResultExpression(ExpressionTree, Object)} on the expression in the yield trees. It
+ * does not descend into switch expressions.
+ */
+ protected YieldVisitor yieldVisitor = new YieldVisitor();
+
+ /**
+ * A scanner that visits all the yield trees in a given tree and calls {@link
+ * #visitSwitchResultExpression(ExpressionTree, Object)} on the expression in the yield trees. It
+ * does not descend into switch expressions.
+ */
+ protected class YieldVisitor extends TreeScanner<@Nullable R, P> {
+
+ @Override
+ public @Nullable R scan(Tree tree, P p) {
+ if (tree == null) {
+ return null;
+ }
+ if (tree.getKind().name().equals("SWITCH_EXPRESSION")) {
+ // Don't scan nested switch expressions.
+ return null;
+ } else if (tree.getKind().name().equals("YIELD")) {
+ ExpressionTree value = TreeUtils.yieldTreeGetValue(tree);
+ return visitSwitchResultExpression(value, p);
+ }
+ return super.scan(tree, p);
+ }
+
+ @Override
+ public R reduce(R r1, R r2) {
+ return combineResults(r1, r2);
+ }
+ }
+
+ /**
+ * An implementation of {@link SwitchExpressionScanner} that uses functions passed to the
+ * constructor for {@link #visitSwitchResultExpression(ExpressionTree, Object)} and {@link
+ * #combineResults(Object, Object)}.
+ *
+ * @param the type result of {@link #visitSwitchResultExpression(ExpressionTree, Object)}
+ * @param the type of the parameter to pass to {@link
+ * #visitSwitchResultExpression(ExpressionTree, Object)}
+ */
+ public static class FunctionalSwitchExpressionScanner
+ extends SwitchExpressionScanner {
+
+ /** The function to use for {@link #visitSwitchResultExpression(ExpressionTree, Object)}. */
+ private final BiFunction switchValueExpressionFunction;
+ /** The function to use for {@link #visitSwitchResultExpression(ExpressionTree, Object)}. */
+ private final BiFunction<@Nullable R1, @Nullable R1, R1> combineResultFunc;
+
+ /**
+ * Creates a {@link FunctionalSwitchExpressionScanner} that uses the given functions.
+ *
+ * @param switchValueExpressionFunc the function called on each switch result expression
+ * @param combineResultFunc the function used to combine the result of multiple calls to {@code
+ * switchValueExpressionFunc}
+ */
+ public FunctionalSwitchExpressionScanner(
+ BiFunction switchValueExpressionFunc,
+ BiFunction<@Nullable R1, @Nullable R1, R1> combineResultFunc) {
+ this.switchValueExpressionFunction = switchValueExpressionFunc;
+ this.combineResultFunc = combineResultFunc;
+ }
+
+ @Override
+ protected R1 visitSwitchResultExpression(ExpressionTree resultExpressionTree, P1 p1) {
+ return switchValueExpressionFunction.apply(resultExpressionTree, p1);
+ }
+
+ @Override
+ protected R1 combineResults(@Nullable R1 r1, @Nullable R1 r2) {
+ return combineResultFunc.apply(r1, r2);
+ }
+ }
+}
diff --git a/javacutil/src/main/java/org/checkerframework/javacutil/TreeUtils.java b/javacutil/src/main/java/org/checkerframework/javacutil/TreeUtils.java
index 92a16c897a2..8fa37d4a7dc 100644
--- a/javacutil/src/main/java/org/checkerframework/javacutil/TreeUtils.java
+++ b/javacutil/src/main/java/org/checkerframework/javacutil/TreeUtils.java
@@ -1704,6 +1704,97 @@ public static List extends ExpressionTree> caseTreeGetExpressions(CaseTree cas
}
}
+ /**
+ * Returns the selector expression of {@code switchExpressionTree}. For example
+ *
+ *
+ * switch ( expression ) { ... }
+ *
+ *
+ * @param switchExpressionTree the switch expression whose selector expression is returned
+ * @return the selector expression of {@code switchExpressionTree}
+ */
+ public static ExpressionTree switchExpressionTreeGetExpression(Tree switchExpressionTree) {
+ try {
+ Class> switchExpressionClass = Class.forName("com.sun.source.tree.SwitchExpressionTree");
+ Method getExpressionMethod = switchExpressionClass.getMethod("getExpression");
+ ExpressionTree expressionTree =
+ (ExpressionTree) getExpressionMethod.invoke(switchExpressionTree);
+ if (expressionTree != null) {
+ return expressionTree;
+ }
+ throw new BugInCF(
+ "TreeUtils.switchExpressionTreeGetExpression: expression is null for tree: %s",
+ switchExpressionTree);
+ } catch (ClassNotFoundException
+ | NoSuchMethodException
+ | InvocationTargetException
+ | IllegalAccessException e) {
+ throw new BugInCF(
+ "TreeUtils.switchExpressionTreeGetExpression: reflection failed for tree: %s",
+ switchExpressionTree, e);
+ }
+ }
+
+ /**
+ * Returns the cases of {@code switchExpressionTree}. For example
+ *
+ *
+ * switch ( expression ) {
+ * cases
+ * }
+ *
+ *
+ * @param switchExpressionTree the switch expression whose cases are returned
+ * @return the cases of {@code switchExpressionTree}
+ */
+ public static List extends CaseTree> switchExpressionTreeGetCases(Tree switchExpressionTree) {
+ try {
+ Class> switchExpressionClass = Class.forName("com.sun.source.tree.SwitchExpressionTree");
+ Method getCasesMethod = switchExpressionClass.getMethod("getCases");
+ @SuppressWarnings("unchecked")
+ List extends CaseTree> cases =
+ (List extends CaseTree>) getCasesMethod.invoke(switchExpressionTree);
+ if (cases != null) {
+ return cases;
+ }
+ throw new BugInCF(
+ "TreeUtils.switchExpressionTreeGetCases: cases is null for tree: %s",
+ switchExpressionTree);
+ } catch (ClassNotFoundException
+ | NoSuchMethodException
+ | InvocationTargetException
+ | IllegalAccessException e) {
+ throw new BugInCF(
+ "TreeUtils.switchExpressionTreeGetCases: reflection failed for tree: %s",
+ switchExpressionTree, e);
+ }
+ }
+
+ /**
+ * Returns the value (expression) for {@code yieldTree}.
+ *
+ * @param yieldTree the yield tree
+ * @return the value (expression) for {@code yieldTree}.
+ */
+ public static ExpressionTree yieldTreeGetValue(Tree yieldTree) {
+ try {
+ Class> yieldTreeClass = Class.forName("com.sun.source.tree.YieldTree");
+ Method getCasesMethod = yieldTreeClass.getMethod("getValue");
+ ExpressionTree expressionTree = (ExpressionTree) getCasesMethod.invoke(yieldTree);
+ if (expressionTree != null) {
+ return expressionTree;
+ }
+ throw new BugInCF("TreeUtils.yieldTreeGetValue: expression is null for tree: %s", yieldTree);
+ } catch (ClassNotFoundException
+ | NoSuchMethodException
+ | InvocationTargetException
+ | IllegalAccessException e) {
+ throw new BugInCF(
+ "TreeUtils.yieldTreeGetValue: reflection failed for tree: %s", yieldTree, e);
+ }
+ }
+
/**
* Returns true if the given method/constructor invocation is a varargs invocation.
*