diff --git a/src/main/java/org/openrewrite/staticanalysis/JavaElementFactory.java b/src/main/java/org/openrewrite/staticanalysis/JavaElementFactory.java index ce9e4d057..ed66b14c4 100644 --- a/src/main/java/org/openrewrite/staticanalysis/JavaElementFactory.java +++ b/src/main/java/org/openrewrite/staticanalysis/JavaElementFactory.java @@ -27,6 +27,18 @@ final class JavaElementFactory { + static J.Binary newLogicalExpression(J.Binary.Type operator, Expression left, Expression right) { + return new J.Binary( + randomId(), + Space.EMPTY, + Markers.EMPTY, + left, + new JLeftPadded<>(Space.SINGLE_SPACE, operator, Markers.EMPTY), + right, + JavaType.Primitive.Boolean + ); + } + static J.MemberReference newStaticMethodReference(JavaType.Method method, boolean qualified, @Nullable JavaType type) { JavaType.FullyQualified declaringType = method.getDeclaringType(); Expression containing = className(declaringType, qualified); diff --git a/src/main/java/org/openrewrite/staticanalysis/MinimumSwitchCases.java b/src/main/java/org/openrewrite/staticanalysis/MinimumSwitchCases.java index c7d8e67ee..0c4dd9a56 100644 --- a/src/main/java/org/openrewrite/staticanalysis/MinimumSwitchCases.java +++ b/src/main/java/org/openrewrite/staticanalysis/MinimumSwitchCases.java @@ -18,9 +18,9 @@ import lombok.AllArgsConstructor; import lombok.Value; import lombok.With; +import org.jetbrains.annotations.NotNull; import org.openrewrite.ExecutionContext; import org.openrewrite.Recipe; -import org.openrewrite.Tree; import org.openrewrite.TreeVisitor; import org.openrewrite.internal.ListUtils; import org.openrewrite.internal.RecipeRunException; @@ -29,6 +29,7 @@ import org.openrewrite.java.ShortenFullyQualifiedTypeReferences; import org.openrewrite.java.tree.*; import org.openrewrite.marker.Marker; +import org.openrewrite.marker.Markers; import java.time.Duration; import java.util.ArrayList; @@ -39,6 +40,7 @@ import static java.util.Collections.singleton; import static java.util.Collections.singletonList; import static java.util.Objects.requireNonNull; +import static org.openrewrite.Tree.randomId; public class MinimumSwitchCases extends Recipe { @Override @@ -73,12 +75,7 @@ public TreeVisitor getVisitor() { final JavaTemplate ifElseIfString = JavaTemplate.builder("" + "if(#{any(java.lang.String)}.equals(#{any(java.lang.String)})) {\n" + "} else if(#{any(java.lang.String)}.equals(#{any(java.lang.String)})) {\n" + - "}").contextSensitive().build(); - - final JavaTemplate ifElseIfEnum = JavaTemplate.builder("" + - "if(#{any()} == #{}) {\n" + - "} else if(#{any()} == #{}) {\n" + - "}").contextSensitive().build(); + "}").build(); final JavaTemplate ifElsePrimitive = JavaTemplate.builder("" + "if(#{any()} == #{any()}) {\n" + @@ -88,12 +85,7 @@ public TreeVisitor getVisitor() { final JavaTemplate ifElseString = JavaTemplate.builder("" + "if(#{any(java.lang.String)}.equals(#{any(java.lang.String)})) {\n" + "} else {\n" + - "}").contextSensitive().build(); - - final JavaTemplate ifElseEnum = JavaTemplate.builder("" + - "if(#{any()} == #{}) {\n" + - "} else {\n" + - "}").contextSensitive().build(); + "}").build(); final JavaTemplate ifPrimitive = JavaTemplate.builder("" + "if(#{any()} == #{any()}) {\n" + @@ -101,11 +93,7 @@ public TreeVisitor getVisitor() { final JavaTemplate ifString = JavaTemplate.builder("" + "if(#{any(java.lang.String)}.equals(#{any(java.lang.String)})) {\n" + - "}").contextSensitive().build(); - - final JavaTemplate ifEnum = JavaTemplate.builder("" + - "if(#{any()} == #{}) {\n" + - "}").contextSensitive().build(); + "}").build(); @Override public J visitBlock(J.Block block, ExecutionContext executionContext) { @@ -174,18 +162,26 @@ public J visitSwitch(J.Switch switch_, ExecutionContext ctx) { generatedIf = ifElseIfString.apply(getCursor(), switch_.getCoordinates().replace(), cases[0].getPattern(), tree, cases[1].getPattern(), tree); } } else if (switchesOnEnum(switch_)) { - if (cases[1] == null) { - if (isDefault(cases[0])) { - return switch_.withMarkers(switch_.getMarkers().add(new DefaultOnly())); - } else { - generatedIf = ifEnum.apply(getCursor(), switch_.getCoordinates().replace(), tree, enumIdentToFieldAccessString(cases[0].getPattern())); + if (cases[1] == null && isDefault(cases[0])) { + return switch_.withMarkers(switch_.getMarkers().add(new DefaultOnly())); + } + + generatedIf = createIfForEnum(tree, cases[0].getPattern()); + if (cases[1] != null) { + Statement elseBody = J.Block.createEmptyBlock(); + if (!isDefault(cases[1])) { + elseBody = createIfForEnum(tree, cases[1].getPattern()); } - } else if (isDefault(cases[1])) { - generatedIf = ifElseEnum.apply(getCursor(), switch_.getCoordinates().replace(), tree, enumIdentToFieldAccessString(cases[0].getPattern())); - } else { - generatedIf = ifElseIfEnum.apply(getCursor(), switch_.getCoordinates().replace(), tree, enumIdentToFieldAccessString(cases[0].getPattern()), tree, enumIdentToFieldAccessString(cases[1].getPattern())); + generatedIf = generatedIf + .withElsePart(new J.If.Else( + randomId(), + Space.EMPTY, + Markers.EMPTY, + JRightPadded.build(elseBody) + ) + ); } - doAfterVisit(new ShortenFullyQualifiedTypeReferences().getVisitor()); + doAfterVisit(ShortenFullyQualifiedTypeReferences.modifyOnly(generatedIf)); } else { if (cases[1] == null) { if (isDefault(cases[0])) { @@ -264,6 +260,31 @@ private String enumIdentToFieldAccessString(Expression casePattern) { }; } + @NotNull + private static J.If createIfForEnum(Expression expression, Expression enumTree) { + J.If generatedIf; + if (enumTree instanceof J.Identifier) { + enumTree = new J.FieldAccess( + randomId(), + enumTree.getPrefix(), + Markers.EMPTY, + JavaElementFactory.className(enumTree.getType(), true), + JLeftPadded.build(enumTree.withPrefix(Space.EMPTY)), + enumTree.getType() + ); + } + J.Binary ifCond = JavaElementFactory.newLogicalExpression(J.Binary.Type.Equal, expression, enumTree); + generatedIf = new J.If( + randomId(), + Space.EMPTY, + Markers.EMPTY, + new J.ControlParentheses<>(randomId(), Space.EMPTY, Markers.EMPTY, JRightPadded.build(ifCond)), + JRightPadded.build(J.Block.createEmptyBlock()), + null + ); + return generatedIf; + } + @Value @With @AllArgsConstructor @@ -271,7 +292,7 @@ private static class DefaultOnly implements Marker { UUID id; public DefaultOnly() { - id = Tree.randomId(); + id = randomId(); } } } diff --git a/src/test/java/org/openrewrite/staticanalysis/MinimumSwitchCasesTest.java b/src/test/java/org/openrewrite/staticanalysis/MinimumSwitchCasesTest.java index 71a985064..97525eadc 100644 --- a/src/test/java/org/openrewrite/staticanalysis/MinimumSwitchCasesTest.java +++ b/src/test/java/org/openrewrite/staticanalysis/MinimumSwitchCasesTest.java @@ -494,7 +494,7 @@ void importsOnEnumImplied() { java( """ import java.time.LocalDate; - + class Test { void test(LocalDate date) { switch(date.getDayOfWeek()) { @@ -647,6 +647,38 @@ void doSomethingElse() {} ); } + @Test + void nestedEnum() { + rewriteRun( + //language=java + java( + """ + class Test { + int test(java.io.ObjectInputFilter filter) { + switch (filter.checkInput(null)) { + case ALLOWED: return 0; + default: return 1; + } + } + } + """, + """ + import java.io.ObjectInputFilter; + + class Test { + int test(java.io.ObjectInputFilter filter) { + if (filter.checkInput(null) == ObjectInputFilter.Status.ALLOWED) { + return 0; + } else { + return 1; + } + } + } + """ + ) + ); + } + @Test @Issue("https://github.com/openrewrite/rewrite/issues/3076") void multipleSwitchExpressions() {