From daf142e9ea6611fe9b1c53948dcf302e11d7a90c Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 21 Jul 2024 17:37:27 +0100 Subject: [PATCH] refactor: Rename and move `OneOrSeq` Planned in https://github.com/vega/altair/pull/3427#discussion_r1683242081 Will allow for more reuse --- altair/vegalite/v5/api.py | 23 ++--------------------- altair/vegalite/v5/schema/_typing.py | 21 +++++++++++++++++++-- tools/generate_schema_wrapper.py | 24 +++++++++++++++++++++++- tools/schemapi/utils.py | 17 ++++++++++------- 4 files changed, 54 insertions(+), 31 deletions(-) diff --git a/altair/vegalite/v5/api.py b/altair/vegalite/v5/api.py index 2320f2b94..667446704 100644 --- a/altair/vegalite/v5/api.py +++ b/altair/vegalite/v5/api.py @@ -14,7 +14,6 @@ Union, TYPE_CHECKING, TypeVar, - Sequence, Protocol, ) from typing_extensions import TypeAlias @@ -45,10 +44,6 @@ from typing import TypedDict else: from typing_extensions import TypedDict -if sys.version_info >= (3, 12): - from typing import TypeAliasType -else: - from typing_extensions import TypeAliasType if TYPE_CHECKING: from ...utils.core import DataFrameLike @@ -125,26 +120,12 @@ AggregateOp_T, MultiTimeUnit_T, SingleTimeUnit_T, + OneOrSeq, ) ChartDataType: TypeAlias = Optional[Union[DataType, core.Data, str, core.Generator]] _TSchemaBase = TypeVar("_TSchemaBase", bound=core.SchemaBase) -_T = TypeVar("_T") -_OneOrSeq = TypeAliasType("_OneOrSeq", Union[_T, Sequence[_T]], type_params=(_T,)) -"""One of ``_T`` specified type(s), or a `Sequence` of such. - -Examples --------- -The parameters ``short``, ``long`` accept the same range of types:: - - # ruff: noqa: UP006, UP007 - - def func( - short: _OneOrSeq[str | bool | float], - long: Union[str, bool, float, Sequence[Union[str, bool, float]], - ): ... -""" # ------------------------------------------------------------------------ @@ -571,7 +552,7 @@ class _ConditionExtra(TypedDict, closed=True, total=False): # type: ignore[call param: Parameter | str test: _TestPredicateType value: Any - __extra_items__: _StatementType | _OneOrSeq[_LiteralValue] + __extra_items__: _StatementType | OneOrSeq[_LiteralValue] _Condition: TypeAlias = _ConditionExtra diff --git a/altair/vegalite/v5/schema/_typing.py b/altair/vegalite/v5/schema/_typing.py index fa7a2d8dc..80467c45b 100644 --- a/altair/vegalite/v5/schema/_typing.py +++ b/altair/vegalite/v5/schema/_typing.py @@ -4,9 +4,9 @@ from __future__ import annotations -from typing import Any, Literal, Mapping +from typing import Any, Literal, Mapping, Sequence, TypeVar, Union -from typing_extensions import TypeAlias +from typing_extensions import TypeAlias, TypeAliasType __all__ = [ "AggregateOp_T", @@ -32,6 +32,7 @@ "Mark_T", "MultiTimeUnit_T", "NonArgAggregateOp_T", + "OneOrSeq", "Orient_T", "Orientation_T", "ProjectionType_T", @@ -60,6 +61,22 @@ ] +T = TypeVar("T") +OneOrSeq = TypeAliasType("OneOrSeq", Union[T, Sequence[T]], type_params=(T,)) +"""One of ``T`` specified type(s), or a `Sequence` of such. + +Examples +-------- +The parameters ``short``, ``long`` accept the same range of types:: + + # ruff: noqa: UP006, UP007 + + def func( + short: OneOrSeq[str | bool | float], + long: Union[str, bool, float, Sequence[Union[str, bool, float]], + ): ... +""" + Map: TypeAlias = Mapping[str, Any] AggregateOp_T: TypeAlias = Literal[ "argmax", diff --git a/tools/generate_schema_wrapper.py b/tools/generate_schema_wrapper.py index 9244c3261..545cc4083 100644 --- a/tools/generate_schema_wrapper.py +++ b/tools/generate_schema_wrapper.py @@ -232,6 +232,26 @@ def encode({encode_method_args}) -> Self: return copy ''' +# NOTE: Not yet reasonable to generalize `TypeAliasType`, `TypeVar` +# Revisit if this starts to become more common +TYPING_EXTRA: Final = ''' +T = TypeVar("T") +OneOrSeq = TypeAliasType("OneOrSeq", Union[T, Sequence[T]], type_params=(T,)) +"""One of ``T`` specified type(s), or a `Sequence` of such. + +Examples +-------- +The parameters ``short``, ``long`` accept the same range of types:: + + # ruff: noqa: UP006, UP007 + + def func( + short: OneOrSeq[str | bool | float], + long: Union[str, bool, float, Sequence[Union[str, bool, float]], + ): ... +""" +''' + class SchemaGenerator(codegen.SchemaGenerator): schema_class_template = textwrap.dedent( @@ -815,7 +835,9 @@ def vegalite_main(skip_download: bool = False) -> None: ) print(msg) TypeAliasTracer.update_aliases(("Map", "Mapping[str, Any]")) - TypeAliasTracer.write_module(fp_typing, header=HEADER) + TypeAliasTracer.write_module( + fp_typing, "OneOrSeq", header=HEADER, extra=TYPING_EXTRA + ) # Write the pre-generated modules for fp, contents in files.items(): print(f"Writing\n {schemafile!s}\n ->{fp!s}") diff --git a/tools/schemapi/utils.py b/tools/schemapi/utils.py index 44c4c4573..e504cf299 100644 --- a/tools/schemapi/utils.py +++ b/tools/schemapi/utils.py @@ -71,8 +71,8 @@ def __init__( self._aliases: dict[str, str] = {} self._imports: Sequence[str] = ( "from __future__ import annotations\n", - "from typing import Literal, Mapping, Any", - "from typing_extensions import TypeAlias", + "from typing import Any, Literal, Mapping, TypeVar, Sequence, Union", + "from typing_extensions import TypeAlias, TypeAliasType", ) self._cmd_check: list[str] = ["--fix"] self._cmd_format: Sequence[str] = ruff_format or () @@ -141,7 +141,7 @@ def is_cached(self, tp: str, /) -> bool: return tp in self._literals_invert or tp in self._literals def write_module( - self, fp: Path, *extra_imports: str, header: LiteralString + self, fp: Path, *extra_all: str, header: LiteralString, extra: LiteralString ) -> None: """Write all collected `TypeAlias`'s to `fp`. @@ -149,20 +149,23 @@ def write_module( ---------- fp Path to new module. - *extra_imports - Follows `self._imports` block. + *extra_all + Any manually spelled types to be exported. header `tools.generate_schema_wrapper.HEADER`. + extra + `tools.generate_schema_wrapper.TYPING_EXTRA`. """ ruff_format = ["ruff", "format", fp] if self._cmd_format: ruff_format.extend(self._cmd_format) commands = (["ruff", "check", fp, *self._cmd_check], ruff_format) - static = (header, "\n", *self._imports, *extra_imports, "\n\n") + static = (header, "\n", *self._imports, "\n\n") self.update_aliases(*sorted(self._literals.items(), key=itemgetter(0))) + all_ = [*iter(self._aliases), *extra_all] it = chain( static, - [f"__all__ = {list(self._aliases)}", "\n\n"], + [f"__all__ = {all_}", "\n\n", extra], self.generate_aliases(), ) fp.write_text("\n".join(it), encoding="utf-8")