diff --git a/sklearn_numba_dpex/common/kernels.py b/sklearn_numba_dpex/common/kernels.py index 2824d73..a3eefdd 100644 --- a/sklearn_numba_dpex/common/kernels.py +++ b/sklearn_numba_dpex/common/kernels.py @@ -36,25 +36,41 @@ def elementwise_ops(data): @lru_cache -def make_initialize_to_zeros_kernel(shape, work_group_size, dtype): +def make_fill_kernel(fill_value, shape, work_group_size, dtype): n_items = math.prod(shape) global_size = math.ceil(n_items / work_group_size) * work_group_size - zero = dtype(0.0) + fill_value = dtype(fill_value) @dpex.kernel - def initialize_to_zeros_kernel(data): + def fill_kernel(data): item_idx = dpex.get_global_id(zero_idx) if item_idx >= n_items: return - data[item_idx] = zero + data[item_idx] = fill_value - def initialize_to_zeros(data): + def fill(data): data = dpt.reshape(data, (-1,)) - initialize_to_zeros_kernel[global_size, work_group_size](data) + fill_kernel[global_size, work_group_size](data) + + return fill + + +@lru_cache +def make_range_kernel(n_items, work_group_size): + global_size = math.ceil(n_items / work_group_size) * work_group_size + + @dpex.kernel + def range_kernel(data): + item_idx = dpex.get_global_id(zero_idx) + + if item_idx >= n_items: + return + + data[item_idx] = item_idx - return initialize_to_zeros + return range_kernel[global_size, work_group_size] @lru_cache diff --git a/sklearn_numba_dpex/common/topk.py b/sklearn_numba_dpex/common/topk.py index 46790ee..fb7d984 100644 --- a/sklearn_numba_dpex/common/topk.py +++ b/sklearn_numba_dpex/common/topk.py @@ -19,7 +19,7 @@ _get_sequential_processing_device, check_power_of_2, ) -from sklearn_numba_dpex.common.kernels import make_initialize_to_zeros_kernel +from sklearn_numba_dpex.common.kernels import make_fill_kernel, make_range_kernel from sklearn_numba_dpex.common.reductions import make_sum_reduction_2d_kernel zero_idx = np.int64(0) @@ -137,58 +137,16 @@ def topk(array_in, k, group_sizes=None): The output is not deterministic: the order of the output is undefined. Successive calls can return the same items in different order. """ - # TODO: it seems a kernel specialized for 1d arrays would show 10-20% better - # performance. If this case becomes specifically relevant, consider implementing - # this case separately rather than using the generic multirow top k for 1d arrays. - - shape = array_in.shape - - is_1d = len(shape) == 1 - if is_1d: - n_rows = 1 - n_cols = shape[0] - array_in = dpt.reshape(array_in, (1, -1)) - else: - n_rows, n_cols = shape - - ( - threshold, - n_threshold_occurences_in_topk, - n_threshold_occurences_in_data, - work_group_size, - dtype, - device, - ) = _get_topk_threshold(array_in, k, group_sizes) - - gather_topk_kernel = _make_gather_topk_kernel( - n_rows, - n_cols, + _get_topk_kernel = _make_get_topk_kernel( k, - work_group_size, - ) - - result = dpt.empty((n_rows, k), dtype=dtype, device=device) - - # For each row, maintain an atomically incremented index of the next result value - # to be stored:. - # Note that the ordering of the topk is non-deteriminstic and dependents on the - # concurrency of the parallel work items. - result_col_idx = dpt.zeros((n_rows,), dtype=np.int32, device=device) - - gather_topk_kernel( - array_in, - threshold, - n_threshold_occurences_in_topk, - n_threshold_occurences_in_data, - result_col_idx, - # OUT - result, + array_in.shape, + array_in.dtype.type, + array_in.device.sycl_device, + group_sizes, + output="values", ) - if is_1d: - return dpt.reshape(result, (-1,)) - - return result + return _get_topk_kernel(array_in) def topk_idx(array_in, k, group_sizes=None): @@ -226,60 +184,169 @@ def topk_idx(array_in, k, group_sizes=None): for this value can be different between two successive calls. """ - shape = array_in.shape + _get_topk_kernel = _make_get_topk_kernel( + k, + array_in.shape, + array_in.dtype.type, + array_in.device.sycl_device, + group_sizes, + output="idx", + ) + + return _get_topk_kernel(array_in) + + +def _make_get_topk_kernel( + k, shape, dtype, device, group_sizes, output, reuse_result_buffer=False +): + """Returns a `_get_topk_kernel` closure. + The closure can be passed an array with attributes `shape`, `dtype` and `device` + and will perform a TopK search, returning requested top-k items. + + As long as a closure is referenced, it keeps in cache pre-allocated buffers and + pre-defined kernel functions. Thus, it is more efficient to perform sequential + calls to the same closure, since subsequent calls will not have the overhead of + re-defining kernels and re-allocating buffers. + + For isolated calls, top-level user-exposed `topk` and `topk_idx` can be used + instead. They include definition of kernels, allocation of buffers, and + cleaning of said allocations afterwards. + + By default, the memory allocation for the result array is not reused. This is to + avoid a previously computed result to be erased by a subsequent call to the same + closure without the user noticing. Reusing the same buffer can still be enforced by + setting `reuse_result_buffer=True`. + """ + # TODO: it seems a kernel specialized for 1d arrays would show 10-20% better + # performance. If this case becomes specifically relevant, consider implementing + # this case separately rather than using the generic multirow top k for 1d arrays. is_1d = len(shape) == 1 if is_1d: n_rows = 1 n_cols = shape[0] - array_in = dpt.reshape(array_in, (1, -1)) else: n_rows, n_cols = shape - ( - threshold, - n_threshold_occurences_in_topk, - n_threshold_occurences_in_data, - work_group_size, - dtype, - device, - ) = _get_topk_threshold(array_in, k, group_sizes) + work_group_size, get_topk_threshold = _make_get_topk_threshold_kernel( + n_rows, n_cols, k, dtype, device, group_sizes + ) - gather_topk_idx_kernel = _make_gather_topk_idx_kernel( - n_rows, - n_cols, - k, - work_group_size, + ( + _initialize_result, + _initialize_result_col_idx, + gather_results_kernel, + ) = _get_gather_results_kernels( + n_rows, n_cols, k, work_group_size, dtype, device, output, reuse_result_buffer ) - result = dpt.empty((n_rows, k), dtype=np.int64, device=device) - result_col_idx = dpt.zeros((n_rows,), dtype=np.int32, device=device) + def _get_topk(array_in): + if is_1d: + array_in = dpt.reshape(array_in, (1, -1)) + + ( + threshold, + n_threshold_occurences_in_topk, + n_threshold_occurences_in_data, + ) = get_topk_threshold(array_in) + + result_col_idx = _initialize_result_col_idx() + result = _initialize_result(array_in.dtype.type) + + gather_results_kernel( + array_in, + threshold, + n_threshold_occurences_in_topk, + n_threshold_occurences_in_data, + result_col_idx, + # OUT + result, + ) - gather_topk_idx_kernel( - array_in, - threshold, - n_threshold_occurences_in_topk, - n_threshold_occurences_in_data, - result_col_idx, - result, - ) + if is_1d: + return dpt.reshape(result, (-1,)) - if is_1d: - return dpt.reshape(result, (-1,)) + return result + + return _get_topk + + +@lru_cache +def _get_gather_results_kernels( + n_rows, n_cols, k, work_group_size, dtype, device, output, reuse_result_buffer +): + if output == "values": + gather_results_kernel = _make_gather_topk_kernel( + n_rows, + n_cols, + k, + work_group_size, + ) + if reuse_result_buffer: + result = dpt.empty((n_rows, k), dtype=dtype, device=device) + + def _initialize_result(dtype): + return result + + else: + + def _initialize_result(dtype): + return dpt.empty((n_rows, k), dtype=dtype, device=device) - return result + elif output == "idx": + gather_results_kernel = _make_gather_topk_idx_kernel( + n_rows, + n_cols, + k, + work_group_size, + ) + if reuse_result_buffer: + result = dpt.empty((n_rows, k), dtype=np.int64, device=device) + + def _initialize_result(dtype): + return result + + else: + + def _initialize_result(dtype): + return dpt.empty((n_rows, k), dtype=np.int64, device=device) + elif output == "values+idx": + raise NotImplementedError -def _get_topk_threshold(array_in, k, group_sizes): - n_rows, n_cols = array_in.shape + else: + raise ValueError( + 'Expected output parameter value to be equal to "values", "idx" or ' + f'"values+idx", but got {output} instead.' + ) + + # `result_col_idx` is used to maintain an atomically incremented index of the next + # result value to be stored:. + # Note that the ordering of the topk is non-deterministic and depends on the + # concurrency of the parallel work items. + result_col_idx = dpt.empty((n_rows,), dtype=np.int32, device=device) + initialize_result_col_idx_kernel = make_fill_kernel( + fill_value=0, + shape=(n_rows,), + work_group_size=work_group_size, + dtype=np.int32, + ) + def _initialize_result_col_idx(): + initialize_result_col_idx_kernel(result_col_idx) + return result_col_idx + + return _initialize_result, _initialize_result_col_idx, gather_results_kernel + + +@lru_cache +def _make_get_topk_threshold_kernel(n_rows, n_cols, k, dtype, device, group_sizes): if n_cols < k: raise ValueError( "Expected k to be greater than or equal to the number of items in the " f"search space, but got k={k} and {n_cols} items in the search space." ) - dtype = np.dtype(array_in.dtype).type if dtype not in uint_type_mapping: raise ValueError( f"topk currently only supports dtypes in {uint_type_mapping.keys()}, but " @@ -288,8 +355,6 @@ def _get_topk_threshold(array_in, k, group_sizes): uint_type = uint_type_mapping[dtype] n_bits_per_item = _get_n_bits_per_item(dtype) - device = array_in.device.sycl_device - if group_sizes is not None: work_group_size, sub_group_size = group_sizes else: @@ -315,24 +380,6 @@ def _get_topk_threshold(array_in, k, group_sizes): device, ) - # This kernel can only reduce 1d or 2d matrices but will be used to reduce the 3d - # matrix of private counts over axis 0. It is made possible by adequatly reshaping - # the 3d matrix before and after the kernel call to a 2d matrix. - n_rows_x_radix_size = n_rows * radix_size - - reduce_privatized_counts = make_sum_reduction_2d_kernel( - shape=(n_counts_private_copies, n_rows_x_radix_size), - device=device, - dtype=np.int64, - work_group_size="max", - axis=0, - sub_group_size=sub_group_size, - ) - - initialize_privatized_counts = make_initialize_to_zeros_kernel( - (n_counts_private_copies, n_rows, radix_size), work_group_size, dtype - ) - # The kernel `check_radix_histogram` seems to be more adapted to cpu or gpu # depending on if `n_rows` is large enough to leverage enough of the # parallelization capabilities of the gpu. @@ -341,8 +388,8 @@ def _get_topk_threshold(array_in, k, group_sizes): # Some other steps in the main loop are more fitted for cpu than gpu. # To this purpose the following variables check availability of a cpu and wether - # a data transfer is required. - + # a data transfer is required when checking the radix histogramn and when updating + # the radix filtering variables. (check_radix_histogram_device, check_radix_histogram_on_sequential_device,) = ( sequential_processing_device, sequential_processing_on_different_device, @@ -368,9 +415,12 @@ def _get_topk_threshold(array_in, k, group_sizes): # In each iteration of the main loop, a lesser, decreasing amount of top values are # searched for in a decreasing subset of data. The following variable records the - # amount of top values to search for at the given iteration. - k_in_subset = dpt.full( - n_rows, k, dtype=np.int32, device=check_radix_histogram_device + # amount of top values to search for at the given iteration (starting at k). + k_in_subset_ = dpt.empty( + n_rows, dtype=np.int32, device=check_radix_histogram_device + ) + initialize_k_in_subset_kernel = make_fill_kernel( + fill_value=k, shape=(n_rows,), work_group_size=work_group_size, dtype=np.int32 ) # Depending on the data, it's possible that the search early stops before having to @@ -380,151 +430,251 @@ def _get_topk_threshold(array_in, k, group_sizes): # If `n_rows > 1`, the search might terminate sooner in some rows than others. # Let's call "active rows" the rows for which the search is still ongoing at the # current iteration. Rows that are not "active rows" are finished searching and are - # waiting for search in other rows to complete. + # waiting for the search in other rows to complete. - # Number of currently active rows - n_active_rows_ = n_rows - n_active_rows = dpt.asarray( - [n_active_rows_], dtype=np.int64, device=check_radix_histogram_device - ) - # Buffer to store the number of active rows in the next iteration - new_n_active_rows = dpt.asarray( - [0], dtype=np.int64, device=check_radix_histogram_device + # Memory allocation for the number of currently active rows + n_active_rows_ = dpt.empty(1, dtype=np.int64, device=check_radix_histogram_device) + # Memory allocation for the number of currently active rows in the next iteration + new_n_active_rows_ = dpt.empty( + 1, dtype=np.int64, device=check_radix_histogram_device ) - # List of indexes of currently active rows (at a given iteration, only slots from - # 0 to `n_active_rows` are used) - active_rows_mapping = dpt.arange(n_rows, dtype=np.int64, device=device) - # Buffer to store the mapping that will be used in the next iteration - new_active_rows_mapping = dpt.zeros( - n_rows, dtype=np.int64, device=check_radix_histogram_device - ) + # NB: The following parameters might be transfered back and forth in memories of + # different devices. In this case, pre-allocation is detrimental. Depending on + # wether two different devices are used for different steps or not, their + # respective memory buffers are cached, or re-allocated at every call. - # Position of the radix that is currently used for sorting data - radix_position = dpt.asarray( - [n_bits_per_item - radix_bits], dtype=uint_type, device=device - ) + # TODO: it would be useful if `dpctl` supports copying a buffer from a given device + # to a pre-allocated array of another device. It is currently not supported. + + # `active_rows_mapping` and `new_active_rows_mapping` are arrays that hold a list + # of the indexes of currently active rows, at the current iteration, and at the + # next iteration + # NB: at a given iteration, only slots from 0 to `n_active_rows` are used + + # `desired_mask_value` is an array that defines the mask value to search for when + # filtering data. + if check_radix_histogram_on_sequential_device: # no caching of buffers + + def initialize_active_rows_mapping(): + active_rows_mapping = dpt.arange(n_rows, dtype=np.int64, device=device) + new_active_rows_mapping = dpt.zeros( + n_rows, dtype=np.int64, device=check_radix_histogram_device + ) + return active_rows_mapping, new_active_rows_mapping + + def initialize_desired_masked_value(): + return dpt.zeros((n_rows,), dtype=uint_type, device=device) + + else: # use caching + active_rows_mapping_ = dpt.empty( + n_rows, dtype=np.int64, device=check_radix_histogram_device + ) + new_active_rows_mapping_ = dpt.empty( + n_rows, dtype=np.int64, device=check_radix_histogram_device + ) + initialize_active_rows_mapping_kernel = make_range_kernel( + n_rows, work_group_size + ) + + def initialize_active_rows_mapping(): + initialize_active_rows_mapping_kernel(active_rows_mapping_) + return active_rows_mapping_, new_active_rows_mapping_ - # mask and value used to filter the subset of the data that is currently searched - # at the given iteration - mask_for_desired_value = dpt.zeros((1,), dtype=uint_type, device=device) - desired_masked_value = dpt.zeros((n_rows,), dtype=uint_type, device=device) + desired_masked_value_ = dpt.empty((n_rows,), dtype=uint_type, device=device) + initialize_desired_masked_value_kernel = make_fill_kernel( + fill_value=0, + shape=(n_rows,), + work_group_size=work_group_size, + dtype=uint_type, + ) - # Buffer that stores the counts of occurences of the values at the current radix - # position. - privatized_counts = dpt.zeros( + def initialize_desired_masked_value(): + initialize_desired_masked_value_kernel(desired_masked_value_) + return desired_masked_value_ + + # `radix_position` holds the position of the radix that is currently used for + # sorting data at a given iteration. + # `mask_for_desired_value` defines, along with `desired_mask_value`, the condition + # that is applied at each iteration to skip already data that has already been + # filtered out by sorting on previous radixes. The remaining data is the data that + # will be scanned at the current iteration. + if sequential_processing_on_different_device: # no caching + + def initialize_radix_mask(): + radix_position_ = dpt.asarray( + [n_bits_per_item - radix_bits], dtype=uint_type, device=device + ) + mask_for_desired_value_ = dpt.zeros((1,), dtype=uint_type, device=device) + return radix_position_, mask_for_desired_value_ + + else: # use caching + radix_position_ = dpt.empty(1, dtype=uint_type, device=device) + mask_for_desired_value_ = dpt.empty((1,), dtype=uint_type, device=device) + + def initialize_radix_mask(): + radix_position_[0] = n_bits_per_item - radix_bits + mask_for_desired_value_[0] = 0 + return radix_position_, mask_for_desired_value_ + + # `privatized_counts` storess the counts of occurences of the values at the current + # radix position. Several copies of it are created, to avoid conflicts on atomic + # during the count update step. The copies are reduced by the reduction kernel + # defined thereafter. + privatized_counts = dpt.empty( (n_counts_private_copies, n_rows, radix_size), dtype=np.int64, device=device ) + # Our sum reduction kernel can only reduce 1d or 2d matrices but will be used to + # reduce the 3d matrix of private counts over axis 0. It is made possible by + # adequatly reshaping the 3d matrix before and after the kernel call to a 2d matrix. + n_rows_x_radix_size = n_rows * radix_size + reduce_privatized_counts = make_sum_reduction_2d_kernel( + shape=(n_counts_private_copies, n_rows_x_radix_size), + device=device, + dtype=np.int64, + work_group_size="max", + axis=0, + sub_group_size=sub_group_size, + ) + initialize_privatized_counts = make_fill_kernel( + fill_value=0, + shape=(n_counts_private_copies, n_rows, radix_size), + work_group_size=work_group_size, + dtype=dtype, + ) # Will store the number of occurences of the top-k threshold value in the data - threshold_count = dpt.zeros( + threshold_count_ = dpt.empty( (n_rows,), dtype=np.int64, device=check_radix_histogram_device ) - - # Reinterpret buffer as uint so we can use bitwise compute - array_in_uint = dpt.usm_ndarray( - shape=(n_rows, n_cols), - dtype=uint_type, - buffer=array_in, + initialize_threshold_count_kernel = make_fill_kernel( + fill_value=0, shape=(n_rows,), work_group_size=work_group_size, dtype=np.int64 ) - # The main loop: each iteration consists in sorting partially the data on the - # values of a given radix of size `radix_size`, then discarding values that are - # below the top k values. - while True: - create_radix_histogram_kernel( - array_in_uint, + def _get_topk_threshold(array_in): + # Use variables that are local to the closure, so it can be manipulated more + # easily in the main loop + k_in_subset, n_active_rows, new_n_active_rows, threshold_count = ( + k_in_subset_, n_active_rows_, - active_rows_mapping, - mask_for_desired_value, - desired_masked_value, - radix_position, - # OUT - privatized_counts, + new_n_active_rows_, + threshold_count_, ) - privatized_counts_ = dpt.reshape( - privatized_counts, (n_counts_private_copies, n_rows_x_radix_size) + # Initialize all buffers + initialize_k_in_subset_kernel(k_in_subset) + n_active_rows[0] = n_active_rows_scalar = n_rows + initialize_threshold_count_kernel(threshold_count) + active_rows_mapping, new_active_rows_mapping = initialize_active_rows_mapping() + desired_masked_value = initialize_desired_masked_value() + radix_position, mask_for_desired_value = initialize_radix_mask() + + # Reinterpret input as uint so we can use bitwise compute + array_in_uint = dpt.usm_ndarray( + shape=(n_rows, n_cols), + dtype=uint_type, + buffer=array_in, ) - counts = dpt.reshape( - reduce_privatized_counts(privatized_counts_), (n_rows, radix_size) - ) + # The main loop: each iteration consists in sorting partially the data on the + # values of a given radix of size `radix_size`, then discarding values that are + # below the top k values. + while True: + initialize_privatized_counts(privatized_counts) + + create_radix_histogram_kernel( + array_in_uint, + n_active_rows_scalar, + active_rows_mapping, + mask_for_desired_value, + desired_masked_value, + radix_position, + # OUT + privatized_counts, + ) - if check_radix_histogram_on_sequential_device: - counts = counts.to_device(check_radix_histogram_device) - mask_for_desired_value = mask_for_desired_value.to_device( - check_radix_histogram_device + privatized_counts_ = dpt.reshape( + privatized_counts, (n_counts_private_copies, n_rows_x_radix_size) ) - desired_masked_value = desired_masked_value.to_device( - check_radix_histogram_device + + counts = dpt.reshape( + reduce_privatized_counts(privatized_counts_), (n_rows, radix_size) ) - radix_position = radix_position.to_device(check_radix_histogram_device) - active_rows_mapping = active_rows_mapping.to_device( - check_radix_histogram_device + + if check_radix_histogram_on_sequential_device: + counts = counts.to_device(check_radix_histogram_device) + desired_masked_value = desired_masked_value.to_device( + check_radix_histogram_device + ) + radix_position = radix_position.to_device(check_radix_histogram_device) + active_rows_mapping = active_rows_mapping.to_device( + check_radix_histogram_device + ) + + new_n_active_rows[0] = 0 + + check_radix_histogram( + counts, + radix_position, + n_active_rows, + active_rows_mapping, + # INOUT + k_in_subset, + desired_masked_value, + # OUT + threshold_count, + new_n_active_rows, + new_active_rows_mapping, ) - check_radix_histogram( - counts, - radix_position, - n_active_rows, - active_rows_mapping, - # INOUT - k_in_subset, - desired_masked_value, - # OUT - threshold_count, - new_n_active_rows, - new_active_rows_mapping, - ) + # If the top k values have been found in all rows, can exit early. + if (n_active_rows_scalar := int(new_n_active_rows[0])) == 0: + break + + # Else, update `radix_position` continue searching using the next radix + if sequential_processing_on_different_device: + mask_for_desired_value = mask_for_desired_value.to_device( + sequential_processing_device + ) + + if change_device_for_radix_update: + radix_position = radix_position.to_device(sequential_processing_device) - # If the top k values have been found in all rows, can exit early. - if (n_active_rows_ := int(new_n_active_rows[0])) == 0: - break + update_radix_position(radix_position, mask_for_desired_value) - # Else, update `radix_position` continue searching using the next radix - if change_device_for_radix_update: - radix_position = radix_position.to_device(sequential_processing_device) - mask_for_desired_value = mask_for_desired_value.to_device( - sequential_processing_device + if sequential_processing_on_different_device: + radix_position = radix_position.to_device(device) + mask_for_desired_value = mask_for_desired_value.to_device(device) + + # Prepare next iteration + n_active_rows, new_n_active_rows = new_n_active_rows, n_active_rows + new_n_active_rows[:] = 0 + active_rows_mapping, new_active_rows_mapping = ( + new_active_rows_mapping, + active_rows_mapping, ) - update_radix_position(radix_position, mask_for_desired_value) - if change_device_for_radix_update or check_radix_histogram_on_sequential_device: - radix_position = radix_position.to_device(device) - mask_for_desired_value = mask_for_desired_value.to_device(device) - - # Prepare next iteration - n_active_rows, new_n_active_rows = new_n_active_rows, n_active_rows - new_n_active_rows[:] = 0 - active_rows_mapping, new_active_rows_mapping = ( - new_active_rows_mapping, - active_rows_mapping, - ) + if check_radix_histogram_on_sequential_device: + desired_masked_value = desired_masked_value.to_device(device) + active_rows_mapping = active_rows_mapping.to_device(device) + + # Ensure data is located on the expected device before returning if check_radix_histogram_on_sequential_device: + k_in_subset = k_in_subset.to_device(device) + threshold_count = threshold_count.to_device(device) desired_masked_value = desired_masked_value.to_device(device) - active_rows_mapping = active_rows_mapping.to_device(device) - - initialize_privatized_counts(privatized_counts) - # Ensure data is located on the expected device before returning - if check_radix_histogram_on_sequential_device: - k_in_subset = k_in_subset.to_device(device) - threshold_count = threshold_count.to_device(device) - desired_masked_value = desired_masked_value.to_device(device) + # reinterpret the threshold back to a dtype item + threshold = dpt.usm_ndarray( + shape=desired_masked_value.shape, dtype=dtype, buffer=desired_masked_value + ) - # reinterpret the threshold back to a dtype item - threshold = dpt.usm_ndarray( - shape=desired_masked_value.shape, dtype=dtype, buffer=desired_masked_value - ) + return ( + threshold, + k_in_subset, + threshold_count, + ) - return ( - threshold, - k_in_subset, - threshold_count, - work_group_size, - dtype, - device, - ) + return work_group_size, _get_topk_threshold def _get_n_bits_per_item(dtype): diff --git a/sklearn_numba_dpex/kmeans/drivers.py b/sklearn_numba_dpex/kmeans/drivers.py index 2bf3931..f40a027 100644 --- a/sklearn_numba_dpex/kmeans/drivers.py +++ b/sklearn_numba_dpex/kmeans/drivers.py @@ -12,8 +12,8 @@ make_apply_elementwise_func, make_broadcast_division_1d_2d_axis0_kernel, make_broadcast_ops_1d_2d_axis1_kernel, + make_fill_kernel, make_half_l2_norm_2d_axis0_kernel, - make_initialize_to_zeros_kernel, ) from sklearn_numba_dpex.common.random import ( create_xoroshiro128pp_states, @@ -91,13 +91,15 @@ def lloyd( n_samples, n_features, max_work_group_size, compute_dtype ) - reset_cluster_sizes_private_copies_kernel = make_initialize_to_zeros_kernel( + reset_cluster_sizes_private_copies_kernel = make_fill_kernel( + fill_value=0, shape=(n_centroids_private_copies, n_clusters), work_group_size=max_work_group_size, dtype=compute_dtype, ) - reset_centroids_private_copies_kernel = make_initialize_to_zeros_kernel( + reset_centroids_private_copies_kernel = make_fill_kernel( + fill_value=0, shape=(n_centroids_private_copies, n_features, n_clusters), work_group_size=max_work_group_size, dtype=compute_dtype,