Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WX-495] DRS Parallel Downloads #7214

Merged
merged 42 commits into from
Sep 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
29b374f
add json dependency
THWiseman Jul 20, 2023
2a12ab8
Merge branch 'develop' into WX-495
THWiseman Jul 20, 2023
ddf0328
make it scala
THWiseman Jul 20, 2023
7d6cd8b
lots of scaffolding
THWiseman Jul 21, 2023
201c314
merge
THWiseman Aug 1, 2023
be95d22
forward progress
THWiseman Aug 3, 2023
2519727
stash
THWiseman Aug 4, 2023
81f19a5
more progress
THWiseman Aug 4, 2023
9e6b1f5
json
THWiseman Aug 8, 2023
c874202
Merge branch 'develop' into WX-495
THWiseman Aug 11, 2023
61f248e
stash
THWiseman Aug 14, 2023
125202c
expired tokens
THWiseman Aug 14, 2023
a91ea03
Merge branch 'develop' into WX-495
THWiseman Aug 15, 2023
28e1b63
think about stuff
THWiseman Aug 15, 2023
6e9b698
undo sins
THWiseman Aug 15, 2023
a9a9c38
undo one more sin
THWiseman Aug 15, 2023
0ce068c
Merge branch 'develop' into WX-495
THWiseman Aug 31, 2023
e935cfc
deck chairs
THWiseman Aug 31, 2023
294f0c3
somethin that kinda works
THWiseman Sep 1, 2023
afe84da
working test
THWiseman Sep 5, 2023
0bd1a8d
lots of good cleanup
THWiseman Sep 7, 2023
6a77152
time for tests
THWiseman Sep 8, 2023
00ed9dc
remove conf
THWiseman Sep 8, 2023
510b085
working on tests
THWiseman Sep 8, 2023
8b99bac
test progress
THWiseman Sep 11, 2023
99efe34
much cleaner
THWiseman Sep 11, 2023
776bfac
oops
THWiseman Sep 11, 2023
a05f63e
bunch of tests
THWiseman Sep 12, 2023
c66ee6f
more better tests
THWiseman Sep 12, 2023
2a4ae0d
fix hashing
THWiseman Sep 12, 2023
cb4fb5d
who doesnt love a good implicit actor system
THWiseman Sep 12, 2023
d32613f
low hanging feedback
THWiseman Sep 13, 2023
52c7055
remove unnecessary IO
THWiseman Sep 13, 2023
33670eb
stale comment
THWiseman Sep 13, 2023
d745e2b
teamwork
THWiseman Sep 13, 2023
3e1dec2
cleanup file when done
THWiseman Sep 14, 2023
f6928d7
google download retries
THWiseman Sep 14, 2023
7e3fb8f
remove debugging log
THWiseman Sep 14, 2023
256b8d1
migrate last test
THWiseman Sep 14, 2023
5e876d9
version for bee testing
THWiseman Sep 15, 2023
c236c87
includes, better retries
THWiseman Sep 15, 2023
89f573f
Merge branch 'develop' into WX-495
THWiseman Sep 15, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ package drs.localizer

import cats.data.NonEmptyList
import cats.effect.{ExitCode, IO, IOApp}
import cats.implicits._
import cats.implicits.toTraverseOps
import cloud.nio.impl.drs.DrsPathResolver.{FatalRetryDisposition, RegularRetryDisposition}
import cloud.nio.impl.drs._
import cloud.nio.spi.{CloudNioBackoff, CloudNioSimpleExponentialBackoff}
import com.typesafe.scalalogging.StrictLogging
import drs.localizer.CommandLineParser.AccessTokenStrategy.{Azure, Google}
import drs.localizer.downloaders.AccessUrlDownloader.Hashes
import drs.localizer.DrsLocalizerMain.{defaultNumRetries, toValidatedUriType}
import drs.localizer.downloaders._
import org.apache.commons.csv.{CSVFormat, CSVParser}

Expand All @@ -17,7 +17,10 @@ import java.nio.charset.Charset
import scala.concurrent.duration._
import scala.jdk.CollectionConverters._
import scala.language.postfixOps
import drs.localizer.URIType.URIType

