Skip to content

Commit

Permalink
improve code
Browse files Browse the repository at this point in the history
  • Loading branch information
WeichenXu123 committed Dec 8, 2017
1 parent ec50dad commit 2cc7c28
Showing 1 changed file with 4 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.ml.tuning

import java.util.{List => JList, Locale}
import java.util.concurrent.atomic.AtomicInteger

import scala.collection.JavaConverters._
import scala.concurrent.Future
Expand Down Expand Up @@ -146,15 +147,13 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
val validationDataset = sparkSession.createDataFrame(validation, schema).cache()
logDebug(s"Train split $splitIndex with multiple sets of parameters.")

var completeFitCount = 0
val signal = new Object
val completeFitCount = new AtomicInteger(0)
// Fit models in a Future for training in parallel
val foldMetricFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) =>
Future[Double] {
val model = est.fit(trainingDataset, paramMap).asInstanceOf[Model[_]]
signal.synchronized {
completeFitCount += 1
signal.notify()
if (completeFitCount.incrementAndGet() == epm.length) {
trainingDataset.unpersist()
}

if (collectSubModelsParam) {
Expand All @@ -166,14 +165,6 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
metric
} (executionContext)
}
Future {
signal.synchronized {
while (completeFitCount < epm.length) {
signal.wait()
}
}
trainingDataset.unpersist()
} (executionContext)

// Wait for metrics to be calculated before unpersisting validation dataset
val foldMetrics = foldMetricFutures.map(ThreadUtils.awaitResult(_, Duration.Inf))
Expand Down

0 comments on commit 2cc7c28

Please sign in to comment.