diff --git a/notebooks/samples/203 - Breast Cancer - Tune Hyperparameters.ipynb b/notebooks/samples/203 - Breast Cancer - Tune Hyperparameters.ipynb index 2e2f3e1361..4fbd7cdcec 100644 --- a/notebooks/samples/203 - Breast Cancer - Tune Hyperparameters.ipynb +++ b/notebooks/samples/203 - Breast Cancer - Tune Hyperparameters.ipynb @@ -125,6 +125,23 @@ " paramSpace=randomSpace.space(), seed=0).fit(tune)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can view the best model's parameters and retrieve the underlying best model pipeline" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(bestModel.getBestModelInfo())\n", + "print(bestModel.getBestModel())" + ] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/src/codegen/src/main/scala/WrapperGenerator.scala b/src/codegen/src/main/scala/WrapperGenerator.scala index 6f4e7e5269..865bebc5cf 100644 --- a/src/codegen/src/main/scala/WrapperGenerator.scala +++ b/src/codegen/src/main/scala/WrapperGenerator.scala @@ -23,6 +23,8 @@ abstract class WrapperGenerator { def wrapperName(myClass: Class[_]): String + def modelWrapperName(myClass: Class[_], modelName: String): String + def generateEstimatorWrapper(entryPoint: Estimator[_], entryPointName: String, entryPointQualifiedName: String, @@ -59,16 +61,18 @@ abstract class WrapperGenerator { generateTransformerTestWrapper(t, className, qualifiedClassName)) case e: Estimator[_] => val sc = iterate[Class[_]](myClass)(_.getSuperclass) - .find(c => Seq("Estimator", "Predictor").contains(c.getSuperclass.getSimpleName)) + .find(c => Seq("Estimator", "ProbabilisticClassifier", "Predictor") + .contains(c.getSuperclass.getSimpleName)) .get val typeArgs = sc.getGenericSuperclass.asInstanceOf[ParameterizedType] .getActualTypeArguments val getModelFromGenericType = (modelType: Type) => { val modelClass = modelType.getTypeName.split("<").head - (modelClass.split("\\.").last, modelClass) + (modelWrapperName(myClass, modelClass.split("\\.").last), modelClass) } val (modelClass, modelQualifiedClass) = sc.getSuperclass.getSimpleName match { case "Estimator" => getModelFromGenericType(typeArgs.head) + case "ProbabilisticClassifier" => getModelFromGenericType(typeArgs(2)) case "Predictor" => getModelFromGenericType(typeArgs(2)) } @@ -149,6 +153,11 @@ class PySparkWrapperGenerator extends WrapperGenerator { prefix + myClass.getSimpleName } + def modelWrapperName(myClass: Class[_], modelName: String): String = { + val prefix = if (needsInternalWrapper(myClass)) internalPrefix else "" + prefix + modelName + } + def generateEstimatorWrapper(entryPoint: Estimator[_], entryPointName: String, entryPointQualifiedName: String, @@ -228,8 +237,8 @@ class SparklyRWrapperGenerator extends WrapperGenerator { |export(sdf_transform) |""".stripMargin) - def wrapperName(myClass: Class[_]): String = - myClass.getSimpleName.foldLeft((true, ""))((base, c) => { + def formatWrapperName(name: String): String = + name.foldLeft((true, ""))((base, c) => { val ignoreCaps = base._1 val partialStr = base._2 if (!c.isUpper) (false, partialStr + c) @@ -237,6 +246,10 @@ class SparklyRWrapperGenerator extends WrapperGenerator { else (true, partialStr + "_" + c.toLower) })._2 + def wrapperName(myClass: Class[_]): String = formatWrapperName(myClass.getSimpleName) + + def modelWrapperName(myClass: Class[_], modelName: String): String = formatWrapperName(modelName) + def generateEstimatorWrapper(entryPoint: Estimator[_], entryPointName: String, entryPointQualifiedName: String, diff --git a/src/lightgbm/src/main/python/LightGBMClassifier.py b/src/lightgbm/src/main/python/LightGBMClassifier.py index 5607b8f1b6..906d251094 100644 --- a/src/lightgbm/src/main/python/LightGBMClassifier.py +++ b/src/lightgbm/src/main/python/LightGBMClassifier.py @@ -9,7 +9,7 @@ basestring = str from mmlspark._LightGBMClassifier import _LightGBMClassifier -from mmlspark._LightGBMClassifier import M +from mmlspark._LightGBMClassifier import _LightGBMClassificationModel from pyspark.ml.common import inherit_doc @inherit_doc @@ -21,7 +21,7 @@ def _create_model(self, java_model): return model @inherit_doc -class LightGBMClassificationModel(M): +class LightGBMClassificationModel(_LightGBMClassificationModel): def saveNativeModel(self, sparkSession, filename): """ Save the booster as string format to a local or WASB remote location. diff --git a/src/lightgbm/src/main/python/LightGBMRegressor.py b/src/lightgbm/src/main/python/LightGBMRegressor.py index 409134d544..54f6ba705e 100644 --- a/src/lightgbm/src/main/python/LightGBMRegressor.py +++ b/src/lightgbm/src/main/python/LightGBMRegressor.py @@ -9,7 +9,7 @@ basestring = str from mmlspark._LightGBMRegressor import _LightGBMRegressor -from mmlspark._LightGBMRegressor import M +from mmlspark._LightGBMRegressor import _LightGBMRegressionModel from pyspark.ml.common import inherit_doc @inherit_doc @@ -21,7 +21,7 @@ def _create_model(self, java_model): return model @inherit_doc -class LightGBMRegressionModel(M): +class LightGBMRegressionModel(_LightGBMRegressionModel): def saveNativeModel(self, sparkSession, filename): """ Save the booster as string format to a local or WASB remote location. diff --git a/src/tune-hyperparameters/src/main/python/TuneHyperparameters.py b/src/tune-hyperparameters/src/main/python/TuneHyperparameters.py new file mode 100644 index 0000000000..beb6c71d97 --- /dev/null +++ b/src/tune-hyperparameters/src/main/python/TuneHyperparameters.py @@ -0,0 +1,36 @@ +# Copyright (C) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See LICENSE in project root for information. + +import sys +from pyspark import SQLContext +from pyspark import SparkContext + +if sys.version >= '3': + basestring = str + +from mmlspark._TuneHyperparameters import _TuneHyperparameters +from mmlspark._TuneHyperparameters import _TuneHyperparametersModel +from pyspark.ml.wrapper import JavaParams +from pyspark.ml.common import inherit_doc + +@inherit_doc +class TuneHyperparameters(_TuneHyperparameters): + def _create_model(self, java_model): + model = TuneHyperparametersModel() + model._java_obj = java_model + model._transfer_params_from_java() + return model + +@inherit_doc +class TuneHyperparametersModel(_TuneHyperparametersModel): + def getBestModel(self): + """ + Returns the best model. + """ + return JavaParams._from_java(self._java_obj.getBestModel()) + + def getBestModelInfo(self): + """ + Returns the best model parameter info. + """ + return self._java_obj.getBestModelInfo() diff --git a/src/tune-hyperparameters/src/main/scala/TuneHyperparameters.scala b/src/tune-hyperparameters/src/main/scala/TuneHyperparameters.scala index 2e1b35d439..6b9b412591 100644 --- a/src/tune-hyperparameters/src/main/scala/TuneHyperparameters.scala +++ b/src/tune-hyperparameters/src/main/scala/TuneHyperparameters.scala @@ -29,6 +29,7 @@ import scala.util.control.NonFatal * Allows user to specify multiple untrained models to tune using various search strategies. * Currently supports cross validation with random grid search. */ +@InternalWrapper class TuneHyperparameters(override val uid: String) extends Estimator[TuneHyperparametersModel] with Wrappable with ComplexParamsWritable with HasEvaluationMetric { @@ -191,6 +192,7 @@ class TuneHyperparameters(override val uid: String) extends Estimator[TuneHyperp object TuneHyperparameters extends ComplexParamsReadable[TuneHyperparameters] /** Model produced by [[TuneHyperparameters]]. */ +@InternalWrapper class TuneHyperparametersModel(val uid: String, val model: Transformer, val bestMetric: Double)