Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use packaging.version.parse instead of distutils.version.LooseVersion #17173

Merged
merged 1 commit into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions python/tvm/contrib/msc/core/utils/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""tvm.contrib.msc.core.utils.info"""

from typing import List, Tuple, Dict, Any, Union
from distutils.version import LooseVersion
from packaging.version import parse
import numpy as np

import tvm
Expand Down Expand Up @@ -409,8 +409,8 @@ def get_version(framework: str) -> List[int]:
raw_version = "1.0.0"
except: # pylint: disable=bare-except
raw_version = "1.0.0"
raw_version = raw_version or "1.0.0"
return LooseVersion(raw_version).version
version = parse(raw_version or "1.0.0")
return [version.major, version.minor, version.micro]


def compare_version(given_version: List[int], target_version: List[int]) -> int:
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relay/frontend/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,15 @@ def is_version_greater_than(ver):
than the one given as an argument.
"""
import torch
from distutils.version import LooseVersion
from packaging.version import parse

torch_ver = torch.__version__
# PT version numbers can include +cu[cuda version code]
# and we don't want to include that in the comparison
if "+cu" in torch_ver:
torch_ver = torch_ver.split("+cu")[0]

return LooseVersion(torch_ver) > ver
return parse(torch_ver) > parse(ver)


def getattr_attr_name(node):
Expand Down
6 changes: 3 additions & 3 deletions python/tvm/relay/op/contrib/ethosn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# pylint: disable=invalid-name, unused-argument
"""Arm(R) Ethos(TM)-N NPU supported operators."""
from enum import Enum
from distutils.version import LooseVersion
from packaging.version import parse

import tvm.ir
from tvm.relay import transform
Expand Down Expand Up @@ -118,7 +118,7 @@ def partition_for_ethosn(mod, params=None, **opts):
"""
api_version = ethosn_api_version()
supported_api_versions = ["3.2.0"]
if all(api_version != LooseVersion(exp_ver) for exp_ver in supported_api_versions):
if all(parse(api_version) != parse(exp_ver) for exp_ver in supported_api_versions):
raise ValueError(
f"Driver stack version {api_version} is unsupported. "
f"Please use version in {supported_api_versions}."
Expand Down Expand Up @@ -433,7 +433,7 @@ def split(expr):
"""Check if a split is supported by Ethos-N."""
if not ethosn_available():
return False
if ethosn_api_version() == LooseVersion("3.0.1"):
if parse(ethosn_api_version()) == parse("3.0.1"):
return False
if not _ethosn.split(expr):
return False
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relay/testing/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""Common utilities for creating TFLite models"""
from distutils.version import LooseVersion
from packaging.version import parse
import numpy as np
import pytest
import tflite.Model # pylint: disable=wrong-import-position
Expand Down Expand Up @@ -134,7 +134,7 @@ def generate_reference_data(self):
assert self.serial_model is not None, "TFLite model was not created."

output_tolerance = None
if tf.__version__ < LooseVersion("2.5.0"):
if parse(tf.__version__) < parse("2.5.0"):
output_tolerance = 1
interpreter = tf.lite.Interpreter(model_content=self.serial_model)
else:
Expand Down
4 changes: 2 additions & 2 deletions tests/python/contrib/test_arm_compute_lib/test_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
"""Arm Compute Library network tests."""

from distutils.version import LooseVersion
from packaging.version import parse

import numpy as np
import pytest
Expand Down Expand Up @@ -137,7 +137,7 @@ def get_model():
mod, params = _get_keras_model(mobilenet, inputs)
return mod, params, inputs

if keras.__version__ < LooseVersion("2.9"):
if parse(keras.__version__) < parse("2.9"):
# This can be removed after we migrate to TF/Keras >= 2.9
expected_tvm_ops = 56
expected_acl_partitions = 31
Expand Down
9 changes: 4 additions & 5 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
This article is a test script to test tensorflow operator with Relay.
"""
from __future__ import print_function
from distutils.version import LooseVersion

import threading
import platform
Expand Down Expand Up @@ -1755,7 +1754,7 @@ def _test_concat_v2(shape1, shape2, dim):


def test_forward_concat_v2():
if tf.__version__ < LooseVersion("1.4.1"):
if package_version.parse(tf.__version__) < package_version.parse("1.4.1"):
return

_test_concat_v2([2, 3], [2, 3], 0)
Expand Down Expand Up @@ -3128,7 +3127,7 @@ def _test_forward_clip_by_value(ip_shape, clip_value_min, clip_value_max, dtype)

def test_forward_clip_by_value():
"""test ClipByValue op"""
if tf.__version__ < LooseVersion("1.9"):
if package_version.parse(tf.__version__) < package_version.parse("1.9"):
_test_forward_clip_by_value((4,), 0.1, 5.0, "float32")
_test_forward_clip_by_value((4, 4), 1, 5, "int32")

Expand Down Expand Up @@ -4482,7 +4481,7 @@ def _test_forward_zeros_like(in_shape, dtype):


def test_forward_zeros_like():
if tf.__version__ < LooseVersion("1.2"):
if package_version.parse(tf.__version__) < package_version.parse("1.2"):
_test_forward_zeros_like((2, 3), "int32")
_test_forward_zeros_like((2, 3, 5), "int8")
_test_forward_zeros_like((2, 3, 5, 7), "uint16")
Expand Down Expand Up @@ -5566,7 +5565,7 @@ def test_forward_spop():
# This test is expected to fail in TF version >= 2.6
# as the generated graph will be considered frozen, hence
# not passing the criteria for the test below.
if tf.__version__ < LooseVersion("2.6.1"):
if package_version.parse(tf.__version__) < package_version.parse("2.6.1"):
_test_spop_resource_variables()

# Placeholder test cases
Expand Down
19 changes: 9 additions & 10 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
"""
from __future__ import print_function
from functools import partial
from distutils.version import LooseVersion
import platform
import os
import tempfile
Expand Down Expand Up @@ -1054,7 +1053,7 @@ def representative_data_gen():
input_node = subgraph.Tensors(model_input).Name().decode("utf-8")

