Skip to content

Commit

Permalink
Add warning about future change in hessian sign
Browse files Browse the repository at this point in the history
  • Loading branch information
aseyboldt authored and ricardoV94 committed Jun 16, 2023
1 parent 77f24d7 commit dfd1640
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 16 deletions.
8 changes: 6 additions & 2 deletions pymc/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,7 @@ def compile_d2logp(
self,
vars: Optional[Union[Variable, Sequence[Variable]]] = None,
jacobian: bool = True,
negate_output=True,
) -> PointFunc:
"""Compiled log probability density hessian function.
Expand All @@ -707,7 +708,9 @@ def compile_d2logp(
jacobian:
Whether to include jacobian terms in logprob graph. Defaults to True.
"""
return self.model.compile_fn(self.d2logp(vars=vars, jacobian=jacobian))
return self.model.compile_fn(
self.d2logp(vars=vars, jacobian=jacobian, negate_output=negate_output)
)

def logp(
self,
Expand Down Expand Up @@ -830,6 +833,7 @@ def d2logp(
self,
vars: Optional[Union[Variable, Sequence[Variable]]] = None,
jacobian: bool = True,
negate_output=True,
) -> Variable:
"""Hessian of the models log-probability w.r.t. ``vars``.
Expand Down Expand Up @@ -862,7 +866,7 @@ def d2logp(
)

cost = self.logp(jacobian=jacobian)
return hessian(cost, value_vars)
return hessian(cost, value_vars, negate_output=negate_output)

@property
def datalogp(self) -> Variable:
Expand Down
26 changes: 22 additions & 4 deletions pymc/pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,8 +525,17 @@ def grad_ii(i, f, x):


@pytensor.config.change_flags(compute_test_value="ignore")
def hessian(f, vars=None):
return -jacobian(gradient(f, vars), vars)
def hessian(f, vars=None, negate_output=True):
res = jacobian(gradient(f, vars), vars)
if negate_output:
warnings.warn(
"hessian will stop negating the output in a future version of PyMC.\n"
"To suppress this warning set `negate_output=False`",
FutureWarning,
stacklevel=2,
)
res = -res
return res


@pytensor.config.change_flags(compute_test_value="ignore")
Expand All @@ -541,12 +550,21 @@ def hess_ii(i):


@pytensor.config.change_flags(compute_test_value="ignore")
def hessian_diag(f, vars=None):
def hessian_diag(f, vars=None, negate_output=True):
if vars is None:
vars = cont_inputs(f)

if vars:
return -pt.concatenate([hessian_diag1(f, v) for v in vars], axis=0)
res = pt.concatenate([hessian_diag1(f, v) for v in vars], axis=0)
if negate_output:
warnings.warn(
"hessian_diag will stop negating the output in a future version of PyMC.\n"
"To suppress this warning set `negate_output=False`",
FutureWarning,
stacklevel=2,
)
res = -res
return res
else:
return empty_gradient

Expand Down
2 changes: 1 addition & 1 deletion pymc/sampling/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1418,7 +1418,7 @@ def init_nuts(
potential = quadpotential.QuadPotentialDiag(cov)
elif init == "map":
start = pm.find_MAP(include_transformed=True, seed=random_seed_list[0])
cov = pm.find_hessian(point=start)
cov = -pm.find_hessian(point=start, negate_output=False)
initial_points = [start] * chains
potential = quadpotential.QuadPotentialFull(cov)
elif init == "adapt_full":
Expand Down
10 changes: 5 additions & 5 deletions pymc/tuning/scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def fixed_hessian(point, model=None):
return rval


def find_hessian(point, vars=None, model=None):
def find_hessian(point, vars=None, model=None, negate_output=True):
"""
Returns Hessian of logp at the point passed.
Expand All @@ -55,11 +55,11 @@ def find_hessian(point, vars=None, model=None):
Variables for which Hessian is to be calculated.
"""
model = modelcontext(model)
H = model.compile_d2logp(vars)
H = model.compile_d2logp(vars, negate_output=negate_output)
return H(Point(point, filter_model_vars=True, model=model))


def find_hessian_diag(point, vars=None, model=None):
def find_hessian_diag(point, vars=None, model=None, negate_output=True):
"""
Returns Hessian of logp at the point passed.
Expand All @@ -71,14 +71,14 @@ def find_hessian_diag(point, vars=None, model=None):
Variables for which Hessian is to be calculated.
"""
model = modelcontext(model)
H = model.compile_fn(hessian_diag(model.logp(), vars))
H = model.compile_fn(hessian_diag(model.logp(), vars, negate_output=negate_output))
return H(Point(point, model=model))


def guess_scaling(point, vars=None, model=None, scaling_bound=1e-8):
model = modelcontext(model)
try:
h = find_hessian_diag(point, vars, model=model)
h = -find_hessian_diag(point, vars, model=model, negate_output=False)
except NotImplementedError:
h = fixed_hessian(point, model=model)
return adjust_scaling(h, scaling_bound)
Expand Down
8 changes: 4 additions & 4 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1041,16 +1041,16 @@ def test_model_d2logp(jacobian):
test_vals = np.array([0.0, -1.0])
state = {"x": test_vals, "y_log__": test_vals}

expected_x_d2logp = expected_y_d2logp = np.eye(2)
expected_x_d2logp = expected_y_d2logp = -np.eye(2)

dlogps = m.compile_d2logp(jacobian=jacobian)(state)
dlogps = m.compile_d2logp(jacobian=jacobian, negate_output=False)(state)
assert np.all(np.isclose(dlogps[:2, :2], expected_x_d2logp))
assert np.all(np.isclose(dlogps[2:, 2:], expected_y_d2logp))

x_dlogp2 = m.compile_d2logp(vars=[x], jacobian=jacobian)(state)
x_dlogp2 = m.compile_d2logp(vars=[x], jacobian=jacobian, negate_output=False)(state)
assert np.all(np.isclose(x_dlogp2, expected_x_d2logp))

y_dlogp2 = m.compile_d2logp(vars=[y], jacobian=jacobian)(state)
y_dlogp2 = m.compile_d2logp(vars=[y], jacobian=jacobian, negate_output=False)(state)
assert np.all(np.isclose(y_dlogp2, expected_y_d2logp))


Expand Down
16 changes: 16 additions & 0 deletions tests/test_pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
constant_fold,
convert_observed_data,
extract_obs_data,
hessian,
hessian_diag,
replace_rng_nodes,
replace_rvs_by_values,
reseed_rngs,
Expand Down Expand Up @@ -878,3 +880,17 @@ def replacement_fn(var, replacements):
[new_x], _ = _replace_vars_in_graphs([x], replacement_fn=replacement_fn)

assert new_x.eval() > 50


@pytest.mark.filterwarnings("error")
@pytest.mark.parametrize("func", (hessian, hessian_diag))
def test_hessian_sign_change_warning(func):
x = pt.vector("x")
f = (x**2).sum()
with pytest.warns(
FutureWarning,
match="will stop negating the output",
):
res_neg = func(f, vars=[x])
res = func(f, vars=[x], negate_output=False)
assert equal_computations([res_neg], [-res])

0 comments on commit dfd1640

Please sign in to comment.