Skip to content

Commit

Permalink
Merge branch 'main' into brendt/new-opt-alg
Browse files Browse the repository at this point in the history
  • Loading branch information
bwohlberg committed Dec 7, 2021
2 parents c94e94b + 50e4935 commit 82a96ea
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 195 deletions.
211 changes: 20 additions & 191 deletions scico/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,7 @@
# user license can be found in the 'LICENSE' file distributed with the
# package.

"""Optimization algorithms.
.. todo::
Add motivation for this module; when to choose over jax optimizers
"""
"""Optimization algorithms."""


from functools import wraps
Expand Down Expand Up @@ -86,7 +81,7 @@ def wrapper(x, *args):
return wrapper


def split_real_imag(x: Union[JaxArray, BlockArray]) -> Union[JaxArray, BlockArray]:
def _split_real_imag(x: Union[JaxArray, BlockArray]) -> Union[JaxArray, BlockArray]:
"""Split an array of shape (N,M,...) into real and imaginary parts.
Args:
Expand All @@ -101,12 +96,12 @@ def split_real_imag(x: Union[JaxArray, BlockArray]) -> Union[JaxArray, BlockArra
BlockArray.
"""
if isinstance(x, BlockArray):
return BlockArray.array([split_real_imag(_) for _ in x])
return BlockArray.array([_split_real_imag(_) for _ in x])
else:
return snp.stack((snp.real(x), snp.imag(x)))


def join_real_imag(x: Union[JaxArray, BlockArray]) -> Union[JaxArray, BlockArray]:
def _join_real_imag(x: Union[JaxArray, BlockArray]) -> Union[JaxArray, BlockArray]:
"""Join a real array of shape (2,N,M,...) into a complex array.
Join a real array of shape (2,N,M,...) into a complex array of length
Expand All @@ -120,16 +115,11 @@ def join_real_imag(x: Union[JaxArray, BlockArray]) -> Union[JaxArray, BlockArray
and ``x[1]`` respectively.
"""
if isinstance(x, BlockArray):
return BlockArray.array([join_real_imag(_) for _ in x])
return BlockArray.array([_join_real_imag(_) for _ in x])
else:
return x[0] + 1j * x[1]


# TODO: Use jax to compute Hessian-vector products for use in Newton methods
# see https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#Hessian-vector-products-using-both-forward--and-reverse-mode
# for examples of constructing Hessians in jax


