Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WX-495] DRS Parallel Downloads #7214

Merged
merged 42 commits into from
Sep 15, 2023
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
29b374f
add json dependency
THWiseman Jul 20, 2023
2a12ab8
Merge branch 'develop' into WX-495
THWiseman Jul 20, 2023
ddf0328
make it scala
THWiseman Jul 20, 2023
7d6cd8b
lots of scaffolding
THWiseman Jul 21, 2023
201c314
merge
THWiseman Aug 1, 2023
be95d22
forward progress
THWiseman Aug 3, 2023
2519727
stash
THWiseman Aug 4, 2023
81f19a5
more progress
THWiseman Aug 4, 2023
9e6b1f5
json
THWiseman Aug 8, 2023
c874202
Merge branch 'develop' into WX-495
THWiseman Aug 11, 2023
61f248e
stash
THWiseman Aug 14, 2023
125202c
expired tokens
THWiseman Aug 14, 2023
a91ea03
Merge branch 'develop' into WX-495
THWiseman Aug 15, 2023
28e1b63
think about stuff
THWiseman Aug 15, 2023
6e9b698
undo sins
THWiseman Aug 15, 2023
a9a9c38
undo one more sin
THWiseman Aug 15, 2023
0ce068c
Merge branch 'develop' into WX-495
THWiseman Aug 31, 2023
e935cfc
deck chairs
THWiseman Aug 31, 2023
294f0c3
somethin that kinda works
THWiseman Sep 1, 2023
afe84da
working test
THWiseman Sep 5, 2023
0bd1a8d
lots of good cleanup
THWiseman Sep 7, 2023
6a77152
time for tests
THWiseman Sep 8, 2023
00ed9dc
remove conf
THWiseman Sep 8, 2023
510b085
working on tests
THWiseman Sep 8, 2023
8b99bac
test progress
THWiseman Sep 11, 2023
99efe34
much cleaner
THWiseman Sep 11, 2023
776bfac
oops
THWiseman Sep 11, 2023
a05f63e
bunch of tests
THWiseman Sep 12, 2023
c66ee6f
more better tests
THWiseman Sep 12, 2023
2a4ae0d
fix hashing
THWiseman Sep 12, 2023
cb4fb5d
who doesnt love a good implicit actor system
THWiseman Sep 12, 2023
d32613f
low hanging feedback
THWiseman Sep 13, 2023
52c7055
remove unnecessary IO
THWiseman Sep 13, 2023
33670eb
stale comment
THWiseman Sep 13, 2023
d745e2b
teamwork
THWiseman Sep 13, 2023
3e1dec2
cleanup file when done
THWiseman Sep 14, 2023
f6928d7
google download retries
THWiseman Sep 14, 2023
7e3fb8f
remove debugging log
THWiseman Sep 14, 2023
256b8d1
migrate last test
THWiseman Sep 14, 2023
5e876d9
version for bee testing
THWiseman Sep 15, 2023
c236c87
includes, better retries
THWiseman Sep 15, 2023
89f573f
Merge branch 'develop' into WX-495
THWiseman Sep 15, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@ package drs.localizer

import cats.data.NonEmptyList
import cats.effect.{ExitCode, IO, IOApp}
import cats.implicits._
import cloud.nio.impl.drs.DrsPathResolver.{FatalRetryDisposition, RegularRetryDisposition}
import cats.implicits.toTraverseOps
import cloud.nio.impl.drs._
import cloud.nio.spi.{CloudNioBackoff, CloudNioSimpleExponentialBackoff}
import com.typesafe.scalalogging.StrictLogging
import drs.localizer.CommandLineParser.AccessTokenStrategy.{Azure, Google}
import drs.localizer.downloaders.AccessUrlDownloader.Hashes
import drs.localizer.DrsLocalizerMain.toValidatedUriType
import drs.localizer.downloaders._
import org.apache.commons.csv.{CSVFormat, CSVParser}

Expand All @@ -17,7 +16,9 @@ import java.nio.charset.Charset
import scala.concurrent.duration._
import scala.jdk.CollectionConverters._
import scala.language.postfixOps

