Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement multipart body support in sttp stub #4117

Merged
merged 9 commits into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,13 @@ import sttp.tapir.RawBodyType
import sttp.tapir.model.ServerRequest
import sttp.tapir.server.interpreter.{RawValue, RequestBody}

import java.io.{ByteArrayInputStream, ByteArrayOutputStream, InputStream}
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, File, InputStream}
import java.nio.ByteBuffer
import scala.annotation.tailrec
import sttp.client3
import sttp.model.Part
import sttp.model.MediaType
import sttp.tapir.FileRange

class SttpRequestBody[F[_]](implicit ME: MonadError[F]) extends RequestBody[F, AnyStreams] {
override val streams: AnyStreams = AnyStreams
Expand All @@ -26,7 +30,12 @@ class SttpRequestBody[F[_]](implicit ME: MonadError[F]) extends RequestBody[F, A
case RawBodyType.InputStreamRangeBody => ME.unit(RawValue(InputStreamRange(() => new ByteArrayInputStream(bytes))))
case _: RawBodyType.MultipartBody => ME.error(new UnsupportedOperationException)
}
case _ => throw new IllegalArgumentException("Stream body provided while endpoint accepts raw body type")
case Right(value) =>
bodyType match {
case mp: RawBodyType.MultipartBody =>
ME.unit(RawValue(extractMultipartParts(value.asInstanceOf[Seq[Part[client3.RequestBody[_]]]], mp)))
case _ => throw new IllegalArgumentException("Stream body provided while endpoint accepts raw body type")
}
}

