Skip to content

Commit

Permalink
Address part of comments
Browse files Browse the repository at this point in the history
  • Loading branch information
olegkkruglov committed Dec 19, 2024
1 parent f60d59d commit bd17d4a
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 29 deletions.
50 changes: 42 additions & 8 deletions nncf/tensor/functions/tf_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,50 @@ def _(
axis: Optional[Union[int, Tuple[int, ...]]] = None,
keepdims: bool = False,
) -> tf.Tensor:
if axis is None:
axis = 0 if a._rank() == 1 else (0, 1)

if ord is None or (a._rank() == 1 and ord == "fro"):
if ord is None:
ord = "euclidean"
rank = tf.rank(a)
if rank == 2 and axis is None:
axis = (0, 1)

with tf.device(a.device):
if ord == "nuc":
s, _, _ = tf.linalg.svd(a)
return tf.reduce_sum(s)
if ord == "nuc" and isinstance(axis, tuple) and len(axis) != 1:
if rank != 2:
raise ValueError("ord='nuc' is only supported for 2D tensors")
s = tf.linalg.svd(a, compute_uv=False)
return tf.reduce_sum(s, axis=-1)

if ord == -1 and isinstance(axis, tuple) and len(axis) != 1:
if rank != 2:
raise ValueError("ord=-1 is only supported for 2D tensors")
return tf.reduce_min(tf.reduce_sum(tf.abs(a), axis=axis[0]), keepdims=keepdims)

if ord == 1 and isinstance(axis, tuple) and len(axis) != 1:
if rank != 2:
raise ValueError("ord=1 is only supported for 2D tensors")
return tf.reduce_max(tf.reduce_sum(tf.abs(a), axis=axis[0]), keepdims=keepdims)

if ord == -2 and isinstance(axis, tuple) and len(axis) != 1:
if rank != 2:
raise ValueError("ord=-2 is only supported for 2D tensors")
s = tf.linalg.svd(a, compute_uv=False)
return tf.reduce_min(s, axis=-1)

if ord == 2 and isinstance(axis, tuple) and len(axis) != 1:
if rank != 2:
raise ValueError("ord=2 is only supported for 2D tensors")
s = tf.linalg.svd(a, compute_uv=False)
return tf.reduce_max(s, axis=-1)

if ord == float("inf") and isinstance(axis, tuple) and len(axis) != 1:
if rank != 2:
raise ValueError("ord=inf is only supported for 2D tensors")
return tf.reduce_max(tf.reduce_sum(tf.abs(a), axis=axis[1]), keepdims=keepdims)

if ord == -float("inf") and isinstance(axis, tuple) and len(axis) != 1:
if rank != 2:
raise ValueError("ord=-inf is only supported for 2D tensors")
return tf.reduce_min(tf.reduce_sum(tf.abs(a), axis=axis[1]), keepdims=keepdims)

return tf.linalg.norm(a, ord=ord, axis=axis, keepdims=keepdims)

Expand Down Expand Up @@ -92,4 +126,4 @@ def _(a: tf.Tensor, full_matrices: Optional[bool] = True) -> tf.Tensor:
with tf.device(a.device):
s, u, v = tf.linalg.svd(a, full_matrices=full_matrices)

