diff --git a/edgedb/codegen/generator.py b/edgedb/codegen/generator.py index 626e735e..ba865a43 100644 --- a/edgedb/codegen/generator.py +++ b/edgedb/codegen/generator.py @@ -67,6 +67,7 @@ "cal::relative_duration": "edgedb.RelativeDuration", "cal::date_duration": "edgedb.DateDuration", "cfg::memory": "edgedb.ConfigMemory", + "ext::pgvector::vector": "array.array", } TYPE_IMPORTS = { @@ -77,8 +78,23 @@ "cal::local_date": "datetime", "cal::local_time": "datetime", "cal::local_datetime": "datetime", + "ext::pgvector::vector": "array", } +INPUT_TYPE_MAPPING = TYPE_MAPPING.copy() +INPUT_TYPE_MAPPING.update( + { + "ext::pgvector::vector": "typing.Sequence[float]", + } +) + +INPUT_TYPE_IMPORTS = TYPE_IMPORTS.copy() +INPUT_TYPE_IMPORTS.update( + { + "ext::pgvector::vector": "typing", + } +) + PYDANTIC_MIXIN = """\ class NoPydanticValidation: @classmethod @@ -336,14 +352,18 @@ def _generate( if "".join(dr.input_type.elements.keys()).isdecimal(): for el_name, el in dr.input_type.elements.items(): args[int(el_name)] = self._generate_code_with_cardinality( - el.type, f"arg{el_name}", el.cardinality + el.type, f"arg{el_name}", el.cardinality, is_input=True ) args = {f"arg{i}": v for i, v in sorted(args.items())} else: kw_only = True for el_name, el in dr.input_type.elements.items(): args[el_name] = self._generate_code_with_cardinality( - el.type, el_name, el.cardinality, keyword_argument=True + el.type, + el_name, + el.cardinality, + keyword_argument=True, + is_input=True ) if self._async: @@ -392,22 +412,28 @@ def _generate( return buf.getvalue() def _generate_code( - self, type_: typing.Optional[describe.AnyType], name_hint: str + self, + type_: typing.Optional[describe.AnyType], + name_hint: str, + is_input: bool = False, ) -> str: if type_ is None: return "None" - if type_.desc_id in self._cache: - return self._cache[type_.desc_id] + if (type_.desc_id, is_input) in self._cache: + return self._cache[(type_.desc_id, is_input)] + + imports = INPUT_TYPE_IMPORTS if is_input else TYPE_IMPORTS + mapping = INPUT_TYPE_MAPPING if is_input else TYPE_MAPPING if isinstance(type_, describe.BaseScalarType): - if type_.name in TYPE_IMPORTS: - self._imports.add(TYPE_IMPORTS[type_.name]) - rv = TYPE_MAPPING[type_.name] + if import_str := imports.get(type_.name): + self._imports.add(import_str) + rv = mapping[type_.name] elif isinstance(type_, describe.SequenceType): el_type = self._generate_code( - type_.element_type, f"{name_hint}Item" + type_.element_type, f"{name_hint}Item", is_input ) if SYS_VERSION_INFO >= (3, 9): rv = f"list[{el_type}]" @@ -417,7 +443,7 @@ def _generate_code( elif isinstance(type_, describe.TupleType): elements = ", ".join( - self._generate_code(el_type, f"{name_hint}Item") + self._generate_code(el_type, f"{name_hint}Item", is_input) for el_type in type_.element_types ) if SYS_VERSION_INFO >= (3, 9): @@ -429,9 +455,9 @@ def _generate_code( elif isinstance(type_, describe.ScalarType): rv = self._find_name(type_.name) base_type_name = type_.base_type.name - if base_type_name in TYPE_IMPORTS: - self._imports.add(TYPE_IMPORTS[base_type_name]) - value = TYPE_MAPPING[base_type_name] + if import_str := imports.get(base_type_name): + self._imports.add(import_str) + value = mapping[base_type_name] self._aliases[rv] = f"{rv} = {value}" elif isinstance(type_, describe.ObjectType): @@ -487,7 +513,7 @@ def _generate_code( print(f"class {rv}(typing.NamedTuple):", file=buf) for el_name, el_type in type_.element_types.items(): el_code = self._generate_code( - el_type, f"{rv}{self._snake_to_camel(el_name)}" + el_type, f"{rv}{self._snake_to_camel(el_name)}", is_input ) print(f"{INDENT}{el_name}: {el_code}", file=buf) self._defs[rv] = buf.getvalue().strip() @@ -502,13 +528,13 @@ def _generate_code( self._defs[rv] = buf.getvalue().strip() elif isinstance(type_, describe.RangeType): - value = self._generate_code(type_.value_type, name_hint) + value = self._generate_code(type_.value_type, name_hint, is_input) rv = f"edgedb.Range[{value}]" else: rv = "??" - self._cache[type_.desc_id] = rv + self._cache[(type_.desc_id, is_input)] = rv return rv def _generate_code_with_cardinality( @@ -517,8 +543,9 @@ def _generate_code_with_cardinality( name_hint: str, cardinality: edgedb.Cardinality, keyword_argument: bool = False, + is_input: bool = False, ): - rv = self._generate_code(type_, name_hint) + rv = self._generate_code(type_, name_hint, is_input) if cardinality == edgedb.Cardinality.AT_MOST_ONE: if SYS_VERSION_INFO >= (3, 10): rv = f"{rv} | None" diff --git a/tests/codegen/test-project2/generated_async_edgeql.py.assert b/tests/codegen/test-project2/generated_async_edgeql.py.assert index 5a314e73..c12668b3 100644 --- a/tests/codegen/test-project2/generated_async_edgeql.py.assert +++ b/tests/codegen/test-project2/generated_async_edgeql.py.assert @@ -12,6 +12,7 @@ from __future__ import annotations +import array import dataclasses import datetime import edgedb @@ -113,6 +114,8 @@ class MyQueryResult: az: MyScalar | None ba: MyEnum bb: MyEnum | None + bc: array.array + bd: array.array | None @dataclasses.dataclass @@ -217,6 +220,8 @@ async def my_query( av: edgedb.Range[datetime.datetime] | None = None, aw: edgedb.Range[datetime.date], ax: edgedb.Range[datetime.date] | None = None, + bc: typing.Sequence[float], + bd: typing.Sequence[float] | None = None, ) -> MyQueryResult: return await executor.query_single( """\ @@ -278,6 +283,8 @@ async def my_query( az := {}, ba := MyEnum.This, bb := {}, + bc := $bc, + bd := $bd, }\ """, a=a, @@ -330,6 +337,8 @@ async def my_query( av=av, aw=aw, ax=ax, + bc=bc, + bd=bd, ) diff --git a/tests/codegen/test-project2/parpkg/subpkg/my_query.edgeql b/tests/codegen/test-project2/parpkg/subpkg/my_query.edgeql index a00f8964..8e5bd35d 100644 --- a/tests/codegen/test-project2/parpkg/subpkg/my_query.edgeql +++ b/tests/codegen/test-project2/parpkg/subpkg/my_query.edgeql @@ -56,4 +56,6 @@ select { az := {}, ba := MyEnum.This, bb := {}, + bc := $bc, + bd := $bd, } diff --git a/tests/codegen/test-project2/parpkg/subpkg/my_query_async_edgeql.py.assert b/tests/codegen/test-project2/parpkg/subpkg/my_query_async_edgeql.py.assert index f58acdaf..cd7d6bb2 100644 --- a/tests/codegen/test-project2/parpkg/subpkg/my_query_async_edgeql.py.assert +++ b/tests/codegen/test-project2/parpkg/subpkg/my_query_async_edgeql.py.assert @@ -3,6 +3,7 @@ from __future__ import annotations +import array import dataclasses import datetime import edgedb @@ -88,6 +89,8 @@ class MyQueryResult(NoPydanticValidation): az: typing.Optional[MyScalar] ba: MyEnum bb: typing.Optional[MyEnum] + bc: array.array + bd: typing.Optional[array.array] async def my_query( @@ -143,6 +146,8 @@ async def my_query( av: typing.Optional[edgedb.Range[datetime.datetime]] = None, aw: edgedb.Range[datetime.date], ax: typing.Optional[edgedb.Range[datetime.date]] = None, + bc: typing.Sequence[float], + bd: typing.Optional[typing.Sequence[float]] = None, ) -> MyQueryResult: return await executor.query_single( """\ @@ -204,6 +209,8 @@ async def my_query( az := {}, ba := MyEnum.This, bb := {}, + bc := $bc, + bd := $bd, }\ """, a=a, @@ -256,4 +263,6 @@ async def my_query( av=av, aw=aw, ax=ax, + bc=bc, + bd=bd, ) diff --git a/tests/codegen/test-project2/parpkg/subpkg/my_query_edgeql.py.assert b/tests/codegen/test-project2/parpkg/subpkg/my_query_edgeql.py.assert index 0b68b62a..c5a85ea8 100644 --- a/tests/codegen/test-project2/parpkg/subpkg/my_query_edgeql.py.assert +++ b/tests/codegen/test-project2/parpkg/subpkg/my_query_edgeql.py.assert @@ -3,6 +3,7 @@ from __future__ import annotations +import array import dataclasses import datetime import edgedb @@ -79,6 +80,8 @@ class MyQueryResult: az: typing.Optional[MyScalar] ba: MyEnum bb: typing.Optional[MyEnum] + bc: array.array + bd: typing.Optional[array.array] def my_query( @@ -134,6 +137,8 @@ def my_query( av: typing.Optional[edgedb.Range[datetime.datetime]] = None, aw: edgedb.Range[datetime.date], ax: typing.Optional[edgedb.Range[datetime.date]] = None, + bc: typing.Sequence[float], + bd: typing.Optional[typing.Sequence[float]] = None, ) -> MyQueryResult: return executor.query_single( """\ @@ -195,6 +200,8 @@ def my_query( az := {}, ba := MyEnum.This, bb := {}, + bc := $bc, + bd := $bd, }\ """, a=a, @@ -247,4 +254,6 @@ def my_query( av=av, aw=aw, ax=ax, + bc=bc, + bd=bd, ) diff --git a/tests/test_codegen.py b/tests/test_codegen.py index ae3c0257..9cec27f6 100644 --- a/tests/test_codegen.py +++ b/tests/test_codegen.py @@ -28,6 +28,14 @@ class TestCodegen(tb.AsyncQueryTestCase): + SETUP = ''' + create extension pgvector; + ''' + + TEARDOWN = ''' + drop extension pgvector; + ''' + async def test_codegen(self): env = os.environ.copy() env.update( @@ -36,6 +44,7 @@ async def test_codegen(self): for k, v in self.get_connect_args().items() } ) + env["EDGEDB_DATABASE"] = self.get_database_name() container = pathlib.Path(__file__).absolute().parent / "codegen" with tempfile.TemporaryDirectory() as td: td_path = pathlib.Path(td)