override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = body(serverRequest) match {
Expand All @@ -36,7 +45,6 @@ class SttpRequestBody[F[_]](implicit ME: MonadError[F]) extends RequestBody[F, A

private def sttpRequest(serverRequest: ServerRequest) = serverRequest.underlying.asInstanceOf[Request[_, _]]

/** Either bytes or any stream */
private def body(serverRequest: ServerRequest): Either[Array[Byte], Any] = sttpRequest(serverRequest).body match {
case NoBody => Left(Array.emptyByteArray)
case StringBody(s, encoding, _) => Left(s.getBytes(encoding))
Expand All @@ -45,8 +53,7 @@ class SttpRequestBody[F[_]](implicit ME: MonadError[F]) extends RequestBody[F, A
case InputStreamBody(b, _) => Left(toByteArray(b))
case FileBody(f, _) => Left(f.readAsByteArray)
case StreamBody(s) => Right(s)
case MultipartBody(_) =>
throw new IllegalArgumentException("Stub cannot handle multipart bodies")
case MultipartBody(parts) => Right(parts)
abdelfetah18 marked this conversation as resolved.
Show resolved Hide resolved
}

private def toByteArray(is: InputStream): Array[Byte] = {
Expand All @@ -66,4 +73,52 @@ class SttpRequestBody[F[_]](implicit ME: MonadError[F]) extends RequestBody[F, A
transfer()
os.toByteArray
}

private def extractMultipartParts(parts: Seq[Part[client3.RequestBody[_]]], bodyType: RawBodyType.MultipartBody): List[Part[Any]] = {
parts.flatMap { part =>
bodyType.partType(part.name).flatMap { partType =>
extractPartBody(part, partType).map { body =>
Part(
name = part.name,
body = body,
contentType = part.contentType.flatMap(ct => MediaType.parse(ct).toOption),
fileName = part.fileName
)
}
}
}.toList
}

private def extractPartBody[B](part: Part[client3.RequestBody[_]], bodyType: RawBodyType[B]): Option[Any] = {
part.body match {
case ByteArrayBody(b, _) =>
bodyType match {
case RawBodyType.StringBody(charset) => Some(b)
case RawBodyType.ByteArrayBody => Some(b)
case RawBodyType.ByteBufferBody => Some(ByteBuffer.wrap(b))
case RawBodyType.InputStreamBody => Some(new ByteArrayInputStream(b))
case RawBodyType.InputStreamRangeBody => Some(InputStreamRange(() => new ByteArrayInputStream(b)))
case RawBodyType.FileBody => None
abdelfetah18 marked this conversation as resolved.
Show resolved Hide resolved
case _: RawBodyType.MultipartBody => None
}
case FileBody(f, _) =>
bodyType match {
case RawBodyType.FileBody => Some(FileRange(new File(f.toString)))
case _ => None
}
case StringBody(s, charset, _) =>
bodyType match {
case RawBodyType.StringBody(_) => Some(s)
case RawBodyType.ByteArrayBody => Some(s.getBytes(charset))
case RawBodyType.ByteBufferBody => Some(ByteBuffer.wrap(s.getBytes(charset)))
case _ => None
}
case InputStreamBody(is, _) =>
bodyType match {
case RawBodyType.InputStreamBody => Some(is)
case _ => None
abdelfetah18 marked this conversation as resolved.
Show resolved Hide resolved
}
case _ => None
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import sttp.tapir.server.interceptor.exception.ExceptionHandler
import sttp.tapir.server.interceptor.reject.RejectHandler
import sttp.tapir.server.interceptor.{CustomiseInterceptors, Interceptor}
import sttp.tapir.server.model.ValuedEndpointOutput
import sttp.tapir.generic.auto._

class TapirStubInterpreterTest extends AnyFlatSpec with Matchers {

Expand Down Expand Up @@ -204,8 +205,90 @@ class TapirStubInterpreterTest extends AnyFlatSpec with Matchers {
response.body shouldBe Left("Internal server error")
response.code shouldBe StatusCode.InternalServerError
}

it should "handle multipart body" in {
abdelfetah18 marked this conversation as resolved.
Show resolved Hide resolved
// given
val e =
endpoint.post
.in("api" / "multipart")
.in(multipartBody)
.out(stringBody)

val server = TapirStubInterpreter(SttpBackendStub(IdMonad))
.whenEndpoint(e)
.thenRespond("success")
.backend()

// when
val response = sttp.client3.basicRequest
.post(uri"http://test.com/api/multipart")
.multipartBody(multipart("name", "abc"))
.send(server)

// then
response.body shouldBe Right("success")
}

it should "correctly process a multipart body" in {
abdelfetah18 marked this conversation as resolved.
Show resolved Hide resolved
// given
val e =
endpoint.post
.in("api" / "multipart")
.in(multipartBody)
.out(stringBody)

val server = TapirStubInterpreter(SttpBackendStub(IdMonad))
abdelfetah18 marked this conversation as resolved.
Show resolved Hide resolved
.whenServerEndpointRunLogic(e.serverLogic(multipartData => {
val partOpt = multipartData.find(_.name == "name")
partOpt match {
case Some(part) =>
val data = new String(part.body)
IdMonad.unit(Right("Hello " + data))
case None =>
IdMonad.unit(Right("Part not found"))
}
}))
.backend()

// when
val response = sttp.client3.basicRequest
.post(uri"http://test.com/api/multipart")
.multipartBody(multipart("name", "abc"))
.send(server)

// then
response.body shouldBe Right("Hello abc")
}

it should "correctly handle derived multipart body" in {
// given
val e =
endpoint.post
.in("api" / "multipart")
.in(multipartBody[MultipartData])
.out(stringBody)

val server = TapirStubInterpreter(SttpBackendStub(IdMonad))
.whenServerEndpointRunLogic(e.serverLogic(multipartData => {
IdMonad.unit(Right("Hello " + multipartData.name))
abdelfetah18 marked this conversation as resolved.
Show resolved Hide resolved
}))
.backend()

// when
val response = sttp.client3.basicRequest
.post(uri"http://test.com/api/multipart")
.multipartBody(
multipart("name", "abc")
)
.send(server)

// then
response.body shouldBe Right("Hello abc")
}
}

case class MultipartData(name: String)
abdelfetah18 marked this conversation as resolved.
Show resolved Hide resolved

object ProductsApi {

val getProduct: Endpoint[Unit, Unit, String, String, Any] = endpoint.get
Expand Down
Loading