From 4b41485abf6de40bd7171f2be6b8a79fa9426fb2 Mon Sep 17 00:00:00 2001 From: RakshitKumar04 Date: Sat, 6 May 2023 01:34:43 +0530 Subject: [PATCH 1/4] Added Cumprod Instance method to Jax NumPy Frontend --- ivy/functional/frontends/jax/devicearray.py | 8 +++++ .../test_jax/test_jax_devicearray.py | 36 +++++++++++++++++++ 2 files changed, 44 insertions(+) diff --git a/ivy/functional/frontends/jax/devicearray.py b/ivy/functional/frontends/jax/devicearray.py index cecb0246f5918..0bad42fe19a57 100644 --- a/ivy/functional/frontends/jax/devicearray.py +++ b/ivy/functional/frontends/jax/devicearray.py @@ -78,6 +78,14 @@ def mean(self, *, axis=None, dtype=None, out=None, keepdims=False, where=None): keepdims=keepdims, where=where, ) + + def cumprod(self, axis=None, dtype=None, out=None): + return jax_frontend.numpy.cumsum( + self._ivy_array, + axis=axis, + dtype=dtype, + out=out, + ) def __add__(self, other): return jax_frontend.numpy.add(self, other) diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_devicearray.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_devicearray.py index 4996f787d9a6d..6e731d0e93043 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_devicearray.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_devicearray.py @@ -240,6 +240,42 @@ def test_jax_devicearray_mean( ) +@handle_frontend_method( + class_tree=CLASS_TREE, + init_tree="jax.numpy.array", + method_name="cumprod", + dtype_and_x=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("integer"), + force_int_axis=True, + valid_axis=True, + ), +) +def test_jax_devicearray_cumprod( + dtype_and_x, + on_device, + frontend, + frontend_method_data, + init_flags, + method_flags, +): + input_dtype, x, axis = dtype_and_x + helpers.test_frontend_method( + init_input_dtypes=input_dtype, + init_all_as_kwargs_np={ + "object": x[0], + }, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={ + "axis": axis, + }, + frontend=frontend, + frontend_method_data=frontend_method_data, + init_flags=init_flags, + method_flags=method_flags, + on_device=on_device, + ) + + @handle_frontend_method( class_tree=CLASS_TREE, init_tree="jax.numpy.array", From 9bc5ffe8717908d83f68e440eb4b17e5583f123e Mon Sep 17 00:00:00 2001 From: RakshitKumar04 Date: Sat, 6 May 2023 22:15:51 +0530 Subject: [PATCH 2/4] Edited the code --- ivy/functional/frontends/jax/devicearray.py | 2 ++ .../test_frontends/test_jax/test_jax_devicearray.py | 10 ++++++++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/ivy/functional/frontends/jax/devicearray.py b/ivy/functional/frontends/jax/devicearray.py index 0bad42fe19a57..10f9984e27451 100644 --- a/ivy/functional/frontends/jax/devicearray.py +++ b/ivy/functional/frontends/jax/devicearray.py @@ -80,6 +80,8 @@ def mean(self, *, axis=None, dtype=None, out=None, keepdims=False, where=None): ) def cumprod(self, axis=None, dtype=None, out=None): + if dtype is None: + dtype = ivy.as_ivy_dtype(self.dtype) return jax_frontend.numpy.cumsum( self._ivy_array, axis=axis, diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_devicearray.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_devicearray.py index 6e731d0e93043..d267af87d364c 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_devicearray.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_devicearray.py @@ -245,9 +245,15 @@ def test_jax_devicearray_mean( init_tree="jax.numpy.array", method_name="cumprod", dtype_and_x=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("integer"), - force_int_axis=True, + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=1, + max_num_dims=5, + min_value=-100, + max_value=100, valid_axis=True, + allow_neg_axes=False, + max_axes_size=1, + force_int_axis=True, ), ) def test_jax_devicearray_cumprod( From 8eeefadd3e1c5eeb8f6a02dbe4a5205445f7b818 Mon Sep 17 00:00:00 2001 From: RakshitKumar04 Date: Mon, 8 May 2023 21:40:15 +0530 Subject: [PATCH 3/4] Updated the implementation --- ivy/functional/frontends/jax/devicearray.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/ivy/functional/frontends/jax/devicearray.py b/ivy/functional/frontends/jax/devicearray.py index 10f9984e27451..d2107c5fd72e6 100644 --- a/ivy/functional/frontends/jax/devicearray.py +++ b/ivy/functional/frontends/jax/devicearray.py @@ -80,9 +80,7 @@ def mean(self, *, axis=None, dtype=None, out=None, keepdims=False, where=None): ) def cumprod(self, axis=None, dtype=None, out=None): - if dtype is None: - dtype = ivy.as_ivy_dtype(self.dtype) - return jax_frontend.numpy.cumsum( + return jax_frontend.numpy.cumprod( self._ivy_array, axis=axis, dtype=dtype, From 0b546d69da7fe36cf9e3ae96eb0df1121255ef46 Mon Sep 17 00:00:00 2001 From: RakshitKumar04 Date: Mon, 8 May 2023 21:51:35 +0530 Subject: [PATCH 4/4] Re-edited the implemetation --- ivy/functional/frontends/jax/devicearray.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ivy/functional/frontends/jax/devicearray.py b/ivy/functional/frontends/jax/devicearray.py index d2107c5fd72e6..b0fbb49e1fddd 100644 --- a/ivy/functional/frontends/jax/devicearray.py +++ b/ivy/functional/frontends/jax/devicearray.py @@ -81,7 +81,7 @@ def mean(self, *, axis=None, dtype=None, out=None, keepdims=False, where=None): def cumprod(self, axis=None, dtype=None, out=None): return jax_frontend.numpy.cumprod( - self._ivy_array, + self, axis=axis, dtype=dtype, out=out,