Skip to content

Commit

Permalink
Merge pull request #13931 from glefloch/fix/projection-spring
Browse files Browse the repository at this point in the history
Handle single result custom type in custom query method
  • Loading branch information
geoand authored Dec 17, 2020
2 parents 5ef0548 + a2e155e commit 9cdae32
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ protected void generateFindQueryResultHandling(MethodCreator methodCreator, Resu
if (limit != null) {
// create a custom page object that will limit the results by the limit size
page = methodCreator.newInstance(MethodDescriptor.ofConstructor(Page.class, int.class), methodCreator.load(limit));

} else if (pageableParameterIndex != null) {
page = methodCreator.invokeStaticMethod(
MethodDescriptor.ofMethod(TypesConverter.class, "toPanachePage", Page.class, Pageable.class),
Expand All @@ -98,8 +97,10 @@ protected void generateFindQueryResultHandling(MethodCreator methodCreator, Resu
ResultHandle singleResult = tryBlock.invokeInterfaceMethod(
MethodDescriptor.ofMethod(PanacheQuery.class, panacheQueryMethodToUse, Object.class),
panacheQuery);

ResultHandle casted = tryBlock.checkCast(singleResult, entityClassInfo.name().toString());
tryBlock.returnValue(casted);

CatchBlockCreator catchBlock = tryBlock.addCatch(NoResultException.class);
catchBlock.returnValue(catchBlock.loadNull());

Expand All @@ -116,11 +117,24 @@ protected void generateFindQueryResultHandling(MethodCreator methodCreator, Resu
ResultHandle singleResult = tryBlock.invokeInterfaceMethod(
MethodDescriptor.ofMethod(PanacheQuery.class, panacheQueryMethodToUse, Object.class),
panacheQuery);
ResultHandle casted = tryBlock.checkCast(singleResult, entityClassInfo.name().toString());
ResultHandle optional = tryBlock.invokeStaticMethod(
MethodDescriptor.ofMethod(Optional.class, "of", Optional.class, Object.class),
casted);
tryBlock.returnValue(optional);

if (customResultType == null) {
ResultHandle casted = tryBlock.checkCast(singleResult, entityClassInfo.name().toString());
ResultHandle optional = tryBlock.invokeStaticMethod(
MethodDescriptor.ofMethod(Optional.class, "of", Optional.class, Object.class),
casted);
tryBlock.returnValue(optional);
} else {
ResultHandle customResult = tryBlock.invokeStaticMethod(
MethodDescriptor.ofMethod(customResultType.toString(), "convert_" + methodName,
customResultType.toString(),
Object[].class.getName()),
singleResult);
ResultHandle optional = tryBlock.invokeStaticMethod(
MethodDescriptor.ofMethod(Optional.class, "of", Optional.class, Object.class),
customResult);
tryBlock.returnValue(optional);
}
CatchBlockCreator catchBlock = tryBlock.addCatch(NoResultException.class);
ResultHandle emptyOptional = catchBlock.invokeStaticMethod(
MethodDescriptor.ofMethod(Optional.class, "empty", Optional.class));
Expand All @@ -134,13 +148,14 @@ protected void generateFindQueryResultHandling(MethodCreator methodCreator, Resu
MethodDescriptor.ofMethod(PanacheQuery.class, "list", List.class),
panacheQuery);
} else {

ResultHandle stream = methodCreator.invokeInterfaceMethod(
MethodDescriptor.ofMethod(PanacheQuery.class, "stream", Stream.class),
panacheQuery);

// Function to convert Object[] to the custom type (using the generated static convert method)
FunctionCreator function = methodCreator.createFunction(Function.class);
BytecodeCreator funcBytecode = function.getBytecode();
FunctionCreator customResultMappingFunction = methodCreator.createFunction(Function.class);
BytecodeCreator funcBytecode = customResultMappingFunction.getBytecode();
ResultHandle obj = funcBytecode.invokeStaticMethod(
MethodDescriptor.ofMethod(customResultType.toString(), "convert_" + methodName,
customResultType.toString(),
Expand All @@ -150,7 +165,7 @@ protected void generateFindQueryResultHandling(MethodCreator methodCreator, Resu

stream = methodCreator.invokeInterfaceMethod(
MethodDescriptor.ofMethod(Stream.class, "map", Stream.class, Function.class),
stream, function.getInstance());
stream, customResultMappingFunction.getInstance());

// Re-collect the stream into a list
ResultHandle collector = methodCreator.invokeStaticMethod(
Expand Down Expand Up @@ -213,12 +228,32 @@ protected void generateFindQueryResultHandling(MethodCreator methodCreator, Resu
}

methodCreator.returnValue(sliceResult);

} else if (isIntLongOrBoolean(returnType)) {
ResultHandle singleResult = methodCreator.invokeInterfaceMethod(
MethodDescriptor.ofMethod(PanacheQuery.class, "singleResult", Object.class),
panacheQuery);
methodCreator.returnValue(singleResult);
} else if (customResultType != null) {
// when limit is specified we don't want to fail when there are multiple results, we just want to return the first one
String panacheQueryMethodToUse = (limit != null) ? "firstResult" : "singleResult";

TryBlock tryBlock = methodCreator.tryBlock();
ResultHandle singleResult = tryBlock.invokeInterfaceMethod(
MethodDescriptor.ofMethod(PanacheQuery.class, panacheQueryMethodToUse, Object.class),
panacheQuery);

ResultHandle customResult = tryBlock.invokeStaticMethod(
MethodDescriptor.ofMethod(customResultType.toString(), "convert_" + methodName,
customResultType.toString(),
Object[].class.getName()),
singleResult);

tryBlock.returnValue(customResult);

CatchBlockCreator catchBlock = tryBlock.addCatch(NoResultException.class);
catchBlock.returnValue(catchBlock.loadNull());

tryBlock.returnValue(customResult);
} else {
throw new IllegalArgumentException(
"Return type of method " + methodName + " of Repository " + repositoryClassInfo
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import java.util.Iterator;
import java.util.List;
import java.util.Optional;

import org.springframework.data.domain.Page;
import org.springframework.data.domain.Pageable;
Expand Down Expand Up @@ -61,9 +62,21 @@ public interface MovieRepository extends CrudRepository<Movie, Long> {
@Query("SELECT DISTINCT m.rating FROM Movie m where m.rating != null")
List<String> findAllRatings();

@Query("SELECT title, rating from Movie where title = ?1")
Optional<MovieRating> findOptionalRatingByTitle(String title);

@Query("SELECT title, rating FROM Movie WHERE title = ?1")
MovieRating findRatingByTitle(String title);

interface MovieCountByRating {
String getRating();

Long getCount();
}

interface MovieRating {
String getTitle();

String getRating();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.Optional;

import javax.ws.rs.GET;
import javax.ws.rs.Path;
Expand Down Expand Up @@ -114,7 +115,6 @@ public void setRatingToNullForTitle(@PathParam("title") String title) {
@Produces("application/json")
public List<MovieRepository.MovieCountByRating> countByRating() {
List<MovieRepository.MovieCountByRating> list = movieRepository.countByRating();

// #6205 - Make sure elements in list have been properly cast to the target object type.
// If the type is wrong (Object array), this will throw a ClassNotFoundException
MovieRepository.MovieCountByRating first = list.get(0);
Expand All @@ -123,6 +123,24 @@ public List<MovieRepository.MovieCountByRating> countByRating() {
return list;
}

@GET
@Path("/rating/forTitle/{title}")
@Produces("application/json")
public MovieRepository.MovieRating titleRating(@PathParam("title") String title) {
MovieRepository.MovieRating result = movieRepository.findRatingByTitle(title);
Objects.requireNonNull(result);
return result;
}

@GET
@Path("/rating/opt/forTitle/{title}")
@Produces("application/json")
public Optional<MovieRepository.MovieRating> optionalTitleRating(@PathParam("title") String title) {
Optional result = movieRepository.findOptionalRatingByTitle(title);
System.out.println(result);
return result;
}

@GET
@Path("/ratings")
@Produces("application/json")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,23 @@ void testFindAllRatings() {
.body(containsString("PG-13"));
}

@Test
void testFindRatingByTitle() {
when().get("/movie/rating/forTitle/Interstellar").then()
.statusCode(200)
.body(containsString("Interstellar"))
.body(containsString("PG-13"))
.body(not(containsString("duration")));
}

@Test
void testFindOptionalRatingByTitle() {
when().get("/movie/rating/opt/forTitle/Aladdin").then()
.statusCode(200)
.body(containsString("Aladdin"))
.body(not(containsString("duration")));
}

@Test
void testNewMovie() {
long id = 999L;
Expand Down

0 comments on commit 9cdae32

Please sign in to comment.