Skip to content

Commit

Permalink
fix: handle long run ids in adb tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mhamilton723 committed Oct 20, 2023
1 parent 2abfbda commit 06f1383
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import scala.language.existentials
class DatabricksCPUTests extends DatabricksTestHelper {

val clusterId: String = createClusterInPool(ClusterName, AdbRuntime, NumWorkers, PoolId, "[]")
val jobIdsToCancel: ListBuffer[Int] = databricksTestHelper(clusterId, Libraries, CPUNotebooks)
val jobIdsToCancel: ListBuffer[Long] = databricksTestHelper(clusterId, Libraries, CPUNotebooks)

protected override def afterAll(): Unit = {
afterAllHelper(jobIdsToCancel, clusterId, ClusterName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class DatabricksGPUTests extends DatabricksTestHelper {
"src", "main", "python", "horovod_installation.sh").getCanonicalFile
uploadFileToDBFS(horovodInstallationScript, "/FileStore/horovod-fix-commit/horovod_installation.sh")
val clusterId: String = createClusterInPool(GPUClusterName, AdbGpuRuntime, 2, GpuPoolId, GPUInitScripts)
val jobIdsToCancel: ListBuffer[Int] = databricksTestHelper(
val jobIdsToCancel: ListBuffer[Long] = databricksTestHelper(
clusterId, GPULibraries, GPUNotebooks)

protected override def afterAll(): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ object DatabricksUtilities {
lazy val Token: String = sys.env.getOrElse("MML_ADB_TOKEN", Secrets.AdbToken)
lazy val AuthValue: String = "Basic " + BaseEncoding.base64()
.encode(("token:" + Token).getBytes("UTF-8"))
val BaseURL = s"https://$Region.azuredatabricks.net/api/2.0/"

lazy val PoolId: String = getPoolIdByName(PoolName)
lazy val GpuPoolId: String = getPoolIdByName(GpuPoolName)
lazy val ClusterName = s"mmlspark-build-${LocalDateTime.now()}"
Expand Down Expand Up @@ -67,6 +67,8 @@ object DatabricksUtilities {
"interpret-community"
)

def baseURL(apiVersion: String): String = s"https://$Region.azuredatabricks.net/api/$apiVersion/"

val Libraries: String = (
List(Map("maven" -> Map("coordinates" -> PackageMavenCoordinate, "repo" -> PackageRepository))) ++
PipPackages.map(p => Map("pypi" -> Map("package" -> p)))
Expand Down Expand Up @@ -98,15 +100,15 @@ object DatabricksUtilities {

val GPUNotebooks: Seq[File] = ParallelizableNotebooks.filter(_.getAbsolutePath.contains("Fine-tune"))

def databricksGet(path: String): JsValue = {
val request = new HttpGet(BaseURL + path)
def databricksGet(path: String, apiVersion: String = "2.0"): JsValue = {
val request = new HttpGet(baseURL(apiVersion) + path)
request.addHeader("Authorization", AuthValue)
RESTHelpers.sendAndParseJson(request)
}

//TODO convert all this to typed code
def databricksPost(path: String, body: String): JsValue = {
val request = new HttpPost(BaseURL + path)
def databricksPost(path: String, body: String, apiVersion: String = "2.0"): JsValue = {
val request = new HttpPost(baseURL(apiVersion) + path)
request.addHeader("Authorization", AuthValue)
request.setEntity(new StringEntity(body))
RESTHelpers.sendAndParseJson(request)
Expand All @@ -120,7 +122,7 @@ object DatabricksUtilities {
}

def getPoolIdByName(name: String): String = {
val jsonObj = databricksGet("instance-pools/list")
val jsonObj = databricksGet("instance-pools/list", apiVersion = "2.0")
val cluster = jsonObj.select[Array[JsValue]]("instance_pools")
.filter(_.select[String]("instance_pool_name") == name).head
cluster.select[String]("instance_pool_id")
Expand Down Expand Up @@ -230,7 +232,7 @@ object DatabricksUtilities {
()
}

def submitRun(clusterId: String, notebookPath: String): Int = {
def submitRun(clusterId: String, notebookPath: String): Long = {
val body =
s"""
|{
Expand All @@ -244,7 +246,7 @@ object DatabricksUtilities {
| "libraries": $Libraries
|}
""".stripMargin
databricksPost("jobs/runs/submit", body).select[Int]("run_id")
databricksPost("jobs/runs/submit", body).select[Long]("run_id")
}

def isClusterActive(clusterId: String): Boolean = {
Expand All @@ -265,7 +267,7 @@ object DatabricksUtilities {
libraryStatuses.forall(_.select[String]("status") == "INSTALLED")
}

private def getRunStatuses(runId: Int): (String, Option[String]) = {
private def getRunStatuses(runId: Long): (String, Option[String]) = {
val runObj = databricksGet(s"jobs/runs/get?run_id=$runId")
val stateObj = runObj.select[JsObject]("state")
val lifeCycleState = stateObj.select[String]("life_cycle_state")
Expand All @@ -277,7 +279,7 @@ object DatabricksUtilities {
}
}

def getRunUrlAndNBName(runId: Int): (String, String) = {
def getRunUrlAndNBName(runId: Long): (String, String) = {
val runObj = databricksGet(s"jobs/runs/get?run_id=$runId").asJsObject()
val url = runObj.select[String]("run_page_url")
.replaceAll("westus", Region) //TODO this seems like an ADB bug
Expand All @@ -286,7 +288,7 @@ object DatabricksUtilities {
}

//scalastyle:off cyclomatic.complexity
def monitorJob(runId: Integer,
def monitorJob(runId: Long,
timeout: Int,
interval: Int = 8000,
logLevel: Int = 1): Future[Unit] = {
Expand Down Expand Up @@ -342,28 +344,28 @@ object DatabricksUtilities {
workspaceMkDir(folderToCreate)
val destination: String = folderToCreate + notebookFile.getName
uploadNotebook(notebookFile, destination)
val runId: Int = submitRun(clusterId, destination)
val runId: Long = submitRun(clusterId, destination)
val run: DatabricksNotebookRun = DatabricksNotebookRun(runId, notebookFile.getName)
println(s"Successfully submitted job run id ${run.runId} for notebook ${run.notebookName}")
run
}

def cancelRun(runId: Int): Unit = {
def cancelRun(runId: Long): Unit = {
println(s"Cancelling job $runId")
databricksPost("jobs/runs/cancel", s"""{"run_id": $runId}""")
()
}

def listActiveJobs(clusterId: String): Vector[Int] = {
def listActiveJobs(clusterId: String): Vector[Long] = {
//TODO this only gets the first 1k running jobs, full solution would page results
databricksGet("jobs/runs/list?active_only=true&limit=1000")
.asJsObject.fields.get("runs").map { runs =>
runs.asInstanceOf[JsArray].elements.flatMap {
case run if clusterId == run.select[String]("cluster_instance.cluster_id") =>
Some(run.select[Int]("run_id"))
case _ => None
}
}.getOrElse(Array().toVector: Vector[Int])
runs.asInstanceOf[JsArray].elements.flatMap {
case run if clusterId == run.select[String]("cluster_instance.cluster_id") =>
Some(run.select[Long]("run_id"))
case _ => None
}
}.getOrElse(Array().toVector: Vector[Long])
}

def listInstalledLibraries(clusterId: String): Vector[JsValue] = {
Expand Down Expand Up @@ -400,8 +402,8 @@ abstract class DatabricksTestHelper extends TestBase {

def databricksTestHelper(clusterId: String,
libraries: String,
notebooks: Seq[File]): mutable.ListBuffer[Int] = {
val jobIdsToCancel: mutable.ListBuffer[Int] = mutable.ListBuffer[Int]()
notebooks: Seq[File]): mutable.ListBuffer[Long] = {
val jobIdsToCancel: mutable.ListBuffer[Long] = mutable.ListBuffer[Long]()

println("Checking if cluster is active")
tryWithRetries(Seq.fill(60 * 15)(1000).toArray) { () =>
Expand Down Expand Up @@ -437,7 +439,7 @@ abstract class DatabricksTestHelper extends TestBase {
jobIdsToCancel
}

protected def afterAllHelper(jobIdsToCancel: mutable.ListBuffer[Int],
protected def afterAllHelper(jobIdsToCancel: mutable.ListBuffer[Long],
clusterId: String,
clusterName: String): Unit = {
println("Suite test finished. Running afterAll procedure...")
Expand All @@ -447,7 +449,7 @@ abstract class DatabricksTestHelper extends TestBase {
}
}

case class DatabricksNotebookRun(runId: Int, notebookName: String) {
case class DatabricksNotebookRun(runId: Long, notebookName: String) {
def monitor(logLevel: Int = 2): Future[Any] = {
monitorJob(runId, TimeoutInMillis, logLevel)
}
Expand Down

0 comments on commit 06f1383

Please sign in to comment.