From 017ddcd6bcf0fb838a9e5b3fc427ab7a746ddce6 Mon Sep 17 00:00:00 2001 From: yasin Date: Thu, 13 Jun 2024 15:19:43 +0200 Subject: [PATCH] Add custom coverage test --- discord/utils.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/discord/utils.py b/discord/utils.py index 99c7cfc94233..48960805ea9e 100644 --- a/discord/utils.py +++ b/discord/utils.py @@ -71,6 +71,7 @@ import typing import warnings import logging +from cov import test, mark import yarl @@ -480,7 +481,7 @@ def find(predicate: Callable[[T], Any], iterable: _Iter[T], /) -> Union[Optional else _find(predicate, iterable) # type: ignore ) - +@test(3) def _get(iterable: Iterable[T], /, **attrs: Any) -> Optional[T]: # global -> local _all = all @@ -488,6 +489,7 @@ def _get(iterable: Iterable[T], /, **attrs: Any) -> Optional[T]: # Special case the single element call if len(attrs) == 1: + mark(0) k, v = attrs.popitem() pred = attrget(k.replace('__', '.')) return next((elem for elem in iterable if pred(elem) == v), None) @@ -495,10 +497,12 @@ def _get(iterable: Iterable[T], /, **attrs: Any) -> Optional[T]: converted = [(attrget(attr.replace('__', '.')), value) for attr, value in attrs.items()] for elem in iterable: if _all(pred(elem) == value for pred, value in converted): + mark(1) return elem + mark(2) return None - +@test(4) async def _aget(iterable: AsyncIterable[T], /, **attrs: Any) -> Optional[T]: # global -> local _all = all @@ -510,14 +514,18 @@ async def _aget(iterable: AsyncIterable[T], /, **attrs: Any) -> Optional[T]: pred = attrget(k.replace('__', '.')) async for elem in iterable: if pred(elem) == v: + mark(0) return elem + mark(1) return None converted = [(attrget(attr.replace('__', '.')), value) for attr, value in attrs.items()] async for elem in iterable: if _all(pred(elem) == value for pred, value in converted): + mark(2) return elem + mark(3) return None