diff --git a/build.gradle.kts b/build.gradle.kts index 6e03b6a63..4f3662b82 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -43,6 +43,7 @@ dependencies { testImplementation("org.openrewrite:rewrite-java-17") testImplementation("org.openrewrite:rewrite-groovy") + testImplementation("org.openrewrite:rewrite-kotlin:$rewriteVersion") testImplementation("org.openrewrite.gradle.tooling:model:$rewriteVersion") testRuntimeOnly("org.gradle:gradle-tooling-api:latest.release") diff --git a/src/main/java/org/openrewrite/java/testing/junit5/CleanupJUnitImports.java b/src/main/java/org/openrewrite/java/testing/junit5/CleanupJUnitImports.java index a411e92b2..77a408c1b 100644 --- a/src/main/java/org/openrewrite/java/testing/junit5/CleanupJUnitImports.java +++ b/src/main/java/org/openrewrite/java/testing/junit5/CleanupJUnitImports.java @@ -22,6 +22,7 @@ import org.openrewrite.java.JavaIsoVisitor; import org.openrewrite.java.search.UsesType; import org.openrewrite.java.tree.J; +import org.openrewrite.java.tree.JavaSourceFile; public class CleanupJUnitImports extends Recipe { @Override @@ -44,14 +45,18 @@ public TreeVisitor getVisitor() { public static class CleanupJUnitImportsVisitor extends JavaIsoVisitor { @Override - public J.CompilationUnit visitCompilationUnit(J.CompilationUnit cu, ExecutionContext ctx) { - for (J.Import im : cu.getImports()) { - String packageName = im.getPackageName(); - if (packageName.startsWith("junit") || (packageName.startsWith("org.junit") && !packageName.contains("jupiter"))) { - maybeRemoveImport(im.getTypeName()); + public J preVisit(J tree, ExecutionContext ctx) { + stopAfterPreVisit(); + if (tree instanceof JavaSourceFile) { + JavaSourceFile c = (JavaSourceFile) tree; + for (J.Import imp : c.getImports()) { + String packageName = imp.getPackageName(); + if (packageName.startsWith("junit") || (packageName.startsWith("org.junit") && !packageName.contains("jupiter"))) { + maybeRemoveImport(imp.getTypeName()); + } } } - return cu; + return tree; } } } diff --git a/src/test/java/org/openrewrite/java/testing/junit5/CleanupJUnitImportsTest.java b/src/test/java/org/openrewrite/java/testing/junit5/CleanupJUnitImportsTest.java index f59a1cbea..30ec0e109 100644 --- a/src/test/java/org/openrewrite/java/testing/junit5/CleanupJUnitImportsTest.java +++ b/src/test/java/org/openrewrite/java/testing/junit5/CleanupJUnitImportsTest.java @@ -19,10 +19,12 @@ import org.openrewrite.DocumentExample; import org.openrewrite.InMemoryExecutionContext; import org.openrewrite.java.JavaParser; +import org.openrewrite.kotlin.KotlinParser; import org.openrewrite.test.RecipeSpec; import org.openrewrite.test.RewriteTest; import static org.openrewrite.java.Assertions.java; +import static org.openrewrite.kotlin.Assertions.kotlin; class CleanupJUnitImportsTest implements RewriteTest { @@ -31,14 +33,16 @@ public void defaults(RecipeSpec spec) { spec .parser(JavaParser.fromJavaVersion() .classpathFromResources(new InMemoryExecutionContext(), "junit-4.13")) + .parser(KotlinParser.builder() + .classpathFromResources(new InMemoryExecutionContext(), "junit-4.13")) .recipe(new CleanupJUnitImports()); } @DocumentExample @Test void removesUnusedImport() { - //language=java rewriteRun( + //language=java java( """ import org.junit.Test; @@ -48,14 +52,25 @@ public class MyTest {} """ public class MyTest {} """ + ), + //language=kotlin + kotlin( + """ + import org.junit.Test + + class MyTest {} + """, + """ + class MyTest {} + """ ) ); } @Test void leavesOtherImportsAlone() { - //language=java rewriteRun( + //language=java java( """ import java.util.Arrays; @@ -65,14 +80,25 @@ void leavesOtherImportsAlone() { public class MyTest { } """ + ), + //language=kotlin + kotlin( + """ + import java.util.Arrays + import java.util.Collections + import java.util.HashSet + + class MyTest { + } + """ ) ); } @Test void leavesUsedJUnitImportAlone() { - //language=java rewriteRun( + //language=java java( """ import org.junit.Test; @@ -82,7 +108,19 @@ public class MyTest { public void foo() {} } """ + ), + //language=kotlin + kotlin( + """ + import org.junit.Test + + class MyTest { + @Test + fun foo() {} + } + """ ) ); + } }