Skip to content

Commit

Permalink
Fix codegen for pgvector (#447)
Browse files Browse the repository at this point in the history
  • Loading branch information
fantix committed Sep 22, 2023
1 parent a9c639f commit aaf81c1
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 17 deletions.
60 changes: 43 additions & 17 deletions edgedb/codegen/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
"cal::relative_duration": "edgedb.RelativeDuration",
"cal::date_duration": "edgedb.DateDuration",
"cfg::memory": "edgedb.ConfigMemory",
"ext::pgvector::vector": "array.array",
}

TYPE_IMPORTS = {
Expand All @@ -77,8 +78,23 @@
"cal::local_date": "datetime",
"cal::local_time": "datetime",
"cal::local_datetime": "datetime",
"ext::pgvector::vector": "array",
}

INPUT_TYPE_MAPPING = TYPE_MAPPING.copy()
INPUT_TYPE_MAPPING.update(
{
"ext::pgvector::vector": "typing.Sequence[float]",
}
)

INPUT_TYPE_IMPORTS = TYPE_IMPORTS.copy()
INPUT_TYPE_IMPORTS.update(
{
"ext::pgvector::vector": "typing",
}
)

PYDANTIC_MIXIN = """\
class NoPydanticValidation:
@classmethod
Expand Down Expand Up @@ -336,14 +352,17 @@ def _generate(
if "".join(dr.input_type.elements.keys()).isdecimal():
for el_name, el in dr.input_type.elements.items():
args[int(el_name)] = self._generate_code_with_cardinality(
el.type, f"arg{el_name}", el.cardinality
el.type, f"arg{el_name}", el.cardinality, is_input=True
)
args = {f"arg{i}": v for i, v in sorted(args.items())}
else:
kw_only = True
for el_name, el in dr.input_type.elements.items():
args[el_name] = self._generate_code_with_cardinality(
el.type, el_name, el.cardinality
el.type,
el_name,
el.cardinality,
is_input=True
)

if self._async:
Expand Down Expand Up @@ -392,22 +411,28 @@ def _generate(
return buf.getvalue()

def _generate_code(
self, type_: typing.Optional[describe.AnyType], name_hint: str
self,
type_: typing.Optional[describe.AnyType],
name_hint: str,
is_input: bool = False,
) -> str:
if type_ is None:
return "None"

if type_.desc_id in self._cache:
return self._cache[type_.desc_id]
if (type_.desc_id, is_input) in self._cache:
return self._cache[(type_.desc_id, is_input)]

imports = INPUT_TYPE_IMPORTS if is_input else TYPE_IMPORTS
mapping = INPUT_TYPE_MAPPING if is_input else TYPE_MAPPING

if isinstance(type_, describe.BaseScalarType):
if type_.name in TYPE_IMPORTS:
self._imports.add(TYPE_IMPORTS[type_.name])
rv = TYPE_MAPPING[type_.name]
if import_str := imports.get(type_.name):
self._imports.add(import_str)
rv = mapping[type_.name]

elif isinstance(type_, describe.SequenceType):
el_type = self._generate_code(
type_.element_type, f"{name_hint}Item"
type_.element_type, f"{name_hint}Item", is_input
)
if SYS_VERSION_INFO >= (3, 9):
rv = f"list[{el_type}]"
Expand All @@ -417,7 +442,7 @@ def _generate_code(

elif isinstance(type_, describe.TupleType):
elements = ", ".join(
self._generate_code(el_type, f"{name_hint}Item")
self._generate_code(el_type, f"{name_hint}Item", is_input)
for el_type in type_.element_types
)
if SYS_VERSION_INFO >= (3, 9):
Expand All @@ -429,9 +454,9 @@ def _generate_code(
elif isinstance(type_, describe.ScalarType):
rv = self._find_name(type_.name)
base_type_name = type_.base_type.name
if base_type_name in TYPE_IMPORTS:
self._imports.add(TYPE_IMPORTS[base_type_name])
value = TYPE_MAPPING[base_type_name]
if import_str := imports.get(base_type_name):
self._imports.add(import_str)
value = mapping[base_type_name]
self._aliases[rv] = f"{rv} = {value}"

elif isinstance(type_, describe.ObjectType):
Expand Down Expand Up @@ -491,7 +516,7 @@ def _generate_code(
print(f"class {rv}(typing.NamedTuple):", file=buf)
for el_name, el_type in type_.element_types.items():
el_code = self._generate_code(
el_type, f"{rv}{self._snake_to_camel(el_name)}"
el_type, f"{rv}{self._snake_to_camel(el_name)}", is_input
)
print(f"{INDENT}{el_name}: {el_code}", file=buf)
self._defs[rv] = buf.getvalue().strip()
Expand All @@ -506,22 +531,23 @@ def _generate_code(
self._defs[rv] = buf.getvalue().strip()

elif isinstance(type_, describe.RangeType):
value = self._generate_code(type_.value_type, name_hint)
value = self._generate_code(type_.value_type, name_hint, is_input)
rv = f"edgedb.Range[{value}]"

else:
rv = "??"

self._cache[type_.desc_id] = rv
self._cache[(type_.desc_id, is_input)] = rv
return rv

def _generate_code_with_cardinality(
self,
type_: typing.Optional[describe.AnyType],
name_hint: str,
cardinality: edgedb.Cardinality,
is_input: bool = False,
):
rv = self._generate_code(type_, name_hint)
rv = self._generate_code(type_, name_hint, is_input)
if cardinality == edgedb.Cardinality.AT_MOST_ONE:
if SYS_VERSION_INFO >= (3, 10):
rv = f"{rv} | None"
Expand Down
9 changes: 9 additions & 0 deletions tests/codegen/test-project2/generated_async_edgeql.py.assert
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@


from __future__ import annotations
import array
import dataclasses
import datetime
import edgedb
Expand Down Expand Up @@ -113,6 +114,8 @@ class MyQueryResult:
az: MyScalar | None
ba: MyEnum
bb: MyEnum | None
bc: array.array
bd: array.array | None


@dataclasses.dataclass
Expand Down Expand Up @@ -217,6 +220,8 @@ async def my_query(
av: edgedb.Range[datetime.datetime] | None,
aw: edgedb.Range[datetime.date],
ax: edgedb.Range[datetime.date] | None,
bc: typing.Sequence[float],
bd: typing.Sequence[float] | None,
) -> MyQueryResult:
return await executor.query_single(
"""\
Expand Down Expand Up @@ -278,6 +283,8 @@ async def my_query(
az := <optional MyScalar>{},
ba := MyEnum.This,
bb := <optional MyEnum>{},
bc := <ext::pgvector::vector>$bc,
bd := <optional ext::pgvector::vector>$bd,
}\
""",
a=a,
Expand Down Expand Up @@ -330,6 +337,8 @@ async def my_query(
av=av,
aw=aw,
ax=ax,
bc=bc,
bd=bd,
)


Expand Down
2 changes: 2 additions & 0 deletions tests/codegen/test-project2/parpkg/subpkg/my_query.edgeql
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,6 @@ select {
az := <optional MyScalar>{},
ba := MyEnum.This,
bb := <optional MyEnum>{},
bc := <ext::pgvector::vector>$bc,
bd := <optional ext::pgvector::vector>$bd,
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@


from __future__ import annotations
import array
import dataclasses
import datetime
import edgedb
Expand Down Expand Up @@ -88,6 +89,8 @@ class MyQueryResult(NoPydanticValidation):
az: typing.Optional[MyScalar]
ba: MyEnum
bb: typing.Optional[MyEnum]
bc: array.array
bd: typing.Optional[array.array]


async def my_query(
Expand Down Expand Up @@ -143,6 +146,8 @@ async def my_query(
av: typing.Optional[edgedb.Range[datetime.datetime]],
aw: edgedb.Range[datetime.date],
ax: typing.Optional[edgedb.Range[datetime.date]],
bc: typing.Sequence[float],
bd: typing.Optional[typing.Sequence[float]],
) -> MyQueryResult:
return await executor.query_single(
"""\
Expand Down Expand Up @@ -204,6 +209,8 @@ async def my_query(
az := <optional MyScalar>{},
ba := MyEnum.This,
bb := <optional MyEnum>{},
bc := <ext::pgvector::vector>$bc,
bd := <optional ext::pgvector::vector>$bd,
}\
""",
a=a,
Expand Down Expand Up @@ -256,4 +263,6 @@ async def my_query(
av=av,
aw=aw,
ax=ax,
bc=bc,
bd=bd,
)
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@


from __future__ import annotations
import array
import dataclasses
import datetime
import edgedb
Expand Down Expand Up @@ -79,6 +80,8 @@ class MyQueryResult:
az: typing.Optional[MyScalar]
ba: MyEnum
bb: typing.Optional[MyEnum]
bc: array.array
bd: typing.Optional[array.array]


def my_query(
Expand Down Expand Up @@ -134,6 +137,8 @@ def my_query(
av: typing.Optional[edgedb.Range[datetime.datetime]],
aw: edgedb.Range[datetime.date],
ax: typing.Optional[edgedb.Range[datetime.date]],
bc: typing.Sequence[float],
bd: typing.Optional[typing.Sequence[float]],
) -> MyQueryResult:
return executor.query_single(
"""\
Expand Down Expand Up @@ -195,6 +200,8 @@ def my_query(
az := <optional MyScalar>{},
ba := MyEnum.This,
bb := <optional MyEnum>{},
bc := <ext::pgvector::vector>$bc,
bd := <optional ext::pgvector::vector>$bd,
}\
""",
a=a,
Expand Down Expand Up @@ -247,4 +254,6 @@ def my_query(
av=av,
aw=aw,
ax=ax,
bc=bc,
bd=bd,
)
9 changes: 9 additions & 0 deletions tests/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,14 @@


class TestCodegen(tb.AsyncQueryTestCase):
SETUP = '''
create extension pgvector;
'''

TEARDOWN = '''
drop extension pgvector;
'''

async def test_codegen(self):
env = os.environ.copy()
env.update(
Expand All @@ -36,6 +44,7 @@ async def test_codegen(self):
for k, v in self.get_connect_args().items()
}
)
env["EDGEDB_DATABASE"] = self.get_database_name()
container = pathlib.Path(__file__).absolute().parent / "codegen"
with tempfile.TemporaryDirectory() as td:
td_path = pathlib.Path(td)
Expand Down

0 comments on commit aaf81c1

Please sign in to comment.