Skip to content

Commit

Permalink
Updated Hessian code and added test (#89)
Browse files Browse the repository at this point in the history
  • Loading branch information
AndPotap authored Apr 2, 2024
1 parent 4af199e commit 9562ae1
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 12 deletions.
15 changes: 3 additions & 12 deletions cola/ops/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,23 +484,14 @@ class Hessian(LinearOperator):
>>> op = Hessian(f, x)
"""
def __init__(self, f, x):
self.f = f
self.x = x
self.f, self.x = f, x
assert len(x.shape) == 1, "x must be a vector"
super().__init__(dtype=x.dtype, shape=(x.shape[0], x.shape[0]))

def _matmat(self, X):
xnp = self.xnp
# hack to make it work with pytorch
if xnp.__name__ == 'cola.backends.torch_fns' and False:
expanded_x = self.x[None, :] + self.xnp.zeros((X.shape[0], 1), dtype=self.x.dtype, device=self.device)
fn = partial(self.xnp.vjp_derivs, self.xnp.vmap(self.xnp.grad(self.f)), (expanded_x, ))
out = fn((X, ))
else:
mvm = partial(xnp.jvp_derivs, xnp.grad(self.f), (self.x, ), create_graph=False)
out = xnp.vmap(mvm)((X.T, )).T
if xnp.__name__ == 'cola.backends.torch_fns': # pytorch converts to double silently
out = out.to(dtype=self.dtype, device=self.device)
mvm = partial(xnp.jvp_derivs, xnp.grad(self.f), (self.x, ), create_graph=False)
out = xnp.vmap(mvm)((X.T, )).T
return out

def __str__(self):
Expand Down
21 changes: 21 additions & 0 deletions tests/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,34 @@
from cola.ops import Jacobian
from cola.ops import LinearOperator
from cola.ops import Kernel
from cola.ops import Hessian
from cola.linalg.decompositions.arnoldi import get_householder_vec
from cola.utils.test_utils import get_xnp, parametrize, relative_error
from cola.backends import all_backends, tracing_backends
from linalg.operator_market import op_names, get_test_operator

_tol = 1e-6


@parametrize(tracing_backends)
def test_Hessian(backend):
xnp = get_xnp(backend)
dtype = xnp.float32
P = 27
cons = xnp.array([idx for idx in range(P)], dtype=dtype, device=None)
x = xnp.ones(shape=(P, ), device=None, dtype=dtype)

def fn(z):
out = cons * z**2.
return xnp.sum(out)

H = Hessian(fn, x)
approx = H.to_dense()
soln = 2 * xnp.diag(cons)
rel_error = relative_error(approx, soln, xnp=xnp)
assert rel_error < _tol


_exclude = (slice(None), slice(None), ['square_fft'])


Expand Down

0 comments on commit 9562ae1

Please sign in to comment.