diff --git a/junit-jupiter-migration-support/src/main/java/org/junit/jupiter/migrationsupport/rules/ParameterizedExtension.java b/junit-jupiter-migration-support/src/main/java/org/junit/jupiter/migrationsupport/rules/ParameterizedExtension.java new file mode 100644 index 000000000000..4cf368cbdd7c --- /dev/null +++ b/junit-jupiter-migration-support/src/main/java/org/junit/jupiter/migrationsupport/rules/ParameterizedExtension.java @@ -0,0 +1,266 @@ +/* + * Copyright 2015-2017 the original author or authors. + * + * All rights reserved. This program and the accompanying materials are + * made available under the terms of the Eclipse Public License v1.0 which + * accompanies this distribution and is available at + * + * http://www.eclipse.org/legal/epl-v10.html + */ + +package org.junit.jupiter.migrationsupport.rules; + +import static java.util.Collections.singletonList; +import static java.util.function.Function.identity; +import static java.util.stream.Collectors.groupingBy; +import static java.util.stream.Collectors.toList; +import static org.junit.platform.commons.meta.API.Usage.Experimental; + +import java.lang.reflect.Constructor; +import java.lang.reflect.Field; +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.IntStream; +import java.util.stream.Stream; + +import org.junit.jupiter.api.extension.BeforeTestExecutionCallback; +import org.junit.jupiter.api.extension.ContainerExtensionContext; +import org.junit.jupiter.api.extension.Extension; +import org.junit.jupiter.api.extension.ExtensionContext; +import org.junit.jupiter.api.extension.ParameterContext; +import org.junit.jupiter.api.extension.ParameterResolutionException; +import org.junit.jupiter.api.extension.ParameterResolver; +import org.junit.jupiter.api.extension.TestExtensionContext; +import org.junit.jupiter.api.extension.TestTemplateInvocationContext; +import org.junit.jupiter.api.extension.TestTemplateInvocationContextProvider; +import org.junit.platform.commons.meta.API; +import org.junit.platform.commons.util.ReflectionUtils; +import org.junit.runners.Parameterized; + +@API(Experimental) +public class ParameterizedExtension implements TestTemplateInvocationContextProvider, ParameterResolver { + private static ExtensionContext.Namespace parameters = ExtensionContext.Namespace.create( + ParameterizedExtension.class);; + private AtomicInteger parameterCount = new AtomicInteger(); + + /** + * Indicate whether we can provide parameterized support. + * This requires the testClass to either have a static {@code @Parameters} method + * and correct {@code @Parameter} and their corresponding values + * or to have a constructor that could be injected. + */ + @Override + public boolean supports(ContainerExtensionContext context) { + return hasParametersMethod(context) && hasCorrectParameterFields(context); + } + + @Override + public Iterator provide(ContainerExtensionContext context) { + //grabbing the parent ensures the paremeters are stored in the same store. + return context.getParent().flatMap(ParameterizedExtension::parameters).map( + ParameterizedExtension::testTemplateContextsFromParameters).orElse(Collections.emptyIterator()); + } + + /** + * Since the parameterized runner in JUnit 4 could only resolve constructor parameters + * this extension once again here only support them on the constructor and require an {@code @Parameters} method + * + * @param parameterContext the context for the parameter to be resolved; never + * {@code null} + * @param extensionContext the extension context for the {@code Executable} + * about to be invoked; never {@code null} + * @return true if the above is met otherwise false. + */ + @Override + public boolean supports(ParameterContext parameterContext, ExtensionContext extensionContext) { + return hasParametersMethod(extensionContext) + && parameterContext.getDeclaringExecutable() instanceof Constructor; + } + + @Override + public Object resolve(ParameterContext parameterContext, ExtensionContext extensionContext) + throws ParameterResolutionException { + int parameterCount = parameterContext.getDeclaringExecutable().getParameterCount(); + Object[] parameters = resolveParametersForConstructor(extensionContext, parameterCount); + + int parameterIndex = parameterContext.getIndex(); + //move to the next set of parametersFields + if (lastParameterToBeResolved(parameterContext)) { + this.parameterCount.incrementAndGet(); + } + + return parameters[parameterIndex]; + } + + /** + * Retrieves the Object[] of the current iteration we are working on. + * + * @param extensionContext the extensionContext + * @param parameterCount the amount of parameters of the constructor. + * + * @return the object[] for this parameter iteration. + * @throws ParameterResolutionException If the amount of arguments of the constructor doesn't match the amount + * of arguments of the currently resolved object[] + */ + private Object[] resolveParametersForConstructor(ExtensionContext extensionContext, int parameterCount) + throws ParameterResolutionException { + return parameters(extensionContext).map(ArrayList::new).map(l -> l.get(this.parameterCount.get())).filter( + params -> params.length == parameterCount).orElseThrow( + ParameterizedExtension::unMatchedAmountOfParametersException); + } + + private static boolean hasCorrectParameterFields(ExtensionContext context) { + List fields = parametersFields(context); + boolean hasFieldInjection = !fields.isEmpty(); + + if (hasArgsConstructor(context) && hasFieldInjection) { + return false; + } + else if (hasFieldInjection) { + return areParametersFormedCorrectly(fields); + } + + return true; + } + + private static boolean areParametersFormedCorrectly(List fields) { + List parameterValues = parameterIndexes(fields); + + List duplicateIndexes = duplicatedIndexes(parameterValues); + + boolean hasAllIndexes = indexRangeComplete(parameterValues); + + return hasAllIndexes && duplicateIndexes.isEmpty(); + } + + private static List parameterIndexes(List fields) { + return fields.stream().map(f -> f.getAnnotation(Parameterized.Parameter.class)).map( + Parameterized.Parameter::value).collect(toList()); + } + + private static List duplicatedIndexes(List parameterValues) { + return parameterValues.stream().collect(groupingBy(identity())).entrySet().stream().filter( + e -> e.getValue().size() > 1).map(Map.Entry::getKey).collect(toList()); + } + + private static Boolean indexRangeComplete(List parameterValues) { + return parameterValues.stream().max(Integer::compareTo).map( + i -> parameterValues.containsAll(IntStream.range(0, i).boxed().collect(toList()))).orElse(false); + } + + private static boolean lastParameterToBeResolved(ParameterContext parameterContext) { + return parameterContext.getIndex() == parameterContext.getDeclaringExecutable().getParameterCount() - 1; + } + + private static Optional> parameters(ExtensionContext context) { + return context.getStore(parameters).getOrComputeIfAbsent("parameterMethod", + k -> new ParameterWrapper(callParameters(context)), ParameterWrapper.class).getValue(); + + } + + private static Optional> callParameters(ExtensionContext context) { + return findParametersMethod(context).map(m -> ReflectionUtils.invokeMethod(m, null)).map( + ParameterizedExtension::convertParametersMethodReturnType); + } + + private static boolean hasParametersMethod(ExtensionContext context) { + return findParametersMethod(context).isPresent(); + } + + private static Optional findParametersMethod(ExtensionContext extensionContext) { + return extensionContext.getTestClass().flatMap(ParameterizedExtension::ensureSingleParametersMethod).filter( + ReflectionUtils::isPublic); + } + + private static Optional ensureSingleParametersMethod(Class testClass) { + return ReflectionUtils.findMethods(testClass, + m -> m.isAnnotationPresent(Parameterized.Parameters.class)).stream().findFirst(); + } + + private static Iterator testTemplateContextsFromParameters(Collection o) { + return o.stream().map(ParameterizedExtension::contextFactory).iterator(); + } + + private static TestTemplateInvocationContext contextFactory(Object[] parameters) { + return new TestTemplateInvocationContext() { + @Override + public List getAdditionalExtensions() { + return singletonList(new InjectionExtension(parameters)); + } + }; + } + + private static class InjectionExtension implements BeforeTestExecutionCallback { + private final Object[] parameters; + + public InjectionExtension(Object[] parameters) { + this.parameters = parameters; + } + + @Override + public void beforeTestExecution(TestExtensionContext context) throws Exception { + List parameters = parametersFields(context); + + if (!parameters.isEmpty() && parameters.size() != this.parameters.length) { + throw unMatchedAmountOfParametersException(); + } + + for (Field param : parameters) { + Parameterized.Parameter annotation = param.getAnnotation(Parameterized.Parameter.class); + int paramIndex = annotation.value(); + param.set(context.getTestInstance(), this.parameters[paramIndex]); + } + } + } + + private static boolean hasArgsConstructor(ExtensionContext context) { + return context.getTestClass().map(ReflectionUtils::getDeclaredConstructor).filter( + c -> c.getParameterCount() > 0).isPresent(); + } + + private static List parametersFields(ExtensionContext context) { + Stream fieldStream = context.getTestClass().map(Class::getDeclaredFields).map(Stream::of).orElse( + Stream.empty()); + + return fieldStream.filter(f -> f.isAnnotationPresent(Parameterized.Parameter.class)).filter( + ReflectionUtils::isPublic).collect(toList()); + } + + private static ParameterResolutionException unMatchedAmountOfParametersException() { + return new ParameterResolutionException( + "The amount of parametersFields in the constructor doesn't match those in the provided parametersFields"); + } + + private static ParameterResolutionException wrongParametersReturnType() { + return new ParameterResolutionException("The @Parameters returns the wrong type"); + } + + @SuppressWarnings("unchecked") + private static Collection convertParametersMethodReturnType(Object o) { + if (o instanceof Collection) { + return (Collection) o; + } + else { + throw wrongParametersReturnType(); + } + } + + private static class ParameterWrapper { + private final Optional> value; + + public ParameterWrapper(Optional> value) { + this.value = value; + } + + public Optional> getValue() { + return value; + } + } +} diff --git a/junit-jupiter-migration-support/src/test/java/org/junit/jupiter/migrationsupport/ParameterizedExtentsionTests.java b/junit-jupiter-migration-support/src/test/java/org/junit/jupiter/migrationsupport/ParameterizedExtentsionTests.java new file mode 100644 index 000000000000..b2b68b93bf63 --- /dev/null +++ b/junit-jupiter-migration-support/src/test/java/org/junit/jupiter/migrationsupport/ParameterizedExtentsionTests.java @@ -0,0 +1,308 @@ +/* + * Copyright 2015-2017 the original author or authors. + * + * All rights reserved. This program and the accompanying materials are + * made available under the terms of the Eclipse Public License v1.0 which + * accompanies this distribution and is available at + * + * http://www.eclipse.org/legal/epl-v10.html + */ + +package org.junit.jupiter.migrationsupport; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.platform.engine.discovery.DiscoverySelectors.selectClass; +import static org.junit.platform.launcher.core.LauncherDiscoveryRequestBuilder.request; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ContainerExtensionContext; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.extension.ParameterResolutionException; +import org.junit.jupiter.engine.JupiterTestEngine; +import org.junit.jupiter.migrationsupport.rules.ParameterizedExtension; +import org.junit.platform.engine.ExecutionRequest; +import org.junit.platform.engine.TestDescriptor; +import org.junit.platform.engine.TestExecutionResult; +import org.junit.platform.engine.UniqueId; +import org.junit.platform.engine.test.event.ExecutionEventRecorder; +import org.junit.platform.launcher.LauncherDiscoveryRequest; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameters; + +public class ParameterizedExtentsionTests { + + @Test + public void parametrizedWithParameterFieldInjection() { + ExecutionEventRecorder executionEventRecorder = executeTestsForClass(FibonacciTest.class); + assertThat(executionEventRecorder.getTestSuccessfulCount()).isEqualTo(7); + } + + @ExtendWith(ParameterizedExtension.class) + protected static class FibonacciTest { + @Parameters + public static Collection data() { + return Arrays.asList( + new Object[][] { { 0, 0 }, { 1, 1 }, { 2, 1 }, { 3, 2 }, { 4, 3 }, { 5, 5 }, { 6, 8 } }); + } + + @Parameterized.Parameter + public int fInput; + + @Parameterized.Parameter(1) + public int fExpected; + + @TestTemplate + public void test() { + assertEquals(fExpected, compute(fInput)); + } + + private static int compute(int n) { + int result = 0; + + if (n <= 1) { + result = n; + } + else { + result = compute(n - 1) + compute(n - 2); + } + + return result; + } + } + + @Test + public void paremeterizedWithConstructorInjection() { + ExecutionEventRecorder executionEventRecorder = executeTestsForClass(ParameterizedTestWithConstructor.class); + assertThat(executionEventRecorder.getTestSuccessfulCount()).isEqualTo(7); + } + + @ExtendWith(ParameterizedExtension.class) + protected static class ParameterizedTestWithConstructor { + @Parameters + public static Collection data() { + return Arrays.asList( + new Object[][] { { 0, 2 }, { 1, 3 }, { 2, 1 }, { 3, 2 }, { 4, 3 }, { 5, 8 }, { 6, 8 } }); + } + + private int a; + private int b; + + public ParameterizedTestWithConstructor(int a, int b) { + this.a = a; + this.b = b; + } + + @TestTemplate + public void test() { + assertNotEquals(a, b); + } + } + + @Test + void unMatchedConstructorArgumentCount() { + ExecutionEventRecorder eventRecorder = executeTestsForClass(UnMatchedConstructor.class); + assertThat(eventRecorder.getTestSuccessfulCount()).isEqualTo(0); + } + + @ExtendWith(ParameterizedExtension.class) + protected static class UnMatchedConstructor { + @Parameters + public static Collection data() { + return Arrays.asList(new Object[][] { { 0, 2 } }); + } + + public UnMatchedConstructor(int a) { + + } + + @TestTemplate + public void dummy() { + + } + } + + @Test + void unMatchedParameterFieldsCount() { + ExecutionEventRecorder executionEventRecorder = executeTestsForClass(WrongParameters.class); + assertThat(exceptionsThrown(executionEventRecorder)).allSatisfy( + e -> assertThat(e).isInstanceOf(ParameterResolutionException.class)); + } + + @ExtendWith(ParameterizedExtension.class) + protected static class WrongParameters { + @Parameterized.Parameter + public int a; + @Parameterized.Parameter(1) + public int b; + + @Parameters + public static Collection data() { + return Arrays.asList(new Object[][] { { 0 }, { 1 }, { 1 }, { 3 }, { 4 }, { 5 }, { 6 } }); + } + + @TestTemplate + public void dummy() { + + } + } + + @Test + void noInjectionMix() { + ParameterizedExtension extension = new ParameterizedExtension(); + + ContainerExtensionContext containerContext = mock(ContainerExtensionContext.class); + + when(containerContext.getTestClass()).thenReturn(Optional.of(DoubleInjection.class)); + assertFalse(extension.supports(containerContext)); + } + + @ExtendWith(ParameterizedExtension.class) + protected static class DoubleInjection { + @Parameterized.Parameter + public int a; + + public DoubleInjection(int a) { + + } + + @Parameters + public static Collection data() { + return Arrays.asList(new Object[][] { { 0 } }); + } + + @TestTemplate + public void dummy() { + + } + } + + @Test + void wrongReturnTypeFromParameters() { + ExecutionEventRecorder eventRecorder = executeTestsForClass(BadParameterReturnType.class); + assertThat(exceptionsThrown(eventRecorder)).allSatisfy(e -> { + assertThat(e).isInstanceOf(ParameterResolutionException.class); + assertThat(e).hasMessage("The @Parameters returns the wrong type"); + }); + } + + @ExtendWith(ParameterizedExtension.class) + private static class BadParameterReturnType { + public BadParameterReturnType(int a) { + + } + + @Parameters + public static int params() { + return 0; + } + + @TestTemplate + public void dummy() { + + } + } + + @Test + void emptyParametersList() { + ExecutionEventRecorder eventRecorder = executeTestsForClass(EmptyParameters.class); + assertThat(eventRecorder.getTestSuccessfulCount()).isEqualTo(2); + } + + @ExtendWith(ParameterizedExtension.class) + protected static class EmptyParameters { + + public EmptyParameters() { + int a = 0; + int b = a + 3; + } + + @Parameters + public static Collection data() { + return Arrays.asList(new Object[][] { {}, {} }); + } + + @TestTemplate + public void dummy() { + + } + } + + @Test + void duplicatedParameterFieldIndex() { + ExecutionEventRecorder eventRecorder = executeTestsForClass(DuplicatedIndex.class); + assertThat(eventRecorder.getTestSuccessfulCount()).isEqualTo(0); + } + + @ExtendWith(ParameterizedExtension.class) + protected static class DuplicatedIndex { + @Parameterized.Parameter + public int a; + + @Parameterized.Parameter + public int b; + + @Parameters + public static Collection data() { + return Arrays.asList(new Object[][] { {}, {} }); + } + + @TestTemplate + public void dummy() { + + } + } + + @Test + void parametersAreOnlyCalledOnce() { + ExecutionEventRecorder executionEventRecorder = executeTestsForClass(ParametersCalledOnce.class); + assertThat(executionEventRecorder.getTestSuccessfulCount()).isEqualTo(2); + } + + @ExtendWith(ParameterizedExtension.class) + protected static class ParametersCalledOnce { + private static int invocationCount = 0; + + public ParametersCalledOnce(int a) { + + } + + @Parameters + public static Collection data() { + invocationCount++; + return Arrays.asList(new Object[][] { { 3 }, { 4 } }); + } + + @TestTemplate + void dummy() { + assertEquals(invocationCount, 1); + } + } + + private ExecutionEventRecorder executeTestsForClass(Class testClass) { + LauncherDiscoveryRequest request = request().selectors(selectClass(testClass)).build(); + JupiterTestEngine engine = new JupiterTestEngine(); + TestDescriptor testDescriptor = engine.discover(request, UniqueId.forEngine(engine.getId())); + ExecutionEventRecorder eventRecorder = new ExecutionEventRecorder(); + engine.execute(new ExecutionRequest(testDescriptor, eventRecorder, request.getConfigurationParameters())); + return eventRecorder; + } + + private static List exceptionsThrown(ExecutionEventRecorder executionEventRecorder) { + return executionEventRecorder.getFailedTestFinishedEvents().stream().map( + it -> it.getPayload(TestExecutionResult.class)).map( + o -> o.flatMap(TestExecutionResult::getThrowable)).filter(Optional::isPresent).map( + Optional::get).collect(Collectors.toList()); + } +}