From 7b3ae548154625a56a6dbdb03bc07d4f599c37b2 Mon Sep 17 00:00:00 2001 From: Tom Hennigan Date: Wed, 19 Oct 2022 09:48:36 -0700 Subject: [PATCH] Support cast_to_${x} on numpy arrays. PiperOrigin-RevId: 482230083 --- jmp/_src/BUILD | 5 ++++- jmp/_src/policy.py | 4 +++- jmp/_src/policy_test.py | 18 +++++++++++------- 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/jmp/_src/BUILD b/jmp/_src/BUILD index 4a36ab1..abf2c3b 100644 --- a/jmp/_src/BUILD +++ b/jmp/_src/BUILD @@ -32,7 +32,10 @@ jmp_py_library( name = "policy", srcs = ["policy.py"], srcs_version = "PY3", - deps = ["//third_party/py/jax"], + deps = [ + # pip: jax + # pip: numpy + ], ) jmp_py_test( diff --git a/jmp/_src/policy.py b/jmp/_src/policy.py index fb32dc3..8b7afed 100644 --- a/jmp/_src/policy.py +++ b/jmp/_src/policy.py @@ -19,13 +19,15 @@ import jax import jax.numpy as jnp +import numpy as np T = TypeVar("T") def _cast_floating_to(tree: T, dtype: jnp.dtype) -> T: def conditional_cast(x): - if isinstance(x, jnp.ndarray) and jnp.issubdtype(x.dtype, jnp.floating): + if (isinstance(x, (np.ndarray, jnp.ndarray)) and + jnp.issubdtype(x.dtype, jnp.floating)): x = x.astype(dtype) return x return jax.tree_util.tree_map(conditional_cast, tree) diff --git a/jmp/_src/policy_test.py b/jmp/_src/policy_test.py index 34a880c..d70a267 100644 --- a/jmp/_src/policy_test.py +++ b/jmp/_src/policy_test.py @@ -27,7 +27,7 @@ HALF_DTYPES = (np.float16, jnp.float16, jnp.bfloat16) FULL_DTYPES = (np.float32, jnp.float32) DTYPES = HALF_DTYPES + FULL_DTYPES -NUMPYS = np, jnp +NUMPYS = (np, jnp) def get_dtype_name(dtype): @@ -53,14 +53,18 @@ def skip_if_unsupported(dtype): class PolicyTest(parameterized.TestCase): + def assert_dtypes_equal(self, tree_a, tree_b): + jax.tree_map(lambda a, b: self.assertEqual(a.dtype, b.dtype), tree_a, + tree_b) + @parameterized.parameters(*it.product(DTYPES, NUMPYS)) def test_policy_cast_to_param(self, dtype, np_): skip_if_unsupported(dtype) policy = jmp.Policy(dtype, dtype, dtype) self.assertEqual(policy.param_dtype, dtype) tree = {"a": np_.ones([])} - self.assertEqual(policy.cast_to_param(tree), - {"a": np_.ones([], dtype)}) + self.assert_dtypes_equal(policy.cast_to_param(tree), + {"a": np_.ones([], dtype)}) @parameterized.parameters(*it.product(DTYPES, NUMPYS)) def test_policy_cast_to_compute(self, dtype, np_): @@ -68,8 +72,8 @@ def test_policy_cast_to_compute(self, dtype, np_): policy = jmp.Policy(dtype, dtype, dtype) self.assertEqual(policy.compute_dtype, dtype) tree = {"a": np_.ones([])} - self.assertEqual(policy.cast_to_compute(tree), - {"a": np_.ones([], dtype)}) + self.assert_dtypes_equal(policy.cast_to_compute(tree), + {"a": np_.ones([], dtype)}) @parameterized.parameters(*it.product(DTYPES, NUMPYS)) def test_policy_cast_to_output(self, dtype, np_): @@ -77,8 +81,8 @@ def test_policy_cast_to_output(self, dtype, np_): policy = jmp.Policy(dtype, dtype, dtype) self.assertEqual(policy.output_dtype, dtype) tree = {"a": np_.ones([])} - self.assertEqual(policy.cast_to_output(tree), - {"a": np_.ones([], dtype)}) + self.assert_dtypes_equal(policy.cast_to_output(tree), + {"a": np_.ones([], dtype)}) @parameterized.parameters(*it.product(DTYPES, NUMPYS)) def test_policy_with_output_dtype(self, dtype, np_):