Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Logistic Regression] Support fit on two classes #343

Merged
merged 7 commits into from
Aug 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@ then
fi
echo "use --runslow to run all tests"
pytest benchmark/test_gen_data.py
pytest -ra "$@" --durations=10 tests
# pytest -ra --runslow --durations=10 tests
pytest -ra "$@" --durations=10 --ignore=tests/test_logistic_regression.py tests
# pytest -ra --runslow --durations=10 --ignore=tests/test_logistic_regression.py tests
283 changes: 281 additions & 2 deletions python/src/spark_rapids_ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union, cast

from pyspark.ml.evaluation import Evaluator, MulticlassClassificationEvaluator

Expand All @@ -32,7 +32,8 @@
RandomForestClassificationSummary,
_RandomForestClassifierParams,
)
from pyspark.ml.linalg import Vector
from pyspark.ml.functions import vector_to_array
from pyspark.ml.linalg import Vector, Vectors, VectorUDT
from pyspark.ml.param.shared import HasProbabilityCol, HasRawPredictionCol
from pyspark.sql import Column, DataFrame
from pyspark.sql.functions import col
Expand Down Expand Up @@ -509,3 +510,281 @@ def _transformEvaluate(
)
scores.append(metrics.evaluate(evaluator))
return scores


from typing import Callable

from pyspark.ml.param.shared import HasFeaturesCol, HasLabelCol, HasPredictionCol
from pyspark.sql.types import (
ArrayType,
DoubleType,
FloatType,
IntegerType,
StringType,
StructField,
StructType,
)

from .core import (
FitInputType,
_CumlEstimatorSupervised,
_CumlModelWithPredictionCol,
param_alias,
)
from .params import HasFeaturesCols, _CumlClass, _CumlParams
from .utils import PartitionDescriptor, _ArrayOrder, _concat_and_free


class LogisticRegressionClass(_CumlClass):
@classmethod
def _param_mapping(cls) -> Dict[str, Optional[str]]:
return {}

@classmethod
def _param_value_mapping(
cls,
) -> Dict[str, Callable[[str], Union[None, str, float, int]]]:
return {}

def _get_cuml_params_default(self) -> Dict[str, Any]:
return {}


class _LogisticRegressionCumlParams(
_CumlParams, HasFeaturesCol, HasLabelCol, HasFeaturesCols, HasPredictionCol
):
def getFeaturesCol(self) -> Union[str, List[str]]: # type:ignore
"""
Gets the value of :py:attr:`featuresCol` or :py:attr:`featuresCols`
"""
if self.isDefined(self.featuresCols):
return self.getFeaturesCols()
elif self.isDefined(self.featuresCol):
return self.getOrDefault("featuresCol")
else:
raise RuntimeError("featuresCol is not set")

def setFeaturesCol(
self: "_LogisticRegressionCumlParams", value: Union[str, List[str]]
) -> "_LogisticRegressionCumlParams":
"""
Sets the value of :py:attr:`featuresCol` or :py:attr:`featureCols`.
"""
if isinstance(value, str):
self.set_params(featuresCol=value)
else:
self.set_params(featuresCols=value)
return self

def setFeaturesCols(
self: "_LogisticRegressionCumlParams", value: List[str]
) -> "_LogisticRegressionCumlParams":
"""
Sets the value of :py:attr:`featuresCols`.
"""
return self.set_params(featuresCols=value)

def setLabelCol(
self: "_LogisticRegressionCumlParams", value: str
) -> "_LogisticRegressionCumlParams":
"""
Sets the value of :py:attr:`labelCol`.
"""
return self.set_params(labelCol=value)

def setPredictionCol(
self: "_LogisticRegressionCumlParams", value: str
) -> "_LogisticRegressionCumlParams":
"""
Sets the value of :py:attr:`predictionCol`.
"""
return self.set_params(predictionCol=value)


class LogisticRegression(
LogisticRegressionClass,
_CumlEstimatorSupervised,
_LogisticRegressionCumlParams,
):
"""
Examples
--------
>>> from spark_rapids_ml.classification import LogisticRegression
>>> data = [
... ([1.0, 2.0], 1.0),
... ([1.0, 3.0], 1.0),
... ([2.0, 1.0], 0.0),
... ([3.0, 1.0], 0.0),
... ]
>>> schema = "features array<float>, label float"
>>> df = spark.createDataFrame(data, schema=schema)
>>> df.show()
+----------+-----+
| features|label|
+----------+-----+
|[1.0, 2.0]| 1.0|
|[1.0, 3.0]| 1.0|
|[2.0, 1.0]| 0.0|
|[3.0, 1.0]| 0.0|
+----------+-----+

>>> lr_estimator = LogisticRegression()
>>> lr_estimator.setFeaturesCol("features")
LogisticRegression_a757215437b0
>>> lr_estimator.setLabelCol("label")
LogisticRegression_a757215437b0
>>> lr_model = lr_estimator.fit(df)
>>> lr_model.coefficients
DenseVector([-0.7148, 0.7148])
>>> lr_model.intercept
-8.543887375367376e-09
"""

