Skip to content

Commit

Permalink
[PT] check only major and minor version (#2422)
Browse files Browse the repository at this point in the history
### Changes

- Use `SpecifierSet` to check version of PT and TF
- Remove unused `BKC_TORCHVISION_VERSION`
- Fix strict check version of TF that never raised before

---------

Co-authored-by: Alexander Suslov <[email protected]>
  • Loading branch information
AlexanderDokuchaev and alexsu52 authored Mar 11, 2024
1 parent 119db6d commit cd4f4f2
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 24 deletions.
4 changes: 2 additions & 2 deletions nncf/common/logging/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def extension_is_loading_info_log(extension_name: str):

def warn_bkc_version_mismatch(backend: str, bkc_version: str, current_version: str):
nncf_logger.warning(
f"NNCF provides best results with {backend}=={bkc_version}, "
f"NNCF provides best results with {backend}{bkc_version}, "
f"while current {backend} version is {current_version}. "
f"If you encounter issues, consider switching to {backend}=={bkc_version}"
f"If you encounter issues, consider switching to {backend}{bkc_version}"
)
20 changes: 10 additions & 10 deletions nncf/tensorflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,26 +14,26 @@

import tensorflow
from packaging import version
from packaging.specifiers import SpecifierSet

import nncf
from nncf import nncf_logger
from nncf.common.logging.logger import warn_bkc_version_mismatch
from nncf.version import BKC_TF_VERSION
from nncf.version import BKC_TF_SPEC
from nncf.version import STRICT_TF_SPEC

try:
_tf_version = tensorflow.__version__
tensorflow_version = version.parse(_tf_version).base_version
tensorflow_version = version.parse(version.parse(tensorflow.__version__).base_version)
except:
nncf_logger.debug("Could not parse tensorflow version")
_tf_version = "0.0.0"
tensorflow_version = version.parse(_tf_version).base_version
tensorflow_version_major, tensorflow_version_minor = tuple(map(int, tensorflow_version.split(".")))[:2]
if not tensorflow_version.startswith(BKC_TF_VERSION[:-2]):
warn_bkc_version_mismatch("tensorflow", BKC_TF_VERSION, _tf_version)
elif not (tensorflow_version_major == 2 and 8 <= tensorflow_version_minor <= 13):
tensorflow_version = version.parse("0.0.0")

if tensorflow_version not in SpecifierSet(STRICT_TF_SPEC):
raise nncf.UnsupportedVersionError(
f"NNCF only supports 2.8.4 <= tensorflow <= 2.13.*, while current tensorflow version is {_tf_version}"
f"NNCF only supports tensorflow{STRICT_TF_SPEC}, while current tensorflow version is {tensorflow_version}"
)
if tensorflow_version not in SpecifierSet(BKC_TF_SPEC):
warn_bkc_version_mismatch("torch", BKC_TF_SPEC, tensorflow.__version__)


from nncf.common.accuracy_aware_training.training_loop import (
Expand Down
2 changes: 1 addition & 1 deletion nncf/tensorflow/tf_internals.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from nncf.tensorflow import tensorflow_version

if version.parse(tensorflow_version) < version.parse("2.13"):
if tensorflow_version < version.parse("2.13"):
from keras import engine as keras_engine # noqa: F401
from keras.applications import imagenet_utils as imagenet_utils
from keras.engine.keras_tensor import KerasTensor as KerasTensor
Expand Down
13 changes: 6 additions & 7 deletions nncf/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,20 @@
from nncf import nncf_logger
from nncf.common.logging.logger import warn_bkc_version_mismatch

from nncf.version import BKC_TORCH_VERSION
from nncf.version import BKC_TORCH_SPEC

import torch
from packaging import version
from packaging.specifiers import SpecifierSet

try:
_torch_version = torch.__version__
torch_version = version.parse(_torch_version).base_version
_torch_version = version.parse(version.parse(torch.__version__).base_version)
except: # noqa: E722
nncf_logger.debug("Could not parse torch version")
_torch_version = "0.0.0"
torch_version = version.parse(_torch_version).base_version
_torch_version = version.parse("0.0.0")

if version.parse(BKC_TORCH_VERSION).base_version != torch_version:
warn_bkc_version_mismatch("torch", BKC_TORCH_VERSION, torch.__version__)
if _torch_version not in SpecifierSet(BKC_TORCH_SPEC):
warn_bkc_version_mismatch("torch", BKC_TORCH_SPEC, torch.__version__)


# Required for correct COMPRESSION_ALGORITHMS registry functioning
Expand Down
6 changes: 3 additions & 3 deletions nncf/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@

__version__ = "2.10.0"

BKC_TORCH_VERSION = "2.2.1"
BKC_TORCHVISION_VERSION = "0.17.1"
BKC_TF_VERSION = "2.12.*"
BKC_TORCH_SPEC = "==2.2.*"
BKC_TF_SPEC = "==2.12.*"
STRICT_TF_SPEC = ">=2.8.4,<2.14.0"
2 changes: 1 addition & 1 deletion ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ exclude = ["nncf/tensorflow/__init__.py"]

[per-file-ignores]
"nncf/experimental/torch/nas/bootstrapNAS/__init__.py" = ["F401"]
"nncf/torch/__init__.py" = ["F401"]
"nncf/torch/__init__.py" = ["F401", "E402"]
"tests/**/*.py" = ["F403"]
"tests/**/__init__.py" = ["F401"]
"examples/**/*.py" = ["F403"]
Expand Down

0 comments on commit cd4f4f2

Please sign in to comment.