Skip to content

Commit

Permalink
AWS S3: use strict http entity for in-memory chunks (#2703)
Browse files Browse the repository at this point in the history
  • Loading branch information
phiSgr authored Aug 19, 2021
1 parent 54f780e commit ceaa570
Show file tree
Hide file tree
Showing 8 changed files with 36 additions and 17 deletions.
15 changes: 14 additions & 1 deletion s3/src/main/scala/akka/stream/alpakka/s3/impl/Chunk.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,22 @@ package akka.stream.alpakka.s3.impl
import akka.stream.scaladsl.Source
import akka.NotUsed
import akka.annotation.InternalApi
import akka.http.scaladsl.model.{ContentTypes, HttpEntity, RequestEntity}
import akka.util.ByteString

/**
* Internal Api
*/
@InternalApi private[impl] final case class Chunk(data: Source[ByteString, NotUsed], size: Int)
@InternalApi private[impl] sealed trait Chunk {
def asEntity(): RequestEntity
def size: Int
}

@InternalApi private[impl] final case class DiskChunk(data: Source[ByteString, NotUsed], size: Int) extends Chunk {
def asEntity(): RequestEntity = HttpEntity(ContentTypes.`application/octet-stream`, size, data)
}

@InternalApi private[impl] final case class MemoryChunk(data: ByteString) extends Chunk {
def asEntity(): RequestEntity = HttpEntity.Strict(ContentTypes.`application/octet-stream`, data)
def size: Int = data.size
}
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ import akka.annotation.InternalApi
}(ExecutionContexts.parasitic)
NotUsed
}
emit(out, Chunk(src, length), () => completeStage())
emit(out, DiskChunk(src, length), () => completeStage())
}
setHandlers(in, out, this)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import akka.http.scaladsl.marshallers.xml.ScalaXmlSupport._
import akka.http.scaladsl.marshalling.Marshal
import akka.http.scaladsl.model.Uri.{Authority, Query}
import akka.http.scaladsl.model.headers.{`Raw-Request-URI`, Host, RawHeader}
import akka.http.scaladsl.model.{ContentTypes, RequestEntity, _}
import akka.http.scaladsl.model.{RequestEntity, _}
import akka.stream.alpakka.s3.AccessStyle.{PathAccessStyle, VirtualHostAccessStyle}
import akka.stream.alpakka.s3.{ApiVersion, S3Settings}
import akka.stream.scaladsl.Source
Expand Down Expand Up @@ -97,15 +97,14 @@ import scala.concurrent.{ExecutionContext, Future}

def uploadPartRequest(upload: MultipartUpload,
partNumber: Int,
payload: Source[ByteString, _],
payloadSize: Int,
payload: Chunk,
s3Headers: Seq[HttpHeader] = Seq.empty)(implicit conf: S3Settings): HttpRequest =
s3Request(
upload.s3Location,
HttpMethods.PUT,
_.withQuery(Query("partNumber" -> partNumber.toString, "uploadId" -> upload.uploadId))
).withDefaultHeaders(s3Headers)
.withEntity(HttpEntity(ContentTypes.`application/octet-stream`, payloadSize, payload))
.withEntity(payload.asEntity())

def completeMultipartUploadRequest(upload: MultipartUpload, parts: Seq[(Int, String)], headers: Seq[HttpHeader])(
implicit ec: ExecutionContext,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ package akka.stream.alpakka.s3.impl

import akka.annotation.InternalApi
import akka.stream.{Attributes, FlowShape, Inlet, Outlet}
import akka.stream.scaladsl.Source
import akka.stream.stage.{GraphStage, GraphStageLogic, InHandler, OutHandler}
import akka.util.ByteString

Expand Down Expand Up @@ -46,7 +45,7 @@ import akka.util.ByteString
completeStage()
}

private def emit(): Unit = emit(out, Chunk(Source.single(buffer), buffer.size), () => completeStage())
private def emit(): Unit = emit(out, MemoryChunk(buffer), () => completeStage())

setHandlers(in, out, this)
}
Expand Down
4 changes: 2 additions & 2 deletions s3/src/main/scala/akka/stream/alpakka/s3/impl/S3Stream.scala
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,7 @@ import scala.util.{Failure, Success, Try}
if (prefix.nonEmpty) {
Source(prefix).concat(tail)
} else {
Source.single(Chunk(Source.empty, 0))
Source.single(MemoryChunk(ByteString.empty))
}
}

