From a4eee475cc7457569aeab5581267675fc7a2d3cd Mon Sep 17 00:00:00 2001 From: The oryx Authors Date: Sun, 3 Mar 2024 15:56:13 -0800 Subject: [PATCH] Add 'default_clobber' mode that lets different loop iterations sow different values to be reaped. PiperOrigin-RevId: 612253581 --- oryx/core/interpreters/harvest.py | 315 +++++++++++++++++-------- oryx/core/interpreters/harvest_test.py | 85 ++++++- 2 files changed, 301 insertions(+), 99 deletions(-) diff --git a/oryx/core/interpreters/harvest.py b/oryx/core/interpreters/harvest.py index 4d4c95b..e3aa2b1 100644 --- a/oryx/core/interpreters/harvest.py +++ b/oryx/core/interpreters/harvest.py @@ -51,11 +51,13 @@ default set to `'strict'`. The `mode` of a `sow` describes how it behaves when the same name appears multiple times. In "strict" mode, `sow` will error if the same `(tag, name)` appears more than once. Another option is `'append'`, in -which all sows of the same name will be appended into a growing array. Finally, -there is `'clobber'`, where only the final sown value for a given `(tag, name)` -will be returned. The final optional argument for `sow` is `key`, which will -automatically be tied-in to the output of `sow` to introduce a fake -data-dependence. By default, it is `None`. +which all sows of the same name will be appended into a growing array. There is +also `'clobber'`, where only the final sown value for a given `(tag, name)` will +be returned. Finally, there is `'default_clobber'`, which is like clobber but +allows sowing to be conditional, falling back on zeros if no sow took place. +The final optional argument for `sow` is `key`, which will automatically be +tied-in to the output of `sow` to introduce a fake data-dependence. By default, +it is `None`. ## `harvest` @@ -137,7 +139,8 @@ def f(x): import collections import dataclasses import functools -from typing import Any, Callable, Dict, FrozenSet, Hashable, Iterable, List, Optional, Tuple, Union +import typing +from typing import Any, Callable, Dict, FrozenSet, Hashable, Iterable, List, Literal, Optional, Tuple, Union from jax import api_util from jax import lax @@ -229,18 +232,29 @@ def sow(value, *, tag: Hashable, name: str, mode: str = 'strict', key=None): value: A JAX value to be tagged and named. tag: a string representing the tag of the sown value. name: a string representing the name to sow the value with. - mode: The mode by which to sow the value. There are three options: 1. + mode: The mode by which to sow the value. There are four options: 1. `'strict'` - if another value is sown with the same name and tag in the same context, harvest will throw an error. 2. `'clobber'` - if another is - value is sown with the same name and tag, it will replace this value 3. + value is sown with the same name and tag, it will replace this value. 3. `'append'` - sown values of the same name and tag are appended to a growing list. Append mode assumes some ordering on the values being sown - defined by data-dependence. + defined by data-dependence. 4. `'default_clobber'` - like `'clobber'`, but + sowing may be conditional, falling back on zeros if no sow took place. key: an optional JAX value that will be tied into the sown value. Returns: The original `value` that was passed in. """ + if mode == 'default_clobber': + return _sow(value, tag=tag, name=name, mode=mode, key=key, cond=True)[0] + else: + return _sow(value, tag=tag, name=name, mode=mode, key=key) + + +def _sow(value, *, tag, name, mode, key=None, cond=None): + assert (cond is not None) == (mode == 'default_clobber') + if cond is not None: + value = value, cond value = tree_util.tree_map(jax_core.raise_as_much_as_possible, value) if key is not None: value = prim.tie_in(key, value) @@ -441,7 +455,7 @@ class HarvestContext: def process_sow(self, *values, name, tag, mode, tree): """Handles a `sow` primitive in a `HarvestTrace`.""" - if mode not in {'strict', 'append', 'clobber'}: + if mode not in {'strict', 'append', 'clobber', 'default_clobber'}: raise ValueError(f'Invalid mode: {mode}') if tag != self.settings.tag: if self.settings.exclusive: @@ -493,6 +507,7 @@ def post_process_custom_vjp_call_fwd(self, trace, out_tracers, out_trees): @dataclasses.dataclass class Reap: value: Any + cond: Any metadata: Dict[str, Any] @@ -513,8 +528,13 @@ def handle_sow(self, *values, name, tag, tree, mode): avals = tree_util.tree_unflatten( tree, [jax_core.raise_to_shaped(jax_core.get_aval(v)) for v in values]) - self.reaps[name] = Reap( - tree_util.tree_unflatten(tree, values), dict(mode=mode, aval=avals)) + vals = tree_util.tree_unflatten(tree, values) + cond = None + if mode == 'default_clobber': + avals, _ = avals + vals, cond = vals + metadata = dict(mode=mode, aval=avals) + self.reaps[name] = Reap(vals, cond, metadata) return values def reap_higher_order_primitive(self, trace, call_primitive, f, tracers, @@ -537,13 +557,14 @@ def new_out_axes_thunk(): params = dict(params, out_axes_thunk=new_out_axes_thunk) out_flat = call_primitive.bind(f, *vals, name=name, **params) out_tree, metadata = aux() - out_vals, reaps = tree_util.tree_unflatten(out_tree, out_flat) + out_vals, reaps, conds = tree_util.tree_unflatten(out_tree, out_flat) out_tracers = jax_util.safe_map(trace.pure, out_vals) reap_tracers = tree_util.tree_map(trace.pure, reaps) - return out_tracers, reap_tracers, metadata + cond_tracers = tree_util.tree_map(trace.pure, conds) + return out_tracers, reap_tracers, cond_tracers, metadata def process_nest(self, trace, f, *tracers, scope, name, **params): - out_tracers, reap_tracers, _ = self.reap_higher_order_primitive( + out_tracers, reap_tracers, _, _ = self.reap_higher_order_primitive( trace, nest_p, f, tracers, dict(params, name=name, scope=scope), False) tag = self.settings.tag if reap_tracers: @@ -555,10 +576,15 @@ def process_nest(self, trace, f, *tracers, scope, name, **params): def process_higher_order_primitive(self, trace, call_primitive, f, tracers, params, is_map): - out_tracers, reap_tracers, metadata = self.reap_higher_order_primitive( - trace, call_primitive, f, tracers, params, is_map) + out_tracers, reap_tracers, cond_tracers, metadata = ( + self.reap_higher_order_primitive( + trace, call_primitive, f, tracers, params, is_map + ) + ) tag = self.settings.tag for k, v in reap_tracers.items(): + if metadata[k]['mode'] == 'default_clobber': + v = (v, cond_tracers[k]) flat_reap_tracers, reap_tree = tree_util.tree_flatten(v) trace.process_primitive( sow_p, flat_reap_tracers, @@ -583,10 +609,14 @@ def _jvp_subtrace(main, *args): out_flat = primitive.bind(fun, jvp, *vals_in, symbolic_zeros=symbolic_zeros) fst, (out_tree, metadata) = lu.merge_linear_aux(aux1, aux2) if fst: - out, reaps = tree_util.tree_unflatten(out_tree, out_flat) - out_tracers, reap_tracers = tree_util.tree_map(trace.pure, (out, reaps)) + out, reaps, conds = tree_util.tree_unflatten(out_tree, out_flat) + out_tracers, reap_tracers, cond_tracers = tree_util.tree_map( + trace.pure, (out, reaps, conds) + ) tag = context.settings.tag for k, v in reap_tracers.items(): + if metadata[k]['mode'] == 'default_clobber': + v = (v, cond_tracers[k]) flat_reap_tracers, reap_tree = tree_util.tree_flatten(v) trace.process_primitive( sow_p, flat_reap_tracers, @@ -621,10 +651,14 @@ def _fwd_subtrace(main, *args): symbolic_zeros=symbolic_zeros) fst, (out_tree, metadata) = lu.merge_linear_aux(aux1, aux2) if fst: - out, reaps = tree_util.tree_unflatten(out_tree, out_flat) - out_tracers, reap_tracers = tree_util.tree_map(trace.pure, (out, reaps)) + out, reaps, conds = tree_util.tree_unflatten(out_tree, out_flat) + out_tracers, reap_tracers, cond_tracers = tree_util.tree_map( + trace.pure, (out, reaps, conds) + ) tag = context.settings.tag for k, v in reap_tracers.items(): + if metadata[k]['mode'] == 'default_clobber': + v = (v, cond_tracers[k]) flat_reap_tracers, reap_tree = tree_util.tree_flatten(v) trace.process_primitive( sow_p, flat_reap_tracers, @@ -648,7 +682,7 @@ def post_process_custom_vjp_call_fwd(self, trace, out_tracers, out_trees): @lu.transformation def reap_function(main: jax_core.MainTrace, settings: HarvestSettings, return_metadata: bool, args: Iterable[Any]): - """A function transformation that returns reap values.""" + """A function transformation that returns reap values and conditions.""" trace = HarvestTrace(main, jax_core.cur_sublevel()) in_tracers = jax_util.safe_map(trace.pure, args) context = ReapContext(settings, {}) @@ -657,14 +691,17 @@ def reap_function(main: jax_core.MainTrace, settings: HarvestSettings, out_tracers = jax_util.safe_map(trace.full_raise, ans) reap_tracers = tree_util.tree_map( lambda x: tree_util.tree_map(trace.full_raise, x.value), context.reaps) + cond_tracers = tree_util.tree_map( + lambda x: trace.full_raise(x.cond), context.reaps) reap_metadata = tree_util.tree_map(lambda x: x.metadata, context.reaps) del main - out_values, reap_values = tree_util.tree_map(lambda x: x.val, - (out_tracers, reap_tracers)) + out_values, reap_values, cond_values = tree_util.tree_map( + lambda x: x.val, (out_tracers, reap_tracers, cond_tracers) + ) if return_metadata: - out = (out_values, reap_values, reap_metadata) + out = (out_values, reap_values, cond_values, reap_metadata) else: - out = (out_values, reap_values) + out = (out_values, reap_values, cond_values) yield out @@ -678,16 +715,16 @@ def reap_eval( @lu.transformation_with_aux def reap_wrapper(trace: HarvestTrace, *args): del trace - out, reaps, metadata = yield (args,), {} - out_flat, out_tree = tree_util.tree_flatten((out, reaps)) + out, reaps, conds, metadata = yield (args,), {} + out_flat, out_tree = tree_util.tree_flatten((out, reaps, conds)) yield out_flat, (out_tree, metadata) @lu.transformation def reap_wrapper_drop_aux(trace: HarvestTrace, *args): del trace - out, reaps, _ = yield (args,), {} - out_flat, _ = tree_util.tree_flatten((out, reaps)) + out, reaps, conds, _ = yield (args,), {} + out_flat, _ = tree_util.tree_flatten((out, reaps, conds)) yield out_flat @@ -713,6 +750,18 @@ def call_and_reap(f, A new function that executes the original and returns its sown values as an additional return value. """ + wrapped = _call_and_reap( + f, + tag=tag, + allowlist=allowlist, + blocklist=blocklist, + exclusive=exclusive, + ) + return lambda *args, **kwargs: wrapped(*args, **kwargs)[:-1] + + +def _call_and_reap(f, *, tag, allowlist, blocklist, exclusive): + """Like call_and_reap but including conditions.""" blocklist = frozenset(blocklist) if allowlist is not None: allowlist = frozenset(allowlist) @@ -724,9 +773,9 @@ def wrapped(*args, **kwargs): flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree) with jax_core.new_main(HarvestTrace) as main: flat_fun = reap_function(flat_fun, main, settings, False) - out_flat, reaps = flat_fun.call_wrapped(flat_args) + out_flat, reaps, conds = flat_fun.call_wrapped(flat_args) del main - return tree_util.tree_unflatten(out_tree(), out_flat), reaps + return tree_util.tree_unflatten(out_tree(), out_flat), reaps, conds return wrapped @@ -766,8 +815,8 @@ def wrapped(*args, **kwargs): @lu.transformation_with_aux def _reap_metadata_wrapper(*args): - out, reaps, metadata = yield (args,), {} - yield (out, reaps), metadata + out, reaps, conds, metadata = yield (args,), {} + yield (out, reaps, conds), metadata def _get_harvest_metadata(closed_jaxpr, settings, *args): @@ -789,6 +838,17 @@ def _get_harvest_metadata(closed_jaxpr, settings, *args): return metadata +def _update_clobber_carry(carry_reaps, carry_conds, name, val, conds, mode): + if mode == 'default_clobber': + carry_reaps[name], carry_conds[name] = lax.cond( + conds[name], + lambda val=val: (val, True), + lambda name=name: (carry_reaps[name], carry_conds[name]), + ) + else: + carry_reaps[name] = val + + def _reap_scan_rule(trace: HarvestTrace, *tracers, length, reverse, jaxpr, num_consts, num_carry, linear, unroll, _split_transpose): """Reaps the body of a scan to pull out `clobber` and `append` sows.""" @@ -809,6 +869,7 @@ def _reap_scan_rule(trace: HarvestTrace, *tracers, length, reverse, jaxpr, reap_modes = collections.defaultdict(set) reap_carry_avals = {} + cond_carry_avals = {} for name, meta in metadata.items(): mode = meta['mode'] aval = meta['aval'] @@ -817,38 +878,46 @@ def _reap_scan_rule(trace: HarvestTrace, *tracers, length, reverse, jaxpr, reap_modes[mode].add(name) if mode == 'clobber': reap_carry_avals[name] = aval + cond_carry_avals[name] = None + if mode == 'default_clobber': + reap_carry_avals[name] = aval + cond_carry_avals[name] = jax_core.raise_to_shaped(jax_core.get_aval(True)) + body_fun = jax_core.jaxpr_as_fun(jaxpr) - reap_carry_flat_avals, _ = tree_util.tree_flatten(reap_carry_avals) + reap_carry_flat_avals = tree_util.tree_leaves( + (reap_carry_avals, cond_carry_avals) + ) reap_carry_in_tree = tree_util.tree_structure( - ((carry_avals, reap_carry_avals), xs_avals)) + ((carry_avals, reap_carry_avals, cond_carry_avals), xs_avals)) def new_body(carry, x): - carry, _ = carry + carry, carry_reaps, carry_conds = carry all_values = const_vals + tree_util.tree_leaves((carry, x)) - out, reaps = call_and_reap( + out, reaps, conds = _call_and_reap( body_fun, tag=settings.tag, allowlist=settings.allowlist, blocklist=settings.blocklist, exclusive=settings.exclusive)(*all_values) carry_out, y = jax_util.split_list(out, [num_carry]) - carry_reaps = { - name: val - for name, val in reaps.items() - if name in reap_modes['clobber'] - } + clobber_reap_modes = reap_modes['clobber'] | reap_modes['default_clobber'] + for name, val in reaps.items(): + if name in clobber_reap_modes: + _update_clobber_carry(carry_reaps, carry_conds, name, val, conds, mode) xs_reaps = { name: val for name, val in reaps.items() if name in reap_modes['append'] } - return (carry_out, carry_reaps), (y, xs_reaps) + return (carry_out, carry_reaps, carry_conds), (y, xs_reaps) new_body_jaxpr, consts, out_tree = lcf._initial_style_jaxpr( # pylint: disable=protected-access new_body, reap_carry_in_tree, tuple(carry_avals + reap_carry_flat_avals + x_avals)) dummy_reap_carry_vals = tree_util.tree_map( - lambda x: jnp.zeros(x.shape, x.dtype), reap_carry_flat_avals) + lambda x: jnp.zeros(x.shape, x.dtype), + reap_carry_flat_avals, + ) out = lax.scan_p.bind( *(consts + carry_vals + dummy_reap_carry_vals + xs_vals), reverse=reverse, @@ -860,11 +929,16 @@ def new_body(carry, x): linear[len(consts):]), unroll=unroll, _split_transpose=_split_transpose) - (carry_out, - carry_reaps), (ys, ys_reaps) = tree_util.tree_unflatten(out_tree, out) - (carry_out, carry_reaps), (ys, ys_reaps) = tree_util.tree_map( - trace.pure, ((carry_out, carry_reaps), (ys, ys_reaps))) - for k, v in {**carry_reaps, **ys_reaps}.items(): + (carry_out, carry_reaps, carry_conds), (ys, ys_reaps) = ( + tree_util.tree_unflatten(out_tree, out) + ) + (carry_out, carry_reaps, carry_conds), (ys, ys_reaps) = tree_util.tree_map( + trace.pure, ((carry_out, carry_reaps, carry_conds), (ys, ys_reaps)) + ) + for k, v in carry_reaps.items(): + mode = metadata[k]['mode'] + _sow(v, tag=settings.tag, mode=mode, name=k, cond=carry_conds[k]) + for k, v in ys_reaps.items(): sow(v, tag=settings.tag, mode=metadata[k]['mode'], name=k) return carry_out + ys @@ -874,7 +948,7 @@ def new_body(carry, x): def _reap_while_rule(trace: HarvestTrace, *tracers, cond_jaxpr, body_jaxpr, cond_nconsts, body_nconsts): - """Reaps the body of a while loop to get the reaps of the final iteration.""" + """Reaps the body of a while loop to get the reaps of `clobber` sows.""" cond_const_tracers, body_const_tracers, init_tracers = jax_util.split_list( tracers, [cond_nconsts, body_nconsts]) _, init_avals = tree_util.tree_map(lambda x: x.aval, @@ -885,12 +959,18 @@ def _reap_while_rule(trace: HarvestTrace, *tracers, cond_jaxpr, body_jaxpr, settings = context.settings body_metadata = _get_harvest_metadata(body_jaxpr, settings, *(body_const_tracers + init_tracers)) + reap_avals = {} + cond_avals = collections.defaultdict(lambda: None) for k, meta in body_metadata.items(): mode = meta['mode'] - if mode != 'clobber': + if mode not in ['clobber', 'default_clobber']: raise ValueError( - f'Must use clobber mode for \'{k}\' inside of a `while_loop`.') - reap_avals = {k: v['aval'] for k, v in body_metadata.items()} + f"Must use clobber or default_clobber mode for '{k}' inside of a" + ' `while_loop`.' + ) + reap_avals[k] = meta['aval'] + if mode == 'default_clobber': + cond_avals[k] = jax_core.raise_to_shaped(jax_core.get_aval(True)) cond_fun = jax_core.jaxpr_as_fun(cond_jaxpr) body_fun = jax_core.jaxpr_as_fun(body_jaxpr) @@ -900,21 +980,27 @@ def _reap_while_rule(trace: HarvestTrace, *tracers, cond_jaxpr, body_jaxpr, blocklist=settings.blocklist, exclusive=settings.exclusive) - def new_cond(carry, _): + def new_cond(carry, *_): return cond_fun(*(cond_const_vals + carry)) - def new_body(carry, _): - carry, reaps = call_and_reap(body_fun, - **reap_settings)(*(body_const_vals + carry)) - return (carry, reaps) - - new_in_avals, new_in_tree = tree_util.tree_flatten((init_avals, reap_avals)) + def new_body(carry, carry_reaps, carry_conds): + carry, reaps, conds = _call_and_reap(body_fun, **reap_settings)( + *(body_const_vals + carry) + ) + for name, val in reaps.items(): + mode = body_metadata[name]['mode'] + _update_clobber_carry(carry_reaps, carry_conds, name, val, conds, mode) + return (carry, carry_reaps, carry_conds) + + new_in_avals, new_in_tree = tree_util.tree_flatten( + (init_avals, reap_avals, cond_avals) + ) new_cond_jaxpr, cond_consts, _ = lcf._initial_style_jaxpr( # pylint: disable=protected-access new_cond, new_in_tree, tuple(new_in_avals)) new_body_jaxpr, body_consts, out_tree = lcf._initial_style_jaxpr( # pylint: disable=protected-access new_body, new_in_tree, tuple(new_in_avals)) dummy_reap_vals = tree_util.tree_map(lambda x: jnp.zeros(x.shape, x.dtype), - reap_avals) + (reap_avals, cond_avals)) new_in_vals = tree_util.tree_leaves((init_vals, dummy_reap_vals)) out = lax.while_p.bind( *(cond_consts + body_consts + new_in_vals), @@ -923,29 +1009,53 @@ def new_body(carry, _): cond_jaxpr=new_cond_jaxpr, body_jaxpr=new_body_jaxpr) out = jax_util.safe_map(trace.pure, out) - out, reaps = tree_util.tree_unflatten(out_tree, out) + out, reaps, conds = tree_util.tree_unflatten(out_tree, out) for k, v in reaps.items(): - sow(v, name=k, tag=settings.tag, mode=body_metadata[k]['mode']) + mode = body_metadata[k]['mode'] + _sow(v, name=k, tag=settings.tag, mode=mode, cond=conds[k]) return out reap_custom_rules[lcf.while_p] = _reap_while_rule -def _check_branch_metadata(branch_metadatas): - """Checks that a set of harvest metadata are consistent with each other.""" - first_branch_meta = branch_metadatas[0] - for branch_metadata in branch_metadatas[1:]: - if len(branch_metadata) != len(first_branch_meta): - raise ValueError('Mismatching number of `sow`s between branches.') +def _combine_branch_metadata(branch_metadatas): + """Combines metadatas from branches, checking consistency.""" + metas = {} + + for branch_metadata in branch_metadatas: for name, meta in branch_metadata.items(): - if name not in first_branch_meta: - raise ValueError(f'Missing sow in branch: \'{name}\'.') - first_meta_aval = first_branch_meta[name]['aval'] - if meta['aval'].shape != first_meta_aval.shape: + if (ret_meta := metas.get(name)) is None: + metas[name] = meta + continue + ret_aval = ret_meta['aval'] + if meta['aval'].shape != ret_aval.shape: raise ValueError(f'Mismatched shape between branches: \'{name}\'.') - if meta['aval'].dtype != first_meta_aval.dtype: + if meta['aval'].dtype != ret_aval.dtype: raise ValueError(f'Mismatched dtype between branches: \'{name}\'.') + if meta['mode'] != ret_meta['mode']: + raise ValueError(f'Mismatched mode between branches: \'{name}\'.') + return metas + + +def add_missing_branch_sows(branch_funs, branch_metadatas, metadata, tag): + """Wrap branch funs that may not have all sows in metadata, so that they do.""" + + def add_missing_sows(names, f, *args, **kwargs): + for name in names: + meta = metadata[name] + aval = meta['aval'] + default = jnp.zeros(aval.shape, aval.dtype) + if meta['mode'] != 'default_clobber': + raise ValueError(f"Missing sow in branch: '{name}'") + _sow(default, tag=tag, name=name, mode='default_clobber', cond=False) + return f(*args, **kwargs) + for branch_metadata, fun in zip(branch_metadatas, branch_funs): + if len(branch_metadata) == len(metadata): + yield fun + else: + names = set(metadata) - set(branch_metadata) + yield functools.partial(add_missing_sows, names, fun) def _reap_cond_rule(trace, *tracers, branches, linear): @@ -965,10 +1075,15 @@ def _reap_cond_rule(trace, *tracers, branches, linear): branch_metadatas = tuple( _get_harvest_metadata(branch, settings, *ops_tracers) for branch in branches) - _check_branch_metadata(branch_metadatas) + metadatas = _combine_branch_metadata(branch_metadatas) branch_funs = tuple(map(jax_core.jaxpr_as_fun, branches)) + branch_funs = tuple( + add_missing_branch_sows( + branch_funs, branch_metadatas, metadatas, settings.tag + ) + ) reaped_branches = tuple( - call_and_reap(f, **reap_settings) for f in branch_funs) + _call_and_reap(f, **reap_settings) for f in branch_funs) in_tree = tree_util.tree_structure(ops_avals) new_branch_jaxprs, consts, out_trees = ( lcf._initial_style_jaxprs_with_common_consts( # pylint: disable=protected-access @@ -979,9 +1094,9 @@ def _reap_cond_rule(trace, *tracers, branches, linear): branches=tuple(new_branch_jaxprs), linear=(False,) * len(tuple(consts) + linear)) out = jax_util.safe_map(trace.pure, out) - out, reaps = tree_util.tree_unflatten(out_trees[0], out) - for k, v in reaps.items(): - sow(v, name=k, tag=settings.tag, mode=branch_metadatas[0][k]['mode']) + out, reaps, conds = tree_util.tree_unflatten(out_trees[0], out) + for k, meta in metadatas.items(): + _sow(reaps[k], name=k, tag=settings.tag, mode=meta['mode'], cond=conds[k]) return out @@ -1002,7 +1117,7 @@ def _reap_checkpoint_rule(trace, *tracers, jaxpr, policy, prevent_cse, closed_jaxpr = jax_core.ClosedJaxpr(jaxpr, ()) reap_metadata = _get_harvest_metadata(closed_jaxpr, settings, *tracers) remat_fun = jax_core.jaxpr_as_fun(closed_jaxpr) - reaped_remat_fun = call_and_reap(remat_fun, **reap_settings) + reaped_remat_fun = _call_and_reap(remat_fun, **reap_settings) reap_jaxpr, consts, out_tree = lcf._initial_style_jaxpr( # pylint: disable=protected-access reaped_remat_fun, tree_util.tree_structure(invals), tuple(t.aval for t in tracers)) @@ -1014,9 +1129,10 @@ def _reap_checkpoint_rule(trace, *tracers, jaxpr, policy, prevent_cse, prevent_cse=prevent_cse, differentiated=differentiated) outvals = jax_util.safe_map(trace.pure, outvals) - out, reaps = tree_util.tree_unflatten(out_tree, outvals) + out, reaps, conds = tree_util.tree_unflatten(out_tree, outvals) for k, v in reaps.items(): - sow(v, name=k, tag=settings.tag, mode=reap_metadata[k]['mode']) + mode = reap_metadata[k]['mode'] + _sow(v, name=k, tag=settings.tag, mode=mode, cond=conds[k]) return out @@ -1073,7 +1189,7 @@ def _reap_pjit_rule(trace, *tracers, **params): closed_jaxpr = params['jaxpr'] reap_metadata = _get_harvest_metadata(closed_jaxpr, settings, *tracers) pjit_fun = jax_core.jaxpr_as_fun(closed_jaxpr) - reaped_pjit_fun = lu.wrap_init(call_and_reap(pjit_fun, **reap_settings)) + reaped_pjit_fun = lu.wrap_init(_call_and_reap(pjit_fun, **reap_settings)) in_tree = tree_util.tree_structure(invals) flat_fun, out_tree = api_util.flatten_fun_nokwargs(reaped_pjit_fun, in_tree) @@ -1094,9 +1210,10 @@ def _reap_pjit_rule(trace, *tracers, **params): outvals = pjit.pjit_p.bind(*final_consts, *invals, **new_params) outvals = jax_util.safe_map(trace.pure, outvals) - out, reaps = tree_util.tree_unflatten(out_tree(), outvals) + out, reaps, conds = tree_util.tree_unflatten(out_tree(), outvals) for k, v in reaps.items(): - sow(v, name=k, tag=settings.tag, mode=reap_metadata[k]['mode']) + mode = reap_metadata[k]['mode'] + _sow(v, name=k, tag=settings.tag, mode=mode, cond=conds[k]) return out @@ -1124,7 +1241,10 @@ def handle_sow(self, *values, name, tag, tree, mode): raise ValueError(f'Variable has already been planted: {name}') if name in self.plants: self._already_planted.add(name) - return tree_util.tree_leaves(self.plants[name]) + if mode == 'default_clobber': + return tree_util.tree_leaves((self.plants[name], True)) + else: + return tree_util.tree_leaves(self.plants[name]) return sow_p.bind(*values, name=name, tag=tag, mode=mode, tree=tree) def process_nest(self, trace, f, *tracers, scope, name, **params): @@ -1305,10 +1425,10 @@ def _plant_scan_rule(trace: HarvestTrace, *tracers, length, reverse, jaxpr, if mode == 'append' and name in plants: plant_xs_avals[name] = aval body_fun = jax_core.jaxpr_as_fun(jaxpr) - clobber_plants = { + all_clobber_plants = { name: value for name, value in plants.items() - if name in plant_modes['clobber'] + if name in plant_modes['clobber'] | plant_modes['default_clobber'] } append_plants = { name: value @@ -1323,7 +1443,7 @@ def _plant_scan_rule(trace: HarvestTrace, *tracers, length, reverse, jaxpr, def new_body(carry, x): x, plants = x - all_plants = {**plants, **clobber_plants} + all_plants = {**plants, **all_clobber_plants} all_values = const_vals + tree_util.tree_leaves((carry, x)) out = plant( body_fun, @@ -1368,9 +1488,11 @@ def _plant_while_rule(trace: HarvestTrace, *tracers, cond_jaxpr, body_jaxpr, *(body_const_tracers + init_tracers)) for k, meta in body_metadata.items(): mode = meta['mode'] - if mode != 'clobber': + if mode not in ['clobber', 'default_clobber']: raise ValueError( - f'Must use clobber mode for \'{k}\' inside of a `while_loop`.') + f"Must use clobber or default_clobber mode for '{k}' inside of a" + ' `while_loop`.' + ) body_fun = jax_core.jaxpr_as_fun(body_jaxpr) plant_settings = dict( @@ -1416,9 +1538,14 @@ def _plant_cond_rule(trace, *tracers, branches, linear): branch_metadatas = tuple( _get_harvest_metadata(branch, settings, *ops_tracers) for branch in branches) - _check_branch_metadata(branch_metadatas) - plants = context.plants + metadatas = _combine_branch_metadata(branch_metadatas) branch_funs = tuple(map(jax_core.jaxpr_as_fun, branches)) + branch_funs = tuple( + add_missing_branch_sows( + branch_funs, branch_metadatas, metadatas, settings.tag + ) + ) + plants = context.plants planted_branches = tuple( functools.partial(plant(f, **plant_settings), plants) for f in branch_funs) diff --git a/oryx/core/interpreters/harvest_test.py b/oryx/core/interpreters/harvest_test.py index e414601..5cda0f0 100644 --- a/oryx/core/interpreters/harvest_test.py +++ b/oryx/core/interpreters/harvest_test.py @@ -728,6 +728,33 @@ def f(init): self.assertListEqual(['x'], list(variables.keys())) np.testing.assert_allclose(variables['x'], true_out[-1]) + @parameterized.named_parameters(('scan', True), ('while_loop', False)) + def test_can_reap_and_plant_looped_values_in_default_clobber_mode( + self, static_length + ): + length = 5 + + def body(index, x): + x = x + index + x = jax.lax.switch( + index, + [ + lambda i=i: variable(x, name=f'x{i}', mode='default_clobber') + for i in range(length + 1) + ], + ) + return x + + def f(upper, init): + return lax.fori_loop(0, length if static_length else upper, body, init) + + out, variables = harvest_variables(f)(dict(x3=0.5), length, 1.) + np.testing.assert_allclose(out, 0.5 + 4) + default = 0. + self.assertDictEqual( + dict(x0=1., x1=2., x2=4., x4=out, x5=default), variables + ) + def test_non_clobber_mode_in_while_loop_should_error_with_reap_and_plant( self): @@ -745,12 +772,16 @@ def f(init): with self.assertRaisesRegex( ValueError, - 'Must use clobber mode for \'x\' inside of a `while_loop`.'): - reap_variables(f)((0, 0.)) + "Must use clobber or default_clobber mode for 'x' inside of a" + ' `while_loop`.', + ): + reap_variables(f)((0, 0.0)) with self.assertRaisesRegex( ValueError, - 'Must use clobber mode for \'x\' inside of a `while_loop`.'): + "Must use clobber or default_clobber mode for 'x' inside of a" + ' `while_loop`.', + ): plant_variables(f)(dict(x=4.), (0, 0.)) def test_can_reap_final_values_from_while_loop(self): @@ -844,11 +875,11 @@ def false_fun(x): return lax.cond(pred, true_fun, false_fun, x) with self.assertRaisesRegex( - ValueError, 'Mismatching number of `sow`s between branches.'): + ValueError, 'Missing sow in branch: \'y\''): reap_variables(f2)(True, 1.) with self.assertRaisesRegex( - ValueError, 'Mismatching number of `sow`s between branches.'): + ValueError, 'Missing sow in branch: \'y\''): plant_variables(f2)({}, True, 1.) def f3(pred, x): @@ -893,6 +924,29 @@ def false_fun(x): self.assertEqual(out, 6.) self.assertDictEqual(reaps, dict(x=3.)) + def test_can_reap_from_mismatching_branches_of_default_clobber_cond(self): + + def f(pred, x): + + @jax.jit + def true_fun(x): + x = variable(x, name='x', mode='default_clobber') + return x + 2. + + def false_fun(x): + x = variable(x + 2., name='y', mode='default_clobber') + return x + 3. + + return lax.cond(pred, true_fun, false_fun, x) + + out, reaps = call_and_reap_variables(f)(True, 1.) + self.assertEqual(out, 3.) + self.assertDictEqual(reaps, dict(x=1., y=0.)) + + out, reaps = call_and_reap_variables(f)(False, 1.) + self.assertEqual(out, 6.) + self.assertDictEqual(reaps, dict(y=3., x=0.)) + def test_can_plant_values_into_either_branch_of_cond(self): def f(pred, x): @@ -913,6 +967,27 @@ def false_fun(x): out = plant_variables(f)(dict(x=4.), False, 1.) self.assertEqual(out, 7.) + def test_can_plant_into_mismatching_branches_of_default_clobber_cond(self): + + def f(pred, x): + + @jax.jit + def true_fun(x): + x = variable(x, name='x', mode='default_clobber') + return x + 2. + + def false_fun(x): + x = variable(x + 2., name='y', mode='default_clobber') + return x + 3. + + return lax.cond(pred, true_fun, false_fun, x) + + out = plant_variables(f)(dict(x=4.), True, 1.) + self.assertEqual(out, 6.) + + out = plant_variables(f)(dict(y=4.), False, 1.) + self.assertEqual(out, 7.) + def test_can_reap_values_from_any_branch_in_switch(self): def f(index, x):