Skip to content

Commit

Permalink
update ml/tests.py
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed May 12, 2015
1 parent 64a536c commit 13bd70a
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 40 deletions.
15 changes: 6 additions & 9 deletions python/pyspark/ml/param/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,22 +139,19 @@ def hasParam(self, paramName):
Tests whether this instance contains a param with a given
(string) name.
"""
return self.params.count(paramName) != 0
param = self._resolveParam(paramName)
return param in self.params

def getOrDefault(self, param):
"""
Gets the value of a param in the user-supplied param map or its
default value. Raises an error if either is set.
"""
if isinstance(param, Param):
if param in self._paramMap:
return self._paramMap[param]
else:
return self._defaultParamMap[param]
elif isinstance(param, str):
return self.getOrDefault(self.getParam(param))
param = self._resolveParam(param)
if param in self._paramMap:
return self._paramMap[param]
else:
raise KeyError("Cannot recognize %r as a param." % param)
return self._defaultParamMap[param]

def extractParamMap(self, extra={}):
"""
Expand Down
15 changes: 12 additions & 3 deletions python/pyspark/ml/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,10 @@ def fit(self, dataset, params={}):
if isinstance(params, (list, tuple)):
return [self.fit(dataset, paramMap) for paramMap in params]
elif isinstance(params, dict):
return self.copy(params)._fit(dataset)
if params:
return self.copy(params)._fit(dataset)
else:
return self._fit(dataset)
else:
raise ValueError("Params must be either a param map or a list/tuple of param maps, "
"but got %s." % type(params))
Expand Down Expand Up @@ -97,7 +100,10 @@ def transform(self, dataset, params={}):
:returns: transformed dataset
"""
if isinstance(params, dict):
return self.copy(params,)._transform(dataset)
if params:
return self.copy(params,)._transform(dataset)
else:
return self._transform(dataset)
else:
raise ValueError("Params must be either a param map but got %s." % type(params))

Expand Down Expand Up @@ -263,6 +269,9 @@ def evaluate(self, dataset, params={}):
:return: metric
"""
if isinstance(params, dict):
return self.copy(params)._evaluate(dataset)
if params:
return self.copy(params)._evaluate(dataset)
else:
return self._evaluate(dataset)
else:
raise ValueError("Params must be a param map but got %s." % type(params))
47 changes: 19 additions & 28 deletions python/pyspark/ml/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
from pyspark.sql import DataFrame
from pyspark.ml.param import Param
from pyspark.ml.param import Param, Params
from pyspark.ml.param.shared import HasMaxIter, HasInputCol
from pyspark.ml.pipeline import Estimator, Model, Pipeline, Transformer

Expand All @@ -43,44 +43,38 @@ def __init__(self):
self.index = 0


class MockTransformer(Transformer):
class HasFake(Params):

def __init__(self):
super(HasFake, self).__init__()
self.fake = Param(self, "fake", "fake param")


class MockTransformer(Transformer, HasFake):

def __init__(self):
super(MockTransformer, self).__init__()
self.fake = Param(self, "fake", "fake")
self.dataset_index = None
self.fake_param_value = None

def transform(self, dataset, params={}):
def _transform(self, dataset):
self.dataset_index = dataset.index
if self.fake in params:
self.fake_param_value = params[self.fake]
dataset.index += 1
return dataset


class MockEstimator(Estimator):
class MockEstimator(Estimator, HasFake):

def __init__(self):
super(MockEstimator, self).__init__()
self.fake = Param(self, "fake", "fake")
self.dataset_index = None
self.fake_param_value = None
self.model = None

def fit(self, dataset, params={}):
def _fit(self, dataset):
self.dataset_index = dataset.index
if self.fake in params:
self.fake_param_value = params[self.fake]
model = MockModel()
self.model = model
return model


class MockModel(MockTransformer, Model):

def __init__(self):
super(MockModel, self).__init__()
class MockModel(MockTransformer, Model, HasFake): pass


class PipelineTests(PySparkTestCase):
Expand All @@ -94,16 +88,13 @@ def test_pipeline(self):
pipeline = Pipeline() \
.setStages([estimator0, transformer1, estimator2, transformer3])
pipeline_model = pipeline.fit(dataset, {estimator0.fake: 0, transformer1.fake: 1})
self.assertEqual(0, estimator0.dataset_index)
self.assertEqual(0, estimator0.fake_param_value)
model0 = estimator0.model
model0, transformer1, model2, transformer3 = pipeline_model.stages
self.assertEqual(0, model0.dataset_index)
self.assertEqual(1, transformer1.dataset_index)
self.assertEqual(1, transformer1.fake_param_value)
self.assertEqual(2, estimator2.dataset_index)
model2 = estimator2.model
self.assertIsNone(model2.dataset_index, "The model produced by the last estimator should "
"not be called during fit.")
self.assertEqual(2, dataset.index)
self.assertIsNone(model2.dataset_index, "The last model shouldn't be called in fit.")
self.assertIsNone(transformer3.dataset_index,
"The last transformer shouldn't be called in fit.")
dataset = pipeline_model.transform(dataset)
self.assertEqual(2, model0.dataset_index)
self.assertEqual(3, transformer1.dataset_index)
Expand All @@ -129,7 +120,7 @@ def test_param(self):
maxIter = testParams.maxIter
self.assertEqual(maxIter.name, "maxIter")
self.assertEqual(maxIter.doc, "max number of iterations (>= 0)")
self.assertTrue(maxIter.parent is testParams)
self.assertTrue(maxIter.parent == testParams.uid)

def test_params(self):
testParams = TestParams()
Expand Down

0 comments on commit 13bd70a

Please sign in to comment.