Skip to content

Commit

Permalink
Merge pull request #692 from rhayes777/feature/jax_laplace
Browse files Browse the repository at this point in the history
fix jax in laplace by casting array to float
  • Loading branch information
Jammy2211 authored Mar 6, 2023
2 parents dfa33e3 + d96a85c commit 436e1b6
Showing 1 changed file with 45 additions and 53 deletions.
98 changes: 45 additions & 53 deletions autofit/graphical/laplace/newton.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def gradient_ascent(state: OptimisationState, **kwargs) -> VariableData:
def newton_direction(state: OptimisationState, **kwargs) -> VariableData:
return state.hessian.ldiv(state.gradient)


def newton_abs_direction(state: OptimisationState, d=1e-6, **kwargs) -> VariableData:
posdef = state.hessian.abs().diagonalupdate(state.parameters.full_like(d))
return posdef.ldiv(state.gradient)
Expand All @@ -34,7 +35,7 @@ def newton_abs_direction(state: OptimisationState, d=1e-6, **kwargs) -> Variable


def sr1_update(
state1: OptimisationState, state: OptimisationState, mintol=1e-8, **kwargs
state1: OptimisationState, state: OptimisationState, mintol=1e-8, **kwargs
) -> OptimisationState:
yk = VariableData.sub(state1.gradient, state.gradient)
dk = VariableData.sub(state1.parameters, state.parameters)
Expand All @@ -57,7 +58,7 @@ def sr1_update(


def diag_sr1_update(
state1: OptimisationState, state: OptimisationState, tol=1e-8, **kwargs
state1: OptimisationState, state: OptimisationState, tol=1e-8, **kwargs
) -> OptimisationState:
yk = VariableData.sub(state1.gradient, state.gradient)
dk = VariableData.sub(state1.parameters, state.parameters)
Expand All @@ -76,7 +77,7 @@ def diag_sr1_update(


def diag_sr1_update_(
state1: OptimisationState, state: OptimisationState, tol=1e-8, **kwargs
state1: OptimisationState, state: OptimisationState, tol=1e-8, **kwargs
) -> OptimisationState:
yk = VariableData.sub(state1.gradient, state.gradient)
dk = VariableData.sub(state1.parameters, state.parameters)
Expand All @@ -99,7 +100,7 @@ def diag_sr1_update_(


def diag_sr1_bfgs_update(
state1: OptimisationState, state: OptimisationState, **kwargs
state1: OptimisationState, state: OptimisationState, **kwargs
) -> OptimisationState:
yk = VariableData.sub(state1.gradient, state.gradient)
dk = VariableData.sub(state1.parameters, state.parameters)
Expand All @@ -109,9 +110,7 @@ def diag_sr1_bfgs_update(


def bfgs1_update(
state1: OptimisationState,
state: OptimisationState,
**kwargs,
state1: OptimisationState, state: OptimisationState, **kwargs,
) -> OptimisationState:
"""
y_k = g_{k+1} - g{k}
Expand Down Expand Up @@ -139,9 +138,7 @@ def bfgs1_update(


def bfgs_update(
state1: OptimisationState,
state: OptimisationState,
**kwargs,
state1: OptimisationState, state: OptimisationState, **kwargs,
) -> OptimisationState:
yk = VariableData.sub(state1.gradient, state.gradient)
dk = VariableData.sub(state1.parameters, state.parameters)
Expand All @@ -158,9 +155,7 @@ def bfgs_update(


def quasi_deterministic_update(
state1: OptimisationState,
state: OptimisationState,
**kwargs,
state1: OptimisationState, state: OptimisationState, **kwargs,
) -> OptimisationState:
dk = VariableData.sub(state1.parameters, state.parameters)
zk = VariableData.sub(
Expand All @@ -179,9 +174,7 @@ def quasi_deterministic_update(


def diag_quasi_deterministic_update(
state1: OptimisationState,
state: OptimisationState,
**kwargs,
state1: OptimisationState, state: OptimisationState, **kwargs,
) -> OptimisationState:
dk = VariableData.sub(state1.parameters, state.parameters)
zk = VariableData.sub(
Expand All @@ -191,7 +184,7 @@ def diag_quasi_deterministic_update(
zk2 = zk ** 2
zk4 = (zk2 ** 2).sum()
alpha = (dk.dot(Bxk.dot(dk)) - zk.dot(Bzk.dot(zk))) / zk4
state1.det_hessian = Bzk.diagonalupdate(alpha * zk2)
state1.det_hessian = Bzk.diagonalupdate(float(alpha) * zk2)

return state1

Expand All @@ -202,10 +195,7 @@ def __init__(self, quasi_newton_update, det_quasi_newton_update):
self.det_quasi_newton_update = det_quasi_newton_update

def __call__(
self,
state1: OptimisationState,
state: OptimisationState,
**kwargs,
self, state1: OptimisationState, state: OptimisationState, **kwargs,
) -> OptimisationState:

# Only update estimate if a step has been taken
Expand All @@ -225,28 +215,28 @@ def __call__(


def take_step(
state: OptimisationState,
old_state: Optional[OptimisationState] = None,
*,
search_direction=newton_abs_direction,
calc_line_search=line_search,
search_direction_kws: Optional[Dict[str, Any]] = None,
line_search_kws: Optional[Dict[str, Any]] = None,
state: OptimisationState,
old_state: Optional[OptimisationState] = None,
*,
search_direction=newton_abs_direction,
calc_line_search=line_search,
search_direction_kws: Optional[Dict[str, Any]] = None,
line_search_kws: Optional[Dict[str, Any]] = None,
) -> Tuple[Optional[float], OptimisationState]:
state.search_direction = search_direction(state, **(search_direction_kws or {}))
return calc_line_search(state, old_state, **(line_search_kws or {}))


def take_quasi_newton_step(
state: OptimisationState,
old_state: Optional[OptimisationState] = None,
*,
search_direction=newton_abs_direction,
calc_line_search=line_search,
quasi_newton_update=full_bfgs_update,
search_direction_kws: Optional[Dict[str, Any]] = None,
line_search_kws: Optional[Dict[str, Any]] = None,
quasi_newton_kws: Optional[Dict[str, Any]] = None,
state: OptimisationState,
old_state: Optional[OptimisationState] = None,
*,
search_direction=newton_abs_direction,
calc_line_search=line_search,
quasi_newton_update=full_bfgs_update,
search_direction_kws: Optional[Dict[str, Any]] = None,
line_search_kws: Optional[Dict[str, Any]] = None,
quasi_newton_kws: Optional[Dict[str, Any]] = None,
) -> Tuple[Optional[float], OptimisationState]:
""" """
state.search_direction = search_direction(state, **(search_direction_kws or {}))
Expand Down Expand Up @@ -314,7 +304,7 @@ def ngev_condition(state, old_state, maxgev=10000, **kwargs):


def check_stop_conditions(
stepsize, state, old_state, stop_conditions, **stop_kws
stepsize, state, old_state, stop_conditions, **stop_kws
) -> Optional[Tuple[bool, str]]:
if stepsize is None:
return False, "abnormal termination of line search"
Expand All @@ -328,20 +318,20 @@ def check_stop_conditions(


def optimise_quasi_newton(
state: OptimisationState,
old_state: Optional[OptimisationState] = None,
*,
max_iter=100,
search_direction=newton_abs_direction,
calc_line_search=line_search,
quasi_newton_update=bfgs_update,
stop_conditions=stop_conditions,
search_direction_kws: Optional[Dict[str, Any]] = None,
line_search_kws: Optional[Dict[str, Any]] = None,
quasi_newton_kws: Optional[Dict[str, Any]] = None,
stop_kws: Optional[Dict[str, Any]] = None,
callback: Optional[_OPT_CALLBACK] = None,
**kwargs,
state: OptimisationState,
old_state: Optional[OptimisationState] = None,
*,
max_iter=100,
search_direction=newton_abs_direction,
calc_line_search=line_search,
quasi_newton_update=bfgs_update,
stop_conditions=stop_conditions,
search_direction_kws: Optional[Dict[str, Any]] = None,
line_search_kws: Optional[Dict[str, Any]] = None,
quasi_newton_kws: Optional[Dict[str, Any]] = None,
stop_kws: Optional[Dict[str, Any]] = None,
callback: Optional[_OPT_CALLBACK] = None,
**kwargs,
) -> Tuple[OptimisationState, Status]:
success = True
updated = False
Expand All @@ -356,7 +346,9 @@ def optimise_quasi_newton(
success, message = stop
break

with LogWarnings(logger=_log_projection_warnings, action='always') as caught_warnings:
with LogWarnings(
logger=_log_projection_warnings, action="always"
) as caught_warnings:
try:
stepsize, state1 = take_quasi_newton_step(
state,
Expand Down

0 comments on commit 436e1b6

Please sign in to comment.