diff --git a/ivy/functional/backends/jax/experimental/linear_algebra.py b/ivy/functional/backends/jax/experimental/linear_algebra.py index a4e19fe4ef816..de1800e93fcf5 100644 --- a/ivy/functional/backends/jax/experimental/linear_algebra.py +++ b/ivy/functional/backends/jax/experimental/linear_algebra.py @@ -159,3 +159,16 @@ def lu_factor( out: Optional[JaxArray] = None, ) -> Tuple[JaxArray]: raise IvyNotImplementedException() + + +def dot( + a: JaxArray, + b: JaxArray, + /, + *, + out: Optional[JaxArray] = None, +) -> JaxArray: + return jnp.dot(a, b, out=out) + + +dot.support_native_out = True diff --git a/ivy/functional/backends/mxnet/experimental/linear_algebra.py b/ivy/functional/backends/mxnet/experimental/linear_algebra.py index 567081ba53254..a4b17e927e4a8 100644 --- a/ivy/functional/backends/mxnet/experimental/linear_algebra.py +++ b/ivy/functional/backends/mxnet/experimental/linear_algebra.py @@ -99,3 +99,15 @@ def cond( out: Optional[Union[(None, mx.ndarray.NDArray)]] = None, ) -> Union[(None, mx.ndarray.NDArray)]: raise IvyNotImplementedException() + +def dot( + a: mx.ndarray.NDArray, + b: mx.ndarray.NDArray, + /, + *, + out: Optional[mx.ndarray.NDArray] = None, +) -> mx.ndarray.NDArray: + return mx.symbol.dot(a, b, out=out) + + +dot.support_native_out = True diff --git a/ivy/functional/backends/numpy/experimental/linear_algebra.py b/ivy/functional/backends/numpy/experimental/linear_algebra.py index ea2e117dd75fd..8f1399e4a10c6 100644 --- a/ivy/functional/backends/numpy/experimental/linear_algebra.py +++ b/ivy/functional/backends/numpy/experimental/linear_algebra.py @@ -184,3 +184,15 @@ def lu_factor( out: Optional[np.ndarray] = None, ) -> Tuple[np.ndarray]: raise IvyNotImplementedException() + +def dot( + a: np.ndarray, + b: np.ndarray, + /, + *, + out: Optional[np.ndarray] = None, +) -> np.ndarray: + return np.dot(a, b, out=out) + + +dot.support_native_out = True diff --git a/ivy/functional/backends/paddle/experimental/linear_algebra.py b/ivy/functional/backends/paddle/experimental/linear_algebra.py index 75038b9446a79..899d0a3dfeb34 100644 --- a/ivy/functional/backends/paddle/experimental/linear_algebra.py +++ b/ivy/functional/backends/paddle/experimental/linear_algebra.py @@ -104,3 +104,15 @@ def lu_factor( out: Optional[paddle.Tensor] = None, ) -> Any: raise IvyNotImplementedException() + +def dot( + a: paddle.Tensor, + b: paddle.Tensor, + /, + *, + out: Optional[paddle.Tensor] = None, +) -> paddle.Tensor: + return paddle.dot(a, b, out=out) + + +dot.support_native_out = True diff --git a/ivy/functional/backends/tensorflow/experimental/linear_algebra.py b/ivy/functional/backends/tensorflow/experimental/linear_algebra.py index 1689a3142ec37..1ff2159d0669e 100644 --- a/ivy/functional/backends/tensorflow/experimental/linear_algebra.py +++ b/ivy/functional/backends/tensorflow/experimental/linear_algebra.py @@ -201,3 +201,16 @@ def lu_factor( out: Optional[Union[tf.Tensor, tf.Variable]] = None, ) -> Tuple[tf.Tensor]: raise IvyNotImplementedException() + + +def dot( + a: tf.Tensor, + b: tf.Tensor, + /, + *, + out: Optional[tf.Tensor] = None, +) -> tf.Tensor: + return tf.tensordot(a, b, out=out) + + +dot.support_native_out = True diff --git a/ivy/functional/backends/torch/experimental/linear_algebra.py b/ivy/functional/backends/torch/experimental/linear_algebra.py index 7f34abc7977be..19c3500e17198 100644 --- a/ivy/functional/backends/torch/experimental/linear_algebra.py +++ b/ivy/functional/backends/torch/experimental/linear_algebra.py @@ -189,3 +189,16 @@ def lu_factor( out: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: raise IvyNotImplementedException() + + +def dot( + a: torch.Tensor, + b: torch.Tensor, + /, + *, + out: Optional[torch.Tensor] = None, +) -> torch.Tensor: + return torch.dot(a, b, out=out) + + +dot.support_native_out = True diff --git a/ivy/functional/ivy/experimental/linear_algebra.py b/ivy/functional/ivy/experimental/linear_algebra.py index 529a723923c86..0ab488600ca66 100644 --- a/ivy/functional/ivy/experimental/linear_algebra.py +++ b/ivy/functional/ivy/experimental/linear_algebra.py @@ -559,3 +559,62 @@ def cond( ivy.array(21.0) """ return current_backend(x).cond(x, p=p, out=out) + + +@handle_nestable +@handle_out_argument +@to_native_arrays_and_back +@handle_exceptions +def dot( + a: Union[ivy.Array, ivy.NativeArray], + b: Union[ivy.Array, ivy.NativeArray], + /, + *, + out: Optional[ivy.Array] = None, +) -> ivy.Array: + """ + Compute the dot product between two arrays `a` and `b` using the + current backend's implementation. The dot product is defined as the sum of the + element-wise product of the input arrays. + + Parameters: + ---------- + a + First input array. + b + Second input array. + out + Optional output array. If provided, the output array to store the result. + + Returns: + ------- + ret + The dot product of the input arrays. + + Examples + -------- + With :class:`ivy.Array` inputs: + + >>> a = ivy.array([1, 2, 3]) + >>> b = ivy.array([4, 5, 6]) + >>> result = ivy.dot(a, b) + >>> print(result) + 32 + + >>> c = ivy.array([[1, 2], [3, 4]]) + >>> d = ivy.array([[5, 6], [7, 8]]) + >>> e = ivy.empty_like(d) + >>> results_matrix = ivy.dot(c, d, out=e) + >>> print(results_matrix) + ivy.array([[19, 22], + [43, 50]]) + + >>> f = ivy.array([[1.1, 2.3, -3.6]]) + >>> g = ivy.array([[-4.8], [5.2], [6.1]]) + >>> h = np.zeros((1,1)) + >>> result_ = ivy.dot(f, g, out=h) + >>> print(result_) + ivy.array([[-15.28]]) + """ + return current_backend(a, b).dot(a=a, b=b, out=out) +