Skip to content

Commit

Permalink
import tree_map from tree_util
Browse files Browse the repository at this point in the history
  • Loading branch information
eserie committed Nov 18, 2022
1 parent 75b8bf0 commit 24ec9ea
Show file tree
Hide file tree
Showing 10 changed files with 35 additions and 33 deletions.
10 changes: 5 additions & 5 deletions docs/notebooks/07_Online_Time_Series_Prediction.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@
}
],
"source": [
"jax.tree_map(lambda x: x[-1], opt_info.params)"
"jax.tree_util.tree_map(lambda x: x[-1], opt_info.params)"
]
},
{
Expand Down Expand Up @@ -774,7 +774,7 @@
"\n",
" def project_params(params: Any, opt_state: OptState = None):\n",
" del opt_state\n",
" return jax.tree_map(lambda w: jnp.clip(w, -1, 1), params)\n",
" return jax.tree_util.tree_map(lambda w: jnp.clip(w, -1, 1), params)\n",
"\n",
" def params_predicate(m: str, n: str, p: jnp.ndarray) -> bool:\n",
" # print(m, n, p)\n",
Expand Down Expand Up @@ -1456,7 +1456,7 @@
" def fun_batch(*args, **kwargs):\n",
" res = VMap(fun)(*args, **kwargs)\n",
" if take_mean:\n",
" res = jax.tree_map(lambda x: x.mean(axis=0), res)\n",
" res = jax.tree_util.tree_map(lambda x: x.mean(axis=0), res)\n",
" return res\n",
"\n",
" return fun_batch\n",
Expand Down Expand Up @@ -1681,7 +1681,7 @@
"\n",
" BEST_STEP_SIZE[name] = loss.idxmin()\n",
" best_idx = loss.reset_index(drop=True).idxmin()\n",
" BEST_GYM[name] = jax.tree_map(lambda x: x[:, best_idx], gym)\n",
" BEST_GYM[name] = jax.tree_util.tree_map(lambda x: x[:, best_idx], gym)\n",
"\n",
" ax = loss[loss < 0.15].plot(logx=True, logy=False, ax=ax, label=name)\n",
"plt.legend()"
Expand Down Expand Up @@ -1972,7 +1972,7 @@
"I_BEST_PARAM = jnp.argmin(x)\n",
"\n",
"\n",
"BEST_NEWTON_GYM = jax.tree_map(lambda x: x[:, I_BEST_PARAM], gym_newton)\n",
"BEST_NEWTON_GYM = jax.tree_util.tree_map(lambda x: x[:, I_BEST_PARAM], gym_newton)\n",
"print(\"Best newton parameters: \", STEP_SIZE, NEWTON_EPS)"
]
},
Expand Down
10 changes: 5 additions & 5 deletions docs/notebooks/07_Online_Time_Series_Prediction.md
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ pd.Series(opt_info.loss).expanding().mean().plot()
Let's look at the latest weights:

```python
jax.tree_map(lambda x: x[-1], opt_info.params)
jax.tree_util.tree_map(lambda x: x[-1], opt_info.params)
```


Expand Down Expand Up @@ -348,7 +348,7 @@ def build_agent(time_series_model=None, opt=None, embargo=1):

def project_params(params: Any, opt_state: OptState = None):
del opt_state
return jax.tree_map(lambda w: jnp.clip(w, -1, 1), params)
return jax.tree_util.tree_map(lambda w: jnp.clip(w, -1, 1), params)

def params_predicate(m: str, n: str, p: jnp.ndarray) -> bool:
# print(m, n, p)
Expand Down Expand Up @@ -563,7 +563,7 @@ def add_batch(fun, take_mean=True):
def fun_batch(*args, **kwargs):
res = VMap(fun)(*args, **kwargs)
if take_mean:
res = jax.tree_map(lambda x: x.mean(axis=0), res)
res = jax.tree_util.tree_map(lambda x: x.mean(axis=0), res)
return res

return fun_batch
Expand Down Expand Up @@ -649,7 +649,7 @@ for name, (gym, info) in res.items():

BEST_STEP_SIZE[name] = loss.idxmin()
best_idx = loss.reset_index(drop=True).idxmin()
BEST_GYM[name] = jax.tree_map(lambda x: x[:, best_idx], gym)
BEST_GYM[name] = jax.tree_util.tree_map(lambda x: x[:, best_idx], gym)

ax = loss[loss < 0.15].plot(logx=True, logy=False, ax=ax, label=name)
plt.legend()
Expand Down Expand Up @@ -758,7 +758,7 @@ x = jnp.where(jnp.isnan(x), jnp.inf, x)
I_BEST_PARAM = jnp.argmin(x)