def __init__(self, *, num_workers: Optional[int] = None):
super().__init__()

def _fit_array_order(self) -> _ArrayOrder:
return "C"

def _get_cuml_fit_func(
self,
dataset: DataFrame,
extra_params: Optional[List[Dict[str, Any]]] = None,
) -> Callable[[FitInputType, Dict[str, Any]], Dict[str, Any],]:
array_order = self._fit_array_order()

def _logistic_regression_fit(
dfs: FitInputType,
params: Dict[str, Any],
) -> Dict[str, Any]:
init_parameters = params[param_alias.cuml_init]

from cuml.linear_model.logistic_regression_mg import LogisticRegressionMG

supported_params: List[str] = []

# filter only supported params
init_parameters = {
k: v for k, v in init_parameters.items() if k in supported_params
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems supported_params is always an empty list?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, working on C++/Cython support: rapidsai/cuml#5516
Will address this in the next PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will you be integrating the init params in this PR, now that the cuml init params PR is merged? It would help flesh out the param mapping and the compat tests.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. not sure actually.

The init PR will update a certain amount of the codes in multiple places of this PR. May make this PR too long and introduce extra reviewing overhead.

Thinking of getting this one merged. Then I will create a new PR for transform and init and update all the tests in test_logistic_regression.py.

}

logistic_regression = LogisticRegressionMG(
handle=params[param_alias.handle],
**init_parameters,
)

X_list = [x for (x, _, _) in dfs]
y_list = [y for (_, y, _) in dfs]
if isinstance(X_list[0], pd.DataFrame):
concated = pd.concat(X_list)
concated_y = pd.concat(y_list)
else:
# features are either cp or np arrays here
concated = _concat_and_free(X_list, order=array_order)
concated_y = _concat_and_free(y_list, order=array_order)

pdesc = PartitionDescriptor.build(
[concated.shape[0]], params[param_alias.num_cols]
)

logistic_regression.fit(
[(concated, concated_y)],
pdesc.m,
pdesc.n,
pdesc.parts_rank_size,
pdesc.rank,
)

return {
"coef_": [logistic_regression.coef_.tolist()],
"intercept_": [logistic_regression.intercept_.tolist()],
"n_cols": [logistic_regression.n_cols],
"dtype": [logistic_regression.dtype.name],
}

return _logistic_regression_fit

def _pre_process_data(
self, dataset: DataFrame
) -> Tuple[
List[Column], Optional[List[str]], int, Union[Type[FloatType], Type[DoubleType]]
]:
(
select_cols,
multi_col_names,
dimension,
feature_type,
) = super()._pre_process_data(dataset)

# if input format is vectorUDT, convert data type to float32
# TODO: support float64
input_col, _ = self._get_input_columns()
label_col = self.getLabelCol()

if input_col is not None and isinstance(
dataset.schema[input_col].dataType, VectorUDT
):
select_cols[0] = vector_to_array(col(input_col), dtype="float32").alias(
alias.data
)

select_cols[1] = col(label_col).cast(FloatType()).alias(alias.label)
feature_type = FloatType

return select_cols, multi_col_names, dimension, feature_type

def _out_schema(self) -> Union[StructType, str]:
return StructType(
[
StructField("coef_", ArrayType(ArrayType(DoubleType()), False), False),
StructField("intercept_", ArrayType(DoubleType()), False),
StructField("n_cols", IntegerType(), False),
StructField("dtype", StringType(), False),
]
)

def _create_pyspark_model(self, result: Row) -> "LogisticRegressionModel":
return LogisticRegressionModel.from_row(result)


class LogisticRegressionModel(
LogisticRegressionClass,
_CumlModelWithPredictionCol,
_LogisticRegressionCumlParams,
):
"""Model fitted by :class:`LogisticRegression`."""

def __init__(
self,
coef_: List[List[float]],
intercept_: List[float],
n_cols: int,
dtype: str,
) -> None:
super().__init__(dtype=dtype, n_cols=n_cols, coef_=coef_, intercept_=intercept_)
self.coef_ = coef_
self.intercept_ = intercept_

def _get_cuml_transform_func( # type:ignore
self, dataset: DataFrame, category: str = transform_evaluate.transform
) -> Tuple[_ConstructFunc, _TransformFunc, Optional[_EvaluateFunc],]: # type:ignore
pass

def _transform(self, dataset: DataFrame) -> DataFrame: # type:ignore
pass

@property
def coefficients(self) -> Vector:
"""
Model coefficients.
"""
assert len(self.coef_) == 1, "multi classes not supported yet"
return Vectors.dense(cast(list, self.coef_[0]))

@property
def intercept(self) -> float:
"""
Model intercept.
"""
assert len(self.intercept_) == 1, "multi classes not supported yet"
return self.intercept_[0]
Loading
Loading