tflite_output = run_tflite_graph(tflite_model_quant, data)
if tf.__version__ < LooseVersion("2.9"):
if package_version.parse(tf.__version__) < package_version.parse("2.9"):
input_node = data_in.name.replace(":0", "")
else:
input_node = "serving_default_" + data_in.name + ":0"
Expand Down Expand Up @@ -1775,7 +1774,7 @@ def representative_data_gen():

tflite_output = run_tflite_graph(tflite_model_quant, data)

if tf.__version__ < LooseVersion("2.9"):
if package_version.parse(tf.__version__) < package_version.parse("2.9"):
input_node = data_in.name.replace(":0", "")
else:
input_node = "serving_default_" + data_in.name + ":0"
Expand Down Expand Up @@ -2219,9 +2218,9 @@ def _test_abs(data, quantized, int_quant_dtype=tf.int8):
tflite_output = run_tflite_graph(tflite_model_quant, data)

# TFLite 2.6.x upgrade support
if tf.__version__ < LooseVersion("2.6.1"):
if package_version.parse(tf.__version__) < package_version.parse("2.6.1"):
in_node = ["serving_default_input_int8"]
elif tf.__version__ < LooseVersion("2.9"):
elif package_version.parse(tf.__version__) < package_version.parse("2.9"):
in_node = (
["serving_default_input_int16"] if int_quant_dtype == tf.int16 else ["tfl.quantize"]
)
Expand All @@ -2245,7 +2244,7 @@ def _test_rsqrt(data, quantized, int_quant_dtype=tf.int8):
"""One iteration of rsqrt"""

# tensorflow version upgrade support
if tf.__version__ < LooseVersion("2.6.1") or not quantized:
if package_version.parse(tf.__version__) < package_version.parse("2.6.1") or not quantized:
return _test_unary_elemwise(
math_ops.rsqrt, data, quantized, quant_range=[1, 6], int_quant_dtype=int_quant_dtype
)
Expand All @@ -2254,7 +2253,7 @@ def _test_rsqrt(data, quantized, int_quant_dtype=tf.int8):
tf.math.rsqrt, data, int_quant_dtype=int_quant_dtype
)
tflite_output = run_tflite_graph(tflite_model_quant, data)
if tf.__version__ < LooseVersion("2.9"):
if package_version.parse(tf.__version__) < package_version.parse("2.9"):
in_node = ["tfl.quantize"]
else:
in_node = "serving_default_input"
Expand Down Expand Up @@ -2338,7 +2337,7 @@ def _test_cos(data, quantized, int_quant_dtype=tf.int8):
tf.math.cos, data, int_quant_dtype=int_quant_dtype
)
tflite_output = run_tflite_graph(tflite_model_quant, data)
if tf.__version__ < LooseVersion("2.9"):
if package_version.parse(tf.__version__) < package_version.parse("2.9"):
in_node = ["tfl.quantize"]
else:
in_node = "serving_default_input"
Expand Down Expand Up @@ -3396,7 +3395,7 @@ def representative_data_gen():
tflite_model_quant = _quantize_keras_model(keras_model, representative_data_gen, True, True)

tflite_output = run_tflite_graph(tflite_model_quant, data)
if tf.__version__ < LooseVersion("2.9"):
if package_version.parse(tf.__version__) < package_version.parse("2.9"):
in_node = data_in.name.split(":")[0]
else:
in_node = "serving_default_" + data_in.name + ":0"
Expand Down Expand Up @@ -3426,7 +3425,7 @@ def representative_data_gen():
tflite_model_quant = _quantize_keras_model(keras_model, representative_data_gen, True, True)

tflite_output = run_tflite_graph(tflite_model_quant, data)
if tf.__version__ < LooseVersion("2.9"):
if package_version.parse(tf.__version__) < package_version.parse("2.9"):
in_node = data_in.name.split(":")[0]
else:
in_node = "serving_default_" + data_in.name + ":0"
Expand Down
Loading