Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-22922][ML][PySpark] Pyspark portion of the fit-multiple API #20058

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 66 additions & 3 deletions python/pyspark/ml/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,52 @@
from abc import ABCMeta, abstractmethod

import copy
import threading

from pyspark import since
from pyspark.ml.param import Params
from pyspark.ml.param.shared import *
from pyspark.ml.common import inherit_doc
from pyspark.sql.functions import udf
from pyspark.sql.types import StructField, StructType, DoubleType
from pyspark.sql.types import StructField, StructType


class _FitMultipleIterator(object):
"""
Used by default implementation of Estimator.fitMultiple to produce models in a thread safe
iterator. This class handles the simple case of fitMultiple where each param map should be
fit independently.

:param fitSingleModel: Function: (int => Model) which fits an estimator to a dataset.
`fitSingleModel` may be called up to `numModels` times, with a unique index each time.
Each call to `fitSingleModel` with an index should return the Model associated with
that index.
:param numModel: Number of models this iterator should produce.

See Estimator.fitMultiple for more info.
"""
def __init__(self, fitSingleModel, numModels):
"""

"""
self.fitSingleModel = fitSingleModel
self.numModel = numModels
self.counter = 0
self.lock = threading.Lock()

def __iter__(self):
return self

def __next__(self):
with self.lock:
index = self.counter
if index >= self.numModel:
raise StopIteration("No models remaining.")
self.counter += 1
return index, self.fitSingleModel(index)

def next(self):
"""For python2 compatibility."""
return self.__next__()


@inherit_doc
Expand All @@ -47,6 +86,27 @@ def _fit(self, dataset):
"""
raise NotImplementedError()

@since("2.3.0")
def fitMultiple(self, dataset, paramMaps):
"""
Fits a model to the input dataset for each param map in `paramMaps`.

:param dataset: input dataset, which is an instance of :py:class:`pyspark.sql.DataFrame`.
:param paramMaps: A Sequence of param maps.
:return: A thread safe iterable which contains one model for each param map. Each
call to `next(modelIterator)` will return `(index, model)` where model was fit
using `paramMaps[index]`. `index` values may not be sequential.

.. note:: DeveloperApi
.. note:: Experimental
"""
estimator = self.copy()

def fitSingleModel(index):
return estimator.fit(dataset, paramMaps[index])

return _FitMultipleIterator(fitSingleModel, len(paramMaps))

@since("1.3.0")
def fit(self, dataset, params=None):
"""
Expand All @@ -61,7 +121,10 @@ def fit(self, dataset, params=None):
if params is None:
params = dict()
if isinstance(params, (list, tuple)):
return [self.fit(dataset, paramMap) for paramMap in params]
models = [None] * len(params)
for index, model in self.fitMultiple(dataset, params):
models[index] = model
return models
elif isinstance(params, dict):
if params:
return self.copy(params)._fit(dataset)
Expand Down
15 changes: 15 additions & 0 deletions python/pyspark/ml/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2359,6 +2359,21 @@ def test_unary_transformer_transform(self):
self.assertEqual(res.input + shiftVal, res.output)


class EstimatorTest(unittest.TestCase):

def testDefaultFitMultiple(self):
N = 4
data = MockDataset()
estimator = MockEstimator()
params = [{estimator.fake: i} for i in range(N)]
modelIter = estimator.fitMultiple(data, params)
indexList = []
for index, model in modelIter:
self.assertEqual(model.getFake(), index)
indexList.append(index)
self.assertEqual(sorted(indexList), list(range(N)))


if __name__ == "__main__":
from pyspark.ml.tests import *
if xmlrunner:
Expand Down
44 changes: 29 additions & 15 deletions python/pyspark/ml/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,28 @@
'TrainValidationSplitModel']


def _parallelFitTasks(est, train, eva, validation, epm):
"""
Creates a list of callables which can be called from different threads to fit and evaluate
an estimator in parallel. Each callable returns an `(index, metric)` pair.

:param est: Estimator, the estimator to be fit.
:param train: DataFrame, training data set, used for fitting.
:param eva: Evaluator, used to compute `metric`
:param validation: DataFrame, validation data set, used for evaluation.
:param epm: Sequence of ParamMap, params maps to be used during fitting & evaluation.
:return: (int, float), an index into `epm` and the associated metric value.
"""
modelIter = est.fitMultiple(train, epm)

def singleTask():
index, model = next(modelIter)
metric = eva.evaluate(model.transform(validation, epm[index]))
return index, metric

return [singleTask] * len(epm)


class ParamGridBuilder(object):
r"""
Builder for a param grid used in grid search-based model selection.
Expand Down Expand Up @@ -266,15 +288,9 @@ def _fit(self, dataset):
validation = df.filter(condition).cache()
train = df.filter(~condition).cache()

def singleTrain(paramMap):
model = est.fit(train, paramMap)
# TODO: duplicate evaluator to take extra params from input
metric = eva.evaluate(model.transform(validation, paramMap))
return metric

currentFoldMetrics = pool.map(singleTrain, epm)
for j in range(numModels):
metrics[j] += (currentFoldMetrics[j] / nFolds)
tasks = _parallelFitTasks(est, train, eva, validation, epm)
for j, metric in pool.imap_unordered(lambda f: f(), tasks):
metrics[j] += (metric / nFolds)
validation.unpersist()
train.unpersist()

Expand Down Expand Up @@ -523,13 +539,11 @@ def _fit(self, dataset):
validation = df.filter(condition).cache()
train = df.filter(~condition).cache()

def singleTrain(paramMap):
model = est.fit(train, paramMap)
metric = eva.evaluate(model.transform(validation, paramMap))
return metric

tasks = _parallelFitTasks(est, train, eva, validation, epm)
pool = ThreadPool(processes=min(self.getParallelism(), numModels))
metrics = pool.map(singleTrain, epm)
metrics = [None] * numModels
for j, metric in pool.imap_unordered(lambda f: f(), tasks):
metrics[j] = metric
train.unpersist()
validation.unpersist()

Expand Down