Skip to content

Commit

Permalink
Fix for #1056: for class check in switch, handle similar to instanceof
Browse files Browse the repository at this point in the history
  • Loading branch information
eric-milles committed Mar 7, 2020
1 parent ef61845 commit ca76113
Show file tree
Hide file tree
Showing 3 changed files with 212 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3138,6 +3138,168 @@ public void testInstanceOf12() {
assertType(contents, offset, offset + 5, "java.lang.Number");
}

@Test
public void testSwitchClassCase1() {
String contents =
"void test(obj) {\n" +
" switch (obj) {\n" +
" case Number:\n" +
" obj\n" +
" break\n" +
" case String:\n" +
" obj\n" +
" }\n" +
" obj\n" +
"}\n";

int offset = contents.lastIndexOf("obj)");
assertType(contents, offset, offset + 3, "java.lang.Object");

offset = contents.indexOf("obj", offset + 1);
assertType(contents, offset, offset + 3, "java.lang.Number");

offset = contents.indexOf("obj", offset + 1);
assertType(contents, offset, offset + 3, "java.lang.String");

offset = contents.indexOf("obj", offset + 1);
assertType(contents, offset, offset + 3, "java.lang.Object");
}

@Test
public void testSwitchClassCase2() {
String contents =
"void test(obj) {\n" +
" switch (obj) {\n" +
" case Number:\n" +
" obj\n" +
" return\n" +
" case String:\n" +
" obj\n" +
" }\n" +
" obj\n" +
"}\n";

int offset = contents.lastIndexOf("obj)");
assertType(contents, offset, offset + 3, "java.lang.Object");

offset = contents.indexOf("obj", offset + 1);
assertType(contents, offset, offset + 3, "java.lang.Number");

offset = contents.indexOf("obj", offset + 1);
assertType(contents, offset, offset + 3, "java.lang.String");

offset = contents.indexOf("obj", offset + 1);
assertType(contents, offset, offset + 3, "java.lang.Object");
}

@Test
public void testSwitchClassCase3() {
String contents =
"void test(obj) {\n" +
" switch (obj) {\n" +
" case Number:\n" +
" obj\n" +
" throw new Exception()\n" +
" case String:\n" +
" obj\n" +
" }\n" +
" obj\n" +
"}\n";

int offset = contents.lastIndexOf("obj)");
assertType(contents, offset, offset + 3, "java.lang.Object");

offset = contents.indexOf("obj", offset + 1);
assertType(contents, offset, offset + 3, "java.lang.Number");

offset = contents.indexOf("obj", offset + 1);
assertType(contents, offset, offset + 3, "java.lang.String");

offset = contents.indexOf("obj", offset + 1);
assertType(contents, offset, offset + 3, "java.lang.Object");
}

@Test
public void testSwitchClassCase4() {
String contents =
"void test(obj) {\n" +
" for (i in 1..3) {\n" +
" switch (obj) {\n" +
" case Number:\n" +
" obj\n" +
" continue\n" +
" case String:\n" +
" obj\n" +
" }\n" +
" obj\n" +
" }\n" +
"}\n";

int offset = contents.lastIndexOf("obj)");
assertType(contents, offset, offset + 3, "java.lang.Object");

offset = contents.indexOf("obj", offset + 1);
assertType(contents, offset, offset + 3, "java.lang.Number");

offset = contents.indexOf("obj", offset + 1);
assertType(contents, offset, offset + 3, "java.lang.String");

offset = contents.indexOf("obj", offset + 1);
assertType(contents, offset, offset + 3, "java.lang.Object");
}

@Test
public void testSwitchClassCase5() {
String contents =
"void test(obj) {\n" +
" switch (obj) {\n" +
" case Number:\n" +
" obj\n" +
" case String:\n" +
" obj\n" +
" }\n" +
" obj\n" +
"}\n";

int offset = contents.lastIndexOf("obj)");
assertType(contents, offset, offset + 3, "java.lang.Object");

offset = contents.indexOf("obj", offset + 1);
assertType(contents, offset, offset + 3, "java.lang.Number");

offset = contents.indexOf("obj", offset + 1);
assertType(contents, offset, offset + 3, "java.io.Serializable");

offset = contents.indexOf("obj", offset + 1);
assertType(contents, offset, offset + 3, "java.lang.Object");
}

