diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py index d31b13dd75..fe4d066fbe 100644 --- a/thunder/core/jit_ext.py +++ b/thunder/core/jit_ext.py @@ -53,6 +53,7 @@ ) import torch +import torch.utils.checkpoint from thunder.core.proxies import ( DistParallelType, proxy, @@ -617,6 +618,231 @@ def _general_jit_torch_finfo_lookaside(dtype: thunder.dtypes.dtype): return res +from torch.utils.checkpoint import noop_context_fn + + +@register_general_jit_lookaside(torch.utils.checkpoint.checkpoint) +@register_general_jit_lookaside(torch.ops.higher_order.tag_activation_checkpoint) +def _general_jit_torch_checkpoint_lookaside( + function: Callable, + *args, + # use_reentrant=None, + # context_fn=noop_context_fn, + # determinism_check="default", + # debug=False, + **kwargs: Any, +): + """ + This function does preprocessing of the `function` argument before + dispatching the call to `thunder.torch.checkpoint`. This is necessary + because the `function` is potentially calling into PyTorch functions that + are not yet translated to Thunder. `thunder.torch.checkpoint` is a Thunder + function that can handle only Thunder functions as input. + + Args: + function: The function to be checkpointed. + args: Arguments to the function. + kwargs: Keyword arguments to the function. + + Returns: + The result of calling `thunder.torch.checkpoint` with the preprocessed + `function` and its arguments. + """ + # from thunder.torch import checkpoint + from thunder.core.baseutils import check, sequencify + from thunder.core.transforms import augmented_forward_impls, backward_impls, VJPDual + + jit_ctx: JitCtx = get_jit_ctx() + + # Construct computation trace(trace_of_checkpoint), checkpoint_fwd_sym + jit_ctx.computation_trace.push_scope([]) + func = unwrap(function) + result = _interpret_call(func, *args, **kwargs) + + if result is INTERPRETER_SIGNALS.EXCEPTION_RAISED: + return result + + bsyms: list[BoundSymbol] = jit_ctx.computation_trace.pop_scope() + unwrapped_result = unwrap(result) + trace_of_checkpoint = TraceCtx() + for bsym in bsyms: + trace_of_checkpoint.add_bound_symbol(bsym) + with tracectx(trace_of_checkpoint): + prims.python_return(unwrapped_result) + + unwrapped_args = tree_map(lambda a: unwrap(a), args) + + si = SigInfo("activation_checkpoint") + si.args.append(("function", None)) + for a in unwrapped_args: + if isinstance(a, Proxy): + si.args.append((a.name, None)) + else: + pa = proxy(a) + si.args.append((pa.name, None)) + trace_of_checkpoint._siginfo = si + trace_of_checkpoint.args = (func, *unwrapped_args) + + @wraps(trace_of_checkpoint.python_callable()) + def core_of_forward(f, *args, **kwargs): + return thunder.core.trace_interpreter.interpret_trace(trace_of_checkpoint, f, *args, **kwargs) + + def bind_postprocess(bsym): + bsym._call_ctx = {} + + checkpoint_fwd_sym = Symbol( + name="activation_checkpoint", + id="activation_checkpoint", + meta=core_of_forward, + _bind_postprocess=bind_postprocess, + ) + # checkpoint_fwd_sym = jit_ctx.ad_hoc_executor.register_operator( + # "activation_checkpoint", + # like=core_of_forward, + # bind_postprocess=bind_postprocess, + # ) + unwrapped_forward_result = checkpoint_fwd_sym(func, *unwrapped_args) + # return value + forward_result = wrap( + unwrapped_forward_result, + provenance=ProvenanceRecord(PseudoInst.LOOKASIDE, inputs=[function.provenance, result.provenance]), + ) + # jit_ctx.ad_hoc_executor.register_implementation(checkpoint_fwd_sym, execution_transform=core_of_forward) + thunder.executors.torchex._register_implementation( + checkpoint_fwd_sym, core_of_forward, checker=thunder.executors.torchex._always_executable + ) + + # construct checkpoint augmented forward(trace_of_augmented_fwd), augmented forward meta function + augmented_bsym_output: tuple[tuple[TensorProxy, ...], tuple[TensorProxy, ...]] = ( + tuple(sequencify(unwrapped_result)), + ((func, *sequencify(unwrapped_args)), {}), + ) + trace_of_augmented_fwd = TraceCtx() + for bsym in bsyms: + trace_of_augmented_fwd.add_bound_symbol(bsym) + with tracectx(trace_of_augmented_fwd): + prims.python_return(augmented_bsym_output) + si = SigInfo(checkpoint_fwd_sym.name) + si.args.append(("function", None)) + for a in unwrapped_args: + if isinstance(a, Proxy): + si.args.append((a.name, None)) + else: + pa = proxy(a) + si.args.append((pa.name, None)) + trace_of_augmented_fwd._siginfo = si + # TODO: support kwargs + trace_of_augmented_fwd.args = (func, *unwrapped_args) + + @wraps(trace_of_augmented_fwd.python_callable()) + def core_of_augmented_forward(f, *args, **kwargs): + return thunder.core.trace_interpreter.interpret_trace(trace_of_augmented_fwd, f, *args, **kwargs) + + @wraps(core_of_augmented_forward) + def augmented_custom_forward_rule(f, *args, **kwargs): + primal, residulas = core_of_augmented_forward(f, *args, **kwargs) + # import pdb;pdb.set_trace() + return VJPDual(primal=primal, residuals=residulas) + + augmented_forward_impls[checkpoint_fwd_sym.name] = augmented_custom_forward_rule + + # construct backward, has problem + from thunder.core.transforms import vjp + + def checkpoint_backward( + function, + args, + kwargs, + *grad_outputs, + ): + result, grad = vjp(function)(args, grad_outputs, **kwargs) + return grad # result + + grads = tree_map( + lambda a: a.replace_name(f"grad_{a.name}"), + sequencify(unwrapped_forward_result), + ) + trace_of_backward = TraceCtx() + bwd_si = SigInfo(f"{checkpoint_fwd_sym.name}_backward") + bwd_si.args.append(("function", None)) + for a in unwrapped_args + grads: + if isinstance(a, Proxy): + bwd_si.args.append((a.name, None)) + else: + pa = proxy(a) + bwd_si.args.append((pa.name, None)) + trace_of_backward._siginfo = bwd_si + trace_of_backward.args = (func, *(unwrapped_args + grads)) + + jit_ctx.computation_trace.push_scope([]) + wrapped_grads = tree_map(lambda g: wrap(g, provenance=result.provenance), grads) + + # pr1 = ProvenanceRecord(PseudoInst.BUILD_TUPLE, inputs=[v.provenance for v in args]) # other inst? + # pr2 = ProvenanceRecord(PseudoInst.BUILD_TUPLE, inputs=[v.provenance for v in wrapped_grads]) + # res = _interpret_call(func, *args, **kwargs) + tmp = vjp(checkpoint_fwd_sym) + import pdb + + pdb.set_trace() + # TODO How to call vjp on forward symbol, currently: + # TypeError: decomposed_fn_backward_rule(decomposed_fn, args, kwargs, saved_for_backward, *grads) doesn't match the signature of + # result = backward(*residuals, *cotangents) when activation_checkpoint augforward symbol is registered but backward is not and use decomposition + # residules are saved_for_backward here?? + checkpoint_backward_result = tmp((func, *unwrapped_args), grads, **kwargs) + + if checkpoint_backward_result is INTERPRETER_SIGNALS.EXCEPTION_RAISED: + return checkpoint_backward_result + + checkpoint_bwd_sym: list[BoundSymbol] = jit_ctx.computation_trace.pop_scope() + + for bsym in checkpoint_bwd_sym: + trace_of_backward.add_bound_symbol(bsym) + with tracectx(trace_of_backward): + prims.python_return.bind(*unwrap(checkpoint_backward_result)[1], output=()) + + # @wraps(trace_of_backward.python_callable()) + # def bwd_trace_callable_interface(f, *args, **kwargs): + # return thunder.core.trace_interpreter.interpret_trace(trace_of_backward, f, *args, **kwargs) + + # bwd_si = SigInfo("backward_impl") + # bwd_si.args.append(("function", None)) + # for a in unwrapped_args + grads: + # if isinstance(a, Proxy): + # bwd_si.args.append((a.name, None)) + # else: + # pa = proxy(a) + # bwd_si.args.append((pa.name, None)) + # bwd_trace_impl = TraceCtx() + # for bsym in checkpoint_bwd_sym: + # bwd_trace_impl.add_bound_symbol(bsym) + # bwd_trace_impl.add_bound_symbol(prims.python_return.bind(*unwrap(checkpoint_backward_result)[1], output=())) + # bwd_trace_impl._siginfo = bwd_si + # bwd_trace_impl.args = tuple(func + unwrapped_args + grads) + + # @wraps(bwd_trace_impl.python_callable()) + # def bwd_impl_callable(f, *args, **kwargs): + # return thunder.core.trace_interpreter.interpret_trace(bwd_trace_impl, f, *args, **kwargs) + + # @wraps(bwd_trace_callable_interface) + # def backward_impl(f, *args, **kwargs): + # # check(not kwargs, lambda: f"{kwargs} expected to be empty") + # # new_args = ctx_proxy.saved_consts + args + # return bwd_impl_callable(f, *args, **kwargs) + + # backward_impls[checkpoint_fwd_sym.name] = backward_impl + return forward_result + + # It should be possible to call the general_thunder_jit here to handle the + # conversion from torch to thunder but it doesn't work now + # See https://github.com/Lightning-AI/lightning-thunder/issues/1126 + # TODO: Convert the function to a Thunder function + # def thunder_function(*args, **kwargs): + # return unwrap(function)(*args, **kwargs) + + # wrapped_thunder_function = wrap_const(thunder_function) + # return interpreter_needs_wrap(checkpoint)(wrapped_thunder_function, *args, **kwargs) + + # Adds proxy methods # NOTE These methods map to themselves, which prevents the interpreter from looking into them # This is OK because these methods are written in a tracing-safe manner, and trying to diff --git a/thunder/core/pytree.py b/thunder/core/pytree.py index 8bf2648ee5..4589789949 100644 --- a/thunder/core/pytree.py +++ b/thunder/core/pytree.py @@ -6,6 +6,7 @@ import thunder.core.dtypes as dtypes import thunder.core.devices as devices from thunder.core.baseutils import ProxyInterface +from types import FunctionType OPTREE_NAMESPACE = "thunder" @@ -24,6 +25,7 @@ def tree_flatten(args, namespace=""): if ( type(args) not in { + FunctionType, dict, list, str, diff --git a/thunder/tests/test_grad.py b/thunder/tests/test_grad.py index 3c964f5a62..13e0abcd6b 100644 --- a/thunder/tests/test_grad.py +++ b/thunder/tests/test_grad.py @@ -1673,6 +1673,49 @@ def func(a, b): get_saved_for_backward_tensors(execution_trace) +def test_torch_checkpoint(): + import torch.utils.checkpoint + import torch._higher_order_ops.wrap + from thunder.dynamo import ThunderCompiler + + def fn_to_checkpoint(x): + # return x.sin()#.cos().exp() + return torch.sin(x) + + checkpoint_fns = ( + # torch.utils.checkpoint.checkpoint, + # thunder.torch.checkpoint, + partial(torch.utils.checkpoint.checkpoint, use_reentrant=False), + torch.ops.higher_order.tag_activation_checkpoint, + ) + + for checkpoint_fn in checkpoint_fns: + + def f(x): + return checkpoint_fn(fn_to_checkpoint, x, use_reentrant=False) + + x = make_tensor((2, 2), device="cpu", dtype=torch.float32, requires_grad=True) + # backend = ThunderCompiler() + # jf = torch.compile(backend=backend)(f) + jf = thunder.jit(f) + out = jf(x) + print(thunder.last_traces(jf)[0]) + print(thunder.last_backward_traces(jf)[0]) + + # With activation checkpointing, we are saving only the original input. + # The intermediate values are recomputed during backward pass. + assert len(out.grad_fn.saved_tensors) == 1 + assert out.grad_fn.saved_tensors[0] is x + + g = torch.ones_like(out) + out.backward(g) + + x_ref = x.detach().requires_grad_() + out_ref = fn_to_checkpoint(x_ref) + out_ref.backward(g) + torch.testing.assert_close(x.grad, x_ref.grad) + + def test_inconsistent_output_length_grad_transform(): from thunder.extend import OperatorExecutor from thunder.core.proxies import AnyProxy, TensorProxy diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 1308300be5..a065d7acb1 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -43,7 +43,7 @@ ) from thunder.core.pytree import tree_map, tree_flatten, tree_unflatten from thunder.core.symbol import Symbol -from thunder.core.transforms import register_grad +from thunder.core.transforms import register_grad, register_augmented_forward, register_backward from thunder.core.prims import get_grad, put_grad from thunder.core.baseutils import run_once import thunder @@ -56,6 +56,8 @@ # NOTE torch is a requirement import torch +import torch.utils.checkpoint +import torch._higher_order_ops.wrap import warnings @@ -5152,6 +5154,71 @@ def _unwrap_if_dead(tensor): register_function(torch._C._functorch.unwrap_if_dead, _unwrap_if_dead) +# @torchsymbol( +# torch.utils.checkpoint.checkpoint, +# torch.ops.higher_order.tag_activation_checkpoint, +# id="activation_checkpoint", +# ) +# def checkpoint( +# function: Callable[..., TensorLike], +# *args: TensorLike, +# context_fn: None | Callable[..., Any] = None, +# debug: None | bool = None, +# determinism_check: None | str = None, +# preserve_rng_state: None | bool = None, +# use_reentrant: bool = False, +# **kwargs: Any, +# ) -> TensorLike: +# utils.check( +# not use_reentrant, +# lambda: "torch.checkpoint: use_reentrant=True is not supported in Thunder", +# ) +# # NOTE: Thunder currently ignores the context_fn, debug, determinism_check, preserve_rng_state arguments +# # Let's raise a warning if any of these arguments are passed +# if context_fn is not None: +# warnings.warn("torch.checkpoint: context_fn is not supported in Thunder and will be ignored") +# if debug is not None: +# warnings.warn("torch.checkpoint: debug is not supported in Thunder and will be ignored") +# if determinism_check is not None: +# warnings.warn("torch.checkpoint: determinism_check is not supported in Thunder and will be ignored") +# if preserve_rng_state is not None: +# warnings.warn("torch.checkpoint: preserve_rng_state is not supported in Thunder and will be ignored") +# return function(*args, **kwargs) + + +# @register_augmented_forward( +# "activation_checkpoint", +# ) +# def _augmented_forward_checkpoint( +# function: Callable[..., TensorLike], +# *args: TensorLike, +# # context_fn: None | Callable[..., Any] = None, +# # debug: None | bool = None, +# # determinism_check: None | str = None, +# # preserve_rng_state: None | bool = None, +# # use_reentrant: bool = False, +# **kwargs: Any, +# ) -> TensorLike: +# result = function(*args, **kwargs) +# saved_for_backward = (function, args, kwargs) +# return result, saved_for_backward + + +# @register_backward( +# "activation_checkpoint", +# ) +# def _backward_checkpoint( +# function, +# args, +# kwargs, +# *grad_outputs, +# ) -> tuple[None | TensorLike, ...]: +# from thunder.core.transforms import vjp + +# result, grad = vjp(function)(args, grad_outputs, **kwargs) +# return grad #result + + # # Distributed operations #