From 8b32156eff668d8e336ef09e07dc2f26be5ca2d6 Mon Sep 17 00:00:00 2001 From: Intron7 Date: Tue, 12 Nov 2024 15:02:39 +0100 Subject: [PATCH] update tree pca --- src/rapids_singlecell/preprocessing/_pca.py | 64 ++++++++++----------- 1 file changed, 32 insertions(+), 32 deletions(-) diff --git a/src/rapids_singlecell/preprocessing/_pca.py b/src/rapids_singlecell/preprocessing/_pca.py index 26cc3021..89b29ed7 100644 --- a/src/rapids_singlecell/preprocessing/_pca.py +++ b/src/rapids_singlecell/preprocessing/_pca.py @@ -145,6 +145,8 @@ def pca( "Dask arrays are not supported for chunked PCA computation." ) _check_gpu_X(X, allow_dask=True) + if not zero_center: + raise ValueError("Dask arrays do not support non-zero centered PCA.") if isinstance(X._meta, cp.ndarray): from cuml.dask.decomposition import PCA @@ -160,7 +162,7 @@ def pca( pca_func = pca_func.fit(X) X_pca = pca_func.transform(X) - else: + elif zero_center: if chunked: from cuml.decomposition import IncrementalPCA @@ -179,38 +181,36 @@ def pca( if issparse(chunk) or cpissparse(chunk): chunk = chunk.toarray() X_pca[start_idx:stop_idx] = pca_func.transform(chunk) + elif cpissparse(X) or issparse(X): + if issparse(X): + X = sparse_scipy_to_cp(X, dtype=X.dtype) + from ._sparse_pca._sparse_pca import PCA_sparse + + if not isspmatrix_csr(X): + X = X.tocsr() + pca_func = PCA_sparse(n_components=n_comps) + X_pca = pca_func.fit_transform(X) else: - if zero_center: - if cpissparse(X) or issparse(X): - if issparse(X): - X = sparse_scipy_to_cp(X, dtype=X.dtype) - from ._sparse_pca._sparse_pca import PCA_sparse - - if not isspmatrix_csr(X): - X = X.tocsr() - pca_func = PCA_sparse(n_components=n_comps) - X_pca = pca_func.fit_transform(X) - else: - from cuml.decomposition import PCA - - pca_func = PCA( - n_components=n_comps, - svd_solver=svd_solver, - random_state=random_state, - output_type="numpy", - ) - X_pca = pca_func.fit_transform(X) - - else: # not zero_center - from cuml.decomposition import TruncatedSVD - - pca_func = TruncatedSVD( - n_components=n_comps, - random_state=random_state, - algorithm=svd_solver, - output_type="numpy", - ) - X_pca = pca_func.fit_transform(X) + from cuml.decomposition import PCA + + pca_func = PCA( + n_components=n_comps, + svd_solver=svd_solver, + random_state=random_state, + output_type="numpy", + ) + X_pca = pca_func.fit_transform(X) + + else: # not zero_center + from cuml.decomposition import TruncatedSVD + + pca_func = TruncatedSVD( + n_components=n_comps, + random_state=random_state, + algorithm=svd_solver, + output_type="numpy", + ) + X_pca = pca_func.fit_transform(X) if X_pca.dtype.descr != np.dtype(dtype).descr: X_pca = X_pca.astype(dtype)