From 01ca741ab4c338fa1be11553bd4eb3a1fa1ba6d4 Mon Sep 17 00:00:00 2001 From: no Date: Thu, 17 Oct 2024 16:41:58 -0500 Subject: [PATCH 1/4] Improve typing for BS4 element.Tag's `get` and `get_attribute_list`. Currently, - `get()` will always have `None` as a possible return type, even if the default is set. - `get()`'s typing is wrong if the default's type is not `str` or `list[str]`. - `get_attribute_list()` will never have `None` in the possible types inside the list it returns, even if the default allows for that. - `get_attribute_list()`'s typing is wrong if the default's type is not `str` or `list[str]`. This change improves the type handling for these methods by using the same idea as `stdlib/builtins.pyi` uses for `dict.get()`. This is what `dict.get()`'s typing looks like: ```python @overload # type: ignore[override] def get(self, key: _KT) -> _VT | None: ... @overload def get(self, key: _KT, default: _VT) -> _VT: ... @overload def get(self, key: _KT, default: _T) -> _VT | _T: ... ``` Since Tag.get takes a `str` key and returns a `str` or `list[str]` if the attribute exists, we can do something like a `dict[str, str | list[str]]`: ```python @overload # type: ignore[override] def get(self, key: str) -> str | list | None: ... @overload def get(self, key: str, default: str | list) -> str | list: ... @overload def get(self, key: str, default: _T) -> str | list | _T: ... ``` Then because `str | list == str | list | str | list`, we can simplify: ```python @overload # type: ignore[override] def get(self, key: str) -> str | list | None: ... @overload def get(self, key: str, default: _T) -> str | list | _T: ... ```` We can also do something similar for `get_attribute_list()`: ```python @overload def get_attribute_list(self, key: str) -> list[str | None]: ... @overload def get_attribute_list(self, key: str, default: _T) -> list[str | _T]: ... ``` Except this isn't quite right -- if default is a list, the implementation returns it instead of a list[list], so we need to unwrap it: ```python @overload def get_attribute_list(self, key: str) -> list[str | None]: ... @overload def get_attribute_list(self, key: str, default: list[_T]) -> list[str | _T]: ... @overload def get_attribute_list(self, key: str, default: _T) -> list[str | _T]: ... ``` --- stubs/beautifulsoup4/bs4/element.pyi | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/stubs/beautifulsoup4/bs4/element.pyi b/stubs/beautifulsoup4/bs4/element.pyi index 6e9cf6d5f9ab..382068c78604 100644 --- a/stubs/beautifulsoup4/bs4/element.pyi +++ b/stubs/beautifulsoup4/bs4/element.pyi @@ -276,8 +276,16 @@ class Tag(PageElement): def clear(self, decompose: bool = False) -> None: ... def smooth(self) -> None: ... def index(self, element: PageElement) -> int: ... - def get(self, key: str, default: str | list[str] | None = None) -> str | list[str] | None: ... - def get_attribute_list(self, key: str, default: str | list[str] | None = None) -> list[str]: ... + @overload + def get(self, key: str) -> str | list[str] | None: ... + @overload + def get(self, key: str, default: _T) -> str | list[str] | _T: ... + @overload + def get_attribute_list(self, key: str) -> list[str | None]: ... + @overload + def get_attribute_list(self, key: str, default: list[_T]) -> list[str | _T]: ... + @overload + def get_attribute_list(self, key: str, default: _T) -> list[str | _T]: ... def has_attr(self, key: str) -> bool: ... def __hash__(self) -> int: ... def __getitem__(self, key: str) -> str | list[str]: ... From 76ff2748f2adafb5710b4cd0d413bd881fd6c788 Mon Sep 17 00:00:00 2001 From: no Date: Thu, 17 Oct 2024 21:20:27 -0500 Subject: [PATCH 2/4] Add the missing _T declaration --- stubs/beautifulsoup4/bs4/element.pyi | 1 + 1 file changed, 1 insertion(+) diff --git a/stubs/beautifulsoup4/bs4/element.pyi b/stubs/beautifulsoup4/bs4/element.pyi index 382068c78604..0269ced9a3cb 100644 --- a/stubs/beautifulsoup4/bs4/element.pyi +++ b/stubs/beautifulsoup4/bs4/element.pyi @@ -27,6 +27,7 @@ class ContentMetaAttributeValue(AttributeValueWithCharsetSubstitution): def __new__(cls, original_value): ... def encode(self, encoding: str) -> str: ... # type: ignore[override] # incompatible with str +_T = TypeVar("_T") _PageElementT = TypeVar("_PageElementT", bound=PageElement) _SimpleStrainable: TypeAlias = str | bool | None | bytes | Pattern[str] | Callable[[str], bool] | Callable[[Tag], bool] _Strainable: TypeAlias = _SimpleStrainable | Iterable[_SimpleStrainable] From b256ccce7961eef3d39ff1cd67484dbea9437328 Mon Sep 17 00:00:00 2001 From: no Date: Thu, 17 Oct 2024 21:23:22 -0500 Subject: [PATCH 3/4] Add overload for default: None. --- stubs/beautifulsoup4/bs4/element.pyi | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/stubs/beautifulsoup4/bs4/element.pyi b/stubs/beautifulsoup4/bs4/element.pyi index 0269ced9a3cb..c18c68636d15 100644 --- a/stubs/beautifulsoup4/bs4/element.pyi +++ b/stubs/beautifulsoup4/bs4/element.pyi @@ -280,10 +280,14 @@ class Tag(PageElement): @overload def get(self, key: str) -> str | list[str] | None: ... @overload + def get(self, key: str, default: None) -> str | list[str] | None: ... + @overload def get(self, key: str, default: _T) -> str | list[str] | _T: ... @overload def get_attribute_list(self, key: str) -> list[str | None]: ... @overload + def get_attribute_list(self, key: str, default: None) -> list[str | None]: ... + @overload def get_attribute_list(self, key: str, default: list[_T]) -> list[str | _T]: ... @overload def get_attribute_list(self, key: str, default: _T) -> list[str | _T]: ... From 835fe9b5a0480af698ae4282843f2065e1e12abf Mon Sep 17 00:00:00 2001 From: Kevin Mustelier Date: Sat, 19 Oct 2024 00:11:28 -0500 Subject: [PATCH 4/4] Apply suggestions from code review Co-authored-by: Sebastian Rittau --- stubs/beautifulsoup4/bs4/element.pyi | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/stubs/beautifulsoup4/bs4/element.pyi b/stubs/beautifulsoup4/bs4/element.pyi index c18c68636d15..cf5d39a4d4ef 100644 --- a/stubs/beautifulsoup4/bs4/element.pyi +++ b/stubs/beautifulsoup4/bs4/element.pyi @@ -278,15 +278,11 @@ class Tag(PageElement): def smooth(self) -> None: ... def index(self, element: PageElement) -> int: ... @overload - def get(self, key: str) -> str | list[str] | None: ... - @overload - def get(self, key: str, default: None) -> str | list[str] | None: ... + def get(self, key: str, default: None = None) -> str | list[str] | None: ... @overload def get(self, key: str, default: _T) -> str | list[str] | _T: ... @overload - def get_attribute_list(self, key: str) -> list[str | None]: ... - @overload - def get_attribute_list(self, key: str, default: None) -> list[str | None]: ... + def get_attribute_list(self, key: str, default: None = None) -> list[str | None]: ... @overload def get_attribute_list(self, key: str, default: list[_T]) -> list[str | _T]: ... @overload