def minimize(
func: Callable,
x0: Union[JaxArray, BlockArray],
Expand All @@ -154,152 +144,27 @@ def minimize(
supported.
- Functions mapping from complex arrays -> float are supported.
Docstring for :func:`scipy.optimize.minimize` follows. For
descriptions of the optimization methods and custom minimizers, refer
to the original docstring for :func:`scipy.optimize.minimize`.
Args:
func: The objective function to be minimized.
``func(x, *args) -> float``
where ``x`` is an array and ``args`` is a tuple of the fixed parameters
needed to completely specify the function. Unlike
:func:`scipy.optimize.minimize`, ``x`` need not be a 1D array.
x0: Initial guess. If ``func`` is a mapping from complex arrays to floats,
x0 must have a complex data type.
args: Extra arguments passed to the objective function and `hess`.
method: Type of solver. Should be one of:
- 'Nelder-Mead' `(see here) <https://docs.scipy.org/doc/scipy/reference/optimize.minimize-neldermead.html>`__
- 'Powell' `(see here) <https://docs.scipy.org/doc/scipy/reference/optimize.minimize-powell.html>`__
- 'CG' `(see here) <https://docs.scipy.org/doc/scipy/reference/optimize.minimize-cg.html>`__
- 'BFGS' `(see here) <https://docs.scipy.org/doc/scipy/reference/optimize.minimize-bfgs.html>`__
- 'Newton-CG' `(see here) <https://docs.scipy.org/doc/scipy/reference/optimize.minimize-newtoncg.html>`__
- 'L-BFGS-B' `(see here) <https://docs.scipy.org/doc/scipy/reference/optimize.minimize-lbfgsb.html>`__
- 'TNC' `(see here) <https://docs.scipy.org/doc/scipy/reference/optimize.minimize-tnc.html>`__
- 'COBYLA' `(see here) <https://docs.scipy.org/doc/scipy/reference/optimize.minimize-cobyla.html>`__
- 'SLSQP' `(see here) <https://docs.scipy.org/doc/scipy/reference/optimize.minimize-slsqp.html>`__
- 'trust-constr'`(see here) <https://docs.scipy.org/doc/scipy/reference/optimize.minimize-trustconstr.html>`__
- 'dogleg' `(see here) <https://docs.scipy.org/doc/scipy/reference/optimize.minimize-dogleg.html>`__
- 'trust-ncg' `(see here) <https://docs.scipy.org/doc/scipy/reference/optimize.minimize-trustncg.html>`__
- 'trust-exact' `(see here) <https://docs.scipy.org/doc/scipy/reference/optimize.minimize-trustexact.html>`__
- 'trust-krylov' `(see here) <https://docs.scipy.org/doc/scipy/reference/optimize.minimize-trustkrylov.html>`__
- custom - a callable object (added in version SciPy 0.14.0), see :func:`scipy.optimize.minmize_scalar`.
If not given, chosen to be one of ``BFGS``, ``L-BFGS-B``, ``SLSQP``,
depending if the problem has constraints or bounds.
hess: Method for computing the Hessian matrix. Only for Newton-CG, dogleg,
trust-ncg, trust-krylov, trust-exact and trust-constr. If it is
callable, it should return the Hessian matrix:
``hess(x, *args) -> {LinearOperator, spmatrix, array}, (n, n)``
where x is a (n,) ndarray and `args` is a tuple with the fixed
parameters. LinearOperator and sparse matrix returns are
allowed only for 'trust-constr' method. Alternatively, the keywords
{'2-point', '3-point', 'cs'} select a finite difference scheme
for numerical estimation. Or, objects implementing
`HessianUpdateStrategy` interface can be used to approximate
the Hessian. Available quasi-Newton methods implementing
this interface are:
- `BFGS`;
- `SR1`.
Whenever the gradient is estimated via finite-differences,
the Hessian cannot be estimated with options
{'2-point', '3-point', 'cs'} and needs to be
estimated using one of the quasi-Newton strategies.
Finite-difference options {'2-point', '3-point', 'cs'} and
`HessianUpdateStrategy` are available only for 'trust-constr' method.
NOTE: In the future, `hess` may be determined using jax.
hessp: Hessian of objective function times an arbitrary vector p.
Only for Newton-CG, trust-ncg, trust-krylov, trust-constr.
Only one of `hessp` or `hess` needs to be given. If `hess` is
provided, then `hessp` will be ignored. `hessp` must compute the
Hessian times an arbitrary vector:
``hessp(x, p, *args) -> array``
where x is a ndarray, p is an arbitrary vector with
dimension equal to x, and `args` is a tuple with the fixed parameters.
bounds (None, optional): Bounds on variables for L-BFGS-B, TNC, SLSQP, Powell, and
trust-constr methods. There are two ways to specify the bounds:
1. Instance of `Bounds` class.
2. Sequence of ``(min, max)`` pairs for each element in `x`. None
is used to specify no bound.
constraints: Constraints definition (only for COBYLA, SLSQP and trust-constr).
Constraints for 'trust-constr' are defined as a single object or a
list of objects specifying constraints to the optimization problem.
Available constraints are:
- `LinearConstraint`
- `NonlinearConstraint`
Constraints for COBYLA, SLSQP are defined as a list of dictionaries.
Each dictionary with fields:
type : str
Constraint type: 'eq' for equality, 'ineq' for inequality.
fun : callable
The function defining the constraint.
jac : callable, optional
The Jacobian of `fun` (only for SLSQP).
args : sequence, optional
Extra arguments to be passed to the function and Jacobian.
Equality constraint means that the constraint function result is to
be zero whereas inequality means that it is to be non-negative.
Note that COBYLA only supports inequality constraints.
tol: Tolerance for termination. For detailed control, use solver-specific options.
callback: Called after each iteration. For 'trust-constr' it is a callable with
the signature:
``callback(xk, OptimizeResult state) -> bool``
where ``xk`` is the current parameter vector. and ``state``
is an `OptimizeResult` object, with the same fields
as the ones from the return. If callback returns True
the algorithm execution is terminated.
For all the other methods, the signature is:
``callback(xk)``
where ``xk`` is the current parameter vector.
options: A dictionary of solver options. All methods accept the following
generic options:
maxiter : int
Maximum number of iterations to perform.
disp : bool
Set to True to print convergence messages.
See :func:`scipy.optimize.show_options()` for solver-specific options.
For more detail, including descriptions of the optimization methods
and custom minimizers, refer to the original docs for
:func:`scipy.optimize.minimize`.
"""

if snp.iscomplexobj(x0):
# scipy minimize function requires real-valued arrays, so
# we split x0 into a vector with real/imaginary parts stacked
# and compose `func` with a `join_real_imag`
# and compose `func` with a `_join_real_imag`
iscomplex = True
func_ = lambda x: func(join_real_imag(x))
x0 = split_real_imag(x0)
func_ = lambda x: func(_join_real_imag(x))
x0 = _split_real_imag(x0)
else:
iscomplex = False
func_ = func

x0_shape = x0.shape
x0_dtype = x0.dtype
x0 = x0.ravel() # If x0 is a BlockArray it will become a DeviceArray here
x0 = x0.ravel() # if x0 is a BlockArray it will become a DeviceArray here
if isinstance(x0, jax.interpreters.xla.DeviceArray):
dev = x0.device_buffer.device() # device where x0 resides; used to put result back in place
dev = x0.device_buffer.device() # device for x0; used to put result back in place
x0 = x0.copy().astype(float)
else:
dev = None
Expand Down Expand Up @@ -330,15 +195,15 @@ def minimize(
# un-vectorize the output array, put on device
res.x = snp.reshape(
res.x, x0_shape
) # If x0 was originally a BlockArray be converted back to one here
) # if x0 was originally a BlockArray be converted back to one here

res.x = res.x.astype(x0_dtype)

if dev:
res.x = jax.device_put(res.x, dev)

if iscomplex:
res.x = join_real_imag(res.x)
res.x = _join_real_imag(res.x)

return res

Expand All @@ -355,47 +220,11 @@ def minimize_scalar(

"""Minimization of scalar function of one variable.
Wrapper around :func:`scipy.optimize.minimize_scalar`. Docstring for
:func:`scipy.optimize.minimize_scalar` follows. For descriptions of
the optimization methods and custom minimizers, refer to the original
docstring for :func:`scipy.optimize.minimize_scalar`.
Args:
func: Objective function. Scalar function, must return a scalar.
bracket: For methods 'brent' and 'golden', `bracket` defines the bracketing
interval and can either have three items ``(a, b, c)`` so that
``a < b < c`` and ``fun(b) < fun(a), fun(c)`` or two items ``a`` and
``c`` which are assumed to be a starting interval for a downhill
bracket search (see `bracket`); it doesn't always mean that the
obtained solution will satisfy ``a <= x <= c``.
bounds: For method 'bounded', `bounds` is mandatory and must have two items
corresponding to the optimization bounds.
args: Extra arguments passed to the objective function.
method: Type of solver. Should be one of:
- 'Brent' `(see here) <https://docs.scipy.org/doc/scipy/reference/optimize.minimize_scalar-brent.html>`__
- 'Bounded' `(see here) <https://docs.scipy.org/doc/scipy/reference/optimize.minimize_scalar-bounded.html>`__
- 'Golden' `(see here) <https://docs.scipy.org/doc/scipy/reference/optimize.minimize_scalar-golden.html>`__
- custom - a callable object (added in SciPy version 0.14.0), see :func:`scipy.optimize.minmize_scalar`.
tol: Tolerance for termination. For detailed control, use solver-specific
options.
options: A dictionary of solver options.
maxiter : int
Maximum number of iterations to perform.
disp : bool
Set to True to print convergence messages.
See :func:`scipy.optimize.show_options()` for solver-specific options.
Returns:
The optimization result represented as a ``OptimizeResult`` object.
Important attributes are: ``x`` the solution array, ``success`` a
Boolean flag indicating if the optimizer exited successfully and
``message`` which describes the cause of the termination. See
:class:`scipy.optimize.OptimizeResult` for a description of other attributes.
Wrapper around :func:`scipy.optimize.minimize_scalar`.
For more detail, including descriptions of the optimization methods
and custom minimizers, refer to the original docstring for
:func:`scipy.optimize.minimize_scalar`.
"""

def f(x, *args):
Expand Down Expand Up @@ -437,7 +266,7 @@ def cg(
x0: Initial solution.
tol: Relative residual stopping tolerance. Convergence occurs
when ``norm(residual) <= max(tol * norm(b), atol)``.
atol : Absolute residual stopping tolerance. Convergence occurs
atol: Absolute residual stopping tolerance. Convergence occurs
when ``norm(residual) <= max(tol * norm(b), atol)``.
maxiter: Maximum iterations. Default: 1000.
M: Preconditioner for A. The preconditioner should approximate
Expand Down
8 changes: 4 additions & 4 deletions scico/test/test_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,25 +182,25 @@ def f(x):

def test_split_join_array():
x, key = random.randn((4, 4), dtype=np.complex64)
x_s = solver.split_real_imag(x)
x_s = solver._split_real_imag(x)
assert x_s.shape == (2, 4, 4)
np.testing.assert_allclose(x_s[0], snp.real(x))
np.testing.assert_allclose(x_s[1], snp.imag(x))

x_j = solver.join_real_imag(x_s)
x_j = solver._join_real_imag(x_s)
np.testing.assert_allclose(x_j, x, rtol=1e-4)


def test_split_join_blockarray():
x, key = random.randn(((4, 4), (3,)), dtype=np.complex64)
x_s = solver.split_real_imag(x)
x_s = solver._split_real_imag(x)
assert x_s.shape == ((2, 4, 4), (2, 3))

real_block = BlockArray.array((x_s[0][0], x_s[1][0]))
imag_block = BlockArray.array((x_s[0][1], x_s[1][1]))
np.testing.assert_allclose(real_block.ravel(), snp.real(x).ravel(), rtol=1e-4)
np.testing.assert_allclose(imag_block.ravel(), snp.imag(x).ravel(), rtol=1e-4)

x_j = solver.join_real_imag(x_s)
x_j = solver._join_real_imag(x_s)
assert x_j.shape == x.shape
np.testing.assert_allclose(x_j.ravel(), x.ravel(), rtol=1e-4)

0 comments on commit 82a96ea

Please sign in to comment.