Skip to content

Commit

Permalink
changed groupby function to accept multiple attributes #1839 (#1846)
Browse files Browse the repository at this point in the history
* changed groupby function to accept multiple attributes #1839

* Added version number to versionadded field

* erroneously changed versionadded field instead of adding versionchanged

* Added tests for groupby function using multiple attributes as arguments

* Updated documentation of groupby function. Updated CHANGELOG and AUTHORS

* Corrected formatting of versionchanged field

* multiple functions are passed to groupby as a list. Updated tests acordingly

* updated docs and CHANGELOG to latest version of PR

* updated CHANGELOG

* Modified output of groupby to flattened dict. Added dictionary flattening function to utils

* tests updated according to changes to groupby

* Added tests for flatten_dict function

* updated CHANGELOG

* now testing for string_types in groupby. explicit cast on groupby tests

* added test on string types to groupby tests

* documentation formatting

* added check for and decode of bytes string

* added bytes string compatibility

* cleaned code and style formatting

* corrected issues with output when using one attribute
  • Loading branch information
davidercruz authored and jbarnoud committed Mar 28, 2018
1 parent 5a232a3 commit 8f950b6
Show file tree
Hide file tree
Showing 6 changed files with 154 additions and 21 deletions.
1 change: 1 addition & 0 deletions package/AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ Chronological list of authors
- Navya Khare
- Johannes Zeman
- Ayush Suhane
- Davide Cruz

External code
-------------
Expand Down
7 changes: 6 additions & 1 deletion package/CHANGELOG
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,16 @@ The rules for this file:

------------------------------------------------------------------------------
mm/dd/18 richardjgowers, palnabarun, bieniekmateusz, kain88-de, orbeckst,
xiki-tempula, navyakhare, zemanj, ayushsuhane
xiki-tempula, navyakhare, zemanj, ayushsuhane, davidercruz

* 0.17.1

