Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sys_platform environment marker to version checker #5437

Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions model-optimizer/mo/utils/versions_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,11 @@ def parse_and_filter_versions_list(required_fw_versions, version_list, env_setup
# check environment marker
if len(splited_requirement) > 1:
achetver marked this conversation as resolved.
Show resolved Hide resolved
env_req = splited_requirement[1]
splited_env_req = re.split(r"==|>=|<=|>|<|~=", env_req)
if any([x in splited_requirement[1] for x in [' and ', ' or ']]):
log.error("The version checker doesn't support environment marker combination and it will be ignored: {}"
"".format(splited_requirement[1]), extra={'is_warning': True})
return version_list
splited_env_req = re.split(r"==|>=|<=|>|<|~=|!=", env_req)
achetver marked this conversation as resolved.
Show resolved Hide resolved
splited_env_req = [l.strip(',') for l in splited_env_req]
env_marker = splited_env_req[0].strip(' ')
if env_marker == 'python_version' and env_marker in env_setup:
Expand All @@ -82,14 +86,28 @@ def parse_and_filter_versions_list(required_fw_versions, version_list, env_setup
# this python_version requirement is not satisfied to required environment
# and requirement for a dependency will be skipped
return version_list
elif env_marker == 'sys_platform' and env_marker in env_setup:
splited_env_req[1] = splited_env_req[1].strip(' ').replace('\'', '')
achetver marked this conversation as resolved.
Show resolved Hide resolved
achetver marked this conversation as resolved.
Show resolved Hide resolved
if '==' in env_req:
if not env_setup['sys_platform'] == splited_env_req[1]:
achetver marked this conversation as resolved.
Show resolved Hide resolved
# this sys_platform requirement is not satisfied to required environment
# and requirement for a dependency will be skipped
return version_list
elif '!=' in env_req:
if not env_setup['sys_platform'] != splited_env_req[1]:
achetver marked this conversation as resolved.
Show resolved Hide resolved
# this sys_platform requirement is not satisfied to required environment
# and requirement for a dependency will be skipped
return version_list
else:
log.error("Error during platform version check")
achetver marked this conversation as resolved.
Show resolved Hide resolved
else:
log.error("{} is unsupported environment marker and it will be ignored".format(env_marker),
extra={'is_warning': True})

# parse a requirement for a dependency
requirement = splited_requirement[0]
splited_versions_by_conditions = re.split(r"==|>=|<=|>|<|~=", requirement)
splited_versions_by_conditions = [l.strip(',') for l in splited_versions_by_conditions]
splited_versions_by_conditions = [l.strip(',').strip(' ') for l in splited_versions_by_conditions]

if len(splited_versions_by_conditions) == 0:
return version_list
Expand Down Expand Up @@ -127,6 +145,9 @@ def get_module_version_list_from_file(file_name, env_setup):
req_dict = list()
with open(file_name) as f:
for line in f:
# handle comments
line = line.split('#')[0]

req_dict = parse_and_filter_versions_list(line, req_dict, env_setup)
return req_dict

Expand Down Expand Up @@ -191,6 +212,7 @@ def get_environment_setup():
exec("del tensorflow")
except (AttributeError, ImportError):
pass
env_setup['sys_platform'] = sys.platform
achetver marked this conversation as resolved.
Show resolved Hide resolved
return env_setup


Expand Down
46 changes: 44 additions & 2 deletions model-optimizer/unit_tests/mo/utils/versions_checker_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,9 @@ def test_get_module_version_list_from_file2(self, mock_open):
def test_get_module_version_list_from_file3(self, mock_open):
mock_open.return_value.__enter__ = mock_open
mock_open.return_value.__iter__ = mock.Mock(
return_value=iter(['tensorflow>=1.15.2,<2.0; python_version < "3.8"',
'tensorflow>=2.0; python_version >= "3.8"',
return_value=iter(['# Commented line',
'tensorflow>=1.15.2,<2.0; python_version < "3.8"',
'tensorflow>=2.0; python_version >= "3.8" # Comment after line',
'numpy==1.12.0',
'defusedxml<=0.5.0',
'networkx~=1.11']))
Expand Down Expand Up @@ -86,6 +87,47 @@ def test_append_version_list(self):
for i, v in enumerate(req_list):
self.assertEqual(v, ref_list[i])

def test_append_version_list_sys_neg_1(self):
v1 = "mxnet>=1.7.0 ; sys_platform != 'win32'"
req_list = list()
achetver marked this conversation as resolved.
Show resolved Hide resolved
parse_and_filter_versions_list(v1, req_list, {'sys_platform': 'darwin'})
ref_list = [('mxnet', '>=', '1.7.0')]
for i, v in enumerate(req_list):
self.assertEqual(v, ref_list[i])

def test_append_version_list_sys_neg_2(self):
v1 = "mxnet>=1.7.0 ; sys_platform != 'win32'"
req_list = list()
parse_and_filter_versions_list(v1, req_list, {'sys_platform': 'win32'})
ref_list = []
for i, v in enumerate(req_list):
self.assertEqual(v, ref_list[i])

def test_append_version_list_sys(self):
v1 = "mxnet>=1.7.0 ; sys_platform == 'linux'"
req_list = list()

parse_and_filter_versions_list(v1, req_list, {'sys_platform': 'linux'})
ref_list = [('mxnet', '>=', '1.7.0')]
for i, v in enumerate(req_list):
self.assertEqual(v, ref_list[i])

def test_append_version_list_sys_python_ver_1(self):
v1 = "mxnet>=1.7.0 ; sys_platform == 'linux' or python_version >= \"3.8\""
req_list = list()
parse_and_filter_versions_list(v1, req_list, {'python_version': '3.8.1', 'sys_platform': 'linux'})
ref_list = []
for i, v in enumerate(req_list):
self.assertEqual(v, ref_list[i])

def test_append_version_list_sys_python_ver_2(self):
v1 = "mxnet>=1.7.0 ; sys_platform == 'linux' and python_version >= \"3.8\""
req_list = list()
parse_and_filter_versions_list(v1, req_list, {'python_version': '3.7.1', 'sys_platform': 'linux'})
ref_list = []
for i, v in enumerate(req_list):
self.assertEqual(v, ref_list[i])

def test_version_check_equal(self):
modules_versions_list = [('module_1', '==', '2.0', '2.0'),
('module_2', '==', '2.0', '2.0.1'),
Expand Down