Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Suppress annoying jax device warning #385

Merged
merged 3 commits into from
Mar 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions docs/source/notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ No GPU/TPU Warning

JAX currently issues a warning when used on a platform without a
GPU. To disable this warning, set the environment variable
``JAX_PLATFORM_NAME=cpu`` before running Python.
``JAX_PLATFORM_NAME=cpu`` before running Python. This warning is
suppressed by SCICO for JAX versions after 0.3.23, making use of
the environment variable unnecessary.


Debugging
Expand Down Expand Up @@ -70,7 +72,7 @@ the generation and splitting of PRNG keys.
print(y) # [ 0.00870693 -0.04888531]

The user is responsible for passing the PRNG key to
:mod:`scico.random` functions. If no key is passed, repeated calls to
:mod:`scico.random` functions. If no key is passed, repeated calls to
:mod:`scico.random` functions will return the same random numbers:

::
Expand Down Expand Up @@ -169,7 +171,7 @@ Non-differentiable Functionals
------------------------------

:func:`scico.grad` can be applied to any function, but has undefined
behavior for non-differentiable functions. For non-differerentiable
behavior for non-differentiable functions. For non-differerentiable
functions, :func:`scico.grad` may or may not return a valid
subgradient. As an example, ``scico.grad(snp.abs)(0.) = 0``, which is
a valid subgradient. However, ``scico.grad(snp.linalg.norm)([0., 0.])
Expand All @@ -180,7 +182,7 @@ differentiable and non-differentiable function should be avoided. As
an example, :math:`f(x) = \norm{x}_2^2` can be implemented in as ``f =
lambda x: snp.linalg.norm(x)**2``. This involves first calculating the
non-squared :math:`\ell_2` norm, then squaring it. The un-squared
:math:`\ell_2` norm is not differentiable at zero. When evaluating
:math:`\ell_2` norm is not differentiable at zero. When evaluating
the gradient of ``f`` at 0, :func:`scico.grad` returns ``nan``:

::
Expand Down Expand Up @@ -248,7 +250,7 @@ We recommend that input data be converted to DeviceArray via
``jax.device_put`` before calling any SCICO optimizers.

On a multi-GPU system, ``jax.device_put`` can place data on a specific
GPU. See the `JAX notes on data placement
GPU. See the `JAX notes on data placement
<https://jax.readthedocs.io/en/latest/faq.html?highlight=data%20placement#controlling-data-and-computation-placement-on-devices>`_.


Expand Down
10 changes: 9 additions & 1 deletion scico/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (C) 2021-2022 by SCICO Developers
# Copyright (C) 2021-2023 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
Expand All @@ -10,6 +10,7 @@

__version__ = "0.0.4.dev0"

import logging
import sys

from . import _python37 # python 3.7 compatibility
Expand All @@ -23,6 +24,13 @@

from . import numpy

# Suppress jax device warning. See https://github.com/google/jax/issues/6805
# This only works for jax>0.3.23; for earlier versions, the getLogger
# argument should be "absl".
logging.getLogger("jax._src.lib.xla_bridge").addFilter(
logging.Filter("No GPU/TPU found, falling back to CPU.")
)

__all__ = [
"grad",
"value_and_grad",
Expand Down