Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add with_globals() and friends through state over the protocol #315

Merged
merged 11 commits into from
Jun 27, 2022
19 changes: 19 additions & 0 deletions edgedb/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,13 @@ class QueryContext(typing.NamedTuple):
cache: QueryCache
query_options: QueryOptions
retry_options: typing.Optional[options.RetryOptions]
session: typing.Optional[options.Session]


class ScriptContext(typing.NamedTuple):
query: QueryWithArgs
cache: QueryCache
session: typing.Optional[options.Session]


_query_opts = QueryOptions(
Expand Down Expand Up @@ -88,6 +90,9 @@ def _get_query_cache(self) -> QueryCache:
def _get_retry_options(self) -> typing.Optional[options.RetryOptions]:
return None

def _get_session(self) -> options.Session:
...


class ReadOnlyExecutor(BaseReadOnlyExecutor):
"""Subclasses can execute *at least* read-only queries"""
Expand All @@ -104,6 +109,7 @@ def query(self, query: str, *args, **kwargs) -> datatypes.Set:
cache=self._get_query_cache(),
query_options=_query_opts,
retry_options=self._get_retry_options(),
session=self._get_session(),
))

def query_single(
Expand All @@ -114,6 +120,7 @@ def query_single(
cache=self._get_query_cache(),
query_options=_query_single_opts,
retry_options=self._get_retry_options(),
session=self._get_session(),
))

def query_required_single(self, query: str, *args, **kwargs) -> typing.Any:
Expand All @@ -122,6 +129,7 @@ def query_required_single(self, query: str, *args, **kwargs) -> typing.Any:
cache=self._get_query_cache(),
query_options=_query_required_single_opts,
retry_options=self._get_retry_options(),
session=self._get_session(),
))

def query_json(self, query: str, *args, **kwargs) -> str:
Expand All @@ -130,6 +138,7 @@ def query_json(self, query: str, *args, **kwargs) -> str:
cache=self._get_query_cache(),
query_options=_query_json_opts,
retry_options=self._get_retry_options(),
session=self._get_session(),
))

def query_single_json(self, query: str, *args, **kwargs) -> str:
Expand All @@ -138,6 +147,7 @@ def query_single_json(self, query: str, *args, **kwargs) -> str:
cache=self._get_query_cache(),
query_options=_query_single_json_opts,
retry_options=self._get_retry_options(),
session=self._get_session(),
))

def query_required_single_json(self, query: str, *args, **kwargs) -> str:
Expand All @@ -146,6 +156,7 @@ def query_required_single_json(self, query: str, *args, **kwargs) -> str:
cache=self._get_query_cache(),
query_options=_query_required_single_json_opts,
retry_options=self._get_retry_options(),
session=self._get_session(),
))

@abc.abstractmethod
Expand All @@ -156,6 +167,7 @@ def execute(self, query: str, *args, **kwargs) -> None:
self._execute(ScriptContext(
query=QueryWithArgs(query, args, kwargs),
cache=self._get_query_cache(),
session=self._get_session(),
))


Expand All @@ -180,6 +192,7 @@ async def query(self, query: str, *args, **kwargs) -> datatypes.Set:
cache=self._get_query_cache(),
query_options=_query_opts,
retry_options=self._get_retry_options(),
session=self._get_session(),
))

async def query_single(self, query: str, *args, **kwargs) -> typing.Any:
Expand All @@ -188,6 +201,7 @@ async def query_single(self, query: str, *args, **kwargs) -> typing.Any:
cache=self._get_query_cache(),
query_options=_query_single_opts,
retry_options=self._get_retry_options(),
session=self._get_session(),
))

async def query_required_single(
Expand All @@ -201,6 +215,7 @@ async def query_required_single(
cache=self._get_query_cache(),
query_options=_query_required_single_opts,
retry_options=self._get_retry_options(),
session=self._get_session(),
))

async def query_json(self, query: str, *args, **kwargs) -> str:
Expand All @@ -209,6 +224,7 @@ async def query_json(self, query: str, *args, **kwargs) -> str:
cache=self._get_query_cache(),
query_options=_query_json_opts,
retry_options=self._get_retry_options(),
session=self._get_session(),
))

