From bd17d4a62d4db3083b1f09a4ced96abfe33ef08b Mon Sep 17 00:00:00 2001 From: "Kruglov, Oleg" Date: Wed, 18 Dec 2024 18:39:27 -0800 Subject: [PATCH] Address part of comments --- nncf/tensor/functions/tf_linalg.py | 50 +++++++++++++++--- nncf/tensor/functions/tf_numeric.py | 28 +++------- .../template_test_nncf_tensor.py | 51 +++++++++++++++++++ 3 files changed, 100 insertions(+), 29 deletions(-) diff --git a/nncf/tensor/functions/tf_linalg.py b/nncf/tensor/functions/tf_linalg.py index 9769ac4ac32..f0a5b8db290 100644 --- a/nncf/tensor/functions/tf_linalg.py +++ b/nncf/tensor/functions/tf_linalg.py @@ -24,16 +24,50 @@ def _( axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ) -> tf.Tensor: - if axis is None: - axis = 0 if a._rank() == 1 else (0, 1) - - if ord is None or (a._rank() == 1 and ord == "fro"): + if ord is None: ord = "euclidean" + rank = tf.rank(a) + if rank == 2 and axis is None: + axis = (0, 1) with tf.device(a.device): - if ord == "nuc": - s, _, _ = tf.linalg.svd(a) - return tf.reduce_sum(s) + if ord == "nuc" and isinstance(axis, tuple) and len(axis) != 1: + if rank != 2: + raise ValueError("ord='nuc' is only supported for 2D tensors") + s = tf.linalg.svd(a, compute_uv=False) + return tf.reduce_sum(s, axis=-1) + + if ord == -1 and isinstance(axis, tuple) and len(axis) != 1: + if rank != 2: + raise ValueError("ord=-1 is only supported for 2D tensors") + return tf.reduce_min(tf.reduce_sum(tf.abs(a), axis=axis[0]), keepdims=keepdims) + + if ord == 1 and isinstance(axis, tuple) and len(axis) != 1: + if rank != 2: + raise ValueError("ord=1 is only supported for 2D tensors") + return tf.reduce_max(tf.reduce_sum(tf.abs(a), axis=axis[0]), keepdims=keepdims) + + if ord == -2 and isinstance(axis, tuple) and len(axis) != 1: + if rank != 2: + raise ValueError("ord=-2 is only supported for 2D tensors") + s = tf.linalg.svd(a, compute_uv=False) + return tf.reduce_min(s, axis=-1) + + if ord == 2 and isinstance(axis, tuple) and len(axis) != 1: + if rank != 2: + raise ValueError("ord=2 is only supported for 2D tensors") + s = tf.linalg.svd(a, compute_uv=False) + return tf.reduce_max(s, axis=-1) + + if ord == float("inf") and isinstance(axis, tuple) and len(axis) != 1: + if rank != 2: + raise ValueError("ord=inf is only supported for 2D tensors") + return tf.reduce_max(tf.reduce_sum(tf.abs(a), axis=axis[1]), keepdims=keepdims) + + if ord == -float("inf") and isinstance(axis, tuple) and len(axis) != 1: + if rank != 2: + raise ValueError("ord=-inf is only supported for 2D tensors") + return tf.reduce_min(tf.reduce_sum(tf.abs(a), axis=axis[1]), keepdims=keepdims) return tf.linalg.norm(a, ord=ord, axis=axis, keepdims=keepdims) @@ -92,4 +126,4 @@ def _(a: tf.Tensor, full_matrices: Optional[bool] = True) -> tf.Tensor: with tf.device(a.device): s, u, v = tf.linalg.svd(a, full_matrices=full_matrices) - return u, s, tf.transpose(v) + return u, s, tf.transpose(v, conjugate=True) diff --git a/nncf/tensor/functions/tf_numeric.py b/nncf/tensor/functions/tf_numeric.py index 95a98a728f6..cc7799dad62 100644 --- a/nncf/tensor/functions/tf_numeric.py +++ b/nncf/tensor/functions/tf_numeric.py @@ -53,10 +53,6 @@ def _(a: tf.Tensor) -> TensorBackend: @numeric.squeeze.register(tf.Tensor) def _(a: tf.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> tf.Tensor: with tf.device(a.device): - if axis is None: - return tf.squeeze(a) - if isinstance(axis, Tuple) and any(a.shape[i] != 1 for i in axis): - raise ValueError("Cannot select an axis to squeeze out which has size not equal to one") return tf.squeeze(a, axis) @@ -67,19 +63,15 @@ def _(a: tf.Tensor) -> tf.Tensor: @numeric.max.register(tf.Tensor) -def _(a: tf.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdim: bool = False) -> tf.Tensor: +def _(a: tf.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> tf.Tensor: with tf.device(a.device): - if axis is None: - return tf.reduce_max(a) - return tf.reduce_max(a, axis=axis, keepdims=keepdim) + return tf.reduce_max(a, axis=axis, keepdims=keepdims) @numeric.min.register(tf.Tensor) -def _(a: tf.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdim: bool = False) -> tf.Tensor: +def _(a: tf.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> tf.Tensor: with tf.device(a.device): - if axis is None: - return tf.reduce_min(a) - return tf.reduce_min(a, axis=axis, keepdims=keepdim) + return tf.reduce_min(a, axis=axis, keepdims=keepdims) @numeric.abs.register(tf.Tensor) @@ -139,7 +131,7 @@ def _(a: tf.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> tf.Te @numeric.isempty.register(tf.Tensor) def _(a: tf.Tensor) -> bool: - return bool(tf.equal(tf.size(a), 0).numpy()) + return bool(tf.equal(tf.size(a), 0)) @numeric.isclose.register(tf.Tensor) @@ -214,6 +206,7 @@ def _( dtype: Optional[TensorDataType] = None, ) -> tf.Tensor: with tf.device(a.device): + a = tf.cast(a, DTYPE_MAP[dtype]) if dtype is not None else a return tf.reduce_mean(a, axis=axis, keepdims=keepdims) @@ -304,14 +297,7 @@ def _(a: tf.Tensor, data: Any) -> tf.Tensor: @numeric.item.register(tf.Tensor) def _(a: tf.Tensor) -> Union[int, float, bool]: - a = tf.reshape(a, []) - np_item = a.numpy() - if isinstance(np_item, np.floating): - return float(np_item) - if isinstance(np_item, np.bool_): - return bool(np_item) - - return int(np_item) + return a.numpy().item() @numeric.sum.register(tf.Tensor) diff --git a/tests/cross_fw/test_templates/template_test_nncf_tensor.py b/tests/cross_fw/test_templates/template_test_nncf_tensor.py index ad176188970..a1b7021849e 100644 --- a/tests/cross_fw/test_templates/template_test_nncf_tensor.py +++ b/tests/cross_fw/test_templates/template_test_nncf_tensor.py @@ -812,6 +812,7 @@ def test_fn_median(self, x, axis, keepdims, ref): (1.1, 0, 1.0), ([1.1, 0.9], 0, [1.0, 1.0]), ([1.11, 0.91], 1, [1.1, 0.9]), + ([5.5, 3.3], -1, [10.0, 0.0]), ), ) def test_fn_round(self, val, decimals, ref): @@ -1053,6 +1054,13 @@ def test_fn_var(self, x, axis, keepdims, ddof, ref): True, [[1.53063197]], ), + ( + [[0.8, 0.2, 0.2], [0.1, 0.7, 0.1]], + "nuc", + (0, 1), + False, + [1.53063197], + ), ( [[0.8, 0.2, 0.2], [0.1, 0.7, 0.1]], float("inf"), @@ -1067,6 +1075,49 @@ def test_fn_var(self, x, axis, keepdims, ddof, ref): False, 0.9364634205074938, ), + ( + [[0.8, 0.2, 0.2], [0.1, 0.7, 0.1]], + 2, + 0, + False, + [0.8062258, 0.72801095, 0.22360681], + ), + ( + [[0.8, 0.2, 0.2], [0.1, 0.7, 0.1]], + 1, + None, + False, + 0.9, + ), + ( + [[0.8, 0.2, 0.2], [0.1, 0.7, 0.1]], + -1, + None, + False, + 0.3, + ), + ( + [[0.8, 0.2, 0.2], [0.1, 0.7, 0.1]], + -2, + None, + False, + 0.59416854, + ), + ( + [[0.8, 0.2, 0.2], [0.1, 0.7, 0.1]], + float("inf"), + None, + False, + 1.2, + ), + ( + [[0.8, 0.2, 0.2], [0.1, 0.7, 0.1]], + -float("inf"), + None, + False, + 0.9, + ), + ([[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]], None, None, False, 2.82842708), ), ) def test_fn_linalg_norm(self, x, ord, axis, keepdims, ref):