Skip to content

Commit

Permalink
Migrate from jax.experimental.host_callback() to jax.debug.callback() (
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanocortinovis authored Nov 1, 2024
1 parent 8f14cbd commit 823c931
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions gpjax/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
)
import jax
from jax import lax
from jax.experimental import host_callback as hcb
import jax.numpy as jnp
import jax.tree_util as jtu
from jaxtyping import (
Expand Down Expand Up @@ -54,7 +53,8 @@ def _callback(cond: ScalarBool, func: Callable, *args: Any) -> None:

def _do_callback(_) -> int:
"""Perform the callback."""
return hcb.id_tap(func, *args, result=_dummy_result)
jax.debug.callback(func, *args)
return _dummy_result

def _not_callback(_) -> int:
"""Do nothing."""
Expand Down Expand Up @@ -113,19 +113,19 @@ def vscan(
_progress_bar = trange(_length)
_progress_bar.set_description("Compiling...", refresh=True)

def _set_running(args: Any, transform: Any) -> None:
def _set_running(*args: Any) -> None:
"""Set the tqdm progress bar to running."""
_progress_bar.set_description("Running", refresh=False)

def _update_tqdm(args: Any, transform: Any) -> None:
def _update_tqdm(*args: Any) -> None:
"""Update the tqdm progress bar with the latest objective value."""
_value, _iter_num = args
_progress_bar.update(_iter_num)
_progress_bar.update(_iter_num.item())

if log_value and _value is not None:
_progress_bar.set_postfix({"Value": f"{_value: .2f}"})

def _close_tqdm(args: Any, transform: Any) -> None:
def _close_tqdm(*args: Any) -> None:
"""Close the tqdm progress bar."""
_progress_bar.close()

Expand All @@ -145,16 +145,16 @@ def _body_fun(carry: Carry, iter_num_and_x: Tuple[ScalarInt, X]) -> Tuple[Carry,
_is_last: bool = iter_num == _length - 1

# Update progress bar, if first of log_rate.
_callback(_is_first, _set_running, (y, log_rate))
_callback(_is_first, _set_running)

# Update progress bar, if multiple of log_rate.
_callback(_is_multiple, _update_tqdm, (y, log_rate))
_callback(_is_multiple, _update_tqdm, y, log_rate)

# Update progress bar, if remainder.
_callback(_is_remainder, _update_tqdm, (y, _remainder))
_callback(_is_remainder, _update_tqdm, y, _remainder)

# Close progress bar, if last iteration.
_callback(_is_last, _close_tqdm, (y, None))
_callback(_is_last, _close_tqdm)

return carry, y

Expand Down

0 comments on commit 823c931

Please sign in to comment.