Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Various improvements to scico.flax and related example scripts #498

Merged
merged 42 commits into from
Feb 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
5fba865
Rename flax data files
bwohlberg Jan 11, 2024
9a4a40a
Merge branch 'main' into brendt/flax
bwohlberg Jan 11, 2024
29f1cb1
Merge branch 'main' into brendt/flax
bwohlberg Jan 16, 2024
79c57dc
Minor docstring improvements
bwohlberg Jan 18, 2024
55b3931
Minor docstring improvement
bwohlberg Jan 18, 2024
752a217
Remove astra-toolbox channel
bwohlberg Jan 22, 2024
7bfcdaf
Minor improvement
bwohlberg Jan 23, 2024
f3f6c5d
Add timer to trainer class
bwohlberg Jan 23, 2024
a54cc22
Minor clean up
bwohlberg Jan 23, 2024
c00025a
Minor change to log format
bwohlberg Jan 23, 2024
dbf507f
Improve log formatting
bwohlberg Jan 23, 2024
ff80b3e
Improve log formatting
bwohlberg Jan 23, 2024
83199b6
Update submodule
bwohlberg Jan 30, 2024
ae7ba06
Merge branch 'main' into brendt/flax
bwohlberg Jan 30, 2024
ceb808b
Typo fix
bwohlberg Jan 30, 2024
ae0226a
Improve log format
bwohlberg Jan 30, 2024
eca873f
Docs fixes
bwohlberg Jan 30, 2024
1a13096
Docs consistency
bwohlberg Jan 30, 2024
1564103
Minor edit
bwohlberg Jan 30, 2024
781a59e
Docs consistency
bwohlberg Jan 30, 2024
6d6d19a
Fix broken cross-references
bwohlberg Jan 31, 2024
ce7c1d9
Docs consistency
bwohlberg Jan 31, 2024
eddb6dd
Improve docs
bwohlberg Jan 31, 2024
6ea5754
Rename functions
bwohlberg Jan 31, 2024
d80aa27
Update function docs
bwohlberg Jan 31, 2024
58603c4
Update URL
bwohlberg Jan 31, 2024
5c25f8e
Clean up log format
bwohlberg Jan 31, 2024
9a542d9
Update submodule
bwohlberg Jan 31, 2024
4785973
Fix overly simple regex
bwohlberg Jan 31, 2024
ee80129
Trivial edit
bwohlberg Jan 31, 2024
7f2ef9e
Clean up some scripts
bwohlberg Jan 31, 2024
740ea48
Update submodule
bwohlberg Jan 31, 2024
c4206cd
Overlooked change from recent astra PR
Jan 31, 2024
6f795e1
Add note on GPU support test script
bwohlberg Jan 31, 2024
7484884
Merge branch 'brendt/flax' of github.com:lanl/scico into brendt/flax
Jan 31, 2024
d5bbdba
Overlooked change from recent astra PR
Jan 31, 2024
c3f4979
Add script for removing error output from notebooks
bwohlberg Jan 31, 2024
981b0dd
Merge branch 'brendt/flax' of github.com:lanl/scico into brendt/flax
bwohlberg Jan 31, 2024
11dfc6b
Update submodule
bwohlberg Feb 1, 2024
9f9506f
Fix tests
bwohlberg Feb 1, 2024
af1d2a7
Fix tests
bwohlberg Feb 1, 2024
28e2f3a
Update submodule
bwohlberg Feb 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ Version 0.0.6 (unreleased)
----------------------------

• Significant changes to ``linop.xray.astra`` API.
• Rename ``scico.flax.save_weights`` and ``scico.flax.load_weights`` to
``scico.flax.save_variables`` and ``scico.flax.load_variables``
respectively.



Expand Down
5 changes: 5 additions & 0 deletions docs/source/install.rst
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,11 @@ a version with GPU support:
numbers.


The script `misc/envinfo.py <https://github.com/lanl/scico/blob/main/misc/envinfo.py>`_
in the source distribution is provided as an aid to debugging GPU support
issues.



