Skip to content

Commit

Permalink
Fix selective activation checkpointing with random ops (fairinternal/…
Browse files Browse the repository at this point in the history
  • Loading branch information
fmassa authored and xFormers Bot committed Mar 5, 2024
1 parent 051b56a commit 5c8b7c9
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 4 deletions.
8 changes: 7 additions & 1 deletion tests/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,13 @@ def test_optimize_runtime_with_given_memory(max_memory, optimal_soln):
memory = torch.tensor([x[2] for x in data], dtype=torch.float64)

out = _optimize_runtime_with_given_memory(
memory, runtimes, max_memory, view_like_ops, inplace_ops, rand_ops
memory,
runtimes,
max_memory,
view_like_ops,
inplace_ops,
rand_ops,
force_store_random=False,
)
torch.testing.assert_close(optimal_soln, out)

Expand Down
30 changes: 27 additions & 3 deletions xformers/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def _get_default_policy(allow_list=None):
if allow_list is None:
allow_list = _default_allow_list

def _default_policy(func, *args, **kwargs):
def _default_policy(mode, func, *args, **kwargs):
return str(func) in allow_list

return _default_policy
Expand Down Expand Up @@ -183,6 +183,7 @@ def checkpoint(
function,
*args,
use_reentrant=False,
preserve_rng_state=preserve_rng_state,
context_fn=functools.partial(selective_checkpoint_context_fn, policy_fn),
**kwargs,
)
Expand Down Expand Up @@ -232,6 +233,10 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None):
curr_idx, output_ids, inplace_info = self._get_inplace_metadata(func, out)
is_view_like = is_view_fn(func) or is_inplace_view_fn(func)
is_rand_op = torch.Tag.nondeterministic_seeded in func.tags
# sdpa has non-deterministic seed, but might be deterministic
# if no dropout is applied
if func.overloadpacket.__name__ == "_scaled_dot_product_flash_attention":
is_rand_op = kwargs.get("dropout_p", 0) != 0

# get runtime info of func
torch.cuda.synchronize()
Expand Down Expand Up @@ -346,13 +351,17 @@ def get_optimal_checkpoint_policy(function, *args, memory_budget: float) -> Call

max_memory = memory_budget * memory.sum().item()

# workaround to fix https://github.com/pytorch/pytorch/issues/121212
force_store_random = all([not isinstance(x, torch.Tensor) for x in args])

optim_output = _optimize_runtime_with_given_memory(
memory=memory,
runtimes=runtimes,
max_memory=max_memory,
view_like_ops=view_like_ops,
inplace_ops=inplace_ops,
random_ops=rand_ops,
force_store_random=force_store_random,
)
return _OptimalPolicy(optim_output=optim_output)

Expand All @@ -364,6 +373,7 @@ def _optimize_runtime_with_given_memory(
view_like_ops: List[int],
inplace_ops: List[Tuple[int, ...]],
random_ops: List[int],
force_store_random: bool,
) -> torch.Tensor:
"""
Given a list of operator names, their corresponding runtimes, and the maximum amount of memory available,
Expand All @@ -379,6 +389,7 @@ def _optimize_runtime_with_given_memory(
This will be used to add the constraint that in-place ops need to either be
stored in memory with the previous op, or recomputed with the previous op.
random_ops ([List[int]): Indices of the random ops, which will always be recomputed.
force_store_random (bool): force random ops to always be stored (instead of recomputed)
"""
c = -runtimes # type: ignore[operator]

Expand Down Expand Up @@ -406,16 +417,26 @@ def _optimize_runtime_with_given_memory(
A[op] = 1
constraints.append(LinearConstraint(A=A, lb=1, ub=1))

# random ops should always be recomputed
# ideally, always recompute random ops
# in practice, due to a bug in https://github.com/pytorch/pytorch/issues/121212
# sometimes we need to store them to avoid correctness issues
for i in random_ops:
A = torch.zeros_like(c)
A[i] = 1
constraints.append(LinearConstraint(A=A, lb=0, ub=0))
val = int(force_store_random)
constraints.append(LinearConstraint(A=A, lb=val, ub=val))

integrality = torch.ones_like(c)
res = milp(
c=c, constraints=constraints, integrality=integrality, bounds=Bounds(0, 1)
)
if not res.success:
raise ValueError(
"The problem is infeasible, and probably due to a change in xformers "
"that makes random ops always be stored. Try passing a larger memory_budget. "
"This will be fixed once https://github.com/pytorch/pytorch/issues/121212 "
"is solved"
)
x = torch.from_numpy(res.x)
return x

Expand Down Expand Up @@ -448,6 +469,9 @@ def __init__(self, mod, memory_budget=None, policy_fn=None):

@torch.compiler.disable
def _get_policy_fn(self, *args, **kwargs):
if not torch.is_grad_enabled():
# no need to compute a policy as it won't be used
return []
# if policy is not specified, initialize policy for a given memory budget
with torch.random.fork_rng():
policy_fn = get_optimal_checkpoint_policy(
Expand Down

0 comments on commit 5c8b7c9

Please sign in to comment.