Skip to content

Commit

Permalink
Fixed reduction kernel.
Browse files Browse the repository at this point in the history
Merged with main.
  • Loading branch information
mingjie-intel committed May 13, 2023
1 parent 9c9fb35 commit ad5a888
Show file tree
Hide file tree
Showing 11 changed files with 106 additions and 139 deletions.
107 changes: 0 additions & 107 deletions driver.py

This file was deleted.

9 changes: 4 additions & 5 deletions numba_dpex/core/passes/parfor_lowering_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,17 +154,20 @@ def _submit_parfor_kernel(
)
global_range.append(stop)
else:
stop = None
if kernel_type == 1:
stop = _load_range(lowerer, 1)
else:
stop = reductionHelper.global_size_var
stop = _load_range(lowerer, stop)
stop = _load_range(lowerer, stop)
global_range.append(stop)

local_range = []
if kernel_type == 2:
local_range.append(
_load_range(lowerer, reductionHelper.work_group_size)
)

# Submit a synchronous kernel
ir_builder.submit_sync_ranged_kernel(
curr_queue,
Expand Down Expand Up @@ -311,15 +314,13 @@ def _lower_parfor_as_kernel(self, lowerer, parfor):
parfor_reddict,
)

print("-------> submit first NDrange kernel")
self._submit_parfor_kernel(
lowerer,
psrfor_kernel,
loop_ranges,
2,
reductionHelperList[0],
)
print("-----> first kernel is done ")

psrfor_kernel = create_reduction_remainder_kernel_for_parfor(
parfor,
Expand All @@ -331,15 +332,13 @@ def _lower_parfor_as_kernel(self, lowerer, parfor):
reductionHelperList,
)

print("-------> submit second range kernel")
self._submit_parfor_kernel(
lowerer,
psrfor_kernel,
loop_ranges,
1,
reductionHelperList[0],
)
print("-----> second kernel is done ")

reductionKernelVar.copy_final_sum_to_host(psrfor_kernel)
# ---------------------- End of Reduction codegen
Expand Down
8 changes: 4 additions & 4 deletions numba_dpex/core/pipelines/kernel_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,10 @@ def define_untyped_pipeline(state, name="dpex_kernel_untyped"):
# Add pass to ensure when users allocate static constant memory the
# size of the allocation is a constant and not specified by a closure
# variable.
pm.add_pass(
ConstantSizeStaticLocalMemoryPass,
"dpex constant size for static local memory",
)
# pm.add_pass(
# ConstantSizeStaticLocalMemoryPass,
# "dpex constant size for static local memory",
# )

# --- End of dpex passes added to the untyped pipeline --#

Expand Down
3 changes: 3 additions & 0 deletions numba_dpex/core/runtime/_dpexrt_python.c
Original file line number Diff line number Diff line change
Expand Up @@ -1287,6 +1287,7 @@ static PyObject *build_c_helpers_dict(void)
_declpointer("DPEXRTQueue_CreateFromFilterString",
&DPEXRTQueue_CreateFromFilterString);
_declpointer("DpexrtQueue_SubmitRange", &DpexrtQueue_SubmitRange);
_declpointer("DpexrtQueue_SubmitNDRange", &DpexrtQueue_SubmitNDRange);
_declpointer("DPEXRT_MemInfo_alloc", &DPEXRT_MemInfo_alloc);
_declpointer("DPEXRT_MemInfo_fill", &DPEXRT_MemInfo_fill);
_declpointer("NRT_ExternalAllocator_new_for_usm",
Expand Down Expand Up @@ -1350,6 +1351,8 @@ MOD_INIT(_dpexrt_python)
PyLong_FromVoidPtr(&DPEXRTQueue_CreateFromFilterString));
PyModule_AddObject(m, "DpexrtQueue_SubmitRange",
PyLong_FromVoidPtr(&DpexrtQueue_SubmitRange));
PyModule_AddObject(m, "DpexrtQueue_SubmitNDRange",
PyLong_FromVoidPtr(&DpexrtQueue_SubmitNDRange));
PyModule_AddObject(m, "DPEXRT_MemInfo_alloc",
PyLong_FromVoidPtr(&DPEXRT_MemInfo_alloc));
PyModule_AddObject(m, "DPEXRT_MemInfo_fill",
Expand Down
2 changes: 1 addition & 1 deletion numba_dpex/core/utils/kernel_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,5 +414,5 @@ def submit_sync_ranged_kernel(
),
self.context.get_constant(types.uintp, 0),
]
print("---------------> this is NDrange kernel")

self.rtctx.submit_ndrange(self.builder, *args)
Original file line number Diff line number Diff line change
Expand Up @@ -262,9 +262,6 @@ def _generate_remainder_kernel_as_string(self):
for i, redvar in enumerate(self._redvars):
gufunc_txt += f" {self._final_sum_var_name[i]}[0] += {self._partial_sum_var_name[i]}[j]\n"

gufunc_txt += f" print({self._final_sum_var_name[0]}[0])\n"
gufunc_txt += f" print({self._final_sum_var_name[1]}[0])\n"

