diff --git a/misc/conda/make_conda_env.sh b/misc/conda/make_conda_env.sh index 5c1cccf81..206632b7f 100755 --- a/misc/conda/make_conda_env.sh +++ b/misc/conda/make_conda_env.sh @@ -46,7 +46,7 @@ EOF ) # Requirements that cannot be installed via conda (i.e. have to use pip) NOCONDA=$(cat <<-EOF -bm3d bm4d faculty-sphinx-theme py2jn colour_demosaicing ray[tune] +flax bm3d bm4d faculty-sphinx-theme py2jn colour_demosaicing ray[tune] EOF ) diff --git a/scico/test/linop/test_radon_svmbir.py b/scico/test/linop/test_radon_svmbir.py index b469f64d5..9beeab0e3 100644 --- a/scico/test/linop/test_radon_svmbir.py +++ b/scico/test/linop/test_radon_svmbir.py @@ -29,6 +29,8 @@ BIG_INPUT_OFFSET_RANGE = (0, 3) SMALL_INPUT_OFFSET_RANGE = (0, 0.1) +device = jax.devices()[0] + def make_im(Nx, Ny, is_3d=True): x, y = snp.meshgrid(snp.linspace(-1, 1, Nx), snp.linspace(-1, 1, Ny)) @@ -147,6 +149,7 @@ def test_adjoint( adjoint_test(A) +@pytest.mark.skipif(device.platform != "cpu", reason="test hangs on gpu") @pytest.mark.parametrize( "Nx, Ny, num_angles, num_channels, dist_source_detector, magnification", (SMALL_INPUT,) ) @@ -191,6 +194,7 @@ def test_prox( prox_test(v, f, f.prox, alpha=0.25, rtol=5e-4) +@pytest.mark.skipif(device.platform != "cpu", reason="test hangs on gpu") @pytest.mark.parametrize( "Nx, Ny, num_angles, num_channels, dist_source_detector, magnification", (SMALL_INPUT,) )