Additional Dependencies
-----------------------
Expand Down
26 changes: 24 additions & 2 deletions examples/jnb.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2022-2023 by SCICO Developers
# Copyright (C) 2022-2024 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
Expand Down Expand Up @@ -44,7 +44,7 @@ def py_file_to_string(src):
else:
# Set flag indicating that an import statement has been seen once one has
# been encountered
if re.match("^(import|from)", line):
if re.match("^import|^from .* import", line):
import_seen = True
lines.append(line)
# Backtrack through list of lines to find last import statement
Expand Down Expand Up @@ -221,3 +221,25 @@ def replace_markdown_cells(src, dst):
# the dst cell
if srccell[n]["cell_type"] == "markdown":
dstcell[n]["source"] = srccell[n]["source"]


def remove_error_output(src):
"""Remove output to stderr from all cells in `src`."""

if "cells" in src:
cells = src["cells"]
else:
cells = src["worksheets"][0]["cells"]

modified = False
for c in cells:
if "outputs" in c:
dellist = []
for n, out in enumerate(c["outputs"]):
if "name" in out and out["name"] == "stderr":
dellist.append(n)
modified = True
for n in dellist[::-1]:
del c["outputs"][n]

return modified
18 changes: 18 additions & 0 deletions examples/removejnberr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#!/usr/bin/env python

# Remove output to stderr in notebooks. NB: use with caution!
# Run as
# python removejnberr.py

import glob
import os

from jnb import read_notebook, remove_error_output
from py2jn.tools import write_notebook

for src in glob.glob(os.path.join("notebooks", "*.ipynb")):
nb = read_notebook(src)
modflg = remove_error_output(nb)
if modflg:
print(f"Removing output to stderr from {src}")
write_notebook(nb, src)
17 changes: 8 additions & 9 deletions examples/scripts/ct_astra_modl_train_foam2.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
from scico import metric, plot
from scico.flax.examples import load_ct_data
from scico.flax.train.traversals import clip_positive, construct_traversal
from scico.linop.xray.astra import XRayTransform
from scico.linop.xray.astra import XRayTransform2D

