Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] using ML optimizers such as Adam in JAX #1238

Draft
wants to merge 17 commits into
base: main
Choose a base branch
from
497 changes: 497 additions & 0 deletions docs/examples/notebooks/learn/Hessians.ipynb

Large diffs are not rendered by default.

227 changes: 227 additions & 0 deletions docs/examples/notebooks/learn/minuit_errors.ipynb

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions src/pyhf/cli/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def cli():
)
@click.option(
"--optimizer",
type=click.Choice(["scipy", "minuit"]),
type=click.Choice(["scipy", "minuit", "customjax"]),
help="The optimizer used for the calculation.",
default="scipy",
)
Expand Down Expand Up @@ -149,7 +149,7 @@ def fit(
)
@click.option(
"--optimizer",
type=click.Choice(["scipy", "minuit"]),
type=click.Choice(["scipy", "minuit", "customjax"]),
help="The optimizer used for the calculation.",
default="scipy",
)
Expand Down
1 change: 1 addition & 0 deletions src/pyhf/infer/mle.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def fit(data, pdf, init_pars=None, par_bounds=None, fixed_params=None, **kwargs)
if is_fixed
]

kwargs['do_stitch'] = True
return opt.minimize(
twice_nll, data, pdf, init_pars, par_bounds, fixed_vals, **kwargs
)
Expand Down
5 changes: 5 additions & 0 deletions src/pyhf/optimize/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@

class _OptimizerRetriever:
def __getattr__(self, name):
if name == 'customjax':
from .opt_custom_jax import jaxcustom_optimizer

self.jaxcustom_optimizer = jaxcustom_optimizer
return jaxcustom_optimizer
if name == 'scipy_optimizer':
from .opt_scipy import scipy_optimizer

Expand Down
109 changes: 72 additions & 37 deletions src/pyhf/optimize/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from ..tensor.common import _TensorViewer


def _make_stitch_pars(tv=None, fixed_values=None):
def _make_post_processor(tv=None, fixed_values=None):
"""
Construct a callable to stitch fixed paramter values into the unfixed parameters. See :func:`shim`.

Expand All @@ -21,14 +21,14 @@ def _make_stitch_pars(tv=None, fixed_values=None):
if tv is None or fixed_values is None:
return lambda pars, stitch_with=None: pars

def stitch_pars(pars, stitch_with=fixed_values):
def post_processor(pars, stitch_with=fixed_values):
tb, _ = get_backend()
return tv.stitch([tb.astensor(stitch_with, dtype='float'), pars])

return stitch_pars
return post_processor


def _get_tensor_shim():
def _get_internal_objective(*args, **kwargs):
"""
A shim-retriever to lazy-retrieve the necessary shims as needed.

Expand All @@ -39,25 +39,80 @@ def _get_tensor_shim():
if tensorlib.name == 'numpy':
from .opt_numpy import wrap_objective as numpy_shim

return numpy_shim
return numpy_shim(*args, **kwargs)

if tensorlib.name == 'tensorflow':
from .opt_tflow import wrap_objective as tflow_shim

return tflow_shim
return tflow_shim(*args, **kwargs)

if tensorlib.name == 'pytorch':
from .opt_pytorch import wrap_objective as pytorch_shim

return pytorch_shim
return pytorch_shim(*args, **kwargs)

if tensorlib.name == 'jax':
from .opt_jax import wrap_objective as jax_shim

return jax_shim
return jax_shim(*args, **kwargs)
raise ValueError(f'No optimizer shim for {tensorlib.name}.')


def to_inf(x, bounds):
tensorlib, _ = get_backend()
lo, hi = bounds.T
return tensorlib.arcsin(2 * (x - lo) / (hi - lo) - 1)


def to_bnd(x, bounds):
tensorlib, _ = get_backend()
lo, hi = bounds.T
return lo + 0.5 * (hi - lo) * (tensorlib.sin(x) + 1)


def _configure_internal_minimize(
init_pars, variable_idx, do_stitch, par_bounds, fixed_idx, fixed_values
):
tensorlib, _ = get_backend()
if do_stitch:
all_init = tensorlib.astensor(init_pars)
internal_init = tensorlib.gather(
all_init, tensorlib.astensor(variable_idx, dtype='int')
)

internal_bounds = [par_bounds[i] for i in variable_idx]
# stitched out the fixed values, so we don't pass any to the underlying minimizer
external_fixed_vals = []

