Skip to content

Commit

Permalink
Get rid of block size
Browse files Browse the repository at this point in the history
  • Loading branch information
antonwolfy committed Apr 24, 2024
1 parent b2b9dcf commit dd6e322
Showing 1 changed file with 10 additions and 19 deletions.
29 changes: 10 additions & 19 deletions dpnp/dpnp_iface_histograms.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,16 +298,12 @@ def histogram(a, bins=10, range=None, density=None, weights=None):
else:
ntype = weights.dtype

# We set a block size, as this allows us to iterate over chunks when
# computing histograms, to minimize memory usage.
block_size = 65536

# The fast path uses bincount, but that only works for certain types
# of weight
# simple_weights = (
# weights is None or
# np.can_cast(weights.dtype, np.double) or
# np.can_cast(weights.dtype, complex)
# dpnp.can_cast(weights.dtype, dpnp.double) or
# dpnp.can_cast(weights.dtype, complex)
# )
# TODO: implement a fast path
simple_weights = False
Expand All @@ -317,24 +313,19 @@ def histogram(a, bins=10, range=None, density=None, weights=None):
pass
else:
# Compute via cumulative histogram
cum_n = dpnp.zeros_like(bin_edges, dtype=ntype)
if weights is None:
for i in _range(0, len(a), block_size):
sa = dpnp.sort(a[i : i + block_size])
cum_n += _search_sorted_inclusive(sa, bin_edges)
sa = dpnp.sort(a)
cum_n = _search_sorted_inclusive(sa, bin_edges)
else:
zero = dpnp.zeros(
1, dtype=ntype, sycl_queue=a.sycl_queue, usm_type=a.usm_type
)
for i in _range(0, len(a), block_size):
tmp_a = a[i : i + block_size]
tmp_w = weights[i : i + block_size]
sorting_index = dpnp.argsort(tmp_a)
sa = tmp_a[sorting_index]
sw = tmp_w[sorting_index]
cw = dpnp.concatenate((zero, sw.cumsum(dtype=ntype)))
bin_index = _search_sorted_inclusive(sa, bin_edges)
cum_n += cw[bin_index]
sorting_index = dpnp.argsort(a)
sa = a[sorting_index]
sw = weights[sorting_index]
cw = dpnp.concatenate((zero, sw.cumsum(dtype=ntype)))
bin_index = _search_sorted_inclusive(sa, bin_edges)
cum_n = cw[bin_index]

n = dpnp.diff(cum_n)

Expand Down

0 comments on commit dd6e322

Please sign in to comment.