"""
Prepare parallel processing. Set an arbitrary processor count (only
Expand All @@ -81,9 +81,9 @@
Build CT projection operator.
"""
angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles
A = XRayTransform(
A = XRayTransform2D(
input_shape=(N, N),
detector_spacing=1,
det_spacing=1,
det_count=N,
angles=angles,
) # CT projection operator
Expand Down Expand Up @@ -138,7 +138,7 @@


"""
Construct functionality for making sure that the learned
Construct functionality for ensuring that the learned
regularization parameter is always positive.
"""
lmbdatrav = construct_traversal("lmbda") # select lmbda parameters in model
Expand All @@ -152,8 +152,8 @@
"""
Print configuration of distributed run.
"""
print(f"{'JAX process: '}{jax.process_index()}{' / '}{jax.process_count()}")
print(f"{'JAX local devices: '}{jax.local_devices()}")
print(f"\nJAX process: {jax.process_index()}{' / '}{jax.process_count()}")
print(f"JAX local devices: {jax.local_devices()}\n")


"""
Expand Down Expand Up @@ -212,9 +212,8 @@
cg_iter=model_conf["cg_iter_1"],
)
# First stage: initialization training loop.
workdir = os.path.join(os.path.expanduser("~"), ".cache", "scico", "examples", "modl_ct_out")

train_conf["workdir"] = workdir
workdir1 = os.path.join(os.path.expanduser("~"), ".cache", "scico", "examples", "modl_ct_out")
train_conf["workdir"] = workdir1
train_conf["post_lst"] = [lmbdapos]
# Construct training object
trainer = sflax.BasicFlaxTrainer(
Expand Down
18 changes: 8 additions & 10 deletions examples/scripts/ct_astra_odp_train_foam2.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
from scico import metric, plot
from scico.flax.examples import load_ct_data
from scico.flax.train.traversals import clip_positive, construct_traversal
from scico.linop.xray.astra import XRayTransform
from scico.linop.xray.astra import XRayTransform2D

"""
Prepare parallel processing. Set an arbitrary processor count (only
Expand All @@ -85,9 +85,9 @@
Build CT projection operator.
"""
angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles
A = XRayTransform(
A = XRayTransform2D(
input_shape=(N, N),
detector_spacing=1,
det_spacing=1,
det_count=N,
angles=angles,
) # CT projection operator
Expand Down Expand Up @@ -138,7 +138,7 @@


"""
Construct functionality for making sure that the learned fidelity weight
Construct functionality for ensuring that the learned fidelity weight
parameter is always positive.
"""
alphatrav = construct_traversal("alpha") # select alpha parameters in model
Expand All @@ -152,8 +152,8 @@
"""
Print configuration of distributed run.
"""
print(f"{'JAX process: '}{jax.process_index()}{' / '}{jax.process_count()}")
print(f"{'JAX local devices: '}{jax.local_devices()}")
print(f"\nJAX process: {jax.process_index()}{' / '}{jax.process_count()}")
print(f"JAX local devices: {jax.local_devices()}\n")


"""
Expand Down Expand Up @@ -185,10 +185,7 @@
train_ds,
test_ds,
)

start_time = time()
modvar, stats_object = trainer.train()
time_train = time() - start_time


"""
Expand All @@ -215,13 +212,14 @@
psnr_eval = metric.psnr(test_ds["label"][:maxn], output)
print(
f"{'ODPNet training':18s}{'epochs:':2s}{epochs:>5d}{'':21s}"
f"{'time[s]:':10s}{time_train:>7.2f}"
f"{'time[s]:':10s}{trainer.train_time:>7.2f}"
)
print(
f"{'ODPNet testing':18s}{'SNR:':5s}{snr_eval:>5.2f}{' dB'}{'':3s}"
f"{'PSNR:':6s}{psnr_eval:>5.2f}{' dB'}{'':3s}{'time[s]:':10s}{time_eval:>7.2f}"
)


"""
Plot comparison.
"""
Expand Down
11 changes: 3 additions & 8 deletions examples/scripts/ct_astra_unet_train_foam2.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,21 +104,16 @@
"""
workdir = os.path.join(os.path.expanduser("~"), ".cache", "scico", "examples", "unet_ct_out")
train_conf["workdir"] = workdir
print(f"{'JAX process: '}{jax.process_index()}{' / '}{jax.process_count()}")
print(f"{'JAX local devices: '}{jax.local_devices()}")
print(f"\nJAX process: {jax.process_index()}{' / '}{jax.process_count()}")
print(f"JAX local devices: {jax.local_devices()}\n")


# Construct training object
trainer = sflax.BasicFlaxTrainer(
train_conf,
model,
train_ds,
test_ds,
)

start_time = time()
modvar, stats_object = trainer.train()
time_train = time() - start_time


"""
Expand All @@ -144,7 +139,7 @@
psnr_eval = metric.psnr(test_ds["label"][:maxn], output)
print(
f"{'UNet training':15s}{'epochs:':2s}{train_conf['num_epochs']:>5d}"
f"{'':21s}{'time[s]:':10s}{time_train:>7.2f}"
f"{'':21s}{'time[s]:':10s}{trainer.train_time:>7.2f}"
)
print(
f"{'UNet testing':15s}{'SNR:':5s}{snr_eval:>5.2f}{' dB'}{'':3s}"
Expand Down
2 changes: 0 additions & 2 deletions examples/scripts/deconv_datagen_bsds.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
blur_sigma = 5 # Gaussian blur kernel parameter

opBlur = PaddedCircularConvolve(output_size, channels, blur_shape, blur_sigma)

opBlur_vmap = vmap(opBlur) # for batch processing


Expand All @@ -47,7 +46,6 @@
stride = 100 # stride to sample multiple patches from each image
augment = True # augment data via rotations and flips


train_ds, test_ds = load_image_data(
train_nimg,
test_nimg,
Expand Down
12 changes: 5 additions & 7 deletions examples/scripts/deconv_modl_train_foam1.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@

ishape = (output_size, output_size)
opBlur = CircularConvolve(h=psf, input_shape=ishape)

opBlur_vmap = jax.vmap(opBlur) # for batch processing in data generation


Expand Down Expand Up @@ -133,7 +132,7 @@


"""
Construct functionality for making sure that the learned regularization
Construct functionality for ensuring that the learned regularization
parameter is always positive.
"""
lmbdatrav = construct_traversal("lmbda") # select lmbda parameters in model
Expand All @@ -147,8 +146,8 @@
"""
Print configuration of distributed run.
"""
print(f"{'JAX process: '}{jax.process_index()}{' / '}{jax.process_count()}")
print(f"{'JAX local devices: '}{jax.local_devices()}")
print(f"\nJAX process: {jax.process_index()}{' / '}{jax.process_count()}")
print(f"JAX local devices: {jax.local_devices()}\n")


"""
Expand Down Expand Up @@ -204,9 +203,8 @@
cg_iter=model_conf["cg_iter"],
)
# First stage: initialization training loop.
workdir = os.path.join(os.path.expanduser("~"), ".cache", "scico", "examples", "modl_dcnv_out")

