Skip to content

Commit

Permalink
convert to np
Browse files Browse the repository at this point in the history
  • Loading branch information
rishic3 committed Jul 22, 2023
1 parent 50cabff commit 3b096e0
Showing 1 changed file with 3 additions and 7 deletions.
10 changes: 3 additions & 7 deletions python/src/spark_rapids_ml/umap.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ def _get_cuml_transform_func(
self, dataset: DataFrame, category: str = transform_evaluate.transform
) -> Tuple[_ConstructFunc, _TransformFunc, Optional[_EvaluateFunc],]:
cuml_alg_params = self.cuml_params.copy()
driver_embedding = self.embedding_
driver_embedding = np.array(self.embedding_, dtype=np.float32)
driver_raw_data = np.array(self.raw_data_, dtype=np.float32)

def _construct_umap() -> CumlT:
Expand All @@ -519,9 +519,7 @@ def _construct_umap() -> CumlT:
from cuml.manifold import UMAP as CumlUMAP

if is_sparse(driver_raw_data):
raw_data_cuml = SparseCumlArray(
driver_raw_data, convert_to_dtype=cp.float32, convert_format=False
)
raw_data_cuml = SparseCumlArray(driver_raw_data, convert_format=False)
else:
raw_data_cuml, _, _, _ = input_to_cuml_array(
driver_raw_data,
Expand All @@ -530,9 +528,7 @@ def _construct_umap() -> CumlT:
)

internal_model = CumlUMAP(**cuml_alg_params)
internal_model.embedding_ = cp.array(
driver_embedding, dtype=cp.float32
).data
internal_model.embedding_ = cp.array(driver_embedding).data
internal_model._raw_data = raw_data_cuml

return internal_model
Expand Down

0 comments on commit 3b096e0

Please sign in to comment.