From de2dd64a77dcd729b061ec7c9fb4d6f08164fbd2 Mon Sep 17 00:00:00 2001 From: Wonchan Lee Date: Mon, 1 Aug 2022 09:06:09 -0700 Subject: [PATCH] Activate the NumPy fallback for cunumeric.random in CPU build (#485) * Activate the NumPy fallback for cunumeric.random in CPU build * Fix for mypy errors --- cunumeric/config.py | 3 +++ cunumeric/random/__init__.py | 9 ++++++--- cunumeric/runtime.py | 2 ++ src/cunumeric/cunumeric.cc | 9 +++++++++ src/cunumeric/cunumeric_c.h | 1 + src/cunumeric/unary/isnan.h | 6 +++++- 6 files changed, 26 insertions(+), 4 deletions(-) diff --git a/cunumeric/config.py b/cunumeric/config.py index 34a6cd1bb..d898588c5 100644 --- a/cunumeric/config.py +++ b/cunumeric/config.py @@ -264,6 +264,9 @@ class _CunumericSharedLib: CUNUMERIC_WRITE: int CUNUMERIC_ZIP: int + def cunumeric_has_curand(self) -> int: + ... + # Load the cuNumeric library first so we have a shard object that # we can use to initialize all these configuration enumerations diff --git a/cunumeric/random/__init__.py b/cunumeric/random/__init__.py index 3ffec0bdf..9edf79865 100644 --- a/cunumeric/random/__init__.py +++ b/cunumeric/random/__init__.py @@ -15,10 +15,13 @@ from __future__ import annotations import numpy.random as _nprandom -from cunumeric.random.random import * from cunumeric.coverage import clone_module -from cunumeric.random.bitgenerator import * -from cunumeric.random.generator import * +from cunumeric.runtime import runtime + +if runtime.has_curand: + from cunumeric.random.random import * + from cunumeric.random.bitgenerator import * + from cunumeric.random.generator import * clone_module(_nprandom, globals()) diff --git a/cunumeric/runtime.py b/cunumeric/runtime.py index 31e006ffa..64a5c2420 100644 --- a/cunumeric/runtime.py +++ b/cunumeric/runtime.py @@ -162,6 +162,8 @@ def __init__(self, legate_context: LegateContext) -> None: # Make sure that our CuNumericLib object knows about us so it can # destroy us cunumeric_lib.set_runtime(self) + assert cunumeric_lib.shared_object is not None + self.has_curand = cunumeric_lib.shared_object.cunumeric_has_curand() self._register_dtypes() self.args = parse_command_args("cunumeric", ARGS) diff --git a/src/cunumeric/cunumeric.cc b/src/cunumeric/cunumeric.cc index 84bdd1b22..370f8b9b6 100644 --- a/src/cunumeric/cunumeric.cc +++ b/src/cunumeric/cunumeric.cc @@ -83,4 +83,13 @@ void cunumeric_perform_registration(void) ctx, CUNUMERIC_TUNABLE_HAS_NUMAMEM, cunumeric::CuNumeric::mapper_id); if (fut.get_result() != 0) cunumeric::CuNumeric::has_numamem = true; } + +bool cunumeric_has_curand() +{ +#ifdef LEGATE_USE_CUDA + return true; +#else + return false; +#endif +} } diff --git a/src/cunumeric/cunumeric_c.h b/src/cunumeric/cunumeric_c.h index 9cdff1a81..f080aa260 100644 --- a/src/cunumeric/cunumeric_c.h +++ b/src/cunumeric/cunumeric_c.h @@ -328,6 +328,7 @@ extern "C" { #endif void cunumeric_perform_registration(); +bool cunumeric_has_curand(); #ifdef __cplusplus } diff --git a/src/cunumeric/unary/isnan.h b/src/cunumeric/unary/isnan.h index dbc5dcfec..1809d266d 100644 --- a/src/cunumeric/unary/isnan.h +++ b/src/cunumeric/unary/isnan.h @@ -38,6 +38,10 @@ __CUDA_HD__ bool is_nan(const complex& x) return std::isnan(x.imag()) || std::isnan(x.real()); } -__CUDA_HD__ inline bool is_nan(const __half& x) { return isnan(x); } +__CUDA_HD__ inline bool is_nan(const __half& x) +{ + using std::isnan; + return isnan(x); +} } // namespace cunumeric