import drs.localizer.URIType.URIType
case class UnresolvedDrsUrl(drsUrl: String, downloadDestinationPath: String)
case class ResolvedDrsUrl(drsResponse: DrsResolverResponse, downloadDestinationPath: String, uriType: URIType)
object DrsLocalizerMain extends IOApp with StrictLogging {

override def run(args: List[String]): IO[ExitCode] = {
Expand All @@ -42,47 +43,91 @@ object DrsLocalizerMain extends IOApp with StrictLogging {
initialInterval = 10 seconds, maxInterval = 60 seconds, multiplier = 2)

val defaultDownloaderFactory: DownloaderFactory = new DownloaderFactory {
override def buildAccessUrlDownloader(accessUrl: AccessUrl, downloadLoc: String, hashes: Hashes): IO[Downloader] =
IO.pure(AccessUrlDownloader(accessUrl, downloadLoc, hashes))

override def buildGcsUriDownloader(gcsPath: String, serviceAccountJsonOption: Option[String], downloadLoc: String, requesterPaysProjectOption: Option[String]): IO[Downloader] =
IO.pure(GcsUriDownloader(gcsPath, serviceAccountJsonOption, downloadLoc, requesterPaysProjectOption))

override def buildBulkAccessUrlDownloader(urlsToDownload: List[ResolvedDrsUrl]): IO[Downloader] = {
IO.pure(BulkAccessUrlDownloader(urlsToDownload))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this an IO? Presumably to make something up-stack happier, since this is just instantiating a case class? I would move the IO.pure wrapper up to the place its needed (same for GcsUriDownloader).

}
}

private def printUsage: IO[ExitCode] = {
System.err.println(CommandLineParser.Usage)
IO.pure(ExitCode.Error)
}

def runLocalizer(commandLineArguments: CommandLineArguments, drsCredentials: DrsCredentials): IO[ExitCode] = {
commandLineArguments.manifestPath match {
/**
* Helper function to read a CSV file as a map from drs URL to requested download destination
*
* @param csvManifestPath Path to a CSV file where each row is something like: drs://asdf.ghj, path/to/my/directory
*/
def loadCSVManifest(csvManifestPath: String): IO[List[UnresolvedDrsUrl]] = {
IO {
val openFile = new File(csvManifestPath)
val csvParser = CSVParser.parse(openFile, Charset.defaultCharset(), CSVFormat.DEFAULT)
THWiseman marked this conversation as resolved.
Show resolved Hide resolved
val list = csvParser.getRecords.asScala.map(record => UnresolvedDrsUrl(record.get(0), record.get(1))).toList
list
}
}


def runLocalizer(commandLineArguments: CommandLineArguments, drsCredentials: DrsCredentials) : IO[ExitCode] = {
val urlList : IO[List[UnresolvedDrsUrl]] = commandLineArguments.manifestPath match {
case Some(manifestPath) =>
val manifestFile = new File(manifestPath)
val csvParser = CSVParser.parse(manifestFile, Charset.defaultCharset(), CSVFormat.DEFAULT)
val exitCodes: IO[List[ExitCode]] = csvParser.asScala.map(record => {
val drsObject = record.get(0)
val containerPath = record.get(1)
localizeFile(commandLineArguments, drsCredentials, drsObject, containerPath)
}).toList.sequence
exitCodes.map(_.find(_ != ExitCode.Success).getOrElse(ExitCode.Success))
loadCSVManifest(manifestPath)
case None =>
val drsObject = commandLineArguments.drsObject.get
val containerPath = commandLineArguments.containerPath.get
localizeFile(commandLineArguments, drsCredentials, drsObject, containerPath)
IO.pure(List(UnresolvedDrsUrl(commandLineArguments.drsObject.get, commandLineArguments.containerPath.get)))
}
IO{
val main = new DrsLocalizerMain(urlList, defaultDownloaderFactory, drsCredentials, commandLineArguments.googleRequesterPaysProject)
main.resolveAndDownload().unsafeRunSync().exitCode
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little suspicious of an unsafeRunSync inside an IO block. Can you instead do something like

main.resolveAndDownload().map(_.exitCode)?

}
}

private def localizeFile(commandLineArguments: CommandLineArguments, drsCredentials: DrsCredentials, drsObject: String, containerPath: String) = {
new DrsLocalizerMain(drsObject, containerPath, drsCredentials, commandLineArguments.googleRequesterPaysProject).
resolveAndDownloadWithRetries(downloadRetries = 3, checksumRetries = 1, defaultDownloaderFactory, Option(defaultBackoff)).map(_.exitCode)
/**
* Helper function to decide which downloader to use based on data from the DRS response.
* Throws a runtime exception if the DRS response is invalid.
*/
def toValidatedUriType(accessUrl: Option[AccessUrl], gsUri: Option[String]): URIType = {
// if both are provided, prefer using access urls
(accessUrl, gsUri) match {
case (Some(_), _) =>
if(!accessUrl.get.url.startsWith("https://")) { throw new RuntimeException("Resolved Access URL does not start with https://")}
URIType.ACCESS
case (_, Some(_)) =>
if(!gsUri.get.startsWith("gs://")) { throw new RuntimeException("Resolved Google URL does not start with gs://")}
URIType.GCS
case (_, _) =>
throw new RuntimeException("DRS response did not contain any URLs")
}
}
}

object URIType extends Enumeration {
type URIType = Value
val GCS, ACCESS, UNKNOWN = Value
}

class DrsLocalizerMain(drsUrl: String,
downloadLoc: String,
class DrsLocalizerMain(toResolveAndDownload: IO[List[UnresolvedDrsUrl]],
downloaderFactory: DownloaderFactory,
drsCredentials: DrsCredentials,
requesterPaysProjectIdOption: Option[String]) extends StrictLogging {

/**
* This will:
* - resolve all URLS
* - build downloader(s) for them
* - Invoke the downloaders to localize the files.
* @return DownloadSuccess if all downloads succeed. An error otherwise.
*/
def resolveAndDownload(): IO[DownloadResult] = {
IO {
val downloaders: List[Downloader] = buildDownloaders().unsafeRunSync()
val results: List[DownloadResult] = downloaders.map(downloader => downloader.download.unsafeRunSync())
results.find(res => res != DownloadSuccess).getOrElse(DownloadSuccess)
}
}

def getDrsPathResolver: IO[DrsLocalizerDrsPathResolver] = {
IO {
val drsConfig = DrsConfig.fromEnv(sys.env)
Expand All @@ -91,6 +136,63 @@ class DrsLocalizerMain(drsUrl: String,
}
}

/**
* Runs a synchronous HTTP request to resolve the provided DRS URL with the provided resolver.
*/
def resolveSingleUrl(resolverObject: DrsLocalizerDrsPathResolver, drsUrlToResolve: UnresolvedDrsUrl): IO[ResolvedDrsUrl] = {
IO {
val fields = NonEmptyList.of(DrsResolverField.GsUri, DrsResolverField.GoogleServiceAccount, DrsResolverField.AccessUrl, DrsResolverField.Hashes)
//Insert retry logic here.
val drsResponse = resolverObject.resolveDrs(drsUrlToResolve.drsUrl, fields).unsafeRunSync()
ResolvedDrsUrl(drsResponse, drsUrlToResolve.downloadDestinationPath, toValidatedUriType(drsResponse.accessUrl, drsResponse.gsUri))
}
}

/**
* Runs synchronous HTTP requests to resolve all the DRS urls.
*/
def resolveUrls(unresolvedUrls: IO[List[UnresolvedDrsUrl]]) : IO[List[ResolvedDrsUrl]] = {
unresolvedUrls.flatMap{unresolvedList =>
getDrsPathResolver.flatMap{resolver =>
unresolvedList.map{unresolvedUrl =>
resolveSingleUrl(resolver, unresolvedUrl)
}.traverse(identity)
}
}
}

/**
* After resolving all of the URLs, this sorts them into an "Access" or "GCS" bucket.
* All access URLS will be downloaded as a batch with a single bulk downloader.
* All google URLs will be downloaded individually in their own google downloader.
* @return List of all downloaders required to fulfill the request.
*/
def buildDownloaders() : IO[List[Downloader]] = {
resolveUrls(toResolveAndDownload).flatMap { pendingDownloads =>
val accessUrls = pendingDownloads.filter(url => url.uriType == URIType.ACCESS)
val googleUrls = pendingDownloads.filter(url => url.uriType == URIType.GCS)
val bulkDownloader: Option[List[IO[Downloader]]] = if(accessUrls.isEmpty) None else Option(List(buildBulkAccessUrlDownloader(accessUrls)))
val googleDownloaders: Option[List[IO[Downloader]]] = if(googleUrls.isEmpty) None else Option(buildGoogleDownloaders(googleUrls))
val combined: List[IO[Downloader]] = googleDownloaders.map(list => list).getOrElse(List()) ++ bulkDownloader.map(list => list).getOrElse(List())
combined.traverse(identity)
}
}

def buildGoogleDownloaders(resolvedGoogleUrls: List[ResolvedDrsUrl]) : List[IO[Downloader]] = {
resolvedGoogleUrls.map{url=>
downloaderFactory.buildGcsUriDownloader(
gcsPath = url.drsResponse.gsUri.get,
serviceAccountJsonOption = url.drsResponse.googleServiceAccount.map(_.data.spaces2),
downloadLoc = url.downloadDestinationPath,
requesterPaysProjectOption = requesterPaysProjectIdOption)
}
}
def buildBulkAccessUrlDownloader(resolvedUrls: List[ResolvedDrsUrl]) : IO[Downloader] = {
downloaderFactory.buildBulkAccessUrlDownloader(resolvedUrls)
}


/*
def resolveAndDownloadWithRetries(downloadRetries: Int,
checksumRetries: Int,
downloaderFactory: DownloaderFactory,
Expand All @@ -108,7 +210,8 @@ class DrsLocalizerMain(drsUrl: String,
IO.raiseError(new RuntimeException(s"Exhausted $checksumRetries checksum retries to resolve, download and checksum $drsUrl", t))
}
}

*/
/*
def maybeRetryForDownloadFailure(t: Throwable): IO[DownloadResult] = {
t match {
case _: FatalRetryDisposition =>
Expand All @@ -121,7 +224,8 @@ class DrsLocalizerMain(drsUrl: String,
IO.raiseError(new RuntimeException(s"Exhausted $downloadRetries download retries to resolve, download and checksum $drsUrl", t))
}
}

*/
/*
resolveAndDownload(downloaderFactory).redeemWith({
maybeRetryForDownloadFailure
},
Expand All @@ -134,33 +238,11 @@ class DrsLocalizerMain(drsUrl: String,
case ChecksumFailure =>
maybeRetryForChecksumFailure(new RuntimeException(s"Checksum failure for $drsUrl on checksum retry attempt $checksumAttempt of $checksumRetries"))
case o => IO.pure(o)
})
}
}*

private [localizer] def resolveAndDownload(downloaderFactory: DownloaderFactory): IO[DownloadResult] = {
resolve(downloaderFactory) flatMap { _.download }
}

private [localizer] def resolve(downloaderFactory: DownloaderFactory): IO[Downloader] = {
val fields = NonEmptyList.of(DrsResolverField.GsUri, DrsResolverField.GoogleServiceAccount, DrsResolverField.AccessUrl, DrsResolverField.Hashes)
for {
resolver <- getDrsPathResolver
drsResolverResponse <- resolver.resolveDrs(drsUrl, fields)

// Currently DRS Resolver only supports resolving DRS paths to access URLs or GCS paths.
downloader <- (drsResolverResponse.accessUrl, drsResolverResponse.gsUri) match {
case (Some(accessUrl), _) =>
downloaderFactory.buildAccessUrlDownloader(accessUrl, downloadLoc, drsResolverResponse.hashes)
case (_, Some(gcsPath)) =>
val serviceAccountJsonOption = drsResolverResponse.googleServiceAccount.map(_.data.spaces2)
downloaderFactory.buildGcsUriDownloader(
gcsPath = gcsPath,
serviceAccountJsonOption = serviceAccountJsonOption,
downloadLoc = downloadLoc,
requesterPaysProjectOption = requesterPaysProjectIdOption)
case _ =>
IO.raiseError(new RuntimeException(DrsPathResolver.ExtractUriErrorMsg))
}
} yield downloader
IO.raiseError(new RuntimeException(s"Exhausted $downloadRetries download retries to resolve, download and checksum $drsUrl", t))
}
*/
}

This file was deleted.

Loading