Skip to content

Commit

Permalink
Support cast_to_${x} on numpy arrays.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 482230083
  • Loading branch information
tomhennigan authored and copybara-github committed Oct 19, 2022
1 parent e64e7fb commit 7b3ae54
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 9 deletions.
5 changes: 4 additions & 1 deletion jmp/_src/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ jmp_py_library(
name = "policy",
srcs = ["policy.py"],
srcs_version = "PY3",
deps = ["//third_party/py/jax"],
deps = [
# pip: jax
# pip: numpy
],
)

jmp_py_test(
Expand Down
4 changes: 3 additions & 1 deletion jmp/_src/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@

import jax
import jax.numpy as jnp
import numpy as np

T = TypeVar("T")


def _cast_floating_to(tree: T, dtype: jnp.dtype) -> T:
def conditional_cast(x):
if isinstance(x, jnp.ndarray) and jnp.issubdtype(x.dtype, jnp.floating):
if (isinstance(x, (np.ndarray, jnp.ndarray)) and
jnp.issubdtype(x.dtype, jnp.floating)):
x = x.astype(dtype)
return x
return jax.tree_util.tree_map(conditional_cast, tree)
Expand Down
18 changes: 11 additions & 7 deletions jmp/_src/policy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
HALF_DTYPES = (np.float16, jnp.float16, jnp.bfloat16)
FULL_DTYPES = (np.float32, jnp.float32)
DTYPES = HALF_DTYPES + FULL_DTYPES
NUMPYS = np, jnp
NUMPYS = (np, jnp)


def get_dtype_name(dtype):
Expand All @@ -53,32 +53,36 @@ def skip_if_unsupported(dtype):

class PolicyTest(parameterized.TestCase):

def assert_dtypes_equal(self, tree_a, tree_b):
jax.tree_map(lambda a, b: self.assertEqual(a.dtype, b.dtype), tree_a,
tree_b)

@parameterized.parameters(*it.product(DTYPES, NUMPYS))
def test_policy_cast_to_param(self, dtype, np_):
skip_if_unsupported(dtype)
policy = jmp.Policy(dtype, dtype, dtype)
self.assertEqual(policy.param_dtype, dtype)
tree = {"a": np_.ones([])}
self.assertEqual(policy.cast_to_param(tree),
{"a": np_.ones([], dtype)})
self.assert_dtypes_equal(policy.cast_to_param(tree),
{"a": np_.ones([], dtype)})

@parameterized.parameters(*it.product(DTYPES, NUMPYS))
def test_policy_cast_to_compute(self, dtype, np_):
skip_if_unsupported(dtype)
policy = jmp.Policy(dtype, dtype, dtype)
self.assertEqual(policy.compute_dtype, dtype)
tree = {"a": np_.ones([])}
self.assertEqual(policy.cast_to_compute(tree),
{"a": np_.ones([], dtype)})
self.assert_dtypes_equal(policy.cast_to_compute(tree),
{"a": np_.ones([], dtype)})

@parameterized.parameters(*it.product(DTYPES, NUMPYS))
def test_policy_cast_to_output(self, dtype, np_):
skip_if_unsupported(dtype)
policy = jmp.Policy(dtype, dtype, dtype)
self.assertEqual(policy.output_dtype, dtype)
tree = {"a": np_.ones([])}
self.assertEqual(policy.cast_to_output(tree),
{"a": np_.ones([], dtype)})
self.assert_dtypes_equal(policy.cast_to_output(tree),
{"a": np_.ones([], dtype)})

@parameterized.parameters(*it.product(DTYPES, NUMPYS))
def test_policy_with_output_dtype(self, dtype, np_):
Expand Down

0 comments on commit 7b3ae54

Please sign in to comment.