diff --git a/src/RuntimeGeneratedFunctions.jl b/src/RuntimeGeneratedFunctions.jl index 99f4cb9..6e9ec83 100644 --- a/src/RuntimeGeneratedFunctions.jl +++ b/src/RuntimeGeneratedFunctions.jl @@ -75,8 +75,19 @@ struct RuntimeGeneratedFunction{argnames, cache_tag, context_tag, id, B} <: Func end end -function drop_expr(::RuntimeGeneratedFunction{A, C1, C2, ID}) where {A, C1, C2, ID} - RuntimeGeneratedFunction{A, C1, C2, ID}(nothing) +function drop_expr(::RuntimeGeneratedFunction{a, cache_tag, c, id}) where {a, cache_tag, c, + id} + # When dropping the reference to the body from an RGF, we need to upgrade + # from a weak to a strong reference in the cache to prevent the body being + # GC'd. + lock(_cache_lock) do + cache = getfield(parentmodule(cache_tag), _cachename) + body = cache[id] + if body isa WeakRef + cache[id] = body.value + end + end + RuntimeGeneratedFunction{a, cache_tag, c, id}(nothing) end function _check_rgf_initialized(mods...) @@ -119,7 +130,7 @@ function Base.show(io::IO, ::MIME"text/plain", } cache_mod = parentmodule(cache_tag) context_mod = parentmodule(context_tag) - func_expr = Expr(:->, Expr(:tuple, argnames...), f.body) + func_expr = Expr(:->, Expr(:tuple, argnames...), _lookup_body(cache_tag, id)) print(io, "RuntimeGeneratedFunction(#=in $cache_mod=#, #=using $context_mod=#, ", repr(func_expr), ")") end @@ -169,16 +180,29 @@ function _cache_body(cache_tag, id, body) cache = getfield(parentmodule(cache_tag), _cachename) # Caching is tricky when `id` is the same for different AST instances: # - # Tricky case #1: If a function body with the same `id` was cached - # previously, we need to use that older instance of the body AST as the - # canonical one rather than `body`. This ensures the lifetime of the - # body in the cache will always cover the lifetime of the parent - # `RuntimeGeneratedFunction`s when they share the same `id`. - cached_body = haskey(cache, id) ? cache[id] : nothing - cached_body = cached_body !== nothing ? cached_body : body - # We cannot use WeakRef because we might drop body to make RGF GPU - # compatible. - cache[id] = cached_body + # 1. If a function body with the same `id` was cached previously, we need + # to use that older instance of the body AST as the canonical one + # rather than `body`. This ensures the lifetime of the body in the + # cache will always cover the lifetime of all RGFs which share the same + # `id`. + # + # 2. Unless we hold a separate reference to `cache[id].value`, the GC + # can collect it (causing it to become `nothing`). So root it in a + # local variable first. + # + cached_body = get(cache, id, nothing) + if !isnothing(cached_body) + if cached_body isa WeakRef + # `value` may be nothing here if it was previously cached but GC'd + cached_body = cached_body.value + end + end + if isnothing(cached_body) + cached_body = body + # Use a WeakRef to allow `body` to be garbage collected. (After GC, the + # cache will still contain an empty entry with key `id`.) + cache[id] = WeakRef(cached_body) + end return cached_body end end @@ -186,7 +210,8 @@ end function _lookup_body(cache_tag, id) lock(_cache_lock) do cache = getfield(parentmodule(cache_tag), _cachename) - cache[id] + body = cache[id] + body isa WeakRef ? body.value : body end end diff --git a/test/runtests.jl b/test/runtests.jl index e5e114d..6c89960 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -96,6 +96,17 @@ end GC.gc() @test f_gc(1, -1) == 100001 +# Test that drop_expr works +f_drop1, f_drop2 = let + ex = Base.remove_linenums!(:(x -> x - 1)) + # Construct two identical RGFs here to test the cache deduplication code + (drop_expr(@RuntimeGeneratedFunction(ex)), + drop_expr(@RuntimeGeneratedFunction(ex))) +end +GC.gc() +@test f_drop1(1) == 0 +@test f_drop2(1) == 0 + # Test that threaded use works tasks = [] for k in 1:4