diff --git a/CHANGELOG.md b/CHANGELOG.md index 39f47b4..e854d3f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,9 +14,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - The visualisation notebook now displays the protein with py3Dmol. Some examples for creating and displaying a graph from the interaction dataframe have been added - Updated the installation instructions to show how to install a specific release +- The previous repr method of `ResidueId` was easy to confuse with a string, especially + when trying to access the `Fingerprint.ifp` results by string. The new repr method is + now more explicit. ### Deprecated ### Removed ### Fixed +- `Fingerprint.to_dataframe` is now much faster (Issue #7) ## [0.3.0] - 2020-12-23 ### Added diff --git a/prolif/residue.py b/prolif/residue.py index 1e1d2ba..191331f 100644 --- a/prolif/residue.py +++ b/prolif/residue.py @@ -36,6 +36,9 @@ def __init__(self, self.resid += f".{self.chain}" def __repr__(self): + return f"ResidueId({self.name}, {self.number}, {self.chain})" + + def __str__(self): return self.resid def __hash__(self): diff --git a/prolif/utils.py b/prolif/utils.py index 9857707..e05f86c 100644 --- a/prolif/utils.py +++ b/prolif/utils.py @@ -3,7 +3,9 @@ ======================================== """ from math import pi +from collections import defaultdict from collections.abc import Iterable +from copy import deepcopy import numpy as np import pandas as pd from scipy.spatial import cKDTree @@ -171,33 +173,53 @@ def to_dataframe(ifp, interactions, index_col="Frame", dtype=None, ... """ + ifp = deepcopy(ifp) n_interactions = len(interactions) - data = pd.DataFrame(ifp) - data.set_index(index_col, inplace=True) - # sort columns by ResidueIds and interaction - data.sort_index(axis=1, inplace=True) - data.columns = pd.MultiIndex.from_tuples(data.columns) - # check if dealing with single values or atom indices - value = data.values[0, 0][0] - is_iterable = isinstance(value, Iterable) - # replace NaNs with appropriate values empty_value = dtype(False) if dtype else False - fill_value = [None, None] if is_iterable else empty_value - data = data.applymap(lambda x: [fill_value] * n_interactions - if x is np.nan else x) - # split each bitvector in separate columns for each interaction - df = pd.DataFrame() - for l, p in data.columns: - cols = [(str(l), str(p), i) for i in interactions] - df[cols] = data[(l, p)].apply(pd.Series) - df.columns = pd.MultiIndex.from_tuples( - df.columns, names=["ligand", "protein", "interaction"]) + # residue pairs + keys = sorted(set([k for d in ifp for k in d.keys() if k != index_col])) + # check if each interaction value is a list of atom indices or smthg else + for k in keys: + if k in ifp[0].keys(): + break + is_atompair = isinstance(ifp[0][k][0], Iterable) + # create empty array for each residue pair interaction that doesn't exist + # in a particular frame + if is_atompair: + empty_arr = [[None, None]] * n_interactions + else: + empty_arr = np.array([empty_value] * n_interactions) + # sparse to dense + data = defaultdict(list) + index = [] + for d in ifp: + index.append(d.pop(index_col)) + for key in keys: + try: + data[key].append(d[key]) + except KeyError: + data[key].append(empty_arr) + # create dataframes + values = np.array([np.hstack([np.ravel(a[i]) for a in data.values()]) + for i in range(len(index))]) + if is_atompair: + columns = pd.MultiIndex.from_tuples([(str(k[0]), str(k[1]), i, a) for k in keys + for i in interactions for a in ["ligand", "protein"]], + names=["ligand", "protein", "interaction", "atom"]) + else: + columns = pd.MultiIndex.from_tuples([(str(k[0]), str(k[1]), i) for k in keys + for i in interactions], + names=["ligand", "protein", "interaction"]) + index = pd.Series(index, name=index_col) + df = pd.DataFrame(values, columns=columns, index=index) + if is_atompair: + df = df.groupby(axis=1, level=["ligand", "protein", "interaction"]).agg(tuple) if dtype: df = df.astype(dtype) if drop_empty: - if is_iterable: + if is_atompair: mask = df.apply(lambda s: - ~(s.map(tuple).isin([(None, None)]).all()), axis=0) + ~(s.isin([(None, None)]).all()), axis=0) else: mask = (df != empty_value).any(axis=0) df = df.loc[:, mask] diff --git a/tests/test_residues.py b/tests/test_residues.py index b4351cf..f78c0d6 100644 --- a/tests/test_residues.py +++ b/tests/test_residues.py @@ -115,6 +115,17 @@ def test_lt(self, res1, res2): res2 = ResidueId.from_string(res2) assert res1 < res2 + @pytest.mark.parametrize("resid_str", [ + "ALA1.A", + "DA2.B", + "HIS3", + "GLU", + ]) + def test_repr(self, resid_str): + resid = ResidueId.from_string(resid_str) + expected = f"ResidueId({resid.name}, {resid.number}, {resid.chain})" + assert repr(resid) == expected + class TestResidue(TestBaseRDKitMol): @pytest.fixture(scope="class") diff --git a/tests/test_utils.py b/tests/test_utils.py index fcc6a07..c359e3d 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -133,13 +133,13 @@ def test_to_df_atom_pairs(): assert df.shape == (2, 4) assert df.index.name == "Frame" assert ("LIG", "ALA1", "A") in df.columns - assert df[("LIG", "ALA1", "A")][0] == [0, 1] + assert df[("LIG", "ALA1", "A")][0] == (0, 1) assert ("LIG", "ALA1", "B") in df.columns - assert df[("LIG", "ALA1", "B")][0] == [None, None] + assert df[("LIG", "ALA1", "B")][0] == (None, None) assert ("LIG", "ALA1", "C") not in df.columns assert ("LIG", "GLU2", "A") not in df.columns assert ("LIG", "ASP3", "B") in df.columns - assert df[("LIG", "ASP3", "B")][0] == [None, None] + assert df[("LIG", "ASP3", "B")][0] == (None, None) @pytest.mark.parametrize("dtype", [