From 26fd1fd9467a8fcf2bab039118521549c5491815 Mon Sep 17 00:00:00 2001 From: Alexander Dokuchaev Date: Mon, 8 Jul 2024 12:41:22 +0300 Subject: [PATCH] Add assignment operators and matmul (#2792) ### Changes Support `__iadd__`, `__isub__`, `__imul__`, `__ipow__`, `__itruediv__`, `__ifloordiv__`, `__matmul__` operators for Tensor --- nncf/tensor/tensor.py | 27 +++++++++++++++++++ .../template_test_nncf_tensor.py | 20 +++++++++++--- 2 files changed, 43 insertions(+), 4 deletions(-) diff --git a/nncf/tensor/tensor.py b/nncf/tensor/tensor.py index 2be02d8c252..52966be1ad1 100644 --- a/nncf/tensor/tensor.py +++ b/nncf/tensor/tensor.py @@ -82,36 +82,63 @@ def __add__(self, other: Union[Tensor, float]) -> Tensor: def __radd__(self, other: Union[Tensor, float]) -> Tensor: return Tensor(unwrap_tensor_data(other) + self.data) + def __iadd__(self, other: Union[Tensor, float]) -> Tensor: + self._data += unwrap_tensor_data(other) + return self + def __sub__(self, other: Union[Tensor, float]) -> Tensor: return Tensor(self.data - unwrap_tensor_data(other)) def __rsub__(self, other: Union[Tensor, float]) -> Tensor: return Tensor(unwrap_tensor_data(other) - self.data) + def __isub__(self, other: Union[Tensor, float]) -> Tensor: + self._data -= unwrap_tensor_data(other) + return self + def __mul__(self, other: Union[Tensor, float]) -> Tensor: return Tensor(self.data * unwrap_tensor_data(other)) def __rmul__(self, other: Union[Tensor, float]) -> Tensor: return Tensor(unwrap_tensor_data(other) * self.data) + def __imul__(self, other: Union[Tensor, float]) -> Tensor: + self._data *= unwrap_tensor_data(other) + return self + def __pow__(self, other: Union[Tensor, float]) -> Tensor: return Tensor(self.data ** unwrap_tensor_data(other)) def __rpow__(self, other: Union[Tensor, float]) -> Tensor: return Tensor(unwrap_tensor_data(other) ** self.data) + def __ipow__(self, other: Union[Tensor, float]) -> Tensor: + self._data **= unwrap_tensor_data(other) + return self + def __truediv__(self, other: Union[Tensor, float]) -> Tensor: return _call_function("_binary_op_nowarn", self, other, operator.truediv) def __rtruediv__(self, other: Union[Tensor, float]) -> Tensor: return _call_function("_binary_reverse_op_nowarn", self, other, operator.truediv) + def __itruediv__(self, other: Union[Tensor, float]) -> Tensor: + self._data /= unwrap_tensor_data(other) + return self + def __floordiv__(self, other: Union[Tensor, float]) -> Tensor: return _call_function("_binary_op_nowarn", self, other, operator.floordiv) def __rfloordiv__(self, other: Union[Tensor, float]) -> Tensor: return _call_function("_binary_reverse_op_nowarn", self, other, operator.floordiv) + def __ifloordiv__(self, other: Union[Tensor, float]) -> Tensor: + self._data /= unwrap_tensor_data(other) + return self + + def __matmul__(self, other: Union[Tensor, float]) -> Tensor: + return Tensor(self.data @ unwrap_tensor_data(other)) + def __neg__(self) -> Tensor: return Tensor(-self.data) diff --git a/tests/shared/test_templates/template_test_nncf_tensor.py b/tests/shared/test_templates/template_test_nncf_tensor.py index 10bae2ad9d5..29b558b240b 100644 --- a/tests/shared/test_templates/template_test_nncf_tensor.py +++ b/tests/shared/test_templates/template_test_nncf_tensor.py @@ -38,6 +38,12 @@ "truediv": operator.truediv, "floordiv": operator.floordiv, "neg": lambda a, _: -a, + "iadd": operator.iadd, + "isub": operator.isub, + "imul": operator.imul, + "ipow": operator.ipow, + "itruediv": operator.itruediv, + "ifloordiv": operator.ifloordiv, } BINARY_OPERATORS = ["add", "sub", "pow", "mul", "truediv", "floordiv"] @@ -93,8 +99,8 @@ def test_operator_clone(self): @pytest.mark.parametrize("op_name", OPERATOR_MAP.keys()) def test_operators_tensor(self, op_name): - tensor_a = self.to_tensor([1, 2]) - tensor_b = self.to_tensor([22, 11]) + tensor_a = self.to_tensor([1.0, 2.0]) + tensor_b = self.to_tensor([22.0, 11.0]) nncf_tensor_a = Tensor(tensor_a) nncf_tensor_b = Tensor(tensor_b) @@ -110,8 +116,8 @@ def test_operators_tensor(self, op_name): @pytest.mark.parametrize("op_name", OPERATOR_MAP.keys()) def test_operators_int(self, op_name): - tensor_a = self.to_tensor([1, 2]) - value = 2 + tensor_a = self.to_tensor([1.0, 2.0]) + value = 2.0 nncf_tensor_a = Tensor(tensor_a) @@ -1090,6 +1096,12 @@ def test_fn_matmul(self, m1, m2, ref): assert fns.allclose(res.data, ref_tensor) assert res.device == tensor1.device + res = tensor1 @ tensor2 + + assert isinstance(res, Tensor) + assert fns.allclose(res.data, ref_tensor) + assert res.device == tensor1.device + @pytest.mark.parametrize( "val, axis, ref", (