Skip to content

Commit

Permalink
[SPARK-6421][MLLIB] _regression_train_wrapper does not test initialWe…
Browse files Browse the repository at this point in the history
…ights correctly

Weight parameters must be initialized correctly even when numpy array is passed as initial weights.

Author: lewuathe <[email protected]>

Closes apache#5101 from Lewuathe/SPARK-6421 and squashes the following commits:

7795201 [lewuathe] Fix lint-python errors
21d4fe3 [lewuathe] Fix init logic of weights
  • Loading branch information
Lewuathe authored and mengxr committed Mar 20, 2015
1 parent 11e0259 commit 257cde7
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 1 deletion.
3 changes: 2 additions & 1 deletion python/pyspark/mllib/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,8 @@ def _regression_train_wrapper(train_func, modelClass, data, initial_weights):
first = data.first()
if not isinstance(first, LabeledPoint):
raise ValueError("data should be an RDD of LabeledPoint, but got %s" % first)
initial_weights = initial_weights or [0.0] * len(data.first().features)
if initial_weights is None:
initial_weights = [0.0] * len(data.first().features)
weights, intercept = train_func(data, _convert_to_vector(initial_weights))
return modelClass(weights, intercept)

Expand Down
7 changes: 7 additions & 0 deletions python/pyspark/mllib/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,13 @@ def test_regression(self):
self.assertTrue(gbt_model.predict(features[2]) <= 0)
self.assertTrue(gbt_model.predict(features[3]) > 0)

try:
LinearRegressionWithSGD.train(rdd, initialWeights=array([1.0, 1.0]))
LassoWithSGD.train(rdd, initialWeights=array([1.0, 1.0]))
RidgeRegressionWithSGD.train(rdd, initialWeights=array([1.0, 1.0]))
except ValueError:
self.fail()


class StatTests(PySparkTestCase):
# SPARK-4023
Expand Down

0 comments on commit 257cde7

Please sign in to comment.