From 6ae0fdada1f19fc9c4ac4986b021348681ec44cd Mon Sep 17 00:00:00 2001 From: beauxq Date: Wed, 28 Feb 2024 18:26:39 -0800 Subject: [PATCH] allow any iterable of strings --- Options.py | 8 ++++---- Utils.py | 11 +++++++++++ requirements.txt | 1 + 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/Options.py b/Options.py index 7409bb3172d..cb7fa5749bc 100644 --- a/Options.py +++ b/Options.py @@ -12,7 +12,7 @@ from schema import And, Optional, Or, Schema -from Utils import get_fuzzy_results +from Utils import get_fuzzy_results, is_iterable_of_str if typing.TYPE_CHECKING: from BaseClasses import PlandoOptions @@ -766,7 +766,7 @@ class VerifyKeys(metaclass=FreezeValidKeys): value: typing.Any @classmethod - def verify_keys(cls, data: typing.List[str]): + def verify_keys(cls, data: typing.Iterable[str]): if cls.valid_keys: data = set(data) dataset = set(word.casefold() for word in data) if cls.valid_keys_casefold else set(data) @@ -857,7 +857,7 @@ def from_text(cls, text: str): @classmethod def from_any(cls, data: typing.Any): - if isinstance(data, (list, set, frozenset, tuple)): + if is_iterable_of_str(data): cls.verify_keys(data) return cls(data) return cls.from_text(str(data)) @@ -883,7 +883,7 @@ def from_text(cls, text: str): @classmethod def from_any(cls, data: typing.Any): - if isinstance(data, (list, set, frozenset, tuple)): + if is_iterable_of_str(data): cls.verify_keys(data) return cls(data) return cls.from_text(str(data)) diff --git a/Utils.py b/Utils.py index da2d837ad3a..cea6405a38b 100644 --- a/Utils.py +++ b/Utils.py @@ -19,6 +19,7 @@ from argparse import Namespace from settings import Settings, get_settings from typing import BinaryIO, Coroutine, Optional, Set, Dict, Any, Union +from typing_extensions import TypeGuard from yaml import load, load_all, dump try: @@ -966,3 +967,13 @@ def __bool__(self): def __len__(self): return sum(len(iterable) for iterable in self.iterable) + + +def is_iterable_of_str(obj: object) -> TypeGuard[typing.Iterable[str]]: + """ but not a `str` (because technically, `str` is `Iterable[str]`) """ + if isinstance(obj, str): + return False + if not isinstance(obj, typing.Iterable): + return False + obj_it: typing.Iterable[object] = obj + return all(isinstance(v, str) for v in obj_it) diff --git a/requirements.txt b/requirements.txt index e2ccb67c18d..9531e3058e8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,3 +11,4 @@ certifi>=2023.11.17 cython>=3.0.8 cymem>=2.0.8 orjson>=3.9.10 +typing-extensions>=4.7.0