Skip to content

Commit

Permalink
Implement codegen using describe() API
Browse files Browse the repository at this point in the history
  • Loading branch information
fantix committed Sep 7, 2022
1 parent 6245671 commit f0cf862
Show file tree
Hide file tree
Showing 2 changed files with 314 additions and 19 deletions.
6 changes: 6 additions & 0 deletions edgedb/codegen/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@
choices=["default", "strict", "no_host_verification", "insecure"],
)
parser.add_argument("--file", action="store_true")
parser.add_argument(
"--target",
choices=["blocking", "async", "pydantic"],
nargs="*",
default=["async", "pydantic"],
)


def main():
Expand Down
327 changes: 308 additions & 19 deletions edgedb/codegen/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,54 @@

import argparse
import getpass
import io
import os
import pathlib
import sys
import textwrap

import edgedb
from edgedb.con_utils import find_edgedb_project_dir


FILE_MODE_OUTPUT_FILE = "generated_edgeql.py"
INDENT = " "
SUFFIXES = [
("async", "_async_edgeql.py", True),
("blocking", "_edgeql.py", False),
]
FILE_MODE_OUTPUT_FILE = "generated"

TYPE_MAPPING = {
"std::str": "str",
"std::float32": "float",
"std::float64": "float",
"std::int16": "int",
"std::int32": "int",
"std::int64": "int",
"std::bigint": "int",
"std::bool": "bool",
"std::uuid": "uuid.UUID",
"std::bytes": "bytes",
"std::decimal": "decimal.Decimal",
"std::datetime": "datetime.datetime",
"std::duration": "datetime.timedelta",
"cal::local_date": "datetime.date",
"cal::local_time": "datetime.time",
"cal::local_datetime": "datetime.datetime",
"cal::relative_duration": "edgedb.RelativeDuration",
"cal::date_duration": "edgedb.DateDuration",
"cfg::memory": "edgedb.ConfigMemory",
}

TYPE_IMPORTS = {
"std::uuid": "uuid",
"std::decimal": "decimal",
"std::datetime": "datetime",
"std::duration": "datetime",
"cal::local_date": "datetime",
"cal::local_time": "datetime",
"cal::local_datetime": "datetime",
}


def _get_conn_args(args: argparse.Namespace):
Expand Down Expand Up @@ -61,6 +101,9 @@ def _get_conn_args(args: argparse.Namespace):

class Generator:
def __init__(self, args: argparse.Namespace):
self._default_module = "default"
self._targets = args.target
self._async = False
try:
self._project_dir = pathlib.Path(find_edgedb_project_dir())
except edgedb.ClientConnectionError:
Expand All @@ -73,7 +116,19 @@ def __init__(self, args: argparse.Namespace):
self._file_mode = args.file
self._method_names = set()
self._describe_results = []
self._output = []

self._cache = {}
self._imports = set()
self._aliases = {}
self._defs = {}
self._names = set()

def _new_file(self):
self._cache.clear()
self._imports.clear()
self._aliases.clear()
self._defs.clear()
self._names.clear()

def run(self):
try:
Expand All @@ -83,10 +138,14 @@ def run(self):
sys.exit(61)
with self._client:
self._process_dir(self._project_dir)
if self._file_mode:
self._generate_single_file()
else:
self._generate_files()
for target, suffix, is_async in SUFFIXES:
if target in self._targets:
self._async = is_async
if self._file_mode:
self._generate_single_file(suffix)
else:
self._generate_files(suffix)
self._new_file()

def _process_dir(self, dir_: pathlib.Path):
for file_or_dir in dir_.iterdir():
Expand All @@ -112,23 +171,253 @@ def _process_file(self, source: pathlib.Path):
sys.exit(17)
self._method_names.add(name)
dr = self._client.describe(query)
self._describe_results.append((name, source, dr))
self._describe_results.append((name, source, query, dr))

def _generate_files(self):
for name, source, dr in self._describe_results:
target = source.with_stem(f"{name}_edgeql").with_suffix(".py")
def _generate_files(self, suffix: str):
for name, source, query, dr in self._describe_results:
target = source.parent / f"{name}{suffix}"
print(f"Generating {target}", file=sys.stderr)
content = self._generate(name, dr)
self._new_file()
content = self._generate(name, query, dr)
buf = io.StringIO()
self._write_definitions(buf)
buf.write(content)
with target.open("w") as f:
f.write(content)
f.write(buf.getvalue())

