Skip to content

Commit

Permalink
Rework _is_child(_embedded_document) to take template as argument
Browse files Browse the repository at this point in the history
  • Loading branch information
lafrech committed May 5, 2020
1 parent 545bce2 commit db60d62
Showing 1 changed file with 25 additions and 15 deletions.
40 changes: 25 additions & 15 deletions umongo/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,27 @@ def camel_to_snake(name):
return re.sub('([a-z0-9])([A-Z])', r'\1_\2', tmp_str).lower()


def _is_child(bases):
def _is_child(template):
"""Find if the given inheritance leeds to a child document (i.e.
a document that shares the same collection with a parent)
"""
return any(b for b in bases if issubclass(b, DocumentImplementation) and not b.opts.abstract)
return any(
b for b in template.__bases__
if issubclass(b, DocumentTemplate) and
b is not DocumentTemplate and
('Meta' not in b.__dict__ or not getattr(b.Meta, 'abstract', False))
)


def _is_child_embedded_document(bases):
def _is_child_embedded_document(template):
"""Same thing than _is_child, but for EmbeddedDocument...
"""
return any(b for b in bases if issubclass(b, EmbeddedDocumentImplementation) and
not b.opts.abstract)
return any(
b for b in template.__bases__
if issubclass(b, EmbeddedDocumentTemplate) and
b is not EmbeddedDocumentTemplate and
('Meta' not in b.__dict__ or not getattr(b.Meta, 'abstract', False))
)


def _on_need_add_id_field(bases, fields_dict):
Expand Down Expand Up @@ -104,13 +113,12 @@ def _collect_schema_attrs(template):
return nmspc, schema_fields, schema_non_fields


def _collect_indexes(meta, schema_nmspc, bases):
def _collect_indexes(meta, schema_nmspc, bases, is_child):
"""
Retrieve all indexes (custom defined in meta class, by inheritances
and unique attribut in fields)
"""
indexes = []
is_child = _is_child(bases)

# First collect parent indexes (including inherited field's unique indexes)
for base in bases:
Expand Down Expand Up @@ -148,15 +156,15 @@ def parse_field(mongo_path, path, field):
return indexes


def _build_document_opts(instance, template, name, nmspc, bases):
def _build_document_opts(instance, template, name, nmspc, bases, is_child):
kwargs = {}
meta = nmspc.get('Meta')
collection_name = getattr(meta, 'collection_name', None)
kwargs['instance'] = instance
kwargs['template'] = template
kwargs['abstract'] = getattr(meta, 'abstract', False)
kwargs['allow_inheritance'] = getattr(meta, 'allow_inheritance', None)
kwargs['is_child'] = _is_child(bases)
kwargs['is_child'] = is_child
kwargs['strict'] = getattr(meta, 'strict', True)

# Handle option inheritance and integrity checks
Expand Down Expand Up @@ -186,14 +194,14 @@ def _build_document_opts(instance, template, name, nmspc, bases):
return DocumentOpts(collection_name=collection_name, **kwargs)


def _build_embedded_document_opts(instance, template, name, nmspc, bases):
def _build_embedded_document_opts(instance, template, name, nmspc, bases, is_child):
kwargs = {}
meta = nmspc.get('Meta')
kwargs['instance'] = instance
kwargs['template'] = template
kwargs['abstract'] = getattr(meta, 'abstract', False)
kwargs['allow_inheritance'] = getattr(meta, 'allow_inheritance', True)
kwargs['is_child'] = _is_child_embedded_document(bases)
kwargs['is_child'] = is_child
kwargs['strict'] = getattr(meta, 'strict', True)

# Handle option inheritance and integrity checks
Expand Down Expand Up @@ -279,12 +287,14 @@ def build_from_template(self, template):
name = template.__name__
bases = self._convert_bases(template.__bases__)
if embedded:
is_child = _is_child_embedded_document(template)
opts = _build_embedded_document_opts(
self.instance, template, name, template.__dict__, bases
self.instance, template, name, template.__dict__, bases, is_child
)
else:
is_child = _is_child(template)
opts = _build_document_opts(
self.instance, template, name, template.__dict__, bases
self.instance, template, name, template.__dict__, bases, is_child
)
nmspc, schema_fields, schema_non_fields = _collect_schema_attrs(template)
nmspc['opts'] = opts
Expand All @@ -296,7 +306,7 @@ def build_from_template(self, template):
if not embedded:
nmspc['pk_field'] = _on_need_add_id_field(schema_bases, schema_fields)

if opts.is_child:
if is_child:
_add_child_field(name, schema_fields)
schema_cls = self._build_schema(template, schema_bases, schema_fields, schema_non_fields)
nmspc['Schema'] = schema_cls
Expand All @@ -307,7 +317,7 @@ def build_from_template(self, template):
if not embedded:
# _build_document_opts cannot determine the indexes given we need to
# visit the document's fields which weren't defined at this time
opts.indexes = _collect_indexes(nmspc.get('Meta'), schema.fields, bases)
opts.indexes = _collect_indexes(nmspc.get('Meta'), schema.fields, bases, is_child)

implementation = type(name, bases, nmspc)
self._templates_lookup[template] = implementation
Expand Down

0 comments on commit db60d62

Please sign in to comment.