Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve ray.tune usage example. #209

Merged
merged 14 commits into from
Feb 4, 2022
2 changes: 1 addition & 1 deletion data
15 changes: 15 additions & 0 deletions docs/source/exampledepend.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
.. _example_dependencies:

Example Dependencies
--------------------

Some examples use additional dependencies, which are listed in `examples_requirements.txt <https://github.com/lanl/scico/blob/main/examples/examples_requirements.txt>`_.
The additional requirements should be installed via pip, with the exception of ``astra-toolbox``,
which should be installed via conda:

::

conda install -c astra-toolbox astra-toolbox
pip install -r examples/examples_requirements.txt # Installs other example requirements

The dependencies can also be installed individually as required.
19 changes: 4 additions & 15 deletions docs/source/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,7 @@ Usage Examples
.. toctree::
:maxdepth: 1

.. _example_dependencies:

Example Dependencies
--------------------

Some examples use additional dependencies, which are listed in `examples_requirements.txt <https://github.com/lanl/scico/blob/main/examples/examples_requirements.txt>`_.
The additional requirements should be installed via pip, with the exception of ``astra-toolbox``,
which should be installed via conda:

::

conda install -c astra-toolbox astra-toolbox
pip install -r examples/examples_requirements.txt # Installs other example requirements

The dependencies can also be installed individually as required.
.. include:: exampledepend.rst

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you remove the example related dependencies? I don't think these are there anywhere else.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is going to be a continual pain until #195 is addressed. I'll fix manually for now.


Organized by Application
Expand Down Expand Up @@ -57,6 +43,7 @@ Deconvolution
examples/deconv_ppp_bm3d_pgm
examples/deconv_ppp_dncnn_admm
examples/deconv_tv_admm
examples/deconv_tv_admm_tune


Sparse Coding
Expand Down Expand Up @@ -117,6 +104,7 @@ Total Variation
examples/deconv_microscopy_tv_admm
examples/deconv_microscopy_allchn_tv_admm
examples/deconv_tv_admm
examples/deconv_tv_admm_tune
examples/denoise_tv_iso_admm
examples/denoise_tv_iso_pgm
examples/denoise_tv_iso_multi
Expand Down Expand Up @@ -158,6 +146,7 @@ ADMM
examples/deconv_ppp_bm3d_admm
examples/deconv_ppp_dncnn_admm
examples/deconv_tv_admm
examples/deconv_tv_admm_tune
examples/demosaic_ppp_bm3d_admm
examples/denoise_tv_iso_admm
examples/denoise_tv_iso_multi
Expand Down
8 changes: 7 additions & 1 deletion examples/makeindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,18 @@
print(".. _example_notebooks:\n", file=dstfile)
with open(src, "r") as srcfile:
for line in srcfile:
# Add toctree and include statements after main heading
if line[0:3] == "===":
print(line, end="", file=dstfile)
print("\n.. toctree::\n :maxdepth: 1", file=dstfile)
print("\n.. include:: exampledepend.rst", file=dstfile)
continue
# Detect lines containing script filenames
m = re.match(r"(\s+)- ([^\s]+).py", line)
if m:
print(" " + prfx + m.group(2), file=dstfile)
else:
print(line, end="", file=dstfile)
# Add toctree statements after section headings
# Add toctree statement after section headings
if line[0:3] == line[0] * 3 and line[0] in ["=", "-", "^"]:
print("\n.. toctree::\n :maxdepth: 1", file=dstfile)
6 changes: 6 additions & 0 deletions examples/scripts/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ Deconvolution
Image Deconvolution (ADMM Plug-and-Play Priors w/ DnCNN)
`deconv_tv_admm.py <deconv_tv_admm.py>`_
Image Deconvolution (ADMM w/ Total Variation)
`deconv_tv_admm_tune.py <deconv_tv_admm_tune.py>`_
Image Deconvolution Parameter Tuning


