Skip to content

Commit

Permalink
update tree pca
Browse files Browse the repository at this point in the history
  • Loading branch information
Intron7 committed Nov 12, 2024
1 parent d3421ca commit 8b32156
Showing 1 changed file with 32 additions and 32 deletions.
64 changes: 32 additions & 32 deletions src/rapids_singlecell/preprocessing/_pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ def pca(
"Dask arrays are not supported for chunked PCA computation."
)
_check_gpu_X(X, allow_dask=True)
if not zero_center:
raise ValueError("Dask arrays do not support non-zero centered PCA.")
if isinstance(X._meta, cp.ndarray):
from cuml.dask.decomposition import PCA

Expand All @@ -160,7 +162,7 @@ def pca(
pca_func = pca_func.fit(X)
X_pca = pca_func.transform(X)

else:
elif zero_center:
if chunked:
from cuml.decomposition import IncrementalPCA

Expand All @@ -179,38 +181,36 @@ def pca(
if issparse(chunk) or cpissparse(chunk):
chunk = chunk.toarray()
X_pca[start_idx:stop_idx] = pca_func.transform(chunk)
elif cpissparse(X) or issparse(X):
if issparse(X):
X = sparse_scipy_to_cp(X, dtype=X.dtype)
from ._sparse_pca._sparse_pca import PCA_sparse

if not isspmatrix_csr(X):
X = X.tocsr()
pca_func = PCA_sparse(n_components=n_comps)
X_pca = pca_func.fit_transform(X)
else:
if zero_center:
if cpissparse(X) or issparse(X):
if issparse(X):
X = sparse_scipy_to_cp(X, dtype=X.dtype)
from ._sparse_pca._sparse_pca import PCA_sparse

if not isspmatrix_csr(X):
X = X.tocsr()
pca_func = PCA_sparse(n_components=n_comps)
X_pca = pca_func.fit_transform(X)
else:
from cuml.decomposition import PCA

pca_func = PCA(
n_components=n_comps,
svd_solver=svd_solver,
random_state=random_state,
output_type="numpy",
)
X_pca = pca_func.fit_transform(X)

else: # not zero_center
from cuml.decomposition import TruncatedSVD

pca_func = TruncatedSVD(
n_components=n_comps,
random_state=random_state,
algorithm=svd_solver,
output_type="numpy",
)
X_pca = pca_func.fit_transform(X)
from cuml.decomposition import PCA

pca_func = PCA(
n_components=n_comps,
svd_solver=svd_solver,
random_state=random_state,
output_type="numpy",
)
X_pca = pca_func.fit_transform(X)

else: # not zero_center
from cuml.decomposition import TruncatedSVD

pca_func = TruncatedSVD(
n_components=n_comps,
random_state=random_state,
algorithm=svd_solver,
output_type="numpy",
)
X_pca = pca_func.fit_transform(X)

if X_pca.dtype.descr != np.dtype(dtype).descr:
X_pca = X_pca.astype(dtype)
Expand Down

0 comments on commit 8b32156

Please sign in to comment.