diff --git a/src/main/java/org/openrewrite/java/testing/jmockit/JMockitMockUpToMockito.java b/src/main/java/org/openrewrite/java/testing/jmockit/JMockitMockUpToMockito.java
new file mode 100644
index 000000000..fb942c504
--- /dev/null
+++ b/src/main/java/org/openrewrite/java/testing/jmockit/JMockitMockUpToMockito.java
@@ -0,0 +1,532 @@
+/*
+ * Copyright 2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.openrewrite.java.testing.jmockit;
+
+import org.openrewrite.*;
+import org.openrewrite.java.JavaIsoVisitor;
+import org.openrewrite.java.JavaParser;
+import org.openrewrite.java.JavaTemplate;
+import org.openrewrite.java.search.UsesType;
+import org.openrewrite.java.tree.*;
+import org.openrewrite.marker.SearchResult;
+import org.openrewrite.staticanalysis.LambdaBlockToExpression;
+import org.openrewrite.staticanalysis.VariableReferences;
+
+import java.util.*;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.stream.Collectors;
+
+import static java.util.stream.Collectors.toList;
+import static java.util.stream.Collectors.toMap;
+import static org.openrewrite.java.testing.mockito.MockitoUtils.maybeAddMethodWithAnnotation;
+import static org.openrewrite.java.tree.Flag.Private;
+import static org.openrewrite.java.tree.Flag.Static;
+
+public class JMockitMockUpToMockito extends Recipe {
+
+ private static final String JMOCKIT_MOCKUP_IMPORT = "mockit.MockUp";
+ private static final String JMOCKIT_MOCK_IMPORT = "mockit.Mock";
+
+ private static final String MOCKITO_CLASSPATH = "mockito-core-3";
+ private static final String MOCKITO_ALL_IMPORT = "org.mockito.Mockito.*";
+ private static final String MOCKITO_MATCHER_IMPORT = "org.mockito.ArgumentMatchers.*";
+ private static final String MOCKITO_DELEGATEANSWER_IMPORT = "org.mockito.AdditionalAnswers.delegatesTo";
+ private static final String MOCKITO_STATIC_PREFIX = "mockStatic";
+ private static final String MOCKITO_STATIC_IMPORT = "org.mockito.MockedStatic";
+ private static final String MOCKITO_MOCK_PREFIX = "mock";
+ private static final String MOCKITO_CONSTRUCTION_PREFIX = "mockCons";
+ private static final String MOCKITO_CONSTRUCTION_IMPORT = "org.mockito.MockedConstruction";
+
+ @Override
+ public String getDisplayName() {
+ return "Rewrite JMockit MockUp to Mockito statements";
+ }
+
+ @Override
+ public String getDescription() {
+ return "Rewrites JMockit `MockUp` blocks to Mockito statements. This recipe will not rewrite private methods in MockUp.";
+ }
+
+ @Override
+ public TreeVisitor, ExecutionContext> getVisitor() {
+ return Preconditions.check(new UsesType<>(JMOCKIT_MOCKUP_IMPORT, false), new JMockitMockUpToMockitoVisitor());
+ }
+
+ private static class JMockitMockUpToMockitoVisitor extends JavaIsoVisitor {
+ private final Map tearDownMocks = new HashMap<>();
+
+ /**
+ * Handle at class level because need to handle the case where when there is a MockUp in a setup method, and we
+ * need to close the migrated mockCons in the teardown, yet the teardown method comes before the setup method
+ */
+ @Override
+ public J.ClassDeclaration visitClassDeclaration(J.ClassDeclaration classDecl, ExecutionContext ctx) {
+ // Handle @Before/@BeforeEach mockUp
+ Set mds = TreeVisitor.collect(
+ new JavaIsoVisitor() {
+ @Override
+ public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration md, ExecutionContext ctx) {
+ if (isSetUpMethod(md)) {
+ return SearchResult.found(md);
+ }
+ return super.visitMethodDeclaration(md, ctx);
+ }
+ },
+ classDecl,
+ new HashSet<>()
+ )
+ .stream()
+ .filter(J.MethodDeclaration.class::isInstance)
+ .map(J.MethodDeclaration.class::cast)
+ .collect(Collectors.toSet());
+ if (mds.isEmpty()) {
+ return super.visitClassDeclaration(classDecl, ctx);
+ }
+
+ AtomicReference cdRef = new AtomicReference<>(classDecl);
+ mds.forEach(md -> md.getBody()
+ .getStatements()
+ .stream()
+ .filter(this::isMockUpStatement)
+ .map(J.NewClass.class::cast)
+ .forEach(newClass -> {
+ String className = ((J.ParameterizedType) newClass.getClazz()).getTypeParameters().get(0).toString();
+
+ Map mockedMethods = getMockUpMethods(newClass);
+
+ // Add mockStatic field
+ if (mockedMethods.values().stream().anyMatch(m -> m.getFlags().contains(Static))) {
+ cdRef.set(JavaTemplate.builder("private MockedStatic #{};")
+ .contextSensitive()
+ .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, MOCKITO_CLASSPATH))
+ .imports(MOCKITO_STATIC_IMPORT)
+ .staticImports(MOCKITO_ALL_IMPORT)
+ .build()
+ .apply(
+ new Cursor(getCursor().getParentOrThrow(), cdRef.get()),
+ cdRef.get().getBody().getCoordinates().firstStatement(),
+ MOCKITO_STATIC_PREFIX + className
+ ));
+ J.VariableDeclarations mockField = (J.VariableDeclarations) cdRef.get().getBody().getStatements().get(0);
+ J.Identifier mockFieldId = mockField.getVariables().get(0).getName();
+ tearDownMocks.put(MOCKITO_STATIC_PREFIX + className, mockFieldId);
+ }
+ // Add mockConstruction field
+ if (mockedMethods.values().stream().anyMatch(m -> !m.getFlags().contains(Static))) {
+ cdRef.set(JavaTemplate.builder("private MockedConstruction #{};")
+ .contextSensitive()
+ .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, MOCKITO_CLASSPATH))
+ .imports(MOCKITO_CONSTRUCTION_IMPORT)
+ .staticImports(MOCKITO_ALL_IMPORT)
+ .build()
+ .apply(
+ updateCursor(cdRef.get()),
+ cdRef.get().getBody().getCoordinates().firstStatement(),
+ MOCKITO_CONSTRUCTION_PREFIX + className
+ ));
+ J.VariableDeclarations mockField = (J.VariableDeclarations) cdRef.get().getBody().getStatements().get(0);
+ J.Identifier mockFieldId = mockField.getVariables().get(0).getName();
+ tearDownMocks.put(MOCKITO_CONSTRUCTION_PREFIX + className, mockFieldId);
+ }
+ }));
+
+ J.ClassDeclaration cd = maybeAddMethodWithAnnotation(this, cdRef.get(), ctx, true, "tearDown",
+ "@org.junit.After",
+ "@After",
+ "junit-4.13",
+ "org.junit.After",
+ "");
+
+ return super.visitClassDeclaration(cd, ctx);
+ }
+
+ @Override
+ public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration methodDecl, ExecutionContext ctx) {
+ J.MethodDeclaration md = methodDecl;
+ if (md.getBody() == null) {
+ return md;
+ }
+ if (isTearDownMethod(md)) {
+ for (J.Identifier id : tearDownMocks.values()) {
+ String type = TypeUtils.asFullyQualified(id.getFieldType().getType()).getFullyQualifiedName();
+ md = JavaTemplate.builder("#{any(" + type + ")}.closeOnDemand();")
+ .contextSensitive()
+ .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, MOCKITO_CLASSPATH))
+ .imports(MOCKITO_STATIC_IMPORT, MOCKITO_CONSTRUCTION_IMPORT)
+ .staticImports(MOCKITO_ALL_IMPORT)
+ .build()
+ .apply(
+ updateCursor(md),
+ md.getBody().getCoordinates().lastStatement(),
+ id
+ );
+ }
+ return md;
+ }
+
+ boolean isBeforeTest = isSetUpMethod(md);
+ List varDeclarationInTry = new ArrayList<>();
+ List mockStaticMethodInTry = new ArrayList<>();
+ List mockConstructionMethodInTry = new ArrayList<>();
+ List encloseStatements = new ArrayList<>();
+ List residualStatements = new ArrayList<>();
+ for (Statement statement : md.getBody().getStatements()) {
+ if (!isMockUpStatement(statement)) {
+ encloseStatements.add(statement);
+ continue;
+ }
+
+ J.NewClass newClass = (J.NewClass) statement;
+
+ // Only discard @Mock method declarations
+ residualStatements.addAll(newClass
+ .getBody()
+ .getStatements()
+ .stream()
+ .filter(s -> {
+ if (s instanceof J.MethodDeclaration) {
+ return ((J.MethodDeclaration) s).getLeadingAnnotations().stream()
+ .noneMatch(o -> TypeUtils.isOfClassType(o.getType(), JMOCKIT_MOCK_IMPORT));
+ }
+ return true;
+ })
+ .collect(toList())
+ );
+
+ JavaType mockType = ((J.ParameterizedType) newClass.getClazz()).getTypeParameters().get(0).getType();
+ String className = ((J.ParameterizedType) newClass.getClazz()).getTypeParameters().get(0).toString();
+
+ Map mockedMethods = getMockUpMethods(newClass);
+
+ // Add MockStatic
+ Map mockedPublicStaticMethods = mockedMethods
+ .entrySet()
+ .stream()
+ .filter(m -> m.getValue().getFlags().contains(Static))
+ .collect(toMap(Map.Entry::getKey, Map.Entry::getValue));
+ if (!mockedPublicStaticMethods.isEmpty()) {
+ if (isBeforeTest) {
+ String tpl = getMockStaticDeclarationInBefore(className) +
+ getMockStaticMethods((JavaType.Class) mockType, className, mockedPublicStaticMethods);
+
+ md = JavaTemplate.builder(tpl)
+ .contextSensitive()
+ .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, MOCKITO_CLASSPATH))
+ .imports(MOCKITO_STATIC_IMPORT)
+ .staticImports(MOCKITO_ALL_IMPORT)
+ .build()
+ .apply(
+ updateCursor(md),
+ statement.getCoordinates().after(),
+ tearDownMocks.get(MOCKITO_STATIC_PREFIX + className)
+ );
+ } else {
+ varDeclarationInTry.add(getMockStaticDeclarationInTry(className));
+ mockStaticMethodInTry.add(getMockStaticMethods((JavaType.Class) mockType, className, mockedPublicStaticMethods));
+ }
+
+ maybeAddImport(MOCKITO_STATIC_IMPORT);
+ }
+
+ // Add MockConstruction
+ Map mockedPublicMethods = mockedMethods
+ .entrySet()
+ .stream()
+ .filter(m -> !m.getValue().getFlags().contains(Static))
+ .collect(toMap(Map.Entry::getKey, Map.Entry::getValue));
+ if (!mockedPublicMethods.isEmpty()) {
+ if (isBeforeTest) {
+ String tpl = getMockConstructionMethods(className, mockedPublicMethods) +
+ getMockConstructionDeclarationInBefore(className);
+
+ md = JavaTemplate.builder(tpl)
+ .contextSensitive()
+ .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, MOCKITO_CLASSPATH))
+ .imports(MOCKITO_STATIC_IMPORT)
+ .staticImports(MOCKITO_ALL_IMPORT, MOCKITO_DELEGATEANSWER_IMPORT)
+ .build()
+ .apply(
+ updateCursor(md),
+ statement.getCoordinates().after(),
+ tearDownMocks.get(MOCKITO_CONSTRUCTION_PREFIX + className)
+ );
+ } else {
+ varDeclarationInTry.add(getMockConstructionDeclarationInTry(className));
+ mockConstructionMethodInTry.add(getMockConstructionMethods(className, mockedPublicMethods));
+ }
+
+ maybeAddImport(MOCKITO_CONSTRUCTION_IMPORT);
+ maybeAddImport("org.mockito.Answers", "CALLS_REAL_METHODS", false);
+ maybeAddImport("org.mockito.AdditionalAnswers", "delegatesTo", false);
+ }
+
+ List statements = md.getBody().getStatements();
+ statements.remove(statement);
+ md = md.withBody(md.getBody().withStatements(statements));
+ }
+
+ if (!varDeclarationInTry.isEmpty()) {
+ String tpl = String.join("", mockConstructionMethodInTry) +
+ "try (" +
+ String.join(";", varDeclarationInTry) +
+ ") {" +
+ String.join(";", mockStaticMethodInTry) +
+ "}";
+
+ J.MethodDeclaration residualMd = md.withBody(md.getBody().withStatements(residualStatements));
+ residualMd = JavaTemplate.builder(tpl)
+ .contextSensitive()
+ .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, MOCKITO_CLASSPATH))
+ .imports(MOCKITO_STATIC_IMPORT, MOCKITO_CONSTRUCTION_IMPORT)
+ .staticImports(MOCKITO_ALL_IMPORT, MOCKITO_MATCHER_IMPORT, MOCKITO_MATCHER_IMPORT, MOCKITO_DELEGATEANSWER_IMPORT)
+ .build()
+ .apply(updateCursor(residualMd), residualMd.getBody().getCoordinates().lastStatement());
+
+ List mdStatements = residualMd.getBody().getStatements();
+ J.Try try_ = (J.Try) mdStatements.get(mdStatements.size() - 1);
+
+ List tryStatements = try_.getBody().getStatements();
+ tryStatements.addAll(encloseStatements);
+ try_ = try_.withBody(try_.getBody().withStatements(tryStatements));
+
+ mdStatements.set(mdStatements.size() - 1, try_);
+ md = md.withBody(residualMd.getBody().withStatements(mdStatements));
+ }
+
+ maybeAddImport(MOCKITO_ALL_IMPORT.replace(".*", ""), "*", false);
+ maybeRemoveImport(JMOCKIT_MOCK_IMPORT);
+ maybeRemoveImport(JMOCKIT_MOCKUP_IMPORT);
+
+ doAfterVisit(new LambdaBlockToExpression().getVisitor());
+ return maybeAutoFormat(methodDecl, md, ctx);
+ }
+
+ private String getMatcher(JavaType s) {
+ maybeAddImport(MOCKITO_MATCHER_IMPORT.replace(".*", ""), "*", false);
+ if (s instanceof JavaType.Primitive) {
+ switch (s.toString()) {
+ case "int":
+ return "anyInt()";
+ case "long":
+ return "anyLong()";
+ case "double":
+ return "anyDouble()";
+ case "float":
+ return "anyFloat()";
+ case "short":
+ return "anyShort()";
+ case "byte":
+ return "anyByte()";
+ case "char":
+ return "anyChar()";
+ case "boolean":
+ return "anyBoolean()";
+ }
+ } else if (s instanceof JavaType.Array) {
+ String elem = TypeUtils.asArray(s).getElemType().toString();
+ return "nullable(" + elem + "[].class)";
+ }
+ return "nullable(" + TypeUtils.asFullyQualified(s).getClassName() + ".class)";
+ }
+
+ private String getAnswerBody(J.MethodDeclaration md) {
+ Set usedVariables = new HashSet<>();
+ new JavaIsoVisitor>() {
+ @Override
+ public J.VariableDeclarations.NamedVariable visitVariable(J.VariableDeclarations.NamedVariable variable, Set ctx) {
+ Cursor scope = getCursor().dropParentUntil((is) -> is instanceof J.ClassDeclaration || is instanceof J.Block || is instanceof J.MethodDeclaration || is instanceof J.ForLoop || is instanceof J.ForEachLoop || is instanceof J.ForLoop.Control || is instanceof J.ForEachLoop.Control || is instanceof J.Case || is instanceof J.Try || is instanceof J.Try.Resource || is instanceof J.Try.Catch || is instanceof J.MultiCatch || is instanceof J.Lambda || is instanceof JavaSourceFile);
+ if (!VariableReferences.findRhsReferences(scope.getValue(), variable.getName()).isEmpty()) {
+ ctx.add(variable.getSimpleName());
+ }
+ return super.visitVariable(variable, ctx);
+ }
+ }.visit(md, usedVariables);
+
+ StringBuilder sb = new StringBuilder();
+ List parameters = md.getParameters();
+ for (int i = 0; i < parameters.size(); i++) {
+ if (!(parameters.get(i) instanceof J.VariableDeclarations)) {
+ continue;
+ }
+ J.VariableDeclarations vd = (J.VariableDeclarations) parameters.get(i);
+ String className;
+ if (vd.getType() instanceof JavaType.Primitive) {
+ className = vd.getType().toString();
+ } else {
+ className = vd.getTypeAsFullyQualified().getClassName();
+ }
+ String varName = vd.getVariables().get(0).getName().getSimpleName();
+ if (usedVariables.contains(varName)) {
+ sb.append(className).append(" ").append(varName)
+ .append(" = invocation.getArgument(").append(i).append(");");
+ }
+ }
+
+ boolean hasReturn = false;
+ for (Statement s : md.getBody().getStatements()) {
+ hasReturn |= s instanceof J.Return;
+ sb.append(s.print(getCursor())).append(";");
+ }
+ // Avoid syntax error
+ if (!hasReturn) {
+ sb.append("return null;");
+ }
+ return sb.toString();
+ }
+
+ private String getCallRealMethod(JavaType.Method m) {
+ return "(" +
+ m.getParameterTypes()
+ .stream()
+ .map(this::getMatcher)
+ .collect(Collectors.joining(", ")) +
+ ")).thenCallRealMethod();";
+ }
+
+ private String getMockStaticDeclarationInBefore(String className) {
+ return "#{any(" + MOCKITO_STATIC_IMPORT + ")}" +
+ " = mockStatic(" + className + ".class);";
+ }
+
+ private String getMockStaticDeclarationInTry(String className) {
+ return "MockedStatic " + MOCKITO_STATIC_PREFIX + className +
+ " = mockStatic(" + className + ".class)";
+ }
+
+ private String getMockStaticMethods(JavaType.Class clazz, String className, Map mockedMethods) {
+ StringBuilder tpl = new StringBuilder();
+
+ // To generate predictable method order
+ List keys = mockedMethods.keySet().stream()
+ .sorted(Comparator.comparing(o -> o.print(getCursor())))
+ .collect(toList());
+ for (J.MethodDeclaration m : keys) {
+ tpl.append("mockStatic").append(className)
+ .append(".when(() -> ").append(className).append(".").append(m.getSimpleName()).append("(")
+ .append(m.getParameters()
+ .stream()
+ .filter(J.VariableDeclarations.class::isInstance)
+ .map(J.VariableDeclarations.class::cast)
+ .map(J.VariableDeclarations::getType)
+ .map(this::getMatcher)
+ .collect(Collectors.joining(", "))
+ )
+ .append(")).thenAnswer(invocation -> {")
+ .append(getAnswerBody(m))
+ .append("});");
+ }
+
+ // Call real method for non private, static methods
+ clazz.getMethods()
+ .stream()
+ .filter(m -> !m.isConstructor())
+ .filter(m -> !m.getFlags().contains(Private))
+ .filter(m -> m.getFlags().contains(Static))
+ .filter(m -> !mockedMethods.containsValue(m))
+ .forEach(m -> tpl.append("mockStatic").append(className).append(".when(() -> ")
+ .append(className).append(".").append(m.getName())
+ .append(getCallRealMethod(m))
+ .append(");")
+ );
+
+ return tpl.toString();
+ }
+
+ private String getMockConstructionDeclarationInBefore(String className) {
+ return "#{any(" + MOCKITO_CONSTRUCTION_IMPORT + ")}" +
+ " = mockConstructionWithAnswer(" + className + ".class, delegatesTo(" + MOCKITO_MOCK_PREFIX + className + "));";
+ }
+
+ private String getMockConstructionDeclarationInTry(String className) {
+ return "MockedConstruction " + MOCKITO_CONSTRUCTION_PREFIX + className +
+ " = mockConstructionWithAnswer(" + className + ".class, delegatesTo(" + MOCKITO_MOCK_PREFIX + className + "))";
+ }
+
+ private String getMockConstructionMethods(String className, Map mockedMethods) {
+ StringBuilder tpl = new StringBuilder()
+ .append(className)
+ .append(" ")
+ .append(MOCKITO_MOCK_PREFIX).append(className)
+ .append(" = mock(").append(className).append(".class, CALLS_REAL_METHODS);");
+
+ mockedMethods
+ .keySet()
+ .stream()
+ .sorted(Comparator.comparing(o -> o.print(getCursor())))
+ .forEach(m -> tpl.append("doAnswer(invocation -> {")
+ .append(getAnswerBody(m))
+ .append("}).when(").append(MOCKITO_MOCK_PREFIX).append(className).append(").").append(m.getSimpleName()).append("(")
+ .append(m.getParameters()
+ .stream()
+ .filter(J.VariableDeclarations.class::isInstance)
+ .map(J.VariableDeclarations.class::cast)
+ .map(J.VariableDeclarations::getType)
+ .map(this::getMatcher)
+ .collect(Collectors.joining(", "))
+ )
+ .append(");"));
+
+ return tpl.toString();
+ }
+
+ private boolean isMockUpStatement(Tree tree) {
+ return tree instanceof J.NewClass &&
+ ((J.NewClass) tree).getClazz() != null &&
+ TypeUtils.isOfClassType(((J.NewClass) tree).getClazz().getType(), JMOCKIT_MOCKUP_IMPORT);
+ }
+
+ private boolean isSetUpMethod(J.MethodDeclaration md) {
+ return md
+ .getLeadingAnnotations()
+ .stream()
+ .anyMatch(o -> TypeUtils.isOfClassType(o.getType(), "org.junit.Before"));
+ }
+
+ private boolean isTearDownMethod(J.MethodDeclaration md) {
+ return md
+ .getLeadingAnnotations()
+ .stream()
+ .anyMatch(o -> TypeUtils.isOfClassType(o.getType(), "org.junit.After"));
+ }
+
+ private Map getMockUpMethods(J.NewClass newClass) {
+ JavaType mockType = ((J.ParameterizedType) newClass.getClazz()).getTypeParameters().get(0).getType();
+ return newClass.getBody()
+ .getStatements()
+ .stream()
+ .filter(J.MethodDeclaration.class::isInstance)
+ .map(J.MethodDeclaration.class::cast)
+ .filter(s -> s.getLeadingAnnotations().stream()
+ .anyMatch(o -> TypeUtils.isOfClassType(o.getType(), JMOCKIT_MOCK_IMPORT)))
+ .map(method -> {
+ Optional found = TypeUtils.findDeclaredMethod(
+ TypeUtils.asFullyQualified(mockType),
+ method.getSimpleName(),
+ method.getMethodType().getParameterTypes()
+ );
+ if (found.isPresent()) {
+ JavaType.Method m = found.get();
+ if (!m.getFlags().contains(Private)) {
+ return new AbstractMap.SimpleEntry<>(method, found.get());
+ }
+ }
+ return null;
+ })
+ .filter(Objects::nonNull)
+ .collect(toMap(Map.Entry::getKey, Map.Entry::getValue));
+ }
+ }
+}
diff --git a/src/main/java/org/openrewrite/java/testing/mockito/MockitoUtils.java b/src/main/java/org/openrewrite/java/testing/mockito/MockitoUtils.java
new file mode 100644
index 000000000..ad143e008
--- /dev/null
+++ b/src/main/java/org/openrewrite/java/testing/mockito/MockitoUtils.java
@@ -0,0 +1,93 @@
+/*
+ * Copyright 2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.openrewrite.java.testing.mockito;
+
+import org.jspecify.annotations.Nullable;
+import org.openrewrite.Cursor;
+import org.openrewrite.ExecutionContext;
+import org.openrewrite.java.AnnotationMatcher;
+import org.openrewrite.java.JavaParser;
+import org.openrewrite.java.JavaTemplate;
+import org.openrewrite.java.JavaVisitor;
+import org.openrewrite.java.tree.J;
+import org.openrewrite.java.tree.Statement;
+
+import java.util.List;
+import java.util.stream.Collectors;
+
+public class MockitoUtils {
+ public static J.ClassDeclaration maybeAddMethodWithAnnotation(
+ JavaVisitor visitor,
+ J.ClassDeclaration classDecl,
+ ExecutionContext ctx,
+ boolean isPublic,
+ String methodName,
+ String methodAnnotationSignature,
+ String methodAnnotationToAdd,
+ String additionalClasspathResource,
+ String importToAdd,
+ String methodAnnotationParameters
+ ) {
+ if (hasMethodWithAnnotation(classDecl, new AnnotationMatcher(methodAnnotationSignature))) {
+ return classDecl;
+ }
+
+ J.MethodDeclaration firstTestMethod = getFirstTestMethod(
+ classDecl.getBody().getStatements().stream().filter(J.MethodDeclaration.class::isInstance)
+ .map(J.MethodDeclaration.class::cast).collect(Collectors.toList()));
+
+ visitor.maybeAddImport(importToAdd);
+ String tplStr = methodAnnotationToAdd + methodAnnotationParameters +
+ (isPublic ? " public" : "") + " void " + methodName + "() {}";
+ return JavaTemplate.builder(tplStr)
+ .contextSensitive()
+ .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, additionalClasspathResource))
+ .imports(importToAdd)
+ .build()
+ .apply(
+ new Cursor(visitor.getCursor().getParentOrThrow(), classDecl),
+ (firstTestMethod != null) ?
+ firstTestMethod.getCoordinates().before() :
+ classDecl.getBody().getCoordinates().lastStatement()
+ );
+ }
+
+ private static boolean hasMethodWithAnnotation(J.ClassDeclaration classDecl, AnnotationMatcher annotationMatcher) {
+ for (Statement statement : classDecl.getBody().getStatements()) {
+ if (statement instanceof J.MethodDeclaration) {
+ J.MethodDeclaration methodDeclaration = (J.MethodDeclaration) statement;
+ List allAnnotations = methodDeclaration.getAllAnnotations();
+ for (J.Annotation annotation : allAnnotations) {
+ if (annotationMatcher.matches(annotation)) {
+ return true;
+ }
+ }
+ }
+ }
+ return false;
+ }
+
+ private static J.@Nullable MethodDeclaration getFirstTestMethod(List methods) {
+ for (J.MethodDeclaration methodDeclaration : methods) {
+ for (J.Annotation annotation : methodDeclaration.getLeadingAnnotations()) {
+ if ("Test".equals(annotation.getSimpleName())) {
+ return methodDeclaration;
+ }
+ }
+ }
+ return null;
+ }
+}
diff --git a/src/main/java/org/openrewrite/java/testing/mockito/PowerMockitoMockStaticToMockito.java b/src/main/java/org/openrewrite/java/testing/mockito/PowerMockitoMockStaticToMockito.java
index aa7cb417d..dbc259a8e 100644
--- a/src/main/java/org/openrewrite/java/testing/mockito/PowerMockitoMockStaticToMockito.java
+++ b/src/main/java/org/openrewrite/java/testing/mockito/PowerMockitoMockStaticToMockito.java
@@ -26,6 +26,8 @@
import java.util.*;
import java.util.stream.Collectors;
+import static org.openrewrite.java.testing.mockito.MockitoUtils.maybeAddMethodWithAnnotation;
+
public class PowerMockitoMockStaticToMockito extends Recipe {
@Override
@@ -213,32 +215,6 @@ private static boolean isFieldAlreadyDefined(J.Block classBody, String fieldName
return false;
}
- private static J.@Nullable MethodDeclaration getFirstTestMethod(List methods) {
- for (J.MethodDeclaration methodDeclaration : methods) {
- for (J.Annotation annotation : methodDeclaration.getLeadingAnnotations()) {
- if ("Test".equals(annotation.getSimpleName())) {
- return methodDeclaration;
- }
- }
- }
- return null;
- }
-
- private static boolean hasMethodWithAnnotation(J.ClassDeclaration classDecl, AnnotationMatcher annotationMatcher) {
- for (Statement statement : classDecl.getBody().getStatements()) {
- if (statement instanceof J.MethodDeclaration) {
- J.MethodDeclaration methodDeclaration = (J.MethodDeclaration) statement;
- List allAnnotations = methodDeclaration.getAllAnnotations();
- for (J.Annotation annotation : allAnnotations) {
- if (annotationMatcher.matches(annotation)) {
- return true;
- }
- }
- }
- }
- return false;
- }
-
private static boolean isStaticMockAlreadyClosed(J.Identifier staticMock, J.Block methodBody) {
for (Statement statement : methodBody.getStatements()) {
if (statement instanceof J.MethodInvocation) {
@@ -465,7 +441,7 @@ private J.ClassDeclaration addFieldDeclarationForMockedTypes(J.ClassDeclaration
private J.ClassDeclaration maybeAddSetUpMethodBody(J.ClassDeclaration classDecl, ExecutionContext ctx) {
String testGroupsAsString = getTestGroupsAsString();
- return maybeAddMethodWithAnnotation(classDecl, ctx, "setUpStaticMocks",
+ return maybeAddMethodWithAnnotation(this, classDecl, ctx, false, "setUpStaticMocks",
setUpMethodAnnotationSignature, setUpMethodAnnotation,
additionalClasspathResource, setUpImportToAdd, testGroupsAsString);
}
@@ -483,38 +459,12 @@ private String getTestGroupsAsString() {
private J.ClassDeclaration maybeAddTearDownMethodBody(J.ClassDeclaration classDecl, ExecutionContext ctx) {
String testGroupsAsString = (getTestGroupsAsString().isEmpty()) ? tearDownMethodAnnotationParameters : getTestGroupsAsString();
- return maybeAddMethodWithAnnotation(classDecl, ctx, "tearDownStaticMocks",
+ return maybeAddMethodWithAnnotation(this, classDecl, ctx, false, "tearDownStaticMocks",
tearDownMethodAnnotationSignature,
tearDownMethodAnnotation,
additionalClasspathResource, tearDownImportToAdd, testGroupsAsString);
}
- private J.ClassDeclaration maybeAddMethodWithAnnotation(J.ClassDeclaration classDecl, ExecutionContext ctx,
- String methodName, String methodAnnotationSignature,
- String methodAnnotationToAdd,
- String additionalClasspathResource, String importToAdd,
- String methodAnnotationParameters) {
- if (hasMethodWithAnnotation(classDecl, new AnnotationMatcher(methodAnnotationSignature))) {
- return classDecl;
- }
-
- J.MethodDeclaration firstTestMethod = getFirstTestMethod(
- classDecl.getBody().getStatements().stream().filter(J.MethodDeclaration.class::isInstance)
- .map(J.MethodDeclaration.class::cast).collect(Collectors.toList()));
-
- maybeAddImport(importToAdd);
- return JavaTemplate.builder(methodAnnotationToAdd + methodAnnotationParameters + " void " + methodName + "() {}")
- .contextSensitive()
- .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, additionalClasspathResource))
- .imports(importToAdd)
- .build()
- .apply(
- new Cursor(getCursor().getParentOrThrow(), classDecl),
- (firstTestMethod != null) ?
- firstTestMethod.getCoordinates().before() :
- classDecl.getBody().getCoordinates().lastStatement()
- );
- }
private J.MethodInvocation modifyWhenMethodInvocation(J.MethodInvocation whenMethod) {
List methodArguments = whenMethod.getArguments();
diff --git a/src/main/resources/META-INF/rewrite/jmockit.yml b/src/main/resources/META-INF/rewrite/jmockit.yml
index a15136715..d1ee335d7 100644
--- a/src/main/resources/META-INF/rewrite/jmockit.yml
+++ b/src/main/resources/META-INF/rewrite/jmockit.yml
@@ -23,6 +23,7 @@ tags:
- jmockit
recipeList:
- org.openrewrite.java.testing.jmockit.JMockitBlockToMockito
+ - org.openrewrite.java.testing.jmockit.JMockitMockUpToMockito
- org.openrewrite.java.testing.jmockit.JMockitAnnotatedArgumentToMockito
- org.openrewrite.java.ChangeType:
oldFullyQualifiedTypeName: mockit.Mocked
diff --git a/src/test/java/org/openrewrite/java/testing/jmockit/JMockitMockUpToMockitoTest.java b/src/test/java/org/openrewrite/java/testing/jmockit/JMockitMockUpToMockitoTest.java
new file mode 100644
index 000000000..49c9ed48b
--- /dev/null
+++ b/src/test/java/org/openrewrite/java/testing/jmockit/JMockitMockUpToMockitoTest.java
@@ -0,0 +1,653 @@
+/*
+ * Copyright 2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.openrewrite.java.testing.jmockit;
+
+import org.junit.jupiter.api.Test;
+import org.openrewrite.DocumentExample;
+import org.openrewrite.test.RecipeSpec;
+import org.openrewrite.test.RewriteTest;
+import org.openrewrite.test.SourceSpec;
+import org.openrewrite.test.TypeValidation;
+
+import static org.openrewrite.java.Assertions.java;
+import static org.openrewrite.java.testing.jmockit.JMockitTestUtils.*;
+
+class JMockitMockUpToMockitoTest implements RewriteTest {
+
+ @Override
+ public void defaults(RecipeSpec spec) {
+ setParserSettings(spec, JMOCKIT_DEPENDENCY, JUNIT_4_DEPENDENCY);
+ }
+
+ @DocumentExample
+ @Test
+ void mockUpStaticMethodTest() {
+ //language=java
+ rewriteRun(
+ java(
+ """
+ import mockit.Mock;
+ import mockit.MockUp;
+ import static org.junit.Assert.assertEquals;
+ import org.junit.Test;
+
+ public class MockUpTest {
+ @Test
+ public void test() {
+ new MockUp() {
+
+ @Mock
+ public int staticMethod() {
+ return 1024;
+ }
+
+ @Mock
+ public int staticMethod(int v) {
+ return 128;
+ }
+ };
+ assertEquals(1024, MyClazz.staticMethod());
+ assertEquals(128, MyClazz.staticMethod(0));
+ }
+
+ public static class MyClazz {
+ public static int staticMethod() {
+ return 0;
+ }
+
+ public static int staticMethod(int v) {
+ return 1;
+ }
+ }
+ }
+ """, """
+ import static org.junit.Assert.assertEquals;
+ import static org.mockito.ArgumentMatchers.*;
+ import static org.mockito.Mockito.*;
+
+ import org.junit.Test;
+ import org.mockito.MockedStatic;
+
+ public class MockUpTest {
+ @Test
+ public void test() {
+ try (MockedStatic mockStaticMyClazz = mockStatic(MyClazz.class)) {
+ mockStaticMyClazz.when(() -> MyClazz.staticMethod()).thenAnswer(invocation -> 1024);
+ mockStaticMyClazz.when(() -> MyClazz.staticMethod(anyInt())).thenAnswer(invocation -> 128);
+ assertEquals(1024, MyClazz.staticMethod());
+ assertEquals(128, MyClazz.staticMethod(0));
+ }
+ }
+
+ public static class MyClazz {
+ public static int staticMethod() {
+ return 0;
+ }
+
+ public static int staticMethod(int v) {
+ return 1;
+ }
+ }
+ }
+ """));
+ }
+
+ @Test
+ void mockUpMultipleTest() {
+ //language=java
+ rewriteRun(
+ spec -> spec.afterTypeValidationOptions(TypeValidation.builder().identifiers(false).build()),
+ java(
+ """
+ package com.openrewrite;
+ public static class Foo {
+ public String getMsg() {
+ return "foo";
+ }
+
+ public String getMsg(String echo) {
+ return "foo" + echo;
+ }
+ }
+ """,
+ SourceSpec::skip
+ ),
+ java(
+ """
+ package com.openrewrite;
+ public static class Bar {
+ public String getMsg() {
+ return "bar";
+ }
+
+ public String getMsg(String echo) {
+ return "bar" + echo;
+ }
+ }
+ """,
+ SourceSpec::skip
+ ),
+ java(
+ """
+ import com.openrewrite.Foo;
+ import com.openrewrite.Bar;
+ import org.junit.Test;
+ import mockit.Mock;
+ import mockit.MockUp;
+ import static org.junit.Assert.assertEquals;
+
+ public class MockUpTest {
+ @Test
+ public void test() {
+ new MockUp() {
+ @Mock
+ public String getMsg() {
+ return "FOO";
+ }
+ @Mock
+ public String getMsg(String echo) {
+ return "FOO" + echo;
+ }
+ };
+ new MockUp() {
+ @Mock
+ public String getMsg() {
+ return "BAR";
+ }
+ @Mock
+ public String getMsg(String echo) {
+ return "BAR" + echo;
+ }
+ };
+ assertEquals("FOO", new Foo().getMsg());
+ assertEquals("FOOecho", new Foo().getMsg("echo"));
+ assertEquals("BAR", new Bar().getMsg());
+ assertEquals("BARecho", new Bar().getMsg("echo"));
+ }
+ }
+ """, """
+ import com.openrewrite.Foo;
+ import com.openrewrite.Bar;
+ import org.junit.Test;
+ import org.mockito.MockedConstruction;
+ import static org.junit.Assert.assertEquals;
+ import static org.mockito.AdditionalAnswers.delegatesTo;
+ import static org.mockito.Answers.CALLS_REAL_METHODS;
+ import static org.mockito.ArgumentMatchers.*;
+ import static org.mockito.Mockito.*;
+
+ public class MockUpTest {
+ @Test
+ public void test() {
+ Foo mockFoo = mock(Foo.class, CALLS_REAL_METHODS);
+ doAnswer(invocation -> "FOO").when(mockFoo).getMsg();
+ doAnswer(invocation -> {
+ String echo = invocation.getArgument(0);
+ return "FOO" + echo;
+ }).when(mockFoo).getMsg(nullable(String.class));
+ Bar mockBar = mock(Bar.class, CALLS_REAL_METHODS);
+ doAnswer(invocation -> "BAR").when(mockBar).getMsg();
+ doAnswer(invocation -> {
+ String echo = invocation.getArgument(0);
+ return "BAR" + echo;
+ }).when(mockBar).getMsg(nullable(String.class));
+ try (MockedConstruction mockConsFoo = mockConstructionWithAnswer(Foo.class, delegatesTo(mockFoo));MockedConstruction mockConsBar = mockConstructionWithAnswer(Bar.class, delegatesTo(mockBar))) {
+ assertEquals("FOO", new Foo().getMsg());
+ assertEquals("FOOecho", new Foo().getMsg("echo"));
+ assertEquals("BAR", new Bar().getMsg());
+ assertEquals("BARecho", new Bar().getMsg("echo"));
+ }
+ }
+ }
+ """)
+ );
+ }
+
+ @Test
+ void mockUpInnerStatementTest() {
+ //language=java
+ rewriteRun(
+ java(
+ """
+ import mockit.Mock;
+ import mockit.MockUp;
+
+ import org.junit.Test;
+ import static org.junit.Assert.assertEquals;
+
+ public class MockUpTest {
+ @Test
+ public void test() {
+ new MockUp() {
+ final String msg = "newMsg";
+
+ @Mock
+ public String getMsg() {
+ return msg;
+ }
+ };
+
+ // Should ignore the newClass statement
+ new Runnable() {
+ @Override
+ public void run() {
+ System.out.println("run");
+ }
+ };
+ assertEquals("newMsg", new MyClazz().getMsg());
+ }
+
+ public static class MyClazz {
+ public String getMsg() {
+ return "msg";
+ }
+ }
+ }
+ """, """
+ import org.junit.Test;
+ import org.mockito.MockedConstruction;
+
+ import static org.junit.Assert.assertEquals;
+ import static org.mockito.AdditionalAnswers.delegatesTo;
+ import static org.mockito.Answers.CALLS_REAL_METHODS;
+ import static org.mockito.Mockito.*;
+
+ public class MockUpTest {
+ @Test
+ public void test() {
+ final String msg = "newMsg";
+ MyClazz mockMyClazz = mock(MyClazz.class, CALLS_REAL_METHODS);
+ doAnswer(invocation -> msg).when(mockMyClazz).getMsg();
+ try (MockedConstruction mockConsMyClazz = mockConstructionWithAnswer(MyClazz.class, delegatesTo(mockMyClazz))) {
+
+ // Should ignore the newClass statement
+ new Runnable() {
+ @Override
+ public void run() {
+ System.out.println("run");
+ }
+ };
+ assertEquals("newMsg", new MyClazz().getMsg());
+ }
+ }
+
+ public static class MyClazz {
+ public String getMsg() {
+ return "msg";
+ }
+ }
+ }
+ """));
+ }
+
+ @Test
+ void mockUpVoidTest() {
+ //language=java
+ rewriteRun(
+ java(
+ """
+ import mockit.Mock;
+ import mockit.MockUp;
+ import static org.junit.Assert.assertEquals;
+ import org.junit.Test;
+
+ public class MockUpTest {
+ @Test
+ public void test() {
+ new MockUp() {
+ @Mock
+ public void changeMsg() {
+ MockUpClass.Save.msg = "mockMsg";
+ }
+
+ @Mock
+ public void changeText(String text) {
+ MockUpClass.Save.text = "mockText";
+ }
+ };
+
+ assertEquals("mockMsg", new MockUpClass().getMsg());
+ assertEquals("mockText", new MockUpClass().getText());
+ }
+
+ public static class MockUpClass {
+ public static class Save {
+ public static String msg = "msg";
+ public static String text = "text";
+ }
+
+ public final String getMsg() {
+ changeMsg();
+ return Save.msg;
+ }
+
+ public void changeMsg() {
+ Save.msg = "newMsg";
+ }
+
+ public String getText() {
+ changeText("newText");
+ return Save.text;
+ }
+
+ public static void changeText(String text) {
+ Save.text = text;
+ }
+ }
+ }
+ """,
+ """
+ import static org.junit.Assert.assertEquals;
+ import static org.mockito.AdditionalAnswers.delegatesTo;
+ import static org.mockito.Answers.CALLS_REAL_METHODS;
+ import static org.mockito.ArgumentMatchers.*;
+ import static org.mockito.Mockito.*;
+
+ import org.junit.Test;
+ import org.mockito.MockedConstruction;
+ import org.mockito.MockedStatic;
+
+ public class MockUpTest {
+ @Test
+ public void test() {
+ MockUpClass mockMockUpClass = mock(MockUpClass.class, CALLS_REAL_METHODS);
+ doAnswer(invocation -> {
+ MockUpClass.Save.msg = "mockMsg";
+ return null;
+ }).when(mockMockUpClass).changeMsg();
+ try (MockedStatic mockStaticMockUpClass = mockStatic(MockUpClass.class);MockedConstruction mockConsMockUpClass = mockConstructionWithAnswer(MockUpClass.class, delegatesTo(mockMockUpClass))) {
+ mockStaticMockUpClass.when(() -> MockUpClass.changeText(nullable(String.class))).thenAnswer(invocation -> {
+ String text = invocation.getArgument(0);
+ MockUpClass.Save.text = "mockText";
+ return null;
+ });
+
+ assertEquals("mockMsg", new MockUpClass().getMsg());
+ assertEquals("mockText", new MockUpClass().getText());
+ }
+ }
+
+ public static class MockUpClass {
+ public static class Save {
+ public static String msg = "msg";
+ public static String text = "text";
+ }
+
+ public final String getMsg() {
+ changeMsg();
+ return Save.msg;
+ }
+
+ public void changeMsg() {
+ Save.msg = "newMsg";
+ }
+
+ public String getText() {
+ changeText("newText");
+ return Save.text;
+ }
+
+ public static void changeText(String text) {
+ Save.text = text;
+ }
+ }
+ }
+ """));
+ }
+
+ @Test
+ void mockUpAtSetUpWithoutTearDownTest() {
+ rewriteRun(
+ //language=java
+ java(
+ """
+ import org.junit.Before;
+ import org.junit.Test;
+ import mockit.Mock;
+ import mockit.MockUp;
+ import static org.junit.Assert.assertEquals;
+
+ public class MockUpTest {
+ @Before
+ public void setUp() {
+ new MockUp() {
+ @Mock
+ public String getMsg() {
+ return "mockMsg";
+ }
+ };
+ }
+
+ @Test
+ public void test() {
+ assertEquals("mockMsg", new MyClazz().getMsg());
+ }
+
+ public static class MyClazz {
+ public String getMsg() {
+ return "msg";
+ }
+ }
+ }
+ """,
+ """
+ import org.junit.After;
+ import org.junit.Before;
+ import org.junit.Test;
+ import org.mockito.MockedConstruction;
+ import static org.junit.Assert.assertEquals;
+ import static org.mockito.AdditionalAnswers.delegatesTo;
+ import static org.mockito.Answers.CALLS_REAL_METHODS;
+ import static org.mockito.Mockito.*;
+
+ public class MockUpTest {
+ private MockedConstruction mockConsMyClazz;
+
+ @Before
+ public void setUp() {
+ MyClazz mockMyClazz = mock(MyClazz.class, CALLS_REAL_METHODS);
+ doAnswer(invocation -> "mockMsg").when(mockMyClazz).getMsg();
+ mockConsMyClazz = mockConstructionWithAnswer(MyClazz.class, delegatesTo(mockMyClazz));
+ }
+
+ @After
+ public void tearDown() {
+ mockConsMyClazz.closeOnDemand();
+ }
+
+ @Test
+ public void test() {
+ assertEquals("mockMsg", new MyClazz().getMsg());
+ }
+
+ public static class MyClazz {
+ public String getMsg() {
+ return "msg";
+ }
+ }
+ }
+ """
+ )
+ );
+ }
+
+ @Test
+ void mockUpAtSetUpWithTearDownTest() {
+ rewriteRun(
+ //language=java
+ java(
+ """
+ import org.junit.Before;
+ import org.junit.After;
+ import org.junit.Test;
+ import mockit.Mock;
+ import mockit.MockUp;
+ import static org.junit.Assert.assertEquals;
+
+ public class MockUpTest {
+ @Before
+ public void setUp() {
+ new MockUp() {
+ @Mock
+ public String getMsg() {
+ return "mockMsg";
+ }
+
+ @Mock
+ public String getStaticMsg() {
+ return "mockStaticMsg";
+ }
+ };
+ }
+
+ @After
+ public void tearDown() {
+ }
+
+ @Test
+ public void test() {
+ assertEquals("mockMsg", new MyClazz().getMsg());
+ assertEquals("mockStaticMsg", MyClazz.getStaticMsg());
+ }
+
+ public static class MyClazz {
+ public String getMsg() {
+ return "msg";
+ }
+
+ public static String getStaticMsg() {
+ return "staticMsg";
+ }
+ }
+ }
+ """,
+ """
+ import org.junit.Before;
+ import org.junit.After;
+ import org.junit.Test;
+ import org.mockito.MockedConstruction;
+ import org.mockito.MockedStatic;
+ import static org.junit.Assert.assertEquals;
+ import static org.mockito.AdditionalAnswers.delegatesTo;
+ import static org.mockito.Answers.CALLS_REAL_METHODS;
+ import static org.mockito.Mockito.*;
+
+ public class MockUpTest {
+ private MockedConstruction mockConsMyClazz;
+ private MockedStatic mockStaticMyClazz;
+
+ @Before
+ public void setUp() {
+ MyClazz mockMyClazz = mock(MyClazz.class, CALLS_REAL_METHODS);
+ doAnswer(invocation -> "mockMsg").when(mockMyClazz).getMsg();
+ mockConsMyClazz = mockConstructionWithAnswer(MyClazz.class, delegatesTo(mockMyClazz));
+ mockStaticMyClazz = mockStatic(MyClazz.class);
+ mockStaticMyClazz.when(() -> MyClazz.getStaticMsg()).thenAnswer(invocation -> "mockStaticMsg");
+ }
+
+ @After
+ public void tearDown() {
+ mockConsMyClazz.closeOnDemand();
+ mockStaticMyClazz.closeOnDemand();
+ }
+
+ @Test
+ public void test() {
+ assertEquals("mockMsg", new MyClazz().getMsg());
+ assertEquals("mockStaticMsg", MyClazz.getStaticMsg());
+ }
+
+ public static class MyClazz {
+ public String getMsg() {
+ return "msg";
+ }
+
+ public static String getStaticMsg() {
+ return "staticMsg";
+ }
+ }
+ }
+ """
+ )
+ );
+ }
+
+ @Test
+ void mockUpWithParamsTest() {
+ rewriteRun(
+ //language=java
+ java(
+ """
+ import mockit.Mock;
+ import mockit.MockUp;
+ import org.junit.Test;
+
+ import static org.junit.Assert.assertEquals;
+
+ public class MockUpTest {
+ @Test
+ public void init() {
+ new MockUp() {
+ @Mock
+ public String getMsg(String foo, String bar, String unused) {
+ return foo + bar;
+ }
+ };
+ assertEquals("foobar", new MyClazz().getMsg("foo", "bar", "unused"));
+ }
+
+ public static class MyClazz {
+ public String getMsg(String foo, String bar, String unused) {
+ return "msg";
+ }
+ }
+ }
+ """,
+ """
+ import org.junit.Test;
+ import org.mockito.MockedConstruction;
+
+ import static org.junit.Assert.assertEquals;
+ import static org.mockito.AdditionalAnswers.delegatesTo;
+ import static org.mockito.Answers.CALLS_REAL_METHODS;
+ import static org.mockito.ArgumentMatchers.*;
+ import static org.mockito.Mockito.*;
+
+ public class MockUpTest {
+ @Test
+ public void init() {
+ MyClazz mockMyClazz = mock(MyClazz.class, CALLS_REAL_METHODS);
+ doAnswer(invocation -> {
+ String foo = invocation.getArgument(0);
+ String bar = invocation.getArgument(1);
+ return foo + bar;
+ }).when(mockMyClazz).getMsg(nullable(String.class), nullable(String.class), nullable(String.class));
+ try (MockedConstruction mockConsMyClazz = mockConstructionWithAnswer(MyClazz.class, delegatesTo(mockMyClazz))) {
+ assertEquals("foobar", new MyClazz().getMsg("foo", "bar", "unused"));
+ }
+ }
+
+ public static class MyClazz {
+ public String getMsg(String foo, String bar, String unused) {
+ return "msg";
+ }
+ }
+ }
+ """
+ )
+ );
+ }
+}
diff --git a/src/test/java/org/openrewrite/java/testing/jmockit/JMockitNonStrictExpectationsToMockitoTest.java b/src/test/java/org/openrewrite/java/testing/jmockit/JMockitNonStrictExpectationsToMockitoTest.java
index bc186bb48..463746a26 100644
--- a/src/test/java/org/openrewrite/java/testing/jmockit/JMockitNonStrictExpectationsToMockitoTest.java
+++ b/src/test/java/org/openrewrite/java/testing/jmockit/JMockitNonStrictExpectationsToMockitoTest.java
@@ -21,12 +21,9 @@
import org.openrewrite.test.RewriteTest;
import static org.openrewrite.java.Assertions.java;
-import static org.openrewrite.java.testing.jmockit.JMockitTestUtils.MOCKITO_CORE_DEPENDENCY;
-import static org.openrewrite.java.testing.jmockit.JMockitTestUtils.setParserSettings;
+import static org.openrewrite.java.testing.jmockit.JMockitTestUtils.*;
class JMockitNonStrictExpectationsToMockitoTest implements RewriteTest {
-
- private static final String JUNIT_4_DEPENDENCY = "junit-4.13.2";
private static final String LEGACY_JMOCKIT_DEPENDENCY = "jmockit-1.22";
@Override
@@ -53,14 +50,14 @@ public String getSomeField() {
import mockit.Mocked;
import mockit.integration.junit4.JMockit;
import org.junit.runner.RunWith;
-
+
import static org.junit.Assert.assertNull;
-
+
@RunWith(JMockit.class)
class MyTest {
@Mocked
MyObject myObject;
-
+
void test() {
new NonStrictExpectations() {{
myObject.getSomeField();
@@ -74,16 +71,16 @@ void test() {
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
-
+
import static org.junit.Assert.assertNull;
import static org.mockito.Mockito.lenient;
import static org.mockito.Mockito.when;
-
+
@RunWith(MockitoJUnitRunner.class)
class MyTest {
@Mock
MyObject myObject;
-
+
void test() {
lenient().when(myObject.getSomeField()).thenReturn(null);
assertNull(myObject.getSomeField());
@@ -113,14 +110,14 @@ public int getSomeField() {
import mockit.Mocked;
import mockit.integration.junit4.JMockit;
import org.junit.runner.RunWith;
-
+
import static org.junit.Assert.assertEquals;
-
+
@RunWith(JMockit.class)
class MyTest {
@Mocked
MyObject myObject;
-
+
void test() {
new NonStrictExpectations() {{
myObject.getSomeField();
@@ -139,16 +136,16 @@ void test() {
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
-
+
import static org.junit.Assert.assertEquals;
import static org.mockito.Mockito.lenient;
import static org.mockito.Mockito.when;
-
+
@RunWith(MockitoJUnitRunner.class)
class MyTest {
@Mock
MyObject myObject;
-
+
void test() {
lenient().when(myObject.getSomeField()).thenReturn(10);
assertEquals(10, myObject.getSomeField());
@@ -180,14 +177,14 @@ public String getSomeField(String s) {
import mockit.Mocked;
import mockit.integration.junit4.JMockit;
import org.junit.runner.RunWith;
-
+
import static org.junit.Assert.assertEquals;
-
+
@RunWith(JMockit.class)
class MyTest {
@Mocked
MyObject myObject;
-
+
void test() {
new NonStrictExpectations() {{
myObject.getSomeField(anyString);
@@ -201,15 +198,15 @@ void test() {
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
-
+
import static org.junit.Assert.assertEquals;
import static org.mockito.Mockito.*;
-
+
@RunWith(MockitoJUnitRunner.class)
class MyTest {
@Mock
MyObject myObject;
-
+
void test() {
lenient().when(myObject.getSomeField(anyString())).thenReturn("foo");
assertEquals("foo", myObject.getSomeField("bar"));
@@ -239,16 +236,16 @@ public String getSomeField() {
import mockit.Mocked;
import mockit.integration.junit4.JMockit;
import org.junit.runner.RunWith;
-
+
import static org.junit.Assert.assertEquals;
-
+
@RunWith(JMockit.class)
class MyTest {
@Mocked
MyObject myObject;
-
+
String expected = "expected";
-
+
void test() {
new NonStrictExpectations() {{
myObject.getSomeField();
@@ -262,18 +259,18 @@ void test() {
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
-
+
import static org.junit.Assert.assertEquals;
import static org.mockito.Mockito.lenient;
import static org.mockito.Mockito.when;
-
+
@RunWith(MockitoJUnitRunner.class)
class MyTest {
@Mock
MyObject myObject;
-
+
String expected = "expected";
-
+
void test() {
lenient().when(myObject.getSomeField()).thenReturn(expected);
assertEquals(expected, myObject.getSomeField());
@@ -303,14 +300,14 @@ public Object getSomeField() {
import mockit.Mocked;
import mockit.integration.junit4.JMockit;
import org.junit.runner.RunWith;
-
+
import static org.junit.Assert.assertNotNull;
-
+
@RunWith(JMockit.class)
class MyTest {
@Mocked
MyObject myObject;
-
+
void test() {
new NonStrictExpectations() {{
myObject.getSomeField();
@@ -324,16 +321,16 @@ void test() {
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
-
+
import static org.junit.Assert.assertNotNull;
import static org.mockito.Mockito.lenient;
import static org.mockito.Mockito.when;
-
+
@RunWith(MockitoJUnitRunner.class)
class MyTest {
@Mock
MyObject myObject;
-
+
void test() {
lenient().when(myObject.getSomeField()).thenReturn(new Object());
assertNotNull(myObject.getSomeField());
@@ -363,12 +360,12 @@ public String getSomeField() {
import mockit.Mocked;
import mockit.integration.junit4.JMockit;
import org.junit.runner.RunWith;
-
+
@RunWith(JMockit.class)
class MyTest {
@Mocked
MyObject myObject;
-
+
void test() throws RuntimeException {
new NonStrictExpectations() {{
myObject.getSomeField();
@@ -382,15 +379,15 @@ void test() throws RuntimeException {
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
-
+
import static org.mockito.Mockito.lenient;
import static org.mockito.Mockito.when;
-
+
@RunWith(MockitoJUnitRunner.class)
class MyTest {
@Mock
MyObject myObject;
-
+
void test() throws RuntimeException {
lenient().when(myObject.getSomeField()).thenThrow(new RuntimeException());
myObject.getSomeField();
@@ -420,14 +417,14 @@ public String getSomeField() {
import mockit.Mocked;
import mockit.integration.junit4.JMockit;
import org.junit.runner.RunWith;
-
+
import static org.junit.Assert.assertEquals;
-
+
@RunWith(JMockit.class)
class MyTest {
@Mocked
MyObject myObject;
-
+
void test() throws RuntimeException {
new NonStrictExpectations() {{
myObject.getSomeField();
@@ -442,16 +439,16 @@ void test() throws RuntimeException {
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
-
+
import static org.junit.Assert.assertEquals;
import static org.mockito.Mockito.lenient;
import static org.mockito.Mockito.when;
-
+
@RunWith(MockitoJUnitRunner.class)
class MyTest {
@Mock
MyObject myObject;
-
+
void test() throws RuntimeException {
lenient().when(myObject.getSomeField()).thenReturn("foo", "bar");
assertEquals("foo", myObject.getSomeField());
@@ -470,7 +467,7 @@ void whenClassArgumentMatcher() {
java(
"""
import java.util.List;
-
+
class MyObject {
public String getSomeField(List input) {
return "X";
@@ -485,19 +482,19 @@ public String getSomeOtherField(Object input) {
"""
import java.util.ArrayList;
import java.util.List;
-
+
import mockit.NonStrictExpectations;
import mockit.Mocked;
import mockit.integration.junit4.JMockit;
import org.junit.runner.RunWith;
-
+
import static org.junit.Assert.assertNull;
-
+
@RunWith(JMockit.class)
class MyTest {
@Mocked
MyObject myObject;
-
+
void test() {
new NonStrictExpectations() {{
myObject.getSomeField((List) any);
@@ -513,19 +510,19 @@ void test() {
"""
import java.util.ArrayList;
import java.util.List;
-
+
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
-
+
import static org.junit.Assert.assertNull;
import static org.mockito.Mockito.*;
-
+
@RunWith(MockitoJUnitRunner.class)
class MyTest {
@Mock
MyObject myObject;
-
+
void test() {
lenient().when(myObject.getSomeField(anyList())).thenReturn(null);
lenient().when(myObject.getSomeOtherField(any(Object.class))).thenReturn(null);
@@ -545,7 +542,7 @@ void whenNoArguments() {
java(
"""
import java.util.List;
-
+
class MyObject {
public String getSomeField() {
return "X";
@@ -557,19 +554,19 @@ public String getSomeField() {
"""
import java.util.ArrayList;
import java.util.List;
-
+
import mockit.NonStrictExpectations;
import mockit.Mocked;
import mockit.integration.junit4.JMockit;
import org.junit.runner.RunWith;
-
+
import static org.junit.Assert.assertNull;
-
+
@RunWith(JMockit.class)
class MyTest {
@Mocked
MyObject myObject;
-
+
void test() {
new NonStrictExpectations() {{
myObject.getSomeField();
@@ -582,20 +579,20 @@ void test() {
"""
import java.util.ArrayList;
import java.util.List;
-
+
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
-
+
import static org.junit.Assert.assertNull;
import static org.mockito.Mockito.lenient;
import static org.mockito.Mockito.when;
-
+
@RunWith(MockitoJUnitRunner.class)
class MyTest {
@Mock
MyObject myObject;
-
+
void test() {
lenient().when(myObject.getSomeField()).thenReturn(null);
assertNull(myObject.getSomeField());
@@ -614,7 +611,7 @@ void whenMixedArgumentMatcher() {
java(
"""
import java.util.List;
-
+
class MyObject {
public String getSomeField(String s, String s2, String s3, long l1) {
return "X";
@@ -626,19 +623,19 @@ public String getSomeField(String s, String s2, String s3, long l1) {
"""
import java.util.ArrayList;
import java.util.List;
-
+
import mockit.NonStrictExpectations;
import mockit.Mocked;
import mockit.integration.junit4.JMockit;
import org.junit.runner.RunWith;
-
+
import static org.junit.Assert.assertNull;
-
+
@RunWith(JMockit.class)
class MyTest {
@Mocked
MyObject myObject;
-
+
void test() {
String bazz = "bazz";
new NonStrictExpectations() {{
@@ -652,19 +649,19 @@ void test() {
"""
import java.util.ArrayList;
import java.util.List;
-
+
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
-
+
import static org.junit.Assert.assertNull;
import static org.mockito.Mockito.*;
-
+
@RunWith(MockitoJUnitRunner.class)
class MyTest {
@Mock
MyObject myObject;
-
+
void test() {
String bazz = "bazz";
lenient().when(myObject.getSomeField(eq("foo"), anyString(), eq(bazz), eq(10L))).thenReturn(null);
@@ -683,7 +680,7 @@ void whenSetupStatements() {
java(
"""
class MyObject {
-
+
public String getSomeField(String s) {
return "X";
}
@@ -699,26 +696,26 @@ public String getString() {
import mockit.Mocked;
import mockit.integration.junit4.JMockit;
import org.junit.runner.RunWith;
-
+
import static org.junit.Assert.assertEquals;
-
+
@RunWith(JMockit.class)
class MyTest {
@Mocked
MyObject myObject;
-
+
void test() {
String a = "a";
String s = "s";
-
+
new NonStrictExpectations() {{
myObject.getSomeField(anyString);
result = s;
-
+
myObject.getString();
result = a;
}};
-
+
assertEquals("s", myObject.getSomeField("foo"));
assertEquals("a", myObject.getString());
}
@@ -728,21 +725,21 @@ void test() {
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
-
+
import static org.junit.Assert.assertEquals;
import static org.mockito.Mockito.*;
-
+
@RunWith(MockitoJUnitRunner.class)
class MyTest {
@Mock
MyObject myObject;
-
+
void test() {
String a = "a";
String s = "s";
lenient().when(myObject.getSomeField(anyString())).thenReturn(s);
lenient().when(myObject.getString()).thenReturn(a);
-
+
assertEquals("s", myObject.getSomeField("foo"));
assertEquals("a", myObject.getString());
}
@@ -771,14 +768,14 @@ public String getSomeField(String s) {
import mockit.Mocked;
import mockit.integration.junit4.JMockit;
import org.junit.runner.RunWith;
-
+
import static org.junit.Assert.assertEquals;
-
+
@RunWith(JMockit.class)
class MyTest {
@Mocked
MyObject myObject;
-
+
void test() {
String a = "a";
new NonStrictExpectations() {{
@@ -787,7 +784,7 @@ void test() {
String b = "b";
result = s;
}};
-
+
assertEquals("s", myObject.getSomeField("foo"));
}
}
@@ -796,21 +793,21 @@ void test() {
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
-
+
import static org.junit.Assert.assertEquals;
import static org.mockito.Mockito.*;
-
+
@RunWith(MockitoJUnitRunner.class)
class MyTest {
@Mock
MyObject myObject;
-
+
void test() {
String a = "a";
String s = "s";
String b = "b";
lenient().when(myObject.getSomeField(anyString())).thenReturn(s);
-
+
assertEquals("s", myObject.getSomeField("foo"));
}
}
@@ -838,14 +835,14 @@ public String getSomeField() {
import mockit.Tested;
import mockit.integration.junit4.JMockit;
import org.junit.runner.RunWith;
-
+
import static org.junit.Assert.assertEquals;
-
+
@RunWith(JMockit.class)
class MyTest {
@Tested
MyObject myObject;
-
+
void test() {
new NonStrictExpectations(myObject) {{
myObject.getSomeField();
@@ -859,16 +856,16 @@ void test() {
import org.junit.runner.RunWith;
import org.mockito.InjectMocks;
import org.mockito.junit.MockitoJUnitRunner;
-
+
import static org.junit.Assert.assertEquals;
import static org.mockito.Mockito.lenient;
import static org.mockito.Mockito.when;
-
+
@RunWith(MockitoJUnitRunner.class)
class MyTest {
@InjectMocks
MyObject myObject;
-
+
void test() {
lenient().when(myObject.getSomeField()).thenReturn("foo");
assertEquals("foo", myObject.getSomeField());
@@ -905,18 +902,18 @@ public void doSomething() {}
import mockit.Mocked;
import mockit.integration.junit4.JMockit;
import org.junit.runner.RunWith;
-
+
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
-
+
@RunWith(JMockit.class)
class MyTest {
@Mocked
Object myObject;
-
+
@Mocked
MyObject myOtherObject;
-
+
void test() {
new NonStrictExpectations() {{
myObject.hashCode();
@@ -938,19 +935,19 @@ void test() {
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
-
+
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import static org.mockito.Mockito.*;
-
+
@RunWith(MockitoJUnitRunner.class)
class MyTest {
@Mock
Object myObject;
-
+
@Mock
MyObject myOtherObject;
-
+
void test() {
lenient().when(myObject.hashCode()).thenReturn(10);
lenient().when(myOtherObject.getSomeObjectField()).thenReturn(null);
@@ -985,15 +982,15 @@ public String getSomeStringField() {
import mockit.Mocked;
import mockit.integration.junit4.JMockit;
import org.junit.runner.RunWith;
-
+
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
-
+
@RunWith(JMockit.class)
class MyTest {
@Mocked
MyObject myObject;
-
+
void test() {
new NonStrictExpectations() {{
myObject.getSomeStringField();
@@ -1012,17 +1009,17 @@ void test() {
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
-
+
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import static org.mockito.Mockito.lenient;
import static org.mockito.Mockito.when;
-
+
@RunWith(MockitoJUnitRunner.class)
class MyTest {
@Mock
MyObject myObject;
-
+
void test() {
lenient().when(myObject.getSomeStringField()).thenReturn("a");
assertEquals("a", myObject.getSomeStringField());
@@ -1045,15 +1042,15 @@ void whenNoResultsNoTimes() {
import mockit.Mocked;
import mockit.integration.junit4.JMockit;
import org.junit.runner.RunWith;
-
+
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
-
+
@RunWith(JMockit.class)
class MyTest {
@Mocked
Object myObject;
-
+
void test() {
new NonStrictExpectations() {{
myObject.wait(anyLong);
@@ -1066,15 +1063,15 @@ void test() {
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
-
+
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
-
+
@RunWith(MockitoJUnitRunner.class)
class MyTest {
@Mock
Object myObject;
-
+
void test() {
myObject.wait(1L);
}
@@ -1094,12 +1091,12 @@ void whenTimes() {
import mockit.Mocked;
import mockit.integration.junit4.JMockit;
import org.junit.runner.RunWith;
-
+
@RunWith(JMockit.class)
class MyTest {
@Mocked
Object myObject;
-
+
void test() {
new NonStrictExpectations() {{
myObject.wait(anyLong, anyInt);
@@ -1115,14 +1112,14 @@ void test() {
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
-
+
import static org.mockito.Mockito.*;
-
+
@RunWith(MockitoJUnitRunner.class)
class MyTest {
@Mock
Object myObject;
-
+
void test() {
myObject.wait(10L, 10);
myObject.wait(10L, 10);
@@ -1145,12 +1142,12 @@ void whenTimesAndResult() {
import mockit.Mocked;
import mockit.integration.junit4.JMockit;
import org.junit.runner.RunWith;
-
+
@RunWith(JMockit.class)
class MyTest {
@Mocked
Object myObject;
-
+
void test() {
new NonStrictExpectations() {{
myObject.toString();
@@ -1166,14 +1163,14 @@ void test() {
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
-
+
import static org.mockito.Mockito.*;
-
+
@RunWith(MockitoJUnitRunner.class)
class MyTest {
@Mock
Object myObject;
-
+
void test() {
when(myObject.toString()).thenReturn("foo");
myObject.toString();
diff --git a/src/test/java/org/openrewrite/java/testing/jmockit/JMockitTestUtils.java b/src/test/java/org/openrewrite/java/testing/jmockit/JMockitTestUtils.java
index c66d3fbb1..9c2ac5f31 100644
--- a/src/test/java/org/openrewrite/java/testing/jmockit/JMockitTestUtils.java
+++ b/src/test/java/org/openrewrite/java/testing/jmockit/JMockitTestUtils.java
@@ -24,6 +24,7 @@ public class JMockitTestUtils {
static final String MOCKITO_CORE_DEPENDENCY = "mockito-core-3.12";
static final String JUNIT_5_JUPITER_DEPENDENCY = "junit-jupiter-api-5.9";
+ static final String JUNIT_4_DEPENDENCY = "junit-4.13.2";
static final String JMOCKIT_DEPENDENCY = "jmockit-1.49";
static final String MOCKITO_JUPITER_DEPENDENCY = "mockito-junit-jupiter-3.12";