Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WX-1078 ACR support #7192

Merged
merged 13 commits into from
Aug 9, 2023
1 change: 1 addition & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ lazy val databaseMigration = (project in file("database/migration"))

lazy val dockerHashing = project
.withLibrarySettings("cromwell-docker-hashing", dockerHashingDependencies)
.dependsOn(cloudSupport)
.dependsOn(core)
.dependsOn(core % "test->test")
.dependsOn(common % "test->test")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package cromwell.filesystems.blob
package cromwell.cloudsupport.azure

import cats.implicits.catsSyntaxValidatedId
import com.azure.core.credential.TokenRequestContext
Expand All @@ -9,7 +9,6 @@ import common.validation.ErrorOr.ErrorOr

import scala.concurrent.duration._
import scala.jdk.DurationConverters._

import scala.util.{Failure, Success, Try}

/**
Expand Down
9 changes: 9 additions & 0 deletions core/src/main/resources/reference.conf
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,15 @@ docker {
max-retries = 3

// Supported registries (Docker Hub, Google, Quay) can have additional configuration set separately
azure {
// Worst case `ReadOps per minute` value from official docs
// https://github.com/MicrosoftDocs/azure-docs/blob/main/includes/container-registry-limits.md
throttle {
number-of-requests = 1000
per = 60 seconds
}
num-threads = 10
}
google {
// Example of how to configure throttling, available for all supported registries
throttle {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package cromwell.docker

import cromwell.docker.registryv2.flows.azure.AzureContainerRegistry

import scala.util.{Failure, Success, Try}

sealed trait DockerImageIdentifier {
Expand All @@ -14,7 +16,14 @@ sealed trait DockerImageIdentifier {
lazy val name = repository map { r => s"$r/$image" } getOrElse image
// The name of the image with a repository prefix if a repository was specified, or with a default repository prefix of
// "library" if no repository was specified.
lazy val nameWithDefaultRepository = repository.getOrElse("library") + s"/$image"
lazy val nameWithDefaultRepository = {
// In ACR, the repository is part of the registry domain instead of the path
// e.g. `terrabatchdev.azurecr.io`
if (host.exists(_.contains(AzureContainerRegistry.domain)))
image
else
repository.getOrElse("library") + s"/$image"
}
lazy val hostAsString = host map { h => s"$h/" } getOrElse ""
// The full name of this image, including a repository prefix only if a repository was explicitly specified.
lazy val fullName = s"$hostAsString$name:$reference"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import cromwell.core.actor.StreamIntegration.{BackPressure, StreamContext}
import cromwell.core.{Dispatcher, DockerConfiguration}
import cromwell.docker.DockerInfoActor._
import cromwell.docker.registryv2.DockerRegistryV2Abstract
import cromwell.docker.registryv2.flows.azure.AzureContainerRegistry
import cromwell.docker.registryv2.flows.dockerhub.DockerHubRegistry
import cromwell.docker.registryv2.flows.google.GoogleRegistry
import cromwell.docker.registryv2.flows.quay.QuayRegistry
Expand Down Expand Up @@ -232,6 +233,7 @@ object DockerInfoActor {

// To add a new registry, simply add it to that list
List(
("azure", { c: DockerRegistryConfig => new AzureContainerRegistry(c) }),
("dockerhub", { c: DockerRegistryConfig => new DockerHubRegistry(c) }),
("google", { c: DockerRegistryConfig => new GoogleRegistry(c) }),
("quay", { c: DockerRegistryConfig => new QuayRegistry(c) })
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ abstract class DockerRegistryV2Abstract(override val config: DockerRegistryConfi
}

// Execute a request. No retries because they're expected to already be handled by the client
private def executeRequest[A](request: IO[Request[IO]], handler: Response[IO] => IO[A])(implicit client: Client[IO]): IO[A] = {
protected def executeRequest[A](request: IO[Request[IO]], handler: Response[IO] => IO[A])(implicit client: Client[IO]): IO[A] = {
request.flatMap(client.run(_).use[IO, A](handler))
}

Expand Down Expand Up @@ -188,7 +188,7 @@ abstract class DockerRegistryV2Abstract(override val config: DockerRegistryConfi
/**
* Builds the token request
*/
private def buildTokenRequest(dockerInfoContext: DockerInfoContext): IO[Request[IO]] = {
protected def buildTokenRequest(dockerInfoContext: DockerInfoContext): IO[Request[IO]] = {
val request = Method.GET(
buildTokenRequestUri(dockerInfoContext.dockerImageID),
buildTokenRequestHeaders(dockerInfoContext): _*
Expand Down Expand Up @@ -220,7 +220,7 @@ abstract class DockerRegistryV2Abstract(override val config: DockerRegistryConfi
* Request to get the manifest, using the auth token if provided
*/
private def manifestRequest(token: Option[String], imageId: DockerImageIdentifier, manifestHeader: Accept): IO[Request[IO]] = {
val authorizationHeader = token.map(t => Authorization(Credentials.Token(AuthScheme.Bearer, t)))
val authorizationHeader: Option[Authorization] = token.map(t => Authorization(Credentials.Token(AuthScheme.Bearer, t)))
val request = Method.GET(
buildManifestUri(imageId),
List(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package cromwell.docker.registryv2.flows.azure

case class AcrAccessToken(access_token: String)
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package cromwell.docker.registryv2.flows.azure

case class AcrRefreshToken(refresh_token: String)
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
package cromwell.docker.registryv2.flows.azure

import cats.data.Validated.{Invalid, Valid}
import cats.effect.IO
import com.typesafe.scalalogging.LazyLogging
import common.validation.ErrorOr.ErrorOr
import cromwell.cloudsupport.azure.AzureCredentials
import cromwell.docker.DockerInfoActor.DockerInfoContext
import cromwell.docker.{DockerImageIdentifier, DockerRegistryConfig}
import cromwell.docker.registryv2.DockerRegistryV2Abstract
import org.http4s.{Header, Request, Response, Status}
import cromwell.docker.registryv2.flows.azure.AzureContainerRegistry.domain
import org.http4s.circe.jsonOf
import org.http4s.client.Client
import io.circe.generic.auto._
import org.http4s._


class AzureContainerRegistry(config: DockerRegistryConfig) extends DockerRegistryV2Abstract(config) with LazyLogging {

/**
* (e.g registry-1.docker.io)
*/
override protected def registryHostName(dockerImageIdentifier: DockerImageIdentifier): String =
dockerImageIdentifier.host.getOrElse("")

override def accepts(dockerImageIdentifier: DockerImageIdentifier): Boolean =
dockerImageIdentifier.hostAsString.contains(domain)

override protected def authorizationServerHostName(dockerImageIdentifier: DockerImageIdentifier): String =
dockerImageIdentifier.host.getOrElse("")

/**
* In Azure, service name does not exist at the registry level, it varies per repo, e.g. `terrabatchdev.azurecr.io`
*/
override def serviceName: Option[String] =
throw new Exception("ACR service name is host of user-defined registry, must derive from `DockerImageIdentifier`")

/**
* Builds the list of headers for the token request
*/
override protected def buildTokenRequestHeaders(dockerInfoContext: DockerInfoContext): List[Header] = {
List(contentTypeHeader)
}

private val contentTypeHeader: Header = {
import org.http4s.headers.`Content-Type`
import org.http4s.MediaType
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any particular reason for having these imports embedded in these functions rather than at the top?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personal preference really, they are extremely specific and only used in one place.


`Content-Type`(MediaType.application.`x-www-form-urlencoded`)
}

private def getRefreshToken(authServerHostname: String, defaultAccessToken: String): IO[Request[IO]] = {
import org.http4s.Uri.{Authority, Scheme}
import org.http4s.client.dsl.io._
import org.http4s._

val uri = Uri.apply(
scheme = Option(Scheme.https),
authority = Option(Authority(host = Uri.RegName(authServerHostname))),
path = "/oauth2/exchange",
query = Query.empty
)

org.http4s.Method.POST(
UrlForm(
"service" -> authServerHostname,
"access_token" -> defaultAccessToken,
"grant_type" -> "access_token"
),
uri,
List(contentTypeHeader): _*
)
}

/*
Unlike other repositories, Azure reserves `GET /oauth2/token` for Basic Authentication [0]
In order to use Oauth we must `POST /oauth2/token` [1]

[0] https://github.com/Azure/acr/blob/main/docs/Token-BasicAuth.md#using-the-token-api
[1] https://github.com/Azure/acr/blob/main/docs/AAD-OAuth.md#calling-post-oauth2token-to-get-an-acr-access-token
*/
private def getDockerAccessToken(hostname: String, repository: String, refreshToken: String): IO[Request[IO]] = {
import org.http4s.Uri.{Authority, Scheme}
import org.http4s.client.dsl.io._
import org.http4s._

val uri = Uri.apply(
scheme = Option(Scheme.https),
authority = Option(Authority(host = Uri.RegName(hostname))),
path = "/oauth2/token",
query = Query.empty
)

org.http4s.Method.POST(
UrlForm(
// Tricky behavior - invalid `repository` values return a 200 with a valid-looking token.
// However, the token will cause 401s on all subsequent requests.
"scope" -> s"repository:$repository:pull",
"service" -> hostname,
"refresh_token" -> refreshToken,
"grant_type" -> "refresh_token"
),
uri,
List(contentTypeHeader): _* // http4s adds `Content-Length` which ACR does not like (400 response)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this syntax prevent the Content-Length header from being added?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies, this is a comment based on old info. Removing.

)
}

override protected def getToken(dockerInfoContext: DockerInfoContext)(implicit client: Client[IO]): IO[Option[String]] = {
val hostname = authorizationServerHostName(dockerInfoContext.dockerImageID)
val maybeAadAccessToken: ErrorOr[String] = AzureCredentials.getAccessToken(None) // AAD token suitable for get-refresh-token request
val repository = dockerInfoContext.dockerImageID.image // ACR uses what we think of image name, as the repository

// Top-level flow: AAD access token -> refresh token -> ACR access token
maybeAadAccessToken match {
case Valid(accessToken) =>
(for {
refreshToken <- executeRequest(getRefreshToken(hostname, accessToken), parseRefreshToken)
dockerToken <- executeRequest(getDockerAccessToken(hostname, repository, refreshToken), parseAccessToken)
} yield dockerToken).map(Option.apply)
case Invalid(errors) =>
IO.raiseError(
new Exception(s"Could not obtain AAD token to exchange for ACR refresh token. Error(s): ${errors}")
)
}
}

implicit val refreshTokenDecoder: EntityDecoder[IO, AcrRefreshToken] = jsonOf[IO, AcrRefreshToken]
implicit val accessTokenDecoder: EntityDecoder[IO, AcrAccessToken] = jsonOf[IO, AcrAccessToken]

private def parseRefreshToken(response: Response[IO]): IO[String] = response match {
case Status.Successful(r) => r.as[AcrRefreshToken].map(_.refresh_token)
case r =>
r.as[String].flatMap(b => IO.raiseError(new Exception(s"Request failed with status ${r.status.code} and body $b")))
}

private def parseAccessToken(response: Response[IO]): IO[String] = response match {
case Status.Successful(r) => r.as[AcrAccessToken].map(_.access_token)
case r =>
r.as[String].flatMap(b => IO.raiseError(new Exception(s"Request failed with status ${r.status.code} and body $b")))
}

}

object AzureContainerRegistry {

def domain: String = "azurecr.io"

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@ class DockerImageIdentifierSpec extends AnyFlatSpec with CromwellTimeoutSpec wit
("broad/cromwell/submarine", None, Option("broad/cromwell"), "submarine", "latest"),
("gcr.io/google/slim", Option("gcr.io"), Option("google"), "slim", "latest"),
("us-central1-docker.pkg.dev/google/slim", Option("us-central1-docker.pkg.dev"), Option("google"), "slim", "latest"),
("terrabatchdev.azurecr.io/postgres", Option("terrabatchdev.azurecr.io"), None, "postgres", "latest"),
// With tags
("ubuntu:latest", None, None, "ubuntu", "latest"),
("ubuntu:1235-SNAP", None, None, "ubuntu", "1235-SNAP"),
("ubuntu:V3.8-5_1", None, None, "ubuntu", "V3.8-5_1"),
("index.docker.io:9999/ubuntu:170904", Option("index.docker.io:9999"), None, "ubuntu", "170904"),
("localhost:5000/capture/transwf:170904", Option("localhost:5000"), Option("capture"), "transwf", "170904"),
("quay.io/biocontainers/platypus-variant:0.8.1.1--htslib1.5_0", Option("quay.io"), Option("biocontainers"), "platypus-variant", "0.8.1.1--htslib1.5_0"),
("terrabatchdev.azurecr.io/postgres:latest", Option("terrabatchdev.azurecr.io"), None, "postgres", "latest"),
// Very long tags with trailing spaces cause problems for the re engine
("someuser/someimage:supercalifragilisticexpialidociouseventhoughthesoundofitissomethingquiteatrociousifyousayitloudenoughyoullalwayssoundprecocious ", None, Some("someuser"), "someimage", "supercalifragilisticexpialidociouseventhoughthesoundofitissomethingquiteatrociousifyousayitloudenoughyoullalwayssoundprecocious")
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package cromwell.docker

import cromwell.core.Tags.IntegrationTest
import cromwell.docker.DockerInfoActor._
import cromwell.docker.registryv2.flows.azure.AzureContainerRegistry
import cromwell.docker.registryv2.flows.dockerhub.DockerHubRegistry
import cromwell.docker.registryv2.flows.google.GoogleRegistry
import cromwell.docker.registryv2.flows.quay.QuayRegistry
Expand All @@ -18,7 +19,8 @@ class DockerInfoActorSpec extends DockerRegistrySpec with AnyFlatSpecLike with M
override protected lazy val registryFlows = List(
new DockerHubRegistry(DockerRegistryConfig.default),
new GoogleRegistry(DockerRegistryConfig.default),
new QuayRegistry(DockerRegistryConfig.default)
new QuayRegistry(DockerRegistryConfig.default),
new AzureContainerRegistry(DockerRegistryConfig.default)
)

it should "retrieve a public docker hash" taggedAs IntegrationTest in {
Expand Down Expand Up @@ -50,6 +52,16 @@ class DockerInfoActorSpec extends DockerRegistrySpec with AnyFlatSpecLike with M
hash should not be empty
}
}

it should "retrieve a private docker hash on acr" taggedAs IntegrationTest in {
dockerActor ! makeRequest("terrabatchdev.azurecr.io/postgres:latest")

expectMsgPF(15 second) {
case DockerInfoSuccessResponse(DockerInformation(DockerHashResult(alg, hash), _), _) =>
alg shouldBe "sha256"
hash should not be empty
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Huh... doesn't this depend on auth? How is this passing in GHA?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tests tagged IntegrationTest don't run on GHA, only locally. Kind of strange, but it definitely comes in handy iterating locally. Presumably in the future we'll have an environment suitable for them, or maybe a more specific AzureIntegrationTest tag.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TIL!

}
}

it should "send image not found message back if the image does not exist" taggedAs IntegrationTest in {
val notFound = makeRequest("ubuntu:nonexistingtag")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import com.azure.storage.blob.sas.{BlobContainerSasPermission, BlobServiceSasSig
import com.typesafe.config.Config
import com.typesafe.scalalogging.LazyLogging
import common.validation.Validation._
import cromwell.cloudsupport.azure.AzureUtils
import cromwell.cloudsupport.azure.{AzureCredentials, AzureUtils}

import java.net.URI
import java.nio.file.{FileSystem, FileSystemNotFoundException, FileSystems}
Expand Down