From 27e6a8b3e899c85bf3ed2af968c9d2b47087c673 Mon Sep 17 00:00:00 2001 From: Phillip Kruger Date: Mon, 1 Aug 2022 12:58:52 +1000 Subject: [PATCH] Make sure to deactivate the Managed context Signed-off-by: Phillip Kruger --- .../deployment/ConcurrentAuthTest.java | 189 ++++++++++++++++++ .../SmallRyeGraphQLAbstractHandler.java | 4 + .../datafetcher/AbstractAsyncDataFetcher.java | 6 + .../spi/datafetcher/BlockingHelper.java | 13 ++ .../QuarkusDefaultDataFetcher.java | 33 +-- 5 files changed, 232 insertions(+), 13 deletions(-) create mode 100644 extensions/smallrye-graphql/deployment/src/test/java/io/quarkus/smallrye/graphql/deployment/ConcurrentAuthTest.java diff --git a/extensions/smallrye-graphql/deployment/src/test/java/io/quarkus/smallrye/graphql/deployment/ConcurrentAuthTest.java b/extensions/smallrye-graphql/deployment/src/test/java/io/quarkus/smallrye/graphql/deployment/ConcurrentAuthTest.java new file mode 100644 index 00000000000000..9504bfcb83ce56 --- /dev/null +++ b/extensions/smallrye-graphql/deployment/src/test/java/io/quarkus/smallrye/graphql/deployment/ConcurrentAuthTest.java @@ -0,0 +1,189 @@ +package io.quarkus.smallrye.graphql.deployment; + +import static io.quarkus.smallrye.graphql.deployment.AbstractGraphQLTest.getPropertyAsString; +import static io.restassured.RestAssured.given; + +import java.time.LocalDate; +import java.time.Month; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; + +import javax.enterprise.context.ApplicationScoped; +import javax.inject.Inject; + +import org.eclipse.microprofile.graphql.GraphQLApi; +import org.eclipse.microprofile.graphql.Query; +import org.jboss.shrinkwrap.api.asset.EmptyAsset; +import org.jboss.shrinkwrap.api.asset.StringAsset; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.security.Authenticated; +import io.quarkus.test.QuarkusUnitTest; +import io.restassured.response.Response; + +public class ConcurrentAuthTest extends AbstractGraphQLTest { + + static Map PROPERTIES = new HashMap<>(); + static { + + PROPERTIES.put("quarkus.smallrye-graphql.error-extension-fields", "classification,code"); + PROPERTIES.put("quarkus.smallrye-graphql.show-runtime-exception-message", "java.lang.SecurityException"); + + PROPERTIES.put("quarkus.http.auth.basic", "true"); + PROPERTIES.put("quarkus.security.users.embedded.enabled", "true"); + PROPERTIES.put("quarkus.security.users.embedded.plain-text", "true"); + PROPERTIES.put("quarkus.security.users.embedded.users.scott", "jb0ss"); + } + + @RegisterExtension + static QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot((jar) -> jar + .addClasses(FilmResource.class, Film.class, GalaxyService.class) + .addAsResource(new StringAsset(getPropertyAsString(PROPERTIES)), "application.properties") + .addAsManifestResource(EmptyAsset.INSTANCE, "beans.xml")); + + private int iterations = 5000; + + @Test + public void concurrentAllFilmsOnly() throws InterruptedException, ExecutionException { + ExecutorService executor = Executors.newFixedThreadPool(50); + + var futures = new ArrayList>(iterations); + for (int i = 0; i < iterations; i++) { + futures.add(CompletableFuture.supplyAsync(this::allFilmsRequestWithAuth, executor) + .thenApply(r -> !r.getBody().asString().contains("unauthorized"))); + } + Optional success = getTestResult(futures); + Assertions.assertTrue(success.orElse(false), "Unauthorized response codes were found"); + executor.shutdown(); + } + + private static Optional getTestResult(ArrayList> futures) + throws InterruptedException, ExecutionException { + return CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])) + .thenApply(v -> futures.stream() + .map(CompletableFuture::join) + .reduce(Boolean::logicalAnd)) + .get(); + } + + private Response allFilmsRequestWithAuth() { + String requestBody = "{\"query\":" + + "\"" + + "{" + + " allFilmsSecured {" + + " title" + + " director" + + " releaseDate" + + " episodeID" + + "}" + + "}" + + "\"" + + "}"; + + return given() + .body(requestBody) + .auth() + .preemptive() + .basic("scott", "jb0ss") + .post("/graphql/"); + } + + @GraphQLApi + public static class FilmResource { + + @Inject + GalaxyService service; + + @Query("allFilmsSecured") + @Authenticated + public List getAllFilmsSecured() { + return service.getAllFilms(); + } + } + + public static class Film { + + private String title; + private Integer episodeID; + private String director; + private LocalDate releaseDate; + + public String getTitle() { + return title; + } + + public void setTitle(String title) { + this.title = title; + } + + public Integer getEpisodeID() { + return episodeID; + } + + public void setEpisodeID(Integer episodeID) { + this.episodeID = episodeID; + } + + public String getDirector() { + return director; + } + + public void setDirector(String director) { + this.director = director; + } + + public LocalDate getReleaseDate() { + return releaseDate; + } + + public void setReleaseDate(LocalDate releaseDate) { + this.releaseDate = releaseDate; + } + + } + + @ApplicationScoped + public static class GalaxyService { + + private List films = new ArrayList<>(); + + public GalaxyService() { + + Film aNewHope = new Film(); + aNewHope.setTitle("A New Hope"); + aNewHope.setReleaseDate(LocalDate.of(1977, Month.MAY, 25)); + aNewHope.setEpisodeID(4); + aNewHope.setDirector("George Lucas"); + + Film theEmpireStrikesBack = new Film(); + theEmpireStrikesBack.setTitle("The Empire Strikes Back"); + theEmpireStrikesBack.setReleaseDate(LocalDate.of(1980, Month.MAY, 21)); + theEmpireStrikesBack.setEpisodeID(5); + theEmpireStrikesBack.setDirector("George Lucas"); + + Film returnOfTheJedi = new Film(); + returnOfTheJedi.setTitle("Return Of The Jedi"); + returnOfTheJedi.setReleaseDate(LocalDate.of(1983, Month.MAY, 25)); + returnOfTheJedi.setEpisodeID(6); + returnOfTheJedi.setDirector("George Lucas"); + + films.add(aNewHope); + films.add(theEmpireStrikesBack); + films.add(returnOfTheJedi); + } + + public List getAllFilms() { + return films; + } + } +} diff --git a/extensions/smallrye-graphql/runtime/src/main/java/io/quarkus/smallrye/graphql/runtime/SmallRyeGraphQLAbstractHandler.java b/extensions/smallrye-graphql/runtime/src/main/java/io/quarkus/smallrye/graphql/runtime/SmallRyeGraphQLAbstractHandler.java index b6f44d441a3183..6eb63a7ccef692 100644 --- a/extensions/smallrye-graphql/runtime/src/main/java/io/quarkus/smallrye/graphql/runtime/SmallRyeGraphQLAbstractHandler.java +++ b/extensions/smallrye-graphql/runtime/src/main/java/io/quarkus/smallrye/graphql/runtime/SmallRyeGraphQLAbstractHandler.java @@ -12,6 +12,7 @@ import javax.json.JsonReaderFactory; import io.quarkus.arc.Arc; +import io.quarkus.arc.InjectableContext; import io.quarkus.arc.ManagedContext; import io.quarkus.security.identity.CurrentIdentityAssociation; import io.quarkus.security.identity.SecurityIdentity; @@ -67,6 +68,7 @@ public void handle(final RoutingContext ctx) { } try { handleWithIdentity(ctx); + currentManagedContext.deactivate(); } catch (Throwable t) { currentManagedContext.terminate(); throw t; @@ -108,6 +110,8 @@ protected Map getMetaData(RoutingContext ctx) { Map metaData = new ConcurrentHashMap<>(); metaData.put("runBlocking", runBlocking); metaData.put("httpHeaders", getHeaders(ctx)); + InjectableContext.ContextState state = currentManagedContext.getState(); + metaData.put("state", state); return metaData; } diff --git a/extensions/smallrye-graphql/runtime/src/main/java/io/quarkus/smallrye/graphql/runtime/spi/datafetcher/AbstractAsyncDataFetcher.java b/extensions/smallrye-graphql/runtime/src/main/java/io/quarkus/smallrye/graphql/runtime/spi/datafetcher/AbstractAsyncDataFetcher.java index 80d55240d4d2fb..cbb24778d934d2 100644 --- a/extensions/smallrye-graphql/runtime/src/main/java/io/quarkus/smallrye/graphql/runtime/spi/datafetcher/AbstractAsyncDataFetcher.java +++ b/extensions/smallrye-graphql/runtime/src/main/java/io/quarkus/smallrye/graphql/runtime/spi/datafetcher/AbstractAsyncDataFetcher.java @@ -7,6 +7,8 @@ import graphql.execution.DataFetcherResult; import graphql.schema.DataFetchingEnvironment; +import io.quarkus.arc.Arc; +import io.quarkus.arc.ManagedContext; import io.smallrye.graphql.SmallRyeGraphQLServerMessages; import io.smallrye.graphql.execution.datafetcher.AbstractDataFetcher; import io.smallrye.graphql.schema.model.Operation; @@ -72,11 +74,15 @@ protected O invokeFailure(DataFetcherResult.Builder resultBuilder) { @Override @SuppressWarnings("unchecked") protected CompletionStage> invokeBatch(DataFetchingEnvironment dfe, Object[] arguments) { + ManagedContext requestContext = Arc.container().requestContext(); try { + BlockingHelper.reactivate(requestContext, dfe); return handleUserBatchLoad(dfe, arguments) .subscribe().asCompletionStage(); } catch (Exception ex) { throw new RuntimeException(ex); + } finally { + requestContext.deactivate(); } } diff --git a/extensions/smallrye-graphql/runtime/src/main/java/io/quarkus/smallrye/graphql/runtime/spi/datafetcher/BlockingHelper.java b/extensions/smallrye-graphql/runtime/src/main/java/io/quarkus/smallrye/graphql/runtime/spi/datafetcher/BlockingHelper.java index 95e5994128b1e2..765925de3c4a4e 100644 --- a/extensions/smallrye-graphql/runtime/src/main/java/io/quarkus/smallrye/graphql/runtime/spi/datafetcher/BlockingHelper.java +++ b/extensions/smallrye-graphql/runtime/src/main/java/io/quarkus/smallrye/graphql/runtime/spi/datafetcher/BlockingHelper.java @@ -2,6 +2,9 @@ import java.util.concurrent.Callable; +import graphql.schema.DataFetchingEnvironment; +import io.quarkus.arc.InjectableContext; +import io.quarkus.arc.ManagedContext; import io.smallrye.graphql.schema.model.Execute; import io.smallrye.graphql.schema.model.Operation; import io.vertx.core.Context; @@ -31,4 +34,14 @@ public static void runBlocking(Context vc, Callable contextualCallable, }, result); } + public static void reactivate(ManagedContext requestContext, DataFetchingEnvironment dfe) { + if (!requestContext.isActive()) { + Object maybeState = dfe.getGraphQlContext().getOrDefault("state", null); + if (maybeState != null) { + requestContext.activate((InjectableContext.ContextState) maybeState); + } else { + requestContext.activate(); + } + } + } } diff --git a/extensions/smallrye-graphql/runtime/src/main/java/io/quarkus/smallrye/graphql/runtime/spi/datafetcher/QuarkusDefaultDataFetcher.java b/extensions/smallrye-graphql/runtime/src/main/java/io/quarkus/smallrye/graphql/runtime/spi/datafetcher/QuarkusDefaultDataFetcher.java index 1907a29fb29787..bb7b463c235c4e 100644 --- a/extensions/smallrye-graphql/runtime/src/main/java/io/quarkus/smallrye/graphql/runtime/spi/datafetcher/QuarkusDefaultDataFetcher.java +++ b/extensions/smallrye-graphql/runtime/src/main/java/io/quarkus/smallrye/graphql/runtime/spi/datafetcher/QuarkusDefaultDataFetcher.java @@ -10,6 +10,7 @@ import graphql.execution.DataFetcherResult; import graphql.schema.DataFetchingEnvironment; import io.quarkus.arc.Arc; +import io.quarkus.arc.ManagedContext; import io.smallrye.context.SmallRyeThreadContext; import io.smallrye.graphql.execution.datafetcher.DefaultDataFetcher; import io.smallrye.graphql.schema.model.Operation; @@ -44,7 +45,7 @@ public CompletionStage> invokeBatch(DataFetchingEnvironment dfe, Object[ if (runBlocking(dfe) || BlockingHelper.blockingShouldExecuteNonBlocking(operation, vc)) { return super.invokeBatch(dfe, arguments); } else { - return invokeBatchBlocking(arguments, vc); + return invokeBatchBlocking(dfe, arguments, vc); } } @@ -81,18 +82,24 @@ private T invokeAndTransformBlocking(final DataFetchingEnvironment dfe, Data } @SuppressWarnings("unchecked") - private CompletionStage> invokeBatchBlocking(Object[] arguments, Context vc) { - SmallRyeThreadContext threadContext = Arc.container().select(SmallRyeThreadContext.class).get(); - final Promise> result = Promise.promise(); - - // We need some make sure that we call given the context - Callable contextualCallable = threadContext.contextualCallable(() -> { - return (List) operationInvoker.invokePrivileged(arguments); - }); - - // Here call blocking with context - BlockingHelper.runBlocking(vc, contextualCallable, result); - return result.future().toCompletionStage(); + private CompletionStage> invokeBatchBlocking(DataFetchingEnvironment dfe, Object[] arguments, Context vc) { + ManagedContext requestContext = Arc.container().requestContext(); + try { + BlockingHelper.reactivate(requestContext, dfe); + SmallRyeThreadContext threadContext = Arc.container().select(SmallRyeThreadContext.class).get(); + final Promise> result = Promise.promise(); + + // We need some make sure that we call given the context + Callable contextualCallable = threadContext.contextualCallable(() -> { + return (List) operationInvoker.invokePrivileged(arguments); + }); + + // Here call blocking with context + BlockingHelper.runBlocking(vc, contextualCallable, result); + return result.future().toCompletionStage(); + } finally { + requestContext.deactivate(); + } } private boolean runBlocking(DataFetchingEnvironment dfe) {