case class UnresolvedDrsUrl(drsUrl: String, downloadDestinationPath: String)
case class ResolvedDrsUrl(drsResponse: DrsResolverResponse, downloadDestinationPath: String, uriType: URIType)
object DrsLocalizerMain extends IOApp with StrictLogging {

override def run(args: List[String]): IO[ExitCode] = {
Expand All @@ -38,51 +41,102 @@ object DrsLocalizerMain extends IOApp with StrictLogging {

def buildParser(): scopt.OptionParser[CommandLineArguments] = new CommandLineParser()

// Default retry parameters for resolving a DRS url
val defaultNumRetries: Int = 5
val defaultBackoff: CloudNioBackoff = CloudNioSimpleExponentialBackoff(
initialInterval = 10 seconds, maxInterval = 60 seconds, multiplier = 2)
initialInterval = 1 seconds, maxInterval = 60 seconds, multiplier = 2)

val defaultDownloaderFactory: DownloaderFactory = new DownloaderFactory {
override def buildAccessUrlDownloader(accessUrl: AccessUrl, downloadLoc: String, hashes: Hashes): IO[Downloader] =
IO.pure(AccessUrlDownloader(accessUrl, downloadLoc, hashes))
override def buildGcsUriDownloader(gcsPath: String, serviceAccountJsonOption: Option[String], downloadLoc: String, requesterPaysProjectOption: Option[String]): Downloader =
GcsUriDownloader(gcsPath, serviceAccountJsonOption, downloadLoc, requesterPaysProjectOption)

override def buildGcsUriDownloader(gcsPath: String, serviceAccountJsonOption: Option[String], downloadLoc: String, requesterPaysProjectOption: Option[String]): IO[Downloader] =
IO.pure(GcsUriDownloader(gcsPath, serviceAccountJsonOption, downloadLoc, requesterPaysProjectOption))
override def buildBulkAccessUrlDownloader(urlsToDownload: List[ResolvedDrsUrl]): Downloader = {
BulkAccessUrlDownloader(urlsToDownload)
}
}

private def printUsage: IO[ExitCode] = {
System.err.println(CommandLineParser.Usage)
IO.pure(ExitCode.Error)
}

def runLocalizer(commandLineArguments: CommandLineArguments, drsCredentials: DrsCredentials): IO[ExitCode] = {
commandLineArguments.manifestPath match {
case Some(manifestPath) =>
val manifestFile = new File(manifestPath)
val csvParser = CSVParser.parse(manifestFile, Charset.defaultCharset(), CSVFormat.DEFAULT)
val exitCodes: IO[List[ExitCode]] = csvParser.asScala.map(record => {
val drsObject = record.get(0)
val containerPath = record.get(1)
localizeFile(commandLineArguments, drsCredentials, drsObject, containerPath)
}).toList.sequence
exitCodes.map(_.find(_ != ExitCode.Success).getOrElse(ExitCode.Success))
case None =>
val drsObject = commandLineArguments.drsObject.get
val containerPath = commandLineArguments.containerPath.get
localizeFile(commandLineArguments, drsCredentials, drsObject, containerPath)
/**
* Helper function to read a CSV file as pairs of drsURL -> local download destination.
* @param csvManifestPath Path to a CSV file where each row is something like: drs://asdf.ghj, path/to/my/directory
*/
def loadCSVManifest(csvManifestPath: String): IO[List[UnresolvedDrsUrl]] = {
IO {
val openFile = new File(csvManifestPath)
val csvParser = CSVParser.parse(openFile, Charset.defaultCharset(), CSVFormat.DEFAULT)
THWiseman marked this conversation as resolved.
Show resolved Hide resolved
try{
csvParser.getRecords.asScala.map(record => UnresolvedDrsUrl(record.get(0), record.get(1))).toList
} finally {
csvParser.close()
}
}
}

private def localizeFile(commandLineArguments: CommandLineArguments, drsCredentials: DrsCredentials, drsObject: String, containerPath: String) = {
new DrsLocalizerMain(drsObject, containerPath, drsCredentials, commandLineArguments.googleRequesterPaysProject).
resolveAndDownloadWithRetries(downloadRetries = 3, checksumRetries = 1, defaultDownloaderFactory, Option(defaultBackoff)).map(_.exitCode)
def runLocalizer(commandLineArguments: CommandLineArguments, drsCredentials: DrsCredentials) : IO[ExitCode] = {
val urlList = (commandLineArguments.manifestPath, commandLineArguments.drsObject, commandLineArguments.containerPath) match {
case (Some(manifestPath), _, _) => {
loadCSVManifest(manifestPath)
}
case (_, Some(drsObject), Some(containerPath)) => {
IO.pure(List(UnresolvedDrsUrl(drsObject, containerPath)))
}
case(_,_,_) => {
throw new RuntimeException("Illegal command line arguments supplied to drs localizer.")
}
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little suspicious of an unsafeRunSync inside an IO block. Can you instead do something like

main.resolveAndDownload().map(_.exitCode)?

val main = new DrsLocalizerMain(urlList, defaultDownloaderFactory, drsCredentials, commandLineArguments.googleRequesterPaysProject)
main.resolveAndDownload().map(_.exitCode)
}

/**
* Helper function to decide which downloader to use based on data from the DRS response.
* Throws a runtime exception if the DRS response is invalid.
*/
def toValidatedUriType(accessUrl: Option[AccessUrl], gsUri: Option[String]): URIType = {
// if both are provided, prefer using access urls
(accessUrl, gsUri) match {
case (Some(_), _) =>
if(!accessUrl.get.url.startsWith("https://")) { throw new RuntimeException("Resolved Access URL does not start with https://")}
URIType.ACCESS
case (_, Some(_)) =>
if(!gsUri.get.startsWith("gs://")) { throw new RuntimeException("Resolved Google URL does not start with gs://")}
URIType.GCS
case (_, _) =>
throw new RuntimeException("DRS response did not contain any URLs")
}
}
}

object URIType extends Enumeration {
type URIType = Value
val GCS, ACCESS, UNKNOWN = Value
}

class DrsLocalizerMain(drsUrl: String,
downloadLoc: String,
class DrsLocalizerMain(toResolveAndDownload: IO[List[UnresolvedDrsUrl]],
downloaderFactory: DownloaderFactory,
drsCredentials: DrsCredentials,
requesterPaysProjectIdOption: Option[String]) extends StrictLogging {

/**
* This will:
* - resolve all URLS
* - build downloader(s) for them
* - Invoke the downloaders to localize the files.
* @return DownloadSuccess if all downloads succeed. An error otherwise.
*/
def resolveAndDownload(): IO[DownloadResult] = {
val downloadResults = buildDownloaders().flatMap { downloaderList =>
downloaderList.map(downloader => downloader.download).traverse(identity)
}
downloadResults.map{list =>
list.find(result => result != DownloadSuccess).getOrElse(DownloadSuccess)
}
}

def getDrsPathResolver: IO[DrsLocalizerDrsPathResolver] = {
IO {
val drsConfig = DrsConfig.fromEnv(sys.env)
Expand All @@ -91,76 +145,86 @@ class DrsLocalizerMain(drsUrl: String,
}
}

def resolveAndDownloadWithRetries(downloadRetries: Int,
checksumRetries: Int,
downloaderFactory: DownloaderFactory,
backoff: Option[CloudNioBackoff],
downloadAttempt: Int = 0,
checksumAttempt: Int = 0): IO[DownloadResult] = {

def maybeRetryForChecksumFailure(t: Throwable): IO[DownloadResult] = {
if (checksumAttempt < checksumRetries) {
backoff foreach { b => Thread.sleep(b.backoffMillis) }
logger.warn(s"Attempting retry $checksumAttempt of $checksumRetries checksum retries to download $drsUrl", t)
// In the event of a checksum failure reset the download attempt to zero.
resolveAndDownloadWithRetries(downloadRetries, checksumRetries, downloaderFactory, backoff map { _.next }, 0, checksumAttempt + 1)
} else {
IO.raiseError(new RuntimeException(s"Exhausted $checksumRetries checksum retries to resolve, download and checksum $drsUrl", t))
}
/**
* After resolving all of the URLs, this sorts them into an "Access" or "GCS" bucket.
* All access URLS will be downloaded as a batch with a single bulk downloader.
* All google URLs will be downloaded individually in their own google downloader.
* @return List of all downloaders required to fulfill the request.
*/
def buildDownloaders() : IO[List[Downloader]] = {
resolveUrls(toResolveAndDownload).map { pendingDownloads =>
val accessUrls = pendingDownloads.filter(url => url.uriType == URIType.ACCESS)
val googleUrls = pendingDownloads.filter(url => url.uriType == URIType.GCS)
val bulkDownloader: List[Downloader] = if (accessUrls.isEmpty) List() else List(buildBulkAccessUrlDownloader(accessUrls))
val googleDownloaders: List[Downloader] = if (googleUrls.isEmpty) List() else buildGoogleDownloaders(googleUrls)
bulkDownloader ++ googleDownloaders
}
}

def maybeRetryForDownloadFailure(t: Throwable): IO[DownloadResult] = {
t match {
case _: FatalRetryDisposition =>
IO.raiseError(t)
case _ if downloadAttempt < downloadRetries =>
backoff foreach { b => Thread.sleep(b.backoffMillis) }
logger.warn(s"Attempting retry $downloadAttempt of $downloadRetries download retries to download $drsUrl", t)
resolveAndDownloadWithRetries(downloadRetries, checksumRetries, downloaderFactory, backoff map { _.next }, downloadAttempt + 1, checksumAttempt)
case _ =>
IO.raiseError(new RuntimeException(s"Exhausted $downloadRetries download retries to resolve, download and checksum $drsUrl", t))
}
def buildGoogleDownloaders(resolvedGoogleUrls: List[ResolvedDrsUrl]) : List[Downloader] = {
resolvedGoogleUrls.map{url=>
downloaderFactory.buildGcsUriDownloader(
gcsPath = url.drsResponse.gsUri.get,
serviceAccountJsonOption = url.drsResponse.googleServiceAccount.map(_.data.spaces2),
downloadLoc = url.downloadDestinationPath,
requesterPaysProjectOption = requesterPaysProjectIdOption)
}
}
def buildBulkAccessUrlDownloader(resolvedUrls: List[ResolvedDrsUrl]) : Downloader = {
downloaderFactory.buildBulkAccessUrlDownloader(resolvedUrls)
}

resolveAndDownload(downloaderFactory).redeemWith({
maybeRetryForDownloadFailure
},
{
case f: FatalDownloadFailure =>
IO.raiseError(new RuntimeException(s"Fatal error downloading DRS object: $f"))
case r: RetryableDownloadFailure =>
maybeRetryForDownloadFailure(
new RuntimeException(s"Retryable download error: $r for $drsUrl on retry attempt $downloadAttempt of $downloadRetries") with RegularRetryDisposition)
case ChecksumFailure =>
maybeRetryForChecksumFailure(new RuntimeException(s"Checksum failure for $drsUrl on checksum retry attempt $checksumAttempt of $checksumRetries"))
case o => IO.pure(o)
})
/**
* Runs a synchronous HTTP request to resolve the provided DRS URL with the provided resolver.
*/
def resolveSingleUrl(resolverObject: DrsLocalizerDrsPathResolver, drsUrlToResolve: UnresolvedDrsUrl): IO[ResolvedDrsUrl] = {
val fields = NonEmptyList.of(DrsResolverField.GsUri, DrsResolverField.GoogleServiceAccount, DrsResolverField.AccessUrl, DrsResolverField.Hashes)
val drsResponse = resolverObject.resolveDrs(drsUrlToResolve.drsUrl, fields)
drsResponse.map(resp => ResolvedDrsUrl(resp, drsUrlToResolve.downloadDestinationPath, toValidatedUriType(resp.accessUrl, resp.gsUri)))
}

private [localizer] def resolveAndDownload(downloaderFactory: DownloaderFactory): IO[DownloadResult] = {
resolve(downloaderFactory) flatMap { _.download }

val defaultBackoff: CloudNioBackoff = CloudNioSimpleExponentialBackoff(
initialInterval = 10 seconds, maxInterval = 60 seconds, multiplier = 2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems aggressive (or, not aggressive enough?) - maybe start with 1 second interval? Is there a max number of retries set elsewhere?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is what we were using before, but happy to tone it down a bit. I will also put the max retry count variable here (I went with 5) so everything is in the same place.


/**
* Runs synchronous HTTP requests to resolve all the DRS urls.
*/
def resolveUrls(unresolvedUrls: IO[List[UnresolvedDrsUrl]]): IO[List[ResolvedDrsUrl]] = {
unresolvedUrls.flatMap { unresolvedList =>
getDrsPathResolver.flatMap { resolver =>
unresolvedList.map { unresolvedUrl =>
resolveWithRetries(resolver, unresolvedUrl, defaultNumRetries, Option(defaultBackoff))
}.traverse(identity)
}
}
}

private [localizer] def resolve(downloaderFactory: DownloaderFactory): IO[Downloader] = {
val fields = NonEmptyList.of(DrsResolverField.GsUri, DrsResolverField.GoogleServiceAccount, DrsResolverField.AccessUrl, DrsResolverField.Hashes)
for {
resolver <- getDrsPathResolver
drsResolverResponse <- resolver.resolveDrs(drsUrl, fields)

// Currently DRS Resolver only supports resolving DRS paths to access URLs or GCS paths.
downloader <- (drsResolverResponse.accessUrl, drsResolverResponse.gsUri) match {
case (Some(accessUrl), _) =>
downloaderFactory.buildAccessUrlDownloader(accessUrl, downloadLoc, drsResolverResponse.hashes)
case (_, Some(gcsPath)) =>
val serviceAccountJsonOption = drsResolverResponse.googleServiceAccount.map(_.data.spaces2)
downloaderFactory.buildGcsUriDownloader(
gcsPath = gcsPath,
serviceAccountJsonOption = serviceAccountJsonOption,
downloadLoc = downloadLoc,
requesterPaysProjectOption = requesterPaysProjectIdOption)
case _ =>
IO.raiseError(new RuntimeException(DrsPathResolver.ExtractUriErrorMsg))
def resolveWithRetries(resolverObject: DrsLocalizerDrsPathResolver,
drsUrlToResolve: UnresolvedDrsUrl,
resolutionRetries: Int,
backoff: Option[CloudNioBackoff],
resolutionAttempt: Int = 0) : IO[ResolvedDrsUrl] = {

def maybeRetryForResolutionFailure(t: Throwable): IO[ResolvedDrsUrl] = {
if (resolutionAttempt < resolutionRetries) {
backoff foreach { b => Thread.sleep(b.backoffMillis) }
logger.warn(s"Attempting retry $resolutionAttempt of $resolutionRetries drs resolution retries to resolve ${drsUrlToResolve.drsUrl}", t)
resolveWithRetries(resolverObject, drsUrlToResolve, resolutionRetries, backoff map { _.next }, resolutionAttempt+1)
} else {
IO.raiseError(new RuntimeException(s"Exhausted $resolutionRetries resolution retries to resolve $drsUrlToResolve.drsUrl", t))
}
} yield downloader
}

resolveSingleUrl(resolverObject, drsUrlToResolve).redeemWith(
recover = maybeRetryForResolutionFailure,
bind = {
case f: FatalRetryDisposition =>
IO.raiseError(new RuntimeException(s"Fatal error resolving DRS URL: $f"))
case _: RegularRetryDisposition =>
resolveWithRetries(resolverObject, drsUrlToResolve, resolutionRetries, backoff, resolutionAttempt+1)
case o => IO.pure(o)
})
}
}

This file was deleted.

Loading