Skip to content

Commit

Permalink
support zero nnz for sparse logistic regression
Browse files Browse the repository at this point in the history
  • Loading branch information
lijinf2 committed Sep 25, 2024
1 parent f818527 commit 6c90418
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 0 deletions.
10 changes: 10 additions & 0 deletions cpp/src/glm/qn/mg/standardization.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,16 @@ void mean_stddev(const raft::handle_t& handle,
{
auto stream = handle.get_stream();
int D = X.n;

if (X.nnz == 0) {
SimpleVec<T> meanVec(mean_vector, D);
meanVec.fill(0., stream);

SimpleVec<T> stddevVec(stddev_vector, D);
stddevVec.fill(0., stream);
return;
}

mean(handle, X, n_samples, mean_vector);

// calculate stdev.S
Expand Down
41 changes: 41 additions & 0 deletions python/cuml/cuml/tests/dask/test_dask_logistic_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -1119,3 +1119,44 @@ def make_classification_with_nnz(
)

assert lr_on.dtype == datatype


@pytest.mark.parametrize("standardization", [False, True])
@pytest.mark.parametrize("fit_intercept", [False, True])
def test_sparse_all_zeroes(standardization, fit_intercept, client):
n_parts = 2
datatype = "float32"

X = np.array([(0, 0), (0, 0), (0, 0), (0, 0)], datatype)
y = np.array([1.0, 1.0, 0.0, 0.0], datatype)
X = csr_matrix(X)
X_da_csr, y_da = _prep_training_data_sparse(client, X, y, n_parts)

from cuml.dask.linear_model import LogisticRegression as cumlLBFGS_dask

mg = cumlLBFGS_dask(
fit_intercept=fit_intercept,
verbose=True,
standardization=standardization,
)
mg.fit(X_da_csr, y_da)
mg_preds = mg.predict(X_da_csr).compute()

from sklearn.linear_model import LogisticRegression

cpu_lr = LogisticRegression(fit_intercept=fit_intercept)
cpu_lr.fit(X, y)
cpu_preds = cpu_lr.predict(X)

assert array_equal(mg_preds, cpu_preds)

assert array_equal(
mg.coef_,
cpu_lr.coef_,
with_sign=True,
)
assert array_equal(
mg.intercept_,
cpu_lr.intercept_,
with_sign=True,
)

0 comments on commit 6c90418

Please sign in to comment.