Skip to content

Commit

Permalink
added train() to Predictor subclasses which does not take a ParamMap.
Browse files Browse the repository at this point in the history
  • Loading branch information
jkbradley committed Feb 5, 2015
1 parent 57d54ab commit 58802e3
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,12 @@ class LogisticRegression extends Classifier[LogisticRegression, LogisticRegressi

/**
* Same as [[fit()]], but using strong types.
*
* @param dataset Training data. WARNING: This does not yet handle instance weights.
* NOTE: This does NOT support instance weights.
* @param dataset Training data. Instance weights are ignored.
* @param paramMap Parameters for training.
* These values override any specified in this Estimator's embedded ParamMap.
*/
def train(dataset: RDD[LabeledPoint], paramMap: ParamMap): LogisticRegressionModel = {
override def train(dataset: RDD[LabeledPoint], paramMap: ParamMap): LogisticRegressionModel = {
val map = this.paramMap ++ paramMap
val oldDataset = dataset.map { case LabeledPoint(label: Double, features: Vector, weight) =>
org.apache.spark.mllib.regression.LabeledPoint(label, features)
Expand All @@ -96,6 +96,13 @@ class LogisticRegression extends Classifier[LogisticRegression, LogisticRegressi
}
lrm
}

/**
* Same as [[fit()]], but using strong types.
* NOTE: This does NOT support instance weights.
* @param dataset Training data. Instance weights are ignored.
*/
def train(dataset: RDD[LabeledPoint]): LogisticRegressionModel = train(dataset, new ParamMap())
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,12 @@ class LinearRegression extends Regressor[LinearRegression, LinearRegressionModel

/**
* Same as [[fit()]], but using strong types.
*
* @param dataset Training data. WARNING: This does not yet handle instance weights.
* NOTE: This does NOT support instance weights.
* @param dataset Training data. Instance weights are ignored.
* @param paramMap Parameters for training.
* These values override any specified in this Estimator's embedded ParamMap.
*/
def train(dataset: RDD[LabeledPoint], paramMap: ParamMap): LinearRegressionModel = {
override def train(dataset: RDD[LabeledPoint], paramMap: ParamMap): LinearRegressionModel = {
val oldDataset = dataset.map { case LabeledPoint(label: Double, features: Vector, weight) =>
org.apache.spark.mllib.regression.LabeledPoint(label, features)
}
Expand All @@ -71,6 +71,13 @@ class LinearRegression extends Regressor[LinearRegression, LinearRegressionModel
}
lrm
}

/**
* Same as [[fit()]], but using strong types.
* NOTE: This does NOT support instance weights.
* @param dataset Training data. Instance weights are ignored.
*/
def train(dataset: RDD[LabeledPoint]): LinearRegressionModel = train(dataset, new ParamMap())
}

/**
Expand Down

0 comments on commit 58802e3

Please sign in to comment.