Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[key reuse] fix scan single-key consumption issue #19634

Merged
merged 1 commit into from
Feb 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions jax/experimental/key_reuse/_forwarding.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,16 +257,20 @@ def _scan_key_type_signature(eqn, args_consumed):

# scan body should not consume key in constants
if any(np.any(s.mask) for s in signature.sinks if s.idx < num_consts):
raise KeyReuseError("scan body function leads to key reuse when repeatedly executed:\n"
raise KeyReuseError("scan body function leads to key reuse when repeatedly executed, "
"because key constants are repeatedly consumed:\n"
f" {signature=}\n"
f" {eqn=}\n"
f" {jaxpr=}")

# scan carry should only consume keys that are sourced on output.
carry_sinks = {s.idx - num_consts: s.mask for s in signature.sinks if 0 <= s.idx - num_consts < num_carry}
carry_sources = {s.idx: s.mask for s in signature.sources if 0 <= s.idx < num_carry}
if carry_sinks.keys() != carry_sources.keys(): # TODO(jakevdp): check that masks match
raise KeyReuseError(f"scan body function leads to key reuse when repeatedly executed:\n"
carry_sinks = {s.idx - num_consts: s.mask for s in signature.sinks
if 0 <= s.idx - num_consts < num_carry and np.any(s.mask)}
carry_sources = {s.idx: s.mask for s in signature.sources
if 0 <= s.idx < num_carry and np.any(s.mask)}
if not set(carry_sinks).issubset(set(carry_sources)): # TODO(jakevdp): check that masks match
raise KeyReuseError("scan body function leads to key reuse when repeatedly executed, "
"because consumed inputs don't match sourced outputs:\n"
f" {signature=}\n"
f" {eqn=}\n"
f" {jaxpr=}")
Expand Down
14 changes: 9 additions & 5 deletions jax/experimental/key_reuse/_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,16 +227,20 @@ def _scan_key_type_signature(eqn, args_consumed):

# scan body should not consume key in constants
if any(np.any(s.mask) for s in signature.sinks if s.idx < num_consts):
raise KeyReuseError(f"scan body function leads to key reuse when repeatedly executed:\n"
raise KeyReuseError("scan body function leads to key reuse when repeatedly executed, "
"because key constants are repeatedly consumed:\n"
f" {signature=}\n"
f" {eqn=}\n"
f" {jaxpr=}")

# scan carry should only consume keys that are sourced on output.
carry_sinks = {s.idx - num_consts: s.mask for s in signature.sinks if 0 <= s.idx - num_consts < num_carry}
carry_sources = {s.idx: s.mask for s in signature.sources if 0 <= s.idx < num_carry}
if carry_sinks.keys() != carry_sources.keys(): # TODO(jakevdp): check that masks match
raise KeyReuseError(f"scan body function leads to key reuse when repeatedly executed:\n"
carry_sinks = {s.idx - num_consts: s.mask for s in signature.sinks
if 0 <= s.idx - num_consts < num_carry and np.any(s.mask)}
carry_sources = {s.idx: s.mask for s in signature.sources
if 0 <= s.idx < num_carry and np.any(s.mask)}
if not set(carry_sinks).issubset(set(carry_sources)):
raise KeyReuseError("scan body function leads to key reuse when repeatedly executed, "
"because consumed inputs don't match sourced outputs:\n"
f" {signature=}\n"
f" {eqn=}\n"
f" {jaxpr=}")
Expand Down
7 changes: 7 additions & 0 deletions tests/key_reuse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,13 @@ def f_scan_over_keys(key):
return jax.lax.map(jax.random.bits, keys)
self.check_key_reuse(f_scan_over_keys, jax.random.key(0))

def test_scan_consume_one(self):
def f_scan_over_keys(*keys):
def body_func(keys, x):
return tuple(jax.random.split(keys[0])), x
return jax.lax.scan(body_func, keys, xs=jnp.arange(10))
self.check_key_reuse(f_scan_over_keys, jax.random.key(0), jax.random.key(1))

def test_vmap(self):
@jax.vmap
def f_good(seed):
Expand Down
Loading