async def query_single_json(self, query: str, *args, **kwargs) -> str:
Expand All @@ -217,6 +233,7 @@ async def query_single_json(self, query: str, *args, **kwargs) -> str:
cache=self._get_query_cache(),
query_options=_query_single_json_opts,
retry_options=self._get_retry_options(),
session=self._get_session(),
))

async def query_required_single_json(
Expand All @@ -230,6 +247,7 @@ async def query_required_single_json(
cache=self._get_query_cache(),
query_options=_query_required_single_json_opts,
retry_options=self._get_retry_options(),
session=self._get_session(),
))

@abc.abstractmethod
Expand All @@ -240,6 +258,7 @@ async def execute(self, query: str, *args, **kwargs) -> None:
await self._execute(ScriptContext(
query=QueryWithArgs(query, args, kwargs),
cache=self._get_query_cache(),
session=self._get_session(),
))


Expand Down
11 changes: 9 additions & 2 deletions edgedb/base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,11 @@ async def raw_query(self, query_context: abstract.QueryContext):
await self.connect(single_attempt=True)
if self._protocol.is_legacy:
execute = self._protocol.legacy_execute_anonymous
allow_capabilities = enums.Capability.LEGACY_EXECUTE
else:
execute = self._protocol.query
self._protocol.set_state(query_context.session)
allow_capabilities = enums.Capability.EXECUTE
return await execute(
query=query_context.query.query,
args=query_context.query.args,
Expand All @@ -222,7 +225,7 @@ async def raw_query(self, query_context: abstract.QueryContext):
output_format=query_context.query_options.output_format,
expect_one=query_context.query_options.expect_one,
required_one=query_context.query_options.required_one,
allow_capabilities=enums.Capability.EXECUTE,
allow_capabilities=allow_capabilities,
)
except errors.EdgeDBError as e:
if query_context.retry_options is None:
Expand Down Expand Up @@ -261,9 +264,10 @@ async def _execute(self, script: abstract.ScriptContext) -> None:
"Legacy protocol doesn't support arguments in execute()"
)
await self._protocol.legacy_simple_query(
script.query.query, enums.Capability.EXECUTE
script.query.query, enums.Capability.LEGACY_EXECUTE
)
else:
self._protocol.set_state(script.session)
await self._protocol.execute(
query=script.query.query,
args=script.query.args,
Expand Down Expand Up @@ -697,6 +701,9 @@ def _get_query_cache(self) -> abstract.QueryCache:
def _get_retry_options(self) -> typing.Optional[_options.RetryOptions]:
return self._options.retry_options

def _get_session(self) -> _options.Session:
return self._options.session

@property
def max_concurrency(self) -> int:
"""Max number of connections in the pool."""
Expand Down
5 changes: 3 additions & 2 deletions edgedb/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@ class Capability(enum.IntFlag):
DDL = 1 << 3 # noqa
PERSISTENT_CONFIG = 1 << 4 # noqa

ALL = 0xFFFF_FFFF_FFFF_FFFF # noqa
EXECUTE = ALL & ~TRANSACTION # noqa
ALL = 0xFFFF_FFFF_FFFF_FFFF # noqa
EXECUTE = ALL & ~TRANSACTION & ~SESSION_CONFIG # noqa
LEGACY_EXECUTE = ALL & ~TRANSACTION # noqa


class CompilationFlag(enum.IntFlag):
Expand Down
6 changes: 6 additions & 0 deletions edgedb/errors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
'UnexpectedMessageError',
'InputDataError',
'ParameterTypeMismatchError',
'StateMismatchError',
'ResultCardinalityMismatchError',
'CapabilityError',
'UnsupportedCapabilityError',
Expand Down Expand Up @@ -146,6 +147,10 @@ class ParameterTypeMismatchError(InputDataError):
_code = 0x_03_02_01_00


class StateMismatchError(InputDataError):
_code = 0x_03_02_02_00


class ResultCardinalityMismatchError(ProtocolError):
_code = 0x_03_03_00_00

Expand Down Expand Up @@ -487,3 +492,4 @@ class NoDataError(ClientError):

class InternalClientError(ClientError):
_code = 0x_FF_04_00_00

121 changes: 120 additions & 1 deletion edgedb/options.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import abc
import enum
import random
import typing
from collections import namedtuple

from . import errors
Expand Down Expand Up @@ -108,6 +109,82 @@ def __repr__(self):
)