@Test
public void testSwitchClassCase6() {
String contents =
"void test(obj) {\n" +
" switch (obj) {\n" +
" case Number:\n" +
" obj\n" +
" default:\n" +
" obj\n" +
" }\n" +
" obj\n" +
"}\n";

int offset = contents.lastIndexOf("obj)");
assertType(contents, offset, offset + 3, "java.lang.Object");

offset = contents.indexOf("obj", offset + 1);
assertType(contents, offset, offset + 3, "java.lang.Number");

offset = contents.indexOf("obj", offset + 1);
assertType(contents, offset, offset + 3, "java.lang.Object");

offset = contents.indexOf("obj", offset + 1);
assertType(contents, offset, offset + 3, "java.lang.Object");
}

@Test
public void testThisInInnerClass() {
String contents =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,17 @@
import org.codehaus.groovy.ast.expr.UnaryPlusExpression;
import org.codehaus.groovy.ast.expr.VariableExpression;
import org.codehaus.groovy.ast.stmt.BlockStatement;
import org.codehaus.groovy.ast.stmt.BreakStatement;
import org.codehaus.groovy.ast.stmt.CaseStatement;
import org.codehaus.groovy.ast.stmt.CatchStatement;
import org.codehaus.groovy.ast.stmt.ContinueStatement;
import org.codehaus.groovy.ast.stmt.ExpressionStatement;
import org.codehaus.groovy.ast.stmt.ForStatement;
import org.codehaus.groovy.ast.stmt.IfStatement;
import org.codehaus.groovy.ast.stmt.ReturnStatement;
import org.codehaus.groovy.ast.stmt.Statement;
import org.codehaus.groovy.ast.stmt.SwitchStatement;
import org.codehaus.groovy.ast.stmt.ThrowStatement;
import org.codehaus.groovy.ast.tools.GeneralUtils;
import org.codehaus.groovy.ast.tools.GenericsUtils;
import org.codehaus.groovy.ast.tools.WideningCategories;
Expand Down Expand Up @@ -1632,6 +1637,50 @@ public void visitStaticMethodCallExpression(final StaticMethodCallExpression nod
});
}

@Override
public void visitSwitch(final SwitchStatement switchStatement) {

switchStatement.getExpression().visit(this);

VariableScope caseScope = null;
List<VariableScope> caseScopes = new ArrayList<>();
for (final CaseStatement caseStatement : switchStatement.getCaseStatements()) {

caseStatement.getExpression().visit(this);

// when the case tests for a type, apply instanceof-like flow-typing
if (switchStatement.getExpression() instanceof VariableExpression && caseStatement.getExpression() instanceof ClassExpression) {
String name = switchStatement.getExpression().getText();
ClassNode type = caseStatement.getExpression().getType();
if (caseScope == null) {
scopes.add(caseScope = new VariableScope(scopes.getLast(), caseStatement.getCode(), false));
} else {
type = WideningCategories.lowestUpperBound(type, caseScope.lookupNameInCurrentScope(name).type);
}
caseScope.updateVariableSoft(name, type);
}
// TODO: Should "case T:" fall through "case <expr>:" fall through "case U:" be LUB(T,U) or LUB(T,U,?) or something?

caseStatement.getCode().visit(this);

if (caseScope != null) {
Statement lastStmt = (caseStatement.getCode() instanceof BlockStatement && !((BlockStatement) caseStatement.getCode()).isEmpty()
? DefaultGroovyMethods.last(((BlockStatement) caseStatement.getCode()).getStatements()) : caseStatement.getCode());
if (lastStmt instanceof BreakStatement || lastStmt instanceof ContinueStatement || lastStmt instanceof ReturnStatement || lastStmt instanceof ThrowStatement) {
caseScopes.add(scopes.removeLast());
caseScope = null;
}
}
}
if (caseScope != null) {
caseScopes.add(scopes.removeLast());
}

switchStatement.getDefaultStatement().visit(this);

caseScopes.forEach(VariableScope::bubbleUpdates);
}

@Override
public void visitTernaryExpression(final TernaryExpression node) {
if (isDependentExpression(node)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -578,9 +578,7 @@ private boolean updateVariableImpl(String name, ClassNode type, ClassNode declar
* @param type type of variable
*/
/*package*/ void updateVariableSoft(String name, ClassNode type) {
VariableInfo info = merge(parent.lookupName(name), type, null);
info = nameVariableMap.put(name, info);
assert info == null;
nameVariableMap.put(name, merge(parent.lookupName(name), type, null));
}

private static VariableInfo merge(VariableInfo base, ClassNode type, ClassNode declaringType) {
Expand Down

0 comments on commit ca76113

Please sign in to comment.