Skip to content

Commit

Permalink
Merge branch 'separating_tests' into restapi_tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Snehal Waychal committed Dec 6, 2016
2 parents c0e2774 + fdf73e9 commit c86e470
Show file tree
Hide file tree
Showing 8 changed files with 241 additions and 51 deletions.
8 changes: 6 additions & 2 deletions aiida/backends/djsite/db/subtests/query.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
from aiida.backends.djsite.db.testbase import AiidaTestCase
from aiida.backends.tests.query import TestQueryBuilder
from aiida.backends.tests.query import TestQueryBuilder, QueryBuilderJoinsTests



Expand All @@ -9,6 +9,9 @@
__authors__ = "The AiiDA team."
__version__ = "0.7.0"

class QueryBuilderJoinsTestsDjango(AiidaTestCase, QueryBuilderJoinsTests):
pass

class TestQueryBuilderDjango(AiidaTestCase, TestQueryBuilder):


Expand Down Expand Up @@ -90,4 +93,5 @@ def test_clsf_django(self):
):
self.assertEqual(clstype, Data._plugin_type_string)
self.assertEqual(query_type_string, Data._query_type_string)
self.assertTrue(issubclass(cls, DbNode))
self.assertTrue(issubclass(cls, DbNode))

142 changes: 109 additions & 33 deletions aiida/backends/querybuild/querybuilder_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,8 @@ def _get_tag_from_type(self, ormclasstype):

def append(self, cls=None, type=None, tag=None,
autotag=False, filters=None, project=None, subclassing=True,
edge_tag=None, edge_filters=None, edge_project=None, **kwargs
edge_tag=None, edge_filters=None, edge_project=None,
outerjoin=False, **kwargs
):
"""
Any iterative procedure to build the path for a graph query
Expand All @@ -311,14 +312,17 @@ def append(self, cls=None, type=None, tag=None,
:param filters:
Filters to apply for this vertice.
See usage examples for details.
:param autotag:
:param bool autotag:
Whether to search for a unique tag,
(default **False**). If **True**, will find a unique tag.
Cannot be set to **True** if tag is specified.
:param subclassing:
:param bool subclassing:
Whether to include subclasses of the given class
(default **True**).
E.g. Specifying JobCalculation will include PwCalculation
:param bool outerjoin:
If True, (default is False), will do a left outerjoin
instead of an inner join
A small usage example how this can be invoked::
Expand Down Expand Up @@ -575,7 +579,7 @@ def append(self, cls=None, type=None, tag=None,

path_extension = dict(
type=ormclasstype, tag=tag, joining_keyword=joining_keyword,
joining_value=joining_value
joining_value=joining_value, outerjoin=outerjoin,
)
if aliased_edge is not None:
path_extension.update(dict(edge_tag=edge_tag))
Expand Down Expand Up @@ -1254,7 +1258,7 @@ def _join_masters(self, joined_entity, entity_to_join):
#~ call.caller_id == entity_to_join.id
#~ )

def _join_outputs(self, joined_entity, entity_to_join, aliased_edge):
def _join_outputs(self, joined_entity, entity_to_join, aliased_edge, isouterjoin):
"""
:param joined_entity: The (aliased) ORMclass that is an input
:param entity_to_join: The (aliased) ORMClass that is an output.
Expand All @@ -1268,16 +1272,19 @@ def _join_outputs(self, joined_entity, entity_to_join, aliased_edge):
(entity_to_join, self.Node),
'output_of'
)

self._query = self._query.join(
aliased_edge,
aliased_edge.input_id == joined_entity.id
).join(
aliased_edge.input_id == joined_entity.id,
isouter=isouterjoin
).join(
entity_to_join,
aliased_edge.output_id == entity_to_join.id,
isouter=self.isouter

aliased_edge.output_id == entity_to_join.id,
isouter=isouterjoin
)

def _join_inputs(self, joined_entity, entity_to_join, aliased_edge):
def _join_inputs(self, joined_entity, entity_to_join, aliased_edge,isouterjoin):
"""
:param joined_entity: The (aliased) ORMclass that is an output
:param entity_to_join: The (aliased) ORMClass that is an input.
Expand All @@ -1286,21 +1293,72 @@ def _join_inputs(self, joined_entity, entity_to_join, aliased_edge):
from **joined_entity** as output to **enitity_to_join** as input
(**enitity_to_join** is an *input_of* **joined_entity**)
"""

self._check_dbentities(
(joined_entity, self.Node),
(entity_to_join, self.Node),
'input_of'
)
self._query = self._query.join(
aliased_edge,
aliased_edge.output_id == joined_entity.id
aliased_edge.output_id == joined_entity.id,
).join(
entity_to_join,
aliased_edge.input_id == entity_to_join.id,
isouter=self.isouter
isouter=isouterjoin
)

def _join_descendants(self, joined_entity, entity_to_join, aliased_path):
def _join_descendants_beta(self, joined_entity, entity_to_join, aliased_path, isouterjoin):
"""
Beta version, joining descendants using the recursive functionality
"""
self._check_dbentities(
(joined_entity, self.Node),
(entity_to_join, self.Node),
'descendant_of_beta'
)

