Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Structured Properties #102

Merged
merged 2 commits into from
Jun 4, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 107 additions & 55 deletions src/google/cloud/ndb/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,17 +340,15 @@ def _entity_from_protobuf(protobuf):
return _entity_from_ds_entity(ds_entity)


def _entity_to_protobuf(entity, set_key=True):
"""Serialize an entity to a protobuffer.
def _entity_to_ds_entity(entity, set_key=True):
"""Convert an NDB entity to Datastore entity.

Args:
entity (Model): The entity to be serialized.
entity (Model): The entity to be converted.

Returns:
google.cloud.datastore_v1.types.Entity: The protocol buffer
representation.
google.cloud.datastore.entity.Entity: The converted entity.
"""
# First, make a datastore entity
data = {}
for cls in type(entity).mro():
for prop in cls.__dict__.values():
Expand All @@ -376,7 +374,20 @@ def _entity_to_protobuf(entity, set_key=True):
ds_entity = entity_module.Entity()
ds_entity.update(data)

# Then, use datatore to get the protocol buffer
return ds_entity


def _entity_to_protobuf(entity, set_key=True):
"""Serialize an entity to a protocol buffer.

Args:
entity (Model): The entity to be serialized.

Returns:
google.cloud.datastore_v1.types.Entity: The protocol buffer
representation.
"""
ds_entity = _entity_to_ds_entity(entity, set_key=set_key)
return helpers.entity_to_protobuf(ds_entity)


Expand Down Expand Up @@ -3362,29 +3373,29 @@ class StructuredProperty(Property):
The values of the sub-entity are indexed and can be queried.
"""

_modelclass = None
_model_class = None
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In different places, it was spelled _modelclass, _model_class, and _kls. I picked one.

_kwargs = None

def __init__(self, modelclass, name=None, **kwargs):
def __init__(self, model_class, name=None, **kwargs):
super(StructuredProperty, self).__init__(name=name, **kwargs)
if self._repeated:
if modelclass._has_repeated:
if model_class._has_repeated:
raise TypeError(
"This StructuredProperty cannot use repeated=True "
"because its model class (%s) contains repeated "
"properties (directly or indirectly)."
% modelclass.__name__
% model_class.__name__
)
self._modelclass = modelclass
self._model_class = model_class

def _get_value(self, entity):
"""Override _get_value() to *not* raise UnprojectedPropertyError.

This is necessary because the projection must include both the sub-entity and
the property name that is projected (e.g. 'foo.bar' instead of only 'foo'). In
that case the original code would fail, because it only looks for the property
name ('foo'). Here we check for a value, and only call the original code if the
value is None.
This is necessary because the projection must include both the
sub-entity and the property name that is projected (e.g. 'foo.bar'
instead of only 'foo'). In that case the original code would fail,
because it only looks for the property name ('foo'). Here we check for
a value, and only call the original code if the value is None.
"""
value = self._get_user_value(entity)
if value is None and entity._projection:
Expand All @@ -3403,11 +3414,11 @@ def _get_for_dict(self, entity):
def __getattr__(self, attrname):
"""Dynamically get a subproperty."""
# Optimistically try to use the dict key.
prop = self._modelclass._properties.get(attrname)
prop = self._model_class._properties.get(attrname)
if prop is None:
raise AttributeError(
"Model subclass %s has no attribute %s"
% (self._modelclass.__name__, attrname)
% (self._model_class.__name__, attrname)
)
prop_copy = copy.copy(prop)
prop_copy._name = self._name + "." + prop_copy._name
Expand Down Expand Up @@ -3436,37 +3447,41 @@ def _comparison(self, op, value):
) # Import late to avoid circular imports.

return FilterNode(self._name, op, value)

value = self._do_validate(value)
value = self._call_to_base_type(value)
filters = []
match_keys = []
for prop in self._modelclass._properties.values():
vals = prop._get_base_value_unwrapped_as_list(value)
for prop in self._model_class._properties.values():
subvalue = prop._get_value(value)
if prop._repeated:
if vals: # pragma: no branch
if subvalue: # pragma: no branch
raise exceptions.BadFilterError(
"Cannot query for non-empty repeated property %s"
% prop._name
)
continue # pragma: NO COVER
val = vals[0]
if val is not None: # pragma: no branch

