diff --git a/spring-test/src/main/java/org/springframework/test/context/TestContext.java b/spring-test/src/main/java/org/springframework/test/context/TestContext.java index cce8366ece80..3f6325341780 100644 --- a/spring-test/src/main/java/org/springframework/test/context/TestContext.java +++ b/spring-test/src/main/java/org/springframework/test/context/TestContext.java @@ -42,6 +42,7 @@ * override {@link #setMethodInvoker(MethodInvoker)} and {@link #getMethodInvoker()}. * * @author Sam Brannen + * @author Andreas Ahlenstorf * @since 2.5 * @see TestContextManager * @see TestExecutionListener @@ -110,6 +111,25 @@ default void publishEvent(Function even */ Object getTestInstance(); + /** + * Tests whether a test method is part of this test context. Returns + * {@code true} if this context has a current test method, {@code false} + * otherwise. + * + *

The default implementation of this method always returns {@code false}. + * Custom {@code TestContext} implementations are therefore highly encouraged + * to override this method with a more meaningful implementation. Note that + * the standard {@code TestContext} implementation in Spring overrides this + * method appropriately. + * @return {@code true} if the test execution has already entered a test + * method + * @since 6.1 + * @see #getTestMethod() + */ + default boolean hasTestMethod() { + return false; + } + /** * Get the current {@linkplain Method test method} for this test context. *

Note: this is a mutable property. diff --git a/spring-test/src/main/java/org/springframework/test/context/jdbc/Sql.java b/spring-test/src/main/java/org/springframework/test/context/jdbc/Sql.java index 943b9edff869..6ab67c8cfa51 100644 --- a/spring-test/src/main/java/org/springframework/test/context/jdbc/Sql.java +++ b/spring-test/src/main/java/org/springframework/test/context/jdbc/Sql.java @@ -33,6 +33,11 @@ * *

Method-level declarations override class-level declarations by default, * but this behavior can be configured via {@link SqlMergeMode @SqlMergeMode}. + * However, this does not apply to class-level declarations that use + * {@link ExecutionPhase#BEFORE_TEST_CLASS} or + * {@link ExecutionPhase#AFTER_TEST_CLASS}. Such declarations are retained and + * scripts and statements are executed once per class in addition to any + * method-level annotations. * *

Script execution is performed by the {@link SqlScriptsTestExecutionListener}, * which is enabled by default. @@ -61,6 +66,7 @@ * modules as well as their transitive dependencies to be present on the classpath. * * @author Sam Brannen + * @author Andreas Ahlenstorf * @since 4.1 * @see SqlConfig * @see SqlMergeMode @@ -161,6 +167,18 @@ */ enum ExecutionPhase { + /** + * The configured SQL scripts and statements will be executed + * once before any test method is run. + */ + BEFORE_TEST_CLASS, + + /** + * The configured SQL scripts and statements will be executed + * once after any test method is run. + */ + AFTER_TEST_CLASS, + /** * The configured SQL scripts and statements will be executed * before the corresponding test method. diff --git a/spring-test/src/main/java/org/springframework/test/context/jdbc/SqlScriptsTestExecutionListener.java b/spring-test/src/main/java/org/springframework/test/context/jdbc/SqlScriptsTestExecutionListener.java index c7e51561f8dd..90d98db5c4ca 100644 --- a/spring-test/src/main/java/org/springframework/test/context/jdbc/SqlScriptsTestExecutionListener.java +++ b/spring-test/src/main/java/org/springframework/test/context/jdbc/SqlScriptsTestExecutionListener.java @@ -67,10 +67,17 @@ * {@link Sql#scripts scripts} and inlined {@link Sql#statements statements} * configured via the {@link Sql @Sql} annotation. * - *

Scripts and inlined statements will be executed {@linkplain #beforeTestMethod(TestContext) before} - * or {@linkplain #afterTestMethod(TestContext) after} execution of the corresponding - * {@linkplain java.lang.reflect.Method test method}, depending on the configured - * value of the {@link Sql#executionPhase executionPhase} flag. + *

Class-level annotations that are constrained to a class-level execution + * phase ({@link ExecutionPhase#BEFORE_TEST_CLASS} or + * {@link ExecutionPhase#AFTER_TEST_CLASS}) will be run + * {@linkplain #beforeTestClass(TestContext) once before all test methods} or + * {@linkplain #afterTestMethod(TestContext) once after all test methods}, + * respectively. All other scripts and inlined statements will be executed + * {@linkplain #beforeTestMethod(TestContext) before} or + * {@linkplain #afterTestMethod(TestContext) after} execution of the + * corresponding {@linkplain java.lang.reflect.Method test method}, depending + * on the configured value of the {@link Sql#executionPhase executionPhase} + * flag. * *

Scripts and inlined statements will be executed without a transaction, * within an existing Spring-managed transaction, or within an isolated transaction, @@ -98,6 +105,7 @@ * * @author Sam Brannen * @author Dmitry Semukhin + * @author Andreas Ahlenstorf * @since 4.1 * @see Sql * @see SqlConfig @@ -126,6 +134,26 @@ public final int getOrder() { return 5000; } + /** + * Execute SQL scripts configured via {@link Sql @Sql} for the supplied + * {@link TestContext} once per test class before any test method + * is run. + */ + @Override + public void beforeTestClass(TestContext testContext) throws Exception { + executeBeforeOrAfterClassSqlScripts(testContext, ExecutionPhase.BEFORE_TEST_CLASS); + } + + /** + * Execute SQL scripts configured via {@link Sql @Sql} for the supplied + * {@link TestContext} once per test class after all test methods + * have been run. + */ + @Override + public void afterTestClass(TestContext testContext) throws Exception { + executeBeforeOrAfterClassSqlScripts(testContext, ExecutionPhase.AFTER_TEST_CLASS); + } + /** * Execute SQL scripts configured via {@link Sql @Sql} for the supplied * {@link TestContext} before the current test method. @@ -159,6 +187,17 @@ public void processAheadOfTime(RuntimeHints runtimeHints, Class testClass, Cl registerClasspathResources(getScripts(sql, testClass, testMethod, false), runtimeHints, classLoader))); } + /** + * Execute class-level SQL scripts configured via {@link Sql @Sql} for the + * supplied {@link TestContext} and the execution phases + * {@link ExecutionPhase#BEFORE_TEST_CLASS} and + * {@link ExecutionPhase#AFTER_TEST_CLASS}. + */ + private void executeBeforeOrAfterClassSqlScripts(TestContext testContext, ExecutionPhase executionPhase) { + Class testClass = testContext.getTestClass(); + executeSqlScripts(getSqlAnnotationsFor(testClass), testContext, executionPhase, true); + } + /** * Execute SQL scripts configured via {@link Sql @Sql} for the supplied * {@link TestContext} and {@link ExecutionPhase}. @@ -246,6 +285,9 @@ private void executeSqlScripts( private void executeSqlScripts( Sql sql, ExecutionPhase executionPhase, TestContext testContext, boolean classLevel) { + Assert.isTrue(classLevel || isValidMethodLevelPhase(sql.executionPhase()), + () -> "%s cannot be used on methods".formatted(sql.executionPhase())); + if (executionPhase != sql.executionPhase()) { return; } @@ -260,7 +302,12 @@ else if (logger.isDebugEnabled()) { .formatted(executionPhase, testContext.getTestClass().getName())); } - String[] scripts = getScripts(sql, testContext.getTestClass(), testContext.getTestMethod(), classLevel); + Method testMethod = null; + if (testContext.hasTestMethod()) { + testMethod = testContext.getTestMethod(); + } + + String[] scripts = getScripts(sql, testContext.getTestClass(), testMethod, classLevel); List scriptResources = TestContextResourceUtils.convertToResourceList( testContext.getApplicationContext(), scripts); for (String stmt : sql.statements()) { @@ -354,7 +401,7 @@ private DataSource getDataSourceFromTransactionManager(PlatformTransactionManage return null; } - private String[] getScripts(Sql sql, Class testClass, Method testMethod, boolean classLevel) { + private String[] getScripts(Sql sql, Class testClass, @Nullable Method testMethod, boolean classLevel) { String[] scripts = sql.scripts(); if (ObjectUtils.isEmpty(scripts) && ObjectUtils.isEmpty(sql.statements())) { scripts = new String[] {detectDefaultScript(testClass, testMethod, classLevel)}; @@ -366,7 +413,9 @@ private String[] getScripts(Sql sql, Class testClass, Method testMethod, bool * Detect a default SQL script by implementing the algorithm defined in * {@link Sql#scripts}. */ - private String detectDefaultScript(Class testClass, Method testMethod, boolean classLevel) { + private String detectDefaultScript(Class testClass, @Nullable Method testMethod, boolean classLevel) { + Assert.state(classLevel || testMethod != null, "Method-level @Sql requires a testMethod"); + String elementType = (classLevel ? "class" : "method"); String elementName = (classLevel ? testClass.getName() : testMethod.toString()); @@ -407,4 +456,9 @@ private void registerClasspathResources(String[] paths, RuntimeHints runtimeHint .forEach(runtimeHints.resources()::registerResource); } + private static boolean isValidMethodLevelPhase(ExecutionPhase executionPhase) { + // Class-level phases cannot be used on methods. + return executionPhase == ExecutionPhase.BEFORE_TEST_METHOD || + executionPhase == ExecutionPhase.AFTER_TEST_METHOD; + } } diff --git a/spring-test/src/main/java/org/springframework/test/context/support/DefaultTestContext.java b/spring-test/src/main/java/org/springframework/test/context/support/DefaultTestContext.java index 03b6a25827ed..d95b2457e629 100644 --- a/spring-test/src/main/java/org/springframework/test/context/support/DefaultTestContext.java +++ b/spring-test/src/main/java/org/springframework/test/context/support/DefaultTestContext.java @@ -41,6 +41,7 @@ * @author Sam Brannen * @author Juergen Hoeller * @author Rob Harrop + * @author Andreas Ahlenstorf * @since 4.0 */ @SuppressWarnings("serial") @@ -166,6 +167,11 @@ public final Object getTestInstance() { return testInstance; } + @Override + public boolean hasTestMethod() { + return this.testMethod != null; + } + @Override public final Method getTestMethod() { Method testMethod = this.testMethod; diff --git a/spring-test/src/main/java/org/springframework/test/context/transaction/TestContextTransactionUtils.java b/spring-test/src/main/java/org/springframework/test/context/transaction/TestContextTransactionUtils.java index 38301afc42e7..addfd8ded263 100644 --- a/spring-test/src/main/java/org/springframework/test/context/transaction/TestContextTransactionUtils.java +++ b/spring-test/src/main/java/org/springframework/test/context/transaction/TestContextTransactionUtils.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -46,6 +46,7 @@ * * @author Sam Brannen * @author Juergen Hoeller + * @author Andreas Ahlenstorf * @since 4.1 */ public abstract class TestContextTransactionUtils { @@ -227,7 +228,8 @@ private static void logBeansException(TestContext testContext, BeansException ex /** * Create a delegating {@link TransactionAttribute} for the supplied target * {@link TransactionAttribute} and {@link TestContext}, using the names of - * the test class and test method to build the name of the transaction. + * the test class and test method (if available) to build the name of the + * transaction. * @param testContext the {@code TestContext} upon which to base the name * @param targetAttribute the {@code TransactionAttribute} to delegate to * @return the delegating {@code TransactionAttribute} @@ -248,7 +250,13 @@ private static class TestContextTransactionAttribute extends DelegatingTransacti public TestContextTransactionAttribute(TransactionAttribute targetAttribute, TestContext testContext) { super(targetAttribute); - this.name = ClassUtils.getQualifiedMethodName(testContext.getTestMethod(), testContext.getTestClass()); + + if (testContext.hasTestMethod()) { + this.name = ClassUtils.getQualifiedMethodName(testContext.getTestMethod(), testContext.getTestClass()); + } + else { + this.name = testContext.getTestClass().getName(); + } } @Override diff --git a/spring-test/src/test/java/org/springframework/test/context/jdbc/AfterTestClassSqlScriptsTests.java b/spring-test/src/test/java/org/springframework/test/context/jdbc/AfterTestClassSqlScriptsTests.java new file mode 100644 index 000000000000..1d6633288ead --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/jdbc/AfterTestClassSqlScriptsTests.java @@ -0,0 +1,87 @@ +/* + * Copyright 2002-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.jdbc; + +import javax.sql.DataSource; + +import org.junit.jupiter.api.Order; +import org.junit.jupiter.api.Test; + +import org.springframework.core.Ordered; +import org.springframework.jdbc.BadSqlGrammarException; +import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.test.annotation.Commit; +import org.springframework.test.annotation.DirtiesContext; +import org.springframework.test.context.TestContext; +import org.springframework.test.context.TestExecutionListener; +import org.springframework.test.context.TestExecutionListeners; +import org.springframework.test.context.junit.jupiter.SpringJUnitConfig; +import org.springframework.test.context.transaction.TestContextTransactionUtils; + +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; + +/** + * Verifies that {@link Sql @Sql} with {@link Sql.ExecutionPhase#AFTER_TEST_CLASS} is run after all tests in the class + * have been run. + * + * @author Andreas Ahlenstorf + * @since 6.1 + */ +@SpringJUnitConfig(PopulatedSchemaDatabaseConfig.class) +@DirtiesContext(classMode = DirtiesContext.ClassMode.AFTER_CLASS) +@Sql(value = {"drop-schema.sql"}, executionPhase = Sql.ExecutionPhase.AFTER_TEST_CLASS) +@TestExecutionListeners( + value = AfterTestClassSqlScriptsTests.VerifyTestExecutionListener.class, + mergeMode = TestExecutionListeners.MergeMode.MERGE_WITH_DEFAULTS +) +class AfterTestClassSqlScriptsTests extends AbstractTransactionalTests { + + @Test + @Order(1) + @Sql(scripts = "data-add-catbert.sql") + @Commit + void databaseHasBeenInitialized() { + assertUsers("Catbert"); + } + + @Test + @Order(2) + @Sql(scripts = "data-add-dogbert.sql") + @Commit + void databaseIsNotWipedBetweenTests() { + assertUsers("Catbert", "Dogbert"); + } + + static class VerifyTestExecutionListener implements TestExecutionListener, Ordered { + + @Override + public void afterTestClass(TestContext testContext) throws Exception { + DataSource dataSource = TestContextTransactionUtils.retrieveDataSource(testContext, null); + JdbcTemplate jdbcTemplate = new JdbcTemplate(dataSource); + + assertThatExceptionOfType(BadSqlGrammarException.class) + .isThrownBy(() -> jdbcTemplate.queryForList("SELECT name FROM user", String.class)); + } + + @Override + public int getOrder() { + // Must run before DirtiesContextTestExecutionListener. Otherwise, the old data source will be removed and + // replaced with a new one. + return 3001; + } + } +} diff --git a/spring-test/src/test/java/org/springframework/test/context/jdbc/BeforeTestClassSqlScriptsTests.java b/spring-test/src/test/java/org/springframework/test/context/jdbc/BeforeTestClassSqlScriptsTests.java new file mode 100644 index 000000000000..acaae772cd76 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/jdbc/BeforeTestClassSqlScriptsTests.java @@ -0,0 +1,58 @@ +/* + * Copyright 2002-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.jdbc; + +import org.junit.jupiter.api.Test; + +import org.springframework.test.annotation.DirtiesContext; +import org.springframework.test.context.junit.jupiter.SpringJUnitConfig; + +import static org.springframework.test.context.jdbc.SqlMergeMode.MergeMode.MERGE; +import static org.springframework.test.context.jdbc.SqlMergeMode.MergeMode.OVERRIDE; + +/** + * Verifies that {@link Sql @Sql} with {@link Sql.ExecutionPhase#BEFORE_TEST_CLASS} is run before all tests in the class + * have been run. + * + * @author Andreas Ahlenstorf + * @since 6.1 + */ +@SpringJUnitConfig(classes = EmptyDatabaseConfig.class) +@DirtiesContext +@Sql(value = {"schema.sql", "data-add-catbert.sql"}, executionPhase = Sql.ExecutionPhase.BEFORE_TEST_CLASS) +class BeforeTestClassSqlScriptsTests extends AbstractTransactionalTests { + + @Test + void classLevelScriptsHaveBeenRun() { + assertUsers("Catbert"); + } + + @Test + @Sql("data-add-dogbert.sql") + @SqlMergeMode(MERGE) + void mergeDoesNotAffectClassLevelPhase() { + assertUsers("Catbert", "Dogbert"); + } + + @Test + @Sql({"data-add-dogbert.sql"}) + @SqlMergeMode(OVERRIDE) + void overrideDoesNotAffectClassLevelPhase() { + assertUsers("Dogbert", "Catbert"); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/jdbc/SqlScriptsTestExecutionListenerTests.java b/spring-test/src/test/java/org/springframework/test/context/jdbc/SqlScriptsTestExecutionListenerTests.java index 8d8a54b2c944..ff9f3b953f2a 100644 --- a/spring-test/src/test/java/org/springframework/test/context/jdbc/SqlScriptsTestExecutionListenerTests.java +++ b/spring-test/src/test/java/org/springframework/test/context/jdbc/SqlScriptsTestExecutionListenerTests.java @@ -25,6 +25,7 @@ import org.springframework.test.context.jdbc.SqlConfig.TransactionMode; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.assertj.core.api.Assertions.assertThatIllegalStateException; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.BDDMockito.given; @@ -34,6 +35,7 @@ * Unit tests for {@link SqlScriptsTestExecutionListener}. * * @author Sam Brannen + * @author Andreas Ahlenstorf * @since 4.1 */ class SqlScriptsTestExecutionListenerTests { @@ -56,6 +58,7 @@ void missingValueAndScriptsAndStatementsAtClassLevel() throws Exception { void missingValueAndScriptsAndStatementsAtMethodLevel() throws Exception { Class clazz = MissingValueAndScriptsAndStatementsAtMethodLevel.class; BDDMockito.> given(testContext.getTestClass()).willReturn(clazz); + given(testContext.hasTestMethod()).willReturn(true); given(testContext.getTestMethod()).willReturn(clazz.getDeclaredMethod("foo")); assertExceptionContains(clazz.getSimpleName() + ".foo" + ".sql"); @@ -102,6 +105,30 @@ void missingDataSourceAndTxMgr() throws Exception { assertExceptionContains("supply at least a DataSource or PlatformTransactionManager"); } + @Test + void beforeTestClassOnMethod() throws Exception { + Class clazz = ClassLevelExecutionPhaseOnMethod.class; + BDDMockito.> given(testContext.getTestClass()).willReturn(clazz); + given(testContext.hasTestMethod()).willReturn(true); + given(testContext.getTestMethod()).willReturn(clazz.getDeclaredMethod("beforeTestClass")); + + assertThatIllegalArgumentException() + .isThrownBy(() -> listener.beforeTestMethod(testContext)) + .withMessage("BEFORE_TEST_CLASS cannot be used on methods"); + } + + @Test + void afterTestClassOnMethod() throws Exception { + Class clazz = ClassLevelExecutionPhaseOnMethod.class; + BDDMockito.> given(testContext.getTestClass()).willReturn(clazz); + given(testContext.hasTestMethod()).willReturn(true); + given(testContext.getTestMethod()).willReturn(clazz.getDeclaredMethod("afterTestClass")); + + assertThatIllegalArgumentException() + .isThrownBy(() -> listener.beforeTestMethod(testContext)) + .withMessage("AFTER_TEST_CLASS cannot be used on methods"); + } + private void assertExceptionContains(String msg) throws Exception { assertThatIllegalStateException().isThrownBy(() -> listener.beforeTestMethod(testContext)) @@ -146,4 +173,14 @@ public void foo() { } } + static class ClassLevelExecutionPhaseOnMethod { + + @Sql(scripts = "foo.sql", executionPhase = Sql.ExecutionPhase.BEFORE_TEST_CLASS) + public void beforeTestClass() { + } + + @Sql(scripts = "foo.sql", executionPhase = Sql.ExecutionPhase.AFTER_TEST_CLASS) + public void afterTestClass() { + } + } }