Skip to content

Commit

Permalink
registering all PointN types during the Runtime initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
ipdemes committed Apr 5, 2022
1 parent 2e92e78 commit dc287a6
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 13 deletions.
16 changes: 15 additions & 1 deletion cunumeric/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def destroy(self):
# Match these to CuNumericOpCode in cunumeric_c.h
@unique
class CuNumericOpCode(IntEnum):
ADVANCED_INDX = _cunumeric.CUNUMERIC_ADVANCED_INDEXING
ADVANCED_INDEXING = _cunumeric.CUNUMERIC_ADVANCED_INDEXING
ARANGE = _cunumeric.CUNUMERIC_ARANGE
BINARY_OP = _cunumeric.CUNUMERIC_BINARY_OP
BINARY_RED = _cunumeric.CUNUMERIC_BINARY_RED
Expand Down Expand Up @@ -211,3 +211,17 @@ class CuNumericTunable(IntEnum):
NUM_PROCS = _cunumeric.CUNUMERIC_TUNABLE_NUM_PROCS
MAX_EAGER_VOLUME = _cunumeric.CUNUMERIC_TUNABLE_MAX_EAGER_VOLUME
HAS_NUMAMEM = _cunumeric.CUNUMERIC_TUNABLE_HAS_NUMAMEM


# Match these to CuNumericTypeCOdes in cunumeric_c.h
@unique
class CuNumericTypeCodes(IntEnum):
CUNUMERIC_TYPE_POINT1 = _cunumeric.CUNUMERIC_TYPE_POINT1
CUNUMERIC_TYPE_POINT2 = _cunumeric.CUNUMERIC_TYPE_POINT2
CUNUMERIC_TYPE_POINT3 = _cunumeric.CUNUMERIC_TYPE_POINT3
CUNUMERIC_TYPE_POINT4 = _cunumeric.CUNUMERIC_TYPE_POINT4
CUNUMERIC_TYPE_POINT5 = _cunumeric.CUNUMERIC_TYPE_POINT5
CUNUMERIC_TYPE_POINT6 = _cunumeric.CUNUMERIC_TYPE_POINT6
CUNUMERIC_TYPE_POINT7 = _cunumeric.CUNUMERIC_TYPE_POINT7
CUNUMERIC_TYPE_POINT8 = _cunumeric.CUNUMERIC_TYPE_POINT8
CUNUMERIC_TYPE_POINT9 = _cunumeric.CUNUMERIC_TYPE_POINT9
6 changes: 3 additions & 3 deletions cunumeric/deferred.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ def _zip_indices(self, start_index, arrays):
# of that dtype, so long as we don't try to convert it to a
# NumPy array.
N = self.ndim
pointN_dtype = self.runtime.add_point_type(N)
pointN_dtype = self.runtime.get_point_type(N)
store = self.context.create_store(
pointN_dtype, shape=out_shape, optimize_scalar=True
)
Expand Down Expand Up @@ -541,11 +541,11 @@ def _create_indexing_array(self, key, is_set=False):
out_dtype = self.dtype
if is_set:
N = self.ndim
out_dtype = self.runtime.add_point_type(N)
out_dtype = self.runtime.get_point_type(N)

out = self.runtime.create_unbound_thunk(out_dtype)
task = self.context.create_task(
CuNumericOpCode.ADVANCED_INDX
CuNumericOpCode.ADVANCED_INDEXING
)
task.add_output(out.base)
task.add_input(self.base)
Expand Down
28 changes: 19 additions & 9 deletions cunumeric/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
CuNumericOpCode,
CuNumericRedopCode,
CuNumericTunable,
CuNumericTypeCodes,
cunumeric_context,
cunumeric_lib,
)
Expand Down Expand Up @@ -113,6 +114,24 @@ def _register_dtypes(self):
for numpy_type, core_type in _supported_dtypes.items():
type_system.make_alias(np.dtype(numpy_type), core_type)

for n in range(1, LEGATE_MAX_DIM + 1):
self._register_point_type(n)

def _register_point_type(self, n):
type_system = self.legate_context.type_system
point_type = "" + str(n)
if point_type not in type_system:
code = CuNumericTypeCodes.CUNUMERIC_TYPE_POINT1 + n - 1
size_in_bytes = 8 * n
type_system.add_type(point_type, size_in_bytes, code)

def get_point_type(self, n):
type_system = self.legate_context.type_system
point_type = "" + str(n)
if point_type not in type_system:
raise ValueError(f"there is no point type registered fro {n}")
return point_type

def _parse_command_args(self):
try:
# Prune it out so the application does not see it
Expand Down Expand Up @@ -192,15 +211,6 @@ def get_arg_dtype(self, value_dtype):
dtype.register_reduction_op(redop, redop_id)
return arg_dtype

def add_point_type(self, n):
type_system = self.legate_context.type_system
point_type = "point" + str(n)
if point_type not in type_system:
code = type_system[ty.int64].code
size_in_bytes = 8 * n
type_system.add_type(point_type, size_in_bytes, code)
return point_type

def _report_coverage(self):
total = len(self.api_calls)
implemented = sum(int(impl) for (_, _, impl) in self.api_calls)
Expand Down
12 changes: 12 additions & 0 deletions src/cunumeric/cunumeric_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,18 @@ enum CuNumericBounds {
CUNUMERIC_MAX_TASKS = 1048576,
};

enum CuNumericTypeCodes {
CUNUMERIC_TYPE_POINT1 = LEGION_TYPE_TOTAL + 1,
CUNUMERIC_TYPE_POINT2,
CUNUMERIC_TYPE_POINT3,
CUNUMERIC_TYPE_POINT4,
CUNUMERIC_TYPE_POINT5,
CUNUMERIC_TYPE_POINT6,
CUNUMERIC_TYPE_POINT7,
CUNUMERIC_TYPE_POINT8,
CUNUMERIC_TYPE_POINT9,
};

#ifdef __cplusplus
extern "C" {
#endif
Expand Down

0 comments on commit dc287a6

Please sign in to comment.