Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

QueryBuilder: add flat keyword to first method #5410

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions aiida/orm/nodes/data/upf.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def upload_upf_family(folder, group_label, group_description, stop_if_existing=T
md5sum = md5_file(filename)
builder = orm.QueryBuilder(backend=backend)
builder.append(UpfData, filters={'attributes.md5': {'==': md5sum}})
existing_upf = builder.first()
existing_upf = builder.first(flat=True)

if existing_upf is None:
# return the upfdata instances, not stored
Expand All @@ -133,7 +133,6 @@ def upload_upf_family(folder, group_label, group_description, stop_if_existing=T
else:
if stop_if_existing:
raise ValueError(f'A UPF with identical MD5 to {filename} cannot be added with stop_if_existing')
existing_upf = existing_upf[0]
pseudo_and_created.append((existing_upf, False))

# check whether pseudo are unique per element
Expand Down
30 changes: 25 additions & 5 deletions aiida/orm/querybuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
An instance of one of the implementation classes becomes a member of the :func:`QueryBuilder` instance
when instantiated by the user.
"""
from __future__ import annotations

from copy import deepcopy
from inspect import isclass as inspect_isclass
from typing import (
Expand All @@ -27,6 +29,7 @@
Dict,
Iterable,
List,
Literal,
NamedTuple,
Optional,
Sequence,
Expand All @@ -35,6 +38,7 @@
Type,
Union,
cast,
overload,
)
import warnings

Expand Down Expand Up @@ -989,20 +993,36 @@ def _get_aiida_entity_res(value) -> Any:
except TypeError:
return value

def first(self) -> Optional[List[Any]]:
"""Executes the query, asking for the first row of results.
@overload
def first(self, flat: Literal[False]) -> Optional[list[Any]]:
...

@overload
def first(self, flat: Literal[True]) -> Optional[Any]:
...

def first(self, flat: bool = False) -> Optional[list[Any] | Any]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def first(self, flat: bool = False) -> Optional[list[Any] | Any]:
def first(self, flat: bool = False) -> None | list[Any] | Any:

Will approve anyway, but you are kind of mixing old and new type annotations in these changes

"""Return the first result of the query.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still think it is important to indicate this actually executes a query, i.e. connects to the storage.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a comment in docstring in vein of SQLA docs


Note, this may change if several rows are valid for the query,
as persistent ordering is not guaranteed unless explicitly specified.
Calling ``first`` results in an execution of the underlying query.

Note, this may change if several rows are valid for the query, as persistent ordering is not guaranteed unless
explicitly specified.

:param flat: if True, return just the projected quantity if there is just a single projection.
:returns: One row of results as a list, or None if no result returned.
"""
result = self._impl.first(self.as_dict())

if result is None:
return None
Comment on lines 1017 to 1018
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sphuber it can return None here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I saw that, but take the first hit aiida/orm/nodes/data/upf.py:136:27:

        existing_upf = builder.first()

        if existing_upf is None:
            ...
        else:
            existing_upf = existing_upf[0]

That looks like a false positive to me. If anything this would be the perfect candidate for first(flat=True).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same goes for the next "hits"

tests/orm/test_querybuilder.py:724:26: E1136: Value 'result' is unsubscriptable (unsubscriptable-object)
tests/orm/test_querybuilder.py:725:26: E1136: Value 'result' is unsubscriptable (unsubscriptable-object)

Also false positives.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair, then IMO, I would just disable unsubscriptable-object, because pylint does not seem to be doing a good job with them: pylint-dev/pylint#1498
I don't feel pylint should be trying to involve itself in type-checking, since that is what mypy is for, that does it a lot better

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eurgh could also be overload, could be future annotations: pylint-dev/pylint#5189, pylint-dev/pylint#3979, pylint-dev/pylint#4369
Silly pylint


return [self._get_aiida_entity_res(rowitem) for rowitem in result]
result = [self._get_aiida_entity_res(rowitem) for rowitem in result]

if flat and len(result) == 1:
return result[0]

return result

def count(self) -> int:
"""
Expand Down
4 changes: 2 additions & 2 deletions aiida/restapi/translator/nodes/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def _get_content(self):
return {}

# otherwise ...
node = self.qbobj.first()[0]
node = self.qbobj.first()[0] # pylint: disable=unsubscriptable-object

# content/attributes
if self._content_type == 'attributes':
Expand Down Expand Up @@ -643,7 +643,7 @@ def get_node_description(node):
nodes = []

if qb_obj.count() > 0:
main_node = qb_obj.first()[0]
main_node = qb_obj.first(flat=True)
pk = main_node.pk
uuid = main_node.uuid
nodetype = main_node.node_type
Expand Down
4 changes: 2 additions & 2 deletions tests/orm/test_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def test_group_uuid_hashing_for_querybuidler(self):
# Search for the UUID of the stored group
builder = orm.QueryBuilder()
builder.append(orm.Group, project=['uuid'], filters={'label': {'==': 'test_group'}})
[uuid] = builder.first()
uuid = builder.first(flat=True)

# Look the node with the previously returned UUID
builder = orm.QueryBuilder()
Expand All @@ -279,7 +279,7 @@ def test_group_uuid_hashing_for_querybuidler(self):

# And that the results are correct
assert builder.count() == 1
assert builder.first()[0] == group.id
assert builder.first(flat=True) == group.id


@pytest.mark.usefixtures('aiida_profile_clean')
Expand Down
28 changes: 22 additions & 6 deletions tests/orm/test_querybuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,7 @@ def test_direction_keyword(self):
assert res2 == {d2.id, d4.id}

@staticmethod
def test_flat():
def test_all_flat():
"""Test the `flat` keyword for the `QueryBuilder.all()` method."""
pks = []
uuids = []
Expand All @@ -665,13 +665,26 @@ def test_flat():
assert len(result) == 10
assert result == pks

# Mutltiple projections
# Multiple projections
builder = orm.QueryBuilder().append(orm.Data, project=['id', 'uuid']).order_by({orm.Data: 'id'})
result = builder.all(flat=True)
assert isinstance(result, list)
assert len(result) == 20
assert result == list(chain.from_iterable(zip(pks, uuids)))

@staticmethod
def test_first_flat():
"""Test the `flat` keyword for the `QueryBuilder.first()` method."""
node = orm.Data().store()

# Single projected property
query = orm.QueryBuilder().append(orm.Data, project='id', filters={'id': node.pk})
assert query.first(flat=True) == node.pk

# Mutltiple projections
query = orm.QueryBuilder().append(orm.Data, project=['id', 'uuid'], filters={'id': node.pk})
assert query.first(flat=True) == [node.pk, node.uuid]

def test_query_links(self):
"""Test querying for links"""
d1, d2, d3, d4 = [orm.Data().store() for _ in range(4)]
Expand Down Expand Up @@ -703,13 +716,16 @@ def test_first_multiple_projections(self):
orm.Data().store()
orm.Data().store()

result = orm.QueryBuilder().append(orm.User, tag='user',
project=['email']).append(orm.Data, with_user='user', project=['*']).first()
query = orm.QueryBuilder()
query.append(orm.User, tag='user', project=['email'])
query.append(orm.Data, with_user='user', project=['*'])

result = query.first()

assert isinstance(result, list)
assert len(result) == 2
assert isinstance(result[0], str)
assert isinstance(result[1], orm.Data)
assert isinstance(result[0], str) # pylint: disable=unsubscriptable-object
assert isinstance(result[1], orm.Data) # pylint: disable=unsubscriptable-object


class TestRepresentations:
Expand Down
4 changes: 2 additions & 2 deletions tests/test_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def test_node_uuid_hashing_for_querybuidler(self):
# Search for the UUID of the stored node
qb = orm.QueryBuilder()
qb.append(orm.Data, project=['uuid'], filters={'id': {'==': n.id}})
[uuid] = qb.first()
uuid = qb.first(flat=True)

# Look the node with the previously returned UUID
qb = orm.QueryBuilder()
Expand All @@ -99,7 +99,7 @@ def test_node_uuid_hashing_for_querybuidler(self):
qb.all()
# And that the results are correct
assert qb.count() == 1
assert qb.first()[0] == n.id
assert qb.first(flat=True) == n.id

@staticmethod
def create_folderdata_with_empty_file():
Expand Down
20 changes: 10 additions & 10 deletions tests/tools/archive/orm/test_computers.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,17 +79,17 @@ def test_same_computer_import(tmp_path, aiida_profile_clean, aiida_localhost):
builder = orm.QueryBuilder()
builder.append(orm.CalcJobNode, project=['label'])
assert builder.count() == 1, 'Only one calculation should be found.'
assert str(builder.first()[0]) == calc1_label, 'The calculation label is not correct.'
assert str(builder.first(flat=True)) == calc1_label, 'The calculation label is not correct.'

# Check that the referenced computer is imported correctly.
builder = orm.QueryBuilder()
builder.append(orm.Computer, project=['label', 'uuid', 'id'])
assert builder.count() == 1, 'Only one computer should be found.'
assert str(builder.first()[0]) == comp_name, 'The computer name is not correct.'
assert str(builder.first()[1]) == comp_uuid, 'The computer uuid is not correct.'
assert str(builder.first()[0]) == comp_name, 'The computer name is not correct.' # pylint: disable=unsubscriptable-object
assert str(builder.first()[1]) == comp_uuid, 'The computer uuid is not correct.' # pylint: disable=unsubscriptable-object

# Store the id of the computer
comp_id = builder.first()[2]
comp_id = builder.first()[2] # pylint: disable=unsubscriptable-object

# Import the second calculation
import_archive(filename2)
Expand All @@ -99,9 +99,9 @@ def test_same_computer_import(tmp_path, aiida_profile_clean, aiida_localhost):
builder = orm.QueryBuilder()
builder.append(orm.Computer, project=['label', 'uuid', 'id'])
assert builder.count() == 1, f'Found {builder.count()} computersbut only one computer should be found.'
assert str(builder.first()[0]) == comp_name, 'The computer name is not correct.'
assert str(builder.first()[1]) == comp_uuid, 'The computer uuid is not correct.'
assert builder.first()[2] == comp_id, 'The computer id is not correct.'
assert str(builder.first()[0]) == comp_name, 'The computer name is not correct.' # pylint: disable=unsubscriptable-object
assert str(builder.first()[1]) == comp_uuid, 'The computer uuid is not correct.' # pylint: disable=unsubscriptable-object
assert builder.first()[2] == comp_id, 'The computer id is not correct.' # pylint: disable=unsubscriptable-object

# Check that now you have two calculations attached to the same
# computer.
Expand Down Expand Up @@ -175,13 +175,13 @@ def test_same_computer_different_name_import(tmp_path, aiida_profile_clean, aiid
builder = orm.QueryBuilder()
builder.append(orm.CalcJobNode, project=['label'])
assert builder.count() == 1, 'Only one calculation should be found.'
assert str(builder.first()[0]) == calc1_label, 'The calculation label is not correct.'
assert str(builder.first(flat=True)) == calc1_label, 'The calculation label is not correct.'

# Check that the referenced computer is imported correctly.
builder = orm.QueryBuilder()
builder.append(orm.Computer, project=['label', 'uuid', 'id'])
assert builder.count() == 1, 'Only one computer should be found.'
assert str(builder.first()[0]) == comp1_name, 'The computer name is not correct.'
assert str(builder.first()[0]) == comp1_name, 'The computer name is not correct.' # pylint: disable=unsubscriptable-object

# Import the second calculation
import_archive(filename2)
Expand All @@ -191,7 +191,7 @@ def test_same_computer_different_name_import(tmp_path, aiida_profile_clean, aiid
builder = orm.QueryBuilder()
builder.append(orm.Computer, project=['label'])
assert builder.count() == 1, f'Found {builder.count()} computersbut only one computer should be found.'
assert str(builder.first()[0]) == comp1_name, 'The computer name is not correct.'
assert str(builder.first(flat=True)) == comp1_name, 'The computer name is not correct.'


def test_different_computer_same_name_import(tmp_path, aiida_profile_clean, aiida_localhost_factory):
Expand Down