if subvalue is not None: # pragma: no branch
altprop = getattr(self, prop._code_name)
filt = altprop._comparison(op, val)
filt = altprop._comparison(op, subvalue)
filters.append(filt)
match_keys.append(altprop._name)

if not filters:
raise exceptions.BadFilterError(
"StructuredProperty filter without any values"
)

if len(filters) == 1:
return filters[0]

if self._repeated:
raise NotImplementedError("This depends on code not yet ported.")
# pb = value._to_pb(allow_partial=True)
# pred = RepeatedStructuredPropertyPredicate(match_keys, pb,
# self._name + '.')
# filters.append(PostFilterNode(pred))

return ConjunctionNode(*filters)

def _IN(self, value):
Expand All @@ -3491,11 +3506,11 @@ def _IN(self, value):
def _validate(self, value):
if isinstance(value, dict):
# A dict is assumed to be the result of a _to_dict() call.
return self._modelclass(**value)
if not isinstance(value, self._modelclass):
return self._model_class(**value)
if not isinstance(value, self._model_class):
raise exceptions.BadValueError(
"Expected %s instance, got %s"
% (self._modelclass.__name__, value.__class__)
% (self._model_class.__name__, value.__class__)
)

def _has_value(self, entity, rest=None):
Expand All @@ -3507,27 +3522,34 @@ def _has_value(self, entity, rest=None):

Args:
entity (ndb.Model): An instance of a model.
rest (list[str]): optional list of attribute names to check in addition.
rest (list[str]): optional list of attribute names to check in
addition.

Returns:
bool: True if the entity has a value for that property.
"""
ok = super(StructuredProperty, self)._has_value(entity)
if ok and rest:
lst = self._get_base_value_unwrapped_as_list(entity)
if len(lst) != 1:
raise RuntimeError(
"Failed to retrieve sub-entity of StructuredProperty"
" %s" % self._name
)
subent = lst[0]
value = self._get_value(entity)
if self._repeated:
if len(value) != 1:
raise RuntimeError(
"Failed to retrieve sub-entity of StructuredProperty"
" %s" % self._name
)
subent = value[0]
else:
subent = value

if subent is None:
return True

subprop = subent._properties.get(rest[0])
if subprop is None:
ok = False
else:
ok = subprop._has_value(subent, rest[1:])

return ok

def _check_property(self, rest=None, require_indexed=True):
Expand All @@ -3541,15 +3563,42 @@ def _check_property(self, rest=None, require_indexed=True):
raise InvalidPropertyError(
"Structured property %s requires a subproperty" % self._name
)
self._modelclass._check_properties(
self._model_class._check_properties(
[rest], require_indexed=require_indexed
)

def _get_base_value_at_index(self, entity, index):
assert self._repeated
value = self._retrieve_value(entity, self._default)
value[index] = self._opt_call_to_base_type(value[index])
return value[index].b_val
def _to_base_type(self, value):
"""Convert a value to the "base" value type for this property.

Args:
value: The given class value to be converted.

Returns:
bytes

Raises:
TypeError: If ``value`` is not the correct ``Model`` type.
"""
if not isinstance(value, self._model_class):
raise TypeError(
"Cannot convert to protocol buffer. Expected {} value; "
"received {}".format(self._model_class.__name__, value)
)
return _entity_to_ds_entity(value)

def _from_base_type(self, value):
"""Convert a value from the "base" value type for this property.
Args:
value(~google.cloud.datastore.Entity or bytes): The value to be
converted.
Returns:
The converted value with given class.
"""
if isinstance(value, entity_module.Entity):
value = _entity_from_ds_entity(
value, model_class=self._model_class
)
return value

def _get_value_size(self, entity):
values = self._retrieve_value(entity, self._default)
Expand All @@ -3569,7 +3618,8 @@ class LocalStructuredProperty(BlobProperty):
.. automethod:: _from_base_type
.. automethod:: _validate
Args:
kls (ndb.Model): The class of the property.
model_class (type): The class of the property. (Must be subclass of
``ndb.Model``.)
name (str): The name of the property.
compressed (bool): Indicates if the value should be compressed (via
``zlib``).
Expand All @@ -3585,19 +3635,19 @@ class LocalStructuredProperty(BlobProperty):
to the datastore.
"""

