Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for deactivating managed context in GraphQL server #27040

Merged
merged 1 commit into from
Aug 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
package io.quarkus.smallrye.graphql.deployment;

import static io.quarkus.smallrye.graphql.deployment.AbstractGraphQLTest.getPropertyAsString;
import static io.restassured.RestAssured.given;

import java.time.LocalDate;
import java.time.Month;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

import javax.enterprise.context.ApplicationScoped;
import javax.inject.Inject;

import org.eclipse.microprofile.graphql.GraphQLApi;
import org.eclipse.microprofile.graphql.Query;
import org.jboss.shrinkwrap.api.asset.EmptyAsset;
import org.jboss.shrinkwrap.api.asset.StringAsset;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.quarkus.security.Authenticated;
import io.quarkus.test.QuarkusUnitTest;
import io.restassured.response.Response;

public class ConcurrentAuthTest extends AbstractGraphQLTest {

static Map<String, String> PROPERTIES = new HashMap<>();
static {

PROPERTIES.put("quarkus.smallrye-graphql.error-extension-fields", "classification,code");
PROPERTIES.put("quarkus.smallrye-graphql.show-runtime-exception-message", "java.lang.SecurityException");

PROPERTIES.put("quarkus.http.auth.basic", "true");
PROPERTIES.put("quarkus.security.users.embedded.enabled", "true");
PROPERTIES.put("quarkus.security.users.embedded.plain-text", "true");
PROPERTIES.put("quarkus.security.users.embedded.users.scott", "jb0ss");
}

@RegisterExtension
static QuarkusUnitTest test = new QuarkusUnitTest()
.withApplicationRoot((jar) -> jar
.addClasses(FilmResource.class, Film.class, GalaxyService.class)
.addAsResource(new StringAsset(getPropertyAsString(PROPERTIES)), "application.properties")
.addAsManifestResource(EmptyAsset.INSTANCE, "beans.xml"));

private int iterations = 5000;

@Test
public void concurrentAllFilmsOnly() throws InterruptedException, ExecutionException {
ExecutorService executor = Executors.newFixedThreadPool(50);

var futures = new ArrayList<CompletableFuture<Boolean>>(iterations);
for (int i = 0; i < iterations; i++) {
futures.add(CompletableFuture.supplyAsync(this::allFilmsRequestWithAuth, executor)
.thenApply(r -> !r.getBody().asString().contains("unauthorized")));
}
Optional<Boolean> success = getTestResult(futures);
Assertions.assertTrue(success.orElse(false), "Unauthorized response codes were found");
executor.shutdown();
}

private static Optional<Boolean> getTestResult(ArrayList<CompletableFuture<Boolean>> futures)
throws InterruptedException, ExecutionException {
return CompletableFuture.allOf(futures.toArray(new CompletableFuture[0]))
.thenApply(v -> futures.stream()
.map(CompletableFuture::join)
.reduce(Boolean::logicalAnd))
.get();
}

private Response allFilmsRequestWithAuth() {
String requestBody = "{\"query\":" +
"\"" +
"{" +
" allFilmsSecured {" +
" title" +
" director" +
" releaseDate" +
" episodeID" +
"}" +
"}" +
"\"" +
"}";

return given()
.body(requestBody)
.auth()
.preemptive()
.basic("scott", "jb0ss")
.post("/graphql/");
}

@GraphQLApi
public static class FilmResource {

@Inject
GalaxyService service;

@Query("allFilmsSecured")
@Authenticated
public List<Film> getAllFilmsSecured() {
return service.getAllFilms();
}
}

public static class Film {

private String title;
private Integer episodeID;
private String director;
private LocalDate releaseDate;

public String getTitle() {
return title;
}

public void setTitle(String title) {
this.title = title;
}

public Integer getEpisodeID() {
return episodeID;
}

public void setEpisodeID(Integer episodeID) {
this.episodeID = episodeID;
}

public String getDirector() {
return director;
}

public void setDirector(String director) {
this.director = director;
}

public LocalDate getReleaseDate() {
return releaseDate;
}

public void setReleaseDate(LocalDate releaseDate) {
this.releaseDate = releaseDate;
}

}

@ApplicationScoped
public static class GalaxyService {

private List<Film> films = new ArrayList<>();

public GalaxyService() {

Film aNewHope = new Film();
aNewHope.setTitle("A New Hope");
aNewHope.setReleaseDate(LocalDate.of(1977, Month.MAY, 25));
aNewHope.setEpisodeID(4);
aNewHope.setDirector("George Lucas");

Film theEmpireStrikesBack = new Film();
theEmpireStrikesBack.setTitle("The Empire Strikes Back");
theEmpireStrikesBack.setReleaseDate(LocalDate.of(1980, Month.MAY, 21));
theEmpireStrikesBack.setEpisodeID(5);
theEmpireStrikesBack.setDirector("George Lucas");

Film returnOfTheJedi = new Film();
returnOfTheJedi.setTitle("Return Of The Jedi");
returnOfTheJedi.setReleaseDate(LocalDate.of(1983, Month.MAY, 25));
returnOfTheJedi.setEpisodeID(6);
returnOfTheJedi.setDirector("George Lucas");

films.add(aNewHope);
films.add(theEmpireStrikesBack);
films.add(returnOfTheJedi);
}

public List<Film> getAllFilms() {
return films;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import javax.json.JsonReaderFactory;

import io.quarkus.arc.Arc;
import io.quarkus.arc.InjectableContext;
import io.quarkus.arc.ManagedContext;
import io.quarkus.security.identity.CurrentIdentityAssociation;
import io.quarkus.security.identity.SecurityIdentity;
Expand Down Expand Up @@ -67,6 +68,7 @@ public void handle(final RoutingContext ctx) {
}
try {
handleWithIdentity(ctx);
currentManagedContext.deactivate();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why isn't this one in a finally? Why deactivate in one case and terminate in the case of an exception?

Also what's the purpose of this.currentManagedContextTerminationHandler now?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this.currentManagedContextTerminationHandler will terminate (so deactivate and destroy) once we are done, while the deactivate after the handleWithIdentity will only deactivate (so that it does not leak)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does not need to be in a finally as the catch with terminate

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can understand that but my concern was more that we are not consistent. Typically here: https://github.com/quarkusio/quarkus/pull/27040/files#diff-5b5cda47e6a5418c9345a577d468f058e05c412dc9b64fce6a6ef5aec496023bR75-R85 where we simply deactivate. Or maybe the exception from the latter is caught below?

Also what do you mean by leaking? What I don't understand is how we can leak something between this and the end of the response? Or is it that the thread is somehow reused between the end of this and the end of the response?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The thread could be reused, in a high traffic scenario.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But you're sure that all the response handlers we added earlier are always executed on the right thread at the right time?
That's what gets me a bit worried.

Note that I have no idea how this all works so I'm asking naive questions.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes sure. Maybe @stuartwdouglas can explain better. By leaking I mean that a new request that use the same thread can get the context that still float around (because it's not been deactivated)

} catch (Throwable t) {
currentManagedContext.terminate();
throw t;
Expand Down Expand Up @@ -108,6 +110,8 @@ protected Map<String, Object> getMetaData(RoutingContext ctx) {
Map<String, Object> metaData = new ConcurrentHashMap<>();
metaData.put("runBlocking", runBlocking);
metaData.put("httpHeaders", getHeaders(ctx));
InjectableContext.ContextState state = currentManagedContext.getState();
metaData.put("state", state);
return metaData;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package io.quarkus.smallrye.graphql.runtime;

import java.util.Map;

import org.jboss.logging.Logger;

import io.quarkus.security.identity.CurrentIdentityAssociation;
Expand All @@ -25,7 +27,9 @@ public SmallRyeGraphQLOverWebSocketHandler(CurrentIdentityAssociation currentIde

@Override
protected void doHandle(final RoutingContext ctx) {

if (ctx.request().headers().contains(HttpHeaders.UPGRADE, HttpHeaders.WEBSOCKET, true) && !ctx.request().isEnded()) {
Map<String, Object> metaData = getMetaData(ctx);
ctx.request().toWebSocket(event -> {
if (event.succeeded()) {
ServerWebSocket serverWebSocket = event.result();
Expand All @@ -34,11 +38,11 @@ protected void doHandle(final RoutingContext ctx) {
switch (subprotocol) {
case "graphql-transport-ws":
handler = new GraphQLTransportWSSubprotocolHandler(
new QuarkusVertxWebSocketSession(serverWebSocket), getMetaData(ctx));
new QuarkusVertxWebSocketSession(serverWebSocket), metaData);
break;
case "graphql-ws":
handler = new GraphQLWSSubprotocolHandler(
new QuarkusVertxWebSocketSession(serverWebSocket), getMetaData(ctx));
new QuarkusVertxWebSocketSession(serverWebSocket), metaData);
break;
default:
log.warn("Unknown graphql-over-websocket protocol: " + subprotocol);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import graphql.execution.DataFetcherResult;
import graphql.schema.DataFetchingEnvironment;
import io.quarkus.arc.Arc;
import io.quarkus.arc.ManagedContext;
import io.smallrye.graphql.SmallRyeGraphQLServerMessages;
import io.smallrye.graphql.execution.datafetcher.AbstractDataFetcher;
import io.smallrye.graphql.schema.model.Operation;
Expand Down Expand Up @@ -72,11 +74,15 @@ protected <O> O invokeFailure(DataFetcherResult.Builder<Object> resultBuilder) {
@Override
@SuppressWarnings("unchecked")
protected CompletionStage<List<T>> invokeBatch(DataFetchingEnvironment dfe, Object[] arguments) {
ManagedContext requestContext = Arc.container().requestContext();
try {
BlockingHelper.reactivate(requestContext, dfe);
return handleUserBatchLoad(dfe, arguments)
.subscribe().asCompletionStage();
} catch (Exception ex) {
throw new RuntimeException(ex);
} finally {
requestContext.deactivate();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

import java.util.concurrent.Callable;

import graphql.schema.DataFetchingEnvironment;
import io.quarkus.arc.InjectableContext;
import io.quarkus.arc.ManagedContext;
import io.smallrye.graphql.schema.model.Execute;
import io.smallrye.graphql.schema.model.Operation;
import io.vertx.core.Context;
Expand Down Expand Up @@ -31,4 +34,14 @@ public static void runBlocking(Context vc, Callable<Object> contextualCallable,
}, result);
}

public static void reactivate(ManagedContext requestContext, DataFetchingEnvironment dfe) {
if (!requestContext.isActive()) {
Object maybeState = dfe.getGraphQlContext().getOrDefault("state", null);
if (maybeState != null) {
requestContext.activate((InjectableContext.ContextState) maybeState);
} else {
requestContext.activate();
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import graphql.execution.DataFetcherResult;
import graphql.schema.DataFetchingEnvironment;
import io.quarkus.arc.Arc;
import io.quarkus.arc.ManagedContext;
import io.smallrye.context.SmallRyeThreadContext;
import io.smallrye.graphql.execution.datafetcher.DefaultDataFetcher;
import io.smallrye.graphql.schema.model.Operation;
Expand Down Expand Up @@ -44,7 +45,7 @@ public CompletionStage<List<T>> invokeBatch(DataFetchingEnvironment dfe, Object[
if (runBlocking(dfe) || BlockingHelper.blockingShouldExecuteNonBlocking(operation, vc)) {
return super.invokeBatch(dfe, arguments);
} else {
return invokeBatchBlocking(arguments, vc);
return invokeBatchBlocking(dfe, arguments, vc);
}
}

Expand Down Expand Up @@ -81,18 +82,24 @@ private <T> T invokeAndTransformBlocking(final DataFetchingEnvironment dfe, Data
}

@SuppressWarnings("unchecked")
private CompletionStage<List<T>> invokeBatchBlocking(Object[] arguments, Context vc) {
SmallRyeThreadContext threadContext = Arc.container().select(SmallRyeThreadContext.class).get();
final Promise<List<T>> result = Promise.promise();

// We need some make sure that we call given the context
Callable<Object> contextualCallable = threadContext.contextualCallable(() -> {
return (List<T>) operationInvoker.invokePrivileged(arguments);
});

// Here call blocking with context
BlockingHelper.runBlocking(vc, contextualCallable, result);
return result.future().toCompletionStage();
private CompletionStage<List<T>> invokeBatchBlocking(DataFetchingEnvironment dfe, Object[] arguments, Context vc) {
ManagedContext requestContext = Arc.container().requestContext();
try {
BlockingHelper.reactivate(requestContext, dfe);
SmallRyeThreadContext threadContext = Arc.container().select(SmallRyeThreadContext.class).get();
final Promise<List<T>> result = Promise.promise();

// We need some make sure that we call given the context
Callable<Object> contextualCallable = threadContext.contextualCallable(() -> {
return (List<T>) operationInvoker.invokePrivileged(arguments);
});

// Here call blocking with context
BlockingHelper.runBlocking(vc, contextualCallable, result);
return result.future().toCompletionStage();
} finally {
requestContext.deactivate();
}
}

private boolean runBlocking(DataFetchingEnvironment dfe) {
Expand Down