Sparse Coding
Expand Down Expand Up @@ -106,6 +108,8 @@ Total Variation
Deconvolution Microscopy (All Channels)
`deconv_tv_admm.py <deconv_tv_admm.py>`_
Image Deconvolution (ADMM w/ Total Variation)
`deconv_tv_admm_tune.py <deconv_tv_admm_tune.py>`_
Image Deconvolution Parameter Tuning
`denoise_tv_iso_admm.py <denoise_tv_iso_admm.py>`_
Isotropic Total Variation (ADMM)
`denoise_tv_iso_pgm.py <denoise_tv_iso_pgm.py>`_
Expand Down Expand Up @@ -156,6 +160,8 @@ ADMM
Image Deconvolution (ADMM Plug-and-Play Priors w/ DnCNN)
`deconv_tv_admm.py <deconv_tv_admm.py>`_
Image Deconvolution (ADMM w/ Total Variation)
`deconv_tv_admm_tune.py <deconv_tv_admm_tune.py>`_
Image Deconvolution Parameter Tuning
`demosaic_ppp_bm3d_admm.py <demosaic_ppp_bm3d_admm.py>`_
Image Demosaicing (ADMM Plug-and-Play Priors w/ BM3D)
`denoise_tv_iso_admm.py <denoise_tv_iso_admm.py>`_
Expand Down
31 changes: 24 additions & 7 deletions examples/scripts/deconv_tv_admm_tune.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# This file is part of the SCICO package. Details of the copyright
# and user license can be found in the 'LICENSE.txt' file distributed
# with the package.

r"""
Image Deconvolution Parameter Tuning
====================================

This example demonstrates the use of
[scico.ray.tune](../_autosummary/scico.ray.tune.rst) to tune parameters
for the companion [example script](deconv_tv_admm.rst).
"""


import numpy as np

