From f8ae8bdefa38f8ae14de03d57ef087adff3d20fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Randy=20D=C3=B6ring?= <30527984+radoering@users.noreply.github.com> Date: Sun, 13 Mar 2022 08:22:24 +0100 Subject: [PATCH] Improve union of multi markers --- src/poetry/core/version/markers.py | 127 +++++++++++++++++++++-------- tests/packages/utils/test_utils.py | 15 ++-- tests/version/test_markers.py | 107 +++++++++++++++++++++--- 3 files changed, 199 insertions(+), 50 deletions(-) diff --git a/src/poetry/core/version/markers.py b/src/poetry/core/version/markers.py index 4149b810c..ba70f35c8 100644 --- a/src/poetry/core/version/markers.py +++ b/src/poetry/core/version/markers.py @@ -460,14 +460,51 @@ def intersect(self, other: BaseMarker) -> BaseMarker: return MultiMarker.of(*new_markers) def union(self, other: BaseMarker) -> BaseMarker: - if other in self._markers: - return other - if isinstance(other, (SingleMarker, MultiMarker)): return MarkerUnion.of(self, other) return other.union(self) + def union_simplify(self, other: BaseMarker) -> BaseMarker | None: + """ + In contrast to the standard union method, which prefers to return + a MarkerUnion of MultiMarkers, this version prefers to return + a MultiMarker of MarkerUnions. + + The rationale behind this approach is to find additional simplifications. + In order to avoid endless recursions, this method returns None + if it cannot find a simplification. + """ + if isinstance(other, SingleMarker): + new_markers = [] + for marker in self._markers: + union = marker.union(other) + if not union.is_any(): + new_markers.append(union) + + if len(new_markers) == 1: + return new_markers[0] + if other in new_markers and all( + other == m or isinstance(m, MarkerUnion) and other in m.markers + for m in new_markers + ): + return other + + elif isinstance(other, MultiMarker): + markers = set(self._markers) + other_markers = set(other.markers) + common_markers = markers & other_markers + unique_markers = markers - common_markers + other_unique_markers = other_markers - common_markers + if common_markers: + unique_union = self.of(*unique_markers).union( + self.of(*other_unique_markers) + ) + if not isinstance(unique_union, MarkerUnion): + return self.of(*common_markers).intersect(unique_union) + + return None + def validate(self, environment: dict[str, Any]) -> bool: return all(m.validate(environment) for m in self._markers) @@ -543,42 +580,68 @@ def markers(self) -> list[BaseMarker]: @classmethod def of(cls, *markers: BaseMarker) -> BaseMarker: - flattened_markers = _flatten_markers(markers, MarkerUnion) + new_markers = _flatten_markers(markers, MarkerUnion) + old_markers: list[BaseMarker] = [] - new_markers: list[BaseMarker] = [] - for marker in flattened_markers: - if marker in new_markers: - continue + while old_markers != new_markers: + old_markers = new_markers + new_markers = [] + for marker in old_markers: + if marker in new_markers or marker.is_empty(): + continue - if isinstance(marker, SingleMarker): included = False - for i, mark in enumerate(new_markers): - if isinstance(mark, SingleMarker) and ( - mark.name == marker.name - or ( - mark.name in PYTHON_VERSION_MARKERS - and marker.name in PYTHON_VERSION_MARKERS - ) - ): - union = mark.constraint.union(marker.constraint) - if union == mark.constraint: - included = True - break - elif union == marker.constraint: - new_markers[i] = marker - included = True - break - elif union.is_any(): - return AnyMarker() - elif isinstance(union, VersionConstraint) and union.is_simple(): - new_markers[i] = SingleMarker(mark.name, union) + + if isinstance(marker, SingleMarker): + for i, mark in enumerate(new_markers): + if isinstance(mark, SingleMarker) and ( + mark.name == marker.name + or ( + mark.name in PYTHON_VERSION_MARKERS + and marker.name in PYTHON_VERSION_MARKERS + ) + ): + constraint_union = mark.constraint.union(marker.constraint) + if constraint_union == mark.constraint: + included = True + break + elif constraint_union == marker.constraint: + new_markers[i] = marker + included = True + break + elif constraint_union.is_any(): + return AnyMarker() + elif ( + isinstance(constraint_union, VersionConstraint) + and constraint_union.is_simple() + ): + new_markers[i] = SingleMarker( + mark.name, constraint_union + ) + included = True + break + + elif isinstance(mark, MultiMarker): + union = mark.union_simplify(marker) + if union is not None: + new_markers[i] = union + included = True + break + + elif isinstance(marker, MultiMarker): + included = False + for i, mark in enumerate(new_markers): + union = marker.union_simplify(mark) + if union is not None: + new_markers[i] = union included = True break if included: - continue - - new_markers.append(marker) + # flatten again because union_simplify may return a union + new_markers = _flatten_markers(new_markers, MarkerUnion) + else: + new_markers.append(marker) if any(m.is_any() for m in new_markers): return AnyMarker() diff --git a/tests/packages/utils/test_utils.py b/tests/packages/utils/test_utils.py index 82ffd1935..4e01f766b 100644 --- a/tests/packages/utils/test_utils.py +++ b/tests/packages/utils/test_utils.py @@ -10,22 +10,27 @@ def test_convert_markers(): marker = parse_marker( - 'sys_platform == "win32" and python_version < "3.6" or sys_platform == "win32"' + 'sys_platform == "win32" and python_version < "3.6" or sys_platform == "linux"' ' and python_version < "3.6" and python_version >= "3.3" or sys_platform ==' - ' "win32" and python_version < "3.3"' + ' "darwin" and python_version < "3.3"' ) - converted = convert_markers(marker) - assert converted["python_version"] == [ [("<", "3.6")], [("<", "3.6"), (">=", "3.3")], [("<", "3.3")], ] - marker = parse_marker('python_version == "2.7" or python_version == "2.6"') + marker = parse_marker( + 'sys_platform == "win32" and python_version < "3.6" or sys_platform == "win32"' + ' and python_version < "3.6" and python_version >= "3.3" or sys_platform ==' + ' "win32" and python_version < "3.3"' + ) converted = convert_markers(marker) + assert converted["python_version"] == [[("<", "3.6")]] + marker = parse_marker('python_version == "2.7" or python_version == "2.6"') + converted = convert_markers(marker) assert converted["python_version"] == [[("==", "2.7")], [("==", "2.6")]] diff --git a/tests/version/test_markers.py b/tests/version/test_markers.py index f9fedbd8e..2fe0319c2 100644 --- a/tests/version/test_markers.py +++ b/tests/version/test_markers.py @@ -240,6 +240,60 @@ def test_single_marker_union_with_multi_duplicate(): assert str(union) == 'sys_platform == "darwin" and python_version >= "3.6"' +@pytest.mark.parametrize( + ("single_marker", "multi_marker", "expected"), + [ + ( + 'python_version >= "3.6"', + 'python_version >= "3.7" and sys_platform == "win32"', + 'python_version >= "3.6"', + ), + ( + 'sys_platform == "linux"', + 'sys_platform != "linux" and sys_platform != "win32"', + 'sys_platform != "win32"', + ), + ], +) +def test_single_marker_union_with_multi_is_single_marker( + single_marker: str, multi_marker: str, expected: str +): + m = parse_marker(single_marker) + union = m.union(parse_marker(multi_marker)) + assert str(union) == expected + + +def test_single_marker_union_with_multi_cannot_be_simplified(): + m = parse_marker('python_version >= "3.7"') + union = m.union(parse_marker('python_version >= "3.6" and sys_platform == "win32"')) + assert ( + str(union) + == 'python_version >= "3.6" and sys_platform == "win32" or python_version >=' + ' "3.7"' + ) + + +def test_single_marker_union_with_multi_is_union_of_single_markers(): + m = parse_marker('python_version >= "3.6"') + union = m.union(parse_marker('python_version < "3.6" and sys_platform == "win32"')) + assert str(union) == 'sys_platform == "win32" or python_version >= "3.6"' + + +def test_single_marker_union_with_multi_union_is_union_of_single_markers(): + m = parse_marker('python_version >= "3.6"') + union = m.union( + parse_marker( + 'python_version < "3.6" and sys_platform == "win32" or python_version <' + ' "3.6" and sys_platform == "linux"' + ) + ) + assert ( + str(union) + == 'sys_platform == "win32" or sys_platform == "linux" or python_version >=' + ' "3.6"' + ) + + def test_single_marker_union_with_union(): m = parse_marker('sys_platform == "darwin"') @@ -367,29 +421,60 @@ def test_multi_marker_intersect_with_multi_union_leads_to_empty_in_two_steps(): def test_multi_marker_union_multi(): m = parse_marker('sys_platform == "darwin" and implementation_name == "cpython"') - intersection = m.union( - parse_marker('python_version >= "3.6" and os_name == "Windows"') - ) + union = m.union(parse_marker('python_version >= "3.6" and os_name == "Windows"')) assert ( - str(intersection) + str(union) == 'sys_platform == "darwin" and implementation_name == "cpython" ' 'or python_version >= "3.6" and os_name == "Windows"' ) +def test_multi_marker_union_multi_is_single_marker(): + m = parse_marker('python_version >= "3" and sys_platform == "win32"') + m2 = parse_marker('sys_platform != "win32" and python_version >= "3"') + assert str(m.union(m2)) == 'python_version >= "3"' + assert str(m2.union(m)) == 'python_version >= "3"' + + +def test_multi_marker_union_multi_is_multi(): + m = parse_marker('python_version >= "3" and sys_platform == "win32"') + m2 = parse_marker( + 'python_version >= "3" and sys_platform != "win32" and sys_platform != "linux"' + ) + assert str(m.union(m2)) == 'python_version >= "3" and sys_platform != "linux"' + assert str(m2.union(m)) == 'python_version >= "3" and sys_platform != "linux"' + + def test_multi_marker_union_with_union(): m = parse_marker('sys_platform == "darwin" and implementation_name == "cpython"') - intersection = m.union( - parse_marker('python_version >= "3.6" or os_name == "Windows"') - ) + union = m.union(parse_marker('python_version >= "3.6" or os_name == "Windows"')) assert ( - str(intersection) + str(union) == 'python_version >= "3.6" or os_name == "Windows"' ' or sys_platform == "darwin" and implementation_name == "cpython"' ) +def test_multi_marker_union_with_multi_union_is_single_marker(): + m = parse_marker('sys_platform == "darwin" and python_version == "3"') + m2 = parse_marker( + 'sys_platform == "darwin" and python_version < "3" or sys_platform == "darwin"' + ' and python_version > "3"' + ) + assert str(m.union(m2)) == 'sys_platform == "darwin"' + assert str(m2.union(m)) == 'sys_platform == "darwin"' + + +def test_multi_marker_union_with_union_multi_is_single_marker(): + m = parse_marker('sys_platform == "darwin" and python_version == "3"') + m2 = parse_marker( + 'sys_platform == "darwin" and (python_version < "3" or python_version > "3")' + ) + assert str(m.union(m2)) == 'sys_platform == "darwin"' + assert str(m2.union(m)) == 'sys_platform == "darwin"' + + def test_marker_union(): m = parse_marker('sys_platform == "darwin" or implementation_name == "cpython"') @@ -440,11 +525,7 @@ def test_marker_union_intersect_single_with_overlapping_constraints(): m = parse_marker('sys_platform == "darwin" or python_version < "3.4"') intersection = m.intersect(parse_marker('sys_platform == "darwin"')) - assert ( - str(intersection) - == 'sys_platform == "darwin" or python_version < "3.4" and sys_platform ==' - ' "darwin"' - ) + assert str(intersection) == 'sys_platform == "darwin"' def test_marker_union_intersect_marker_union():