-
Notifications
You must be signed in to change notification settings - Fork 28.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-22922][ML][PySpark] Pyspark portion of the fit-multiple API #20058
Changes from 4 commits
15e9c33
fdef9d5
49c8332
d73af1f
fe3d6bd
c44db97
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,13 +18,40 @@ | |
from abc import ABCMeta, abstractmethod | ||
|
||
import copy | ||
import threading | ||
|
||
from pyspark import since | ||
from pyspark.ml.param import Params | ||
from pyspark.ml.param.shared import * | ||
from pyspark.ml.common import inherit_doc | ||
from pyspark.sql.functions import udf | ||
from pyspark.sql.types import StructField, StructType, DoubleType | ||
from pyspark.sql.types import StructField, StructType | ||
|
||
|
||
class FitMutlipleIterator(object): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What about change this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm open to this, but I didn't initially do it this way because I've been bit by nested classes in python before. There are subtle issues with nested classes in python. The one that comes to mind is serialization (which isn't an issue here) but that's not the only one. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jkbradley @WeichenXu123 I made |
||
""" | ||
Used by default implementation of Estimator.fitMultiple to produce models in a thread safe | ||
iterator. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It'd be nice to document what fitSingleModel should do, plus what the iterator returns. nit: How about renaming numModel -> numModels ? |
||
""" | ||
def __init__(self, fitSingleModel, numModel): | ||
self.fitSingleModel = fitSingleModel | ||
self.numModel = numModel | ||
self.counter = 0 | ||
self.lock = threading.Lock() | ||
|
||
def __iter__(self): | ||
return self | ||
|
||
def __next__(self): | ||
with self.lock: | ||
index = self.counter | ||
if index >= self.numModel: | ||
raise StopIteration("No models remaining.") | ||
self.counter += 1 | ||
return index, self.fitSingleModel(index) | ||
|
||
def next(self): | ||
"""For python2 compatibility.""" | ||
return self.__next__() | ||
|
||
|
||
@inherit_doc | ||
|
@@ -47,6 +74,24 @@ def _fit(self, dataset): | |
""" | ||
raise NotImplementedError() | ||
|
||
@since("2.3.0") | ||
def fitMultiple(self, dataset, params): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So in Scala Spark we use the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Check out the discussion on the JIRA and the linked design doc. Basically, we need the same argument types but different return types from what the current fit() method provides. (It's a somewhat long chain of discussion stemming from adding the "parallelism" Param to meta-algorithms in master.) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We couldn't use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's a good point that we could rename "params" to be clearer in this new API. How about "paramMaps"? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I made this change. |
||
""" | ||
Fits a model to the input dataset for each param map in params. | ||
|
||
:param dataset: input dataset, which is an instance of :py:class:`pyspark.sql.DataFrame`. | ||
:param params: A list/tuple of param maps. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's explicitly check that this is a list or tuple and throw a good error message if not. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I changed the docstring to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there another Sequence type this could be other than list or tuple? |
||
:return: A thread safe iterable which contains one model for each param map. Each | ||
call to `next(modelIterator)` will return `(index, model)` where model was fit | ||
using `params[index]`. Params maps may be fit in an order different than their | ||
order in params. | ||
|
||
.. note:: Experimental | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's use |
||
""" | ||
def fitSingleModel(index): | ||
return self.fit(dataset, params[index]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shall we make a copy of the Estimator before defining fitSingleModel to be extra safe (in case some other thread modifies the Params in this Estimator before a call to fit()? You can do |
||
return FitMutlipleIterator(fitSingleModel, len(params)) | ||
|
||
@since("1.3.0") | ||
def fit(self, dataset, params=None): | ||
""" | ||
|
@@ -61,7 +106,10 @@ def fit(self, dataset, params=None): | |
if params is None: | ||
params = dict() | ||
if isinstance(params, (list, tuple)): | ||
return [self.fit(dataset, paramMap) for paramMap in params] | ||
models = [None] * len(params) | ||
for index, model in self.fitMultiple(dataset, params): | ||
models[index] = model | ||
return models | ||
elif isinstance(params, dict): | ||
if params: | ||
return self.copy(params)._fit(dataset) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2359,6 +2359,21 @@ def test_unary_transformer_transform(self): | |
self.assertEqual(res.input + shiftVal, res.output) | ||
|
||
|
||
class TestFit(unittest.TestCase): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: How about EstimatorTest since this is testing part of the Estimator API? |
||
|
||
def testDefaultFitMultiple(self): | ||
N = 4 | ||
data = MockDataset() | ||
estimator = MockEstimator() | ||
params = [{estimator.fake: i} for i in range(N)] | ||
modelIter = estimator.fitMultiple(data, params) | ||
indexList = [] | ||
for index, model in modelIter: | ||
self.assertEqual(model.getFake(), index) | ||
indexList.append(index) | ||
self.assertEqual(sorted(indexList), list(range(N))) | ||
|
||
|
||
if __name__ == "__main__": | ||
from pyspark.ml.tests import * | ||
if xmlrunner: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,6 +31,17 @@ | |
'TrainValidationSplitModel'] | ||
|
||
|
||
def parallelFitTasks(est, train, eva, validation, epm): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about a brief doc string? |
||
modelIter = est.fitMultiple(train, epm) | ||
|
||
def singleTask(): | ||
index, model = next(modelIter) | ||
metric = eva.evaluate(model.transform(validation, epm[index])) | ||
return index, metric | ||
|
||
return [singleTask] * len(epm) | ||
|
||
|
||
class ParamGridBuilder(object): | ||
r""" | ||
Builder for a param grid used in grid search-based model selection. | ||
|
@@ -266,15 +277,9 @@ def _fit(self, dataset): | |
validation = df.filter(condition).cache() | ||
train = df.filter(~condition).cache() | ||
|
||
def singleTrain(paramMap): | ||
model = est.fit(train, paramMap) | ||
# TODO: duplicate evaluator to take extra params from input | ||
metric = eva.evaluate(model.transform(validation, paramMap)) | ||
return metric | ||
|
||
currentFoldMetrics = pool.map(singleTrain, epm) | ||
for j in range(numModels): | ||
metrics[j] += (currentFoldMetrics[j] / nFolds) | ||
tasks = parallelFitTasks(est, train, eva, validation, epm) | ||
for j, metric in pool.imap_unordered(lambda f: f(), tasks): | ||
metrics[j] += (metric / nFolds) | ||
validation.unpersist() | ||
train.unpersist() | ||
|
||
|
@@ -523,13 +528,11 @@ def _fit(self, dataset): | |
validation = df.filter(condition).cache() | ||
train = df.filter(~condition).cache() | ||
|
||
def singleTrain(paramMap): | ||
model = est.fit(train, paramMap) | ||
metric = eva.evaluate(model.transform(validation, paramMap)) | ||
return metric | ||
|
||
tasks = parallelFitTasks(est, train, eva, validation, epm) | ||
pool = ThreadPool(processes=min(self.getParallelism(), numModels)) | ||
metrics = pool.map(singleTrain, epm) | ||
metrics = [None] * numModels | ||
for j, metric in pool.imap_unordered(lambda f: f(), tasks): | ||
metrics[j] = metric | ||
train.unpersist() | ||
validation.unpersist() | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo: Mutliple -> Multiple