gufunc_txt += (
f" for j in range ({self._global_size_mod_var_name[0]}) :\n"
)
Expand All @@ -280,9 +277,6 @@ def _generate_remainder_kernel_as_string(self):
+ " = "
+ f"{self._global_size_var_name[0]} + j\n"
)
gufunc_txt += f" print({self._legal_loop_indices[0]})\n"
gufunc_txt += f" print(a[{self._legal_loop_indices[0]}])\n"
gufunc_txt += f" print(b[{self._legal_loop_indices[0]}])\n"

for redvar in self._redvars:
rtyp = str(self._typemap[redvar])
Expand All @@ -306,7 +300,6 @@ def _generate_remainder_kernel_as_string(self):
redop = self._parfor_reddict[redvar].redop
if redop == operator.iadd:
gufunc_txt += f" {self._final_sum_var_name[i]}[0] += local_sums_{legal_redvar}[0]\n"
gufunc_txt += f" print(local_sums_{legal_redvar}[0])\n"
elif redop == operator.imul:
gufunc_txt += f" {self._final_sum_var_name[i]}[0] *= local_sums_{legal_redvar}[0]\n"
else:
Expand Down
2 changes: 1 addition & 1 deletion numba_dpex/core/utils/reduction_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def _allocate_partail_reduction_arrays(
types.literal(inputArrayType.usm_type),
),
)
# shape, dtype=None, order="C", usm_type=None, device=None, sycl_queue=None
# shape, dtype=None, order="C", device=None, usm_type=None, sycl_queue=None
sizeVar = pfbdr.make_tuple_variable(
[self.partial_sum_size_var], name="tuple_sizeVar"
)
Expand Down
7 changes: 4 additions & 3 deletions numba_dpex/core/utils/reduction_kernel_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
get_unused_var_name,
legalize_names,
mk_unique_var,
remove_dels,
rename_labels,
replace_var_names,
)
from numba.core.typing import signature
Expand Down Expand Up @@ -143,8 +145,8 @@ def create_reduction_main_kernel_for_parfor(

# FIXME: Why rename and remove dels causes the partial_sum array update
# instructions to be removed.
# kernel_ir.blocks = rename_labels(kernel_ir.blocks)
# remove_dels(kernel_ir.blocks)
kernel_ir.blocks = rename_labels(kernel_ir.blocks)
remove_dels(kernel_ir.blocks)

old_alias = flags.noalias

Expand All @@ -154,7 +156,6 @@ def create_reduction_main_kernel_for_parfor(
kernel_sig = signature(types.none, *kernel_param_types)
exec_queue = typemap[parfor_args[0]].queue

breakpoint()
sycl_kernel = _compile_kernel_parfor(
exec_queue,
kernel_name,
Expand Down
8 changes: 4 additions & 4 deletions numba_dpex/dpnp_iface/_intrinsic.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def impl_dpnp_zeros(
ty_shape,
ty_dtype,
ty_order,
ty_like,
# ty_like,
ty_device,
ty_usm_type,
ty_sycl_queue,
Expand Down Expand Up @@ -223,7 +223,7 @@ def impl_dpnp_zeros(
ty_shape,
ty_dtype,
ty_order,
ty_like,
# ty_like,
ty_device,
ty_usm_type,
ty_sycl_queue,
Expand All @@ -245,7 +245,7 @@ def impl_dpnp_ones(
ty_shape,
ty_dtype,
ty_order,
ty_like,
# ty_like,
ty_device,
ty_usm_type,
ty_sycl_queue,
Expand Down Expand Up @@ -283,7 +283,7 @@ def impl_dpnp_ones(
ty_shape,
ty_dtype,
ty_order,
ty_like,
# ty_like,
ty_device,
ty_usm_type,
ty_sycl_queue,
Expand Down
14 changes: 7 additions & 7 deletions numba_dpex/dpnp_iface/arrayobj.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def ol_dpnp_zeros(
shape,
dtype=None,
order="C",
like=None,
# like=None,
device=None,
usm_type="device",
sycl_queue=None,
Expand Down Expand Up @@ -355,7 +355,7 @@ def ol_dpnp_zeros(
"""

_ndim = _ty_parse_shape(shape)
_layout = _parse_layout(order)
# _layout = _parse_layout(order)
_dtype = _parse_dtype(dtype)
_layout = _parse_layout(order)
_usm_type = _parse_usm_type(usm_type) if usm_type is not None else "device"
Expand All @@ -377,7 +377,7 @@ def impl(
shape,
dtype=None,
order="C",
like=None,
# like=None,
device=None,
usm_type="device",
sycl_queue=None,
Expand All @@ -386,7 +386,7 @@ def impl(
shape,
_dtype,
order,
like,
# like,
_device,
_usm_type,
sycl_queue,
Expand All @@ -408,7 +408,7 @@ def ol_dpnp_ones(
shape,
dtype=None,
order="C",
like=None,
# like=None,
device=None,
usm_type="device",
sycl_queue=None,
Expand Down Expand Up @@ -475,7 +475,7 @@ def impl(
shape,
dtype=None,
order="C",
like=None,
# like=None,
device=None,
usm_type="device",
sycl_queue=None,
Expand All @@ -484,7 +484,7 @@ def impl(
shape,
_dtype,
order,
like,
# like,
_device,
_usm_type,
sycl_queue,
Expand Down
Loading

0 comments on commit ad5a888

Please sign in to comment.