Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix @InjectMock and @InjectSpy handling of @Nested tests #19574

Merged
merged 2 commits into from
Aug 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package io.quarkus.it.mockbean;

import static io.restassured.RestAssured.given;
import static org.hamcrest.Matchers.is;

import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;

import io.quarkus.test.junit.QuarkusTest;
import io.quarkus.test.junit.mockito.InjectMock;

@QuarkusTest
public class NestedTest {

@InjectMock
MessageService messageService;

@Nested
public class ActualTest {

@InjectMock
SuffixService suffixService;

@Test
public void testGreet() {
Mockito.when(messageService.getMessage()).thenReturn("hi");
Mockito.when(suffixService.getSuffix()).thenReturn("!");

given()
.when().get("/greeting")
.then()
.statusCode(200)
.body(is("HI!"));
}

@Test
public void testGreetAgain() {
Mockito.when(messageService.getMessage()).thenReturn("yolo");
Mockito.when(suffixService.getSuffix()).thenReturn("!!!");

given()
.when().get("/greeting")
.then()
.statusCode(200)
.body(is("YOLO!!!"));
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package io.quarkus.test.junit.mockito.internal;

import io.quarkus.test.junit.callback.QuarkusTestAfterAllCallback;
import io.quarkus.test.junit.callback.QuarkusTestContext;

public class ResetOuterMockitoMocksCallback implements QuarkusTestAfterAllCallback {

@Override
public void afterAll(QuarkusTestContext context) {
if (context.getOuterInstance() != null) {
MockitoMocksTracker.reset(context.getOuterInstance());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ public class SetMockitoMockAsBeanMockCallback implements QuarkusTestBeforeEachCa
@Override
public void beforeEach(QuarkusTestMethodContext context) {
MockitoMocksTracker.getMocks(context.getTestInstance()).forEach(this::installMock);
if (context.getOuterInstance() != null) {
MockitoMocksTracker.getMocks(context.getOuterInstance()).forEach(this::installMock);
}
}

private void installMock(MockitoMocksTracker.Mocked mocked) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
io.quarkus.test.junit.mockito.internal.ResetOuterMockitoMocksCallback
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,12 @@
import io.quarkus.test.common.http.TestHTTPEndpoint;
import io.quarkus.test.common.http.TestHTTPResourceManager;
import io.quarkus.test.junit.buildchain.TestBuildChainCustomizerProducer;
import io.quarkus.test.junit.callback.QuarkusTestAfterAllCallback;
import io.quarkus.test.junit.callback.QuarkusTestAfterConstructCallback;
import io.quarkus.test.junit.callback.QuarkusTestAfterEachCallback;
import io.quarkus.test.junit.callback.QuarkusTestBeforeClassCallback;
import io.quarkus.test.junit.callback.QuarkusTestBeforeEachCallback;
import io.quarkus.test.junit.callback.QuarkusTestContext;
import io.quarkus.test.junit.callback.QuarkusTestMethodContext;
import io.quarkus.test.junit.internal.DeepClone;
import io.quarkus.test.junit.internal.SerializationWithXStreamFallbackDeepClone;
Expand All @@ -131,16 +133,18 @@ public class QuarkusTestExtension

private static Class<?> actualTestClass;
private static Object actualTestInstance;
// needed for @Nested
private static Object outerInstance;
private static ClassLoader originalCl;
private static RunningQuarkusApplication runningQuarkusApplication;
private static Pattern clonePattern;
private static Throwable firstException; //if this is set then it will be thrown from the very first test that is run, the rest are aborted

private static List<Object> beforeClassCallbacks;
private static List<Object> afterConstructCallbacks;
private static List<Object> legacyAfterConstructCallbacks;
private static List<Object> beforeEachCallbacks;
private static List<Object> afterEachCallbacks;
private static List<Object> afterAllCallbacks;
private static Class<?> quarkusTestMethodContextClass;
private static Class<? extends QuarkusTestProfile> quarkusTestProfile;
private static boolean hasPerTestResources;
Expand Down Expand Up @@ -473,9 +477,9 @@ private void populateCallbacks(ClassLoader classLoader) throws ClassNotFoundExce
quarkusTestMethodContextClass = null;
beforeClassCallbacks = new ArrayList<>();
afterConstructCallbacks = new ArrayList<>();
legacyAfterConstructCallbacks = new ArrayList<>();
beforeEachCallbacks = new ArrayList<>();
afterEachCallbacks = new ArrayList<>();
afterAllCallbacks = new ArrayList<>();

ServiceLoader<?> quarkusTestBeforeClassLoader = ServiceLoader
.load(Class.forName(QuarkusTestBeforeClassCallback.class.getName(), false, classLoader), classLoader);
Expand All @@ -497,6 +501,11 @@ private void populateCallbacks(ClassLoader classLoader) throws ClassNotFoundExce
for (Object quarkusTestAfterEach : quarkusTestAfterEachLoader) {
afterEachCallbacks.add(quarkusTestAfterEach);
}
ServiceLoader<?> quarkusTestAfterAllLoader = ServiceLoader
.load(Class.forName(QuarkusTestAfterAllCallback.class.getName(), false, classLoader), classLoader);
for (Object quarkusTestAfterAll : quarkusTestAfterAllLoader) {
afterAllCallbacks.add(quarkusTestAfterAll);
}
}

private void populateTestMethodInvokers(ClassLoader quarkusClassLoader) {
Expand Down Expand Up @@ -642,9 +651,9 @@ public void afterEach(ExtensionContext context) throws Exception {
throw new RuntimeException("Could not find method " + originalTestMethod + " on test class");
}

Constructor<?> constructor = quarkusTestMethodContextClass.getConstructor(Object.class, Method.class);
Constructor<?> constructor = quarkusTestMethodContextClass.getConstructor(Object.class, Object.class, Method.class);
return new AbstractMap.SimpleEntry<>(quarkusTestMethodContextClass,
constructor.newInstance(actualTestInstance, actualTestMethod));
constructor.newInstance(actualTestInstance, outerInstance, actualTestMethod));
}

private boolean isNativeOrIntegrationTest(Class<?> clazz) {
Expand Down Expand Up @@ -851,12 +860,13 @@ private void initTestState(ExtensionContext extensionContext, ExtensionState sta
Class<?> previousActualTestClass = actualTestClass;
actualTestClass = Class.forName(extensionContext.getRequiredTestClass().getName(), true,
Thread.currentThread().getContextClassLoader());
outerInstance = null;
if (extensionContext.getRequiredTestClass().isAnnotationPresent(Nested.class)) {
Class<?> parent = actualTestClass.getEnclosingClass();
Object parentInstance = runningQuarkusApplication.instance(parent);
Constructor<?> declaredConstructor = actualTestClass.getDeclaredConstructor(parent);
Class<?> outerClass = actualTestClass.getEnclosingClass();
outerInstance = runningQuarkusApplication.instance(outerClass);
Constructor<?> declaredConstructor = actualTestClass.getDeclaredConstructor(outerClass);
declaredConstructor.setAccessible(true);
actualTestInstance = declaredConstructor.newInstance(parentInstance);
actualTestInstance = declaredConstructor.newInstance(outerInstance);
} else {
actualTestInstance = runningQuarkusApplication.instance(actualTestClass);
}
Expand All @@ -870,9 +880,11 @@ private void initTestState(ExtensionContext extensionContext, ExtensionState sta
afterConstructCallback.getClass().getMethod("afterConstruct", Object.class).invoke(afterConstructCallback,
actualTestInstance);
}
for (Object legacyAfterConstructCallback : legacyAfterConstructCallbacks) {
legacyAfterConstructCallback.getClass().getMethod("beforeAll", Object.class)
.invoke(legacyAfterConstructCallback, actualTestInstance);
if (outerInstance != null) {
for (Object afterConstructCallback : afterConstructCallbacks) {
afterConstructCallback.getClass().getMethod("afterConstruct", Object.class).invoke(afterConstructCallback,
outerInstance);
}
}
} catch (Exception e) {
throw new TestInstantiationException("Failed to create test instance", e);
Expand Down Expand Up @@ -1108,6 +1120,7 @@ private Method determineTCCLExtensionMethod(ReflectiveInvocationContext<Method>
@Override
public void afterAll(ExtensionContext context) throws Exception {
resetHangTimeout();
runAfterAllCallbacks(context);
try {
if (!isNativeOrIntegrationTest(context.getRequiredTestClass()) && (runningQuarkusApplication != null)) {
popMockContext();
Expand All @@ -1117,6 +1130,31 @@ public void afterAll(ExtensionContext context) throws Exception {
}
} finally {
currentTestClassStack.pop();
outerInstance = null;
}
}

private void runAfterAllCallbacks(ExtensionContext context) throws Exception {
if (isNativeOrIntegrationTest(context.getRequiredTestClass())) {
return;
}
if (afterAllCallbacks.isEmpty()) {
return;
}

Class<?> quarkusTestContextClass = Class.forName(QuarkusTestContext.class.getName(), true,
runningQuarkusApplication.getClassLoader());
Object quarkusTestContextInstance = quarkusTestContextClass.getConstructor(Object.class, Object.class)
.newInstance(actualTestInstance, outerInstance);

ClassLoader original = setCCL(runningQuarkusApplication.getClassLoader());
try {
for (Object afterAllCallback : afterAllCallbacks) {
afterAllCallback.getClass().getMethod("afterAll", quarkusTestContextClass)
.invoke(afterAllCallback, quarkusTestContextInstance);
}
} finally {
setCCL(original);
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package io.quarkus.test.junit.callback;

/**
* Can be implemented by classes that shall be called after all test methods in a {@code @QuarkusTest} have been run.
* <p>
* The implementing class has to be {@linkplain java.util.ServiceLoader deployed as service provider on the class path}.
*/
public interface QuarkusTestAfterAllCallback {

void afterAll(QuarkusTestContext context);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package io.quarkus.test.junit.callback;

/**
* Context object passed to {@link QuarkusTestAfterAllCallback}
*/
public class QuarkusTestContext {

private final Object testInstance;
private final Object outerInstance;

public QuarkusTestContext(Object testInstance, Object outerInstance) {
this.testInstance = testInstance;
this.outerInstance = outerInstance;
}

public Object getTestInstance() {
return testInstance;
}

public Object getOuterInstance() {
return outerInstance;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,15 @@
/**
* Context object passed to {@link QuarkusTestBeforeEachCallback} and {@link QuarkusTestAfterEachCallback}
*/
public final class QuarkusTestMethodContext {
public final class QuarkusTestMethodContext extends QuarkusTestContext {

private final Object testInstance;
private final Method testMethod;

public QuarkusTestMethodContext(Object testInstance, Method testMethod) {
this.testInstance = testInstance;
public QuarkusTestMethodContext(Object testInstance, Object outerInstance, Method testMethod) {
super(testInstance, outerInstance);
this.testMethod = testMethod;
}

public Object getTestInstance() {
return testInstance;
}

public Method getTestMethod() {
return testMethod;
}
Expand Down