Skip to content

Commit

Permalink
Merge pull request quarkusio#37241 from geoand/quarkusio#37107
Browse files Browse the repository at this point in the history
Make improvements to REST Client SSE handling
  • Loading branch information
geoand authored Nov 21, 2023
2 parents fc5b82c + 6f41d71 commit 09afc2d
Show file tree
Hide file tree
Showing 7 changed files with 487 additions and 10 deletions.
101 changes: 101 additions & 0 deletions docs/src/main/asciidoc/rest-client-reactive.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -883,6 +883,107 @@ If you use a `CompletionStage`, you would need to call the service's method to r
This difference comes from the laziness aspect of Mutiny and its subscription protocol.
More details about this can be found in https://smallrye.io/smallrye-mutiny/latest/reference/uni-and-multi/[the Mutiny documentation].

=== Server-Sent Event (SSE) support

Consuming SSE events is possible simply by declaring the result type as a `io.smallrye.mutiny.Multi`.

The simplest example is:

[source, java]
----
package org.acme.rest.client;
import io.smallrye.mutiny.Multi;
import org.eclipse.microprofile.rest.client.inject.RegisterRestClient;
import jakarta.ws.rs.GET;
import jakarta.ws.rs.Path;
@Path("/sse")
@RegisterRestClient(configKey = "some-api")
public interface SseClient {
@GET
@Produces(MediaType.SERVER_SENT_EVENTS)
Multi<String> get();
}
----

[NOTE]
====
All the IO involved in streaming the SSE results is done in a non-blocking manner.
====

Results are not limited to strings - for example when the server returns JSON payload for each event, Quarkus automatically deserializes it into the generic type used in the `Multi`.

[TIP]
====
Users can also access the entire SSE event by using the `org.jboss.resteasy.reactive.client.SseEvent` type.
A simple example where the event payloads are `Long` values is the following:
[source, java]
----
package org.acme.rest.client;
import io.smallrye.mutiny.Uni;
import org.eclipse.microprofile.rest.client.inject.RegisterRestClient;
import org.jboss.resteasy.reactive.client.SseEvent;
import jakarta.ws.rs.GET;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.QueryParam;
@Path("/sse")
@RegisterRestClient(configKey = "some-api")
public interface SseClient {
@GET
@Produces(MediaType.SERVER_SENT_EVENTS)
Multi<SseEvent<Long>> get();
}
----
====

==== Filtering out events

On occasion, the stream of SSE events may contain some events that should not be returned by the client - an example of this is having the server send heartbeat events in order to keep the underlying TCP connection open.
The REST Client supports filtering out such events by providing the `@org.jboss.resteasy.reactive.client.SseEventFilter`.

Here is an example of filtering out heartbeat events:

[source,java]
----
package org.acme.rest.client;
import io.smallrye.mutiny.Uni;
import java.util.function.Predicate;
import org.eclipse.microprofile.rest.client.inject.RegisterRestClient;
import org.jboss.resteasy.reactive.client.SseEvent;
import org.jboss.resteasy.reactive.client.SseEventFilter;
import jakarta.ws.rs.GET;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.QueryParam;
@Path("/sse")
@RegisterRestClient(configKey = "some-api")
public interface SseClient {
@GET
@Produces(MediaType.SERVER_SENT_EVENTS)
@SseEventFilter(HeartbeatFilter.class)
Multi<SseEvent<Long>> get();
class HeartbeatFilter implements Predicate<SessionEvent<String>> {
@Override
public boolean test(SseEvent<String> event) {
return !"heartbeat".equals(event.id());
}
}
}
----

== Custom headers support

There are a few ways in which you can specify custom headers for your REST calls:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,23 @@
import java.util.Objects;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import java.util.function.Predicate;

import jakarta.ws.rs.GET;
import jakarta.ws.rs.POST;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.Produces;
import jakarta.ws.rs.core.Context;
import jakarta.ws.rs.core.MediaType;
import jakarta.ws.rs.sse.OutboundSseEvent;
import jakarta.ws.rs.sse.Sse;
import jakarta.ws.rs.sse.SseEventSink;

