diff --git a/aiida/orm/nodes/data/upf.py b/aiida/orm/nodes/data/upf.py index b212327ba2..6896ca5a19 100644 --- a/aiida/orm/nodes/data/upf.py +++ b/aiida/orm/nodes/data/upf.py @@ -122,7 +122,7 @@ def upload_upf_family(folder, group_label, group_description, stop_if_existing=T md5sum = md5_file(filename) builder = orm.QueryBuilder(backend=backend) builder.append(UpfData, filters={'attributes.md5': {'==': md5sum}}) - existing_upf = builder.first() + existing_upf = builder.first(flat=True) if existing_upf is None: # return the upfdata instances, not stored @@ -133,7 +133,6 @@ def upload_upf_family(folder, group_label, group_description, stop_if_existing=T else: if stop_if_existing: raise ValueError(f'A UPF with identical MD5 to {filename} cannot be added with stop_if_existing') - existing_upf = existing_upf[0] pseudo_and_created.append((existing_upf, False)) # check whether pseudo are unique per element diff --git a/aiida/orm/querybuilder.py b/aiida/orm/querybuilder.py index 69582b618f..df62f60af8 100644 --- a/aiida/orm/querybuilder.py +++ b/aiida/orm/querybuilder.py @@ -19,6 +19,8 @@ An instance of one of the implementation classes becomes a member of the :func:`QueryBuilder` instance when instantiated by the user. """ +from __future__ import annotations + from copy import deepcopy from inspect import isclass as inspect_isclass from typing import ( @@ -27,6 +29,7 @@ Dict, Iterable, List, + Literal, NamedTuple, Optional, Sequence, @@ -35,6 +38,7 @@ Type, Union, cast, + overload, ) import warnings @@ -989,12 +993,23 @@ def _get_aiida_entity_res(value) -> Any: except TypeError: return value - def first(self) -> Optional[List[Any]]: - """Executes the query, asking for the first row of results. + @overload + def first(self, flat: Literal[False]) -> Optional[list[Any]]: + ... + + @overload + def first(self, flat: Literal[True]) -> Optional[Any]: + ... + + def first(self, flat: bool = False) -> Optional[list[Any] | Any]: + """Return the first result of the query. - Note, this may change if several rows are valid for the query, - as persistent ordering is not guaranteed unless explicitly specified. + Calling ``first`` results in an execution of the underlying query. + Note, this may change if several rows are valid for the query, as persistent ordering is not guaranteed unless + explicitly specified. + + :param flat: if True, return just the projected quantity if there is just a single projection. :returns: One row of results as a list, or None if no result returned. """ result = self._impl.first(self.as_dict()) @@ -1002,7 +1017,12 @@ def first(self) -> Optional[List[Any]]: if result is None: return None - return [self._get_aiida_entity_res(rowitem) for rowitem in result] + result = [self._get_aiida_entity_res(rowitem) for rowitem in result] + + if flat and len(result) == 1: + return result[0] + + return result def count(self) -> int: """ diff --git a/aiida/restapi/translator/nodes/node.py b/aiida/restapi/translator/nodes/node.py index 2a38586afe..3f3c920ea8 100644 --- a/aiida/restapi/translator/nodes/node.py +++ b/aiida/restapi/translator/nodes/node.py @@ -254,7 +254,7 @@ def _get_content(self): return {} # otherwise ... - node = self.qbobj.first()[0] + node = self.qbobj.first()[0] # pylint: disable=unsubscriptable-object # content/attributes if self._content_type == 'attributes': @@ -643,7 +643,7 @@ def get_node_description(node): nodes = [] if qb_obj.count() > 0: - main_node = qb_obj.first()[0] + main_node = qb_obj.first(flat=True) pk = main_node.pk uuid = main_node.uuid nodetype = main_node.node_type diff --git a/tests/orm/test_groups.py b/tests/orm/test_groups.py index 08f5925b0e..fabcd29760 100644 --- a/tests/orm/test_groups.py +++ b/tests/orm/test_groups.py @@ -268,7 +268,7 @@ def test_group_uuid_hashing_for_querybuidler(self): # Search for the UUID of the stored group builder = orm.QueryBuilder() builder.append(orm.Group, project=['uuid'], filters={'label': {'==': 'test_group'}}) - [uuid] = builder.first() + uuid = builder.first(flat=True) # Look the node with the previously returned UUID builder = orm.QueryBuilder() @@ -279,7 +279,7 @@ def test_group_uuid_hashing_for_querybuidler(self): # And that the results are correct assert builder.count() == 1 - assert builder.first()[0] == group.id + assert builder.first(flat=True) == group.id @pytest.mark.usefixtures('aiida_profile_clean') diff --git a/tests/orm/test_querybuilder.py b/tests/orm/test_querybuilder.py index 9ad8cdc8b5..2d950aacb6 100644 --- a/tests/orm/test_querybuilder.py +++ b/tests/orm/test_querybuilder.py @@ -649,7 +649,7 @@ def test_direction_keyword(self): assert res2 == {d2.id, d4.id} @staticmethod - def test_flat(): + def test_all_flat(): """Test the `flat` keyword for the `QueryBuilder.all()` method.""" pks = [] uuids = [] @@ -665,13 +665,26 @@ def test_flat(): assert len(result) == 10 assert result == pks - # Mutltiple projections + # Multiple projections builder = orm.QueryBuilder().append(orm.Data, project=['id', 'uuid']).order_by({orm.Data: 'id'}) result = builder.all(flat=True) assert isinstance(result, list) assert len(result) == 20 assert result == list(chain.from_iterable(zip(pks, uuids))) + @staticmethod + def test_first_flat(): + """Test the `flat` keyword for the `QueryBuilder.first()` method.""" + node = orm.Data().store() + + # Single projected property + query = orm.QueryBuilder().append(orm.Data, project='id', filters={'id': node.pk}) + assert query.first(flat=True) == node.pk + + # Mutltiple projections + query = orm.QueryBuilder().append(orm.Data, project=['id', 'uuid'], filters={'id': node.pk}) + assert query.first(flat=True) == [node.pk, node.uuid] + def test_query_links(self): """Test querying for links""" d1, d2, d3, d4 = [orm.Data().store() for _ in range(4)] @@ -703,13 +716,16 @@ def test_first_multiple_projections(self): orm.Data().store() orm.Data().store() - result = orm.QueryBuilder().append(orm.User, tag='user', - project=['email']).append(orm.Data, with_user='user', project=['*']).first() + query = orm.QueryBuilder() + query.append(orm.User, tag='user', project=['email']) + query.append(orm.Data, with_user='user', project=['*']) + + result = query.first() assert isinstance(result, list) assert len(result) == 2 - assert isinstance(result[0], str) - assert isinstance(result[1], orm.Data) + assert isinstance(result[0], str) # pylint: disable=unsubscriptable-object + assert isinstance(result[1], orm.Data) # pylint: disable=unsubscriptable-object class TestRepresentations: diff --git a/tests/test_nodes.py b/tests/test_nodes.py index 5fdeae14da..a0110aa6ec 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -89,7 +89,7 @@ def test_node_uuid_hashing_for_querybuidler(self): # Search for the UUID of the stored node qb = orm.QueryBuilder() qb.append(orm.Data, project=['uuid'], filters={'id': {'==': n.id}}) - [uuid] = qb.first() + uuid = qb.first(flat=True) # Look the node with the previously returned UUID qb = orm.QueryBuilder() @@ -99,7 +99,7 @@ def test_node_uuid_hashing_for_querybuidler(self): qb.all() # And that the results are correct assert qb.count() == 1 - assert qb.first()[0] == n.id + assert qb.first(flat=True) == n.id @staticmethod def create_folderdata_with_empty_file(): diff --git a/tests/tools/archive/orm/test_computers.py b/tests/tools/archive/orm/test_computers.py index 3f1513e03c..fc0a652c7d 100644 --- a/tests/tools/archive/orm/test_computers.py +++ b/tests/tools/archive/orm/test_computers.py @@ -79,17 +79,17 @@ def test_same_computer_import(tmp_path, aiida_profile_clean, aiida_localhost): builder = orm.QueryBuilder() builder.append(orm.CalcJobNode, project=['label']) assert builder.count() == 1, 'Only one calculation should be found.' - assert str(builder.first()[0]) == calc1_label, 'The calculation label is not correct.' + assert str(builder.first(flat=True)) == calc1_label, 'The calculation label is not correct.' # Check that the referenced computer is imported correctly. builder = orm.QueryBuilder() builder.append(orm.Computer, project=['label', 'uuid', 'id']) assert builder.count() == 1, 'Only one computer should be found.' - assert str(builder.first()[0]) == comp_name, 'The computer name is not correct.' - assert str(builder.first()[1]) == comp_uuid, 'The computer uuid is not correct.' + assert str(builder.first()[0]) == comp_name, 'The computer name is not correct.' # pylint: disable=unsubscriptable-object + assert str(builder.first()[1]) == comp_uuid, 'The computer uuid is not correct.' # pylint: disable=unsubscriptable-object # Store the id of the computer - comp_id = builder.first()[2] + comp_id = builder.first()[2] # pylint: disable=unsubscriptable-object # Import the second calculation import_archive(filename2) @@ -99,9 +99,9 @@ def test_same_computer_import(tmp_path, aiida_profile_clean, aiida_localhost): builder = orm.QueryBuilder() builder.append(orm.Computer, project=['label', 'uuid', 'id']) assert builder.count() == 1, f'Found {builder.count()} computersbut only one computer should be found.' - assert str(builder.first()[0]) == comp_name, 'The computer name is not correct.' - assert str(builder.first()[1]) == comp_uuid, 'The computer uuid is not correct.' - assert builder.first()[2] == comp_id, 'The computer id is not correct.' + assert str(builder.first()[0]) == comp_name, 'The computer name is not correct.' # pylint: disable=unsubscriptable-object + assert str(builder.first()[1]) == comp_uuid, 'The computer uuid is not correct.' # pylint: disable=unsubscriptable-object + assert builder.first()[2] == comp_id, 'The computer id is not correct.' # pylint: disable=unsubscriptable-object # Check that now you have two calculations attached to the same # computer. @@ -175,13 +175,13 @@ def test_same_computer_different_name_import(tmp_path, aiida_profile_clean, aiid builder = orm.QueryBuilder() builder.append(orm.CalcJobNode, project=['label']) assert builder.count() == 1, 'Only one calculation should be found.' - assert str(builder.first()[0]) == calc1_label, 'The calculation label is not correct.' + assert str(builder.first(flat=True)) == calc1_label, 'The calculation label is not correct.' # Check that the referenced computer is imported correctly. builder = orm.QueryBuilder() builder.append(orm.Computer, project=['label', 'uuid', 'id']) assert builder.count() == 1, 'Only one computer should be found.' - assert str(builder.first()[0]) == comp1_name, 'The computer name is not correct.' + assert str(builder.first()[0]) == comp1_name, 'The computer name is not correct.' # pylint: disable=unsubscriptable-object # Import the second calculation import_archive(filename2) @@ -191,7 +191,7 @@ def test_same_computer_different_name_import(tmp_path, aiida_profile_clean, aiid builder = orm.QueryBuilder() builder.append(orm.Computer, project=['label']) assert builder.count() == 1, f'Found {builder.count()} computersbut only one computer should be found.' - assert str(builder.first()[0]) == comp1_name, 'The computer name is not correct.' + assert str(builder.first(flat=True)) == comp1_name, 'The computer name is not correct.' def test_different_computer_same_name_import(tmp_path, aiida_profile_clean, aiida_localhost_factory):