Skip to content

Commit

Permalink
Merge pull request #39 from tylerjereddy/treddy_2D_reciprocal
Browse files Browse the repository at this point in the history
ENH: 2D reciprocal ufunc
  • Loading branch information
NaderAlAwar authored Jul 28, 2022
2 parents 1124b2d + 2c72b67 commit 6f6d199
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 2 deletions.
20 changes: 18 additions & 2 deletions pykokkos/lib/ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,18 @@ def reciprocal_impl_1d_float(tid: int, view: pk.View1D[pk.float]):
view[tid] = 1 / view[tid] # type: ignore


@pk.workunit
def reciprocal_impl_2d_double(tid: int, view: pk.View2D[pk.double]):
for i in range(view.extent(1)): # type: ignore
view[tid][i] = 1 / view[tid][i] # type: ignore


@pk.workunit
def reciprocal_impl_2d_float(tid: int, view: pk.View2D[pk.float]):
for i in range(view.extent(1)): # type: ignore
view[tid][i] = 1 / view[tid][i] # type: ignore


def reciprocal(view):
"""
Return the reciprocal of the argument, element-wise.
Expand All @@ -33,10 +45,14 @@ def reciprocal(view):
"""
# see gh-29 for some discussion of the dispatching
# awkwardness used here
if str(view.dtype) == "DataType.double":
if str(view.dtype) == "DataType.double" and len(view.shape) == 1:
pk.parallel_for(view.shape[0], reciprocal_impl_1d_double, view=view)
elif str(view.dtype) == "DataType.float":
elif str(view.dtype) == "DataType.float" and len(view.shape) == 1:
pk.parallel_for(view.shape[0], reciprocal_impl_1d_float, view=view)
elif str(view.dtype) == "DataType.float" and len(view.shape) == 2:
pk.parallel_for(view.shape[0], reciprocal_impl_2d_float, view=view)
elif str(view.dtype) == "DataType.double" and len(view.shape) == 2:
pk.parallel_for(view.shape[0], reciprocal_impl_2d_double, view=view)
# NOTE: pretty awkward to both return the view
# and operate on it in place; the former is closer
# to NumPy semantics
Expand Down
22 changes: 22 additions & 0 deletions tests/test_ufuncs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pykokkos as pk

import numpy as np
from numpy.random import default_rng
from numpy.testing import assert_allclose
import pytest

Expand Down Expand Up @@ -349,3 +350,24 @@ def test_caching():
view[:] = np.arange(10, dtype=np.float32)
actual = pk.reciprocal(view=view)
assert_allclose(actual, expected)


@pytest.mark.parametrize("pk_ufunc, numpy_ufunc", [
(pk.reciprocal, np.reciprocal),
])
@pytest.mark.parametrize("pk_dtype, numpy_dtype", [
(pk.double, np.float64),
(pk.float, np.float32),
])
def test_2d_exposed_ufuncs_vs_numpy(pk_ufunc,
numpy_ufunc,
pk_dtype,
numpy_dtype):
rng = default_rng(123)
in_arr = rng.random((5, 5)).astype(numpy_dtype)
expected = numpy_ufunc(in_arr)

view: pk.View2d = pk.View([5, 5], pk_dtype)
view[:] = in_arr
actual = pk_ufunc(view=view)
assert_allclose(actual, expected)

0 comments on commit 6f6d199

Please sign in to comment.