BEST_NEWTON_GYM = jax.tree_map(lambda x: x[:, I_BEST_PARAM], gym_newton)
BEST_NEWTON_GYM = jax.tree_util.tree_map(lambda x: x[:, I_BEST_PARAM], gym_newton)
print("Best newton parameters: ", STEP_SIZE, NEWTON_EPS)
```

Expand Down
10 changes: 5 additions & 5 deletions docs/notebooks/07_Online_Time_Series_Prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def learn(y, X=None):

# Let's look at the latest weights:

jax.tree_map(lambda x: x[-1], opt_info.params)
jax.tree_util.tree_map(lambda x: x[-1], opt_info.params)


# ## Learn and Forecast
Expand Down Expand Up @@ -324,7 +324,7 @@ def model_with_loss(y, X=None):

def project_params(params: Any, opt_state: OptState = None):
del opt_state
return jax.tree_map(lambda w: jnp.clip(w, -1, 1), params)
return jax.tree_util.tree_map(lambda w: jnp.clip(w, -1, 1), params)

def params_predicate(m: str, n: str, p: jnp.ndarray) -> bool:
# print(m, n, p)
Expand Down Expand Up @@ -520,7 +520,7 @@ def add_batch(fun, take_mean=True):
def fun_batch(*args, **kwargs):
res = VMap(fun)(*args, **kwargs)
if take_mean:
res = jax.tree_map(lambda x: x.mean(axis=0), res)
res = jax.tree_util.tree_map(lambda x: x.mean(axis=0), res)
return res

return fun_batch
Expand Down Expand Up @@ -602,7 +602,7 @@ def scan_params(step_size):

BEST_STEP_SIZE[name] = loss.idxmin()
best_idx = loss.reset_index(drop=True).idxmin()
BEST_GYM[name] = jax.tree_map(lambda x: x[:, best_idx], gym)
BEST_GYM[name] = jax.tree_util.tree_map(lambda x: x[:, best_idx], gym)

ax = loss[loss < 0.15].plot(logx=True, logy=False, ax=ax, label=name)
plt.legend()
Expand Down Expand Up @@ -696,7 +696,7 @@ def scan_params(hparams):
I_BEST_PARAM = jnp.argmin(x)


BEST_NEWTON_GYM = jax.tree_map(lambda x: x[:, I_BEST_PARAM], gym_newton)
BEST_NEWTON_GYM = jax.tree_util.tree_map(lambda x: x[:, I_BEST_PARAM], gym_newton)
print("Best newton parameters: ", STEP_SIZE, NEWTON_EPS)

# +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@
"\n",
" def project_params(params: Any, opt_state: OptState = None):\n",
" del opt_state\n",
" return jax.tree_map(lambda w: jnp.clip(w, -1, 1), params)\n",
" return jax.tree_util.tree_map(lambda w: jnp.clip(w, -1, 1), params)\n",
"\n",
" def params_predicate(m: str, n: str, p: jnp.ndarray) -> bool:\n",
" # print(m, n, p)\n",
Expand Down Expand Up @@ -278,7 +278,7 @@
" BEST_STEP_SIZE[name] = loss.idxmin()\n",
"\n",
" best_idx = jnp.argmax(gym.reward[LEARN_TIME_SLICE].mean(axis=0))\n",
" BEST_GYM[name] = jax.tree_map(lambda x: x[:, best_idx], gym)\n",
" BEST_GYM[name] = jax.tree_util.tree_map(lambda x: x[:, best_idx], gym)\n",
"\n",
" ax = loss.plot(\n",
" logx=True, logy=False, ax=ax, label=name, ylim=(MIN_ERR, MAX_ERR)\n",
Expand Down Expand Up @@ -444,7 +444,7 @@
" x = jnp.where(jnp.isnan(x), jnp.inf, x)\n",
" I_BEST_PARAM = jnp.argmin(x)\n",
"\n",
" BEST_NEWTON_GYM = jax.tree_map(lambda x: x[:, I_BEST_PARAM], gym_newton)\n",
" BEST_NEWTON_GYM = jax.tree_util.tree_map(lambda x: x[:, I_BEST_PARAM], gym_newton)\n",
"\n",
" print(\"Best newton parameters: \", STEP_SIZE, NEWTON_EPS)\n",
"\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def build_agent(time_series_model=None, opt=None):

def project_params(params: Any, opt_state: OptState = None):
del opt_state
return jax.tree_map(lambda w: jnp.clip(w, -1, 1), params)
return jax.tree_util.tree_map(lambda w: jnp.clip(w, -1, 1), params)

def params_predicate(m: str, n: str, p: jnp.ndarray) -> bool:
# print(m, n, p)
Expand Down Expand Up @@ -203,7 +203,7 @@ def scan_hparams_first_order():
BEST_STEP_SIZE[name] = loss.idxmin()

best_idx = jnp.argmax(gym.reward[LEARN_TIME_SLICE].mean(axis=0))
BEST_GYM[name] = jax.tree_map(lambda x: x[:, best_idx], gym)
BEST_GYM[name] = jax.tree_util.tree_map(lambda x: x[:, best_idx], gym)

ax = loss.plot(
logx=True, logy=False, ax=ax, label=name, ylim=(MIN_ERR, MAX_ERR)
Expand Down Expand Up @@ -324,7 +324,7 @@ def scan_hparams_newton():
x = jnp.where(jnp.isnan(x), jnp.inf, x)
I_BEST_PARAM = jnp.argmin(x)

BEST_NEWTON_GYM = jax.tree_map(lambda x: x[:, I_BEST_PARAM], gym_newton)
BEST_NEWTON_GYM = jax.tree_util.tree_map(lambda x: x[:, I_BEST_PARAM], gym_newton)

print("Best newton parameters: ", STEP_SIZE, NEWTON_EPS)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def model_with_loss(y, X=None):

def project_params(params: Any, opt_state: OptState = None):
del opt_state
return jax.tree_map(lambda w: jnp.clip(w, -1, 1), params)
return jax.tree_util.tree_map(lambda w: jnp.clip(w, -1, 1), params)

def params_predicate(m: str, n: str, p: jnp.ndarray) -> bool:
# print(m, n, p)
Expand Down Expand Up @@ -195,7 +195,7 @@ def scan_params(step_size):
BEST_STEP_SIZE[name] = loss.idxmin()

best_idx = jnp.argmax(gym.reward[LEARN_TIME_SLICE].mean(axis=0))
BEST_GYM[name] = jax.tree_map(lambda x: x[:, best_idx], gym)
BEST_GYM[name] = jax.tree_util.tree_map(lambda x: x[:, best_idx], gym)

ax = loss.plot(
logx=True, logy=False, ax=ax, label=name, ylim=(MIN_ERR, MAX_ERR)
Expand Down Expand Up @@ -310,7 +310,7 @@ def scan_params(hparams):
x = jnp.where(jnp.isnan(x), jnp.inf, x)
I_BEST_PARAM = jnp.argmin(x)

BEST_NEWTON_GYM = jax.tree_map(lambda x: x[:, I_BEST_PARAM], gym_newton)
BEST_NEWTON_GYM = jax.tree_util.tree_map(lambda x: x[:, I_BEST_PARAM], gym_newton)

print("Best newton parameters: ", STEP_SIZE, NEWTON_EPS)

Expand Down
2 changes: 1 addition & 1 deletion wax/modules/func_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __call__(self, *args, **kwargs):
if self.grads_fill_nan_inf:
grads = FillNanInf()(grads)

# trainable_params = jax.tree_map(self.opt, trainable_params, grads)
# trainable_params = jax.tree_util.tree_map(self.opt, trainable_params, grads)
trainable_params = self.opt(trainable_params, grads)

hk.set_state("trainable_params", trainable_params)
Expand Down
2 changes: 1 addition & 1 deletion wax/modules/snarimax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def model_with_loss(y, X=None):

def project_params(params: Any, opt_state: OptState = None):
del opt_state
return jax.tree_map(lambda w: jnp.clip(w, -1, 1), params)
return jax.tree_util.tree_map(lambda w: jnp.clip(w, -1, 1), params)

def params_predicate(m: str, n: str, p: jnp.ndarray) -> bool:
# print(m, n, p)
Expand Down
2 changes: 1 addition & 1 deletion wax/modules/vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def fun_batch(*args, **kwargs):
res = vmap_lift_with_state(fun, split_rng=False)(*args, **kwargs)

if take_mean:
res = jax.tree_map(lambda x: x.mean(axis=0), res)
res = jax.tree_util.tree_map(lambda x: x.mean(axis=0), res)
return res

return fun_batch
14 changes: 8 additions & 6 deletions wax/optim/newton.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def scale_by_newton(eps: float = 1e-7) -> base.GradientTransformation:
"""

def init_fn(params):
hessian_inv = jax.tree_map(
hessian_inv = jax.tree_util.tree_map(
lambda t: jnp.eye(len(t.flatten()), dtype=t.dtype) / eps, params
)
return ScaleByNewtonState(hessian_inv=hessian_inv)
Expand All @@ -89,13 +89,15 @@ class Tuple(tuple):

...

shapes = jax.tree_map(lambda x: Tuple(x.shape), updates)
updates = jax.tree_map(lambda x: x.flatten(), updates)
hessian_inv = jax.tree_map(
shapes = jax.tree_util.tree_map(lambda x: Tuple(x.shape), updates)
updates = jax.tree_util.tree_map(lambda x: x.flatten(), updates)
hessian_inv = jax.tree_util.tree_map(
lambda u, hinv: sherman_morrison(hinv, u, u), updates, state.hessian_inv
)
updates = jax.tree_map(lambda hinv, g: hinv @ g, hessian_inv, updates)
updates = jax.tree_map(lambda u, shape: u.reshape(shape), updates, shapes)
updates = jax.tree_util.tree_map(lambda hinv, g: hinv @ g, hessian_inv, updates)
updates = jax.tree_util.tree_map(
lambda u, shape: u.reshape(shape), updates, shapes
)

return updates, ScaleByNewtonState(hessian_inv=hessian_inv)

Expand Down

0 comments on commit 24ec9ea

Please sign in to comment.