Skip to content

Commit

Permalink
get run_tests.sh passed and added docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
lijinf2 committed Jul 25, 2023
1 parent 88aeb6f commit cdf5d5a
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 76 deletions.
124 changes: 90 additions & 34 deletions python/src/spark_rapids_ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union, cast

from pyspark.ml.evaluation import Evaluator, MulticlassClassificationEvaluator

Expand All @@ -32,7 +32,7 @@
RandomForestClassificationSummary,
_RandomForestClassifierParams,
)
from pyspark.ml.linalg import Vector
from pyspark.ml.linalg import Vector, Vectors
from pyspark.ml.param.shared import HasProbabilityCol, HasRawPredictionCol
from pyspark.sql import Column, DataFrame
from pyspark.sql.functions import col
Expand Down Expand Up @@ -511,15 +511,9 @@ def _transformEvaluate(
return scores


from .params import _CumlClass, _CumlParams, HasFeaturesCols
from typing import Callable
from .core import _CumlEstimatorSupervised, _CumlModelWithPredictionCol, FitInputType, param_alias
from pyspark.ml.param.shared import (
HasFeaturesCol,
HasLabelCol,
HasPredictionCol
)
from .utils import PartitionDescriptor, _concat_and_free, _ArrayOrder

from pyspark.ml.param.shared import HasFeaturesCol, HasLabelCol, HasPredictionCol
from pyspark.sql.types import (
ArrayType,
DoubleType,
Expand All @@ -530,27 +524,34 @@ def _transformEvaluate(
StructType,
)

from .core import (
FitInputType,
_CumlEstimatorSupervised,
_CumlModelWithPredictionCol,
param_alias,
)
from .params import HasFeaturesCols, _CumlClass, _CumlParams
from .utils import PartitionDescriptor, _ArrayOrder, _concat_and_free


class LogisticRegressionClass(_CumlClass):
@classmethod
def _param_mapping(cls) -> Dict[str, Optional[str]]:
return {
}
return {}

@classmethod
def _param_value_mapping(
cls,
) -> Dict[str, Callable[[str], Union[None, str, float, int]]]:
return {
}
return {}

def _get_cuml_params_default(self) -> Dict[str, Any]:
return {
}
return {}


class _LogisticRegressionCumlParams(
_CumlParams, HasFeaturesCol, HasLabelCol, HasFeaturesCols, HasPredictionCol
):

def getFeaturesCol(self) -> Union[str, List[str]]: # type:ignore
"""
Gets the value of :py:attr:`featuresCol` or :py:attr:`featuresCols`
Expand All @@ -562,7 +563,9 @@ def getFeaturesCol(self) -> Union[str, List[str]]: # type:ignore
else:
raise RuntimeError("featuresCol is not set")

def setFeaturesCol(self: "_LogisticRegressionCumlParams", value: Union[str, List[str]]) -> "_LogisticRegressionCumlParams":
def setFeaturesCol(
self: "_LogisticRegressionCumlParams", value: Union[str, List[str]]
) -> "_LogisticRegressionCumlParams":
"""
Sets the value of :py:attr:`featuresCol` or :py:attr:`featureCols`.
"""
Expand All @@ -572,33 +575,71 @@ def setFeaturesCol(self: "_LogisticRegressionCumlParams", value: Union[str, List
self.set_params(featuresCols=value)
return self

def setFeaturesCols(self: "_LogisticRegressionCumlParams", value: List[str]) -> "_LogisticRegressionCumlParams":
def setFeaturesCols(
self: "_LogisticRegressionCumlParams", value: List[str]
) -> "_LogisticRegressionCumlParams":
"""
Sets the value of :py:attr:`featuresCols`.
"""
return self.set_params(featuresCols=value)

def setLabelCol(self: "_LogisticRegressionCumlParams", value: str) -> "_LogisticRegressionCumlParams":
def setLabelCol(
self: "_LogisticRegressionCumlParams", value: str
) -> "_LogisticRegressionCumlParams":
"""
Sets the value of :py:attr:`labelCol`.
"""
return self.set_params(labelCol=value)

def setPredictionCol(self: "_LogisticRegressionCumlParams", value: str) -> "_LogisticRegressionCumlParams":
def setPredictionCol(
self: "_LogisticRegressionCumlParams", value: str
) -> "_LogisticRegressionCumlParams":
"""
Sets the value of :py:attr:`predictionCol`.
"""
return self.set_params(predictionCol=value)


class LogisticRegression(
LogisticRegressionClass,
_CumlEstimatorSupervised,
_LogisticRegressionCumlParams,
):
def __init__(
self,
*,
num_workers: Optional[int] = None) :
"""
Examples
--------
>>> from spark_rapids_ml.classification import LogisticRegression
>>> data = [
... ([1.0, 2.0], 1.0),
... ([1.0, 3.0], 1.0),
... ([2.0, 1.0], 0.0),
... ([3.0, 1.0], 0.0),
... ]
>>> schema = "features array<float>, label float"
>>> df = spark.createDataFrame(data, schema=schema)
>>> df.show()
+----------+-----+
| features|label|
+----------+-----+
|[1.0, 2.0]| 1.0|
|[1.0, 3.0]| 1.0|
|[2.0, 1.0]| 0.0|
|[3.0, 1.0]| 0.0|
+----------+-----+
>>> lr_estimator = LogisticRegression()
>>> lr_estimator.setFeaturesCol("features")
LogisticRegression_a757215437b0
>>> lr_estimator.setLabelCol("label")
LogisticRegression_a757215437b0
>>> lr_model = lr_estimator.fit(df)
>>> lr_model.coefficients
DenseVector([-0.7148, 0.7148])
>>> lr_model.intercept
-8.543887375367376e-09
"""

def __init__(self, *, num_workers: Optional[int] = None):
super().__init__()

def _fit_array_order(self) -> _ArrayOrder:
Expand All @@ -609,7 +650,6 @@ def _get_cuml_fit_func(
dataset: DataFrame,
extra_params: Optional[List[Dict[str, Any]]] = None,
) -> Callable[[FitInputType, Dict[str, Any]], Dict[str, Any],]:

num_workers = self.num_workers
array_order = self._fit_array_order()
num_classes = dataset.select(alias.label).distinct().count()
Expand All @@ -621,8 +661,8 @@ def _logistic_regression_fit(
init_parameters = params[param_alias.cuml_init]

from cuml.linear_model.logistic_regression_mg import LogisticRegressionMG
supported_params = [
]

supported_params: List[str] = []

# filter only supported params
init_parameters = {
Expand All @@ -648,11 +688,10 @@ def _logistic_regression_fit(
[concated.shape[0]], params[param_alias.num_cols]
)


logistic_regression.fit(
[(concated, concated_y)],
pdesc.m,
pdesc.n,
pdesc.n,
pdesc.parts_rank_size,
pdesc.rank,
)
Expand All @@ -679,6 +718,7 @@ def _out_schema(self) -> Union[StructType, str]:
def _create_pyspark_model(self, result: Row) -> "LogisticRegressionModel":
return LogisticRegressionModel.from_row(result)


class LogisticRegressionModel(
LogisticRegressionClass,
_CumlModelWithPredictionCol,
Expand All @@ -697,10 +737,26 @@ def __init__(
self.coef_ = coef_
self.intercept_ = intercept_

def _get_cuml_transform_func(
def _get_cuml_transform_func( # type:ignore
self, dataset: DataFrame, category: str = transform_evaluate.transform
) -> Tuple[_ConstructFunc, _TransformFunc, Optional[_EvaluateFunc],]:
) -> Tuple[_ConstructFunc, _TransformFunc, Optional[_EvaluateFunc],]: # type:ignore
pass

def _transform(self, dataset: DataFrame) -> DataFrame: # type:ignore
pass

def _transform(self, dataset: DataFrame) -> DataFrame:
pass
@property
def coefficients(self) -> Vector:
"""
Model coefficients.
"""
assert len(self.coef_) == 1, "multi classes not supported yet"
return Vectors.dense(cast(list, self.coef_[0]))

@property
def intercept(self) -> float:
"""
Model intercept.
"""
assert len(self.intercept_) == 1, "multi classes not supported yet"
return self.intercept_[0]
74 changes: 32 additions & 42 deletions python/tests/test_logistic_regression.py
Original file line number Diff line number Diff line change
@@ -1,77 +1,62 @@
from .sparksession import CleanSparkSession
from spark_rapids_ml.classification import LogisticRegression, LogisticRegressionModel
import pytest
import json
import math
from typing import Any, Dict, List, Tuple, Type, TypeVar, Union, cast
from typing import Any, Dict, Tuple

import numpy as np
import pytest
from cuml import accuracy_score
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.linalg import Vectors
from pyspark.ml.param import Param
from pyspark.ml.tuning import CrossValidator as SparkCrossValidator
from pyspark.ml.tuning import CrossValidatorModel, ParamGridBuilder
from pyspark.sql.types import DoubleType
from sklearn.metrics import r2_score

from spark_rapids_ml.tuning import CrossValidator
from spark_rapids_ml.classification import LogisticRegression, LogisticRegressionModel

from .sparksession import CleanSparkSession
from .utils import (
array_equal,
assert_params,
create_pyspark_dataframe,
cuml_supported_data_types,
feature_types,
get_default_cuml_parameters,
idfn,
make_classification_dataset,
make_regression_dataset,
pyspark_supported_feature_types,
)


def test_toy_example(gpu_number: int) -> None:
# reduce the number of GPUs for toy dataset to avoid empty partition
gpu_number = min(gpu_number, 2)
data = [
([1., 2.], 1.),
([1., 3.], 1.),
([2., 1.], 0.),
([3., 1.], 0.),
([1.0, 2.0], 1.0),
([1.0, 3.0], 1.0),
([2.0, 1.0], 0.0),
([3.0, 1.0], 0.0),
]

with CleanSparkSession() as spark:
features_col = "features"
label_col = "label"
schema = features_col + " array<float>, " + label_col + " float"
schema = features_col + " array<float>, " + label_col + " float"
df = spark.createDataFrame(data, schema=schema)
df.show()
lr_estimator = LogisticRegression(num_workers=gpu_number)
lr_estimator.setFeaturesCol(features_col)
lr_estimator.setLabelCol(label_col)
lr_model = lr_estimator.fit(df)

assert len(lr_model.coef_) == 1
assert lr_model.coef_[0] == pytest.approx([-0.71483153, 0.7148315], abs=1e-6)
assert lr_model.intercept_ == pytest.approx([-2.2614916e-08], abs=1e-6)
assert lr_model.n_cols == 2
assert lr_model.dtype == "float32"

assert len(lr_model.coef_) == 1
assert lr_model.coef_[0] == pytest.approx([-0.71483153, 0.7148315], abs=1e-6)
assert lr_model.intercept_ == pytest.approx([-2.2614916e-08], abs=1e-6)

assert lr_model.coefficients.toArray() == pytest.approx(
[-0.71483153, 0.7148315], abs=1e-6
)
assert lr_model.intercept == pytest.approx(-2.2614916e-08, abs=1e-6)

#@pytest.mark.parametrize("data_shape", [(2000, 8)], ids=idfn)
#@pytest.mark.parametrize("data_type", cuml_supported_data_types)
#@pytest.mark.parametrize("max_record_batch", [100, 10000])
#@pytest.mark.parametrize("n_classes", [2, 4])
#@pytest.mark.parametrize("num_workers", num_workers)
#@pytest.mark.slow

@pytest.mark.parametrize("feature_type", ['array'])
# TODO support float64
# 'vector' will be converted to float64 so It depends on float64 support
@pytest.mark.parametrize("feature_type", ["array", "multi_cols"])
@pytest.mark.parametrize("data_shape", [(2000, 8)], ids=idfn)
@pytest.mark.parametrize("data_type", [np.float32])
@pytest.mark.parametrize("max_record_batch", [100, 10000])
@pytest.mark.parametrize("n_classes", [2])
@pytest.mark.slow
def test_classifier(
feature_type: str,
data_shape: Tuple[int, int],
Expand All @@ -93,6 +78,7 @@ def test_classifier(
)

from cuml import LogisticRegression as cuLR

cu_lr = cuLR()
cu_lr.fit(X_train, y_train)

Expand All @@ -110,12 +96,16 @@ def test_classifier(
spark_lr.setLabelCol(label_col)
spark_lr_model: LogisticRegressionModel = spark_lr.fit(train_df)

assert len(spark_lr_model.coef_) == len(cu_lr.coef_)
for i in range(len(spark_lr_model.coef_)):
assert spark_lr_model.coef_[i] == pytest.approx(cu_lr.coef_[i], tolerance)

assert spark_lr_model.intercept_ == pytest.approx(cu_lr.intercept_, tolerance)

# test coefficients and intercepts
assert spark_lr_model.n_cols == cu_lr.n_cols

assert spark_lr_model.dtype == cu_lr.dtype

assert len(spark_lr_model.coef_) == 1
assert len(cu_lr.coef_) == 1
assert array_equal(spark_lr_model.coef_[0], cu_lr.coef_[0], tolerance)
assert array_equal(spark_lr_model.intercept_, cu_lr.intercept_, tolerance)

assert array_equal(
spark_lr_model.coefficients.toArray(), cu_lr.coef_[0], tolerance
)
assert spark_lr_model.intercept == pytest.approx(cu_lr.intercept_[0], tolerance)

0 comments on commit cdf5d5a

Please sign in to comment.