From 7dd463e4d7cfef8843aca3dbf1464e635c757d17 Mon Sep 17 00:00:00 2001 From: Alistair Muldal Date: Fri, 6 Oct 2023 04:24:48 -0700 Subject: [PATCH] Replace `jax.linear_util` -> `jax.extend.linear_util` to suppress `DeprecationWarning`s Also suppressed a pytype error relating to conditional `tree` import. PiperOrigin-RevId: 571299183 --- haiku/_src/BUILD | 4 ++-- haiku/_src/dot.py | 9 +++++---- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/haiku/_src/BUILD b/haiku/_src/BUILD index 918bf3687..25d19fa59 100644 --- a/haiku/_src/BUILD +++ b/haiku/_src/BUILD @@ -342,12 +342,12 @@ hk_py_library( name = "dot", srcs = ["dot.py"], deps = [ - ":config", + ":config", # build_cleaner: keep ":data_structures", ":module", ":utils", # pip: jax - # pip: tree + # pip: jax:extend ], ) diff --git a/haiku/_src/dot.py b/haiku/_src/dot.py index bbdbae274..bd8ab1b7c 100644 --- a/haiku/_src/dot.py +++ b/haiku/_src/dot.py @@ -26,12 +26,13 @@ import jax import jax.core from jax.experimental import pjit +from jax.extend import linear_util # Import tree if available, but only throw error at runtime. # Permits us to drop dm-tree from deps. try: - import tree # pylint: disable=g-import-not-at-top + import tree # pylint: disable=g-import-not-at-top # pytype: disable=import-error except ImportError: tree = None @@ -134,7 +135,7 @@ def to_graph(fun): @functools.wraps(fun) def wrapped_fun(*args): """See `fun`.""" - f = jax.linear_util.wrap_init(fun) + f = linear_util.wrap_init(fun) args_flat, in_tree = jax.tree_util.tree_flatten((args, {})) flat_fun, out_tree = jax.api_util.flatten_fun(f, in_tree) graph = Graph.create(title=name_or_str(fun)) @@ -160,7 +161,7 @@ def method_hook(mod: module.Module, method_name: str): return wrapped_fun -@jax.linear_util.transformation +@linear_util.transformation def _interpret_subtrace(main, *in_vals): trace = DotTrace(main, jax.core.cur_sublevel()) in_tracers = [DotTracer(trace, val) for val in in_vals] @@ -202,7 +203,7 @@ def process_primitive(self, primitive, tracers, params): if primitive is pjit.pjit_p: f = jax.core.jaxpr_as_fun(params['jaxpr']) f.__name__ = params['name'] - fun = jax.linear_util.wrap_init(f) + fun = linear_util.wrap_init(f) return self.process_call(primitive, fun, tracers, params) inputs = [t.val for t in tracers]