Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MaxContentLength support (streaming) #3319

Merged
merged 10 commits into from
Nov 23, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ object ZTapirTest extends ZIOSpecDefault with ZTapir {
private val exampleRequestBody = new RequestBody[TestEffect, RequestBodyType] {
override val streams: Streams[RequestBodyType] = null.asInstanceOf[Streams[RequestBodyType]]
override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R]): TestEffect[RawValue[R]] = ???
override def toStream(serverRequest: ServerRequest): streams.BinaryStream = ???
override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = ???
}

private val exampleToResponse: ToResponseBody[ResponseBodyType, RequestBodyType] = new ToResponseBody[ResponseBodyType, RequestBodyType] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ object ZTapirTest extends DefaultRunnableSpec with ZTapir {
private val exampleRequestBody = new RequestBody[TestEffect, RequestBodyType] {
override val streams: Streams[RequestBodyType] = null.asInstanceOf[Streams[RequestBodyType]]
override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R]): TestEffect[RawValue[R]] = ???
override def toStream(serverRequest: ServerRequest): streams.BinaryStream = ???
override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = ???
}

private val exampleToResponse: ToResponseBody[ResponseBodyType, RequestBodyType] = new ToResponseBody[ResponseBodyType, RequestBodyType] {
Expand Down
2 changes: 1 addition & 1 deletion project/Versions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ object Versions {
val helidon = "4.0.0"
val sttp = "3.9.1"
val sttpModel = "1.7.6"
val sttpShared = "1.3.16"
val sttpShared = "1.3.17"
val sttpApispec = "0.7.2"
val akkaHttp = "10.2.10"
val akkaStreams = "2.6.20"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ private[akkagrpc] class AkkaGrpcRequestBody(serverOptions: AkkaHttpServerOptions
override def toRaw[R](request: ServerRequest, bodyType: RawBodyType[R]): Future[RawValue[R]] =
toRawFromEntity(request, akkaRequestEntity(request), bodyType)

override def toStream(request: ServerRequest): streams.BinaryStream = ???
override def toStream(request: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = ???

private def akkaRequestEntity(request: ServerRequest) = request.underlying.asInstanceOf[RequestContext].request.entity

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ private[akkahttp] class AkkaRequestBody(serverOptions: AkkaHttpServerOptions)(im
override val streams: AkkaStreams = AkkaStreams
override def toRaw[R](request: ServerRequest, bodyType: RawBodyType[R]): Future[RawValue[R]] =
toRawFromEntity(request, akkeRequestEntity(request), bodyType)
override def toStream(request: ServerRequest): streams.BinaryStream = akkeRequestEntity(request).dataBytes
override def toStream(request: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = {
val stream = akkeRequestEntity(request).dataBytes
maxBytes.map(AkkaStreams.limitBytes(stream, _)).getOrElse(stream)
}

private def akkeRequestEntity(request: ServerRequest) = request.underlying.asInstanceOf[RequestContext].request.entity

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,11 @@ class AkkaHttpServerTest extends TestSuite with EitherValues {
}
}
)
def drainAkka(stream: AkkaStreams.BinaryStream): Future[Unit] =
stream.runWith(Sink.ignore).map(_ => ())

new AllServerTests(createServerTest, interpreter, backend).tests() ++
new ServerStreamingTests(createServerTest, AkkaStreams).tests() ++
new ServerStreamingTests(createServerTest, maxLengthSupported = true).tests(AkkaStreams)(drainAkka) ++
new ServerWebSocketTests(createServerTest, AkkaStreams) {
override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = Flow.fromFunction(f)
override def emptyPipe[A, B]: Flow[A, B, Any] = Flow.fromSinkAndSource(Sink.ignore, Source.empty)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,10 @@ private object Fs2StreamCompatible {
dispatcher
)

override def fromArmeriaStream(publisher: Publisher[HttpData]): Stream[F, Byte] =
publisher.toStreamBuffered[F](4).flatMap(httpData => Stream.chunk(Chunk.array(httpData.array())))
override def fromArmeriaStream(publisher: Publisher[HttpData], maxBytes: Option[Long]): Stream[F, Byte] = {
val stream = publisher.toStreamBuffered[F](4).flatMap(httpData => Stream.chunk(Chunk.array(httpData.array())))
maxBytes.map(Fs2Streams.limitBytes(stream, _)).getOrElse(stream)
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@ class ArmeriaCatsServerTest extends TestSuite {

val interpreter = new ArmeriaCatsTestServerInterpreter(dispatcher)
val createServerTest = new DefaultCreateServerTest(backend, interpreter)
def drainFs2(stream: Fs2Streams[IO]#BinaryStream): IO[Unit] =
stream.compile.drain.void

new AllServerTests(createServerTest, interpreter, backend, basic = false, options = false).tests() ++
new ServerBasicTests(createServerTest, interpreter, supportsUrlEncodedPathSegments = false).tests() ++
new ServerStreamingTests(createServerTest, Fs2Streams[IO]).tests()
new ServerStreamingTests(createServerTest, maxLengthSupported = true).tests(Fs2Streams[IO])(drainFs2)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ private[armeria] final class ArmeriaRequestBody[F[_], S <: Streams[S]](

override val streams: Streams[S] = streamCompatible.streams

override def toStream(serverRequest: ServerRequest): streams.BinaryStream = {
override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = {
streamCompatible
.fromArmeriaStream(armeriaCtx(serverRequest).request().filter(x => x.isInstanceOf[HttpData]).asInstanceOf[StreamMessage[HttpData]])
.fromArmeriaStream(armeriaCtx(serverRequest).request().filter(x => x.isInstanceOf[HttpData]).asInstanceOf[StreamMessage[HttpData]], maxBytes)
.asInstanceOf[streams.BinaryStream]
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ import sttp.capabilities.Streams
private[armeria] trait StreamCompatible[S <: Streams[S]] {
val streams: S
def asStreamMessage(s: streams.BinaryStream): Publisher[HttpData]
def fromArmeriaStream(s: Publisher[HttpData]): streams.BinaryStream
def fromArmeriaStream(s: Publisher[HttpData], maxBytes: Option[Long]): streams.BinaryStream
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ private[armeria] final case class TapirFutureService(
private object ArmeriaStreamCompatible extends StreamCompatible[ArmeriaStreams] {
override val streams: ArmeriaStreams = ArmeriaStreams

override def fromArmeriaStream(s: Publisher[HttpData]): Publisher[HttpData] = s
override def fromArmeriaStream(s: Publisher[HttpData], maxBytes: Option[Long]): Publisher[HttpData] = s

override def asStreamMessage(s: Publisher[HttpData]): Publisher[HttpData] = s
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import sttp.capabilities.armeria.ArmeriaStreams
import sttp.monad.FutureMonad
import sttp.tapir.server.tests._
import sttp.tapir.tests.{Test, TestSuite}
import scala.concurrent.Future

class ArmeriaFutureServerTest extends TestSuite {

Expand All @@ -16,6 +17,6 @@ class ArmeriaFutureServerTest extends TestSuite {

new AllServerTests(createServerTest, interpreter, backend, basic = false, options = false).tests() ++
new ServerBasicTests(createServerTest, interpreter, supportsUrlEncodedPathSegments = false).tests() ++
new ServerStreamingTests(createServerTest, ArmeriaStreams).tests()
new ServerStreamingTests(createServerTest, maxLengthSupported = false).tests(ArmeriaStreams)(_ => Future.unit)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,10 @@ private object ZioStreamCompatible {
.getOrThrowFiberFailure()
)

override def fromArmeriaStream(publisher: Publisher[HttpData]): Stream[Throwable, Byte] =
publisher.toZIOStream().mapConcatChunk(httpData => Chunk.fromArray(httpData.array()))
override def fromArmeriaStream(publisher: Publisher[HttpData], maxBytes: Option[Long]): Stream[Throwable, Byte] = {
val stream = publisher.toZIOStream().mapConcatChunk(httpData => Chunk.fromArray(httpData.array()))
maxBytes.map(ZioStreams.limitBytes(stream, _)).getOrElse(stream)
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import sttp.tapir.server.tests._
import sttp.tapir.tests.{Test, TestSuite}
import sttp.tapir.ztapir.RIOMonadError
import zio.Task
import zio.stream.ZSink

class ArmeriaZioServerTest extends TestSuite {

Expand All @@ -16,9 +17,11 @@ class ArmeriaZioServerTest extends TestSuite {

val interpreter = new ArmeriaZioTestServerInterpreter()
val createServerTest = new DefaultCreateServerTest(backend, interpreter)
def drainZStream(zStream: ZioStreams.BinaryStream): Task[Unit] =
zStream.run(ZSink.drain)

new AllServerTests(createServerTest, interpreter, backend, basic = false, options = false).tests() ++
new ServerBasicTests(createServerTest, interpreter, supportsUrlEncodedPathSegments = false).tests() ++
new ServerStreamingTests(createServerTest, ZioStreams).tests()
new ServerStreamingTests(createServerTest, maxLengthSupported = true).tests(ZioStreams)(drainZStream)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ private object ZioStreamCompatible {
override def asStreamMessage(stream: Stream[Throwable, Byte]): Publisher[HttpData] =
runtime.unsafeRun(stream.mapChunks(c => Chunk.single(HttpData.wrap(c.toArray))).toPublisher)

override def fromArmeriaStream(publisher: Publisher[HttpData]): Stream[Throwable, Byte] =
override def fromArmeriaStream(publisher: Publisher[HttpData], maxBytes: Option[Long]): Stream[Throwable, Byte] =
publisher.toStream().mapConcatChunk(httpData => Chunk.fromArray(httpData.array()))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@ class ArmeriaZioServerTest extends TestSuite {

new AllServerTests(createServerTest, interpreter, backend, basic = false, options = false).tests() ++
new ServerBasicTests(createServerTest, interpreter, supportsUrlEncodedPathSegments = false).tests() ++
new ServerStreamingTests(createServerTest, ZioStreams).tests()
new ServerStreamingTests(createServerTest, maxLengthSupported = false).tests(ZioStreams)(_ => Task.unit)
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package sttp.tapir.server.interceptor.exception

import sttp.capabilities.StreamMaxLengthExceededException
import sttp.model.StatusCode
import sttp.monad.MonadError
import sttp.tapir.server.model.ValuedEndpointOutput
Expand All @@ -25,7 +26,12 @@ object ExceptionHandler {

case class DefaultExceptionHandler[F[_]](response: (StatusCode, String) => ValuedEndpointOutput[_]) extends ExceptionHandler[F] {
override def apply(ctx: ExceptionContext)(implicit monad: MonadError[F]): F[Option[ValuedEndpointOutput[_]]] =
monad.unit(Some(response(StatusCode.InternalServerError, "Internal server error")))
ctx.e match {
case StreamMaxLengthExceededException(maxBytes) =>
monad.unit(Some(response(StatusCode.PayloadTooLarge, s"Payload limit (${maxBytes}B) exceeded")))
case _ =>
monad.unit(Some(response(StatusCode.InternalServerError, "Internal server error")))
}
}

object DefaultExceptionHandler {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,17 @@ package sttp.tapir.server.interpreter
import sttp.capabilities.Streams
import sttp.model.Part
import sttp.tapir.model.ServerRequest
import sttp.tapir.AttributeKey
import sttp.tapir.EndpointInfo
import sttp.tapir.{FileRange, RawBodyType, RawPart}

case class MaxContentLength(value: Long)

trait RequestBody[F[_], S] {
val streams: Streams[S]
def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R]): F[RawValue[R]]
def toStream(serverRequest: ServerRequest): streams.BinaryStream
def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream

}

case class RawValue[R](value: R, createdFiles: Seq[FileRange] = Nil)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import sttp.tapir.server.{model, _}
import sttp.tapir.server.interceptor._
import sttp.tapir.server.model.{ServerResponse, ValuedEndpointOutput}
import sttp.tapir.{DecodeResult, EndpointIO, EndpointInput, TapirFile}
import sttp.tapir.EndpointInfo
import sttp.tapir.AttributeKey

class ServerInterpreter[R, F[_], B, S](
serverEndpoints: ServerRequest => List[ServerEndpoint[R, F]],
Expand Down Expand Up @@ -106,7 +108,7 @@ class ServerInterpreter[R, F[_], B, S](
// index (so that the correct one is passed to the decode failure handler)
_ <- resultOrValueFrom(DecodeBasicInputsResult.higherPriorityFailure(securityBasicInputs, regularBasicInputs))
// 3. computing the security input value
securityValues <- resultOrValueFrom(decodeBody(request, securityBasicInputs))
securityValues <- resultOrValueFrom(decodeBody(request, securityBasicInputs, se.info))
securityParams <- resultOrValueFrom(InputValue(se.endpoint.securityInput, securityValues))
inputValues <- resultOrValueFrom(regularBasicInputs)
a = securityParams.asAny.asInstanceOf[A]
Expand All @@ -132,7 +134,7 @@ class ServerInterpreter[R, F[_], B, S](
case Right(u) =>
for {
// 5. decoding the body of regular inputs, computing the input value, and running the main logic
values <- resultOrValueFrom(decodeBody(request, inputValues))
values <- resultOrValueFrom(decodeBody(request, inputValues, se.endpoint.info))
params <- resultOrValueFrom(InputValue(se.endpoint.input, values))
response <- resultOrValueFrom.value(
endpointHandler(defaultSecurityFailureResponse, endpointInterceptors)
Expand All @@ -146,29 +148,33 @@ class ServerInterpreter[R, F[_], B, S](

private def decodeBody(
request: ServerRequest,
result: DecodeBasicInputsResult
result: DecodeBasicInputsResult,
endpointInfo: EndpointInfo
): F[DecodeBasicInputsResult] =
result match {
case values: DecodeBasicInputsResult.Values =>
val maxBodyLength = endpointInfo.attribute(AttributeKey[MaxContentLength]).map(_.value)
values.bodyInputWithIndex match {
case Some((Left(oneOfBodyInput), _)) =>
oneOfBodyInput.chooseBodyToDecode(request.contentTypeParsed) match {
case Some(Left(body)) => decodeBody(request, values, body)
case Some(Right(body: EndpointIO.StreamBodyWrapper[Any, Any])) => decodeStreamingBody(request, values, body)
case Some(Right(body: EndpointIO.StreamBodyWrapper[Any, Any])) => decodeStreamingBody(request, values, body, maxBodyLength)
case None => unsupportedInputMediaTypeResponse(request, oneOfBodyInput)
}
case Some((Right(bodyInput: EndpointIO.StreamBodyWrapper[Any, Any]), _)) => decodeStreamingBody(request, values, bodyInput)
case None => (values: DecodeBasicInputsResult).unit
case Some((Right(bodyInput: EndpointIO.StreamBodyWrapper[Any, Any]), _)) =>
decodeStreamingBody(request, values, bodyInput, maxBodyLength)
case None => (values: DecodeBasicInputsResult).unit
}
case failure: DecodeBasicInputsResult.Failure => (failure: DecodeBasicInputsResult).unit
}

private def decodeStreamingBody(
request: ServerRequest,
values: DecodeBasicInputsResult.Values,
bodyInput: EndpointIO.StreamBodyWrapper[Any, Any]
bodyInput: EndpointIO.StreamBodyWrapper[Any, Any],
maxBodyLength: Option[Long]
): F[DecodeBasicInputsResult] =
(bodyInput.codec.decode(requestBody.toStream(request)) match {
(bodyInput.codec.decode(requestBody.toStream(request, maxBodyLength)) match {
case DecodeResult.Value(bodyV) => values.setBodyInputValue(bodyV)
case failure: DecodeResult.Failure => DecodeBasicInputsResult.Failure(bodyInput, failure): DecodeBasicInputsResult
}).unit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ object TestUtil {
object TestRequestBody extends RequestBody[Id, NoStreams] {
override val streams: Streams[NoStreams] = NoStreams
override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R]): Id[RawValue[R]] = ???
override def toStream(serverRequest: ServerRequest): streams.BinaryStream = ???
override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = ???
}

object UnitToResponseBody extends ToResponseBody[Unit, NoStreams] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ class FinatraRequestBody(serverOptions: FinatraServerOptions) extends RequestBod
.map(_.toList)
}

override def toStream(serverRequest: ServerRequest): streams.BinaryStream = throw new UnsupportedOperationException()
override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream =
throw new UnsupportedOperationException()

private def finatraRequest(serverRequest: ServerRequest) = serverRequest.underlying.asInstanceOf[Request]
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@ private[http4s] class Http4sRequestBody[F[_]: Async](
val r = http4sRequest(serverRequest)
toRawFromStream(serverRequest, r.body, bodyType, r.charset)
}
override def toStream(serverRequest: ServerRequest): streams.BinaryStream = http4sRequest(serverRequest).body
override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = {
val stream = http4sRequest(serverRequest).body
maxBytes.map(Fs2Streams.limitBytes(stream, _)).getOrElse(stream)
}

private def http4sRequest(serverRequest: ServerRequest): Request[F] = serverRequest.underlying.asInstanceOf[Request[F]]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,11 @@ class Http4sServerTest[R >: Fs2Streams[IO] with WebSockets] extends TestSuite wi
}
)

def drainFs2(stream: Fs2Streams[IO]#BinaryStream): IO[Unit] =
stream.compile.drain.void

new AllServerTests(createServerTest, interpreter, backend).tests() ++
new ServerStreamingTests(createServerTest, Fs2Streams[IO]).tests() ++
new ServerStreamingTests(createServerTest, maxLengthSupported = true).tests(Fs2Streams[IO])(drainFs2) ++
new ServerWebSocketTests(createServerTest, Fs2Streams[IO]) {
override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = in => in.map(f)
override def emptyPipe[A, B]: Pipe[IO, A, B] = _ => fs2.Stream.empty
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import sttp.tapir.server.http4s.Http4sServerSentEvents
import sttp.tapir.server.tests._
import sttp.tapir.tests.{Test, TestSuite}
import zio.interop.catz._
import zio.stream.ZStream
import zio.stream.{ZSink, ZStream}
import zio.{Task, ZIO}

import java.util.UUID
Expand Down Expand Up @@ -50,9 +50,11 @@ class ZHttp4sServerTest extends TestSuite with OptionValues {
.map(_.body.toOption.value shouldBe List(sse1, sse2))
}
)
def drainZStream(zStream: ZioStreams.BinaryStream): Task[Unit] =
zStream.run(ZSink.drain)

new AllServerTests(createServerTest, interpreter, backend).tests() ++
new ServerStreamingTests(createServerTest, ZioStreams).tests() ++
new ServerStreamingTests(createServerTest, maxLengthSupported = true).tests(ZioStreams)(drainZStream) ++
new ServerWebSocketTests(createServerTest, ZioStreams) {
override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = in => in.map(f)
override def emptyPipe[A, B]: streams.Pipe[A, B] = _ => ZStream.empty
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import sttp.tapir._
import sttp.tapir.integ.cats.effect.CatsMonadError
import sttp.tapir.server.tests._
import sttp.tapir.tests.{Test, TestSuite}
import zio.{RIO, UIO}
import zio.{RIO, Task, UIO}
import zio.blocking.Blocking
import zio.clock.Clock
import zio.interop.catz._
Expand Down Expand Up @@ -53,7 +53,7 @@ class ZHttp4sServerTest extends TestSuite with OptionValues {
)

new AllServerTests(createServerTest, interpreter, backend).tests() ++
new ServerStreamingTests(createServerTest, ZioStreams).tests() ++
new ServerStreamingTests(createServerTest, maxLengthSupported = false).tests(ZioStreams)(_ => Task.unit) ++
new ServerWebSocketTests(createServerTest, ZioStreams) {
override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = in => in.map(f)
override def emptyPipe[A, B]: streams.Pipe[A, B] = _ => zio.stream.Stream.empty
Expand Down
Loading
Loading