Skip to content

Commit

Permalink
Add ability to clean the DB
Browse files Browse the repository at this point in the history
- In tests this is done via @CleanDatabase
- In dev mode this is done from the dev console

Fixes #583
  • Loading branch information
stuartwdouglas committed Mar 18, 2021
1 parent 280967a commit e16231a
Show file tree
Hide file tree
Showing 22 changed files with 566 additions and 1 deletion.
5 changes: 5 additions & 0 deletions docs/src/main/asciidoc/getting-started-testing.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,11 @@ TIP: It is possible to read annotations from the test class or method to control
WARNING: While it is possible to use JUnit Jupiter callback interfaces like `BeforeEachCallback`, you might run into classloading issues because Quarkus has
to run tests in a custom classloader which JUnit is not aware of.

== Reset the Database after tests

You can use the `@io.quarkus.test.ResetDatabase` annotation to reset the database after a test has run. This will drop the database,
and recreate the schema. Quarkus can use Liqibase, FlyWay or Hibernate ORM to reset the schema, depending on what is configured.

[[testing_different_profiles]]
== Testing Different Profiles

Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
package io.quarkus.agroal.deployment;

import static io.quarkus.deployment.annotations.ExecutionTime.STATIC_INIT;

import java.sql.Driver;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.stream.Collectors;

import javax.annotation.Priority;
import javax.enterprise.inject.Default;
import javax.inject.Singleton;
import javax.interceptor.Interceptor;
import javax.sql.XADataSource;

import org.jboss.jandex.DotName;
Expand All @@ -24,9 +29,12 @@
import io.quarkus.agroal.runtime.DataSources;
import io.quarkus.agroal.runtime.DataSourcesJdbcBuildTimeConfig;
import io.quarkus.agroal.runtime.TransactionIntegration;
import io.quarkus.agroal.runtime.schema.CleanDatabaseInterceptor;
import io.quarkus.agroal.spi.JdbcDataSourceBuildItem;
import io.quarkus.agroal.spi.JdbcDriverBuildItem;
import io.quarkus.arc.deployment.AdditionalBeanBuildItem;
import io.quarkus.arc.deployment.GeneratedBeanBuildItem;
import io.quarkus.arc.deployment.GeneratedBeanGizmoAdaptor;
import io.quarkus.arc.deployment.SyntheticBeanBuildItem;
import io.quarkus.arc.processor.DotNames;
import io.quarkus.datasource.common.runtime.DataSourceUtil;
Expand All @@ -37,6 +45,7 @@
import io.quarkus.deployment.Capabilities;
import io.quarkus.deployment.Capability;
import io.quarkus.deployment.Feature;
import io.quarkus.deployment.IsTest;
import io.quarkus.deployment.annotations.BuildProducer;
import io.quarkus.deployment.annotations.BuildStep;
import io.quarkus.deployment.annotations.ExecutionTime;
Expand All @@ -48,6 +57,10 @@
import io.quarkus.deployment.builditem.nativeimage.NativeImageResourceBuildItem;
import io.quarkus.deployment.builditem.nativeimage.ReflectiveClassBuildItem;
import io.quarkus.deployment.pkg.builditem.CurateOutcomeBuildItem;
import io.quarkus.devconsole.spi.DevConsoleRouteBuildItem;
import io.quarkus.devconsole.spi.DevConsoleTemplateInfoBuildItem;
import io.quarkus.gizmo.ClassCreator;
import io.quarkus.narayana.jta.runtime.interceptor.TestTransactionInterceptor;
import io.quarkus.runtime.configuration.ConfigurationException;
import io.quarkus.smallrye.health.deployment.spi.HealthBuildItem;

