Skip to content

Commit

Permalink
Add assignment operators and matmul (#2792)
Browse files Browse the repository at this point in the history
### Changes

Support `__iadd__`, `__isub__`, `__imul__`, `__ipow__`, `__itruediv__`,
`__ifloordiv__`, `__matmul__` operators for Tensor
  • Loading branch information
AlexanderDokuchaev authored Jul 8, 2024
1 parent 9fcba3e commit 26fd1fd
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 4 deletions.
27 changes: 27 additions & 0 deletions nncf/tensor/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,36 +82,63 @@ def __add__(self, other: Union[Tensor, float]) -> Tensor:
def __radd__(self, other: Union[Tensor, float]) -> Tensor:
return Tensor(unwrap_tensor_data(other) + self.data)

def __iadd__(self, other: Union[Tensor, float]) -> Tensor:
self._data += unwrap_tensor_data(other)
return self

def __sub__(self, other: Union[Tensor, float]) -> Tensor:
return Tensor(self.data - unwrap_tensor_data(other))

def __rsub__(self, other: Union[Tensor, float]) -> Tensor:
return Tensor(unwrap_tensor_data(other) - self.data)

def __isub__(self, other: Union[Tensor, float]) -> Tensor:
self._data -= unwrap_tensor_data(other)
return self

def __mul__(self, other: Union[Tensor, float]) -> Tensor:
return Tensor(self.data * unwrap_tensor_data(other))

def __rmul__(self, other: Union[Tensor, float]) -> Tensor:
return Tensor(unwrap_tensor_data(other) * self.data)

def __imul__(self, other: Union[Tensor, float]) -> Tensor:
self._data *= unwrap_tensor_data(other)
return self

def __pow__(self, other: Union[Tensor, float]) -> Tensor:
return Tensor(self.data ** unwrap_tensor_data(other))

def __rpow__(self, other: Union[Tensor, float]) -> Tensor:
return Tensor(unwrap_tensor_data(other) ** self.data)

def __ipow__(self, other: Union[Tensor, float]) -> Tensor:
self._data **= unwrap_tensor_data(other)
return self

def __truediv__(self, other: Union[Tensor, float]) -> Tensor:
return _call_function("_binary_op_nowarn", self, other, operator.truediv)

def __rtruediv__(self, other: Union[Tensor, float]) -> Tensor:
return _call_function("_binary_reverse_op_nowarn", self, other, operator.truediv)

def __itruediv__(self, other: Union[Tensor, float]) -> Tensor:
self._data /= unwrap_tensor_data(other)
return self

def __floordiv__(self, other: Union[Tensor, float]) -> Tensor:
return _call_function("_binary_op_nowarn", self, other, operator.floordiv)

def __rfloordiv__(self, other: Union[Tensor, float]) -> Tensor:
return _call_function("_binary_reverse_op_nowarn", self, other, operator.floordiv)

def __ifloordiv__(self, other: Union[Tensor, float]) -> Tensor:
self._data /= unwrap_tensor_data(other)
return self

def __matmul__(self, other: Union[Tensor, float]) -> Tensor:
return Tensor(self.data @ unwrap_tensor_data(other))

def __neg__(self) -> Tensor:
return Tensor(-self.data)

Expand Down
20 changes: 16 additions & 4 deletions tests/shared/test_templates/template_test_nncf_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@
"truediv": operator.truediv,
"floordiv": operator.floordiv,
"neg": lambda a, _: -a,
"iadd": operator.iadd,
"isub": operator.isub,
"imul": operator.imul,
"ipow": operator.ipow,
"itruediv": operator.itruediv,
"ifloordiv": operator.ifloordiv,
}
BINARY_OPERATORS = ["add", "sub", "pow", "mul", "truediv", "floordiv"]

Expand Down Expand Up @@ -93,8 +99,8 @@ def test_operator_clone(self):

@pytest.mark.parametrize("op_name", OPERATOR_MAP.keys())
def test_operators_tensor(self, op_name):
tensor_a = self.to_tensor([1, 2])
tensor_b = self.to_tensor([22, 11])
tensor_a = self.to_tensor([1.0, 2.0])
tensor_b = self.to_tensor([22.0, 11.0])

nncf_tensor_a = Tensor(tensor_a)
nncf_tensor_b = Tensor(tensor_b)
Expand All @@ -110,8 +116,8 @@ def test_operators_tensor(self, op_name):

@pytest.mark.parametrize("op_name", OPERATOR_MAP.keys())
def test_operators_int(self, op_name):
tensor_a = self.to_tensor([1, 2])
value = 2
tensor_a = self.to_tensor([1.0, 2.0])
value = 2.0

nncf_tensor_a = Tensor(tensor_a)

Expand Down Expand Up @@ -1090,6 +1096,12 @@ def test_fn_matmul(self, m1, m2, ref):
assert fns.allclose(res.data, ref_tensor)
assert res.device == tensor1.device

res = tensor1 @ tensor2

assert isinstance(res, Tensor)
assert fns.allclose(res.data, ref_tensor)
assert res.device == tensor1.device

@pytest.mark.parametrize(
"val, axis, ref",
(
Expand Down

0 comments on commit 26fd1fd

Please sign in to comment.