Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-22922][ML][PySpark] Pyspark portion of the fit-multiple API #20058

Closed
wants to merge 6 commits into from

Conversation

MrBago
Copy link
Contributor

@MrBago MrBago commented Dec 22, 2017

What changes were proposed in this pull request?

Adding fitMultiple API to Estimator with default implementation. Also update have ml.tuning meta-estimators use this API.

How was this patch tested?

Unit tests.

@SparkQA
Copy link

SparkQA commented Dec 22, 2017

Test build #85318 has finished for PR 20058 at commit fdef9d5.

  • This patch fails Python style tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Dec 22, 2017

Test build #85319 has finished for PR 20058 at commit 49c8332.

  • This patch fails Python style tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Dec 27, 2017

Test build #85443 has finished for PR 20058 at commit d73af1f.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@jkbradley
Copy link
Member

reviewing now

Copy link
Member

@jkbradley jkbradley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just minor comments. Thanks!

@@ -2359,6 +2359,21 @@ def test_unary_transformer_transform(self):
self.assertEqual(res.input + shiftVal, res.output)


class TestFit(unittest.TestCase):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: How about EstimatorTest since this is testing part of the Estimator API?

from pyspark.sql.types import StructField, StructType


class FitMutlipleIterator(object):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: Mutliple -> Multiple

class FitMutlipleIterator(object):
"""
Used by default implementation of Estimator.fitMultiple to produce models in a thread safe
iterator.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'd be nice to document what fitSingleModel should do, plus what the iterator returns.

nit: How about renaming numModel -> numModels ?

using `params[index]`. Params maps may be fit in an order different than their
order in params.

.. note:: Experimental
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use .. note:: DeveloperApi too.

.. note:: Experimental
"""
def fitSingleModel(index):
return self.fit(dataset, params[index])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we make a copy of the Estimator before defining fitSingleModel to be extra safe (in case some other thread modifies the Params in this Estimator before a call to fit()? You can do self.copy() beforehand to get a copy.

@@ -31,6 +31,17 @@
'TrainValidationSplitModel']


def parallelFitTasks(est, train, eva, validation, epm):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about a brief doc string?

Fits a model to the input dataset for each param map in params.

:param dataset: input dataset, which is an instance of :py:class:`pyspark.sql.DataFrame`.
:param params: A list/tuple of param maps.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's explicitly check that this is a list or tuple and throw a good error message if not.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the docstring to Sequence instead of list/tuple, is that ok? Do you want to explicitly restrict the input to be a list or tuple?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there another Sequence type this could be other than list or tuple?

from pyspark.sql.types import StructField, StructType


class FitMutlipleIterator(object):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about change this FitMutlipleIterator class to be an inner class in default implementation method fitMultiple ? I think put it outside will be no other usage.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm open to this, but I didn't initially do it this way because I've been bit by nested classes in python before. There are subtle issues with nested classes in python. The one that comes to mind is serialization (which isn't an issue here) but that's not the only one.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jkbradley @WeichenXu123 I made FitMultipleIterator a private class, is that good enough or should I make it internal to the fitMultiple method?

@SparkQA
Copy link

SparkQA commented Dec 28, 2017

Test build #85481 has finished for PR 20058 at commit 209278d.

  • This patch fails PySpark unit tests.
  • This patch merges cleanly.
  • This patch adds the following public classes (experimental):
  • class FitMultipleIterator(object):

@SparkQA
Copy link

SparkQA commented Dec 28, 2017

Test build #85483 has finished for PR 20058 at commit fe3d6bd.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds the following public classes (experimental):
  • class FitMultipleIterator(object):

Copy link
Contributor

@holdenk holdenk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have some initial quick questions, but this looks interesting :)

@@ -47,6 +86,28 @@ def _fit(self, dataset):
"""
raise NotImplementedError()

@since("2.3.0")
def fitMultiple(self, dataset, params):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So in Scala Spark we use the fit function rather than separate functions. Also the params name is different than the Scala one. Any reason for the difference?

Copy link
Member

@jkbradley jkbradley Dec 29, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check out the discussion on the JIRA and the linked design doc. Basically, we need the same argument types but different return types from what the current fit() method provides. (It's a somewhat long chain of discussion stemming from adding the "parallelism" Param to meta-algorithms in master.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We couldn't use fit because it's going to have the same signature as the existing fit method but return a different type, (Iterator[(Int, Model)] instead of Seq[Model]). I was trying to be consistent with Estimator.fit which uses the name params which is different than the name of the same argument in Scala :/. Happy to change it.

Copy link
Member

@jkbradley jkbradley Dec 29, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point that we could rename "params" to be clearer in this new API. How about "paramMaps"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made this change.

def fitSingleModel(index):
return estimator.fit(dataset, params[index])

return FitMultipleIterator(fitSingleModel, len(params))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So whats the benefit of FitMultipleIterator v.s. using imap_unordered?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea is you should be able to do something like this:

pool = ...
modelIter = estimator.fitMultiple(params)
rng = range(len(params))
for index, model in pool.imap_unordered(lambda _: next(modelIter), rng):
    pass

That's pretty much how I've set up corss validator to use it, https://github.com/apache/spark/pull/20058/files/fe3d6bddc3e9e50febf706d7f22007b1e0d58de3#diff-cbc8c36bfdd245e4e4d5bd27f9b95359R292

The reason for set it up this way is so that, when appropriate, Estimators can implement their own optimized fitMultiple methods that just need to return an "iterator", A class with __iter__ and __next__. For examples models that use maxIter and maxDepth params.

@MrBago MrBago changed the title [SPARK-22126][ML][PySpark] Pyspark portion of the fit-multiple API [SPARK-22922][ML][PySpark] Pyspark portion of the fit-multiple API Dec 29, 2017
@SparkQA
Copy link

SparkQA commented Dec 29, 2017

Test build #85529 has finished for PR 20058 at commit c44db97.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds the following public classes (experimental):
  • class _FitMultipleIterator(object):

@jkbradley
Copy link
Member

LGTM
I'll merge this b/c of the time pressure for 2.3, but @holdenk please follow up if you have more comments on this.
Thanks!

@asfgit asfgit closed this in 30fcdc0 Dec 30, 2017
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants