Skip to content

Commit

Permalink
optional arguments for logistic_regression() in mediation_analysis() (#…
Browse files Browse the repository at this point in the history
…245)

* added dict with optional arguments for logistic_regression() to declaration of mediation_analysis()

* Add test to check if `logreg_kwargs` is being passed on to `LogisticRegression`

* Add more tests for `logreg_kwargs`

Co-authored-by: Julian Libiseller-Egger <[email protected]>
  • Loading branch information
julibeg and julibeg authored Mar 19, 2022
1 parent 367c935 commit b08ad14
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 7 deletions.
24 changes: 17 additions & 7 deletions pingouin/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,7 +917,7 @@ def logistic_regression(X, y, coef_only=False, alpha=0.05,


def _point_estimate(X_val, XM_val, M_val, y_val, idx, n_mediator,
mtype='linear'):
mtype='linear', **logreg_kwargs):
"""Point estimate of indirect effect based on bootstrap sample."""
# Mediator(s) model (M(j) ~ X + covar)
beta_m = []
Expand All @@ -926,8 +926,8 @@ def _point_estimate(X_val, XM_val, M_val, y_val, idx, n_mediator,
beta_m.append(linear_regression(X_val[idx], M_val[idx, j],
coef_only=True)[1])
else:
beta_m.append(logistic_regression(X_val[idx], M_val[idx, j],
coef_only=True)[1])
beta_m.append(logistic_regression(
X_val[idx], M_val[idx, j], coef_only=True, **logreg_kwargs)[1])

# Full model (Y ~ X + M + covar)
beta_y = linear_regression(XM_val[idx], y_val[idx],
Expand Down Expand Up @@ -985,7 +985,8 @@ def _pval_from_bootci(boot, estimate):

@pf.register_dataframe_method
def mediation_analysis(data=None, x=None, m=None, y=None, covar=None,
alpha=0.05, n_boot=500, seed=None, return_dist=False):
alpha=0.05, n_boot=500, seed=None, return_dist=False,
logreg_kwargs=None):
"""Mediation analysis using a bias-correct non-parametric bootstrap method.
Parameters
Expand Down Expand Up @@ -1013,6 +1014,9 @@ def mediation_analysis(data=None, x=None, m=None, y=None, covar=None,
estimation. The greater, the slower.
seed : int or None
Random state seed.
logreg_kwargs : dict or None
Dictionary with optional arguments passed to
:py:func:`logistic_regression()`
return_dist : bool
If True, the function also returns the indirect bootstrapped beta
samples (size = n_boot). Can be plotted for instance using
Expand Down Expand Up @@ -1182,6 +1186,9 @@ def mediation_analysis(data=None, x=None, m=None, y=None, covar=None,
# Check if mediator is binary
mtype = 'logistic' if all(data[m].nunique() == 2) else 'linear'

# Check if a dict with kwargs for logistic_regression has been passed
logreg_kwargs = {} if logreg_kwargs is None else logreg_kwargs

# Name of CI
ll_name = 'CI[%.1f%%]' % (100 * alpha / 2)
ul_name = 'CI[%.1f%%]' % (100 * (1 - alpha / 2))
Expand All @@ -1205,7 +1212,8 @@ def mediation_analysis(data=None, x=None, m=None, y=None, covar=None,
if mtype == 'linear':
sxm[j] = linear_regression(X_val, M_val[:, idx], alpha=alpha).loc[[1], cols]
else:
sxm[j] = logistic_regression(X_val, M_val[:, idx], alpha=alpha).loc[[1], cols]
sxm[j] = logistic_regression(X_val, M_val[:, idx], alpha=alpha,
**logreg_kwargs).loc[[1], cols]
sxm[j].at[1, 'names'] = '%s ~ X' % j
sxm = pd.concat(sxm, ignore_index=True)

Expand All @@ -1231,9 +1239,11 @@ def mediation_analysis(data=None, x=None, m=None, y=None, covar=None,
ab_estimates = np.zeros(shape=(n_boot, n_mediator))
for i in range(n_boot):
ab_estimates[i, :] = _point_estimate(
X_val, XM_val, M_val, y_val, idx[i, :], n_mediator, mtype)
X_val, XM_val, M_val, y_val, idx[i, :], n_mediator, mtype,
**logreg_kwargs)

ab = _point_estimate(X_val, XM_val, M_val, y_val, np.arange(n), n_mediator, mtype)
ab = _point_estimate(X_val, XM_val, M_val, y_val, np.arange(n), n_mediator,
mtype, **logreg_kwargs)
indirect = {'names': m, 'coef': ab, 'se': ab_estimates.std(ddof=1, axis=0),
'pval': [], ll_name: [], ul_name: [], 'sig': []}

Expand Down
12 changes: 12 additions & 0 deletions pingouin/tests/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,18 @@ def test_mediation_analysis(self):
assert_almost_equal(ma['CI[97.5%]'][3], 0.617, decimal=1)
assert ma['sig'][3] == 'Yes'

# Check if `logreg_kwargs` is being passed on to `LogisticRegression`
with pytest.raises(ValueError):
mediation_analysis(data=df, x='X', m='Mbin', y='Y', n_boot=2000,
logreg_kwargs=dict(max_iter=-1))
# Solve with 0 iterations and make sure that the results are different
ma = mediation_analysis(data=df, x='X', m='Mbin', y='Y', n_boot=2000,
logreg_kwargs=dict(max_iter=0))
with pytest.raises(AssertionError):
assert_almost_equal(ma['coef'][0], -0.0208, decimal=2)
with pytest.raises(AssertionError):
assert_almost_equal(ma['coef'][4], 0.0033, decimal=3)

# With multiple mediator
np.random.seed(42)
df.rename(columns={"M": "M1"}, inplace=True)
Expand Down

0 comments on commit b08ad14

Please sign in to comment.