class Session:
__slots__ = ['_module', '_aliases', '_config', '_globals']

def __init__(
self,
module: typing.Optional[str] = None,
aliases: typing.Mapping[str, str] = None,
config: typing.Mapping[str, typing.Any] = None,
globals_: typing.Mapping[str, typing.Any] = None,
):
self._module = module
self._aliases = {} if aliases is None else dict(aliases)
self._config = {} if config is None else dict(config)
self._globals = {} if globals_ is None else dict(globals_)

@classmethod
def defaults(cls):
return cls()

def with_aliases(self, module=..., **aliases):
new_aliases = self._aliases.copy()
new_aliases.update(aliases)
return Session(
module=self._module if module is ... else module,
aliases=new_aliases,
config=self._config,
globals_=self._globals,
)

def with_config(self, **config):
new_config = self._config.copy()
new_config.update(config)
return Session(
module=self._module,
aliases=self._aliases,
config=new_config,
globals_=self._globals,
)

def with_globals(self, **globals_):
new_globals = self._globals.copy()
new_globals.update(globals_)
return Session(
module=self._module,
aliases=self._aliases,
config=self._config,
globals_=new_globals,
)

def as_dict(self):
rv = {}
if self._module is not None:
module = rv["module"] = self._module
else:
module = 'default'
if self._aliases:
rv["aliases"] = list(self._aliases.items())
if self._config:
rv["config"] = self._config
if self._globals:
rv["globals"] = g = {}
for k, v in self._globals.items():
parts = k.split("::")
if len(parts) == 1:
g[f"{module}::{k}"] = v
elif len(parts) == 2:
mod, glob = parts
mod = self._aliases.get(mod, mod)
g[f"{mod}::{glob}"] = v
else:
raise errors.InvalidArgumentError(
f"Illegal global name: {k}"
)
return rv


class _OptionsMixin:
def __init__(self, *args, **kwargs):
self._options = _Options.defaults()
Expand Down Expand Up @@ -153,19 +230,47 @@ def with_retry_options(self, options: RetryOptions=None):
result._options = self._options.with_retry_options(options)
return result

def with_session(self, session: Session):
result = self._shallow_clone()
result._options = self._options.with_session(session)
return result

def with_aliases(self, module=None, **aliases):
result = self._shallow_clone()
result._options = self._options.with_session(
self._options.session.with_aliases(module=module, **aliases)
)
return result

def with_config(self, **config):
result = self._shallow_clone()
result._options = self._options.with_session(
self._options.session.with_config(**config)
)
return result

def with_globals(self, **globals_):
result = self._shallow_clone()
result._options = self._options.with_session(
self._options.session.with_globals(**globals_)
)
return result


class _Options:
"""Internal class for storing connection options"""

__slots__ = ['_retry_options', '_transaction_options']
__slots__ = ['_retry_options', '_transaction_options', '_session']

def __init__(
self,
retry_options: RetryOptions,
transaction_options: TransactionOptions,
session: Session,
):
self._retry_options = retry_options
self._transaction_options = transaction_options
self._session = session

@property
def retry_options(self):
Expand All @@ -175,21 +280,35 @@ def retry_options(self):
def transaction_options(self):
return self._transaction_options

@property
def session(self):
return self._session

def with_retry_options(self, options: RetryOptions):
return _Options(
options,
self._transaction_options,
self._session,
)

def with_transaction_options(self, options: TransactionOptions):
return _Options(
self._retry_options,
options,
self._session,
)

def with_session(self, session: Session):
return _Options(
self._retry_options,
self._transaction_options,
session,
)

@classmethod
def defaults(cls):
return cls(
RetryOptions.defaults(),
TransactionOptions.defaults(),
Session.defaults(),
)
Loading