diff --git a/src/main/java/org/openrewrite/java/testing/cleanup/AssertFalseEqualsToAssertNotEquals.java b/src/main/java/org/openrewrite/java/testing/cleanup/AssertFalseEqualsToAssertNotEquals.java index 437e857d9..289582ce5 100644 --- a/src/main/java/org/openrewrite/java/testing/cleanup/AssertFalseEqualsToAssertNotEquals.java +++ b/src/main/java/org/openrewrite/java/testing/cleanup/AssertFalseEqualsToAssertNotEquals.java @@ -23,7 +23,6 @@ import org.openrewrite.java.tree.Expression; import org.openrewrite.java.tree.J; -import java.util.List; import java.util.function.Supplier; public class AssertFalseEqualsToAssertNotEquals extends Recipe { @@ -70,8 +69,7 @@ public J visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) } sb.append(")"); - J.MethodInvocation methodInvocation = getMethodInvocation(method); - J.MethodInvocation s = (J.MethodInvocation)methodInvocation.getArguments().get(0); + J.MethodInvocation s = (J.MethodInvocation) method.getArguments().get(0); args = method.getArguments().size() == 2 ? new Object[]{s.getSelect(), s.getArguments().get(0), mi.getArguments().get(1)} : new Object[]{s.getSelect(), s.getArguments().get(0)}; JavaTemplate t; if (mi.getSelect() == null) { @@ -86,22 +84,15 @@ public J visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) return mi; } - private J.MethodInvocation getMethodInvocation(Expression expr){ - List s = expr.getSideEffects(); - return ((J.MethodInvocation) s.get(0)); - } - private boolean isEquals(Expression expr) { - List s = expr.getSideEffects(); - - if (s.isEmpty()){ + if (!(expr instanceof J.MethodInvocation)) { return false; } - J.MethodInvocation methodInvocation = getMethodInvocation(expr); - - return "equals".equals(methodInvocation.getName().getSimpleName()); + J.MethodInvocation methodInvocation = (J.MethodInvocation) expr; + return "equals".equals(methodInvocation.getName().getSimpleName()) + && methodInvocation.getArguments().size() == 1; } }; } diff --git a/src/main/java/org/openrewrite/java/testing/cleanup/AssertTrueEqualsToAssertEquals.java b/src/main/java/org/openrewrite/java/testing/cleanup/AssertTrueEqualsToAssertEquals.java index 721297d75..961d616ba 100644 --- a/src/main/java/org/openrewrite/java/testing/cleanup/AssertTrueEqualsToAssertEquals.java +++ b/src/main/java/org/openrewrite/java/testing/cleanup/AssertTrueEqualsToAssertEquals.java @@ -23,7 +23,6 @@ import org.openrewrite.java.tree.Expression; import org.openrewrite.java.tree.J; -import java.util.List; import java.util.function.Supplier; public class AssertTrueEqualsToAssertEquals extends Recipe { @@ -64,8 +63,7 @@ public J visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) } else { sb.append("Assertions."); } - J.MethodInvocation methodInvocation = getMethodInvocation(mi); - J.MethodInvocation s = (J.MethodInvocation)methodInvocation.getArguments().get(0); + J.MethodInvocation s = (J.MethodInvocation) mi.getArguments().get(0); sb.append("assertEquals(#{any(java.lang.Object)},#{any(java.lang.Object)}"); Object[] args; if (mi.getArguments().size() == 2) { @@ -90,23 +88,15 @@ public J visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) return mi; } - private J.MethodInvocation getMethodInvocation(Expression expr){ - List s = expr.getSideEffects(); - return ((J.MethodInvocation) s.get(0)); - } - private boolean isEquals(Expression expr) { - List s = expr.getSideEffects(); - - if (s.isEmpty()){ + if (!(expr instanceof J.MethodInvocation)) { return false; } - J.MethodInvocation methodInvocation = getMethodInvocation(expr); + J.MethodInvocation methodInvocation = (J.MethodInvocation) expr; return "equals".equals(methodInvocation.getName().getSimpleName()) && methodInvocation.getArguments().size() == 1; - } }; } diff --git a/src/test/java/org/openrewrite/java/testing/cleanup/AssertFalseEqualToAssertNotEqualsTest.java b/src/test/java/org/openrewrite/java/testing/cleanup/AssertFalseEqualToAssertNotEqualsTest.java index f532f40ce..d0bdd7309 100644 --- a/src/test/java/org/openrewrite/java/testing/cleanup/AssertFalseEqualToAssertNotEqualsTest.java +++ b/src/test/java/org/openrewrite/java/testing/cleanup/AssertFalseEqualToAssertNotEqualsTest.java @@ -99,4 +99,26 @@ void test() { ) ); } + + @SuppressWarnings("ConstantConditions") + @Test + void retainEqualsAndedWithSomethingElse() { + //language=java + rewriteRun( + java( + """ + import java.util.Arrays; + import org.junit.jupiter.api.Assertions; + + public class Test { + void test() { + String a = "a"; + String b = "b"; + Assertions.assertFalse(a.equals(b) && a.length() > 0); + } + } + """ + ) + ); + } } diff --git a/src/test/java/org/openrewrite/java/testing/cleanup/AssertTrueEqualsToAssertEqualsTest.java b/src/test/java/org/openrewrite/java/testing/cleanup/AssertTrueEqualsToAssertEqualsTest.java index e7cc00f23..8bb5b9690 100644 --- a/src/test/java/org/openrewrite/java/testing/cleanup/AssertTrueEqualsToAssertEqualsTest.java +++ b/src/test/java/org/openrewrite/java/testing/cleanup/AssertTrueEqualsToAssertEqualsTest.java @@ -124,4 +124,26 @@ void test() { ) ); } + + @SuppressWarnings("ConstantConditions") + @Test + void retainEqualsAndedWithSomethingElse() { + //language=java + rewriteRun( + java( + """ + import java.util.Arrays; + import org.junit.jupiter.api.Assertions; + + public class Test { + void test() { + String a = "a"; + String b = "b"; + Assertions.assertTrue(a.equals(b) && a.length() > 0); + } + } + """ + ) + ); + } }