diff --git a/testing/trino-testing-services/pom.xml b/testing/trino-testing-services/pom.xml
index d9f53a201609b..e97457687dd75 100644
--- a/testing/trino-testing-services/pom.xml
+++ b/testing/trino-testing-services/pom.xml
@@ -53,6 +53,13 @@
true
+
+ org.junit.jupiter
+ junit-jupiter-api
+
+ true
+
+
org.openjdk.jmh
jmh-core
diff --git a/testing/trino-testing-services/src/main/java/io/trino/junit/ReportBadJunitTestAnnotations.java b/testing/trino-testing-services/src/main/java/io/trino/junit/ReportBadJunitTestAnnotations.java
new file mode 100644
index 0000000000000..26028bc7cc157
--- /dev/null
+++ b/testing/trino-testing-services/src/main/java/io/trino/junit/ReportBadJunitTestAnnotations.java
@@ -0,0 +1,127 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.trino.junit;
+
+import com.google.common.annotations.VisibleForTesting;
+import io.trino.testng.services.ReportBadTestAnnotations;
+import org.junit.jupiter.api.extension.ExtensionContext;
+import org.junit.jupiter.api.extension.TestInstanceFactoryContext;
+import org.junit.jupiter.api.extension.TestInstancePreConstructCallback;
+
+import java.lang.annotation.Annotation;
+import java.lang.reflect.Method;
+import java.lang.reflect.Modifier;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Optional;
+
+import static com.google.common.base.Throwables.getStackTraceAsString;
+import static com.google.common.collect.ImmutableList.toImmutableList;
+import static io.trino.testing.Listeners.reportListenerFailure;
+import static java.util.stream.Collectors.joining;
+
+public class ReportBadJunitTestAnnotations
+ implements TestInstancePreConstructCallback
+{
+ @Override
+ public void preConstructTestInstance(TestInstanceFactoryContext factoryContext, ExtensionContext context)
+ {
+ Class> testClass = factoryContext.getTestClass();
+ try {
+ reportBadTestAnnotations(testClass);
+ }
+ catch (RuntimeException | Error e) {
+ reportListenerFailure(
+ ReportBadTestAnnotations.class,
+ "Failed to process %s: \n%s",
+ testClass,
+ getStackTraceAsString(e));
+ }
+ }
+
+ private void reportBadTestAnnotations(Class> testClass)
+ {
+ List unannotatedTestMethods = findUnannotatedInheritedTestMethods(testClass);
+ if (!unannotatedTestMethods.isEmpty()) {
+ reportListenerFailure(
+ ReportBadJunitTestAnnotations.class,
+ "Test class %s has methods which are inherited but not explicitly annotated. Are they missing @Test?%s",
+ testClass.getName(),
+ unannotatedTestMethods.stream()
+ .map(Method::toString)
+ .collect(joining("\n\t\t", "\n\t\t", "")));
+ }
+ }
+
+ @VisibleForTesting
+ static List findUnannotatedInheritedTestMethods(Class> realClass)
+ {
+ return Arrays.stream(realClass.getMethods())
+ .filter(method -> method.getDeclaringClass() != Object.class)
+ .filter(method -> !Modifier.isStatic(method.getModifiers()))
+ .filter(method -> !method.isBridge())
+ .filter(method -> isUnannotated(method) && overriddenMethodHasTestAnnotation(method))
+ .collect(toImmutableList());
+ }
+
+ private static boolean isUnannotated(Method method)
+ {
+ return Arrays.stream(method.getAnnotations()).map(Annotation::annotationType)
+ .noneMatch(ReportBadJunitTestAnnotations::isJUnitAnnotation);
+ }
+
+ private static boolean isJUnitAnnotation(Class extends Annotation> clazz)
+ {
+ return clazz.getPackage().getName().startsWith("org.junit.jupiter.api");
+ }
+
+ private static boolean overriddenMethodHasTestAnnotation(Method method)
+ {
+ if (method.isAnnotationPresent(org.junit.jupiter.api.Test.class)) {
+ return true;
+ }
+
+ // Skip methods in Object class, e.g. toString()
+ if (method.getDeclaringClass() == Object.class) {
+ return false;
+ }
+
+ // The test class may override the default method of the interface
+ for (Class> interfaceClass : method.getDeclaringClass().getInterfaces()) {
+ Optional overridden = getOverridden(method, interfaceClass);
+ if (overridden.isPresent() && overridden.get().isAnnotationPresent(org.junit.jupiter.api.Test.class)) {
+ return true;
+ }
+ }
+
+ Class> superClass = method.getDeclaringClass().getSuperclass();
+ if (superClass == null) {
+ return false;
+ }
+ return getOverridden(method, superClass)
+ .map(ReportBadJunitTestAnnotations::overriddenMethodHasTestAnnotation)
+ .orElse(false);
+ }
+
+ private static Optional getOverridden(Method method, Class> base)
+ {
+ try {
+ // Simplistic override detection
+ return Optional.of(base.getMethod(method.getName(), method.getParameterTypes()));
+ }
+ catch (NoSuchMethodException ignored) {
+ return Optional.empty();
+ }
+ }
+}
diff --git a/testing/trino-testing-services/src/main/java/io/trino/testng/services/Listeners.java b/testing/trino-testing-services/src/main/java/io/trino/testing/Listeners.java
similarity index 85%
rename from testing/trino-testing-services/src/main/java/io/trino/testng/services/Listeners.java
rename to testing/trino-testing-services/src/main/java/io/trino/testing/Listeners.java
index 8709d65089ef4..73a8f50616eaf 100644
--- a/testing/trino-testing-services/src/main/java/io/trino/testng/services/Listeners.java
+++ b/testing/trino-testing-services/src/main/java/io/trino/testing/Listeners.java
@@ -11,27 +11,26 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package io.trino.testng.services;
+package io.trino.testing;
import com.google.common.base.Joiner;
import com.google.errorprone.annotations.FormatMethod;
import org.testng.ITestClass;
-import org.testng.ITestNGListener;
import org.testng.ITestResult;
import static java.lang.String.format;
-final class Listeners
+public final class Listeners
{
private Listeners() {}
/**
* Print error to standard error and exit JVM.
*
- * @apiNote A TestNG listener cannot throw an exception, as this are not currently properly handled by TestNG.
+ * @apiNote A TestNG listener and JUnit extension cannot throw an exception, as this are not currently properly handled by them.
*/
@FormatMethod
- public static void reportListenerFailure(Class extends ITestNGListener> listenerClass, String format, Object... args)
+ public static void reportListenerFailure(Class> listenerClass, String format, Object... args)
{
System.err.println(format("FATAL: %s: ", listenerClass.getName()) + format(format, args));
System.err.println("JVM will be terminated");
diff --git a/testing/trino-testing-services/src/main/java/io/trino/testng/services/FlakyAnnotationVerifier.java b/testing/trino-testing-services/src/main/java/io/trino/testng/services/FlakyAnnotationVerifier.java
index 292a02ae94ec8..692421539bfbc 100644
--- a/testing/trino-testing-services/src/main/java/io/trino/testng/services/FlakyAnnotationVerifier.java
+++ b/testing/trino-testing-services/src/main/java/io/trino/testng/services/FlakyAnnotationVerifier.java
@@ -28,7 +28,7 @@
import static com.google.common.base.Throwables.getStackTraceAsString;
import static com.google.common.collect.ImmutableList.toImmutableList;
-import static io.trino.testng.services.Listeners.reportListenerFailure;
+import static io.trino.testing.Listeners.reportListenerFailure;
import static java.lang.String.format;
import static java.util.stream.Collectors.joining;
diff --git a/testing/trino-testing-services/src/main/java/io/trino/testng/services/LogTestDurationListener.java b/testing/trino-testing-services/src/main/java/io/trino/testng/services/LogTestDurationListener.java
index f5e0e4d726b70..d4e9daca5e6fb 100644
--- a/testing/trino-testing-services/src/main/java/io/trino/testng/services/LogTestDurationListener.java
+++ b/testing/trino-testing-services/src/main/java/io/trino/testng/services/LogTestDurationListener.java
@@ -39,8 +39,8 @@
import static com.google.common.base.Throwables.getStackTraceAsString;
import static io.airlift.concurrent.Threads.daemonThreadsNamed;
import static io.airlift.units.Duration.nanosSince;
-import static io.trino.testng.services.Listeners.formatTestName;
-import static io.trino.testng.services.Listeners.reportListenerFailure;
+import static io.trino.testing.Listeners.formatTestName;
+import static io.trino.testing.Listeners.reportListenerFailure;
import static java.lang.String.format;
import static java.lang.management.ManagementFactory.getThreadMXBean;
import static java.util.concurrent.TimeUnit.MINUTES;
diff --git a/testing/trino-testing-services/src/main/java/io/trino/testng/services/ManageTestResources.java b/testing/trino-testing-services/src/main/java/io/trino/testng/services/ManageTestResources.java
index 3b3046f54f572..4501579ed2b1f 100644
--- a/testing/trino-testing-services/src/main/java/io/trino/testng/services/ManageTestResources.java
+++ b/testing/trino-testing-services/src/main/java/io/trino/testng/services/ManageTestResources.java
@@ -38,7 +38,7 @@
import static com.google.common.base.Throwables.getStackTraceAsString;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static com.google.common.collect.MoreCollectors.toOptional;
-import static io.trino.testng.services.Listeners.reportListenerFailure;
+import static io.trino.testing.Listeners.reportListenerFailure;
import static io.trino.testng.services.ManageTestResources.Stage.AFTER_CLASS;
import static io.trino.testng.services.ManageTestResources.Stage.BEFORE_CLASS;
import static java.lang.annotation.ElementType.FIELD;
diff --git a/testing/trino-testing-services/src/main/java/io/trino/testng/services/ProgressLoggingListener.java b/testing/trino-testing-services/src/main/java/io/trino/testng/services/ProgressLoggingListener.java
index b436c82c43754..c1ddaff698f50 100644
--- a/testing/trino-testing-services/src/main/java/io/trino/testng/services/ProgressLoggingListener.java
+++ b/testing/trino-testing-services/src/main/java/io/trino/testng/services/ProgressLoggingListener.java
@@ -25,7 +25,7 @@
import java.math.BigDecimal;
import java.math.RoundingMode;
-import static io.trino.testng.services.Listeners.formatTestName;
+import static io.trino.testing.Listeners.formatTestName;
import static java.lang.String.format;
public class ProgressLoggingListener
diff --git a/testing/trino-testing-services/src/main/java/io/trino/testng/services/ReportAfterMethodNotAlwaysRun.java b/testing/trino-testing-services/src/main/java/io/trino/testng/services/ReportAfterMethodNotAlwaysRun.java
index 0443fe6c68b2d..34c0b01bd42c7 100644
--- a/testing/trino-testing-services/src/main/java/io/trino/testng/services/ReportAfterMethodNotAlwaysRun.java
+++ b/testing/trino-testing-services/src/main/java/io/trino/testng/services/ReportAfterMethodNotAlwaysRun.java
@@ -32,7 +32,7 @@
import static com.google.common.base.Throwables.getStackTraceAsString;
import static com.google.common.collect.ImmutableList.toImmutableList;
-import static io.trino.testng.services.Listeners.reportListenerFailure;
+import static io.trino.testing.Listeners.reportListenerFailure;
import static java.util.Objects.requireNonNull;
import static java.util.function.Predicate.not;
import static java.util.stream.Collectors.joining;
diff --git a/testing/trino-testing-services/src/main/java/io/trino/testng/services/ReportBadTestAnnotations.java b/testing/trino-testing-services/src/main/java/io/trino/testng/services/ReportBadTestAnnotations.java
index 0c9470c50127a..ef9a8d82308e7 100644
--- a/testing/trino-testing-services/src/main/java/io/trino/testng/services/ReportBadTestAnnotations.java
+++ b/testing/trino-testing-services/src/main/java/io/trino/testng/services/ReportBadTestAnnotations.java
@@ -28,7 +28,7 @@
import static com.google.common.base.Throwables.getStackTraceAsString;
import static com.google.common.collect.ImmutableList.toImmutableList;
-import static io.trino.testng.services.Listeners.reportListenerFailure;
+import static io.trino.testing.Listeners.reportListenerFailure;
import static java.lang.annotation.ElementType.METHOD;
import static java.lang.annotation.ElementType.TYPE;
import static java.lang.annotation.RetentionPolicy.RUNTIME;
diff --git a/testing/trino-testing-services/src/main/java/io/trino/testng/services/ReportIllNamedTest.java b/testing/trino-testing-services/src/main/java/io/trino/testng/services/ReportIllNamedTest.java
index f639a273c884f..d9cc99d133082 100644
--- a/testing/trino-testing-services/src/main/java/io/trino/testng/services/ReportIllNamedTest.java
+++ b/testing/trino-testing-services/src/main/java/io/trino/testng/services/ReportIllNamedTest.java
@@ -17,7 +17,7 @@
import org.testng.ITestClass;
import static com.google.common.base.Throwables.getStackTraceAsString;
-import static io.trino.testng.services.Listeners.reportListenerFailure;
+import static io.trino.testing.Listeners.reportListenerFailure;
import static io.trino.testng.services.ReportBadTestAnnotations.isTemptoClass;
public class ReportIllNamedTest
diff --git a/testing/trino-testing-services/src/main/java/io/trino/testng/services/ReportInnerTestClasses.java b/testing/trino-testing-services/src/main/java/io/trino/testng/services/ReportInnerTestClasses.java
index 09f4f3e870ff0..14df836219b88 100644
--- a/testing/trino-testing-services/src/main/java/io/trino/testng/services/ReportInnerTestClasses.java
+++ b/testing/trino-testing-services/src/main/java/io/trino/testng/services/ReportInnerTestClasses.java
@@ -19,7 +19,7 @@
import java.util.Optional;
import static com.google.common.base.Throwables.getStackTraceAsString;
-import static io.trino.testng.services.Listeners.reportListenerFailure;
+import static io.trino.testing.Listeners.reportListenerFailure;
/**
* Detects test classes which are defined as inner classes
diff --git a/testing/trino-testing-services/src/main/java/io/trino/testng/services/ReportMultiThreadedBeforeOrAfterMethod.java b/testing/trino-testing-services/src/main/java/io/trino/testng/services/ReportMultiThreadedBeforeOrAfterMethod.java
index 0d5ad0f11cd76..92c460a9e2f9d 100644
--- a/testing/trino-testing-services/src/main/java/io/trino/testng/services/ReportMultiThreadedBeforeOrAfterMethod.java
+++ b/testing/trino-testing-services/src/main/java/io/trino/testng/services/ReportMultiThreadedBeforeOrAfterMethod.java
@@ -25,7 +25,7 @@
import java.lang.reflect.Method;
import static com.google.common.base.Throwables.getStackTraceAsString;
-import static io.trino.testng.services.Listeners.reportListenerFailure;
+import static io.trino.testing.Listeners.reportListenerFailure;
import static java.lang.String.format;
public class ReportMultiThreadedBeforeOrAfterMethod
diff --git a/testing/trino-testing-services/src/main/java/io/trino/testng/services/ReportOrphanedExecutors.java b/testing/trino-testing-services/src/main/java/io/trino/testng/services/ReportOrphanedExecutors.java
index 84ea1658d2fe2..295bf4499ff37 100644
--- a/testing/trino-testing-services/src/main/java/io/trino/testng/services/ReportOrphanedExecutors.java
+++ b/testing/trino-testing-services/src/main/java/io/trino/testng/services/ReportOrphanedExecutors.java
@@ -26,7 +26,7 @@
import static com.google.common.base.Throwables.getStackTraceAsString;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
-import static io.trino.testng.services.Listeners.reportListenerFailure;
+import static io.trino.testing.Listeners.reportListenerFailure;
import static java.lang.String.format;
import static java.lang.annotation.ElementType.FIELD;
import static java.lang.annotation.RetentionPolicy.RUNTIME;
diff --git a/testing/trino-testing-services/src/main/java/io/trino/testng/services/ReportPrivateMethods.java b/testing/trino-testing-services/src/main/java/io/trino/testng/services/ReportPrivateMethods.java
index 88a0dd18aaff4..3723deb733f14 100644
--- a/testing/trino-testing-services/src/main/java/io/trino/testng/services/ReportPrivateMethods.java
+++ b/testing/trino-testing-services/src/main/java/io/trino/testng/services/ReportPrivateMethods.java
@@ -28,7 +28,7 @@
import static com.google.common.base.Throwables.getStackTraceAsString;
import static com.google.common.collect.ImmutableList.toImmutableList;
-import static io.trino.testng.services.Listeners.reportListenerFailure;
+import static io.trino.testing.Listeners.reportListenerFailure;
import static java.lang.annotation.ElementType.METHOD;
import static java.lang.annotation.RetentionPolicy.RUNTIME;
import static java.util.stream.Collectors.joining;
diff --git a/testing/trino-testing-services/src/main/resources/META-INF/services/org.junit.jupiter.api.extension.Extension b/testing/trino-testing-services/src/main/resources/META-INF/services/org.junit.jupiter.api.extension.Extension
new file mode 100644
index 0000000000000..73ce00e7550b5
--- /dev/null
+++ b/testing/trino-testing-services/src/main/resources/META-INF/services/org.junit.jupiter.api.extension.Extension
@@ -0,0 +1 @@
+io.trino.junit.ReportBadJunitTestAnnotations
diff --git a/testing/trino-testing-services/src/main/resources/junit-platform.properties b/testing/trino-testing-services/src/main/resources/junit-platform.properties
new file mode 100644
index 0000000000000..6efc0d5e85ce0
--- /dev/null
+++ b/testing/trino-testing-services/src/main/resources/junit-platform.properties
@@ -0,0 +1 @@
+junit.jupiter.extensions.autodetection.enabled=true
diff --git a/testing/trino-testing-services/src/test/java/io/trino/junit/TestReportBadJunitTestAnnotations.java b/testing/trino-testing-services/src/test/java/io/trino/junit/TestReportBadJunitTestAnnotations.java
new file mode 100644
index 0000000000000..690866423aa1c
--- /dev/null
+++ b/testing/trino-testing-services/src/test/java/io/trino/junit/TestReportBadJunitTestAnnotations.java
@@ -0,0 +1,113 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.trino.junit;
+
+import org.junit.jupiter.api.AfterAll;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.Test;
+import org.testng.annotations.AfterClass;
+import org.testng.annotations.BeforeClass;
+
+import java.lang.reflect.Method;
+
+import static io.trino.junit.ReportBadJunitTestAnnotations.findUnannotatedInheritedTestMethods;
+import static org.assertj.core.api.Assertions.assertThat;
+
+public class TestReportBadJunitTestAnnotations
+{
+ @Test
+ public void testTest()
+ {
+ assertThat(findUnannotatedInheritedTestMethods(TestingTest.class))
+ .isEmpty();
+ assertThat(findUnannotatedInheritedTestMethods(TestingTestWithProxy.class))
+ .isEmpty();
+ assertThat(findUnannotatedInheritedTestMethods(TestingBeforeAfterAnnotations.class))
+ .isEmpty();
+ assertThat(findUnannotatedInheritedTestMethods(TestingTestWithoutTestAnnotation.class))
+ .isEmpty();
+ }
+
+ @Test
+ public void testTestWithoutTestAnnotation()
+ {
+ assertThat(findUnannotatedInheritedTestMethods(TestingTestWithoutAnnotation.class))
+ .extracting(Method::getName)
+ .containsExactly("testInInterface");
+ }
+
+ private static class TestingTest
+ implements TestingInterfaceWithTest
+ {
+ @Test
+ public void test() {}
+ }
+
+ private static class TestingTestWithoutAnnotation
+ implements TestingInterfaceWithTest
+ {
+ @Override
+ public void testInInterface()
+ {
+ TestingInterfaceWithTest.super.testInInterface();
+ }
+ }
+
+ private static class TestingTestWithProxy
+ extends TestingInterfaceWithTestProxy
+ {
+ @Test
+ public void test() {}
+ }
+
+ private static class TestingTestWithoutTestAnnotation
+ implements TestingInterface
+ {
+ public void testWithMissingTestAnnotation() {}
+
+ @Override
+ public String toString()
+ {
+ return "test override";
+ }
+ }
+
+ private static class TestingInterfaceWithTestProxy
+ implements TestingInterfaceWithTest {}
+
+ private interface TestingInterfaceWithTest
+ {
+ @Test
+ default void testInInterface() {}
+ }
+
+ private interface TestingInterface
+ {
+ default void methodInInterface() {}
+ }
+
+ private static class TestingBeforeAfterAnnotations
+ extends BaseTest {}
+
+ private static class BaseTest
+ {
+ @BeforeAll
+ @BeforeClass
+ public final void initialize() {}
+
+ @AfterAll
+ @AfterClass(alwaysRun = true)
+ public final void destroy() {}
+ }
+}