Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multirange support. #452

Merged
merged 4 commits into from
Jul 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion edgedb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
Tuple, NamedTuple, EnumValue, RelativeDuration, DateDuration, ConfigMemory
)
from edgedb.datatypes.datatypes import Set, Object, Array, Link, LinkSet
from edgedb.datatypes.range import Range
from edgedb.datatypes.range import Range, MultiRange

from .abstract import (
Executor, AsyncIOExecutor, ReadOnlyExecutor, AsyncIOReadOnlyExecutor,
Expand Down
58 changes: 48 additions & 10 deletions edgedb/datatypes/range.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
# limitations under the License.
#

from typing import Generic, Optional, TypeVar

from typing import (TypeVar, Any, Generic, Optional, Iterable, Iterator,
Sequence)

T = TypeVar("T")

Expand Down Expand Up @@ -78,22 +78,24 @@ def is_empty(self) -> bool:
def __bool__(self):
return not self.is_empty()

def __eq__(self, other):
if not isinstance(other, Range):
def __eq__(self, other) -> bool:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should have try to have Range and MultiRange able to compare equal.
In particular, if we are going to do that, we need to have __hash__ be equal when the objects are equal, which is not worth dealing with.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point about hashes equality. I'll drop the equality between Range and MultiRange.

if isinstance(other, Range):
o = other
else:
return NotImplemented

return (
self._lower,
self._upper,
self._inc_lower,
self._inc_upper,
self._empty
) == (
other._lower,
other._upper,
other._inc_lower,
other._inc_upper,
self._empty,
) == (
o._lower,
o._upper,
o._inc_lower,
o._inc_upper,
o._empty,
)

def __hash__(self) -> int:
Expand Down Expand Up @@ -125,3 +127,39 @@ def __str__(self) -> str:
return f"<Range {desc}>"

__repr__ = __str__


# TODO: maybe we should implement range and multirange operations as well as
# normalization of the sub-ranges?
class MultiRange(Iterable[T]):

_ranges: Sequence[T]

def __init__(self, iterable: Optional[Iterable[T]] = None) -> None:
if iterable is not None:
self._ranges = tuple(iterable)
else:
self._ranges = tuple()

def __len__(self) -> int:
return len(self._ranges)

def __iter__(self) -> Iterator[T]:
return iter(self._ranges)

def __reversed__(self) -> Iterator[T]:
return reversed(self._ranges)

def __str__(self) -> str:
return f'<MultiRange {list(self._ranges)}>'

__repr__ = __str__

def __eq__(self, other: Any) -> bool:
if isinstance(other, MultiRange):
return set(self._ranges) == set(other._ranges)
else:
return NotImplemented

def __hash__(self) -> int:
return hash(self._ranges)
5 changes: 5 additions & 0 deletions edgedb/describe.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,8 @@ class SparseObjectType(ObjectType):
@dataclasses.dataclass(frozen=True)
class RangeType(AnyType):
value_type: AnyType


@dataclasses.dataclass(frozen=True)
class MultiRangeType(AnyType):
value_type: AnyType
3 changes: 2 additions & 1 deletion edgedb/protocol/codecs/array.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ cdef class BaseArrayCodec(BaseCodec):

if not isinstance(
self.sub_codec,
(ScalarCodec, TupleCodec, NamedTupleCodec, RangeCodec, EnumCodec)
(ScalarCodec, TupleCodec, NamedTupleCodec, EnumCodec,
RangeCodec, MultiRangeCodec)
):
raise TypeError(
'only arrays of scalars are supported (got type {!r})'.format(
Expand Down
2 changes: 1 addition & 1 deletion edgedb/protocol/codecs/base.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ cdef class BaseRecordCodec(BaseCodec):
if not isinstance(
codec,
(ScalarCodec, ArrayCodec, TupleCodec, NamedTupleCodec,
EnumCodec, RangeCodec),
EnumCodec, RangeCodec, MultiRangeCodec),
):
self.encoder_flags |= RECORD_ENCODER_INVALID
break
Expand Down
22 changes: 22 additions & 0 deletions edgedb/protocol/codecs/codecs.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ DEF CTYPE_INPUT_SHAPE = 8
DEF CTYPE_RANGE = 9
DEF CTYPE_OBJECT = 10
DEF CTYPE_COMPOUND = 11
DEF CTYPE_MULTIRANGE = 12
DEF CTYPE_ANNO_TYPENAME = 255

DEF _CODECS_BUILD_CACHE_SIZE = 200
Expand Down Expand Up @@ -165,6 +166,9 @@ cdef class CodecsRegistry:
elif t == CTYPE_RANGE:
frb_read(spec, 2)

elif t == CTYPE_MULTIRANGE:
frb_read(spec, 2)

elif t == CTYPE_ENUM:
els = <uint16_t>hton.unpack_int16(frb_read(spec, 2))
for i in range(els):
Expand Down Expand Up @@ -444,6 +448,24 @@ cdef class CodecsRegistry:
res = RangeCodec.new(tid, sub_codec)
res.type_name = type_name

elif t == CTYPE_MULTIRANGE:
if protocol_version >= (2, 0):
str_len = hton.unpack_uint32(frb_read(spec, 4))
type_name = cpythonx.PyUnicode_FromStringAndSize(
frb_read(spec, str_len), str_len)
schema_defined = <bint>frb_read(spec, 1)[0]
ancestor_count = <uint16_t>hton.unpack_int16(frb_read(spec, 2))
for _ in range(ancestor_count):
ancestor_pos = <uint16_t>hton.unpack_int16(
frb_read(spec, 2))
ancestor_codec = codecs_list[ancestor_pos]
else:
type_name = None
pos = <uint16_t>hton.unpack_int16(frb_read(spec, 2))
sub_codec = <BaseCodec>codecs_list[pos]
res = MultiRangeCodec.new(tid, sub_codec)
res.type_name = type_name

elif t == CTYPE_OBJECT and protocol_version >= (2, 0):
# Ignore
frb_read(spec, desc_len)
Expand Down
16 changes: 16 additions & 0 deletions edgedb/protocol/codecs/range.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,19 @@ cdef class RangeCodec(BaseCodec):

@staticmethod
cdef BaseCodec new(bytes tid, BaseCodec sub_codec)

@staticmethod
cdef encode_range(WriteBuffer buf, object obj, BaseCodec sub_codec)

@staticmethod
cdef decode_range(FRBuffer *buf, BaseCodec sub_codec)


@cython.final
cdef class MultiRangeCodec(BaseCodec):

cdef:
BaseCodec sub_codec

@staticmethod
cdef BaseCodec new(bytes tid, BaseCodec sub_codec)
122 changes: 115 additions & 7 deletions edgedb/protocol/codecs/range.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ cdef class RangeCodec(BaseCodec):

return codec

cdef encode(self, WriteBuffer buf, object obj):
@staticmethod
cdef encode_range(WriteBuffer buf, object obj, BaseCodec sub_codec):
cdef:
uint8_t flags = 0
WriteBuffer sub_data
Expand All @@ -56,10 +57,10 @@ cdef class RangeCodec(BaseCodec):
bint inc_upper = obj.inc_upper
bint empty = obj.is_empty()

if not isinstance(self.sub_codec, ScalarCodec):
if not isinstance(sub_codec, ScalarCodec):
raise TypeError(
'only scalar ranges are supported (got type {!r})'.format(
type(self.sub_codec).__name__
type(sub_codec).__name__
)
)

Expand All @@ -78,14 +79,14 @@ cdef class RangeCodec(BaseCodec):
sub_data = WriteBuffer.new()
if lower is not None:
try:
self.sub_codec.encode(sub_data, lower)
sub_codec.encode(sub_data, lower)
except TypeError as e:
raise ValueError(
'invalid range lower bound: {}'.format(
e.args[0])) from None
if upper is not None:
try:
self.sub_codec.encode(sub_data, upper)
sub_codec.encode(sub_data, upper)
except TypeError as e:
raise ValueError(
'invalid range upper bound: {}'.format(
Expand All @@ -95,7 +96,8 @@ cdef class RangeCodec(BaseCodec):
buf.write_byte(<int8_t>flags)
buf.write_buffer(sub_data)

cdef decode(self, FRBuffer *buf):
@staticmethod
cdef decode_range(FRBuffer *buf, BaseCodec sub_codec):
cdef:
uint8_t flags = <uint8_t>frb_read(buf, 1)[0]
bint empty = (flags & RANGE_EMPTY) != 0
Expand All @@ -107,7 +109,6 @@ cdef class RangeCodec(BaseCodec):
object upper = None
int32_t sub_len
FRBuffer sub_buf
BaseCodec sub_codec = self.sub_codec

if has_lower:
sub_len = hton.unpack_int32(frb_read(buf, 4))
Expand Down Expand Up @@ -137,6 +138,12 @@ cdef class RangeCodec(BaseCodec):
empty=empty,
)

cdef encode(self, WriteBuffer buf, object obj):
RangeCodec.encode_range(buf, obj, self.sub_codec)

cdef decode(self, FRBuffer *buf):
return RangeCodec.decode_range(buf, self.sub_codec)

cdef dump(self, int level = 0):
return f'{level * " "}{self.name}\n{self.sub_codec.dump(level + 1)}'

Expand All @@ -146,3 +153,104 @@ cdef class RangeCodec(BaseCodec):
name=self.type_name,
value_type=self.sub_codec.make_type(describe_context),
)


@cython.final
cdef class MultiRangeCodec(BaseCodec):

def __cinit__(self):
self.sub_codec = None

@staticmethod
cdef BaseCodec new(bytes tid, BaseCodec sub_codec):
cdef:
MultiRangeCodec codec

codec = MultiRangeCodec.__new__(MultiRangeCodec)

codec.tid = tid
codec.name = 'MultiRange'
codec.sub_codec = sub_codec

return codec

cdef encode(self, WriteBuffer buf, object obj):
cdef:
WriteBuffer elem_data
Py_ssize_t objlen
Py_ssize_t elem_data_len

if not isinstance(self.sub_codec, ScalarCodec):
raise TypeError(
f'only scalar multiranges are supported (got type '
f'{type(self.sub_codec).__name__!r})'
)

if not _is_array_iterable(obj):
raise TypeError(
f'a sized iterable container expected (got type '
f'{type(obj).__name__!r})'
)

objlen = len(obj)
if objlen > _MAXINT32:
raise ValueError('too many elements in multirange value')

elem_data = WriteBuffer.new()
for item in obj:
try:
RangeCodec.encode_range(elem_data, item, self.sub_codec)
except TypeError as e:
raise ValueError(
f'invalid multirange element: {e.args[0]}') from None

elem_data_len = elem_data.len()
if elem_data_len > _MAXINT32 - 4:
raise OverflowError(
f'size of encoded multirange datum exceeds the maximum '
f'allowed {_MAXINT32 - 4} bytes')

# Datum length
buf.write_int32(4 + <int32_t>elem_data_len)
# Number of elements in multirange
buf.write_int32(<int32_t>objlen)
buf.write_buffer(elem_data)

cdef decode(self, FRBuffer *buf):
cdef:
Py_ssize_t elem_count = <Py_ssize_t><uint32_t>hton.unpack_int32(
frb_read(buf, 4))
object result
Py_ssize_t i
int32_t elem_len
FRBuffer elem_buf

result = cpython.PyList_New(elem_count)
for i in range(elem_count):
elem_len = hton.unpack_int32(frb_read(buf, 4))
if elem_len == -1:
raise RuntimeError(
'unexpected NULL element in multirange value')
else:
frb_slice_from(&elem_buf, buf, elem_len)
elem = RangeCodec.decode_range(&elem_buf, self.sub_codec)
if frb_get_len(&elem_buf):
raise RuntimeError(
f'unexpected trailing data in buffer after '
f'multirange element decoding: '
f'{frb_get_len(&elem_buf)}')

cpython.Py_INCREF(elem)
cpython.PyList_SET_ITEM(result, i, elem)

return range_mod.MultiRange(result)

cdef dump(self, int level = 0):
return f'{level * " "}{self.name}\n{self.sub_codec.dump(level + 1)}'

def make_type(self, describe_context):
return describe.MultiRangeType(
desc_id=uuid.UUID(bytes=self.tid),
name=self.type_name,
value_type=self.sub_codec.make_type(describe_context),
)
Loading