diff --git a/tools/generate_schema_wrapper.py b/tools/generate_schema_wrapper.py index e7307078a..1dcfd7d45 100644 --- a/tools/generate_schema_wrapper.py +++ b/tools/generate_schema_wrapper.py @@ -12,7 +12,7 @@ from itertools import chain from operator import attrgetter from pathlib import Path -from typing import TYPE_CHECKING, Any, Final, Literal +from typing import TYPE_CHECKING, Any, Final, Generic, Literal, TypedDict, TypeVar from urllib import request import vl_convert as vlc @@ -317,6 +317,18 @@ def encode(self, *args: Any, {method_args}) -> Self: return copy ''' +# Enables use of ~, &, | with compositions of selection objects. +DUNDER_PREDICATE_COMPOSITION = """ + def __invert__(self) -> PredicateComposition: + return PredicateComposition({"not": self.to_dict()}) + + def __and__(self, other: SchemaBase) -> PredicateComposition: + return PredicateComposition({"and": [self.to_dict(), other.to_dict()]}) + + def __or__(self, other: SchemaBase) -> PredicateComposition: + return PredicateComposition({"or": [self.to_dict(), other.to_dict()]}) +""" + # NOTE: Not yet reasonable to generalize `TypeAliasType`, `TypeVar` # Revisit if this starts to become more common @@ -431,6 +443,37 @@ class {classname}({basename}): ) +class MethodSchemaGenerator(SchemaGenerator): + """Base template w/ an extra slot `{method_code}` after `{init_code}`.""" + + schema_class_template = textwrap.dedent( + ''' + class {classname}({basename}): + """{docstring}""" + _schema = {schema!r} + + {init_code} + + {method_code} + ''' + ) + + +SchGen = TypeVar("SchGen", bound=SchemaGenerator) + + +class OverridesItem(TypedDict, Generic[SchGen]): + tp: type[SchGen] + kwds: dict[str, Any] + + +CORE_OVERRIDES: dict[str, OverridesItem[SchemaGenerator]] = { + "PredicateComposition": OverridesItem( + tp=MethodSchemaGenerator, kwds={"method_code": DUNDER_PREDICATE_COMPOSITION} + ) +} + + class FieldSchemaGenerator(SchemaGenerator): schema_class_template = textwrap.dedent( ''' @@ -656,13 +699,20 @@ def generate_vegalite_schema_wrapper(fp: Path, /) -> str: defschema = {"$ref": "#/definitions/" + name} defschema_repr = {"$ref": "#/definitions/" + name} name = get_valid_identifier(name) - definitions[name] = SchemaGenerator( + if overrides := CORE_OVERRIDES.get(name): + tp = overrides["tp"] + kwds = overrides["kwds"] + else: + tp = SchemaGenerator + kwds = {} + definitions[name] = tp( name, schema=defschema, schemarepr=defschema_repr, rootschema=rootschema, basename=basename, rootschemarepr=CodeSnippet(f"{basename}._rootschema"), + **kwds, ) for name, schema in definitions.items(): graph[name] = [] diff --git a/tools/schemapi/codegen.py b/tools/schemapi/codegen.py index 47d96dcd7..37d512087 100644 --- a/tools/schemapi/codegen.py +++ b/tools/schemapi/codegen.py @@ -260,16 +260,20 @@ def schema_class(self) -> str: basename = self.basename else: basename = ", ".join(self.basename) + docstring = self.docstring(indent=4) + init_code = self.init_code(indent=4) + if type(self).haspropsetters: + method_code = self.overload_code(indent=4) + else: + method_code = self.kwargs.pop("method_code", None) return self.schema_class_template.format( classname=self.classname, basename=basename, schema=schemarepr, rootschema=rootschemarepr, - docstring=self.docstring(indent=4), - init_code=self.init_code(indent=4), - method_code=( - self.overload_code(indent=4) if type(self).haspropsetters else None - ), + docstring=docstring, + init_code=init_code, + method_code=method_code, **self.kwargs, )