Skip to content

Commit

Permalink
fix entity-manager retrieval in spring-data-repos
Browse files Browse the repository at this point in the history
when using multiple persistence-units:
- save a detached entity (merge)
- getOne
- paginated queries
- deleteAll

closes 38319
  • Loading branch information
fladdimir committed Jan 22, 2024
1 parent 18b26d5 commit e7bd343
Show file tree
Hide file tree
Showing 6 changed files with 131 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ public class AdditionalJpaOperations {
public static PanacheQuery<?> find(AbstractJpaOperations<?> jpaOperations, Class<?> entityClass, String query,
String countQuery, Sort sort, Map<String, Object> params) {
String findQuery = createFindQuery(entityClass, query, jpaOperations.paramCount(params));
EntityManager em = jpaOperations.getEntityManager();
EntityManager em = jpaOperations.getEntityManager(entityClass);
Query jpaQuery = em.createQuery(sort != null ? findQuery + toOrderBy(sort) : findQuery);
JpaOperations.bindParameters(jpaQuery, params);
return new CustomCountPanacheQuery(em, jpaQuery, countQuery, params);
Expand All @@ -47,14 +47,14 @@ public static PanacheQuery<?> find(AbstractJpaOperations<?> jpaOperations, Class
public static PanacheQuery<?> find(AbstractJpaOperations<?> jpaOperations, Class<?> entityClass, String query,
String countQuery, Sort sort, Object... params) {
String findQuery = createFindQuery(entityClass, query, jpaOperations.paramCount(params));
EntityManager em = jpaOperations.getEntityManager();
EntityManager em = jpaOperations.getEntityManager(entityClass);
Query jpaQuery = em.createQuery(sort != null ? findQuery + toOrderBy(sort) : findQuery);
JpaOperations.bindParameters(jpaQuery, params);
return new CustomCountPanacheQuery(em, jpaQuery, countQuery, params);
}

public static long deleteAllWithCascade(AbstractJpaOperations<?> jpaOperations, Class<?> entityClass) {
EntityManager em = jpaOperations.getEntityManager();
EntityManager em = jpaOperations.getEntityManager(entityClass);
//detecting the case where there are cascade-delete associations, and do the bulk delete query otherwise.
if (deleteOnCascadeDetected(jpaOperations, entityClass)) {
int count = 0;
Expand All @@ -77,7 +77,7 @@ public static long deleteAllWithCascade(AbstractJpaOperations<?> jpaOperations,
* @return true if cascading delete is needed. False otherwise
*/
private static boolean deleteOnCascadeDetected(AbstractJpaOperations<?> jpaOperations, Class<?> entityClass) {
EntityManager em = jpaOperations.getEntityManager();
EntityManager em = jpaOperations.getEntityManager(entityClass);
Metamodel metamodel = em.getMetamodel();
EntityType<?> entity1 = metamodel.entity(entityClass);
Set<Attribute<?, ?>> declaredAttributes = ((EntityTypeImpl) entity1).getDeclaredAttributes();
Expand All @@ -96,7 +96,7 @@ private static boolean deleteOnCascadeDetected(AbstractJpaOperations<?> jpaOpera

public static <PanacheQueryType> long deleteWithCascade(AbstractJpaOperations<PanacheQueryType> jpaOperations,
Class<?> entityClass, String query, Object... params) {
EntityManager em = jpaOperations.getEntityManager();
EntityManager em = jpaOperations.getEntityManager(entityClass);
if (deleteOnCascadeDetected(jpaOperations, entityClass)) {
int count = 0;
List<?> objects = jpaOperations.list(jpaOperations.find(entityClass, query, params));
Expand All @@ -112,7 +112,7 @@ public static <PanacheQueryType> long deleteWithCascade(AbstractJpaOperations<Pa
public static <PanacheQueryType> long deleteWithCascade(AbstractJpaOperations<PanacheQueryType> jpaOperations,
Class<?> entityClass, String query,
Map<String, Object> params) {
EntityManager em = jpaOperations.getEntityManager();
EntityManager em = jpaOperations.getEntityManager(entityClass);
if (deleteOnCascadeDetected(jpaOperations, entityClass)) {
int count = 0;
List<?> objects = jpaOperations.list(jpaOperations.find(entityClass, query, params));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@ public void add(ClassCreator classCreator, FieldDescriptor entityClassFieldDescr
// and if so generate the implementation while also keeping the proper records

generateSave(classCreator, generatedClassName, entityDotName, entityTypeStr,
allMethodsToBeImplementedToResult);
allMethodsToBeImplementedToResult, entityClassFieldDescriptor);
generateSaveAndFlush(classCreator, generatedClassName, entityDotName, entityTypeStr,
allMethodsToBeImplementedToResult);
allMethodsToBeImplementedToResult, entityClassFieldDescriptor);
generateSaveAll(classCreator, entityClassFieldDescriptor, generatedClassName, entityDotName, entityTypeStr,
allMethodsToBeImplementedToResult);
generateFlush(classCreator, generatedClassName, allMethodsToBeImplementedToResult);
Expand Down Expand Up @@ -121,7 +121,8 @@ public void add(ClassCreator classCreator, FieldDescriptor entityClassFieldDescr

private void generateSave(ClassCreator classCreator, String generatedClassName,
DotName entityDotName, String entityTypeStr,
Map<MethodDescriptor, Boolean> allMethodsToBeImplementedToResult) {
Map<MethodDescriptor, Boolean> allMethodsToBeImplementedToResult,
FieldDescriptor entityClassFieldDescriptor) {

MethodDescriptor saveDescriptor = MethodDescriptor.ofMethod(generatedClassName, "save", entityTypeStr,
entityTypeStr);
Expand All @@ -144,7 +145,7 @@ private void generateSave(ClassCreator classCreator, String generatedClassName,
entity);
BranchResult isNewBranch = save.ifTrue(isNew);
generatePersistAndReturn(entity, isNewBranch.trueBranch());
generateMergeAndReturn(entity, isNewBranch.falseBranch());
generateMergeAndReturn(entity, isNewBranch.falseBranch(), entityClassFieldDescriptor);
} else {
AnnotationTarget idAnnotationTarget = getIdAnnotationTarget(entityDotName, index);
ResultHandle idValue = generateObtainValue(save, entityDotName, entity, idAnnotationTarget);
Expand All @@ -167,7 +168,7 @@ private void generateSave(ClassCreator classCreator, String generatedClassName,
versionValueTarget.get());
BranchResult versionValueIsNullBranch = save.ifNull(versionValue);
generatePersistAndReturn(entity, versionValueIsNullBranch.trueBranch());
generateMergeAndReturn(entity, versionValueIsNullBranch.falseBranch());
generateMergeAndReturn(entity, versionValueIsNullBranch.falseBranch(), entityClassFieldDescriptor);
}

BytecodeCreator idValueUnset;
Expand All @@ -192,7 +193,7 @@ private void generateSave(ClassCreator classCreator, String generatedClassName,
idValueUnset = idValueNullBranch.trueBranch();
}
generatePersistAndReturn(entity, idValueUnset);
generateMergeAndReturn(entity, idValueSet);
generateMergeAndReturn(entity, idValueSet, entityClassFieldDescriptor);
}
}
try (MethodCreator bridgeSave = classCreator.getMethodCreator(bridgeSaveDescriptor)) {
Expand Down Expand Up @@ -236,10 +237,13 @@ private void generatePersistAndReturn(ResultHandle entity, BytecodeCreator bytec
bytecodeCreator.returnValue(entity);
}

private void generateMergeAndReturn(ResultHandle entity, BytecodeCreator bytecodeCreator) {
private void generateMergeAndReturn(ResultHandle entity, BytecodeCreator bytecodeCreator,
FieldDescriptor entityClassFieldDescriptor) {
ResultHandle entityClass = bytecodeCreator.readInstanceField(entityClassFieldDescriptor, bytecodeCreator.getThis());
ResultHandle entityManager = bytecodeCreator.invokeVirtualMethod(
ofMethod(AbstractJpaOperations.class, "getEntityManager", EntityManager.class),
bytecodeCreator.readStaticField(operationsField));
ofMethod(AbstractJpaOperations.class, "getEntityManager", EntityManager.class, Class.class),
bytecodeCreator.readStaticField(operationsField),
entityClass);
entity = bytecodeCreator.invokeInterfaceMethod(
MethodDescriptor.ofMethod(EntityManager.class, "merge", Object.class, Object.class),
entityManager, entity);
Expand Down Expand Up @@ -280,7 +284,7 @@ private Type getTypeOfTarget(AnnotationTarget idAnnotationTarget) {

private void generateSaveAndFlush(ClassCreator classCreator,
String generatedClassName, DotName entityDotName, String entityTypeStr,
Map<MethodDescriptor, Boolean> allMethodsToBeImplementedToResult) {
Map<MethodDescriptor, Boolean> allMethodsToBeImplementedToResult, FieldDescriptor entityClassFieldDescriptor) {

MethodDescriptor saveAndFlushDescriptor = MethodDescriptor.ofMethod(generatedClassName, "saveAndFlush", entityTypeStr,
entityTypeStr);
Expand All @@ -298,7 +302,7 @@ private void generateSaveAndFlush(ClassCreator classCreator,
// we need to force the generation of findById since this method depends on it
allMethodsToBeImplementedToResult.put(save, false);
generateSave(classCreator, generatedClassName, entityDotName, entityTypeStr,
allMethodsToBeImplementedToResult);
allMethodsToBeImplementedToResult, entityClassFieldDescriptor);

try (MethodCreator saveAndFlush = classCreator.getMethodCreator(saveAndFlushDescriptor)) {
saveAndFlush.addAnnotation(Transactional.class);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
package io.quarkus.spring.data.deployment.multiple_pu;

import static org.assertj.core.api.Assertions.assertThat;

import java.util.function.Supplier;

import jakarta.inject.Inject;

import org.hamcrest.Matchers;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.springframework.data.domain.PageRequest;
import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Sort;

import io.quarkus.narayana.jta.QuarkusTransaction;
import io.quarkus.spring.data.deployment.multiple_pu.first.FirstEntity;
import io.quarkus.spring.data.deployment.multiple_pu.first.FirstEntityRepository;
import io.quarkus.spring.data.deployment.multiple_pu.second.SecondEntity;
Expand All @@ -21,6 +32,17 @@ public class MultiplePersistenceUnitConfigTest {
PanacheTestResource.class)
.addAsResource("application-multiple-persistence-units.properties", "application.properties"));

@Inject
private FirstEntityRepository repository1;
@Inject
private SecondEntityRepository repository2;

@BeforeEach
void beforeEach() {
repository1.deleteAll();
repository2.deleteAll();
}

@Test
public void panacheOperations() {
/**
Expand All @@ -35,4 +57,64 @@ public void panacheOperations() {
RestAssured.when().get("/persistence-unit/second/name-1").then().body(Matchers.is("1"));
RestAssured.when().get("/persistence-unit/second/name-2").then().body(Matchers.is("2"));
}

@Test
public void entityLifecycle() {
var detached = repository2.save(new SecondEntity());
assertThat(detached.id).isNotNull();
assertThat(inTx(repository2::count)).isEqualTo(1);

detached.name = "name";
repository2.save(detached);
assertThat(inTx(repository2::count)).isEqualTo(1);

inTx(() -> {
var lazyRef = repository2.getOne(detached.id);
assertThat(lazyRef.name).isEqualTo(detached.name);
return null;
});

repository2.deleteByName("otherThan" + detached.name);
assertThat(inTx(() -> repository2.findById(detached.id))).isPresent();

repository2.deleteByName(detached.name);
assertThat(inTx(() -> repository2.findById(detached.id))).isEmpty();
}

@Test
void pagedQueries() {
var newEntity = new SecondEntity();
newEntity.name = "name";
var detached = repository2.save(newEntity);

Pageable pageable = PageRequest.of(0, 10, Sort.Direction.DESC, "id");

var page = inTx(() -> repository2.findByName(detached.name, pageable));
assertThat(page.getContent()).extracting(e -> e.id).containsExactly(detached.id);

var pageIndexParam = inTx(() -> repository2.findByNameQueryIndexed(detached.name, pageable));
assertThat(pageIndexParam.getContent()).extracting(e -> e.id).containsExactly(detached.id);

var pageNamedParam = inTx(() -> repository2.findByNameQueryNamed(detached.name, pageable));
assertThat(pageNamedParam.getContent()).extracting(e -> e.id).containsExactly(detached.id);
}

@Test
void cascading() {
var newParent = new SecondEntity();
newParent.name = "parent";
var newChild = new SecondEntity();
newChild.name = "child";
newParent.child = newChild;
var detachedParent = repository2.save(newParent);

assertThat(inTx(repository2::count)).isEqualTo(2);

repository2.deleteByName(detachedParent.name);
assertThat(inTx(repository2::count)).isZero();
}

private <T> T inTx(Supplier<T> action) {
return QuarkusTransaction.requiringNew().call(action::get);
}
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
package io.quarkus.spring.data.deployment.multiple_pu.second;

import jakarta.persistence.Entity;
import jakarta.persistence.GeneratedValue;
import jakarta.persistence.Id;
import jakarta.persistence.*;

@Entity
public class SecondEntity {
Expand All @@ -12,4 +10,7 @@ public class SecondEntity {
public Long id;

public String name;

@OneToOne(cascade = CascadeType.ALL)
public SecondEntity child;
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
package io.quarkus.spring.data.deployment.multiple_pu.second;

import java.util.Optional;

import org.springframework.data.domain.Page;
import org.springframework.data.domain.Pageable;
import org.springframework.data.jpa.repository.Query;
import org.springframework.data.repository.query.Param;
import org.springframework.stereotype.Repository;

@Repository
Expand All @@ -8,4 +14,21 @@ public interface SecondEntityRepository extends org.springframework.data.reposit
SecondEntity save(SecondEntity entity);

long count();

Optional<SecondEntity> findById(Long id);

SecondEntity getOne(Long id);

void deleteAll();

void deleteByName(String name);

Page<SecondEntity> findByName(String name, Pageable pageable);

@Query(value = "SELECT se FROM SecondEntity se WHERE name=?1", countQuery = "SELECT COUNT(*) FROM SecondEntity se WHERE name=?1")
Page<SecondEntity> findByNameQueryIndexed(String name, Pageable pageable);

@Query(value = "SELECT se FROM SecondEntity se WHERE name=:name", countQuery = "SELECT COUNT(*) FROM SecondEntity se WHERE name=:name")
Page<SecondEntity> findByNameQueryNamed(@Param("name") String name, Pageable pageable);

}
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public static void deleteAll(AbstractJpaOperations<PanacheQuery<?>> operations,
}

public static Object getOne(AbstractJpaOperations<PanacheQuery<?>> operations, Class<?> entityClass, Object id) {
return operations.getEntityManager().getReference(entityClass, id);
return operations.getEntityManager(entityClass).getReference(entityClass, id);
}

public static void clear(Class<?> clazz) {
Expand Down

0 comments on commit e7bd343

Please sign in to comment.