import jax
Expand Down Expand Up @@ -43,7 +59,7 @@
"""


def eval_params(config):
def eval_params(config, reporter):
# Extract solver parameters from config dict.
λ, ρ = config["lambda"], config["rho"]
# Get main arrays from ray object store.
Expand All @@ -62,20 +78,20 @@ def eval_params(config):
C_list=[C],
rho_list=[ρ],
x0=A.adj(y),
maxiter=5,
maxiter=10,
subproblem_solver=LinearSubproblemSolver(),
)
# Perform 50 iterations, reporting performance to ray.tune every 5 iterations.
for step in range(10):
# Perform 50 iterations, reporting performance to ray.tune every 10 iterations.
for step in range(5):
x_admm = solver.solve()
tune.report(psnr=float(metric.psnr(x_gt, x_admm)))
reporter(psnr=float(metric.psnr(x_gt, x_admm)))


"""
Define parameter search space and resources per trial.
"""
config = {"lambda": tune.loguniform(1e-2, 1e0), "rho": tune.loguniform(1e-1, 1e1)}
resources = {"gpu": 0, "cpu": 1} # gpus per trial, cpus per trial
resources = {"cpu": 4, "gpu": 0} # cpus per trial, gpus per trial


"""
Expand Down Expand Up @@ -119,7 +135,7 @@ def eval_params(config):
mec="blue",
fig=fig,
)
_, ax = plot.plot(
plot.plot(
best_config["lambda"],
best_config["rho"],
ptyp="loglog",
Expand All @@ -133,6 +149,7 @@ def eval_params(config):
mec="red",
fig=fig,
)
ax = fig.axes[0]
ax.set_xlim([config["rho"].lower, config["rho"].upper])
ax.set_ylim([config["lambda"].lower, config["lambda"].upper])
fig.show()
Expand Down
3 changes: 3 additions & 0 deletions examples/scripts/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ Deconvolution
- deconv_ppp_bm3d_pgm.py
- deconv_ppp_dncnn_admm.py
- deconv_tv_admm.py
- deconv_tv_admm_tune.py


Sparse Coding
Expand Down Expand Up @@ -72,6 +73,7 @@ Total Variation
- deconv_microscopy_tv_admm.py
- deconv_microscopy_allchn_tv_admm.py
- deconv_tv_admm.py
- deconv_tv_admm_tune.py
- denoise_tv_iso_admm.py
- denoise_tv_iso_pgm.py
- denoise_tv_iso_multi.py
Expand Down Expand Up @@ -104,6 +106,7 @@ ADMM
- deconv_ppp_bm3d_admm.py
- deconv_ppp_dncnn_admm.py
- deconv_tv_admm.py
- deconv_tv_admm_tune.py
- demosaic_ppp_bm3d_admm.py
- denoise_tv_iso_admm.py
- denoise_tv_iso_multi.py
Expand Down
23 changes: 19 additions & 4 deletions scico/ray/tune.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2021 by SCICO Developers
# Copyright (C) 2021-2022 by SCICO Developers
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For other files, should the end year be the current year or the year the file was last edited on?
In the long-term, should we manually update these or should we make a script (say in scico/misc) that updates these automatically.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is that the end year should correspond to the year of most recent edit. I think we can live with manual updates for now.

# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
Expand All @@ -8,6 +8,8 @@
"""Parameter tuning using :doc:`ray.tune <ray:tune/index>`."""

import datetime
import os
import tempfile
from typing import Any, Callable, Dict, List, Mapping, Optional, Type, Union

import ray
Expand Down Expand Up @@ -71,6 +73,7 @@ def run(
config: Optional[Dict[str, Any]] = None,
hyperopt: bool = True,
verbose: bool = True,
local_dir: Optional[str] = None,
) -> ray.tune.ExperimentAnalysis:
"""Simplified wrapper for :func:`ray.tune.run`.

Expand All @@ -97,6 +100,10 @@ def run(
running, and terminated trials are indicated by "P:", "R:",
and "T:" respectively, followed by the current best metric
value and the parameters at which it was reported.
local_dir: Directory in which to save tuning results. Defaults to
a subdirectory "ray_results" within the path returned by
`tempfile.gettempdir()`, corresponding e.g. to
"/tmp/ray_results" under Linux.

Returns:
Result of parameter search.
Expand All @@ -114,15 +121,23 @@ def run(
else:
kwargs.update({"verbose": 0})

def _run(config, checkpoint_dir=None):
run_or_experiment(config)
if isinstance(run_or_experiment, str):
name = run_or_experiment
else:
name = run_or_experiment.__name__
name += "_" + datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

if local_dir is None:
local_dir = os.path.join(tempfile.gettempdir(), "ray_results")

return ray.tune.run(
_run,
run_or_experiment,
metric=metric,
mode=mode,
name=name,
time_budget_s=time_budget_s,
num_samples=num_samples,
local_dir=local_dir,
resources_per_trial=resources_per_trial,
max_concurrent_trials=max_concurrent_trials,
reuse_actors=True,
Expand Down
12 changes: 9 additions & 3 deletions scico/test/test_ray_tune.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import os
import tempfile

import numpy as np

import pytest
Expand All @@ -11,11 +14,13 @@
pytest.skip("ray.tune not installed", allow_module_level=True)


def eval_params(config):
def eval_params(config, reporter):
x, y = config["x"], config["y"]
cost = x ** 2 + (y - 0.5) ** 2
tune.report(cost=cost)
reporter(cost=cost)


tune.ray.tune.register_trainable("eval_func", eval_params)

config = {"x": tune.uniform(-1, 1), "y": tune.uniform(-1, 1)}
resources = {"gpu": 0, "cpu": 1}
Expand All @@ -24,14 +29,15 @@ def eval_params(config):
@pytest.mark.filterwarnings("ignore::pytest.PytestUnhandledThreadExceptionWarning")
def test_random():
analysis = tune.run(
eval_params,
"eval_func",
metric="cost",
mode="min",
num_samples=100,
config=config,
resources_per_trial=resources,
hyperopt=False,
verbose=False,
local_dir=os.path.join(tempfile.gettempdir(), "ray_test"),
)
best_config = analysis.get_best_config(metric="cost", mode="min")
assert np.abs(best_config["x"]) < 0.25
Expand Down