From 7736c43cc7b225aa1c9c6e70774800ca40cc8418 Mon Sep 17 00:00:00 2001 From: Aron Bierbaum Date: Fri, 2 Sep 2022 11:06:03 -0500 Subject: [PATCH] Improve type hints for str vs bytes --- lxml-stubs/cssselect.pyi | 2 - lxml-stubs/etree.pyi | 213 +++++++++++++++++++---------------- lxml-stubs/html/__init__.pyi | 8 +- 3 files changed, 122 insertions(+), 101 deletions(-) diff --git a/lxml-stubs/cssselect.pyi b/lxml-stubs/cssselect.pyi index eb059fa..edb3b5d 100644 --- a/lxml-stubs/cssselect.pyi +++ b/lxml-stubs/cssselect.pyi @@ -5,8 +5,6 @@ from lxml import etree # dummy for missing stubs def __getattr__(name) -> Any: ... -_DictAnyStr = Union[Dict[str, str], Dict[bytes, bytes]] - class CSSSelector(etree.XPath): def __init__( self, diff --git a/lxml-stubs/etree.pyi b/lxml-stubs/etree.pyi index b87af01..be366f6 100644 --- a/lxml-stubs/etree.pyi +++ b/lxml-stubs/etree.pyi @@ -3,9 +3,11 @@ # Any use of `Any` below means I couldn't figure out the type. from os import PathLike +import sys from typing import ( IO, Any, + AnyStr, Callable, Dict, Iterable, @@ -32,7 +34,11 @@ def __getattr__(name: str) -> Any: ... # unnecessary constraint. It seems reasonable to constrain each # List/Dict argument to use one type consistently, though, and it is # necessary in order to keep these brief. -_AnyStr = Union[str, bytes] +if sys.version_info[0] >= 3: + _StrResult = str +else: + _StrResult = Union[str, bytes] + _AnySmartStr = Union[ "_ElementUnicodeResult", "_PyElementUnicodeResult", "_ElementStringResult" ] @@ -44,21 +50,27 @@ _XPathObject = Union[ bool, float, _AnySmartStr, - _AnyStr, + _StrResult, List[ Union[ "_Element", _AnySmartStr, - _AnyStr, - Tuple[Optional[_AnyStr], Optional[_AnyStr]], + _StrResult, + Tuple[Optional[_StrResult], Optional[_StrResult]], ] ], ] _AnyParser = Union["XMLParser", "HTMLParser"] -_ListAnyStr = Union[List[str], List[bytes]] -_DictAnyStr = Union[Dict[str, str], Dict[bytes, bytes]] -_Dict_Tuple2AnyStr_Any = Union[Dict[Tuple[str, str], Any], Tuple[bytes, bytes], Any] -_xpath = Union["XPath", _AnyStr] +if sys.version_info[0] >= 3: + _ListAnyStr = List[str] + _DictAnyStr = Dict[str, str] +else: + _ListAnyStr = Union[List[str], List[bytes]] + _DictAnyStr = Union[Dict[str, str], Dict[bytes, bytes]] +_StrOrBytes = Union[str, bytes] +_ValueType = Union[str, bytes, QName] +_InputDictAnyStr = Dict[_StrOrBytes, _StrOrBytes] +_ExtensionsDict = Dict[Tuple[_StrOrBytes, _StrOrBytes], Any] # See https://github.com/python/typing/pull/273 # Due to Mapping having invariant key types, Mapping[Union[A, B], ...] @@ -81,7 +93,7 @@ _KnownEncodings = Literal[ "us-ascii", ] _ElementOrTree = Union[_Element, _ElementTree] -_FileSource = Union[_AnyStr, IO[Any], PathLike[Any]] +_FileSource = Union[_StrOrBytes, IO[AnyStr], PathLike[str]] class ElementChildIterator(Iterator["_Element"]): def __iter__(self) -> "ElementChildIterator": ... @@ -91,21 +103,21 @@ class _ElementUnicodeResult(str): is_attribute: bool is_tail: bool is_text: bool - attrname: Optional[_AnyStr] + attrname: Optional[_StrResult] def getparent(self) -> Optional["_Element"]: ... class _PyElementUnicodeResult(str): is_attribute: bool is_tail: bool is_text: bool - attrname: Optional[_AnyStr] + attrname: Optional[_StrResult] def getparent(self) -> Optional["_Element"]: ... class _ElementStringResult(bytes): is_attribute: bool is_tail: bool is_text: bool - attrname: Optional[_AnyStr] + attrname: Optional[_StrResult] def getparent(self) -> Optional["_Element"]: ... class DocInfo: @@ -152,9 +164,9 @@ class _Element(Iterable["_Element"], Sized): ) -> List["_Element"]: ... def clear(self) -> None: ... @overload - def get(self, key: _TagName) -> Optional[str]: ... + def get(self, key: _TagName) -> Optional[_StrResult]: ... @overload - def get(self, key: _TagName, default: _T) -> Union[str, _T]: ... + def get(self, key: _TagName, default: _T) -> Union[_StrResult, _T]: ... def getnext(self) -> Optional[_Element]: ... def getparent(self) -> Optional[_Element]: ... def getprevious(self) -> Optional[_Element]: ... @@ -163,7 +175,7 @@ class _Element(Iterable["_Element"], Sized): self, child: _Element, start: Optional[int] = ..., stop: Optional[int] = ... ) -> int: ... def insert(self, index: int, element: _Element) -> None: ... - def items(self) -> Sequence[Tuple[_AnyStr, _AnyStr]]: ... + def items(self) -> Sequence[Tuple[_StrResult, _StrResult]]: ... def iter( self, tag: Optional[_TagSelector] = ..., *tags: _TagSelector ) -> Iterator[_Element]: ... @@ -189,35 +201,35 @@ class _Element(Iterable["_Element"], Sized): tag: Optional[_TagSelector] = ..., with_tail: bool = False, *tags: _TagSelector, - ) -> Iterator[_AnyStr]: ... - def keys(self) -> Sequence[_AnyStr]: ... + ) -> Iterator[_StrResult]: ... + def keys(self) -> Sequence[_StrResult]: ... def makeelement( self, _tag: _TagName, - attrib: Optional[_DictAnyStr] = ..., + attrib: Optional[_InputDictAnyStr] = ..., nsmap: Optional[_NSMapArg] = ..., **_extra: Any, ) -> _Element: ... def remove(self, element: _Element) -> None: ... def replace(self, old_element: _Element, new_element: _Element) -> None: ... - def set(self, key: _TagName, value: _AnyStr) -> None: ... - def values(self) -> Sequence[_AnyStr]: ... + def set(self, key: _TagName, value: _ValueType) -> None: ... + def values(self) -> Sequence[_StrResult]: ... def xpath( self, - _path: _AnyStr, + _path: _StrOrBytes, namespaces: Optional[_NonDefaultNSMapArg] = ..., extensions: Any = ..., smart_strings: bool = ..., **_variables: _XPathObject, ) -> _XPathObject: ... - tag = ... # type: str + tag = ... # type: _StrResult attrib = ... # type: _Attrib - text = ... # type: Optional[str] - tail = ... # type: Optional[str] - prefix = ... # type: str + text = ... # type: Optional[_StrResult] + tail = ... # type: Optional[_StrResult] + prefix = ... # type: Optional[_StrResult] sourceline = ... # Optional[int] @property - def nsmap(self) -> Dict[Optional[str], str]: ... + def nsmap(self) -> Dict[Optional[_StrResult], Optional[_StrResult]]: ... base = ... # type: Optional[str] class ElementBase(_Element): ... @@ -250,34 +262,36 @@ class _ElementTree: self, source: _FileSource, parser: Optional[_AnyParser] = ..., - base_url: Optional[_AnyStr] = ..., + base_url: Optional[_StrOrBytes] = ..., ) -> _Element: ... def write( self, file: _FileSource, - encoding: _AnyStr = ..., - method: _AnyStr = ..., + encoding: Optional[_StrOrBytes] = ..., + method: str = ..., pretty_print: bool = ..., xml_declaration: Any = ..., with_tail: Any = ..., standalone: bool = ..., + doctype: _StrOrBytes = ..., compression: int = ..., exclusive: bool = ..., + inclusive_ns_prefixes: List[Union[str, bytes]] = ..., with_comments: bool = ..., - inclusive_ns_prefixes: _ListAnyStr = ..., + strip_text: bool = ..., ) -> None: ... def write_c14n( self, file: _FileSource, with_comments: bool = ..., compression: int = ..., - inclusive_ns_prefixes: Iterable[_AnyStr] = ..., + inclusive_ns_prefixes: Iterable[_StrOrBytes] = ..., ) -> None: ... def _setroot(self, root: _Element) -> None: ... def xinclude(self) -> None: ... def xpath( self, - _path: _AnyStr, + _path: _StrOrBytes, namespaces: Optional[_NonDefaultNSMapArg] = ..., extensions: Any = ..., smart_strings: bool = ..., @@ -286,7 +300,7 @@ class _ElementTree: def xslt( self, _xslt: XSLT, - extensions: Optional[_Dict_Tuple2AnyStr_Any] = ..., + extensions: Optional[_ExtensionsDict] = ..., access_control: Optional[XSLTAccessControl] = ..., **_variables: Any, ) -> _ElementTree: ... @@ -295,35 +309,41 @@ class __ContentOnlyEleement(_Element): ... class _Comment(__ContentOnlyEleement): ... class _ProcessingInstruction(__ContentOnlyEleement): - target: _AnyStr + target: _StrResult class _Attrib: - def __setitem__(self, key: _AnyStr, value: _AnyStr) -> None: ... - def __delitem__(self, key: _AnyStr) -> None: ... + def __setitem__(self, key: _TagName, value: _ValueType) -> None: ... + def __delitem__(self, key: _TagName) -> None: ... def update( self, sequence_or_dict: Union[ - _Attrib, Mapping[_AnyStr, _AnyStr], Sequence[Tuple[_AnyStr, _AnyStr]] + _Attrib, Mapping[_ValueType, _ValueType], Sequence[Tuple[_ValueType, _ValueType]] ], ) -> None: ... - def pop(self, key: _AnyStr, default: _AnyStr) -> _AnyStr: ... + @overload + def pop(self, key: _TagName) -> _StrResult: ... + @overload + def pop(self, key: _TagName, default: _T) -> Union[_StrResult, _T]: ... def clear(self) -> None: ... def __repr__(self) -> str: ... def __copy__(self) -> _DictAnyStr: ... def __deepcopy__(self, memo: Dict[Any, Any]) -> _DictAnyStr: ... - def __getitem__(self, key: _AnyStr) -> _AnyStr: ... + def __getitem__(self, key: _TagName) -> _StrResult: ... def __bool__(self) -> bool: ... def __len__(self) -> int: ... - def get(self, key: _AnyStr, default: _AnyStr = ...) -> Optional[_AnyStr]: ... + @overload + def get(self, key: _TagName) -> _StrResult: ... + @overload + def get(self, key: _TagName, default: _T) -> Union[_StrResult, _T]: ... def keys(self) -> _ListAnyStr: ... - def __iter__(self) -> Iterator[_AnyStr]: ... # actually _AttribIterator - def iterkeys(self) -> Iterator[_AnyStr]: ... + def __iter__(self) -> Iterator[_StrResult]: ... # actually _AttribIterator + def iterkeys(self) -> Iterator[_StrResult]: ... def values(self) -> _ListAnyStr: ... - def itervalues(self) -> Iterator[_AnyStr]: ... - def items(self) -> List[Tuple[_AnyStr, _AnyStr]]: ... - def iteritems(self) -> Iterator[Tuple[_AnyStr, _AnyStr]]: ... - def has_key(self, key: _AnyStr) -> bool: ... - def __contains__(self, key: _AnyStr) -> bool: ... + def itervalues(self) -> Iterator[_StrResult]: ... + def items(self) -> List[Tuple[_StrResult, _StrResult]]: ... + def iteritems(self) -> Iterator[Tuple[_StrResult, _StrResult]]: ... + def has_key(self, key: _TagName) -> bool: ... + def __contains__(self, key: _TagName) -> bool: ... def __richcmp__(self, other: _Attrib, op: int) -> bool: ... class QName: @@ -332,8 +352,8 @@ class QName: text = ... # type: str def __init__( self, - text_or_uri_element: Union[None, _AnyStr, _Element], - tag: Optional[_AnyStr] = ..., + text_or_uri_element: Union[None, _TagName, _Element], + tag: Optional[_TagName] = ..., ) -> None: ... class _XSLTResultTree(_ElementTree, SupportsBytes): @@ -343,11 +363,11 @@ class _XSLTQuotedStringParam: ... # https://lxml.de/parsing.html#the-target-parser-interface class ParserTarget(Protocol): - def comment(self, text: _AnyStr) -> None: ... + def comment(self, text: _StrResult) -> None: ... def close(self) -> Any: ... - def data(self, data: _AnyStr) -> None: ... - def end(self, tag: _AnyStr) -> None: ... - def start(self, tag: _AnyStr, attrib: Dict[_AnyStr, _AnyStr]) -> None: ... + def data(self, data: _StrResult) -> None: ... + def end(self, tag: _StrResult) -> None: ... + def start(self, tag: _StrResult, attrib: Dict[_StrResult, _StrResult]) -> None: ... class ElementClassLookup: ... @@ -367,7 +387,7 @@ class _BaseParser: def makeelement( self, _tag: _TagName, - attrib: Optional[Union[_DictAnyStr, _Attrib]] = ..., + attrib: Optional[Union[_InputDictAnyStr, _Attrib]] = ..., nsmap: Optional[_NSMapArg] = ..., **_extra: Any, ) -> _Element: ... @@ -381,12 +401,12 @@ class _BaseParser: class _FeedParser(_BaseParser): def __getattr__(self, name: str) -> Any: ... # Incomplete def close(self) -> _Element: ... - def feed(self, data: _AnyStr) -> None: ... + def feed(self, data: _StrOrBytes) -> None: ... class XMLParser(_FeedParser): def __init__( self, - encoding: Optional[_AnyStr] = ..., + encoding: Optional[_StrOrBytes] = ..., attribute_defaults: bool = ..., dtd_validation: bool = ..., load_dtd: bool = ..., @@ -409,7 +429,7 @@ class XMLParser(_FeedParser): class HTMLParser(_FeedParser): def __init__( self, - encoding: Optional[_AnyStr] = ..., + encoding: Optional[_StrOrBytes] = ..., collect_ids: bool = ..., compact: bool = ..., huge_tree: bool = ..., @@ -430,10 +450,10 @@ class _ResolverRegistry: class Resolver: def resolve(self, system_url: str, public_id: str): ... def resolve_file( - self, f: IO[Any], context: Any, *, base_url: Optional[_AnyStr], close: bool + self, f: IO[Any], context: Any, *, base_url: Optional[_StrOrBytes], close: bool ): ... def resolve_string( - self, string: _AnyStr, context: Any, *, base_url: Optional[_AnyStr] + self, string: _StrOrBytes, context: Any, *, base_url: Optional[_StrOrBytes] ): ... class XMLSchema(_Validator): @@ -450,32 +470,32 @@ class XSLT: def __init__( self, xslt_input: _ElementOrTree, - extensions: _Dict_Tuple2AnyStr_Any = ..., + extensions: Optional[_ExtensionsDict] = ..., regexp: bool = ..., - access_control: XSLTAccessControl = ..., + access_control: Optional[XSLTAccessControl] = ..., ) -> None: ... def __call__( self, _input: _ElementOrTree, profile_run: bool = ..., - **kwargs: Union[_AnyStr, _XSLTQuotedStringParam], + **kwargs: Union[_StrOrBytes, XPath, _XSLTQuotedStringParam], ) -> _XSLTResultTree: ... @staticmethod - def strparam(s: _AnyStr) -> _XSLTQuotedStringParam: ... + def strparam(s: _StrOrBytes) -> _XSLTQuotedStringParam: ... -def Comment(text: Optional[_AnyStr] = ...) -> _Comment: ... +def Comment(text: Optional[_StrOrBytes] = ...) -> _Comment: ... def Element( _tag: _TagName, - attrib: Optional[_DictAnyStr] = ..., + attrib: Optional[_InputDictAnyStr] = ..., nsmap: Optional[_NSMapArg] = ..., - **extra: _AnyStr, + **extra: _StrOrBytes, ) -> _Element: ... def SubElement( _parent: _Element, _tag: _TagName, - attrib: Optional[_DictAnyStr] = ..., + attrib: Optional[_InputDictAnyStr] = ..., nsmap: Optional[_NSMapArg] = ..., - **extra: _AnyStr, + **extra: _StrOrBytes, ) -> _Element: ... def ElementTree( element: _Element = ..., @@ -483,44 +503,44 @@ def ElementTree( parser: _AnyParser = ..., ) -> _ElementTree: ... def ProcessingInstruction( - target: _AnyStr, text: _AnyStr = ... + target: _StrOrBytes, text: Optional[_StrOrBytes] = ... ) -> _ProcessingInstruction: ... PI = ProcessingInstruction def HTML( - text: _AnyStr, + text: _StrOrBytes, parser: Optional[HTMLParser] = ..., - base_url: Optional[_AnyStr] = ..., + base_url: Optional[_StrOrBytes] = ..., ) -> _Element: ... def XML( - text: _AnyStr, + text: _StrOrBytes, parser: Optional[XMLParser] = ..., - base_url: Optional[_AnyStr] = ..., + base_url: Optional[_StrOrBytes] = ..., ) -> _Element: ... def cleanup_namespaces( tree_or_element: _ElementOrTree, top_nsmap: Optional[_NSMapArg] = ..., - keep_ns_prefixes: Optional[Iterable[_AnyStr]] = ..., + keep_ns_prefixes: Optional[Iterable[_StrOrBytes]] = ..., ) -> None: ... def parse( source: _FileSource, parser: _AnyParser = ..., - base_url: _AnyStr = ..., + base_url: _StrOrBytes = ..., ) -> Union[_ElementTree, Any]: ... @overload def fromstring( - text: _AnyStr, + text: _StrOrBytes, parser: None = ..., *, - base_url: _AnyStr = ..., + base_url: _StrOrBytes = ..., ) -> _Element: ... @overload def fromstring( - text: _AnyStr, + text: _StrOrBytes, parser: _AnyParser = ..., *, - base_url: _AnyStr = ..., + base_url: _StrOrBytes = ..., ) -> Union[_Element, Any]: ... @overload def tostring( @@ -533,8 +553,9 @@ def tostring( standalone: bool = ..., doctype: str = ..., exclusive: bool = ..., - with_comments: bool = ..., inclusive_ns_prefixes: Any = ..., + with_comments: bool = ..., + skip_text: bool = ..., ) -> str: ... @overload def tostring( @@ -548,8 +569,9 @@ def tostring( standalone: bool = ..., doctype: str = ..., exclusive: bool = ..., - with_comments: bool = ..., inclusive_ns_prefixes: Any = ..., + with_comments: bool = ..., + skip_text: bool = ..., ) -> bytes: ... @overload def tostring( @@ -562,9 +584,10 @@ def tostring( standalone: bool = ..., doctype: str = ..., exclusive: bool = ..., - with_comments: bool = ..., inclusive_ns_prefixes: Any = ..., -) -> _AnyStr: ... + with_comments: bool = ..., + skip_text: bool = ..., +) -> _StrOrBytes: ... class _ErrorLog: ... class Error(Exception): ... @@ -596,7 +619,7 @@ class _XPathEvaluatorBase: ... class XPath(_XPathEvaluatorBase): def __init__( self, - path: _AnyStr, + path: _StrOrBytes, *, namespaces: Optional[_NonDefaultNSMapArg] = ..., extensions: Any = ..., @@ -611,7 +634,7 @@ class XPath(_XPathEvaluatorBase): class ETXPath(XPath): def __init__( self, - path: _AnyStr, + path: _StrOrBytes, *, extensions: Any = ..., regexp: bool = ..., @@ -628,8 +651,8 @@ class XPathElementEvaluator(_XPathEvaluatorBase): regexp: bool = ..., smart_strings: bool = ..., ) -> None: ... - def __call__(self, _path: _AnyStr, **_variables: _XPathObject) -> _XPathObject: ... - def register_namespace(self, prefix: _AnyStr, uri: _AnyStr) -> None: ... + def __call__(self, _path: _StrOrBytes, **_variables: _XPathObject) -> _XPathObject: ... + def register_namespace(self, prefix: _StrOrBytes, uri: _StrOrBytes) -> None: ... def register_namespaces( self, namespaces: Optional[_NonDefaultNSMapArg] ) -> None: ... @@ -670,9 +693,9 @@ def XPathEvaluator( smart_strings: bool = ..., ) -> Union[XPathElementEvaluator, XPathDocumentEvaluator]: ... -_ElementFactory = Callable[[Any, Dict[_AnyStr, _AnyStr]], _Element] -_CommentFactory = Callable[[_AnyStr], _Comment] -_ProcessingInstructionFactory = Callable[[_AnyStr, _AnyStr], _ProcessingInstruction] +_ElementFactory = Callable[[_StrOrBytes, Dict[_StrOrBytes, _StrOrBytes]], _Element] +_CommentFactory = Callable[[_StrOrBytes], _Comment] +_ProcessingInstructionFactory = Callable[[_StrOrBytes, _StrOrBytes], _ProcessingInstruction] class TreeBuilder: def __init__( @@ -685,10 +708,10 @@ class TreeBuilder: insert_pis: bool = ..., ) -> None: ... def close(self) -> _Element: ... - def comment(self, text: _AnyStr) -> None: ... - def data(self, data: _AnyStr) -> None: ... - def end(self, tag: _AnyStr) -> None: ... - def pi(self, target: _AnyStr, data: Optional[_AnyStr] = ...) -> Any: ... - def start(self, tag: _AnyStr, attrib: Dict[_AnyStr, _AnyStr]) -> None: ... + def comment(self, text: _StrOrBytes) -> None: ... + def data(self, data: _StrOrBytes) -> None: ... + def end(self, tag: _StrOrBytes) -> None: ... + def pi(self, target: _StrOrBytes, data: Optional[_StrOrBytes] = ...) -> Any: ... + def start(self, tag: _StrOrBytes, attrib: Dict[_StrOrBytes, _StrOrBytes]) -> None: ... def iselement(element: Any) -> TypeGuard[_Element]: ... diff --git a/lxml-stubs/html/__init__.pyi b/lxml-stubs/html/__init__.pyi index 65da1ea..d2a76bb 100644 --- a/lxml-stubs/html/__init__.pyi +++ b/lxml-stubs/html/__init__.pyi @@ -16,7 +16,7 @@ from typing_extensions import Literal if TYPE_CHECKING: from ..etree import HTMLParser as _HTMLParser from ..etree import XMLParser as _XMLParser - from ..etree import _AnySmartStr, _AnyStr, _BaseParser, _Element + from ..etree import _AnySmartStr, _BaseParser, _Element, _StrOrBytes _HANDLE_FALURES = Literal["ignore", "discard", None] @@ -70,16 +70,16 @@ class XHTMLParser(_XMLParser): pass def document_fromstring( - html: "_AnyStr", parser: "_BaseParser" = ..., ensure_head_body: bool = ..., **kw + html: _StrOrBytes, parser: "_BaseParser" = ..., ensure_head_body: bool = ..., **kw ) -> "_Element": ... def fragments_fromstring( - html: "_AnyStr", + html: _StrOrBytes, no_leading_text: bool = ..., base_url: str = ..., parser: "_BaseParser" = ..., **kw ) -> "_Element": ... def fromstring( - html: "_AnyStr", base_url: str = ..., parser: "_BaseParser" = ..., **kw + html: _StrOrBytes, base_url: str = ..., parser: "_BaseParser" = ..., **kw ) -> "_Element": ... def __getattr__(name: str) -> Any: ... # incomplete