Skip to content

Commit

Permalink
Adding getBestModel and getBestModelInfo to TuneHyperparameters and f…
Browse files Browse the repository at this point in the history
…ixed autogenerated code for model name (#355)
  • Loading branch information
imatiach-msft authored and mhamilton723 committed Aug 18, 2018
1 parent 0351dfe commit d287118
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 8 deletions.
17 changes: 17 additions & 0 deletions notebooks/samples/203 - Breast Cancer - Tune Hyperparameters.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,23 @@
" paramSpace=randomSpace.space(), seed=0).fit(tune)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can view the best model's parameters and retrieve the underlying best model pipeline"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(bestModel.getBestModelInfo())\n",
"print(bestModel.getBestModel())"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
21 changes: 17 additions & 4 deletions src/codegen/src/main/scala/WrapperGenerator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ abstract class WrapperGenerator {

def wrapperName(myClass: Class[_]): String

def modelWrapperName(myClass: Class[_], modelName: String): String

def generateEstimatorWrapper(entryPoint: Estimator[_],
entryPointName: String,
entryPointQualifiedName: String,
Expand Down Expand Up @@ -59,16 +61,18 @@ abstract class WrapperGenerator {
generateTransformerTestWrapper(t, className, qualifiedClassName))
case e: Estimator[_] =>
val sc = iterate[Class[_]](myClass)(_.getSuperclass)
.find(c => Seq("Estimator", "Predictor").contains(c.getSuperclass.getSimpleName))
.find(c => Seq("Estimator", "ProbabilisticClassifier", "Predictor")
.contains(c.getSuperclass.getSimpleName))
.get
val typeArgs = sc.getGenericSuperclass.asInstanceOf[ParameterizedType]
.getActualTypeArguments
val getModelFromGenericType = (modelType: Type) => {
val modelClass = modelType.getTypeName.split("<").head
(modelClass.split("\\.").last, modelClass)
(modelWrapperName(myClass, modelClass.split("\\.").last), modelClass)
}
val (modelClass, modelQualifiedClass) = sc.getSuperclass.getSimpleName match {
case "Estimator" => getModelFromGenericType(typeArgs.head)
case "ProbabilisticClassifier" => getModelFromGenericType(typeArgs(2))
case "Predictor" => getModelFromGenericType(typeArgs(2))
}

Expand Down Expand Up @@ -149,6 +153,11 @@ class PySparkWrapperGenerator extends WrapperGenerator {
prefix + myClass.getSimpleName
}

def modelWrapperName(myClass: Class[_], modelName: String): String = {
val prefix = if (needsInternalWrapper(myClass)) internalPrefix else ""
prefix + modelName
}

def generateEstimatorWrapper(entryPoint: Estimator[_],
entryPointName: String,
entryPointQualifiedName: String,
Expand Down Expand Up @@ -228,15 +237,19 @@ class SparklyRWrapperGenerator extends WrapperGenerator {
|export(sdf_transform)
|""".stripMargin)

def wrapperName(myClass: Class[_]): String =
myClass.getSimpleName.foldLeft((true, ""))((base, c) => {
def formatWrapperName(name: String): String =
name.foldLeft((true, ""))((base, c) => {
val ignoreCaps = base._1
val partialStr = base._2
if (!c.isUpper) (false, partialStr + c)
else if (ignoreCaps) (true, partialStr + c.toLower)
else (true, partialStr + "_" + c.toLower)
})._2

def wrapperName(myClass: Class[_]): String = formatWrapperName(myClass.getSimpleName)

def modelWrapperName(myClass: Class[_], modelName: String): String = formatWrapperName(modelName)

def generateEstimatorWrapper(entryPoint: Estimator[_],
entryPointName: String,
entryPointQualifiedName: String,
Expand Down
4 changes: 2 additions & 2 deletions src/lightgbm/src/main/python/LightGBMClassifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
basestring = str

from mmlspark._LightGBMClassifier import _LightGBMClassifier
from mmlspark._LightGBMClassifier import M
from mmlspark._LightGBMClassifier import _LightGBMClassificationModel
from pyspark.ml.common import inherit_doc

@inherit_doc
Expand All @@ -21,7 +21,7 @@ def _create_model(self, java_model):
return model

@inherit_doc
class LightGBMClassificationModel(M):
class LightGBMClassificationModel(_LightGBMClassificationModel):
def saveNativeModel(self, sparkSession, filename):
"""
Save the booster as string format to a local or WASB remote location.
Expand Down
4 changes: 2 additions & 2 deletions src/lightgbm/src/main/python/LightGBMRegressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
basestring = str

from mmlspark._LightGBMRegressor import _LightGBMRegressor
from mmlspark._LightGBMRegressor import M
from mmlspark._LightGBMRegressor import _LightGBMRegressionModel
from pyspark.ml.common import inherit_doc

@inherit_doc
Expand All @@ -21,7 +21,7 @@ def _create_model(self, java_model):
return model

@inherit_doc
class LightGBMRegressionModel(M):
class LightGBMRegressionModel(_LightGBMRegressionModel):
def saveNativeModel(self, sparkSession, filename):
"""
Save the booster as string format to a local or WASB remote location.
Expand Down
36 changes: 36 additions & 0 deletions src/tune-hyperparameters/src/main/python/TuneHyperparameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright (C) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE in project root for information.

import sys
from pyspark import SQLContext
from pyspark import SparkContext

if sys.version >= '3':
basestring = str

from mmlspark._TuneHyperparameters import _TuneHyperparameters
from mmlspark._TuneHyperparameters import _TuneHyperparametersModel
from pyspark.ml.wrapper import JavaParams
from pyspark.ml.common import inherit_doc

@inherit_doc
class TuneHyperparameters(_TuneHyperparameters):
def _create_model(self, java_model):
model = TuneHyperparametersModel()
model._java_obj = java_model
model._transfer_params_from_java()
return model

@inherit_doc
class TuneHyperparametersModel(_TuneHyperparametersModel):
def getBestModel(self):
"""
Returns the best model.
"""
return JavaParams._from_java(self._java_obj.getBestModel())

def getBestModelInfo(self):
"""
Returns the best model parameter info.
"""
return self._java_obj.getBestModelInfo()
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import scala.util.control.NonFatal
* Allows user to specify multiple untrained models to tune using various search strategies.
* Currently supports cross validation with random grid search.
*/
@InternalWrapper
class TuneHyperparameters(override val uid: String) extends Estimator[TuneHyperparametersModel]
with Wrappable with ComplexParamsWritable with HasEvaluationMetric {

Expand Down Expand Up @@ -191,6 +192,7 @@ class TuneHyperparameters(override val uid: String) extends Estimator[TuneHyperp
object TuneHyperparameters extends ComplexParamsReadable[TuneHyperparameters]

/** Model produced by [[TuneHyperparameters]]. */
@InternalWrapper
class TuneHyperparametersModel(val uid: String,
val model: Transformer,
val bestMetric: Double)
Expand Down

0 comments on commit d287118

Please sign in to comment.