diff --git a/src/main/java/org/openrewrite/java/testing/cleanup/TestsShouldNotBePublic.java b/src/main/java/org/openrewrite/java/testing/cleanup/TestsShouldNotBePublic.java index 3f8974a21..7ea4850ce 100644 --- a/src/main/java/org/openrewrite/java/testing/cleanup/TestsShouldNotBePublic.java +++ b/src/main/java/org/openrewrite/java/testing/cleanup/TestsShouldNotBePublic.java @@ -17,25 +17,27 @@ import lombok.AllArgsConstructor; import lombok.EqualsAndHashCode; +import lombok.RequiredArgsConstructor; +import lombok.Value; import org.openrewrite.ExecutionContext; import org.openrewrite.Option; -import org.openrewrite.Recipe; +import org.openrewrite.ScanningRecipe; import org.openrewrite.TreeVisitor; import org.openrewrite.internal.ListUtils; import org.openrewrite.internal.lang.Nullable; import org.openrewrite.java.ChangeMethodAccessLevelVisitor; import org.openrewrite.java.JavaIsoVisitor; import org.openrewrite.java.MethodMatcher; -import org.openrewrite.java.tree.*; +import org.openrewrite.java.tree.Comment; +import org.openrewrite.java.tree.Flag; +import org.openrewrite.java.tree.J; +import org.openrewrite.java.tree.TypeUtils; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.Set; +import java.util.*; @AllArgsConstructor @EqualsAndHashCode(callSuper = false) -public class TestsShouldNotBePublic extends Recipe { +public class TestsShouldNotBePublic extends ScanningRecipe { @Option(displayName = "Remove protected modifiers", description = "Also remove protected modifiers from test methods", @@ -60,16 +62,37 @@ public Set getTags() { } @Override - public TreeVisitor getVisitor() { - return new TestsNotPublicVisitor(Boolean.TRUE.equals(removeProtectedModifiers)); + public Accumulator getInitialValue(ExecutionContext ctx) { + return new Accumulator(); } + @Override + public TreeVisitor getScanner(Accumulator acc) { + return new JavaIsoVisitor() { + @Override + public J.ClassDeclaration visitClassDeclaration(J.ClassDeclaration classDeclaration, ExecutionContext ctx) { + J.ClassDeclaration cd = super.visitClassDeclaration(classDeclaration, ctx); + if (cd.getExtends() != null) { + acc.extendedClasses.add(String.valueOf(cd.getExtends().getType())); + } + return cd; + } + }; + } + + @Override + public TreeVisitor getVisitor(Accumulator acc) { + return new TestsNotPublicVisitor(Boolean.TRUE.equals(removeProtectedModifiers), acc); + } + + public static class Accumulator { + Set extendedClasses = new HashSet<>(); + } + + @RequiredArgsConstructor private static final class TestsNotPublicVisitor extends JavaIsoVisitor { private final Boolean orProtected; - - private TestsNotPublicVisitor(Boolean orProtected) { - this.orProtected = orProtected; - } + private final Accumulator acc; @Override public J.ClassDeclaration visitClassDeclaration(J.ClassDeclaration classDecl, ExecutionContext ctx) { @@ -77,8 +100,8 @@ public J.ClassDeclaration visitClassDeclaration(J.ClassDeclaration classDecl, Ex if (c.getKind() != J.ClassDeclaration.Kind.Type.Interface && c.getModifiers().stream().anyMatch(mod -> mod.getType() == J.Modifier.Type.Public) - && c.getModifiers().stream().noneMatch(mod -> mod.getType() == J.Modifier.Type.Abstract)) { - + && c.getModifiers().stream().noneMatch(mod -> mod.getType() == J.Modifier.Type.Abstract) + && !acc.extendedClasses.contains(String.valueOf(c.getType()))) { boolean hasTestMethods = c.getBody().getStatements().stream() .filter(org.openrewrite.java.tree.J.MethodDeclaration.class::isInstance) .map(J.MethodDeclaration.class::cast) diff --git a/src/test/java/org/openrewrite/java/testing/cleanup/TestsShouldNotBePublicTest.java b/src/test/java/org/openrewrite/java/testing/cleanup/TestsShouldNotBePublicTest.java index 5241651f4..3d9679549 100644 --- a/src/test/java/org/openrewrite/java/testing/cleanup/TestsShouldNotBePublicTest.java +++ b/src/test/java/org/openrewrite/java/testing/cleanup/TestsShouldNotBePublicTest.java @@ -18,6 +18,7 @@ import org.junit.jupiter.api.Test; import org.openrewrite.DocumentExample; import org.openrewrite.InMemoryExecutionContext; +import org.openrewrite.Issue; import org.openrewrite.java.JavaParser; import org.openrewrite.test.RecipeSpec; import org.openrewrite.test.RewriteTest; @@ -493,4 +494,42 @@ Collection testFactoryMethod() { ) ); } + + @Test + @Issue("https://github.com/openrewrite/rewrite-testing-frameworks/issues/309") + void baseclassForTestsNeedsToStayPublic() { + //language=java + rewriteRun( + spec -> spec.recipe(new TestsShouldNotBePublic(true)), + java( + // base class for tests should stay public + """ + package com.hello; + + import org.junit.jupiter.api.BeforeEach; + + public class MyTestBase { + @BeforeEach + void setUp() { + } + } + """ + ), + java( + // test class extends base class from another package + """ + package com.world; + + import com.hello.MyTestBase; + import org.junit.jupiter.api.Test; + + class MyTest extends MyTestBase { + @Test + void isWorking() { + } + } + """ + ) + ); + } }