diff --git a/aiida/backends/djsite/db/subtests/query.py b/aiida/backends/djsite/db/subtests/query.py index 83d1e08f4d..fca6ebed3c 100644 --- a/aiida/backends/djsite/db/subtests/query.py +++ b/aiida/backends/djsite/db/subtests/query.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- from aiida.backends.djsite.db.testbase import AiidaTestCase -from aiida.backends.tests.query import TestQueryBuilder +from aiida.backends.tests.query import TestQueryBuilder, QueryBuilderJoinsTests @@ -9,6 +9,9 @@ __authors__ = "The AiiDA team." __version__ = "0.7.0" +class QueryBuilderJoinsTestsDjango(AiidaTestCase, QueryBuilderJoinsTests): + pass + class TestQueryBuilderDjango(AiidaTestCase, TestQueryBuilder): @@ -90,4 +93,5 @@ def test_clsf_django(self): ): self.assertEqual(clstype, Data._plugin_type_string) self.assertEqual(query_type_string, Data._query_type_string) - self.assertTrue(issubclass(cls, DbNode)) \ No newline at end of file + self.assertTrue(issubclass(cls, DbNode)) + diff --git a/aiida/backends/querybuild/querybuilder_base.py b/aiida/backends/querybuild/querybuilder_base.py index e572113e13..db967e819c 100644 --- a/aiida/backends/querybuild/querybuilder_base.py +++ b/aiida/backends/querybuild/querybuilder_base.py @@ -297,7 +297,8 @@ def _get_tag_from_type(self, ormclasstype): def append(self, cls=None, type=None, tag=None, autotag=False, filters=None, project=None, subclassing=True, - edge_tag=None, edge_filters=None, edge_project=None, **kwargs + edge_tag=None, edge_filters=None, edge_project=None, + outerjoin=False, **kwargs ): """ Any iterative procedure to build the path for a graph query @@ -311,14 +312,17 @@ def append(self, cls=None, type=None, tag=None, :param filters: Filters to apply for this vertice. See usage examples for details. - :param autotag: + :param bool autotag: Whether to search for a unique tag, (default **False**). If **True**, will find a unique tag. Cannot be set to **True** if tag is specified. - :param subclassing: + :param bool subclassing: Whether to include subclasses of the given class (default **True**). E.g. Specifying JobCalculation will include PwCalculation + :param bool outerjoin: + If True, (default is False), will do a left outerjoin + instead of an inner join A small usage example how this can be invoked:: @@ -575,7 +579,7 @@ def append(self, cls=None, type=None, tag=None, path_extension = dict( type=ormclasstype, tag=tag, joining_keyword=joining_keyword, - joining_value=joining_value + joining_value=joining_value, outerjoin=outerjoin, ) if aliased_edge is not None: path_extension.update(dict(edge_tag=edge_tag)) @@ -1254,7 +1258,7 @@ def _join_masters(self, joined_entity, entity_to_join): #~ call.caller_id == entity_to_join.id #~ ) - def _join_outputs(self, joined_entity, entity_to_join, aliased_edge): + def _join_outputs(self, joined_entity, entity_to_join, aliased_edge, isouterjoin): """ :param joined_entity: The (aliased) ORMclass that is an input :param entity_to_join: The (aliased) ORMClass that is an output. @@ -1268,16 +1272,19 @@ def _join_outputs(self, joined_entity, entity_to_join, aliased_edge): (entity_to_join, self.Node), 'output_of' ) + self._query = self._query.join( aliased_edge, - aliased_edge.input_id == joined_entity.id - ).join( + aliased_edge.input_id == joined_entity.id, + isouter=isouterjoin + ).join( entity_to_join, - aliased_edge.output_id == entity_to_join.id, - isouter=self.isouter + + aliased_edge.output_id == entity_to_join.id, + isouter=isouterjoin ) - def _join_inputs(self, joined_entity, entity_to_join, aliased_edge): + def _join_inputs(self, joined_entity, entity_to_join, aliased_edge,isouterjoin): """ :param joined_entity: The (aliased) ORMclass that is an output :param entity_to_join: The (aliased) ORMClass that is an input. @@ -1286,6 +1293,7 @@ def _join_inputs(self, joined_entity, entity_to_join, aliased_edge): from **joined_entity** as output to **enitity_to_join** as input (**enitity_to_join** is an *input_of* **joined_entity**) """ + self._check_dbentities( (joined_entity, self.Node), (entity_to_join, self.Node), @@ -1293,14 +1301,64 @@ def _join_inputs(self, joined_entity, entity_to_join, aliased_edge): ) self._query = self._query.join( aliased_edge, - aliased_edge.output_id == joined_entity.id + aliased_edge.output_id == joined_entity.id, ).join( entity_to_join, aliased_edge.input_id == entity_to_join.id, - isouter=self.isouter + isouter=isouterjoin ) - def _join_descendants(self, joined_entity, entity_to_join, aliased_path): + def _join_descendants_beta(self, joined_entity, entity_to_join, aliased_path, isouterjoin): + """ + Beta version, joining descendants using the recursive functionality + """ + self._check_dbentities( + (joined_entity, self.Node), + (entity_to_join, self.Node), + 'descendant_of_beta' + ) + + self._query = self._query.join( + aliased_path, + aliased_path.ancestor_id == joined_entity.id + ).join( + entity_to_join, + aliased_path.descendant_id == entity_to_join.id, + isouter=isouterjoin + ).filter( + # it is necessary to put this filter so that the + # the node does not include itself as a ancestor/descendant + aliased_path.depth > -1 + ) + + def _join_ancestors_beta(self, joined_entity, entity_to_join, aliased_path, isouterjoin): + """ + :param joined_entity: The (aliased) ORMclass that is a descendant + :param entity_to_join: The (aliased) ORMClass that is an ancestor. + :param aliased_path: An aliased instance of DbPath + + """ + self._check_dbentities( + (joined_entity, self.Node), + (entity_to_join, self.Node), + 'ancestor_of_beta' + ) + #~ aliased_path = aliased(self.Path) + self._query = self._query.join( + aliased_path, + aliased_path.descendant_id == joined_entity.id + ).join( + entity_to_join, + aliased_path.ancestor_id == entity_to_join.id, + isouter=isouterjoin + ).filter( + # it is necessary to put this filter so that the + # the node does not include itself as a ancestor/descendant + aliased_path.depth > -1 + ) + + + def _join_descendants(self, joined_entity, entity_to_join, aliased_path, isouterjoin): """ :param joined_entity: The (aliased) ORMclass that is an ancestor :param entity_to_join: The (aliased) ORMClass that is a descendant. @@ -1323,10 +1381,10 @@ def _join_descendants(self, joined_entity, entity_to_join, aliased_path): ).join( entity_to_join, aliased_path.child_id == entity_to_join.id, - isouter=self.isouter + isouter=isouterjoin ) - def _join_ancestors(self, joined_entity, entity_to_join, aliased_path): + def _join_ancestors(self, joined_entity, entity_to_join, aliased_path, isouterjoin): """ :param joined_entity: The (aliased) ORMclass that is a descendant :param entity_to_join: The (aliased) ORMClass that is an ancestor. @@ -1349,10 +1407,10 @@ def _join_ancestors(self, joined_entity, entity_to_join, aliased_path): ).join( entity_to_join, aliased_path.parent_id == entity_to_join.id, - isouter=self.isouter + isouter=isouterjoin ) - def _join_group_members(self, joined_entity, entity_to_join): + def _join_group_members(self, joined_entity, entity_to_join, isouterjoin): """ :param joined_entity: The (aliased) ORMclass that is @@ -1377,10 +1435,10 @@ def _join_group_members(self, joined_entity, entity_to_join): ).join( entity_to_join, entity_to_join.id == aliased_group_nodes.c.dbnode_id, - isouter=self.isouter + isouter=isouterjoin ) - def _join_groups(self, joined_entity, entity_to_join): + def _join_groups(self, joined_entity, entity_to_join, isouterjoin): """ :param joined_entity: The (aliased) node in the database :param entity_to_join: The (aliased) Group @@ -1402,10 +1460,10 @@ def _join_groups(self, joined_entity, entity_to_join): ).join( entity_to_join, entity_to_join.id == aliased_group_nodes.c.dbgroup_id, - isouter=self.isouter + isouter=isouterjoin ) - def _join_creator_of(self, joined_entity, entity_to_join): + def _join_creator_of(self, joined_entity, entity_to_join, isouterjoin): """ :param joined_entity: the aliased node :param entity_to_join: the aliased user to join to that node @@ -1418,9 +1476,9 @@ def _join_creator_of(self, joined_entity, entity_to_join): self._query = self._query.join( entity_to_join, entity_to_join.id == joined_entity.user_id, - isouter=self.isouter + isouter=isouterjoin ) - def _join_created_by(self, joined_entity, entity_to_join): + def _join_created_by(self, joined_entity, entity_to_join, isouterjoin): """ :param joined_entity: the aliased user you want to join to :param entity_to_join: the (aliased) node or group in the DB to join with @@ -1433,10 +1491,10 @@ def _join_created_by(self, joined_entity, entity_to_join): self._query = self._query.join( entity_to_join, entity_to_join.user_id == joined_entity.id, - isouter=self.isouter + isouter=isouterjoin ) - def _join_to_computer_used(self, joined_entity, entity_to_join): + def _join_to_computer_used(self, joined_entity, entity_to_join, isouterjoin): """ :param joined_entity: the (aliased) computer entity :param entity_to_join: the (aliased) node entity @@ -1450,10 +1508,10 @@ def _join_to_computer_used(self, joined_entity, entity_to_join): self._query = self._query.join( entity_to_join, entity_to_join.dbcomputer_id == joined_entity.id, - isouter=self.isouter + isouter=isouterjoin ) - def _join_computer(self, joined_entity, entity_to_join): + def _join_computer(self, joined_entity, entity_to_join, isouterjoin): """ :param joined_entity: An entity that can use a computer (eg a node) :param entity_to_join: aliased dbcomputer entity @@ -1468,7 +1526,7 @@ def _join_computer(self, joined_entity, entity_to_join): self._query = self._query.join( entity_to_join, joined_entity.dbcomputer_id == entity_to_join.id, - isouter=self.isouter + isouter=isouterjoin ) def _get_function_map(self): @@ -1672,11 +1730,12 @@ def _build(self): index, **verticespec ) edge_tag = verticespec.get('edge_tag', None) + isouterjoin = verticespec.get('outerjoin') if edge_tag is None: - connection_func(toconnectwith, alias) + connection_func(toconnectwith, alias, isouterjoin=isouterjoin) else: aliased_edge = self._tag_to_alias_map[edge_tag] - connection_func(toconnectwith, alias, aliased_edge) + connection_func(toconnectwith, alias, aliased_edge, isouterjoin=isouterjoin) ######################### FILTERS ############################## @@ -1926,10 +1985,21 @@ def _yield_per(self, batch_size): :returns: a generator """ - return self.get_query().yield_per(batch_size) + try: + return self.get_query().yield_per(batch_size) + except Exception as e: + # exception was raised. Rollback the session + self._get_session().rollback() + raise e + def _all(self): - return self.get_query().all() + try: + return self.get_query().all() + except Exception as e: + # exception was raised. Rollback the session + self._get_session().rollback() + raise e def _first(self): """ @@ -1937,7 +2007,13 @@ def _first(self): :returns: One row of aiida results """ - return self.get_query().first() + try: + return self.get_query().first() + except Exception as e: + # exception was raised. Rollback the session + self._get_session().rollback() + raise e + def first(self): """ diff --git a/aiida/backends/querybuild/querybuilder_django.py b/aiida/backends/querybuild/querybuilder_django.py index c3af692646..acea1b9f77 100644 --- a/aiida/backends/querybuild/querybuilder_django.py +++ b/aiida/backends/querybuild/querybuilder_django.py @@ -290,7 +290,6 @@ def _get_projectable_attribute( attrkey = '.'.join(attrpath) - exists_stmt = exists(select([1], correlate=True).select_from( aliased_attributes ).where(and_( diff --git a/aiida/backends/querybuild/querybuilder_sa.py b/aiida/backends/querybuild/querybuilder_sa.py index 9cd4cd1b37..7cf87202e9 100644 --- a/aiida/backends/querybuild/querybuilder_sa.py +++ b/aiida/backends/querybuild/querybuilder_sa.py @@ -139,7 +139,7 @@ def cast_according_to_type(path_in_json, value): elif operator in ('>=', '=>'): type_filter, casted_entity = cast_according_to_type(database_entity, value) expr = and_(type_filter, casted_entity >= value) - elif operator == ('<=', '=<'): + elif operator in ('<=', '=<'): type_filter, casted_entity = cast_according_to_type(database_entity, value) expr = and_(type_filter, casted_entity <= value) elif operator == 'of_type': diff --git a/aiida/backends/sqlalchemy/models/settings.py b/aiida/backends/sqlalchemy/models/settings.py index 023c1ecb4d..3520925411 100644 --- a/aiida/backends/sqlalchemy/models/settings.py +++ b/aiida/backends/sqlalchemy/models/settings.py @@ -6,7 +6,7 @@ from sqlalchemy.schema import UniqueConstraint from sqlalchemy.types import Integer, String, DateTime -from aiida.backends.sqlalchemy import session +import aiida.backends.sqlalchemy as sa from aiida.backends.sqlalchemy.models.base import Base from aiida.utils import timezone @@ -37,7 +37,7 @@ def set_value(cls, key, value, with_transaction=True, subspecifier_value=None, other_attribs={}, stop_if_existing=False): - setting = session.query(DbSetting).filter_by(key=key).first() + setting = sa.session.query(DbSetting).filter_by(key=key).first() if setting is not None: if stop_if_existing: return @@ -67,7 +67,7 @@ def get_description(self): @classmethod def del_value(cls, key, only_children=False, subspecifier_value=None): - setting = session.query(DbSetting).filter(key=key) + setting = sa.session.query(DbSetting).filter(key=key) setting.val = None setting.time = timezone.datetime.utcnow() flag_modified(setting, "val") diff --git a/aiida/backends/sqlalchemy/tests/query.py b/aiida/backends/sqlalchemy/tests/query.py index 31ddb0f4db..5d6aebdd2e 100644 --- a/aiida/backends/sqlalchemy/tests/query.py +++ b/aiida/backends/sqlalchemy/tests/query.py @@ -1,4 +1,4 @@ -from aiida.backends.tests.query import TestQueryBuilder +from aiida.backends.tests.query import TestQueryBuilder, QueryBuilderJoinsTests from aiida.backends.sqlalchemy.tests.testbase import SqlAlchemyTests @@ -23,3 +23,8 @@ def test_clsf_sqla(self): self.assertEqual(cls, ORMCls) self.assertEqual(query_type_string, typestr) + + + +class QueryBuilderJoinsTestsSQLA(SqlAlchemyTests, QueryBuilderJoinsTests): + pass diff --git a/aiida/backends/sqlalchemy/tests/test_runner.py b/aiida/backends/sqlalchemy/tests/test_runner.py index 557d9427bb..c445bdafb2 100644 --- a/aiida/backends/sqlalchemy/tests/test_runner.py +++ b/aiida/backends/sqlalchemy/tests/test_runner.py @@ -21,7 +21,7 @@ def find_classes(module_str): def run_tests(): modules_str = [ - # "aiida.backends.sqlalchemy.tests.query", + # "aiida.backends.sqlalchemy.tests.query", # "aiida.backends.sqlalchemy.tests.nodes", # "aiida.backends.sqlalchemy.tests.backup_script", # "aiida.backends.sqlalchemy.tests.export_and_import", diff --git a/aiida/backends/tests/query.py b/aiida/backends/tests/query.py index f289543426..1dc2c7d75b 100644 --- a/aiida/backends/tests/query.py +++ b/aiida/backends/tests/query.py @@ -8,17 +8,8 @@ __authors__ = "The AiiDA team." __version__ = "0.7.0" -def is_postgres(): - from aiida.backends import settings - from aiida.common.setup import get_profile_config - profile_conf = get_profile_config(settings.AIIDADB_PROFILE) - return profile_conf['AIIDADB_ENGINE'] == 'postgresql_psycopg2' -def is_django(): - from aiida.backends import settings - return settings.BACKEND == 'django' - class TestQueryBuilder(): def test_classification(self): @@ -266,3 +257,118 @@ def test_simple_query_2(self): self.assertTrue(id(query1) != id(query2)) self.assertTrue(id(query2) == id(query3)) + + + def test_operators_eq_lt_gt(self): + from aiida.orm.querybuilder import QueryBuilder + from aiida.orm import Node + + + nodes = [Node() for _ in range(8)] + + + nodes[0]._set_attr('fa', 1) + nodes[1]._set_attr('fa', 1.0) + nodes[2]._set_attr('fa', 1.01) + nodes[3]._set_attr('fa', 1.02) + nodes[4]._set_attr('fa', 1.03) + nodes[5]._set_attr('fa', 1.04) + nodes[6]._set_attr('fa', 1.05) + nodes[7]._set_attr('fa', 1.06) + + [n.store() for n in nodes] + + self.assertEqual(QueryBuilder().append(Node, filters={'attributes.fa':{'<':1}}).count(), 0) + self.assertEqual(QueryBuilder().append(Node, filters={'attributes.fa':{'==':1}}).count(), 2) + self.assertEqual(QueryBuilder().append(Node, filters={'attributes.fa':{'<':1.02}}).count(), 3) + self.assertEqual(QueryBuilder().append(Node, filters={'attributes.fa':{'<=':1.02}}).count(), 4) + self.assertEqual(QueryBuilder().append(Node, filters={'attributes.fa':{'>':1.02}}).count(), 4) + self.assertEqual(QueryBuilder().append(Node, filters={'attributes.fa':{'>=':1.02}}).count(), 5) + + + + + +class QueryBuilderJoinsTests(): + def test_joins1(self): + from aiida.orm import Node, Data, Calculation + from aiida.orm.querybuilder import QueryBuilder + # Creating n1, who will be a parent: + parent=Node() + parent.label = 'mother' + + good_child=Node() + good_child.label='good_child' + good_child._set_attr('is_good', True) + + bad_child=Node() + bad_child.label='bad_child' + bad_child._set_attr('is_good', False) + + unrelated = Node() + unrelated.label = 'unrelated' + + for n in (good_child, bad_child, parent, unrelated): + n.store() + + good_child.add_link_from(parent, label='parent') + bad_child.add_link_from(parent, label='parent') + + # Using a standard inner join + qb = QueryBuilder() + qb.append(Node, tag='parent') + qb.append(Node, tag='children', project='label', filters={'attributes.is_good':True}) + self.assertEqual(qb.count(), 1) + + + qb = QueryBuilder() + qb.append(Node, tag='parent') + qb.append(Node, tag='children', outerjoin=True, project='label', filters={'attributes.is_good':True}) + self.assertEqual(qb.count(), 1) + + def test_joins2(self): + from aiida.orm import Node, Data, Calculation + from aiida.orm.querybuilder import QueryBuilder + # Creating n1, who will be a parent: + + students = [Node() for i in range(10)] + advisors = [Node() for i in range(3)] + for i, a in enumerate(advisors): + a.label = 'advisor {}'.format(i) + a._set_attr('advisor_id', i) + + for n in advisors+students: + n.store() + + + # advisor 0 get student 0, 1 + for i in (0,1): + students[i].add_link_from(advisors[0], label='is_advisor') + + # advisor 1 get student 3, 4 + for i in (3,4): + students[i].add_link_from(advisors[1], label='is_advisor') + + # advisor 2 get student 5, 6, 7 + for i in (5,6,7): + students[i].add_link_from(advisors[2], label='is_advisor') + + # let's add a differnt relationship than advisor: + students[9].add_link_from(advisors[2], label='lover') + + + self.assertEqual( + QueryBuilder().append( + Node + ).append( + Node, edge_filters={'label':'is_advisor'}, tag='student' + ).count(), 7) + + for adv_id, number_students in zip(range(3), (2,2,3)): + self.assertEqual(QueryBuilder().append( + Node, filters={'attributes.advisor_id':adv_id} + ).append( + Node, edge_filters={'label':'is_advisor'}, tag='student' + ).count(), number_students) + +