train_conf["workdir"] = workdir
workdir1 = os.path.join(os.path.expanduser("~"), ".cache", "scico", "examples", "modl_dcnv_out")
train_conf["workdir"] = workdir1
train_conf["post_lst"] = [lmbdapos]
# Construct training object
trainer = sflax.BasicFlaxTrainer(
Expand Down
14 changes: 5 additions & 9 deletions examples/scripts/deconv_odp_train_foam1.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@

ishape = (output_size, output_size)
opBlur = CircularConvolve(h=psf, input_shape=ishape)

opBlur_vmap = jax.vmap(opBlur) # for batch processing in data generation


Expand Down Expand Up @@ -153,7 +152,7 @@


"""
Construct functionality for making sure that the learned fidelity weight
Construct functionality for ensuring that the learned fidelity weight
parameter is always positive.
"""
alphatrav = construct_traversal("alpha") # select alpha parameters in model
Expand All @@ -167,10 +166,10 @@
"""
Run training loop.
"""
workdir = os.path.join(os.path.expanduser("~"), ".cache", "scico", "examples", "odp_dcnv_out")
print(f"{'JAX process: '}{jax.process_index()}{' / '}{jax.process_count()}")
print(f"{'JAX local devices: '}{jax.local_devices()}")
print(f"\nJAX process: {jax.process_index()}{' / '}{jax.process_count()}")
print(f"JAX local devices: {jax.local_devices()}\n")

workdir = os.path.join(os.path.expanduser("~"), ".cache", "scico", "examples", "odp_dcnv_out")
train_conf["workdir"] = workdir
train_conf["post_lst"] = [alphapos]
# Construct training object
Expand All @@ -180,10 +179,7 @@
train_ds,
test_ds,
)

start_time = time()
modvar, stats_object = trainer.train()
time_train = time() - start_time


"""
Expand All @@ -210,7 +206,7 @@
psnr_eval = metric.psnr(test_ds["label"][:maxn], output)
print(
f"{'ODPNet training':18s}{'epochs:':2s}{train_conf['num_epochs']:>5d}"
f"{'':21s}{'time[s]:':10s}{time_train:>7.2f}"
f"{'':21s}{'time[s]:':10s}{trainer.train_time:>7.2f}"
)
print(
f"{'ODPNet testing':18s}{'SNR:':5s}{snr_eval:>5.2f}{' dB'}{'':3s}"
Expand Down
13 changes: 4 additions & 9 deletions examples/scripts/denoise_dncnn_train_bsds.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
noise_range = False # Use fixed noise level
stride = 23 # Stride to sample multiple patches from each image


train_ds, test_ds = load_image_data(
train_nimg,
test_nimg,
Expand Down Expand Up @@ -105,19 +104,15 @@
"""
workdir = os.path.join(os.path.expanduser("~"), ".cache", "scico", "examples", "dncnn_out")
train_conf["workdir"] = workdir
print(f"{'JAX process: '}{jax.process_index()}{' / '}{jax.process_count()}")
print(f"{'JAX local devices: '}{jax.local_devices()}")
print(f"\nJAX local devices: {jax.local_devices()}\n")

trainer = sflax.BasicFlaxTrainer(
train_conf,
model,
train_ds,
test_ds,
)

start_time = time()
modvar, stats_object = trainer.train()
time_train = time() - start_time


"""
Expand All @@ -138,7 +133,7 @@
psnr_eval = metric.psnr(test_ds["label"][:test_patches], output)
print(
f"{'DnCNNNet training':18s}{'epochs:':2s}{train_conf['num_epochs']:>5d}"
f"{'':21s}{'time[s]:':10s}{time_train:>7.2f}"
f"{'':21s}{'time[s]:':10s}{trainer.train_time:>7.2f}"
)
print(
f"{'DnCNNNet testing':18s}{'SNR:':5s}{snr_eval:>5.2f}{' dB'}{'':3s}"
Expand All @@ -147,8 +142,8 @@


"""
Plot comparison. Note that patches have small sizes, thus, plots may
correspond to unidentifiable fragments.
Plot comparison. Note that plots may display unidentifiable image
fragments due to the small patch size.
"""
np.random.seed(123)
indx = np.random.randint(0, high=test_patches)
Expand Down
Loading
Loading