self._query = self._query.join(
aliased_path,
aliased_path.ancestor_id == joined_entity.id
).join(
entity_to_join,
aliased_path.descendant_id == entity_to_join.id,
isouter=isouterjoin
).filter(
# it is necessary to put this filter so that the
# the node does not include itself as a ancestor/descendant
aliased_path.depth > -1
)

def _join_ancestors_beta(self, joined_entity, entity_to_join, aliased_path, isouterjoin):
"""
:param joined_entity: The (aliased) ORMclass that is a descendant
:param entity_to_join: The (aliased) ORMClass that is an ancestor.
:param aliased_path: An aliased instance of DbPath
"""
self._check_dbentities(
(joined_entity, self.Node),
(entity_to_join, self.Node),
'ancestor_of_beta'
)
#~ aliased_path = aliased(self.Path)
self._query = self._query.join(
aliased_path,
aliased_path.descendant_id == joined_entity.id
).join(
entity_to_join,
aliased_path.ancestor_id == entity_to_join.id,
isouter=isouterjoin
).filter(
# it is necessary to put this filter so that the
# the node does not include itself as a ancestor/descendant
aliased_path.depth > -1
)


def _join_descendants(self, joined_entity, entity_to_join, aliased_path, isouterjoin):
"""
:param joined_entity: The (aliased) ORMclass that is an ancestor
:param entity_to_join: The (aliased) ORMClass that is a descendant.
Expand All @@ -1323,10 +1381,10 @@ def _join_descendants(self, joined_entity, entity_to_join, aliased_path):
).join(
entity_to_join,
aliased_path.child_id == entity_to_join.id,
isouter=self.isouter
isouter=isouterjoin
)

def _join_ancestors(self, joined_entity, entity_to_join, aliased_path):
def _join_ancestors(self, joined_entity, entity_to_join, aliased_path, isouterjoin):
"""
:param joined_entity: The (aliased) ORMclass that is a descendant
:param entity_to_join: The (aliased) ORMClass that is an ancestor.
Expand All @@ -1349,10 +1407,10 @@ def _join_ancestors(self, joined_entity, entity_to_join, aliased_path):
).join(
entity_to_join,
aliased_path.parent_id == entity_to_join.id,
isouter=self.isouter
isouter=isouterjoin
)

def _join_group_members(self, joined_entity, entity_to_join):
def _join_group_members(self, joined_entity, entity_to_join, isouterjoin):
"""
:param joined_entity:
The (aliased) ORMclass that is
Expand All @@ -1377,10 +1435,10 @@ def _join_group_members(self, joined_entity, entity_to_join):
).join(
entity_to_join,
entity_to_join.id == aliased_group_nodes.c.dbnode_id,
isouter=self.isouter
isouter=isouterjoin
)

def _join_groups(self, joined_entity, entity_to_join):
def _join_groups(self, joined_entity, entity_to_join, isouterjoin):
"""
:param joined_entity: The (aliased) node in the database
:param entity_to_join: The (aliased) Group
Expand All @@ -1402,10 +1460,10 @@ def _join_groups(self, joined_entity, entity_to_join):
).join(
entity_to_join,
entity_to_join.id == aliased_group_nodes.c.dbgroup_id,
isouter=self.isouter
isouter=isouterjoin
)

def _join_creator_of(self, joined_entity, entity_to_join):
def _join_creator_of(self, joined_entity, entity_to_join, isouterjoin):
"""
:param joined_entity: the aliased node
:param entity_to_join: the aliased user to join to that node
Expand All @@ -1418,9 +1476,9 @@ def _join_creator_of(self, joined_entity, entity_to_join):
self._query = self._query.join(
entity_to_join,
entity_to_join.id == joined_entity.user_id,
isouter=self.isouter
isouter=isouterjoin
)
def _join_created_by(self, joined_entity, entity_to_join):
def _join_created_by(self, joined_entity, entity_to_join, isouterjoin):
"""
:param joined_entity: the aliased user you want to join to
:param entity_to_join: the (aliased) node or group in the DB to join with
Expand All @@ -1433,10 +1491,10 @@ def _join_created_by(self, joined_entity, entity_to_join):
self._query = self._query.join(
entity_to_join,
entity_to_join.user_id == joined_entity.id,
isouter=self.isouter
isouter=isouterjoin
)

def _join_to_computer_used(self, joined_entity, entity_to_join):
def _join_to_computer_used(self, joined_entity, entity_to_join, isouterjoin):
"""
:param joined_entity: the (aliased) computer entity
:param entity_to_join: the (aliased) node entity
Expand All @@ -1450,10 +1508,10 @@ def _join_to_computer_used(self, joined_entity, entity_to_join):
self._query = self._query.join(
entity_to_join,
entity_to_join.dbcomputer_id == joined_entity.id,
isouter=self.isouter
isouter=isouterjoin
)

def _join_computer(self, joined_entity, entity_to_join):
def _join_computer(self, joined_entity, entity_to_join, isouterjoin):
"""
:param joined_entity: An entity that can use a computer (eg a node)
:param entity_to_join: aliased dbcomputer entity
Expand All @@ -1468,7 +1526,7 @@ def _join_computer(self, joined_entity, entity_to_join):
self._query = self._query.join(
entity_to_join,
joined_entity.dbcomputer_id == entity_to_join.id,
isouter=self.isouter
isouter=isouterjoin
)