_kls = None
_model_class = None
_keep_keys = False
_kwargs = None

def __init__(self, kls, **kwargs):
def __init__(self, model_class, **kwargs):
indexed = kwargs.pop("indexed", False)
if indexed:
raise NotImplementedError(
"Cannot index LocalStructuredProperty {}.".format(self._name)
)
keep_keys = kwargs.pop("keep_keys", False)
super(LocalStructuredProperty, self).__init__(**kwargs)
self._kls = kls
self._model_class = model_class
self._keep_keys = keep_keys

def _validate(self, value):
Expand All @@ -3609,11 +3659,13 @@ def _validate(self, value):
"""
if isinstance(value, dict):
# A dict is assumed to be the result of a _to_dict() call.
value = self._kls(**value)
value = self._model_class(**value)

if not isinstance(value, self._kls):
if not isinstance(value, self._model_class):
raise exceptions.BadValueError(
"Expected {}, got {!r}".format(self._kls.__name__, value)
"Expected {}, got {!r}".format(
self._model_class.__name__, value
)
)

def _to_base_type(self, value):
Expand All @@ -3623,12 +3675,12 @@ def _to_base_type(self, value):
Returns:
bytes
Raises:
TypeError: If ``value`` is not a given class.
TypeError: If ``value`` is not the correct ``Model`` type.
"""
if not isinstance(value, self._kls):
if not isinstance(value, self._model_class):
raise TypeError(
"Cannot convert to bytes expected {} value; "
"received {}".format(self._kls.__name__, value)
"received {}".format(self._model_class.__name__, value)
)
pb = _entity_to_protobuf(value, set_key=self._keep_keys)
return pb.SerializePartialToString()
Expand All @@ -3647,7 +3699,7 @@ def _from_base_type(self, value):
value = helpers.entity_from_protobuf(pb)
if not self._keep_keys and value.key:
value.key = None
return _entity_from_ds_entity(value, model_class=self._kls)
return _entity_from_ds_entity(value, model_class=self._model_class)


class GenericProperty(Property):
Expand Down Expand Up @@ -4328,7 +4380,7 @@ def _fix_up_properties(cls):
if isinstance(attr, Property):
if attr._repeated or (
isinstance(attr, StructuredProperty)
and attr._modelclass._has_repeated
and attr._model_class._has_repeated
):
cls._has_repeated = True
cls._properties[attr._name] = attr
Expand Down
1 change: 1 addition & 0 deletions tests/system/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import time

KIND = "SomeKind"
OTHER_KIND = "OtherKind"
OTHER_NAMESPACE = "other-namespace"


Expand Down
3 changes: 2 additions & 1 deletion tests/system/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
from google.cloud import datastore
from google.cloud import ndb

from . import KIND, OTHER_NAMESPACE
from . import KIND, OTHER_KIND, OTHER_NAMESPACE


def all_entities(client):
return itertools.chain(
client.query(kind=KIND).fetch(),
client.query(kind=OTHER_KIND).fetch(),
client.query(namespace=OTHER_NAMESPACE).fetch(),
)

Expand Down
23 changes: 23 additions & 0 deletions tests/system/test_crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,3 +329,26 @@ def do_the_thing():

entity = ndb.transaction(do_the_thing)
assert entity.foo == 42


@pytest.mark.usefixtures("client_context")
def test_insert_entity_with_structured_property(dispose_of):
class OtherKind(ndb.Model):
one = ndb.StringProperty()
two = ndb.StringProperty()

class SomeKind(ndb.Model):
foo = ndb.IntegerProperty()
bar = ndb.StructuredProperty(OtherKind)

entity = SomeKind(foo=42, bar=OtherKind(one="hi", two="mom"))
key = entity.put()

retrieved = key.get()
assert retrieved.foo == 42
assert retrieved.bar.one == "hi"
assert retrieved.bar.two == "mom"

assert isinstance(retrieved.bar, OtherKind)

dispose_of(key._key)
Loading