Enhancements
* Added flatten_dict function that flattens nested dicts into shallow
dicts with tuples as keys.
* Can now pass multiple attributes as a list to groupby function.
Eg ag.groupby(["masses","charges"])
(Issue #1839)
* Added reading of record types (ATOM/HETATM) for PDB, PDBQT and PQR formats
(Issue #1753)
* Added Universe.copy to allow creation of an independent copy of a Universe
Expand Down
67 changes: 52 additions & 15 deletions package/MDAnalysis/core/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -1039,32 +1039,69 @@ def wrap(self, compound="atoms", center="com", box=None):
if not all(s == 0.0):
o.atoms.translate(s)

def groupby(self, topattr):
def groupby(self, topattrs):
"""Group together items in this group according to values of *topattr*
Parameters
----------
topattr: str
Topology attribute to group components by.
topattrs: str or list
One or more topology attribute to group components by.
Single arguments are passed as a string. Multiple arguments
are passed as a list.
Returns
-------
dict
Unique values of the topology attribute as keys, Groups as values.
Unique values of the multiple combinations of topology attributes
as keys, Groups as values.
Example
-------
To group atoms with the same mass together::
>>> ag.groupby('masses')
{12.010999999999999: <AtomGroup with 462 atoms>,
14.007: <AtomGroup with 116 atoms>,
15.999000000000001: <AtomGroup with 134 atoms>}
-------
To group atoms with the same mass together:
>>> ag.groupby('masses')
{12.010999999999999: <AtomGroup with 462 atoms>,
14.007: <AtomGroup with 116 atoms>,
15.999000000000001: <AtomGroup with 134 atoms>}
To group atoms with the same residue name and mass together:
>>> ag.groupby(['resnames', 'masses'])
{('ALA', 1.008): <AtomGroup with 95 atoms>,
('ALA', 12.011): <AtomGroup with 57 atoms>,
('ALA', 14.007): <AtomGroup with 19 atoms>,
('ALA', 15.999): <AtomGroup with 19 atoms>},
('ARG', 1.008): <AtomGroup with 169 atoms>,
...}
>>> ag.groupby(['resnames', 'masses'])('ALA', 15.999)
<AtomGroup with 19 atoms>
.. versionadded:: 0.16.0
"""
ta = getattr(self, topattr)
return {i: self[ta == i] for i in set(ta)}
.. versionchanged:: 0.18.0 The function accepts multiple attributes
"""

res = dict()

if isinstance(topattrs, (string_types, bytes)):
attr=topattrs
if isinstance(topattrs, bytes):
attr = topattrs.decode('utf-8')
ta = getattr(self, attr)

return {i: self[ta == i] for i in set(ta)}

else:
attr = topattrs[0]
ta = getattr(self, attr)
for i in set(ta):
if len(topattrs) == 1:
res[i] = self[ta == i]
else:
res[i] = self[ta == i].groupby(topattrs[1:])

return util.flatten_dict(res)


@_only_same_level
def concatenate(self, other):
Expand Down
32 changes: 32 additions & 0 deletions package/MDAnalysis/lib/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@
import re
import io
import warnings
import collections
from functools import wraps
import mmtf
import numpy as np
Expand Down Expand Up @@ -1677,3 +1678,34 @@ def ltruncate_int(value, ndigits):
1234
"""
return int(str(value)[-ndigits:])


def flatten_dict(d, parent_key=tuple()):
"""Flatten a nested dict `d` into a shallow dict with tuples as keys.
Parameters
----------
d : dict
Returns
-------
dict
Note
-----
Based on https://stackoverflow.com/a/6027615/ by user https://stackoverflow.com/users/1897/imran
.. versionadded:: 0.18.0
"""

items = []
for k, v in d.items():
if type(k) != tuple:
new_key = parent_key + (k, )
else:
new_key = parent_key + k
if isinstance(v, collections.MutableMapping):
items.extend(flatten_dict(v, new_key).items())
else:
items.append((new_key, v))
return dict(items)
50 changes: 45 additions & 5 deletions testsuite/MDAnalysisTests/core/test_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,8 @@ class TestGroupBy(object):
# tests for the method 'groupby'
@pytest.fixture()
def u(self):
return make_Universe(('types', 'charges', 'resids'))
return make_Universe(('segids', 'charges', 'resids'))


def test_groupby_float(self, u):
gb = u.atoms.groupby('charges')
Expand All @@ -587,14 +588,15 @@ def test_groupby_float(self, u):
assert all(g.charges == ref)
assert len(g) == 25

def test_groupby_string(self, u):
gb = u.atoms.groupby('types')
@pytest.mark.parametrize('string', ['segids', b'segids', u'segids'])
def test_groupby_string(self, u, string):
gb = u.atoms.groupby(string)

assert len(gb) == 5
for ref in ['TypeA', 'TypeB', 'TypeC', 'TypeD', 'TypeE']:
for ref in ['SegA', 'SegB', 'SegC', 'SegD', 'SegE']:
assert ref in gb
g = gb[ref]
assert all(g.types == ref)
assert all(g.segids == ref)
assert len(g) == 25

def test_groupby_int(self, u):
Expand All @@ -603,6 +605,44 @@ def test_groupby_int(self, u):
for g in gb.values():
assert len(g) == 5

# tests for multiple attributes as arguments

def test_groupby_float_string(self, u):
gb = u.atoms.groupby(['charges', 'segids'])

for ref in [-1.5, -0.5, 0.0, 0.5, 1.5]:
for subref in ['SegA','SegB','SegC','SegD','SegE']:
assert (ref, subref) in gb.keys()
a = gb[(ref, subref)]
assert len(a) == 5
assert all(a.charges == ref)
assert all(a.segids == subref)

def test_groupby_int_float(self, u):
gb = u.atoms.groupby(['resids', 'charges'])

uplim=int(len(gb)/5+1)
for ref in range(1, uplim):
for subref in [-1.5, -0.5, 0.0, 0.5, 1.5]:
assert (ref, subref) in gb.keys()
a = gb[(ref, subref)]
assert len(a) == 1
assert all(a.resids == ref)
assert all(a.charges == subref)

def test_groupby_string_int(self, u):
gb = u.atoms.groupby(['segids', 'resids'])

assert len(gb) == 25
res = 1
for ref in ['SegA','SegB','SegC','SegD','SegE']:
for subref in range(0, 5):
assert (ref, res) in gb.keys()
a = gb[(ref, res)]
assert all(a.segids == ref)
assert all(a.resids == res)
res += 1


class TestReprs(object):
@pytest.fixture()
Expand Down
18 changes: 18 additions & 0 deletions testsuite/MDAnalysisTests/lib/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1060,3 +1060,21 @@ class TestTruncateInteger(object):
])
def test_ltruncate_int(self, a, b):
assert util.ltruncate_int(*a) == b

class TestFlattenDict(object):
def test_flatten_dict(self):
d = {
'A' : { 1 : ('a', 'b', 'c')},
'B' : { 2 : ('c', 'd', 'e')},
'C' : { 3 : ('f', 'g', 'h')}
}
result = util.flatten_dict(d)

for k in result:
assert type(k) == tuple
assert len(k) == 2
assert k[0] in d
assert k[1] in d[k[0]]
assert result[k] in d[k[0]].values()


0 comments on commit 8f950b6

Please sign in to comment.