Skip to content

Commit

Permalink
Extract payload signature in it's own middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
daddykotex committed Apr 19, 2024
1 parent 971ec28 commit 295afb7
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 11 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Copyright 2021-2024 Disney Streaming
*
* Licensed under the Tomorrow Open Source Technology License, Version 1.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://disneystreaming.github.io/TOST-1.0.txt
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package smithy4s.aws
package internals

import cats.effect.Concurrent
import cats.effect.Resource
import cats.syntax.all._
import fs2.Chunk
import org.http4s._
import org.http4s.client.Client
import org.typelevel.ci.CIString
import smithy4s._
import smithy4s.aws.kernel.AwsCrypto._

private[aws] sealed trait AwsPayloadSignature {
import AwsPayloadSignature._
val headerValue: String = this match {
case Sha256(v) => v
case UnsignedPayload => "UNSIGNED-PAYLOAD"
// case StreamingUnsignedPayload => "STREAMING-UNSIGNED-PAYLOAD-TRAILER"
}
}

/**
* This is a draft API. There are many other ways to include the payload in the signature.
* Some of which are complex: using trailers and/or multiple chunks
*/
private[aws] object AwsPayloadSignature {
case class Sha256(value: String) extends AwsPayloadSignature
case object UnsignedPayload extends AwsPayloadSignature
// case object StreamingUnsignedPayload extends AwsPayloadSignature

val `X-Amz-Content-SHA256` = CIString("X-Amz-Content-SHA256")

def makeHeader(value: AwsPayloadSignature): Header.Raw =
Header.Raw(`X-Amz-Content-SHA256`, value.headerValue)


def signSingleChunk[F[_]: Concurrent]: Endpoint.Middleware[Client[F]] =
new Endpoint.Middleware[Client[F]] {
def prepare[Alg[_[_, _, _, _, _]]](service: Service[Alg])(
endpoint: service.Endpoint[_, _, _, _, _]
): Client[F] => Client[F] = { client =>
Client { request =>
Resource.eval(hashSingleChunk(request)).flatMap { request =>
client.run(request)
}
}
}
}

private def hashSingleChunk[F[_]: Concurrent](
request: Request[F]
): F[Request[F]] = {
request.body.chunks.compile.to(Chunk).map(_.flatten).map { body =>
val payloadHash = sha256HexDigest(body.toArray)
val signature = AwsPayloadSignature.Sha256(payloadHash)
request.putHeaders(AwsPayloadSignature.makeHeader(signature))
}
}
}
28 changes: 17 additions & 11 deletions modules/aws-http4s/src/smithy4s/aws/internals/AwsSigning.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,18 @@ package internals
import cats.effect.Concurrent
import cats.effect.Resource
import cats.syntax.all._
import fs2.Chunk
import org.http4s._
import org.http4s.client.Client
import org.typelevel.ci.CIString
import smithy4s._
import smithy4s.aws.kernel.AwsCrypto._
import smithy4s.aws.internals.AwsPayloadSignature.`X-Amz-Content-SHA256`

import java.net.URLEncoder
import java.nio.charset.StandardCharsets

/**
* A Client middleware that signs http requests before they are sent to AWS.
* This works by compiling the body of the request in memory in a chunk before sending
* it back, which means it is not proper to use it in the context of streaming.
*/
private[aws] object AwsSigning {

Expand Down Expand Up @@ -108,8 +106,7 @@ private[aws] object AwsSigning {
// scalafmt: { align.preset = most, danglingParentheses.preset = false, maxColumn = 240, align.tokens = [{code = ":"}]}
(request: Request[F]) => {

val bodyF = request.body.chunks.compile.to(Chunk).map(_.flatten)
val awsHeadersF = (bodyF, timestamp, credentials, region).mapN { case (body, timestamp, credentials, region) =>
val awsHeadersF = (timestamp, credentials, region).mapN { case (timestamp, credentials, region) =>
val credentialsScope = s"${timestamp.conciseDate}/$region/$endpointPrefix/aws4_request"
val queryParams: Vector[(String, String)] =
request.uri.query.toVector.sorted.map { case (k, v) => k -> v.getOrElse("") }
Expand All @@ -122,23 +119,32 @@ private[aws] object AwsSigning {
}
.mkString("&")

// // !\ Important: these must remain in the same order
val baseHeadersList = List(
val amzHeaders: List[(CIString, String)] = request.headers.headers
.map(h => (h.name, h.value))
.filterNot(_._2 == null)

// It is assumed that the hash value is computed before this middleware run
// via another middleware. If it is not, we use a default unsigned value
val payloadHash = amzHeaders.find(_._1 == `X-Amz-Content-SHA256`).map(_._2).getOrElse(AwsPayloadSignature.UnsignedPayload.headerValue)

val addedHeaders: List[(CIString, String)] = List(
`Content-Type` -> request.contentType.map(contentType.value(_)).orNull,
`Host` -> request.uri.host.map(_.renderString).orNull,
`X-Amz-Date` -> timestamp.conciseDateTime,
`X-Amz-Security-Token` -> credentials.sessionToken.orNull,
`X-Amz-Target` -> (serviceName + "." + operationName)
).filterNot(_._2 == null)

val canonicalHeadersString = baseHeadersList
// Headers included in the signature needs to be sorted alphabetically
val allHeaders = (addedHeaders ++ amzHeaders).sortBy(_._1)

val canonicalHeadersString = allHeaders
.map { case (key, value) =>
key.toString.toLowerCase + ":" + value.trim
}
.mkString(newline)
lazy val signedHeadersString = baseHeadersList.map(_._1).map(_.toString.toLowerCase()).mkString(";")
lazy val signedHeadersString = allHeaders.map(_._1).map(_.toString.toLowerCase()).mkString(";")

val payloadHash = sha256HexDigest(body.toArray)
val pathString = request.uri.path.toAbsolute.renderString
val canonicalRequest = new StringBuilder()
.append(request.method.name.toUpperCase())
Expand Down Expand Up @@ -171,7 +177,7 @@ private[aws] object AwsSigning {
val signature = toHexString(hmacSha256(stringToSign, signatureKey))
val authHeaderValue = s"${algorithm} Credential=${credentials.accessKeyId}/$credentialsScope, SignedHeaders=$signedHeadersString, Signature=$signature"
val authHeader = Headers("Authorization" -> authHeaderValue)
val baseHeaders = Headers(baseHeadersList.map { case (k, v) => Header.Raw(k, v) })
val baseHeaders = Headers(addedHeaders.map { case (k, v) => Header.Raw(k, v) })
authHeader ++ baseHeaders
}

Expand Down

0 comments on commit 295afb7

Please sign in to comment.