Skip to content

Commit

Permalink
ENH add adapt = False (prior_scale=1) option to lfc_shrink() (#267)
Browse files Browse the repository at this point in the history
* added prior_scale 1 option

* test: add a lfc shrinkage test without adaptative fit

* refactor: skip prio_var fitting in lfc_shrink if adapt==False

* test: rename test data file for consistency

---------

Co-authored-by: Boris MUZELLEC <[email protected]>
  • Loading branch information
awalsh17 and BorisMuzellec authored Apr 12, 2024
1 parent fe7e8f5 commit 92f1d09
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 3 deletions.
11 changes: 8 additions & 3 deletions pydeseq2/ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def run_wald_test(self) -> None:
self.statistics.loc[self.dds.new_all_zeroes_genes] = 0.0
self.p_values.loc[self.dds.new_all_zeroes_genes] = 1.0

def lfc_shrink(self, coeff: Optional[str] = None) -> None:
def lfc_shrink(self, coeff: Optional[str] = None, adapt: bool = True) -> None:
"""LFC shrinkage with an apeGLM prior :cite:p:`DeseqStats-zhu2019heavy`.
Shrinks LFCs using a heavy-tailed Cauchy prior, leaving p-values unchanged.
Expand All @@ -343,6 +343,9 @@ def lfc_shrink(self, coeff: Optional[str] = None) -> None:
If the desired coefficient is not available, it may be set from the
:class:`pydeseq2.dds.DeseqDataSet` argument ``ref_level``.
(default: ``None``).
adapt: bool
Whether to use the MLE estimates of LFC to adapt the prior. If False, the
prior scale is set to 1. (``default=True``)
"""
if self.contrast[1] == self.contrast[2] == "":
# The factor being tested is continuous
Expand Down Expand Up @@ -390,8 +393,10 @@ def lfc_shrink(self, coeff: Optional[str] = None) -> None:

# Set priors
prior_no_shrink_scale = 15
prior_var = self._fit_prior_var(coeff_idx=coeff_idx)
prior_scale = np.minimum(np.sqrt(prior_var), 1)
prior_scale = 1
if adapt:
prior_var = self._fit_prior_var(coeff_idx=coeff_idx)
prior_scale = np.minimum(np.sqrt(prior_var), 1)

design_matrix = self.design_matrix.values

Expand Down
11 changes: 11 additions & 0 deletions tests/data/single_factor/r_test_lfc_shrink_no_apeAdapt_res.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"","baseMean","log2FoldChange","lfcSE","pvalue","padj"
"gene1",8.54131729397935,0.591975666632662,0.283593478291824,0.028605778832833,0.06413253606192
"gene2",21.2812387436367,0.528482010716568,0.149036804392133,0.000329116482202811,0.00164558241101406
"gene3",5.01012348853472,-0.589690914974718,0.290216684048141,0.03206626803096,0.06413253606192
"gene4",100.51796142035,-0.406910458525142,0.118064287344954,0.000512946548270916,0.00170982182756972
"gene5",27.1424502740787,0.570819621477817,0.153693224107584,0.000168688601657692,0.00164558241101406
"gene6",5.4130427476525,0.00154719287124319,0.297994906704438,0.996252928924723,0.996252928924723
"gene7",28.2940230404605,0.13158763131402,0.148675587725712,0.370439870361461,0.411599855957179
"gene8",40.3583444203556,-0.266280574143726,0.135444467327823,0.0472273333199321,0.0787122221998868
"gene9",37.1661826339853,-0.20919599744192,0.132384213293892,0.110392490019143,0.143143368673419
"gene10",11.5893249023836,0.366288276213398,0.239712866785749,0.114514694938736,0.143143368673419
45 changes: 45 additions & 0 deletions tests/test_pydeseq2.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,51 @@ def test_lfc_shrinkage(counts_df, metadata, tol=0.02):
).max() < tol


def test_lfc_shrinkage_no_apeAdapt(counts_df, metadata, tol=0.02):
"""Test that the outputs of the lfc_shrink function match those of the original
R package (starting from the same inputs), up to a tolerance in relative error.
"""

test_path = str(Path(os.path.realpath(tests.__file__)).parent.resolve())
r_res = pd.read_csv(
os.path.join(test_path, "data/single_factor/r_test_res.csv"), index_col=0
)
r_shrunk_res = pd.read_csv(
os.path.join(
test_path, "data/single_factor/r_test_lfc_shrink_no_apeAdapt_res.csv"
),
index_col=0,
)

r_size_factors = pd.read_csv(
os.path.join(test_path, "data/single_factor/r_test_size_factors.csv"),
index_col=0,
).squeeze()

r_dispersions = pd.read_csv(
os.path.join(test_path, "data/single_factor/r_test_dispersions.csv"),
index_col=0,
).squeeze()

dds = DeseqDataSet(counts=counts_df, metadata=metadata, design_factors="condition")
dds.deseq2()
dds.obsm["size_factors"] = r_size_factors.values
dds.varm["dispersions"] = r_dispersions.values
dds.varm["LFC"].iloc[:, 1] = r_res.log2FoldChange.values * np.log(2)

res = DeseqStats(dds)
res.summary()
res.SE = r_res.lfcSE * np.log(2)
res.lfc_shrink(coeff="condition_B_vs_A", adapt=False)
shrunk_res = res.results_df

# Check that the same LFC are found (up to tol)
assert (
abs(r_shrunk_res.log2FoldChange - shrunk_res.log2FoldChange)
/ abs(r_shrunk_res.log2FoldChange)
).max() < tol


def test_iterative_size_factors(counts_df, metadata, tol=0.02):
"""Test that the outputs of the iterative size factor method match those of the
original R package (starting from the same inputs), up to a tolerance in relative
Expand Down

0 comments on commit 92f1d09

Please sign in to comment.