Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Julep/Very WIP - Heap allocated immutable arrays and compiler support #31630

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,12 +151,20 @@ function vect(X...)
return copyto!(Vector{T}(undef, length(X)), X)
end

size(a::Array, d::Integer) = arraysize(a, convert(Int, d))
size(a::Vector) = (arraysize(a,1),)
size(a::Matrix) = (arraysize(a,1), arraysize(a,2))
size(a::Array{<:Any,N}) where {N} = (@_inline_meta; ntuple(M -> size(a, M), Val(N)))
const ImmutableArray = Core.ImmutableArray
const IMArray{T,N} = Union{Array{T, N}, ImmutableArray{T,N}}
const IMVector{T} = IMArray{T, 1}
const IMMatrix{T} = IMArray{T, 2}

asize_from(a::Array, n) = n > ndims(a) ? () : (arraysize(a,n), asize_from(a, n+1)...)
freeze(a::Array) = Core.arrayfreeze(a)
melt(a::ImmutableArray) = Core.arraymelt(a)

size(a::IMArray, d::Integer) = arraysize(a, convert(Int, d))
size(a::IMVector) = (arraysize(a,1),)
size(a::IMMatrix) = (arraysize(a,1), arraysize(a,2))
size(a::IMArray{<:Any,N}) where {N} = (@_inline_meta; ntuple(M -> size(a, M), Val(N)))

asize_from(a::IMArray, n) = n > ndims(a) ? () : (arraysize(a,n), asize_from(a, n+1)...)

"""
Base.isbitsunion(::Type{T})
Expand Down Expand Up @@ -208,6 +216,13 @@ function isassigned(a::Array, i::Int...)
ccall(:jl_array_isassigned, Cint, (Any, UInt), a, ii) == 1
end

function isassigned(a::ImmutableArray, i::Int...)
@_inline_meta
ii = (_sub2ind(size(a), i...) % UInt) - 1
@boundscheck ii < length(a) % UInt || return false
ccall(:jl_array_isassigned, Cint, (Any, UInt), a, ii) == 1
end

## copy ##

"""
Expand Down Expand Up @@ -728,6 +743,9 @@ function getindex end
@eval getindex(A::Array, i1::Int) = arrayref($(Expr(:boundscheck)), A, i1)
@eval getindex(A::Array, i1::Int, i2::Int, I::Int...) = (@_inline_meta; arrayref($(Expr(:boundscheck)), A, i1, i2, I...))

@eval getindex(A::ImmutableArray, i1::Int) = arrayref($(Expr(:boundscheck)), A, i1)
@eval getindex(A::ImmutableArray, i1::Int, i2::Int, I::Int...) = (@_inline_meta; arrayref($(Expr(:boundscheck)), A, i1, i2, I...))

# Faster contiguous indexing using copyto! for UnitRange and Colon
function getindex(A::Array, I::UnitRange{Int})
@_inline_meta
Expand Down
7 changes: 7 additions & 0 deletions base/compiler/ssair/domtree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@ function dominates(domtree::DomTree, bb1::Int, bb2::Int)
return bb1 == bb2
end

function ssadominates(ir::IRCode, domtree::DomTree, ssa1::Int, ssa2::Int)
bb1 = block_for_inst(ir.cfg, ssa1)
bb2 = block_for_inst(ir.cfg, ssa2)
bb1 == bb2 && return ssa1 < ssa2
return dominates(domtree, bb1, bb2)
end

bb_unreachable(domtree::DomTree, bb::Int) = bb != 1 && domtree.nodes[bb].level == 1

function update_level!(domtree::Vector{DomTreeNode}, node::Int, level::Int)
Expand Down
1 change: 1 addition & 0 deletions base/compiler/ssair/driver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ function run_passes(ci::CodeInfo, nargs::Int, sv::OptimizationState)
#@Base.show ir.new_nodes
#@Base.show ("after_sroa", ir)
ir = adce_pass!(ir)
ir = memory_opt!(ir)
#@Base.show ("after_adce", ir)
@timeit "type lift" ir = type_lift_pass!(ir)
@timeit "compact 3" ir = compact!(ir)
Expand Down
73 changes: 73 additions & 0 deletions base/compiler/ssair/passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1140,3 +1140,76 @@ function cfg_simplify!(ir::IRCode)
compact.active_result_bb = length(bb_starts)
return finish(compact)
end

function is_allocation(stmt)
isexpr(stmt, :foreigncall) || return false
s = stmt.args[1]
isa(s, QuoteNode) && (s = s.value)
return s === :jl_alloc_array_1d
end

function memory_opt!(ir::IRCode)
# TODO: This is wrong if there's a use in a phi node.
compact = IncrementalCompact(ir, true)
uses = IdDict{Int, Vector{Int}}()
relevant = IdSet{Int}()
revisit = Int[]
function mark_val(val)
isa(val, SSAValue) || return
val.id in relevant && pop!(relevant, val.id)
end
for (idx, stmt) in compact
if isa(stmt, ReturnNode)
isdefined(stmt, :val) || continue
val = stmt.val
if isa(val, SSAValue) && val.id in relevant
(haskey(uses, val.id)) || (uses[val.id] = Int[])
push!(uses[val.id], idx)
end
continue
end
(isexpr(stmt, :call) || isexpr(stmt, :foreigncall)) || continue
if is_allocation(stmt)
push!(relevant, idx)
# TODO: Mark everything else here
continue
end
# TODO: Replace this by interprocedural escape analysis
if is_known_call(stmt, arrayset, compact)
# The value being set escapes, everything else doesn't
mark_val(stmt.args[4])
arr = stmt.args[3]
if isa(arr, SSAValue) && arr.id in relevant
(haskey(uses, arr.id)) || (uses[arr.id] = Int[])
push!(uses[arr.id], idx)
end
elseif is_known_call(stmt, Core.arrayfreeze, compact) && isa(stmt.args[2], SSAValue)
push!(revisit, idx)
else
# For now we assume everything escapes
for ur in userefs(stmt)
mark_val(ur[])
end
end
end
ir = finish(compact)
domtree = construct_domtree(ir.cfg)
for idx in revisit
# Make sure that the value we reference didn't escape
id = ir.stmts[idx].args[2].id
(id in relevant) || continue

# We're ok to steal the memory if we don't dominate any uses
ok = true
for use in uses[id]
if ssadominates(ir, domtree, idx, use)
ok = false
break
end
end
ok || continue

