From a192eb3d6b1d1daf3e809fb9f330f07da26b9d6f Mon Sep 17 00:00:00 2001 From: kciesielski Date: Thu, 23 Nov 2023 10:54:02 +0100 Subject: [PATCH 01/35] wip --- build.sbt | 2 +- .../main/scala/sttp/tapir/DecodeResult.scala | 1 + .../scala/sttp/tapir/ztapir/ZTapirTest.scala | 4 +- .../scala/sttp/tapir/ztapir/ZTapirTest.scala | 4 +- project/Versions.scala | 4 +- project/plugins.sbt | 2 +- .../server/akkagrpc/AkkaGrpcRequestBody.scala | 2 +- .../server/akkahttp/AkkaRequestBody.scala | 5 +- .../server/akkahttp/AkkaHttpServerTest.scala | 5 +- .../armeria/cats/TapirCatsService.scala | 6 +- .../armeria/cats/ArmeriaCatsServerTest.scala | 4 +- .../server/armeria/ArmeriaRequestBody.scala | 6 +- .../server/armeria/StreamCompatible.scala | 2 +- .../server/armeria/TapirFutureService.scala | 2 +- .../armeria/ArmeriaFutureServerTest.scala | 3 +- .../server/armeria/zio/TapirZioService.scala | 6 +- .../armeria/zio/ArmeriaZioServerTest.scala | 5 +- .../server/armeria/zio/TapirZioService.scala | 2 +- .../armeria/zio/ArmeriaZioServerTest.scala | 2 +- .../decodefailure/DecodeFailureHandler.scala | 12 ++-- .../exception/ExceptionHandler.scala | 8 ++- .../server/interpreter/RequestBody.scala | 9 ++- .../RequestBodyToRawException.scala | 8 +++ .../interpreter/ServerInterpreter.scala | 55 +++++++++++------ .../scala/sttp/tapir/server/TestUtil.scala | 4 +- .../server/finatra/FinatraRequestBody.scala | 3 +- .../server/http4s/Http4sRequestBody.scala | 7 ++- .../server/http4s/Http4sServerTest.scala | 5 +- .../http4s/ztapir/ZHttp4sServerTest.scala | 6 +- .../http4s/ztapir/ZHttp4sServerTest.scala | 4 +- .../jdkhttp/internal/JdkHttpRequestBody.scala | 12 ++-- .../netty/internal/NettyCatsRequestBody.scala | 42 ++++++++----- .../netty/cats/NettyCatsServerTest.scala | 4 +- .../netty/internal/NettyRequestBody.scala | 45 +++++++++----- .../server/netty/NettyFutureServerTest.scala | 2 +- .../netty/internal/NettyZioRequestBody.scala | 48 +++++++++------ .../server/netty/zio/NettyZioServerTest.scala | 15 ++++- .../nima/internal/NimaRequestBody.scala | 5 +- .../pekkogrpc/PekkoGrpcRequestBody.scala | 2 +- .../server/pekkohttp/PekkoRequestBody.scala | 7 ++- .../pekkohttp/PekkoHttpServerTest.scala | 5 +- .../tapir/server/play/PlayRequestBody.scala | 7 ++- .../tapir/server/play/PlayServerTest.scala | 5 +- .../tapir/server/stub/SttpRequestBody.scala | 4 +- .../tapir/server/tests/ServerBasicTests.scala | 23 ++++++- .../server/tests/ServerStreamingTests.scala | 60 +++++++++++++++---- .../tapir/server/vertx/cats/streams/fs2.scala | 6 +- .../vertx/cats/CatsVertxServerTest.scala | 5 +- .../vertx/cats/streams/Fs2StreamTest.scala | 6 +- .../vertx/decoders/VertxRequestBody.scala | 6 +- .../vertx/streams/ReadStreamCompatible.scala | 2 +- .../tapir/server/vertx/streams/package.scala | 2 +- .../tapir/server/vertx/VertxServerTest.scala | 3 +- .../tapir/server/vertx/zio/streams/zio.scala | 6 +- .../server/vertx/zio/ZioVertxServerTest.scala | 17 +++--- .../vertx/zio/streams/ZStreamTest.scala | 6 +- .../tapir/server/vertx/zio/streams/zio.scala | 3 +- .../server/vertx/zio/ZioVertxServerTest.scala | 3 +- .../vertx/zio/streams/ZStreamTest.scala | 6 +- .../server/ziohttp/ZioHttpRequestBody.scala | 33 +++++++--- .../server/ziohttp/ZioHttpServerTest.scala | 9 ++- .../server/ziohttp/ZioHttpRequestBody.scala | 43 +++++++------ .../server/ziohttp/ZioHttpServerTest.scala | 2 +- .../aws/lambda/AwsRequestBody.scala | 5 +- .../scala/sttp/tapir/tests/Streaming.scala | 4 ++ 65 files changed, 436 insertions(+), 205 deletions(-) create mode 100644 server/core/src/main/scala/sttp/tapir/server/interpreter/RequestBodyToRawException.scala diff --git a/build.sbt b/build.sbt index 63304354ca..e8a97bff95 100644 --- a/build.sbt +++ b/build.sbt @@ -345,7 +345,7 @@ lazy val rootProject = (project in file(".")) testFinatra := (Test / test).all(filterProject(p => p.contains("finatra"))).value, compileScoped := Def.inputTaskDyn { val args = spaceDelimited("").parsed - Def.taskDyn((Compile / compile).all(filterByVersionAndPlatform(args.head, args(1)))) + Def.taskDyn((Test / compile).all(filterByVersionAndPlatform(args.head, args(1)))) }.evaluated, testScoped := Def.inputTaskDyn { val args = spaceDelimited("").parsed diff --git a/core/src/main/scala/sttp/tapir/DecodeResult.scala b/core/src/main/scala/sttp/tapir/DecodeResult.scala index ae59e3dc04..fc372409c4 100644 --- a/core/src/main/scala/sttp/tapir/DecodeResult.scala +++ b/core/src/main/scala/sttp/tapir/DecodeResult.scala @@ -35,6 +35,7 @@ object DecodeResult { } } case class Mismatch(expected: String, actual: String) extends Failure + case class BodyTooLarge(maxBytes: Long) extends Failure /** A validation error that occurred when decoding the value, that is, when some `Validator` failed. */ case class InvalidValue(errors: List[ValidationError[_]]) extends Failure diff --git a/integrations/zio/src/test/scala/sttp/tapir/ztapir/ZTapirTest.scala b/integrations/zio/src/test/scala/sttp/tapir/ztapir/ZTapirTest.scala index 0ae7569154..05fbba3cee 100644 --- a/integrations/zio/src/test/scala/sttp/tapir/ztapir/ZTapirTest.scala +++ b/integrations/zio/src/test/scala/sttp/tapir/ztapir/ZTapirTest.scala @@ -28,8 +28,8 @@ object ZTapirTest extends ZIOSpecDefault with ZTapir { private val exampleRequestBody = new RequestBody[TestEffect, RequestBodyType] { override val streams: Streams[RequestBodyType] = null.asInstanceOf[Streams[RequestBodyType]] - override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R]): TestEffect[RawValue[R]] = ??? - override def toStream(serverRequest: ServerRequest): streams.BinaryStream = ??? + override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R], maxBytes: Option[Long]): TestEffect[RawValue[R]] = ??? + override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = ??? } private val exampleToResponse: ToResponseBody[ResponseBodyType, RequestBodyType] = new ToResponseBody[ResponseBodyType, RequestBodyType] { diff --git a/integrations/zio1/src/test/scala/sttp/tapir/ztapir/ZTapirTest.scala b/integrations/zio1/src/test/scala/sttp/tapir/ztapir/ZTapirTest.scala index 3e1617c785..d5dd2f81c2 100644 --- a/integrations/zio1/src/test/scala/sttp/tapir/ztapir/ZTapirTest.scala +++ b/integrations/zio1/src/test/scala/sttp/tapir/ztapir/ZTapirTest.scala @@ -30,8 +30,8 @@ object ZTapirTest extends DefaultRunnableSpec with ZTapir { private val exampleRequestBody = new RequestBody[TestEffect, RequestBodyType] { override val streams: Streams[RequestBodyType] = null.asInstanceOf[Streams[RequestBodyType]] - override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R]): TestEffect[RawValue[R]] = ??? - override def toStream(serverRequest: ServerRequest): streams.BinaryStream = ??? + override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R], maxBytes: Option[Long]): TestEffect[RawValue[R]] = ??? + override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = ??? } private val exampleToResponse: ToResponseBody[ResponseBodyType, RequestBodyType] = new ToResponseBody[ResponseBodyType, RequestBodyType] { diff --git a/project/Versions.scala b/project/Versions.scala index 19e6d4cc1f..e09a586733 100644 --- a/project/Versions.scala +++ b/project/Versions.scala @@ -10,7 +10,7 @@ object Versions { val helidon = "4.0.0" val sttp = "3.9.1" val sttpModel = "1.7.6" - val sttpShared = "1.3.16" + val sttpShared = "1.3.17" val sttpApispec = "0.7.2" val akkaHttp = "10.2.10" val akkaStreams = "2.6.20" @@ -44,7 +44,7 @@ object Versions { val vertx = "4.5.0" val jsScalaJavaTime = "2.5.0" val nativeScalaJavaTime = "2.4.0-M3" - val jwtScala = "9.4.4" + val jwtScala = "9.4.5" val derevo = "0.13.0" val newtype = "0.4.4" val monixNewtype = "0.2.3" diff --git a/project/plugins.sbt b/project/plugins.sbt index 4fa2ca88ba..19606dd502 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -7,7 +7,7 @@ addSbtPlugin("com.softwaremill.sbt-softwaremill" % "sbt-softwaremill-publish" % addSbtPlugin("com.softwaremill.sbt-softwaremill" % "sbt-softwaremill-browser-test-js" % sbtSoftwareMillVersion) //addSbtPlugin("io.spray" % "sbt-boilerplate" % "0.6.1") addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "1.1.3") -addSbtPlugin("org.playframework.twirl" % "sbt-twirl" % "2.0.1") +addSbtPlugin("org.playframework.twirl" % "sbt-twirl" % "2.0.2") addSbtPlugin("org.scalameta" % "sbt-mdoc" % "2.5.1") addSbtPlugin("com.eed3si9n" % "sbt-projectmatrix" % "0.9.1") addSbtPlugin("org.jetbrains.scala" % "sbt-ide-settings" % "1.1.2") diff --git a/server/akka-grpc-server/src/main/scala/sttp/tapir/server/akkagrpc/AkkaGrpcRequestBody.scala b/server/akka-grpc-server/src/main/scala/sttp/tapir/server/akkagrpc/AkkaGrpcRequestBody.scala index e4cf0a1468..315d7d8e71 100644 --- a/server/akka-grpc-server/src/main/scala/sttp/tapir/server/akkagrpc/AkkaGrpcRequestBody.scala +++ b/server/akka-grpc-server/src/main/scala/sttp/tapir/server/akkagrpc/AkkaGrpcRequestBody.scala @@ -25,7 +25,7 @@ private[akkagrpc] class AkkaGrpcRequestBody(serverOptions: AkkaHttpServerOptions override def toRaw[R](request: ServerRequest, bodyType: RawBodyType[R]): Future[RawValue[R]] = toRawFromEntity(request, akkaRequestEntity(request), bodyType) - override def toStream(request: ServerRequest): streams.BinaryStream = ??? + override def toStream(request: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = ??? private def akkaRequestEntity(request: ServerRequest) = request.underlying.asInstanceOf[RequestContext].request.entity diff --git a/server/akka-http-server/src/main/scala/sttp/tapir/server/akkahttp/AkkaRequestBody.scala b/server/akka-http-server/src/main/scala/sttp/tapir/server/akkahttp/AkkaRequestBody.scala index 4ca5e8b517..31231631b7 100644 --- a/server/akka-http-server/src/main/scala/sttp/tapir/server/akkahttp/AkkaRequestBody.scala +++ b/server/akka-http-server/src/main/scala/sttp/tapir/server/akkahttp/AkkaRequestBody.scala @@ -23,7 +23,10 @@ private[akkahttp] class AkkaRequestBody(serverOptions: AkkaHttpServerOptions)(im override val streams: AkkaStreams = AkkaStreams override def toRaw[R](request: ServerRequest, bodyType: RawBodyType[R]): Future[RawValue[R]] = toRawFromEntity(request, akkeRequestEntity(request), bodyType) - override def toStream(request: ServerRequest): streams.BinaryStream = akkeRequestEntity(request).dataBytes + override def toStream(request: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = { + val stream = akkeRequestEntity(request).dataBytes + maxBytes.map(AkkaStreams.limitBytes(stream, _)).getOrElse(stream) + } private def akkeRequestEntity(request: ServerRequest) = request.underlying.asInstanceOf[RequestContext].request.entity diff --git a/server/akka-http-server/src/test/scala/sttp/tapir/server/akkahttp/AkkaHttpServerTest.scala b/server/akka-http-server/src/test/scala/sttp/tapir/server/akkahttp/AkkaHttpServerTest.scala index ae84c896af..209602d4df 100644 --- a/server/akka-http-server/src/test/scala/sttp/tapir/server/akkahttp/AkkaHttpServerTest.scala +++ b/server/akka-http-server/src/test/scala/sttp/tapir/server/akkahttp/AkkaHttpServerTest.scala @@ -152,8 +152,11 @@ class AkkaHttpServerTest extends TestSuite with EitherValues { } } ) + def drainAkka(stream: AkkaStreams.BinaryStream): Future[Unit] = + stream.runWith(Sink.ignore).map(_ => ()) + new AllServerTests(createServerTest, interpreter, backend).tests() ++ - new ServerStreamingTests(createServerTest, AkkaStreams).tests() ++ + new ServerStreamingTests(createServerTest, maxLengthSupported = true).tests(AkkaStreams)(drainAkka) ++ new ServerWebSocketTests(createServerTest, AkkaStreams) { override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = Flow.fromFunction(f) override def emptyPipe[A, B]: Flow[A, B, Any] = Flow.fromSinkAndSource(Sink.ignore, Source.empty) diff --git a/server/armeria-server/cats/src/main/scala/sttp/tapir/server/armeria/cats/TapirCatsService.scala b/server/armeria-server/cats/src/main/scala/sttp/tapir/server/armeria/cats/TapirCatsService.scala index 481865a9d8..9cdc55593f 100644 --- a/server/armeria-server/cats/src/main/scala/sttp/tapir/server/armeria/cats/TapirCatsService.scala +++ b/server/armeria-server/cats/src/main/scala/sttp/tapir/server/armeria/cats/TapirCatsService.scala @@ -84,8 +84,10 @@ private object Fs2StreamCompatible { dispatcher ) - override def fromArmeriaStream(publisher: Publisher[HttpData]): Stream[F, Byte] = - publisher.toStreamBuffered[F](4).flatMap(httpData => Stream.chunk(Chunk.array(httpData.array()))) + override def fromArmeriaStream(publisher: Publisher[HttpData], maxBytes: Option[Long]): Stream[F, Byte] = { + val stream = publisher.toStreamBuffered[F](4).flatMap(httpData => Stream.chunk(Chunk.array(httpData.array()))) + maxBytes.map(Fs2Streams.limitBytes(stream, _)).getOrElse(stream) + } } } } diff --git a/server/armeria-server/cats/src/test/scala/sttp/tapir/server/armeria/cats/ArmeriaCatsServerTest.scala b/server/armeria-server/cats/src/test/scala/sttp/tapir/server/armeria/cats/ArmeriaCatsServerTest.scala index cd4f904ade..cde75c2a91 100644 --- a/server/armeria-server/cats/src/test/scala/sttp/tapir/server/armeria/cats/ArmeriaCatsServerTest.scala +++ b/server/armeria-server/cats/src/test/scala/sttp/tapir/server/armeria/cats/ArmeriaCatsServerTest.scala @@ -13,9 +13,11 @@ class ArmeriaCatsServerTest extends TestSuite { val interpreter = new ArmeriaCatsTestServerInterpreter(dispatcher) val createServerTest = new DefaultCreateServerTest(backend, interpreter) + def drainFs2(stream: Fs2Streams[IO]#BinaryStream): IO[Unit] = + stream.compile.drain.void new AllServerTests(createServerTest, interpreter, backend, basic = false, options = false).tests() ++ new ServerBasicTests(createServerTest, interpreter, supportsUrlEncodedPathSegments = false).tests() ++ - new ServerStreamingTests(createServerTest, Fs2Streams[IO]).tests() + new ServerStreamingTests(createServerTest, maxLengthSupported = true).tests(Fs2Streams[IO])(drainFs2) } } diff --git a/server/armeria-server/src/main/scala/sttp/tapir/server/armeria/ArmeriaRequestBody.scala b/server/armeria-server/src/main/scala/sttp/tapir/server/armeria/ArmeriaRequestBody.scala index a0db02886b..637513d4a5 100644 --- a/server/armeria-server/src/main/scala/sttp/tapir/server/armeria/ArmeriaRequestBody.scala +++ b/server/armeria-server/src/main/scala/sttp/tapir/server/armeria/ArmeriaRequestBody.scala @@ -23,13 +23,13 @@ private[armeria] final class ArmeriaRequestBody[F[_], S <: Streams[S]]( override val streams: Streams[S] = streamCompatible.streams - override def toStream(serverRequest: ServerRequest): streams.BinaryStream = { + override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = { streamCompatible - .fromArmeriaStream(armeriaCtx(serverRequest).request().filter(x => x.isInstanceOf[HttpData]).asInstanceOf[StreamMessage[HttpData]]) + .fromArmeriaStream(armeriaCtx(serverRequest).request().filter(x => x.isInstanceOf[HttpData]).asInstanceOf[StreamMessage[HttpData]], maxBytes) .asInstanceOf[streams.BinaryStream] } - override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R]): F[RawValue[R]] = { + override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R], maxBytes: Option[Long]): F[RawValue[R]] = { val ctx = armeriaCtx(serverRequest) val request = ctx.request() diff --git a/server/armeria-server/src/main/scala/sttp/tapir/server/armeria/StreamCompatible.scala b/server/armeria-server/src/main/scala/sttp/tapir/server/armeria/StreamCompatible.scala index 0e7c8dda82..02e82f6c1c 100644 --- a/server/armeria-server/src/main/scala/sttp/tapir/server/armeria/StreamCompatible.scala +++ b/server/armeria-server/src/main/scala/sttp/tapir/server/armeria/StreamCompatible.scala @@ -7,5 +7,5 @@ import sttp.capabilities.Streams private[armeria] trait StreamCompatible[S <: Streams[S]] { val streams: S def asStreamMessage(s: streams.BinaryStream): Publisher[HttpData] - def fromArmeriaStream(s: Publisher[HttpData]): streams.BinaryStream + def fromArmeriaStream(s: Publisher[HttpData], maxBytes: Option[Long]): streams.BinaryStream } diff --git a/server/armeria-server/src/main/scala/sttp/tapir/server/armeria/TapirFutureService.scala b/server/armeria-server/src/main/scala/sttp/tapir/server/armeria/TapirFutureService.scala index e71c587bd2..8331844beb 100644 --- a/server/armeria-server/src/main/scala/sttp/tapir/server/armeria/TapirFutureService.scala +++ b/server/armeria-server/src/main/scala/sttp/tapir/server/armeria/TapirFutureService.scala @@ -51,7 +51,7 @@ private[armeria] final case class TapirFutureService( private object ArmeriaStreamCompatible extends StreamCompatible[ArmeriaStreams] { override val streams: ArmeriaStreams = ArmeriaStreams - override def fromArmeriaStream(s: Publisher[HttpData]): Publisher[HttpData] = s + override def fromArmeriaStream(s: Publisher[HttpData], maxBytes: Option[Long]): Publisher[HttpData] = s override def asStreamMessage(s: Publisher[HttpData]): Publisher[HttpData] = s } diff --git a/server/armeria-server/src/test/scala/sttp/tapir/server/armeria/ArmeriaFutureServerTest.scala b/server/armeria-server/src/test/scala/sttp/tapir/server/armeria/ArmeriaFutureServerTest.scala index 7310f0c801..22f546f4da 100644 --- a/server/armeria-server/src/test/scala/sttp/tapir/server/armeria/ArmeriaFutureServerTest.scala +++ b/server/armeria-server/src/test/scala/sttp/tapir/server/armeria/ArmeriaFutureServerTest.scala @@ -5,6 +5,7 @@ import sttp.capabilities.armeria.ArmeriaStreams import sttp.monad.FutureMonad import sttp.tapir.server.tests._ import sttp.tapir.tests.{Test, TestSuite} +import scala.concurrent.Future class ArmeriaFutureServerTest extends TestSuite { @@ -16,6 +17,6 @@ class ArmeriaFutureServerTest extends TestSuite { new AllServerTests(createServerTest, interpreter, backend, basic = false, options = false).tests() ++ new ServerBasicTests(createServerTest, interpreter, supportsUrlEncodedPathSegments = false).tests() ++ - new ServerStreamingTests(createServerTest, ArmeriaStreams).tests() + new ServerStreamingTests(createServerTest, maxLengthSupported = false).tests(ArmeriaStreams)(_ => Future.unit) } } diff --git a/server/armeria-server/zio/src/main/scala/sttp/tapir/server/armeria/zio/TapirZioService.scala b/server/armeria-server/zio/src/main/scala/sttp/tapir/server/armeria/zio/TapirZioService.scala index d80454cc59..218a1574d8 100644 --- a/server/armeria-server/zio/src/main/scala/sttp/tapir/server/armeria/zio/TapirZioService.scala +++ b/server/armeria-server/zio/src/main/scala/sttp/tapir/server/armeria/zio/TapirZioService.scala @@ -76,8 +76,10 @@ private object ZioStreamCompatible { .getOrThrowFiberFailure() ) - override def fromArmeriaStream(publisher: Publisher[HttpData]): Stream[Throwable, Byte] = - publisher.toZIOStream().mapConcatChunk(httpData => Chunk.fromArray(httpData.array())) + override def fromArmeriaStream(publisher: Publisher[HttpData], maxBytes: Option[Long]): Stream[Throwable, Byte] = { + val stream = publisher.toZIOStream().mapConcatChunk(httpData => Chunk.fromArray(httpData.array())) + maxBytes.map(ZioStreams.limitBytes(stream, _)).getOrElse(stream) + } } } } diff --git a/server/armeria-server/zio/src/test/scala/sttp/tapir/server/armeria/zio/ArmeriaZioServerTest.scala b/server/armeria-server/zio/src/test/scala/sttp/tapir/server/armeria/zio/ArmeriaZioServerTest.scala index a3c9026586..17df98f35a 100644 --- a/server/armeria-server/zio/src/test/scala/sttp/tapir/server/armeria/zio/ArmeriaZioServerTest.scala +++ b/server/armeria-server/zio/src/test/scala/sttp/tapir/server/armeria/zio/ArmeriaZioServerTest.scala @@ -7,6 +7,7 @@ import sttp.tapir.server.tests._ import sttp.tapir.tests.{Test, TestSuite} import sttp.tapir.ztapir.RIOMonadError import zio.Task +import zio.stream.ZSink class ArmeriaZioServerTest extends TestSuite { @@ -16,9 +17,11 @@ class ArmeriaZioServerTest extends TestSuite { val interpreter = new ArmeriaZioTestServerInterpreter() val createServerTest = new DefaultCreateServerTest(backend, interpreter) + def drainZStream(zStream: ZioStreams.BinaryStream): Task[Unit] = + zStream.run(ZSink.drain) new AllServerTests(createServerTest, interpreter, backend, basic = false, options = false).tests() ++ new ServerBasicTests(createServerTest, interpreter, supportsUrlEncodedPathSegments = false).tests() ++ - new ServerStreamingTests(createServerTest, ZioStreams).tests() + new ServerStreamingTests(createServerTest, maxLengthSupported = true).tests(ZioStreams)(drainZStream) } } diff --git a/server/armeria-server/zio1/src/main/scala/sttp/tapir/server/armeria/zio/TapirZioService.scala b/server/armeria-server/zio1/src/main/scala/sttp/tapir/server/armeria/zio/TapirZioService.scala index 94a52b4948..47e32b80a5 100644 --- a/server/armeria-server/zio1/src/main/scala/sttp/tapir/server/armeria/zio/TapirZioService.scala +++ b/server/armeria-server/zio1/src/main/scala/sttp/tapir/server/armeria/zio/TapirZioService.scala @@ -72,7 +72,7 @@ private object ZioStreamCompatible { override def asStreamMessage(stream: Stream[Throwable, Byte]): Publisher[HttpData] = runtime.unsafeRun(stream.mapChunks(c => Chunk.single(HttpData.wrap(c.toArray))).toPublisher) - override def fromArmeriaStream(publisher: Publisher[HttpData]): Stream[Throwable, Byte] = + override def fromArmeriaStream(publisher: Publisher[HttpData], maxBytes: Option[Long]): Stream[Throwable, Byte] = publisher.toStream().mapConcatChunk(httpData => Chunk.fromArray(httpData.array())) } } diff --git a/server/armeria-server/zio1/src/test/scala/sttp/tapir/server/armeria/zio/ArmeriaZioServerTest.scala b/server/armeria-server/zio1/src/test/scala/sttp/tapir/server/armeria/zio/ArmeriaZioServerTest.scala index a3c9026586..7d1b356c84 100644 --- a/server/armeria-server/zio1/src/test/scala/sttp/tapir/server/armeria/zio/ArmeriaZioServerTest.scala +++ b/server/armeria-server/zio1/src/test/scala/sttp/tapir/server/armeria/zio/ArmeriaZioServerTest.scala @@ -19,6 +19,6 @@ class ArmeriaZioServerTest extends TestSuite { new AllServerTests(createServerTest, interpreter, backend, basic = false, options = false).tests() ++ new ServerBasicTests(createServerTest, interpreter, supportsUrlEncodedPathSegments = false).tests() ++ - new ServerStreamingTests(createServerTest, ZioStreams).tests() + new ServerStreamingTests(createServerTest, maxLengthSupported = false).tests(ZioStreams)(_ => Task.unit) } } diff --git a/server/core/src/main/scala/sttp/tapir/server/interceptor/decodefailure/DecodeFailureHandler.scala b/server/core/src/main/scala/sttp/tapir/server/interceptor/decodefailure/DecodeFailureHandler.scala index 72c3ac1e56..362f4b8e94 100644 --- a/server/core/src/main/scala/sttp/tapir/server/interceptor/decodefailure/DecodeFailureHandler.scala +++ b/server/core/src/main/scala/sttp/tapir/server/interceptor/decodefailure/DecodeFailureHandler.scala @@ -124,6 +124,7 @@ object DefaultDecodeFailureHandler { respondUnsupportedMediaType case (_: EndpointIO.FixedHeader[_], _) => respondBadRequest case (_: EndpointIO.Headers[_], _) => respondBadRequest + case (_, _: DecodeResult.BodyTooLarge) => respondPayloadTooLarge case (_: EndpointIO.Body[_, _], _) => respondBadRequest case (_: EndpointIO.OneOfBody[_, _], _: DecodeResult.Mismatch) => respondUnsupportedMediaType case (_: EndpointIO.StreamBodyWrapper[_, _], _) => respondBadRequest @@ -143,6 +144,7 @@ object DefaultDecodeFailureHandler { } private val respondBadRequest = Some(onlyStatus(StatusCode.BadRequest)) private val respondUnsupportedMediaType = Some(onlyStatus(StatusCode.UnsupportedMediaType)) + private val respondPayloadTooLarge = Some(onlyStatus(StatusCode.PayloadTooLarge)) def respondNotFoundIfHasAuth( ctx: DecodeFailureContext, @@ -224,10 +226,12 @@ object DefaultDecodeFailureHandler { } .mkString(", ") ) - case Missing => Some("missing") - case Multiple(_) => Some("multiple values") - case Mismatch(_, _) => Some("value mismatch") - case _ => None + case Missing => Some("missing") + case Multiple(_) => Some("multiple values") + case Mismatch(_, _) => Some("value mismatch") + case BodyTooLarge(maxBytes) => Some(s"Content length limit: $maxBytes bytes") + case _: Error => None + case _: InvalidValue => None } def combineSourceAndDetail(source: String, detail: Option[String]): String = diff --git a/server/core/src/main/scala/sttp/tapir/server/interceptor/exception/ExceptionHandler.scala b/server/core/src/main/scala/sttp/tapir/server/interceptor/exception/ExceptionHandler.scala index b9c1a0a683..2c9182be4f 100644 --- a/server/core/src/main/scala/sttp/tapir/server/interceptor/exception/ExceptionHandler.scala +++ b/server/core/src/main/scala/sttp/tapir/server/interceptor/exception/ExceptionHandler.scala @@ -1,5 +1,6 @@ package sttp.tapir.server.interceptor.exception +import sttp.capabilities.StreamMaxLengthExceededException import sttp.model.StatusCode import sttp.monad.MonadError import sttp.tapir.server.model.ValuedEndpointOutput @@ -25,7 +26,12 @@ object ExceptionHandler { case class DefaultExceptionHandler[F[_]](response: (StatusCode, String) => ValuedEndpointOutput[_]) extends ExceptionHandler[F] { override def apply(ctx: ExceptionContext)(implicit monad: MonadError[F]): F[Option[ValuedEndpointOutput[_]]] = - monad.unit(Some(response(StatusCode.InternalServerError, "Internal server error"))) + ctx.e match { + case StreamMaxLengthExceededException(maxBytes) => + monad.unit(Some(response(StatusCode.PayloadTooLarge, s"Payload limit (${maxBytes}B) exceeded"))) + case _ => + monad.unit(Some(response(StatusCode.InternalServerError, "Internal server error"))) + } } object DefaultExceptionHandler { diff --git a/server/core/src/main/scala/sttp/tapir/server/interpreter/RequestBody.scala b/server/core/src/main/scala/sttp/tapir/server/interpreter/RequestBody.scala index 2310c43172..5292a8619b 100644 --- a/server/core/src/main/scala/sttp/tapir/server/interpreter/RequestBody.scala +++ b/server/core/src/main/scala/sttp/tapir/server/interpreter/RequestBody.scala @@ -3,12 +3,17 @@ package sttp.tapir.server.interpreter import sttp.capabilities.Streams import sttp.model.Part import sttp.tapir.model.ServerRequest +import sttp.tapir.AttributeKey +import sttp.tapir.EndpointInfo import sttp.tapir.{FileRange, RawBodyType, RawPart} +case class MaxContentLength(value: Long) + trait RequestBody[F[_], S] { val streams: Streams[S] - def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R]): F[RawValue[R]] - def toStream(serverRequest: ServerRequest): streams.BinaryStream + def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R], maxBytes: Option[Long]): F[RawValue[R]] + def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream + } case class RawValue[R](value: R, createdFiles: Seq[FileRange] = Nil) diff --git a/server/core/src/main/scala/sttp/tapir/server/interpreter/RequestBodyToRawException.scala b/server/core/src/main/scala/sttp/tapir/server/interpreter/RequestBodyToRawException.scala new file mode 100644 index 0000000000..965877aebb --- /dev/null +++ b/server/core/src/main/scala/sttp/tapir/server/interpreter/RequestBodyToRawException.scala @@ -0,0 +1,8 @@ +package sttp.tapir.server.interpreter + +import sttp.tapir.DecodeResult + +/** Can be used with RequestBody.toRaw to fail its effect F and pass failures, which are treated as decoding failures that happen before + * actual decoding of the raw value. + */ +private[tapir] case class RequestBodyToRawException(failure: DecodeResult.Failure) extends Exception diff --git a/server/core/src/main/scala/sttp/tapir/server/interpreter/ServerInterpreter.scala b/server/core/src/main/scala/sttp/tapir/server/interpreter/ServerInterpreter.scala index 629bc5a16b..49bb18e9b3 100644 --- a/server/core/src/main/scala/sttp/tapir/server/interpreter/ServerInterpreter.scala +++ b/server/core/src/main/scala/sttp/tapir/server/interpreter/ServerInterpreter.scala @@ -1,14 +1,18 @@ package sttp.tapir.server.interpreter +import sttp.capabilities.StreamMaxLengthExceededException import sttp.model.{Headers, StatusCode} import sttp.monad.MonadError import sttp.monad.syntax._ import sttp.tapir.internal.{Params, ParamsAsAny, RichOneOfBody} import sttp.tapir.model.ServerRequest -import sttp.tapir.server.{model, _} import sttp.tapir.server.interceptor._ import sttp.tapir.server.model.{ServerResponse, ValuedEndpointOutput} +import sttp.tapir.server.{model, _} +import sttp.tapir.{AttributeKey, DecodeResult, EndpointIO, EndpointInfo, EndpointInput, TapirFile} import sttp.tapir.{DecodeResult, EndpointIO, EndpointInput, TapirFile} +import sttp.tapir.EndpointInfo +import sttp.tapir.AttributeKey class ServerInterpreter[R, F[_], B, S]( serverEndpoints: ServerRequest => List[ServerEndpoint[R, F]], @@ -106,7 +110,7 @@ class ServerInterpreter[R, F[_], B, S]( // index (so that the correct one is passed to the decode failure handler) _ <- resultOrValueFrom(DecodeBasicInputsResult.higherPriorityFailure(securityBasicInputs, regularBasicInputs)) // 3. computing the security input value - securityValues <- resultOrValueFrom(decodeBody(request, securityBasicInputs)) + securityValues <- resultOrValueFrom(decodeBody(request, securityBasicInputs, se.info)) securityParams <- resultOrValueFrom(InputValue(se.endpoint.securityInput, securityValues)) inputValues <- resultOrValueFrom(regularBasicInputs) a = securityParams.asAny.asInstanceOf[A] @@ -132,7 +136,7 @@ class ServerInterpreter[R, F[_], B, S]( case Right(u) => for { // 5. decoding the body of regular inputs, computing the input value, and running the main logic - values <- resultOrValueFrom(decodeBody(request, inputValues)) + values <- resultOrValueFrom(decodeBody(request, inputValues, se.endpoint.info)) params <- resultOrValueFrom(InputValue(se.endpoint.input, values)) response <- resultOrValueFrom.value( endpointHandler(defaultSecurityFailureResponse, endpointInterceptors) @@ -146,19 +150,22 @@ class ServerInterpreter[R, F[_], B, S]( private def decodeBody( request: ServerRequest, - result: DecodeBasicInputsResult + result: DecodeBasicInputsResult, + endpointInfo: EndpointInfo ): F[DecodeBasicInputsResult] = result match { case values: DecodeBasicInputsResult.Values => + val maxBodyLength = endpointInfo.attribute(AttributeKey[MaxContentLength]).map(_.value) values.bodyInputWithIndex match { case Some((Left(oneOfBodyInput), _)) => oneOfBodyInput.chooseBodyToDecode(request.contentTypeParsed) match { - case Some(Left(body)) => decodeBody(request, values, body) - case Some(Right(body: EndpointIO.StreamBodyWrapper[Any, Any])) => decodeStreamingBody(request, values, body) + case Some(Left(body)) => decodeBody(request, values, body, maxBodyLength) + case Some(Right(body: EndpointIO.StreamBodyWrapper[Any, Any])) => decodeStreamingBody(request, values, body, maxBodyLength) case None => unsupportedInputMediaTypeResponse(request, oneOfBodyInput) } - case Some((Right(bodyInput: EndpointIO.StreamBodyWrapper[Any, Any]), _)) => decodeStreamingBody(request, values, bodyInput) - case None => (values: DecodeBasicInputsResult).unit + case Some((Right(bodyInput: EndpointIO.StreamBodyWrapper[Any, Any]), _)) => + decodeStreamingBody(request, values, bodyInput, maxBodyLength) + case None => (values: DecodeBasicInputsResult).unit } case failure: DecodeBasicInputsResult.Failure => (failure: DecodeBasicInputsResult).unit } @@ -166,9 +173,10 @@ class ServerInterpreter[R, F[_], B, S]( private def decodeStreamingBody( request: ServerRequest, values: DecodeBasicInputsResult.Values, - bodyInput: EndpointIO.StreamBodyWrapper[Any, Any] + bodyInput: EndpointIO.StreamBodyWrapper[Any, Any], + maxBodyLength: Option[Long] ): F[DecodeBasicInputsResult] = - (bodyInput.codec.decode(requestBody.toStream(request)) match { + (bodyInput.codec.decode(requestBody.toStream(request, maxBodyLength)) match { case DecodeResult.Value(bodyV) => values.setBodyInputValue(bodyV) case failure: DecodeResult.Failure => DecodeBasicInputsResult.Failure(bodyInput, failure): DecodeBasicInputsResult }).unit @@ -176,17 +184,26 @@ class ServerInterpreter[R, F[_], B, S]( private def decodeBody[RAW, T]( request: ServerRequest, values: DecodeBasicInputsResult.Values, - bodyInput: EndpointIO.Body[RAW, T] + bodyInput: EndpointIO.Body[RAW, T], + maxBodyLength: Option[Long] ): F[DecodeBasicInputsResult] = { - requestBody.toRaw(request, bodyInput.bodyType).flatMap { v => - bodyInput.codec.decode(v.value) match { - case DecodeResult.Value(bodyV) => (values.setBodyInputValue(bodyV): DecodeBasicInputsResult).unit - case failure: DecodeResult.Failure => - v.createdFiles - .foldLeft(monad.unit(()))((u, f) => u.flatMap(_ => deleteFile(f.file))) - .map(_ => DecodeBasicInputsResult.Failure(bodyInput, failure): DecodeBasicInputsResult) + requestBody + .toRaw(request, bodyInput.bodyType, maxBodyLength) + .flatMap { v => + bodyInput.codec.decode(v.value) match { + case DecodeResult.Value(bodyV) => (values.setBodyInputValue(bodyV): DecodeBasicInputsResult).unit + case failure: DecodeResult.Failure => + v.createdFiles + .foldLeft(monad.unit(()))((u, f) => u.flatMap(_ => deleteFile(f.file))) + .map(_ => DecodeBasicInputsResult.Failure(bodyInput, failure): DecodeBasicInputsResult) + } + } + .handleError { + case RequestBodyToRawException(failure) => + (DecodeBasicInputsResult.Failure(bodyInput, failure): DecodeBasicInputsResult).unit + case StreamMaxLengthExceededException(maxBytes) => + (DecodeBasicInputsResult.Failure(bodyInput, DecodeResult.BodyTooLarge(maxBytes)): DecodeBasicInputsResult).unit } - } } private def unsupportedInputMediaTypeResponse( diff --git a/server/core/src/test/scala/sttp/tapir/server/TestUtil.scala b/server/core/src/test/scala/sttp/tapir/server/TestUtil.scala index 6a3f80507a..4cb6f5b882 100644 --- a/server/core/src/test/scala/sttp/tapir/server/TestUtil.scala +++ b/server/core/src/test/scala/sttp/tapir/server/TestUtil.scala @@ -14,8 +14,8 @@ import scala.util.{Success, Try} object TestUtil { object TestRequestBody extends RequestBody[Id, NoStreams] { override val streams: Streams[NoStreams] = NoStreams - override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R]): Id[RawValue[R]] = ??? - override def toStream(serverRequest: ServerRequest): streams.BinaryStream = ??? + override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R], maxBytes: Option[Long]): Id[RawValue[R]] = ??? + override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = ??? } object UnitToResponseBody extends ToResponseBody[Unit, NoStreams] { diff --git a/server/finatra-server/src/main/scala/sttp/tapir/server/finatra/FinatraRequestBody.scala b/server/finatra-server/src/main/scala/sttp/tapir/server/finatra/FinatraRequestBody.scala index 7e7220ceec..9a6071bc14 100644 --- a/server/finatra-server/src/main/scala/sttp/tapir/server/finatra/FinatraRequestBody.scala +++ b/server/finatra-server/src/main/scala/sttp/tapir/server/finatra/FinatraRequestBody.scala @@ -114,7 +114,8 @@ class FinatraRequestBody(serverOptions: FinatraServerOptions) extends RequestBod .map(_.toList) } - override def toStream(serverRequest: ServerRequest): streams.BinaryStream = throw new UnsupportedOperationException() + override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = + throw new UnsupportedOperationException() private def finatraRequest(serverRequest: ServerRequest) = serverRequest.underlying.asInstanceOf[Request] } diff --git a/server/http4s-server/src/main/scala/sttp/tapir/server/http4s/Http4sRequestBody.scala b/server/http4s-server/src/main/scala/sttp/tapir/server/http4s/Http4sRequestBody.scala index cbd7e1fadc..eba4433a89 100644 --- a/server/http4s-server/src/main/scala/sttp/tapir/server/http4s/Http4sRequestBody.scala +++ b/server/http4s-server/src/main/scala/sttp/tapir/server/http4s/Http4sRequestBody.scala @@ -18,11 +18,14 @@ private[http4s] class Http4sRequestBody[F[_]: Async]( serverOptions: Http4sServerOptions[F] ) extends RequestBody[F, Fs2Streams[F]] { override val streams: Fs2Streams[F] = Fs2Streams[F] - override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R]): F[RawValue[R]] = { + override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R], maxBytes: Option[Long]): F[RawValue[R]] = { val r = http4sRequest(serverRequest) toRawFromStream(serverRequest, r.body, bodyType, r.charset) } - override def toStream(serverRequest: ServerRequest): streams.BinaryStream = http4sRequest(serverRequest).body + override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = { + val stream = http4sRequest(serverRequest).body + maxBytes.map(Fs2Streams.limitBytes(stream, _)).getOrElse(stream) + } private def http4sRequest(serverRequest: ServerRequest): Request[F] = serverRequest.underlying.asInstanceOf[Request[F]] diff --git a/server/http4s-server/src/test/scala/sttp/tapir/server/http4s/Http4sServerTest.scala b/server/http4s-server/src/test/scala/sttp/tapir/server/http4s/Http4sServerTest.scala index bf80ef1cad..86ddfd58b4 100644 --- a/server/http4s-server/src/test/scala/sttp/tapir/server/http4s/Http4sServerTest.scala +++ b/server/http4s-server/src/test/scala/sttp/tapir/server/http4s/Http4sServerTest.scala @@ -133,8 +133,11 @@ class Http4sServerTest[R >: Fs2Streams[IO] with WebSockets] extends TestSuite wi } ) + def drainFs2(stream: Fs2Streams[IO]#BinaryStream): IO[Unit] = + stream.compile.drain.void + new AllServerTests(createServerTest, interpreter, backend).tests() ++ - new ServerStreamingTests(createServerTest, Fs2Streams[IO]).tests() ++ + new ServerStreamingTests(createServerTest, maxLengthSupported = true).tests(Fs2Streams[IO])(drainFs2) ++ new ServerWebSocketTests(createServerTest, Fs2Streams[IO]) { override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = in => in.map(f) override def emptyPipe[A, B]: Pipe[IO, A, B] = _ => fs2.Stream.empty diff --git a/server/http4s-server/zio/src/test/scala/sttp/tapir/server/http4s/ztapir/ZHttp4sServerTest.scala b/server/http4s-server/zio/src/test/scala/sttp/tapir/server/http4s/ztapir/ZHttp4sServerTest.scala index 22f502dec4..7935098791 100644 --- a/server/http4s-server/zio/src/test/scala/sttp/tapir/server/http4s/ztapir/ZHttp4sServerTest.scala +++ b/server/http4s-server/zio/src/test/scala/sttp/tapir/server/http4s/ztapir/ZHttp4sServerTest.scala @@ -14,7 +14,7 @@ import sttp.tapir.server.http4s.Http4sServerSentEvents import sttp.tapir.server.tests._ import sttp.tapir.tests.{Test, TestSuite} import zio.interop.catz._ -import zio.stream.ZStream +import zio.stream.{ZSink, ZStream} import zio.{Task, ZIO} import java.util.UUID @@ -50,9 +50,11 @@ class ZHttp4sServerTest extends TestSuite with OptionValues { .map(_.body.toOption.value shouldBe List(sse1, sse2)) } ) + def drainZStream(zStream: ZioStreams.BinaryStream): Task[Unit] = + zStream.run(ZSink.drain) new AllServerTests(createServerTest, interpreter, backend).tests() ++ - new ServerStreamingTests(createServerTest, ZioStreams).tests() ++ + new ServerStreamingTests(createServerTest, maxLengthSupported = true).tests(ZioStreams)(drainZStream) ++ new ServerWebSocketTests(createServerTest, ZioStreams) { override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = in => in.map(f) override def emptyPipe[A, B]: streams.Pipe[A, B] = _ => ZStream.empty diff --git a/server/http4s-server/zio1/src/test/scala/sttp/tapir/server/http4s/ztapir/ZHttp4sServerTest.scala b/server/http4s-server/zio1/src/test/scala/sttp/tapir/server/http4s/ztapir/ZHttp4sServerTest.scala index d848cdd5d3..5c15e3e06a 100644 --- a/server/http4s-server/zio1/src/test/scala/sttp/tapir/server/http4s/ztapir/ZHttp4sServerTest.scala +++ b/server/http4s-server/zio1/src/test/scala/sttp/tapir/server/http4s/ztapir/ZHttp4sServerTest.scala @@ -10,7 +10,7 @@ import sttp.tapir._ import sttp.tapir.integ.cats.effect.CatsMonadError import sttp.tapir.server.tests._ import sttp.tapir.tests.{Test, TestSuite} -import zio.{RIO, UIO} +import zio.{RIO, Task, UIO} import zio.blocking.Blocking import zio.clock.Clock import zio.interop.catz._ @@ -53,7 +53,7 @@ class ZHttp4sServerTest extends TestSuite with OptionValues { ) new AllServerTests(createServerTest, interpreter, backend).tests() ++ - new ServerStreamingTests(createServerTest, ZioStreams).tests() ++ + new ServerStreamingTests(createServerTest, maxLengthSupported = false).tests(ZioStreams)(_ => Task.unit) ++ new ServerWebSocketTests(createServerTest, ZioStreams) { override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = in => in.map(f) override def emptyPipe[A, B]: streams.Pipe[A, B] = _ => zio.stream.Stream.empty diff --git a/server/jdkhttp-server/src/main/scala/sttp/tapir/server/jdkhttp/internal/JdkHttpRequestBody.scala b/server/jdkhttp-server/src/main/scala/sttp/tapir/server/jdkhttp/internal/JdkHttpRequestBody.scala index 4ebbaf4ed9..5164a29f66 100644 --- a/server/jdkhttp-server/src/main/scala/sttp/tapir/server/jdkhttp/internal/JdkHttpRequestBody.scala +++ b/server/jdkhttp-server/src/main/scala/sttp/tapir/server/jdkhttp/internal/JdkHttpRequestBody.scala @@ -14,10 +14,11 @@ import java.io._ import java.nio.ByteBuffer import java.nio.file.{Files, StandardCopyOption} -private[jdkhttp] class JdkHttpRequestBody(createFile: ServerRequest => TapirFile, multipartFileThresholdBytes: Long) extends RequestBody[Id, NoStreams] { +private[jdkhttp] class JdkHttpRequestBody(createFile: ServerRequest => TapirFile, multipartFileThresholdBytes: Long) + extends RequestBody[Id, NoStreams] { override val streams: capabilities.Streams[NoStreams] = NoStreams - override def toRaw[RAW](serverRequest: ServerRequest, bodyType: RawBodyType[RAW]): RawValue[RAW] = { + override def toRaw[RAW](serverRequest: ServerRequest, bodyType: RawBodyType[RAW], maxBytes: Option[Long]): RawValue[RAW] = { val request = jdkHttpRequest(serverRequest) toRaw(serverRequest, bodyType, request.getRequestBody) } @@ -76,9 +77,10 @@ private[jdkhttp] class JdkHttpRequestBody(createFile: ServerRequest => TapirFile ) } - override def toStream(serverRequest: ServerRequest): streams.BinaryStream = throw new UnsupportedOperationException( - "Streaming is not supported" - ) + override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = + throw new UnsupportedOperationException( + "Streaming is not supported" + ) private def jdkHttpRequest(serverRequest: ServerRequest): HttpExchange = serverRequest.underlying.asInstanceOf[HttpExchange] diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsRequestBody.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsRequestBody.scala index 17e2ed0a2a..90d8a3a81a 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsRequestBody.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsRequestBody.scala @@ -10,33 +10,48 @@ import io.netty.buffer.ByteBufUtil import io.netty.handler.codec.http.{FullHttpRequest, HttpContent} import sttp.capabilities.fs2.Fs2Streams import sttp.tapir.model.ServerRequest -import sttp.tapir.server.interpreter.{RawValue, RequestBody} +import sttp.tapir.server.interpreter.{RawValue, RequestBody, RequestBodyToRawException} import sttp.tapir.{FileRange, InputStreamRange, RawBodyType, TapirFile} import java.io.ByteArrayInputStream import java.nio.ByteBuffer +import sttp.tapir.DecodeResult private[netty] class NettyCatsRequestBody[F[_]](createFile: ServerRequest => F[TapirFile])(implicit val monad: Async[F]) extends RequestBody[F, Fs2Streams[F]] { override val streams: Fs2Streams[F] = Fs2Streams[F] - override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R]): F[RawValue[R]] = { + override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R], maxBytes: Option[Long]): F[RawValue[R]] = { + def nettyRequestBytes: F[Array[Byte]] = serverRequest.underlying match { + case req: FullHttpRequest => + val buf = req.content() + maxBytes + .map(max => + if (buf.readableBytes() > max) + monad.raiseError[Array[Byte]](RequestBodyToRawException(DecodeResult.BodyTooLarge(max))) + else + monad.delay(ByteBufUtil.getBytes(buf)) + ) + .getOrElse(monad.delay(ByteBufUtil.getBytes(buf))) + case _: StreamedHttpRequest => toStream(serverRequest, maxBytes).compile.to(Chunk).map(_.toArray[Byte]) + case other => monad.raiseError(new UnsupportedOperationException(s"Unexpected Netty request of type ${other.getClass().getName()}")) + } bodyType match { - case RawBodyType.StringBody(charset) => nettyRequestBytes(serverRequest).map(bs => RawValue(new String(bs, charset))) + case RawBodyType.StringBody(charset) => nettyRequestBytes.map(bs => RawValue(new String(bs, charset))) case RawBodyType.ByteArrayBody => - nettyRequestBytes(serverRequest).map(RawValue(_)) + nettyRequestBytes.map(RawValue(_)) case RawBodyType.ByteBufferBody => - nettyRequestBytes(serverRequest).map(bs => RawValue(ByteBuffer.wrap(bs))) + nettyRequestBytes.map(bs => RawValue(ByteBuffer.wrap(bs))) case RawBodyType.InputStreamBody => - nettyRequestBytes(serverRequest).map(bs => RawValue(new ByteArrayInputStream(bs))) + nettyRequestBytes.map(bs => RawValue(new ByteArrayInputStream(bs))) case RawBodyType.InputStreamRangeBody => - nettyRequestBytes(serverRequest).map(bs => RawValue(InputStreamRange(() => new ByteArrayInputStream(bs)))) + nettyRequestBytes.map(bs => RawValue(InputStreamRange(() => new ByteArrayInputStream(bs)))) case RawBodyType.FileBody => createFile(serverRequest) .flatMap(tapirFile => { - toStream(serverRequest) + toStream(serverRequest, maxBytes) .through( Files[F](Files.forAsync[F]).writeAll(Path.fromNioPath(tapirFile.toPath)) ) @@ -48,17 +63,12 @@ private[netty] class NettyCatsRequestBody[F[_]](createFile: ServerRequest => F[T } } - override def toStream(serverRequest: ServerRequest): streams.BinaryStream = { + override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = { val nettyRequest = serverRequest.underlying.asInstanceOf[StreamedHttpRequest] - fs2.Stream + val stream = fs2.Stream .eval(StreamSubscriber[F, HttpContent](NettyRequestBody.DefaultChunkSize)) .flatMap(s => s.sub.stream(Sync[F].delay(nettyRequest.subscribe(s)))) .flatMap(httpContent => fs2.Stream.chunk(Chunk.byteBuffer(httpContent.content.nioBuffer()))) - } - - private def nettyRequestBytes(serverRequest: ServerRequest): F[Array[Byte]] = serverRequest.underlying match { - case req: FullHttpRequest => monad.delay(ByteBufUtil.getBytes(req.content())) - case _: StreamedHttpRequest => toStream(serverRequest).compile.to(Chunk).map(_.toArray[Byte]) - case other => monad.raiseError(new UnsupportedOperationException(s"Unexpected Netty request of type ${other.getClass().getName()}")) + maxBytes.map(Fs2Streams.limitBytes(stream, _)).getOrElse(stream) } } diff --git a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala index cde65b3d38..dd9c9f865e 100644 --- a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala +++ b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala @@ -27,6 +27,8 @@ class NettyCatsServerTest extends TestSuite with EitherValues { val ioSleeper: Sleeper[IO] = new Sleeper[IO] { override def sleep(duration: FiniteDuration): IO[Unit] = IO.sleep(duration) } + def drainFs2(stream: Fs2Streams[IO]#BinaryStream): IO[Unit] = + stream.compile.drain.void val tests = new AllServerTests( createServerTest, @@ -36,7 +38,7 @@ class NettyCatsServerTest extends TestSuite with EitherValues { maxContentLength = Some(NettyCatsTestServerInterpreter.maxContentLength) ) .tests() ++ - new ServerStreamingTests(createServerTest, Fs2Streams[IO]).tests() ++ + new ServerStreamingTests(createServerTest, maxLengthSupported = true).tests(Fs2Streams[IO])(drainFs2) ++ new ServerCancellationTests(createServerTest)(m, IO.asyncForIO).tests() ++ new NettyFs2StreamingCancellationTest(createServerTest).tests() ++ new ServerGracefulShutdownTests(createServerTest, ioSleeper).tests() diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyRequestBody.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyRequestBody.scala index 6c9d251676..6f36268e35 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyRequestBody.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyRequestBody.scala @@ -8,10 +8,12 @@ import sttp.tapir.{FileRange, InputStreamRange, RawBodyType, TapirFile} import sttp.tapir.model.ServerRequest import sttp.monad.syntax._ import sttp.tapir.capabilities.NoStreams -import sttp.tapir.server.interpreter.{RawValue, RequestBody} +import sttp.tapir.server.interpreter.{RawValue, RequestBody, RequestBodyToRawException} import java.nio.ByteBuffer import java.nio.file.Files +import io.netty.buffer.ByteBuf +import sttp.tapir.DecodeResult class NettyRequestBody[F[_]](createFile: ServerRequest => F[TapirFile])(implicit monadError: MonadError[F] @@ -19,29 +21,44 @@ class NettyRequestBody[F[_]](createFile: ServerRequest => F[TapirFile])(implicit override val streams: capabilities.Streams[NoStreams] = NoStreams - override def toRaw[RAW](serverRequest: ServerRequest, bodyType: RawBodyType[RAW]): F[RawValue[RAW]] = { + override def toRaw[RAW](serverRequest: ServerRequest, bodyType: RawBodyType[RAW], maxBytes: Option[Long]): F[RawValue[RAW]] = { + + def byteBuf: F[ByteBuf] = { + val buf = nettyRequest(serverRequest).content() + maxBytes + .map(max => + if (buf.readableBytes() > max) + monadError.error[ByteBuf](RequestBodyToRawException(DecodeResult.BodyTooLarge(max))) + else + monadError.unit(buf) + ) + .getOrElse(monadError.unit(buf)) + } /** [[ByteBufUtil.getBytes(io.netty.buffer.ByteBuf)]] copies buffer without affecting reader index of the original. */ - def requestContentAsByteArray = ByteBufUtil.getBytes(nettyRequest(serverRequest).content()) + def requestContentAsByteArray: F[Array[Byte]] = byteBuf.map(ByteBufUtil.getBytes) bodyType match { - case RawBodyType.StringBody(charset) => monadError.unit(RawValue(nettyRequest(serverRequest).content().toString(charset))) - case RawBodyType.ByteArrayBody => monadError.unit(RawValue(requestContentAsByteArray)) - case RawBodyType.ByteBufferBody => monadError.unit(RawValue(ByteBuffer.wrap(requestContentAsByteArray))) - case RawBodyType.InputStreamBody => monadError.unit(RawValue(new ByteBufInputStream(nettyRequest(serverRequest).content()))) + case RawBodyType.StringBody(charset) => byteBuf.map(buf => RawValue(buf.toString(charset))) + case RawBodyType.ByteArrayBody => requestContentAsByteArray.map(ba => RawValue(ba)) + case RawBodyType.ByteBufferBody => requestContentAsByteArray.map(ba => RawValue(ByteBuffer.wrap(ba))) + case RawBodyType.InputStreamBody => byteBuf.map(buf => RawValue(new ByteBufInputStream(buf))) case RawBodyType.InputStreamRangeBody => - monadError.unit(RawValue(InputStreamRange(() => new ByteBufInputStream(nettyRequest(serverRequest).content())))) + byteBuf.map(buf => RawValue(InputStreamRange(() => new ByteBufInputStream(buf)))) case RawBodyType.FileBody => - createFile(serverRequest) - .map(file => { - Files.write(file.toPath, requestContentAsByteArray) - RawValue(FileRange(file), Seq(FileRange(file))) - }) + requestContentAsByteArray.flatMap(ba => + createFile(serverRequest) + .map(file => { + Files.write(file.toPath, ba) + RawValue(FileRange(file), Seq(FileRange(file))) + }) + ) case _: RawBodyType.MultipartBody => ??? } } - override def toStream(serverRequest: ServerRequest): streams.BinaryStream = throw new UnsupportedOperationException() + override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = + throw new UnsupportedOperationException() private def nettyRequest(serverRequest: ServerRequest): FullHttpRequest = serverRequest.underlying.asInstanceOf[FullHttpRequest] } diff --git a/server/netty-server/src/test/scala/sttp/tapir/server/netty/NettyFutureServerTest.scala b/server/netty-server/src/test/scala/sttp/tapir/server/netty/NettyFutureServerTest.scala index b7d86c5e4a..d1485b96c5 100644 --- a/server/netty-server/src/test/scala/sttp/tapir/server/netty/NettyFutureServerTest.scala +++ b/server/netty-server/src/test/scala/sttp/tapir/server/netty/NettyFutureServerTest.scala @@ -21,7 +21,7 @@ class NettyFutureServerTest extends TestSuite with EitherValues { val interpreter = new NettyFutureTestServerInterpreter(eventLoopGroup) val createServerTest = new DefaultCreateServerTest(backend, interpreter) - val tests = new AllServerTests(createServerTest, interpreter, backend, multipart = false).tests() ++ + val tests = new AllServerTests(createServerTest, interpreter, backend, multipart = false, maxContentLength = Some(300)).tests() ++ new ServerGracefulShutdownTests(createServerTest, Sleeper.futureSleeper).tests() (tests, eventLoopGroup) diff --git a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/internal/NettyZioRequestBody.scala b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/internal/NettyZioRequestBody.scala index eff6650e08..7ca54ce22e 100644 --- a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/internal/NettyZioRequestBody.scala +++ b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/internal/NettyZioRequestBody.scala @@ -6,7 +6,7 @@ import io.netty.handler.codec.http.FullHttpRequest import sttp.capabilities.zio.ZioStreams import sttp.tapir.RawBodyType._ import sttp.tapir.model.ServerRequest -import sttp.tapir.server.interpreter.{RawValue, RequestBody} +import sttp.tapir.server.interpreter.{RawValue, RequestBody, RequestBodyToRawException} import sttp.tapir.{FileRange, InputStreamRange, RawBodyType, TapirFile} import zio.interop.reactivestreams._ import zio.stream.{ZStream, _} @@ -14,28 +14,45 @@ import zio.{Chunk, RIO, ZIO} import java.io.ByteArrayInputStream import java.nio.ByteBuffer +import sttp.tapir.DecodeResult private[netty] class NettyZioRequestBody[Env](createFile: ServerRequest => RIO[Env, TapirFile]) extends RequestBody[RIO[Env, *], ZioStreams] { override val streams: ZioStreams = ZioStreams - override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R]): RIO[Env, RawValue[R]] = { + override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R], maxBytes: Option[Long]): RIO[Env, RawValue[R]] = { + + def nettyRequestBytes: RIO[Env, Array[Byte]] = serverRequest.underlying match { + case req: FullHttpRequest => + val buf = req.content() + maxBytes + .map(max => + if (buf.readableBytes() > max) + ZIO.fail(RequestBodyToRawException(DecodeResult.BodyTooLarge(max))) + else + ZIO.succeed(ByteBufUtil.getBytes(buf)) + ) + .getOrElse(ZIO.succeed(ByteBufUtil.getBytes(buf))) + + case _: StreamedHttpRequest => toStream(serverRequest, maxBytes).run(ZSink.collectAll[Byte]).map(_.toArray) + case other => ZIO.fail(new UnsupportedOperationException(s"Unexpected Netty request of type ${other.getClass().getName()}")) + } bodyType match { - case StringBody(charset) => nettyRequestBytes(serverRequest).map(bs => RawValue(new String(bs, charset))) + case StringBody(charset) => nettyRequestBytes.map(bs => RawValue(new String(bs, charset))) case ByteArrayBody => - nettyRequestBytes(serverRequest).map(RawValue(_)) + nettyRequestBytes.map(RawValue(_)) case ByteBufferBody => - nettyRequestBytes(serverRequest).map(bs => RawValue(ByteBuffer.wrap(bs))) + nettyRequestBytes.map(bs => RawValue(ByteBuffer.wrap(bs))) case InputStreamBody => - nettyRequestBytes(serverRequest).map(bs => RawValue(new ByteArrayInputStream(bs))) + nettyRequestBytes.map(bs => RawValue(new ByteArrayInputStream(bs))) case InputStreamRangeBody => - nettyRequestBytes(serverRequest).map(bs => RawValue(InputStreamRange(() => new ByteArrayInputStream(bs)))) + nettyRequestBytes.map(bs => RawValue(InputStreamRange(() => new ByteArrayInputStream(bs)))) case FileBody => - createFile(serverRequest) + createFile(serverRequest) .flatMap(tapirFile => { - toStream(serverRequest) + toStream(serverRequest, maxBytes) .run(ZSink.fromFile(tapirFile)) .map(_ => RawValue(FileRange(tapirFile), Seq(FileRange(tapirFile)))) }) @@ -44,18 +61,11 @@ private[netty] class NettyZioRequestBody[Env](createFile: ServerRequest => RIO[E } } - override def toStream(serverRequest: ServerRequest): streams.BinaryStream = { - - serverRequest.underlying + override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = { + val stream = serverRequest.underlying .asInstanceOf[StreamedHttpRequest] .toZIOStream() .flatMap(httpContent => ZStream.fromChunk(Chunk.fromByteBuffer(httpContent.content.nioBuffer()))) + maxBytes.map(ZioStreams.limitBytes(stream, _)).getOrElse(stream) } - - private def nettyRequestBytes(serverRequest: ServerRequest): RIO[Env, Array[Byte]] = serverRequest.underlying match { - case req: FullHttpRequest => ZIO.succeed(ByteBufUtil.getBytes(req.content())) - case _: StreamedHttpRequest => toStream(serverRequest).run(ZSink.collectAll[Byte]).map(_.toArray) - case other => ZIO.fail(new UnsupportedOperationException(s"Unexpected Netty request of type ${other.getClass().getName()}")) - } - } diff --git a/server/netty-server/zio/src/test/scala/sttp/tapir/server/netty/zio/NettyZioServerTest.scala b/server/netty-server/zio/src/test/scala/sttp/tapir/server/netty/zio/NettyZioServerTest.scala index 1486c0bb6c..fc29fd4960 100644 --- a/server/netty-server/zio/src/test/scala/sttp/tapir/server/netty/zio/NettyZioServerTest.scala +++ b/server/netty-server/zio/src/test/scala/sttp/tapir/server/netty/zio/NettyZioServerTest.scala @@ -14,8 +14,12 @@ import zio.{Task, ZIO} import scala.concurrent.Future import scala.concurrent.duration.FiniteDuration +import zio.stream.ZSink class NettyZioServerTest extends TestSuite with EitherValues { + def drainZStream(zStream: ZioStreams.BinaryStream): Task[Unit] = + zStream.run(ZSink.drain) + override def tests: Resource[IO, List[Test]] = backendResource.flatMap { backend => Resource @@ -30,8 +34,15 @@ class NettyZioServerTest extends TestSuite with EitherValues { } val tests = - new AllServerTests(createServerTest, interpreter, backend, staticContent = false, multipart = false).tests() ++ - new ServerStreamingTests(createServerTest, ZioStreams).tests() ++ + new AllServerTests( + createServerTest, + interpreter, + backend, + staticContent = false, + multipart = false, + maxContentLength = Some(300) + ).tests() ++ + new ServerStreamingTests(createServerTest, maxLengthSupported = true).tests(ZioStreams)(drainZStream) ++ new ServerCancellationTests(createServerTest)(monadError, asyncInstance).tests() ++ new ServerGracefulShutdownTests(createServerTest, zioSleeper).tests() diff --git a/server/nima-server/src/main/scala/sttp/tapir/server/nima/internal/NimaRequestBody.scala b/server/nima-server/src/main/scala/sttp/tapir/server/nima/internal/NimaRequestBody.scala index 00e39b702c..a9d31f11c2 100644 --- a/server/nima-server/src/main/scala/sttp/tapir/server/nima/internal/NimaRequestBody.scala +++ b/server/nima-server/src/main/scala/sttp/tapir/server/nima/internal/NimaRequestBody.scala @@ -14,7 +14,7 @@ import java.nio.file.{Files, StandardCopyOption} private[nima] class NimaRequestBody(createFile: ServerRequest => TapirFile) extends RequestBody[Id, NoStreams] { override val streams: capabilities.Streams[NoStreams] = NoStreams - override def toRaw[RAW](serverRequest: ServerRequest, bodyType: RawBodyType[RAW]): RawValue[RAW] = { + override def toRaw[RAW](serverRequest: ServerRequest, bodyType: RawBodyType[RAW], maxBytes: Option[Long]): RawValue[RAW] = { def asInputStream = nimaRequest(serverRequest).content().inputStream() def asByteArray = asInputStream.readAllBytes() @@ -32,7 +32,8 @@ private[nima] class NimaRequestBody(createFile: ServerRequest => TapirFile) exte } } - override def toStream(serverRequest: ServerRequest): streams.BinaryStream = throw new UnsupportedOperationException() + override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = + throw new UnsupportedOperationException() private def nimaRequest(serverRequest: ServerRequest): JavaNimaServerRequest = serverRequest.underlying.asInstanceOf[JavaNimaServerRequest] diff --git a/server/pekko-grpc-server/src/main/scala/sttp/tapir/server/pekkogrpc/PekkoGrpcRequestBody.scala b/server/pekko-grpc-server/src/main/scala/sttp/tapir/server/pekkogrpc/PekkoGrpcRequestBody.scala index 3eccd49401..206aa429f7 100644 --- a/server/pekko-grpc-server/src/main/scala/sttp/tapir/server/pekkogrpc/PekkoGrpcRequestBody.scala +++ b/server/pekko-grpc-server/src/main/scala/sttp/tapir/server/pekkogrpc/PekkoGrpcRequestBody.scala @@ -25,7 +25,7 @@ private[pekkogrpc] class PekkoGrpcRequestBody(serverOptions: PekkoHttpServerOpti override def toRaw[R](request: ServerRequest, bodyType: RawBodyType[R]): Future[RawValue[R]] = toRawFromEntity(request, akkaRequestEntity(request), bodyType) - override def toStream(request: ServerRequest): streams.BinaryStream = ??? + override def toStream(request: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = ??? private def akkaRequestEntity(request: ServerRequest) = request.underlying.asInstanceOf[RequestContext].request.entity diff --git a/server/pekko-http-server/src/main/scala/sttp/tapir/server/pekkohttp/PekkoRequestBody.scala b/server/pekko-http-server/src/main/scala/sttp/tapir/server/pekkohttp/PekkoRequestBody.scala index 8a341dd9b8..2f37d6e28c 100644 --- a/server/pekko-http-server/src/main/scala/sttp/tapir/server/pekkohttp/PekkoRequestBody.scala +++ b/server/pekko-http-server/src/main/scala/sttp/tapir/server/pekkohttp/PekkoRequestBody.scala @@ -21,9 +21,12 @@ private[pekkohttp] class PekkoRequestBody(serverOptions: PekkoHttpServerOptions) ec: ExecutionContext ) extends RequestBody[Future, PekkoStreams] { override val streams: PekkoStreams = PekkoStreams - override def toRaw[R](request: ServerRequest, bodyType: RawBodyType[R]): Future[RawValue[R]] = + override def toRaw[R](request: ServerRequest, bodyType: RawBodyType[R], maxBytes: Option[Long]): Future[RawValue[R]] = toRawFromEntity(request, akkeRequestEntity(request), bodyType) - override def toStream(request: ServerRequest): streams.BinaryStream = akkeRequestEntity(request).dataBytes + override def toStream(request: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = { + val stream = akkeRequestEntity(request).dataBytes + maxBytes.map(PekkoStreams.limitBytes(stream, _)).getOrElse(stream) + } private def akkeRequestEntity(request: ServerRequest) = request.underlying.asInstanceOf[RequestContext].request.entity diff --git a/server/pekko-http-server/src/test/scala/sttp/tapir/server/pekkohttp/PekkoHttpServerTest.scala b/server/pekko-http-server/src/test/scala/sttp/tapir/server/pekkohttp/PekkoHttpServerTest.scala index a442c74eee..ebfe952eb6 100644 --- a/server/pekko-http-server/src/test/scala/sttp/tapir/server/pekkohttp/PekkoHttpServerTest.scala +++ b/server/pekko-http-server/src/test/scala/sttp/tapir/server/pekkohttp/PekkoHttpServerTest.scala @@ -100,8 +100,11 @@ class PekkoHttpServerTest extends TestSuite with EitherValues { .unsafeToFuture() } ) + def drainPekko(stream: PekkoStreams.BinaryStream): Future[Unit] = + stream.runWith(Sink.ignore).map(_ => ()) + new AllServerTests(createServerTest, interpreter, backend).tests() ++ - new ServerStreamingTests(createServerTest, PekkoStreams).tests() ++ + new ServerStreamingTests(createServerTest, maxLengthSupported = true).tests(PekkoStreams)(drainPekko) ++ new ServerWebSocketTests(createServerTest, PekkoStreams) { override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = Flow.fromFunction(f) override def emptyPipe[A, B]: Flow[A, B, Any] = Flow.fromSinkAndSource(Sink.ignore, Source.empty) diff --git a/server/play-server/src/main/scala/sttp/tapir/server/play/PlayRequestBody.scala b/server/play-server/src/main/scala/sttp/tapir/server/play/PlayRequestBody.scala index 792730c460..e3feeaaf05 100644 --- a/server/play-server/src/main/scala/sttp/tapir/server/play/PlayRequestBody.scala +++ b/server/play-server/src/main/scala/sttp/tapir/server/play/PlayRequestBody.scala @@ -23,14 +23,17 @@ private[play] class PlayRequestBody(serverOptions: PlayServerOptions)(implicit override val streams: PekkoStreams = PekkoStreams - override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R]): Future[RawValue[R]] = { + override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R], maxBytes: Option[Long]): Future[RawValue[R]] = { import mat.executionContext val request = playRequest(serverRequest) val charset = request.charset.map(Charset.forName) toRaw(request, bodyType, charset, () => request.body, None) } - override def toStream(serverRequest: ServerRequest): streams.BinaryStream = playRequest(serverRequest).body + override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = { + val stream = playRequest(serverRequest).body + maxBytes.map(PekkoStreams.limitBytes(stream, _)).getOrElse(stream) + } private def toRaw[R]( request: Request[PekkoStreams.BinaryStream], diff --git a/server/play-server/src/test/scala/sttp/tapir/server/play/PlayServerTest.scala b/server/play-server/src/test/scala/sttp/tapir/server/play/PlayServerTest.scala index 3f14b5b49e..5409a4c9bc 100644 --- a/server/play-server/src/test/scala/sttp/tapir/server/play/PlayServerTest.scala +++ b/server/play-server/src/test/scala/sttp/tapir/server/play/PlayServerTest.scala @@ -104,6 +104,9 @@ class PlayServerTest extends TestSuite { } ) + def drainPekko(stream: PekkoStreams.BinaryStream): Future[Unit] = + stream.runWith(Sink.ignore).map(_ => ()) + new ServerBasicTests( createServerTest, interpreter, @@ -113,7 +116,7 @@ class PlayServerTest extends TestSuite { ).tests() ++ new ServerMultipartTests(createServerTest, partOtherHeaderSupport = false).tests() ++ new AllServerTests(createServerTest, interpreter, backend, basic = false, multipart = false, options = false).tests() ++ - new ServerStreamingTests(createServerTest, PekkoStreams).tests() ++ + new ServerStreamingTests(createServerTest, maxLengthSupported = true).tests(PekkoStreams)(drainPekko) ++ new PlayServerWithContextTest(backend).tests() ++ new ServerWebSocketTests(createServerTest, PekkoStreams) { override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = Flow.fromFunction(f) diff --git a/server/sttp-stub-server/src/main/scala/sttp/tapir/server/stub/SttpRequestBody.scala b/server/sttp-stub-server/src/main/scala/sttp/tapir/server/stub/SttpRequestBody.scala index 6517603d96..c8adc730d1 100644 --- a/server/sttp-stub-server/src/main/scala/sttp/tapir/server/stub/SttpRequestBody.scala +++ b/server/sttp-stub-server/src/main/scala/sttp/tapir/server/stub/SttpRequestBody.scala @@ -14,7 +14,7 @@ import scala.annotation.tailrec class SttpRequestBody[F[_]](implicit ME: MonadError[F]) extends RequestBody[F, AnyStreams] { override val streams: AnyStreams = AnyStreams - override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R]): F[RawValue[R]] = + override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R], maxBytes: Option[Long]): F[RawValue[R]] = body(serverRequest) match { case Left(bytes) => bodyType match { @@ -29,7 +29,7 @@ class SttpRequestBody[F[_]](implicit ME: MonadError[F]) extends RequestBody[F, A case _ => throw new IllegalArgumentException("Stream body provided while endpoint accepts raw body type") } - override def toStream(serverRequest: ServerRequest): streams.BinaryStream = body(serverRequest) match { + override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = body(serverRequest) match { case Right(stream) => stream case _ => throw new IllegalArgumentException("Raw body provided while endpoint accepts stream body") } diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala index ecda573c9c..38330e98a1 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala @@ -23,6 +23,7 @@ import sttp.tapir.tests.data.{FruitAmount, FruitError} import java.io.{ByteArrayInputStream, InputStream} import java.nio.ByteBuffer +import sttp.tapir.server.interpreter.MaxContentLength class ServerBasicTests[F[_], OPTIONS, ROUTE]( createServerTest: CreateServerTest[F, Any, OPTIONS, ROUTE], @@ -745,9 +746,25 @@ class ServerBasicTests[F[_], OPTIONS, ROUTE]( ) def maxContentLengthTests(): List[Test] = List( - testServer(in_string_out_string, "returns 413 on exceeded max content length")(_ => - pureResult(List.fill(maxContentLength.getOrElse(0) + 1)('x').mkString.asRight[Unit]) - ) { (backend, baseUri) => basicRequest.post(uri"$baseUri/api/echo").body("irrelevant").send(backend).map(_.code.code shouldBe 413) } + { + val maxContentLength = 300 + testServer( + in_string_out_string.attribute(AttributeKey[MaxContentLength], MaxContentLength(maxContentLength.toLong)), + "returns 413 on exceeded max content length (request)" + )(_ => pureResult("ok".asRight[Unit])) { (backend, baseUri) => + val tooLargeBody: String = List.fill(maxContentLength + 1)('x').mkString + basicRequest.post(uri"$baseUri/api/echo").body(tooLargeBody).send(backend).map(_.code shouldBe StatusCode.PayloadTooLarge) + } + }, { + val maxContentLength = 300 + testServer( + in_string_out_string.attribute(AttributeKey[MaxContentLength], MaxContentLength(maxContentLength.toLong)), + "returns OK on content length below or equal max (request)" + )(_ => pureResult("ok".asRight[Unit])) { (backend, baseUri) => + val fineBody: String = List.fill(maxContentLength)('x').mkString + basicRequest.post(uri"$baseUri/api/echo").body(fineBody).send(backend).map(_.code shouldBe StatusCode.Ok) + } + } ) def exceptionTests(): List[Test] = List( diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerStreamingTests.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerStreamingTests.scala index 13ee47524a..1d0043188e 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerStreamingTests.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerStreamingTests.scala @@ -4,27 +4,29 @@ import cats.syntax.all._ import org.scalatest.matchers.should.Matchers._ import sttp.capabilities.Streams import sttp.client3._ -import sttp.model.{Header, HeaderNames, MediaType} +import sttp.model.{Header, MediaType, StatusCode} import sttp.monad.MonadError +import sttp.monad.syntax._ import sttp.tapir.tests.Test -import sttp.tapir.tests.Streaming.{ - in_stream_out_either_json_xml_stream, - in_stream_out_stream, - in_stream_out_stream_with_content_length, - in_string_stream_out_either_stream_string, - out_custom_content_type_stream_body -} +import sttp.tapir.tests.Streaming._ +import sttp.tapir.server.interpreter.MaxContentLength +import sttp.tapir.AttributeKey +import cats.effect.IO +import sttp.capabilities.fs2.Fs2Streams -class ServerStreamingTests[F[_], S, OPTIONS, ROUTE](createServerTest: CreateServerTest[F, S, OPTIONS, ROUTE], streams: Streams[S])(implicit +class ServerStreamingTests[F[_], S, OPTIONS, ROUTE]( + createServerTest: CreateServerTest[F, S, OPTIONS, ROUTE], + maxLengthSupported: Boolean +)(implicit m: MonadError[F] ) { - def tests(): List[Test] = { + def tests(streams: Streams[_ >: S])(drain: streams.BinaryStream => F[Unit]): List[Test] = { import createServerTest._ val penPineapple = "pen pineapple apple pen" - List( + val baseTests = List( testServer(in_stream_out_stream(streams))((s: streams.BinaryStream) => pureResult(s.asRight[Unit])) { (backend, baseUri) => basicRequest.post(uri"$baseUri/api/echo").body(penPineapple).send(backend).map(_.body shouldBe Right(penPineapple)) }, @@ -99,5 +101,41 @@ class ServerStreamingTests[F[_], S, OPTIONS, ROUTE](createServerTest: CreateServ } } ) + + val maxContentLengthTests = List( + { + val inputByteCount = 1024 + val maxBytes = 1024L + val inputStream = fs2.Stream.fromIterator[IO](Iterator.fill[Byte](inputByteCount)('5'.toByte), chunkSize = 256) + testServer( + in_stream_out_stream(streams).attribute(AttributeKey[MaxContentLength], MaxContentLength(maxBytes)), + "with request content length == max" + )((s: streams.BinaryStream) => pureResult(s.asRight[Unit])) { (backend, baseUri) => + basicRequest + .post(uri"$baseUri/api/echo") + .streamBody(Fs2Streams[IO])(inputStream) + .send(backend) + .map(resp => assert(resp.isSuccess, "Response 200 OK")) + } + }, { + val inputByteCount = 1024 + val maxBytes = 1023L + val inputStream = fs2.Stream.fromIterator[IO](Iterator.fill[Byte](inputByteCount)('5'.toByte), chunkSize = 256) + testServer( + in_stream_out_string(streams).attribute(AttributeKey[MaxContentLength], MaxContentLength(maxBytes)), + "with request content length > max" + )((s: streams.BinaryStream) => drain(s).flatMap(_ => pureResult("ok".asRight[Unit]))) { (backend, baseUri) => + basicRequest + .post(uri"$baseUri/api/echo") + .streamBody(Fs2Streams[IO])(inputStream) + .send(backend) + .map(_.code shouldBe (StatusCode.PayloadTooLarge)) + } + } + ) + + if (maxLengthSupported) + baseTests ++ maxContentLengthTests + else baseTests } } diff --git a/server/vertx-server/cats/src/main/scala/sttp/tapir/server/vertx/cats/streams/fs2.scala b/server/vertx-server/cats/src/main/scala/sttp/tapir/server/vertx/cats/streams/fs2.scala index 1af0c770f2..df62f8e3d8 100644 --- a/server/vertx-server/cats/src/main/scala/sttp/tapir/server/vertx/cats/streams/fs2.scala +++ b/server/vertx-server/cats/src/main/scala/sttp/tapir/server/vertx/cats/streams/fs2.scala @@ -107,8 +107,10 @@ object fs2 { } } - override def fromReadStream(readStream: ReadStream[Buffer]): Stream[F, Byte] = - fromReadStreamInternal(readStream).map(buffer => Chunk.array(buffer.getBytes)).unchunks + override def fromReadStream(readStream: ReadStream[Buffer], maxBytes: Option[Long]): Stream[F, Byte] = { + val stream = fromReadStreamInternal(readStream).map(buffer => Chunk.array(buffer.getBytes)).unchunks + maxBytes.map(Fs2Streams.limitBytes(stream, _)).getOrElse(stream) + } private def fromReadStreamInternal[T](readStream: ReadStream[T]): Stream[F, T] = opts.dispatcher.unsafeRunSync { diff --git a/server/vertx-server/cats/src/test/scala/sttp/tapir/server/vertx/cats/CatsVertxServerTest.scala b/server/vertx-server/cats/src/test/scala/sttp/tapir/server/vertx/cats/CatsVertxServerTest.scala index 36c4618091..d3b532bea2 100644 --- a/server/vertx-server/cats/src/test/scala/sttp/tapir/server/vertx/cats/CatsVertxServerTest.scala +++ b/server/vertx-server/cats/src/test/scala/sttp/tapir/server/vertx/cats/CatsVertxServerTest.scala @@ -14,6 +14,9 @@ class CatsVertxServerTest extends TestSuite { def vertxResource: Resource[IO, Vertx] = Resource.make(IO.delay(Vertx.vertx()))(vertx => new CatsFFromVFuture[IO]().apply(vertx.close).void) + def drainFs2(stream: Fs2Streams[IO]#BinaryStream): IO[Unit] = + stream.compile.drain.void + override def tests: Resource[IO, List[Test]] = backendResource.flatMap { backend => vertxResource.map { implicit vertx => implicit val m: MonadError[IO] = VertxCatsServerInterpreter.monadError[IO] @@ -26,7 +29,7 @@ class CatsVertxServerTest extends TestSuite { partContentTypeHeaderSupport = false, // README: doesn't seem supported but I may be wrong partOtherHeaderSupport = false ).tests() ++ - new ServerStreamingTests(createServerTest, Fs2Streams.apply[IO]).tests() ++ + new ServerStreamingTests(createServerTest, maxLengthSupported = true).tests(Fs2Streams.apply[IO])(drainFs2) ++ new ServerWebSocketTests(createServerTest, Fs2Streams.apply[IO]) { override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = in => in.map(f) override def emptyPipe[A, B]: streams.Pipe[A, B] = _ => Stream.empty diff --git a/server/vertx-server/cats/src/test/scala/sttp/tapir/server/vertx/cats/streams/Fs2StreamTest.scala b/server/vertx-server/cats/src/test/scala/sttp/tapir/server/vertx/cats/streams/Fs2StreamTest.scala index db592dceb9..baadbe13f4 100644 --- a/server/vertx-server/cats/src/test/scala/sttp/tapir/server/vertx/cats/streams/Fs2StreamTest.scala +++ b/server/vertx-server/cats/src/test/scala/sttp/tapir/server/vertx/cats/streams/Fs2StreamTest.scala @@ -148,7 +148,7 @@ class Fs2StreamTest extends AsyncFlatSpec with Matchers with BeforeAndAfterAll { val opts = options.copy(maxQueueSizeForReadStream = 128) val count = 100 val readStream = new FakeStream() - val stream = streams.fs2.fs2ReadStreamCompatible[IO](opts)(implicitly).fromReadStream(readStream) + val stream = streams.fs2.fs2ReadStreamCompatible[IO](opts)(implicitly).fromReadStream(readStream, None) (for { resultFiber <- stream .chunkN(4) @@ -174,7 +174,7 @@ class Fs2StreamTest extends AsyncFlatSpec with Matchers with BeforeAndAfterAll { it should "drain read stream with small buffer" in { val count = 100 val readStream = new FakeStream() - val stream = streams.fs2.fs2ReadStreamCompatible[IO](options).fromReadStream(readStream) + val stream = streams.fs2.fs2ReadStreamCompatible[IO](options).fromReadStream(readStream, None) (for { resultFiber <- stream .chunkN(4) @@ -205,7 +205,7 @@ class Fs2StreamTest extends AsyncFlatSpec with Matchers with BeforeAndAfterAll { val ex = new Exception("!") val count = 50 val readStream = new FakeStream() - val stream = streams.fs2.fs2ReadStreamCompatible[IO](options).fromReadStream(readStream) + val stream = streams.fs2.fs2ReadStreamCompatible[IO](options).fromReadStream(readStream, None) (for { resultFiber <- stream .chunkN(4) diff --git a/server/vertx-server/src/main/scala/sttp/tapir/server/vertx/decoders/VertxRequestBody.scala b/server/vertx-server/src/main/scala/sttp/tapir/server/vertx/decoders/VertxRequestBody.scala index 345e5d54fd..a2f10720e7 100644 --- a/server/vertx-server/src/main/scala/sttp/tapir/server/vertx/decoders/VertxRequestBody.scala +++ b/server/vertx-server/src/main/scala/sttp/tapir/server/vertx/decoders/VertxRequestBody.scala @@ -26,7 +26,7 @@ class VertxRequestBody[F[_], S <: Streams[S]]( extends RequestBody[F, S] { override val streams: Streams[S] = readStreamCompatible.streams - override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R]): F[RawValue[R]] = { + override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R], maxBytes: Option[Long]): F[RawValue[R]] = { val rc = routingContext(serverRequest) fromVFuture(bodyType match { case RawBodyType.StringBody(defaultCharset) => @@ -95,9 +95,9 @@ class VertxRequestBody[F[_], S <: Streams[S]]( }) } - override def toStream(serverRequest: ServerRequest): streams.BinaryStream = + override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = readStreamCompatible - .fromReadStream(routingContext(serverRequest).request) + .fromReadStream(routingContext(serverRequest).request, maxBytes) .asInstanceOf[streams.BinaryStream] private def extractStringPart[B](part: String, bodyType: RawBodyType[B]): Option[Any] = { diff --git a/server/vertx-server/src/main/scala/sttp/tapir/server/vertx/streams/ReadStreamCompatible.scala b/server/vertx-server/src/main/scala/sttp/tapir/server/vertx/streams/ReadStreamCompatible.scala index 3ab08405e7..0b38818e0a 100644 --- a/server/vertx-server/src/main/scala/sttp/tapir/server/vertx/streams/ReadStreamCompatible.scala +++ b/server/vertx-server/src/main/scala/sttp/tapir/server/vertx/streams/ReadStreamCompatible.scala @@ -9,7 +9,7 @@ import sttp.ws.WebSocketFrame trait ReadStreamCompatible[S <: Streams[S]] { val streams: S def asReadStream(s: streams.BinaryStream): ReadStream[Buffer] - def fromReadStream(s: ReadStream[Buffer]): streams.BinaryStream + def fromReadStream(s: ReadStream[Buffer], maxBytes: Option[Long]): streams.BinaryStream def webSocketPipe[REQ, RESP]( readStream: ReadStream[WebSocketFrame], diff --git a/server/vertx-server/src/main/scala/sttp/tapir/server/vertx/streams/package.scala b/server/vertx-server/src/main/scala/sttp/tapir/server/vertx/streams/package.scala index 6935219653..e9197ac7b8 100644 --- a/server/vertx-server/src/main/scala/sttp/tapir/server/vertx/streams/package.scala +++ b/server/vertx-server/src/main/scala/sttp/tapir/server/vertx/streams/package.scala @@ -18,7 +18,7 @@ package object streams { override def asReadStream(readStream: ReadStream[Buffer]): ReadStream[Buffer] = readStream - override def fromReadStream(readStream: ReadStream[Buffer]): ReadStream[Buffer] = + override def fromReadStream(readStream: ReadStream[Buffer], maxBytes: Option[Long]): ReadStream[Buffer] = // TODO support maxBytes readStream override def webSocketPipe[REQ, RESP]( diff --git a/server/vertx-server/src/test/scala/sttp/tapir/server/vertx/VertxServerTest.scala b/server/vertx-server/src/test/scala/sttp/tapir/server/vertx/VertxServerTest.scala index b9369dc2bf..dbb2e5747e 100644 --- a/server/vertx-server/src/test/scala/sttp/tapir/server/vertx/VertxServerTest.scala +++ b/server/vertx-server/src/test/scala/sttp/tapir/server/vertx/VertxServerTest.scala @@ -9,6 +9,7 @@ import sttp.tapir.server.vertx.streams.VertxStreams import sttp.tapir.tests.{Test, TestSuite} import scala.concurrent.ExecutionContext +import scala.concurrent.Future class VertxServerTest extends TestSuite { def vertxResource: Resource[IO, Vertx] = @@ -26,7 +27,7 @@ class VertxServerTest extends TestSuite { createServerTest, partContentTypeHeaderSupport = false, // README: doesn't seem supported but I may be wrong partOtherHeaderSupport = false - ).tests() ++ new ServerStreamingTests(createServerTest, VertxStreams).tests() ++ + ).tests() ++ new ServerStreamingTests(createServerTest, maxLengthSupported = false).tests(VertxStreams)(_ => Future.unit) ++ (new ServerWebSocketTests(createServerTest, VertxStreams) { override def functionToPipe[A, B](f: A => B): VertxStreams.Pipe[A, B] = in => new ReadStreamMapping(in, f) override def emptyPipe[A, B]: VertxStreams.Pipe[A, B] = _ => new EmptyReadStream() diff --git a/server/vertx-server/zio/src/main/scala/sttp/tapir/server/vertx/zio/streams/zio.scala b/server/vertx-server/zio/src/main/scala/sttp/tapir/server/vertx/zio/streams/zio.scala index 4675fae530..c25f55a62a 100644 --- a/server/vertx-server/zio/src/main/scala/sttp/tapir/server/vertx/zio/streams/zio.scala +++ b/server/vertx-server/zio/src/main/scala/sttp/tapir/server/vertx/zio/streams/zio.scala @@ -103,8 +103,10 @@ package object streams { }).toEither .fold(throw _, identity) - override def fromReadStream(readStream: ReadStream[Buffer]): Stream[Throwable, Byte] = - fromReadStreamInternal(readStream).mapConcatChunk(buffer => Chunk.fromArray(buffer.getBytes)) + override def fromReadStream(readStream: ReadStream[Buffer], maxBytes: Option[Long]): Stream[Throwable, Byte] = { + val stream = fromReadStreamInternal(readStream).mapConcatChunk(buffer => Chunk.fromArray(buffer.getBytes)) + maxBytes.map(ZioStreams.limitBytes(stream, _)).getOrElse(stream) + } private def fromReadStreamInternal[T](readStream: ReadStream[T]): Stream[Throwable, T] = unsafeRunSync(for { diff --git a/server/vertx-server/zio/src/test/scala/sttp/tapir/server/vertx/zio/ZioVertxServerTest.scala b/server/vertx-server/zio/src/test/scala/sttp/tapir/server/vertx/zio/ZioVertxServerTest.scala index 908e762c26..ea448196fd 100644 --- a/server/vertx-server/zio/src/test/scala/sttp/tapir/server/vertx/zio/ZioVertxServerTest.scala +++ b/server/vertx-server/zio/src/test/scala/sttp/tapir/server/vertx/zio/ZioVertxServerTest.scala @@ -1,18 +1,19 @@ package sttp.tapir.server.vertx.zio +import _root_.zio.stream.ZStream +import _root_.zio.{Task, ZIO} import cats.effect.{IO, Resource} import io.vertx.core.Vertx +import org.scalatest.OptionValues +import org.scalatest.matchers.should.Matchers._ import sttp.capabilities.zio.ZioStreams +import sttp.client3.basicRequest import sttp.monad.MonadError +import sttp.tapir._ import sttp.tapir.server.tests._ import sttp.tapir.tests.{Test, TestSuite} -import sttp.tapir._ -import _root_.zio.{Task, ZIO} -import _root_.zio.stream.ZStream -import org.scalatest.OptionValues -import sttp.client3.basicRequest import sttp.tapir.ztapir.RIOMonadError -import org.scalatest.matchers.should.Matchers._ +import zio.stream.ZSink class ZioVertxServerTest extends TestSuite with OptionValues { def vertxResource: Resource[IO, Vertx] = @@ -32,6 +33,8 @@ class ZioVertxServerTest extends TestSuite with OptionValues { basicRequest.get(baseUri).send(backend).map(_.body.toOption.value should not include "vert.x-eventloop-thread") } ) + def drainZStream(zStream: ZioStreams.BinaryStream): Task[Unit] = + zStream.run(ZSink.drain) new AllServerTests(createServerTest, interpreter, backend, multipart = false, reject = false, options = false).tests() ++ new ServerMultipartTests( @@ -39,7 +42,7 @@ class ZioVertxServerTest extends TestSuite with OptionValues { partContentTypeHeaderSupport = false, // README: doesn't seem supported but I may be wrong partOtherHeaderSupport = false ).tests() ++ additionalTests() ++ - new ServerStreamingTests(createServerTest, ZioStreams).tests() ++ + new ServerStreamingTests(createServerTest, maxLengthSupported = true).tests(ZioStreams)(drainZStream) ++ new ServerWebSocketTests(createServerTest, ZioStreams) { override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = in => in.map(f) override def emptyPipe[A, B]: streams.Pipe[A, B] = _ => ZStream.empty diff --git a/server/vertx-server/zio/src/test/scala/sttp/tapir/server/vertx/zio/streams/ZStreamTest.scala b/server/vertx-server/zio/src/test/scala/sttp/tapir/server/vertx/zio/streams/ZStreamTest.scala index d93ac3cce3..567ddf335b 100644 --- a/server/vertx-server/zio/src/test/scala/sttp/tapir/server/vertx/zio/streams/ZStreamTest.scala +++ b/server/vertx-server/zio/src/test/scala/sttp/tapir/server/vertx/zio/streams/ZStreamTest.scala @@ -124,7 +124,7 @@ class ZStreamTest extends AsyncFlatSpec with Matchers { val opts = options.copy(maxQueueSizeForReadStream = 128) val count = 100 val readStream = new FakeStream() - val stream = zioReadStreamCompatible(opts)(runtime).fromReadStream(readStream) + val stream = zioReadStreamCompatible(opts)(runtime).fromReadStream(readStream, None) unsafeToFuture(for { resultFiber <- ZIO .scoped( @@ -154,7 +154,7 @@ class ZStreamTest extends AsyncFlatSpec with Matchers { val opts = options.copy(maxQueueSizeForReadStream = 4) val count = 100 val readStream = new FakeStream() - val stream = zioReadStreamCompatible(opts)(runtime).fromReadStream(readStream) + val stream = zioReadStreamCompatible(opts)(runtime).fromReadStream(readStream, None) unsafeToFuture(for { resultFiber <- ZIO .scoped( @@ -188,7 +188,7 @@ class ZStreamTest extends AsyncFlatSpec with Matchers { val opts = options.copy(maxQueueSizeForReadStream = 4) val count = 50 val readStream = new FakeStream() - val stream = zioReadStreamCompatible(opts)(runtime).fromReadStream(readStream) + val stream = zioReadStreamCompatible(opts)(runtime).fromReadStream(readStream, None) unsafeToFuture(for { resultFiber <- ZIO .scoped( diff --git a/server/vertx-server/zio1/src/main/scala/sttp/tapir/server/vertx/zio/streams/zio.scala b/server/vertx-server/zio1/src/main/scala/sttp/tapir/server/vertx/zio/streams/zio.scala index dc83380ead..87e41466d8 100644 --- a/server/vertx-server/zio1/src/main/scala/sttp/tapir/server/vertx/zio/streams/zio.scala +++ b/server/vertx-server/zio1/src/main/scala/sttp/tapir/server/vertx/zio/streams/zio.scala @@ -119,8 +119,9 @@ package object streams { .toEither .fold(throw _, identity) - override def fromReadStream(readStream: ReadStream[Buffer]): Stream[Throwable, Byte] = + override def fromReadStream(readStream: ReadStream[Buffer], maxBytes: Option[Long]): Stream[Throwable, Byte] = { fromReadStreamInternal(readStream).mapConcatChunk(buffer => Chunk.fromArray(buffer.getBytes)) + } private def fromReadStreamInternal[T](readStream: ReadStream[T]): Stream[Throwable, T] = runtime diff --git a/server/vertx-server/zio1/src/test/scala/sttp/tapir/server/vertx/zio/ZioVertxServerTest.scala b/server/vertx-server/zio1/src/test/scala/sttp/tapir/server/vertx/zio/ZioVertxServerTest.scala index 1d511b8332..dbe692745e 100644 --- a/server/vertx-server/zio1/src/test/scala/sttp/tapir/server/vertx/zio/ZioVertxServerTest.scala +++ b/server/vertx-server/zio1/src/test/scala/sttp/tapir/server/vertx/zio/ZioVertxServerTest.scala @@ -10,6 +10,7 @@ import _root_.zio.RIO import _root_.zio.blocking.Blocking import sttp.tapir.ztapir.RIOMonadError import zio.stream.ZStream +import zio.Task class ZioVertxServerTest extends TestSuite { def vertxResource: Resource[IO, Vertx] = @@ -27,7 +28,7 @@ class ZioVertxServerTest extends TestSuite { partContentTypeHeaderSupport = false, // README: doesn't seem supported but I may be wrong partOtherHeaderSupport = false ).tests() ++ - new ServerStreamingTests(createServerTest, ZioStreams).tests() ++ + new ServerStreamingTests(createServerTest, maxLengthSupported = false).tests(ZioStreams)(_ => Task.unit) ++ new ServerWebSocketTests(createServerTest, ZioStreams) { override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = in => in.map(f) override def emptyPipe[A, B]: streams.Pipe[A, B] = _ => ZStream.empty diff --git a/server/vertx-server/zio1/src/test/scala/sttp/tapir/server/vertx/zio/streams/ZStreamTest.scala b/server/vertx-server/zio1/src/test/scala/sttp/tapir/server/vertx/zio/streams/ZStreamTest.scala index 181ce98b65..fafa0daffc 100644 --- a/server/vertx-server/zio1/src/test/scala/sttp/tapir/server/vertx/zio/streams/ZStreamTest.scala +++ b/server/vertx-server/zio1/src/test/scala/sttp/tapir/server/vertx/zio/streams/ZStreamTest.scala @@ -131,7 +131,7 @@ class ZStreamTest extends AsyncFlatSpec with Matchers { val opts = options.copy(maxQueueSizeForReadStream = 128) val count = 100 val readStream = new FakeStream() - val stream = zioReadStreamCompatible(opts)(runtime).fromReadStream(readStream) + val stream = zioReadStreamCompatible(opts)(runtime).fromReadStream(readStream, maxBytes = None) runtime .unsafeRunToFuture(for { resultFiber <- stream @@ -160,7 +160,7 @@ class ZStreamTest extends AsyncFlatSpec with Matchers { val opts = options.copy(maxQueueSizeForReadStream = 4) val count = 100 val readStream = new FakeStream() - val stream = zioReadStreamCompatible(opts)(runtime).fromReadStream(readStream) + val stream = zioReadStreamCompatible(opts)(runtime).fromReadStream(readStream, maxBytes = None) runtime .unsafeRunToFuture(for { resultFiber <- stream @@ -193,7 +193,7 @@ class ZStreamTest extends AsyncFlatSpec with Matchers { val opts = options.copy(maxQueueSizeForReadStream = 4) val count = 50 val readStream = new FakeStream() - val stream = zioReadStreamCompatible(opts)(runtime).fromReadStream(readStream) + val stream = zioReadStreamCompatible(opts)(runtime).fromReadStream(readStream, maxBytes = None) runtime .unsafeRunToFuture(for { resultFiber <- stream diff --git a/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpRequestBody.scala b/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpRequestBody.scala index 5e324cfbfb..1a14fcd57b 100644 --- a/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpRequestBody.scala +++ b/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpRequestBody.scala @@ -2,22 +2,21 @@ package sttp.tapir.server.ziohttp import sttp.capabilities import sttp.capabilities.zio.ZioStreams -import sttp.tapir.{FileRange, InputStreamRange} -import sttp.tapir.RawBodyType import sttp.tapir.model.ServerRequest -import sttp.tapir.server.interpreter.RawValue -import sttp.tapir.server.interpreter.RequestBody +import sttp.tapir.server.interpreter.{BodyMaxLengthExceededException, RawValue, RequestBody} +import sttp.tapir.{FileRange, InputStreamRange, RawBodyType} import zio.http.Request -import zio.{RIO, Task, ZIO} import zio.stream.Stream +import zio.{RIO, Task, ZIO} import java.io.ByteArrayInputStream import java.nio.ByteBuffer +import sttp.capabilities.StreamMaxLengthExceededException class ZioHttpRequestBody[R](serverOptions: ZioHttpServerOptions[R]) extends RequestBody[RIO[R, *], ZioStreams] { override val streams: capabilities.Streams[ZioStreams] = ZioStreams - override def toRaw[RAW](serverRequest: ServerRequest, bodyType: RawBodyType[RAW]): Task[RawValue[RAW]] = bodyType match { + override def toRaw[RAW](serverRequest: ServerRequest, bodyType: RawBodyType[RAW], maxBytes: Option[Long]): Task[RawValue[RAW]] = bodyType match { case RawBodyType.StringBody(defaultCharset) => asByteArray(serverRequest).map(new String(_, defaultCharset)).map(RawValue(_)) case RawBodyType.ByteArrayBody => asByteArray(serverRequest).map(RawValue(_)) case RawBodyType.ByteBufferBody => asByteArray(serverRequest).map(bytes => ByteBuffer.wrap(bytes)).map(RawValue(_)) @@ -29,12 +28,30 @@ class ZioHttpRequestBody[R](serverOptions: ZioHttpServerOptions[R]) extends Requ case RawBodyType.MultipartBody(_, _) => ZIO.fail(new UnsupportedOperationException("Multipart is not supported")) } - override def toStream(serverRequest: ServerRequest): streams.BinaryStream = stream(serverRequest).asInstanceOf[streams.BinaryStream] + override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = + toZioStream(serverRequest, maxBytes).asInstanceOf[streams.BinaryStream] + + private def toZioStream(serverRequest: ServerRequest, maxBytes: Option[Long]): Stream[Throwable, Byte] = { + val inputStream = stream(serverRequest) + maxBytes.map(ZioStreams.limitBytes(inputStream, _)).getOrElse(inputStream) + } private def stream(serverRequest: ServerRequest): Stream[Throwable, Byte] = zioHttpRequest(serverRequest).body.asStream - private def asByteArray(serverRequest: ServerRequest): Task[Array[Byte]] = zioHttpRequest(serverRequest).body.asArray + private def asByteArray(serverRequest: ServerRequest): Task[Array[Byte]] = { + val maxBytes: Option[Long] = Some(30L) // TODO + val body = zioHttpRequest(serverRequest).body + println("Checking against maxBytes") + if (body.isComplete) { + maxBytes.map(limit => body.asArray.filterOrFail(_.length <= limit)(new BodyMaxLengthExceededException(limit))).getOrElse(body.asArray) + } else + toZioStream(serverRequest, maxBytes).runCollect + .catchSomeDefect { case e: StreamMaxLengthExceededException => + ZIO.fail(e) + } + .map(_.toArray) + } private def zioHttpRequest(serverRequest: ServerRequest) = serverRequest.underlying.asInstanceOf[Request] } diff --git a/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpServerTest.scala b/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpServerTest.scala index 66e87d55bb..a4df8a8c05 100644 --- a/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpServerTest.scala +++ b/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpServerTest.scala @@ -50,6 +50,7 @@ import java.nio.charset.Charset import java.time import scala.concurrent.Future import scala.concurrent.duration.DurationInt +import zio.stream.ZSink class ZioHttpServerTest extends TestSuite { @@ -242,13 +243,17 @@ class ZioHttpServerTest extends TestSuite { implicit val m: MonadError[Task] = new RIOMonadError[Any] + def drainZStream(zStream: ZioStreams.BinaryStream): Task[Unit] = + zStream.run(ZSink.drain) + new ServerBasicTests( createServerTest, interpreter, multipleValueHeaderSupport = false, supportsUrlEncodedPathSegments = false, supportsMultipleSetCookieHeaders = false, - invulnerableToUnsanitizedHeaders = false + invulnerableToUnsanitizedHeaders = false, + maxContentLength = Some(300) ).tests() ++ // TODO: re-enable static content once a newer zio http is available. Currently these tests often fail with: // Cause: java.io.IOException: parsing HTTP/1.1 status line, receiving [f2 content], parser state [STATUS_LINE] @@ -262,7 +267,7 @@ class ZioHttpServerTest extends TestSuite { file = false, options = false ).tests() ++ - new ServerStreamingTests(createServerTest, ZioStreams).tests() ++ + new ServerStreamingTests(createServerTest, maxLengthSupported = true).tests(ZioStreams)(drainZStream) ++ new ZioHttpCompositionTest(createServerTest).tests() ++ new ServerWebSocketTests(createServerTest, ZioStreams) { override def functionToPipe[A, B](f: A => B): ZioStreams.Pipe[A, B] = in => in.map(f) diff --git a/server/zio1-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpRequestBody.scala b/server/zio1-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpRequestBody.scala index 1234bab7e8..675cafc4e2 100644 --- a/server/zio1-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpRequestBody.scala +++ b/server/zio1-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpRequestBody.scala @@ -18,25 +18,30 @@ import java.nio.ByteBuffer class ZioHttpRequestBody[R](serverOptions: ZioHttpServerOptions[R]) extends RequestBody[RIO[R, *], ZioStreams] { override val streams: capabilities.Streams[ZioStreams] = ZioStreams - override def toRaw[RAW](serverRequest: ServerRequest, bodyType: RawBodyType[RAW]): Task[RawValue[RAW]] = bodyType match { - case RawBodyType.StringBody(defaultCharset) => asByteArray(serverRequest).map(new String(_, defaultCharset)).map(RawValue(_)) - case RawBodyType.ByteArrayBody => asByteArray(serverRequest).map(RawValue(_)) - case RawBodyType.ByteBufferBody => asByteArray(serverRequest).map(bytes => ByteBuffer.wrap(bytes)).map(RawValue(_)) - case RawBodyType.InputStreamBody => asByteArray(serverRequest).map(new ByteArrayInputStream(_)).map(RawValue(_)) - case RawBodyType.InputStreamRangeBody => - asByteArray(serverRequest).map(bytes => new InputStreamRange(() => new ByteArrayInputStream(bytes))).map(RawValue(_)) - case RawBodyType.FileBody => - for { - tmpFile <- serverOptions.createFile(serverRequest) - _ <- toStream(serverRequest).asInstanceOf[Stream[Throwable, Byte]].run(ZSink.fromFile(tmpFile.toPath)).provideLayer(Blocking.live) - } yield { - val fileRange = FileRange(tmpFile) - RawValue(fileRange, Seq(fileRange)) - } - case RawBodyType.MultipartBody(_, _) => ZIO.fail(new UnsupportedOperationException("Multipart is not supported")) - } - - override def toStream(serverRequest: ServerRequest): streams.BinaryStream = stream(serverRequest).asInstanceOf[streams.BinaryStream] + override def toRaw[RAW](serverRequest: ServerRequest, bodyType: RawBodyType[RAW], maxBytes: Option[Long]): Task[RawValue[RAW]] = + bodyType match { + case RawBodyType.StringBody(defaultCharset) => asByteArray(serverRequest).map(new String(_, defaultCharset)).map(RawValue(_)) + case RawBodyType.ByteArrayBody => asByteArray(serverRequest).map(RawValue(_)) + case RawBodyType.ByteBufferBody => asByteArray(serverRequest).map(bytes => ByteBuffer.wrap(bytes)).map(RawValue(_)) + case RawBodyType.InputStreamBody => asByteArray(serverRequest).map(new ByteArrayInputStream(_)).map(RawValue(_)) + case RawBodyType.InputStreamRangeBody => + asByteArray(serverRequest).map(bytes => new InputStreamRange(() => new ByteArrayInputStream(bytes))).map(RawValue(_)) + case RawBodyType.FileBody => + for { + tmpFile <- serverOptions.createFile(serverRequest) + _ <- toStream(serverRequest, None) + .asInstanceOf[Stream[Throwable, Byte]] + .run(ZSink.fromFile(tmpFile.toPath)) + .provideLayer(Blocking.live) + } yield { + val fileRange = FileRange(tmpFile) + RawValue(fileRange, Seq(fileRange)) + } + case RawBodyType.MultipartBody(_, _) => ZIO.fail(new UnsupportedOperationException("Multipart is not supported")) + } + + override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = + stream(serverRequest).asInstanceOf[streams.BinaryStream] private def asByteArray(serverRequest: ServerRequest): Task[Array[Byte]] = zioHttpRequest(serverRequest).body.map(_.toArray) diff --git a/server/zio1-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpServerTest.scala b/server/zio1-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpServerTest.scala index 65e0236634..53224d17dd 100644 --- a/server/zio1-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpServerTest.scala +++ b/server/zio1-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpServerTest.scala @@ -54,7 +54,7 @@ class ZioHttpServerTest extends TestSuite { // Cause: java.io.IOException: parsing HTTP/1.1 status line, receiving [f2 content], parser state [STATUS_LINE] new AllServerTests(createServerTest, interpreter, backend, basic = false, staticContent = false, multipart = false, file = true) .tests() ++ - new ServerStreamingTests(createServerTest, ZioStreams).tests() ++ + new ServerStreamingTests(createServerTest, maxLengthSupported = false).tests(ZioStreams)(_ => Task.unit) ++ new ZioHttpCompositionTest(createServerTest).tests() // ++ // TODO: only works with zio2 // additionalTests() diff --git a/serverless/aws/lambda-core/src/main/scala/sttp/tapir/serverless/aws/lambda/AwsRequestBody.scala b/serverless/aws/lambda-core/src/main/scala/sttp/tapir/serverless/aws/lambda/AwsRequestBody.scala index c8675ed552..077229f5ce 100644 --- a/serverless/aws/lambda-core/src/main/scala/sttp/tapir/serverless/aws/lambda/AwsRequestBody.scala +++ b/serverless/aws/lambda-core/src/main/scala/sttp/tapir/serverless/aws/lambda/AwsRequestBody.scala @@ -15,7 +15,7 @@ import java.util.Base64 private[lambda] class AwsRequestBody[F[_]: MonadError]() extends RequestBody[F, NoStreams] { override val streams: capabilities.Streams[NoStreams] = NoStreams - override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R]): F[RawValue[R]] = { + override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R], maxBytes: Option[Long]): F[RawValue[R]] = { val request = awsRequest(serverRequest) val decoded = if (request.isBase64Encoded) Left(Base64.getDecoder.decode(request.body.getOrElse(""))) else Right(request.body.getOrElse("")) @@ -33,7 +33,8 @@ private[lambda] class AwsRequestBody[F[_]: MonadError]() extends RequestBody[F, }).asInstanceOf[RawValue[R]].unit } - override def toStream(serverRequest: ServerRequest): streams.BinaryStream = throw new UnsupportedOperationException + override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = + throw new UnsupportedOperationException private def awsRequest(serverRequest: ServerRequest) = serverRequest.underlying.asInstanceOf[AwsRequest] } diff --git a/tests/src/main/scala/sttp/tapir/tests/Streaming.scala b/tests/src/main/scala/sttp/tapir/tests/Streaming.scala index 91b5a59c18..d6013d0a35 100644 --- a/tests/src/main/scala/sttp/tapir/tests/Streaming.scala +++ b/tests/src/main/scala/sttp/tapir/tests/Streaming.scala @@ -12,6 +12,10 @@ object Streaming { endpoint.post.in("api" / "echo").in(sb).out(sb) } + def in_stream_out_string[S](s: Streams[S]): PublicEndpoint[s.BinaryStream, Unit, String, S] = { + endpoint.post.in("api" / "echo").in(streamTextBody(s)(CodecFormat.TextPlain(), Some(StandardCharsets.UTF_8))).out(stringBody) + } + def in_stream_out_stream_with_content_length[S]( s: Streams[S] ): PublicEndpoint[(Long, s.BinaryStream), Unit, (Long, s.BinaryStream), S] = { From 0f4b410a58ea6e81972503acffc49330b704dfca Mon Sep 17 00:00:00 2001 From: kciesielski Date: Thu, 23 Nov 2023 11:10:15 +0100 Subject: [PATCH 02/35] Revert changes to ZIO servers --- .../server/ziohttp/ZioHttpRequestBody.scala | 28 ++++---------- .../server/ziohttp/ZioHttpServerTest.scala | 3 +- .../server/ziohttp/ZioHttpRequestBody.scala | 38 +++++++++---------- 3 files changed, 26 insertions(+), 43 deletions(-) diff --git a/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpRequestBody.scala b/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpRequestBody.scala index 67099eba19..0631d4a0b9 100644 --- a/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpRequestBody.scala +++ b/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpRequestBody.scala @@ -2,16 +2,17 @@ package sttp.tapir.server.ziohttp import sttp.capabilities import sttp.capabilities.zio.ZioStreams +import sttp.tapir.{FileRange, InputStreamRange} +import sttp.tapir.RawBodyType import sttp.tapir.model.ServerRequest -import sttp.tapir.server.interpreter.{BodyMaxLengthExceededException, RawValue, RequestBody} -import sttp.tapir.{FileRange, InputStreamRange, RawBodyType} +import sttp.tapir.server.interpreter.RawValue +import sttp.tapir.server.interpreter.RequestBody import zio.http.Request -import zio.stream.Stream import zio.{RIO, Task, ZIO} +import zio.stream.Stream import java.io.ByteArrayInputStream import java.nio.ByteBuffer -import sttp.capabilities.StreamMaxLengthExceededException class ZioHttpRequestBody[R](serverOptions: ZioHttpServerOptions[R]) extends RequestBody[RIO[R, *], ZioStreams] { override val streams: capabilities.Streams[ZioStreams] = ZioStreams @@ -28,28 +29,15 @@ class ZioHttpRequestBody[R](serverOptions: ZioHttpServerOptions[R]) extends Requ case RawBodyType.MultipartBody(_, _) => ZIO.fail(new UnsupportedOperationException("Multipart is not supported")) } - override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = - toZioStream(serverRequest, maxBytes).asInstanceOf[streams.BinaryStream] - - private def toZioStream(serverRequest: ServerRequest, maxBytes: Option[Long]): Stream[Throwable, Byte] = { + override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = { val inputStream = stream(serverRequest) - maxBytes.map(ZioStreams.limitBytes(inputStream, _)).getOrElse(inputStream) + maxBytes.map(ZioStreams.limitBytes(inputStream, _)).getOrElse(inputStream).asInstanceOf[streams.BinaryStream] } private def stream(serverRequest: ServerRequest): Stream[Throwable, Byte] = zioHttpRequest(serverRequest).body.asStream - private def asByteArray(serverRequest: ServerRequest): Task[Array[Byte]] = { - val body = zioHttpRequest(serverRequest).body - if (body.isComplete) { - maxBytes.map(limit => body.asArray.filterOrFail(_.length <= limit)(new BodyMaxLengthExceededException(limit))).getOrElse(body.asArray) - } else - toZioStream(serverRequest, maxBytes).runCollect - .catchSomeDefect { case e: StreamMaxLengthExceededException => - ZIO.fail(e) - } - .map(_.toArray) - } + private def asByteArray(serverRequest: ServerRequest): Task[Array[Byte]] = zioHttpRequest(serverRequest).body.asArray private def zioHttpRequest(serverRequest: ServerRequest) = serverRequest.underlying.asInstanceOf[Request] } diff --git a/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpServerTest.scala b/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpServerTest.scala index a4df8a8c05..2d70a2934e 100644 --- a/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpServerTest.scala +++ b/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpServerTest.scala @@ -252,8 +252,7 @@ class ZioHttpServerTest extends TestSuite { multipleValueHeaderSupport = false, supportsUrlEncodedPathSegments = false, supportsMultipleSetCookieHeaders = false, - invulnerableToUnsanitizedHeaders = false, - maxContentLength = Some(300) + invulnerableToUnsanitizedHeaders = false ).tests() ++ // TODO: re-enable static content once a newer zio http is available. Currently these tests often fail with: // Cause: java.io.IOException: parsing HTTP/1.1 status line, receiving [f2 content], parser state [STATUS_LINE] diff --git a/server/zio1-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpRequestBody.scala b/server/zio1-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpRequestBody.scala index d1c4852e7f..4b178e546a 100644 --- a/server/zio1-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpRequestBody.scala +++ b/server/zio1-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpRequestBody.scala @@ -18,27 +18,23 @@ import java.nio.ByteBuffer class ZioHttpRequestBody[R](serverOptions: ZioHttpServerOptions[R]) extends RequestBody[RIO[R, *], ZioStreams] { override val streams: capabilities.Streams[ZioStreams] = ZioStreams - override def toRaw[RAW](serverRequest: ServerRequest, bodyType: RawBodyType[RAW], maxBytes: Option[Long]): Task[RawValue[RAW]] = - bodyType match { - case RawBodyType.StringBody(defaultCharset) => asByteArray(serverRequest).map(new String(_, defaultCharset)).map(RawValue(_)) - case RawBodyType.ByteArrayBody => asByteArray(serverRequest).map(RawValue(_)) - case RawBodyType.ByteBufferBody => asByteArray(serverRequest).map(bytes => ByteBuffer.wrap(bytes)).map(RawValue(_)) - case RawBodyType.InputStreamBody => asByteArray(serverRequest).map(new ByteArrayInputStream(_)).map(RawValue(_)) - case RawBodyType.InputStreamRangeBody => - asByteArray(serverRequest).map(bytes => new InputStreamRange(() => new ByteArrayInputStream(bytes))).map(RawValue(_)) - case RawBodyType.FileBody => - for { - tmpFile <- serverOptions.createFile(serverRequest) - _ <- toStream(serverRequest, maxBytes) - .asInstanceOf[Stream[Throwable, Byte]] - .run(ZSink.fromFile(tmpFile.toPath)) - .provideLayer(Blocking.live) - } yield { - val fileRange = FileRange(tmpFile) - RawValue(fileRange, Seq(fileRange)) - } - case RawBodyType.MultipartBody(_, _) => ZIO.fail(new UnsupportedOperationException("Multipart is not supported")) - } + override def toRaw[RAW](serverRequest: ServerRequest, bodyType: RawBodyType[RAW], maxBytes: Option[Long]): Task[RawValue[RAW]] = bodyType match { + case RawBodyType.StringBody(defaultCharset) => asByteArray(serverRequest).map(new String(_, defaultCharset)).map(RawValue(_)) + case RawBodyType.ByteArrayBody => asByteArray(serverRequest).map(RawValue(_)) + case RawBodyType.ByteBufferBody => asByteArray(serverRequest).map(bytes => ByteBuffer.wrap(bytes)).map(RawValue(_)) + case RawBodyType.InputStreamBody => asByteArray(serverRequest).map(new ByteArrayInputStream(_)).map(RawValue(_)) + case RawBodyType.InputStreamRangeBody => + asByteArray(serverRequest).map(bytes => new InputStreamRange(() => new ByteArrayInputStream(bytes))).map(RawValue(_)) + case RawBodyType.FileBody => + for { + tmpFile <- serverOptions.createFile(serverRequest) + _ <- toStream(serverRequest, None).asInstanceOf[Stream[Throwable, Byte]].run(ZSink.fromFile(tmpFile.toPath)).provideLayer(Blocking.live) + } yield { + val fileRange = FileRange(tmpFile) + RawValue(fileRange, Seq(fileRange)) + } + case RawBodyType.MultipartBody(_, _) => ZIO.fail(new UnsupportedOperationException("Multipart is not supported")) + } override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = stream(serverRequest).asInstanceOf[streams.BinaryStream] From a14b12972ad8dd7a79a3257b22c0b9b26189af26 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Thu, 23 Nov 2023 13:16:29 +0100 Subject: [PATCH 03/35] Test for non-streaming body limiting --- .../tapir/server/tests/ServerBasicTests.scala | 52 ++++++++++++------- 1 file changed, 33 insertions(+), 19 deletions(-) diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala index 38330e98a1..a0f42428e2 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala @@ -24,6 +24,8 @@ import sttp.tapir.tests.data.{FruitAmount, FruitError} import java.io.{ByteArrayInputStream, InputStream} import java.nio.ByteBuffer import sttp.tapir.server.interpreter.MaxContentLength +import sttp.tapir.tests.Files.in_file_out_file +import java.io.File class ServerBasicTests[F[_], OPTIONS, ROUTE]( createServerTest: CreateServerTest[F, Any, OPTIONS, ROUTE], @@ -745,26 +747,38 @@ class ServerBasicTests[F[_], OPTIONS, ROUTE]( } ) + val maxLength = 300 + private def limited[I, E, O, R](endpoint: PublicEndpoint[I, E, O, R], maxBytes: Int) = + endpoint.attribute(AttributeKey[MaxContentLength], MaxContentLength(maxBytes.toLong)) + + def testPayloadTooLarge[I]( + testedEndpoint: PublicEndpoint[I, Unit, I, Any] + ) = testServer( + testedEndpoint.attribute(AttributeKey[MaxContentLength], MaxContentLength(maxLength.toLong)), + "returns 413 on exceeded max content length (request)" + )(i => pureResult(i.asRight[Unit])) { (backend, baseUri) => + val tooLargeBody: String = List.fill(maxLength + 1)('x').mkString + basicRequest.post(uri"$baseUri/api/echo").body(tooLargeBody).send(backend).map(_.code shouldBe StatusCode.PayloadTooLarge) + } + def testPayloadWithinLimit[I]( + testedEndpoint: PublicEndpoint[I, Unit, I, Any], + ) = testServer( + testedEndpoint.attribute(AttributeKey[MaxContentLength], MaxContentLength(maxLength.toLong)), + "returns OK on content length below or equal max (request)", + )(i => pureResult(i.asRight[Unit])) { (backend, baseUri) => + val fineBody: String = List.fill(maxLength)('x').mkString + basicRequest.post(uri"$baseUri/api/echo").body(fineBody).send(backend).map(_.code shouldBe StatusCode.Ok) + } + def maxContentLengthTests(): List[Test] = List( - { - val maxContentLength = 300 - testServer( - in_string_out_string.attribute(AttributeKey[MaxContentLength], MaxContentLength(maxContentLength.toLong)), - "returns 413 on exceeded max content length (request)" - )(_ => pureResult("ok".asRight[Unit])) { (backend, baseUri) => - val tooLargeBody: String = List.fill(maxContentLength + 1)('x').mkString - basicRequest.post(uri"$baseUri/api/echo").body(tooLargeBody).send(backend).map(_.code shouldBe StatusCode.PayloadTooLarge) - } - }, { - val maxContentLength = 300 - testServer( - in_string_out_string.attribute(AttributeKey[MaxContentLength], MaxContentLength(maxContentLength.toLong)), - "returns OK on content length below or equal max (request)" - )(_ => pureResult("ok".asRight[Unit])) { (backend, baseUri) => - val fineBody: String = List.fill(maxContentLength)('x').mkString - basicRequest.post(uri"$baseUri/api/echo").body(fineBody).send(backend).map(_.code shouldBe StatusCode.Ok) - } - } + testPayloadTooLarge(in_string_out_string), + testPayloadTooLarge(in_byte_array_out_byte_array), + testPayloadTooLarge(in_file_out_file), + testPayloadTooLarge(in_byte_buffer_out_byte_buffer), + testPayloadWithinLimit(in_string_out_string), + testPayloadWithinLimit(in_byte_array_out_byte_array), + testPayloadWithinLimit(in_file_out_file), + testPayloadWithinLimit(in_byte_buffer_out_byte_buffer) ) def exceptionTests(): List[Test] = List( From 13ff4b81956172f95d5473ddeb1965cd787597cd Mon Sep 17 00:00:00 2001 From: kciesielski Date: Thu, 23 Nov 2023 14:05:39 +0100 Subject: [PATCH 04/35] More refactoring in tests --- .../tapir/server/tests/ServerBasicTests.scala | 30 +++++++++---------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala index a0f42428e2..cc33d50c8a 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala @@ -50,7 +50,7 @@ class ServerBasicTests[F[_], OPTIONS, ROUTE]( customiseDecodeFailureHandlerTests() ++ serverSecurityLogicTests() ++ (if (inputStreamSupport) inputStreamTests() else Nil) ++ - (if (maxContentLength.nonEmpty) maxContentLengthTests() else Nil) ++ + maxContentLength.map(maxContentLengthTests).getOrElse(Nil) ++ exceptionTests() def basicTests(): List[Test] = List( @@ -747,12 +747,9 @@ class ServerBasicTests[F[_], OPTIONS, ROUTE]( } ) - val maxLength = 300 - private def limited[I, E, O, R](endpoint: PublicEndpoint[I, E, O, R], maxBytes: Int) = - endpoint.attribute(AttributeKey[MaxContentLength], MaxContentLength(maxBytes.toLong)) - def testPayloadTooLarge[I]( - testedEndpoint: PublicEndpoint[I, Unit, I, Any] + testedEndpoint: PublicEndpoint[I, Unit, I, Any], + maxLength: Int, ) = testServer( testedEndpoint.attribute(AttributeKey[MaxContentLength], MaxContentLength(maxLength.toLong)), "returns 413 on exceeded max content length (request)" @@ -762,23 +759,24 @@ class ServerBasicTests[F[_], OPTIONS, ROUTE]( } def testPayloadWithinLimit[I]( testedEndpoint: PublicEndpoint[I, Unit, I, Any], + maxLength: Int, ) = testServer( testedEndpoint.attribute(AttributeKey[MaxContentLength], MaxContentLength(maxLength.toLong)), - "returns OK on content length below or equal max (request)", + "returns OK on content length below or equal max (request)", )(i => pureResult(i.asRight[Unit])) { (backend, baseUri) => val fineBody: String = List.fill(maxLength)('x').mkString basicRequest.post(uri"$baseUri/api/echo").body(fineBody).send(backend).map(_.code shouldBe StatusCode.Ok) } - def maxContentLengthTests(): List[Test] = List( - testPayloadTooLarge(in_string_out_string), - testPayloadTooLarge(in_byte_array_out_byte_array), - testPayloadTooLarge(in_file_out_file), - testPayloadTooLarge(in_byte_buffer_out_byte_buffer), - testPayloadWithinLimit(in_string_out_string), - testPayloadWithinLimit(in_byte_array_out_byte_array), - testPayloadWithinLimit(in_file_out_file), - testPayloadWithinLimit(in_byte_buffer_out_byte_buffer) + def maxContentLengthTests(maxLength: Int): List[Test] = List( + testPayloadTooLarge(in_string_out_string, maxLength), + testPayloadTooLarge(in_byte_array_out_byte_array, maxLength), + testPayloadTooLarge(in_file_out_file, maxLength), + testPayloadTooLarge(in_byte_buffer_out_byte_buffer, maxLength), + testPayloadWithinLimit(in_string_out_string, maxLength), + testPayloadWithinLimit(in_byte_array_out_byte_array, maxLength), + testPayloadWithinLimit(in_file_out_file, maxLength), + testPayloadWithinLimit(in_byte_buffer_out_byte_buffer, maxLength) ) def exceptionTests(): List[Test] = List( From c1c478918bf59c1c865e9e8c1c4b6827664f0d33 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Thu, 23 Nov 2023 14:05:53 +0100 Subject: [PATCH 05/35] Fix RequestBody impls --- .../main/scala/sttp/tapir/server/akkahttp/AkkaRequestBody.scala | 2 +- .../sttp/tapir/server/pekkogrpc/PekkoGrpcRequestBody.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/server/akka-http-server/src/main/scala/sttp/tapir/server/akkahttp/AkkaRequestBody.scala b/server/akka-http-server/src/main/scala/sttp/tapir/server/akkahttp/AkkaRequestBody.scala index 31231631b7..19654402f7 100644 --- a/server/akka-http-server/src/main/scala/sttp/tapir/server/akkahttp/AkkaRequestBody.scala +++ b/server/akka-http-server/src/main/scala/sttp/tapir/server/akkahttp/AkkaRequestBody.scala @@ -21,7 +21,7 @@ private[akkahttp] class AkkaRequestBody(serverOptions: AkkaHttpServerOptions)(im ec: ExecutionContext ) extends RequestBody[Future, AkkaStreams] { override val streams: AkkaStreams = AkkaStreams - override def toRaw[R](request: ServerRequest, bodyType: RawBodyType[R]): Future[RawValue[R]] = + override def toRaw[R](request: ServerRequest, bodyType: RawBodyType[R], maxBytes: Option[Long]): Future[RawValue[R]] = toRawFromEntity(request, akkeRequestEntity(request), bodyType) override def toStream(request: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = { val stream = akkeRequestEntity(request).dataBytes diff --git a/server/pekko-grpc-server/src/main/scala/sttp/tapir/server/pekkogrpc/PekkoGrpcRequestBody.scala b/server/pekko-grpc-server/src/main/scala/sttp/tapir/server/pekkogrpc/PekkoGrpcRequestBody.scala index 206aa429f7..3dac104b2e 100644 --- a/server/pekko-grpc-server/src/main/scala/sttp/tapir/server/pekkogrpc/PekkoGrpcRequestBody.scala +++ b/server/pekko-grpc-server/src/main/scala/sttp/tapir/server/pekkogrpc/PekkoGrpcRequestBody.scala @@ -22,7 +22,7 @@ private[pekkogrpc] class PekkoGrpcRequestBody(serverOptions: PekkoHttpServerOpti private val grpcProtocol = GrpcProtocolNative.newReader(Identity) override val streams: PekkoStreams = PekkoStreams - override def toRaw[R](request: ServerRequest, bodyType: RawBodyType[R]): Future[RawValue[R]] = + override def toRaw[R](request: ServerRequest, bodyType: RawBodyType[R], maxBytes: Option[Long]): Future[RawValue[R]] = toRawFromEntity(request, akkaRequestEntity(request), bodyType) override def toStream(request: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = ??? From 88f6382ff64e3536f8b830973bf71241a1e23cde Mon Sep 17 00:00:00 2001 From: kciesielski Date: Fri, 24 Nov 2023 09:57:18 +0100 Subject: [PATCH 06/35] Simplify error handling --- .../main/scala/sttp/tapir/DecodeResult.scala | 1 - .../decodefailure/DecodeFailureHandler.scala | 25 ++++++++++--------- .../RequestBodyToRawException.scala | 8 ------ .../interpreter/ServerInterpreter.scala | 6 ++--- .../netty/internal/NettyCatsRequestBody.scala | 5 ++-- .../netty/internal/NettyRequestBody.scala | 5 ++-- .../netty/internal/NettyZioRequestBody.scala | 7 +++--- 7 files changed, 25 insertions(+), 32 deletions(-) delete mode 100644 server/core/src/main/scala/sttp/tapir/server/interpreter/RequestBodyToRawException.scala diff --git a/core/src/main/scala/sttp/tapir/DecodeResult.scala b/core/src/main/scala/sttp/tapir/DecodeResult.scala index fc372409c4..ae59e3dc04 100644 --- a/core/src/main/scala/sttp/tapir/DecodeResult.scala +++ b/core/src/main/scala/sttp/tapir/DecodeResult.scala @@ -35,7 +35,6 @@ object DecodeResult { } } case class Mismatch(expected: String, actual: String) extends Failure - case class BodyTooLarge(maxBytes: Long) extends Failure /** A validation error that occurred when decoding the value, that is, when some `Validator` failed. */ case class InvalidValue(errors: List[ValidationError[_]]) extends Failure diff --git a/server/core/src/main/scala/sttp/tapir/server/interceptor/decodefailure/DecodeFailureHandler.scala b/server/core/src/main/scala/sttp/tapir/server/interceptor/decodefailure/DecodeFailureHandler.scala index 362f4b8e94..ab07eacc3c 100644 --- a/server/core/src/main/scala/sttp/tapir/server/interceptor/decodefailure/DecodeFailureHandler.scala +++ b/server/core/src/main/scala/sttp/tapir/server/interceptor/decodefailure/DecodeFailureHandler.scala @@ -10,6 +10,7 @@ import sttp.tapir.server.model.ValuedEndpointOutput import sttp.tapir.{DecodeResult, EndpointIO, EndpointInput, ValidationError, Validator, server, _} import scala.annotation.tailrec +import sttp.capabilities.StreamMaxLengthExceededException trait DecodeFailureHandler[F[_]] { @@ -122,12 +123,12 @@ object DefaultDecodeFailureHandler { case (_: EndpointIO.Header[_], _) => respondBadRequest case (fh: EndpointIO.FixedHeader[_], _: DecodeResult.Mismatch) if fh.h.name == HeaderNames.ContentType => respondUnsupportedMediaType - case (_: EndpointIO.FixedHeader[_], _) => respondBadRequest - case (_: EndpointIO.Headers[_], _) => respondBadRequest - case (_, _: DecodeResult.BodyTooLarge) => respondPayloadTooLarge - case (_: EndpointIO.Body[_, _], _) => respondBadRequest - case (_: EndpointIO.OneOfBody[_, _], _: DecodeResult.Mismatch) => respondUnsupportedMediaType - case (_: EndpointIO.StreamBodyWrapper[_, _], _) => respondBadRequest + case (_: EndpointIO.FixedHeader[_], _) => respondBadRequest + case (_: EndpointIO.Headers[_], _) => respondBadRequest + case (_, DecodeResult.Error(_, _: StreamMaxLengthExceededException)) => respondPayloadTooLarge + case (_: EndpointIO.Body[_, _], _) => respondBadRequest + case (_: EndpointIO.OneOfBody[_, _], _: DecodeResult.Mismatch) => respondUnsupportedMediaType + case (_: EndpointIO.StreamBodyWrapper[_, _], _) => respondBadRequest // we assume that the only decode failure that might happen during path segment decoding is an error // a non-standard path decoder might return Missing/Multiple/Mismatch, but that would be indistinguishable from // a path shape mismatch @@ -226,12 +227,12 @@ object DefaultDecodeFailureHandler { } .mkString(", ") ) - case Missing => Some("missing") - case Multiple(_) => Some("multiple values") - case Mismatch(_, _) => Some("value mismatch") - case BodyTooLarge(maxBytes) => Some(s"Content length limit: $maxBytes bytes") - case _: Error => None - case _: InvalidValue => None + case Missing => Some("missing") + case Multiple(_) => Some("multiple values") + case Mismatch(_, _) => Some("value mismatch") + case Error(_, StreamMaxLengthExceededException(maxBytes)) => Some(s"Content length limit: $maxBytes bytes") + case _: Error => None + case _: InvalidValue => None } def combineSourceAndDetail(source: String, detail: Option[String]): String = diff --git a/server/core/src/main/scala/sttp/tapir/server/interpreter/RequestBodyToRawException.scala b/server/core/src/main/scala/sttp/tapir/server/interpreter/RequestBodyToRawException.scala deleted file mode 100644 index 965877aebb..0000000000 --- a/server/core/src/main/scala/sttp/tapir/server/interpreter/RequestBodyToRawException.scala +++ /dev/null @@ -1,8 +0,0 @@ -package sttp.tapir.server.interpreter - -import sttp.tapir.DecodeResult - -/** Can be used with RequestBody.toRaw to fail its effect F and pass failures, which are treated as decoding failures that happen before - * actual decoding of the raw value. - */ -private[tapir] case class RequestBodyToRawException(failure: DecodeResult.Failure) extends Exception diff --git a/server/core/src/main/scala/sttp/tapir/server/interpreter/ServerInterpreter.scala b/server/core/src/main/scala/sttp/tapir/server/interpreter/ServerInterpreter.scala index 49bb18e9b3..a49ac9d6d3 100644 --- a/server/core/src/main/scala/sttp/tapir/server/interpreter/ServerInterpreter.scala +++ b/server/core/src/main/scala/sttp/tapir/server/interpreter/ServerInterpreter.scala @@ -199,10 +199,8 @@ class ServerInterpreter[R, F[_], B, S]( } } .handleError { - case RequestBodyToRawException(failure) => - (DecodeBasicInputsResult.Failure(bodyInput, failure): DecodeBasicInputsResult).unit - case StreamMaxLengthExceededException(maxBytes) => - (DecodeBasicInputsResult.Failure(bodyInput, DecodeResult.BodyTooLarge(maxBytes)): DecodeBasicInputsResult).unit + case e: StreamMaxLengthExceededException => + (DecodeBasicInputsResult.Failure(bodyInput, DecodeResult.Error("", e)): DecodeBasicInputsResult).unit } } diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsRequestBody.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsRequestBody.scala index 90d8a3a81a..3c30ed8579 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsRequestBody.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsRequestBody.scala @@ -10,12 +10,13 @@ import io.netty.buffer.ByteBufUtil import io.netty.handler.codec.http.{FullHttpRequest, HttpContent} import sttp.capabilities.fs2.Fs2Streams import sttp.tapir.model.ServerRequest -import sttp.tapir.server.interpreter.{RawValue, RequestBody, RequestBodyToRawException} +import sttp.tapir.server.interpreter.{RawValue, RequestBody} import sttp.tapir.{FileRange, InputStreamRange, RawBodyType, TapirFile} import java.io.ByteArrayInputStream import java.nio.ByteBuffer import sttp.tapir.DecodeResult +import sttp.capabilities.StreamMaxLengthExceededException private[netty] class NettyCatsRequestBody[F[_]](createFile: ServerRequest => F[TapirFile])(implicit val monad: Async[F]) extends RequestBody[F, Fs2Streams[F]] { @@ -30,7 +31,7 @@ private[netty] class NettyCatsRequestBody[F[_]](createFile: ServerRequest => F[T maxBytes .map(max => if (buf.readableBytes() > max) - monad.raiseError[Array[Byte]](RequestBodyToRawException(DecodeResult.BodyTooLarge(max))) + monad.raiseError[Array[Byte]](StreamMaxLengthExceededException(max)) else monad.delay(ByteBufUtil.getBytes(buf)) ) diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyRequestBody.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyRequestBody.scala index 6f36268e35..82ced7e5b3 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyRequestBody.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyRequestBody.scala @@ -8,12 +8,13 @@ import sttp.tapir.{FileRange, InputStreamRange, RawBodyType, TapirFile} import sttp.tapir.model.ServerRequest import sttp.monad.syntax._ import sttp.tapir.capabilities.NoStreams -import sttp.tapir.server.interpreter.{RawValue, RequestBody, RequestBodyToRawException} +import sttp.tapir.server.interpreter.{RawValue, RequestBody} import java.nio.ByteBuffer import java.nio.file.Files import io.netty.buffer.ByteBuf import sttp.tapir.DecodeResult +import sttp.capabilities.StreamMaxLengthExceededException class NettyRequestBody[F[_]](createFile: ServerRequest => F[TapirFile])(implicit monadError: MonadError[F] @@ -28,7 +29,7 @@ class NettyRequestBody[F[_]](createFile: ServerRequest => F[TapirFile])(implicit maxBytes .map(max => if (buf.readableBytes() > max) - monadError.error[ByteBuf](RequestBodyToRawException(DecodeResult.BodyTooLarge(max))) + monadError.error[ByteBuf](StreamMaxLengthExceededException(max)) else monadError.unit(buf) ) diff --git a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/internal/NettyZioRequestBody.scala b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/internal/NettyZioRequestBody.scala index 7ca54ce22e..a80e69a652 100644 --- a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/internal/NettyZioRequestBody.scala +++ b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/internal/NettyZioRequestBody.scala @@ -6,7 +6,7 @@ import io.netty.handler.codec.http.FullHttpRequest import sttp.capabilities.zio.ZioStreams import sttp.tapir.RawBodyType._ import sttp.tapir.model.ServerRequest -import sttp.tapir.server.interpreter.{RawValue, RequestBody, RequestBodyToRawException} +import sttp.tapir.server.interpreter.{RawValue, RequestBody} import sttp.tapir.{FileRange, InputStreamRange, RawBodyType, TapirFile} import zio.interop.reactivestreams._ import zio.stream.{ZStream, _} @@ -15,6 +15,7 @@ import zio.{Chunk, RIO, ZIO} import java.io.ByteArrayInputStream import java.nio.ByteBuffer import sttp.tapir.DecodeResult +import sttp.capabilities.StreamMaxLengthExceededException private[netty] class NettyZioRequestBody[Env](createFile: ServerRequest => RIO[Env, TapirFile]) extends RequestBody[RIO[Env, *], ZioStreams] { @@ -29,7 +30,7 @@ private[netty] class NettyZioRequestBody[Env](createFile: ServerRequest => RIO[E maxBytes .map(max => if (buf.readableBytes() > max) - ZIO.fail(RequestBodyToRawException(DecodeResult.BodyTooLarge(max))) + ZIO.fail(StreamMaxLengthExceededException(max)) else ZIO.succeed(ByteBufUtil.getBytes(buf)) ) @@ -50,7 +51,7 @@ private[netty] class NettyZioRequestBody[Env](createFile: ServerRequest => RIO[E case InputStreamRangeBody => nettyRequestBytes.map(bs => RawValue(InputStreamRange(() => new ByteArrayInputStream(bs)))) case FileBody => - createFile(serverRequest) + createFile(serverRequest) .flatMap(tapirFile => { toStream(serverRequest, maxBytes) .run(ZSink.fromFile(tapirFile)) From 24beb30f6cc935518519d4cbccc171cce89f8fca Mon Sep 17 00:00:00 2001 From: kciesielski Date: Fri, 24 Nov 2023 12:08:46 +0100 Subject: [PATCH 07/35] Fix AkkaGrpcRequestBody --- .../scala/sttp/tapir/server/akkagrpc/AkkaGrpcRequestBody.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/akka-grpc-server/src/main/scala/sttp/tapir/server/akkagrpc/AkkaGrpcRequestBody.scala b/server/akka-grpc-server/src/main/scala/sttp/tapir/server/akkagrpc/AkkaGrpcRequestBody.scala index 315d7d8e71..9097491d57 100644 --- a/server/akka-grpc-server/src/main/scala/sttp/tapir/server/akkagrpc/AkkaGrpcRequestBody.scala +++ b/server/akka-grpc-server/src/main/scala/sttp/tapir/server/akkagrpc/AkkaGrpcRequestBody.scala @@ -22,7 +22,7 @@ private[akkagrpc] class AkkaGrpcRequestBody(serverOptions: AkkaHttpServerOptions private val grpcProtocol = GrpcProtocolNative.newReader(Identity) override val streams: AkkaStreams = AkkaStreams - override def toRaw[R](request: ServerRequest, bodyType: RawBodyType[R]): Future[RawValue[R]] = + override def toRaw[R](request: ServerRequest, bodyType: RawBodyType[R], maxBytes: Option[Long]): Future[RawValue[R]] = toRawFromEntity(request, akkaRequestEntity(request), bodyType) override def toStream(request: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = ??? From b6f666f46e9f63f916f9bd07b496a49b23e7f6ec Mon Sep 17 00:00:00 2001 From: kciesielski Date: Fri, 24 Nov 2023 12:10:29 +0100 Subject: [PATCH 08/35] Fix FinatraRequestBody --- .../scala/sttp/tapir/server/finatra/FinatraRequestBody.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/finatra-server/src/main/scala/sttp/tapir/server/finatra/FinatraRequestBody.scala b/server/finatra-server/src/main/scala/sttp/tapir/server/finatra/FinatraRequestBody.scala index 9a6071bc14..7e3d12fa6a 100644 --- a/server/finatra-server/src/main/scala/sttp/tapir/server/finatra/FinatraRequestBody.scala +++ b/server/finatra-server/src/main/scala/sttp/tapir/server/finatra/FinatraRequestBody.scala @@ -20,7 +20,7 @@ import scala.collection.immutable.Seq class FinatraRequestBody(serverOptions: FinatraServerOptions) extends RequestBody[Future, NoStreams] { override val streams: NoStreams = NoStreams - override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R]): Future[RawValue[R]] = { + override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R], maxBytes: Option[Long]): Future[RawValue[R]] = { val request = finatraRequest(serverRequest) toRaw(request, bodyType, request.content, request.charset.map(Charset.forName)) } From 054a75fa951f59c66141d431f830260516549258 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Fri, 24 Nov 2023 17:12:04 +0100 Subject: [PATCH 09/35] wip --- .../server/netty/NettyFutureServer.scala | 4 +- .../netty/NettyFutureServerInterpreter.scala | 2 + .../internal/NettyFutureRequestBody.scala | 62 +++++++++++++ .../internal/NettyServerInterpreter.scala | 4 +- .../FileWriterSubscriber.scala | 63 +++++++++++++ .../LimitedLengthSubscriber.scala | 45 +++++++++ .../reactivestreams/PromisingSubscriber.scala | 9 ++ .../reactivestreams/SimpleSubscriber.scala | 92 +++++++++++++++++++ .../server/netty/NettyFutureServerTest.scala | 5 +- .../NettyFutureTestServerInterpreter.scala | 2 +- 10 files changed, 282 insertions(+), 6 deletions(-) create mode 100644 server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyFutureRequestBody.scala create mode 100644 server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/FileWriterSubscriber.scala create mode 100644 server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/LimitedLengthSubscriber.scala create mode 100644 server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/PromisingSubscriber.scala create mode 100644 server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/SimpleSubscriber.scala diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala index eaa7a86fe8..abe9472126 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala @@ -121,10 +121,10 @@ case class NettyFutureServer(routes: Vector[FutureRoute], options: NettyFutureSe object NettyFutureServer { def apply()(implicit ec: ExecutionContext): NettyFutureServer = - NettyFutureServer(Vector.empty, NettyFutureServerOptions.default, NettyConfig.defaultNoStreaming) + NettyFutureServer(Vector.empty, NettyFutureServerOptions.default, NettyConfig.defaultWithStreaming) def apply(serverOptions: NettyFutureServerOptions)(implicit ec: ExecutionContext): NettyFutureServer = - NettyFutureServer(Vector.empty, serverOptions, NettyConfig.defaultNoStreaming) + NettyFutureServer(Vector.empty, serverOptions, NettyConfig.defaultWithStreaming) def apply(config: NettyConfig)(implicit ec: ExecutionContext): NettyFutureServer = NettyFutureServer(Vector.empty, NettyFutureServerOptions.default, config) diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServerInterpreter.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServerInterpreter.scala index 7c2ef53f7e..64b881652c 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServerInterpreter.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServerInterpreter.scala @@ -6,6 +6,7 @@ import sttp.tapir.server.netty.NettyFutureServerInterpreter.FutureRunAsync import sttp.tapir.server.netty.internal.{NettyServerInterpreter, RunAsync} import scala.concurrent.{ExecutionContext, Future} +import sttp.tapir.server.netty.internal.NettyFutureRequestBody trait NettyFutureServerInterpreter { def nettyServerOptions: NettyFutureServerOptions @@ -21,6 +22,7 @@ trait NettyFutureServerInterpreter { NettyServerInterpreter.toRoute( ses, nettyServerOptions.interceptors, + Some(new NettyFutureRequestBody(nettyServerOptions.createFile)), nettyServerOptions.createFile, nettyServerOptions.deleteFile, FutureRunAsync diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyFutureRequestBody.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyFutureRequestBody.scala new file mode 100644 index 0000000000..405a5a72aa --- /dev/null +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyFutureRequestBody.scala @@ -0,0 +1,62 @@ +package sttp.tapir.server.netty.internal + +import io.netty.buffer.{ByteBufInputStream, ByteBufUtil} +import io.netty.handler.codec.http.FullHttpRequest +import sttp.capabilities +import sttp.monad.MonadError +import sttp.tapir.{FileRange, InputStreamRange, RawBodyType, TapirFile} +import sttp.tapir.model.ServerRequest +import sttp.monad.syntax._ +import sttp.tapir.capabilities.NoStreams +import sttp.tapir.server.interpreter.{RawValue, RequestBody} + +import java.nio.ByteBuffer +import java.nio.file.Files +import io.netty.buffer.ByteBuf +import sttp.tapir.DecodeResult +import sttp.capabilities.StreamMaxLengthExceededException +import org.playframework.netty.http.StreamedHttpRequest +import scala.concurrent.Future +import scala.concurrent.ExecutionContext +import reactivestreams._ +import java.io.ByteArrayInputStream + +class NettyFutureRequestBody(createFile: ServerRequest => Future[TapirFile])(implicit ec: ExecutionContext) extends RequestBody[Future, NoStreams] { + + override val streams: capabilities.Streams[NoStreams] = NoStreams + + override def toRaw[RAW](serverRequest: ServerRequest, bodyType: RawBodyType[RAW], maxBytes: Option[Long]): Future[RawValue[RAW]] = { + + def byteBuf: Future[ByteBuffer] = { + val subscriber = new SimpleSubscriber() // TODO add limiting bytes + nettyRequest(serverRequest).subscribe(subscriber) + subscriber.future + } + + def requestContentAsByteArray: Future[Array[Byte]] = byteBuf.map(_.array) + + bodyType match { + case RawBodyType.StringBody(charset) => requestContentAsByteArray.map(ba => RawValue(new String(ba, charset))) + case RawBodyType.ByteArrayBody => requestContentAsByteArray.map(ba => RawValue(ba)) + case RawBodyType.ByteBufferBody => byteBuf.map(buf => RawValue(buf)) + // InputStreamBody and InputStreamRangeBody can be further optimized to avoid loading all data in memory + case RawBodyType.InputStreamBody => requestContentAsByteArray.map(ba => RawValue(new ByteArrayInputStream(ba))) + case RawBodyType.InputStreamRangeBody => + requestContentAsByteArray.map(ba => RawValue(InputStreamRange(() => new ByteArrayInputStream(ba)))) + case RawBodyType.FileBody => + createFile(serverRequest) + .flatMap(file => // TODO wrap with limiting of the stream + FileWriterSubscriber.writeAll(nettyRequest(serverRequest), file.toPath).map( + _ => RawValue(FileRange(file), Seq(FileRange(file))) + ) + ) + case _: RawBodyType.MultipartBody => ??? + } + } + + override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = + throw new UnsupportedOperationException() + + private def nettyRequest(serverRequest: ServerRequest): StreamedHttpRequest = serverRequest.underlying.asInstanceOf[StreamedHttpRequest] +} + diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerInterpreter.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerInterpreter.scala index eb3dd3d02f..50ca5bc293 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerInterpreter.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerInterpreter.scala @@ -10,11 +10,13 @@ import sttp.tapir.server.interceptor.reject.RejectInterceptor import sttp.tapir.server.interceptor.{Interceptor, RequestResult} import sttp.tapir.server.interpreter.{BodyListener, FilterServerEndpoints, ServerInterpreter} import sttp.tapir.server.netty.{NettyResponse, NettyServerRequest, Route} +import sttp.tapir.server.interpreter.RequestBody object NettyServerInterpreter { def toRoute[F[_]: MonadError]( ses: List[ServerEndpoint[Any, F]], interceptors: List[Interceptor[F]], + requestBody: Option[RequestBody[F, NoStreams]], createFile: ServerRequest => F[TapirFile], deleteFile: TapirFile => F[Unit], runAsync: RunAsync[F] @@ -22,7 +24,7 @@ object NettyServerInterpreter { implicit val bodyListener: BodyListener[F, NettyResponse] = new NettyBodyListener(runAsync) val serverInterpreter = new ServerInterpreter[Any, F, NettyResponse, NoStreams]( FilterServerEndpoints(ses), - new NettyRequestBody(createFile), + requestBody.getOrElse(new NettyRequestBody(createFile)), new NettyToResponseBody, RejectInterceptor.disableWhenSingleEndpoint(interceptors, ses), deleteFile diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/FileWriterSubscriber.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/FileWriterSubscriber.scala new file mode 100644 index 0000000000..58d7e4b8e0 --- /dev/null +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/FileWriterSubscriber.scala @@ -0,0 +1,63 @@ +package sttp.tapir.server.netty.internal.reactivestreams + +import java.nio.ByteBuffer +import java.nio.channels.AsynchronousFileChannel +import java.nio.file.{Path, StandardOpenOption} +import org.reactivestreams.{Subscriber, Subscription} +import io.netty.handler.codec.http.HttpContent +import java.io.IOException +import scala.concurrent.Promise +import org.reactivestreams.Publisher +import scala.concurrent.Future + +class FileWriterSubscriber(path: Path) extends PromisingSubscriber[Unit, HttpContent] { + private var subscription: Subscription = _ + private var fileChannel: AsynchronousFileChannel = _ + private var position: Long = 0 + private val resultPromise = Promise[Unit]() + + override def future: Future[Unit] = resultPromise.future + + override def onSubscribe(s: Subscription): Unit = { + this.subscription = s + fileChannel = AsynchronousFileChannel.open(path, StandardOpenOption.WRITE, StandardOpenOption.CREATE) + s.request(1) + } + + override def onNext(httpContent: HttpContent): Unit = { + val byteBuffer = httpContent.content().nioBuffer() + fileChannel.write( + byteBuffer, + position, + (), + new java.nio.channels.CompletionHandler[Integer, Unit] { + override def completed(result: Integer, attachment: Unit): Unit = { + position += result + subscription.request(1) + } + + override def failed(exc: Throwable, attachment: Unit): Unit = { + onError(exc) + } + } + ) + } + + override def onError(t: Throwable): Unit = { + fileChannel.close() + resultPromise.failure(t) + } + + override def onComplete(): Unit = { + fileChannel.close() + resultPromise.success(()) + } +} + +object FileWriterSubscriber { + def writeAll(publisher: Publisher[HttpContent], path: Path, maxBytes: Long): Future[Unit] = { + val subscriber = new LimitedLengthSubscriber(maxBytes, new FileWriterSubscriber(path)) + publisher.subscribe(subscriber) + subscriber.future + } +} diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/LimitedLengthSubscriber.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/LimitedLengthSubscriber.scala new file mode 100644 index 0000000000..f6d10b6100 --- /dev/null +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/LimitedLengthSubscriber.scala @@ -0,0 +1,45 @@ +package sttp.tapir.server.netty.internal.reactivestreams + +import org.reactivestreams.{Publisher, Subscriber, Subscription} + +import scala.collection.JavaConverters._ +import scala.concurrent.Promise +import io.netty.handler.codec.http.HttpContent +import scala.concurrent.Future +import io.netty.buffer.ByteBufUtil +import sttp.capabilities.StreamMaxLengthExceededException + +// based on org.asynchttpclient.request.body.generator.ReactiveStreamsBodyGenerator.SimpleSubscriber +// Requests all data at once and loads it into memory + +private[netty] class LimitedLengthSubscriber[R](maxBytes: Long, delegate: PromisingSubscriber[R, HttpContent]) + extends PromisingSubscriber[R, HttpContent] { + private var size = 0L + + override def future: Future[R] = delegate.future + + override def onSubscribe(s: Subscription): Unit = + delegate.onSubscribe(s) + + override def onNext(content: HttpContent): Unit = { + assert(content != null) + size = size + content.content.readableBytes() + if (size > maxBytes) + onError(StreamMaxLengthExceededException(maxBytes)) + else + delegate.onNext(content) + } + + override def onError(t: Throwable): Unit = { + assert(t != null) + delegate.onError(t) + } + + override def onComplete(): Unit = { + delegate.onComplete() + } +} + +object LimitedLengthSubscriber { + def () +} diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/PromisingSubscriber.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/PromisingSubscriber.scala new file mode 100644 index 0000000000..5c0bef2545 --- /dev/null +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/PromisingSubscriber.scala @@ -0,0 +1,9 @@ +package sttp.tapir.server.netty.internal.reactivestreams + +import org.reactivestreams.Subscriber + +import scala.concurrent.Future + +trait PromisingSubscriber[R, A] extends Subscriber[A] { + def future: Future[R] +} diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/SimpleSubscriber.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/SimpleSubscriber.scala new file mode 100644 index 0000000000..db9f394334 --- /dev/null +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/SimpleSubscriber.scala @@ -0,0 +1,92 @@ +package sttp.tapir.server.netty.internal.reactivestreams + +import java.nio.ByteBuffer +import java.util.concurrent.ConcurrentLinkedQueue +import java.util.concurrent.atomic.AtomicReference +import java.util.function.UnaryOperator + +import org.reactivestreams.{Publisher, Subscriber, Subscription} + +import scala.collection.JavaConverters._ +import scala.concurrent.Promise +import io.netty.handler.codec.http.HttpContent +import scala.concurrent.Future +import io.netty.buffer.ByteBufUtil + +// based on org.asynchttpclient.request.body.generator.ReactiveStreamsBodyGenerator.SimpleSubscriber +// Requests all data at once and loads it into memory + +private[netty] class SimpleSubscriber() extends PromisingSubscriber[ByteBuffer, HttpContent] { + // a pair of values: (is cancelled, current subscription) + private val subscription = new AtomicReference[(Boolean, Subscription)]((false, null)) + private val chunks = new ConcurrentLinkedQueue[Array[Byte]]() + private var size = 0 + private val resultPromise = Promise[ByteBuffer]() + + override def future: Future[ByteBuffer] = resultPromise.future + + override def onSubscribe(s: Subscription): Unit = { + assert(s != null) + + // The following can be safely run multiple times, as cancel() is idempotent + val result = subscription.updateAndGet(new UnaryOperator[(Boolean, Subscription)] { + override def apply(current: (Boolean, Subscription)): (Boolean, Subscription) = { + // If someone has made a mistake and added this Subscriber multiple times, let's handle it gracefully + if (current._2 != null) { + current._2.cancel() // Cancel the additional subscription + } + + if (current._1) { // already cancelled + s.cancel() + (true, null) + } else { // happy path + (false, s) + } + } + }) + + if (result._2 != null) { + result._2.request(Long.MaxValue) // not cancelled, we can request data + } + } + + override def onNext(content: HttpContent): Unit = { + assert(content != null) + println(content.content().readableBytes()) + println("On next, calling getBytes") + val a = ByteBufUtil.getBytes(content.content()) + println("Bytes loaded") + size += a.length + chunks.add(a) + } + + override def onError(t: Throwable): Unit = { + assert(t != null) + chunks.clear() + resultPromise.failure(t) + } + + override def onComplete(): Unit = { + println(">>>>>> onComplete") + val result = ByteBuffer.allocate(size) + chunks.asScala.foreach(result.put) + chunks.clear() + resultPromise.success(result) + } + + def cancel(): Unit = + // subscription.cancel is idempotent: + // https://github.com/reactive-streams/reactive-streams-jvm/blob/v1.0.3/README.md#specification + // so the following can be safely retried + subscription.updateAndGet(new UnaryOperator[(Boolean, Subscription)] { + override def apply(current: (Boolean, Subscription)): (Boolean, Subscription) = { + if (current._2 != null) current._2.cancel() + (true, null) + } + }) +} + +object SimpleSubscriber { + def readAll(publisher: Publisher[HttpContent], maxBytes: Option[Long]) = + new LimitedLengthSubscriber(maxBytes, new SimpleSubscriber()) +} diff --git a/server/netty-server/src/test/scala/sttp/tapir/server/netty/NettyFutureServerTest.scala b/server/netty-server/src/test/scala/sttp/tapir/server/netty/NettyFutureServerTest.scala index d1485b96c5..6eb116371c 100644 --- a/server/netty-server/src/test/scala/sttp/tapir/server/netty/NettyFutureServerTest.scala +++ b/server/netty-server/src/test/scala/sttp/tapir/server/netty/NettyFutureServerTest.scala @@ -21,8 +21,9 @@ class NettyFutureServerTest extends TestSuite with EitherValues { val interpreter = new NettyFutureTestServerInterpreter(eventLoopGroup) val createServerTest = new DefaultCreateServerTest(backend, interpreter) - val tests = new AllServerTests(createServerTest, interpreter, backend, multipart = false, maxContentLength = Some(300)).tests() ++ - new ServerGracefulShutdownTests(createServerTest, Sleeper.futureSleeper).tests() + val tests = + new AllServerTests(createServerTest, interpreter, backend, multipart = false, maxContentLength = Some(3000)).tests() ++ + new ServerGracefulShutdownTests(createServerTest, Sleeper.futureSleeper).tests() (tests, eventLoopGroup) }) { case (_, eventLoopGroup) => diff --git a/server/netty-server/src/test/scala/sttp/tapir/server/netty/NettyFutureTestServerInterpreter.scala b/server/netty-server/src/test/scala/sttp/tapir/server/netty/NettyFutureTestServerInterpreter.scala index 0073136a5e..ebe4877c6f 100644 --- a/server/netty-server/src/test/scala/sttp/tapir/server/netty/NettyFutureTestServerInterpreter.scala +++ b/server/netty-server/src/test/scala/sttp/tapir/server/netty/NettyFutureTestServerInterpreter.scala @@ -24,7 +24,7 @@ class NettyFutureTestServerInterpreter(eventLoopGroup: NioEventLoopGroup)(implic gracefulShutdownTimeout: Option[FiniteDuration] = None ): Resource[IO, (Port, KillSwitch)] = { val config = - NettyConfig.defaultNoStreaming + NettyConfig.defaultWithStreaming .eventLoopGroup(eventLoopGroup) .randomPort .withDontShutdownEventLoopGroupOnClose From 0bb3c7dc61f317f053fc2a8478a45bde4b04c221 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Mon, 27 Nov 2023 15:25:00 +0100 Subject: [PATCH 10/35] wip --- .../internal/NettyFutureRequestBody.scala | 40 ++++----- .../netty/internal/NettyToResponseBody.scala | 62 +++++--------- .../reactivestreams/FileRangePublisher.scala | 84 +++++++++++++++++++ .../FileWriterSubscriber.scala | 18 ++-- .../InputStreamPublisher.scala | 77 +++++++++++++++++ .../LimitedLengthSubscriber.scala | 42 ++++------ .../reactivestreams/SimpleSubscriber.scala | 69 ++++----------- .../server/netty/NettyFutureServerTest.scala | 2 +- 8 files changed, 247 insertions(+), 147 deletions(-) create mode 100644 server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/FileRangePublisher.scala create mode 100644 server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/InputStreamPublisher.scala diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyFutureRequestBody.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyFutureRequestBody.scala index 405a5a72aa..e87923f182 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyFutureRequestBody.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyFutureRequestBody.scala @@ -1,9 +1,7 @@ package sttp.tapir.server.netty.internal -import io.netty.buffer.{ByteBufInputStream, ByteBufUtil} import io.netty.handler.codec.http.FullHttpRequest import sttp.capabilities -import sttp.monad.MonadError import sttp.tapir.{FileRange, InputStreamRange, RawBodyType, TapirFile} import sttp.tapir.model.ServerRequest import sttp.monad.syntax._ @@ -11,26 +9,31 @@ import sttp.tapir.capabilities.NoStreams import sttp.tapir.server.interpreter.{RawValue, RequestBody} import java.nio.ByteBuffer -import java.nio.file.Files -import io.netty.buffer.ByteBuf -import sttp.tapir.DecodeResult -import sttp.capabilities.StreamMaxLengthExceededException import org.playframework.netty.http.StreamedHttpRequest import scala.concurrent.Future import scala.concurrent.ExecutionContext import reactivestreams._ import java.io.ByteArrayInputStream -class NettyFutureRequestBody(createFile: ServerRequest => Future[TapirFile])(implicit ec: ExecutionContext) extends RequestBody[Future, NoStreams] { +class NettyFutureRequestBody(createFile: ServerRequest => Future[TapirFile])(implicit ec: ExecutionContext) + extends RequestBody[Future, NoStreams] { override val streams: capabilities.Streams[NoStreams] = NoStreams override def toRaw[RAW](serverRequest: ServerRequest, bodyType: RawBodyType[RAW], maxBytes: Option[Long]): Future[RawValue[RAW]] = { - def byteBuf: Future[ByteBuffer] = { - val subscriber = new SimpleSubscriber() // TODO add limiting bytes - nettyRequest(serverRequest).subscribe(subscriber) - subscriber.future + def byteBuf: Future[ByteBuffer] = + serverRequest.underlying match { + case r: StreamedHttpRequest => SimpleSubscriber.readAll(r, maxBytes) + // This can still happen in case an EmptyHttpRequest is received + case r: FullHttpRequest => Future.successful { + val underlyingBuf = r.content().nioBuffer() + if (underlyingBuf.hasArray()) + underlyingBuf + else + ByteBuffer.wrap(new Array[Byte](0)) + } + case other => Future.failed(new UnsupportedOperationException(s"Unexpected request type: ${other.getClass.getName()}")) } def requestContentAsByteArray: Future[Array[Byte]] = byteBuf.map(_.array) @@ -40,16 +43,16 @@ class NettyFutureRequestBody(createFile: ServerRequest => Future[TapirFile])(imp case RawBodyType.ByteArrayBody => requestContentAsByteArray.map(ba => RawValue(ba)) case RawBodyType.ByteBufferBody => byteBuf.map(buf => RawValue(buf)) // InputStreamBody and InputStreamRangeBody can be further optimized to avoid loading all data in memory - case RawBodyType.InputStreamBody => requestContentAsByteArray.map(ba => RawValue(new ByteArrayInputStream(ba))) + case RawBodyType.InputStreamBody => requestContentAsByteArray.map(ba => RawValue(new ByteArrayInputStream(ba))) case RawBodyType.InputStreamRangeBody => requestContentAsByteArray.map(ba => RawValue(InputStreamRange(() => new ByteArrayInputStream(ba)))) case RawBodyType.FileBody => - createFile(serverRequest) - .flatMap(file => // TODO wrap with limiting of the stream - FileWriterSubscriber.writeAll(nettyRequest(serverRequest), file.toPath).map( - _ => RawValue(FileRange(file), Seq(FileRange(file))) - ) - ) + createFile(serverRequest) + .flatMap(file => + FileWriterSubscriber + .writeAll(nettyRequest(serverRequest), file.toPath, maxBytes) + .map(_ => RawValue(FileRange(file), Seq(FileRange(file)))) + ) case _: RawBodyType.MultipartBody => ??? } } @@ -59,4 +62,3 @@ class NettyFutureRequestBody(createFile: ServerRequest => Future[TapirFile])(imp private def nettyRequest(serverRequest: ServerRequest): StreamedHttpRequest = serverRequest.underlying.asInstanceOf[StreamedHttpRequest] } - diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToResponseBody.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToResponseBody.scala index dd39ebef31..5754eef7b5 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToResponseBody.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToResponseBody.scala @@ -2,31 +2,28 @@ package sttp.tapir.server.netty.internal import io.netty.buffer.Unpooled import io.netty.channel.ChannelHandlerContext -import io.netty.handler.stream.{ChunkedFile, ChunkedStream} +import io.netty.handler.codec.http.HttpContent +import org.reactivestreams.Publisher import sttp.capabilities import sttp.model.HasHeaders import sttp.tapir.capabilities.NoStreams import sttp.tapir.server.interpreter.ToResponseBody import sttp.tapir.server.netty.NettyResponse -import sttp.tapir.server.netty.NettyResponseContent.{ - ByteBufNettyResponseContent, - ChunkedFileNettyResponseContent, - ChunkedStreamNettyResponseContent -} +import sttp.tapir.server.netty.NettyResponseContent.{ByteBufNettyResponseContent, ReactivePublisherNettyResponseContent} +import sttp.tapir.server.netty.internal.NettyToResponseBody.DefaultChunkSize +import sttp.tapir.server.netty.internal.reactivestreams.{FileRangePublisher, InputStreamPublisher} import sttp.tapir.{CodecFormat, FileRange, InputStreamRange, RawBodyType, WebSocketBodyOutput} -import java.io.{InputStream, RandomAccessFile} +import java.io.InputStream import java.nio.ByteBuffer import java.nio.charset.Charset - -private[internal] class RangedChunkedStream(raw: InputStream, length: Long) extends ChunkedStream(raw) { - - override def isEndOfInput(): Boolean = - super.isEndOfInput || transferredBytes == length -} +import java.util.concurrent.ForkJoinPool +import scala.concurrent.ExecutionContext class NettyToResponseBody extends ToResponseBody[NettyResponse, NoStreams] { override val streams: capabilities.Streams[NoStreams] = NoStreams + // TODO cleanup + lazy val blockingEc: ExecutionContext = ExecutionContext.fromExecutor(new ForkJoinPool) override def fromRawValue[R](v: R, headers: HasHeaders, format: CodecFormat, bodyType: RawBodyType[R]): NettyResponse = { bodyType match { @@ -42,44 +39,31 @@ class NettyToResponseBody extends ToResponseBody[NettyResponse, NoStreams] { val byteBuffer = v.asInstanceOf[ByteBuffer] (ctx: ChannelHandlerContext) => ByteBufNettyResponseContent(ctx.newPromise(), Unpooled.wrappedBuffer(byteBuffer)) - case RawBodyType.InputStreamBody => - (ctx: ChannelHandlerContext) => ChunkedStreamNettyResponseContent(ctx.newPromise(), wrap(v)) + case RawBodyType.InputStreamBody => + (ctx: ChannelHandlerContext) => ReactivePublisherNettyResponseContent(ctx.newPromise(), wrap(v)) case RawBodyType.InputStreamRangeBody => - (ctx: ChannelHandlerContext) => ChunkedStreamNettyResponseContent(ctx.newPromise(), wrap(v)) + (ctx: ChannelHandlerContext) => ReactivePublisherNettyResponseContent(ctx.newPromise(), wrap(v)) - case RawBodyType.FileBody => - (ctx: ChannelHandlerContext) => ChunkedFileNettyResponseContent(ctx.newPromise(), wrap(v)) + case RawBodyType.FileBody => { + (ctx: ChannelHandlerContext) => ReactivePublisherNettyResponseContent(ctx.newPromise(), wrap(v)) + + } case _: RawBodyType.MultipartBody => throw new UnsupportedOperationException } } - private def wrap(streamRange: InputStreamRange): ChunkedStream = { - streamRange.range - .map(r => new RangedChunkedStream(streamRange.inputStreamFromRangeStart(), r.contentLength)) - .getOrElse(new ChunkedStream(streamRange.inputStream())) + private def wrap(streamRange: InputStreamRange): Publisher[HttpContent] = { + new InputStreamPublisher(streamRange, DefaultChunkSize, blockingEc) } - private def wrap(content: InputStream): ChunkedStream = { - new ChunkedStream(content) + private def wrap(fileRange: FileRange): Publisher[HttpContent] = { + new FileRangePublisher(fileRange, DefaultChunkSize) } - private def wrap(content: FileRange): ChunkedFile = { - val file = content.file - val maybeRange = for { - range <- content.range - start <- range.start - end <- range.end - } yield (start, end + NettyToResponseBody.IncludingLastOffset) - - maybeRange match { - case Some((start, end)) => { - val randomAccessFile = new RandomAccessFile(file, NettyToResponseBody.ReadOnlyAccessMode) - new ChunkedFile(randomAccessFile, start, end - start, NettyToResponseBody.DefaultChunkSize) - } - case None => new ChunkedFile(file) - } + private def wrap(content: InputStream): Publisher[HttpContent] = { + wrap(InputStreamRange(() => content, range = None)) } override def fromStreamValue( diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/FileRangePublisher.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/FileRangePublisher.scala new file mode 100644 index 0000000000..fea3992bb7 --- /dev/null +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/FileRangePublisher.scala @@ -0,0 +1,84 @@ +package sttp.tapir.server.netty.internal.reactivestreams + +import io.netty.buffer.Unpooled +import io.netty.handler.codec.http.{DefaultHttpContent, HttpContent} +import org.reactivestreams.{Publisher, Subscriber, Subscription} +import sttp.tapir.FileRange + +import java.nio.ByteBuffer +import java.nio.channels.{AsynchronousFileChannel, CompletionHandler} +import java.nio.file.StandardOpenOption +import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong} + +class FileRangePublisher(fileRange: FileRange, chunkSize: Int) extends Publisher[HttpContent] { + override def subscribe(subscriber: Subscriber[_ >: HttpContent]): Unit = { + if (subscriber == null) throw new NullPointerException("Subscriber cannot be null") + val subscription = new FileRangeSubscription(subscriber, fileRange, chunkSize) + subscriber.onSubscribe(subscription) + } + + private class FileRangeSubscription(subscriber: Subscriber[_ >: HttpContent], fileRange: FileRange, chunkSize: Int) extends Subscription { + private lazy val channel: AsynchronousFileChannel = AsynchronousFileChannel.open(fileRange.file.toPath(), StandardOpenOption.READ) + private val demand = new AtomicLong(0L) + private val position = new AtomicLong(fileRange.range.flatMap(_.start).getOrElse(0L)) + private val buffer: ByteBuffer = ByteBuffer.allocate(chunkSize) + private val isCompleted = new AtomicBoolean(false) + private val readingInProgress = new AtomicBoolean(false) + + override def request(n: Long): Unit = { + if (n <= 0) subscriber.onError(new IllegalArgumentException("§3.9: n must be greater than 0")) + else { + demand.addAndGet(n) + readNextChunkIfNeeded() + } + } + + private def readNextChunkIfNeeded(): Unit = { + if (demand.get() > 0 && !isCompleted.get() && readingInProgress.compareAndSet(false, true)) { + val pos = position.get() + val expectedBytes: Int = fileRange.range.flatMap(_.end) match { + case Some(endPos) if pos + chunkSize > endPos => (endPos - pos + 1).toInt + case _ => chunkSize + } + buffer.clear() + channel.read( + buffer, + pos, + null, + new CompletionHandler[Integer, Void] { + override def completed(bytesRead: Integer, attachment: Void): Unit = { + if (bytesRead == -1) { + cancel() + subscriber.onComplete() + } else { + val bytesToRead = Math.min(bytesRead, expectedBytes) + buffer.flip() + val bytes = new Array[Byte](bytesToRead) + buffer.get(bytes) + position.addAndGet(bytesToRead.toLong) + subscriber.onNext(new DefaultHttpContent(Unpooled.wrappedBuffer(bytes))) + if (bytesToRead < expectedBytes) { + cancel() + subscriber.onComplete() + } else { + demand.decrementAndGet() + readingInProgress.set(false) + readNextChunkIfNeeded() // Read next chunk if there's more demand + } + } + } + + override def failed(exc: Throwable, attachment: Void): Unit = { + subscriber.onError(exc) + } + } + ) + } + } + + override def cancel(): Unit = { + isCompleted.set(true) + channel.close() + } + } +} diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/FileWriterSubscriber.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/FileWriterSubscriber.scala index 58d7e4b8e0..db20e5a467 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/FileWriterSubscriber.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/FileWriterSubscriber.scala @@ -1,14 +1,11 @@ package sttp.tapir.server.netty.internal.reactivestreams -import java.nio.ByteBuffer +import io.netty.handler.codec.http.HttpContent +import org.reactivestreams.{Publisher, Subscription} + import java.nio.channels.AsynchronousFileChannel import java.nio.file.{Path, StandardOpenOption} -import org.reactivestreams.{Subscriber, Subscription} -import io.netty.handler.codec.http.HttpContent -import java.io.IOException -import scala.concurrent.Promise -import org.reactivestreams.Publisher -import scala.concurrent.Future +import scala.concurrent.{Future, Promise} class FileWriterSubscriber(path: Path) extends PromisingSubscriber[Unit, HttpContent] { private var subscription: Subscription = _ @@ -37,6 +34,7 @@ class FileWriterSubscriber(path: Path) extends PromisingSubscriber[Unit, HttpCon } override def failed(exc: Throwable, attachment: Unit): Unit = { + subscription.cancel() onError(exc) } } @@ -55,9 +53,9 @@ class FileWriterSubscriber(path: Path) extends PromisingSubscriber[Unit, HttpCon } object FileWriterSubscriber { - def writeAll(publisher: Publisher[HttpContent], path: Path, maxBytes: Long): Future[Unit] = { - val subscriber = new LimitedLengthSubscriber(maxBytes, new FileWriterSubscriber(path)) - publisher.subscribe(subscriber) + def writeAll(publisher: Publisher[HttpContent], path: Path, maxBytes: Option[Long]): Future[Unit] = { + val subscriber = new FileWriterSubscriber(path) + publisher.subscribe(maxBytes.map(new LimitedLengthSubscriber(_, subscriber)).getOrElse(subscriber)) subscriber.future } } diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/InputStreamPublisher.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/InputStreamPublisher.scala new file mode 100644 index 0000000000..633fb94539 --- /dev/null +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/InputStreamPublisher.scala @@ -0,0 +1,77 @@ +package sttp.tapir.server.netty.internal.reactivestreams + +import io.netty.buffer.Unpooled +import io.netty.handler.codec.http.{DefaultHttpContent, HttpContent} +import org.reactivestreams.{Publisher, Subscriber, Subscription} +import sttp.tapir.InputStreamRange + +import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong} +import scala.concurrent.Future +import scala.util.Success +import scala.util.Failure +import scala.concurrent.ExecutionContext + +class InputStreamPublisher(range: InputStreamRange, chunkSize: Int, blockingEc: ExecutionContext) extends Publisher[HttpContent] { + override def subscribe(subscriber: Subscriber[_ >: HttpContent]): Unit = { + if (subscriber == null) throw new NullPointerException("Subscriber cannot be null") + val subscription = new InputStreamSubscription(subscriber, range, chunkSize) + subscriber.onSubscribe(subscription) + } + + private class InputStreamSubscription(subscriber: Subscriber[_ >: HttpContent], range: InputStreamRange, chunkSize: Int) + extends Subscription { + private val stream = range.inputStreamFromRangeStart() + private val demand = new AtomicLong(0L) + private val position = new AtomicLong(range.range.flatMap(_.start).getOrElse(0L)) + private val isCompleted = new AtomicBoolean(false) + private val readingInProgress = new AtomicBoolean(false) + + override def request(n: Long): Unit = { + if (n <= 0) subscriber.onError(new IllegalArgumentException("§3.9: n must be greater than 0")) + else { + demand.addAndGet(n) + readNextChunkIfNeeded() + } + } + + private def readNextChunkIfNeeded(): Unit = { + if (demand.get() > 0 && !isCompleted.get() && readingInProgress.compareAndSet(false, true)) { + val pos = position.get() + val expectedBytes: Int = range.range.flatMap(_.end) match { + case Some(endPos) if pos + chunkSize > endPos => (endPos - pos + 1).toInt + case _ => chunkSize + } + Future { + stream.readNBytes(expectedBytes) // Blocking I/IO + }(blockingEc) + .onComplete { + case Success(bytes) => + val bytesRead = bytes.length + if (bytesRead == 0) { + cancel() + subscriber.onComplete() + } else { + position.addAndGet(bytesRead.toLong) + subscriber.onNext(new DefaultHttpContent(Unpooled.wrappedBuffer(bytes))) + if (bytesRead < expectedBytes) { + cancel() + subscriber.onComplete() + } else { + demand.decrementAndGet() + readingInProgress.set(false) + readNextChunkIfNeeded() + } + } + case Failure(e) => + stream.close() + subscriber.onError(e) + }(blockingEc) + } + } + + override def cancel(): Unit = { + isCompleted.set(true) + stream.close() + } + } +} diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/LimitedLengthSubscriber.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/LimitedLengthSubscriber.scala index f6d10b6100..bd0e624084 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/LimitedLengthSubscriber.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/LimitedLengthSubscriber.scala @@ -1,45 +1,39 @@ package sttp.tapir.server.netty.internal.reactivestreams -import org.reactivestreams.{Publisher, Subscriber, Subscription} - -import scala.collection.JavaConverters._ -import scala.concurrent.Promise import io.netty.handler.codec.http.HttpContent -import scala.concurrent.Future -import io.netty.buffer.ByteBufUtil +import org.reactivestreams.{Subscriber, Subscription} import sttp.capabilities.StreamMaxLengthExceededException +import scala.collection.JavaConverters._ + // based on org.asynchttpclient.request.body.generator.ReactiveStreamsBodyGenerator.SimpleSubscriber // Requests all data at once and loads it into memory +private[netty] class LimitedLengthSubscriber[R](maxBytes: Long, delegate: Subscriber[HttpContent]) extends Subscriber[HttpContent] { + private var subscription: Subscription = _ + private var bytesReadSoFar = 0L -private[netty] class LimitedLengthSubscriber[R](maxBytes: Long, delegate: PromisingSubscriber[R, HttpContent]) - extends PromisingSubscriber[R, HttpContent] { - private var size = 0L - - override def future: Future[R] = delegate.future - - override def onSubscribe(s: Subscription): Unit = + override def onSubscribe(s: Subscription): Unit = { + subscription = s delegate.onSubscribe(s) + } override def onNext(content: HttpContent): Unit = { - assert(content != null) - size = size + content.content.readableBytes() - if (size > maxBytes) + bytesReadSoFar = bytesReadSoFar + content.content.readableBytes() + if (bytesReadSoFar > maxBytes) { + subscription.cancel() onError(StreamMaxLengthExceededException(maxBytes)) - else + subscription = null + } else delegate.onNext(content) } override def onError(t: Throwable): Unit = { - assert(t != null) - delegate.onError(t) + if (subscription != null) + delegate.onError(t) } override def onComplete(): Unit = { - delegate.onComplete() + if (subscription != null) + delegate.onComplete() } } - -object LimitedLengthSubscriber { - def () -} diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/SimpleSubscriber.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/SimpleSubscriber.scala index db9f394334..988a68336b 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/SimpleSubscriber.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/SimpleSubscriber.scala @@ -1,24 +1,17 @@ package sttp.tapir.server.netty.internal.reactivestreams +import io.netty.buffer.ByteBufUtil +import io.netty.handler.codec.http.HttpContent +import org.reactivestreams.{Publisher, Subscription} + import java.nio.ByteBuffer import java.util.concurrent.ConcurrentLinkedQueue -import java.util.concurrent.atomic.AtomicReference -import java.util.function.UnaryOperator - -import org.reactivestreams.{Publisher, Subscriber, Subscription} - import scala.collection.JavaConverters._ -import scala.concurrent.Promise -import io.netty.handler.codec.http.HttpContent -import scala.concurrent.Future -import io.netty.buffer.ByteBufUtil +import scala.concurrent.{Future, Promise} -// based on org.asynchttpclient.request.body.generator.ReactiveStreamsBodyGenerator.SimpleSubscriber // Requests all data at once and loads it into memory - private[netty] class SimpleSubscriber() extends PromisingSubscriber[ByteBuffer, HttpContent] { - // a pair of values: (is cancelled, current subscription) - private val subscription = new AtomicReference[(Boolean, Subscription)]((false, null)) + private var subscription: Subscription = _ private val chunks = new ConcurrentLinkedQueue[Array[Byte]]() private var size = 0 private val resultPromise = Promise[ByteBuffer]() @@ -26,67 +19,35 @@ private[netty] class SimpleSubscriber() extends PromisingSubscriber[ByteBuffer, override def future: Future[ByteBuffer] = resultPromise.future override def onSubscribe(s: Subscription): Unit = { - assert(s != null) - - // The following can be safely run multiple times, as cancel() is idempotent - val result = subscription.updateAndGet(new UnaryOperator[(Boolean, Subscription)] { - override def apply(current: (Boolean, Subscription)): (Boolean, Subscription) = { - // If someone has made a mistake and added this Subscriber multiple times, let's handle it gracefully - if (current._2 != null) { - current._2.cancel() // Cancel the additional subscription - } - - if (current._1) { // already cancelled - s.cancel() - (true, null) - } else { // happy path - (false, s) - } - } - }) - - if (result._2 != null) { - result._2.request(Long.MaxValue) // not cancelled, we can request data - } + subscription = s + s.request(1) } override def onNext(content: HttpContent): Unit = { - assert(content != null) - println(content.content().readableBytes()) - println("On next, calling getBytes") val a = ByteBufUtil.getBytes(content.content()) - println("Bytes loaded") size += a.length chunks.add(a) + subscription.request(1) } override def onError(t: Throwable): Unit = { - assert(t != null) chunks.clear() resultPromise.failure(t) } override def onComplete(): Unit = { - println(">>>>>> onComplete") val result = ByteBuffer.allocate(size) chunks.asScala.foreach(result.put) + result.flip() chunks.clear() resultPromise.success(result) } - - def cancel(): Unit = - // subscription.cancel is idempotent: - // https://github.com/reactive-streams/reactive-streams-jvm/blob/v1.0.3/README.md#specification - // so the following can be safely retried - subscription.updateAndGet(new UnaryOperator[(Boolean, Subscription)] { - override def apply(current: (Boolean, Subscription)): (Boolean, Subscription) = { - if (current._2 != null) current._2.cancel() - (true, null) - } - }) } object SimpleSubscriber { - def readAll(publisher: Publisher[HttpContent], maxBytes: Option[Long]) = - new LimitedLengthSubscriber(maxBytes, new SimpleSubscriber()) + def readAll(publisher: Publisher[HttpContent], maxBytes: Option[Long]): Future[ByteBuffer] = { + val subscriber = new SimpleSubscriber() + publisher.subscribe(maxBytes.map(max => new LimitedLengthSubscriber(max, subscriber)).getOrElse(subscriber)) + subscriber.future + } } diff --git a/server/netty-server/src/test/scala/sttp/tapir/server/netty/NettyFutureServerTest.scala b/server/netty-server/src/test/scala/sttp/tapir/server/netty/NettyFutureServerTest.scala index 6eb116371c..20604e82c9 100644 --- a/server/netty-server/src/test/scala/sttp/tapir/server/netty/NettyFutureServerTest.scala +++ b/server/netty-server/src/test/scala/sttp/tapir/server/netty/NettyFutureServerTest.scala @@ -22,7 +22,7 @@ class NettyFutureServerTest extends TestSuite with EitherValues { val createServerTest = new DefaultCreateServerTest(backend, interpreter) val tests = - new AllServerTests(createServerTest, interpreter, backend, multipart = false, maxContentLength = Some(3000)).tests() ++ + new AllServerTests(createServerTest, interpreter, backend, multipart = false, maxContentLength = Some(300000)).tests() ++ new ServerGracefulShutdownTests(createServerTest, Sleeper.futureSleeper).tests() (tests, eventLoopGroup) From 412e3dbb1dfee1487902c7f89ff8975d11f619b9 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Mon, 27 Nov 2023 15:40:38 +0100 Subject: [PATCH 11/35] Fix compilation for the Id server (temp solution) --- .../sttp/tapir/server/netty/loom/NettyIdServerInterpreter.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdServerInterpreter.scala b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdServerInterpreter.scala index 1f89ab6c09..b9d3ec0cc4 100644 --- a/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdServerInterpreter.scala +++ b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdServerInterpreter.scala @@ -12,6 +12,7 @@ trait NettyIdServerInterpreter { NettyServerInterpreter.toRoute[Id]( ses, nettyServerOptions.interceptors, + requestBody = None, nettyServerOptions.createFile, nettyServerOptions.deleteFile, new RunAsync[Id] { From 9b82b6d0de3a1a66c5f4ffd1244249ac57f98287 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Mon, 27 Nov 2023 16:39:47 +0100 Subject: [PATCH 12/35] Remove incorrect descriptions --- .../netty/internal/reactivestreams/LimitedLengthSubscriber.scala | 1 - .../server/netty/internal/reactivestreams/SimpleSubscriber.scala | 1 - 2 files changed, 2 deletions(-) diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/LimitedLengthSubscriber.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/LimitedLengthSubscriber.scala index bd0e624084..7a670ebf9e 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/LimitedLengthSubscriber.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/LimitedLengthSubscriber.scala @@ -7,7 +7,6 @@ import sttp.capabilities.StreamMaxLengthExceededException import scala.collection.JavaConverters._ // based on org.asynchttpclient.request.body.generator.ReactiveStreamsBodyGenerator.SimpleSubscriber -// Requests all data at once and loads it into memory private[netty] class LimitedLengthSubscriber[R](maxBytes: Long, delegate: Subscriber[HttpContent]) extends Subscriber[HttpContent] { private var subscription: Subscription = _ private var bytesReadSoFar = 0L diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/SimpleSubscriber.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/SimpleSubscriber.scala index 988a68336b..056990ca9c 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/SimpleSubscriber.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/SimpleSubscriber.scala @@ -9,7 +9,6 @@ import java.util.concurrent.ConcurrentLinkedQueue import scala.collection.JavaConverters._ import scala.concurrent.{Future, Promise} -// Requests all data at once and loads it into memory private[netty] class SimpleSubscriber() extends PromisingSubscriber[ByteBuffer, HttpContent] { private var subscription: Subscription = _ private val chunks = new ConcurrentLinkedQueue[Array[Byte]]() From 8a1ae662401f88cba0360256798f51658da00975 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Mon, 27 Nov 2023 16:56:51 +0100 Subject: [PATCH 13/35] Remove the global maxContentLength setting --- .../server/netty/cats/NettyCatsServer.scala | 2 +- .../netty/cats/NettyCatsServerTest.scala | 2 +- .../cats/NettyCatsTestServerInterpreter.scala | 5 --- .../server/netty/loom/NettyIdServer.scala | 1 - .../sttp/tapir/server/netty/NettyConfig.scala | 10 +---- .../server/netty/NettyFutureServer.scala | 2 +- .../netty/internal/NettyServerHandler.scala | 42 +++---------------- .../server/netty/zio/NettyZioServer.scala | 1 - 8 files changed, 9 insertions(+), 56 deletions(-) diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala index 340ebd76dc..70f439f93b 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala @@ -73,7 +73,7 @@ case class NettyCatsServer[F[_]: Async](routes: Vector[Route[F]], options: Netty val channelFuture = NettyBootstrap( config, - new NettyServerHandler(route, unsafeRunAsync, config.maxContentLength, channelGroup, isShuttingDown), + new NettyServerHandler(route, unsafeRunAsync, channelGroup, isShuttingDown), eventLoopGroup, socketOverride ) diff --git a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala index dd9c9f865e..fdea2d80fa 100644 --- a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala +++ b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala @@ -35,7 +35,7 @@ class NettyCatsServerTest extends TestSuite with EitherValues { interpreter, backend, multipart = false, - maxContentLength = Some(NettyCatsTestServerInterpreter.maxContentLength) + maxContentLength = Some(300000) ) .tests() ++ new ServerStreamingTests(createServerTest, maxLengthSupported = true).tests(Fs2Streams[IO])(drainFs2) ++ diff --git a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsTestServerInterpreter.scala b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsTestServerInterpreter.scala index 521e6a342d..d70ccc9067 100644 --- a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsTestServerInterpreter.scala +++ b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsTestServerInterpreter.scala @@ -28,7 +28,6 @@ class NettyCatsTestServerInterpreter(eventLoopGroup: NioEventLoopGroup, dispatch .eventLoopGroup(eventLoopGroup) .randomPort .withDontShutdownEventLoopGroupOnClose - .maxContentLength(NettyCatsTestServerInterpreter.maxContentLength) .noGracefulShutdown val customizedConfig = gracefulShutdownTimeout.map(config.withGracefulShutdownTimeout).getOrElse(config) @@ -39,7 +38,3 @@ class NettyCatsTestServerInterpreter(eventLoopGroup: NioEventLoopGroup, dispatch .make(bind.map(b => (b.port, b.stop()))) { case (_, release) => release } } } - -object NettyCatsTestServerInterpreter { - val maxContentLength = 10000 -} diff --git a/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdServer.scala b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdServer.scala index 8609b49a36..bdb9ce14a9 100644 --- a/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdServer.scala +++ b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdServer.scala @@ -95,7 +95,6 @@ case class NettyIdServer(routes: Vector[IdRoute], options: NettyIdServerOptions, new NettyServerHandler( route, unsafeRunF, - config.maxContentLength, channelGroup, isShuttingDown ), diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala index fbadd899fe..d1f5091597 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala @@ -17,9 +17,6 @@ import scala.concurrent.duration._ /** Netty configuration, used by [[NettyFutureServer]] and other server implementations to configure the networking layer, the Netty * processing pipeline, and start & stop the server. * - * @param maxContentLength - * The max content length passed to the [[io.netty.handler.codec.http.HttpObjectAggregator]] handler. - * * @param maxConnections * The maximum number of concurrent connections allowed by the server. Any connections above this limit will be closed right after they * are opened. @@ -56,7 +53,6 @@ case class NettyConfig( host: String, port: Int, shutdownEventLoopGroupOnClose: Boolean, - maxContentLength: Option[Int], maxConnections: Option[Int], socketBacklog: Int, requestTimeout: Option[FiniteDuration], @@ -79,9 +75,6 @@ case class NettyConfig( def withShutdownEventLoopGroupOnClose: NettyConfig = copy(shutdownEventLoopGroupOnClose = true) def withDontShutdownEventLoopGroupOnClose: NettyConfig = copy(shutdownEventLoopGroupOnClose = false) - def maxContentLength(m: Int): NettyConfig = copy(maxContentLength = Some(m)) - def noMaxContentLength: NettyConfig = copy(maxContentLength = None) - def maxConnections(m: Int): NettyConfig = copy(maxConnections = Some(m)) def socketBacklog(s: Int): NettyConfig = copy(socketBacklog = s) @@ -124,7 +117,6 @@ object NettyConfig { socketTimeout = Some(60.seconds), lingerTimeout = Some(60.seconds), gracefulShutdownTimeout = Some(10.seconds), - maxContentLength = None, maxConnections = None, addLoggingHandler = false, sslContext = None, @@ -136,7 +128,7 @@ object NettyConfig { def defaultInitPipelineNoStreaming(cfg: NettyConfig)(pipeline: ChannelPipeline, handler: ChannelHandler): Unit = { cfg.sslContext.foreach(s => pipeline.addLast(s.newHandler(pipeline.channel().alloc()))) pipeline.addLast(new HttpServerCodec()) - pipeline.addLast(new HttpObjectAggregator(cfg.maxContentLength.getOrElse(Integer.MAX_VALUE))) + pipeline.addLast(new HttpObjectAggregator((Integer.MAX_VALUE))) pipeline.addLast(new ChunkedWriteHandler()) pipeline.addLast(handler) () diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala index abe9472126..210fd6d4ae 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala @@ -70,7 +70,7 @@ case class NettyFutureServer(routes: Vector[FutureRoute], options: NettyFutureSe val channelFuture = NettyBootstrap( config, - new NettyServerHandler(route, unsafeRunAsync, config.maxContentLength, channelGroup, isShuttingDown), + new NettyServerHandler(route, unsafeRunAsync, channelGroup, isShuttingDown), eventLoopGroup, socketOverride ) diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala index bef744142d..745c6fc6bd 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala @@ -35,7 +35,6 @@ import scala.util.{Failure, Success} class NettyServerHandler[F[_]]( route: Route[F], unsafeRunAsync: (() => F[ServerResponse[NettyResponse]]) => (Future[ServerResponse[NettyResponse]], () => Future[Unit]), - maxContentLength: Option[Int], channelGroup: ChannelGroup, isShuttingDown: AtomicBoolean )(implicit @@ -177,16 +176,11 @@ class NettyServerHandler[F[_]]( serverResponse.handle( ctx = ctx, byteBufHandler = (channelPromise, byteBuf) => { - - if (maxContentLength.exists(_ < byteBuf.readableBytes)) - writeEntityTooLargeResponse(ctx, req) - else { - val res = new DefaultFullHttpResponse(req.protocolVersion(), HttpResponseStatus.valueOf(serverResponse.code.code), byteBuf) - res.setHeadersFrom(serverResponse) - res.handleContentLengthAndChunkedHeaders(Option(byteBuf.readableBytes())) - res.handleCloseAndKeepAliveHeaders(req) - ctx.writeAndFlush(res, channelPromise).closeIfNeeded(req) - } + val res = new DefaultFullHttpResponse(req.protocolVersion(), HttpResponseStatus.valueOf(serverResponse.code.code), byteBuf) + res.setHeadersFrom(serverResponse) + res.handleContentLengthAndChunkedHeaders(Option(byteBuf.readableBytes())) + res.handleCloseAndKeepAliveHeaders(req) + ctx.writeAndFlush(res, channelPromise).closeIfNeeded(req) }, chunkedStreamHandler = (channelPromise, chunkedStream) => { val resHeader: DefaultHttpResponse = @@ -234,32 +228,6 @@ class NettyServerHandler[F[_]]( } ) - private def writeEntityTooLargeResponse(ctx: ChannelHandlerContext, req: HttpRequest): Unit = { - - if (!HttpUtil.is100ContinueExpected(req) && !HttpUtil.isKeepAlive(req)) { - val future: ChannelFuture = ctx.writeAndFlush(EntityTooLargeClose.retainedDuplicate()) - val _ = future.addListener(new ChannelFutureListener() { - override def operationComplete(future: ChannelFuture) = { - if (!future.isSuccess()) { - logger.warn("Failed to send a 413 Request Entity Too Large.", future.cause()) - } - val _ = ctx.close() - } - }) - } else { - val _ = ctx - .writeAndFlush(EntityTooLarge.retainedDuplicate()) - .addListener(new ChannelFutureListener() { - override def operationComplete(future: ChannelFuture) = { - if (!future.isSuccess()) { - logger.warn("Failed to send a 413 Request Entity Too Large.", future.cause()) - val _ = ctx.close() - } - } - }) - } - } - private implicit class RichServerNettyResponse(val r: ServerResponse[NettyResponse]) { def handle( ctx: ChannelHandlerContext, diff --git a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala index 968e0d1013..50ceeb9a6d 100644 --- a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala +++ b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala @@ -89,7 +89,6 @@ case class NettyZioServer[R](routes: Vector[RIO[R, Route[RIO[R, *]]]], options: new NettyServerHandler[RIO[R, *]]( route, unsafeRunAsync(runtime), - config.maxContentLength, channelGroup, isShuttingDown ), From 233a62599655264149b0faeb69e9f6aa8aacd361 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Mon, 27 Nov 2023 17:06:48 +0100 Subject: [PATCH 14/35] Adjust tests --- .../netty/cats/NettyCatsServerTest.scala | 2 +- .../server/netty/NettyFutureServerTest.scala | 2 +- .../server/netty/zio/NettyZioServerTest.scala | 2 +- .../tapir/server/tests/AllServerTests.scala | 2 +- .../tapir/server/tests/ServerBasicTests.scala | 33 ++++++++++--------- 5 files changed, 22 insertions(+), 19 deletions(-) diff --git a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala index fdea2d80fa..52d86ae1d4 100644 --- a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala +++ b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala @@ -35,7 +35,7 @@ class NettyCatsServerTest extends TestSuite with EitherValues { interpreter, backend, multipart = false, - maxContentLength = Some(300000) + maxContentLength = true ) .tests() ++ new ServerStreamingTests(createServerTest, maxLengthSupported = true).tests(Fs2Streams[IO])(drainFs2) ++ diff --git a/server/netty-server/src/test/scala/sttp/tapir/server/netty/NettyFutureServerTest.scala b/server/netty-server/src/test/scala/sttp/tapir/server/netty/NettyFutureServerTest.scala index 20604e82c9..5125a86532 100644 --- a/server/netty-server/src/test/scala/sttp/tapir/server/netty/NettyFutureServerTest.scala +++ b/server/netty-server/src/test/scala/sttp/tapir/server/netty/NettyFutureServerTest.scala @@ -22,7 +22,7 @@ class NettyFutureServerTest extends TestSuite with EitherValues { val createServerTest = new DefaultCreateServerTest(backend, interpreter) val tests = - new AllServerTests(createServerTest, interpreter, backend, multipart = false, maxContentLength = Some(300000)).tests() ++ + new AllServerTests(createServerTest, interpreter, backend, multipart = false, maxContentLength = true).tests() ++ new ServerGracefulShutdownTests(createServerTest, Sleeper.futureSleeper).tests() (tests, eventLoopGroup) diff --git a/server/netty-server/zio/src/test/scala/sttp/tapir/server/netty/zio/NettyZioServerTest.scala b/server/netty-server/zio/src/test/scala/sttp/tapir/server/netty/zio/NettyZioServerTest.scala index fc29fd4960..134de29376 100644 --- a/server/netty-server/zio/src/test/scala/sttp/tapir/server/netty/zio/NettyZioServerTest.scala +++ b/server/netty-server/zio/src/test/scala/sttp/tapir/server/netty/zio/NettyZioServerTest.scala @@ -40,7 +40,7 @@ class NettyZioServerTest extends TestSuite with EitherValues { backend, staticContent = false, multipart = false, - maxContentLength = Some(300) + maxContentLength = true ).tests() ++ new ServerStreamingTests(createServerTest, maxLengthSupported = true).tests(ZioStreams)(drainZStream) ++ new ServerCancellationTests(createServerTest)(monadError, asyncInstance).tests() ++ diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/AllServerTests.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/AllServerTests.scala index 3e6b762346..8e6ad861ff 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/AllServerTests.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/AllServerTests.scala @@ -28,7 +28,7 @@ class AllServerTests[F[_], OPTIONS, ROUTE]( oneOfBody: Boolean = true, cors: Boolean = true, options: Boolean = true, - maxContentLength: Option[Int] = None + maxContentLength: Boolean = false )(implicit m: MonadError[F] ) { diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala index cc33d50c8a..2413d1f2c1 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala @@ -35,7 +35,7 @@ class ServerBasicTests[F[_], OPTIONS, ROUTE]( supportsUrlEncodedPathSegments: Boolean = true, supportsMultipleSetCookieHeaders: Boolean = true, invulnerableToUnsanitizedHeaders: Boolean = true, - maxContentLength: Option[Int] = None + maxContentLength: Boolean = false )(implicit m: MonadError[F] ) { @@ -50,7 +50,7 @@ class ServerBasicTests[F[_], OPTIONS, ROUTE]( customiseDecodeFailureHandlerTests() ++ serverSecurityLogicTests() ++ (if (inputStreamSupport) inputStreamTests() else Nil) ++ - maxContentLength.map(maxContentLengthTests).getOrElse(Nil) ++ + (if (maxContentLength) maxContentLengthTests else Nil) ++ exceptionTests() def basicTests(): List[Test] = List( @@ -749,7 +749,7 @@ class ServerBasicTests[F[_], OPTIONS, ROUTE]( def testPayloadTooLarge[I]( testedEndpoint: PublicEndpoint[I, Unit, I, Any], - maxLength: Int, + maxLength: Int ) = testServer( testedEndpoint.attribute(AttributeKey[MaxContentLength], MaxContentLength(maxLength.toLong)), "returns 413 on exceeded max content length (request)" @@ -759,25 +759,28 @@ class ServerBasicTests[F[_], OPTIONS, ROUTE]( } def testPayloadWithinLimit[I]( testedEndpoint: PublicEndpoint[I, Unit, I, Any], - maxLength: Int, + maxLength: Int ) = testServer( testedEndpoint.attribute(AttributeKey[MaxContentLength], MaxContentLength(maxLength.toLong)), - "returns OK on content length below or equal max (request)", + "returns OK on content length below or equal max (request)" )(i => pureResult(i.asRight[Unit])) { (backend, baseUri) => val fineBody: String = List.fill(maxLength)('x').mkString basicRequest.post(uri"$baseUri/api/echo").body(fineBody).send(backend).map(_.code shouldBe StatusCode.Ok) } - def maxContentLengthTests(maxLength: Int): List[Test] = List( - testPayloadTooLarge(in_string_out_string, maxLength), - testPayloadTooLarge(in_byte_array_out_byte_array, maxLength), - testPayloadTooLarge(in_file_out_file, maxLength), - testPayloadTooLarge(in_byte_buffer_out_byte_buffer, maxLength), - testPayloadWithinLimit(in_string_out_string, maxLength), - testPayloadWithinLimit(in_byte_array_out_byte_array, maxLength), - testPayloadWithinLimit(in_file_out_file, maxLength), - testPayloadWithinLimit(in_byte_buffer_out_byte_buffer, maxLength) - ) + def maxContentLengthTests: List[Test] = { + val maxLength = 16484 // To generate a few chunks of default size 8192 + some extra bytes + List( + testPayloadTooLarge(in_string_out_string, maxLength), + testPayloadTooLarge(in_byte_array_out_byte_array, maxLength), + testPayloadTooLarge(in_file_out_file, maxLength), + testPayloadTooLarge(in_byte_buffer_out_byte_buffer, maxLength), + testPayloadWithinLimit(in_string_out_string, maxLength), + testPayloadWithinLimit(in_byte_array_out_byte_array, maxLength), + testPayloadWithinLimit(in_file_out_file, maxLength), + testPayloadWithinLimit(in_byte_buffer_out_byte_buffer, maxLength) + ) + } def exceptionTests(): List[Test] = List( testServer(endpoint, "handle exceptions")(_ => throw new RuntimeException()) { (backend, baseUri) => From 8fbba4b599eb06ec14b749d384bd9fac0e50de28 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Mon, 27 Nov 2023 17:12:12 +0100 Subject: [PATCH 15/35] Add test for InputStream --- .../main/scala/sttp/tapir/server/tests/ServerBasicTests.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala index 2413d1f2c1..1be06c4338 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala @@ -25,7 +25,6 @@ import java.io.{ByteArrayInputStream, InputStream} import java.nio.ByteBuffer import sttp.tapir.server.interpreter.MaxContentLength import sttp.tapir.tests.Files.in_file_out_file -import java.io.File class ServerBasicTests[F[_], OPTIONS, ROUTE]( createServerTest: CreateServerTest[F, Any, OPTIONS, ROUTE], @@ -774,8 +773,10 @@ class ServerBasicTests[F[_], OPTIONS, ROUTE]( testPayloadTooLarge(in_string_out_string, maxLength), testPayloadTooLarge(in_byte_array_out_byte_array, maxLength), testPayloadTooLarge(in_file_out_file, maxLength), + testPayloadTooLarge(in_input_stream_out_input_stream, maxLength), testPayloadTooLarge(in_byte_buffer_out_byte_buffer, maxLength), testPayloadWithinLimit(in_string_out_string, maxLength), + testPayloadWithinLimit(in_input_stream_out_input_stream, maxLength), testPayloadWithinLimit(in_byte_array_out_byte_array, maxLength), testPayloadWithinLimit(in_file_out_file, maxLength), testPayloadWithinLimit(in_byte_buffer_out_byte_buffer, maxLength) From 4ae0ec801d699d863dca47889506efa948e56a18 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Wed, 29 Nov 2023 09:33:41 +0100 Subject: [PATCH 16/35] Remove the distinction between defaultStreaming and defaultNoStreaming --- doc/server/netty.md | 6 ++--- generated-doc/out/server/netty.md | 6 ++--- .../server/netty/cats/NettyCatsServer.scala | 4 ++-- .../cats/NettyCatsTestServerInterpreter.scala | 2 +- .../server/netty/loom/NettyIdServer.scala | 4 ++-- .../loom/NettyIdTestServerInterpreter.scala | 2 +- .../sttp/tapir/server/netty/NettyConfig.scala | 22 +++++-------------- .../server/netty/NettyFutureServer.scala | 4 ++-- .../NettyFutureTestServerInterpreter.scala | 2 +- .../server/netty/zio/NettyZioServer.scala | 4 ++-- .../zio/NettyZioTestServerInterpreter.scala | 2 +- 11 files changed, 23 insertions(+), 35 deletions(-) diff --git a/doc/server/netty.md b/doc/server/netty.md index 64e1ea25bf..3f8eafb4a4 100644 --- a/doc/server/netty.md +++ b/doc/server/netty.md @@ -80,7 +80,7 @@ NettyFutureServer().port(9090).addEndpoints(???) NettyFutureServer(NettyFutureServerOptions.customiseInterceptors.serverLog(None).options) // customise Netty config -NettyFutureServer(NettyConfig.defaultNoStreaming.socketBacklog(256)) +NettyFutureServer(NettyConfig.default.socketBacklog(256)) ``` ## Graceful shutdown @@ -93,9 +93,9 @@ import sttp.tapir.server.netty.NettyConfig import scala.concurrent.duration._ // adjust the waiting time to your needs -val config = NettyConfig.defaultNoStreaming.withGracefulShutdownTimeout(5.seconds) +val config = NettyConfig.default.withGracefulShutdownTimeout(5.seconds) // or if you don't want the server to wait for in-flight requests -val config2 = NettyConfig.defaultNoStreaming.noGracefulShutdown +val config2 = NettyConfig.default.noGracefulShutdown ``` ## Domain socket support diff --git a/generated-doc/out/server/netty.md b/generated-doc/out/server/netty.md index 2129947b14..799bc5bc0c 100644 --- a/generated-doc/out/server/netty.md +++ b/generated-doc/out/server/netty.md @@ -80,7 +80,7 @@ NettyFutureServer().port(9090).addEndpoints(???) NettyFutureServer(NettyFutureServerOptions.customiseInterceptors.serverLog(None).options) // customise Netty config -NettyFutureServer(NettyConfig.defaultNoStreaming.socketBacklog(256)) +NettyFutureServer(NettyConfig.default.socketBacklog(256)) ``` ## Graceful shutdown @@ -93,9 +93,9 @@ import sttp.tapir.server.netty.NettyConfig import scala.concurrent.duration._ // adjust the waiting time to your needs -val config = NettyConfig.defaultNoStreaming.withGracefulShutdownTimeout(5.seconds) +val config = NettyConfig.default.withGracefulShutdownTimeout(5.seconds) // or if you don't want the server to wait for in-flight requests -val config2 = NettyConfig.defaultNoStreaming.noGracefulShutdown +val config2 = NettyConfig.default.noGracefulShutdown ``` ## Domain socket support diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala index 70f439f93b..b9e2b958a3 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala @@ -123,9 +123,9 @@ case class NettyCatsServer[F[_]: Async](routes: Vector[Route[F]], options: Netty object NettyCatsServer { def apply[F[_]: Async](dispatcher: Dispatcher[F]): NettyCatsServer[F] = - NettyCatsServer(Vector.empty, NettyCatsServerOptions.default(dispatcher), NettyConfig.defaultWithStreaming) + NettyCatsServer(Vector.empty, NettyCatsServerOptions.default(dispatcher), NettyConfig.default) def apply[F[_]: Async](options: NettyCatsServerOptions[F]): NettyCatsServer[F] = - NettyCatsServer(Vector.empty, options, NettyConfig.defaultWithStreaming) + NettyCatsServer(Vector.empty, options, NettyConfig.default) def apply[F[_]: Async](dispatcher: Dispatcher[F], config: NettyConfig): NettyCatsServer[F] = NettyCatsServer(Vector.empty, NettyCatsServerOptions.default(dispatcher), config) def apply[F[_]: Async](options: NettyCatsServerOptions[F], config: NettyConfig): NettyCatsServer[F] = diff --git a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsTestServerInterpreter.scala b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsTestServerInterpreter.scala index d70ccc9067..61be7e6f4d 100644 --- a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsTestServerInterpreter.scala +++ b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsTestServerInterpreter.scala @@ -24,7 +24,7 @@ class NettyCatsTestServerInterpreter(eventLoopGroup: NioEventLoopGroup, dispatch routes: NonEmptyList[Route[IO]], gracefulShutdownTimeout: Option[FiniteDuration] = None ): Resource[IO, (Port, KillSwitch)] = { - val config = NettyConfig.defaultWithStreaming + val config = NettyConfig.default .eventLoopGroup(eventLoopGroup) .randomPort .withDontShutdownEventLoopGroupOnClose diff --git a/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdServer.scala b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdServer.scala index bdb9ce14a9..aef0b99bdc 100644 --- a/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdServer.scala +++ b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdServer.scala @@ -141,10 +141,10 @@ case class NettyIdServer(routes: Vector[IdRoute], options: NettyIdServerOptions, } object NettyIdServer { - def apply(): NettyIdServer = NettyIdServer(Vector.empty, NettyIdServerOptions.default, NettyConfig.defaultNoStreaming) + def apply(): NettyIdServer = NettyIdServer(Vector.empty, NettyIdServerOptions.default, NettyConfig.default) def apply(serverOptions: NettyIdServerOptions): NettyIdServer = - NettyIdServer(Vector.empty, serverOptions, NettyConfig.defaultNoStreaming) + NettyIdServer(Vector.empty, serverOptions, NettyConfig.default) def apply(config: NettyConfig): NettyIdServer = NettyIdServer(Vector.empty, NettyIdServerOptions.default, config) diff --git a/server/netty-server/loom/src/test/scala/sttp/tapir/server/netty/loom/NettyIdTestServerInterpreter.scala b/server/netty-server/loom/src/test/scala/sttp/tapir/server/netty/loom/NettyIdTestServerInterpreter.scala index 8d6e940a68..c4535bbc61 100644 --- a/server/netty-server/loom/src/test/scala/sttp/tapir/server/netty/loom/NettyIdTestServerInterpreter.scala +++ b/server/netty-server/loom/src/test/scala/sttp/tapir/server/netty/loom/NettyIdTestServerInterpreter.scala @@ -22,7 +22,7 @@ class NettyIdTestServerInterpreter(eventLoopGroup: NioEventLoopGroup) gracefulShutdownTimeout: Option[FiniteDuration] = None ): Resource[IO, (Port, IO[Unit])] = { val config = - NettyConfig.defaultNoStreaming.eventLoopGroup(eventLoopGroup).randomPort.withDontShutdownEventLoopGroupOnClose.noGracefulShutdown + NettyConfig.default.eventLoopGroup(eventLoopGroup).randomPort.withDontShutdownEventLoopGroupOnClose.noGracefulShutdown val customizedConfig = gracefulShutdownTimeout.map(config.withGracefulShutdownTimeout).getOrElse(config) val options = NettyIdServerOptions.default val bind = IO.blocking(NettyIdServer(options, customizedConfig).addRoutes(routes.toList).start()) diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala index d1f5091597..8ae9ac0523 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala @@ -1,15 +1,14 @@ package sttp.tapir.server.netty -import org.playframework.netty.http.HttpStreamsServerHandler import io.netty.channel.epoll.{Epoll, EpollEventLoopGroup, EpollServerSocketChannel} import io.netty.channel.kqueue.{KQueue, KQueueEventLoopGroup, KQueueServerSocketChannel} import io.netty.channel.nio.NioEventLoopGroup import io.netty.channel.socket.nio.NioServerSocketChannel import io.netty.channel.{ChannelHandler, ChannelPipeline, EventLoopGroup, ServerChannel} -import io.netty.handler.codec.http.{HttpObjectAggregator, HttpServerCodec} +import io.netty.handler.codec.http.HttpServerCodec import io.netty.handler.logging.LoggingHandler import io.netty.handler.ssl.SslContext -import io.netty.handler.stream.ChunkedWriteHandler +import org.playframework.netty.http.HttpStreamsServerHandler import sttp.tapir.server.netty.NettyConfig.EventLoopConfig import scala.concurrent.duration._ @@ -106,7 +105,7 @@ case class NettyConfig( } object NettyConfig { - def defaultNoStreaming: NettyConfig = NettyConfig( + def default: NettyConfig = NettyConfig( host = "localhost", port = 8080, shutdownEventLoopGroupOnClose = true, @@ -122,19 +121,10 @@ object NettyConfig { sslContext = None, eventLoopConfig = EventLoopConfig.auto, socketConfig = NettySocketConfig.default, - initPipeline = cfg => defaultInitPipelineNoStreaming(cfg)(_, _) + initPipeline = cfg => defaultInitPipeline(cfg)(_, _) ) - def defaultInitPipelineNoStreaming(cfg: NettyConfig)(pipeline: ChannelPipeline, handler: ChannelHandler): Unit = { - cfg.sslContext.foreach(s => pipeline.addLast(s.newHandler(pipeline.channel().alloc()))) - pipeline.addLast(new HttpServerCodec()) - pipeline.addLast(new HttpObjectAggregator((Integer.MAX_VALUE))) - pipeline.addLast(new ChunkedWriteHandler()) - pipeline.addLast(handler) - () - } - - def defaultInitPipelineStreaming(cfg: NettyConfig)(pipeline: ChannelPipeline, handler: ChannelHandler): Unit = { + def defaultInitPipeline(cfg: NettyConfig)(pipeline: ChannelPipeline, handler: ChannelHandler): Unit = { cfg.sslContext.foreach(s => pipeline.addLast(s.newHandler(pipeline.channel().alloc()))) pipeline.addLast(new HttpServerCodec()) pipeline.addLast(new HttpStreamsServerHandler()) @@ -143,8 +133,6 @@ object NettyConfig { () } - def defaultWithStreaming: NettyConfig = defaultNoStreaming.copy(initPipeline = cfg => defaultInitPipelineStreaming(cfg)(_, _)) - case class EventLoopConfig(initEventLoopGroup: () => EventLoopGroup, serverChannel: Class[_ <: ServerChannel]) object EventLoopConfig { diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala index 210fd6d4ae..f9d9bce64d 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala @@ -121,10 +121,10 @@ case class NettyFutureServer(routes: Vector[FutureRoute], options: NettyFutureSe object NettyFutureServer { def apply()(implicit ec: ExecutionContext): NettyFutureServer = - NettyFutureServer(Vector.empty, NettyFutureServerOptions.default, NettyConfig.defaultWithStreaming) + NettyFutureServer(Vector.empty, NettyFutureServerOptions.default, NettyConfig.default) def apply(serverOptions: NettyFutureServerOptions)(implicit ec: ExecutionContext): NettyFutureServer = - NettyFutureServer(Vector.empty, serverOptions, NettyConfig.defaultWithStreaming) + NettyFutureServer(Vector.empty, serverOptions, NettyConfig.default) def apply(config: NettyConfig)(implicit ec: ExecutionContext): NettyFutureServer = NettyFutureServer(Vector.empty, NettyFutureServerOptions.default, config) diff --git a/server/netty-server/src/test/scala/sttp/tapir/server/netty/NettyFutureTestServerInterpreter.scala b/server/netty-server/src/test/scala/sttp/tapir/server/netty/NettyFutureTestServerInterpreter.scala index ebe4877c6f..7eb0867ff1 100644 --- a/server/netty-server/src/test/scala/sttp/tapir/server/netty/NettyFutureTestServerInterpreter.scala +++ b/server/netty-server/src/test/scala/sttp/tapir/server/netty/NettyFutureTestServerInterpreter.scala @@ -24,7 +24,7 @@ class NettyFutureTestServerInterpreter(eventLoopGroup: NioEventLoopGroup)(implic gracefulShutdownTimeout: Option[FiniteDuration] = None ): Resource[IO, (Port, KillSwitch)] = { val config = - NettyConfig.defaultWithStreaming + NettyConfig.default .eventLoopGroup(eventLoopGroup) .randomPort .withDontShutdownEventLoopGroupOnClose diff --git a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala index 50ceeb9a6d..be553f0ccf 100644 --- a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala +++ b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala @@ -140,9 +140,9 @@ case class NettyZioServer[R](routes: Vector[RIO[R, Route[RIO[R, *]]]], options: } object NettyZioServer { - def apply[R](): NettyZioServer[R] = NettyZioServer(Vector.empty, NettyZioServerOptions.default[R], NettyConfig.defaultWithStreaming) + def apply[R](): NettyZioServer[R] = NettyZioServer(Vector.empty, NettyZioServerOptions.default[R], NettyConfig.default) def apply[R](options: NettyZioServerOptions[R]): NettyZioServer[R] = - NettyZioServer(Vector.empty, options, NettyConfig.defaultWithStreaming) + NettyZioServer(Vector.empty, options, NettyConfig.default) def apply[R](config: NettyConfig): NettyZioServer[R] = NettyZioServer(Vector.empty, NettyZioServerOptions.default[R], config) def apply[R](options: NettyZioServerOptions[R], config: NettyConfig): NettyZioServer[R] = NettyZioServer(Vector.empty, options, config) } diff --git a/server/netty-server/zio/src/test/scala/sttp/tapir/server/netty/zio/NettyZioTestServerInterpreter.scala b/server/netty-server/zio/src/test/scala/sttp/tapir/server/netty/zio/NettyZioTestServerInterpreter.scala index 2f17fdefaf..ed6ad1cb7c 100644 --- a/server/netty-server/zio/src/test/scala/sttp/tapir/server/netty/zio/NettyZioTestServerInterpreter.scala +++ b/server/netty-server/zio/src/test/scala/sttp/tapir/server/netty/zio/NettyZioTestServerInterpreter.scala @@ -25,7 +25,7 @@ class NettyZioTestServerInterpreter[R](eventLoopGroup: NioEventLoopGroup) routes: NonEmptyList[Task[Route[Task]]], gracefulShutdownTimeout: Option[FiniteDuration] = None ): Resource[IO, (Port, KillSwitch)] = { - val config = NettyConfig.defaultWithStreaming + val config = NettyConfig.default .eventLoopGroup(eventLoopGroup) .randomPort .withDontShutdownEventLoopGroupOnClose From 50dde625a2047114f88c42d6a8aebd106d1f69b6 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Wed, 29 Nov 2023 11:16:35 +0100 Subject: [PATCH 17/35] Disjoin RequestBody/ToResponseBody impls for netty serverse --- .../cats/NettyCatsServerInterpreter.scala | 2 +- .../internal/NettyCatsToResponseBody.scala | 26 +++++++++++----- .../netty/loom/NettyIdServerInterpreter.scala | 6 ++-- .../netty/NettyFutureServerInterpreter.scala | 5 ++-- ....scala => NettyFutureToResponseBody.scala} | 17 +++++------ .../internal/NettyIdToResponseBody.scala | 30 +++++++++++++++++++ .../internal/NettyServerInterpreter.scala | 10 +++---- .../internal/NettyToStreamsResponseBody.scala | 30 +++++++++---------- .../InputStreamPublisher.scala | 11 ++++--- .../netty/zio/NettyZioServerInterpreter.scala | 2 +- 10 files changed, 90 insertions(+), 49 deletions(-) rename server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/{NettyToResponseBody.scala => NettyFutureToResponseBody.scala} (82%) create mode 100644 server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyIdToResponseBody.scala diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServerInterpreter.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServerInterpreter.scala index d6409f0fd0..92830b8652 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServerInterpreter.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServerInterpreter.scala @@ -32,7 +32,7 @@ trait NettyCatsServerInterpreter[F[_]] { val serverInterpreter = new ServerInterpreter[Fs2Streams[F], F, NettyResponse, Fs2Streams[F]]( FilterServerEndpoints(ses), new NettyCatsRequestBody(createFile), - new NettyCatsToResponseBody(nettyServerOptions.dispatcher, delegate = new NettyToResponseBody), + new NettyCatsToResponseBody(nettyServerOptions.dispatcher), RejectInterceptor.disableWhenSingleEndpoint(interceptors, ses), deleteFile ) diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsToResponseBody.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsToResponseBody.scala index 757053ae36..3a247f9d61 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsToResponseBody.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsToResponseBody.scala @@ -17,15 +17,27 @@ import sttp.tapir.server.netty.NettyResponseContent._ import sttp.tapir.{CodecFormat, RawBodyType, WebSocketBodyOutput} import java.io.InputStream +import java.nio.ByteBuffer import java.nio.charset.Charset -class NettyCatsToResponseBody[F[_]: Async](dispatcher: Dispatcher[F], delegate: NettyToResponseBody) - extends ToResponseBody[NettyResponse, Fs2Streams[F]] { +class NettyCatsToResponseBody[F[_]: Async](dispatcher: Dispatcher[F]) extends ToResponseBody[NettyResponse, Fs2Streams[F]] { override val streams: Fs2Streams[F] = Fs2Streams[F] override def fromRawValue[R](v: R, headers: HasHeaders, format: CodecFormat, bodyType: RawBodyType[R]): NettyResponse = { bodyType match { + case RawBodyType.StringBody(charset) => + val bytes = v.asInstanceOf[String].getBytes(charset) + (ctx: ChannelHandlerContext) => ByteBufNettyResponseContent(ctx.newPromise(), Unpooled.wrappedBuffer(bytes)) + + case RawBodyType.ByteArrayBody => + val bytes = v.asInstanceOf[Array[Byte]] + (ctx: ChannelHandlerContext) => ByteBufNettyResponseContent(ctx.newPromise(), Unpooled.wrappedBuffer(bytes)) + + case RawBodyType.ByteBufferBody => + val byteBuffer = v.asInstanceOf[ByteBuffer] + (ctx: ChannelHandlerContext) => ByteBufNettyResponseContent(ctx.newPromise(), Unpooled.wrappedBuffer(byteBuffer)) + case RawBodyType.InputStreamBody => val stream = inputStreamToFs2(() => v) (ctx: ChannelHandlerContext) => new ReactivePublisherNettyResponseContent(ctx.newPromise(), fs2StreamToPublisher(stream)) @@ -41,22 +53,20 @@ class NettyCatsToResponseBody[F[_]: Async](dispatcher: Dispatcher[F], delegate: val path = Path.fromNioPath(tapirFile.file.toPath) val stream = tapirFile.range .flatMap(r => - r.startAndEnd.map(s => Files[F](Files.forAsync[F]).readRange(path, NettyToResponseBody.DefaultChunkSize, s._1, s._2)) + r.startAndEnd.map(s => Files[F](Files.forAsync[F]).readRange(path, NettyFutureToResponseBody.DefaultChunkSize, s._1, s._2)) ) - .getOrElse(Files[F](Files.forAsync[F]).readAll(path, NettyToResponseBody.DefaultChunkSize, Flags.Read)) + .getOrElse(Files[F](Files.forAsync[F]).readAll(path, NettyFutureToResponseBody.DefaultChunkSize, Flags.Read)) (ctx: ChannelHandlerContext) => new ReactivePublisherNettyResponseContent(ctx.newPromise(), fs2StreamToPublisher(stream)) case _: RawBodyType.MultipartBody => throw new UnsupportedOperationException - - case _ => delegate.fromRawValue(v, headers, format, bodyType) } } private def inputStreamToFs2(inputStream: () => InputStream) = fs2.io.readInputStream( Sync[F].blocking(inputStream()), - NettyToResponseBody.DefaultChunkSize + NettyFutureToResponseBody.DefaultChunkSize ) private def fs2StreamToPublisher(stream: streams.BinaryStream): Publisher[HttpContent] = { @@ -64,7 +74,7 @@ class NettyCatsToResponseBody[F[_]: Async](dispatcher: Dispatcher[F], delegate: // dispatcher, which results in a Resource[], which is hard to afford here StreamUnicastPublisher( stream - .chunkLimit(NettyToResponseBody.DefaultChunkSize) + .chunkLimit(NettyFutureToResponseBody.DefaultChunkSize) .map { chunk => val bytes: Chunk.ArraySlice[Byte] = chunk.compact diff --git a/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdServerInterpreter.scala b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdServerInterpreter.scala index b9d3ec0cc4..fdcb42fda2 100644 --- a/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdServerInterpreter.scala +++ b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdServerInterpreter.scala @@ -1,7 +1,7 @@ package sttp.tapir.server.netty.loom import sttp.tapir.server.ServerEndpoint -import sttp.tapir.server.netty.internal.{NettyServerInterpreter, RunAsync} +import sttp.tapir.server.netty.internal.{NettyIdToResponseBody, NettyRequestBody, NettyServerInterpreter, RunAsync} trait NettyIdServerInterpreter { def nettyServerOptions: NettyIdServerOptions @@ -12,8 +12,8 @@ trait NettyIdServerInterpreter { NettyServerInterpreter.toRoute[Id]( ses, nettyServerOptions.interceptors, - requestBody = None, - nettyServerOptions.createFile, + new NettyRequestBody(nettyServerOptions.createFile), + new NettyIdToResponseBody, nettyServerOptions.deleteFile, new RunAsync[Id] { override def apply[T](f: => Id[T]): Unit = { diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServerInterpreter.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServerInterpreter.scala index 64b881652c..deec6449c6 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServerInterpreter.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServerInterpreter.scala @@ -7,6 +7,7 @@ import sttp.tapir.server.netty.internal.{NettyServerInterpreter, RunAsync} import scala.concurrent.{ExecutionContext, Future} import sttp.tapir.server.netty.internal.NettyFutureRequestBody +import sttp.tapir.server.netty.internal.NettyFutureToResponseBody trait NettyFutureServerInterpreter { def nettyServerOptions: NettyFutureServerOptions @@ -22,8 +23,8 @@ trait NettyFutureServerInterpreter { NettyServerInterpreter.toRoute( ses, nettyServerOptions.interceptors, - Some(new NettyFutureRequestBody(nettyServerOptions.createFile)), - nettyServerOptions.createFile, + new NettyFutureRequestBody(nettyServerOptions.createFile), + new NettyFutureToResponseBody, nettyServerOptions.deleteFile, FutureRunAsync ) diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToResponseBody.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyFutureToResponseBody.scala similarity index 82% rename from server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToResponseBody.scala rename to server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyFutureToResponseBody.scala index 5754eef7b5..37f9de2af4 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToResponseBody.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyFutureToResponseBody.scala @@ -10,20 +10,17 @@ import sttp.tapir.capabilities.NoStreams import sttp.tapir.server.interpreter.ToResponseBody import sttp.tapir.server.netty.NettyResponse import sttp.tapir.server.netty.NettyResponseContent.{ByteBufNettyResponseContent, ReactivePublisherNettyResponseContent} -import sttp.tapir.server.netty.internal.NettyToResponseBody.DefaultChunkSize +import sttp.tapir.server.netty.internal.NettyFutureToResponseBody.DefaultChunkSize import sttp.tapir.server.netty.internal.reactivestreams.{FileRangePublisher, InputStreamPublisher} import sttp.tapir.{CodecFormat, FileRange, InputStreamRange, RawBodyType, WebSocketBodyOutput} import java.io.InputStream import java.nio.ByteBuffer import java.nio.charset.Charset -import java.util.concurrent.ForkJoinPool import scala.concurrent.ExecutionContext -class NettyToResponseBody extends ToResponseBody[NettyResponse, NoStreams] { +class NettyFutureToResponseBody(implicit ec: ExecutionContext) extends ToResponseBody[NettyResponse, NoStreams] { override val streams: capabilities.Streams[NoStreams] = NoStreams - // TODO cleanup - lazy val blockingEc: ExecutionContext = ExecutionContext.fromExecutor(new ForkJoinPool) override def fromRawValue[R](v: R, headers: HasHeaders, format: CodecFormat, bodyType: RawBodyType[R]): NettyResponse = { bodyType match { @@ -39,14 +36,14 @@ class NettyToResponseBody extends ToResponseBody[NettyResponse, NoStreams] { val byteBuffer = v.asInstanceOf[ByteBuffer] (ctx: ChannelHandlerContext) => ByteBufNettyResponseContent(ctx.newPromise(), Unpooled.wrappedBuffer(byteBuffer)) - case RawBodyType.InputStreamBody => + case RawBodyType.InputStreamBody => (ctx: ChannelHandlerContext) => ReactivePublisherNettyResponseContent(ctx.newPromise(), wrap(v)) case RawBodyType.InputStreamRangeBody => (ctx: ChannelHandlerContext) => ReactivePublisherNettyResponseContent(ctx.newPromise(), wrap(v)) - case RawBodyType.FileBody => { - (ctx: ChannelHandlerContext) => ReactivePublisherNettyResponseContent(ctx.newPromise(), wrap(v)) + case RawBodyType.FileBody => { (ctx: ChannelHandlerContext) => + ReactivePublisherNettyResponseContent(ctx.newPromise(), wrap(v)) } @@ -55,7 +52,7 @@ class NettyToResponseBody extends ToResponseBody[NettyResponse, NoStreams] { } private def wrap(streamRange: InputStreamRange): Publisher[HttpContent] = { - new InputStreamPublisher(streamRange, DefaultChunkSize, blockingEc) + new InputStreamPublisher(streamRange, DefaultChunkSize) } private def wrap(fileRange: FileRange): Publisher[HttpContent] = { @@ -79,7 +76,7 @@ class NettyToResponseBody extends ToResponseBody[NettyResponse, NoStreams] { ): NettyResponse = throw new UnsupportedOperationException } -private[internal] object NettyToResponseBody { +private[internal] object NettyFutureToResponseBody { val DefaultChunkSize = 8192 val IncludingLastOffset = 1 val ReadOnlyAccessMode = "r" diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyIdToResponseBody.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyIdToResponseBody.scala new file mode 100644 index 0000000000..75c5b408eb --- /dev/null +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyIdToResponseBody.scala @@ -0,0 +1,30 @@ +package sttp.tapir.server.netty.internal + +import sttp.capabilities +import sttp.model.HasHeaders +import sttp.tapir.capabilities.NoStreams +import sttp.tapir.server.interpreter.ToResponseBody +import sttp.tapir.server.netty.NettyResponse +import sttp.tapir.{CodecFormat, RawBodyType} +import sttp.tapir.WebSocketBodyOutput +import java.nio.charset.Charset + +class NettyIdToResponseBody extends ToResponseBody[NettyResponse, NoStreams] { + + override val streams: capabilities.Streams[NoStreams] = NoStreams + + override def fromRawValue[R](v: R, headers: HasHeaders, format: CodecFormat, bodyType: RawBodyType[R]): NettyResponse = { + ??? // TODO + } + override def fromStreamValue( + v: streams.BinaryStream, + headers: HasHeaders, + format: CodecFormat, + charset: Option[Charset] + ): NettyResponse = ??? + + override def fromWebSocketPipe[REQ, RESP]( + pipe: streams.Pipe[REQ, RESP], + o: WebSocketBodyOutput[streams.Pipe[REQ, RESP], REQ, RESP, _, NoStreams] + ): NettyResponse = throw new UnsupportedOperationException +} diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerInterpreter.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerInterpreter.scala index 50ca5bc293..8b87f78f47 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerInterpreter.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerInterpreter.scala @@ -4,28 +4,28 @@ import sttp.monad.MonadError import sttp.monad.syntax._ import sttp.tapir.TapirFile import sttp.tapir.capabilities.NoStreams -import sttp.tapir.model.ServerRequest import sttp.tapir.server.ServerEndpoint import sttp.tapir.server.interceptor.reject.RejectInterceptor import sttp.tapir.server.interceptor.{Interceptor, RequestResult} import sttp.tapir.server.interpreter.{BodyListener, FilterServerEndpoints, ServerInterpreter} import sttp.tapir.server.netty.{NettyResponse, NettyServerRequest, Route} import sttp.tapir.server.interpreter.RequestBody +import sttp.tapir.server.interpreter.ToResponseBody object NettyServerInterpreter { def toRoute[F[_]: MonadError]( ses: List[ServerEndpoint[Any, F]], interceptors: List[Interceptor[F]], - requestBody: Option[RequestBody[F, NoStreams]], - createFile: ServerRequest => F[TapirFile], + requestBody: RequestBody[F, NoStreams], + toResponseBody: ToResponseBody[NettyResponse, NoStreams], deleteFile: TapirFile => F[Unit], runAsync: RunAsync[F] ): Route[F] = { implicit val bodyListener: BodyListener[F, NettyResponse] = new NettyBodyListener(runAsync) val serverInterpreter = new ServerInterpreter[Any, F, NettyResponse, NoStreams]( FilterServerEndpoints(ses), - requestBody.getOrElse(new NettyRequestBody(createFile)), - new NettyToResponseBody, + requestBody, + toResponseBody, RejectInterceptor.disableWhenSingleEndpoint(interceptors, ses), deleteFile ) diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToStreamsResponseBody.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToStreamsResponseBody.scala index 8c45f8285c..2168cf2bdd 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToStreamsResponseBody.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToStreamsResponseBody.scala @@ -2,32 +2,34 @@ package sttp.tapir.server.netty.internal import io.netty.buffer.Unpooled import io.netty.channel.ChannelHandlerContext -import io.netty.handler.stream.{ChunkedFile, ChunkedStream} -import sttp.capabilities import sttp.capabilities.Streams import sttp.model.HasHeaders -import sttp.tapir.capabilities.NoStreams import sttp.tapir.server.interpreter.ToResponseBody import sttp.tapir.server.netty.NettyResponse -import sttp.tapir.server.netty.NettyResponseContent.{ - ByteBufNettyResponseContent, - ChunkedFileNettyResponseContent, - ChunkedStreamNettyResponseContent, - ReactivePublisherNettyResponseContent -} -import sttp.tapir.{CodecFormat, FileRange, InputStreamRange, RawBodyType, WebSocketBodyOutput} +import sttp.tapir.server.netty.NettyResponseContent.{ByteBufNettyResponseContent, ReactivePublisherNettyResponseContent} +import sttp.tapir.{CodecFormat, RawBodyType, WebSocketBodyOutput} -import java.io.{InputStream, RandomAccessFile} import java.nio.ByteBuffer import java.nio.charset.Charset -class NettyToStreamsResponseBody[S <: Streams[S]](delegate: NettyToResponseBody, streamCompatible: StreamCompatible[S]) - extends ToResponseBody[NettyResponse, S] { +class NettyToStreamsResponseBody[S <: Streams[S]](streamCompatible: StreamCompatible[S]) extends ToResponseBody[NettyResponse, S] { override val streams: S = streamCompatible.streams override def fromRawValue[R](v: R, headers: HasHeaders, format: CodecFormat, bodyType: RawBodyType[R]): NettyResponse = { bodyType match { + case RawBodyType.StringBody(charset) => + val bytes = v.asInstanceOf[String].getBytes(charset) + (ctx: ChannelHandlerContext) => ByteBufNettyResponseContent(ctx.newPromise(), Unpooled.wrappedBuffer(bytes)) + + case RawBodyType.ByteArrayBody => + val bytes = v.asInstanceOf[Array[Byte]] + (ctx: ChannelHandlerContext) => ByteBufNettyResponseContent(ctx.newPromise(), Unpooled.wrappedBuffer(bytes)) + + case RawBodyType.ByteBufferBody => + val byteBuffer = v.asInstanceOf[ByteBuffer] + (ctx: ChannelHandlerContext) => ByteBufNettyResponseContent(ctx.newPromise(), Unpooled.wrappedBuffer(byteBuffer)) + case RawBodyType.InputStreamBody => (ctx: ChannelHandlerContext) => new ReactivePublisherNettyResponseContent(ctx.newPromise(), streamCompatible.publisherFromInputStream(() => v, length = None)) @@ -43,8 +45,6 @@ class NettyToStreamsResponseBody[S <: Streams[S]](delegate: NettyToResponseBody, (ctx: ChannelHandlerContext) => new ReactivePublisherNettyResponseContent(ctx.newPromise(), streamCompatible.publisherFromFile(v)) case _: RawBodyType.MultipartBody => throw new UnsupportedOperationException - - case _ => delegate.fromRawValue(v, headers, format, bodyType) } } diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/InputStreamPublisher.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/InputStreamPublisher.scala index 633fb94539..f4e7ffb9a2 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/InputStreamPublisher.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/InputStreamPublisher.scala @@ -6,12 +6,13 @@ import org.reactivestreams.{Publisher, Subscriber, Subscription} import sttp.tapir.InputStreamRange import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong} +import scala.concurrent.blocking import scala.concurrent.Future import scala.util.Success import scala.util.Failure import scala.concurrent.ExecutionContext -class InputStreamPublisher(range: InputStreamRange, chunkSize: Int, blockingEc: ExecutionContext) extends Publisher[HttpContent] { +class InputStreamPublisher(range: InputStreamRange, chunkSize: Int)(implicit ec: ExecutionContext) extends Publisher[HttpContent] { override def subscribe(subscriber: Subscriber[_ >: HttpContent]): Unit = { if (subscriber == null) throw new NullPointerException("Subscriber cannot be null") val subscription = new InputStreamSubscription(subscriber, range, chunkSize) @@ -42,8 +43,10 @@ class InputStreamPublisher(range: InputStreamRange, chunkSize: Int, blockingEc: case _ => chunkSize } Future { - stream.readNBytes(expectedBytes) // Blocking I/IO - }(blockingEc) + blocking { + stream.readNBytes(expectedBytes) + } + } .onComplete { case Success(bytes) => val bytesRead = bytes.length @@ -65,7 +68,7 @@ class InputStreamPublisher(range: InputStreamRange, chunkSize: Int, blockingEc: case Failure(e) => stream.close() subscriber.onError(e) - }(blockingEc) + } } } diff --git a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServerInterpreter.scala b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServerInterpreter.scala index 56a16bc622..07d14aaf43 100644 --- a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServerInterpreter.scala +++ b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServerInterpreter.scala @@ -27,7 +27,7 @@ trait NettyZioServerInterpreter[R] { val serverInterpreter = new ServerInterpreter[ZioStreams, F, NettyResponse, ZioStreams]( FilterServerEndpoints(widenedSes), new NettyZioRequestBody(widenedServerOptions.createFile), - new NettyToStreamsResponseBody[ZioStreams](delegate = new NettyToResponseBody(), ZioStreamCompatible(runtime)), + new NettyToStreamsResponseBody[ZioStreams](ZioStreamCompatible(runtime)), RejectInterceptor.disableWhenSingleEndpoint(widenedServerOptions.interceptors, widenedSes), widenedServerOptions.deleteFile ) From 2cfcb27bf887ba2984d5604e830a0933877c7765 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Wed, 29 Nov 2023 16:30:47 +0100 Subject: [PATCH 18/35] Improve error handling and add comments --- .../netty/internal/NettyServerHandler.scala | 22 +++++++------------ .../reactivestreams/FileRangePublisher.scala | 10 ++++++++- .../InputStreamPublisher.scala | 18 ++++++++------- .../tapir/server/tests/ServerBasicTests.scala | 2 +- 4 files changed, 28 insertions(+), 24 deletions(-) diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala index 745c6fc6bd..c7e32baf9b 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala @@ -4,7 +4,6 @@ import com.typesafe.scalalogging.Logger import io.netty.buffer.{ByteBuf, Unpooled} import io.netty.channel._ import io.netty.channel.group.ChannelGroup -import io.netty.handler.codec.http.HttpHeaderNames.{CONNECTION, CONTENT_LENGTH} import io.netty.handler.codec.http._ import io.netty.handler.stream.{ChunkedFile, ChunkedStream} import org.playframework.netty.http.{DefaultStreamedHttpResponse, StreamedHttpRequest} @@ -63,19 +62,6 @@ class NettyServerHandler[F[_]]( private val logger = Logger[NettyServerHandler[F]] - private val EntityTooLarge: FullHttpResponse = { - val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE, Unpooled.EMPTY_BUFFER) - res.headers().set(CONTENT_LENGTH, 0) - res - } - - private val EntityTooLargeClose: FullHttpResponse = { - val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE, Unpooled.EMPTY_BUFFER) - res.headers().set(CONTENT_LENGTH, 0) - res.headers().set(CONNECTION, HttpHeaderValues.CLOSE) - res - } - override def handlerAdded(ctx: ChannelHandlerContext): Unit = if (ctx.channel.isActive) { initHandler(ctx) @@ -210,6 +196,14 @@ class NettyServerHandler[F[_]]( res.setHeadersFrom(serverResponse) res.handleCloseAndKeepAliveHeaders(req) + + channelPromise.addListener((future: ChannelFuture) => { + // A reactive publisher silently closes the channel and fails the channel promise, so we need + // to listen on it and log failure details + if (!future.isSuccess()) { + logger.error("Error when streaming HTTP response", future.cause()) + } + }) ctx.writeAndFlush(res, channelPromise).closeIfNeeded(req) }, diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/FileRangePublisher.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/FileRangePublisher.scala index fea3992bb7..e981bec388 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/FileRangePublisher.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/FileRangePublisher.scala @@ -33,6 +33,9 @@ class FileRangePublisher(fileRange: FileRange, chunkSize: Int) extends Publisher } } + /** Can be called multiple times by request(n), or concurrently by channel.read() callback. The readingInProgress check ensures that + * calls are serialized. A channel.read() operation will be started only if another isn't running. This method is non-blocking. + */ private def readNextChunkIfNeeded(): Unit = { if (demand.get() > 0 && !isCompleted.get() && readingInProgress.compareAndSet(false, true)) { val pos = position.get() @@ -52,6 +55,9 @@ class FileRangePublisher(fileRange: FileRange, chunkSize: Int) extends Publisher subscriber.onComplete() } else { val bytesToRead = Math.min(bytesRead, expectedBytes) + // The buffer is modified only by one thread at a time, because only one channel.read() + // is running at a time, and because buffer.clear() calls before the read are guarded + // by readingInProgress.compareAndSet. buffer.flip() val bytes = new Array[Byte](bytesToRead) buffer.get(bytes) @@ -63,7 +69,9 @@ class FileRangePublisher(fileRange: FileRange, chunkSize: Int) extends Publisher } else { demand.decrementAndGet() readingInProgress.set(false) - readNextChunkIfNeeded() // Read next chunk if there's more demand + // Either this call, or a call from request(n) will win the race to + // actually start a new read. + readNextChunkIfNeeded() } } } diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/InputStreamPublisher.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/InputStreamPublisher.scala index f4e7ffb9a2..517d5b11b3 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/InputStreamPublisher.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/InputStreamPublisher.scala @@ -5,12 +5,10 @@ import io.netty.handler.codec.http.{DefaultHttpContent, HttpContent} import org.reactivestreams.{Publisher, Subscriber, Subscription} import sttp.tapir.InputStreamRange +import java.io.InputStream import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong} -import scala.concurrent.blocking -import scala.concurrent.Future -import scala.util.Success -import scala.util.Failure -import scala.concurrent.ExecutionContext +import scala.concurrent.{ExecutionContext, Future, blocking} +import scala.util.{Failure, Success, Try} class InputStreamPublisher(range: InputStreamRange, chunkSize: Int)(implicit ec: ExecutionContext) extends Publisher[HttpContent] { override def subscribe(subscriber: Subscriber[_ >: HttpContent]): Unit = { @@ -21,7 +19,7 @@ class InputStreamPublisher(range: InputStreamRange, chunkSize: Int)(implicit ec: private class InputStreamSubscription(subscriber: Subscriber[_ >: HttpContent], range: InputStreamRange, chunkSize: Int) extends Subscription { - private val stream = range.inputStreamFromRangeStart() + private lazy val stream: InputStream = range.inputStreamFromRangeStart() private val demand = new AtomicLong(0L) private val position = new AtomicLong(range.range.flatMap(_.start).getOrElse(0L)) private val isCompleted = new AtomicBoolean(false) @@ -35,6 +33,10 @@ class InputStreamPublisher(range: InputStreamRange, chunkSize: Int)(implicit ec: } } + /** Non-blocking by itself, starts an asynchronous operation with blocking stream.readNBytes. Can be called multiple times by + * request(n), or concurrently by onComplete callback. The readingInProgress check ensures that calls are serialized. A + * stream.readNBytes operation will be started only if another isn't running. + */ private def readNextChunkIfNeeded(): Unit = { if (demand.get() > 0 && !isCompleted.get() && readingInProgress.compareAndSet(false, true)) { val pos = position.get() @@ -66,7 +68,7 @@ class InputStreamPublisher(range: InputStreamRange, chunkSize: Int)(implicit ec: } } case Failure(e) => - stream.close() + val _ = Try(stream.close()) subscriber.onError(e) } } @@ -74,7 +76,7 @@ class InputStreamPublisher(range: InputStreamRange, chunkSize: Int)(implicit ec: override def cancel(): Unit = { isCompleted.set(true) - stream.close() + val _ = Try(stream.close()) } } } diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala index 1be06c4338..add8d4ae97 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala @@ -768,7 +768,7 @@ class ServerBasicTests[F[_], OPTIONS, ROUTE]( } def maxContentLengthTests: List[Test] = { - val maxLength = 16484 // To generate a few chunks of default size 8192 + some extra bytes + val maxLength = 17000 // To generate a few chunks of default size 8192 + some extra bytes List( testPayloadTooLarge(in_string_out_string, maxLength), testPayloadTooLarge(in_byte_array_out_byte_array, maxLength), From c60e0149189a8a11f3aafd5c885939ebf469bb21 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Thu, 30 Nov 2023 08:56:05 +0100 Subject: [PATCH 19/35] Remove unused imports --- .../main/scala/sttp/tapir/server/interpreter/RequestBody.scala | 3 --- 1 file changed, 3 deletions(-) diff --git a/server/core/src/main/scala/sttp/tapir/server/interpreter/RequestBody.scala b/server/core/src/main/scala/sttp/tapir/server/interpreter/RequestBody.scala index 5292a8619b..91dc1fcacb 100644 --- a/server/core/src/main/scala/sttp/tapir/server/interpreter/RequestBody.scala +++ b/server/core/src/main/scala/sttp/tapir/server/interpreter/RequestBody.scala @@ -3,8 +3,6 @@ package sttp.tapir.server.interpreter import sttp.capabilities.Streams import sttp.model.Part import sttp.tapir.model.ServerRequest -import sttp.tapir.AttributeKey -import sttp.tapir.EndpointInfo import sttp.tapir.{FileRange, RawBodyType, RawPart} case class MaxContentLength(value: Long) @@ -13,7 +11,6 @@ trait RequestBody[F[_], S] { val streams: Streams[S] def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R], maxBytes: Option[Long]): F[RawValue[R]] def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream - } case class RawValue[R](value: R, createdFiles: Seq[FileRange] = Nil) From 667c59336004403dcc7e73a0729cdc7115c7a5ee Mon Sep 17 00:00:00 2001 From: kciesielski Date: Thu, 30 Nov 2023 09:03:23 +0100 Subject: [PATCH 20/35] Extract MaxContentLength to its own file --- .../sttp/tapir/server/interpreter/MaxContentLength.scala | 9 +++++++++ .../sttp/tapir/server/interpreter/RequestBody.scala | 2 -- 2 files changed, 9 insertions(+), 2 deletions(-) create mode 100644 server/core/src/main/scala/sttp/tapir/server/interpreter/MaxContentLength.scala diff --git a/server/core/src/main/scala/sttp/tapir/server/interpreter/MaxContentLength.scala b/server/core/src/main/scala/sttp/tapir/server/interpreter/MaxContentLength.scala new file mode 100644 index 0000000000..5b9774f8a1 --- /dev/null +++ b/server/core/src/main/scala/sttp/tapir/server/interpreter/MaxContentLength.scala @@ -0,0 +1,9 @@ +package sttp.tapir.server.interpreter + +/** Can be used as an endpoint attribute. + * @example + * {{{ + * endpoint.attribute(AttributeKey[MaxContentLength], MaxContentLength(16384L)) + * }}} + */ +case class MaxContentLength(value: Long) diff --git a/server/core/src/main/scala/sttp/tapir/server/interpreter/RequestBody.scala b/server/core/src/main/scala/sttp/tapir/server/interpreter/RequestBody.scala index 91dc1fcacb..7503040a36 100644 --- a/server/core/src/main/scala/sttp/tapir/server/interpreter/RequestBody.scala +++ b/server/core/src/main/scala/sttp/tapir/server/interpreter/RequestBody.scala @@ -5,8 +5,6 @@ import sttp.model.Part import sttp.tapir.model.ServerRequest import sttp.tapir.{FileRange, RawBodyType, RawPart} -case class MaxContentLength(value: Long) - trait RequestBody[F[_], S] { val streams: Streams[S] def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R], maxBytes: Option[Long]): F[RawValue[R]] From 60cdc3babdfd722cf23ebc3498ff0eb3c28ee3d5 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Thu, 30 Nov 2023 09:06:11 +0100 Subject: [PATCH 21/35] Remove unused imports and reformat --- .../sttp/tapir/server/interpreter/ServerInterpreter.scala | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/server/core/src/main/scala/sttp/tapir/server/interpreter/ServerInterpreter.scala b/server/core/src/main/scala/sttp/tapir/server/interpreter/ServerInterpreter.scala index a49ac9d6d3..2b34ec1df5 100644 --- a/server/core/src/main/scala/sttp/tapir/server/interpreter/ServerInterpreter.scala +++ b/server/core/src/main/scala/sttp/tapir/server/interpreter/ServerInterpreter.scala @@ -9,7 +9,6 @@ import sttp.tapir.model.ServerRequest import sttp.tapir.server.interceptor._ import sttp.tapir.server.model.{ServerResponse, ValuedEndpointOutput} import sttp.tapir.server.{model, _} -import sttp.tapir.{AttributeKey, DecodeResult, EndpointIO, EndpointInfo, EndpointInput, TapirFile} import sttp.tapir.{DecodeResult, EndpointIO, EndpointInput, TapirFile} import sttp.tapir.EndpointInfo import sttp.tapir.AttributeKey @@ -198,9 +197,8 @@ class ServerInterpreter[R, F[_], B, S]( .map(_ => DecodeBasicInputsResult.Failure(bodyInput, failure): DecodeBasicInputsResult) } } - .handleError { - case e: StreamMaxLengthExceededException => - (DecodeBasicInputsResult.Failure(bodyInput, DecodeResult.Error("", e)): DecodeBasicInputsResult).unit + .handleError { case e: StreamMaxLengthExceededException => + (DecodeBasicInputsResult.Failure(bodyInput, DecodeResult.Error("", e)): DecodeBasicInputsResult).unit } } From bf55c763ed1e85c5851199d3c65f63cdedfa134b Mon Sep 17 00:00:00 2001 From: kciesielski Date: Thu, 30 Nov 2023 22:07:02 +0100 Subject: [PATCH 22/35] Extract common logic to NettyRequestBody --- .../netty/internal/NettyCatsRequestBody.scala | 83 +++++++------------ .../netty/loom/NettyIdRequestBody.scala | 61 ++++++++++++++ .../netty/loom/NettyIdServerInterpreter.scala | 4 +- .../internal/NettyFutureRequestBody.scala | 70 +++++----------- ...estBody.scala => NettyIdRequestBody.scala} | 6 +- .../FileWriterSubscriber.scala | 2 +- .../reactivestreams/NettyRequestBody.scala | 71 ++++++++++++++++ .../reactivestreams/SimpleSubscriber.scala | 17 ++-- .../netty/internal/NettyZioRequestBody.scala | 81 ++++++------------ .../tapir/server/tests/ServerBasicTests.scala | 8 +- 10 files changed, 229 insertions(+), 174 deletions(-) create mode 100644 server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdRequestBody.scala rename server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/{NettyRequestBody.scala => NettyIdRequestBody.scala} (89%) create mode 100644 server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/NettyRequestBody.scala diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsRequestBody.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsRequestBody.scala index 3c30ed8579..8fb820e5db 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsRequestBody.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsRequestBody.scala @@ -2,74 +2,47 @@ package sttp.tapir.server.netty.internal import cats.effect.{Async, Sync} import cats.syntax.all._ -import org.playframework.netty.http.StreamedHttpRequest import fs2.Chunk import fs2.interop.reactivestreams.StreamSubscriber import fs2.io.file.{Files, Path} -import io.netty.buffer.ByteBufUtil -import io.netty.handler.codec.http.{FullHttpRequest, HttpContent} +import io.netty.handler.codec.http.HttpContent +import org.reactivestreams.Publisher import sttp.capabilities.fs2.Fs2Streams +import sttp.monad.MonadError +import sttp.tapir.integ.cats.effect.CatsMonadError import sttp.tapir.model.ServerRequest -import sttp.tapir.server.interpreter.{RawValue, RequestBody} -import sttp.tapir.{FileRange, InputStreamRange, RawBodyType, TapirFile} +import sttp.tapir.server.netty.internal.reactivestreams.NettyRequestBody +import sttp.tapir.TapirFile -import java.io.ByteArrayInputStream -import java.nio.ByteBuffer -import sttp.tapir.DecodeResult -import sttp.capabilities.StreamMaxLengthExceededException - -private[netty] class NettyCatsRequestBody[F[_]](createFile: ServerRequest => F[TapirFile])(implicit val monad: Async[F]) - extends RequestBody[F, Fs2Streams[F]] { +private[netty] class NettyCatsRequestBody[F[_]: Async](val createFile: ServerRequest => F[TapirFile]) + extends NettyRequestBody[F, Fs2Streams[F]] { override val streams: Fs2Streams[F] = Fs2Streams[F] - override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R], maxBytes: Option[Long]): F[RawValue[R]] = { + override implicit val monad: MonadError[F] = new CatsMonadError() - def nettyRequestBytes: F[Array[Byte]] = serverRequest.underlying match { - case req: FullHttpRequest => - val buf = req.content() - maxBytes - .map(max => - if (buf.readableBytes() > max) - monad.raiseError[Array[Byte]](StreamMaxLengthExceededException(max)) - else - monad.delay(ByteBufUtil.getBytes(buf)) - ) - .getOrElse(monad.delay(ByteBufUtil.getBytes(buf))) - case _: StreamedHttpRequest => toStream(serverRequest, maxBytes).compile.to(Chunk).map(_.toArray[Byte]) - case other => monad.raiseError(new UnsupportedOperationException(s"Unexpected Netty request of type ${other.getClass().getName()}")) - } - bodyType match { - case RawBodyType.StringBody(charset) => nettyRequestBytes.map(bs => RawValue(new String(bs, charset))) - case RawBodyType.ByteArrayBody => - nettyRequestBytes.map(RawValue(_)) - case RawBodyType.ByteBufferBody => - nettyRequestBytes.map(bs => RawValue(ByteBuffer.wrap(bs))) - case RawBodyType.InputStreamBody => - nettyRequestBytes.map(bs => RawValue(new ByteArrayInputStream(bs))) - case RawBodyType.InputStreamRangeBody => - nettyRequestBytes.map(bs => RawValue(InputStreamRange(() => new ByteArrayInputStream(bs)))) - case RawBodyType.FileBody => - createFile(serverRequest) - .flatMap(tapirFile => { - toStream(serverRequest, maxBytes) - .through( - Files[F](Files.forAsync[F]).writeAll(Path.fromNioPath(tapirFile.toPath)) - ) - .compile - .drain - .map(_ => RawValue(FileRange(tapirFile), Seq(FileRange(tapirFile)))) - }) - case _: RawBodyType.MultipartBody => ??? - } - } + override def publisherToBytes(publisher: Publisher[HttpContent], maxBytes: Option[Long]): F[Array[Byte]] = + publisherToStream(publisher, maxBytes).compile.to(Chunk).map(_.toArray[Byte]) + + override def writeToFile(serverRequest: ServerRequest, file: TapirFile, maxBytes: Option[Long]): F[Unit] = + toStream(serverRequest, maxBytes) + .through( + Files[F](Files.forAsync[F]).writeAll(Path.fromNioPath(file.toPath)) + ) + .compile + .drain - override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = { - val nettyRequest = serverRequest.underlying.asInstanceOf[StreamedHttpRequest] + override def publisherToStream(publisher: Publisher[HttpContent], maxBytes: Option[Long]): streams.BinaryStream = { val stream = fs2.Stream - .eval(StreamSubscriber[F, HttpContent](NettyRequestBody.DefaultChunkSize)) - .flatMap(s => s.sub.stream(Sync[F].delay(nettyRequest.subscribe(s)))) + .eval(StreamSubscriber[F, HttpContent](NettyIdRequestBody.DefaultChunkSize)) + .flatMap(s => s.sub.stream(Sync[F].delay(publisher.subscribe(s)))) .flatMap(httpContent => fs2.Stream.chunk(Chunk.byteBuffer(httpContent.content.nioBuffer()))) maxBytes.map(Fs2Streams.limitBytes(stream, _)).getOrElse(stream) } + + override def failedStream(e: => Throwable): streams.BinaryStream = + fs2.Stream.raiseError(e) + + override def emptyStream: streams.BinaryStream = + fs2.Stream.empty } diff --git a/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdRequestBody.scala b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdRequestBody.scala new file mode 100644 index 0000000000..3e8fb150ce --- /dev/null +++ b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdRequestBody.scala @@ -0,0 +1,61 @@ +package sttp.tapir.server.netty.loom + +import io.netty.handler.codec.http.FullHttpRequest +import io.netty.buffer.{ByteBufInputStream, ByteBufUtil} +import sttp.capabilities +import sttp.monad.MonadError +import sttp.tapir.{FileRange, InputStreamRange, RawBodyType, TapirFile} +import sttp.tapir.model.ServerRequest +import sttp.monad.syntax._ +import sttp.tapir.capabilities.NoStreams +import sttp.tapir.server.interpreter.{RawValue, RequestBody} + +import java.nio.ByteBuffer +import java.nio.file.Files +import io.netty.buffer.ByteBuf +import sttp.capabilities.StreamMaxLengthExceededException + +class NettyIdRequestBody(createFile: ServerRequest => TapirFile) extends RequestBody[Id, NoStreams] { + + override val streams: capabilities.Streams[NoStreams] = NoStreams + + override def toRaw[RAW](serverRequest: ServerRequest, bodyType: RawBodyType[RAW], maxBytes: Option[Long]): RawValue[RAW] = { + + def byteBuf: ByteBuf = { + val buf = nettyRequest(serverRequest).content() + maxBytes + .map(max => + if (buf.readableBytes() > max) + monadError.error[ByteBuf](StreamMaxLengthExceededException(max)) + else + monadError.unit(buf) + ) + .getOrElse(monadError.unit(buf)) + } + + def requestContentAsByteArray: Array[Byte] = byteBuf.map(ByteBufUtil.getBytes) + + bodyType match { + case RawBodyType.StringBody(charset) => byteBuf.map(buf => RawValue(buf.toString(charset))) + case RawBodyType.ByteArrayBody => requestContentAsByteArray.map(ba => RawValue(ba)) + case RawBodyType.ByteBufferBody => requestContentAsByteArray.map(ba => RawValue(ByteBuffer.wrap(ba))) + case RawBodyType.InputStreamBody => byteBuf.map(buf => RawValue(new ByteBufInputStream(buf))) + case RawBodyType.InputStreamRangeBody => + byteBuf.map(buf => RawValue(InputStreamRange(() => new ByteBufInputStream(buf)))) + case RawBodyType.FileBody => + requestContentAsByteArray.flatMap(ba => + createFile(serverRequest) + .map(file => { + Files.write(file.toPath, ba) + RawValue(FileRange(file), Seq(FileRange(file))) + }) + ) + case _: RawBodyType.MultipartBody => ??? + } + } + + override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = + throw new UnsupportedOperationException() + + private def nettyRequest(serverRequest: ServerRequest): FullHttpRequest = serverRequest.underlying.asInstanceOf[FullHttpRequest] +} diff --git a/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdServerInterpreter.scala b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdServerInterpreter.scala index fdcb42fda2..251e21e258 100644 --- a/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdServerInterpreter.scala +++ b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdServerInterpreter.scala @@ -1,7 +1,7 @@ package sttp.tapir.server.netty.loom import sttp.tapir.server.ServerEndpoint -import sttp.tapir.server.netty.internal.{NettyIdToResponseBody, NettyRequestBody, NettyServerInterpreter, RunAsync} +import sttp.tapir.server.netty.internal.{NettyIdToResponseBody, NettyIdRequestBody, NettyServerInterpreter, RunAsync} trait NettyIdServerInterpreter { def nettyServerOptions: NettyIdServerOptions @@ -12,7 +12,7 @@ trait NettyIdServerInterpreter { NettyServerInterpreter.toRoute[Id]( ses, nettyServerOptions.interceptors, - new NettyRequestBody(nettyServerOptions.createFile), + new NettyIdRequestBody(nettyServerOptions.createFile), new NettyIdToResponseBody, nettyServerOptions.deleteFile, new RunAsync[Id] { diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyFutureRequestBody.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyFutureRequestBody.scala index e87923f182..e23f7fc05f 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyFutureRequestBody.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyFutureRequestBody.scala @@ -1,64 +1,40 @@ package sttp.tapir.server.netty.internal -import io.netty.handler.codec.http.FullHttpRequest +import io.netty.handler.codec.http.HttpContent +import org.playframework.netty.http.StreamedHttpRequest +import org.reactivestreams.Publisher import sttp.capabilities -import sttp.tapir.{FileRange, InputStreamRange, RawBodyType, TapirFile} -import sttp.tapir.model.ServerRequest -import sttp.monad.syntax._ +import sttp.monad.{FutureMonad, MonadError} import sttp.tapir.capabilities.NoStreams -import sttp.tapir.server.interpreter.{RawValue, RequestBody} +import sttp.tapir.model.ServerRequest +import sttp.tapir.TapirFile + +import scala.concurrent.{ExecutionContext, Future} -import java.nio.ByteBuffer -import org.playframework.netty.http.StreamedHttpRequest -import scala.concurrent.Future -import scala.concurrent.ExecutionContext import reactivestreams._ -import java.io.ByteArrayInputStream -class NettyFutureRequestBody(createFile: ServerRequest => Future[TapirFile])(implicit ec: ExecutionContext) - extends RequestBody[Future, NoStreams] { +class NettyFutureRequestBody(val createFile: ServerRequest => Future[TapirFile])(implicit ec: ExecutionContext) + extends NettyRequestBody[Future, NoStreams] { override val streams: capabilities.Streams[NoStreams] = NoStreams - override def toRaw[RAW](serverRequest: ServerRequest, bodyType: RawBodyType[RAW], maxBytes: Option[Long]): Future[RawValue[RAW]] = { + override implicit val monad: MonadError[Future] = new FutureMonad() - def byteBuf: Future[ByteBuffer] = - serverRequest.underlying match { - case r: StreamedHttpRequest => SimpleSubscriber.readAll(r, maxBytes) - // This can still happen in case an EmptyHttpRequest is received - case r: FullHttpRequest => Future.successful { - val underlyingBuf = r.content().nioBuffer() - if (underlyingBuf.hasArray()) - underlyingBuf - else - ByteBuffer.wrap(new Array[Byte](0)) - } - case other => Future.failed(new UnsupportedOperationException(s"Unexpected request type: ${other.getClass.getName()}")) - } - - def requestContentAsByteArray: Future[Array[Byte]] = byteBuf.map(_.array) + override def publisherToBytes(publisher: Publisher[HttpContent], maxBytes: Option[Long]): Future[Array[Byte]] = + SimpleSubscriber.processAll(publisher, maxBytes) - bodyType match { - case RawBodyType.StringBody(charset) => requestContentAsByteArray.map(ba => RawValue(new String(ba, charset))) - case RawBodyType.ByteArrayBody => requestContentAsByteArray.map(ba => RawValue(ba)) - case RawBodyType.ByteBufferBody => byteBuf.map(buf => RawValue(buf)) - // InputStreamBody and InputStreamRangeBody can be further optimized to avoid loading all data in memory - case RawBodyType.InputStreamBody => requestContentAsByteArray.map(ba => RawValue(new ByteArrayInputStream(ba))) - case RawBodyType.InputStreamRangeBody => - requestContentAsByteArray.map(ba => RawValue(InputStreamRange(() => new ByteArrayInputStream(ba)))) - case RawBodyType.FileBody => - createFile(serverRequest) - .flatMap(file => - FileWriterSubscriber - .writeAll(nettyRequest(serverRequest), file.toPath, maxBytes) - .map(_ => RawValue(FileRange(file), Seq(FileRange(file)))) - ) - case _: RawBodyType.MultipartBody => ??? + override def writeToFile(serverRequest: ServerRequest, file: TapirFile, maxBytes: Option[Long]): Future[Unit] = + serverRequest.underlying match { + case r: StreamedHttpRequest => FileWriterSubscriber.processAll(r, file.toPath, maxBytes) + case _ => monad.unit(()) } - } - override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = + override def publisherToStream(publisher: Publisher[HttpContent], maxBytes: Option[Long]): streams.BinaryStream = throw new UnsupportedOperationException() - private def nettyRequest(serverRequest: ServerRequest): StreamedHttpRequest = serverRequest.underlying.asInstanceOf[StreamedHttpRequest] + override def emptyStream: streams.BinaryStream = + throw new UnsupportedOperationException() + + override def failedStream(e: => Throwable): streams.BinaryStream = + throw new UnsupportedOperationException() } diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyRequestBody.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyIdRequestBody.scala similarity index 89% rename from server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyRequestBody.scala rename to server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyIdRequestBody.scala index 82ced7e5b3..372201ca9d 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyRequestBody.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyIdRequestBody.scala @@ -13,10 +13,9 @@ import sttp.tapir.server.interpreter.{RawValue, RequestBody} import java.nio.ByteBuffer import java.nio.file.Files import io.netty.buffer.ByteBuf -import sttp.tapir.DecodeResult import sttp.capabilities.StreamMaxLengthExceededException -class NettyRequestBody[F[_]](createFile: ServerRequest => F[TapirFile])(implicit +class NettyIdRequestBody[F[_]](createFile: ServerRequest => F[TapirFile])(implicit monadError: MonadError[F] ) extends RequestBody[F, NoStreams] { @@ -36,7 +35,6 @@ class NettyRequestBody[F[_]](createFile: ServerRequest => F[TapirFile])(implicit .getOrElse(monadError.unit(buf)) } - /** [[ByteBufUtil.getBytes(io.netty.buffer.ByteBuf)]] copies buffer without affecting reader index of the original. */ def requestContentAsByteArray: F[Array[Byte]] = byteBuf.map(ByteBufUtil.getBytes) bodyType match { @@ -64,6 +62,6 @@ class NettyRequestBody[F[_]](createFile: ServerRequest => F[TapirFile])(implicit private def nettyRequest(serverRequest: ServerRequest): FullHttpRequest = serverRequest.underlying.asInstanceOf[FullHttpRequest] } -private[internal] object NettyRequestBody { +private[internal] object NettyIdRequestBody { val DefaultChunkSize = 8192 } diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/FileWriterSubscriber.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/FileWriterSubscriber.scala index db20e5a467..b912a75d15 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/FileWriterSubscriber.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/FileWriterSubscriber.scala @@ -53,7 +53,7 @@ class FileWriterSubscriber(path: Path) extends PromisingSubscriber[Unit, HttpCon } object FileWriterSubscriber { - def writeAll(publisher: Publisher[HttpContent], path: Path, maxBytes: Option[Long]): Future[Unit] = { + def processAll(publisher: Publisher[HttpContent], path: Path, maxBytes: Option[Long]): Future[Unit] = { val subscriber = new FileWriterSubscriber(path) publisher.subscribe(maxBytes.map(new LimitedLengthSubscriber(_, subscriber)).getOrElse(subscriber)) subscriber.future diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/NettyRequestBody.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/NettyRequestBody.scala new file mode 100644 index 0000000000..f7ad44cacd --- /dev/null +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/NettyRequestBody.scala @@ -0,0 +1,71 @@ +package sttp.tapir.server.netty.internal.reactivestreams + +import io.netty.buffer.ByteBufUtil +import io.netty.buffer.Unpooled +import io.netty.handler.codec.http.{FullHttpRequest, HttpContent} +import org.playframework.netty.http.StreamedHttpRequest +import org.reactivestreams.Publisher +import sttp.capabilities.StreamMaxLengthExceededException +import sttp.monad.MonadError +import sttp.monad.syntax._ +import sttp.tapir.model.ServerRequest +import sttp.tapir.server.interpreter.RequestBody +import sttp.tapir.RawBodyType +import sttp.tapir.TapirFile +import sttp.tapir.server.interpreter.RawValue +import sttp.tapir.FileRange +import sttp.tapir.InputStreamRange +import java.io.ByteArrayInputStream +import java.nio.ByteBuffer + +trait NettyRequestBody[F[_], S] extends RequestBody[F, S] { + + implicit def monad: MonadError[F] + def createFile: ServerRequest => F[TapirFile] + def publisherToBytes(publisher: Publisher[HttpContent], maxBytes: Option[Long]): F[Array[Byte]] + def writeToFile(serverRequest: ServerRequest, file: TapirFile, maxBytes: Option[Long]): F[Unit] + def publisherToStream(publisher: Publisher[HttpContent], maxBytes: Option[Long]): streams.BinaryStream + def failedStream(e: => Throwable): streams.BinaryStream + def emptyStream: streams.BinaryStream + + override def toRaw[RAW](serverRequest: ServerRequest, bodyType: RawBodyType[RAW], maxBytes: Option[Long]): F[RawValue[RAW]] = { + bodyType match { + case RawBodyType.StringBody(charset) => readAllBytes(serverRequest, maxBytes).map(bs => RawValue(new String(bs, charset))) + case RawBodyType.ByteArrayBody => + readAllBytes(serverRequest, maxBytes).map(RawValue(_)) + case RawBodyType.ByteBufferBody => + readAllBytes(serverRequest, maxBytes).map(bs => RawValue(ByteBuffer.wrap(bs))) + case RawBodyType.InputStreamBody => + readAllBytes(serverRequest, maxBytes).map(bs => RawValue(new ByteArrayInputStream(bs))) + case RawBodyType.InputStreamRangeBody => + readAllBytes(serverRequest, maxBytes).map(bs => RawValue(InputStreamRange(() => new ByteArrayInputStream(bs)))) + case RawBodyType.FileBody => + for { + file <- createFile(serverRequest) + _ <- writeToFile(serverRequest, file, maxBytes) + } + yield RawValue(FileRange(file), Seq(FileRange(file))) + case _: RawBodyType.MultipartBody => monad.error(new UnsupportedOperationException()) + } + } + + override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = + serverRequest.underlying match { + case r: FullHttpRequest if r.content() == Unpooled.EMPTY_BUFFER => // means EmptyHttpRequest, but that class is not public + emptyStream + case publisher: StreamedHttpRequest => + publisherToStream(publisher, maxBytes) + case other => + failedStream(new UnsupportedOperationException(s"Unexpected Netty request of type: ${other.getClass().getName()}")) + } + + // Used by different netty backends to handle raw body input + def readAllBytes(serverRequest: ServerRequest, maxBytes: Option[Long]): F[Array[Byte]] = + serverRequest.underlying match { + case r: FullHttpRequest if r.content() == Unpooled.EMPTY_BUFFER => + monad.unit(Array[Byte](0)) + case req: StreamedHttpRequest => + publisherToBytes(req, maxBytes) + case other => monad.error(new UnsupportedOperationException(s"Unexpected Netty request of type ${other.getClass().getName()}")) + } +} diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/SimpleSubscriber.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/SimpleSubscriber.scala index 056990ca9c..0f2b10bbb4 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/SimpleSubscriber.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/SimpleSubscriber.scala @@ -4,18 +4,17 @@ import io.netty.buffer.ByteBufUtil import io.netty.handler.codec.http.HttpContent import org.reactivestreams.{Publisher, Subscription} -import java.nio.ByteBuffer import java.util.concurrent.ConcurrentLinkedQueue import scala.collection.JavaConverters._ import scala.concurrent.{Future, Promise} -private[netty] class SimpleSubscriber() extends PromisingSubscriber[ByteBuffer, HttpContent] { +private[netty] class SimpleSubscriber() extends PromisingSubscriber[Array[Byte], HttpContent] { private var subscription: Subscription = _ private val chunks = new ConcurrentLinkedQueue[Array[Byte]]() private var size = 0 - private val resultPromise = Promise[ByteBuffer]() + private val resultPromise = Promise[Array[Byte]]() - override def future: Future[ByteBuffer] = resultPromise.future + override def future: Future[Array[Byte]] = resultPromise.future override def onSubscribe(s: Subscription): Unit = { subscription = s @@ -35,16 +34,18 @@ private[netty] class SimpleSubscriber() extends PromisingSubscriber[ByteBuffer, } override def onComplete(): Unit = { - val result = ByteBuffer.allocate(size) - chunks.asScala.foreach(result.put) - result.flip() + val result = new Array[Byte](size) + chunks.asScala.foldLeft(0)((currentPosition, array) => { + System.arraycopy(array, 0, result, currentPosition, array.length) + currentPosition + array.length + }) chunks.clear() resultPromise.success(result) } } object SimpleSubscriber { - def readAll(publisher: Publisher[HttpContent], maxBytes: Option[Long]): Future[ByteBuffer] = { + def processAll(publisher: Publisher[HttpContent], maxBytes: Option[Long]): Future[Array[Byte]] = { val subscriber = new SimpleSubscriber() publisher.subscribe(maxBytes.map(max => new LimitedLengthSubscriber(max, subscriber)).getOrElse(subscriber)) subscriber.future diff --git a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/internal/NettyZioRequestBody.scala b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/internal/NettyZioRequestBody.scala index a80e69a652..45b53676db 100644 --- a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/internal/NettyZioRequestBody.scala +++ b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/internal/NettyZioRequestBody.scala @@ -1,72 +1,41 @@ package sttp.tapir.server.netty.internal -import org.playframework.netty.http.StreamedHttpRequest -import io.netty.buffer.ByteBufUtil -import io.netty.handler.codec.http.FullHttpRequest +import io.netty.handler.codec.http.HttpContent +import org.reactivestreams.Publisher import sttp.capabilities.zio.ZioStreams -import sttp.tapir.RawBodyType._ +import sttp.monad.MonadError +import sttp.tapir.TapirFile import sttp.tapir.model.ServerRequest -import sttp.tapir.server.interpreter.{RawValue, RequestBody} -import sttp.tapir.{FileRange, InputStreamRange, RawBodyType, TapirFile} +import sttp.tapir.server.netty.internal.reactivestreams.NettyRequestBody +import sttp.tapir.ztapir.RIOMonadError import zio.interop.reactivestreams._ import zio.stream.{ZStream, _} -import zio.{Chunk, RIO, ZIO} +import zio.{Chunk, RIO} -import java.io.ByteArrayInputStream -import java.nio.ByteBuffer -import sttp.tapir.DecodeResult -import sttp.capabilities.StreamMaxLengthExceededException - -private[netty] class NettyZioRequestBody[Env](createFile: ServerRequest => RIO[Env, TapirFile]) - extends RequestBody[RIO[Env, *], ZioStreams] { +private[netty] class NettyZioRequestBody[Env](val createFile: ServerRequest => RIO[Env, TapirFile]) + extends NettyRequestBody[RIO[Env, *], ZioStreams] { override val streams: ZioStreams = ZioStreams - override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R], maxBytes: Option[Long]): RIO[Env, RawValue[R]] = { - - def nettyRequestBytes: RIO[Env, Array[Byte]] = serverRequest.underlying match { - case req: FullHttpRequest => - val buf = req.content() - maxBytes - .map(max => - if (buf.readableBytes() > max) - ZIO.fail(StreamMaxLengthExceededException(max)) - else - ZIO.succeed(ByteBufUtil.getBytes(buf)) - ) - .getOrElse(ZIO.succeed(ByteBufUtil.getBytes(buf))) + override implicit val monad: MonadError[RIO[Env, *]] = new RIOMonadError[Env] - case _: StreamedHttpRequest => toStream(serverRequest, maxBytes).run(ZSink.collectAll[Byte]).map(_.toArray) - case other => ZIO.fail(new UnsupportedOperationException(s"Unexpected Netty request of type ${other.getClass().getName()}")) - } - bodyType match { - case StringBody(charset) => nettyRequestBytes.map(bs => RawValue(new String(bs, charset))) + override def publisherToBytes(publisher: Publisher[HttpContent], maxBytes: Option[Long]): RIO[Env, Array[Byte]] = + publisherToStream(publisher, maxBytes).run(ZSink.collectAll[Byte]).map(_.toArray) - case ByteArrayBody => - nettyRequestBytes.map(RawValue(_)) - case ByteBufferBody => - nettyRequestBytes.map(bs => RawValue(ByteBuffer.wrap(bs))) - case InputStreamBody => - nettyRequestBytes.map(bs => RawValue(new ByteArrayInputStream(bs))) - case InputStreamRangeBody => - nettyRequestBytes.map(bs => RawValue(InputStreamRange(() => new ByteArrayInputStream(bs)))) - case FileBody => - createFile(serverRequest) - .flatMap(tapirFile => { - toStream(serverRequest, maxBytes) - .run(ZSink.fromFile(tapirFile)) - .map(_ => RawValue(FileRange(tapirFile), Seq(FileRange(tapirFile)))) - }) - case MultipartBody(partTypes, defaultType) => - throw new java.lang.UnsupportedOperationException() - } - } + override def writeToFile(serverRequest: ServerRequest, file: TapirFile, maxBytes: Option[Long]): RIO[Env, Unit] = + toStream(serverRequest, maxBytes).run(ZSink.fromFile(file)).map(_ => ()) - override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = { - val stream = serverRequest.underlying - .asInstanceOf[StreamedHttpRequest] - .toZIOStream() - .flatMap(httpContent => ZStream.fromChunk(Chunk.fromByteBuffer(httpContent.content.nioBuffer()))) + override def publisherToStream(publisher: Publisher[HttpContent], maxBytes: Option[Long]): streams.BinaryStream = { + val stream = + Adapters + .publisherToStream(publisher, 16) + .flatMap(httpContent => ZStream.fromChunk(Chunk.fromByteBuffer(httpContent.content.nioBuffer()))) maxBytes.map(ZioStreams.limitBytes(stream, _)).getOrElse(stream) } + + override def failedStream(e: => Throwable): streams.BinaryStream = + ZStream.fail(e) + + override def emptyStream: streams.BinaryStream = + ZStream.empty } diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala index add8d4ae97..a560f1b28f 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala @@ -779,7 +779,13 @@ class ServerBasicTests[F[_], OPTIONS, ROUTE]( testPayloadWithinLimit(in_input_stream_out_input_stream, maxLength), testPayloadWithinLimit(in_byte_array_out_byte_array, maxLength), testPayloadWithinLimit(in_file_out_file, maxLength), - testPayloadWithinLimit(in_byte_buffer_out_byte_buffer, maxLength) + testPayloadWithinLimit(in_byte_buffer_out_byte_buffer, maxLength), + testServer( + in_string_out_string, + "testkc" + )(i => pureResult(i.asRight[Unit])) { (backend, baseUri) => + basicRequest.post(uri"$baseUri/api/echo").body("").send(backend).map(_.code shouldBe StatusCode.Ok) + } ) } From 337ebbef4c469a0162fabd6023dd427be18706bc Mon Sep 17 00:00:00 2001 From: kciesielski Date: Fri, 1 Dec 2023 10:23:43 +0100 Subject: [PATCH 23/35] Refactoring --- .../cats/NettyCatsServerInterpreter.scala | 3 +- .../cats/internal/Fs2StreamCompatible.scala | 61 +++++++++++ .../netty/internal/NettyCatsRequestBody.scala | 2 +- .../internal/NettyCatsToResponseBody.scala | 101 ------------------ .../netty/loom/NettyIdRequestBody.scala | 12 ++- .../netty/loom/NettyIdServerInterpreter.scala | 4 +- .../netty/NettyFutureServerInterpreter.scala | 4 +- .../netty/internal/NettyIdRequestBody.scala | 67 ------------ .../internal/NettyIdToResponseBody.scala | 30 ------ ...seBody.scala => NettyToResponseBody.scala} | 14 +-- .../netty/internal/StreamCompatible.scala | 3 +- .../InputStreamPublisher.scala | 49 +++++---- .../reactivestreams/NettyRequestBody.scala | 6 +- .../zio/internal/ZioStreamCompatible.scala | 3 - 14 files changed, 113 insertions(+), 246 deletions(-) create mode 100644 server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/Fs2StreamCompatible.scala delete mode 100644 server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsToResponseBody.scala delete mode 100644 server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyIdRequestBody.scala delete mode 100644 server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyIdToResponseBody.scala rename server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/{NettyFutureToResponseBody.scala => NettyToResponseBody.scala} (87%) diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServerInterpreter.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServerInterpreter.scala index 92830b8652..71a87a4f98 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServerInterpreter.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServerInterpreter.scala @@ -2,6 +2,7 @@ package sttp.tapir.server.netty.cats import cats.effect.Async import cats.effect.std.Dispatcher +import internal.Fs2StreamCompatible import sttp.capabilities.fs2.Fs2Streams import sttp.monad.MonadError import sttp.monad.syntax._ @@ -32,7 +33,7 @@ trait NettyCatsServerInterpreter[F[_]] { val serverInterpreter = new ServerInterpreter[Fs2Streams[F], F, NettyResponse, Fs2Streams[F]]( FilterServerEndpoints(ses), new NettyCatsRequestBody(createFile), - new NettyCatsToResponseBody(nettyServerOptions.dispatcher), + new NettyToStreamsResponseBody(Fs2StreamCompatible[F](nettyServerOptions.dispatcher)), RejectInterceptor.disableWhenSingleEndpoint(interceptors, ses), deleteFile ) diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/Fs2StreamCompatible.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/Fs2StreamCompatible.scala new file mode 100644 index 0000000000..5f61e58a0f --- /dev/null +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/Fs2StreamCompatible.scala @@ -0,0 +1,61 @@ +package sttp.tapir.server.netty.cats.internal + +import io.netty.buffer.Unpooled +import io.netty.handler.codec.http.{DefaultHttpContent, HttpContent} +import org.reactivestreams.Publisher +import sttp.tapir.FileRange +import sttp.tapir.server.netty.internal._ + +import java.io.InputStream +import cats.effect.std.Dispatcher +import sttp.capabilities.fs2.Fs2Streams +import fs2.io.file.Path +import fs2.io.file.Files +import cats.effect.kernel.Async +import fs2.io.file.Flags +import fs2.interop.reactivestreams.StreamUnicastPublisher +import cats.effect.kernel.Sync +import fs2.Chunk + +private[cats] object Fs2StreamCompatible { + + def apply[F[_]: Async](dispatcher: Dispatcher[F]): StreamCompatible[Fs2Streams[F]] = { + new StreamCompatible[Fs2Streams[F]] { + override val streams: Fs2Streams[F] = Fs2Streams[F] + + override def fromFile(fileRange: FileRange): streams.BinaryStream = { + val path = Path.fromNioPath(fileRange.file.toPath) + fileRange.range + .flatMap(r => + r.startAndEnd.map(s => Files[F](Files.forAsync[F]).readRange(path, NettyToResponseBody.DefaultChunkSize, s._1, s._2)) + ) + .getOrElse(Files[F](Files.forAsync[F]).readAll(path, NettyToResponseBody.DefaultChunkSize, Flags.Read)) + } + + override def fromInputStream(is: () => InputStream, length: Option[Long]): streams.BinaryStream = + length match { + case Some(limitedLength) => inputStreamToFs2(is).take(limitedLength) + case None => inputStreamToFs2(is) + } + + override def asPublisher(stream: fs2.Stream[F, Byte]): Publisher[HttpContent] = + // Deprecated constructor, but the proposed one does roughly the same, forcing a dedicated + // dispatcher, which results in a Resource[], which is hard to afford here + StreamUnicastPublisher( + stream + .chunkLimit(NettyToResponseBody.DefaultChunkSize) + .map { chunk => + val bytes: Chunk.ArraySlice[Byte] = chunk.compact + new DefaultHttpContent(Unpooled.wrappedBuffer(bytes.values, bytes.offset, bytes.length)) + }, + dispatcher + ) + + private def inputStreamToFs2(inputStream: () => InputStream) = + fs2.io.readInputStream( + Sync[F].blocking(inputStream()), + NettyToResponseBody.DefaultChunkSize + ) + } + } +} diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsRequestBody.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsRequestBody.scala index 8fb820e5db..321f6c8983 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsRequestBody.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsRequestBody.scala @@ -34,7 +34,7 @@ private[netty] class NettyCatsRequestBody[F[_]: Async](val createFile: ServerReq override def publisherToStream(publisher: Publisher[HttpContent], maxBytes: Option[Long]): streams.BinaryStream = { val stream = fs2.Stream - .eval(StreamSubscriber[F, HttpContent](NettyIdRequestBody.DefaultChunkSize)) + .eval(StreamSubscriber[F, HttpContent](DefaultChunkSize)) .flatMap(s => s.sub.stream(Sync[F].delay(publisher.subscribe(s)))) .flatMap(httpContent => fs2.Stream.chunk(Chunk.byteBuffer(httpContent.content.nioBuffer()))) maxBytes.map(Fs2Streams.limitBytes(stream, _)).getOrElse(stream) diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsToResponseBody.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsToResponseBody.scala deleted file mode 100644 index 3a247f9d61..0000000000 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsToResponseBody.scala +++ /dev/null @@ -1,101 +0,0 @@ -package sttp.tapir.server.netty.internal - -import cats.effect.kernel.{Async, Sync} -import cats.effect.std.Dispatcher -import fs2.Chunk -import fs2.interop.reactivestreams._ -import fs2.io.file.{Files, Flags, Path} -import io.netty.buffer.Unpooled -import io.netty.channel.ChannelHandlerContext -import io.netty.handler.codec.http.{DefaultHttpContent, HttpContent} -import org.reactivestreams.Publisher -import sttp.capabilities.fs2.Fs2Streams -import sttp.model.HasHeaders -import sttp.tapir.server.interpreter.ToResponseBody -import sttp.tapir.server.netty.NettyResponse -import sttp.tapir.server.netty.NettyResponseContent._ -import sttp.tapir.{CodecFormat, RawBodyType, WebSocketBodyOutput} - -import java.io.InputStream -import java.nio.ByteBuffer -import java.nio.charset.Charset - -class NettyCatsToResponseBody[F[_]: Async](dispatcher: Dispatcher[F]) extends ToResponseBody[NettyResponse, Fs2Streams[F]] { - override val streams: Fs2Streams[F] = Fs2Streams[F] - - override def fromRawValue[R](v: R, headers: HasHeaders, format: CodecFormat, bodyType: RawBodyType[R]): NettyResponse = { - bodyType match { - - case RawBodyType.StringBody(charset) => - val bytes = v.asInstanceOf[String].getBytes(charset) - (ctx: ChannelHandlerContext) => ByteBufNettyResponseContent(ctx.newPromise(), Unpooled.wrappedBuffer(bytes)) - - case RawBodyType.ByteArrayBody => - val bytes = v.asInstanceOf[Array[Byte]] - (ctx: ChannelHandlerContext) => ByteBufNettyResponseContent(ctx.newPromise(), Unpooled.wrappedBuffer(bytes)) - - case RawBodyType.ByteBufferBody => - val byteBuffer = v.asInstanceOf[ByteBuffer] - (ctx: ChannelHandlerContext) => ByteBufNettyResponseContent(ctx.newPromise(), Unpooled.wrappedBuffer(byteBuffer)) - - case RawBodyType.InputStreamBody => - val stream = inputStreamToFs2(() => v) - (ctx: ChannelHandlerContext) => new ReactivePublisherNettyResponseContent(ctx.newPromise(), fs2StreamToPublisher(stream)) - - case RawBodyType.InputStreamRangeBody => - val stream = v.range - .map(range => inputStreamToFs2(v.inputStreamFromRangeStart).take(range.contentLength)) - .getOrElse(inputStreamToFs2(v.inputStream)) - (ctx: ChannelHandlerContext) => new ReactivePublisherNettyResponseContent(ctx.newPromise(), fs2StreamToPublisher(stream)) - - case RawBodyType.FileBody => - val tapirFile = v - val path = Path.fromNioPath(tapirFile.file.toPath) - val stream = tapirFile.range - .flatMap(r => - r.startAndEnd.map(s => Files[F](Files.forAsync[F]).readRange(path, NettyFutureToResponseBody.DefaultChunkSize, s._1, s._2)) - ) - .getOrElse(Files[F](Files.forAsync[F]).readAll(path, NettyFutureToResponseBody.DefaultChunkSize, Flags.Read)) - - (ctx: ChannelHandlerContext) => new ReactivePublisherNettyResponseContent(ctx.newPromise(), fs2StreamToPublisher(stream)) - - case _: RawBodyType.MultipartBody => throw new UnsupportedOperationException - } - } - - private def inputStreamToFs2(inputStream: () => InputStream) = - fs2.io.readInputStream( - Sync[F].blocking(inputStream()), - NettyFutureToResponseBody.DefaultChunkSize - ) - - private def fs2StreamToPublisher(stream: streams.BinaryStream): Publisher[HttpContent] = { - // Deprecated constructor, but the proposed one does roughly the same, forcing a dedicated - // dispatcher, which results in a Resource[], which is hard to afford here - StreamUnicastPublisher( - stream - .chunkLimit(NettyFutureToResponseBody.DefaultChunkSize) - .map { chunk => - val bytes: Chunk.ArraySlice[Byte] = chunk.compact - - new DefaultHttpContent(Unpooled.wrappedBuffer(bytes.values, bytes.offset, bytes.length)) - }, - dispatcher - ) - } - - override def fromStreamValue( - v: streams.BinaryStream, - headers: HasHeaders, - format: CodecFormat, - charset: Option[Charset] - ): NettyResponse = - (ctx: ChannelHandlerContext) => { - new ReactivePublisherNettyResponseContent(ctx.newPromise(), fs2StreamToPublisher(v)) - } - - override def fromWebSocketPipe[REQ, RESP]( - pipe: streams.Pipe[REQ, RESP], - o: WebSocketBodyOutput[streams.Pipe[REQ, RESP], REQ, RESP, _, Fs2Streams[F]] - ): NettyResponse = throw new UnsupportedOperationException -} diff --git a/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdRequestBody.scala b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdRequestBody.scala index 3e8fb150ce..8eab65b298 100644 --- a/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdRequestBody.scala +++ b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdRequestBody.scala @@ -14,11 +14,21 @@ import java.nio.ByteBuffer import java.nio.file.Files import io.netty.buffer.ByteBuf import sttp.capabilities.StreamMaxLengthExceededException +import sttp.tapir.server.netty.internal.reactivestreams.SimpleSubscriber -class NettyIdRequestBody(createFile: ServerRequest => TapirFile) extends RequestBody[Id, NoStreams] { +class NettyIdRequestBody(val createFile: ServerRequest => TapirFile) extends NettyRequestBody[Id, NoStreams] { + override implicit val monad: MonadError[Id] = idMonad override val streams: capabilities.Streams[NoStreams] = NoStreams + def publisherToBytes(publisher: Publisher[HttpContent], maxBytes: Option[Long]): Array[Byte] = + SimpleSubscriber.processAll(publisher, maxBytes) + + def writeToFile(serverRequest: ServerRequest, file: TapirFile, maxBytes: Option[Long]): Unit = + Files.write(fi) + + def publisherToStream(publisher: Publisher[HttpContent], maxBytes: Option[Long]): streams.BinaryStream + override def toRaw[RAW](serverRequest: ServerRequest, bodyType: RawBodyType[RAW], maxBytes: Option[Long]): RawValue[RAW] = { def byteBuf: ByteBuf = { diff --git a/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdServerInterpreter.scala b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdServerInterpreter.scala index 251e21e258..444a4c5e5d 100644 --- a/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdServerInterpreter.scala +++ b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdServerInterpreter.scala @@ -1,7 +1,7 @@ package sttp.tapir.server.netty.loom import sttp.tapir.server.ServerEndpoint -import sttp.tapir.server.netty.internal.{NettyIdToResponseBody, NettyIdRequestBody, NettyServerInterpreter, RunAsync} +import sttp.tapir.server.netty.internal.{NettyToResponseBody, NettyIdRequestBody, NettyServerInterpreter, RunAsync} trait NettyIdServerInterpreter { def nettyServerOptions: NettyIdServerOptions @@ -13,7 +13,7 @@ trait NettyIdServerInterpreter { ses, nettyServerOptions.interceptors, new NettyIdRequestBody(nettyServerOptions.createFile), - new NettyIdToResponseBody, + new NettyToResponseBody[Id], nettyServerOptions.deleteFile, new RunAsync[Id] { override def apply[T](f: => Id[T]): Unit = { diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServerInterpreter.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServerInterpreter.scala index deec6449c6..246c49c6bd 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServerInterpreter.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServerInterpreter.scala @@ -7,7 +7,7 @@ import sttp.tapir.server.netty.internal.{NettyServerInterpreter, RunAsync} import scala.concurrent.{ExecutionContext, Future} import sttp.tapir.server.netty.internal.NettyFutureRequestBody -import sttp.tapir.server.netty.internal.NettyFutureToResponseBody +import sttp.tapir.server.netty.internal.NettyToResponseBody trait NettyFutureServerInterpreter { def nettyServerOptions: NettyFutureServerOptions @@ -24,7 +24,7 @@ trait NettyFutureServerInterpreter { ses, nettyServerOptions.interceptors, new NettyFutureRequestBody(nettyServerOptions.createFile), - new NettyFutureToResponseBody, + new NettyToResponseBody[Future](), nettyServerOptions.deleteFile, FutureRunAsync ) diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyIdRequestBody.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyIdRequestBody.scala deleted file mode 100644 index 372201ca9d..0000000000 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyIdRequestBody.scala +++ /dev/null @@ -1,67 +0,0 @@ -package sttp.tapir.server.netty.internal - -import io.netty.buffer.{ByteBufInputStream, ByteBufUtil} -import io.netty.handler.codec.http.FullHttpRequest -import sttp.capabilities -import sttp.monad.MonadError -import sttp.tapir.{FileRange, InputStreamRange, RawBodyType, TapirFile} -import sttp.tapir.model.ServerRequest -import sttp.monad.syntax._ -import sttp.tapir.capabilities.NoStreams -import sttp.tapir.server.interpreter.{RawValue, RequestBody} - -import java.nio.ByteBuffer -import java.nio.file.Files -import io.netty.buffer.ByteBuf -import sttp.capabilities.StreamMaxLengthExceededException - -class NettyIdRequestBody[F[_]](createFile: ServerRequest => F[TapirFile])(implicit - monadError: MonadError[F] -) extends RequestBody[F, NoStreams] { - - override val streams: capabilities.Streams[NoStreams] = NoStreams - - override def toRaw[RAW](serverRequest: ServerRequest, bodyType: RawBodyType[RAW], maxBytes: Option[Long]): F[RawValue[RAW]] = { - - def byteBuf: F[ByteBuf] = { - val buf = nettyRequest(serverRequest).content() - maxBytes - .map(max => - if (buf.readableBytes() > max) - monadError.error[ByteBuf](StreamMaxLengthExceededException(max)) - else - monadError.unit(buf) - ) - .getOrElse(monadError.unit(buf)) - } - - def requestContentAsByteArray: F[Array[Byte]] = byteBuf.map(ByteBufUtil.getBytes) - - bodyType match { - case RawBodyType.StringBody(charset) => byteBuf.map(buf => RawValue(buf.toString(charset))) - case RawBodyType.ByteArrayBody => requestContentAsByteArray.map(ba => RawValue(ba)) - case RawBodyType.ByteBufferBody => requestContentAsByteArray.map(ba => RawValue(ByteBuffer.wrap(ba))) - case RawBodyType.InputStreamBody => byteBuf.map(buf => RawValue(new ByteBufInputStream(buf))) - case RawBodyType.InputStreamRangeBody => - byteBuf.map(buf => RawValue(InputStreamRange(() => new ByteBufInputStream(buf)))) - case RawBodyType.FileBody => - requestContentAsByteArray.flatMap(ba => - createFile(serverRequest) - .map(file => { - Files.write(file.toPath, ba) - RawValue(FileRange(file), Seq(FileRange(file))) - }) - ) - case _: RawBodyType.MultipartBody => ??? - } - } - - override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = - throw new UnsupportedOperationException() - - private def nettyRequest(serverRequest: ServerRequest): FullHttpRequest = serverRequest.underlying.asInstanceOf[FullHttpRequest] -} - -private[internal] object NettyIdRequestBody { - val DefaultChunkSize = 8192 -} diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyIdToResponseBody.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyIdToResponseBody.scala deleted file mode 100644 index 75c5b408eb..0000000000 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyIdToResponseBody.scala +++ /dev/null @@ -1,30 +0,0 @@ -package sttp.tapir.server.netty.internal - -import sttp.capabilities -import sttp.model.HasHeaders -import sttp.tapir.capabilities.NoStreams -import sttp.tapir.server.interpreter.ToResponseBody -import sttp.tapir.server.netty.NettyResponse -import sttp.tapir.{CodecFormat, RawBodyType} -import sttp.tapir.WebSocketBodyOutput -import java.nio.charset.Charset - -class NettyIdToResponseBody extends ToResponseBody[NettyResponse, NoStreams] { - - override val streams: capabilities.Streams[NoStreams] = NoStreams - - override def fromRawValue[R](v: R, headers: HasHeaders, format: CodecFormat, bodyType: RawBodyType[R]): NettyResponse = { - ??? // TODO - } - override def fromStreamValue( - v: streams.BinaryStream, - headers: HasHeaders, - format: CodecFormat, - charset: Option[Charset] - ): NettyResponse = ??? - - override def fromWebSocketPipe[REQ, RESP]( - pipe: streams.Pipe[REQ, RESP], - o: WebSocketBodyOutput[streams.Pipe[REQ, RESP], REQ, RESP, _, NoStreams] - ): NettyResponse = throw new UnsupportedOperationException -} diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyFutureToResponseBody.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToResponseBody.scala similarity index 87% rename from server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyFutureToResponseBody.scala rename to server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToResponseBody.scala index 37f9de2af4..0d16c48142 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyFutureToResponseBody.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToResponseBody.scala @@ -6,20 +6,20 @@ import io.netty.handler.codec.http.HttpContent import org.reactivestreams.Publisher import sttp.capabilities import sttp.model.HasHeaders +import sttp.monad.MonadError import sttp.tapir.capabilities.NoStreams import sttp.tapir.server.interpreter.ToResponseBody import sttp.tapir.server.netty.NettyResponse import sttp.tapir.server.netty.NettyResponseContent.{ByteBufNettyResponseContent, ReactivePublisherNettyResponseContent} -import sttp.tapir.server.netty.internal.NettyFutureToResponseBody.DefaultChunkSize +import sttp.tapir.server.netty.internal.NettyToResponseBody.DefaultChunkSize import sttp.tapir.server.netty.internal.reactivestreams.{FileRangePublisher, InputStreamPublisher} import sttp.tapir.{CodecFormat, FileRange, InputStreamRange, RawBodyType, WebSocketBodyOutput} import java.io.InputStream import java.nio.ByteBuffer import java.nio.charset.Charset -import scala.concurrent.ExecutionContext -class NettyFutureToResponseBody(implicit ec: ExecutionContext) extends ToResponseBody[NettyResponse, NoStreams] { +class NettyToResponseBody[F[_]](implicit me: MonadError[F]) extends ToResponseBody[NettyResponse, NoStreams] { override val streams: capabilities.Streams[NoStreams] = NoStreams override def fromRawValue[R](v: R, headers: HasHeaders, format: CodecFormat, bodyType: RawBodyType[R]): NettyResponse = { @@ -44,15 +44,13 @@ class NettyFutureToResponseBody(implicit ec: ExecutionContext) extends ToRespons case RawBodyType.FileBody => { (ctx: ChannelHandlerContext) => ReactivePublisherNettyResponseContent(ctx.newPromise(), wrap(v)) - } - case _: RawBodyType.MultipartBody => throw new UnsupportedOperationException } } private def wrap(streamRange: InputStreamRange): Publisher[HttpContent] = { - new InputStreamPublisher(streamRange, DefaultChunkSize) + new InputStreamPublisher[F](streamRange, DefaultChunkSize) } private def wrap(fileRange: FileRange): Publisher[HttpContent] = { @@ -76,8 +74,6 @@ class NettyFutureToResponseBody(implicit ec: ExecutionContext) extends ToRespons ): NettyResponse = throw new UnsupportedOperationException } -private[internal] object NettyFutureToResponseBody { +object NettyToResponseBody { val DefaultChunkSize = 8192 - val IncludingLastOffset = 1 - val ReadOnlyAccessMode = "r" } diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/StreamCompatible.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/StreamCompatible.scala index a64d92dfca..e4ca4d72a3 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/StreamCompatible.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/StreamCompatible.scala @@ -3,7 +3,7 @@ package sttp.tapir.server.netty.internal import io.netty.handler.codec.http.HttpContent import org.reactivestreams.Publisher import sttp.capabilities.Streams -import sttp.tapir.{FileRange, TapirFile} +import sttp.tapir.FileRange import java.io.InputStream @@ -11,7 +11,6 @@ private[netty] trait StreamCompatible[S <: Streams[S]] { val streams: S def fromFile(file: FileRange): streams.BinaryStream def fromInputStream(is: () => InputStream, length: Option[Long]): streams.BinaryStream - def fromNettyStream(s: Publisher[HttpContent]): streams.BinaryStream def asPublisher(s: streams.BinaryStream): Publisher[HttpContent] def publisherFromFile(file: FileRange): Publisher[HttpContent] = diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/InputStreamPublisher.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/InputStreamPublisher.scala index 517d5b11b3..7f16c0a108 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/InputStreamPublisher.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/InputStreamPublisher.scala @@ -7,10 +7,11 @@ import sttp.tapir.InputStreamRange import java.io.InputStream import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong} -import scala.concurrent.{ExecutionContext, Future, blocking} -import scala.util.{Failure, Success, Try} +import scala.util.Try +import sttp.monad.MonadError +import sttp.monad.syntax._ -class InputStreamPublisher(range: InputStreamRange, chunkSize: Int)(implicit ec: ExecutionContext) extends Publisher[HttpContent] { +class InputStreamPublisher[F[_]](range: InputStreamRange, chunkSize: Int)(implicit monad: MonadError[F]) extends Publisher[HttpContent] { override def subscribe(subscriber: Subscriber[_ >: HttpContent]): Unit = { if (subscriber == null) throw new NullPointerException("Subscriber cannot be null") val subscription = new InputStreamSubscription(subscriber, range, chunkSize) @@ -44,32 +45,34 @@ class InputStreamPublisher(range: InputStreamRange, chunkSize: Int)(implicit ec: case Some(endPos) if pos + chunkSize > endPos => (endPos - pos + 1).toInt case _ => chunkSize } - Future { - blocking { + + val _ = monad + .blocking( stream.readNBytes(expectedBytes) - } - } - .onComplete { - case Success(bytes) => - val bytesRead = bytes.length - if (bytesRead == 0) { + ) + .map { bytes => + val bytesRead = bytes.length + if (bytesRead == 0) { + cancel() + subscriber.onComplete() + } else { + position.addAndGet(bytesRead.toLong) + subscriber.onNext(new DefaultHttpContent(Unpooled.wrappedBuffer(bytes))) + if (bytesRead < expectedBytes) { cancel() subscriber.onComplete() } else { - position.addAndGet(bytesRead.toLong) - subscriber.onNext(new DefaultHttpContent(Unpooled.wrappedBuffer(bytes))) - if (bytesRead < expectedBytes) { - cancel() - subscriber.onComplete() - } else { - demand.decrementAndGet() - readingInProgress.set(false) - readNextChunkIfNeeded() - } + demand.decrementAndGet() + readingInProgress.set(false) + readNextChunkIfNeeded() } - case Failure(e) => + } + } + .handleError { + case e => { val _ = Try(stream.close()) - subscriber.onError(e) + monad.unit(subscriber.onError(e)) + } } } } diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/NettyRequestBody.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/NettyRequestBody.scala index f7ad44cacd..e075d76559 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/NettyRequestBody.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/NettyRequestBody.scala @@ -1,11 +1,9 @@ package sttp.tapir.server.netty.internal.reactivestreams -import io.netty.buffer.ByteBufUtil import io.netty.buffer.Unpooled import io.netty.handler.codec.http.{FullHttpRequest, HttpContent} import org.playframework.netty.http.StreamedHttpRequest import org.reactivestreams.Publisher -import sttp.capabilities.StreamMaxLengthExceededException import sttp.monad.MonadError import sttp.monad.syntax._ import sttp.tapir.model.ServerRequest @@ -20,6 +18,7 @@ import java.nio.ByteBuffer trait NettyRequestBody[F[_], S] extends RequestBody[F, S] { + val DefaultChunkSize = 8192 implicit def monad: MonadError[F] def createFile: ServerRequest => F[TapirFile] def publisherToBytes(publisher: Publisher[HttpContent], maxBytes: Option[Long]): F[Array[Byte]] @@ -43,8 +42,7 @@ trait NettyRequestBody[F[_], S] extends RequestBody[F, S] { for { file <- createFile(serverRequest) _ <- writeToFile(serverRequest, file, maxBytes) - } - yield RawValue(FileRange(file), Seq(FileRange(file))) + } yield RawValue(FileRange(file), Seq(FileRange(file))) case _: RawBodyType.MultipartBody => monad.error(new UnsupportedOperationException()) } } diff --git a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/internal/ZioStreamCompatible.scala b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/internal/ZioStreamCompatible.scala index c5ecd41fde..2251991aac 100644 --- a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/internal/ZioStreamCompatible.scala +++ b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/internal/ZioStreamCompatible.scala @@ -45,9 +45,6 @@ private[zio] object ZioStreamCompatible { .run(stream.mapChunks(c => Chunk.single(new DefaultHttpContent(Unpooled.wrappedBuffer(c.toArray)): HttpContent)).toPublisher) .getOrThrowFiberFailure() ) - - override def fromNettyStream(publisher: Publisher[HttpContent]): Stream[Throwable, Byte] = - publisher.toZIOStream().mapConcatChunk(httpContent => Chunk.fromByteBuffer(httpContent.content.nioBuffer())) } } } From fea96740b20c6e05c6be792252099beb7dbf19f1 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Fri, 1 Dec 2023 11:22:05 +0100 Subject: [PATCH 24/35] More refactoring and extractions --- .../cats/NettyCatsServerInterpreter.scala | 3 +- .../cats/internal/Fs2StreamCompatible.scala | 19 +++++++- .../cats/internal/NettyCatsRequestBody.scala | 33 +++++++++++++ .../netty/internal/NettyCatsRequestBody.scala | 48 ------------------- .../internal/NettyFutureRequestBody.scala | 13 ++--- .../internal/NettyStreamingRequestBody.scala | 24 ++++++++++ .../netty/internal/StreamCompatible.scala | 4 ++ .../reactivestreams/NettyRequestBody.scala | 16 ++----- .../netty/internal/NettyZioRequestBody.scala | 41 ---------------- .../netty/zio/NettyZioServerInterpreter.scala | 4 +- .../zio/internal/NettyZioRequestBody.scala | 27 +++++++++++ .../zio/internal/ZioStreamCompatible.scala | 14 ++++++ 12 files changed, 129 insertions(+), 117 deletions(-) create mode 100644 server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/NettyCatsRequestBody.scala delete mode 100644 server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsRequestBody.scala create mode 100644 server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyStreamingRequestBody.scala delete mode 100644 server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/internal/NettyZioRequestBody.scala create mode 100644 server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/internal/NettyZioRequestBody.scala diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServerInterpreter.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServerInterpreter.scala index 71a87a4f98..e9e223461c 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServerInterpreter.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServerInterpreter.scala @@ -12,6 +12,7 @@ import sttp.tapir.server.interceptor.RequestResult import sttp.tapir.server.interceptor.reject.RejectInterceptor import sttp.tapir.server.interpreter.{BodyListener, FilterServerEndpoints, ServerInterpreter} import sttp.tapir.server.netty.internal.{NettyBodyListener, RunAsync, _} +import sttp.tapir.server.netty.cats.internal.NettyCatsRequestBody import sttp.tapir.server.netty.{NettyResponse, NettyServerRequest, Route} trait NettyCatsServerInterpreter[F[_]] { @@ -32,7 +33,7 @@ trait NettyCatsServerInterpreter[F[_]] { val serverInterpreter = new ServerInterpreter[Fs2Streams[F], F, NettyResponse, Fs2Streams[F]]( FilterServerEndpoints(ses), - new NettyCatsRequestBody(createFile), + new NettyCatsRequestBody(createFile, Fs2StreamCompatible[F](nettyServerOptions.dispatcher)), new NettyToStreamsResponseBody(Fs2StreamCompatible[F](nettyServerOptions.dispatcher)), RejectInterceptor.disableWhenSingleEndpoint(interceptors, ses), deleteFile diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/Fs2StreamCompatible.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/Fs2StreamCompatible.scala index 5f61e58a0f..d96b6edb56 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/Fs2StreamCompatible.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/Fs2StreamCompatible.scala @@ -16,10 +16,11 @@ import fs2.io.file.Flags import fs2.interop.reactivestreams.StreamUnicastPublisher import cats.effect.kernel.Sync import fs2.Chunk +import fs2.interop.reactivestreams.StreamSubscriber -private[cats] object Fs2StreamCompatible { +object Fs2StreamCompatible { - def apply[F[_]: Async](dispatcher: Dispatcher[F]): StreamCompatible[Fs2Streams[F]] = { +private[cats] def apply[F[_]: Async](dispatcher: Dispatcher[F]): StreamCompatible[Fs2Streams[F]] = { new StreamCompatible[Fs2Streams[F]] { override val streams: Fs2Streams[F] = Fs2Streams[F] @@ -51,6 +52,20 @@ private[cats] object Fs2StreamCompatible { dispatcher ) + override def fromPublisher(publisher: Publisher[HttpContent], maxBytes: Option[Long]): streams.BinaryStream = { + val stream = fs2.Stream + .eval(StreamSubscriber[F, HttpContent](NettyToResponseBody.DefaultChunkSize)) + .flatMap(s => s.sub.stream(Sync[F].delay(publisher.subscribe(s)))) + .flatMap(httpContent => fs2.Stream.chunk(Chunk.byteBuffer(httpContent.content.nioBuffer()))) + maxBytes.map(Fs2Streams.limitBytes(stream, _)).getOrElse(stream) + } + + override def failedStream(e: => Throwable): streams.BinaryStream = + fs2.Stream.raiseError(e) + + override def emptyStream: streams.BinaryStream = + fs2.Stream.empty + private def inputStreamToFs2(inputStream: () => InputStream) = fs2.io.readInputStream( Sync[F].blocking(inputStream()), diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/NettyCatsRequestBody.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/NettyCatsRequestBody.scala new file mode 100644 index 0000000000..e1a762ae70 --- /dev/null +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/NettyCatsRequestBody.scala @@ -0,0 +1,33 @@ +package sttp.tapir.server.netty.cats.internal + +import cats.effect.Async +import cats.syntax.all._ +import fs2.Chunk +import fs2.io.file.{Files, Path} +import io.netty.handler.codec.http.HttpContent +import org.reactivestreams.Publisher +import sttp.capabilities.fs2.Fs2Streams +import sttp.monad.MonadError +import sttp.tapir.TapirFile +import sttp.tapir.integ.cats.effect.CatsMonadError +import sttp.tapir.model.ServerRequest +import sttp.tapir.server.netty.internal.{NettyStreamingRequestBody, StreamCompatible} + +private[cats] class NettyCatsRequestBody[F[_]: Async]( + val createFile: ServerRequest => F[TapirFile], + val streamCompatible: StreamCompatible[Fs2Streams[F]] +) extends NettyStreamingRequestBody[F, Fs2Streams[F]] { + + override implicit val monad: MonadError[F] = new CatsMonadError() + + override def publisherToBytes(publisher: Publisher[HttpContent], maxBytes: Option[Long]): F[Array[Byte]] = + streamCompatible.fromPublisher(publisher, maxBytes).compile.to(Chunk).map(_.toArray[Byte]) + + override def writeToFile(serverRequest: ServerRequest, file: TapirFile, maxBytes: Option[Long]): F[Unit] = + (toStream(serverRequest, maxBytes).asInstanceOf[streamCompatible.streams.BinaryStream]) + .through( + Files[F](Files.forAsync[F]).writeAll(Path.fromNioPath(file.toPath)) + ) + .compile + .drain +} diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsRequestBody.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsRequestBody.scala deleted file mode 100644 index 321f6c8983..0000000000 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsRequestBody.scala +++ /dev/null @@ -1,48 +0,0 @@ -package sttp.tapir.server.netty.internal - -import cats.effect.{Async, Sync} -import cats.syntax.all._ -import fs2.Chunk -import fs2.interop.reactivestreams.StreamSubscriber -import fs2.io.file.{Files, Path} -import io.netty.handler.codec.http.HttpContent -import org.reactivestreams.Publisher -import sttp.capabilities.fs2.Fs2Streams -import sttp.monad.MonadError -import sttp.tapir.integ.cats.effect.CatsMonadError -import sttp.tapir.model.ServerRequest -import sttp.tapir.server.netty.internal.reactivestreams.NettyRequestBody -import sttp.tapir.TapirFile - -private[netty] class NettyCatsRequestBody[F[_]: Async](val createFile: ServerRequest => F[TapirFile]) - extends NettyRequestBody[F, Fs2Streams[F]] { - - override val streams: Fs2Streams[F] = Fs2Streams[F] - - override implicit val monad: MonadError[F] = new CatsMonadError() - - override def publisherToBytes(publisher: Publisher[HttpContent], maxBytes: Option[Long]): F[Array[Byte]] = - publisherToStream(publisher, maxBytes).compile.to(Chunk).map(_.toArray[Byte]) - - override def writeToFile(serverRequest: ServerRequest, file: TapirFile, maxBytes: Option[Long]): F[Unit] = - toStream(serverRequest, maxBytes) - .through( - Files[F](Files.forAsync[F]).writeAll(Path.fromNioPath(file.toPath)) - ) - .compile - .drain - - override def publisherToStream(publisher: Publisher[HttpContent], maxBytes: Option[Long]): streams.BinaryStream = { - val stream = fs2.Stream - .eval(StreamSubscriber[F, HttpContent](DefaultChunkSize)) - .flatMap(s => s.sub.stream(Sync[F].delay(publisher.subscribe(s)))) - .flatMap(httpContent => fs2.Stream.chunk(Chunk.byteBuffer(httpContent.content.nioBuffer()))) - maxBytes.map(Fs2Streams.limitBytes(stream, _)).getOrElse(stream) - } - - override def failedStream(e: => Throwable): streams.BinaryStream = - fs2.Stream.raiseError(e) - - override def emptyStream: streams.BinaryStream = - fs2.Stream.empty -} diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyFutureRequestBody.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyFutureRequestBody.scala index e23f7fc05f..846b5bd981 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyFutureRequestBody.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyFutureRequestBody.scala @@ -5,9 +5,9 @@ import org.playframework.netty.http.StreamedHttpRequest import org.reactivestreams.Publisher import sttp.capabilities import sttp.monad.{FutureMonad, MonadError} +import sttp.tapir.TapirFile import sttp.tapir.capabilities.NoStreams import sttp.tapir.model.ServerRequest -import sttp.tapir.TapirFile import scala.concurrent.{ExecutionContext, Future} @@ -17,7 +17,6 @@ class NettyFutureRequestBody(val createFile: ServerRequest => Future[TapirFile]) extends NettyRequestBody[Future, NoStreams] { override val streams: capabilities.Streams[NoStreams] = NoStreams - override implicit val monad: MonadError[Future] = new FutureMonad() override def publisherToBytes(publisher: Publisher[HttpContent], maxBytes: Option[Long]): Future[Array[Byte]] = @@ -28,13 +27,7 @@ class NettyFutureRequestBody(val createFile: ServerRequest => Future[TapirFile]) case r: StreamedHttpRequest => FileWriterSubscriber.processAll(r, file.toPath, maxBytes) case _ => monad.unit(()) } - - override def publisherToStream(publisher: Publisher[HttpContent], maxBytes: Option[Long]): streams.BinaryStream = - throw new UnsupportedOperationException() - - override def emptyStream: streams.BinaryStream = - throw new UnsupportedOperationException() - - override def failedStream(e: => Throwable): streams.BinaryStream = + + override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = throw new UnsupportedOperationException() } diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyStreamingRequestBody.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyStreamingRequestBody.scala new file mode 100644 index 0000000000..cdf19f1f4e --- /dev/null +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyStreamingRequestBody.scala @@ -0,0 +1,24 @@ +package sttp.tapir.server.netty.internal + +import io.netty.buffer.Unpooled +import io.netty.handler.codec.http.FullHttpRequest +import org.playframework.netty.http.StreamedHttpRequest +import sttp.capabilities.Streams +import sttp.tapir.model.ServerRequest +import sttp.tapir.server.netty.internal.reactivestreams.NettyRequestBody + +trait NettyStreamingRequestBody[F[_], S <: Streams[S]] extends NettyRequestBody[F, S] { + + val streamCompatible: StreamCompatible[S] + override val streams = streamCompatible.streams + + override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = + (serverRequest.underlying match { + case r: FullHttpRequest if r.content() == Unpooled.EMPTY_BUFFER => // means EmptyHttpRequest, but that class is not public + streamCompatible.emptyStream + case publisher: StreamedHttpRequest => + streamCompatible.fromPublisher(publisher, maxBytes) + case other => + streamCompatible.failedStream(new UnsupportedOperationException(s"Unexpected Netty request of type: ${other.getClass().getName()}")) + }).asInstanceOf[streams.BinaryStream] +} diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/StreamCompatible.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/StreamCompatible.scala index e4ca4d72a3..3e703b7196 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/StreamCompatible.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/StreamCompatible.scala @@ -11,8 +11,12 @@ private[netty] trait StreamCompatible[S <: Streams[S]] { val streams: S def fromFile(file: FileRange): streams.BinaryStream def fromInputStream(is: () => InputStream, length: Option[Long]): streams.BinaryStream + def fromPublisher(publisher: Publisher[HttpContent], maxBytes: Option[Long]): streams.BinaryStream def asPublisher(s: streams.BinaryStream): Publisher[HttpContent] + def failedStream(e: => Throwable): streams.BinaryStream + def emptyStream: streams.BinaryStream + def publisherFromFile(file: FileRange): Publisher[HttpContent] = asPublisher(fromFile(file)) diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/NettyRequestBody.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/NettyRequestBody.scala index e075d76559..8c6f46889f 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/NettyRequestBody.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/NettyRequestBody.scala @@ -10,22 +10,21 @@ import sttp.tapir.model.ServerRequest import sttp.tapir.server.interpreter.RequestBody import sttp.tapir.RawBodyType import sttp.tapir.TapirFile +import sttp.tapir.server.netty.internal.StreamCompatible import sttp.tapir.server.interpreter.RawValue import sttp.tapir.FileRange import sttp.tapir.InputStreamRange import java.io.ByteArrayInputStream import java.nio.ByteBuffer +import sttp.capabilities.Streams -trait NettyRequestBody[F[_], S] extends RequestBody[F, S] { +trait NettyRequestBody[F[_], S <: Streams[S]] extends RequestBody[F, S] { val DefaultChunkSize = 8192 implicit def monad: MonadError[F] def createFile: ServerRequest => F[TapirFile] def publisherToBytes(publisher: Publisher[HttpContent], maxBytes: Option[Long]): F[Array[Byte]] def writeToFile(serverRequest: ServerRequest, file: TapirFile, maxBytes: Option[Long]): F[Unit] - def publisherToStream(publisher: Publisher[HttpContent], maxBytes: Option[Long]): streams.BinaryStream - def failedStream(e: => Throwable): streams.BinaryStream - def emptyStream: streams.BinaryStream override def toRaw[RAW](serverRequest: ServerRequest, bodyType: RawBodyType[RAW], maxBytes: Option[Long]): F[RawValue[RAW]] = { bodyType match { @@ -47,15 +46,6 @@ trait NettyRequestBody[F[_], S] extends RequestBody[F, S] { } } - override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = - serverRequest.underlying match { - case r: FullHttpRequest if r.content() == Unpooled.EMPTY_BUFFER => // means EmptyHttpRequest, but that class is not public - emptyStream - case publisher: StreamedHttpRequest => - publisherToStream(publisher, maxBytes) - case other => - failedStream(new UnsupportedOperationException(s"Unexpected Netty request of type: ${other.getClass().getName()}")) - } // Used by different netty backends to handle raw body input def readAllBytes(serverRequest: ServerRequest, maxBytes: Option[Long]): F[Array[Byte]] = diff --git a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/internal/NettyZioRequestBody.scala b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/internal/NettyZioRequestBody.scala deleted file mode 100644 index 45b53676db..0000000000 --- a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/internal/NettyZioRequestBody.scala +++ /dev/null @@ -1,41 +0,0 @@ -package sttp.tapir.server.netty.internal - -import io.netty.handler.codec.http.HttpContent -import org.reactivestreams.Publisher -import sttp.capabilities.zio.ZioStreams -import sttp.monad.MonadError -import sttp.tapir.TapirFile -import sttp.tapir.model.ServerRequest -import sttp.tapir.server.netty.internal.reactivestreams.NettyRequestBody -import sttp.tapir.ztapir.RIOMonadError -import zio.interop.reactivestreams._ -import zio.stream.{ZStream, _} -import zio.{Chunk, RIO} - -private[netty] class NettyZioRequestBody[Env](val createFile: ServerRequest => RIO[Env, TapirFile]) - extends NettyRequestBody[RIO[Env, *], ZioStreams] { - - override val streams: ZioStreams = ZioStreams - - override implicit val monad: MonadError[RIO[Env, *]] = new RIOMonadError[Env] - - override def publisherToBytes(publisher: Publisher[HttpContent], maxBytes: Option[Long]): RIO[Env, Array[Byte]] = - publisherToStream(publisher, maxBytes).run(ZSink.collectAll[Byte]).map(_.toArray) - - override def writeToFile(serverRequest: ServerRequest, file: TapirFile, maxBytes: Option[Long]): RIO[Env, Unit] = - toStream(serverRequest, maxBytes).run(ZSink.fromFile(file)).map(_ => ()) - - override def publisherToStream(publisher: Publisher[HttpContent], maxBytes: Option[Long]): streams.BinaryStream = { - val stream = - Adapters - .publisherToStream(publisher, 16) - .flatMap(httpContent => ZStream.fromChunk(Chunk.fromByteBuffer(httpContent.content.nioBuffer()))) - maxBytes.map(ZioStreams.limitBytes(stream, _)).getOrElse(stream) - } - - override def failedStream(e: => Throwable): streams.BinaryStream = - ZStream.fail(e) - - override def emptyStream: streams.BinaryStream = - ZStream.empty -} diff --git a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServerInterpreter.scala b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServerInterpreter.scala index 07d14aaf43..64a91e3cff 100644 --- a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServerInterpreter.scala +++ b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServerInterpreter.scala @@ -6,7 +6,7 @@ import sttp.tapir.server.interceptor.reject.RejectInterceptor import sttp.tapir.server.interpreter.{BodyListener, FilterServerEndpoints, ServerInterpreter} import sttp.tapir.server.netty.internal.{NettyBodyListener, RunAsync, _} import sttp.tapir.server.netty.zio.NettyZioServerInterpreter.ZioRunAsync -import sttp.tapir.server.netty.zio.internal.ZioStreamCompatible +import sttp.tapir.server.netty.zio.internal.{NettyZioRequestBody, ZioStreamCompatible} import sttp.tapir.server.netty.{NettyResponse, NettyServerRequest, Route} import sttp.tapir.ztapir.{RIOMonadError, ZServerEndpoint, _} import zio._ @@ -26,7 +26,7 @@ trait NettyZioServerInterpreter[R] { implicit val bodyListener: BodyListener[F, NettyResponse] = new NettyBodyListener(runAsync) val serverInterpreter = new ServerInterpreter[ZioStreams, F, NettyResponse, ZioStreams]( FilterServerEndpoints(widenedSes), - new NettyZioRequestBody(widenedServerOptions.createFile), + new NettyZioRequestBody(widenedServerOptions.createFile, ZioStreamCompatible(runtime)), new NettyToStreamsResponseBody[ZioStreams](ZioStreamCompatible(runtime)), RejectInterceptor.disableWhenSingleEndpoint(widenedServerOptions.interceptors, widenedSes), widenedServerOptions.deleteFile diff --git a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/internal/NettyZioRequestBody.scala b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/internal/NettyZioRequestBody.scala new file mode 100644 index 0000000000..2e551cad81 --- /dev/null +++ b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/internal/NettyZioRequestBody.scala @@ -0,0 +1,27 @@ +package sttp.tapir.server.netty.zio.internal + +import io.netty.handler.codec.http.HttpContent +import org.reactivestreams.Publisher +import sttp.capabilities.zio.ZioStreams +import sttp.monad.MonadError +import sttp.tapir.TapirFile +import sttp.tapir.model.ServerRequest +import sttp.tapir.server.netty.internal.{NettyStreamingRequestBody, StreamCompatible} +import sttp.tapir.ztapir.RIOMonadError +import zio.RIO +import zio.stream._ + +private[zio] class NettyZioRequestBody[Env]( + val createFile: ServerRequest => RIO[Env, TapirFile], + val streamCompatible: StreamCompatible[ZioStreams] +) extends NettyStreamingRequestBody[RIO[Env, *], ZioStreams] { + + override val streams: ZioStreams = ZioStreams + override implicit val monad: MonadError[RIO[Env, *]] = new RIOMonadError[Env] + + override def publisherToBytes(publisher: Publisher[HttpContent], maxBytes: Option[Long]): RIO[Env, Array[Byte]] = + streamCompatible.fromPublisher(publisher, maxBytes).run(ZSink.collectAll[Byte]).map(_.toArray) + + override def writeToFile(serverRequest: ServerRequest, file: TapirFile, maxBytes: Option[Long]): RIO[Env, Unit] = + toStream(serverRequest, maxBytes).run(ZSink.fromFile(file)).map(_ => ()) +} diff --git a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/internal/ZioStreamCompatible.scala b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/internal/ZioStreamCompatible.scala index 2251991aac..8d0655f603 100644 --- a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/internal/ZioStreamCompatible.scala +++ b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/internal/ZioStreamCompatible.scala @@ -45,6 +45,20 @@ private[zio] object ZioStreamCompatible { .run(stream.mapChunks(c => Chunk.single(new DefaultHttpContent(Unpooled.wrappedBuffer(c.toArray)): HttpContent)).toPublisher) .getOrThrowFiberFailure() ) + + override def fromPublisher(publisher: Publisher[HttpContent], maxBytes: Option[Long]): streams.BinaryStream = { + val stream = + Adapters + .publisherToStream(publisher, 16) + .flatMap(httpContent => ZStream.fromChunk(Chunk.fromByteBuffer(httpContent.content.nioBuffer()))) + maxBytes.map(ZioStreams.limitBytes(stream, _)).getOrElse(stream) + } + + override def failedStream(e: => Throwable): streams.BinaryStream = + ZStream.fail(e) + + override def emptyStream: streams.BinaryStream = + ZStream.empty } } } From adb970aef7704f09c464582087659a2f7248c2d3 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Fri, 1 Dec 2023 11:36:39 +0100 Subject: [PATCH 25/35] Leave id server compiling but unfinished --- .../netty/loom/NettyIdRequestBody.scala | 67 +++---------------- .../netty/loom/NettyIdServerInterpreter.scala | 2 +- 2 files changed, 12 insertions(+), 57 deletions(-) diff --git a/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdRequestBody.scala b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdRequestBody.scala index 8eab65b298..e4587d18b3 100644 --- a/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdRequestBody.scala +++ b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdRequestBody.scala @@ -1,71 +1,26 @@ package sttp.tapir.server.netty.loom -import io.netty.handler.codec.http.FullHttpRequest -import io.netty.buffer.{ByteBufInputStream, ByteBufUtil} +import io.netty.handler.codec.http.HttpContent +import org.reactivestreams.Publisher import sttp.capabilities import sttp.monad.MonadError -import sttp.tapir.{FileRange, InputStreamRange, RawBodyType, TapirFile} -import sttp.tapir.model.ServerRequest -import sttp.monad.syntax._ +import sttp.tapir.TapirFile import sttp.tapir.capabilities.NoStreams -import sttp.tapir.server.interpreter.{RawValue, RequestBody} - -import java.nio.ByteBuffer -import java.nio.file.Files -import io.netty.buffer.ByteBuf -import sttp.capabilities.StreamMaxLengthExceededException -import sttp.tapir.server.netty.internal.reactivestreams.SimpleSubscriber +import sttp.tapir.model.ServerRequest +import sttp.tapir.server.netty.internal.reactivestreams.NettyRequestBody class NettyIdRequestBody(val createFile: ServerRequest => TapirFile) extends NettyRequestBody[Id, NoStreams] { - override implicit val monad: MonadError[Id] = idMonad + override implicit val monad: MonadError[Id] = idMonad override val streams: capabilities.Streams[NoStreams] = NoStreams - def publisherToBytes(publisher: Publisher[HttpContent], maxBytes: Option[Long]): Array[Byte] = - SimpleSubscriber.processAll(publisher, maxBytes) - - def writeToFile(serverRequest: ServerRequest, file: TapirFile, maxBytes: Option[Long]): Unit = - Files.write(fi) - - def publisherToStream(publisher: Publisher[HttpContent], maxBytes: Option[Long]): streams.BinaryStream + override def publisherToBytes(publisher: Publisher[HttpContent], maxBytes: Option[Long]): Array[Byte] = + ??? // TODO + // SimpleSubscriber.processAll(publisher, maxBytes) returns Future - override def toRaw[RAW](serverRequest: ServerRequest, bodyType: RawBodyType[RAW], maxBytes: Option[Long]): RawValue[RAW] = { - - def byteBuf: ByteBuf = { - val buf = nettyRequest(serverRequest).content() - maxBytes - .map(max => - if (buf.readableBytes() > max) - monadError.error[ByteBuf](StreamMaxLengthExceededException(max)) - else - monadError.unit(buf) - ) - .getOrElse(monadError.unit(buf)) - } - - def requestContentAsByteArray: Array[Byte] = byteBuf.map(ByteBufUtil.getBytes) - - bodyType match { - case RawBodyType.StringBody(charset) => byteBuf.map(buf => RawValue(buf.toString(charset))) - case RawBodyType.ByteArrayBody => requestContentAsByteArray.map(ba => RawValue(ba)) - case RawBodyType.ByteBufferBody => requestContentAsByteArray.map(ba => RawValue(ByteBuffer.wrap(ba))) - case RawBodyType.InputStreamBody => byteBuf.map(buf => RawValue(new ByteBufInputStream(buf))) - case RawBodyType.InputStreamRangeBody => - byteBuf.map(buf => RawValue(InputStreamRange(() => new ByteBufInputStream(buf)))) - case RawBodyType.FileBody => - requestContentAsByteArray.flatMap(ba => - createFile(serverRequest) - .map(file => { - Files.write(file.toPath, ba) - RawValue(FileRange(file), Seq(FileRange(file))) - }) - ) - case _: RawBodyType.MultipartBody => ??? - } - } + override def writeToFile(serverRequest: ServerRequest, file: TapirFile, maxBytes: Option[Long]): Unit = + ??? // TODO override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = throw new UnsupportedOperationException() - - private def nettyRequest(serverRequest: ServerRequest): FullHttpRequest = serverRequest.underlying.asInstanceOf[FullHttpRequest] } diff --git a/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdServerInterpreter.scala b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdServerInterpreter.scala index 444a4c5e5d..4cd225e0eb 100644 --- a/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdServerInterpreter.scala +++ b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdServerInterpreter.scala @@ -1,7 +1,7 @@ package sttp.tapir.server.netty.loom import sttp.tapir.server.ServerEndpoint -import sttp.tapir.server.netty.internal.{NettyToResponseBody, NettyIdRequestBody, NettyServerInterpreter, RunAsync} +import sttp.tapir.server.netty.internal.{NettyToResponseBody, NettyServerInterpreter, RunAsync} trait NettyIdServerInterpreter { def nettyServerOptions: NettyIdServerOptions From 945f9cf37a2b70b4274efa25dd797635d65ee97f Mon Sep 17 00:00:00 2001 From: kciesielski Date: Fri, 1 Dec 2023 12:25:32 +0100 Subject: [PATCH 26/35] Integrate netty-loom with reactive Publishers --- .../server/netty/loom/NettyIdRequestBody.scala | 11 ++++++++--- .../tapir/server/netty/loom/NettyIdServerTest.scala | 6 ++++-- .../netty/internal/NettyFutureRequestBody.scala | 2 +- .../reactivestreams/FileWriterSubscriber.scala | 10 ++++++++++ .../internal/reactivestreams/SimpleSubscriber.scala | 13 ++++++++++++- 5 files changed, 35 insertions(+), 7 deletions(-) diff --git a/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdRequestBody.scala b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdRequestBody.scala index e4587d18b3..35b953b200 100644 --- a/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdRequestBody.scala +++ b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdRequestBody.scala @@ -8,6 +8,9 @@ import sttp.tapir.TapirFile import sttp.tapir.capabilities.NoStreams import sttp.tapir.model.ServerRequest import sttp.tapir.server.netty.internal.reactivestreams.NettyRequestBody +import sttp.tapir.server.netty.internal.reactivestreams.SimpleSubscriber +import sttp.tapir.server.netty.internal.reactivestreams.FileWriterSubscriber +import org.playframework.netty.http.StreamedHttpRequest class NettyIdRequestBody(val createFile: ServerRequest => TapirFile) extends NettyRequestBody[Id, NoStreams] { @@ -15,11 +18,13 @@ class NettyIdRequestBody(val createFile: ServerRequest => TapirFile) extends Net override val streams: capabilities.Streams[NoStreams] = NoStreams override def publisherToBytes(publisher: Publisher[HttpContent], maxBytes: Option[Long]): Array[Byte] = - ??? // TODO - // SimpleSubscriber.processAll(publisher, maxBytes) returns Future + SimpleSubscriber.processAllBlocking(publisher, maxBytes) override def writeToFile(serverRequest: ServerRequest, file: TapirFile, maxBytes: Option[Long]): Unit = - ??? // TODO + serverRequest.underlying match { + case r: StreamedHttpRequest => FileWriterSubscriber.processAllBlocking(r, file.toPath, maxBytes) + case _ => () // Empty request + } override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = throw new UnsupportedOperationException() diff --git a/server/netty-server/loom/src/test/scala/sttp/tapir/server/netty/loom/NettyIdServerTest.scala b/server/netty-server/loom/src/test/scala/sttp/tapir/server/netty/loom/NettyIdServerTest.scala index 339a30d8a4..73fe2cdd20 100644 --- a/server/netty-server/loom/src/test/scala/sttp/tapir/server/netty/loom/NettyIdServerTest.scala +++ b/server/netty-server/loom/src/test/scala/sttp/tapir/server/netty/loom/NettyIdServerTest.scala @@ -21,8 +21,10 @@ class NettyIdServerTest extends TestSuite with EitherValues { val createServerTest = new DefaultCreateServerTest(backend, interpreter) val sleeper: Sleeper[Id] = (duration: FiniteDuration) => Thread.sleep(duration.toMillis) - val tests = new AllServerTests(createServerTest, interpreter, backend, staticContent = false, multipart = false).tests() ++ - new ServerGracefulShutdownTests(createServerTest, sleeper).tests() + val tests = + new AllServerTests(createServerTest, interpreter, backend, staticContent = false, multipart = false, maxContentLength = true) + .tests() ++ + new ServerGracefulShutdownTests(createServerTest, sleeper).tests() (tests, eventLoopGroup) }) { case (_, eventLoopGroup) => diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyFutureRequestBody.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyFutureRequestBody.scala index 846b5bd981..0cfa268b3f 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyFutureRequestBody.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyFutureRequestBody.scala @@ -25,7 +25,7 @@ class NettyFutureRequestBody(val createFile: ServerRequest => Future[TapirFile]) override def writeToFile(serverRequest: ServerRequest, file: TapirFile, maxBytes: Option[Long]): Future[Unit] = serverRequest.underlying match { case r: StreamedHttpRequest => FileWriterSubscriber.processAll(r, file.toPath, maxBytes) - case _ => monad.unit(()) + case _ => monad.unit(()) // Empty request } override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/FileWriterSubscriber.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/FileWriterSubscriber.scala index b912a75d15..72b5c22caf 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/FileWriterSubscriber.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/FileWriterSubscriber.scala @@ -5,6 +5,7 @@ import org.reactivestreams.{Publisher, Subscription} import java.nio.channels.AsynchronousFileChannel import java.nio.file.{Path, StandardOpenOption} +import java.util.concurrent.Semaphore import scala.concurrent.{Future, Promise} class FileWriterSubscriber(path: Path) extends PromisingSubscriber[Unit, HttpContent] { @@ -12,8 +13,10 @@ class FileWriterSubscriber(path: Path) extends PromisingSubscriber[Unit, HttpCon private var fileChannel: AsynchronousFileChannel = _ private var position: Long = 0 private val resultPromise = Promise[Unit]() + private val resultBlockingSemaphore: Semaphore = new Semaphore(0) override def future: Future[Unit] = resultPromise.future + def waitForResultBlocking(): Unit = resultBlockingSemaphore.acquire() override def onSubscribe(s: Subscription): Unit = { this.subscription = s @@ -49,6 +52,7 @@ class FileWriterSubscriber(path: Path) extends PromisingSubscriber[Unit, HttpCon override def onComplete(): Unit = { fileChannel.close() resultPromise.success(()) + resultBlockingSemaphore.release() } } @@ -58,4 +62,10 @@ object FileWriterSubscriber { publisher.subscribe(maxBytes.map(new LimitedLengthSubscriber(_, subscriber)).getOrElse(subscriber)) subscriber.future } + + def processAllBlocking(publisher: Publisher[HttpContent], path: Path, maxBytes: Option[Long]): Unit = { + val subscriber = new FileWriterSubscriber(path) + publisher.subscribe(maxBytes.map(new LimitedLengthSubscriber(_, subscriber)).getOrElse(subscriber)) + subscriber.waitForResultBlocking() + } } diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/SimpleSubscriber.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/SimpleSubscriber.scala index 0f2b10bbb4..193e36245d 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/SimpleSubscriber.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/SimpleSubscriber.scala @@ -7,14 +7,18 @@ import org.reactivestreams.{Publisher, Subscription} import java.util.concurrent.ConcurrentLinkedQueue import scala.collection.JavaConverters._ import scala.concurrent.{Future, Promise} +import java.util.concurrent.BlockingQueue +import java.util.concurrent.LinkedBlockingQueue private[netty] class SimpleSubscriber() extends PromisingSubscriber[Array[Byte], HttpContent] { private var subscription: Subscription = _ private val chunks = new ConcurrentLinkedQueue[Array[Byte]]() private var size = 0 private val resultPromise = Promise[Array[Byte]]() + private val resultBlockingQueue: BlockingQueue[Array[Byte]] = new LinkedBlockingQueue[Array[Byte]]() override def future: Future[Array[Byte]] = resultPromise.future + def resultBlocking: Array[Byte] = resultBlockingQueue.poll() override def onSubscribe(s: Subscription): Unit = { subscription = s @@ -35,11 +39,12 @@ private[netty] class SimpleSubscriber() extends PromisingSubscriber[Array[Byte], override def onComplete(): Unit = { val result = new Array[Byte](size) - chunks.asScala.foldLeft(0)((currentPosition, array) => { + val _ = chunks.asScala.foldLeft(0)((currentPosition, array) => { System.arraycopy(array, 0, result, currentPosition, array.length) currentPosition + array.length }) chunks.clear() + resultBlockingQueue.add(result) resultPromise.success(result) } } @@ -50,4 +55,10 @@ object SimpleSubscriber { publisher.subscribe(maxBytes.map(max => new LimitedLengthSubscriber(max, subscriber)).getOrElse(subscriber)) subscriber.future } + + def processAllBlocking(publisher: Publisher[HttpContent], maxBytes: Option[Long]): Array[Byte] = { + val subscriber = new SimpleSubscriber() + publisher.subscribe(maxBytes.map(max => new LimitedLengthSubscriber(max, subscriber)).getOrElse(subscriber)) + subscriber.resultBlocking + } } From 5a5d953210f3f5faef36fccdd039a7aa68d46aa2 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Fri, 1 Dec 2023 12:53:19 +0100 Subject: [PATCH 27/35] Fix creating empty byte array --- .../netty/internal/reactivestreams/NettyRequestBody.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/NettyRequestBody.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/NettyRequestBody.scala index 8c6f46889f..80d11cc996 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/NettyRequestBody.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/NettyRequestBody.scala @@ -10,7 +10,6 @@ import sttp.tapir.model.ServerRequest import sttp.tapir.server.interpreter.RequestBody import sttp.tapir.RawBodyType import sttp.tapir.TapirFile -import sttp.tapir.server.netty.internal.StreamCompatible import sttp.tapir.server.interpreter.RawValue import sttp.tapir.FileRange import sttp.tapir.InputStreamRange @@ -51,7 +50,7 @@ trait NettyRequestBody[F[_], S <: Streams[S]] extends RequestBody[F, S] { def readAllBytes(serverRequest: ServerRequest, maxBytes: Option[Long]): F[Array[Byte]] = serverRequest.underlying match { case r: FullHttpRequest if r.content() == Unpooled.EMPTY_BUFFER => - monad.unit(Array[Byte](0)) + monad.unit(Array.empty[Byte]) case req: StreamedHttpRequest => publisherToBytes(req, maxBytes) case other => monad.error(new UnsupportedOperationException(s"Unexpected Netty request of type ${other.getClass().getName()}")) From 975d1dc0e9bcd824a726c70334afc5d84d099122 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Fri, 1 Dec 2023 14:20:50 +0100 Subject: [PATCH 28/35] Remove temporary test --- .../scala/sttp/tapir/server/tests/ServerBasicTests.scala | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala index a560f1b28f..add8d4ae97 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala @@ -779,13 +779,7 @@ class ServerBasicTests[F[_], OPTIONS, ROUTE]( testPayloadWithinLimit(in_input_stream_out_input_stream, maxLength), testPayloadWithinLimit(in_byte_array_out_byte_array, maxLength), testPayloadWithinLimit(in_file_out_file, maxLength), - testPayloadWithinLimit(in_byte_buffer_out_byte_buffer, maxLength), - testServer( - in_string_out_string, - "testkc" - )(i => pureResult(i.asRight[Unit])) { (backend, baseUri) => - basicRequest.post(uri"$baseUri/api/echo").body("").send(backend).map(_.code shouldBe StatusCode.Ok) - } + testPayloadWithinLimit(in_byte_buffer_out_byte_buffer, maxLength) ) } From 97e03409375a3fffdcac45de0dac46c4a325020b Mon Sep 17 00:00:00 2001 From: kciesielski Date: Fri, 1 Dec 2023 14:21:03 +0100 Subject: [PATCH 29/35] Fix error handling --- .../reactivestreams/FileWriterSubscriber.scala | 11 ++++++----- .../reactivestreams/SimpleSubscriber.scala | 14 +++++++++----- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/FileWriterSubscriber.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/FileWriterSubscriber.scala index 72b5c22caf..dad69b86a9 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/FileWriterSubscriber.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/FileWriterSubscriber.scala @@ -5,18 +5,18 @@ import org.reactivestreams.{Publisher, Subscription} import java.nio.channels.AsynchronousFileChannel import java.nio.file.{Path, StandardOpenOption} -import java.util.concurrent.Semaphore import scala.concurrent.{Future, Promise} +import java.util.concurrent.LinkedBlockingQueue class FileWriterSubscriber(path: Path) extends PromisingSubscriber[Unit, HttpContent] { private var subscription: Subscription = _ private var fileChannel: AsynchronousFileChannel = _ private var position: Long = 0 private val resultPromise = Promise[Unit]() - private val resultBlockingSemaphore: Semaphore = new Semaphore(0) + private val resultBlockingQueue = new LinkedBlockingQueue[Either[Throwable, Unit]]() override def future: Future[Unit] = resultPromise.future - def waitForResultBlocking(): Unit = resultBlockingSemaphore.acquire() + private def waitForResultBlocking(): Either[Throwable, Unit] = resultBlockingQueue.take() override def onSubscribe(s: Subscription): Unit = { this.subscription = s @@ -46,13 +46,14 @@ class FileWriterSubscriber(path: Path) extends PromisingSubscriber[Unit, HttpCon override def onError(t: Throwable): Unit = { fileChannel.close() + resultBlockingQueue.add(Left(t)) resultPromise.failure(t) } override def onComplete(): Unit = { fileChannel.close() + val _ = resultBlockingQueue.add(Right(())) resultPromise.success(()) - resultBlockingSemaphore.release() } } @@ -66,6 +67,6 @@ object FileWriterSubscriber { def processAllBlocking(publisher: Publisher[HttpContent], path: Path, maxBytes: Option[Long]): Unit = { val subscriber = new FileWriterSubscriber(path) publisher.subscribe(maxBytes.map(new LimitedLengthSubscriber(_, subscriber)).getOrElse(subscriber)) - subscriber.waitForResultBlocking() + subscriber.waitForResultBlocking().left.foreach(e => throw e) } } diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/SimpleSubscriber.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/SimpleSubscriber.scala index 193e36245d..e5d9bf96dd 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/SimpleSubscriber.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/SimpleSubscriber.scala @@ -15,10 +15,10 @@ private[netty] class SimpleSubscriber() extends PromisingSubscriber[Array[Byte], private val chunks = new ConcurrentLinkedQueue[Array[Byte]]() private var size = 0 private val resultPromise = Promise[Array[Byte]]() - private val resultBlockingQueue: BlockingQueue[Array[Byte]] = new LinkedBlockingQueue[Array[Byte]]() + private val resultBlockingQueue = new LinkedBlockingQueue[Either[Throwable, Array[Byte]]]() override def future: Future[Array[Byte]] = resultPromise.future - def resultBlocking: Array[Byte] = resultBlockingQueue.poll() + def resultBlocking(): Either[Throwable, Array[Byte]] = resultBlockingQueue.take() override def onSubscribe(s: Subscription): Unit = { subscription = s @@ -31,9 +31,10 @@ private[netty] class SimpleSubscriber() extends PromisingSubscriber[Array[Byte], chunks.add(a) subscription.request(1) } - + override def onError(t: Throwable): Unit = { chunks.clear() + resultBlockingQueue.add(Left(t)) resultPromise.failure(t) } @@ -44,7 +45,7 @@ private[netty] class SimpleSubscriber() extends PromisingSubscriber[Array[Byte], currentPosition + array.length }) chunks.clear() - resultBlockingQueue.add(result) + val _ = resultBlockingQueue.add(Right(result)) resultPromise.success(result) } } @@ -59,6 +60,9 @@ object SimpleSubscriber { def processAllBlocking(publisher: Publisher[HttpContent], maxBytes: Option[Long]): Array[Byte] = { val subscriber = new SimpleSubscriber() publisher.subscribe(maxBytes.map(max => new LimitedLengthSubscriber(max, subscriber)).getOrElse(subscriber)) - subscriber.resultBlocking + subscriber.resultBlocking() match { + case Right(result) => result + case Left(e) => throw e + } } } From 19ac5b17f82ce8afe90ee89ea5734323b7434606 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Fri, 1 Dec 2023 14:55:23 +0100 Subject: [PATCH 30/35] Adjust to allow compiling on Scala 2.12 --- .../netty/internal/reactivestreams/SimpleSubscriber.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/SimpleSubscriber.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/SimpleSubscriber.scala index e5d9bf96dd..40138b3614 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/SimpleSubscriber.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/SimpleSubscriber.scala @@ -7,7 +7,6 @@ import org.reactivestreams.{Publisher, Subscription} import java.util.concurrent.ConcurrentLinkedQueue import scala.collection.JavaConverters._ import scala.concurrent.{Future, Promise} -import java.util.concurrent.BlockingQueue import java.util.concurrent.LinkedBlockingQueue private[netty] class SimpleSubscriber() extends PromisingSubscriber[Array[Byte], HttpContent] { @@ -45,7 +44,7 @@ private[netty] class SimpleSubscriber() extends PromisingSubscriber[Array[Byte], currentPosition + array.length }) chunks.clear() - val _ = resultBlockingQueue.add(Right(result)) + resultBlockingQueue.add(Right(result)) resultPromise.success(result) } } From 39ca430fd37903d3577b1690659b458d2ed06a1b Mon Sep 17 00:00:00 2001 From: kciesielski Date: Fri, 1 Dec 2023 16:10:38 +0100 Subject: [PATCH 31/35] Improve chunking --- .../cats/internal/Fs2StreamCompatible.scala | 24 +++++++------- .../netty/NettyFutureServerInterpreter.scala | 4 +-- .../internal/NettyStreamingRequestBody.scala | 3 +- .../netty/internal/NettyToResponseBody.scala | 6 +++- .../internal/NettyToStreamsResponseBody.scala | 11 +++++-- .../netty/internal/StreamCompatible.scala | 12 +++---- .../reactivestreams/FileRangePublisher.scala | 3 ++ .../FileWriterSubscriber.scala | 10 ++++++ .../reactivestreams/NettyRequestBody.scala | 33 +++++++++++++++---- .../zio/internal/ZioStreamCompatible.scala | 15 +++++---- 10 files changed, 81 insertions(+), 40 deletions(-) diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/Fs2StreamCompatible.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/Fs2StreamCompatible.scala index d96b6edb56..6a2a1fe69b 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/Fs2StreamCompatible.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/Fs2StreamCompatible.scala @@ -24,37 +24,35 @@ private[cats] def apply[F[_]: Async](dispatcher: Dispatcher[F]): StreamCompatibl new StreamCompatible[Fs2Streams[F]] { override val streams: Fs2Streams[F] = Fs2Streams[F] - override def fromFile(fileRange: FileRange): streams.BinaryStream = { + override def fromFile(fileRange: FileRange, chunkSize: Int): streams.BinaryStream = { val path = Path.fromNioPath(fileRange.file.toPath) fileRange.range .flatMap(r => - r.startAndEnd.map(s => Files[F](Files.forAsync[F]).readRange(path, NettyToResponseBody.DefaultChunkSize, s._1, s._2)) + r.startAndEnd.map(s => Files[F](Files.forAsync[F]).readRange(path, chunkSize, s._1, s._2)) ) - .getOrElse(Files[F](Files.forAsync[F]).readAll(path, NettyToResponseBody.DefaultChunkSize, Flags.Read)) + .getOrElse(Files[F](Files.forAsync[F]).readAll(path, chunkSize, Flags.Read)) } - override def fromInputStream(is: () => InputStream, length: Option[Long]): streams.BinaryStream = + override def fromInputStream(is: () => InputStream, chunkSize: Int, length: Option[Long]): streams.BinaryStream = length match { - case Some(limitedLength) => inputStreamToFs2(is).take(limitedLength) - case None => inputStreamToFs2(is) + case Some(limitedLength) => inputStreamToFs2(is, chunkSize).take(limitedLength) + case None => inputStreamToFs2(is, chunkSize) } override def asPublisher(stream: fs2.Stream[F, Byte]): Publisher[HttpContent] = // Deprecated constructor, but the proposed one does roughly the same, forcing a dedicated // dispatcher, which results in a Resource[], which is hard to afford here StreamUnicastPublisher( - stream - .chunkLimit(NettyToResponseBody.DefaultChunkSize) - .map { chunk => + stream.mapChunks { chunk => val bytes: Chunk.ArraySlice[Byte] = chunk.compact - new DefaultHttpContent(Unpooled.wrappedBuffer(bytes.values, bytes.offset, bytes.length)) + Chunk.singleton(new DefaultHttpContent(Unpooled.wrappedBuffer(bytes.values, bytes.offset, bytes.length))) }, dispatcher ) override def fromPublisher(publisher: Publisher[HttpContent], maxBytes: Option[Long]): streams.BinaryStream = { val stream = fs2.Stream - .eval(StreamSubscriber[F, HttpContent](NettyToResponseBody.DefaultChunkSize)) + .eval(StreamSubscriber[F, HttpContent](bufferSize = 2)) .flatMap(s => s.sub.stream(Sync[F].delay(publisher.subscribe(s)))) .flatMap(httpContent => fs2.Stream.chunk(Chunk.byteBuffer(httpContent.content.nioBuffer()))) maxBytes.map(Fs2Streams.limitBytes(stream, _)).getOrElse(stream) @@ -66,10 +64,10 @@ private[cats] def apply[F[_]: Async](dispatcher: Dispatcher[F]): StreamCompatibl override def emptyStream: streams.BinaryStream = fs2.Stream.empty - private def inputStreamToFs2(inputStream: () => InputStream) = + private def inputStreamToFs2(inputStream: () => InputStream, chunkSize: Int) = fs2.io.readInputStream( Sync[F].blocking(inputStream()), - NettyToResponseBody.DefaultChunkSize + chunkSize ) } } diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServerInterpreter.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServerInterpreter.scala index 246c49c6bd..a2255216dd 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServerInterpreter.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServerInterpreter.scala @@ -3,11 +3,9 @@ package sttp.tapir.server.netty import sttp.monad.FutureMonad import sttp.tapir.server.ServerEndpoint import sttp.tapir.server.netty.NettyFutureServerInterpreter.FutureRunAsync -import sttp.tapir.server.netty.internal.{NettyServerInterpreter, RunAsync} +import sttp.tapir.server.netty.internal.{NettyFutureRequestBody, NettyServerInterpreter, NettyToResponseBody, RunAsync} import scala.concurrent.{ExecutionContext, Future} -import sttp.tapir.server.netty.internal.NettyFutureRequestBody -import sttp.tapir.server.netty.internal.NettyToResponseBody trait NettyFutureServerInterpreter { def nettyServerOptions: NettyFutureServerOptions diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyStreamingRequestBody.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyStreamingRequestBody.scala index cdf19f1f4e..992a79bc27 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyStreamingRequestBody.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyStreamingRequestBody.scala @@ -7,6 +7,7 @@ import sttp.capabilities.Streams import sttp.tapir.model.ServerRequest import sttp.tapir.server.netty.internal.reactivestreams.NettyRequestBody +/** Common logic for processing streaming request body in all Netty backends which support streaming. */ trait NettyStreamingRequestBody[F[_], S <: Streams[S]] extends NettyRequestBody[F, S] { val streamCompatible: StreamCompatible[S] @@ -20,5 +21,5 @@ trait NettyStreamingRequestBody[F[_], S <: Streams[S]] extends NettyRequestBody[ streamCompatible.fromPublisher(publisher, maxBytes) case other => streamCompatible.failedStream(new UnsupportedOperationException(s"Unexpected Netty request of type: ${other.getClass().getName()}")) - }).asInstanceOf[streams.BinaryStream] + }).asInstanceOf[streams.BinaryStream] // Scala can't figure out that it's the same type as streamCompatible.streams.BinaryStream } diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToResponseBody.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToResponseBody.scala index 0d16c48142..7da53552f7 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToResponseBody.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToResponseBody.scala @@ -19,7 +19,11 @@ import java.io.InputStream import java.nio.ByteBuffer import java.nio.charset.Charset -class NettyToResponseBody[F[_]](implicit me: MonadError[F]) extends ToResponseBody[NettyResponse, NoStreams] { +/** Common logic for producing response body from responses in all Netty backends that don't support streaming. These backends use our custom reactive + * Publishers to integrate responses like InputStreamBody, InputStreamRangeBody or FileBody with Netty reactive extensions. Other kinds of + * raw responses like directly available String, ByteArray or ByteBuffer can be returned without wrapping into a Publisher. + */ +private[netty] class NettyToResponseBody[F[_]](implicit me: MonadError[F]) extends ToResponseBody[NettyResponse, NoStreams] { override val streams: capabilities.Streams[NoStreams] = NoStreams override def fromRawValue[R](v: R, headers: HasHeaders, format: CodecFormat, bodyType: RawBodyType[R]): NettyResponse = { diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToStreamsResponseBody.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToStreamsResponseBody.scala index 2168cf2bdd..0243415d91 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToStreamsResponseBody.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToStreamsResponseBody.scala @@ -6,12 +6,17 @@ import sttp.capabilities.Streams import sttp.model.HasHeaders import sttp.tapir.server.interpreter.ToResponseBody import sttp.tapir.server.netty.NettyResponse +import sttp.tapir.server.netty.internal.NettyToResponseBody._ import sttp.tapir.server.netty.NettyResponseContent.{ByteBufNettyResponseContent, ReactivePublisherNettyResponseContent} import sttp.tapir.{CodecFormat, RawBodyType, WebSocketBodyOutput} import java.nio.ByteBuffer import java.nio.charset.Charset +/** Common logic for producing response body in all Netty backends that support streaming. These backends use streaming libraries + * like fs2 or zio-streams to obtain reactive Publishers representing responses like InputStreamBody, InputStreamRangeBody or FileBody. + * Other kinds of raw responses like directly available String, ByteArray or ByteBuffer can be returned without wrapping into a Publisher. + */ class NettyToStreamsResponseBody[S <: Streams[S]](streamCompatible: StreamCompatible[S]) extends ToResponseBody[NettyResponse, S] { override val streams: S = streamCompatible.streams @@ -32,17 +37,17 @@ class NettyToStreamsResponseBody[S <: Streams[S]](streamCompatible: StreamCompat case RawBodyType.InputStreamBody => (ctx: ChannelHandlerContext) => - new ReactivePublisherNettyResponseContent(ctx.newPromise(), streamCompatible.publisherFromInputStream(() => v, length = None)) + new ReactivePublisherNettyResponseContent(ctx.newPromise(), streamCompatible.publisherFromInputStream(() => v, DefaultChunkSize, length = None)) case RawBodyType.InputStreamRangeBody => (ctx: ChannelHandlerContext) => new ReactivePublisherNettyResponseContent( ctx.newPromise(), - streamCompatible.publisherFromInputStream(v.inputStreamFromRangeStart, length = v.range.map(_.contentLength)) + streamCompatible.publisherFromInputStream(v.inputStreamFromRangeStart, DefaultChunkSize, length = v.range.map(_.contentLength)) ) case RawBodyType.FileBody => - (ctx: ChannelHandlerContext) => new ReactivePublisherNettyResponseContent(ctx.newPromise(), streamCompatible.publisherFromFile(v)) + (ctx: ChannelHandlerContext) => new ReactivePublisherNettyResponseContent(ctx.newPromise(), streamCompatible.publisherFromFile(v, DefaultChunkSize)) case _: RawBodyType.MultipartBody => throw new UnsupportedOperationException } diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/StreamCompatible.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/StreamCompatible.scala index 3e703b7196..30070cd520 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/StreamCompatible.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/StreamCompatible.scala @@ -9,17 +9,17 @@ import java.io.InputStream private[netty] trait StreamCompatible[S <: Streams[S]] { val streams: S - def fromFile(file: FileRange): streams.BinaryStream - def fromInputStream(is: () => InputStream, length: Option[Long]): streams.BinaryStream + def fromFile(file: FileRange, chunkSize: Int): streams.BinaryStream + def fromInputStream(is: () => InputStream, chunkSize: Int, length: Option[Long]): streams.BinaryStream def fromPublisher(publisher: Publisher[HttpContent], maxBytes: Option[Long]): streams.BinaryStream def asPublisher(s: streams.BinaryStream): Publisher[HttpContent] def failedStream(e: => Throwable): streams.BinaryStream def emptyStream: streams.BinaryStream - def publisherFromFile(file: FileRange): Publisher[HttpContent] = - asPublisher(fromFile(file)) + def publisherFromFile(file: FileRange, chunkSize: Int): Publisher[HttpContent] = + asPublisher(fromFile(file, chunkSize)) - def publisherFromInputStream(is: () => InputStream, length: Option[Long]): Publisher[HttpContent] = - asPublisher(fromInputStream(is, length)) + def publisherFromInputStream(is: () => InputStream, chunkSize: Int, length: Option[Long]): Publisher[HttpContent] = + asPublisher(fromInputStream(is, chunkSize, length)) } diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/FileRangePublisher.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/FileRangePublisher.scala index e981bec388..9e84324b63 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/FileRangePublisher.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/FileRangePublisher.scala @@ -10,6 +10,8 @@ import java.nio.channels.{AsynchronousFileChannel, CompletionHandler} import java.nio.file.StandardOpenOption import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong} +/** A Reactive Streams publisher which emits chunks of HttpContent read from a given file. + */ class FileRangePublisher(fileRange: FileRange, chunkSize: Int) extends Publisher[HttpContent] { override def subscribe(subscriber: Subscriber[_ >: HttpContent]): Unit = { if (subscriber == null) throw new NullPointerException("Subscriber cannot be null") @@ -44,6 +46,7 @@ class FileRangePublisher(fileRange: FileRange, chunkSize: Int) extends Publisher case _ => chunkSize } buffer.clear() + // Async call, so readNextChunkIfNeeded() finishes immediately after firing this channel.read( buffer, pos, diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/FileWriterSubscriber.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/FileWriterSubscriber.scala index dad69b86a9..523eaf3ab9 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/FileWriterSubscriber.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/FileWriterSubscriber.scala @@ -8,11 +8,21 @@ import java.nio.file.{Path, StandardOpenOption} import scala.concurrent.{Future, Promise} import java.util.concurrent.LinkedBlockingQueue +/** A Reactive Streams subscriber which receives chunks of bytes and writes them to a file. + */ class FileWriterSubscriber(path: Path) extends PromisingSubscriber[Unit, HttpContent] { private var subscription: Subscription = _ + + /** JDK interface to write asynchronously to a file */ private var fileChannel: AsynchronousFileChannel = _ + + /** Current position in the file */ private var position: Long = 0 + + /** Used to signal completion, so that external code can represent writing to a file as Future[Unit] */ private val resultPromise = Promise[Unit]() + + /** An alternative way to signal completion, so that non-effectful servers can await on the response (like netty-loom) */ private val resultBlockingQueue = new LinkedBlockingQueue[Either[Throwable, Unit]]() override def future: Future[Unit] = resultPromise.future diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/NettyRequestBody.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/NettyRequestBody.scala index 80d11cc996..27ac91e350 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/NettyRequestBody.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/NettyRequestBody.scala @@ -17,12 +17,33 @@ import java.io.ByteArrayInputStream import java.nio.ByteBuffer import sttp.capabilities.Streams -trait NettyRequestBody[F[_], S <: Streams[S]] extends RequestBody[F, S] { +/** Common logic for processing request body in all Netty backends. It requires particular backends to implement a few operations. */ +private[netty] trait NettyRequestBody[F[_], S <: Streams[S]] extends RequestBody[F, S] { - val DefaultChunkSize = 8192 implicit def monad: MonadError[F] + + /** Backend-specific implementation for creating a file. */ def createFile: ServerRequest => F[TapirFile] + + /** Backend-specific way to process all elements emitted by a Publisher[HttpContent] into a raw array of bytes. + * + * @param publisher reactive publisher emitting byte chunks. + * @param maxBytes + * optional request length limit. If exceeded, The effect `F` is failed with a [[sttp.capabilities.StreamMaxLengthExceededException]] + * @return An effect which finishes with a single array of all collected bytes. + */ def publisherToBytes(publisher: Publisher[HttpContent], maxBytes: Option[Long]): F[Array[Byte]] + + /** Backend-specific way to process all elements emitted by a Publisher[HttpContent] and write their bytes into a file. + * + * @param serverRequest + * can have underlying `Publisher[HttpContent]` or an empty `FullHttpRequest` + * @param file + * an empty file where bytes should be stored. + * @param maxBytes + * optional request length limit. If exceeded, The effect `F` is failed with a [[sttp.capabilities.StreamMaxLengthExceededException]] + * @return an effect which finishes when all data is written to the file. + */ def writeToFile(serverRequest: ServerRequest, file: TapirFile, maxBytes: Option[Long]): F[Unit] override def toRaw[RAW](serverRequest: ServerRequest, bodyType: RawBodyType[RAW], maxBytes: Option[Long]): F[RawValue[RAW]] = { @@ -33,8 +54,10 @@ trait NettyRequestBody[F[_], S <: Streams[S]] extends RequestBody[F, S] { case RawBodyType.ByteBufferBody => readAllBytes(serverRequest, maxBytes).map(bs => RawValue(ByteBuffer.wrap(bs))) case RawBodyType.InputStreamBody => + // Possibly can be optimized to avoid loading all data eagerly into memory readAllBytes(serverRequest, maxBytes).map(bs => RawValue(new ByteArrayInputStream(bs))) case RawBodyType.InputStreamRangeBody => + // Possibly can be optimized to avoid loading all data eagerly into memory readAllBytes(serverRequest, maxBytes).map(bs => RawValue(InputStreamRange(() => new ByteArrayInputStream(bs)))) case RawBodyType.FileBody => for { @@ -45,11 +68,9 @@ trait NettyRequestBody[F[_], S <: Streams[S]] extends RequestBody[F, S] { } } - - // Used by different netty backends to handle raw body input - def readAllBytes(serverRequest: ServerRequest, maxBytes: Option[Long]): F[Array[Byte]] = + private def readAllBytes(serverRequest: ServerRequest, maxBytes: Option[Long]): F[Array[Byte]] = serverRequest.underlying match { - case r: FullHttpRequest if r.content() == Unpooled.EMPTY_BUFFER => + case r: FullHttpRequest if r.content() == Unpooled.EMPTY_BUFFER => // Empty request monad.unit(Array.empty[Byte]) case req: StreamedHttpRequest => publisherToBytes(req, maxBytes) diff --git a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/internal/ZioStreamCompatible.scala b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/internal/ZioStreamCompatible.scala index 8d0655f603..7ec3ae4fa9 100644 --- a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/internal/ZioStreamCompatible.scala +++ b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/internal/ZioStreamCompatible.scala @@ -18,12 +18,12 @@ private[zio] object ZioStreamCompatible { new StreamCompatible[ZioStreams] { override val streams: ZioStreams = ZioStreams - override def fromFile(fileRange: FileRange): streams.BinaryStream = { + override def fromFile(fileRange: FileRange, chunkSize: Int): streams.BinaryStream = { fileRange.range .flatMap(r => r.startAndEnd.map { case (fStart, _) => ZStream - .fromPath(fileRange.file.toPath) + .fromPath(fileRange.file.toPath, chunkSize) .drop(fStart.toInt) .take(r.contentLength) } @@ -33,10 +33,10 @@ private[zio] object ZioStreamCompatible { ) } - override def fromInputStream(is: () => InputStream, length: Option[Long]): streams.BinaryStream = + override def fromInputStream(is: () => InputStream, chunkSize: Int, length: Option[Long]): streams.BinaryStream = length match { - case Some(limitedLength) => ZStream.fromInputStream(is()).take(limitedLength.toInt) - case None => ZStream.fromInputStream(is()) + case Some(limitedLength) => ZStream.fromInputStream(is(), chunkSize).take(limitedLength.toInt) + case None => ZStream.fromInputStream(is(), chunkSize) } override def asPublisher(stream: Stream[Throwable, Byte]): Publisher[HttpContent] = @@ -49,8 +49,9 @@ private[zio] object ZioStreamCompatible { override def fromPublisher(publisher: Publisher[HttpContent], maxBytes: Option[Long]): streams.BinaryStream = { val stream = Adapters - .publisherToStream(publisher, 16) - .flatMap(httpContent => ZStream.fromChunk(Chunk.fromByteBuffer(httpContent.content.nioBuffer()))) + .publisherToStream(publisher, bufferSize = 2) + .map(httpContent => Chunk.fromByteBuffer(httpContent.content.nioBuffer())) + .flattenChunks maxBytes.map(ZioStreams.limitBytes(stream, _)).getOrElse(stream) } From a9621dc93e159e99482e893502ec51d30fbc0e43 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Fri, 1 Dec 2023 17:27:45 +0100 Subject: [PATCH 32/35] Documentation and package reorganization for public API --- doc/endpoint/security.md | 28 ++++++++++++++++ .../server/interpreter/MaxContentLength.scala | 9 ----- .../interpreter/ServerInterpreter.scala | 2 +- .../server/model/EndpointExtensions.scala | 33 +++++++++++++++++++ .../tapir/server/tests/ServerBasicTests.scala | 5 +-- .../server/tests/ServerStreamingTests.scala | 2 +- 6 files changed, 66 insertions(+), 13 deletions(-) delete mode 100644 server/core/src/main/scala/sttp/tapir/server/interpreter/MaxContentLength.scala create mode 100644 server/core/src/main/scala/sttp/tapir/server/model/EndpointExtensions.scala diff --git a/doc/endpoint/security.md b/doc/endpoint/security.md index e94f2c1fa2..f2fbfa7ea0 100644 --- a/doc/endpoint/security.md +++ b/doc/endpoint/security.md @@ -49,6 +49,34 @@ will show you a password prompt. Optional and multiple authentication inputs have some additional rules as to how hey map to documentation, see the ["Authentication inputs and security requirements"](../docs/openapi.md) section in the OpenAPI docs for details. +## Limiting request body length + +*Supported backends*: +Feature enabled only for Netty-based servers. More backends will be added in the near future. + +Individual endpoints can be annotated with content length limit: + +```scala mdoc:compile-only +import sttp.tapir._ +import sttp.tapir.server.model.EndpointExtensions._ + +val limitedEndpoint = endpoint.maxRequestBodyLength(maxBytes = 163484L) +``` + +The `EndpointsExtensions` utility is available in `tapir-server` core module. If you can't depend on it where your endpoint +definitions are located, you can directly put an attribute: + +```scala mdoc:compile-only +import sttp.tapir._ +import sttp.tapir.server.model.MaxContentLength + +val limitedEndpoint = endpoint.attribute(AttributeKey[MaxContentLength], MaxContentLength(16384L)) +``` +Such protection would prevent loading all the input data into memory if it exceeds the limit. Instead, it will result +in a `HTTP 413` response to the client. +Please note that in case of endpoints with `streamBody` input type, the server logic receives a reference to a lazily +evaluated stream, so actual length verification will happen only when the logic performs streams processing, not earlier. + ## Next Read on about [streaming support](streaming.md). diff --git a/server/core/src/main/scala/sttp/tapir/server/interpreter/MaxContentLength.scala b/server/core/src/main/scala/sttp/tapir/server/interpreter/MaxContentLength.scala deleted file mode 100644 index 5b9774f8a1..0000000000 --- a/server/core/src/main/scala/sttp/tapir/server/interpreter/MaxContentLength.scala +++ /dev/null @@ -1,9 +0,0 @@ -package sttp.tapir.server.interpreter - -/** Can be used as an endpoint attribute. - * @example - * {{{ - * endpoint.attribute(AttributeKey[MaxContentLength], MaxContentLength(16384L)) - * }}} - */ -case class MaxContentLength(value: Long) diff --git a/server/core/src/main/scala/sttp/tapir/server/interpreter/ServerInterpreter.scala b/server/core/src/main/scala/sttp/tapir/server/interpreter/ServerInterpreter.scala index 2b34ec1df5..0656677c5c 100644 --- a/server/core/src/main/scala/sttp/tapir/server/interpreter/ServerInterpreter.scala +++ b/server/core/src/main/scala/sttp/tapir/server/interpreter/ServerInterpreter.scala @@ -7,7 +7,7 @@ import sttp.monad.syntax._ import sttp.tapir.internal.{Params, ParamsAsAny, RichOneOfBody} import sttp.tapir.model.ServerRequest import sttp.tapir.server.interceptor._ -import sttp.tapir.server.model.{ServerResponse, ValuedEndpointOutput} +import sttp.tapir.server.model.{MaxContentLength, ServerResponse, ValuedEndpointOutput} import sttp.tapir.server.{model, _} import sttp.tapir.{DecodeResult, EndpointIO, EndpointInput, TapirFile} import sttp.tapir.EndpointInfo diff --git a/server/core/src/main/scala/sttp/tapir/server/model/EndpointExtensions.scala b/server/core/src/main/scala/sttp/tapir/server/model/EndpointExtensions.scala new file mode 100644 index 0000000000..a9984077a8 --- /dev/null +++ b/server/core/src/main/scala/sttp/tapir/server/model/EndpointExtensions.scala @@ -0,0 +1,33 @@ +package sttp.tapir.server.model + +import sttp.tapir.EndpointInfoOps +import sttp.tapir.AttributeKey + +/** Can be used as an endpoint attribute. + * @example + * {{{ + * endpoint.attribute(AttributeKey[MaxContentLength], MaxContentLength(16384L)) + * }}} + */ +case class MaxContentLength(value: Long) extends AnyVal + +object EndpointExtensions { + private val MaxContentLengthAttributeKey: AttributeKey[MaxContentLength] = AttributeKey[MaxContentLength] + + implicit class RichServerEndpoint[E <: EndpointInfoOps[_]](e: E) { + + /** Enables checks that prevent loading full request body into memory if it exceeds given limit. Otherwise causes endpoint to reply with + * HTTP 413 Payload Too Loarge. + * + * Please refer to Tapir docs to ensure which backends are supported: https://tapir.softwaremill.com/en/latest/endpoint/security.html + * @example + * {{{ + * endpoint.maxRequestBodyLength(16384L) + * }}} + * @param maxBytes + * maximum allowed size of request body in bytes. + */ + def maxRequestBodyLength(maxBytes: Long): E = + e.attribute(MaxContentLengthAttributeKey, MaxContentLength(maxBytes)).asInstanceOf[E] + } +} diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala index add8d4ae97..10d24d60f7 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala @@ -15,6 +15,8 @@ import sttp.tapir.codec.enumeratum.TapirCodecEnumeratum import sttp.tapir.generic.auto._ import sttp.tapir.json.circe._ import sttp.tapir.server.ServerEndpoint +import sttp.tapir.server.model.EndpointExtensions._ +import sttp.tapir.server.model._ import sttp.tapir.server.interceptor.decodefailure.DefaultDecodeFailureHandler import sttp.tapir.tests.Basic._ import sttp.tapir.tests.TestUtil._ @@ -23,7 +25,6 @@ import sttp.tapir.tests.data.{FruitAmount, FruitError} import java.io.{ByteArrayInputStream, InputStream} import java.nio.ByteBuffer -import sttp.tapir.server.interpreter.MaxContentLength import sttp.tapir.tests.Files.in_file_out_file class ServerBasicTests[F[_], OPTIONS, ROUTE]( @@ -750,7 +751,7 @@ class ServerBasicTests[F[_], OPTIONS, ROUTE]( testedEndpoint: PublicEndpoint[I, Unit, I, Any], maxLength: Int ) = testServer( - testedEndpoint.attribute(AttributeKey[MaxContentLength], MaxContentLength(maxLength.toLong)), + testedEndpoint.maxRequestBodyLength(maxLength.toLong), "returns 413 on exceeded max content length (request)" )(i => pureResult(i.asRight[Unit])) { (backend, baseUri) => val tooLargeBody: String = List.fill(maxLength + 1)('x').mkString diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerStreamingTests.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerStreamingTests.scala index 1d0043188e..2a66e8326e 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerStreamingTests.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerStreamingTests.scala @@ -9,7 +9,7 @@ import sttp.monad.MonadError import sttp.monad.syntax._ import sttp.tapir.tests.Test import sttp.tapir.tests.Streaming._ -import sttp.tapir.server.interpreter.MaxContentLength +import sttp.tapir.server.model.MaxContentLength import sttp.tapir.AttributeKey import cats.effect.IO import sttp.capabilities.fs2.Fs2Streams From 8320de45b8c29fbbcd64a208bdbc7b08b2f7693e Mon Sep 17 00:00:00 2001 From: kciesielski Date: Fri, 1 Dec 2023 17:40:29 +0100 Subject: [PATCH 33/35] More docs and minor polishing touches --- doc/migrating.md | 4 ++++ .../tapir/server/netty/internal/NettyFutureRequestBody.scala | 2 +- .../server/netty/internal/NettyStreamingRequestBody.scala | 2 +- .../tapir/server/netty/internal/NettyToResponseBody.scala | 2 +- .../server/netty/internal/NettyToStreamsResponseBody.scala | 2 +- .../sttp/tapir/server/netty/internal/StreamCompatible.scala | 5 +++++ 6 files changed, 13 insertions(+), 4 deletions(-) diff --git a/doc/migrating.md b/doc/migrating.md index 6d6a59210d..3d61874a17 100644 --- a/doc/migrating.md +++ b/doc/migrating.md @@ -1,5 +1,9 @@ # Migrating +## From 1.9.3 to 1.9.4 + +- `NettyConfig.defaultNoStreaming` has been removed, use `NettyConfig.default`. + ## From 1.4 to 1.5 - `badRequestOnPathErrorIfPathShapeMatches` and `badRequestOnPathInvalidIfPathShapeMatches` have been removed from `DefaultDecodeFailureHandler`. These flags were causing confusion and incosistencies caused by specifics of ZIO and Play backends. Before tapir 1.5, keeping defaults (`false` and `true` respectively for these flags) meant that some path segment decoding failures (specifically, errors - when an exception has been thrown during decoding, but not for e.g. enumeration mismatches) were translated to a "no-match", meaning that the next endpoint was attempted. From 1.5, tapir defaults to a 400 Bad Request response to be sent instead, on all path decoding failures. diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyFutureRequestBody.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyFutureRequestBody.scala index 0cfa268b3f..6428605bfd 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyFutureRequestBody.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyFutureRequestBody.scala @@ -13,7 +13,7 @@ import scala.concurrent.{ExecutionContext, Future} import reactivestreams._ -class NettyFutureRequestBody(val createFile: ServerRequest => Future[TapirFile])(implicit ec: ExecutionContext) +private[netty] class NettyFutureRequestBody(val createFile: ServerRequest => Future[TapirFile])(implicit ec: ExecutionContext) extends NettyRequestBody[Future, NoStreams] { override val streams: capabilities.Streams[NoStreams] = NoStreams diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyStreamingRequestBody.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyStreamingRequestBody.scala index 992a79bc27..558af56b32 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyStreamingRequestBody.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyStreamingRequestBody.scala @@ -8,7 +8,7 @@ import sttp.tapir.model.ServerRequest import sttp.tapir.server.netty.internal.reactivestreams.NettyRequestBody /** Common logic for processing streaming request body in all Netty backends which support streaming. */ -trait NettyStreamingRequestBody[F[_], S <: Streams[S]] extends NettyRequestBody[F, S] { +private[netty] trait NettyStreamingRequestBody[F[_], S <: Streams[S]] extends NettyRequestBody[F, S] { val streamCompatible: StreamCompatible[S] override val streams = streamCompatible.streams diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToResponseBody.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToResponseBody.scala index 7da53552f7..e0b2c0b35e 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToResponseBody.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToResponseBody.scala @@ -78,6 +78,6 @@ private[netty] class NettyToResponseBody[F[_]](implicit me: MonadError[F]) exten ): NettyResponse = throw new UnsupportedOperationException } -object NettyToResponseBody { +private[netty] object NettyToResponseBody { val DefaultChunkSize = 8192 } diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToStreamsResponseBody.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToStreamsResponseBody.scala index 0243415d91..0f335a7b14 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToStreamsResponseBody.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToStreamsResponseBody.scala @@ -17,7 +17,7 @@ import java.nio.charset.Charset * like fs2 or zio-streams to obtain reactive Publishers representing responses like InputStreamBody, InputStreamRangeBody or FileBody. * Other kinds of raw responses like directly available String, ByteArray or ByteBuffer can be returned without wrapping into a Publisher. */ -class NettyToStreamsResponseBody[S <: Streams[S]](streamCompatible: StreamCompatible[S]) extends ToResponseBody[NettyResponse, S] { +private[netty] class NettyToStreamsResponseBody[S <: Streams[S]](streamCompatible: StreamCompatible[S]) extends ToResponseBody[NettyResponse, S] { override val streams: S = streamCompatible.streams diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/StreamCompatible.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/StreamCompatible.scala index 30070cd520..6d5da177bd 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/StreamCompatible.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/StreamCompatible.scala @@ -7,6 +7,11 @@ import sttp.tapir.FileRange import java.io.InputStream +/** + * Operations on streams that have to be implemented for each streaming integration (fs2, zio-streams, etc) used by Netty backends. + * This includes conversions like building a stream from a `File`, an `InputStream`, or a reactive `Publisher`. + * We also need implementation of a failed (errored) stream, as well as an empty stream (for handling empty requests). + */ private[netty] trait StreamCompatible[S <: Streams[S]] { val streams: S def fromFile(file: FileRange, chunkSize: Int): streams.BinaryStream From a4f5d21f7c1c339f8f4108ce26f503c295e6431d Mon Sep 17 00:00:00 2001 From: kciesielski Date: Fri, 1 Dec 2023 18:04:52 +0100 Subject: [PATCH 34/35] Fix type package --- .../tapir/server/netty/loom/NettyIdRequestBody.scala | 9 ++++----- .../netty/internal/NettyFutureRequestBody.scala | 3 +-- .../{reactivestreams => }/NettyRequestBody.scala | 11 +++++++---- .../netty/internal/NettyStreamingRequestBody.scala | 1 - 4 files changed, 12 insertions(+), 12 deletions(-) rename server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/{reactivestreams => }/NettyRequestBody.scala (92%) diff --git a/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdRequestBody.scala b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdRequestBody.scala index 35b953b200..5b1aaf8980 100644 --- a/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdRequestBody.scala +++ b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdRequestBody.scala @@ -1,18 +1,17 @@ package sttp.tapir.server.netty.loom import io.netty.handler.codec.http.HttpContent +import org.playframework.netty.http.StreamedHttpRequest import org.reactivestreams.Publisher import sttp.capabilities import sttp.monad.MonadError import sttp.tapir.TapirFile import sttp.tapir.capabilities.NoStreams import sttp.tapir.model.ServerRequest -import sttp.tapir.server.netty.internal.reactivestreams.NettyRequestBody -import sttp.tapir.server.netty.internal.reactivestreams.SimpleSubscriber -import sttp.tapir.server.netty.internal.reactivestreams.FileWriterSubscriber -import org.playframework.netty.http.StreamedHttpRequest +import sttp.tapir.server.netty.internal.NettyRequestBody +import sttp.tapir.server.netty.internal.reactivestreams.{FileWriterSubscriber, SimpleSubscriber} -class NettyIdRequestBody(val createFile: ServerRequest => TapirFile) extends NettyRequestBody[Id, NoStreams] { +private[netty] class NettyIdRequestBody(val createFile: ServerRequest => TapirFile) extends NettyRequestBody[Id, NoStreams] { override implicit val monad: MonadError[Id] = idMonad override val streams: capabilities.Streams[NoStreams] = NoStreams diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyFutureRequestBody.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyFutureRequestBody.scala index 6428605bfd..c6dcbf0a9d 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyFutureRequestBody.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyFutureRequestBody.scala @@ -8,11 +8,10 @@ import sttp.monad.{FutureMonad, MonadError} import sttp.tapir.TapirFile import sttp.tapir.capabilities.NoStreams import sttp.tapir.model.ServerRequest +import sttp.tapir.server.netty.internal.reactivestreams._ import scala.concurrent.{ExecutionContext, Future} -import reactivestreams._ - private[netty] class NettyFutureRequestBody(val createFile: ServerRequest => Future[TapirFile])(implicit ec: ExecutionContext) extends NettyRequestBody[Future, NoStreams] { diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/NettyRequestBody.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyRequestBody.scala similarity index 92% rename from server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/NettyRequestBody.scala rename to server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyRequestBody.scala index 27ac91e350..9d1375e7a5 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/NettyRequestBody.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyRequestBody.scala @@ -1,4 +1,4 @@ -package sttp.tapir.server.netty.internal.reactivestreams +package sttp.tapir.server.netty.internal import io.netty.buffer.Unpooled import io.netty.handler.codec.http.{FullHttpRequest, HttpContent} @@ -27,10 +27,12 @@ private[netty] trait NettyRequestBody[F[_], S <: Streams[S]] extends RequestBody /** Backend-specific way to process all elements emitted by a Publisher[HttpContent] into a raw array of bytes. * - * @param publisher reactive publisher emitting byte chunks. + * @param publisher + * reactive publisher emitting byte chunks. * @param maxBytes * optional request length limit. If exceeded, The effect `F` is failed with a [[sttp.capabilities.StreamMaxLengthExceededException]] - * @return An effect which finishes with a single array of all collected bytes. + * @return + * An effect which finishes with a single array of all collected bytes. */ def publisherToBytes(publisher: Publisher[HttpContent], maxBytes: Option[Long]): F[Array[Byte]] @@ -42,7 +44,8 @@ private[netty] trait NettyRequestBody[F[_], S <: Streams[S]] extends RequestBody * an empty file where bytes should be stored. * @param maxBytes * optional request length limit. If exceeded, The effect `F` is failed with a [[sttp.capabilities.StreamMaxLengthExceededException]] - * @return an effect which finishes when all data is written to the file. + * @return + * an effect which finishes when all data is written to the file. */ def writeToFile(serverRequest: ServerRequest, file: TapirFile, maxBytes: Option[Long]): F[Unit] diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyStreamingRequestBody.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyStreamingRequestBody.scala index 558af56b32..cccb1a0fce 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyStreamingRequestBody.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyStreamingRequestBody.scala @@ -5,7 +5,6 @@ import io.netty.handler.codec.http.FullHttpRequest import org.playframework.netty.http.StreamedHttpRequest import sttp.capabilities.Streams import sttp.tapir.model.ServerRequest -import sttp.tapir.server.netty.internal.reactivestreams.NettyRequestBody /** Common logic for processing streaming request body in all Netty backends which support streaming. */ private[netty] trait NettyStreamingRequestBody[F[_], S <: Streams[S]] extends NettyRequestBody[F, S] { From 36ad071941a48d9f1922a0c9ce01f845d69dccb0 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Mon, 4 Dec 2023 15:36:16 +0100 Subject: [PATCH 35/35] Review fixes --- doc/endpoint/security.md | 14 +++----------- .../tapir/server/model/EndpointExtensions.scala | 9 ++++++--- .../reactivestreams/FileWriterSubscriber.scala | 2 +- .../sttp/tapir/server/tests/AllServerTests.scala | 2 +- 4 files changed, 11 insertions(+), 16 deletions(-) diff --git a/doc/endpoint/security.md b/doc/endpoint/security.md index f2fbfa7ea0..cb677720bb 100644 --- a/doc/endpoint/security.md +++ b/doc/endpoint/security.md @@ -63,17 +63,9 @@ import sttp.tapir.server.model.EndpointExtensions._ val limitedEndpoint = endpoint.maxRequestBodyLength(maxBytes = 163484L) ``` -The `EndpointsExtensions` utility is available in `tapir-server` core module. If you can't depend on it where your endpoint -definitions are located, you can directly put an attribute: - -```scala mdoc:compile-only -import sttp.tapir._ -import sttp.tapir.server.model.MaxContentLength - -val limitedEndpoint = endpoint.attribute(AttributeKey[MaxContentLength], MaxContentLength(16384L)) -``` -Such protection would prevent loading all the input data into memory if it exceeds the limit. Instead, it will result -in a `HTTP 413` response to the client. +The `EndpointsExtensions` utility is available in `tapir-server` core module. +Such protection prevents loading all the input data if it exceeds the limit. Instead, it will result in a `HTTP 413` +response to the client. Please note that in case of endpoints with `streamBody` input type, the server logic receives a reference to a lazily evaluated stream, so actual length verification will happen only when the logic performs streams processing, not earlier. diff --git a/server/core/src/main/scala/sttp/tapir/server/model/EndpointExtensions.scala b/server/core/src/main/scala/sttp/tapir/server/model/EndpointExtensions.scala index a9984077a8..d2f0b31609 100644 --- a/server/core/src/main/scala/sttp/tapir/server/model/EndpointExtensions.scala +++ b/server/core/src/main/scala/sttp/tapir/server/model/EndpointExtensions.scala @@ -6,13 +6,16 @@ import sttp.tapir.AttributeKey /** Can be used as an endpoint attribute. * @example * {{{ - * endpoint.attribute(AttributeKey[MaxContentLength], MaxContentLength(16384L)) + * endpoint.attribute(MaxContentLength.attributeKey, MaxContentLength(16384L)) * }}} */ case class MaxContentLength(value: Long) extends AnyVal +object MaxContentLength { + val attributeKey: AttributeKey[MaxContentLength] = AttributeKey[MaxContentLength] +} + object EndpointExtensions { - private val MaxContentLengthAttributeKey: AttributeKey[MaxContentLength] = AttributeKey[MaxContentLength] implicit class RichServerEndpoint[E <: EndpointInfoOps[_]](e: E) { @@ -28,6 +31,6 @@ object EndpointExtensions { * maximum allowed size of request body in bytes. */ def maxRequestBodyLength(maxBytes: Long): E = - e.attribute(MaxContentLengthAttributeKey, MaxContentLength(maxBytes)).asInstanceOf[E] + e.attribute(MaxContentLength.attributeKey, MaxContentLength(maxBytes)).asInstanceOf[E] } } diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/FileWriterSubscriber.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/FileWriterSubscriber.scala index 523eaf3ab9..e7c4ca0479 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/FileWriterSubscriber.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/FileWriterSubscriber.scala @@ -17,7 +17,7 @@ class FileWriterSubscriber(path: Path) extends PromisingSubscriber[Unit, HttpCon private var fileChannel: AsynchronousFileChannel = _ /** Current position in the file */ - private var position: Long = 0 + @volatile private var position: Long = 0 /** Used to signal completion, so that external code can represent writing to a file as Future[Unit] */ private val resultPromise = Promise[Unit]() diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/AllServerTests.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/AllServerTests.scala index 8e6ad861ff..922e6973eb 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/AllServerTests.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/AllServerTests.scala @@ -28,7 +28,7 @@ class AllServerTests[F[_], OPTIONS, ROUTE]( oneOfBody: Boolean = true, cors: Boolean = true, options: Boolean = true, - maxContentLength: Boolean = false + maxContentLength: Boolean = false // TODO let's work towards making this true by default )(implicit m: MonadError[F] ) {