import org.eclipse.microprofile.rest.client.inject.RegisterRestClient;
import org.jboss.resteasy.reactive.RestStreamElementType;
import org.jboss.resteasy.reactive.client.SseEvent;
import org.jboss.resteasy.reactive.client.SseEventFilter;
import org.jboss.resteasy.reactive.server.jackson.JacksonBasicMessageBodyReader;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
Expand Down Expand Up @@ -112,6 +120,82 @@ void shouldRestStreamElementTypeOverwriteProducesAtClassLevel() {
.containsExactly(new Dto("foo", "bar"), new Dto("chocolate", "bar")));
}

@Test
void shouldBeAbleReadEntireEvent() {
var resultList = new CopyOnWriteArrayList<>();
createClient()
.event()
.subscribe().with(new Consumer<>() {
@Override
public void accept(SseEvent<Dto> event) {
resultList.add(new EventContainer(event.id(), event.name(), event.data()));
}
});
await().atMost(5, TimeUnit.SECONDS)
.untilAsserted(
() -> assertThat(resultList).containsExactly(
new EventContainer("id0", "name0", new Dto("name0", "0")),
new EventContainer("id1", "name1", new Dto("name1", "1"))));
}

@Test
void shouldBeAbleReadEntireEventWhileAlsoBeingAbleToFilterEvents() {
var resultList = new CopyOnWriteArrayList<>();
createClient()
.eventWithFilter()
.subscribe().with(new Consumer<>() {
@Override
public void accept(SseEvent<Dto> event) {
resultList.add(new EventContainer(event.id(), event.name(), event.data()));
}
});
await().atMost(5, TimeUnit.SECONDS)
.untilAsserted(
() -> assertThat(resultList).containsExactly(
new EventContainer("id", "n0", new Dto("name0", "0")),
new EventContainer("id", "n1", new Dto("name1", "1")),
new EventContainer("id", "n2", new Dto("name2", "2"))));
}

static class EventContainer {
final String id;
final String name;
final Dto dto;

EventContainer(String id, String name, Dto dto) {
this.id = id;
this.name = name;
this.dto = dto;
}

@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
EventContainer that = (EventContainer) o;
return Objects.equals(id, that.id) && Objects.equals(name, that.name)
&& Objects.equals(dto, that.dto);
}

@Override
public int hashCode() {
return Objects.hash(id, name, dto);
}

@Override
public String toString() {
return "EventContainer{" +
"id='" + id + '\'' +
", name='" + name + '\'' +
", dto=" + dto +
'}';
}
}

