Skip to content

Commit

Permalink
Merge branch 'release_v2141' into ad/r_safetensors
Browse files Browse the repository at this point in the history
  • Loading branch information
MaximProshin authored Dec 14, 2024
2 parents 33a8fee + d316ce7 commit cb7c7ea
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 22 deletions.
9 changes: 5 additions & 4 deletions docs/api/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ def collect_api_entities() -> APIInfo:
except Exception as e:
skipped_modules[modname] = str(e)

from nncf.common.utils.api_marker import api
from nncf.common.utils.api_marker import API_MARKER_ATTR
from nncf.common.utils.api_marker import CANONICAL_ALIAS_ATTR

canonical_imports_seen = set()

Expand All @@ -86,7 +87,7 @@ def collect_api_entities() -> APIInfo:
if (
objects_module == modname
and (inspect.isclass(obj) or inspect.isfunction(obj))
and hasattr(obj, api.API_MARKER_ATTR)
and hasattr(obj, API_MARKER_ATTR)
):
marked_object_name = obj._nncf_api_marker
# Check the actual name of the originally marked object
Expand All @@ -95,8 +96,8 @@ def collect_api_entities() -> APIInfo:
if marked_object_name != obj.__name__:
continue
fqn = f"{modname}.{obj_name}"
if hasattr(obj, api.CANONICAL_ALIAS_ATTR):
canonical_import_name = getattr(obj, api.CANONICAL_ALIAS_ATTR)
if hasattr(obj, CANONICAL_ALIAS_ATTR):
canonical_import_name = getattr(obj, CANONICAL_ALIAS_ATTR)
if canonical_import_name in canonical_imports_seen:
assert False, f"Duplicate canonical_alias detected: {canonical_import_name}"
retval.fqn_vs_canonical_name[fqn] = canonical_import_name
Expand Down
41 changes: 26 additions & 15 deletions nncf/common/utils/api_marker.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,35 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any

from typing import Any, Callable, TypeVar, Union

class api:
API_MARKER_ATTR = "_nncf_api_marker"
CANONICAL_ALIAS_ATTR = "_nncf_canonical_alias"
TObj = TypeVar("TObj", bound=Union[Callable[..., Any], type])

def __init__(self, canonical_alias: str = None):
self._canonical_alias = canonical_alias
API_MARKER_ATTR = "_nncf_api_marker"
CANONICAL_ALIAS_ATTR = "_nncf_canonical_alias"

def __call__(self, obj: Any) -> Any:
# The value of the marker will be useful in determining
# whether we are handling a base class or a derived one.
setattr(obj, api.API_MARKER_ATTR, obj.__name__)
if self._canonical_alias is not None:
setattr(obj, api.CANONICAL_ALIAS_ATTR, self._canonical_alias)
return obj

def api(canonical_alias: str = None) -> Callable[[TObj], TObj]:
"""
Decorator function used to mark a object as an API.
Example:
@api(canonical_alias="alias")
class Class:
pass
@api(canonical_alias="alias")
def function():
pass
:param canonical_alias: The canonical alias for the API class.
"""

def decorator(obj: TObj) -> TObj:
setattr(obj, API_MARKER_ATTR, obj.__name__)
if canonical_alias is not None:
setattr(obj, CANONICAL_ALIAS_ATTR, canonical_alias)
return obj

def is_api(obj: Any) -> bool:
return hasattr(obj, api.API_MARKER_ATTR)
return decorator
6 changes: 3 additions & 3 deletions nncf/torch/dynamic_graph/patch_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import functools
import inspect
from contextlib import contextmanager
from typing import Callable, List, Union
from typing import Callable, List, Optional, Union

import torch
import torch.utils.cpp_extension
Expand Down Expand Up @@ -251,13 +251,13 @@ def get_torch_compile_wrapper():
"""

@functools.wraps(_ORIG_TORCH_COMPILE)
def wrapper(model, *args, **kwargs):
def wrapper(model: Optional[Callable] = None, **kwargs):
from nncf.torch.nncf_network import NNCFNetwork

if isinstance(model, NNCFNetwork):
raise TypeError("At the moment torch.compile() is not supported for models optimized by NNCF.")
with disable_patching():
return _ORIG_TORCH_COMPILE(model, *args, **kwargs)
return _ORIG_TORCH_COMPILE(model, **kwargs)

return wrapper

Expand Down

0 comments on commit cb7c7ea

Please sign in to comment.