def _get_function_map(self):
Expand Down Expand Up @@ -1672,11 +1730,12 @@ def _build(self):
index, **verticespec
)
edge_tag = verticespec.get('edge_tag', None)
isouterjoin = verticespec.get('outerjoin')
if edge_tag is None:
connection_func(toconnectwith, alias)
connection_func(toconnectwith, alias, isouterjoin=isouterjoin)
else:
aliased_edge = self._tag_to_alias_map[edge_tag]
connection_func(toconnectwith, alias, aliased_edge)
connection_func(toconnectwith, alias, aliased_edge, isouterjoin=isouterjoin)

######################### FILTERS ##############################

Expand Down Expand Up @@ -1926,18 +1985,35 @@ def _yield_per(self, batch_size):
:returns: a generator
"""
return self.get_query().yield_per(batch_size)
try:
return self.get_query().yield_per(batch_size)
except Exception as e:
# exception was raised. Rollback the session
self._get_session().rollback()
raise e


def _all(self):
return self.get_query().all()
try:
return self.get_query().all()
except Exception as e:
# exception was raised. Rollback the session
self._get_session().rollback()
raise e

def _first(self):
"""
Executes query in the backend asking for one instance.
:returns: One row of aiida results
"""
return self.get_query().first()
try:
return self.get_query().first()
except Exception as e:
# exception was raised. Rollback the session
self._get_session().rollback()
raise e


def first(self):
"""
Expand Down
1 change: 0 additions & 1 deletion aiida/backends/querybuild/querybuilder_django.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,6 @@ def _get_projectable_attribute(

attrkey = '.'.join(attrpath)


exists_stmt = exists(select([1], correlate=True).select_from(
aliased_attributes
).where(and_(
Expand Down
2 changes: 1 addition & 1 deletion aiida/backends/querybuild/querybuilder_sa.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def cast_according_to_type(path_in_json, value):
elif operator in ('>=', '=>'):
type_filter, casted_entity = cast_according_to_type(database_entity, value)
expr = and_(type_filter, casted_entity >= value)
elif operator == ('<=', '=<'):
elif operator in ('<=', '=<'):
type_filter, casted_entity = cast_according_to_type(database_entity, value)
expr = and_(type_filter, casted_entity <= value)
elif operator == 'of_type':
Expand Down
6 changes: 3 additions & 3 deletions aiida/backends/sqlalchemy/models/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from sqlalchemy.schema import UniqueConstraint
from sqlalchemy.types import Integer, String, DateTime

from aiida.backends.sqlalchemy import session
import aiida.backends.sqlalchemy as sa
from aiida.backends.sqlalchemy.models.base import Base
from aiida.utils import timezone

Expand Down Expand Up @@ -37,7 +37,7 @@ def set_value(cls, key, value, with_transaction=True,
subspecifier_value=None, other_attribs={},
stop_if_existing=False):

setting = session.query(DbSetting).filter_by(key=key).first()
setting = sa.session.query(DbSetting).filter_by(key=key).first()
if setting is not None:
if stop_if_existing:
return
Expand Down Expand Up @@ -67,7 +67,7 @@ def get_description(self):

@classmethod
def del_value(cls, key, only_children=False, subspecifier_value=None):
setting = session.query(DbSetting).filter(key=key)
setting = sa.session.query(DbSetting).filter(key=key)
setting.val = None
setting.time = timezone.datetime.utcnow()
flag_modified(setting, "val")
Expand Down
7 changes: 6 additions & 1 deletion aiida/backends/sqlalchemy/tests/query.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from aiida.backends.tests.query import TestQueryBuilder
from aiida.backends.tests.query import TestQueryBuilder, QueryBuilderJoinsTests
from aiida.backends.sqlalchemy.tests.testbase import SqlAlchemyTests


Expand All @@ -23,3 +23,8 @@ def test_clsf_sqla(self):

self.assertEqual(cls, ORMCls)
self.assertEqual(query_type_string, typestr)



class QueryBuilderJoinsTestsSQLA(SqlAlchemyTests, QueryBuilderJoinsTests):
pass
2 changes: 1 addition & 1 deletion aiida/backends/sqlalchemy/tests/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def find_classes(module_str):

def run_tests():
modules_str = [
# "aiida.backends.sqlalchemy.tests.query",
# "aiida.backends.sqlalchemy.tests.query",
# "aiida.backends.sqlalchemy.tests.nodes",
# "aiida.backends.sqlalchemy.tests.backup_script",
# "aiida.backends.sqlalchemy.tests.export_and_import",
Expand Down
Loading

0 comments on commit c86e470

Please sign in to comment.