private SseClient createClient() {
return QuarkusRestClientBuilder.newBuilder()
.baseUri(uri)
Expand Down Expand Up @@ -144,6 +228,31 @@ public interface SseClient {
@Produces(MediaType.SERVER_SENT_EVENTS)
@Path("/with-entity-json")
Multi<Map<String, String>> postAndReadAsMap(String entity);

@GET
@Path("/event")
@Produces(MediaType.SERVER_SENT_EVENTS)
Multi<SseEvent<Dto>> event();

@GET
@Path("/event-with-filter")
@Produces(MediaType.SERVER_SENT_EVENTS)
@SseEventFilter(CustomFilter.class)
Multi<SseEvent<Dto>> eventWithFilter();
}

public static class CustomFilter implements Predicate<SseEvent<String>> {

@Override
public boolean test(SseEvent<String> event) {
if ("heartbeat".equals(event.id())) {
return false;
}
if ("END".equals(event.data())) {
return false;
}
return true;
}
}

@Path("/sse")
Expand Down Expand Up @@ -175,6 +284,68 @@ public Multi<String> post(String entity) {
public Multi<Dto> postAndReadAsMap(String entity) {
return Multi.createBy().repeating().supplier(() -> new Dto("foo", entity)).atMost(3);
}

@GET
@Path("/event")
@Produces(MediaType.SERVER_SENT_EVENTS)
public void event(@Context SseEventSink sink, @Context Sse sse) {
// send a stream of few events
try (sink) {
for (int i = 0; i < 2; i++) {
final OutboundSseEvent.Builder builder = sse.newEventBuilder();
builder.id("id" + i)
.mediaType(MediaType.APPLICATION_JSON_TYPE)
.data(Dto.class, new Dto("name" + i, String.valueOf(i)))
.name("name" + i);

sink.send(builder.build());
}
}
}

@GET
@Path("/event-with-filter")
@Produces(MediaType.SERVER_SENT_EVENTS)
public void eventWithFilter(@Context SseEventSink sink, @Context Sse sse) {
try (sink) {
sink.send(sse.newEventBuilder()
.id("id")
.mediaType(MediaType.APPLICATION_JSON_TYPE)
.data(Dto.class, new Dto("name0", "0"))
.name("n0")
.build());

sink.send(sse.newEventBuilder()
.id("heartbeat")
.comment("heartbeat")
.mediaType(MediaType.APPLICATION_JSON_TYPE)
.build());

sink.send(sse.newEventBuilder()
.id("id")
.mediaType(MediaType.APPLICATION_JSON_TYPE)
.data(Dto.class, new Dto("name1", "1"))
.name("n1")
.build());

sink.send(sse.newEventBuilder()
.id("heartbeat")
.comment("heartbeat")
.build());

sink.send(sse.newEventBuilder()
.id("id")
.mediaType(MediaType.APPLICATION_JSON_TYPE)
.data(Dto.class, new Dto("name2", "2"))
.name("n2")
.build());

sink.send(sse.newEventBuilder()
.id("end")
.data("END")
.build());
}
}
}

@Path("/sse-rest-stream-element-type")
Expand Down Expand Up @@ -226,5 +397,13 @@ public boolean equals(Object o) {
public int hashCode() {
return Objects.hash(name, value);
}

@Override
public String toString() {
return "Dto{" +
"name='" + name + '\'' +
", value='" + value + '\'' +
'}';
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.eclipse.microprofile.rest.client.annotation.RegisterProviders;
import org.eclipse.microprofile.rest.client.ext.ResponseExceptionMapper;
import org.jboss.jandex.DotName;
import org.jboss.resteasy.reactive.client.SseEventFilter;

import io.quarkus.rest.client.reactive.ClientExceptionMapper;
import io.quarkus.rest.client.reactive.ClientFormParam;
Expand Down Expand Up @@ -41,6 +42,8 @@ public class DotNames {

static final DotName METHOD = DotName.createSimple(Method.class.getName());

public static final DotName SSE_EVENT_FILTER = DotName.createSimple(SseEventFilter.class);

private DotNames() {
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
import org.jboss.resteasy.reactive.common.util.QuarkusMultivaluedHashMap;

import io.quarkus.arc.deployment.AdditionalBeanBuildItem;
import io.quarkus.arc.deployment.BeanArchiveIndexBuildItem;
import io.quarkus.arc.deployment.CustomScopeAnnotationsBuildItem;
import io.quarkus.arc.deployment.GeneratedBeanBuildItem;
import io.quarkus.arc.deployment.GeneratedBeanGizmoAdaptor;
Expand Down Expand Up @@ -371,6 +372,42 @@ void registerCompressionInterceptors(BuildProducer<ReflectiveClassBuildItem> ref
}
}

@BuildStep
void handleSseEventFilter(BuildProducer<ReflectiveClassBuildItem> reflectiveClasses,
BeanArchiveIndexBuildItem beanArchiveIndexBuildItem) {
var index = beanArchiveIndexBuildItem.getIndex();
Collection<AnnotationInstance> instances = index.getAnnotations(DotNames.SSE_EVENT_FILTER);
if (instances.isEmpty()) {
return;
}

List<String> filterClassNames = new ArrayList<>(instances.size());
for (AnnotationInstance instance : instances) {
if (instance.target().kind() != AnnotationTarget.Kind.METHOD) {
continue;
}
if (instance.value() == null) {
continue; // can't happen
}
Type filterType = instance.value().asClass();
DotName filterClassName = filterType.name();
ClassInfo filterClassInfo = index.getClassByName(filterClassName.toString());
if (filterClassInfo == null) {
log.warn("Unable to find class '" + filterType.name() + "' in index");
} else if (!filterClassInfo.hasNoArgsConstructor()) {
throw new RestClientDefinitionException(
"Classes used in @SseEventFilter must have a no-args constructor. Offending class is '"
+ filterClassName + "'");
} else {
filterClassNames.add(filterClassName.toString());
}
}
reflectiveClasses.produce(ReflectiveClassBuildItem
.builder(filterClassNames.toArray(new String[0]))
.constructors(true)
.build());
}

@BuildStep
@Record(ExecutionTime.STATIC_INIT)
void addRestClientBeans(Capabilities capabilities,
Expand Down
Loading

0 comments on commit 09afc2d

Please sign in to comment.