def _generate_single_file(self):
target = self._project_dir / FILE_MODE_OUTPUT_FILE
def _generate_single_file(self, suffix: str):
target = self._project_dir / f"{FILE_MODE_OUTPUT_FILE}{suffix}"
print(f"Generating {target}", file=sys.stderr)
for name, _, dr in self._describe_results:
self._output.append(self._generate(name, dr))
buf = io.StringIO()
output = []
for name, _, query, dr in self._describe_results:
output.append(self._generate(name, query, dr))
self._write_definitions(buf)
buf.write(f"{os.linesep}{os.linesep}".join(output))
with target.open("w") as f:
f.writelines(self._output)
f.write(buf.getvalue())

def _write_definitions(self, f: io.TextIOBase):
print("from __future__ import annotations", file=f)
for m in sorted(self._imports):
print(f"import {m}", file=f)
print(file=f)
print(file=f)

if self._aliases:
for _, a in sorted(self._aliases.items()):
print(a, file=f)
print(file=f)
print(file=f)

if "pydantic" in self._targets:
print(
textwrap.dedent(
"""
class NoPydanticValidation:
@classmethod
def __get_validators__(cls):
from pydantic import dataclasses as dc, json as js
js.ENCODERS_BY_TYPE[edgedb.EnumValue] = str
dc.dataclass(cls)
cls.__pydantic_model__.__get_validators__ = lambda: []
return []\
"""
).strip(),
file=f,
)
print(file=f)
print(file=f)

for _, d in sorted(self._defs.items()):
print(d, file=f)
print(file=f)
print(file=f)

def _generate(
self, name: str, query: str, dr: edgedb.DescribeResult
) -> str:
buf = io.StringIO()

out_type = self._generate_code(
dr.output_codec, f"{name.title()}Result"
)
if dr.output_cardinality == edgedb.Cardinality.MANY:
self._imports.add("typing")
out_type = f"typing.Sequence[{out_type}]"

args = {}
kw_only = False
if dr.input_codec.codec_type == edgedb.CodecType.OBJECT:
if dr.input_codec.sub_types[0].name.isdecimal():
for sub_type in dr.input_codec.sub_types:
args[int(sub_type.name)] = self._generate_code(
sub_type.codec, f"arg{sub_type.name}"
)
args = {f"arg{i}": v for i, v in sorted(args.items())}
else:
kw_only = True
for sub_type in dr.input_codec.sub_types:
args[sub_type.name] = self._generate_code(
sub_type.codec, sub_type.name
)

if self._async:
print(f"async def {name}(", file=buf)
else:
print(f"def {name}(", file=buf)
self._imports.add("edgedb")
if self._async:
print(f"{INDENT}client: edgedb.AsyncIOClient,", file=buf)
else:
print(f"{INDENT}client: edgedb.Client,", file=buf)
if kw_only:
print(f"{INDENT}*,", file=buf)
for name, arg in args.items():
print(f"{INDENT}{name}: {arg},", file=buf)
print(f") -> {out_type}:", file=buf)
if dr.output_cardinality == edgedb.Cardinality.MANY:
method = "query"
rt = "return "
elif dr.output_cardinality == edgedb.Cardinality.NO_RESULT:
method = "execute"
rt = ""
else:
method = "query_single"
rt = "return "

def _generate(self, name: str, dr: edgedb.DescribeResult) -> str:
return f"{name}: {dr}\n"
if self._async:
print(f"{INDENT}{rt}await client.{method}(", file=buf)
else:
print(f"{INDENT}{rt}client.{method}(", file=buf)
print(f'{INDENT}{INDENT}"""\\', file=buf)
print(
textwrap.indent(
textwrap.dedent(query).strip(), f"{INDENT}{INDENT}"
)
+ "\\",
file=buf,
)
print(f'{INDENT}{INDENT}""",', file=buf)
for name in args:
if kw_only:
print(f"{INDENT}{INDENT}{name}={name},", file=buf)
else:
print(f"{INDENT}{INDENT}{name},", file=buf)
print(f"{INDENT})", file=buf)
return buf.getvalue()

