Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-22922][ML][PySpark] Pyspark portion of the fit-multiple API #20058

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 51 additions & 3 deletions python/pyspark/ml/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,40 @@
from abc import ABCMeta, abstractmethod

import copy
import threading

from pyspark import since
from pyspark.ml.param import Params
from pyspark.ml.param.shared import *
from pyspark.ml.common import inherit_doc
from pyspark.sql.functions import udf
from pyspark.sql.types import StructField, StructType, DoubleType
from pyspark.sql.types import StructField, StructType


class FitMutlipleIterator(object):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: Mutliple -> Multiple

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about change this FitMutlipleIterator class to be an inner class in default implementation method fitMultiple ? I think put it outside will be no other usage.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm open to this, but I didn't initially do it this way because I've been bit by nested classes in python before. There are subtle issues with nested classes in python. The one that comes to mind is serialization (which isn't an issue here) but that's not the only one.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jkbradley @WeichenXu123 I made FitMultipleIterator a private class, is that good enough or should I make it internal to the fitMultiple method?

"""
Used by default implementation of Estimator.fitMultiple to produce models in a thread safe
iterator.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'd be nice to document what fitSingleModel should do, plus what the iterator returns.

nit: How about renaming numModel -> numModels ?

"""
def __init__(self, fitSingleModel, numModel):
self.fitSingleModel = fitSingleModel
self.numModel = numModel
self.counter = 0
self.lock = threading.Lock()

def __iter__(self):
return self

def __next__(self):
with self.lock:
index = self.counter
if index >= self.numModel:
raise StopIteration("No models remaining.")
self.counter += 1
return index, self.fitSingleModel(index)

def next(self):
"""For python2 compatibility."""
return self.__next__()


@inherit_doc
Expand All @@ -47,6 +74,24 @@ def _fit(self, dataset):
"""
raise NotImplementedError()

@since("2.3.0")
def fitMultiple(self, dataset, params):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So in Scala Spark we use the fit function rather than separate functions. Also the params name is different than the Scala one. Any reason for the difference?

Copy link
Member

@jkbradley jkbradley Dec 29, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check out the discussion on the JIRA and the linked design doc. Basically, we need the same argument types but different return types from what the current fit() method provides. (It's a somewhat long chain of discussion stemming from adding the "parallelism" Param to meta-algorithms in master.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We couldn't use fit because it's going to have the same signature as the existing fit method but return a different type, (Iterator[(Int, Model)] instead of Seq[Model]). I was trying to be consistent with Estimator.fit which uses the name params which is different than the name of the same argument in Scala :/. Happy to change it.

Copy link
Member

@jkbradley jkbradley Dec 29, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point that we could rename "params" to be clearer in this new API. How about "paramMaps"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made this change.

"""
Fits a model to the input dataset for each param map in params.

:param dataset: input dataset, which is an instance of :py:class:`pyspark.sql.DataFrame`.
:param params: A list/tuple of param maps.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's explicitly check that this is a list or tuple and throw a good error message if not.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the docstring to Sequence instead of list/tuple, is that ok? Do you want to explicitly restrict the input to be a list or tuple?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there another Sequence type this could be other than list or tuple?

:return: A thread safe iterable which contains one model for each param map. Each
call to `next(modelIterator)` will return `(index, model)` where model was fit
using `params[index]`. Params maps may be fit in an order different than their
order in params.

.. note:: Experimental
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use .. note:: DeveloperApi too.

"""
def fitSingleModel(index):
return self.fit(dataset, params[index])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we make a copy of the Estimator before defining fitSingleModel to be extra safe (in case some other thread modifies the Params in this Estimator before a call to fit()? You can do self.copy() beforehand to get a copy.

return FitMutlipleIterator(fitSingleModel, len(params))

@since("1.3.0")
def fit(self, dataset, params=None):
"""
Expand All @@ -61,7 +106,10 @@ def fit(self, dataset, params=None):
if params is None:
params = dict()
if isinstance(params, (list, tuple)):
return [self.fit(dataset, paramMap) for paramMap in params]
models = [None] * len(params)
for index, model in self.fitMultiple(dataset, params):
models[index] = model
return models
elif isinstance(params, dict):
if params:
return self.copy(params)._fit(dataset)
Expand Down
15 changes: 15 additions & 0 deletions python/pyspark/ml/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2359,6 +2359,21 @@ def test_unary_transformer_transform(self):
self.assertEqual(res.input + shiftVal, res.output)


class TestFit(unittest.TestCase):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: How about EstimatorTest since this is testing part of the Estimator API?


def testDefaultFitMultiple(self):
N = 4
data = MockDataset()
estimator = MockEstimator()
params = [{estimator.fake: i} for i in range(N)]
modelIter = estimator.fitMultiple(data, params)
indexList = []
for index, model in modelIter:
self.assertEqual(model.getFake(), index)
indexList.append(index)
self.assertEqual(sorted(indexList), list(range(N)))


if __name__ == "__main__":
from pyspark.ml.tests import *
if xmlrunner:
Expand Down
33 changes: 18 additions & 15 deletions python/pyspark/ml/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,17 @@
'TrainValidationSplitModel']


def parallelFitTasks(est, train, eva, validation, epm):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about a brief doc string?

modelIter = est.fitMultiple(train, epm)

def singleTask():
index, model = next(modelIter)
metric = eva.evaluate(model.transform(validation, epm[index]))
return index, metric

return [singleTask] * len(epm)


class ParamGridBuilder(object):
r"""
Builder for a param grid used in grid search-based model selection.
Expand Down Expand Up @@ -266,15 +277,9 @@ def _fit(self, dataset):
validation = df.filter(condition).cache()
train = df.filter(~condition).cache()

def singleTrain(paramMap):
model = est.fit(train, paramMap)
# TODO: duplicate evaluator to take extra params from input
metric = eva.evaluate(model.transform(validation, paramMap))
return metric

currentFoldMetrics = pool.map(singleTrain, epm)
for j in range(numModels):
metrics[j] += (currentFoldMetrics[j] / nFolds)
tasks = parallelFitTasks(est, train, eva, validation, epm)
for j, metric in pool.imap_unordered(lambda f: f(), tasks):
metrics[j] += (metric / nFolds)
validation.unpersist()
train.unpersist()

Expand Down Expand Up @@ -523,13 +528,11 @@ def _fit(self, dataset):
validation = df.filter(condition).cache()
train = df.filter(~condition).cache()

def singleTrain(paramMap):
model = est.fit(train, paramMap)
metric = eva.evaluate(model.transform(validation, paramMap))
return metric

tasks = parallelFitTasks(est, train, eva, validation, epm)
pool = ThreadPool(processes=min(self.getParallelism(), numModels))
metrics = pool.map(singleTrain, epm)
metrics = [None] * numModels
for j, metric in pool.imap_unordered(lambda f: f(), tasks):
metrics[j] = metric
train.unpersist()
validation.unpersist()

Expand Down