Skip to content

Commit

Permalink
Upgrade to SDK v2
Browse files Browse the repository at this point in the history
  • Loading branch information
mchv committed Nov 19, 2024
1 parent 6fe190d commit 89295fa
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 81 deletions.
32 changes: 18 additions & 14 deletions anghammarad/src/main/scala/com/gu/anghammarad/Config.scala
Original file line number Diff line number Diff line change
@@ -1,30 +1,33 @@
package com.gu.anghammarad

import com.amazonaws.auth._
import com.amazonaws.auth.profile.ProfileCredentialsProvider
import com.amazonaws.regions.Regions
import com.amazonaws.services.s3.AmazonS3Client
import com.amazonaws.services.s3.model.{GetObjectRequest, S3ObjectInputStream}
import software.amazon.awssdk.auth._
import software.amazon.awssdk.auth.credentials.AwsCredentialsProviderChain
import software.amazon.awssdk.auth.credentials.{EnvironmentVariableCredentialsProvider, ProfileCredentialsProvider}
import software.amazon.awssdk.core.ResponseInputStream
import software.amazon.awssdk.core.sync.ResponseTransformer
import software.amazon.awssdk.regions.Region
import software.amazon.awssdk.services.s3.S3Client
import software.amazon.awssdk.services.s3.model.{GetObjectRequest, GetObjectResponse}
import com.gu.anghammarad.common.AnghammaradException.Fail

import scala.io.Source
import scala.util.{Success, Try}


