Skip to content

Commit

Permalink
[ML-123][Core] Improve locality handling for native lib loading (#124)
Browse files Browse the repository at this point in the history
  • Loading branch information
xwu99 authored Aug 18, 2021
1 parent 0475721 commit 5aee77e
Show file tree
Hide file tree
Showing 7 changed files with 15 additions and 24 deletions.
10 changes: 10 additions & 0 deletions mllib-dal/src/main/java/org/apache/spark/ml/util/LibLoader.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ public final class LibLoader {

private static final Logger log = LoggerFactory.getLogger("LibLoader");

private static boolean isLoaded = false;

/**
* Get temp dir for exacting lib files
*
Expand All @@ -45,11 +47,16 @@ public static String getTempSubDir() {
* Load all native libs
*/
public static synchronized void loadLibraries() throws IOException {
if (isLoaded)
return;

if (!loadLibSYCL()) {
log.debug("SYCL libraries are not available, will load CPU libraries only.");
}
loadLibCCL();
loadLibMLlibDAL();

isLoaded = true;
}

/**
Expand All @@ -74,12 +81,15 @@ private static synchronized Boolean loadLibSYCL() throws IOException {
if (streamIn == null) {
return false;
}
streamIn.close();

loadFromJar(subDir, "libintlc.so.5");
loadFromJar(subDir, "libimf.so");
loadFromJar(subDir, "libirng.so");
loadFromJar(subDir, "libsvml.so");
loadFromJar(subDir, "libOpenCL.so.1");
loadFromJar(subDir, "libsycl.so.5");

return true;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,6 @@ class KMeansDALImpl(var nClusters: Int,
// Make sure there is only one result from rank 0
assert(results.length == 1)

// Release native memory for numeric tables
OneDAL.releaseNumericTables(data.sparkContext)

val centerVectors = results(0)._1
val totalCost = results(0)._2
val iterationNum = results(0)._3
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,6 @@ class PCADALImpl(val k: Int,
// Make sure there is only one result from rank 0
assert(results.length == 1)

// Release native memory for numeric tables
OneDAL.releaseNumericTables(data.sparkContext)

val pc = results(0)._1
val explainedVariance = results(0)._2

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import com.intel.daal.services.DaalContext
import org.apache.spark.Partitioner
import org.apache.spark.internal.Logging
import org.apache.spark.ml.recommendation.ALS.Rating
import org.apache.spark.ml.util.LibLoader.loadLibraries
import org.apache.spark.ml.util.Utils.getOneCCLIPPort
import org.apache.spark.ml.util._
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -84,9 +83,6 @@ class ALSDALImpl[@specialized(Int, Long) ID: ClassTag]( data: RDD[Rating[ID]],
Rating(p.item, p.user, p.rating)
}
.mapPartitionsWithIndex { (rank, iter) =>
// TODO: Use one-time init to load libraries
loadLibraries()

OneCCL.init(executorNum, rank, kvsIPPort)
val rankId = OneCCL.rankID()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import org.apache.spark.internal.Logging

object OneCCL extends Logging {

LibLoader.loadLibraries()

var cclParam = new CCLParam()

// Run on Executor
Expand Down
14 changes: 3 additions & 11 deletions mllib-dal/src/main/scala/org/apache/spark/ml/util/OneDAL.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ import scala.collection.mutable.ArrayBuffer

object OneDAL {

LibLoader.loadLibraries()

private val logger = Logger.getLogger("util.OneDAL")
private val logLevel = Level.INFO

Expand Down Expand Up @@ -119,19 +121,9 @@ object OneDAL {
matrix
}

def releaseNumericTables(sparkContext: SparkContext): Unit = {
sparkContext.getPersistentRDDs
.filter(r => r._2.name == "numericTables")
.foreach { rdd =>
val numericTables = rdd._2.asInstanceOf[RDD[Long]]
numericTables.foreach { address =>
OneDAL.cFreeDataMemory(address)
}
}
}

def rddDoubleToNumericTables(doubles: RDD[Double], executorNum: Int): RDD[Long] = {
require(executorNum > 0)

val doublesTables = doubles.repartition(executorNum).mapPartitions { it: Iterator[Double] =>
val data = it.toArray
// Build DALMatrix, this will load libJavaAPI, libtbb, libtbbmalloc
Expand Down
3 changes: 0 additions & 3 deletions mllib-dal/src/main/scala/org/apache/spark/ml/util/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,6 @@ object Utils {
// All other functions using native libraries will depend on this function to be called first
//
def checkClusterPlatformCompatibility(sc: SparkContext): Boolean = {
LibLoader.loadLibraries()

// check driver platform compatibility
if (!OneDAL.cCheckPlatformCompatibility()) {
return false
Expand All @@ -128,7 +126,6 @@ object Utils {
val executor_num = Utils.sparkExecutorNum(sc)
val data = sc.parallelize(1 to executor_num, executor_num)
val result = data.mapPartitions { p =>
LibLoader.loadLibraries()
Iterator(OneDAL.cCheckPlatformCompatibility())
}.collect()

Expand Down

0 comments on commit 5aee77e

Please sign in to comment.