def _generate_code(
self, codec: edgedb.AbstractCodec, name_hint: str
) -> str:
if codec.type_id in self._cache:
return self._cache[codec.type_id]

if codec.codec_type is None:
rv = "None"

elif codec.codec_type == edgedb.CodecType.BASE_SCALAR:
if codec.type_name in TYPE_IMPORTS:
self._imports.add(TYPE_IMPORTS[codec.type_name])
rv = TYPE_MAPPING[codec.type_name]

elif codec.codec_type in [
edgedb.CodecType.ARRAY,
edgedb.CodecType.SET,
]:
sub_type = self._generate_code(
codec.sub_types[0].codec, f"{name_hint}Item"
)
self._imports.add("typing")
rv = f"typing.Sequence[{sub_type}]"

elif codec.codec_type == edgedb.CodecType.TUPLE:
content = ", ".join(
self._generate_code(sub_type.codec, f"{name_hint}Item")
for sub_type in codec.sub_types
)
self._imports.add("typing")
rv = f"typing.Tuple[{content}]"

elif codec.codec_type == edgedb.CodecType.SCALAR:
rv = self._find_name(codec.type_name)
if codec.base_type_name in TYPE_IMPORTS:
self._imports.add(TYPE_IMPORTS[codec.base_type_name])
value = TYPE_MAPPING[codec.base_type_name]
self._aliases[rv] = f"{rv} = {value}"

elif codec.codec_type == edgedb.CodecType.OBJECT:
name = codec.type_name
if not name or name == "std::FreeObject":
name = name_hint
rv = self._find_name(name)
buf = io.StringIO()
self._imports.add("dataclasses")
print("@dataclasses.dataclass", file=buf)
if "pydantic" in self._targets:
print(f"class {rv}(NoPydanticValidation):", file=buf)
else:
print(f"class {rv}:", file=buf)
for sub_type in codec.sub_types:
if sub_type.is_implicit and sub_type.name != "id":
continue
name = f"{rv}{sub_type.name.title()}"
# if sub_type.cardinality == edgedb.Cardinality.MANY:
# name = f"{name}Item"
# st_code = self._generate_code(sub_type.codec, name)
# self._imports.add("typing")
# st_code = f"typing.Sequence[{st_code}]"
# else:
st_code = self._generate_code(sub_type.codec, name)
print(f"{INDENT}{sub_type.name}: {st_code}", file=buf)
self._defs[rv] = buf.getvalue().strip()

elif codec.codec_type == edgedb.CodecType.NAMED_TUPLE:
rv = self._find_name(name_hint)
buf = io.StringIO()
self._imports.add("typing")
print(f"class {rv}(typing.NamedTuple):", file=buf)
for sub_type in codec.sub_types:
name = f"{rv}{sub_type.name.title()}"
st_code = self._generate_code(sub_type.codec, name)
print(f"{INDENT}{sub_type.name}: {st_code}", file=buf)
self._defs[rv] = buf.getvalue().strip()

elif codec.codec_type == edgedb.CodecType.ENUM:
rv = self._find_name(codec.type_name or name_hint)
buf = io.StringIO()
self._imports.add("enum")
print(f"class {rv}(enum.Enum):", file=buf)
for label in codec.enum_labels:
print(f'{INDENT}{label.upper()} = "{label}"', file=buf)
self._defs[rv] = buf.getvalue().strip()

else:
rv = "??"

self._cache[codec.type_id] = rv
return rv

def _find_name(self, name: str) -> str:
default_prefix = f"{self._default_module}::"
if name.startswith(default_prefix):
name = name[len(default_prefix) :]
mod, _, name = name.rpartition("::")
parts = name.split("_")
if len(parts) > 1:
name = "".join(map(str.title, parts))
name = mod.title() + name
if name in self._names:
for i in range(2, 100):
new = f"{name}{i:02d}"
if new not in self._names:
name = new
break
else:
print(
f"Failed to find a unique name for: {name}",
file=sys.stderr,
)
sys.exit(17)
self._names.add(name)
return name

0 comments on commit f0cf862

Please sign in to comment.