diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 8fa0183e4683d..8052163acd00a 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -22,7 +22,7 @@ from pyspark import since, keyword_only from pyspark.ml import Estimator, Model -from pyspark.ml.common import _py2java +from pyspark.ml.common import _py2java, _java2py from pyspark.ml.param import Params, Param, TypeConverters from pyspark.ml.param.shared import HasCollectSubModels, HasParallelism, HasSeed from pyspark.ml.util import * @@ -216,6 +216,8 @@ class CrossValidator(Estimator, _CrossValidatorParams, HasParallelism, HasCollec >>> from pyspark.ml.classification import LogisticRegression >>> from pyspark.ml.evaluation import BinaryClassificationEvaluator >>> from pyspark.ml.linalg import Vectors + >>> from pyspark.ml.tuning import CrossValidatorModel + >>> import tempfile >>> dataset = spark.createDataFrame( ... [(Vectors.dense([0.0]), 0.0), ... (Vectors.dense([0.4]), 1.0), @@ -233,6 +235,12 @@ class CrossValidator(Estimator, _CrossValidatorParams, HasParallelism, HasCollec 3 >>> cvModel.avgMetrics[0] 0.5 + >>> path = tempfile.mkdtemp() + >>> model_path = path + "/model" + >>> cvModel.write().save(model_path) + >>> cvModelRead = CrossValidatorModel.read().load(model_path) + >>> cvModelRead.avgMetrics + [0.5, ... >>> evaluator.evaluate(cvModel.transform(dataset)) 0.8333... @@ -483,10 +491,12 @@ def _from_java(cls, java_stage): Given a Java CrossValidatorModel, create and return a Python wrapper of it. Used for ML persistence. """ + sc = SparkContext._active_spark_context bestModel = JavaParams._from_java(java_stage.bestModel()) + avgMetrics = _java2py(sc, java_stage.avgMetrics()) estimator, epms, evaluator = super(CrossValidatorModel, cls)._from_java_impl(java_stage) - py_stage = cls(bestModel=bestModel).setEstimator(estimator) + py_stage = cls(bestModel=bestModel, avgMetrics=avgMetrics).setEstimator(estimator) py_stage = py_stage.setEstimatorParamMaps(epms).setEvaluator(evaluator) if java_stage.hasSubModels(): @@ -505,11 +515,10 @@ def _to_java(self): """ sc = SparkContext._active_spark_context - # TODO: persist average metrics as well _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.CrossValidatorModel", self.uid, self.bestModel._to_java(), - _py2java(sc, [])) + _py2java(sc, self.avgMetrics)) estimator, epms, evaluator = super(CrossValidatorModel, self)._to_java_impl() _java_obj.set("evaluator", evaluator) @@ -551,6 +560,8 @@ class TrainValidationSplit(Estimator, _TrainValidationSplitParams, HasParallelis >>> from pyspark.ml.classification import LogisticRegression >>> from pyspark.ml.evaluation import BinaryClassificationEvaluator >>> from pyspark.ml.linalg import Vectors + >>> from pyspark.ml.tuning import TrainValidationSplitModel + >>> import tempfile >>> dataset = spark.createDataFrame( ... [(Vectors.dense([0.0]), 0.0), ... (Vectors.dense([0.4]), 1.0), @@ -566,6 +577,14 @@ class TrainValidationSplit(Estimator, _TrainValidationSplitParams, HasParallelis >>> tvsModel = tvs.fit(dataset) >>> tvsModel.getTrainRatio() 0.75 + >>> tvsModel.validationMetrics + [0.5, ... + >>> path = tempfile.mkdtemp() + >>> model_path = path + "/model" + >>> tvsModel.write().save(model_path) + >>> tvsModelRead = TrainValidationSplitModel.read().load(model_path) + >>> tvsModelRead.validationMetrics + [0.5, ... >>> evaluator.evaluate(tvsModel.transform(dataset)) 0.833... @@ -809,11 +828,14 @@ def _from_java(cls, java_stage): """ # Load information from java_stage to the instance. + sc = SparkContext._active_spark_context bestModel = JavaParams._from_java(java_stage.bestModel()) + validationMetrics = _java2py(sc, java_stage.validationMetrics()) estimator, epms, evaluator = super(TrainValidationSplitModel, cls)._from_java_impl(java_stage) # Create a new instance of this stage. - py_stage = cls(bestModel=bestModel).setEstimator(estimator) + py_stage = cls(bestModel=bestModel, + validationMetrics=validationMetrics).setEstimator(estimator) py_stage = py_stage.setEstimatorParamMaps(epms).setEvaluator(evaluator) if java_stage.hasSubModels(): @@ -830,12 +852,11 @@ def _to_java(self): """ sc = SparkContext._active_spark_context - # TODO: persst validation metrics as well _java_obj = JavaParams._new_java_obj( "org.apache.spark.ml.tuning.TrainValidationSplitModel", self.uid, self.bestModel._to_java(), - _py2java(sc, [])) + _py2java(sc, self.validationMetrics)) estimator, epms, evaluator = super(TrainValidationSplitModel, self)._to_java_impl() _java_obj.set("evaluator", evaluator)