diff --git a/pyproject.toml b/pyproject.toml index c83e0c3b86..01996bd327 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ requires = ["setuptools"] build-backend = "setuptools.build_meta" [tool.mypy] -files = ["src/rez/"] +files = ["src/rez/", "src/rezplugins/"] exclude = [ '.*/rez/data/.*', '.*/rez/vendor/.*', diff --git a/src/rez/build_system.py b/src/rez/build_system.py index 08ad023a2c..0f741c794f 100644 --- a/src/rez/build_system.py +++ b/src/rez/build_system.py @@ -6,7 +6,7 @@ import argparse import os.path -from typing import TYPE_CHECKING, TypedDict +from typing import TYPE_CHECKING from rez.build_process import BuildType @@ -15,16 +15,20 @@ from rez.rex_bindings import VariantBinding if TYPE_CHECKING: + from typing import TypedDict # not available until python 3.8 from rez.developer_package import DeveloperPackage from rez.resolved_context import ResolvedContext from rez.packages import Package, Variant from rez.rex import RexExecutor + # FIXME: move this out of TYPE_CHECKING block when python 3.7 support is dropped + class BuildResult(TypedDict, total=False): + success: bool + extra_files: list[str] + build_env_script: str -class BuildResult(TypedDict): - success: bool - extra_files: list[str] - build_env_script: str +else: + BuildResult = dict def get_buildsys_types(): diff --git a/src/rez/config.py b/src/rez/config.py index 52ba8510fa..56b8cf6a45 100644 --- a/src/rez/config.py +++ b/src/rez/config.py @@ -17,6 +17,7 @@ from rez.vendor.schema.schema import Schema, SchemaError, And, Or, Use from rez.vendor import yaml from rez.vendor.yaml.error import YAMLError +from rez.utils.typing import Protocol import rez.deprecations from contextlib import contextmanager from functools import lru_cache @@ -24,7 +25,7 @@ import os import re import copy -from typing import Protocol, TYPE_CHECKING +from typing import TYPE_CHECKING class Validatable(Protocol): diff --git a/src/rez/deprecations.py b/src/rez/deprecations.py index 40b3f5e315..0cd47bc456 100644 --- a/src/rez/deprecations.py +++ b/src/rez/deprecations.py @@ -20,7 +20,7 @@ def warn(message, category, pre_formatted=False, stacklevel=1, filename=None, ** original_formatwarning = warnings.formatwarning if pre_formatted: - def formatwarning(_, category, *args, **kwargs): + def formatwarning(_, category, *args, **kwargs) -> str: return "{0}{1}: {2}\n".format( "{0}: ".format(filename) if filename else "", category.__name__, message ) diff --git a/src/rez/package_filter.py b/src/rez/package_filter.py index c371cf2bfc..049867ba03 100644 --- a/src/rez/package_filter.py +++ b/src/rez/package_filter.py @@ -2,12 +2,15 @@ # Copyright Contributors to the Rez Project +from __future__ import annotations + from rez.packages import iter_packages from rez.exceptions import ConfigurationError from rez.config import config from rez.utils.data_utils import cached_property, cached_class_property from rez.version import VersionedObject, Requirement from hashlib import sha1 +from typing import Pattern import fnmatch import re @@ -328,6 +331,7 @@ class Rule(object): #: Rule name name: str + _family: str | None def match(self, package): """Apply the rule to the package. @@ -340,7 +344,7 @@ def match(self, package): """ raise NotImplementedError - def family(self): + def family(self) -> str | None: """Returns a package family string if this rule only applies to a given package family, otherwise None. @@ -412,7 +416,7 @@ def _parse_label(cls, txt): return None, txt @classmethod - def _extract_family(cls, txt): + def _extract_family(cls, txt) -> str | None: m = cls.family_re.match(txt) if m: return m.group()[:-1] @@ -426,7 +430,10 @@ def __repr__(self): class RegexRuleBase(Rule): - def match(self, package): + regex: Pattern[str] + txt: str + + def match(self, package) -> bool: return bool(self.regex.match(package.qualified_name)) def cost(self): @@ -448,7 +455,7 @@ class RegexRule(RegexRuleBase): """ name = "regex" - def __init__(self, s): + def __init__(self, s: str): """Create a regex rule. Args: @@ -466,7 +473,7 @@ class GlobRule(RegexRuleBase): """ name = "glob" - def __init__(self, s): + def __init__(self, s: str): """Create a glob rule. Args: diff --git a/src/rez/package_order.py b/src/rez/package_order.py index e6e79dd928..de1d548fbb 100644 --- a/src/rez/package_order.py +++ b/src/rez/package_order.py @@ -6,13 +6,14 @@ from inspect import isclass from hashlib import sha1 -from typing import Any, Callable, Iterable, List, Protocol, TYPE_CHECKING +from typing import Any, Callable, Iterable, List, TYPE_CHECKING from rez.config import config from rez.utils.data_utils import cached_class_property from rez.version import Version, VersionRange from rez.version._version import _Comparable, _ReversedComparable, _LowerBound, _UpperBound, _Bound from rez.packages import iter_packages, Package +from rez.utils.typing import SupportsLessThan if TYPE_CHECKING: # this is not available in typing until 3.11, but due to __future__.annotations @@ -22,11 +23,6 @@ ALL_PACKAGES = "*" -class SupportsLessThan(Protocol): - def __lt__(self, __other: Any) -> bool: - pass - - class FallbackComparable(_Comparable): """First tries to compare objects using the main_comparable, but if that fails, compares using the fallback_comparable object. diff --git a/src/rez/package_repository.py b/src/rez/package_repository.py index 11ae0c1462..1a323593e8 100644 --- a/src/rez/package_repository.py +++ b/src/rez/package_repository.py @@ -333,8 +333,10 @@ def on_variant_install_cancelled(self, variant_resource: VariantResource): """ pass - def install_variant(self, variant_resource: VariantResource, - dry_run: bool = False, overrides: dict[str, Any] = None) -> VariantResource: + def install_variant(self, + variant_resource: VariantResource, + dry_run: bool = False, + overrides: dict[str, Any] | None = None) -> VariantResource: """Install a variant into this repository. Use this function to install a variant from some other package repository diff --git a/src/rez/packages.py b/src/rez/packages.py index 2c528a15e2..087ede43eb 100644 --- a/src/rez/packages.py +++ b/src/rez/packages.py @@ -23,9 +23,10 @@ import os import sys -from typing import overload, Any, Iterator, Literal, TYPE_CHECKING +from typing import overload, Any, Iterator, TYPE_CHECKING if TYPE_CHECKING: + from typing import Literal # not available in typing module until 3.8 from rez.developer_package import DeveloperPackage from rez.version import Requirement from rez.package_repository import PackageRepository diff --git a/src/rez/resolver.py b/src/rez/resolver.py index 17c561465d..e4653f3794 100644 --- a/src/rez/resolver.py +++ b/src/rez/resolver.py @@ -363,6 +363,8 @@ def _set_cached_solve(self, solver_dict): release_times_dict = {} variant_states_dict = {} + assert self.resolved_packages_ is not None, \ + "self.resolved_packages_ is set in _set_result when status is 'solved'" for variant in self.resolved_packages_: time_ = get_last_release_time(variant.name, self.package_paths) diff --git a/src/rez/solver.py b/src/rez/solver.py index ad2047abd6..6fae203202 100644 --- a/src/rez/solver.py +++ b/src/rez/solver.py @@ -25,10 +25,11 @@ PackageFamilyNotFoundError, RezSystemError from rez.version import Version, VersionRange from rez.version import VersionedObject, Requirement, RequirementList +from rez.utils.typing import SupportsLessThan, Protocol from contextlib import contextmanager from enum import Enum from itertools import product, chain -from typing import Any, Callable, Generator, Protocol, Iterator, TypeVar, TYPE_CHECKING +from typing import Any, Callable, Generator, Iterator, TypeVar, TYPE_CHECKING import copy import time import sys @@ -43,16 +44,6 @@ T = TypeVar("T") -class SupportsLessThan(Protocol): - def __lt__(self, __other: Any) -> bool: - pass - - -class SupportsWrite(Protocol): - def write(self, __s: str) -> object: - pass - - # a hidden control for forcing to non-optimized solving mode. This is here as # first port of call for narrowing down the cause of a solver bug if we see one # diff --git a/src/rez/suite.py b/src/rez/suite.py index 7b84571e87..ef1a163aa4 100644 --- a/src/rez/suite.py +++ b/src/rez/suite.py @@ -12,12 +12,26 @@ from rez.vendor.yaml.error import YAMLError from rez.utils.yaml import dump_yaml from collections import defaultdict +from typing import TYPE_CHECKING import os import os.path import shutil import sys +if TYPE_CHECKING: + from typing import TypedDict + + # FIXME: move this out of TYPE_CHECKING block when python 3.7 support is dropped + class Tool(TypedDict): + tool_name: str + tool_alias: str + context_name: str + variant: int +else: + Tool = dict + + class Suite(object): """A collection of contexts. @@ -47,7 +61,7 @@ def __init__(self): self.next_priority = 1 self.tools = None - self.tool_conflicts = None + self.tool_conflicts: defaultdict[str, list[Tool]] | None = None self.hidden_tools = None @property @@ -725,7 +739,7 @@ def _update_tools(self): if alias is None: alias = "%s%s%s" % (prefix, tool_name, suffix) - entry = dict(tool_name=tool_name, + entry = Tool(tool_name=tool_name, tool_alias=alias, context_name=context_name, variant=variant) diff --git a/src/rez/utils/patching.py b/src/rez/utils/patching.py index 1252867bba..a422e9831b 100644 --- a/src/rez/utils/patching.py +++ b/src/rez/utils/patching.py @@ -42,8 +42,8 @@ def get_patched_request(requires, patchlist): '^': (True, True, True) } - requires = [Requirement(x) if not isinstance(x, Requirement) else x - for x in requires] + requires: list[Requirement | None] = [ + Requirement(x) if not isinstance(x, Requirement) else x for x in requires] appended = [] for patch in patchlist: diff --git a/src/rez/utils/platform_.py b/src/rez/utils/platform_.py index 0ba51315e1..5283cac79d 100644 --- a/src/rez/utils/platform_.py +++ b/src/rez/utils/platform_.py @@ -555,6 +555,7 @@ def _difftool(self): # singleton +# FIXME: is is valid for platform_ to be None? platform_ = None name = platform.system().lower() if name == "linux": diff --git a/src/rez/utils/schema.py b/src/rez/utils/schema.py index 36c22380e6..1ecefd2269 100644 --- a/src/rez/utils/schema.py +++ b/src/rez/utils/schema.py @@ -6,6 +6,7 @@ Utilities for working with dict-based schemas. """ from rez.vendor.schema.schema import Schema, Optional, Use, And +from rez.config import Validatable # an alias which just so happens to be the same number of characters as @@ -68,7 +69,7 @@ def _to(value): d[k] = _to(v) if allow_custom_keys: d[Optional(str)] = modifier or object - schema = Schema(d) + schema: Validatable = Schema(d) elif modifier: schema = And(value, modifier) else: diff --git a/src/rez/utils/typing.py b/src/rez/utils/typing.py new file mode 100644 index 0000000000..61e2b2ba6a --- /dev/null +++ b/src/rez/utils/typing.py @@ -0,0 +1,20 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright Contributors to the Rez Project + + +from __future__ import absolute_import, print_function + +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + # FIXME: use typing.Protocol instead of this workaround when python 3.7 support is dropped + from typing import Protocol + +else: + class Protocol(object): + pass + + +class SupportsLessThan(Protocol): + def __lt__(self, __other: Any) -> bool: + pass diff --git a/src/rez/version/_version.py b/src/rez/version/_version.py index 7583b456d6..b9a2750990 100644 --- a/src/rez/version/_version.py +++ b/src/rez/version/_version.py @@ -294,8 +294,8 @@ def __init__(self, ver_str: str | None = '', make_token=AlphanumericVersionToken """ self.tokens: list[VersionToken] | None = [] self.seps = [] - self._str = None - self._hash = None + self._str: str | None = None + self._hash: int | None = None if ver_str: toks = re_token.findall(ver_str) @@ -323,6 +323,8 @@ def copy(self) -> Version: Version: """ other = Version(None) + if self.tokens is None: + raise RuntimeError("Version.inf cannot be copied") other.tokens = self.tokens[:] other.seps = self.seps[:] return other @@ -338,6 +340,8 @@ def trim(self, len_: int) -> Version: Version: """ other = Version(None) + if self.tokens is None: + raise RuntimeError("Version.inf cannot be trimmed") other.tokens = self.tokens[:len_] other.seps = self.seps[:len_ - 1] return other @@ -393,6 +397,9 @@ def as_tuple(self) -> tuple[str, ...]: Returns: tuple[str]: """ + if self.tokens is None: + # Version.inf + return () return tuple(map(str, self.tokens)) def __len__(self) -> int: @@ -1273,23 +1280,23 @@ def __contains__(self, version_or_range: Version | VersionRange) -> bool: def __len__(self) -> int: return len(self.bounds) - def __invert__(self): + def __invert__(self) -> VersionRange | None: return self.inverse() - def __and__(self, other): + def __and__(self, other) -> VersionRange | None: return self.intersection(other) - def __or__(self, other): + def __or__(self, other) -> VersionRange: return self.union(other) - def __add__(self, other): + def __add__(self, other) -> VersionRange: return self.union(other) - def __sub__(self, other): + def __sub__(self, other) -> VersionRange | None: inv = other.inverse() return None if inv is None else self.intersection(inv) - def __str__(self): + def __str__(self) -> str: if self._str is None: self._str = '|'.join(map(str, self.bounds)) return self._str @@ -1319,8 +1326,8 @@ def _union(cls, bounds: list[_Bound]) -> list[_Bound]: bounds_ = list(sorted(bounds)) new_bounds = [] - prev_bound = None - upper = None + prev_bound: _Bound | None = None + upper: _UpperBound | None = None start = 0 for i, bound in enumerate(bounds_): @@ -1351,21 +1358,21 @@ def _intersection(cls, bounds1: list[_Bound], bounds2: list[_Bound]) -> list[_Bo @classmethod def _inverse(cls, bounds: list[_Bound]) -> list[_Bound]: - lbounds = [None] - ubounds = [] + lbounds: list[_LowerBound | None] = [None] + ubounds: list[_UpperBound | None] = [] for bound in bounds: if not bound.lower.version and bound.lower.inclusive: ubounds.append(None) else: - b = _UpperBound(bound.lower.version, not bound.lower.inclusive) - ubounds.append(b) + ub = _UpperBound(bound.lower.version, not bound.lower.inclusive) + ubounds.append(ub) if bound.upper.version == Version.inf: lbounds.append(None) else: - b = _LowerBound(bound.upper.version, not bound.upper.inclusive) - lbounds.append(b) + lb = _LowerBound(bound.upper.version, not bound.upper.inclusive) + lbounds.append(lb) ubounds.append(None) new_bounds = [] @@ -1491,7 +1498,7 @@ def _next_non_intersecting(self) -> T: return value @property - def _bound(self): + def _bound(self) -> _Bound | None: if self.index < self.nbounds: return self.range_.bounds[self.index] else: