From 01a804b39068090bb120911e8c657fe666198cd3 Mon Sep 17 00:00:00 2001 From: Jeff Bezanson Date: Tue, 31 Jan 2017 12:43:46 -0500 Subject: [PATCH] add keyword argument support to `invoke` fixes #7045 --- src/builtin_proto.h | 2 ++ src/builtins.c | 56 ++++++++++++++++++++++++++++++++++++++++++++ src/codegen.cpp | 1 + src/dump.c | 2 +- src/gf.c | 4 +--- src/julia_internal.h | 1 + test/keywordargs.jl | 8 +++++++ 7 files changed, 70 insertions(+), 4 deletions(-) diff --git a/src/builtin_proto.h b/src/builtin_proto.h index 0cd436157fc7d..0ac2a70c4b19b 100644 --- a/src/builtin_proto.h +++ b/src/builtin_proto.h @@ -32,6 +32,8 @@ DECLARE_BUILTIN(apply_type); DECLARE_BUILTIN(applicable); DECLARE_BUILTIN(invoke); DECLARE_BUILTIN(_expr); DECLARE_BUILTIN(typeassert); +JL_CALLABLE(jl_f_invoke_kwsorter); + #ifdef __cplusplus } #endif diff --git a/src/builtins.c b/src/builtins.c index 8713e81e34d83..f5f6d7bb6a239 100644 --- a/src/builtins.c +++ b/src/builtins.c @@ -1116,6 +1116,56 @@ JL_CALLABLE(jl_f_invoke) return res; } +JL_CALLABLE(jl_f_invoke_kwsorter) +{ + JL_NARGSV(invoke, 3); + jl_value_t *kwargs = args[0]; + // args[1] is `invoke` itself + jl_value_t *func = args[2]; + jl_value_t *argtypes = args[3]; + jl_value_t *kws = jl_get_keyword_sorter(func); + JL_GC_PUSH1(&argtypes); + if (jl_is_tuple(argtypes)) { + jl_depwarn("`invoke(f, (types...), ...)` is deprecated, " + "use `invoke(f, Tuple{types...}, ...)` instead", + (jl_value_t*)jl_symbol("invoke")); + argtypes = (jl_value_t*)jl_apply_tuple_type_v((jl_value_t**)jl_data_ptr(argtypes), + jl_nfields(argtypes)); + } + if (jl_is_tuple_type(argtypes)) { + // construct a tuple type for invoking a keyword sorter by putting `Vector{Any}` + // and the type of the function at the front. + size_t i, nt = jl_nparams(argtypes) + 2; + if (nt < jl_page_size/sizeof(jl_value_t*)) { + jl_value_t **types = (jl_value_t**)alloca(nt*sizeof(jl_value_t*)); + types[0] = jl_array_any_type; types[1] = jl_typeof(func); + for(i=2; i < nt; i++) + types[i] = jl_tparam(argtypes,i-2); + argtypes = (jl_value_t*)jl_apply_tuple_type_v(types, nt); + } + else { + jl_svec_t *types = jl_alloc_svec_uninit(nt); + JL_GC_PUSH1(&types); + jl_svecset(types, 0, jl_array_any_type); + jl_svecset(types, 1, jl_typeof(func)); + for(i=2; i < nt; i++) + jl_svecset(types, i, jl_tparam(argtypes,i-2)); + argtypes = (jl_value_t*)jl_apply_tuple_type(types); + JL_GC_POP(); + } + } + else { + // invoke will throw an error + } + args[0] = kws; + args[1] = argtypes; + args[2] = kwargs; + args[3] = func; + jl_value_t *res = jl_f_invoke(NULL, args, nargs); + JL_GC_POP(); + return res; +} + // Expr constructor for internal use ------------------------------------------ jl_expr_t *jl_exprn(jl_sym_t *head, size_t n) @@ -1288,6 +1338,12 @@ void jl_init_primitives(void) // method table utils add_builtin_func("applicable", jl_f_applicable); add_builtin_func("invoke", jl_f_invoke); + jl_value_t *invokef = jl_get_global(jl_core_module, jl_symbol("invoke")); + jl_typename_t *itn = ((jl_datatype_t*)jl_typeof(invokef))->name; + jl_value_t *ikws = jl_new_generic_function_with_supertype(itn->name, jl_core_module, jl_builtin_type, 1); + itn->mt->kwsorter = ikws; + jl_gc_wb(itn->mt, ikws); + jl_mk_builtin_func((jl_datatype_t*)jl_typeof(ikws), jl_symbol_name(jl_gf_name(ikws)), jl_f_invoke_kwsorter); // internal functions add_builtin_func("apply_type", jl_f_apply_type); diff --git a/src/codegen.cpp b/src/codegen.cpp index d51e3c03f5f7b..ba1ef2500e895 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -5898,6 +5898,7 @@ static void init_julia_llvm_env(Module *m) builtin_func_map[jl_f_svec] = jlcall_func_to_llvm("jl_f_svec", &jl_f_svec, m); builtin_func_map[jl_f_applicable] = jlcall_func_to_llvm("jl_f_applicable", &jl_f_applicable, m); builtin_func_map[jl_f_invoke] = jlcall_func_to_llvm("jl_f_invoke", &jl_f_invoke, m); + builtin_func_map[jl_f_invoke_kwsorter] = jlcall_func_to_llvm("jl_f_invoke_kwsorter", &jl_f_invoke_kwsorter, m); builtin_func_map[jl_f_isdefined] = jlcall_func_to_llvm("jl_f_isdefined", &jl_f_isdefined, m); builtin_func_map[jl_f_getfield] = jlcall_func_to_llvm("jl_f_getfield", &jl_f_getfield, m); builtin_func_map[jl_f_setfield] = jlcall_func_to_llvm("jl_f_setfield", &jl_f_setfield, m); diff --git a/src/dump.c b/src/dump.c index b794a76231b32..5f6403a802b58 100644 --- a/src/dump.c +++ b/src/dump.c @@ -74,7 +74,7 @@ static const jl_fptr_t id_to_fptrs[] = { NULL, NULL, jl_f_throw, jl_f_is, jl_f_typeof, jl_f_issubtype, jl_f_isa, jl_f_typeassert, jl_f__apply, jl_f__apply_pure, jl_f_isdefined, - jl_f_tuple, jl_f_svec, jl_f_intrinsic_call, + jl_f_tuple, jl_f_svec, jl_f_intrinsic_call, jl_f_invoke_kwsorter, jl_f_getfield, jl_f_setfield, jl_f_fieldtype, jl_f_nfields, jl_f_arrayref, jl_f_arrayset, jl_f_arraysize, jl_f_apply_type, jl_f_applicable, jl_f_invoke, jl_unprotect_stack, jl_f_sizeof, jl_f__expr, diff --git a/src/gf.c b/src/gf.c index fd6abc181c2f8..d4cd133de1a6a 100644 --- a/src/gf.c +++ b/src/gf.c @@ -198,8 +198,6 @@ JL_DLLEXPORT jl_value_t *jl_methtable_lookup(jl_methtable_t *mt, jl_tupletype_t // ----- MethodInstance specialization instantiation ----- // JL_DLLEXPORT jl_method_t *jl_new_method_uninit(void); -static jl_function_t *jl_new_generic_function_with_supertype(jl_sym_t *name, - jl_module_t *module, jl_datatype_t *st, int iskw); void jl_mk_builtin_func(jl_datatype_t *dt, const char *name, jl_fptr_t fptr) { jl_sym_t *sname = jl_symbol(name); @@ -2377,7 +2375,7 @@ JL_DLLEXPORT jl_value_t *jl_get_invoke_lambda(jl_methtable_t *mt, } // Return value is rooted globally -static jl_function_t *jl_new_generic_function_with_supertype(jl_sym_t *name, jl_module_t *module, jl_datatype_t *st, int iskw) +jl_function_t *jl_new_generic_function_with_supertype(jl_sym_t *name, jl_module_t *module, jl_datatype_t *st, int iskw) { // type name is function name prefixed with # size_t l = strlen(jl_symbol_name(name)); diff --git a/src/julia_internal.h b/src/julia_internal.h index f7aacd972661e..f735aae68fa4b 100644 --- a/src/julia_internal.h +++ b/src/julia_internal.h @@ -349,6 +349,7 @@ jl_value_t *jl_wrap_vararg(jl_value_t *t, jl_value_t *n); void jl_assign_bits(void *dest, jl_value_t *bits); jl_expr_t *jl_exprn(jl_sym_t *head, size_t n); jl_function_t *jl_new_generic_function(jl_sym_t *name, jl_module_t *module); +jl_function_t *jl_new_generic_function_with_supertype(jl_sym_t *name, jl_module_t *module, jl_datatype_t *st, int iskw); jl_function_t *jl_module_call_func(jl_module_t *m); int jl_is_submodule(jl_module_t *child, jl_module_t *parent); diff --git a/test/keywordargs.jl b/test/keywordargs.jl index d3dc9a55a3738..0e1057e80155d 100644 --- a/test/keywordargs.jl +++ b/test/keywordargs.jl @@ -240,3 +240,11 @@ eval(Core.Inference, quote g18396(;x=1,y=2) = x+y end) @test Core.Inference.f18396() == 3 + +# issue #7045, `invoke` with keyword args +f7045(x::Float64; y=true) = y ? 1 : invoke(f7045,Tuple{Real},x,y=y) +f7045(x::Real; y=true) = y ? 2 : 3 +@test f7045(1) === 2 +@test f7045(1.0) === 1 +@test f7045(1, y=false) === 3 +@test f7045(1.0, y=false) === 3