Skip to content

Commit

Permalink
Merge pull request #24 from Teichlab/devel
Browse files Browse the repository at this point in the history
  • Loading branch information
emdann authored Dec 13, 2022
2 parents da9eb32 + ec4b868 commit 2a7cb64
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 31 deletions.
50 changes: 29 additions & 21 deletions src/multi_view_atlas/tl/MultiViewAtlas.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,17 +94,19 @@ def __init__(
if "full" not in mdata["full"].uns["view_hierarchy"].keys():
mdata["full"].uns["view_hierarchy"] = {"full": mdata["full"].uns["view_hierarchy"]}

if "view_assign" not in mdata.obsm:
try:
mdata.obsm["view_assign"] = mdata["full"].obsm["view_assign"]
except KeyError:
view_assign = pd.DataFrame(index=mdata["full"].obs_names)
for k, v in mdata.mod.items():
view_assign[k] = view_assign.index.isin(v.obs_names)
view_assign = view_assign.astype("int")
mdata["full"].obsm["view_assign"] = view_assign
_clean_view_assignment(mdata["full"])
mdata.obsm["view_assign"] = view_assign
# Build view assignment
# if "view_assign" not in mdata.obsm:
# try:
# mdata.obsm["view_assign"] = mdata["full"].obsm["view_assign"]
# except KeyError:
view_assign = pd.DataFrame(
np.vstack([mdata.obsm[v] for v in mdata.mod.keys()]).T.astype("int"),
index=mdata.obs_names,
columns=mdata.mod.keys(),
)

mdata.obsm["view_assign"] = view_assign
_clean_view_assignment(mdata)

# Remove var and X from views
for k in mdata.mod.keys():
Expand Down Expand Up @@ -345,16 +347,22 @@ def _dict_set_nested(d, keys, value):
def _harmonize_mdata_full(mva: MultiViewAtlas):
"""Harmonize info in mdata common slots and mdata['full']"""
# Harmonize view assignment table
try:
view_assign_key_full = [x for x in mva.mdata["full"].obsm_keys() if "view_assign" in x][0]
except IndexError:
raise AssertionError("mva.mdata['full'] does not contain a view assignment table")

full_view_assign = mva.mdata["full"].obsm[view_assign_key_full].copy()
missing_cols = np.setdiff1d(mva.mdata.obsm["view_assign"].columns, full_view_assign.columns)
if len(missing_cols) > 0:
for c in missing_cols:
mva.mdata["full"].obsm[view_assign_key_full].loc[:, c] = mva.mdata.obsm["view_assign"][c].copy()
if "view_assign" in mva.mdata["full"].obsm.keys():
view_assign_key_full = "view_assign"
elif "view_assign_full" in mva.mdata["full"].obsm.keys():
view_assign_key_full = "view_assign_full"
else:
view_assign_key_full = None

if view_assign_key_full is not None:
full_view_assign = mva.mdata["full"].obsm[view_assign_key_full].copy()
missing_cols = np.setdiff1d(mva.mdata.obsm["view_assign"].columns, full_view_assign.columns)
if len(missing_cols) > 0:
for c in missing_cols:
mva.mdata["full"].obsm[view_assign_key_full].loc[:, c] = mva.mdata.obsm["view_assign"][c].copy()
else:
mva.mdata["full"].obsm["view_assign"] = mva.mdata.obsm["view_assign"].copy()
view_assign_key_full = "view_assign"

# Reorder columns
mva.mdata["full"].obsm[view_assign_key_full] = mva.mdata["full"].obsm[view_assign_key_full][
Expand Down
39 changes: 29 additions & 10 deletions src/multi_view_atlas/tl/map_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from sklearn.neighbors import KNeighborsClassifier

from ..utils import check_transition_rule, get_views_from_structure
from .MultiViewAtlas import MultiViewAtlas
from .MultiViewAtlas import MultiViewAtlas, _harmonize_mdata_full

warnings.simplefilter(action="ignore", category=FutureWarning)
warnings.simplefilter(action="ignore", category=UserWarning)
Expand Down Expand Up @@ -64,8 +64,17 @@ def load_query(
"""
)

mvatlas_mapped = mvatlas.copy()
mvatlas_mapped.mdata.mod["full"] = vdata_full.copy()
mdata = mvatlas.mdata.copy()
mdata.mod["full"] = vdata_full.copy()
try:
mdata.update()
except KeyError:
mdata.update()
del mdata.obsm["view_assign"]
mvatlas_mapped = MultiViewAtlas(mdata, rename_obsm=False)
mvatlas_mapped.view_transition_rule = mvatlas.view_transition_rule.copy()

_harmonize_mdata_full(mvatlas_mapped)
return mvatlas_mapped


Expand Down Expand Up @@ -102,17 +111,25 @@ def split_query(
depth = row["depth"]
current_view = row["parent_view"]
next_view = row["child_view"]
if "dataset_group" in vdata_dict[current_view].obs:
try:
n_query_current = sum(vdata_dict[current_view].obs["dataset_group"] == "query")
except KeyError:
n_query_current = 0
try:
n_query_next = sum(mvatlas_mapped.mdata[next_view].obs["dataset_group"] == "query")
except KeyError:
n_query_next = 0
# if "dataset_group" in vdata_dict[current_view].obs:
if n_query_current > 0:
adata_query = vdata_dict[current_view][vdata_dict[current_view].obs["dataset_group"] == "query"].copy()
logging.info(f"Assigning to {next_view} from {current_view} with rule {row['transition_rule']}")
# print(adata_query)
# print(vdata_dict[current_view])
# print(mvatlas_mapped.mdata[current_view])
if "dataset_group" in mvatlas_mapped.mdata[next_view].obs:
if sum(mvatlas_mapped.mdata[next_view].obs["dataset_group"] == "query") > 0:
logging.info(f"Query cells already in {next_view}")
v_assign = mvatlas_mapped.mdata.obsm["view_assign"][[next_view]]
vdata_dict[next_view] = mvatlas_mapped.mdata[next_view].copy()
if n_query_next > 0:
logging.info(f"Query cells already in {next_view}")
v_assign = mvatlas_mapped.mdata.obsm["view_assign"][[next_view]]
vdata_dict[next_view] = mvatlas_mapped.mdata[next_view].copy()
else:
adata_query_concat = AnnData(obs=adata_query.obs, obsm=adata_query.obsm, obsp=adata_query.obsp)
if depth > 0:
Expand Down Expand Up @@ -205,7 +222,9 @@ def map_next_view(
# next_view_adata = next_view_adata[next_view_adata.obs[batch_key] == batch_categories[0]].copy()
# assert "dataset_group" not in next_view_adata.obs.columns
else:
v_assign = mvatlas.mdata.obsm["view_assign"][[next_view]]
v_assign = mvatlas.mdata.obsm["view_assign"].loc[mvatlas.mdata["full"].obs[batch_key] == batch_categories[0]][
[next_view]
]
transition_rule = mvatlas.view_transition_rule[current_view][next_view]
if transition_rule is not None:
try:
Expand Down

0 comments on commit 2a7cb64

Please sign in to comment.