diff --git a/src/RuntimeGeneratedFunctions.jl b/src/RuntimeGeneratedFunctions.jl index 8f2d4de..107e90a 100644 --- a/src/RuntimeGeneratedFunctions.jl +++ b/src/RuntimeGeneratedFunctions.jl @@ -10,19 +10,20 @@ export @RuntimeGeneratedFunction This type should be constructed via the macro @RuntimeGeneratedFunction. """ -struct RuntimeGeneratedFunction{argnames,moduletag,id} <: Function +struct RuntimeGeneratedFunction{argnames, cache_tag, context_tag, id} <: Function body::Expr - function RuntimeGeneratedFunction(moduletag, ex) + function RuntimeGeneratedFunction(cache_tag, context_tag, ex) def = splitdef(ex) args, body = normalize_args(def[:args]), def[:body] id = expr_to_id(body) - cached_body = _cache_body(moduletag, id, body) - new{Tuple(args),moduletag,id}(cached_body) + cached_body = _cache_body(cache_tag, id, body) + new{Tuple(args), cache_tag, context_tag, id}(cached_body) end end """ @RuntimeGeneratedFunction(function_expression) + @RuntimeGeneratedFunction(context_module, function_expression) Construct a function from `function_expression` which can be called immediately without world age problems. Somewhat like using `eval(function_expression)` and @@ -35,6 +36,10 @@ then calling the resulting function. The differences are: You need to use `RuntimeGeneratedFunctions.init(your_module)` a single time at the top level of `your_module` before any other uses of the macro. +If provided, `context_module` is module in which symbols within +`function_expression` will be looked up. By default this is module in which +`@RuntimeGeneratedFunction` is expanded. + # Examples ``` RuntimeGeneratedFunctions.init(@__MODULE__) # Required at module top-level @@ -46,23 +51,33 @@ function foo() end ``` """ -macro RuntimeGeneratedFunction(ex) +macro RuntimeGeneratedFunction(code) + _RGF_constructor_code(:(@__MODULE__), esc(code)) +end +macro RuntimeGeneratedFunction(context_module, code) + _RGF_constructor_code(esc(context_module), esc(code)) +end + +function _RGF_constructor_code(context_module, code) quote - if !($(esc(:(@isdefined($_tagname))))) + code = $code + cache_module = @__MODULE__ + context_module = $context_module + if #==# !isdefined(cache_module, $(QuoteNode(_tagname))) || + !isdefined(context_module, $(QuoteNode(_tagname))) + init_mods = unique([context_module, cache_module]) error("""You must use `RuntimeGeneratedFunctions.init(@__MODULE__)` at module - top level before using runtime generated functions""") + top level before using runtime generated functions in $init_mods""") end - RuntimeGeneratedFunction( - $(esc(_tagname)), - $(esc(ex)) - ) + RuntimeGeneratedFunction(cache_module.$_tagname, context_module.$_tagname, $code) end end -function Base.show(io::IO, f::RuntimeGeneratedFunction{argnames, moduletag, id}) where {argnames,moduletag,id} - mod = parentmodule(moduletag) +function Base.show(io::IO, ::MIME"text/plain", f::RuntimeGeneratedFunction{argnames, cache_tag, context_tag, id}) where {argnames,cache_tag,context_tag,id} + cache_mod = parentmodule(cache_tag) + context_mod = parentmodule(context_tag) func_expr = Expr(:->, Expr(:tuple, argnames...), f.body) - print(io, "RuntimeGeneratedFunction(#=in $mod=#, ", repr(func_expr), ")") + print(io, "RuntimeGeneratedFunction(#=in $cache_mod=#, #=using $context_mod=#, ", repr(func_expr), ")") end (f::RuntimeGeneratedFunction)(args::Vararg{Any,N}) where N = generated_callfunc(f, args...) @@ -71,9 +86,9 @@ end # @RuntimeGeneratedFunction function generated_callfunc end -function generated_callfunc_body(argnames, moduletag, id, __args) +function generated_callfunc_body(argnames, cache_tag, id, __args) setup = (:($(argnames[i]) = @inbounds __args[$i]) for i in 1:length(argnames)) - body = _lookup_body(moduletag, id) + body = _lookup_body(cache_tag, id) @assert body !== nothing quote $(setup...) @@ -103,9 +118,9 @@ _cache_lock = Threads.SpinLock() _cachename = Symbol("#_RuntimeGeneratedFunctions_cache") _tagname = Symbol("#_RGF_ModTag") -function _cache_body(moduletag, id, body) +function _cache_body(cache_tag, id, body) lock(_cache_lock) do - cache = getfield(parentmodule(moduletag), _cachename) + cache = getfield(parentmodule(cache_tag), _cachename) # Caching is tricky when `id` is the same for different AST instances: # # Tricky case #1: If a function body with the same `id` was cached @@ -127,9 +142,9 @@ function _cache_body(moduletag, id, body) end end -function _lookup_body(moduletag, id) +function _lookup_body(cache_tag, id) lock(_cache_lock) do - cache = getfield(parentmodule(moduletag), _cachename) + cache = getfield(parentmodule(cache_tag), _cachename) cache[id].value end end @@ -159,8 +174,9 @@ function init(mod) # or so. See: # https://github.com/JuliaLang/julia/pull/32902 # https://github.com/NHDaly/StagedFunctions.jl/blob/master/src/StagedFunctions.jl#L30 - @inline @generated function $RuntimeGeneratedFunctions.generated_callfunc(f::$RuntimeGeneratedFunctions.RuntimeGeneratedFunction{argnames, $_tagname, id}, __args...) where {argnames,id} - $RuntimeGeneratedFunctions.generated_callfunc_body(argnames, $_tagname, id, __args) + @inline @generated function $RuntimeGeneratedFunctions.generated_callfunc( + f::$RuntimeGeneratedFunctions.RuntimeGeneratedFunction{argnames, cache_tag, $_tagname, id}, __args...) where {argnames, cache_tag, id} + $RuntimeGeneratedFunctions.generated_callfunc_body(argnames, cache_tag, id, __args) end end) end diff --git a/test/runtests.jl b/test/runtests.jl index 9f73caa..db2efe8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -72,9 +72,9 @@ end @test no_worldage() === nothing # Test show() -@test sprint(show, @RuntimeGeneratedFunction(Base.remove_linenums!(:((x,y)->x+y+1)))) == +@test sprint(show, MIME"text/plain"(), @RuntimeGeneratedFunction(Base.remove_linenums!(:((x,y)->x+y+1)))) == """ - RuntimeGeneratedFunction(#=in $(@__MODULE__)=#, :((x, y)->begin + RuntimeGeneratedFunction(#=in $(@__MODULE__)=#, #=using $(@__MODULE__)=#, :((x, y)->begin x + y + 1 end))""" @@ -118,12 +118,15 @@ module GlobalsTest using RuntimeGeneratedFunctions RuntimeGeneratedFunctions.init(@__MODULE__) - y = 40 - f = @RuntimeGeneratedFunction(:(x->x+y)) + y_in_GlobalsTest = 40 + f = @RuntimeGeneratedFunction(:(x->x + y_in_GlobalsTest)) end @test GlobalsTest.f(2) == 42 +f_outside = @RuntimeGeneratedFunction(GlobalsTest, :(x->x + y_in_GlobalsTest)) +@test f_outside(2) == 42 + @test_throws ErrorException @eval(module NotInitTest using RuntimeGeneratedFunctions # RuntimeGeneratedFunctions.init(@__MODULE__) # <-- missing