ir.stmts[idx].args[1] = Core.mutating_arrayfreeze
end
return ir
end
15 changes: 15 additions & 0 deletions base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1266,6 +1266,21 @@ function builtin_tfunction(@nospecialize(f), argtypes::Array{Any,1},
end
end
return Any
elseif f === Core.arrayfreeze || f === Core.arraymelt
if length(argtypes) != 1
isva && return Any
return Bottom
end
a = widenconst(argtypes[1])
at = (f === Core.arrayfreeze ? Array : ImmutableArray)
rt = (f === Core.arrayfreeze ? ImmutableArray : Array)
if a <: at
unw = unwrap_unionall(a)
if isa(unw, DataType)
return rewrap_unionall(rt{unw.parameters[1], unw.parameters[2]}, a)
end
end
return rt
elseif f === Expr
if length(argtypes) < 1 && !isva
return Bottom
Expand Down
4 changes: 4 additions & 0 deletions base/exports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -916,6 +916,10 @@ export
rand,
randn,

# mutation
freeze,
melt,

# Macros
# parser internal
@__FILE__,
Expand Down
2 changes: 2 additions & 0 deletions src/builtin_proto.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ DECLARE_BUILTIN(getfield); DECLARE_BUILTIN(setfield);
DECLARE_BUILTIN(fieldtype); DECLARE_BUILTIN(arrayref);
DECLARE_BUILTIN(const_arrayref);
DECLARE_BUILTIN(arrayset); DECLARE_BUILTIN(arraysize);
DECLARE_BUILTIN(arrayfreeze); DECLARE_BUILTIN(arraymelt);
DECLARE_BUILTIN(mutating_arrayfreeze);
DECLARE_BUILTIN(apply_type); DECLARE_BUILTIN(applicable);
DECLARE_BUILTIN(invoke); DECLARE_BUILTIN(_expr);
DECLARE_BUILTIN(typeassert); DECLARE_BUILTIN(ifelse);
Expand Down
55 changes: 53 additions & 2 deletions src/builtins.c
Original file line number Diff line number Diff line change
Expand Up @@ -1014,7 +1014,9 @@ JL_CALLABLE(jl_f__typevar)
JL_CALLABLE(jl_f_arraysize)
{
JL_NARGS(arraysize, 2, 2);
JL_TYPECHK(arraysize, array, args[0]);
if (!jl_is_arrayish(args[0])) {
jl_type_error("arraysize", (jl_value_t*)jl_array_type, args[0]);
}
jl_array_t *a = (jl_array_t*)args[0];
size_t nd = jl_array_ndims(a);
JL_TYPECHK(arraysize, long, args[1]);
Expand Down Expand Up @@ -1053,7 +1055,9 @@ JL_CALLABLE(jl_f_arrayref)
{
JL_NARGSV(arrayref, 3);
JL_TYPECHK(arrayref, bool, args[0]);
JL_TYPECHK(arrayref, array, args[1]);
if (!jl_is_arrayish(args[1])) {
jl_type_error("arrayref", (jl_value_t*)jl_array_type, args[1]);
}
jl_array_t *a = (jl_array_t*)args[1];
size_t i = array_nd_index(a, &args[2], nargs - 2, "arrayref");
return jl_arrayref(a, i);
Expand All @@ -1075,6 +1079,49 @@ JL_CALLABLE(jl_f_arrayset)
return args[1];
}

JL_CALLABLE(jl_f_arrayfreeze)
{
JL_NARGSV(arrayfreeze, 1);
JL_TYPECHK(arrayfreeze, array, args[0]);
jl_array_t *a = (jl_array_t*)args[0];
jl_datatype_t *it = (jl_datatype_t *)jl_apply_type2((jl_value_t*)jl_immutable_array_type,
jl_tparam0(jl_typeof(a)), jl_tparam1(jl_typeof(a)));
// The idea is to elide this copy if the compiler or runtime can prove that
// doing so is safe to do.
jl_array_t *na = jl_array_copy(a);
jl_set_typeof(na, it);
return (jl_value_t*)na;
}

JL_CALLABLE(jl_f_mutating_arrayfreeze)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At first, I was thrown by what this name meant. It makes sense in the end; in this case you are sort-of mutating the type of a but not the value/data... though I did wonder if it would be better language to describe this as "taking" or "stealing" a or something like that.

{
JL_NARGSV(arrayfreeze, 1);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mutating_arrayfreeze?

JL_TYPECHK(arrayfreeze, array, args[0]);
jl_array_t *a = (jl_array_t*)args[0];
jl_datatype_t *it = (jl_datatype_t *)jl_apply_type2((jl_value_t*)jl_immutable_array_type,
jl_tparam0(jl_typeof(a)), jl_tparam1(jl_typeof(a)));
// The idea is to elide this copy if the compiler or runtime can prove that
// doing so is safe to do.
jl_set_typeof(a, it);
return (jl_value_t*)a;
}

JL_CALLABLE(jl_f_arraymelt)
{
JL_NARGSV(arrayfreeze, 1);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

arraymelt?

if (((jl_datatype_t*)jl_typeof(args[0]))->name != jl_immutable_array_typename) {
jl_type_error("arraymelt", (jl_value_t*)jl_immutable_array_type, args[0]);
}
jl_array_t *a = (jl_array_t*)args[0];
jl_datatype_t *it = (jl_datatype_t *)jl_apply_type2((jl_value_t*)jl_array_type,
jl_tparam0(jl_typeof(a)), jl_tparam1(jl_typeof(a)));
// The idea is to elide this copy if the compiler or runtime can prove that
// doing so is safe to do.
jl_array_t *na = jl_array_copy(a);
jl_set_typeof(na, it);
return (jl_value_t*)na;
}

// IntrinsicFunctions ---------------------------------------------------------

static void (*runtime_fp[num_intrinsics])(void);
Expand Down Expand Up @@ -1218,6 +1265,9 @@ void jl_init_primitives(void) JL_GC_DISABLED
add_builtin_func("const_arrayref", jl_f_arrayref);
add_builtin_func("arrayset", jl_f_arrayset);
add_builtin_func("arraysize", jl_f_arraysize);
add_builtin_func("arrayfreeze", jl_f_arrayfreeze);
add_builtin_func("mutating_arrayfreeze", jl_f_mutating_arrayfreeze);
add_builtin_func("arraymelt", jl_f_arraymelt);

// method table utils
add_builtin_func("applicable", jl_f_applicable);
Expand Down Expand Up @@ -1276,6 +1326,7 @@ void jl_init_primitives(void) JL_GC_DISABLED
add_builtin("AbstractArray", (jl_value_t*)jl_abstractarray_type);
add_builtin("DenseArray", (jl_value_t*)jl_densearray_type);
add_builtin("Array", (jl_value_t*)jl_array_type);
add_builtin("ImmutableArray", (jl_value_t*)jl_immutable_array_type);

add_builtin("Expr", (jl_value_t*)jl_expr_type);
add_builtin("LineNumberNode", (jl_value_t*)jl_linenumbernode_type);
Expand Down
5 changes: 4 additions & 1 deletion src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ static MDNode *best_tbaa(jl_value_t *jt) {
// note that this is guaranteed to include jl_isbits
static bool jl_justbits(jl_value_t* t)
{
return jl_is_immutable_datatype(t) && ((jl_datatype_t*)t)->layout && ((jl_datatype_t*)t)->layout->npointers == 0;
return jl_is_immutable_datatype(t) && !jl_is_arrayish_type(t) && ((jl_datatype_t*)t)->layout && ((jl_datatype_t*)t)->layout->npointers == 0;
}

// metadata tracking for a llvm Value* during codegen
Expand Down Expand Up @@ -7244,6 +7244,9 @@ static void init_julia_llvm_env(Module *m)
builtin_func_map[jl_f_const_arrayref] = jlcall_func_to_llvm("jl_f_const_arrayref", &jl_f_arrayref, m);
builtin_func_map[jl_f_arrayset] = jlcall_func_to_llvm("jl_f_arrayset", &jl_f_arrayset, m);
builtin_func_map[jl_f_arraysize] = jlcall_func_to_llvm("jl_f_arraysize", &jl_f_arraysize, m);
builtin_func_map[jl_f_arrayfreeze] = jlcall_func_to_llvm("jl_f_arrayfreeze", &jl_f_arrayfreeze, m);
builtin_func_map[jl_f_mutating_arrayfreeze] = jlcall_func_to_llvm("jl_f_mutating_arrayfreeze", &jl_f_mutating_arrayfreeze, m);
builtin_func_map[jl_f_arraymelt] = jlcall_func_to_llvm("jl_f_arraymelt", &jl_f_arraymelt, m);
builtin_func_map[jl_f_apply_type] = jlcall_func_to_llvm("jl_f_apply_type", &jl_f_apply_type, m);
jltuple_func = builtin_func_map[jl_f_tuple];
jlgetfield_func = builtin_func_map[jl_f_getfield];
Expand Down
3 changes: 2 additions & 1 deletion src/datatype.c
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ unsigned jl_special_vector_alignment(size_t nfields, jl_value_t *t)
STATIC_INLINE int jl_is_datatype_make_singleton(jl_datatype_t *d)
{
return (!d->abstract && jl_datatype_size(d) == 0 && d != jl_sym_type && d->name != jl_array_typename &&
d->name != jl_immutable_array_typename &&
d->uid != 0 && !d->mutabl);
}

Expand Down Expand Up @@ -300,7 +301,7 @@ void jl_compute_field_offsets(jl_datatype_t *st)
// compute whether this type can be inlined
// based on whether its definition is self-referential
if (w->types != NULL) {
st->isbitstype = st->isconcretetype && !st->mutabl;
st->isbitstype = st->isconcretetype && !st->mutabl && st->name != jl_immutable_array_typename;
size_t i, nf = jl_field_count(st);
for (i = 0; i < nf; i++) {
jl_value_t *fld = jl_field_type(st, i);
Expand Down
1 change: 1 addition & 0 deletions src/dump.c
Original file line number Diff line number Diff line change
Expand Up @@ -3320,6 +3320,7 @@ void jl_init_serializer(void)

arraylist_new(&builtin_typenames, 0);
arraylist_push(&builtin_typenames, jl_array_typename);
arraylist_push(&builtin_typenames, jl_immutable_array_typename);
arraylist_push(&builtin_typenames, ((jl_datatype_t*)jl_ref_type->body)->name);
arraylist_push(&builtin_typenames, jl_pointer_typename);
arraylist_push(&builtin_typenames, jl_type_typename);
Expand Down
12 changes: 12 additions & 0 deletions src/jltypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ jl_unionall_t *jl_typetype_type;

jl_unionall_t *jl_array_type;
jl_typename_t *jl_array_typename;
jl_unionall_t *jl_immutable_array_type;
jl_typename_t *jl_immutable_array_typename;
jl_value_t *jl_array_uint8_type;
jl_value_t *jl_array_any_type;
jl_value_t *jl_array_symbol_type;
Expand Down Expand Up @@ -1963,6 +1965,16 @@ void jl_init_types(void) JL_GC_DISABLED
jl_array_uint8_type = jl_apply_type2((jl_value_t*)jl_array_type, (jl_value_t*)jl_uint8_type, jl_box_long(1));
jl_array_int32_type = jl_apply_type2((jl_value_t*)jl_array_type, (jl_value_t*)jl_int32_type, jl_box_long(1));

tv = jl_svec2(tvar("T"), tvar("N"));
jl_immutable_array_type = (jl_unionall_t*)
jl_new_datatype(jl_symbol("ImmutableArray"), core,
(jl_datatype_t*)
jl_apply_type((jl_value_t*)jl_densearray_type, jl_svec_data(tv), 2),
tv,
jl_emptysvec, jl_emptysvec, 0, 0, 0)->name->wrapper;
jl_immutable_array_typename = ((jl_datatype_t*)jl_unwrap_unionall((jl_value_t*)jl_immutable_array_type))->name;
jl_compute_field_offsets((jl_datatype_t*)jl_unwrap_unionall((jl_value_t*)jl_immutable_array_type));

jl_expr_type =
jl_new_datatype(jl_symbol("Expr"), core,
jl_any_type, jl_emptysvec,
Expand Down
Loading