object Config {
val credentialsProvider = new AWSCredentialsProviderChain(
new ProfileCredentialsProvider("deployTools"),
new EnvironmentVariableCredentialsProvider()
val credentialsProvider = AwsCredentialsProviderChain.of(
ProfileCredentialsProvider.create("deployTools"),
EnvironmentVariableCredentialsProvider.create()
)

private val s3Client = AmazonS3Client
private val s3Client = S3Client
.builder
.withCredentials(credentialsProvider)
.withRegion(Regions.EU_WEST_1)
.credentialsProvider(credentialsProvider)
.region(Region.EU_WEST_1)
.build()

private def fetchContent(request: GetObjectRequest): Try[S3ObjectInputStream] = {
Try(s3Client.getObject(request).getObjectContent)
private def fetchContent(request: GetObjectRequest): Try[ResponseInputStream[GetObjectResponse]] = {
Try(s3Client.getObject(request, ResponseTransformer.toInputStream()))
}

private def fetchString(request: GetObjectRequest): Try[String] = {
Expand All @@ -46,7 +49,8 @@ object Config {
val bucket = s"anghammarad-configuration"
val key = s"$stage/anghammarad-config.json"

val request = new GetObjectRequest(bucket, key)
val request = GetObjectRequest.builder().key(key).bucket(bucket).build();

fetchString(request)
}
}
Original file line number Diff line number Diff line change
@@ -1,33 +1,36 @@
package com.gu.anghammarad.messages

import com.amazonaws.regions.Regions
import com.amazonaws.services.simpleemail.model.{Body, Content, Destination, SendEmailRequest, Message => AwsMessage}
import com.amazonaws.services.simpleemail.{AmazonSimpleEmailService, AmazonSimpleEmailServiceClientBuilder}
import software.amazon.awssdk.regions.Region
import software.amazon.awssdk.services.ses.model.{Body, Content, Destination, SendEmailRequest, Message => AwsMessage}
import software.amazon.awssdk.services.ses.SesClient
import com.gu.anghammarad.models.EmailMessage
import com.gu.anghammarad.Config

import scala.util.Try


object EmailService {
val client: AmazonSimpleEmailService = AmazonSimpleEmailServiceClientBuilder.standard().withRegion(Regions.EU_WEST_1)
.withCredentials(Config.credentialsProvider)
val client = SesClient.builder().region(Region.EU_WEST_1)
.credentialsProvider(Config.credentialsProvider)
.build()

def emailRequest(senderAddress: String, recipientAddress: String, message: EmailMessage): SendEmailRequest = {
def buildContent(data: String) = new Content().withCharset("UTF-8").withData(data)
def buildContent(data: String) = Content.builder().charset("UTF-8").data(data).build()

val awsMessage = new AwsMessage()
.withSubject(buildContent(message.subject))
.withBody(new Body()
.withHtml(buildContent(message.html))
.withText(buildContent(message.plainText))
val awsMessage = AwsMessage.builder()
.subject(buildContent(message.subject))
.body(Body.builder()
.html(buildContent(message.html))
.text(buildContent(message.plainText))
.build()
)
.build()

new SendEmailRequest()
.withDestination(new Destination().withToAddresses(recipientAddress))
.withSource(senderAddress)
.withMessage(awsMessage)
SendEmailRequest.builder()
.destination(Destination.builder().toAddresses(recipientAddress).build())
.source(senderAddress)
.message(awsMessage)
.build()
}
def sendEmail(senderAddress: String, recipientAddress: String, message: EmailMessage): Try[Unit] = {
Try(client.sendEmail(emailRequest(senderAddress, recipientAddress, message)))
Expand Down
11 changes: 6 additions & 5 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ inThisBuild(Seq(
licenses := Seq(License.Apache2),
))

val awsSdkVersion = "1.12.777"
val awsSdkVersion = "2.29.15"
val circeVersion = "0.14.10"
val flexmarkVersion = "0.64.8"
val scalaTestVersion = "3.2.19"
Expand Down Expand Up @@ -67,7 +67,8 @@ lazy val client = project
.settings(
name := "anghammarad-client",
libraryDependencies ++= Seq(
"com.amazonaws" % "aws-java-sdk-sns" % awsSdkVersion,
"software.amazon.awssdk" % "sns" % awsSdkVersion,

"org.json" % "json" % "20240303",
"com.typesafe.scala-logging" %% "scala-logging" % scalaLoggingVersion,
"org.scalatest" %% "scalatest" % scalaTestVersion % Test
Expand All @@ -84,9 +85,9 @@ lazy val anghammarad = project
"org.scala-lang.modules" %% "scala-collection-compat" % "2.12.0",
"com.amazonaws" % "aws-lambda-java-events" % "3.14.0",
"com.amazonaws" % "aws-lambda-java-core" % "1.2.3",
"com.amazonaws" % "aws-java-sdk-lambda" % awsSdkVersion,
"com.amazonaws" % "aws-java-sdk-ses" % awsSdkVersion,
"com.amazonaws" % "aws-java-sdk-s3" % awsSdkVersion,
"software.amazon.awssdk" % "lambda" % awsSdkVersion,
"software.amazon.awssdk" % "ses" % awsSdkVersion,
"software.amazon.awssdk" % "s3" % awsSdkVersion,
"io.circe" %% "circe-core" % circeVersion,
"io.circe" %% "circe-generic" % circeVersion,
"io.circe" %% "circe-parser" % circeVersion,
Expand Down
54 changes: 26 additions & 28 deletions client/src/main/scala/com/gu/anghammarad/AWS.scala
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
package com.gu.anghammarad

import com.amazonaws.AmazonWebServiceRequest
import com.amazonaws.auth.profile.ProfileCredentialsProvider
import com.amazonaws.auth.{AWSCredentialsProviderChain, EnvironmentVariableCredentialsProvider, InstanceProfileCredentialsProvider}
import com.amazonaws.handlers.AsyncHandler
import com.amazonaws.regions.Regions
import com.amazonaws.services.sns.{AmazonSNSAsync, AmazonSNSAsyncClientBuilder}

import java.util.concurrent.CompletableFuture

import software.amazon.awssdk.regions.Region
import software.amazon.awssdk.services.sns.SnsAsyncClient
import software.amazon.awssdk.auth.credentials.AwsCredentialsProviderChain
import software.amazon.awssdk.auth.credentials.{EnvironmentVariableCredentialsProvider, ProfileCredentialsProvider, InstanceProfileCredentialsProvider}



import scala.concurrent.{Future, Promise}

Expand All @@ -14,36 +17,31 @@ object AWS {
/**
* Use this to make an SNS client, or provide your own.
*/
def snsClient(credentialsProvider: AWSCredentialsProviderChain): AmazonSNSAsync = {
AmazonSNSAsyncClientBuilder.standard()
.withRegion(Regions.EU_WEST_1)
.withCredentials(credentialsProvider)
def snsClient(credentialsProvider: AwsCredentialsProviderChain): SnsAsyncClient = {
SnsAsyncClient.builder()
.region(Region.EU_WEST_1)
.credentialsProvider(credentialsProvider)
.build()
}

def credentialsProvider(): AWSCredentialsProviderChain = {
new AWSCredentialsProviderChain(
def credentialsProvider(): AwsCredentialsProviderChain = {
AwsCredentialsProviderChain.of(
// EC2
InstanceProfileCredentialsProvider.getInstance(),
InstanceProfileCredentialsProvider.create(),
// Lambda
new EnvironmentVariableCredentialsProvider(),
EnvironmentVariableCredentialsProvider.create(),
// local
new ProfileCredentialsProvider("deployTools")
ProfileCredentialsProvider.create("deployTools"),
)
}

private class AwsAsyncPromiseHandler[R <: AmazonWebServiceRequest, T](promise: Promise[T]) extends AsyncHandler[R, T] {
def onError(e: Exception): Unit = {
promise failure e
}
def onSuccess(r: R, t: T): Unit = {
promise success t
}
private[anghammarad] def asScala[T](cf: CompletableFuture[T]): Future[T] = {
val p = Promise[T]()
cf.whenCompleteAsync{ (result, ex) =>
if (result == null) p failure ex
else p success result
}
p.future
}

private[anghammarad] def awsToScala[R <: AmazonWebServiceRequest, T](sdkMethod: ( (R, AsyncHandler[R, T]) => java.util.concurrent.Future[T])): (R => Future[T]) = { req =>
val p = Promise[T]()
sdkMethod(req, new AwsAsyncPromiseHandler(p))
p.future
}
}
}
41 changes: 22 additions & 19 deletions client/src/main/scala/com/gu/anghammarad/Anghammarad.scala
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
package com.gu.anghammarad

import com.amazonaws.services.sns.AmazonSNSAsync
import com.amazonaws.services.sns.model.PublishRequest
import com.gu.anghammarad.AWS._
import com.gu.anghammarad.Json._
import com.gu.anghammarad.models._
import software.amazon.awssdk.services.sns.SnsAsyncClient
import software.amazon.awssdk.services.sns.model.PublishRequest

import scala.concurrent.{ExecutionContext, Future}

Expand All @@ -25,13 +25,14 @@ object Anghammarad {
* @return Future containing the resulting SNS Message ID
*/
def notify(subject: String, message: String, actions: List[Action], target: List[Target], channel: RequestedChannel,
sourceSystem: String, topicArn: String, client: AmazonSNSAsync = defaultClient)
sourceSystem: String, topicArn: String, client: SnsAsyncClient = defaultClient)
(implicit executionContext: ExecutionContext): Future[String] = {
val request = new PublishRequest()
.withTopicArn(topicArn)
.withSubject(subject)
.withMessage(messageJson(message, sourceSystem, channel, target, actions))
awsToScala(client.publishAsync)(request).map(_.getMessageId)
val request = PublishRequest.builder()
.topicArn(topicArn)
.subject(subject)
.message(messageJson(message, sourceSystem, channel, target, actions))
.build()
asScala(client.publish(request)).map(_.messageId)
}

/**
Expand All @@ -44,11 +45,12 @@ object Anghammarad {
*/
def notify(notification: Notification, topicArn: String)
(implicit executionContext: ExecutionContext): Future[String] = {
val request = new PublishRequest()
.withTopicArn(topicArn)
.withSubject(notification.subject)
.withMessage(messageJson(notification.message, notification.sourceSystem, notification.channel, notification.target, notification.actions))
awsToScala(defaultClient.publishAsync)(request).map(_.getMessageId)
val request = PublishRequest.builder()
.topicArn(topicArn)
.subject(notification.subject)
.message(messageJson(notification.message, notification.sourceSystem, notification.channel, notification.target, notification.actions))
.build()
asScala(defaultClient.publish(request)).map(_.messageId)
}

/**
Expand All @@ -59,12 +61,13 @@ object Anghammarad {
* @param client The SNS client used to add your notification to the topic
* @return Future containing the resulting SNS Message ID
*/
def notify(notification: Notification, topicArn: String, client: AmazonSNSAsync)
def notify(notification: Notification, topicArn: String, client: SnsAsyncClient)
(implicit executionContext: ExecutionContext): Future[String] = {
val request = new PublishRequest()
.withTopicArn(topicArn)
.withSubject(notification.subject)
.withMessage(messageJson(notification.message, notification.sourceSystem, notification.channel, notification.target, notification.actions))
awsToScala(client.publishAsync)(request).map(_.getMessageId)
val request = PublishRequest.builder()
.topicArn(topicArn)
.subject(notification.subject)
.message(messageJson(notification.message, notification.sourceSystem, notification.channel, notification.target, notification.actions))
.build()
asScala(client.publish(request)).map(_.messageId)
}
}

0 comments on commit 89295fa

Please sign in to comment.