Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SSE: Allow provision of custom EventStreamUnmarshalling in EventSource #1185

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import akka.http.scaladsl.model.sse.ServerSentEvent.heartbeat
import akka.http.scaladsl.model.MediaTypes.`text/event-stream`
import akka.http.scaladsl.model.headers.`Last-Event-ID`
import akka.http.scaladsl.unmarshalling.sse.EventStreamUnmarshalling

import scala.concurrent.Future
import scala.concurrent.duration.{Duration, FiniteDuration}

Expand Down Expand Up @@ -76,16 +77,18 @@ object EventSource {
* @param send function to send a HTTP request
* @param initialLastEventId initial value for Last-Evend-ID header, `None` by default
* @param retryDelay delay for retrying after completion, `0` by default
* @param unmarshaller converts event-stream responses to a Source of `ServerSentEvent`s.
* @param mat implicit `Materializer`, needed to obtain server-sent events
* @return continuous source of server-sent events
*/
def apply(uri: Uri,
send: HttpRequest => Future[HttpResponse],
initialLastEventId: Option[String] = None,
retryDelay: FiniteDuration = Duration.Zero)(
retryDelay: FiniteDuration = Duration.Zero,
unmarshaller: EventStreamUnmarshalling = EventStreamUnmarshalling)(
implicit mat: Materializer
): EventSource = {
import EventStreamUnmarshalling._
import unmarshaller._
import mat.executionContext

val continuousEvents = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import akka.Done
import akka.actor.{Actor, ActorLogging, ActorSystem, Props, Status}
import akka.http.scaladsl.Http
import akka.http.scaladsl.marshalling.sse.EventStreamMarshalling
import akka.http.scaladsl.unmarshalling.sse.EventStreamUnmarshalling
import akka.http.scaladsl.model.MediaTypes.`text/event-stream`
import akka.http.scaladsl.model.StatusCodes.BadRequest
import akka.http.scaladsl.model.headers.`Last-Event-ID`
Expand All @@ -21,9 +22,12 @@ import akka.stream.{ActorMaterializer, ThrottleMode}
import akka.testkit.SocketUtil
import java.net.InetSocketAddress
import java.nio.charset.StandardCharsets.UTF_8

import org.scalatest.{AsyncWordSpec, BeforeAndAfterAll, Matchers}

import scala.concurrent.Await
import scala.concurrent.duration.DurationInt
import scala.util.Random

object EventSourceSpec {

Expand All @@ -32,7 +36,7 @@ object EventSourceSpec {
private final case object Bind
private final case object Unbind

private def route(size: Int, setEventId: Boolean): Route = {
private def route(size: Int, setEventId: Boolean, data: Option[String]): Route = {
import Directives._
import EventStreamMarshalling._
get {
Expand All @@ -41,7 +45,7 @@ object EventSourceSpec {
val fromSeqNo = lastEventId.map(_.trim.toInt).getOrElse(0) + 1
complete {
Source(fromSeqNo.until(fromSeqNo + size))
.map(toServerSentEvent(setEventId))
.map(toServerSentEvent(setEventId, data))
.intersperse(ServerSentEvent.heartbeat)
}
} catch {
Expand All @@ -61,7 +65,11 @@ object EventSourceSpec {
}
}

final class Server(address: String, port: Int, size: Int, shouldSetEventId: Boolean = false)
final class Server(address: String,
port: Int,
size: Int,
shouldSetEventId: Boolean = false,
eventData: Option[String] = None)
extends Actor
with ActorLogging {
import Server._
Expand All @@ -75,7 +83,7 @@ object EventSourceSpec {

private def unbound: Receive = {
case Bind =>
Http(context.system).bindAndHandle(route(size, shouldSetEventId), address, port).pipeTo(self)
Http(context.system).bindAndHandle(route(size, shouldSetEventId, eventData), address, port).pipeTo(self)
context.become(binding)
}

Expand Down Expand Up @@ -108,10 +116,10 @@ object EventSourceSpec {
}
}

private def toServerSentEvent(setEventId: Boolean)(n: Int) = {
val data = n.toString
val event = ServerSentEvent(data)
if (setEventId) event.copy(id = Some(data)) else event
private def toServerSentEvent(setEventId: Boolean, data: Option[String] = None)(n: Int) = {
val eventId = n.toString
val event = ServerSentEvent(data.getOrElse(eventId))
if (setEventId) event.copy(id = Some(eventId)) else event
}

private def hostAndPort() = {
Expand All @@ -128,7 +136,7 @@ final class EventSourceSpec extends AsyncWordSpec with Matchers with BeforeAndAf
private implicit val mat = ActorMaterializer()

"EventSource" should {
"communicate correctly with an instable HTTP server" in {
"communicate correctly with an unstable HTTP server" in {
val nrOfSamples = 20
val (host, port) = hostAndPort()
val server = system.actorOf(Props(new Server(host, port, 2, true)))
Expand All @@ -155,6 +163,23 @@ final class EventSourceSpec extends AsyncWordSpec with Matchers with BeforeAndAf
val expected = Seq.tabulate(nrOfSamples)(_ % 2 + 3).map(toServerSentEvent(false))
events.map(_ shouldBe expected).andThen { case _ => system.stop(server) }
}

"permit the provison of a custom umarshaller" in {
val nrOfSamples = 20
val (host, port) = hostAndPort()
val data = Some(Random.alphanumeric.take(6000).mkString)
val server = system.actorOf(Props(new Server(host, port, 2, true, data)))

object unmarshaller extends EventStreamUnmarshalling {
override def maxLineSize: Int = 6001
}

val eventSource = EventSource(Uri(s"http://$host:$port"), send, Some("2"), 1.second, unmarshaller)
val events =
eventSource.throttle(1, 500.milliseconds, 1, ThrottleMode.Shaping).take(nrOfSamples).runWith(Sink.seq)
val expected = Seq.tabulate(nrOfSamples)(_ + 3).map(toServerSentEvent(true, data))
events.map(_ shouldBe expected).andThen { case _ => system.stop(server) }
}
}

override protected def afterAll() = {
Expand Down