Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sys_platform environment marker to version checker #5437

Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 46 additions & 25 deletions model-optimizer/mo/utils/versions_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,31 +60,45 @@ def parse_and_filter_versions_list(required_fw_versions, version_list, env_setup

# check environment marker
if len(splited_requirement) > 1:
achetver marked this conversation as resolved.
Show resolved Hide resolved
env_req = splited_requirement[1]
splited_env_req = re.split(r"==|>=|<=|>|<|~=", env_req)
splited_env_req = [l.strip(',') for l in splited_env_req]
env_marker = splited_env_req[0].strip(' ')
if env_marker == 'python_version' and env_marker in env_setup:
installed_python_version = env_setup['python_version']
env_req_version_list = []
splited_required_versions = re.split(r",", env_req)
for i, l in enumerate(splited_required_versions):
for comparison in ['==', '>=', '<=', '<', '>', '~=']:
if comparison in l:
required_version = splited_env_req[i + 1].strip(' ').replace('"', '')
env_req_version_list.append((env_marker, comparison, required_version))
break
not_satisfied_list = []
for name, key, required_version in env_req_version_list:
version_check(name, installed_python_version, required_version,
key, not_satisfied_list)
if len(not_satisfied_list) > 0:
# this python_version requirement is not satisfied to required environment
# and requirement for a dependency will be skipped
return version_list
else:
log.error("{} is unsupported environment marker and it will be ignored".format(env_marker),
extra={'is_warning': True})
for env_req in splited_requirement[1:]:
splited_env_req = re.split(r"==|>=|<=|>|<|~=|!=", env_req)
splited_env_req = [l.strip(',') for l in splited_env_req]
env_marker = splited_env_req[0].strip(' ')
if env_marker == 'python_version' and env_marker in env_setup:
installed_python_version = env_setup['python_version']
env_req_version_list = []
splited_required_versions = re.split(r",", env_req)
for i, l in enumerate(splited_required_versions):
for comparison in ['==', '>=', '<=', '<', '>', '~=']:
if comparison in l:
required_version = splited_env_req[i + 1].strip(' ').replace('"', '')
env_req_version_list.append((env_marker, comparison, required_version))
break
not_satisfied_list = []
for name, key, required_version in env_req_version_list:
version_check(name, installed_python_version, required_version,
key, not_satisfied_list)
if len(not_satisfied_list) > 0:
# this python_version requirement is not satisfied to required environment
# and requirement for a dependency will be skipped
return version_list
elif env_marker == 'sys_platform' and env_marker in env_setup:
splited_env_req[1] = splited_env_req[1].strip(' ').replace('\'', '')
if '==' in env_req:
if not env_setup['sys_platform'] == splited_env_req[1]:
# this sys_platform requirement is not satisfied to required environment
# and requirement for a dependency will be skipped
return version_list
elif '!=' in env_req:
if not env_setup['sys_platform'] != splited_env_req[1]:
# this sys_platform requirement is not satisfied to required environment
# and requirement for a dependency will be skipped
return version_list
else:
log.error("Error during platform version check")
else:
log.error("{} is unsupported environment marker and it will be ignored".format(env_marker),
extra={'is_warning': True})

# parse a requirement for a dependency
requirement = splited_requirement[0]
Expand Down Expand Up @@ -127,6 +141,12 @@ def get_module_version_list_from_file(file_name, env_setup):
req_dict = list()
with open(file_name) as f:
for line in f:
if '#' in line:
# Handle comments
pos = line.find('#')
line = line[:pos]
if line == '':
continue
achetver marked this conversation as resolved.
Show resolved Hide resolved
req_dict = parse_and_filter_versions_list(line, req_dict, env_setup)
return req_dict

Expand Down Expand Up @@ -191,6 +211,7 @@ def get_environment_setup():
exec("del tensorflow")
except (AttributeError, ImportError):
pass
env_setup['sys_platform'] = sys.platform
achetver marked this conversation as resolved.
Show resolved Hide resolved
return env_setup


Expand Down
41 changes: 39 additions & 2 deletions model-optimizer/unit_tests/mo/utils/versions_checker_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import sys
import unittest
import unittest.mock as mock
from unittest.mock import mock_open
Expand Down Expand Up @@ -52,8 +53,9 @@ def test_get_module_version_list_from_file2(self, mock_open):
def test_get_module_version_list_from_file3(self, mock_open):
mock_open.return_value.__enter__ = mock_open
mock_open.return_value.__iter__ = mock.Mock(
return_value=iter(['tensorflow>=1.15.2,<2.0; python_version < "3.8"',
'tensorflow>=2.0; python_version >= "3.8"',
return_value=iter(['# Commented line',
'tensorflow>=1.15.2,<2.0; python_version < "3.8"',
'tensorflow>=2.0; python_version >= "3.8" # Comment after line',
'numpy==1.12.0',
'defusedxml<=0.5.0',
'networkx~=1.11']))
Expand Down Expand Up @@ -86,6 +88,41 @@ def test_append_version_list(self):
for i, v in enumerate(req_list):
self.assertEqual(v, ref_list[i])

def test_append_version_list_sys_neg(self):
v1 = "mxnet>=1.7.0; sys_platform != 'win32'"
req_list = list()
achetver marked this conversation as resolved.
Show resolved Hide resolved
parse_and_filter_versions_list(v1, req_list, {'sys_platform': sys.platform})
ref_list = [('mxnet', '>=', '1.7.0')] if sys.platform != 'win32' else []
for i, v in enumerate(req_list):
self.assertEqual(v, ref_list[i])

def test_append_version_list_sys(self):
v1 = "mxnet>=1.7.0; sys_platform == 'linux'"
req_list = list()
platform = sys.platform
achetver marked this conversation as resolved.
Show resolved Hide resolved
parse_and_filter_versions_list(v1, req_list, {'sys_platform': platform})
ref_list = [('mxnet', '>=', '1.7.0')] if platform == 'linux' else []
for i, v in enumerate(req_list):
self.assertEqual(v, ref_list[i])

def test_append_version_list_sys_python_ver_1(self):
v1 = "mxnet>=1.7.0; sys_platform == 'linux'; python_version >= \"3.8\""
req_list = list()
platform = sys.platform
parse_and_filter_versions_list(v1, req_list, {'python_version': '3.8.1', 'sys_platform': platform})
ref_list = [('mxnet', '>=', '1.7.0')] if platform == 'linux' else []
for i, v in enumerate(req_list):
self.assertEqual(v, ref_list[i])

def test_append_version_list_sys_python_ver_2(self):
v1 = "mxnet>=1.7.0; sys_platform == 'linux'; python_version >= \"3.8\""
req_list = list()
platform = sys.platform
parse_and_filter_versions_list(v1, req_list, {'python_version': '3.7.1', 'sys_platform': platform})
ref_list = []
for i, v in enumerate(req_list):
self.assertEqual(v, ref_list[i])

def test_version_check_equal(self):
modules_versions_list = [('module_1', '==', '2.0', '2.0'),
('module_2', '==', '2.0', '2.0.1'),
Expand Down