From d1c8644c1f434f09d7855f30dbb806a625ac8854 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 22 Dec 2020 21:25:30 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- docs/examples/notebooks/learn/Hessians.ipynb | 247 +++++++++--------- .../notebooks/learn/minuit_errors.ipynb | 137 +++++----- src/pyhf/cli/infer.py | 4 +- src/pyhf/optimize/__init__.py | 1 + src/pyhf/optimize/common.py | 52 ++-- src/pyhf/optimize/mixins.py | 1 - src/pyhf/optimize/opt_custom_jax.py | 39 +-- src/pyhf/optimize/opt_jax.py | 30 ++- 8 files changed, 278 insertions(+), 233 deletions(-) diff --git a/docs/examples/notebooks/learn/Hessians.ipynb b/docs/examples/notebooks/learn/Hessians.ipynb index 644c3f32d8..889800e93d 100644 --- a/docs/examples/notebooks/learn/Hessians.ipynb +++ b/docs/examples/notebooks/learn/Hessians.ipynb @@ -57,13 +57,15 @@ "import matplotlib.pyplot as plt\n", "import matplotlib.patches as patches\n", "\n", - "def to_bounded(n,bounds):\n", - " a,b = bounds\n", - " return a+0.5*(b-a)*(jnp.sin(n) + 1)\n", "\n", - "def to_inf(x,bounds):\n", - " a,b = bounds\n", - " return jnp.arcsin(2*(x-a)/(b-a)-1)" + "def to_bounded(n, bounds):\n", + " a, b = bounds\n", + " return a + 0.5 * (b - a) * (jnp.sin(n) + 1)\n", + "\n", + "\n", + "def to_inf(x, bounds):\n", + " a, b = bounds\n", + " return jnp.arcsin(2 * (x - a) / (b - a) - 1)" ] }, { @@ -85,27 +87,29 @@ ], "source": [ "def plot_trfs():\n", - " bounds = [0,5]\n", + " bounds = [0, 5]\n", "\n", - " f,axarr = plt.subplots(2,1)\n", + " f, axarr = plt.subplots(2, 1)\n", "\n", - " x = jnp.linspace(bounds[0],bounds[1],1001)\n", - " n = jax.vmap(to_inf,in_axes=(0,None))(x,bounds)\n", + " x = jnp.linspace(bounds[0], bounds[1], 1001)\n", + " n = jax.vmap(to_inf, in_axes=(0, None))(x, bounds)\n", " ax = axarr[0]\n", - " ax.plot(x,n)\n", + " ax.plot(x, n)\n", " ax.set_xlabel('x')\n", " ax.set_ylabel('n')\n", " ax.set_title(r'$x \\to n$')\n", "\n", - " n = jnp.linspace(0,10,1001)\n", - " x = jax.vmap(to_bounded,in_axes=(0,None))(n,bounds)\n", + " n = jnp.linspace(0, 10, 1001)\n", + " x = jax.vmap(to_bounded, in_axes=(0, None))(n, bounds)\n", "\n", " ax = axarr[1]\n", - " ax.plot(n,x)\n", + " ax.plot(n, x)\n", " ax.set_xlabel('n')\n", " ax.set_ylabel('x')\n", " ax.set_title(r'$n \\to x$')\n", " f.set_tight_layout(True)\n", + "\n", + "\n", "plot_trfs()" ] }, @@ -116,69 +120,76 @@ "outputs": [], "source": [ "def func(external_pars):\n", - " x,y = external_pars\n", + " x, y = external_pars\n", " # a,b = 2*x+y,x-y\n", - " a,b = x,y\n", - " ca,cb = 1,1\n", - " z = (a-ca)**2 + (b-cb)**2\n", + " a, b = x, y\n", + " ca, cb = 1, 1\n", + " z = (a - ca) ** 2 + (b - cb) ** 2\n", " return z\n", - " \n", - " \n", - "def internal_func(internal_pars,bounds):\n", - " external_pars = jax.vmap(to_bounded)(internal_pars,bounds)\n", + "\n", + "\n", + "def internal_func(internal_pars, bounds):\n", + " external_pars = jax.vmap(to_bounded)(internal_pars, bounds)\n", " return func(external_pars)\n", "\n", - "bounds = jnp.array([[-5,5],[-5,5]])\n", "\n", + "bounds = jnp.array([[-5, 5], [-5, 5]])\n", "\n", "\n", - "def plot_func(ax,func,slices,bounds = None):\n", - " grid = x,y = np.mgrid[slices[0],slices[1]]\n", - " X = jnp.swapaxes(grid,0,-1).reshape(-1,2)\n", + "def plot_func(ax, func, slices, bounds=None):\n", + " grid = x, y = np.mgrid[slices[0], slices[1]]\n", + " X = jnp.swapaxes(grid, 0, -1).reshape(-1, 2)\n", "\n", " if bounds is not None:\n", - " Z = jax.vmap(func,in_axes=(0,None))(X,bounds)\n", + " Z = jax.vmap(func, in_axes=(0, None))(X, bounds)\n", " else:\n", " Z = jax.vmap(func)(X)\n", - " z = jnp.swapaxes(Z.reshape(101,101),0,-1)\n", - " ax.contourf(x,y,z,levels = 100)\n", - " ax.contour(x,y,z,levels = 10, colors = 'w')\n", + " z = jnp.swapaxes(Z.reshape(101, 101), 0, -1)\n", + " ax.contourf(x, y, z, levels=100)\n", + " ax.contour(x, y, z, levels=10, colors='w')\n", " ax.set_xlabel(r'$n_1$')\n", " ax.set_xlabel(r'$n_2$')\n", " if bounds is not None:\n", - " rect = patches.Rectangle([-np.pi/2,-np.pi/2],np.pi,np.pi, alpha = 0.2, facecolor = 'k')\n", + " rect = patches.Rectangle(\n", + " [-np.pi / 2, -np.pi / 2], np.pi, np.pi, alpha=0.2, facecolor='k'\n", + " )\n", " ax.add_patch(rect)\n", "\n", + "\n", "def angle_and_lam(M):\n", - " lam,bases = jnp.linalg.eig(M)\n", - " first = bases[:,0]\n", - " sign = jnp.sign(first[2])\n", - " angle = jnp.arccos(first[0])*180/np.pi\n", - " return lam,sign*angle\n", + " lam, bases = jnp.linalg.eig(M)\n", + " first = bases[:, 0]\n", + " sign = jnp.sign(first[2])\n", + " angle = jnp.arccos(first[0]) * 180 / np.pi\n", + " return lam, sign * angle\n", + "\n", "\n", - "def draw_covariances(ax,func,slices,bounds = None,scale = 1):\n", - " grid = x,y = np.mgrid[slices[0],slices[1]]\n", - " X = np.swapaxes(grid,0,-1).reshape(-1,2)\n", + "def draw_covariances(ax, func, slices, bounds=None, scale=1):\n", + " grid = x, y = np.mgrid[slices[0], slices[1]]\n", + " X = np.swapaxes(grid, 0, -1).reshape(-1, 2)\n", "\n", " if bounds is not None:\n", - " covariance = lambda X,bounds: jnp.linalg.inv(jax.hessian(func)(X,bounds))\n", - " args = (X,bounds)\n", - " axes = (0,None)\n", + " covariance = lambda X, bounds: jnp.linalg.inv(jax.hessian(func)(X, bounds))\n", + " args = (X, bounds)\n", + " axes = (0, None)\n", " else:\n", " covariance = lambda X: jnp.linalg.inv(jax.hessian(func)(X))\n", " args = (X,)\n", " axes = (0,)\n", - " lams,angles = jax.vmap(angle_and_lam)(jax.vmap(covariance,in_axes=axes)(*args))\n", - " for i,(lam,angle) in enumerate(zip(lams,angles)):\n", + " lams, angles = jax.vmap(angle_and_lam)(jax.vmap(covariance, in_axes=axes)(*args))\n", + " for i, (lam, angle) in enumerate(zip(lams, angles)):\n", " e = patches.Ellipse(\n", - " X[i],np.sqrt(lam[0])*scale,np.sqrt(lam[1])*scale,angle,\n", - " alpha = 0.5,\n", - " facecolor = 'none',\n", - " edgecolor = 'k'\n", + " X[i],\n", + " np.sqrt(lam[0]) * scale,\n", + " np.sqrt(lam[1]) * scale,\n", + " angle,\n", + " alpha=0.5,\n", + " facecolor='none',\n", + " edgecolor='k',\n", " )\n", " ax.add_patch(e)\n", - " ax.set_xlim(slices[0].start,slices[0].stop)\n", - " ax.set_ylim(slices[0].start,slices[0].stop) \n" + " ax.set_xlim(slices[0].start, slices[0].stop)\n", + " ax.set_ylim(slices[0].start, slices[0].stop)" ] }, { @@ -199,20 +210,14 @@ } ], "source": [ - "f,ax = plt.subplots(1,1)\n", - "f.set_size_inches(5,5)\n", + "f, ax = plt.subplots(1, 1)\n", + "f.set_size_inches(5, 5)\n", "f.set_tight_layout(True)\n", - "plot_func(ax, func, slices = [\n", - " slice(-5,5,101*1j),\n", - " slice(-5,5,101*1j)\n", - " ])\n", - "\n", - "draw_covariances(ax,func,slices = [\n", - " slice(-5,5,10*1j),\n", - " slice(-5,5,10*1j)\n", - " ],\n", - " scale = 1\n", - ") \n" + "plot_func(ax, func, slices=[slice(-5, 5, 101 * 1j), slice(-5, 5, 101 * 1j)])\n", + "\n", + "draw_covariances(\n", + " ax, func, slices=[slice(-5, 5, 10 * 1j), slice(-5, 5, 10 * 1j)], scale=1\n", + ")" ] }, { @@ -233,21 +238,23 @@ } ], "source": [ - "f,ax = plt.subplots(1,1)\n", - "f.set_size_inches(5,5)\n", - "\n", - "plot_func(ax,internal_func,slices = [\n", - " slice(-np.pi,np.pi,101*1j),\n", - " slice(-np.pi,np.pi,101*1j)\n", - " ],bounds = bounds)\n", - "bounds = jnp.array([[-5,5],[-5,5]])\n", - "draw_covariances(ax,internal_func,slices = [\n", - " slice(-np.pi,np.pi,10*1j),\n", - " slice(-np.pi,np.pi,10*1j)\n", - " ],\n", - " bounds = bounds,\n", - " scale = 1\n", - ")\n" + "f, ax = plt.subplots(1, 1)\n", + "f.set_size_inches(5, 5)\n", + "\n", + "plot_func(\n", + " ax,\n", + " internal_func,\n", + " slices=[slice(-np.pi, np.pi, 101 * 1j), slice(-np.pi, np.pi, 101 * 1j)],\n", + " bounds=bounds,\n", + ")\n", + "bounds = jnp.array([[-5, 5], [-5, 5]])\n", + "draw_covariances(\n", + " ax,\n", + " internal_func,\n", + " slices=[slice(-np.pi, np.pi, 10 * 1j), slice(-np.pi, np.pi, 10 * 1j)],\n", + " bounds=bounds,\n", + " scale=1,\n", + ")" ] }, { @@ -312,8 +319,8 @@ "outputs": [], "source": [ "def grads_from_n(n):\n", - " x = jax.vmap(to_bounded)(n,bounds)\n", - " J = jax.jacfwd(jax.vmap(to_inf))(x,bounds)\n", + " x = jax.vmap(to_bounded)(n, bounds)\n", + " J = jax.jacfwd(jax.vmap(to_inf))(x, bounds)\n", " return J" ] }, @@ -342,20 +349,20 @@ "metadata": {}, "outputs": [], "source": [ - "def hessian_transform(extr,bounds):\n", - " intr = jax.vmap(to_inf)(extr,bounds)\n", + "def hessian_transform(extr, bounds):\n", + " intr = jax.vmap(to_inf)(extr, bounds)\n", "\n", - " first = jax.jacfwd(jax.vmap(to_inf))(extr,bounds) \n", + " first = jax.jacfwd(jax.vmap(to_inf))(extr, bounds)\n", " secnd = jax.jacfwd(grads_from_n)(intr)\n", - " third = jax.grad(internal_func)(intr,bounds)\n", + " third = jax.grad(internal_func)(intr, bounds)\n", "\n", - " J = jax.jacfwd(jax.vmap(to_inf))(extr,bounds)\n", + " J = jax.jacfwd(jax.vmap(to_inf))(extr, bounds)\n", "\n", - " a = jnp.einsum('ik,kjl,l->ij',first,secnd,third)\n", + " a = jnp.einsum('ik,kjl,l->ij', first, secnd, third)\n", "\n", - " int_hessian = jax.hessian(internal_func)(intr,bounds)\n", - " b = jnp.einsum('ik,jl,kl->ij',J,J,int_hessian)\n", - " return int_hessian,a,b,a+b" + " int_hessian = jax.hessian(internal_func)(intr, bounds)\n", + " b = jnp.einsum('ik,jl,kl->ij', J, J, int_hessian)\n", + " return int_hessian, a, b, a + b" ] }, { @@ -379,8 +386,8 @@ } ], "source": [ - "def check_point(extrn,bounds):\n", - " int_hessian,a,b,extrn_hessian = hessian_transform(extrn,bounds)\n", + "def check_point(extrn, bounds):\n", + " int_hessian, a, b, extrn_hessian = hessian_transform(extrn, bounds)\n", "\n", " print(f'internal hessian:\\n{int_hessian}')\n", " print(f'additional part:\\n{a}')\n", @@ -390,9 +397,10 @@ " direct_hessian = jax.hessian(func)(extrn)\n", " print(f'directly computed hessian:\\n{direct_hessian}')\n", "\n", - "bounds = jnp.array([[-5,5],[-5,5]])\n", - "extrn = jnp.array([1.,1.])\n", - "check_point(extrn,bounds)" + "\n", + "bounds = jnp.array([[-5, 5], [-5, 5]])\n", + "extrn = jnp.array([1.0, 1.0])\n", + "check_point(extrn, bounds)" ] }, { @@ -416,8 +424,8 @@ } ], "source": [ - "extrn = jnp.array([2.,2.])\n", - "check_point(extrn,bounds)" + "extrn = jnp.array([2.0, 2.0])\n", + "check_point(extrn, bounds)" ] }, { @@ -426,28 +434,31 @@ "metadata": {}, "outputs": [], "source": [ - "def compare(ax,scale = 2, index = -1, color = 'k'):\n", - " slices = [slice(-5,5,11j),slice(-5,5,11j)]\n", + "def compare(ax, scale=2, index=-1, color='k'):\n", + " slices = [slice(-5, 5, 11j), slice(-5, 5, 11j)]\n", "\n", - " grid = x,y = np.mgrid[slices[0],slices[1]]\n", - " X = np.swapaxes(grid,0,-1).reshape(-1,2)\n", - " covariance = lambda X,bounds: hessian_transform(X,bounds)[index]\n", - " args = (X,bounds)\n", - " axes = (0,None)\n", + " grid = x, y = np.mgrid[slices[0], slices[1]]\n", + " X = np.swapaxes(grid, 0, -1).reshape(-1, 2)\n", + " covariance = lambda X, bounds: hessian_transform(X, bounds)[index]\n", + " args = (X, bounds)\n", + " axes = (0, None)\n", "\n", - " covariances = jax.vmap(covariance,in_axes=axes)(*args)\n", + " covariances = jax.vmap(covariance, in_axes=axes)(*args)\n", "\n", - " lams,angles = jax.vmap(angle_and_lam)(covariances)\n", - " for i,(lam,angle) in enumerate(zip(lams,angles)):\n", + " lams, angles = jax.vmap(angle_and_lam)(covariances)\n", + " for i, (lam, angle) in enumerate(zip(lams, angles)):\n", " e = patches.Ellipse(\n", - " X[i],lam[0]*scale,lam[1]*scale,angle,\n", - " alpha = 0.5,\n", - " facecolor = 'none',\n", - " edgecolor = color\n", + " X[i],\n", + " lam[0] * scale,\n", + " lam[1] * scale,\n", + " angle,\n", + " alpha=0.5,\n", + " facecolor='none',\n", + " edgecolor=color,\n", " )\n", " ax.add_patch(e)\n", - " ax.set_xlim(slices[0].start,slices[0].stop)\n", - " ax.set_ylim(slices[0].start,slices[0].stop) " + " ax.set_xlim(slices[0].start, slices[0].stop)\n", + " ax.set_ylim(slices[0].start, slices[0].stop)" ] }, { @@ -468,11 +479,11 @@ } ], "source": [ - "f,ax = plt.subplots(1,1)\n", - "compare(ax,scale = 0.2, index = -1)\n", - "compare(ax,scale = 0.2, index = -2, color = 'r')\n", - "compare(ax,scale = 0.2, index = -3, color = 'b')\n", - "plt.gcf().set_size_inches(5,5)" + "f, ax = plt.subplots(1, 1)\n", + "compare(ax, scale=0.2, index=-1)\n", + "compare(ax, scale=0.2, index=-2, color='r')\n", + "compare(ax, scale=0.2, index=-3, color='b')\n", + "plt.gcf().set_size_inches(5, 5)" ] }, { @@ -483,4 +494,4 @@ "source": [] } ] -} \ No newline at end of file +} diff --git a/docs/examples/notebooks/learn/minuit_errors.ipynb b/docs/examples/notebooks/learn/minuit_errors.ipynb index 85888af69d..4e8754a76d 100644 --- a/docs/examples/notebooks/learn/minuit_errors.ipynb +++ b/docs/examples/notebooks/learn/minuit_errors.ipynb @@ -44,120 +44,134 @@ "from jax.config import config\n", "import pyhf\n", "import scipy.optimize\n", + "\n", "config.update('jax_enable_x64', True)\n", "pyhf.set_backend('jax')\n", "\n", - "def toinf_single(x,bounds):\n", - " lo,hi = bounds\n", - " return jax.numpy.arcsin(2*(x-lo)/(hi-lo)-1)\n", "\n", - "def tobnd_single(x,bounds):\n", - " lo,hi = bounds\n", - " return lo + 0.5*(hi-lo)*(jax.numpy.sin(x) +1)\n", + "def toinf_single(x, bounds):\n", + " lo, hi = bounds\n", + " return jax.numpy.arcsin(2 * (x - lo) / (hi - lo) - 1)\n", + "\n", + "\n", + "def tobnd_single(x, bounds):\n", + " lo, hi = bounds\n", + " return lo + 0.5 * (hi - lo) * (jax.numpy.sin(x) + 1)\n", "\n", - "def _calc_minuit_errors(infmin,minim,inf_inv_hess,barray):\n", + "\n", + "def _calc_minuit_errors(infmin, minim, inf_inv_hess, barray):\n", " '''The core routine for MINUIT-like errors from an internal Hessian'''\n", " # https://root.cern.ch/doc/master/classROOT_1_1Minuit2_1_1MnUserTransformation.html#ad900f367f4d2c5df13f899dd55bdf212\n", " errs = jnp.sqrt(jnp.diag(inf_inv_hess))\n", - " infmin = jax.vmap(toinf_single)(minim,barray)\n", - " up = jax.vmap(tobnd_single)(infmin+errs,barray)-minim\n", - " dn = jax.vmap(tobnd_single)(infmin-errs,barray)-minim\n", - " up = jnp.where(errs>1,barray[:,1]-barray[:,0],up) #it's unclear to me why this is done\n", - " fn = (jnp.abs(up)+jnp.abs(dn))*0.5\n", + " infmin = jax.vmap(toinf_single)(minim, barray)\n", + " up = jax.vmap(tobnd_single)(infmin + errs, barray) - minim\n", + " dn = jax.vmap(tobnd_single)(infmin - errs, barray) - minim\n", + " up = jnp.where(\n", + " errs > 1, barray[:, 1] - barray[:, 0], up\n", + " ) # it's unclear to me why this is done\n", + " fn = (jnp.abs(up) + jnp.abs(dn)) * 0.5\n", " return fn\n", "\n", - "def calc_minuit_errors(objective,minim,barray):\n", + "\n", + "def calc_minuit_errors(objective, minim, barray):\n", " '''This computes MINUIT-like errors'''\n", + "\n", " def internal_obj(x):\n", - " ext = jax.vmap(tobnd_single)(x,barray)\n", + " ext = jax.vmap(tobnd_single)(x, barray)\n", " return objective(ext)\n", - " infmin = jax.vmap(toinf_single)(minim,barray)\n", + "\n", + " infmin = jax.vmap(toinf_single)(minim, barray)\n", " internal_hessian = jax.hessian(internal_obj)(infmin)\n", " inf_inv_hess = jax.numpy.linalg.inv(internal_hessian)\n", "\n", - " minuit_errors = _calc_minuit_errors(infmin,minim,inf_inv_hess,barray)\n", + " minuit_errors = _calc_minuit_errors(infmin, minim, inf_inv_hess, barray)\n", " return minuit_errors\n", "\n", - "def run_via_pyhf_scipy(objective,data_pdf,init,barray):\n", + "\n", + "def run_via_pyhf_scipy(objective, data_pdf, init, barray):\n", " '''This version runs the standard pyhf interface but computes minuit-like errors'''\n", " minim = pyhf.infer.mle.fit(*data_pdf)\n", - " minuit_errors = calc_minuit_errors(objective,minim,barray)\n", + " minuit_errors = calc_minuit_errors(objective, minim, barray)\n", " return minuit_errors\n", "\n", - "def run_raw_scipy(objective,init,barray):\n", + "\n", + "def run_raw_scipy(objective, init, barray):\n", " '''This version runs the raw scipy optimization and computes minuit-like errors'''\n", "\n", - " minim = scipy.optimize.minimize(objective,jnp.array(init),bounds = barray).x\n", + " minim = scipy.optimize.minimize(objective, jnp.array(init), bounds=barray).x\n", "\n", - " minuit_errors = calc_minuit_errors(objective,minim,barray)\n", + " minuit_errors = calc_minuit_errors(objective, minim, barray)\n", "\n", " external_hessian = jax.hessian(objective)(minim)\n", " inv_hess = jax.numpy.linalg.inv(external_hessian)\n", " sqrt_inv_hess_err = jnp.sqrt(jnp.diag(inv_hess))\n", " return minuit_errors, sqrt_inv_hess_err\n", "\n", - "def run_pyhf_minuit(data_pdf,grad = False):\n", + "\n", + "def run_pyhf_minuit(data_pdf, grad=False):\n", " '''This version runs the raw scipy optimization and computes minuit-like errors'''\n", - " result = pyhf.infer.mle.fit(*data_pdf, return_uncertainties = True,do_grad=grad)\n", + " result = pyhf.infer.mle.fit(*data_pdf, return_uncertainties=True, do_grad=grad)\n", " return result\n", "\n", - "def raw_minuit(objective,init,barray):\n", + "\n", + "def raw_minuit(objective, init, barray):\n", " '''This version runs just raw minuit without pyhf'''\n", - " m = iminuit.Minuit(objective,\n", + " m = iminuit.Minuit(\n", + " objective,\n", " use_array_call=True,\n", - " forced_parameters = ['p1','p2'],\n", - " errordef = 0.5,\n", - " p1 = init[0],\n", - " p2 = init[1],\n", - " error_p1 = 0.01,\n", - " error_p2 = 0.01,\n", - " limit_p1 = barray[0],\n", - " limit_p2 = barray[1]\n", + " forced_parameters=['p1', 'p2'],\n", + " errordef=0.5,\n", + " p1=init[0],\n", + " p2=init[1],\n", + " error_p1=0.01,\n", + " error_p2=0.01,\n", + " limit_p1=barray[0],\n", + " limit_p2=barray[1],\n", " )\n", " m.strategy = 0\n", " m.migrad()\n", " m.hesse()\n", " return m.np_errors()\n", "\n", - "def run_error_analysis(pdf,obs_count):\n", - " data = jnp.array([obs_count]+pdf.config.auxdata)\n", + "\n", + "def run_error_analysis(pdf, obs_count):\n", + " data = jnp.array([obs_count] + pdf.config.auxdata)\n", + "\n", " def func(x):\n", - " return -2.0*pdf.logpdf(x,data)[0]\n", + " return -2.0 * pdf.logpdf(x, data)[0]\n", + "\n", " bounds = jnp.array(pdf.config.suggested_bounds())\n", " init = jnp.array(pdf.config.suggested_init())\n", "\n", - " pyhf.set_backend('jax','scipy')\n", + " pyhf.set_backend('jax', 'scipy')\n", "\n", - " min_errors = raw_minuit(func,init,bounds)\n", - " pyhf_scipy_minuit_errors = run_via_pyhf_scipy(func,(data,pdf),init,bounds)\n", - " scipy_minuit_errors,sqrt_inv_hess_err = run_raw_scipy(func,init,bounds)\n", + " min_errors = raw_minuit(func, init, bounds)\n", + " pyhf_scipy_minuit_errors = run_via_pyhf_scipy(func, (data, pdf), init, bounds)\n", + " scipy_minuit_errors, sqrt_inv_hess_err = run_raw_scipy(func, init, bounds)\n", "\n", - " pyhf.set_backend('jax',pyhf.optimize.minuit_optimizer(errordef = 0.5))\n", - " result_nograd = run_pyhf_minuit((data,pdf),grad = False)\n", - " result_grad = run_pyhf_minuit((data,pdf),grad = True)\n", + " pyhf.set_backend('jax', pyhf.optimize.minuit_optimizer(errordef=0.5))\n", + " result_nograd = run_pyhf_minuit((data, pdf), grad=False)\n", + " result_grad = run_pyhf_minuit((data, pdf), grad=True)\n", " return {\n", " 'sqrt inv hessian': sqrt_inv_hess_err,\n", - "\n", " 'raw scipy + AD minuit-like': scipy_minuit_errors,\n", " 'pyhf scipy + AD minuit-like': pyhf_scipy_minuit_errors,\n", - "\n", " 'raw minuit': min_errors,\n", - " 'pyhf minuit iface bo AD ': result_nograd[:,1],\n", - " 'pyhf minuit iface AD': result_grad[:,1],\n", + " 'pyhf minuit iface bo AD ': result_nograd[:, 1],\n", + " 'pyhf minuit iface AD': result_grad[:, 1],\n", " }\n", "\n", "\n", - "\n", "def run_scan(scan):\n", - " pdf = pyhf.simplemodels.hepdata_like([2.],[50.],[5.])\n", + " pdf = pyhf.simplemodels.hepdata_like([2.0], [50.0], [5.0])\n", " data = {}\n", " for o in scan:\n", - " d = run_error_analysis(pdf,o)\n", - " for k,v in d.items():\n", - " data.setdefault(k,[]).append(v)\n", - " for k,v in d.items():\n", + " d = run_error_analysis(pdf, o)\n", + " for k, v in d.items():\n", + " data.setdefault(k, []).append(v)\n", + " for k, v in d.items():\n", " data[k] = jnp.array(data[k])\n", - " return data\n" + " return data" ] }, { @@ -186,17 +200,18 @@ ], "source": [ "import matplotlib.pyplot as plt\n", - "scan = jnp.linspace(50,70,10)\n", + "\n", + "scan = jnp.linspace(50, 70, 10)\n", "data = run_scan(scan)\n", "\n", "\n", - "f,axarr = plt.subplots(1,2)\n", - "f.set_size_inches(10,5)\n", - "for k,v in data.items():\n", - " axarr[0].plot(scan,v[:,0], label = k, linestyle = 'dashed')\n", + "f, axarr = plt.subplots(1, 2)\n", + "f.set_size_inches(10, 5)\n", + "for k, v in data.items():\n", + " axarr[0].plot(scan, v[:, 0], label=k, linestyle='dashed')\n", " axarr[0].set_title('par1')\n", " axarr[0].legend()\n", - " axarr[1].plot(scan,v[:,0], label = k, linestyle = 'dashed')\n", + " axarr[1].plot(scan, v[:, 0], label=k, linestyle='dashed')\n", " axarr[1].set_title('par2')\n", " axarr[1].legend()" ] @@ -209,4 +224,4 @@ "source": [] } ] -} \ No newline at end of file +} diff --git a/src/pyhf/cli/infer.py b/src/pyhf/cli/infer.py index b25041a87d..d2ccb01f55 100644 --- a/src/pyhf/cli/infer.py +++ b/src/pyhf/cli/infer.py @@ -41,7 +41,7 @@ def cli(): ) @click.option( "--optimizer", - type=click.Choice(["scipy", "minuit","customjax"]), + type=click.Choice(["scipy", "minuit", "customjax"]), help="The optimizer used for the calculation.", default="scipy", ) @@ -149,7 +149,7 @@ def fit( ) @click.option( "--optimizer", - type=click.Choice(["scipy", "minuit","customjax"]), + type=click.Choice(["scipy", "minuit", "customjax"]), help="The optimizer used for the calculation.", default="scipy", ) diff --git a/src/pyhf/optimize/__init__.py b/src/pyhf/optimize/__init__.py index a16509ed9f..4c5afb7ab0 100644 --- a/src/pyhf/optimize/__init__.py +++ b/src/pyhf/optimize/__init__.py @@ -7,6 +7,7 @@ class _OptimizerRetriever: def __getattr__(self, name): if name == 'customjax': from .opt_custom_jax import jaxcustom_optimizer + self.jaxcustom_optimizer = jaxcustom_optimizer return jaxcustom_optimizer if name == 'scipy_optimizer': diff --git a/src/pyhf/optimize/common.py b/src/pyhf/optimize/common.py index bace305044..7aa4ac8719 100644 --- a/src/pyhf/optimize/common.py +++ b/src/pyhf/optimize/common.py @@ -28,7 +28,7 @@ def post_processor(pars, stitch_with=fixed_values): return post_processor -def _get_internal_objective(*args,**kwargs): +def _get_internal_objective(*args, **kwargs): """ A shim-retriever to lazy-retrieve the necessary shims as needed. @@ -39,37 +39,40 @@ def _get_internal_objective(*args,**kwargs): if tensorlib.name == 'numpy': from .opt_numpy import wrap_objective as numpy_shim - return numpy_shim(*args,**kwargs) + return numpy_shim(*args, **kwargs) if tensorlib.name == 'tensorflow': from .opt_tflow import wrap_objective as tflow_shim - return tflow_shim(*args,**kwargs) + return tflow_shim(*args, **kwargs) if tensorlib.name == 'pytorch': from .opt_pytorch import wrap_objective as pytorch_shim - return pytorch_shim(*args,**kwargs) + return pytorch_shim(*args, **kwargs) if tensorlib.name == 'jax': from .opt_jax import wrap_objective as jax_shim - return jax_shim(*args,**kwargs) + return jax_shim(*args, **kwargs) raise ValueError(f'No optimizer shim for {tensorlib.name}.') -def to_inf(x,bounds): +def to_inf(x, bounds): tensorlib, _ = get_backend() - lo,hi = bounds.T - return tensorlib.arcsin(2*(x-lo)/(hi-lo)-1) + lo, hi = bounds.T + return tensorlib.arcsin(2 * (x - lo) / (hi - lo) - 1) -def to_bnd(x,bounds): + +def to_bnd(x, bounds): tensorlib, _ = get_backend() - lo,hi = bounds.T - return lo + 0.5*(hi-lo)*(tensorlib.sin(x) +1) + lo, hi = bounds.T + return lo + 0.5 * (hi - lo) * (tensorlib.sin(x) + 1) -def _configure_internal_minimize(init_pars,variable_idx,do_stitch,par_bounds,fixed_idx,fixed_values): +def _configure_internal_minimize( + init_pars, variable_idx, do_stitch, par_bounds, fixed_idx, fixed_values +): tensorlib, _ = get_backend() if do_stitch: all_init = tensorlib.astensor(init_pars) @@ -90,22 +93,25 @@ def _configure_internal_minimize(init_pars,variable_idx,do_stitch,par_bounds,fix external_fixed_vals = fixed_vals post_processor = _make_post_processor() - internal_init = to_inf(tensorlib.astensor(internal_init),tensorlib.astensor(internal_bounds)) + internal_init = to_inf( + tensorlib.astensor(internal_init), tensorlib.astensor(internal_bounds) + ) + def mypostprocessor(x): - x = to_bnd(x,tensorlib.astensor(internal_bounds)) + x = to_bnd(x, tensorlib.astensor(internal_bounds)) return post_processor(x) no_internal_bounds = None - - kwargs = dict( - x0 = internal_init, - variable_bounds = internal_bounds, + kwargs = dict( + x0=internal_init, + variable_bounds=internal_bounds, bounds=no_internal_bounds, fixed_vals=external_fixed_vals, - ) + ) return kwargs, mypostprocessor + def shim( objective, data, @@ -158,9 +164,9 @@ def shim( fixed_values = [x[1] for x in fixed_vals] variable_idx = [x for x in range(pdf.config.npars) if x not in fixed_idx] - minimizer_kwargs,post_processor = _configure_internal_minimize(init_pars,variable_idx,do_stitch,par_bounds,fixed_idx,fixed_values) - - + minimizer_kwargs, post_processor = _configure_internal_minimize( + init_pars, variable_idx, do_stitch, par_bounds, fixed_idx, fixed_values + ) internal_objective_maybe_grad = _get_internal_objective( objective, @@ -173,7 +179,7 @@ def shim( 'variable_idx': variable_idx, 'fixed_values': fixed_values, 'do_stitch': do_stitch, - 'par_bounds': tensorlib.astensor(minimizer_kwargs.pop('variable_bounds')) + 'par_bounds': tensorlib.astensor(minimizer_kwargs.pop('variable_bounds')), }, ) diff --git a/src/pyhf/optimize/mixins.py b/src/pyhf/optimize/mixins.py index 7cab7615fa..e6019a4e73 100644 --- a/src/pyhf/optimize/mixins.py +++ b/src/pyhf/optimize/mixins.py @@ -60,7 +60,6 @@ def _internal_postprocess(self, fitresult, stitch_pars): """ tensorlib, _ = get_backend() - fitted_pars = stitch_pars(tensorlib.astensor(fitresult.x)) # extract number of fixed parameters diff --git a/src/pyhf/optimize/opt_custom_jax.py b/src/pyhf/optimize/opt_custom_jax.py index 941fdc5853..4dd0a949da 100644 --- a/src/pyhf/optimize/opt_custom_jax.py +++ b/src/pyhf/optimize/opt_custom_jax.py @@ -16,33 +16,37 @@ def _get_minimizer( ): return None - def _custom_internal_minimize(self,objective, init_pars, maxiter = 1000,rtol = 1e-7): + def _custom_internal_minimize(self, objective, init_pars, maxiter=1000, rtol=1e-7): import jax.experimental.optimizers as optimizers import jax - opt_init, opt_update, opt_getpars = optimizers.adam(step_size = 1e-2) + + opt_init, opt_update, opt_getpars = optimizers.adam(step_size=1e-2) state = opt_init(init_pars) - vold,_ = objective(init_pars) + vold, _ = objective(init_pars) + def cond(loop_state): delta = loop_state['delta'] i = loop_state['i'] - delta_below = jax.numpy.logical_and(loop_state['delta'] > 0,loop_state['delta'] < rtol) - delta_below = jax.numpy.logical_and(loop_state['i'] > 1, delta_below) + delta_below = jax.numpy.logical_and( + loop_state['delta'] > 0, loop_state['delta'] < rtol + ) + delta_below = jax.numpy.logical_and(loop_state['i'] > 1, delta_below) maxed_iter = loop_state['i'] > maxiter - return ~jax.numpy.logical_or(maxed_iter,delta_below) - + return ~jax.numpy.logical_or(maxed_iter, delta_below) + def body(loop_state): i = loop_state['i'] - state = loop_state['state'] + state = loop_state['state'] pars = opt_getpars(state) - v,g = objective(pars) - newopt_state = opt_update(0,g,state) + v, g = objective(pars) + newopt_state = opt_update(0, g, state) vold = loop_state['vold'] - delta = jax.numpy.abs(v-vold)/v + delta = jax.numpy.abs(v - vold) / v new_state = {} - new_state['delta'] = delta - new_state['state'] = newopt_state + new_state['delta'] = delta + new_state['state'] = newopt_state new_state['vold'] = v - new_state['i'] = i+1 + new_state['i'] = i + 1 return new_state loop_state = {'delta': 0, 'i': 0, 'state': state, 'vold': vold} @@ -50,19 +54,20 @@ def body(loop_state): # start = time.time() # # while(cond(loop_state)): # loop_state = body(loop_state) - loop_state = jax.lax.while_loop(cond,body,loop_state) + loop_state = jax.lax.while_loop(cond, body, loop_state) # print(time.time()-start) minimized = opt_getpars(loop_state['state']) + class Result: pass + r = Result() r.x = minimized r.success = True r.fun = objective(minimized)[0] return r - def _minimize( self, minimizer, @@ -78,5 +83,5 @@ def _minimize( assert fixed_vals == [] assert return_uncertainties == False assert bounds == None - result = self._custom_internal_minimize(func,x0) + result = self._custom_internal_minimize(func, x0) return result diff --git a/src/pyhf/optimize/opt_jax.py b/src/pyhf/optimize/opt_jax.py index 53d5d1ea5d..7fe40b193f 100644 --- a/src/pyhf/optimize/opt_jax.py +++ b/src/pyhf/optimize/opt_jax.py @@ -8,26 +8,34 @@ log = logging.getLogger(__name__) -def to_inf(x,bounds): +def to_inf(x, bounds): tensorlib, _ = get_backend() - lo,hi = bounds.T - return tensorlib.arcsin(2*(x-lo)/(hi-lo)-1) + lo, hi = bounds.T + return tensorlib.arcsin(2 * (x - lo) / (hi - lo) - 1) -def to_bnd(x,bounds): + +def to_bnd(x, bounds): tensorlib, _ = get_backend() - lo,hi = bounds.T - return lo + 0.5*(hi-lo)*(tensorlib.sin(x) +1) + lo, hi = bounds.T + return lo + 0.5 * (hi - lo) * (tensorlib.sin(x) + 1) def _final_objective( - pars, data, fixed_values, fixed_idx, variable_idx, do_stitch, objective, pdf, par_bounds + pars, + data, + fixed_values, + fixed_idx, + variable_idx, + do_stitch, + objective, + pdf, + par_bounds, ): log.debug('jitting function') tensorlib, _ = get_backend() pars = tensorlib.astensor(pars) - pars = to_bnd(pars,par_bounds) - + pars = to_bnd(pars, par_bounds) if do_stitch: tv = _TensorViewer([fixed_idx, variable_idx]) @@ -75,7 +83,7 @@ def func(pars): jit_pieces['do_stitch'], objective, pdf, - jit_pieces['par_bounds'] + jit_pieces['par_bounds'], ) return result @@ -92,7 +100,7 @@ def func(pars): jit_pieces['do_stitch'], objective, pdf, - jit_pieces['par_bounds'] + jit_pieces['par_bounds'], ) return func