diff --git a/src/main/scala/zio/s3/package.scala b/src/main/scala/zio/s3/package.scala index 3fe63d52..6ffbed2d 100644 --- a/src/main/scala/zio/s3/package.scala +++ b/src/main/scala/zio/s3/package.scala @@ -22,7 +22,7 @@ import software.amazon.awssdk.services.s3.S3AsyncClient import software.amazon.awssdk.services.s3.model.S3Exception import zio.nio.file.{ Path => ZPath } import zio.s3.S3Bucket.S3BucketListing -import zio.s3.providers.const +import zio.s3.providers.{ basic, const } import zio.stream.ZStream import java.net.URI @@ -32,7 +32,7 @@ package object s3 { type S3Stream[A] = ZStream[S3, S3Exception, A] def live(region: Region, credentials: AwsCredentials, uriEndpoint: Option[URI] = None): Layer[S3Exception, S3] = - liveZIO(region, const(credentials.accessKeyId, credentials.secretAccessKey), uriEndpoint) + liveZIO(region, const(credentials), uriEndpoint) def liveZIO[R]( region: Region, @@ -50,7 +50,7 @@ package object s3 { val live: ZLayer[S3Settings, ConnectionError, S3] = ZLayer.scoped( ZIO.serviceWithZIO[S3Settings](s => - Live.connect(s.s3Region, const(s.credentials.accessKeyId, s.credentials.secretAccessKey), None) + Live.connect(s.s3Region, basic(s.credentials.accessKeyId, s.credentials.secretAccessKey), None) ) ) diff --git a/src/main/scala/zio/s3/providers.scala b/src/main/scala/zio/s3/providers.scala index 0cad41be..e07a83ce 100644 --- a/src/main/scala/zio/s3/providers.scala +++ b/src/main/scala/zio/s3/providers.scala @@ -5,8 +5,14 @@ import zio.{ IO, Scope, UIO, ZIO } object providers { - def const(accessKeyId: String, secretAccessKey: String): UIO[AwsCredentialsProvider] = - ZIO.succeedNow[AwsCredentialsProvider](() => AwsBasicCredentials.create(accessKeyId, secretAccessKey)) + def const(credential: AwsCredentials): UIO[AwsCredentialsProvider] = + ZIO.succeedNow[AwsCredentialsProvider](() => credential) + + def basic(accessKeyId: String, secretAccessKey: String): UIO[AwsCredentialsProvider] = + const(AwsBasicCredentials.create(accessKeyId, secretAccessKey)) + + def session(accessKeyId: String, secretAccessKey: String, sessionToken: String): UIO[AwsCredentialsProvider] = + const(AwsSessionCredentials.create(accessKeyId, secretAccessKey, sessionToken)) val system: IO[InvalidCredentials, SystemPropertyCredentialsProvider] = ZIO diff --git a/src/test/scala/zio/s3/S3ProvidersTest.scala b/src/test/scala/zio/s3/S3ProvidersTest.scala index 9109cf4f..7aa92bcb 100644 --- a/src/test/scala/zio/s3/S3ProvidersTest.scala +++ b/src/test/scala/zio/s3/S3ProvidersTest.scala @@ -1,6 +1,6 @@ package zio.s3 -import software.amazon.awssdk.auth.credentials.AwsBasicCredentials +import software.amazon.awssdk.auth.credentials.{ AwsBasicCredentials, AwsSessionCredentials } import software.amazon.awssdk.regions.Region import zio.s3.providers._ import zio.test.Assertion._ @@ -25,14 +25,19 @@ object S3ProvidersTest extends ZIOSpecDefault { def spec: Spec[TestEnvironment with Scope, Any] = suite("Providers")( - test("cred with const") { + test("basic credentials") { ZIO - .scoped(const("k", "v").map(_.resolveCredentials())) + .scoped(basic("k", "v").map(_.resolveCredentials())) .map(res => assertTrue(res == AwsBasicCredentials.create("k", "v"))) }, - test("cred with default fallback const") { + test("session credentials") { ZIO - .scoped((env <> const("k", "v")).map(_.resolveCredentials())) + .scoped(session("k", "v", "t").map(_.resolveCredentials())) + .map(res => assertTrue(res == AwsSessionCredentials.create("k", "v", "t"))) + }, + test("basic credentials default fallback const") { + ZIO + .scoped((env <> basic("k", "v")).map(_.resolveCredentials())) .map(res => assertTrue(res == AwsBasicCredentials.create("k", "v"))) }, test("cred in system properties") {