return u, s, tf.transpose(v)
return u, s, tf.transpose(v, conjugate=True)
28 changes: 7 additions & 21 deletions nncf/tensor/functions/tf_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,6 @@ def _(a: tf.Tensor) -> TensorBackend:
@numeric.squeeze.register(tf.Tensor)
def _(a: tf.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> tf.Tensor:
with tf.device(a.device):
if axis is None:
return tf.squeeze(a)
if isinstance(axis, Tuple) and any(a.shape[i] != 1 for i in axis):
raise ValueError("Cannot select an axis to squeeze out which has size not equal to one")
return tf.squeeze(a, axis)


Expand All @@ -67,19 +63,15 @@ def _(a: tf.Tensor) -> tf.Tensor:


@numeric.max.register(tf.Tensor)
def _(a: tf.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdim: bool = False) -> tf.Tensor:
def _(a: tf.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> tf.Tensor:
with tf.device(a.device):
if axis is None:
return tf.reduce_max(a)
return tf.reduce_max(a, axis=axis, keepdims=keepdim)
return tf.reduce_max(a, axis=axis, keepdims=keepdims)


@numeric.min.register(tf.Tensor)
def _(a: tf.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdim: bool = False) -> tf.Tensor:
def _(a: tf.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> tf.Tensor:
with tf.device(a.device):
if axis is None:
return tf.reduce_min(a)
return tf.reduce_min(a, axis=axis, keepdims=keepdim)
return tf.reduce_min(a, axis=axis, keepdims=keepdims)


@numeric.abs.register(tf.Tensor)
Expand Down Expand Up @@ -139,7 +131,7 @@ def _(a: tf.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> tf.Te

@numeric.isempty.register(tf.Tensor)
def _(a: tf.Tensor) -> bool:
return bool(tf.equal(tf.size(a), 0).numpy())
return bool(tf.equal(tf.size(a), 0))


@numeric.isclose.register(tf.Tensor)
Expand Down Expand Up @@ -214,6 +206,7 @@ def _(
dtype: Optional[TensorDataType] = None,
) -> tf.Tensor:
with tf.device(a.device):
a = tf.cast(a, DTYPE_MAP[dtype]) if dtype is not None else a
return tf.reduce_mean(a, axis=axis, keepdims=keepdims)


Expand Down Expand Up @@ -304,14 +297,7 @@ def _(a: tf.Tensor, data: Any) -> tf.Tensor:

@numeric.item.register(tf.Tensor)
def _(a: tf.Tensor) -> Union[int, float, bool]:
a = tf.reshape(a, [])
np_item = a.numpy()
if isinstance(np_item, np.floating):
return float(np_item)
if isinstance(np_item, np.bool_):
return bool(np_item)

return int(np_item)
return a.numpy().item()


@numeric.sum.register(tf.Tensor)
Expand Down
51 changes: 51 additions & 0 deletions tests/cross_fw/test_templates/template_test_nncf_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,6 +812,7 @@ def test_fn_median(self, x, axis, keepdims, ref):
(1.1, 0, 1.0),
([1.1, 0.9], 0, [1.0, 1.0]),
([1.11, 0.91], 1, [1.1, 0.9]),
([5.5, 3.3], -1, [10.0, 0.0]),
),
)
def test_fn_round(self, val, decimals, ref):
Expand Down Expand Up @@ -1053,6 +1054,13 @@ def test_fn_var(self, x, axis, keepdims, ddof, ref):
True,
[[1.53063197]],
),
(
[[0.8, 0.2, 0.2], [0.1, 0.7, 0.1]],
"nuc",
(0, 1),
False,
[1.53063197],
),
(
[[0.8, 0.2, 0.2], [0.1, 0.7, 0.1]],
float("inf"),
Expand All @@ -1067,6 +1075,49 @@ def test_fn_var(self, x, axis, keepdims, ddof, ref):
False,
0.9364634205074938,
),
(
[[0.8, 0.2, 0.2], [0.1, 0.7, 0.1]],
2,
0,
False,
[0.8062258, 0.72801095, 0.22360681],
),
(
[[0.8, 0.2, 0.2], [0.1, 0.7, 0.1]],
1,
None,
False,
0.9,
),
(
[[0.8, 0.2, 0.2], [0.1, 0.7, 0.1]],
-1,
None,
False,
0.3,
),
(
[[0.8, 0.2, 0.2], [0.1, 0.7, 0.1]],
-2,
None,
False,
0.59416854,
),
(
[[0.8, 0.2, 0.2], [0.1, 0.7, 0.1]],
float("inf"),
None,
False,
1.2,
),
(
[[0.8, 0.2, 0.2], [0.1, 0.7, 0.1]],
-float("inf"),
None,
False,
0.9,
),
([[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]], None, None, False, 2.82842708),
),
)
def test_fn_linalg_norm(self, x, ord, axis, keepdims, ref):
Expand Down

0 comments on commit bd17d4a

Please sign in to comment.