Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provide JAX Ops from Optional tensorflow-probability dependency #403

Merged
merged 1 commit into from
Jul 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ jobs:
# PyTensor next, pip installs a lower version of numpy via the PyPI.
if [[ $INSTALL_NUMBA == "1" ]] && [[ $PYTHON_VERSION == "3.9" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numpy<1.23" "numba>=0.57" numba-scipy; fi
if [[ $INSTALL_NUMBA == "1" ]] && [[ $PYTHON_VERSION != "3.9" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.57" numba-scipy; fi
if [[ $INSTALL_JAX == "1" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib numpyro; fi
if [[ $INSTALL_JAX == "1" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib numpyro && pip install tensorflow-probability; fi
pip install -e ./
mamba list && pip freeze
python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))'
Expand Down
36 changes: 35 additions & 1 deletion pytensor/link/jax/dispatch/scalar.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import functools
import typing
from typing import Callable, Optional

import jax
import jax.numpy as jnp
Expand All @@ -18,7 +20,21 @@
Second,
Sub,
)
from pytensor.scalar.math import Erf, Erfc, Erfinv, Log1mexp, Psi
from pytensor.scalar.math import Erf, Erfc, Erfcinv, Erfcx, Erfinv, Iv, Log1mexp, Psi


def try_import_tfp_jax_op(op: ScalarOp, jax_op_name: Optional[str] = None) -> Callable:
try:
import tensorflow_probability.substrates.jax.math as tfp_jax_math
except ModuleNotFoundError:
raise NotImplementedError(
f"No JAX implementation for Op {op.name}. "
"Implementation is available if TensorFlow Probability is installed"
)

if jax_op_name is None:
jax_op_name = op.name
return typing.cast(Callable, getattr(tfp_jax_math, jax_op_name))


def check_if_inputs_scalars(node):
Expand Down Expand Up @@ -211,6 +227,24 @@ def erfinv(x):
return erfinv


@jax_funcify.register(Erfcx)
@jax_funcify.register(Erfcinv)
def jax_funcify_from_tfp(op, **kwargs):
tfp_jax_op = try_import_tfp_jax_op(op)

return tfp_jax_op


@jax_funcify.register(Iv)
def jax_funcify_Iv(op, **kwargs):
ive = try_import_tfp_jax_op(op, jax_op_name="bessel_ive")

def iv(v, x):
return ive(v, x) / jnp.exp(-jnp.abs(jnp.real(x)))

return iv


@jax_funcify.register(Log1mexp)
def jax_funcify_Log1mexp(op, node, **kwargs):
def log1mexp(x):
Expand Down
29 changes: 29 additions & 0 deletions tests/link/jax/test_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,17 @@
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import get_test_value
from pytensor.scalar.basic import Composite
from pytensor.tensor import as_tensor
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.math import all as at_all
from pytensor.tensor.math import (
cosh,
erf,
erfc,
erfcinv,
erfcx,
erfinv,
iv,
log,
log1mexp,
psi,
Expand All @@ -28,6 +32,14 @@
from pytensor.link.jax.dispatch import jax_funcify


try:
pass

TFP_INSTALLED = True
except ModuleNotFoundError:
TFP_INSTALLED = False


def test_second():
a0 = scalar("a0")
b = scalar("b")
Expand Down Expand Up @@ -134,6 +146,23 @@ def test_erfinv():
compare_jax_and_py(fg, [0.95])


@pytest.mark.parametrize(
"op, test_values",
[
(erfcx, (0.7,)),
(erfcinv, (0.7,)),
(iv, (0.3, 0.7)),
],
)
@pytest.mark.skipif(not TFP_INSTALLED, reason="Test requires tensorflow-probability")
def test_tfp_ops(op, test_values):
inputs = [as_tensor(test_value).type() for test_value in test_values]
output = op(*inputs)

fg = FunctionGraph(inputs, [output])
compare_jax_and_py(fg, test_values)


def test_psi():
x = scalar("x")
out = psi(x)
Expand Down
Loading