Skip to content

Commit

Permalink
revised accordingly to the new cuml fit wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
lijinf2 committed Jul 25, 2023
1 parent 3f03479 commit 88aeb6f
Showing 1 changed file with 10 additions and 26 deletions.
36 changes: 10 additions & 26 deletions python/src/spark_rapids_ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,16 +614,12 @@ def _get_cuml_fit_func(
array_order = self._fit_array_order()
num_classes = dataset.select(alias.label).distinct().count()

def _linear_regression_fit(
def _logistic_regression_fit(
dfs: FitInputType,
params: Dict[str, Any],
) -> Dict[str, Any]:
init_parameters = params[param_alias.cuml_init]

pdesc = PartitionDescriptor.build(
params[param_alias.part_sizes], params[param_alias.num_cols]
)

from cuml.linear_model.logistic_regression_mg import LogisticRegressionMG
supported_params = [
]
Expand All @@ -648,39 +644,27 @@ def _linear_regression_fit(
concated = _concat_and_free(X_list, order=array_order)
concated_y = _concat_and_free(y_list, order=array_order)

pdesc = PartitionDescriptor.build(
[concated.shape[0]], params[param_alias.num_cols]
)


print(f"DEBUG: pdesc.rank: {pdesc.rank}, num_workers: {num_workers}, pdesc.m: {pdesc.m}, num_classes: {num_classes}")
logistic_regression.fit(
concated,
concated_y,
pdesc.rank,
num_workers,
[(concated, concated_y)],
pdesc.m,
num_classes,
pdesc.n,
pdesc.parts_rank_size,
pdesc.rank,
)

print("DEBUG: showing logistic_regression.coef_: ")

print(logistic_regression.coef_)
print(type(logistic_regression.coef_))
print(logistic_regression.coef_.shape)
print(type(logistic_regression.coef_[0]))

print("DEBUG: showing logistic_regression.intercept_: ")
print(logistic_regression.intercept_)
print(logistic_regression.intercept_.shape)

print(f"DEBUG: dtype: {logistic_regression.dtype}")
print(f"DEBUG: dtype.name: {logistic_regression.dtype.name}")

return {
"coef_": [logistic_regression.coef_.tolist()],
"intercept_": [logistic_regression.intercept_.tolist()],
"n_cols": [logistic_regression.n_cols],
"dtype": [logistic_regression.dtype.name],
}

return _linear_regression_fit
return _logistic_regression_fit

def _out_schema(self) -> Union[StructType, str]:
return StructType(
Expand Down

0 comments on commit 88aeb6f

Please sign in to comment.