From d96a85cf8782b088f8d9e153e8bb0a11a20ed615 Mon Sep 17 00:00:00 2001 From: Richard Date: Mon, 6 Mar 2023 14:15:09 +0000 Subject: [PATCH] fix jax in laplace by casting array to float --- autofit/graphical/laplace/newton.py | 98 +++++++++++++---------------- 1 file changed, 45 insertions(+), 53 deletions(-) diff --git a/autofit/graphical/laplace/newton.py b/autofit/graphical/laplace/newton.py index d667ca2eb..b65f6eea2 100644 --- a/autofit/graphical/laplace/newton.py +++ b/autofit/graphical/laplace/newton.py @@ -20,6 +20,7 @@ def gradient_ascent(state: OptimisationState, **kwargs) -> VariableData: def newton_direction(state: OptimisationState, **kwargs) -> VariableData: return state.hessian.ldiv(state.gradient) + def newton_abs_direction(state: OptimisationState, d=1e-6, **kwargs) -> VariableData: posdef = state.hessian.abs().diagonalupdate(state.parameters.full_like(d)) return posdef.ldiv(state.gradient) @@ -34,7 +35,7 @@ def newton_abs_direction(state: OptimisationState, d=1e-6, **kwargs) -> Variable def sr1_update( - state1: OptimisationState, state: OptimisationState, mintol=1e-8, **kwargs + state1: OptimisationState, state: OptimisationState, mintol=1e-8, **kwargs ) -> OptimisationState: yk = VariableData.sub(state1.gradient, state.gradient) dk = VariableData.sub(state1.parameters, state.parameters) @@ -57,7 +58,7 @@ def sr1_update( def diag_sr1_update( - state1: OptimisationState, state: OptimisationState, tol=1e-8, **kwargs + state1: OptimisationState, state: OptimisationState, tol=1e-8, **kwargs ) -> OptimisationState: yk = VariableData.sub(state1.gradient, state.gradient) dk = VariableData.sub(state1.parameters, state.parameters) @@ -76,7 +77,7 @@ def diag_sr1_update( def diag_sr1_update_( - state1: OptimisationState, state: OptimisationState, tol=1e-8, **kwargs + state1: OptimisationState, state: OptimisationState, tol=1e-8, **kwargs ) -> OptimisationState: yk = VariableData.sub(state1.gradient, state.gradient) dk = VariableData.sub(state1.parameters, state.parameters) @@ -99,7 +100,7 @@ def diag_sr1_update_( def diag_sr1_bfgs_update( - state1: OptimisationState, state: OptimisationState, **kwargs + state1: OptimisationState, state: OptimisationState, **kwargs ) -> OptimisationState: yk = VariableData.sub(state1.gradient, state.gradient) dk = VariableData.sub(state1.parameters, state.parameters) @@ -109,9 +110,7 @@ def diag_sr1_bfgs_update( def bfgs1_update( - state1: OptimisationState, - state: OptimisationState, - **kwargs, + state1: OptimisationState, state: OptimisationState, **kwargs, ) -> OptimisationState: """ y_k = g_{k+1} - g{k} @@ -139,9 +138,7 @@ def bfgs1_update( def bfgs_update( - state1: OptimisationState, - state: OptimisationState, - **kwargs, + state1: OptimisationState, state: OptimisationState, **kwargs, ) -> OptimisationState: yk = VariableData.sub(state1.gradient, state.gradient) dk = VariableData.sub(state1.parameters, state.parameters) @@ -158,9 +155,7 @@ def bfgs_update( def quasi_deterministic_update( - state1: OptimisationState, - state: OptimisationState, - **kwargs, + state1: OptimisationState, state: OptimisationState, **kwargs, ) -> OptimisationState: dk = VariableData.sub(state1.parameters, state.parameters) zk = VariableData.sub( @@ -179,9 +174,7 @@ def quasi_deterministic_update( def diag_quasi_deterministic_update( - state1: OptimisationState, - state: OptimisationState, - **kwargs, + state1: OptimisationState, state: OptimisationState, **kwargs, ) -> OptimisationState: dk = VariableData.sub(state1.parameters, state.parameters) zk = VariableData.sub( @@ -191,7 +184,7 @@ def diag_quasi_deterministic_update( zk2 = zk ** 2 zk4 = (zk2 ** 2).sum() alpha = (dk.dot(Bxk.dot(dk)) - zk.dot(Bzk.dot(zk))) / zk4 - state1.det_hessian = Bzk.diagonalupdate(alpha * zk2) + state1.det_hessian = Bzk.diagonalupdate(float(alpha) * zk2) return state1 @@ -202,10 +195,7 @@ def __init__(self, quasi_newton_update, det_quasi_newton_update): self.det_quasi_newton_update = det_quasi_newton_update def __call__( - self, - state1: OptimisationState, - state: OptimisationState, - **kwargs, + self, state1: OptimisationState, state: OptimisationState, **kwargs, ) -> OptimisationState: # Only update estimate if a step has been taken @@ -225,28 +215,28 @@ def __call__( def take_step( - state: OptimisationState, - old_state: Optional[OptimisationState] = None, - *, - search_direction=newton_abs_direction, - calc_line_search=line_search, - search_direction_kws: Optional[Dict[str, Any]] = None, - line_search_kws: Optional[Dict[str, Any]] = None, + state: OptimisationState, + old_state: Optional[OptimisationState] = None, + *, + search_direction=newton_abs_direction, + calc_line_search=line_search, + search_direction_kws: Optional[Dict[str, Any]] = None, + line_search_kws: Optional[Dict[str, Any]] = None, ) -> Tuple[Optional[float], OptimisationState]: state.search_direction = search_direction(state, **(search_direction_kws or {})) return calc_line_search(state, old_state, **(line_search_kws or {})) def take_quasi_newton_step( - state: OptimisationState, - old_state: Optional[OptimisationState] = None, - *, - search_direction=newton_abs_direction, - calc_line_search=line_search, - quasi_newton_update=full_bfgs_update, - search_direction_kws: Optional[Dict[str, Any]] = None, - line_search_kws: Optional[Dict[str, Any]] = None, - quasi_newton_kws: Optional[Dict[str, Any]] = None, + state: OptimisationState, + old_state: Optional[OptimisationState] = None, + *, + search_direction=newton_abs_direction, + calc_line_search=line_search, + quasi_newton_update=full_bfgs_update, + search_direction_kws: Optional[Dict[str, Any]] = None, + line_search_kws: Optional[Dict[str, Any]] = None, + quasi_newton_kws: Optional[Dict[str, Any]] = None, ) -> Tuple[Optional[float], OptimisationState]: """ """ state.search_direction = search_direction(state, **(search_direction_kws or {})) @@ -314,7 +304,7 @@ def ngev_condition(state, old_state, maxgev=10000, **kwargs): def check_stop_conditions( - stepsize, state, old_state, stop_conditions, **stop_kws + stepsize, state, old_state, stop_conditions, **stop_kws ) -> Optional[Tuple[bool, str]]: if stepsize is None: return False, "abnormal termination of line search" @@ -328,20 +318,20 @@ def check_stop_conditions( def optimise_quasi_newton( - state: OptimisationState, - old_state: Optional[OptimisationState] = None, - *, - max_iter=100, - search_direction=newton_abs_direction, - calc_line_search=line_search, - quasi_newton_update=bfgs_update, - stop_conditions=stop_conditions, - search_direction_kws: Optional[Dict[str, Any]] = None, - line_search_kws: Optional[Dict[str, Any]] = None, - quasi_newton_kws: Optional[Dict[str, Any]] = None, - stop_kws: Optional[Dict[str, Any]] = None, - callback: Optional[_OPT_CALLBACK] = None, - **kwargs, + state: OptimisationState, + old_state: Optional[OptimisationState] = None, + *, + max_iter=100, + search_direction=newton_abs_direction, + calc_line_search=line_search, + quasi_newton_update=bfgs_update, + stop_conditions=stop_conditions, + search_direction_kws: Optional[Dict[str, Any]] = None, + line_search_kws: Optional[Dict[str, Any]] = None, + quasi_newton_kws: Optional[Dict[str, Any]] = None, + stop_kws: Optional[Dict[str, Any]] = None, + callback: Optional[_OPT_CALLBACK] = None, + **kwargs, ) -> Tuple[OptimisationState, Status]: success = True updated = False @@ -356,7 +346,9 @@ def optimise_quasi_newton( success, message = stop break - with LogWarnings(logger=_log_projection_warnings, action='always') as caught_warnings: + with LogWarnings( + logger=_log_projection_warnings, action="always" + ) as caught_warnings: try: stepsize, state1 = take_quasi_newton_step( state,