From c2bca18cb74de20752123dbe0f2c31bfe11c5337 Mon Sep 17 00:00:00 2001 From: The oryx Authors Date: Sun, 3 Mar 2024 15:56:13 -0800 Subject: [PATCH] Add 'default_clobber' mode that lets different loop iterations sow different values to be reaped. PiperOrigin-RevId: 612253581 --- oryx/__init__.py | 2 +- oryx/bijectors/__init__.py | 2 +- oryx/core/__init__.py | 2 +- oryx/core/interpreters/__init__.py | 2 +- oryx/core/interpreters/harvest.py | 332 ++++++++++++------ oryx/core/interpreters/harvest_test.py | 77 +++- oryx/core/interpreters/inverse/__init__.py | 2 +- .../inverse/bijector_extensions.py | 2 +- .../inverse/bijector_extensions_test.py | 2 +- oryx/core/interpreters/inverse/core.py | 2 +- .../interpreters/inverse/custom_inverse.py | 2 +- .../inverse/custom_inverse_test.py | 2 +- .../core/interpreters/inverse/inverse_test.py | 2 +- oryx/core/interpreters/inverse/rules.py | 2 +- oryx/core/interpreters/inverse/slice.py | 2 +- oryx/core/interpreters/inverse/slice_test.py | 2 +- oryx/core/interpreters/log_prob.py | 2 +- oryx/core/interpreters/log_prob_test.py | 2 +- oryx/core/interpreters/propagate.py | 2 +- oryx/core/interpreters/propagate_test.py | 2 +- oryx/core/kwargs_util.py | 2 +- oryx/core/kwargs_util_test.py | 2 +- oryx/core/ppl/__init__.py | 2 +- oryx/core/ppl/effect_handler.py | 2 +- oryx/core/ppl/effect_handler_test.py | 2 +- oryx/core/ppl/plate_util.py | 2 +- oryx/core/ppl/transformations.py | 2 +- oryx/core/ppl/transformations_test.py | 2 +- oryx/core/primitive.py | 2 +- oryx/core/primitive_test.py | 2 +- oryx/core/pytree.py | 2 +- oryx/core/serialize.py | 2 +- oryx/core/serialize_test.py | 2 +- oryx/core/state/__init__.py | 2 +- oryx/core/state/api.py | 2 +- oryx/core/state/function.py | 2 +- oryx/core/state/function_test.py | 2 +- oryx/core/state/module.py | 2 +- oryx/core/state/registrations.py | 2 +- oryx/core/state/registrations_test.py | 2 +- oryx/core/trace_util.py | 2 +- oryx/distributions/__init__.py | 2 +- oryx/distributions/distribution_extensions.py | 2 +- .../distribution_extensions_test.py | 2 +- oryx/experimental/__init__.py | 2 +- oryx/experimental/autoconj/addn.py | 2 +- oryx/experimental/autoconj/addn_test.py | 2 +- oryx/experimental/autoconj/canonicalize.py | 2 +- .../autoconj/canonicalize_test.py | 2 +- oryx/experimental/autoconj/einsum.py | 2 +- oryx/experimental/autoconj/einsum_test.py | 2 +- oryx/experimental/matching/__init__.py | 2 +- oryx/experimental/matching/jax_rewrite.py | 2 +- .../experimental/matching/jax_rewrite_test.py | 2 +- oryx/experimental/matching/matcher.py | 2 +- oryx/experimental/matching/matcher_test.py | 2 +- oryx/experimental/matching/rules.py | 2 +- oryx/experimental/matching/rules_test.py | 2 +- oryx/experimental/mcmc/__init__.py | 2 +- oryx/experimental/mcmc/kernels.py | 2 +- oryx/experimental/mcmc/kernels_test.py | 2 +- oryx/experimental/mcmc/utils.py | 2 +- oryx/experimental/nn/__init__.py | 2 +- oryx/experimental/nn/base.py | 2 +- oryx/experimental/nn/base_test.py | 2 +- oryx/experimental/nn/combinator.py | 2 +- oryx/experimental/nn/combinator_test.py | 2 +- oryx/experimental/nn/convolution.py | 2 +- oryx/experimental/nn/convolution_test.py | 2 +- oryx/experimental/nn/core.py | 2 +- oryx/experimental/nn/core_test.py | 2 +- oryx/experimental/nn/function.py | 2 +- oryx/experimental/nn/function_test.py | 2 +- oryx/experimental/nn/normalization.py | 2 +- oryx/experimental/nn/normalization_test.py | 2 +- oryx/experimental/nn/pooling.py | 2 +- oryx/experimental/nn/pooling_test.py | 2 +- oryx/experimental/nn/reshape.py | 2 +- oryx/experimental/nn/reshape_test.py | 2 +- oryx/experimental/optimizers/__init__.py | 2 +- oryx/experimental/optimizers/optix.py | 2 +- oryx/experimental/optimizers/optix_test.py | 2 +- oryx/internal/__init__.py | 2 +- oryx/internal/test_util.py | 2 +- oryx/tools/build_oryx_docs.py | 2 +- oryx/util/__init__.py | 2 +- oryx/util/summary.py | 2 +- oryx/util/summary_test.py | 2 +- oryx/version.py | 2 +- 89 files changed, 395 insertions(+), 188 deletions(-) diff --git a/oryx/__init__.py b/oryx/__init__.py index 0706bd8..900e157 100644 --- a/oryx/__init__.py +++ b/oryx/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/bijectors/__init__.py b/oryx/bijectors/__init__.py index 67f64c6..f7e2d84 100644 --- a/oryx/bijectors/__init__.py +++ b/oryx/bijectors/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/core/__init__.py b/oryx/core/__init__.py index 94b5b9a..4649cb0 100644 --- a/oryx/core/__init__.py +++ b/oryx/core/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/core/interpreters/__init__.py b/oryx/core/interpreters/__init__.py index b5b4a3f..255a997 100644 --- a/oryx/core/interpreters/__init__.py +++ b/oryx/core/interpreters/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/core/interpreters/harvest.py b/oryx/core/interpreters/harvest.py index 14b1d71..3450971 100644 --- a/oryx/core/interpreters/harvest.py +++ b/oryx/core/interpreters/harvest.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -51,11 +51,13 @@ default set to `'strict'`. The `mode` of a `sow` describes how it behaves when the same name appears multiple times. In "strict" mode, `sow` will error if the same `(tag, name)` appears more than once. Another option is `'append'`, in -which all sows of the same name will be appended into a growing array. Finally, -there is `'clobber'`, where only the final sown value for a given `(tag, name)` -will be returned. The final optional argument for `sow` is `key`, which will -automatically be tied-in to the output of `sow` to introduce a fake -data-dependence. By default, it is `None`. +which all sows of the same name will be appended into a growing array. There is +also `'clobber'`, where only the final sown value for a given `(tag, name)` will +be returned. Finally, there is `'default_clobber'`, which is like clobber but +allows sowing to be conditional, falling back on zeros if no sow took place. +The final optional argument for `sow` is `key`, which will automatically be +tied-in to the output of `sow` to introduce a fake data-dependence. By default, +it is `None`. ## `harvest` @@ -137,7 +139,8 @@ def f(x): import collections import dataclasses import functools -from typing import Any, Callable, Dict, FrozenSet, Hashable, Iterable, List, Optional, Tuple, Union +import typing +from typing import Any, Callable, Dict, FrozenSet, Hashable, Iterable, List, Literal, Optional, Tuple, Union from jax import api_util from jax import lax @@ -229,23 +232,35 @@ def sow(value, *, tag: Hashable, name: str, mode: str = 'strict', key=None): value: A JAX value to be tagged and named. tag: a string representing the tag of the sown value. name: a string representing the name to sow the value with. - mode: The mode by which to sow the value. There are three options: 1. + mode: The mode by which to sow the value. There are four options: 1. `'strict'` - if another value is sown with the same name and tag in the same context, harvest will throw an error. 2. `'clobber'` - if another is - value is sown with the same name and tag, it will replace this value 3. + value is sown with the same name and tag, it will replace this value. 3. `'append'` - sown values of the same name and tag are appended to a growing list. Append mode assumes some ordering on the values being sown - defined by data-dependence. + defined by data-dependence. 4. `'default_clobber'` - like `'clobber'`, but + sowing may be conditional, falling back on zeros if no sow took place. key: an optional JAX value that will be tied into the sown value. Returns: The original `value` that was passed in. """ + if mode == 'default_clobber': + mode = 'clobber' + cond = True + else: + cond = None + return _sow(value, tag=tag, name=name, mode=mode, key=key, cond=cond) + + +def _sow(value, *, tag, name, mode, key=None, cond=None): value = tree_util.tree_map(jax_core.raise_as_much_as_possible, value) if key is not None: value = prim.tie_in(key, value) flat_args, in_tree = tree_util.tree_flatten(value) - out_flat = sow_p.bind(*flat_args, name=name, tag=tag, mode=mode, tree=in_tree) + out_flat = sow_p.bind( + *flat_args, name=name, tag=tag, mode=mode, tree=in_tree, cond=cond + ) return tree_util.tree_unflatten(in_tree, out_flat) @@ -439,25 +454,31 @@ class HarvestContext: """A context that handles `sow`s and `nest`s in a `HarvestTrace`.""" settings: HarvestSettings - def process_sow(self, *values, name, tag, mode, tree): + def process_sow(self, *values, name, tag, mode, tree, cond): """Handles a `sow` primitive in a `HarvestTrace`.""" if mode not in {'strict', 'append', 'clobber'}: raise ValueError(f'Invalid mode: {mode}') + if cond is not None and mode != 'clobber': + raise ValueError(f'Mode {mode} does not support conditions.') if tag != self.settings.tag: if self.settings.exclusive: return values - return sow_p.bind(*values, name=name, tag=tag, tree=tree, mode=mode) + return sow_p.bind( + *values, name=name, tag=tag, tree=tree, mode=mode, cond=cond + ) if (self.settings.allowlist is not None and name not in self.settings.allowlist): return values if name in self.settings.blocklist: return values - return self.handle_sow(*values, name=name, tag=tag, tree=tree, mode=mode) + return self.handle_sow( + *values, name=name, tag=tag, tree=tree, mode=mode, cond=cond + ) def get_custom_rule(self, primitive): raise NotImplementedError - def handle_sow(self, *values, name, tag, mode, tree): + def handle_sow(self, *values, name, tag, mode, tree, cond): raise NotImplementedError def process_nest(self, trace, f, *tracers, scope, name): @@ -493,6 +514,7 @@ def post_process_custom_vjp_call_fwd(self, trace, out_tracers, out_trees): @dataclasses.dataclass class Reap: value: Any + cond: Any metadata: Dict[str, Any] @@ -505,7 +527,7 @@ class ReapContext(HarvestContext): def get_custom_rule(self, primitive): return reap_custom_rules.get(primitive) - def handle_sow(self, *values, name, tag, tree, mode): + def handle_sow(self, *values, name, tag, tree, mode, cond): """Stores a sow in the reaps dictionary.""" del tag if name in self.reaps: @@ -513,8 +535,9 @@ def handle_sow(self, *values, name, tag, tree, mode): avals = tree_util.tree_unflatten( tree, [jax_core.raise_to_shaped(jax_core.get_aval(v)) for v in values]) - self.reaps[name] = Reap( - tree_util.tree_unflatten(tree, values), dict(mode=mode, aval=avals)) + vals = tree_util.tree_unflatten(tree, values) + metadata = dict(mode=mode, aval=avals, has_cond=cond is not None) + self.reaps[name] = Reap(vals, cond, metadata) return values def reap_higher_order_primitive(self, trace, call_primitive, f, tracers, @@ -537,32 +560,38 @@ def new_out_axes_thunk(): params = dict(params, out_axes_thunk=new_out_axes_thunk) out_flat = call_primitive.bind(f, *vals, name=name, **params) out_tree, metadata = aux() - out_vals, reaps = tree_util.tree_unflatten(out_tree, out_flat) + out_vals, reaps, conds = tree_util.tree_unflatten(out_tree, out_flat) out_tracers = jax_util.safe_map(trace.pure, out_vals) reap_tracers = tree_util.tree_map(trace.pure, reaps) - return out_tracers, reap_tracers, metadata + cond_tracers = tree_util.tree_map(trace.pure, conds) + return out_tracers, reap_tracers, cond_tracers, metadata def process_nest(self, trace, f, *tracers, scope, name, **params): - out_tracers, reap_tracers, _ = self.reap_higher_order_primitive( + out_tracers, reap_tracers, _, _ = self.reap_higher_order_primitive( trace, nest_p, f, tracers, dict(params, name=name, scope=scope), False) tag = self.settings.tag if reap_tracers: flat_reap_tracers, reap_tree = tree_util.tree_flatten(reap_tracers) trace.process_primitive( sow_p, flat_reap_tracers, - dict(name=scope, tag=tag, tree=reap_tree, mode='strict')) + dict(name=scope, tag=tag, tree=reap_tree, mode='strict', cond=None)) return out_tracers def process_higher_order_primitive(self, trace, call_primitive, f, tracers, params, is_map): - out_tracers, reap_tracers, metadata = self.reap_higher_order_primitive( - trace, call_primitive, f, tracers, params, is_map) + out_tracers, reap_tracers, cond_tracers, metadata = ( + self.reap_higher_order_primitive( + trace, call_primitive, f, tracers, params, is_map + ) + ) tag = self.settings.tag for k, v in reap_tracers.items(): flat_reap_tracers, reap_tree = tree_util.tree_flatten(v) + mode = metadata[k]['mode'] + cond = cond_tracers[k] trace.process_primitive( sow_p, flat_reap_tracers, - dict(name=k, tag=tag, tree=reap_tree, mode=metadata[k]['mode'])) + dict(name=k, tag=tag, tree=reap_tree, mode=mode, cond=cond)) return out_tracers def process_custom_jvp_call(self, trace, primitive, fun, jvp, tracers, *, @@ -583,14 +612,18 @@ def _jvp_subtrace(main, *args): out_flat = primitive.bind(fun, jvp, *vals_in, symbolic_zeros=symbolic_zeros) fst, (out_tree, metadata) = lu.merge_linear_aux(aux1, aux2) if fst: - out, reaps = tree_util.tree_unflatten(out_tree, out_flat) - out_tracers, reap_tracers = tree_util.tree_map(trace.pure, (out, reaps)) + out, reaps, conds = tree_util.tree_unflatten(out_tree, out_flat) + out_tracers, reap_tracers, cond_tracers = tree_util.tree_map( + trace.pure, (out, reaps, conds) + ) tag = context.settings.tag for k, v in reap_tracers.items(): flat_reap_tracers, reap_tree = tree_util.tree_flatten(v) + mode = metadata[k]['mode'] + cond = cond_tracers[k] trace.process_primitive( sow_p, flat_reap_tracers, - dict(name=k, tag=tag, tree=reap_tree, mode=metadata[k]['mode'])) + dict(name=k, tag=tag, tree=reap_tree, mode=mode, cond=cond)) else: out_tracers = jax_util.safe_map(trace.pure, out_flat) return out_tracers @@ -621,14 +654,19 @@ def _fwd_subtrace(main, *args): symbolic_zeros=symbolic_zeros) fst, (out_tree, metadata) = lu.merge_linear_aux(aux1, aux2) if fst: - out, reaps = tree_util.tree_unflatten(out_tree, out_flat) - out_tracers, reap_tracers = tree_util.tree_map(trace.pure, (out, reaps)) + out, reaps, conds = tree_util.tree_unflatten(out_tree, out_flat) + out_tracers, reap_tracers, cond_tracers = tree_util.tree_map( + trace.pure, (out, reaps, conds) + ) tag = context.settings.tag for k, v in reap_tracers.items(): flat_reap_tracers, reap_tree = tree_util.tree_flatten(v) + mode = metadata[k]['mode'] + cond = cond_tracers[k] trace.process_primitive( sow_p, flat_reap_tracers, - dict(name=k, tag=tag, tree=reap_tree, mode=metadata[k]['mode'])) + dict(name=k, tag=tag, tree=reap_tree, mode=mode, cond=cond), + ) else: out_tracers = jax_util.safe_map(trace.pure, out_flat) return out_tracers @@ -648,7 +686,7 @@ def post_process_custom_vjp_call_fwd(self, trace, out_tracers, out_trees): @lu.transformation def reap_function(main: jax_core.MainTrace, settings: HarvestSettings, return_metadata: bool, args: Iterable[Any]): - """A function transformation that returns reap values.""" + """A function transformation that returns reap values and conditions.""" trace = HarvestTrace(main, jax_core.cur_sublevel()) in_tracers = jax_util.safe_map(trace.pure, args) context = ReapContext(settings, {}) @@ -657,14 +695,17 @@ def reap_function(main: jax_core.MainTrace, settings: HarvestSettings, out_tracers = jax_util.safe_map(trace.full_raise, ans) reap_tracers = tree_util.tree_map( lambda x: tree_util.tree_map(trace.full_raise, x.value), context.reaps) + cond_tracers = tree_util.tree_map( + lambda x: tree_util.tree_map(trace.full_raise, x.cond), context.reaps) reap_metadata = tree_util.tree_map(lambda x: x.metadata, context.reaps) del main - out_values, reap_values = tree_util.tree_map(lambda x: x.val, - (out_tracers, reap_tracers)) + out_values, reap_values, cond_values = tree_util.tree_map( + lambda x: x.val, (out_tracers, reap_tracers, cond_tracers) + ) if return_metadata: - out = (out_values, reap_values, reap_metadata) + out = (out_values, reap_values, cond_values, reap_metadata) else: - out = (out_values, reap_values) + out = (out_values, reap_values, cond_values) yield out @@ -678,16 +719,16 @@ def reap_eval( @lu.transformation_with_aux def reap_wrapper(trace: HarvestTrace, *args): del trace - out, reaps, metadata = yield (args,), {} - out_flat, out_tree = tree_util.tree_flatten((out, reaps)) + out, reaps, conds, metadata = yield (args,), {} + out_flat, out_tree = tree_util.tree_flatten((out, reaps, conds)) yield out_flat, (out_tree, metadata) @lu.transformation def reap_wrapper_drop_aux(trace: HarvestTrace, *args): del trace - out, reaps, _ = yield (args,), {} - out_flat, _ = tree_util.tree_flatten((out, reaps)) + out, reaps, conds, _ = yield (args,), {} + out_flat, _ = tree_util.tree_flatten((out, reaps, conds)) yield out_flat @@ -713,6 +754,18 @@ def call_and_reap(f, A new function that executes the original and returns its sown values as an additional return value. """ + wrapped = _call_and_reap( + f, + tag=tag, + allowlist=allowlist, + blocklist=blocklist, + exclusive=exclusive, + ) + return lambda *args, **kwargs: wrapped(*args, **kwargs)[:-1] + + +def _call_and_reap(f, *, tag, allowlist, blocklist, exclusive): + """Like call_and_reap but including conditions.""" blocklist = frozenset(blocklist) if allowlist is not None: allowlist = frozenset(allowlist) @@ -724,9 +777,9 @@ def wrapped(*args, **kwargs): flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree) with jax_core.new_main(HarvestTrace) as main: flat_fun = reap_function(flat_fun, main, settings, False) - out_flat, reaps = flat_fun.call_wrapped(flat_args) + out_flat, reaps, conds = flat_fun.call_wrapped(flat_args) del main - return tree_util.tree_unflatten(out_tree(), out_flat), reaps + return tree_util.tree_unflatten(out_tree(), out_flat), reaps, conds return wrapped @@ -766,8 +819,8 @@ def wrapped(*args, **kwargs): @lu.transformation_with_aux def _reap_metadata_wrapper(*args): - out, reaps, metadata = yield (args,), {} - yield (out, reaps), metadata + out, reaps, conds, metadata = yield (args,), {} + yield (out, reaps, conds), metadata def _get_harvest_metadata(closed_jaxpr, settings, *args): @@ -789,6 +842,17 @@ def _get_harvest_metadata(closed_jaxpr, settings, *args): return metadata +def _update_clobber_carry(carry_reaps, carry_conds, name, val, conds, metadata): + if metadata[name]['has_cond']: + carry_reaps[name], carry_conds[name] = lax.cond( + conds[name], + lambda val=val: (val, True), + lambda name=name: (carry_reaps[name], carry_conds[name]), + ) + else: + carry_reaps[name], carry_conds[name] = val, carry_conds[name] + + def _reap_scan_rule(trace: HarvestTrace, *tracers, length, reverse, jaxpr, num_consts, num_carry, linear, unroll): """Reaps the body of a scan to pull out `clobber` and `append` sows.""" @@ -809,6 +873,7 @@ def _reap_scan_rule(trace: HarvestTrace, *tracers, length, reverse, jaxpr, reap_modes = collections.defaultdict(set) reap_carry_avals = {} + cond_carry_avals = {} for name, meta in metadata.items(): mode = meta['mode'] aval = meta['aval'] @@ -817,38 +882,48 @@ def _reap_scan_rule(trace: HarvestTrace, *tracers, length, reverse, jaxpr, reap_modes[mode].add(name) if mode == 'clobber': reap_carry_avals[name] = aval + cond_carry_avals[name] = ( + jax_core.raise_to_shaped(jax_core.get_aval(True)) + if meta['has_cond'] + else None + ) + body_fun = jax_core.jaxpr_as_fun(jaxpr) - reap_carry_flat_avals, _ = tree_util.tree_flatten(reap_carry_avals) + reap_carry_flat_avals = tree_util.tree_leaves( + (reap_carry_avals, cond_carry_avals) + ) reap_carry_in_tree = tree_util.tree_structure( - ((carry_avals, reap_carry_avals), xs_avals)) + ((carry_avals, reap_carry_avals, cond_carry_avals), xs_avals)) def new_body(carry, x): - carry, _ = carry + carry, carry_reaps, carry_conds = carry all_values = const_vals + tree_util.tree_leaves((carry, x)) - out, reaps = call_and_reap( + out, reaps, conds = _call_and_reap( body_fun, tag=settings.tag, allowlist=settings.allowlist, blocklist=settings.blocklist, exclusive=settings.exclusive)(*all_values) carry_out, y = jax_util.split_list(out, [num_carry]) - carry_reaps = { - name: val - for name, val in reaps.items() - if name in reap_modes['clobber'] - } + for name, val in reaps.items(): + if name in reap_modes['clobber']: + _update_clobber_carry( + carry_reaps, carry_conds, name, val, conds, metadata + ) xs_reaps = { name: val for name, val in reaps.items() if name in reap_modes['append'] } - return (carry_out, carry_reaps), (y, xs_reaps) + return (carry_out, carry_reaps, carry_conds), (y, xs_reaps) new_body_jaxpr, consts, out_tree = lcf._initial_style_jaxpr( # pylint: disable=protected-access new_body, reap_carry_in_tree, tuple(carry_avals + reap_carry_flat_avals + x_avals)) dummy_reap_carry_vals = tree_util.tree_map( - lambda x: jnp.zeros(x.shape, x.dtype), reap_carry_flat_avals) + lambda x: jnp.zeros(x.shape, x.dtype), + reap_carry_flat_avals, + ) out = lax.scan_p.bind( *(consts + carry_vals + dummy_reap_carry_vals + xs_vals), reverse=reverse, @@ -859,11 +934,16 @@ def new_body(carry, x): linear=(linear[:len(consts)] + (False,) * len(dummy_reap_carry_vals) + linear[len(consts):]), unroll=unroll) - (carry_out, - carry_reaps), (ys, ys_reaps) = tree_util.tree_unflatten(out_tree, out) - (carry_out, carry_reaps), (ys, ys_reaps) = tree_util.tree_map( - trace.pure, ((carry_out, carry_reaps), (ys, ys_reaps))) - for k, v in {**carry_reaps, **ys_reaps}.items(): + (carry_out, carry_reaps, carry_conds), (ys, ys_reaps) = ( + tree_util.tree_unflatten(out_tree, out) + ) + (carry_out, carry_reaps, carry_conds), (ys, ys_reaps) = tree_util.tree_map( + trace.pure, ((carry_out, carry_reaps, carry_conds), (ys, ys_reaps)) + ) + for k, v in carry_reaps.items(): + mode = metadata[k]['mode'] + _sow(v, tag=settings.tag, mode=mode, name=k, cond=carry_conds[k]) + for k, v in ys_reaps.items(): sow(v, tag=settings.tag, mode=metadata[k]['mode'], name=k) return carry_out + ys @@ -873,7 +953,7 @@ def new_body(carry, x): def _reap_while_rule(trace: HarvestTrace, *tracers, cond_jaxpr, body_jaxpr, cond_nconsts, body_nconsts): - """Reaps the body of a while loop to get the reaps of the final iteration.""" + """Reaps the body of a while loop to get the reaps of `clobber` sows.""" cond_const_tracers, body_const_tracers, init_tracers = jax_util.split_list( tracers, [cond_nconsts, body_nconsts]) _, init_avals = tree_util.tree_map(lambda x: x.aval, @@ -884,12 +964,19 @@ def _reap_while_rule(trace: HarvestTrace, *tracers, cond_jaxpr, body_jaxpr, settings = context.settings body_metadata = _get_harvest_metadata(body_jaxpr, settings, *(body_const_tracers + init_tracers)) + reap_avals = {} + cond_avals = {} for k, meta in body_metadata.items(): mode = meta['mode'] if mode != 'clobber': raise ValueError( f'Must use clobber mode for \'{k}\' inside of a `while_loop`.') - reap_avals = {k: v['aval'] for k, v in body_metadata.items()} + reap_avals[k] = meta['aval'] + cond_avals[k] = ( + jax_core.raise_to_shaped(jax_core.get_aval(True)) + if meta['has_cond'] + else None + ) cond_fun = jax_core.jaxpr_as_fun(cond_jaxpr) body_fun = jax_core.jaxpr_as_fun(body_jaxpr) @@ -899,21 +986,28 @@ def _reap_while_rule(trace: HarvestTrace, *tracers, cond_jaxpr, body_jaxpr, blocklist=settings.blocklist, exclusive=settings.exclusive) - def new_cond(carry, _): + def new_cond(carry, *_): return cond_fun(*(cond_const_vals + carry)) - def new_body(carry, _): - carry, reaps = call_and_reap(body_fun, - **reap_settings)(*(body_const_vals + carry)) - return (carry, reaps) - - new_in_avals, new_in_tree = tree_util.tree_flatten((init_avals, reap_avals)) + def new_body(carry, carry_reaps, carry_conds): + carry, reaps, conds = _call_and_reap( + body_fun, **reap_settings # pylint: disable=unbalanced-tuple-unpacking + )(*(body_const_vals + carry)) + for name, val in reaps.items(): + _update_clobber_carry( + carry_reaps, carry_conds, name, val, conds, body_metadata + ) + return (carry, carry_reaps, carry_conds) + + new_in_avals, new_in_tree = tree_util.tree_flatten( + (init_avals, reap_avals, cond_avals) + ) new_cond_jaxpr, cond_consts, _ = lcf._initial_style_jaxpr( # pylint: disable=protected-access new_cond, new_in_tree, tuple(new_in_avals)) new_body_jaxpr, body_consts, out_tree = lcf._initial_style_jaxpr( # pylint: disable=protected-access new_body, new_in_tree, tuple(new_in_avals)) dummy_reap_vals = tree_util.tree_map(lambda x: jnp.zeros(x.shape, x.dtype), - reap_avals) + (reap_avals, cond_avals)) new_in_vals = tree_util.tree_leaves((init_vals, dummy_reap_vals)) out = lax.while_p.bind( *(cond_consts + body_consts + new_in_vals), @@ -922,29 +1016,55 @@ def new_body(carry, _): cond_jaxpr=new_cond_jaxpr, body_jaxpr=new_body_jaxpr) out = jax_util.safe_map(trace.pure, out) - out, reaps = tree_util.tree_unflatten(out_tree, out) + out, reaps, conds = tree_util.tree_unflatten(out_tree, out) for k, v in reaps.items(): - sow(v, name=k, tag=settings.tag, mode=body_metadata[k]['mode']) + mode = body_metadata[k]['mode'] + _sow(v, name=k, tag=settings.tag, mode=mode, cond=conds[k]) return out reap_custom_rules[lcf.while_p] = _reap_while_rule -def _check_branch_metadata(branch_metadatas): - """Checks that a set of harvest metadata are consistent with each other.""" - first_branch_meta = branch_metadatas[0] - for branch_metadata in branch_metadatas[1:]: - if len(branch_metadata) != len(first_branch_meta): - raise ValueError('Mismatching number of `sow`s between branches.') +def _combine_branch_metadata(branch_metadatas): + """Combines metadatas from branches, checking consistency.""" + metas = {} + + for branch_metadata in branch_metadatas: for name, meta in branch_metadata.items(): - if name not in first_branch_meta: - raise ValueError(f'Missing sow in branch: \'{name}\'.') - first_meta_aval = first_branch_meta[name]['aval'] - if meta['aval'].shape != first_meta_aval.shape: + if (ret_meta := metas.get(name)) is None: + metas[name] = meta + continue + ret_aval = ret_meta['aval'] + if meta['aval'].shape != ret_aval.shape: raise ValueError(f'Mismatched shape between branches: \'{name}\'.') - if meta['aval'].dtype != first_meta_aval.dtype: + if meta['aval'].dtype != ret_aval.dtype: raise ValueError(f'Mismatched dtype between branches: \'{name}\'.') + if meta['mode'] != ret_meta['mode']: + raise ValueError(f'Mismatched mode between branches: \'{name}\'.') + if meta['has_cond'] != ret_meta['has_cond']: + raise ValueError(f'Mismatched cond between branches: \'{name}\'.') + return metas + + +def add_missing_branch_sows(branch_funs, branch_metadatas, metadata, tag): + """Wrap branch funs that may not have all sows in metadata, so that they do.""" + + def add_missing_sows(names, f, *args, **kwargs): + for name in names: + meta = metadata[name] + aval = meta['aval'] + default = jnp.zeros(aval.shape, aval.dtype) + if meta['mode'] != 'clobber' or not meta['has_cond']: + raise ValueError(f"Missing sow in branch: '{name}'") + _sow(default, tag=tag, name=name, mode='clobber', cond=False) + return f(*args, **kwargs) + for branch_metadata, fun in zip(branch_metadatas, branch_funs): + if len(branch_metadata) == len(metadata): + yield fun + else: + names = set(metadata) - set(branch_metadata) + yield functools.partial(add_missing_sows, names, fun) def _reap_cond_rule(trace, *tracers, branches, linear): @@ -964,10 +1084,15 @@ def _reap_cond_rule(trace, *tracers, branches, linear): branch_metadatas = tuple( _get_harvest_metadata(branch, settings, *ops_tracers) for branch in branches) - _check_branch_metadata(branch_metadatas) + metadatas = _combine_branch_metadata(branch_metadatas) branch_funs = tuple(map(jax_core.jaxpr_as_fun, branches)) + branch_funs = tuple( + add_missing_branch_sows( + branch_funs, branch_metadatas, metadatas, settings.tag + ) + ) reaped_branches = tuple( - call_and_reap(f, **reap_settings) for f in branch_funs) + _call_and_reap(f, **reap_settings) for f in branch_funs) in_tree = tree_util.tree_structure(ops_avals) new_branch_jaxprs, consts, out_trees = ( lcf._initial_style_jaxprs_with_common_consts( # pylint: disable=protected-access @@ -978,9 +1103,11 @@ def _reap_cond_rule(trace, *tracers, branches, linear): branches=tuple(new_branch_jaxprs), linear=(False,) * len(tuple(consts) + linear)) out = jax_util.safe_map(trace.pure, out) - out, reaps = tree_util.tree_unflatten(out_trees[0], out) - for k, v in reaps.items(): - sow(v, name=k, tag=settings.tag, mode=branch_metadatas[0][k]['mode']) + out, reaps, conds = tree_util.tree_unflatten(out_trees[0], out) + for k, meta in metadatas.items(): + _sow( + reaps[k], name=k, tag=settings.tag, mode=meta['mode'], cond=conds[k] + ) return out @@ -1001,7 +1128,7 @@ def _reap_checkpoint_rule(trace, *tracers, jaxpr, policy, prevent_cse, closed_jaxpr = jax_core.ClosedJaxpr(jaxpr, ()) reap_metadata = _get_harvest_metadata(closed_jaxpr, settings, *tracers) remat_fun = jax_core.jaxpr_as_fun(closed_jaxpr) - reaped_remat_fun = call_and_reap(remat_fun, **reap_settings) + reaped_remat_fun = _call_and_reap(remat_fun, **reap_settings) reap_jaxpr, consts, out_tree = lcf._initial_style_jaxpr( # pylint: disable=protected-access reaped_remat_fun, tree_util.tree_structure(invals), tuple(t.aval for t in tracers)) @@ -1013,9 +1140,10 @@ def _reap_checkpoint_rule(trace, *tracers, jaxpr, policy, prevent_cse, prevent_cse=prevent_cse, differentiated=differentiated) outvals = jax_util.safe_map(trace.pure, outvals) - out, reaps = tree_util.tree_unflatten(out_tree, outvals) + out, reaps, conds = tree_util.tree_unflatten(out_tree, outvals) for k, v in reaps.items(): - sow(v, name=k, tag=settings.tag, mode=reap_metadata[k]['mode']) + mode = reap_metadata[k]['mode'] + _sow(v, name=k, tag=settings.tag, mode=mode, cond=conds[k]) return out @@ -1071,7 +1199,7 @@ def _reap_pjit_rule(trace, *tracers, **params): closed_jaxpr = params['jaxpr'] reap_metadata = _get_harvest_metadata(closed_jaxpr, settings, *tracers) pjit_fun = jax_core.jaxpr_as_fun(closed_jaxpr) - reaped_pjit_fun = lu.wrap_init(call_and_reap(pjit_fun, **reap_settings)) + reaped_pjit_fun = lu.wrap_init(_call_and_reap(pjit_fun, **reap_settings)) in_tree = tree_util.tree_structure(invals) flat_fun, out_tree = api_util.flatten_fun_nokwargs(reaped_pjit_fun, in_tree) @@ -1088,9 +1216,10 @@ def _reap_pjit_rule(trace, *tracers, **params): outvals = pjit.pjit_p.bind(*final_consts, *invals, **new_params) outvals = jax_util.safe_map(trace.pure, outvals) - out, reaps = tree_util.tree_unflatten(out_tree(), outvals) + out, reaps, conds = tree_util.tree_unflatten(out_tree(), outvals) for k, v in reaps.items(): - sow(v, name=k, tag=settings.tag, mode=reap_metadata[k]['mode']) + mode = reap_metadata[k]['mode'] + _sow(v, name=k, tag=settings.tag, mode=mode, cond=conds[k]) return out @@ -1112,14 +1241,16 @@ def __post_init__(self): def get_custom_rule(self, primitive): return plant_custom_rules.get(primitive) - def handle_sow(self, *values, name, tag, tree, mode): + def handle_sow(self, *values, name, tag, tree, mode, cond): """Returns the value stored in the plants dictionary.""" if name in self._already_planted: raise ValueError(f'Variable has already been planted: {name}') if name in self.plants: self._already_planted.add(name) return tree_util.tree_leaves(self.plants[name]) - return sow_p.bind(*values, name=name, tag=tag, mode=mode, tree=tree) + return sow_p.bind( + *values, name=name, tag=tag, mode=mode, tree=tree, cond=cond + ) def process_nest(self, trace, f, *tracers, scope, name, **params): return self.process_higher_order_primitive( @@ -1409,9 +1540,14 @@ def _plant_cond_rule(trace, *tracers, branches, linear): branch_metadatas = tuple( _get_harvest_metadata(branch, settings, *ops_tracers) for branch in branches) - _check_branch_metadata(branch_metadatas) - plants = context.plants + metadatas = _combine_branch_metadata(branch_metadatas) branch_funs = tuple(map(jax_core.jaxpr_as_fun, branches)) + branch_funs = tuple( + add_missing_branch_sows( + branch_funs, branch_metadatas, metadatas, settings.tag + ) + ) + plants = context.plants planted_branches = tuple( functools.partial(plant(f, **plant_settings), plants) for f in branch_funs) diff --git a/oryx/core/interpreters/harvest_test.py b/oryx/core/interpreters/harvest_test.py index f5f8c54..efaa87a 100644 --- a/oryx/core/interpreters/harvest_test.py +++ b/oryx/core/interpreters/harvest_test.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -728,6 +728,33 @@ def f(init): self.assertListEqual(['x'], list(variables.keys())) np.testing.assert_allclose(variables['x'], true_out[-1]) + @parameterized.named_parameters(('scan', True), ('while_loop', False)) + def test_can_reap_and_plant_looped_values_in_default_clobber_mode( + self, static_length + ): + length = 5 + + def body(index, x): + x = x + index + x = jax.lax.switch( + index, + [ + lambda i=i: variable(x, name=f'x{i}', mode='default_clobber') + for i in range(length + 1) + ], + ) + return x + + def f(upper, init): + return lax.fori_loop(0, length if static_length else upper, body, init) + + out, variables = harvest_variables(f)(dict(x3=0.5), length, 1.) + np.testing.assert_allclose(out, 0.5 + 4) + default = 0. + self.assertDictEqual( + dict(x0=1., x1=2., x2=4., x4=out, x5=default), variables + ) + def test_non_clobber_mode_in_while_loop_should_error_with_reap_and_plant( self): @@ -844,11 +871,11 @@ def false_fun(x): return lax.cond(pred, true_fun, false_fun, x) with self.assertRaisesRegex( - ValueError, 'Mismatching number of `sow`s between branches.'): + ValueError, 'Missing sow in branch: \'y\''): reap_variables(f2)(True, 1.) with self.assertRaisesRegex( - ValueError, 'Mismatching number of `sow`s between branches.'): + ValueError, 'Missing sow in branch: \'y\''): plant_variables(f2)({}, True, 1.) def f3(pred, x): @@ -893,6 +920,29 @@ def false_fun(x): self.assertEqual(out, 6.) self.assertDictEqual(reaps, dict(x=3.)) + def test_can_reap_from_mismatching_branches_of_default_clobber_cond(self): + + def f(pred, x): + + @jax.jit + def true_fun(x): + x = variable(x, name='x', mode='default_clobber') + return x + 2. + + def false_fun(x): + x = variable(x + 2., name='y', mode='default_clobber') + return x + 3. + + return lax.cond(pred, true_fun, false_fun, x) + + out, reaps = call_and_reap_variables(f)(True, 1.) + self.assertEqual(out, 3.) + self.assertDictEqual(reaps, dict(x=1., y=0.)) + + out, reaps = call_and_reap_variables(f)(False, 1.) + self.assertEqual(out, 6.) + self.assertDictEqual(reaps, dict(y=3., x=0.)) + def test_can_plant_values_into_either_branch_of_cond(self): def f(pred, x): @@ -913,6 +963,27 @@ def false_fun(x): out = plant_variables(f)(dict(x=4.), False, 1.) self.assertEqual(out, 7.) + def test_can_plant_into_mismatching_branches_of_default_clobber_cond(self): + + def f(pred, x): + + @jax.jit + def true_fun(x): + x = variable(x, name='x', mode='default_clobber') + return x + 2. + + def false_fun(x): + x = variable(x + 2., name='y', mode='default_clobber') + return x + 3. + + return lax.cond(pred, true_fun, false_fun, x) + + out = plant_variables(f)(dict(x=4.), True, 1.) + self.assertEqual(out, 6.) + + out = plant_variables(f)(dict(y=4.), False, 1.) + self.assertEqual(out, 7.) + def test_can_reap_values_from_any_branch_in_switch(self): def f(index, x): diff --git a/oryx/core/interpreters/inverse/__init__.py b/oryx/core/interpreters/inverse/__init__.py index 947a04e..63a5aa8 100644 --- a/oryx/core/interpreters/inverse/__init__.py +++ b/oryx/core/interpreters/inverse/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/core/interpreters/inverse/bijector_extensions.py b/oryx/core/interpreters/inverse/bijector_extensions.py index 212cb28..0440ece 100644 --- a/oryx/core/interpreters/inverse/bijector_extensions.py +++ b/oryx/core/interpreters/inverse/bijector_extensions.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/core/interpreters/inverse/bijector_extensions_test.py b/oryx/core/interpreters/inverse/bijector_extensions_test.py index 04fd0b3..9a33bc2 100644 --- a/oryx/core/interpreters/inverse/bijector_extensions_test.py +++ b/oryx/core/interpreters/inverse/bijector_extensions_test.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/core/interpreters/inverse/core.py b/oryx/core/interpreters/inverse/core.py index 0407c74..eecfcbe 100644 --- a/oryx/core/interpreters/inverse/core.py +++ b/oryx/core/interpreters/inverse/core.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/core/interpreters/inverse/custom_inverse.py b/oryx/core/interpreters/inverse/custom_inverse.py index 3d751e0..4bb6774 100644 --- a/oryx/core/interpreters/inverse/custom_inverse.py +++ b/oryx/core/interpreters/inverse/custom_inverse.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/core/interpreters/inverse/custom_inverse_test.py b/oryx/core/interpreters/inverse/custom_inverse_test.py index 08ca2cf..62397e6 100644 --- a/oryx/core/interpreters/inverse/custom_inverse_test.py +++ b/oryx/core/interpreters/inverse/custom_inverse_test.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/core/interpreters/inverse/inverse_test.py b/oryx/core/interpreters/inverse/inverse_test.py index 308fec9..cde8af0 100644 --- a/oryx/core/interpreters/inverse/inverse_test.py +++ b/oryx/core/interpreters/inverse/inverse_test.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/core/interpreters/inverse/rules.py b/oryx/core/interpreters/inverse/rules.py index b36624e..2071fd9 100644 --- a/oryx/core/interpreters/inverse/rules.py +++ b/oryx/core/interpreters/inverse/rules.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/core/interpreters/inverse/slice.py b/oryx/core/interpreters/inverse/slice.py index 42d4992..efe1c20 100644 --- a/oryx/core/interpreters/inverse/slice.py +++ b/oryx/core/interpreters/inverse/slice.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/core/interpreters/inverse/slice_test.py b/oryx/core/interpreters/inverse/slice_test.py index 24d10c5..607044b 100644 --- a/oryx/core/interpreters/inverse/slice_test.py +++ b/oryx/core/interpreters/inverse/slice_test.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/core/interpreters/log_prob.py b/oryx/core/interpreters/log_prob.py index 26ac171..f1f659b 100644 --- a/oryx/core/interpreters/log_prob.py +++ b/oryx/core/interpreters/log_prob.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/core/interpreters/log_prob_test.py b/oryx/core/interpreters/log_prob_test.py index b73cabd..aa18fb9 100644 --- a/oryx/core/interpreters/log_prob_test.py +++ b/oryx/core/interpreters/log_prob_test.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/core/interpreters/propagate.py b/oryx/core/interpreters/propagate.py index 7a5e6cb..1a91075 100644 --- a/oryx/core/interpreters/propagate.py +++ b/oryx/core/interpreters/propagate.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/core/interpreters/propagate_test.py b/oryx/core/interpreters/propagate_test.py index 279bab7..dc65cc3 100644 --- a/oryx/core/interpreters/propagate_test.py +++ b/oryx/core/interpreters/propagate_test.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/core/kwargs_util.py b/oryx/core/kwargs_util.py index acf8af4..4efd499 100644 --- a/oryx/core/kwargs_util.py +++ b/oryx/core/kwargs_util.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/core/kwargs_util_test.py b/oryx/core/kwargs_util_test.py index 8b46d88..6fac2c2 100644 --- a/oryx/core/kwargs_util_test.py +++ b/oryx/core/kwargs_util_test.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/core/ppl/__init__.py b/oryx/core/ppl/__init__.py index f4af52a..295d487 100644 --- a/oryx/core/ppl/__init__.py +++ b/oryx/core/ppl/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/core/ppl/effect_handler.py b/oryx/core/ppl/effect_handler.py index 36fc916..905d5d4 100644 --- a/oryx/core/ppl/effect_handler.py +++ b/oryx/core/ppl/effect_handler.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/core/ppl/effect_handler_test.py b/oryx/core/ppl/effect_handler_test.py index fb00502..7a26cd8 100644 --- a/oryx/core/ppl/effect_handler_test.py +++ b/oryx/core/ppl/effect_handler_test.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/core/ppl/plate_util.py b/oryx/core/ppl/plate_util.py index fbb655f..c35c65e 100644 --- a/oryx/core/ppl/plate_util.py +++ b/oryx/core/ppl/plate_util.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/core/ppl/transformations.py b/oryx/core/ppl/transformations.py index c3e2a61..ad73d34 100644 --- a/oryx/core/ppl/transformations.py +++ b/oryx/core/ppl/transformations.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/core/ppl/transformations_test.py b/oryx/core/ppl/transformations_test.py index 7bfc8f9..5278983 100644 --- a/oryx/core/ppl/transformations_test.py +++ b/oryx/core/ppl/transformations_test.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/core/primitive.py b/oryx/core/primitive.py index 700cace..2ae0b2a 100644 --- a/oryx/core/primitive.py +++ b/oryx/core/primitive.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/core/primitive_test.py b/oryx/core/primitive_test.py index 447a418..0621dd4 100644 --- a/oryx/core/primitive_test.py +++ b/oryx/core/primitive_test.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/core/pytree.py b/oryx/core/pytree.py index ec991df..1e7e367 100644 --- a/oryx/core/pytree.py +++ b/oryx/core/pytree.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/core/serialize.py b/oryx/core/serialize.py index 99f4dad..52717a9 100644 --- a/oryx/core/serialize.py +++ b/oryx/core/serialize.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/core/serialize_test.py b/oryx/core/serialize_test.py index 5425a65..39b7c82 100644 --- a/oryx/core/serialize_test.py +++ b/oryx/core/serialize_test.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/core/state/__init__.py b/oryx/core/state/__init__.py index ecbcda6..3d33f02 100644 --- a/oryx/core/state/__init__.py +++ b/oryx/core/state/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/core/state/api.py b/oryx/core/state/api.py index 160719f..f301463 100644 --- a/oryx/core/state/api.py +++ b/oryx/core/state/api.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/core/state/function.py b/oryx/core/state/function.py index 101fce7..394825d 100644 --- a/oryx/core/state/function.py +++ b/oryx/core/state/function.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/core/state/function_test.py b/oryx/core/state/function_test.py index 23f11a5..6543e20 100644 --- a/oryx/core/state/function_test.py +++ b/oryx/core/state/function_test.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/core/state/module.py b/oryx/core/state/module.py index cd70903..df4d244 100644 --- a/oryx/core/state/module.py +++ b/oryx/core/state/module.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/core/state/registrations.py b/oryx/core/state/registrations.py index 33df07f..0bf99c3 100644 --- a/oryx/core/state/registrations.py +++ b/oryx/core/state/registrations.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/core/state/registrations_test.py b/oryx/core/state/registrations_test.py index b5e51c8..64184c2 100644 --- a/oryx/core/state/registrations_test.py +++ b/oryx/core/state/registrations_test.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/core/trace_util.py b/oryx/core/trace_util.py index 095d9c6..0d55533 100644 --- a/oryx/core/trace_util.py +++ b/oryx/core/trace_util.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/distributions/__init__.py b/oryx/distributions/__init__.py index 067a689..9f54257 100644 --- a/oryx/distributions/__init__.py +++ b/oryx/distributions/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/distributions/distribution_extensions.py b/oryx/distributions/distribution_extensions.py index e184372..1df2c35 100644 --- a/oryx/distributions/distribution_extensions.py +++ b/oryx/distributions/distribution_extensions.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/distributions/distribution_extensions_test.py b/oryx/distributions/distribution_extensions_test.py index ca03e50..f27e4fa 100644 --- a/oryx/distributions/distribution_extensions_test.py +++ b/oryx/distributions/distribution_extensions_test.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/experimental/__init__.py b/oryx/experimental/__init__.py index 6fbc3be..aaddfc6 100644 --- a/oryx/experimental/__init__.py +++ b/oryx/experimental/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/experimental/autoconj/addn.py b/oryx/experimental/autoconj/addn.py index bdb89b4..8a480dd 100644 --- a/oryx/experimental/autoconj/addn.py +++ b/oryx/experimental/autoconj/addn.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/experimental/autoconj/addn_test.py b/oryx/experimental/autoconj/addn_test.py index f5b1ca5..2e72133 100644 --- a/oryx/experimental/autoconj/addn_test.py +++ b/oryx/experimental/autoconj/addn_test.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/experimental/autoconj/canonicalize.py b/oryx/experimental/autoconj/canonicalize.py index e9b02ab..df8d9c5 100644 --- a/oryx/experimental/autoconj/canonicalize.py +++ b/oryx/experimental/autoconj/canonicalize.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/experimental/autoconj/canonicalize_test.py b/oryx/experimental/autoconj/canonicalize_test.py index 747ab30..2893594 100644 --- a/oryx/experimental/autoconj/canonicalize_test.py +++ b/oryx/experimental/autoconj/canonicalize_test.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/experimental/autoconj/einsum.py b/oryx/experimental/autoconj/einsum.py index 6dd824b..0ac7dbe 100644 --- a/oryx/experimental/autoconj/einsum.py +++ b/oryx/experimental/autoconj/einsum.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/experimental/autoconj/einsum_test.py b/oryx/experimental/autoconj/einsum_test.py index 470cfa4..8c79cc2 100644 --- a/oryx/experimental/autoconj/einsum_test.py +++ b/oryx/experimental/autoconj/einsum_test.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/experimental/matching/__init__.py b/oryx/experimental/matching/__init__.py index 80f6e77..070c433 100644 --- a/oryx/experimental/matching/__init__.py +++ b/oryx/experimental/matching/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/experimental/matching/jax_rewrite.py b/oryx/experimental/matching/jax_rewrite.py index ec5ecdc..9a22f3f 100644 --- a/oryx/experimental/matching/jax_rewrite.py +++ b/oryx/experimental/matching/jax_rewrite.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/experimental/matching/jax_rewrite_test.py b/oryx/experimental/matching/jax_rewrite_test.py index fb2b39d..bebe5a0 100644 --- a/oryx/experimental/matching/jax_rewrite_test.py +++ b/oryx/experimental/matching/jax_rewrite_test.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/experimental/matching/matcher.py b/oryx/experimental/matching/matcher.py index 1a09bca..c750d18 100644 --- a/oryx/experimental/matching/matcher.py +++ b/oryx/experimental/matching/matcher.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/experimental/matching/matcher_test.py b/oryx/experimental/matching/matcher_test.py index 1c3082f..b1331b4 100644 --- a/oryx/experimental/matching/matcher_test.py +++ b/oryx/experimental/matching/matcher_test.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/experimental/matching/rules.py b/oryx/experimental/matching/rules.py index ad76027..d17505b 100644 --- a/oryx/experimental/matching/rules.py +++ b/oryx/experimental/matching/rules.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/experimental/matching/rules_test.py b/oryx/experimental/matching/rules_test.py index f97003d..c46b47e 100644 --- a/oryx/experimental/matching/rules_test.py +++ b/oryx/experimental/matching/rules_test.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/experimental/mcmc/__init__.py b/oryx/experimental/mcmc/__init__.py index 6a94096..ed03676 100644 --- a/oryx/experimental/mcmc/__init__.py +++ b/oryx/experimental/mcmc/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/experimental/mcmc/kernels.py b/oryx/experimental/mcmc/kernels.py index 1d9c0eb..6aa6340 100644 --- a/oryx/experimental/mcmc/kernels.py +++ b/oryx/experimental/mcmc/kernels.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/experimental/mcmc/kernels_test.py b/oryx/experimental/mcmc/kernels_test.py index 88962e1..14ae527 100644 --- a/oryx/experimental/mcmc/kernels_test.py +++ b/oryx/experimental/mcmc/kernels_test.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/experimental/mcmc/utils.py b/oryx/experimental/mcmc/utils.py index 22b569f..dcefe97 100644 --- a/oryx/experimental/mcmc/utils.py +++ b/oryx/experimental/mcmc/utils.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/experimental/nn/__init__.py b/oryx/experimental/nn/__init__.py index 75da02f..37d7e0a 100644 --- a/oryx/experimental/nn/__init__.py +++ b/oryx/experimental/nn/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/experimental/nn/base.py b/oryx/experimental/nn/base.py index fde83a9..1dbbddb 100644 --- a/oryx/experimental/nn/base.py +++ b/oryx/experimental/nn/base.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/experimental/nn/base_test.py b/oryx/experimental/nn/base_test.py index 2e86d48..6f588ef 100644 --- a/oryx/experimental/nn/base_test.py +++ b/oryx/experimental/nn/base_test.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/experimental/nn/combinator.py b/oryx/experimental/nn/combinator.py index 37c3f2c..640cab3 100644 --- a/oryx/experimental/nn/combinator.py +++ b/oryx/experimental/nn/combinator.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/experimental/nn/combinator_test.py b/oryx/experimental/nn/combinator_test.py index 97575f2..b3ebf37 100644 --- a/oryx/experimental/nn/combinator_test.py +++ b/oryx/experimental/nn/combinator_test.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/experimental/nn/convolution.py b/oryx/experimental/nn/convolution.py index 99b43a4..df9f235 100644 --- a/oryx/experimental/nn/convolution.py +++ b/oryx/experimental/nn/convolution.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/experimental/nn/convolution_test.py b/oryx/experimental/nn/convolution_test.py index df35cd5..38ac820 100644 --- a/oryx/experimental/nn/convolution_test.py +++ b/oryx/experimental/nn/convolution_test.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/experimental/nn/core.py b/oryx/experimental/nn/core.py index 0dce0f2..cf3b160 100644 --- a/oryx/experimental/nn/core.py +++ b/oryx/experimental/nn/core.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/experimental/nn/core_test.py b/oryx/experimental/nn/core_test.py index 495c8be..62d80c6 100644 --- a/oryx/experimental/nn/core_test.py +++ b/oryx/experimental/nn/core_test.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/experimental/nn/function.py b/oryx/experimental/nn/function.py index 1209599..5a6d49f 100644 --- a/oryx/experimental/nn/function.py +++ b/oryx/experimental/nn/function.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/experimental/nn/function_test.py b/oryx/experimental/nn/function_test.py index 17dbad3..770670a 100644 --- a/oryx/experimental/nn/function_test.py +++ b/oryx/experimental/nn/function_test.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/experimental/nn/normalization.py b/oryx/experimental/nn/normalization.py index 2cbcbe3..26eedbd 100644 --- a/oryx/experimental/nn/normalization.py +++ b/oryx/experimental/nn/normalization.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/experimental/nn/normalization_test.py b/oryx/experimental/nn/normalization_test.py index cb4fe63..59937cf 100644 --- a/oryx/experimental/nn/normalization_test.py +++ b/oryx/experimental/nn/normalization_test.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/experimental/nn/pooling.py b/oryx/experimental/nn/pooling.py index 47b73ad..faaa090 100644 --- a/oryx/experimental/nn/pooling.py +++ b/oryx/experimental/nn/pooling.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/experimental/nn/pooling_test.py b/oryx/experimental/nn/pooling_test.py index 9e33c88..d9b12cf 100644 --- a/oryx/experimental/nn/pooling_test.py +++ b/oryx/experimental/nn/pooling_test.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/experimental/nn/reshape.py b/oryx/experimental/nn/reshape.py index de9b53e..1fd1fee 100644 --- a/oryx/experimental/nn/reshape.py +++ b/oryx/experimental/nn/reshape.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/experimental/nn/reshape_test.py b/oryx/experimental/nn/reshape_test.py index 0d4ffa7..d03a259 100644 --- a/oryx/experimental/nn/reshape_test.py +++ b/oryx/experimental/nn/reshape_test.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/experimental/optimizers/__init__.py b/oryx/experimental/optimizers/__init__.py index 4200654..beb96b9 100644 --- a/oryx/experimental/optimizers/__init__.py +++ b/oryx/experimental/optimizers/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/experimental/optimizers/optix.py b/oryx/experimental/optimizers/optix.py index f65a4a3..d43c014 100644 --- a/oryx/experimental/optimizers/optix.py +++ b/oryx/experimental/optimizers/optix.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/experimental/optimizers/optix_test.py b/oryx/experimental/optimizers/optix_test.py index e8b1e4b..213c6dd 100644 --- a/oryx/experimental/optimizers/optix_test.py +++ b/oryx/experimental/optimizers/optix_test.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/internal/__init__.py b/oryx/internal/__init__.py index ff7ca34..8dec396 100644 --- a/oryx/internal/__init__.py +++ b/oryx/internal/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/internal/test_util.py b/oryx/internal/test_util.py index 2c3f49c..91906a0 100644 --- a/oryx/internal/test_util.py +++ b/oryx/internal/test_util.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/tools/build_oryx_docs.py b/oryx/tools/build_oryx_docs.py index 8af823a..7e909cf 100644 --- a/oryx/tools/build_oryx_docs.py +++ b/oryx/tools/build_oryx_docs.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/util/__init__.py b/oryx/util/__init__.py index e067c44..bedfcfc 100644 --- a/oryx/util/__init__.py +++ b/oryx/util/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/util/summary.py b/oryx/util/summary.py index 2ae7ff0..ea8c86e 100644 --- a/oryx/util/summary.py +++ b/oryx/util/summary.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/util/summary_test.py b/oryx/util/summary_test.py index d6e2e96..449f605 100644 --- a/oryx/util/summary_test.py +++ b/oryx/util/summary_test.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/oryx/version.py b/oryx/version.py index 5043225..9b5f38a 100644 --- a/oryx/version.py +++ b/oryx/version.py @@ -1,4 +1,4 @@ -# Copyright 2023 The oryx Authors. +# Copyright 2024 The oryx Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License.