Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Tensor performance #2255

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 7 additions & 15 deletions nncf/experimental/tensor/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@ tenor_b = Tensor(np.array([1,2]))
tensor_a + tenor_b # Tensor(array([2, 4]))
```

**NOTE** Division operations for the numpy backend are performed with warnings disabled for the same for all backends.

### Comparison operators

All math operations are overrided to operated with wrapped object and return `Tensor`
Expand Down Expand Up @@ -108,7 +106,7 @@ tensor_a[0:2] # Tensor(array([[1],[2]]))
2. Add function to [function.py](function.py)

```python
@functools.singledispatch
@tensor_dispatch()
def foo(a: TTensor, arg1: Type) -> TTensor:
"""
__description__
Expand All @@ -117,21 +115,15 @@ tensor_a[0:2] # Tensor(array([[1],[2]]))
:param arg1: __description__
:return: __description__
"""
if isinstance(a, tensor.Tensor):
return tensor.Tensor(foo(a.data, axis))
return NotImplemented(f"Function `foo` is not implemented for {type(a)}")
```

**NOTE** For the case when the first argument has type `List[Tensor]`, use the `_dispatch_list` function. This function dispatches function by first element in the first argument.
**NOTE** To control work with Tensors, different types of wrapper functions can be selected
`@tensor_dispatch(wrapper_type=WrapperType.TensorToTensor)`:

```python
@functools.singledispatch
def foo(x: List[Tensor], axis: int = 0) -> Tensor:
if isinstance(x, List):
unwrapped_x = [i.data for i in x]
return Tensor(_dispatch_list(foo, unwrapped_x, axis=axis))
raise NotImplementedError(f"Function `foo` is not implemented for {type(x)}")
```
- `WrapperType.TensorToTensor` (default) expects Tensor as first argument, result will be wrapped to Tensor.
- `WrapperType.TensorToAny` expects Tensor as first argument, result will not be wrapped to Tensor.
- `WrapperType.TensorToList` expects Tensor as first argument, each element in result list will be wrapped to Tensor.
- `WrapperType.ListToTensor` expects List of Tensors as first argument, result will be wrapped to Tensor.

3. Add backend specific implementation of method to:

Expand Down
155 changes: 155 additions & 0 deletions nncf/experimental/tensor/dispatcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# Copyright (c) 2023 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import types
import weakref
from abc import get_cache_token
from enum import Enum
from enum import auto
from functools import _find_impl
from functools import update_wrapper
from typing import Callable, List, Optional, Type, Union

from nncf.experimental.tensor import Tensor


class WrapperType(Enum):
TensorToTensor = auto()
TensorToAny = auto()
TensorToList = auto()
ListToTensor = auto()
OnlyDispatch = auto()


def tensor_dispatch(wrapper_type: WrapperType = WrapperType.TensorToTensor) -> Callable:
"""Custom implementation of functools.singledispatch function decorator.

Transforms a function into a generic function, which can have different
behaviours depending upon the type of its first argument. The decorated
function acts as the default implementation, and additional
implementations can be registered using the register() attribute of the
generic function.

To control work with Tensors, different types of wrapper functions can be selected:
TensorToTensor - expects Tensor as first argument, result will be wrapped to Tensor.
TensorToAny - expects Tensor as first argument, result will not be wrapped to Tensor.
TensorToList - expects Tensor as first argument, each element in result list will be wrapped to Tensor.
ListToTensor - expects List of Tensors as first argument, result will be wrapped to Tensor.

For not registered types will be raised NotImplementedError.

In case of the first argument is not wrapped to Tensor will call backend specific function directory.

:param wrapper_type: Type of wrapper function, defaults TensorToTensor.
"""

def decorator(func: Callable) -> Callable:
registry = {}
dispatch_cache = weakref.WeakKeyDictionary()
cache_token = None

def dispatch(cls: Type) -> Callable:
"""generic_func.dispatch(cls) -> <function implementation>

Runs the dispatch algorithm to return the best available implementation
for the given *cls* registered on *generic_func*.
"""
nonlocal cache_token
if cache_token is not None:
current_token = get_cache_token()
if cache_token != current_token:
dispatch_cache.clear()
cache_token = current_token
try:
impl = dispatch_cache[cls]
except KeyError:
try:
impl = registry[cls]
except KeyError:
impl = _find_impl(cls, registry)
dispatch_cache[cls] = impl
return impl

def register(cls: Type, func: Optional[Callable] = None):
"""generic_func.register(cls, func) -> func

Registers a new implementation for the given *cls* on a *generic_func*.

"""
nonlocal cache_token
if func is None:
if isinstance(cls, type):
return lambda f: register(cls, f)
ann = getattr(cls, "__annotations__", {})
if not ann:
raise TypeError(
f"Invalid first argument to `register()`: {cls!r}. "
f"Use either `@register(some_class)` or plain `@register` "
f"on an annotated function."
)
func = cls

# only import typing if annotation parsing is necessary
from typing import get_type_hints

argname, cls = next(iter(get_type_hints(func).items()))
if not isinstance(cls, type):
raise TypeError(f"Invalid annotation for {argname!r}. " f"{cls!r} is not a class.")
registry[cls] = func
if cache_token is None and hasattr(cls, "__abstractmethods__"):
cache_token = get_cache_token()
dispatch_cache.clear()
return func

def wrapper_tensor_to_tensor(tensor: Tensor, *args, **kw):
args = tuple(x.data if isinstance(x, Tensor) else x for x in args)
return Tensor(dispatch(tensor.data.__class__)(tensor.data, *args, **kw))

def wrapper_tensor_to_any(tensor: Tensor, *args, **kw):
args = tuple(x.data if isinstance(x, Tensor) else x for x in args)
return dispatch(tensor.data.__class__)(tensor.data, *args, **kw)

def wrapper_tensor_to_list(tensor: Tensor, *args, **kw):
args = tuple(x.data if isinstance(x, Tensor) else x for x in args)
return [Tensor(x) for x in dispatch(tensor.data.__class__)(tensor.data, *args, **kw)]

def wrapper_list_to_tensor(list_of_tensors: List[Tensor], *args, **kw):
list_of_tensors = [x.data for x in list_of_tensors]
return Tensor(dispatch(list_of_tensors[0].__class__)(list_of_tensors, *args, **kw))

wrappers_map = {
WrapperType.TensorToTensor: wrapper_tensor_to_tensor,
WrapperType.TensorToAny: wrapper_tensor_to_any,
WrapperType.TensorToList: wrapper_tensor_to_list,
WrapperType.ListToTensor: wrapper_list_to_tensor,
}

def raise_not_implemented(data: Union[Tensor, List[Tensor]], *args, **kw):
"""
Raising NotImplementedError for not registered type.
"""
if wrapper_type == WrapperType.ListToTensor:
arg_type = type(data[0].data) if isinstance(data[0], Tensor) else type(data[0])
else:
arg_type = type(data.data) if isinstance(data, Tensor) else type(data)

raise NotImplementedError(f"Function `{func.__name__}` is not implemented for {arg_type}")

registry[object] = raise_not_implemented
wrapper = wrappers_map[wrapper_type]
wrapper.register = register
wrapper.dispatch = dispatch
wrapper.registry = types.MappingProxyType(registry)
wrapper._clear_cache = dispatch_cache.clear
update_wrapper(wrapper, func)
return wrapper
Comment on lines +54 to +153
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some of your code here shows as not tested in the coverage report, check it and add respective tests.

Since this is almost a verbatim copy of the CPython code, you should check the license compatibility with our own Apache-2.0 license and do necessary license obligations.


return decorator
Loading