Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster unroll: do not propagate params in scan state #61

Merged
merged 1 commit into from
Oct 26, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 16 additions & 17 deletions wax/unroll.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ class UnrollTransformedWithState(NamedTuple):


class ScanState(NamedTuple):
params: Any
fun_state: Any
rng: jnp.ndarray

Expand All @@ -62,23 +61,25 @@ def unroll_transform_with_state(
tfunc = cast(TransformedWithState, fun)
del fun

def scan_f(scan_state, inputs):
params, state, rng = scan_state
args_step, kwargs_step = inputs
if rng is not None:
(rng, sub_rng) = jax.random.split(rng)
else:
sub_rng = None
outputs, state = tfunc.apply(params, state, sub_rng, *args_step, **kwargs_step)
return ScanState(params, state, rng), outputs

def init(rng: jnp.ndarray, *args, **kwargs):
def init_fn(rng: jnp.ndarray, *args, **kwargs):
xs = (args, kwargs)
args_0, kwargs_0 = tree_map(lambda x: x[0], xs)
params, state = tfunc.init(rng, *args_0, **kwargs_0)
return params, state

def apply_fn(params: Any, state: Any, rng: jnp.ndarray, *args, **kwargs):
def scan_f(scan_state, inputs):
state, rng = scan_state
args_step, kwargs_step = inputs
if rng is not None:
(rng, sub_rng) = jax.random.split(rng)
else:
sub_rng = None
outputs, state = tfunc.apply(
params, state, sub_rng, *args_step, **kwargs_step
)
return ScanState(state, rng), outputs

xs = (args, kwargs)

if skip_first:
Expand All @@ -88,13 +89,11 @@ def apply_fn(params: Any, state: Any, rng: jnp.ndarray, *args, **kwargs):
scan = jax.lax.scan
else:
scan = partial(static_scan, pbar=pbar)
scan_state, output_sequence = scan(
scan_f, init=ScanState(params, state, rng), xs=xs
)
_, final_state, _ = scan_state
scan_state, output_sequence = scan(scan_f, init=ScanState(state, rng), xs=xs)
final_state, final_rng = scan_state
return output_sequence, final_state

return UnrollTransformedWithState(init, apply_fn)
return UnrollTransformedWithState(init_fn, apply_fn)


def unroll(
Expand Down