Expand All @@ -56,6 +69,7 @@ class AgroalProcessor {

private static final Logger log = Logger.getLogger(AgroalProcessor.class);

private static final String CLEAN_DATABASE = "io.quarkus.test.ResetDatabase";
private static final DotName DATA_SOURCE = DotName.createSimple(javax.sql.DataSource.class.getName());

@BuildStep
Expand Down Expand Up @@ -333,4 +347,36 @@ HealthBuildItem addHealthCheck(DataSourcesBuildTimeConfig dataSourcesBuildTimeCo
return new HealthBuildItem("io.quarkus.agroal.runtime.health.DataSourceHealthCheck",
dataSourcesBuildTimeConfig.healthEnabled);
}

@BuildStep
public DevConsoleTemplateInfoBuildItem devConsoleInfo(
List<AggregatedDataSourceBuildTimeConfigBuildItem> dbs) {
return new DevConsoleTemplateInfoBuildItem("dbs",
dbs.stream().map(AggregatedDataSourceBuildTimeConfigBuildItem::getName)
.collect(Collectors.toList()));
}

@BuildStep
@Record(value = STATIC_INIT, optional = true)
DevConsoleRouteBuildItem devConsoleCleanDatabaseHandler(AgroalRecorder recorder) {
return new DevConsoleRouteBuildItem("clean", "POST", recorder.devConsoleCleanDatabaseHandler());
}

@BuildStep(onlyIf = IsTest.class)
void cleanDatabaseSupport(BuildProducer<GeneratedBeanBuildItem> generatedBeanBuildItemBuildProducer,
BuildProducer<AdditionalBeanBuildItem> additionalBeans) {
//generate the annotated interceptor with gizmo
//all the logic is in the parent, but we don't have access to the
//binding annotation here
try (ClassCreator c = ClassCreator.builder()
.classOutput(new GeneratedBeanGizmoAdaptor(generatedBeanBuildItemBuildProducer)).className(
CleanDatabaseInterceptor.class.getName() + "Generated")
.superClass(TestTransactionInterceptor.class).build()) {
c.addAnnotation(CLEAN_DATABASE);
c.addAnnotation(Interceptor.class.getName());
c.addAnnotation(Priority.class).addValue("value", Interceptor.Priority.PLATFORM_BEFORE + 100);
}
additionalBeans.produce(AdditionalBeanBuildItem.builder().addBeanClass(CleanDatabaseInterceptor.class)
.addBeanClass(CLEAN_DATABASE).build());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
{#include main}
{#title}Clean Databases{/title}
{#body}
<table class="table table-striped">
<thead class="thead-dark">
<tr>
<th scope="col">Datasource</th>
<th scope="col">Actions</th>
</tr>
</thead>
<tbody>
{#for db in info:dbs}
<tr>
<td>
{db}
</td>
<td>
<form method="post" enctype="application/x-www-form-urlencoded">
<input type="hidden" name="name" value="{db}">
<input id="invoke" type="submit" value="Invoke" class="btn btn-primary btn-sm">
</form>
</td>
{/for}
</tbody>
</table>
{/body}
{/include}
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
<a href="{urlbase}/clean" class="badge badge-light">
<i class="fa fa-clock fa-fw"></i>
Clean Databases</a>
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
package io.quarkus.agroal.runtime;

import java.util.ServiceLoader;
import java.util.function.Supplier;

import io.agroal.api.AgroalDataSource;
import io.quarkus.agroal.runtime.schema.DatabaseSchemaProvider;
import io.quarkus.datasource.runtime.DataSourcesRuntimeConfig;
import io.quarkus.devconsole.runtime.spi.DevConsolePostHandler;
import io.quarkus.runtime.annotations.Recorder;
import io.vertx.core.Handler;
import io.vertx.core.MultiMap;
import io.vertx.ext.web.RoutingContext;

@Recorder
public class AgroalRecorder {
Expand All @@ -29,4 +35,20 @@ public AgroalDataSource get() {
};
}

public Handler<RoutingContext> devConsoleCleanDatabaseHandler() {
// the usual issue of Vert.x hanging on to the first TCCL and setting it on all its threads
final ClassLoader currentCl = Thread.currentThread().getContextClassLoader();
return new DevConsolePostHandler() {
@Override
protected void handlePost(RoutingContext event, MultiMap form) throws Exception {
String name = form.get("name");
ServiceLoader<DatabaseSchemaProvider> dbs = ServiceLoader.load(DatabaseSchemaProvider.class,
Thread.currentThread().getContextClassLoader());
for (DatabaseSchemaProvider i : dbs) {
i.resetDatabase(name);
}
flashMessage(event, "Action invoked");
}
};
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package io.quarkus.agroal.runtime.schema;

import java.util.ArrayList;
import java.util.List;
import java.util.ServiceLoader;

import javax.interceptor.AroundInvoke;
import javax.interceptor.InvocationContext;

public class CleanDatabaseInterceptor {

final List<DatabaseSchemaProvider> providers;

public CleanDatabaseInterceptor() {
this.providers = new ArrayList<>();
ServiceLoader<DatabaseSchemaProvider> dbs = ServiceLoader.load(DatabaseSchemaProvider.class,
Thread.currentThread().getContextClassLoader());
for (DatabaseSchemaProvider i : dbs) {
providers.add(i);
}
}

@AroundInvoke
public Object intercept(InvocationContext context) throws Exception {
try {
return context.proceed();
} finally {
for (DatabaseSchemaProvider i : providers) {
i.resetAllDatabases();
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package io.quarkus.agroal.runtime.schema;

/**
* A service interface that can be used to reset the database for dev and test mode.
*/
public interface DatabaseSchemaProvider {

void resetDatabase(String dbName);

void resetAllDatabases();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package io.quarkus.flyway.runtime;

import io.quarkus.agroal.runtime.schema.DatabaseSchemaProvider;

public class FlywaySchemaProvider implements DatabaseSchemaProvider {
@Override
public void resetDatabase(String dbName) {
for (FlywayContainer i : FlywayRecorder.FLYWAY_CONTAINERS) {
if (i.getDataSourceName().equals(dbName)) {
i.getFlyway().clean();
i.getFlyway().migrate();
}
}
}

@Override
public void resetAllDatabases() {
for (FlywayContainer i : FlywayRecorder.FLYWAY_CONTAINERS) {
i.getFlyway().clean();
i.getFlyway().migrate();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
io.quarkus.flyway.runtime.FlywaySchemaProvider
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@
import io.quarkus.hibernate.orm.runtime.dialect.QuarkusPostgreSQL10Dialect;
import io.quarkus.hibernate.orm.runtime.integration.HibernateOrmIntegrationStaticDescriptor;
import io.quarkus.hibernate.orm.runtime.proxies.PreGeneratedProxies;
import io.quarkus.hibernate.orm.runtime.schema.SchemaManagementIntegrator;
import io.quarkus.hibernate.orm.runtime.tenant.DataSourceTenantConnectionResolver;
import io.quarkus.hibernate.orm.runtime.tenant.TenantConnectionResolver;
import io.quarkus.hibernate.orm.runtime.tenant.TenantResolver;
Expand Down Expand Up @@ -390,6 +391,7 @@ public void build(RecorderContext recorderContext, HibernateOrmRecorder recorder
List<HibernateOrmIntegrationStaticConfiguredBuildItem> integrationBuildItems,
ProxyDefinitionsBuildItem proxyDefinitions,
BuildProducer<FeatureBuildItem> feature,
LaunchModeBuildItem launchModeBuildItem,
BuildProducer<BeanContainerListenerBuildItem> beanContainerListener) throws Exception {

feature.produce(new FeatureBuildItem(Feature.HIBERNATE_ORM));
Expand Down Expand Up @@ -420,6 +422,9 @@ public void build(RecorderContext recorderContext, HibernateOrmRecorder recorder
for (String integratorClassName : ServiceUtil.classNamesNamedIn(classLoader, INTEGRATOR_SERVICE_FILE)) {
integratorClasses.add((Class<? extends Integrator>) recorderContext.classProxy(integratorClassName));
}
if (launchModeBuildItem.getLaunchMode().isDevOrTest()) {
integratorClasses.add(SchemaManagementIntegrator.class);
}

Map<String, List<HibernateOrmIntegrationStaticDescriptor>> integrationStaticDescriptors = HibernateOrmIntegrationStaticConfiguredBuildItem
.collectDescriptors(integrationBuildItems);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package io.quarkus.hibernate.orm;

import static org.hamcrest.Matchers.is;

import org.jboss.shrinkwrap.api.ShrinkWrap;
import org.jboss.shrinkwrap.api.spec.JavaArchive;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.quarkus.test.QuarkusDevModeTest;
import io.restassured.RestAssured;

public class HibernateSchemaRecreateDevConsoleTestCase {
@RegisterExtension
final static QuarkusDevModeTest TEST = new QuarkusDevModeTest()
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)
.addClasses(MyEntity.class, MyEntityTestResource.class)
.addAsResource("application.properties")
.addAsResource("import.sql"));

@Test
public void testCleanDatabase() {
RestAssured.when().get("/my-entity/count").then().body(is("2"));
RestAssured.when().get("/my-entity/add").then().body(is("MyEntity:added"));
RestAssured.when().get("/my-entity/count").then().body(is("3"));
RestAssured.with()
.redirects().follow(false).formParam("name", "<default>").post("q/dev/io.quarkus.quarkus-agroal/clean")
.then()
.statusCode(303);
RestAssured.when().get("/my-entity/count").then().body(is("2"));

}

private void assertBodyIs(String expectedBody) {
RestAssured.when().get("/my-entity/2").then().body(is(expectedBody));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import javax.inject.Inject;
import javax.persistence.EntityManager;
import javax.transaction.Transactional;
import javax.ws.rs.GET;
import javax.ws.rs.Path;
import javax.ws.rs.PathParam;
Expand All @@ -25,4 +26,22 @@ public String getName(@PathParam("id") long id) {

return "no entity";
}

@GET
@Path("/add")
@Produces(MediaType.TEXT_PLAIN)
@Transactional
public String add() {
MyEntity entity = new MyEntity();
entity.setName("added");
em.persist(entity);
return entity.toString();
}

@GET
@Path("/count")
@Produces(MediaType.TEXT_PLAIN)
public int count() {
return em.createQuery("from MyEntity").getResultList().size();
}
}
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
INSERT INTO MyEntity(id, name) VALUES(1, 'default sql load script entity');
INSERT INTO MyEntity(id, name) VALUES(2, 'import.sql load script entity');
alter sequence myEntitySeq restart with 3;
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import io.quarkus.hibernate.orm.runtime.boot.QuarkusPersistenceUnitDefinition;
import io.quarkus.hibernate.orm.runtime.integration.HibernateOrmIntegrationRuntimeDescriptor;
import io.quarkus.hibernate.orm.runtime.proxies.PreGeneratedProxies;
import io.quarkus.hibernate.orm.runtime.schema.SchemaManagementIntegrator;
import io.quarkus.hibernate.orm.runtime.session.ForwardingSession;
import io.quarkus.hibernate.orm.runtime.tenant.DataSourceTenantConnectionResolver;
import io.quarkus.runtime.annotations.Recorder;
Expand Down Expand Up @@ -54,6 +55,10 @@ public void setupPersistenceProvider(HibernateOrmRuntimeConfig hibernateOrmRunti
public BeanContainerListener initMetadata(List<QuarkusPersistenceUnitDefinition> parsedPersistenceXmlDescriptors,
Scanner scanner, Collection<Class<? extends Integrator>> additionalIntegrators,
PreGeneratedProxies proxyDefinitions) {
SchemaManagementIntegrator.clearDsMap();
for (QuarkusPersistenceUnitDefinition i : parsedPersistenceXmlDescriptors) {
SchemaManagementIntegrator.mapDatasource(i.getDataSource(), i.getName());
}
return new BeanContainerListener() {
@Override
public void created(BeanContainer beanContainer) {
Expand Down Expand Up @@ -118,5 +123,4 @@ protected Session delegate() {
}
};
}

}
Loading

0 comments on commit e16231a

Please sign in to comment.