Skip to content

Commit

Permalink
Add test for reading headers from web socket responses (#2245)
Browse files Browse the repository at this point in the history
  • Loading branch information
adamw authored Aug 6, 2024
1 parent 365356d commit bb151fb
Show file tree
Hide file tree
Showing 16 changed files with 46 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ abstract class WebSocketTest[F[_]]
implicit def monad: MonadError[F]

def throwsWhenNotAWebSocket: Boolean = false
def supportsReadingWebSocketResponseHeaders: Boolean = true

it should "send and receive three messages using asWebSocketAlways" in {
basicRequest
Expand Down Expand Up @@ -207,6 +208,19 @@ abstract class WebSocketTest[F[_]]
.toFuture()
}

if (supportsReadingWebSocketResponseHeaders) {
it should "receive the extra headers set by the server" in {
basicRequest
.get(uri"$wsEndpoint/ws/header")
.response(asWebSocketAlways((ws: WebSocket[F]) => ws.close()))
.send(backend)
.map { response =>
response.header("Correlation-id") shouldBe Some("ABC-XYZ-123")
}
.toFuture()
}
}

def sendText(ws: WebSocket[F], count: Int): F[Unit] =
send(ws, count, (i: Int) => WebSocketFrame.text(s"test$i"))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class FetchBackendWebSocketTest extends WebSocketTest[Future] {

implicit override def executionContext: ExecutionContext = queue
override def throwsWhenNotAWebSocket: Boolean = true
override def supportsReadingWebSocketResponseHeaders: Boolean = false

override val backend: WebSocketBackend[Future] = FetchBackend()
override implicit val convertToFuture: ConvertToFuture[Future] = ConvertToFuture.future
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,6 @@ class HttpClientFutureWebSocketTest[F[_]] extends WebSocketTest[Future] with Web
override implicit val monad: MonadError[Future] = new FutureMonad()

override def concurrently[T](fs: List[() => Future[T]]): Future[List[T]] = Future.sequence(fs.map(_()))

override def supportsReadingWebSocketResponseHeaders: Boolean = false
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,6 @@ class HttpClientSyncWebSocketTest extends WebSocketTest[Identity] {
override implicit val monad: MonadError[Identity] = IdentityMonad

override def throwsWhenNotAWebSocket: Boolean = true
// HttpClient doesn't expose the response headers for web sockets in any way
override def supportsReadingWebSocketResponseHeaders: Boolean = false
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import scala.scalajs.concurrent.JSExecutionContext.queue
class FetchCatsWebSocketTest extends WebSocketTest[IO] with CatsTestBase {
implicit override def executionContext: ExecutionContext = queue
override def throwsWhenNotAWebSocket: Boolean = true
override def supportsReadingWebSocketResponseHeaders: Boolean = false

override val backend: WebSocketBackend[IO] = FetchCatsBackend()
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import scala.scalajs.concurrent.JSExecutionContext.queue
class FetchCatsWebSocketTest extends WebSocketTest[IO] with CatsTestBase {
implicit override def executionContext: ExecutionContext = queue
override def throwsWhenNotAWebSocket: Boolean = true
override def supportsReadingWebSocketResponseHeaders: Boolean = false

override val backend: WebSocketBackend[IO] = FetchCatsBackend()
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,6 @@ class HttpClientCatsWebSocketTest
with HttpClientCatsTestBase {

override def concurrently[T](fs: List[() => IO[T]]): IO[List[T]] = fs.map(_()).parSequence

override def supportsReadingWebSocketResponseHeaders: Boolean = false
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,6 @@ class HttpClientFs2WebSocketTest
to.andThen(rest => fs2.Stream.eval(item.pure[IO]) ++ rest)

override def concurrently[T](fs: List[() => IO[T]]): IO[List[T]] = fs.map(_()).parSequence

override def supportsReadingWebSocketResponseHeaders: Boolean = false
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,6 @@ class HttpClientFs2WebSocketTest
to.andThen(rest => fs2.Stream.eval(item.pure[IO]) ++ rest)

override def concurrently[T](fs: List[() => IO[T]]): IO[List[T]] = fs.map(_()).parSequence

override def supportsReadingWebSocketResponseHeaders: Boolean = false
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import scala.scalajs.concurrent.JSExecutionContext.Implicits.queue
class FetchMonixWebSocketTest extends WebSocketTest[Task] with WebSocketStreamingTest[Task, MonixStreams] {
implicit override def executionContext: ExecutionContext = queue
override def throwsWhenNotAWebSocket: Boolean = true
override def supportsReadingWebSocketResponseHeaders: Boolean = false

override val backend: WebSocketStreamBackend[Task, MonixStreams] = FetchMonixBackend()
override implicit val convertToFuture: ConvertToFuture[Task] = convertMonixTaskToFuture
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,6 @@ class HttpClientMonixWebSocketTest
to.andThen(rest => Observable.now(item) ++ rest)

override def concurrently[T](fs: List[() => Task[T]]): Task[List[T]] = Task.parSequence(fs.map(_()))

override def supportsReadingWebSocketResponseHeaders: Boolean = false
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import scala.scalajs.concurrent.JSExecutionContext.Implicits.queue
class FetchZioWebSocketTest extends WebSocketTest[Task] with WebSocketStreamingTest[Task, ZioStreams] with ZioTestBase {
implicit override def executionContext: ExecutionContext = queue
override def throwsWhenNotAWebSocket: Boolean = true
override def supportsReadingWebSocketResponseHeaders: Boolean = false

override val backend: WebSocketStreamBackend[Task, ZioStreams] = FetchZioBackend()
override implicit val convertToFuture: ConvertToFuture[Task] = convertZioTaskToFuture
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,6 @@ class HttpClientZioWebSocketTest
to.andThen(rest => ZStream(item) ++ rest)

override def concurrently[T](fs: List[() => Task[T]]): Task[List[T]] = ZIO.collectAllPar(fs.map(_()))

override def supportsReadingWebSocketResponseHeaders: Boolean = false
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import scala.scalajs.concurrent.JSExecutionContext.Implicits.queue
class FetchZioWebSocketTest extends WebSocketTest[Task] with WebSocketStreamingTest[Task, ZioStreams] with ZioTestBase {
implicit override def executionContext: ExecutionContext = queue
override def throwsWhenNotAWebSocket: Boolean = true
override def supportsReadingWebSocketResponseHeaders: Boolean = false

override val backend: WebSocketStreamBackend[Task, ZioStreams] = FetchZioBackend()
override implicit val convertToFuture: ConvertToFuture[Task] = convertZioTaskToFuture
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,6 @@ class HttpClientZioWebSocketTest
to.andThen(rest => ZStream(item) ++ rest)

override def concurrently[T](fs: List[() => Task[T]]): Task[List[T]] = Task.collectAllPar(fs.map(_()))

override def supportsReadingWebSocketResponseHeaders: Boolean = false
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import akka.actor.ActorSystem
import akka.http.scaladsl.Http
import akka.http.scaladsl.coding.Coders._
import akka.http.scaladsl.coding.DeflateNoWrap
import akka.http.scaladsl.model.HttpHeader.ParsingResult
import akka.http.scaladsl.model._
import akka.http.scaladsl.model.headers.CacheDirectives._
import akka.http.scaladsl.model.headers._
Expand Down Expand Up @@ -487,6 +488,15 @@ private class HttpServer(port: Int, info: String => Unit) extends AutoCloseable
)
)
}
} ~
path("header") {
respondWithHeader(HttpHeader.parse("Correlation-ID", "ABC-XYZ-123").asInstanceOf[ParsingResult.Ok].header) {
handleWebSocketMessages(Flow[Message].mapConcat {
case tm: TextMessage =>
TextMessage(Source.single("echo: ") ++ tm.textStream) :: Nil
case bm: BinaryMessage => bm :: Nil
})
}
}
} ~ path("empty_content_encoding") {
get {
Expand Down

0 comments on commit bb151fb

Please sign in to comment.