Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce support for canceling async types in Quarkus Rest #41710

Merged
merged 3 commits into from
Jul 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package org.jboss.resteasy.reactive.server.runtime.kotlin

import io.vertx.core.Vertx
import java.util.concurrent.atomic.AtomicBoolean
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.launch
Expand Down Expand Up @@ -35,24 +36,50 @@ class CoroutineInvocationHandler(
logger.trace("Handling request with dispatcher {}", dispatcher)

requestContext.suspend()
coroutineScope.launch(context = dispatcher) {
// ensure the proper CL is not lost in dev-mode
Thread.currentThread().contextClassLoader = originalTCCL
try {
val result =
invoker.invokeCoroutine(
requestContext.endpointInstance,
requestContext.parameters
)
if (result != Unit) {
requestContext.result = result
val done = AtomicBoolean()
var canceled = AtomicBoolean()

val job =
coroutineScope.launch(context = dispatcher) {
// ensure the proper CL is not lost in dev-mode
Thread.currentThread().contextClassLoader = originalTCCL
try {
val result =
invoker.invokeCoroutine(
requestContext.endpointInstance,
requestContext.parameters
)
done.set(true)
if (result != Unit) {
requestContext.result = result
}
requestContext.resume()
} catch (t: Throwable) {
done.set(true)

if (canceled.get()) {
try {
// get rid of everything related to the request since the connection has
// already gone away
requestContext.close()
} catch (ignored: Exception) {}
} else {
// passing true since the target doesn't change and we want response filters
// to
// be able to know what the resource method was
requestContext.handleException(t, true)
requestContext.resume()
}
}
} catch (t: Throwable) {
// passing true since the target doesn't change and we want response filters to be
// able to know what the resource method was
requestContext.handleException(t, true)
}
requestContext.resume()

requestContext.serverResponse().addCloseHandler {
if (!done.get()) {
try {
canceled.set(true)
job.cancel(null)
} catch (ignored: Exception) {}
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
package io.quarkus.resteasy.reactive.server.test;

import static io.restassured.RestAssured.when;
import static org.awaitility.Awaitility.await;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.junit.jupiter.api.Assertions.assertEquals;

import java.net.URL;
import java.time.Duration;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Supplier;

import jakarta.inject.Inject;
import jakarta.ws.rs.GET;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.Produces;
import jakarta.ws.rs.core.MediaType;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.quarkus.test.QuarkusUnitTest;
import io.quarkus.test.common.http.TestHTTPResource;
import io.vertx.core.Vertx;
import io.vertx.ext.web.client.WebClient;

public class CancelableCompletionStageTest {

@RegisterExtension
static QuarkusUnitTest runner = new QuarkusUnitTest()
.withApplicationRoot(jar -> jar.addClasses(Resource.class));

@BeforeEach
void setUp() {
Resource.COUNT.set(0);
}

@Inject
Vertx vertx;

@TestHTTPResource
URL url;

@Test
public void testNormal() {
when().get("test")
.then()
.statusCode(200)
.body(equalTo("Hello, world"));
}

@Test
public void testCancel() {
WebClient client = WebClient.create(vertx);

client.get(url.getPort(), url.getHost(), "/test").send();

try {
// make sure we did make the proper request
await().atMost(Duration.ofSeconds(2)).untilAtomic(Resource.COUNT, equalTo(1));

// this will effectively cancel the request
client.close();

// make sure we wait until the request could have completed
Thread.sleep(7_000);

// if the count did not increase, it means that Uni was cancelled
assertEquals(1, Resource.COUNT.get());
} catch (InterruptedException ignored) {

} finally {
try {
client.close();
} catch (Exception ignored) {

}
}

}

@Path("test")
public static class Resource {

public static final AtomicInteger COUNT = new AtomicInteger(0);

@GET
@Produces(MediaType.TEXT_PLAIN)
public CompletionStage<String> hello() {
COUNT.incrementAndGet();
return CompletableFuture.supplyAsync(
new Supplier<>() {
@Override
public String get() {
COUNT.incrementAndGet();
return "Hello, world";
}
},
CompletableFuture.delayedExecutor(5, TimeUnit.SECONDS));
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
package io.quarkus.resteasy.reactive.server.test;

import static io.restassured.RestAssured.when;
import static org.awaitility.Awaitility.await;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.junit.jupiter.api.Assertions.*;

import java.net.URL;
import java.time.Duration;
import java.util.concurrent.atomic.AtomicInteger;

import jakarta.inject.Inject;
import jakarta.ws.rs.GET;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.Produces;
import jakarta.ws.rs.core.MediaType;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.quarkus.test.QuarkusUnitTest;
import io.quarkus.test.common.http.TestHTTPResource;
import io.smallrye.mutiny.Uni;
import io.vertx.core.Vertx;
import io.vertx.ext.web.client.WebClient;

public class CancelableUniTest {

@RegisterExtension
static QuarkusUnitTest runner = new QuarkusUnitTest()
.withApplicationRoot(jar -> jar.addClasses(Resource.class));

@BeforeEach
void setUp() {
Resource.COUNT.set(0);
}

@Inject
Vertx vertx;

@TestHTTPResource
URL url;

@Test
public void testNormal() {
when().get("test")
.then()
.statusCode(200)
.body(equalTo("Hello, world"));
}

@Test
public void testCancel() {
WebClient client = WebClient.create(vertx);

client.get(url.getPort(), url.getHost(), "/test").send();

try {
// make sure we did make the proper request
await().atMost(Duration.ofSeconds(2)).untilAtomic(Resource.COUNT, equalTo(1));

// this will effectively cancel the request
client.close();

// make sure we wait until the request could have completed
Thread.sleep(7_000);

// if the count did not increase, it means that Uni was cancelled
assertEquals(1, Resource.COUNT.get());
} catch (InterruptedException ignored) {

} finally {
try {
client.close();
} catch (Exception ignored) {

}
}

}

@Path("test")
public static class Resource {

public static final AtomicInteger COUNT = new AtomicInteger(0);

@GET
@Produces(MediaType.TEXT_PLAIN)
public Uni<String> hello() {
COUNT.incrementAndGet();
return Uni.createFrom().item("Hello, world").onItem().delayIt().by(Duration.ofSeconds(5)).onItem().invoke(
COUNT::incrementAndGet);
}
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package org.jboss.resteasy.reactive.server.handlers;

import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.atomic.AtomicBoolean;

import org.jboss.resteasy.reactive.server.core.ResteasyReactiveRequestContext;
import org.jboss.resteasy.reactive.server.spi.ServerRestHandler;
Expand All @@ -10,19 +12,42 @@ public class CompletionStageResponseHandler implements ServerRestHandler {
@Override
public void handle(ResteasyReactiveRequestContext requestContext) throws Exception {
// FIXME: handle Response with entity being a CompletionStage
if (requestContext.getResult() instanceof CompletionStage) {
CompletionStage<?> result = (CompletionStage<?>) requestContext.getResult();
if (requestContext.getResult() instanceof CompletionStage<?> result) {
requestContext.suspend();

AtomicBoolean done = new AtomicBoolean();
AtomicBoolean canceled = new AtomicBoolean();
result.handle((v, t) -> {
if (t != null) {
requestContext.handleException(t, true);
done.set(true);
if (canceled.get()) {
try {
// get rid of everything related to the request since the connection has already gone away
requestContext.close();
} catch (Exception ignored) {

}
} else {
requestContext.setResult(v);
if (t != null) {
requestContext.handleException(t, true);
} else {
requestContext.setResult(v);
}
requestContext.resume();
}
requestContext.resume();
return null;
});

requestContext.serverResponse().addCloseHandler(new Runnable() {
@Override
public void run() {
if (!done.get()) {
if (result instanceof CompletableFuture<?> cf) {
canceled.set(true);
cf.cancel(true);
}
}
}
});
}
}
}
Original file line number Diff line number Diff line change
@@ -1,33 +1,52 @@
package org.jboss.resteasy.reactive.server.handlers;

import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Consumer;

import org.jboss.resteasy.reactive.server.core.ResteasyReactiveRequestContext;
import org.jboss.resteasy.reactive.server.spi.ServerRestHandler;

import io.smallrye.mutiny.Uni;
import io.smallrye.mutiny.subscription.Cancellable;

public class UniResponseHandler implements ServerRestHandler {

@Override
public void handle(ResteasyReactiveRequestContext requestContext) throws Exception {
// FIXME: handle Response with entity being a Uni
if (requestContext.getResult() instanceof Uni) {
Uni<?> result = (Uni<?>) requestContext.getResult();
if (requestContext.getResult() instanceof Uni<?> result) {
requestContext.suspend();

result.subscribe().with(new Consumer<Object>() {
AtomicBoolean done = new AtomicBoolean();
Cancellable cancellable = result.subscribe().with(new Consumer<Object>() {
@Override
public void accept(Object v) {
done.set(true);
requestContext.setResult(v);
requestContext.resume();
}
}, new Consumer<Throwable>() {
}, new Consumer<>() {
@Override
public void accept(Throwable t) {
done.set(true);
requestContext.resume(t, true);
}
});

requestContext.serverResponse().addCloseHandler(new Runnable() {
@Override
public void run() {
if (!done.get()) {
cancellable.cancel();
try {
// get rid of everything related to the request since the connection has already gone away
requestContext.close();
} catch (Exception ignored) {

}
}
}
});
}
}
}