Skip to content

Commit

Permalink
WebSockets Next: error handlers part 2
Browse files Browse the repository at this point in the history
- call error handlers if an endpoint callback returns Uni that receives
a failure
- javadoc clarifications
  • Loading branch information
mkouba committed Mar 28, 2024
1 parent 18d873e commit 103c885
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
import io.smallrye.mutiny.Multi;
import io.smallrye.mutiny.Uni;
import io.smallrye.mutiny.groups.UniCreate;
import io.smallrye.mutiny.groups.UniOnFailure;
import io.vertx.core.buffer.Buffer;
import io.vertx.core.json.JsonArray;
import io.vertx.core.json.JsonObject;
Expand Down Expand Up @@ -465,17 +466,19 @@ private String generateEndpoint(WebSocketEndpointBuildItem endpoint,
TryBlock tryBlock = onErrorTryBlock(doOnOpen);
ResultHandle[] args = callback.generateArguments(tryBlock.getThis(), tryBlock, transformedAnnotations, index);
ResultHandle ret = tryBlock.invokeVirtualMethod(MethodDescriptor.of(callback.method), beanInstance, args);
encodeAndReturnResult(tryBlock.getThis(), tryBlock, callback, ret);
encodeAndReturnResult(tryBlock.getThis(), tryBlock, callback, globalErrorHandlers, endpoint, ret);

MethodCreator onOpenExecutionModel = endpointCreator.getMethodCreator("onOpenExecutionModel",
ExecutionModel.class);
onOpenExecutionModel.returnValue(onOpenExecutionModel.load(callback.executionModel));
}

generateOnMessage(endpointCreator, endpoint, endpoint.onBinaryMessage, argumentProviders, transformedAnnotations,
index);
generateOnMessage(endpointCreator, endpoint, endpoint.onTextMessage, argumentProviders, transformedAnnotations, index);
generateOnMessage(endpointCreator, endpoint, endpoint.onPongMessage, argumentProviders, transformedAnnotations, index);
index, globalErrorHandlers);
generateOnMessage(endpointCreator, endpoint, endpoint.onTextMessage, argumentProviders, transformedAnnotations, index,
globalErrorHandlers);
generateOnMessage(endpointCreator, endpoint, endpoint.onPongMessage, argumentProviders, transformedAnnotations, index,
globalErrorHandlers);

