-
Notifications
You must be signed in to change notification settings - Fork 28.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-22922][ML][PySpark] Pyspark portion of the fit-multiple API #20058
Conversation
Test build #85318 has finished for PR 20058 at commit
|
Test build #85319 has finished for PR 20058 at commit
|
Test build #85443 has finished for PR 20058 at commit
|
reviewing now |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just minor comments. Thanks!
python/pyspark/ml/tests.py
Outdated
@@ -2359,6 +2359,21 @@ def test_unary_transformer_transform(self): | |||
self.assertEqual(res.input + shiftVal, res.output) | |||
|
|||
|
|||
class TestFit(unittest.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: How about EstimatorTest since this is testing part of the Estimator API?
python/pyspark/ml/base.py
Outdated
from pyspark.sql.types import StructField, StructType | ||
|
||
|
||
class FitMutlipleIterator(object): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo: Mutliple -> Multiple
python/pyspark/ml/base.py
Outdated
class FitMutlipleIterator(object): | ||
""" | ||
Used by default implementation of Estimator.fitMultiple to produce models in a thread safe | ||
iterator. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It'd be nice to document what fitSingleModel should do, plus what the iterator returns.
nit: How about renaming numModel -> numModels ?
python/pyspark/ml/base.py
Outdated
using `params[index]`. Params maps may be fit in an order different than their | ||
order in params. | ||
|
||
.. note:: Experimental |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's use .. note:: DeveloperApi
too.
python/pyspark/ml/base.py
Outdated
.. note:: Experimental | ||
""" | ||
def fitSingleModel(index): | ||
return self.fit(dataset, params[index]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shall we make a copy of the Estimator before defining fitSingleModel to be extra safe (in case some other thread modifies the Params in this Estimator before a call to fit()? You can do self.copy()
beforehand to get a copy.
python/pyspark/ml/tuning.py
Outdated
@@ -31,6 +31,17 @@ | |||
'TrainValidationSplitModel'] | |||
|
|||
|
|||
def parallelFitTasks(est, train, eva, validation, epm): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about a brief doc string?
python/pyspark/ml/base.py
Outdated
Fits a model to the input dataset for each param map in params. | ||
|
||
:param dataset: input dataset, which is an instance of :py:class:`pyspark.sql.DataFrame`. | ||
:param params: A list/tuple of param maps. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's explicitly check that this is a list or tuple and throw a good error message if not.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed the docstring to Sequence
instead of list/tuple
, is that ok? Do you want to explicitly restrict the input to be a list or tuple?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there another Sequence type this could be other than list or tuple?
python/pyspark/ml/base.py
Outdated
from pyspark.sql.types import StructField, StructType | ||
|
||
|
||
class FitMutlipleIterator(object): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about change this FitMutlipleIterator
class to be an inner class in default implementation method fitMultiple
? I think put it outside will be no other usage.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm open to this, but I didn't initially do it this way because I've been bit by nested classes in python before. There are subtle issues with nested classes in python. The one that comes to mind is serialization (which isn't an issue here) but that's not the only one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jkbradley @WeichenXu123 I made FitMultipleIterator
a private class, is that good enough or should I make it internal to the fitMultiple
method?
Test build #85481 has finished for PR 20058 at commit
|
209278d
to
fe3d6bd
Compare
Test build #85483 has finished for PR 20058 at commit
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have some initial quick questions, but this looks interesting :)
python/pyspark/ml/base.py
Outdated
@@ -47,6 +86,28 @@ def _fit(self, dataset): | |||
""" | |||
raise NotImplementedError() | |||
|
|||
@since("2.3.0") | |||
def fitMultiple(self, dataset, params): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So in Scala Spark we use the fit
function rather than separate functions. Also the params
name is different than the Scala one. Any reason for the difference?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Check out the discussion on the JIRA and the linked design doc. Basically, we need the same argument types but different return types from what the current fit() method provides. (It's a somewhat long chain of discussion stemming from adding the "parallelism" Param to meta-algorithms in master.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We couldn't use fit
because it's going to have the same signature as the existing fit
method but return a different type, (Iterator[(Int, Model)] instead of Seq[Model]). I was trying to be consistent with Estimator.fit which uses the name params
which is different than the name of the same argument in Scala :/. Happy to change it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good point that we could rename "params" to be clearer in this new API. How about "paramMaps"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made this change.
python/pyspark/ml/base.py
Outdated
def fitSingleModel(index): | ||
return estimator.fit(dataset, params[index]) | ||
|
||
return FitMultipleIterator(fitSingleModel, len(params)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So whats the benefit of FitMultipleIterator
v.s. using imap_unordered
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The idea is you should be able to do something like this:
pool = ...
modelIter = estimator.fitMultiple(params)
rng = range(len(params))
for index, model in pool.imap_unordered(lambda _: next(modelIter), rng):
pass
That's pretty much how I've set up corss validator to use it, https://github.com/apache/spark/pull/20058/files/fe3d6bddc3e9e50febf706d7f22007b1e0d58de3#diff-cbc8c36bfdd245e4e4d5bd27f9b95359R292
The reason for set it up this way is so that, when appropriate, Estimators can implement their own optimized fitMultiple
methods that just need to return an "iterator", A class with __iter__
and __next__
. For examples models that use maxIter
and maxDepth
params.
Test build #85529 has finished for PR 20058 at commit
|
LGTM |
What changes were proposed in this pull request?
Adding fitMultiple API to
Estimator
with default implementation. Also update have ml.tuning meta-estimators use this API.How was this patch tested?
Unit tests.