diff --git a/edgedb/abstract.py b/edgedb/abstract.py index fe89f1a9..44dc4498 100644 --- a/edgedb/abstract.py +++ b/edgedb/abstract.py @@ -39,11 +39,13 @@ class QueryContext(typing.NamedTuple): cache: QueryCache query_options: QueryOptions retry_options: typing.Optional[options.RetryOptions] + session: typing.Optional[options.Session] class ScriptContext(typing.NamedTuple): query: QueryWithArgs cache: QueryCache + session: typing.Optional[options.Session] _query_opts = QueryOptions( @@ -88,6 +90,9 @@ def _get_query_cache(self) -> QueryCache: def _get_retry_options(self) -> typing.Optional[options.RetryOptions]: return None + def _get_session(self) -> options.Session: + ... + class ReadOnlyExecutor(BaseReadOnlyExecutor): """Subclasses can execute *at least* read-only queries""" @@ -104,6 +109,7 @@ def query(self, query: str, *args, **kwargs) -> datatypes.Set: cache=self._get_query_cache(), query_options=_query_opts, retry_options=self._get_retry_options(), + session=self._get_session(), )) def query_single( @@ -114,6 +120,7 @@ def query_single( cache=self._get_query_cache(), query_options=_query_single_opts, retry_options=self._get_retry_options(), + session=self._get_session(), )) def query_required_single(self, query: str, *args, **kwargs) -> typing.Any: @@ -122,6 +129,7 @@ def query_required_single(self, query: str, *args, **kwargs) -> typing.Any: cache=self._get_query_cache(), query_options=_query_required_single_opts, retry_options=self._get_retry_options(), + session=self._get_session(), )) def query_json(self, query: str, *args, **kwargs) -> str: @@ -130,6 +138,7 @@ def query_json(self, query: str, *args, **kwargs) -> str: cache=self._get_query_cache(), query_options=_query_json_opts, retry_options=self._get_retry_options(), + session=self._get_session(), )) def query_single_json(self, query: str, *args, **kwargs) -> str: @@ -138,6 +147,7 @@ def query_single_json(self, query: str, *args, **kwargs) -> str: cache=self._get_query_cache(), query_options=_query_single_json_opts, retry_options=self._get_retry_options(), + session=self._get_session(), )) def query_required_single_json(self, query: str, *args, **kwargs) -> str: @@ -146,6 +156,7 @@ def query_required_single_json(self, query: str, *args, **kwargs) -> str: cache=self._get_query_cache(), query_options=_query_required_single_json_opts, retry_options=self._get_retry_options(), + session=self._get_session(), )) @abc.abstractmethod @@ -156,6 +167,7 @@ def execute(self, query: str, *args, **kwargs) -> None: self._execute(ScriptContext( query=QueryWithArgs(query, args, kwargs), cache=self._get_query_cache(), + session=self._get_session(), )) @@ -180,6 +192,7 @@ async def query(self, query: str, *args, **kwargs) -> datatypes.Set: cache=self._get_query_cache(), query_options=_query_opts, retry_options=self._get_retry_options(), + session=self._get_session(), )) async def query_single(self, query: str, *args, **kwargs) -> typing.Any: @@ -188,6 +201,7 @@ async def query_single(self, query: str, *args, **kwargs) -> typing.Any: cache=self._get_query_cache(), query_options=_query_single_opts, retry_options=self._get_retry_options(), + session=self._get_session(), )) async def query_required_single( @@ -201,6 +215,7 @@ async def query_required_single( cache=self._get_query_cache(), query_options=_query_required_single_opts, retry_options=self._get_retry_options(), + session=self._get_session(), )) async def query_json(self, query: str, *args, **kwargs) -> str: @@ -209,6 +224,7 @@ async def query_json(self, query: str, *args, **kwargs) -> str: cache=self._get_query_cache(), query_options=_query_json_opts, retry_options=self._get_retry_options(), + session=self._get_session(), )) async def query_single_json(self, query: str, *args, **kwargs) -> str: @@ -217,6 +233,7 @@ async def query_single_json(self, query: str, *args, **kwargs) -> str: cache=self._get_query_cache(), query_options=_query_single_json_opts, retry_options=self._get_retry_options(), + session=self._get_session(), )) async def query_required_single_json( @@ -230,6 +247,7 @@ async def query_required_single_json( cache=self._get_query_cache(), query_options=_query_required_single_json_opts, retry_options=self._get_retry_options(), + session=self._get_session(), )) @abc.abstractmethod @@ -240,6 +258,7 @@ async def execute(self, query: str, *args, **kwargs) -> None: await self._execute(ScriptContext( query=QueryWithArgs(query, args, kwargs), cache=self._get_query_cache(), + session=self._get_session(), )) diff --git a/edgedb/base_client.py b/edgedb/base_client.py index 4562f623..c66ff631 100644 --- a/edgedb/base_client.py +++ b/edgedb/base_client.py @@ -211,8 +211,11 @@ async def raw_query(self, query_context: abstract.QueryContext): await self.connect(single_attempt=True) if self._protocol.is_legacy: execute = self._protocol.legacy_execute_anonymous + allow_capabilities = enums.Capability.LEGACY_EXECUTE else: execute = self._protocol.query + self._protocol.set_state(query_context.session) + allow_capabilities = enums.Capability.EXECUTE return await execute( query=query_context.query.query, args=query_context.query.args, @@ -222,7 +225,7 @@ async def raw_query(self, query_context: abstract.QueryContext): output_format=query_context.query_options.output_format, expect_one=query_context.query_options.expect_one, required_one=query_context.query_options.required_one, - allow_capabilities=enums.Capability.EXECUTE, + allow_capabilities=allow_capabilities, ) except errors.EdgeDBError as e: if query_context.retry_options is None: @@ -261,9 +264,10 @@ async def _execute(self, script: abstract.ScriptContext) -> None: "Legacy protocol doesn't support arguments in execute()" ) await self._protocol.legacy_simple_query( - script.query.query, enums.Capability.EXECUTE + script.query.query, enums.Capability.LEGACY_EXECUTE ) else: + self._protocol.set_state(script.session) await self._protocol.execute( query=script.query.query, args=script.query.args, @@ -697,6 +701,9 @@ def _get_query_cache(self) -> abstract.QueryCache: def _get_retry_options(self) -> typing.Optional[_options.RetryOptions]: return self._options.retry_options + def _get_session(self) -> _options.Session: + return self._options.session + @property def max_concurrency(self) -> int: """Max number of connections in the pool.""" diff --git a/edgedb/enums.py b/edgedb/enums.py index 96859bde..f3306192 100644 --- a/edgedb/enums.py +++ b/edgedb/enums.py @@ -10,8 +10,9 @@ class Capability(enum.IntFlag): DDL = 1 << 3 # noqa PERSISTENT_CONFIG = 1 << 4 # noqa - ALL = 0xFFFF_FFFF_FFFF_FFFF # noqa - EXECUTE = ALL & ~TRANSACTION # noqa + ALL = 0xFFFF_FFFF_FFFF_FFFF # noqa + EXECUTE = ALL & ~TRANSACTION & ~SESSION_CONFIG # noqa + LEGACY_EXECUTE = ALL & ~TRANSACTION # noqa class CompilationFlag(enum.IntFlag): diff --git a/edgedb/errors/__init__.py b/edgedb/errors/__init__.py index f2bd92ee..87117443 100644 --- a/edgedb/errors/__init__.py +++ b/edgedb/errors/__init__.py @@ -23,6 +23,7 @@ 'UnexpectedMessageError', 'InputDataError', 'ParameterTypeMismatchError', + 'StateMismatchError', 'ResultCardinalityMismatchError', 'CapabilityError', 'UnsupportedCapabilityError', @@ -146,6 +147,10 @@ class ParameterTypeMismatchError(InputDataError): _code = 0x_03_02_01_00 +class StateMismatchError(InputDataError): + _code = 0x_03_02_02_00 + + class ResultCardinalityMismatchError(ProtocolError): _code = 0x_03_03_00_00 @@ -487,3 +492,4 @@ class NoDataError(ClientError): class InternalClientError(ClientError): _code = 0x_FF_04_00_00 + diff --git a/edgedb/options.py b/edgedb/options.py index f1e16924..0b0df1aa 100644 --- a/edgedb/options.py +++ b/edgedb/options.py @@ -1,6 +1,7 @@ import abc import enum import random +import typing from collections import namedtuple from . import errors @@ -108,6 +109,82 @@ def __repr__(self): ) +class Session: + __slots__ = ['_module', '_aliases', '_config', '_globals'] + + def __init__( + self, + module: typing.Optional[str] = None, + aliases: typing.Mapping[str, str] = None, + config: typing.Mapping[str, typing.Any] = None, + globals_: typing.Mapping[str, typing.Any] = None, + ): + self._module = module + self._aliases = {} if aliases is None else dict(aliases) + self._config = {} if config is None else dict(config) + self._globals = {} if globals_ is None else dict(globals_) + + @classmethod + def defaults(cls): + return cls() + + def with_aliases(self, module=..., **aliases): + new_aliases = self._aliases.copy() + new_aliases.update(aliases) + return Session( + module=self._module if module is ... else module, + aliases=new_aliases, + config=self._config, + globals_=self._globals, + ) + + def with_config(self, **config): + new_config = self._config.copy() + new_config.update(config) + return Session( + module=self._module, + aliases=self._aliases, + config=new_config, + globals_=self._globals, + ) + + def with_globals(self, **globals_): + new_globals = self._globals.copy() + new_globals.update(globals_) + return Session( + module=self._module, + aliases=self._aliases, + config=self._config, + globals_=new_globals, + ) + + def as_dict(self): + rv = {} + if self._module is not None: + module = rv["module"] = self._module + else: + module = 'default' + if self._aliases: + rv["aliases"] = list(self._aliases.items()) + if self._config: + rv["config"] = self._config + if self._globals: + rv["globals"] = g = {} + for k, v in self._globals.items(): + parts = k.split("::") + if len(parts) == 1: + g[f"{module}::{k}"] = v + elif len(parts) == 2: + mod, glob = parts + mod = self._aliases.get(mod, mod) + g[f"{mod}::{glob}"] = v + else: + raise errors.InvalidArgumentError( + f"Illegal global name: {k}" + ) + return rv + + class _OptionsMixin: def __init__(self, *args, **kwargs): self._options = _Options.defaults() @@ -153,19 +230,47 @@ def with_retry_options(self, options: RetryOptions=None): result._options = self._options.with_retry_options(options) return result + def with_session(self, session: Session): + result = self._shallow_clone() + result._options = self._options.with_session(session) + return result + + def with_aliases(self, module=None, **aliases): + result = self._shallow_clone() + result._options = self._options.with_session( + self._options.session.with_aliases(module=module, **aliases) + ) + return result + + def with_config(self, **config): + result = self._shallow_clone() + result._options = self._options.with_session( + self._options.session.with_config(**config) + ) + return result + + def with_globals(self, **globals_): + result = self._shallow_clone() + result._options = self._options.with_session( + self._options.session.with_globals(**globals_) + ) + return result + class _Options: """Internal class for storing connection options""" - __slots__ = ['_retry_options', '_transaction_options'] + __slots__ = ['_retry_options', '_transaction_options', '_session'] def __init__( self, retry_options: RetryOptions, transaction_options: TransactionOptions, + session: Session, ): self._retry_options = retry_options self._transaction_options = transaction_options + self._session = session @property def retry_options(self): @@ -175,16 +280,29 @@ def retry_options(self): def transaction_options(self): return self._transaction_options + @property + def session(self): + return self._session + def with_retry_options(self, options: RetryOptions): return _Options( options, self._transaction_options, + self._session, ) def with_transaction_options(self, options: TransactionOptions): return _Options( self._retry_options, options, + self._session, + ) + + def with_session(self, session: Session): + return _Options( + self._retry_options, + self._transaction_options, + session, ) @classmethod @@ -192,4 +310,5 @@ def defaults(cls): return cls( RetryOptions.defaults(), TransactionOptions.defaults(), + Session.defaults(), ) diff --git a/edgedb/protocol/codecs/array.pyx b/edgedb/protocol/codecs/array.pyx index e2b9aa88..e8ef3e23 100644 --- a/edgedb/protocol/codecs/array.pyx +++ b/edgedb/protocol/codecs/array.pyx @@ -44,8 +44,12 @@ cdef class BaseArrayCodec(BaseCodec): Py_ssize_t objlen Py_ssize_t i - if not isinstance(self.sub_codec, ScalarCodec): - raise TypeError('only arrays of scalars are supported') + if not isinstance(self.sub_codec, (ScalarCodec, TupleCodec)): + raise TypeError( + 'only arrays of scalars are supported (got type {!r})'.format( + type(self.sub_codec).__name__ + ) + ) if not _is_array_iterable(obj): raise TypeError( diff --git a/edgedb/protocol/codecs/codecs.pyx b/edgedb/protocol/codecs/codecs.pyx index 95c80a67..cb74ee1b 100644 --- a/edgedb/protocol/codecs/codecs.pyx +++ b/edgedb/protocol/codecs/codecs.pyx @@ -44,6 +44,7 @@ DEF CTYPE_TUPLE = 4 DEF CTYPE_NAMEDTUPLE = 5 DEF CTYPE_ARRAY = 6 DEF CTYPE_ENUM = 7 +DEF CTYPE_INPUT_SHAPE = 8 DEF _CODECS_BUILD_CACHE_SIZE = 200 @@ -105,7 +106,7 @@ cdef class CodecsRegistry: if t == CTYPE_SET: frb_read(spec, 2) - elif t == CTYPE_SHAPE: + elif t == CTYPE_SHAPE or t == CTYPE_INPUT_SHAPE: els = hton.unpack_int16(frb_read(spec, 2)) for i in range(els): frb_read(spec, 4) # flags @@ -164,7 +165,7 @@ cdef class CodecsRegistry: sub_codec = codecs_list[pos] res = SetCodec.new(tid, sub_codec) - elif t == CTYPE_SHAPE: + elif t == CTYPE_SHAPE or t == CTYPE_INPUT_SHAPE: els = hton.unpack_int16(frb_read(spec, 2)) codecs = cpython.PyTuple_New(els) names = cpython.PyTuple_New(els) @@ -192,7 +193,9 @@ cdef class CodecsRegistry: cpython.Py_INCREF(cardinality) cpython.PyTuple_SetItem(cards, i, cardinality) - res = ObjectCodec.new(tid, names, flags, cards, codecs) + res = ObjectCodec.new( + tid, names, flags, cards, codecs, t == CTYPE_INPUT_SHAPE + ) elif t == CTYPE_BASE_SCALAR: if tid in self.base_codec_overrides: diff --git a/edgedb/protocol/codecs/object.pxd b/edgedb/protocol/codecs/object.pxd index 0b01f3aa..8524f55c 100644 --- a/edgedb/protocol/codecs/object.pxd +++ b/edgedb/protocol/codecs/object.pxd @@ -19,9 +19,10 @@ @cython.final cdef class ObjectCodec(BaseNamedRecordCodec): + cdef bint is_sparse cdef encode_args(self, WriteBuffer buf, dict obj) @staticmethod cdef BaseCodec new(bytes tid, tuple names, tuple flags, - tuple cards, tuple codecs) + tuple cards, tuple codecs, bint is_sparse) diff --git a/edgedb/protocol/codecs/object.pyx b/edgedb/protocol/codecs/object.pyx index 14931af5..da879ded 100644 --- a/edgedb/protocol/codecs/object.pyx +++ b/edgedb/protocol/codecs/object.pyx @@ -21,7 +21,40 @@ cdef class ObjectCodec(BaseNamedRecordCodec): cdef encode(self, WriteBuffer buf, object obj): - raise NotImplementedError + cdef: + WriteBuffer elem_data + Py_ssize_t objlen = 0 + Py_ssize_t i + BaseCodec sub_codec + descriptor = (self).descriptor + + if not self.is_sparse: + raise NotImplementedError + + elem_data = WriteBuffer.new() + for name, arg in obj.items(): + if arg is not None: + try: + i = descriptor.get_pos(name) + except LookupError: + raise self._make_missing_args_error_message(obj) from None + objlen += 1 + elem_data.write_int32(i) + + sub_codec = (self.fields_codecs[i]) + try: + sub_codec.encode(elem_data, arg) + except (TypeError, ValueError) as e: + value_repr = repr(arg) + if len(value_repr) > 40: + value_repr = value_repr[:40] + '...' + raise errors.InvalidArgumentError( + 'invalid input for session argument ' + f' {name} := {value_repr} ({e})') from e + + buf.write_int32(4 + elem_data.len()) # buffer length + buf.write_int32(objlen) + buf.write_buffer(elem_data) cdef encode_args(self, WriteBuffer buf, dict obj): cdef: @@ -31,6 +64,9 @@ cdef class ObjectCodec(BaseNamedRecordCodec): BaseCodec sub_codec descriptor = (self).descriptor + if self.is_sparse: + raise NotImplementedError + self._check_encoder() objlen = len(obj) @@ -83,15 +119,17 @@ cdef class ObjectCodec(BaseNamedRecordCodec): passed_args = set(args.keys()) missed_args = required_args - passed_args extra_args = passed_args - required_args + required = 'acceptable' if self.is_sparse else 'expected' - error_message = f'expected {required_args} arguments' + error_message = f'{required} {required_args} arguments' passed_args_repr = repr(passed_args) if passed_args else 'nothing' error_message += f', got {passed_args_repr}' - missed_args = set(required_args) - set(passed_args) - if missed_args: - error_message += f', missed {missed_args}' + if not self.is_sparse: + missed_args = set(required_args) - set(passed_args) + if missed_args: + error_message += f', missed {missed_args}' extra_args = set(passed_args) - set(required_args) if extra_args: @@ -110,6 +148,9 @@ cdef class ObjectCodec(BaseNamedRecordCodec): tuple fields_codecs = (self).fields_codecs descriptor = (self).descriptor + if self.is_sparse: + raise NotImplementedError + elem_count = hton.unpack_int32(frb_read(buf, 4)) if elem_count != len(fields_codecs): @@ -140,14 +181,18 @@ cdef class ObjectCodec(BaseNamedRecordCodec): @staticmethod cdef BaseCodec new(bytes tid, tuple names, tuple flags, tuple cards, - tuple codecs): + tuple codecs, bint is_sparse): cdef: ObjectCodec codec codec = ObjectCodec.__new__(ObjectCodec) codec.tid = tid - codec.name = 'Object' + if is_sparse: + codec.name = 'SparseObject' + else: + codec.name = 'Object' + codec.is_sparse = is_sparse codec.descriptor = datatypes.record_desc_new(names, flags, cards) codec.fields_codecs = codecs diff --git a/edgedb/protocol/protocol.pxd b/edgedb/protocol/protocol.pxd index 69ed875e..282caef9 100644 --- a/edgedb/protocol/protocol.pxd +++ b/edgedb/protocol/protocol.pxd @@ -103,6 +103,11 @@ cdef class SansIOProtocol: readonly bint is_legacy + bytes state_type_id + BaseCodec state_codec + WriteBuffer state + object user_state + cdef encode_args(self, BaseCodec in_dc, WriteBuffer buf, args, kwargs) cdef parse_data_messages(self, BaseCodec out_dc, result) diff --git a/edgedb/protocol/protocol.pyx b/edgedb/protocol/protocol.pyx index 8adb833b..1f217a0b 100644 --- a/edgedb/protocol/protocol.pyx +++ b/edgedb/protocol/protocol.pyx @@ -150,6 +150,11 @@ cdef class SansIOProtocol: self.reset_status() self.protocol_version = (PROTO_VER_MAJOR, 0) + self.state_type_id = NULL_CODEC_ID + self.state_codec = None + self.state = None + self.user_state = None + cdef reset_status(self): self.last_status = None self.last_details = None @@ -193,6 +198,21 @@ cdef class SansIOProtocol: self.buffer.read_len_prefixed_bytes() # value num_fields -= 1 + def set_state(self, user_state): + cdef WriteBuffer buf = WriteBuffer.new() + if self.user_state is user_state: + return + if user_state is None: + self.state = None + else: + # Apply async state_description for AsyncClient + while self.buffer.take_message(): + self.fallthrough() + + self.state_codec.encode(buf, user_state.as_dict()) + self.state = buf + self.user_state = user_state + cdef ensure_connected(self): if self.cancelled: raise errors.ClientConnectionClosedError( @@ -358,11 +378,18 @@ cdef class SansIOProtocol: buf.write_bytes(in_dc.get_tid()) buf.write_bytes(out_dc.get_tid()) + buf.write_bytes( + NULL_CODEC_ID if self.state is None else self.state_type_id + ) if not isinstance(in_dc, NullCodec): self.encode_args(in_dc, buf, args, kwargs) else: buf.write_bytes(EMPTY_NULL_DATA) + if self.state is not None: + buf.write_buffer(self.state) + else: + buf.write_bytes(EMPTY_NULL_DATA) buf.end_message() @@ -982,6 +1009,16 @@ cdef class SansIOProtocol: data = buf.read_len_prefixed_bytes() self.server_settings[name] = self.parse_system_config(codec, data) + elif name == 'state_description': + self.state_type_id = typedesc_id = val[:16] + typedesc = val[16 + 4:] + + if self.internal_reg.has_codec(typedesc_id): + self.state_codec = self.internal_reg.get_codec(typedesc_id) + else: + self.state_codec = self.internal_reg.build_codec( + typedesc, self.protocol_version + ) else: self.server_settings[name] = val @@ -1149,6 +1186,8 @@ cdef class SansIOProtocol: self.ignore_headers() self.last_capabilities = enums.Capability(self.buffer.read_int64()) self.last_status = self.buffer.read_len_prefixed_bytes() + self.buffer.read_bytes(16) # state type id + self.buffer.read_len_prefixed_bytes() # state self.buffer.finish_message() cdef parse_sync_message(self): diff --git a/edgedb/transaction.py b/edgedb/transaction.py index aeffa889..a305440e 100644 --- a/edgedb/transaction.py +++ b/edgedb/transaction.py @@ -21,6 +21,7 @@ from . import abstract from . import errors +from . import options class TransactionState(enum.Enum): @@ -178,6 +179,9 @@ async def _exit(self, extype, ex): def _get_query_cache(self) -> abstract.QueryCache: return self._client._get_query_cache() + def _get_session(self) -> options.Session: + return self._client._get_session() + async def _query(self, query_context: abstract.QueryContext): await self._ensure_transaction() return await self._connection.raw_query(query_context) @@ -190,6 +194,7 @@ async def _privileged_execute(self, query: str) -> None: await self._connection.privileged_execute(abstract.ScriptContext( query=abstract.QueryWithArgs(query, (), {}), cache=self._get_query_cache(), + session=self._get_session(), )) diff --git a/tests/test_async_query.py b/tests/test_async_query.py index c5680694..c865fca2 100644 --- a/tests/test_async_query.py +++ b/tests/test_async_query.py @@ -917,6 +917,7 @@ async def test_json_elements(self): required_one=False, ), retry_options=None, + session=None, ) ) self.assertEqual( diff --git a/tests/test_proto.py b/tests/test_proto.py index 702cbede..aa66b1a7 100644 --- a/tests/test_proto.py +++ b/tests/test_proto.py @@ -47,6 +47,7 @@ def test_json_elements(self): required_one=False, ), retry_options=None, + session=None, ) ) )