Skip to content

Commit

Permalink
1. update pool_decorator 2. convert lambda func to def func to avoid …
Browse files Browse the repository at this point in the history
…conflicts with pickle
  • Loading branch information
Yaoyx committed Oct 8, 2024
1 parent a4d5c6d commit 3472949
Showing 1 changed file with 1 addition and 113 deletions.
114 changes: 1 addition & 113 deletions tests/test_expected.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,116 +732,4 @@ def test_diagsum_from_array():
exp = _diagsum_symm_dense(ar, bad_bins=list(range(3, 5)))
exp1 = diagsum_from_array(ar, ignore_diags=0)
exp1["balanced.avg"] = exp1["balanced.sum"] / exp1["n_valid"]
assert np.allclose(exp, exp1["balanced.avg"].values, equal_nan=True)

from concurrent.futures import ProcessPoolExecutor
def test_multiprocessing_expected_cis(request):
# perform test:
clr = cooler.Cooler(op.join(request.fspath.dirname, "data/CN.mm9.1000kb.cool"))
# symm result - engaging diagsum_symm
with ProcessPoolExecutor(8) as p:
res_symm = cooltools.api.expected.expected_cis(
clr,
view_df=view_df,
clr_weight_name=clr_weight_name,
chunksize=chunksize,
ignore_diags=ignore_diags,
map_functor=p.map
)

# check column names
assert list(res_symm.columns) == [
"region1",
"region2",
"dist",
"dist_bp",
"contact_frequency",
"n_total",
"n_valid",
"count.sum",
"balanced.sum",
"count.avg",
"balanced.avg",
"balanced.avg.smoothed",
"balanced.avg.smoothed.agg",
]

# check results for every block
grouped = res_symm.groupby(["region1", "region2"])
for (name1, name2), group in grouped:
assert name1 == name2
matrix = clr.matrix(balance=clr_weight_name).fetch(name1)
desired_expected = _diagsum_symm_dense(matrix)
# fill nan for ignored diags
desired_expected = np.where(
group["dist"] < ignore_diags, np.nan, desired_expected
)
testing.assert_allclose(
actual=group["balanced.avg"].values,
desired=desired_expected,
equal_nan=True,
)

# check column names, when clr_weight_name = None, which is the unbalanced case
with ProcessPoolExecutor(8) as p:
res_symm = cooltools.api.expected.expected_cis(
clr,
view_df=view_df,
clr_weight_name=None,
chunksize=chunksize,
ignore_diags=ignore_diags,
map_functor=p.map
)
assert list(res_symm.columns) == [
"region1",
"region2",
"dist",
"dist_bp",
"contact_frequency",
"n_total",
"n_valid",
"count.sum",
"count.avg",
"count.avg.smoothed",
"count.avg.smoothed.agg",
]

# asymm and symm result together - engaging diagsum_pairwise
res_all = cooltools.api.expected.expected_cis(
clr,
view_df=view_df,
intra_only=False,
clr_weight_name=clr_weight_name,
chunksize=chunksize,
ignore_diags=ignore_diags,
)
# check results for every block
grouped = res_all.groupby(["region1", "region2"])
for (name1, name2), group in grouped:
matrix = clr.matrix(balance=clr_weight_name).fetch(name1, name2)
desired_expected = (
_diagsum_asymm_dense(matrix)
if (name1 != name2)
else _diagsum_symm_dense(matrix)
)
# fill nan for ignored diags
desired_expected = np.where(
group["dist"] < ignore_diags, np.nan, desired_expected
)
testing.assert_allclose(
actual=group["balanced.avg"].values,
desired=desired_expected,
equal_nan=True,
)

# check multiprocessed result
res_all_pooled = cooltools.api.expected.expected_cis(
clr,
view_df=view_df,
intra_only=False,
clr_weight_name=clr_weight_name,
chunksize=chunksize,
ignore_diags=ignore_diags,
nproc=3,
)
assert res_all.equals(res_all_pooled)
assert np.allclose(exp, exp1["balanced.avg"].values, equal_nan=True)

0 comments on commit 3472949

Please sign in to comment.