Skip to content

Commit

Permalink
Make sure to deactivate the Managed context
Browse files Browse the repository at this point in the history
Signed-off-by: Phillip Kruger <[email protected]>
  • Loading branch information
phillip-kruger committed Aug 1, 2022
1 parent f1b54c0 commit 27e6a8b
Show file tree
Hide file tree
Showing 5 changed files with 232 additions and 13 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
package io.quarkus.smallrye.graphql.deployment;

import static io.quarkus.smallrye.graphql.deployment.AbstractGraphQLTest.getPropertyAsString;
import static io.restassured.RestAssured.given;

import java.time.LocalDate;
import java.time.Month;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

import javax.enterprise.context.ApplicationScoped;
import javax.inject.Inject;

import org.eclipse.microprofile.graphql.GraphQLApi;
import org.eclipse.microprofile.graphql.Query;
import org.jboss.shrinkwrap.api.asset.EmptyAsset;
import org.jboss.shrinkwrap.api.asset.StringAsset;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.quarkus.security.Authenticated;
import io.quarkus.test.QuarkusUnitTest;
import io.restassured.response.Response;

public class ConcurrentAuthTest extends AbstractGraphQLTest {

static Map<String, String> PROPERTIES = new HashMap<>();
static {

PROPERTIES.put("quarkus.smallrye-graphql.error-extension-fields", "classification,code");
PROPERTIES.put("quarkus.smallrye-graphql.show-runtime-exception-message", "java.lang.SecurityException");

PROPERTIES.put("quarkus.http.auth.basic", "true");
PROPERTIES.put("quarkus.security.users.embedded.enabled", "true");
PROPERTIES.put("quarkus.security.users.embedded.plain-text", "true");
PROPERTIES.put("quarkus.security.users.embedded.users.scott", "jb0ss");
}

@RegisterExtension
static QuarkusUnitTest test = new QuarkusUnitTest()
.withApplicationRoot((jar) -> jar
.addClasses(FilmResource.class, Film.class, GalaxyService.class)
.addAsResource(new StringAsset(getPropertyAsString(PROPERTIES)), "application.properties")
.addAsManifestResource(EmptyAsset.INSTANCE, "beans.xml"));

private int iterations = 5000;

@Test
public void concurrentAllFilmsOnly() throws InterruptedException, ExecutionException {
ExecutorService executor = Executors.newFixedThreadPool(50);

var futures = new ArrayList<CompletableFuture<Boolean>>(iterations);
for (int i = 0; i < iterations; i++) {
futures.add(CompletableFuture.supplyAsync(this::allFilmsRequestWithAuth, executor)
.thenApply(r -> !r.getBody().asString().contains("unauthorized")));
}
Optional<Boolean> success = getTestResult(futures);
Assertions.assertTrue(success.orElse(false), "Unauthorized response codes were found");
executor.shutdown();
}

private static Optional<Boolean> getTestResult(ArrayList<CompletableFuture<Boolean>> futures)
throws InterruptedException, ExecutionException {
return CompletableFuture.allOf(futures.toArray(new CompletableFuture[0]))
.thenApply(v -> futures.stream()
.map(CompletableFuture::join)
.reduce(Boolean::logicalAnd))
.get();
}

private Response allFilmsRequestWithAuth() {
String requestBody = "{\"query\":" +
"\"" +
"{" +
" allFilmsSecured {" +
" title" +
" director" +
" releaseDate" +
" episodeID" +
"}" +
"}" +
"\"" +
"}";

return given()
.body(requestBody)
.auth()
.preemptive()
.basic("scott", "jb0ss")
.post("/graphql/");
}

@GraphQLApi
public static class FilmResource {

@Inject
GalaxyService service;

@Query("allFilmsSecured")
@Authenticated
public List<Film> getAllFilmsSecured() {
return service.getAllFilms();
}
}

public static class Film {

private String title;
private Integer episodeID;
private String director;
private LocalDate releaseDate;

public String getTitle() {
return title;
}

public void setTitle(String title) {
this.title = title;
}

public Integer getEpisodeID() {
return episodeID;
}

public void setEpisodeID(Integer episodeID) {
this.episodeID = episodeID;
}

public String getDirector() {
return director;
}

public void setDirector(String director) {
this.director = director;
}

public LocalDate getReleaseDate() {
return releaseDate;
}

public void setReleaseDate(LocalDate releaseDate) {
this.releaseDate = releaseDate;
}

}

@ApplicationScoped
public static class GalaxyService {

private List<Film> films = new ArrayList<>();

public GalaxyService() {

Film aNewHope = new Film();
aNewHope.setTitle("A New Hope");
aNewHope.setReleaseDate(LocalDate.of(1977, Month.MAY, 25));
aNewHope.setEpisodeID(4);
aNewHope.setDirector("George Lucas");

Film theEmpireStrikesBack = new Film();
theEmpireStrikesBack.setTitle("The Empire Strikes Back");
theEmpireStrikesBack.setReleaseDate(LocalDate.of(1980, Month.MAY, 21));
theEmpireStrikesBack.setEpisodeID(5);
theEmpireStrikesBack.setDirector("George Lucas");

Film returnOfTheJedi = new Film();
returnOfTheJedi.setTitle("Return Of The Jedi");
returnOfTheJedi.setReleaseDate(LocalDate.of(1983, Month.MAY, 25));
returnOfTheJedi.setEpisodeID(6);
returnOfTheJedi.setDirector("George Lucas");

films.add(aNewHope);
films.add(theEmpireStrikesBack);
films.add(returnOfTheJedi);
}

public List<Film> getAllFilms() {
return films;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import javax.json.JsonReaderFactory;

import io.quarkus.arc.Arc;
import io.quarkus.arc.InjectableContext;
import io.quarkus.arc.ManagedContext;
import io.quarkus.security.identity.CurrentIdentityAssociation;
import io.quarkus.security.identity.SecurityIdentity;
Expand Down Expand Up @@ -67,6 +68,7 @@ public void handle(final RoutingContext ctx) {
}
try {
handleWithIdentity(ctx);
currentManagedContext.deactivate();
} catch (Throwable t) {
currentManagedContext.terminate();
throw t;
Expand Down Expand Up @@ -108,6 +110,8 @@ protected Map<String, Object> getMetaData(RoutingContext ctx) {
Map<String, Object> metaData = new ConcurrentHashMap<>();
metaData.put("runBlocking", runBlocking);
metaData.put("httpHeaders", getHeaders(ctx));
InjectableContext.ContextState state = currentManagedContext.getState();
metaData.put("state", state);
return metaData;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import graphql.execution.DataFetcherResult;
import graphql.schema.DataFetchingEnvironment;
import io.quarkus.arc.Arc;
import io.quarkus.arc.ManagedContext;
import io.smallrye.graphql.SmallRyeGraphQLServerMessages;
import io.smallrye.graphql.execution.datafetcher.AbstractDataFetcher;
import io.smallrye.graphql.schema.model.Operation;
Expand Down Expand Up @@ -72,11 +74,15 @@ protected <O> O invokeFailure(DataFetcherResult.Builder<Object> resultBuilder) {
@Override
@SuppressWarnings("unchecked")
protected CompletionStage<List<T>> invokeBatch(DataFetchingEnvironment dfe, Object[] arguments) {
ManagedContext requestContext = Arc.container().requestContext();
try {
BlockingHelper.reactivate(requestContext, dfe);
return handleUserBatchLoad(dfe, arguments)
.subscribe().asCompletionStage();
} catch (Exception ex) {
throw new RuntimeException(ex);
} finally {
requestContext.deactivate();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

import java.util.concurrent.Callable;

import graphql.schema.DataFetchingEnvironment;
import io.quarkus.arc.InjectableContext;
import io.quarkus.arc.ManagedContext;
import io.smallrye.graphql.schema.model.Execute;
import io.smallrye.graphql.schema.model.Operation;
import io.vertx.core.Context;
Expand Down Expand Up @@ -31,4 +34,14 @@ public static void runBlocking(Context vc, Callable<Object> contextualCallable,
}, result);
}

public static void reactivate(ManagedContext requestContext, DataFetchingEnvironment dfe) {
if (!requestContext.isActive()) {
Object maybeState = dfe.getGraphQlContext().getOrDefault("state", null);
if (maybeState != null) {
requestContext.activate((InjectableContext.ContextState) maybeState);
} else {
requestContext.activate();
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import graphql.execution.DataFetcherResult;
import graphql.schema.DataFetchingEnvironment;
import io.quarkus.arc.Arc;
import io.quarkus.arc.ManagedContext;
import io.smallrye.context.SmallRyeThreadContext;
import io.smallrye.graphql.execution.datafetcher.DefaultDataFetcher;
import io.smallrye.graphql.schema.model.Operation;
Expand Down Expand Up @@ -44,7 +45,7 @@ public CompletionStage<List<T>> invokeBatch(DataFetchingEnvironment dfe, Object[
if (runBlocking(dfe) || BlockingHelper.blockingShouldExecuteNonBlocking(operation, vc)) {
return super.invokeBatch(dfe, arguments);
} else {
return invokeBatchBlocking(arguments, vc);
return invokeBatchBlocking(dfe, arguments, vc);
}
}

Expand Down Expand Up @@ -81,18 +82,24 @@ private <T> T invokeAndTransformBlocking(final DataFetchingEnvironment dfe, Data
}

@SuppressWarnings("unchecked")
private CompletionStage<List<T>> invokeBatchBlocking(Object[] arguments, Context vc) {
SmallRyeThreadContext threadContext = Arc.container().select(SmallRyeThreadContext.class).get();
final Promise<List<T>> result = Promise.promise();

// We need some make sure that we call given the context
Callable<Object> contextualCallable = threadContext.contextualCallable(() -> {
return (List<T>) operationInvoker.invokePrivileged(arguments);
});

// Here call blocking with context
BlockingHelper.runBlocking(vc, contextualCallable, result);
return result.future().toCompletionStage();
private CompletionStage<List<T>> invokeBatchBlocking(DataFetchingEnvironment dfe, Object[] arguments, Context vc) {
ManagedContext requestContext = Arc.container().requestContext();
try {
BlockingHelper.reactivate(requestContext, dfe);
SmallRyeThreadContext threadContext = Arc.container().select(SmallRyeThreadContext.class).get();
final Promise<List<T>> result = Promise.promise();

// We need some make sure that we call given the context
Callable<Object> contextualCallable = threadContext.contextualCallable(() -> {
return (List<T>) operationInvoker.invokePrivileged(arguments);
});

// Here call blocking with context
BlockingHelper.runBlocking(vc, contextualCallable, result);
return result.future().toCompletionStage();
} finally {
requestContext.deactivate();
}
}

private boolean runBlocking(DataFetchingEnvironment dfe) {
Expand Down

0 comments on commit 27e6a8b

Please sign in to comment.