diff --git a/oryx/core/interpreters/harvest.py b/oryx/core/interpreters/harvest.py index 4d4c95b..70b8b4a 100644 --- a/oryx/core/interpreters/harvest.py +++ b/oryx/core/interpreters/harvest.py @@ -57,6 +57,13 @@ automatically be tied-in to the output of `sow` to introduce a fake data-dependence. By default, it is `None`. +## `sow_cond` + +`sow_cond` is a variant of `sow`, that takes an additional positional argument, +`pred`. It supports a single `mode` `'cond_clobber'`, which is like `clobber`, +but sows values conditionally on `pred`, falling back on zeros if no sow took +place. This allows reaping values from loop iterations besides the final one. + ## `harvest` `harvest` is a function transformation that augments the behaviors of `sow`s @@ -239,8 +246,54 @@ def sow(value, *, tag: Hashable, name: str, mode: str = 'strict', key=None): key: an optional JAX value that will be tied into the sown value. Returns: - The original `value` that was passed in. + The original `value` that was passed in, or a planted value. + """ + if mode == 'cond_clobber': + raise ValueError("For 'cond_clobber' mode, use `sow_cond`.'") + return _sow(value, tag=tag, name=name, mode=mode, key=key) + + +def sow_cond( + value, + pred, + *, + tag: Hashable, + name: str, + mode: str = 'cond_clobber', + key=None, +): + """Marks a value, alongside a predicate, with a name and a tag. + + The predicate determines whether the value is to be clobbered in this loop + iteration -- if it's reaped but never clobbered, the value will be full of + zeros. + + Args: + value: A JAX value to be tagged and named. + pred: Whether to sow the value. + tag: a string representing the tag of the sown value. + name: a string representing the name to sow the value with. + mode: The mode by which to sow the value. There are three options: 1. + `'strict'` - if another value is sown with the same name and tag in the + same context, harvest will throw an error. 2. `'clobber'` - if another is + value is sown with the same name and tag, it will replace this value 3. + `'append'` - sown values of the same name and tag are appended to a + growing list. Append mode assumes some ordering on the values being sown + defined by data-dependence. + key: an optional JAX value that will be tied into the sown value. + + Returns: + The original `value` that was passed in, or a planted value. """ + if mode != 'cond_clobber': + raise ValueError("`sow_cond` only supports 'cond_clobber' mode.") + return _sow(value, tag=tag, name=name, mode=mode, key=key, pred=pred)[0] + + +def _sow(value, *, tag, name, mode, key=None, pred=None): + assert (pred is not None) == (mode == 'cond_clobber') + if pred is not None: + value = value, pred value = tree_util.tree_map(jax_core.raise_as_much_as_possible, value) if key is not None: value = prim.tie_in(key, value) @@ -441,7 +494,7 @@ class HarvestContext: def process_sow(self, *values, name, tag, mode, tree): """Handles a `sow` primitive in a `HarvestTrace`.""" - if mode not in {'strict', 'append', 'clobber'}: + if mode not in {'strict', 'append', 'clobber', 'cond_clobber'}: raise ValueError(f'Invalid mode: {mode}') if tag != self.settings.tag: if self.settings.exclusive: @@ -493,6 +546,7 @@ def post_process_custom_vjp_call_fwd(self, trace, out_tracers, out_trees): @dataclasses.dataclass class Reap: value: Any + pred: Any metadata: Dict[str, Any] @@ -513,8 +567,13 @@ def handle_sow(self, *values, name, tag, tree, mode): avals = tree_util.tree_unflatten( tree, [jax_core.raise_to_shaped(jax_core.get_aval(v)) for v in values]) - self.reaps[name] = Reap( - tree_util.tree_unflatten(tree, values), dict(mode=mode, aval=avals)) + vals = tree_util.tree_unflatten(tree, values) + pred = None + if mode == 'cond_clobber': + avals, _ = avals + vals, pred = vals + metadata = dict(mode=mode, aval=avals) + self.reaps[name] = Reap(vals, pred, metadata) return values def reap_higher_order_primitive(self, trace, call_primitive, f, tracers, @@ -537,13 +596,14 @@ def new_out_axes_thunk(): params = dict(params, out_axes_thunk=new_out_axes_thunk) out_flat = call_primitive.bind(f, *vals, name=name, **params) out_tree, metadata = aux() - out_vals, reaps = tree_util.tree_unflatten(out_tree, out_flat) + out_vals, reaps, preds = tree_util.tree_unflatten(out_tree, out_flat) out_tracers = jax_util.safe_map(trace.pure, out_vals) reap_tracers = tree_util.tree_map(trace.pure, reaps) - return out_tracers, reap_tracers, metadata + pred_tracers = tree_util.tree_map(trace.pure, preds) + return out_tracers, reap_tracers, pred_tracers, metadata def process_nest(self, trace, f, *tracers, scope, name, **params): - out_tracers, reap_tracers, _ = self.reap_higher_order_primitive( + out_tracers, reap_tracers, _, _ = self.reap_higher_order_primitive( trace, nest_p, f, tracers, dict(params, name=name, scope=scope), False) tag = self.settings.tag if reap_tracers: @@ -555,10 +615,15 @@ def process_nest(self, trace, f, *tracers, scope, name, **params): def process_higher_order_primitive(self, trace, call_primitive, f, tracers, params, is_map): - out_tracers, reap_tracers, metadata = self.reap_higher_order_primitive( - trace, call_primitive, f, tracers, params, is_map) + out_tracers, reap_tracers, pred_tracers, metadata = ( + self.reap_higher_order_primitive( + trace, call_primitive, f, tracers, params, is_map + ) + ) tag = self.settings.tag for k, v in reap_tracers.items(): + if metadata[k]['mode'] == 'cond_clobber': + v = (v, pred_tracers[k]) flat_reap_tracers, reap_tree = tree_util.tree_flatten(v) trace.process_primitive( sow_p, flat_reap_tracers, @@ -583,10 +648,14 @@ def _jvp_subtrace(main, *args): out_flat = primitive.bind(fun, jvp, *vals_in, symbolic_zeros=symbolic_zeros) fst, (out_tree, metadata) = lu.merge_linear_aux(aux1, aux2) if fst: - out, reaps = tree_util.tree_unflatten(out_tree, out_flat) - out_tracers, reap_tracers = tree_util.tree_map(trace.pure, (out, reaps)) + out, reaps, preds = tree_util.tree_unflatten(out_tree, out_flat) + out_tracers, reap_tracers, pred_tracers = tree_util.tree_map( + trace.pure, (out, reaps, preds) + ) tag = context.settings.tag for k, v in reap_tracers.items(): + if metadata[k]['mode'] == 'cond_clobber': + v = (v, pred_tracers[k]) flat_reap_tracers, reap_tree = tree_util.tree_flatten(v) trace.process_primitive( sow_p, flat_reap_tracers, @@ -621,10 +690,14 @@ def _fwd_subtrace(main, *args): symbolic_zeros=symbolic_zeros) fst, (out_tree, metadata) = lu.merge_linear_aux(aux1, aux2) if fst: - out, reaps = tree_util.tree_unflatten(out_tree, out_flat) - out_tracers, reap_tracers = tree_util.tree_map(trace.pure, (out, reaps)) + out, reaps, preds = tree_util.tree_unflatten(out_tree, out_flat) + out_tracers, reap_tracers, pred_tracers = tree_util.tree_map( + trace.pure, (out, reaps, preds) + ) tag = context.settings.tag for k, v in reap_tracers.items(): + if metadata[k]['mode'] == 'cond_clobber': + v = (v, pred_tracers[k]) flat_reap_tracers, reap_tree = tree_util.tree_flatten(v) trace.process_primitive( sow_p, flat_reap_tracers, @@ -648,7 +721,7 @@ def post_process_custom_vjp_call_fwd(self, trace, out_tracers, out_trees): @lu.transformation def reap_function(main: jax_core.MainTrace, settings: HarvestSettings, return_metadata: bool, args: Iterable[Any]): - """A function transformation that returns reap values.""" + """A function transformation that returns reap values and predicates.""" trace = HarvestTrace(main, jax_core.cur_sublevel()) in_tracers = jax_util.safe_map(trace.pure, args) context = ReapContext(settings, {}) @@ -657,14 +730,17 @@ def reap_function(main: jax_core.MainTrace, settings: HarvestSettings, out_tracers = jax_util.safe_map(trace.full_raise, ans) reap_tracers = tree_util.tree_map( lambda x: tree_util.tree_map(trace.full_raise, x.value), context.reaps) + pred_tracers = tree_util.tree_map( + lambda x: trace.full_raise(x.pred), context.reaps) reap_metadata = tree_util.tree_map(lambda x: x.metadata, context.reaps) del main - out_values, reap_values = tree_util.tree_map(lambda x: x.val, - (out_tracers, reap_tracers)) + out_values, reap_values, pred_values = tree_util.tree_map( + lambda x: x.val, (out_tracers, reap_tracers, pred_tracers) + ) if return_metadata: - out = (out_values, reap_values, reap_metadata) + out = (out_values, reap_values, pred_values, reap_metadata) else: - out = (out_values, reap_values) + out = (out_values, reap_values, pred_values) yield out @@ -678,16 +754,16 @@ def reap_eval( @lu.transformation_with_aux def reap_wrapper(trace: HarvestTrace, *args): del trace - out, reaps, metadata = yield (args,), {} - out_flat, out_tree = tree_util.tree_flatten((out, reaps)) + out, reaps, preds, metadata = yield (args,), {} + out_flat, out_tree = tree_util.tree_flatten((out, reaps, preds)) yield out_flat, (out_tree, metadata) @lu.transformation def reap_wrapper_drop_aux(trace: HarvestTrace, *args): del trace - out, reaps, _ = yield (args,), {} - out_flat, _ = tree_util.tree_flatten((out, reaps)) + out, reaps, preds, _ = yield (args,), {} + out_flat, _ = tree_util.tree_flatten((out, reaps, preds)) yield out_flat @@ -713,6 +789,18 @@ def call_and_reap(f, A new function that executes the original and returns its sown values as an additional return value. """ + wrapped = _call_and_reap( + f, + tag=tag, + allowlist=allowlist, + blocklist=blocklist, + exclusive=exclusive, + ) + return lambda *args, **kwargs: wrapped(*args, **kwargs)[:-1] + + +def _call_and_reap(f, *, tag, allowlist, blocklist, exclusive): + """Like call_and_reap but including predicates.""" blocklist = frozenset(blocklist) if allowlist is not None: allowlist = frozenset(allowlist) @@ -724,9 +812,9 @@ def wrapped(*args, **kwargs): flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree) with jax_core.new_main(HarvestTrace) as main: flat_fun = reap_function(flat_fun, main, settings, False) - out_flat, reaps = flat_fun.call_wrapped(flat_args) + out_flat, reaps, preds = flat_fun.call_wrapped(flat_args) del main - return tree_util.tree_unflatten(out_tree(), out_flat), reaps + return tree_util.tree_unflatten(out_tree(), out_flat), reaps, preds return wrapped @@ -766,8 +854,8 @@ def wrapped(*args, **kwargs): @lu.transformation_with_aux def _reap_metadata_wrapper(*args): - out, reaps, metadata = yield (args,), {} - yield (out, reaps), metadata + out, reaps, preds, metadata = yield (args,), {} + yield (out, reaps, preds), metadata def _get_harvest_metadata(closed_jaxpr, settings, *args): @@ -789,6 +877,17 @@ def _get_harvest_metadata(closed_jaxpr, settings, *args): return metadata +def _update_clobber_carry(carry_reaps, carry_preds, name, val, preds, mode): + if mode == 'cond_clobber': + carry_reaps[name], carry_preds[name] = lax.cond( + preds[name], + lambda val=val: (val, True), + lambda name=name: (carry_reaps[name], carry_preds[name]), + ) + else: + carry_reaps[name] = val + + def _reap_scan_rule(trace: HarvestTrace, *tracers, length, reverse, jaxpr, num_consts, num_carry, linear, unroll, _split_transpose): """Reaps the body of a scan to pull out `clobber` and `append` sows.""" @@ -809,6 +908,7 @@ def _reap_scan_rule(trace: HarvestTrace, *tracers, length, reverse, jaxpr, reap_modes = collections.defaultdict(set) reap_carry_avals = {} + cond_carry_avals = {} for name, meta in metadata.items(): mode = meta['mode'] aval = meta['aval'] @@ -817,38 +917,46 @@ def _reap_scan_rule(trace: HarvestTrace, *tracers, length, reverse, jaxpr, reap_modes[mode].add(name) if mode == 'clobber': reap_carry_avals[name] = aval + cond_carry_avals[name] = None + if mode == 'cond_clobber': + reap_carry_avals[name] = aval + cond_carry_avals[name] = jax_core.raise_to_shaped(jax_core.get_aval(True)) + body_fun = jax_core.jaxpr_as_fun(jaxpr) - reap_carry_flat_avals, _ = tree_util.tree_flatten(reap_carry_avals) + reap_carry_flat_avals = tree_util.tree_leaves( + (reap_carry_avals, cond_carry_avals) + ) reap_carry_in_tree = tree_util.tree_structure( - ((carry_avals, reap_carry_avals), xs_avals)) + ((carry_avals, reap_carry_avals, cond_carry_avals), xs_avals)) def new_body(carry, x): - carry, _ = carry + carry, carry_reaps, carry_preds = carry all_values = const_vals + tree_util.tree_leaves((carry, x)) - out, reaps = call_and_reap( + out, reaps, preds = _call_and_reap( body_fun, tag=settings.tag, allowlist=settings.allowlist, blocklist=settings.blocklist, exclusive=settings.exclusive)(*all_values) carry_out, y = jax_util.split_list(out, [num_carry]) - carry_reaps = { - name: val - for name, val in reaps.items() - if name in reap_modes['clobber'] - } + clobber_reap_modes = reap_modes['clobber'] | reap_modes['cond_clobber'] + for name, val in reaps.items(): + if name in clobber_reap_modes: + _update_clobber_carry(carry_reaps, carry_preds, name, val, preds, mode) xs_reaps = { name: val for name, val in reaps.items() if name in reap_modes['append'] } - return (carry_out, carry_reaps), (y, xs_reaps) + return (carry_out, carry_reaps, carry_preds), (y, xs_reaps) new_body_jaxpr, consts, out_tree = lcf._initial_style_jaxpr( # pylint: disable=protected-access new_body, reap_carry_in_tree, tuple(carry_avals + reap_carry_flat_avals + x_avals)) dummy_reap_carry_vals = tree_util.tree_map( - lambda x: jnp.zeros(x.shape, x.dtype), reap_carry_flat_avals) + lambda x: jnp.zeros(x.shape, x.dtype), + reap_carry_flat_avals, + ) out = lax.scan_p.bind( *(consts + carry_vals + dummy_reap_carry_vals + xs_vals), reverse=reverse, @@ -860,11 +968,16 @@ def new_body(carry, x): linear[len(consts):]), unroll=unroll, _split_transpose=_split_transpose) - (carry_out, - carry_reaps), (ys, ys_reaps) = tree_util.tree_unflatten(out_tree, out) - (carry_out, carry_reaps), (ys, ys_reaps) = tree_util.tree_map( - trace.pure, ((carry_out, carry_reaps), (ys, ys_reaps))) - for k, v in {**carry_reaps, **ys_reaps}.items(): + (carry_out, carry_reaps, carry_preds), (ys, ys_reaps) = ( + tree_util.tree_unflatten(out_tree, out) + ) + (carry_out, carry_reaps, carry_preds), (ys, ys_reaps) = tree_util.tree_map( + trace.pure, ((carry_out, carry_reaps, carry_preds), (ys, ys_reaps)) + ) + for k, v in carry_reaps.items(): + mode = metadata[k]['mode'] + _sow(v, tag=settings.tag, mode=mode, name=k, pred=carry_preds[k]) + for k, v in ys_reaps.items(): sow(v, tag=settings.tag, mode=metadata[k]['mode'], name=k) return carry_out + ys @@ -874,7 +987,7 @@ def new_body(carry, x): def _reap_while_rule(trace: HarvestTrace, *tracers, cond_jaxpr, body_jaxpr, cond_nconsts, body_nconsts): - """Reaps the body of a while loop to get the reaps of the final iteration.""" + """Reaps the body of a while loop to get the reaps of `clobber` sows.""" cond_const_tracers, body_const_tracers, init_tracers = jax_util.split_list( tracers, [cond_nconsts, body_nconsts]) _, init_avals = tree_util.tree_map(lambda x: x.aval, @@ -885,12 +998,18 @@ def _reap_while_rule(trace: HarvestTrace, *tracers, cond_jaxpr, body_jaxpr, settings = context.settings body_metadata = _get_harvest_metadata(body_jaxpr, settings, *(body_const_tracers + init_tracers)) + reap_avals = {} + cond_avals = collections.defaultdict(lambda: None) for k, meta in body_metadata.items(): mode = meta['mode'] - if mode != 'clobber': + if mode not in ['clobber', 'cond_clobber']: raise ValueError( - f'Must use clobber mode for \'{k}\' inside of a `while_loop`.') - reap_avals = {k: v['aval'] for k, v in body_metadata.items()} + f"Must use clobber or cond_clobber mode for '{k}' inside of a" + ' `while_loop`.' + ) + reap_avals[k] = meta['aval'] + if mode == 'cond_clobber': + cond_avals[k] = jax_core.raise_to_shaped(jax_core.get_aval(True)) cond_fun = jax_core.jaxpr_as_fun(cond_jaxpr) body_fun = jax_core.jaxpr_as_fun(body_jaxpr) @@ -900,21 +1019,27 @@ def _reap_while_rule(trace: HarvestTrace, *tracers, cond_jaxpr, body_jaxpr, blocklist=settings.blocklist, exclusive=settings.exclusive) - def new_cond(carry, _): + def new_cond(carry, *_): return cond_fun(*(cond_const_vals + carry)) - def new_body(carry, _): - carry, reaps = call_and_reap(body_fun, - **reap_settings)(*(body_const_vals + carry)) - return (carry, reaps) - - new_in_avals, new_in_tree = tree_util.tree_flatten((init_avals, reap_avals)) + def new_body(carry, carry_reaps, carry_preds): + carry, reaps, preds = _call_and_reap(body_fun, **reap_settings)( + *(body_const_vals + carry) + ) + for name, val in reaps.items(): + mode = body_metadata[name]['mode'] + _update_clobber_carry(carry_reaps, carry_preds, name, val, preds, mode) + return (carry, carry_reaps, carry_preds) + + new_in_avals, new_in_tree = tree_util.tree_flatten( + (init_avals, reap_avals, cond_avals) + ) new_cond_jaxpr, cond_consts, _ = lcf._initial_style_jaxpr( # pylint: disable=protected-access new_cond, new_in_tree, tuple(new_in_avals)) new_body_jaxpr, body_consts, out_tree = lcf._initial_style_jaxpr( # pylint: disable=protected-access new_body, new_in_tree, tuple(new_in_avals)) dummy_reap_vals = tree_util.tree_map(lambda x: jnp.zeros(x.shape, x.dtype), - reap_avals) + (reap_avals, cond_avals)) new_in_vals = tree_util.tree_leaves((init_vals, dummy_reap_vals)) out = lax.while_p.bind( *(cond_consts + body_consts + new_in_vals), @@ -923,9 +1048,10 @@ def new_body(carry, _): cond_jaxpr=new_cond_jaxpr, body_jaxpr=new_body_jaxpr) out = jax_util.safe_map(trace.pure, out) - out, reaps = tree_util.tree_unflatten(out_tree, out) + out, reaps, preds = tree_util.tree_unflatten(out_tree, out) for k, v in reaps.items(): - sow(v, name=k, tag=settings.tag, mode=body_metadata[k]['mode']) + mode = body_metadata[k]['mode'] + _sow(v, name=k, tag=settings.tag, mode=mode, pred=preds[k]) return out @@ -968,7 +1094,7 @@ def _reap_cond_rule(trace, *tracers, branches, linear): _check_branch_metadata(branch_metadatas) branch_funs = tuple(map(jax_core.jaxpr_as_fun, branches)) reaped_branches = tuple( - call_and_reap(f, **reap_settings) for f in branch_funs) + _call_and_reap(f, **reap_settings) for f in branch_funs) in_tree = tree_util.tree_structure(ops_avals) new_branch_jaxprs, consts, out_trees = ( lcf._initial_style_jaxprs_with_common_consts( # pylint: disable=protected-access @@ -979,9 +1105,10 @@ def _reap_cond_rule(trace, *tracers, branches, linear): branches=tuple(new_branch_jaxprs), linear=(False,) * len(tuple(consts) + linear)) out = jax_util.safe_map(trace.pure, out) - out, reaps = tree_util.tree_unflatten(out_trees[0], out) + out, reaps, preds = tree_util.tree_unflatten(out_trees[0], out) for k, v in reaps.items(): - sow(v, name=k, tag=settings.tag, mode=branch_metadatas[0][k]['mode']) + mode = branch_metadatas[0][k]['mode'] + _sow(v, name=k, tag=settings.tag, mode=mode, pred=preds[k]) return out @@ -1002,7 +1129,7 @@ def _reap_checkpoint_rule(trace, *tracers, jaxpr, policy, prevent_cse, closed_jaxpr = jax_core.ClosedJaxpr(jaxpr, ()) reap_metadata = _get_harvest_metadata(closed_jaxpr, settings, *tracers) remat_fun = jax_core.jaxpr_as_fun(closed_jaxpr) - reaped_remat_fun = call_and_reap(remat_fun, **reap_settings) + reaped_remat_fun = _call_and_reap(remat_fun, **reap_settings) reap_jaxpr, consts, out_tree = lcf._initial_style_jaxpr( # pylint: disable=protected-access reaped_remat_fun, tree_util.tree_structure(invals), tuple(t.aval for t in tracers)) @@ -1014,9 +1141,10 @@ def _reap_checkpoint_rule(trace, *tracers, jaxpr, policy, prevent_cse, prevent_cse=prevent_cse, differentiated=differentiated) outvals = jax_util.safe_map(trace.pure, outvals) - out, reaps = tree_util.tree_unflatten(out_tree, outvals) + out, reaps, preds = tree_util.tree_unflatten(out_tree, outvals) for k, v in reaps.items(): - sow(v, name=k, tag=settings.tag, mode=reap_metadata[k]['mode']) + mode = reap_metadata[k]['mode'] + _sow(v, name=k, tag=settings.tag, mode=mode, pred=preds[k]) return out @@ -1073,7 +1201,7 @@ def _reap_pjit_rule(trace, *tracers, **params): closed_jaxpr = params['jaxpr'] reap_metadata = _get_harvest_metadata(closed_jaxpr, settings, *tracers) pjit_fun = jax_core.jaxpr_as_fun(closed_jaxpr) - reaped_pjit_fun = lu.wrap_init(call_and_reap(pjit_fun, **reap_settings)) + reaped_pjit_fun = lu.wrap_init(_call_and_reap(pjit_fun, **reap_settings)) in_tree = tree_util.tree_structure(invals) flat_fun, out_tree = api_util.flatten_fun_nokwargs(reaped_pjit_fun, in_tree) @@ -1094,9 +1222,10 @@ def _reap_pjit_rule(trace, *tracers, **params): outvals = pjit.pjit_p.bind(*final_consts, *invals, **new_params) outvals = jax_util.safe_map(trace.pure, outvals) - out, reaps = tree_util.tree_unflatten(out_tree(), outvals) + out, reaps, preds = tree_util.tree_unflatten(out_tree(), outvals) for k, v in reaps.items(): - sow(v, name=k, tag=settings.tag, mode=reap_metadata[k]['mode']) + mode = reap_metadata[k]['mode'] + _sow(v, name=k, tag=settings.tag, mode=mode, pred=preds[k]) return out @@ -1124,7 +1253,10 @@ def handle_sow(self, *values, name, tag, tree, mode): raise ValueError(f'Variable has already been planted: {name}') if name in self.plants: self._already_planted.add(name) - return tree_util.tree_leaves(self.plants[name]) + if mode == 'cond_clobber': + return tree_util.tree_leaves((self.plants[name], True)) + else: + return tree_util.tree_leaves(self.plants[name]) return sow_p.bind(*values, name=name, tag=tag, mode=mode, tree=tree) def process_nest(self, trace, f, *tracers, scope, name, **params): @@ -1305,10 +1437,10 @@ def _plant_scan_rule(trace: HarvestTrace, *tracers, length, reverse, jaxpr, if mode == 'append' and name in plants: plant_xs_avals[name] = aval body_fun = jax_core.jaxpr_as_fun(jaxpr) - clobber_plants = { + all_clobber_plants = { name: value for name, value in plants.items() - if name in plant_modes['clobber'] + if name in plant_modes['clobber'] | plant_modes['cond_clobber'] } append_plants = { name: value @@ -1323,7 +1455,7 @@ def _plant_scan_rule(trace: HarvestTrace, *tracers, length, reverse, jaxpr, def new_body(carry, x): x, plants = x - all_plants = {**plants, **clobber_plants} + all_plants = {**plants, **all_clobber_plants} all_values = const_vals + tree_util.tree_leaves((carry, x)) out = plant( body_fun, @@ -1368,9 +1500,11 @@ def _plant_while_rule(trace: HarvestTrace, *tracers, cond_jaxpr, body_jaxpr, *(body_const_tracers + init_tracers)) for k, meta in body_metadata.items(): mode = meta['mode'] - if mode != 'clobber': + if mode not in ['clobber', 'cond_clobber']: raise ValueError( - f'Must use clobber mode for \'{k}\' inside of a `while_loop`.') + f"Must use clobber or cond_clobber mode for '{k}' inside of a" + ' `while_loop`.' + ) body_fun = jax_core.jaxpr_as_fun(body_jaxpr) plant_settings = dict( diff --git a/oryx/core/interpreters/harvest_test.py b/oryx/core/interpreters/harvest_test.py index e414601..e76e9a7 100644 --- a/oryx/core/interpreters/harvest_test.py +++ b/oryx/core/interpreters/harvest_test.py @@ -48,6 +48,7 @@ config.update('jax_traceback_filtering', 'off') sow = harvest.sow +sow_cond = harvest.sow_cond reap = harvest.reap call_and_reap = harvest.call_and_reap plant = harvest.plant @@ -728,6 +729,29 @@ def f(init): self.assertListEqual(['x'], list(variables.keys())) np.testing.assert_allclose(variables['x'], true_out[-1]) + @parameterized.named_parameters(('scan', True), ('while_loop', False)) + def test_can_reap_and_plant_looped_values_in_cond_clobber_mode( + self, static_length + ): + length = 5 + + def body(index, x): + x = x + index + values = [] + for i, pred in enumerate(index == np.arange(length + 1)): + values.append(sow_cond(x, pred, name=f'x{i}', tag='variable')) + return jax.lax.select_n(index, *values) + + def f(upper, init): + return lax.fori_loop(0, length if static_length else upper, body, init) + + out, variables = harvest_variables(f)(dict(x3=0.5), length, 1.) + np.testing.assert_allclose(out, 0.5 + 4) + default = 0. + self.assertDictEqual( + dict(x0=1., x1=2., x2=4., x4=out, x5=default), variables + ) + def test_non_clobber_mode_in_while_loop_should_error_with_reap_and_plant( self): @@ -745,12 +769,16 @@ def f(init): with self.assertRaisesRegex( ValueError, - 'Must use clobber mode for \'x\' inside of a `while_loop`.'): - reap_variables(f)((0, 0.)) + "Must use clobber or cond_clobber mode for 'x' inside of a" + ' `while_loop`.', + ): + reap_variables(f)((0, 0.0)) with self.assertRaisesRegex( ValueError, - 'Must use clobber mode for \'x\' inside of a `while_loop`.'): + "Must use clobber or cond_clobber mode for 'x' inside of a" + ' `while_loop`.', + ): plant_variables(f)(dict(x=4.), (0, 0.)) def test_can_reap_final_values_from_while_loop(self):