Skip to content

Commit

Permalink
Redundant PTNNCFTensor unwrapping is removed
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Nov 29, 2023
1 parent f2769fa commit d8a8982
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 14 deletions.
7 changes: 1 addition & 6 deletions nncf/torch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Union

import torch

from nncf.common.tensor import NNCFTensor
from nncf.torch.return_types import maybe_unwrap_from_torch_return_type


class PTNNCFTensor(NNCFTensor):
"""
A realisation of torch tensors wrapper for common NNCF algorithms.
"""

def __init__(self, tensor: Union[torch.tensor, "PTNNCFTensor", tuple]):
def __init__(self, tensor: torch.tensor):
# In case somebody attempts to wrap
# tensor twice
if isinstance(tensor, self.__class__):
tensor = tensor.tensor
else:
tensor = maybe_unwrap_from_torch_return_type(tensor)

super().__init__(tensor)

Expand Down
8 changes: 0 additions & 8 deletions tests/torch/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

from nncf.experimental.tensor import Tensor
from nncf.experimental.tensor.enums import TensorDeviceType
from nncf.torch.tensor import PTNNCFTensor
from tests.shared.test_templates.template_test_nncf_tensor import TemplateTestNNCFTensorOperators


Expand All @@ -32,10 +31,3 @@ def to_tensor(x):
def test_device(self):
tensor = Tensor(self.to_tensor([1]))
assert tensor.device == TensorDeviceType.GPU


def test_torch_return_type_input():
return_type_input = torch.return_types.max((torch.tensor(0), torch.tensor(1)))
return_type_input.values == torch.tensor(0)
pt_tensor = PTNNCFTensor(return_type_input)
assert pt_tensor.tensor == torch.tensor(0)

0 comments on commit d8a8982

Please sign in to comment.