-
Notifications
You must be signed in to change notification settings - Fork 240
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve Tensor performance #2255
Improve Tensor performance #2255
Conversation
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## develop #2255 +/- ##
===========================================
- Coverage 90.70% 90.67% -0.04%
===========================================
Files 485 486 +1
Lines 43601 43600 -1
===========================================
- Hits 39549 39535 -14
- Misses 4052 4065 +13
Flags with carried forward coverage won't be shown. Click here to find out more.
|
nncf/quantization/fake_quantize.py
Outdated
@@ -106,8 +106,9 @@ def tune_range( | |||
fval = -left_border * s | |||
qval = fns.round(fval) | |||
|
|||
ra = fns.where(qval < level_high, qval / (qval - level_high) * right_border, left_border) | |||
rb = fns.where(qval > 0.0, (qval - level_high) / qval * left_border, right_border) | |||
with fns.disable_error_handling(qval): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please, do not enter such context because user usage scenarios are not clear.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed it, tune_range will be reworked in next pr.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One of the reasons why the "manual dispatching" solution was originally chosen over the idiomatic OOP because it was supposed to be "simpler to work with". That analysis was wrong from the start, but now you are adding yet another layer of complexity to gain unknown improvements to NNCF performance on the whole.
The tune_range
function cannot be treated as representative of the NNCF performance on the whole, PTQ or QAT, and it cannot be said that a 40% improvement in the runtime of a tune_range
function will lead to a 40% reduction in the overall NNCF runtime. Providing an own singledispatch implementation is not justifiable if the overall NNCF runtime improment due to this ends up amounting to 0.1%.
You should have first switched all the algos to the new tensor approach, then provided a performance measurement framework that could provide measurements of overall NNCF performance for a given algorithm, and then estimated the improvement due to the change in this PR using the framework. It would be also good if the framework measured the percentage of tensor operation overhead that we are seemingly so keen on optimizing, so that we don't waste time optimizing the overhead if it comprises 1% of the overall algo runtime. Then, if the overall perf improvement justifies this kind of a complexity increase that you are bringing with this PR, it will also be interesting for me to take https://github.com/vshampor/nncf/tree/nncf_tensor_saved with the OOP approach and maybe implement an abc.ABC
on my own too with better performance over the regular Python built-in, in case we are allowed to optimize Python built-ins in this part of NNCF code. Maybe I will even beat the "manual dispatch" implementation in terms of raw performance and we will switch the code to the approach in https://github.com/vshampor/nncf/tree/nncf_tensor_saved as a result.
def decorator(func: Callable) -> Callable: | ||
registry = {} | ||
dispatch_cache = weakref.WeakKeyDictionary() | ||
cache_token = None | ||
|
||
def dispatch(cls: Type) -> Callable: | ||
"""generic_func.dispatch(cls) -> <function implementation> | ||
|
||
Runs the dispatch algorithm to return the best available implementation | ||
for the given *cls* registered on *generic_func*. | ||
""" | ||
nonlocal cache_token | ||
if cache_token is not None: | ||
current_token = get_cache_token() | ||
if cache_token != current_token: | ||
dispatch_cache.clear() | ||
cache_token = current_token | ||
try: | ||
impl = dispatch_cache[cls] | ||
except KeyError: | ||
try: | ||
impl = registry[cls] | ||
except KeyError: | ||
impl = _find_impl(cls, registry) | ||
dispatch_cache[cls] = impl | ||
return impl | ||
|
||
def register(cls: Type, func: Optional[Callable] = None): | ||
"""generic_func.register(cls, func) -> func | ||
|
||
Registers a new implementation for the given *cls* on a *generic_func*. | ||
|
||
""" | ||
nonlocal cache_token | ||
if func is None: | ||
if isinstance(cls, type): | ||
return lambda f: register(cls, f) | ||
ann = getattr(cls, "__annotations__", {}) | ||
if not ann: | ||
raise TypeError( | ||
f"Invalid first argument to `register()`: {cls!r}. " | ||
f"Use either `@register(some_class)` or plain `@register` " | ||
f"on an annotated function." | ||
) | ||
func = cls | ||
|
||
# only import typing if annotation parsing is necessary | ||
from typing import get_type_hints | ||
|
||
argname, cls = next(iter(get_type_hints(func).items())) | ||
if not isinstance(cls, type): | ||
raise TypeError(f"Invalid annotation for {argname!r}. " f"{cls!r} is not a class.") | ||
registry[cls] = func | ||
if cache_token is None and hasattr(cls, "__abstractmethods__"): | ||
cache_token = get_cache_token() | ||
dispatch_cache.clear() | ||
return func | ||
|
||
def wrapper_tensor_to_tensor(tensor: Tensor, *args, **kw): | ||
args = tuple(x.data if isinstance(x, Tensor) else x for x in args) | ||
return Tensor(dispatch(tensor.data.__class__)(tensor.data, *args, **kw)) | ||
|
||
def wrapper_tensor_to_any(tensor: Tensor, *args, **kw): | ||
args = tuple(x.data if isinstance(x, Tensor) else x for x in args) | ||
return dispatch(tensor.data.__class__)(tensor.data, *args, **kw) | ||
|
||
def wrapper_tensor_to_list(tensor: Tensor, *args, **kw): | ||
args = tuple(x.data if isinstance(x, Tensor) else x for x in args) | ||
return [Tensor(x) for x in dispatch(tensor.data.__class__)(tensor.data, *args, **kw)] | ||
|
||
def wrapper_list_to_tensor(list_of_tensors: List[Tensor], *args, **kw): | ||
list_of_tensors = [x.data for x in list_of_tensors] | ||
return Tensor(dispatch(list_of_tensors[0].__class__)(list_of_tensors, *args, **kw)) | ||
|
||
wrappers_map = { | ||
WrapperType.TensorToTensor: wrapper_tensor_to_tensor, | ||
WrapperType.TensorToAny: wrapper_tensor_to_any, | ||
WrapperType.TensorToList: wrapper_tensor_to_list, | ||
WrapperType.ListToTensor: wrapper_list_to_tensor, | ||
} | ||
|
||
def raise_not_implemented(data: Union[Tensor, List[Tensor]], *args, **kw): | ||
""" | ||
Raising NotImplementedError for not registered type. | ||
""" | ||
if wrapper_type == WrapperType.ListToTensor: | ||
arg_type = type(data[0].data) if isinstance(data[0], Tensor) else type(data[0]) | ||
else: | ||
arg_type = type(data.data) if isinstance(data, Tensor) else type(data) | ||
|
||
raise NotImplementedError(f"Function `{func.__name__}` is not implemented for {arg_type}") | ||
|
||
registry[object] = raise_not_implemented | ||
wrapper = wrappers_map[wrapper_type] | ||
wrapper.register = register | ||
wrapper.dispatch = dispatch | ||
wrapper.registry = types.MappingProxyType(registry) | ||
wrapper._clear_cache = dispatch_cache.clear | ||
update_wrapper(wrapper, func) | ||
return wrapper |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some of your code here shows as not tested in the coverage report, check it and add respective tests.
Since this is almost a verbatim copy of the CPython code, you should check the license compatibility with our own Apache-2.0 license and do necessary license obligations.
Changes
Add custom custom implementation of functools.singledispatch with custom arguments 'wrapper_type' to select wrapper function.
Remove
np.errstate
from divide function because it degrades performance, to disable warning adddisable_error_handling
function.Tested on run 1000 times
tune_range
function by cProfile.