From 2f2bdc38191632f30b2c0be4f3445190dc896d90 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Fri, 22 Oct 2021 11:24:22 +0200 Subject: [PATCH 1/5] =?UTF-8?q?=F0=9F=91=8C=20IMPROVE:=20Ensure=20`QueryBu?= =?UTF-8?q?ilder`=20is=20passed=20`Backend`=20(#5186)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR ensures core code always calls `QueryBuilder` with a specific `Backend`, as opposed to assuming the loaded `Backend`. This will allow for muliple backends to be used at the same time (for example export archives), for features including graph traversal and visualisation. --- aiida/cmdline/utils/common.py | 8 +++-- aiida/orm/implementation/django/comments.py | 2 +- aiida/orm/implementation/django/logs.py | 2 +- .../orm/implementation/sqlalchemy/comments.py | 2 +- aiida/orm/implementation/sqlalchemy/logs.py | 2 +- aiida/orm/nodes/data/array/bands.py | 4 +-- aiida/orm/nodes/data/cif.py | 4 +-- aiida/orm/nodes/data/code.py | 8 ++--- aiida/orm/nodes/data/upf.py | 14 ++++----- aiida/orm/nodes/node.py | 8 ++--- aiida/orm/querybuilder.py | 11 +++++-- aiida/orm/utils/links.py | 8 ++--- aiida/orm/utils/remote.py | 6 ++-- aiida/tools/graph/age_entities.py | 11 ------- aiida/tools/graph/age_rules.py | 7 +++-- aiida/tools/graph/deletions.py | 20 ++++++------ aiida/tools/graph/graph_traversers.py | 24 +++++++++----- aiida/tools/visualization/graph.py | 31 +++++++++++++++---- 18 files changed, 101 insertions(+), 71 deletions(-) diff --git a/aiida/cmdline/utils/common.py b/aiida/cmdline/utils/common.py index 36fa393170..9c57980ab9 100644 --- a/aiida/cmdline/utils/common.py +++ b/aiida/cmdline/utils/common.py @@ -11,11 +11,15 @@ import logging import os import sys +from typing import TYPE_CHECKING from tabulate import tabulate from . import echo +if TYPE_CHECKING: + from aiida.orm import WorkChainNode + __all__ = ('is_verbose',) @@ -306,7 +310,7 @@ def get_process_function_report(node): return '\n'.join(report) -def get_workchain_report(node, levelname, indent_size=4, max_depth=None): +def get_workchain_report(node: 'WorkChainNode', levelname, indent_size=4, max_depth=None): """ Return a multi line string representation of the log messages and output of a given workchain @@ -333,7 +337,7 @@ def get_subtree(uuid, level=0): Get a nested tree of work calculation nodes and their nesting level starting from this uuid. The result is a list of uuid of these nodes. """ - builder = orm.QueryBuilder() + builder = orm.QueryBuilder(backend=node.backend) builder.append(cls=orm.WorkChainNode, filters={'uuid': uuid}, tag='workcalculation') builder.append( cls=orm.WorkChainNode, diff --git a/aiida/orm/implementation/django/comments.py b/aiida/orm/implementation/django/comments.py index be7fe71b9d..ab874b6b52 100644 --- a/aiida/orm/implementation/django/comments.py +++ b/aiida/orm/implementation/django/comments.py @@ -168,7 +168,7 @@ def delete_many(self, filters): raise exceptions.ValidationError('filters must not be empty') # Apply filter and delete found entities - builder = QueryBuilder().append(Comment, filters=filters, project='id').all() + builder = QueryBuilder(backend=self.backend).append(Comment, filters=filters, project='id').all() entities_to_delete = [_[0] for _ in builder] for entity in entities_to_delete: self.delete(entity) diff --git a/aiida/orm/implementation/django/logs.py b/aiida/orm/implementation/django/logs.py index 7b3b725c2c..4ddd8fe10f 100644 --- a/aiida/orm/implementation/django/logs.py +++ b/aiida/orm/implementation/django/logs.py @@ -144,7 +144,7 @@ def delete_many(self, filters): raise exceptions.ValidationError('filters must not be empty') # Apply filter and delete found entities - builder = QueryBuilder().append(Log, filters=filters, project='id') + builder = QueryBuilder(backend=self.backend).append(Log, filters=filters, project='id') entities_to_delete = builder.all(flat=True) for entity in entities_to_delete: self.delete(entity) diff --git a/aiida/orm/implementation/sqlalchemy/comments.py b/aiida/orm/implementation/sqlalchemy/comments.py index da100140dd..618aa021bf 100644 --- a/aiida/orm/implementation/sqlalchemy/comments.py +++ b/aiida/orm/implementation/sqlalchemy/comments.py @@ -171,7 +171,7 @@ def delete_many(self, filters): raise exceptions.ValidationError('filters must not be empty') # Apply filter and delete found entities - builder = QueryBuilder().append(Comment, filters=filters, project='id') + builder = QueryBuilder(backend=self.backend).append(Comment, filters=filters, project='id') entities_to_delete = builder.all(flat=True) for entity in entities_to_delete: self.delete(entity) diff --git a/aiida/orm/implementation/sqlalchemy/logs.py b/aiida/orm/implementation/sqlalchemy/logs.py index b4d75ad6ac..62a973171d 100644 --- a/aiida/orm/implementation/sqlalchemy/logs.py +++ b/aiida/orm/implementation/sqlalchemy/logs.py @@ -153,7 +153,7 @@ def delete_many(self, filters): raise exceptions.ValidationError('filter must not be empty') # Apply filter and delete found entities - builder = QueryBuilder().append(Log, filters=filters, project='id') + builder = QueryBuilder(backend=self.backend).append(Log, filters=filters, project='id') entities_to_delete = builder.all(flat=True) for entity in entities_to_delete: self.delete(entity) diff --git a/aiida/orm/nodes/data/array/bands.py b/aiida/orm/nodes/data/array/bands.py index 83484b0983..ec6adffeb3 100644 --- a/aiida/orm/nodes/data/array/bands.py +++ b/aiida/orm/nodes/data/array/bands.py @@ -1803,7 +1803,7 @@ def _prepare_json(self, main_file_name='', comments=True): # pylint: disable=un MATPLOTLIB_FOOTER_TEMPLATE_EXPORTFILE_WITH_DPI = Template("""pl.savefig("$fname", format="$format", dpi=$dpi)""") -def get_bands_and_parents_structure(args): +def get_bands_and_parents_structure(args, backend=None): """Search for bands and return bands and the closest structure that is a parent of the instance. :returns: @@ -1817,7 +1817,7 @@ def get_bands_and_parents_structure(args): from aiida import orm from aiida.common import timezone - q_build = orm.QueryBuilder() + q_build = orm.QueryBuilder(backend=backend) if args.all_users is False: q_build.append(orm.User, tag='creator', filters={'email': orm.User.objects.get_default().email}) else: diff --git a/aiida/orm/nodes/data/cif.py b/aiida/orm/nodes/data/cif.py index 5b0696e103..f0278d0724 100644 --- a/aiida/orm/nodes/data/cif.py +++ b/aiida/orm/nodes/data/cif.py @@ -329,7 +329,7 @@ def read_cif(fileobj, index=-1, **kwargs): return struct_list[index] @classmethod - def from_md5(cls, md5): + def from_md5(cls, md5, backend=None): """ Return a list of all CIF files that match a given MD5 hash. @@ -337,7 +337,7 @@ def from_md5(cls, md5): otherwise the CIF file will not be found. """ from aiida.orm.querybuilder import QueryBuilder - builder = QueryBuilder() + builder = QueryBuilder(backend=backend) builder.append(cls, filters={'attributes.md5': {'==': md5}}) return builder.all(flat=True) diff --git a/aiida/orm/nodes/data/code.py b/aiida/orm/nodes/data/code.py index 9bd8787b13..c936e0ef16 100644 --- a/aiida/orm/nodes/data/code.py +++ b/aiida/orm/nodes/data/code.py @@ -151,7 +151,7 @@ def get_description(self): return f'{self.description}' @classmethod - def get_code_helper(cls, label, machinename=None): + def get_code_helper(cls, label, machinename=None, backend=None): """ :param label: the code label identifying the code to load :param machinename: the machine name where code is setup @@ -164,7 +164,7 @@ def get_code_helper(cls, label, machinename=None): from aiida.orm.computers import Computer from aiida.orm.querybuilder import QueryBuilder - query = QueryBuilder() + query = QueryBuilder(backend=backend) query.append(cls, filters={'label': label}, project='*', tag='code') if machinename: query.append(Computer, filters={'label': machinename}, with_node='code') @@ -249,7 +249,7 @@ def get_from_string(cls, code_string): raise MultipleObjectsError(f'{code_string} could not be uniquely resolved') @classmethod - def list_for_plugin(cls, plugin, labels=True): + def list_for_plugin(cls, plugin, labels=True, backend=None): """ Return a list of valid code strings for a given plugin. @@ -260,7 +260,7 @@ def list_for_plugin(cls, plugin, labels=True): otherwise a list of integers with the code PKs. """ from aiida.orm.querybuilder import QueryBuilder - query = QueryBuilder() + query = QueryBuilder(backend=backend) query.append(cls, filters={'attributes.input_plugin': {'==': plugin}}) valid_codes = query.all(flat=True) diff --git a/aiida/orm/nodes/data/upf.py b/aiida/orm/nodes/data/upf.py index 1ad082dd37..b212327ba2 100644 --- a/aiida/orm/nodes/data/upf.py +++ b/aiida/orm/nodes/data/upf.py @@ -70,7 +70,7 @@ def get_pseudos_from_structure(structure, family_name): return pseudo_list -def upload_upf_family(folder, group_label, group_description, stop_if_existing=True): +def upload_upf_family(folder, group_label, group_description, stop_if_existing=True, backend=None): """Upload a set of UPF files in a given group. :param folder: a path containing all UPF files to be added. @@ -120,7 +120,7 @@ def upload_upf_family(folder, group_label, group_description, stop_if_existing=T for filename in filenames: md5sum = md5_file(filename) - builder = orm.QueryBuilder() + builder = orm.QueryBuilder(backend=backend) builder.append(UpfData, filters={'attributes.md5': {'==': md5sum}}) existing_upf = builder.first() @@ -321,7 +321,7 @@ def store(self, *args, **kwargs): # pylint: disable=signature-differs return super().store(*args, **kwargs) @classmethod - def from_md5(cls, md5): + def from_md5(cls, md5, backend=None): """Return a list of all `UpfData` that match the given md5 hash. .. note:: assumes hash of stored `UpfData` nodes is stored in the `md5` attribute @@ -330,7 +330,7 @@ def from_md5(cls, md5): :return: list of existing `UpfData` nodes that have the same md5 hash """ from aiida.orm.querybuilder import QueryBuilder - builder = QueryBuilder() + builder = QueryBuilder(backend=backend) builder.append(cls, filters={'attributes.md5': {'==': md5}}) return builder.all(flat=True) @@ -366,7 +366,7 @@ def get_upf_family_names(self): """Get the list of all upf family names to which the pseudo belongs.""" from aiida.orm import QueryBuilder, UpfFamily - query = QueryBuilder() + query = QueryBuilder(backend=self.backend) query.append(UpfFamily, tag='group', project='label') query.append(UpfData, filters={'id': {'==': self.id}}, with_group='group') return query.all(flat=True) @@ -448,7 +448,7 @@ def get_upf_group(cls, group_label): return UpfFamily.get(label=group_label) @classmethod - def get_upf_groups(cls, filter_elements=None, user=None): + def get_upf_groups(cls, filter_elements=None, user=None, backend=None): """Return all names of groups of type UpfFamily, possibly with some filters. :param filter_elements: A string or a list of strings. @@ -460,7 +460,7 @@ def get_upf_groups(cls, filter_elements=None, user=None): """ from aiida.orm import QueryBuilder, UpfFamily, User - builder = QueryBuilder() + builder = QueryBuilder(backend=backend) builder.append(UpfFamily, tag='group', project='*') if user: diff --git a/aiida/orm/nodes/node.py b/aiida/orm/nodes/node.py index 723705f109..c8694d2405 100644 --- a/aiida/orm/nodes/node.py +++ b/aiida/orm/nodes/node.py @@ -456,11 +456,11 @@ def validate_incoming(self, source: 'Node', link_type: LinkType, link_label: str """ from aiida.orm.utils.links import validate_link - validate_link(source, self, link_type, link_label) + validate_link(source, self, link_type, link_label, backend=self.backend) # Check if the proposed link would introduce a cycle in the graph following ancestor/descendant rules if link_type in [LinkType.CREATE, LinkType.INPUT_CALC, LinkType.INPUT_WORK]: - builder = QueryBuilder().append( + builder = QueryBuilder(backend=self.backend).append( Node, filters={'id': self.pk}, tag='parent').append( Node, filters={'id': source.pk}, tag='child', with_ancestors='parent') # yapf:disable if builder.count() > 0: @@ -537,7 +537,7 @@ def get_stored_link_triples( if link_label_filter: edge_filters['label'] = {'like': link_label_filter} - builder = QueryBuilder() + builder = QueryBuilder(backend=self.backend) builder.append(Node, filters=node_filters, tag='main') node_project = ['uuid'] if only_uuid else ['*'] @@ -894,7 +894,7 @@ def _iter_all_same_nodes(self, allow_before_store=False) -> Iterator['Node']: if not node_hash or not self._cachable: return iter(()) - builder = QueryBuilder() + builder = QueryBuilder(backend=self.backend) builder.append(self.__class__, filters={'extras._aiida_hash': node_hash}, project='*', subclassing=False) nodes_identical = (n[0] for n in builder.iterall()) diff --git a/aiida/orm/querybuilder.py b/aiida/orm/querybuilder.py index 82444e5d55..d3c04ebc56 100644 --- a/aiida/orm/querybuilder.py +++ b/aiida/orm/querybuilder.py @@ -136,8 +136,8 @@ def __init__( :param distinct: Whether to return de-duplicated rows """ - backend = backend or get_manager().get_backend() - self._impl: BackendQueryBuilder = backend.query() + self._backend = backend or get_manager().get_backend() + self._impl: BackendQueryBuilder = self._backend.query() # SERIALISABLE ATTRIBUTES # A list storing the path being traversed by the query @@ -189,6 +189,11 @@ def __init__( if order_by: self.order_by(order_by) + @property + def backend(self) -> 'Backend': + """Return the backend used by the QueryBuilder.""" + return self._backend + def as_dict(self, copy: bool = True) -> QueryDictType: """Convert to a JSON serialisable dictionary representation of the query.""" data: QueryDictType = { @@ -225,7 +230,7 @@ def __str__(self) -> str: def __deepcopy__(self, memo) -> 'QueryBuilder': """Create deep copy of the instance.""" - return type(self)(**self.as_dict()) # type: ignore + return type(self)(backend=self.backend, **self.as_dict()) # type: ignore def get_used_tags(self, vertices: bool = True, edges: bool = True) -> List[str]: """Returns a list of all the vertices that are being used. diff --git a/aiida/orm/utils/links.py b/aiida/orm/utils/links.py index 535ca0caa5..f79667777f 100644 --- a/aiida/orm/utils/links.py +++ b/aiida/orm/utils/links.py @@ -21,7 +21,7 @@ LinkQuadruple = namedtuple('LinkQuadruple', ['source_id', 'target_id', 'link_type', 'link_label']) -def link_triple_exists(source, target, link_type, link_label): +def link_triple_exists(source, target, link_type, link_label, backend=None): """Return whether a link with the given type and label exists between the given source and target node. :param source: node from which the link is outgoing @@ -42,7 +42,7 @@ def link_triple_exists(source, target, link_type, link_label): # Here we have two stored nodes, so we need to check if the same link already exists in the database. # Finding just a single match is sufficient so we can use the `limit` clause for efficiency - builder = QueryBuilder() + builder = QueryBuilder(backend=backend) builder.append(Node, filters={'id': source.id}, project=['id']) builder.append(Node, filters={'id': target.id}, edge_filters={'type': link_type.value, 'label': link_label}) builder.limit(1) @@ -50,7 +50,7 @@ def link_triple_exists(source, target, link_type, link_label): return builder.count() != 0 -def validate_link(source, target, link_type, link_label): +def validate_link(source, target, link_type, link_label, backend=None): """ Validate adding a link of the given type and label from a given node to ourself. @@ -153,7 +153,7 @@ def validate_link(source, target, link_type, link_label): if outdegree == 'unique_triple' or indegree == 'unique_triple': # For a `unique_triple` degree we just have to check if an identical triple already exist, either in the cache # or stored, in which case, the new proposed link is a duplicate and thus illegal - duplicate_link_triple = link_triple_exists(source, target, link_type, link_label) + duplicate_link_triple = link_triple_exists(source, target, link_type, link_label, backend) # If the outdegree is `unique` there cannot already be any other outgoing link of that type if outdegree == 'unique' and source.get_outgoing(link_type=link_type, only_uuid=True).all(): diff --git a/aiida/orm/utils/remote.py b/aiida/orm/utils/remote.py index 71f3e339d3..deb40ab874 100644 --- a/aiida/orm/utils/remote.py +++ b/aiida/orm/utils/remote.py @@ -37,13 +37,13 @@ def clean_remote(transport, path): pass -def get_calcjob_remote_paths(pks=None, past_days=None, older_than=None, computers=None, user=None): +def get_calcjob_remote_paths(pks=None, past_days=None, older_than=None, computers=None, user=None, backend=None): """ Return a mapping of computer uuids to a list of remote paths, for a given set of calcjobs. The set of calcjobs will be determined by a query with filters based on the pks, past_days, older_than, computers and user arguments. - :param pks: onlu include calcjobs with a pk in this list + :param pks: only include calcjobs with a pk in this list :param past_days: only include calcjobs created since past_days :param older_than: only include calcjobs older than :param computers: only include calcjobs that were ran on these computers @@ -74,7 +74,7 @@ def get_calcjob_remote_paths(pks=None, past_days=None, older_than=None, computer if pks: filters_calc['id'] = {'in': pks} - query = orm.QueryBuilder() + query = orm.QueryBuilder(backend=backend) query.append(CalcJobNode, tag='calc', project=['attributes.remote_workdir'], filters=filters_calc) query.append(orm.Computer, with_node='calc', tag='computer', project=['*'], filters=filters_computer) query.append(orm.User, with_node='calc', filters={'email': user.email}) diff --git a/aiida/tools/graph/age_entities.py b/aiida/tools/graph/age_entities.py index a729d58ee7..de5fcec0d1 100644 --- a/aiida/tools/graph/age_entities.py +++ b/aiida/tools/graph/age_entities.py @@ -225,17 +225,6 @@ def aiida_cls(self): """Class of nodes contained in the entity set (node or group)""" return self._aiida_cls - def get_entities(self): - """Iterator that returns the AiiDA entities""" - for entity, in orm.QueryBuilder().append( - self._aiida_cls, project='*', filters={ - self._identifier: { - 'in': self.keyset - } - } - ).iterall(): - yield entity - class DirectedEdgeSet(AbstractSetContainer): """Extension of AbstractSetContainer diff --git a/aiida/tools/graph/age_rules.py b/aiida/tools/graph/age_rules.py index f768d8d07a..973d334909 100644 --- a/aiida/tools/graph/age_rules.py +++ b/aiida/tools/graph/age_rules.py @@ -11,6 +11,7 @@ from abc import ABCMeta, abstractmethod from collections import defaultdict +from copy import deepcopy import numpy as np @@ -65,7 +66,7 @@ class QueryRule(Operation, metaclass=ABCMeta): found in the last iteration of the query (ReplaceRule). """ - def __init__(self, querybuilder, max_iterations=1, track_edges=False): + def __init__(self, querybuilder: orm.QueryBuilder, max_iterations=1, track_edges=False): """Initialization method :param querybuilder: an instance of the QueryBuilder class from which to take the @@ -107,7 +108,7 @@ def get_spec_from_path(query_dict, idx): for pathspec in query_dict['path']: if not pathspec['entity_type']: pathspec['entity_type'] = 'node.Node.' - self._qbtemplate = orm.QueryBuilder(**query_dict) + self._qbtemplate = deepcopy(querybuilder) query_dict = self._qbtemplate.as_dict() self._first_tag = query_dict['path'][0]['tag'] self._last_tag = query_dict['path'][-1]['tag'] @@ -163,7 +164,7 @@ def _init_run(self, operational_set): # Copying qbtemplate so there's no problem if it is used again in a later run: query_dict = self._qbtemplate.as_dict() - self._querybuilder = orm.QueryBuilder.from_dict(query_dict) + self._querybuilder = deepcopy(self._qbtemplate) self._entity_to_identifier = operational_set[self._entity_to].identifier diff --git a/aiida/tools/graph/deletions.py b/aiida/tools/graph/deletions.py index d14d9c7dd5..61e0454f1d 100644 --- a/aiida/tools/graph/deletions.py +++ b/aiida/tools/graph/deletions.py @@ -71,17 +71,18 @@ def _missing_callback(_pks: Iterable[int]): for _pk in _pks: DELETE_LOGGER.warning(f'warning: node with pk<{_pk}> does not exist, skipping') - pks_set_to_delete = get_nodes_delete(pks, get_links=False, missing_callback=_missing_callback, - **traversal_rules)['nodes'] + pks_set_to_delete = get_nodes_delete( + pks, get_links=False, missing_callback=_missing_callback, backend=backend, **traversal_rules + )['nodes'] DELETE_LOGGER.report('%s Node(s) marked for deletion', len(pks_set_to_delete)) if pks_set_to_delete and DELETE_LOGGER.level == logging.DEBUG: - builder = QueryBuilder().append( - Node, filters={'id': { - 'in': pks_set_to_delete - }}, project=('uuid', 'id', 'node_type', 'label') - ) + builder = QueryBuilder( + backend=backend + ).append(Node, filters={'id': { + 'in': pks_set_to_delete + }}, project=('uuid', 'id', 'node_type', 'label')) DELETE_LOGGER.debug('Node(s) to delete:') for uuid, pk, type_string, label in builder.iterall(): try: @@ -113,6 +114,7 @@ def _missing_callback(_pks: Iterable[int]): def delete_group_nodes( pks: Iterable[int], dry_run: Union[bool, Callable[[Set[int]], bool]] = True, + backend=None, **traversal_rules: bool ) -> Tuple[Set[int], bool]: """Delete nodes contained in a list of groups (not the groups themselves!). @@ -149,7 +151,7 @@ def delete_group_nodes( :returns: (node pks to delete, whether they were deleted) """ - group_node_query = QueryBuilder().append( + group_node_query = QueryBuilder(backend=backend).append( Group, filters={ 'id': { @@ -160,4 +162,4 @@ def delete_group_nodes( ).append(Node, project='id', with_group='groups') group_node_query.distinct() node_pks = group_node_query.all(flat=True) - return delete_nodes(node_pks, dry_run=dry_run, **traversal_rules) + return delete_nodes(node_pks, dry_run=dry_run, backend=backend, **traversal_rules) diff --git a/aiida/tools/graph/graph_traversers.py b/aiida/tools/graph/graph_traversers.py index 6468ead76e..8f9f0c0f6d 100644 --- a/aiida/tools/graph/graph_traversers.py +++ b/aiida/tools/graph/graph_traversers.py @@ -9,7 +9,7 @@ ########################################################################### """Module for functions to traverse AiiDA graphs.""" import sys -from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Set, cast +from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Mapping, Optional, Set, cast from numpy import inf @@ -20,6 +20,9 @@ from aiida.tools.graph.age_entities import Basket from aiida.tools.graph.age_rules import RuleSaveWalkers, RuleSequence, RuleSetWalkers, UpdateRule +if TYPE_CHECKING: + from aiida.orm.implementation import Backend + if sys.version_info >= (3, 8): from typing import TypedDict @@ -35,6 +38,7 @@ def get_nodes_delete( starting_pks: Iterable[int], get_links: bool = False, missing_callback: Optional[Callable[[Iterable[int]], None]] = None, + backend: Optional['Backend'] = None, **traversal_rules: bool ) -> TraverseGraphOutput: """ @@ -59,9 +63,10 @@ def get_nodes_delete( traverse_output = traverse_graph( starting_pks, get_links=get_links, + backend=backend, links_forward=traverse_links['forward'], links_backward=traverse_links['backward'], - missing_callback=missing_callback + missing_callback=missing_callback, ) function_output = { @@ -74,7 +79,10 @@ def get_nodes_delete( def get_nodes_export( - starting_pks: Iterable[int], get_links: bool = False, **traversal_rules: bool + starting_pks: Iterable[int], + get_links: bool = False, + backend: Optional['Backend'] = None, + **traversal_rules: bool ) -> TraverseGraphOutput: """ This function will return the set of all nodes that can be connected @@ -99,6 +107,7 @@ def get_nodes_export( traverse_output = traverse_graph( starting_pks, get_links=get_links, + backend=backend, links_forward=traverse_links['forward'], links_backward=traverse_links['backward'] ) @@ -186,7 +195,8 @@ def traverse_graph( get_links: bool = False, links_forward: Iterable[LinkType] = (), links_backward: Iterable[LinkType] = (), - missing_callback: Optional[Callable[[Iterable[int]], None]] = None + missing_callback: Optional[Callable[[Iterable[int]], None]] = None, + backend: Optional['Backend'] = None ) -> TraverseGraphOutput: """ This function will return the set of all nodes that can be connected @@ -239,7 +249,7 @@ def traverse_graph( return {'nodes': set(), 'links': set()} return {'nodes': set(), 'links': None} - query_nodes = orm.QueryBuilder() + query_nodes = orm.QueryBuilder(backend=backend) query_nodes.append(orm.Node, project=['id'], filters={'id': {'in': operational_set}}) existing_pks = set(query_nodes.all(flat=True)) missing_pks = operational_set.difference(existing_pks) @@ -266,7 +276,7 @@ def traverse_graph( rules += [RuleSaveWalkers(stash)] if links_forward: - query_outgoing = orm.QueryBuilder() + query_outgoing = orm.QueryBuilder(backend=backend) query_outgoing.append(orm.Node, tag='sources') query_outgoing.append(orm.Node, edge_filters=filters_forwards, with_incoming='sources') rule_outgoing = UpdateRule(query_outgoing, max_iterations=1, track_edges=get_links) @@ -276,7 +286,7 @@ def traverse_graph( rules += [RuleSetWalkers(stash)] if links_backward: - query_incoming = orm.QueryBuilder() + query_incoming = orm.QueryBuilder(backend=backend) query_incoming.append(orm.Node, tag='sources') query_incoming.append(orm.Node, edge_filters=filters_backwards, with_outgoing='sources') rule_incoming = UpdateRule(query_incoming, max_iterations=1, track_edges=get_links) diff --git a/aiida/tools/visualization/graph.py b/aiida/tools/visualization/graph.py index 2793ce9ce0..b864ad28bb 100644 --- a/aiida/tools/visualization/graph.py +++ b/aiida/tools/visualization/graph.py @@ -10,17 +10,21 @@ """ provides functionality to create graphs of the AiiDa data providence, *via* graphviz. """ - import os from types import MappingProxyType # pylint: disable=no-name-in-module,useless-suppression +from typing import TYPE_CHECKING, Optional from graphviz import Digraph from aiida import orm from aiida.common import LinkType +from aiida.manage.manager import get_manager from aiida.orm.utils.links import LinkPair from aiida.tools.graph.graph_traversers import traverse_graph +if TYPE_CHECKING: + from aiida.orm.implementation import Backend + __all__ = ('Graph', 'default_link_styles', 'default_node_styles', 'pstate_node_styles', 'default_node_sublabels') @@ -359,7 +363,8 @@ def __init__( link_style_fn=None, node_style_fn=None, node_sublabel_fn=None, - node_id_type='pk' + node_id_type='pk', + backend: Optional['Backend'] = None ): """a class to create graphviz graphs of the AiiDA node provenance @@ -398,10 +403,16 @@ def __init__( self._node_styles = node_style_fn or default_node_styles self._node_sublabels = node_sublabel_fn or default_node_sublabels self._node_id_type = node_id_type + self._backend = backend or get_manager().get_backend() self._ignore_node_style = _OVERRIDE_STYLES_DICT['ignore_node'] self._origin_node_style = _OVERRIDE_STYLES_DICT['origin_node'] + @property + def backend(self) -> 'Backend': + """The backend used to create the graph""" + return self._backend + @property def graphviz(self): """return a copy of the graphviz.Digraph""" @@ -539,10 +550,11 @@ def add_incoming(self, node, link_types=(), annotate_links=None, return_pks=True (node_pk,), max_iterations=1, get_links=True, + backend=self.backend, links_backward=valid_link_types, ) - traversed_nodes = orm.QueryBuilder().append( + traversed_nodes = orm.QueryBuilder(backend=self.backend).append( orm.Node, filters={'id': { 'in': traversed_graph['nodes'] @@ -595,10 +607,11 @@ def add_outgoing(self, node, link_types=(), annotate_links=None, return_pks=True (node_pk,), max_iterations=1, get_links=True, + backend=self.backend, links_forward=valid_link_types, ) - traversed_nodes = orm.QueryBuilder().append( + traversed_nodes = orm.QueryBuilder(backend=self.backend).append( orm.Node, filters={'id': { 'in': traversed_graph['nodes'] @@ -664,6 +677,7 @@ def recurse_descendants( (origin_pk,), max_iterations=depth, get_links=True, + backend=self.backend, links_forward=valid_link_types, ) @@ -674,13 +688,14 @@ def recurse_descendants( traversed_graph['nodes'], max_iterations=1, get_links=True, + backend=self.backend, links_backward=[LinkType.INPUT_WORK, LinkType.INPUT_CALC] ) traversed_graph['nodes'] = traversed_graph['nodes'].union(traversed_outputs['nodes']) traversed_graph['links'] = traversed_graph['links'].union(traversed_outputs['links']) # Do one central query for all nodes in the Graph and generate a {id: Node} dictionary - traversed_nodes = orm.QueryBuilder().append( + traversed_nodes = orm.QueryBuilder(backend=self.backend).append( orm.Node, filters={'id': { 'in': traversed_graph['nodes'] @@ -755,6 +770,7 @@ def recurse_ancestors( (origin_pk,), max_iterations=depth, get_links=True, + backend=self.backend, links_backward=valid_link_types, ) @@ -765,13 +781,14 @@ def recurse_ancestors( traversed_graph['nodes'], max_iterations=1, get_links=True, + backend=self.backend, links_forward=[LinkType.CREATE, LinkType.RETURN] ) traversed_graph['nodes'] = traversed_graph['nodes'].union(traversed_outputs['nodes']) traversed_graph['links'] = traversed_graph['links'].union(traversed_outputs['links']) # Do one central query for all nodes in the Graph and generate a {id: Node} dictionary - traversed_nodes = orm.QueryBuilder().append( + traversed_nodes = orm.QueryBuilder(backend=self.backend).append( orm.Node, filters={'id': { 'in': traversed_graph['nodes'] @@ -842,6 +859,7 @@ def add_origin_to_targets( self.add_node(origin_node, style_override=dict(origin_style)) query = orm.QueryBuilder( + backend=self.backend, **{ 'path': [{ 'cls': origin_node.__class__, @@ -902,6 +920,7 @@ def add_origins_to_targets( origin_filters = {} query = orm.QueryBuilder( + backend=self.backend, **{'path': [{ 'cls': origin_cls, 'filters': origin_filters, From a80812c30ebb0de12f0429954508f9ceb6120d34 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Sat, 23 Oct 2021 09:14:24 +0200 Subject: [PATCH 2/5] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20REFACTOR:=20Delegate?= =?UTF-8?q?=20`RepositoryBackend`=20control=20to=20the=20`Backend`=20(#516?= =?UTF-8?q?9)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit moves `Profile.get_repository() -> Repository`, to `Backend.get_repository() -> BackendRepository`. The original delegation, (1) did not allow different backends to implement different repository classes, and (2) made it impossible to implement/instantiate a `Backend` that was not directly dependent on the currently loaded `Profile`. Additionally, the commit improves the docstring of the `Backend`, and fixes a minor issue, that `verdi profile delete` left an empty repository folder for that profile. --- .../db/migrations/0047_migrate_repository.py | 4 +- .../1feaea71bd5a_migrate_repository.py | 4 +- aiida/backends/testbase.py | 2 +- aiida/cmdline/commands/cmd_setup.py | 2 +- aiida/cmdline/commands/cmd_status.py | 2 +- aiida/manage/configuration/config.py | 5 ++- aiida/manage/configuration/profile.py | 10 ----- aiida/manage/manager.py | 24 ++++++------ aiida/orm/implementation/backends.py | 16 +++++++- aiida/orm/implementation/sql/backends.py | 20 ++++++++-- aiida/orm/nodes/node.py | 6 +-- aiida/orm/nodes/repository.py | 3 +- aiida/repository/repository.py | 39 ++++++++++--------- aiida/tools/importexport/dbexport/main.py | 5 +-- .../importexport/dbimport/backends/common.py | 5 +-- docs/source/nitpick-exceptions | 2 + tests/cmdline/commands/test_setup.py | 4 +- tests/orm/data/test_array.py | 2 +- tests/repository/test_repository.py | 28 ++----------- tests/tools/importexport/test_repository.py | 2 +- 20 files changed, 95 insertions(+), 90 deletions(-) diff --git a/aiida/backends/djsite/db/migrations/0047_migrate_repository.py b/aiida/backends/djsite/db/migrations/0047_migrate_repository.py index 14b11695be..32c5dbd7ee 100644 --- a/aiida/backends/djsite/db/migrations/0047_migrate_repository.py +++ b/aiida/backends/djsite/db/migrations/0047_migrate_repository.py @@ -35,10 +35,12 @@ def migrate_repository(apps, schema_editor): from aiida.common import exceptions from aiida.common.progress_reporter import get_progress_reporter, set_progress_bar_tqdm, set_progress_reporter from aiida.manage.configuration import get_profile + from aiida.manage.manager import get_manager DbNode = apps.get_model('db', 'DbNode') profile = get_profile() + backend = get_manager().get_backend() node_count = DbNode.objects.count() missing_node_uuids = [] missing_repo_folder = [] @@ -107,7 +109,7 @@ def migrate_repository(apps, schema_editor): # Store the UUID of the repository container in the `DbSetting` table. Note that for new databases, the profile # setup will already have stored the UUID and so it should be skipped, or an exception for a duplicate key will be # raised. This migration step is only necessary for existing databases that are migrated. - container_id = profile.get_repository().uuid + container_id = backend.get_repository().uuid with schema_editor.connection.cursor() as cursor: cursor.execute( f""" diff --git a/aiida/backends/sqlalchemy/migrations/versions/1feaea71bd5a_migrate_repository.py b/aiida/backends/sqlalchemy/migrations/versions/1feaea71bd5a_migrate_repository.py index ab234d022d..cc33aac2ca 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/1feaea71bd5a_migrate_repository.py +++ b/aiida/backends/sqlalchemy/migrations/versions/1feaea71bd5a_migrate_repository.py @@ -35,6 +35,7 @@ def upgrade(): from aiida.common import exceptions from aiida.common.progress_reporter import get_progress_reporter, set_progress_bar_tqdm, set_progress_reporter from aiida.manage.configuration import get_profile + from aiida.manage.manager import get_manager connection = op.get_bind() @@ -46,6 +47,7 @@ def upgrade(): ) profile = get_profile() + backend = get_manager().get_backend() node_count = connection.execute(select(func.count()).select_from(DbNode)).scalar() missing_repo_folder = [] shard_count = 256 @@ -106,7 +108,7 @@ def upgrade(): # Store the UUID of the repository container in the `DbSetting` table. Note that for new databases, the profile # setup will already have stored the UUID and so it should be skipped, or an exception for a duplicate key will be # raised. This migration step is only necessary for existing databases that are migrated. - container_id = profile.get_repository().uuid + container_id = backend.get_repository().uuid statement = text( f""" INSERT INTO db_dbsetting (key, val, description) diff --git a/aiida/backends/testbase.py b/aiida/backends/testbase.py index f6307df575..8fd5bc5a51 100644 --- a/aiida/backends/testbase.py +++ b/aiida/backends/testbase.py @@ -142,7 +142,7 @@ def initialise_repository(cls): """Initialise the repository""" from aiida.manage.configuration import get_profile profile = get_profile() - repository = profile.get_repository() + repository = cls.backend.get_repository() repository.initialise(clear=True, **profile.defaults['repository']) @classmethod diff --git a/aiida/cmdline/commands/cmd_setup.py b/aiida/cmdline/commands/cmd_setup.py index 28f52ef8ad..63b58c0baf 100644 --- a/aiida/cmdline/commands/cmd_setup.py +++ b/aiida/cmdline/commands/cmd_setup.py @@ -94,7 +94,7 @@ def setup( # with that UUID and we have to make sure that the provided repository corresponds to it. backend_manager = manager.get_backend_manager() repository_uuid_database = backend_manager.get_repository_uuid() - repository_uuid_profile = profile.get_repository().uuid + repository_uuid_profile = backend.get_repository().uuid if repository_uuid_database != repository_uuid_profile: echo.echo_critical( diff --git a/aiida/cmdline/commands/cmd_status.py b/aiida/cmdline/commands/cmd_status.py index 0117c15f14..1180c7cd0c 100644 --- a/aiida/cmdline/commands/cmd_status.py +++ b/aiida/cmdline/commands/cmd_status.py @@ -85,7 +85,7 @@ def verdi_status(print_traceback, no_rmq): # Getting the repository try: - repository = profile.get_repository() + repository = manager.get_backend().get_repository() except Exception as exc: message = 'Error with repository folder' print_status(ServiceStatus.ERROR, 'repository', message, exception=exc, print_traceback=print_traceback) diff --git a/aiida/manage/configuration/config.py b/aiida/manage/configuration/config.py index 5d6de33c90..efba5cecce 100644 --- a/aiida/manage/configuration/config.py +++ b/aiida/manage/configuration/config.py @@ -361,8 +361,9 @@ def delete_profile( profile = self.get_profile(name) if include_repository: - repository = profile.get_repository() - repository.delete() + folder = profile.repository_path + if folder.exists(): + shutil.rmtree(folder) if include_database: postgres = Postgres.from_profile(profile) diff --git a/aiida/manage/configuration/profile.py b/aiida/manage/configuration/profile.py index 1dee9e8c89..cfa2673a98 100644 --- a/aiida/manage/configuration/profile.py +++ b/aiida/manage/configuration/profile.py @@ -128,16 +128,6 @@ def __init__(self, name, attributes, from_config=False): # Currently, whether a profile is a test profile is solely determined by its name starting with 'test_' self._test_profile = bool(self.name.startswith('test_')) - def get_repository(self) -> 'Repository': - """Return the repository configured for this profile.""" - from disk_objectstore import Container - - from aiida.repository import Repository - from aiida.repository.backend import DiskObjectStoreRepositoryBackend - container = Container(self.repository_path / 'container') - backend = DiskObjectStoreRepositoryBackend(container=container) - return Repository(backend=backend) - @property def uuid(self): """Return the profile uuid. diff --git a/aiida/manage/manager.py b/aiida/manage/manager.py index 51a9cd79f7..cc79e13729 100644 --- a/aiida/manage/manager.py +++ b/aiida/manage/manager.py @@ -132,11 +132,23 @@ def _load_backend(self, schema_check: bool = True, repository_check: bool = True backend_manager.load_backend_environment(profile, validate_schema=schema_check) configuration.BACKEND_UUID = profile.uuid + backend_type = profile.database_backend + + # Can only import the backend classes after the backend has been loaded + if backend_type == BACKEND_DJANGO: + from aiida.orm.implementation.django.backend import DjangoBackend + self._backend = DjangoBackend() + elif backend_type == BACKEND_SQLA: + from aiida.orm.implementation.sqlalchemy.backend import SqlaBackend + self._backend = SqlaBackend() + else: + raise ValueError(f'unknown database backend type: {backend_type}') + # Perform the check on the repository compatibility. Since this is new functionality and the stability is not # yet known, we issue a warning in the case the repo and database are incompatible. In the future this might # then become an exception once we have verified that it is working reliably. if repository_check and not profile.is_test_profile: - repository_uuid_config = profile.get_repository().uuid + repository_uuid_config = self._backend.get_repository().uuid repository_uuid_database = backend_manager.get_repository_uuid() from aiida.cmdline.utils import echo @@ -149,16 +161,6 @@ def _load_backend(self, schema_check: bool = True, repository_check: bool = True 'Please make sure that the configuration of your profile is correct.\n' ) - backend_type = profile.database_backend - - # Can only import the backend classes after the backend has been loaded - if backend_type == BACKEND_DJANGO: - from aiida.orm.implementation.django.backend import DjangoBackend - self._backend = DjangoBackend() - elif backend_type == BACKEND_SQLA: - from aiida.orm.implementation.sqlalchemy.backend import SqlaBackend - self._backend = SqlaBackend() - # Reconfigure the logging with `with_orm=True` to make sure that profile specific logging configuration options # are taken into account and the `DbLogHandler` is configured. configure_logging(with_orm=True) diff --git a/aiida/orm/implementation/backends.py b/aiida/orm/implementation/backends.py index b1273661d9..c1ffb8b0b1 100644 --- a/aiida/orm/implementation/backends.py +++ b/aiida/orm/implementation/backends.py @@ -25,6 +25,7 @@ BackendQueryBuilder, BackendUserCollection, ) + from aiida.repository.backend.abstract import AbstractRepositoryBackend __all__ = ('Backend',) @@ -32,7 +33,16 @@ class Backend(abc.ABC): - """The public interface that defines a backend factory that creates backend specific concrete objects.""" + """Abstraction for a backend to read/write persistent data for a profile's provenance graph. + + AiiDA splits data storage into two sources: + + - Searchable data, which is stored in the database and can be queried using the QueryBuilder + - Non-searchable data, which is stored in the repository and can be loaded using the RepositoryBackend + + The two sources are inter-linked by the ``Node.repository_metadata``. + Once stored, the leaf values of this dictionary must be valid pointers to object keys in the repository. + """ @abc.abstractmethod def migrate(self) -> None: @@ -135,3 +145,7 @@ def delete_nodes_and_connections(self, pks_to_delete: Sequence[int]): :raises: ``AssertionError`` if a transaction is not active """ + + @abc.abstractmethod + def get_repository(self) -> 'AbstractRepositoryBackend': + """Return the object repository configured for this backend.""" diff --git a/aiida/orm/implementation/sql/backends.py b/aiida/orm/implementation/sql/backends.py index 1423ce5d22..fb8b9321e7 100644 --- a/aiida/orm/implementation/sql/backends.py +++ b/aiida/orm/implementation/sql/backends.py @@ -9,17 +9,20 @@ ########################################################################### """Generic backend related objects""" import abc -import typing +from typing import TYPE_CHECKING, Generic, TypeVar from .. import backends, entities +if TYPE_CHECKING: + from aiida.repository.backend import DiskObjectStoreRepositoryBackend + __all__ = ('SqlBackend',) # The template type for the base sqlalchemy/django ORM model type -ModelType = typing.TypeVar('ModelType') # pylint: disable=invalid-name +ModelType = TypeVar('ModelType') # pylint: disable=invalid-name -class SqlBackend(typing.Generic[ModelType], backends.Backend): +class SqlBackend(Generic[ModelType], backends.Backend): """ A class for SQL based backends. Assumptions are that: * there is an ORM @@ -29,6 +32,17 @@ class SqlBackend(typing.Generic[ModelType], backends.Backend): if any of these assumptions do not fit then just implement a backend from :class:`aiida.orm.implementation.Backend` """ + def get_repository(self) -> 'DiskObjectStoreRepositoryBackend': + from disk_objectstore import Container + + from aiida.manage.manager import get_manager + from aiida.repository.backend import DiskObjectStoreRepositoryBackend + + profile = get_manager().get_profile() + assert profile is not None, 'profile not loaded' + container = Container(profile.repository_path / 'container') + return DiskObjectStoreRepositoryBackend(container=container) + @abc.abstractmethod def get_backend_entity(self, model: ModelType) -> entities.BackendEntity: """ diff --git a/aiida/orm/nodes/node.py b/aiida/orm/nodes/node.py index c8694d2405..3d7f083d53 100644 --- a/aiida/orm/nodes/node.py +++ b/aiida/orm/nodes/node.py @@ -709,13 +709,13 @@ def _store(self, with_transaction: bool = True, clean: bool = True) -> 'Node': :param with_transaction: if False, do not use a transaction because the caller will already have opened one. :param clean: boolean, if True, will clean the attributes and extras before attempting to store """ + from aiida.repository import Repository from aiida.repository.backend import SandboxRepositoryBackend # Only if the backend repository is a sandbox do we have to clone its contents to the permanent repository. if isinstance(self._repository.backend, SandboxRepositoryBackend): - profile = get_manager().get_profile() - assert profile is not None, 'profile not loaded' - repository = profile.get_repository() + repository_backend = self.backend.get_repository() + repository = Repository(backend=repository_backend) repository.clone(self._repository) # Swap the sandbox repository for the new permanent repository instance which should delete the sandbox self._repository_instance = repository diff --git a/aiida/orm/nodes/repository.py b/aiida/orm/nodes/repository.py index 11ee76401c..f8608299fb 100644 --- a/aiida/orm/nodes/repository.py +++ b/aiida/orm/nodes/repository.py @@ -50,8 +50,7 @@ def _repository(self) -> Repository: """ if self._repository_instance is None: if self.is_stored: - from aiida.manage.manager import get_manager - backend = get_manager().get_profile().get_repository().backend + backend = self.backend.get_repository() serialized = self.repository_metadata self._repository_instance = Repository.from_serialized(backend=backend, serialized=serialized) else: diff --git a/aiida/repository/repository.py b/aiida/repository/repository.py index 4c2a94b067..12392b6d3d 100644 --- a/aiida/repository/repository.py +++ b/aiida/repository/repository.py @@ -51,19 +51,12 @@ def __str__(self) -> str: @property def uuid(self) -> Optional[str]: - """Return the unique identifier of the repository or ``None`` if it doesn't have one.""" + """Return the unique identifier of the repository backend or ``None`` if it doesn't have one.""" return self.backend.uuid - def initialise(self, **kwargs: Any) -> None: - """Initialise the repository if it hasn't already been initialised. - - :param kwargs: keyword argument that will be passed to the ``initialise`` call of the backend. - """ - self.backend.initialise(**kwargs) - @property def is_initialised(self) -> bool: - """Return whether the repository has been initialised.""" + """Return whether the repository backend has been initialised.""" return self.backend.is_initialised @classmethod @@ -417,16 +410,6 @@ def delete_object(self, path: FilePath, hard_delete: bool = False) -> None: directory = self.get_directory(path.parent) directory.objects.pop(path.name) - def delete(self) -> None: - """Delete the repository. - - .. important:: This will not just delete the contents of the repository but also the repository itself and all - of its assets. For example, if the repository is stored inside a folder on disk, the folder may be deleted. - - """ - self.backend.erase() - self.reset() - def erase(self) -> None: """Delete all objects from the repository. @@ -510,3 +493,21 @@ def copy_tree(self, target: Union[str, pathlib.Path], path: FilePath = None) -> with self.open(root / filename) as handle: filepath.write_bytes(handle.read()) + + # these methods are not actually used in aiida-core, but are here for completeness + + def initialise(self, **kwargs: Any) -> None: + """Initialise the repository if it hasn't already been initialised. + + :param kwargs: keyword argument that will be passed to the ``initialise`` call of the backend. + """ + self.backend.initialise(**kwargs) + + def delete(self) -> None: + """Delete the repository. + + .. important:: This will not just delete the contents of the repository but also the repository itself and all + of its assets. For example, if the repository is stored inside a folder on disk, the folder may be deleted. + """ + self.backend.erase() + self.reset() diff --git a/aiida/tools/importexport/dbexport/main.py b/aiida/tools/importexport/dbexport/main.py index 83fbb502d6..5ddd78ff5d 100644 --- a/aiida/tools/importexport/dbexport/main.py +++ b/aiida/tools/importexport/dbexport/main.py @@ -572,9 +572,8 @@ def _write_node_repositories( container_export = Container(dirpath) container_export.init_container() - profile = get_manager().get_profile() - assert profile is not None, 'profile not loaded' - container_profile = profile.get_repository().backend.container + backend = get_manager().get_backend() + container_profile = backend.get_repository().container # This should be done more effectively, starting by not having to load the node. Either the repository # metadata should be collected earlier when the nodes themselves are already exported or a single separate diff --git a/aiida/tools/importexport/dbimport/backends/common.py b/aiida/tools/importexport/dbimport/backends/common.py index 97b09101c8..5a5620be4f 100644 --- a/aiida/tools/importexport/dbimport/backends/common.py +++ b/aiida/tools/importexport/dbimport/backends/common.py @@ -43,9 +43,8 @@ def _copy_node_repositories(*, repository_metadatas: List[Dict], reader: Archive if not container_export.is_initialised: container_export.init_container() - profile = get_manager().get_profile() - assert profile is not None, 'profile not loaded' - container_profile = profile.get_repository().backend.container + backend = get_manager().get_backend() + container_profile = backend.get_repository().container def collect_hashkeys(objects, hashkeys): for obj in objects.values(): diff --git a/docs/source/nitpick-exceptions b/docs/source/nitpick-exceptions index d7bfb15798..bb2f466230 100644 --- a/docs/source/nitpick-exceptions +++ b/docs/source/nitpick-exceptions @@ -45,6 +45,7 @@ py:class aiida.tools.importexport.dbexport.ArchiveData py:class aiida.tools.groups.paths.WalkNodeResult py:meth aiida.orm.groups.GroupCollection.delete +py:class AbstractRepositoryBackend py:class Backend py:class BackendEntity py:class BackendEntityType @@ -58,6 +59,7 @@ py:class CollectionType py:class Computer py:class Data py:class DbImporter +py:class DiskObjectStoreRepositoryBackend py:class EntityType py:class EntityTypes py:class ExitCode diff --git a/tests/cmdline/commands/test_setup.py b/tests/cmdline/commands/test_setup.py index d23853c815..1121fc413b 100644 --- a/tests/cmdline/commands/test_setup.py +++ b/tests/cmdline/commands/test_setup.py @@ -80,7 +80,7 @@ def test_quicksetup(self): # Check that the repository UUID was stored in the database manager = get_manager() backend_manager = manager.get_backend_manager() - self.assertEqual(backend_manager.get_repository_uuid(), profile.get_repository().uuid) + self.assertEqual(backend_manager.get_repository_uuid(), self.backend.get_repository().uuid) def test_quicksetup_from_config_file(self): """Test `verdi quicksetup` from configuration file.""" @@ -167,4 +167,4 @@ def test_setup(self): # Check that the repository UUID was stored in the database manager = get_manager() backend_manager = manager.get_backend_manager() - self.assertEqual(backend_manager.get_repository_uuid(), profile.get_repository().uuid) + self.assertEqual(backend_manager.get_repository_uuid(), self.backend.get_repository().uuid) diff --git a/tests/orm/data/test_array.py b/tests/orm/data/test_array.py index 20d9ce1b98..a77c393fde 100644 --- a/tests/orm/data/test_array.py +++ b/tests/orm/data/test_array.py @@ -31,7 +31,7 @@ def test_read_stored(): assert numpy.array_equal(loaded.get_array('array'), array) # Now pack all the files in the repository - container = get_manager().get_profile().get_repository().backend.container + container = get_manager().get_backend().get_repository().container container.pack_all_loose() loaded = load_node(node.uuid) diff --git a/tests/repository/test_repository.py b/tests/repository/test_repository.py index 0400e623d8..e2033623e9 100644 --- a/tests/repository/test_repository.py +++ b/tests/repository/test_repository.py @@ -37,8 +37,8 @@ def repository(request, tmp_path_factory) -> Repository: """ with request.param(tmp_path_factory.mktemp('container')) as backend: + backend.initialise() repository = Repository(backend=backend) - repository.initialise() yield repository @@ -78,12 +78,12 @@ def test_uuid(repository_uninitialised): if isinstance(repository.backend, SandboxRepositoryBackend): assert repository.uuid is None - repository.initialise() + repository.backend.initialise() assert repository.uuid is None if isinstance(repository.backend, DiskObjectStoreRepositoryBackend): assert repository.uuid is None - repository.initialise() + repository.backend.initialise() assert isinstance(repository.uuid, str) @@ -92,7 +92,7 @@ def test_initialise(repository_uninitialised): repository = repository_uninitialised assert not repository.is_initialised - repository.initialise() + repository.backend.initialise() assert repository.is_initialised @@ -508,26 +508,6 @@ def test_delete_object_hard(repository, generate_directory): assert not repository.backend.has_object(key) -def test_delete(repository, generate_directory): - """Test the ``Repository.delete`` method.""" - directory = generate_directory({ - 'file_a': b'content_a', - 'relative': { - 'file_b': b'content_b', - } - }) - - repository.put_object_from_tree(str(directory)) - - assert repository.has_object('file_a') - assert repository.has_object('relative/file_b') - - repository.delete() - - assert repository.is_empty() - assert not repository.is_initialised - - def test_erase(repository, generate_directory): """Test the ``Repository.erase`` method.""" directory = generate_directory({ diff --git a/tests/tools/importexport/test_repository.py b/tests/tools/importexport/test_repository.py index fc2d2a8358..2d8c6ddfa5 100644 --- a/tests/tools/importexport/test_repository.py +++ b/tests/tools/importexport/test_repository.py @@ -32,7 +32,7 @@ def test_export_repository(aiida_profile, tmp_path): export([node], filename=filepath) aiida_profile.reset_db() - container = get_manager().get_profile().get_repository().backend.container + container = get_manager().get_backend().get_repository().container container.init_container(clear=True) import_data(filepath, silent=True) From 6e272a3f715e60f63129c072d192eeea2e44b771 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Sun, 24 Oct 2021 21:34:09 +0200 Subject: [PATCH 3/5] =?UTF-8?q?=F0=9F=91=8C=20IMPROVE:=20Add=20`AuthInfo`?= =?UTF-8?q?=20joins=20to=20`QueryBuilder`=20(#5195)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Allow for `with_user` and `with_computer` when querying `AuthInfo` --- aiida/orm/computers.py | 4 +- aiida/orm/implementation/querybuilder.py | 2 +- .../sqlalchemy/querybuilder/joiner.py | 377 ++++++++++-------- tests/orm/test_querybuilder.py | 20 + 4 files changed, 225 insertions(+), 178 deletions(-) diff --git a/aiida/orm/computers.py b/aiida/orm/computers.py index 4c77d96f41..ffeba6ad8f 100644 --- a/aiida/orm/computers.py +++ b/aiida/orm/computers.py @@ -84,7 +84,7 @@ def objects(cls) -> ComputerCollection: # pylint: disable=no-self-argument def __init__( # pylint: disable=too-many-arguments self, label: str = None, - hostname: str = None, + hostname: str = '', description: str = '', transport_type: str = '', scheduler_type: str = '', @@ -137,7 +137,7 @@ def _hostname_validator(cls, hostname: str) -> None: """ Validates the hostname. """ - if not hostname.strip(): + if not (hostname or hostname.strip()): raise exceptions.ValidationError('No hostname specified') @classmethod diff --git a/aiida/orm/implementation/querybuilder.py b/aiida/orm/implementation/querybuilder.py index 6e2c0de340..cb55bdc0b1 100644 --- a/aiida/orm/implementation/querybuilder.py +++ b/aiida/orm/implementation/querybuilder.py @@ -41,7 +41,7 @@ class EntityTypes(Enum): EntityRelationships: Dict[str, Set[str]] = { - 'authinfo': set(), + 'authinfo': {'with_computer', 'with_user'}, 'comment': {'with_node', 'with_user'}, 'computer': {'with_node'}, 'group': {'with_node', 'with_user'}, diff --git a/aiida/orm/implementation/sqlalchemy/querybuilder/joiner.py b/aiida/orm/implementation/sqlalchemy/querybuilder/joiner.py index 08dd41ee58..b0e407fcb8 100644 --- a/aiida/orm/implementation/sqlalchemy/querybuilder/joiner.py +++ b/aiida/orm/implementation/sqlalchemy/querybuilder/joiner.py @@ -33,6 +33,10 @@ class _EntityMapper(Protocol): # pylint: disable=invalid-name + @property + def AuthInfo(self) -> Type[Model]: + ... + @property def Node(self) -> Type[Model]: ... @@ -97,41 +101,213 @@ def _entity_join_map(self) -> Dict[str, Dict[str, JoinFuncType]]: and the second defines the relationship with respect to a given tag. """ mapping = { - 'node': { - 'with_log': self._join_log_node, - 'with_comment': self._join_comment_node, - 'with_incoming': self._join_outputs, - 'with_outgoing': self._join_inputs, - 'with_descendants': self._join_ancestors_recursive, - 'with_ancestors': self._join_descendants_recursive, - 'with_computer': self._join_to_computer_used, - 'with_user': self._join_created_by, - 'with_group': self._join_group_members, + 'authinfo': { + 'with_computer': self._join_computer_authinfo, + 'with_user': self._join_user_authinfo, }, - 'computer': { - 'with_node': self._join_computer, + 'comment': { + 'with_node': self._join_node_comment, + 'with_user': self._join_user_comment, }, - 'user': { - 'with_comment': self._join_comment_user, - 'with_node': self._join_creator_of, - 'with_group': self._join_group_user, + 'computer': { + 'with_node': self._join_node_computer, }, 'group': { - 'with_node': self._join_groups, + 'with_node': self._join_node_group, 'with_user': self._join_user_group, }, - 'comment': { - 'with_node': self._join_node_comment, - 'with_user': self._join_user_comment, - }, 'log': { 'with_node': self._join_node_log, - } + }, + 'node': { + 'with_log': self._join_log_node, + 'with_comment': self._join_comment_node, + 'with_incoming': self._join_node_outputs, + 'with_outgoing': self._join_node_inputs, + 'with_descendants': self._join_node_ancestors_recursive, + 'with_ancestors': self._join_node_descendants_recursive, + 'with_computer': self._join_computer_node, + 'with_user': self._join_user_node, + 'with_group': self._join_group_node, + }, + 'user': { + 'with_comment': self._join_comment_user, + 'with_node': self._join_node_user, + 'with_group': self._join_group_user, + }, } return mapping # type: ignore - def _join_outputs(self, query: Query, joined_entity, entity_to_join, isouterjoin: bool, **_kw): + def _join_computer_authinfo(self, query: Query, joined_entity, entity_to_join, isouterjoin: bool, **_kw): + """ + :param joined_entity: the aliased user you want to join to + :param entity_to_join: the (aliased) node or group in the DB to join with + """ + _check_dbentities((joined_entity, self._entities.Computer), (entity_to_join, self._entities.AuthInfo), + 'with_computer') + new_query = query.join(entity_to_join, entity_to_join.dbcomputer_id == joined_entity.id, isouter=isouterjoin) + return JoinReturn(new_query) + + def _join_user_authinfo(self, query: Query, joined_entity, entity_to_join, isouterjoin: bool, **_kw): + """ + :param joined_entity: the aliased user you want to join to + :param entity_to_join: the (aliased) node or group in the DB to join with + """ + _check_dbentities((joined_entity, self._entities.User), (entity_to_join, self._entities.AuthInfo), 'with_user') + new_query = query.join(entity_to_join, entity_to_join.aiidauser_id == joined_entity.id, isouter=isouterjoin) + return JoinReturn(new_query) + + def _join_group_node(self, query: Query, joined_entity, entity_to_join, isouterjoin: bool, **_kw): + """ + :param joined_entity: + The (aliased) ORMclass that is + a group in the database + :param entity_to_join: + The (aliased) ORMClass that is a node and member of the group + + **joined_entity** and **entity_to_join** + are joined via the table_groups_nodes table. + from **joined_entity** as group to **enitity_to_join** as node. + (**enitity_to_join** is *with_group* **joined_entity**) + """ + _check_dbentities((joined_entity, self._entities.Group), (entity_to_join, self._entities.Node), 'with_group') + aliased_group_nodes = aliased(self._entities.table_groups_nodes) + new_query = query.join(aliased_group_nodes, aliased_group_nodes.c.dbgroup_id == joined_entity.id).join( + entity_to_join, entity_to_join.id == aliased_group_nodes.c.dbnode_id, isouter=isouterjoin + ) + return JoinReturn(new_query, aliased_group_nodes) + + def _join_node_group(self, query: Query, joined_entity, entity_to_join, isouterjoin: bool, **_kw): + """ + :param joined_entity: The (aliased) node in the database + :param entity_to_join: The (aliased) Group + + **joined_entity** and **entity_to_join** are + joined via the table_groups_nodes table. + from **joined_entity** as node to **enitity_to_join** as group. + (**enitity_to_join** is a group *with_node* **joined_entity**) + """ + _check_dbentities((joined_entity, self._entities.Node), (entity_to_join, self._entities.Group), 'with_node') + aliased_group_nodes = aliased(self._entities.table_groups_nodes) + new_query = query.join(aliased_group_nodes, aliased_group_nodes.c.dbnode_id == joined_entity.id).join( + entity_to_join, entity_to_join.id == aliased_group_nodes.c.dbgroup_id, isouter=isouterjoin + ) + return JoinReturn(new_query, aliased_group_nodes) + + def _join_node_user(self, query: Query, joined_entity, entity_to_join, isouterjoin: bool, **_kw): + """ + :param joined_entity: the aliased node + :param entity_to_join: the aliased user to join to that node + """ + _check_dbentities((joined_entity, self._entities.Node), (entity_to_join, self._entities.User), 'with_node') + new_query = query.join(entity_to_join, entity_to_join.id == joined_entity.user_id, isouter=isouterjoin) + return JoinReturn(new_query) + + def _join_user_node(self, query: Query, joined_entity, entity_to_join, isouterjoin: bool, **_kw): + """ + :param joined_entity: the aliased user you want to join to + :param entity_to_join: the (aliased) node or group in the DB to join with + """ + _check_dbentities((joined_entity, self._entities.User), (entity_to_join, self._entities.Node), 'with_user') + new_query = query.join(entity_to_join, entity_to_join.user_id == joined_entity.id, isouter=isouterjoin) + return JoinReturn(new_query) + + def _join_computer_node(self, query: Query, joined_entity, entity_to_join, isouterjoin: bool, **_kw): + """ + :param joined_entity: the (aliased) computer entity + :param entity_to_join: the (aliased) node entity + + """ + _check_dbentities((joined_entity, self._entities.Computer), (entity_to_join, self._entities.Node), + 'with_computer') + new_query = query.join(entity_to_join, entity_to_join.dbcomputer_id == joined_entity.id, isouter=isouterjoin) + return JoinReturn(new_query) + + def _join_node_computer(self, query: Query, joined_entity, entity_to_join, isouterjoin: bool, **_kw): + """ + :param joined_entity: An entity that can use a computer (eg a node) + :param entity_to_join: aliased dbcomputer entity + """ + _check_dbentities((joined_entity, self._entities.Node), (entity_to_join, self._entities.Computer), 'with_node') + new_query = query.join(entity_to_join, joined_entity.dbcomputer_id == entity_to_join.id, isouter=isouterjoin) + return JoinReturn(new_query) + + def _join_group_user(self, query: Query, joined_entity, entity_to_join, isouterjoin: bool, **_kw): + """ + :param joined_entity: An aliased dbgroup + :param entity_to_join: aliased dbuser + """ + _check_dbentities((joined_entity, self._entities.Group), (entity_to_join, self._entities.User), 'with_group') + new_query = query.join(entity_to_join, joined_entity.user_id == entity_to_join.id, isouter=isouterjoin) + return JoinReturn(new_query) + + def _join_user_group(self, query: Query, joined_entity, entity_to_join, isouterjoin: bool, **_kw): + """ + :param joined_entity: An aliased user + :param entity_to_join: aliased group + """ + _check_dbentities((joined_entity, self._entities.User), (entity_to_join, self._entities.Group), 'with_user') + new_query = query.join(entity_to_join, joined_entity.id == entity_to_join.user_id, isouter=isouterjoin) + return JoinReturn(new_query) + + def _join_node_comment(self, query: Query, joined_entity, entity_to_join, isouterjoin: bool, **_kw): + """ + :param joined_entity: An aliased node + :param entity_to_join: aliased comment + """ + _check_dbentities((joined_entity, self._entities.Node), (entity_to_join, self._entities.Comment), 'with_node') + new_query = query.join(entity_to_join, joined_entity.id == entity_to_join.dbnode_id, isouter=isouterjoin) + return JoinReturn(new_query) + + def _join_comment_node(self, query: Query, joined_entity, entity_to_join, isouterjoin: bool, **_kw): + """ + :param joined_entity: An aliased comment + :param entity_to_join: aliased node + """ + _check_dbentities((joined_entity, self._entities.Comment), (entity_to_join, self._entities.Node), + 'with_comment') + new_query = query.join(entity_to_join, joined_entity.dbnode_id == entity_to_join.id, isouter=isouterjoin) + return JoinReturn(new_query) + + def _join_node_log(self, query: Query, joined_entity, entity_to_join, isouterjoin: bool, **_kw): + """ + :param joined_entity: An aliased node + :param entity_to_join: aliased log + """ + _check_dbentities((joined_entity, self._entities.Node), (entity_to_join, self._entities.Log), 'with_node') + new_query = query.join(entity_to_join, joined_entity.id == entity_to_join.dbnode_id, isouter=isouterjoin) + return JoinReturn(new_query) + + def _join_log_node(self, query: Query, joined_entity, entity_to_join, isouterjoin: bool, **_kw): + """ + :param joined_entity: An aliased log + :param entity_to_join: aliased node + """ + _check_dbentities((joined_entity, self._entities.Log), (entity_to_join, self._entities.Node), 'with_log') + new_query = query.join(entity_to_join, joined_entity.dbnode_id == entity_to_join.id, isouter=isouterjoin) + return JoinReturn(new_query) + + def _join_user_comment(self, query: Query, joined_entity, entity_to_join, isouterjoin: bool, **_kw): + """ + :param joined_entity: An aliased user + :param entity_to_join: aliased comment + """ + _check_dbentities((joined_entity, self._entities.User), (entity_to_join, self._entities.Comment), 'with_user') + new_query = query.join(entity_to_join, joined_entity.id == entity_to_join.user_id, isouter=isouterjoin) + return JoinReturn(new_query) + + def _join_comment_user(self, query: Query, joined_entity, entity_to_join, isouterjoin: bool, **_kw): + """ + :param joined_entity: An aliased comment + :param entity_to_join: aliased user + """ + _check_dbentities((joined_entity, self._entities.Comment), (entity_to_join, self._entities.User), + 'with_comment') + new_query = query.join(entity_to_join, joined_entity.user_id == entity_to_join.id, isouter=isouterjoin) + return JoinReturn(new_query) + + def _join_node_outputs(self, query: Query, joined_entity, entity_to_join, isouterjoin: bool, **_kw): """ :param joined_entity: The (aliased) ORMclass that is an input :param entity_to_join: The (aliased) ORMClass that is an output. @@ -147,7 +323,7 @@ def _join_outputs(self, query: Query, joined_entity, entity_to_join, isouterjoin ).join(entity_to_join, aliased_edge.output_id == entity_to_join.id, isouter=isouterjoin) return JoinReturn(new_query, aliased_edge) - def _join_inputs(self, query: Query, joined_entity, entity_to_join, isouterjoin: bool, **_kw): + def _join_node_inputs(self, query: Query, joined_entity, entity_to_join, isouterjoin: bool, **_kw): """ :param joined_entity: The (aliased) ORMclass that is an output :param entity_to_join: The (aliased) ORMClass that is an input. @@ -166,7 +342,7 @@ def _join_inputs(self, query: Query, joined_entity, entity_to_join, isouterjoin: ).join(entity_to_join, aliased_edge.input_id == entity_to_join.id, isouter=isouterjoin) return JoinReturn(new_query, aliased_edge) - def _join_descendants_recursive( + def _join_node_descendants_recursive( self, query: Query, joined_entity, @@ -231,7 +407,7 @@ def _join_descendants_recursive( ) return JoinReturn(new_query, descendants_recursive.c) - def _join_ancestors_recursive( + def _join_node_ancestors_recursive( self, query: Query, joined_entity, @@ -297,155 +473,6 @@ def _join_ancestors_recursive( ) return JoinReturn(new_query, ancestors_recursive.c) - def _join_group_members(self, query: Query, joined_entity, entity_to_join, isouterjoin: bool, **_kw): - """ - :param joined_entity: - The (aliased) ORMclass that is - a group in the database - :param entity_to_join: - The (aliased) ORMClass that is a node and member of the group - - **joined_entity** and **entity_to_join** - are joined via the table_groups_nodes table. - from **joined_entity** as group to **enitity_to_join** as node. - (**enitity_to_join** is *with_group* **joined_entity**) - """ - _check_dbentities((joined_entity, self._entities.Group), (entity_to_join, self._entities.Node), 'with_group') - aliased_group_nodes = aliased(self._entities.table_groups_nodes) - new_query = query.join(aliased_group_nodes, aliased_group_nodes.c.dbgroup_id == joined_entity.id).join( - entity_to_join, entity_to_join.id == aliased_group_nodes.c.dbnode_id, isouter=isouterjoin - ) - return JoinReturn(new_query, aliased_group_nodes) - - def _join_groups(self, query: Query, joined_entity, entity_to_join, isouterjoin: bool, **_kw): - """ - :param joined_entity: The (aliased) node in the database - :param entity_to_join: The (aliased) Group - - **joined_entity** and **entity_to_join** are - joined via the table_groups_nodes table. - from **joined_entity** as node to **enitity_to_join** as group. - (**enitity_to_join** is a group *with_node* **joined_entity**) - """ - _check_dbentities((joined_entity, self._entities.Node), (entity_to_join, self._entities.Group), 'with_node') - aliased_group_nodes = aliased(self._entities.table_groups_nodes) - new_query = query.join(aliased_group_nodes, aliased_group_nodes.c.dbnode_id == joined_entity.id).join( - entity_to_join, entity_to_join.id == aliased_group_nodes.c.dbgroup_id, isouter=isouterjoin - ) - return JoinReturn(new_query, aliased_group_nodes) - - def _join_creator_of(self, query: Query, joined_entity, entity_to_join, isouterjoin: bool, **_kw): - """ - :param joined_entity: the aliased node - :param entity_to_join: the aliased user to join to that node - """ - _check_dbentities((joined_entity, self._entities.Node), (entity_to_join, self._entities.User), 'with_node') - new_query = query.join(entity_to_join, entity_to_join.id == joined_entity.user_id, isouter=isouterjoin) - return JoinReturn(new_query) - - def _join_created_by(self, query: Query, joined_entity, entity_to_join, isouterjoin: bool, **_kw): - """ - :param joined_entity: the aliased user you want to join to - :param entity_to_join: the (aliased) node or group in the DB to join with - """ - _check_dbentities((joined_entity, self._entities.User), (entity_to_join, self._entities.Node), 'with_user') - new_query = query.join(entity_to_join, entity_to_join.user_id == joined_entity.id, isouter=isouterjoin) - return JoinReturn(new_query) - - def _join_to_computer_used(self, query: Query, joined_entity, entity_to_join, isouterjoin: bool, **_kw): - """ - :param joined_entity: the (aliased) computer entity - :param entity_to_join: the (aliased) node entity - - """ - _check_dbentities((joined_entity, self._entities.Computer), (entity_to_join, self._entities.Node), - 'with_computer') - new_query = query.join(entity_to_join, entity_to_join.dbcomputer_id == joined_entity.id, isouter=isouterjoin) - return JoinReturn(new_query) - - def _join_computer(self, query: Query, joined_entity, entity_to_join, isouterjoin: bool, **_kw): - """ - :param joined_entity: An entity that can use a computer (eg a node) - :param entity_to_join: aliased dbcomputer entity - """ - _check_dbentities((joined_entity, self._entities.Node), (entity_to_join, self._entities.Computer), 'with_node') - new_query = query.join(entity_to_join, joined_entity.dbcomputer_id == entity_to_join.id, isouter=isouterjoin) - return JoinReturn(new_query) - - def _join_group_user(self, query: Query, joined_entity, entity_to_join, isouterjoin: bool, **_kw): - """ - :param joined_entity: An aliased dbgroup - :param entity_to_join: aliased dbuser - """ - _check_dbentities((joined_entity, self._entities.Group), (entity_to_join, self._entities.User), 'with_group') - new_query = query.join(entity_to_join, joined_entity.user_id == entity_to_join.id, isouter=isouterjoin) - return JoinReturn(new_query) - - def _join_user_group(self, query: Query, joined_entity, entity_to_join, isouterjoin: bool, **_kw): - """ - :param joined_entity: An aliased user - :param entity_to_join: aliased group - """ - _check_dbentities((joined_entity, self._entities.User), (entity_to_join, self._entities.Group), 'with_user') - new_query = query.join(entity_to_join, joined_entity.id == entity_to_join.user_id, isouter=isouterjoin) - return JoinReturn(new_query) - - def _join_node_comment(self, query: Query, joined_entity, entity_to_join, isouterjoin: bool, **_kw): - """ - :param joined_entity: An aliased node - :param entity_to_join: aliased comment - """ - _check_dbentities((joined_entity, self._entities.Node), (entity_to_join, self._entities.Comment), 'with_node') - new_query = query.join(entity_to_join, joined_entity.id == entity_to_join.dbnode_id, isouter=isouterjoin) - return JoinReturn(new_query) - - def _join_comment_node(self, query: Query, joined_entity, entity_to_join, isouterjoin: bool, **_kw): - """ - :param joined_entity: An aliased comment - :param entity_to_join: aliased node - """ - _check_dbentities((joined_entity, self._entities.Comment), (entity_to_join, self._entities.Node), - 'with_comment') - new_query = query.join(entity_to_join, joined_entity.dbnode_id == entity_to_join.id, isouter=isouterjoin) - return JoinReturn(new_query) - - def _join_node_log(self, query: Query, joined_entity, entity_to_join, isouterjoin: bool, **_kw): - """ - :param joined_entity: An aliased node - :param entity_to_join: aliased log - """ - _check_dbentities((joined_entity, self._entities.Node), (entity_to_join, self._entities.Log), 'with_node') - new_query = query.join(entity_to_join, joined_entity.id == entity_to_join.dbnode_id, isouter=isouterjoin) - return JoinReturn(new_query) - - def _join_log_node(self, query: Query, joined_entity, entity_to_join, isouterjoin: bool, **_kw): - """ - :param joined_entity: An aliased log - :param entity_to_join: aliased node - """ - _check_dbentities((joined_entity, self._entities.Log), (entity_to_join, self._entities.Node), 'with_log') - new_query = query.join(entity_to_join, joined_entity.dbnode_id == entity_to_join.id, isouter=isouterjoin) - return JoinReturn(new_query) - - def _join_user_comment(self, query: Query, joined_entity, entity_to_join, isouterjoin: bool, **_kw): - """ - :param joined_entity: An aliased user - :param entity_to_join: aliased comment - """ - _check_dbentities((joined_entity, self._entities.User), (entity_to_join, self._entities.Comment), 'with_user') - new_query = query.join(entity_to_join, joined_entity.id == entity_to_join.user_id, isouter=isouterjoin) - return JoinReturn(new_query) - - def _join_comment_user(self, query: Query, joined_entity, entity_to_join, isouterjoin: bool, **_kw): - """ - :param joined_entity: An aliased comment - :param entity_to_join: aliased user - """ - _check_dbentities((joined_entity, self._entities.Comment), (entity_to_join, self._entities.User), - 'with_comment') - new_query = query.join(entity_to_join, joined_entity.user_id == entity_to_join.id, isouter=isouterjoin) - return JoinReturn(new_query) - def _check_dbentities(entities_cls_joined, entities_cls_to_join, relationship: str): """Type check for entities diff --git a/tests/orm/test_querybuilder.py b/tests/orm/test_querybuilder.py index e6e366dba0..75f626ecd1 100644 --- a/tests/orm/test_querybuilder.py +++ b/tests/orm/test_querybuilder.py @@ -1084,6 +1084,26 @@ def test_joins3_user_group(self): assert qb.count() == 1, 'The expected user that owns the selected group was not found.' + def test_joins_authinfo(self): + """Test querying for AuthInfo with specific computer/user.""" + user = orm.User(email='email@new.com').store() + computer = orm.Computer( + label='new', hostname='localhost', transport_type='core.local', scheduler_type='core.direct' + ).store() + authinfo = computer.configure(user) + + # Search for the user of the authinfo + qb = orm.QueryBuilder() + qb.append(orm.User, tag='user', filters={'id': {'==': user.id}}) + qb.append(orm.AuthInfo, with_user='user', filters={'id': {'==': authinfo.id}}) + assert qb.count() == 1, 'The expected user that owns the selected authinfo was not found.' + + # Search for the computer of the authinfo + qb = orm.QueryBuilder() + qb.append(orm.Computer, tag='computer', filters={'id': {'==': computer.id}}) + qb.append(orm.AuthInfo, with_computer='computer', filters={'id': {'==': authinfo.id}}) + assert qb.count() == 1, 'The expected computer that owns the selected authinfo was not found.' + def test_joins_group_node(self): """ This test checks that the querying for the nodes that belong to a group works correctly (using QueryBuilder). From dfbd49760922602014d4fe85c50517253d3d836b Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Tue, 26 Oct 2021 07:06:08 +0200 Subject: [PATCH 4/5] =?UTF-8?q?=F0=9F=91=8C=20IMPROVE:=20Get=20session=20f?= =?UTF-8?q?rom=20backend=20in=20`SqlaModelEntity`=20(#5199)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace `get_scoped_session()` with `self.backend.get_session()`. Note currently `SqlaBackend.get_session` simply calls `get_scoped_session()`, so this does not change any logic. But eventually `SqlaBackend` instances will manage their own session, and `get_scoped_session()` will be removed. --- .../orm/implementation/sqlalchemy/authinfos.py | 6 ++---- aiida/orm/implementation/sqlalchemy/comments.py | 5 ++--- .../orm/implementation/sqlalchemy/computers.py | 10 ++++------ aiida/orm/implementation/sqlalchemy/groups.py | 17 ++++++----------- aiida/orm/implementation/sqlalchemy/logs.py | 5 ++--- aiida/orm/implementation/sqlalchemy/nodes.py | 11 +++++------ 6 files changed, 21 insertions(+), 33 deletions(-) diff --git a/aiida/orm/implementation/sqlalchemy/authinfos.py b/aiida/orm/implementation/sqlalchemy/authinfos.py index 5f8b266ca0..bad9053d29 100644 --- a/aiida/orm/implementation/sqlalchemy/authinfos.py +++ b/aiida/orm/implementation/sqlalchemy/authinfos.py @@ -8,8 +8,6 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Module for the SqlAlchemy backend implementation of the `AuthInfo` ORM class.""" - -from aiida.backends.sqlalchemy import get_scoped_session from aiida.backends.sqlalchemy.models.authinfo import DbAuthInfo from aiida.common import exceptions from aiida.common.lang import type_check @@ -123,7 +121,7 @@ def delete(self, pk): # pylint: disable=import-error,no-name-in-module from sqlalchemy.orm.exc import NoResultFound - session = get_scoped_session() + session = self.backend.get_session() try: session.query(DbAuthInfo).filter_by(id=pk).one().delete() @@ -143,7 +141,7 @@ def get(self, computer, user): # pylint: disable=import-error,no-name-in-module from sqlalchemy.orm.exc import MultipleResultsFound, NoResultFound - session = get_scoped_session() + session = self.backend.get_session() try: authinfo = session.query(DbAuthInfo).filter_by(dbcomputer_id=computer.id, aiidauser_id=user.id).one() diff --git a/aiida/orm/implementation/sqlalchemy/comments.py b/aiida/orm/implementation/sqlalchemy/comments.py index 618aa021bf..f67a008478 100644 --- a/aiida/orm/implementation/sqlalchemy/comments.py +++ b/aiida/orm/implementation/sqlalchemy/comments.py @@ -14,7 +14,6 @@ from sqlalchemy.orm.exc import NoResultFound -from aiida.backends.sqlalchemy import get_scoped_session from aiida.backends.sqlalchemy.models import comment as models from aiida.common import exceptions, lang @@ -125,7 +124,7 @@ def delete(self, comment_id): if not isinstance(comment_id, int): raise TypeError('comment_id must be an int') - session = get_scoped_session() + session = self.backend.get_session() try: session.query(models.DbComment).filter_by(id=comment_id).one().delete() @@ -140,7 +139,7 @@ def delete_all(self): :raises `~aiida.common.exceptions.IntegrityError`: if all Comments could not be deleted """ - session = get_scoped_session() + session = self.backend.get_session() try: session.query(models.DbComment).delete() diff --git a/aiida/orm/implementation/sqlalchemy/computers.py b/aiida/orm/implementation/sqlalchemy/computers.py index 525fae15b4..6cebd9b351 100644 --- a/aiida/orm/implementation/sqlalchemy/computers.py +++ b/aiida/orm/implementation/sqlalchemy/computers.py @@ -15,7 +15,6 @@ from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm.session import make_transient -from aiida.backends.sqlalchemy import get_scoped_session from aiida.backends.sqlalchemy.models.computer import DbComputer from aiida.common import exceptions from aiida.orm.implementation.computers import BackendComputer, BackendComputerCollection @@ -52,7 +51,7 @@ def is_stored(self): def copy(self): """Create an unstored clone of an already stored `Computer`.""" - session = get_scoped_session() + session = self.backend.get_session() if not self.is_stored: raise exceptions.InvalidOperation('You can copy a computer only after having stored it') @@ -128,14 +127,13 @@ class SqlaComputerCollection(BackendComputerCollection): ENTITY_CLASS = SqlaComputer - @staticmethod - def list_names(): - session = get_scoped_session() + def list_names(self): + session = self.backend.get_session() return session.query(DbComputer.label).all() def delete(self, pk): try: - session = get_scoped_session() + session = self.backend.get_session() session.get(DbComputer, pk).delete() session.commit() except SQLAlchemyError as exc: diff --git a/aiida/orm/implementation/sqlalchemy/groups.py b/aiida/orm/implementation/sqlalchemy/groups.py index 8b8e991c3b..6c5744bcb4 100644 --- a/aiida/orm/implementation/sqlalchemy/groups.py +++ b/aiida/orm/implementation/sqlalchemy/groups.py @@ -12,7 +12,6 @@ from collections.abc import Iterable import logging -from aiida.backends import sqlalchemy as sa from aiida.backends.sqlalchemy.models.group import DbGroup, table_groups_nodes from aiida.backends.sqlalchemy.models.node import DbNode from aiida.common.exceptions import UniquenessError @@ -124,14 +123,12 @@ def count(self): :return: integer number of entities contained within the group """ - from aiida.backends.sqlalchemy import get_scoped_session - session = get_scoped_session() + session = self.backend.get_session() return session.query(self.MODEL_CLASS).join(self.MODEL_CLASS.dbnodes).filter(DbGroup.id == self.pk).count() def clear(self): """Remove all the nodes from this group.""" - from aiida.backends.sqlalchemy import get_scoped_session - session = get_scoped_session() + session = self.backend.get_session() # Note we have to call `dbmodel` and `_dbmodel` to circumvent the `ModelWrapper` self.dbmodel.dbnodes = [] session.commit() @@ -184,7 +181,6 @@ def add_nodes(self, nodes, **kwargs): from sqlalchemy.dialects.postgresql import insert # pylint: disable=import-error, no-name-in-module from sqlalchemy.exc import IntegrityError # pylint: disable=import-error, no-name-in-module - from aiida.backends.sqlalchemy import get_scoped_session from aiida.backends.sqlalchemy.models.base import Base from aiida.orm.implementation.sqlalchemy.nodes import SqlaNode @@ -199,7 +195,7 @@ def check_node(given_node): if not given_node.is_stored: raise ValueError('At least one of the provided nodes is unstored, stopping...') - with utils.disable_expire_on_commit(get_scoped_session()) as session: + with utils.disable_expire_on_commit(self.backend.get_session()) as session: if not skip_orm: # Get dbnodes here ONCE, otherwise each call to dbnodes will re-read the current value in the database dbnodes = self._dbmodel.dbnodes @@ -241,7 +237,6 @@ def remove_nodes(self, nodes, **kwargs): """ from sqlalchemy import and_ - from aiida.backends.sqlalchemy import get_scoped_session from aiida.backends.sqlalchemy.models.base import Base from aiida.orm.implementation.sqlalchemy.nodes import SqlaNode @@ -260,7 +255,7 @@ def check_node(node): list_nodes = [] - with utils.disable_expire_on_commit(get_scoped_session()) as session: + with utils.disable_expire_on_commit(self.backend.get_session()) as session: if not skip_orm: for node in nodes: check_node(node) @@ -303,7 +298,7 @@ def query( # pylint: disable=too-many-branches from aiida.orm.implementation.sqlalchemy.nodes import SqlaNode - session = sa.get_scoped_session() + session = self.backend.get_session() filters = [] @@ -366,7 +361,7 @@ def query( return [SqlaGroup.from_dbmodel(group, self._backend) for group in groups] # pylint: disable=no-member def delete(self, id): # pylint: disable=redefined-builtin - session = sa.get_scoped_session() + session = self.backend.get_session() session.get(DbGroup, id).delete() session.commit() diff --git a/aiida/orm/implementation/sqlalchemy/logs.py b/aiida/orm/implementation/sqlalchemy/logs.py index 62a973171d..8abc1c8e53 100644 --- a/aiida/orm/implementation/sqlalchemy/logs.py +++ b/aiida/orm/implementation/sqlalchemy/logs.py @@ -12,7 +12,6 @@ from sqlalchemy.orm.exc import NoResultFound -from aiida.backends.sqlalchemy import get_scoped_session from aiida.backends.sqlalchemy.models import log as models from aiida.common import exceptions @@ -107,7 +106,7 @@ def delete(self, log_id): if not isinstance(log_id, int): raise TypeError('log_id must be an int') - session = get_scoped_session() + session = self.backend.get_session() try: session.query(models.DbLog).filter_by(id=log_id).one().delete() @@ -122,7 +121,7 @@ def delete_all(self): :raises `~aiida.common.exceptions.IntegrityError`: if all Logs could not be deleted """ - session = get_scoped_session() + session = self.backend.get_session() try: session.query(models.DbLog).delete() diff --git a/aiida/orm/implementation/sqlalchemy/nodes.py b/aiida/orm/implementation/sqlalchemy/nodes.py index db324ba280..1bcc841548 100644 --- a/aiida/orm/implementation/sqlalchemy/nodes.py +++ b/aiida/orm/implementation/sqlalchemy/nodes.py @@ -15,7 +15,6 @@ from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm.exc import NoResultFound -from aiida.backends.sqlalchemy import get_scoped_session from aiida.backends.sqlalchemy.models import node as models from aiida.common import exceptions from aiida.common.lang import type_check @@ -158,7 +157,7 @@ def add_incoming(self, source, link_type, link_label): :return: True if the proposed link is allowed, False otherwise :raise aiida.common.ModificationNotAllowed: if either source or target node is not stored """ - session = get_scoped_session() + session = self.backend.get_session() type_check(source, SqlaNode) @@ -180,7 +179,7 @@ def _add_link(self, source, link_type, link_label): """ from aiida.backends.sqlalchemy.models.node import DbLink - session = get_scoped_session() + session = self.backend.get_session() try: with session.begin_nested(): @@ -200,7 +199,7 @@ def store(self, links=None, with_transaction=True, clean=True): # pylint: disab :param with_transaction: if False, do not use a transaction because the caller will already have opened one. :param clean: boolean, if True, will clean the attributes and extras before attempting to store """ - session = get_scoped_session() + session = self.backend.get_session() if clean: self.clean_values() @@ -231,7 +230,7 @@ def get(self, pk): :param pk: id of the node """ - session = get_scoped_session() + session = self.backend.get_session() try: return self.ENTITY_CLASS.from_dbmodel(session.query(models.DbNode).filter_by(id=pk).one(), self.backend) @@ -243,7 +242,7 @@ def delete(self, pk): :param pk: id of the node to delete """ - session = get_scoped_session() + session = self.backend.get_session() try: session.query(models.DbNode).filter_by(id=pk).one().delete() From 78c7b7281669b1000ab328a1f731fb0cd3a1aef3 Mon Sep 17 00:00:00 2001 From: "Jason.Yu" Date: Tue, 26 Oct 2021 07:18:56 +0200 Subject: [PATCH 5/5] =?UTF-8?q?=F0=9F=93=9A=20DOCS:=20howto=20set=20PYTHON?= =?UTF-8?q?PATH=20for=20local=20work=20chains=20(#5191)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Chris Sewell --- docs/source/howto/faq.rst | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/docs/source/howto/faq.rst b/docs/source/howto/faq.rst index 41b8803ec3..b62778c98e 100644 --- a/docs/source/howto/faq.rst +++ b/docs/source/howto/faq.rst @@ -74,9 +74,17 @@ To determine exactly what might be going wrong, first :ref:`set the loglevel `_. -Make sure that the PYTHONPATH is correctly defined automatically when starting your shell, so for example if you are using bash, add it to your ``.bashrc``. +Make sure that the PYTHONPATH is correctly defined automatically when starting your shell, so for example if you are using bash, add it to your ``.bashrc`` and completely reset daemon. +For example, go to the directory that contains the file where you defined the process and run: + +.. code-block:: console + + $ echo "export PYTHONPATH=\$PYTHONPATH:$PWD" >> $HOME/.bashrc + $ source $HOME/.bashrc + $ verdi daemon restart --reset .. _how-to:faq:caching-not-enabled: