diff --git a/docs/src/main/asciidoc/spring-data-jpa.adoc b/docs/src/main/asciidoc/spring-data-jpa.adoc index f4cdd5aa8c701..54d4aadfda32e 100644 --- a/docs/src/main/asciidoc/spring-data-jpa.adoc +++ b/docs/src/main/asciidoc/spring-data-jpa.adoc @@ -611,8 +611,9 @@ is not used at all (since all the necessary plumbing is done at build time). Sim * Using `java.util.concurrent.Future` and classes that extend it as return types of repository methods. * Native and named queries when using `@Query` * https://github.com/spring-projects/spring-data-jpa/blob/master/src/main/asciidoc/jpa.adoc#entity-state-detection-strategies[Entity State-detection Strategies] -via `Persistable.isNew(...)` or `EntityInformation`. +via `EntityInformation`. ** As of Quarkus 1.6.0, only "Version-Property and Id-Property inspection" is implemented (which should cover most cases). +** As of Quarkus 1.7.0, `org.springframework.data.domain.Persistable` is also implemented. The Quarkus team is exploring various alternatives to bridging the gap between the JPA and Reactive worlds. diff --git a/extensions/spring-data-jpa/deployment/src/main/java/io/quarkus/spring/data/deployment/DotNames.java b/extensions/spring-data-jpa/deployment/src/main/java/io/quarkus/spring/data/deployment/DotNames.java index 8983e40c90b71..81e2d327c32e0 100644 --- a/extensions/spring-data-jpa/deployment/src/main/java/io/quarkus/spring/data/deployment/DotNames.java +++ b/extensions/spring-data-jpa/deployment/src/main/java/io/quarkus/spring/data/deployment/DotNames.java @@ -17,6 +17,7 @@ import org.springframework.data.domain.Page; import org.springframework.data.domain.PageRequest; import org.springframework.data.domain.Pageable; +import org.springframework.data.domain.Persistable; import org.springframework.data.domain.Slice; import org.springframework.data.domain.Sort; import org.springframework.data.jpa.repository.JpaRepository; @@ -63,6 +64,8 @@ public final class DotNames { .createSimple(Param.class.getName()); public static final DotName SPRING_DATA_MODIFYING = DotName .createSimple(Modifying.class.getName()); + public static final DotName SPRING_DATA_PERSISTABLE = DotName + .createSimple(Persistable.class.getName()); public static final DotName JPA_ID = DotName.createSimple(Id.class.getName()); public static final DotName VERSION = DotName.createSimple(Version.class.getName()); diff --git a/extensions/spring-data-jpa/deployment/src/main/java/io/quarkus/spring/data/deployment/generate/StockMethodsAdder.java b/extensions/spring-data-jpa/deployment/src/main/java/io/quarkus/spring/data/deployment/generate/StockMethodsAdder.java index 8d3113f0d280e..c517473797d6f 100644 --- a/extensions/spring-data-jpa/deployment/src/main/java/io/quarkus/spring/data/deployment/generate/StockMethodsAdder.java +++ b/extensions/spring-data-jpa/deployment/src/main/java/io/quarkus/spring/data/deployment/generate/StockMethodsAdder.java @@ -77,9 +77,9 @@ public void add(ClassCreator classCreator, FieldDescriptor entityClassFieldDescr // for all Spring Data repository methods we know how to implement, check if the generated class actually needs the method // and if so generate the implementation while also keeping the proper records - generateSave(classCreator, entityClassFieldDescriptor, generatedClassName, entityDotName, entityTypeStr, + generateSave(classCreator, generatedClassName, entityDotName, entityTypeStr, allMethodsToBeImplementedToResult); - generateSaveAndFlush(classCreator, entityClassFieldDescriptor, generatedClassName, entityDotName, entityTypeStr, + generateSaveAndFlush(classCreator, generatedClassName, entityDotName, entityTypeStr, allMethodsToBeImplementedToResult); generateSaveAll(classCreator, entityClassFieldDescriptor, generatedClassName, entityDotName, entityTypeStr, allMethodsToBeImplementedToResult); @@ -108,7 +108,7 @@ public void add(ClassCreator classCreator, FieldDescriptor entityClassFieldDescr handleUnimplementedMethods(classCreator, allMethodsToBeImplementedToResult); } - private void generateSave(ClassCreator classCreator, FieldDescriptor entityClassFieldDescriptor, String generatedClassName, + private void generateSave(ClassCreator classCreator, String generatedClassName, DotName entityDotName, String entityTypeStr, Map allMethodsToBeImplementedToResult) { @@ -125,52 +125,64 @@ private void generateSave(ClassCreator classCreator, FieldDescriptor entityClass save.addAnnotation(Transactional.class); ResultHandle entity = save.getMethodParam(0); - AnnotationTarget idAnnotationTarget = getIdAnnotationTarget(entityDotName, index); - ResultHandle idValue = generateObtainValue(save, entityDotName, entity, idAnnotationTarget); - Type idType = getTypeOfTarget(idAnnotationTarget); - Optional versionValueTarget = getVersionAnnotationTarget(entityDotName, index); - - // the following code generated bytecode that: - // if there is a field annotated with @Version, calls 'persist' if the field is null, 'merge' otherwise - // if there is no field annotated with @Version, then if the value of the field annotated with '@Id' - // is "falsy", 'persist' is called, otherwise 'merge' is called - - if (versionValueTarget.isPresent()) { - Type versionType = getTypeOfTarget(versionValueTarget.get()); - if (versionType instanceof PrimitiveType) { - throw new IllegalArgumentException( - "The '@Version' annotation cannot be used on primitive types. Offending entity is '" - + entityDotName + "'."); - } - ResultHandle versionValue = generateObtainValue(save, entityDotName, entity, versionValueTarget.get()); - BranchResult versionValueIsNullBranch = save.ifNull(versionValue); - generatePersistAndReturn(entity, versionValueIsNullBranch.trueBranch()); - generateMergeAndReturn(entity, versionValueIsNullBranch.falseBranch()); - } - BytecodeCreator idValueUnset; - BytecodeCreator idValueSet; - if (idType instanceof PrimitiveType) { - if (!idType.name().equals(DotNames.PRIMITIVE_LONG) - && !idType.name().equals(DotNames.PRIMITIVE_INTEGER)) { - throw new IllegalArgumentException("Id type of '" + entityDotName + "' is invalid."); + // if an entity is Persistable, then all we need to do is call isNew to determine if it's new or not + if (isPersistable(entityDotName)) { + ResultHandle isNew = save.invokeVirtualMethod( + ofMethod(entityDotName.toString(), "isNew", boolean.class.toString()), + entity); + BranchResult isNewBranch = save.ifTrue(isNew); + generatePersistAndReturn(entity, isNewBranch.trueBranch()); + generateMergeAndReturn(entity, isNewBranch.falseBranch()); + } else { + AnnotationTarget idAnnotationTarget = getIdAnnotationTarget(entityDotName, index); + ResultHandle idValue = generateObtainValue(save, entityDotName, entity, idAnnotationTarget); + Type idType = getTypeOfTarget(idAnnotationTarget); + Optional versionValueTarget = getVersionAnnotationTarget(entityDotName, index); + + // the following code generated bytecode that: + // if there is a field annotated with @Version, calls 'persist' if the field is null, 'merge' otherwise + // if there is no field annotated with @Version, then if the value of the field annotated with '@Id' + // is "falsy", 'persist' is called, otherwise 'merge' is called + + if (versionValueTarget.isPresent()) { + Type versionType = getTypeOfTarget(versionValueTarget.get()); + if (versionType instanceof PrimitiveType) { + throw new IllegalArgumentException( + "The '@Version' annotation cannot be used on primitive types. Offending entity is '" + + entityDotName + "'."); + } + ResultHandle versionValue = generateObtainValue(save, entityDotName, entity, + versionValueTarget.get()); + BranchResult versionValueIsNullBranch = save.ifNull(versionValue); + generatePersistAndReturn(entity, versionValueIsNullBranch.trueBranch()); + generateMergeAndReturn(entity, versionValueIsNullBranch.falseBranch()); } - if (idType.name().equals(DotNames.PRIMITIVE_LONG)) { - ResultHandle longObject = save.invokeStaticMethod( - MethodDescriptor.ofMethod(Long.class, "valueOf", Long.class, long.class), idValue); - idValue = save.invokeVirtualMethod(MethodDescriptor.ofMethod(Long.class, "intValue", int.class), - longObject); + + BytecodeCreator idValueUnset; + BytecodeCreator idValueSet; + if (idType instanceof PrimitiveType) { + if (!idType.name().equals(DotNames.PRIMITIVE_LONG) + && !idType.name().equals(DotNames.PRIMITIVE_INTEGER)) { + throw new IllegalArgumentException("Id type of '" + entityDotName + "' is invalid."); + } + if (idType.name().equals(DotNames.PRIMITIVE_LONG)) { + ResultHandle longObject = save.invokeStaticMethod( + MethodDescriptor.ofMethod(Long.class, "valueOf", Long.class, long.class), idValue); + idValue = save.invokeVirtualMethod(MethodDescriptor.ofMethod(Long.class, "intValue", int.class), + longObject); + } + BranchResult idValueNonZeroBranch = save.ifNonZero(idValue); + idValueSet = idValueNonZeroBranch.trueBranch(); + idValueUnset = idValueNonZeroBranch.falseBranch(); + } else { + BranchResult idValueNullBranch = save.ifNull(idValue); + idValueSet = idValueNullBranch.falseBranch(); + idValueUnset = idValueNullBranch.trueBranch(); } - BranchResult idValueNonZeroBranch = save.ifNonZero(idValue); - idValueSet = idValueNonZeroBranch.trueBranch(); - idValueUnset = idValueNonZeroBranch.falseBranch(); - } else { - BranchResult idValueNullBranch = save.ifNull(idValue); - idValueSet = idValueNullBranch.falseBranch(); - idValueUnset = idValueNullBranch.trueBranch(); + generatePersistAndReturn(entity, idValueUnset); + generateMergeAndReturn(entity, idValueSet); } - generatePersistAndReturn(entity, idValueUnset); - generateMergeAndReturn(entity, idValueSet); } try (MethodCreator bridgeSave = classCreator.getMethodCreator(bridgeSaveDescriptor)) { MethodDescriptor save = MethodDescriptor.ofMethod(generatedClassName, "save", entityTypeStr, @@ -187,6 +199,24 @@ private void generateSave(ClassCreator classCreator, FieldDescriptor entityClass } } + private boolean isPersistable(DotName entityDotName) { + ClassInfo classInfo = index.getClassByName(entityDotName); + if (classInfo == null) { + throw new IllegalStateException("Entity " + entityDotName + " was not part of the Quarkus index"); + } + + if (classInfo.interfaceNames().contains(DotNames.SPRING_DATA_PERSISTABLE)) { + return true; + } + + DotName superDotName = classInfo.superName(); + if (superDotName.equals(DotNames.OBJECT)) { + return false; + } + + return isPersistable(superDotName); + } + private void generatePersistAndReturn(ResultHandle entity, BytecodeCreator bytecodeCreator) { bytecodeCreator.invokeStaticMethod( MethodDescriptor.ofMethod(JpaOperations.class, "persist", void.class, Object.class), @@ -235,7 +265,7 @@ private Type getTypeOfTarget(AnnotationTarget idAnnotationTarget) { return idAnnotationTarget.asMethod().returnType(); } - private void generateSaveAndFlush(ClassCreator classCreator, FieldDescriptor entityClassFieldDescriptor, + private void generateSaveAndFlush(ClassCreator classCreator, String generatedClassName, DotName entityDotName, String entityTypeStr, Map allMethodsToBeImplementedToResult) { @@ -254,7 +284,7 @@ private void generateSaveAndFlush(ClassCreator classCreator, FieldDescriptor ent // we need to force the generation of findById since this method depends on it allMethodsToBeImplementedToResult.put(save, false); - generateSave(classCreator, entityClassFieldDescriptor, generatedClassName, entityDotName, entityTypeStr, + generateSave(classCreator, generatedClassName, entityDotName, entityTypeStr, allMethodsToBeImplementedToResult); try (MethodCreator saveAndFlush = classCreator.getMethodCreator(saveAndFlushDescriptor)) { diff --git a/integration-tests/spring-data-jpa/src/main/java/io/quarkus/it/spring/data/jpa/Customer.java b/integration-tests/spring-data-jpa/src/main/java/io/quarkus/it/spring/data/jpa/Customer.java index 6bf0737c12df7..663b24b100bb3 100644 --- a/integration-tests/spring-data-jpa/src/main/java/io/quarkus/it/spring/data/jpa/Customer.java +++ b/integration-tests/spring-data-jpa/src/main/java/io/quarkus/it/spring/data/jpa/Customer.java @@ -9,8 +9,10 @@ import javax.persistence.OneToMany; import javax.validation.constraints.Email; +import org.springframework.data.domain.Persistable; + @Entity -public class Customer extends AbstractEntity { +public class Customer extends AbstractEntity implements Persistable { @Column(name = "first_name") private String firstName; @@ -45,6 +47,11 @@ public Customer(String firstName, String lastName, @Email String email, this.enabled = enabled; } + @Override + public boolean isNew() { + return id == null; + } + public String getFirstName() { return firstName; }