Skip to content

Commit

Permalink
S3: Add multipart upload with context (#2770)
Browse files Browse the repository at this point in the history
  • Loading branch information
mdedetrich authored Nov 29, 2021
1 parent 5af62b2 commit 94f239a
Show file tree
Hide file tree
Showing 11 changed files with 1,103 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import akka.http.scaladsl.model.Uri.{Authority, Query}
import akka.http.scaladsl.model.headers.{`Raw-Request-URI`, Host, RawHeader}
import akka.http.scaladsl.model.{RequestEntity, _}
import akka.stream.alpakka.s3.AccessStyle.{PathAccessStyle, VirtualHostAccessStyle}
import akka.stream.alpakka.s3.{ApiVersion, S3Settings}
import akka.stream.alpakka.s3.{ApiVersion, MultipartUpload, S3Settings}
import akka.stream.scaladsl.Source
import akka.util.ByteString
import software.amazon.awssdk.regions.Region
Expand Down Expand Up @@ -200,7 +200,7 @@ import scala.concurrent.{ExecutionContext, Future}
payload: Chunk,
s3Headers: Seq[HttpHeader] = Seq.empty)(implicit conf: S3Settings): HttpRequest =
s3Request(
upload.s3Location,
S3Location(upload.bucket, upload.key),
HttpMethods.PUT,
_.withQuery(Query("partNumber" -> partNumber.toString, "uploadId" -> upload.uploadId))
).withDefaultHeaders(s3Headers)
Expand All @@ -224,7 +224,7 @@ import scala.concurrent.{ExecutionContext, Future}
entity <- Marshal(payload).to[RequestEntity]
} yield {
s3Request(
upload.s3Location,
S3Location(upload.bucket, upload.key),
HttpMethods.POST,
_.withQuery(Query("uploadId" -> upload.uploadId))
).withEntity(entity).withDefaultHeaders(headers)
Expand Down Expand Up @@ -261,7 +261,7 @@ import scala.concurrent.{ExecutionContext, Future}

val allHeaders = s3Headers ++ copyHeaders

s3Request(upload.s3Location,
s3Request(S3Location(upload.bucket, upload.key),
HttpMethods.PUT,
_.withQuery(Query("partNumber" -> copyPartition.partNumber.toString, "uploadId" -> upload.uploadId)))
.withDefaultHeaders(allHeaders)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import scala.xml.NodeSeq
nodeSeqUnmarshaller(MediaTypes.`application/xml`, ContentTypes.`application/octet-stream`) map {
case NodeSeq.Empty => throw Unmarshaller.NoContentException
case x =>
MultipartUpload(S3Location((x \ "Bucket").text, (x \ "Key").text), (x \ "UploadId").text)
MultipartUpload((x \ "Bucket").text, (x \ "Key").text, (x \ "UploadId").text)
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Copyright (C) since 2016 Lightbend Inc. <https://www.lightbend.com>
*/

package akka.stream.alpakka.s3.impl

import akka.annotation.InternalApi
import akka.stream.{Attributes, FlowShape, Inlet, Outlet}
import akka.stream.stage.{GraphStage, GraphStageLogic, InHandler, OutHandler}
import akka.util.ByteString

import scala.collection.immutable
import scala.collection.mutable.ListBuffer

/**
* Internal Api
*
* Buffers the complete incoming stream into memory, which can then be read several times afterwards.
*
* The stage waits for the incoming stream containing a context to complete. After that, it emits a single Chunk item
* on its output which contains the latest context at that point in time. The Chunk contains a `ByteString` source that
* can be materialized multiple times, and the total size of the file.
*
* @param maxSize Maximum size to buffer
*/
@InternalApi private[impl] final class MemoryWithContext[C](maxSize: Int)
extends GraphStage[FlowShape[(ByteString, C), (Chunk, immutable.Iterable[C])]] {
val in = Inlet[(ByteString, C)]("MemoryBuffer.in")
val out = Outlet[(Chunk, immutable.Iterable[C])]("MemoryBuffer.out")
override val shape = FlowShape.of(in, out)

override def createLogic(attr: Attributes): GraphStageLogic =
new GraphStageLogic(shape) with InHandler with OutHandler {
private var buffer = ByteString.empty
private val contextBuffer: ListBuffer[C] = new ListBuffer[C]()

override def onPull(): Unit = if (isClosed(in)) emit() else pull(in)

override def onPush(): Unit = {
val (elem, context) = grab(in)
if (buffer.size + elem.size > maxSize) {
failStage(new IllegalStateException("Buffer size of " + maxSize + " bytes exceeded."))
} else {
buffer ++= elem
// This is a corner case where context can have a sentinel value of null which represents the initial empty
// stream. We don't want to add null's into the final output
if (context != null)
contextBuffer.append(context)
pull(in)
}
}

override def onUpstreamFinish(): Unit = {
if (isAvailable(out)) emit()
completeStage()
}

private def emit(): Unit = emit(out, (MemoryChunk(buffer), contextBuffer.toList), () => completeStage())

setHandlers(in, out, this)
}

}
219 changes: 177 additions & 42 deletions s3/src/main/scala/akka/stream/alpakka/s3/impl/S3Stream.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,28 +41,6 @@ import scala.util.{Failure, Success, Try}
}
}

/** Internal Api */
@InternalApi private[impl] final case class MultipartUpload(s3Location: S3Location, uploadId: String)

/** Internal Api */
@InternalApi private[impl] sealed trait UploadPartResponse {
def multipartUpload: MultipartUpload

def index: Int
}

/** Internal Api */
@InternalApi private[impl] final case class SuccessfulUploadPart(multipartUpload: MultipartUpload,
index: Int,
eTag: String)
extends UploadPartResponse

/** Internal Api */
@InternalApi private[impl] final case class FailedUploadPart(multipartUpload: MultipartUpload,
index: Int,
exception: Throwable)
extends UploadPartResponse

/** Internal Api */
@InternalApi private[impl] final case class CompleteMultipartUploadResult(location: Uri,
bucket: String,
Expand Down Expand Up @@ -856,6 +834,21 @@ import scala.util.{Failure, Success, Try}
chunkAndRequest(s3Location, contentType, s3Headers, chunkSize)(chunkingParallelism)
.toMat(completionSink(s3Location, s3Headers.serverSideEncryption))(Keep.right)

/**
* Uploads a stream of ByteStrings along with a context to a specified location as a multipart upload. The
* chunkUploadSink parameter allows you to act upon the context when a chunk has been uploaded to S3.
*/
def multipartUploadWithContext[C](
s3Location: S3Location,
chunkUploadSink: Sink[(UploadPartResponse, immutable.Iterable[C]), NotUsed],
contentType: ContentType = ContentTypes.`application/octet-stream`,
s3Headers: S3Headers,
chunkSize: Int = MinChunkSize,
chunkingParallelism: Int = 4
): Sink[(ByteString, C), Future[MultipartUploadResult]] =
chunkAndRequestWithContext[C](s3Location, contentType, s3Headers, chunkSize, chunkUploadSink)(chunkingParallelism)
.toMat(completionSink(s3Location, s3Headers.serverSideEncryption))(Keep.right)

/**
* Resumes a previously created a multipart upload by uploading a stream of ByteStrings to a specified location
* and uploadId
Expand All @@ -869,21 +862,46 @@ import scala.util.{Failure, Success, Try}
chunkingParallelism: Int = 4): Sink[ByteString, Future[MultipartUploadResult]] = {
val initialUpload = Some((uploadId, previousParts.size + 1))
val successfulParts = previousParts.map { part =>
SuccessfulUploadPart(MultipartUpload(s3Location, uploadId), part.partNumber, part.eTag)
SuccessfulUploadPart(MultipartUpload(s3Location.bucket, s3Location.key, uploadId), part.partNumber, part.eTag)
}
chunkAndRequest(s3Location, contentType, s3Headers, chunkSize, initialUpload)(chunkingParallelism)
.prepend(Source(successfulParts))
.toMat(completionSink(s3Location, s3Headers.serverSideEncryption))(Keep.right)
}

/**
* Resumes a previously created a multipart upload by uploading a stream of ByteStrings to a specified location
* and uploadId. The chunkUploadSink parameter allows you to act upon the context when a chunk has been uploaded to
* S3.
*/
def resumeMultipartUploadWithContext[C](
s3Location: S3Location,
uploadId: String,
previousParts: immutable.Iterable[Part],
chunkUploadSink: Sink[(UploadPartResponse, immutable.Iterable[C]), NotUsed],
contentType: ContentType = ContentTypes.`application/octet-stream`,
s3Headers: S3Headers,
chunkSize: Int = MinChunkSize,
chunkingParallelism: Int = 4
): Sink[(ByteString, C), Future[MultipartUploadResult]] = {
val initialUpload = Some((uploadId, previousParts.size + 1))
val successfulParts = previousParts.map { part =>
SuccessfulUploadPart(MultipartUpload(s3Location.bucket, s3Location.key, uploadId), part.partNumber, part.eTag)
}
chunkAndRequestWithContext[C](s3Location, contentType, s3Headers, chunkSize, chunkUploadSink, initialUpload)(
chunkingParallelism
).prepend(Source(successfulParts))
.toMat(completionSink(s3Location, s3Headers.serverSideEncryption))(Keep.right)
}

def completeMultipartUpload(
s3Location: S3Location,
uploadId: String,
parts: immutable.Iterable[Part],
s3Headers: S3Headers
)(implicit mat: Materializer, attr: Attributes): Future[MultipartUploadResult] = {
val successfulParts = parts.map { part =>
SuccessfulUploadPart(MultipartUpload(s3Location, uploadId), part.partNumber, part.eTag)
SuccessfulUploadPart(MultipartUpload(s3Location.bucket, s3Location.key, uploadId), part.partNumber, part.eTag)
}
Source(successfulParts)
.toMat(completionSink(s3Location, s3Headers.serverSideEncryption).withAttributes(attr))(Keep.right)
Expand Down Expand Up @@ -983,7 +1001,7 @@ import scala.util.{Failure, Success, Try}

Source
.future(
completeMultipartUploadRequest(parts.head.multipartUpload, parts.map(p => p.index -> p.eTag), headers)
completeMultipartUploadRequest(parts.head.multipartUpload, parts.map(p => p.partNumber -> p.eTag), headers)
)
.flatMapConcat(signAndGetAs[CompleteMultipartUploadResult](_, populateResult(_, _)))
.runWith(Sink.head)
Expand Down Expand Up @@ -1067,21 +1085,6 @@ import scala.util.{Failure, Success, Try}

val chunkBufferSize = chunkSize * 2

val requestInfoOrInitialUploadState = initialUploadState match {
case Some((uploadId, initialIndex)) =>
// We are resuming from a previously aborted Multipart upload so rather than creating a new MultipartUpload
// resource we just need to set up the initial state
Source
.single(s3Location)
.flatMapConcat(_ => Source.single(MultipartUpload(s3Location, uploadId)))
.mapConcat(r => Stream.continually(r))
.zip(Source.fromIterator(() => Iterator.from(initialIndex)))
case None =>
// First step of the multi part upload process is made.
// The response is then used to construct the subsequent individual upload part requests
initiateUpload(s3Location, contentType, s3Headers.headersFor(InitiateMultipartUpload))
}

val headers = s3Headers.serverSideEncryption.toIndexedSeq.flatMap(_.headersFor(UploadPart))

Flow
Expand Down Expand Up @@ -1125,7 +1128,7 @@ import scala.util.{Failure, Success, Try}
.mergeSubstreamsWithParallelism(parallelism)
.filter(_.size > 0)
.via(atLeastOne)
.zip(requestInfoOrInitialUploadState)
.zip(requestInfoOrUploadState(s3Location, contentType, s3Headers, initialUploadState))
.groupBy(parallelism, { case (_, (_, chunkIndex)) => chunkIndex % parallelism })
// Allow requests that fail with transient errors to be retried, using the already buffered chunk.
.via(RetryFlow.withBackoff(minBackoff, maxBackoff, randomFactor, maxRetries, retriableFlow) {
Expand All @@ -1149,6 +1152,138 @@ import scala.util.{Failure, Success, Try}
.mapMaterializedValue(_ => NotUsed)
}

private def chunkAndRequestWithContext[C](
s3Location: S3Location,
contentType: ContentType,
s3Headers: S3Headers,
chunkSize: Int,
chunkUploadSink: Sink[(UploadPartResponse, immutable.Iterable[C]), NotUsed],
initialUploadState: Option[(String, Int)] = None
)(parallelism: Int): Flow[(ByteString, C), UploadPartResponse, NotUsed] = {

// This part of the API doesn't support disk-buffer because we have no way of serializing a C to a ByteString
// so we only store the chunks in memory
def getChunk(bufferSize: Int) =
new MemoryWithContext[C](bufferSize)

// Multipart upload requests (except for the completion api) are created here.
// The initial upload request gets executed within this function as well.
// The individual upload part requests are created.

assert(
chunkSize >= MinChunkSize,
s"Chunk size must be at least 5 MB = $MinChunkSize bytes (was $chunkSize bytes). See http://docs.aws.amazon.com/AmazonS3/latest/API/mpUploadUploadPart.html"
)

val chunkBufferSize = chunkSize * 2

val headers = s3Headers.serverSideEncryption.toIndexedSeq.flatMap(_.headersFor(UploadPart))

Flow
.fromMaterializer { (mat, attr) =>
implicit val conf: S3Settings = resolveSettings(attr, mat.system)
implicit val sys: ActorSystem = mat.system
implicit val materializer: Materializer = mat

// Emits at a chunk if no chunks have been emitted. Ensures that we can upload empty files.
val atLeastOne =
Flow[(Chunk, immutable.Iterable[C])]
.prefixAndTail(1)
.flatMapConcat {
case (prefix, tail) =>
if (prefix.nonEmpty) {
Source(prefix).concat(tail)
} else {
Source.single((MemoryChunk(ByteString.empty), immutable.Iterable.empty))
}
}

val retriableFlow: Flow[((Chunk, (MultipartUpload, Int)), immutable.Iterable[C]),
((Try[HttpResponse], (MultipartUpload, Int)), immutable.Iterable[C]),
NotUsed] =
Flow[((Chunk, (MultipartUpload, Int)), immutable.Iterable[C])]
.map {
case ((chunkedPayload, (uploadInfo, chunkIndex)), allContext) =>
//each of the payload requests are created
val partRequest =
uploadPartRequest(uploadInfo, chunkIndex, chunkedPayload, headers)
((partRequest, (uploadInfo, chunkIndex)), allContext)
}
.flatMapConcat {
case ((req, info), allContext) =>
Signer.signedRequest(req, signingKey, conf.signAnonymousRequests).zip(Source.single(info)).map {
case (httpRequest, data) => (httpRequest, (data, allContext))
}
}
.via(superPool[((MultipartUpload, Int), immutable.Iterable[C])])
.map {
case (response, (info, allContext)) => ((response, info), allContext)
}

import conf.multipartUploadSettings.retrySettings._

val atLeastOneByteStringAndEmptyContext: Flow[(ByteString, C), (ByteString, C), NotUsed] =
Flow[(ByteString, C)].orElse(
Source.single((ByteString.empty, null.asInstanceOf[C]))
)

SplitAfterSizeWithContext(chunkSize)(atLeastOneByteStringAndEmptyContext)
.via(getChunk(chunkBufferSize))
.mergeSubstreamsWithParallelism(parallelism)
.filter { case (chunk, _) => chunk.size > 0 }
.via(atLeastOne)
.zip(requestInfoOrUploadState(s3Location, contentType, s3Headers, initialUploadState))
.groupBy(parallelism, { case (_, (_, chunkIndex)) => chunkIndex % parallelism })
.map {
case ((chunk, allContext), info) =>
((chunk, info), allContext)
}
// Allow requests that fail with transient errors to be retried, using the already buffered chunk.
.via(RetryFlow.withBackoff(minBackoff, maxBackoff, randomFactor, maxRetries, retriableFlow) {
case ((chunkAndUploadInfo, allContext), ((Success(r), _), _)) =>
if (isTransientError(r.status)) {
r.entity.discardBytes()
Some((chunkAndUploadInfo, allContext))
} else {
None
}
case ((chunkAndUploadInfo, allContext), ((Failure(_), _), _)) =>
// Treat any exception as transient.
Some((chunkAndUploadInfo, allContext))
})
.mapAsync(1) {
case ((response, (upload, index)), allContext) =>
handleChunkResponse(response, upload, index, conf.multipartUploadSettings.retrySettings).map { result =>
(result, allContext)
}(ExecutionContexts.parasitic)
}
.alsoTo(chunkUploadSink)
.map { case (result, _) => result }
.mergeSubstreamsWithParallelism(parallelism)
}
.mapMaterializedValue(_ => NotUsed)
}

private def requestInfoOrUploadState(s3Location: S3Location,
contentType: ContentType,
s3Headers: S3Headers,
initialUploadState: Option[(String, Int)]) = {
initialUploadState match {
case Some((uploadId, initialIndex)) =>
// We are resuming from a previously aborted Multipart upload so rather than creating a new MultipartUpload
// resource we just need to set up the initial state
Source
.single(s3Location)
.flatMapConcat(_ => Source.single(MultipartUpload(s3Location.bucket, s3Location.key, uploadId)))
.mapConcat(r => Stream.continually(r))
.zip(Source.fromIterator(() => Iterator.from(initialIndex)))
case None =>
// First step of the multi part upload process is made.
// The response is then used to construct the subsequent individual upload part requests
initiateUpload(s3Location, contentType, s3Headers.headersFor(InitiateMultipartUpload))
}
}

private def handleChunkResponse(response: Try[HttpResponse],
upload: MultipartUpload,
index: Int,
Expand Down Expand Up @@ -1209,7 +1344,7 @@ import scala.util.{Failure, Success, Try}
if (responses.isEmpty) {
Future.failed(new RuntimeException("No Responses"))
} else if (failures.isEmpty) {
Future.successful(successes.sortBy(_.index))
Future.successful(successes.sortBy(_.partNumber))
} else {
Future.failed(FailedUpload(failures.map(_.exception)))
}
Expand Down
Loading

0 comments on commit 94f239a

Please sign in to comment.