Skip to content

Commit

Permalink
Fix init logic of weights
Browse files Browse the repository at this point in the history
  • Loading branch information
Lewuathe committed Mar 20, 2015
1 parent 0745a30 commit 21d4fe3
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 1 deletion.
3 changes: 2 additions & 1 deletion python/pyspark/mllib/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,8 @@ def _regression_train_wrapper(train_func, modelClass, data, initial_weights):
first = data.first()
if not isinstance(first, LabeledPoint):
raise ValueError("data should be an RDD of LabeledPoint, but got %s" % first)
initial_weights = initial_weights or [0.0] * len(data.first().features)
if initial_weights == None:
initial_weights = [0.0] * len(data.first().features)
weights, intercept = train_func(data, _convert_to_vector(initial_weights))
return modelClass(weights, intercept)

Expand Down
6 changes: 6 additions & 0 deletions python/pyspark/mllib/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,12 @@ def test_regression(self):
self.assertTrue(gbt_model.predict(features[2]) <= 0)
self.assertTrue(gbt_model.predict(features[3]) > 0)

try:
LinearRegressionWithSGD.train(rdd, initialWeights=array([1.0, 1.0]))
LassoWithSGD.train(rdd, initialWeights=array([1.0, 1.0]))
RidgeRegressionWithSGD.train(rdd, initialWeights=array([1.0, 1.0]))
except ValueError:
self.fail()

class StatTests(PySparkTestCase):
# SPARK-4023
Expand Down

0 comments on commit 21d4fe3

Please sign in to comment.