diff --git a/tools/__pycache__/generated_classes.cpython-311.pyc b/tools/__pycache__/generated_classes.cpython-311.pyc new file mode 100644 index 00000000..1ffff5c6 Binary files /dev/null and b/tools/__pycache__/generated_classes.cpython-311.pyc differ diff --git a/tools/example-api-usage.py b/tools/example-api-usage.py new file mode 100644 index 00000000..dab9a982 --- /dev/null +++ b/tools/example-api-usage.py @@ -0,0 +1,3 @@ +import python_api + + diff --git a/tools/generate_mosaic_schema_wrapper.py b/tools/generate_mosaic_schema_wrapper.py new file mode 100644 index 00000000..e69de29b diff --git a/tools/generate_schema_wrapper.py b/tools/generate_schema_wrapper.py new file mode 100644 index 00000000..e51f4b92 --- /dev/null +++ b/tools/generate_schema_wrapper.py @@ -0,0 +1,114 @@ +import json +from typing import Any, Dict, List, Union, Final, Iterable, Iterator, Literal +import sys +import yaml +import argparse +import copy +import re +import textwrap +from dataclasses import dataclass +from itertools import chain +from pathlib import Path +from urllib import request + +sys.path.insert(0, str(Path.cwd())) +from tools.schemapi import CodeSnippet, SchemaInfo, codegen +from tools.schemapi.utils import ( + TypeAliasTracer, + get_valid_identifier, + indent_docstring, + resolve_references, + rst_parse, + rst_syntax_for_class, + ruff_format_py, + ruff_write_lint_format_str, + spell_literal, +) + +def generate_class(class_name: str, class_schema: Dict[str, Any]) -> str: + + imports = "from typing import Any, Union\n" + + if 'type' in class_schema and 'properties' not in class_schema: + return f"class {class_name}:\n def __init__(self):\n pass\n" + + if '$ref' in class_schema: + ref_class_name = class_schema['$ref'].split('/')[-1] + return f"{imports}\nclass {class_name}:\n pass # This is a reference to {ref_class_name}\n" + + if 'anyOf' in class_schema: + return generate_any_of_class(class_name, class_schema['anyOf']) + + properties = class_schema.get('properties', {}) + required = class_schema.get('required', []) + + class_def = f"{imports}class {class_name}:\n" + class_def += " def __init__(self" + + for prop, prop_schema in properties.items(): + type_hint = get_type_hint(prop_schema) + if prop in required: + class_def += f", {prop}: {type_hint}" + else: + class_def += f", {prop}: {type_hint} = None" + + class_def += "):\n" + + for prop in properties: + class_def += f" self.{prop} = {prop}\n" + + return class_def + + +def generate_any_of_class(class_name: str, any_of_schemas: List[Dict[str, Any]]) -> str: + types = [get_type_hint(schema) for schema in any_of_schemas] + type_union = "Union[" + ", ".join(f'"{t}"' for t in types) + "]" + + class_def = f"class {class_name}:\n" + class_def += f" def __init__(self, value: {type_union}):\n" + class_def += " self.value = value\n" + + return class_def + + + +def get_type_hint(prop_schema: Dict[str, Any]) -> str: + """Get type hint for a property schema.""" + if 'type' in prop_schema: + if prop_schema['type'] == 'string': + return 'str' + elif prop_schema['type'] == 'boolean': + return 'bool' + elif prop_schema['type'] == 'object': + return 'Dict[str, Any]' + elif 'anyOf' in prop_schema: + types = [get_type_hint(option) for option in prop_schema['anyOf']] + return f'Union[{", ".join(types)}]' + elif '$ref' in prop_schema: + return prop_schema['$ref'].split('/')[-1] + return 'Any' + +def load_schema(schema_path: Path) -> dict: + """Load a JSON schema from the specified path.""" + with schema_path.open(encoding="utf8") as f: + return json.load(f) + +def generate_schema_wrapper(schema_file: Path, output_file: Path) -> str: + """Generate a schema wrapper for the given schema file.""" + rootschema = load_schema(schema_file) + + definitions: Dict[str, str] = {} + + for name, schema in rootschema.get("definitions", {}).items(): + class_code = generate_class(name, schema) + definitions[name] = class_code + + generated_classes = "\n\n".join(definitions.values()) + + with open(output_file, 'w') as f: + f.write(generated_classes) + +if __name__ == "__main__": + schema_file = "tools/testingSchema.json" + output_file = Path("tools/generated_classes.py") + generate_schema_wrapper(Path(schema_file), output_file) diff --git a/tools/generate_schema_wrapper_commented.py b/tools/generate_schema_wrapper_commented.py new file mode 100644 index 00000000..7328d880 --- /dev/null +++ b/tools/generate_schema_wrapper_commented.py @@ -0,0 +1,1038 @@ +from __future__ import annotations +"""Generate a schema wrapper from a schema.""" + +""" +(X) The file is organized into several key sections: + +1. **Constants**: + - This section defines constants that are used throughout the module for configuration and encoding methods. + +2. **Schema Generation Functions**: + - `generate_vegalite_schema_wrapper`: This function generates a schema wrapper for Vega-Lite based on the provided schema file. + - `load_schema_with_shorthand_properties`: Loads the schema and incorporates shorthand properties for easier usage. + - `_add_shorthand_property_to_field_encodings`: Adds shorthand properties to field encodings within the schema. + +3. **Utility Functions**: + - `copy_schemapi_util`: Copies the schemapi utility into the altair/utils directory for reuse. + - `recursive_dict_update`: Recursively updates a dictionary schema with new definitions, ensuring that references are resolved. + - `get_field_datum_value_defs`: Retrieves definitions for fields, datum, and values from a given property schema. + - `toposort`: Performs a topological sort on a directed acyclic graph, which is useful for managing dependencies between schema definitions. + +4. **Channel Wrapper Generation**: + - `generate_vegalite_channel_wrappers`: Generates channel wrappers for the Vega-Lite schema, allowing for the mapping of data properties to visual properties. + +5. **Mixin Generation**: + - `generate_vegalite_mark_mixin`: Creates a mixin class that defines methods for different types of marks in Vega-Lite. + - `generate_vegalite_config_mixin`: Generates a mixin class that provides configuration methods for the schema. + +6. **Main Execution Function**: + - `vegalite_main`: The main function that orchestrates the schema generation process, handling the loading of schemas and the creation of wrapper files. + +7. **Encoding Artifacts Generation**: + - `generate_encoding_artifacts`: Generates artifacts related to encoding, including type aliases and mixin classes for encoding methods. + +8. **Main Entry Point**: + - `main`: The entry point for the script, which processes command-line arguments and initiates the schema generation workflow. +""" + +import yaml +import argparse +import copy +import json +import re +import sys +import textwrap +from dataclasses import dataclass +from itertools import chain +from pathlib import Path +from typing import Final, Iterable, Iterator, Literal +from urllib import request + +import vl_convert as vlc + +sys.path.insert(0, str(Path.cwd())) +###(H) SchemaInfo class imported from altair/tools/schemapi/utils.py +### It's a wrapper for inspecting JSON schema +from tools.schemapi import CodeSnippet, SchemaInfo, codegen +from tools.schemapi.utils import ( + TypeAliasTracer, + get_valid_identifier, + indent_docstring, + resolve_references, + rst_parse, + rst_syntax_for_class, + ruff_format_py, + ruff_write_lint_format_str, + spell_literal, +) + +SCHEMA_VERSION: Final = "v5.20.1" + +reLink = re.compile(r"(?<=\[)([^\]]+)(?=\]\([^\)]+\))", re.MULTILINE) +reSpecial = re.compile(r"[*_]{2,3}|`", re.MULTILINE) + +HEADER: Final = """\ +# The contents of this file are automatically written by +# tools/generate_schema_wrapper.py. Do not modify directly. +""" + +SCHEMA_URL_TEMPLATE: Final = "https://vega.github.io/schema/{library}/{version}.json" + +CHANNEL_MYPY_IGNORE_STATEMENTS: Final = """\ +# These errors need to be ignored as they come from the overload methods +# which trigger two kind of errors in mypy: +# * all of them do not have an implementation in this file +# * some of them are the only overload methods -> overloads usually only make +# sense if there are multiple ones +# However, we need these overloads due to how the propertysetter works +# mypy: disable-error-code="no-overload-impl, empty-body, misc" +""" + +BASE_SCHEMA: Final = """ +class {basename}(SchemaBase): + _rootschema = load_schema() + @classmethod + def _default_wrapper_classes(cls) -> Iterator[type[Any]]: + return _subclasses({basename}) +""" + +LOAD_SCHEMA: Final = ''' +def load_schema() -> dict: + """Load the json schema associated with this module's functions""" + schema_bytes = pkgutil.get_data(__name__, "{schemafile}") + if schema_bytes is None: + raise ValueError("Unable to load {schemafile}") + return json.loads( + schema_bytes.decode("utf-8") + ) +''' + + +CHANNEL_MIXINS: Final = """ +class FieldChannelMixin: + _encoding_name: str + def to_dict( + self, + validate: bool = True, + ignore: list[str] | None = None, + context: dict[str, Any] | None = None, + ) -> dict | list[dict]: + context = context or {} + ignore = ignore or [] + shorthand = self._get("shorthand") # type: ignore[attr-defined] + field = self._get("field") # type: ignore[attr-defined] + + if shorthand is not Undefined and field is not Undefined: + msg = f"{self.__class__.__name__} specifies both shorthand={shorthand} and field={field}. " + raise ValueError(msg) + + if isinstance(shorthand, (tuple, list)): + # If given a list of shorthands, then transform it to a list of classes + kwds = self._kwds.copy() # type: ignore[attr-defined] + kwds.pop("shorthand") + return [ + self.__class__(sh, **kwds).to_dict( # type: ignore[call-arg] + validate=validate, ignore=ignore, context=context + ) + for sh in shorthand + ] + + if shorthand is Undefined: + parsed = {} + elif isinstance(shorthand, str): + data: nw.DataFrame | Any = context.get("data", None) + parsed = parse_shorthand(shorthand, data=data) + type_required = "type" in self._kwds # type: ignore[attr-defined] + type_in_shorthand = "type" in parsed + type_defined_explicitly = self._get("type") is not Undefined # type: ignore[attr-defined] + if not type_required: + # Secondary field names don't require a type argument in VegaLite 3+. + # We still parse it out of the shorthand, but drop it here. + parsed.pop("type", None) + elif not (type_in_shorthand or type_defined_explicitly): + if isinstance(data, nw.DataFrame): + msg = ( + f'Unable to determine data type for the field "{shorthand}";' + " verify that the field name is not misspelled." + " If you are referencing a field from a transform," + " also confirm that the data type is specified correctly." + ) + raise ValueError(msg) + else: + msg = ( + f"{shorthand} encoding field is specified without a type; " + "the type cannot be automatically inferred because " + "the data is not specified as a pandas.DataFrame." + ) + raise ValueError(msg) + else: + # Shorthand is not a string; we pass the definition to field, + # and do not do any parsing. + parsed = {"field": shorthand} + context["parsed_shorthand"] = parsed + + return super(FieldChannelMixin, self).to_dict( + validate=validate, ignore=ignore, context=context + ) + + +class ValueChannelMixin: + _encoding_name: str + def to_dict( + self, + validate: bool = True, + ignore: list[str] | None = None, + context: dict[str, Any] | None = None, + ) -> dict: + context = context or {} + ignore = ignore or [] + condition = self._get("condition", Undefined) # type: ignore[attr-defined] + copy = self # don't copy unless we need to + if condition is not Undefined: + if isinstance(condition, core.SchemaBase): + pass + elif "field" in condition and "type" not in condition: + kwds = parse_shorthand(condition["field"], context.get("data", None)) + copy = self.copy(deep=["condition"]) # type: ignore[attr-defined] + copy["condition"].update(kwds) # type: ignore[index] + return super(ValueChannelMixin, copy).to_dict( + validate=validate, ignore=ignore, context=context + ) + + +class DatumChannelMixin: + _encoding_name: str + def to_dict( + self, + validate: bool = True, + ignore: list[str] | None = None, + context: dict[str, Any] | None = None, + ) -> dict: + context = context or {} + ignore = ignore or [] + datum = self._get("datum", Undefined) # type: ignore[attr-defined] # noqa + copy = self # don't copy unless we need to + return super(DatumChannelMixin, copy).to_dict( + validate=validate, ignore=ignore, context=context + ) +""" + +MARK_METHOD: Final = ''' +def mark_{mark}({def_arglist}) -> Self: + """Set the chart's mark to '{mark}' (see :class:`{mark_def}`) + """ + kwds = dict({dict_arglist}) + copy = self.copy(deep=False) # type: ignore[attr-defined] + if any(val is not Undefined for val in kwds.values()): + copy.mark = core.{mark_def}(type="{mark}", **kwds) + else: + copy.mark = "{mark}" + return copy +''' + +CONFIG_METHOD: Final = """ +@use_signature(core.{classname}) +def {method}(self, *args, **kwargs) -> Self: + copy = self.copy(deep=False) # type: ignore[attr-defined] + copy.config = core.{classname}(*args, **kwargs) + return copy +""" + +CONFIG_PROP_METHOD: Final = """ +@use_signature(core.{classname}) +def configure_{prop}(self, *args, **kwargs) -> Self: + copy = self.copy(deep=['config']) # type: ignore[attr-defined] + if copy.config is Undefined: + copy.config = core.Config() + copy.config["{prop}"] = core.{classname}(*args, **kwargs) + return copy +""" + +ENCODE_METHOD: Final = ''' +class _EncodingMixin: + def encode({method_args}) -> Self: + """Map properties of the data to visual properties of the chart (see :class:`FacetedEncoding`) + {docstring}""" + # Compat prep for `infer_encoding_types` signature + kwargs = locals() + kwargs.pop("self") + args = kwargs.pop("args") + if args: + kwargs = {{k: v for k, v in kwargs.items() if v is not Undefined}} + + # Convert args to kwargs based on their types. + kwargs = _infer_encoding_types(args, kwargs) + # get a copy of the dict representation of the previous encoding + # ignore type as copy method comes from SchemaBase + copy = self.copy(deep=['encoding']) # type: ignore[attr-defined] + encoding = copy._get('encoding', {{}}) + if isinstance(encoding, core.VegaLiteSchema): + encoding = {{k: v for k, v in encoding._kwds.items() if v is not Undefined}} + # update with the new encodings, and apply them to the copy + encoding.update(kwargs) + copy.encoding = core.FacetedEncoding(**encoding) + return copy +''' + +ENCODE_TYPED_DICT: Final = ''' +class EncodeKwds(TypedDict, total=False): + """Encoding channels map properties of the data to visual properties of the chart. + {docstring}""" + {channels} + +''' + +# NOTE: Not yet reasonable to generalize `TypeAliasType`, `TypeVar` +# Revisit if this starts to become more common +TYPING_EXTRA: Final = ''' +T = TypeVar("T") +OneOrSeq = TypeAliasType("OneOrSeq", Union[T, Sequence[T]], type_params=(T,)) +"""One of ``T`` specified type(s), or a `Sequence` of such. + +Examples +-------- +The parameters ``short``, ``long`` accept the same range of types:: + + # ruff: noqa: UP006, UP007 + + def func( + short: OneOrSeq[str | bool | float], + long: Union[str, bool, float, Sequence[Union[str, bool, float]], + ): ... +""" +''' + + +class SchemaGenerator(codegen.SchemaGenerator): + schema_class_template = textwrap.dedent( + ''' + class {classname}({basename}): + """{docstring}""" + _schema = {schema!r} + + {init_code} + ''' + ) + + @staticmethod + def _process_description(description: str) -> str: + return process_description(description) + + +def process_description(description: str) -> str: + # remove formatting from links + description = "".join( + [ + reSpecial.sub("", d) if i % 2 else d + for i, d in enumerate(reLink.split(description)) + ] + ) + description = rst_parse(description) + # Some entries in the Vega-Lite schema miss the second occurence of '__' + description = description.replace("__Default value: ", "__Default value:__ ") + # Fixing ambiguous unicode, RUF001 produces RUF002 in docs + description = description.replace("’", "'") # noqa: RUF001 [RIGHT SINGLE QUOTATION MARK] + description = description.replace("–", "-") # noqa: RUF001 [EN DASH] + description = description.replace(" ", " ") # noqa: RUF001 [NO-BREAK SPACE] + return description.strip() + + +class FieldSchemaGenerator(SchemaGenerator): + schema_class_template = textwrap.dedent( + ''' + @with_property_setters + class {classname}(FieldChannelMixin, core.{basename}): + """{docstring}""" + _class_is_valid_at_instantiation = False + _encoding_name = "{encodingname}" + + {method_code} + + {init_code} + ''' + ) + + +class ValueSchemaGenerator(SchemaGenerator): + schema_class_template = textwrap.dedent( + ''' + @with_property_setters + class {classname}(ValueChannelMixin, core.{basename}): + """{docstring}""" + _class_is_valid_at_instantiation = False + _encoding_name = "{encodingname}" + + {method_code} + + {init_code} + ''' + ) + + +class DatumSchemaGenerator(SchemaGenerator): + schema_class_template = textwrap.dedent( + ''' + @with_property_setters + class {classname}(DatumChannelMixin, core.{basename}): + """{docstring}""" + _class_is_valid_at_instantiation = False + _encoding_name = "{encodingname}" + + {method_code} + + {init_code} + ''' + ) + + +def schema_class(*args, **kwargs) -> str: + return SchemaGenerator(*args, **kwargs).schema_class() + + +def schema_url(version: str = SCHEMA_VERSION) -> str: + return SCHEMA_URL_TEMPLATE.format(library="vega-lite", version=version) + + +def download_schemafile( + version: str, schemapath: Path, skip_download: bool = False +) -> Path: + url = schema_url(version=version) + schemadir = Path(schemapath) + schemadir.mkdir(parents=True, exist_ok=True) + fp = schemadir / "vega-lite-schema.json" + if not skip_download: + request.urlretrieve(url, fp) + elif not fp.exists(): + msg = f"Cannot skip download: {fp!s} does not exist" + raise ValueError(msg) + return fp + + +def update_vega_themes(fp: Path, /, indent: str | int | None = 2) -> None: + themes = vlc.get_themes() + data = json.dumps(themes, indent=indent, sort_keys=True) + fp.write_text(data, encoding="utf8") + + theme_names = sorted(iter(themes)) + TypeAliasTracer.update_aliases(("VegaThemes", spell_literal(theme_names))) + + +def load_schema_with_shorthand_properties(schemapath: Path) -> dict: + with schemapath.open(encoding="utf8") as f: + schema = json.load(f) + + # At this point, schema is a python Dict + # Not sure what the below function does. It uses a lot of JSON logic + schema = _add_shorthand_property_to_field_encodings(schema) + return schema + + +def _add_shorthand_property_to_field_encodings(schema: dict) -> dict: + encoding_def = "FacetedEncoding" + + encoding = SchemaInfo(schema["definitions"][encoding_def], rootschema=schema) + + #print(yaml.dump(schema, default_flow_style=False)) + for _, propschema in encoding.properties.items(): + def_dict = get_field_datum_value_defs(propschema, schema) + + field_ref = def_dict.get("field") + if field_ref is not None: + defschema = {"$ref": field_ref} + defschema = copy.deepcopy(resolve_references(defschema, schema)) + # For Encoding field definitions, we patch the schema by adding the + # shorthand property. + defschema["properties"]["shorthand"] = { + "anyOf": [ + {"type": "string"}, + {"type": "array", "items": {"type": "string"}}, + {"$ref": "#/definitions/RepeatRef"}, + ], + "description": "shorthand for field, aggregate, and type", + } + if "required" not in defschema: + defschema["required"] = ["shorthand"] + elif "shorthand" not in defschema["required"]: + defschema["required"].append("shorthand") + schema["definitions"][field_ref.split("/")[-1]] = defschema + return schema + + +def copy_schemapi_util() -> None: + """Copy the schemapi utility into altair/utils/ and its test file to tests/utils/.""" + # copy the schemapi utility file + source_fp = Path(__file__).parent / "schemapi" / "schemapi.py" + destination_fp = Path(__file__).parent / ".." / "altair" / "utils" / "schemapi.py" + + print(f"Copying\n {source_fp!s}\n -> {destination_fp!s}") + with source_fp.open(encoding="utf8") as source, destination_fp.open( + "w", encoding="utf8" + ) as dest: + dest.write(HEADER) + dest.writelines(source.readlines()) + if sys.platform == "win32": + ruff_format_py(destination_fp) + + +def recursive_dict_update(schema: dict, root: dict, def_dict: dict) -> None: + if "$ref" in schema: + next_schema = resolve_references(schema, root) + if "properties" in next_schema: + definition = schema["$ref"] + properties = next_schema["properties"] + for k in def_dict: + if k in properties: + def_dict[k] = definition + else: + recursive_dict_update(next_schema, root, def_dict) + elif "anyOf" in schema: + for sub_schema in schema["anyOf"]: + recursive_dict_update(sub_schema, root, def_dict) + + +def get_field_datum_value_defs(propschema: SchemaInfo, root: dict) -> dict[str, str]: + def_dict: dict[str, str | None] = dict.fromkeys(("field", "datum", "value")) + schema = propschema.schema + if propschema.is_reference() and "properties" in schema: + if "field" in schema["properties"]: + def_dict["field"] = propschema.ref + else: + msg = "Unexpected schema structure" + raise ValueError(msg) + else: + recursive_dict_update(schema, root, def_dict) + + return {i: j for i, j in def_dict.items() if j} + + +def toposort(graph: dict[str, list[str]]) -> list[str]: + """ + Topological sort of a directed acyclic graph. + + Parameters + ---------- + graph : dict of lists + Mapping of node labels to list of child node labels. + This is assumed to represent a graph with no cycles. + + Returns + ------- + order : list + topological order of input graph. + """ + # Once we drop support for Python 3.8, this can potentially be replaced + # with graphlib.TopologicalSorter from the standard library. + stack: list[str] = [] + visited: dict[str, Literal[True]] = {} + + def visit(nodes): + for node in sorted(nodes, reverse=True): + if not visited.get(node): + visited[node] = True + visit(graph.get(node, [])) + stack.insert(0, node) + + visit(graph) + return stack + +### (X) Function to generate a schema wrapper for Vega-Lite. +def generate_vegalite_schema_wrapper(schema_file: Path) -> str: + """Generate a schema wrapper at the given path.""" + # TODO: generate simple tests for each wrapper + basename = "VegaLiteSchema" + + # Not sure what the below function does. It uses a lot of JSON logic + # I'm thinkking of it as just loading the schema + rootschema = load_schema_with_shorthand_properties(schema_file) + + definitions: dict[str, SchemaGenerator] = {} + + ### (X) Loop through the definitions in the rootschema and create a SchemaGenerator for each one. + # There is a schema generator object for every single lowest level key in the JSON object + for name in rootschema["definitions"]: + defschema = {"$ref": "#/definitions/" + name} + defschema_repr = {"$ref": "#/definitions/" + name} + name = get_valid_identifier(name) + definitions[name] = SchemaGenerator( + name, + schema=defschema, + schemarepr=defschema_repr, + rootschema=rootschema, + basename=basename, + rootschemarepr=CodeSnippet(f"{basename}._rootschema"), + ) + + #print(definitions) + #print("\n\n\n") + + ### (X) Create a DAG of the definitions. + # The DAG consists of each lowest level key corresponding to an array of each in-document $ref + # reference in a dictionary + graph: dict[str, list[str]] = {} + + for name, schema in definitions.items(): + graph[name] = [] + for child_name in schema.subclasses(): + child_name = get_valid_identifier(child_name) + graph[name].append(child_name) + child: SchemaGenerator = definitions[child_name] + if child.basename == basename: + child.basename = [name] + else: + assert isinstance(child.basename, list) + child.basename.append(name) + + #print(graph) + + # Specify __all__ explicitly so that we can exclude the ones from the list + # of exported classes which are also defined in the channels or api modules which takes + # precedent in the generated __init__.py files one and two levels up. + # Importing these classes from multiple modules confuses type checkers. + EXCLUDE = {"Color", "Text", "LookupData", "Dict", "FacetMapping"} + it = (c for c in definitions.keys() - EXCLUDE if not c.startswith("_")) + all_ = [*sorted(it), "Root", "VegaLiteSchema", "SchemaBase", "load_schema"] + + contents = [ + HEADER, + "from __future__ import annotations\n" + "from typing import Any, Literal, Union, Protocol, Sequence, List, Iterator, TYPE_CHECKING", + "import pkgutil", + "import json\n", + "from narwhals.dependencies import is_pandas_dataframe as _is_pandas_dataframe", + "from altair.utils.schemapi import SchemaBase, Undefined, UndefinedType, _subclasses # noqa: F401\n", + _type_checking_only_imports( + "from altair import Parameter", + "from altair.typing import Optional", + "from ._typing import * # noqa: F403", + ), + "\n" f"__all__ = {all_}\n", + LOAD_SCHEMA.format(schemafile="vega-lite-schema.json"), + BASE_SCHEMA.format(basename=basename), + schema_class( + "Root", + schema=rootschema, + basename=basename, + schemarepr=CodeSnippet(f"{basename}._rootschema"), + ), + ] + + ### (X) Append the schema classes in topological order to the contents. + # This sort puts the edges at the start of the reference chain first + for name in toposort(graph): + contents.append(definitions[name].schema_class()) + + contents.append("") # end with newline + return "\n".join(contents) + + +def _type_checking_only_imports(*imports: str) -> str: + return ( + "\n# ruff: noqa: F405\nif TYPE_CHECKING:\n" + + "\n".join(f" {s}" for s in imports) + + "\n" + ) + + +@dataclass +class ChannelInfo: + supports_arrays: bool + deep_description: str + field_class_name: str + datum_class_name: str | None = None + value_class_name: str | None = None + + @property + def is_field_only(self) -> bool: + return not (self.datum_class_name or self.value_class_name) + + @property + def all_names(self) -> Iterator[str]: + """All channels are expected to have a field class.""" + yield self.field_class_name + yield from self.non_field_names + + @property + def non_field_names(self) -> Iterator[str]: + if self.is_field_only: + yield from () + else: + if self.datum_class_name: + yield self.datum_class_name + if self.value_class_name: + yield self.value_class_name + + +def generate_vegalite_channel_wrappers( + schemafile: Path, version: str, imports: list[str] | None = None +) -> str: + schema = load_schema_with_shorthand_properties(schemafile) + + encoding_def = "FacetedEncoding" + + encoding = SchemaInfo(schema["definitions"][encoding_def], rootschema=schema) + + channel_infos: dict[str, ChannelInfo] = {} + + class_defs = [] + + for prop, propschema in encoding.properties.items(): + def_dict = get_field_datum_value_defs(propschema, schema) + + supports_arrays = any( + schema_info.is_array() for schema_info in propschema.anyOf + ) + classname: str = prop[0].upper() + prop[1:] + channel_info = ChannelInfo( + supports_arrays=supports_arrays, + deep_description=propschema.deep_description, + field_class_name=classname, + ) + + for encoding_spec, definition in def_dict.items(): + basename = definition.rsplit("/", maxsplit=1)[-1] + basename = get_valid_identifier(basename) + + gen: SchemaGenerator + defschema = {"$ref": definition} + kwds = { + "basename": basename, + "schema": defschema, + "rootschema": schema, + "encodingname": prop, + "haspropsetters": True, + } + if encoding_spec == "field": + gen = FieldSchemaGenerator(classname, nodefault=[], **kwds) + elif encoding_spec == "datum": + temp_name = f"{classname}Datum" + channel_info.datum_class_name = temp_name + gen = DatumSchemaGenerator(temp_name, nodefault=["datum"], **kwds) + elif encoding_spec == "value": + temp_name = f"{classname}Value" + channel_info.value_class_name = temp_name + gen = ValueSchemaGenerator(temp_name, nodefault=["value"], **kwds) + + class_defs.append(gen.schema_class()) + + channel_infos[prop] = channel_info + + # NOTE: See https://github.com/vega/altair/pull/3482#issuecomment-2241577342 + COMPAT_EXPORTS = ( + "DatumChannelMixin", + "FieldChannelMixin", + "ValueChannelMixin", + "with_property_setters", + ) + + it = chain.from_iterable(info.all_names for info in channel_infos.values()) + all_ = list(chain(it, COMPAT_EXPORTS)) + + imports = imports or [ + "from __future__ import annotations\n", + "from typing import Any, overload, Sequence, List, Literal, Union, TYPE_CHECKING, TypedDict", + "from typing_extensions import TypeAlias", + "import narwhals.stable.v1 as nw", + "from altair.utils.schemapi import Undefined, with_property_setters", + "from altair.utils import infer_encoding_types as _infer_encoding_types", + "from altair.utils import parse_shorthand", + "from . import core", + "from ._typing import * # noqa: F403", + ] + contents = [ + HEADER, + CHANNEL_MYPY_IGNORE_STATEMENTS, + *imports, + _type_checking_only_imports( + "from altair import Parameter, SchemaBase", + "from altair.typing import Optional", + "from typing_extensions import Self", + ), + "\n" f"__all__ = {sorted(all_)}\n", + CHANNEL_MIXINS, + *class_defs, + *generate_encoding_artifacts(channel_infos, ENCODE_METHOD, ENCODE_TYPED_DICT), + ] + return "\n".join(contents) + + +def generate_vegalite_mark_mixin( + schemafile: Path, markdefs: dict[str, str] +) -> tuple[list[str], str]: + with schemafile.open(encoding="utf8") as f: + schema = json.load(f) + + class_name = "MarkMethodMixin" + + imports = [ + "from typing import Any, Sequence, List, Literal, Union", + "", + "from altair.utils.schemapi import Undefined, UndefinedType", + "from . import core", + ] + + code = [ + f"class {class_name}:", + ' """A mixin class that defines mark methods"""', + ] + + for mark_enum, mark_def in markdefs.items(): + if "enum" in schema["definitions"][mark_enum]: + marks = schema["definitions"][mark_enum]["enum"] + else: + marks = [schema["definitions"][mark_enum]["const"]] + info = SchemaInfo({"$ref": f"#/definitions/{mark_def}"}, rootschema=schema) + + # adapted from SchemaInfo.init_code + arg_info = codegen.get_args(info) + arg_info.required -= {"type"} + arg_info.kwds -= {"type"} + + def_args = ["self"] + [ + f"{p}: " + + info.properties[p].get_python_type_representation( + for_type_hints=True, + additional_type_hints=["UndefinedType"], + ) + + " = Undefined" + for p in (sorted(arg_info.required) + sorted(arg_info.kwds)) + ] + dict_args = [ + f"{p}={p}" for p in (sorted(arg_info.required) + sorted(arg_info.kwds)) + ] + + if arg_info.additional or arg_info.invalid_kwds: + def_args.append("**kwds") + dict_args.append("**kwds") + + for mark in marks: + # TODO: only include args relevant to given type? + mark_method = MARK_METHOD.format( + mark=mark, + mark_def=mark_def, + def_arglist=", ".join(def_args), + dict_arglist=", ".join(dict_args), + ) + code.append("\n ".join(mark_method.splitlines())) + + return imports, "\n".join(code) + + +def generate_vegalite_config_mixin(schemafile: Path) -> tuple[list[str], str]: + imports = [ + "from . import core", + "from altair.utils import use_signature", + ] + + class_name = "ConfigMethodMixin" + + code = [ + f"class {class_name}:", + ' """A mixin class that defines config methods"""', + ] + with schemafile.open(encoding="utf8") as f: + schema = json.load(f) + info = SchemaInfo({"$ref": "#/definitions/Config"}, rootschema=schema) + + # configure() method + method = CONFIG_METHOD.format(classname="Config", method="configure") + code.append("\n ".join(method.splitlines())) + + # configure_prop() methods + for prop, prop_info in info.properties.items(): + classname = prop_info.refname + if classname and classname.endswith("Config"): + method = CONFIG_PROP_METHOD.format(classname=classname, prop=prop) + code.append("\n ".join(method.splitlines())) + return imports, "\n".join(code) + + +def vegalite_main(skip_download: bool = False) -> None: + version = SCHEMA_VERSION + ###(H) Below just gets the path to vegalite main file + vn = version.split(".")[0] + fp = (Path(__file__).parent / ".." / "altair" / "vegalite" / vn).resolve() + schemapath = fp / "schema" + ###(H) They download the schema, eg: altair/altair/vegalite/v5/schema/vega-lite-schema.json + schemafile = download_schemafile( + version=version, + schemapath=schemapath, + skip_download=skip_download, + ) + + fp_themes = schemapath / "vega-themes.json" + print(f"Updating themes\n {schemafile!s}\n ->{fp_themes!s}") + update_vega_themes(fp_themes) + + # Generate __init__.py file + outfile = schemapath / "__init__.py" + print(f"Writing {outfile!s}") + # The content is written word for word as seen + content = [ + "# ruff: noqa\n", + "from .core import *\nfrom .channels import *\n", + f"SCHEMA_VERSION = '{version}'\n", + f"SCHEMA_URL = {schema_url(version)!r}\n", + ] + ###(H)ruff is a python 'linter' written in Rust, which is essentially + ###syntax formatting and checking. + ###The function below is a combination of writing, ruff checking and formatting + ruff_write_lint_format_str(outfile, content) + + # TypeAliasTracer is imported from utils.py and keeps track of all aliases for literals + TypeAliasTracer.update_aliases(("Map", "Mapping[str, Any]")) + + ###(H) Note: Path is a type imported from pathlib. Every Path added to the files + ### dictionary is eventually written to and formatted using ruff + files: dict[Path, str | Iterable[str]] = {} + + # Generate the core schema wrappers + fp_core = schemapath / "core.py" + print(f"Generating\n {schemafile!s}\n ->{fp_core!s}") + # Reminder: the schemafile here is the downloaded reference schemafile + files[fp_core] = generate_vegalite_schema_wrapper(schemafile) + + # Generate the channel wrappers + fp_channels = schemapath / "channels.py" + print(f"Generating\n {schemafile!s}\n ->{fp_channels!s}") + files[fp_channels] = generate_vegalite_channel_wrappers(schemafile, version=version) + + # generate the mark mixin + # A mixin class is one which provides functionality to other classes as a standalone class + markdefs = {k: f"{k}Def" for k in ["Mark", "BoxPlot", "ErrorBar", "ErrorBand"]} + fp_mixins = schemapath / "mixins.py" + print(f"Generating\n {schemafile!s}\n ->{fp_mixins!s}") + + # The following function dynamically creates a mixin class that can be used for 'marks' (eg. bars on bar chart, dot on scatter) + mark_imports, mark_mixin = generate_vegalite_mark_mixin(schemafile, markdefs) + config_imports, config_mixin = generate_vegalite_config_mixin(schemafile) + try_except_imports = [ + "if sys.version_info >= (3, 11):", + " from typing import Self", + "else:", + " from typing_extensions import Self", + ] + stdlib_imports = ["from __future__ import annotations\n", "import sys"] + content_mixins = [ + HEADER, + "\n".join(stdlib_imports), + "\n\n", + "\n".join(sorted({*mark_imports, *config_imports})), + "\n\n", + "\n".join(try_except_imports), + "\n\n", + _type_checking_only_imports( + "from altair import Parameter, SchemaBase", + "from altair.typing import Optional", + "from ._typing import * # noqa: F403", + ), + "\n\n\n", + mark_mixin, + "\n\n\n", + config_mixin, + ] + files[fp_mixins] = content_mixins + + # Write `_typing.py` TypeAlias, for import in generated modules + fp_typing = schemapath / "_typing.py" + msg = ( + f"Generating\n {schemafile!s}\n ->{fp_typing!s}\n" + f"Tracer cache collected {TypeAliasTracer.n_entries!r} entries." + ) + print(msg) + TypeAliasTracer.write_module( + fp_typing, "OneOrSeq", header=HEADER, extra=TYPING_EXTRA + ) + # Write the pre-generated modules + for fp, contents in files.items(): + print(f"Writing\n {schemafile!s}\n ->{fp!s}") + ruff_write_lint_format_str(fp, contents) + + +def generate_encoding_artifacts( + channel_infos: dict[str, ChannelInfo], fmt_method: str, fmt_typed_dict: str +) -> Iterator[str]: + """ + Generate ``Chart.encode()`` and related typing structures. + + - `TypeAlias`(s) for each parameter to ``Chart.encode()`` + - Mixin class that provides the ``Chart.encode()`` method + - `TypedDict`, utilising/describing these structures as part of https://github.com/pola-rs/polars/pull/17995. + + Notes + ----- + - `Map`/`Dict` stands for the return types of `alt.(datum|value)`, and any encoding channel class. + - See discussions in https://github.com/vega/altair/pull/3208 + - We could be more specific about what types are accepted in the `List` + - but this translates poorly to an IDE + - `info.supports_arrays` + """ + signature_args: list[str] = ["self", "*args: Any"] + type_aliases: list[str] = [] + typed_dict_args: list[str] = [] + signature_doc_params: list[str] = ["", "Parameters", "----------"] + typed_dict_doc_params: list[str] = ["", "Parameters", "----------"] + + for channel, info in channel_infos.items(): + alias_name: str = f"Channel{channel[0].upper()}{channel[1:]}" + + it: Iterator[str] = info.all_names + it_rst_names: Iterator[str] = (rst_syntax_for_class(c) for c in info.all_names) + + docstring_types: list[str] = ["str", next(it_rst_names), "Dict"] + tp_inner: str = ", ".join(chain(("str", next(it), "Map"), it)) + tp_inner = f"Union[{tp_inner}]" + + if info.supports_arrays: + docstring_types.append("List") + tp_inner = f"OneOrSeq[{tp_inner}]" + + doc_types_flat: str = ", ".join(chain(docstring_types, it_rst_names)) + + type_aliases.append(f"{alias_name}: TypeAlias = {tp_inner}") + # We use the full type hints instead of the alias in the signatures below + # as IDEs such as VS Code would else show the name of the alias instead + # of the expanded full type hints. The later are more useful to users. + typed_dict_args.append(f"{channel}: {tp_inner}") + signature_args.append(f"{channel}: Optional[{tp_inner}] = Undefined") + + description: str = f" {process_description(info.deep_description)}" + + signature_doc_params.extend((f"{channel} : {doc_types_flat}", description)) + typed_dict_doc_params.extend((f"{channel}", description)) + + method: str = fmt_method.format( + method_args=", ".join(signature_args), + docstring=indent_docstring(signature_doc_params, indent_level=8, lstrip=False), + ) + typed_dict: str = fmt_typed_dict.format( + channels="\n ".join(typed_dict_args), + docstring=indent_docstring(typed_dict_doc_params, indent_level=4, lstrip=False), + ) + artifacts: Iterable[str] = *type_aliases, method, typed_dict + yield from artifacts + + +def main() -> None: + parser = argparse.ArgumentParser( + prog="generate_schema_wrapper.py", description="Generate the Altair package." + ) + parser.add_argument( + "--skip-download", action="store_true", help="skip downloading schema files" + ) + ###(H) I've used this library before. The below just does the actual arg parsing + args = parser.parse_args() + ###(H) Copies the schemapi.py file from schemapi to ../altair/utils + copy_schemapi_util() + + vegalite_main(args.skip_download) + + # The modules below are imported after the generation of the new schema files + # as these modules import Altair. This allows them to use the new changes + from tools import generate_api_docs, update_init_file + + generate_api_docs.write_api_file() + update_init_file.update__all__variable() + + +if __name__ == "__main__": + main() diff --git a/tools/generated_classes.py b/tools/generated_classes.py new file mode 100644 index 00000000..437d65d4 --- /dev/null +++ b/tools/generated_classes.py @@ -0,0 +1,37 @@ +from typing import Any, Union +class AggregateExpression: + def __init__(self, agg: str, label: str = None): + self.agg = agg + self.label = label + +class ParamRef: + def __init__(self): + pass + +class TransformField: + def __init__(self, value: Union["str", "ParamRef"]): + self.value = value + +class AggregateTransform: + def __init__(self, value: Union["Argmax", "Argmin", "Avg", "Count", "Max", "Min", "First", "Last", "Median", "Mode", "Product", "Quantile", "Stddev", "StddevPop", "Sum", "Variance", "VarPop"]): + self.value = value + +class Argmax: + def __init__(self, argmax: Any, distinct: bool = None, orderby: Union[TransformField, Any] = None, partitionby: Union[TransformField, Any] = None, range: Union[Any, ParamRef] = None, rows: Union[Any, ParamRef] = None): + self.argmax = argmax + self.distinct = distinct + self.orderby = orderby + self.partitionby = partitionby + self.range = range + self.rows = rows + +class Argmin: + def __init__(self, argmin: Any, distinct: bool = None, orderby: Union[TransformField, Any] = None, partitionby: Union[TransformField, Any] = None, range: Union[Any, ParamRef] = None, rows: Union[Any, ParamRef] = None): + self.argmin = argmin + self.distinct = distinct + self.orderby = orderby + self.partitionby = partitionby + self.range = range + self.rows = rows + + diff --git a/tools/schemapi/__init__.py b/tools/schemapi/__init__.py new file mode 100755 index 00000000..023a9a2a --- /dev/null +++ b/tools/schemapi/__init__.py @@ -0,0 +1,8 @@ +"""schemapi: tools for generating Python APIs from JSON schemas.""" + +from tools.schemapi import codegen, utils +from tools.schemapi.codegen import CodeSnippet +from tools.schemapi.schemapi import SchemaBase, Undefined +from tools.schemapi.utils import SchemaInfo + +__all__ = ["CodeSnippet", "SchemaBase", "SchemaInfo", "Undefined", "codegen", "utils"] diff --git a/tools/schemapi/__pycache__/__init__.cpython-311.pyc b/tools/schemapi/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..244c9bd4 Binary files /dev/null and b/tools/schemapi/__pycache__/__init__.cpython-311.pyc differ diff --git a/tools/schemapi/__pycache__/codegen.cpython-311.pyc b/tools/schemapi/__pycache__/codegen.cpython-311.pyc new file mode 100644 index 00000000..4f967470 Binary files /dev/null and b/tools/schemapi/__pycache__/codegen.cpython-311.pyc differ diff --git a/tools/schemapi/__pycache__/schemapi.cpython-311.pyc b/tools/schemapi/__pycache__/schemapi.cpython-311.pyc new file mode 100644 index 00000000..110df5c6 Binary files /dev/null and b/tools/schemapi/__pycache__/schemapi.cpython-311.pyc differ diff --git a/tools/schemapi/__pycache__/utils.cpython-311.pyc b/tools/schemapi/__pycache__/utils.cpython-311.pyc new file mode 100644 index 00000000..7f61e2d0 Binary files /dev/null and b/tools/schemapi/__pycache__/utils.cpython-311.pyc differ diff --git a/tools/schemapi/codegen.py b/tools/schemapi/codegen.py new file mode 100755 index 00000000..cf8ea81b --- /dev/null +++ b/tools/schemapi/codegen.py @@ -0,0 +1,380 @@ +"""Code generation utilities.""" + +from __future__ import annotations + +import re +import textwrap +from dataclasses import dataclass +from typing import Final + +from .utils import ( + SchemaInfo, + TypeAliasTracer, + flatten, + indent_docstring, + is_valid_identifier, + jsonschema_to_python_types, + spell_literal, +) + + +class CodeSnippet: + """Object whose repr() is a string of code.""" + + def __init__(self, code: str): + self.code = code + + def __repr__(self) -> str: + return self.code + + +@dataclass +class ArgInfo: + nonkeyword: bool + required: set[str] + kwds: set[str] + invalid_kwds: set[str] + additional: bool + + +def get_args(info: SchemaInfo) -> ArgInfo: + """Return the list of args & kwds for building the __init__ function.""" + # TODO: - set additional properties correctly + # - handle patternProperties etc. + required: set[str] = set() + kwds: set[str] = set() + invalid_kwds: set[str] = set() + + # TODO: specialize for anyOf/oneOf? + + if info.is_allOf(): + # recursively call function on all children + arginfo = [get_args(child) for child in info.allOf] + nonkeyword = all(args.nonkeyword for args in arginfo) + required = set.union(set(), *(args.required for args in arginfo)) + kwds = set.union(set(), *(args.kwds for args in arginfo)) + kwds -= required + invalid_kwds = set.union(set(), *(args.invalid_kwds for args in arginfo)) + additional = all(args.additional for args in arginfo) + elif info.is_empty() or info.is_compound(): + nonkeyword = True + additional = True + elif info.is_value(): + nonkeyword = True + additional = False + elif info.is_object(): + invalid_kwds = {p for p in info.required if not is_valid_identifier(p)} | { + p for p in info.properties if not is_valid_identifier(p) + } + required = {p for p in info.required if is_valid_identifier(p)} + kwds = {p for p in info.properties if is_valid_identifier(p)} + kwds -= required + nonkeyword = False + additional = True + # additional = info.additionalProperties or info.patternProperties + else: + msg = "Schema object not understood" + raise ValueError(msg) + + return ArgInfo( + nonkeyword=nonkeyword, + required=required, + kwds=kwds, + invalid_kwds=invalid_kwds, + additional=additional, + ) + + +class SchemaGenerator: + """ + Class that defines methods for generating code from schemas. + + Parameters + ---------- + classname : string + The name of the class to generate + schema : dict + The dictionary defining the schema class + rootschema : dict (optional) + The root schema for the class + basename : string or list of strings (default: "SchemaBase") + The name(s) of the base class(es) to use in the class definition + schemarepr : CodeSnippet or object, optional + An object whose repr will be used in the place of the explicit schema. + This can be useful, for example, when the generated code should reference + a predefined schema object. The user must ensure that the schema within + the evaluated code is identical to the schema used to generate the code. + rootschemarepr : CodeSnippet or object, optional + An object whose repr will be used in the place of the explicit root + schema. + **kwargs : dict + Additional keywords for derived classes. + """ + + schema_class_template = textwrap.dedent( + ''' + class {classname}({basename}): + """{docstring}""" + _schema = {schema!r} + _rootschema = {rootschema!r} + + {init_code} + ''' + ) + + init_template: Final = textwrap.dedent( + """ + def __init__({arglist}): + super({classname}, self).__init__({super_arglist}) + """ + ).lstrip() + + def _process_description(self, description: str): + return description + + def __init__( + self, + classname: str, + schema: dict, + rootschema: dict | None = None, + basename: str | list[str] = "SchemaBase", + schemarepr: object | None = None, + rootschemarepr: object | None = None, + nodefault: list[str] | None = None, + haspropsetters: bool = False, + **kwargs, + ) -> None: + self.classname = classname + self.schema = schema + self.rootschema = rootschema + self.basename = basename + self.schemarepr = schemarepr + self.rootschemarepr = rootschemarepr + self.nodefault = nodefault or () + self.haspropsetters = haspropsetters + self.kwargs = kwargs + + def subclasses(self) -> list[str]: + """Return a list of subclass names, if any.""" + info = SchemaInfo(self.schema, self.rootschema) + return [child.refname for child in info.anyOf if child.is_reference()] + + def schema_class(self) -> str: + """Generate code for a schema class.""" + rootschema: dict = ( + self.rootschema if self.rootschema is not None else self.schema + ) + schemarepr: object = ( + self.schemarepr if self.schemarepr is not None else self.schema + ) + rootschemarepr = self.rootschemarepr + if rootschemarepr is None: + if rootschema is self.schema: + rootschemarepr = CodeSnippet("_schema") + else: + rootschemarepr = rootschema + if isinstance(self.basename, str): + basename = self.basename + else: + basename = ", ".join(self.basename) + return self.schema_class_template.format( + classname=self.classname, + basename=basename, + schema=schemarepr, + rootschema=rootschemarepr, + docstring=self.docstring(indent=4), + init_code=self.init_code(indent=4), + method_code=self.method_code(indent=4), + **self.kwargs, + ) + + @property + def info(self) -> SchemaInfo: + return SchemaInfo(self.schema, self.rootschema) + + @property + def arg_info(self) -> ArgInfo: + return get_args(self.info) + + def docstring(self, indent: int = 0) -> str: + info = self.info + # https://numpydoc.readthedocs.io/en/latest/format.html#short-summary + doc = [f"{self.classname} schema wrapper"] + if info.description: + # https://numpydoc.readthedocs.io/en/latest/format.html#extended-summary + # Remove condition from description + desc: str = re.sub(r"\n\{\n(\n|.)*\n\}", "", info.description) + ext_summary: list[str] = self._process_description(desc).splitlines() + # Remove lines which contain the "raw-html" directive which cannot be processed + # by Sphinx at this level of the docstring. It works for descriptions + # of attributes which is why we do not do the same below. The removed + # lines are anyway non-descriptive for a user. + ext_summary = [line for line in ext_summary if ":raw-html:" not in line] + # Only add an extended summary if the above did not result in an empty list. + if ext_summary: + doc.append("") + doc.extend(ext_summary) + + if info.properties: + arg_info = self.arg_info + doc += ["", "Parameters", "----------", ""] + for prop in ( + sorted(arg_info.required) + + sorted(arg_info.kwds) + + sorted(arg_info.invalid_kwds) + ): + propinfo = info.properties[prop] + doc += [ + f"{prop} : {propinfo.get_python_type_representation()}", + f" {self._process_description(propinfo.deep_description)}", + ] + return indent_docstring(doc, indent_level=indent, width=100, lstrip=True) + + def init_code(self, indent: int = 0) -> str: + """Return code suitable for the __init__ function of a Schema class.""" + args, super_args = self.init_args() + + initfunc = self.init_template.format( + classname=self.classname, + arglist=", ".join(args), + super_arglist=", ".join(super_args), + ) + if indent: + initfunc = ("\n" + indent * " ").join(initfunc.splitlines()) + return initfunc + + def init_args( + self, additional_types: list[str] | None = None + ) -> tuple[list[str], list[str]]: + additional_types = additional_types or [] + info = self.info + arg_info = self.arg_info + + nodefault = set(self.nodefault) + arg_info.required -= nodefault + arg_info.kwds -= nodefault + + args: list[str] = ["self"] + super_args: list[str] = [] + + self.init_kwds = sorted(arg_info.kwds) + + if nodefault: + args.extend(sorted(nodefault)) + elif arg_info.nonkeyword: + args.append("*args") + super_args.append("*args") + + args.extend( + f"{p}: Optional[Union[" + + ", ".join( + [ + *additional_types, + *info.properties[p].get_python_type_representation( + for_type_hints=True, return_as_str=False + ), + ] + ) + + "]] = Undefined" + for p in sorted(arg_info.required) + sorted(arg_info.kwds) + ) + super_args.extend( + f"{p}={p}" + for p in sorted(nodefault) + + sorted(arg_info.required) + + sorted(arg_info.kwds) + ) + + if arg_info.additional: + args.append("**kwds") + super_args.append("**kwds") + return args, super_args + + def get_args(self, si: SchemaInfo) -> list[str]: + contents = ["self"] + prop_infos: dict[str, SchemaInfo] = {} + if si.is_anyOf(): + prop_infos = {} + for si_sub in si.anyOf: + prop_infos.update(si_sub.properties) + elif si.properties: + prop_infos = dict(si.properties.items()) + + if prop_infos: + contents.extend( + [ + f"{p}: " + + info.get_python_type_representation( + for_type_hints=True, additional_type_hints=["UndefinedType"] + ) + + " = Undefined" + for p, info in prop_infos.items() + ] + ) + elif si.type: + py_type = jsonschema_to_python_types[si.type] + if py_type == "list": + # Try to get a type hint like "List[str]" which is more specific + # then just "list" + item_vl_type = si.items.get("type", None) + if item_vl_type is not None: + item_type = jsonschema_to_python_types[item_vl_type] + else: + item_si = SchemaInfo(si.items, self.rootschema) + assert item_si.is_reference() + altair_class_name = item_si.title + item_type = f"core.{altair_class_name}" + py_type = f"List[{item_type}]" + elif si.is_literal(): + # If it's an enum, we can type hint it as a Literal which tells + # a type checker that only the values in enum are acceptable + py_type = TypeAliasTracer.add_literal( + si, spell_literal(si.literal), replace=True + ) + contents.append(f"_: {py_type}") + + contents.append("**kwds") + + return contents + + def get_signature( + self, attr: str, sub_si: SchemaInfo, indent: int, has_overload: bool = False + ) -> list[str]: + lines = [] + if has_overload: + lines.append("@overload") + args = ", ".join(self.get_args(sub_si)) + lines.extend( + (f"def {attr}({args}) -> '{self.classname}':", indent * " " + "...\n") + ) + return lines + + def setter_hint(self, attr: str, indent: int) -> list[str]: + si = SchemaInfo(self.schema, self.rootschema).properties[attr] + if si.is_anyOf(): + return self._get_signature_any_of(si, attr, indent) + else: + return self.get_signature(attr, si, indent, has_overload=True) + + def _get_signature_any_of( + self, si: SchemaInfo, attr: str, indent: int + ) -> list[str]: + signatures = [] + for sub_si in si.anyOf: + if sub_si.is_anyOf(): + # Recursively call method again to go a level deeper + signatures.extend(self._get_signature_any_of(sub_si, attr, indent)) + else: + signatures.extend( + self.get_signature(attr, sub_si, indent, has_overload=True) + ) + return list(flatten(signatures)) + + def method_code(self, indent: int = 0) -> str | None: + """Return code to assist setter methods.""" + if not self.haspropsetters: + return None + args = self.init_kwds + type_hints = [hint for a in args for hint in self.setter_hint(a, indent)] + + return ("\n" + indent * " ").join(type_hints) diff --git a/tools/schemapi/schemapi.py b/tools/schemapi/schemapi.py new file mode 100755 index 00000000..5140073a --- /dev/null +++ b/tools/schemapi/schemapi.py @@ -0,0 +1,1496 @@ +from __future__ import annotations + +import contextlib +import copy +import inspect +import json +import sys +import textwrap +from collections import defaultdict +from functools import partial +from importlib.metadata import version as importlib_version +from itertools import chain, zip_longest +from math import ceil +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Final, + Iterable, + Iterator, + List, + Literal, + Sequence, + TypeVar, + Union, + cast, + overload, +) +from typing_extensions import TypeAlias + +import jsonschema +import jsonschema.exceptions +import jsonschema.validators +import narwhals.stable.v1 as nw +from packaging.version import Version + +# This leads to circular imports with the vegalite module. Currently, this works +# but be aware that when you access it in this script, the vegalite module might +# not yet be fully instantiated in case your code is being executed during import time +from altair import vegalite + +if TYPE_CHECKING: + from types import ModuleType + from typing import ClassVar + + from referencing import Registry + + from altair.typing import ChartType + + if sys.version_info >= (3, 13): + from typing import TypeIs + else: + from typing_extensions import TypeIs + + if sys.version_info >= (3, 11): + from typing import Never, Self + else: + from typing_extensions import Never, Self + _OptionalModule: TypeAlias = "ModuleType | None" + +ValidationErrorList: TypeAlias = List[jsonschema.exceptions.ValidationError] +GroupedValidationErrors: TypeAlias = Dict[str, ValidationErrorList] + +# This URI is arbitrary and could be anything else. It just cannot be an empty +# string as we need to reference the schema registered in +# the referencing.Registry. +_VEGA_LITE_ROOT_URI: Final = "urn:vega-lite-schema" + +# Ideally, jsonschema specification would be parsed from the current Vega-Lite +# schema instead of being hardcoded here as a default value. +# However, due to circular imports between this module and the altair.vegalite +# modules, this information is not yet available at this point as altair.vegalite +# is only partially loaded. The draft version which is used is unlikely to +# change often so it's ok to keep this. There is also a test which validates +# that this value is always the same as in the Vega-Lite schema. +_DEFAULT_JSON_SCHEMA_DRAFT_URL: Final = "http://json-schema.org/draft-07/schema#" + + +# If DEBUG_MODE is True, then schema objects are converted to dict and +# validated at creation time. This slows things down, particularly for +# larger specs, but leads to much more useful tracebacks for the user. +# Individual schema classes can override this by setting the +# class-level _class_is_valid_at_instantiation attribute to False +DEBUG_MODE: bool = True + +jsonschema_version_str = importlib_version("jsonschema") + + +def enable_debug_mode() -> None: + global DEBUG_MODE + DEBUG_MODE = True + + +def disable_debug_mode() -> None: + global DEBUG_MODE + DEBUG_MODE = False + + +@contextlib.contextmanager +def debug_mode(arg: bool) -> Iterator[None]: + global DEBUG_MODE + original = DEBUG_MODE + DEBUG_MODE = arg + try: + yield + finally: + DEBUG_MODE = original + + +@overload +def validate_jsonschema( + spec: Any, + schema: dict[str, Any], + rootschema: dict[str, Any] | None = ..., + *, + raise_error: Literal[True] = ..., +) -> Never: ... + + +@overload +def validate_jsonschema( + spec: Any, + schema: dict[str, Any], + rootschema: dict[str, Any] | None = ..., + *, + raise_error: Literal[False], +) -> jsonschema.exceptions.ValidationError | None: ... + + +def validate_jsonschema( + spec, + schema: dict[str, Any], + rootschema: dict[str, Any] | None = None, + *, + raise_error: bool = True, +) -> jsonschema.exceptions.ValidationError | None: + """ + Validates the passed in spec against the schema in the context of the rootschema. + + If any errors are found, they are deduplicated and prioritized + and only the most relevant errors are kept. Errors are then either raised + or returned, depending on the value of `raise_error`. + """ + errors = _get_errors_from_spec(spec, schema, rootschema=rootschema) + if errors: + leaf_errors = _get_leaves_of_error_tree(errors) + grouped_errors = _group_errors_by_json_path(leaf_errors) + grouped_errors = _subset_to_most_specific_json_paths(grouped_errors) + grouped_errors = _deduplicate_errors(grouped_errors) + + # Nothing special about this first error but we need to choose one + # which can be raised + main_error: Any = next(iter(grouped_errors.values()))[0] + # All errors are then attached as a new attribute to ValidationError so that + # they can be used in SchemaValidationError to craft a more helpful + # error message. Setting a new attribute like this is not ideal as + # it then no longer matches the type ValidationError. It would be better + # to refactor this function to never raise but only return errors. + main_error._all_errors = grouped_errors + if raise_error: + raise main_error + else: + return main_error + else: + return None + + +def _get_errors_from_spec( + spec: dict[str, Any], + schema: dict[str, Any], + rootschema: dict[str, Any] | None = None, +) -> ValidationErrorList: + """ + Uses the relevant jsonschema validator to validate the passed in spec against the schema using the rootschema to resolve references. + + The schema and rootschema themselves are not validated but instead considered as valid. + """ + # We don't use jsonschema.validate as this would validate the schema itself. + # Instead, we pass the schema directly to the validator class. This is done for + # two reasons: The schema comes from Vega-Lite and is not based on the user + # input, therefore there is no need to validate it in the first place. Furthermore, + # the "uri-reference" format checker fails for some of the references as URIs in + # "$ref" are not encoded, + # e.g. '#/definitions/ValueDefWithCondition' would be a valid $ref in a Vega-Lite schema but + # it is not a valid URI reference due to the characters such as '<'. + + json_schema_draft_url = _get_json_schema_draft_url(rootschema or schema) + validator_cls = jsonschema.validators.validator_for( + {"$schema": json_schema_draft_url} + ) + validator_kwargs: dict[str, Any] = {} + if hasattr(validator_cls, "FORMAT_CHECKER"): + validator_kwargs["format_checker"] = validator_cls.FORMAT_CHECKER + + if _use_referencing_library(): + schema = _prepare_references_in_schema(schema) + validator_kwargs["registry"] = _get_referencing_registry( + rootschema or schema, json_schema_draft_url + ) + + else: + # No resolver is necessary if the schema is already the full schema + validator_kwargs["resolver"] = ( + jsonschema.RefResolver.from_schema(rootschema) + if rootschema is not None + else None + ) + + validator = validator_cls(schema, **validator_kwargs) + errors = list(validator.iter_errors(spec)) + return errors + + +def _get_json_schema_draft_url(schema: dict[str, Any]) -> str: + return schema.get("$schema", _DEFAULT_JSON_SCHEMA_DRAFT_URL) + + +def _use_referencing_library() -> bool: + """In version 4.18.0, the jsonschema package deprecated RefResolver in favor of the referencing library.""" + return Version(jsonschema_version_str) >= Version("4.18") + + +def _prepare_references_in_schema(schema: dict[str, Any]) -> dict[str, Any]: + # Create a copy so that $ref is not modified in the original schema in case + # that it would still reference a dictionary which might be attached to + # an Altair class _schema attribute + schema = copy.deepcopy(schema) + + def _prepare_refs(d: dict[str, Any]) -> dict[str, Any]: + """ + Add _VEGA_LITE_ROOT_URI in front of all $ref values. + + This function recursively iterates through the whole dictionary. + + $ref values can only be nested in dictionaries or lists + as the passed in `d` dictionary comes from the Vega-Lite json schema + and in json we only have arrays (-> lists in Python) and objects + (-> dictionaries in Python) which we need to iterate through. + """ + for key, value in d.items(): + if key == "$ref": + d[key] = _VEGA_LITE_ROOT_URI + d[key] + elif isinstance(value, dict): + d[key] = _prepare_refs(value) + elif isinstance(value, list): + prepared_values = [] + for v in value: + if isinstance(v, dict): + v = _prepare_refs(v) + prepared_values.append(v) + d[key] = prepared_values + return d + + schema = _prepare_refs(schema) + return schema + + +# We do not annotate the return value here as the referencing library is not always +# available and this function is only executed in those cases. +def _get_referencing_registry( + rootschema: dict[str, Any], json_schema_draft_url: str | None = None +) -> Registry: + # Referencing is a dependency of newer jsonschema versions, starting with the + # version that is specified in _use_referencing_library and we therefore + # can expect that it is installed if the function returns True. + # We ignore 'import' mypy errors which happen when the referencing library + # is not installed. That's ok as in these cases this function is not called. + # We also have to ignore 'unused-ignore' errors as mypy raises those in case + # referencing is installed. + import referencing # type: ignore[import,unused-ignore] + import referencing.jsonschema # type: ignore[import,unused-ignore] + + if json_schema_draft_url is None: + json_schema_draft_url = _get_json_schema_draft_url(rootschema) + + specification = referencing.jsonschema.specification_with(json_schema_draft_url) + resource = specification.create_resource(rootschema) + return referencing.Registry().with_resource( + uri=_VEGA_LITE_ROOT_URI, resource=resource + ) + + +def _json_path(err: jsonschema.exceptions.ValidationError) -> str: + """ + Drop in replacement for the .json_path property of the jsonschema ValidationError class. + + This is not available as property for ValidationError with jsonschema<4.0.1. + + More info, see https://github.com/vega/altair/issues/3038. + """ + path = "$" + for elem in err.absolute_path: + if isinstance(elem, int): + path += "[" + str(elem) + "]" + else: + path += "." + elem + return path + + +def _group_errors_by_json_path( + errors: ValidationErrorList, +) -> GroupedValidationErrors: + """ + Groups errors by the `json_path` attribute of the jsonschema ValidationError class. + + This attribute contains the path to the offending element within + a chart specification and can therefore be considered as an identifier of an + 'issue' in the chart that needs to be fixed. + """ + errors_by_json_path = defaultdict(list) + for err in errors: + err_key = getattr(err, "json_path", _json_path(err)) + errors_by_json_path[err_key].append(err) + return dict(errors_by_json_path) + + +def _get_leaves_of_error_tree( + errors: ValidationErrorList, +) -> ValidationErrorList: + """ + For each error in `errors`, it traverses down the "error tree" that is generated by the jsonschema library to find and return all "leaf" errors. + + These are errors which have no further errors that caused it and so they are the most specific errors + with the most specific error messages. + """ + leaves: ValidationErrorList = [] + for err in errors: + if err.context: + # This means that the error `err` was caused by errors in subschemas. + # The list of errors from the subschemas are available in the property + # `context`. + leaves.extend(_get_leaves_of_error_tree(err.context)) + else: + leaves.append(err) + return leaves + + +def _subset_to_most_specific_json_paths( + errors_by_json_path: GroupedValidationErrors, +) -> GroupedValidationErrors: + """ + Removes key (json path), value (errors) pairs where the json path is fully contained in another json path. + + For example if `errors_by_json_path` has two keys, `$.encoding.X` and `$.encoding.X.tooltip`, + then the first one will be removed and only the second one is returned. + + This is done under the assumption that more specific json paths give more helpful error messages to the user. + """ + errors_by_json_path_specific: GroupedValidationErrors = {} + for json_path, errors in errors_by_json_path.items(): + if not _contained_at_start_of_one_of_other_values( + json_path, list(errors_by_json_path.keys()) + ): + errors_by_json_path_specific[json_path] = errors + return errors_by_json_path_specific + + +def _contained_at_start_of_one_of_other_values(x: str, values: Sequence[str]) -> bool: + # Does not count as "contained at start of other value" if the values are + # the same. These cases should be handled separately + return any(value.startswith(x) for value in values if x != value) + + +def _deduplicate_errors( + grouped_errors: GroupedValidationErrors, +) -> GroupedValidationErrors: + """ + Some errors have very similar error messages or are just in general not helpful for a user. + + This function removes as many of these cases as possible and + can be extended over time to handle new cases that come up. + """ + grouped_errors_deduplicated: GroupedValidationErrors = {} + for json_path, element_errors in grouped_errors.items(): + errors_by_validator = _group_errors_by_validator(element_errors) + + deduplication_functions = { + "enum": _deduplicate_enum_errors, + "additionalProperties": _deduplicate_additional_properties_errors, + } + deduplicated_errors: ValidationErrorList = [] + for validator, errors in errors_by_validator.items(): + deduplication_func = deduplication_functions.get(validator) + if deduplication_func is not None: + errors = deduplication_func(errors) + deduplicated_errors.extend(_deduplicate_by_message(errors)) + + # Removes any ValidationError "'value' is a required property" as these + # errors are unlikely to be the relevant ones for the user. They come from + # validation against a schema definition where the output of `alt.value` + # would be valid. However, if a user uses `alt.value`, the `value` keyword + # is included automatically from that function and so it's unlikely + # that this was what the user intended if the keyword is not present + # in the first place. + deduplicated_errors = [ + err for err in deduplicated_errors if not _is_required_value_error(err) + ] + + grouped_errors_deduplicated[json_path] = deduplicated_errors + return grouped_errors_deduplicated + + +def _is_required_value_error(err: jsonschema.exceptions.ValidationError) -> bool: + return err.validator == "required" and err.validator_value == ["value"] + + +def _group_errors_by_validator(errors: ValidationErrorList) -> GroupedValidationErrors: + """ + Groups the errors by the json schema "validator" that casued the error. + + For example if the error is that a value is not one of an enumeration in the json schema + then the "validator" is `"enum"`, if the error is due to an unknown property that + was set although no additional properties are allowed then "validator" is + `"additionalProperties`, etc. + """ + errors_by_validator: defaultdict[str, ValidationErrorList] = defaultdict(list) + for err in errors: + # Ignore mypy error as err.validator as it wrongly sees err.validator + # as of type Optional[Validator] instead of str which it is according + # to the documentation and all tested cases + errors_by_validator[err.validator].append(err) # type: ignore[index] + return dict(errors_by_validator) + + +def _deduplicate_enum_errors(errors: ValidationErrorList) -> ValidationErrorList: + """ + Deduplicate enum errors by removing the errors where the allowed values are a subset of another error. + + For example, if `enum` contains two errors and one has `validator_value` (i.e. accepted values) ["A", "B"] and the + other one ["A", "B", "C"] then the first one is removed and the final + `enum` list only contains the error with ["A", "B", "C"]. + """ + if len(errors) > 1: + # Values (and therefore `validator_value`) of an enum are always arrays, + # see https://json-schema.org/understanding-json-schema/reference/generic.html#enumerated-values + # which is why we can use join below + value_strings = [",".join(err.validator_value) for err in errors] # type: ignore + longest_enums: ValidationErrorList = [] + for value_str, err in zip(value_strings, errors): + if not _contained_at_start_of_one_of_other_values(value_str, value_strings): + longest_enums.append(err) + errors = longest_enums + return errors + + +def _deduplicate_additional_properties_errors( + errors: ValidationErrorList, +) -> ValidationErrorList: + """ + If there are multiple additional property errors it usually means that the offending element was validated against multiple schemas and its parent is a common anyOf validator. + + The error messages produced from these cases are usually + very similar and we just take the shortest one. For example, + the following 3 errors are raised for the `unknown` channel option in + `alt.X("variety", unknown=2)`: + - "Additional properties are not allowed ('unknown' was unexpected)" + - "Additional properties are not allowed ('field', 'unknown' were unexpected)" + - "Additional properties are not allowed ('field', 'type', 'unknown' were unexpected)". + """ + if len(errors) > 1: + # Test if all parent errors are the same anyOf error and only do + # the prioritization in these cases. Can't think of a chart spec where this + # would not be the case but still allow for it below to not break anything. + parent = errors[0].parent + if ( + parent is not None + and parent.validator == "anyOf" + # Use [1:] as don't have to check for first error as it was used + # above to define `parent` + and all(err.parent is parent for err in errors[1:]) + ): + errors = [min(errors, key=lambda x: len(x.message))] + return errors + + +def _deduplicate_by_message(errors: ValidationErrorList) -> ValidationErrorList: + """Deduplicate errors by message. This keeps the original order in case it was chosen intentionally.""" + return list({e.message: e for e in errors}.values()) + + +def _subclasses(cls: type[Any]) -> Iterator[type[Any]]: + """Breadth-first sequence of all classes which inherit from cls.""" + seen = set() + current_set = {cls} + while current_set: + seen |= current_set + current_set = set.union(*(set(cls.__subclasses__()) for cls in current_set)) + for cls in current_set - seen: + yield cls + + +def _from_array_like(obj: Iterable[Any], /) -> list[Any]: + try: + ser = nw.from_native(obj, strict=True, series_only=True) + return ser.to_list() + except TypeError: + return list(obj) + + +def _todict(obj: Any, context: dict[str, Any] | None, np_opt: Any, pd_opt: Any) -> Any: # noqa: C901 + """Convert an object to a dict representation.""" + if np_opt is not None: + np = np_opt + if isinstance(obj, np.ndarray): + return [_todict(v, context, np_opt, pd_opt) for v in obj] + elif isinstance(obj, np.number): + return float(obj) + elif isinstance(obj, np.datetime64): + result = str(obj) + if "T" not in result: + # See https://github.com/vega/altair/issues/1027 for why this is necessary. + result += "T00:00:00" + return result + if isinstance(obj, SchemaBase): + return obj.to_dict(validate=False, context=context) + elif isinstance(obj, (list, tuple)): + return [_todict(v, context, np_opt, pd_opt) for v in obj] + elif isinstance(obj, dict): + return { + k: _todict(v, context, np_opt, pd_opt) + for k, v in obj.items() + if v is not Undefined + } + elif ( + hasattr(obj, "to_dict") + and (module_name := obj.__module__) + and module_name.startswith("altair") + ): + return obj.to_dict() + elif pd_opt is not None and isinstance(obj, pd_opt.Timestamp): + return pd_opt.Timestamp(obj).isoformat() + elif _is_iterable(obj, exclude=(str, bytes)): + return _todict(_from_array_like(obj), context, np_opt, pd_opt) + else: + return obj + + +def _resolve_references( + schema: dict[str, Any], rootschema: dict[str, Any] | None = None +) -> dict[str, Any]: + """Resolve schema references until there is no $ref anymore in the top-level of the dictionary.""" + if _use_referencing_library(): + registry = _get_referencing_registry(rootschema or schema) + # Using a different variable name to show that this is not the + # jsonschema.RefResolver but instead a Resolver from the referencing + # library + referencing_resolver = registry.resolver() + while "$ref" in schema: + schema = referencing_resolver.lookup( + _VEGA_LITE_ROOT_URI + schema["$ref"] + ).contents + else: + resolver = jsonschema.RefResolver.from_schema(rootschema or schema) + while "$ref" in schema: + with resolver.resolving(schema["$ref"]) as resolved: + schema = resolved + return schema + + +class SchemaValidationError(jsonschema.ValidationError): + def __init__(self, obj: SchemaBase, err: jsonschema.ValidationError) -> None: + """ + A wrapper for ``jsonschema.ValidationError`` with friendlier traceback. + + Parameters + ---------- + obj + The instance that failed ``self.validate(...)``. + err + The original ``ValidationError``. + + Notes + ----- + We do not raise `from err` as else the resulting traceback is very long + as it contains part of the Vega-Lite schema. + + It would also first show the less helpful `ValidationError` instead of + the more user friendly `SchemaValidationError`. + """ + super().__init__(**err._contents()) + self.obj = obj + self._errors: GroupedValidationErrors = getattr( + err, "_all_errors", {getattr(err, "json_path", _json_path(err)): [err]} + ) + # This is the message from err + self._original_message = self.message + self.message = self._get_message() + + def __str__(self) -> str: + return self.message + + def _get_message(self) -> str: + def indent_second_line_onwards(message: str, indent: int = 4) -> str: + modified_lines: list[str] = [] + for idx, line in enumerate(message.split("\n")): + if idx > 0 and len(line) > 0: + line = " " * indent + line + modified_lines.append(line) + return "\n".join(modified_lines) + + error_messages: list[str] = [] + # Only show a maximum of 3 errors as else the final message returned by this + # method could get very long. + for errors in list(self._errors.values())[:3]: + error_messages.append(self._get_message_for_errors_group(errors)) + + message = "" + if len(error_messages) > 1: + error_messages = [ + indent_second_line_onwards(f"Error {error_id}: {m}") + for error_id, m in enumerate(error_messages, start=1) + ] + message += "Multiple errors were found.\n\n" + message += "\n\n".join(error_messages) + return message + + def _get_message_for_errors_group( + self, + errors: ValidationErrorList, + ) -> str: + if errors[0].validator == "additionalProperties": + # During development, we only found cases where an additionalProperties + # error was raised if that was the only error for the offending instance + # as identifiable by the json path. Therefore, we just check here the first + # error. However, other constellations might exist in which case + # this should be adapted so that other error messages are shown as well. + message = self._get_additional_properties_error_message(errors[0]) + else: + message = self._get_default_error_message(errors=errors) + + return message.strip() + + def _get_additional_properties_error_message( + self, + error: jsonschema.exceptions.ValidationError, + ) -> str: + """Output all existing parameters when an unknown parameter is specified.""" + altair_cls = self._get_altair_class_for_error(error) + param_dict_keys = inspect.signature(altair_cls).parameters.keys() + param_names_table = self._format_params_as_table(param_dict_keys) + + # Error messages for these errors look like this: + # "Additional properties are not allowed ('unknown' was unexpected)" + # Line below extracts "unknown" from this string + parameter_name = error.message.split("('")[-1].split("'")[0] + message = f"""\ +`{altair_cls.__name__}` has no parameter named '{parameter_name}' + +Existing parameter names are: +{param_names_table} +See the help for `{altair_cls.__name__}` to read the full description of these parameters""" + return message + + def _get_altair_class_for_error( + self, error: jsonschema.exceptions.ValidationError + ) -> type[SchemaBase]: + """ + Try to get the lowest class possible in the chart hierarchy so it can be displayed in the error message. + + This should lead to more informative error messages pointing the user closer to the source of the issue. + """ + for prop_name in reversed(error.absolute_path): + # Check if str as e.g. first item can be a 0 + if isinstance(prop_name, str): + potential_class_name = prop_name[0].upper() + prop_name[1:] + cls = getattr(vegalite, potential_class_name, None) + if cls is not None: + break + else: + # Did not find a suitable class based on traversing the path so we fall + # back on the class of the top-level object which created + # the SchemaValidationError + cls = self.obj.__class__ + return cls + + @staticmethod + def _format_params_as_table(param_dict_keys: Iterable[str]) -> str: + """Format param names into a table so that they are easier to read.""" + param_names: tuple[str, ...] + name_lengths: tuple[int, ...] + param_names, name_lengths = zip( + *[ + (name, len(name)) + for name in param_dict_keys + if name not in {"kwds", "self"} + ] + ) + # Worst case scenario with the same longest param name in the same + # row for all columns + max_name_length = max(name_lengths) + max_column_width = 80 + # Output a square table if not too big (since it is easier to read) + num_param_names = len(param_names) + square_columns = int(ceil(num_param_names**0.5)) + columns = min(max_column_width // max_name_length, square_columns) + + # Compute roughly equal column heights to evenly divide the param names + def split_into_equal_parts(n: int, p: int) -> list[int]: + return [n // p + 1] * (n % p) + [n // p] * (p - n % p) + + column_heights = split_into_equal_parts(num_param_names, columns) + + # Section the param names into columns and compute their widths + param_names_columns: list[tuple[str, ...]] = [] + column_max_widths: list[int] = [] + last_end_idx: int = 0 + for ch in column_heights: + param_names_columns.append(param_names[last_end_idx : last_end_idx + ch]) + column_max_widths.append( + max(len(param_name) for param_name in param_names_columns[-1]) + ) + last_end_idx = ch + last_end_idx + + # Transpose the param name columns into rows to facilitate looping + param_names_rows: list[tuple[str, ...]] = [] + for li in zip_longest(*param_names_columns, fillvalue=""): + param_names_rows.append(li) + # Build the table as a string by iterating over and formatting the rows + param_names_table: str = "" + for param_names_row in param_names_rows: + for num, param_name in enumerate(param_names_row): + # Set column width based on the longest param in the column + max_name_length_column = column_max_widths[num] + column_pad = 3 + param_names_table += "{:<{}}".format( + param_name, max_name_length_column + column_pad + ) + # Insert newlines and spacing after the last element in each row + if num == (len(param_names_row) - 1): + param_names_table += "\n" + return param_names_table + + def _get_default_error_message( + self, + errors: ValidationErrorList, + ) -> str: + bullet_points: list[str] = [] + errors_by_validator = _group_errors_by_validator(errors) + if "enum" in errors_by_validator: + for error in errors_by_validator["enum"]: + bullet_points.append(f"one of {error.validator_value}") + + if "type" in errors_by_validator: + types = [f"'{err.validator_value}'" for err in errors_by_validator["type"]] + point = "of type " + if len(types) == 1: + point += types[0] + elif len(types) == 2: + point += f"{types[0]} or {types[1]}" + else: + point += ", ".join(types[:-1]) + f", or {types[-1]}" + bullet_points.append(point) + + # It should not matter which error is specifically used as they are all + # about the same offending instance (i.e. invalid value), so we can just + # take the first one + error = errors[0] + # Add a summary line when parameters are passed an invalid value + # For example: "'asdf' is an invalid value for `stack` + message = f"'{error.instance}' is an invalid value" + if error.absolute_path: + message += f" for `{error.absolute_path[-1]}`" + + # Add bullet points + if len(bullet_points) == 0: + message += ".\n\n" + elif len(bullet_points) == 1: + message += f". Valid values are {bullet_points[0]}.\n\n" + else: + # We don't use .capitalize below to make the first letter uppercase + # as that makes the rest of the message lowercase + bullet_points = [point[0].upper() + point[1:] for point in bullet_points] + message += ". Valid values are:\n\n" + message += "\n".join([f"- {point}" for point in bullet_points]) + message += "\n\n" + + # Add unformatted messages of any remaining errors which were not + # considered so far. This is not expected to be used but more exists + # as a fallback for cases which were not known during development. + it = ( + "\n".join(e.message for e in errors) + for validator, errors in errors_by_validator.items() + if validator not in {"enum", "type"} + ) + message += "".join(it) + return message + + +class UndefinedType: + """A singleton object for marking undefined parameters.""" + + __instance = None + + def __new__(cls, *args, **kwargs) -> Self: + if not isinstance(cls.__instance, cls): + cls.__instance = object.__new__(cls, *args, **kwargs) + return cls.__instance + + def __repr__(self) -> str: + return "Undefined" + + +Undefined = UndefinedType() +T = TypeVar("T") +Optional: TypeAlias = Union[T, UndefinedType] +"""One of ``T`` specified type(s), or the ``Undefined`` singleton. + +Examples +-------- +The parameters ``short``, ``long`` accept the same range of types:: + + # ruff: noqa: UP006, UP007 + from altair.typing import Optional + + def func_1( + short: Optional[str | bool | float | dict[str, Any] | SchemaBase] = Undefined, + long: Union[ + str, bool, float, Dict[str, Any], SchemaBase, UndefinedType + ] = Undefined, + ): ... + +This is distinct from `typing.Optional `__. + +``altair.typing.Optional`` treats ``None`` like any other type:: + + # ruff: noqa: UP006, UP007 + from altair.typing import Optional + + def func_2( + short: Optional[str | float | dict[str, Any] | None | SchemaBase] = Undefined, + long: Union[ + str, float, Dict[str, Any], None, SchemaBase, UndefinedType + ] = Undefined, + ): ... +""" + + +def is_undefined(obj: Any) -> TypeIs[UndefinedType]: + """ + Type-safe singleton check for `UndefinedType`. + + Notes + ----- + - Using `obj is Undefined` does not narrow from `UndefinedType` in a union. + - Due to the assumption that other `UndefinedType`'s could exist. + - Current [typing spec advises](https://typing.readthedocs.io/en/latest/spec/concepts.html#support-for-singleton-types-in-unions) using an `Enum`. + - Otherwise, requires an explicit guard to inform the type checker. + """ + return obj is Undefined + + +@overload +def _shallow_copy(obj: _CopyImpl) -> _CopyImpl: ... +@overload +def _shallow_copy(obj: Any) -> Any: ... +def _shallow_copy(obj: _CopyImpl | Any) -> _CopyImpl | Any: + if isinstance(obj, SchemaBase): + return obj.copy(deep=False) + elif isinstance(obj, (list, dict)): + return obj.copy() + else: + return obj + + +@overload +def _deep_copy(obj: _CopyImpl, by_ref: set[str]) -> _CopyImpl: ... +@overload +def _deep_copy(obj: Any, by_ref: set[str]) -> Any: ... +def _deep_copy(obj: _CopyImpl | Any, by_ref: set[str]) -> _CopyImpl | Any: + copy = partial(_deep_copy, by_ref=by_ref) + if isinstance(obj, SchemaBase): + if copier := getattr(obj, "__deepcopy__", None): + with debug_mode(False): + return copier(obj) + args = (copy(arg) for arg in obj._args) + kwds = {k: (copy(v) if k not in by_ref else v) for k, v in obj._kwds.items()} + with debug_mode(False): + return obj.__class__(*args, **kwds) + elif isinstance(obj, list): + return [copy(v) for v in obj] + elif isinstance(obj, dict): + return {k: (copy(v) if k not in by_ref else v) for k, v in obj.items()} + else: + return obj + + +class SchemaBase: + """ + Base class for schema wrappers. + + Each derived class should set the _schema class attribute (and optionally + the _rootschema class attribute) which is used for validation. + """ + + _schema: ClassVar[dict[str, Any] | Any] = None + _rootschema: ClassVar[dict[str, Any] | None] = None + _class_is_valid_at_instantiation: ClassVar[bool] = True + + def __init__(self, *args: Any, **kwds: Any) -> None: + # Two valid options for initialization, which should be handled by + # derived classes: + # - a single arg with no kwds, for, e.g. {'type': 'string'} + # - zero args with zero or more kwds for {'type': 'object'} + if self._schema is None: + msg = ( + f"Cannot instantiate object of type {self.__class__}: " + "_schema class attribute is not defined." + "" + ) + raise ValueError(msg) + + if kwds: + assert len(args) == 0 + else: + assert len(args) in {0, 1} + + # use object.__setattr__ because we override setattr below. + object.__setattr__(self, "_args", args) + object.__setattr__(self, "_kwds", kwds) + + if DEBUG_MODE and self._class_is_valid_at_instantiation: + self.to_dict(validate=True) + + def copy( + self, deep: bool | Iterable[Any] = True, ignore: list[str] | None = None + ) -> Self: + """ + Return a copy of the object. + + Parameters + ---------- + deep : boolean or list, optional + If True (default) then return a deep copy of all dict, list, and + SchemaBase objects within the object structure. + If False, then only copy the top object. + If a list or iterable, then only copy the listed attributes. + ignore : list, optional + A list of keys for which the contents should not be copied, but + only stored by reference. + """ + if deep is True: + return cast("Self", _deep_copy(self, set(ignore) if ignore else set())) + with debug_mode(False): + copy = self.__class__(*self._args, **self._kwds) + if _is_iterable(deep): + for attr in deep: + copy[attr] = _shallow_copy(copy._get(attr)) + return copy + + def _get(self, attr, default=Undefined): + """Get an attribute, returning default if not present.""" + attr = self._kwds.get(attr, Undefined) + if attr is Undefined: + attr = default + return attr + + def __getattr__(self, attr): + # reminder: getattr is called after the normal lookups + if attr == "_kwds": + raise AttributeError() + if attr in self._kwds: + return self._kwds[attr] + else: + try: + _getattr = super().__getattr__ # pyright: ignore[reportAttributeAccessIssue] + except AttributeError: + _getattr = super().__getattribute__ + return _getattr(attr) + + def __setattr__(self, item, val) -> None: + self._kwds[item] = val + + def __getitem__(self, item): + return self._kwds[item] + + def __setitem__(self, item, val) -> None: + self._kwds[item] = val + + def __repr__(self) -> str: + name = type(self).__name__ + if kwds := self._kwds: + it = (f"{k}: {v!r}" for k, v in sorted(kwds.items()) if v is not Undefined) + args = ",\n".join(it).replace("\n", "\n ") + LB, RB = "{", "}" + return f"{name}({LB}\n {args}\n{RB})" + else: + return f"{name}({self._args[0]!r})" + + def __eq__(self, other: Any) -> bool: + return ( + type(self) is type(other) + and self._args == other._args + and self._kwds == other._kwds + ) + + def to_dict( + self, + validate: bool = True, + *, + ignore: list[str] | None = None, + context: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """ + Return a dictionary representation of the object. + + Parameters + ---------- + validate : bool, optional + If True (default), then validate the result against the schema. + ignore : list[str], optional + A list of keys to ignore. + context : dict[str, Any], optional + A context dictionary. + + Raises + ------ + SchemaValidationError : + If ``validate`` and the result does not conform to the schema. + + Notes + ----- + - ``ignore``, ``context`` are usually not needed to be specified as a user. + - *Technical*: ``ignore`` will **not** be passed to child :meth:`.to_dict()`. + """ + context = context or {} + ignore = ignore or [] + opts = _get_optional_modules(np_opt="numpy", pd_opt="pandas") + + if self._args and not self._kwds: + kwds = self._args[0] + elif not self._args: + kwds = self._kwds.copy() + exclude = {*ignore, "shorthand"} + if parsed := context.pop("parsed_shorthand", None): + kwds = _replace_parsed_shorthand(parsed, kwds) + kwds = {k: v for k, v in kwds.items() if k not in exclude} + if (mark := kwds.get("mark")) and isinstance(mark, str): + kwds["mark"] = {"type": mark} + else: + msg = f"{type(self)} instance has both a value and properties : cannot serialize to dict" + raise ValueError(msg) + result = _todict(kwds, context=context, **opts) + if validate: + # NOTE: Don't raise `from err`, see `SchemaValidationError` doc + try: + self.validate(result) + except jsonschema.ValidationError as err: + raise SchemaValidationError(self, err) from None + return result + + def to_json( + self, + validate: bool = True, + indent: int | str | None = 2, + sort_keys: bool = True, + *, + ignore: list[str] | None = None, + context: dict[str, Any] | None = None, + **kwargs, + ) -> str: + """ + Emit the JSON representation for this object as a string. + + Parameters + ---------- + validate : bool, optional + If True (default), then validate the result against the schema. + indent : int, optional + The number of spaces of indentation to use. The default is 2. + sort_keys : bool, optional + If True (default), sort keys in the output. + ignore : list[str], optional + A list of keys to ignore. + context : dict[str, Any], optional + A context dictionary. + **kwargs + Additional keyword arguments are passed to ``json.dumps()`` + + Raises + ------ + SchemaValidationError : + If ``validate`` and the result does not conform to the schema. + + Notes + ----- + - ``ignore``, ``context`` are usually not needed to be specified as a user. + - *Technical*: ``ignore`` will **not** be passed to child :meth:`.to_dict()`. + """ + if ignore is None: + ignore = [] + if context is None: + context = {} + dct = self.to_dict(validate=validate, ignore=ignore, context=context) + return json.dumps(dct, indent=indent, sort_keys=sort_keys, **kwargs) + + @classmethod + def _default_wrapper_classes(cls) -> Iterator[type[SchemaBase]]: + """Return the set of classes used within cls.from_dict().""" + return _subclasses(SchemaBase) + + @classmethod + def from_dict( + cls: type[TSchemaBase], dct: dict[str, Any], validate: bool = True + ) -> TSchemaBase: + """ + Construct class from a dictionary representation. + + Parameters + ---------- + dct : dictionary + The dict from which to construct the class + validate : boolean + If True (default), then validate the input against the schema. + + Raises + ------ + jsonschema.ValidationError : + If ``validate`` and ``dct`` does not conform to the schema + """ + if validate: + cls.validate(dct) + converter = _FromDict(cls._default_wrapper_classes()) + return converter.from_dict(dct, cls) + + @classmethod + def from_json( + cls, + json_string: str, + validate: bool = True, + **kwargs: Any, + # Type hints for this method would get rather complicated + # if we want to provide a more specific return type + ) -> ChartType: + """ + Instantiate the object from a valid JSON string. + + Parameters + ---------- + json_string : string + The string containing a valid JSON chart specification. + validate : boolean + If True (default), then validate the input against the schema. + **kwargs : + Additional keyword arguments are passed to json.loads + + Returns + ------- + chart : Chart object + The altair Chart object built from the specification. + """ + dct: dict[str, Any] = json.loads(json_string, **kwargs) + return cls.from_dict(dct, validate=validate) # type: ignore[return-value] + + @classmethod + def validate( + cls, instance: dict[str, Any], schema: dict[str, Any] | None = None + ) -> None: + """Validate the instance against the class schema in the context of the rootschema.""" + if schema is None: + schema = cls._schema + # For the benefit of mypy + assert schema is not None + validate_jsonschema(instance, schema, rootschema=cls._rootschema or cls._schema) + + @classmethod + def resolve_references(cls, schema: dict[str, Any] | None = None) -> dict[str, Any]: + """Resolve references in the context of this object's schema or root schema.""" + schema_to_pass = schema or cls._schema + # For the benefit of mypy + assert schema_to_pass is not None + return _resolve_references( + schema=schema_to_pass, + rootschema=(cls._rootschema or cls._schema or schema), + ) + + @classmethod + def validate_property( + cls, name: str, value: Any, schema: dict[str, Any] | None = None + ) -> None: + """Validate a property against property schema in the context of the rootschema.""" + opts = _get_optional_modules(np_opt="numpy", pd_opt="pandas") + value = _todict(value, context={}, **opts) + props = cls.resolve_references(schema or cls._schema).get("properties", {}) + validate_jsonschema( + value, props.get(name, {}), rootschema=cls._rootschema or cls._schema + ) + + def __dir__(self) -> list[str]: + return sorted(chain(super().__dir__(), self._kwds)) + + +def _get_optional_modules(**modules: str) -> dict[str, _OptionalModule]: + """ + Returns packages only if they have already been imported - otherwise they return `None`. + + This is useful for `isinstance` checks. + + For example, if `pandas` has not been imported, then an object is + definitely not a `pandas.Timestamp`. + + Parameters + ---------- + **modules + Keyword-only binding from `{alias: module_name}`. + + Examples + -------- + >>> import pandas as pd # doctest: +SKIP + >>> import polars as pl # doctest: +SKIP + >>> from altair.utils.schemapi import _get_optional_modules # doctest: +SKIP + >>> + >>> _get_optional_modules(pd="pandas", pl="polars", ibis="ibis") # doctest: +SKIP + { + "pd": , + "pl": , + "ibis": None, + } + + If the user later imports ``ibis``, it would appear in subsequent calls. + + >>> import ibis # doctest: +SKIP + >>> + >>> _get_optional_modules(ibis="ibis") # doctest: +SKIP + { + "ibis": , + } + """ + return {k: sys.modules.get(v) for k, v in modules.items()} + + +def _replace_parsed_shorthand( + parsed_shorthand: dict[str, Any], kwds: dict[str, Any] +) -> dict[str, Any]: + """ + `parsed_shorthand` is added by `FieldChannelMixin`. + + It's used below to replace shorthand with its long form equivalent + `parsed_shorthand` is removed from `context` if it exists so that it is + not passed to child `to_dict` function calls. + """ + # Prevent that pandas categorical data is automatically sorted + # when a non-ordinal data type is specifed manually + # or if the encoding channel does not support sorting + if "sort" in parsed_shorthand and ( + "sort" not in kwds or kwds["type"] not in {"ordinal", Undefined} + ): + parsed_shorthand.pop("sort") + + kwds.update( + (k, v) + for k, v in parsed_shorthand.items() + if kwds.get(k, Undefined) is Undefined + ) + return kwds + + +TSchemaBase = TypeVar("TSchemaBase", bound=SchemaBase) + +_CopyImpl = TypeVar("_CopyImpl", SchemaBase, Dict[Any, Any], List[Any]) +""" +Types which have an implementation in ``SchemaBase.copy()``. + +All other types are returned **by reference**. +""" + + +def _is_dict(obj: Any | dict[Any, Any]) -> TypeIs[dict[Any, Any]]: + return isinstance(obj, dict) + + +def _is_list(obj: Any | list[Any]) -> TypeIs[list[Any]]: + return isinstance(obj, list) + + +def _is_iterable( + obj: Any, *, exclude: type | tuple[type, ...] = (str, bytes) +) -> TypeIs[Iterable[Any]]: + return not isinstance(obj, exclude) and isinstance(obj, Iterable) + + +def _passthrough(*args: Any, **kwds: Any) -> Any | dict[str, Any]: + return args[0] if args else kwds + + +class _FromDict: + """ + Class used to construct SchemaBase class hierarchies from a dict. + + The primary purpose of using this class is to be able to build a hash table + that maps schemas to their wrapper classes. The candidate classes are + specified in the ``wrapper_classes`` positional-only argument to the constructor. + """ + + _hash_exclude_keys = ("definitions", "title", "description", "$schema", "id") + + def __init__(self, wrapper_classes: Iterable[type[SchemaBase]], /) -> None: + # Create a mapping of a schema hash to a list of matching classes + # This lets us quickly determine the correct class to construct + self.class_dict: dict[int, list[type[SchemaBase]]] = defaultdict(list) + for tp in wrapper_classes: + if tp._schema is not None: + self.class_dict[self.hash_schema(tp._schema)].append(tp) + + @classmethod + def hash_schema(cls, schema: dict[str, Any], use_json: bool = True) -> int: + """ + Compute a python hash for a nested dictionary which properly handles dicts, lists, sets, and tuples. + + At the top level, the function excludes from the hashed schema all keys + listed in `exclude_keys`. + + This implements two methods: one based on conversion to JSON, and one based + on recursive conversions of unhashable to hashable types; the former seems + to be slightly faster in several benchmarks. + """ + if cls._hash_exclude_keys and isinstance(schema, dict): + schema = { + key: val + for key, val in schema.items() + if key not in cls._hash_exclude_keys + } + if use_json: + s = json.dumps(schema, sort_keys=True) + return hash(s) + else: + + def _freeze(val): + if isinstance(val, dict): + return frozenset((k, _freeze(v)) for k, v in val.items()) + elif isinstance(val, set): + return frozenset(map(_freeze, val)) + elif isinstance(val, (list, tuple)): + return tuple(map(_freeze, val)) + else: + return val + + return hash(_freeze(schema)) + + @overload + def from_dict( + self, + dct: TSchemaBase, + tp: None = ..., + schema: None = ..., + rootschema: None = ..., + default_class: Any = ..., + ) -> TSchemaBase: ... + @overload + def from_dict( + self, + dct: dict[str, Any] | list[dict[str, Any]], + tp: Any = ..., + schema: Any = ..., + rootschema: Any = ..., + default_class: type[TSchemaBase] = ..., # pyright: ignore[reportInvalidTypeVarUse] + ) -> TSchemaBase: ... + @overload + def from_dict( + self, + dct: dict[str, Any], + tp: None = ..., + schema: dict[str, Any] = ..., + rootschema: None = ..., + default_class: Any = ..., + ) -> SchemaBase: ... + @overload + def from_dict( + self, + dct: dict[str, Any], + tp: type[TSchemaBase], + schema: None = ..., + rootschema: None = ..., + default_class: Any = ..., + ) -> TSchemaBase: ... + @overload + def from_dict( + self, + dct: dict[str, Any] | list[dict[str, Any]], + tp: type[TSchemaBase], + schema: dict[str, Any], + rootschema: dict[str, Any] | None = ..., + default_class: Any = ..., + ) -> Never: ... + def from_dict( + self, + dct: dict[str, Any] | list[dict[str, Any]] | TSchemaBase, + tp: type[TSchemaBase] | None = None, + schema: dict[str, Any] | None = None, + rootschema: dict[str, Any] | None = None, + default_class: Any = _passthrough, + ) -> TSchemaBase | SchemaBase: + """Construct an object from a dict representation.""" + target_tp: Any + current_schema: dict[str, Any] + if isinstance(dct, SchemaBase): + return dct + elif tp is not None: + current_schema = tp._schema + root_schema: dict[str, Any] = rootschema or tp._rootschema or current_schema + target_tp = tp + elif schema is not None: + # If there are multiple matches, we use the first one in the dict. + # Our class dict is constructed breadth-first from top to bottom, + # so the first class that matches is the most general match. + current_schema = schema + root_schema = rootschema or current_schema + matches = self.class_dict[self.hash_schema(current_schema)] + target_tp = matches[0] if matches else default_class + else: + msg = "Must provide either `tp` or `schema`, but not both." + raise ValueError(msg) + + from_dict = partial(self.from_dict, rootschema=root_schema) + # Can also return a list? + resolved = _resolve_references(current_schema, root_schema) + if "anyOf" in resolved or "oneOf" in resolved: + schemas = resolved.get("anyOf", []) + resolved.get("oneOf", []) + for possible in schemas: + try: + validate_jsonschema(dct, possible, rootschema=root_schema) + except jsonschema.ValidationError: + continue + else: + return from_dict(dct, schema=possible, default_class=target_tp) + + if _is_dict(dct): + # TODO: handle schemas for additionalProperties/patternProperties + props: dict[str, Any] = resolved.get("properties", {}) + kwds = { + k: (from_dict(v, schema=props[k]) if k in props else v) + for k, v in dct.items() + } + return target_tp(**kwds) + elif _is_list(dct): + item_schema: dict[str, Any] = resolved.get("items", {}) + return target_tp([from_dict(k, schema=item_schema) for k in dct]) + else: + # NOTE: Unsure what is valid here + return target_tp(dct) + + +class _PropertySetter: + def __init__(self, prop: str, schema: dict[str, Any]) -> None: + self.prop = prop + self.schema = schema + + def __get__(self, obj, cls): + self.obj = obj + self.cls = cls + # The docs from the encoding class parameter (e.g. `bin` in X, Color, + # etc); this provides a general description of the parameter. + self.__doc__ = self.schema["description"].replace("__", "**") + property_name = f"{self.prop}"[0].upper() + f"{self.prop}"[1:] + if hasattr(vegalite, property_name): + altair_prop = getattr(vegalite, property_name) + # Add the docstring from the helper class (e.g. `BinParams`) so + # that all the parameter names of the helper class are included in + # the final docstring + parameter_index = altair_prop.__doc__.find("Parameters\n") + if parameter_index > -1: + self.__doc__ = ( + altair_prop.__doc__[:parameter_index].replace(" ", "") + + self.__doc__ + + textwrap.dedent( + f"\n\n {altair_prop.__doc__[parameter_index:]}" + ) + ) + # For short docstrings such as Aggregate, Stack, et + else: + self.__doc__ = ( + altair_prop.__doc__.replace(" ", "") + "\n" + self.__doc__ + ) + # Add signatures and tab completion for the method and parameter names + self.__signature__ = inspect.signature(altair_prop) + self.__wrapped__ = inspect.getfullargspec(altair_prop) + self.__name__ = altair_prop.__name__ + else: + # It seems like bandPosition is the only parameter that doesn't + # have a helper class. + pass + return self + + def __call__(self, *args: Any, **kwargs: Any): + obj = self.obj.copy() + # TODO: use schema to validate + obj[self.prop] = args[0] if args else kwargs + return obj + + +def with_property_setters(cls: type[TSchemaBase]) -> type[TSchemaBase]: + """Decorator to add property setters to a Schema class.""" + schema = cls.resolve_references() + for prop, propschema in schema.get("properties", {}).items(): + setattr(cls, prop, _PropertySetter(prop, propschema)) + return cls diff --git a/tools/schemapi/utils.py b/tools/schemapi/utils.py new file mode 100755 index 00000000..3fa30492 --- /dev/null +++ b/tools/schemapi/utils.py @@ -0,0 +1,902 @@ +"""Utilities for working with schemas.""" + +from __future__ import annotations + +import keyword +import re +import subprocess +import textwrap +import urllib +from html import unescape +from itertools import chain +from operator import itemgetter +from typing import ( + TYPE_CHECKING, + Any, + Final, + Iterable, + Iterator, + Literal, + Sequence, + overload, +) + +import mistune +from mistune.renderers.rst import RSTRenderer as _RSTRenderer + +from tools.schemapi.schemapi import _resolve_references as resolve_references + +if TYPE_CHECKING: + from pathlib import Path + from typing_extensions import LiteralString + + from mistune import BlockState + +EXCLUDE_KEYS: Final = ("definitions", "title", "description", "$schema", "id") + +jsonschema_to_python_types = { + "string": "str", + "number": "float", + "integer": "int", + "object": "Map", + "boolean": "bool", + "array": "list", + "null": "None", +} + + +class _TypeAliasTracer: + """ + Recording all `enum` -> `Literal` translations. + + Rewrites as `TypeAlias` to be reused anywhere, and not clog up method definitions. + + Parameters + ---------- + fmt + A format specifier to produce the `TypeAlias` name. + + Will be provided a `SchemaInfo.title` as a single positional argument. + *ruff_check + Optional [ruff rule codes](https://docs.astral.sh/ruff/rules/), + each prefixed with `--select ` and follow a `ruff check --fix ` call. + + If not provided, uses `[tool.ruff.lint.select]` from `pyproject.toml`. + ruff_format + Optional argument list supplied to [ruff format](https://docs.astral.sh/ruff/formatter/#ruff-format) + + Attributes + ---------- + _literals: dict[str, str] + `{alias_name: literal_statement}` + _literals_invert: dict[str, str] + `{literal_statement: alias_name}` + aliases: list[tuple[str, str]] + `_literals` sorted by `alias_name` + _imports: Sequence[str] + Prefined import statements to appear at beginning of module. + """ + + def __init__( + self, + fmt: str = "{}_T", + *ruff_check: str, + ruff_format: Sequence[str] | None = None, + ) -> None: + self.fmt: str = fmt + self._literals: dict[str, str] = {} + self._literals_invert: dict[str, str] = {} + self._aliases: dict[str, str] = {} + self._imports: Sequence[str] = ( + "from __future__ import annotations\n", + "from typing import Any, Literal, Mapping, TypeVar, Sequence, Union", + "from typing_extensions import TypeAlias, TypeAliasType", + ) + self._cmd_check: list[str] = ["--fix"] + self._cmd_format: Sequence[str] = ruff_format or () + for c in ruff_check: + self._cmd_check.extend(("--extend-select", c)) + + def _update_literals(self, name: str, tp: str, /) -> None: + """Produces an inverted index, to reuse a `Literal` when `SchemaInfo.title` is empty.""" + self._literals[name] = tp + self._literals_invert[tp] = name + + def add_literal( + self, info: SchemaInfo, tp: str, /, *, replace: bool = False + ) -> str: + """ + `replace=True` returns the eventual alias name. + + - Doing so will mean that the `_typing` module must be written first, before the source of `info`. + - Otherwise, `ruff` will raise an error during `check`/`format`, as the import will be invalid. + - Where a `title` is not found, an attempt will be made to find an existing alias definition that had one. + """ + if info.title: + alias = self.fmt.format(info.title) + if alias not in self._literals: + self._update_literals(alias, tp) + if replace: + tp = alias + elif (alias := self._literals_invert.get(tp)) and replace: + tp = alias + elif replace and info.is_union_literal(): + # Handles one very specific edge case `WindowFieldDef` + # - Has an anonymous enum union + # - One of the members is declared afterwards + # - SchemaBase needs to be first, as the union wont be internally sorted + it = ( + self.add_literal(el, spell_literal(el.literal), replace=True) + for el in info.anyOf + ) + tp = f"Union[SchemaBase, {', '.join(it)}]" + return tp + + def update_aliases(self, *name_statement: tuple[str, str]) -> None: + """ + Adds `(name, statement)` pairs to the definitions. + + These types should support annotations in generated code, but + are not required to be derived from the schema itself. + + Each tuple will appear in the generated module as:: + + name: TypeAlias = statement + + All aliases will be written in runtime-scope, therefore + externally dependent types should be declared as regular imports. + """ + self._aliases.update(name_statement) + + def generate_aliases(self) -> Iterator[str]: + """Represents a line per `TypeAlias` declaration.""" + for name, statement in self._aliases.items(): + yield f"{name}: TypeAlias = {statement}" + + def is_cached(self, tp: str, /) -> bool: + """ + Applies to both docstring and type hints. + + Currently used as a sort key, to place literals/aliases last. + """ + return tp in self._literals_invert or tp in self._literals or tp in self._aliases # fmt: skip + + def write_module( + self, fp: Path, *extra_all: str, header: LiteralString, extra: LiteralString + ) -> None: + """ + Write all collected `TypeAlias`'s to `fp`. + + Parameters + ---------- + fp + Path to new module. + *extra_all + Any manually spelled types to be exported. + header + `tools.generate_schema_wrapper.HEADER`. + extra + `tools.generate_schema_wrapper.TYPING_EXTRA`. + """ + ruff_format = ["ruff", "format", fp] + if self._cmd_format: + ruff_format.extend(self._cmd_format) + commands = (["ruff", "check", fp, *self._cmd_check], ruff_format) + static = (header, "\n", *self._imports, "\n\n") + self.update_aliases(*sorted(self._literals.items(), key=itemgetter(0))) + all_ = [*iter(self._aliases), *extra_all] + it = chain( + static, + [f"__all__ = {all_}", "\n\n", extra], + self.generate_aliases(), + ) + fp.write_text("\n".join(it), encoding="utf-8") + for cmd in commands: + r = subprocess.run(cmd, check=True) + r.check_returncode() + + @property + def n_entries(self) -> int: + """Number of unique `TypeAlias` defintions collected.""" + return len(self._literals) + + +TypeAliasTracer: _TypeAliasTracer = _TypeAliasTracer("{}_T", "I001", "I002") +"""An instance of `_TypeAliasTracer`. + +Collects a cache of unique `Literal` types used globally. + +These are then converted to `TypeAlias` statements, written to another module. + +Allows for a single definition to be reused multiple times, +rather than repeating long literals in every method definition. +""" + + +def get_valid_identifier( + prop: str, + replacement_character: str = "", + allow_unicode: bool = False, + url_decode: bool = True, +) -> str: + """ + Given a string property, generate a valid Python identifier. + + Parameters + ---------- + prop: string + Name of property to decode. + replacement_character: string, default '' + The character to replace invalid characters with. + allow_unicode: boolean, default False + If True, then allow Python 3-style unicode identifiers. + url_decode: boolean, default True + If True, decode URL characters in identifier names. + + Examples + -------- + >>> get_valid_identifier("my-var") + 'myvar' + + >>> get_valid_identifier("if") + 'if_' + + >>> get_valid_identifier("$schema", "_") + '_schema' + + >>> get_valid_identifier("$*#$") + '_' + + >>> get_valid_identifier("Name%3Cstring%3E") + 'Namestring' + """ + # Decode URL characters. + if url_decode: + prop = urllib.parse.unquote(prop) + + # Deal with [] + prop = prop.replace("[]", "Array") + + # First substitute-out all non-valid characters. + flags = re.UNICODE if allow_unicode else re.ASCII + valid = re.sub(r"\W", replacement_character, prop, flags=flags) + + # If nothing is left, use just an underscore + if not valid: + valid = "_" + + # first character must be a non-digit. Prefix with an underscore + # if needed + if re.match(r"^[\d\W]", valid): + valid = "_" + valid + + # if the result is a reserved keyword, then add an underscore at the end + if keyword.iskeyword(valid): + valid += "_" + return valid + + +def is_valid_identifier(var: str, allow_unicode: bool = False): + """ + Return true if var contains a valid Python identifier. + + Parameters + ---------- + val : string + identifier to check + allow_unicode : bool (default: False) + if True, then allow Python 3 style unicode identifiers. + """ + flags = re.UNICODE if allow_unicode else re.ASCII + is_valid = re.match(r"^[^\d\W]\w*\Z", var, flags) + return is_valid and not keyword.iskeyword(var) + + +class SchemaProperties: + """A wrapper for properties within a schema.""" + + def __init__( + self, + properties: dict[str, Any], + schema: dict, + rootschema: dict | None = None, + ) -> None: + self._properties = properties + self._schema = schema + self._rootschema = rootschema or schema + + def __bool__(self) -> bool: + return bool(self._properties) + + def __dir__(self) -> list[str]: + return list(self._properties.keys()) + + def __getattr__(self, attr): + try: + return self[attr] + except KeyError: + return super().__getattr__(attr) + + def __getitem__(self, attr): + dct = self._properties[attr] + if "definitions" in self._schema and "definitions" not in dct: + dct = dict(definitions=self._schema["definitions"], **dct) + return SchemaInfo(dct, self._rootschema) + + def __iter__(self): + return iter(self._properties) + + def items(self): + return ((key, self[key]) for key in self) + + def keys(self): + return self._properties.keys() + + def values(self): + return (self[key] for key in self) + + +class SchemaInfo: + """A wrapper for inspecting a JSON schema.""" + + def __init__( + self, schema: dict[str, Any], rootschema: dict[str, Any] | None = None + ) -> None: + if not rootschema: + rootschema = schema + self.raw_schema = schema + self.rootschema = rootschema + self.schema = resolve_references(schema, rootschema) + + def child(self, schema: dict) -> SchemaInfo: + return self.__class__(schema, rootschema=self.rootschema) + + def __repr__(self) -> str: + keys = [] + for key in sorted(self.schema.keys()): + val = self.schema[key] + rval = repr(val).replace("\n", "") + if len(rval) > 30: + rval = rval[:30] + "..." + if key == "definitions": + rval = "{...}" + elif key == "properties": + rval = "{\n " + "\n ".join(sorted(map(repr, val))) + "\n }" + keys.append(f'"{key}": {rval}') + return "SchemaInfo({\n " + "\n ".join(keys) + "\n})" + + @property + def title(self) -> str: + if self.is_reference(): + return get_valid_identifier(self.refname) + else: + return "" + + @overload + def get_python_type_representation( + self, + for_type_hints: bool = ..., + return_as_str: Literal[True] = ..., + additional_type_hints: list[str] | None = ..., + ) -> str: ... + @overload + def get_python_type_representation( + self, + for_type_hints: bool = ..., + return_as_str: Literal[False] = ..., + additional_type_hints: list[str] | None = ..., + ) -> list[str]: ... + def get_python_type_representation( # noqa: C901 + self, + for_type_hints: bool = False, + return_as_str: bool = True, + additional_type_hints: list[str] | None = None, + ) -> str | list[str]: + type_representations: list[str] = [] + """ + All types which can be used for the current `SchemaInfo`. + Including `altair` classes, standard `python` types, etc. + """ + + if self.title: + if for_type_hints: + # To keep type hints simple, we only use the SchemaBase class + # as the type hint for all classes which inherit from it. + class_names = ["SchemaBase"] + if self.title in {"ExprRef", "ParameterExtent"}: + class_names.append("Parameter") + # In these cases, a value parameter is also always accepted. + # It would be quite complex to further differentiate + # between a value and a selection parameter based on + # the type system (one could + # try to check for the type of the Parameter.param attribute + # but then we would need to write some overload signatures for + # api.param). + + type_representations.extend(class_names) + else: + # use RST syntax for generated sphinx docs + type_representations.append(rst_syntax_for_class(self.title)) + + if self.is_empty(): + type_representations.append("Any") + elif self.is_literal(): + tp_str = spell_literal(self.literal) + if for_type_hints: + tp_str = TypeAliasTracer.add_literal(self, tp_str, replace=True) + type_representations.append(tp_str) + elif for_type_hints and self.is_union_literal(): + it = chain.from_iterable(el.literal for el in self.anyOf) + tp_str = TypeAliasTracer.add_literal(self, spell_literal(it), replace=True) + type_representations.append(tp_str) + elif self.is_anyOf(): + it = ( + s.get_python_type_representation( + for_type_hints=for_type_hints, return_as_str=False + ) + for s in self.anyOf + ) + type_representations.extend(maybe_rewrap_literal(chain.from_iterable(it))) + elif isinstance(self.type, list): + options = [] + subschema = SchemaInfo(dict(**self.schema)) + for typ_ in self.type: + subschema.schema["type"] = typ_ + # We always use title if possible for nested objects + options.append( + subschema.get_python_type_representation( + for_type_hints=for_type_hints + ) + ) + type_representations.extend(options) + elif self.is_object() and not for_type_hints: + type_representations.append("dict") + elif self.is_array(): + # A list is invariant in its type parameter. This means that e.g. + # List[str] is not a subtype of List[Union[core.FieldName, str]] + # and hence we would need to explicitly write out the combinations, + # so in this case: + # List[core.FieldName], List[str], List[core.FieldName, str] + # However, this can easily explode to too many combinations. + # Furthermore, we would also need to add additional entries + # for e.g. int wherever a float is accepted which would lead to very + # long code. + # As suggested in the mypy docs, + # https://mypy.readthedocs.io/en/stable/common_issues.html#variance, + # we revert to using Sequence which works as well for lists and also + # includes tuples which are also supported by the SchemaBase.to_dict + # method. However, it is not entirely accurate as some sequences + # such as e.g. a range are not supported by SchemaBase.to_dict but + # this tradeoff seems worth it. + s = self.child(self.items).get_python_type_representation( + for_type_hints=for_type_hints + ) + type_representations.append(f"Sequence[{s}]") + elif self.type in jsonschema_to_python_types: + type_representations.append(jsonschema_to_python_types[self.type]) + else: + msg = "No Python type representation available for this schema" + raise ValueError(msg) + + # Shorter types are usually the more relevant ones, e.g. `str` instead + # of `SchemaBase`. Output order from set is non-deterministic -> If + # types have same length names, order would be non-deterministic as it is + # returned from sort. Hence, we sort as well by type name as a tie-breaker, + # see https://docs.python.org/3.10/howto/sorting.html#sort-stability-and-complex-sorts + # for more infos. + # Using lower as we don't want to prefer uppercase such as "None" over + it = sorted(set(flatten(type_representations)), key=str.lower) # Tertiary sort + it = sorted(it, key=len) # Secondary sort + type_representations = sorted(it, key=TypeAliasTracer.is_cached) # Primary sort + if additional_type_hints: + type_representations.extend(additional_type_hints) + + if return_as_str: + type_representations_str = ", ".join(type_representations) + # If it's not for_type_hints but instead for the docstrings, we don't want + # to include Union as it just clutters the docstrings. + if len(type_representations) > 1 and for_type_hints: + # Use parameterised `TypeAlias` instead of exposing `UndefinedType` + # `Union` is collapsed by `ruff` later + if type_representations_str.endswith(", UndefinedType"): + s = type_representations_str.replace(", UndefinedType", "") + s = f"Optional[Union[{s}]]" + else: + s = f"Union[{type_representations_str}]" + return s + return type_representations_str + else: + return type_representations + + @property + def properties(self) -> SchemaProperties: + return SchemaProperties( + self.schema.get("properties", {}), self.schema, self.rootschema + ) + + @property + def definitions(self) -> SchemaProperties: + return SchemaProperties( + self.schema.get("definitions", {}), self.schema, self.rootschema + ) + + @property + def required(self) -> list: + return self.schema.get("required", []) + + @property + def patternProperties(self) -> dict: + return self.schema.get("patternProperties", {}) + + @property + def additionalProperties(self) -> bool: + return self.schema.get("additionalProperties", True) + + @property + def type(self) -> str | list[Any] | None: + return self.schema.get("type", None) + + @property + def anyOf(self) -> list[SchemaInfo]: + return [self.child(s) for s in self.schema.get("anyOf", [])] + + @property + def oneOf(self) -> list[SchemaInfo]: + return [self.child(s) for s in self.schema.get("oneOf", [])] + + @property + def allOf(self) -> list[SchemaInfo]: + return [self.child(s) for s in self.schema.get("allOf", [])] + + @property + def not_(self) -> SchemaInfo: + return self.child(self.schema.get("not", {})) + + @property + def items(self) -> dict: + return self.schema.get("items", {}) + + @property + def enum(self) -> list[str]: + return self.schema.get("enum", []) + + @property + def const(self) -> str: + return self.schema.get("const", "") + + @property + def literal(self) -> list[str]: + return self.schema.get("enum", [self.const]) + + @property + def refname(self) -> str: + return self.raw_schema.get("$ref", "#/").split("/")[-1] + + @property + def ref(self) -> str | None: + return self.raw_schema.get("$ref", None) + + @property + def description(self) -> str: + return self._get_description(include_sublevels=False) + + @property + def deep_description(self) -> str: + return self._get_description(include_sublevels=True) + + def _get_description(self, include_sublevels: bool = False) -> str: + desc = self.raw_schema.get("description", self.schema.get("description", "")) + if not desc and include_sublevels: + for item in self.anyOf: + sub_desc = item._get_description(include_sublevels=False) + if desc and sub_desc: + raise ValueError( + "There are multiple potential descriptions which could" + + " be used for the currently inspected schema. You'll need to" + + " clarify which one is the correct one.\n" + + str(self.schema) + ) + if sub_desc: + desc = sub_desc + return desc + + def is_reference(self) -> bool: + return "$ref" in self.raw_schema + + def is_enum(self) -> bool: + return "enum" in self.schema + + def is_const(self) -> bool: + return "const" in self.schema + + def is_literal(self) -> bool: + return not ({"enum", "const"}.isdisjoint(self.schema)) + + def is_empty(self) -> bool: + return not (set(self.schema.keys()) - set(EXCLUDE_KEYS)) + + def is_compound(self) -> bool: + return any(key in self.schema for key in ["anyOf", "allOf", "oneOf"]) + + def is_anyOf(self) -> bool: + return "anyOf" in self.schema + + def is_allOf(self) -> bool: + return "allOf" in self.schema + + def is_oneOf(self) -> bool: + return "oneOf" in self.schema + + def is_not(self) -> bool: + return "not" in self.schema + + def is_object(self) -> bool: + if self.type == "object": + return True + elif self.type is not None: + return False + elif ( + self.properties + or self.required + or self.patternProperties + or self.additionalProperties + ): + return True + else: + msg = "Unclear whether schema.is_object() is True" + raise ValueError(msg) + + def is_value(self) -> bool: + return not self.is_object() + + def is_array(self) -> bool: + return self.type == "array" + + def is_union(self) -> bool: + """ + Candidate for ``Union`` type alias. + + Not a real class. + """ + return self.is_anyOf() and self.type is None + + def is_union_literal(self) -> bool: + """ + Candidate for reducing to a single ``Literal`` alias. + + E.g. `BinnedTimeUnit` + """ + return self.is_union() and all(el.is_literal() for el in self.anyOf) + + +class RSTRenderer(_RSTRenderer): + def __init__(self) -> None: + super().__init__() + + def inline_html(self, token: dict[str, Any], state: BlockState) -> str: + html = token["raw"] + return rf"\ :raw-html:`{html}`\ " + + +class RSTParse(mistune.Markdown): + def __init__( + self, + renderer: mistune.BaseRenderer, + block: mistune.BlockParser | None = None, + inline: mistune.InlineParser | None = None, + plugins=None, + ) -> None: + super().__init__(renderer, block, inline, plugins) + + def __call__(self, s: str) -> str: + s = super().__call__(s) + return unescape(s).replace(r"\ ,", ",").replace(r"\ ", " ") + + +rst_parse: RSTParse = RSTParse(RSTRenderer()) + + +def indent_docstring( # noqa: C901 + lines: list[str], indent_level: int, width: int = 100, lstrip=True +) -> str: + """Indent a docstring for use in generated code.""" + final_lines = [] + if len(lines) > 1: + lines += [""] + + for i, line in enumerate(lines): + stripped = line.lstrip() + if stripped: + leading_space = len(line) - len(stripped) + indent = indent_level + leading_space + wrapper = textwrap.TextWrapper( + width=width - indent, + initial_indent=indent * " ", + subsequent_indent=indent * " ", + break_long_words=False, + break_on_hyphens=False, + drop_whitespace=True, + ) + list_wrapper = textwrap.TextWrapper( + width=width - indent, + initial_indent=indent * " " + "* ", + subsequent_indent=indent * " " + " ", + break_long_words=False, + break_on_hyphens=False, + drop_whitespace=True, + ) + for line in stripped.split("\n"): + line_stripped = line.lstrip() + line_stripped = fix_docstring_issues(line_stripped) + if line_stripped == "": + final_lines.append("") + elif line_stripped.startswith("* "): + final_lines.extend(list_wrapper.wrap(line_stripped[2:])) + # Matches lines where an attribute is mentioned followed by the accepted + # types (lines starting with a character sequence that + # does not contain white spaces or '*' followed by ' : '). + # It therefore matches 'condition : anyOf(...' but not '**Notes** : ...' + # These lines should not be wrapped at all but appear on one line + elif re.match(r"[^\s*]+ : ", line_stripped): + final_lines.append(indent * " " + line_stripped) + else: + final_lines.extend(wrapper.wrap(line_stripped)) + + # If this is the last line, put in an indent + elif i + 1 == len(lines): + final_lines.append(indent_level * " ") + # If it's not the last line, this is a blank line that should not indent. + else: + final_lines.append("") + # Remove any trailing whitespaces on the right side + stripped_lines = [] + for i, line in enumerate(final_lines): + if i + 1 == len(final_lines): + stripped_lines.append(line) + else: + stripped_lines.append(line.rstrip()) + # Join it all together + wrapped = "\n".join(stripped_lines) + if lstrip: + wrapped = wrapped.lstrip() + return wrapped + + +def fix_docstring_issues(docstring: str) -> str: + # All lists should start with '*' followed by a whitespace. Fixes the ones + # which either do not have a whitespace or/and start with '-' by first replacing + # "-" with "*" and then adding a whitespace where necessary + docstring = re.sub( + r"^-(?=[ `\"a-z])", + "*", + docstring, + flags=re.MULTILINE, + ) + # Now add a whitespace where an asterisk is followed by one of the characters + # in the square brackets of the regex pattern + docstring = re.sub( + r"^\*(?=[`\"a-z])", + "* ", + docstring, + flags=re.MULTILINE, + ) + + # Links to the vega-lite documentation cannot be relative but instead need to + # contain the full URL. + docstring = docstring.replace( + "types#datetime", "https://vega.github.io/vega-lite/docs/datetime.html" + ) + return docstring + + +def rst_syntax_for_class(class_name: str) -> str: + return f":class:`{class_name}`" + + +def flatten(container: Iterable) -> Iterable: + """ + Flatten arbitrarily flattened list. + + From https://stackoverflow.com/a/10824420 + """ + for i in container: + if isinstance(i, (list, tuple)): + yield from flatten(i) + else: + yield i + + +def spell_literal(it: Iterable[str], /, *, quote: bool = True) -> str: + """ + Combine individual ``str`` type reprs into a single ``Literal``. + + Parameters + ---------- + it + Type representations. + quote + Call ``repr()`` on each element in ``it``. + + .. note:: + Set to ``False`` if performing a second pass. + """ + it_el: Iterable[str] = (f"{s!r}" for s in it) if quote else it + return f"Literal[{', '.join(it_el)}]" + + +def maybe_rewrap_literal(it: Iterable[str], /) -> Iterator[str]: + """ + Where `it` may contain one or more `"enum"`, `"const"`, flatten to a single `Literal[...]`. + + All other type representations are yielded unchanged. + """ + seen: set[str] = set() + for s in it: + if s.startswith("Literal["): + seen.add(unwrap_literal(s)) + else: + yield s + if seen: + yield spell_literal(sorted(seen), quote=False) + + +def unwrap_literal(tp: str, /) -> str: + """`"Literal['value']"` -> `"value"`.""" + return re.sub(r"Literal\[(.+)\]", r"\g<1>", tp) + + +def ruff_format_str(code: str | list[str]) -> str: + if isinstance(code, list): + code = "\n".join(code) + + r = subprocess.run( + # Name of the file does not seem to matter but ruff requires one + ["ruff", "format", "--stdin-filename", "placeholder.py"], + input=code.encode(), + check=True, + capture_output=True, + ) + return r.stdout.decode() + + +def ruff_format_py(fp: Path, /, *extra_args: str) -> None: + """ + Format an existing file. + + Running on `win32` after writing lines will ensure "lf" is used before: + ```bash + ruff format --diff --check . + ``` + """ + cmd = ["ruff", "format", fp] + if extra_args: + cmd.extend(extra_args) + r = subprocess.run(cmd, check=True) + r.check_returncode() + + +def ruff_write_lint_format_str( + fp: Path, code: str | Iterable[str], /, *, encoding: str = "utf-8" +) -> None: + """ + Combined steps of writing, `ruff check`, `ruff format`. + + Notes + ----- + - `fp` is written to first, as the size before formatting will be the smallest + - Better utilizes `ruff` performance, rather than `python` str and io + - `code` is no longer bound to `list` + - Encoding set as default + - `I001/2` are `isort` rules, to sort imports. + """ + commands = ( + ["ruff", "check", fp, "--fix"], + ["ruff", "check", fp, "--fix", "--select", "I001", "--select", "I002"], + ) + if not isinstance(code, str): + code = "\n".join(code) + fp.write_text(code, encoding=encoding) + for cmd in commands: + r = subprocess.run(cmd, check=True) + r.check_returncode() + ruff_format_py(fp) diff --git a/tools/test.py b/tools/test.py new file mode 100644 index 00000000..d041a952 --- /dev/null +++ b/tools/test.py @@ -0,0 +1,11 @@ +from generated_classes import AggregateExpression, AggregateTransform + +def test_aggregate_expression(): + pass + +def test_aggregate_transform(): + pass +if __name__ == "__main__": + test_aggregate_expression() + test_aggregate_transform() + print("All tests passed!") diff --git a/tools/testingSchema.json b/tools/testingSchema.json new file mode 100644 index 00000000..822e7c1c --- /dev/null +++ b/tools/testingSchema.json @@ -0,0 +1,259 @@ +{ + "$ref": "#/definitions/Spec", + "$schema": "http://json-schema.org/draft-07/schema#", + "definitions": { + "AggregateExpression": { + "additionalProperties": false, + "description": "A custom SQL aggregate expression.", + "properties": { + "agg": { + "description": "A SQL expression string to calculate an aggregate value. Embedded Param references, such as `SUM($param + 1)`, are supported. For expressions without aggregate functions, use *sql* instead.", + "type": "string" + }, + "label": { + "description": "A label for this expression, for example to label a plot axis.", + "type": "string" + } + }, + "required": ["agg"], + "type": "object" + }, + "AggregateTransform": { + "anyOf": [ + { + "$ref": "#/definitions/Argmax" + }, + { + "$ref": "#/definitions/Argmin" + }, + { + "$ref": "#/definitions/Avg" + }, + { + "$ref": "#/definitions/Count" + }, + { + "$ref": "#/definitions/Max" + }, + { + "$ref": "#/definitions/Min" + }, + { + "$ref": "#/definitions/First" + }, + { + "$ref": "#/definitions/Last" + }, + { + "$ref": "#/definitions/Median" + }, + { + "$ref": "#/definitions/Mode" + }, + { + "$ref": "#/definitions/Product" + }, + { + "$ref": "#/definitions/Quantile" + }, + { + "$ref": "#/definitions/Stddev" + }, + { + "$ref": "#/definitions/StddevPop" + }, + { + "$ref": "#/definitions/Sum" + }, + { + "$ref": "#/definitions/Variance" + }, + { + "$ref": "#/definitions/VarPop" + } + ], + "description": "An aggregate transform that combines multiple values." + }, + "Argmax": { + "additionalProperties": false, + "properties": { + "argmax": { + "description": "Find a value of the first column that maximizes the second column.", + "items": { + "description": "A transform argument.", + "type": [ + "string", + "number", + "boolean" + ] + }, + "maxItems": 2, + "minItems": 2, + "type": "array" + }, + "distinct": { + "type": "boolean" + }, + "orderby": { + "anyOf": [ + { + "$ref": "#/definitions/TransformField" + }, + { + "items": { + "$ref": "#/definitions/TransformField" + }, + "type": "array" + } + ] + }, + "partitionby": { + "anyOf": [ + { + "$ref": "#/definitions/TransformField" + }, + { + "items": { + "$ref": "#/definitions/TransformField" + }, + "type": "array" + } + ] + }, + "range": { + "anyOf": [ + { + "items": { + "type": [ + "number", + "null" + ] + }, + "type": "array" + }, + { + "$ref": "#/definitions/ParamRef" + } + ] + }, + "rows": { + "anyOf": [ + { + "items": { + "type": [ + "number", + "null" + ] + }, + "type": "array" + }, + { + "$ref": "#/definitions/ParamRef" + } + ] + } + }, + "required": [ + "argmax" + ], + "type": "object" + }, + "Argmin": { + "additionalProperties": false, + "properties": { + "argmin": { + "description": "Find a value of the first column that minimizes the second column.", + "items": { + "description": "A transform argument.", + "type": [ + "string", + "number", + "boolean" + ] + }, + "maxItems": 2, + "minItems": 2, + "type": "array" + }, + "distinct": { + "type": "boolean" + }, + "orderby": { + "anyOf": [ + { + "$ref": "#/definitions/TransformField" + }, + { + "items": { + "$ref": "#/definitions/TransformField" + }, + "type": "array" + } + ] + }, + "partitionby": { + "anyOf": [ + { + "$ref": "#/definitions/TransformField" + }, + { + "items": { + "$ref": "#/definitions/TransformField" + }, + "type": "array" + } + ] + }, + "range": { + "anyOf": [ + { + "items": { + "type": [ + "number", + "null" + ] + }, + "type": "array" + }, + { + "$ref": "#/definitions/ParamRef" + } + ] + }, + "rows": { + "anyOf": [ + { + "items": { + "type": [ + "number", + "null" + ] + }, + "type": "array" + }, + { + "$ref": "#/definitions/ParamRef" + } + ] + } + }, + "required": [ + "argmin" + ], + "type": "object" + }, "TransformField": { + "anyOf": [ + { + "type": "string" + }, + { + "$ref": "#/definitions/ParamRef" + } + ], + "description": "A field argument to a data transform." + }, + "ParamRef": { + "type": "string" + } + } +}