Skip to content

Commit

Permalink
Add more tests for filter_obs()
Browse files Browse the repository at this point in the history
  • Loading branch information
gtca committed Oct 24, 2024
1 parent cc3f551 commit db905d6
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 3 deletions.
8 changes: 5 additions & 3 deletions muon/_core/preproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,6 +739,7 @@ def func(x):
# will fail due to _validate_value()
attrm = dict(attrm)
attrp = dict(attrp)
layers = dict(data.layers)

# Subset .obs/.var
setattr(data, f"_{attr}", df[subset])
Expand Down Expand Up @@ -773,11 +774,12 @@ def func(x):
data.filename = None

# Subset layers
for layer in data.layers:
for layer in layers:
if attr == "obs":
data.layers[layer] = data.layers[layer][subset, :]
layers[layer] = layers[layer][subset, :]
else:
data.layers[layer] = data.layers[layer][:, subset]
layers[layer] = layers[layer][:, subset]
data.layers = layers

# Subset raw - only when subsetting obs
if attr == "obs" and data.raw is not None:
Expand Down
39 changes: 39 additions & 0 deletions tests/test_muon_preproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,45 @@ def test_filter_obs_with_obsm_obsp(self, pbmc3k_processed):
assert_equal(mdata["A"], A_subset)
assert_equal(mdata["B"], B_subset)

def test_filter_obs_with_obsm_obsp_explicit(self, mdata):
mdata = mdata.copy()

# obsm
np.random.seed(42)
mdata["mod1"].obsm["X_normal"] = np.random.normal(size=(mdata["mod1"].n_obs, 10))
mdata["mod2"].obsm["X_normal"] = np.random.normal(size=(mdata["mod2"].n_obs, 10))
mdata.obsm["X_normal"] = np.random.normal(size=(mdata.n_obs, 10))
selection = mdata.obsm["X_normal"].sum(axis=1) > 0

# obsp
mdata["mod1"].obsp["connectivities"] = np.random.normal(
size=(mdata["mod1"].n_obs, mdata["mod1"].n_obs)
)
mdata["mod2"].obsp["connectivities"] = np.random.normal(
size=(mdata["mod2"].n_obs, mdata["mod2"].n_obs)
)
mdata.obsp["connectivities"] = np.random.normal(size=(mdata.n_obs, mdata.n_obs))

mu.pp.filter_obs(mdata, selection)
assert mdata.n_obs == selection.sum()

def test_filter_obs_anndata(self, mdata):
adata = mdata["mod1"].copy()

# layers
adata.layers["X2"] = adata.X**2

# obsm
np.random.seed(42)
adata.obsm["X_normal"] = np.random.normal(size=(adata.n_obs, 10))
selection = adata.obsm["X_normal"].sum(axis=1) > 0

# obsp
adata.obsp["connectivities"] = np.random.normal(size=(adata.n_obs, adata.n_obs))

mu.pp.filter_obs(adata, selection)
assert adata.n_obs == selection.sum()

# Variables

def test_filter_var_adata(self, mdata, filepath_h5mu):
Expand Down

0 comments on commit db905d6

Please sign in to comment.