Skip to content

Commit

Permalink
Call typechecker from isolated frame (#260)
Browse files Browse the repository at this point in the history
* Call typechecker from isolated frame

* Add comment

* Enable test_no_garbage for TypeGuard
  • Loading branch information
ojw28 authored Oct 23, 2024
1 parent d22a0a8 commit a831de6
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 8 deletions.
17 changes: 14 additions & 3 deletions jaxtyping/_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,15 @@ def __repr__(self):
_tb_flag = True


def _apply_typechecker(typechecker, fn):
"""Calls `typechecker(fn)` in an isolated frame, returning the result.
This avoids reference cycles that can otherwise occur if `typechecker` grabs
the calling frame's locals.
"""
return typechecker(fn)


@overload
def jaxtyped(
*,
Expand Down Expand Up @@ -422,8 +431,8 @@ def wrapped_fn(*args, **kwargs): # pyright: ignore
param_fn = _make_fn_with_signature(
name, qualname, module, param_signature, output=False
)
full_fn = typechecker(full_fn)
param_fn = typechecker(param_fn)
full_fn = _apply_typechecker(typechecker, full_fn)
param_fn = _apply_typechecker(typechecker, param_fn)

def wrapped_fn_impl(args, kwargs, bound, memos):
# First type-check just the parameters before the function is
Expand Down Expand Up @@ -790,7 +799,9 @@ def _get_problem_arg(
fn = _make_fn_with_signature(
"check_single_arg", "check_single_arg", module, new_signature, output=False
)
fn = typechecker(fn) # but no `jaxtyped`; keep the same environment.
fn = _apply_typechecker(
typechecker, fn
) # but no `jaxtyped`; keep the same environment.
try:
fn(*args, **kwargs)
except Exception as e:
Expand Down
31 changes: 26 additions & 5 deletions test/test_decorator.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import abc
import dataclasses
import sys
from typing import no_type_check

import jax.numpy as jnp
import jax.random as jr
import pytest
import typeguard

from jaxtyping import Array, Float, jaxtyped, print_bindings

Expand Down Expand Up @@ -213,10 +213,6 @@ def g(x: Float[Array, "foo bar"]):


def test_no_garbage(typecheck):
if typecheck is typeguard.typechecked:
# Currently fails due to reference cycles in typeguard.
pytest.skip()

with assert_no_garbage():

@jaxtyped(typechecker=typecheck)
Expand All @@ -236,3 +232,28 @@ class _Obj:
x: int

_Obj(x=5)


def test_no_garbage_frame_capture_typecheck():
with assert_no_garbage():
# Some typechecker implementations (e.g., typeguard 2.13.3) capture the calling
# frame's f_locals. This test checks that the calling frames in jaxtyping are
# sufficiently isolated to avoid introducing reference cycles when a
# typechecker does this.
def frame_locals_capture(fn):
locals = sys._getframe(1).f_locals

def wrapper(*args, **kwargs):
# Required to ensure wrapper holds a reference to f_locals, which is
# the scenario under test.
_ = locals
return fn(*args, **kwargs)

return wrapper

@jaxtyped(typechecker=frame_locals_capture)
@dataclasses.dataclass
class _Obj:
x: int

_Obj(x=5)

0 comments on commit a831de6

Please sign in to comment.