diff --git a/model-optimizer/mo/utils/versions_checker.py b/model-optimizer/mo/utils/versions_checker.py index d98a8ddd5ec719..0b532819e19b65 100644 --- a/model-optimizer/mo/utils/versions_checker.py +++ b/model-optimizer/mo/utils/versions_checker.py @@ -44,17 +44,24 @@ def check_python_version(): return 1 -def parse_versions_list(required_fw_versions, version_list): +def parse_and_filter_versions_list(required_fw_versions, version_list, env_setup): """ Please do not add parameter type annotations (param:type). Because we import this file while checking Python version. Python 2.x will fail with no clear message on type annotations. - Parsing requirements versions + Parsing requirements versions for a dependency and filtering out requirements that + satisfy environment setup such as python version. + if environment version (python_version, etc.) is satisfied :param required_fw_versions: String with fw versions from requirements file :param version_list: List for append + :param env_setup: a dictionary with environment setup :return: list of tuples of strings like (name_of_module, sign, version) + Examples of required_fw_versions: + 'tensorflow>=1.15.2,<2.0; python_version < "3.8"' + 'tensorflow>=2.0' + Returned object is: [('tensorflow', '>=', '1.2.0'), ('networkx', '==', '2.1'), ('numpy', None, None)] """ @@ -62,26 +69,57 @@ def parse_versions_list(required_fw_versions, version_list): line = required_fw_versions.strip('\n') line = line.strip(' ') if line == '': - return [] - splited_versions_by_conditions = re.split(r"==|>=|<=|>|<", line) + return version_list + splited_requirement = line.split(";") + + # check environment marker + if len(splited_requirement) > 1: + env_req = splited_requirement[1] + splited_env_req = re.split(r"==|>=|<=|>|<", env_req) + splited_env_req = [l.strip(',') for l in splited_env_req] + env_marker = splited_env_req[0].strip(' ') + if env_marker == 'python_version' and env_marker in env_setup: + installed_python_version = env_setup['python_version'] + env_req_version_list = [] + splited_required_versions = re.split(r",", env_req) + for i, l in enumerate(splited_required_versions): + for comparison in ['==', '>=', '<=', '<', '>']: + if comparison in l: + required_version = splited_env_req[i + 1].strip(' ').replace('"', '') + env_req_version_list.append((env_marker, comparison, required_version)) + break + not_satisfied_list = [] + for name, key, required_version in env_req_version_list: + version_check(name, installed_python_version, required_version, + key, not_satisfied_list, 0) + if len(not_satisfied_list) > 0: + # this python_version requirement is not satisfied to required environment + # and requirement for a dependency will be skipped + return version_list + else: + log.error("{} is unsupported environment marker and it will be ignored".format(env_marker), + extra={'is_warning': True}) + + # parse a requirement for a dependency + requirement = splited_requirement[0] + splited_versions_by_conditions = re.split(r"==|>=|<=|>|<", requirement) splited_versions_by_conditions = [l.strip(',') for l in splited_versions_by_conditions] if len(splited_versions_by_conditions) == 0: - return [] + return version_list if len(splited_versions_by_conditions) == 1: version_list.append((splited_versions_by_conditions[0], None, None)) else: - splited_required_versions= re.split(r",", line) + splited_required_versions= re.split(r",", requirement) for i, l in enumerate(splited_required_versions): - comparisons = ['==', '>=', '<=', '<', '>'] - for comparison in comparisons: + for comparison in ['==', '>=', '<=', '<', '>']: if comparison in l: version_list.append((splited_versions_by_conditions[0], comparison, splited_versions_by_conditions[i + 1])) break return version_list -def get_module_version_list_from_file(file_name): +def get_module_version_list_from_file(file_name, env_setup): """ Please do not add parameter type annotations (param:type). Because we import this file while checking Python version. @@ -89,6 +127,7 @@ def get_module_version_list_from_file(file_name): Reads file with requirements :param file_name: Name of the requirements file + :param env_setup: a dictionary with environment setup elements :return: list of tuples of strings like (name_of_module, sign, version) File content example: @@ -102,7 +141,7 @@ def get_module_version_list_from_file(file_name): req_dict = list() with open(file_name) as f: for line in f: - req_dict = parse_versions_list(line, req_dict) + req_dict = parse_and_filter_versions_list(line, req_dict, env_setup) return req_dict @@ -113,7 +152,7 @@ def version_check(name, installed_v, required_v, sign, not_satisfied_v, exit_cod Python 2.x will fail with no clear message on type annotations. Evaluates comparison of installed and required versions according to requirements file of one module. - If installed version does not satisfy requirements appends this module to not_stisfied_v list. + If installed version does not satisfy requirements appends this module to not_satisfied_v list. :param name: module name :param installed_v: installed version of module :param required_v: required version of module @@ -146,6 +185,25 @@ def version_check(name, installed_v, required_v, sign, not_satisfied_v, exit_cod return exit_code +def get_environment_setup(): + """ + Get environment setup such as Python version, TensorFlow version + :return: a dictionary of environment variables + """ + env_setup = dict() + python_version = "{}.{}.{}".format(sys.version_info.major, + sys.version_info.minor, + sys.version_info.micro) + env_setup['python_version'] = python_version + try: + exec("import tensorflow") + env_setup['tensorflow'] = sys.modules["tensorflow"].__version__ + exec("del tensorflow") + except (AttributeError, ImportError): + pass + return env_setup + + def check_requirements(framework=None): """ Please do not add parameter type annotations (param:type). @@ -158,13 +216,20 @@ def check_requirements(framework=None): :param framework: framework name :return: exit code (0 - execution successful, 1 - error) """ + env_setup = get_environment_setup() if framework is None: framework_suffix = "" + elif framework == "tf": + if "tensorflow" in env_setup and env_setup["tensorflow"] >= LooseVersion("2.0.0"): + framework_suffix = "_tf2" + else: + framework_suffix = "_tf" else: framework_suffix = "_{}".format(framework) + file_name = "requirements{}.txt".format(framework_suffix) requirements_file = os.path.realpath(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir, file_name)) - requirements_list = get_module_version_list_from_file(requirements_file) + requirements_list = get_module_version_list_from_file(requirements_file, env_setup) not_satisfied_versions = [] exit_code = 0 for name, key, required_version in requirements_list: diff --git a/model-optimizer/mo/utils/versions_checker_test.py b/model-optimizer/mo/utils/versions_checker_test.py index 227b74ee0086d8..35346d8fba4ac2 100644 --- a/model-optimizer/mo/utils/versions_checker_test.py +++ b/model-optimizer/mo/utils/versions_checker_test.py @@ -18,7 +18,7 @@ import unittest.mock as mock from unittest.mock import mock_open -from mo.utils.versions_checker import get_module_version_list_from_file, parse_versions_list +from mo.utils.versions_checker import get_module_version_list_from_file, parse_and_filter_versions_list class TestingVersionsChecker(unittest.TestCase): @@ -30,18 +30,51 @@ def test_get_module_version_list_from_file(self, mock_open): ref_list =[('mxnet', '>=', '1.0.0'), ('mxnet', '<=', '1.3.1'), ('networkx', '>=', '1.11'), ('numpy', '==', '1.12.0'), ('defusedxml', '<=', '0.5.0')] - version_list = get_module_version_list_from_file('mock_file') + version_list = get_module_version_list_from_file('mock_file', {}) self.assertEqual(len(version_list), 5) for i, version_dict in enumerate(version_list): self.assertTupleEqual(ref_list[i], version_dict) + @mock.patch('builtins.open', new_callable=mock_open, create=True) + def test_get_module_version_list_from_file2(self, mock_open): + mock_open.return_value.__enter__ = mock_open + mock_open.return_value.__iter__ = mock.Mock( + return_value=iter(['tensorflow>=1.15.2,<2.0; python_version < "3.8"', + 'tensorflow>=2.0; python_version >= "3.8"', + 'numpy==1.12.0', + 'defusedxml<=0.5.0'])) + ref_list =[('tensorflow', '>=', '1.15.2'), + ('tensorflow', '<', '2.0'), + ('numpy', '==', '1.12.0'), + ('defusedxml', '<=', '0.5.0')] + version_list = get_module_version_list_from_file('mock_file', {'python_version': '3.7.0'}) + self.assertEqual(len(version_list), 4) + for i, version_dict in enumerate(version_list): + self.assertTupleEqual(ref_list[i], version_dict) + + @mock.patch('builtins.open', new_callable=mock_open, create=True) + def test_get_module_version_list_from_file3(self, mock_open): + mock_open.return_value.__enter__ = mock_open + mock_open.return_value.__iter__ = mock.Mock( + return_value=iter(['tensorflow>=1.15.2,<2.0; python_version < "3.8"', + 'tensorflow>=2.0; python_version >= "3.8"', + 'numpy==1.12.0', + 'defusedxml<=0.5.0'])) + ref_list =[('tensorflow', '>=', '2.0'), + ('numpy', '==', '1.12.0'), + ('defusedxml', '<=', '0.5.0')] + version_list = get_module_version_list_from_file('mock_file', {'python_version': '3.8.1'}) + self.assertEqual(len(version_list), 3) + for i, version_dict in enumerate(version_list): + self.assertTupleEqual(ref_list[i], version_dict) + @mock.patch('builtins.open', new_callable=mock_open, create=True) def test_get_module_version_list_from_file_with_fw_name(self, mock_open): mock_open.return_value.__enter__ = mock_open mock_open.return_value.__iter__ = mock.Mock( return_value=iter(['mxnet'])) ref_list = [('mxnet', None, None)] - version_list = get_module_version_list_from_file('mock_file') + version_list = get_module_version_list_from_file('mock_file', {}) self.assertEqual(len(version_list), 1) for i, version_dict in enumerate(version_list): self.assertTupleEqual(ref_list[i], version_dict) @@ -49,7 +82,7 @@ def test_get_module_version_list_from_file_with_fw_name(self, mock_open): def test_append_version_list(self): v1 = 'mxnet>=1.0.0,<=1.3.1' req_list = list() - parse_versions_list(v1, req_list) + parse_and_filter_versions_list(v1, req_list, {}) ref_list = [('mxnet', '>=', '1.0.0'), ('mxnet', '<=', '1.3.1')] for i, v in enumerate(req_list): diff --git a/model-optimizer/requirements.txt b/model-optimizer/requirements.txt index e8069df734d5d8..137b4113b3c82c 100644 --- a/model-optimizer/requirements.txt +++ b/model-optimizer/requirements.txt @@ -1,4 +1,5 @@ -tensorflow>=1.15.2,<2.0 +tensorflow>=1.15.2,<2.0; python_version < "3.8" +tensorflow>=2.0; python_version >= "3.8" mxnet>=1.0.0,<=1.5.1 networkx>=1.11 numpy>=1.13.0 diff --git a/model-optimizer/requirements_tf.txt b/model-optimizer/requirements_tf.txt index ef7e24ed235ac1..a22cd69ac7b731 100644 --- a/model-optimizer/requirements_tf.txt +++ b/model-optimizer/requirements_tf.txt @@ -1,4 +1,5 @@ -tensorflow>=1.15.2,<2.0 +tensorflow>=1.15.2,<2.0; python_version < "3.8" +tensorflow>=2.0; python_version >= "3.8" networkx>=1.11 numpy>=1.13.0 test-generator==0.1.1