tv = _TensorViewer([fixed_idx, variable_idx])
# NB: this is a closure, tensorlib needs to be accessed at a different point in time
post_processor = _make_post_processor(tv, fixed_values)

else:
internal_init = init_pars
internal_bounds = par_bounds
external_fixed_vals = fixed_vals
post_processor = _make_post_processor()

internal_init = to_inf(
tensorlib.astensor(internal_init), tensorlib.astensor(internal_bounds)
)

def mypostprocessor(x):
x = to_bnd(x, tensorlib.astensor(internal_bounds))
return post_processor(x)

no_internal_bounds = None

kwargs = dict(
x0=internal_init,
variable_bounds=internal_bounds,
bounds=no_internal_bounds,
fixed_vals=external_fixed_vals,
)
return kwargs, mypostprocessor


def shim(
objective,
data,
Expand Down Expand Up @@ -110,45 +165,25 @@ def shim(
fixed_values = [x[1] for x in fixed_vals]
variable_idx = [x for x in range(pdf.config.npars) if x not in fixed_idx]

if do_stitch:
all_init = tensorlib.astensor(init_pars)
variable_init = tensorlib.tolist(
tensorlib.gather(all_init, tensorlib.astensor(variable_idx, dtype='int'))
)
variable_bounds = [par_bounds[i] for i in variable_idx]
# stitched out the fixed values, so we don't pass any to the underlying minimizer
minimizer_fixed_vals = []

tv = _TensorViewer([fixed_idx, variable_idx])
# NB: this is a closure, tensorlib needs to be accessed at a different point in time
stitch_pars = _make_stitch_pars(tv, fixed_values)

else:
variable_init = init_pars
variable_bounds = par_bounds
minimizer_fixed_vals = fixed_vals
stitch_pars = _make_stitch_pars()
minimizer_kwargs, post_processor = _configure_internal_minimize(
init_pars, variable_idx, do_stitch, par_bounds, fixed_idx, fixed_values
)

objective_and_grad = _get_tensor_shim()(
internal_objective_maybe_grad = _get_internal_objective(
objective,
tensorlib.astensor(data),
pdf,
stitch_pars,
post_processor,
do_grad=do_grad,
jit_pieces={
'fixed_idx': fixed_idx,
'variable_idx': variable_idx,
'fixed_values': fixed_values,
'do_stitch': do_stitch,
'par_bounds': tensorlib.astensor(minimizer_kwargs.pop('variable_bounds')),
},
)

minimizer_kwargs = dict(
func=objective_and_grad,
x0=variable_init,
do_grad=do_grad,
bounds=variable_bounds,
fixed_vals=minimizer_fixed_vals,
)

return minimizer_kwargs, stitch_pars
minimizer_kwargs['func'] = internal_objective_maybe_grad
minimizer_kwargs['do_grad'] = do_grad
return minimizer_kwargs, post_processor
2 changes: 1 addition & 1 deletion src/pyhf/optimize/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def __init__(self, **kwargs):
def _internal_minimize(
self, func, x0, do_grad=False, bounds=None, fixed_vals=None, options={}
):

minimizer = self._get_minimizer(
func, x0, bounds, fixed_vals=fixed_vals, do_grad=do_grad
)
Expand Down Expand Up @@ -62,6 +61,7 @@ def _internal_postprocess(self, fitresult, stitch_pars):
tensorlib, _ = get_backend()

fitted_pars = stitch_pars(tensorlib.astensor(fitresult.x))

# extract number of fixed parameters
num_fixed_pars = len(fitted_pars) - len(fitresult.x)

Expand Down
87 changes: 87 additions & 0 deletions src/pyhf/optimize/opt_custom_jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""JAX Custom Optimizer Class."""
from .. import exceptions
from .mixins import OptimizerMixin
import scipy


class jaxcustom_optimizer(OptimizerMixin):
__slots__ = ['name']

def __init__(self, *args, **kwargs):
self.name = 'jaxcustom'
super().__init__(*args, **kwargs)

def _get_minimizer(
self, objective_and_grad, init_pars, init_bounds, fixed_vals=None, do_grad=False
):
return None

def _custom_internal_minimize(self, objective, init_pars, maxiter=1000, rtol=1e-7):
import jax.experimental.optimizers as optimizers
import jax

opt_init, opt_update, opt_getpars = optimizers.adam(step_size=1e-2)
state = opt_init(init_pars)
vold, _ = objective(init_pars)

def cond(loop_state):
delta = loop_state['delta']
i = loop_state['i']
delta_below = jax.numpy.logical_and(
loop_state['delta'] > 0, loop_state['delta'] < rtol
)
delta_below = jax.numpy.logical_and(loop_state['i'] > 1, delta_below)
maxed_iter = loop_state['i'] > maxiter
return ~jax.numpy.logical_or(maxed_iter, delta_below)

def body(loop_state):
i = loop_state['i']
state = loop_state['state']
pars = opt_getpars(state)
v, g = objective(pars)
newopt_state = opt_update(0, g, state)
vold = loop_state['vold']
delta = jax.numpy.abs(v - vold) / v
new_state = {}
new_state['delta'] = delta
new_state['state'] = newopt_state
new_state['vold'] = v
new_state['i'] = i + 1
return new_state

loop_state = {'delta': 0, 'i': 0, 'state': state, 'vold': vold}
# import time
# start = time.time()
# # while(cond(loop_state)):
# loop_state = body(loop_state)
loop_state = jax.lax.while_loop(cond, body, loop_state)
# print(time.time()-start)

minimized = opt_getpars(loop_state['state'])

class Result:
pass

r = Result()
r.x = minimized
r.success = True
r.fun = objective(minimized)[0]
return r

def _minimize(
self,
minimizer,
func,
x0,
do_grad=False,
bounds=None,
fixed_vals=None,
return_uncertainties=False,
options={},
):
assert minimizer == None
assert fixed_vals == []
assert return_uncertainties == False
assert bounds == None
result = self._custom_internal_minimize(func, x0)
return result
30 changes: 28 additions & 2 deletions src/pyhf/optimize/opt_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,35 @@
log = logging.getLogger(__name__)


def to_inf(x, bounds):
tensorlib, _ = get_backend()
lo, hi = bounds.T
return tensorlib.arcsin(2 * (x - lo) / (hi - lo) - 1)


def to_bnd(x, bounds):
tensorlib, _ = get_backend()
lo, hi = bounds.T
return lo + 0.5 * (hi - lo) * (tensorlib.sin(x) + 1)


def _final_objective(
pars, data, fixed_values, fixed_idx, variable_idx, do_stitch, objective, pdf
pars,
data,
fixed_values,
fixed_idx,
variable_idx,
do_stitch,
objective,
pdf,
par_bounds,
):
log.debug('jitting function')
tensorlib, _ = get_backend()
pars = tensorlib.astensor(pars)

pars = to_bnd(pars, par_bounds)

if do_stitch:
tv = _TensorViewer([fixed_idx, variable_idx])
constrained_pars = tv.stitch(
Expand Down Expand Up @@ -51,7 +74,7 @@ def wrap_objective(objective, data, pdf, stitch_pars, do_grad=False, jit_pieces=

def func(pars):
# need to conver to tuple to make args hashable
return _jitted_objective_and_grad(
result = _jitted_objective_and_grad(
pars,
data,
jit_pieces['fixed_values'],
Expand All @@ -60,7 +83,9 @@ def func(pars):
jit_pieces['do_stitch'],
objective,
pdf,
jit_pieces['par_bounds'],
)
return result

else:

Expand All @@ -75,6 +100,7 @@ def func(pars):
jit_pieces['do_stitch'],
objective,
pdf,
jit_pieces['par_bounds'],
)

return func
2 changes: 1 addition & 1 deletion src/pyhf/optimize/opt_minuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(self, *args, **kwargs):
tolerance (:obj:`float`): tolerance for termination. See specific optimizer for detailed meaning. Default is 0.1.
"""
self.name = 'minuit'
self.errordef = kwargs.pop('errordef', 1)
self.errordef = kwargs.pop('errordef', 0.5)
self.steps = kwargs.pop('steps', 1000)
self.strategy = kwargs.pop('strategy', None)
self.tolerance = kwargs.pop('tolerance', 0.1)
Expand Down
2 changes: 1 addition & 1 deletion src/pyhf/pdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,7 @@ def __init__(self, spec, batch_size=None, **config_kwargs):
self.version = config_kwargs.pop('version', None)
# run jsonschema validation of input specification against the (provided) schema
log.info(f"Validating spec against schema: {self.schema:s}")
utils.validate(self.spec, self.schema, version=self.version)
# utils.validate(self.spec, self.schema, version=self.version)
# build up our representation of the specification
self.config = _ModelConfig(self.spec, **config_kwargs)

Expand Down
Loading