Expand All @@ -655,7 +655,7 @@ import scala.util.{Failure, Success, Try}
case (chunkedPayload, (uploadInfo, chunkIndex)) =>
//each of the payload requests are created
val partRequest =
uploadPartRequest(uploadInfo, chunkIndex, chunkedPayload.data, chunkedPayload.size, headers)
uploadPartRequest(uploadInfo, chunkIndex, chunkedPayload, headers)
(partRequest, (uploadInfo, chunkIndex))
}
.flatMapConcat { case (req, info) => Signer.signedRequest(req, signingKey).zip(Source.single(info)) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,12 @@ class DiskBufferSpec(_system: ActorSystem)

result should have size (1)
val chunk = result.head
chunk shouldBe a[DiskChunk]
val diskChunk = chunk.asInstanceOf[DiskChunk]
chunk.size should be(14)
chunk.data.runWith(Sink.seq).futureValue should be(Seq(ByteString(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14)))
diskChunk.data.runWith(Sink.seq).futureValue should be(
Seq(ByteString(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14))
)
}

it should "fail if more than maxSize bytes are fed into it" in {
Expand All @@ -63,12 +67,14 @@ class DiskBufferSpec(_system: ActorSystem)
it should "delete its temp file after N materializations" in {
val tmpDir = Files.createTempDirectory("DiskBufferSpec").toFile()
val before = tmpDir.list().size
val source = Source(Vector(ByteString(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14)))
val chunk = Source(Vector(ByteString(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14)))
.via(new DiskBuffer(2, 200, Some(tmpDir.toPath)))
.runWith(Sink.seq)
.futureValue
.head
.data

chunk shouldBe a[DiskChunk]
val source = chunk.asInstanceOf[DiskChunk].data

tmpDir.list().size should be(before + 1)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ import akka.http.scaladsl.model.headers.{`Raw-Request-URI`, ByteRange, RawHeader
import akka.stream.alpakka.s3.headers.{CannedAcl, ServerSideEncryption, StorageClass}
import akka.stream.alpakka.s3._
import akka.stream.alpakka.testkit.scaladsl.LogCapturing
import akka.stream.scaladsl.Source
import akka.testkit.{SocketUtil, TestKit, TestProbe}
import akka.util.ByteString
import org.scalatest.concurrent.{IntegrationPatience, ScalaFutures}
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
Expand Down Expand Up @@ -272,7 +272,7 @@ class HttpRequestsSpec extends AnyFlatSpec with Matchers with ScalaFutures with
implicit val settings = getSettings(s3Region = Region.EU_WEST_1).withEndpointUrl("http://localhost:8080")

val req =
HttpRequests.uploadPartRequest(multipartUpload, 1, Source.empty, 1)
HttpRequests.uploadPartRequest(multipartUpload, 1, MemoryChunk(ByteString.empty))

req.uri.scheme shouldEqual "http"
req.uri.authority.host.address shouldEqual "localhost"
Expand All @@ -284,7 +284,7 @@ class HttpRequestsSpec extends AnyFlatSpec with Matchers with ScalaFutures with
val myKey = "my-key"
val md5Key = "md5-key"
val s3Headers = ServerSideEncryption.customerKeys(myKey).withMd5(md5Key).headersFor(UploadPart)
val req = HttpRequests.uploadPartRequest(multipartUpload, 1, Source.empty, 1, s3Headers)
val req = HttpRequests.uploadPartRequest(multipartUpload, 1, MemoryChunk(ByteString.empty), s3Headers)

req.headers should contain(RawHeader("x-amz-server-side-encryption-customer-algorithm", "AES256"))
req.headers should contain(RawHeader("x-amz-server-side-encryption-customer-key", myKey))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ class MemoryBufferSpec(_system: ActorSystem)
result should have size (1)
val chunk = result.head
chunk.size should be(14)
chunk.data.runWith(Sink.seq).futureValue should be(Seq(ByteString(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14)))

chunk shouldBe a[MemoryChunk]
chunk.asInstanceOf[MemoryChunk].data should be(ByteString(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14))
}

it should "fail if more than maxSize bytes are fed into it" in {
Expand Down

0 comments on commit ceaa570

Please sign in to comment.