From d01475bec6a66ed0e5efa71859ca1955b3c4ca91 Mon Sep 17 00:00:00 2001 From: Jeff Bezanson Date: Mon, 7 Aug 2017 16:12:11 -0400 Subject: [PATCH] add `if @generated ... else ... end` inside functions to provide optional optimizers use meta nodes instead of `stagedfunction` expression head --- NEWS.md | 4 + base/boot.jl | 28 +++++ base/docs/Docs.jl | 2 +- base/expr.jl | 19 +++- base/linalg/bidiag.jl | 16 ++- base/methodshow.jl | 12 +- base/multidimensional.jl | 52 +++++---- base/reflection.jl | 18 +-- base/sysimg.jl | 32 +++--- doc/src/manual/metaprogramming.md | 76 +++++++++++-- src/ast.c | 5 +- src/ast.scm | 14 +++ src/codegen.cpp | 8 +- src/dump.c | 2 +- src/interpreter.c | 2 +- src/jltypes.c | 3 +- src/julia-syntax.scm | 147 ++++++++++++++++--------- src/julia.h | 4 +- src/julia_internal.h | 2 + src/macroexpand.scm | 5 - src/method.c | 177 +++++++++++++----------------- src/utils.scm | 13 ++- test/reflection.jl | 7 +- test/staged.jl | 28 +++++ 24 files changed, 422 insertions(+), 254 deletions(-) diff --git a/NEWS.md b/NEWS.md index e69fa63f533f31..38daf8585e7b0c 100644 --- a/NEWS.md +++ b/NEWS.md @@ -19,6 +19,10 @@ New language features * The macro call syntax `@macroname[args]` is now available and is parsed as `@macroname([args])` ([#23519]). + * The construct `if @generated ...; else ...; end` can be used to provide both + `@generated` and normal implementations of part of a function. Surrounding code + will be common to both versions ([#23168]). + Language changes ---------------- diff --git a/base/boot.jl b/base/boot.jl index d66b8174cebf46..7d1163f24d32d7 100644 --- a/base/boot.jl +++ b/base/boot.jl @@ -430,4 +430,32 @@ show(@nospecialize a) = show(STDOUT, a) print(@nospecialize a...) = print(STDOUT, a...) println(@nospecialize a...) = println(STDOUT, a...) +struct GeneratedFunctionStub + gen + argnames::Array{Any,1} + spnames::Union{Void, Array{Any,1}} + line::Int + file::Symbol +end + +# invoke and wrap the results of @generated +function (g::GeneratedFunctionStub)(@nospecialize args...) + body = g.gen(args...) + if body isa CodeInfo + return body + end + lam = Expr(:lambda, g.argnames, + Expr(Symbol("scope-block"), + Expr(:block, + LineNumberNode(g.line, g.file), + Expr(:meta, :push_loc, g.file, Symbol("@generated body")), + Expr(:return, body), + Expr(:meta, :pop_loc)))) + if g.spnames === nothing + return lam + else + return Expr(Symbol("with-static-parameters"), lam, g.spnames...) + end +end + ccall(:jl_set_istopmod, Void, (Any, Bool), Core, true) diff --git a/base/docs/Docs.jl b/base/docs/Docs.jl index 94f0336926424c..ed2d8b71ee6d17 100644 --- a/base/docs/Docs.jl +++ b/base/docs/Docs.jl @@ -642,7 +642,7 @@ finddoc(λ, def) = false # Predicates and helpers for `docm` expression selection: -const FUNC_HEADS = [:function, :stagedfunction, :macro, :(=)] +const FUNC_HEADS = [:function, :macro, :(=)] const BINDING_HEADS = [:typealias, :const, :global, :(=)] # deprecation: remove `typealias` post-0.6 # For the special `:@mac` / `:(Base.@mac)` syntax for documenting a macro after definition. isquotedmacrocall(x) = diff --git a/base/expr.jl b/base/expr.jl index d3138e1a0ac23e..9035fdfeed04bd 100644 --- a/base/expr.jl +++ b/base/expr.jl @@ -332,10 +332,23 @@ function remove_linenums!(ex::Expr) return ex end +macro generated() + return Expr(:generated) +end + macro generated(f) - if isa(f, Expr) && (f.head === :function || is_short_function_def(f)) - f.head = :stagedfunction - return Expr(:escape, f) + if isa(f, Expr) && (f.head === :function || is_short_function_def(f)) + body = f.args[2] + lno = body.args[1] + return Expr(:escape, + Expr(f.head, f.args[1], + Expr(:block, + lno, + Expr(:if, Expr(:generated), + body, + Expr(:block, + Expr(:meta, :generated_only), + Expr(:return, nothing)))))) else error("invalid syntax; @generated must be used with a function definition") end diff --git a/base/linalg/bidiag.jl b/base/linalg/bidiag.jl index f8d426b4857fda..25f6fd8bae858f 100644 --- a/base/linalg/bidiag.jl +++ b/base/linalg/bidiag.jl @@ -577,12 +577,18 @@ _valuefields(::Type{<:AbstractTriangular}) = [:data] const SpecialArrays = Union{Diagonal,Bidiagonal,Tridiagonal,SymTridiagonal,AbstractTriangular} -@generated function fillslots!(A::SpecialArrays, x) - ex = :(xT = convert(eltype(A), x)) - for field in _valuefields(A) - ex = :($ex; fill!(A.$field, xT)) +function fillslots!(A::SpecialArrays, x) + xT = convert(eltype(A), x) + if @generated + quote + $([ :(fill!(A.$field, xT)) for field in _valuefields(A) ]...) + end + else + for field in _valuefields(A) + fill!(getfield(A, field), xT) + end end - :($ex;return A) + return A end # for historical reasons: diff --git a/base/methodshow.jl b/base/methodshow.jl index ed09e3da62d9e5..c31c6879d07837 100644 --- a/base/methodshow.jl +++ b/base/methodshow.jl @@ -42,6 +42,15 @@ function argtype_decl(env, n, sig::DataType, i::Int, nargs, isva::Bool) # -> (ar return s, string_with_env(env, t) end +function method_argnames(m::Method) + if !isdefined(m, :source) && isdefined(m, :generator) + return m.generator.argnames + end + argnames = Vector{Any}(m.nargs) + ccall(:jl_fill_argnames, Void, (Any, Any), m.source, argnames) + return argnames +end + function arg_decl_parts(m::Method) tv = Any[] sig = m.sig @@ -52,8 +61,7 @@ function arg_decl_parts(m::Method) file = m.file line = m.line if isdefined(m, :source) || isdefined(m, :generator) - argnames = Vector{Any}(m.nargs) - ccall(:jl_fill_argnames, Void, (Any, Any), isdefined(m, :source) ? m.source : m.generator.inferred, argnames) + argnames = method_argnames(m) show_env = ImmutableDict{Symbol, Any}() for t in tv show_env = ImmutableDict(show_env, :unionall_env => t) diff --git a/base/multidimensional.jl b/base/multidimensional.jl index 2811eca7a3b72f..cc5f4349d573cf 100644 --- a/base/multidimensional.jl +++ b/base/multidimensional.jl @@ -549,14 +549,11 @@ end @noinline throw_checksize_error(A, sz) = throw(DimensionMismatch("output array is the wrong size; expected $sz, got $(size(A))")) ## setindex! ## -@generated function _setindex!(l::IndexStyle, A::AbstractArray, x, I::Union{Real, AbstractArray}...) - N = length(I) - quote - @_inline_meta - @boundscheck checkbounds(A, I...) - _unsafe_setindex!(l, _maybe_reshape(l, A, I...), x, I...) - A - end +function _setindex!(l::IndexStyle, A::AbstractArray, x, I::Union{Real, AbstractArray}...) + @_inline_meta + @boundscheck checkbounds(A, I...) + _unsafe_setindex!(l, _maybe_reshape(l, A, I...), x, I...) + A end _iterable(v::AbstractArray) = v @@ -916,28 +913,29 @@ function copy!(dest::AbstractArray{T,N}, src::AbstractArray{T,N}) where {T,N} dest end -@generated function copy!(dest::AbstractArray{T1,N}, - Rdest::CartesianRange{N}, - src::AbstractArray{T2,N}, - Rsrc::CartesianRange{N}) where {T1,T2,N} - quote - isempty(Rdest) && return dest - if size(Rdest) != size(Rsrc) - throw(ArgumentError("source and destination must have same size (got $(size(Rsrc)) and $(size(Rdest)))")) +function copy!(dest::AbstractArray{T1,N}, Rdest::CartesianRange{N}, + src::AbstractArray{T2,N}, Rsrc::CartesianRange{N}) where {T1,T2,N} + isempty(Rdest) && return dest + if size(Rdest) != size(Rsrc) + throw(ArgumentError("source and destination must have same size (got $(size(Rsrc)) and $(size(Rdest)))")) + end + @boundscheck checkbounds(dest, first(Rdest)) + @boundscheck checkbounds(dest, last(Rdest)) + @boundscheck checkbounds(src, first(Rsrc)) + @boundscheck checkbounds(src, last(Rsrc)) + ΔI = first(Rdest) - first(Rsrc) + if @generated + quote + @nloops $N i (n->Rsrc.indices[n]) begin + @inbounds @nref($N,dest,n->i_n+ΔI[n]) = @nref($N,src,i) + end end - @boundscheck checkbounds(dest, first(Rdest)) - @boundscheck checkbounds(dest, last(Rdest)) - @boundscheck checkbounds(src, first(Rsrc)) - @boundscheck checkbounds(src, last(Rsrc)) - ΔI = first(Rdest) - first(Rsrc) - # TODO: restore when #9080 is fixed - # for I in Rsrc - # @inbounds dest[I+ΔI] = src[I] - @nloops $N i (n->Rsrc.indices[n]) begin - @inbounds @nref($N,dest,n->i_n+ΔI[n]) = @nref($N,src,i) + else + for I in Rsrc + @inbounds dest[I + ΔI] = src[I] end - dest end + dest end """ diff --git a/base/reflection.jl b/base/reflection.jl index 1f803be53ce6a6..78ac1184a70f93 100644 --- a/base/reflection.jl +++ b/base/reflection.jl @@ -597,12 +597,13 @@ end """ code_lowered(f, types, expand_generated = true) -Return an array of lowered ASTs for the methods matching the given generic function and type signature. +Return an array of the lowered forms (IR) for the methods matching the given generic function +and type signature. -If `expand_generated` is `false`, then the `CodeInfo` instances returned for `@generated` -methods will correspond to the generators' lowered ASTs. If `expand_generated` is `true`, -these `CodeInfo` instances will correspond to the lowered ASTs of the method bodies yielded -by expanding the generators. +If `expand_generated` is `false`, the returned `CodeInfo` instances will correspond to fallback +implementations. An error is thrown if no fallback implementation exists. +If `expand_generated` is `true`, these `CodeInfo` instances will correspond to the method bodies +yielded by expanding the generators. Note that an error will be thrown if `types` are not leaf types when `expand_generated` is `true` and the corresponding method is a `@generated` method. @@ -737,7 +738,8 @@ function length(mt::MethodTable) end isempty(mt::MethodTable) = (mt.defs === nothing) -uncompressed_ast(m::Method) = uncompressed_ast(m, isdefined(m, :source) ? m.source : m.generator.inferred) +uncompressed_ast(m::Method) = isdefined(m,:source) ? uncompressed_ast(m, m.source) : + error("Method is @generated; try `code_lowered` instead.") uncompressed_ast(m::Method, s::CodeInfo) = s uncompressed_ast(m::Method, s::Array{UInt8,1}) = ccall(:jl_uncompress_ast, Any, (Any, Any), m, s)::CodeInfo uncompressed_ast(m::Core.MethodInstance) = uncompressed_ast(m.def) @@ -851,7 +853,7 @@ code_native(::IO, ::Any, ::Symbol) = error("illegal code_native call") # resolve # give a decent error message if we try to instantiate a staged function on non-leaf types function func_for_method_checked(m::Method, @nospecialize types) - if isdefined(m,:generator) && !isdefined(m,:source) && !_isleaftype(types) + if isdefined(m,:generator) && !_isleaftype(types) error("cannot call @generated function `", m, "` ", "with abstract argument types: ", types) end @@ -861,7 +863,7 @@ end """ code_typed(f, types; optimize=true) -Returns an array of lowered and type-inferred ASTs for the methods matching the given +Returns an array of type-inferred lowered form (IR) for the methods matching the given generic function and type signature. The keyword argument `optimize` controls whether additional optimizations, such as inlining, are also applied. """ diff --git a/base/sysimg.jl b/base/sysimg.jl index f845cbb31196fb..0f1a6d2cd514f6 100644 --- a/base/sysimg.jl +++ b/base/sysimg.jl @@ -236,21 +236,27 @@ include("broadcast.jl") using .Broadcast # define the real ntuple functions -@generated function ntuple(f::F, ::Val{N}) where {F,N} - Core.typeassert(N, Int) - (N >= 0) || return :(throw($(ArgumentError(string("tuple length should be ≥0, got ", N))))) - return quote - $(Expr(:meta, :inline)) - @nexprs $N i -> t_i = f(i) - @ncall $N tuple t +@inline function ntuple(f::F, ::Val{N}) where {F,N} + N::Int + (N >= 0) || throw(ArgumentError(string("tuple length should be ≥0, got ", N))) + if @generated + quote + @nexprs $N i -> t_i = f(i) + @ncall $N tuple t + end + else + Tuple(f(i) for i = 1:N) end end -@generated function fill_to_length(t::Tuple, val, ::Val{N}) where {N} - M = length(t.parameters) - M > N && return :(throw($(ArgumentError("input tuple of length $M, requested $N")))) - return quote - $(Expr(:meta, :inline)) - (t..., $(Any[ :val for i = (M + 1):N ]...)) +@inline function fill_to_length(t::Tuple, val, ::Val{N}) where {N} + M = length(t) + M > N && throw(ArgumentError("input tuple of length $M, requested $N")) + if @generated + quote + (t..., $(fill(:val, N-length(t.parameters))...)) + end + else + (t..., fill(val, N-M)...) end end diff --git a/doc/src/manual/metaprogramming.md b/doc/src/manual/metaprogramming.md index 8a2d534d6f1f8a..a7b4e414949e6a 100644 --- a/doc/src/manual/metaprogramming.md +++ b/doc/src/manual/metaprogramming.md @@ -1011,17 +1011,16 @@ syntax tree. A very special macro is `@generated`, which allows you to define so-called *generated functions*. These have the capability to generate specialized code depending on the types of their arguments with more flexibility and/or less code than what can be achieved with multiple dispatch. While -macros work with expressions at parsing-time and cannot access the types of their inputs, a generated +macros work with expressions at parse time and cannot access the types of their inputs, a generated function gets expanded at a time when the types of the arguments are known, but the function is not yet compiled. Instead of performing some calculation or action, a generated function declaration returns a quoted expression which then forms the body for the method corresponding to the types of the arguments. -When called, the body expression is first evaluated and compiled, then the returned expression -is compiled and run. To make this efficient, the result is often cached. And to make this inferable, -only a limited subset of the language is usable. Thus, generated functions provide a flexible -framework to move work from run-time to compile-time, at the expense of greater restrictions on -the allowable constructs. +When a generated function is called, the expression it returns is compiled and then run. +To make this efficient, the result is usually cached. And to make this inferable, only a limited +subset of the language is usable. Thus, generated functions provide a flexible way to move work from +run time to compile time, at the expense of greater restrictions on allowed constructs. When defining generated functions, there are four main differences to ordinary functions: @@ -1037,7 +1036,7 @@ When defining generated functions, there are four main differences to ordinary f This means they can only read global constants, and cannot have any side effects. In other words, they must be completely pure. Due to an implementation limitation, this also means that they currently cannot define a closure - or untyped generator. + or generator. It's easiest to illustrate this with an example. We can declare a generated function `foo` as @@ -1052,9 +1051,8 @@ foo (generic function with 1 method) Note that the body returns a quoted expression, namely `:(x * x)`, rather than just the value of `x * x`. -From the caller's perspective, they are very similar to regular functions; in fact, you don't -have to know if you're calling a regular or generated function - the syntax and result of the -call is just the same. Let's see how `foo` behaves: +From the caller's perspective, this is identical to a regular function; in fact, you don't +have to know whether you're calling a regular or generated function. Let's see how `foo` behaves: ```jldoctest generated julia> x = foo(2); # note: output is from println() statement in the body @@ -1198,7 +1196,7 @@ end and at the call site; however, *don't copy them*, for the following reasons: when, how often or how many times these side-effects will occur * the `bar` function solves a problem that is better solved with multiple dispatch - defining `bar(x) = x` and `bar(x::Integer) = x ^ 2` will do the same thing, but it is both simpler and faster. - * the `baz` function is pathologically insane + * the `baz` function is pathological Note that the set of operations that should not be attempted in a generated function is unbounded, and the runtime system can currently only detect a subset of the invalid operations. There are @@ -1316,3 +1314,59 @@ the two tuples, multiplication and addition/subtraction. All the looping is perf and we avoid looping during execution entirely. Thus, we only loop *once per type*, in this case once per `N` (except in edge cases where the function is generated more than once - see disclaimer above). + +### Optionally-generated functions + +Generated functions can achieve high efficiency at run time, but come with a compile time cost: +a new function body must be generated for every combination of concrete argument types. +Typically, Julia is able to compile "generic" versions of functions that will work for any +arguments, but with generated functions this is impossible. +This means that programs making heavy use of generated functions might be impossible to +statically compile. + +To solve this problem, the language provides syntax for writing normal, non-generated +alternative implementations of generated functions. +Applied to the `sub2ind` example above, it would look like this: + +```julia +function sub2ind_gen(dims::NTuple{N}, I::Integer...) where N + if N != length(I) + throw(ArgumentError("Number of dimensions must match number of indices.")) + end + if @generated + ex = :(I[$N] - 1) + for i = (N - 1):-1:1 + ex = :(I[$i] - 1 + dims[$i] * $ex) + end + return :($ex + 1) + else + ind = I[N] - 1 + for i = (N - 1):-1:1 + ind = I[i] - 1 + dims[i]*ind + end + return ind + 1 + end +end +``` + +Internally, this code creates two implementations of the function: a generated one where +the first block in `if @generated` is used, and a normal one where the `else` block is used. +Inside the `then` part of the `if @generated` block, code has the same semantics as other +generated functions: argument names refer to types, and the code should return an expression. +Multiple `if @generated` blocks may occur, in which case the generated implementation uses +all of the `then` blocks and the alternate implementation uses all of the `else` blocks. + +Notice that we added an error check to the top of the function. +This code will be common to both versions, and is run-time code in both versions +(it will be quoted and returned as an expression from the generated version). +That means that the values and types of local variables are not available at code generation +time --- the code-generation code can only see the types of arguments. + +In this style of definition, the code generation feature is essentially an optional +optimization. +The compiler will use it if convenient, but otherwise may choose to use the normal +implementation instead. +This style is preferred, since it allows the compiler to make more decisions and compile +programs in more ways, and since normal code is more readable than code-generating code. +However, which implementation is used depends on compiler implementation details, so it +is essential for the two implementations to behave identically. diff --git a/src/ast.c b/src/ast.c index 8b9117ff3b25c6..c2a7f90500cd5d 100644 --- a/src/ast.c +++ b/src/ast.c @@ -55,7 +55,8 @@ jl_sym_t *meta_sym; jl_sym_t *compiler_temp_sym; jl_sym_t *inert_sym; jl_sym_t *vararg_sym; jl_sym_t *unused_sym; jl_sym_t *static_parameter_sym; jl_sym_t *polly_sym; jl_sym_t *inline_sym; -jl_sym_t *propagate_inbounds_sym; +jl_sym_t *propagate_inbounds_sym; jl_sym_t *generated_sym; +jl_sym_t *generated_only_sym; jl_sym_t *isdefined_sym; jl_sym_t *nospecialize_sym; jl_sym_t *macrocall_sym; jl_sym_t *hygienicscope_sym; @@ -343,6 +344,8 @@ void jl_init_frontend(void) hygienicscope_sym = jl_symbol("hygienic-scope"); gc_preserve_begin_sym = jl_symbol("gc_preserve_begin"); gc_preserve_end_sym = jl_symbol("gc_preserve_end"); + generated_sym = jl_symbol("generated"); + generated_only_sym = jl_symbol("generated_only"); } JL_DLLEXPORT void jl_lisp_prompt(void) diff --git a/src/ast.scm b/src/ast.scm index 18d389cc46037b..c8799364a93e29 100644 --- a/src/ast.scm +++ b/src/ast.scm @@ -358,6 +358,20 @@ (and (if one (length= e 3) (length> e 2)) (eq? (car e) 'meta) (eq? (cadr e) 'nospecialize))) +(define (if-generated? e) + (and (length= e 4) (eq? (car e) 'if) (equal? (cadr e) '(generated)))) + +(define (generated-meta? e) + (and (length= e 3) (eq? (car e) 'meta) (eq? (cadr e) 'generated))) + +(define (generated_only-meta? e) + (and (length= e 2) (eq? (car e) 'meta) (eq? (cadr e) 'generated_only))) + +(define (function-def? e) + (and (pair? e) (or (eq? (car e) 'function) (eq? (car e) '->) + (and (eq? (car e) '=) (length= e 3) + (eventually-call? (cadr e)))))) + ;; flatten nested expressions with the given head ;; (op (op a b) c) => (op a b c) (define (flatten-ex head e) diff --git a/src/codegen.cpp b/src/codegen.cpp index a9e3190730ade9..981baaa048a18e 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -1193,8 +1193,6 @@ jl_llvm_functions_t jl_compile_linfo(jl_method_instance_t **pli, jl_code_info_t li->inferred && // and there is something to delete (test this before calling jl_ast_flag_inlineable) li->inferred != jl_nothing && - // don't delete the code for the generator - li != li->def.method->generator && // don't delete inlineable code, unless it is constant (li->jlcall_api == 2 || !jl_ast_flag_inlineable((jl_array_t*)li->inferred)) && // don't delete code when generating a precompile file @@ -3842,11 +3840,10 @@ static jl_cgval_t emit_expr(jl_codectx_t &ctx, jl_value_t *expr) } Value *a1 = boxed(ctx, emit_expr(ctx, args[1])); Value *a2 = boxed(ctx, emit_expr(ctx, args[2])); - Value *mdargs[4] = { + Value *mdargs[3] = { /*argdata*/a1, /*code*/a2, - /*module*/literal_pointer_val(ctx, (jl_value_t*)ctx.module), - /*isstaged*/literal_pointer_val(ctx, args[3]) + /*module*/literal_pointer_val(ctx, (jl_value_t*)ctx.module) }; ctx.builder.CreateCall(prepare_call(jlmethod_func), makeArrayRef(mdargs)); return ghostValue(jl_void_type); @@ -6360,7 +6357,6 @@ static void init_julia_llvm_env(Module *m) mdargs.push_back(T_prjlvalue); mdargs.push_back(T_prjlvalue); mdargs.push_back(T_pjlvalue); - mdargs.push_back(T_pjlvalue); jlmethod_func = Function::Create(FunctionType::get(T_void, mdargs, false), Function::ExternalLinkage, diff --git a/src/dump.c b/src/dump.c index 6e223ab4d6017a..e45087a82bd258 100644 --- a/src/dump.c +++ b/src/dump.c @@ -1454,7 +1454,7 @@ static jl_value_t *jl_deserialize_value_method(jl_serializer_state *s, jl_value_ m->unspecialized = (jl_method_instance_t*)jl_deserialize_value(s, (jl_value_t**)&m->unspecialized); if (m->unspecialized) jl_gc_wb(m, m->unspecialized); - m->generator = (jl_method_instance_t*)jl_deserialize_value(s, (jl_value_t**)&m->generator); + m->generator = jl_deserialize_value(s, (jl_value_t**)&m->generator); if (m->generator) jl_gc_wb(m, m->generator); m->invokes.unknown = jl_deserialize_value(s, (jl_value_t**)&m->invokes); diff --git a/src/interpreter.c b/src/interpreter.c index 7faa061f405c3f..8c2e1d0e521b03 100644 --- a/src/interpreter.c +++ b/src/interpreter.c @@ -322,7 +322,7 @@ static jl_value_t *eval(jl_value_t *e, interpreter_state *s) JL_GC_PUSH2(&atypes, &meth); atypes = eval(args[1], s); meth = eval(args[2], s); - jl_method_def((jl_svec_t*)atypes, (jl_code_info_t*)meth, s->module, args[3]); + jl_method_def((jl_svec_t*)atypes, (jl_code_info_t*)meth, s->module); JL_GC_POP(); return jl_nothing; } diff --git a/src/jltypes.c b/src/jltypes.c index 320c12b969936b..56209a26aedeff 100644 --- a/src/jltypes.c +++ b/src/jltypes.c @@ -2061,7 +2061,7 @@ void jl_init_types(void) jl_simplevector_type, jl_any_type, jl_any_type, // jl_method_instance_type - jl_any_type, // jl_method_instance_type + jl_any_type, jl_array_any_type, jl_any_type, jl_int32_type, @@ -2174,7 +2174,6 @@ void jl_init_types(void) #endif jl_svecset(jl_methtable_type->types, 8, jl_int32_type); // uint32_t jl_svecset(jl_method_type->types, 10, jl_method_instance_type); - jl_svecset(jl_method_type->types, 11, jl_method_instance_type); jl_svecset(jl_method_instance_type->types, 12, jl_voidpointer_type); jl_svecset(jl_method_instance_type->types, 13, jl_voidpointer_type); jl_svecset(jl_method_instance_type->types, 14, jl_voidpointer_type); diff --git a/src/julia-syntax.scm b/src/julia-syntax.scm index 9afa174c4b19f6..f57a8c90be1cd0 100644 --- a/src/julia-syntax.scm +++ b/src/julia-syntax.scm @@ -286,9 +286,35 @@ (map (lambda (x) (replace-outer-vars x renames)) (cdr e)))))) +(define (make-generator-function name sp-names arg-names body) + (let ((arg-names (append sp-names + (map (lambda (n) + (if (eq? n '|#self#|) (gensy) n)) + arg-names)))) + (let ((body (insert-after-meta body ;; don't specialize on generator arguments + `((meta nospecialize ,@arg-names))))) + `(block + (global ,name) + (function (call ,name ,@arg-names) ,body))))) + +;; select the `then` or `else` part of `if @generated` based on flag `genpart` +(define (generated-part- x genpart) + (cond ((or (atom? x) (quoted? x) (function-def? x)) x) + ((if-generated? x) + (if genpart `($ ,(caddr x)) (cadddr x))) + (else (cons (car x) + (map (lambda (e) (generated-part- e genpart)) (cdr x)))))) + +(define (generated-version body) + `(block + ,(julia-bq-macro (generated-part- body #t)))) + +(define (non-generated-version body) + (generated-part- body #f)) + ;; construct the (method ...) expression for one primitive method definition, ;; assuming optional and keyword args are already handled -(define (method-def-expr- name sparams argl body isstaged (rett '(core Any))) +(define (method-def-expr- name sparams argl body (rett '(core Any))) (if (any kwarg? argl) ;; has optional positional args @@ -307,20 +333,39 @@ (dfl (map caddr kws))) (receive (vararg req) (separate vararg? argl) - (optional-positional-defs name sparams req opt dfl body isstaged + (optional-positional-defs name sparams req opt dfl body (append req opt vararg) rett))))) ;; no optional positional args - (let ((names (map car sparams))) - (let ((anames (llist-vars argl))) - (if (has-dups (filter (lambda (x) (not (eq? x UNUSED))) anames)) - (error "function argument names not unique")) - (if (has-dups names) - (error "function static parameter names not unique")) - (if (any (lambda (x) (and (not (eq? x UNUSED)) (memq x names))) anames) - (error "function argument and static parameter names must be distinct"))) + (let ((names (map car sparams)) + (anames (llist-vars argl))) + (if (has-dups (filter (lambda (x) (not (eq? x UNUSED))) anames)) + (error "function argument names not unique")) + (if (has-dups names) + (error "function static parameter names not unique")) + (if (any (lambda (x) (and (not (eq? x UNUSED)) (memq x names))) anames) + (error "function argument and static parameter names must be distinct")) (if (or (and name (not (sym-ref? name))) (eq? name 'true) (eq? name 'false)) (error (string "invalid function name \"" (deparse name) "\""))) - (let* ((types (llist-types argl)) + (let* ((generator (if (expr-contains-p if-generated? body (lambda (x) (not (function-def? x)))) + (let* ((gen (generated-version body)) + (nongen (non-generated-version body)) + (gname (symbol (string (gensy) "#" (current-julia-module-counter)))) + (gf (make-generator-function gname names (llist-vars argl) gen)) + (loc (function-body-lineno body))) + (set! body (insert-after-meta + nongen + `((meta generated + (new (core GeneratedFunctionStub) + ,gname + ,(cons 'list anames) + ,(if (null? sparams) + 'nothing + (cons 'list (map car sparams))) + ,(if (null? loc) 0 (cadr loc)) + (inert ,(if (null? loc) 'none (caddr loc)))))))) + (list gf)) + '())) + (types (llist-types argl)) (body (method-lambda-expr argl body rett)) ;; HACK: the typevars need to be bound to ssavalues, since this code ;; might be moved to a different scope by closure-convert. @@ -329,7 +374,7 @@ (mdef (if (null? sparams) `(method ,name (call (core svec) (call (core svec) ,@(dots->vararg types)) (call (core svec))) - ,body ,isstaged) + ,body) `(method ,name (block ,@(let loop ((n names) @@ -350,10 +395,12 @@ (replace-vars ty renames)) types))) (call (core svec) ,@temps))) - ,body ,isstaged)))) + ,body)))) (if (or (symbol? name) (globalref? name)) - `(block (method ,name) ,mdef (unnecessary ,name)) ;; return the function - mdef))))) + `(block ,@generator (method ,name) ,mdef (unnecessary ,name)) ;; return the function + (if (not (null? generator)) + `(block ,@generator ,mdef) + mdef)))))) ;; wrap expr in nested scopes assigning names to vals (define (scopenest names vals expr) @@ -365,10 +412,8 @@ (define empty-vector-any '(call (core AnyVector) 0)) -(define (keywords-method-def-expr name sparams argl body isstaged rett) +(define (keywords-method-def-expr name sparams argl body rett) (let* ((kargl (cdar argl)) ;; keyword expressions (= k v) - (annotations (map (lambda (a) `(meta nospecialize ,(arg-name (cadr (caddr a))))) - (filter nospecialize-meta? kargl))) (kargl (map (lambda (a) (if (nospecialize-meta? a) (caddr a) a)) kargl)) @@ -403,6 +448,8 @@ keynames)) ;; list of function's initial line number and meta nodes (empty if none) (prologue (extract-method-prologue body)) + (annotations (map (lambda (a) `(meta nospecialize ,(arg-name (cadr (caddr a))))) + (filter nospecialize-meta? kargl))) ;; body statements (stmts (cdr body)) (positional-sparams @@ -427,7 +474,7 @@ ,(method-def-expr- name positional-sparams (append pargl vararg) `(block - ,@prologue + ,@(without-generated prologue) ,(let (;; call mangled(vals..., [rest_kw,] pargs..., [vararg]...) (ret `(return (call ,mangled ,@(if ordered-defaults keynames vals) @@ -437,8 +484,7 @@ (list `(... ,(arg-name (car vararg))))))))) (if ordered-defaults (scopenest keynames vals ret) - ret))) - #f) + ret)))) ;; call with keyword args pre-sorted - original method code goes here ,(method-def-expr- @@ -457,7 +503,7 @@ (insert-after-meta `(block ,@stmts) annotations) - isstaged rett) + rett) ;; call with unsorted keyword args. this sorts and re-dispatches. ,(method-def-expr- @@ -539,8 +585,7 @@ ,@(if (null? restkw) '() (list rkw)) ,@(map arg-name pargl) ,@(if (null? vararg) '() - (list `(... ,(arg-name (car vararg))))))))) - #f) + (list `(... ,(arg-name (car vararg)))))))))) ;; return primary function ,(if (not (symbol? name)) '(null) name))))) @@ -553,6 +598,11 @@ (cdr body)) '())) +(define (without-generated stmts) + (filter (lambda (x) (not (or (generated-meta? x) + (generated_only-meta? x)))) + stmts)) + ;; keep only sparams used by `expr` or other sparams (define (filter-sparams expr sparams) (let loop ((filtered '()) @@ -566,8 +616,8 @@ (else (loop filtered (cdr params)))))) -(define (optional-positional-defs name sparams req opt dfl body isstaged overall-argl rett) - (let ((prologue (extract-method-prologue body))) +(define (optional-positional-defs name sparams req opt dfl body overall-argl rett) + (let ((prologue (without-generated (extract-method-prologue body)))) `(block ,@(map (lambda (n) (let* ((passed (append req (list-head opt n))) @@ -596,9 +646,9 @@ `(block ,@prologue (call ,(arg-name (car req)) ,@(map arg-name (cdr passed)) ,@vals))))) - (method-def-expr- name sp passed body #f))) + (method-def-expr- name sp passed body))) (iota (length opt))) - ,(method-def-expr- name sparams overall-argl body isstaged rett)))) + ,(method-def-expr- name sparams overall-argl body rett)))) ;; strip empty (parameters ...), normalizing `f(x;)` to `f(x)`. (define (remove-empty-parameters argl) @@ -627,14 +677,14 @@ ;; definitions without keyword arguments are passed to method-def-expr-, ;; which handles optional positional arguments by adding the needed small ;; boilerplate definitions. -(define (method-def-expr name sparams argl body isstaged rett) +(define (method-def-expr name sparams argl body rett) (let ((argl (remove-empty-parameters argl))) (if (has-parameters? argl) ;; has keywords (begin (check-kw-args (cdar argl)) - (keywords-method-def-expr name sparams argl body isstaged rett)) + (keywords-method-def-expr name sparams argl body rett)) ;; no keywords - (method-def-expr- name sparams argl body isstaged rett)))) + (method-def-expr- name sparams argl body rett)))) (define (struct-def-expr name params super fields mut) (receive @@ -763,12 +813,12 @@ ,@sig) new-params))))) -(define (ctor-def keyword name Tname params bounds sig ctor-body body wheres) +(define (ctor-def name Tname params bounds sig ctor-body body wheres) (let* ((curly? (and (pair? name) (eq? (car name) 'curly))) (curlyargs (if curly? (cddr name) '())) (name (if curly? (cadr name) name))) (cond ((not (eq? name Tname)) - `(,keyword ,(with-wheres `(call ,(if curly? + `(function ,(with-wheres `(call ,(if curly? `(curly ,name ,@curlyargs) name) ,@sig) @@ -777,7 +827,7 @@ ;; new{...} inside a non-ctor inner definition. ,(ctor-body body '()))) (wheres - `(,keyword ,(with-wheres `(call ,(if curly? + `(function ,(with-wheres `(call ,(if curly? `(curly ,name ,@curlyargs) name) ,@sig) @@ -791,7 +841,7 @@ (syntax-deprecation #f (string "inner constructor " name "(...)" (linenode-string (function-body-lineno body))) (deparse `(where (call (curly ,name ,@params) ...) ,@params)))) - `(,keyword ,sig ,(ctor-body body params))))))) + `(function ,sig ,(ctor-body body params))))))) (define (function-body-lineno body) (let ((lnos (filter (lambda (e) (and (pair? e) (eq? (car e) 'line))) @@ -818,18 +868,14 @@ (pattern-set ;; definitions without `where` (pattern-lambda (function (-$ (call name . sig) (|::| (call name . sig) _t)) body) - (ctor-def (car __) name Tname params bounds sig ctor-body body #f)) - (pattern-lambda (stagedfunction (-$ (call name . sig) (|::| (call name . sig) _t)) body) - (ctor-def (car __) name Tname params bounds sig ctor-body body #f)) + (ctor-def name Tname params bounds sig ctor-body body #f)) (pattern-lambda (= (-$ (call name . sig) (|::| (call name . sig) _t)) body) - (ctor-def 'function name Tname params bounds sig ctor-body body #f)) + (ctor-def name Tname params bounds sig ctor-body body #f)) ;; definitions with `where` (pattern-lambda (function (where (-$ (call name . sig) (|::| (call name . sig) _t)) . wheres) body) - (ctor-def (car __) name Tname params bounds sig ctor-body body wheres)) - (pattern-lambda (stagedfunction (where (-$ (call name . sig) (|::| (call name . sig) _t)) . wheres) body) - (ctor-def (car __) name Tname params bounds sig ctor-body body wheres)) + (ctor-def name Tname params bounds sig ctor-body body wheres)) (pattern-lambda (= (where (-$ (call name . sig) (|::| (call name . sig) _t)) . wheres) body) - (ctor-def 'function name Tname params bounds sig ctor-body body wheres))) + (ctor-def name Tname params bounds sig ctor-body body wheres))) ;; flatten `where`s first (pattern-replace @@ -970,7 +1016,7 @@ (loop (if isseq F (cdr F)) (cdr A) stmts (list* (if isamp `(& ,ca) ca) C) (list* g GC)))))))) -(define (expand-function-def e) ;; handle function or stagedfunction +(define (expand-function-def e) ;; handle function definitions (define (just-arglist? ex) (and (pair? ex) (or (memq (car ex) '(tuple block)) @@ -1054,7 +1100,6 @@ (where where) (else '()))) (sparams (map analyze-typevar raw-typevars)) - (isstaged (eq? (car e) 'stagedfunction)) (adj-decl (lambda (n) (if (and (decl? n) (length= n 2)) `(|::| |#self#| ,(cadr n)) n))) @@ -1083,7 +1128,7 @@ (cdr argl))) ,@raw-typevars)))) (expand-forms - (method-def-expr name sparams argl body isstaged rett)))) + (method-def-expr name sparams argl body rett)))) (else (error (string "invalid assignment location \"" (deparse name) "\"")))))) @@ -1888,7 +1933,6 @@ (define expand-table (table 'function expand-function-def - 'stagedfunction expand-function-def '-> expand-arrow 'let expand-let 'macro expand-macro-def @@ -3225,8 +3269,7 @@ f(x) = yt(x) ,@top-stmts (block ,@sp-inits (method ,name ,(cl-convert sig fname lam namemap toplevel interp) - ,(julia-bq-macro newlam) - ,(last e))))))) + ,(julia-bq-macro newlam))))))) ;; local case - lift to a new type at top level (let* ((exists (get namemap name #f)) (type-name (or exists @@ -3303,8 +3346,7 @@ f(x) = yt(x) (if iskw (caddr (lam:args lam2)) (car (lam:args lam2))) - #f closure-param-names) - ,(last e))))))) + #f closure-param-names))))))) (mk-closure ;; expression to make the closure (let* ((var-exprs (map (lambda (v) (let ((cv (assq v (cadr (lam:vinfo lam))))) @@ -3750,8 +3792,7 @@ f(x) = yt(x) (if (length> e 2) (begin (emit `(method ,(or (cadr e) 'false) ,(compile (caddr e) break-labels #t #f) - ,(linearize (cadddr e)) - ,(if (car (cddddr e)) 'true 'false))) + ,(linearize (cadddr e)))) (if value (compile '(null) break-labels value tail))) (cond (tail (emit-return e)) (value e) diff --git a/src/julia.h b/src/julia.h index 95c091affb5eea..32a7d2a62f4b89 100644 --- a/src/julia.h +++ b/src/julia.h @@ -248,7 +248,7 @@ typedef struct _jl_method_t { jl_svec_t *sparam_syms; // symbols giving static parameter names jl_value_t *source; // original code template (jl_code_info_t, but may be compressed), null for builtins struct _jl_method_instance_t *unspecialized; // unspecialized executable method instance, or null - struct _jl_method_instance_t *generator; // executable code-generating function if available + jl_value_t *generator; // executable code-generating function if available jl_array_t *roots; // pointers in generated code (shared to reduce memory), or null // cache of specializations of this method for invoke(), i.e. @@ -1055,7 +1055,7 @@ JL_DLLEXPORT jl_value_t *jl_generic_function_def(jl_sym_t *name, jl_module_t *module, jl_value_t **bp, jl_value_t *bp_owner, jl_binding_t *bnd); -JL_DLLEXPORT void jl_method_def(jl_svec_t *argdata, jl_code_info_t *f, jl_module_t *module, jl_value_t *isstaged); +JL_DLLEXPORT void jl_method_def(jl_svec_t *argdata, jl_code_info_t *f, jl_module_t *module); JL_DLLEXPORT jl_code_info_t *jl_code_for_staged(jl_method_instance_t *linfo); JL_DLLEXPORT jl_code_info_t *jl_copy_code_info(jl_code_info_t *src); JL_DLLEXPORT size_t jl_get_world_counter(void); diff --git a/src/julia_internal.h b/src/julia_internal.h index 4093a9ffe050de..afb14cd16f190e 100644 --- a/src/julia_internal.h +++ b/src/julia_internal.h @@ -1000,6 +1000,8 @@ extern jl_sym_t *isdefined_sym; extern jl_sym_t *nospecialize_sym; extern jl_sym_t *boundscheck_sym; extern jl_sym_t *gc_preserve_begin_sym; extern jl_sym_t *gc_preserve_end_sym; +extern jl_sym_t *generated_sym; +extern jl_sym_t *generated_only_sym; struct _jl_sysimg_fptrs_t; diff --git a/src/macroexpand.scm b/src/macroexpand.scm index b5b41976015541..9f99d4e23917f5 100644 --- a/src/macroexpand.scm +++ b/src/macroexpand.scm @@ -400,11 +400,6 @@ (apply append (map decl-vars* (cdr e))) (list (decl-var* e)))) -(define (function-def? e) - (and (pair? e) (or (eq? (car e) 'function) (eq? (car e) '->) - (and (eq? (car e) '=) (length= e 3) - (eventually-call? (cadr e)))))) - ;; count hygienic / escape pairs ;; and fold together a list resulting from applying the function to ;; any block at the same hygienic scope diff --git a/src/method.c b/src/method.c index 40c027254743d6..a984cb1ad9645c 100644 --- a/src/method.c +++ b/src/method.c @@ -247,24 +247,23 @@ jl_code_info_t *jl_new_code_info_from_ast(jl_expr_t *ast) } // invoke (compiling if necessary) the jlcall function pointer for a method template -STATIC_INLINE jl_value_t *jl_call_staged(jl_svec_t *sparam_vals, jl_method_instance_t *generator, +STATIC_INLINE jl_value_t *jl_call_staged(jl_method_t *def, jl_value_t *generator, jl_svec_t *sparam_vals, jl_value_t **args, uint32_t nargs) { - jl_generic_fptr_t fptr; - fptr.fptr = generator->fptr; - fptr.jlcall_api = generator->jlcall_api; - if (__unlikely(fptr.fptr == NULL || fptr.jlcall_api == 0)) { - size_t world = generator->def.method->min_world; - const char *F = jl_compile_linfo(&generator, (jl_code_info_t*)generator->inferred, world, &jl_default_cgparams).functionObject; - fptr = jl_generate_fptr(generator, F, world); + size_t n_sparams = jl_svec_len(sparam_vals); + jl_value_t **gargs; + size_t totargs = 1 + n_sparams + nargs + def->isva; + JL_GC_PUSHARGS(gargs, totargs); + gargs[0] = generator; + memcpy(&gargs[1], jl_svec_data(sparam_vals), n_sparams * sizeof(void*)); + memcpy(&gargs[1 + n_sparams], args, nargs * sizeof(void*)); + if (def->isva) { + gargs[totargs-1] = jl_f_tuple(NULL, &gargs[1 + n_sparams + def->nargs - 1], nargs - (def->nargs - 1)); + gargs[1 + n_sparams + def->nargs - 1] = gargs[totargs - 1]; } - assert(jl_svec_len(generator->def.method->sparam_syms) == jl_svec_len(sparam_vals)); - if (fptr.jlcall_api == 1) - return fptr.fptr1(args[0], &args[1], nargs-1); - else if (fptr.jlcall_api == 3) - return fptr.fptr3(sparam_vals, args[0], &args[1], nargs-1); - else - abort(); // shouldn't have inferred any other calling convention + jl_value_t *code = jl_apply(gargs, 1 + n_sparams + def->nargs); + JL_GC_POP(); + return code; } // return a newly allocated CodeInfo for the function signature @@ -273,71 +272,34 @@ JL_DLLEXPORT jl_code_info_t *jl_code_for_staged(jl_method_instance_t *linfo) { JL_TIMING(STAGED_FUNCTION); jl_tupletype_t *tt = (jl_tupletype_t*)linfo->specTypes; - jl_svec_t *env = linfo->sparam_vals; - jl_expr_t *ex = NULL; - jl_value_t *linenum = NULL; - jl_svec_t *sparam_vals = env; - jl_method_instance_t *generator = linfo->def.method->generator; + jl_method_t *def = linfo->def.method; + jl_value_t *generator = def->generator; assert(generator != NULL); - assert(linfo != generator); + assert(jl_is_method(def)); jl_code_info_t *func = NULL; - JL_GC_PUSH4(&ex, &linenum, &sparam_vals, &func); + jl_value_t *ex = NULL; + JL_GC_PUSH2(&ex, &func); jl_ptls_t ptls = jl_get_ptls_states(); int last_lineno = jl_lineno; int last_in = ptls->in_pure_callback; jl_module_t *last_m = ptls->current_module; jl_module_t *task_last_m = ptls->current_task->current_module; size_t last_age = jl_get_ptls_states()->world_age; - assert(jl_svec_len(linfo->def.method->sparam_syms) == jl_svec_len(sparam_vals)); + JL_TRY { ptls->in_pure_callback = 1; // need to eval macros in the right module ptls->current_task->current_module = ptls->current_module = linfo->def.method->module; // and the right world - ptls->world_age = generator->def.method->min_world; - - ex = jl_exprn(lambda_sym, 2); - - jl_array_t *argnames = jl_alloc_vec_any(linfo->def.method->nargs); - jl_array_ptr_set(ex->args, 0, argnames); - jl_fill_argnames((jl_array_t*)generator->inferred, argnames); - - // build the rest of the body to pass to expand - jl_expr_t *scopeblock = jl_exprn(jl_symbol("scope-block"), 1); - jl_array_ptr_set(ex->args, 1, scopeblock); - jl_expr_t *body = jl_exprn(jl_symbol("block"), 3); - jl_array_ptr_set(((jl_expr_t*)jl_exprarg(ex, 1))->args, 0, body); - - // add location meta - linenum = jl_box_long(linfo->def.method->line); - jl_value_t *linenode = jl_new_struct(jl_linenumbernode_type, linenum, linfo->def.method->file); - jl_array_ptr_set(body->args, 0, linenode); - jl_expr_t *pushloc = jl_exprn(meta_sym, 3); - jl_array_ptr_set(body->args, 1, pushloc); - jl_array_ptr_set(pushloc->args, 0, jl_symbol("push_loc")); - jl_array_ptr_set(pushloc->args, 1, linfo->def.method->file); // file - jl_array_ptr_set(pushloc->args, 2, jl_symbol("@generated body")); // function + ptls->world_age = def->min_world; // invoke code generator - assert(jl_nparams(tt) == jl_array_len(argnames) || - (linfo->def.method->isva && (jl_nparams(tt) >= jl_array_len(argnames) - 1))); - jl_value_t *generated_body = jl_call_staged(sparam_vals, generator, jl_svec_data(tt->parameters), jl_nparams(tt)); - jl_array_ptr_set(body->args, 2, generated_body); - - if (jl_is_code_info(generated_body)) { - func = (jl_code_info_t*)generated_body; - } else { - if (linfo->def.method->sparam_syms != jl_emptysvec) { - // mark this function as having the same static parameters as the generator - size_t i, nsp = jl_svec_len(linfo->def.method->sparam_syms); - jl_expr_t *newast = jl_exprn(jl_symbol("with-static-parameters"), nsp + 1); - jl_exprarg(newast, 0) = (jl_value_t*)ex; - // (with-static-parameters func_expr sp_1 sp_2 ...) - for (i = 0; i < nsp; i++) - jl_exprarg(newast, i+1) = jl_svecref(linfo->def.method->sparam_syms, i); - ex = newast; - } + jl_value_t *ex = jl_call_staged(linfo->def.method, generator, linfo->sparam_vals, jl_svec_data(tt->parameters), jl_nparams(tt)); + if (jl_is_code_info(ex)) { + func = (jl_code_info_t*)ex; + } + else { func = (jl_code_info_t*)jl_expand((jl_value_t*)ex, linfo->def.method->module); if (!jl_is_code_info(func)) { if (jl_is_expr(func) && ((jl_expr_t*)func)->head == error_sym) @@ -349,15 +311,9 @@ JL_DLLEXPORT jl_code_info_t *jl_code_for_staged(jl_method_instance_t *linfo) size_t i, l; for (i = 0, l = jl_array_len(stmts); i < l; i++) { jl_value_t *stmt = jl_array_ptr_ref(stmts, i); - stmt = jl_resolve_globals(stmt, linfo->def.method->module, env); + stmt = jl_resolve_globals(stmt, linfo->def.method->module, linfo->sparam_vals); jl_array_ptr_set(stmts, i, stmt); } - - // add pop_loc meta - jl_array_ptr_1d_push(stmts, jl_nothing); - jl_expr_t *poploc = jl_exprn(meta_sym, 1); - jl_array_ptr_set(stmts, jl_array_len(stmts) - 1, poploc); - jl_array_ptr_set(poploc->args, 0, jl_symbol("pop_loc")); } ptls->in_pure_callback = last_in; @@ -404,6 +360,7 @@ static void jl_method_set_source(jl_method_t *m, jl_code_info_t *src) { uint8_t j; uint8_t called = 0; + int gen_only = 0; for (j = 1; j < m->nargs && j <= 8; j++) { jl_value_t *ai = jl_array_ptr_ref(src->slotnames, j); if (ai == (jl_value_t*)unused_sym) @@ -434,28 +391,50 @@ static void jl_method_set_source(jl_method_t *m, jl_code_info_t *src) set_lineno = 1; } } - else if (jl_is_expr(st) && ((jl_expr_t*)st)->head == meta_sym && - jl_expr_nargs(st) > 1 && jl_exprarg(st, 0) == (jl_value_t*)nospecialize_sym) { - for (size_t j=1; j < jl_expr_nargs(st); j++) { - jl_value_t *aj = jl_exprarg(st, j); - if (jl_is_slot(aj)) { - int sn = (int)jl_slot_number(aj) - 2; - if (sn >= 0) { // @nospecialize on self is valid but currently ignored - if (sn > (m->nargs - 2)) { - jl_error("@nospecialize annotation applied to a non-argument"); - } - else if (sn >= sizeof(m->nospecialize) * 8) { - jl_printf(JL_STDERR, - "WARNING: @nospecialize annotation only supported on the first %d arguments.\n", - (int)(sizeof(m->nospecialize) * 8)); - } - else { - m->nospecialize |= (1 << sn); + else if (jl_is_expr(st) && ((jl_expr_t*)st)->head == meta_sym) { + if (jl_expr_nargs(st) > 1 && jl_exprarg(st, 0) == (jl_value_t*)nospecialize_sym) { + for (size_t j=1; j < jl_expr_nargs(st); j++) { + jl_value_t *aj = jl_exprarg(st, j); + if (jl_is_slot(aj)) { + int sn = (int)jl_slot_number(aj) - 2; + if (sn >= 0) { // @nospecialize on self is valid but currently ignored + if (sn > (m->nargs - 2)) { + jl_error("@nospecialize annotation applied to a non-argument"); + } + else if (sn >= sizeof(m->nospecialize) * 8) { + jl_printf(JL_STDERR, + "WARNING: @nospecialize annotation only supported on the first %d arguments.\n", + (int)(sizeof(m->nospecialize) * 8)); + } + else { + m->nospecialize |= (1 << sn); + } } } } + st = jl_nothing; + } + else if (jl_expr_nargs(st) == 2 && jl_exprarg(st, 0) == (jl_value_t*)generated_sym) { + m->generator = NULL; + jl_value_t *gexpr = jl_exprarg(st, 1); + if (jl_expr_nargs(gexpr) == 6) { + // expects (new (core GeneratedFunctionStub) funcname argnames sp line file) + jl_value_t *funcname = jl_exprarg(gexpr, 1); + assert(jl_is_symbol(funcname)); + if (jl_get_global(m->module, (jl_sym_t*)funcname) != NULL) { + m->generator = jl_toplevel_eval(m->module, gexpr); + jl_gc_wb(m, m->generator); + } + } + if (m->generator == NULL) { + jl_error("invalid @generated function; try placing it in global scope"); + } + st = jl_nothing; + } + else if (jl_expr_nargs(st) == 1 && jl_exprarg(st, 0) == (jl_value_t*)generated_only_sym) { + gen_only = 1; + st = jl_nothing; } - st = jl_nothing; } else { st = jl_resolve_globals(st, m->module, sparam_vars); @@ -465,7 +444,10 @@ static void jl_method_set_source(jl_method_t *m, jl_code_info_t *src) src = jl_copy_code_info(src); src->code = copy; jl_gc_wb(src, copy); - m->source = (jl_value_t*)jl_compress_ast(m, src); + if (gen_only) + m->source = NULL; + else + m->source = (jl_value_t*)jl_compress_ast(m, src); jl_gc_wb(m, m->source); JL_GC_POP(); } @@ -506,8 +488,7 @@ static jl_method_t *jl_new_method( jl_tupletype_t *sig, size_t nargs, int isva, - jl_svec_t *tvars, - int isstaged) + jl_svec_t *tvars) { size_t i, l = jl_svec_len(tvars); jl_svec_t *sparam_syms = jl_alloc_svec_uninit(l); @@ -527,13 +508,6 @@ static jl_method_t *jl_new_method( m->isva = isva; m->nargs = nargs; jl_method_set_source(m, definition); - if (isstaged) { - // create and store generator for generated functions - m->generator = jl_get_specialized(m, (jl_value_t*)jl_anytuple_type, jl_emptysvec); - jl_gc_wb(m, m->generator); - m->generator->inferred = (jl_value_t*)m->source; - m->source = NULL; - } #ifdef RECORD_METHOD_ORDER if (jl_all_methods == NULL) @@ -653,8 +627,7 @@ extern tracer_cb jl_newmeth_tracer; JL_DLLEXPORT void jl_method_def(jl_svec_t *argdata, jl_code_info_t *f, - jl_module_t *module, - jl_value_t *isstaged) + jl_module_t *module) { // argdata is svec(svec(types...), svec(typevars...)) jl_svec_t *atypes = (jl_svec_t*)jl_svecref(argdata, 0); @@ -711,7 +684,7 @@ JL_DLLEXPORT void jl_method_def(jl_svec_t *argdata, // the result is that the closure variables get interpolated directly into the AST f = jl_new_code_info_from_ast((jl_expr_t*)f); } - m = jl_new_method(f, name, module, (jl_tupletype_t*)argtype, nargs, isva, tvars, isstaged == jl_true); + m = jl_new_method(f, name, module, (jl_tupletype_t*)argtype, nargs, isva, tvars); m->nospecialize |= nospec; if (jl_has_free_typevars(argtype)) { diff --git a/src/utils.scm b/src/utils.scm index 97842a387b5446..211d79ffef7b54 100644 --- a/src/utils.scm +++ b/src/utils.scm @@ -40,12 +40,13 @@ (cdr expr))))) ;; same as above, with predicate -(define (expr-contains-p p expr) - (or (p expr) - (and (pair? expr) - (not (quoted? expr)) - (any (lambda (y) (expr-contains-p p y)) - (cdr expr))))) +(define (expr-contains-p p expr (filt (lambda (x) #t))) + (and (filt expr) + (or (p expr) + (and (pair? expr) + (not (quoted? expr)) + (any (lambda (y) (expr-contains-p p y filt)) + (cdr expr)))))) ;; find all subexprs satisfying `p`, applying `key` to each one (define (expr-find-all p expr key (filt (lambda (x) #t))) diff --git a/test/reflection.jl b/test/reflection.jl index 306456d78389d2..ae1dc07242acd9 100644 --- a/test/reflection.jl +++ b/test/reflection.jl @@ -761,7 +761,7 @@ world = typemax(UInt) mtypes, msp, m = Base._methods_by_ftype(T22979, -1, world)[] instance = Core.Inference.code_for_method(m, mtypes, msp, world, false) cinfo_generated = Core.Inference.get_staged(instance) -cinfo_ungenerated = Base.uncompressed_ast(m) +@test_throws ErrorException Base.uncompressed_ast(m) test_similar_codeinfo(@code_lowered(f22979(x22979...)), cinfo_generated) @@ -770,7 +770,4 @@ cinfos = code_lowered(f22979, typeof.(x22979), true) cinfo = cinfos[] test_similar_codeinfo(cinfo, cinfo_generated) -cinfos = code_lowered(f22979, typeof.(x22979), false) -@test length(cinfos) == 1 -cinfo = cinfos[] -test_similar_codeinfo(cinfo, cinfo_ungenerated) +@test_throws ErrorException code_lowered(f22979, typeof.(x22979), false) diff --git a/test/staged.jl b/test/staged.jl index 7bcd41a478ced5..fca2d876ba91b1 100644 --- a/test/staged.jl +++ b/test/staged.jl @@ -250,3 +250,31 @@ end @test f22440(0.0) === f22440kernel(0.0) @test f22440(0.0f0) === f22440kernel(0.0f0) @test f22440(0) === f22440kernel(0) + +# PR #23168 + +function f23168(a, x) + push!(a, 1) + if @generated + :(y = x + x) + else + y = 2x + end + push!(a, y) + if @generated + :(y = (y, $x)) + else + y = (y, typeof(x)) + end + push!(a, 3) + return y +end + +let a = Any[] + @test f23168(a, 3) == (6, Int) + @test a == [1, 6, 3] + @test contains(string(code_lowered(f23168, (Vector{Any},Int))), "x + x") + @test contains(string(Base.uncompressed_ast(first(methods(f23168)))), "2 * x") + @test contains(string(code_lowered(f23168, (Vector{Any},Int), false)), "2 * x") + @test contains(string(code_typed(f23168, (Vector{Any},Int))), "(Base.add_int)(x, x)") +end