From e7bd3432c0b057cafdec35962822ccdc8b2f90ab Mon Sep 17 00:00:00 2001 From: Wladimir Hofmann Date: Mon, 22 Jan 2024 09:37:17 +0100 Subject: [PATCH] fix entity-manager retrieval in spring-data-repos when using multiple persistence-units: - save a detached entity (merge) - getOne - paginated queries - deleteAll closes 38319 --- .../runtime/AdditionalJpaOperations.java | 12 +-- .../generate/StockMethodsAdder.java | 26 +++--- .../MultiplePersistenceUnitConfigTest.java | 82 +++++++++++++++++++ .../multiple_pu/second/SecondEntity.java | 7 +- .../second/SecondEntityRepository.java | 23 ++++++ .../data/runtime/RepositorySupport.java | 2 +- 6 files changed, 131 insertions(+), 21 deletions(-) diff --git a/extensions/panache/hibernate-orm-panache/runtime/src/main/java/io/quarkus/hibernate/orm/panache/runtime/AdditionalJpaOperations.java b/extensions/panache/hibernate-orm-panache/runtime/src/main/java/io/quarkus/hibernate/orm/panache/runtime/AdditionalJpaOperations.java index f136bf2ec6cf4..e0c9a7c729f1a 100644 --- a/extensions/panache/hibernate-orm-panache/runtime/src/main/java/io/quarkus/hibernate/orm/panache/runtime/AdditionalJpaOperations.java +++ b/extensions/panache/hibernate-orm-panache/runtime/src/main/java/io/quarkus/hibernate/orm/panache/runtime/AdditionalJpaOperations.java @@ -32,7 +32,7 @@ public class AdditionalJpaOperations { public static PanacheQuery find(AbstractJpaOperations jpaOperations, Class entityClass, String query, String countQuery, Sort sort, Map params) { String findQuery = createFindQuery(entityClass, query, jpaOperations.paramCount(params)); - EntityManager em = jpaOperations.getEntityManager(); + EntityManager em = jpaOperations.getEntityManager(entityClass); Query jpaQuery = em.createQuery(sort != null ? findQuery + toOrderBy(sort) : findQuery); JpaOperations.bindParameters(jpaQuery, params); return new CustomCountPanacheQuery(em, jpaQuery, countQuery, params); @@ -47,14 +47,14 @@ public static PanacheQuery find(AbstractJpaOperations jpaOperations, Class public static PanacheQuery find(AbstractJpaOperations jpaOperations, Class entityClass, String query, String countQuery, Sort sort, Object... params) { String findQuery = createFindQuery(entityClass, query, jpaOperations.paramCount(params)); - EntityManager em = jpaOperations.getEntityManager(); + EntityManager em = jpaOperations.getEntityManager(entityClass); Query jpaQuery = em.createQuery(sort != null ? findQuery + toOrderBy(sort) : findQuery); JpaOperations.bindParameters(jpaQuery, params); return new CustomCountPanacheQuery(em, jpaQuery, countQuery, params); } public static long deleteAllWithCascade(AbstractJpaOperations jpaOperations, Class entityClass) { - EntityManager em = jpaOperations.getEntityManager(); + EntityManager em = jpaOperations.getEntityManager(entityClass); //detecting the case where there are cascade-delete associations, and do the bulk delete query otherwise. if (deleteOnCascadeDetected(jpaOperations, entityClass)) { int count = 0; @@ -77,7 +77,7 @@ public static long deleteAllWithCascade(AbstractJpaOperations jpaOperations, * @return true if cascading delete is needed. False otherwise */ private static boolean deleteOnCascadeDetected(AbstractJpaOperations jpaOperations, Class entityClass) { - EntityManager em = jpaOperations.getEntityManager(); + EntityManager em = jpaOperations.getEntityManager(entityClass); Metamodel metamodel = em.getMetamodel(); EntityType entity1 = metamodel.entity(entityClass); Set> declaredAttributes = ((EntityTypeImpl) entity1).getDeclaredAttributes(); @@ -96,7 +96,7 @@ private static boolean deleteOnCascadeDetected(AbstractJpaOperations jpaOpera public static long deleteWithCascade(AbstractJpaOperations jpaOperations, Class entityClass, String query, Object... params) { - EntityManager em = jpaOperations.getEntityManager(); + EntityManager em = jpaOperations.getEntityManager(entityClass); if (deleteOnCascadeDetected(jpaOperations, entityClass)) { int count = 0; List objects = jpaOperations.list(jpaOperations.find(entityClass, query, params)); @@ -112,7 +112,7 @@ public static long deleteWithCascade(AbstractJpaOperations long deleteWithCascade(AbstractJpaOperations jpaOperations, Class entityClass, String query, Map params) { - EntityManager em = jpaOperations.getEntityManager(); + EntityManager em = jpaOperations.getEntityManager(entityClass); if (deleteOnCascadeDetected(jpaOperations, entityClass)) { int count = 0; List objects = jpaOperations.list(jpaOperations.find(entityClass, query, params)); diff --git a/extensions/spring-data-jpa/deployment/src/main/java/io/quarkus/spring/data/deployment/generate/StockMethodsAdder.java b/extensions/spring-data-jpa/deployment/src/main/java/io/quarkus/spring/data/deployment/generate/StockMethodsAdder.java index c719ced1638ae..f8fa8fce729a0 100644 --- a/extensions/spring-data-jpa/deployment/src/main/java/io/quarkus/spring/data/deployment/generate/StockMethodsAdder.java +++ b/extensions/spring-data-jpa/deployment/src/main/java/io/quarkus/spring/data/deployment/generate/StockMethodsAdder.java @@ -87,9 +87,9 @@ public void add(ClassCreator classCreator, FieldDescriptor entityClassFieldDescr // and if so generate the implementation while also keeping the proper records generateSave(classCreator, generatedClassName, entityDotName, entityTypeStr, - allMethodsToBeImplementedToResult); + allMethodsToBeImplementedToResult, entityClassFieldDescriptor); generateSaveAndFlush(classCreator, generatedClassName, entityDotName, entityTypeStr, - allMethodsToBeImplementedToResult); + allMethodsToBeImplementedToResult, entityClassFieldDescriptor); generateSaveAll(classCreator, entityClassFieldDescriptor, generatedClassName, entityDotName, entityTypeStr, allMethodsToBeImplementedToResult); generateFlush(classCreator, generatedClassName, allMethodsToBeImplementedToResult); @@ -121,7 +121,8 @@ public void add(ClassCreator classCreator, FieldDescriptor entityClassFieldDescr private void generateSave(ClassCreator classCreator, String generatedClassName, DotName entityDotName, String entityTypeStr, - Map allMethodsToBeImplementedToResult) { + Map allMethodsToBeImplementedToResult, + FieldDescriptor entityClassFieldDescriptor) { MethodDescriptor saveDescriptor = MethodDescriptor.ofMethod(generatedClassName, "save", entityTypeStr, entityTypeStr); @@ -144,7 +145,7 @@ private void generateSave(ClassCreator classCreator, String generatedClassName, entity); BranchResult isNewBranch = save.ifTrue(isNew); generatePersistAndReturn(entity, isNewBranch.trueBranch()); - generateMergeAndReturn(entity, isNewBranch.falseBranch()); + generateMergeAndReturn(entity, isNewBranch.falseBranch(), entityClassFieldDescriptor); } else { AnnotationTarget idAnnotationTarget = getIdAnnotationTarget(entityDotName, index); ResultHandle idValue = generateObtainValue(save, entityDotName, entity, idAnnotationTarget); @@ -167,7 +168,7 @@ private void generateSave(ClassCreator classCreator, String generatedClassName, versionValueTarget.get()); BranchResult versionValueIsNullBranch = save.ifNull(versionValue); generatePersistAndReturn(entity, versionValueIsNullBranch.trueBranch()); - generateMergeAndReturn(entity, versionValueIsNullBranch.falseBranch()); + generateMergeAndReturn(entity, versionValueIsNullBranch.falseBranch(), entityClassFieldDescriptor); } BytecodeCreator idValueUnset; @@ -192,7 +193,7 @@ private void generateSave(ClassCreator classCreator, String generatedClassName, idValueUnset = idValueNullBranch.trueBranch(); } generatePersistAndReturn(entity, idValueUnset); - generateMergeAndReturn(entity, idValueSet); + generateMergeAndReturn(entity, idValueSet, entityClassFieldDescriptor); } } try (MethodCreator bridgeSave = classCreator.getMethodCreator(bridgeSaveDescriptor)) { @@ -236,10 +237,13 @@ private void generatePersistAndReturn(ResultHandle entity, BytecodeCreator bytec bytecodeCreator.returnValue(entity); } - private void generateMergeAndReturn(ResultHandle entity, BytecodeCreator bytecodeCreator) { + private void generateMergeAndReturn(ResultHandle entity, BytecodeCreator bytecodeCreator, + FieldDescriptor entityClassFieldDescriptor) { + ResultHandle entityClass = bytecodeCreator.readInstanceField(entityClassFieldDescriptor, bytecodeCreator.getThis()); ResultHandle entityManager = bytecodeCreator.invokeVirtualMethod( - ofMethod(AbstractJpaOperations.class, "getEntityManager", EntityManager.class), - bytecodeCreator.readStaticField(operationsField)); + ofMethod(AbstractJpaOperations.class, "getEntityManager", EntityManager.class, Class.class), + bytecodeCreator.readStaticField(operationsField), + entityClass); entity = bytecodeCreator.invokeInterfaceMethod( MethodDescriptor.ofMethod(EntityManager.class, "merge", Object.class, Object.class), entityManager, entity); @@ -280,7 +284,7 @@ private Type getTypeOfTarget(AnnotationTarget idAnnotationTarget) { private void generateSaveAndFlush(ClassCreator classCreator, String generatedClassName, DotName entityDotName, String entityTypeStr, - Map allMethodsToBeImplementedToResult) { + Map allMethodsToBeImplementedToResult, FieldDescriptor entityClassFieldDescriptor) { MethodDescriptor saveAndFlushDescriptor = MethodDescriptor.ofMethod(generatedClassName, "saveAndFlush", entityTypeStr, entityTypeStr); @@ -298,7 +302,7 @@ private void generateSaveAndFlush(ClassCreator classCreator, // we need to force the generation of findById since this method depends on it allMethodsToBeImplementedToResult.put(save, false); generateSave(classCreator, generatedClassName, entityDotName, entityTypeStr, - allMethodsToBeImplementedToResult); + allMethodsToBeImplementedToResult, entityClassFieldDescriptor); try (MethodCreator saveAndFlush = classCreator.getMethodCreator(saveAndFlushDescriptor)) { saveAndFlush.addAnnotation(Transactional.class); diff --git a/extensions/spring-data-jpa/deployment/src/test/java/io/quarkus/spring/data/deployment/multiple_pu/MultiplePersistenceUnitConfigTest.java b/extensions/spring-data-jpa/deployment/src/test/java/io/quarkus/spring/data/deployment/multiple_pu/MultiplePersistenceUnitConfigTest.java index 03bdfddb810eb..5c9a54bd7a93c 100644 --- a/extensions/spring-data-jpa/deployment/src/test/java/io/quarkus/spring/data/deployment/multiple_pu/MultiplePersistenceUnitConfigTest.java +++ b/extensions/spring-data-jpa/deployment/src/test/java/io/quarkus/spring/data/deployment/multiple_pu/MultiplePersistenceUnitConfigTest.java @@ -1,9 +1,20 @@ package io.quarkus.spring.data.deployment.multiple_pu; +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.function.Supplier; + +import jakarta.inject.Inject; + import org.hamcrest.Matchers; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; +import org.springframework.data.domain.PageRequest; +import org.springframework.data.domain.Pageable; +import org.springframework.data.domain.Sort; +import io.quarkus.narayana.jta.QuarkusTransaction; import io.quarkus.spring.data.deployment.multiple_pu.first.FirstEntity; import io.quarkus.spring.data.deployment.multiple_pu.first.FirstEntityRepository; import io.quarkus.spring.data.deployment.multiple_pu.second.SecondEntity; @@ -21,6 +32,17 @@ public class MultiplePersistenceUnitConfigTest { PanacheTestResource.class) .addAsResource("application-multiple-persistence-units.properties", "application.properties")); + @Inject + private FirstEntityRepository repository1; + @Inject + private SecondEntityRepository repository2; + + @BeforeEach + void beforeEach() { + repository1.deleteAll(); + repository2.deleteAll(); + } + @Test public void panacheOperations() { /** @@ -35,4 +57,64 @@ public void panacheOperations() { RestAssured.when().get("/persistence-unit/second/name-1").then().body(Matchers.is("1")); RestAssured.when().get("/persistence-unit/second/name-2").then().body(Matchers.is("2")); } + + @Test + public void entityLifecycle() { + var detached = repository2.save(new SecondEntity()); + assertThat(detached.id).isNotNull(); + assertThat(inTx(repository2::count)).isEqualTo(1); + + detached.name = "name"; + repository2.save(detached); + assertThat(inTx(repository2::count)).isEqualTo(1); + + inTx(() -> { + var lazyRef = repository2.getOne(detached.id); + assertThat(lazyRef.name).isEqualTo(detached.name); + return null; + }); + + repository2.deleteByName("otherThan" + detached.name); + assertThat(inTx(() -> repository2.findById(detached.id))).isPresent(); + + repository2.deleteByName(detached.name); + assertThat(inTx(() -> repository2.findById(detached.id))).isEmpty(); + } + + @Test + void pagedQueries() { + var newEntity = new SecondEntity(); + newEntity.name = "name"; + var detached = repository2.save(newEntity); + + Pageable pageable = PageRequest.of(0, 10, Sort.Direction.DESC, "id"); + + var page = inTx(() -> repository2.findByName(detached.name, pageable)); + assertThat(page.getContent()).extracting(e -> e.id).containsExactly(detached.id); + + var pageIndexParam = inTx(() -> repository2.findByNameQueryIndexed(detached.name, pageable)); + assertThat(pageIndexParam.getContent()).extracting(e -> e.id).containsExactly(detached.id); + + var pageNamedParam = inTx(() -> repository2.findByNameQueryNamed(detached.name, pageable)); + assertThat(pageNamedParam.getContent()).extracting(e -> e.id).containsExactly(detached.id); + } + + @Test + void cascading() { + var newParent = new SecondEntity(); + newParent.name = "parent"; + var newChild = new SecondEntity(); + newChild.name = "child"; + newParent.child = newChild; + var detachedParent = repository2.save(newParent); + + assertThat(inTx(repository2::count)).isEqualTo(2); + + repository2.deleteByName(detachedParent.name); + assertThat(inTx(repository2::count)).isZero(); + } + + private T inTx(Supplier action) { + return QuarkusTransaction.requiringNew().call(action::get); + } } diff --git a/extensions/spring-data-jpa/deployment/src/test/java/io/quarkus/spring/data/deployment/multiple_pu/second/SecondEntity.java b/extensions/spring-data-jpa/deployment/src/test/java/io/quarkus/spring/data/deployment/multiple_pu/second/SecondEntity.java index c0753077f8a0b..9bc0ce353c449 100644 --- a/extensions/spring-data-jpa/deployment/src/test/java/io/quarkus/spring/data/deployment/multiple_pu/second/SecondEntity.java +++ b/extensions/spring-data-jpa/deployment/src/test/java/io/quarkus/spring/data/deployment/multiple_pu/second/SecondEntity.java @@ -1,8 +1,6 @@ package io.quarkus.spring.data.deployment.multiple_pu.second; -import jakarta.persistence.Entity; -import jakarta.persistence.GeneratedValue; -import jakarta.persistence.Id; +import jakarta.persistence.*; @Entity public class SecondEntity { @@ -12,4 +10,7 @@ public class SecondEntity { public Long id; public String name; + + @OneToOne(cascade = CascadeType.ALL) + public SecondEntity child; } diff --git a/extensions/spring-data-jpa/deployment/src/test/java/io/quarkus/spring/data/deployment/multiple_pu/second/SecondEntityRepository.java b/extensions/spring-data-jpa/deployment/src/test/java/io/quarkus/spring/data/deployment/multiple_pu/second/SecondEntityRepository.java index d35fbaab8775f..f0ddfd7d4bfab 100644 --- a/extensions/spring-data-jpa/deployment/src/test/java/io/quarkus/spring/data/deployment/multiple_pu/second/SecondEntityRepository.java +++ b/extensions/spring-data-jpa/deployment/src/test/java/io/quarkus/spring/data/deployment/multiple_pu/second/SecondEntityRepository.java @@ -1,5 +1,11 @@ package io.quarkus.spring.data.deployment.multiple_pu.second; +import java.util.Optional; + +import org.springframework.data.domain.Page; +import org.springframework.data.domain.Pageable; +import org.springframework.data.jpa.repository.Query; +import org.springframework.data.repository.query.Param; import org.springframework.stereotype.Repository; @Repository @@ -8,4 +14,21 @@ public interface SecondEntityRepository extends org.springframework.data.reposit SecondEntity save(SecondEntity entity); long count(); + + Optional findById(Long id); + + SecondEntity getOne(Long id); + + void deleteAll(); + + void deleteByName(String name); + + Page findByName(String name, Pageable pageable); + + @Query(value = "SELECT se FROM SecondEntity se WHERE name=?1", countQuery = "SELECT COUNT(*) FROM SecondEntity se WHERE name=?1") + Page findByNameQueryIndexed(String name, Pageable pageable); + + @Query(value = "SELECT se FROM SecondEntity se WHERE name=:name", countQuery = "SELECT COUNT(*) FROM SecondEntity se WHERE name=:name") + Page findByNameQueryNamed(@Param("name") String name, Pageable pageable); + } diff --git a/extensions/spring-data-jpa/runtime/src/main/java/io/quarkus/spring/data/runtime/RepositorySupport.java b/extensions/spring-data-jpa/runtime/src/main/java/io/quarkus/spring/data/runtime/RepositorySupport.java index 53e91d906f023..2da94e173b573 100644 --- a/extensions/spring-data-jpa/runtime/src/main/java/io/quarkus/spring/data/runtime/RepositorySupport.java +++ b/extensions/spring-data-jpa/runtime/src/main/java/io/quarkus/spring/data/runtime/RepositorySupport.java @@ -42,7 +42,7 @@ public static void deleteAll(AbstractJpaOperations> operations, } public static Object getOne(AbstractJpaOperations> operations, Class entityClass, Object id) { - return operations.getEntityManager().getReference(entityClass, id); + return operations.getEntityManager(entityClass).getReference(entityClass, id); } public static void clear(Class clazz) {