Skip to content

Commit

Permalink
Support Spring Data JPA's isPersistable
Browse files Browse the repository at this point in the history
Fixes: #10231
  • Loading branch information
geoand committed Jul 14, 2020
1 parent 5ffccca commit 6059d01
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 49 deletions.
3 changes: 2 additions & 1 deletion docs/src/main/asciidoc/spring-data-jpa.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -611,8 +611,9 @@ is not used at all (since all the necessary plumbing is done at build time). Sim
* Using `java.util.concurrent.Future` and classes that extend it as return types of repository methods.
* Native and named queries when using `@Query`
* https://github.com/spring-projects/spring-data-jpa/blob/master/src/main/asciidoc/jpa.adoc#entity-state-detection-strategies[Entity State-detection Strategies]
via `Persistable.isNew(...)` or `EntityInformation`.
via `EntityInformation`.
** As of Quarkus 1.6.0, only "Version-Property and Id-Property inspection" is implemented (which should cover most cases).
** As of Quarkus 1.7.0, `org.springframework.data.domain.Persistable` is also implemented.

The Quarkus team is exploring various alternatives to bridging the gap between the JPA and Reactive worlds.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.springframework.data.domain.Page;
import org.springframework.data.domain.PageRequest;
import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Persistable;
import org.springframework.data.domain.Slice;
import org.springframework.data.domain.Sort;
import org.springframework.data.jpa.repository.JpaRepository;
Expand Down Expand Up @@ -63,6 +64,8 @@ public final class DotNames {
.createSimple(Param.class.getName());
public static final DotName SPRING_DATA_MODIFYING = DotName
.createSimple(Modifying.class.getName());
public static final DotName SPRING_DATA_PERSISTABLE = DotName
.createSimple(Persistable.class.getName());

public static final DotName JPA_ID = DotName.createSimple(Id.class.getName());
public static final DotName VERSION = DotName.createSimple(Version.class.getName());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,9 @@ public void add(ClassCreator classCreator, FieldDescriptor entityClassFieldDescr
// for all Spring Data repository methods we know how to implement, check if the generated class actually needs the method
// and if so generate the implementation while also keeping the proper records

generateSave(classCreator, entityClassFieldDescriptor, generatedClassName, entityDotName, entityTypeStr,
generateSave(classCreator, generatedClassName, entityDotName, entityTypeStr,
allMethodsToBeImplementedToResult);
generateSaveAndFlush(classCreator, entityClassFieldDescriptor, generatedClassName, entityDotName, entityTypeStr,
generateSaveAndFlush(classCreator, generatedClassName, entityDotName, entityTypeStr,
allMethodsToBeImplementedToResult);
generateSaveAll(classCreator, entityClassFieldDescriptor, generatedClassName, entityDotName, entityTypeStr,
allMethodsToBeImplementedToResult);
Expand Down Expand Up @@ -108,7 +108,7 @@ public void add(ClassCreator classCreator, FieldDescriptor entityClassFieldDescr
handleUnimplementedMethods(classCreator, allMethodsToBeImplementedToResult);
}

private void generateSave(ClassCreator classCreator, FieldDescriptor entityClassFieldDescriptor, String generatedClassName,
private void generateSave(ClassCreator classCreator, String generatedClassName,
DotName entityDotName, String entityTypeStr,
Map<MethodDescriptor, Boolean> allMethodsToBeImplementedToResult) {

Expand All @@ -125,52 +125,64 @@ private void generateSave(ClassCreator classCreator, FieldDescriptor entityClass
save.addAnnotation(Transactional.class);

ResultHandle entity = save.getMethodParam(0);
AnnotationTarget idAnnotationTarget = getIdAnnotationTarget(entityDotName, index);
ResultHandle idValue = generateObtainValue(save, entityDotName, entity, idAnnotationTarget);
Type idType = getTypeOfTarget(idAnnotationTarget);
Optional<AnnotationTarget> versionValueTarget = getVersionAnnotationTarget(entityDotName, index);

// the following code generated bytecode that:
// if there is a field annotated with @Version, calls 'persist' if the field is null, 'merge' otherwise
// if there is no field annotated with @Version, then if the value of the field annotated with '@Id'
// is "falsy", 'persist' is called, otherwise 'merge' is called

if (versionValueTarget.isPresent()) {
Type versionType = getTypeOfTarget(versionValueTarget.get());
if (versionType instanceof PrimitiveType) {
throw new IllegalArgumentException(
"The '@Version' annotation cannot be used on primitive types. Offending entity is '"
+ entityDotName + "'.");
}
ResultHandle versionValue = generateObtainValue(save, entityDotName, entity, versionValueTarget.get());
BranchResult versionValueIsNullBranch = save.ifNull(versionValue);
generatePersistAndReturn(entity, versionValueIsNullBranch.trueBranch());
generateMergeAndReturn(entity, versionValueIsNullBranch.falseBranch());
}

BytecodeCreator idValueUnset;
BytecodeCreator idValueSet;
if (idType instanceof PrimitiveType) {
if (!idType.name().equals(DotNames.PRIMITIVE_LONG)
&& !idType.name().equals(DotNames.PRIMITIVE_INTEGER)) {
throw new IllegalArgumentException("Id type of '" + entityDotName + "' is invalid.");
// if an entity is Persistable, then all we need to do is call isNew to determine if it's new or not
if (isPersistable(entityDotName)) {
ResultHandle isNew = save.invokeVirtualMethod(
ofMethod(entityDotName.toString(), "isNew", boolean.class.toString()),
entity);
BranchResult isNewBranch = save.ifTrue(isNew);
generatePersistAndReturn(entity, isNewBranch.trueBranch());
generateMergeAndReturn(entity, isNewBranch.falseBranch());
} else {
AnnotationTarget idAnnotationTarget = getIdAnnotationTarget(entityDotName, index);
ResultHandle idValue = generateObtainValue(save, entityDotName, entity, idAnnotationTarget);
Type idType = getTypeOfTarget(idAnnotationTarget);
Optional<AnnotationTarget> versionValueTarget = getVersionAnnotationTarget(entityDotName, index);

// the following code generated bytecode that:
// if there is a field annotated with @Version, calls 'persist' if the field is null, 'merge' otherwise
// if there is no field annotated with @Version, then if the value of the field annotated with '@Id'
// is "falsy", 'persist' is called, otherwise 'merge' is called

if (versionValueTarget.isPresent()) {
Type versionType = getTypeOfTarget(versionValueTarget.get());
if (versionType instanceof PrimitiveType) {
throw new IllegalArgumentException(
"The '@Version' annotation cannot be used on primitive types. Offending entity is '"
+ entityDotName + "'.");
}
ResultHandle versionValue = generateObtainValue(save, entityDotName, entity,
versionValueTarget.get());
BranchResult versionValueIsNullBranch = save.ifNull(versionValue);
generatePersistAndReturn(entity, versionValueIsNullBranch.trueBranch());
generateMergeAndReturn(entity, versionValueIsNullBranch.falseBranch());
}
if (idType.name().equals(DotNames.PRIMITIVE_LONG)) {
ResultHandle longObject = save.invokeStaticMethod(
MethodDescriptor.ofMethod(Long.class, "valueOf", Long.class, long.class), idValue);
idValue = save.invokeVirtualMethod(MethodDescriptor.ofMethod(Long.class, "intValue", int.class),
longObject);

BytecodeCreator idValueUnset;
BytecodeCreator idValueSet;
if (idType instanceof PrimitiveType) {
if (!idType.name().equals(DotNames.PRIMITIVE_LONG)
&& !idType.name().equals(DotNames.PRIMITIVE_INTEGER)) {
throw new IllegalArgumentException("Id type of '" + entityDotName + "' is invalid.");
}
if (idType.name().equals(DotNames.PRIMITIVE_LONG)) {
ResultHandle longObject = save.invokeStaticMethod(
MethodDescriptor.ofMethod(Long.class, "valueOf", Long.class, long.class), idValue);
idValue = save.invokeVirtualMethod(MethodDescriptor.ofMethod(Long.class, "intValue", int.class),
longObject);
}
BranchResult idValueNonZeroBranch = save.ifNonZero(idValue);
idValueSet = idValueNonZeroBranch.trueBranch();
idValueUnset = idValueNonZeroBranch.falseBranch();
} else {
BranchResult idValueNullBranch = save.ifNull(idValue);
idValueSet = idValueNullBranch.falseBranch();
idValueUnset = idValueNullBranch.trueBranch();
}
BranchResult idValueNonZeroBranch = save.ifNonZero(idValue);
idValueSet = idValueNonZeroBranch.trueBranch();
idValueUnset = idValueNonZeroBranch.falseBranch();
} else {
BranchResult idValueNullBranch = save.ifNull(idValue);
idValueSet = idValueNullBranch.falseBranch();
idValueUnset = idValueNullBranch.trueBranch();
generatePersistAndReturn(entity, idValueUnset);
generateMergeAndReturn(entity, idValueSet);
}
generatePersistAndReturn(entity, idValueUnset);
generateMergeAndReturn(entity, idValueSet);
}
try (MethodCreator bridgeSave = classCreator.getMethodCreator(bridgeSaveDescriptor)) {
MethodDescriptor save = MethodDescriptor.ofMethod(generatedClassName, "save", entityTypeStr,
Expand All @@ -187,6 +199,24 @@ private void generateSave(ClassCreator classCreator, FieldDescriptor entityClass
}
}

private boolean isPersistable(DotName entityDotName) {
ClassInfo classInfo = index.getClassByName(entityDotName);
if (classInfo == null) {
throw new IllegalStateException("Entity " + entityDotName + " was not part of the Quarkus index");
}

if (classInfo.interfaceNames().contains(DotNames.SPRING_DATA_PERSISTABLE)) {
return true;
}

DotName superDotName = classInfo.superName();
if (superDotName.equals(DotNames.OBJECT)) {
return false;
}

return isPersistable(superDotName);
}

private void generatePersistAndReturn(ResultHandle entity, BytecodeCreator bytecodeCreator) {
bytecodeCreator.invokeStaticMethod(
MethodDescriptor.ofMethod(JpaOperations.class, "persist", void.class, Object.class),
Expand Down Expand Up @@ -235,7 +265,7 @@ private Type getTypeOfTarget(AnnotationTarget idAnnotationTarget) {
return idAnnotationTarget.asMethod().returnType();
}

private void generateSaveAndFlush(ClassCreator classCreator, FieldDescriptor entityClassFieldDescriptor,
private void generateSaveAndFlush(ClassCreator classCreator,
String generatedClassName, DotName entityDotName, String entityTypeStr,
Map<MethodDescriptor, Boolean> allMethodsToBeImplementedToResult) {

Expand All @@ -254,7 +284,7 @@ private void generateSaveAndFlush(ClassCreator classCreator, FieldDescriptor ent

// we need to force the generation of findById since this method depends on it
allMethodsToBeImplementedToResult.put(save, false);
generateSave(classCreator, entityClassFieldDescriptor, generatedClassName, entityDotName, entityTypeStr,
generateSave(classCreator, generatedClassName, entityDotName, entityTypeStr,
allMethodsToBeImplementedToResult);

try (MethodCreator saveAndFlush = classCreator.getMethodCreator(saveAndFlushDescriptor)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
import javax.persistence.OneToMany;
import javax.validation.constraints.Email;

import org.springframework.data.domain.Persistable;

@Entity
public class Customer extends AbstractEntity {
public class Customer extends AbstractEntity implements Persistable<Long> {

@Column(name = "first_name")
private String firstName;
Expand Down Expand Up @@ -45,6 +47,11 @@ public Customer(String firstName, String lastName, @Email String email,
this.enabled = enabled;
}

@Override
public boolean isNew() {
return id == null;
}

public String getFirstName() {
return firstName;
}
Expand Down

0 comments on commit 6059d01

Please sign in to comment.