if (endpoint.onClose != null) {
Callback callback = endpoint.onClose;
Expand All @@ -488,7 +491,7 @@ private String generateEndpoint(WebSocketEndpointBuildItem endpoint,
TryBlock tryBlock = onErrorTryBlock(doOnClose);
ResultHandle[] args = callback.generateArguments(tryBlock.getThis(), tryBlock, transformedAnnotations, index);
ResultHandle ret = tryBlock.invokeVirtualMethod(MethodDescriptor.of(callback.method), beanInstance, args);
encodeAndReturnResult(tryBlock.getThis(), tryBlock, callback, ret);
encodeAndReturnResult(tryBlock.getThis(), tryBlock, callback, globalErrorHandlers, endpoint, ret);

MethodCreator onCloseExecutionModel = endpointCreator.getMethodCreator("onCloseExecutionModel",
ExecutionModel.class);
Expand All @@ -504,10 +507,6 @@ private String generateEndpoint(WebSocketEndpointBuildItem endpoint,
private void generateOnError(ClassCreator endpointCreator, WebSocketEndpointBuildItem endpoint,
CallbackArgumentsBuildItem callbackArguments, TransformedAnnotationsBuildItem transformedAnnotations,
GlobalErrorHandlersBuildItem globalErrorHandlers, IndexView index) {
if (endpoint.onErrors.isEmpty()) {
return;
}
MethodCreator doOnError = endpointCreator.getMethodCreator("doOnError", Uni.class, Throwable.class);

Map<DotName, Callback> errors = new HashMap<>();
List<ThrowableInfo> throwableInfos = new ArrayList<>();
Expand All @@ -532,9 +531,13 @@ private void generateOnError(ClassCreator endpointCreator, WebSocketEndpointBuil
}
}

if (throwableInfos.isEmpty()) {
return;
}

MethodCreator doOnError = endpointCreator.getMethodCreator("doOnError", Uni.class, Throwable.class);
// Most specific errors go first
throwableInfos.sort(Comparator.comparingInt(ThrowableInfo::level).reversed());

ResultHandle endpointThis = doOnError.getThis();

for (ThrowableInfo throwableInfo : throwableInfos) {
Expand All @@ -553,7 +556,7 @@ private void generateOnError(ClassCreator endpointCreator, WebSocketEndpointBuil
endpointThis, funBytecode.load(throwableInfo.bean().getIdentifier()));
ResultHandle[] args = callback.generateArguments(endpointThis, tryBlock, transformedAnnotations, index);
ResultHandle ret = tryBlock.invokeVirtualMethod(MethodDescriptor.of(callback.method), beanInstance, args);
encodeAndReturnResult(endpointThis, tryBlock, callback, ret);
encodeAndReturnResult(endpointThis, tryBlock, callback, globalErrorHandlers, endpoint, ret);

// return doErrorExecute()
throwableMatches.returnValue(
Expand Down Expand Up @@ -604,7 +607,7 @@ record GlobalErrorHandler(BeanInfo bean, Callback callback) {

private void generateOnMessage(ClassCreator endpointCreator, WebSocketEndpointBuildItem endpoint, Callback callback,
CallbackArgumentsBuildItem callbackArguments, TransformedAnnotationsBuildItem transformedAnnotations,
IndexView index) {
IndexView index, GlobalErrorHandlersBuildItem globalErrorHandlers) {
if (callback == null) {
return;
}
Expand Down Expand Up @@ -638,7 +641,7 @@ private void generateOnMessage(ClassCreator endpointCreator, WebSocketEndpointBu
// Call the business method
ResultHandle ret = tryBlock.invokeVirtualMethod(MethodDescriptor.of(callback.method), beanInstance,
args);
encodeAndReturnResult(tryBlock.getThis(), tryBlock, callback, ret);
encodeAndReturnResult(tryBlock.getThis(), tryBlock, callback, globalErrorHandlers, endpoint, ret);

MethodCreator onMessageExecutionModel = endpointCreator.getMethodCreator("on" + messageType + "MessageExecutionModel",
ExecutionModel.class);
Expand Down Expand Up @@ -752,7 +755,30 @@ static ResultHandle decodeMessage(
}
}

private ResultHandle uniOnFailureDoOnError(ResultHandle endpointThis, BytecodeCreator method, Callback callback,
ResultHandle uni, WebSocketEndpointBuildItem endpoint, GlobalErrorHandlersBuildItem globalErrorHandlers) {
if (callback.isOnError()
|| (globalErrorHandlers.handlers.isEmpty() && (endpoint == null || endpoint.onErrors.isEmpty()))) {
// @OnError or no error handlers available
return uni;
}
// return uniMessage.onFailure().recoverWithUni(t -> {
// return doOnError(t);
// });
FunctionCreator fun = method.createFunction(Function.class);
BytecodeCreator funBytecode = fun.getBytecode();
funBytecode.returnValue(funBytecode.invokeVirtualMethod(
MethodDescriptor.ofMethod(WebSocketEndpointBase.class, "doOnError", Uni.class, Throwable.class),
endpointThis, funBytecode.getMethodParam(0)));
ResultHandle uniOnFailure = method.invokeInterfaceMethod(
MethodDescriptor.ofMethod(Uni.class, "onFailure", UniOnFailure.class), uni);
return method.invokeVirtualMethod(
MethodDescriptor.ofMethod(UniOnFailure.class, "recoverWithUni", Uni.class, Function.class),
uniOnFailure, fun.getInstance());
}

private ResultHandle encodeMessage(ResultHandle endpointThis, BytecodeCreator method, Callback callback,
GlobalErrorHandlersBuildItem globalErrorHandlers, WebSocketEndpointBuildItem endpoint,
ResultHandle value) {
if (callback.acceptsBinaryMessage()) {
// ----------------------
Expand All @@ -762,7 +788,7 @@ private ResultHandle encodeMessage(ResultHandle endpointThis, BytecodeCreator me
Type messageType = callback.returnType().asParameterizedType().arguments().get(0);
if (messageType.name().equals(WebSocketDotNames.VOID)) {
// Uni<Void>
return value;
return uniOnFailureDoOnError(endpointThis, method, callback, value, endpoint, globalErrorHandlers);
} else {
// return uniMessage.chain(m -> {
// Buffer buffer = encodeBuffer(m);
Expand All @@ -781,7 +807,7 @@ private ResultHandle encodeMessage(ResultHandle endpointThis, BytecodeCreator me
ResultHandle uniChain = method.invokeInterfaceMethod(
MethodDescriptor.ofMethod(Uni.class, "chain", Uni.class, Function.class), value,
fun.getInstance());
return uniChain;
return uniOnFailureDoOnError(endpointThis, method, callback, uniChain, endpoint, globalErrorHandlers);
}
} else if (callback.isReturnTypeMulti()) {
// return multiBinary(multi, broadcast, m -> {
Expand Down Expand Up @@ -817,7 +843,8 @@ private ResultHandle encodeMessage(ResultHandle endpointThis, BytecodeCreator me
Type messageType = callback.returnType().asParameterizedType().arguments().get(0);
if (messageType.name().equals(WebSocketDotNames.VOID)) {
// Uni<Void>
return value;

return uniOnFailureDoOnError(endpointThis, method, callback, value, endpoint, globalErrorHandlers);
} else {
// return uniMessage.chain(m -> {
// String text = encodeText(m);
Expand All @@ -835,7 +862,7 @@ private ResultHandle encodeMessage(ResultHandle endpointThis, BytecodeCreator me
ResultHandle uniChain = method.invokeInterfaceMethod(
MethodDescriptor.ofMethod(Uni.class, "chain", Uni.class, Function.class), value,
fun.getInstance());
return uniChain;
return uniOnFailureDoOnError(endpointThis, method, callback, uniChain, endpoint, globalErrorHandlers);
}
} else if (callback.isReturnTypeMulti()) {
// return multiText(multi, broadcast, m -> {
Expand Down Expand Up @@ -929,6 +956,7 @@ private ResultHandle uniVoid(BytecodeCreator method) {
}

private void encodeAndReturnResult(ResultHandle endpointThis, BytecodeCreator method, Callback callback,
GlobalErrorHandlersBuildItem globalErrorHandlers, WebSocketEndpointBuildItem endpoint,
ResultHandle result) {
// The result must be always Uni<Void>
if (callback.isReturnTypeVoid()) {
Expand All @@ -938,7 +966,7 @@ private void encodeAndReturnResult(ResultHandle endpointThis, BytecodeCreator me
// Skip response
BytecodeCreator isNull = method.ifNull(result).trueBranch();
isNull.returnValue(uniVoid(isNull));
method.returnValue(encodeMessage(endpointThis, method, callback, result));
method.returnValue(encodeMessage(endpointThis, method, callback, globalErrorHandlers, endpoint, result));
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
package io.quarkus.websockets.next.test.errors;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;

import java.net.URI;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;

import jakarta.annotation.PreDestroy;
import jakarta.enterprise.context.RequestScoped;
import jakarta.inject.Inject;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.quarkus.test.QuarkusUnitTest;
import io.quarkus.test.common.http.TestHTTPResource;
import io.quarkus.websockets.next.BinaryDecodeException;
import io.quarkus.websockets.next.BinaryEncodeException;
import io.quarkus.websockets.next.OnBinaryMessage;
import io.quarkus.websockets.next.OnError;
import io.quarkus.websockets.next.WebSocket;
import io.quarkus.websockets.next.WebSocketConnection;
import io.quarkus.websockets.next.test.utils.WSClient;
import io.smallrye.mutiny.Uni;
import io.vertx.core.Context;
import io.vertx.core.Vertx;
import io.vertx.core.buffer.Buffer;

public class UniFailureErrorTest {

@RegisterExtension
public static final QuarkusUnitTest test = new QuarkusUnitTest()
.withApplicationRoot(root -> {
root.addClasses(Echo.class, RequestBean.class, WSClient.class);
});

@Inject
Vertx vertx;

@TestHTTPResource("echo")
URI testUri;

@Test
void testError() throws InterruptedException {
WSClient client = WSClient.create(vertx).connect(testUri);
client.send(Buffer.buffer("1"));
client.waitForMessages(1);
assertEquals("Something went wrong", client.getLastMessage().toString());
assertTrue(RequestBean.DESTROYED_LATCH.await(5, TimeUnit.SECONDS));
}

@WebSocket(path = "/echo")
public static class Echo {

@Inject
WebSocketConnection connection;

@Inject
RequestBean requestBean;

@OnBinaryMessage
Uni<Void> process(WebSocketConnection connection, Buffer message) {
requestBean.setState("ok");
return Uni.createFrom().failure(new IllegalStateException("Something went wrong"));
}

@OnError
String encodingError(BinaryEncodeException e) {
return "Problem encoding: " + e.getEncodedObject().toString();
}

@OnError
String decodingError(BinaryDecodeException e) {
return "Problem decoding: " + e.getBytes().toString();
}

@OnError
String runtimeProblem(RuntimeException e, WebSocketConnection connection) {
assertTrue(Context.isOnWorkerThread());
assertEquals(connection.id(), this.connection.id());
// The request context from @OnBinaryMessage is reused
assertEquals("ok", requestBean.getState());
return e.getMessage();
}

@OnError
String catchAll(Throwable e) {
return "Ooops!";
}

}

@RequestScoped
public static class RequestBean {

static final CountDownLatch DESTROYED_LATCH = new CountDownLatch(1);

private volatile String state = "nok";

public String getState() {
return state;
}

public void setState(String state) {
this.state = state;
}

@PreDestroy
void destroy() {
DESTROYED_LATCH.countDown();
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
/**
* A {@link WebSocket} endpoint method annotated with this annotation is invoked when an error occurs.
* <p>
* It is used when an endpoint callback throws a runtime error, or when a conversion errors occurs, or when a returned
* {@link io.smallrye.mutiny.Uni} receives a failure.
* <p>
* The method must accept exactly one "error" parameter, i.e. a parameter that is assignable from {@link java.lang.Throwable}.
* The method may also accept the following parameters:
* <ul>
Expand All @@ -20,11 +23,11 @@
* <li>{@link String} parameters annotated with {@link PathParam}</li>
* </ul>
* <p>
* An endpoint may declare multiple methods annotated with this annotation. However, each method must declare a unique error
* An endpoint may declare multiple methods annotated with this annotation. However, each method must declare a different error
* parameter. The method that declares a most-specific supertype of the actual exception is selected.
* <p>
* This annotation can be also used to declare a global error handler, i.e. a method that is not declared on a {@link WebSocket}
* endpoint. Such a method may not not accept {@link PathParam} paremeters. Error handlers declared on an endpoint take
* endpoint. Such a method may not accept {@link PathParam} paremeters. Error handlers declared on an endpoint take
* precedence over the global error handlers.
*/
@Retention(RUNTIME)
Expand Down

0 comments on commit 103c885

Please sign in to comment.